Files
Iliyan Angelov 6b247e5b9f Updates
2025-09-19 11:58:53 +03:00

294 lines
13 KiB
Python

import os
import re
from django.urls import URLPattern, URLResolver
from rest_framework import views, viewsets
from rest_framework.schemas.generators import BaseSchemaGenerator
from rest_framework.schemas.generators import EndpointEnumerator as BaseEndpointEnumerator
from rest_framework.settings import api_settings
from drf_spectacular.drainage import (
add_trace_message, error, get_override, reset_generator_stats, warn,
)
from drf_spectacular.extensions import OpenApiViewExtension
from drf_spectacular.openapi import AutoSchema
from drf_spectacular.plumbing import (
ComponentRegistry, alpha_operation_sorter, build_root_object, camelize_operation, get_class,
is_versioning_supported, modify_for_versioning, normalize_result_object,
operation_matches_version, process_webhooks, sanitize_result_object,
)
from drf_spectacular.settings import spectacular_settings
class EndpointEnumerator(BaseEndpointEnumerator):
def get_api_endpoints(self, patterns=None, prefix=''):
api_endpoints = self._get_api_endpoints(patterns, prefix)
for hook in spectacular_settings.PREPROCESSING_HOOKS:
api_endpoints = hook(endpoints=api_endpoints)
api_endpoints_deduplicated = {}
for path, path_regex, method, callback in api_endpoints:
if (path, method) not in api_endpoints_deduplicated:
api_endpoints_deduplicated[path, method] = (path, path_regex, method, callback)
api_endpoints = list(api_endpoints_deduplicated.values())
if callable(spectacular_settings.SORT_OPERATIONS):
return sorted(api_endpoints, key=spectacular_settings.SORT_OPERATIONS)
elif spectacular_settings.SORT_OPERATIONS:
return sorted(api_endpoints, key=alpha_operation_sorter)
else:
return api_endpoints
def get_path_from_regex(self, path_regex):
path = super().get_path_from_regex(path_regex)
# bugfix oversight in DRF regex stripping
path = path.replace('\\.', '.')
return path
def _get_api_endpoints(self, patterns, prefix):
"""
Return a list of all available API endpoints by inspecting the URL conf.
Only modification the DRF version is passing through the path_regex.
"""
if patterns is None:
patterns = self.patterns
api_endpoints = []
for pattern in patterns:
path_regex = prefix + str(pattern.pattern)
if isinstance(pattern, URLPattern):
path = self.get_path_from_regex(path_regex)
callback = pattern.callback
if self.should_include_endpoint(path, callback):
for method in self.get_allowed_methods(callback):
endpoint = (path, path_regex, method, callback)
api_endpoints.append(endpoint)
elif isinstance(pattern, URLResolver):
nested_endpoints = self._get_api_endpoints(
patterns=pattern.url_patterns,
prefix=path_regex
)
api_endpoints.extend(nested_endpoints)
return api_endpoints
def get_allowed_methods(self, callback):
if hasattr(callback, 'actions'):
actions = set(callback.actions)
if 'http_method_names' in callback.initkwargs:
http_method_names = set(callback.initkwargs['http_method_names'])
else:
http_method_names = set(callback.cls.http_method_names)
methods = [method.upper() for method in actions & http_method_names]
else:
# pass to constructor allowed method names to get valid ones
kwargs = {}
if 'http_method_names' in callback.initkwargs:
kwargs['http_method_names'] = callback.initkwargs['http_method_names']
methods = callback.cls(**kwargs).allowed_methods
return [
method for method in methods
if method not in ('OPTIONS', 'HEAD', 'TRACE', 'CONNECT')
]
class SchemaGenerator(BaseSchemaGenerator):
endpoint_inspector_cls = EndpointEnumerator
def __init__(self, *args, **kwargs):
self.registry = ComponentRegistry()
self.api_version = kwargs.pop('api_version', None)
self.inspector = None
super().__init__(*args, **kwargs)
def coerce_path(self, path, method, view):
"""
Customized coerce_path which also considers the `_pk` suffix in URL paths
of nested routers.
"""
path = super().coerce_path(path, method, view) # take care of {pk}
if spectacular_settings.SCHEMA_COERCE_PATH_PK_SUFFIX:
path = re.sub(pattern=r'{(\w+)_pk}', repl=r'{\1_id}', string=path)
return path
def create_view(self, callback, method, request=None):
"""
customized create_view which is called when all routes are traversed. part of this
is instantiating views with default params. in case of custom routes (@action) the
custom AutoSchema is injected properly through 'initkwargs' on view. However, when
decorating plain views like retrieve, this initialization logic is not running.
Therefore forcefully set the schema if @extend_schema decorator was used.
"""
override_view = OpenApiViewExtension.get_match(callback.cls)
if override_view:
original_cls = callback.cls
callback.cls = override_view.view_replacement()
# we refrain from passing request and deal with it ourselves in parse()
view = super().create_view(callback, method, None)
# drf-yasg compatibility feature. makes the view aware that we are running
# schema generation and not a real request.
view.swagger_fake_view = True
# callback.cls is hosted in urlpatterns and is therefore not an ephemeral modification.
# restore after view creation so potential revisits have a clean state as basis.
if override_view:
callback.cls = original_cls
if isinstance(view, viewsets.ViewSetMixin):
action = getattr(view, view.action)
elif isinstance(view, views.APIView):
action = getattr(view, method.lower())
else:
error(
'Using not supported View class. Class must be derived from APIView '
'or any of its subclasses like GenericApiView, GenericViewSet.'
)
return view
action_schema = getattr(action, 'kwargs', {}).get('schema', None)
if not action_schema:
# there is no method/action customized schema so we are done here.
return view
# action_schema is either a class or instance. when @extend_schema is used, it
# is always a class to prevent the weakref reverse "schema.view" bug for multi
# annotations. The bug is prevented by delaying the instantiation of the schema
# class until create_view (here) and not doing it immediately in @extend_schema.
action_schema_class = get_class(action_schema)
view_schema_class = get_class(callback.cls.schema)
if not issubclass(action_schema_class, view_schema_class):
# this handles the case of having a manually set custom AutoSchema on the
# view together with extend_schema. In most cases, the decorator mechanics
# prevent extend_schema from having access to the view's schema class. So
# extend_schema is forced to use DEFAULT_SCHEMA_CLASS as fallback base class
# instead of the correct base class set in view. We remedy this chicken-egg
# problem here by rearranging the class hierarchy.
mro = tuple(
cls for cls in action_schema_class.__mro__
if cls not in api_settings.DEFAULT_SCHEMA_CLASS.__mro__
) + view_schema_class.__mro__
action_schema_class = type('ExtendedRearrangedSchema', mro, {})
view.schema = action_schema_class()
return view
def _initialise_endpoints(self):
if self.endpoints is None:
self.inspector = self.endpoint_inspector_cls(self.patterns, self.urlconf)
self.endpoints = self.inspector.get_api_endpoints()
def _get_paths_and_endpoints(self):
"""
Generate (path, method, view) given (path, method, callback) for paths.
"""
view_endpoints = []
for path, path_regex, method, callback in self.endpoints:
view = self.create_view(callback, method)
path = self.coerce_path(path, method, view)
view_endpoints.append((path, path_regex, method, view))
return view_endpoints
def parse(self, input_request, public):
""" Iterate endpoints generating per method path operations. """
result = {}
self._initialise_endpoints()
endpoints = self._get_paths_and_endpoints()
if spectacular_settings.SCHEMA_PATH_PREFIX is None:
# estimate common path prefix if none was given. only use it if we encountered more
# than one view to prevent emission of erroneous and unnecessary fallback names.
non_trivial_prefix = len(set([view.__class__ for _, _, _, view in endpoints])) > 1
if non_trivial_prefix:
path_prefix = os.path.commonpath([path for path, _, _, _ in endpoints])
path_prefix = re.escape(path_prefix) # guard for RE special chars in path
else:
path_prefix = '/'
else:
path_prefix = spectacular_settings.SCHEMA_PATH_PREFIX
if not path_prefix.startswith('^'):
path_prefix = '^' + path_prefix # make sure regex only matches from the start
for path, path_regex, method, view in endpoints:
# emit queued up warnings/error that happened prior to generation (decoration)
for w in get_override(view, 'warnings', []):
warn(w)
for e in get_override(view, 'errors', []):
error(e)
view.request = spectacular_settings.GET_MOCK_REQUEST(method, path, view, input_request)
if not (public or self.has_view_permissions(path, method, view)):
continue
if view.versioning_class and not is_versioning_supported(view.versioning_class):
warn(
f'using unsupported versioning class "{view.versioning_class}". view will be '
f'processed as unversioned view.'
)
elif view.versioning_class:
version = (
self.api_version # explicit version from CLI, SpecView or SpecView request
or view.versioning_class.default_version # fallback
)
if not version:
continue
path = modify_for_versioning(self.inspector.patterns, method, path, view, version)
if not operation_matches_version(view, version):
continue
assert isinstance(view.schema, AutoSchema), (
f'Incompatible AutoSchema used on View {view.__class__}. Is DRF\'s '
f'DEFAULT_SCHEMA_CLASS pointing to "drf_spectacular.openapi.AutoSchema" '
f'or any other drf-spectacular compatible AutoSchema?'
)
with add_trace_message(getattr(view, '__class__', view)):
operation = view.schema.get_operation(
path, path_regex, path_prefix, method, self.registry
)
# operation was manually removed via @extend_schema
if not operation:
continue
if spectacular_settings.SCHEMA_PATH_PREFIX_TRIM:
path = re.sub(pattern=path_prefix, repl='', string=path, flags=re.IGNORECASE)
if spectacular_settings.SCHEMA_PATH_PREFIX_INSERT:
path = spectacular_settings.SCHEMA_PATH_PREFIX_INSERT + path
if not path.startswith('/'):
path = '/' + path
if spectacular_settings.CAMELIZE_NAMES:
path, operation = camelize_operation(path, operation)
result.setdefault(path, {})
result[path][method.lower()] = operation
return result
def get_schema(self, request=None, public=False):
""" Generate a OpenAPI schema. """
reset_generator_stats()
result = build_root_object(
paths=self.parse(request, public),
components=self.registry.build(spectacular_settings.APPEND_COMPONENTS),
webhooks=process_webhooks(spectacular_settings.WEBHOOKS, self.registry),
version=self.api_version or getattr(request, 'version', None),
)
for hook in spectacular_settings.POSTPROCESSING_HOOKS:
result = hook(result=result, generator=self, request=request, public=public)
return sanitize_result_object(normalize_result_object(result))