GNXSOFT.COM

This commit is contained in:
Iliyan Angelov
2025-09-26 00:15:37 +03:00
commit fe26b7cca4
16323 changed files with 2011881 additions and 0 deletions

View File

@@ -0,0 +1,58 @@
"""
rest_framework.schemas
schemas:
__init__.py
generators.py # Top-down schema generation
inspectors.py # Per-endpoint view introspection
utils.py # Shared helper functions
views.py # Houses `SchemaView`, `APIView` subclass.
We expose a minimal "public" API directly from `schemas`. This covers the
basic use-cases:
from rest_framework.schemas import (
AutoSchema,
ManualSchema,
get_schema_view,
SchemaGenerator,
)
Other access should target the submodules directly
"""
from rest_framework.settings import api_settings
from . import coreapi, openapi
from .coreapi import AutoSchema, ManualSchema, SchemaGenerator # noqa
from .inspectors import DefaultSchema # noqa
def get_schema_view(
title=None, url=None, description=None, urlconf=None, renderer_classes=None,
public=False, patterns=None, generator_class=None,
authentication_classes=api_settings.DEFAULT_AUTHENTICATION_CLASSES,
permission_classes=api_settings.DEFAULT_PERMISSION_CLASSES,
version=None):
"""
Return a schema view.
"""
if generator_class is None:
if coreapi.is_enabled():
generator_class = coreapi.SchemaGenerator
else:
generator_class = openapi.SchemaGenerator
generator = generator_class(
title=title, url=url, description=description,
urlconf=urlconf, patterns=patterns, version=version
)
# Avoid import cycle on APIView
from .views import SchemaView
return SchemaView.as_view(
renderer_classes=renderer_classes,
schema_generator=generator,
public=public,
authentication_classes=authentication_classes,
permission_classes=permission_classes,
)

View File

