This commit is contained in:
Iliyan Angelov
2025-09-19 11:58:53 +03:00
parent 306b20e24a
commit 6b247e5b9f
11423 changed files with 1500615 additions and 778 deletions

View File

@@ -0,0 +1,15 @@
__all__ = [
'django_oauth_toolkit',
'djangorestframework_camel_case',
'rest_auth',
'rest_framework',
'rest_polymorphic',
'rest_framework_dataclasses',
'rest_framework_jwt',
'rest_framework_simplejwt',
'django_filters',
'rest_framework_recursive',
'rest_framework_gis',
'pydantic',
'knox_auth_token',
]

View File

@@ -0,0 +1,305 @@
from django.db import models
from drf_spectacular.drainage import add_trace_message, get_override, has_override, warn
from drf_spectacular.extensions import OpenApiFilterExtension
from drf_spectacular.plumbing import (
build_array_type, build_basic_type, build_choice_description_list, build_parameter_type,
follow_field_source, force_instance, get_manager, get_type_hints, get_view_model, is_basic_type,
is_field,
)
from drf_spectacular.settings import spectacular_settings
from drf_spectacular.types import OpenApiTypes
from drf_spectacular.utils import OpenApiParameter
_NoHint = object()
class DjangoFilterExtension(OpenApiFilterExtension):
"""
Extensions that specifically deals with ``django-filter`` fields. The introspection
attempts to estimate the underlying model field types to generate filter types.
However, there are under-specified filter fields for which heuristics need to be performed.
This serves as an explicit list of all partially-handled filter fields:
- ``AllValuesFilter``: skip choices to prevent DB query
- ``AllValuesMultipleFilter``: skip choices to prevent DB query, multi handled though
- ``ChoiceFilter``: enum handled, type under-specified
- ``DateRangeFilter``: N/A
- ``LookupChoiceFilter``: N/A
- ``ModelChoiceFilter``: enum handled
- ``ModelMultipleChoiceFilter``: enum, multi handled
- ``MultipleChoiceFilter``: enum, multi handled
- ``RangeFilter``: min/max handled, type under-specified
- ``TypedChoiceFilter``: enum handled
- ``TypedMultipleChoiceFilter``: enum, multi handled
In case of warnings or incorrect filter types, you can manually override the underlying
field type with a manual ``extend_schema_field`` decoration. Alternatively, if you have a
filter method for your filter field, you can attach ``extend_schema_field`` to that filter
method.
.. code-block::
class SomeFilter(FilterSet):
some_field = extend_schema_field(OpenApiTypes.NUMBER)(
RangeFilter(field_name='some_manually_annotated_field_in_qs')
)
"""
target_class = 'django_filters.rest_framework.DjangoFilterBackend'
match_subclasses = True
def get_schema_operation_parameters(self, auto_schema, *args, **kwargs):
model = get_view_model(auto_schema.view)
if not model:
return []
filterset_class = self.target.get_filterset_class(auto_schema.view, get_manager(model).none())
if not filterset_class:
return []
result = []
with add_trace_message(filterset_class):
for field_name, filter_field in filterset_class.base_filters.items():
result += self.resolve_filter_field(
auto_schema, model, filterset_class, field_name, filter_field
)
return result
def resolve_filter_field(self, auto_schema, model, filterset_class, field_name, filter_field):
from django_filters import filters
unambiguous_mapping = {
filters.CharFilter: OpenApiTypes.STR,
filters.BooleanFilter: OpenApiTypes.BOOL,
filters.DateFilter: OpenApiTypes.DATE,
filters.DateTimeFilter: OpenApiTypes.DATETIME,
filters.IsoDateTimeFilter: OpenApiTypes.DATETIME,
filters.TimeFilter: OpenApiTypes.TIME,
filters.UUIDFilter: OpenApiTypes.UUID,
filters.DurationFilter: OpenApiTypes.DURATION,
filters.OrderingFilter: OpenApiTypes.STR,
filters.TimeRangeFilter: OpenApiTypes.TIME,
filters.DateFromToRangeFilter: OpenApiTypes.DATE,
filters.IsoDateTimeFromToRangeFilter: OpenApiTypes.DATETIME,
filters.DateTimeFromToRangeFilter: OpenApiTypes.DATETIME,
}
filter_method = self._get_filter_method(filterset_class, filter_field)
filter_method_hint = self._get_filter_method_hint(filter_method)
filter_choices = self._get_explicit_filter_choices(filter_field)
schema_from_override = False
if has_override(filter_field, 'field') or has_override(filter_method, 'field'):
schema_from_override = True
annotation = (
get_override(filter_field, 'field') or get_override(filter_method, 'field')
)
if is_basic_type(annotation):
schema = build_basic_type(annotation)
elif isinstance(annotation, dict):
# allow injecting raw schema via @extend_schema_field decorator
schema = annotation.copy()
elif is_field(annotation):
schema = auto_schema._map_serializer_field(force_instance(annotation), "request")
else:
warn(
f"Unsupported annotation {annotation} on filter field {field_name}. defaulting to string."
)
schema = build_basic_type(OpenApiTypes.STR)
elif filter_method_hint is not _NoHint:
if is_basic_type(filter_method_hint):
schema = build_basic_type(filter_method_hint)
else:
schema = build_basic_type(OpenApiTypes.STR)
elif isinstance(filter_field, tuple(unambiguous_mapping)):
for cls in filter_field.__class__.__mro__:
if cls in unambiguous_mapping:
schema = build_basic_type(unambiguous_mapping[cls])
break
elif isinstance(filter_field, (filters.NumberFilter, filters.NumericRangeFilter)):
# NumberField is underspecified by itself. try to find the
# type that makes the most sense or default to generic NUMBER
model_field = self._get_model_field(filter_field, model)
if isinstance(model_field, (models.IntegerField, models.AutoField)):
schema = build_basic_type(OpenApiTypes.INT)
elif isinstance(model_field, models.FloatField):
schema = build_basic_type(OpenApiTypes.FLOAT)
elif isinstance(model_field, models.DecimalField):
schema = build_basic_type(OpenApiTypes.NUMBER) # TODO may be improved
else:
schema = build_basic_type(OpenApiTypes.NUMBER)
elif isinstance(filter_field, (filters.ChoiceFilter, filters.MultipleChoiceFilter)):
try:
schema = self._get_schema_from_model_field(auto_schema, filter_field, model)
except Exception:
if filter_choices and is_basic_type(type(filter_choices[0])):
# fallback to type guessing from first choice element
schema = build_basic_type(type(filter_choices[0]))
else:
warn(
f'Unable to guess choice types from values, filter method\'s type hint '
f'or find "{field_name}" in model. Defaulting to string.'
)
schema = build_basic_type(OpenApiTypes.STR)
else:
# the last resort is to look up the type via the model or queryset field
# and emit a warning if we were unsuccessful.
try:
schema = self._get_schema_from_model_field(auto_schema, filter_field, model)
except Exception as exc: # pragma: no cover
warn(
f'Exception raised while trying resolve model field for django-filter '
f'field "{field_name}". Defaulting to string (Exception: {exc})'
)
schema = build_basic_type(OpenApiTypes.STR)
# primary keys are usually non-editable (readOnly=True) and map_model_field correctly
# signals that attribute. however this does not apply in this context.
schema.pop('readOnly', None)
# enrich schema with additional info from filter_field
enum = schema.pop('enum', None)
# explicit filter choices may disable enum retrieved from model
if not schema_from_override and filter_choices is not None:
enum = filter_choices
description = schema.pop('description', None)
if not schema_from_override:
description = self._get_field_description(filter_field, description)
# parameter style variations based on filter base class
if isinstance(filter_field, filters.BaseCSVFilter):
schema = build_array_type(schema)
field_names = [field_name]
explode = False
style = 'form'
elif isinstance(filter_field, filters.MultipleChoiceFilter):
schema = build_array_type(schema)
field_names = [field_name]
explode = True
style = 'form'
elif isinstance(filter_field, (filters.RangeFilter, filters.NumericRangeFilter)):
try:
suffixes = filter_field.field_class.widget.suffixes
except AttributeError:
suffixes = ['min', 'max']
field_names = [
f'{field_name}_{suffix}' if suffix else field_name for suffix in suffixes
]
explode = None
style = None
else:
field_names = [field_name]
explode = None
style = None
return [
build_parameter_type(
name=field_name,
required=filter_field.extra['required'],
location=OpenApiParameter.QUERY,
description=description,
schema=schema,
enum=enum,
explode=explode,
style=style
)
for field_name in field_names
]
def _get_filter_method(self, filterset_class, filter_field):
if callable(filter_field.method):
return filter_field.method
elif isinstance(filter_field.method, str):
return getattr(filterset_class, filter_field.method)
else:
return None
def _get_filter_method_hint(self, filter_method):
try:
return get_type_hints(filter_method)['value']
except: # noqa: E722
return _NoHint
def _get_explicit_filter_choices(self, filter_field):
if 'choices' not in filter_field.extra:
return None
elif callable(filter_field.extra['choices']):
# choices function may utilize the DB, so refrain from actually calling it.
return []
else:
return [c for c, _ in filter_field.extra['choices']]
def _get_model_field(self, filter_field, model):
if not filter_field.field_name:
return None
path = filter_field.field_name.split('__')
return follow_field_source(model, path, emit_warnings=False)
def _get_schema_from_model_field(self, auto_schema, filter_field, model):
# Has potential to throw exceptions. Needs to be wrapped in try/except!
#
# first search for the field in the model as this has the least amount of
# potential side effects. Only after that fails, attempt to call
# get_queryset() to check for potential query annotations.
model_field = self._get_model_field(filter_field, model)
# this is a cross feature between rest-framework-gis and django-filter. Regular
# behavior needs to be sidestepped as the model information is lost down the line.
# TODO for now this will be just a string to cover WKT, WKB, and urlencoded GeoJSON
# build_geo_schema(model_field) would yield the correct result
if self._is_gis(model_field):
return build_basic_type(OpenApiTypes.STR)
if not isinstance(model_field, models.Field):
qs = auto_schema.view.get_queryset()
model_field = qs.query.annotations[filter_field.field_name].field
return auto_schema._map_model_field(model_field, direction=None)
def _get_field_description(self, filter_field, description):
# Try to improve description beyond auto-generated model description
if filter_field.extra.get('help_text', None):
description = filter_field.extra['help_text']
elif filter_field.label is not None:
description = filter_field.label
choices = filter_field.extra.get('choices')
if choices and callable(choices):
# remove auto-generated enum list, since choices come from a callable
if '\n\n*' in (description or ''):
description, _, _ = description.partition('\n\n*')
elif (description or '').startswith('* `'):
description = ''
return description
choice_description = ''
if spectacular_settings.ENUM_GENERATE_CHOICE_DESCRIPTION and choices and not callable(choices):
choice_description = build_choice_description_list(choices)
if not choices:
return description
if not description:
return choice_description
if '\n\n*' in description:
description, _, _ = description.partition('\n\n*')
return description + '\n\n' + choice_description
if description.startswith('* `'):
return choice_description
return description + '\n\n' + choice_description
@classmethod
def _is_gis(cls, field):
if not getattr(cls, '_has_gis', True):
return False
try:
from django.contrib.gis.db.models import GeometryField
from rest_framework_gis.filters import GeometryFilter
return isinstance(field, (GeometryField, GeometryFilter))
except: # noqa
cls._has_gis = False
return False

