This commit is contained in:
Iliyan Angelov
2025-12-01 06:50:10 +02:00
parent 91f51bc6fe
commit 62c1fe5951
4682 changed files with 544807 additions and 31208 deletions

View File

@@ -0,0 +1,29 @@
from .errors import InvalidTokenError
from .errors import MismatchingStateError
from .errors import MissingRequestTokenError
from .errors import MissingTokenError
from .errors import OAuthError
from .errors import TokenExpiredError
from .errors import UnsupportedTokenTypeError
from .framework_integration import FrameworkIntegration
from .registry import BaseOAuth
from .sync_app import BaseApp
from .sync_app import OAuth1Mixin
from .sync_app import OAuth2Mixin
from .sync_openid import OpenIDMixin
__all__ = [
"BaseOAuth",
"BaseApp",
"OAuth1Mixin",
"OAuth2Mixin",
"OpenIDMixin",
"FrameworkIntegration",
"OAuthError",
"MissingRequestTokenError",
"MissingTokenError",
"TokenExpiredError",
"InvalidTokenError",
"UnsupportedTokenTypeError",
"MismatchingStateError",
]

View File

@@ -0,0 +1,152 @@
import logging
import time
from authlib.common.urls import urlparse
from .errors import MissingRequestTokenError
from .errors import MissingTokenError
from .sync_app import OAuth1Base
from .sync_app import OAuth2Base
log = logging.getLogger(__name__)
__all__ = ["AsyncOAuth1Mixin", "AsyncOAuth2Mixin"]
class AsyncOAuth1Mixin(OAuth1Base):
async def request(self, method, url, token=None, **kwargs):
async with self._get_oauth_client() as session:
return await _http_request(self, session, method, url, token, kwargs)
async def create_authorization_url(self, redirect_uri=None, **kwargs):
"""Generate the authorization url and state for HTTP redirect.
:param redirect_uri: Callback or redirect URI for authorization.
:param kwargs: Extra parameters to include.
:return: dict
"""
if not self.authorize_url:
raise RuntimeError('Missing "authorize_url" value')
if self.authorize_params:
kwargs.update(self.authorize_params)
async with self._get_oauth_client() as client:
client.redirect_uri = redirect_uri
params = {}
if self.request_token_params:
params.update(self.request_token_params)
request_token = await client.fetch_request_token(
self.request_token_url, **params
)
log.debug(f"Fetch request token: {request_token!r}")
url = client.create_authorization_url(self.authorize_url, **kwargs)
state = request_token["oauth_token"]
return {"url": url, "request_token": request_token, "state": state}
async def fetch_access_token(self, request_token=None, **kwargs):
"""Fetch access token in one step.
:param request_token: A previous request token for OAuth 1.
:param kwargs: Extra parameters to fetch access token.
:return: A token dict.
"""
async with self._get_oauth_client() as client:
if request_token is None:
raise MissingRequestTokenError()
# merge request token with verifier
token = {}
token.update(request_token)
token.update(kwargs)
client.token = token
params = self.access_token_params or {}
token = await client.fetch_access_token(self.access_token_url, **params)
return token
class AsyncOAuth2Mixin(OAuth2Base):
async def _on_update_token(self, token, refresh_token=None, access_token=None):
if self._update_token:
await self._update_token(
token,
refresh_token=refresh_token,
access_token=access_token,
)
async def load_server_metadata(self):
if self._server_metadata_url and "_loaded_at" not in self.server_metadata:
async with self.client_cls(**self.client_kwargs) as client:
resp = await client.request(
"GET", self._server_metadata_url, withhold_token=True
)
resp.raise_for_status()
metadata = resp.json()
metadata["_loaded_at"] = time.time()
self.server_metadata.update(metadata)
return self.server_metadata
async def request(self, method, url, token=None, **kwargs):
metadata = await self.load_server_metadata()
async with self._get_oauth_client(**metadata) as session:
return await _http_request(self, session, method, url, token, kwargs)
async def create_authorization_url(self, redirect_uri=None, **kwargs):
"""Generate the authorization url and state for HTTP redirect.
:param redirect_uri: Callback or redirect URI for authorization.
:param kwargs: Extra parameters to include.
:return: dict
"""
metadata = await self.load_server_metadata()
authorization_endpoint = self.authorize_url or metadata.get(
"authorization_endpoint"
)
if not authorization_endpoint:
raise RuntimeError('Missing "authorize_url" value')
if self.authorize_params:
kwargs.update(self.authorize_params)
async with self._get_oauth_client(**metadata) as client:
client.redirect_uri = redirect_uri
return self._create_oauth2_authorization_url(
client, authorization_endpoint, **kwargs
)
async def fetch_access_token(self, redirect_uri=None, **kwargs):
"""Fetch access token in the final step.
:param redirect_uri: Callback or Redirect URI that is used in
previous :meth:`authorize_redirect`.
:param kwargs: Extra parameters to fetch access token.
:return: A token dict.
"""
metadata = await self.load_server_metadata()
token_endpoint = self.access_token_url or metadata.get("token_endpoint")
async with self._get_oauth_client(**metadata) as client:
if redirect_uri is not None:
client.redirect_uri = redirect_uri
params = {}
if self.access_token_params:
params.update(self.access_token_params)
params.update(kwargs)
token = await client.fetch_token(token_endpoint, **params)
return token
async def _http_request(ctx, session, method, url, token, kwargs):
request = kwargs.pop("request", None)
withhold_token = kwargs.get("withhold_token")
if ctx.api_base_url and not url.startswith(("https://", "http://")):
url = urlparse.urljoin(ctx.api_base_url, url)
if withhold_token:
return await session.request(method, url, **kwargs)
if token is None and ctx._fetch_token and request:
token = await ctx._fetch_token(request)
if token is None:
raise MissingTokenError()
session.token = token
return await session.request(method, url, **kwargs)

View File

@@ -0,0 +1,85 @@
from authlib.jose import JsonWebKey
from authlib.jose import JsonWebToken
from authlib.oidc.core import CodeIDToken
from authlib.oidc.core import ImplicitIDToken
from authlib.oidc.core import UserInfo
__all__ = ["AsyncOpenIDMixin"]
class AsyncOpenIDMixin:
async def fetch_jwk_set(self, force=False):
metadata = await self.load_server_metadata()
jwk_set = metadata.get("jwks")
if jwk_set and not force:
return jwk_set
uri = metadata.get("jwks_uri")
if not uri:
raise RuntimeError('Missing "jwks_uri" in metadata')
async with self.client_cls(**self.client_kwargs) as client:
resp = await client.request("GET", uri, withhold_token=True)
resp.raise_for_status()
jwk_set = resp.json()
self.server_metadata["jwks"] = jwk_set
return jwk_set
async def userinfo(self, **kwargs):
"""Fetch user info from ``userinfo_endpoint``."""
metadata = await self.load_server_metadata()
resp = await self.get(metadata["userinfo_endpoint"], **kwargs)
resp.raise_for_status()
data = resp.json()
return UserInfo(data)
async def parse_id_token(
self, token, nonce, claims_options=None, claims_cls=None, leeway=120
):
"""Return an instance of UserInfo from token's ``id_token``."""
claims_params = dict(
nonce=nonce,
client_id=self.client_id,
)
if claims_cls is None:
if "access_token" in token:
claims_params["access_token"] = token["access_token"]
claims_cls = CodeIDToken
else:
claims_cls = ImplicitIDToken
metadata = await self.load_server_metadata()
if claims_options is None and "issuer" in metadata:
claims_options = {"iss": {"values": [metadata["issuer"]]}}
alg_values = metadata.get("id_token_signing_alg_values_supported")
if not alg_values:
alg_values = ["RS256"]
jwt = JsonWebToken(alg_values)
jwk_set = await self.fetch_jwk_set()
try:
claims = jwt.decode(
token["id_token"],
key=JsonWebKey.import_key_set(jwk_set),
claims_cls=claims_cls,
claims_options=claims_options,
claims_params=claims_params,
)
except ValueError:
jwk_set = await self.fetch_jwk_set(force=True)
claims = jwt.decode(
token["id_token"],
key=JsonWebKey.import_key_set(jwk_set),
claims_cls=claims_cls,
claims_options=claims_options,
claims_params=claims_params,
)
# https://github.com/authlib/authlib/issues/259
if claims.get("nonce_supported") is False:
claims.params["nonce"] = None
claims.validate(leeway=leeway)
return UserInfo(claims)

View File

@@ -0,0 +1,30 @@
from authlib.common.errors import AuthlibBaseError
class OAuthError(AuthlibBaseError):
error = "oauth_error"
class MissingRequestTokenError(OAuthError):
error = "missing_request_token"
class MissingTokenError(OAuthError):
error = "missing_token"
class TokenExpiredError(OAuthError):
error = "token_expired"
class InvalidTokenError(OAuthError):
error = "token_invalid"
class UnsupportedTokenTypeError(OAuthError):
error = "unsupported_token_type"
class MismatchingStateError(OAuthError):
error = "mismatching_state"
description = "CSRF Warning! State not equal in request and response."

View File

@@ -0,0 +1,64 @@
import json
import time
class FrameworkIntegration:
expires_in = 3600
def __init__(self, name, cache=None):
self.name = name
self.cache = cache
def _get_cache_data(self, key):
value = self.cache.get(key)
if not value:
return None
try:
return json.loads(value)
except (TypeError, ValueError):
return None
def _clear_session_state(self, session):
now = time.time()
for key in dict(session):
if "_authlib_" in key:
# TODO: remove in future
session.pop(key)
elif key.startswith("_state_"):
value = session[key]
exp = value.get("exp")
if not exp or exp < now:
session.pop(key)
def get_state_data(self, session, state):
key = f"_state_{self.name}_{state}"
if self.cache:
value = self._get_cache_data(key)
else:
value = session.get(key)
if value:
return value.get("data")
return None
def set_state_data(self, session, state, data):
key = f"_state_{self.name}_{state}"
if self.cache:
self.cache.set(key, json.dumps({"data": data}), self.expires_in)
else:
now = time.time()
session[key] = {"data": data, "exp": now + self.expires_in}
def clear_state_data(self, session, state):
key = f"_state_{self.name}_{state}"
if self.cache:
self.cache.delete(key)
else:
session.pop(key, None)
self._clear_session_state(session)
def update_token(self, token, refresh_token=None, access_token=None):
raise NotImplementedError()
@staticmethod
def load_config(oauth, name, params):
raise NotImplementedError()

View File

@@ -0,0 +1,139 @@
import functools
from .framework_integration import FrameworkIntegration
__all__ = ["BaseOAuth"]
OAUTH_CLIENT_PARAMS = (
"client_id",
"client_secret",
"request_token_url",
"request_token_params",
"access_token_url",
"access_token_params",
"refresh_token_url",
"refresh_token_params",
"authorize_url",
"authorize_params",
"api_base_url",
"client_kwargs",
"server_metadata_url",
)
class BaseOAuth:
"""Registry for oauth clients.
Create an instance for registry::
oauth = OAuth()
"""
oauth1_client_cls = None
oauth2_client_cls = None
framework_integration_cls = FrameworkIntegration
def __init__(self, cache=None, fetch_token=None, update_token=None):
self._registry = {}
self._clients = {}
self.cache = cache
self.fetch_token = fetch_token
self.update_token = update_token
def create_client(self, name):
"""Create or get the given named OAuth client. For instance, the
OAuth registry has ``.register`` a twitter client, developers may
access the client with::
client = oauth.create_client("twitter")
:param: name: Name of the remote application
:return: OAuth remote app
"""
if name in self._clients:
return self._clients[name]
if name not in self._registry:
return None
overwrite, config = self._registry[name]
client_cls = config.pop("client_cls", None)
if client_cls and client_cls.OAUTH_APP_CONFIG:
kwargs = client_cls.OAUTH_APP_CONFIG
kwargs.update(config)
else:
kwargs = config
kwargs = self.generate_client_kwargs(name, overwrite, **kwargs)
framework = self.framework_integration_cls(name, self.cache)
if client_cls:
client = client_cls(framework, name, **kwargs)
elif kwargs.get("request_token_url"):
client = self.oauth1_client_cls(framework, name, **kwargs)
else:
client = self.oauth2_client_cls(framework, name, **kwargs)
self._clients[name] = client
return client
def register(self, name, overwrite=False, **kwargs):
"""Registers a new remote application.
:param name: Name of the remote application.
:param overwrite: Overwrite existing config with framework settings.
:param kwargs: Parameters for :class:`RemoteApp`.
Find parameters for the given remote app class. When a remote app is
registered, it can be accessed with *named* attribute::
oauth.register('twitter', client_id='', ...)
oauth.twitter.get('timeline')
"""
self._registry[name] = (overwrite, kwargs)
return self.create_client(name)
def generate_client_kwargs(self, name, overwrite, **kwargs):
fetch_token = kwargs.pop("fetch_token", None)
update_token = kwargs.pop("update_token", None)
config = self.load_config(name, OAUTH_CLIENT_PARAMS)
if config:
kwargs = _config_client(config, kwargs, overwrite)
if not fetch_token and self.fetch_token:
fetch_token = functools.partial(self.fetch_token, name)
kwargs["fetch_token"] = fetch_token
if not kwargs.get("request_token_url"):
if not update_token and self.update_token:
update_token = functools.partial(self.update_token, name)
kwargs["update_token"] = update_token
return kwargs
def load_config(self, name, params):
return self.framework_integration_cls.load_config(self, name, params)
def __getattr__(self, key):
try:
return object.__getattribute__(self, key)
except AttributeError as exc:
if key in self._registry:
return self.create_client(key)
raise AttributeError(f"No such client: {key}") from exc
def _config_client(config, kwargs, overwrite):
for k in OAUTH_CLIENT_PARAMS:
v = config.get(k, None)
if k not in kwargs:
kwargs[k] = v
elif overwrite and v:
if isinstance(kwargs[k], dict):
kwargs[k].update(v)
else:
kwargs[k] = v
return kwargs

