This commit is contained in:
Iliyan Angelov
2025-12-01 06:50:10 +02:00
parent 91f51bc6fe
commit 62c1fe5951
4682 changed files with 544807 additions and 31208 deletions

View File

@@ -0,0 +1,26 @@
from ..base_client import BaseOAuth
from ..base_client import OAuthError
from .apps import StarletteOAuth1App
from .apps import StarletteOAuth2App
from .integration import StarletteIntegration
class OAuth(BaseOAuth):
oauth1_client_cls = StarletteOAuth1App
oauth2_client_cls = StarletteOAuth2App
framework_integration_cls = StarletteIntegration
def __init__(self, config=None, cache=None, fetch_token=None, update_token=None):
super().__init__(
cache=cache, fetch_token=fetch_token, update_token=update_token
)
self.config = config
__all__ = [
"OAuth",
"OAuthError",
"StarletteIntegration",
"StarletteOAuth1App",
"StarletteOAuth2App",
]

View File

@@ -0,0 +1,106 @@
from starlette.datastructures import URL
from starlette.responses import RedirectResponse
from ..base_client import BaseApp
from ..base_client import OAuthError
from ..base_client.async_app import AsyncOAuth1Mixin
from ..base_client.async_app import AsyncOAuth2Mixin
from ..base_client.async_openid import AsyncOpenIDMixin
from ..httpx_client import AsyncOAuth1Client
from ..httpx_client import AsyncOAuth2Client
class StarletteAppMixin:
async def save_authorize_data(self, request, **kwargs):
state = kwargs.pop("state", None)
if state:
if self.framework.cache:
session = None
else:
session = request.session
await self.framework.set_state_data(session, state, kwargs)
else:
raise RuntimeError("Missing state value")
async def authorize_redirect(self, request, redirect_uri=None, **kwargs):
"""Create a HTTP Redirect for Authorization Endpoint.
:param request: HTTP request instance from Starlette view.
:param redirect_uri: Callback or redirect URI for authorization.
:param kwargs: Extra parameters to include.
:return: A HTTP redirect response.
"""
# Handle Starlette >= 0.26.0 where redirect_uri may now be a URL and not a string
if redirect_uri and isinstance(redirect_uri, URL):
redirect_uri = str(redirect_uri)
rv = await self.create_authorization_url(redirect_uri, **kwargs)
await self.save_authorize_data(request, redirect_uri=redirect_uri, **rv)
return RedirectResponse(rv["url"], status_code=302)
class StarletteOAuth1App(StarletteAppMixin, AsyncOAuth1Mixin, BaseApp):
client_cls = AsyncOAuth1Client
async def authorize_access_token(self, request, **kwargs):
params = dict(request.query_params)
state = params.get("oauth_token")
if not state:
raise OAuthError(description='Missing "oauth_token" parameter')
data = await self.framework.get_state_data(request.session, state)
if not data:
raise OAuthError(description='Missing "request_token" in temporary data')
params["request_token"] = data["request_token"]
params.update(kwargs)
await self.framework.clear_state_data(request.session, state)
return await self.fetch_access_token(**params)
class StarletteOAuth2App(
StarletteAppMixin, AsyncOAuth2Mixin, AsyncOpenIDMixin, BaseApp
):
client_cls = AsyncOAuth2Client
async def authorize_access_token(self, request, **kwargs):
if request.scope.get("method", "GET") == "GET":
error = request.query_params.get("error")
if error:
description = request.query_params.get("error_description")
raise OAuthError(error=error, description=description)
params = {
"code": request.query_params.get("code"),
"state": request.query_params.get("state"),
}
else:
async with request.form() as form:
params = {
"code": form.get("code"),
"state": form.get("state"),
}
if self.framework.cache:
session = None
else:
session = request.session
state_data = await self.framework.get_state_data(session, params.get("state"))
await self.framework.clear_state_data(session, params.get("state"))
params = self._format_state_params(state_data, params)
claims_options = kwargs.pop("claims_options", None)
claims_cls = kwargs.pop("claims_cls", None)
leeway = kwargs.pop("leeway", 120)
token = await self.fetch_access_token(**params, **kwargs)
if "id_token" in token and "nonce" in state_data:
userinfo = await self.parse_id_token(
token,
nonce=state_data["nonce"],
claims_options=claims_options,
claims_cls=claims_cls,
leeway=leeway,
)
token["userinfo"] = userinfo
return token

View File

@@ -0,0 +1,72 @@
import json
import time
from collections.abc import Hashable
from typing import Any
from typing import Optional
from ..base_client import FrameworkIntegration
class StarletteIntegration(FrameworkIntegration):
async def _get_cache_data(self, key: Hashable):
value = await self.cache.get(key)
if not value:
return None
try:
return json.loads(value)
except (TypeError, ValueError):
return None
async def get_state_data(
self, session: Optional[dict[str, Any]], state: str
) -> dict[str, Any]:
key = f"_state_{self.name}_{state}"
if self.cache:
value = await self._get_cache_data(key)
elif session is not None:
value = session.get(key)
else:
value = None
if value:
return value.get("data")
return None
async def set_state_data(
self, session: Optional[dict[str, Any]], state: str, data: Any
):
key_prefix = f"_state_{self.name}_"
key = f"{key_prefix}{state}"
if self.cache:
await self.cache.set(key, json.dumps({"data": data}), self.expires_in)
elif session is not None:
# clear old state data to avoid session size growing
for old_key in list(session.keys()):
if old_key.startswith(key_prefix):
session.pop(old_key)
now = time.time()
session[key] = {"data": data, "exp": now + self.expires_in}
async def clear_state_data(self, session: Optional[dict[str, Any]], state: str):
key = f"_state_{self.name}_{state}"
if self.cache:
await self.cache.delete(key)
elif session is not None:
session.pop(key, None)
self._clear_session_state(session)
def update_token(self, token, refresh_token=None, access_token=None):
pass
@staticmethod
def load_config(oauth, name, params):
if not oauth.config:
return {}
rv = {}
for k in params:
conf_key = f"{name}_{k}".upper()
v = oauth.config.get(conf_key, default=None)
if v is not None:
rv[k] = v
return rv