This commit is contained in:
Iliyan Angelov
2025-09-19 11:58:53 +03:00
parent 306b20e24a
commit 6b247e5b9f
11423 changed files with 1500615 additions and 778 deletions

View File

@@ -0,0 +1,6 @@
import django
__version__ = '0.28.0'
if django.VERSION < (3, 2):
default_app_config = 'drf_spectacular.apps.SpectacularConfig'

View File

@@ -0,0 +1,9 @@
from django.apps import AppConfig
class SpectacularConfig(AppConfig):
name = 'drf_spectacular'
verbose_name = "drf-spectacular"
def ready(self):
import drf_spectacular.checks # noqa: F401

View File

@@ -0,0 +1,42 @@
from django.conf import settings
from drf_spectacular.extensions import OpenApiAuthenticationExtension
from drf_spectacular.plumbing import build_bearer_security_scheme_object
class SessionScheme(OpenApiAuthenticationExtension):
target_class = 'rest_framework.authentication.SessionAuthentication'
name = 'cookieAuth'
priority = -1
def get_security_definition(self, auto_schema):
return {
'type': 'apiKey',
'in': 'cookie',
'name': settings.SESSION_COOKIE_NAME,
}
class BasicScheme(OpenApiAuthenticationExtension):
target_class = 'rest_framework.authentication.BasicAuthentication'
name = 'basicAuth'
priority = -1
def get_security_definition(self, auto_schema):
return {
'type': 'http',
'scheme': 'basic',
}
class TokenScheme(OpenApiAuthenticationExtension):
target_class = 'rest_framework.authentication.TokenAuthentication'
name = 'tokenAuth'
match_subclasses = True
priority = -1
def get_security_definition(self, auto_schema):
return build_bearer_security_scheme_object(
header_name='Authorization',
token_prefix=self.target.keyword,
)

View File

@@ -0,0 +1,26 @@
from django.core.checks import Error, Warning, register
@register(deploy=True)
def schema_check(app_configs, **kwargs):
""" Perform dummy generation and emit warnings/errors as part of Django's check framework """
from drf_spectacular.drainage import GENERATOR_STATS
from drf_spectacular.settings import spectacular_settings
if not spectacular_settings.ENABLE_DJANGO_DEPLOY_CHECK:
return []
errors = []
try:
with GENERATOR_STATS.silence():
spectacular_settings.DEFAULT_GENERATOR_CLASS().get_schema(request=None, public=True)
except Exception as exc:
errors.append(
Error(f'Schema generation threw exception "{exc}"', id='drf_spectacular.E001')
)
if GENERATOR_STATS:
for w in GENERATOR_STATS._warn_cache:
errors.append(Warning(w, id='drf_spectacular.W001'))
for e in GENERATOR_STATS._error_cache:
errors.append(Warning(e, id='drf_spectacular.W002'))
return errors

View File

@@ -0,0 +1,15 @@
__all__ = [
'django_oauth_toolkit',
'djangorestframework_camel_case',
'rest_auth',
'rest_framework',
'rest_polymorphic',
'rest_framework_dataclasses',
'rest_framework_jwt',
'rest_framework_simplejwt',
'django_filters',
'rest_framework_recursive',
'rest_framework_gis',
'pydantic',
'knox_auth_token',
]

View File

@@ -0,0 +1,305 @@
from django.db import models
from drf_spectacular.drainage import add_trace_message, get_override, has_override, warn
from drf_spectacular.extensions import OpenApiFilterExtension
from drf_spectacular.plumbing import (
build_array_type, build_basic_type, build_choice_description_list, build_parameter_type,
follow_field_source, force_instance, get_manager, get_type_hints, get_view_model, is_basic_type,
is_field,
)
from drf_spectacular.settings import spectacular_settings
from drf_spectacular.types import OpenApiTypes
from drf_spectacular.utils import OpenApiParameter
_NoHint = object()
class DjangoFilterExtension(OpenApiFilterExtension):
"""
Extensions that specifically deals with ``django-filter`` fields. The introspection
attempts to estimate the underlying model field types to generate filter types.
However, there are under-specified filter fields for which heuristics need to be performed.
This serves as an explicit list of all partially-handled filter fields:
- ``AllValuesFilter``: skip choices to prevent DB query
- ``AllValuesMultipleFilter``: skip choices to prevent DB query, multi handled though
- ``ChoiceFilter``: enum handled, type under-specified
- ``DateRangeFilter``: N/A
- ``LookupChoiceFilter``: N/A
- ``ModelChoiceFilter``: enum handled
- ``ModelMultipleChoiceFilter``: enum, multi handled
- ``MultipleChoiceFilter``: enum, multi handled
- ``RangeFilter``: min/max handled, type under-specified
- ``TypedChoiceFilter``: enum handled
- ``TypedMultipleChoiceFilter``: enum, multi handled
In case of warnings or incorrect filter types, you can manually override the underlying
field type with a manual ``extend_schema_field`` decoration. Alternatively, if you have a
filter method for your filter field, you can attach ``extend_schema_field`` to that filter
method.
.. code-block::
class SomeFilter(FilterSet):
some_field = extend_schema_field(OpenApiTypes.NUMBER)(
RangeFilter(field_name='some_manually_annotated_field_in_qs')
)
"""
target_class = 'django_filters.rest_framework.DjangoFilterBackend'
match_subclasses = True
def get_schema_operation_parameters(self, auto_schema, *args, **kwargs):
model = get_view_model(auto_schema.view)
if not model:
return []
filterset_class = self.target.get_filterset_class(auto_schema.view, get_manager(model).none())
if not filterset_class:
return []
result = []
with add_trace_message(filterset_class):
for field_name, filter_field in filterset_class.base_filters.items():
result += self.resolve_filter_field(
auto_schema, model, filterset_class, field_name, filter_field
)
return result
def resolve_filter_field(self, auto_schema, model, filterset_class, field_name, filter_field):
from django_filters import filters
unambiguous_mapping = {
filters.CharFilter: OpenApiTypes.STR,
filters.BooleanFilter: OpenApiTypes.BOOL,
filters.DateFilter: OpenApiTypes.DATE,
filters.DateTimeFilter: OpenApiTypes.DATETIME,
filters.IsoDateTimeFilter: OpenApiTypes.DATETIME,
filters.TimeFilter: OpenApiTypes.TIME,
filters.UUIDFilter: OpenApiTypes.UUID,
filters.DurationFilter: OpenApiTypes.DURATION,
filters.OrderingFilter: OpenApiTypes.STR,
filters.TimeRangeFilter: OpenApiTypes.TIME,
filters.DateFromToRangeFilter: OpenApiTypes.DATE,
filters.IsoDateTimeFromToRangeFilter: OpenApiTypes.DATETIME,
filters.DateTimeFromToRangeFilter: OpenApiTypes.DATETIME,
}
filter_method = self._get_filter_method(filterset_class, filter_field)
filter_method_hint = self._get_filter_method_hint(filter_method)
filter_choices = self._get_explicit_filter_choices(filter_field)
schema_from_override = False
if has_override(filter_field, 'field') or has_override(filter_method, 'field'):
schema_from_override = True
annotation = (
get_override(filter_field, 'field') or get_override(filter_method, 'field')
)
if is_basic_type(annotation):
schema = build_basic_type(annotation)
elif isinstance(annotation, dict):
# allow injecting raw schema via @extend_schema_field decorator
schema = annotation.copy()
elif is_field(annotation):
schema = auto_schema._map_serializer_field(force_instance(annotation), "request")
else:
warn(
f"Unsupported annotation {annotation} on filter field {field_name}. defaulting to string."
)
schema = build_basic_type(OpenApiTypes.STR)
elif filter_method_hint is not _NoHint:
if is_basic_type(filter_method_hint):
schema = build_basic_type(filter_method_hint)
else:
schema = build_basic_type(OpenApiTypes.STR)
elif isinstance(filter_field, tuple(unambiguous_mapping)):
for cls in filter_field.__class__.__mro__:
if cls in unambiguous_mapping:
schema = build_basic_type(unambiguous_mapping[cls])
break
elif isinstance(filter_field, (filters.NumberFilter, filters.NumericRangeFilter)):
# NumberField is underspecified by itself. try to find the
# type that makes the most sense or default to generic NUMBER
model_field = self._get_model_field(filter_field, model)
if isinstance(model_field, (models.IntegerField, models.AutoField)):
schema = build_basic_type(OpenApiTypes.INT)
elif isinstance(model_field, models.FloatField):
schema = build_basic_type(OpenApiTypes.FLOAT)
elif isinstance(model_field, models.DecimalField):
schema = build_basic_type(OpenApiTypes.NUMBER) # TODO may be improved
else:
schema = build_basic_type(OpenApiTypes.NUMBER)
elif isinstance(filter_field, (filters.ChoiceFilter, filters.MultipleChoiceFilter)):
try:
schema = self._get_schema_from_model_field(auto_schema, filter_field, model)
except Exception:
if filter_choices and is_basic_type(type(filter_choices[0])):
# fallback to type guessing from first choice element
schema = build_basic_type(type(filter_choices[0]))
else:
warn(
f'Unable to guess choice types from values, filter method\'s type hint '
f'or find "{field_name}" in model. Defaulting to string.'
)
schema = build_basic_type(OpenApiTypes.STR)
else:
# the last resort is to look up the type via the model or queryset field
# and emit a warning if we were unsuccessful.
try:
schema = self._get_schema_from_model_field(auto_schema, filter_field, model)
except Exception as exc: # pragma: no cover
warn(
f'Exception raised while trying resolve model field for django-filter '
f'field "{field_name}". Defaulting to string (Exception: {exc})'
)
schema = build_basic_type(OpenApiTypes.STR)
# primary keys are usually non-editable (readOnly=True) and map_model_field correctly
# signals that attribute. however this does not apply in this context.
schema.pop('readOnly', None)
# enrich schema with additional info from filter_field
enum = schema.pop('enum', None)
# explicit filter choices may disable enum retrieved from model
if not schema_from_override and filter_choices is not None:
enum = filter_choices
description = schema.pop('description', None)
if not schema_from_override:
description = self._get_field_description(filter_field, description)
# parameter style variations based on filter base class
if isinstance(filter_field, filters.BaseCSVFilter):
schema = build_array_type(schema)
field_names = [field_name]
explode = False
style = 'form'
elif isinstance(filter_field, filters.MultipleChoiceFilter):
schema = build_array_type(schema)
field_names = [field_name]
explode = True
style = 'form'
elif isinstance(filter_field, (filters.RangeFilter, filters.NumericRangeFilter)):
try:
suffixes = filter_field.field_class.widget.suffixes
except AttributeError:
suffixes = ['min', 'max']
field_names = [
f'{field_name}_{suffix}' if suffix else field_name for suffix in suffixes
]
explode = None
style = None
else:
field_names = [field_name]
explode = None
style = None
return [
build_parameter_type(
name=field_name,
required=filter_field.extra['required'],
location=OpenApiParameter.QUERY,
description=description,
schema=schema,
enum=enum,
explode=explode,
style=style
)
for field_name in field_names
]
def _get_filter_method(self, filterset_class, filter_field):
if callable(filter_field.method):
return filter_field.method
elif isinstance(filter_field.method, str):
return getattr(filterset_class, filter_field.method)
else:
return None
def _get_filter_method_hint(self, filter_method):
try:
return get_type_hints(filter_method)['value']
except: # noqa: E722
return _NoHint
def _get_explicit_filter_choices(self, filter_field):
if 'choices' not in filter_field.extra:
return None
elif callable(filter_field.extra['choices']):
# choices function may utilize the DB, so refrain from actually calling it.
return []
else:
return [c for c, _ in filter_field.extra['choices']]
def _get_model_field(self, filter_field, model):
if not filter_field.field_name:
return None
path = filter_field.field_name.split('__')
return follow_field_source(model, path, emit_warnings=False)
def _get_schema_from_model_field(self, auto_schema, filter_field, model):
# Has potential to throw exceptions. Needs to be wrapped in try/except!
#
# first search for the field in the model as this has the least amount of
# potential side effects. Only after that fails, attempt to call
# get_queryset() to check for potential query annotations.
model_field = self._get_model_field(filter_field, model)
# this is a cross feature between rest-framework-gis and django-filter. Regular
# behavior needs to be sidestepped as the model information is lost down the line.
# TODO for now this will be just a string to cover WKT, WKB, and urlencoded GeoJSON
# build_geo_schema(model_field) would yield the correct result
if self._is_gis(model_field):
return build_basic_type(OpenApiTypes.STR)
if not isinstance(model_field, models.Field):
qs = auto_schema.view.get_queryset()
model_field = qs.query.annotations[filter_field.field_name].field
return auto_schema._map_model_field(model_field, direction=None)
def _get_field_description(self, filter_field, description):
# Try to improve description beyond auto-generated model description
if filter_field.extra.get('help_text', None):
description = filter_field.extra['help_text']
elif filter_field.label is not None:
description = filter_field.label
choices = filter_field.extra.get('choices')
if choices and callable(choices):
# remove auto-generated enum list, since choices come from a callable
if '\n\n*' in (description or ''):
description, _, _ = description.partition('\n\n*')
elif (description or '').startswith('* `'):
description = ''
return description
choice_description = ''
if spectacular_settings.ENUM_GENERATE_CHOICE_DESCRIPTION and choices and not callable(choices):
choice_description = build_choice_description_list(choices)
if not choices:
return description
if not description:
return choice_description
if '\n\n*' in description:
description, _, _ = description.partition('\n\n*')
return description + '\n\n' + choice_description
if description.startswith('* `'):
return choice_description
return description + '\n\n' + choice_description
@classmethod
def _is_gis(cls, field):
if not getattr(cls, '_has_gis', True):
return False
try:
from django.contrib.gis.db.models import GeometryField
from rest_framework_gis.filters import GeometryFilter
return isinstance(field, (GeometryField, GeometryFilter))
except: # noqa
cls._has_gis = False
return False

View File

@@ -0,0 +1,49 @@
from drf_spectacular.extensions import OpenApiAuthenticationExtension
class DjangoOAuthToolkitScheme(OpenApiAuthenticationExtension):
target_class = 'oauth2_provider.contrib.rest_framework.OAuth2Authentication'
name = 'oauth2'
def get_security_requirement(self, auto_schema):
from oauth2_provider.contrib.rest_framework import (
IsAuthenticatedOrTokenHasScope, TokenHasScope, TokenMatchesOASRequirements,
)
view = auto_schema.view
request = view.request
for permission in auto_schema.view.get_permissions():
if isinstance(permission, TokenMatchesOASRequirements):
alt_scopes = permission.get_required_alternate_scopes(request, view)
alt_scopes = alt_scopes.get(auto_schema.method, [])
return [{self.name: group} for group in alt_scopes]
if isinstance(permission, IsAuthenticatedOrTokenHasScope):
return {self.name: TokenHasScope().get_scopes(request, view)}
if isinstance(permission, TokenHasScope):
# catch-all for subclasses of TokenHasScope like TokenHasReadWriteScope
return {self.name: permission.get_scopes(request, view)}
def get_security_definition(self, auto_schema):
from oauth2_provider.scopes import get_scopes_backend
from drf_spectacular.settings import spectacular_settings
flows = {}
for flow_type in spectacular_settings.OAUTH2_FLOWS:
flows[flow_type] = {}
if flow_type in ('implicit', 'authorizationCode'):
flows[flow_type]['authorizationUrl'] = spectacular_settings.OAUTH2_AUTHORIZATION_URL
if flow_type in ('password', 'clientCredentials', 'authorizationCode'):
flows[flow_type]['tokenUrl'] = spectacular_settings.OAUTH2_TOKEN_URL
if spectacular_settings.OAUTH2_REFRESH_URL:
flows[flow_type]['refreshUrl'] = spectacular_settings.OAUTH2_REFRESH_URL
if spectacular_settings.OAUTH2_SCOPES:
flows[flow_type]['scopes'] = spectacular_settings.OAUTH2_SCOPES
else:
scope_backend = get_scopes_backend()
flows[flow_type]['scopes'] = scope_backend.get_all_scopes()
return {
'type': 'oauth2',
'flows': flows
}

View File

@@ -0,0 +1,62 @@
import re
from typing import Optional
from django.utils.module_loading import import_string
def camelize_serializer_fields(result, generator, request, public):
from django.conf import settings
from djangorestframework_camel_case.settings import api_settings
from djangorestframework_camel_case.util import camelize_re, underscore_to_camel
# prunes subtrees from camelization based on owning field name
ignore_fields = api_settings.JSON_UNDERSCOREIZE.get("ignore_fields") or ()
# ignore certain field names while camelizing
ignore_keys = api_settings.JSON_UNDERSCOREIZE.get("ignore_keys") or ()
def has_middleware_installed():
try:
from djangorestframework_camel_case.middleware import CamelCaseMiddleWare
except ImportError:
return False
for middleware in [import_string(m) for m in settings.MIDDLEWARE]:
try:
if issubclass(middleware, CamelCaseMiddleWare):
return True
except TypeError:
pass
def camelize_str(key: str) -> str:
new_key = re.sub(camelize_re, underscore_to_camel, key) if "_" in key else key
if key in ignore_keys or new_key in ignore_keys:
return key
return new_key
def camelize_component(schema: dict, name: Optional[str] = None) -> dict:
if name is not None and (name in ignore_fields or camelize_str(name) in ignore_fields):
return schema
elif schema.get('type') == 'object':
if 'properties' in schema:
schema['properties'] = {
camelize_str(field_name): camelize_component(field_schema, field_name)
for field_name, field_schema in schema['properties'].items()
}
if 'required' in schema:
schema['required'] = [camelize_str(field) for field in schema['required']]
elif schema.get('type') == 'array':
camelize_component(schema['items'])
return schema
for (_, component_type), component in generator.registry._components.items():
if component_type == 'schemas':
camelize_component(component.schema)
if has_middleware_installed():
for url_schema in result["paths"].values():
for method_schema in url_schema.values():
for parameter in method_schema.get("parameters", []):
parameter["name"] = camelize_str(parameter["name"])
# inplace modification of components also affect result dict, so regeneration is not necessary
return result

View File

@@ -0,0 +1,13 @@
from drf_spectacular.extensions import OpenApiAuthenticationExtension
from drf_spectacular.plumbing import build_bearer_security_scheme_object
class KnoxTokenScheme(OpenApiAuthenticationExtension):
target_class = 'knox.auth.TokenAuthentication'
name = 'knoxApiToken'
def get_security_definition(self, auto_schema):
return build_bearer_security_scheme_object(
header_name='Authorization',
token_prefix=self.target.authenticate_header(""),
)