View File

@@ -0,0 +1,377 @@
import logging
import time
from authlib.common.security import generate_token
from authlib.common.urls import urlparse
from authlib.consts import default_user_agent
from .errors import MismatchingStateError
from .errors import MissingRequestTokenError
from .errors import MissingTokenError
log = logging.getLogger(__name__)
class BaseApp:
client_cls = None
OAUTH_APP_CONFIG = None
def request(self, method, url, token=None, **kwargs):
raise NotImplementedError()
def get(self, url, **kwargs):
"""Invoke GET http request.
If ``api_base_url`` configured, shortcut is available::
client.get("users/lepture")
"""
return self.request("GET", url, **kwargs)
def post(self, url, **kwargs):
"""Invoke POST http request.
If ``api_base_url`` configured, shortcut is available::
client.post("timeline", json={"text": "Hi"})
"""
return self.request("POST", url, **kwargs)
def patch(self, url, **kwargs):
"""Invoke PATCH http request.
If ``api_base_url`` configured, shortcut is available::
client.patch("profile", json={"name": "Hsiaoming Yang"})
"""
return self.request("PATCH", url, **kwargs)
def put(self, url, **kwargs):
"""Invoke PUT http request.
If ``api_base_url`` configured, shortcut is available::
client.put("profile", json={"name": "Hsiaoming Yang"})
"""
return self.request("PUT", url, **kwargs)
def delete(self, url, **kwargs):
"""Invoke DELETE http request.
If ``api_base_url`` configured, shortcut is available::
client.delete("posts/123")
"""
return self.request("DELETE", url, **kwargs)
class _RequestMixin:
def _get_requested_token(self, request):
if self._fetch_token and request:
return self._fetch_token(request)
def _send_token_request(self, session, method, url, token, kwargs):
request = kwargs.pop("request", None)
withhold_token = kwargs.get("withhold_token")
if self.api_base_url and not url.startswith(("https://", "http://")):
url = urlparse.urljoin(self.api_base_url, url)
if withhold_token:
return session.request(method, url, **kwargs)
if token is None:
token = self._get_requested_token(request)
if token is None:
raise MissingTokenError()
session.token = token
return session.request(method, url, **kwargs)
class OAuth1Base:
client_cls = None
def __init__(
self,
framework,
name=None,
fetch_token=None,
client_id=None,
client_secret=None,
request_token_url=None,
request_token_params=None,
access_token_url=None,
access_token_params=None,
authorize_url=None,
authorize_params=None,
api_base_url=None,
client_kwargs=None,
user_agent=None,
**kwargs,
):
self.framework = framework
self.name = name
self.client_id = client_id
self.client_secret = client_secret
self.request_token_url = request_token_url
self.request_token_params = request_token_params
self.access_token_url = access_token_url
self.access_token_params = access_token_params
self.authorize_url = authorize_url
self.authorize_params = authorize_params
self.api_base_url = api_base_url
self.client_kwargs = client_kwargs or {}
self._fetch_token = fetch_token
self._user_agent = user_agent or default_user_agent
self._kwargs = kwargs
def _get_oauth_client(self):
session = self.client_cls(
self.client_id, self.client_secret, **self.client_kwargs
)
session.headers["User-Agent"] = self._user_agent
return session
class OAuth1Mixin(_RequestMixin, OAuth1Base):
def request(self, method, url, token=None, **kwargs):
with self._get_oauth_client() as session:
return self._send_token_request(session, method, url, token, kwargs)
def create_authorization_url(self, redirect_uri=None, **kwargs):
"""Generate the authorization url and state for HTTP redirect.
:param redirect_uri: Callback or redirect URI for authorization.
:param kwargs: Extra parameters to include.
:return: dict
"""
if not self.authorize_url:
raise RuntimeError('Missing "authorize_url" value')
if self.authorize_params:
kwargs.update(self.authorize_params)
with self._get_oauth_client() as client:
client.redirect_uri = redirect_uri
params = self.request_token_params or {}
request_token = client.fetch_request_token(self.request_token_url, **params)
log.debug(f"Fetch request token: {request_token!r}")
url = client.create_authorization_url(self.authorize_url, **kwargs)
state = request_token["oauth_token"]
return {"url": url, "request_token": request_token, "state": state}
def fetch_access_token(self, request_token=None, **kwargs):
"""Fetch access token in one step.
:param request_token: A previous request token for OAuth 1.
:param kwargs: Extra parameters to fetch access token.
:return: A token dict.
"""
with self._get_oauth_client() as client:
if request_token is None:
raise MissingRequestTokenError()
# merge request token with verifier
token = {}
token.update(request_token)
token.update(kwargs)
client.token = token
params = self.access_token_params or {}
token = client.fetch_access_token(self.access_token_url, **params)
return token
class OAuth2Base:
client_cls = None
def __init__(
self,
framework,
name=None,
fetch_token=None,
update_token=None,
client_id=None,
client_secret=None,
access_token_url=None,
access_token_params=None,
authorize_url=None,
authorize_params=None,
api_base_url=None,
client_kwargs=None,
server_metadata_url=None,
compliance_fix=None,
client_auth_methods=None,
user_agent=None,
**kwargs,
):
self.framework = framework
self.name = name
self.client_id = client_id
self.client_secret = client_secret
self.access_token_url = access_token_url
self.access_token_params = access_token_params
self.authorize_url = authorize_url
self.authorize_params = authorize_params
self.api_base_url = api_base_url
self.client_kwargs = client_kwargs or {}
self.compliance_fix = compliance_fix
self.client_auth_methods = client_auth_methods
self._fetch_token = fetch_token
self._update_token = update_token
self._user_agent = user_agent or default_user_agent
self._server_metadata_url = server_metadata_url
self.server_metadata = kwargs
def _on_update_token(self, token, refresh_token=None, access_token=None):
raise NotImplementedError()
def _get_oauth_client(self, **metadata):
client_kwargs = {}
client_kwargs.update(self.client_kwargs)
client_kwargs.update(metadata)
if self.authorize_url:
client_kwargs["authorization_endpoint"] = self.authorize_url
if self.access_token_url:
client_kwargs["token_endpoint"] = self.access_token_url
session = self.client_cls(
client_id=self.client_id,
client_secret=self.client_secret,
update_token=self._on_update_token,
**client_kwargs,
)
if self.client_auth_methods:
for f in self.client_auth_methods:
session.register_client_auth_method(f)
if self.compliance_fix:
self.compliance_fix(session)
session.headers["User-Agent"] = self._user_agent
return session
@staticmethod
def _format_state_params(state_data, params):
if state_data is None:
raise MismatchingStateError()
code_verifier = state_data.get("code_verifier")
if code_verifier:
params["code_verifier"] = code_verifier
redirect_uri = state_data.get("redirect_uri")
if redirect_uri:
params["redirect_uri"] = redirect_uri
return params
@staticmethod
def _create_oauth2_authorization_url(client, authorization_endpoint, **kwargs):
rv = {}
if client.code_challenge_method:
code_verifier = kwargs.get("code_verifier")
if not code_verifier:
code_verifier = generate_token(48)
kwargs["code_verifier"] = code_verifier
rv["code_verifier"] = code_verifier
log.debug(f"Using code_verifier: {code_verifier!r}")
scope = kwargs.get("scope", client.scope)
scope = (
(scope if isinstance(scope, (list, tuple)) else scope.split())
if scope
else None
)
if scope and "openid" in scope:
# this is an OpenID Connect service
nonce = kwargs.get("nonce")
if not nonce:
nonce = generate_token(20)
kwargs["nonce"] = nonce
rv["nonce"] = nonce
url, state = client.create_authorization_url(authorization_endpoint, **kwargs)
rv["url"] = url
rv["state"] = state
return rv
class OAuth2Mixin(_RequestMixin, OAuth2Base):
def _on_update_token(self, token, refresh_token=None, access_token=None):
if callable(self._update_token):
self._update_token(
token,
refresh_token=refresh_token,
access_token=access_token,
)
self.framework.update_token(
token,
refresh_token=refresh_token,
access_token=access_token,
)
def request(self, method, url, token=None, **kwargs):
metadata = self.load_server_metadata()
with self._get_oauth_client(**metadata) as session:
return self._send_token_request(session, method, url, token, kwargs)
def load_server_metadata(self):
if self._server_metadata_url and "_loaded_at" not in self.server_metadata:
with self.client_cls(**self.client_kwargs) as session:
resp = session.request(
"GET", self._server_metadata_url, withhold_token=True
)
resp.raise_for_status()
metadata = resp.json()
metadata["_loaded_at"] = time.time()
self.server_metadata.update(metadata)
return self.server_metadata
def create_authorization_url(self, redirect_uri=None, **kwargs):
"""Generate the authorization url and state for HTTP redirect.
:param redirect_uri: Callback or redirect URI for authorization.
:param kwargs: Extra parameters to include.
:return: dict
"""
metadata = self.load_server_metadata()
authorization_endpoint = self.authorize_url or metadata.get(
"authorization_endpoint"
)
if not authorization_endpoint:
raise RuntimeError('Missing "authorize_url" value')
if self.authorize_params:
kwargs.update(self.authorize_params)
with self._get_oauth_client(**metadata) as client:
if redirect_uri is not None:
client.redirect_uri = redirect_uri
return self._create_oauth2_authorization_url(
client, authorization_endpoint, **kwargs
)
def fetch_access_token(self, redirect_uri=None, **kwargs):
"""Fetch access token in the final step.
:param redirect_uri: Callback or Redirect URI that is used in
previous :meth:`authorize_redirect`.
:param kwargs: Extra parameters to fetch access token.
:return: A token dict.
"""
metadata = self.load_server_metadata()
token_endpoint = self.access_token_url or metadata.get("token_endpoint")
with self._get_oauth_client(**metadata) as client:
if redirect_uri is not None:
client.redirect_uri = redirect_uri
params = {}
if self.access_token_params:
params.update(self.access_token_params)
params.update(kwargs)
token = client.fetch_token(token_endpoint, **params)
return token

View File

@@ -0,0 +1,95 @@
from authlib.jose import JsonWebKey
from authlib.jose import JsonWebToken
from authlib.jose import jwt
from authlib.oidc.core import CodeIDToken
from authlib.oidc.core import ImplicitIDToken
from authlib.oidc.core import UserInfo
class OpenIDMixin:
def fetch_jwk_set(self, force=False):
metadata = self.load_server_metadata()
jwk_set = metadata.get("jwks")
if jwk_set and not force:
return jwk_set
uri = metadata.get("jwks_uri")
if not uri:
raise RuntimeError('Missing "jwks_uri" in metadata')
with self.client_cls(**self.client_kwargs) as session:
resp = session.request("GET", uri, withhold_token=True)
resp.raise_for_status()
jwk_set = resp.json()
self.server_metadata["jwks"] = jwk_set
return jwk_set
def userinfo(self, **kwargs):
"""Fetch user info from ``userinfo_endpoint``."""
metadata = self.load_server_metadata()
resp = self.get(metadata["userinfo_endpoint"], **kwargs)
resp.raise_for_status()
data = resp.json()
return UserInfo(data)
def parse_id_token(
self, token, nonce, claims_options=None, claims_cls=None, leeway=120
):
"""Return an instance of UserInfo from token's ``id_token``."""
if "id_token" not in token:
return None
load_key = self.create_load_key()
claims_params = dict(
nonce=nonce,
client_id=self.client_id,
)
if claims_cls is None:
if "access_token" in token:
claims_params["access_token"] = token["access_token"]
claims_cls = CodeIDToken
else:
claims_cls = ImplicitIDToken
metadata = self.load_server_metadata()
if claims_options is None and "issuer" in metadata:
claims_options = {"iss": {"values": [metadata["issuer"]]}}
alg_values = metadata.get("id_token_signing_alg_values_supported")
if alg_values:
_jwt = JsonWebToken(alg_values)
else:
_jwt = jwt
claims = _jwt.decode(
token["id_token"],
key=load_key,
claims_cls=claims_cls,
claims_options=claims_options,
claims_params=claims_params,
)
# https://github.com/authlib/authlib/issues/259
if claims.get("nonce_supported") is False:
claims.params["nonce"] = None
claims.validate(leeway=leeway)
return UserInfo(claims)
def create_load_key(self):
def load_key(header, _):
jwk_set = JsonWebKey.import_key_set(self.fetch_jwk_set())
try:
return jwk_set.find_by_kid(
header.get("kid"), use="sig", alg=header.get("alg")
)
except ValueError:
# re-try with new jwk set
jwk_set = JsonWebKey.import_key_set(self.fetch_jwk_set(force=True))
return jwk_set.find_by_kid(
header.get("kid"), use="sig", alg=header.get("alg")
)
return load_key

View File

@@ -0,0 +1,22 @@
from ..base_client import BaseOAuth
from ..base_client import OAuthError
from .apps import DjangoOAuth1App
from .apps import DjangoOAuth2App
from .integration import DjangoIntegration
from .integration import token_update
class OAuth(BaseOAuth):
oauth1_client_cls = DjangoOAuth1App
oauth2_client_cls = DjangoOAuth2App
framework_integration_cls = DjangoIntegration
__all__ = [
"OAuth",
"DjangoOAuth1App",
"DjangoOAuth2App",
"DjangoIntegration",
"token_update",
"OAuthError",
]

View File