View File

@@ -0,0 +1,49 @@
from drf_spectacular.extensions import OpenApiAuthenticationExtension
class DjangoOAuthToolkitScheme(OpenApiAuthenticationExtension):
target_class = 'oauth2_provider.contrib.rest_framework.OAuth2Authentication'
name = 'oauth2'
def get_security_requirement(self, auto_schema):
from oauth2_provider.contrib.rest_framework import (
IsAuthenticatedOrTokenHasScope, TokenHasScope, TokenMatchesOASRequirements,
)
view = auto_schema.view
request = view.request
for permission in auto_schema.view.get_permissions():
if isinstance(permission, TokenMatchesOASRequirements):
alt_scopes = permission.get_required_alternate_scopes(request, view)
alt_scopes = alt_scopes.get(auto_schema.method, [])
return [{self.name: group} for group in alt_scopes]
if isinstance(permission, IsAuthenticatedOrTokenHasScope):
return {self.name: TokenHasScope().get_scopes(request, view)}
if isinstance(permission, TokenHasScope):
# catch-all for subclasses of TokenHasScope like TokenHasReadWriteScope
return {self.name: permission.get_scopes(request, view)}
def get_security_definition(self, auto_schema):
from oauth2_provider.scopes import get_scopes_backend
from drf_spectacular.settings import spectacular_settings
flows = {}
for flow_type in spectacular_settings.OAUTH2_FLOWS:
flows[flow_type] = {}
if flow_type in ('implicit', 'authorizationCode'):
flows[flow_type]['authorizationUrl'] = spectacular_settings.OAUTH2_AUTHORIZATION_URL
if flow_type in ('password', 'clientCredentials', 'authorizationCode'):
flows[flow_type]['tokenUrl'] = spectacular_settings.OAUTH2_TOKEN_URL
if spectacular_settings.OAUTH2_REFRESH_URL:
flows[flow_type]['refreshUrl'] = spectacular_settings.OAUTH2_REFRESH_URL
if spectacular_settings.OAUTH2_SCOPES:
flows[flow_type]['scopes'] = spectacular_settings.OAUTH2_SCOPES
else:
scope_backend = get_scopes_backend()
flows[flow_type]['scopes'] = scope_backend.get_all_scopes()
return {
'type': 'oauth2',
'flows': flows
}

