updates
This commit is contained in:
@@ -1,41 +1,43 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
import re
|
||||
import typing
|
||||
from collections.abc import Callable
|
||||
from typing import Any, NamedTuple
|
||||
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import Response
|
||||
from starlette.routing import BaseRoute, Mount, Route
|
||||
from starlette.routing import BaseRoute, Host, Mount, Route
|
||||
|
||||
try:
|
||||
import yaml
|
||||
except ModuleNotFoundError: # pragma: nocover
|
||||
except ModuleNotFoundError: # pragma: no cover
|
||||
yaml = None # type: ignore[assignment]
|
||||
|
||||
|
||||
class OpenAPIResponse(Response):
|
||||
media_type = "application/vnd.oai.openapi"
|
||||
|
||||
def render(self, content: typing.Any) -> bytes:
|
||||
def render(self, content: Any) -> bytes:
|
||||
assert yaml is not None, "`pyyaml` must be installed to use OpenAPIResponse."
|
||||
assert isinstance(
|
||||
content, dict
|
||||
), "The schema passed to OpenAPIResponse should be a dictionary."
|
||||
assert isinstance(content, dict), "The schema passed to OpenAPIResponse should be a dictionary."
|
||||
return yaml.dump(content, default_flow_style=False).encode("utf-8")
|
||||
|
||||
|
||||
class EndpointInfo(typing.NamedTuple):
|
||||
class EndpointInfo(NamedTuple):
|
||||
path: str
|
||||
http_method: str
|
||||
func: typing.Callable
|
||||
func: Callable[..., Any]
|
||||
|
||||
|
||||
_remove_converter_pattern = re.compile(r":\w+}")
|
||||
|
||||
|
||||
class BaseSchemaGenerator:
|
||||
def get_schema(self, routes: typing.List[BaseRoute]) -> dict:
|
||||
def get_schema(self, routes: list[BaseRoute]) -> dict[str, Any]:
|
||||
raise NotImplementedError() # pragma: no cover
|
||||
|
||||
def get_endpoints(
|
||||
self, routes: typing.List[BaseRoute]
|
||||
) -> typing.List[EndpointInfo]:
|
||||
def get_endpoints(self, routes: list[BaseRoute]) -> list[EndpointInfo]:
|
||||
"""
|
||||
Given the routes, yields the following information:
|
||||
|
||||
@@ -46,12 +48,15 @@ class BaseSchemaGenerator:
|
||||
- func
|
||||
method ready to extract the docstring
|
||||
"""
|
||||
endpoints_info: list = []
|
||||
endpoints_info: list[EndpointInfo] = []
|
||||
|
||||
for route in routes:
|
||||
if isinstance(route, Mount):
|
||||
path = self._remove_converter(route.path)
|
||||
if isinstance(route, Mount | Host):
|
||||
routes = route.routes or []
|
||||
if isinstance(route, Mount):
|
||||
path = self._remove_converter(route.path)
|
||||
else:
|
||||
path = ""
|
||||
sub_endpoints = [
|
||||
EndpointInfo(
|
||||
path="".join((path, sub_endpoint.path)),
|
||||
@@ -70,9 +75,7 @@ class BaseSchemaGenerator:
|
||||
for method in route.methods or ["GET"]:
|
||||
if method == "HEAD":
|
||||
continue
|
||||
endpoints_info.append(
|
||||
EndpointInfo(path, method.lower(), route.endpoint)
|
||||
)
|
||||
endpoints_info.append(EndpointInfo(path, method.lower(), route.endpoint))
|
||||
else:
|
||||
path = self._remove_converter(route.path)
|
||||
for method in ["get", "post", "put", "patch", "delete", "options"]:
|
||||
@@ -90,9 +93,9 @@ class BaseSchemaGenerator:
|
||||
Route("/users/{id:int}", endpoint=get_user, methods=["GET"])
|
||||
Should be represented as `/users/{id}` in the OpenAPI schema.
|
||||
"""
|
||||
return re.sub(r":\w+}", "}", path)
|
||||
return _remove_converter_pattern.sub("}", path)
|
||||
|
||||
def parse_docstring(self, func_or_method: typing.Callable) -> dict:
|
||||
def parse_docstring(self, func_or_method: Callable[..., Any]) -> dict[str, Any]:
|
||||
"""
|
||||
Given a function, parse the docstring as YAML and return a dictionary of info.
|
||||
"""
|
||||
@@ -123,10 +126,10 @@ class BaseSchemaGenerator:
|
||||
|
||||
|
||||
class SchemaGenerator(BaseSchemaGenerator):
|
||||
def __init__(self, base_schema: dict) -> None:
|
||||
def __init__(self, base_schema: dict[str, Any]) -> None:
|
||||
self.base_schema = base_schema
|
||||
|
||||
def get_schema(self, routes: typing.List[BaseRoute]) -> dict:
|
||||
def get_schema(self, routes: list[BaseRoute]) -> dict[str, Any]:
|
||||
schema = dict(self.base_schema)
|
||||
schema.setdefault("paths", {})
|
||||
endpoints_info = self.get_endpoints(routes)
|
||||
|
||||
Reference in New Issue
Block a user