View File

@@ -0,0 +1,50 @@
from drf_spectacular.drainage import set_override, warn
from drf_spectacular.extensions import OpenApiSerializerExtension
from drf_spectacular.plumbing import ResolvedComponent, build_basic_type
from drf_spectacular.types import OpenApiTypes
class PydanticExtension(OpenApiSerializerExtension):
"""
Allows using pydantic models on @extend_schema(request=..., response=...) to
describe your API.
We only have partial support for pydantic's version of dataclass, due to the way they
are designed. The outermost class (the @extend_schema argument) has to be a subclass
of pydantic.BaseModel. Inside this outermost BaseModel, any combination of dataclass
and BaseModel can be used.
"""
target_class = "pydantic.BaseModel"
match_subclasses = True
def get_name(self, auto_schema, direction):
# due to the fact that it is complicated to pull out every field member BaseModel class
# of the entry model, we simply use the class name as string for object. This hack may
# create false positive warnings, so turn it off. However, this may suppress correct
# warnings involving the entry class.
# TODO suppression may be migrated to new ComponentIdentity system
set_override(self.target, 'suppress_collision_warning', True)
return self.target.__name__
def map_serializer(self, auto_schema, direction):
# let pydantic generate a JSON schema
try:
from pydantic.json_schema import model_json_schema
except ImportError:
warn("Only pydantic >= 2 is supported. defaulting to generic object.")
return build_basic_type(OpenApiTypes.OBJECT)
schema = model_json_schema(self.target, ref_template="#/components/schemas/{model}", mode="serialization")
# pull out potential sub-schemas and put them into component section
for sub_name, sub_schema in schema.pop("$defs", {}).items():
component = ResolvedComponent(
name=sub_name,
type=ResolvedComponent.SCHEMA,
object=sub_name,
schema=sub_schema,
)
auto_schema.registry.register_on_missing(component)
return schema

View File

@@ -0,0 +1,173 @@
from django.conf import settings
from django.utils.version import get_version_tuple
from rest_framework import serializers
from drf_spectacular.contrib.rest_framework_simplejwt import (
SimpleJWTScheme, TokenRefreshSerializerExtension,
)
from drf_spectacular.drainage import warn
from drf_spectacular.extensions import OpenApiSerializerExtension, OpenApiViewExtension
from drf_spectacular.utils import extend_schema
def get_dj_rest_auth_setting(class_name, setting_name):
from dj_rest_auth.__version__ import __version__
if get_version_tuple(__version__) < (3, 0, 0):
from dj_rest_auth import app_settings
return getattr(app_settings, class_name)
else:
from dj_rest_auth.app_settings import api_settings
return getattr(api_settings, setting_name)
def get_token_serializer_class():
from dj_rest_auth.__version__ import __version__
if get_version_tuple(__version__) < (3, 0, 0):
use_jwt = getattr(settings, 'REST_USE_JWT', False)
else:
from dj_rest_auth.app_settings import api_settings
use_jwt = api_settings.USE_JWT
if use_jwt:
return get_dj_rest_auth_setting('JWTSerializer', 'JWT_SERIALIZER')
else:
return get_dj_rest_auth_setting('TokenSerializer', 'TOKEN_SERIALIZER')
class RestAuthDetailSerializer(serializers.Serializer):
detail = serializers.CharField(read_only=True, required=False)
class RestAuthDefaultResponseView(OpenApiViewExtension):
def view_replacement(self):
class Fixed(self.target_class):
@extend_schema(responses=RestAuthDetailSerializer)
def post(self, request, *args, **kwargs):
pass # pragma: no cover
return Fixed
class RestAuthLoginView(OpenApiViewExtension):
target_class = 'dj_rest_auth.views.LoginView'
def view_replacement(self):
class Fixed(self.target_class):
@extend_schema(responses=get_token_serializer_class())
def post(self, request, *args, **kwargs):
pass # pragma: no cover
return Fixed
class RestAuthLogoutView(OpenApiViewExtension):
target_class = 'dj_rest_auth.views.LogoutView'
def view_replacement(self):
if getattr(settings, 'ACCOUNT_LOGOUT_ON_GET', None):
get_schema_params = {'responses': RestAuthDetailSerializer}
else:
get_schema_params = {'exclude': True}
class Fixed(self.target_class):
@extend_schema(**get_schema_params)
def get(self, request, *args, **kwargs):
pass # pragma: no cover
@extend_schema(request=None, responses=RestAuthDetailSerializer)
def post(self, request, *args, **kwargs):
pass # pragma: no cover
return Fixed
class RestAuthPasswordChangeView(RestAuthDefaultResponseView):
target_class = 'dj_rest_auth.views.PasswordChangeView'
class RestAuthPasswordResetView(RestAuthDefaultResponseView):
target_class = 'dj_rest_auth.views.PasswordResetView'
class RestAuthPasswordResetConfirmView(RestAuthDefaultResponseView):
target_class = 'dj_rest_auth.views.PasswordResetConfirmView'
class RestAuthVerifyEmailView(RestAuthDefaultResponseView):
target_class = 'dj_rest_auth.registration.views.VerifyEmailView'
optional = True
class RestAuthResendEmailVerificationView(RestAuthDefaultResponseView):
target_class = 'dj_rest_auth.registration.views.ResendEmailVerificationView'
optional = True
class RestAuthJWTSerializer(OpenApiSerializerExtension):
target_class = 'dj_rest_auth.serializers.JWTSerializer'
def map_serializer(self, auto_schema, direction):
class Fixed(self.target_class):
user = get_dj_rest_auth_setting('UserDetailsSerializer', 'USER_DETAILS_SERIALIZER')()
return auto_schema._map_serializer(Fixed, direction)
class CookieTokenRefreshSerializerExtension(TokenRefreshSerializerExtension):
target_class = 'dj_rest_auth.jwt_auth.CookieTokenRefreshSerializer'
optional = True
def get_name(self):
return 'TokenRefresh'
class RestAuthRegisterView(OpenApiViewExtension):
target_class = 'dj_rest_auth.registration.views.RegisterView'
optional = True
def view_replacement(self):
from allauth.account.app_settings import EMAIL_VERIFICATION, EmailVerificationMethod
if EMAIL_VERIFICATION == EmailVerificationMethod.MANDATORY:
response_serializer = RestAuthDetailSerializer
else:
response_serializer = get_token_serializer_class()
class Fixed(self.target_class):
@extend_schema(responses=response_serializer)
def post(self, request, *args, **kwargs):
pass # pragma: no cover
return Fixed
class SimpleJWTCookieScheme(SimpleJWTScheme):
target_class = 'dj_rest_auth.jwt_auth.JWTCookieAuthentication'
optional = True
name = ['jwtHeaderAuth', 'jwtCookieAuth'] # type: ignore
def get_security_requirement(self, auto_schema):
return [{name: []} for name in self.name]
def get_security_definition(self, auto_schema):
cookie_name = get_dj_rest_auth_setting('JWT_AUTH_COOKIE', 'JWT_AUTH_COOKIE')
if not cookie_name:
cookie_name = 'jwt-auth'
warn(
f'"JWT_AUTH_COOKIE" setting required for JWTCookieAuthentication. '
f'defaulting to {cookie_name}'
)
return [
super().get_security_definition(auto_schema), # JWT from header
{
'type': 'apiKey',
'in': 'cookie',
'name': cookie_name,
}
]

View File

@@ -0,0 +1,32 @@
from drf_spectacular.extensions import OpenApiViewExtension
class ObtainAuthTokenView(OpenApiViewExtension):
target_class = 'rest_framework.authtoken.views.ObtainAuthToken'
match_subclasses = True
def view_replacement(self):
"""
Prior to DRF 3.12.0, usage of ObtainAuthToken resulted in AssertionError
Incompatible AutoSchema used on View "ObtainAuthToken". Is DRF's DEFAULT_SCHEMA_CLASS ...
This is because DRF had a bug which made it NOT honor DEFAULT_SCHEMA_CLASS and instead
injected an unsolicited coreschema class for this view and this view only. This extension
fixes the view before the wrong schema class is used.
Bug in DRF that was fixed in later versions:
https://github.com/encode/django-rest-framework/blob/4121b01b912668c049b26194a9a107c27a332429/rest_framework/authtoken/views.py#L16
"""
from rest_framework import VERSION
from drf_spectacular.openapi import AutoSchema
# no intervention needed
if VERSION >= '3.12':
return self.target
class FixedObtainAuthToken(self.target):
schema = AutoSchema()
return FixedObtainAuthToken

View File

@@ -0,0 +1,36 @@
from typing import Any
from drf_spectacular.drainage import get_override, has_override
from drf_spectacular.extensions import OpenApiSerializerExtension
from drf_spectacular.plumbing import ComponentIdentity, get_doc
from drf_spectacular.utils import Direction
class OpenApiDataclassSerializerExtensions(OpenApiSerializerExtension):
target_class = "rest_framework_dataclasses.serializers.DataclassSerializer"
match_subclasses = True
def get_name(self):
"""Use the dataclass name in the schema, instead of the serializer prefix (which can be just Dataclass)."""
if has_override(self.target, 'component_name'):
return get_override(self.target, 'component_name')
if getattr(getattr(self.target, 'Meta', None), 'ref_name', None) is not None:
return self.target.Meta.ref_name
if has_override(self.target.dataclass_definition.dataclass_type, 'component_name'):
return get_override(self.target.dataclass_definition.dataclass_type, 'component_name')
return self.target.dataclass_definition.dataclass_type.__name__
def get_identity(self, auto_schema, direction: Direction) -> Any:
return ComponentIdentity(self.target.dataclass_definition.dataclass_type)
def strip_library_doc(self, schema):
"""Strip the DataclassSerializer library documentation from the schema."""
from rest_framework_dataclasses.serializers import DataclassSerializer
if 'description' in schema and schema['description'] == get_doc(DataclassSerializer):
del schema['description']
return schema
def map_serializer(self, auto_schema, direction: Direction):
""""Generate the schema for a DataclassSerializer."""
schema = auto_schema._map_serializer(self.target, direction, bypass_extensions=True)
return self.strip_library_doc(schema)

View File

@@ -0,0 +1,219 @@
from rest_framework.utils.model_meta import get_field_info
from drf_spectacular.drainage import warn
from drf_spectacular.extensions import OpenApiSerializerExtension, OpenApiSerializerFieldExtension
from drf_spectacular.plumbing import (
ResolvedComponent, build_array_type, build_object_type, follow_field_source, get_doc,
)
def build_point_schema():
return {
"type": "array",
"items": {"type": "number", "format": "float"},
"example": [12.9721, 77.5933],
"minItems": 2,
"maxItems": 3,
}
def build_linestring_schema():
return {
"type": "array",
"items": build_point_schema(),
"example": [[22.4707, 70.0577], [12.9721, 77.5933]],
"minItems": 2,
}
def build_polygon_schema():
return {
"type": "array",
"items": {**build_linestring_schema(), "minItems": 4},
"example": [
[
[0.0, 0.0],
[0.0, 50.0],
[50.0, 50.0],
[50.0, 0.0],
[0.0, 0.0],
],
]
}
def build_geo_container_schema(name, coords):
return build_object_type(
properties={
"type": {"type": "string", "enum": [name]},
"coordinates": coords,
}
)
def build_point_geo_schema():
return build_geo_container_schema("Point", build_point_schema())
def build_linestring_geo_schema():
return build_geo_container_schema("LineString", build_linestring_schema())
def build_polygon_geo_schema():
return build_geo_container_schema("Polygon", build_polygon_schema())
def build_geometry_geo_schema():
return {
'oneOf': [
build_point_geo_schema(),
build_linestring_geo_schema(),
build_polygon_geo_schema(),
]
}
def build_bbox_schema():
return {
"type": "array",
"items": {"type": "number"},
"minItems": 4,
"maxItems": 4,
"example": [12.9721, 77.5933, 12.9721, 77.5933],
}
def build_geo_schema(model_field):
from django.contrib.gis.db import models
if isinstance(model_field, models.PointField):
return build_point_geo_schema()
elif isinstance(model_field, models.LineStringField):
return build_linestring_geo_schema()
elif isinstance(model_field, models.PolygonField):
return build_polygon_geo_schema()
elif isinstance(model_field, models.MultiPointField):
return build_geo_container_schema(
"MultiPoint", build_array_type(build_point_schema())
)
elif isinstance(model_field, models.MultiLineStringField):
return build_geo_container_schema(
"MultiLineString", build_array_type(build_linestring_schema())
)
elif isinstance(model_field, models.MultiPolygonField):
return build_geo_container_schema(
"MultiPolygon", build_array_type(build_polygon_schema())
)
elif isinstance(model_field, models.GeometryCollectionField):
return build_geo_container_schema(
"GeometryCollection", build_array_type(build_geometry_geo_schema())
)
elif isinstance(model_field, models.GeometryField):
return build_geometry_geo_schema()
else:
warn("Encountered unknown GIS geometry field")
return {}
def map_geo_field(serializer, geo_field_name):
from rest_framework_gis.fields import GeometrySerializerMethodField
field = serializer.fields[geo_field_name]
if isinstance(field, GeometrySerializerMethodField):
warn("Geometry generation for GeometrySerializerMethodField is not supported.")
return {}
model_field = get_field_info(serializer.Meta.model).fields[geo_field_name]
return build_geo_schema(model_field)
def _inject_enum_collision_fix(collection):
from drf_spectacular.settings import spectacular_settings
if not collection and 'GisFeatureEnum' not in spectacular_settings.ENUM_NAME_OVERRIDES:
spectacular_settings.ENUM_NAME_OVERRIDES['GisFeatureEnum'] = ('Feature',)
if collection and 'GisFeatureCollectionEnum' not in spectacular_settings.ENUM_NAME_OVERRIDES:
spectacular_settings.ENUM_NAME_OVERRIDES['GisFeatureCollectionEnum'] = ('FeatureCollection',)
class GeoFeatureModelSerializerExtension(OpenApiSerializerExtension):
target_class = 'rest_framework_gis.serializers.GeoFeatureModelSerializer'
match_subclasses = True
def map_serializer(self, auto_schema, direction):
_inject_enum_collision_fix(collection=False)
base_schema = auto_schema._map_serializer(self.target, direction, bypass_extensions=True)
return self.map_geo_feature_model_serializer(self.target, base_schema)
def map_geo_feature_model_serializer(self, serializer, base_schema):
from rest_framework_gis.serializers import GeoFeatureModelSerializer
geo_properties = {
"type": {"type": "string", "enum": ["Feature"]}
}
if serializer.Meta.id_field:
geo_properties["id"] = base_schema["properties"].pop(serializer.Meta.id_field)
geo_properties["geometry"] = map_geo_field(serializer, serializer.Meta.geo_field)
base_schema["properties"].pop(serializer.Meta.geo_field)
if serializer.Meta.auto_bbox or serializer.Meta.bbox_geo_field:
geo_properties["bbox"] = build_bbox_schema()
base_schema["properties"].pop(serializer.Meta.bbox_geo_field, None)
# only expose if description comes from the user
description = base_schema.pop('description', None)
if description == get_doc(GeoFeatureModelSerializer):
description = None
# ignore this aspect for now
base_schema.pop('required', None)
# nest remaining fields under property "properties"
geo_properties["properties"] = base_schema
return build_object_type(
properties=geo_properties,
description=description,
)
class GeoFeatureModelListSerializerExtension(OpenApiSerializerExtension):
target_class = 'rest_framework_gis.serializers.GeoFeatureModelListSerializer'
def map_serializer(self, auto_schema, direction):
_inject_enum_collision_fix(collection=True)
# build/retrieve feature component generated by GeoFeatureModelSerializerExtension.
# wrap the ref in the special list structure and build another component based on that.
feature_component = auto_schema.resolve_serializer(self.target.child, direction)
collection_schema = build_object_type(
properties={
"type": {"type": "string", "enum": ["FeatureCollection"]},
"features": build_array_type(feature_component.ref)
}
)
list_component = ResolvedComponent(
name=f'{feature_component.name}List',
type=ResolvedComponent.SCHEMA,
object=self.target.child,
schema=collection_schema
)
auto_schema.registry.register_on_missing(list_component)
return list_component.ref
class GeometryFieldExtension(OpenApiSerializerFieldExtension):
target_class = 'rest_framework_gis.fields.GeometryField'
match_subclasses = True
def map_serializer_field(self, auto_schema, direction):
# running this extension for GeoFeatureModelSerializer's geo_field is superfluous
# as above extension already handles that individually. We run it anyway because
# robustly checking the proper condition is harder.
try:
model = self.target.parent.Meta.model
model_field = follow_field_source(model, self.target.source.split('.'))
return build_geo_schema(model_field)
except: # noqa: E722
warn(f'Encountered an issue resolving field {self.target}. defaulting to generic object.')
return {}

View File

@@ -0,0 +1,16 @@
from drf_spectacular.extensions import OpenApiAuthenticationExtension
from drf_spectacular.plumbing import build_bearer_security_scheme_object
class JWTScheme(OpenApiAuthenticationExtension):
target_class = 'rest_framework_jwt.authentication.JSONWebTokenAuthentication'
name = 'jwtAuth'
def get_security_definition(self, auto_schema):
from rest_framework_jwt.settings import api_settings
return build_bearer_security_scheme_object(
header_name='AUTHORIZATION',
token_prefix=api_settings.JWT_AUTH_HEADER_PREFIX,
bearer_format='JWT'
)

View File

@@ -0,0 +1,16 @@
from drf_spectacular.extensions import OpenApiSerializerFieldExtension
from drf_spectacular.plumbing import build_array_type, is_list_serializer
class RecursiveFieldExtension(OpenApiSerializerFieldExtension):
target_class = "rest_framework_recursive.fields.RecursiveField"
def map_serializer_field(self, auto_schema, direction):
proxied = self.target.proxied
if is_list_serializer(proxied):
component = auto_schema.resolve_serializer(proxied.child, direction)
return build_array_type(component.ref)
component = auto_schema.resolve_serializer(proxied, direction)
return component.ref

View File

