update
This commit is contained in:
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -6,9 +6,7 @@ import uuid
|
||||
from contextlib import suppress
|
||||
from collections import OrderedDict
|
||||
from decimal import Decimal
|
||||
from inspect import signature as inspect_signature
|
||||
|
||||
import pkg_resources
|
||||
import typing
|
||||
from django.core import validators
|
||||
from django.db import models
|
||||
@@ -23,7 +21,20 @@ from ..utils import (
|
||||
decimal_as_float, field_value_to_representation, filter_none, get_serializer_class, get_serializer_ref_name
|
||||
)
|
||||
|
||||
drf_version = pkg_resources.get_distribution("djangorestframework").version
|
||||
try:
|
||||
from importlib import metadata
|
||||
drf_version = metadata.version("djangorestframework")
|
||||
except ImportError: # Python < 3.8
|
||||
import pkg_resources
|
||||
drf_version = pkg_resources.get_distribution("djangorestframework").version
|
||||
|
||||
try:
|
||||
from types import NoneType, UnionType
|
||||
|
||||
UNION_TYPES = (typing.Union, UnionType)
|
||||
except ImportError: # Python < 3.10
|
||||
NoneType = type(None)
|
||||
UNION_TYPES = (typing.Union,)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -189,7 +200,7 @@ def get_queryset_from_view(view, serializer=None):
|
||||
|
||||
if queryset is not None and serializer is not None:
|
||||
# make sure the view is actually using *this* serializer
|
||||
assert type(serializer) == call_view_method(view, 'get_serializer_class', 'serializer_class')
|
||||
assert type(serializer) is call_view_method(view, 'get_serializer_class', 'serializer_class')
|
||||
|
||||
return queryset
|
||||
except Exception: # pragma: no cover
|
||||
@@ -476,15 +487,6 @@ def decimal_return_type():
|
||||
return openapi.TYPE_STRING if rest_framework_settings.COERCE_DECIMAL_TO_STRING else openapi.TYPE_NUMBER
|
||||
|
||||
|
||||
def get_origin_type(hint_class):
|
||||
return getattr(hint_class, '__origin__', None) or hint_class
|
||||
|
||||
|
||||
def hint_class_issubclass(hint_class, check_class):
|
||||
origin_type = get_origin_type(hint_class)
|
||||
return inspect.isclass(origin_type) and issubclass(origin_type, check_class)
|
||||
|
||||
|
||||
hinting_type_info = [
|
||||
(bool, (openapi.TYPE_BOOLEAN, None)),
|
||||
(int, (openapi.TYPE_INTEGER, None)),
|
||||
@@ -501,11 +503,15 @@ hinting_type_info = [
|
||||
if hasattr(typing, 'get_args'):
|
||||
# python >=3.8
|
||||
typing_get_args = typing.get_args
|
||||
typing_get_origin = typing.get_origin
|
||||
else:
|
||||
# python <3.8
|
||||
def typing_get_args(tp):
|
||||
return getattr(tp, '__args__', ())
|
||||
|
||||
def typing_get_origin(tp):
|
||||
return getattr(tp, '__origin__', None)
|
||||
|
||||
|
||||
def inspect_collection_hint_class(hint_class):
|
||||
args = typing_get_args(hint_class)
|
||||
@@ -521,12 +527,6 @@ def inspect_collection_hint_class(hint_class):
|
||||
hinting_type_info.append(((typing.Sequence, typing.AbstractSet), inspect_collection_hint_class))
|
||||
|
||||
|
||||
def _get_union_types(hint_class):
|
||||
origin_type = get_origin_type(hint_class)
|
||||
if origin_type is typing.Union:
|
||||
return hint_class.__args__
|
||||
|
||||
|
||||
def get_basic_type_info_from_hint(hint_class):
|
||||
"""Given a class (eg from a SerializerMethodField's return type hint,
|
||||
return its basic type information - ``type``, ``format``, ``pattern``,
|
||||
@@ -536,12 +536,12 @@ def get_basic_type_info_from_hint(hint_class):
|
||||
:return: the extracted attributes as a dictionary, or ``None`` if the field type is not known
|
||||
:rtype: OrderedDict
|
||||
"""
|
||||
union_types = _get_union_types(hint_class)
|
||||
|
||||
if union_types:
|
||||
if typing_get_origin(hint_class) in UNION_TYPES:
|
||||
# Optional is implemented as Union[T, None]
|
||||
if len(union_types) == 2 and isinstance(None, union_types[1]):
|
||||
result = get_basic_type_info_from_hint(union_types[0])
|
||||
filtered_types = [t for t in typing_get_args(hint_class) if t is not NoneType]
|
||||
if len(filtered_types) == 1:
|
||||
result = get_basic_type_info_from_hint(filtered_types[0])
|
||||
if result:
|
||||
result['x-nullable'] = True
|
||||
|
||||
@@ -549,8 +549,15 @@ def get_basic_type_info_from_hint(hint_class):
|
||||
|
||||
return None
|
||||
|
||||
# resolve the origin class if the class is generic
|
||||
resolved_class = typing_get_origin(hint_class) or hint_class
|
||||
|
||||
# bail out early
|
||||
if not inspect.isclass(resolved_class):
|
||||
return None
|
||||
|
||||
for check_class, info in hinting_type_info:
|
||||
if hint_class_issubclass(hint_class, check_class):
|
||||
if issubclass(resolved_class, check_class):
|
||||
if callable(info):
|
||||
return info(hint_class)
|
||||
|
||||
@@ -613,17 +620,19 @@ class SerializerMethodFieldInspector(FieldInspector):
|
||||
return self.probe_field_inspectors(serializer, swagger_object_type, use_references, read_only=True)
|
||||
else:
|
||||
# look for Python 3.5+ style type hinting of the return value
|
||||
hint_class = inspect_signature(method).return_annotation
|
||||
hint_class = typing.get_type_hints(method).get('return')
|
||||
|
||||
if not inspect.isclass(hint_class) and hasattr(hint_class, '__args__'):
|
||||
hint_class = hint_class.__args__[0]
|
||||
if inspect.isclass(hint_class) and not issubclass(hint_class, inspect._empty):
|
||||
type_info = get_basic_type_info_from_hint(hint_class)
|
||||
# annotations such as typing.Optional have an __instancecheck__
|
||||
# hook and will not look like classes, but `issubclass` needs
|
||||
# a class as its first argument, so only in that case abort
|
||||
if inspect.isclass(hint_class) and issubclass(hint_class, inspect._empty):
|
||||
return NotHandled
|
||||
|
||||
if type_info is not None:
|
||||
SwaggerType, ChildSwaggerType = self._get_partial_types(field, swagger_object_type,
|
||||
use_references, **kwargs)
|
||||
return SwaggerType(**type_info)
|
||||
type_info = get_basic_type_info_from_hint(hint_class)
|
||||
if type_info is not None:
|
||||
SwaggerType, ChildSwaggerType = self._get_partial_types(field, swagger_object_type,
|
||||
use_references, **kwargs)
|
||||
return SwaggerType(**type_info)
|
||||
|
||||
return NotHandled
|
||||
|
||||
|
||||
Reference in New Issue
Block a user