View File

@@ -0,0 +1,62 @@
import re
from typing import Optional
from django.utils.module_loading import import_string
def camelize_serializer_fields(result, generator, request, public):
from django.conf import settings
from djangorestframework_camel_case.settings import api_settings
from djangorestframework_camel_case.util import camelize_re, underscore_to_camel
# prunes subtrees from camelization based on owning field name
ignore_fields = api_settings.JSON_UNDERSCOREIZE.get("ignore_fields") or ()
# ignore certain field names while camelizing
ignore_keys = api_settings.JSON_UNDERSCOREIZE.get("ignore_keys") or ()
def has_middleware_installed():
try:
from djangorestframework_camel_case.middleware import CamelCaseMiddleWare
except ImportError:
return False
for middleware in [import_string(m) for m in settings.MIDDLEWARE]:
try:
if issubclass(middleware, CamelCaseMiddleWare):
return True
except TypeError:
pass
def camelize_str(key: str) -> str:
new_key = re.sub(camelize_re, underscore_to_camel, key) if "_" in key else key
if key in ignore_keys or new_key in ignore_keys:
return key
return new_key
def camelize_component(schema: dict, name: Optional[str] = None) -> dict:
if name is not None and (name in ignore_fields or camelize_str(name) in ignore_fields):
return schema
elif schema.get('type') == 'object':
if 'properties' in schema:
schema['properties'] = {
camelize_str(field_name): camelize_component(field_schema, field_name)
for field_name, field_schema in schema['properties'].items()
}
if 'required' in schema:
schema['required'] = [camelize_str(field) for field in schema['required']]
elif schema.get('type') == 'array':
camelize_component(schema['items'])
return schema
for (_, component_type), component in generator.registry._components.items():
if component_type == 'schemas':
camelize_component(component.schema)
if has_middleware_installed():
for url_schema in result["paths"].values():
for method_schema in url_schema.values():
for parameter in method_schema.get("parameters", []):
parameter["name"] = camelize_str(parameter["name"])
# inplace modification of components also affect result dict, so regeneration is not necessary
return result

