This commit is contained in:
Iliyan Angelov
2025-12-01 06:50:10 +02:00
parent 91f51bc6fe
commit 62c1fe5951
4682 changed files with 544807 additions and 31208 deletions

View File

@@ -0,0 +1,16 @@
"""authlib.jose.rfc7517.
~~~~~~~~~~~~~~~~~~~~~
This module represents a direct implementation of
JSON Web Key (JWK).
https://tools.ietf.org/html/rfc7517
"""
from ._cryptography_key import load_pem_key
from .asymmetric_key import AsymmetricKey
from .base_key import Key
from .jwk import JsonWebKey
from .key_set import KeySet
__all__ = ["Key", "AsymmetricKey", "KeySet", "JsonWebKey", "load_pem_key"]

View File

@@ -0,0 +1,35 @@
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives.serialization import load_pem_private_key
from cryptography.hazmat.primitives.serialization import load_pem_public_key
from cryptography.hazmat.primitives.serialization import load_ssh_public_key
from cryptography.x509 import load_pem_x509_certificate
from authlib.common.encoding import to_bytes
def load_pem_key(raw, ssh_type=None, key_type=None, password=None):
raw = to_bytes(raw)
if ssh_type and raw.startswith(ssh_type):
return load_ssh_public_key(raw, backend=default_backend())
if key_type == "public":
return load_pem_public_key(raw, backend=default_backend())
if key_type == "private" or password is not None:
return load_pem_private_key(raw, password=password, backend=default_backend())
if b"PUBLIC" in raw:
return load_pem_public_key(raw, backend=default_backend())
if b"PRIVATE" in raw:
return load_pem_private_key(raw, password=password, backend=default_backend())
if b"CERTIFICATE" in raw:
cert = load_pem_x509_certificate(raw, default_backend())
return cert.public_key()
try:
return load_pem_private_key(raw, password=password, backend=default_backend())
except ValueError:
return load_pem_public_key(raw, backend=default_backend())

View File

