This commit is contained in:
Iliyan Angelov
2025-11-23 18:59:18 +02:00
parent be07802066
commit 627959f52b
1840 changed files with 236564 additions and 3475 deletions

View File

@@ -1,17 +1,6 @@
# This contains the main Connection class. Everything in h11 revolves around
# this.
from typing import (
Any,
Callable,
cast,
Dict,
List,
Optional,
overload,
Tuple,
Type,
Union,
)
from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Type, Union
from ._events import (
ConnectionClosed,
@@ -68,7 +57,6 @@ class PAUSED(Sentinel, metaclass=Sentinel):
# - Apache: <8 KiB per line>
DEFAULT_MAX_INCOMPLETE_EVENT_SIZE = 16 * 1024
# RFC 7230's rules for connection lifecycles:
# - If either side says they want to close the connection, then the connection
# must close.
@@ -172,7 +160,7 @@ class Connection:
self._max_incomplete_event_size = max_incomplete_event_size
# State and role tracking
if our_role not in (CLIENT, SERVER):
raise ValueError(f"expected CLIENT or SERVER, not {our_role!r}")
raise ValueError("expected CLIENT or SERVER, not {!r}".format(our_role))
self.our_role = our_role
self.their_role: Type[Sentinel]
if our_role is CLIENT:
@@ -428,7 +416,7 @@ class Connection:
# return that event, and then the state will change and we'll
# get called again to generate the actual ConnectionClosed().
if hasattr(self._reader, "read_eof"):
event = self._reader.read_eof()
event = self._reader.read_eof() # type: ignore[attr-defined]
else:
event = ConnectionClosed()
if event is None:
@@ -500,20 +488,6 @@ class Connection:
else:
raise
@overload
def send(self, event: ConnectionClosed) -> None:
...
@overload
def send(
self, event: Union[Request, InformationalResponse, Response, Data, EndOfMessage]
) -> bytes:
...
@overload
def send(self, event: Event) -> Optional[bytes]:
...
def send(self, event: Event) -> Optional[bytes]:
"""Convert a high-level event into bytes that can be sent to the peer,
while updating our internal state machine.

View File

@@ -7,8 +7,8 @@
import re
from abc import ABC
from dataclasses import dataclass
from typing import List, Tuple, Union
from dataclasses import dataclass, field
from typing import Any, cast, Dict, List, Tuple, Union
from ._abnf import method, request_target
from ._headers import Headers, normalize_and_validate

View File

@@ -12,8 +12,6 @@ try:
except ImportError:
from typing_extensions import Literal # type: ignore
CONTENT_LENGTH_MAX_DIGITS = 20 # allow up to 1 billion TB - 1
# Facts
# -----
@@ -175,8 +173,6 @@ def normalize_and_validate(
raise LocalProtocolError("conflicting Content-Length headers")
value = lengths.pop()
validate(_content_length_re, value, "bad Content-Length")
if len(value) > CONTENT_LENGTH_MAX_DIGITS:
raise LocalProtocolError("bad Content-Length")
if seen_content_length is None:
seen_content_length = value
new_headers.append((raw_name, name, value))

View File

@@ -148,9 +148,10 @@ chunk_header_re = re.compile(chunk_header.encode("ascii"))
class ChunkedReader:
def __init__(self) -> None:
self._bytes_in_chunk = 0
# After reading a chunk, we have to throw away the trailing \r\n.
# This tracks the bytes that we need to match and throw away.
self._bytes_to_discard = b""
# After reading a chunk, we have to throw away the trailing \r\n; if
# this is >0 then we discard that many bytes before resuming regular
# de-chunkification.
self._bytes_to_discard = 0
self._reading_trailer = False
def __call__(self, buf: ReceiveBuffer) -> Union[Data, EndOfMessage, None]:
@@ -159,19 +160,15 @@ class ChunkedReader:
if lines is None:
return None
return EndOfMessage(headers=list(_decode_header_lines(lines)))
if self._bytes_to_discard:
data = buf.maybe_extract_at_most(len(self._bytes_to_discard))
if self._bytes_to_discard > 0:
data = buf.maybe_extract_at_most(self._bytes_to_discard)
if data is None:
return None
if data != self._bytes_to_discard[: len(data)]:
raise LocalProtocolError(
f"malformed chunk footer: {data!r} (expected {self._bytes_to_discard!r})"
)
self._bytes_to_discard = self._bytes_to_discard[len(data) :]
if self._bytes_to_discard:
self._bytes_to_discard -= len(data)
if self._bytes_to_discard > 0:
return None
# else, fall through and read some more
assert self._bytes_to_discard == b""
assert self._bytes_to_discard == 0
if self._bytes_in_chunk == 0:
# We need to refill our chunk count
chunk_header = buf.maybe_extract_next_line()
@@ -197,7 +194,7 @@ class ChunkedReader:
return None
self._bytes_in_chunk -= len(data)
if self._bytes_in_chunk == 0:
self._bytes_to_discard = b"\r\n"
self._bytes_to_discard = 2
chunk_end = True
else:
chunk_end = False

View File

@@ -283,7 +283,9 @@ class ConnectionState:
assert role is SERVER
if server_switch_event not in self.pending_switch_proposals:
raise LocalProtocolError(
"Received server _SWITCH_UPGRADE event without a pending proposal"
"Received server {} event without a pending proposal".format(
server_switch_event
)
)
_event_type = (event_type, server_switch_event)
if server_switch_event is None and _event_type is Response:
@@ -356,7 +358,7 @@ class ConnectionState:
def start_next_cycle(self) -> None:
if self.states != {CLIENT: DONE, SERVER: DONE}:
raise LocalProtocolError(
f"not in a reusable state. self.states={self.states}"
"not in a reusable state. self.states={}".format(self.states)
)
# Can't reach DONE/DONE with any of these active, but still, let's be
# sure.

View File

@@ -13,4 +13,4 @@
# want. (Contrast with the special suffix 1.0.0.dev, which sorts *before*
# 1.0.0.)
__version__ = "0.16.0"
__version__ = "0.14.0"

View File

@@ -0,0 +1 @@
92b12bc045050b55b848d37167a1a63947c364579889ce1d39788e45e9fac9e5

View File

@@ -0,0 +1,101 @@
from typing import cast, List, Type, Union, ValuesView
from .._connection import Connection, NEED_DATA, PAUSED
from .._events import (
ConnectionClosed,
Data,
EndOfMessage,
Event,
InformationalResponse,
Request,
Response,
)
from .._state import CLIENT, CLOSED, DONE, MUST_CLOSE, SERVER
from .._util import Sentinel
try:
from typing import Literal
except ImportError:
from typing_extensions import Literal # type: ignore
def get_all_events(conn: Connection) -> List[Event]:
got_events = []
while True:
event = conn.next_event()
if event in (NEED_DATA, PAUSED):
break
event = cast(Event, event)
got_events.append(event)
if type(event) is ConnectionClosed:
break
return got_events
def receive_and_get(conn: Connection, data: bytes) -> List[Event]:
conn.receive_data(data)
return get_all_events(conn)
# Merges adjacent Data events, converts payloads to bytestrings, and removes
# chunk boundaries.
def normalize_data_events(in_events: List[Event]) -> List[Event]:
out_events: List[Event] = []
for event in in_events:
if type(event) is Data:
event = Data(data=bytes(event.data), chunk_start=False, chunk_end=False)
if out_events and type(out_events[-1]) is type(event) is Data:
out_events[-1] = Data(
data=out_events[-1].data + event.data,
chunk_start=out_events[-1].chunk_start,
chunk_end=out_events[-1].chunk_end,
)
else:
out_events.append(event)
return out_events
# Given that we want to write tests that push some events through a Connection
# and check that its state updates appropriately... we might as make a habit
# of pushing them through two Connections with a fake network link in
# between.
class ConnectionPair:
def __init__(self) -> None:
self.conn = {CLIENT: Connection(CLIENT), SERVER: Connection(SERVER)}
self.other = {CLIENT: SERVER, SERVER: CLIENT}
@property
def conns(self) -> ValuesView[Connection]:
return self.conn.values()
# expect="match" if expect=send_events; expect=[...] to say what expected
def send(
self,
role: Type[Sentinel],
send_events: Union[List[Event], Event],
expect: Union[List[Event], Event, Literal["match"]] = "match",
) -> bytes:
if not isinstance(send_events, list):
send_events = [send_events]
data = b""
closed = False
for send_event in send_events:
new_data = self.conn[role].send(send_event)
if new_data is None:
closed = True
else:
data += new_data
# send uses b"" to mean b"", and None to mean closed
# receive uses b"" to mean closed, and None to mean "try again"
# so we have to translate between the two conventions
if data:
self.conn[self.other[role]].receive_data(data)
if closed:
self.conn[self.other[role]].receive_data(b"")
got_events = get_all_events(self.conn[self.other[role]])
if expect == "match":
expect = send_events
if not isinstance(expect, list):
expect = [expect]
assert got_events == expect
return data

View File

@@ -0,0 +1,115 @@
import json
import os.path
import socket
import socketserver
import threading
from contextlib import closing, contextmanager
from http.server import SimpleHTTPRequestHandler
from typing import Callable, Generator
from urllib.request import urlopen
import h11
@contextmanager
def socket_server(
handler: Callable[..., socketserver.BaseRequestHandler]
) -> Generator[socketserver.TCPServer, None, None]:
httpd = socketserver.TCPServer(("127.0.0.1", 0), handler)
thread = threading.Thread(
target=httpd.serve_forever, kwargs={"poll_interval": 0.01}
)
thread.daemon = True
try:
thread.start()
yield httpd
finally:
httpd.shutdown()
test_file_path = os.path.join(os.path.dirname(__file__), "data/test-file")
with open(test_file_path, "rb") as f:
test_file_data = f.read()
class SingleMindedRequestHandler(SimpleHTTPRequestHandler):
def translate_path(self, path: str) -> str:
return test_file_path
def test_h11_as_client() -> None:
with socket_server(SingleMindedRequestHandler) as httpd:
with closing(socket.create_connection(httpd.server_address)) as s:
c = h11.Connection(h11.CLIENT)
s.sendall(
c.send( # type: ignore[arg-type]
h11.Request(
method="GET", target="/foo", headers=[("Host", "localhost")]
)
)
)
s.sendall(c.send(h11.EndOfMessage())) # type: ignore[arg-type]
data = bytearray()
while True:
event = c.next_event()
print(event)
if event is h11.NEED_DATA:
# Use a small read buffer to make things more challenging
# and exercise more paths :-)
c.receive_data(s.recv(10))
continue
if type(event) is h11.Response:
assert event.status_code == 200
if type(event) is h11.Data:
data += event.data
if type(event) is h11.EndOfMessage:
break
assert bytes(data) == test_file_data
class H11RequestHandler(socketserver.BaseRequestHandler):
def handle(self) -> None:
with closing(self.request) as s:
c = h11.Connection(h11.SERVER)
request = None
while True:
event = c.next_event()
if event is h11.NEED_DATA:
# Use a small read buffer to make things more challenging
# and exercise more paths :-)
c.receive_data(s.recv(10))
continue
if type(event) is h11.Request:
request = event
if type(event) is h11.EndOfMessage:
break
assert request is not None
info = json.dumps(
{
"method": request.method.decode("ascii"),
"target": request.target.decode("ascii"),
"headers": {
name.decode("ascii"): value.decode("ascii")
for (name, value) in request.headers
},
}
)
s.sendall(c.send(h11.Response(status_code=200, headers=[]))) # type: ignore[arg-type]
s.sendall(c.send(h11.Data(data=info.encode("ascii"))))
s.sendall(c.send(h11.EndOfMessage()))
def test_h11_as_server() -> None:
with socket_server(H11RequestHandler) as httpd:
host, port = httpd.server_address
url = "http://{}:{}/some-path".format(host, port)
with closing(urlopen(url)) as f:
assert f.getcode() == 200
data = f.read()
info = json.loads(data.decode("ascii"))
print(info)
assert info["method"] == "GET"
assert info["target"] == "/some-path"
assert "urllib" in info["headers"]["user-agent"]

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,150 @@
from http import HTTPStatus
import pytest
from .. import _events
from .._events import (
ConnectionClosed,
Data,
EndOfMessage,
Event,
InformationalResponse,
Request,
Response,
)
from .._util import LocalProtocolError
def test_events() -> None:
with pytest.raises(LocalProtocolError):
# Missing Host:
req = Request(
method="GET", target="/", headers=[("a", "b")], http_version="1.1"
)
# But this is okay (HTTP/1.0)
req = Request(method="GET", target="/", headers=[("a", "b")], http_version="1.0")
# fields are normalized
assert req.method == b"GET"
assert req.target == b"/"
assert req.headers == [(b"a", b"b")]
assert req.http_version == b"1.0"
# This is also okay -- has a Host (with weird capitalization, which is ok)
req = Request(
method="GET",
target="/",
headers=[("a", "b"), ("hOSt", "example.com")],
http_version="1.1",
)
# we normalize header capitalization
assert req.headers == [(b"a", b"b"), (b"host", b"example.com")]
# Multiple host is bad too
with pytest.raises(LocalProtocolError):
req = Request(
method="GET",
target="/",
headers=[("Host", "a"), ("Host", "a")],
http_version="1.1",
)
# Even for HTTP/1.0
with pytest.raises(LocalProtocolError):
req = Request(
method="GET",
target="/",
headers=[("Host", "a"), ("Host", "a")],
http_version="1.0",
)
# Header values are validated
for bad_char in "\x00\r\n\f\v":
with pytest.raises(LocalProtocolError):
req = Request(
method="GET",
target="/",
headers=[("Host", "a"), ("Foo", "asd" + bad_char)],
http_version="1.0",
)
# But for compatibility we allow non-whitespace control characters, even
# though they're forbidden by the spec.
Request(
method="GET",
target="/",
headers=[("Host", "a"), ("Foo", "asd\x01\x02\x7f")],
http_version="1.0",
)
# Request target is validated
for bad_byte in b"\x00\x20\x7f\xee":
target = bytearray(b"/")
target.append(bad_byte)
with pytest.raises(LocalProtocolError):
Request(
method="GET", target=target, headers=[("Host", "a")], http_version="1.1"
)
# Request method is validated
with pytest.raises(LocalProtocolError):
Request(
method="GET / HTTP/1.1",
target=target,
headers=[("Host", "a")],
http_version="1.1",
)
ir = InformationalResponse(status_code=100, headers=[("Host", "a")])
assert ir.status_code == 100
assert ir.headers == [(b"host", b"a")]
assert ir.http_version == b"1.1"
with pytest.raises(LocalProtocolError):
InformationalResponse(status_code=200, headers=[("Host", "a")])
resp = Response(status_code=204, headers=[], http_version="1.0") # type: ignore[arg-type]
assert resp.status_code == 204
assert resp.headers == []
assert resp.http_version == b"1.0"
with pytest.raises(LocalProtocolError):
resp = Response(status_code=100, headers=[], http_version="1.0") # type: ignore[arg-type]
with pytest.raises(LocalProtocolError):
Response(status_code="100", headers=[], http_version="1.0") # type: ignore[arg-type]
with pytest.raises(LocalProtocolError):
InformationalResponse(status_code=b"100", headers=[], http_version="1.0") # type: ignore[arg-type]
d = Data(data=b"asdf")
assert d.data == b"asdf"
eom = EndOfMessage()
assert eom.headers == []
cc = ConnectionClosed()
assert repr(cc) == "ConnectionClosed()"
def test_intenum_status_code() -> None:
# https://github.com/python-hyper/h11/issues/72
r = Response(status_code=HTTPStatus.OK, headers=[], http_version="1.0") # type: ignore[arg-type]
assert r.status_code == HTTPStatus.OK
assert type(r.status_code) is not type(HTTPStatus.OK)
assert type(r.status_code) is int
def test_header_casing() -> None:
r = Request(
method="GET",
target="/",
headers=[("Host", "example.org"), ("Connection", "keep-alive")],
http_version="1.1",
)
assert len(r.headers) == 2
assert r.headers[0] == (b"host", b"example.org")
assert r.headers == [(b"host", b"example.org"), (b"connection", b"keep-alive")]
assert r.headers.raw_items() == [
(b"Host", b"example.org"),
(b"Connection", b"keep-alive"),
]

View File

@@ -0,0 +1,157 @@
import pytest
from .._events import Request
from .._headers import (
get_comma_header,
has_expect_100_continue,
Headers,
normalize_and_validate,
set_comma_header,
)
from .._util import LocalProtocolError
def test_normalize_and_validate() -> None:
assert normalize_and_validate([("foo", "bar")]) == [(b"foo", b"bar")]
assert normalize_and_validate([(b"foo", b"bar")]) == [(b"foo", b"bar")]
# no leading/trailing whitespace in names
with pytest.raises(LocalProtocolError):
normalize_and_validate([(b"foo ", "bar")])
with pytest.raises(LocalProtocolError):
normalize_and_validate([(b" foo", "bar")])
# no weird characters in names
with pytest.raises(LocalProtocolError) as excinfo:
normalize_and_validate([(b"foo bar", b"baz")])
assert "foo bar" in str(excinfo.value)
with pytest.raises(LocalProtocolError):
normalize_and_validate([(b"foo\x00bar", b"baz")])
# Not even 8-bit characters:
with pytest.raises(LocalProtocolError):
normalize_and_validate([(b"foo\xffbar", b"baz")])
# And not even the control characters we allow in values:
with pytest.raises(LocalProtocolError):
normalize_and_validate([(b"foo\x01bar", b"baz")])
# no return or NUL characters in values
with pytest.raises(LocalProtocolError) as excinfo:
normalize_and_validate([("foo", "bar\rbaz")])
assert "bar\\rbaz" in str(excinfo.value)
with pytest.raises(LocalProtocolError):
normalize_and_validate([("foo", "bar\nbaz")])
with pytest.raises(LocalProtocolError):
normalize_and_validate([("foo", "bar\x00baz")])
# no leading/trailing whitespace
with pytest.raises(LocalProtocolError):
normalize_and_validate([("foo", "barbaz ")])
with pytest.raises(LocalProtocolError):
normalize_and_validate([("foo", " barbaz")])
with pytest.raises(LocalProtocolError):
normalize_and_validate([("foo", "barbaz\t")])
with pytest.raises(LocalProtocolError):
normalize_and_validate([("foo", "\tbarbaz")])
# content-length
assert normalize_and_validate([("Content-Length", "1")]) == [
(b"content-length", b"1")
]
with pytest.raises(LocalProtocolError):
normalize_and_validate([("Content-Length", "asdf")])
with pytest.raises(LocalProtocolError):
normalize_and_validate([("Content-Length", "1x")])
with pytest.raises(LocalProtocolError):
normalize_and_validate([("Content-Length", "1"), ("Content-Length", "2")])
assert normalize_and_validate(
[("Content-Length", "0"), ("Content-Length", "0")]
) == [(b"content-length", b"0")]
assert normalize_and_validate([("Content-Length", "0 , 0")]) == [
(b"content-length", b"0")
]
with pytest.raises(LocalProtocolError):
normalize_and_validate(
[("Content-Length", "1"), ("Content-Length", "1"), ("Content-Length", "2")]
)
with pytest.raises(LocalProtocolError):
normalize_and_validate([("Content-Length", "1 , 1,2")])
# transfer-encoding
assert normalize_and_validate([("Transfer-Encoding", "chunked")]) == [
(b"transfer-encoding", b"chunked")
]
assert normalize_and_validate([("Transfer-Encoding", "cHuNkEd")]) == [
(b"transfer-encoding", b"chunked")
]
with pytest.raises(LocalProtocolError) as excinfo:
normalize_and_validate([("Transfer-Encoding", "gzip")])
assert excinfo.value.error_status_hint == 501 # Not Implemented
with pytest.raises(LocalProtocolError) as excinfo:
normalize_and_validate(
[("Transfer-Encoding", "chunked"), ("Transfer-Encoding", "gzip")]
)
assert excinfo.value.error_status_hint == 501 # Not Implemented
def test_get_set_comma_header() -> None:
headers = normalize_and_validate(
[
("Connection", "close"),
("whatever", "something"),
("connectiON", "fOo,, , BAR"),
]
)
assert get_comma_header(headers, b"connection") == [b"close", b"foo", b"bar"]
headers = set_comma_header(headers, b"newthing", ["a", "b"]) # type: ignore
with pytest.raises(LocalProtocolError):
set_comma_header(headers, b"newthing", [" a", "b"]) # type: ignore
assert headers == [
(b"connection", b"close"),
(b"whatever", b"something"),
(b"connection", b"fOo,, , BAR"),
(b"newthing", b"a"),
(b"newthing", b"b"),
]
headers = set_comma_header(headers, b"whatever", ["different thing"]) # type: ignore
assert headers == [
(b"connection", b"close"),
(b"connection", b"fOo,, , BAR"),
(b"newthing", b"a"),
(b"newthing", b"b"),
(b"whatever", b"different thing"),
]
def test_has_100_continue() -> None:
assert has_expect_100_continue(
Request(
method="GET",
target="/",
headers=[("Host", "example.com"), ("Expect", "100-continue")],
)
)
assert not has_expect_100_continue(
Request(method="GET", target="/", headers=[("Host", "example.com")])
)
# Case insensitive
assert has_expect_100_continue(
Request(
method="GET",
target="/",
headers=[("Host", "example.com"), ("Expect", "100-Continue")],
)
)
# Doesn't work in HTTP/1.0
assert not has_expect_100_continue(
Request(
method="GET",
target="/",
headers=[("Host", "example.com"), ("Expect", "100-continue")],
http_version="1.0",
)
)

View File

@@ -0,0 +1,32 @@
from .._events import (
ConnectionClosed,
Data,
EndOfMessage,
Event,
InformationalResponse,
Request,
Response,
)
from .helpers import normalize_data_events
def test_normalize_data_events() -> None:
assert normalize_data_events(
[
Data(data=bytearray(b"1")),
Data(data=b"2"),
Response(status_code=200, headers=[]), # type: ignore[arg-type]
Data(data=b"3"),
Data(data=b"4"),
EndOfMessage(),
Data(data=b"5"),
Data(data=b"6"),
Data(data=b"7"),
]
) == [
Data(data=b"12"),
Response(status_code=200, headers=[]), # type: ignore[arg-type]
Data(data=b"34"),
EndOfMessage(),
Data(data=b"567"),
]

View File

@@ -0,0 +1,572 @@
from typing import Any, Callable, Generator, List
import pytest
from .._events import (
ConnectionClosed,
Data,
EndOfMessage,
Event,
InformationalResponse,
Request,
Response,
)
from .._headers import Headers, normalize_and_validate
from .._readers import (
_obsolete_line_fold,
ChunkedReader,
ContentLengthReader,
Http10Reader,
READERS,
)
from .._receivebuffer import ReceiveBuffer
from .._state import (
CLIENT,
CLOSED,
DONE,
IDLE,
MIGHT_SWITCH_PROTOCOL,
MUST_CLOSE,
SEND_BODY,
SEND_RESPONSE,
SERVER,
SWITCHED_PROTOCOL,
)
from .._util import LocalProtocolError
from .._writers import (
ChunkedWriter,
ContentLengthWriter,
Http10Writer,
write_any_response,
write_headers,
write_request,
WRITERS,
)
from .helpers import normalize_data_events
SIMPLE_CASES = [
(
(CLIENT, IDLE),
Request(
method="GET",
target="/a",
headers=[("Host", "foo"), ("Connection", "close")],
),
b"GET /a HTTP/1.1\r\nHost: foo\r\nConnection: close\r\n\r\n",
),
(
(SERVER, SEND_RESPONSE),
Response(status_code=200, headers=[("Connection", "close")], reason=b"OK"),
b"HTTP/1.1 200 OK\r\nConnection: close\r\n\r\n",
),
(
(SERVER, SEND_RESPONSE),
Response(status_code=200, headers=[], reason=b"OK"), # type: ignore[arg-type]
b"HTTP/1.1 200 OK\r\n\r\n",
),
(
(SERVER, SEND_RESPONSE),
InformationalResponse(
status_code=101, headers=[("Upgrade", "websocket")], reason=b"Upgrade"
),
b"HTTP/1.1 101 Upgrade\r\nUpgrade: websocket\r\n\r\n",
),
(
(SERVER, SEND_RESPONSE),
InformationalResponse(status_code=101, headers=[], reason=b"Upgrade"), # type: ignore[arg-type]
b"HTTP/1.1 101 Upgrade\r\n\r\n",
),
]
def dowrite(writer: Callable[..., None], obj: Any) -> bytes:
got_list: List[bytes] = []
writer(obj, got_list.append)
return b"".join(got_list)
def tw(writer: Any, obj: Any, expected: Any) -> None:
got = dowrite(writer, obj)
assert got == expected
def makebuf(data: bytes) -> ReceiveBuffer:
buf = ReceiveBuffer()
buf += data
return buf
def tr(reader: Any, data: bytes, expected: Any) -> None:
def check(got: Any) -> None:
assert got == expected
# Headers should always be returned as bytes, not e.g. bytearray
# https://github.com/python-hyper/wsproto/pull/54#issuecomment-377709478
for name, value in getattr(got, "headers", []):
assert type(name) is bytes
assert type(value) is bytes
# Simple: consume whole thing
buf = makebuf(data)
check(reader(buf))
assert not buf
# Incrementally growing buffer
buf = ReceiveBuffer()
for i in range(len(data)):
assert reader(buf) is None
buf += data[i : i + 1]
check(reader(buf))
# Trailing data
buf = makebuf(data)
buf += b"trailing"
check(reader(buf))
assert bytes(buf) == b"trailing"
def test_writers_simple() -> None:
for ((role, state), event, binary) in SIMPLE_CASES:
tw(WRITERS[role, state], event, binary)
def test_readers_simple() -> None:
for ((role, state), event, binary) in SIMPLE_CASES:
tr(READERS[role, state], binary, event)
def test_writers_unusual() -> None:
# Simple test of the write_headers utility routine
tw(
write_headers,
normalize_and_validate([("foo", "bar"), ("baz", "quux")]),
b"foo: bar\r\nbaz: quux\r\n\r\n",
)
tw(write_headers, Headers([]), b"\r\n")
# We understand HTTP/1.0, but we don't speak it
with pytest.raises(LocalProtocolError):
tw(
write_request,
Request(
method="GET",
target="/",
headers=[("Host", "foo"), ("Connection", "close")],
http_version="1.0",
),
None,
)
with pytest.raises(LocalProtocolError):
tw(
write_any_response,
Response(
status_code=200, headers=[("Connection", "close")], http_version="1.0"
),
None,
)
def test_readers_unusual() -> None:
# Reading HTTP/1.0
tr(
READERS[CLIENT, IDLE],
b"HEAD /foo HTTP/1.0\r\nSome: header\r\n\r\n",
Request(
method="HEAD",
target="/foo",
headers=[("Some", "header")],
http_version="1.0",
),
)
# check no-headers, since it's only legal with HTTP/1.0
tr(
READERS[CLIENT, IDLE],
b"HEAD /foo HTTP/1.0\r\n\r\n",
Request(method="HEAD", target="/foo", headers=[], http_version="1.0"), # type: ignore[arg-type]
)
tr(
READERS[SERVER, SEND_RESPONSE],
b"HTTP/1.0 200 OK\r\nSome: header\r\n\r\n",
Response(
status_code=200,
headers=[("Some", "header")],
http_version="1.0",
reason=b"OK",
),
)
# single-character header values (actually disallowed by the ABNF in RFC
# 7230 -- this is a bug in the standard that we originally copied...)
tr(
READERS[SERVER, SEND_RESPONSE],
b"HTTP/1.0 200 OK\r\n" b"Foo: a a a a a \r\n\r\n",
Response(
status_code=200,
headers=[("Foo", "a a a a a")],
http_version="1.0",
reason=b"OK",
),
)
# Empty headers -- also legal
tr(
READERS[SERVER, SEND_RESPONSE],
b"HTTP/1.0 200 OK\r\n" b"Foo:\r\n\r\n",
Response(
status_code=200, headers=[("Foo", "")], http_version="1.0", reason=b"OK"
),
)
tr(
READERS[SERVER, SEND_RESPONSE],
b"HTTP/1.0 200 OK\r\n" b"Foo: \t \t \r\n\r\n",
Response(
status_code=200, headers=[("Foo", "")], http_version="1.0", reason=b"OK"
),
)
# Tolerate broken servers that leave off the response code
tr(
READERS[SERVER, SEND_RESPONSE],
b"HTTP/1.0 200\r\n" b"Foo: bar\r\n\r\n",
Response(
status_code=200, headers=[("Foo", "bar")], http_version="1.0", reason=b""
),
)
# Tolerate headers line endings (\r\n and \n)
# \n\r\b between headers and body
tr(
READERS[SERVER, SEND_RESPONSE],
b"HTTP/1.1 200 OK\r\nSomeHeader: val\n\r\n",
Response(
status_code=200,
headers=[("SomeHeader", "val")],
http_version="1.1",
reason="OK",
),
)
# delimited only with \n
tr(
READERS[SERVER, SEND_RESPONSE],
b"HTTP/1.1 200 OK\nSomeHeader1: val1\nSomeHeader2: val2\n\n",
Response(
status_code=200,
headers=[("SomeHeader1", "val1"), ("SomeHeader2", "val2")],
http_version="1.1",
reason="OK",
),
)
# mixed \r\n and \n
tr(
READERS[SERVER, SEND_RESPONSE],
b"HTTP/1.1 200 OK\r\nSomeHeader1: val1\nSomeHeader2: val2\n\r\n",
Response(
status_code=200,
headers=[("SomeHeader1", "val1"), ("SomeHeader2", "val2")],
http_version="1.1",
reason="OK",
),
)
# obsolete line folding
tr(
READERS[CLIENT, IDLE],
b"HEAD /foo HTTP/1.1\r\n"
b"Host: example.com\r\n"
b"Some: multi-line\r\n"
b" header\r\n"
b"\tnonsense\r\n"
b" \t \t\tI guess\r\n"
b"Connection: close\r\n"
b"More-nonsense: in the\r\n"
b" last header \r\n\r\n",
Request(
method="HEAD",
target="/foo",
headers=[
("Host", "example.com"),
("Some", "multi-line header nonsense I guess"),
("Connection", "close"),
("More-nonsense", "in the last header"),
],
),
)
with pytest.raises(LocalProtocolError):
tr(
READERS[CLIENT, IDLE],
b"HEAD /foo HTTP/1.1\r\n" b" folded: line\r\n\r\n",
None,
)
with pytest.raises(LocalProtocolError):
tr(
READERS[CLIENT, IDLE],
b"HEAD /foo HTTP/1.1\r\n" b"foo : line\r\n\r\n",
None,
)
with pytest.raises(LocalProtocolError):
tr(
READERS[CLIENT, IDLE],
b"HEAD /foo HTTP/1.1\r\n" b"foo\t: line\r\n\r\n",
None,
)
with pytest.raises(LocalProtocolError):
tr(
READERS[CLIENT, IDLE],
b"HEAD /foo HTTP/1.1\r\n" b"foo\t: line\r\n\r\n",
None,
)
with pytest.raises(LocalProtocolError):
tr(READERS[CLIENT, IDLE], b"HEAD /foo HTTP/1.1\r\n" b": line\r\n\r\n", None)
def test__obsolete_line_fold_bytes() -> None:
# _obsolete_line_fold has a defensive cast to bytearray, which is
# necessary to protect against O(n^2) behavior in case anyone ever passes
# in regular bytestrings... but right now we never pass in regular
# bytestrings. so this test just exists to get some coverage on that
# defensive cast.
assert list(_obsolete_line_fold([b"aaa", b"bbb", b" ccc", b"ddd"])) == [
b"aaa",
bytearray(b"bbb ccc"),
b"ddd",
]
def _run_reader_iter(
reader: Any, buf: bytes, do_eof: bool
) -> Generator[Any, None, None]:
while True:
event = reader(buf)
if event is None:
break
yield event
# body readers have undefined behavior after returning EndOfMessage,
# because this changes the state so they don't get called again
if type(event) is EndOfMessage:
break
if do_eof:
assert not buf
yield reader.read_eof()
def _run_reader(*args: Any) -> List[Event]:
events = list(_run_reader_iter(*args))
return normalize_data_events(events)
def t_body_reader(thunk: Any, data: bytes, expected: Any, do_eof: bool = False) -> None:
# Simple: consume whole thing
print("Test 1")
buf = makebuf(data)
assert _run_reader(thunk(), buf, do_eof) == expected
# Incrementally growing buffer
print("Test 2")
reader = thunk()
buf = ReceiveBuffer()
events = []
for i in range(len(data)):
events += _run_reader(reader, buf, False)
buf += data[i : i + 1]
events += _run_reader(reader, buf, do_eof)
assert normalize_data_events(events) == expected
is_complete = any(type(event) is EndOfMessage for event in expected)
if is_complete and not do_eof:
buf = makebuf(data + b"trailing")
assert _run_reader(thunk(), buf, False) == expected
def test_ContentLengthReader() -> None:
t_body_reader(lambda: ContentLengthReader(0), b"", [EndOfMessage()])
t_body_reader(
lambda: ContentLengthReader(10),
b"0123456789",
[Data(data=b"0123456789"), EndOfMessage()],
)
def test_Http10Reader() -> None:
t_body_reader(Http10Reader, b"", [EndOfMessage()], do_eof=True)
t_body_reader(Http10Reader, b"asdf", [Data(data=b"asdf")], do_eof=False)
t_body_reader(
Http10Reader, b"asdf", [Data(data=b"asdf"), EndOfMessage()], do_eof=True
)
def test_ChunkedReader() -> None:
t_body_reader(ChunkedReader, b"0\r\n\r\n", [EndOfMessage()])
t_body_reader(
ChunkedReader,
b"0\r\nSome: header\r\n\r\n",
[EndOfMessage(headers=[("Some", "header")])],
)
t_body_reader(
ChunkedReader,
b"5\r\n01234\r\n"
+ b"10\r\n0123456789abcdef\r\n"
+ b"0\r\n"
+ b"Some: header\r\n\r\n",
[
Data(data=b"012340123456789abcdef"),
EndOfMessage(headers=[("Some", "header")]),
],
)
t_body_reader(
ChunkedReader,
b"5\r\n01234\r\n" + b"10\r\n0123456789abcdef\r\n" + b"0\r\n\r\n",
[Data(data=b"012340123456789abcdef"), EndOfMessage()],
)
# handles upper and lowercase hex
t_body_reader(
ChunkedReader,
b"aA\r\n" + b"x" * 0xAA + b"\r\n" + b"0\r\n\r\n",
[Data(data=b"x" * 0xAA), EndOfMessage()],
)
# refuses arbitrarily long chunk integers
with pytest.raises(LocalProtocolError):
# Technically this is legal HTTP/1.1, but we refuse to process chunk
# sizes that don't fit into 20 characters of hex
t_body_reader(ChunkedReader, b"9" * 100 + b"\r\nxxx", [Data(data=b"xxx")])
# refuses garbage in the chunk count
with pytest.raises(LocalProtocolError):
t_body_reader(ChunkedReader, b"10\x00\r\nxxx", None)
# handles (and discards) "chunk extensions" omg wtf
t_body_reader(
ChunkedReader,
b"5; hello=there\r\n"
+ b"xxxxx"
+ b"\r\n"
+ b'0; random="junk"; some=more; canbe=lonnnnngg\r\n\r\n',
[Data(data=b"xxxxx"), EndOfMessage()],
)
t_body_reader(
ChunkedReader,
b"5 \r\n01234\r\n" + b"0\r\n\r\n",
[Data(data=b"01234"), EndOfMessage()],
)
def test_ContentLengthWriter() -> None:
w = ContentLengthWriter(5)
assert dowrite(w, Data(data=b"123")) == b"123"
assert dowrite(w, Data(data=b"45")) == b"45"
assert dowrite(w, EndOfMessage()) == b""
w = ContentLengthWriter(5)
with pytest.raises(LocalProtocolError):
dowrite(w, Data(data=b"123456"))
w = ContentLengthWriter(5)
dowrite(w, Data(data=b"123"))
with pytest.raises(LocalProtocolError):
dowrite(w, Data(data=b"456"))
w = ContentLengthWriter(5)
dowrite(w, Data(data=b"123"))
with pytest.raises(LocalProtocolError):
dowrite(w, EndOfMessage())
w = ContentLengthWriter(5)
dowrite(w, Data(data=b"123")) == b"123"
dowrite(w, Data(data=b"45")) == b"45"
with pytest.raises(LocalProtocolError):
dowrite(w, EndOfMessage(headers=[("Etag", "asdf")]))
def test_ChunkedWriter() -> None:
w = ChunkedWriter()
assert dowrite(w, Data(data=b"aaa")) == b"3\r\naaa\r\n"
assert dowrite(w, Data(data=b"a" * 20)) == b"14\r\n" + b"a" * 20 + b"\r\n"
assert dowrite(w, Data(data=b"")) == b""
assert dowrite(w, EndOfMessage()) == b"0\r\n\r\n"
assert (
dowrite(w, EndOfMessage(headers=[("Etag", "asdf"), ("a", "b")]))
== b"0\r\nEtag: asdf\r\na: b\r\n\r\n"
)
def test_Http10Writer() -> None:
w = Http10Writer()
assert dowrite(w, Data(data=b"1234")) == b"1234"
assert dowrite(w, EndOfMessage()) == b""
with pytest.raises(LocalProtocolError):
dowrite(w, EndOfMessage(headers=[("Etag", "asdf")]))
def test_reject_garbage_after_request_line() -> None:
with pytest.raises(LocalProtocolError):
tr(READERS[SERVER, SEND_RESPONSE], b"HTTP/1.0 200 OK\x00xxxx\r\n\r\n", None)
def test_reject_garbage_after_response_line() -> None:
with pytest.raises(LocalProtocolError):
tr(
READERS[CLIENT, IDLE],
b"HEAD /foo HTTP/1.1 xxxxxx\r\n" b"Host: a\r\n\r\n",
None,
)
def test_reject_garbage_in_header_line() -> None:
with pytest.raises(LocalProtocolError):
tr(
READERS[CLIENT, IDLE],
b"HEAD /foo HTTP/1.1\r\n" b"Host: foo\x00bar\r\n\r\n",
None,
)
def test_reject_non_vchar_in_path() -> None:
for bad_char in b"\x00\x20\x7f\xee":
message = bytearray(b"HEAD /")
message.append(bad_char)
message.extend(b" HTTP/1.1\r\nHost: foobar\r\n\r\n")
with pytest.raises(LocalProtocolError):
tr(READERS[CLIENT, IDLE], message, None)
# https://github.com/python-hyper/h11/issues/57
def test_allow_some_garbage_in_cookies() -> None:
tr(
READERS[CLIENT, IDLE],
b"HEAD /foo HTTP/1.1\r\n"
b"Host: foo\r\n"
b"Set-Cookie: ___utmvafIumyLc=kUd\x01UpAt; path=/; Max-Age=900\r\n"
b"\r\n",
Request(
method="HEAD",
target="/foo",
headers=[
("Host", "foo"),
("Set-Cookie", "___utmvafIumyLc=kUd\x01UpAt; path=/; Max-Age=900"),
],
),
)
def test_host_comes_first() -> None:
tw(
write_headers,
normalize_and_validate([("foo", "bar"), ("Host", "example.com")]),
b"Host: example.com\r\nfoo: bar\r\n\r\n",
)

View File

@@ -0,0 +1,135 @@
import re
from typing import Tuple
import pytest
from .._receivebuffer import ReceiveBuffer
def test_receivebuffer() -> None:
b = ReceiveBuffer()
assert not b
assert len(b) == 0
assert bytes(b) == b""
b += b"123"
assert b
assert len(b) == 3
assert bytes(b) == b"123"
assert bytes(b) == b"123"
assert b.maybe_extract_at_most(2) == b"12"
assert b
assert len(b) == 1
assert bytes(b) == b"3"
assert bytes(b) == b"3"
assert b.maybe_extract_at_most(10) == b"3"
assert bytes(b) == b""
assert b.maybe_extract_at_most(10) is None
assert not b
################################################################
# maybe_extract_until_next
################################################################
b += b"123\n456\r\n789\r\n"
assert b.maybe_extract_next_line() == b"123\n456\r\n"
assert bytes(b) == b"789\r\n"
assert b.maybe_extract_next_line() == b"789\r\n"
assert bytes(b) == b""
b += b"12\r"
assert b.maybe_extract_next_line() is None
assert bytes(b) == b"12\r"
b += b"345\n\r"
assert b.maybe_extract_next_line() is None
assert bytes(b) == b"12\r345\n\r"
# here we stopped at the middle of b"\r\n" delimiter
b += b"\n6789aaa123\r\n"
assert b.maybe_extract_next_line() == b"12\r345\n\r\n"
assert b.maybe_extract_next_line() == b"6789aaa123\r\n"
assert b.maybe_extract_next_line() is None
assert bytes(b) == b""
################################################################
# maybe_extract_lines
################################################################
b += b"123\r\na: b\r\nfoo:bar\r\n\r\ntrailing"
lines = b.maybe_extract_lines()
assert lines == [b"123", b"a: b", b"foo:bar"]
assert bytes(b) == b"trailing"
assert b.maybe_extract_lines() is None
b += b"\r\n\r"
assert b.maybe_extract_lines() is None
assert b.maybe_extract_at_most(100) == b"trailing\r\n\r"
assert not b
# Empty body case (as happens at the end of chunked encoding if there are
# no trailing headers, e.g.)
b += b"\r\ntrailing"
assert b.maybe_extract_lines() == []
assert bytes(b) == b"trailing"
@pytest.mark.parametrize(
"data",
[
pytest.param(
(
b"HTTP/1.1 200 OK\r\n",
b"Content-type: text/plain\r\n",
b"Connection: close\r\n",
b"\r\n",
b"Some body",
),
id="with_crlf_delimiter",
),
pytest.param(
(
b"HTTP/1.1 200 OK\n",
b"Content-type: text/plain\n",
b"Connection: close\n",
b"\n",
b"Some body",
),
id="with_lf_only_delimiter",
),
pytest.param(
(
b"HTTP/1.1 200 OK\n",
b"Content-type: text/plain\r\n",
b"Connection: close\n",
b"\n",
b"Some body",
),
id="with_mixed_crlf_and_lf",
),
],
)
def test_receivebuffer_for_invalid_delimiter(data: Tuple[bytes]) -> None:
b = ReceiveBuffer()
for line in data:
b += line
lines = b.maybe_extract_lines()
assert lines == [
b"HTTP/1.1 200 OK",
b"Content-type: text/plain",
b"Connection: close",
]
assert bytes(b) == b"Some body"

View File

@@ -0,0 +1,271 @@
import pytest
from .._events import (
ConnectionClosed,
Data,
EndOfMessage,
Event,
InformationalResponse,
Request,
Response,
)
from .._state import (
_SWITCH_CONNECT,
_SWITCH_UPGRADE,
CLIENT,
CLOSED,
ConnectionState,
DONE,
IDLE,
MIGHT_SWITCH_PROTOCOL,
MUST_CLOSE,
SEND_BODY,
SEND_RESPONSE,
SERVER,
SWITCHED_PROTOCOL,
)
from .._util import LocalProtocolError
def test_ConnectionState() -> None:
cs = ConnectionState()
# Basic event-triggered transitions
assert cs.states == {CLIENT: IDLE, SERVER: IDLE}
cs.process_event(CLIENT, Request)
# The SERVER-Request special case:
assert cs.states == {CLIENT: SEND_BODY, SERVER: SEND_RESPONSE}
# Illegal transitions raise an error and nothing happens
with pytest.raises(LocalProtocolError):
cs.process_event(CLIENT, Request)
assert cs.states == {CLIENT: SEND_BODY, SERVER: SEND_RESPONSE}
cs.process_event(SERVER, InformationalResponse)
assert cs.states == {CLIENT: SEND_BODY, SERVER: SEND_RESPONSE}
cs.process_event(SERVER, Response)
assert cs.states == {CLIENT: SEND_BODY, SERVER: SEND_BODY}
cs.process_event(CLIENT, EndOfMessage)
cs.process_event(SERVER, EndOfMessage)
assert cs.states == {CLIENT: DONE, SERVER: DONE}
# State-triggered transition
cs.process_event(SERVER, ConnectionClosed)
assert cs.states == {CLIENT: MUST_CLOSE, SERVER: CLOSED}
def test_ConnectionState_keep_alive() -> None:
# keep_alive = False
cs = ConnectionState()
cs.process_event(CLIENT, Request)
cs.process_keep_alive_disabled()
cs.process_event(CLIENT, EndOfMessage)
assert cs.states == {CLIENT: MUST_CLOSE, SERVER: SEND_RESPONSE}
cs.process_event(SERVER, Response)
cs.process_event(SERVER, EndOfMessage)
assert cs.states == {CLIENT: MUST_CLOSE, SERVER: MUST_CLOSE}
def test_ConnectionState_keep_alive_in_DONE() -> None:
# Check that if keep_alive is disabled when the CLIENT is already in DONE,
# then this is sufficient to immediately trigger the DONE -> MUST_CLOSE
# transition
cs = ConnectionState()
cs.process_event(CLIENT, Request)
cs.process_event(CLIENT, EndOfMessage)
assert cs.states[CLIENT] is DONE
cs.process_keep_alive_disabled()
assert cs.states[CLIENT] is MUST_CLOSE
def test_ConnectionState_switch_denied() -> None:
for switch_type in (_SWITCH_CONNECT, _SWITCH_UPGRADE):
for deny_early in (True, False):
cs = ConnectionState()
cs.process_client_switch_proposal(switch_type)
cs.process_event(CLIENT, Request)
cs.process_event(CLIENT, Data)
assert cs.states == {CLIENT: SEND_BODY, SERVER: SEND_RESPONSE}
assert switch_type in cs.pending_switch_proposals
if deny_early:
# before client reaches DONE
cs.process_event(SERVER, Response)
assert not cs.pending_switch_proposals
cs.process_event(CLIENT, EndOfMessage)
if deny_early:
assert cs.states == {CLIENT: DONE, SERVER: SEND_BODY}
else:
assert cs.states == {
CLIENT: MIGHT_SWITCH_PROTOCOL,
SERVER: SEND_RESPONSE,
}
cs.process_event(SERVER, InformationalResponse)
assert cs.states == {
CLIENT: MIGHT_SWITCH_PROTOCOL,
SERVER: SEND_RESPONSE,
}
cs.process_event(SERVER, Response)
assert cs.states == {CLIENT: DONE, SERVER: SEND_BODY}
assert not cs.pending_switch_proposals
_response_type_for_switch = {
_SWITCH_UPGRADE: InformationalResponse,
_SWITCH_CONNECT: Response,
None: Response,
}
def test_ConnectionState_protocol_switch_accepted() -> None:
for switch_event in [_SWITCH_UPGRADE, _SWITCH_CONNECT]:
cs = ConnectionState()
cs.process_client_switch_proposal(switch_event)
cs.process_event(CLIENT, Request)
cs.process_event(CLIENT, Data)
assert cs.states == {CLIENT: SEND_BODY, SERVER: SEND_RESPONSE}
cs.process_event(CLIENT, EndOfMessage)
assert cs.states == {CLIENT: MIGHT_SWITCH_PROTOCOL, SERVER: SEND_RESPONSE}
cs.process_event(SERVER, InformationalResponse)
assert cs.states == {CLIENT: MIGHT_SWITCH_PROTOCOL, SERVER: SEND_RESPONSE}
cs.process_event(SERVER, _response_type_for_switch[switch_event], switch_event)
assert cs.states == {CLIENT: SWITCHED_PROTOCOL, SERVER: SWITCHED_PROTOCOL}
def test_ConnectionState_double_protocol_switch() -> None:
# CONNECT + Upgrade is legal! Very silly, but legal. So we support
# it. Because sometimes doing the silly thing is easier than not.
for server_switch in [None, _SWITCH_UPGRADE, _SWITCH_CONNECT]:
cs = ConnectionState()
cs.process_client_switch_proposal(_SWITCH_UPGRADE)
cs.process_client_switch_proposal(_SWITCH_CONNECT)
cs.process_event(CLIENT, Request)
cs.process_event(CLIENT, EndOfMessage)
assert cs.states == {CLIENT: MIGHT_SWITCH_PROTOCOL, SERVER: SEND_RESPONSE}
cs.process_event(
SERVER, _response_type_for_switch[server_switch], server_switch
)
if server_switch is None:
assert cs.states == {CLIENT: DONE, SERVER: SEND_BODY}
else:
assert cs.states == {CLIENT: SWITCHED_PROTOCOL, SERVER: SWITCHED_PROTOCOL}
def test_ConnectionState_inconsistent_protocol_switch() -> None:
for client_switches, server_switch in [
([], _SWITCH_CONNECT),
([], _SWITCH_UPGRADE),
([_SWITCH_UPGRADE], _SWITCH_CONNECT),
([_SWITCH_CONNECT], _SWITCH_UPGRADE),
]:
cs = ConnectionState()
for client_switch in client_switches: # type: ignore[attr-defined]
cs.process_client_switch_proposal(client_switch)
cs.process_event(CLIENT, Request)
with pytest.raises(LocalProtocolError):
cs.process_event(SERVER, Response, server_switch)
def test_ConnectionState_keepalive_protocol_switch_interaction() -> None:
# keep_alive=False + pending_switch_proposals
cs = ConnectionState()
cs.process_client_switch_proposal(_SWITCH_UPGRADE)
cs.process_event(CLIENT, Request)
cs.process_keep_alive_disabled()
cs.process_event(CLIENT, Data)
assert cs.states == {CLIENT: SEND_BODY, SERVER: SEND_RESPONSE}
# the protocol switch "wins"
cs.process_event(CLIENT, EndOfMessage)
assert cs.states == {CLIENT: MIGHT_SWITCH_PROTOCOL, SERVER: SEND_RESPONSE}
# but when the server denies the request, keep_alive comes back into play
cs.process_event(SERVER, Response)
assert cs.states == {CLIENT: MUST_CLOSE, SERVER: SEND_BODY}
def test_ConnectionState_reuse() -> None:
cs = ConnectionState()
with pytest.raises(LocalProtocolError):
cs.start_next_cycle()
cs.process_event(CLIENT, Request)
cs.process_event(CLIENT, EndOfMessage)
with pytest.raises(LocalProtocolError):
cs.start_next_cycle()
cs.process_event(SERVER, Response)
cs.process_event(SERVER, EndOfMessage)
cs.start_next_cycle()
assert cs.states == {CLIENT: IDLE, SERVER: IDLE}
# No keepalive
cs.process_event(CLIENT, Request)
cs.process_keep_alive_disabled()
cs.process_event(CLIENT, EndOfMessage)
cs.process_event(SERVER, Response)
cs.process_event(SERVER, EndOfMessage)
with pytest.raises(LocalProtocolError):
cs.start_next_cycle()
# One side closed
cs = ConnectionState()
cs.process_event(CLIENT, Request)
cs.process_event(CLIENT, EndOfMessage)
cs.process_event(CLIENT, ConnectionClosed)
cs.process_event(SERVER, Response)
cs.process_event(SERVER, EndOfMessage)
with pytest.raises(LocalProtocolError):
cs.start_next_cycle()
# Succesful protocol switch
cs = ConnectionState()
cs.process_client_switch_proposal(_SWITCH_UPGRADE)
cs.process_event(CLIENT, Request)
cs.process_event(CLIENT, EndOfMessage)
cs.process_event(SERVER, InformationalResponse, _SWITCH_UPGRADE)
with pytest.raises(LocalProtocolError):
cs.start_next_cycle()
# Failed protocol switch
cs = ConnectionState()
cs.process_client_switch_proposal(_SWITCH_UPGRADE)
cs.process_event(CLIENT, Request)
cs.process_event(CLIENT, EndOfMessage)
cs.process_event(SERVER, Response)
cs.process_event(SERVER, EndOfMessage)
cs.start_next_cycle()
assert cs.states == {CLIENT: IDLE, SERVER: IDLE}
def test_server_request_is_illegal() -> None:
# There used to be a bug in how we handled the Request special case that
# made this allowed...
cs = ConnectionState()
with pytest.raises(LocalProtocolError):
cs.process_event(SERVER, Request)

View File

@@ -0,0 +1,112 @@
import re
import sys
import traceback
from typing import NoReturn
import pytest
from .._util import (
bytesify,
LocalProtocolError,
ProtocolError,
RemoteProtocolError,
Sentinel,
validate,
)
def test_ProtocolError() -> None:
with pytest.raises(TypeError):
ProtocolError("abstract base class")
def test_LocalProtocolError() -> None:
try:
raise LocalProtocolError("foo")
except LocalProtocolError as e:
assert str(e) == "foo"
assert e.error_status_hint == 400
try:
raise LocalProtocolError("foo", error_status_hint=418)
except LocalProtocolError as e:
assert str(e) == "foo"
assert e.error_status_hint == 418
def thunk() -> NoReturn:
raise LocalProtocolError("a", error_status_hint=420)
try:
try:
thunk()
except LocalProtocolError as exc1:
orig_traceback = "".join(traceback.format_tb(sys.exc_info()[2]))
exc1._reraise_as_remote_protocol_error()
except RemoteProtocolError as exc2:
assert type(exc2) is RemoteProtocolError
assert exc2.args == ("a",)
assert exc2.error_status_hint == 420
new_traceback = "".join(traceback.format_tb(sys.exc_info()[2]))
assert new_traceback.endswith(orig_traceback)
def test_validate() -> None:
my_re = re.compile(rb"(?P<group1>[0-9]+)\.(?P<group2>[0-9]+)")
with pytest.raises(LocalProtocolError):
validate(my_re, b"0.")
groups = validate(my_re, b"0.1")
assert groups == {"group1": b"0", "group2": b"1"}
# successful partial matches are an error - must match whole string
with pytest.raises(LocalProtocolError):
validate(my_re, b"0.1xx")
with pytest.raises(LocalProtocolError):
validate(my_re, b"0.1\n")
def test_validate_formatting() -> None:
my_re = re.compile(rb"foo")
with pytest.raises(LocalProtocolError) as excinfo:
validate(my_re, b"", "oops")
assert "oops" in str(excinfo.value)
with pytest.raises(LocalProtocolError) as excinfo:
validate(my_re, b"", "oops {}")
assert "oops {}" in str(excinfo.value)
with pytest.raises(LocalProtocolError) as excinfo:
validate(my_re, b"", "oops {} xx", 10)
assert "oops 10 xx" in str(excinfo.value)
def test_make_sentinel() -> None:
class S(Sentinel, metaclass=Sentinel):
pass
assert repr(S) == "S"
assert S == S
assert type(S).__name__ == "S"
assert S in {S}
assert type(S) is S
class S2(Sentinel, metaclass=Sentinel):
pass
assert repr(S2) == "S2"
assert S != S2
assert S not in {S2}
assert type(S) is not type(S2)
def test_bytesify() -> None:
assert bytesify(b"123") == b"123"
assert bytesify(bytearray(b"123")) == b"123"
assert bytesify("123") == b"123"
with pytest.raises(UnicodeEncodeError):
bytesify("\u1234")
with pytest.raises(TypeError):
bytesify(10)