View File

@@ -0,0 +1,13 @@
from drf_spectacular.extensions import OpenApiAuthenticationExtension
from drf_spectacular.plumbing import build_bearer_security_scheme_object
class KnoxTokenScheme(OpenApiAuthenticationExtension):
target_class = 'knox.auth.TokenAuthentication'
name = 'knoxApiToken'
def get_security_definition(self, auto_schema):
return build_bearer_security_scheme_object(
header_name='Authorization',
token_prefix=self.target.authenticate_header(""),
)

View File

@@ -0,0 +1,50 @@
from drf_spectacular.drainage import set_override, warn
from drf_spectacular.extensions import OpenApiSerializerExtension
from drf_spectacular.plumbing import ResolvedComponent, build_basic_type
from drf_spectacular.types import OpenApiTypes
class PydanticExtension(OpenApiSerializerExtension):
"""
Allows using pydantic models on @extend_schema(request=..., response=...) to
describe your API.
We only have partial support for pydantic's version of dataclass, due to the way they
are designed. The outermost class (the @extend_schema argument) has to be a subclass
of pydantic.BaseModel. Inside this outermost BaseModel, any combination of dataclass
and BaseModel can be used.
"""
target_class = "pydantic.BaseModel"
match_subclasses = True
def get_name(self, auto_schema, direction):
# due to the fact that it is complicated to pull out every field member BaseModel class
# of the entry model, we simply use the class name as string for object. This hack may
# create false positive warnings, so turn it off. However, this may suppress correct
# warnings involving the entry class.
# TODO suppression may be migrated to new ComponentIdentity system
set_override(self.target, 'suppress_collision_warning', True)
return self.target.__name__
def map_serializer(self, auto_schema, direction):
# let pydantic generate a JSON schema
try:
from pydantic.json_schema import model_json_schema
except ImportError:
warn("Only pydantic >= 2 is supported. defaulting to generic object.")
return build_basic_type(OpenApiTypes.OBJECT)
schema = model_json_schema(self.target, ref_template="#/components/schemas/{model}", mode="serialization")
# pull out potential sub-schemas and put them into component section
for sub_name, sub_schema in schema.pop("$defs", {}).items():
component = ResolvedComponent(
name=sub_name,
type=ResolvedComponent.SCHEMA,
object=sub_name,
schema=sub_schema,
)
auto_schema.registry.register_on_missing(component)
return schema

View File

