This commit is contained in:
Iliyan Angelov
2025-12-01 06:50:10 +02:00
parent 91f51bc6fe
commit 62c1fe5951
4682 changed files with 544807 additions and 31208 deletions

View File

@@ -0,0 +1,36 @@
from authlib.oauth1 import SIGNATURE_HMAC_SHA1
from authlib.oauth1 import SIGNATURE_PLAINTEXT
from authlib.oauth1 import SIGNATURE_RSA_SHA1
from authlib.oauth1 import SIGNATURE_TYPE_BODY
from authlib.oauth1 import SIGNATURE_TYPE_HEADER
from authlib.oauth1 import SIGNATURE_TYPE_QUERY
from ..base_client import OAuthError
from .assertion_client import AssertionClient
from .assertion_client import AsyncAssertionClient
from .oauth1_client import AsyncOAuth1Client
from .oauth1_client import OAuth1Auth
from .oauth1_client import OAuth1Client
from .oauth2_client import AsyncOAuth2Client
from .oauth2_client import OAuth2Auth
from .oauth2_client import OAuth2Client
from .oauth2_client import OAuth2ClientAuth
__all__ = [
"OAuthError",
"OAuth1Auth",
"AsyncOAuth1Client",
"OAuth1Client",
"SIGNATURE_HMAC_SHA1",
"SIGNATURE_RSA_SHA1",
"SIGNATURE_PLAINTEXT",
"SIGNATURE_TYPE_HEADER",
"SIGNATURE_TYPE_QUERY",
"SIGNATURE_TYPE_BODY",
"OAuth2Auth",
"OAuth2ClientAuth",
"OAuth2Client",
"AsyncOAuth2Client",
"AssertionClient",
"AsyncAssertionClient",
]

View File

@@ -0,0 +1,124 @@
import httpx
from httpx import USE_CLIENT_DEFAULT
from httpx import Response
from authlib.oauth2.rfc7521 import AssertionClient as _AssertionClient
from authlib.oauth2.rfc7523 import JWTBearerGrant
from ..base_client import OAuthError
from .oauth2_client import OAuth2Auth
from .utils import extract_client_kwargs
__all__ = ["AsyncAssertionClient"]
class AsyncAssertionClient(_AssertionClient, httpx.AsyncClient):
token_auth_class = OAuth2Auth
oauth_error_class = OAuthError
JWT_BEARER_GRANT_TYPE = JWTBearerGrant.GRANT_TYPE
ASSERTION_METHODS = {
JWT_BEARER_GRANT_TYPE: JWTBearerGrant.sign,
}
DEFAULT_GRANT_TYPE = JWT_BEARER_GRANT_TYPE
def __init__(
self,
token_endpoint,
issuer,
subject,
audience=None,
grant_type=None,
claims=None,
token_placement="header",
scope=None,
**kwargs,
):
client_kwargs = extract_client_kwargs(kwargs)
httpx.AsyncClient.__init__(self, **client_kwargs)
_AssertionClient.__init__(
self,
session=None,
token_endpoint=token_endpoint,
issuer=issuer,
subject=subject,
audience=audience,
grant_type=grant_type,
claims=claims,
token_placement=token_placement,
scope=scope,
**kwargs,
)
async def request(
self, method, url, withhold_token=False, auth=USE_CLIENT_DEFAULT, **kwargs
) -> Response:
"""Send request with auto refresh token feature."""
if not withhold_token and auth is USE_CLIENT_DEFAULT:
if not self.token or self.token.is_expired():
await self.refresh_token()
auth = self.token_auth
return await super().request(method, url, auth=auth, **kwargs)
async def _refresh_token(self, data):
resp = await self.request(
"POST", self.token_endpoint, data=data, withhold_token=True
)
return self.parse_response_token(resp)
class AssertionClient(_AssertionClient, httpx.Client):
token_auth_class = OAuth2Auth
oauth_error_class = OAuthError
JWT_BEARER_GRANT_TYPE = JWTBearerGrant.GRANT_TYPE
ASSERTION_METHODS = {
JWT_BEARER_GRANT_TYPE: JWTBearerGrant.sign,
}
DEFAULT_GRANT_TYPE = JWT_BEARER_GRANT_TYPE
def __init__(
self,
token_endpoint,
issuer,
subject,
audience=None,
grant_type=None,
claims=None,
token_placement="header",
scope=None,
**kwargs,
):
client_kwargs = extract_client_kwargs(kwargs)
# app keyword was dropped!
app_value = client_kwargs.pop("app", None)
if app_value is not None:
client_kwargs["transport"] = httpx.WSGITransport(app=app_value)
httpx.Client.__init__(self, **client_kwargs)
_AssertionClient.__init__(
self,
session=self,
token_endpoint=token_endpoint,
issuer=issuer,
subject=subject,
audience=audience,
grant_type=grant_type,
claims=claims,
token_placement=token_placement,
scope=scope,
**kwargs,
)
def request(
self, method, url, withhold_token=False, auth=USE_CLIENT_DEFAULT, **kwargs
):
"""Send request with auto refresh token feature."""
if not withhold_token and auth is USE_CLIENT_DEFAULT:
if not self.token or self.token.is_expired():
self.refresh_token()
auth = self.token_auth
return super().request(method, url, auth=auth, **kwargs)

