GNXSOFT.COM
This commit is contained in:
@@ -0,0 +1,35 @@
|
||||
from ..app_settings import swagger_settings
|
||||
from .base import (
|
||||
BaseInspector, FieldInspector, FilterInspector, NotHandled, PaginatorInspector, SerializerInspector, ViewInspector
|
||||
)
|
||||
from .field import (
|
||||
CamelCaseJSONFilter, ChoiceFieldInspector, DictFieldInspector, FileFieldInspector, HiddenFieldInspector,
|
||||
InlineSerializerInspector, JSONFieldInspector, RecursiveFieldInspector, ReferencingSerializerInspector,
|
||||
RelatedFieldInspector, SerializerMethodFieldInspector, SimpleFieldInspector, StringDefaultFieldInspector
|
||||
)
|
||||
from .query import DrfAPICompatInspector, CoreAPICompatInspector, DjangoRestResponsePagination
|
||||
from .view import SwaggerAutoSchema
|
||||
|
||||
# these settings must be accessed only after defining/importing all the classes in this module to avoid ImportErrors
|
||||
ViewInspector.field_inspectors = swagger_settings.DEFAULT_FIELD_INSPECTORS
|
||||
ViewInspector.filter_inspectors = swagger_settings.DEFAULT_FILTER_INSPECTORS
|
||||
ViewInspector.paginator_inspectors = swagger_settings.DEFAULT_PAGINATOR_INSPECTORS
|
||||
|
||||
__all__ = [
|
||||
# base inspectors
|
||||
'BaseInspector', 'FilterInspector', 'PaginatorInspector', 'FieldInspector', 'SerializerInspector', 'ViewInspector',
|
||||
|
||||
# filter and pagination inspectors
|
||||
'DrfAPICompatInspector', 'CoreAPICompatInspector', 'DjangoRestResponsePagination',
|
||||
|
||||
# field inspectors
|
||||
'InlineSerializerInspector', 'RecursiveFieldInspector', 'ReferencingSerializerInspector', 'RelatedFieldInspector',
|
||||
'SimpleFieldInspector', 'FileFieldInspector', 'ChoiceFieldInspector', 'DictFieldInspector', 'JSONFieldInspector',
|
||||
'StringDefaultFieldInspector', 'CamelCaseJSONFilter', 'HiddenFieldInspector', 'SerializerMethodFieldInspector',
|
||||
|
||||
# view inspectors
|
||||
'SwaggerAutoSchema',
|
||||
|
||||
# module constants
|
||||
'NotHandled',
|
||||
]
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,462 @@
|
||||
import inspect
|
||||
import logging
|
||||
|
||||
from rest_framework import serializers
|
||||
|
||||
from .. import openapi
|
||||
from ..utils import force_real_str, get_field_default, get_object_classes, is_list_view
|
||||
|
||||
#: Sentinel value that inspectors must return to signal that they do not know how to handle an object
|
||||
NotHandled = object()
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def is_callable_method(cls_or_instance, method_name):
|
||||
method = getattr(cls_or_instance, method_name)
|
||||
if inspect.ismethod(method) and getattr(method, '__self__', None):
|
||||
# bound classmethod or instance method
|
||||
return method, True
|
||||
|
||||
from inspect import getattr_static
|
||||
return method, isinstance(getattr_static(cls_or_instance, method_name, None), staticmethod)
|
||||
|
||||
|
||||
def call_view_method(view, method_name, fallback_attr=None, default=None):
|
||||
"""Call a view method which might throw an exception. If an exception is thrown, log an informative error message
|
||||
and return the value of fallback_attr, or default if not present. The method must be callable without any arguments
|
||||
except cls or self.
|
||||
|
||||
:param view: view class or instance; if a class is passed, instance methods won't be called
|
||||
:type view: rest_framework.views.APIView or type[rest_framework.views.APIView]
|
||||
:param str method_name: name of a method on the view
|
||||
:param str fallback_attr: name of an attribute on the view to fall back on, if calling the method fails
|
||||
:param default: default value if all else fails
|
||||
:return: view method's return value, or value of view's fallback_attr, or default
|
||||
:rtype: any or None
|
||||
"""
|
||||
if hasattr(view, method_name):
|
||||
try:
|
||||
view_method, is_callabale = is_callable_method(view, method_name)
|
||||
if is_callabale:
|
||||
return view_method()
|
||||
except Exception: # pragma: no cover
|
||||
logger.warning("view's %s raised exception during schema generation; use "
|
||||
"`getattr(self, 'swagger_fake_view', False)` to detect and short-circuit this",
|
||||
type(view).__name__, exc_info=True)
|
||||
|
||||
if fallback_attr and hasattr(view, fallback_attr):
|
||||
return getattr(view, fallback_attr)
|
||||
|
||||
return default
|
||||
|
||||
|
||||
class BaseInspector:
|
||||
def __init__(self, view, path, method, components, request):
|
||||
"""
|
||||
:param rest_framework.views.APIView view: the view associated with this endpoint
|
||||
:param str path: the path component of the operation URL
|
||||
:param str method: the http method of the operation
|
||||
:param openapi.ReferenceResolver components: referenceable components
|
||||
:param rest_framework.request.Request request: the request made against the schema view; can be None
|
||||
"""
|
||||
self.view = view
|
||||
self.path = path
|
||||
self.method = method
|
||||
self.components = components
|
||||
self.request = request
|
||||
|
||||
def process_result(self, result, method_name, obj, **kwargs):
|
||||
"""After an inspector handles an object (i.e. returns a value other than :data:`.NotHandled`), all inspectors
|
||||
that were probed get the chance to alter the result, in reverse order. The inspector that handled the object
|
||||
is the first to receive a ``process_result`` call with the object it just returned.
|
||||
|
||||
This behavior is similar to the Django request/response middleware processing.
|
||||
|
||||
If this inspector has no post-processing to do, it should just ``return result`` (the default implementation).
|
||||
|
||||
:param result: the return value of the winning inspector, or ``None`` if no inspector handled the object
|
||||
:param str method_name: name of the method that was called on the inspector
|
||||
:param obj: first argument passed to inspector method
|
||||
:param kwargs: additional arguments passed to inspector method
|
||||
:return:
|
||||
"""
|
||||
return result
|
||||
|
||||
def probe_inspectors(self, inspectors, method_name, obj, initkwargs=None, **kwargs):
|
||||
"""Probe a list of inspectors with a given object. The first inspector in the list to return a value that
|
||||
is not :data:`.NotHandled` wins.
|
||||
|
||||
:param list[type[BaseInspector]] inspectors: list of inspectors to probe
|
||||
:param str method_name: name of the target method on the inspector
|
||||
:param obj: first argument to inspector method
|
||||
:param dict initkwargs: extra kwargs for instantiating inspector class
|
||||
:param kwargs: additional arguments to inspector method
|
||||
:return: the return value of the winning inspector, or ``None`` if no inspector handled the object
|
||||
"""
|
||||
initkwargs = initkwargs or {}
|
||||
tried_inspectors = []
|
||||
|
||||
for inspector in inspectors:
|
||||
assert inspect.isclass(inspector), "inspector must be a class, not an object"
|
||||
assert issubclass(inspector, BaseInspector), "inspectors must subclass BaseInspector"
|
||||
|
||||
inspector = inspector(self.view, self.path, self.method, self.components, self.request, **initkwargs)
|
||||
tried_inspectors.append(inspector)
|
||||
method = getattr(inspector, method_name, None)
|
||||
if method is None:
|
||||
continue
|
||||
|
||||
result = method(obj, **kwargs)
|
||||
if result is not NotHandled:
|
||||
break
|
||||
else: # pragma: no cover
|
||||
logger.warning("%s ignored because no inspector in %s handled it (operation: %s)",
|
||||
obj, inspectors, method_name)
|
||||
result = None
|
||||
|
||||
for inspector in reversed(tried_inspectors):
|
||||
result = inspector.process_result(result, method_name, obj, **kwargs)
|
||||
|
||||
return result
|
||||
|
||||
def get_renderer_classes(self):
|
||||
"""Get the renderer classes of this view by calling `get_renderers`.
|
||||
|
||||
:return: renderer classes
|
||||
:rtype: list[type[rest_framework.renderers.BaseRenderer]]
|
||||
"""
|
||||
return get_object_classes(call_view_method(self.view, 'get_renderers', 'renderer_classes', []))
|
||||
|
||||
def get_parser_classes(self):
|
||||
"""Get the parser classes of this view by calling `get_parsers`.
|
||||
|
||||
:return: parser classes
|
||||
:rtype: list[type[rest_framework.parsers.BaseParser]]
|
||||
"""
|
||||
return get_object_classes(call_view_method(self.view, 'get_parsers', 'parser_classes', []))
|
||||
|
||||
|
||||
class PaginatorInspector(BaseInspector):
|
||||
"""Base inspector for paginators.
|
||||
|
||||
Responsible for determining extra query parameters and response structure added by given paginators.
|
||||
"""
|
||||
|
||||
def get_paginator_parameters(self, paginator):
|
||||
"""Get the pagination parameters for a single paginator **instance**.
|
||||
|
||||
Should return :data:`.NotHandled` if this inspector does not know how to handle the given `paginator`.
|
||||
|
||||
:param BasePagination paginator: the paginator
|
||||
:rtype: list[openapi.Parameter]
|
||||
"""
|
||||
return NotHandled
|
||||
|
||||
def get_paginated_response(self, paginator, response_schema):
|
||||
"""Add appropriate paging fields to a response :class:`.Schema`.
|
||||
|
||||
Should return :data:`.NotHandled` if this inspector does not know how to handle the given `paginator`.
|
||||
|
||||
:param BasePagination paginator: the paginator
|
||||
:param openapi.Schema response_schema: the response schema that must be paged.
|
||||
:rtype: openapi.Schema
|
||||
"""
|
||||
return NotHandled
|
||||
|
||||
|
||||
class FilterInspector(BaseInspector):
|
||||
"""Base inspector for filter backends.
|
||||
|
||||
Responsible for determining extra query parameters added by given filter backends.
|
||||
"""
|
||||
|
||||
def get_filter_parameters(self, filter_backend):
|
||||
"""Get the filter parameters for a single filter backend **instance**.
|
||||
|
||||
Should return :data:`.NotHandled` if this inspector does not know how to handle the given `filter_backend`.
|
||||
|
||||
:param BaseFilterBackend filter_backend: the filter backend
|
||||
:rtype: list[openapi.Parameter]
|
||||
"""
|
||||
return NotHandled
|
||||
|
||||
|
||||
class FieldInspector(BaseInspector):
|
||||
"""Base inspector for serializers and serializer fields. """
|
||||
|
||||
def __init__(self, view, path, method, components, request, field_inspectors):
|
||||
super(FieldInspector, self).__init__(view, path, method, components, request)
|
||||
self.field_inspectors = field_inspectors
|
||||
|
||||
def add_manual_fields(self, serializer_or_field, schema):
|
||||
"""Set fields from the ``swagger_schema_fields`` attribute on the Meta class. This method is called
|
||||
only for serializers or fields that are converted into ``openapi.Schema`` objects.
|
||||
|
||||
:param serializer_or_field: serializer or field instance
|
||||
:param openapi.Schema schema: the schema object to be modified in-place
|
||||
"""
|
||||
meta = getattr(serializer_or_field, 'Meta', None)
|
||||
swagger_schema_fields = getattr(meta, 'swagger_schema_fields', {})
|
||||
if swagger_schema_fields:
|
||||
for attr, val in swagger_schema_fields.items():
|
||||
setattr(schema, attr, val)
|
||||
|
||||
def field_to_swagger_object(self, field, swagger_object_type, use_references, **kwargs):
|
||||
"""Convert a drf Serializer or Field instance into a Swagger object.
|
||||
|
||||
Should return :data:`.NotHandled` if this inspector does not know how to handle the given `field`.
|
||||
|
||||
:param rest_framework.serializers.Field field: the source field
|
||||
:param type[openapi.SwaggerDict] swagger_object_type: should be one of Schema, Parameter, Items
|
||||
:param bool use_references: if False, forces all objects to be declared inline
|
||||
instead of by referencing other components
|
||||
:param kwargs: extra attributes for constructing the object;
|
||||
if swagger_object_type is Parameter, ``name`` and ``in_`` should be provided
|
||||
:return: the swagger object
|
||||
:rtype: openapi.Parameter or openapi.Items or openapi.Schema or openapi.SchemaRef
|
||||
"""
|
||||
return NotHandled
|
||||
|
||||
def probe_field_inspectors(self, field, swagger_object_type, use_references, **kwargs):
|
||||
"""Helper method for recursively probing `field_inspectors` to handle a given field.
|
||||
|
||||
All arguments are the same as :meth:`.field_to_swagger_object`.
|
||||
|
||||
:rtype: openapi.Parameter or openapi.Items or openapi.Schema or openapi.SchemaRef
|
||||
"""
|
||||
return self.probe_inspectors(
|
||||
self.field_inspectors, 'field_to_swagger_object', field, {'field_inspectors': self.field_inspectors},
|
||||
swagger_object_type=swagger_object_type, use_references=use_references, **kwargs
|
||||
)
|
||||
|
||||
def _get_partial_types(self, field, swagger_object_type, use_references, **kwargs):
|
||||
"""Helper method to extract generic information from a field and return a partial constructor for the
|
||||
appropriate openapi object.
|
||||
|
||||
All arguments are the same as :meth:`.field_to_swagger_object`.
|
||||
|
||||
The return value is a tuple consisting of:
|
||||
|
||||
* a function for constructing objects of `swagger_object_type`; its prototype is: ::
|
||||
|
||||
def SwaggerType(existing_object=None, **instance_kwargs):
|
||||
|
||||
This function creates an instance of `swagger_object_type`, passing the following attributes to its init,
|
||||
in order of precedence:
|
||||
|
||||
- arguments specified by the ``kwargs`` parameter of :meth:`._get_partial_types`
|
||||
- ``instance_kwargs`` passed to the constructor function
|
||||
- ``title``, ``description``, ``required``, ``x-nullable`` and ``default`` inferred from the field,
|
||||
where appropriate
|
||||
|
||||
If ``existing_object`` is not ``None``, it is updated instead of creating a new object.
|
||||
|
||||
* a type that should be used for child objects if `field` is of an array type. This can currently have two
|
||||
values:
|
||||
|
||||
- :class:`.Schema` if `swagger_object_type` is :class:`.Schema`
|
||||
- :class:`.Items` if `swagger_object_type` is :class:`.Parameter` or :class:`.Items`
|
||||
|
||||
:rtype: (function,type[openapi.Schema] or type[openapi.Items])
|
||||
"""
|
||||
assert swagger_object_type in (openapi.Schema, openapi.Parameter, openapi.Items)
|
||||
assert not isinstance(field, openapi.SwaggerDict), "passed field is already a SwaggerDict object"
|
||||
title = force_real_str(field.label) if field.label else None
|
||||
title = title if swagger_object_type == openapi.Schema else None # only Schema has title
|
||||
help_text = getattr(field, 'help_text', None)
|
||||
description = force_real_str(help_text) if help_text else None
|
||||
description = description if swagger_object_type != openapi.Items else None # Items has no description either
|
||||
|
||||
def SwaggerType(existing_object=None, use_field_title=True, **instance_kwargs):
|
||||
if 'required' not in instance_kwargs and swagger_object_type == openapi.Parameter:
|
||||
instance_kwargs['required'] = field.required
|
||||
|
||||
if 'default' not in instance_kwargs and swagger_object_type != openapi.Items:
|
||||
default = get_field_default(field)
|
||||
if default not in (None, serializers.empty):
|
||||
instance_kwargs['default'] = default
|
||||
|
||||
if use_field_title and instance_kwargs.get('type', None) != openapi.TYPE_ARRAY:
|
||||
instance_kwargs.setdefault('title', title)
|
||||
if description is not None:
|
||||
instance_kwargs.setdefault('description', description)
|
||||
|
||||
if getattr(field, 'allow_null', None):
|
||||
instance_kwargs['x_nullable'] = True
|
||||
|
||||
instance_kwargs.update(kwargs)
|
||||
|
||||
if existing_object is not None:
|
||||
assert isinstance(existing_object, swagger_object_type)
|
||||
for key, val in sorted(instance_kwargs.items()):
|
||||
setattr(existing_object, key, val)
|
||||
result = existing_object
|
||||
else:
|
||||
result = swagger_object_type(**instance_kwargs)
|
||||
|
||||
# Provide an option to add manual parameters to a schema
|
||||
# for example, to add examples
|
||||
if swagger_object_type == openapi.Schema:
|
||||
self.add_manual_fields(field, result)
|
||||
return result
|
||||
|
||||
# arrays in Schema have Schema elements, arrays in Parameter and Items have Items elements
|
||||
child_swagger_type = openapi.Schema if swagger_object_type == openapi.Schema else openapi.Items
|
||||
return SwaggerType, child_swagger_type
|
||||
|
||||
|
||||
class SerializerInspector(FieldInspector):
|
||||
def get_schema(self, serializer):
|
||||
"""Convert a DRF Serializer instance to an :class:`.openapi.Schema`.
|
||||
|
||||
Should return :data:`.NotHandled` if this inspector does not know how to handle the given `serializer`.
|
||||
|
||||
:param serializers.BaseSerializer serializer: the ``Serializer`` instance
|
||||
:rtype: openapi.Schema
|
||||
"""
|
||||
return NotHandled
|
||||
|
||||
def get_request_parameters(self, serializer, in_):
|
||||
"""Convert a DRF serializer into a list of :class:`.Parameter`\\ s.
|
||||
|
||||
Should return :data:`.NotHandled` if this inspector does not know how to handle the given `serializer`.
|
||||
|
||||
:param serializers.BaseSerializer serializer: the ``Serializer`` instance
|
||||
:param str in_: the location of the parameters, one of the `openapi.IN_*` constants
|
||||
:rtype: list[openapi.Parameter]
|
||||
"""
|
||||
return NotHandled
|
||||
|
||||
|
||||
class ViewInspector(BaseInspector):
|
||||
body_methods = ('PUT', 'PATCH', 'POST', 'DELETE') #: methods that are allowed to have a request body
|
||||
|
||||
#: methods that are assumed to require a request body determined by the view's ``serializer_class``
|
||||
implicit_body_methods = ('PUT', 'PATCH', 'POST')
|
||||
|
||||
#: methods which are assumed to return a list of objects when present on non-detail endpoints
|
||||
implicit_list_response_methods = ('GET',)
|
||||
|
||||
# real values set in __init__ to prevent import errors
|
||||
field_inspectors = [] #:
|
||||
filter_inspectors = [] #:
|
||||
paginator_inspectors = [] #:
|
||||
|
||||
def __init__(self, view, path, method, components, request, overrides):
|
||||
"""
|
||||
Inspector class responsible for providing :class:`.Operation` definitions given a view, path and method.
|
||||
|
||||
:param dict overrides: manual overrides as passed to :func:`@swagger_auto_schema <.swagger_auto_schema>`
|
||||
"""
|
||||
super(ViewInspector, self).__init__(view, path, method, components, request)
|
||||
self.overrides = overrides
|
||||
self._prepend_inspector_overrides('field_inspectors')
|
||||
self._prepend_inspector_overrides('filter_inspectors')
|
||||
self._prepend_inspector_overrides('paginator_inspectors')
|
||||
|
||||
def _prepend_inspector_overrides(self, inspectors):
|
||||
extra_inspectors = self.overrides.get(inspectors, None)
|
||||
if extra_inspectors:
|
||||
default_inspectors = [insp for insp in getattr(self, inspectors) if insp not in extra_inspectors]
|
||||
setattr(self, inspectors, extra_inspectors + default_inspectors)
|
||||
|
||||
def get_operation(self, operation_keys):
|
||||
"""Get an :class:`.Operation` for the given API endpoint (path, method).
|
||||
This includes query, body parameters and response schemas.
|
||||
|
||||
:param tuple[str] operation_keys: an array of keys describing the hierarchical layout of this view in the API;
|
||||
e.g. ``('snippets', 'list')``, ``('snippets', 'retrieve')``, etc.
|
||||
:rtype: openapi.Operation
|
||||
"""
|
||||
raise NotImplementedError("ViewInspector must implement get_operation()!")
|
||||
|
||||
def is_list_view(self):
|
||||
"""Determine whether this view is a list or a detail view. The difference between the two is that
|
||||
detail views depend on a pk/id path parameter. Note that a non-detail view does not necessarily imply a list
|
||||
response (:meth:`.has_list_response`), nor are list responses limited to non-detail views.
|
||||
|
||||
For example, one might have a `/topic/<pk>/posts` endpoint which is a detail view that has a list response.
|
||||
|
||||
:rtype: bool"""
|
||||
return is_list_view(self.path, self.method, self.view)
|
||||
|
||||
def has_list_response(self):
|
||||
"""Determine whether this view returns multiple objects. By default this is any non-detail view
|
||||
(see :meth:`.is_list_view`) whose request method is one of :attr:`.implicit_list_response_methods`.
|
||||
|
||||
:rtype: bool
|
||||
"""
|
||||
return self.is_list_view() and (self.method.upper() in self.implicit_list_response_methods)
|
||||
|
||||
def should_filter(self):
|
||||
"""Determine whether filter backend parameters should be included for this request.
|
||||
|
||||
:rtype: bool
|
||||
"""
|
||||
return getattr(self.view, 'filter_backends', None) and self.has_list_response()
|
||||
|
||||
def get_filter_parameters(self):
|
||||
"""Return the parameters added to the view by its filter backends.
|
||||
|
||||
:rtype: list[openapi.Parameter]
|
||||
"""
|
||||
if not self.should_filter():
|
||||
return []
|
||||
|
||||
fields = []
|
||||
for filter_backend in getattr(self.view, 'filter_backends'):
|
||||
fields += self.probe_inspectors(self.filter_inspectors, 'get_filter_parameters', filter_backend()) or []
|
||||
|
||||
return fields
|
||||
|
||||
def should_page(self):
|
||||
"""Determine whether paging parameters and structure should be added to this operation's request and response.
|
||||
|
||||
:rtype: bool
|
||||
"""
|
||||
return getattr(self.view, 'paginator', None) and self.has_list_response()
|
||||
|
||||
def get_pagination_parameters(self):
|
||||
"""Return the parameters added to the view by its paginator.
|
||||
|
||||
:rtype: list[openapi.Parameter]
|
||||
"""
|
||||
if not self.should_page():
|
||||
return []
|
||||
|
||||
return self.probe_inspectors(self.paginator_inspectors, 'get_paginator_parameters',
|
||||
getattr(self.view, 'paginator')) or []
|
||||
|
||||
def serializer_to_schema(self, serializer):
|
||||
"""Convert a serializer to an OpenAPI :class:`.Schema`.
|
||||
|
||||
:param serializers.BaseSerializer serializer: the ``Serializer`` instance
|
||||
:returns: the converted :class:`.Schema`, or ``None`` in case of an unknown serializer
|
||||
:rtype: openapi.Schema or openapi.SchemaRef
|
||||
"""
|
||||
return self.probe_inspectors(
|
||||
self.field_inspectors, 'get_schema', serializer, {'field_inspectors': self.field_inspectors}
|
||||
)
|
||||
|
||||
def serializer_to_parameters(self, serializer, in_):
|
||||
"""Convert a serializer to a possibly empty list of :class:`.Parameter`\\ s.
|
||||
|
||||
:param serializers.BaseSerializer serializer: the ``Serializer`` instance
|
||||
:param str in_: the location of the parameters, one of the `openapi.IN_*` constants
|
||||
:rtype: list[openapi.Parameter]
|
||||
"""
|
||||
return self.probe_inspectors(
|
||||
self.field_inspectors, 'get_request_parameters', serializer, {'field_inspectors': self.field_inspectors},
|
||||
in_=in_
|
||||
) or []
|
||||
|
||||
def get_paginated_response(self, response_schema):
|
||||
"""Add appropriate paging fields to a response :class:`.Schema`.
|
||||
|
||||
:param openapi.Schema response_schema: the response schema that must be paged.
|
||||
:returns: the paginated response class:`.Schema`, or ``None`` in case of an unknown pagination scheme
|
||||
:rtype: openapi.Schema
|
||||
"""
|
||||
return self.probe_inspectors(self.paginator_inspectors, 'get_paginated_response',
|
||||
getattr(self.view, 'paginator'), response_schema=response_schema)
|
||||
@@ -0,0 +1,860 @@
|
||||
import datetime
|
||||
import inspect
|
||||
import logging
|
||||
import operator
|
||||
import uuid
|
||||
from contextlib import suppress
|
||||
from collections import OrderedDict
|
||||
from decimal import Decimal
|
||||
from inspect import signature as inspect_signature
|
||||
|
||||
import pkg_resources
|
||||
import typing
|
||||
from django.core import validators
|
||||
from django.db import models
|
||||
from packaging import version
|
||||
from rest_framework import serializers
|
||||
from rest_framework.settings import api_settings as rest_framework_settings
|
||||
|
||||
from .base import call_view_method, FieldInspector, NotHandled, SerializerInspector
|
||||
from .. import openapi
|
||||
from ..errors import SwaggerGenerationError
|
||||
from ..utils import (
|
||||
decimal_as_float, field_value_to_representation, filter_none, get_serializer_class, get_serializer_ref_name
|
||||
)
|
||||
|
||||
drf_version = pkg_resources.get_distribution("djangorestframework").version
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class InlineSerializerInspector(SerializerInspector):
|
||||
"""Provides serializer conversions using :meth:`.FieldInspector.field_to_swagger_object`."""
|
||||
|
||||
#: whether to output :class:`.Schema` definitions inline or into the ``definitions`` section
|
||||
use_definitions = False
|
||||
|
||||
def get_schema(self, serializer):
|
||||
return self.probe_field_inspectors(serializer, openapi.Schema, self.use_definitions)
|
||||
|
||||
def add_manual_parameters(self, serializer, parameters):
|
||||
"""Add/replace parameters from the given list of automatically generated request parameters. This method
|
||||
is called only when the serializer is converted into a list of parameters for use in a form data request.
|
||||
|
||||
:param serializer: serializer instance
|
||||
:param list[openapi.Parameter] parameters: generated parameters
|
||||
:return: modified parameters
|
||||
:rtype: list[openapi.Parameter]
|
||||
"""
|
||||
return parameters
|
||||
|
||||
def get_request_parameters(self, serializer, in_):
|
||||
fields = getattr(serializer, 'fields', {})
|
||||
parameters = [
|
||||
self.probe_field_inspectors(
|
||||
value, openapi.Parameter, self.use_definitions,
|
||||
name=self.get_parameter_name(key), in_=in_
|
||||
)
|
||||
for key, value
|
||||
in fields.items()
|
||||
if not getattr(value, 'read_only', False)
|
||||
]
|
||||
|
||||
return self.add_manual_parameters(serializer, parameters)
|
||||
|
||||
def get_property_name(self, field_name):
|
||||
return field_name
|
||||
|
||||
def get_parameter_name(self, field_name):
|
||||
return field_name
|
||||
|
||||
def get_serializer_ref_name(self, serializer):
|
||||
return get_serializer_ref_name(serializer)
|
||||
|
||||
def _has_ref_name(self, serializer):
|
||||
serializer_meta = getattr(serializer, 'Meta', None)
|
||||
return hasattr(serializer_meta, 'ref_name')
|
||||
|
||||
def field_to_swagger_object(self, field, swagger_object_type, use_references, **kwargs):
|
||||
SwaggerType, ChildSwaggerType = self._get_partial_types(field, swagger_object_type, use_references, **kwargs)
|
||||
|
||||
if isinstance(field, (serializers.ListSerializer, serializers.ListField)):
|
||||
child_schema = self.probe_field_inspectors(field.child, ChildSwaggerType, use_references)
|
||||
limits = find_limits(field) or {}
|
||||
return SwaggerType(
|
||||
type=openapi.TYPE_ARRAY,
|
||||
items=child_schema,
|
||||
**limits
|
||||
)
|
||||
elif isinstance(field, serializers.Serializer):
|
||||
if swagger_object_type != openapi.Schema:
|
||||
raise SwaggerGenerationError("cannot instantiate nested serializer as " + swagger_object_type.__name__)
|
||||
|
||||
ref_name = self.get_serializer_ref_name(field)
|
||||
|
||||
def make_schema_definition(serializer=field):
|
||||
properties = OrderedDict()
|
||||
required = []
|
||||
for property_name, child in serializer.fields.items():
|
||||
property_name = self.get_property_name(property_name)
|
||||
prop_kwargs = {
|
||||
'read_only': bool(child.read_only) or None
|
||||
}
|
||||
prop_kwargs = filter_none(prop_kwargs)
|
||||
|
||||
child_schema = self.probe_field_inspectors(
|
||||
child, ChildSwaggerType, use_references, **prop_kwargs
|
||||
)
|
||||
properties[property_name] = child_schema
|
||||
|
||||
if child.required and not getattr(child_schema, 'read_only', False):
|
||||
required.append(property_name)
|
||||
|
||||
result = SwaggerType(
|
||||
# the title is derived from the field name and is better to
|
||||
# be omitted from models
|
||||
use_field_title=False,
|
||||
type=openapi.TYPE_OBJECT,
|
||||
properties=properties,
|
||||
required=required or None,
|
||||
)
|
||||
|
||||
setattr(result, '_NP_serializer', get_serializer_class(serializer))
|
||||
return result
|
||||
|
||||
if not ref_name or not use_references:
|
||||
return make_schema_definition()
|
||||
|
||||
definitions = self.components.with_scope(openapi.SCHEMA_DEFINITIONS)
|
||||
actual_schema = definitions.setdefault(ref_name, make_schema_definition)
|
||||
actual_schema._remove_read_only()
|
||||
|
||||
actual_serializer = getattr(actual_schema, '_NP_serializer', None)
|
||||
this_serializer = get_serializer_class(field)
|
||||
if actual_serializer and actual_serializer != this_serializer:
|
||||
explicit_refs = self._has_ref_name(actual_serializer) and self._has_ref_name(this_serializer)
|
||||
if not explicit_refs:
|
||||
raise SwaggerGenerationError(
|
||||
"Schema for %s would override distinct serializer %s because they implicitly share the same "
|
||||
"ref_name; explicitly set the ref_name attribute on both serializers' Meta classes"
|
||||
% (actual_serializer, this_serializer))
|
||||
|
||||
return openapi.SchemaRef(definitions, ref_name)
|
||||
|
||||
return NotHandled
|
||||
|
||||
|
||||
class ReferencingSerializerInspector(InlineSerializerInspector):
|
||||
use_definitions = True
|
||||
|
||||
|
||||
def get_queryset_field(queryset, field_name):
|
||||
"""Try to get information about a model and model field from a queryset.
|
||||
|
||||
:param queryset: the queryset
|
||||
:param field_name: target field name
|
||||
:returns: the model and target field from the queryset as a 2-tuple; both elements can be ``None``
|
||||
:rtype: tuple
|
||||
"""
|
||||
model = getattr(queryset, 'model', None)
|
||||
model_field = get_model_field(model, field_name)
|
||||
return model, model_field
|
||||
|
||||
|
||||
def get_model_field(model, field_name):
|
||||
"""Try to get the given field from a django db model.
|
||||
|
||||
:param model: the model
|
||||
:param field_name: target field name
|
||||
:return: model field or ``None``
|
||||
"""
|
||||
try:
|
||||
if field_name == 'pk':
|
||||
return model._meta.pk
|
||||
else:
|
||||
return model._meta.get_field(field_name)
|
||||
except Exception: # pragma: no cover
|
||||
return None
|
||||
|
||||
|
||||
def get_queryset_from_view(view, serializer=None):
|
||||
"""Try to get the queryset of the given view
|
||||
|
||||
:param view: the view instance or class
|
||||
:param serializer: if given, will check that the view's get_serializer_class return matches this serializer
|
||||
:return: queryset or ``None``
|
||||
"""
|
||||
try:
|
||||
queryset = call_view_method(view, 'get_queryset', 'queryset')
|
||||
|
||||
if queryset is not None and serializer is not None:
|
||||
# make sure the view is actually using *this* serializer
|
||||
assert type(serializer) == call_view_method(view, 'get_serializer_class', 'serializer_class')
|
||||
|
||||
return queryset
|
||||
except Exception: # pragma: no cover
|
||||
return None
|
||||
|
||||
|
||||
def get_parent_serializer(field):
|
||||
"""Get the nearest parent ``Serializer`` instance for the given field.
|
||||
|
||||
:return: ``Serializer`` or ``None``
|
||||
"""
|
||||
while field is not None:
|
||||
if isinstance(field, serializers.Serializer):
|
||||
return field
|
||||
|
||||
field = field.parent
|
||||
|
||||
return None # pragma: no cover
|
||||
|
||||
|
||||
def get_model_from_descriptor(descriptor):
|
||||
with suppress(Exception):
|
||||
try:
|
||||
return descriptor.rel.related_model
|
||||
except Exception:
|
||||
return descriptor.field.remote_field.model
|
||||
|
||||
|
||||
def get_related_model(model, source):
|
||||
"""Try to find the other side of a model relationship given the name of a related field.
|
||||
|
||||
:param model: one side of the relationship
|
||||
:param str source: related field name
|
||||
:return: related model or ``None``
|
||||
"""
|
||||
|
||||
with suppress(Exception):
|
||||
if '.' in source and source.index('.'):
|
||||
attr, source = source.split('.', maxsplit=1)
|
||||
return get_related_model(get_model_from_descriptor(getattr(model, attr)), source)
|
||||
return get_model_from_descriptor(getattr(model, source))
|
||||
|
||||
|
||||
class RelatedFieldInspector(FieldInspector):
|
||||
"""Provides conversions for ``RelatedField``\\ s."""
|
||||
|
||||
def field_to_swagger_object(self, field, swagger_object_type, use_references, **kwargs):
|
||||
SwaggerType, ChildSwaggerType = self._get_partial_types(field, swagger_object_type, use_references, **kwargs)
|
||||
|
||||
if isinstance(field, serializers.ManyRelatedField):
|
||||
child_schema = self.probe_field_inspectors(field.child_relation, ChildSwaggerType, use_references)
|
||||
return SwaggerType(
|
||||
type=openapi.TYPE_ARRAY,
|
||||
items=child_schema,
|
||||
unique_items=True,
|
||||
)
|
||||
|
||||
if not isinstance(field, serializers.RelatedField):
|
||||
return NotHandled
|
||||
|
||||
field_queryset = getattr(field, 'queryset', None)
|
||||
|
||||
if isinstance(field, (serializers.PrimaryKeyRelatedField, serializers.SlugRelatedField)):
|
||||
if getattr(field, 'pk_field', ''):
|
||||
# a PrimaryKeyRelatedField can have a `pk_field` attribute which is a
|
||||
# serializer field that will convert the PK value
|
||||
result = self.probe_field_inspectors(field.pk_field, swagger_object_type, use_references, **kwargs)
|
||||
# take the type, format, etc from `pk_field`, and the field-level information
|
||||
# like title, description, default from the PrimaryKeyRelatedField
|
||||
return SwaggerType(existing_object=result)
|
||||
|
||||
target_field = getattr(field, 'slug_field', 'pk')
|
||||
if field_queryset is not None:
|
||||
# if the RelatedField has a queryset, try to get the related model field from there
|
||||
model, model_field = get_queryset_field(field_queryset, target_field)
|
||||
else:
|
||||
# if the RelatedField has no queryset (e.g. read only), try to find the target model
|
||||
# from the view queryset or ModelSerializer model, if present
|
||||
parent_serializer = get_parent_serializer(field)
|
||||
|
||||
serializer_meta = getattr(parent_serializer, 'Meta', None)
|
||||
this_model = getattr(serializer_meta, 'model', None)
|
||||
if not this_model:
|
||||
view_queryset = get_queryset_from_view(self.view, parent_serializer)
|
||||
this_model = getattr(view_queryset, 'model', None)
|
||||
|
||||
source = getattr(field, 'source', '') or field.field_name
|
||||
if not source and isinstance(field.parent, serializers.ManyRelatedField):
|
||||
source = field.parent.field_name
|
||||
|
||||
model = get_related_model(this_model, source)
|
||||
model_field = get_model_field(model, target_field)
|
||||
|
||||
attrs = get_basic_type_info(model_field) or {'type': openapi.TYPE_STRING}
|
||||
return SwaggerType(**attrs)
|
||||
elif isinstance(field, serializers.HyperlinkedRelatedField):
|
||||
return SwaggerType(type=openapi.TYPE_STRING, format=openapi.FORMAT_URI)
|
||||
|
||||
return NotHandled # pragma: no cover
|
||||
|
||||
|
||||
def find_regex(regex_field):
|
||||
"""Given a ``Field``, look for a ``RegexValidator`` and try to extract its pattern and return it as a string.
|
||||
|
||||
:param serializers.Field regex_field: the field instance
|
||||
:return: the extracted pattern, or ``None``
|
||||
:rtype: str
|
||||
"""
|
||||
regex_validator = None
|
||||
for validator in regex_field.validators:
|
||||
if isinstance(validator, validators.RegexValidator):
|
||||
if isinstance(validator, validators.URLValidator) or validator == validators.validate_ipv4_address:
|
||||
# skip the default url and IP regexes because they are complex and unhelpful
|
||||
# validate_ipv4_address is a RegexValidator instance in Django 1.11
|
||||
continue
|
||||
if regex_validator is not None:
|
||||
# bail if multiple validators are found - no obvious way to choose
|
||||
return None # pragma: no cover
|
||||
regex_validator = validator
|
||||
|
||||
# regex_validator.regex should be a compiled re object...
|
||||
try:
|
||||
pattern = getattr(getattr(regex_validator, 'regex', None), 'pattern', None)
|
||||
except Exception: # pragma: no cover
|
||||
logger.warning('failed to compile regex validator of ' + str(regex_field), exc_info=True)
|
||||
return None
|
||||
|
||||
if pattern:
|
||||
# attempt some basic cleanup to remove regex constructs not supported by JavaScript
|
||||
# -- swagger uses javascript-style regexes - see https://github.com/swagger-api/swagger-editor/issues/1601
|
||||
if pattern.endswith('\\Z') or pattern.endswith('\\z'):
|
||||
pattern = pattern[:-2] + '$'
|
||||
|
||||
return pattern
|
||||
|
||||
|
||||
numeric_fields = (serializers.IntegerField, serializers.FloatField, serializers.DecimalField)
|
||||
limit_validators = [
|
||||
# minimum and maximum apply to numbers
|
||||
(validators.MinValueValidator, numeric_fields, 'minimum', operator.__gt__),
|
||||
(validators.MaxValueValidator, numeric_fields, 'maximum', operator.__lt__),
|
||||
|
||||
# minLength and maxLength apply to strings
|
||||
(validators.MinLengthValidator, serializers.CharField, 'min_length', operator.__gt__),
|
||||
(validators.MaxLengthValidator, serializers.CharField, 'max_length', operator.__lt__),
|
||||
|
||||
# minItems and maxItems apply to lists
|
||||
(validators.MinLengthValidator, (serializers.ListField, serializers.ListSerializer), 'min_items', operator.__gt__),
|
||||
(validators.MaxLengthValidator, (serializers.ListField, serializers.ListSerializer), 'max_items', operator.__lt__),
|
||||
]
|
||||
|
||||
|
||||
def find_limits(field):
|
||||
"""Given a ``Field``, look for min/max value/length validators and return appropriate limit validation attributes.
|
||||
|
||||
:param serializers.Field field: the field instance
|
||||
:return: the extracted limits
|
||||
:rtype: OrderedDict
|
||||
"""
|
||||
limits = {}
|
||||
applicable_limits = [
|
||||
(validator, attr, improves)
|
||||
for validator, field_class, attr, improves in limit_validators
|
||||
if isinstance(field, field_class)
|
||||
]
|
||||
|
||||
if isinstance(field, serializers.DecimalField) and not decimal_as_float(field):
|
||||
return limits
|
||||
|
||||
for validator in field.validators:
|
||||
if not hasattr(validator, 'limit_value'):
|
||||
continue
|
||||
|
||||
limit_value = validator.limit_value
|
||||
if isinstance(limit_value, Decimal) and decimal_as_float(field):
|
||||
limit_value = float(limit_value)
|
||||
|
||||
for validator_class, attr, improves in applicable_limits:
|
||||
if isinstance(validator, validator_class):
|
||||
if attr not in limits or improves(limit_value, limits[attr]):
|
||||
limits[attr] = limit_value
|
||||
|
||||
if hasattr(field, "allow_blank") and not field.allow_blank:
|
||||
if limits.get('min_length', 0) < 1:
|
||||
limits['min_length'] = 1
|
||||
|
||||
return OrderedDict(sorted(limits.items()))
|
||||
|
||||
|
||||
def decimal_field_type(field):
|
||||
return openapi.TYPE_NUMBER if decimal_as_float(field) else openapi.TYPE_STRING
|
||||
|
||||
|
||||
model_field_to_basic_type = [
|
||||
(models.AutoField, (openapi.TYPE_INTEGER, None)),
|
||||
(models.BinaryField, (openapi.TYPE_STRING, openapi.FORMAT_BINARY)),
|
||||
(models.BooleanField, (openapi.TYPE_BOOLEAN, None)),
|
||||
(models.DateTimeField, (openapi.TYPE_STRING, openapi.FORMAT_DATETIME)),
|
||||
(models.DateField, (openapi.TYPE_STRING, openapi.FORMAT_DATE)),
|
||||
(models.DecimalField, (decimal_field_type, openapi.FORMAT_DECIMAL)),
|
||||
(models.DurationField, (openapi.TYPE_STRING, None)),
|
||||
(models.FloatField, (openapi.TYPE_NUMBER, None)),
|
||||
(models.IntegerField, (openapi.TYPE_INTEGER, None)),
|
||||
(models.IPAddressField, (openapi.TYPE_STRING, openapi.FORMAT_IPV4)),
|
||||
(models.GenericIPAddressField, (openapi.TYPE_STRING, openapi.FORMAT_IPV6)),
|
||||
(models.SlugField, (openapi.TYPE_STRING, openapi.FORMAT_SLUG)),
|
||||
(models.TextField, (openapi.TYPE_STRING, None)),
|
||||
(models.TimeField, (openapi.TYPE_STRING, None)),
|
||||
(models.UUIDField, (openapi.TYPE_STRING, openapi.FORMAT_UUID)),
|
||||
(models.CharField, (openapi.TYPE_STRING, None)),
|
||||
]
|
||||
|
||||
ip_format = {'ipv4': openapi.FORMAT_IPV4, 'ipv6': openapi.FORMAT_IPV6}
|
||||
|
||||
serializer_field_to_basic_type = [
|
||||
(serializers.EmailField, (openapi.TYPE_STRING, openapi.FORMAT_EMAIL)),
|
||||
(serializers.SlugField, (openapi.TYPE_STRING, openapi.FORMAT_SLUG)),
|
||||
(serializers.URLField, (openapi.TYPE_STRING, openapi.FORMAT_URI)),
|
||||
(serializers.IPAddressField, (openapi.TYPE_STRING, lambda field: ip_format.get(field.protocol, None))),
|
||||
(serializers.UUIDField, (openapi.TYPE_STRING, openapi.FORMAT_UUID)),
|
||||
(serializers.RegexField, (openapi.TYPE_STRING, None)),
|
||||
(serializers.CharField, (openapi.TYPE_STRING, None)),
|
||||
(serializers.BooleanField, (openapi.TYPE_BOOLEAN, None)),
|
||||
(serializers.IntegerField, (openapi.TYPE_INTEGER, None)),
|
||||
(serializers.FloatField, (openapi.TYPE_NUMBER, None)),
|
||||
(serializers.DecimalField, (decimal_field_type, openapi.FORMAT_DECIMAL)),
|
||||
(serializers.DurationField, (openapi.TYPE_STRING, None)),
|
||||
(serializers.DateField, (openapi.TYPE_STRING, openapi.FORMAT_DATE)),
|
||||
(serializers.DateTimeField, (openapi.TYPE_STRING, openapi.FORMAT_DATETIME)),
|
||||
(serializers.ModelField, (openapi.TYPE_STRING, None)),
|
||||
]
|
||||
|
||||
if version.parse(drf_version) < version.parse("3.14.0"):
|
||||
model_field_to_basic_type.append(
|
||||
(models.NullBooleanField, (openapi.TYPE_BOOLEAN, None))
|
||||
)
|
||||
|
||||
serializer_field_to_basic_type.append(
|
||||
(serializers.NullBooleanField, (openapi.TYPE_BOOLEAN, None)),
|
||||
)
|
||||
|
||||
basic_type_info = serializer_field_to_basic_type + model_field_to_basic_type
|
||||
|
||||
|
||||
def get_basic_type_info(field):
|
||||
"""Given a serializer or model ``Field``, return its basic type information - ``type``, ``format``, ``pattern``,
|
||||
and any applicable min/max limit values.
|
||||
|
||||
:param field: the field instance
|
||||
:return: the extracted attributes as a dictionary, or ``None`` if the field type is not known
|
||||
:rtype: OrderedDict
|
||||
"""
|
||||
if field is None:
|
||||
return None
|
||||
|
||||
for field_class, type_format in basic_type_info:
|
||||
if isinstance(field, field_class):
|
||||
swagger_type, format = type_format
|
||||
if callable(swagger_type):
|
||||
swagger_type = swagger_type(field)
|
||||
if callable(format):
|
||||
format = format(field)
|
||||
break
|
||||
else: # pragma: no cover
|
||||
return None
|
||||
|
||||
pattern = None
|
||||
if swagger_type == openapi.TYPE_STRING:
|
||||
pattern = find_regex(field)
|
||||
|
||||
limits = find_limits(field)
|
||||
|
||||
result = OrderedDict([
|
||||
('type', swagger_type),
|
||||
('format', format),
|
||||
('pattern', pattern)
|
||||
])
|
||||
result.update(limits)
|
||||
result = filter_none(result)
|
||||
return result
|
||||
|
||||
|
||||
def decimal_return_type():
|
||||
return openapi.TYPE_STRING if rest_framework_settings.COERCE_DECIMAL_TO_STRING else openapi.TYPE_NUMBER
|
||||
|
||||
|
||||
def get_origin_type(hint_class):
|
||||
return getattr(hint_class, '__origin__', None) or hint_class
|
||||
|
||||
|
||||
def hint_class_issubclass(hint_class, check_class):
|
||||
origin_type = get_origin_type(hint_class)
|
||||
return inspect.isclass(origin_type) and issubclass(origin_type, check_class)
|
||||
|
||||
|
||||
hinting_type_info = [
|
||||
(bool, (openapi.TYPE_BOOLEAN, None)),
|
||||
(int, (openapi.TYPE_INTEGER, None)),
|
||||
(str, (openapi.TYPE_STRING, None)),
|
||||
(float, (openapi.TYPE_NUMBER, None)),
|
||||
(dict, (openapi.TYPE_OBJECT, None)),
|
||||
(Decimal, (decimal_return_type, openapi.FORMAT_DECIMAL)),
|
||||
(uuid.UUID, (openapi.TYPE_STRING, openapi.FORMAT_UUID)),
|
||||
(datetime.datetime, (openapi.TYPE_STRING, openapi.FORMAT_DATETIME)),
|
||||
(datetime.date, (openapi.TYPE_STRING, openapi.FORMAT_DATE)),
|
||||
]
|
||||
|
||||
|
||||
if hasattr(typing, 'get_args'):
|
||||
# python >=3.8
|
||||
typing_get_args = typing.get_args
|
||||
else:
|
||||
# python <3.8
|
||||
def typing_get_args(tp):
|
||||
return getattr(tp, '__args__', ())
|
||||
|
||||
|
||||
def inspect_collection_hint_class(hint_class):
|
||||
args = typing_get_args(hint_class)
|
||||
child_class = args[0] if args else str
|
||||
child_type_info = get_basic_type_info_from_hint(child_class) or {'type': openapi.TYPE_STRING}
|
||||
|
||||
return OrderedDict([
|
||||
('type', openapi.TYPE_ARRAY),
|
||||
('items', openapi.Items(**child_type_info)),
|
||||
])
|
||||
|
||||
|
||||
hinting_type_info.append(((typing.Sequence, typing.AbstractSet), inspect_collection_hint_class))
|
||||
|
||||
|
||||
def _get_union_types(hint_class):
|
||||
origin_type = get_origin_type(hint_class)
|
||||
if origin_type is typing.Union:
|
||||
return hint_class.__args__
|
||||
|
||||
|
||||
def get_basic_type_info_from_hint(hint_class):
|
||||
"""Given a class (eg from a SerializerMethodField's return type hint,
|
||||
return its basic type information - ``type``, ``format``, ``pattern``,
|
||||
and any applicable min/max limit values.
|
||||
|
||||
:param hint_class: the class
|
||||
:return: the extracted attributes as a dictionary, or ``None`` if the field type is not known
|
||||
:rtype: OrderedDict
|
||||
"""
|
||||
union_types = _get_union_types(hint_class)
|
||||
|
||||
if union_types:
|
||||
# Optional is implemented as Union[T, None]
|
||||
if len(union_types) == 2 and isinstance(None, union_types[1]):
|
||||
result = get_basic_type_info_from_hint(union_types[0])
|
||||
if result:
|
||||
result['x-nullable'] = True
|
||||
|
||||
return result
|
||||
|
||||
return None
|
||||
|
||||
for check_class, info in hinting_type_info:
|
||||
if hint_class_issubclass(hint_class, check_class):
|
||||
if callable(info):
|
||||
return info(hint_class)
|
||||
|
||||
swagger_type, format = info
|
||||
if callable(swagger_type):
|
||||
swagger_type = swagger_type()
|
||||
|
||||
return OrderedDict([
|
||||
('type', swagger_type),
|
||||
('format', format),
|
||||
])
|
||||
|
||||
return None
|
||||
|
||||
|
||||
class SerializerMethodFieldInspector(FieldInspector):
|
||||
"""Provides conversion for SerializerMethodField, optionally using information from the swagger_serializer_method
|
||||
decorator.
|
||||
"""
|
||||
|
||||
def field_to_swagger_object(self, field, swagger_object_type, use_references, **kwargs):
|
||||
if not isinstance(field, serializers.SerializerMethodField):
|
||||
return NotHandled
|
||||
|
||||
method = getattr(field.parent, field.method_name, None)
|
||||
if method is None:
|
||||
return NotHandled
|
||||
|
||||
# attribute added by the swagger_serializer_method decorator
|
||||
serializer = getattr(method, "_swagger_serializer", None)
|
||||
|
||||
if serializer:
|
||||
# in order of preference for description, use:
|
||||
# 1) field.help_text from SerializerMethodField(help_text)
|
||||
# 2) serializer.help_text from swagger_serializer_method(serializer)
|
||||
# 3) method's docstring
|
||||
description = field.help_text
|
||||
if description is None:
|
||||
description = getattr(serializer, 'help_text', None)
|
||||
if description is None:
|
||||
description = method.__doc__
|
||||
|
||||
label = field.label
|
||||
if label is None:
|
||||
label = getattr(serializer, 'label', None)
|
||||
|
||||
if inspect.isclass(serializer):
|
||||
serializer_kwargs = {
|
||||
"help_text": description,
|
||||
"label": label,
|
||||
"read_only": True,
|
||||
}
|
||||
|
||||
serializer = method._swagger_serializer(**serializer_kwargs)
|
||||
else:
|
||||
serializer.help_text = description
|
||||
serializer.label = label
|
||||
serializer.read_only = True
|
||||
|
||||
return self.probe_field_inspectors(serializer, swagger_object_type, use_references, read_only=True)
|
||||
else:
|
||||
# look for Python 3.5+ style type hinting of the return value
|
||||
hint_class = inspect_signature(method).return_annotation
|
||||
|
||||
if not inspect.isclass(hint_class) and hasattr(hint_class, '__args__'):
|
||||
hint_class = hint_class.__args__[0]
|
||||
if inspect.isclass(hint_class) and not issubclass(hint_class, inspect._empty):
|
||||
type_info = get_basic_type_info_from_hint(hint_class)
|
||||
|
||||
if type_info is not None:
|
||||
SwaggerType, ChildSwaggerType = self._get_partial_types(field, swagger_object_type,
|
||||
use_references, **kwargs)
|
||||
return SwaggerType(**type_info)
|
||||
|
||||
return NotHandled
|
||||
|
||||
|
||||
class SimpleFieldInspector(FieldInspector):
|
||||
"""Provides conversions for fields which can be described using just ``type``, ``format``, ``pattern``
|
||||
and min/max validators.
|
||||
"""
|
||||
|
||||
def field_to_swagger_object(self, field, swagger_object_type, use_references, **kwargs):
|
||||
type_info = get_basic_type_info(field)
|
||||
if type_info is None:
|
||||
return NotHandled
|
||||
|
||||
SwaggerType, ChildSwaggerType = self._get_partial_types(field, swagger_object_type, use_references, **kwargs)
|
||||
return SwaggerType(**type_info)
|
||||
|
||||
|
||||
class ChoiceFieldInspector(FieldInspector):
|
||||
"""Provides conversions for ``ChoiceField`` and ``MultipleChoiceField``."""
|
||||
|
||||
def field_to_swagger_object(self, field, swagger_object_type, use_references, **kwargs):
|
||||
SwaggerType, ChildSwaggerType = self._get_partial_types(field, swagger_object_type, use_references, **kwargs)
|
||||
|
||||
if isinstance(field, serializers.ChoiceField):
|
||||
enum_type = openapi.TYPE_STRING
|
||||
enum_values = []
|
||||
for choice in field.choices.keys():
|
||||
if isinstance(field, serializers.MultipleChoiceField):
|
||||
choice = field_value_to_representation(field, [choice])[0]
|
||||
else:
|
||||
choice = field_value_to_representation(field, choice)
|
||||
|
||||
enum_values.append(choice)
|
||||
|
||||
# for ModelSerializer, try to infer the type from the associated model field
|
||||
serializer = get_parent_serializer(field)
|
||||
if isinstance(serializer, serializers.ModelSerializer):
|
||||
model = getattr(getattr(serializer, 'Meta'), 'model')
|
||||
# Use the parent source for nested fields
|
||||
model_field = get_model_field(model, field.source or field.parent.source)
|
||||
# If the field has a base_field its type must be used
|
||||
if getattr(model_field, "base_field", None):
|
||||
model_field = model_field.base_field
|
||||
if model_field:
|
||||
model_type = get_basic_type_info(model_field)
|
||||
if model_type:
|
||||
enum_type = model_type.get('type', enum_type)
|
||||
else:
|
||||
# Try to infer field type based on enum values
|
||||
enum_value_types = {type(v) for v in enum_values}
|
||||
if len(enum_value_types) == 1:
|
||||
values_type = get_basic_type_info_from_hint(next(iter(enum_value_types)))
|
||||
if values_type:
|
||||
enum_type = values_type.get('type', enum_type)
|
||||
|
||||
if isinstance(field, serializers.MultipleChoiceField):
|
||||
result = SwaggerType(
|
||||
type=openapi.TYPE_ARRAY,
|
||||
items=ChildSwaggerType(
|
||||
type=enum_type,
|
||||
enum=enum_values
|
||||
)
|
||||
)
|
||||
if swagger_object_type == openapi.Parameter:
|
||||
if result['in'] in (openapi.IN_FORM, openapi.IN_QUERY):
|
||||
result.collection_format = 'multi'
|
||||
else:
|
||||
result = SwaggerType(type=enum_type, enum=enum_values)
|
||||
|
||||
return result
|
||||
|
||||
return NotHandled
|
||||
|
||||
|
||||
class FileFieldInspector(FieldInspector):
|
||||
"""Provides conversions for ``FileField``\\ s."""
|
||||
|
||||
def field_to_swagger_object(self, field, swagger_object_type, use_references, **kwargs):
|
||||
SwaggerType, ChildSwaggerType = self._get_partial_types(field, swagger_object_type, use_references, **kwargs)
|
||||
|
||||
if isinstance(field, serializers.FileField):
|
||||
# swagger 2.0 does not support specifics about file fields, so ImageFile gets no special treatment
|
||||
# OpenAPI 3.0 does support it, so a future implementation could handle this better
|
||||
err = SwaggerGenerationError("FileField is supported only in a formData Parameter or response Schema")
|
||||
if swagger_object_type == openapi.Schema:
|
||||
# FileField.to_representation returns URL or file name
|
||||
result = SwaggerType(type=openapi.TYPE_STRING, read_only=True)
|
||||
if getattr(field, 'use_url', rest_framework_settings.UPLOADED_FILES_USE_URL):
|
||||
result.format = openapi.FORMAT_URI
|
||||
return result
|
||||
elif swagger_object_type == openapi.Parameter:
|
||||
param = SwaggerType(type=openapi.TYPE_FILE)
|
||||
if param['in'] != openapi.IN_FORM:
|
||||
raise err # pragma: no cover
|
||||
return param
|
||||
else:
|
||||
raise err # pragma: no cover
|
||||
|
||||
return NotHandled
|
||||
|
||||
|
||||
class DictFieldInspector(FieldInspector):
|
||||
"""Provides conversion for ``DictField``."""
|
||||
|
||||
def field_to_swagger_object(self, field, swagger_object_type, use_references, **kwargs):
|
||||
SwaggerType, ChildSwaggerType = self._get_partial_types(field, swagger_object_type, use_references, **kwargs)
|
||||
|
||||
if isinstance(field, serializers.DictField) and swagger_object_type == openapi.Schema:
|
||||
child_schema = self.probe_field_inspectors(field.child, ChildSwaggerType, use_references)
|
||||
return SwaggerType(
|
||||
type=openapi.TYPE_OBJECT,
|
||||
additional_properties=child_schema
|
||||
)
|
||||
|
||||
return NotHandled
|
||||
|
||||
|
||||
class HiddenFieldInspector(FieldInspector):
|
||||
"""Hide ``HiddenField``."""
|
||||
|
||||
def field_to_swagger_object(self, field, swagger_object_type, use_references, **kwargs):
|
||||
if isinstance(field, serializers.HiddenField):
|
||||
return None
|
||||
|
||||
return NotHandled
|
||||
|
||||
|
||||
class JSONFieldInspector(FieldInspector):
|
||||
"""Provides conversion for ``JSONField``."""
|
||||
|
||||
def field_to_swagger_object(self, field, swagger_object_type, use_references, **kwargs):
|
||||
SwaggerType, ChildSwaggerType = self._get_partial_types(field, swagger_object_type, use_references, **kwargs)
|
||||
|
||||
if isinstance(field, serializers.JSONField) and swagger_object_type == openapi.Schema:
|
||||
return SwaggerType(type=openapi.TYPE_OBJECT)
|
||||
|
||||
return NotHandled
|
||||
|
||||
|
||||
class StringDefaultFieldInspector(FieldInspector):
|
||||
"""For otherwise unhandled fields, return them as plain :data:`.TYPE_STRING` objects."""
|
||||
|
||||
def field_to_swagger_object(self, field, swagger_object_type, use_references, **kwargs): # pragma: no cover
|
||||
# TODO unhandled fields: TimeField
|
||||
SwaggerType, ChildSwaggerType = self._get_partial_types(field, swagger_object_type, use_references, **kwargs)
|
||||
return SwaggerType(type=openapi.TYPE_STRING)
|
||||
|
||||
|
||||
try:
|
||||
from djangorestframework_camel_case.parser import CamelCaseJSONParser
|
||||
from djangorestframework_camel_case.render import CamelCaseJSONRenderer, camelize
|
||||
except ImportError: # pragma: no cover
|
||||
CamelCaseJSONParser = CamelCaseJSONRenderer = None
|
||||
|
||||
def camelize(data):
|
||||
return data
|
||||
|
||||
|
||||
class CamelCaseJSONFilter(FieldInspector):
|
||||
"""Converts property names to camelCase if ``djangorestframework_camel_case`` is used."""
|
||||
|
||||
def camelize_string(self, s):
|
||||
"""Hack to force ``djangorestframework_camel_case`` to camelize a plain string.
|
||||
|
||||
:param str s: the string
|
||||
:return: camelized string
|
||||
:rtype: str
|
||||
"""
|
||||
return next(iter(camelize({s: ''})))
|
||||
|
||||
def camelize_schema(self, schema):
|
||||
"""Recursively camelize property names for the given schema using ``djangorestframework_camel_case``.
|
||||
The target schema object must be modified in-place.
|
||||
|
||||
:param openapi.Schema schema: the :class:`.Schema` object
|
||||
"""
|
||||
if getattr(schema, 'properties', {}):
|
||||
schema.properties = OrderedDict(
|
||||
(self.camelize_string(key), self.camelize_schema(openapi.resolve_ref(val, self.components)) or val)
|
||||
for key, val in schema.properties.items()
|
||||
)
|
||||
|
||||
if getattr(schema, 'required', []):
|
||||
schema.required = [self.camelize_string(p) for p in schema.required]
|
||||
|
||||
def process_result(self, result, method_name, obj, **kwargs):
|
||||
if isinstance(result, openapi.Schema.OR_REF) and self.is_camel_case():
|
||||
schema = openapi.resolve_ref(result, self.components)
|
||||
self.camelize_schema(schema)
|
||||
|
||||
return result
|
||||
|
||||
if CamelCaseJSONParser and CamelCaseJSONRenderer:
|
||||
def is_camel_case(self):
|
||||
return (
|
||||
any(issubclass(parser, CamelCaseJSONParser) for parser in self.get_parser_classes()) or
|
||||
any(issubclass(renderer, CamelCaseJSONRenderer) for renderer in self.get_renderer_classes())
|
||||
)
|
||||
else:
|
||||
def is_camel_case(self):
|
||||
return False
|
||||
|
||||
|
||||
try:
|
||||
from rest_framework_recursive.fields import RecursiveField
|
||||
except ImportError: # pragma: no cover
|
||||
class RecursiveFieldInspector(FieldInspector):
|
||||
"""Provides conversion for RecursiveField (https://github.com/heywbj/django-rest-framework-recursive)"""
|
||||
pass
|
||||
else:
|
||||
class RecursiveFieldInspector(FieldInspector):
|
||||
"""Provides conversion for RecursiveField (https://github.com/heywbj/django-rest-framework-recursive)"""
|
||||
|
||||
def field_to_swagger_object(self, field, swagger_object_type, use_references, **kwargs):
|
||||
if isinstance(field, RecursiveField) and swagger_object_type == openapi.Schema:
|
||||
assert use_references is True, "Can not create schema for RecursiveField when use_references is False"
|
||||
|
||||
proxied = field.proxied
|
||||
if isinstance(field.proxied, serializers.ListSerializer):
|
||||
proxied = proxied.child
|
||||
|
||||
ref_name = get_serializer_ref_name(proxied)
|
||||
assert ref_name is not None, "Can't create RecursiveField schema for inline " + str(type(proxied))
|
||||
|
||||
definitions = self.components.with_scope(openapi.SCHEMA_DEFINITIONS)
|
||||
|
||||
ref = openapi.SchemaRef(definitions, ref_name, ignore_unresolved=True)
|
||||
if isinstance(field.proxied, serializers.ListSerializer):
|
||||
ref = openapi.Items(type=openapi.TYPE_ARRAY, items=ref)
|
||||
|
||||
return ref
|
||||
|
||||
return NotHandled
|
||||
@@ -0,0 +1,131 @@
|
||||
from collections import OrderedDict
|
||||
from functools import wraps
|
||||
|
||||
try:
|
||||
import coreschema
|
||||
except ImportError:
|
||||
coreschema = None
|
||||
|
||||
from .. import openapi
|
||||
from ..utils import force_real_str
|
||||
from .base import FilterInspector, PaginatorInspector, NotHandled
|
||||
|
||||
|
||||
def ignore_assert_decorator(func):
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
try:
|
||||
return func(*args, **kwargs)
|
||||
except AssertionError:
|
||||
return NotHandled
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
class DrfAPICompatInspector(PaginatorInspector, FilterInspector):
|
||||
def param_to_schema(self, param):
|
||||
return openapi.Parameter(
|
||||
name=param['name'],
|
||||
in_=param['in'],
|
||||
description=param.get('description'),
|
||||
required=param.get('required', False),
|
||||
**param['schema'],
|
||||
)
|
||||
|
||||
def get_paginator_parameters(self, paginator):
|
||||
if hasattr(paginator, 'get_schema_operation_parameters'):
|
||||
return list(map(self.param_to_schema, paginator.get_schema_operation_parameters(self.view)))
|
||||
return NotHandled
|
||||
|
||||
def get_filter_parameters(self, filter_backend):
|
||||
if hasattr(filter_backend, 'get_schema_operation_parameters'):
|
||||
return list(map(self.param_to_schema, filter_backend.get_schema_operation_parameters(self.view)))
|
||||
return NotHandled
|
||||
|
||||
|
||||
class CoreAPICompatInspector(PaginatorInspector, FilterInspector):
|
||||
"""Converts ``coreapi.Field``\\ s to :class:`.openapi.Parameter`\\ s for filters and paginators that implement a
|
||||
``get_schema_fields`` method.
|
||||
"""
|
||||
|
||||
@ignore_assert_decorator
|
||||
def get_paginator_parameters(self, paginator):
|
||||
fields = []
|
||||
if hasattr(paginator, 'get_schema_fields'):
|
||||
fields = paginator.get_schema_fields(self.view)
|
||||
|
||||
return [self.coreapi_field_to_parameter(field) for field in fields]
|
||||
|
||||
@ignore_assert_decorator
|
||||
def get_filter_parameters(self, filter_backend):
|
||||
fields = []
|
||||
if hasattr(filter_backend, 'get_schema_fields'):
|
||||
fields = filter_backend.get_schema_fields(self.view)
|
||||
return [self.coreapi_field_to_parameter(field) for field in fields]
|
||||
|
||||
def coreapi_field_to_parameter(self, field):
|
||||
"""Convert an instance of `coreapi.Field` to a swagger :class:`.Parameter` object.
|
||||
|
||||
:param coreapi.Field field:
|
||||
:rtype: openapi.Parameter
|
||||
"""
|
||||
location_to_in = {
|
||||
'query': openapi.IN_QUERY,
|
||||
'path': openapi.IN_PATH,
|
||||
'form': openapi.IN_FORM,
|
||||
'body': openapi.IN_FORM,
|
||||
}
|
||||
coreapi_types = {
|
||||
coreschema.Integer: openapi.TYPE_INTEGER,
|
||||
coreschema.Number: openapi.TYPE_NUMBER,
|
||||
coreschema.String: openapi.TYPE_STRING,
|
||||
coreschema.Boolean: openapi.TYPE_BOOLEAN,
|
||||
}
|
||||
|
||||
coreschema_attrs = ['format', 'pattern', 'enum', 'min_length', 'max_length']
|
||||
schema = field.schema
|
||||
return openapi.Parameter(
|
||||
name=field.name,
|
||||
in_=location_to_in[field.location],
|
||||
required=field.required,
|
||||
description=force_real_str(schema.description) if schema else None,
|
||||
type=coreapi_types.get(type(schema), openapi.TYPE_STRING),
|
||||
**OrderedDict((attr, getattr(schema, attr, None)) for attr in coreschema_attrs)
|
||||
)
|
||||
|
||||
|
||||
class DjangoRestResponsePagination(PaginatorInspector):
|
||||
"""Provides response schema pagination wrapping for django-rest-framework's LimitOffsetPagination,
|
||||
PageNumberPagination and CursorPagination
|
||||
"""
|
||||
|
||||
def fix_paginated_property(self, key: str, value: dict):
|
||||
# Need to remove useless params from schema
|
||||
value.pop('example', None)
|
||||
if 'nullable' in value:
|
||||
value['x-nullable'] = value.pop('nullable')
|
||||
if key in {'next', 'previous'} and 'format' not in value:
|
||||
value['format'] = 'uri'
|
||||
return openapi.Schema(**value)
|
||||
|
||||
def get_paginated_response(self, paginator, response_schema):
|
||||
if hasattr(paginator, 'get_paginated_response_schema'):
|
||||
paginator_schema = paginator.get_paginated_response_schema(response_schema)
|
||||
if paginator_schema['type'] == openapi.TYPE_OBJECT:
|
||||
properties = {
|
||||
k: self.fix_paginated_property(k, v)
|
||||
for k, v in paginator_schema.pop('properties').items()
|
||||
}
|
||||
if 'required' not in paginator_schema:
|
||||
paginator_schema.setdefault('required', [])
|
||||
for prop in ('count', 'results'):
|
||||
if prop in properties:
|
||||
paginator_schema['required'].append(prop)
|
||||
return openapi.Schema(
|
||||
**paginator_schema,
|
||||
properties=properties
|
||||
)
|
||||
else:
|
||||
return openapi.Schema(**paginator_schema)
|
||||
|
||||
return response_schema
|
||||
@@ -0,0 +1,408 @@
|
||||
import logging
|
||||
from collections import OrderedDict
|
||||
|
||||
from rest_framework.request import is_form_media_type
|
||||
from rest_framework.schemas import AutoSchema
|
||||
from rest_framework.status import is_success
|
||||
|
||||
from .. import openapi
|
||||
from ..errors import SwaggerGenerationError
|
||||
from ..utils import (
|
||||
filter_none, force_real_str, force_serializer_instance, get_consumes, get_produces, guess_response_status,
|
||||
merge_params, no_body, param_list_to_odict
|
||||
)
|
||||
from .base import ViewInspector, call_view_method
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SwaggerAutoSchema(ViewInspector):
|
||||
def __init__(self, view, path, method, components, request, overrides, operation_keys=None):
|
||||
super(SwaggerAutoSchema, self).__init__(view, path, method, components, request, overrides)
|
||||
self._sch = AutoSchema()
|
||||
self._sch.view = view
|
||||
self.operation_keys = operation_keys
|
||||
|
||||
def get_operation(self, operation_keys=None):
|
||||
operation_keys = operation_keys or self.operation_keys
|
||||
|
||||
consumes = self.get_consumes()
|
||||
produces = self.get_produces()
|
||||
|
||||
body = self.get_request_body_parameters(consumes)
|
||||
query = self.get_query_parameters()
|
||||
parameters = body + query
|
||||
parameters = filter_none(parameters)
|
||||
parameters = self.add_manual_parameters(parameters)
|
||||
|
||||
operation_id = self.get_operation_id(operation_keys)
|
||||
summary, description = self.get_summary_and_description()
|
||||
security = self.get_security()
|
||||
assert security is None or isinstance(security, list), "security must be a list of security requirement objects"
|
||||
deprecated = self.is_deprecated()
|
||||
tags = self.get_tags(operation_keys)
|
||||
|
||||
responses = self.get_responses()
|
||||
|
||||
return openapi.Operation(
|
||||
operation_id=operation_id,
|
||||
description=force_real_str(description),
|
||||
summary=force_real_str(summary),
|
||||
responses=responses,
|
||||
parameters=parameters,
|
||||
consumes=consumes,
|
||||
produces=produces,
|
||||
tags=tags,
|
||||
security=security,
|
||||
deprecated=deprecated
|
||||
)
|
||||
|
||||
def get_request_body_parameters(self, consumes):
|
||||
"""Return the request body parameters for this view. |br|
|
||||
This is either:
|
||||
|
||||
- a list with a single object Parameter with a :class:`.Schema` derived from the request serializer
|
||||
- a list of primitive Parameters parsed as form data
|
||||
|
||||
:param list[str] consumes: a list of accepted MIME types as returned by :meth:`.get_consumes`
|
||||
:return: a (potentially empty) list of :class:`.Parameter`\\ s either ``in: body`` or ``in: formData``
|
||||
:rtype: list[openapi.Parameter]
|
||||
"""
|
||||
serializer = self.get_request_serializer()
|
||||
schema = None
|
||||
if serializer is None:
|
||||
return []
|
||||
|
||||
if isinstance(serializer, openapi.Schema.OR_REF):
|
||||
schema = serializer
|
||||
|
||||
if any(is_form_media_type(encoding) for encoding in consumes):
|
||||
if schema is not None:
|
||||
raise SwaggerGenerationError("form request body cannot be a Schema")
|
||||
return self.get_request_form_parameters(serializer)
|
||||
else:
|
||||
if schema is None:
|
||||
schema = self.get_request_body_schema(serializer)
|
||||
return [self.make_body_parameter(schema)] if schema is not None else []
|
||||
|
||||
def get_view_serializer(self):
|
||||
"""Return the serializer as defined by the view's ``get_serializer()`` method.
|
||||
|
||||
:return: the view's ``Serializer``
|
||||
:rtype: rest_framework.serializers.Serializer
|
||||
"""
|
||||
return call_view_method(self.view, 'get_serializer')
|
||||
|
||||
def _get_request_body_override(self):
|
||||
"""Parse the request_body key in the override dict. This method is not public API."""
|
||||
body_override = self.overrides.get('request_body', None)
|
||||
|
||||
if body_override is not None:
|
||||
if body_override is no_body:
|
||||
return no_body
|
||||
if self.method not in self.body_methods:
|
||||
raise SwaggerGenerationError("request_body can only be applied to (" + ','.join(self.body_methods) +
|
||||
"); are you looking for query_serializer or manual_parameters?")
|
||||
if isinstance(body_override, openapi.Schema.OR_REF):
|
||||
return body_override
|
||||
return force_serializer_instance(body_override)
|
||||
|
||||
return body_override
|
||||
|
||||
def get_request_serializer(self):
|
||||
"""Return the request serializer (used for parsing the request payload) for this endpoint.
|
||||
|
||||
:return: the request serializer, or one of :class:`.Schema`, :class:`.SchemaRef`, ``None``
|
||||
:rtype: rest_framework.serializers.Serializer
|
||||
"""
|
||||
body_override = self._get_request_body_override()
|
||||
|
||||
if body_override is None and self.method in self.implicit_body_methods:
|
||||
return self.get_view_serializer()
|
||||
|
||||
if body_override is no_body:
|
||||
return None
|
||||
|
||||
return body_override
|
||||
|
||||
def get_request_form_parameters(self, serializer):
|
||||
"""Given a Serializer, return a list of ``in: formData`` :class:`.Parameter`\\ s.
|
||||
|
||||
:param serializer: the view's request serializer as returned by :meth:`.get_request_serializer`
|
||||
:rtype: list[openapi.Parameter]
|
||||
"""
|
||||
return self.serializer_to_parameters(serializer, in_=openapi.IN_FORM)
|
||||
|
||||
def get_request_body_schema(self, serializer):
|
||||
"""Return the :class:`.Schema` for a given request's body data. Only applies to PUT, PATCH and POST requests.
|
||||
|
||||
:param serializer: the view's request serializer as returned by :meth:`.get_request_serializer`
|
||||
:rtype: openapi.Schema
|
||||
"""
|
||||
return self.serializer_to_schema(serializer)
|
||||
|
||||
def make_body_parameter(self, schema):
|
||||
"""Given a :class:`.Schema` object, create an ``in: body`` :class:`.Parameter`.
|
||||
|
||||
:param openapi.Schema schema: the request body schema
|
||||
:rtype: openapi.Parameter
|
||||
"""
|
||||
return openapi.Parameter(name='data', in_=openapi.IN_BODY, required=True, schema=schema)
|
||||
|
||||
def add_manual_parameters(self, parameters):
|
||||
"""Add/replace parameters from the given list of automatically generated request parameters.
|
||||
|
||||
:param list[openapi.Parameter] parameters: generated parameters
|
||||
:return: modified parameters
|
||||
:rtype: list[openapi.Parameter]
|
||||
"""
|
||||
manual_parameters = self.overrides.get('manual_parameters', None) or []
|
||||
|
||||
if any(param.in_ == openapi.IN_BODY for param in manual_parameters): # pragma: no cover
|
||||
raise SwaggerGenerationError("specify the body parameter as a Schema or Serializer in request_body")
|
||||
if any(param.in_ == openapi.IN_FORM for param in manual_parameters): # pragma: no cover
|
||||
has_body_parameter = any(param.in_ == openapi.IN_BODY for param in parameters)
|
||||
if has_body_parameter or not any(is_form_media_type(encoding) for encoding in self.get_consumes()):
|
||||
raise SwaggerGenerationError("cannot add form parameters when the request has a request body; "
|
||||
"did you forget to set an appropriate parser class on the view?")
|
||||
if self.method not in self.body_methods:
|
||||
raise SwaggerGenerationError("form parameters can only be applied to "
|
||||
"(" + ','.join(self.body_methods) + ") HTTP methods")
|
||||
|
||||
return merge_params(parameters, manual_parameters)
|
||||
|
||||
def get_responses(self):
|
||||
"""Get the possible responses for this view as a swagger :class:`.Responses` object.
|
||||
|
||||
:return: the documented responses
|
||||
:rtype: openapi.Responses
|
||||
"""
|
||||
response_serializers = self.get_response_serializers()
|
||||
return openapi.Responses(
|
||||
responses=self.get_response_schemas(response_serializers)
|
||||
)
|
||||
|
||||
def get_default_response_serializer(self):
|
||||
"""Return the default response serializer for this endpoint. This is derived from either the ``request_body``
|
||||
override or the request serializer (:meth:`.get_view_serializer`).
|
||||
|
||||
:return: response serializer, :class:`.Schema`, :class:`.SchemaRef`, ``None``
|
||||
"""
|
||||
body_override = self._get_request_body_override()
|
||||
if body_override and body_override is not no_body:
|
||||
return body_override
|
||||
|
||||
return self.get_view_serializer()
|
||||
|
||||
def get_default_responses(self):
|
||||
"""Get the default responses determined for this view from the request serializer and request method.
|
||||
|
||||
:type: dict[str, openapi.Schema]
|
||||
"""
|
||||
method = self.method.lower()
|
||||
|
||||
default_status = guess_response_status(method)
|
||||
default_schema = ''
|
||||
if method in ('get', 'post', 'put', 'patch'):
|
||||
default_schema = self.get_default_response_serializer()
|
||||
|
||||
default_schema = default_schema or ''
|
||||
if default_schema and not isinstance(default_schema, openapi.Schema):
|
||||
default_schema = self.serializer_to_schema(default_schema) or ''
|
||||
|
||||
if default_schema:
|
||||
if self.has_list_response():
|
||||
default_schema = openapi.Schema(type=openapi.TYPE_ARRAY, items=default_schema)
|
||||
if self.should_page():
|
||||
default_schema = self.get_paginated_response(default_schema) or default_schema
|
||||
|
||||
return OrderedDict({str(default_status): default_schema})
|
||||
|
||||
def get_response_serializers(self):
|
||||
"""Return the response codes that this view is expected to return, and the serializer for each response body.
|
||||
The return value should be a dict where the keys are possible status codes, and values are either strings,
|
||||
``Serializer``\\ s, :class:`.Schema`, :class:`.SchemaRef` or :class:`.Response` objects. See
|
||||
:func:`@swagger_auto_schema <.swagger_auto_schema>` for more details.
|
||||
|
||||
:return: the response serializers
|
||||
:rtype: dict
|
||||
"""
|
||||
manual_responses = self.overrides.get('responses', None) or {}
|
||||
manual_responses = OrderedDict((str(sc), resp) for sc, resp in manual_responses.items())
|
||||
|
||||
responses = OrderedDict()
|
||||
if not any(is_success(int(sc)) for sc in manual_responses if sc != 'default'):
|
||||
responses = self.get_default_responses()
|
||||
|
||||
responses.update((str(sc), resp) for sc, resp in manual_responses.items())
|
||||
return responses
|
||||
|
||||
def get_response_schemas(self, response_serializers):
|
||||
"""Return the :class:`.openapi.Response` objects calculated for this view.
|
||||
|
||||
:param dict response_serializers: response serializers as returned by :meth:`.get_response_serializers`
|
||||
:return: a dictionary of status code to :class:`.Response` object
|
||||
:rtype: dict[str, openapi.Response]
|
||||
"""
|
||||
responses = OrderedDict()
|
||||
for sc, serializer in response_serializers.items():
|
||||
if isinstance(serializer, str):
|
||||
response = openapi.Response(
|
||||
description=force_real_str(serializer)
|
||||
)
|
||||
elif not serializer:
|
||||
continue
|
||||
elif isinstance(serializer, openapi.Response):
|
||||
response = serializer
|
||||
if hasattr(response, 'schema') and not isinstance(response.schema, openapi.Schema.OR_REF):
|
||||
serializer = force_serializer_instance(response.schema)
|
||||
response.schema = self.serializer_to_schema(serializer)
|
||||
elif isinstance(serializer, openapi.Schema.OR_REF):
|
||||
response = openapi.Response(
|
||||
description='',
|
||||
schema=serializer,
|
||||
)
|
||||
elif isinstance(serializer, openapi._Ref):
|
||||
response = serializer
|
||||
else:
|
||||
serializer = force_serializer_instance(serializer)
|
||||
response = openapi.Response(
|
||||
description='',
|
||||
schema=self.serializer_to_schema(serializer),
|
||||
)
|
||||
|
||||
responses[str(sc)] = response
|
||||
|
||||
return responses
|
||||
|
||||
def get_query_serializer(self):
|
||||
"""Return the query serializer (used for parsing query parameters) for this endpoint.
|
||||
|
||||
:return: the query serializer, or ``None``
|
||||
"""
|
||||
query_serializer = self.overrides.get('query_serializer', None)
|
||||
if query_serializer is not None:
|
||||
query_serializer = force_serializer_instance(query_serializer)
|
||||
return query_serializer
|
||||
|
||||
def get_query_parameters(self):
|
||||
"""Return the query parameters accepted by this view.
|
||||
|
||||
:rtype: list[openapi.Parameter]
|
||||
"""
|
||||
natural_parameters = self.get_filter_parameters() + self.get_pagination_parameters()
|
||||
|
||||
query_serializer = self.get_query_serializer()
|
||||
serializer_parameters = []
|
||||
if query_serializer is not None:
|
||||
serializer_parameters = self.serializer_to_parameters(query_serializer, in_=openapi.IN_QUERY)
|
||||
|
||||
if len(set(param_list_to_odict(natural_parameters)) & set(param_list_to_odict(serializer_parameters))) != 0:
|
||||
raise SwaggerGenerationError(
|
||||
"your query_serializer contains fields that conflict with the "
|
||||
"filter_backend or paginator_class on the view - %s %s" % (self.method, self.path)
|
||||
)
|
||||
|
||||
return natural_parameters + serializer_parameters
|
||||
|
||||
def get_operation_id(self, operation_keys=None):
|
||||
"""Return an unique ID for this operation. The ID must be unique across
|
||||
all :class:`.Operation` objects in the API.
|
||||
|
||||
:param tuple[str] operation_keys: an array of keys derived from the path describing the hierarchical layout
|
||||
of this view in the API; e.g. ``('snippets', 'list')``, ``('snippets', 'retrieve')``, etc.
|
||||
:rtype: str
|
||||
"""
|
||||
operation_keys = operation_keys or self.operation_keys
|
||||
|
||||
operation_id = self.overrides.get('operation_id', '')
|
||||
if not operation_id:
|
||||
operation_id = '_'.join(operation_keys)
|
||||
return operation_id
|
||||
|
||||
def split_summary_from_description(self, description):
|
||||
"""Decide if and how to split a summary out of the given description. The default implementation
|
||||
uses the first paragraph of the description as a summary if it is less than 120 characters long.
|
||||
|
||||
:param description: the full description to be analyzed
|
||||
:return: summary and description
|
||||
:rtype: (str,str)
|
||||
"""
|
||||
# https://www.python.org/dev/peps/pep-0257/#multi-line-docstrings
|
||||
summary = None
|
||||
summary_max_len = 120 # OpenAPI 2.0 spec says summary should be under 120 characters
|
||||
sections = description.split('\n\n', 1)
|
||||
if len(sections) == 2:
|
||||
sections[0] = sections[0].strip()
|
||||
if len(sections[0]) < summary_max_len:
|
||||
summary, description = sections
|
||||
description = description.strip()
|
||||
|
||||
return summary, description
|
||||
|
||||
def get_summary_and_description(self):
|
||||
"""Return an operation summary and description determined from the view's docstring.
|
||||
|
||||
:return: summary and description
|
||||
:rtype: (str,str)
|
||||
"""
|
||||
description = self.overrides.get('operation_description', None)
|
||||
summary = self.overrides.get('operation_summary', None)
|
||||
if description is None:
|
||||
description = self._sch.get_description(self.path, self.method) or ''
|
||||
description = description.strip().replace('\r', '')
|
||||
|
||||
if description and (summary is None):
|
||||
# description from docstring... do summary magic
|
||||
summary, description = self.split_summary_from_description(description)
|
||||
|
||||
return summary, description
|
||||
|
||||
def get_security(self):
|
||||
"""Return a list of security requirements for this operation.
|
||||
|
||||
Returning an empty list marks the endpoint as unauthenticated (i.e. removes all accepted
|
||||
authentication schemes). Returning ``None`` will inherit the top-level security requirements.
|
||||
|
||||
:return: security requirements
|
||||
:rtype: list[dict[str,list[str]]]"""
|
||||
return self.overrides.get('security', None)
|
||||
|
||||
def is_deprecated(self):
|
||||
"""Return ``True`` if this operation is to be marked as deprecated.
|
||||
|
||||
:return: deprecation status
|
||||
:rtype: bool
|
||||
"""
|
||||
return self.overrides.get('deprecated', None)
|
||||
|
||||
def get_tags(self, operation_keys=None):
|
||||
"""Get a list of tags for this operation. Tags determine how operations relate with each other, and in the UI
|
||||
each tag will show as a group containing the operations that use it. If not provided in overrides,
|
||||
tags will be inferred from the operation url.
|
||||
|
||||
:param tuple[str] operation_keys: an array of keys derived from the path describing the hierarchical layout
|
||||
of this view in the API; e.g. ``('snippets', 'list')``, ``('snippets', 'retrieve')``, etc.
|
||||
:rtype: list[str]
|
||||
"""
|
||||
operation_keys = operation_keys or self.operation_keys
|
||||
|
||||
tags = self.overrides.get('tags')
|
||||
if not tags:
|
||||
tags = [operation_keys[0]]
|
||||
|
||||
return tags
|
||||
|
||||
def get_consumes(self):
|
||||
"""Return the MIME types this endpoint can consume.
|
||||
|
||||
:rtype: list[str]
|
||||
"""
|
||||
return get_consumes(self.get_parser_classes())
|
||||
|
||||
def get_produces(self):
|
||||
"""Return the MIME types this endpoint can produce.
|
||||
|
||||
:rtype: list[str]
|
||||
"""
|
||||
return get_produces(self.get_renderer_classes())
|
||||
Reference in New Issue
Block a user