updates
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import netrc
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
@@ -8,13 +9,16 @@ from base64 import b64encode
|
||||
from urllib.request import parse_http_list
|
||||
|
||||
from ._exceptions import ProtocolError
|
||||
from ._models import Request, Response
|
||||
from ._models import Cookies, Request, Response
|
||||
from ._utils import to_bytes, to_str, unquote
|
||||
|
||||
if typing.TYPE_CHECKING: # pragma: no cover
|
||||
from hashlib import _Hash
|
||||
|
||||
|
||||
__all__ = ["Auth", "BasicAuth", "DigestAuth", "NetRCAuth"]
|
||||
|
||||
|
||||
class Auth:
|
||||
"""
|
||||
Base class for all authentication schemes.
|
||||
@@ -125,18 +129,14 @@ class BasicAuth(Auth):
|
||||
and uses HTTP Basic authentication.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, username: typing.Union[str, bytes], password: typing.Union[str, bytes]
|
||||
):
|
||||
def __init__(self, username: str | bytes, password: str | bytes) -> None:
|
||||
self._auth_header = self._build_auth_header(username, password)
|
||||
|
||||
def auth_flow(self, request: Request) -> typing.Generator[Request, Response, None]:
|
||||
request.headers["Authorization"] = self._auth_header
|
||||
yield request
|
||||
|
||||
def _build_auth_header(
|
||||
self, username: typing.Union[str, bytes], password: typing.Union[str, bytes]
|
||||
) -> str:
|
||||
def _build_auth_header(self, username: str | bytes, password: str | bytes) -> str:
|
||||
userpass = b":".join((to_bytes(username), to_bytes(password)))
|
||||
token = b64encode(userpass).decode()
|
||||
return f"Basic {token}"
|
||||
@@ -147,7 +147,11 @@ class NetRCAuth(Auth):
|
||||
Use a 'netrc' file to lookup basic auth credentials based on the url host.
|
||||
"""
|
||||
|
||||
def __init__(self, file: typing.Optional[str] = None):
|
||||
def __init__(self, file: str | None = None) -> None:
|
||||
# Lazily import 'netrc'.
|
||||
# There's no need for us to load this module unless 'NetRCAuth' is being used.
|
||||
import netrc
|
||||
|
||||
self._netrc_info = netrc.netrc(file)
|
||||
|
||||
def auth_flow(self, request: Request) -> typing.Generator[Request, Response, None]:
|
||||
@@ -162,16 +166,14 @@ class NetRCAuth(Auth):
|
||||
)
|
||||
yield request
|
||||
|
||||
def _build_auth_header(
|
||||
self, username: typing.Union[str, bytes], password: typing.Union[str, bytes]
|
||||
) -> str:
|
||||
def _build_auth_header(self, username: str | bytes, password: str | bytes) -> str:
|
||||
userpass = b":".join((to_bytes(username), to_bytes(password)))
|
||||
token = b64encode(userpass).decode()
|
||||
return f"Basic {token}"
|
||||
|
||||
|
||||
class DigestAuth(Auth):
|
||||
_ALGORITHM_TO_HASH_FUNCTION: typing.Dict[str, typing.Callable[[bytes], "_Hash"]] = {
|
||||
_ALGORITHM_TO_HASH_FUNCTION: dict[str, typing.Callable[[bytes], _Hash]] = {
|
||||
"MD5": hashlib.md5,
|
||||
"MD5-SESS": hashlib.md5,
|
||||
"SHA": hashlib.sha1,
|
||||
@@ -182,12 +184,10 @@ class DigestAuth(Auth):
|
||||
"SHA-512-SESS": hashlib.sha512,
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self, username: typing.Union[str, bytes], password: typing.Union[str, bytes]
|
||||
) -> None:
|
||||
def __init__(self, username: str | bytes, password: str | bytes) -> None:
|
||||
self._username = to_bytes(username)
|
||||
self._password = to_bytes(password)
|
||||
self._last_challenge: typing.Optional[_DigestAuthChallenge] = None
|
||||
self._last_challenge: _DigestAuthChallenge | None = None
|
||||
self._nonce_count = 1
|
||||
|
||||
def auth_flow(self, request: Request) -> typing.Generator[Request, Response, None]:
|
||||
@@ -217,11 +217,13 @@ class DigestAuth(Auth):
|
||||
request.headers["Authorization"] = self._build_auth_header(
|
||||
request, self._last_challenge
|
||||
)
|
||||
if response.cookies:
|
||||
Cookies(response.cookies).set_cookie_header(request=request)
|
||||
yield request
|
||||
|
||||
def _parse_challenge(
|
||||
self, request: Request, response: Response, auth_header: str
|
||||
) -> "_DigestAuthChallenge":
|
||||
) -> _DigestAuthChallenge:
|
||||
"""
|
||||
Returns a challenge from a Digest WWW-Authenticate header.
|
||||
These take the form of:
|
||||
@@ -232,7 +234,7 @@ class DigestAuth(Auth):
|
||||
# This method should only ever have been called with a Digest auth header.
|
||||
assert scheme.lower() == "digest"
|
||||
|
||||
header_dict: typing.Dict[str, str] = {}
|
||||
header_dict: dict[str, str] = {}
|
||||
for field in parse_http_list(fields):
|
||||
key, value = field.strip().split("=", 1)
|
||||
header_dict[key] = unquote(value)
|
||||
@@ -251,7 +253,7 @@ class DigestAuth(Auth):
|
||||
raise ProtocolError(message, request=request) from exc
|
||||
|
||||
def _build_auth_header(
|
||||
self, request: Request, challenge: "_DigestAuthChallenge"
|
||||
self, request: Request, challenge: _DigestAuthChallenge
|
||||
) -> str:
|
||||
hash_func = self._ALGORITHM_TO_HASH_FUNCTION[challenge.algorithm.upper()]
|
||||
|
||||
@@ -275,17 +277,18 @@ class DigestAuth(Auth):
|
||||
|
||||
qop = self._resolve_qop(challenge.qop, request=request)
|
||||
if qop is None:
|
||||
# Following RFC 2069
|
||||
digest_data = [HA1, challenge.nonce, HA2]
|
||||
else:
|
||||
digest_data = [challenge.nonce, nc_value, cnonce, qop, HA2]
|
||||
key_digest = b":".join(digest_data)
|
||||
# Following RFC 2617/7616
|
||||
digest_data = [HA1, challenge.nonce, nc_value, cnonce, qop, HA2]
|
||||
|
||||
format_args = {
|
||||
"username": self._username,
|
||||
"realm": challenge.realm,
|
||||
"nonce": challenge.nonce,
|
||||
"uri": path,
|
||||
"response": digest(b":".join((HA1, key_digest))),
|
||||
"response": digest(b":".join(digest_data)),
|
||||
"algorithm": challenge.algorithm.encode(),
|
||||
}
|
||||
if challenge.opaque:
|
||||
@@ -305,7 +308,7 @@ class DigestAuth(Auth):
|
||||
|
||||
return hashlib.sha1(s).hexdigest()[:16].encode()
|
||||
|
||||
def _get_header_value(self, header_fields: typing.Dict[str, bytes]) -> str:
|
||||
def _get_header_value(self, header_fields: dict[str, bytes]) -> str:
|
||||
NON_QUOTED_FIELDS = ("algorithm", "qop", "nc")
|
||||
QUOTED_TEMPLATE = '{}="{}"'
|
||||
NON_QUOTED_TEMPLATE = "{}={}"
|
||||
@@ -323,9 +326,7 @@ class DigestAuth(Auth):
|
||||
|
||||
return header_value
|
||||
|
||||
def _resolve_qop(
|
||||
self, qop: typing.Optional[bytes], request: Request
|
||||
) -> typing.Optional[bytes]:
|
||||
def _resolve_qop(self, qop: bytes | None, request: Request) -> bytes | None:
|
||||
if qop is None:
|
||||
return None
|
||||
qops = re.split(b", ?", qop)
|
||||
@@ -343,5 +344,5 @@ class _DigestAuthChallenge(typing.NamedTuple):
|
||||
realm: bytes
|
||||
nonce: bytes
|
||||
algorithm: str
|
||||
opaque: typing.Optional[bytes]
|
||||
qop: typing.Optional[bytes]
|
||||
opaque: bytes | None
|
||||
qop: bytes | None
|
||||
|
||||
Reference in New Issue
Block a user