294 lines
13 KiB
Python
294 lines
13 KiB
Python
import os
|
|
import re
|
|
|
|
from django.urls import URLPattern, URLResolver
|
|
from rest_framework import views, viewsets
|
|
from rest_framework.schemas.generators import BaseSchemaGenerator
|
|
from rest_framework.schemas.generators import EndpointEnumerator as BaseEndpointEnumerator
|
|
from rest_framework.settings import api_settings
|
|
|
|
from drf_spectacular.drainage import (
|
|
add_trace_message, error, get_override, reset_generator_stats, warn,
|
|
)
|
|
from drf_spectacular.extensions import OpenApiViewExtension
|
|
from drf_spectacular.openapi import AutoSchema
|
|
from drf_spectacular.plumbing import (
|
|
ComponentRegistry, alpha_operation_sorter, build_root_object, camelize_operation, get_class,
|
|
is_versioning_supported, modify_for_versioning, normalize_result_object,
|
|
operation_matches_version, process_webhooks, sanitize_result_object,
|
|
)
|
|
from drf_spectacular.settings import spectacular_settings
|
|
|
|
|
|
class EndpointEnumerator(BaseEndpointEnumerator):
|
|
def get_api_endpoints(self, patterns=None, prefix=''):
|
|
api_endpoints = self._get_api_endpoints(patterns, prefix)
|
|
|
|
for hook in spectacular_settings.PREPROCESSING_HOOKS:
|
|
api_endpoints = hook(endpoints=api_endpoints)
|
|
|
|
api_endpoints_deduplicated = {}
|
|
for path, path_regex, method, callback in api_endpoints:
|
|
if (path, method) not in api_endpoints_deduplicated:
|
|
api_endpoints_deduplicated[path, method] = (path, path_regex, method, callback)
|
|
|
|
api_endpoints = list(api_endpoints_deduplicated.values())
|
|
|
|
if callable(spectacular_settings.SORT_OPERATIONS):
|
|
return sorted(api_endpoints, key=spectacular_settings.SORT_OPERATIONS)
|
|
elif spectacular_settings.SORT_OPERATIONS:
|
|
return sorted(api_endpoints, key=alpha_operation_sorter)
|
|
else:
|
|
return api_endpoints
|
|
|
|
def get_path_from_regex(self, path_regex):
|
|
path = super().get_path_from_regex(path_regex)
|
|
# bugfix oversight in DRF regex stripping
|
|
path = path.replace('\\.', '.')
|
|
return path
|
|
|
|
def _get_api_endpoints(self, patterns, prefix):
|
|
"""
|
|
Return a list of all available API endpoints by inspecting the URL conf.
|
|
Only modification the DRF version is passing through the path_regex.
|
|
"""
|
|
if patterns is None:
|
|
patterns = self.patterns
|
|
|
|
api_endpoints = []
|
|
|
|
for pattern in patterns:
|
|
path_regex = prefix + str(pattern.pattern)
|
|
if isinstance(pattern, URLPattern):
|
|
path = self.get_path_from_regex(path_regex)
|
|
callback = pattern.callback
|
|
if self.should_include_endpoint(path, callback):
|
|
for method in self.get_allowed_methods(callback):
|
|
endpoint = (path, path_regex, method, callback)
|
|
api_endpoints.append(endpoint)
|
|
|
|
elif isinstance(pattern, URLResolver):
|
|
nested_endpoints = self._get_api_endpoints(
|
|
patterns=pattern.url_patterns,
|
|
prefix=path_regex
|
|
)
|
|
api_endpoints.extend(nested_endpoints)
|
|
|
|
return api_endpoints
|
|
|
|
def get_allowed_methods(self, callback):
|
|
if hasattr(callback, 'actions'):
|
|
actions = set(callback.actions)
|
|
if 'http_method_names' in callback.initkwargs:
|
|
http_method_names = set(callback.initkwargs['http_method_names'])
|
|
else:
|
|
http_method_names = set(callback.cls.http_method_names)
|
|
|
|
methods = [method.upper() for method in actions & http_method_names]
|
|
else:
|
|
# pass to constructor allowed method names to get valid ones
|
|
kwargs = {}
|
|
if 'http_method_names' in callback.initkwargs:
|
|
kwargs['http_method_names'] = callback.initkwargs['http_method_names']
|
|
|
|
methods = callback.cls(**kwargs).allowed_methods
|
|
|
|
return [
|
|
method for method in methods
|
|
if method not in ('OPTIONS', 'HEAD', 'TRACE', 'CONNECT')
|
|
]
|
|
|
|
|
|
class SchemaGenerator(BaseSchemaGenerator):
|
|
endpoint_inspector_cls = EndpointEnumerator
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
self.registry = ComponentRegistry()
|
|
self.api_version = kwargs.pop('api_version', None)
|
|
self.inspector = None
|
|
super().__init__(*args, **kwargs)
|
|
|
|
def coerce_path(self, path, method, view):
|
|
"""
|
|
Customized coerce_path which also considers the `_pk` suffix in URL paths
|
|
of nested routers.
|
|
"""
|
|
path = super().coerce_path(path, method, view) # take care of {pk}
|
|
if spectacular_settings.SCHEMA_COERCE_PATH_PK_SUFFIX:
|
|
path = re.sub(pattern=r'{(\w+)_pk}', repl=r'{\1_id}', string=path)
|
|
return path
|
|
|
|
def create_view(self, callback, method, request=None):
|
|
"""
|
|
customized create_view which is called when all routes are traversed. part of this
|
|
is instantiating views with default params. in case of custom routes (@action) the
|
|
custom AutoSchema is injected properly through 'initkwargs' on view. However, when
|
|
decorating plain views like retrieve, this initialization logic is not running.
|
|
Therefore forcefully set the schema if @extend_schema decorator was used.
|
|
"""
|
|
override_view = OpenApiViewExtension.get_match(callback.cls)
|
|
if override_view:
|
|
original_cls = callback.cls
|
|
callback.cls = override_view.view_replacement()
|
|
|
|
# we refrain from passing request and deal with it ourselves in parse()
|
|
view = super().create_view(callback, method, None)
|
|
|
|
# drf-yasg compatibility feature. makes the view aware that we are running
|
|
# schema generation and not a real request.
|
|
view.swagger_fake_view = True
|
|
|
|
# callback.cls is hosted in urlpatterns and is therefore not an ephemeral modification.
|
|
# restore after view creation so potential revisits have a clean state as basis.
|
|
if override_view:
|
|
callback.cls = original_cls
|
|
|
|
if isinstance(view, viewsets.ViewSetMixin):
|
|
action = getattr(view, view.action)
|
|
elif isinstance(view, views.APIView):
|
|
action = getattr(view, method.lower())
|
|
else:
|
|
error(
|
|
'Using not supported View class. Class must be derived from APIView '
|
|
'or any of its subclasses like GenericApiView, GenericViewSet.'
|
|
)
|
|
return view
|
|
|
|
action_schema = getattr(action, 'kwargs', {}).get('schema', None)
|
|
if not action_schema:
|
|
# there is no method/action customized schema so we are done here.
|
|
return view
|
|
|
|
# action_schema is either a class or instance. when @extend_schema is used, it
|
|
# is always a class to prevent the weakref reverse "schema.view" bug for multi
|
|
# annotations. The bug is prevented by delaying the instantiation of the schema
|
|
# class until create_view (here) and not doing it immediately in @extend_schema.
|
|
action_schema_class = get_class(action_schema)
|
|
view_schema_class = get_class(callback.cls.schema)
|
|
|
|
if not issubclass(action_schema_class, view_schema_class):
|
|
# this handles the case of having a manually set custom AutoSchema on the
|
|
# view together with extend_schema. In most cases, the decorator mechanics
|
|
# prevent extend_schema from having access to the view's schema class. So
|
|
# extend_schema is forced to use DEFAULT_SCHEMA_CLASS as fallback base class
|
|
# instead of the correct base class set in view. We remedy this chicken-egg
|
|
# problem here by rearranging the class hierarchy.
|
|
mro = tuple(
|
|
cls for cls in action_schema_class.__mro__
|
|
if cls not in api_settings.DEFAULT_SCHEMA_CLASS.__mro__
|
|
) + view_schema_class.__mro__
|
|
action_schema_class = type('ExtendedRearrangedSchema', mro, {})
|
|
|
|
view.schema = action_schema_class()
|
|
return view
|
|
|
|
def _initialise_endpoints(self):
|
|
if self.endpoints is None:
|
|
self.inspector = self.endpoint_inspector_cls(self.patterns, self.urlconf)
|
|
self.endpoints = self.inspector.get_api_endpoints()
|
|
|
|
def _get_paths_and_endpoints(self):
|
|
"""
|
|
Generate (path, method, view) given (path, method, callback) for paths.
|
|
"""
|
|
view_endpoints = []
|
|
for path, path_regex, method, callback in self.endpoints:
|
|
view = self.create_view(callback, method)
|
|
path = self.coerce_path(path, method, view)
|
|
view_endpoints.append((path, path_regex, method, view))
|
|
|
|
return view_endpoints
|
|
|
|
def parse(self, input_request, public):
|
|
""" Iterate endpoints generating per method path operations. """
|
|
result = {}
|
|
self._initialise_endpoints()
|
|
endpoints = self._get_paths_and_endpoints()
|
|
|
|
if spectacular_settings.SCHEMA_PATH_PREFIX is None:
|
|
# estimate common path prefix if none was given. only use it if we encountered more
|
|
# than one view to prevent emission of erroneous and unnecessary fallback names.
|
|
non_trivial_prefix = len(set([view.__class__ for _, _, _, view in endpoints])) > 1
|
|
if non_trivial_prefix:
|
|
path_prefix = os.path.commonpath([path for path, _, _, _ in endpoints])
|
|
path_prefix = re.escape(path_prefix) # guard for RE special chars in path
|
|
else:
|
|
path_prefix = '/'
|
|
else:
|
|
path_prefix = spectacular_settings.SCHEMA_PATH_PREFIX
|
|
if not path_prefix.startswith('^'):
|
|
path_prefix = '^' + path_prefix # make sure regex only matches from the start
|
|
|
|
for path, path_regex, method, view in endpoints:
|
|
# emit queued up warnings/error that happened prior to generation (decoration)
|
|
for w in get_override(view, 'warnings', []):
|
|
warn(w)
|
|
for e in get_override(view, 'errors', []):
|
|
error(e)
|
|
|
|
view.request = spectacular_settings.GET_MOCK_REQUEST(method, path, view, input_request)
|
|
|
|
if not (public or self.has_view_permissions(path, method, view)):
|
|
continue
|
|
|
|
if view.versioning_class and not is_versioning_supported(view.versioning_class):
|
|
warn(
|
|
f'using unsupported versioning class "{view.versioning_class}". view will be '
|
|
f'processed as unversioned view.'
|
|
)
|
|
elif view.versioning_class:
|
|
version = (
|
|
self.api_version # explicit version from CLI, SpecView or SpecView request
|
|
or view.versioning_class.default_version # fallback
|
|
)
|
|
if not version:
|
|
continue
|
|
path = modify_for_versioning(self.inspector.patterns, method, path, view, version)
|
|
if not operation_matches_version(view, version):
|
|
continue
|
|
|
|
assert isinstance(view.schema, AutoSchema), (
|
|
f'Incompatible AutoSchema used on View {view.__class__}. Is DRF\'s '
|
|
f'DEFAULT_SCHEMA_CLASS pointing to "drf_spectacular.openapi.AutoSchema" '
|
|
f'or any other drf-spectacular compatible AutoSchema?'
|
|
)
|
|
with add_trace_message(getattr(view, '__class__', view)):
|
|
operation = view.schema.get_operation(
|
|
path, path_regex, path_prefix, method, self.registry
|
|
)
|
|
|
|
# operation was manually removed via @extend_schema
|
|
if not operation:
|
|
continue
|
|
|
|
if spectacular_settings.SCHEMA_PATH_PREFIX_TRIM:
|
|
path = re.sub(pattern=path_prefix, repl='', string=path, flags=re.IGNORECASE)
|
|
|
|
if spectacular_settings.SCHEMA_PATH_PREFIX_INSERT:
|
|
path = spectacular_settings.SCHEMA_PATH_PREFIX_INSERT + path
|
|
|
|
if not path.startswith('/'):
|
|
path = '/' + path
|
|
|
|
if spectacular_settings.CAMELIZE_NAMES:
|
|
path, operation = camelize_operation(path, operation)
|
|
|
|
result.setdefault(path, {})
|
|
result[path][method.lower()] = operation
|
|
|
|
return result
|
|
|
|
def get_schema(self, request=None, public=False):
|
|
""" Generate a OpenAPI schema. """
|
|
reset_generator_stats()
|
|
result = build_root_object(
|
|
paths=self.parse(request, public),
|
|
components=self.registry.build(spectacular_settings.APPEND_COMPONENTS),
|
|
webhooks=process_webhooks(spectacular_settings.WEBHOOKS, self.registry),
|
|
version=self.api_version or getattr(request, 'version', None),
|
|
)
|
|
for hook in spectacular_settings.POSTPROCESSING_HOOKS:
|
|
result = hook(result=result, generator=self, request=request, public=public)
|
|
|
|
return sanitize_result_object(normalize_result_object(result))
|