@@ -0,0 +1,196 @@
from cryptography.hazmat.primitives.serialization import BestAvailableEncryption
from cryptography.hazmat.primitives.serialization import Encoding
from cryptography.hazmat.primitives.serialization import NoEncryption
from cryptography.hazmat.primitives.serialization import PrivateFormat
from cryptography.hazmat.primitives.serialization import PublicFormat
from authlib.common.encoding import to_bytes
from ._cryptography_key import load_pem_key
from .base_key import Key
class AsymmetricKey(Key):
"""This is the base class for a JSON Web Key."""
PUBLIC_KEY_FIELDS = []
PRIVATE_KEY_FIELDS = []
PRIVATE_KEY_CLS = bytes
PUBLIC_KEY_CLS = bytes
SSH_PUBLIC_PREFIX = b""
def __init__(self, private_key=None, public_key=None, options=None):
super().__init__(options)
self.private_key = private_key
self.public_key = public_key
@property
def public_only(self):
if self.private_key:
return False
if "d" in self.tokens:
return False
return True
def get_op_key(self, operation):
"""Get the raw key for the given key_op. This method will also
check if the given key_op is supported by this key.
:param operation: key operation value, such as "sign", "encrypt".
:return: raw key
"""
self.check_key_op(operation)
if operation in self.PUBLIC_KEY_OPS:
return self.get_public_key()
return self.get_private_key()
def get_public_key(self):
if self.public_key:
return self.public_key
private_key = self.get_private_key()
if private_key:
return private_key.public_key()
return self.public_key
def get_private_key(self):
if self.private_key:
return self.private_key
if self.tokens:
self.load_raw_key()
return self.private_key
def load_raw_key(self):
if "d" in self.tokens:
self.private_key = self.load_private_key()
else:
self.public_key = self.load_public_key()
def load_dict_key(self):
if self.private_key:
self._dict_data.update(self.dumps_private_key())
else:
self._dict_data.update(self.dumps_public_key())
def dumps_private_key(self):
raise NotImplementedError()
def dumps_public_key(self):
raise NotImplementedError()
def load_private_key(self):
raise NotImplementedError()
def load_public_key(self):
raise NotImplementedError()
def as_dict(self, is_private=False, **params):
"""Represent this key as a dict of the JSON Web Key."""
tokens = self.tokens
if is_private and "d" not in tokens:
raise ValueError("This is a public key")
kid = tokens.get("kid")
if "d" in tokens and not is_private:
# filter out private fields
tokens = {k: tokens[k] for k in tokens if k in self.PUBLIC_KEY_FIELDS}
tokens["kty"] = self.kty
if kid:
tokens["kid"] = kid
if not kid:
tokens["kid"] = self.thumbprint()
tokens.update(params)
return tokens
def as_key(self, is_private=False):
"""Represent this key as raw key."""
if is_private:
return self.get_private_key()
return self.get_public_key()
def as_bytes(self, encoding=None, is_private=False, password=None):
"""Export key into PEM/DER format bytes.
:param encoding: "PEM" or "DER"
:param is_private: export private key or public key
:param password: encrypt private key with password
:return: bytes
"""
if encoding is None or encoding == "PEM":
encoding = Encoding.PEM
elif encoding == "DER":
encoding = Encoding.DER
else:
raise ValueError(f"Invalid encoding: {encoding!r}")
raw_key = self.as_key(is_private)
if is_private:
if not raw_key:
raise ValueError("This is a public key")
if password is None:
encryption_algorithm = NoEncryption()
else:
encryption_algorithm = BestAvailableEncryption(to_bytes(password))
return raw_key.private_bytes(
encoding=encoding,
format=PrivateFormat.PKCS8,
encryption_algorithm=encryption_algorithm,
)
return raw_key.public_bytes(
encoding=encoding,
format=PublicFormat.SubjectPublicKeyInfo,
)
def as_pem(self, is_private=False, password=None):
return self.as_bytes(is_private=is_private, password=password)
def as_der(self, is_private=False, password=None):
return self.as_bytes(encoding="DER", is_private=is_private, password=password)
@classmethod
def import_dict_key(cls, raw, options=None):
cls.check_required_fields(raw)
key = cls(options=options)
key._dict_data = raw
return key
@classmethod
def import_key(cls, raw, options=None):
if isinstance(raw, cls):
if options is not None:
raw.options.update(options)
return raw
if isinstance(raw, cls.PUBLIC_KEY_CLS):
key = cls(public_key=raw, options=options)
elif isinstance(raw, cls.PRIVATE_KEY_CLS):
key = cls(private_key=raw, options=options)
elif isinstance(raw, dict):
key = cls.import_dict_key(raw, options)
else:
if options is not None:
password = options.pop("password", None)
else:
password = None
raw_key = load_pem_key(raw, cls.SSH_PUBLIC_PREFIX, password=password)
if isinstance(raw_key, cls.PUBLIC_KEY_CLS):
key = cls(public_key=raw_key, options=options)
elif isinstance(raw_key, cls.PRIVATE_KEY_CLS):
key = cls(private_key=raw_key, options=options)
else:
raise ValueError("Invalid data for importing key")
return key
@classmethod
def validate_raw_key(cls, key):
return isinstance(key, cls.PUBLIC_KEY_CLS) or isinstance(
key, cls.PRIVATE_KEY_CLS
)
@classmethod
def generate_key(cls, crv_or_size, options=None, is_private=False):
raise NotImplementedError()

View File

@@ -0,0 +1,120 @@
import hashlib
from collections import OrderedDict
from authlib.common.encoding import json_dumps
from authlib.common.encoding import to_bytes
from authlib.common.encoding import to_unicode
from authlib.common.encoding import urlsafe_b64encode
from ..errors import InvalidUseError
class Key:
"""This is the base class for a JSON Web Key."""
kty = "_"
ALLOWED_PARAMS = ["use", "key_ops", "alg", "kid", "x5u", "x5c", "x5t", "x5t#S256"]
PRIVATE_KEY_OPS = [
"sign",
"decrypt",
"unwrapKey",
]
PUBLIC_KEY_OPS = [
"verify",
"encrypt",
"wrapKey",
]
REQUIRED_JSON_FIELDS = []
def __init__(self, options=None):
self.options = options or {}
self._dict_data = {}
@property
def tokens(self):
if not self._dict_data:
self.load_dict_key()
rv = dict(self._dict_data)
rv["kty"] = self.kty
for k in self.ALLOWED_PARAMS:
if k not in rv and k in self.options:
rv[k] = self.options[k]
return rv
@property
def kid(self):
return self.tokens.get("kid")
def keys(self):
return self.tokens.keys()
def __getitem__(self, item):
return self.tokens[item]
@property
def public_only(self):
raise NotImplementedError()
def load_raw_key(self):
raise NotImplementedError()
def load_dict_key(self):
raise NotImplementedError()
def check_key_op(self, operation):
"""Check if the given key_op is supported by this key.
:param operation: key operation value, such as "sign", "encrypt".
:raise: ValueError
"""
key_ops = self.tokens.get("key_ops")
if key_ops is not None and operation not in key_ops:
raise ValueError(f'Unsupported key_op "{operation}"')
if operation in self.PRIVATE_KEY_OPS and self.public_only:
raise ValueError(f'Invalid key_op "{operation}" for public key')
use = self.tokens.get("use")
if use:
if operation in ["sign", "verify"]:
if use != "sig":
raise InvalidUseError()
elif operation in ["decrypt", "encrypt", "wrapKey", "unwrapKey"]:
if use != "enc":
raise InvalidUseError()
def as_dict(self, is_private=False, **params):
raise NotImplementedError()
def as_json(self, is_private=False, **params):
"""Represent this key as a JSON string."""
obj = self.as_dict(is_private, **params)
return json_dumps(obj)
def thumbprint(self):
"""Implementation of RFC7638 JSON Web Key (JWK) Thumbprint."""
fields = list(self.REQUIRED_JSON_FIELDS)
fields.append("kty")
fields.sort()
data = OrderedDict()
for k in fields:
data[k] = self.tokens[k]
json_data = json_dumps(data)
digest_data = hashlib.sha256(to_bytes(json_data)).digest()
return to_unicode(urlsafe_b64encode(digest_data))
@classmethod
def check_required_fields(cls, data):
for k in cls.REQUIRED_JSON_FIELDS:
if k not in data:
raise ValueError(f'Missing required field: "{k}"')
@classmethod
def validate_raw_key(cls, key):
raise NotImplementedError()

