update to python fastpi

This commit is contained in:
Iliyan Angelov
2025-11-16 15:59:05 +02:00
parent 93d4c1df80
commit 98ccd5b6ff
4464 changed files with 773233 additions and 13740 deletions

View File

@@ -0,0 +1,648 @@
from __future__ import annotations
import socket
import ssl as ssl_module
import threading
import warnings
from collections.abc import Sequence
from typing import Any, Callable, Literal, TypeVar, cast
from ..client import ClientProtocol
from ..datastructures import Headers, HeadersLike
from ..exceptions import InvalidProxyMessage, InvalidProxyStatus, ProxyError
from ..extensions.base import ClientExtensionFactory
from ..extensions.permessage_deflate import enable_client_permessage_deflate
from ..headers import build_authorization_basic, build_host, validate_subprotocols
from ..http11 import USER_AGENT, Response
from ..protocol import CONNECTING, Event
from ..streams import StreamReader
from ..typing import LoggerLike, Origin, Subprotocol
from ..uri import Proxy, WebSocketURI, get_proxy, parse_proxy, parse_uri
from .connection import Connection
from .utils import Deadline
__all__ = ["connect", "unix_connect", "ClientConnection"]
class ClientConnection(Connection):
"""
:mod:`threading` implementation of a WebSocket client connection.
:class:`ClientConnection` provides :meth:`recv` and :meth:`send` methods for
receiving and sending messages.
It supports iteration to receive messages::
for message in websocket:
process(message)
The iterator exits normally when the connection is closed with close code
1000 (OK) or 1001 (going away) or without a close code. It raises a
:exc:`~websockets.exceptions.ConnectionClosedError` when the connection is
closed with any other code.
The ``ping_interval``, ``ping_timeout``, ``close_timeout``, and
``max_queue`` arguments have the same meaning as in :func:`connect`.
Args:
socket: Socket connected to a WebSocket server.
protocol: Sans-I/O connection.
"""
def __init__(
self,
socket: socket.socket,
protocol: ClientProtocol,
*,
ping_interval: float | None = 20,
ping_timeout: float | None = 20,
close_timeout: float | None = 10,
max_queue: int | None | tuple[int | None, int | None] = 16,
) -> None:
self.protocol: ClientProtocol
self.response_rcvd = threading.Event()
super().__init__(
socket,
protocol,
ping_interval=ping_interval,
ping_timeout=ping_timeout,
close_timeout=close_timeout,
max_queue=max_queue,
)
def handshake(
self,
additional_headers: HeadersLike | None = None,
user_agent_header: str | None = USER_AGENT,
timeout: float | None = None,
) -> None:
"""
Perform the opening handshake.
"""
with self.send_context(expected_state=CONNECTING):
self.request = self.protocol.connect()
if additional_headers is not None:
self.request.headers.update(additional_headers)
if user_agent_header is not None:
self.request.headers.setdefault("User-Agent", user_agent_header)
self.protocol.send_request(self.request)
if not self.response_rcvd.wait(timeout):
raise TimeoutError("timed out while waiting for handshake response")
# self.protocol.handshake_exc is set when the connection is lost before
# receiving a response, when the response cannot be parsed, or when the
# response fails the handshake.
if self.protocol.handshake_exc is not None:
raise self.protocol.handshake_exc
def process_event(self, event: Event) -> None:
"""
Process one incoming event.
"""
# First event - handshake response.
if self.response is None:
assert isinstance(event, Response)
self.response = event
self.response_rcvd.set()
# Later events - frames.
else:
super().process_event(event)
def recv_events(self) -> None:
"""
Read incoming data from the socket and process events.
"""
try:
super().recv_events()
finally:
# If the connection is closed during the handshake, unblock it.
self.response_rcvd.set()
def connect(
uri: str,
*,
# TCP/TLS
sock: socket.socket | None = None,
ssl: ssl_module.SSLContext | None = None,
server_hostname: str | None = None,
# WebSocket
origin: Origin | None = None,
extensions: Sequence[ClientExtensionFactory] | None = None,
subprotocols: Sequence[Subprotocol] | None = None,
compression: str | None = "deflate",
# HTTP
additional_headers: HeadersLike | None = None,
user_agent_header: str | None = USER_AGENT,
proxy: str | Literal[True] | None = True,
proxy_ssl: ssl_module.SSLContext | None = None,
proxy_server_hostname: str | None = None,
# Timeouts
open_timeout: float | None = 10,
ping_interval: float | None = 20,
ping_timeout: float | None = 20,
close_timeout: float | None = 10,
# Limits
max_size: int | None = 2**20,
max_queue: int | None | tuple[int | None, int | None] = 16,
# Logging
logger: LoggerLike | None = None,
# Escape hatch for advanced customization
create_connection: type[ClientConnection] | None = None,
**kwargs: Any,
) -> ClientConnection:
"""
Connect to the WebSocket server at ``uri``.
This function returns a :class:`ClientConnection` instance, which you can
use to send and receive messages.
:func:`connect` may be used as a context manager::
from websockets.sync.client import connect
with connect(...) as websocket:
...
The connection is closed automatically when exiting the context.
Args:
uri: URI of the WebSocket server.
sock: Preexisting TCP socket. ``sock`` overrides the host and port
from ``uri``. You may call :func:`socket.create_connection` to
create a suitable TCP socket.
ssl: Configuration for enabling TLS on the connection.
server_hostname: Host name for the TLS handshake. ``server_hostname``
overrides the host name from ``uri``.
origin: Value of the ``Origin`` header, for servers that require it.
extensions: List of supported extensions, in order in which they
should be negotiated and run.
subprotocols: List of supported subprotocols, in order of decreasing
preference.
compression: The "permessage-deflate" extension is enabled by default.
Set ``compression`` to :obj:`None` to disable it. See the
:doc:`compression guide <../../topics/compression>` for details.
additional_headers (HeadersLike | None): Arbitrary HTTP headers to add
to the handshake request.
user_agent_header: Value of the ``User-Agent`` request header.
It defaults to ``"Python/x.y.z websockets/X.Y"``.
Setting it to :obj:`None` removes the header.
proxy: If a proxy is configured, it is used by default. Set ``proxy``
to :obj:`None` to disable the proxy or to the address of a proxy
to override the system configuration. See the :doc:`proxy docs
<../../topics/proxies>` for details.
proxy_ssl: Configuration for enabling TLS on the proxy connection.
proxy_server_hostname: Host name for the TLS handshake with the proxy.
``proxy_server_hostname`` overrides the host name from ``proxy``.
open_timeout: Timeout for opening the connection in seconds.
:obj:`None` disables the timeout.
ping_interval: Interval between keepalive pings in seconds.
:obj:`None` disables keepalive.
ping_timeout: Timeout for keepalive pings in seconds.
:obj:`None` disables timeouts.
close_timeout: Timeout for closing the connection in seconds.
:obj:`None` disables the timeout.
max_size: Maximum size of incoming messages in bytes.
:obj:`None` disables the limit.
max_queue: High-water mark of the buffer where frames are received.
It defaults to 16 frames. The low-water mark defaults to ``max_queue
// 4``. You may pass a ``(high, low)`` tuple to set the high-water
and low-water marks. If you want to disable flow control entirely,
you may set it to ``None``, although that's a bad idea.
logger: Logger for this client.
It defaults to ``logging.getLogger("websockets.client")``.
See the :doc:`logging guide <../../topics/logging>` for details.
create_connection: Factory for the :class:`ClientConnection` managing
the connection. Set it to a wrapper or a subclass to customize
connection handling.
Any other keyword arguments are passed to :func:`~socket.create_connection`.
Raises:
InvalidURI: If ``uri`` isn't a valid WebSocket URI.
OSError: If the TCP connection fails.
InvalidHandshake: If the opening handshake fails.
TimeoutError: If the opening handshake times out.
"""
# Process parameters
# Backwards compatibility: ssl used to be called ssl_context.
if ssl is None and "ssl_context" in kwargs:
ssl = kwargs.pop("ssl_context")
warnings.warn( # deprecated in 13.0 - 2024-08-20
"ssl_context was renamed to ssl",
DeprecationWarning,
)
ws_uri = parse_uri(uri)
if not ws_uri.secure and ssl is not None:
raise ValueError("ssl argument is incompatible with a ws:// URI")
# Private APIs for unix_connect()
unix: bool = kwargs.pop("unix", False)
path: str | None = kwargs.pop("path", None)
if unix:
if path is None and sock is None:
raise ValueError("missing path argument")
elif path is not None and sock is not None:
raise ValueError("path and sock arguments are incompatible")
if subprotocols is not None:
validate_subprotocols(subprotocols)
if compression == "deflate":
extensions = enable_client_permessage_deflate(extensions)
elif compression is not None:
raise ValueError(f"unsupported compression: {compression}")
if unix:
proxy = None
if sock is not None:
proxy = None
if proxy is True:
proxy = get_proxy(ws_uri)
# Calculate timeouts on the TCP, TLS, and WebSocket handshakes.
# The TCP and TLS timeouts must be set on the socket, then removed
# to avoid conflicting with the WebSocket timeout in handshake().
deadline = Deadline(open_timeout)
if create_connection is None:
create_connection = ClientConnection
try:
# Connect socket
if sock is None:
if unix:
sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
sock.settimeout(deadline.timeout())
assert path is not None # mypy cannot figure this out
sock.connect(path)
elif proxy is not None:
proxy_parsed = parse_proxy(proxy)
if proxy_parsed.scheme[:5] == "socks":
# Connect to the server through the proxy.
sock = connect_socks_proxy(
proxy_parsed,
ws_uri,
deadline,
# websockets is consistent with the socket module while
# python_socks is consistent across implementations.
local_addr=kwargs.pop("source_address", None),
)
elif proxy_parsed.scheme[:4] == "http":
# Validate the proxy_ssl argument.
if proxy_parsed.scheme != "https" and proxy_ssl is not None:
raise ValueError(
"proxy_ssl argument is incompatible with an http:// proxy"
)
# Connect to the server through the proxy.
sock = connect_http_proxy(
proxy_parsed,
ws_uri,
deadline,
user_agent_header=user_agent_header,
ssl=proxy_ssl,
server_hostname=proxy_server_hostname,
**kwargs,
)
else:
raise AssertionError("unsupported proxy")
else:
kwargs.setdefault("timeout", deadline.timeout())
sock = socket.create_connection(
(ws_uri.host, ws_uri.port),
**kwargs,
)
sock.settimeout(None)
# Disable Nagle algorithm
if not unix:
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, True)
# Initialize TLS wrapper and perform TLS handshake
if ws_uri.secure:
if ssl is None:
ssl = ssl_module.create_default_context()
if server_hostname is None:
server_hostname = ws_uri.host
sock.settimeout(deadline.timeout())
if proxy_ssl is None:
sock = ssl.wrap_socket(sock, server_hostname=server_hostname)
else:
sock_2 = SSLSSLSocket(sock, ssl, server_hostname=server_hostname)
# Let's pretend that sock is a socket, even though it isn't.
sock = cast(socket.socket, sock_2)
sock.settimeout(None)
# Initialize WebSocket protocol
protocol = ClientProtocol(
ws_uri,
origin=origin,
extensions=extensions,
subprotocols=subprotocols,
max_size=max_size,
logger=logger,
)
# Initialize WebSocket connection
connection = create_connection(
sock,
protocol,
ping_interval=ping_interval,
ping_timeout=ping_timeout,
close_timeout=close_timeout,
max_queue=max_queue,
)
except Exception:
if sock is not None:
sock.close()
raise
try:
connection.handshake(
additional_headers,
user_agent_header,
deadline.timeout(),
)
except Exception:
connection.close_socket()
connection.recv_events_thread.join()
raise
connection.start_keepalive()
return connection
def unix_connect(
path: str | None = None,
uri: str | None = None,
**kwargs: Any,
) -> ClientConnection:
"""
Connect to a WebSocket server listening on a Unix socket.
This function accepts the same keyword arguments as :func:`connect`.
It's only available on Unix.
It's mainly useful for debugging servers listening on Unix sockets.
Args:
path: File system path to the Unix socket.
uri: URI of the WebSocket server. ``uri`` defaults to
``ws://localhost/`` or, when a ``ssl`` is provided, to
``wss://localhost/``.
"""
if uri is None:
# Backwards compatibility: ssl used to be called ssl_context.
if kwargs.get("ssl") is None and kwargs.get("ssl_context") is None:
uri = "ws://localhost/"
else:
uri = "wss://localhost/"
return connect(uri=uri, unix=True, path=path, **kwargs)
try:
from python_socks import ProxyType
from python_socks.sync import Proxy as SocksProxy
SOCKS_PROXY_TYPES = {
"socks5h": ProxyType.SOCKS5,
"socks5": ProxyType.SOCKS5,
"socks4a": ProxyType.SOCKS4,
"socks4": ProxyType.SOCKS4,
}
SOCKS_PROXY_RDNS = {
"socks5h": True,
"socks5": False,
"socks4a": True,
"socks4": False,
}
def connect_socks_proxy(
proxy: Proxy,
ws_uri: WebSocketURI,
deadline: Deadline,
**kwargs: Any,
) -> socket.socket:
"""Connect via a SOCKS proxy and return the socket."""
socks_proxy = SocksProxy(
SOCKS_PROXY_TYPES[proxy.scheme],
proxy.host,
proxy.port,
proxy.username,
proxy.password,
SOCKS_PROXY_RDNS[proxy.scheme],
)
kwargs.setdefault("timeout", deadline.timeout())
# connect() is documented to raise OSError and TimeoutError.
# Wrap other exceptions in ProxyError, a subclass of InvalidHandshake.
try:
return socks_proxy.connect(ws_uri.host, ws_uri.port, **kwargs)
except (OSError, TimeoutError, socket.timeout):
raise
except Exception as exc:
raise ProxyError("failed to connect to SOCKS proxy") from exc
except ImportError:
def connect_socks_proxy(
proxy: Proxy,
ws_uri: WebSocketURI,
deadline: Deadline,
**kwargs: Any,
) -> socket.socket:
raise ImportError("python-socks is required to use a SOCKS proxy")
def prepare_connect_request(
proxy: Proxy,
ws_uri: WebSocketURI,
user_agent_header: str | None = None,
) -> bytes:
host = build_host(ws_uri.host, ws_uri.port, ws_uri.secure, always_include_port=True)
headers = Headers()
headers["Host"] = build_host(ws_uri.host, ws_uri.port, ws_uri.secure)
if user_agent_header is not None:
headers["User-Agent"] = user_agent_header
if proxy.username is not None:
assert proxy.password is not None # enforced by parse_proxy()
headers["Proxy-Authorization"] = build_authorization_basic(
proxy.username, proxy.password
)
# We cannot use the Request class because it supports only GET requests.
return f"CONNECT {host} HTTP/1.1\r\n".encode() + headers.serialize()
def read_connect_response(sock: socket.socket, deadline: Deadline) -> Response:
reader = StreamReader()
parser = Response.parse(
reader.read_line,
reader.read_exact,
reader.read_to_eof,
include_body=False,
)
try:
while True:
sock.settimeout(deadline.timeout())
data = sock.recv(4096)
if data:
reader.feed_data(data)
else:
reader.feed_eof()
next(parser)
except StopIteration as exc:
assert isinstance(exc.value, Response) # help mypy
response = exc.value
if 200 <= response.status_code < 300:
return response
else:
raise InvalidProxyStatus(response)
except socket.timeout:
raise TimeoutError("timed out while connecting to HTTP proxy")
except Exception as exc:
raise InvalidProxyMessage(
"did not receive a valid HTTP response from proxy"
) from exc
finally:
sock.settimeout(None)
def connect_http_proxy(
proxy: Proxy,
ws_uri: WebSocketURI,
deadline: Deadline,
*,
user_agent_header: str | None = None,
ssl: ssl_module.SSLContext | None = None,
server_hostname: str | None = None,
**kwargs: Any,
) -> socket.socket:
# Connect socket
kwargs.setdefault("timeout", deadline.timeout())
sock = socket.create_connection((proxy.host, proxy.port), **kwargs)
# Initialize TLS wrapper and perform TLS handshake
if proxy.scheme == "https":
if ssl is None:
ssl = ssl_module.create_default_context()
if server_hostname is None:
server_hostname = proxy.host
sock.settimeout(deadline.timeout())
sock = ssl.wrap_socket(sock, server_hostname=server_hostname)
sock.settimeout(None)
# Send CONNECT request to the proxy and read response.
sock.sendall(prepare_connect_request(proxy, ws_uri, user_agent_header))
try:
read_connect_response(sock, deadline)
except Exception:
sock.close()
raise
return sock
T = TypeVar("T")
F = TypeVar("F", bound=Callable[..., T])
class SSLSSLSocket:
"""
Socket-like object providing TLS-in-TLS.
Only methods that are used by websockets are implemented.
"""
recv_bufsize = 65536
def __init__(
self,
sock: socket.socket,
ssl_context: ssl_module.SSLContext,
server_hostname: str | None = None,
) -> None:
self.incoming = ssl_module.MemoryBIO()
self.outgoing = ssl_module.MemoryBIO()
self.ssl_socket = sock
self.ssl_object = ssl_context.wrap_bio(
self.incoming,
self.outgoing,
server_hostname=server_hostname,
)
self.run_io(self.ssl_object.do_handshake)
def run_io(self, func: Callable[..., T], *args: Any) -> T:
while True:
want_read = False
want_write = False
try:
result = func(*args)
except ssl_module.SSLWantReadError:
want_read = True
except ssl_module.SSLWantWriteError: # pragma: no cover
want_write = True
# Write outgoing data in all cases.
data = self.outgoing.read()
if data:
self.ssl_socket.sendall(data)
# Read incoming data and retry on SSLWantReadError.
if want_read:
data = self.ssl_socket.recv(self.recv_bufsize)
if data:
self.incoming.write(data)
else:
self.incoming.write_eof()
continue
# Retry after writing outgoing data on SSLWantWriteError.
if want_write: # pragma: no cover
continue
# Return result if no error happened.
return result
def recv(self, buflen: int) -> bytes:
try:
return self.run_io(self.ssl_object.read, buflen)
except ssl_module.SSLEOFError:
return b"" # always ignore ragged EOFs
def send(self, data: bytes) -> int:
return self.run_io(self.ssl_object.write, data)
def sendall(self, data: bytes) -> None:
# adapted from ssl_module.SSLSocket.sendall()
count = 0
with memoryview(data) as view, view.cast("B") as byte_view:
amount = len(byte_view)
while count < amount:
count += self.send(byte_view[count:])
# recv_into(), recvfrom(), recvfrom_into(), sendto(), unwrap(), and the
# flags argument aren't implemented because websockets doesn't need them.
def __getattr__(self, name: str) -> Any:
return getattr(self.ssl_socket, name)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,345 @@
from __future__ import annotations
import codecs
import queue
import threading
from typing import Any, Callable, Iterable, Iterator, Literal, overload
from ..exceptions import ConcurrencyError
from ..frames import OP_BINARY, OP_CONT, OP_TEXT, Frame
from ..typing import Data
from .utils import Deadline
__all__ = ["Assembler"]
UTF8Decoder = codecs.getincrementaldecoder("utf-8")
class Assembler:
"""
Assemble messages from frames.
:class:`Assembler` expects only data frames. The stream of frames must
respect the protocol; if it doesn't, the behavior is undefined.
Args:
pause: Called when the buffer of frames goes above the high water mark;
should pause reading from the network.
resume: Called when the buffer of frames goes below the low water mark;
should resume reading from the network.
"""
def __init__(
self,
high: int | None = None,
low: int | None = None,
pause: Callable[[], Any] = lambda: None,
resume: Callable[[], Any] = lambda: None,
) -> None:
# Serialize reads and writes -- except for reads via synchronization
# primitives provided by the threading and queue modules.
self.mutex = threading.Lock()
# Queue of incoming frames.
self.frames: queue.SimpleQueue[Frame | None] = queue.SimpleQueue()
# We cannot put a hard limit on the size of the queue because a single
# call to Protocol.data_received() could produce thousands of frames,
# which must be buffered. Instead, we pause reading when the buffer goes
# above the high limit and we resume when it goes under the low limit.
if high is not None and low is None:
low = high // 4
if high is None and low is not None:
high = low * 4
if high is not None and low is not None:
if low < 0:
raise ValueError("low must be positive or equal to zero")
if high < low:
raise ValueError("high must be greater than or equal to low")
self.high, self.low = high, low
self.pause = pause
self.resume = resume
self.paused = False
# This flag prevents concurrent calls to get() by user code.
self.get_in_progress = False
# This flag marks the end of the connection.
self.closed = False
def get_next_frame(self, timeout: float | None = None) -> Frame:
# Helper to factor out the logic for getting the next frame from the
# queue, while handling timeouts and reaching the end of the stream.
if self.closed:
try:
frame = self.frames.get(block=False)
except queue.Empty:
raise EOFError("stream of frames ended") from None
else:
try:
# Check for a frame that's already received if timeout <= 0.
# SimpleQueue.get() doesn't support negative timeout values.
if timeout is not None and timeout <= 0:
frame = self.frames.get(block=False)
else:
frame = self.frames.get(block=True, timeout=timeout)
except queue.Empty:
raise TimeoutError(f"timed out in {timeout:.1f}s") from None
if frame is None:
raise EOFError("stream of frames ended")
return frame
def reset_queue(self, frames: Iterable[Frame]) -> None:
# Helper to put frames back into the queue after they were fetched.
# This happens only when the queue is empty. However, by the time
# we acquire self.mutex, put() may have added items in the queue.
# Therefore, we must handle the case where the queue is not empty.
frame: Frame | None
with self.mutex:
queued = []
try:
while True:
queued.append(self.frames.get(block=False))
except queue.Empty:
pass
for frame in frames:
self.frames.put(frame)
# This loop runs only when a race condition occurs.
for frame in queued: # pragma: no cover
self.frames.put(frame)
# This overload structure is required to avoid the error:
# "parameter without a default follows parameter with a default"
@overload
def get(self, timeout: float | None, decode: Literal[True]) -> str: ...
@overload
def get(self, timeout: float | None, decode: Literal[False]) -> bytes: ...
@overload
def get(self, timeout: float | None = None, *, decode: Literal[True]) -> str: ...
@overload
def get(self, timeout: float | None = None, *, decode: Literal[False]) -> bytes: ...
@overload
def get(self, timeout: float | None = None, decode: bool | None = None) -> Data: ...
def get(self, timeout: float | None = None, decode: bool | None = None) -> Data:
"""
Read the next message.
:meth:`get` returns a single :class:`str` or :class:`bytes`.
If the message is fragmented, :meth:`get` waits until the last frame is
received, then it reassembles the message and returns it. To receive
messages frame by frame, use :meth:`get_iter` instead.
Args:
timeout: If a timeout is provided and elapses before a complete
message is received, :meth:`get` raises :exc:`TimeoutError`.
decode: :obj:`False` disables UTF-8 decoding of text frames and
returns :class:`bytes`. :obj:`True` forces UTF-8 decoding of
binary frames and returns :class:`str`.
Raises:
EOFError: If the stream of frames has ended.
UnicodeDecodeError: If a text frame contains invalid UTF-8.
ConcurrencyError: If two coroutines run :meth:`get` or
:meth:`get_iter` concurrently.
TimeoutError: If a timeout is provided and elapses before a
complete message is received.
"""
with self.mutex:
if self.get_in_progress:
raise ConcurrencyError("get() or get_iter() is already running")
self.get_in_progress = True
# Locking with get_in_progress prevents concurrent execution
# until get() fetches a complete message or times out.
try:
deadline = Deadline(timeout)
# First frame
frame = self.get_next_frame(deadline.timeout(raise_if_elapsed=False))
with self.mutex:
self.maybe_resume()
assert frame.opcode is OP_TEXT or frame.opcode is OP_BINARY
if decode is None:
decode = frame.opcode is OP_TEXT
frames = [frame]
# Following frames, for fragmented messages
while not frame.fin:
try:
frame = self.get_next_frame(
deadline.timeout(raise_if_elapsed=False)
)
except TimeoutError:
# Put frames already received back into the queue
# so that future calls to get() can return them.
self.reset_queue(frames)
raise
with self.mutex:
self.maybe_resume()
assert frame.opcode is OP_CONT
frames.append(frame)
finally:
self.get_in_progress = False
data = b"".join(frame.data for frame in frames)
if decode:
return data.decode()
else:
return data
@overload
def get_iter(self, decode: Literal[True]) -> Iterator[str]: ...
@overload
def get_iter(self, decode: Literal[False]) -> Iterator[bytes]: ...
@overload
def get_iter(self, decode: bool | None = None) -> Iterator[Data]: ...
def get_iter(self, decode: bool | None = None) -> Iterator[Data]:
"""
Stream the next message.
Iterating the return value of :meth:`get_iter` yields a :class:`str` or
:class:`bytes` for each frame in the message.
The iterator must be fully consumed before calling :meth:`get_iter` or
:meth:`get` again. Else, :exc:`ConcurrencyError` is raised.
This method only makes sense for fragmented messages. If messages aren't
fragmented, use :meth:`get` instead.
Args:
decode: :obj:`False` disables UTF-8 decoding of text frames and
returns :class:`bytes`. :obj:`True` forces UTF-8 decoding of
binary frames and returns :class:`str`.
Raises:
EOFError: If the stream of frames has ended.
UnicodeDecodeError: If a text frame contains invalid UTF-8.
ConcurrencyError: If two coroutines run :meth:`get` or
:meth:`get_iter` concurrently.
"""
with self.mutex:
if self.get_in_progress:
raise ConcurrencyError("get() or get_iter() is already running")
self.get_in_progress = True
# Locking with get_in_progress prevents concurrent execution
# until get_iter() fetches a complete message or times out.
# If get_iter() raises an exception e.g. in decoder.decode(),
# get_in_progress remains set and the connection becomes unusable.
# First frame
frame = self.get_next_frame()
with self.mutex:
self.maybe_resume()
assert frame.opcode is OP_TEXT or frame.opcode is OP_BINARY
if decode is None:
decode = frame.opcode is OP_TEXT
if decode:
decoder = UTF8Decoder()
yield decoder.decode(frame.data, frame.fin)
else:
yield frame.data
# Following frames, for fragmented messages
while not frame.fin:
frame = self.get_next_frame()
with self.mutex:
self.maybe_resume()
assert frame.opcode is OP_CONT
if decode:
yield decoder.decode(frame.data, frame.fin)
else:
yield frame.data
self.get_in_progress = False
def put(self, frame: Frame) -> None:
"""
Add ``frame`` to the next message.
Raises:
EOFError: If the stream of frames has ended.
"""
with self.mutex:
if self.closed:
raise EOFError("stream of frames ended")
self.frames.put(frame)
self.maybe_pause()
# put() and get/get_iter() call maybe_pause() and maybe_resume() while
# holding self.mutex. This guarantees that the calls interleave properly.
# Specifically, it prevents a race condition where maybe_resume() would
# run before maybe_pause(), leaving the connection incorrectly paused.
# A race condition is possible when get/get_iter() call self.frames.get()
# without holding self.mutex. However, it's harmless — and even beneficial!
# It can only result in popping an item from the queue before maybe_resume()
# runs and skipping a pause() - resume() cycle that would otherwise occur.
def maybe_pause(self) -> None:
"""Pause the writer if queue is above the high water mark."""
# Skip if flow control is disabled
if self.high is None:
return
assert self.mutex.locked()
# Check for "> high" to support high = 0
if self.frames.qsize() > self.high and not self.paused:
self.paused = True
self.pause()
def maybe_resume(self) -> None:
"""Resume the writer if queue is below the low water mark."""
# Skip if flow control is disabled
if self.low is None:
return
assert self.mutex.locked()
# Check for "<= low" to support low = 0
if self.frames.qsize() <= self.low and self.paused:
self.paused = False
self.resume()
def close(self) -> None:
"""
End the stream of frames.
Calling :meth:`close` concurrently with :meth:`get`, :meth:`get_iter`,
or :meth:`put` is safe. They will raise :exc:`EOFError`.
"""
with self.mutex:
if self.closed:
return
self.closed = True
if self.get_in_progress:
# Unblock get() or get_iter().
self.frames.put(None)
if self.paused:
# Unblock recv_events().
self.paused = False
self.resume()