@@ -0,0 +1,99 @@
from django.http import HttpResponseRedirect
from ..base_client import BaseApp
from ..base_client import OAuth1Mixin
from ..base_client import OAuth2Mixin
from ..base_client import OAuthError
from ..base_client import OpenIDMixin
from ..requests_client import OAuth1Session
from ..requests_client import OAuth2Session
class DjangoAppMixin:
def save_authorize_data(self, request, **kwargs):
state = kwargs.pop("state", None)
if state:
self.framework.set_state_data(request.session, state, kwargs)
else:
raise RuntimeError("Missing state value")
def authorize_redirect(self, request, redirect_uri=None, **kwargs):
"""Create a HTTP Redirect for Authorization Endpoint.
:param request: HTTP request instance from Django view.
:param redirect_uri: Callback or redirect URI for authorization.
:param kwargs: Extra parameters to include.
:return: A HTTP redirect response.
"""
rv = self.create_authorization_url(redirect_uri, **kwargs)
self.save_authorize_data(request, redirect_uri=redirect_uri, **rv)
return HttpResponseRedirect(rv["url"])
class DjangoOAuth1App(DjangoAppMixin, OAuth1Mixin, BaseApp):
client_cls = OAuth1Session
def authorize_access_token(self, request, **kwargs):
"""Fetch access token in one step.
:param request: HTTP request instance from Django view.
:return: A token dict.
"""
params = request.GET.dict()
state = params.get("oauth_token")
if not state:
raise OAuthError(description='Missing "oauth_token" parameter')
data = self.framework.get_state_data(request.session, state)
if not data:
raise OAuthError(description='Missing "request_token" in temporary data')
params["request_token"] = data["request_token"]
params.update(kwargs)
self.framework.clear_state_data(request.session, state)
return self.fetch_access_token(**params)
class DjangoOAuth2App(DjangoAppMixin, OAuth2Mixin, OpenIDMixin, BaseApp):
client_cls = OAuth2Session
def authorize_access_token(self, request, **kwargs):
"""Fetch access token in one step.
:param request: HTTP request instance from Django view.
:return: A token dict.
"""
if request.method == "GET":
error = request.GET.get("error")
if error:
description = request.GET.get("error_description")
raise OAuthError(error=error, description=description)
params = {
"code": request.GET.get("code"),
"state": request.GET.get("state"),
}
else:
params = {
"code": request.POST.get("code"),
"state": request.POST.get("state"),
}
state_data = self.framework.get_state_data(request.session, params.get("state"))
self.framework.clear_state_data(request.session, params.get("state"))
params = self._format_state_params(state_data, params)
claims_options = kwargs.pop("claims_options", None)
claims_cls = kwargs.pop("claims_cls", None)
leeway = kwargs.pop("leeway", 120)
token = self.fetch_access_token(**params, **kwargs)
if "id_token" in token and "nonce" in state_data:
userinfo = self.parse_id_token(
token,
nonce=state_data["nonce"],
claims_options=claims_options,
claims_cls=claims_cls,
leeway=leeway,
)
token["userinfo"] = userinfo
return token

View File

@@ -0,0 +1,23 @@
from django.conf import settings
from django.dispatch import Signal
from ..base_client import FrameworkIntegration
token_update = Signal()
class DjangoIntegration(FrameworkIntegration):
def update_token(self, token, refresh_token=None, access_token=None):
token_update.send(
sender=self.__class__,
name=self.name,
token=token,
refresh_token=refresh_token,
access_token=access_token,
)
@staticmethod
def load_config(oauth, name, params):
config = getattr(settings, "AUTHLIB_OAUTH_CLIENTS", None)
if config:
return config.get(name)

View File

@@ -0,0 +1,5 @@
from .authorization_server import BaseServer
from .authorization_server import CacheAuthorizationServer
from .resource_protector import ResourceProtector
__all__ = ["BaseServer", "CacheAuthorizationServer", "ResourceProtector"]

View File

@@ -0,0 +1,128 @@
import logging
from django.conf import settings
from django.core.cache import cache
from django.http import HttpResponse
from authlib.common.security import generate_token
from authlib.common.urls import url_encode
from authlib.oauth1 import AuthorizationServer as _AuthorizationServer
from authlib.oauth1 import OAuth1Request
from authlib.oauth1 import TemporaryCredential
from .nonce import exists_nonce_in_cache
log = logging.getLogger(__name__)
class BaseServer(_AuthorizationServer):
def __init__(self, client_model, token_model, token_generator=None):
self.client_model = client_model
self.token_model = token_model
if token_generator is None:
def token_generator():
return {
"oauth_token": generate_token(42),
"oauth_token_secret": generate_token(48),
}
self.token_generator = token_generator
self._config = getattr(settings, "AUTHLIB_OAUTH1_PROVIDER", {})
self._nonce_expires_in = self._config.get("nonce_expires_in", 86400)
methods = self._config.get("signature_methods")
if methods:
self.SUPPORTED_SIGNATURE_METHODS = methods
def get_client_by_id(self, client_id):
try:
return self.client_model.objects.get(client_id=client_id)
except self.client_model.DoesNotExist:
return None
def exists_nonce(self, nonce, request):
return exists_nonce_in_cache(nonce, request, self._nonce_expires_in)
def create_token_credential(self, request):
temporary_credential = request.credential
token = self.token_generator()
item = self.token_model(
oauth_token=token["oauth_token"],
oauth_token_secret=token["oauth_token_secret"],
user_id=temporary_credential.get_user_id(),
client_id=temporary_credential.get_client_id(),
)
item.save()
return item
def check_authorization_request(self, request):
req = self.create_oauth1_request(request)
self.validate_authorization_request(req)
return req
def create_oauth1_request(self, request):
if request.method == "POST":
body = request.POST.dict()
else:
body = None
url = request.build_absolute_uri()
return OAuth1Request(request.method, url, body, request.headers)
def handle_response(self, status_code, payload, headers):
resp = HttpResponse(url_encode(payload), status=status_code)
for k, v in headers:
resp[k] = v
return resp
class CacheAuthorizationServer(BaseServer):
def __init__(self, client_model, token_model, token_generator=None):
super().__init__(client_model, token_model, token_generator)
self._temporary_expires_in = self._config.get(
"temporary_credential_expires_in", 86400
)
self._temporary_credential_key_prefix = self._config.get(
"temporary_credential_key_prefix", "temporary_credential:"
)
def create_temporary_credential(self, request):
key_prefix = self._temporary_credential_key_prefix
token = self.token_generator()
client_id = request.client_id
redirect_uri = request.redirect_uri
key = key_prefix + token["oauth_token"]
token["client_id"] = client_id
if redirect_uri:
token["oauth_callback"] = redirect_uri
cache.set(key, token, timeout=self._temporary_expires_in)
return TemporaryCredential(token)
def get_temporary_credential(self, request):
if not request.token:
return None
key_prefix = self._temporary_credential_key_prefix
key = key_prefix + request.token
value = cache.get(key)
if value:
return TemporaryCredential(value)
def delete_temporary_credential(self, request):
if request.token:
key_prefix = self._temporary_credential_key_prefix
key = key_prefix + request.token
cache.delete(key)
def create_authorization_verifier(self, request):
key_prefix = self._temporary_credential_key_prefix
verifier = generate_token(36)
credential = request.credential
user = request.user
key = key_prefix + credential.get_oauth_token()
credential["oauth_verifier"] = verifier
credential["user_id"] = user.pk
cache.set(key, credential, timeout=self._temporary_expires_in)
return verifier

View File

@@ -0,0 +1,15 @@
from django.core.cache import cache
def exists_nonce_in_cache(nonce, request, timeout):
key_prefix = "nonce:"
timestamp = request.timestamp
client_id = request.client_id
token = request.token
key = f"{key_prefix}{nonce}-{timestamp}-{client_id}"
if token:
key = f"{key}-{token}"
rv = bool(cache.get(key))
cache.set(key, 1, timeout=timeout)
return rv

View File

@@ -0,0 +1,68 @@
import functools
from django.conf import settings
from django.http import JsonResponse
from authlib.oauth1 import ResourceProtector as _ResourceProtector
from authlib.oauth1.errors import OAuth1Error
from .nonce import exists_nonce_in_cache
class ResourceProtector(_ResourceProtector):
def __init__(self, client_model, token_model):
self.client_model = client_model
self.token_model = token_model
config = getattr(settings, "AUTHLIB_OAUTH1_PROVIDER", {})
methods = config.get("signature_methods", [])
if methods and isinstance(methods, (list, tuple)):
self.SUPPORTED_SIGNATURE_METHODS = methods
self._nonce_expires_in = config.get("nonce_expires_in", 86400)
def get_client_by_id(self, client_id):
try:
return self.client_model.objects.get(client_id=client_id)
except self.client_model.DoesNotExist:
return None
def get_token_credential(self, request):
try:
return self.token_model.objects.get(
client_id=request.client_id, oauth_token=request.token
)
except self.token_model.DoesNotExist:
return None
def exists_nonce(self, nonce, request):
return exists_nonce_in_cache(nonce, request, self._nonce_expires_in)
def acquire_credential(self, request):
if request.method in ["POST", "PUT"]:
body = request.POST.dict()
else:
body = None
url = request.build_absolute_uri()
req = self.validate_request(request.method, url, body, request.headers)
return req.credential
def __call__(self, realm=None):
def wrapper(f):
@functools.wraps(f)
def decorated(request, *args, **kwargs):
try:
credential = self.acquire_credential(request)
request.oauth1_credential = credential
except OAuth1Error as error:
body = dict(error.get_body())
resp = JsonResponse(body, status=error.status_code)
resp["Cache-Control"] = "no-store"
resp["Pragma"] = "no-cache"
return resp
return f(request, *args, **kwargs)
return decorated
return wrapper

View File

@@ -0,0 +1,9 @@
# flake8: noqa
from .authorization_server import AuthorizationServer
from .endpoints import RevocationEndpoint
from .resource_protector import BearerTokenValidator
from .resource_protector import ResourceProtector
from .signals import client_authenticated
from .signals import token_authenticated
from .signals import token_revoked

View File

@@ -0,0 +1,122 @@
from django.conf import settings
from django.http import HttpResponse
from django.utils.module_loading import import_string
from authlib.common.encoding import json_dumps
from authlib.common.security import generate_token as _generate_token
from authlib.oauth2 import AuthorizationServer as _AuthorizationServer
from authlib.oauth2.rfc6750 import BearerTokenGenerator
from .requests import DjangoJsonRequest
from .requests import DjangoOAuth2Request
from .signals import client_authenticated
from .signals import token_revoked
class AuthorizationServer(_AuthorizationServer):
"""Django implementation of :class:`authlib.oauth2.rfc6749.AuthorizationServer`.
Initialize it with client model and token model::
from authlib.integrations.django_oauth2 import AuthorizationServer
from your_project.models import OAuth2Client, OAuth2Token
server = AuthorizationServer(OAuth2Client, OAuth2Token)
"""
def __init__(self, client_model, token_model):
super().__init__()
self.client_model = client_model
self.token_model = token_model
self.load_config(getattr(settings, "AUTHLIB_OAUTH2_PROVIDER", {}))
def load_config(self, config):
self.config = config
scopes_supported = self.config.get("scopes_supported")
self.scopes_supported = scopes_supported
# add default token generator
self.register_token_generator("default", self.create_bearer_token_generator())
def query_client(self, client_id):
"""Default method for ``AuthorizationServer.query_client``. Developers MAY
rewrite this function to meet their own needs.
"""
try:
return self.client_model.objects.get(client_id=client_id)
except self.client_model.DoesNotExist:
return None
def save_token(self, token, request):
"""Default method for ``AuthorizationServer.save_token``. Developers MAY
rewrite this function to meet their own needs.
"""
client = request.client
if request.user:
user_id = request.user.pk
else:
user_id = client.user_id
item = self.token_model(client_id=client.client_id, user_id=user_id, **token)
item.save()
return item
def create_oauth2_request(self, request):
return DjangoOAuth2Request(request)
def create_json_request(self, request):
return DjangoJsonRequest(request)
def handle_response(self, status_code, payload, headers):
if isinstance(payload, dict):
payload = json_dumps(payload)
resp = HttpResponse(payload, status=status_code)
for k, v in headers:
resp[k] = v
return resp
def send_signal(self, name, *args, **kwargs):
if name == "after_authenticate_client":
client_authenticated.send(*args, sender=self.__class__, **kwargs)
elif name == "after_revoke_token":
token_revoked.send(*args, sender=self.__class__, **kwargs)
def create_bearer_token_generator(self):
"""Default method to create BearerToken generator."""
conf = self.config.get("access_token_generator", True)
access_token_generator = create_token_generator(conf, 42)
conf = self.config.get("refresh_token_generator", False)
refresh_token_generator = create_token_generator(conf, 48)
conf = self.config.get("token_expires_in")
expires_generator = create_token_expires_in_generator(conf)
return BearerTokenGenerator(
access_token_generator=access_token_generator,
refresh_token_generator=refresh_token_generator,
expires_generator=expires_generator,
)
def create_token_generator(token_generator_conf, length=42):
if callable(token_generator_conf):
return token_generator_conf
if isinstance(token_generator_conf, str):
return import_string(token_generator_conf)
elif token_generator_conf is True:
def token_generator(*args, **kwargs):
return _generate_token(length)
return token_generator
def create_token_expires_in_generator(expires_in_conf=None):
data = {}
data.update(BearerTokenGenerator.GRANT_TYPES_EXPIRES_IN)
if expires_in_conf:
data.update(expires_in_conf)
def expires_in(client, grant_type):
return data.get(grant_type, BearerTokenGenerator.DEFAULT_EXPIRES_IN)
return expires_in

View File

