This commit is contained in:
Iliyan Angelov
2025-12-01 06:50:10 +02:00
parent 91f51bc6fe
commit 62c1fe5951
4682 changed files with 544807 additions and 31208 deletions

View File

@@ -0,0 +1,19 @@
from .client_mixin import OAuth2ClientMixin
from .functions import create_bearer_token_validator
from .functions import create_query_client_func
from .functions import create_query_token_func
from .functions import create_revocation_endpoint
from .functions import create_save_token_func
from .tokens_mixins import OAuth2AuthorizationCodeMixin
from .tokens_mixins import OAuth2TokenMixin
__all__ = [
"OAuth2ClientMixin",
"OAuth2AuthorizationCodeMixin",
"OAuth2TokenMixin",
"create_query_client_func",
"create_save_token_func",
"create_query_token_func",
"create_revocation_endpoint",
"create_bearer_token_validator",
]

View File

@@ -0,0 +1,147 @@
import secrets
from sqlalchemy import Column
from sqlalchemy import Integer
from sqlalchemy import String
from sqlalchemy import Text
from authlib.common.encoding import json_dumps
from authlib.common.encoding import json_loads
from authlib.oauth2.rfc6749 import ClientMixin
from authlib.oauth2.rfc6749 import list_to_scope
from authlib.oauth2.rfc6749 import scope_to_list
class OAuth2ClientMixin(ClientMixin):
client_id = Column(String(48), index=True)
client_secret = Column(String(120))
client_id_issued_at = Column(Integer, nullable=False, default=0)
client_secret_expires_at = Column(Integer, nullable=False, default=0)
_client_metadata = Column("client_metadata", Text)
@property
def client_info(self):
"""Implementation for Client Info in OAuth 2.0 Dynamic Client
Registration Protocol via `Section 3.2.1`_.
.. _`Section 3.2.1`: https://tools.ietf.org/html/rfc7591#section-3.2.1
"""
return dict(
client_id=self.client_id,
client_secret=self.client_secret,
client_id_issued_at=self.client_id_issued_at,
client_secret_expires_at=self.client_secret_expires_at,
)
@property
def client_metadata(self):
if "client_metadata" in self.__dict__:
return self.__dict__["client_metadata"]
if self._client_metadata:
data = json_loads(self._client_metadata)
self.__dict__["client_metadata"] = data
return data
return {}
def set_client_metadata(self, value):
self._client_metadata = json_dumps(value)
if "client_metadata" in self.__dict__:
del self.__dict__["client_metadata"]
@property
def redirect_uris(self):
return self.client_metadata.get("redirect_uris", [])
@property
def token_endpoint_auth_method(self):
return self.client_metadata.get(
"token_endpoint_auth_method", "client_secret_basic"
)
@property
def grant_types(self):
return self.client_metadata.get("grant_types", [])
@property
def response_types(self):
return self.client_metadata.get("response_types", [])
@property
def client_name(self):
return self.client_metadata.get("client_name")
@property
def client_uri(self):
return self.client_metadata.get("client_uri")
@property
def logo_uri(self):
return self.client_metadata.get("logo_uri")
@property
def scope(self):
return self.client_metadata.get("scope", "")
@property
def contacts(self):
return self.client_metadata.get("contacts", [])
@property
def tos_uri(self):
return self.client_metadata.get("tos_uri")
@property
def policy_uri(self):
return self.client_metadata.get("policy_uri")
@property
def jwks_uri(self):
return self.client_metadata.get("jwks_uri")
@property
def jwks(self):
return self.client_metadata.get("jwks", [])
@property
def software_id(self):
return self.client_metadata.get("software_id")
@property
def software_version(self):
return self.client_metadata.get("software_version")
@property
def id_token_signed_response_alg(self):
return self.client_metadata.get("id_token_signed_response_alg")
def get_client_id(self):
return self.client_id
def get_default_redirect_uri(self):
if self.redirect_uris:
return self.redirect_uris[0]
def get_allowed_scope(self, scope):
if not scope:
return ""
allowed = set(self.scope.split())
scopes = scope_to_list(scope)
return list_to_scope([s for s in scopes if s in allowed])
def check_redirect_uri(self, redirect_uri):
return redirect_uri in self.redirect_uris
def check_client_secret(self, client_secret):
return secrets.compare_digest(self.client_secret, client_secret)
def check_endpoint_auth_method(self, method, endpoint):
if endpoint == "token":
return self.token_endpoint_auth_method == method
# TODO
return True
def check_response_type(self, response_type):
return response_type in self.response_types
def check_grant_type(self, grant_type):
return grant_type in self.grant_types