@@ -0,0 +1,173 @@
from django.conf import settings
from django.utils.version import get_version_tuple
from rest_framework import serializers
from drf_spectacular.contrib.rest_framework_simplejwt import (
SimpleJWTScheme, TokenRefreshSerializerExtension,
)
from drf_spectacular.drainage import warn
from drf_spectacular.extensions import OpenApiSerializerExtension, OpenApiViewExtension
from drf_spectacular.utils import extend_schema
def get_dj_rest_auth_setting(class_name, setting_name):
from dj_rest_auth.__version__ import __version__
if get_version_tuple(__version__) < (3, 0, 0):
from dj_rest_auth import app_settings
return getattr(app_settings, class_name)
else:
from dj_rest_auth.app_settings import api_settings
return getattr(api_settings, setting_name)
def get_token_serializer_class():
from dj_rest_auth.__version__ import __version__
if get_version_tuple(__version__) < (3, 0, 0):
use_jwt = getattr(settings, 'REST_USE_JWT', False)
else:
from dj_rest_auth.app_settings import api_settings
use_jwt = api_settings.USE_JWT
if use_jwt:
return get_dj_rest_auth_setting('JWTSerializer', 'JWT_SERIALIZER')
else:
return get_dj_rest_auth_setting('TokenSerializer', 'TOKEN_SERIALIZER')
class RestAuthDetailSerializer(serializers.Serializer):
detail = serializers.CharField(read_only=True, required=False)
class RestAuthDefaultResponseView(OpenApiViewExtension):
def view_replacement(self):
class Fixed(self.target_class):
@extend_schema(responses=RestAuthDetailSerializer)
def post(self, request, *args, **kwargs):
pass # pragma: no cover
return Fixed
class RestAuthLoginView(OpenApiViewExtension):
target_class = 'dj_rest_auth.views.LoginView'
def view_replacement(self):
class Fixed(self.target_class):
@extend_schema(responses=get_token_serializer_class())
def post(self, request, *args, **kwargs):
pass # pragma: no cover
return Fixed
class RestAuthLogoutView(OpenApiViewExtension):
target_class = 'dj_rest_auth.views.LogoutView'
def view_replacement(self):
if getattr(settings, 'ACCOUNT_LOGOUT_ON_GET', None):
get_schema_params = {'responses': RestAuthDetailSerializer}
else:
get_schema_params = {'exclude': True}
class Fixed(self.target_class):
@extend_schema(**get_schema_params)
def get(self, request, *args, **kwargs):
pass # pragma: no cover
@extend_schema(request=None, responses=RestAuthDetailSerializer)
def post(self, request, *args, **kwargs):
pass # pragma: no cover
return Fixed
class RestAuthPasswordChangeView(RestAuthDefaultResponseView):
target_class = 'dj_rest_auth.views.PasswordChangeView'
class RestAuthPasswordResetView(RestAuthDefaultResponseView):
target_class = 'dj_rest_auth.views.PasswordResetView'
class RestAuthPasswordResetConfirmView(RestAuthDefaultResponseView):
target_class = 'dj_rest_auth.views.PasswordResetConfirmView'
class RestAuthVerifyEmailView(RestAuthDefaultResponseView):
target_class = 'dj_rest_auth.registration.views.VerifyEmailView'
optional = True
class RestAuthResendEmailVerificationView(RestAuthDefaultResponseView):
target_class = 'dj_rest_auth.registration.views.ResendEmailVerificationView'
optional = True
class RestAuthJWTSerializer(OpenApiSerializerExtension):
target_class = 'dj_rest_auth.serializers.JWTSerializer'
def map_serializer(self, auto_schema, direction):
class Fixed(self.target_class):
user = get_dj_rest_auth_setting('UserDetailsSerializer', 'USER_DETAILS_SERIALIZER')()
return auto_schema._map_serializer(Fixed, direction)
class CookieTokenRefreshSerializerExtension(TokenRefreshSerializerExtension):
target_class = 'dj_rest_auth.jwt_auth.CookieTokenRefreshSerializer'
optional = True
def get_name(self):
return 'TokenRefresh'
class RestAuthRegisterView(OpenApiViewExtension):
target_class = 'dj_rest_auth.registration.views.RegisterView'
optional = True
def view_replacement(self):
from allauth.account.app_settings import EMAIL_VERIFICATION, EmailVerificationMethod
if EMAIL_VERIFICATION == EmailVerificationMethod.MANDATORY:
response_serializer = RestAuthDetailSerializer
else:
response_serializer = get_token_serializer_class()
class Fixed(self.target_class):
@extend_schema(responses=response_serializer)
def post(self, request, *args, **kwargs):
pass # pragma: no cover
return Fixed
class SimpleJWTCookieScheme(SimpleJWTScheme):
target_class = 'dj_rest_auth.jwt_auth.JWTCookieAuthentication'
optional = True
name = ['jwtHeaderAuth', 'jwtCookieAuth'] # type: ignore
def get_security_requirement(self, auto_schema):
return [{name: []} for name in self.name]
def get_security_definition(self, auto_schema):
cookie_name = get_dj_rest_auth_setting('JWT_AUTH_COOKIE', 'JWT_AUTH_COOKIE')
if not cookie_name:
cookie_name = 'jwt-auth'
warn(
f'"JWT_AUTH_COOKIE" setting required for JWTCookieAuthentication. '
f'defaulting to {cookie_name}'
)
return [
super().get_security_definition(auto_schema), # JWT from header
{
'type': 'apiKey',
'in': 'cookie',
'name': cookie_name,
}
]

View File

@@ -0,0 +1,32 @@
from drf_spectacular.extensions import OpenApiViewExtension
class ObtainAuthTokenView(OpenApiViewExtension):
target_class = 'rest_framework.authtoken.views.ObtainAuthToken'
match_subclasses = True
def view_replacement(self):
"""
Prior to DRF 3.12.0, usage of ObtainAuthToken resulted in AssertionError
Incompatible AutoSchema used on View "ObtainAuthToken". Is DRF's DEFAULT_SCHEMA_CLASS ...
This is because DRF had a bug which made it NOT honor DEFAULT_SCHEMA_CLASS and instead
injected an unsolicited coreschema class for this view and this view only. This extension
fixes the view before the wrong schema class is used.
Bug in DRF that was fixed in later versions:
https://github.com/encode/django-rest-framework/blob/4121b01b912668c049b26194a9a107c27a332429/rest_framework/authtoken/views.py#L16
"""
from rest_framework import VERSION
from drf_spectacular.openapi import AutoSchema
# no intervention needed
if VERSION >= '3.12':
return self.target
class FixedObtainAuthToken(self.target):
schema = AutoSchema()
return FixedObtainAuthToken

View File