@@ -0,0 +1,56 @@
from authlib.oauth2.rfc7009 import RevocationEndpoint as _RevocationEndpoint
class RevocationEndpoint(_RevocationEndpoint):
"""The revocation endpoint for OAuth authorization servers allows clients
to notify the authorization server that a previously obtained refresh or
access token is no longer needed.
Register it into authorization server, and create token endpoint response
for token revocation::
from django.views.decorators.http import require_http_methods
# see register into authorization server instance
server.register_endpoint(RevocationEndpoint)
@require_http_methods(["POST"])
def revoke_token(request):
return server.create_endpoint_response(
RevocationEndpoint.ENDPOINT_NAME, request
)
"""
def query_token(self, token, token_type_hint):
"""Query requested token from database."""
token_model = self.server.token_model
if token_type_hint == "access_token":
rv = _query_access_token(token_model, token)
elif token_type_hint == "refresh_token":
rv = _query_refresh_token(token_model, token)
else:
rv = _query_access_token(token_model, token)
if not rv:
rv = _query_refresh_token(token_model, token)
return rv
def revoke_token(self, token, request):
"""Mark the give token as revoked."""
token.revoked = True
token.save()
def _query_access_token(token_model, token):
try:
return token_model.objects.get(access_token=token)
except token_model.DoesNotExist:
return None
def _query_refresh_token(token_model, token):
try:
return token_model.objects.get(refresh_token=token)
except token_model.DoesNotExist:
return None

View File

@@ -0,0 +1,65 @@
from collections import defaultdict
from django.http import HttpRequest
from django.utils.functional import cached_property
from authlib.common.encoding import json_loads
from authlib.oauth2.rfc6749 import JsonPayload
from authlib.oauth2.rfc6749 import JsonRequest
from authlib.oauth2.rfc6749 import OAuth2Payload
from authlib.oauth2.rfc6749 import OAuth2Request
class DjangoOAuth2Payload(OAuth2Payload):
def __init__(self, request: HttpRequest):
self._request = request
@cached_property
def data(self):
data = {}
data.update(self._request.GET.dict())
data.update(self._request.POST.dict())
return data
@cached_property
def datalist(self):
values = defaultdict(list)
for k in self._request.GET:
values[k].extend(self._request.GET.getlist(k))
for k in self._request.POST:
values[k].extend(self._request.POST.getlist(k))
return values
class DjangoOAuth2Request(OAuth2Request):
def __init__(self, request: HttpRequest):
super().__init__(
method=request.method,
uri=request.build_absolute_uri(),
headers=request.headers,
)
self.payload = DjangoOAuth2Payload(request)
self._request = request
@property
def args(self):
return self._request.GET
@property
def form(self):
return self._request.POST
class DjangoJsonPayload(JsonPayload):
def __init__(self, request: HttpRequest):
self._request = request
@cached_property
def data(self):
return json_loads(self._request.body)
class DjangoJsonRequest(JsonRequest):
def __init__(self, request: HttpRequest):
super().__init__(request.method, request.build_absolute_uri(), request.headers)
self.payload = DjangoJsonPayload(request)

View File

@@ -0,0 +1,75 @@
import functools
from django.http import JsonResponse
from authlib.oauth2 import OAuth2Error
from authlib.oauth2 import ResourceProtector as _ResourceProtector
from authlib.oauth2.rfc6749 import MissingAuthorizationError
from authlib.oauth2.rfc6750 import BearerTokenValidator as _BearerTokenValidator
from .requests import DjangoJsonRequest
from .signals import token_authenticated
class ResourceProtector(_ResourceProtector):
def acquire_token(self, request, scopes=None, **kwargs):
"""A method to acquire current valid token with the given scope.
:param request: Django HTTP request instance
:param scopes: a list of scope values
:return: token object
"""
req = DjangoJsonRequest(request)
# backward compatibility
kwargs["scopes"] = scopes
for claim in kwargs:
if isinstance(kwargs[claim], str):
kwargs[claim] = [kwargs[claim]]
token = self.validate_request(request=req, **kwargs)
token_authenticated.send(sender=self.__class__, token=token)
return token
def __call__(self, scopes=None, optional=False, **kwargs):
claims = kwargs
# backward compatibility
claims["scopes"] = scopes
def wrapper(f):
@functools.wraps(f)
def decorated(request, *args, **kwargs):
try:
token = self.acquire_token(request, **claims)
request.oauth_token = token
except MissingAuthorizationError as error:
if optional:
request.oauth_token = None
return f(request, *args, **kwargs)
return return_error_response(error)
except OAuth2Error as error:
return return_error_response(error)
return f(request, *args, **kwargs)
return decorated
return wrapper
class BearerTokenValidator(_BearerTokenValidator):
def __init__(self, token_model, realm=None, **extra_attributes):
self.token_model = token_model
super().__init__(realm, **extra_attributes)
def authenticate_token(self, token_string):
try:
return self.token_model.objects.get(access_token=token_string)
except self.token_model.DoesNotExist:
return None
def return_error_response(error):
body = dict(error.get_body())
resp = JsonResponse(body, status=error.status_code)
headers = error.get_headers()
for k, v in headers:
resp[k] = v
return resp

View File

@@ -0,0 +1,10 @@
from django.dispatch import Signal
#: signal when client is authenticated
client_authenticated = Signal()
#: signal when token is revoked
token_revoked = Signal()
#: signal when token is authenticated
token_authenticated = Signal()

View File

@@ -0,0 +1,59 @@
from werkzeug.local import LocalProxy
from ..base_client import BaseOAuth
from ..base_client import OAuthError
from .apps import FlaskOAuth1App
from .apps import FlaskOAuth2App
from .integration import FlaskIntegration
from .integration import token_update
class OAuth(BaseOAuth):
oauth1_client_cls = FlaskOAuth1App
oauth2_client_cls = FlaskOAuth2App
framework_integration_cls = FlaskIntegration
def __init__(self, app=None, cache=None, fetch_token=None, update_token=None):
super().__init__(
cache=cache, fetch_token=fetch_token, update_token=update_token
)
self.app = app
if app:
self.init_app(app)
def init_app(self, app, cache=None, fetch_token=None, update_token=None):
"""Initialize lazy for Flask app. This is usually used for Flask application
factory pattern.
"""
self.app = app
if cache is not None:
self.cache = cache
if fetch_token:
self.fetch_token = fetch_token
if update_token:
self.update_token = update_token
app.extensions = getattr(app, "extensions", {})
app.extensions["authlib.integrations.flask_client"] = self
def create_client(self, name):
if not self.app:
raise RuntimeError("OAuth is not init with Flask app.")
return super().create_client(name)
def register(self, name, overwrite=False, **kwargs):
self._registry[name] = (overwrite, kwargs)
if self.app:
return self.create_client(name)
return LocalProxy(lambda: self.create_client(name))
__all__ = [
"OAuth",
"FlaskIntegration",
"FlaskOAuth1App",
"FlaskOAuth2App",
"token_update",
"OAuthError",
]

View File

@@ -0,0 +1,122 @@
from flask import g
from flask import redirect
from flask import request
from flask import session
from ..base_client import BaseApp
from ..base_client import OAuth1Mixin
from ..base_client import OAuth2Mixin
from ..base_client import OAuthError
from ..base_client import OpenIDMixin
from ..requests_client import OAuth1Session
from ..requests_client import OAuth2Session
class FlaskAppMixin:
@property
def token(self):
attr = f"_oauth_token_{self.name}"
token = g.get(attr)
if token:
return token
if self._fetch_token:
token = self._fetch_token()
self.token = token
return token
@token.setter
def token(self, token):
attr = f"_oauth_token_{self.name}"
setattr(g, attr, token)
def _get_requested_token(self, *args, **kwargs):
return self.token
def save_authorize_data(self, **kwargs):
state = kwargs.pop("state", None)
if state:
self.framework.set_state_data(session, state, kwargs)
else:
raise RuntimeError("Missing state value")
def authorize_redirect(self, redirect_uri=None, **kwargs):
"""Create a HTTP Redirect for Authorization Endpoint.
:param redirect_uri: Callback or redirect URI for authorization.
:param kwargs: Extra parameters to include.
:return: A HTTP redirect response.
"""
rv = self.create_authorization_url(redirect_uri, **kwargs)
self.save_authorize_data(redirect_uri=redirect_uri, **rv)
return redirect(rv["url"])
class FlaskOAuth1App(FlaskAppMixin, OAuth1Mixin, BaseApp):
client_cls = OAuth1Session
def authorize_access_token(self, **kwargs):
"""Fetch access token in one step.
:return: A token dict.
"""
params = request.args.to_dict(flat=True)
state = params.get("oauth_token")
if not state:
raise OAuthError(description='Missing "oauth_token" parameter')
data = self.framework.get_state_data(session, state)
if not data:
raise OAuthError(description='Missing "request_token" in temporary data')
params["request_token"] = data["request_token"]
params.update(kwargs)
self.framework.clear_state_data(session, state)
token = self.fetch_access_token(**params)
self.token = token
return token
class FlaskOAuth2App(FlaskAppMixin, OAuth2Mixin, OpenIDMixin, BaseApp):
client_cls = OAuth2Session
def authorize_access_token(self, **kwargs):
"""Fetch access token in one step.
:return: A token dict.
"""
if request.method == "GET":
error = request.args.get("error")
if error:
description = request.args.get("error_description")
raise OAuthError(error=error, description=description)
params = {
"code": request.args.get("code"),
"state": request.args.get("state"),
}
else:
params = {
"code": request.form.get("code"),
"state": request.form.get("state"),
}
state_data = self.framework.get_state_data(session, params.get("state"))
self.framework.clear_state_data(session, params.get("state"))
params = self._format_state_params(state_data, params)
claims_options = kwargs.pop("claims_options", None)
claims_cls = kwargs.pop("claims_cls", None)
leeway = kwargs.pop("leeway", 120)
token = self.fetch_access_token(**params, **kwargs)
self.token = token
if "id_token" in token and "nonce" in state_data:
userinfo = self.parse_id_token(
token,
nonce=state_data["nonce"],
claims_options=claims_options,
claims_cls=claims_cls,
leeway=leeway,
)
token["userinfo"] = userinfo
return token

View File

@@ -0,0 +1,29 @@
from flask import current_app
from flask.signals import Namespace
from ..base_client import FrameworkIntegration
_signal = Namespace()
#: signal when token is updated
token_update = _signal.signal("token_update")
class FlaskIntegration(FrameworkIntegration):
def update_token(self, token, refresh_token=None, access_token=None):
token_update.send(
current_app,
name=self.name,
token=token,
refresh_token=refresh_token,
access_token=access_token,
)
@staticmethod
def load_config(oauth, name, params):
rv = {}
for k in params:
conf_key = f"{name}_{k}".upper()
v = oauth.app.config.get(conf_key, None)
if v is not None:
rv[k] = v
return rv

View File

@@ -0,0 +1,8 @@
# flake8: noqa
from .authorization_server import AuthorizationServer
from .cache import create_exists_nonce_func
from .cache import register_nonce_hooks
from .cache import register_temporary_credential_hooks
from .resource_protector import ResourceProtector
from .resource_protector import current_credential

View File

@@ -0,0 +1,168 @@
import logging
from flask import Response
from flask import request as flask_req
from werkzeug.utils import import_string
from authlib.common.security import generate_token
from authlib.common.urls import url_encode
from authlib.oauth1 import AuthorizationServer as _AuthorizationServer
from authlib.oauth1 import OAuth1Request
log = logging.getLogger(__name__)
class AuthorizationServer(_AuthorizationServer):
"""Flask implementation of :class:`authlib.rfc5849.AuthorizationServer`.
Initialize it with Flask app instance, client model class and cache::
server = AuthorizationServer(app=app, query_client=query_client)
# or initialize lazily
server = AuthorizationServer()
server.init_app(app, query_client=query_client)
:param app: A Flask app instance
:param query_client: A function to get client by client_id. The client
model class MUST implement the methods described by
:class:`~authlib.oauth1.rfc5849.ClientMixin`.
:param token_generator: A function to generate token
"""
def __init__(self, app=None, query_client=None, token_generator=None):
self.app = app
self.query_client = query_client
self.token_generator = token_generator
self._hooks = {
"exists_nonce": None,
"create_temporary_credential": None,
"get_temporary_credential": None,
"delete_temporary_credential": None,
"create_authorization_verifier": None,
"create_token_credential": None,
}
if app is not None:
self.init_app(app)
def init_app(self, app, query_client=None, token_generator=None):
if query_client is not None:
self.query_client = query_client
if token_generator is not None:
self.token_generator = token_generator
if self.token_generator is None:
self.token_generator = self.create_token_generator(app)
methods = app.config.get("OAUTH1_SUPPORTED_SIGNATURE_METHODS")
if methods and isinstance(methods, (list, tuple)):
self.SUPPORTED_SIGNATURE_METHODS = methods
self.app = app
def register_hook(self, name, func):
if name not in self._hooks:
raise ValueError('Invalid "name" of hook')
self._hooks[name] = func
def create_token_generator(self, app):
token_generator = app.config.get("OAUTH1_TOKEN_GENERATOR")
if isinstance(token_generator, str):
token_generator = import_string(token_generator)
else:
length = app.config.get("OAUTH1_TOKEN_LENGTH", 42)
def token_generator():
return generate_token(length)
secret_generator = app.config.get("OAUTH1_TOKEN_SECRET_GENERATOR")
if isinstance(secret_generator, str):
secret_generator = import_string(secret_generator)
else:
length = app.config.get("OAUTH1_TOKEN_SECRET_LENGTH", 48)
def secret_generator():
return generate_token(length)
def create_token():
return {
"oauth_token": token_generator(),
"oauth_token_secret": secret_generator(),
}
return create_token
def get_client_by_id(self, client_id):
return self.query_client(client_id)
def exists_nonce(self, nonce, request):
func = self._hooks["exists_nonce"]
if callable(func):
timestamp = request.timestamp
client_id = request.client_id
token = request.token
return func(nonce, timestamp, client_id, token)
raise RuntimeError('"exists_nonce" hook is required.')
def create_temporary_credential(self, request):
func = self._hooks["create_temporary_credential"]
if callable(func):
token = self.token_generator()
return func(token, request.client_id, request.redirect_uri)
raise RuntimeError('"create_temporary_credential" hook is required.')
def get_temporary_credential(self, request):
func = self._hooks["get_temporary_credential"]
if callable(func):
return func(request.token)
raise RuntimeError('"get_temporary_credential" hook is required.')
def delete_temporary_credential(self, request):
func = self._hooks["delete_temporary_credential"]
if callable(func):
return func(request.token)
raise RuntimeError('"delete_temporary_credential" hook is required.')
def create_authorization_verifier(self, request):
func = self._hooks["create_authorization_verifier"]
if callable(func):
verifier = generate_token(36)
func(request.credential, request.user, verifier)
return verifier
raise RuntimeError('"create_authorization_verifier" hook is required.')
def create_token_credential(self, request):
func = self._hooks["create_token_credential"]
if callable(func):
temporary_credential = request.credential
token = self.token_generator()
return func(token, temporary_credential)
raise RuntimeError('"create_token_credential" hook is required.')
def check_authorization_request(self):
req = self.create_oauth1_request(None)
self.validate_authorization_request(req)
return req
def create_authorization_response(self, request=None, grant_user=None):
return super().create_authorization_response(request, grant_user)
def create_token_response(self, request=None):
return super().create_token_response(request)
def create_oauth1_request(self, request):
if request is None:
request = flask_req
if request.method in ("POST", "PUT"):
body = request.form.to_dict(flat=True)
else:
body = None
return OAuth1Request(request.method, request.url, body, request.headers)
def handle_response(self, status_code, payload, headers):
return Response(url_encode(payload), status=status_code, headers=headers)