@@ -0,0 +1,87 @@
from rest_framework import serializers
from drf_spectacular.drainage import warn
from drf_spectacular.extensions import OpenApiAuthenticationExtension, OpenApiSerializerExtension
from drf_spectacular.plumbing import build_bearer_security_scheme_object
from drf_spectacular.utils import inline_serializer
class TokenObtainPairSerializerExtension(OpenApiSerializerExtension):
target_class = 'rest_framework_simplejwt.serializers.TokenObtainPairSerializer'
def map_serializer(self, auto_schema, direction):
Fixed = inline_serializer('Fixed', fields={
self.target_class.username_field: serializers.CharField(write_only=True),
'password': serializers.CharField(write_only=True),
'access': serializers.CharField(read_only=True),
'refresh': serializers.CharField(read_only=True),
})
return auto_schema._map_serializer(Fixed, direction)
class TokenObtainSlidingSerializerExtension(OpenApiSerializerExtension):
target_class = 'rest_framework_simplejwt.serializers.TokenObtainSlidingSerializer'
def map_serializer(self, auto_schema, direction):
Fixed = inline_serializer('Fixed', fields={
self.target_class.username_field: serializers.CharField(write_only=True),
'password': serializers.CharField(write_only=True),
'token': serializers.CharField(read_only=True),
})
return auto_schema._map_serializer(Fixed, direction)
class TokenRefreshSerializerExtension(OpenApiSerializerExtension):
target_class = 'rest_framework_simplejwt.serializers.TokenRefreshSerializer'
def map_serializer(self, auto_schema, direction):
from rest_framework_simplejwt.settings import api_settings
if api_settings.ROTATE_REFRESH_TOKENS:
class Fixed(serializers.Serializer):
access = serializers.CharField(read_only=True)
refresh = serializers.CharField()
else:
class Fixed(serializers.Serializer):
access = serializers.CharField(read_only=True)
refresh = serializers.CharField(write_only=True)
return auto_schema._map_serializer(Fixed, direction)
class TokenVerifySerializerExtension(OpenApiSerializerExtension):
target_class = 'rest_framework_simplejwt.serializers.TokenVerifySerializer'
def map_serializer(self, auto_schema, direction):
Fixed = inline_serializer('Fixed', fields={
'token': serializers.CharField(write_only=True),
})
return auto_schema._map_serializer(Fixed, direction)
class SimpleJWTScheme(OpenApiAuthenticationExtension):
target_class = 'rest_framework_simplejwt.authentication.JWTAuthentication'
name = 'jwtAuth'
def get_security_definition(self, auto_schema):
from rest_framework_simplejwt.settings import api_settings
if len(api_settings.AUTH_HEADER_TYPES) > 1:
warn(
f'OpenAPI3 can only have one "bearerFormat". JWT Settings specify '
f'{api_settings.AUTH_HEADER_TYPES}. Using the first one.'
)
return build_bearer_security_scheme_object(
header_name=getattr(api_settings, 'AUTH_HEADER_NAME', 'HTTP_AUTHORIZATION'),
token_prefix=api_settings.AUTH_HEADER_TYPES[0],
bearer_format='JWT'
)
class SimpleJWTTokenUserScheme(SimpleJWTScheme):
target_class = 'rest_framework_simplejwt.authentication.JWTTokenUserAuthentication'
class SimpleJWTStatelessUserScheme(SimpleJWTScheme):
target_class = "rest_framework_simplejwt.authentication.JWTStatelessUserAuthentication"

View File

@@ -0,0 +1,81 @@
from drf_spectacular.drainage import warn
from drf_spectacular.extensions import OpenApiSerializerExtension
from drf_spectacular.plumbing import (
ComponentIdentity, ResolvedComponent, build_basic_type, build_object_type,
is_patched_serializer,
)
from drf_spectacular.settings import spectacular_settings
from drf_spectacular.types import OpenApiTypes
class PolymorphicSerializerExtension(OpenApiSerializerExtension):
target_class = 'rest_polymorphic.serializers.PolymorphicSerializer'
match_subclasses = True
def map_serializer(self, auto_schema, direction):
sub_components = []
serializer = self.target
for sub_model in serializer.model_serializer_mapping:
sub_serializer = serializer._get_serializer_from_model_or_instance(sub_model)
sub_serializer.partial = serializer.partial
resource_type = serializer.to_resource_type(sub_model)
component = auto_schema.resolve_serializer(sub_serializer, direction)
if not component:
# rebuild a virtual schema-less component to model empty serializers
component = ResolvedComponent(
name=auto_schema._get_serializer_name(sub_serializer, direction),
type=ResolvedComponent.SCHEMA,
object=ComponentIdentity('virtual')
)
typed_component = self.build_typed_component(
auto_schema=auto_schema,
component=component,
resource_type_field_name=serializer.resource_type_field_name,
patched=is_patched_serializer(sub_serializer, direction)
)
sub_components.append((resource_type, typed_component.ref))
if not resource_type:
warn(
f'discriminator mapping key is empty for {sub_serializer.__class__}. '
f'this might lead to code generation issues.'
)
one_of_list = []
for _, ref in sub_components:
if ref not in one_of_list:
one_of_list.append(ref)
return {
'oneOf': one_of_list,
'discriminator': {
'propertyName': serializer.resource_type_field_name,
'mapping': {resource_type: ref['$ref'] for resource_type, ref in sub_components},
}
}
def build_typed_component(self, auto_schema, component, resource_type_field_name, patched):
if spectacular_settings.COMPONENT_SPLIT_REQUEST and component.name.endswith('Request'):
typed_component_name = component.name[:-len('Request')] + 'TypedRequest'
else:
typed_component_name = f'{component.name}Typed'
resource_type_schema = build_object_type(
properties={resource_type_field_name: build_basic_type(OpenApiTypes.STR)},
required=None if patched else [resource_type_field_name]
)
# if sub-serializer has an empty schema, only expose the resource_type field part
if component.schema:
schema = {'allOf': [resource_type_schema, component.ref]}
else:
schema = resource_type_schema
component_typed = ResolvedComponent(
name=typed_component_name,
type=ResolvedComponent.SCHEMA,
object=component.object,
schema=schema,
)
auto_schema.registry.register_on_missing(component_typed)
return component_typed

View File

@@ -0,0 +1,220 @@
import contextlib
import functools
import inspect
import sys
from collections import defaultdict
from typing import Any, Callable, DefaultDict, List, Optional, Tuple, TypeVar
if sys.version_info >= (3, 8):
from typing import ( # type: ignore[attr-defined] # noqa: F401
Final, Literal, TypedDict, _TypedDictMeta,
)
else:
from typing_extensions import Final, Literal, TypedDict, _TypedDictMeta # noqa: F401
if sys.version_info >= (3, 10):
from typing import TypeGuard # noqa: F401
else:
from typing_extensions import TypeGuard # noqa: F401
F = TypeVar('F', bound=Callable[..., Any])
class GeneratorStats:
_warn_cache: DefaultDict[str, int] = defaultdict(int)
_error_cache: DefaultDict[str, int] = defaultdict(int)
_traces: List[Tuple[Optional[str], Optional[str], str]] = []
_trace_lineno = False
_blue = ''
_red = ''
_yellow = ''
_clear = ''
def __getattr__(self, name):
if 'silent' not in self.__dict__:
from drf_spectacular.settings import spectacular_settings
self.silent = spectacular_settings.DISABLE_ERRORS_AND_WARNINGS
try:
return self.__dict__[name]
except KeyError:
raise AttributeError(name)
def __bool__(self):
return bool(self._warn_cache or self._error_cache)
@contextlib.contextmanager
def silence(self):
self.silent, tmp = True, self.silent
try:
yield
finally:
self.silent = tmp
def reset(self) -> None:
self._warn_cache.clear()
self._error_cache.clear()
def enable_color(self) -> None:
self._blue = '\033[0;34m'
self._red = '\033[0;31m'
self._yellow = '\033[0;33m'
self._clear = '\033[0m'
def enable_trace_lineno(self) -> None:
self._trace_lineno = True
def _get_current_trace(self) -> Tuple[Optional[str], str]:
source_locations = [t for t in self._traces if t[0]]
if source_locations:
sourcefile, lineno, _ = source_locations[-1]
source_location = f'{sourcefile}:{lineno}' if lineno else sourcefile
else:
source_location = ''
breadcrumbs = ' > '.join(t[2] for t in self._traces)
return source_location, breadcrumbs
def emit(self, msg: str, severity: str) -> None:
assert severity in ['warning', 'error']
cache = self._warn_cache if severity == 'warning' else self._error_cache
source_location, breadcrumbs = self._get_current_trace()
prefix = f'{self._blue}{source_location}: ' if source_location else ''
prefix += self._yellow if severity == 'warning' else self._red
prefix += f'{severity.capitalize()}'
prefix += f' [{breadcrumbs}]: ' if breadcrumbs else ': '
msg = prefix + self._clear + str(msg)
if not self.silent and msg not in cache:
print(msg, file=sys.stderr)
cache[msg] += 1
def emit_summary(self) -> None:
if not self.silent and (self._warn_cache or self._error_cache):
print(
f'\nSchema generation summary:\n'
f'Warnings: {sum(self._warn_cache.values())} ({len(self._warn_cache)} unique)\n'
f'Errors: {sum(self._error_cache.values())} ({len(self._error_cache)} unique)\n',
file=sys.stderr
)
GENERATOR_STATS = GeneratorStats()
def warn(msg: str, delayed: Any = None) -> None:
if delayed:
warnings = get_override(delayed, 'warnings', [])
warnings.append(msg)
set_override(delayed, 'warnings', warnings)
else:
GENERATOR_STATS.emit(msg, 'warning')
def error(msg: str, delayed: Any = None) -> None:
if delayed:
errors = get_override(delayed, 'errors', [])
errors.append(msg)
set_override(delayed, 'errors', errors)
else:
GENERATOR_STATS.emit(msg, 'error')
def reset_generator_stats() -> None:
GENERATOR_STATS.reset()
@contextlib.contextmanager
def add_trace_message(obj):
"""
Adds a message to be used as a prefix when emitting warnings and errors.
"""
sourcefile, lineno = _get_source_location(obj)
GENERATOR_STATS._traces.append((sourcefile, lineno, obj.__name__))
yield
GENERATOR_STATS._traces.pop()
@functools.lru_cache(maxsize=1000)
def _get_source_location(obj):
try:
sourcefile = inspect.getsourcefile(obj)
except: # noqa: E722
sourcefile = None
try:
# This is a rather expensive operation. Only do it when explicitly enabled (CLI)
# and cache results to speed up some recurring objects like serializers.
lineno = inspect.getsourcelines(obj)[1] if GENERATOR_STATS._trace_lineno else None
except: # noqa: E722
lineno = None
return sourcefile, lineno
def has_override(obj: Any, prop: str) -> bool:
if isinstance(obj, functools.partial):
obj = obj.func
if not hasattr(obj, '_spectacular_annotation'):
return False
if prop not in obj._spectacular_annotation:
return False
return True
def get_override(obj: Any, prop: str, default: Any = None) -> Any:
if isinstance(obj, functools.partial):
obj = obj.func
if not has_override(obj, prop):
return default
return obj._spectacular_annotation[prop]
def set_override(obj: Any, prop: str, value: Any) -> Any:
if not hasattr(obj, '_spectacular_annotation'):
obj._spectacular_annotation = {}
elif '_spectacular_annotation' not in obj.__dict__:
obj._spectacular_annotation = obj._spectacular_annotation.copy()
obj._spectacular_annotation[prop] = value
return obj
def get_view_method_names(view, schema=None) -> List[str]:
schema = schema or view.schema
return [
item for item in dir(view) if callable(getattr(view, item, None)) and (
item in view.http_method_names
or item in schema.method_mapping.values()
or item == 'list'
or hasattr(getattr(view, item, None), 'mapping')
)
]
def isolate_view_method(view, method_name):
"""
Prevent modifying a view method which is derived from other views. Changes to
a derived method would leak into the view where the method originated from.
Break derivation by wrapping the method and explicitly setting it on the view.
"""
method = getattr(view, method_name)
# no isolation is required if the view method is not derived.
# @api_view is a special case that also breaks isolation. It proxies all view
# methods through a single handler function, which then also requires isolation.
if method_name in view.__dict__ and method.__name__ != 'handler':
return method
@functools.wraps(method)
def wrapped_method(self, request, *args, **kwargs):
return method(self, request, *args, **kwargs)
# wraps() will only create a shallow copy of method.__dict__. Updates to "kwargs"
# via @extend_schema would leak to the original method. Isolate by creating a copy.
if hasattr(method, 'kwargs'):
wrapped_method.kwargs = method.kwargs.copy()
setattr(view, method_name, wrapped_method)
return wrapped_method
def cache(user_function: F) -> F:
""" simple polyfill for python < 3.9 """
return functools.lru_cache(maxsize=None)(user_function) # type: ignore

View File

@@ -0,0 +1,143 @@
from abc import abstractmethod
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type, Union
from drf_spectacular.plumbing import OpenApiGeneratorExtension
from drf_spectacular.utils import Direction
if TYPE_CHECKING:
from rest_framework.views import APIView
from drf_spectacular.openapi import AutoSchema
_SchemaType = Dict[str, Any]
class OpenApiAuthenticationExtension(OpenApiGeneratorExtension['OpenApiAuthenticationExtension']):
"""
Extension for specifying authentication schemes.
The common use-case usually consists of setting a ``name`` string and returning a dict from
``get_security_definition``. To model a group of headers that go together, set a list
of names and return a corresponding list of definitions from ``get_security_definition``.
The view class is available via ``auto_schema.view``, while the original authentication class
can be accessed via ``self.target``. If you want to override an included extension, be sure to
set a higher matching priority by setting the class attribute ``priority = 1`` or higher.
get_security_requirement is expected to return a dict with security object names as keys and a
scope list as value (usually just []). More than one key in the dict means that each entry is
required (AND). If you need alternate variations (OR), return a list of those dicts instead.
``get_security_definition()`` is expected to return a valid `OpenAPI security scheme object
<https://spec.openapis.org/oas/v3.0.3#security-scheme-object>`_
"""
_registry: List[Type['OpenApiAuthenticationExtension']] = []
name: Union[str, List[str]]
def get_security_requirement(
self, auto_schema: 'AutoSchema'
) -> Union[Dict[str, List[Any]], List[Dict[str, List[Any]]]]:
assert self.name, 'name(s) must be specified'
if isinstance(self.name, str):
return {self.name: []}
else:
return {name: [] for name in self.name}
@abstractmethod
def get_security_definition(self, auto_schema: 'AutoSchema') -> Union[_SchemaType, List[_SchemaType]]:
pass # pragma: no cover
class OpenApiSerializerExtension(OpenApiGeneratorExtension['OpenApiSerializerExtension']):
"""
Extension for replacing an insufficient or specifying an unknown Serializer schema.
The existing implementation of ``map_serializer()`` will generate the same result
as *drf-spectacular* would. Either augment or replace the generated schema. The
view instance is available via ``auto_schema.view``, while the original serializer
can be accessed via ``self.target``.
``map_serializer()`` is expected to return a valid `OpenAPI schema object
<https://spec.openapis.org/oas/v3.0.3#schema-object>`_.
"""
_registry: List[Type['OpenApiSerializerExtension']] = []
def get_name(self, auto_schema: 'AutoSchema', direction: Direction) -> Optional[str]:
""" return str for overriding default name extraction """
return None
def get_identity(self, auto_schema: 'AutoSchema', direction: Direction) -> Any:
""" return anything to compare instances of target. Target will be used by default. """
return None
def map_serializer(self, auto_schema: 'AutoSchema', direction: Direction) -> _SchemaType:
""" override for customized serializer mapping """
return auto_schema._map_serializer(self.target_class, direction, bypass_extensions=True)
class OpenApiSerializerFieldExtension(OpenApiGeneratorExtension['OpenApiSerializerFieldExtension']):
"""
Extension for replacing an insufficient or specifying an unknown SerializerField schema.
To augment the default schema, you can get what *drf-spectacular* would generate with
``auto_schema._map_serializer_field(self.target, direction, bypass_extensions=True)``.
and edit the returned schema at your discretion. Beware that this may still emit
warnings, in which case manual construction is advisable.
``map_serializer_field()`` is expected to return a valid `OpenAPI schema object
<https://spec.openapis.org/oas/v3.0.3#schema-object>`_.
"""
_registry: List[Type['OpenApiSerializerFieldExtension']] = []
def get_name(self) -> Optional[str]:
""" return str for breaking out field schema into separate named component """
return None
@abstractmethod
def map_serializer_field(self, auto_schema: 'AutoSchema', direction: Direction) -> _SchemaType:
""" override for customized serializer field mapping """
pass # pragma: no cover
class OpenApiViewExtension(OpenApiGeneratorExtension['OpenApiViewExtension']):
"""
Extension for replacing discovered views with a more schema-appropriate/annotated version.
``view_replacement()`` is expected to return a subclass of ``APIView`` (which includes
``ViewSet`` et al.). The discovered original view instance can be accessed with
``self.target`` and be subclassed if desired.
"""
_registry: List[Type['OpenApiViewExtension']] = []
@classmethod
def _load_class(cls):
super()._load_class()
# special case @api_view: view class is nested in the cls attr of the function object
if hasattr(cls.target_class, 'cls'):
cls.target_class = cls.target_class.cls
@abstractmethod
def view_replacement(self) -> 'Type[APIView]':
pass # pragma: no cover
class OpenApiFilterExtension(OpenApiGeneratorExtension['OpenApiFilterExtension']):
"""
Extension for specifying a list of filter parameters for a given ``FilterBackend``.
The original filter class object can be accessed via ``self.target``. The attached view
is accessible via ``auto_schema.view``.
``get_schema_operation_parameters()`` is expected to return either an empty list or a list
of valid raw `OpenAPI parameter objects
<https://spec.openapis.org/oas/v3.0.3#parameter-object>`_.
Using ``drf_spectacular.plumbing.build_parameter_type`` is recommended to generate
the appropriate raw dict objects.
"""
_registry: List[Type['OpenApiFilterExtension']] = []
@abstractmethod
def get_schema_operation_parameters(self, auto_schema: 'AutoSchema', *args, **kwargs) -> List[_SchemaType]:
pass # pragma: no cover

View File