@@ -0,0 +1,616 @@
import warnings
from collections import Counter, OrderedDict
from urllib import parse
from django.db import models
from django.utils.encoding import force_str
from rest_framework import exceptions, serializers
from rest_framework.compat import coreapi, coreschema, uritemplate
from rest_framework.settings import api_settings
from .generators import BaseSchemaGenerator
from .inspectors import ViewInspector
from .utils import get_pk_description, is_list_view
def common_path(paths):
split_paths = [path.strip('/').split('/') for path in paths]
s1 = min(split_paths)
s2 = max(split_paths)
common = s1
for i, c in enumerate(s1):
if c != s2[i]:
common = s1[:i]
break
return '/' + '/'.join(common)
def is_custom_action(action):
return action not in {
'retrieve', 'list', 'create', 'update', 'partial_update', 'destroy'
}
def distribute_links(obj):
for key, value in obj.items():
distribute_links(value)
for preferred_key, link in obj.links:
key = obj.get_available_key(preferred_key)
obj[key] = link
INSERT_INTO_COLLISION_FMT = """
Schema Naming Collision.
coreapi.Link for URL path {value_url} cannot be inserted into schema.
Position conflicts with coreapi.Link for URL path {target_url}.
Attempted to insert link with keys: {keys}.
Adjust URLs to avoid naming collision or override `SchemaGenerator.get_keys()`
to customise schema structure.
"""
class LinkNode(OrderedDict):
def __init__(self):
self.links = []
self.methods_counter = Counter()
super().__init__()
def get_available_key(self, preferred_key):
if preferred_key not in self:
return preferred_key
while True:
current_val = self.methods_counter[preferred_key]
self.methods_counter[preferred_key] += 1
key = '{}_{}'.format(preferred_key, current_val)
if key not in self:
return key
def insert_into(target, keys, value):
"""
Nested dictionary insertion.
>>> example = {}
>>> insert_into(example, ['a', 'b', 'c'], 123)
>>> example
LinkNode({'a': LinkNode({'b': LinkNode({'c': LinkNode(links=[123])}}})))
"""
for key in keys[:-1]:
if key not in target:
target[key] = LinkNode()
target = target[key]
try:
target.links.append((keys[-1], value))
except TypeError:
msg = INSERT_INTO_COLLISION_FMT.format(
value_url=value.url,
target_url=target.url,
keys=keys
)
raise ValueError(msg)
class SchemaGenerator(BaseSchemaGenerator):
"""
Original CoreAPI version.
"""
# Map HTTP methods onto actions.
default_mapping = {
'get': 'retrieve',
'post': 'create',
'put': 'update',
'patch': 'partial_update',
'delete': 'destroy',
}
# Map the method names we use for viewset actions onto external schema names.
# These give us names that are more suitable for the external representation.
# Set by 'SCHEMA_COERCE_METHOD_NAMES'.
coerce_method_names = None
def __init__(self, title=None, url=None, description=None, patterns=None, urlconf=None, version=None):
assert coreapi, '`coreapi` must be installed for schema support.'
assert coreschema, '`coreschema` must be installed for schema support.'
super().__init__(title, url, description, patterns, urlconf)
self.coerce_method_names = api_settings.SCHEMA_COERCE_METHOD_NAMES
def get_links(self, request=None):
"""
Return a dictionary containing all the links that should be
included in the API schema.
"""
links = LinkNode()
paths, view_endpoints = self._get_paths_and_endpoints(request)
# Only generate the path prefix for paths that will be included
if not paths:
return None
prefix = self.determine_path_prefix(paths)
for path, method, view in view_endpoints:
if not self.has_view_permissions(path, method, view):
continue
link = view.schema.get_link(path, method, base_url=self.url)
subpath = path[len(prefix):]
keys = self.get_keys(subpath, method, view)
insert_into(links, keys, link)
return links
def get_schema(self, request=None, public=False):
"""
Generate a `coreapi.Document` representing the API schema.
"""
self._initialise_endpoints()
links = self.get_links(None if public else request)
if not links:
return None
url = self.url
if not url and request is not None:
url = request.build_absolute_uri()
distribute_links(links)
return coreapi.Document(
title=self.title, description=self.description,
url=url, content=links
)
# Method for generating the link layout....
def get_keys(self, subpath, method, view):
"""
Return a list of keys that should be used to layout a link within
the schema document.
/users/ ("users", "list"), ("users", "create")
/users/{pk}/ ("users", "read"), ("users", "update"), ("users", "delete")
/users/enabled/ ("users", "enabled") # custom viewset list action
/users/{pk}/star/ ("users", "star") # custom viewset detail action
/users/{pk}/groups/ ("users", "groups", "list"), ("users", "groups", "create")
/users/{pk}/groups/{pk}/ ("users", "groups", "read"), ("users", "groups", "update"), ("users", "groups", "delete")
"""
if hasattr(view, 'action'):
# Viewsets have explicitly named actions.
action = view.action
else:
# Views have no associated action, so we determine one from the method.
if is_list_view(subpath, method, view):
action = 'list'
else:
action = self.default_mapping[method.lower()]
named_path_components = [
component for component
in subpath.strip('/').split('/')
if '{' not in component
]
if is_custom_action(action):
# Custom action, eg "/users/{pk}/activate/", "/users/active/"
mapped_methods = {
# Don't count head mapping, e.g. not part of the schema
method for method in view.action_map if method != 'head'
}
if len(mapped_methods) > 1:
action = self.default_mapping[method.lower()]
if action in self.coerce_method_names:
action = self.coerce_method_names[action]
return named_path_components + [action]
else:
return named_path_components[:-1] + [action]
if action in self.coerce_method_names:
action = self.coerce_method_names[action]
# Default action, eg "/users/", "/users/{pk}/"
return named_path_components + [action]
def determine_path_prefix(self, paths):
"""
Given a list of all paths, return the common prefix which should be
discounted when generating a schema structure.
This will be the longest common string that does not include that last
component of the URL, or the last component before a path parameter.
For example:
/api/v1/users/
/api/v1/users/{pk}/
The path prefix is '/api/v1'
"""
prefixes = []
for path in paths:
components = path.strip('/').split('/')
initial_components = []
for component in components:
if '{' in component:
break
initial_components.append(component)
prefix = '/'.join(initial_components[:-1])
if not prefix:
# We can just break early in the case that there's at least
# one URL that doesn't have a path prefix.
return '/'
prefixes.append('/' + prefix + '/')
return common_path(prefixes)
# View Inspectors #
def field_to_schema(field):
title = force_str(field.label) if field.label else ''
description = force_str(field.help_text) if field.help_text else ''
if isinstance(field, (serializers.ListSerializer, serializers.ListField)):
child_schema = field_to_schema(field.child)
return coreschema.Array(
items=child_schema,
title=title,
description=description
)
elif isinstance(field, serializers.DictField):
return coreschema.Object(
title=title,
description=description
)
elif isinstance(field, serializers.Serializer):
return coreschema.Object(
properties=OrderedDict([
(key, field_to_schema(value))
for key, value
in field.fields.items()
]),
title=title,
description=description
)
elif isinstance(field, serializers.ManyRelatedField):
related_field_schema = field_to_schema(field.child_relation)
return coreschema.Array(
items=related_field_schema,
title=title,
description=description
)
elif isinstance(field, serializers.PrimaryKeyRelatedField):
schema_cls = coreschema.String
model = getattr(field.queryset, 'model', None)
if model is not None:
model_field = model._meta.pk
if isinstance(model_field, models.AutoField):
schema_cls = coreschema.Integer
return schema_cls(title=title, description=description)
elif isinstance(field, serializers.RelatedField):
return coreschema.String(title=title, description=description)
elif isinstance(field, serializers.MultipleChoiceField):
return coreschema.Array(
items=coreschema.Enum(enum=list(field.choices)),
title=title,
description=description
)
elif isinstance(field, serializers.ChoiceField):
return coreschema.Enum(
enum=list(field.choices),
title=title,
description=description
)
elif isinstance(field, serializers.BooleanField):
return coreschema.Boolean(title=title, description=description)
elif isinstance(field, (serializers.DecimalField, serializers.FloatField)):
return coreschema.Number(title=title, description=description)
elif isinstance(field, serializers.IntegerField):
return coreschema.Integer(title=title, description=description)
elif isinstance(field, serializers.DateField):
return coreschema.String(
title=title,
description=description,
format='date'
)
elif isinstance(field, serializers.DateTimeField):
return coreschema.String(
title=title,
description=description,
format='date-time'
)
elif isinstance(field, serializers.JSONField):
return coreschema.Object(title=title, description=description)
if field.style.get('base_template') == 'textarea.html':
return coreschema.String(
title=title,
description=description,
format='textarea'
)
return coreschema.String(title=title, description=description)
class AutoSchema(ViewInspector):
"""
Default inspector for APIView
Responsible for per-view introspection and schema generation.
"""
def __init__(self, manual_fields=None):
"""
Parameters:
* `manual_fields`: list of `coreapi.Field` instances that
will be added to auto-generated fields, overwriting on `Field.name`
"""
super().__init__()
if manual_fields is None:
manual_fields = []
self._manual_fields = manual_fields
def get_link(self, path, method, base_url):
"""
Generate `coreapi.Link` for self.view, path and method.
This is the main _public_ access point.
Parameters:
* path: Route path for view from URLConf.
* method: The HTTP request method.
* base_url: The project "mount point" as given to SchemaGenerator
"""
fields = self.get_path_fields(path, method)
fields += self.get_serializer_fields(path, method)
fields += self.get_pagination_fields(path, method)
fields += self.get_filter_fields(path, method)
manual_fields = self.get_manual_fields(path, method)
fields = self.update_fields(fields, manual_fields)
if fields and any([field.location in ('form', 'body') for field in fields]):
encoding = self.get_encoding(path, method)
else:
encoding = None
description = self.get_description(path, method)
if base_url and path.startswith('/'):
path = path[1:]
return coreapi.Link(
url=parse.urljoin(base_url, path),
action=method.lower(),
encoding=encoding,
fields=fields,
description=description
)
def get_path_fields(self, path, method):
"""
Return a list of `coreapi.Field` instances corresponding to any
templated path variables.
"""
view = self.view
model = getattr(getattr(view, 'queryset', None), 'model', None)
fields = []
for variable in uritemplate.variables(path):
title = ''
description = ''
schema_cls = coreschema.String
kwargs = {}
if model is not None:
# Attempt to infer a field description if possible.
try:
model_field = model._meta.get_field(variable)
except Exception:
model_field = None
if model_field is not None and model_field.verbose_name:
title = force_str(model_field.verbose_name)
if model_field is not None and model_field.help_text:
description = force_str(model_field.help_text)
elif model_field is not None and model_field.primary_key:
description = get_pk_description(model, model_field)
if hasattr(view, 'lookup_value_regex') and view.lookup_field == variable:
kwargs['pattern'] = view.lookup_value_regex
elif isinstance(model_field, models.AutoField):
schema_cls = coreschema.Integer
field = coreapi.Field(
name=variable,
location='path',
required=True,
schema=schema_cls(title=title, description=description, **kwargs)
)
fields.append(field)
return fields
def get_serializer_fields(self, path, method):
"""
Return a list of `coreapi.Field` instances corresponding to any
request body input, as determined by the serializer class.
"""
view = self.view
if method not in ('PUT', 'PATCH', 'POST'):
return []
if not hasattr(view, 'get_serializer'):
return []
try:
serializer = view.get_serializer()
except exceptions.APIException:
serializer = None
warnings.warn('{}.get_serializer() raised an exception during '
'schema generation. Serializer fields will not be '
'generated for {} {}.'
.format(view.__class__.__name__, method, path))
if isinstance(serializer, serializers.ListSerializer):
return [
coreapi.Field(
name='data',
location='body',
required=True,
schema=coreschema.Array()
)
]
if not isinstance(serializer, serializers.Serializer):
return []
fields = []
for field in serializer.fields.values():
if field.read_only or isinstance(field, serializers.HiddenField):
continue
required = field.required and method != 'PATCH'
field = coreapi.Field(
name=field.field_name,
location='form',
required=required,
schema=field_to_schema(field)
)
fields.append(field)
return fields
def get_pagination_fields(self, path, method):
view = self.view
if not is_list_view(path, method, view):
return []
pagination = getattr(view, 'pagination_class', None)
if not pagination:
return []
paginator = view.pagination_class()
return paginator.get_schema_fields(view)
def _allows_filters(self, path, method):
"""
Determine whether to include filter Fields in schema.
Default implementation looks for ModelViewSet or GenericAPIView
actions/methods that cause filtering on the default implementation.
Override to adjust behaviour for your view.
Note: Introduced in v3.7: Initially "private" (i.e. with leading underscore)
to allow changes based on user experience.
"""
if getattr(self.view, 'filter_backends', None) is None:
return False
if hasattr(self.view, 'action'):
return self.view.action in ["list", "retrieve", "update", "partial_update", "destroy"]
return method.lower() in ["get", "put", "patch", "delete"]
def get_filter_fields(self, path, method):
if not self._allows_filters(path, method):
return []
fields = []
for filter_backend in self.view.filter_backends:
fields += filter_backend().get_schema_fields(self.view)
return fields
def get_manual_fields(self, path, method):
return self._manual_fields
@staticmethod
def update_fields(fields, update_with):
"""
Update list of coreapi.Field instances, overwriting on `Field.name`.
Utility function to handle replacing coreapi.Field fields
from a list by name. Used to handle `manual_fields`.
Parameters:
* `fields`: list of `coreapi.Field` instances to update
* `update_with: list of `coreapi.Field` instances to add or replace.
"""
if not update_with:
return fields
by_name = OrderedDict((f.name, f) for f in fields)
for f in update_with:
by_name[f.name] = f
fields = list(by_name.values())
return fields
def get_encoding(self, path, method):
"""
Return the 'encoding' parameter to use for a given endpoint.
"""
view = self.view
# Core API supports the following request encodings over HTTP...
supported_media_types = {
'application/json',
'application/x-www-form-urlencoded',
'multipart/form-data',
}
parser_classes = getattr(view, 'parser_classes', [])
for parser_class in parser_classes:
media_type = getattr(parser_class, 'media_type', None)
if media_type in supported_media_types:
return media_type
# Raw binary uploads are supported with "application/octet-stream"
if media_type == '*/*':
return 'application/octet-stream'
return None
class ManualSchema(ViewInspector):
"""
Allows providing a list of coreapi.Fields,
plus an optional description.
"""
def __init__(self, fields, description='', encoding=None):
"""
Parameters:
* `fields`: list of `coreapi.Field` instances.
* `description`: String description for view. Optional.
"""
super().__init__()
assert all(isinstance(f, coreapi.Field) for f in fields), "`fields` must be a list of coreapi.Field instances"
self._fields = fields
self._description = description
self._encoding = encoding
def get_link(self, path, method, base_url):
if base_url and path.startswith('/'):
path = path[1:]
return coreapi.Link(
url=parse.urljoin(base_url, path),
action=method.lower(),
encoding=self._encoding,
fields=self._fields,
description=self._description
)
def is_enabled():
"""Is CoreAPI Mode enabled?"""
return issubclass(api_settings.DEFAULT_SCHEMA_CLASS, AutoSchema)