View File

@@ -0,0 +1,88 @@
from authlib.oauth1 import TemporaryCredential
def register_temporary_credential_hooks(
authorization_server, cache, key_prefix="temporary_credential:"
):
"""Register temporary credential related hooks to authorization server.
:param authorization_server: AuthorizationServer instance
:param cache: Cache instance
:param key_prefix: key prefix for temporary credential
"""
def create_temporary_credential(token, client_id, redirect_uri):
key = key_prefix + token["oauth_token"]
token["client_id"] = client_id
if redirect_uri:
token["oauth_callback"] = redirect_uri
cache.set(key, token, timeout=86400) # cache for one day
return TemporaryCredential(token)
def get_temporary_credential(oauth_token):
if not oauth_token:
return None
key = key_prefix + oauth_token
value = cache.get(key)
if value:
return TemporaryCredential(value)
def delete_temporary_credential(oauth_token):
if oauth_token:
key = key_prefix + oauth_token
cache.delete(key)
def create_authorization_verifier(credential, grant_user, verifier):
key = key_prefix + credential.get_oauth_token()
credential["oauth_verifier"] = verifier
credential["user_id"] = grant_user.get_user_id()
cache.set(key, credential, timeout=86400)
return credential
authorization_server.register_hook(
"create_temporary_credential", create_temporary_credential
)
authorization_server.register_hook(
"get_temporary_credential", get_temporary_credential
)
authorization_server.register_hook(
"delete_temporary_credential", delete_temporary_credential
)
authorization_server.register_hook(
"create_authorization_verifier", create_authorization_verifier
)
def create_exists_nonce_func(cache, key_prefix="nonce:", expires=86400):
"""Create an ``exists_nonce`` function that can be used in hooks and
resource protector.
:param cache: Cache instance
:param key_prefix: key prefix for temporary credential
:param expires: Expire time for nonce
"""
def exists_nonce(nonce, timestamp, client_id, oauth_token):
key = f"{key_prefix}{nonce}-{timestamp}-{client_id}"
if oauth_token:
key = f"{key}-{oauth_token}"
rv = cache.has(key)
cache.set(key, 1, timeout=expires)
return rv
return exists_nonce
def register_nonce_hooks(
authorization_server, cache, key_prefix="nonce:", expires=86400
):
"""Register nonce related hooks to authorization server.
:param authorization_server: AuthorizationServer instance
:param cache: Cache instance
:param key_prefix: key prefix for temporary credential
:param expires: Expire time for nonce
"""
exists_nonce = create_exists_nonce_func(cache, key_prefix, expires)
authorization_server.register_hook("exists_nonce", exists_nonce)

View File

@@ -0,0 +1,121 @@
import functools
from flask import Response
from flask import g
from flask import json
from flask import request as _req
from werkzeug.local import LocalProxy
from authlib.consts import default_json_headers
from authlib.oauth1 import ResourceProtector as _ResourceProtector
from authlib.oauth1.errors import OAuth1Error
class ResourceProtector(_ResourceProtector):
"""A protecting method for resource servers. Initialize a resource
protector with the these method:
1. query_client
2. query_token,
3. exists_nonce
Usually, a ``query_client`` method would look like (if using SQLAlchemy)::
def query_client(client_id):
return Client.query.filter_by(client_id=client_id).first()
A ``query_token`` method accept two parameters, ``client_id`` and ``oauth_token``::
def query_token(client_id, oauth_token):
return Token.query.filter_by(
client_id=client_id, oauth_token=oauth_token
).first()
And for ``exists_nonce``, if using cache, we have a built-in hook to create this method::
from authlib.integrations.flask_oauth1 import create_exists_nonce_func
exists_nonce = create_exists_nonce_func(cache)
Then initialize the resource protector with those methods::
require_oauth = ResourceProtector(
app,
query_client=query_client,
query_token=query_token,
exists_nonce=exists_nonce,
)
"""
def __init__(
self, app=None, query_client=None, query_token=None, exists_nonce=None
):
self.query_client = query_client
self.query_token = query_token
self._exists_nonce = exists_nonce
self.app = app
if app:
self.init_app(app)
def init_app(self, app, query_client=None, query_token=None, exists_nonce=None):
if query_client is not None:
self.query_client = query_client
if query_token is not None:
self.query_token = query_token
if exists_nonce is not None:
self._exists_nonce = exists_nonce
methods = app.config.get("OAUTH1_SUPPORTED_SIGNATURE_METHODS")
if methods and isinstance(methods, (list, tuple)):
self.SUPPORTED_SIGNATURE_METHODS = methods
self.app = app
def get_client_by_id(self, client_id):
return self.query_client(client_id)
def get_token_credential(self, request):
return self.query_token(request.client_id, request.token)
def exists_nonce(self, nonce, request):
if not self._exists_nonce:
raise RuntimeError('"exists_nonce" function is required.')
timestamp = request.timestamp
client_id = request.client_id
token = request.token
return self._exists_nonce(nonce, timestamp, client_id, token)
def acquire_credential(self):
req = self.validate_request(
_req.method, _req.url, _req.form.to_dict(flat=True), _req.headers
)
g.authlib_server_oauth1_credential = req.credential
return req.credential
def __call__(self, scope=None):
def wrapper(f):
@functools.wraps(f)
def decorated(*args, **kwargs):
try:
self.acquire_credential()
except OAuth1Error as error:
body = dict(error.get_body())
return Response(
json.dumps(body),
status=error.status_code,
headers=default_json_headers,
)
return f(*args, **kwargs)
return decorated
return wrapper
def _get_current_credential():
return g.get("authlib_server_oauth1_credential")
current_credential = LocalProxy(_get_current_credential)

View File

@@ -0,0 +1,8 @@
# flake8: noqa
from .authorization_server import AuthorizationServer
from .resource_protector import ResourceProtector
from .resource_protector import current_token
from .signals import client_authenticated
from .signals import token_authenticated
from .signals import token_revoked

View File

@@ -0,0 +1,165 @@
from flask import Response
from flask import json
from flask import request as flask_req
from werkzeug.utils import import_string
from authlib.common.security import generate_token
from authlib.oauth2 import AuthorizationServer as _AuthorizationServer
from authlib.oauth2.rfc6750 import BearerTokenGenerator
from .requests import FlaskJsonRequest
from .requests import FlaskOAuth2Request
from .signals import client_authenticated
from .signals import token_revoked
class AuthorizationServer(_AuthorizationServer):
"""Flask implementation of :class:`authlib.oauth2.rfc6749.AuthorizationServer`.
Initialize it with ``query_client``, ``save_token`` methods and Flask
app instance::
def query_client(client_id):
return Client.query.filter_by(client_id=client_id).first()
def save_token(token, request):
if request.user:
user_id = request.user.id
else:
user_id = None
client = request.client
tok = Token(client_id=client.client_id, user_id=user.id, **token)
db.session.add(tok)
db.session.commit()
server = AuthorizationServer(app, query_client, save_token)
# or initialize lazily
server = AuthorizationServer()
server.init_app(app, query_client, save_token)
"""
def __init__(self, app=None, query_client=None, save_token=None):
super().__init__()
self._query_client = query_client
self._save_token = save_token
self._error_uris = None
if app is not None:
self.init_app(app)
def init_app(self, app, query_client=None, save_token=None):
"""Initialize later with Flask app instance."""
if query_client is not None:
self._query_client = query_client
if save_token is not None:
self._save_token = save_token
self.load_config(app.config)
def load_config(self, config):
self.register_token_generator(
"default", self.create_bearer_token_generator(config)
)
self.scopes_supported = config.get("OAUTH2_SCOPES_SUPPORTED")
self._error_uris = config.get("OAUTH2_ERROR_URIS")
def query_client(self, client_id):
return self._query_client(client_id)
def save_token(self, token, request):
return self._save_token(token, request)
def get_error_uri(self, request, error):
if self._error_uris:
uris = dict(self._error_uris)
return uris.get(error.error)
def create_oauth2_request(self, request):
return FlaskOAuth2Request(flask_req)
def create_json_request(self, request):
return FlaskJsonRequest(flask_req)
def handle_response(self, status_code, payload, headers):
if isinstance(payload, dict):
payload = json.dumps(payload)
return Response(payload, status=status_code, headers=headers)
def send_signal(self, name, *args, **kwargs):
if name == "after_authenticate_client":
client_authenticated.send(self, *args, **kwargs)
elif name == "after_revoke_token":
token_revoked.send(self, *args, **kwargs)
def create_bearer_token_generator(self, config):
"""Create a generator function for generating ``token`` value. This
method will create a Bearer Token generator with
:class:`authlib.oauth2.rfc6750.BearerToken`.
Configurable settings:
1. OAUTH2_ACCESS_TOKEN_GENERATOR: Boolean or import string, default is True.
2. OAUTH2_REFRESH_TOKEN_GENERATOR: Boolean or import string, default is False.
3. OAUTH2_TOKEN_EXPIRES_IN: Dict or import string, default is None.
By default, it will not generate ``refresh_token``, which can be turn on by
configure ``OAUTH2_REFRESH_TOKEN_GENERATOR``.
Here are some examples of the token generator::
OAUTH2_ACCESS_TOKEN_GENERATOR = "your_project.generators.gen_token"
# and in module `your_project.generators`, you can define:
def gen_token(client, grant_type, user, scope):
# generate token according to these parameters
token = create_random_token()
return f"{client.id}-{user.id}-{token}"
Here is an example of ``OAUTH2_TOKEN_EXPIRES_IN``::
OAUTH2_TOKEN_EXPIRES_IN = {
"authorization_code": 864000,
"urn:ietf:params:oauth:grant-type:jwt-bearer": 3600,
}
"""
conf = config.get("OAUTH2_ACCESS_TOKEN_GENERATOR", True)
access_token_generator = create_token_generator(conf, 42)
conf = config.get("OAUTH2_REFRESH_TOKEN_GENERATOR", False)
refresh_token_generator = create_token_generator(conf, 48)
expires_conf = config.get("OAUTH2_TOKEN_EXPIRES_IN")
expires_generator = create_token_expires_in_generator(expires_conf)
return BearerTokenGenerator(
access_token_generator, refresh_token_generator, expires_generator
)
def create_token_expires_in_generator(expires_in_conf=None):
if isinstance(expires_in_conf, str):
return import_string(expires_in_conf)
data = {}
data.update(BearerTokenGenerator.GRANT_TYPES_EXPIRES_IN)
if isinstance(expires_in_conf, dict):
data.update(expires_in_conf)
def expires_in(client, grant_type):
return data.get(grant_type, BearerTokenGenerator.DEFAULT_EXPIRES_IN)
return expires_in
def create_token_generator(token_generator_conf, length=42):
if callable(token_generator_conf):
return token_generator_conf
if isinstance(token_generator_conf, str):
return import_string(token_generator_conf)
elif token_generator_conf is True:
def token_generator(*args, **kwargs):
return generate_token(length)
return token_generator

View File

@@ -0,0 +1,40 @@
import importlib.metadata
from werkzeug.exceptions import HTTPException
_version = importlib.metadata.version("werkzeug").split(".")[0]
if _version in ("0", "1"):
class _HTTPException(HTTPException):
def __init__(self, code, body, headers, response=None):
super().__init__(None, response)
self.code = code
self.body = body
self.headers = headers
def get_body(self, environ=None):
return self.body
def get_headers(self, environ=None):
return self.headers
else:
class _HTTPException(HTTPException):
def __init__(self, code, body, headers, response=None):
super().__init__(None, response)
self.code = code
self.body = body
self.headers = headers
def get_body(self, environ=None, scope=None):
return self.body
def get_headers(self, environ=None, scope=None):
return self.headers
def raise_http_exception(status, body, headers):
raise _HTTPException(status, body, headers)

View File