@@ -0,0 +1,293 @@
import os
import re
from django.urls import URLPattern, URLResolver
from rest_framework import views, viewsets
from rest_framework.schemas.generators import BaseSchemaGenerator
from rest_framework.schemas.generators import EndpointEnumerator as BaseEndpointEnumerator
from rest_framework.settings import api_settings
from drf_spectacular.drainage import (
add_trace_message, error, get_override, reset_generator_stats, warn,
)
from drf_spectacular.extensions import OpenApiViewExtension
from drf_spectacular.openapi import AutoSchema
from drf_spectacular.plumbing import (
ComponentRegistry, alpha_operation_sorter, build_root_object, camelize_operation, get_class,
is_versioning_supported, modify_for_versioning, normalize_result_object,
operation_matches_version, process_webhooks, sanitize_result_object,
)
from drf_spectacular.settings import spectacular_settings
class EndpointEnumerator(BaseEndpointEnumerator):
def get_api_endpoints(self, patterns=None, prefix=''):
api_endpoints = self._get_api_endpoints(patterns, prefix)
for hook in spectacular_settings.PREPROCESSING_HOOKS:
api_endpoints = hook(endpoints=api_endpoints)
api_endpoints_deduplicated = {}
for path, path_regex, method, callback in api_endpoints:
if (path, method) not in api_endpoints_deduplicated:
api_endpoints_deduplicated[path, method] = (path, path_regex, method, callback)
api_endpoints = list(api_endpoints_deduplicated.values())
if callable(spectacular_settings.SORT_OPERATIONS):
return sorted(api_endpoints, key=spectacular_settings.SORT_OPERATIONS)
elif spectacular_settings.SORT_OPERATIONS:
return sorted(api_endpoints, key=alpha_operation_sorter)
else:
return api_endpoints
def get_path_from_regex(self, path_regex):
path = super().get_path_from_regex(path_regex)
# bugfix oversight in DRF regex stripping
path = path.replace('\\.', '.')
return path
def _get_api_endpoints(self, patterns, prefix):
"""
Return a list of all available API endpoints by inspecting the URL conf.
Only modification the DRF version is passing through the path_regex.
"""
if patterns is None:
patterns = self.patterns
api_endpoints = []
for pattern in patterns:
path_regex = prefix + str(pattern.pattern)
if isinstance(pattern, URLPattern):
path = self.get_path_from_regex(path_regex)
callback = pattern.callback
if self.should_include_endpoint(path, callback):
for method in self.get_allowed_methods(callback):
endpoint = (path, path_regex, method, callback)
api_endpoints.append(endpoint)
elif isinstance(pattern, URLResolver):
nested_endpoints = self._get_api_endpoints(
patterns=pattern.url_patterns,
prefix=path_regex
)
api_endpoints.extend(nested_endpoints)
return api_endpoints
def get_allowed_methods(self, callback):
if hasattr(callback, 'actions'):
actions = set(callback.actions)
if 'http_method_names' in callback.initkwargs:
http_method_names = set(callback.initkwargs['http_method_names'])
else:
http_method_names = set(callback.cls.http_method_names)
methods = [method.upper() for method in actions & http_method_names]
else:
# pass to constructor allowed method names to get valid ones
kwargs = {}
if 'http_method_names' in callback.initkwargs:
kwargs['http_method_names'] = callback.initkwargs['http_method_names']
methods = callback.cls(**kwargs).allowed_methods
return [
method for method in methods
if method not in ('OPTIONS', 'HEAD', 'TRACE', 'CONNECT')
]
class SchemaGenerator(BaseSchemaGenerator):
endpoint_inspector_cls = EndpointEnumerator
def __init__(self, *args, **kwargs):
self.registry = ComponentRegistry()
self.api_version = kwargs.pop('api_version', None)
self.inspector = None
super().__init__(*args, **kwargs)
def coerce_path(self, path, method, view):
"""
Customized coerce_path which also considers the `_pk` suffix in URL paths
of nested routers.
"""
path = super().coerce_path(path, method, view) # take care of {pk}
if spectacular_settings.SCHEMA_COERCE_PATH_PK_SUFFIX:
path = re.sub(pattern=r'{(\w+)_pk}', repl=r'{\1_id}', string=path)
return path
def create_view(self, callback, method, request=None):
"""
customized create_view which is called when all routes are traversed. part of this
is instantiating views with default params. in case of custom routes (@action) the
custom AutoSchema is injected properly through 'initkwargs' on view. However, when
decorating plain views like retrieve, this initialization logic is not running.
Therefore forcefully set the schema if @extend_schema decorator was used.
"""
override_view = OpenApiViewExtension.get_match(callback.cls)
if override_view:
original_cls = callback.cls
callback.cls = override_view.view_replacement()
# we refrain from passing request and deal with it ourselves in parse()
view = super().create_view(callback, method, None)
# drf-yasg compatibility feature. makes the view aware that we are running
# schema generation and not a real request.
view.swagger_fake_view = True
# callback.cls is hosted in urlpatterns and is therefore not an ephemeral modification.
# restore after view creation so potential revisits have a clean state as basis.
if override_view:
callback.cls = original_cls
if isinstance(view, viewsets.ViewSetMixin):
action = getattr(view, view.action)
elif isinstance(view, views.APIView):
action = getattr(view, method.lower())
else:
error(
'Using not supported View class. Class must be derived from APIView '
'or any of its subclasses like GenericApiView, GenericViewSet.'
)
return view
action_schema = getattr(action, 'kwargs', {}).get('schema', None)
if not action_schema:
# there is no method/action customized schema so we are done here.
return view
# action_schema is either a class or instance. when @extend_schema is used, it
# is always a class to prevent the weakref reverse "schema.view" bug for multi
# annotations. The bug is prevented by delaying the instantiation of the schema
# class until create_view (here) and not doing it immediately in @extend_schema.
action_schema_class = get_class(action_schema)
view_schema_class = get_class(callback.cls.schema)
if not issubclass(action_schema_class, view_schema_class):
# this handles the case of having a manually set custom AutoSchema on the
# view together with extend_schema. In most cases, the decorator mechanics
# prevent extend_schema from having access to the view's schema class. So
# extend_schema is forced to use DEFAULT_SCHEMA_CLASS as fallback base class
# instead of the correct base class set in view. We remedy this chicken-egg
# problem here by rearranging the class hierarchy.
mro = tuple(
cls for cls in action_schema_class.__mro__
if cls not in api_settings.DEFAULT_SCHEMA_CLASS.__mro__
) + view_schema_class.__mro__
action_schema_class = type('ExtendedRearrangedSchema', mro, {})
view.schema = action_schema_class()
return view
def _initialise_endpoints(self):
if self.endpoints is None:
self.inspector = self.endpoint_inspector_cls(self.patterns, self.urlconf)
self.endpoints = self.inspector.get_api_endpoints()
def _get_paths_and_endpoints(self):
"""
Generate (path, method, view) given (path, method, callback) for paths.
"""
view_endpoints = []
for path, path_regex, method, callback in self.endpoints:
view = self.create_view(callback, method)
path = self.coerce_path(path, method, view)
view_endpoints.append((path, path_regex, method, view))
return view_endpoints
def parse(self, input_request, public):
""" Iterate endpoints generating per method path operations. """
result = {}
self._initialise_endpoints()
endpoints = self._get_paths_and_endpoints()
if spectacular_settings.SCHEMA_PATH_PREFIX is None:
# estimate common path prefix if none was given. only use it if we encountered more
# than one view to prevent emission of erroneous and unnecessary fallback names.
non_trivial_prefix = len(set([view.__class__ for _, _, _, view in endpoints])) > 1
if non_trivial_prefix:
path_prefix = os.path.commonpath([path for path, _, _, _ in endpoints])
path_prefix = re.escape(path_prefix) # guard for RE special chars in path
else:
path_prefix = '/'
else:
path_prefix = spectacular_settings.SCHEMA_PATH_PREFIX
if not path_prefix.startswith('^'):
path_prefix = '^' + path_prefix # make sure regex only matches from the start
for path, path_regex, method, view in endpoints:
# emit queued up warnings/error that happened prior to generation (decoration)
for w in get_override(view, 'warnings', []):
warn(w)
for e in get_override(view, 'errors', []):
error(e)
view.request = spectacular_settings.GET_MOCK_REQUEST(method, path, view, input_request)
if not (public or self.has_view_permissions(path, method, view)):
continue
if view.versioning_class and not is_versioning_supported(view.versioning_class):
warn(
f'using unsupported versioning class "{view.versioning_class}". view will be '
f'processed as unversioned view.'
)
elif view.versioning_class:
version = (
self.api_version # explicit version from CLI, SpecView or SpecView request
or view.versioning_class.default_version # fallback
)
if not version:
continue
path = modify_for_versioning(self.inspector.patterns, method, path, view, version)
if not operation_matches_version(view, version):
continue
assert isinstance(view.schema, AutoSchema), (
f'Incompatible AutoSchema used on View {view.__class__}. Is DRF\'s '
f'DEFAULT_SCHEMA_CLASS pointing to "drf_spectacular.openapi.AutoSchema" '
f'or any other drf-spectacular compatible AutoSchema?'
)
with add_trace_message(getattr(view, '__class__', view)):
operation = view.schema.get_operation(
path, path_regex, path_prefix, method, self.registry
)
# operation was manually removed via @extend_schema
if not operation:
continue
if spectacular_settings.SCHEMA_PATH_PREFIX_TRIM:
path = re.sub(pattern=path_prefix, repl='', string=path, flags=re.IGNORECASE)
if spectacular_settings.SCHEMA_PATH_PREFIX_INSERT:
path = spectacular_settings.SCHEMA_PATH_PREFIX_INSERT + path
if not path.startswith('/'):
path = '/' + path
if spectacular_settings.CAMELIZE_NAMES:
path, operation = camelize_operation(path, operation)
result.setdefault(path, {})
result[path][method.lower()] = operation
return result
def get_schema(self, request=None, public=False):
""" Generate a OpenAPI schema. """
reset_generator_stats()
result = build_root_object(
paths=self.parse(request, public),
components=self.registry.build(spectacular_settings.APPEND_COMPONENTS),
webhooks=process_webhooks(spectacular_settings.WEBHOOKS, self.registry),
version=self.api_version or getattr(request, 'version', None),
)
for hook in spectacular_settings.POSTPROCESSING_HOOKS:
result = hook(result=result, generator=self, request=request, public=public)
return sanitize_result_object(normalize_result_object(result))

View File

@@ -0,0 +1,42 @@
from django.utils.module_loading import import_string
def lazy_serializer(path: str):
""" simulate initiated object but actually load class and init on first usage """
class LazySerializer:
def __init__(self, *args, **kwargs):
self.lazy_args, self.lazy_kwargs, self.lazy_obj = args, kwargs, None
def __getattr__(self, item):
if not self.lazy_obj:
self.lazy_obj = import_string(path)(*self.lazy_args, **self.lazy_kwargs)
return getattr(self.lazy_obj, item)
@property # type: ignore
def __class__(self):
return self.__getattr__('__class__')
@property
def __dict__(self):
return self.__getattr__('__dict__')
def __str__(self):
return self.__getattr__('__str__')()
def __repr__(self):
return self.__getattr__('__repr__')()
return LazySerializer
def forced_singular_serializer(serializer_class):
from drf_spectacular.drainage import set_override
from drf_spectacular.utils import extend_schema_serializer
patched_serializer_class = type(serializer_class.__name__, (serializer_class,), {})
extend_schema_serializer(many=False)(patched_serializer_class)
set_override(patched_serializer_class, 'suppress_collision_warning', True)
return patched_serializer_class

View File

@@ -0,0 +1,210 @@
import re
from collections import defaultdict
from inflection import camelize
from rest_framework.settings import api_settings
from drf_spectacular.drainage import warn
from drf_spectacular.plumbing import (
ResolvedComponent, list_hash, load_enum_name_overrides, safe_ref,
)
from drf_spectacular.settings import spectacular_settings
def postprocess_schema_enums(result, generator, **kwargs):
"""
simple replacement of Enum/Choices that globally share the same name and have
the same choices. Aids client generation to not generate a separate enum for
every occurrence. only takes effect when replacement is guaranteed to be correct.
"""
def iter_prop_containers(schema, component_name=None):
if not component_name:
for component_name, schema in schema.items():
if spectacular_settings.COMPONENT_SPLIT_PATCH:
component_name = re.sub('^Patched(.+)', r'\1', component_name)
if spectacular_settings.COMPONENT_SPLIT_REQUEST:
component_name = re.sub('(.+)Request$', r'\1', component_name)
yield from iter_prop_containers(schema, component_name)
elif isinstance(schema, list):
for item in schema:
yield from iter_prop_containers(item, component_name)
elif isinstance(schema, dict):
if schema.get('properties'):
yield component_name, schema['properties']
yield from iter_prop_containers(schema.get('oneOf', []), component_name)
yield from iter_prop_containers(schema.get('allOf', []), component_name)
yield from iter_prop_containers(schema.get('anyOf', []), component_name)
def create_enum_component(name, schema):
component = ResolvedComponent(
name=name,
type=ResolvedComponent.SCHEMA,
schema=schema,
object=name,
)
generator.registry.register_on_missing(component)
return component
def extract_hash(schema):
if 'x-spec-enum-id' in schema:
# try to use the injected enum hash first as it generated from (name, value) tuples,
# which prevents collisions on choice sets only differing in labels not values.
return schema['x-spec-enum-id']
else:
# fall back to actual list hashing when we encounter enums not generated by us.
# remove blank/null entry for hashing. will be reconstructed in the last step
return list_hash([(i, i) for i in schema['enum'] if i not in ('', None)])
schemas = result.get('components', {}).get('schemas', {})
overrides = load_enum_name_overrides()
prop_hash_mapping = defaultdict(set)
hash_name_mapping = defaultdict(set)
# collect all enums, their names and choice sets
for component_name, props in iter_prop_containers(schemas):
for prop_name, prop_schema in props.items():
if prop_schema.get('type') == 'array':
prop_schema = prop_schema.get('items', {})
if 'enum' not in prop_schema:
continue
prop_enum_cleaned_hash = extract_hash(prop_schema)
prop_hash_mapping[prop_name].add(prop_enum_cleaned_hash)
hash_name_mapping[prop_enum_cleaned_hash].add((component_name, prop_name))
# get the suffix to be used for enums from settings
enum_suffix = spectacular_settings.ENUM_SUFFIX
# traverse all enum properties and generate a name for the choice set. naming collisions
# are resolved and a warning is emitted. giving a choice set multiple names is technically
# correct but potentially unwanted. also emit a warning there to make the user aware.
enum_name_mapping = {}
for prop_name, prop_hash_set in prop_hash_mapping.items():
for prop_hash in prop_hash_set:
if prop_hash in overrides:
enum_name = overrides[prop_hash]
elif len(prop_hash_set) == 1:
# prop_name has been used exclusively for one choice set (best case)
enum_name = f'{camelize(prop_name)}{enum_suffix}'
elif len(hash_name_mapping[prop_hash]) == 1:
# prop_name has multiple choice sets, but each one limited to one component only
component_name, _ = next(iter(hash_name_mapping[prop_hash]))
enum_name = f'{camelize(component_name)}{camelize(prop_name)}{enum_suffix}'
else:
enum_name = f'{camelize(prop_name)}{prop_hash[:3].capitalize()}{enum_suffix}'
warn(
f'enum naming encountered a non-optimally resolvable collision for fields '
f'named "{prop_name}". The same name has been used for multiple choice sets '
f'in multiple components. The collision was resolved with "{enum_name}". '
f'add an entry to ENUM_NAME_OVERRIDES to fix the naming.'
)
if enum_name_mapping.get(prop_hash, enum_name) != enum_name:
warn(
f'encountered multiple names for the same choice set ({enum_name}). This '
f'may be unwanted even though the generated schema is technically correct. '
f'Add an entry to ENUM_NAME_OVERRIDES to fix the naming.'
)
del enum_name_mapping[prop_hash]
else:
enum_name_mapping[prop_hash] = enum_name
enum_name_mapping[(prop_hash, prop_name)] = enum_name
# replace all enum occurrences with a enum schema component. cut out the
# enum, replace it with a reference and add a corresponding component.
for _, props in iter_prop_containers(schemas):
for prop_name, prop_schema in props.items():
is_array = prop_schema.get('type') == 'array'
if is_array:
prop_schema = prop_schema.get('items', {})
if 'enum' not in prop_schema:
continue
prop_enum_original_list = prop_schema['enum']
prop_schema['enum'] = [i for i in prop_schema['enum'] if i not in ['', None]]
prop_hash = extract_hash(prop_schema)
# when choice sets are reused under multiple names, the generated name cannot be
# resolved from the hash alone. fall back to prop_name and hash for resolution.
enum_name = enum_name_mapping.get(prop_hash) or enum_name_mapping[prop_hash, prop_name]
# split property into remaining property and enum component parts
enum_schema = {k: v for k, v in prop_schema.items() if k in ['type', 'enum']}
prop_schema = {k: v for k, v in prop_schema.items() if k not in ['type', 'enum', 'x-spec-enum-id']}
# separate actual description from name-value tuples
if spectacular_settings.ENUM_GENERATE_CHOICE_DESCRIPTION:
if prop_schema.get('description', '').startswith('*'):
enum_schema['description'] = prop_schema.pop('description')
elif '\n\n*' in prop_schema.get('description', ''):
_, _, post = prop_schema['description'].partition('\n\n*')
enum_schema['description'] = '*' + post
components = [
create_enum_component(enum_name, schema=enum_schema)
]
if spectacular_settings.ENUM_ADD_EXPLICIT_BLANK_NULL_CHOICE:
if '' in prop_enum_original_list:
components.append(create_enum_component(f'Blank{enum_suffix}', schema={'enum': ['']}))
if None in prop_enum_original_list:
if spectacular_settings.OAS_VERSION.startswith('3.1'):
components.append(create_enum_component(f'Null{enum_suffix}', schema={'type': 'null'}))
else:
components.append(create_enum_component(f'Null{enum_suffix}', schema={'enum': [None]}))
# undo OAS 3.1 type list NULL construction as we cover this in a separate component already
if spectacular_settings.OAS_VERSION.startswith('3.1') and isinstance(enum_schema['type'], list):
enum_schema['type'] = [t for t in enum_schema['type'] if t != 'null'][0]
if len(components) == 1:
prop_schema.update(components[0].ref)
else:
prop_schema.update({'oneOf': [c.ref for c in components]})
if is_array:
props[prop_name]['items'] = safe_ref(prop_schema)
else:
props[prop_name] = safe_ref(prop_schema)
# sort again with additional components
result['components'] = generator.registry.build(spectacular_settings.APPEND_COMPONENTS)
# remove remaining ids that were not part of this hook (operation parameters mainly)
postprocess_schema_enum_id_removal(result, generator)
return result
def postprocess_schema_enum_id_removal(result, generator, **kwargs):
"""
Iterative modifying approach to scanning the whole schema and removing the
temporary helper ids that allowed us to distinguish similar enums.
"""
def clean(sub_result):
if isinstance(sub_result, dict):
for key in list(sub_result):
if key == 'x-spec-enum-id':
del sub_result['x-spec-enum-id']
else:
clean(sub_result[key])
elif isinstance(sub_result, (list, tuple)):
for item in sub_result:
clean(item)
clean(result)
return result
def preprocess_exclude_path_format(endpoints, **kwargs):
"""
preprocessing hook that filters out {format} suffixed paths, in case
format_suffix_patterns is used and {format} path params are unwanted.
"""
format_path = f'{{{api_settings.FORMAT_SUFFIX_KWARG}}}'
return [
(path, path_regex, method, callback)
for path, path_regex, method, callback in endpoints
if not (path.endswith(format_path) or path.endswith(format_path + '/'))
]