View File

@@ -0,0 +1,192 @@
from __future__ import annotations
import http
import ssl as ssl_module
import urllib.parse
from typing import Any, Callable, Literal
from werkzeug.exceptions import NotFound
from werkzeug.routing import Map, RequestRedirect
from ..http11 import Request, Response
from .server import Server, ServerConnection, serve
__all__ = ["route", "unix_route", "Router"]
class Router:
"""WebSocket router supporting :func:`route`."""
def __init__(
self,
url_map: Map,
server_name: str | None = None,
url_scheme: str = "ws",
) -> None:
self.url_map = url_map
self.server_name = server_name
self.url_scheme = url_scheme
for rule in self.url_map.iter_rules():
rule.websocket = True
def get_server_name(self, connection: ServerConnection, request: Request) -> str:
if self.server_name is None:
return request.headers["Host"]
else:
return self.server_name
def redirect(self, connection: ServerConnection, url: str) -> Response:
response = connection.respond(http.HTTPStatus.FOUND, f"Found at {url}")
response.headers["Location"] = url
return response
def not_found(self, connection: ServerConnection) -> Response:
return connection.respond(http.HTTPStatus.NOT_FOUND, "Not Found")
def route_request(
self, connection: ServerConnection, request: Request
) -> Response | None:
"""Route incoming request."""
url_map_adapter = self.url_map.bind(
server_name=self.get_server_name(connection, request),
url_scheme=self.url_scheme,
)
try:
parsed = urllib.parse.urlparse(request.path)
handler, kwargs = url_map_adapter.match(
path_info=parsed.path,
query_args=parsed.query,
)
except RequestRedirect as redirect:
return self.redirect(connection, redirect.new_url)
except NotFound:
return self.not_found(connection)
connection.handler, connection.handler_kwargs = handler, kwargs
return None
def handler(self, connection: ServerConnection) -> None:
"""Handle a connection."""
return connection.handler(connection, **connection.handler_kwargs)
def route(
url_map: Map,
*args: Any,
server_name: str | None = None,
ssl: ssl_module.SSLContext | Literal[True] | None = None,
create_router: type[Router] | None = None,
**kwargs: Any,
) -> Server:
"""
Create a WebSocket server dispatching connections to different handlers.
This feature requires the third-party library `werkzeug`_:
.. code-block:: console
$ pip install werkzeug
.. _werkzeug: https://werkzeug.palletsprojects.com/
:func:`route` accepts the same arguments as
:func:`~websockets.sync.server.serve`, except as described below.
The first argument is a :class:`werkzeug.routing.Map` that maps URL patterns
to connection handlers. In addition to the connection, handlers receive
parameters captured in the URL as keyword arguments.
Here's an example::
from websockets.sync.router import route
from werkzeug.routing import Map, Rule
def channel_handler(websocket, channel_id):
...
url_map = Map([
Rule("/channel/<uuid:channel_id>", endpoint=channel_handler),
...
])
with route(url_map, ...) as server:
server.serve_forever()
Refer to the documentation of :mod:`werkzeug.routing` for details.
If you define redirects with ``Rule(..., redirect_to=...)`` in the URL map,
when the server runs behind a reverse proxy that modifies the ``Host``
header or terminates TLS, you need additional configuration:
* Set ``server_name`` to the name of the server as seen by clients. When not
provided, websockets uses the value of the ``Host`` header.
* Set ``ssl=True`` to generate ``wss://`` URIs without actually enabling
TLS. Under the hood, this bind the URL map with a ``url_scheme`` of
``wss://`` instead of ``ws://``.
There is no need to specify ``websocket=True`` in each rule. It is added
automatically.
Args:
url_map: Mapping of URL patterns to connection handlers.
server_name: Name of the server as seen by clients. If :obj:`None`,
websockets uses the value of the ``Host`` header.
ssl: Configuration for enabling TLS on the connection. Set it to
:obj:`True` if a reverse proxy terminates TLS connections.
create_router: Factory for the :class:`Router` dispatching requests to
handlers. Set it to a wrapper or a subclass to customize routing.
"""
url_scheme = "ws" if ssl is None else "wss"
if ssl is not True and ssl is not None:
kwargs["ssl"] = ssl
if create_router is None:
create_router = Router
router = create_router(url_map, server_name, url_scheme)
_process_request: (
Callable[
[ServerConnection, Request],
Response | None,
]
| None
) = kwargs.pop("process_request", None)
if _process_request is None:
process_request: Callable[
[ServerConnection, Request],
Response | None,
] = router.route_request
else:
def process_request(
connection: ServerConnection, request: Request
) -> Response | None:
response = _process_request(connection, request)
if response is not None:
return response
return router.route_request(connection, request)
return serve(router.handler, *args, process_request=process_request, **kwargs)
def unix_route(
url_map: Map,
path: str | None = None,
**kwargs: Any,
) -> Server:
"""
Create a WebSocket Unix server dispatching connections to different handlers.
:func:`unix_route` combines the behaviors of :func:`route` and
:func:`~websockets.sync.server.unix_serve`.
Args:
url_map: Mapping of URL patterns to connection handlers.
path: File system path to the Unix socket.
"""
return route(url_map, unix=True, path=path, **kwargs)

