updates
This commit is contained in:
@@ -0,0 +1,39 @@
|
||||
from .ec_key import ECKey
|
||||
from .jwe_algs import JWE_ALG_ALGORITHMS
|
||||
from .jwe_algs import AESAlgorithm
|
||||
from .jwe_algs import ECDHESAlgorithm
|
||||
from .jwe_algs import u32be_len_input
|
||||
from .jwe_encs import JWE_ENC_ALGORITHMS
|
||||
from .jwe_encs import CBCHS2EncAlgorithm
|
||||
from .jwe_zips import DeflateZipAlgorithm
|
||||
from .jws_algs import JWS_ALGORITHMS
|
||||
from .oct_key import OctKey
|
||||
from .rsa_key import RSAKey
|
||||
|
||||
|
||||
def register_jws_rfc7518(cls):
|
||||
for algorithm in JWS_ALGORITHMS:
|
||||
cls.register_algorithm(algorithm)
|
||||
|
||||
|
||||
def register_jwe_rfc7518(cls):
|
||||
for algorithm in JWE_ALG_ALGORITHMS:
|
||||
cls.register_algorithm(algorithm)
|
||||
|
||||
for algorithm in JWE_ENC_ALGORITHMS:
|
||||
cls.register_algorithm(algorithm)
|
||||
|
||||
cls.register_algorithm(DeflateZipAlgorithm())
|
||||
|
||||
|
||||
__all__ = [
|
||||
"register_jws_rfc7518",
|
||||
"register_jwe_rfc7518",
|
||||
"OctKey",
|
||||
"RSAKey",
|
||||
"ECKey",
|
||||
"u32be_len_input",
|
||||
"AESAlgorithm",
|
||||
"ECDHESAlgorithm",
|
||||
"CBCHS2EncAlgorithm",
|
||||
]
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,108 @@
|
||||
from cryptography.hazmat.backends import default_backend
|
||||
from cryptography.hazmat.primitives.asymmetric import ec
|
||||
from cryptography.hazmat.primitives.asymmetric.ec import SECP256K1
|
||||
from cryptography.hazmat.primitives.asymmetric.ec import SECP256R1
|
||||
from cryptography.hazmat.primitives.asymmetric.ec import SECP384R1
|
||||
from cryptography.hazmat.primitives.asymmetric.ec import SECP521R1
|
||||
from cryptography.hazmat.primitives.asymmetric.ec import (
|
||||
EllipticCurvePrivateKeyWithSerialization,
|
||||
)
|
||||
from cryptography.hazmat.primitives.asymmetric.ec import EllipticCurvePrivateNumbers
|
||||
from cryptography.hazmat.primitives.asymmetric.ec import EllipticCurvePublicKey
|
||||
from cryptography.hazmat.primitives.asymmetric.ec import EllipticCurvePublicNumbers
|
||||
|
||||
from authlib.common.encoding import base64_to_int
|
||||
from authlib.common.encoding import int_to_base64
|
||||
|
||||
from ..rfc7517 import AsymmetricKey
|
||||
|
||||
|
||||
class ECKey(AsymmetricKey):
|
||||
"""Key class of the ``EC`` key type."""
|
||||
|
||||
kty = "EC"
|
||||
DSS_CURVES = {
|
||||
"P-256": SECP256R1,
|
||||
"P-384": SECP384R1,
|
||||
"P-521": SECP521R1,
|
||||
# https://tools.ietf.org/html/rfc8812#section-3.1
|
||||
"secp256k1": SECP256K1,
|
||||
}
|
||||
CURVES_DSS = {
|
||||
SECP256R1.name: "P-256",
|
||||
SECP384R1.name: "P-384",
|
||||
SECP521R1.name: "P-521",
|
||||
SECP256K1.name: "secp256k1",
|
||||
}
|
||||
REQUIRED_JSON_FIELDS = ["crv", "x", "y"]
|
||||
|
||||
PUBLIC_KEY_FIELDS = REQUIRED_JSON_FIELDS
|
||||
PRIVATE_KEY_FIELDS = ["crv", "d", "x", "y"]
|
||||
|
||||
PUBLIC_KEY_CLS = EllipticCurvePublicKey
|
||||
PRIVATE_KEY_CLS = EllipticCurvePrivateKeyWithSerialization
|
||||
SSH_PUBLIC_PREFIX = b"ecdsa-sha2-"
|
||||
|
||||
def exchange_shared_key(self, pubkey):
|
||||
# # used in ECDHESAlgorithm
|
||||
private_key = self.get_private_key()
|
||||
if private_key:
|
||||
return private_key.exchange(ec.ECDH(), pubkey)
|
||||
raise ValueError("Invalid key for exchanging shared key")
|
||||
|
||||
@property
|
||||
def curve_key_size(self):
|
||||
raw_key = self.get_private_key()
|
||||
if not raw_key:
|
||||
raw_key = self.public_key
|
||||
return raw_key.curve.key_size
|
||||
|
||||
def load_private_key(self):
|
||||
curve = self.DSS_CURVES[self._dict_data["crv"]]()
|
||||
public_numbers = EllipticCurvePublicNumbers(
|
||||
base64_to_int(self._dict_data["x"]),
|
||||
base64_to_int(self._dict_data["y"]),
|
||||
curve,
|
||||
)
|
||||
private_numbers = EllipticCurvePrivateNumbers(
|
||||
base64_to_int(self.tokens["d"]), public_numbers
|
||||
)
|
||||
return private_numbers.private_key(default_backend())
|
||||
|
||||
def load_public_key(self):
|
||||
curve = self.DSS_CURVES[self._dict_data["crv"]]()
|
||||
public_numbers = EllipticCurvePublicNumbers(
|
||||
base64_to_int(self._dict_data["x"]),
|
||||
base64_to_int(self._dict_data["y"]),
|
||||
curve,
|
||||
)
|
||||
return public_numbers.public_key(default_backend())
|
||||
|
||||
def dumps_private_key(self):
|
||||
numbers = self.private_key.private_numbers()
|
||||
return {
|
||||
"crv": self.CURVES_DSS[self.private_key.curve.name],
|
||||
"x": int_to_base64(numbers.public_numbers.x),
|
||||
"y": int_to_base64(numbers.public_numbers.y),
|
||||
"d": int_to_base64(numbers.private_value),
|
||||
}
|
||||
|
||||
def dumps_public_key(self):
|
||||
numbers = self.public_key.public_numbers()
|
||||
return {
|
||||
"crv": self.CURVES_DSS[numbers.curve.name],
|
||||
"x": int_to_base64(numbers.x),
|
||||
"y": int_to_base64(numbers.y),
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def generate_key(cls, crv="P-256", options=None, is_private=False) -> "ECKey":
|
||||
if crv not in cls.DSS_CURVES:
|
||||
raise ValueError(f'Invalid crv value: "{crv}"')
|
||||
raw_key = ec.generate_private_key(
|
||||
curve=cls.DSS_CURVES[crv](),
|
||||
backend=default_backend(),
|
||||
)
|
||||
if not is_private:
|
||||
raw_key = raw_key.public_key()
|
||||
return cls.import_key(raw_key, options=options)
|
||||
@@ -0,0 +1,350 @@
|
||||
import os
|
||||
import struct
|
||||
|
||||
from cryptography.hazmat.backends import default_backend
|
||||
from cryptography.hazmat.primitives import hashes
|
||||
from cryptography.hazmat.primitives.asymmetric import padding
|
||||
from cryptography.hazmat.primitives.ciphers import Cipher
|
||||
from cryptography.hazmat.primitives.ciphers.algorithms import AES
|
||||
from cryptography.hazmat.primitives.ciphers.modes import GCM
|
||||
from cryptography.hazmat.primitives.kdf.concatkdf import ConcatKDFHash
|
||||
from cryptography.hazmat.primitives.keywrap import aes_key_unwrap
|
||||
from cryptography.hazmat.primitives.keywrap import aes_key_wrap
|
||||
|
||||
from authlib.common.encoding import to_bytes
|
||||
from authlib.common.encoding import to_native
|
||||
from authlib.common.encoding import urlsafe_b64decode
|
||||
from authlib.common.encoding import urlsafe_b64encode
|
||||
from authlib.jose.rfc7516 import JWEAlgorithm
|
||||
|
||||
from .ec_key import ECKey
|
||||
from .oct_key import OctKey
|
||||
from .rsa_key import RSAKey
|
||||
|
||||
|
||||
class DirectAlgorithm(JWEAlgorithm):
|
||||
name = "dir"
|
||||
description = "Direct use of a shared symmetric key"
|
||||
|
||||
def prepare_key(self, raw_data):
|
||||
return OctKey.import_key(raw_data)
|
||||
|
||||
def generate_preset(self, enc_alg, key):
|
||||
return {}
|
||||
|
||||
def wrap(self, enc_alg, headers, key, preset=None):
|
||||
cek = key.get_op_key("encrypt")
|
||||
if len(cek) * 8 != enc_alg.CEK_SIZE:
|
||||
raise ValueError('Invalid "cek" length')
|
||||
return {"ek": b"", "cek": cek}
|
||||
|
||||
def unwrap(self, enc_alg, ek, headers, key):
|
||||
cek = key.get_op_key("decrypt")
|
||||
if len(cek) * 8 != enc_alg.CEK_SIZE:
|
||||
raise ValueError('Invalid "cek" length')
|
||||
return cek
|
||||
|
||||
|
||||
class RSAAlgorithm(JWEAlgorithm):
|
||||
#: A key of size 2048 bits or larger MUST be used with these algorithms
|
||||
#: RSA1_5, RSA-OAEP, RSA-OAEP-256
|
||||
key_size = 2048
|
||||
|
||||
def __init__(self, name, description, pad_fn):
|
||||
self.name = name
|
||||
self.description = description
|
||||
self.padding = pad_fn
|
||||
|
||||
def prepare_key(self, raw_data):
|
||||
return RSAKey.import_key(raw_data)
|
||||
|
||||
def generate_preset(self, enc_alg, key):
|
||||
cek = enc_alg.generate_cek()
|
||||
return {"cek": cek}
|
||||
|
||||
def wrap(self, enc_alg, headers, key, preset=None):
|
||||
if preset and "cek" in preset:
|
||||
cek = preset["cek"]
|
||||
else:
|
||||
cek = enc_alg.generate_cek()
|
||||
|
||||
op_key = key.get_op_key("wrapKey")
|
||||
if op_key.key_size < self.key_size:
|
||||
raise ValueError("A key of size 2048 bits or larger MUST be used")
|
||||
ek = op_key.encrypt(cek, self.padding)
|
||||
return {"ek": ek, "cek": cek}
|
||||
|
||||
def unwrap(self, enc_alg, ek, headers, key):
|
||||
# it will raise ValueError if failed
|
||||
op_key = key.get_op_key("unwrapKey")
|
||||
cek = op_key.decrypt(ek, self.padding)
|
||||
if len(cek) * 8 != enc_alg.CEK_SIZE:
|
||||
raise ValueError('Invalid "cek" length')
|
||||
return cek
|
||||
|
||||
|
||||
class AESAlgorithm(JWEAlgorithm):
|
||||
def __init__(self, key_size):
|
||||
self.name = f"A{key_size}KW"
|
||||
self.description = f"AES Key Wrap using {key_size}-bit key"
|
||||
self.key_size = key_size
|
||||
|
||||
def prepare_key(self, raw_data):
|
||||
return OctKey.import_key(raw_data)
|
||||
|
||||
def generate_preset(self, enc_alg, key):
|
||||
cek = enc_alg.generate_cek()
|
||||
return {"cek": cek}
|
||||
|
||||
def _check_key(self, key):
|
||||
if len(key) * 8 != self.key_size:
|
||||
raise ValueError(f"A key of size {self.key_size} bits is required.")
|
||||
|
||||
def wrap_cek(self, cek, key):
|
||||
op_key = key.get_op_key("wrapKey")
|
||||
self._check_key(op_key)
|
||||
ek = aes_key_wrap(op_key, cek, default_backend())
|
||||
return {"ek": ek, "cek": cek}
|
||||
|
||||
def wrap(self, enc_alg, headers, key, preset=None):
|
||||
if preset and "cek" in preset:
|
||||
cek = preset["cek"]
|
||||
else:
|
||||
cek = enc_alg.generate_cek()
|
||||
return self.wrap_cek(cek, key)
|
||||
|
||||
def unwrap(self, enc_alg, ek, headers, key):
|
||||
op_key = key.get_op_key("unwrapKey")
|
||||
self._check_key(op_key)
|
||||
cek = aes_key_unwrap(op_key, ek, default_backend())
|
||||
if len(cek) * 8 != enc_alg.CEK_SIZE:
|
||||
raise ValueError('Invalid "cek" length')
|
||||
return cek
|
||||
|
||||
|
||||
class AESGCMAlgorithm(JWEAlgorithm):
|
||||
EXTRA_HEADERS = frozenset(["iv", "tag"])
|
||||
|
||||
def __init__(self, key_size):
|
||||
self.name = f"A{key_size}GCMKW"
|
||||
self.description = f"Key wrapping with AES GCM using {key_size}-bit key"
|
||||
self.key_size = key_size
|
||||
|
||||
def prepare_key(self, raw_data):
|
||||
return OctKey.import_key(raw_data)
|
||||
|
||||
def generate_preset(self, enc_alg, key):
|
||||
cek = enc_alg.generate_cek()
|
||||
return {"cek": cek}
|
||||
|
||||
def _check_key(self, key):
|
||||
if len(key) * 8 != self.key_size:
|
||||
raise ValueError(f"A key of size {self.key_size} bits is required.")
|
||||
|
||||
def wrap(self, enc_alg, headers, key, preset=None):
|
||||
if preset and "cek" in preset:
|
||||
cek = preset["cek"]
|
||||
else:
|
||||
cek = enc_alg.generate_cek()
|
||||
|
||||
op_key = key.get_op_key("wrapKey")
|
||||
self._check_key(op_key)
|
||||
|
||||
#: https://tools.ietf.org/html/rfc7518#section-4.7.1.1
|
||||
#: The "iv" (initialization vector) Header Parameter value is the
|
||||
#: base64url-encoded representation of the 96-bit IV value
|
||||
iv_size = 96
|
||||
iv = os.urandom(iv_size // 8)
|
||||
|
||||
cipher = Cipher(AES(op_key), GCM(iv), backend=default_backend())
|
||||
enc = cipher.encryptor()
|
||||
ek = enc.update(cek) + enc.finalize()
|
||||
|
||||
h = {
|
||||
"iv": to_native(urlsafe_b64encode(iv)),
|
||||
"tag": to_native(urlsafe_b64encode(enc.tag)),
|
||||
}
|
||||
return {"ek": ek, "cek": cek, "header": h}
|
||||
|
||||
def unwrap(self, enc_alg, ek, headers, key):
|
||||
op_key = key.get_op_key("unwrapKey")
|
||||
self._check_key(op_key)
|
||||
|
||||
iv = headers.get("iv")
|
||||
if not iv:
|
||||
raise ValueError('Missing "iv" in headers')
|
||||
|
||||
tag = headers.get("tag")
|
||||
if not tag:
|
||||
raise ValueError('Missing "tag" in headers')
|
||||
|
||||
iv = urlsafe_b64decode(to_bytes(iv))
|
||||
tag = urlsafe_b64decode(to_bytes(tag))
|
||||
|
||||
cipher = Cipher(AES(op_key), GCM(iv, tag), backend=default_backend())
|
||||
d = cipher.decryptor()
|
||||
cek = d.update(ek) + d.finalize()
|
||||
if len(cek) * 8 != enc_alg.CEK_SIZE:
|
||||
raise ValueError('Invalid "cek" length')
|
||||
return cek
|
||||
|
||||
|
||||
class ECDHESAlgorithm(JWEAlgorithm):
|
||||
EXTRA_HEADERS = ["epk", "apu", "apv"]
|
||||
ALLOWED_KEY_CLS = ECKey
|
||||
|
||||
# https://tools.ietf.org/html/rfc7518#section-4.6
|
||||
def __init__(self, key_size=None):
|
||||
if key_size is None:
|
||||
self.name = "ECDH-ES"
|
||||
self.description = "ECDH-ES in the Direct Key Agreement mode"
|
||||
else:
|
||||
self.name = f"ECDH-ES+A{key_size}KW"
|
||||
self.description = (
|
||||
f"ECDH-ES using Concat KDF and CEK wrapped with A{key_size}KW"
|
||||
)
|
||||
self.key_size = key_size
|
||||
self.aeskw = AESAlgorithm(key_size)
|
||||
|
||||
def prepare_key(self, raw_data):
|
||||
if isinstance(raw_data, self.ALLOWED_KEY_CLS):
|
||||
return raw_data
|
||||
return ECKey.import_key(raw_data)
|
||||
|
||||
def generate_preset(self, enc_alg, key):
|
||||
epk = self._generate_ephemeral_key(key)
|
||||
h = self._prepare_headers(epk)
|
||||
preset = {"epk": epk, "header": h}
|
||||
if self.key_size is not None:
|
||||
cek = enc_alg.generate_cek()
|
||||
preset["cek"] = cek
|
||||
return preset
|
||||
|
||||
def compute_fixed_info(self, headers, bit_size):
|
||||
# AlgorithmID
|
||||
if self.key_size is None:
|
||||
alg_id = u32be_len_input(headers["enc"])
|
||||
else:
|
||||
alg_id = u32be_len_input(headers["alg"])
|
||||
|
||||
# PartyUInfo
|
||||
apu_info = u32be_len_input(headers.get("apu"), True)
|
||||
|
||||
# PartyVInfo
|
||||
apv_info = u32be_len_input(headers.get("apv"), True)
|
||||
|
||||
# SuppPubInfo
|
||||
pub_info = struct.pack(">I", bit_size)
|
||||
|
||||
return alg_id + apu_info + apv_info + pub_info
|
||||
|
||||
def compute_derived_key(self, shared_key, fixed_info, bit_size):
|
||||
ckdf = ConcatKDFHash(
|
||||
algorithm=hashes.SHA256(),
|
||||
length=bit_size // 8,
|
||||
otherinfo=fixed_info,
|
||||
backend=default_backend(),
|
||||
)
|
||||
return ckdf.derive(shared_key)
|
||||
|
||||
def deliver(self, key, pubkey, headers, bit_size):
|
||||
shared_key = key.exchange_shared_key(pubkey)
|
||||
fixed_info = self.compute_fixed_info(headers, bit_size)
|
||||
return self.compute_derived_key(shared_key, fixed_info, bit_size)
|
||||
|
||||
def _generate_ephemeral_key(self, key):
|
||||
return key.generate_key(key["crv"], is_private=True)
|
||||
|
||||
def _prepare_headers(self, epk):
|
||||
# REQUIRED_JSON_FIELDS contains only public fields
|
||||
pub_epk = {k: epk[k] for k in epk.REQUIRED_JSON_FIELDS}
|
||||
pub_epk["kty"] = epk.kty
|
||||
return {"epk": pub_epk}
|
||||
|
||||
def wrap(self, enc_alg, headers, key, preset=None):
|
||||
if self.key_size is None:
|
||||
bit_size = enc_alg.CEK_SIZE
|
||||
else:
|
||||
bit_size = self.key_size
|
||||
|
||||
if preset and "epk" in preset:
|
||||
epk = preset["epk"]
|
||||
h = {}
|
||||
else:
|
||||
epk = self._generate_ephemeral_key(key)
|
||||
h = self._prepare_headers(epk)
|
||||
|
||||
public_key = key.get_op_key("wrapKey")
|
||||
dk = self.deliver(epk, public_key, headers, bit_size)
|
||||
|
||||
if self.key_size is None:
|
||||
return {"ek": b"", "cek": dk, "header": h}
|
||||
|
||||
if preset and "cek" in preset:
|
||||
preset_for_kw = {"cek": preset["cek"]}
|
||||
else:
|
||||
preset_for_kw = None
|
||||
|
||||
kek = self.aeskw.prepare_key(dk)
|
||||
rv = self.aeskw.wrap(enc_alg, headers, kek, preset_for_kw)
|
||||
rv["header"] = h
|
||||
return rv
|
||||
|
||||
def unwrap(self, enc_alg, ek, headers, key):
|
||||
if "epk" not in headers:
|
||||
raise ValueError('Missing "epk" in headers')
|
||||
|
||||
if self.key_size is None:
|
||||
bit_size = enc_alg.CEK_SIZE
|
||||
else:
|
||||
bit_size = self.key_size
|
||||
|
||||
epk = key.import_key(headers["epk"])
|
||||
public_key = epk.get_op_key("wrapKey")
|
||||
dk = self.deliver(key, public_key, headers, bit_size)
|
||||
|
||||
if self.key_size is None:
|
||||
return dk
|
||||
|
||||
kek = self.aeskw.prepare_key(dk)
|
||||
return self.aeskw.unwrap(enc_alg, ek, headers, kek)
|
||||
|
||||
|
||||
def u32be_len_input(s, base64=False):
|
||||
if not s:
|
||||
return b"\x00\x00\x00\x00"
|
||||
if base64:
|
||||
s = urlsafe_b64decode(to_bytes(s))
|
||||
else:
|
||||
s = to_bytes(s)
|
||||
return struct.pack(">I", len(s)) + s
|
||||
|
||||
|
||||
JWE_ALG_ALGORITHMS = [
|
||||
DirectAlgorithm(), # dir
|
||||
RSAAlgorithm("RSA1_5", "RSAES-PKCS1-v1_5", padding.PKCS1v15()),
|
||||
RSAAlgorithm(
|
||||
"RSA-OAEP",
|
||||
"RSAES OAEP using default parameters",
|
||||
padding.OAEP(padding.MGF1(hashes.SHA1()), hashes.SHA1(), None),
|
||||
),
|
||||
RSAAlgorithm(
|
||||
"RSA-OAEP-256",
|
||||
"RSAES OAEP using SHA-256 and MGF1 with SHA-256",
|
||||
padding.OAEP(padding.MGF1(hashes.SHA256()), hashes.SHA256(), None),
|
||||
),
|
||||
AESAlgorithm(128), # A128KW
|
||||
AESAlgorithm(192), # A192KW
|
||||
AESAlgorithm(256), # A256KW
|
||||
AESGCMAlgorithm(128), # A128GCMKW
|
||||
AESGCMAlgorithm(192), # A192GCMKW
|
||||
AESGCMAlgorithm(256), # A256GCMKW
|
||||
ECDHESAlgorithm(None), # ECDH-ES
|
||||
ECDHESAlgorithm(128), # ECDH-ES+A128KW
|
||||
ECDHESAlgorithm(192), # ECDH-ES+A192KW
|
||||
ECDHESAlgorithm(256), # ECDH-ES+A256KW
|
||||
]
|
||||
|
||||
# 'PBES2-HS256+A128KW': '',
|
||||
# 'PBES2-HS384+A192KW': '',
|
||||
# 'PBES2-HS512+A256KW': '',
|
||||
@@ -0,0 +1,147 @@
|
||||
"""authlib.jose.rfc7518.
|
||||
~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
Cryptographic Algorithms for Cryptographic Algorithms for Content
|
||||
Encryption per `Section 5`_.
|
||||
|
||||
.. _`Section 5`: https://tools.ietf.org/html/rfc7518#section-5
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import hmac
|
||||
|
||||
from cryptography.exceptions import InvalidTag
|
||||
from cryptography.hazmat.backends import default_backend
|
||||
from cryptography.hazmat.primitives.ciphers import Cipher
|
||||
from cryptography.hazmat.primitives.ciphers.algorithms import AES
|
||||
from cryptography.hazmat.primitives.ciphers.modes import CBC
|
||||
from cryptography.hazmat.primitives.ciphers.modes import GCM
|
||||
from cryptography.hazmat.primitives.padding import PKCS7
|
||||
|
||||
from ..rfc7516 import JWEEncAlgorithm
|
||||
from .util import encode_int
|
||||
|
||||
|
||||
class CBCHS2EncAlgorithm(JWEEncAlgorithm):
|
||||
# The IV used is a 128-bit value generated randomly or
|
||||
# pseudo-randomly for use in the cipher.
|
||||
IV_SIZE = 128
|
||||
|
||||
def __init__(self, key_size, hash_type):
|
||||
self.name = f"A{key_size}CBC-HS{hash_type}"
|
||||
tpl = "AES_{}_CBC_HMAC_SHA_{} authenticated encryption algorithm"
|
||||
self.description = tpl.format(key_size, hash_type)
|
||||
|
||||
# bit length
|
||||
self.key_size = key_size
|
||||
# byte length
|
||||
self.key_len = key_size // 8
|
||||
|
||||
self.CEK_SIZE = key_size * 2
|
||||
self.hash_alg = getattr(hashlib, f"sha{hash_type}")
|
||||
|
||||
def _hmac(self, ciphertext, aad, iv, key):
|
||||
al = encode_int(len(aad) * 8, 64)
|
||||
msg = aad + iv + ciphertext + al
|
||||
d = hmac.new(key, msg, self.hash_alg).digest()
|
||||
return d[: self.key_len]
|
||||
|
||||
def encrypt(self, msg, aad, iv, key):
|
||||
"""Key Encryption with AES_CBC_HMAC_SHA2.
|
||||
|
||||
:param msg: text to be encrypt in bytes
|
||||
:param aad: additional authenticated data in bytes
|
||||
:param iv: initialization vector in bytes
|
||||
:param key: encrypted key in bytes
|
||||
:return: (ciphertext, iv, tag)
|
||||
"""
|
||||
self.check_iv(iv)
|
||||
hkey = key[: self.key_len]
|
||||
ekey = key[self.key_len :]
|
||||
|
||||
pad = PKCS7(AES.block_size).padder()
|
||||
padded_data = pad.update(msg) + pad.finalize()
|
||||
|
||||
cipher = Cipher(AES(ekey), CBC(iv), backend=default_backend())
|
||||
enc = cipher.encryptor()
|
||||
ciphertext = enc.update(padded_data) + enc.finalize()
|
||||
tag = self._hmac(ciphertext, aad, iv, hkey)
|
||||
return ciphertext, tag
|
||||
|
||||
def decrypt(self, ciphertext, aad, iv, tag, key):
|
||||
"""Key Decryption with AES AES_CBC_HMAC_SHA2.
|
||||
|
||||
:param ciphertext: ciphertext in bytes
|
||||
:param aad: additional authenticated data in bytes
|
||||
:param iv: initialization vector in bytes
|
||||
:param tag: authentication tag in bytes
|
||||
:param key: encrypted key in bytes
|
||||
:return: message
|
||||
"""
|
||||
self.check_iv(iv)
|
||||
hkey = key[: self.key_len]
|
||||
dkey = key[self.key_len :]
|
||||
|
||||
_tag = self._hmac(ciphertext, aad, iv, hkey)
|
||||
if not hmac.compare_digest(_tag, tag):
|
||||
raise InvalidTag()
|
||||
|
||||
cipher = Cipher(AES(dkey), CBC(iv), backend=default_backend())
|
||||
d = cipher.decryptor()
|
||||
data = d.update(ciphertext) + d.finalize()
|
||||
unpad = PKCS7(AES.block_size).unpadder()
|
||||
return unpad.update(data) + unpad.finalize()
|
||||
|
||||
|
||||
class GCMEncAlgorithm(JWEEncAlgorithm):
|
||||
# Use of an IV of size 96 bits is REQUIRED with this algorithm.
|
||||
# https://tools.ietf.org/html/rfc7518#section-5.3
|
||||
IV_SIZE = 96
|
||||
|
||||
def __init__(self, key_size):
|
||||
self.name = f"A{key_size}GCM"
|
||||
self.description = f"AES GCM using {key_size}-bit key"
|
||||
self.key_size = key_size
|
||||
self.CEK_SIZE = key_size
|
||||
|
||||
def encrypt(self, msg, aad, iv, key):
|
||||
"""Key Encryption with AES GCM.
|
||||
|
||||
:param msg: text to be encrypt in bytes
|
||||
:param aad: additional authenticated data in bytes
|
||||
:param iv: initialization vector in bytes
|
||||
:param key: encrypted key in bytes
|
||||
:return: (ciphertext, iv, tag)
|
||||
"""
|
||||
self.check_iv(iv)
|
||||
cipher = Cipher(AES(key), GCM(iv), backend=default_backend())
|
||||
enc = cipher.encryptor()
|
||||
enc.authenticate_additional_data(aad)
|
||||
ciphertext = enc.update(msg) + enc.finalize()
|
||||
return ciphertext, enc.tag
|
||||
|
||||
def decrypt(self, ciphertext, aad, iv, tag, key):
|
||||
"""Key Decryption with AES GCM.
|
||||
|
||||
:param ciphertext: ciphertext in bytes
|
||||
:param aad: additional authenticated data in bytes
|
||||
:param iv: initialization vector in bytes
|
||||
:param tag: authentication tag in bytes
|
||||
:param key: encrypted key in bytes
|
||||
:return: message
|
||||
"""
|
||||
self.check_iv(iv)
|
||||
cipher = Cipher(AES(key), GCM(iv, tag), backend=default_backend())
|
||||
d = cipher.decryptor()
|
||||
d.authenticate_additional_data(aad)
|
||||
return d.update(ciphertext) + d.finalize()
|
||||
|
||||
|
||||
JWE_ENC_ALGORITHMS = [
|
||||
CBCHS2EncAlgorithm(128, 256), # A128CBC-HS256
|
||||
CBCHS2EncAlgorithm(192, 384), # A192CBC-HS384
|
||||
CBCHS2EncAlgorithm(256, 512), # A256CBC-HS512
|
||||
GCMEncAlgorithm(128), # A128GCM
|
||||
GCMEncAlgorithm(192), # A192GCM
|
||||
GCMEncAlgorithm(256), # A256GCM
|
||||
]
|
||||
@@ -0,0 +1,34 @@
|
||||
import zlib
|
||||
|
||||
from ..rfc7516 import JsonWebEncryption
|
||||
from ..rfc7516 import JWEZipAlgorithm
|
||||
|
||||
GZIP_HEAD = bytes([120, 156])
|
||||
MAX_SIZE = 250 * 1024
|
||||
|
||||
|
||||
class DeflateZipAlgorithm(JWEZipAlgorithm):
|
||||
name = "DEF"
|
||||
description = "DEFLATE"
|
||||
|
||||
def compress(self, s: bytes) -> bytes:
|
||||
"""Compress bytes data with DEFLATE algorithm."""
|
||||
data = zlib.compress(s)
|
||||
# https://datatracker.ietf.org/doc/html/rfc1951
|
||||
# since DEF is always gzip, we can drop gzip headers and tail
|
||||
return data[2:-4]
|
||||
|
||||
def decompress(self, s: bytes) -> bytes:
|
||||
"""Decompress DEFLATE bytes data."""
|
||||
if s.startswith(GZIP_HEAD):
|
||||
decompressor = zlib.decompressobj()
|
||||
else:
|
||||
decompressor = zlib.decompressobj(-zlib.MAX_WBITS)
|
||||
value = decompressor.decompress(s, MAX_SIZE)
|
||||
if decompressor.unconsumed_tail:
|
||||
raise ValueError(f"Decompressed string exceeds {MAX_SIZE} bytes")
|
||||
return value
|
||||
|
||||
|
||||
def register_jwe_rfc7518():
|
||||
JsonWebEncryption.register_algorithm(DeflateZipAlgorithm())
|
||||
@@ -0,0 +1,221 @@
|
||||
"""authlib.jose.rfc7518.
|
||||
~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
"alg" (Algorithm) Header Parameter Values for JWS per `Section 3`_.
|
||||
|
||||
.. _`Section 3`: https://tools.ietf.org/html/rfc7518#section-3
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import hmac
|
||||
|
||||
from cryptography.exceptions import InvalidSignature
|
||||
from cryptography.hazmat.primitives import hashes
|
||||
from cryptography.hazmat.primitives.asymmetric import padding
|
||||
from cryptography.hazmat.primitives.asymmetric.ec import ECDSA
|
||||
from cryptography.hazmat.primitives.asymmetric.utils import decode_dss_signature
|
||||
from cryptography.hazmat.primitives.asymmetric.utils import encode_dss_signature
|
||||
|
||||
from ..rfc7515 import JWSAlgorithm
|
||||
from .ec_key import ECKey
|
||||
from .oct_key import OctKey
|
||||
from .rsa_key import RSAKey
|
||||
from .util import decode_int
|
||||
from .util import encode_int
|
||||
|
||||
|
||||
class NoneAlgorithm(JWSAlgorithm):
|
||||
name = "none"
|
||||
description = "No digital signature or MAC performed"
|
||||
|
||||
def prepare_key(self, raw_data):
|
||||
return None
|
||||
|
||||
def sign(self, msg, key):
|
||||
return b""
|
||||
|
||||
def verify(self, msg, sig, key):
|
||||
return sig == b""
|
||||
|
||||
|
||||
class HMACAlgorithm(JWSAlgorithm):
|
||||
"""HMAC using SHA algorithms for JWS. Available algorithms:
|
||||
|
||||
- HS256: HMAC using SHA-256
|
||||
- HS384: HMAC using SHA-384
|
||||
- HS512: HMAC using SHA-512
|
||||
"""
|
||||
|
||||
SHA256 = hashlib.sha256
|
||||
SHA384 = hashlib.sha384
|
||||
SHA512 = hashlib.sha512
|
||||
|
||||
def __init__(self, sha_type):
|
||||
self.name = f"HS{sha_type}"
|
||||
self.description = f"HMAC using SHA-{sha_type}"
|
||||
self.hash_alg = getattr(self, f"SHA{sha_type}")
|
||||
|
||||
def prepare_key(self, raw_data):
|
||||
return OctKey.import_key(raw_data)
|
||||
|
||||
def sign(self, msg, key):
|
||||
# it is faster than the one in cryptography
|
||||
op_key = key.get_op_key("sign")
|
||||
return hmac.new(op_key, msg, self.hash_alg).digest()
|
||||
|
||||
def verify(self, msg, sig, key):
|
||||
op_key = key.get_op_key("verify")
|
||||
v_sig = hmac.new(op_key, msg, self.hash_alg).digest()
|
||||
return hmac.compare_digest(sig, v_sig)
|
||||
|
||||
|
||||
class RSAAlgorithm(JWSAlgorithm):
|
||||
"""RSA using SHA algorithms for JWS. Available algorithms:
|
||||
|
||||
- RS256: RSASSA-PKCS1-v1_5 using SHA-256
|
||||
- RS384: RSASSA-PKCS1-v1_5 using SHA-384
|
||||
- RS512: RSASSA-PKCS1-v1_5 using SHA-512
|
||||
"""
|
||||
|
||||
SHA256 = hashes.SHA256
|
||||
SHA384 = hashes.SHA384
|
||||
SHA512 = hashes.SHA512
|
||||
|
||||
def __init__(self, sha_type):
|
||||
self.name = f"RS{sha_type}"
|
||||
self.description = f"RSASSA-PKCS1-v1_5 using SHA-{sha_type}"
|
||||
self.hash_alg = getattr(self, f"SHA{sha_type}")
|
||||
self.padding = padding.PKCS1v15()
|
||||
|
||||
def prepare_key(self, raw_data):
|
||||
return RSAKey.import_key(raw_data)
|
||||
|
||||
def sign(self, msg, key):
|
||||
op_key = key.get_op_key("sign")
|
||||
return op_key.sign(msg, self.padding, self.hash_alg())
|
||||
|
||||
def verify(self, msg, sig, key):
|
||||
op_key = key.get_op_key("verify")
|
||||
try:
|
||||
op_key.verify(sig, msg, self.padding, self.hash_alg())
|
||||
return True
|
||||
except InvalidSignature:
|
||||
return False
|
||||
|
||||
|
||||
class ECAlgorithm(JWSAlgorithm):
|
||||
"""ECDSA using SHA algorithms for JWS. Available algorithms:
|
||||
|
||||
- ES256: ECDSA using P-256 and SHA-256
|
||||
- ES384: ECDSA using P-384 and SHA-384
|
||||
- ES512: ECDSA using P-521 and SHA-512
|
||||
"""
|
||||
|
||||
SHA256 = hashes.SHA256
|
||||
SHA384 = hashes.SHA384
|
||||
SHA512 = hashes.SHA512
|
||||
|
||||
def __init__(self, name, curve, sha_type):
|
||||
self.name = name
|
||||
self.curve = curve
|
||||
self.description = f"ECDSA using {self.curve} and SHA-{sha_type}"
|
||||
self.hash_alg = getattr(self, f"SHA{sha_type}")
|
||||
|
||||
def prepare_key(self, raw_data):
|
||||
key = ECKey.import_key(raw_data)
|
||||
if key["crv"] != self.curve:
|
||||
raise ValueError(
|
||||
f'Key for "{self.name}" not supported, only "{self.curve}" allowed'
|
||||
)
|
||||
return key
|
||||
|
||||
def sign(self, msg, key):
|
||||
op_key = key.get_op_key("sign")
|
||||
der_sig = op_key.sign(msg, ECDSA(self.hash_alg()))
|
||||
r, s = decode_dss_signature(der_sig)
|
||||
size = key.curve_key_size
|
||||
return encode_int(r, size) + encode_int(s, size)
|
||||
|
||||
def verify(self, msg, sig, key):
|
||||
key_size = key.curve_key_size
|
||||
length = (key_size + 7) // 8
|
||||
|
||||
if len(sig) != 2 * length:
|
||||
return False
|
||||
|
||||
r = decode_int(sig[:length])
|
||||
s = decode_int(sig[length:])
|
||||
der_sig = encode_dss_signature(r, s)
|
||||
|
||||
try:
|
||||
op_key = key.get_op_key("verify")
|
||||
op_key.verify(der_sig, msg, ECDSA(self.hash_alg()))
|
||||
return True
|
||||
except InvalidSignature:
|
||||
return False
|
||||
|
||||
|
||||
class RSAPSSAlgorithm(JWSAlgorithm):
|
||||
"""RSASSA-PSS using SHA algorithms for JWS. Available algorithms:
|
||||
|
||||
- PS256: RSASSA-PSS using SHA-256 and MGF1 with SHA-256
|
||||
- PS384: RSASSA-PSS using SHA-384 and MGF1 with SHA-384
|
||||
- PS512: RSASSA-PSS using SHA-512 and MGF1 with SHA-512
|
||||
"""
|
||||
|
||||
SHA256 = hashes.SHA256
|
||||
SHA384 = hashes.SHA384
|
||||
SHA512 = hashes.SHA512
|
||||
|
||||
def __init__(self, sha_type):
|
||||
self.name = f"PS{sha_type}"
|
||||
tpl = "RSASSA-PSS using SHA-{} and MGF1 with SHA-{}"
|
||||
self.description = tpl.format(sha_type, sha_type)
|
||||
self.hash_alg = getattr(self, f"SHA{sha_type}")
|
||||
|
||||
def prepare_key(self, raw_data):
|
||||
return RSAKey.import_key(raw_data)
|
||||
|
||||
def sign(self, msg, key):
|
||||
op_key = key.get_op_key("sign")
|
||||
return op_key.sign(
|
||||
msg,
|
||||
padding.PSS(
|
||||
mgf=padding.MGF1(self.hash_alg()), salt_length=self.hash_alg.digest_size
|
||||
),
|
||||
self.hash_alg(),
|
||||
)
|
||||
|
||||
def verify(self, msg, sig, key):
|
||||
op_key = key.get_op_key("verify")
|
||||
try:
|
||||
op_key.verify(
|
||||
sig,
|
||||
msg,
|
||||
padding.PSS(
|
||||
mgf=padding.MGF1(self.hash_alg()),
|
||||
salt_length=self.hash_alg.digest_size,
|
||||
),
|
||||
self.hash_alg(),
|
||||
)
|
||||
return True
|
||||
except InvalidSignature:
|
||||
return False
|
||||
|
||||
|
||||
JWS_ALGORITHMS = [
|
||||
NoneAlgorithm(), # none
|
||||
HMACAlgorithm(256), # HS256
|
||||
HMACAlgorithm(384), # HS384
|
||||
HMACAlgorithm(512), # HS512
|
||||
RSAAlgorithm(256), # RS256
|
||||
RSAAlgorithm(384), # RS384
|
||||
RSAAlgorithm(512), # RS512
|
||||
ECAlgorithm("ES256", "P-256", 256),
|
||||
ECAlgorithm("ES384", "P-384", 384),
|
||||
ECAlgorithm("ES512", "P-521", 512),
|
||||
ECAlgorithm("ES256K", "secp256k1", 256), # defined in RFC8812
|
||||
RSAPSSAlgorithm(256), # PS256
|
||||
RSAPSSAlgorithm(384), # PS384
|
||||
RSAPSSAlgorithm(512), # PS512
|
||||
]
|
||||
@@ -0,0 +1,96 @@
|
||||
import secrets
|
||||
|
||||
from authlib.common.encoding import to_bytes
|
||||
from authlib.common.encoding import to_unicode
|
||||
from authlib.common.encoding import urlsafe_b64decode
|
||||
from authlib.common.encoding import urlsafe_b64encode
|
||||
|
||||
from ..rfc7517 import Key
|
||||
|
||||
POSSIBLE_UNSAFE_KEYS = (
|
||||
b"-----BEGIN ",
|
||||
b"---- BEGIN ",
|
||||
b"ssh-rsa ",
|
||||
b"ssh-dss ",
|
||||
b"ssh-ed25519 ",
|
||||
b"ecdsa-sha2-",
|
||||
)
|
||||
|
||||
|
||||
class OctKey(Key):
|
||||
"""Key class of the ``oct`` key type."""
|
||||
|
||||
kty = "oct"
|
||||
REQUIRED_JSON_FIELDS = ["k"]
|
||||
|
||||
def __init__(self, raw_key=None, options=None):
|
||||
super().__init__(options)
|
||||
self.raw_key = raw_key
|
||||
|
||||
@property
|
||||
def public_only(self):
|
||||
return False
|
||||
|
||||
def get_op_key(self, operation):
|
||||
"""Get the raw key for the given key_op. This method will also
|
||||
check if the given key_op is supported by this key.
|
||||
|
||||
:param operation: key operation value, such as "sign", "encrypt".
|
||||
:return: raw key
|
||||
"""
|
||||
self.check_key_op(operation)
|
||||
if not self.raw_key:
|
||||
self.load_raw_key()
|
||||
return self.raw_key
|
||||
|
||||
def load_raw_key(self):
|
||||
self.raw_key = urlsafe_b64decode(to_bytes(self.tokens["k"]))
|
||||
|
||||
def load_dict_key(self):
|
||||
k = to_unicode(urlsafe_b64encode(self.raw_key))
|
||||
self._dict_data = {"kty": self.kty, "k": k}
|
||||
|
||||
def as_dict(self, is_private=False, **params):
|
||||
tokens = self.tokens
|
||||
if "kid" not in tokens:
|
||||
tokens["kid"] = self.thumbprint()
|
||||
|
||||
tokens.update(params)
|
||||
return tokens
|
||||
|
||||
@classmethod
|
||||
def validate_raw_key(cls, key):
|
||||
return isinstance(key, bytes)
|
||||
|
||||
@classmethod
|
||||
def import_key(cls, raw, options=None):
|
||||
"""Import a key from bytes, string, or dict data."""
|
||||
if isinstance(raw, cls):
|
||||
if options is not None:
|
||||
raw.options.update(options)
|
||||
return raw
|
||||
|
||||
if isinstance(raw, dict):
|
||||
cls.check_required_fields(raw)
|
||||
key = cls(options=options)
|
||||
key._dict_data = raw
|
||||
else:
|
||||
raw_key = to_bytes(raw)
|
||||
|
||||
# security check
|
||||
if raw_key.startswith(POSSIBLE_UNSAFE_KEYS):
|
||||
raise ValueError("This key may not be safe to import")
|
||||
|
||||
key = cls(raw_key=raw_key, options=options)
|
||||
return key
|
||||
|
||||
@classmethod
|
||||
def generate_key(cls, key_size=256, options=None, is_private=True):
|
||||
"""Generate a ``OctKey`` with the given bit size."""
|
||||
if not is_private:
|
||||
raise ValueError("oct key can not be generated as public")
|
||||
|
||||
if key_size % 8 != 0:
|
||||
raise ValueError("Invalid bit size for oct key")
|
||||
|
||||
return cls.import_key(secrets.token_bytes(int(key_size / 8)), options)
|
||||
@@ -0,0 +1,127 @@
|
||||
from cryptography.hazmat.backends import default_backend
|
||||
from cryptography.hazmat.primitives.asymmetric import rsa
|
||||
from cryptography.hazmat.primitives.asymmetric.rsa import RSAPrivateKeyWithSerialization
|
||||
from cryptography.hazmat.primitives.asymmetric.rsa import RSAPrivateNumbers
|
||||
from cryptography.hazmat.primitives.asymmetric.rsa import RSAPublicKey
|
||||
from cryptography.hazmat.primitives.asymmetric.rsa import RSAPublicNumbers
|
||||
from cryptography.hazmat.primitives.asymmetric.rsa import rsa_crt_dmp1
|
||||
from cryptography.hazmat.primitives.asymmetric.rsa import rsa_crt_dmq1
|
||||
from cryptography.hazmat.primitives.asymmetric.rsa import rsa_crt_iqmp
|
||||
from cryptography.hazmat.primitives.asymmetric.rsa import rsa_recover_prime_factors
|
||||
|
||||
from authlib.common.encoding import base64_to_int
|
||||
from authlib.common.encoding import int_to_base64
|
||||
|
||||
from ..rfc7517 import AsymmetricKey
|
||||
|
||||
|
||||
class RSAKey(AsymmetricKey):
|
||||
"""Key class of the ``RSA`` key type."""
|
||||
|
||||
kty = "RSA"
|
||||
PUBLIC_KEY_CLS = RSAPublicKey
|
||||
PRIVATE_KEY_CLS = RSAPrivateKeyWithSerialization
|
||||
|
||||
PUBLIC_KEY_FIELDS = ["e", "n"]
|
||||
PRIVATE_KEY_FIELDS = ["d", "dp", "dq", "e", "n", "p", "q", "qi"]
|
||||
REQUIRED_JSON_FIELDS = ["e", "n"]
|
||||
SSH_PUBLIC_PREFIX = b"ssh-rsa"
|
||||
|
||||
def dumps_private_key(self):
|
||||
numbers = self.private_key.private_numbers()
|
||||
return {
|
||||
"n": int_to_base64(numbers.public_numbers.n),
|
||||
"e": int_to_base64(numbers.public_numbers.e),
|
||||
"d": int_to_base64(numbers.d),
|
||||
"p": int_to_base64(numbers.p),
|
||||
"q": int_to_base64(numbers.q),
|
||||
"dp": int_to_base64(numbers.dmp1),
|
||||
"dq": int_to_base64(numbers.dmq1),
|
||||
"qi": int_to_base64(numbers.iqmp),
|
||||
}
|
||||
|
||||
def dumps_public_key(self):
|
||||
numbers = self.public_key.public_numbers()
|
||||
return {"n": int_to_base64(numbers.n), "e": int_to_base64(numbers.e)}
|
||||
|
||||
def load_private_key(self):
|
||||
obj = self._dict_data
|
||||
|
||||
if "oth" in obj: # pragma: no cover
|
||||
# https://tools.ietf.org/html/rfc7518#section-6.3.2.7
|
||||
raise ValueError('"oth" is not supported yet')
|
||||
|
||||
public_numbers = RSAPublicNumbers(
|
||||
base64_to_int(obj["e"]), base64_to_int(obj["n"])
|
||||
)
|
||||
|
||||
if has_all_prime_factors(obj):
|
||||
numbers = RSAPrivateNumbers(
|
||||
d=base64_to_int(obj["d"]),
|
||||
p=base64_to_int(obj["p"]),
|
||||
q=base64_to_int(obj["q"]),
|
||||
dmp1=base64_to_int(obj["dp"]),
|
||||
dmq1=base64_to_int(obj["dq"]),
|
||||
iqmp=base64_to_int(obj["qi"]),
|
||||
public_numbers=public_numbers,
|
||||
)
|
||||
else:
|
||||
d = base64_to_int(obj["d"])
|
||||
p, q = rsa_recover_prime_factors(public_numbers.n, d, public_numbers.e)
|
||||
numbers = RSAPrivateNumbers(
|
||||
d=d,
|
||||
p=p,
|
||||
q=q,
|
||||
dmp1=rsa_crt_dmp1(d, p),
|
||||
dmq1=rsa_crt_dmq1(d, q),
|
||||
iqmp=rsa_crt_iqmp(p, q),
|
||||
public_numbers=public_numbers,
|
||||
)
|
||||
|
||||
return numbers.private_key(default_backend())
|
||||
|
||||
def load_public_key(self):
|
||||
numbers = RSAPublicNumbers(
|
||||
base64_to_int(self._dict_data["e"]), base64_to_int(self._dict_data["n"])
|
||||
)
|
||||
return numbers.public_key(default_backend())
|
||||
|
||||
@classmethod
|
||||
def generate_key(cls, key_size=2048, options=None, is_private=False) -> "RSAKey":
|
||||
if key_size < 512:
|
||||
raise ValueError("key_size must not be less than 512")
|
||||
if key_size % 8 != 0:
|
||||
raise ValueError("Invalid key_size for RSAKey")
|
||||
raw_key = rsa.generate_private_key(
|
||||
public_exponent=65537,
|
||||
key_size=key_size,
|
||||
backend=default_backend(),
|
||||
)
|
||||
if not is_private:
|
||||
raw_key = raw_key.public_key()
|
||||
return cls.import_key(raw_key, options=options)
|
||||
|
||||
@classmethod
|
||||
def import_dict_key(cls, raw, options=None):
|
||||
cls.check_required_fields(raw)
|
||||
key = cls(options=options)
|
||||
key._dict_data = raw
|
||||
if "d" in raw and not has_all_prime_factors(raw):
|
||||
# reload dict key
|
||||
key.load_raw_key()
|
||||
key.load_dict_key()
|
||||
return key
|
||||
|
||||
|
||||
def has_all_prime_factors(obj):
|
||||
props = ["p", "q", "dp", "dq", "qi"]
|
||||
props_found = [prop in obj for prop in props]
|
||||
if all(props_found):
|
||||
return True
|
||||
|
||||
if any(props_found):
|
||||
raise ValueError(
|
||||
"RSA key must include all parameters if any are present besides d"
|
||||
)
|
||||
|
||||
return False
|
||||
@@ -0,0 +1,12 @@
|
||||
import binascii
|
||||
|
||||
|
||||
def encode_int(num, bits):
|
||||
length = ((bits + 7) // 8) * 2
|
||||
padded_hex = f"{num:0{length}x}"
|
||||
big_endian = binascii.a2b_hex(padded_hex.encode("ascii"))
|
||||
return big_endian
|
||||
|
||||
|
||||
def decode_int(b):
|
||||
return int(binascii.b2a_hex(b), 16)
|
||||
Reference in New Issue
Block a user