View File

@@ -0,0 +1,98 @@
from textwrap import dedent
from django.core.management.base import BaseCommand, CommandError
from django.utils import translation
from django.utils.module_loading import import_string
from drf_spectacular.drainage import GENERATOR_STATS
from drf_spectacular.renderers import OpenApiJsonRenderer, OpenApiYamlRenderer
from drf_spectacular.settings import patched_settings, spectacular_settings
from drf_spectacular.validation import validate_schema
class SchemaGenerationError(CommandError):
pass
class SchemaValidationError(CommandError):
pass
class Command(BaseCommand):
help = dedent("""
Generate a spectacular OpenAPI3-compliant schema for your API.
The warnings serve as a indicator for where your API could not be properly
resolved. @extend_schema and @extend_schema_field are your friends.
The spec should be valid in any case. If not, please open an issue
on github: https://github.com/tfranzel/drf-spectacular/issues
Remember to configure your APIs meta data like servers, version, url,
documentation and so on in your SPECTACULAR_SETTINGS."
""")
def add_arguments(self, parser):
parser.add_argument('--format', dest="format", choices=['openapi', 'openapi-json'], default='openapi', type=str)
parser.add_argument('--urlconf', dest="urlconf", default=None, type=str)
parser.add_argument('--generator-class', dest="generator_class", default=None, type=str)
parser.add_argument('--file', dest="file", default=None, type=str)
parser.add_argument('--fail-on-warn', dest="fail_on_warn", default=False, action='store_true')
parser.add_argument('--validate', dest="validate", default=False, action='store_true')
parser.add_argument('--api-version', dest="api_version", default=None, type=str)
parser.add_argument('--lang', dest="lang", default=None, type=str)
parser.add_argument('--color', dest="color", default=False, action='store_true')
parser.add_argument('--custom-settings', dest="custom_settings", default=None, type=str)
def handle(self, *args, **options):
if options['generator_class']:
generator_class = import_string(options['generator_class'])
else:
generator_class = spectacular_settings.DEFAULT_GENERATOR_CLASS
GENERATOR_STATS.enable_trace_lineno()
if options['color']:
GENERATOR_STATS.enable_color()
generator = generator_class(
urlconf=options['urlconf'],
api_version=options['api_version'],
)
if options['custom_settings']:
custom_settings = import_string(options['custom_settings'])
else:
custom_settings = None
with patched_settings(custom_settings):
if options['lang']:
with translation.override(options['lang']):
schema = generator.get_schema(request=None, public=True)
else:
schema = generator.get_schema(request=None, public=True)
GENERATOR_STATS.emit_summary()
if options['fail_on_warn'] and GENERATOR_STATS:
raise SchemaGenerationError('Failing as requested due to warnings')
if options['validate']:
try:
validate_schema(schema)
except Exception as e:
raise SchemaValidationError(e)
renderer = self.get_renderer(options['format'])
output = renderer.render(schema, renderer_context={})
if options['file']:
with open(options['file'], 'wb') as f:
f.write(output)
else:
self.stdout.write(output.decode())
def get_renderer(self, format):
renderer_cls = {
'openapi': OpenApiYamlRenderer,
'openapi-json': OpenApiJsonRenderer,
}[format]
return renderer_cls()

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,82 @@
from collections import OrderedDict
from datetime import time, timedelta
from decimal import Decimal
from uuid import UUID
import yaml
from django.utils.safestring import SafeString
from rest_framework.exceptions import ErrorDetail
from rest_framework.renderers import BaseRenderer, JSONRenderer
class OpenApiYamlRenderer(BaseRenderer):
media_type = 'application/vnd.oai.openapi'
format = 'yaml'
def render(self, data, accepted_media_type=None, renderer_context=None):
# disable yaml advanced feature 'alias' for clean, portable, and readable output
class Dumper(yaml.SafeDumper):
def ignore_aliases(self, data):
return True
def error_detail_representer(dumper, data):
return dumper.represent_dict({'string': str(data), 'code': data.code})
Dumper.add_representer(ErrorDetail, error_detail_representer)
def multiline_str_representer(dumper, data):
scalar = dumper.represent_str(data)
scalar.style = '|' if '\n' in data else None
return scalar
Dumper.add_representer(str, multiline_str_representer)
def decimal_representer(dumper, data):
# prevent emitting "!! float" tags on fractionless decimals
value = f'{data:f}'
if '.' in value:
return dumper.represent_scalar('tag:yaml.org,2002:float', value)
else:
return dumper.represent_scalar('tag:yaml.org,2002:int', value)
Dumper.add_representer(Decimal, decimal_representer)
def timedelta_representer(dumper, data):
return dumper.represent_str(str(data.total_seconds()))
Dumper.add_representer(timedelta, timedelta_representer)
def time_representer(dumper, data):
return dumper.represent_str(data.isoformat())
Dumper.add_representer(time, time_representer)
def uuid_representer(dumper, data):
return dumper.represent_str(str(data))
Dumper.add_representer(UUID, uuid_representer)
def safestring_representer(dumper, data):
return dumper.represent_str(data)
Dumper.add_representer(SafeString, safestring_representer)
def ordereddict_representer(dumper, data):
return dumper.represent_dict(dict(data))
Dumper.add_representer(OrderedDict, ordereddict_representer)
return yaml.dump(
data,
default_flow_style=False,
sort_keys=False,
allow_unicode=True,
Dumper=Dumper
).encode('utf-8')
class OpenApiYamlRenderer2(OpenApiYamlRenderer):
media_type = 'application/yaml'
class OpenApiJsonRenderer(JSONRenderer):
media_type = 'application/vnd.oai.openapi+json'
def get_indent(self, accepted_media_type, renderer_context):
return super().get_indent(accepted_media_type, renderer_context) or 4
class OpenApiJsonRenderer2(OpenApiJsonRenderer):
media_type = 'application/json'

View File

@@ -0,0 +1,94 @@
from drf_spectacular.drainage import error, warn
from drf_spectacular.extensions import OpenApiSerializerExtension
from drf_spectacular.plumbing import force_instance, is_list_serializer, is_serializer
class PolymorphicProxySerializerExtension(OpenApiSerializerExtension):
target_class = 'drf_spectacular.utils.PolymorphicProxySerializer'
priority = -1
def get_name(self):
return self.target.component_name
def map_serializer(self, auto_schema, direction):
""" custom handling for @extend_schema's injection of PolymorphicProxySerializer """
if isinstance(self.target.serializers, dict):
sub_components = self._get_explicit_sub_components(auto_schema, direction)
else:
sub_components = self._get_implicit_sub_components(auto_schema, direction)
if not self._has_discriminator():
return {'oneOf': [schema for _, schema in sub_components]}
else:
one_of_list = []
for _, schema in sub_components:
if schema not in one_of_list:
one_of_list.append(schema)
return {
'oneOf': one_of_list,
'discriminator': {
'propertyName': self.target.resource_type_field_name,
'mapping': {resource_type: schema['$ref'] for resource_type, schema in sub_components}
}
}
def _get_implicit_sub_components(self, auto_schema, direction):
sub_components = []
for sub_serializer in self.target.serializers:
sub_serializer = self._prep_serializer(sub_serializer)
(resolved_name, resolved_schema) = self._process_serializer(auto_schema, sub_serializer, direction)
if not resolved_schema:
continue
if not self._has_discriminator():
sub_components.append((None, resolved_schema))
else:
try:
discriminator_field = sub_serializer.fields[self.target.resource_type_field_name]
resource_type = discriminator_field.to_representation(None)
except: # noqa: E722
warn(
f'sub-serializer {resolved_name} of {self.target.component_name} '
f'must contain the discriminator field "{self.target.resource_type_field_name}". '
f'defaulting to sub-serializer name, but schema will likely not match the API.'
)
resource_type = resolved_name
sub_components.append((resource_type, resolved_schema))
return sub_components
def _get_explicit_sub_components(self, auto_schema, direction):
sub_components = []
for resource_type, sub_serializer in self.target.serializers.items():
sub_serializer = self._prep_serializer(sub_serializer)
(_, resolved_schema) = self._process_serializer(auto_schema, sub_serializer, direction)
if resolved_schema:
sub_components.append((resource_type, resolved_schema))
return sub_components
def _has_discriminator(self):
return self.target.resource_type_field_name is not None
def _prep_serializer(self, serializer):
serializer = force_instance(serializer)
serializer.partial = self.target.partial
return serializer
def _process_serializer(self, auto_schema, serializer, direction):
if is_list_serializer(serializer):
if self._has_discriminator() or self.target._many is not False:
warn(
"To control sub-serializer's \"many\" attribute, following usage pattern is necessary: "
"PolymorphicProxySerializer(serializers=[...], resource_type_field_name=None, "
"many=False). Ignoring serializer as it is not processable in this configuration."
)
return None, None
schema = auto_schema._unwrap_list_serializer(serializer, direction)
return None, schema
elif is_serializer(serializer):
resolved = auto_schema.resolve_serializer(serializer, direction)
return (resolved.name, resolved.ref) if resolved else (None, None)
else:
error("PolymorphicProxySerializer's serializer argument contained unknown objects.")
return None, None

View File

@@ -0,0 +1,288 @@
from contextlib import contextmanager
from typing import Any, Dict
from django.conf import settings
from rest_framework.settings import APISettings, perform_import
SPECTACULAR_DEFAULTS: Dict[str, Any] = {
# A regex specifying the common denominator for all operation paths. If
# SCHEMA_PATH_PREFIX is set to None, drf-spectacular will attempt to estimate
# a common prefix. Use '' to disable.
# Mainly used for tag extraction, where paths like '/api/v1/albums' with
# a SCHEMA_PATH_PREFIX regex '/api/v[0-9]' would yield the tag 'albums'.
'SCHEMA_PATH_PREFIX': None,
# Remove matching SCHEMA_PATH_PREFIX from operation path. Usually used in
# conjunction with appended prefixes in SERVERS.
'SCHEMA_PATH_PREFIX_TRIM': False,
# Insert a manual path prefix to the operation path, e.g. '/service/backend'.
# Use this for example to align paths when the API is mounted as a sub-resource
# behind a proxy and Django is not aware of that. Alternatively, prefixes can
# also specified via SERVERS, but this makes the operation path more explicit.
'SCHEMA_PATH_PREFIX_INSERT': '',
# Coercion of {pk} to {id} is controlled by SCHEMA_COERCE_PATH_PK. Additionally,
# some libraries (e.g. drf-nested-routers) use "_pk" suffixed path variables.
# This setting globally coerces path variables like "{user_pk}" to "{user_id}".
'SCHEMA_COERCE_PATH_PK_SUFFIX': False,
# Schema generation parameters to influence how components are constructed.
# Some schema features might not translate well to your target.
# Demultiplexing/modifying components might help alleviate those issues.
'DEFAULT_GENERATOR_CLASS': 'drf_spectacular.generators.SchemaGenerator',
# Create separate components for PATCH endpoints (without required list)
'COMPONENT_SPLIT_PATCH': True,
# Split components into request and response parts where appropriate
# This setting is highly recommended to achieve the most accurate API
# description, however it comes at the cost of having more components.
'COMPONENT_SPLIT_REQUEST': False,
# Aid client generator targets that have trouble with read-only properties.
'COMPONENT_NO_READ_ONLY_REQUIRED': False,
# Adds "minLength: 1" to fields that do not allow blank strings. Deactivated
# by default because serializers do not strictly enforce this on responses and
# so "minLength: 1" may not always accurately describe API behavior.
# Gets implicitly enabled by COMPONENT_SPLIT_REQUEST, because this can be
# accurately modeled when request and response components are separated.
'ENFORCE_NON_BLANK_FIELDS': False,
# This version string will end up the in schema header. The default OpenAPI
# version is 3.0.3, which is heavily tested. We now also support 3.1.0,
# which contains the same features and a few mandatory, but minor changes.
'OAS_VERSION': '3.0.3',
# Configuration for serving a schema subset with SpectacularAPIView
'SERVE_URLCONF': None,
# complete public schema or a subset based on the requesting user
'SERVE_PUBLIC': True,
# include schema endpoint into schema
'SERVE_INCLUDE_SCHEMA': True,
# list of authentication/permission classes for spectacular's views.
'SERVE_PERMISSIONS': ['rest_framework.permissions.AllowAny'],
# None will default to DRF's AUTHENTICATION_CLASSES
'SERVE_AUTHENTICATION': None,
# Dictionary of general configuration to pass to the SwaggerUI({ ... })
# https://swagger.io/docs/open-source-tools/swagger-ui/usage/configuration/
# The settings are serialized with json.dumps(). If you need customized JS, use a
# string instead. The string must then contain valid JS and is passed unchanged.
'SWAGGER_UI_SETTINGS': {
'deepLinking': True,
},
# Initialize SwaggerUI with additional OAuth2 configuration.
# https://swagger.io/docs/open-source-tools/swagger-ui/usage/oauth2/
'SWAGGER_UI_OAUTH2_CONFIG': {},
# Dictionary of general configuration to pass to the Redoc.init({ ... })
# https://redocly.com/docs/redoc/config/#functional-settings
# The settings are serialized with json.dumps(). If you need customized JS, use a
# string instead. The string must then contain valid JS and is passed unchanged.
'REDOC_UI_SETTINGS': {},
# CDNs for swagger and redoc. You can change the version or even host your
# own depending on your requirements. For self-hosting, have a look at
# the sidecar option in the README.
'SWAGGER_UI_DIST': 'https://cdn.jsdelivr.net/npm/swagger-ui-dist@latest',
'SWAGGER_UI_FAVICON_HREF': 'https://cdn.jsdelivr.net/npm/swagger-ui-dist@latest/favicon-32x32.png',
'REDOC_DIST': 'https://cdn.jsdelivr.net/npm/redoc@latest',
# Append OpenAPI objects to path and components in addition to the generated objects
'APPEND_PATHS': {},
'APPEND_COMPONENTS': {},
# STRONGLY DISCOURAGED (with the exception for the djangorestframework-api-key library)
# please don't use this anymore as it has tricky implications that
# are hard to get right. For authentication, OpenApiAuthenticationExtension are
# strongly preferred because they are more robust and easy to write.
# However if used, the list of methods is appended to every endpoint in the schema!
'SECURITY': [],
# Postprocessing functions that run at the end of schema generation.
# must satisfy interface result = hook(generator, request, public, result)
'POSTPROCESSING_HOOKS': [
'drf_spectacular.hooks.postprocess_schema_enums'
],
# Preprocessing functions that run before schema generation.
# must satisfy interface result = hook(endpoints=result) where result
# is a list of Tuples (path, path_regex, method, callback).
# Example: 'drf_spectacular.hooks.preprocess_exclude_path_format'
'PREPROCESSING_HOOKS': [],
# Determines how operations should be sorted. If you intend to do sorting with a
# PREPROCESSING_HOOKS, be sure to disable this setting. If configured, the sorting
# is applied after the PREPROCESSING_HOOKS. Accepts either
# True (drf-spectacular's alpha-sorter), False, or a callable for sort's key arg.
'SORT_OPERATIONS': True,
# enum name overrides. dict with keys "YourEnum" and their choice values "field.choices"
# e.g. {'SomeEnum': ['A', 'B'], 'OtherEnum': 'import.path.to.choices'}
'ENUM_NAME_OVERRIDES': {},
# Adds "blank" and "null" enum choices where appropriate. disable on client generation issues
'ENUM_ADD_EXPLICIT_BLANK_NULL_CHOICE': True,
# Add/Append a list of (``choice value`` - choice name) to the enum description string.
'ENUM_GENERATE_CHOICE_DESCRIPTION': True,
# Optional suffix for generated enum.
# e.g. {'ENUM_SUFFIX': "Type"} would produce an enum name 'StatusType'.
'ENUM_SUFFIX': 'Enum',
# function that returns a list of all classes that should be excluded from doc string extraction
'GET_LIB_DOC_EXCLUDES': 'drf_spectacular.plumbing.get_lib_doc_excludes',
# Function that returns a mocked request for view processing. For CLI usage
# original_request will be None.
# interface: request = build_mock_request(method, path, view, original_request, **kwargs)
'GET_MOCK_REQUEST': 'drf_spectacular.plumbing.build_mock_request',
# Camelize names like "operationId" and path parameter names
# Camelization of the operation schema itself requires the addition of
# 'drf_spectacular.contrib.djangorestframework_camel_case.camelize_serializer_fields'
# to POSTPROCESSING_HOOKS. Please note that the hook depends on
# ``djangorestframework_camel_case``, while CAMELIZE_NAMES itself does not.
'CAMELIZE_NAMES': False,
# Changes the location of the action/method on the generated OperationId. For example,
# "POST": "group_person_list", "group_person_create"
# "PRE": "list_group_person", "create_group_person"
'OPERATION_ID_METHOD_POSITION': 'POST',
# Determines if and how free-form 'additionalProperties' should be emitted in the schema. Some
# code generator targets are sensitive to this. None disables generic 'additionalProperties'.
# allowed values are 'dict', 'bool', None
'GENERIC_ADDITIONAL_PROPERTIES': 'dict',
# Path converter schema overrides (e.g. <int:foo>). Can be used to either modify default
# behavior or provide a schema for custom converters registered with register_converter(...).
# Takes converter labels as keys and either basic python types, OpenApiType, or raw schemas
# as values. Example: {'aint': OpenApiTypes.INT, 'bint': str, 'cint': {'type': ...}}
'PATH_CONVERTER_OVERRIDES': {},
# Determines whether operation parameters should be sorted alphanumerically or just in
# the order they arrived. Accepts either True, False, or a callable for sort's key arg.
'SORT_OPERATION_PARAMETERS': True,
# @extend_schema allows to specify status codes besides 200. This functionality is usually used
# to describe error responses, which rarely make use of list mechanics. Therefore, we suppress
# listing (pagination and filtering) on non-2XX status codes by default. Toggle this to enable
# list responses with ListSerializers/many=True irrespective of the status code.
'ENABLE_LIST_MECHANICS_ON_NON_2XX': False,
# This setting allows you to deviate from the default manager by accessing a different model
# property. We use "objects" by default for compatibility reasons. Using "_default_manager"
# will likely fix most issues, though you are free to choose any name.
"DEFAULT_QUERY_MANAGER": 'objects',
# Controls which authentication methods are exposed in the schema. If not None, will hide
# authentication classes that are not contained in the whitelist. Use full import paths
# like ['rest_framework.authentication.TokenAuthentication', ...].
# Empty list ([]) will hide all authentication methods. The default None will show all.
'AUTHENTICATION_WHITELIST': None,
# Controls which parsers are exposed in the schema. Works analog to AUTHENTICATION_WHITELIST.
# List of allowed parsers or None to allow all.
'PARSER_WHITELIST': None,
# Controls which renderers are exposed in the schema. Works analog to AUTHENTICATION_WHITELIST.
# rest_framework.renderers.BrowsableAPIRenderer is ignored by default if whitelist is None
'RENDERER_WHITELIST': None,
# Option for turning off error and warn messages
'DISABLE_ERRORS_AND_WARNINGS': False,
# Runs exemplary schema generation and emits warnings as part of "./manage.py check --deploy"
'ENABLE_DJANGO_DEPLOY_CHECK': True,
# General schema metadata. Refer to spec for valid inputs
# https://spec.openapis.org/oas/v3.0.3#openapi-object
'TITLE': '',
'DESCRIPTION': '',
'TOS': None,
# Optional: MAY contain "name", "url", "email"
'CONTACT': {},
# Optional: MUST contain "name", MAY contain URL
'LICENSE': {},
# Statically set schema version. May also be an empty string. When used together with
# view versioning, will become '0.0.0 (v2)' for 'v2' versioned requests.
# Set VERSION to None if only the request version should be rendered.
'VERSION': '0.0.0',
# Optional list of servers.
# Each entry MUST contain "url", MAY contain "description", "variables"
# e.g. [{'url': 'https://example.com/v1', 'description': 'Text'}, ...]
'SERVERS': [],
# Tags defined in the global scope
'TAGS': [],
# Optional: List of OpenAPI 3.1 webhooks. Each entry should be an import path to an
# OpenApiWebhook instance.
'WEBHOOKS': [],
# Optional: MUST contain 'url', may contain "description"
'EXTERNAL_DOCS': {},
# Arbitrary specification extensions attached to the schema's info object.
# https://swagger.io/specification/#specification-extensions
'EXTENSIONS_INFO': {},
# Arbitrary specification extensions attached to the schema's root object.
# https://swagger.io/specification/#specification-extensions
'EXTENSIONS_ROOT': {},
# Oauth2 related settings. used for example by django-oauth2-toolkit.
# https://spec.openapis.org/oas/v3.0.3#oauth-flows-object
'OAUTH2_FLOWS': [],
'OAUTH2_AUTHORIZATION_URL': None,
'OAUTH2_TOKEN_URL': None,
'OAUTH2_REFRESH_URL': None,
'OAUTH2_SCOPES': None,
}
IMPORT_STRINGS = [
'DEFAULT_GENERATOR_CLASS',
'SERVE_AUTHENTICATION',
'SERVE_PERMISSIONS',
'POSTPROCESSING_HOOKS',
'PREPROCESSING_HOOKS',
'GET_LIB_DOC_EXCLUDES',
'GET_MOCK_REQUEST',
'SORT_OPERATIONS',
'SORT_OPERATION_PARAMETERS',
'AUTHENTICATION_WHITELIST',
'RENDERER_WHITELIST',
'PARSER_WHITELIST',
'WEBHOOKS',
]
class SpectacularSettings(APISettings):
_original_settings: Dict[str, Any] = {}
def apply_patches(self, patches):
for attr, val in patches.items():
if attr.startswith('SERVE_') or attr == 'DEFAULT_GENERATOR_CLASS':
raise AttributeError(
f'{attr} not allowed in custom_settings. use dedicated parameter instead.'
)
if attr in self.import_strings:
val = perform_import(val, attr)
# load and store original value, then override __dict__ entry
self._original_settings[attr] = getattr(self, attr)
setattr(self, attr, val)
def clear_patches(self):
for attr, orig_val in self._original_settings.items():
setattr(self, attr, orig_val)
self._original_settings = {}
spectacular_settings = SpectacularSettings(
user_settings=getattr(settings, 'SPECTACULAR_SETTINGS', {}), # type: ignore
defaults=SPECTACULAR_DEFAULTS, # type: ignore
import_strings=IMPORT_STRINGS,
)
@contextmanager
def patched_settings(patches):
""" temporarily patch the global spectacular settings (or do nothing) """
if not patches:
yield
else:
try:
spectacular_settings.apply_patches(patches)
yield
finally:
spectacular_settings.clear_patches()