View File

@@ -0,0 +1,239 @@
"""
generators.py # Top-down schema generation
See schemas.__init__.py for package overview.
"""
import re
from importlib import import_module
from django.conf import settings
from django.contrib.admindocs.views import simplify_regex
from django.core.exceptions import PermissionDenied
from django.http import Http404
from django.urls import URLPattern, URLResolver
from rest_framework import exceptions
from rest_framework.request import clone_request
from rest_framework.settings import api_settings
from rest_framework.utils.model_meta import _get_pk
def get_pk_name(model):
meta = model._meta.concrete_model._meta
return _get_pk(meta).name
def is_api_view(callback):
"""
Return `True` if the given view callback is a REST framework view/viewset.
"""
# Avoid import cycle on APIView
from rest_framework.views import APIView
cls = getattr(callback, 'cls', None)
return (cls is not None) and issubclass(cls, APIView)
def endpoint_ordering(endpoint):
path, method, callback = endpoint
method_priority = {
'GET': 0,
'POST': 1,
'PUT': 2,
'PATCH': 3,
'DELETE': 4
}.get(method, 5)
return (method_priority,)
_PATH_PARAMETER_COMPONENT_RE = re.compile(
r'<(?:(?P<converter>[^>:]+):)?(?P<parameter>\w+)>'
)
class EndpointEnumerator:
"""
A class to determine the available API endpoints that a project exposes.
"""
def __init__(self, patterns=None, urlconf=None):
if patterns is None:
if urlconf is None:
# Use the default Django URL conf
urlconf = settings.ROOT_URLCONF
# Load the given URLconf module
if isinstance(urlconf, str):
urls = import_module(urlconf)
else:
urls = urlconf
patterns = urls.urlpatterns
self.patterns = patterns
def get_api_endpoints(self, patterns=None, prefix=''):
"""
Return a list of all available API endpoints by inspecting the URL conf.
"""
if patterns is None:
patterns = self.patterns
api_endpoints = []
for pattern in patterns:
path_regex = prefix + str(pattern.pattern)
if isinstance(pattern, URLPattern):
path = self.get_path_from_regex(path_regex)
callback = pattern.callback
if self.should_include_endpoint(path, callback):
for method in self.get_allowed_methods(callback):
endpoint = (path, method, callback)
api_endpoints.append(endpoint)
elif isinstance(pattern, URLResolver):
nested_endpoints = self.get_api_endpoints(
patterns=pattern.url_patterns,
prefix=path_regex
)
api_endpoints.extend(nested_endpoints)
return sorted(api_endpoints, key=endpoint_ordering)
def get_path_from_regex(self, path_regex):
"""
Given a URL conf regex, return a URI template string.
"""
# ???: Would it be feasible to adjust this such that we generate the
# path, plus the kwargs, plus the type from the convertor, such that we
# could feed that straight into the parameter schema object?
path = simplify_regex(path_regex)
# Strip Django 2.0 convertors as they are incompatible with uritemplate format
return re.sub(_PATH_PARAMETER_COMPONENT_RE, r'{\g<parameter>}', path)
def should_include_endpoint(self, path, callback):
"""
Return `True` if the given endpoint should be included.
"""
if not is_api_view(callback):
return False # Ignore anything except REST framework views.
if callback.cls.schema is None:
return False
if 'schema' in callback.initkwargs:
if callback.initkwargs['schema'] is None:
return False
if path.endswith('.{format}') or path.endswith('.{format}/'):
return False # Ignore .json style URLs.
return True
def get_allowed_methods(self, callback):
"""
Return a list of the valid HTTP methods for this endpoint.
"""
if hasattr(callback, 'actions'):
actions = set(callback.actions)
http_method_names = set(callback.cls.http_method_names)
methods = [method.upper() for method in actions & http_method_names]
else:
methods = callback.cls().allowed_methods
return [method for method in methods if method not in ('OPTIONS', 'HEAD')]
class BaseSchemaGenerator:
endpoint_inspector_cls = EndpointEnumerator
# 'pk' isn't great as an externally exposed name for an identifier,
# so by default we prefer to use the actual model field name for schemas.
# Set by 'SCHEMA_COERCE_PATH_PK'.
coerce_path_pk = None
def __init__(self, title=None, url=None, description=None, patterns=None, urlconf=None, version=None):
if url and not url.endswith('/'):
url += '/'
self.coerce_path_pk = api_settings.SCHEMA_COERCE_PATH_PK
self.patterns = patterns
self.urlconf = urlconf
self.title = title
self.description = description
self.version = version
self.url = url
self.endpoints = None
def _initialise_endpoints(self):
if self.endpoints is None:
inspector = self.endpoint_inspector_cls(self.patterns, self.urlconf)
self.endpoints = inspector.get_api_endpoints()
def _get_paths_and_endpoints(self, request):
"""
Generate (path, method, view) given (path, method, callback) for paths.
"""
paths = []
view_endpoints = []
for path, method, callback in self.endpoints:
view = self.create_view(callback, method, request)
path = self.coerce_path(path, method, view)
paths.append(path)
view_endpoints.append((path, method, view))
return paths, view_endpoints
def create_view(self, callback, method, request=None):
"""
Given a callback, return an actual view instance.
"""
view = callback.cls(**getattr(callback, 'initkwargs', {}))
view.args = ()
view.kwargs = {}
view.format_kwarg = None
view.request = None
view.action_map = getattr(callback, 'actions', None)
actions = getattr(callback, 'actions', None)
if actions is not None:
if method == 'OPTIONS':
view.action = 'metadata'
else:
view.action = actions.get(method.lower())
if request is not None:
view.request = clone_request(request, method)
return view
def coerce_path(self, path, method, view):
"""
Coerce {pk} path arguments into the name of the model field,
where possible. This is cleaner for an external representation.
(Ie. "this is an identifier", not "this is a database primary key")
"""
if not self.coerce_path_pk or '{pk}' not in path:
return path
model = getattr(getattr(view, 'queryset', None), 'model', None)
if model:
field_name = get_pk_name(model)
else:
field_name = 'id'
return path.replace('{pk}', '{%s}' % field_name)
def get_schema(self, request=None, public=False):
raise NotImplementedError(".get_schema() must be implemented in subclasses.")
def has_view_permissions(self, path, method, view):
"""
Return `True` if the incoming request has the correct view permissions.
"""
if view.request is None:
return True
try:
view.check_permissions(view.request)
except (exceptions.APIException, Http404, PermissionDenied):
return False
return True

