updates
This commit is contained in:
@@ -1,39 +1,16 @@
|
||||
import logging
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import ssl
|
||||
import sys
|
||||
import typing
|
||||
from pathlib import Path
|
||||
|
||||
import certifi
|
||||
|
||||
from ._compat import set_minimum_tls_version_1_2
|
||||
from ._models import Headers
|
||||
from ._types import CertTypes, HeaderTypes, TimeoutTypes, URLTypes, VerifyTypes
|
||||
from ._types import CertTypes, HeaderTypes, TimeoutTypes
|
||||
from ._urls import URL
|
||||
from ._utils import get_ca_bundle_from_env
|
||||
|
||||
DEFAULT_CIPHERS = ":".join(
|
||||
[
|
||||
"ECDHE+AESGCM",
|
||||
"ECDHE+CHACHA20",
|
||||
"DHE+AESGCM",
|
||||
"DHE+CHACHA20",
|
||||
"ECDH+AESGCM",
|
||||
"DH+AESGCM",
|
||||
"ECDH+AES",
|
||||
"DH+AES",
|
||||
"RSA+AESGCM",
|
||||
"RSA+AES",
|
||||
"!aNULL",
|
||||
"!eNULL",
|
||||
"!MD5",
|
||||
"!DSS",
|
||||
]
|
||||
)
|
||||
if typing.TYPE_CHECKING:
|
||||
import ssl # pragma: no cover
|
||||
|
||||
|
||||
logger = logging.getLogger("httpx")
|
||||
__all__ = ["Limits", "Proxy", "Timeout", "create_ssl_context"]
|
||||
|
||||
|
||||
class UnsetType:
|
||||
@@ -44,152 +21,52 @@ UNSET = UnsetType()
|
||||
|
||||
|
||||
def create_ssl_context(
|
||||
cert: typing.Optional[CertTypes] = None,
|
||||
verify: VerifyTypes = True,
|
||||
verify: ssl.SSLContext | str | bool = True,
|
||||
cert: CertTypes | None = None,
|
||||
trust_env: bool = True,
|
||||
http2: bool = False,
|
||||
) -> ssl.SSLContext:
|
||||
return SSLConfig(
|
||||
cert=cert, verify=verify, trust_env=trust_env, http2=http2
|
||||
).ssl_context
|
||||
import ssl
|
||||
import warnings
|
||||
|
||||
import certifi
|
||||
|
||||
class SSLConfig:
|
||||
"""
|
||||
SSL Configuration.
|
||||
"""
|
||||
|
||||
DEFAULT_CA_BUNDLE_PATH = Path(certifi.where())
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
cert: typing.Optional[CertTypes] = None,
|
||||
verify: VerifyTypes = True,
|
||||
trust_env: bool = True,
|
||||
http2: bool = False,
|
||||
):
|
||||
self.cert = cert
|
||||
self.verify = verify
|
||||
self.trust_env = trust_env
|
||||
self.http2 = http2
|
||||
self.ssl_context = self.load_ssl_context()
|
||||
|
||||
def load_ssl_context(self) -> ssl.SSLContext:
|
||||
logger.debug(
|
||||
"load_ssl_context verify=%r cert=%r trust_env=%r http2=%r",
|
||||
self.verify,
|
||||
self.cert,
|
||||
self.trust_env,
|
||||
self.http2,
|
||||
)
|
||||
|
||||
if self.verify:
|
||||
return self.load_ssl_context_verify()
|
||||
return self.load_ssl_context_no_verify()
|
||||
|
||||
def load_ssl_context_no_verify(self) -> ssl.SSLContext:
|
||||
"""
|
||||
Return an SSL context for unverified connections.
|
||||
"""
|
||||
context = self._create_default_ssl_context()
|
||||
context.check_hostname = False
|
||||
context.verify_mode = ssl.CERT_NONE
|
||||
self._load_client_certs(context)
|
||||
return context
|
||||
|
||||
def load_ssl_context_verify(self) -> ssl.SSLContext:
|
||||
"""
|
||||
Return an SSL context for verified connections.
|
||||
"""
|
||||
if self.trust_env and self.verify is True:
|
||||
ca_bundle = get_ca_bundle_from_env()
|
||||
if ca_bundle is not None:
|
||||
self.verify = ca_bundle
|
||||
|
||||
if isinstance(self.verify, ssl.SSLContext):
|
||||
# Allow passing in our own SSLContext object that's pre-configured.
|
||||
context = self.verify
|
||||
self._load_client_certs(context)
|
||||
return context
|
||||
elif isinstance(self.verify, bool):
|
||||
ca_bundle_path = self.DEFAULT_CA_BUNDLE_PATH
|
||||
elif Path(self.verify).exists():
|
||||
ca_bundle_path = Path(self.verify)
|
||||
if verify is True:
|
||||
if trust_env and os.environ.get("SSL_CERT_FILE"): # pragma: nocover
|
||||
ctx = ssl.create_default_context(cafile=os.environ["SSL_CERT_FILE"])
|
||||
elif trust_env and os.environ.get("SSL_CERT_DIR"): # pragma: nocover
|
||||
ctx = ssl.create_default_context(capath=os.environ["SSL_CERT_DIR"])
|
||||
else:
|
||||
raise IOError(
|
||||
"Could not find a suitable TLS CA certificate bundle, "
|
||||
"invalid path: {}".format(self.verify)
|
||||
)
|
||||
# Default case...
|
||||
ctx = ssl.create_default_context(cafile=certifi.where())
|
||||
elif verify is False:
|
||||
ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
|
||||
ctx.check_hostname = False
|
||||
ctx.verify_mode = ssl.CERT_NONE
|
||||
elif isinstance(verify, str): # pragma: nocover
|
||||
message = (
|
||||
"`verify=<str>` is deprecated. "
|
||||
"Use `verify=ssl.create_default_context(cafile=...)` "
|
||||
"or `verify=ssl.create_default_context(capath=...)` instead."
|
||||
)
|
||||
warnings.warn(message, DeprecationWarning)
|
||||
if os.path.isdir(verify):
|
||||
return ssl.create_default_context(capath=verify)
|
||||
return ssl.create_default_context(cafile=verify)
|
||||
else:
|
||||
ctx = verify
|
||||
|
||||
context = self._create_default_ssl_context()
|
||||
context.verify_mode = ssl.CERT_REQUIRED
|
||||
context.check_hostname = True
|
||||
if cert: # pragma: nocover
|
||||
message = (
|
||||
"`cert=...` is deprecated. Use `verify=<ssl_context>` instead,"
|
||||
"with `.load_cert_chain()` to configure the certificate chain."
|
||||
)
|
||||
warnings.warn(message, DeprecationWarning)
|
||||
if isinstance(cert, str):
|
||||
ctx.load_cert_chain(cert)
|
||||
else:
|
||||
ctx.load_cert_chain(*cert)
|
||||
|
||||
# Signal to server support for PHA in TLS 1.3. Raises an
|
||||
# AttributeError if only read-only access is implemented.
|
||||
if sys.version_info >= (3, 8): # pragma: no cover
|
||||
try:
|
||||
context.post_handshake_auth = True
|
||||
except AttributeError: # pragma: no cover
|
||||
pass
|
||||
|
||||
# Disable using 'commonName' for SSLContext.check_hostname
|
||||
# when the 'subjectAltName' extension isn't available.
|
||||
try:
|
||||
context.hostname_checks_common_name = False
|
||||
except AttributeError: # pragma: no cover
|
||||
pass
|
||||
|
||||
if ca_bundle_path.is_file():
|
||||
cafile = str(ca_bundle_path)
|
||||
logger.debug("load_verify_locations cafile=%r", cafile)
|
||||
context.load_verify_locations(cafile=cafile)
|
||||
elif ca_bundle_path.is_dir():
|
||||
capath = str(ca_bundle_path)
|
||||
logger.debug("load_verify_locations capath=%r", capath)
|
||||
context.load_verify_locations(capath=capath)
|
||||
|
||||
self._load_client_certs(context)
|
||||
|
||||
return context
|
||||
|
||||
def _create_default_ssl_context(self) -> ssl.SSLContext:
|
||||
"""
|
||||
Creates the default SSLContext object that's used for both verified
|
||||
and unverified connections.
|
||||
"""
|
||||
context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
|
||||
set_minimum_tls_version_1_2(context)
|
||||
context.options |= ssl.OP_NO_COMPRESSION
|
||||
context.set_ciphers(DEFAULT_CIPHERS)
|
||||
|
||||
if ssl.HAS_ALPN:
|
||||
alpn_idents = ["http/1.1", "h2"] if self.http2 else ["http/1.1"]
|
||||
context.set_alpn_protocols(alpn_idents)
|
||||
|
||||
if sys.version_info >= (3, 8): # pragma: no cover
|
||||
keylogfile = os.environ.get("SSLKEYLOGFILE")
|
||||
if keylogfile and self.trust_env:
|
||||
context.keylog_filename = keylogfile
|
||||
|
||||
return context
|
||||
|
||||
def _load_client_certs(self, ssl_context: ssl.SSLContext) -> None:
|
||||
"""
|
||||
Loads client certificates into our SSLContext object
|
||||
"""
|
||||
if self.cert is not None:
|
||||
if isinstance(self.cert, str):
|
||||
ssl_context.load_cert_chain(certfile=self.cert)
|
||||
elif isinstance(self.cert, tuple) and len(self.cert) == 2:
|
||||
ssl_context.load_cert_chain(certfile=self.cert[0], keyfile=self.cert[1])
|
||||
elif isinstance(self.cert, tuple) and len(self.cert) == 3:
|
||||
ssl_context.load_cert_chain(
|
||||
certfile=self.cert[0],
|
||||
keyfile=self.cert[1],
|
||||
password=self.cert[2], # type: ignore
|
||||
)
|
||||
return ctx
|
||||
|
||||
|
||||
class Timeout:
|
||||
@@ -208,13 +85,13 @@ class Timeout:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
timeout: typing.Union[TimeoutTypes, UnsetType] = UNSET,
|
||||
timeout: TimeoutTypes | UnsetType = UNSET,
|
||||
*,
|
||||
connect: typing.Union[None, float, UnsetType] = UNSET,
|
||||
read: typing.Union[None, float, UnsetType] = UNSET,
|
||||
write: typing.Union[None, float, UnsetType] = UNSET,
|
||||
pool: typing.Union[None, float, UnsetType] = UNSET,
|
||||
):
|
||||
connect: None | float | UnsetType = UNSET,
|
||||
read: None | float | UnsetType = UNSET,
|
||||
write: None | float | UnsetType = UNSET,
|
||||
pool: None | float | UnsetType = UNSET,
|
||||
) -> None:
|
||||
if isinstance(timeout, Timeout):
|
||||
# Passed as a single explicit Timeout.
|
||||
assert connect is UNSET
|
||||
@@ -252,7 +129,7 @@ class Timeout:
|
||||
self.write = timeout if isinstance(write, UnsetType) else write
|
||||
self.pool = timeout if isinstance(pool, UnsetType) else pool
|
||||
|
||||
def as_dict(self) -> typing.Dict[str, typing.Optional[float]]:
|
||||
def as_dict(self) -> dict[str, float | None]:
|
||||
return {
|
||||
"connect": self.connect,
|
||||
"read": self.read,
|
||||
@@ -296,10 +173,10 @@ class Limits:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
max_connections: typing.Optional[int] = None,
|
||||
max_keepalive_connections: typing.Optional[int] = None,
|
||||
keepalive_expiry: typing.Optional[float] = 5.0,
|
||||
):
|
||||
max_connections: int | None = None,
|
||||
max_keepalive_connections: int | None = None,
|
||||
keepalive_expiry: float | None = 5.0,
|
||||
) -> None:
|
||||
self.max_connections = max_connections
|
||||
self.max_keepalive_connections = max_keepalive_connections
|
||||
self.keepalive_expiry = keepalive_expiry
|
||||
@@ -324,15 +201,16 @@ class Limits:
|
||||
class Proxy:
|
||||
def __init__(
|
||||
self,
|
||||
url: URLTypes,
|
||||
url: URL | str,
|
||||
*,
|
||||
auth: typing.Optional[typing.Tuple[str, str]] = None,
|
||||
headers: typing.Optional[HeaderTypes] = None,
|
||||
):
|
||||
ssl_context: ssl.SSLContext | None = None,
|
||||
auth: tuple[str, str] | None = None,
|
||||
headers: HeaderTypes | None = None,
|
||||
) -> None:
|
||||
url = URL(url)
|
||||
headers = Headers(headers)
|
||||
|
||||
if url.scheme not in ("http", "https", "socks5"):
|
||||
if url.scheme not in ("http", "https", "socks5", "socks5h"):
|
||||
raise ValueError(f"Unknown scheme for proxy URL {url!r}")
|
||||
|
||||
if url.username or url.password:
|
||||
@@ -343,9 +221,10 @@ class Proxy:
|
||||
self.url = url
|
||||
self.auth = auth
|
||||
self.headers = headers
|
||||
self.ssl_context = ssl_context
|
||||
|
||||
@property
|
||||
def raw_auth(self) -> typing.Optional[typing.Tuple[bytes, bytes]]:
|
||||
def raw_auth(self) -> tuple[bytes, bytes] | None:
|
||||
# The proxy authentication as raw bytes.
|
||||
return (
|
||||
None
|
||||
|
||||
Reference in New Issue
Block a user