Updates
This commit is contained in:
@@ -0,0 +1,15 @@
|
||||
__all__ = [
|
||||
'django_oauth_toolkit',
|
||||
'djangorestframework_camel_case',
|
||||
'rest_auth',
|
||||
'rest_framework',
|
||||
'rest_polymorphic',
|
||||
'rest_framework_dataclasses',
|
||||
'rest_framework_jwt',
|
||||
'rest_framework_simplejwt',
|
||||
'django_filters',
|
||||
'rest_framework_recursive',
|
||||
'rest_framework_gis',
|
||||
'pydantic',
|
||||
'knox_auth_token',
|
||||
]
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,305 @@
|
||||
from django.db import models
|
||||
|
||||
from drf_spectacular.drainage import add_trace_message, get_override, has_override, warn
|
||||
from drf_spectacular.extensions import OpenApiFilterExtension
|
||||
from drf_spectacular.plumbing import (
|
||||
build_array_type, build_basic_type, build_choice_description_list, build_parameter_type,
|
||||
follow_field_source, force_instance, get_manager, get_type_hints, get_view_model, is_basic_type,
|
||||
is_field,
|
||||
)
|
||||
from drf_spectacular.settings import spectacular_settings
|
||||
from drf_spectacular.types import OpenApiTypes
|
||||
from drf_spectacular.utils import OpenApiParameter
|
||||
|
||||
_NoHint = object()
|
||||
|
||||
|
||||
class DjangoFilterExtension(OpenApiFilterExtension):
|
||||
"""
|
||||
Extensions that specifically deals with ``django-filter`` fields. The introspection
|
||||
attempts to estimate the underlying model field types to generate filter types.
|
||||
|
||||
However, there are under-specified filter fields for which heuristics need to be performed.
|
||||
This serves as an explicit list of all partially-handled filter fields:
|
||||
|
||||
- ``AllValuesFilter``: skip choices to prevent DB query
|
||||
- ``AllValuesMultipleFilter``: skip choices to prevent DB query, multi handled though
|
||||
- ``ChoiceFilter``: enum handled, type under-specified
|
||||
- ``DateRangeFilter``: N/A
|
||||
- ``LookupChoiceFilter``: N/A
|
||||
- ``ModelChoiceFilter``: enum handled
|
||||
- ``ModelMultipleChoiceFilter``: enum, multi handled
|
||||
- ``MultipleChoiceFilter``: enum, multi handled
|
||||
- ``RangeFilter``: min/max handled, type under-specified
|
||||
- ``TypedChoiceFilter``: enum handled
|
||||
- ``TypedMultipleChoiceFilter``: enum, multi handled
|
||||
|
||||
In case of warnings or incorrect filter types, you can manually override the underlying
|
||||
field type with a manual ``extend_schema_field`` decoration. Alternatively, if you have a
|
||||
filter method for your filter field, you can attach ``extend_schema_field`` to that filter
|
||||
method.
|
||||
|
||||
.. code-block::
|
||||
|
||||
class SomeFilter(FilterSet):
|
||||
some_field = extend_schema_field(OpenApiTypes.NUMBER)(
|
||||
RangeFilter(field_name='some_manually_annotated_field_in_qs')
|
||||
)
|
||||
|
||||
"""
|
||||
target_class = 'django_filters.rest_framework.DjangoFilterBackend'
|
||||
match_subclasses = True
|
||||
|
||||
def get_schema_operation_parameters(self, auto_schema, *args, **kwargs):
|
||||
model = get_view_model(auto_schema.view)
|
||||
if not model:
|
||||
return []
|
||||
|
||||
filterset_class = self.target.get_filterset_class(auto_schema.view, get_manager(model).none())
|
||||
if not filterset_class:
|
||||
return []
|
||||
|
||||
result = []
|
||||
with add_trace_message(filterset_class):
|
||||
for field_name, filter_field in filterset_class.base_filters.items():
|
||||
result += self.resolve_filter_field(
|
||||
auto_schema, model, filterset_class, field_name, filter_field
|
||||
)
|
||||
return result
|
||||
|
||||
def resolve_filter_field(self, auto_schema, model, filterset_class, field_name, filter_field):
|
||||
from django_filters import filters
|
||||
|
||||
unambiguous_mapping = {
|
||||
filters.CharFilter: OpenApiTypes.STR,
|
||||
filters.BooleanFilter: OpenApiTypes.BOOL,
|
||||
filters.DateFilter: OpenApiTypes.DATE,
|
||||
filters.DateTimeFilter: OpenApiTypes.DATETIME,
|
||||
filters.IsoDateTimeFilter: OpenApiTypes.DATETIME,
|
||||
filters.TimeFilter: OpenApiTypes.TIME,
|
||||
filters.UUIDFilter: OpenApiTypes.UUID,
|
||||
filters.DurationFilter: OpenApiTypes.DURATION,
|
||||
filters.OrderingFilter: OpenApiTypes.STR,
|
||||
filters.TimeRangeFilter: OpenApiTypes.TIME,
|
||||
filters.DateFromToRangeFilter: OpenApiTypes.DATE,
|
||||
filters.IsoDateTimeFromToRangeFilter: OpenApiTypes.DATETIME,
|
||||
filters.DateTimeFromToRangeFilter: OpenApiTypes.DATETIME,
|
||||
}
|
||||
filter_method = self._get_filter_method(filterset_class, filter_field)
|
||||
filter_method_hint = self._get_filter_method_hint(filter_method)
|
||||
filter_choices = self._get_explicit_filter_choices(filter_field)
|
||||
schema_from_override = False
|
||||
|
||||
if has_override(filter_field, 'field') or has_override(filter_method, 'field'):
|
||||
schema_from_override = True
|
||||
annotation = (
|
||||
get_override(filter_field, 'field') or get_override(filter_method, 'field')
|
||||
)
|
||||
if is_basic_type(annotation):
|
||||
schema = build_basic_type(annotation)
|
||||
elif isinstance(annotation, dict):
|
||||
# allow injecting raw schema via @extend_schema_field decorator
|
||||
schema = annotation.copy()
|
||||
elif is_field(annotation):
|
||||
schema = auto_schema._map_serializer_field(force_instance(annotation), "request")
|
||||
else:
|
||||
warn(
|
||||
f"Unsupported annotation {annotation} on filter field {field_name}. defaulting to string."
|
||||
)
|
||||
schema = build_basic_type(OpenApiTypes.STR)
|
||||
elif filter_method_hint is not _NoHint:
|
||||
if is_basic_type(filter_method_hint):
|
||||
schema = build_basic_type(filter_method_hint)
|
||||
else:
|
||||
schema = build_basic_type(OpenApiTypes.STR)
|
||||
elif isinstance(filter_field, tuple(unambiguous_mapping)):
|
||||
for cls in filter_field.__class__.__mro__:
|
||||
if cls in unambiguous_mapping:
|
||||
schema = build_basic_type(unambiguous_mapping[cls])
|
||||
break
|
||||
elif isinstance(filter_field, (filters.NumberFilter, filters.NumericRangeFilter)):
|
||||
# NumberField is underspecified by itself. try to find the
|
||||
# type that makes the most sense or default to generic NUMBER
|
||||
model_field = self._get_model_field(filter_field, model)
|
||||
if isinstance(model_field, (models.IntegerField, models.AutoField)):
|
||||
schema = build_basic_type(OpenApiTypes.INT)
|
||||
elif isinstance(model_field, models.FloatField):
|
||||
schema = build_basic_type(OpenApiTypes.FLOAT)
|
||||
elif isinstance(model_field, models.DecimalField):
|
||||
schema = build_basic_type(OpenApiTypes.NUMBER) # TODO may be improved
|
||||
else:
|
||||
schema = build_basic_type(OpenApiTypes.NUMBER)
|
||||
elif isinstance(filter_field, (filters.ChoiceFilter, filters.MultipleChoiceFilter)):
|
||||
try:
|
||||
schema = self._get_schema_from_model_field(auto_schema, filter_field, model)
|
||||
except Exception:
|
||||
if filter_choices and is_basic_type(type(filter_choices[0])):
|
||||
# fallback to type guessing from first choice element
|
||||
schema = build_basic_type(type(filter_choices[0]))
|
||||
else:
|
||||
warn(
|
||||
f'Unable to guess choice types from values, filter method\'s type hint '
|
||||
f'or find "{field_name}" in model. Defaulting to string.'
|
||||
)
|
||||
schema = build_basic_type(OpenApiTypes.STR)
|
||||
else:
|
||||
# the last resort is to look up the type via the model or queryset field
|
||||
# and emit a warning if we were unsuccessful.
|
||||
try:
|
||||
schema = self._get_schema_from_model_field(auto_schema, filter_field, model)
|
||||
except Exception as exc: # pragma: no cover
|
||||
warn(
|
||||
f'Exception raised while trying resolve model field for django-filter '
|
||||
f'field "{field_name}". Defaulting to string (Exception: {exc})'
|
||||
)
|
||||
schema = build_basic_type(OpenApiTypes.STR)
|
||||
|
||||
# primary keys are usually non-editable (readOnly=True) and map_model_field correctly
|
||||
# signals that attribute. however this does not apply in this context.
|
||||
schema.pop('readOnly', None)
|
||||
# enrich schema with additional info from filter_field
|
||||
enum = schema.pop('enum', None)
|
||||
# explicit filter choices may disable enum retrieved from model
|
||||
if not schema_from_override and filter_choices is not None:
|
||||
enum = filter_choices
|
||||
|
||||
description = schema.pop('description', None)
|
||||
if not schema_from_override:
|
||||
description = self._get_field_description(filter_field, description)
|
||||
|
||||
# parameter style variations based on filter base class
|
||||
if isinstance(filter_field, filters.BaseCSVFilter):
|
||||
schema = build_array_type(schema)
|
||||
field_names = [field_name]
|
||||
explode = False
|
||||
style = 'form'
|
||||
elif isinstance(filter_field, filters.MultipleChoiceFilter):
|
||||
schema = build_array_type(schema)
|
||||
field_names = [field_name]
|
||||
explode = True
|
||||
style = 'form'
|
||||
elif isinstance(filter_field, (filters.RangeFilter, filters.NumericRangeFilter)):
|
||||
try:
|
||||
suffixes = filter_field.field_class.widget.suffixes
|
||||
except AttributeError:
|
||||
suffixes = ['min', 'max']
|
||||
field_names = [
|
||||
f'{field_name}_{suffix}' if suffix else field_name for suffix in suffixes
|
||||
]
|
||||
explode = None
|
||||
style = None
|
||||
else:
|
||||
field_names = [field_name]
|
||||
explode = None
|
||||
style = None
|
||||
|
||||
return [
|
||||
build_parameter_type(
|
||||
name=field_name,
|
||||
required=filter_field.extra['required'],
|
||||
location=OpenApiParameter.QUERY,
|
||||
description=description,
|
||||
schema=schema,
|
||||
enum=enum,
|
||||
explode=explode,
|
||||
style=style
|
||||
)
|
||||
for field_name in field_names
|
||||
]
|
||||
|
||||
def _get_filter_method(self, filterset_class, filter_field):
|
||||
if callable(filter_field.method):
|
||||
return filter_field.method
|
||||
elif isinstance(filter_field.method, str):
|
||||
return getattr(filterset_class, filter_field.method)
|
||||
else:
|
||||
return None
|
||||
|
||||
def _get_filter_method_hint(self, filter_method):
|
||||
try:
|
||||
return get_type_hints(filter_method)['value']
|
||||
except: # noqa: E722
|
||||
return _NoHint
|
||||
|
||||
def _get_explicit_filter_choices(self, filter_field):
|
||||
if 'choices' not in filter_field.extra:
|
||||
return None
|
||||
elif callable(filter_field.extra['choices']):
|
||||
# choices function may utilize the DB, so refrain from actually calling it.
|
||||
return []
|
||||
else:
|
||||
return [c for c, _ in filter_field.extra['choices']]
|
||||
|
||||
def _get_model_field(self, filter_field, model):
|
||||
if not filter_field.field_name:
|
||||
return None
|
||||
path = filter_field.field_name.split('__')
|
||||
return follow_field_source(model, path, emit_warnings=False)
|
||||
|
||||
def _get_schema_from_model_field(self, auto_schema, filter_field, model):
|
||||
# Has potential to throw exceptions. Needs to be wrapped in try/except!
|
||||
#
|
||||
# first search for the field in the model as this has the least amount of
|
||||
# potential side effects. Only after that fails, attempt to call
|
||||
# get_queryset() to check for potential query annotations.
|
||||
model_field = self._get_model_field(filter_field, model)
|
||||
|
||||
# this is a cross feature between rest-framework-gis and django-filter. Regular
|
||||
# behavior needs to be sidestepped as the model information is lost down the line.
|
||||
# TODO for now this will be just a string to cover WKT, WKB, and urlencoded GeoJSON
|
||||
# build_geo_schema(model_field) would yield the correct result
|
||||
if self._is_gis(model_field):
|
||||
return build_basic_type(OpenApiTypes.STR)
|
||||
|
||||
if not isinstance(model_field, models.Field):
|
||||
qs = auto_schema.view.get_queryset()
|
||||
model_field = qs.query.annotations[filter_field.field_name].field
|
||||
return auto_schema._map_model_field(model_field, direction=None)
|
||||
|
||||
def _get_field_description(self, filter_field, description):
|
||||
# Try to improve description beyond auto-generated model description
|
||||
if filter_field.extra.get('help_text', None):
|
||||
description = filter_field.extra['help_text']
|
||||
elif filter_field.label is not None:
|
||||
description = filter_field.label
|
||||
|
||||
choices = filter_field.extra.get('choices')
|
||||
if choices and callable(choices):
|
||||
# remove auto-generated enum list, since choices come from a callable
|
||||
if '\n\n*' in (description or ''):
|
||||
description, _, _ = description.partition('\n\n*')
|
||||
elif (description or '').startswith('* `'):
|
||||
description = ''
|
||||
return description
|
||||
|
||||
choice_description = ''
|
||||
if spectacular_settings.ENUM_GENERATE_CHOICE_DESCRIPTION and choices and not callable(choices):
|
||||
choice_description = build_choice_description_list(choices)
|
||||
|
||||
if not choices:
|
||||
return description
|
||||
|
||||
if not description:
|
||||
return choice_description
|
||||
|
||||
if '\n\n*' in description:
|
||||
description, _, _ = description.partition('\n\n*')
|
||||
return description + '\n\n' + choice_description
|
||||
|
||||
if description.startswith('* `'):
|
||||
return choice_description
|
||||
|
||||
return description + '\n\n' + choice_description
|
||||
|
||||
@classmethod
|
||||
def _is_gis(cls, field):
|
||||
if not getattr(cls, '_has_gis', True):
|
||||
return False
|
||||
try:
|
||||
from django.contrib.gis.db.models import GeometryField
|
||||
from rest_framework_gis.filters import GeometryFilter
|
||||
|
||||
return isinstance(field, (GeometryField, GeometryFilter))
|
||||
except: # noqa
|
||||
cls._has_gis = False
|
||||
return False
|
||||
@@ -0,0 +1,49 @@
|
||||
from drf_spectacular.extensions import OpenApiAuthenticationExtension
|
||||
|
||||
|
||||
class DjangoOAuthToolkitScheme(OpenApiAuthenticationExtension):
|
||||
target_class = 'oauth2_provider.contrib.rest_framework.OAuth2Authentication'
|
||||
name = 'oauth2'
|
||||
|
||||
def get_security_requirement(self, auto_schema):
|
||||
from oauth2_provider.contrib.rest_framework import (
|
||||
IsAuthenticatedOrTokenHasScope, TokenHasScope, TokenMatchesOASRequirements,
|
||||
)
|
||||
view = auto_schema.view
|
||||
request = view.request
|
||||
|
||||
for permission in auto_schema.view.get_permissions():
|
||||
if isinstance(permission, TokenMatchesOASRequirements):
|
||||
alt_scopes = permission.get_required_alternate_scopes(request, view)
|
||||
alt_scopes = alt_scopes.get(auto_schema.method, [])
|
||||
return [{self.name: group} for group in alt_scopes]
|
||||
if isinstance(permission, IsAuthenticatedOrTokenHasScope):
|
||||
return {self.name: TokenHasScope().get_scopes(request, view)}
|
||||
if isinstance(permission, TokenHasScope):
|
||||
# catch-all for subclasses of TokenHasScope like TokenHasReadWriteScope
|
||||
return {self.name: permission.get_scopes(request, view)}
|
||||
|
||||
def get_security_definition(self, auto_schema):
|
||||
from oauth2_provider.scopes import get_scopes_backend
|
||||
|
||||
from drf_spectacular.settings import spectacular_settings
|
||||
|
||||
flows = {}
|
||||
for flow_type in spectacular_settings.OAUTH2_FLOWS:
|
||||
flows[flow_type] = {}
|
||||
if flow_type in ('implicit', 'authorizationCode'):
|
||||
flows[flow_type]['authorizationUrl'] = spectacular_settings.OAUTH2_AUTHORIZATION_URL
|
||||
if flow_type in ('password', 'clientCredentials', 'authorizationCode'):
|
||||
flows[flow_type]['tokenUrl'] = spectacular_settings.OAUTH2_TOKEN_URL
|
||||
if spectacular_settings.OAUTH2_REFRESH_URL:
|
||||
flows[flow_type]['refreshUrl'] = spectacular_settings.OAUTH2_REFRESH_URL
|
||||
if spectacular_settings.OAUTH2_SCOPES:
|
||||
flows[flow_type]['scopes'] = spectacular_settings.OAUTH2_SCOPES
|
||||
else:
|
||||
scope_backend = get_scopes_backend()
|
||||
flows[flow_type]['scopes'] = scope_backend.get_all_scopes()
|
||||
|
||||
return {
|
||||
'type': 'oauth2',
|
||||
'flows': flows
|
||||
}
|
||||
@@ -0,0 +1,62 @@
|
||||
import re
|
||||
from typing import Optional
|
||||
|
||||
from django.utils.module_loading import import_string
|
||||
|
||||
|
||||
def camelize_serializer_fields(result, generator, request, public):
|
||||
from django.conf import settings
|
||||
from djangorestframework_camel_case.settings import api_settings
|
||||
from djangorestframework_camel_case.util import camelize_re, underscore_to_camel
|
||||
|
||||
# prunes subtrees from camelization based on owning field name
|
||||
ignore_fields = api_settings.JSON_UNDERSCOREIZE.get("ignore_fields") or ()
|
||||
# ignore certain field names while camelizing
|
||||
ignore_keys = api_settings.JSON_UNDERSCOREIZE.get("ignore_keys") or ()
|
||||
|
||||
def has_middleware_installed():
|
||||
try:
|
||||
from djangorestframework_camel_case.middleware import CamelCaseMiddleWare
|
||||
except ImportError:
|
||||
return False
|
||||
|
||||
for middleware in [import_string(m) for m in settings.MIDDLEWARE]:
|
||||
try:
|
||||
if issubclass(middleware, CamelCaseMiddleWare):
|
||||
return True
|
||||
except TypeError:
|
||||
pass
|
||||
|
||||
def camelize_str(key: str) -> str:
|
||||
new_key = re.sub(camelize_re, underscore_to_camel, key) if "_" in key else key
|
||||
if key in ignore_keys or new_key in ignore_keys:
|
||||
return key
|
||||
return new_key
|
||||
|
||||
def camelize_component(schema: dict, name: Optional[str] = None) -> dict:
|
||||
if name is not None and (name in ignore_fields or camelize_str(name) in ignore_fields):
|
||||
return schema
|
||||
elif schema.get('type') == 'object':
|
||||
if 'properties' in schema:
|
||||
schema['properties'] = {
|
||||
camelize_str(field_name): camelize_component(field_schema, field_name)
|
||||
for field_name, field_schema in schema['properties'].items()
|
||||
}
|
||||
if 'required' in schema:
|
||||
schema['required'] = [camelize_str(field) for field in schema['required']]
|
||||
elif schema.get('type') == 'array':
|
||||
camelize_component(schema['items'])
|
||||
return schema
|
||||
|
||||
for (_, component_type), component in generator.registry._components.items():
|
||||
if component_type == 'schemas':
|
||||
camelize_component(component.schema)
|
||||
|
||||
if has_middleware_installed():
|
||||
for url_schema in result["paths"].values():
|
||||
for method_schema in url_schema.values():
|
||||
for parameter in method_schema.get("parameters", []):
|
||||
parameter["name"] = camelize_str(parameter["name"])
|
||||
|
||||
# inplace modification of components also affect result dict, so regeneration is not necessary
|
||||
return result
|
||||
@@ -0,0 +1,13 @@
|
||||
from drf_spectacular.extensions import OpenApiAuthenticationExtension
|
||||
from drf_spectacular.plumbing import build_bearer_security_scheme_object
|
||||
|
||||
|
||||
class KnoxTokenScheme(OpenApiAuthenticationExtension):
|
||||
target_class = 'knox.auth.TokenAuthentication'
|
||||
name = 'knoxApiToken'
|
||||
|
||||
def get_security_definition(self, auto_schema):
|
||||
return build_bearer_security_scheme_object(
|
||||
header_name='Authorization',
|
||||
token_prefix=self.target.authenticate_header(""),
|
||||
)
|
||||
@@ -0,0 +1,50 @@
|
||||
from drf_spectacular.drainage import set_override, warn
|
||||
from drf_spectacular.extensions import OpenApiSerializerExtension
|
||||
from drf_spectacular.plumbing import ResolvedComponent, build_basic_type
|
||||
from drf_spectacular.types import OpenApiTypes
|
||||
|
||||
|
||||
class PydanticExtension(OpenApiSerializerExtension):
|
||||
"""
|
||||
Allows using pydantic models on @extend_schema(request=..., response=...) to
|
||||
describe your API.
|
||||
|
||||
We only have partial support for pydantic's version of dataclass, due to the way they
|
||||
are designed. The outermost class (the @extend_schema argument) has to be a subclass
|
||||
of pydantic.BaseModel. Inside this outermost BaseModel, any combination of dataclass
|
||||
and BaseModel can be used.
|
||||
"""
|
||||
|
||||
target_class = "pydantic.BaseModel"
|
||||
match_subclasses = True
|
||||
|
||||
def get_name(self, auto_schema, direction):
|
||||
# due to the fact that it is complicated to pull out every field member BaseModel class
|
||||
# of the entry model, we simply use the class name as string for object. This hack may
|
||||
# create false positive warnings, so turn it off. However, this may suppress correct
|
||||
# warnings involving the entry class.
|
||||
# TODO suppression may be migrated to new ComponentIdentity system
|
||||
set_override(self.target, 'suppress_collision_warning', True)
|
||||
return self.target.__name__
|
||||
|
||||
def map_serializer(self, auto_schema, direction):
|
||||
# let pydantic generate a JSON schema
|
||||
try:
|
||||
from pydantic.json_schema import model_json_schema
|
||||
except ImportError:
|
||||
warn("Only pydantic >= 2 is supported. defaulting to generic object.")
|
||||
return build_basic_type(OpenApiTypes.OBJECT)
|
||||
|
||||
schema = model_json_schema(self.target, ref_template="#/components/schemas/{model}", mode="serialization")
|
||||
|
||||
# pull out potential sub-schemas and put them into component section
|
||||
for sub_name, sub_schema in schema.pop("$defs", {}).items():
|
||||
component = ResolvedComponent(
|
||||
name=sub_name,
|
||||
type=ResolvedComponent.SCHEMA,
|
||||
object=sub_name,
|
||||
schema=sub_schema,
|
||||
)
|
||||
auto_schema.registry.register_on_missing(component)
|
||||
|
||||
return schema
|
||||
@@ -0,0 +1,173 @@
|
||||
from django.conf import settings
|
||||
from django.utils.version import get_version_tuple
|
||||
from rest_framework import serializers
|
||||
|
||||
from drf_spectacular.contrib.rest_framework_simplejwt import (
|
||||
SimpleJWTScheme, TokenRefreshSerializerExtension,
|
||||
)
|
||||
from drf_spectacular.drainage import warn
|
||||
from drf_spectacular.extensions import OpenApiSerializerExtension, OpenApiViewExtension
|
||||
from drf_spectacular.utils import extend_schema
|
||||
|
||||
|
||||
def get_dj_rest_auth_setting(class_name, setting_name):
|
||||
from dj_rest_auth.__version__ import __version__
|
||||
|
||||
if get_version_tuple(__version__) < (3, 0, 0):
|
||||
from dj_rest_auth import app_settings
|
||||
|
||||
return getattr(app_settings, class_name)
|
||||
else:
|
||||
from dj_rest_auth.app_settings import api_settings
|
||||
|
||||
return getattr(api_settings, setting_name)
|
||||
|
||||
|
||||
def get_token_serializer_class():
|
||||
from dj_rest_auth.__version__ import __version__
|
||||
|
||||
if get_version_tuple(__version__) < (3, 0, 0):
|
||||
use_jwt = getattr(settings, 'REST_USE_JWT', False)
|
||||
else:
|
||||
from dj_rest_auth.app_settings import api_settings
|
||||
|
||||
use_jwt = api_settings.USE_JWT
|
||||
|
||||
if use_jwt:
|
||||
return get_dj_rest_auth_setting('JWTSerializer', 'JWT_SERIALIZER')
|
||||
else:
|
||||
return get_dj_rest_auth_setting('TokenSerializer', 'TOKEN_SERIALIZER')
|
||||
|
||||
|
||||
class RestAuthDetailSerializer(serializers.Serializer):
|
||||
detail = serializers.CharField(read_only=True, required=False)
|
||||
|
||||
|
||||
class RestAuthDefaultResponseView(OpenApiViewExtension):
|
||||
def view_replacement(self):
|
||||
class Fixed(self.target_class):
|
||||
@extend_schema(responses=RestAuthDetailSerializer)
|
||||
def post(self, request, *args, **kwargs):
|
||||
pass # pragma: no cover
|
||||
|
||||
return Fixed
|
||||
|
||||
|
||||
class RestAuthLoginView(OpenApiViewExtension):
|
||||
target_class = 'dj_rest_auth.views.LoginView'
|
||||
|
||||
def view_replacement(self):
|
||||
class Fixed(self.target_class):
|
||||
@extend_schema(responses=get_token_serializer_class())
|
||||
def post(self, request, *args, **kwargs):
|
||||
pass # pragma: no cover
|
||||
|
||||
return Fixed
|
||||
|
||||
|
||||
class RestAuthLogoutView(OpenApiViewExtension):
|
||||
target_class = 'dj_rest_auth.views.LogoutView'
|
||||
|
||||
def view_replacement(self):
|
||||
if getattr(settings, 'ACCOUNT_LOGOUT_ON_GET', None):
|
||||
get_schema_params = {'responses': RestAuthDetailSerializer}
|
||||
else:
|
||||
get_schema_params = {'exclude': True}
|
||||
|
||||
class Fixed(self.target_class):
|
||||
@extend_schema(**get_schema_params)
|
||||
def get(self, request, *args, **kwargs):
|
||||
pass # pragma: no cover
|
||||
|
||||
@extend_schema(request=None, responses=RestAuthDetailSerializer)
|
||||
def post(self, request, *args, **kwargs):
|
||||
pass # pragma: no cover
|
||||
|
||||
return Fixed
|
||||
|
||||
|
||||
class RestAuthPasswordChangeView(RestAuthDefaultResponseView):
|
||||
target_class = 'dj_rest_auth.views.PasswordChangeView'
|
||||
|
||||
|
||||
class RestAuthPasswordResetView(RestAuthDefaultResponseView):
|
||||
target_class = 'dj_rest_auth.views.PasswordResetView'
|
||||
|
||||
|
||||
class RestAuthPasswordResetConfirmView(RestAuthDefaultResponseView):
|
||||
target_class = 'dj_rest_auth.views.PasswordResetConfirmView'
|
||||
|
||||
|
||||
class RestAuthVerifyEmailView(RestAuthDefaultResponseView):
|
||||
target_class = 'dj_rest_auth.registration.views.VerifyEmailView'
|
||||
optional = True
|
||||
|
||||
|
||||
class RestAuthResendEmailVerificationView(RestAuthDefaultResponseView):
|
||||
target_class = 'dj_rest_auth.registration.views.ResendEmailVerificationView'
|
||||
optional = True
|
||||
|
||||
|
||||
class RestAuthJWTSerializer(OpenApiSerializerExtension):
|
||||
target_class = 'dj_rest_auth.serializers.JWTSerializer'
|
||||
|
||||
def map_serializer(self, auto_schema, direction):
|
||||
class Fixed(self.target_class):
|
||||
user = get_dj_rest_auth_setting('UserDetailsSerializer', 'USER_DETAILS_SERIALIZER')()
|
||||
|
||||
return auto_schema._map_serializer(Fixed, direction)
|
||||
|
||||
|
||||
class CookieTokenRefreshSerializerExtension(TokenRefreshSerializerExtension):
|
||||
target_class = 'dj_rest_auth.jwt_auth.CookieTokenRefreshSerializer'
|
||||
optional = True
|
||||
|
||||
def get_name(self):
|
||||
return 'TokenRefresh'
|
||||
|
||||
|
||||
class RestAuthRegisterView(OpenApiViewExtension):
|
||||
target_class = 'dj_rest_auth.registration.views.RegisterView'
|
||||
optional = True
|
||||
|
||||
def view_replacement(self):
|
||||
from allauth.account.app_settings import EMAIL_VERIFICATION, EmailVerificationMethod
|
||||
|
||||
if EMAIL_VERIFICATION == EmailVerificationMethod.MANDATORY:
|
||||
response_serializer = RestAuthDetailSerializer
|
||||
else:
|
||||
response_serializer = get_token_serializer_class()
|
||||
|
||||
class Fixed(self.target_class):
|
||||
@extend_schema(responses=response_serializer)
|
||||
def post(self, request, *args, **kwargs):
|
||||
pass # pragma: no cover
|
||||
|
||||
return Fixed
|
||||
|
||||
|
||||
class SimpleJWTCookieScheme(SimpleJWTScheme):
|
||||
target_class = 'dj_rest_auth.jwt_auth.JWTCookieAuthentication'
|
||||
optional = True
|
||||
name = ['jwtHeaderAuth', 'jwtCookieAuth'] # type: ignore
|
||||
|
||||
def get_security_requirement(self, auto_schema):
|
||||
return [{name: []} for name in self.name]
|
||||
|
||||
def get_security_definition(self, auto_schema):
|
||||
cookie_name = get_dj_rest_auth_setting('JWT_AUTH_COOKIE', 'JWT_AUTH_COOKIE')
|
||||
if not cookie_name:
|
||||
cookie_name = 'jwt-auth'
|
||||
warn(
|
||||
f'"JWT_AUTH_COOKIE" setting required for JWTCookieAuthentication. '
|
||||
f'defaulting to {cookie_name}'
|
||||
)
|
||||
|
||||
return [
|
||||
super().get_security_definition(auto_schema), # JWT from header
|
||||
{
|
||||
'type': 'apiKey',
|
||||
'in': 'cookie',
|
||||
'name': cookie_name,
|
||||
}
|
||||
]
|
||||
@@ -0,0 +1,32 @@
|
||||
from drf_spectacular.extensions import OpenApiViewExtension
|
||||
|
||||
|
||||
class ObtainAuthTokenView(OpenApiViewExtension):
|
||||
target_class = 'rest_framework.authtoken.views.ObtainAuthToken'
|
||||
match_subclasses = True
|
||||
|
||||
def view_replacement(self):
|
||||
"""
|
||||
Prior to DRF 3.12.0, usage of ObtainAuthToken resulted in AssertionError
|
||||
|
||||
Incompatible AutoSchema used on View "ObtainAuthToken". Is DRF's DEFAULT_SCHEMA_CLASS ...
|
||||
|
||||
This is because DRF had a bug which made it NOT honor DEFAULT_SCHEMA_CLASS and instead
|
||||
injected an unsolicited coreschema class for this view and this view only. This extension
|
||||
fixes the view before the wrong schema class is used.
|
||||
|
||||
Bug in DRF that was fixed in later versions:
|
||||
https://github.com/encode/django-rest-framework/blob/4121b01b912668c049b26194a9a107c27a332429/rest_framework/authtoken/views.py#L16
|
||||
"""
|
||||
from rest_framework import VERSION
|
||||
|
||||
from drf_spectacular.openapi import AutoSchema
|
||||
|
||||
# no intervention needed
|
||||
if VERSION >= '3.12':
|
||||
return self.target
|
||||
|
||||
class FixedObtainAuthToken(self.target):
|
||||
schema = AutoSchema()
|
||||
|
||||
return FixedObtainAuthToken
|
||||
@@ -0,0 +1,36 @@
|
||||
from typing import Any
|
||||
|
||||
from drf_spectacular.drainage import get_override, has_override
|
||||
from drf_spectacular.extensions import OpenApiSerializerExtension
|
||||
from drf_spectacular.plumbing import ComponentIdentity, get_doc
|
||||
from drf_spectacular.utils import Direction
|
||||
|
||||
|
||||
class OpenApiDataclassSerializerExtensions(OpenApiSerializerExtension):
|
||||
target_class = "rest_framework_dataclasses.serializers.DataclassSerializer"
|
||||
match_subclasses = True
|
||||
|
||||
def get_name(self):
|
||||
"""Use the dataclass name in the schema, instead of the serializer prefix (which can be just Dataclass)."""
|
||||
if has_override(self.target, 'component_name'):
|
||||
return get_override(self.target, 'component_name')
|
||||
if getattr(getattr(self.target, 'Meta', None), 'ref_name', None) is not None:
|
||||
return self.target.Meta.ref_name
|
||||
if has_override(self.target.dataclass_definition.dataclass_type, 'component_name'):
|
||||
return get_override(self.target.dataclass_definition.dataclass_type, 'component_name')
|
||||
return self.target.dataclass_definition.dataclass_type.__name__
|
||||
|
||||
def get_identity(self, auto_schema, direction: Direction) -> Any:
|
||||
return ComponentIdentity(self.target.dataclass_definition.dataclass_type)
|
||||
|
||||
def strip_library_doc(self, schema):
|
||||
"""Strip the DataclassSerializer library documentation from the schema."""
|
||||
from rest_framework_dataclasses.serializers import DataclassSerializer
|
||||
if 'description' in schema and schema['description'] == get_doc(DataclassSerializer):
|
||||
del schema['description']
|
||||
return schema
|
||||
|
||||
def map_serializer(self, auto_schema, direction: Direction):
|
||||
""""Generate the schema for a DataclassSerializer."""
|
||||
schema = auto_schema._map_serializer(self.target, direction, bypass_extensions=True)
|
||||
return self.strip_library_doc(schema)
|
||||
@@ -0,0 +1,219 @@
|
||||
from rest_framework.utils.model_meta import get_field_info
|
||||
|
||||
from drf_spectacular.drainage import warn
|
||||
from drf_spectacular.extensions import OpenApiSerializerExtension, OpenApiSerializerFieldExtension
|
||||
from drf_spectacular.plumbing import (
|
||||
ResolvedComponent, build_array_type, build_object_type, follow_field_source, get_doc,
|
||||
)
|
||||
|
||||
|
||||
def build_point_schema():
|
||||
return {
|
||||
"type": "array",
|
||||
"items": {"type": "number", "format": "float"},
|
||||
"example": [12.9721, 77.5933],
|
||||
"minItems": 2,
|
||||
"maxItems": 3,
|
||||
}
|
||||
|
||||
|
||||
def build_linestring_schema():
|
||||
return {
|
||||
"type": "array",
|
||||
"items": build_point_schema(),
|
||||
"example": [[22.4707, 70.0577], [12.9721, 77.5933]],
|
||||
"minItems": 2,
|
||||
}
|
||||
|
||||
|
||||
def build_polygon_schema():
|
||||
return {
|
||||
"type": "array",
|
||||
"items": {**build_linestring_schema(), "minItems": 4},
|
||||
"example": [
|
||||
[
|
||||
[0.0, 0.0],
|
||||
[0.0, 50.0],
|
||||
[50.0, 50.0],
|
||||
[50.0, 0.0],
|
||||
[0.0, 0.0],
|
||||
],
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
def build_geo_container_schema(name, coords):
|
||||
return build_object_type(
|
||||
properties={
|
||||
"type": {"type": "string", "enum": [name]},
|
||||
"coordinates": coords,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def build_point_geo_schema():
|
||||
return build_geo_container_schema("Point", build_point_schema())
|
||||
|
||||
|
||||
def build_linestring_geo_schema():
|
||||
return build_geo_container_schema("LineString", build_linestring_schema())
|
||||
|
||||
|
||||
def build_polygon_geo_schema():
|
||||
return build_geo_container_schema("Polygon", build_polygon_schema())
|
||||
|
||||
|
||||
def build_geometry_geo_schema():
|
||||
return {
|
||||
'oneOf': [
|
||||
build_point_geo_schema(),
|
||||
build_linestring_geo_schema(),
|
||||
build_polygon_geo_schema(),
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
def build_bbox_schema():
|
||||
return {
|
||||
"type": "array",
|
||||
"items": {"type": "number"},
|
||||
"minItems": 4,
|
||||
"maxItems": 4,
|
||||
"example": [12.9721, 77.5933, 12.9721, 77.5933],
|
||||
}
|
||||
|
||||
|
||||
def build_geo_schema(model_field):
|
||||
from django.contrib.gis.db import models
|
||||
|
||||
if isinstance(model_field, models.PointField):
|
||||
return build_point_geo_schema()
|
||||
elif isinstance(model_field, models.LineStringField):
|
||||
return build_linestring_geo_schema()
|
||||
elif isinstance(model_field, models.PolygonField):
|
||||
return build_polygon_geo_schema()
|
||||
elif isinstance(model_field, models.MultiPointField):
|
||||
return build_geo_container_schema(
|
||||
"MultiPoint", build_array_type(build_point_schema())
|
||||
)
|
||||
elif isinstance(model_field, models.MultiLineStringField):
|
||||
return build_geo_container_schema(
|
||||
"MultiLineString", build_array_type(build_linestring_schema())
|
||||
)
|
||||
elif isinstance(model_field, models.MultiPolygonField):
|
||||
return build_geo_container_schema(
|
||||
"MultiPolygon", build_array_type(build_polygon_schema())
|
||||
)
|
||||
elif isinstance(model_field, models.GeometryCollectionField):
|
||||
return build_geo_container_schema(
|
||||
"GeometryCollection", build_array_type(build_geometry_geo_schema())
|
||||
)
|
||||
elif isinstance(model_field, models.GeometryField):
|
||||
return build_geometry_geo_schema()
|
||||
else:
|
||||
warn("Encountered unknown GIS geometry field")
|
||||
return {}
|
||||
|
||||
|
||||
def map_geo_field(serializer, geo_field_name):
|
||||
from rest_framework_gis.fields import GeometrySerializerMethodField
|
||||
|
||||
field = serializer.fields[geo_field_name]
|
||||
if isinstance(field, GeometrySerializerMethodField):
|
||||
warn("Geometry generation for GeometrySerializerMethodField is not supported.")
|
||||
return {}
|
||||
model_field = get_field_info(serializer.Meta.model).fields[geo_field_name]
|
||||
return build_geo_schema(model_field)
|
||||
|
||||
|
||||
def _inject_enum_collision_fix(collection):
|
||||
from drf_spectacular.settings import spectacular_settings
|
||||
if not collection and 'GisFeatureEnum' not in spectacular_settings.ENUM_NAME_OVERRIDES:
|
||||
spectacular_settings.ENUM_NAME_OVERRIDES['GisFeatureEnum'] = ('Feature',)
|
||||
if collection and 'GisFeatureCollectionEnum' not in spectacular_settings.ENUM_NAME_OVERRIDES:
|
||||
spectacular_settings.ENUM_NAME_OVERRIDES['GisFeatureCollectionEnum'] = ('FeatureCollection',)
|
||||
|
||||
|
||||
class GeoFeatureModelSerializerExtension(OpenApiSerializerExtension):
|
||||
target_class = 'rest_framework_gis.serializers.GeoFeatureModelSerializer'
|
||||
match_subclasses = True
|
||||
|
||||
def map_serializer(self, auto_schema, direction):
|
||||
_inject_enum_collision_fix(collection=False)
|
||||
|
||||
base_schema = auto_schema._map_serializer(self.target, direction, bypass_extensions=True)
|
||||
return self.map_geo_feature_model_serializer(self.target, base_schema)
|
||||
|
||||
def map_geo_feature_model_serializer(self, serializer, base_schema):
|
||||
from rest_framework_gis.serializers import GeoFeatureModelSerializer
|
||||
|
||||
geo_properties = {
|
||||
"type": {"type": "string", "enum": ["Feature"]}
|
||||
}
|
||||
if serializer.Meta.id_field:
|
||||
geo_properties["id"] = base_schema["properties"].pop(serializer.Meta.id_field)
|
||||
|
||||
geo_properties["geometry"] = map_geo_field(serializer, serializer.Meta.geo_field)
|
||||
base_schema["properties"].pop(serializer.Meta.geo_field)
|
||||
|
||||
if serializer.Meta.auto_bbox or serializer.Meta.bbox_geo_field:
|
||||
geo_properties["bbox"] = build_bbox_schema()
|
||||
base_schema["properties"].pop(serializer.Meta.bbox_geo_field, None)
|
||||
|
||||
# only expose if description comes from the user
|
||||
description = base_schema.pop('description', None)
|
||||
if description == get_doc(GeoFeatureModelSerializer):
|
||||
description = None
|
||||
|
||||
# ignore this aspect for now
|
||||
base_schema.pop('required', None)
|
||||
|
||||
# nest remaining fields under property "properties"
|
||||
geo_properties["properties"] = base_schema
|
||||
|
||||
return build_object_type(
|
||||
properties=geo_properties,
|
||||
description=description,
|
||||
)
|
||||
|
||||
|
||||
class GeoFeatureModelListSerializerExtension(OpenApiSerializerExtension):
|
||||
target_class = 'rest_framework_gis.serializers.GeoFeatureModelListSerializer'
|
||||
|
||||
def map_serializer(self, auto_schema, direction):
|
||||
_inject_enum_collision_fix(collection=True)
|
||||
|
||||
# build/retrieve feature component generated by GeoFeatureModelSerializerExtension.
|
||||
# wrap the ref in the special list structure and build another component based on that.
|
||||
feature_component = auto_schema.resolve_serializer(self.target.child, direction)
|
||||
collection_schema = build_object_type(
|
||||
properties={
|
||||
"type": {"type": "string", "enum": ["FeatureCollection"]},
|
||||
"features": build_array_type(feature_component.ref)
|
||||
}
|
||||
)
|
||||
list_component = ResolvedComponent(
|
||||
name=f'{feature_component.name}List',
|
||||
type=ResolvedComponent.SCHEMA,
|
||||
object=self.target.child,
|
||||
schema=collection_schema
|
||||
)
|
||||
auto_schema.registry.register_on_missing(list_component)
|
||||
return list_component.ref
|
||||
|
||||
|
||||
class GeometryFieldExtension(OpenApiSerializerFieldExtension):
|
||||
target_class = 'rest_framework_gis.fields.GeometryField'
|
||||
match_subclasses = True
|
||||
|
||||
def map_serializer_field(self, auto_schema, direction):
|
||||
# running this extension for GeoFeatureModelSerializer's geo_field is superfluous
|
||||
# as above extension already handles that individually. We run it anyway because
|
||||
# robustly checking the proper condition is harder.
|
||||
try:
|
||||
model = self.target.parent.Meta.model
|
||||
model_field = follow_field_source(model, self.target.source.split('.'))
|
||||
return build_geo_schema(model_field)
|
||||
except: # noqa: E722
|
||||
warn(f'Encountered an issue resolving field {self.target}. defaulting to generic object.')
|
||||
return {}
|
||||
@@ -0,0 +1,16 @@
|
||||
from drf_spectacular.extensions import OpenApiAuthenticationExtension
|
||||
from drf_spectacular.plumbing import build_bearer_security_scheme_object
|
||||
|
||||
|
||||
class JWTScheme(OpenApiAuthenticationExtension):
|
||||
target_class = 'rest_framework_jwt.authentication.JSONWebTokenAuthentication'
|
||||
name = 'jwtAuth'
|
||||
|
||||
def get_security_definition(self, auto_schema):
|
||||
from rest_framework_jwt.settings import api_settings
|
||||
|
||||
return build_bearer_security_scheme_object(
|
||||
header_name='AUTHORIZATION',
|
||||
token_prefix=api_settings.JWT_AUTH_HEADER_PREFIX,
|
||||
bearer_format='JWT'
|
||||
)
|
||||
@@ -0,0 +1,16 @@
|
||||
from drf_spectacular.extensions import OpenApiSerializerFieldExtension
|
||||
from drf_spectacular.plumbing import build_array_type, is_list_serializer
|
||||
|
||||
|
||||
class RecursiveFieldExtension(OpenApiSerializerFieldExtension):
|
||||
target_class = "rest_framework_recursive.fields.RecursiveField"
|
||||
|
||||
def map_serializer_field(self, auto_schema, direction):
|
||||
proxied = self.target.proxied
|
||||
|
||||
if is_list_serializer(proxied):
|
||||
component = auto_schema.resolve_serializer(proxied.child, direction)
|
||||
return build_array_type(component.ref)
|
||||
|
||||
component = auto_schema.resolve_serializer(proxied, direction)
|
||||
return component.ref
|
||||
@@ -0,0 +1,87 @@
|
||||
from rest_framework import serializers
|
||||
|
||||
from drf_spectacular.drainage import warn
|
||||
from drf_spectacular.extensions import OpenApiAuthenticationExtension, OpenApiSerializerExtension
|
||||
from drf_spectacular.plumbing import build_bearer_security_scheme_object
|
||||
from drf_spectacular.utils import inline_serializer
|
||||
|
||||
|
||||
class TokenObtainPairSerializerExtension(OpenApiSerializerExtension):
|
||||
target_class = 'rest_framework_simplejwt.serializers.TokenObtainPairSerializer'
|
||||
|
||||
def map_serializer(self, auto_schema, direction):
|
||||
Fixed = inline_serializer('Fixed', fields={
|
||||
self.target_class.username_field: serializers.CharField(write_only=True),
|
||||
'password': serializers.CharField(write_only=True),
|
||||
'access': serializers.CharField(read_only=True),
|
||||
'refresh': serializers.CharField(read_only=True),
|
||||
})
|
||||
return auto_schema._map_serializer(Fixed, direction)
|
||||
|
||||
|
||||
class TokenObtainSlidingSerializerExtension(OpenApiSerializerExtension):
|
||||
target_class = 'rest_framework_simplejwt.serializers.TokenObtainSlidingSerializer'
|
||||
|
||||
def map_serializer(self, auto_schema, direction):
|
||||
Fixed = inline_serializer('Fixed', fields={
|
||||
self.target_class.username_field: serializers.CharField(write_only=True),
|
||||
'password': serializers.CharField(write_only=True),
|
||||
'token': serializers.CharField(read_only=True),
|
||||
})
|
||||
return auto_schema._map_serializer(Fixed, direction)
|
||||
|
||||
|
||||
class TokenRefreshSerializerExtension(OpenApiSerializerExtension):
|
||||
target_class = 'rest_framework_simplejwt.serializers.TokenRefreshSerializer'
|
||||
|
||||
def map_serializer(self, auto_schema, direction):
|
||||
from rest_framework_simplejwt.settings import api_settings
|
||||
|
||||
if api_settings.ROTATE_REFRESH_TOKENS:
|
||||
class Fixed(serializers.Serializer):
|
||||
access = serializers.CharField(read_only=True)
|
||||
refresh = serializers.CharField()
|
||||
else:
|
||||
class Fixed(serializers.Serializer):
|
||||
access = serializers.CharField(read_only=True)
|
||||
refresh = serializers.CharField(write_only=True)
|
||||
|
||||
return auto_schema._map_serializer(Fixed, direction)
|
||||
|
||||
|
||||
class TokenVerifySerializerExtension(OpenApiSerializerExtension):
|
||||
target_class = 'rest_framework_simplejwt.serializers.TokenVerifySerializer'
|
||||
|
||||
def map_serializer(self, auto_schema, direction):
|
||||
Fixed = inline_serializer('Fixed', fields={
|
||||
'token': serializers.CharField(write_only=True),
|
||||
})
|
||||
return auto_schema._map_serializer(Fixed, direction)
|
||||
|
||||
|
||||
class SimpleJWTScheme(OpenApiAuthenticationExtension):
|
||||
target_class = 'rest_framework_simplejwt.authentication.JWTAuthentication'
|
||||
name = 'jwtAuth'
|
||||
|
||||
def get_security_definition(self, auto_schema):
|
||||
from rest_framework_simplejwt.settings import api_settings
|
||||
|
||||
if len(api_settings.AUTH_HEADER_TYPES) > 1:
|
||||
warn(
|
||||
f'OpenAPI3 can only have one "bearerFormat". JWT Settings specify '
|
||||
f'{api_settings.AUTH_HEADER_TYPES}. Using the first one.'
|
||||
)
|
||||
|
||||
return build_bearer_security_scheme_object(
|
||||
header_name=getattr(api_settings, 'AUTH_HEADER_NAME', 'HTTP_AUTHORIZATION'),
|
||||
token_prefix=api_settings.AUTH_HEADER_TYPES[0],
|
||||
bearer_format='JWT'
|
||||
)
|
||||
|
||||
|
||||
class SimpleJWTTokenUserScheme(SimpleJWTScheme):
|
||||
target_class = 'rest_framework_simplejwt.authentication.JWTTokenUserAuthentication'
|
||||
|
||||
|
||||
class SimpleJWTStatelessUserScheme(SimpleJWTScheme):
|
||||
target_class = "rest_framework_simplejwt.authentication.JWTStatelessUserAuthentication"
|
||||
@@ -0,0 +1,81 @@
|
||||
from drf_spectacular.drainage import warn
|
||||
from drf_spectacular.extensions import OpenApiSerializerExtension
|
||||
from drf_spectacular.plumbing import (
|
||||
ComponentIdentity, ResolvedComponent, build_basic_type, build_object_type,
|
||||
is_patched_serializer,
|
||||
)
|
||||
from drf_spectacular.settings import spectacular_settings
|
||||
from drf_spectacular.types import OpenApiTypes
|
||||
|
||||
|
||||
class PolymorphicSerializerExtension(OpenApiSerializerExtension):
|
||||
target_class = 'rest_polymorphic.serializers.PolymorphicSerializer'
|
||||
match_subclasses = True
|
||||
|
||||
def map_serializer(self, auto_schema, direction):
|
||||
sub_components = []
|
||||
serializer = self.target
|
||||
|
||||
for sub_model in serializer.model_serializer_mapping:
|
||||
sub_serializer = serializer._get_serializer_from_model_or_instance(sub_model)
|
||||
sub_serializer.partial = serializer.partial
|
||||
resource_type = serializer.to_resource_type(sub_model)
|
||||
component = auto_schema.resolve_serializer(sub_serializer, direction)
|
||||
if not component:
|
||||
# rebuild a virtual schema-less component to model empty serializers
|
||||
component = ResolvedComponent(
|
||||
name=auto_schema._get_serializer_name(sub_serializer, direction),
|
||||
type=ResolvedComponent.SCHEMA,
|
||||
object=ComponentIdentity('virtual')
|
||||
)
|
||||
typed_component = self.build_typed_component(
|
||||
auto_schema=auto_schema,
|
||||
component=component,
|
||||
resource_type_field_name=serializer.resource_type_field_name,
|
||||
patched=is_patched_serializer(sub_serializer, direction)
|
||||
)
|
||||
sub_components.append((resource_type, typed_component.ref))
|
||||
|
||||
if not resource_type:
|
||||
warn(
|
||||
f'discriminator mapping key is empty for {sub_serializer.__class__}. '
|
||||
f'this might lead to code generation issues.'
|
||||
)
|
||||
|
||||
one_of_list = []
|
||||
for _, ref in sub_components:
|
||||
if ref not in one_of_list:
|
||||
one_of_list.append(ref)
|
||||
|
||||
return {
|
||||
'oneOf': one_of_list,
|
||||
'discriminator': {
|
||||
'propertyName': serializer.resource_type_field_name,
|
||||
'mapping': {resource_type: ref['$ref'] for resource_type, ref in sub_components},
|
||||
}
|
||||
}
|
||||
|
||||
def build_typed_component(self, auto_schema, component, resource_type_field_name, patched):
|
||||
if spectacular_settings.COMPONENT_SPLIT_REQUEST and component.name.endswith('Request'):
|
||||
typed_component_name = component.name[:-len('Request')] + 'TypedRequest'
|
||||
else:
|
||||
typed_component_name = f'{component.name}Typed'
|
||||
|
||||
resource_type_schema = build_object_type(
|
||||
properties={resource_type_field_name: build_basic_type(OpenApiTypes.STR)},
|
||||
required=None if patched else [resource_type_field_name]
|
||||
)
|
||||
# if sub-serializer has an empty schema, only expose the resource_type field part
|
||||
if component.schema:
|
||||
schema = {'allOf': [resource_type_schema, component.ref]}
|
||||
else:
|
||||
schema = resource_type_schema
|
||||
|
||||
component_typed = ResolvedComponent(
|
||||
name=typed_component_name,
|
||||
type=ResolvedComponent.SCHEMA,
|
||||
object=component.object,
|
||||
schema=schema,
|
||||
)
|
||||
auto_schema.registry.register_on_missing(component_typed)
|
||||
return component_typed
|
||||
Reference in New Issue
Block a user