This commit is contained in:
Iliyan Angelov
2025-12-01 06:50:10 +02:00
parent 91f51bc6fe
commit 62c1fe5951
4682 changed files with 544807 additions and 31208 deletions

View File

@@ -1,41 +1,43 @@
from __future__ import annotations
import inspect
import re
import typing
from collections.abc import Callable
from typing import Any, NamedTuple
from starlette.requests import Request
from starlette.responses import Response
from starlette.routing import BaseRoute, Mount, Route
from starlette.routing import BaseRoute, Host, Mount, Route
try:
import yaml
except ModuleNotFoundError: # pragma: nocover
except ModuleNotFoundError: # pragma: no cover
yaml = None # type: ignore[assignment]
class OpenAPIResponse(Response):
media_type = "application/vnd.oai.openapi"
def render(self, content: typing.Any) -> bytes:
def render(self, content: Any) -> bytes:
assert yaml is not None, "`pyyaml` must be installed to use OpenAPIResponse."
assert isinstance(
content, dict
), "The schema passed to OpenAPIResponse should be a dictionary."
assert isinstance(content, dict), "The schema passed to OpenAPIResponse should be a dictionary."
return yaml.dump(content, default_flow_style=False).encode("utf-8")
class EndpointInfo(typing.NamedTuple):
class EndpointInfo(NamedTuple):
path: str
http_method: str
func: typing.Callable
func: Callable[..., Any]
_remove_converter_pattern = re.compile(r":\w+}")
class BaseSchemaGenerator:
def get_schema(self, routes: typing.List[BaseRoute]) -> dict:
def get_schema(self, routes: list[BaseRoute]) -> dict[str, Any]:
raise NotImplementedError() # pragma: no cover
def get_endpoints(
self, routes: typing.List[BaseRoute]
) -> typing.List[EndpointInfo]:
def get_endpoints(self, routes: list[BaseRoute]) -> list[EndpointInfo]:
"""
Given the routes, yields the following information:
@@ -46,12 +48,15 @@ class BaseSchemaGenerator:
- func
method ready to extract the docstring
"""
endpoints_info: list = []
endpoints_info: list[EndpointInfo] = []
for route in routes:
if isinstance(route, Mount):
path = self._remove_converter(route.path)
if isinstance(route, Mount | Host):
routes = route.routes or []
if isinstance(route, Mount):
path = self._remove_converter(route.path)
else:
path = ""
sub_endpoints = [
EndpointInfo(
path="".join((path, sub_endpoint.path)),
@@ -70,9 +75,7 @@ class BaseSchemaGenerator:
for method in route.methods or ["GET"]:
if method == "HEAD":
continue
endpoints_info.append(
EndpointInfo(path, method.lower(), route.endpoint)
)
endpoints_info.append(EndpointInfo(path, method.lower(), route.endpoint))
else:
path = self._remove_converter(route.path)
for method in ["get", "post", "put", "patch", "delete", "options"]:
@@ -90,9 +93,9 @@ class BaseSchemaGenerator:
Route("/users/{id:int}", endpoint=get_user, methods=["GET"])
Should be represented as `/users/{id}` in the OpenAPI schema.
"""
return re.sub(r":\w+}", "}", path)
return _remove_converter_pattern.sub("}", path)
def parse_docstring(self, func_or_method: typing.Callable) -> dict:
def parse_docstring(self, func_or_method: Callable[..., Any]) -> dict[str, Any]:
"""
Given a function, parse the docstring as YAML and return a dictionary of info.
"""
@@ -123,10 +126,10 @@ class BaseSchemaGenerator:
class SchemaGenerator(BaseSchemaGenerator):
def __init__(self, base_schema: dict) -> None:
def __init__(self, base_schema: dict[str, Any]) -> None:
self.base_schema = base_schema
def get_schema(self, routes: typing.List[BaseRoute]) -> dict:
def get_schema(self, routes: list[BaseRoute]) -> dict[str, Any]:
schema = dict(self.base_schema)
schema.setdefault("paths", {})
endpoints_info = self.get_endpoints(routes)