@@ -0,0 +1,57 @@
from collections import defaultdict
from functools import cached_property
from flask.wrappers import Request
from authlib.oauth2.rfc6749 import JsonPayload
from authlib.oauth2.rfc6749 import JsonRequest
from authlib.oauth2.rfc6749 import OAuth2Payload
from authlib.oauth2.rfc6749 import OAuth2Request
class FlaskOAuth2Payload(OAuth2Payload):
def __init__(self, request: Request):
self._request = request
@property
def data(self):
return self._request.values
@cached_property
def datalist(self):
values = defaultdict(list)
for k in self.data:
values[k].extend(self.data.getlist(k))
return values
class FlaskOAuth2Request(OAuth2Request):
def __init__(self, request: Request):
super().__init__(
method=request.method, uri=request.url, headers=request.headers
)
self._request = request
self.payload = FlaskOAuth2Payload(request)
@property
def args(self):
return self._request.args
@property
def form(self):
return self._request.form
class FlaskJsonPayload(JsonPayload):
def __init__(self, request: Request):
self._request = request
@property
def data(self):
return self._request.get_json()
class FlaskJsonRequest(JsonRequest):
def __init__(self, request: Request):
super().__init__(request.method, request.url, request.headers)
self.payload = FlaskJsonPayload(request)

View File

@@ -0,0 +1,121 @@
import functools
from contextlib import contextmanager
from flask import g
from flask import json
from flask import request as _req
from werkzeug.local import LocalProxy
from authlib.oauth2 import OAuth2Error
from authlib.oauth2 import ResourceProtector as _ResourceProtector
from authlib.oauth2.rfc6749 import MissingAuthorizationError
from .errors import raise_http_exception
from .requests import FlaskJsonRequest
from .signals import token_authenticated
class ResourceProtector(_ResourceProtector):
"""A protecting method for resource servers. Creating a ``require_oauth``
decorator easily with ResourceProtector::
from authlib.integrations.flask_oauth2 import ResourceProtector
require_oauth = ResourceProtector()
# add bearer token validator
from authlib.oauth2.rfc6750 import BearerTokenValidator
from project.models import Token
class MyBearerTokenValidator(BearerTokenValidator):
def authenticate_token(self, token_string):
return Token.query.filter_by(access_token=token_string).first()
require_oauth.register_token_validator(MyBearerTokenValidator())
# protect resource with require_oauth
@app.route("/user")
@require_oauth(["profile"])
def user_profile():
user = User.get(current_token.user_id)
return jsonify(user.to_dict())
"""
def raise_error_response(self, error):
"""Raise HTTPException for OAuth2Error. Developers can re-implement
this method to customize the error response.
:param error: OAuth2Error
:raise: HTTPException
"""
status = error.status_code
body = json.dumps(dict(error.get_body()))
headers = error.get_headers()
raise_http_exception(status, body, headers)
def acquire_token(self, scopes=None, **kwargs):
"""A method to acquire current valid token with the given scope.
:param scopes: a list of scope values
:return: token object
"""
request = FlaskJsonRequest(_req)
# backward compatibility
kwargs["scopes"] = scopes
for claim in kwargs:
if isinstance(kwargs[claim], str):
kwargs[claim] = [kwargs[claim]]
token = self.validate_request(request=request, **kwargs)
token_authenticated.send(self, token=token)
g.authlib_server_oauth2_token = token
return token
@contextmanager
def acquire(self, scopes=None):
"""The with statement of ``require_oauth``. Instead of using a
decorator, you can use a with statement instead::
@app.route("/api/user")
def user_api():
with require_oauth.acquire("profile") as token:
user = User.get(token.user_id)
return jsonify(user.to_dict())
"""
try:
yield self.acquire_token(scopes)
except OAuth2Error as error:
self.raise_error_response(error)
def __call__(self, scopes=None, optional=False, **kwargs):
claims = kwargs
# backward compatibility
claims["scopes"] = scopes
def wrapper(f):
@functools.wraps(f)
def decorated(*args, **kwargs):
try:
self.acquire_token(**claims)
except MissingAuthorizationError as error:
if optional:
return f(*args, **kwargs)
self.raise_error_response(error)
except OAuth2Error as error:
self.raise_error_response(error)
return f(*args, **kwargs)
return decorated
return wrapper
def _get_current_token():
return g.get("authlib_server_oauth2_token")
current_token = LocalProxy(_get_current_token)

View File

@@ -0,0 +1,12 @@
from flask.signals import Namespace
_signal = Namespace()
#: signal when client is authenticated
client_authenticated = _signal.signal("client_authenticated")
#: signal when token is revoked
token_revoked = _signal.signal("token_revoked")
#: signal when token is authenticated
token_authenticated = _signal.signal("token_authenticated")

View File

@@ -0,0 +1,36 @@
from authlib.oauth1 import SIGNATURE_HMAC_SHA1
from authlib.oauth1 import SIGNATURE_PLAINTEXT
from authlib.oauth1 import SIGNATURE_RSA_SHA1
from authlib.oauth1 import SIGNATURE_TYPE_BODY
from authlib.oauth1 import SIGNATURE_TYPE_HEADER
from authlib.oauth1 import SIGNATURE_TYPE_QUERY
from ..base_client import OAuthError
from .assertion_client import AssertionClient
from .assertion_client import AsyncAssertionClient
from .oauth1_client import AsyncOAuth1Client
from .oauth1_client import OAuth1Auth
from .oauth1_client import OAuth1Client
from .oauth2_client import AsyncOAuth2Client
from .oauth2_client import OAuth2Auth
from .oauth2_client import OAuth2Client
from .oauth2_client import OAuth2ClientAuth
__all__ = [
"OAuthError",
"OAuth1Auth",
"AsyncOAuth1Client",
"OAuth1Client",
"SIGNATURE_HMAC_SHA1",
"SIGNATURE_RSA_SHA1",
"SIGNATURE_PLAINTEXT",
"SIGNATURE_TYPE_HEADER",
"SIGNATURE_TYPE_QUERY",
"SIGNATURE_TYPE_BODY",
"OAuth2Auth",
"OAuth2ClientAuth",
"OAuth2Client",
"AsyncOAuth2Client",
"AssertionClient",
"AsyncAssertionClient",
]

View File

@@ -0,0 +1,124 @@
import httpx
from httpx import USE_CLIENT_DEFAULT
from httpx import Response
from authlib.oauth2.rfc7521 import AssertionClient as _AssertionClient
from authlib.oauth2.rfc7523 import JWTBearerGrant
from ..base_client import OAuthError
from .oauth2_client import OAuth2Auth
from .utils import extract_client_kwargs
__all__ = ["AsyncAssertionClient"]
class AsyncAssertionClient(_AssertionClient, httpx.AsyncClient):
token_auth_class = OAuth2Auth
oauth_error_class = OAuthError
JWT_BEARER_GRANT_TYPE = JWTBearerGrant.GRANT_TYPE
ASSERTION_METHODS = {
JWT_BEARER_GRANT_TYPE: JWTBearerGrant.sign,
}
DEFAULT_GRANT_TYPE = JWT_BEARER_GRANT_TYPE
def __init__(
self,
token_endpoint,
issuer,
subject,
audience=None,
grant_type=None,
claims=None,
token_placement="header",
scope=None,
**kwargs,
):
client_kwargs = extract_client_kwargs(kwargs)
httpx.AsyncClient.__init__(self, **client_kwargs)
_AssertionClient.__init__(
self,
session=None,
token_endpoint=token_endpoint,
issuer=issuer,
subject=subject,
audience=audience,
grant_type=grant_type,
claims=claims,
token_placement=token_placement,
scope=scope,
**kwargs,
)
async def request(
self, method, url, withhold_token=False, auth=USE_CLIENT_DEFAULT, **kwargs
) -> Response:
"""Send request with auto refresh token feature."""
if not withhold_token and auth is USE_CLIENT_DEFAULT:
if not self.token or self.token.is_expired():
await self.refresh_token()
auth = self.token_auth
return await super().request(method, url, auth=auth, **kwargs)
async def _refresh_token(self, data):
resp = await self.request(
"POST", self.token_endpoint, data=data, withhold_token=True
)
return self.parse_response_token(resp)
class AssertionClient(_AssertionClient, httpx.Client):
token_auth_class = OAuth2Auth
oauth_error_class = OAuthError
JWT_BEARER_GRANT_TYPE = JWTBearerGrant.GRANT_TYPE
ASSERTION_METHODS = {
JWT_BEARER_GRANT_TYPE: JWTBearerGrant.sign,
}
DEFAULT_GRANT_TYPE = JWT_BEARER_GRANT_TYPE
def __init__(
self,
token_endpoint,
issuer,
subject,
audience=None,
grant_type=None,
claims=None,
token_placement="header",
scope=None,
**kwargs,
):
client_kwargs = extract_client_kwargs(kwargs)
# app keyword was dropped!
app_value = client_kwargs.pop("app", None)
if app_value is not None:
client_kwargs["transport"] = httpx.WSGITransport(app=app_value)
httpx.Client.__init__(self, **client_kwargs)
_AssertionClient.__init__(
self,
session=self,
token_endpoint=token_endpoint,
issuer=issuer,
subject=subject,
audience=audience,
grant_type=grant_type,
claims=claims,
token_placement=token_placement,
scope=scope,
**kwargs,
)
def request(
self, method, url, withhold_token=False, auth=USE_CLIENT_DEFAULT, **kwargs
):
"""Send request with auto refresh token feature."""
if not withhold_token and auth is USE_CLIENT_DEFAULT:
if not self.token or self.token.is_expired():
self.refresh_token()
auth = self.token_auth
return super().request(method, url, auth=auth, **kwargs)

View File

@@ -0,0 +1,145 @@
import typing
import httpx
from httpx import Auth
from httpx import Request
from httpx import Response
from authlib.common.encoding import to_unicode
from authlib.oauth1 import SIGNATURE_HMAC_SHA1
from authlib.oauth1 import SIGNATURE_TYPE_HEADER
from authlib.oauth1 import ClientAuth
from authlib.oauth1.client import OAuth1Client as _OAuth1Client
from ..base_client import OAuthError
from .utils import build_request
from .utils import extract_client_kwargs
class OAuth1Auth(Auth, ClientAuth):
"""Signs the httpx request using OAuth 1 (RFC5849)."""
requires_request_body = True
def auth_flow(self, request: Request) -> typing.Generator[Request, Response, None]:
url, headers, body = self.prepare(
request.method, str(request.url), request.headers, request.content
)
headers["Content-Length"] = str(len(body))
yield build_request(
url=url, headers=headers, body=body, initial_request=request
)
class AsyncOAuth1Client(_OAuth1Client, httpx.AsyncClient):
auth_class = OAuth1Auth
def __init__(
self,
client_id,
client_secret=None,
token=None,
token_secret=None,
redirect_uri=None,
rsa_key=None,
verifier=None,
signature_method=SIGNATURE_HMAC_SHA1,
signature_type=SIGNATURE_TYPE_HEADER,
force_include_body=False,
**kwargs,
):
_client_kwargs = extract_client_kwargs(kwargs)
httpx.AsyncClient.__init__(self, **_client_kwargs)
_OAuth1Client.__init__(
self,
None,
client_id=client_id,
client_secret=client_secret,
token=token,
token_secret=token_secret,
redirect_uri=redirect_uri,
rsa_key=rsa_key,
verifier=verifier,
signature_method=signature_method,
signature_type=signature_type,
force_include_body=force_include_body,
**kwargs,
)
async def fetch_access_token(self, url, verifier=None, **kwargs):
"""Method for fetching an access token from the token endpoint.
This is the final step in the OAuth 1 workflow. An access token is
obtained using all previously obtained credentials, including the
verifier from the authorization step.
:param url: Access Token endpoint.
:param verifier: A verifier string to prove authorization was granted.
:param kwargs: Extra parameters to include for fetching access token.
:return: A token dict.
"""
if verifier:
self.auth.verifier = verifier
if not self.auth.verifier:
self.handle_error("missing_verifier", 'Missing "verifier" value')
token = await self._fetch_token(url, **kwargs)
self.auth.verifier = None
return token
async def _fetch_token(self, url, **kwargs):
resp = await self.post(url, **kwargs)
text = await resp.aread()
token = self.parse_response_token(resp.status_code, to_unicode(text))
self.token = token
return token
@staticmethod
def handle_error(error_type, error_description):
raise OAuthError(error_type, error_description)
class OAuth1Client(_OAuth1Client, httpx.Client):
auth_class = OAuth1Auth
def __init__(
self,
client_id,
client_secret=None,
token=None,
token_secret=None,
redirect_uri=None,
rsa_key=None,
verifier=None,
signature_method=SIGNATURE_HMAC_SHA1,
signature_type=SIGNATURE_TYPE_HEADER,
force_include_body=False,
**kwargs,
):
_client_kwargs = extract_client_kwargs(kwargs)
# app keyword was dropped!
app_value = _client_kwargs.pop("app", None)
if app_value is not None:
_client_kwargs["transport"] = httpx.WSGITransport(app=app_value)
httpx.Client.__init__(self, **_client_kwargs)
_OAuth1Client.__init__(
self,
self,
client_id=client_id,
client_secret=client_secret,
token=token,
token_secret=token_secret,
redirect_uri=redirect_uri,
rsa_key=rsa_key,
verifier=verifier,
signature_method=signature_method,
signature_type=signature_type,
force_include_body=force_include_body,
**kwargs,
)
@staticmethod
def handle_error(error_type, error_description):
raise OAuthError(error_type, error_description)

View File

