updates
This commit is contained in:
@@ -1,11 +1,17 @@
|
||||
import typing
|
||||
from __future__ import annotations
|
||||
|
||||
from starlette._utils import is_async_callable
|
||||
from starlette.concurrency import run_in_threadpool
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from starlette._exception_handler import (
|
||||
ExceptionHandlers,
|
||||
StatusHandlers,
|
||||
wrap_app_handling_exceptions,
|
||||
)
|
||||
from starlette.exceptions import HTTPException, WebSocketException
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import PlainTextResponse, Response
|
||||
from starlette.types import ASGIApp, Message, Receive, Scope, Send
|
||||
from starlette.types import ASGIApp, ExceptionHandler, Receive, Scope, Send
|
||||
from starlette.websockets import WebSocket
|
||||
|
||||
|
||||
@@ -13,28 +19,24 @@ class ExceptionMiddleware:
|
||||
def __init__(
|
||||
self,
|
||||
app: ASGIApp,
|
||||
handlers: typing.Optional[
|
||||
typing.Mapping[typing.Any, typing.Callable[[Request, Exception], Response]]
|
||||
] = None,
|
||||
handlers: Mapping[Any, ExceptionHandler] | None = None,
|
||||
debug: bool = False,
|
||||
) -> None:
|
||||
self.app = app
|
||||
self.debug = debug # TODO: We ought to handle 404 cases if debug is set.
|
||||
self._status_handlers: typing.Dict[int, typing.Callable] = {}
|
||||
self._exception_handlers: typing.Dict[
|
||||
typing.Type[Exception], typing.Callable
|
||||
] = {
|
||||
self._status_handlers: StatusHandlers = {}
|
||||
self._exception_handlers: ExceptionHandlers = {
|
||||
HTTPException: self.http_exception,
|
||||
WebSocketException: self.websocket_exception,
|
||||
}
|
||||
if handlers is not None:
|
||||
if handlers is not None: # pragma: no branch
|
||||
for key, value in handlers.items():
|
||||
self.add_exception_handler(key, value)
|
||||
|
||||
def add_exception_handler(
|
||||
self,
|
||||
exc_class_or_status_code: typing.Union[int, typing.Type[Exception]],
|
||||
handler: typing.Callable[[Request, Exception], Response],
|
||||
exc_class_or_status_code: int | type[Exception],
|
||||
handler: ExceptionHandler,
|
||||
) -> None:
|
||||
if isinstance(exc_class_or_status_code, int):
|
||||
self._status_handlers[exc_class_or_status_code] = handler
|
||||
@@ -42,68 +44,30 @@ class ExceptionMiddleware:
|
||||
assert issubclass(exc_class_or_status_code, Exception)
|
||||
self._exception_handlers[exc_class_or_status_code] = handler
|
||||
|
||||
def _lookup_exception_handler(
|
||||
self, exc: Exception
|
||||
) -> typing.Optional[typing.Callable]:
|
||||
for cls in type(exc).__mro__:
|
||||
if cls in self._exception_handlers:
|
||||
return self._exception_handlers[cls]
|
||||
return None
|
||||
|
||||
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
if scope["type"] not in ("http", "websocket"):
|
||||
await self.app(scope, receive, send)
|
||||
return
|
||||
|
||||
response_started = False
|
||||
|
||||
async def sender(message: Message) -> None:
|
||||
nonlocal response_started
|
||||
|
||||
if message["type"] == "http.response.start":
|
||||
response_started = True
|
||||
await send(message)
|
||||
|
||||
try:
|
||||
await self.app(scope, receive, sender)
|
||||
except Exception as exc:
|
||||
handler = None
|
||||
|
||||
if isinstance(exc, HTTPException):
|
||||
handler = self._status_handlers.get(exc.status_code)
|
||||
|
||||
if handler is None:
|
||||
handler = self._lookup_exception_handler(exc)
|
||||
|
||||
if handler is None:
|
||||
raise exc
|
||||
|
||||
if response_started:
|
||||
msg = "Caught handled exception, but response already started."
|
||||
raise RuntimeError(msg) from exc
|
||||
|
||||
if scope["type"] == "http":
|
||||
request = Request(scope, receive=receive)
|
||||
if is_async_callable(handler):
|
||||
response = await handler(request, exc)
|
||||
else:
|
||||
response = await run_in_threadpool(handler, request, exc)
|
||||
await response(scope, receive, sender)
|
||||
elif scope["type"] == "websocket":
|
||||
websocket = WebSocket(scope, receive=receive, send=send)
|
||||
if is_async_callable(handler):
|
||||
await handler(websocket, exc)
|
||||
else:
|
||||
await run_in_threadpool(handler, websocket, exc)
|
||||
|
||||
def http_exception(self, request: Request, exc: HTTPException) -> Response:
|
||||
if exc.status_code in {204, 304}:
|
||||
return Response(status_code=exc.status_code, headers=exc.headers)
|
||||
return PlainTextResponse(
|
||||
exc.detail, status_code=exc.status_code, headers=exc.headers
|
||||
scope["starlette.exception_handlers"] = (
|
||||
self._exception_handlers,
|
||||
self._status_handlers,
|
||||
)
|
||||
|
||||
async def websocket_exception(
|
||||
self, websocket: WebSocket, exc: WebSocketException
|
||||
) -> None:
|
||||
await websocket.close(code=exc.code, reason=exc.reason)
|
||||
conn: Request | WebSocket
|
||||
if scope["type"] == "http":
|
||||
conn = Request(scope, receive, send)
|
||||
else:
|
||||
conn = WebSocket(scope, receive, send)
|
||||
|
||||
await wrap_app_handling_exceptions(self.app, conn)(scope, receive, send)
|
||||
|
||||
async def http_exception(self, request: Request, exc: Exception) -> Response:
|
||||
assert isinstance(exc, HTTPException)
|
||||
if exc.status_code in {204, 304}:
|
||||
return Response(status_code=exc.status_code, headers=exc.headers)
|
||||
return PlainTextResponse(exc.detail, status_code=exc.status_code, headers=exc.headers)
|
||||
|
||||
async def websocket_exception(self, websocket: WebSocket, exc: Exception) -> None:
|
||||
assert isinstance(exc, WebSocketException)
|
||||
await websocket.close(code=exc.code, reason=exc.reason) # pragma: no cover
|
||||
|
||||
Reference in New Issue
Block a user