View File

@@ -0,0 +1,125 @@
"""
inspectors.py # Per-endpoint view introspection
See schemas.__init__.py for package overview.
"""
import re
from weakref import WeakKeyDictionary
from django.utils.encoding import smart_str
from rest_framework.settings import api_settings
from rest_framework.utils import formatting
class ViewInspector:
"""
Descriptor class on APIView.
Provide subclass for per-view schema generation
"""
# Used in _get_description_section()
header_regex = re.compile('^[a-zA-Z][0-9A-Za-z_]*:')
def __init__(self):
self.instance_schemas = WeakKeyDictionary()
def __get__(self, instance, owner):
"""
Enables `ViewInspector` as a Python _Descriptor_.
This is how `view.schema` knows about `view`.
`__get__` is called when the descriptor is accessed on the owner.
(That will be when view.schema is called in our case.)
`owner` is always the owner class. (An APIView, or subclass for us.)
`instance` is the view instance or `None` if accessed from the class,
rather than an instance.
See: https://docs.python.org/3/howto/descriptor.html for info on
descriptor usage.
"""
if instance in self.instance_schemas:
return self.instance_schemas[instance]
self.view = instance
return self
def __set__(self, instance, other):
self.instance_schemas[instance] = other
if other is not None:
other.view = instance
@property
def view(self):
"""View property."""
assert self._view is not None, (
"Schema generation REQUIRES a view instance. (Hint: you accessed "
"`schema` from the view class rather than an instance.)"
)
return self._view
@view.setter
def view(self, value):
self._view = value
@view.deleter
def view(self):
self._view = None
def get_description(self, path, method):
"""
Determine a path description.
This will be based on the method docstring if one exists,
or else the class docstring.
"""
view = self.view
method_name = getattr(view, 'action', method.lower())
method_docstring = getattr(view, method_name, None).__doc__
if method_docstring:
# An explicit docstring on the method or action.
return self._get_description_section(view, method.lower(), formatting.dedent(smart_str(method_docstring)))
else:
return self._get_description_section(view, getattr(view, 'action', method.lower()),
view.get_view_description())
def _get_description_section(self, view, header, description):
lines = [line for line in description.splitlines()]
current_section = ''
sections = {'': ''}
for line in lines:
if self.header_regex.match(line):
current_section, separator, lead = line.partition(':')
sections[current_section] = lead.strip()
else:
sections[current_section] += '\n' + line
# TODO: SCHEMA_COERCE_METHOD_NAMES appears here and in `SchemaGenerator.get_keys`
coerce_method_names = api_settings.SCHEMA_COERCE_METHOD_NAMES
if header in sections:
return sections[header].strip()
if header in coerce_method_names:
if coerce_method_names[header] in sections:
return sections[coerce_method_names[header]].strip()
return sections[''].strip()
class DefaultSchema(ViewInspector):
"""Allows overriding AutoSchema using DEFAULT_SCHEMA_CLASS setting"""
def __get__(self, instance, owner):
result = super().__get__(instance, owner)
if not isinstance(result, DefaultSchema):
return result
inspector_class = api_settings.DEFAULT_SCHEMA_CLASS
assert issubclass(inspector_class, ViewInspector), (
"DEFAULT_SCHEMA_CLASS must be set to a ViewInspector (usually an AutoSchema) subclass"
)
inspector = inspector_class()
inspector.view = instance
return inspector