View File

@@ -0,0 +1,145 @@
import typing
import httpx
from httpx import Auth
from httpx import Request
from httpx import Response
from authlib.common.encoding import to_unicode
from authlib.oauth1 import SIGNATURE_HMAC_SHA1
from authlib.oauth1 import SIGNATURE_TYPE_HEADER
from authlib.oauth1 import ClientAuth
from authlib.oauth1.client import OAuth1Client as _OAuth1Client
from ..base_client import OAuthError
from .utils import build_request
from .utils import extract_client_kwargs
class OAuth1Auth(Auth, ClientAuth):
"""Signs the httpx request using OAuth 1 (RFC5849)."""
requires_request_body = True
def auth_flow(self, request: Request) -> typing.Generator[Request, Response, None]:
url, headers, body = self.prepare(
request.method, str(request.url), request.headers, request.content
)
headers["Content-Length"] = str(len(body))
yield build_request(
url=url, headers=headers, body=body, initial_request=request
)
class AsyncOAuth1Client(_OAuth1Client, httpx.AsyncClient):
auth_class = OAuth1Auth
def __init__(
self,
client_id,
client_secret=None,
token=None,
token_secret=None,
redirect_uri=None,
rsa_key=None,
verifier=None,
signature_method=SIGNATURE_HMAC_SHA1,
signature_type=SIGNATURE_TYPE_HEADER,
force_include_body=False,
**kwargs,
):
_client_kwargs = extract_client_kwargs(kwargs)
httpx.AsyncClient.__init__(self, **_client_kwargs)
_OAuth1Client.__init__(
self,
None,
client_id=client_id,
client_secret=client_secret,
token=token,
token_secret=token_secret,
redirect_uri=redirect_uri,
rsa_key=rsa_key,
verifier=verifier,
signature_method=signature_method,
signature_type=signature_type,
force_include_body=force_include_body,
**kwargs,
)
async def fetch_access_token(self, url, verifier=None, **kwargs):
"""Method for fetching an access token from the token endpoint.
This is the final step in the OAuth 1 workflow. An access token is
obtained using all previously obtained credentials, including the
verifier from the authorization step.
:param url: Access Token endpoint.
:param verifier: A verifier string to prove authorization was granted.
:param kwargs: Extra parameters to include for fetching access token.
:return: A token dict.
"""
if verifier:
self.auth.verifier = verifier
if not self.auth.verifier:
self.handle_error("missing_verifier", 'Missing "verifier" value')
token = await self._fetch_token(url, **kwargs)
self.auth.verifier = None
return token
async def _fetch_token(self, url, **kwargs):
resp = await self.post(url, **kwargs)
text = await resp.aread()
token = self.parse_response_token(resp.status_code, to_unicode(text))
self.token = token
return token
@staticmethod
def handle_error(error_type, error_description):
raise OAuthError(error_type, error_description)
class OAuth1Client(_OAuth1Client, httpx.Client):
auth_class = OAuth1Auth
def __init__(
self,
client_id,
client_secret=None,
token=None,
token_secret=None,
redirect_uri=None,
rsa_key=None,
verifier=None,
signature_method=SIGNATURE_HMAC_SHA1,
signature_type=SIGNATURE_TYPE_HEADER,
force_include_body=False,
**kwargs,
):
_client_kwargs = extract_client_kwargs(kwargs)
# app keyword was dropped!
app_value = _client_kwargs.pop("app", None)
if app_value is not None:
_client_kwargs["transport"] = httpx.WSGITransport(app=app_value)
httpx.Client.__init__(self, **_client_kwargs)
_OAuth1Client.__init__(
self,
self,
client_id=client_id,
client_secret=client_secret,
token=token,
token_secret=token_secret,
redirect_uri=redirect_uri,
rsa_key=rsa_key,
verifier=verifier,
signature_method=signature_method,
signature_type=signature_type,
force_include_body=force_include_body,
**kwargs,
)
@staticmethod
def handle_error(error_type, error_description):
raise OAuthError(error_type, error_description)

