This commit is contained in:
Iliyan Angelov
2025-12-01 06:50:10 +02:00
parent 91f51bc6fe
commit 62c1fe5951
4682 changed files with 544807 additions and 31208 deletions

View File

@@ -0,0 +1,39 @@
from .ec_key import ECKey
from .jwe_algs import JWE_ALG_ALGORITHMS
from .jwe_algs import AESAlgorithm
from .jwe_algs import ECDHESAlgorithm
from .jwe_algs import u32be_len_input
from .jwe_encs import JWE_ENC_ALGORITHMS
from .jwe_encs import CBCHS2EncAlgorithm
from .jwe_zips import DeflateZipAlgorithm
from .jws_algs import JWS_ALGORITHMS
from .oct_key import OctKey
from .rsa_key import RSAKey
def register_jws_rfc7518(cls):
for algorithm in JWS_ALGORITHMS:
cls.register_algorithm(algorithm)
def register_jwe_rfc7518(cls):
for algorithm in JWE_ALG_ALGORITHMS:
cls.register_algorithm(algorithm)
for algorithm in JWE_ENC_ALGORITHMS:
cls.register_algorithm(algorithm)
cls.register_algorithm(DeflateZipAlgorithm())
__all__ = [
"register_jws_rfc7518",
"register_jwe_rfc7518",
"OctKey",
"RSAKey",
"ECKey",
"u32be_len_input",
"AESAlgorithm",
"ECDHESAlgorithm",
"CBCHS2EncAlgorithm",
]

View File

@@ -0,0 +1,108 @@
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives.asymmetric import ec
from cryptography.hazmat.primitives.asymmetric.ec import SECP256K1
from cryptography.hazmat.primitives.asymmetric.ec import SECP256R1
from cryptography.hazmat.primitives.asymmetric.ec import SECP384R1
from cryptography.hazmat.primitives.asymmetric.ec import SECP521R1
from cryptography.hazmat.primitives.asymmetric.ec import (
EllipticCurvePrivateKeyWithSerialization,
)
from cryptography.hazmat.primitives.asymmetric.ec import EllipticCurvePrivateNumbers
from cryptography.hazmat.primitives.asymmetric.ec import EllipticCurvePublicKey
from cryptography.hazmat.primitives.asymmetric.ec import EllipticCurvePublicNumbers
from authlib.common.encoding import base64_to_int
from authlib.common.encoding import int_to_base64
from ..rfc7517 import AsymmetricKey
class ECKey(AsymmetricKey):
"""Key class of the ``EC`` key type."""
kty = "EC"
DSS_CURVES = {
"P-256": SECP256R1,
"P-384": SECP384R1,
"P-521": SECP521R1,
# https://tools.ietf.org/html/rfc8812#section-3.1
"secp256k1": SECP256K1,
}
CURVES_DSS = {
SECP256R1.name: "P-256",
SECP384R1.name: "P-384",
SECP521R1.name: "P-521",
SECP256K1.name: "secp256k1",
}
REQUIRED_JSON_FIELDS = ["crv", "x", "y"]
PUBLIC_KEY_FIELDS = REQUIRED_JSON_FIELDS
PRIVATE_KEY_FIELDS = ["crv", "d", "x", "y"]
PUBLIC_KEY_CLS = EllipticCurvePublicKey
PRIVATE_KEY_CLS = EllipticCurvePrivateKeyWithSerialization
SSH_PUBLIC_PREFIX = b"ecdsa-sha2-"
def exchange_shared_key(self, pubkey):
# # used in ECDHESAlgorithm
private_key = self.get_private_key()
if private_key:
return private_key.exchange(ec.ECDH(), pubkey)
raise ValueError("Invalid key for exchanging shared key")
@property
def curve_key_size(self):
raw_key = self.get_private_key()
if not raw_key:
raw_key = self.public_key
return raw_key.curve.key_size
def load_private_key(self):
curve = self.DSS_CURVES[self._dict_data["crv"]]()
public_numbers = EllipticCurvePublicNumbers(
base64_to_int(self._dict_data["x"]),
base64_to_int(self._dict_data["y"]),
curve,
)
private_numbers = EllipticCurvePrivateNumbers(
base64_to_int(self.tokens["d"]), public_numbers
)
return private_numbers.private_key(default_backend())
def load_public_key(self):
curve = self.DSS_CURVES[self._dict_data["crv"]]()
public_numbers = EllipticCurvePublicNumbers(
base64_to_int(self._dict_data["x"]),
base64_to_int(self._dict_data["y"]),
curve,
)
return public_numbers.public_key(default_backend())
def dumps_private_key(self):
numbers = self.private_key.private_numbers()
return {
"crv": self.CURVES_DSS[self.private_key.curve.name],
"x": int_to_base64(numbers.public_numbers.x),
"y": int_to_base64(numbers.public_numbers.y),
"d": int_to_base64(numbers.private_value),
}
def dumps_public_key(self):
numbers = self.public_key.public_numbers()
return {
"crv": self.CURVES_DSS[numbers.curve.name],
"x": int_to_base64(numbers.x),
"y": int_to_base64(numbers.y),
}
@classmethod
def generate_key(cls, crv="P-256", options=None, is_private=False) -> "ECKey":
if crv not in cls.DSS_CURVES:
raise ValueError(f'Invalid crv value: "{crv}"')
raw_key = ec.generate_private_key(
curve=cls.DSS_CURVES[crv](),
backend=default_backend(),
)
if not is_private:
raw_key = raw_key.public_key()
return cls.import_key(raw_key, options=options)

