This commit is contained in:
Iliyan Angelov
2025-12-01 06:50:10 +02:00
parent 91f51bc6fe
commit 62c1fe5951
4682 changed files with 544807 additions and 31208 deletions

View File

@@ -1,8 +1,12 @@
from __future__ import annotations
import enum
import json
import typing
from collections.abc import AsyncIterator, Iterable
from typing import Any, cast
from starlette.requests import HTTPConnection
from starlette.responses import Response
from starlette.types import Message, Receive, Scope, Send
@@ -10,10 +14,11 @@ class WebSocketState(enum.Enum):
CONNECTING = 0
CONNECTED = 1
DISCONNECTED = 2
RESPONSE = 3
class WebSocketDisconnect(Exception):
def __init__(self, code: int = 1000, reason: typing.Optional[str] = None) -> None:
def __init__(self, code: int = 1000, reason: str | None = None) -> None:
self.code = code
self.reason = reason or ""
@@ -35,10 +40,7 @@ class WebSocket(HTTPConnection):
message = await self._receive()
message_type = message["type"]
if message_type != "websocket.connect":
raise RuntimeError(
'Expected ASGI message "websocket.connect", '
f"but got {message_type!r}"
)
raise RuntimeError(f'Expected ASGI message "websocket.connect", but got {message_type!r}')
self.client_state = WebSocketState.CONNECTED
return message
elif self.client_state == WebSocketState.CONNECTED:
@@ -46,16 +48,13 @@ class WebSocket(HTTPConnection):
message_type = message["type"]
if message_type not in {"websocket.receive", "websocket.disconnect"}:
raise RuntimeError(
'Expected ASGI message "websocket.receive" or '
f'"websocket.disconnect", but got {message_type!r}'
f'Expected ASGI message "websocket.receive" or "websocket.disconnect", but got {message_type!r}'
)
if message_type == "websocket.disconnect":
self.client_state = WebSocketState.DISCONNECTED
return message
else:
raise RuntimeError(
'Cannot call "receive" once a disconnect message has been received.'
)
raise RuntimeError('Cannot call "receive" once a disconnect message has been received.')
async def send(self, message: Message) -> None:
"""
@@ -63,13 +62,15 @@ class WebSocket(HTTPConnection):
"""
if self.application_state == WebSocketState.CONNECTING:
message_type = message["type"]
if message_type not in {"websocket.accept", "websocket.close"}:
if message_type not in {"websocket.accept", "websocket.close", "websocket.http.response.start"}:
raise RuntimeError(
'Expected ASGI message "websocket.accept" or '
f'"websocket.close", but got {message_type!r}'
'Expected ASGI message "websocket.accept", "websocket.close" or "websocket.http.response.start", '
f"but got {message_type!r}"
)
if message_type == "websocket.close":
self.application_state = WebSocketState.DISCONNECTED
elif message_type == "websocket.http.response.start":
self.application_state = WebSocketState.RESPONSE
else:
self.application_state = WebSocketState.CONNECTED
await self._send(message)
@@ -77,58 +78,60 @@ class WebSocket(HTTPConnection):
message_type = message["type"]
if message_type not in {"websocket.send", "websocket.close"}:
raise RuntimeError(
'Expected ASGI message "websocket.send" or "websocket.close", '
f"but got {message_type!r}"
f'Expected ASGI message "websocket.send" or "websocket.close", but got {message_type!r}'
)
if message_type == "websocket.close":
self.application_state = WebSocketState.DISCONNECTED
try:
await self._send(message)
except OSError:
self.application_state = WebSocketState.DISCONNECTED
raise WebSocketDisconnect(code=1006)
elif self.application_state == WebSocketState.RESPONSE:
message_type = message["type"]
if message_type != "websocket.http.response.body":
raise RuntimeError(f'Expected ASGI message "websocket.http.response.body", but got {message_type!r}')
if not message.get("more_body", False):
self.application_state = WebSocketState.DISCONNECTED
await self._send(message)
else:
raise RuntimeError('Cannot call "send" once a close message has been sent.')
async def accept(
self,
subprotocol: typing.Optional[str] = None,
headers: typing.Optional[typing.Iterable[typing.Tuple[bytes, bytes]]] = None,
subprotocol: str | None = None,
headers: Iterable[tuple[bytes, bytes]] | None = None,
) -> None:
headers = headers or []
if self.client_state == WebSocketState.CONNECTING:
if self.client_state == WebSocketState.CONNECTING: # pragma: no branch
# If we haven't yet seen the 'connect' message, then wait for it first.
await self.receive()
await self.send(
{"type": "websocket.accept", "subprotocol": subprotocol, "headers": headers}
)
await self.send({"type": "websocket.accept", "subprotocol": subprotocol, "headers": headers})
def _raise_on_disconnect(self, message: Message) -> None:
if message["type"] == "websocket.disconnect":
raise WebSocketDisconnect(message["code"])
raise WebSocketDisconnect(message["code"], message.get("reason"))
async def receive_text(self) -> str:
if self.application_state != WebSocketState.CONNECTED:
raise RuntimeError(
'WebSocket is not connected. Need to call "accept" first.'
)
raise RuntimeError('WebSocket is not connected. Need to call "accept" first.')
message = await self.receive()
self._raise_on_disconnect(message)
return message["text"]
return cast(str, message["text"])
async def receive_bytes(self) -> bytes:
if self.application_state != WebSocketState.CONNECTED:
raise RuntimeError(
'WebSocket is not connected. Need to call "accept" first.'
)
raise RuntimeError('WebSocket is not connected. Need to call "accept" first.')
message = await self.receive()
self._raise_on_disconnect(message)
return message["bytes"]
return cast(bytes, message["bytes"])
async def receive_json(self, mode: str = "text") -> typing.Any:
async def receive_json(self, mode: str = "text") -> Any:
if mode not in {"text", "binary"}:
raise RuntimeError('The "mode" argument should be "text" or "binary".')
if self.application_state != WebSocketState.CONNECTED:
raise RuntimeError(
'WebSocket is not connected. Need to call "accept" first.'
)
raise RuntimeError('WebSocket is not connected. Need to call "accept" first.')
message = await self.receive()
self._raise_on_disconnect(message)
@@ -138,21 +141,21 @@ class WebSocket(HTTPConnection):
text = message["bytes"].decode("utf-8")
return json.loads(text)
async def iter_text(self) -> typing.AsyncIterator[str]:
async def iter_text(self) -> AsyncIterator[str]:
try:
while True:
yield await self.receive_text()
except WebSocketDisconnect:
pass
async def iter_bytes(self) -> typing.AsyncIterator[bytes]:
async def iter_bytes(self) -> AsyncIterator[bytes]:
try:
while True:
yield await self.receive_bytes()
except WebSocketDisconnect:
pass
async def iter_json(self) -> typing.AsyncIterator[typing.Any]:
async def iter_json(self) -> AsyncIterator[Any]:
try:
while True:
yield await self.receive_json()
@@ -165,29 +168,29 @@ class WebSocket(HTTPConnection):
async def send_bytes(self, data: bytes) -> None:
await self.send({"type": "websocket.send", "bytes": data})
async def send_json(self, data: typing.Any, mode: str = "text") -> None:
async def send_json(self, data: Any, mode: str = "text") -> None:
if mode not in {"text", "binary"}:
raise RuntimeError('The "mode" argument should be "text" or "binary".')
text = json.dumps(data, separators=(",", ":"))
text = json.dumps(data, separators=(",", ":"), ensure_ascii=False)
if mode == "text":
await self.send({"type": "websocket.send", "text": text})
else:
await self.send({"type": "websocket.send", "bytes": text.encode("utf-8")})
async def close(
self, code: int = 1000, reason: typing.Optional[str] = None
) -> None:
await self.send(
{"type": "websocket.close", "code": code, "reason": reason or ""}
)
async def close(self, code: int = 1000, reason: str | None = None) -> None:
await self.send({"type": "websocket.close", "code": code, "reason": reason or ""})
async def send_denial_response(self, response: Response) -> None:
if "websocket.http.response" in self.scope.get("extensions", {}):
await response(self.scope, self.receive, self.send)
else:
raise RuntimeError("The server doesn't support the Websocket Denial Response extension.")
class WebSocketClose:
def __init__(self, code: int = 1000, reason: typing.Optional[str] = None) -> None:
def __init__(self, code: int = 1000, reason: str | None = None) -> None:
self.code = code
self.reason = reason or ""
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
await send(
{"type": "websocket.close", "code": self.code, "reason": self.reason}
)
await send({"type": "websocket.close", "code": self.code, "reason": self.reason})