View File

@@ -0,0 +1,285 @@
import typing
from contextlib import asynccontextmanager
import httpx
from anyio import Lock # Import after httpx so import errors refer to httpx
from httpx import USE_CLIENT_DEFAULT
from httpx import Auth
from httpx import Request
from httpx import Response
from authlib.common.urls import url_decode
from authlib.oauth2.auth import ClientAuth
from authlib.oauth2.auth import TokenAuth
from authlib.oauth2.client import OAuth2Client as _OAuth2Client
from ..base_client import InvalidTokenError
from ..base_client import MissingTokenError
from ..base_client import OAuthError
from ..base_client import UnsupportedTokenTypeError
from .utils import HTTPX_CLIENT_KWARGS
from .utils import build_request
__all__ = [
"OAuth2Auth",
"OAuth2ClientAuth",
"AsyncOAuth2Client",
"OAuth2Client",
]
class OAuth2Auth(Auth, TokenAuth):
"""Sign requests for OAuth 2.0, currently only bearer token is supported."""
requires_request_body = True
def auth_flow(self, request: Request) -> typing.Generator[Request, Response, None]:
try:
url, headers, body = self.prepare(
str(request.url), request.headers, request.content
)
headers["Content-Length"] = str(len(body))
yield build_request(
url=url, headers=headers, body=body, initial_request=request
)
except KeyError as error:
description = f"Unsupported token_type: {str(error)}"
raise UnsupportedTokenTypeError(description=description) from error
class OAuth2ClientAuth(Auth, ClientAuth):
requires_request_body = True
def auth_flow(self, request: Request) -> typing.Generator[Request, Response, None]:
url, headers, body = self.prepare(
request.method, str(request.url), request.headers, request.content
)
headers["Content-Length"] = str(len(body))
yield build_request(
url=url, headers=headers, body=body, initial_request=request
)
class AsyncOAuth2Client(_OAuth2Client, httpx.AsyncClient):
SESSION_REQUEST_PARAMS = HTTPX_CLIENT_KWARGS
client_auth_class = OAuth2ClientAuth
token_auth_class = OAuth2Auth
oauth_error_class = OAuthError
def __init__(
self,
client_id=None,
client_secret=None,
token_endpoint_auth_method=None,
revocation_endpoint_auth_method=None,
scope=None,
redirect_uri=None,
token=None,
token_placement="header",
update_token=None,
leeway=60,
**kwargs,
):
# extract httpx.Client kwargs
client_kwargs = self._extract_session_request_params(kwargs)
httpx.AsyncClient.__init__(self, **client_kwargs)
# We use a Lock to synchronize coroutines to prevent
# multiple concurrent attempts to refresh the same token
self._token_refresh_lock = Lock()
_OAuth2Client.__init__(
self,
session=None,
client_id=client_id,
client_secret=client_secret,
token_endpoint_auth_method=token_endpoint_auth_method,
revocation_endpoint_auth_method=revocation_endpoint_auth_method,
scope=scope,
redirect_uri=redirect_uri,
token=token,
token_placement=token_placement,
update_token=update_token,
leeway=leeway,
**kwargs,
)
async def request(
self, method, url, withhold_token=False, auth=USE_CLIENT_DEFAULT, **kwargs
):
if not withhold_token and auth is USE_CLIENT_DEFAULT:
if not self.token:
raise MissingTokenError()
await self.ensure_active_token(self.token)
auth = self.token_auth
return await super().request(method, url, auth=auth, **kwargs)
@asynccontextmanager
async def stream(
self, method, url, withhold_token=False, auth=USE_CLIENT_DEFAULT, **kwargs
):
if not withhold_token and auth is USE_CLIENT_DEFAULT:
if not self.token:
raise MissingTokenError()
await self.ensure_active_token(self.token)
auth = self.token_auth
async with super().stream(method, url, auth=auth, **kwargs) as resp:
yield resp
async def ensure_active_token(self, token):
async with self._token_refresh_lock:
if self.token.is_expired(leeway=self.leeway):
refresh_token = token.get("refresh_token")
url = self.metadata.get("token_endpoint")
if refresh_token and url:
await self.refresh_token(url, refresh_token=refresh_token)
elif self.metadata.get("grant_type") == "client_credentials":
access_token = token["access_token"]
new_token = await self.fetch_token(
url, grant_type="client_credentials"
)
if self.update_token:
await self.update_token(new_token, access_token=access_token)
else:
raise InvalidTokenError()
async def _fetch_token(
self,
url,
body="",
headers=None,
auth=USE_CLIENT_DEFAULT,
method="POST",
**kwargs,
):
if method.upper() == "POST":
resp = await self.post(
url, data=dict(url_decode(body)), headers=headers, auth=auth, **kwargs
)
else:
if "?" in url:
url = "&".join([url, body])
else:
url = "?".join([url, body])
resp = await self.get(url, headers=headers, auth=auth, **kwargs)
for hook in self.compliance_hook["access_token_response"]:
resp = hook(resp)
return self.parse_response_token(resp)
async def _refresh_token(
self,
url,
refresh_token=None,
body="",
headers=None,
auth=USE_CLIENT_DEFAULT,
**kwargs,
):
resp = await self.post(
url, data=dict(url_decode(body)), headers=headers, auth=auth, **kwargs
)
for hook in self.compliance_hook["refresh_token_response"]:
resp = hook(resp)
token = self.parse_response_token(resp)
if "refresh_token" not in token:
self.token["refresh_token"] = refresh_token
if self.update_token:
await self.update_token(self.token, refresh_token=refresh_token)
return self.token
def _http_post(
self, url, body=None, auth=USE_CLIENT_DEFAULT, headers=None, **kwargs
):
return self.post(
url, data=dict(url_decode(body)), headers=headers, auth=auth, **kwargs
)
class OAuth2Client(_OAuth2Client, httpx.Client):
SESSION_REQUEST_PARAMS = HTTPX_CLIENT_KWARGS
client_auth_class = OAuth2ClientAuth
token_auth_class = OAuth2Auth
oauth_error_class = OAuthError
def __init__(
self,
client_id=None,
client_secret=None,
token_endpoint_auth_method=None,
revocation_endpoint_auth_method=None,
scope=None,
redirect_uri=None,
token=None,
token_placement="header",
update_token=None,
**kwargs,
):
# extract httpx.Client kwargs
client_kwargs = self._extract_session_request_params(kwargs)
# app keyword was dropped!
app_value = client_kwargs.pop("app", None)
if app_value is not None:
client_kwargs["transport"] = httpx.WSGITransport(app=app_value)
httpx.Client.__init__(self, **client_kwargs)
_OAuth2Client.__init__(
self,
session=self,
client_id=client_id,
client_secret=client_secret,
token_endpoint_auth_method=token_endpoint_auth_method,
revocation_endpoint_auth_method=revocation_endpoint_auth_method,
scope=scope,
redirect_uri=redirect_uri,
token=token,
token_placement=token_placement,
update_token=update_token,
**kwargs,
)
@staticmethod
def handle_error(error_type, error_description):
raise OAuthError(error_type, error_description)
def request(
self, method, url, withhold_token=False, auth=USE_CLIENT_DEFAULT, **kwargs
):
if not withhold_token and auth is USE_CLIENT_DEFAULT:
if not self.token:
raise MissingTokenError()
if not self.ensure_active_token(self.token):
raise InvalidTokenError()
auth = self.token_auth
return super().request(method, url, auth=auth, **kwargs)
def stream(
self, method, url, withhold_token=False, auth=USE_CLIENT_DEFAULT, **kwargs
):
if not withhold_token and auth is USE_CLIENT_DEFAULT:
if not self.token:
raise MissingTokenError()
if not self.ensure_active_token(self.token):
raise InvalidTokenError()
auth = self.token_auth
return super().stream(method, url, auth=auth, **kwargs)

View File

@@ -0,0 +1,41 @@
from httpx import Request
HTTPX_CLIENT_KWARGS = [
"headers",
"cookies",
"verify",
"cert",
"http1",
"http2",
"proxy",
"mounts",
"timeout",
"follow_redirects",
"limits",
"max_redirects",
"event_hooks",
"base_url",
"transport",
"trust_env",
"default_encoding",
]
def extract_client_kwargs(kwargs):
client_kwargs = {}
for k in HTTPX_CLIENT_KWARGS:
if k in kwargs:
client_kwargs[k] = kwargs.pop(k)
return client_kwargs
def build_request(url, headers, body, initial_request: Request) -> Request:
"""Make sure that all the data from initial request is passed to the updated object."""
updated_request = Request(
method=initial_request.method, url=url, headers=headers, content=body
)
if hasattr(initial_request, "extensions"):
updated_request.extensions = initial_request.extensions
return updated_request