@@ -0,0 +1,36 @@
from typing import Any
from drf_spectacular.drainage import get_override, has_override
from drf_spectacular.extensions import OpenApiSerializerExtension
from drf_spectacular.plumbing import ComponentIdentity, get_doc
from drf_spectacular.utils import Direction
class OpenApiDataclassSerializerExtensions(OpenApiSerializerExtension):
target_class = "rest_framework_dataclasses.serializers.DataclassSerializer"
match_subclasses = True
def get_name(self):
"""Use the dataclass name in the schema, instead of the serializer prefix (which can be just Dataclass)."""
if has_override(self.target, 'component_name'):
return get_override(self.target, 'component_name')
if getattr(getattr(self.target, 'Meta', None), 'ref_name', None) is not None:
return self.target.Meta.ref_name
if has_override(self.target.dataclass_definition.dataclass_type, 'component_name'):
return get_override(self.target.dataclass_definition.dataclass_type, 'component_name')
return self.target.dataclass_definition.dataclass_type.__name__
def get_identity(self, auto_schema, direction: Direction) -> Any:
return ComponentIdentity(self.target.dataclass_definition.dataclass_type)
def strip_library_doc(self, schema):
"""Strip the DataclassSerializer library documentation from the schema."""
from rest_framework_dataclasses.serializers import DataclassSerializer
if 'description' in schema and schema['description'] == get_doc(DataclassSerializer):
del schema['description']
return schema
def map_serializer(self, auto_schema, direction: Direction):
""""Generate the schema for a DataclassSerializer."""
schema = auto_schema._map_serializer(self.target, direction, bypass_extensions=True)
return self.strip_library_doc(schema)

View File

@@ -0,0 +1,219 @@
from rest_framework.utils.model_meta import get_field_info
from drf_spectacular.drainage import warn
from drf_spectacular.extensions import OpenApiSerializerExtension, OpenApiSerializerFieldExtension
from drf_spectacular.plumbing import (
ResolvedComponent, build_array_type, build_object_type, follow_field_source, get_doc,
)
def build_point_schema():
return {
"type": "array",
"items": {"type": "number", "format": "float"},
"example": [12.9721, 77.5933],
"minItems": 2,
"maxItems": 3,
}
def build_linestring_schema():
return {
"type": "array",
"items": build_point_schema(),
"example": [[22.4707, 70.0577], [12.9721, 77.5933]],
"minItems": 2,
}
def build_polygon_schema():
return {
"type": "array",
"items": {**build_linestring_schema(), "minItems": 4},
"example": [
[
[0.0, 0.0],
[0.0, 50.0],
[50.0, 50.0],
[50.0, 0.0],
[0.0, 0.0],
],
]
}
def build_geo_container_schema(name, coords):
return build_object_type(
properties={
"type": {"type": "string", "enum": [name]},
"coordinates": coords,
}
)
def build_point_geo_schema():
return build_geo_container_schema("Point", build_point_schema())
def build_linestring_geo_schema():
return build_geo_container_schema("LineString", build_linestring_schema())
def build_polygon_geo_schema():
return build_geo_container_schema("Polygon", build_polygon_schema())
def build_geometry_geo_schema():
return {
'oneOf': [
build_point_geo_schema(),
build_linestring_geo_schema(),
build_polygon_geo_schema(),
]
}
def build_bbox_schema():
return {
"type": "array",
"items": {"type": "number"},
"minItems": 4,
"maxItems": 4,
"example": [12.9721, 77.5933, 12.9721, 77.5933],
}
def build_geo_schema(model_field):
from django.contrib.gis.db import models
if isinstance(model_field, models.PointField):
return build_point_geo_schema()
elif isinstance(model_field, models.LineStringField):
return build_linestring_geo_schema()
elif isinstance(model_field, models.PolygonField):
return build_polygon_geo_schema()
elif isinstance(model_field, models.MultiPointField):
return build_geo_container_schema(
"MultiPoint", build_array_type(build_point_schema())
)
elif isinstance(model_field, models.MultiLineStringField):
return build_geo_container_schema(
"MultiLineString", build_array_type(build_linestring_schema())
)
elif isinstance(model_field, models.MultiPolygonField):
return build_geo_container_schema(
"MultiPolygon", build_array_type(build_polygon_schema())
)
elif isinstance(model_field, models.GeometryCollectionField):
return build_geo_container_schema(
"GeometryCollection", build_array_type(build_geometry_geo_schema())
)
elif isinstance(model_field, models.GeometryField):
return build_geometry_geo_schema()
else:
warn("Encountered unknown GIS geometry field")
return {}
def map_geo_field(serializer, geo_field_name):
from rest_framework_gis.fields import GeometrySerializerMethodField
field = serializer.fields[geo_field_name]
if isinstance(field, GeometrySerializerMethodField):
warn("Geometry generation for GeometrySerializerMethodField is not supported.")
return {}
model_field = get_field_info(serializer.Meta.model).fields[geo_field_name]
return build_geo_schema(model_field)
def _inject_enum_collision_fix(collection):
from drf_spectacular.settings import spectacular_settings
if not collection and 'GisFeatureEnum' not in spectacular_settings.ENUM_NAME_OVERRIDES:
spectacular_settings.ENUM_NAME_OVERRIDES['GisFeatureEnum'] = ('Feature',)
if collection and 'GisFeatureCollectionEnum' not in spectacular_settings.ENUM_NAME_OVERRIDES:
spectacular_settings.ENUM_NAME_OVERRIDES['GisFeatureCollectionEnum'] = ('FeatureCollection',)
class GeoFeatureModelSerializerExtension(OpenApiSerializerExtension):
target_class = 'rest_framework_gis.serializers.GeoFeatureModelSerializer'
match_subclasses = True
def map_serializer(self, auto_schema, direction):
_inject_enum_collision_fix(collection=False)
base_schema = auto_schema._map_serializer(self.target, direction, bypass_extensions=True)
return self.map_geo_feature_model_serializer(self.target, base_schema)
def map_geo_feature_model_serializer(self, serializer, base_schema):
from rest_framework_gis.serializers import GeoFeatureModelSerializer
geo_properties = {
"type": {"type": "string", "enum": ["Feature"]}
}
if serializer.Meta.id_field:
geo_properties["id"] = base_schema["properties"].pop(serializer.Meta.id_field)
geo_properties["geometry"] = map_geo_field(serializer, serializer.Meta.geo_field)
base_schema["properties"].pop(serializer.Meta.geo_field)
if serializer.Meta.auto_bbox or serializer.Meta.bbox_geo_field:
geo_properties["bbox"] = build_bbox_schema()
base_schema["properties"].pop(serializer.Meta.bbox_geo_field, None)
# only expose if description comes from the user
description = base_schema.pop('description', None)
if description == get_doc(GeoFeatureModelSerializer):
description = None
# ignore this aspect for now
base_schema.pop('required', None)
# nest remaining fields under property "properties"
geo_properties["properties"] = base_schema
return build_object_type(
properties=geo_properties,
description=description,
)
class GeoFeatureModelListSerializerExtension(OpenApiSerializerExtension):
target_class = 'rest_framework_gis.serializers.GeoFeatureModelListSerializer'
def map_serializer(self, auto_schema, direction):
_inject_enum_collision_fix(collection=True)
# build/retrieve feature component generated by GeoFeatureModelSerializerExtension.
# wrap the ref in the special list structure and build another component based on that.
feature_component = auto_schema.resolve_serializer(self.target.child, direction)
collection_schema = build_object_type(
properties={
"type": {"type": "string", "enum": ["FeatureCollection"]},
"features": build_array_type(feature_component.ref)
}
)
list_component = ResolvedComponent(
name=f'{feature_component.name}List',
type=ResolvedComponent.SCHEMA,
object=self.target.child,
schema=collection_schema
)
auto_schema.registry.register_on_missing(list_component)
return list_component.ref
class GeometryFieldExtension(OpenApiSerializerFieldExtension):
target_class = 'rest_framework_gis.fields.GeometryField'
match_subclasses = True
def map_serializer_field(self, auto_schema, direction):
# running this extension for GeoFeatureModelSerializer's geo_field is superfluous
# as above extension already handles that individually. We run it anyway because
# robustly checking the proper condition is harder.
try:
model = self.target.parent.Meta.model
model_field = follow_field_source(model, self.target.source.split('.'))
return build_geo_schema(model_field)
except: # noqa: E722
warn(f'Encountered an issue resolving field {self.target}. defaulting to generic object.')
return {}