@@ -0,0 +1,285 @@
import typing
from contextlib import asynccontextmanager
import httpx
from anyio import Lock # Import after httpx so import errors refer to httpx
from httpx import USE_CLIENT_DEFAULT
from httpx import Auth
from httpx import Request
from httpx import Response
from authlib.common.urls import url_decode
from authlib.oauth2.auth import ClientAuth
from authlib.oauth2.auth import TokenAuth
from authlib.oauth2.client import OAuth2Client as _OAuth2Client
from ..base_client import InvalidTokenError
from ..base_client import MissingTokenError
from ..base_client import OAuthError
from ..base_client import UnsupportedTokenTypeError
from .utils import HTTPX_CLIENT_KWARGS
from .utils import build_request
__all__ = [
"OAuth2Auth",
"OAuth2ClientAuth",
"AsyncOAuth2Client",
"OAuth2Client",
]
class OAuth2Auth(Auth, TokenAuth):
"""Sign requests for OAuth 2.0, currently only bearer token is supported."""
requires_request_body = True
def auth_flow(self, request: Request) -> typing.Generator[Request, Response, None]:
try:
url, headers, body = self.prepare(
str(request.url), request.headers, request.content
)
headers["Content-Length"] = str(len(body))
yield build_request(
url=url, headers=headers, body=body, initial_request=request
)
except KeyError as error:
description = f"Unsupported token_type: {str(error)}"
raise UnsupportedTokenTypeError(description=description) from error
class OAuth2ClientAuth(Auth, ClientAuth):
requires_request_body = True
def auth_flow(self, request: Request) -> typing.Generator[Request, Response, None]:
url, headers, body = self.prepare(
request.method, str(request.url), request.headers, request.content
)
headers["Content-Length"] = str(len(body))
yield build_request(
url=url, headers=headers, body=body, initial_request=request
)
class AsyncOAuth2Client(_OAuth2Client, httpx.AsyncClient):
SESSION_REQUEST_PARAMS = HTTPX_CLIENT_KWARGS
client_auth_class = OAuth2ClientAuth
token_auth_class = OAuth2Auth
oauth_error_class = OAuthError
def __init__(
self,
client_id=None,
client_secret=None,
token_endpoint_auth_method=None,
revocation_endpoint_auth_method=None,
scope=None,
redirect_uri=None,
token=None,
token_placement="header",
update_token=None,
leeway=60,
**kwargs,
):
# extract httpx.Client kwargs
client_kwargs = self._extract_session_request_params(kwargs)
httpx.AsyncClient.__init__(self, **client_kwargs)
# We use a Lock to synchronize coroutines to prevent
# multiple concurrent attempts to refresh the same token
self._token_refresh_lock = Lock()
_OAuth2Client.__init__(
self,
session=None,
client_id=client_id,
client_secret=client_secret,
token_endpoint_auth_method=token_endpoint_auth_method,
revocation_endpoint_auth_method=revocation_endpoint_auth_method,
scope=scope,
redirect_uri=redirect_uri,
token=token,
token_placement=token_placement,
update_token=update_token,
leeway=leeway,
**kwargs,
)
async def request(
self, method, url, withhold_token=False, auth=USE_CLIENT_DEFAULT, **kwargs
):
if not withhold_token and auth is USE_CLIENT_DEFAULT:
if not self.token:
raise MissingTokenError()
await self.ensure_active_token(self.token)
auth = self.token_auth
return await super().request(method, url, auth=auth, **kwargs)
@asynccontextmanager
async def stream(
self, method, url, withhold_token=False, auth=USE_CLIENT_DEFAULT, **kwargs
):
if not withhold_token and auth is USE_CLIENT_DEFAULT:
if not self.token:
raise MissingTokenError()
await self.ensure_active_token(self.token)
auth = self.token_auth
async with super().stream(method, url, auth=auth, **kwargs) as resp:
yield resp
async def ensure_active_token(self, token):
async with self._token_refresh_lock:
if self.token.is_expired(leeway=self.leeway):
refresh_token = token.get("refresh_token")
url = self.metadata.get("token_endpoint")
if refresh_token and url:
await self.refresh_token(url, refresh_token=refresh_token)
elif self.metadata.get("grant_type") == "client_credentials":
access_token = token["access_token"]
new_token = await self.fetch_token(
url, grant_type="client_credentials"
)
if self.update_token:
await self.update_token(new_token, access_token=access_token)
else:
raise InvalidTokenError()
async def _fetch_token(
self,
url,
body="",
headers=None,
auth=USE_CLIENT_DEFAULT,
method="POST",
**kwargs,
):
if method.upper() == "POST":
resp = await self.post(
url, data=dict(url_decode(body)), headers=headers, auth=auth, **kwargs
)
else:
if "?" in url:
url = "&".join([url, body])
else:
url = "?".join([url, body])
resp = await self.get(url, headers=headers, auth=auth, **kwargs)
for hook in self.compliance_hook["access_token_response"]:
resp = hook(resp)
return self.parse_response_token(resp)
async def _refresh_token(
self,
url,
refresh_token=None,
body="",
headers=None,
auth=USE_CLIENT_DEFAULT,
**kwargs,
):
resp = await self.post(
url, data=dict(url_decode(body)), headers=headers, auth=auth, **kwargs
)
for hook in self.compliance_hook["refresh_token_response"]:
resp = hook(resp)
token = self.parse_response_token(resp)
if "refresh_token" not in token:
self.token["refresh_token"] = refresh_token
if self.update_token:
await self.update_token(self.token, refresh_token=refresh_token)
return self.token
def _http_post(
self, url, body=None, auth=USE_CLIENT_DEFAULT, headers=None, **kwargs
):
return self.post(
url, data=dict(url_decode(body)), headers=headers, auth=auth, **kwargs
)
class OAuth2Client(_OAuth2Client, httpx.Client):
SESSION_REQUEST_PARAMS = HTTPX_CLIENT_KWARGS
client_auth_class = OAuth2ClientAuth
token_auth_class = OAuth2Auth
oauth_error_class = OAuthError
def __init__(
self,
client_id=None,
client_secret=None,
token_endpoint_auth_method=None,
revocation_endpoint_auth_method=None,
scope=None,
redirect_uri=None,
token=None,
token_placement="header",
update_token=None,
**kwargs,
):
# extract httpx.Client kwargs
client_kwargs = self._extract_session_request_params(kwargs)
# app keyword was dropped!
app_value = client_kwargs.pop("app", None)
if app_value is not None:
client_kwargs["transport"] = httpx.WSGITransport(app=app_value)
httpx.Client.__init__(self, **client_kwargs)
_OAuth2Client.__init__(
self,
session=self,
client_id=client_id,
client_secret=client_secret,
token_endpoint_auth_method=token_endpoint_auth_method,
revocation_endpoint_auth_method=revocation_endpoint_auth_method,
scope=scope,
redirect_uri=redirect_uri,
token=token,
token_placement=token_placement,
update_token=update_token,
**kwargs,
)
@staticmethod
def handle_error(error_type, error_description):
raise OAuthError(error_type, error_description)
def request(
self, method, url, withhold_token=False, auth=USE_CLIENT_DEFAULT, **kwargs
):
if not withhold_token and auth is USE_CLIENT_DEFAULT:
if not self.token:
raise MissingTokenError()
if not self.ensure_active_token(self.token):
raise InvalidTokenError()
auth = self.token_auth
return super().request(method, url, auth=auth, **kwargs)
def stream(
self, method, url, withhold_token=False, auth=USE_CLIENT_DEFAULT, **kwargs
):
if not withhold_token and auth is USE_CLIENT_DEFAULT:
if not self.token:
raise MissingTokenError()
if not self.ensure_active_token(self.token):
raise InvalidTokenError()
auth = self.token_auth
return super().stream(method, url, auth=auth, **kwargs)

View File

@@ -0,0 +1,41 @@
from httpx import Request
HTTPX_CLIENT_KWARGS = [
"headers",
"cookies",
"verify",
"cert",
"http1",
"http2",
"proxy",
"mounts",
"timeout",
"follow_redirects",
"limits",
"max_redirects",
"event_hooks",
"base_url",
"transport",
"trust_env",
"default_encoding",
]
def extract_client_kwargs(kwargs):
client_kwargs = {}
for k in HTTPX_CLIENT_KWARGS:
if k in kwargs:
client_kwargs[k] = kwargs.pop(k)
return client_kwargs
def build_request(url, headers, body, initial_request: Request) -> Request:
"""Make sure that all the data from initial request is passed to the updated object."""
updated_request = Request(
method=initial_request.method, url=url, headers=headers, content=body
)
if hasattr(initial_request, "extensions"):
updated_request.extensions = initial_request.extensions
return updated_request

View File

@@ -0,0 +1,28 @@
from authlib.oauth1 import SIGNATURE_HMAC_SHA1
from authlib.oauth1 import SIGNATURE_PLAINTEXT
from authlib.oauth1 import SIGNATURE_RSA_SHA1
from authlib.oauth1 import SIGNATURE_TYPE_BODY
from authlib.oauth1 import SIGNATURE_TYPE_HEADER
from authlib.oauth1 import SIGNATURE_TYPE_QUERY
from ..base_client import OAuthError
from .assertion_session import AssertionSession
from .oauth1_session import OAuth1Auth
from .oauth1_session import OAuth1Session
from .oauth2_session import OAuth2Auth
from .oauth2_session import OAuth2Session
__all__ = [
"OAuthError",
"OAuth1Session",
"OAuth1Auth",
"SIGNATURE_HMAC_SHA1",
"SIGNATURE_RSA_SHA1",
"SIGNATURE_PLAINTEXT",
"SIGNATURE_TYPE_HEADER",
"SIGNATURE_TYPE_QUERY",
"SIGNATURE_TYPE_BODY",
"OAuth2Session",
"OAuth2Auth",
"AssertionSession",
]

View File

@@ -0,0 +1,70 @@
from requests import Session
from authlib.oauth2.rfc7521 import AssertionClient
from authlib.oauth2.rfc7523 import JWTBearerGrant
from .oauth2_session import OAuth2Auth
from .utils import update_session_configure
class AssertionAuth(OAuth2Auth):
def ensure_active_token(self):
if self.client and (
not self.token or self.token.is_expired(self.client.leeway)
):
return self.client.refresh_token()
class AssertionSession(AssertionClient, Session):
"""Constructs a new Assertion Framework for OAuth 2.0 Authorization Grants
per RFC7521_.
.. _RFC7521: https://tools.ietf.org/html/rfc7521
"""
token_auth_class = AssertionAuth
JWT_BEARER_GRANT_TYPE = JWTBearerGrant.GRANT_TYPE
ASSERTION_METHODS = {
JWT_BEARER_GRANT_TYPE: JWTBearerGrant.sign,
}
DEFAULT_GRANT_TYPE = JWT_BEARER_GRANT_TYPE
def __init__(
self,
token_endpoint,
issuer,
subject,
audience=None,
grant_type=None,
claims=None,
token_placement="header",
scope=None,
default_timeout=None,
leeway=60,
**kwargs,
):
Session.__init__(self)
self.default_timeout = default_timeout
update_session_configure(self, kwargs)
AssertionClient.__init__(
self,
session=self,
token_endpoint=token_endpoint,
issuer=issuer,
subject=subject,
audience=audience,
grant_type=grant_type,
claims=claims,
token_placement=token_placement,
scope=scope,
leeway=leeway,
**kwargs,
)
def request(self, method, url, withhold_token=False, auth=None, **kwargs):
"""Send request with auto refresh token feature."""
if self.default_timeout:
kwargs.setdefault("timeout", self.default_timeout)
if not withhold_token and auth is None:
auth = self.token_auth
return super().request(method, url, auth=auth, **kwargs)

View File

@@ -0,0 +1,74 @@
from requests import Session
from requests.auth import AuthBase
from authlib.common.encoding import to_native
from authlib.oauth1 import SIGNATURE_HMAC_SHA1
from authlib.oauth1 import SIGNATURE_TYPE_HEADER
from authlib.oauth1 import ClientAuth
from authlib.oauth1.client import OAuth1Client
from ..base_client import OAuthError
from .utils import update_session_configure
class OAuth1Auth(AuthBase, ClientAuth):
"""Signs the request using OAuth 1 (RFC5849)."""
def __call__(self, req):
url, headers, body = self.prepare(req.method, req.url, req.headers, req.body)
req.url = to_native(url)
req.prepare_headers(headers)
if body:
req.body = body
return req
class OAuth1Session(OAuth1Client, Session):
auth_class = OAuth1Auth
def __init__(
self,
client_id,
client_secret=None,
token=None,
token_secret=None,
redirect_uri=None,
rsa_key=None,
verifier=None,
signature_method=SIGNATURE_HMAC_SHA1,
signature_type=SIGNATURE_TYPE_HEADER,
force_include_body=False,
**kwargs,
):
Session.__init__(self)
update_session_configure(self, kwargs)
OAuth1Client.__init__(
self,
session=self,
client_id=client_id,
client_secret=client_secret,
token=token,
token_secret=token_secret,
redirect_uri=redirect_uri,
rsa_key=rsa_key,
verifier=verifier,
signature_method=signature_method,
signature_type=signature_type,
force_include_body=force_include_body,
**kwargs,
)
def rebuild_auth(self, prepared_request, response):
"""When being redirected we should always strip Authorization
header, since nonce may not be reused as per OAuth spec.
"""
if "Authorization" in prepared_request.headers:
# If we get redirected to a new host, we should strip out
# any authentication headers.
prepared_request.headers.pop("Authorization", True)
prepared_request.prepare_auth(self.auth)
@staticmethod
def handle_error(error_type, error_description):
raise OAuthError(error_type, error_description)

View File