View File

@@ -0,0 +1,350 @@
import os
import struct
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.asymmetric import padding
from cryptography.hazmat.primitives.ciphers import Cipher
from cryptography.hazmat.primitives.ciphers.algorithms import AES
from cryptography.hazmat.primitives.ciphers.modes import GCM
from cryptography.hazmat.primitives.kdf.concatkdf import ConcatKDFHash
from cryptography.hazmat.primitives.keywrap import aes_key_unwrap
from cryptography.hazmat.primitives.keywrap import aes_key_wrap
from authlib.common.encoding import to_bytes
from authlib.common.encoding import to_native
from authlib.common.encoding import urlsafe_b64decode
from authlib.common.encoding import urlsafe_b64encode
from authlib.jose.rfc7516 import JWEAlgorithm
from .ec_key import ECKey
from .oct_key import OctKey
from .rsa_key import RSAKey
class DirectAlgorithm(JWEAlgorithm):
name = "dir"
description = "Direct use of a shared symmetric key"
def prepare_key(self, raw_data):
return OctKey.import_key(raw_data)
def generate_preset(self, enc_alg, key):
return {}
def wrap(self, enc_alg, headers, key, preset=None):
cek = key.get_op_key("encrypt")
if len(cek) * 8 != enc_alg.CEK_SIZE:
raise ValueError('Invalid "cek" length')
return {"ek": b"", "cek": cek}
def unwrap(self, enc_alg, ek, headers, key):
cek = key.get_op_key("decrypt")
if len(cek) * 8 != enc_alg.CEK_SIZE:
raise ValueError('Invalid "cek" length')
return cek
class RSAAlgorithm(JWEAlgorithm):
#: A key of size 2048 bits or larger MUST be used with these algorithms
#: RSA1_5, RSA-OAEP, RSA-OAEP-256
key_size = 2048
def __init__(self, name, description, pad_fn):
self.name = name
self.description = description
self.padding = pad_fn
def prepare_key(self, raw_data):
return RSAKey.import_key(raw_data)
def generate_preset(self, enc_alg, key):
cek = enc_alg.generate_cek()
return {"cek": cek}
def wrap(self, enc_alg, headers, key, preset=None):
if preset and "cek" in preset:
cek = preset["cek"]
else:
cek = enc_alg.generate_cek()
op_key = key.get_op_key("wrapKey")
if op_key.key_size < self.key_size:
raise ValueError("A key of size 2048 bits or larger MUST be used")
ek = op_key.encrypt(cek, self.padding)
return {"ek": ek, "cek": cek}
def unwrap(self, enc_alg, ek, headers, key):
# it will raise ValueError if failed
op_key = key.get_op_key("unwrapKey")
cek = op_key.decrypt(ek, self.padding)
if len(cek) * 8 != enc_alg.CEK_SIZE:
raise ValueError('Invalid "cek" length')
return cek
class AESAlgorithm(JWEAlgorithm):
def __init__(self, key_size):
self.name = f"A{key_size}KW"
self.description = f"AES Key Wrap using {key_size}-bit key"
self.key_size = key_size
def prepare_key(self, raw_data):
return OctKey.import_key(raw_data)
def generate_preset(self, enc_alg, key):
cek = enc_alg.generate_cek()
return {"cek": cek}
def _check_key(self, key):
if len(key) * 8 != self.key_size:
raise ValueError(f"A key of size {self.key_size} bits is required.")
def wrap_cek(self, cek, key):
op_key = key.get_op_key("wrapKey")
self._check_key(op_key)
ek = aes_key_wrap(op_key, cek, default_backend())
return {"ek": ek, "cek": cek}
def wrap(self, enc_alg, headers, key, preset=None):
if preset and "cek" in preset:
cek = preset["cek"]
else:
cek = enc_alg.generate_cek()
return self.wrap_cek(cek, key)
def unwrap(self, enc_alg, ek, headers, key):
op_key = key.get_op_key("unwrapKey")
self._check_key(op_key)
cek = aes_key_unwrap(op_key, ek, default_backend())
if len(cek) * 8 != enc_alg.CEK_SIZE:
raise ValueError('Invalid "cek" length')
return cek
class AESGCMAlgorithm(JWEAlgorithm):
EXTRA_HEADERS = frozenset(["iv", "tag"])
def __init__(self, key_size):
self.name = f"A{key_size}GCMKW"
self.description = f"Key wrapping with AES GCM using {key_size}-bit key"
self.key_size = key_size
def prepare_key(self, raw_data):
return OctKey.import_key(raw_data)
def generate_preset(self, enc_alg, key):
cek = enc_alg.generate_cek()
return {"cek": cek}
def _check_key(self, key):
if len(key) * 8 != self.key_size:
raise ValueError(f"A key of size {self.key_size} bits is required.")
def wrap(self, enc_alg, headers, key, preset=None):
if preset and "cek" in preset:
cek = preset["cek"]
else:
cek = enc_alg.generate_cek()
op_key = key.get_op_key("wrapKey")
self._check_key(op_key)
#: https://tools.ietf.org/html/rfc7518#section-4.7.1.1
#: The "iv" (initialization vector) Header Parameter value is the
#: base64url-encoded representation of the 96-bit IV value
iv_size = 96
iv = os.urandom(iv_size // 8)
cipher = Cipher(AES(op_key), GCM(iv), backend=default_backend())
enc = cipher.encryptor()
ek = enc.update(cek) + enc.finalize()
h = {
"iv": to_native(urlsafe_b64encode(iv)),
"tag": to_native(urlsafe_b64encode(enc.tag)),
}
return {"ek": ek, "cek": cek, "header": h}
def unwrap(self, enc_alg, ek, headers, key):
op_key = key.get_op_key("unwrapKey")
self._check_key(op_key)
iv = headers.get("iv")
if not iv:
raise ValueError('Missing "iv" in headers')
tag = headers.get("tag")
if not tag:
raise ValueError('Missing "tag" in headers')
iv = urlsafe_b64decode(to_bytes(iv))
tag = urlsafe_b64decode(to_bytes(tag))
cipher = Cipher(AES(op_key), GCM(iv, tag), backend=default_backend())
d = cipher.decryptor()
cek = d.update(ek) + d.finalize()
if len(cek) * 8 != enc_alg.CEK_SIZE:
raise ValueError('Invalid "cek" length')
return cek
class ECDHESAlgorithm(JWEAlgorithm):
EXTRA_HEADERS = ["epk", "apu", "apv"]
ALLOWED_KEY_CLS = ECKey
# https://tools.ietf.org/html/rfc7518#section-4.6
def __init__(self, key_size=None):
if key_size is None:
self.name = "ECDH-ES"
self.description = "ECDH-ES in the Direct Key Agreement mode"
else:
self.name = f"ECDH-ES+A{key_size}KW"
self.description = (
f"ECDH-ES using Concat KDF and CEK wrapped with A{key_size}KW"
)
self.key_size = key_size
self.aeskw = AESAlgorithm(key_size)
def prepare_key(self, raw_data):
if isinstance(raw_data, self.ALLOWED_KEY_CLS):
return raw_data
return ECKey.import_key(raw_data)
def generate_preset(self, enc_alg, key):
epk = self._generate_ephemeral_key(key)
h = self._prepare_headers(epk)
preset = {"epk": epk, "header": h}
if self.key_size is not None:
cek = enc_alg.generate_cek()
preset["cek"] = cek
return preset
def compute_fixed_info(self, headers, bit_size):
# AlgorithmID
if self.key_size is None:
alg_id = u32be_len_input(headers["enc"])
else:
alg_id = u32be_len_input(headers["alg"])
# PartyUInfo
apu_info = u32be_len_input(headers.get("apu"), True)
# PartyVInfo
apv_info = u32be_len_input(headers.get("apv"), True)
# SuppPubInfo
pub_info = struct.pack(">I", bit_size)
return alg_id + apu_info + apv_info + pub_info
def compute_derived_key(self, shared_key, fixed_info, bit_size):
ckdf = ConcatKDFHash(
algorithm=hashes.SHA256(),
length=bit_size // 8,
otherinfo=fixed_info,
backend=default_backend(),
)
return ckdf.derive(shared_key)
def deliver(self, key, pubkey, headers, bit_size):
shared_key = key.exchange_shared_key(pubkey)
fixed_info = self.compute_fixed_info(headers, bit_size)
return self.compute_derived_key(shared_key, fixed_info, bit_size)
def _generate_ephemeral_key(self, key):
return key.generate_key(key["crv"], is_private=True)
def _prepare_headers(self, epk):
# REQUIRED_JSON_FIELDS contains only public fields
pub_epk = {k: epk[k] for k in epk.REQUIRED_JSON_FIELDS}
pub_epk["kty"] = epk.kty
return {"epk": pub_epk}
def wrap(self, enc_alg, headers, key, preset=None):
if self.key_size is None:
bit_size = enc_alg.CEK_SIZE
else:
bit_size = self.key_size
if preset and "epk" in preset:
epk = preset["epk"]
h = {}
else:
epk = self._generate_ephemeral_key(key)
h = self._prepare_headers(epk)
public_key = key.get_op_key("wrapKey")
dk = self.deliver(epk, public_key, headers, bit_size)
if self.key_size is None:
return {"ek": b"", "cek": dk, "header": h}
if preset and "cek" in preset:
preset_for_kw = {"cek": preset["cek"]}
else:
preset_for_kw = None
kek = self.aeskw.prepare_key(dk)
rv = self.aeskw.wrap(enc_alg, headers, kek, preset_for_kw)
rv["header"] = h
return rv
def unwrap(self, enc_alg, ek, headers, key):
if "epk" not in headers:
raise ValueError('Missing "epk" in headers')
if self.key_size is None:
bit_size = enc_alg.CEK_SIZE
else:
bit_size = self.key_size
epk = key.import_key(headers["epk"])
public_key = epk.get_op_key("wrapKey")
dk = self.deliver(key, public_key, headers, bit_size)
if self.key_size is None:
return dk
kek = self.aeskw.prepare_key(dk)
return self.aeskw.unwrap(enc_alg, ek, headers, kek)
def u32be_len_input(s, base64=False):
if not s:
return b"\x00\x00\x00\x00"
if base64:
s = urlsafe_b64decode(to_bytes(s))
else:
s = to_bytes(s)
return struct.pack(">I", len(s)) + s
JWE_ALG_ALGORITHMS = [
DirectAlgorithm(), # dir
RSAAlgorithm("RSA1_5", "RSAES-PKCS1-v1_5", padding.PKCS1v15()),
RSAAlgorithm(
"RSA-OAEP",
"RSAES OAEP using default parameters",
padding.OAEP(padding.MGF1(hashes.SHA1()), hashes.SHA1(), None),
),
RSAAlgorithm(
"RSA-OAEP-256",
"RSAES OAEP using SHA-256 and MGF1 with SHA-256",
padding.OAEP(padding.MGF1(hashes.SHA256()), hashes.SHA256(), None),
),
AESAlgorithm(128), # A128KW
AESAlgorithm(192), # A192KW
AESAlgorithm(256), # A256KW
AESGCMAlgorithm(128), # A128GCMKW
AESGCMAlgorithm(192), # A192GCMKW
AESGCMAlgorithm(256), # A256GCMKW
ECDHESAlgorithm(None), # ECDH-ES
ECDHESAlgorithm(128), # ECDH-ES+A128KW
ECDHESAlgorithm(192), # ECDH-ES+A192KW
ECDHESAlgorithm(256), # ECDH-ES+A256KW
]
# 'PBES2-HS256+A128KW': '',
# 'PBES2-HS384+A192KW': '',
# 'PBES2-HS512+A256KW': '',

View File

@@ -0,0 +1,147 @@
"""authlib.jose.rfc7518.
~~~~~~~~~~~~~~~~~~~~
Cryptographic Algorithms for Cryptographic Algorithms for Content
Encryption per `Section 5`_.
.. _`Section 5`: https://tools.ietf.org/html/rfc7518#section-5
"""
import hashlib
import hmac
from cryptography.exceptions import InvalidTag
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives.ciphers import Cipher
from cryptography.hazmat.primitives.ciphers.algorithms import AES
from cryptography.hazmat.primitives.ciphers.modes import CBC
from cryptography.hazmat.primitives.ciphers.modes import GCM
from cryptography.hazmat.primitives.padding import PKCS7
from ..rfc7516 import JWEEncAlgorithm
from .util import encode_int
class CBCHS2EncAlgorithm(JWEEncAlgorithm):
# The IV used is a 128-bit value generated randomly or
# pseudo-randomly for use in the cipher.
IV_SIZE = 128
def __init__(self, key_size, hash_type):
self.name = f"A{key_size}CBC-HS{hash_type}"
tpl = "AES_{}_CBC_HMAC_SHA_{} authenticated encryption algorithm"
self.description = tpl.format(key_size, hash_type)
# bit length
self.key_size = key_size
# byte length
self.key_len = key_size // 8
self.CEK_SIZE = key_size * 2
self.hash_alg = getattr(hashlib, f"sha{hash_type}")
def _hmac(self, ciphertext, aad, iv, key):
al = encode_int(len(aad) * 8, 64)
msg = aad + iv + ciphertext + al
d = hmac.new(key, msg, self.hash_alg).digest()
return d[: self.key_len]
def encrypt(self, msg, aad, iv, key):
"""Key Encryption with AES_CBC_HMAC_SHA2.
:param msg: text to be encrypt in bytes
:param aad: additional authenticated data in bytes
:param iv: initialization vector in bytes
:param key: encrypted key in bytes
:return: (ciphertext, iv, tag)
"""
self.check_iv(iv)
hkey = key[: self.key_len]
ekey = key[self.key_len :]
pad = PKCS7(AES.block_size).padder()
padded_data = pad.update(msg) + pad.finalize()
cipher = Cipher(AES(ekey), CBC(iv), backend=default_backend())
enc = cipher.encryptor()
ciphertext = enc.update(padded_data) + enc.finalize()
tag = self._hmac(ciphertext, aad, iv, hkey)
return ciphertext, tag
def decrypt(self, ciphertext, aad, iv, tag, key):
"""Key Decryption with AES AES_CBC_HMAC_SHA2.
:param ciphertext: ciphertext in bytes
:param aad: additional authenticated data in bytes
:param iv: initialization vector in bytes
:param tag: authentication tag in bytes
:param key: encrypted key in bytes
:return: message
"""
self.check_iv(iv)
hkey = key[: self.key_len]
dkey = key[self.key_len :]
_tag = self._hmac(ciphertext, aad, iv, hkey)
if not hmac.compare_digest(_tag, tag):
raise InvalidTag()
cipher = Cipher(AES(dkey), CBC(iv), backend=default_backend())
d = cipher.decryptor()
data = d.update(ciphertext) + d.finalize()
unpad = PKCS7(AES.block_size).unpadder()
return unpad.update(data) + unpad.finalize()
class GCMEncAlgorithm(JWEEncAlgorithm):
# Use of an IV of size 96 bits is REQUIRED with this algorithm.
# https://tools.ietf.org/html/rfc7518#section-5.3
IV_SIZE = 96
def __init__(self, key_size):
self.name = f"A{key_size}GCM"
self.description = f"AES GCM using {key_size}-bit key"
self.key_size = key_size
self.CEK_SIZE = key_size
def encrypt(self, msg, aad, iv, key):
"""Key Encryption with AES GCM.
:param msg: text to be encrypt in bytes
:param aad: additional authenticated data in bytes
:param iv: initialization vector in bytes
:param key: encrypted key in bytes
:return: (ciphertext, iv, tag)
"""
self.check_iv(iv)
cipher = Cipher(AES(key), GCM(iv), backend=default_backend())
enc = cipher.encryptor()
enc.authenticate_additional_data(aad)
ciphertext = enc.update(msg) + enc.finalize()
return ciphertext, enc.tag
def decrypt(self, ciphertext, aad, iv, tag, key):
"""Key Decryption with AES GCM.
:param ciphertext: ciphertext in bytes
:param aad: additional authenticated data in bytes
:param iv: initialization vector in bytes
:param tag: authentication tag in bytes
:param key: encrypted key in bytes
:return: message
"""
self.check_iv(iv)
cipher = Cipher(AES(key), GCM(iv, tag), backend=default_backend())
d = cipher.decryptor()
d.authenticate_additional_data(aad)
return d.update(ciphertext) + d.finalize()
JWE_ENC_ALGORITHMS = [
CBCHS2EncAlgorithm(128, 256), # A128CBC-HS256
CBCHS2EncAlgorithm(192, 384), # A192CBC-HS384
CBCHS2EncAlgorithm(256, 512), # A256CBC-HS512
GCMEncAlgorithm(128), # A128GCM
GCMEncAlgorithm(192), # A192GCM
GCMEncAlgorithm(256), # A256GCM
]

View File

@@ -0,0 +1,34 @@
import zlib
from ..rfc7516 import JsonWebEncryption
from ..rfc7516 import JWEZipAlgorithm
GZIP_HEAD = bytes([120, 156])
MAX_SIZE = 250 * 1024
class DeflateZipAlgorithm(JWEZipAlgorithm):
name = "DEF"
description = "DEFLATE"
def compress(self, s: bytes) -> bytes:
"""Compress bytes data with DEFLATE algorithm."""
data = zlib.compress(s)
# https://datatracker.ietf.org/doc/html/rfc1951
# since DEF is always gzip, we can drop gzip headers and tail
return data[2:-4]
def decompress(self, s: bytes) -> bytes:
"""Decompress DEFLATE bytes data."""
if s.startswith(GZIP_HEAD):
decompressor = zlib.decompressobj()
else:
decompressor = zlib.decompressobj(-zlib.MAX_WBITS)
value = decompressor.decompress(s, MAX_SIZE)
if decompressor.unconsumed_tail:
raise ValueError(f"Decompressed string exceeds {MAX_SIZE} bytes")
return value
def register_jwe_rfc7518():
JsonWebEncryption.register_algorithm(DeflateZipAlgorithm())

View File

@@ -0,0 +1,221 @@
"""authlib.jose.rfc7518.
~~~~~~~~~~~~~~~~~~~~
"alg" (Algorithm) Header Parameter Values for JWS per `Section 3`_.
.. _`Section 3`: https://tools.ietf.org/html/rfc7518#section-3
"""
import hashlib
import hmac
from cryptography.exceptions import InvalidSignature
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.asymmetric import padding
from cryptography.hazmat.primitives.asymmetric.ec import ECDSA
from cryptography.hazmat.primitives.asymmetric.utils import decode_dss_signature
from cryptography.hazmat.primitives.asymmetric.utils import encode_dss_signature
from ..rfc7515 import JWSAlgorithm
from .ec_key import ECKey
from .oct_key import OctKey
from .rsa_key import RSAKey
from .util import decode_int
from .util import encode_int
class NoneAlgorithm(JWSAlgorithm):
name = "none"
description = "No digital signature or MAC performed"
def prepare_key(self, raw_data):
return None
def sign(self, msg, key):
return b""
def verify(self, msg, sig, key):
return sig == b""
class HMACAlgorithm(JWSAlgorithm):
"""HMAC using SHA algorithms for JWS. Available algorithms:
- HS256: HMAC using SHA-256
- HS384: HMAC using SHA-384
- HS512: HMAC using SHA-512
"""
SHA256 = hashlib.sha256
SHA384 = hashlib.sha384
SHA512 = hashlib.sha512
def __init__(self, sha_type):
self.name = f"HS{sha_type}"
self.description = f"HMAC using SHA-{sha_type}"
self.hash_alg = getattr(self, f"SHA{sha_type}")
def prepare_key(self, raw_data):
return OctKey.import_key(raw_data)
def sign(self, msg, key):
# it is faster than the one in cryptography
op_key = key.get_op_key("sign")
return hmac.new(op_key, msg, self.hash_alg).digest()
def verify(self, msg, sig, key):
op_key = key.get_op_key("verify")
v_sig = hmac.new(op_key, msg, self.hash_alg).digest()
return hmac.compare_digest(sig, v_sig)
class RSAAlgorithm(JWSAlgorithm):
"""RSA using SHA algorithms for JWS. Available algorithms:
- RS256: RSASSA-PKCS1-v1_5 using SHA-256
- RS384: RSASSA-PKCS1-v1_5 using SHA-384
- RS512: RSASSA-PKCS1-v1_5 using SHA-512
"""
SHA256 = hashes.SHA256
SHA384 = hashes.SHA384
SHA512 = hashes.SHA512
def __init__(self, sha_type):
self.name = f"RS{sha_type}"
self.description = f"RSASSA-PKCS1-v1_5 using SHA-{sha_type}"
self.hash_alg = getattr(self, f"SHA{sha_type}")
self.padding = padding.PKCS1v15()
def prepare_key(self, raw_data):
return RSAKey.import_key(raw_data)
def sign(self, msg, key):
op_key = key.get_op_key("sign")
return op_key.sign(msg, self.padding, self.hash_alg())
def verify(self, msg, sig, key):
op_key = key.get_op_key("verify")
try:
op_key.verify(sig, msg, self.padding, self.hash_alg())
return True
except InvalidSignature:
return False
class ECAlgorithm(JWSAlgorithm):
"""ECDSA using SHA algorithms for JWS. Available algorithms:
- ES256: ECDSA using P-256 and SHA-256
- ES384: ECDSA using P-384 and SHA-384
- ES512: ECDSA using P-521 and SHA-512
"""
SHA256 = hashes.SHA256
SHA384 = hashes.SHA384
SHA512 = hashes.SHA512
def __init__(self, name, curve, sha_type):
self.name = name
self.curve = curve
self.description = f"ECDSA using {self.curve} and SHA-{sha_type}"
self.hash_alg = getattr(self, f"SHA{sha_type}")
def prepare_key(self, raw_data):
key = ECKey.import_key(raw_data)
if key["crv"] != self.curve:
raise ValueError(
f'Key for "{self.name}" not supported, only "{self.curve}" allowed'
)
return key
def sign(self, msg, key):
op_key = key.get_op_key("sign")
der_sig = op_key.sign(msg, ECDSA(self.hash_alg()))
r, s = decode_dss_signature(der_sig)
size = key.curve_key_size
return encode_int(r, size) + encode_int(s, size)
def verify(self, msg, sig, key):
key_size = key.curve_key_size
length = (key_size + 7) // 8
if len(sig) != 2 * length:
return False
r = decode_int(sig[:length])
s = decode_int(sig[length:])
der_sig = encode_dss_signature(r, s)
try:
op_key = key.get_op_key("verify")
op_key.verify(der_sig, msg, ECDSA(self.hash_alg()))
return True
except InvalidSignature:
return False
class RSAPSSAlgorithm(JWSAlgorithm):
"""RSASSA-PSS using SHA algorithms for JWS. Available algorithms:
- PS256: RSASSA-PSS using SHA-256 and MGF1 with SHA-256
- PS384: RSASSA-PSS using SHA-384 and MGF1 with SHA-384
- PS512: RSASSA-PSS using SHA-512 and MGF1 with SHA-512
"""
SHA256 = hashes.SHA256
SHA384 = hashes.SHA384
SHA512 = hashes.SHA512
def __init__(self, sha_type):
self.name = f"PS{sha_type}"
tpl = "RSASSA-PSS using SHA-{} and MGF1 with SHA-{}"
self.description = tpl.format(sha_type, sha_type)
self.hash_alg = getattr(self, f"SHA{sha_type}")
def prepare_key(self, raw_data):
return RSAKey.import_key(raw_data)
def sign(self, msg, key):
op_key = key.get_op_key("sign")
return op_key.sign(
msg,
padding.PSS(
mgf=padding.MGF1(self.hash_alg()), salt_length=self.hash_alg.digest_size
),
self.hash_alg(),
)
def verify(self, msg, sig, key):
op_key = key.get_op_key("verify")
try:
op_key.verify(
sig,
msg,
padding.PSS(
mgf=padding.MGF1(self.hash_alg()),
salt_length=self.hash_alg.digest_size,
),
self.hash_alg(),
)
return True
except InvalidSignature:
return False
JWS_ALGORITHMS = [
NoneAlgorithm(), # none
HMACAlgorithm(256), # HS256
HMACAlgorithm(384), # HS384
HMACAlgorithm(512), # HS512
RSAAlgorithm(256), # RS256
RSAAlgorithm(384), # RS384
RSAAlgorithm(512), # RS512
ECAlgorithm("ES256", "P-256", 256),
ECAlgorithm("ES384", "P-384", 384),
ECAlgorithm("ES512", "P-521", 512),
ECAlgorithm("ES256K", "secp256k1", 256), # defined in RFC8812
RSAPSSAlgorithm(256), # PS256
RSAPSSAlgorithm(384), # PS384
RSAPSSAlgorithm(512), # PS512
]

View File

@@ -0,0 +1,96 @@
import secrets
from authlib.common.encoding import to_bytes
from authlib.common.encoding import to_unicode
from authlib.common.encoding import urlsafe_b64decode
from authlib.common.encoding import urlsafe_b64encode
from ..rfc7517 import Key
POSSIBLE_UNSAFE_KEYS = (
b"-----BEGIN ",
b"---- BEGIN ",
b"ssh-rsa ",
b"ssh-dss ",
b"ssh-ed25519 ",
b"ecdsa-sha2-",
)
class OctKey(Key):
"""Key class of the ``oct`` key type."""
kty = "oct"
REQUIRED_JSON_FIELDS = ["k"]
def __init__(self, raw_key=None, options=None):
super().__init__(options)
self.raw_key = raw_key
@property
def public_only(self):
return False
def get_op_key(self, operation):
"""Get the raw key for the given key_op. This method will also
check if the given key_op is supported by this key.
:param operation: key operation value, such as "sign", "encrypt".
:return: raw key
"""
self.check_key_op(operation)
if not self.raw_key:
self.load_raw_key()
return self.raw_key
def load_raw_key(self):
self.raw_key = urlsafe_b64decode(to_bytes(self.tokens["k"]))
def load_dict_key(self):
k = to_unicode(urlsafe_b64encode(self.raw_key))
self._dict_data = {"kty": self.kty, "k": k}
def as_dict(self, is_private=False, **params):
tokens = self.tokens
if "kid" not in tokens:
tokens["kid"] = self.thumbprint()
tokens.update(params)
return tokens
@classmethod
def validate_raw_key(cls, key):
return isinstance(key, bytes)
@classmethod
def import_key(cls, raw, options=None):
"""Import a key from bytes, string, or dict data."""
if isinstance(raw, cls):
if options is not None:
raw.options.update(options)
return raw
if isinstance(raw, dict):
cls.check_required_fields(raw)
key = cls(options=options)
key._dict_data = raw
else:
raw_key = to_bytes(raw)
# security check
if raw_key.startswith(POSSIBLE_UNSAFE_KEYS):
raise ValueError("This key may not be safe to import")
key = cls(raw_key=raw_key, options=options)
return key
@classmethod
def generate_key(cls, key_size=256, options=None, is_private=True):
"""Generate a ``OctKey`` with the given bit size."""
if not is_private:
raise ValueError("oct key can not be generated as public")
if key_size % 8 != 0:
raise ValueError("Invalid bit size for oct key")
return cls.import_key(secrets.token_bytes(int(key_size / 8)), options)

View File

@@ -0,0 +1,127 @@
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives.asymmetric import rsa
from cryptography.hazmat.primitives.asymmetric.rsa import RSAPrivateKeyWithSerialization
from cryptography.hazmat.primitives.asymmetric.rsa import RSAPrivateNumbers
from cryptography.hazmat.primitives.asymmetric.rsa import RSAPublicKey
from cryptography.hazmat.primitives.asymmetric.rsa import RSAPublicNumbers
from cryptography.hazmat.primitives.asymmetric.rsa import rsa_crt_dmp1
from cryptography.hazmat.primitives.asymmetric.rsa import rsa_crt_dmq1
from cryptography.hazmat.primitives.asymmetric.rsa import rsa_crt_iqmp
from cryptography.hazmat.primitives.asymmetric.rsa import rsa_recover_prime_factors
from authlib.common.encoding import base64_to_int
from authlib.common.encoding import int_to_base64
from ..rfc7517 import AsymmetricKey
class RSAKey(AsymmetricKey):
"""Key class of the ``RSA`` key type."""
kty = "RSA"
PUBLIC_KEY_CLS = RSAPublicKey
PRIVATE_KEY_CLS = RSAPrivateKeyWithSerialization
PUBLIC_KEY_FIELDS = ["e", "n"]
PRIVATE_KEY_FIELDS = ["d", "dp", "dq", "e", "n", "p", "q", "qi"]
REQUIRED_JSON_FIELDS = ["e", "n"]
SSH_PUBLIC_PREFIX = b"ssh-rsa"
def dumps_private_key(self):
numbers = self.private_key.private_numbers()
return {
"n": int_to_base64(numbers.public_numbers.n),
"e": int_to_base64(numbers.public_numbers.e),
"d": int_to_base64(numbers.d),
"p": int_to_base64(numbers.p),
"q": int_to_base64(numbers.q),
"dp": int_to_base64(numbers.dmp1),
"dq": int_to_base64(numbers.dmq1),
"qi": int_to_base64(numbers.iqmp),
}
def dumps_public_key(self):
numbers = self.public_key.public_numbers()
return {"n": int_to_base64(numbers.n), "e": int_to_base64(numbers.e)}
def load_private_key(self):
obj = self._dict_data
if "oth" in obj: # pragma: no cover
# https://tools.ietf.org/html/rfc7518#section-6.3.2.7
raise ValueError('"oth" is not supported yet')
public_numbers = RSAPublicNumbers(
base64_to_int(obj["e"]), base64_to_int(obj["n"])
)
if has_all_prime_factors(obj):
numbers = RSAPrivateNumbers(
d=base64_to_int(obj["d"]),
p=base64_to_int(obj["p"]),
q=base64_to_int(obj["q"]),
dmp1=base64_to_int(obj["dp"]),
dmq1=base64_to_int(obj["dq"]),
iqmp=base64_to_int(obj["qi"]),
public_numbers=public_numbers,
)
else:
d = base64_to_int(obj["d"])
p, q = rsa_recover_prime_factors(public_numbers.n, d, public_numbers.e)
numbers = RSAPrivateNumbers(
d=d,
p=p,
q=q,
dmp1=rsa_crt_dmp1(d, p),
dmq1=rsa_crt_dmq1(d, q),
iqmp=rsa_crt_iqmp(p, q),
public_numbers=public_numbers,
)
return numbers.private_key(default_backend())
def load_public_key(self):
numbers = RSAPublicNumbers(
base64_to_int(self._dict_data["e"]), base64_to_int(self._dict_data["n"])
)
return numbers.public_key(default_backend())
@classmethod
def generate_key(cls, key_size=2048, options=None, is_private=False) -> "RSAKey":
if key_size < 512:
raise ValueError("key_size must not be less than 512")
if key_size % 8 != 0:
raise ValueError("Invalid key_size for RSAKey")
raw_key = rsa.generate_private_key(
public_exponent=65537,
key_size=key_size,
backend=default_backend(),
)
if not is_private:
raw_key = raw_key.public_key()
return cls.import_key(raw_key, options=options)
@classmethod
def import_dict_key(cls, raw, options=None):
cls.check_required_fields(raw)
key = cls(options=options)
key._dict_data = raw
if "d" in raw and not has_all_prime_factors(raw):
# reload dict key
key.load_raw_key()
key.load_dict_key()
return key
def has_all_prime_factors(obj):
props = ["p", "q", "dp", "dq", "qi"]
props_found = [prop in obj for prop in props]
if all(props_found):
return True
if any(props_found):
raise ValueError(
"RSA key must include all parameters if any are present besides d"
)
return False

View File

@@ -0,0 +1,12 @@
import binascii
def encode_int(num, bits):
length = ((bits + 7) // 8) * 2
padded_hex = f"{num:0{length}x}"
big_endian = binascii.a2b_hex(padded_hex.encode("ascii"))
return big_endian
def decode_int(b):
return int(binascii.b2a_hex(b), 16)