View File

@@ -0,0 +1,16 @@
from drf_spectacular.extensions import OpenApiAuthenticationExtension
from drf_spectacular.plumbing import build_bearer_security_scheme_object
class JWTScheme(OpenApiAuthenticationExtension):
target_class = 'rest_framework_jwt.authentication.JSONWebTokenAuthentication'
name = 'jwtAuth'
def get_security_definition(self, auto_schema):
from rest_framework_jwt.settings import api_settings
return build_bearer_security_scheme_object(
header_name='AUTHORIZATION',
token_prefix=api_settings.JWT_AUTH_HEADER_PREFIX,
bearer_format='JWT'
)

View File

@@ -0,0 +1,16 @@
from drf_spectacular.extensions import OpenApiSerializerFieldExtension
from drf_spectacular.plumbing import build_array_type, is_list_serializer
class RecursiveFieldExtension(OpenApiSerializerFieldExtension):
target_class = "rest_framework_recursive.fields.RecursiveField"
def map_serializer_field(self, auto_schema, direction):
proxied = self.target.proxied
if is_list_serializer(proxied):
component = auto_schema.resolve_serializer(proxied.child, direction)
return build_array_type(component.ref)
component = auto_schema.resolve_serializer(proxied, direction)
return component.ref

View File

@@ -0,0 +1,87 @@
from rest_framework import serializers
from drf_spectacular.drainage import warn
from drf_spectacular.extensions import OpenApiAuthenticationExtension, OpenApiSerializerExtension
from drf_spectacular.plumbing import build_bearer_security_scheme_object
from drf_spectacular.utils import inline_serializer
class TokenObtainPairSerializerExtension(OpenApiSerializerExtension):
target_class = 'rest_framework_simplejwt.serializers.TokenObtainPairSerializer'
def map_serializer(self, auto_schema, direction):
Fixed = inline_serializer('Fixed', fields={
self.target_class.username_field: serializers.CharField(write_only=True),
'password': serializers.CharField(write_only=True),
'access': serializers.CharField(read_only=True),
'refresh': serializers.CharField(read_only=True),
})
return auto_schema._map_serializer(Fixed, direction)
class TokenObtainSlidingSerializerExtension(OpenApiSerializerExtension):
target_class = 'rest_framework_simplejwt.serializers.TokenObtainSlidingSerializer'
def map_serializer(self, auto_schema, direction):
Fixed = inline_serializer('Fixed', fields={
self.target_class.username_field: serializers.CharField(write_only=True),
'password': serializers.CharField(write_only=True),
'token': serializers.CharField(read_only=True),
})
return auto_schema._map_serializer(Fixed, direction)
class TokenRefreshSerializerExtension(OpenApiSerializerExtension):
target_class = 'rest_framework_simplejwt.serializers.TokenRefreshSerializer'
def map_serializer(self, auto_schema, direction):
from rest_framework_simplejwt.settings import api_settings
if api_settings.ROTATE_REFRESH_TOKENS:
class Fixed(serializers.Serializer):
access = serializers.CharField(read_only=True)
refresh = serializers.CharField()
else:
class Fixed(serializers.Serializer):
access = serializers.CharField(read_only=True)
refresh = serializers.CharField(write_only=True)
return auto_schema._map_serializer(Fixed, direction)
class TokenVerifySerializerExtension(OpenApiSerializerExtension):
target_class = 'rest_framework_simplejwt.serializers.TokenVerifySerializer'
def map_serializer(self, auto_schema, direction):
Fixed = inline_serializer('Fixed', fields={
'token': serializers.CharField(write_only=True),
})
return auto_schema._map_serializer(Fixed, direction)
class SimpleJWTScheme(OpenApiAuthenticationExtension):
target_class = 'rest_framework_simplejwt.authentication.JWTAuthentication'
name = 'jwtAuth'
def get_security_definition(self, auto_schema):
from rest_framework_simplejwt.settings import api_settings
if len(api_settings.AUTH_HEADER_TYPES) > 1:
warn(
f'OpenAPI3 can only have one "bearerFormat". JWT Settings specify '
f'{api_settings.AUTH_HEADER_TYPES}. Using the first one.'
)
return build_bearer_security_scheme_object(
header_name=getattr(api_settings, 'AUTH_HEADER_NAME', 'HTTP_AUTHORIZATION'),
token_prefix=api_settings.AUTH_HEADER_TYPES[0],
bearer_format='JWT'
)
class SimpleJWTTokenUserScheme(SimpleJWTScheme):
target_class = 'rest_framework_simplejwt.authentication.JWTTokenUserAuthentication'
class SimpleJWTStatelessUserScheme(SimpleJWTScheme):
target_class = "rest_framework_simplejwt.authentication.JWTStatelessUserAuthentication"

