updates
This commit is contained in:
@@ -1,5 +1,8 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import typing
|
||||
from collections.abc import Callable, Generator
|
||||
from typing import Any, Literal
|
||||
|
||||
from starlette import status
|
||||
from starlette._utils import is_async_callable
|
||||
@@ -23,20 +26,14 @@ class HTTPEndpoint:
|
||||
if getattr(self, method.lower(), None) is not None
|
||||
]
|
||||
|
||||
def __await__(self) -> typing.Generator:
|
||||
def __await__(self) -> Generator[Any, None, None]:
|
||||
return self.dispatch().__await__()
|
||||
|
||||
async def dispatch(self) -> None:
|
||||
request = Request(self.scope, receive=self.receive)
|
||||
handler_name = (
|
||||
"get"
|
||||
if request.method == "HEAD" and not hasattr(self, "head")
|
||||
else request.method.lower()
|
||||
)
|
||||
handler_name = "get" if request.method == "HEAD" and not hasattr(self, "head") else request.method.lower()
|
||||
|
||||
handler: typing.Callable[[Request], typing.Any] = getattr(
|
||||
self, handler_name, self.method_not_allowed
|
||||
)
|
||||
handler: Callable[[Request], Any] = getattr(self, handler_name, self.method_not_allowed)
|
||||
is_async = is_async_callable(handler)
|
||||
if is_async:
|
||||
response = await handler(request)
|
||||
@@ -55,7 +52,7 @@ class HTTPEndpoint:
|
||||
|
||||
|
||||
class WebSocketEndpoint:
|
||||
encoding: typing.Optional[str] = None # May be "text", "bytes", or "json".
|
||||
encoding: Literal["text", "bytes", "json"] | None = None
|
||||
|
||||
def __init__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
assert scope["type"] == "websocket"
|
||||
@@ -63,7 +60,7 @@ class WebSocketEndpoint:
|
||||
self.receive = receive
|
||||
self.send = send
|
||||
|
||||
def __await__(self) -> typing.Generator:
|
||||
def __await__(self) -> Generator[Any, None, None]:
|
||||
return self.dispatch().__await__()
|
||||
|
||||
async def dispatch(self) -> None:
|
||||
@@ -78,10 +75,8 @@ class WebSocketEndpoint:
|
||||
if message["type"] == "websocket.receive":
|
||||
data = await self.decode(websocket, message)
|
||||
await self.on_receive(websocket, data)
|
||||
elif message["type"] == "websocket.disconnect":
|
||||
close_code = int(
|
||||
message.get("code") or status.WS_1000_NORMAL_CLOSURE
|
||||
)
|
||||
elif message["type"] == "websocket.disconnect": # pragma: no branch
|
||||
close_code = int(message.get("code") or status.WS_1000_NORMAL_CLOSURE)
|
||||
break
|
||||
except Exception as exc:
|
||||
close_code = status.WS_1011_INTERNAL_ERROR
|
||||
@@ -89,7 +84,7 @@ class WebSocketEndpoint:
|
||||
finally:
|
||||
await self.on_disconnect(websocket, close_code)
|
||||
|
||||
async def decode(self, websocket: WebSocket, message: Message) -> typing.Any:
|
||||
async def decode(self, websocket: WebSocket, message: Message) -> Any:
|
||||
if self.encoding == "text":
|
||||
if "text" not in message:
|
||||
await websocket.close(code=status.WS_1003_UNSUPPORTED_DATA)
|
||||
@@ -114,16 +109,14 @@ class WebSocketEndpoint:
|
||||
await websocket.close(code=status.WS_1003_UNSUPPORTED_DATA)
|
||||
raise RuntimeError("Malformed JSON data received.")
|
||||
|
||||
assert (
|
||||
self.encoding is None
|
||||
), f"Unsupported 'encoding' attribute {self.encoding}"
|
||||
assert self.encoding is None, f"Unsupported 'encoding' attribute {self.encoding}"
|
||||
return message["text"] if message.get("text") else message["bytes"]
|
||||
|
||||
async def on_connect(self, websocket: WebSocket) -> None:
|
||||
"""Override to handle an incoming websocket connection"""
|
||||
await websocket.accept()
|
||||
|
||||
async def on_receive(self, websocket: WebSocket, data: typing.Any) -> None:
|
||||
async def on_receive(self, websocket: WebSocket, data: Any) -> None:
|
||||
"""Override to handle an incoming websocket message"""
|
||||
|
||||
async def on_disconnect(self, websocket: WebSocket, close_code: int) -> None:
|
||||
|
||||
Reference in New Issue
Block a user