@@ -0,0 +1,140 @@
from requests import Session
from requests.auth import AuthBase
from authlib.oauth2.auth import ClientAuth
from authlib.oauth2.auth import TokenAuth
from authlib.oauth2.client import OAuth2Client
from ..base_client import InvalidTokenError
from ..base_client import MissingTokenError
from ..base_client import OAuthError
from ..base_client import UnsupportedTokenTypeError
from .utils import update_session_configure
__all__ = ["OAuth2Session", "OAuth2Auth"]
class OAuth2Auth(AuthBase, TokenAuth):
"""Sign requests for OAuth 2.0, currently only bearer token is supported."""
def ensure_active_token(self):
if self.client and not self.client.ensure_active_token(self.token):
raise InvalidTokenError()
def __call__(self, req):
self.ensure_active_token()
try:
req.url, req.headers, req.body = self.prepare(
req.url, req.headers, req.body
)
except KeyError as error:
description = f"Unsupported token_type: {str(error)}"
raise UnsupportedTokenTypeError(description=description) from error
return req
class OAuth2ClientAuth(AuthBase, ClientAuth):
"""Attaches OAuth Client Authentication to the given Request object."""
def __call__(self, req):
req.url, req.headers, req.body = self.prepare(
req.method, req.url, req.headers, req.body
)
return req
class OAuth2Session(OAuth2Client, Session):
"""Construct a new OAuth 2 client requests session.
:param client_id: Client ID, which you get from client registration.
:param client_secret: Client Secret, which you get from registration.
:param authorization_endpoint: URL of the authorization server's
authorization endpoint.
:param token_endpoint: URL of the authorization server's token endpoint.
:param token_endpoint_auth_method: client authentication method for
token endpoint.
:param revocation_endpoint: URL of the authorization server's OAuth 2.0
revocation endpoint.
:param revocation_endpoint_auth_method: client authentication method for
revocation endpoint.
:param scope: Scope that you needed to access user resources.
:param state: Shared secret to prevent CSRF attack.
:param redirect_uri: Redirect URI you registered as callback.
:param token: A dict of token attributes such as ``access_token``,
``token_type`` and ``expires_at``.
:param token_placement: The place to put token in HTTP request. Available
values: "header", "body", "uri".
:param update_token: A function for you to update token. It accept a
:class:`OAuth2Token` as parameter.
:param leeway: Time window in seconds before the actual expiration of the
authentication token, that the token is considered expired and will
be refreshed.
:param default_timeout: If settled, every requests will have a default timeout.
"""
client_auth_class = OAuth2ClientAuth
token_auth_class = OAuth2Auth
oauth_error_class = OAuthError
SESSION_REQUEST_PARAMS = (
"allow_redirects",
"timeout",
"cookies",
"files",
"proxies",
"hooks",
"stream",
"verify",
"cert",
"json",
)
def __init__(
self,
client_id=None,
client_secret=None,
token_endpoint_auth_method=None,
revocation_endpoint_auth_method=None,
scope=None,
state=None,
redirect_uri=None,
token=None,
token_placement="header",
update_token=None,
leeway=60,
default_timeout=None,
**kwargs,
):
Session.__init__(self)
self.default_timeout = default_timeout
update_session_configure(self, kwargs)
OAuth2Client.__init__(
self,
session=self,
client_id=client_id,
client_secret=client_secret,
token_endpoint_auth_method=token_endpoint_auth_method,
revocation_endpoint_auth_method=revocation_endpoint_auth_method,
scope=scope,
state=state,
redirect_uri=redirect_uri,
token=token,
token_placement=token_placement,
update_token=update_token,
leeway=leeway,
**kwargs,
)
def fetch_access_token(self, url=None, **kwargs):
"""Alias for fetch_token."""
return self.fetch_token(url, **kwargs)
def request(self, method, url, withhold_token=False, auth=None, **kwargs):
"""Send request with auto refresh token feature (if available)."""
if self.default_timeout:
kwargs.setdefault("timeout", self.default_timeout)
if not withhold_token and auth is None:
if not self.token:
raise MissingTokenError()
auth = self.token_auth
return super().request(method, url, auth=auth, **kwargs)

View File

@@ -0,0 +1,15 @@
REQUESTS_SESSION_KWARGS = [
"proxies",
"hooks",
"stream",
"verify",
"cert",
"max_redirects",
"trust_env",
]
def update_session_configure(session, kwargs):
for k in REQUESTS_SESSION_KWARGS:
if k in kwargs:
setattr(session, k, kwargs.pop(k))

View File

@@ -0,0 +1,19 @@
from .client_mixin import OAuth2ClientMixin
from .functions import create_bearer_token_validator
from .functions import create_query_client_func
from .functions import create_query_token_func
from .functions import create_revocation_endpoint
from .functions import create_save_token_func
from .tokens_mixins import OAuth2AuthorizationCodeMixin
from .tokens_mixins import OAuth2TokenMixin
__all__ = [
"OAuth2ClientMixin",
"OAuth2AuthorizationCodeMixin",
"OAuth2TokenMixin",
"create_query_client_func",
"create_save_token_func",
"create_query_token_func",
"create_revocation_endpoint",
"create_bearer_token_validator",
]

View File

@@ -0,0 +1,147 @@
import secrets
from sqlalchemy import Column
from sqlalchemy import Integer
from sqlalchemy import String
from sqlalchemy import Text
from authlib.common.encoding import json_dumps
from authlib.common.encoding import json_loads
from authlib.oauth2.rfc6749 import ClientMixin
from authlib.oauth2.rfc6749 import list_to_scope
from authlib.oauth2.rfc6749 import scope_to_list
class OAuth2ClientMixin(ClientMixin):
client_id = Column(String(48), index=True)
client_secret = Column(String(120))
client_id_issued_at = Column(Integer, nullable=False, default=0)
client_secret_expires_at = Column(Integer, nullable=False, default=0)
_client_metadata = Column("client_metadata", Text)
@property
def client_info(self):
"""Implementation for Client Info in OAuth 2.0 Dynamic Client
Registration Protocol via `Section 3.2.1`_.
.. _`Section 3.2.1`: https://tools.ietf.org/html/rfc7591#section-3.2.1
"""
return dict(
client_id=self.client_id,
client_secret=self.client_secret,
client_id_issued_at=self.client_id_issued_at,
client_secret_expires_at=self.client_secret_expires_at,
)
@property
def client_metadata(self):
if "client_metadata" in self.__dict__:
return self.__dict__["client_metadata"]
if self._client_metadata:
data = json_loads(self._client_metadata)
self.__dict__["client_metadata"] = data
return data
return {}
def set_client_metadata(self, value):
self._client_metadata = json_dumps(value)
if "client_metadata" in self.__dict__:
del self.__dict__["client_metadata"]
@property
def redirect_uris(self):
return self.client_metadata.get("redirect_uris", [])
@property
def token_endpoint_auth_method(self):
return self.client_metadata.get(
"token_endpoint_auth_method", "client_secret_basic"
)
@property
def grant_types(self):
return self.client_metadata.get("grant_types", [])
@property
def response_types(self):
return self.client_metadata.get("response_types", [])
@property
def client_name(self):
return self.client_metadata.get("client_name")
@property
def client_uri(self):
return self.client_metadata.get("client_uri")
@property
def logo_uri(self):
return self.client_metadata.get("logo_uri")
@property
def scope(self):
return self.client_metadata.get("scope", "")
@property
def contacts(self):
return self.client_metadata.get("contacts", [])
@property
def tos_uri(self):
return self.client_metadata.get("tos_uri")
@property
def policy_uri(self):
return self.client_metadata.get("policy_uri")
@property
def jwks_uri(self):
return self.client_metadata.get("jwks_uri")
@property
def jwks(self):
return self.client_metadata.get("jwks", [])
@property
def software_id(self):
return self.client_metadata.get("software_id")
@property
def software_version(self):
return self.client_metadata.get("software_version")
@property
def id_token_signed_response_alg(self):
return self.client_metadata.get("id_token_signed_response_alg")
def get_client_id(self):
return self.client_id
def get_default_redirect_uri(self):
if self.redirect_uris:
return self.redirect_uris[0]
def get_allowed_scope(self, scope):
if not scope:
return ""
allowed = set(self.scope.split())
scopes = scope_to_list(scope)
return list_to_scope([s for s in scopes if s in allowed])
def check_redirect_uri(self, redirect_uri):
return redirect_uri in self.redirect_uris
def check_client_secret(self, client_secret):
return secrets.compare_digest(self.client_secret, client_secret)
def check_endpoint_auth_method(self, method, endpoint):
if endpoint == "token":
return self.token_endpoint_auth_method == method
# TODO
return True
def check_response_type(self, response_type):
return response_type in self.response_types
def check_grant_type(self, grant_type):
return grant_type in self.grant_types

View File

@@ -0,0 +1,104 @@
import time
def create_query_client_func(session, client_model):
"""Create an ``query_client`` function that can be used in authorization
server.
:param session: SQLAlchemy session
:param client_model: Client model class
"""
def query_client(client_id):
q = session.query(client_model)
return q.filter_by(client_id=client_id).first()
return query_client
def create_save_token_func(session, token_model):
"""Create an ``save_token`` function that can be used in authorization
server.
:param session: SQLAlchemy session
:param token_model: Token model class
"""
def save_token(token, request):
if request.user:
user_id = request.user.get_user_id()
else:
user_id = None
client = request.client
item = token_model(client_id=client.client_id, user_id=user_id, **token)
session.add(item)
session.commit()
return save_token
def create_query_token_func(session, token_model):
"""Create an ``query_token`` function for revocation, introspection
token endpoints.
:param session: SQLAlchemy session
:param token_model: Token model class
"""
def query_token(token, token_type_hint):
q = session.query(token_model)
if token_type_hint == "access_token":
return q.filter_by(access_token=token).first()
elif token_type_hint == "refresh_token":
return q.filter_by(refresh_token=token).first()
# without token_type_hint
item = q.filter_by(access_token=token).first()
if item:
return item
return q.filter_by(refresh_token=token).first()
return query_token
def create_revocation_endpoint(session, token_model):
"""Create a revocation endpoint class with SQLAlchemy session
and token model.
:param session: SQLAlchemy session
:param token_model: Token model class
"""
from authlib.oauth2.rfc7009 import RevocationEndpoint
query_token = create_query_token_func(session, token_model)
class _RevocationEndpoint(RevocationEndpoint):
def query_token(self, token, token_type_hint):
return query_token(token, token_type_hint)
def revoke_token(self, token, request):
now = int(time.time())
hint = request.form.get("token_type_hint")
token.access_token_revoked_at = now
if hint != "access_token":
token.refresh_token_revoked_at = now
session.add(token)
session.commit()
return _RevocationEndpoint
def create_bearer_token_validator(session, token_model):
"""Create an bearer token validator class with SQLAlchemy session
and token model.
:param session: SQLAlchemy session
:param token_model: Token model class
"""
from authlib.oauth2.rfc6750 import BearerTokenValidator
class _BearerTokenValidator(BearerTokenValidator):
def authenticate_token(self, token_string):
q = session.query(token_model)
return q.filter_by(access_token=token_string).first()
return _BearerTokenValidator

View File

@@ -0,0 +1,76 @@
import time
from sqlalchemy import Column
from sqlalchemy import Integer
from sqlalchemy import String
from sqlalchemy import Text
from authlib.oauth2.rfc6749 import AuthorizationCodeMixin
from authlib.oauth2.rfc6749 import TokenMixin
class OAuth2AuthorizationCodeMixin(AuthorizationCodeMixin):
code = Column(String(120), unique=True, nullable=False)
client_id = Column(String(48))
redirect_uri = Column(Text, default="")
response_type = Column(Text, default="")
scope = Column(Text, default="")
nonce = Column(Text)
auth_time = Column(Integer, nullable=False, default=lambda: int(time.time()))
acr = Column(Text, nullable=True)
amr = Column(Text, nullable=True)
code_challenge = Column(Text)
code_challenge_method = Column(String(48))
def is_expired(self):
return self.auth_time + 300 < time.time()
def get_redirect_uri(self):
return self.redirect_uri
def get_scope(self):
return self.scope
def get_auth_time(self):
return self.auth_time
def get_acr(self):
return self.acr
def get_amr(self):
return self.amr.split() if self.amr else []
def get_nonce(self):
return self.nonce
class OAuth2TokenMixin(TokenMixin):
client_id = Column(String(48))
token_type = Column(String(40))
access_token = Column(String(255), unique=True, nullable=False)
refresh_token = Column(String(255), index=True)
scope = Column(Text, default="")
issued_at = Column(Integer, nullable=False, default=lambda: int(time.time()))
access_token_revoked_at = Column(Integer, nullable=False, default=0)
refresh_token_revoked_at = Column(Integer, nullable=False, default=0)
expires_in = Column(Integer, nullable=False, default=0)
def check_client(self, client):
return self.client_id == client.get_client_id()
def get_scope(self):
return self.scope
def get_expires_in(self):
return self.expires_in
def is_revoked(self):
return self.access_token_revoked_at or self.refresh_token_revoked_at
def is_expired(self):
if not self.expires_in:
return False
expires_at = self.issued_at + self.expires_in
return expires_at < time.time()

View File

@@ -0,0 +1,26 @@
from ..base_client import BaseOAuth
from ..base_client import OAuthError
from .apps import StarletteOAuth1App
from .apps import StarletteOAuth2App
from .integration import StarletteIntegration
class OAuth(BaseOAuth):
oauth1_client_cls = StarletteOAuth1App
oauth2_client_cls = StarletteOAuth2App
framework_integration_cls = StarletteIntegration
def __init__(self, config=None, cache=None, fetch_token=None, update_token=None):
super().__init__(
cache=cache, fetch_token=fetch_token, update_token=update_token
)
self.config = config
__all__ = [
"OAuth",
"OAuthError",
"StarletteIntegration",
"StarletteOAuth1App",
"StarletteOAuth2App",
]

Some files were not shown because too many files have changed in this diff Show More