View File

@@ -0,0 +1,722 @@
import re
import warnings
from collections import OrderedDict
from decimal import Decimal
from operator import attrgetter
from urllib.parse import urljoin
from django.core.validators import (
DecimalValidator, EmailValidator, MaxLengthValidator, MaxValueValidator,
MinLengthValidator, MinValueValidator, RegexValidator, URLValidator
)
from django.db import models
from django.utils.encoding import force_str
from rest_framework import (
RemovedInDRF315Warning, exceptions, renderers, serializers
)
from rest_framework.compat import uritemplate
from rest_framework.fields import _UnvalidatedField, empty
from rest_framework.settings import api_settings
from .generators import BaseSchemaGenerator
from .inspectors import ViewInspector
from .utils import get_pk_description, is_list_view
class SchemaGenerator(BaseSchemaGenerator):
def get_info(self):
# Title and version are required by openapi specification 3.x
info = {
'title': self.title or '',
'version': self.version or ''
}
if self.description is not None:
info['description'] = self.description
return info
def check_duplicate_operation_id(self, paths):
ids = {}
for route in paths:
for method in paths[route]:
if 'operationId' not in paths[route][method]:
continue
operation_id = paths[route][method]['operationId']
if operation_id in ids:
warnings.warn(
'You have a duplicated operationId in your OpenAPI schema: {operation_id}\n'
'\tRoute: {route1}, Method: {method1}\n'
'\tRoute: {route2}, Method: {method2}\n'
'\tAn operationId has to be unique across your schema. Your schema may not work in other tools.'
.format(
route1=ids[operation_id]['route'],
method1=ids[operation_id]['method'],
route2=route,
method2=method,
operation_id=operation_id
)
)
ids[operation_id] = {
'route': route,
'method': method
}
def get_schema(self, request=None, public=False):
"""
Generate a OpenAPI schema.
"""
self._initialise_endpoints()
components_schemas = {}
# Iterate endpoints generating per method path operations.
paths = {}
_, view_endpoints = self._get_paths_and_endpoints(None if public else request)
for path, method, view in view_endpoints:
if not self.has_view_permissions(path, method, view):
continue
operation = view.schema.get_operation(path, method)
components = view.schema.get_components(path, method)
for k in components.keys():
if k not in components_schemas:
continue
if components_schemas[k] == components[k]:
continue
warnings.warn('Schema component "{}" has been overriden with a different value.'.format(k))
components_schemas.update(components)
# Normalise path for any provided mount url.
if path.startswith('/'):
path = path[1:]
path = urljoin(self.url or '/', path)
paths.setdefault(path, {})
paths[path][method.lower()] = operation
self.check_duplicate_operation_id(paths)
# Compile final schema.
schema = {
'openapi': '3.0.2',
'info': self.get_info(),
'paths': paths,
}
if len(components_schemas) > 0:
schema['components'] = {
'schemas': components_schemas
}
return schema
# View Inspectors
class AutoSchema(ViewInspector):
def __init__(self, tags=None, operation_id_base=None, component_name=None):
"""
:param operation_id_base: user-defined name in operationId. If empty, it will be deducted from the Model/Serializer/View name.
:param component_name: user-defined component's name. If empty, it will be deducted from the Serializer's class name.
"""
if tags and not all(isinstance(tag, str) for tag in tags):
raise ValueError('tags must be a list or tuple of string.')
self._tags = tags
self.operation_id_base = operation_id_base
self.component_name = component_name
super().__init__()
request_media_types = []
response_media_types = []
method_mapping = {
'get': 'retrieve',
'post': 'create',
'put': 'update',
'patch': 'partialUpdate',
'delete': 'destroy',
}
def get_operation(self, path, method):
operation = {}
operation['operationId'] = self.get_operation_id(path, method)
operation['description'] = self.get_description(path, method)
parameters = []
parameters += self.get_path_parameters(path, method)
parameters += self.get_pagination_parameters(path, method)
parameters += self.get_filter_parameters(path, method)
operation['parameters'] = parameters
request_body = self.get_request_body(path, method)
if request_body:
operation['requestBody'] = request_body
operation['responses'] = self.get_responses(path, method)
operation['tags'] = self.get_tags(path, method)
return operation
def get_component_name(self, serializer):
"""
Compute the component's name from the serializer.
Raise an exception if the serializer's class name is "Serializer" (case-insensitive).
"""
if self.component_name is not None:
return self.component_name
# use the serializer's class name as the component name.
component_name = serializer.__class__.__name__
# We remove the "serializer" string from the class name.
pattern = re.compile("serializer", re.IGNORECASE)
component_name = pattern.sub("", component_name)
if component_name == "":
raise Exception(
'"{}" is an invalid class name for schema generation. '
'Serializer\'s class name should be unique and explicit. e.g. "ItemSerializer"'
.format(serializer.__class__.__name__)
)
return component_name
def get_components(self, path, method):
"""
Return components with their properties from the serializer.
"""
if method.lower() == 'delete':
return {}
request_serializer = self.get_request_serializer(path, method)
response_serializer = self.get_response_serializer(path, method)
components = {}
if isinstance(request_serializer, serializers.Serializer):
component_name = self.get_component_name(request_serializer)
content = self.map_serializer(request_serializer)
components.setdefault(component_name, content)
if isinstance(response_serializer, serializers.Serializer):
component_name = self.get_component_name(response_serializer)
content = self.map_serializer(response_serializer)
components.setdefault(component_name, content)
return components
def _to_camel_case(self, snake_str):
components = snake_str.split('_')
# We capitalize the first letter of each component except the first one
# with the 'title' method and join them together.
return components[0] + ''.join(x.title() for x in components[1:])
def get_operation_id_base(self, path, method, action):
"""
Compute the base part for operation ID from the model, serializer or view name.
"""
model = getattr(getattr(self.view, 'queryset', None), 'model', None)
if self.operation_id_base is not None:
name = self.operation_id_base
# Try to deduce the ID from the view's model
elif model is not None:
name = model.__name__
# Try with the serializer class name
elif self.get_serializer(path, method) is not None:
name = self.get_serializer(path, method).__class__.__name__
if name.endswith('Serializer'):
name = name[:-10]
# Fallback to the view name
else:
name = self.view.__class__.__name__
if name.endswith('APIView'):
name = name[:-7]
elif name.endswith('View'):
name = name[:-4]
# Due to camel-casing of classes and `action` being lowercase, apply title in order to find if action truly
# comes at the end of the name
if name.endswith(action.title()): # ListView, UpdateAPIView, ThingDelete ...
name = name[:-len(action)]
if action == 'list' and not name.endswith('s'): # listThings instead of listThing
name += 's'
return name
def get_operation_id(self, path, method):
"""
Compute an operation ID from the view type and get_operation_id_base method.
"""
method_name = getattr(self.view, 'action', method.lower())
if is_list_view(path, method, self.view):
action = 'list'
elif method_name not in self.method_mapping:
action = self._to_camel_case(method_name)
else:
action = self.method_mapping[method.lower()]
name = self.get_operation_id_base(path, method, action)
return action + name
def get_path_parameters(self, path, method):
"""
Return a list of parameters from templated path variables.
"""
assert uritemplate, '`uritemplate` must be installed for OpenAPI schema support.'
model = getattr(getattr(self.view, 'queryset', None), 'model', None)
parameters = []
for variable in uritemplate.variables(path):
description = ''
if model is not None: # TODO: test this.
# Attempt to infer a field description if possible.
try:
model_field = model._meta.get_field(variable)
except Exception:
model_field = None
if model_field is not None and model_field.help_text:
description = force_str(model_field.help_text)
elif model_field is not None and model_field.primary_key:
description = get_pk_description(model, model_field)
parameter = {
"name": variable,
"in": "path",
"required": True,
"description": description,
'schema': {
'type': 'string', # TODO: integer, pattern, ...
},
}
parameters.append(parameter)
return parameters
def get_filter_parameters(self, path, method):
if not self.allows_filters(path, method):
return []
parameters = []
for filter_backend in self.view.filter_backends:
parameters += filter_backend().get_schema_operation_parameters(self.view)
return parameters
def allows_filters(self, path, method):
"""
Determine whether to include filter Fields in schema.
Default implementation looks for ModelViewSet or GenericAPIView
actions/methods that cause filtering on the default implementation.
"""
if getattr(self.view, 'filter_backends', None) is None:
return False
if hasattr(self.view, 'action'):
return self.view.action in ["list", "retrieve", "update", "partial_update", "destroy"]
return method.lower() in ["get", "put", "patch", "delete"]
def get_pagination_parameters(self, path, method):
view = self.view
if not is_list_view(path, method, view):
return []
paginator = self.get_paginator()
if not paginator:
return []
return paginator.get_schema_operation_parameters(view)
def map_choicefield(self, field):
choices = list(OrderedDict.fromkeys(field.choices)) # preserve order and remove duplicates
if all(isinstance(choice, bool) for choice in choices):
type = 'boolean'
elif all(isinstance(choice, int) for choice in choices):
type = 'integer'
elif all(isinstance(choice, (int, float, Decimal)) for choice in choices): # `number` includes `integer`
# Ref: https://tools.ietf.org/html/draft-wright-json-schema-validation-00#section-5.21
type = 'number'
elif all(isinstance(choice, str) for choice in choices):
type = 'string'
else:
type = None
mapping = {
# The value of `enum` keyword MUST be an array and SHOULD be unique.
# Ref: https://tools.ietf.org/html/draft-wright-json-schema-validation-00#section-5.20
'enum': choices
}
# If We figured out `type` then and only then we should set it. It must be a string.
# Ref: https://swagger.io/docs/specification/data-models/data-types/#mixed-type
# It is optional but it can not be null.
# Ref: https://tools.ietf.org/html/draft-wright-json-schema-validation-00#section-5.21
if type:
mapping['type'] = type
return mapping
def map_field(self, field):
# Nested Serializers, `many` or not.
if isinstance(field, serializers.ListSerializer):
return {
'type': 'array',
'items': self.map_serializer(field.child)
}
if isinstance(field, serializers.Serializer):
data = self.map_serializer(field)
data['type'] = 'object'
return data
# Related fields.
if isinstance(field, serializers.ManyRelatedField):
return {
'type': 'array',
'items': self.map_field(field.child_relation)
}
if isinstance(field, serializers.PrimaryKeyRelatedField):
model = getattr(field.queryset, 'model', None)
if model is not None:
model_field = model._meta.pk
if isinstance(model_field, models.AutoField):
return {'type': 'integer'}
# ChoiceFields (single and multiple).
# Q:
# - Is 'type' required?
# - can we determine the TYPE of a choicefield?
if isinstance(field, serializers.MultipleChoiceField):
return {
'type': 'array',
'items': self.map_choicefield(field)
}
if isinstance(field, serializers.ChoiceField):
return self.map_choicefield(field)
# ListField.
if isinstance(field, serializers.ListField):
mapping = {
'type': 'array',
'items': {},
}
if not isinstance(field.child, _UnvalidatedField):
mapping['items'] = self.map_field(field.child)
return mapping
# DateField and DateTimeField type is string
if isinstance(field, serializers.DateField):
return {
'type': 'string',
'format': 'date',
}
if isinstance(field, serializers.DateTimeField):
return {
'type': 'string',
'format': 'date-time',
}
# "Formats such as "email", "uuid", and so on, MAY be used even though undefined by this specification."
# see: https://github.com/OAI/OpenAPI-Specification/blob/master/versions/3.0.2.md#data-types
# see also: https://swagger.io/docs/specification/data-models/data-types/#string
if isinstance(field, serializers.EmailField):
return {
'type': 'string',
'format': 'email'
}
if isinstance(field, serializers.URLField):
return {
'type': 'string',
'format': 'uri'
}
if isinstance(field, serializers.UUIDField):
return {
'type': 'string',
'format': 'uuid'
}
if isinstance(field, serializers.IPAddressField):
content = {
'type': 'string',
}
if field.protocol != 'both':
content['format'] = field.protocol
return content
if isinstance(field, serializers.DecimalField):
if getattr(field, 'coerce_to_string', api_settings.COERCE_DECIMAL_TO_STRING):
content = {
'type': 'string',
'format': 'decimal',
}
else:
content = {
'type': 'number'
}
if field.decimal_places:
content['multipleOf'] = float('.' + (field.decimal_places - 1) * '0' + '1')
if field.max_whole_digits:
content['maximum'] = int(field.max_whole_digits * '9') + 1
content['minimum'] = -content['maximum']
self._map_min_max(field, content)
return content
if isinstance(field, serializers.FloatField):
content = {
'type': 'number',
}
self._map_min_max(field, content)
return content
if isinstance(field, serializers.IntegerField):
content = {
'type': 'integer'
}
self._map_min_max(field, content)
# 2147483647 is max for int32_size, so we use int64 for format
if int(content.get('maximum', 0)) > 2147483647 or int(content.get('minimum', 0)) > 2147483647:
content['format'] = 'int64'
return content
if isinstance(field, serializers.FileField):
return {
'type': 'string',
'format': 'binary'
}
# Simplest cases, default to 'string' type:
FIELD_CLASS_SCHEMA_TYPE = {
serializers.BooleanField: 'boolean',
serializers.JSONField: 'object',
serializers.DictField: 'object',
serializers.HStoreField: 'object',
}
return {'type': FIELD_CLASS_SCHEMA_TYPE.get(field.__class__, 'string')}
def _map_min_max(self, field, content):
if field.max_value:
content['maximum'] = field.max_value
if field.min_value:
content['minimum'] = field.min_value
def map_serializer(self, serializer):
# Assuming we have a valid serializer instance.
required = []
properties = {}
for field in serializer.fields.values():
if isinstance(field, serializers.HiddenField):
continue
if field.required:
required.append(field.field_name)
schema = self.map_field(field)
if field.read_only:
schema['readOnly'] = True
if field.write_only:
schema['writeOnly'] = True
if field.allow_null:
schema['nullable'] = True
if field.default is not None and field.default != empty and not callable(field.default):
schema['default'] = field.default
if field.help_text:
schema['description'] = str(field.help_text)
self.map_field_validators(field, schema)
properties[field.field_name] = schema
result = {
'type': 'object',
'properties': properties
}
if required:
result['required'] = required
return result
def map_field_validators(self, field, schema):
"""
map field validators
"""
for v in field.validators:
# "Formats such as "email", "uuid", and so on, MAY be used even though undefined by this specification."
# https://github.com/OAI/OpenAPI-Specification/blob/master/versions/3.0.2.md#data-types
if isinstance(v, EmailValidator):
schema['format'] = 'email'
if isinstance(v, URLValidator):
schema['format'] = 'uri'
if isinstance(v, RegexValidator):
# In Python, the token \Z does what \z does in other engines.
# https://stackoverflow.com/questions/53283160
schema['pattern'] = v.regex.pattern.replace('\\Z', '\\z')
elif isinstance(v, MaxLengthValidator):
attr_name = 'maxLength'
if isinstance(field, serializers.ListField):
attr_name = 'maxItems'
schema[attr_name] = v.limit_value
elif isinstance(v, MinLengthValidator):
attr_name = 'minLength'
if isinstance(field, serializers.ListField):
attr_name = 'minItems'
schema[attr_name] = v.limit_value
elif isinstance(v, MaxValueValidator):
schema['maximum'] = v.limit_value
elif isinstance(v, MinValueValidator):
schema['minimum'] = v.limit_value
elif isinstance(v, DecimalValidator) and \
not getattr(field, 'coerce_to_string', api_settings.COERCE_DECIMAL_TO_STRING):
if v.decimal_places:
schema['multipleOf'] = float('.' + (v.decimal_places - 1) * '0' + '1')
if v.max_digits:
digits = v.max_digits
if v.decimal_places is not None and v.decimal_places > 0:
digits -= v.decimal_places
schema['maximum'] = int(digits * '9') + 1
schema['minimum'] = -schema['maximum']
def get_paginator(self):
pagination_class = getattr(self.view, 'pagination_class', None)
if pagination_class:
return pagination_class()
return None
def map_parsers(self, path, method):
return list(map(attrgetter('media_type'), self.view.parser_classes))
def map_renderers(self, path, method):
media_types = []
for renderer in self.view.renderer_classes:
# BrowsableAPIRenderer not relevant to OpenAPI spec
if issubclass(renderer, renderers.BrowsableAPIRenderer):
continue
media_types.append(renderer.media_type)
return media_types
def get_serializer(self, path, method):
view = self.view
if not hasattr(view, 'get_serializer'):
return None
try:
return view.get_serializer()
except exceptions.APIException:
warnings.warn('{}.get_serializer() raised an exception during '
'schema generation. Serializer fields will not be '
'generated for {} {}.'
.format(view.__class__.__name__, method, path))
return None
def get_request_serializer(self, path, method):
"""
Override this method if your view uses a different serializer for
handling request body.
"""
return self.get_serializer(path, method)
def get_response_serializer(self, path, method):
"""
Override this method if your view uses a different serializer for
populating response data.
"""
return self.get_serializer(path, method)
def get_reference(self, serializer):
return {'$ref': '#/components/schemas/{}'.format(self.get_component_name(serializer))}
def get_request_body(self, path, method):
if method not in ('PUT', 'PATCH', 'POST'):
return {}
self.request_media_types = self.map_parsers(path, method)
serializer = self.get_request_serializer(path, method)
if not isinstance(serializer, serializers.Serializer):
item_schema = {}
else:
item_schema = self.get_reference(serializer)
return {
'content': {
ct: {'schema': item_schema}
for ct in self.request_media_types
}
}
def get_responses(self, path, method):
if method == 'DELETE':
return {
'204': {
'description': ''
}
}
self.response_media_types = self.map_renderers(path, method)
serializer = self.get_response_serializer(path, method)
if not isinstance(serializer, serializers.Serializer):
item_schema = {}
else:
item_schema = self.get_reference(serializer)
if is_list_view(path, method, self.view):
response_schema = {
'type': 'array',
'items': item_schema,
}
paginator = self.get_paginator()
if paginator:
response_schema = paginator.get_paginated_response_schema(response_schema)
else:
response_schema = item_schema
status_code = '201' if method == 'POST' else '200'
return {
status_code: {
'content': {
ct: {'schema': response_schema}
for ct in self.response_media_types
},
# description is a mandatory property,
# https://github.com/OAI/OpenAPI-Specification/blob/master/versions/3.0.2.md#responseObject
# TODO: put something meaningful into it
'description': ""
}
}
def get_tags(self, path, method):
# If user have specified tags, use them.
if self._tags:
return self._tags
# First element of a specific path could be valid tag. This is a fallback solution.
# PUT, PATCH, GET(Retrieve), DELETE: /user_profile/{id}/ tags = [user-profile]
# POST, GET(List): /user_profile/ tags = [user-profile]
if path.startswith('/'):
path = path[1:]
return [path.split('/')[0].replace('_', '-')]
def _get_reference(self, serializer):
warnings.warn(
"Method `_get_reference()` has been renamed to `get_reference()`. "
"The old name will be removed in DRF v3.15.",
RemovedInDRF315Warning, stacklevel=2
)
return self.get_reference(serializer)