View File

@@ -0,0 +1,32 @@
<!DOCTYPE html>
<html>
<head>
{% block head %}
<title>{{ title|default:"Redoc" }}</title>
<meta charset="utf-8">
<meta name="viewport" content="width=device-width, initial-scale=1">
<link rel="preconnect" href="https://fonts.googleapis.com">
<link rel="preconnect" href="https://fonts.gstatic.com" crossorigin>
<link rel="stylesheet" href="https://fonts.googleapis.com/css2?family=Montserrat:wght@300;400;700&family=Roboto:wght@300;400;700&display=swap">
<style>
{# Redoc doesn't change outer page styles. #}
body { margin: 0; padding: 0; }
</style>
{% endblock head %}
</head>
<body>
{% block body %}
{% if settings %}
<div id="redoc-container"></div>
<script src="{{ redoc_standalone }}"></script>
<script>
const redocSettings = {{ settings|safe }};
Redoc.init("{{ schema_url }}", redocSettings, document.getElementById('redoc-container'))
</script>
{% else %}
<redoc spec-url="{{ schema_url }}"></redoc>
<script src="{{ redoc_standalone }}"></script>
{% endif %}
{% endblock body %}
</body>
</html>

View File

@@ -0,0 +1,31 @@
<!DOCTYPE html>
<html>
<head>
{% block head %}
<title>{{ title|default:"Swagger" }}</title>
<meta charset="utf-8">
<meta name="viewport" content="width=device-width, initial-scale=1">
{% if favicon_href %}<link rel="icon" href="{{ favicon_href }}">{% endif %}
<link rel="stylesheet" href="{{ swagger_ui_css }}">
<style>
html { box-sizing: border-box; overflow-y: scroll; }
*, *:after, *:before { box-sizing: inherit; }
body { background: #fafafa; margin: 0; }
</style>
{% endblock head %}
</head>
<body>
{% block body %}
<div id="swagger-ui"></div>
<script src="{{ swagger_ui_bundle }}"></script>
<script src="{{ swagger_ui_standalone }}"></script>
{% if script_url %}
<script src="{{ script_url }}"></script>
{% else %}
<script>
{% include template_name_js %}
</script>
{% endif %}
{% endblock %}
</body>
</html>

View File

@@ -0,0 +1,126 @@
"use strict";
const swaggerSettings = {{ settings|safe }};
const schemaAuthNames = {{ schema_auth_names|safe }};
let schemaAuthFailed = false;
const plugins = [];
const reloadSchemaOnAuthChange = () => {
return {
statePlugins: {
auth: {
wrapActions: {
authorizeOauth2:(ori) => (...args) => {
schemaAuthFailed = false;
setTimeout(() => ui.specActions.download());
return ori(...args);
},
authorize: (ori) => (...args) => {
schemaAuthFailed = false;
setTimeout(() => ui.specActions.download());
return ori(...args);
},
logout: (ori) => (...args) => {
schemaAuthFailed = false;
setTimeout(() => ui.specActions.download());
return ori(...args);
},
},
},
},
};
};
if (schemaAuthNames.length > 0) {
plugins.push(reloadSchemaOnAuthChange);
}
const uiInitialized = () => {
try {
ui;
return true;
} catch {
return false;
}
};
const isSchemaUrl = (url) => {
if (!uiInitialized()) {
return false;
}
return url === new URL(ui.getConfigs().url, document.baseURI).href;
};
const responseInterceptor = (response, ...args) => {
if (!response.ok && isSchemaUrl(response.url)) {
console.warn("schema request received '" + response.status + "'. disabling credentials for schema till logout.");
if (!schemaAuthFailed) {
// only retry once to prevent endless loop.
schemaAuthFailed = true;
setTimeout(() => ui.specActions.download());
}
}
return response;
};
const injectAuthCredentials = (request) => {
let authorized;
if (uiInitialized()) {
const state = ui.getState().get("auth").get("authorized");
if (state !== undefined && Object.keys(state.toJS()).length !== 0) {
authorized = state.toJS();
}
} else if (![undefined, "{}"].includes(localStorage.authorized)) {
authorized = JSON.parse(localStorage.authorized);
}
if (authorized === undefined) {
return;
}
for (const authName of schemaAuthNames) {
const authDef = authorized[authName];
if (authDef === undefined || authDef.schema === undefined) {
continue;
}
if (authDef.schema.type === "http" && authDef.schema.scheme === "bearer") {
request.headers["Authorization"] = "Bearer " + authDef.value;
return;
} else if (authDef.schema.type === "http" && authDef.schema.scheme === "basic") {
request.headers["Authorization"] = "Basic " + btoa(authDef.value.username + ":" + authDef.value.password);
return;
} else if (authDef.schema.type === "apiKey" && authDef.schema.in === "header") {
request.headers[authDef.schema.name] = authDef.value;
return;
} else if (authDef.schema.type === "oauth2" && authDef.token.token_type === "Bearer") {
request.headers["Authorization"] = `Bearer ${authDef.token.access_token}`;
return;
}
}
};
const requestInterceptor = (request, ...args) => {
if (request.loadSpec && schemaAuthNames.length > 0 && !schemaAuthFailed) {
try {
injectAuthCredentials(request);
} catch (e) {
console.error("schema auth injection failed with error: ", e);
}
}
// selectively omit adding headers to mitigate CORS issues.
if (!["GET", undefined].includes(request.method) && request.credentials === "same-origin") {
request.headers["{{ csrf_header_name }}"] = "{{ csrf_token }}";
}
return request;
};
const ui = SwaggerUIBundle({
url: "{{ schema_url|escapejs }}",
dom_id: "#swagger-ui",
presets: [SwaggerUIBundle.presets.apis],
plugins,
layout: "BaseLayout",
requestInterceptor,
responseInterceptor,
...swaggerSettings,
});
{% if oauth2_config %}ui.initOAuth({{ oauth2_config|safe }});{% endif %}

View File

@@ -0,0 +1,177 @@
import enum
import typing
from datetime import date, datetime, time, timedelta
from decimal import Decimal
from ipaddress import IPv4Address, IPv6Address
from uuid import UUID
_KnownPythonTypes = typing.Type[typing.Union[
str, float, bool, bytes, int, dict, UUID, Decimal, datetime, date, time,
timedelta, IPv4Address, IPv6Address,
]]
class OpenApiTypes(enum.Enum):
"""
Basic types known to the OpenAPI specification or at least common format extension of it.
- Use ``BYTE`` for base64-encoded data wrapped in a string
- Use ``BINARY`` for raw binary data
- Use ``OBJECT`` for arbitrary free-form object (usually a :py:class:`dict`)
"""
#: Converted to ``{"type": "number"}``.
NUMBER = enum.auto()
#: Converted to ``{"type": "number", "format": "float"}``.
#: Equivalent to :py:class:`float`.
FLOAT = enum.auto()
#: Converted to ``{"type": "number", "format": "double"}``.
DOUBLE = enum.auto()
#: Converted to ``{"type": "boolean"}``.
#: Equivalent to :py:class:`bool`.
BOOL = enum.auto()
#: Converted to ``{"type": "string"}``.
#: Equivalent to :py:class:`str`.
STR = enum.auto()
#: Converted to ``{"type": "string", "format": "byte"}``.
#: Use this for base64-encoded data wrapped in a string.
BYTE = enum.auto()
#: Converted to ``{"type": "string", "format": "binary"}``.
#: Equivalent to :py:class:`bytes`.
#: Use this for raw binary data.
BINARY = enum.auto()
#: Converted to ``{"type": "string", "format": "password"}``.
PASSWORD = enum.auto()
#: Converted to ``{"type": "integer"}``.
#: Equivalent to :py:class:`int`.
INT = enum.auto()
#: Converted to ``{"type": "integer", "format": "int32"}``.
INT32 = enum.auto()
#: Converted to ``{"type": "integer", "format": "int64"}``.
INT64 = enum.auto()
#: Converted to ``{"type": "string", "format": "uuid"}``.
#: Equivalent to :py:class:`~uuid.UUID`.
UUID = enum.auto()
#: Converted to ``{"type": "string", "format": "uri"}``.
URI = enum.auto()
#: Converted to ``{"type": "string", "format": "uri-reference"}``.
URI_REF = enum.auto()
#: Converted to ``{"type": "string", "format": "uri-template"}``.
URI_TPL = enum.auto()
#: Converted to ``{"type": "string", "format": "iri"}``.
IRI = enum.auto()
#: Converted to ``{"type": "string", "format": "iri-reference"}``.
IRI_REF = enum.auto()
#: Converted to ``{"type": "string", "format": "ipv4"}``.
#: Equivalent to :py:class:`~ipaddress.IPv4Address`.
IP4 = enum.auto()
#: Converted to ``{"type": "string", "format": "ipv6"}``.
#: Equivalent to :py:class:`~ipaddress.IPv6Address`.
IP6 = enum.auto()
#: Converted to ``{"type": "string", "format": "hostname"}``.
HOSTNAME = enum.auto()
#: Converted to ``{"type": "string", "format": "idn-hostname"}``.
IDN_HOSTNAME = enum.auto()
#: Converted to ``{"type": "number", "format": "double"}``.
#: The same as :py:attr:`~drf_spectacular.types.OpenApiTypes.DOUBLE`.
#: Equivalent to :py:class:`~decimal.Decimal`.
DECIMAL = enum.auto()
#: Converted to ``{"type": "string", "format": "date-time"}``.
#: Equivalent to :py:class:`~datetime.datetime`.
DATETIME = enum.auto()
#: Converted to ``{"type": "string", "format": "date"}``.
#: Equivalent to :py:class:`~datetime.date`.
DATE = enum.auto()
#: Converted to ``{"type": "string", "format": "time"}``.
#: Equivalent to :py:class:`~datetime.time`.
TIME = enum.auto()
#: Converted to ``{"type": "string", "format": "duration"}``.
#: Equivalent to :py:class:`~datetime.timedelta`.
#: Expressed according to ISO 8601.
DURATION = enum.auto()
#: Converted to ``{"type": "string", "format": "email"}``.
EMAIL = enum.auto()
#: Converted to ``{"type": "string", "format": "idn-email"}``.
IDN_EMAIL = enum.auto()
#: Converted to ``{"type": "string", "format": "json-pointer"}``.
JSON_PTR = enum.auto()
#: Converted to ``{"type": "string", "format": "relative-json-pointer"}``.
JSON_PTR_REL = enum.auto()
#: Converted to ``{"type": "string", "format": "regex"}``.
REGEX = enum.auto()
#: Converted to ``{"type": "object", ...}``.
#: Use this for arbitrary free-form objects (usually a :py:class:`dict`).
#: The ``additionalProperties`` item is added depending on the ``GENERIC_ADDITIONAL_PROPERTIES`` setting.
OBJECT = enum.auto()
#: Equivalent to :py:data:`None`.
#: This signals that the request or response is empty.
NONE = enum.auto()
#: Converted to ``{}`` which sets no type and format.
#: Equivalent to :py:class:`typing.Any`.
ANY = enum.auto()
# make a copy with dict() before modifying returned dict
OPENAPI_TYPE_MAPPING = {
OpenApiTypes.NUMBER: {'type': 'number'},
OpenApiTypes.FLOAT: {'type': 'number', 'format': 'float'},
OpenApiTypes.DOUBLE: {'type': 'number', 'format': 'double'},
OpenApiTypes.BOOL: {'type': 'boolean'},
OpenApiTypes.STR: {'type': 'string'},
OpenApiTypes.BYTE: {'type': 'string', 'format': 'byte'},
OpenApiTypes.BINARY: {'type': 'string', 'format': 'binary'},
OpenApiTypes.PASSWORD: {'type': 'string', 'format': 'password'},
OpenApiTypes.INT: {'type': 'integer'},
OpenApiTypes.INT32: {'type': 'integer', 'format': 'int32'},
OpenApiTypes.INT64: {'type': 'integer', 'format': 'int64'},
OpenApiTypes.UUID: {'type': 'string', 'format': 'uuid'},
OpenApiTypes.URI: {'type': 'string', 'format': 'uri'},
OpenApiTypes.URI_REF: {'type': 'string', 'format': 'uri-reference'},
OpenApiTypes.URI_TPL: {'type': 'string', 'format': 'uri-template'},
OpenApiTypes.IRI: {'type': 'string', 'format': 'iri'},
OpenApiTypes.IRI_REF: {'type': 'string', 'format': 'iri-reference'},
OpenApiTypes.IP4: {'type': 'string', 'format': 'ipv4'},
OpenApiTypes.IP6: {'type': 'string', 'format': 'ipv6'},
OpenApiTypes.HOSTNAME: {'type': 'string', 'format': 'hostname'},
OpenApiTypes.IDN_HOSTNAME: {'type': 'string', 'format': 'idn-hostname'},
OpenApiTypes.DECIMAL: {'type': 'number', 'format': 'double'},
OpenApiTypes.DATETIME: {'type': 'string', 'format': 'date-time'},
OpenApiTypes.DATE: {'type': 'string', 'format': 'date'},
OpenApiTypes.TIME: {'type': 'string', 'format': 'time'},
OpenApiTypes.DURATION: {'type': 'string', 'format': 'duration'}, # ISO 8601
OpenApiTypes.EMAIL: {'type': 'string', 'format': 'email'},
OpenApiTypes.IDN_EMAIL: {'type': 'string', 'format': 'idn-email'},
OpenApiTypes.JSON_PTR: {'type': 'string', 'format': 'json-pointer'},
OpenApiTypes.JSON_PTR_REL: {'type': 'string', 'format': 'relative-json-pointer'},
OpenApiTypes.REGEX: {'type': 'string', 'format': 'regex'},
OpenApiTypes.ANY: {},
OpenApiTypes.NONE: None,
# OpenApiTypes.OBJECT is inserted at runtime due to dependency on settings
}
PYTHON_TYPE_MAPPING = {
str: OpenApiTypes.STR,
float: OpenApiTypes.DOUBLE,
bool: OpenApiTypes.BOOL,
bytes: OpenApiTypes.BINARY,
int: OpenApiTypes.INT,
UUID: OpenApiTypes.UUID,
Decimal: OpenApiTypes.DECIMAL,
datetime: OpenApiTypes.DATETIME,
date: OpenApiTypes.DATE,
time: OpenApiTypes.TIME,
timedelta: OpenApiTypes.DURATION,
IPv4Address: OpenApiTypes.IP4,
IPv6Address: OpenApiTypes.IP6,
dict: OpenApiTypes.OBJECT,
typing.Any: OpenApiTypes.ANY,
None: OpenApiTypes.NONE,
}
DJANGO_PATH_CONVERTER_MAPPING = {
'int': OpenApiTypes.INT,
'path': OpenApiTypes.STR,
'slug': OpenApiTypes.STR,
'str': OpenApiTypes.STR,
'uuid': OpenApiTypes.UUID,
'drf_format_suffix': OpenApiTypes.STR,
}

View File

@@ -0,0 +1,693 @@
import inspect
import sys
from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Type, TypeVar, Union
from django.utils.functional import Promise
# direct import due to https://github.com/microsoft/pyright/issues/3025
if sys.version_info >= (3, 8):
from typing import Final, Literal
else:
from typing_extensions import Final, Literal
from rest_framework.fields import Field, empty
from rest_framework.serializers import ListSerializer, Serializer
from rest_framework.settings import api_settings
from drf_spectacular.drainage import (
error, get_view_method_names, isolate_view_method, set_override, warn,
)
from drf_spectacular.types import OpenApiTypes, _KnownPythonTypes
_ListSerializerType = Union[ListSerializer, Type[ListSerializer]]
_SerializerType = Union[Serializer, Type[Serializer]]
_FieldType = Union[Field, Type[Field]]
_ParameterLocationType = Literal['query', 'path', 'header', 'cookie']
_StrOrPromise = Union[str, Promise]
_SchemaType = Dict[str, Any]
Direction = Literal['request', 'response']
class PolymorphicProxySerializer(Serializer):
"""
This class is to be used with :func:`@extend_schema <.extend_schema>` to
signal a request/response might be polymorphic (accepts/returns data
possibly from different serializers). Usage usually looks like this:
.. code-block::
@extend_schema(
request=PolymorphicProxySerializer(
component_name='MetaPerson',
serializers=[
LegalPersonSerializer, NaturalPersonSerializer,
],
resource_type_field_name='person_type',
)
)
def create(self, request, *args, **kwargs):
return Response(...)
**Beware** that this is not a real serializer and it will raise an AssertionError
if used in that way. It **cannot** be used in views as ``serializer_class``
or as field in an actual serializer. It is solely meant for annotation purposes.
Also make sure that each sub-serializer has a field named after the value of
``resource_type_field_name`` (discriminator field). Generated clients will likely
depend on the existence of this field.
Setting ``resource_type_field_name`` to ``None`` will remove the discriminator
altogether. This may be useful in certain situations, but will most likely break
client generation. Another use-case is explicit control over sub-serializer's ``many``
attribute. To explicitly control this aspect, you need disable the discriminator with
``resource_type_field_name=None`` as well as disable automatic list handling with
``many=False``.
It is **strongly** recommended to pass the ``Serializers`` as **list**,
and by that let *drf-spectacular* retrieve the field and handle the mapping
automatically. In special circumstances, the field may not available when
*drf-spectacular* processes the serializer. In those cases you can explicitly state
the mapping with ``{'legal': LegalPersonSerializer, ...}``, but it is then your
responsibility to have a valid mapping.
It is also permissible to provide a callable with no parameters for ``serializers``,
such as a lambda that will return an appropriate list or dict when evaluated.
"""
def __init__(
self,
component_name: str,
serializers: Union[
Sequence[_SerializerType],
Dict[str, _SerializerType],
Callable[[], Sequence[_SerializerType]],
Callable[[], Dict[str, _SerializerType]]
],
resource_type_field_name: Optional[str],
many: Optional[bool] = None,
**kwargs
):
self.component_name = component_name
self.serializers = serializers
self.resource_type_field_name = resource_type_field_name
if self._many is False: # type: ignore[attr-defined]
set_override(self, 'many', False)
# retain kwargs in context for potential anonymous re-init with many=True
kwargs.setdefault('context', {}).update({
'component_name': component_name,
'serializers': serializers,
'resource_type_field_name': resource_type_field_name
})
super().__init__(**kwargs)
def __new__(cls, *args, **kwargs):
many = kwargs.pop('many', None)
if many is True:
context = kwargs.get('context', {})
for arg in ['component_name', 'serializers', 'resource_type_field_name']:
if arg in context:
kwargs[arg] = context.pop(arg) # re-apply retained args
instance = cls.many_init(*args, **kwargs)
else:
instance = super().__new__(cls, *args, **kwargs)
instance._many = many
return instance
@property
def serializers(self):
if callable(self._serializers):
self._serializers = self._serializers()
return self._serializers
@serializers.setter
def serializers(self, value):
self._serializers = value
@property
def data(self):
self._trap()
def to_internal_value(self, data):
self._trap()
def to_representation(self, instance):
self._trap()
def _trap(self):
raise AssertionError(
"PolymorphicProxySerializer is an annotation helper and not supposed to "
"be used for real requests. See documentation for correct usage."
)
class OpenApiSchemaBase:
pass
class OpenApiExample(OpenApiSchemaBase):
"""
Helper class to document a API parameter / request body / response body
with a concrete example value.
It is recommended to provide a singular example value, since pagination
and list responses are handled by drf-spectacular.
The example will be attached to the operation object where appropriate,
i.e. where the given ``media_type``, ``status_code`` and modifiers match.
Example that do not match any scenario are ignored.
- media_type will default to 'application/json' unless implicitly specified
through :class:`.OpenApiResponse`
- status_codes will default to [200, 201] unless implicitly specified
through :class:`.OpenApiResponse`
"""
def __init__(
self,
name: str,
value: Any = empty,
external_value: str = '',
summary: _StrOrPromise = '',
description: _StrOrPromise = '',
request_only: bool = False,
response_only: bool = False,
parameter_only: Optional[Tuple[str, _ParameterLocationType]] = None,
media_type: Optional[str] = None,
status_codes: Optional[Sequence[Union[str, int]]] = None,
):
self.name = name
self.summary = summary
self.description = description
self.value = value
self.external_value = external_value
self.request_only = request_only
self.response_only = response_only
self.parameter_only = parameter_only
self.media_type = media_type
self.status_codes = status_codes
class OpenApiParameter(OpenApiSchemaBase):
"""
Helper class to document request query/path/header/cookie parameters.
Can also be used to document response headers.
Please note that not all arguments apply to all ``location``/``type``/direction
variations, e.g. path parameters are ``required=True`` by definition.
For valid ``style`` choices please consult the
`OpenAPI specification <https://swagger.io/specification/#style-values>`_.
"""
QUERY: Final = 'query'
PATH: Final = 'path'
HEADER: Final = 'header'
COOKIE: Final = 'cookie'
def __init__(
self,
name: str,
type: Union[_SerializerType, _KnownPythonTypes, OpenApiTypes, _SchemaType] = str,
location: _ParameterLocationType = QUERY,
required: bool = False,
description: _StrOrPromise = '',
enum: Optional[Sequence[Any]] = None,
pattern: Optional[str] = None,
deprecated: bool = False,
style: Optional[str] = None,
explode: Optional[bool] = None,
default: Any = None,
allow_blank: bool = True,
many: Optional[bool] = None,
examples: Optional[Sequence[OpenApiExample]] = None,
extensions: Optional[Dict[str, Any]] = None,
exclude: bool = False,
response: Union[bool, Sequence[Union[int, str]]] = False,
):
self.name = name
self.type = type
self.location = location
self.required = required
self.description = description
self.enum = enum
self.pattern = pattern
self.deprecated = deprecated
self.style = style
self.explode = explode
self.default = default
self.allow_blank = allow_blank
self.many = many
self.examples = examples or []
self.extensions = extensions
self.exclude = exclude
self.response = response
class OpenApiResponse(OpenApiSchemaBase):
"""
Helper class to bundle a response object (``Serializer``, ``OpenApiType``,
raw schema, etc) together with a response object description and/or examples.
Examples can alternatively be provided via :func:`@extend_schema <.extend_schema>`.
This class is especially helpful for explicitly describing status codes on a
"Response Object" level.
"""
def __init__(
self,
response: Any = None,
description: _StrOrPromise = '',
examples: Optional[Sequence[OpenApiExample]] = None
):
self.response = response
self.description = description
self.examples = examples or []
class OpenApiRequest(OpenApiSchemaBase):
"""
Helper class to combine a request object (``Serializer``, ``OpenApiType``,
raw schema, etc.) together with an encoding object and/or examples.
Examples can alternatively be provided via :func:`@extend_schema <.extend_schema>`.
This class is especially helpful for customizing value encoding for
``application/x-www-form-urlencoded`` and ``multipart/*``. The encoding parameter
takes a dictionary with field names as keys and encoding objects as values.
Refer to the `specification <https://swagger.io/specification/#encoding-object>`_
on how to build a valid encoding object.
"""
def __init__(
self,
request: Any = None,
encoding: Optional[Dict[str, Dict[str, Any]]] = None,
examples: Optional[Sequence[OpenApiExample]] = None,
):
self.request = request
self.encoding = encoding
self.examples = examples or []
F = TypeVar('F', bound=Callable[..., Any])
class OpenApiCallback(OpenApiSchemaBase):
"""
Helper class to bundle a callback definition. This specifies a view on the callee's
side, effectively stating the expectations on the receiving end. Please note that this
particular :func:`@extend_schema <.extend_schema>` instance operates from the perspective
of the callback origin, which means that ``request`` specifies the outgoing request.
For convenience sake, we assume the callback sends ``application/json`` and return a ``200``.
If that is not sufficient, you can use ``request`` and ``responses`` overloads just as you
normally would.
:param name: Name under which the this callback is listed in the schema.
:param path: Path on which the callback operation is performed. To reference request
body contents, please refer to OpenAPI specification's
`key expressions <https://swagger.io/specification/#key-expression>`_ for valid choices.
:param decorator: :func:`@extend_schema <.extend_schema>` decorator that specifies the receiving
endpoint. In this special context the allowed parameters are ``requests``, ``responses``,
``summary``, ``description``, ``deprecated``.
"""
def __init__(
self,
name: _StrOrPromise,
path: str,
decorator: Union[Callable[[F], F], Dict[str, Callable[[F], F]], Dict[str, Any]],
):
self.name = name
self.path = path
self.decorator = decorator
class OpenApiWebhook(OpenApiSchemaBase):
"""
Helper class to document webhook definitions. A webhook specifies a possible out-of-band
request initiated by the API provider and the expected responses from the consumer.
Please note that this particular :func:`@extend_schema <.extend_schema>` instance operates
from the perspective of the webhook origin, which means that ``request`` specifies the
outgoing request.
For convenience sake, we assume the API provider sends a POST request with a body of type
``application/json`` and the receiver responds with ``200`` if the event was successfully
received.
:param name: Name under which this webhook is listed in the schema.
:param decorator: :func:`@extend_schema <.extend_schema>` decorator that specifies the receiving
endpoint. In this special context the allowed parameters are ``requests``, ``responses``,
``summary``, ``description``, ``deprecated``.
"""
def __init__(
self,
name: _StrOrPromise,
decorator: Union[Callable[[F], F], Dict[str, Callable[[F], F]], Dict[str, Any]],
):
self.name = name
self.decorator = decorator
def extend_schema(
operation_id: Optional[str] = None,
parameters: Optional[Sequence[Union[OpenApiParameter, _SerializerType]]] = None,
request: Any = empty,
responses: Any = empty,
auth: Optional[Sequence[str]] = None,
description: Optional[_StrOrPromise] = None,
summary: Optional[_StrOrPromise] = None,
deprecated: Optional[bool] = None,
tags: Optional[Sequence[str]] = None,
filters: Optional[bool] = None,
exclude: Optional[bool] = None,
operation: Optional[_SchemaType] = None,
methods: Optional[Sequence[str]] = None,
versions: Optional[Sequence[str]] = None,
examples: Optional[Sequence[OpenApiExample]] = None,
extensions: Optional[Dict[str, Any]] = None,
callbacks: Optional[Sequence[OpenApiCallback]] = None,
external_docs: Optional[Union[Dict[str, str], str]] = None,
) -> Callable[[F], F]:
"""
Decorator mainly for the "view" method kind. Partially or completely overrides
what would be otherwise generated by drf-spectacular.
:param operation_id: replaces the auto-generated operation_id. make sure there
are no naming collisions.
:param parameters: list of additional or replacement parameters added to the
auto-discovered fields.
:param responses: replaces the discovered Serializer. Takes a variety of
inputs that can be used individually or combined
- ``Serializer`` class
- ``Serializer`` instance (e.g. ``Serializer(many=True)`` for listings)
- basic types or instances of ``OpenApiTypes``
- :class:`.OpenApiResponse` for bundling any of the other choices together with
either a dedicated response description and/or examples.
- :class:`.PolymorphicProxySerializer` for signaling that
the operation may yield data from different serializers depending
on the circumstances.
- ``dict`` with status codes as keys and one of the above as values.
Additionally in this case, it is also possible to provide a raw schema dict
as value.
- ``dict`` with tuples (status_code, media_type) as keys and one of the above
as values. Additionally in this case, it is also possible to provide a raw
schema dict as value.
:param request: replaces the discovered ``Serializer``. Takes a variety of inputs
- ``Serializer`` class/instance
- basic types or instances of ``OpenApiTypes``
- :class:`.PolymorphicProxySerializer` for signaling that the operation
accepts a set of different types of objects.
- ``dict`` with media_type as keys and one of the above as values. Additionally, in
this case, it is also possible to provide a raw schema dict as value.
:param auth: replace discovered auth with explicit list of auth methods
:param description: replaces discovered doc strings
:param summary: an optional short summary of the description
:param deprecated: mark operation as deprecated
:param tags: override default list of tags
:param filters: ignore list detection and forcefully enable/disable filter discovery
:param exclude: set True to exclude operation from schema
:param operation: manually override what auto-discovery would generate. you must
provide a OpenAPI3-compliant dictionary that gets directly translated to YAML.
:param methods: scope extend_schema to specific methods. matches all by default.
:param versions: scope extend_schema to specific API version. matches all by default.
:param examples: attach request/response examples to the operation
:param extensions: specification extensions, e.g. ``x-badges``, ``x-code-samples``, etc.
:param callbacks: associate callbacks with this endpoint
:param external_docs: Link external documentation. Provide a dict with an "url" key and
optionally a "description" key. For convenience, if only a string is given it is
treated as the URL.
:return:
"""
if methods is not None:
methods = [method.upper() for method in methods]
def decorator(f):
BaseSchema = (
# explicit manually set schema or previous view annotation
getattr(f, 'schema', None)
# previously set schema with @extend_schema on views methods
or getattr(f, 'kwargs', {}).get('schema', None)
# previously set schema with @extend_schema on @api_view
or getattr(getattr(f, 'cls', None), 'kwargs', {}).get('schema', None)
# the default
or api_settings.DEFAULT_SCHEMA_CLASS
)
if not inspect.isclass(BaseSchema):
BaseSchema = BaseSchema.__class__
def is_in_scope(ext_schema):
version, _ = ext_schema.view.determine_version(
ext_schema.view.request,
**ext_schema.view.kwargs
)
version_scope = versions is None or version in versions
method_scope = methods is None or ext_schema.method in methods
return method_scope and version_scope
class ExtendedSchema(BaseSchema):
def get_operation(self, path, path_regex, path_prefix, method, registry):
self.method = method.upper()
if operation is not None and is_in_scope(self):
return operation
return super().get_operation(path, path_regex, path_prefix, method, registry)
def is_excluded(self):
if exclude is not None and is_in_scope(self):
return exclude
return super().is_excluded()
def get_operation_id(self):
if operation_id and is_in_scope(self):
return operation_id
return super().get_operation_id()
def get_override_parameters(self):
if parameters and is_in_scope(self):
return super().get_override_parameters() + parameters
return super().get_override_parameters()
def get_auth(self):
if auth is not None and is_in_scope(self):
return auth
return super().get_auth()
def get_examples(self):
if examples and is_in_scope(self):
return super().get_examples() + examples
return super().get_examples()
def get_request_serializer(self):
if request is not empty and is_in_scope(self):
return request
return super().get_request_serializer()
def get_response_serializers(self):
if responses is not empty and is_in_scope(self):
return responses
return super().get_response_serializers()
def get_description(self):
if description and is_in_scope(self):
return description
return super().get_description()
def get_summary(self):
if summary and is_in_scope(self):
return str(summary)
return super().get_summary()
def is_deprecated(self):
if deprecated and is_in_scope(self):
return deprecated
return super().is_deprecated()
def get_tags(self):
if tags is not None and is_in_scope(self):
return tags
return super().get_tags()
def get_extensions(self):
if extensions and is_in_scope(self):
return extensions
return super().get_extensions()
def get_filter_backends(self):
if filters is not None and is_in_scope(self):
return getattr(self.view, 'filter_backends', []) if filters else []
return super().get_filter_backends()
def get_callbacks(self):
if callbacks is not None and is_in_scope(self):
return callbacks
return super().get_callbacks()
def get_external_docs(self):
if external_docs is not None and is_in_scope(self):
return external_docs
return super().get_external_docs()
if inspect.isclass(f):
# either direct decoration of views, or unpacked @api_view from OpenApiViewExtension
if operation_id is not None or operation is not None:
error(
f'using @extend_schema on viewset class {f.__name__} with parameters '
f'operation_id or operation will most likely result in a broken schema.',
delayed=f,
)
# reorder schema class MRO so that view method annotation takes precedence
# over view class annotation. only relevant if there is a method annotation
for view_method_name in get_view_method_names(view=f, schema=BaseSchema):
if 'schema' not in getattr(getattr(f, view_method_name), 'kwargs', {}):
continue
view_method = isolate_view_method(f, view_method_name)
view_method.kwargs['schema'] = type(
'ExtendedMetaSchema', (view_method.kwargs['schema'], ExtendedSchema), {}
)
# persist schema on class to provide annotation to derived view methods.
# the second purpose is to serve as base for view multi-annotation
f.schema = ExtendedSchema()
return f
elif callable(f) and hasattr(f, 'cls'):
# 'cls' attr signals that as_view() was called, which only applies to @api_view.
# keep a "unused" schema reference at root level for multi annotation convenience.
setattr(f.cls, 'kwargs', {'schema': ExtendedSchema})
# set schema on method kwargs context to emulate regular view behaviour.
for method in f.cls.http_method_names:
setattr(getattr(f.cls, method), 'kwargs', {'schema': ExtendedSchema})
return f
elif callable(f):
# custom actions have kwargs in their context, others don't. create it so our create_view
# implementation can overwrite the default schema
if not hasattr(f, 'kwargs'):
f.kwargs = {}
# this simulates what @action is actually doing. somewhere along the line in this process
# the schema is picked up from kwargs and used. it's involved my dear friends.
# use class instead of instance due to descriptor weakref reverse collisions
f.kwargs['schema'] = ExtendedSchema
return f
else:
return f
return decorator
def extend_schema_field(
field: Union[_SerializerType, _FieldType, OpenApiTypes, _SchemaType, _KnownPythonTypes],
component_name: Optional[str] = None
) -> Callable[[F], F]:
"""
Decorator for the "field" kind. Can be used with ``SerializerMethodField`` (annotate the actual
method) or with custom ``serializers.Field`` implementations.
If your custom serializer field base class is already the desired type, decoration is not necessary.
To override the discovered base class type, you can decorate your custom field class.
Always takes precedence over other mechanisms (e.g. type hints, auto-discovery).
:param field: accepts a ``Serializer``, :class:`~.types.OpenApiTypes` or raw ``dict``
:param component_name: signals that the field should be broken out as separate component
"""
def decorator(f):
set_override(f, 'field', field)
set_override(f, 'field_component_name', component_name)
return f
return decorator
def extend_schema_serializer(
many: Optional[bool] = None,
exclude_fields: Optional[Sequence[str]] = None,
deprecate_fields: Optional[Sequence[str]] = None,
examples: Optional[Sequence[OpenApiExample]] = None,
extensions: Optional[Dict[str, Any]] = None,
component_name: Optional[str] = None,
) -> Callable[[F], F]:
"""
Decorator for the "serializer" kind. Intended for overriding default serializer behaviour that
cannot be influenced through :func:`@extend_schema <.extend_schema>`.
:param many: override how serializer is initialized. Mainly used to coerce the list view detection
heuristic to acknowledge a non-list serializer.
:param exclude_fields: fields to ignore while processing the serializer. only affects the
schema. fields will still be exposed through the API.
:param deprecate_fields: fields to mark as deprecated while processing the serializer.
:param examples: define example data to serializer.
:param extensions: specification extensions, e.g. ``x-is-dynamic``, etc.
:param component_name: override default class name extraction.
"""
def decorator(klass):
if many is not None:
set_override(klass, 'many', many)
if exclude_fields:
set_override(klass, 'exclude_fields', exclude_fields)
if deprecate_fields:
set_override(klass, 'deprecate_fields', deprecate_fields)
if examples:
set_override(klass, 'examples', examples)
if extensions:
set_override(klass, 'extensions', extensions)
if component_name:
set_override(klass, 'component_name', component_name)
return klass
return decorator
def extend_schema_view(**kwargs) -> Callable[[F], F]:
"""
Convenience decorator for the "view" kind. Intended for annotating derived view methods that
are are not directly present in the view (usually methods like ``list`` or ``retrieve``).
Spares you from overriding methods like ``list``, only to perform a super call in the body
so that you have have something to attach :func:`@extend_schema <.extend_schema>` to.
This decorator also takes care of safely attaching annotations to derived view methods,
preventing leakage into unrelated views.
This decorator also supports custom DRF ``@action`` with the method name as the key.
:param kwargs: method names as argument names and :func:`@extend_schema <.extend_schema>`
calls as values
"""
def decorator(view):
# special case for @api_view. redirect decoration to enclosed WrappedAPIView
if callable(view) and hasattr(view, 'cls'):
extend_schema_view(**kwargs)(view.cls)
return view
available_view_methods = get_view_method_names(view)
for method_name, method_decorator in kwargs.items():
if method_name not in available_view_methods:
warn(
f'@extend_schema_view argument "{method_name}" was not found on view '
f'{view.__name__}. method override for "{method_name}" will be ignored.',
delayed=view
)
continue
# the context of derived methods must not be altered, as it belongs to the
# other view. create a new context so the schema can be safely stored in the
# wrapped_method. view methods that are not derived can be safely altered.
if hasattr(method_decorator, '__iter__'):
for sub_method_decorator in method_decorator:
sub_method_decorator(isolate_view_method(view, method_name))
else:
method_decorator(isolate_view_method(view, method_name))
return view
return decorator
def inline_serializer(name: str, fields: Dict[str, Field], **kwargs) -> Serializer:
"""
A helper function to create an inline serializer. Primary use is with
:func:`@extend_schema <.extend_schema>`, where one needs an implicit one-off
serializer that is not reflected in an actual class.
:param name: name of the
:param fields: dict with field names as keys and serializer fields as values
:param kwargs: optional kwargs for serializer initialization
"""
serializer_class = type(name, (Serializer,), fields)
return serializer_class(**kwargs)

View File

@@ -0,0 +1,34 @@
import json
import os
import jsonschema
def validate_schema(api_schema):
"""
Validate generated API schema against OpenAPI 3.0.X json schema specification.
Note: On conflict, the written specification always wins over the json schema.
OpenApi3 schema specification taken from:
https://github.com/OAI/OpenAPI-Specification/blob/master/schemas/v3.0/schema.json
https://github.com/OAI/OpenAPI-Specification/blob/9dff244e5708fbe16e768738f4f17cf3fddf4066/schemas/v3.0/schema.json
https://github.com/OAI/OpenAPI-Specification/blob/main/schemas/v3.1/schema.json
https://github.com/OAI/OpenAPI-Specification/blob/9dff244e5708fbe16e768738f4f17cf3fddf4066/schemas/v3.1/schema.json
"""
if api_schema['openapi'].startswith("3.0"):
schema_spec_path = os.path.join(os.path.dirname(__file__), 'openapi_3_0_schema.json')
elif api_schema['openapi'].startswith("3.1"):
schema_spec_path = os.path.join(os.path.dirname(__file__), 'openapi_3_1_schema.json')
else:
raise RuntimeError('No validation specification available') # pragma: no cover
with open(schema_spec_path) as fh:
openapi3_schema_spec = json.load(fh)
# coerce any remnants of objects to basic types
from drf_spectacular.renderers import OpenApiJsonRenderer
api_schema = json.loads(OpenApiJsonRenderer().render(api_schema))
jsonschema.validate(instance=api_schema, schema=openapi3_schema_spec)

View File

@@ -0,0 +1,288 @@
import json
from collections import namedtuple
from importlib import import_module
from typing import Any, Dict, List, Optional, Type
from django.conf import settings
from django.templatetags.static import static
from django.utils import translation
from django.utils.translation import gettext_lazy as _
from django.views.generic import RedirectView
from rest_framework.renderers import TemplateHTMLRenderer
from rest_framework.response import Response
from rest_framework.reverse import reverse
from rest_framework.settings import api_settings
from rest_framework.views import APIView
from drf_spectacular.generators import SchemaGenerator
from drf_spectacular.plumbing import get_relative_url, set_query_parameters
from drf_spectacular.renderers import (
OpenApiJsonRenderer, OpenApiJsonRenderer2, OpenApiYamlRenderer, OpenApiYamlRenderer2,
)
from drf_spectacular.settings import patched_settings, spectacular_settings
from drf_spectacular.types import OpenApiTypes
from drf_spectacular.utils import OpenApiParameter, extend_schema
if spectacular_settings.SERVE_INCLUDE_SCHEMA:
SCHEMA_KWARGS: Dict[str, Any] = {'responses': {200: OpenApiTypes.OBJECT}}
if settings.USE_I18N:
SCHEMA_KWARGS['parameters'] = [
OpenApiParameter(
'lang', str, OpenApiParameter.QUERY, enum=list(dict(settings.LANGUAGES).keys())
)
]
else:
SCHEMA_KWARGS = {'exclude': True}
if spectacular_settings.SERVE_AUTHENTICATION is not None:
AUTHENTICATION_CLASSES = spectacular_settings.SERVE_AUTHENTICATION
else:
AUTHENTICATION_CLASSES = api_settings.DEFAULT_AUTHENTICATION_CLASSES
class SpectacularAPIView(APIView):
__doc__ = _("""
OpenApi3 schema for this API. Format can be selected via content negotiation.
- YAML: application/vnd.oai.openapi
- JSON: application/vnd.oai.openapi+json
""") # type: ignore
renderer_classes = [
OpenApiYamlRenderer, OpenApiYamlRenderer2, OpenApiJsonRenderer, OpenApiJsonRenderer2
]
permission_classes = spectacular_settings.SERVE_PERMISSIONS
authentication_classes = AUTHENTICATION_CLASSES
generator_class: Type[SchemaGenerator] = spectacular_settings.DEFAULT_GENERATOR_CLASS
serve_public: bool = spectacular_settings.SERVE_PUBLIC
urlconf: Optional[str] = spectacular_settings.SERVE_URLCONF
api_version: Optional[str] = None
custom_settings: Optional[Dict[str, Any]] = None
patterns: Optional[List[Any]] = None
@extend_schema(**SCHEMA_KWARGS)
def get(self, request, *args, **kwargs):
# special handling of custom urlconf parameter
if isinstance(self.urlconf, list) or isinstance(self.urlconf, tuple):
ModuleWrapper = namedtuple('ModuleWrapper', ['urlpatterns'])
if all(isinstance(i, str) for i in self.urlconf):
# list of import string for urlconf
patterns = []
for item in self.urlconf:
url = import_module(item)
patterns += url.urlpatterns
self.urlconf = ModuleWrapper(tuple(patterns))
else:
# explicitly resolved urlconf
self.urlconf = ModuleWrapper(tuple(self.urlconf))
with patched_settings(self.custom_settings):
if settings.USE_I18N and request.GET.get('lang'):
with translation.override(request.GET.get('lang')):
return self._get_schema_response(request)
else:
return self._get_schema_response(request)
def _get_schema_response(self, request):
# version specified as parameter to the view always takes precedence. after
# that we try to source version through the schema view's own versioning_class.
version = self.api_version or request.version or self._get_version_parameter(request)
generator = self.generator_class(urlconf=self.urlconf, api_version=version, patterns=self.patterns)
return Response(
data=generator.get_schema(request=request, public=self.serve_public),
headers={"Content-Disposition": f'inline; filename="{self._get_filename(request, version)}"'}
)
def _get_filename(self, request, version):
return "{title}{version}.{suffix}".format(
title=spectacular_settings.TITLE or 'schema',
version=f' ({version})' if version else '',
suffix=self.perform_content_negotiation(request, force=True)[0].format
)
def _get_version_parameter(self, request):
version = request.GET.get('version')
if not api_settings.ALLOWED_VERSIONS or version in api_settings.ALLOWED_VERSIONS:
return version
return None
class SpectacularYAMLAPIView(SpectacularAPIView):
renderer_classes = [OpenApiYamlRenderer, OpenApiYamlRenderer2]
class SpectacularJSONAPIView(SpectacularAPIView):
renderer_classes = [OpenApiJsonRenderer, OpenApiJsonRenderer2]
def _get_sidecar_url(filepath):
return static(f'drf_spectacular_sidecar/{filepath}')
class SpectacularSwaggerView(APIView):
renderer_classes = [TemplateHTMLRenderer]
permission_classes = spectacular_settings.SERVE_PERMISSIONS
authentication_classes = AUTHENTICATION_CLASSES
url_name: str = 'schema'
url: Optional[str] = None
template_name: str = 'drf_spectacular/swagger_ui.html'
template_name_js: str = 'drf_spectacular/swagger_ui.js'
title: str = spectacular_settings.TITLE
@extend_schema(exclude=True)
def get(self, request, *args, **kwargs):
return Response(
data={
'title': self.title,
'swagger_ui_css': self._swagger_ui_resource('swagger-ui.css'),
'swagger_ui_bundle': self._swagger_ui_resource('swagger-ui-bundle.js'),
'swagger_ui_standalone': self._swagger_ui_resource('swagger-ui-standalone-preset.js'),
'favicon_href': self._swagger_ui_favicon(),
'schema_url': self._get_schema_url(request),
'settings': self._dump(spectacular_settings.SWAGGER_UI_SETTINGS),
'oauth2_config': self._dump(spectacular_settings.SWAGGER_UI_OAUTH2_CONFIG),
'template_name_js': self.template_name_js,
'script_url': None,
'csrf_header_name': self._get_csrf_header_name(),
'schema_auth_names': self._dump(self._get_schema_auth_names()),
},
template_name=self.template_name,
headers={
"Cross-Origin-Opener-Policy": "unsafe-none",
}
)
def _dump(self, data):
return data if isinstance(data, str) else json.dumps(data, indent=2)
def _get_schema_url(self, request):
schema_url = self.url or get_relative_url(reverse(self.url_name, request=request))
return set_query_parameters(
url=schema_url,
lang=request.GET.get('lang'),
version=request.GET.get('version')
)
def _get_csrf_header_name(self):
csrf_header_name = settings.CSRF_HEADER_NAME
if csrf_header_name.startswith('HTTP_'):
csrf_header_name = csrf_header_name[5:]
return csrf_header_name.replace('_', '-')
def _get_schema_auth_names(self):
from drf_spectacular.extensions import OpenApiAuthenticationExtension
if spectacular_settings.SERVE_PUBLIC:
return []
auth_extensions = [
OpenApiAuthenticationExtension.get_match(klass)
for klass in self.authentication_classes
]
return [auth.name for auth in auth_extensions if auth]
@staticmethod
def _swagger_ui_resource(filename):
if spectacular_settings.SWAGGER_UI_DIST == 'SIDECAR':
return _get_sidecar_url(f'swagger-ui-dist/{filename}')
return f'{spectacular_settings.SWAGGER_UI_DIST}/{filename}'
@staticmethod
def _swagger_ui_favicon():
if spectacular_settings.SWAGGER_UI_FAVICON_HREF == 'SIDECAR':
return _get_sidecar_url('swagger-ui-dist/favicon-32x32.png')
return spectacular_settings.SWAGGER_UI_FAVICON_HREF
class SpectacularSwaggerSplitView(SpectacularSwaggerView):
"""
Alternate Swagger UI implementation that separates the html request from the
javascript request to cater to web servers with stricter CSP policies.
"""
url_self: Optional[str] = None
@extend_schema(exclude=True)
def get(self, request, *args, **kwargs):
if request.GET.get('script') is not None:
return Response(
data={
'schema_url': self._get_schema_url(request),
'settings': self._dump(spectacular_settings.SWAGGER_UI_SETTINGS),
'oauth2_config': self._dump(spectacular_settings.SWAGGER_UI_OAUTH2_CONFIG),
'csrf_header_name': self._get_csrf_header_name(),
'schema_auth_names': self._dump(self._get_schema_auth_names()),
},
template_name=self.template_name_js,
content_type='application/javascript',
)
else:
script_url = self.url_self or request.get_full_path()
return Response(
data={
'title': self.title,
'swagger_ui_css': self._swagger_ui_resource('swagger-ui.css'),
'swagger_ui_bundle': self._swagger_ui_resource('swagger-ui-bundle.js'),
'swagger_ui_standalone': self._swagger_ui_resource('swagger-ui-standalone-preset.js'),
'favicon_href': self._swagger_ui_favicon(),
'script_url': set_query_parameters(
url=script_url,
lang=request.GET.get('lang'),
script='' # signal to deliver init script
)
},
template_name=self.template_name,
)
class SpectacularRedocView(APIView):
renderer_classes = [TemplateHTMLRenderer]
permission_classes = spectacular_settings.SERVE_PERMISSIONS
authentication_classes = AUTHENTICATION_CLASSES
url_name: str = 'schema'
url: Optional[str] = None
template_name: str = 'drf_spectacular/redoc.html'
title: Optional[str] = spectacular_settings.TITLE
@extend_schema(exclude=True)
def get(self, request, *args, **kwargs):
return Response(
data={
'title': self.title,
'redoc_standalone': self._redoc_standalone(),
'schema_url': self._get_schema_url(request),
'settings': self._dump(spectacular_settings.REDOC_UI_SETTINGS),
},
template_name=self.template_name
)
def _dump(self, data):
if not data:
return None
elif isinstance(data, str):
return data
else:
return json.dumps(data, indent=2)
@staticmethod
def _redoc_standalone():
if spectacular_settings.REDOC_DIST == 'SIDECAR':
return _get_sidecar_url('redoc/bundles/redoc.standalone.js')
return f'{spectacular_settings.REDOC_DIST}/bundles/redoc.standalone.js'
def _get_schema_url(self, request):
schema_url = self.url or get_relative_url(reverse(self.url_name, request=request))
return set_query_parameters(
url=schema_url,
lang=request.GET.get('lang'),
version=request.GET.get('version')
)
class SpectacularSwaggerOauthRedirectView(RedirectView):
"""
A view that serves the SwaggerUI oauth2-redirect.html file so that SwaggerUI can authenticate itself using Oauth2
This view should be served as ``./oauth2-redirect.html`` relative to the SwaggerUI itself.
If that is not possible, this views absolute url can also be set via the
``SPECTACULAR_SETTINGS.SWAGGER_UI_SETTINGS.oauth2RedirectUrl`` django settings.
"""
def get_redirect_url(self, *args, **kwargs):
return _get_sidecar_url("swagger-ui-dist/oauth2-redirect.html") + "?" + self.request.GET.urlencode()