View File

@@ -0,0 +1,81 @@
from drf_spectacular.drainage import warn
from drf_spectacular.extensions import OpenApiSerializerExtension
from drf_spectacular.plumbing import (
ComponentIdentity, ResolvedComponent, build_basic_type, build_object_type,
is_patched_serializer,
)
from drf_spectacular.settings import spectacular_settings
from drf_spectacular.types import OpenApiTypes
class PolymorphicSerializerExtension(OpenApiSerializerExtension):
target_class = 'rest_polymorphic.serializers.PolymorphicSerializer'
match_subclasses = True
def map_serializer(self, auto_schema, direction):
sub_components = []
serializer = self.target
for sub_model in serializer.model_serializer_mapping:
sub_serializer = serializer._get_serializer_from_model_or_instance(sub_model)
sub_serializer.partial = serializer.partial
resource_type = serializer.to_resource_type(sub_model)
component = auto_schema.resolve_serializer(sub_serializer, direction)
if not component:
# rebuild a virtual schema-less component to model empty serializers
component = ResolvedComponent(
name=auto_schema._get_serializer_name(sub_serializer, direction),
type=ResolvedComponent.SCHEMA,
object=ComponentIdentity('virtual')
)
typed_component = self.build_typed_component(
auto_schema=auto_schema,
component=component,
resource_type_field_name=serializer.resource_type_field_name,
patched=is_patched_serializer(sub_serializer, direction)
)
sub_components.append((resource_type, typed_component.ref))
if not resource_type:
warn(
f'discriminator mapping key is empty for {sub_serializer.__class__}. '
f'this might lead to code generation issues.'
)
one_of_list = []
for _, ref in sub_components:
if ref not in one_of_list:
one_of_list.append(ref)
return {
'oneOf': one_of_list,
'discriminator': {
'propertyName': serializer.resource_type_field_name,
'mapping': {resource_type: ref['$ref'] for resource_type, ref in sub_components},
}
}
def build_typed_component(self, auto_schema, component, resource_type_field_name, patched):
if spectacular_settings.COMPONENT_SPLIT_REQUEST and component.name.endswith('Request'):
typed_component_name = component.name[:-len('Request')] + 'TypedRequest'
else:
typed_component_name = f'{component.name}Typed'
resource_type_schema = build_object_type(
properties={resource_type_field_name: build_basic_type(OpenApiTypes.STR)},
required=None if patched else [resource_type_field_name]
)
# if sub-serializer has an empty schema, only expose the resource_type field part
if component.schema:
schema = {'allOf': [resource_type_schema, component.ref]}
else:
schema = resource_type_schema
component_typed = ResolvedComponent(
name=typed_component_name,
type=ResolvedComponent.SCHEMA,
object=component.object,
schema=schema,
)
auto_schema.registry.register_on_missing(component_typed)
return component_typed