View File

@@ -0,0 +1,64 @@
from authlib.common.encoding import json_loads
from ._cryptography_key import load_pem_key
from .key_set import KeySet
class JsonWebKey:
JWK_KEY_CLS = {}
@classmethod
def generate_key(cls, kty, crv_or_size, options=None, is_private=False):
"""Generate a Key with the given key type, curve name or bit size.
:param kty: string of ``oct``, ``RSA``, ``EC``, ``OKP``
:param crv_or_size: curve name or bit size
:param options: a dict of other options for Key
:param is_private: create a private key or public key
:return: Key instance
"""
key_cls = cls.JWK_KEY_CLS[kty]
return key_cls.generate_key(crv_or_size, options, is_private)
@classmethod
def import_key(cls, raw, options=None):
"""Import a Key from bytes, string, PEM or dict.
:return: Key instance
"""
kty = None
if options is not None:
kty = options.get("kty")
if kty is None and isinstance(raw, dict):
kty = raw.get("kty")
if kty is None:
raw_key = load_pem_key(raw)
for _kty in cls.JWK_KEY_CLS:
key_cls = cls.JWK_KEY_CLS[_kty]
if key_cls.validate_raw_key(raw_key):
return key_cls.import_key(raw_key, options)
key_cls = cls.JWK_KEY_CLS[kty]
return key_cls.import_key(raw, options)
@classmethod
def import_key_set(cls, raw):
"""Import KeySet from string, dict or a list of keys.
:return: KeySet instance
"""
raw = _transform_raw_key(raw)
if isinstance(raw, dict) and "keys" in raw:
keys = raw.get("keys")
return KeySet([cls.import_key(k) for k in keys])
raise ValueError("Invalid key set format")
def _transform_raw_key(raw):
if isinstance(raw, str) and raw.startswith("{") and raw.endswith("}"):
return json_loads(raw)
elif isinstance(raw, (tuple, list)):
return {"keys": raw}
return raw

View File

@@ -0,0 +1,53 @@
from authlib.common.encoding import json_dumps
class KeySet:
"""This class represents a JSON Web Key Set."""
def __init__(self, keys):
self.keys = keys
def as_dict(self, is_private=False, **params):
"""Represent this key as a dict of the JSON Web Key Set."""
return {"keys": [k.as_dict(is_private, **params) for k in self.keys]}
def as_json(self, is_private=False, **params):
"""Represent this key set as a JSON string."""
obj = self.as_dict(is_private, **params)
return json_dumps(obj)
def find_by_kid(self, kid, **params):
"""Find the key matches the given kid value.
:param kid: A string of kid
:return: Key instance
:raise: ValueError
"""
# Proposed fix, feel free to do something else but the idea is that we take the only key
# of the set if no kid is specified
if kid is None and len(self.keys) == 1:
return self.keys[0]
keys = [key for key in self.keys if key.kid == kid]
if params:
keys = list(_filter_keys_by_params(keys, **params))
if keys:
return keys[0]
raise ValueError("Key not found")
def _filter_keys_by_params(keys, **params):
_use = params.get("use")
_alg = params.get("alg")
for key in keys:
designed_use = key.tokens.get("use")
if designed_use and _use and designed_use != _use:
continue
designed_alg = key.tokens.get("alg")
if designed_alg and _alg and designed_alg != _alg:
continue
yield key