View File

@@ -0,0 +1,41 @@
"""
utils.py # Shared helper functions
See schemas.__init__.py for package overview.
"""
from django.db import models
from django.utils.translation import gettext_lazy as _
from rest_framework.mixins import RetrieveModelMixin
def is_list_view(path, method, view):
"""
Return True if the given path/method appears to represent a list view.
"""
if hasattr(view, 'action'):
# Viewsets have an explicitly defined action, which we can inspect.
return view.action == 'list'
if method.lower() != 'get':
return False
if isinstance(view, RetrieveModelMixin):
return False
path_components = path.strip('/').split('/')
if path_components and '{' in path_components[-1]:
return False
return True
def get_pk_description(model, model_field):
if isinstance(model_field, models.AutoField):
value_type = _('unique integer value')
elif isinstance(model_field, models.UUIDField):
value_type = _('UUID string')
else:
value_type = _('unique value')
return _('A {value_type} identifying this {name}.').format(
value_type=value_type,
name=model._meta.verbose_name,
)

View File

@@ -0,0 +1,48 @@
"""
views.py # Houses `SchemaView`, `APIView` subclass.
See schemas.__init__.py for package overview.
"""
from rest_framework import exceptions, renderers
from rest_framework.response import Response
from rest_framework.schemas import coreapi
from rest_framework.settings import api_settings
from rest_framework.views import APIView
class SchemaView(APIView):
_ignore_model_permissions = True
schema = None # exclude from schema
renderer_classes = None
schema_generator = None
public = False
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
if self.renderer_classes is None:
if coreapi.is_enabled():
self.renderer_classes = [
renderers.CoreAPIOpenAPIRenderer,
renderers.CoreJSONRenderer
]
else:
self.renderer_classes = [
renderers.OpenAPIRenderer,
renderers.JSONOpenAPIRenderer,
]
if renderers.BrowsableAPIRenderer in api_settings.DEFAULT_RENDERER_CLASSES:
self.renderer_classes += [renderers.BrowsableAPIRenderer]
def get(self, request, *args, **kwargs):
schema = self.schema_generator.get_schema(request, self.public)
if schema is None:
raise exceptions.PermissionDenied()
return Response(schema)
def handle_exception(self, exc):
# Schema renderers do not render exceptions, so re-perform content
# negotiation with default renderers.
self.renderer_classes = api_settings.DEFAULT_RENDERER_CLASSES
neg = self.perform_content_negotiation(self.request, force=True)
self.request.accepted_renderer, self.request.accepted_media_type = neg
return super().handle_exception(exc)