View File

@@ -0,0 +1,763 @@
from __future__ import annotations
import hmac
import http
import logging
import os
import re
import selectors
import socket
import ssl as ssl_module
import sys
import threading
import warnings
from collections.abc import Iterable, Sequence
from types import TracebackType
from typing import Any, Callable, Mapping, cast
from ..exceptions import InvalidHeader
from ..extensions.base import ServerExtensionFactory
from ..extensions.permessage_deflate import enable_server_permessage_deflate
from ..frames import CloseCode
from ..headers import (
build_www_authenticate_basic,
parse_authorization_basic,
validate_subprotocols,
)
from ..http11 import SERVER, Request, Response
from ..protocol import CONNECTING, OPEN, Event
from ..server import ServerProtocol
from ..typing import LoggerLike, Origin, StatusLike, Subprotocol
from .connection import Connection
from .utils import Deadline
__all__ = ["serve", "unix_serve", "ServerConnection", "Server", "basic_auth"]
class ServerConnection(Connection):
"""
:mod:`threading` implementation of a WebSocket server connection.
:class:`ServerConnection` provides :meth:`recv` and :meth:`send` methods for
receiving and sending messages.
It supports iteration to receive messages::
for message in websocket:
process(message)
The iterator exits normally when the connection is closed with close code
1000 (OK) or 1001 (going away) or without a close code. It raises a
:exc:`~websockets.exceptions.ConnectionClosedError` when the connection is
closed with any other code.
The ``ping_interval``, ``ping_timeout``, ``close_timeout``, and
``max_queue`` arguments have the same meaning as in :func:`serve`.
Args:
socket: Socket connected to a WebSocket client.
protocol: Sans-I/O connection.
"""
def __init__(
self,
socket: socket.socket,
protocol: ServerProtocol,
*,
ping_interval: float | None = 20,
ping_timeout: float | None = 20,
close_timeout: float | None = 10,
max_queue: int | None | tuple[int | None, int | None] = 16,
) -> None:
self.protocol: ServerProtocol
self.request_rcvd = threading.Event()
super().__init__(
socket,
protocol,
ping_interval=ping_interval,
ping_timeout=ping_timeout,
close_timeout=close_timeout,
max_queue=max_queue,
)
self.username: str # see basic_auth()
self.handler: Callable[[ServerConnection], None] # see route()
self.handler_kwargs: Mapping[str, Any] # see route()
def respond(self, status: StatusLike, text: str) -> Response:
"""
Create a plain text HTTP response.
``process_request`` and ``process_response`` may call this method to
return an HTTP response instead of performing the WebSocket opening
handshake.
You can modify the response before returning it, for example by changing
HTTP headers.
Args:
status: HTTP status code.
text: HTTP response body; it will be encoded to UTF-8.
Returns:
HTTP response to send to the client.
"""
return self.protocol.reject(status, text)
def handshake(
self,
process_request: (
Callable[
[ServerConnection, Request],
Response | None,
]
| None
) = None,
process_response: (
Callable[
[ServerConnection, Request, Response],
Response | None,
]
| None
) = None,
server_header: str | None = SERVER,
timeout: float | None = None,
) -> None:
"""
Perform the opening handshake.
"""
if not self.request_rcvd.wait(timeout):
raise TimeoutError("timed out while waiting for handshake request")
if self.request is not None:
with self.send_context(expected_state=CONNECTING):
response = None
if process_request is not None:
try:
response = process_request(self, self.request)
except Exception as exc:
self.protocol.handshake_exc = exc
response = self.protocol.reject(
http.HTTPStatus.INTERNAL_SERVER_ERROR,
(
"Failed to open a WebSocket connection.\n"
"See server log for more information.\n"
),
)
if response is None:
self.response = self.protocol.accept(self.request)
else:
self.response = response
if server_header:
self.response.headers["Server"] = server_header
response = None
if process_response is not None:
try:
response = process_response(self, self.request, self.response)
except Exception as exc:
self.protocol.handshake_exc = exc
response = self.protocol.reject(
http.HTTPStatus.INTERNAL_SERVER_ERROR,
(
"Failed to open a WebSocket connection.\n"
"See server log for more information.\n"
),
)
if response is not None:
self.response = response
self.protocol.send_response(self.response)
# self.protocol.handshake_exc is set when the connection is lost before
# receiving a request, when the request cannot be parsed, or when the
# handshake fails, including when process_request or process_response
# raises an exception.
# It isn't set when process_request or process_response sends an HTTP
# response that rejects the handshake.
if self.protocol.handshake_exc is not None:
raise self.protocol.handshake_exc
def process_event(self, event: Event) -> None:
"""
Process one incoming event.
"""
# First event - handshake request.
if self.request is None:
assert isinstance(event, Request)
self.request = event
self.request_rcvd.set()
# Later events - frames.
else:
super().process_event(event)
def recv_events(self) -> None:
"""
Read incoming data from the socket and process events.
"""
try:
super().recv_events()
finally:
# If the connection is closed during the handshake, unblock it.
self.request_rcvd.set()
class Server:
"""
WebSocket server returned by :func:`serve`.
This class mirrors the API of :class:`~socketserver.BaseServer`, notably the
:meth:`~socketserver.BaseServer.serve_forever` and
:meth:`~socketserver.BaseServer.shutdown` methods, as well as the context
manager protocol.
Args:
socket: Server socket listening for new connections.
handler: Handler for one connection. Receives the socket and address
returned by :meth:`~socket.socket.accept`.
logger: Logger for this server.
It defaults to ``logging.getLogger("websockets.server")``.
See the :doc:`logging guide <../../topics/logging>` for details.
"""
def __init__(
self,
socket: socket.socket,
handler: Callable[[socket.socket, Any], None],
logger: LoggerLike | None = None,
) -> None:
self.socket = socket
self.handler = handler
if logger is None:
logger = logging.getLogger("websockets.server")
self.logger = logger
if sys.platform != "win32":
self.shutdown_watcher, self.shutdown_notifier = os.pipe()
def serve_forever(self) -> None:
"""
See :meth:`socketserver.BaseServer.serve_forever`.
This method doesn't return. Calling :meth:`shutdown` from another thread
stops the server.
Typical use::
with serve(...) as server:
server.serve_forever()
"""
poller = selectors.DefaultSelector()
try:
poller.register(self.socket, selectors.EVENT_READ)
except ValueError: # pragma: no cover
# If shutdown() is called before poller.register(),
# the socket is closed and poller.register() raises
# ValueError: Invalid file descriptor: -1
return
if sys.platform != "win32":
poller.register(self.shutdown_watcher, selectors.EVENT_READ)
while True:
poller.select()
try:
# If the socket is closed, this will raise an exception and exit
# the loop. So we don't need to check the return value of select().
sock, addr = self.socket.accept()
except OSError:
break
# Since there isn't a mechanism for tracking connections and waiting
# for them to terminate, we cannot use daemon threads, or else all
# connections would be terminate brutally when closing the server.
thread = threading.Thread(target=self.handler, args=(sock, addr))
thread.start()
def shutdown(self) -> None:
"""
See :meth:`socketserver.BaseServer.shutdown`.
"""
self.socket.close()
if sys.platform != "win32":
os.write(self.shutdown_notifier, b"x")
def fileno(self) -> int:
"""
See :meth:`socketserver.BaseServer.fileno`.
"""
return self.socket.fileno()
def __enter__(self) -> Server:
return self
def __exit__(
self,
exc_type: type[BaseException] | None,
exc_value: BaseException | None,
traceback: TracebackType | None,
) -> None:
self.shutdown()
def __getattr__(name: str) -> Any:
if name == "WebSocketServer":
warnings.warn( # deprecated in 13.0 - 2024-08-20
"WebSocketServer was renamed to Server",
DeprecationWarning,
)
return Server
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
def serve(
handler: Callable[[ServerConnection], None],
host: str | None = None,
port: int | None = None,
*,
# TCP/TLS
sock: socket.socket | None = None,
ssl: ssl_module.SSLContext | None = None,
# WebSocket
origins: Sequence[Origin | re.Pattern[str] | None] | None = None,
extensions: Sequence[ServerExtensionFactory] | None = None,
subprotocols: Sequence[Subprotocol] | None = None,
select_subprotocol: (
Callable[
[ServerConnection, Sequence[Subprotocol]],
Subprotocol | None,
]
| None
) = None,
compression: str | None = "deflate",
# HTTP
process_request: (
Callable[
[ServerConnection, Request],
Response | None,
]
| None
) = None,
process_response: (
Callable[
[ServerConnection, Request, Response],
Response | None,
]
| None
) = None,
server_header: str | None = SERVER,
# Timeouts
open_timeout: float | None = 10,
ping_interval: float | None = 20,
ping_timeout: float | None = 20,
close_timeout: float | None = 10,
# Limits
max_size: int | None = 2**20,
max_queue: int | None | tuple[int | None, int | None] = 16,
# Logging
logger: LoggerLike | None = None,
# Escape hatch for advanced customization
create_connection: type[ServerConnection] | None = None,
**kwargs: Any,
) -> Server:
"""
Create a WebSocket server listening on ``host`` and ``port``.
Whenever a client connects, the server creates a :class:`ServerConnection`,
performs the opening handshake, and delegates to the ``handler``.
The handler receives the :class:`ServerConnection` instance, which you can
use to send and receive messages.
Once the handler completes, either normally or with an exception, the server
performs the closing handshake and closes the connection.
This function returns a :class:`Server` whose API mirrors
:class:`~socketserver.BaseServer`. Treat it as a context manager to ensure
that it will be closed and call :meth:`~Server.serve_forever` to serve
requests::
from websockets.sync.server import serve
def handler(websocket):
...
with serve(handler, ...) as server:
server.serve_forever()
Args:
handler: Connection handler. It receives the WebSocket connection,
which is a :class:`ServerConnection`, in argument.
host: Network interfaces the server binds to.
See :func:`~socket.create_server` for details.
port: TCP port the server listens on.
See :func:`~socket.create_server` for details.
sock: Preexisting TCP socket. ``sock`` replaces ``host`` and ``port``.
You may call :func:`socket.create_server` to create a suitable TCP
socket.
ssl: Configuration for enabling TLS on the connection.
origins: Acceptable values of the ``Origin`` header, for defending
against Cross-Site WebSocket Hijacking attacks. Values can be
:class:`str` to test for an exact match or regular expressions
compiled by :func:`re.compile` to test against a pattern. Include
:obj:`None` in the list if the lack of an origin is acceptable.
extensions: List of supported extensions, in order in which they
should be negotiated and run.
subprotocols: List of supported subprotocols, in order of decreasing
preference.
select_subprotocol: Callback for selecting a subprotocol among
those supported by the client and the server. It receives a
:class:`ServerConnection` (not a
:class:`~websockets.server.ServerProtocol`!) instance and a list of
subprotocols offered by the client. Other than the first argument,
it has the same behavior as the
:meth:`ServerProtocol.select_subprotocol
<websockets.server.ServerProtocol.select_subprotocol>` method.
compression: The "permessage-deflate" extension is enabled by default.
Set ``compression`` to :obj:`None` to disable it. See the
:doc:`compression guide <../../topics/compression>` for details.
process_request: Intercept the request during the opening handshake.
Return an HTTP response to force the response. Return :obj:`None` to
continue normally. When you force an HTTP 101 Continue response, the
handshake is successful. Else, the connection is aborted.
process_response: Intercept the response during the opening handshake.
Modify the response or return a new HTTP response to force the
response. Return :obj:`None` to continue normally. When you force an
HTTP 101 Continue response, the handshake is successful. Else, the
connection is aborted.
server_header: Value of the ``Server`` response header.
It defaults to ``"Python/x.y.z websockets/X.Y"``. Setting it to
:obj:`None` removes the header.
open_timeout: Timeout for opening connections in seconds.
:obj:`None` disables the timeout.
ping_interval: Interval between keepalive pings in seconds.
:obj:`None` disables keepalive.
ping_timeout: Timeout for keepalive pings in seconds.
:obj:`None` disables timeouts.
close_timeout: Timeout for closing connections in seconds.
:obj:`None` disables the timeout.
max_size: Maximum size of incoming messages in bytes.
:obj:`None` disables the limit.
max_queue: High-water mark of the buffer where frames are received.
It defaults to 16 frames. The low-water mark defaults to ``max_queue
// 4``. You may pass a ``(high, low)`` tuple to set the high-water
and low-water marks. If you want to disable flow control entirely,
you may set it to ``None``, although that's a bad idea.
logger: Logger for this server.
It defaults to ``logging.getLogger("websockets.server")``. See the
:doc:`logging guide <../../topics/logging>` for details.
create_connection: Factory for the :class:`ServerConnection` managing
the connection. Set it to a wrapper or a subclass to customize
connection handling.
Any other keyword arguments are passed to :func:`~socket.create_server`.
"""
# Process parameters
# Backwards compatibility: ssl used to be called ssl_context.
if ssl is None and "ssl_context" in kwargs:
ssl = kwargs.pop("ssl_context")
warnings.warn( # deprecated in 13.0 - 2024-08-20
"ssl_context was renamed to ssl",
DeprecationWarning,
)
if subprotocols is not None:
validate_subprotocols(subprotocols)
if compression == "deflate":
extensions = enable_server_permessage_deflate(extensions)
elif compression is not None:
raise ValueError(f"unsupported compression: {compression}")
if create_connection is None:
create_connection = ServerConnection
# Bind socket and listen
# Private APIs for unix_connect()
unix: bool = kwargs.pop("unix", False)
path: str | None = kwargs.pop("path", None)
if sock is None:
if unix:
if path is None:
raise ValueError("missing path argument")
kwargs.setdefault("family", socket.AF_UNIX)
sock = socket.create_server(path, **kwargs)
else:
sock = socket.create_server((host, port), **kwargs)
else:
if path is not None:
raise ValueError("path and sock arguments are incompatible")
# Initialize TLS wrapper
if ssl is not None:
sock = ssl.wrap_socket(
sock,
server_side=True,
# Delay TLS handshake until after we set a timeout on the socket.
do_handshake_on_connect=False,
)
# Define request handler
def conn_handler(sock: socket.socket, addr: Any) -> None:
# Calculate timeouts on the TLS and WebSocket handshakes.
# The TLS timeout must be set on the socket, then removed
# to avoid conflicting with the WebSocket timeout in handshake().
deadline = Deadline(open_timeout)
try:
# Disable Nagle algorithm
if not unix:
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, True)
# Perform TLS handshake
if ssl is not None:
sock.settimeout(deadline.timeout())
# mypy cannot figure this out
assert isinstance(sock, ssl_module.SSLSocket)
sock.do_handshake()
sock.settimeout(None)
# Create a closure to give select_subprotocol access to connection.
protocol_select_subprotocol: (
Callable[
[ServerProtocol, Sequence[Subprotocol]],
Subprotocol | None,
]
| None
) = None
if select_subprotocol is not None:
def protocol_select_subprotocol(
protocol: ServerProtocol,
subprotocols: Sequence[Subprotocol],
) -> Subprotocol | None:
# mypy doesn't know that select_subprotocol is immutable.
assert select_subprotocol is not None
# Ensure this function is only used in the intended context.
assert protocol is connection.protocol
return select_subprotocol(connection, subprotocols)
# Initialize WebSocket protocol
protocol = ServerProtocol(
origins=origins,
extensions=extensions,
subprotocols=subprotocols,
select_subprotocol=protocol_select_subprotocol,
max_size=max_size,
logger=logger,
)
# Initialize WebSocket connection
assert create_connection is not None # help mypy
connection = create_connection(
sock,
protocol,
ping_interval=ping_interval,
ping_timeout=ping_timeout,
close_timeout=close_timeout,
max_queue=max_queue,
)
except Exception:
sock.close()
return
try:
try:
connection.handshake(
process_request,
process_response,
server_header,
deadline.timeout(),
)
except TimeoutError:
connection.close_socket()
connection.recv_events_thread.join()
return
except Exception:
connection.logger.error("opening handshake failed", exc_info=True)
connection.close_socket()
connection.recv_events_thread.join()
return
assert connection.protocol.state is OPEN
try:
connection.start_keepalive()
handler(connection)
except Exception:
connection.logger.error("connection handler failed", exc_info=True)
connection.close(CloseCode.INTERNAL_ERROR)
else:
connection.close()
except Exception: # pragma: no cover
# Don't leak sockets on unexpected errors.
sock.close()
# Initialize server
return Server(sock, conn_handler, logger)
def unix_serve(
handler: Callable[[ServerConnection], None],
path: str | None = None,
**kwargs: Any,
) -> Server:
"""
Create a WebSocket server listening on a Unix socket.
This function accepts the same keyword arguments as :func:`serve`.
It's only available on Unix.
It's useful for deploying a server behind a reverse proxy such as nginx.
Args:
handler: Connection handler. It receives the WebSocket connection,
which is a :class:`ServerConnection`, in argument.
path: File system path to the Unix socket.
"""
return serve(handler, unix=True, path=path, **kwargs)
def is_credentials(credentials: Any) -> bool:
try:
username, password = credentials
except (TypeError, ValueError):
return False
else:
return isinstance(username, str) and isinstance(password, str)
def basic_auth(
realm: str = "",
credentials: tuple[str, str] | Iterable[tuple[str, str]] | None = None,
check_credentials: Callable[[str, str], bool] | None = None,
) -> Callable[[ServerConnection, Request], Response | None]:
"""
Factory for ``process_request`` to enforce HTTP Basic Authentication.
:func:`basic_auth` is designed to integrate with :func:`serve` as follows::
from websockets.sync.server import basic_auth, serve
with serve(
...,
process_request=basic_auth(
realm="my dev server",
credentials=("hello", "iloveyou"),
),
):
If authentication succeeds, the connection's ``username`` attribute is set.
If it fails, the server responds with an HTTP 401 Unauthorized status.
One of ``credentials`` or ``check_credentials`` must be provided; not both.
Args:
realm: Scope of protection. It should contain only ASCII characters
because the encoding of non-ASCII characters is undefined. Refer to
section 2.2 of :rfc:`7235` for details.
credentials: Hard coded authorized credentials. It can be a
``(username, password)`` pair or a list of such pairs.
check_credentials: Function that verifies credentials.
It receives ``username`` and ``password`` arguments and returns
whether they're valid.
Raises:
TypeError: If ``credentials`` or ``check_credentials`` is wrong.
ValueError: If ``credentials`` and ``check_credentials`` are both
provided or both not provided.
"""
if (credentials is None) == (check_credentials is None):
raise ValueError("provide either credentials or check_credentials")
if credentials is not None:
if is_credentials(credentials):
credentials_list = [cast(tuple[str, str], credentials)]
elif isinstance(credentials, Iterable):
credentials_list = list(cast(Iterable[tuple[str, str]], credentials))
if not all(is_credentials(item) for item in credentials_list):
raise TypeError(f"invalid credentials argument: {credentials}")
else:
raise TypeError(f"invalid credentials argument: {credentials}")
credentials_dict = dict(credentials_list)
def check_credentials(username: str, password: str) -> bool:
try:
expected_password = credentials_dict[username]
except KeyError:
return False
return hmac.compare_digest(expected_password, password)
assert check_credentials is not None # help mypy
def process_request(
connection: ServerConnection,
request: Request,
) -> Response | None:
"""
Perform HTTP Basic Authentication.
If it succeeds, set the connection's ``username`` attribute and return
:obj:`None`. If it fails, return an HTTP 401 Unauthorized responss.
"""
try:
authorization = request.headers["Authorization"]
except KeyError:
response = connection.respond(
http.HTTPStatus.UNAUTHORIZED,
"Missing credentials\n",
)
response.headers["WWW-Authenticate"] = build_www_authenticate_basic(realm)
return response
try:
username, password = parse_authorization_basic(authorization)
except InvalidHeader:
response = connection.respond(
http.HTTPStatus.UNAUTHORIZED,
"Unsupported credentials\n",
)
response.headers["WWW-Authenticate"] = build_www_authenticate_basic(realm)
return response
if not check_credentials(username, password):
response = connection.respond(
http.HTTPStatus.UNAUTHORIZED,
"Invalid credentials\n",
)
response.headers["WWW-Authenticate"] = build_www_authenticate_basic(realm)
return response
connection.username = username
return None
return process_request

View File

@@ -0,0 +1,45 @@
from __future__ import annotations
import time
__all__ = ["Deadline"]
class Deadline:
"""
Manage timeouts across multiple steps.
Args:
timeout: Time available in seconds or :obj:`None` if there is no limit.
"""
def __init__(self, timeout: float | None) -> None:
self.deadline: float | None
if timeout is None:
self.deadline = None
else:
self.deadline = time.monotonic() + timeout
def timeout(self, *, raise_if_elapsed: bool = True) -> float | None:
"""
Calculate a timeout from a deadline.
Args:
raise_if_elapsed: Whether to raise :exc:`TimeoutError`
if the deadline lapsed.
Raises:
TimeoutError: If the deadline lapsed.
Returns:
Time left in seconds or :obj:`None` if there is no limit.
"""
if self.deadline is None:
return None
timeout = self.deadline - time.monotonic()
if raise_if_elapsed and timeout <= 0:
raise TimeoutError("timed out")
return timeout