View File

@@ -0,0 +1,104 @@
import time
def create_query_client_func(session, client_model):
"""Create an ``query_client`` function that can be used in authorization
server.
:param session: SQLAlchemy session
:param client_model: Client model class
"""
def query_client(client_id):
q = session.query(client_model)
return q.filter_by(client_id=client_id).first()
return query_client
def create_save_token_func(session, token_model):
"""Create an ``save_token`` function that can be used in authorization
server.
:param session: SQLAlchemy session
:param token_model: Token model class
"""
def save_token(token, request):
if request.user:
user_id = request.user.get_user_id()
else:
user_id = None
client = request.client
item = token_model(client_id=client.client_id, user_id=user_id, **token)
session.add(item)
session.commit()
return save_token
def create_query_token_func(session, token_model):
"""Create an ``query_token`` function for revocation, introspection
token endpoints.
:param session: SQLAlchemy session
:param token_model: Token model class
"""
def query_token(token, token_type_hint):
q = session.query(token_model)
if token_type_hint == "access_token":
return q.filter_by(access_token=token).first()
elif token_type_hint == "refresh_token":
return q.filter_by(refresh_token=token).first()
# without token_type_hint
item = q.filter_by(access_token=token).first()
if item:
return item
return q.filter_by(refresh_token=token).first()
return query_token
def create_revocation_endpoint(session, token_model):
"""Create a revocation endpoint class with SQLAlchemy session
and token model.
:param session: SQLAlchemy session
:param token_model: Token model class
"""
from authlib.oauth2.rfc7009 import RevocationEndpoint
query_token = create_query_token_func(session, token_model)
class _RevocationEndpoint(RevocationEndpoint):
def query_token(self, token, token_type_hint):
return query_token(token, token_type_hint)
def revoke_token(self, token, request):
now = int(time.time())
hint = request.form.get("token_type_hint")
token.access_token_revoked_at = now
if hint != "access_token":
token.refresh_token_revoked_at = now
session.add(token)
session.commit()
return _RevocationEndpoint
def create_bearer_token_validator(session, token_model):
"""Create an bearer token validator class with SQLAlchemy session
and token model.
:param session: SQLAlchemy session
:param token_model: Token model class
"""
from authlib.oauth2.rfc6750 import BearerTokenValidator
class _BearerTokenValidator(BearerTokenValidator):
def authenticate_token(self, token_string):
q = session.query(token_model)
return q.filter_by(access_token=token_string).first()
return _BearerTokenValidator

View File

@@ -0,0 +1,76 @@
import time
from sqlalchemy import Column
from sqlalchemy import Integer
from sqlalchemy import String
from sqlalchemy import Text
from authlib.oauth2.rfc6749 import AuthorizationCodeMixin
from authlib.oauth2.rfc6749 import TokenMixin
class OAuth2AuthorizationCodeMixin(AuthorizationCodeMixin):
code = Column(String(120), unique=True, nullable=False)
client_id = Column(String(48))
redirect_uri = Column(Text, default="")
response_type = Column(Text, default="")
scope = Column(Text, default="")
nonce = Column(Text)
auth_time = Column(Integer, nullable=False, default=lambda: int(time.time()))
acr = Column(Text, nullable=True)
amr = Column(Text, nullable=True)
code_challenge = Column(Text)
code_challenge_method = Column(String(48))
def is_expired(self):
return self.auth_time + 300 < time.time()
def get_redirect_uri(self):
return self.redirect_uri
def get_scope(self):
return self.scope
def get_auth_time(self):
return self.auth_time
def get_acr(self):
return self.acr
def get_amr(self):
return self.amr.split() if self.amr else []
def get_nonce(self):
return self.nonce
class OAuth2TokenMixin(TokenMixin):
client_id = Column(String(48))
token_type = Column(String(40))
access_token = Column(String(255), unique=True, nullable=False)
refresh_token = Column(String(255), index=True)
scope = Column(Text, default="")
issued_at = Column(Integer, nullable=False, default=lambda: int(time.time()))
access_token_revoked_at = Column(Integer, nullable=False, default=0)
refresh_token_revoked_at = Column(Integer, nullable=False, default=0)
expires_in = Column(Integer, nullable=False, default=0)
def check_client(self, client):
return self.client_id == client.get_client_id()
def get_scope(self):
return self.scope
def get_expires_in(self):
return self.expires_in
def is_revoked(self):
return self.access_token_revoked_at or self.refresh_token_revoked_at
def is_expired(self):
if not self.expires_in:
return False
expires_at = self.issued_at + self.expires_in
return expires_at < time.time()