This commit is contained in:
Iliyan Angelov
2025-11-19 12:27:01 +02:00
parent 2043ac897c
commit 34b4c969d4
469 changed files with 26870 additions and 8329 deletions

View File

@@ -5,7 +5,8 @@
from __future__ import annotations
import abc
import typing
from cryptography import utils
# This exists to break an import cycle. It is normally accessible from the
# ciphers module.
@@ -21,7 +22,7 @@ class CipherAlgorithm(metaclass=abc.ABCMeta):
@property
@abc.abstractmethod
def key_sizes(self) -> typing.FrozenSet[int]:
def key_sizes(self) -> frozenset[int]:
"""
Valid key sizes for this algorithm in bits
"""
@@ -35,7 +36,7 @@ class CipherAlgorithm(metaclass=abc.ABCMeta):
class BlockCipherAlgorithm(CipherAlgorithm):
key: bytes
key: utils.Buffer
@property
@abc.abstractmethod
@@ -43,3 +44,17 @@ class BlockCipherAlgorithm(CipherAlgorithm):
"""
The size of a block as an integer in bits (e.g. 64, 128).
"""
def _verify_key_size(
algorithm: CipherAlgorithm, key: utils.Buffer
) -> utils.Buffer:
# Verify that the key is instance of bytes
utils._check_byteslike("key", key)
# Verify that the key size matches the expected key size
if len(key) * 8 not in algorithm.key_sizes:
raise ValueError(
f"Invalid key size ({len(key) * 8}) for {algorithm.name}."
)
return key

View File

@@ -5,7 +5,6 @@
from __future__ import annotations
import abc
import typing
from cryptography import utils
from cryptography.hazmat.primitives.hashes import HashAlgorithm
@@ -78,9 +77,9 @@ class KeySerializationEncryptionBuilder:
self,
format: PrivateFormat,
*,
_kdf_rounds: typing.Optional[int] = None,
_hmac_hash: typing.Optional[HashAlgorithm] = None,
_key_cert_algorithm: typing.Optional[PBES] = None,
_kdf_rounds: int | None = None,
_hmac_hash: HashAlgorithm | None = None,
_key_cert_algorithm: PBES | None = None,
) -> None:
self._format = format
@@ -127,8 +126,7 @@ class KeySerializationEncryptionBuilder:
) -> KeySerializationEncryptionBuilder:
if self._format is not PrivateFormat.PKCS12:
raise TypeError(
"key_cert_algorithm only supported with "
"PrivateFormat.PKCS12"
"key_cert_algorithm only supported with PrivateFormat.PKCS12"
)
if self._key_cert_algorithm is not None:
raise ValueError("key_cert_algorithm already set")
@@ -158,9 +156,9 @@ class _KeySerializationEncryption(KeySerializationEncryption):
format: PrivateFormat,
password: bytes,
*,
kdf_rounds: typing.Optional[int],
hmac_hash: typing.Optional[HashAlgorithm],
key_cert_algorithm: typing.Optional[PBES],
kdf_rounds: int | None,
hmac_hash: HashAlgorithm | None,
key_cert_algorithm: PBES | None,
):
self._format = format
self.password = password

View File

@@ -5,142 +5,16 @@
from __future__ import annotations
import abc
import typing
from cryptography.hazmat.bindings._rust import openssl as rust_openssl
from cryptography.hazmat.primitives import _serialization
def generate_parameters(
generator: int, key_size: int, backend: typing.Any = None
) -> DHParameters:
from cryptography.hazmat.backends.openssl.backend import backend as ossl
return ossl.generate_dh_parameters(generator, key_size)
generate_parameters = rust_openssl.dh.generate_parameters
class DHParameterNumbers:
def __init__(self, p: int, g: int, q: typing.Optional[int] = None) -> None:
if not isinstance(p, int) or not isinstance(g, int):
raise TypeError("p and g must be integers")
if q is not None and not isinstance(q, int):
raise TypeError("q must be integer or None")
if g < 2:
raise ValueError("DH generator must be 2 or greater")
if p.bit_length() < rust_openssl.dh.MIN_MODULUS_SIZE:
raise ValueError(
f"p (modulus) must be at least "
f"{rust_openssl.dh.MIN_MODULUS_SIZE}-bit"
)
self._p = p
self._g = g
self._q = q
def __eq__(self, other: object) -> bool:
if not isinstance(other, DHParameterNumbers):
return NotImplemented
return (
self._p == other._p and self._g == other._g and self._q == other._q
)
def parameters(self, backend: typing.Any = None) -> DHParameters:
from cryptography.hazmat.backends.openssl.backend import (
backend as ossl,
)
return ossl.load_dh_parameter_numbers(self)
@property
def p(self) -> int:
return self._p
@property
def g(self) -> int:
return self._g
@property
def q(self) -> typing.Optional[int]:
return self._q
class DHPublicNumbers:
def __init__(self, y: int, parameter_numbers: DHParameterNumbers) -> None:
if not isinstance(y, int):
raise TypeError("y must be an integer.")
if not isinstance(parameter_numbers, DHParameterNumbers):
raise TypeError(
"parameters must be an instance of DHParameterNumbers."
)
self._y = y
self._parameter_numbers = parameter_numbers
def __eq__(self, other: object) -> bool:
if not isinstance(other, DHPublicNumbers):
return NotImplemented
return (
self._y == other._y
and self._parameter_numbers == other._parameter_numbers
)
def public_key(self, backend: typing.Any = None) -> DHPublicKey:
from cryptography.hazmat.backends.openssl.backend import (
backend as ossl,
)
return ossl.load_dh_public_numbers(self)
@property
def y(self) -> int:
return self._y
@property
def parameter_numbers(self) -> DHParameterNumbers:
return self._parameter_numbers
class DHPrivateNumbers:
def __init__(self, x: int, public_numbers: DHPublicNumbers) -> None:
if not isinstance(x, int):
raise TypeError("x must be an integer.")
if not isinstance(public_numbers, DHPublicNumbers):
raise TypeError(
"public_numbers must be an instance of " "DHPublicNumbers."
)
self._x = x
self._public_numbers = public_numbers
def __eq__(self, other: object) -> bool:
if not isinstance(other, DHPrivateNumbers):
return NotImplemented
return (
self._x == other._x
and self._public_numbers == other._public_numbers
)
def private_key(self, backend: typing.Any = None) -> DHPrivateKey:
from cryptography.hazmat.backends.openssl.backend import (
backend as ossl,
)
return ossl.load_dh_private_numbers(self)
@property
def public_numbers(self) -> DHPublicNumbers:
return self._public_numbers
@property
def x(self) -> int:
return self._x
DHPrivateNumbers = rust_openssl.dh.DHPrivateNumbers
DHPublicNumbers = rust_openssl.dh.DHPublicNumbers
DHParameterNumbers = rust_openssl.dh.DHParameterNumbers
class DHParameters(metaclass=abc.ABCMeta):
@@ -207,6 +81,12 @@ class DHPublicKey(metaclass=abc.ABCMeta):
Checks equality.
"""
@abc.abstractmethod
def __copy__(self) -> DHPublicKey:
"""
Returns a copy.
"""
DHPublicKeyWithSerialization = DHPublicKey
DHPublicKey.register(rust_openssl.dh.DHPublicKey)
@@ -256,6 +136,12 @@ class DHPrivateKey(metaclass=abc.ABCMeta):
Returns the key serialized as bytes.
"""
@abc.abstractmethod
def __copy__(self) -> DHPrivateKey:
"""
Returns a copy.
"""
DHPrivateKeyWithSerialization = DHPrivateKey
DHPrivateKey.register(rust_openssl.dh.DHPrivateKey)

View File

@@ -10,6 +10,7 @@ import typing
from cryptography.hazmat.bindings._rust import openssl as rust_openssl
from cryptography.hazmat.primitives import _serialization, hashes
from cryptography.hazmat.primitives.asymmetric import utils as asym_utils
from cryptography.utils import Buffer
class DSAParameters(metaclass=abc.ABCMeta):
@@ -53,8 +54,8 @@ class DSAPrivateKey(metaclass=abc.ABCMeta):
@abc.abstractmethod
def sign(
self,
data: bytes,
algorithm: typing.Union[asym_utils.Prehashed, hashes.HashAlgorithm],
data: Buffer,
algorithm: asym_utils.Prehashed | hashes.HashAlgorithm,
) -> bytes:
"""
Signs the data
@@ -77,6 +78,12 @@ class DSAPrivateKey(metaclass=abc.ABCMeta):
Returns the key serialized as bytes.
"""
@abc.abstractmethod
def __copy__(self) -> DSAPrivateKey:
"""
Returns a copy.
"""
DSAPrivateKeyWithSerialization = DSAPrivateKey
DSAPrivateKey.register(rust_openssl.dsa.DSAPrivateKey)
@@ -115,9 +122,9 @@ class DSAPublicKey(metaclass=abc.ABCMeta):
@abc.abstractmethod
def verify(
self,
signature: bytes,
data: bytes,
algorithm: typing.Union[asym_utils.Prehashed, hashes.HashAlgorithm],
signature: Buffer,
data: Buffer,
algorithm: asym_utils.Prehashed | hashes.HashAlgorithm,
) -> None:
"""
Verifies the signature of the data.
@@ -129,171 +136,32 @@ class DSAPublicKey(metaclass=abc.ABCMeta):
Checks equality.
"""
@abc.abstractmethod
def __copy__(self) -> DSAPublicKey:
"""
Returns a copy.
"""
DSAPublicKeyWithSerialization = DSAPublicKey
DSAPublicKey.register(rust_openssl.dsa.DSAPublicKey)
class DSAParameterNumbers:
def __init__(self, p: int, q: int, g: int):
if (
not isinstance(p, int)
or not isinstance(q, int)
or not isinstance(g, int)
):
raise TypeError(
"DSAParameterNumbers p, q, and g arguments must be integers."
)
self._p = p
self._q = q
self._g = g
@property
def p(self) -> int:
return self._p
@property
def q(self) -> int:
return self._q
@property
def g(self) -> int:
return self._g
def parameters(self, backend: typing.Any = None) -> DSAParameters:
from cryptography.hazmat.backends.openssl.backend import (
backend as ossl,
)
return ossl.load_dsa_parameter_numbers(self)
def __eq__(self, other: object) -> bool:
if not isinstance(other, DSAParameterNumbers):
return NotImplemented
return self.p == other.p and self.q == other.q and self.g == other.g
def __repr__(self) -> str:
return (
"<DSAParameterNumbers(p={self.p}, q={self.q}, "
"g={self.g})>".format(self=self)
)
class DSAPublicNumbers:
def __init__(self, y: int, parameter_numbers: DSAParameterNumbers):
if not isinstance(y, int):
raise TypeError("DSAPublicNumbers y argument must be an integer.")
if not isinstance(parameter_numbers, DSAParameterNumbers):
raise TypeError(
"parameter_numbers must be a DSAParameterNumbers instance."
)
self._y = y
self._parameter_numbers = parameter_numbers
@property
def y(self) -> int:
return self._y
@property
def parameter_numbers(self) -> DSAParameterNumbers:
return self._parameter_numbers
def public_key(self, backend: typing.Any = None) -> DSAPublicKey:
from cryptography.hazmat.backends.openssl.backend import (
backend as ossl,
)
return ossl.load_dsa_public_numbers(self)
def __eq__(self, other: object) -> bool:
if not isinstance(other, DSAPublicNumbers):
return NotImplemented
return (
self.y == other.y
and self.parameter_numbers == other.parameter_numbers
)
def __repr__(self) -> str:
return (
"<DSAPublicNumbers(y={self.y}, "
"parameter_numbers={self.parameter_numbers})>".format(self=self)
)
class DSAPrivateNumbers:
def __init__(self, x: int, public_numbers: DSAPublicNumbers):
if not isinstance(x, int):
raise TypeError("DSAPrivateNumbers x argument must be an integer.")
if not isinstance(public_numbers, DSAPublicNumbers):
raise TypeError(
"public_numbers must be a DSAPublicNumbers instance."
)
self._public_numbers = public_numbers
self._x = x
@property
def x(self) -> int:
return self._x
@property
def public_numbers(self) -> DSAPublicNumbers:
return self._public_numbers
def private_key(self, backend: typing.Any = None) -> DSAPrivateKey:
from cryptography.hazmat.backends.openssl.backend import (
backend as ossl,
)
return ossl.load_dsa_private_numbers(self)
def __eq__(self, other: object) -> bool:
if not isinstance(other, DSAPrivateNumbers):
return NotImplemented
return (
self.x == other.x and self.public_numbers == other.public_numbers
)
DSAPrivateNumbers = rust_openssl.dsa.DSAPrivateNumbers
DSAPublicNumbers = rust_openssl.dsa.DSAPublicNumbers
DSAParameterNumbers = rust_openssl.dsa.DSAParameterNumbers
def generate_parameters(
key_size: int, backend: typing.Any = None
) -> DSAParameters:
from cryptography.hazmat.backends.openssl.backend import backend as ossl
if key_size not in (1024, 2048, 3072, 4096):
raise ValueError("Key size must be 1024, 2048, 3072, or 4096 bits.")
return ossl.generate_dsa_parameters(key_size)
return rust_openssl.dsa.generate_parameters(key_size)
def generate_private_key(
key_size: int, backend: typing.Any = None
) -> DSAPrivateKey:
from cryptography.hazmat.backends.openssl.backend import backend as ossl
return ossl.generate_dsa_private_key_and_parameters(key_size)
def _check_dsa_parameters(parameters: DSAParameterNumbers) -> None:
if parameters.p.bit_length() not in [1024, 2048, 3072, 4096]:
raise ValueError(
"p must be exactly 1024, 2048, 3072, or 4096 bits long"
)
if parameters.q.bit_length() not in [160, 224, 256]:
raise ValueError("q must be exactly 160, 224, or 256 bits long")
if not (1 < parameters.g < parameters.p):
raise ValueError("g, p don't satisfy 1 < g < p.")
def _check_dsa_private_numbers(numbers: DSAPrivateNumbers) -> None:
parameters = numbers.public_numbers.parameter_numbers
_check_dsa_parameters(parameters)
if numbers.x <= 0 or numbers.x >= parameters.q:
raise ValueError("x must be > 0 and < q.")
if numbers.public_numbers.y != pow(parameters.g, numbers.x, parameters.p):
raise ValueError("y must be equal to (g ** x % p).")
parameters = generate_parameters(key_size)
return parameters.generate_private_key()

View File

@@ -8,7 +8,9 @@ import abc
import typing
from cryptography import utils
from cryptography.exceptions import UnsupportedAlgorithm, _Reasons
from cryptography.hazmat._oid import ObjectIdentifier
from cryptography.hazmat.bindings._rust import openssl as rust_openssl
from cryptography.hazmat.primitives import _serialization, hashes
from cryptography.hazmat.primitives.asymmetric import utils as asym_utils
@@ -50,13 +52,20 @@ class EllipticCurve(metaclass=abc.ABCMeta):
Bit size of a secret scalar for the curve.
"""
@property
@abc.abstractmethod
def group_order(self) -> int:
"""
The order of the curve's group.
"""
class EllipticCurveSignatureAlgorithm(metaclass=abc.ABCMeta):
@property
@abc.abstractmethod
def algorithm(
self,
) -> typing.Union[asym_utils.Prehashed, hashes.HashAlgorithm]:
) -> asym_utils.Prehashed | hashes.HashAlgorithm:
"""
The digest algorithm used with this signature.
"""
@@ -95,7 +104,7 @@ class EllipticCurvePrivateKey(metaclass=abc.ABCMeta):
@abc.abstractmethod
def sign(
self,
data: bytes,
data: utils.Buffer,
signature_algorithm: EllipticCurveSignatureAlgorithm,
) -> bytes:
"""
@@ -119,8 +128,15 @@ class EllipticCurvePrivateKey(metaclass=abc.ABCMeta):
Returns the key serialized as bytes.
"""
@abc.abstractmethod
def __copy__(self) -> EllipticCurvePrivateKey:
"""
Returns a copy.
"""
EllipticCurvePrivateKeyWithSerialization = EllipticCurvePrivateKey
EllipticCurvePrivateKey.register(rust_openssl.ec.ECPrivateKey)
class EllipticCurvePublicKey(metaclass=abc.ABCMeta):
@@ -157,8 +173,8 @@ class EllipticCurvePublicKey(metaclass=abc.ABCMeta):
@abc.abstractmethod
def verify(
self,
signature: bytes,
data: bytes,
signature: utils.Buffer,
data: utils.Buffer,
signature_algorithm: EllipticCurveSignatureAlgorithm,
) -> None:
"""
@@ -171,18 +187,13 @@ class EllipticCurvePublicKey(metaclass=abc.ABCMeta):
) -> EllipticCurvePublicKey:
utils._check_bytes("data", data)
if not isinstance(curve, EllipticCurve):
raise TypeError("curve must be an EllipticCurve instance")
if len(data) == 0:
raise ValueError("data must not be an empty byte string")
if data[0] not in [0x02, 0x03, 0x04]:
raise ValueError("Unsupported elliptic curve point type")
from cryptography.hazmat.backends.openssl.backend import backend
return backend.load_elliptic_curve_public_bytes(curve, data)
return rust_openssl.ec.from_public_bytes(curve, data)
@abc.abstractmethod
def __eq__(self, other: object) -> bool:
@@ -190,150 +201,199 @@ class EllipticCurvePublicKey(metaclass=abc.ABCMeta):
Checks equality.
"""
@abc.abstractmethod
def __copy__(self) -> EllipticCurvePublicKey:
"""
Returns a copy.
"""
EllipticCurvePublicKeyWithSerialization = EllipticCurvePublicKey
EllipticCurvePublicKey.register(rust_openssl.ec.ECPublicKey)
EllipticCurvePrivateNumbers = rust_openssl.ec.EllipticCurvePrivateNumbers
EllipticCurvePublicNumbers = rust_openssl.ec.EllipticCurvePublicNumbers
class SECT571R1(EllipticCurve):
name = "sect571r1"
key_size = 570
group_order = 0x3FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFE661CE18FF55987308059B186823851EC7DD9CA1161DE93D5174D66E8382E9BB2FE84E47 # noqa: E501
class SECT409R1(EllipticCurve):
name = "sect409r1"
key_size = 409
group_order = 0x10000000000000000000000000000000000000000000000000001E2AAD6A612F33307BE5FA47C3C9E052F838164CD37D9A21173 # noqa: E501
class SECT283R1(EllipticCurve):
name = "sect283r1"
key_size = 283
group_order = 0x3FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEF90399660FC938A90165B042A7CEFADB307 # noqa: E501
class SECT233R1(EllipticCurve):
name = "sect233r1"
key_size = 233
group_order = 0x1000000000000000000000000000013E974E72F8A6922031D2603CFE0D7
class SECT163R2(EllipticCurve):
name = "sect163r2"
key_size = 163
group_order = 0x40000000000000000000292FE77E70C12A4234C33
class SECT571K1(EllipticCurve):
name = "sect571k1"
key_size = 571
group_order = 0x20000000000000000000000000000000000000000000000000000000000000000000000131850E1F19A63E4B391A8DB917F4138B630D84BE5D639381E91DEB45CFE778F637C1001 # noqa: E501
class SECT409K1(EllipticCurve):
name = "sect409k1"
key_size = 409
group_order = 0x7FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFE5F83B2D4EA20400EC4557D5ED3E3E7CA5B4B5C83B8E01E5FCF # noqa: E501
class SECT283K1(EllipticCurve):
name = "sect283k1"
key_size = 283
group_order = 0x1FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFE9AE2ED07577265DFF7F94451E061E163C61 # noqa: E501
class SECT233K1(EllipticCurve):
name = "sect233k1"
key_size = 233
group_order = 0x8000000000000000000000000000069D5BB915BCD46EFB1AD5F173ABDF
class SECT163K1(EllipticCurve):
name = "sect163k1"
key_size = 163
group_order = 0x4000000000000000000020108A2E0CC0D99F8A5EF
class SECP521R1(EllipticCurve):
name = "secp521r1"
key_size = 521
group_order = 0x1FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFA51868783BF2F966B7FCC0148F709A5D03BB5C9B8899C47AEBB6FB71E91386409 # noqa: E501
class SECP384R1(EllipticCurve):
name = "secp384r1"
key_size = 384
group_order = 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFC7634D81F4372DDF581A0DB248B0A77AECEC196ACCC52973 # noqa: E501
class SECP256R1(EllipticCurve):
name = "secp256r1"
key_size = 256
group_order = (
0xFFFFFFFF00000000FFFFFFFFFFFFFFFFBCE6FAADA7179E84F3B9CAC2FC632551
)
class SECP256K1(EllipticCurve):
name = "secp256k1"
key_size = 256
group_order = (
0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEBAAEDCE6AF48A03BBFD25E8CD0364141
)
class SECP224R1(EllipticCurve):
name = "secp224r1"
key_size = 224
group_order = 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFF16A2E0B8F03E13DD29455C5C2A3D
class SECP192R1(EllipticCurve):
name = "secp192r1"
key_size = 192
group_order = 0xFFFFFFFFFFFFFFFFFFFFFFFF99DEF836146BC9B1B4D22831
class BrainpoolP256R1(EllipticCurve):
name = "brainpoolP256r1"
key_size = 256
group_order = (
0xA9FB57DBA1EEA9BC3E660A909D838D718C397AA3B561A6F7901E0E82974856A7
)
class BrainpoolP384R1(EllipticCurve):
name = "brainpoolP384r1"
key_size = 384
group_order = 0x8CB91E82A3386D280F5D6F7E50E641DF152F7109ED5456B31F166E6CAC0425A7CF3AB6AF6B7FC3103B883202E9046565 # noqa: E501
class BrainpoolP512R1(EllipticCurve):
name = "brainpoolP512r1"
key_size = 512
group_order = 0xAADD9DB8DBE9C48B3FD4E6AE33C9FC07CB308DB3B3C9D20ED6639CCA70330870553E5C414CA92619418661197FAC10471DB1D381085DDADDB58796829CA90069 # noqa: E501
_CURVE_TYPES: typing.Dict[str, typing.Type[EllipticCurve]] = {
"prime192v1": SECP192R1,
"prime256v1": SECP256R1,
"secp192r1": SECP192R1,
"secp224r1": SECP224R1,
"secp256r1": SECP256R1,
"secp384r1": SECP384R1,
"secp521r1": SECP521R1,
"secp256k1": SECP256K1,
"sect163k1": SECT163K1,
"sect233k1": SECT233K1,
"sect283k1": SECT283K1,
"sect409k1": SECT409K1,
"sect571k1": SECT571K1,
"sect163r2": SECT163R2,
"sect233r1": SECT233R1,
"sect283r1": SECT283R1,
"sect409r1": SECT409R1,
"sect571r1": SECT571R1,
"brainpoolP256r1": BrainpoolP256R1,
"brainpoolP384r1": BrainpoolP384R1,
"brainpoolP512r1": BrainpoolP512R1,
_CURVE_TYPES: dict[str, EllipticCurve] = {
"prime192v1": SECP192R1(),
"prime256v1": SECP256R1(),
"secp192r1": SECP192R1(),
"secp224r1": SECP224R1(),
"secp256r1": SECP256R1(),
"secp384r1": SECP384R1(),
"secp521r1": SECP521R1(),
"secp256k1": SECP256K1(),
"sect163k1": SECT163K1(),
"sect233k1": SECT233K1(),
"sect283k1": SECT283K1(),
"sect409k1": SECT409K1(),
"sect571k1": SECT571K1(),
"sect163r2": SECT163R2(),
"sect233r1": SECT233R1(),
"sect283r1": SECT283R1(),
"sect409r1": SECT409R1(),
"sect571r1": SECT571R1(),
"brainpoolP256r1": BrainpoolP256R1(),
"brainpoolP384r1": BrainpoolP384R1(),
"brainpoolP512r1": BrainpoolP512R1(),
}
class ECDSA(EllipticCurveSignatureAlgorithm):
def __init__(
self,
algorithm: typing.Union[asym_utils.Prehashed, hashes.HashAlgorithm],
algorithm: asym_utils.Prehashed | hashes.HashAlgorithm,
deterministic_signing: bool = False,
):
from cryptography.hazmat.backends.openssl.backend import backend
if (
deterministic_signing
and not backend.ecdsa_deterministic_supported()
):
raise UnsupportedAlgorithm(
"ECDSA with deterministic signature (RFC 6979) is not "
"supported by this version of OpenSSL.",
_Reasons.UNSUPPORTED_PUBLIC_KEY_ALGORITHM,
)
self._algorithm = algorithm
self._deterministic_signing = deterministic_signing
@property
def algorithm(
self,
) -> typing.Union[asym_utils.Prehashed, hashes.HashAlgorithm]:
) -> asym_utils.Prehashed | hashes.HashAlgorithm:
return self._algorithm
@property
def deterministic_signing(
self,
) -> bool:
return self._deterministic_signing
def generate_private_key(
curve: EllipticCurve, backend: typing.Any = None
) -> EllipticCurvePrivateKey:
from cryptography.hazmat.backends.openssl.backend import backend as ossl
return ossl.generate_elliptic_curve_private_key(curve)
generate_private_key = rust_openssl.ec.generate_private_key
def derive_private_key(
@@ -341,116 +401,13 @@ def derive_private_key(
curve: EllipticCurve,
backend: typing.Any = None,
) -> EllipticCurvePrivateKey:
from cryptography.hazmat.backends.openssl.backend import backend as ossl
if not isinstance(private_value, int):
raise TypeError("private_value must be an integer type.")
if private_value <= 0:
raise ValueError("private_value must be a positive integer.")
if not isinstance(curve, EllipticCurve):
raise TypeError("curve must provide the EllipticCurve interface.")
return ossl.derive_elliptic_curve_private_key(private_value, curve)
class EllipticCurvePublicNumbers:
def __init__(self, x: int, y: int, curve: EllipticCurve):
if not isinstance(x, int) or not isinstance(y, int):
raise TypeError("x and y must be integers.")
if not isinstance(curve, EllipticCurve):
raise TypeError("curve must provide the EllipticCurve interface.")
self._y = y
self._x = x
self._curve = curve
def public_key(self, backend: typing.Any = None) -> EllipticCurvePublicKey:
from cryptography.hazmat.backends.openssl.backend import (
backend as ossl,
)
return ossl.load_elliptic_curve_public_numbers(self)
@property
def curve(self) -> EllipticCurve:
return self._curve
@property
def x(self) -> int:
return self._x
@property
def y(self) -> int:
return self._y
def __eq__(self, other: object) -> bool:
if not isinstance(other, EllipticCurvePublicNumbers):
return NotImplemented
return (
self.x == other.x
and self.y == other.y
and self.curve.name == other.curve.name
and self.curve.key_size == other.curve.key_size
)
def __hash__(self) -> int:
return hash((self.x, self.y, self.curve.name, self.curve.key_size))
def __repr__(self) -> str:
return (
"<EllipticCurvePublicNumbers(curve={0.curve.name}, x={0.x}, "
"y={0.y}>".format(self)
)
class EllipticCurvePrivateNumbers:
def __init__(
self, private_value: int, public_numbers: EllipticCurvePublicNumbers
):
if not isinstance(private_value, int):
raise TypeError("private_value must be an integer.")
if not isinstance(public_numbers, EllipticCurvePublicNumbers):
raise TypeError(
"public_numbers must be an EllipticCurvePublicNumbers "
"instance."
)
self._private_value = private_value
self._public_numbers = public_numbers
def private_key(
self, backend: typing.Any = None
) -> EllipticCurvePrivateKey:
from cryptography.hazmat.backends.openssl.backend import (
backend as ossl,
)
return ossl.load_elliptic_curve_private_numbers(self)
@property
def private_value(self) -> int:
return self._private_value
@property
def public_numbers(self) -> EllipticCurvePublicNumbers:
return self._public_numbers
def __eq__(self, other: object) -> bool:
if not isinstance(other, EllipticCurvePrivateNumbers):
return NotImplemented
return (
self.private_value == other.private_value
and self.public_numbers == other.public_numbers
)
def __hash__(self) -> int:
return hash((self.private_value, self.public_numbers))
return rust_openssl.ec.derive_private_key(private_value, curve)
class ECDH:
@@ -480,7 +437,7 @@ _OID_TO_CURVE = {
}
def get_curve_for_oid(oid: ObjectIdentifier) -> typing.Type[EllipticCurve]:
def get_curve_for_oid(oid: ObjectIdentifier) -> type[EllipticCurve]:
try:
return _OID_TO_CURVE[oid]
except KeyError:

View File

@@ -9,6 +9,7 @@ import abc
from cryptography.exceptions import UnsupportedAlgorithm, _Reasons
from cryptography.hazmat.bindings._rust import openssl as rust_openssl
from cryptography.hazmat.primitives import _serialization
from cryptography.utils import Buffer
class Ed25519PublicKey(metaclass=abc.ABCMeta):
@@ -22,7 +23,7 @@ class Ed25519PublicKey(metaclass=abc.ABCMeta):
_Reasons.UNSUPPORTED_PUBLIC_KEY_ALGORITHM,
)
return backend.ed25519_load_public_bytes(data)
return rust_openssl.ed25519.from_public_bytes(data)
@abc.abstractmethod
def public_bytes(
@@ -42,7 +43,7 @@ class Ed25519PublicKey(metaclass=abc.ABCMeta):
"""
@abc.abstractmethod
def verify(self, signature: bytes, data: bytes) -> None:
def verify(self, signature: Buffer, data: Buffer) -> None:
"""
Verify the signature.
"""
@@ -53,9 +54,14 @@ class Ed25519PublicKey(metaclass=abc.ABCMeta):
Checks equality.
"""
@abc.abstractmethod
def __copy__(self) -> Ed25519PublicKey:
"""
Returns a copy.
"""
if hasattr(rust_openssl, "ed25519"):
Ed25519PublicKey.register(rust_openssl.ed25519.Ed25519PublicKey)
Ed25519PublicKey.register(rust_openssl.ed25519.Ed25519PublicKey)
class Ed25519PrivateKey(metaclass=abc.ABCMeta):
@@ -69,10 +75,10 @@ class Ed25519PrivateKey(metaclass=abc.ABCMeta):
_Reasons.UNSUPPORTED_PUBLIC_KEY_ALGORITHM,
)
return backend.ed25519_generate_key()
return rust_openssl.ed25519.generate_key()
@classmethod
def from_private_bytes(cls, data: bytes) -> Ed25519PrivateKey:
def from_private_bytes(cls, data: Buffer) -> Ed25519PrivateKey:
from cryptography.hazmat.backends.openssl.backend import backend
if not backend.ed25519_supported():
@@ -81,7 +87,7 @@ class Ed25519PrivateKey(metaclass=abc.ABCMeta):
_Reasons.UNSUPPORTED_PUBLIC_KEY_ALGORITHM,
)
return backend.ed25519_load_private_bytes(data)
return rust_openssl.ed25519.from_private_bytes(data)
@abc.abstractmethod
def public_key(self) -> Ed25519PublicKey:
@@ -108,11 +114,16 @@ class Ed25519PrivateKey(metaclass=abc.ABCMeta):
"""
@abc.abstractmethod
def sign(self, data: bytes) -> bytes:
def sign(self, data: Buffer) -> bytes:
"""
Signs the data.
"""
@abc.abstractmethod
def __copy__(self) -> Ed25519PrivateKey:
"""
Returns a copy.
"""
if hasattr(rust_openssl, "x25519"):
Ed25519PrivateKey.register(rust_openssl.ed25519.Ed25519PrivateKey)
Ed25519PrivateKey.register(rust_openssl.ed25519.Ed25519PrivateKey)

View File

@@ -9,6 +9,7 @@ import abc
from cryptography.exceptions import UnsupportedAlgorithm, _Reasons
from cryptography.hazmat.bindings._rust import openssl as rust_openssl
from cryptography.hazmat.primitives import _serialization
from cryptography.utils import Buffer
class Ed448PublicKey(metaclass=abc.ABCMeta):
@@ -22,7 +23,7 @@ class Ed448PublicKey(metaclass=abc.ABCMeta):
_Reasons.UNSUPPORTED_PUBLIC_KEY_ALGORITHM,
)
return backend.ed448_load_public_bytes(data)
return rust_openssl.ed448.from_public_bytes(data)
@abc.abstractmethod
def public_bytes(
@@ -42,7 +43,7 @@ class Ed448PublicKey(metaclass=abc.ABCMeta):
"""
@abc.abstractmethod
def verify(self, signature: bytes, data: bytes) -> None:
def verify(self, signature: Buffer, data: Buffer) -> None:
"""
Verify the signature.
"""
@@ -53,6 +54,12 @@ class Ed448PublicKey(metaclass=abc.ABCMeta):
Checks equality.
"""
@abc.abstractmethod
def __copy__(self) -> Ed448PublicKey:
"""
Returns a copy.
"""
if hasattr(rust_openssl, "ed448"):
Ed448PublicKey.register(rust_openssl.ed448.Ed448PublicKey)
@@ -68,10 +75,11 @@ class Ed448PrivateKey(metaclass=abc.ABCMeta):
"ed448 is not supported by this version of OpenSSL.",
_Reasons.UNSUPPORTED_PUBLIC_KEY_ALGORITHM,
)
return backend.ed448_generate_key()
return rust_openssl.ed448.generate_key()
@classmethod
def from_private_bytes(cls, data: bytes) -> Ed448PrivateKey:
def from_private_bytes(cls, data: Buffer) -> Ed448PrivateKey:
from cryptography.hazmat.backends.openssl.backend import backend
if not backend.ed448_supported():
@@ -80,7 +88,7 @@ class Ed448PrivateKey(metaclass=abc.ABCMeta):
_Reasons.UNSUPPORTED_PUBLIC_KEY_ALGORITHM,
)
return backend.ed448_load_private_bytes(data)
return rust_openssl.ed448.from_private_bytes(data)
@abc.abstractmethod
def public_key(self) -> Ed448PublicKey:
@@ -89,7 +97,7 @@ class Ed448PrivateKey(metaclass=abc.ABCMeta):
"""
@abc.abstractmethod
def sign(self, data: bytes) -> bytes:
def sign(self, data: Buffer) -> bytes:
"""
Signs the data.
"""
@@ -112,6 +120,12 @@ class Ed448PrivateKey(metaclass=abc.ABCMeta):
Equivalent to private_bytes(Raw, Raw, NoEncryption()).
"""
@abc.abstractmethod
def __copy__(self) -> Ed448PrivateKey:
"""
Returns a copy.
"""
if hasattr(rust_openssl, "x448"):
Ed448PrivateKey.register(rust_openssl.ed448.Ed448PrivateKey)

View File

@@ -5,7 +5,6 @@
from __future__ import annotations
import abc
import typing
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives._asymmetric import (
@@ -35,12 +34,12 @@ class PSS(AsymmetricPadding):
AUTO = _Auto()
DIGEST_LENGTH = _DigestLength()
name = "EMSA-PSS"
_salt_length: typing.Union[int, _MaxLength, _Auto, _DigestLength]
_salt_length: int | _MaxLength | _Auto | _DigestLength
def __init__(
self,
mgf: MGF,
salt_length: typing.Union[int, _MaxLength, _Auto, _DigestLength],
salt_length: int | _MaxLength | _Auto | _DigestLength,
) -> None:
self._mgf = mgf
@@ -57,6 +56,10 @@ class PSS(AsymmetricPadding):
self._salt_length = salt_length
@property
def mgf(self) -> MGF:
return self._mgf
class OAEP(AsymmetricPadding):
name = "EME-OAEP"
@@ -65,7 +68,7 @@ class OAEP(AsymmetricPadding):
self,
mgf: MGF,
algorithm: hashes.HashAlgorithm,
label: typing.Optional[bytes],
label: bytes | None,
):
if not isinstance(algorithm, hashes.HashAlgorithm):
raise TypeError("Expected instance of hashes.HashAlgorithm.")
@@ -74,14 +77,20 @@ class OAEP(AsymmetricPadding):
self._algorithm = algorithm
self._label = label
@property
def algorithm(self) -> hashes.HashAlgorithm:
return self._algorithm
@property
def mgf(self) -> MGF:
return self._mgf
class MGF(metaclass=abc.ABCMeta):
_algorithm: hashes.HashAlgorithm
class MGF1(MGF):
MAX_LENGTH = _MaxLength()
def __init__(self, algorithm: hashes.HashAlgorithm):
if not isinstance(algorithm, hashes.HashAlgorithm):
raise TypeError("Expected instance of hashes.HashAlgorithm.")
@@ -90,7 +99,7 @@ class MGF1(MGF):
def calculate_max_pss_salt_length(
key: typing.Union[rsa.RSAPrivateKey, rsa.RSAPublicKey],
key: rsa.RSAPrivateKey | rsa.RSAPublicKey,
hash_algorithm: hashes.HashAlgorithm,
) -> int:
if not isinstance(key, (rsa.RSAPrivateKey, rsa.RSAPublicKey)):

View File

@@ -5,9 +5,11 @@
from __future__ import annotations
import abc
import random
import typing
from math import gcd
from cryptography.hazmat.bindings._rust import openssl as rust_openssl
from cryptography.hazmat.primitives import _serialization, hashes
from cryptography.hazmat.primitives._asymmetric import AsymmetricPadding
from cryptography.hazmat.primitives.asymmetric import utils as asym_utils
@@ -38,7 +40,7 @@ class RSAPrivateKey(metaclass=abc.ABCMeta):
self,
data: bytes,
padding: AsymmetricPadding,
algorithm: typing.Union[asym_utils.Prehashed, hashes.HashAlgorithm],
algorithm: asym_utils.Prehashed | hashes.HashAlgorithm,
) -> bytes:
"""
Signs the data.
@@ -61,8 +63,15 @@ class RSAPrivateKey(metaclass=abc.ABCMeta):
Returns the key serialized as bytes.
"""
@abc.abstractmethod
def __copy__(self) -> RSAPrivateKey:
"""
Returns a copy.
"""
RSAPrivateKeyWithSerialization = RSAPrivateKey
RSAPrivateKey.register(rust_openssl.rsa.RSAPrivateKey)
class RSAPublicKey(metaclass=abc.ABCMeta):
@@ -101,7 +110,7 @@ class RSAPublicKey(metaclass=abc.ABCMeta):
signature: bytes,
data: bytes,
padding: AsymmetricPadding,
algorithm: typing.Union[asym_utils.Prehashed, hashes.HashAlgorithm],
algorithm: asym_utils.Prehashed | hashes.HashAlgorithm,
) -> None:
"""
Verifies the signature of the data.
@@ -112,7 +121,7 @@ class RSAPublicKey(metaclass=abc.ABCMeta):
self,
signature: bytes,
padding: AsymmetricPadding,
algorithm: typing.Optional[hashes.HashAlgorithm],
algorithm: hashes.HashAlgorithm | None,
) -> bytes:
"""
Recovers the original data from the signature.
@@ -124,8 +133,18 @@ class RSAPublicKey(metaclass=abc.ABCMeta):
Checks equality.
"""
@abc.abstractmethod
def __copy__(self) -> RSAPublicKey:
"""
Returns a copy.
"""
RSAPublicKeyWithSerialization = RSAPublicKey
RSAPublicKey.register(rust_openssl.rsa.RSAPublicKey)
RSAPrivateNumbers = rust_openssl.rsa.RSAPrivateNumbers
RSAPublicNumbers = rust_openssl.rsa.RSAPublicNumbers
def generate_private_key(
@@ -133,10 +152,8 @@ def generate_private_key(
key_size: int,
backend: typing.Any = None,
) -> RSAPrivateKey:
from cryptography.hazmat.backends.openssl.backend import backend as ossl
_verify_rsa_parameters(public_exponent, key_size)
return ossl.generate_rsa_private_key(public_exponent, key_size)
return rust_openssl.rsa.generate_private_key(public_exponent, key_size)
def _verify_rsa_parameters(public_exponent: int, key_size: int) -> None:
@@ -146,66 +163,8 @@ def _verify_rsa_parameters(public_exponent: int, key_size: int) -> None:
"65537. Almost everyone should choose 65537 here!"
)
if key_size < 512:
raise ValueError("key_size must be at least 512-bits.")
def _check_private_key_components(
p: int,
q: int,
private_exponent: int,
dmp1: int,
dmq1: int,
iqmp: int,
public_exponent: int,
modulus: int,
) -> None:
if modulus < 3:
raise ValueError("modulus must be >= 3.")
if p >= modulus:
raise ValueError("p must be < modulus.")
if q >= modulus:
raise ValueError("q must be < modulus.")
if dmp1 >= modulus:
raise ValueError("dmp1 must be < modulus.")
if dmq1 >= modulus:
raise ValueError("dmq1 must be < modulus.")
if iqmp >= modulus:
raise ValueError("iqmp must be < modulus.")
if private_exponent >= modulus:
raise ValueError("private_exponent must be < modulus.")
if public_exponent < 3 or public_exponent >= modulus:
raise ValueError("public_exponent must be >= 3 and < modulus.")
if public_exponent & 1 == 0:
raise ValueError("public_exponent must be odd.")
if dmp1 & 1 == 0:
raise ValueError("dmp1 must be odd.")
if dmq1 & 1 == 0:
raise ValueError("dmq1 must be odd.")
if p * q != modulus:
raise ValueError("p*q must equal modulus.")
def _check_public_key_components(e: int, n: int) -> None:
if n < 3:
raise ValueError("n must be >= 3.")
if e < 3 or e >= n:
raise ValueError("e must be >= 3 and < n.")
if e & 1 == 0:
raise ValueError("e must be odd.")
if key_size < 1024:
raise ValueError("key_size must be at least 1024-bits.")
def _modinv(e: int, m: int) -> int:
@@ -225,6 +184,8 @@ def rsa_crt_iqmp(p: int, q: int) -> int:
"""
Compute the CRT (q ** -1) % p value from RSA primes p and q.
"""
if p <= 1 or q <= 1:
raise ValueError("Values can't be <= 1")
return _modinv(q, p)
@@ -233,6 +194,8 @@ def rsa_crt_dmp1(private_exponent: int, p: int) -> int:
Compute the CRT private_exponent % (p - 1) value from the RSA
private_exponent (d) and p.
"""
if private_exponent <= 1 or p <= 1:
raise ValueError("Values can't be <= 1")
return private_exponent % (p - 1)
@@ -241,22 +204,49 @@ def rsa_crt_dmq1(private_exponent: int, q: int) -> int:
Compute the CRT private_exponent % (q - 1) value from the RSA
private_exponent (d) and q.
"""
if private_exponent <= 1 or q <= 1:
raise ValueError("Values can't be <= 1")
return private_exponent % (q - 1)
def rsa_recover_private_exponent(e: int, p: int, q: int) -> int:
"""
Compute the RSA private_exponent (d) given the public exponent (e)
and the RSA primes p and q.
This uses the Carmichael totient function to generate the
smallest possible working value of the private exponent.
"""
# This lambda_n is the Carmichael totient function.
# The original RSA paper uses the Euler totient function
# here: phi_n = (p - 1) * (q - 1)
# Either version of the private exponent will work, but the
# one generated by the older formulation may be larger
# than necessary. (lambda_n always divides phi_n)
#
# TODO: Replace with lcm(p - 1, q - 1) once the minimum
# supported Python version is >= 3.9.
if e <= 1 or p <= 1 or q <= 1:
raise ValueError("Values can't be <= 1")
lambda_n = (p - 1) * (q - 1) // gcd(p - 1, q - 1)
return _modinv(e, lambda_n)
# Controls the number of iterations rsa_recover_prime_factors will perform
# to obtain the prime factors. Each iteration increments by 2 so the actual
# maximum attempts is half this number.
_MAX_RECOVERY_ATTEMPTS = 1000
# to obtain the prime factors.
_MAX_RECOVERY_ATTEMPTS = 500
def rsa_recover_prime_factors(
n: int, e: int, d: int
) -> typing.Tuple[int, int]:
def rsa_recover_prime_factors(n: int, e: int, d: int) -> tuple[int, int]:
"""
Compute factors p and q from the private exponent d. We assume that n has
no more than two factors. This function is adapted from code in PyCrypto.
"""
# reject invalid values early
if d <= 1 or e <= 1:
raise ValueError("d, e can't be <= 1")
if 17 != pow(17, e * d, n):
raise ValueError("n, d, e don't match")
# See 8.2.2(i) in Handbook of Applied Cryptography.
ktot = d * e - 1
# The quantity d*e-1 is a multiple of phi(n), even,
@@ -270,8 +260,10 @@ def rsa_recover_prime_factors(
# See "Digitalized Signatures and Public Key Functions as Intractable
# as Factorization", M. Rabin, 1979
spotted = False
a = 2
while not spotted and a < _MAX_RECOVERY_ATTEMPTS:
tries = 0
while not spotted and tries < _MAX_RECOVERY_ATTEMPTS:
a = random.randint(2, n - 1)
tries += 1
k = t
# Cycle through all values a^{t*2^i}=a^k
while k < ktot:
@@ -284,8 +276,6 @@ def rsa_recover_prime_factors(
spotted = True
break
k *= 2
# This value was not any good... let's try another!
a += 2
if not spotted:
raise ValueError("Unable to compute factors p and q from exponent d.")
# Found !
@@ -293,147 +283,3 @@ def rsa_recover_prime_factors(
assert r == 0
p, q = sorted((p, q), reverse=True)
return (p, q)
class RSAPrivateNumbers:
def __init__(
self,
p: int,
q: int,
d: int,
dmp1: int,
dmq1: int,
iqmp: int,
public_numbers: RSAPublicNumbers,
):
if (
not isinstance(p, int)
or not isinstance(q, int)
or not isinstance(d, int)
or not isinstance(dmp1, int)
or not isinstance(dmq1, int)
or not isinstance(iqmp, int)
):
raise TypeError(
"RSAPrivateNumbers p, q, d, dmp1, dmq1, iqmp arguments must"
" all be an integers."
)
if not isinstance(public_numbers, RSAPublicNumbers):
raise TypeError(
"RSAPrivateNumbers public_numbers must be an RSAPublicNumbers"
" instance."
)
self._p = p
self._q = q
self._d = d
self._dmp1 = dmp1
self._dmq1 = dmq1
self._iqmp = iqmp
self._public_numbers = public_numbers
@property
def p(self) -> int:
return self._p
@property
def q(self) -> int:
return self._q
@property
def d(self) -> int:
return self._d
@property
def dmp1(self) -> int:
return self._dmp1
@property
def dmq1(self) -> int:
return self._dmq1
@property
def iqmp(self) -> int:
return self._iqmp
@property
def public_numbers(self) -> RSAPublicNumbers:
return self._public_numbers
def private_key(
self,
backend: typing.Any = None,
*,
unsafe_skip_rsa_key_validation: bool = False,
) -> RSAPrivateKey:
from cryptography.hazmat.backends.openssl.backend import (
backend as ossl,
)
return ossl.load_rsa_private_numbers(
self, unsafe_skip_rsa_key_validation
)
def __eq__(self, other: object) -> bool:
if not isinstance(other, RSAPrivateNumbers):
return NotImplemented
return (
self.p == other.p
and self.q == other.q
and self.d == other.d
and self.dmp1 == other.dmp1
and self.dmq1 == other.dmq1
and self.iqmp == other.iqmp
and self.public_numbers == other.public_numbers
)
def __hash__(self) -> int:
return hash(
(
self.p,
self.q,
self.d,
self.dmp1,
self.dmq1,
self.iqmp,
self.public_numbers,
)
)
class RSAPublicNumbers:
def __init__(self, e: int, n: int):
if not isinstance(e, int) or not isinstance(n, int):
raise TypeError("RSAPublicNumbers arguments must be integers.")
self._e = e
self._n = n
@property
def e(self) -> int:
return self._e
@property
def n(self) -> int:
return self._n
def public_key(self, backend: typing.Any = None) -> RSAPublicKey:
from cryptography.hazmat.backends.openssl.backend import (
backend as ossl,
)
return ossl.load_rsa_public_numbers(self)
def __repr__(self) -> str:
return "<RSAPublicNumbers(e={0.e}, n={0.n})>".format(self)
def __eq__(self, other: object) -> bool:
if not isinstance(other, RSAPublicNumbers):
return NotImplemented
return self.e == other.e and self.n == other.n
def __hash__(self) -> int:
return hash((self.e, self.n))

View File

@@ -9,6 +9,7 @@ import abc
from cryptography.exceptions import UnsupportedAlgorithm, _Reasons
from cryptography.hazmat.bindings._rust import openssl as rust_openssl
from cryptography.hazmat.primitives import _serialization
from cryptography.utils import Buffer
class X25519PublicKey(metaclass=abc.ABCMeta):
@@ -22,7 +23,7 @@ class X25519PublicKey(metaclass=abc.ABCMeta):
_Reasons.UNSUPPORTED_EXCHANGE_ALGORITHM,
)
return backend.x25519_load_public_bytes(data)
return rust_openssl.x25519.from_public_bytes(data)
@abc.abstractmethod
def public_bytes(
@@ -47,10 +48,14 @@ class X25519PublicKey(metaclass=abc.ABCMeta):
Checks equality.
"""
@abc.abstractmethod
def __copy__(self) -> X25519PublicKey:
"""
Returns a copy.
"""
# For LibreSSL
if hasattr(rust_openssl, "x25519"):
X25519PublicKey.register(rust_openssl.x25519.X25519PublicKey)
X25519PublicKey.register(rust_openssl.x25519.X25519PublicKey)
class X25519PrivateKey(metaclass=abc.ABCMeta):
@@ -63,10 +68,10 @@ class X25519PrivateKey(metaclass=abc.ABCMeta):
"X25519 is not supported by this version of OpenSSL.",
_Reasons.UNSUPPORTED_EXCHANGE_ALGORITHM,
)
return backend.x25519_generate_key()
return rust_openssl.x25519.generate_key()
@classmethod
def from_private_bytes(cls, data: bytes) -> X25519PrivateKey:
def from_private_bytes(cls, data: Buffer) -> X25519PrivateKey:
from cryptography.hazmat.backends.openssl.backend import backend
if not backend.x25519_supported():
@@ -75,12 +80,12 @@ class X25519PrivateKey(metaclass=abc.ABCMeta):
_Reasons.UNSUPPORTED_EXCHANGE_ALGORITHM,
)
return backend.x25519_load_private_bytes(data)
return rust_openssl.x25519.from_private_bytes(data)
@abc.abstractmethod
def public_key(self) -> X25519PublicKey:
"""
Returns the public key assosciated with this private key
Returns the public key associated with this private key
"""
@abc.abstractmethod
@@ -107,7 +112,11 @@ class X25519PrivateKey(metaclass=abc.ABCMeta):
Performs a key exchange operation using the provided peer's public key.
"""
@abc.abstractmethod
def __copy__(self) -> X25519PrivateKey:
"""
Returns a copy.
"""
# For LibreSSL
if hasattr(rust_openssl, "x25519"):
X25519PrivateKey.register(rust_openssl.x25519.X25519PrivateKey)
X25519PrivateKey.register(rust_openssl.x25519.X25519PrivateKey)

View File

@@ -9,6 +9,7 @@ import abc
from cryptography.exceptions import UnsupportedAlgorithm, _Reasons
from cryptography.hazmat.bindings._rust import openssl as rust_openssl
from cryptography.hazmat.primitives import _serialization
from cryptography.utils import Buffer
class X448PublicKey(metaclass=abc.ABCMeta):
@@ -22,7 +23,7 @@ class X448PublicKey(metaclass=abc.ABCMeta):
_Reasons.UNSUPPORTED_EXCHANGE_ALGORITHM,
)
return backend.x448_load_public_bytes(data)
return rust_openssl.x448.from_public_bytes(data)
@abc.abstractmethod
def public_bytes(
@@ -47,6 +48,12 @@ class X448PublicKey(metaclass=abc.ABCMeta):
Checks equality.
"""
@abc.abstractmethod
def __copy__(self) -> X448PublicKey:
"""
Returns a copy.
"""
if hasattr(rust_openssl, "x448"):
X448PublicKey.register(rust_openssl.x448.X448PublicKey)
@@ -62,10 +69,11 @@ class X448PrivateKey(metaclass=abc.ABCMeta):
"X448 is not supported by this version of OpenSSL.",
_Reasons.UNSUPPORTED_EXCHANGE_ALGORITHM,
)
return backend.x448_generate_key()
return rust_openssl.x448.generate_key()
@classmethod
def from_private_bytes(cls, data: bytes) -> X448PrivateKey:
def from_private_bytes(cls, data: Buffer) -> X448PrivateKey:
from cryptography.hazmat.backends.openssl.backend import backend
if not backend.x448_supported():
@@ -74,7 +82,7 @@ class X448PrivateKey(metaclass=abc.ABCMeta):
_Reasons.UNSUPPORTED_EXCHANGE_ALGORITHM,
)
return backend.x448_load_private_bytes(data)
return rust_openssl.x448.from_private_bytes(data)
@abc.abstractmethod
def public_key(self) -> X448PublicKey:
@@ -106,6 +114,12 @@ class X448PrivateKey(metaclass=abc.ABCMeta):
Performs a key exchange operation using the provided peer's public key.
"""
@abc.abstractmethod
def __copy__(self) -> X448PrivateKey:
"""
Returns a copy.
"""
if hasattr(rust_openssl, "x448"):
X448PrivateKey.register(rust_openssl.x448.X448PrivateKey)

View File

@@ -17,11 +17,11 @@ from cryptography.hazmat.primitives.ciphers.base import (
)
__all__ = [
"Cipher",
"CipherAlgorithm",
"BlockCipherAlgorithm",
"CipherContext",
"AEADCipherContext",
"AEADDecryptionContext",
"AEADEncryptionContext",
"BlockCipherAlgorithm",
"Cipher",
"CipherAlgorithm",
"CipherContext",
]

View File

@@ -4,375 +4,20 @@
from __future__ import annotations
import os
import typing
from cryptography import exceptions, utils
from cryptography.hazmat.backends.openssl import aead
from cryptography.hazmat.backends.openssl.backend import backend
from cryptography.hazmat.bindings._rust import FixedPool
class ChaCha20Poly1305:
_MAX_SIZE = 2**31 - 1
def __init__(self, key: bytes):
if not backend.aead_cipher_supported(self):
raise exceptions.UnsupportedAlgorithm(
"ChaCha20Poly1305 is not supported by this version of OpenSSL",
exceptions._Reasons.UNSUPPORTED_CIPHER,
)
utils._check_byteslike("key", key)
if len(key) != 32:
raise ValueError("ChaCha20Poly1305 key must be 32 bytes.")
self._key = key
self._pool = FixedPool(self._create_fn)
@classmethod
def generate_key(cls) -> bytes:
return os.urandom(32)
def _create_fn(self):
return aead._aead_create_ctx(backend, self, self._key)
def encrypt(
self,
nonce: bytes,
data: bytes,
associated_data: typing.Optional[bytes],
) -> bytes:
if associated_data is None:
associated_data = b""
if len(data) > self._MAX_SIZE or len(associated_data) > self._MAX_SIZE:
# This is OverflowError to match what cffi would raise
raise OverflowError(
"Data or associated data too long. Max 2**31 - 1 bytes"
)
self._check_params(nonce, data, associated_data)
with self._pool.acquire() as ctx:
return aead._encrypt(
backend, self, nonce, data, [associated_data], 16, ctx
)
def decrypt(
self,
nonce: bytes,
data: bytes,
associated_data: typing.Optional[bytes],
) -> bytes:
if associated_data is None:
associated_data = b""
self._check_params(nonce, data, associated_data)
with self._pool.acquire() as ctx:
return aead._decrypt(
backend, self, nonce, data, [associated_data], 16, ctx
)
def _check_params(
self,
nonce: bytes,
data: bytes,
associated_data: bytes,
) -> None:
utils._check_byteslike("nonce", nonce)
utils._check_byteslike("data", data)
utils._check_byteslike("associated_data", associated_data)
if len(nonce) != 12:
raise ValueError("Nonce must be 12 bytes")
class AESCCM:
_MAX_SIZE = 2**31 - 1
def __init__(self, key: bytes, tag_length: int = 16):
utils._check_byteslike("key", key)
if len(key) not in (16, 24, 32):
raise ValueError("AESCCM key must be 128, 192, or 256 bits.")
self._key = key
if not isinstance(tag_length, int):
raise TypeError("tag_length must be an integer")
if tag_length not in (4, 6, 8, 10, 12, 14, 16):
raise ValueError("Invalid tag_length")
self._tag_length = tag_length
if not backend.aead_cipher_supported(self):
raise exceptions.UnsupportedAlgorithm(
"AESCCM is not supported by this version of OpenSSL",
exceptions._Reasons.UNSUPPORTED_CIPHER,
)
@classmethod
def generate_key(cls, bit_length: int) -> bytes:
if not isinstance(bit_length, int):
raise TypeError("bit_length must be an integer")
if bit_length not in (128, 192, 256):
raise ValueError("bit_length must be 128, 192, or 256")
return os.urandom(bit_length // 8)
def encrypt(
self,
nonce: bytes,
data: bytes,
associated_data: typing.Optional[bytes],
) -> bytes:
if associated_data is None:
associated_data = b""
if len(data) > self._MAX_SIZE or len(associated_data) > self._MAX_SIZE:
# This is OverflowError to match what cffi would raise
raise OverflowError(
"Data or associated data too long. Max 2**31 - 1 bytes"
)
self._check_params(nonce, data, associated_data)
self._validate_lengths(nonce, len(data))
return aead._encrypt(
backend, self, nonce, data, [associated_data], self._tag_length
)
def decrypt(
self,
nonce: bytes,
data: bytes,
associated_data: typing.Optional[bytes],
) -> bytes:
if associated_data is None:
associated_data = b""
self._check_params(nonce, data, associated_data)
return aead._decrypt(
backend, self, nonce, data, [associated_data], self._tag_length
)
def _validate_lengths(self, nonce: bytes, data_len: int) -> None:
# For information about computing this, see
# https://tools.ietf.org/html/rfc3610#section-2.1
l_val = 15 - len(nonce)
if 2 ** (8 * l_val) < data_len:
raise ValueError("Data too long for nonce")
def _check_params(
self, nonce: bytes, data: bytes, associated_data: bytes
) -> None:
utils._check_byteslike("nonce", nonce)
utils._check_byteslike("data", data)
utils._check_byteslike("associated_data", associated_data)
if not 7 <= len(nonce) <= 13:
raise ValueError("Nonce must be between 7 and 13 bytes")
class AESGCM:
_MAX_SIZE = 2**31 - 1
def __init__(self, key: bytes):
utils._check_byteslike("key", key)
if len(key) not in (16, 24, 32):
raise ValueError("AESGCM key must be 128, 192, or 256 bits.")
self._key = key
@classmethod
def generate_key(cls, bit_length: int) -> bytes:
if not isinstance(bit_length, int):
raise TypeError("bit_length must be an integer")
if bit_length not in (128, 192, 256):
raise ValueError("bit_length must be 128, 192, or 256")
return os.urandom(bit_length // 8)
def encrypt(
self,
nonce: bytes,
data: bytes,
associated_data: typing.Optional[bytes],
) -> bytes:
if associated_data is None:
associated_data = b""
if len(data) > self._MAX_SIZE or len(associated_data) > self._MAX_SIZE:
# This is OverflowError to match what cffi would raise
raise OverflowError(
"Data or associated data too long. Max 2**31 - 1 bytes"
)
self._check_params(nonce, data, associated_data)
return aead._encrypt(backend, self, nonce, data, [associated_data], 16)
def decrypt(
self,
nonce: bytes,
data: bytes,
associated_data: typing.Optional[bytes],
) -> bytes:
if associated_data is None:
associated_data = b""
self._check_params(nonce, data, associated_data)
return aead._decrypt(backend, self, nonce, data, [associated_data], 16)
def _check_params(
self,
nonce: bytes,
data: bytes,
associated_data: bytes,
) -> None:
utils._check_byteslike("nonce", nonce)
utils._check_byteslike("data", data)
utils._check_byteslike("associated_data", associated_data)
if len(nonce) < 8 or len(nonce) > 128:
raise ValueError("Nonce must be between 8 and 128 bytes")
class AESOCB3:
_MAX_SIZE = 2**31 - 1
def __init__(self, key: bytes):
utils._check_byteslike("key", key)
if len(key) not in (16, 24, 32):
raise ValueError("AESOCB3 key must be 128, 192, or 256 bits.")
self._key = key
if not backend.aead_cipher_supported(self):
raise exceptions.UnsupportedAlgorithm(
"OCB3 is not supported by this version of OpenSSL",
exceptions._Reasons.UNSUPPORTED_CIPHER,
)
@classmethod
def generate_key(cls, bit_length: int) -> bytes:
if not isinstance(bit_length, int):
raise TypeError("bit_length must be an integer")
if bit_length not in (128, 192, 256):
raise ValueError("bit_length must be 128, 192, or 256")
return os.urandom(bit_length // 8)
def encrypt(
self,
nonce: bytes,
data: bytes,
associated_data: typing.Optional[bytes],
) -> bytes:
if associated_data is None:
associated_data = b""
if len(data) > self._MAX_SIZE or len(associated_data) > self._MAX_SIZE:
# This is OverflowError to match what cffi would raise
raise OverflowError(
"Data or associated data too long. Max 2**31 - 1 bytes"
)
self._check_params(nonce, data, associated_data)
return aead._encrypt(backend, self, nonce, data, [associated_data], 16)
def decrypt(
self,
nonce: bytes,
data: bytes,
associated_data: typing.Optional[bytes],
) -> bytes:
if associated_data is None:
associated_data = b""
self._check_params(nonce, data, associated_data)
return aead._decrypt(backend, self, nonce, data, [associated_data], 16)
def _check_params(
self,
nonce: bytes,
data: bytes,
associated_data: bytes,
) -> None:
utils._check_byteslike("nonce", nonce)
utils._check_byteslike("data", data)
utils._check_byteslike("associated_data", associated_data)
if len(nonce) < 12 or len(nonce) > 15:
raise ValueError("Nonce must be between 12 and 15 bytes")
class AESSIV:
_MAX_SIZE = 2**31 - 1
def __init__(self, key: bytes):
utils._check_byteslike("key", key)
if len(key) not in (32, 48, 64):
raise ValueError("AESSIV key must be 256, 384, or 512 bits.")
self._key = key
if not backend.aead_cipher_supported(self):
raise exceptions.UnsupportedAlgorithm(
"AES-SIV is not supported by this version of OpenSSL",
exceptions._Reasons.UNSUPPORTED_CIPHER,
)
@classmethod
def generate_key(cls, bit_length: int) -> bytes:
if not isinstance(bit_length, int):
raise TypeError("bit_length must be an integer")
if bit_length not in (256, 384, 512):
raise ValueError("bit_length must be 256, 384, or 512")
return os.urandom(bit_length // 8)
def encrypt(
self,
data: bytes,
associated_data: typing.Optional[typing.List[bytes]],
) -> bytes:
if associated_data is None:
associated_data = []
self._check_params(data, associated_data)
if len(data) > self._MAX_SIZE or any(
len(ad) > self._MAX_SIZE for ad in associated_data
):
# This is OverflowError to match what cffi would raise
raise OverflowError(
"Data or associated data too long. Max 2**31 - 1 bytes"
)
return aead._encrypt(backend, self, b"", data, associated_data, 16)
def decrypt(
self,
data: bytes,
associated_data: typing.Optional[typing.List[bytes]],
) -> bytes:
if associated_data is None:
associated_data = []
self._check_params(data, associated_data)
return aead._decrypt(backend, self, b"", data, associated_data, 16)
def _check_params(
self,
data: bytes,
associated_data: typing.List[bytes],
) -> None:
utils._check_byteslike("data", data)
if len(data) == 0:
raise ValueError("data must not be zero length")
if not isinstance(associated_data, list):
raise TypeError(
"associated_data must be a list of bytes-like objects or None"
)
for x in associated_data:
utils._check_byteslike("associated_data elements", x)
from cryptography.hazmat.bindings._rust import openssl as rust_openssl
__all__ = [
"AESCCM",
"AESGCM",
"AESGCMSIV",
"AESOCB3",
"AESSIV",
"ChaCha20Poly1305",
]
AESGCM = rust_openssl.aead.AESGCM
ChaCha20Poly1305 = rust_openssl.aead.ChaCha20Poly1305
AESCCM = rust_openssl.aead.AESCCM
AESSIV = rust_openssl.aead.AESSIV
AESOCB3 = rust_openssl.aead.AESOCB3
AESGCMSIV = rust_openssl.aead.AESGCMSIV

View File

@@ -5,33 +5,38 @@
from __future__ import annotations
from cryptography import utils
from cryptography.hazmat.decrepit.ciphers.algorithms import (
ARC4 as ARC4,
)
from cryptography.hazmat.decrepit.ciphers.algorithms import (
CAST5 as CAST5,
)
from cryptography.hazmat.decrepit.ciphers.algorithms import (
IDEA as IDEA,
)
from cryptography.hazmat.decrepit.ciphers.algorithms import (
SEED as SEED,
)
from cryptography.hazmat.decrepit.ciphers.algorithms import (
Blowfish as Blowfish,
)
from cryptography.hazmat.decrepit.ciphers.algorithms import (
TripleDES as TripleDES,
)
from cryptography.hazmat.primitives._cipheralgorithm import _verify_key_size
from cryptography.hazmat.primitives.ciphers import (
BlockCipherAlgorithm,
CipherAlgorithm,
)
def _verify_key_size(algorithm: CipherAlgorithm, key: bytes) -> bytes:
# Verify that the key is instance of bytes
utils._check_byteslike("key", key)
# Verify that the key size matches the expected key size
if len(key) * 8 not in algorithm.key_sizes:
raise ValueError(
"Invalid key size ({}) for {}.".format(
len(key) * 8, algorithm.name
)
)
return key
class AES(BlockCipherAlgorithm):
name = "AES"
block_size = 128
# 512 added to support AES-256-XTS, which uses 512-bit keys
key_sizes = frozenset([128, 192, 256, 512])
def __init__(self, key: bytes):
def __init__(self, key: utils.Buffer):
self.key = _verify_key_size(self, key)
@property
@@ -45,7 +50,7 @@ class AES128(BlockCipherAlgorithm):
key_sizes = frozenset([128])
key_size = 128
def __init__(self, key: bytes):
def __init__(self, key: utils.Buffer):
self.key = _verify_key_size(self, key)
@@ -55,7 +60,7 @@ class AES256(BlockCipherAlgorithm):
key_sizes = frozenset([256])
key_size = 256
def __init__(self, key: bytes):
def __init__(self, key: utils.Buffer):
self.key = _verify_key_size(self, key)
@@ -64,7 +69,7 @@ class Camellia(BlockCipherAlgorithm):
block_size = 128
key_sizes = frozenset([128, 192, 256])
def __init__(self, key: bytes):
def __init__(self, key: utils.Buffer):
self.key = _verify_key_size(self, key)
@property
@@ -72,124 +77,27 @@ class Camellia(BlockCipherAlgorithm):
return len(self.key) * 8
class TripleDES(BlockCipherAlgorithm):
name = "3DES"
block_size = 64
key_sizes = frozenset([64, 128, 192])
def __init__(self, key: bytes):
if len(key) == 8:
key += key + key
elif len(key) == 16:
key += key[:8]
self.key = _verify_key_size(self, key)
@property
def key_size(self) -> int:
return len(self.key) * 8
class Blowfish(BlockCipherAlgorithm):
name = "Blowfish"
block_size = 64
key_sizes = frozenset(range(32, 449, 8))
def __init__(self, key: bytes):
self.key = _verify_key_size(self, key)
@property
def key_size(self) -> int:
return len(self.key) * 8
_BlowfishInternal = Blowfish
utils.deprecated(
Blowfish,
ARC4,
__name__,
"Blowfish has been deprecated",
utils.DeprecatedIn37,
name="Blowfish",
"ARC4 has been moved to "
"cryptography.hazmat.decrepit.ciphers.algorithms.ARC4 and "
"will be removed from "
"cryptography.hazmat.primitives.ciphers.algorithms in 48.0.0.",
utils.DeprecatedIn43,
name="ARC4",
)
class CAST5(BlockCipherAlgorithm):
name = "CAST5"
block_size = 64
key_sizes = frozenset(range(40, 129, 8))
def __init__(self, key: bytes):
self.key = _verify_key_size(self, key)
@property
def key_size(self) -> int:
return len(self.key) * 8
_CAST5Internal = CAST5
utils.deprecated(
CAST5,
TripleDES,
__name__,
"CAST5 has been deprecated",
utils.DeprecatedIn37,
name="CAST5",
)
class ARC4(CipherAlgorithm):
name = "RC4"
key_sizes = frozenset([40, 56, 64, 80, 128, 160, 192, 256])
def __init__(self, key: bytes):
self.key = _verify_key_size(self, key)
@property
def key_size(self) -> int:
return len(self.key) * 8
class IDEA(BlockCipherAlgorithm):
name = "IDEA"
block_size = 64
key_sizes = frozenset([128])
def __init__(self, key: bytes):
self.key = _verify_key_size(self, key)
@property
def key_size(self) -> int:
return len(self.key) * 8
_IDEAInternal = IDEA
utils.deprecated(
IDEA,
__name__,
"IDEA has been deprecated",
utils.DeprecatedIn37,
name="IDEA",
)
class SEED(BlockCipherAlgorithm):
name = "SEED"
block_size = 128
key_sizes = frozenset([128])
def __init__(self, key: bytes):
self.key = _verify_key_size(self, key)
@property
def key_size(self) -> int:
return len(self.key) * 8
_SEEDInternal = SEED
utils.deprecated(
SEED,
__name__,
"SEED has been deprecated",
utils.DeprecatedIn37,
name="SEED",
"TripleDES has been moved to "
"cryptography.hazmat.decrepit.ciphers.algorithms.TripleDES and "
"will be removed from "
"cryptography.hazmat.primitives.ciphers.algorithms in 48.0.0.",
utils.DeprecatedIn43,
name="TripleDES",
)
@@ -197,7 +105,7 @@ class ChaCha20(CipherAlgorithm):
name = "ChaCha20"
key_sizes = frozenset([256])
def __init__(self, key: bytes, nonce: bytes):
def __init__(self, key: utils.Buffer, nonce: utils.Buffer):
self.key = _verify_key_size(self, key)
utils._check_byteslike("nonce", nonce)
@@ -207,7 +115,7 @@ class ChaCha20(CipherAlgorithm):
self._nonce = nonce
@property
def nonce(self) -> bytes:
def nonce(self) -> utils.Buffer:
return self._nonce
@property

View File

@@ -7,30 +7,22 @@ from __future__ import annotations
import abc
import typing
from cryptography.exceptions import (
AlreadyFinalized,
AlreadyUpdated,
NotYetFinalized,
)
from cryptography.hazmat.bindings._rust import openssl as rust_openssl
from cryptography.hazmat.primitives._cipheralgorithm import CipherAlgorithm
from cryptography.hazmat.primitives.ciphers import modes
if typing.TYPE_CHECKING:
from cryptography.hazmat.backends.openssl.ciphers import (
_CipherContext as _BackendCipherContext,
)
from cryptography.utils import Buffer
class CipherContext(metaclass=abc.ABCMeta):
@abc.abstractmethod
def update(self, data: bytes) -> bytes:
def update(self, data: Buffer) -> bytes:
"""
Processes the provided bytes through the cipher and returns the results
as bytes.
"""
@abc.abstractmethod
def update_into(self, data: bytes, buf: bytes) -> int:
def update_into(self, data: Buffer, buf: Buffer) -> int:
"""
Processes the provided bytes and writes the resulting data into the
provided buffer. Returns the number of bytes written.
@@ -42,10 +34,18 @@ class CipherContext(metaclass=abc.ABCMeta):
Returns the results of processing the final block as bytes.
"""
@abc.abstractmethod
def reset_nonce(self, nonce: bytes) -> None:
"""
Resets the nonce for the cipher context to the provided value.
Raises an exception if it does not support reset or if the
provided nonce does not have a valid length.
"""
class AEADCipherContext(CipherContext, metaclass=abc.ABCMeta):
@abc.abstractmethod
def authenticate_additional_data(self, data: bytes) -> None:
def authenticate_additional_data(self, data: Buffer) -> None:
"""
Authenticates the provided bytes.
"""
@@ -97,14 +97,12 @@ class Cipher(typing.Generic[Mode]):
@typing.overload
def encryptor(
self: Cipher[modes.ModeWithAuthenticationTag],
) -> AEADEncryptionContext:
...
) -> AEADEncryptionContext: ...
@typing.overload
def encryptor(
self: _CIPHER_TYPE,
) -> CipherContext:
...
) -> CipherContext: ...
def encryptor(self):
if isinstance(self.mode, modes.ModeWithAuthenticationTag):
@@ -112,158 +110,37 @@ class Cipher(typing.Generic[Mode]):
raise ValueError(
"Authentication tag must be None when encrypting."
)
from cryptography.hazmat.backends.openssl.backend import backend
ctx = backend.create_symmetric_encryption_ctx(
return rust_openssl.ciphers.create_encryption_ctx(
self.algorithm, self.mode
)
return self._wrap_ctx(ctx, encrypt=True)
@typing.overload
def decryptor(
self: Cipher[modes.ModeWithAuthenticationTag],
) -> AEADDecryptionContext:
...
) -> AEADDecryptionContext: ...
@typing.overload
def decryptor(
self: _CIPHER_TYPE,
) -> CipherContext:
...
) -> CipherContext: ...
def decryptor(self):
from cryptography.hazmat.backends.openssl.backend import backend
ctx = backend.create_symmetric_decryption_ctx(
return rust_openssl.ciphers.create_decryption_ctx(
self.algorithm, self.mode
)
return self._wrap_ctx(ctx, encrypt=False)
def _wrap_ctx(
self, ctx: _BackendCipherContext, encrypt: bool
) -> typing.Union[
AEADEncryptionContext, AEADDecryptionContext, CipherContext
]:
if isinstance(self.mode, modes.ModeWithAuthenticationTag):
if encrypt:
return _AEADEncryptionContext(ctx)
else:
return _AEADDecryptionContext(ctx)
else:
return _CipherContext(ctx)
_CIPHER_TYPE = Cipher[
typing.Union[
modes.ModeWithNonce,
modes.ModeWithTweak,
None,
modes.ECB,
modes.ModeWithInitializationVector,
None,
]
]
class _CipherContext(CipherContext):
_ctx: typing.Optional[_BackendCipherContext]
def __init__(self, ctx: _BackendCipherContext) -> None:
self._ctx = ctx
def update(self, data: bytes) -> bytes:
if self._ctx is None:
raise AlreadyFinalized("Context was already finalized.")
return self._ctx.update(data)
def update_into(self, data: bytes, buf: bytes) -> int:
if self._ctx is None:
raise AlreadyFinalized("Context was already finalized.")
return self._ctx.update_into(data, buf)
def finalize(self) -> bytes:
if self._ctx is None:
raise AlreadyFinalized("Context was already finalized.")
data = self._ctx.finalize()
self._ctx = None
return data
class _AEADCipherContext(AEADCipherContext):
_ctx: typing.Optional[_BackendCipherContext]
_tag: typing.Optional[bytes]
def __init__(self, ctx: _BackendCipherContext) -> None:
self._ctx = ctx
self._bytes_processed = 0
self._aad_bytes_processed = 0
self._tag = None
self._updated = False
def _check_limit(self, data_size: int) -> None:
if self._ctx is None:
raise AlreadyFinalized("Context was already finalized.")
self._updated = True
self._bytes_processed += data_size
if self._bytes_processed > self._ctx._mode._MAX_ENCRYPTED_BYTES:
raise ValueError(
"{} has a maximum encrypted byte limit of {}".format(
self._ctx._mode.name, self._ctx._mode._MAX_ENCRYPTED_BYTES
)
)
def update(self, data: bytes) -> bytes:
self._check_limit(len(data))
# mypy needs this assert even though _check_limit already checked
assert self._ctx is not None
return self._ctx.update(data)
def update_into(self, data: bytes, buf: bytes) -> int:
self._check_limit(len(data))
# mypy needs this assert even though _check_limit already checked
assert self._ctx is not None
return self._ctx.update_into(data, buf)
def finalize(self) -> bytes:
if self._ctx is None:
raise AlreadyFinalized("Context was already finalized.")
data = self._ctx.finalize()
self._tag = self._ctx.tag
self._ctx = None
return data
def authenticate_additional_data(self, data: bytes) -> None:
if self._ctx is None:
raise AlreadyFinalized("Context was already finalized.")
if self._updated:
raise AlreadyUpdated("Update has been called on this context.")
self._aad_bytes_processed += len(data)
if self._aad_bytes_processed > self._ctx._mode._MAX_AAD_BYTES:
raise ValueError(
"{} has a maximum AAD byte limit of {}".format(
self._ctx._mode.name, self._ctx._mode._MAX_AAD_BYTES
)
)
self._ctx.authenticate_additional_data(data)
class _AEADDecryptionContext(_AEADCipherContext, AEADDecryptionContext):
def finalize_with_tag(self, tag: bytes) -> bytes:
if self._ctx is None:
raise AlreadyFinalized("Context was already finalized.")
data = self._ctx.finalize_with_tag(tag)
self._tag = self._ctx.tag
self._ctx = None
return data
class _AEADEncryptionContext(_AEADCipherContext, AEADEncryptionContext):
@property
def tag(self) -> bytes:
if self._ctx is not None:
raise NotYetFinalized(
"You must finalize encryption before " "getting the tag."
)
assert self._tag is not None
return self._tag
CipherContext.register(rust_openssl.ciphers.CipherContext)
AEADEncryptionContext.register(rust_openssl.ciphers.AEADEncryptionContext)
AEADDecryptionContext.register(rust_openssl.ciphers.AEADDecryptionContext)

View File

@@ -5,7 +5,6 @@
from __future__ import annotations
import abc
import typing
from cryptography import utils
from cryptography.exceptions import UnsupportedAlgorithm, _Reasons
@@ -35,7 +34,7 @@ class Mode(metaclass=abc.ABCMeta):
class ModeWithInitializationVector(Mode, metaclass=abc.ABCMeta):
@property
@abc.abstractmethod
def initialization_vector(self) -> bytes:
def initialization_vector(self) -> utils.Buffer:
"""
The value of the initialization vector for this mode as bytes.
"""
@@ -44,7 +43,7 @@ class ModeWithInitializationVector(Mode, metaclass=abc.ABCMeta):
class ModeWithTweak(Mode, metaclass=abc.ABCMeta):
@property
@abc.abstractmethod
def tweak(self) -> bytes:
def tweak(self) -> utils.Buffer:
"""
The value of the tweak for this mode as bytes.
"""
@@ -53,7 +52,7 @@ class ModeWithTweak(Mode, metaclass=abc.ABCMeta):
class ModeWithNonce(Mode, metaclass=abc.ABCMeta):
@property
@abc.abstractmethod
def nonce(self) -> bytes:
def nonce(self) -> utils.Buffer:
"""
The value of the nonce for this mode as bytes.
"""
@@ -62,7 +61,7 @@ class ModeWithNonce(Mode, metaclass=abc.ABCMeta):
class ModeWithAuthenticationTag(Mode, metaclass=abc.ABCMeta):
@property
@abc.abstractmethod
def tag(self) -> typing.Optional[bytes]:
def tag(self) -> bytes | None:
"""
The value of the tag supplied to the constructor of this mode.
"""
@@ -78,16 +77,13 @@ def _check_aes_key_length(self: Mode, algorithm: CipherAlgorithm) -> None:
def _check_iv_length(
self: ModeWithInitializationVector, algorithm: BlockCipherAlgorithm
) -> None:
if len(self.initialization_vector) * 8 != algorithm.block_size:
raise ValueError(
"Invalid IV size ({}) for {}.".format(
len(self.initialization_vector), self.name
)
)
iv_len = len(self.initialization_vector)
if iv_len * 8 != algorithm.block_size:
raise ValueError(f"Invalid IV size ({iv_len}) for {self.name}.")
def _check_nonce_length(
nonce: bytes, name: str, algorithm: CipherAlgorithm
nonce: utils.Buffer, name: str, algorithm: CipherAlgorithm
) -> None:
if not isinstance(algorithm, BlockCipherAlgorithm):
raise UnsupportedAlgorithm(
@@ -113,12 +109,12 @@ def _check_iv_and_key_length(
class CBC(ModeWithInitializationVector):
name = "CBC"
def __init__(self, initialization_vector: bytes):
def __init__(self, initialization_vector: utils.Buffer):
utils._check_byteslike("initialization_vector", initialization_vector)
self._initialization_vector = initialization_vector
@property
def initialization_vector(self) -> bytes:
def initialization_vector(self) -> utils.Buffer:
return self._initialization_vector
validate_for_algorithm = _check_iv_and_key_length
@@ -127,7 +123,7 @@ class CBC(ModeWithInitializationVector):
class XTS(ModeWithTweak):
name = "XTS"
def __init__(self, tweak: bytes):
def __init__(self, tweak: utils.Buffer):
utils._check_byteslike("tweak", tweak)
if len(tweak) != 16:
@@ -136,7 +132,7 @@ class XTS(ModeWithTweak):
self._tweak = tweak
@property
def tweak(self) -> bytes:
def tweak(self) -> utils.Buffer:
return self._tweak
def validate_for_algorithm(self, algorithm: CipherAlgorithm) -> None:
@@ -162,12 +158,12 @@ class ECB(Mode):
class OFB(ModeWithInitializationVector):
name = "OFB"
def __init__(self, initialization_vector: bytes):
def __init__(self, initialization_vector: utils.Buffer):
utils._check_byteslike("initialization_vector", initialization_vector)
self._initialization_vector = initialization_vector
@property
def initialization_vector(self) -> bytes:
def initialization_vector(self) -> utils.Buffer:
return self._initialization_vector
validate_for_algorithm = _check_iv_and_key_length
@@ -176,12 +172,12 @@ class OFB(ModeWithInitializationVector):
class CFB(ModeWithInitializationVector):
name = "CFB"
def __init__(self, initialization_vector: bytes):
def __init__(self, initialization_vector: utils.Buffer):
utils._check_byteslike("initialization_vector", initialization_vector)
self._initialization_vector = initialization_vector
@property
def initialization_vector(self) -> bytes:
def initialization_vector(self) -> utils.Buffer:
return self._initialization_vector
validate_for_algorithm = _check_iv_and_key_length
@@ -190,12 +186,12 @@ class CFB(ModeWithInitializationVector):
class CFB8(ModeWithInitializationVector):
name = "CFB8"
def __init__(self, initialization_vector: bytes):
def __init__(self, initialization_vector: utils.Buffer):
utils._check_byteslike("initialization_vector", initialization_vector)
self._initialization_vector = initialization_vector
@property
def initialization_vector(self) -> bytes:
def initialization_vector(self) -> utils.Buffer:
return self._initialization_vector
validate_for_algorithm = _check_iv_and_key_length
@@ -204,12 +200,12 @@ class CFB8(ModeWithInitializationVector):
class CTR(ModeWithNonce):
name = "CTR"
def __init__(self, nonce: bytes):
def __init__(self, nonce: utils.Buffer):
utils._check_byteslike("nonce", nonce)
self._nonce = nonce
@property
def nonce(self) -> bytes:
def nonce(self) -> utils.Buffer:
return self._nonce
def validate_for_algorithm(self, algorithm: CipherAlgorithm) -> None:
@@ -224,8 +220,8 @@ class GCM(ModeWithInitializationVector, ModeWithAuthenticationTag):
def __init__(
self,
initialization_vector: bytes,
tag: typing.Optional[bytes] = None,
initialization_vector: utils.Buffer,
tag: bytes | None = None,
min_tag_length: int = 16,
):
# OpenSSL 3.0.0 constrains GCM IVs to [64, 1024] bits inclusive
@@ -243,19 +239,18 @@ class GCM(ModeWithInitializationVector, ModeWithAuthenticationTag):
raise ValueError("min_tag_length must be >= 4")
if len(tag) < min_tag_length:
raise ValueError(
"Authentication tag must be {} bytes or longer.".format(
min_tag_length
)
f"Authentication tag must be {min_tag_length} bytes or "
"longer."
)
self._tag = tag
self._min_tag_length = min_tag_length
@property
def tag(self) -> typing.Optional[bytes]:
def tag(self) -> bytes | None:
return self._tag
@property
def initialization_vector(self) -> bytes:
def initialization_vector(self) -> utils.Buffer:
return self._initialization_vector
def validate_for_algorithm(self, algorithm: CipherAlgorithm) -> None:
@@ -268,7 +263,6 @@ class GCM(ModeWithInitializationVector, ModeWithAuthenticationTag):
block_size_bytes = algorithm.block_size // 8
if self._tag is not None and len(self._tag) > block_size_bytes:
raise ValueError(
"Authentication tag cannot be more than {} bytes.".format(
block_size_bytes
)
f"Authentication tag cannot be more than {block_size_bytes} "
"bytes."
)

View File

@@ -4,62 +4,7 @@
from __future__ import annotations
import typing
from cryptography.hazmat.bindings._rust import openssl as rust_openssl
from cryptography import utils
from cryptography.exceptions import AlreadyFinalized
from cryptography.hazmat.primitives import ciphers
if typing.TYPE_CHECKING:
from cryptography.hazmat.backends.openssl.cmac import _CMACContext
class CMAC:
_ctx: typing.Optional[_CMACContext]
_algorithm: ciphers.BlockCipherAlgorithm
def __init__(
self,
algorithm: ciphers.BlockCipherAlgorithm,
backend: typing.Any = None,
ctx: typing.Optional[_CMACContext] = None,
) -> None:
if not isinstance(algorithm, ciphers.BlockCipherAlgorithm):
raise TypeError("Expected instance of BlockCipherAlgorithm.")
self._algorithm = algorithm
if ctx is None:
from cryptography.hazmat.backends.openssl.backend import (
backend as ossl,
)
self._ctx = ossl.create_cmac_ctx(self._algorithm)
else:
self._ctx = ctx
def update(self, data: bytes) -> None:
if self._ctx is None:
raise AlreadyFinalized("Context was already finalized.")
utils._check_bytes("data", data)
self._ctx.update(data)
def finalize(self) -> bytes:
if self._ctx is None:
raise AlreadyFinalized("Context was already finalized.")
digest = self._ctx.finalize()
self._ctx = None
return digest
def verify(self, signature: bytes) -> None:
utils._check_bytes("signature", signature)
if self._ctx is None:
raise AlreadyFinalized("Context was already finalized.")
ctx, self._ctx = self._ctx, None
ctx.verify(signature)
def copy(self) -> CMAC:
if self._ctx is None:
raise AlreadyFinalized("Context was already finalized.")
return CMAC(self._algorithm, ctx=self._ctx.copy())
__all__ = ["CMAC"]
CMAC = rust_openssl.cmac.CMAC

View File

@@ -5,32 +5,33 @@
from __future__ import annotations
import abc
import typing
from cryptography.hazmat.bindings._rust import openssl as rust_openssl
from cryptography.utils import Buffer
__all__ = [
"HashAlgorithm",
"HashContext",
"Hash",
"ExtendableOutputFunction",
"MD5",
"SHA1",
"SHA512_224",
"SHA512_256",
"SHA224",
"SHA256",
"SHA384",
"SHA512",
"SHA3_224",
"SHA3_256",
"SHA3_384",
"SHA3_512",
"SHA224",
"SHA256",
"SHA384",
"SHA512",
"SHA512_224",
"SHA512_256",
"SHAKE128",
"SHAKE256",
"MD5",
"SM3",
"BLAKE2b",
"BLAKE2s",
"SM3",
"ExtendableOutputFunction",
"Hash",
"HashAlgorithm",
"HashContext",
"XOFHash",
]
@@ -51,7 +52,7 @@ class HashAlgorithm(metaclass=abc.ABCMeta):
@property
@abc.abstractmethod
def block_size(self) -> typing.Optional[int]:
def block_size(self) -> int | None:
"""
The internal block size of the hash function, or None if the hash
function does not use blocks internally (e.g. SHA3).
@@ -67,7 +68,7 @@ class HashContext(metaclass=abc.ABCMeta):
"""
@abc.abstractmethod
def update(self, data: bytes) -> None:
def update(self, data: Buffer) -> None:
"""
Processes the provided bytes through the hash.
"""
@@ -88,6 +89,8 @@ class HashContext(metaclass=abc.ABCMeta):
Hash = rust_openssl.hashes.Hash
HashContext.register(Hash)
XOFHash = rust_openssl.hashes.XOFHash
class ExtendableOutputFunction(metaclass=abc.ABCMeta):
"""

View File

@@ -0,0 +1,13 @@
# This file is dual licensed under the terms of the Apache License, Version
# 2.0, and the BSD License. See the LICENSE file in the root of this repository
# for complete details.
from __future__ import annotations
from cryptography.hazmat.bindings._rust import openssl as rust_openssl
from cryptography.hazmat.primitives.kdf import KeyDerivationFunction
Argon2id = rust_openssl.kdf.Argon2id
KeyDerivationFunction.register(Argon2id)
__all__ = ["Argon2id"]

View File

@@ -5,6 +5,7 @@
from __future__ import annotations
import typing
from collections.abc import Callable
from cryptography import utils
from cryptography.exceptions import AlreadyFinalized, InvalidKey
@@ -19,7 +20,7 @@ def _int_to_u32be(n: int) -> bytes:
def _common_args_checks(
algorithm: hashes.HashAlgorithm,
length: int,
otherinfo: typing.Optional[bytes],
otherinfo: bytes | None,
) -> None:
max_length = algorithm.digest_size * (2**32 - 1)
if length > max_length:
@@ -29,9 +30,9 @@ def _common_args_checks(
def _concatkdf_derive(
key_material: bytes,
key_material: utils.Buffer,
length: int,
auxfn: typing.Callable[[], hashes.HashContext],
auxfn: Callable[[], hashes.HashContext],
otherinfo: bytes,
) -> bytes:
utils._check_byteslike("key_material", key_material)
@@ -56,7 +57,7 @@ class ConcatKDFHash(KeyDerivationFunction):
self,
algorithm: hashes.HashAlgorithm,
length: int,
otherinfo: typing.Optional[bytes],
otherinfo: bytes | None,
backend: typing.Any = None,
):
_common_args_checks(algorithm, length, otherinfo)
@@ -69,7 +70,7 @@ class ConcatKDFHash(KeyDerivationFunction):
def _hash(self) -> hashes.Hash:
return hashes.Hash(self._algorithm)
def derive(self, key_material: bytes) -> bytes:
def derive(self, key_material: utils.Buffer) -> bytes:
if self._used:
raise AlreadyFinalized
self._used = True
@@ -87,8 +88,8 @@ class ConcatKDFHMAC(KeyDerivationFunction):
self,
algorithm: hashes.HashAlgorithm,
length: int,
salt: typing.Optional[bytes],
otherinfo: typing.Optional[bytes],
salt: bytes | None,
otherinfo: bytes | None,
backend: typing.Any = None,
):
_common_args_checks(algorithm, length, otherinfo)
@@ -111,7 +112,7 @@ class ConcatKDFHMAC(KeyDerivationFunction):
def _hmac(self) -> hmac.HMAC:
return hmac.HMAC(self._salt, self._algorithm)
def derive(self, key_material: bytes) -> bytes:
def derive(self, key_material: utils.Buffer) -> bytes:
if self._used:
raise AlreadyFinalized
self._used = True

View File

@@ -4,98 +4,13 @@
from __future__ import annotations
import typing
from cryptography import utils
from cryptography.exceptions import AlreadyFinalized, InvalidKey
from cryptography.hazmat.primitives import constant_time, hashes, hmac
from cryptography.hazmat.bindings._rust import openssl as rust_openssl
from cryptography.hazmat.primitives.kdf import KeyDerivationFunction
HKDF = rust_openssl.kdf.HKDF
HKDFExpand = rust_openssl.kdf.HKDFExpand
class HKDF(KeyDerivationFunction):
def __init__(
self,
algorithm: hashes.HashAlgorithm,
length: int,
salt: typing.Optional[bytes],
info: typing.Optional[bytes],
backend: typing.Any = None,
):
self._algorithm = algorithm
KeyDerivationFunction.register(HKDF)
KeyDerivationFunction.register(HKDFExpand)
if salt is None:
salt = b"\x00" * self._algorithm.digest_size
else:
utils._check_bytes("salt", salt)
self._salt = salt
self._hkdf_expand = HKDFExpand(self._algorithm, length, info)
def _extract(self, key_material: bytes) -> bytes:
h = hmac.HMAC(self._salt, self._algorithm)
h.update(key_material)
return h.finalize()
def derive(self, key_material: bytes) -> bytes:
utils._check_byteslike("key_material", key_material)
return self._hkdf_expand.derive(self._extract(key_material))
def verify(self, key_material: bytes, expected_key: bytes) -> None:
if not constant_time.bytes_eq(self.derive(key_material), expected_key):
raise InvalidKey
class HKDFExpand(KeyDerivationFunction):
def __init__(
self,
algorithm: hashes.HashAlgorithm,
length: int,
info: typing.Optional[bytes],
backend: typing.Any = None,
):
self._algorithm = algorithm
max_length = 255 * algorithm.digest_size
if length > max_length:
raise ValueError(
f"Cannot derive keys larger than {max_length} octets."
)
self._length = length
if info is None:
info = b""
else:
utils._check_bytes("info", info)
self._info = info
self._used = False
def _expand(self, key_material: bytes) -> bytes:
output = [b""]
counter = 1
while self._algorithm.digest_size * (len(output) - 1) < self._length:
h = hmac.HMAC(key_material, self._algorithm)
h.update(output[-1])
h.update(self._info)
h.update(bytes([counter]))
output.append(h.finalize())
counter += 1
return b"".join(output)[: self._length]
def derive(self, key_material: bytes) -> bytes:
utils._check_byteslike("key_material", key_material)
if self._used:
raise AlreadyFinalized
self._used = True
return self._expand(key_material)
def verify(self, key_material: bytes, expected_key: bytes) -> None:
if not constant_time.bytes_eq(self.derive(key_material), expected_key):
raise InvalidKey
__all__ = ["HKDF", "HKDFExpand"]

View File

@@ -5,6 +5,7 @@
from __future__ import annotations
import typing
from collections.abc import Callable
from cryptography import utils
from cryptography.exceptions import (
@@ -36,16 +37,16 @@ class CounterLocation(utils.Enum):
class _KBKDFDeriver:
def __init__(
self,
prf: typing.Callable,
prf: Callable,
mode: Mode,
length: int,
rlen: int,
llen: typing.Optional[int],
llen: int | None,
location: CounterLocation,
break_location: typing.Optional[int],
label: typing.Optional[bytes],
context: typing.Optional[bytes],
fixed: typing.Optional[bytes],
break_location: int | None,
label: bytes | None,
context: bytes | None,
fixed: bytes | None,
):
assert callable(prf)
@@ -75,7 +76,7 @@ class _KBKDFDeriver:
if (label or context) and fixed:
raise ValueError(
"When supplying fixed data, " "label and context are ignored."
"When supplying fixed data, label and context are ignored."
)
if rlen is None or not self._valid_byte_length(rlen):
@@ -87,6 +88,9 @@ class _KBKDFDeriver:
if llen is not None and not isinstance(llen, int):
raise TypeError("llen must be an integer")
if llen == 0:
raise ValueError("llen must be non-zero")
if label is None:
label = b""
@@ -113,11 +117,11 @@ class _KBKDFDeriver:
raise TypeError("value must be of type int")
value_bin = utils.int_to_bytes(1, value)
if not 1 <= len(value_bin) <= 4:
return False
return True
return 1 <= len(value_bin) <= 4
def derive(self, key_material: bytes, prf_output_size: int) -> bytes:
def derive(
self, key_material: utils.Buffer, prf_output_size: int
) -> bytes:
if self._used:
raise AlreadyFinalized
@@ -181,14 +185,14 @@ class KBKDFHMAC(KeyDerivationFunction):
mode: Mode,
length: int,
rlen: int,
llen: typing.Optional[int],
llen: int | None,
location: CounterLocation,
label: typing.Optional[bytes],
context: typing.Optional[bytes],
fixed: typing.Optional[bytes],
label: bytes | None,
context: bytes | None,
fixed: bytes | None,
backend: typing.Any = None,
*,
break_location: typing.Optional[int] = None,
break_location: int | None = None,
):
if not isinstance(algorithm, hashes.HashAlgorithm):
raise UnsupportedAlgorithm(
@@ -224,7 +228,7 @@ class KBKDFHMAC(KeyDerivationFunction):
def _prf(self, key_material: bytes) -> hmac.HMAC:
return hmac.HMAC(key_material, self._algorithm)
def derive(self, key_material: bytes) -> bytes:
def derive(self, key_material: utils.Buffer) -> bytes:
return self._deriver.derive(key_material, self._algorithm.digest_size)
def verify(self, key_material: bytes, expected_key: bytes) -> None:
@@ -239,14 +243,14 @@ class KBKDFCMAC(KeyDerivationFunction):
mode: Mode,
length: int,
rlen: int,
llen: typing.Optional[int],
llen: int | None,
location: CounterLocation,
label: typing.Optional[bytes],
context: typing.Optional[bytes],
fixed: typing.Optional[bytes],
label: bytes | None,
context: bytes | None,
fixed: bytes | None,
backend: typing.Any = None,
*,
break_location: typing.Optional[int] = None,
break_location: int | None = None,
):
if not issubclass(
algorithm, ciphers.BlockCipherAlgorithm
@@ -257,7 +261,7 @@ class KBKDFCMAC(KeyDerivationFunction):
)
self._algorithm = algorithm
self._cipher: typing.Optional[ciphers.BlockCipherAlgorithm] = None
self._cipher: ciphers.BlockCipherAlgorithm | None = None
self._deriver = _KBKDFDeriver(
self._prf,
@@ -277,7 +281,7 @@ class KBKDFCMAC(KeyDerivationFunction):
return cmac.CMAC(self._cipher)
def derive(self, key_material: bytes) -> bytes:
def derive(self, key_material: utils.Buffer) -> bytes:
self._cipher = self._algorithm(key_material)
assert self._cipher is not None

View File

@@ -33,9 +33,7 @@ class PBKDF2HMAC(KeyDerivationFunction):
if not ossl.pbkdf2_hmac_supported(algorithm):
raise UnsupportedAlgorithm(
"{} is not supported for PBKDF2 by this backend.".format(
algorithm.name
),
f"{algorithm.name} is not supported for PBKDF2.",
_Reasons.UNSUPPORTED_HASH,
)
self._used = False
@@ -45,7 +43,7 @@ class PBKDF2HMAC(KeyDerivationFunction):
self._salt = salt
self._iterations = iterations
def derive(self, key_material: bytes) -> bytes:
def derive(self, key_material: utils.Buffer) -> bytes:
if self._used:
raise AlreadyFinalized("PBKDF2 instances can only be used once.")
self._used = True

View File

@@ -5,76 +5,15 @@
from __future__ import annotations
import sys
import typing
from cryptography import utils
from cryptography.exceptions import (
AlreadyFinalized,
InvalidKey,
UnsupportedAlgorithm,
)
from cryptography.hazmat.bindings._rust import openssl as rust_openssl
from cryptography.hazmat.primitives import constant_time
from cryptography.hazmat.primitives.kdf import KeyDerivationFunction
# This is used by the scrypt tests to skip tests that require more memory
# than the MEM_LIMIT
_MEM_LIMIT = sys.maxsize // 2
Scrypt = rust_openssl.kdf.Scrypt
KeyDerivationFunction.register(Scrypt)
class Scrypt(KeyDerivationFunction):
def __init__(
self,
salt: bytes,
length: int,
n: int,
r: int,
p: int,
backend: typing.Any = None,
):
from cryptography.hazmat.backends.openssl.backend import (
backend as ossl,
)
if not ossl.scrypt_supported():
raise UnsupportedAlgorithm(
"This version of OpenSSL does not support scrypt"
)
self._length = length
utils._check_bytes("salt", salt)
if n < 2 or (n & (n - 1)) != 0:
raise ValueError("n must be greater than 1 and be a power of 2.")
if r < 1:
raise ValueError("r must be greater than or equal to 1.")
if p < 1:
raise ValueError("p must be greater than or equal to 1.")
self._used = False
self._salt = salt
self._n = n
self._r = r
self._p = p
def derive(self, key_material: bytes) -> bytes:
if self._used:
raise AlreadyFinalized("Scrypt instances can only be used once.")
self._used = True
utils._check_byteslike("key_material", key_material)
return rust_openssl.kdf.derive_scrypt(
key_material,
self._salt,
self._n,
self._r,
self._p,
_MEM_LIMIT,
self._length,
)
def verify(self, key_material: bytes, expected_key: bytes) -> None:
derived_key = self.derive(key_material)
if not constant_time.bytes_eq(derived_key, expected_key):
raise InvalidKey("Keys do not match.")
__all__ = ["Scrypt"]

View File

@@ -21,7 +21,7 @@ class X963KDF(KeyDerivationFunction):
self,
algorithm: hashes.HashAlgorithm,
length: int,
sharedinfo: typing.Optional[bytes],
sharedinfo: bytes | None,
backend: typing.Any = None,
):
max_len = algorithm.digest_size * (2**32 - 1)
@@ -35,7 +35,7 @@ class X963KDF(KeyDerivationFunction):
self._sharedinfo = sharedinfo
self._used = False
def derive(self, key_material: bytes) -> bytes:
def derive(self, key_material: utils.Buffer) -> bytes:
if self._used:
raise AlreadyFinalized
self._used = True

View File

@@ -15,7 +15,7 @@ from cryptography.hazmat.primitives.constant_time import bytes_eq
def _wrap_core(
wrapping_key: bytes,
a: bytes,
r: typing.List[bytes],
r: list[bytes],
) -> bytes:
# RFC 3394 Key Wrap - 2.2.1 (index method)
encryptor = Cipher(AES(wrapping_key), ECB()).encryptor()
@@ -58,8 +58,8 @@ def aes_key_wrap(
def _unwrap_core(
wrapping_key: bytes,
a: bytes,
r: typing.List[bytes],
) -> typing.Tuple[bytes, typing.List[bytes]]:
r: list[bytes],
) -> tuple[bytes, list[bytes]]:
# Implement RFC 3394 Key Unwrap - 2.2.2 (index method)
decryptor = Cipher(AES(wrapping_key), ECB()).decryptor()
n = len(r)
@@ -86,7 +86,7 @@ def aes_key_wrap_with_padding(
if len(wrapping_key) not in [16, 24, 32]:
raise ValueError("The wrapping key must be a valid AES key length")
aiv = b"\xA6\x59\x59\xA6" + len(key_to_wrap).to_bytes(
aiv = b"\xa6\x59\x59\xa6" + len(key_to_wrap).to_bytes(
length=4, byteorder="big"
)
# pad the key to wrap if necessary

View File

@@ -5,19 +5,19 @@
from __future__ import annotations
import abc
import typing
from cryptography import utils
from cryptography.exceptions import AlreadyFinalized
from cryptography.hazmat.bindings._rust import (
check_ansix923_padding,
check_pkcs7_padding,
ANSIX923PaddingContext,
ANSIX923UnpaddingContext,
PKCS7PaddingContext,
PKCS7UnpaddingContext,
)
class PaddingContext(metaclass=abc.ABCMeta):
@abc.abstractmethod
def update(self, data: bytes) -> bytes:
def update(self, data: utils.Buffer) -> bytes:
"""
Pads the provided bytes and returns any available data as bytes.
"""
@@ -37,131 +37,20 @@ def _byte_padding_check(block_size: int) -> None:
raise ValueError("block_size must be a multiple of 8.")
def _byte_padding_update(
buffer_: typing.Optional[bytes], data: bytes, block_size: int
) -> typing.Tuple[bytes, bytes]:
if buffer_ is None:
raise AlreadyFinalized("Context was already finalized.")
utils._check_byteslike("data", data)
buffer_ += bytes(data)
finished_blocks = len(buffer_) // (block_size // 8)
result = buffer_[: finished_blocks * (block_size // 8)]
buffer_ = buffer_[finished_blocks * (block_size // 8) :]
return buffer_, result
def _byte_padding_pad(
buffer_: typing.Optional[bytes],
block_size: int,
paddingfn: typing.Callable[[int], bytes],
) -> bytes:
if buffer_ is None:
raise AlreadyFinalized("Context was already finalized.")
pad_size = block_size // 8 - len(buffer_)
return buffer_ + paddingfn(pad_size)
def _byte_unpadding_update(
buffer_: typing.Optional[bytes], data: bytes, block_size: int
) -> typing.Tuple[bytes, bytes]:
if buffer_ is None:
raise AlreadyFinalized("Context was already finalized.")
utils._check_byteslike("data", data)
buffer_ += bytes(data)
finished_blocks = max(len(buffer_) // (block_size // 8) - 1, 0)
result = buffer_[: finished_blocks * (block_size // 8)]
buffer_ = buffer_[finished_blocks * (block_size // 8) :]
return buffer_, result
def _byte_unpadding_check(
buffer_: typing.Optional[bytes],
block_size: int,
checkfn: typing.Callable[[bytes], int],
) -> bytes:
if buffer_ is None:
raise AlreadyFinalized("Context was already finalized.")
if len(buffer_) != block_size // 8:
raise ValueError("Invalid padding bytes.")
valid = checkfn(buffer_)
if not valid:
raise ValueError("Invalid padding bytes.")
pad_size = buffer_[-1]
return buffer_[:-pad_size]
class PKCS7:
def __init__(self, block_size: int):
_byte_padding_check(block_size)
self.block_size = block_size
def padder(self) -> PaddingContext:
return _PKCS7PaddingContext(self.block_size)
return PKCS7PaddingContext(self.block_size)
def unpadder(self) -> PaddingContext:
return _PKCS7UnpaddingContext(self.block_size)
return PKCS7UnpaddingContext(self.block_size)
class _PKCS7PaddingContext(PaddingContext):
_buffer: typing.Optional[bytes]
def __init__(self, block_size: int):
self.block_size = block_size
# TODO: more copies than necessary, we should use zero-buffer (#193)
self._buffer = b""
def update(self, data: bytes) -> bytes:
self._buffer, result = _byte_padding_update(
self._buffer, data, self.block_size
)
return result
def _padding(self, size: int) -> bytes:
return bytes([size]) * size
def finalize(self) -> bytes:
result = _byte_padding_pad(
self._buffer, self.block_size, self._padding
)
self._buffer = None
return result
class _PKCS7UnpaddingContext(PaddingContext):
_buffer: typing.Optional[bytes]
def __init__(self, block_size: int):
self.block_size = block_size
# TODO: more copies than necessary, we should use zero-buffer (#193)
self._buffer = b""
def update(self, data: bytes) -> bytes:
self._buffer, result = _byte_unpadding_update(
self._buffer, data, self.block_size
)
return result
def finalize(self) -> bytes:
result = _byte_unpadding_check(
self._buffer, self.block_size, check_pkcs7_padding
)
self._buffer = None
return result
PaddingContext.register(PKCS7PaddingContext)
PaddingContext.register(PKCS7UnpaddingContext)
class ANSIX923:
@@ -170,56 +59,11 @@ class ANSIX923:
self.block_size = block_size
def padder(self) -> PaddingContext:
return _ANSIX923PaddingContext(self.block_size)
return ANSIX923PaddingContext(self.block_size)
def unpadder(self) -> PaddingContext:
return _ANSIX923UnpaddingContext(self.block_size)
return ANSIX923UnpaddingContext(self.block_size)
class _ANSIX923PaddingContext(PaddingContext):
_buffer: typing.Optional[bytes]
def __init__(self, block_size: int):
self.block_size = block_size
# TODO: more copies than necessary, we should use zero-buffer (#193)
self._buffer = b""
def update(self, data: bytes) -> bytes:
self._buffer, result = _byte_padding_update(
self._buffer, data, self.block_size
)
return result
def _padding(self, size: int) -> bytes:
return bytes([0]) * (size - 1) + bytes([size])
def finalize(self) -> bytes:
result = _byte_padding_pad(
self._buffer, self.block_size, self._padding
)
self._buffer = None
return result
class _ANSIX923UnpaddingContext(PaddingContext):
_buffer: typing.Optional[bytes]
def __init__(self, block_size: int):
self.block_size = block_size
# TODO: more copies than necessary, we should use zero-buffer (#193)
self._buffer = b""
def update(self, data: bytes) -> bytes:
self._buffer, result = _byte_unpadding_update(
self._buffer, data, self.block_size
)
return result
def finalize(self) -> bytes:
result = _byte_unpadding_check(
self._buffer,
self.block_size,
check_ansix923_padding,
)
self._buffer = None
return result
PaddingContext.register(ANSIX923PaddingContext)
PaddingContext.register(ANSIX923UnpaddingContext)

View File

@@ -33,9 +33,25 @@ from cryptography.hazmat.primitives.serialization.ssh import (
load_ssh_private_key,
load_ssh_public_identity,
load_ssh_public_key,
ssh_key_fingerprint,
)
__all__ = [
"BestAvailableEncryption",
"Encoding",
"KeySerializationEncryption",
"NoEncryption",
"ParameterFormat",
"PrivateFormat",
"PublicFormat",
"SSHCertPrivateKeyTypes",
"SSHCertPublicKeyTypes",
"SSHCertificate",
"SSHCertificateBuilder",
"SSHCertificateType",
"SSHPrivateKeyTypes",
"SSHPublicKeyTypes",
"_KeySerializationEncryption",
"load_der_parameters",
"load_der_private_key",
"load_der_public_key",
@@ -45,19 +61,5 @@ __all__ = [
"load_ssh_private_key",
"load_ssh_public_identity",
"load_ssh_public_key",
"Encoding",
"PrivateFormat",
"PublicFormat",
"ParameterFormat",
"KeySerializationEncryption",
"BestAvailableEncryption",
"NoEncryption",
"_KeySerializationEncryption",
"SSHCertificateBuilder",
"SSHCertificate",
"SSHCertificateType",
"SSHCertPublicKeyTypes",
"SSHCertPrivateKeyTypes",
"SSHPrivateKeyTypes",
"SSHPublicKeyTypes",
"ssh_key_fingerprint",
]

View File

@@ -2,72 +2,13 @@
# 2.0, and the BSD License. See the LICENSE file in the root of this repository
# for complete details.
from __future__ import annotations
from cryptography.hazmat.bindings._rust import openssl as rust_openssl
import typing
load_pem_private_key = rust_openssl.keys.load_pem_private_key
load_der_private_key = rust_openssl.keys.load_der_private_key
from cryptography.hazmat.primitives.asymmetric import dh
from cryptography.hazmat.primitives.asymmetric.types import (
PrivateKeyTypes,
PublicKeyTypes,
)
load_pem_public_key = rust_openssl.keys.load_pem_public_key
load_der_public_key = rust_openssl.keys.load_der_public_key
def load_pem_private_key(
data: bytes,
password: typing.Optional[bytes],
backend: typing.Any = None,
*,
unsafe_skip_rsa_key_validation: bool = False,
) -> PrivateKeyTypes:
from cryptography.hazmat.backends.openssl.backend import backend as ossl
return ossl.load_pem_private_key(
data, password, unsafe_skip_rsa_key_validation
)
def load_pem_public_key(
data: bytes, backend: typing.Any = None
) -> PublicKeyTypes:
from cryptography.hazmat.backends.openssl.backend import backend as ossl
return ossl.load_pem_public_key(data)
def load_pem_parameters(
data: bytes, backend: typing.Any = None
) -> dh.DHParameters:
from cryptography.hazmat.backends.openssl.backend import backend as ossl
return ossl.load_pem_parameters(data)
def load_der_private_key(
data: bytes,
password: typing.Optional[bytes],
backend: typing.Any = None,
*,
unsafe_skip_rsa_key_validation: bool = False,
) -> PrivateKeyTypes:
from cryptography.hazmat.backends.openssl.backend import backend as ossl
return ossl.load_der_private_key(
data, password, unsafe_skip_rsa_key_validation
)
def load_der_public_key(
data: bytes, backend: typing.Any = None
) -> PublicKeyTypes:
from cryptography.hazmat.backends.openssl.backend import backend as ossl
return ossl.load_der_public_key(data)
def load_der_parameters(
data: bytes, backend: typing.Any = None
) -> dh.DHParameters:
from cryptography.hazmat.backends.openssl.backend import backend as ossl
return ossl.load_der_parameters(data)
load_pem_parameters = rust_openssl.dh.from_pem_parameters
load_der_parameters = rust_openssl.dh.from_der_parameters

View File

@@ -5,8 +5,10 @@
from __future__ import annotations
import typing
from collections.abc import Iterable
from cryptography import x509
from cryptography.hazmat.bindings._rust import pkcs12 as rust_pkcs12
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives._serialization import PBES as PBES
from cryptography.hazmat.primitives.asymmetric import (
@@ -20,11 +22,12 @@ from cryptography.hazmat.primitives.asymmetric.types import PrivateKeyTypes
__all__ = [
"PBES",
"PKCS12PrivateKeyTypes",
"PKCS12Certificate",
"PKCS12KeyAndCertificates",
"PKCS12PrivateKeyTypes",
"load_key_and_certificates",
"load_pkcs12",
"serialize_java_truststore",
"serialize_key_and_certificates",
]
@@ -37,51 +40,15 @@ PKCS12PrivateKeyTypes = typing.Union[
]
class PKCS12Certificate:
def __init__(
self,
cert: x509.Certificate,
friendly_name: typing.Optional[bytes],
):
if not isinstance(cert, x509.Certificate):
raise TypeError("Expecting x509.Certificate object")
if friendly_name is not None and not isinstance(friendly_name, bytes):
raise TypeError("friendly_name must be bytes or None")
self._cert = cert
self._friendly_name = friendly_name
@property
def friendly_name(self) -> typing.Optional[bytes]:
return self._friendly_name
@property
def certificate(self) -> x509.Certificate:
return self._cert
def __eq__(self, other: object) -> bool:
if not isinstance(other, PKCS12Certificate):
return NotImplemented
return (
self.certificate == other.certificate
and self.friendly_name == other.friendly_name
)
def __hash__(self) -> int:
return hash((self.certificate, self.friendly_name))
def __repr__(self) -> str:
return "<PKCS12Certificate({}, friendly_name={!r})>".format(
self.certificate, self.friendly_name
)
PKCS12Certificate = rust_pkcs12.PKCS12Certificate
class PKCS12KeyAndCertificates:
def __init__(
self,
key: typing.Optional[PrivateKeyTypes],
cert: typing.Optional[PKCS12Certificate],
additional_certs: typing.List[PKCS12Certificate],
key: PrivateKeyTypes | None,
cert: PKCS12Certificate | None,
additional_certs: list[PKCS12Certificate],
):
if key is not None and not isinstance(
key,
@@ -112,15 +79,15 @@ class PKCS12KeyAndCertificates:
self._additional_certs = additional_certs
@property
def key(self) -> typing.Optional[PrivateKeyTypes]:
def key(self) -> PrivateKeyTypes | None:
return self._key
@property
def cert(self) -> typing.Optional[PKCS12Certificate]:
def cert(self) -> PKCS12Certificate | None:
return self._cert
@property
def additional_certs(self) -> typing.List[PKCS12Certificate]:
def additional_certs(self) -> list[PKCS12Certificate]:
return self._additional_certs
def __eq__(self, other: object) -> bool:
@@ -143,28 +110,8 @@ class PKCS12KeyAndCertificates:
return fmt.format(self.key, self.cert, self.additional_certs)
def load_key_and_certificates(
data: bytes,
password: typing.Optional[bytes],
backend: typing.Any = None,
) -> typing.Tuple[
typing.Optional[PrivateKeyTypes],
typing.Optional[x509.Certificate],
typing.List[x509.Certificate],
]:
from cryptography.hazmat.backends.openssl.backend import backend as ossl
return ossl.load_key_and_certificates_from_pkcs12(data, password)
def load_pkcs12(
data: bytes,
password: typing.Optional[bytes],
backend: typing.Any = None,
) -> PKCS12KeyAndCertificates:
from cryptography.hazmat.backends.openssl.backend import backend as ossl
return ossl.load_pkcs12(data, password)
load_key_and_certificates = rust_pkcs12.load_key_and_certificates
load_pkcs12 = rust_pkcs12.load_pkcs12
_PKCS12CATypes = typing.Union[
@@ -173,11 +120,29 @@ _PKCS12CATypes = typing.Union[
]
def serialize_java_truststore(
certs: Iterable[PKCS12Certificate],
encryption_algorithm: serialization.KeySerializationEncryption,
) -> bytes:
if not certs:
raise ValueError("You must supply at least one cert")
if not isinstance(
encryption_algorithm, serialization.KeySerializationEncryption
):
raise TypeError(
"Key encryption algorithm must be a "
"KeySerializationEncryption instance"
)
return rust_pkcs12.serialize_java_truststore(certs, encryption_algorithm)
def serialize_key_and_certificates(
name: typing.Optional[bytes],
key: typing.Optional[PKCS12PrivateKeyTypes],
cert: typing.Optional[x509.Certificate],
cas: typing.Optional[typing.Iterable[_PKCS12CATypes]],
name: bytes | None,
key: PKCS12PrivateKeyTypes | None,
cert: x509.Certificate | None,
cas: Iterable[_PKCS12CATypes] | None,
encryption_algorithm: serialization.KeySerializationEncryption,
) -> bytes:
if key is not None and not isinstance(
@@ -194,22 +159,6 @@ def serialize_key_and_certificates(
"Key must be RSA, DSA, EllipticCurve, ED25519, or ED448"
" private key, or None."
)
if cert is not None and not isinstance(cert, x509.Certificate):
raise TypeError("cert must be a certificate or None")
if cas is not None:
cas = list(cas)
if not all(
isinstance(
val,
(
x509.Certificate,
PKCS12Certificate,
),
)
for val in cas
):
raise TypeError("all values in cas must be certificates")
if not isinstance(
encryption_algorithm, serialization.KeySerializationEncryption
@@ -222,8 +171,6 @@ def serialize_key_and_certificates(
if key is None and cert is None and not cas:
raise ValueError("You must supply at least one of key, cert, or cas")
from cryptography.hazmat.backends.openssl.backend import backend
return backend.serialize_key_and_certificates_to_pkcs12(
return rust_pkcs12.serialize_key_and_certificates(
name, key, cert, cas, encryption_algorithm
)

View File

@@ -10,32 +10,23 @@ import email.message
import email.policy
import io
import typing
from collections.abc import Iterable
from cryptography import utils, x509
from cryptography.exceptions import UnsupportedAlgorithm, _Reasons
from cryptography.hazmat.bindings._rust import pkcs7 as rust_pkcs7
from cryptography.hazmat.primitives import hashes, serialization
from cryptography.hazmat.primitives.asymmetric import ec, rsa
from cryptography.hazmat.primitives.asymmetric import ec, padding, rsa
from cryptography.hazmat.primitives.ciphers import (
algorithms,
)
from cryptography.utils import _check_byteslike
load_pem_pkcs7_certificates = rust_pkcs7.load_pem_pkcs7_certificates
def load_pem_pkcs7_certificates(data: bytes) -> typing.List[x509.Certificate]:
from cryptography.hazmat.backends.openssl.backend import backend
return backend.load_pem_pkcs7_certificates(data)
def load_der_pkcs7_certificates(data: bytes) -> typing.List[x509.Certificate]:
from cryptography.hazmat.backends.openssl.backend import backend
return backend.load_der_pkcs7_certificates(data)
def serialize_certificates(
certs: typing.List[x509.Certificate],
encoding: serialization.Encoding,
) -> bytes:
return rust_pkcs7.serialize_certificates(certs, encoding)
load_der_pkcs7_certificates = rust_pkcs7.load_der_pkcs7_certificates
serialize_certificates = rust_pkcs7.serialize_certificates
PKCS7HashTypes = typing.Union[
hashes.SHA224,
@@ -48,6 +39,10 @@ PKCS7PrivateKeyTypes = typing.Union[
rsa.RSAPrivateKey, ec.EllipticCurvePrivateKey
]
ContentEncryptionAlgorithm = typing.Union[
typing.Type[algorithms.AES128], typing.Type[algorithms.AES256]
]
class PKCS7Options(utils.Enum):
Text = "Add text/plain MIME type"
@@ -61,21 +56,22 @@ class PKCS7Options(utils.Enum):
class PKCS7SignatureBuilder:
def __init__(
self,
data: typing.Optional[bytes] = None,
signers: typing.List[
typing.Tuple[
data: utils.Buffer | None = None,
signers: list[
tuple[
x509.Certificate,
PKCS7PrivateKeyTypes,
PKCS7HashTypes,
padding.PSS | padding.PKCS1v15 | None,
]
] = [],
additional_certs: typing.List[x509.Certificate] = [],
additional_certs: list[x509.Certificate] = [],
):
self._data = data
self._signers = signers
self._additional_certs = additional_certs
def set_data(self, data: bytes) -> PKCS7SignatureBuilder:
def set_data(self, data: utils.Buffer) -> PKCS7SignatureBuilder:
_check_byteslike("data", data)
if self._data is not None:
raise ValueError("data may only be set once")
@@ -87,6 +83,8 @@ class PKCS7SignatureBuilder:
certificate: x509.Certificate,
private_key: PKCS7PrivateKeyTypes,
hash_algorithm: PKCS7HashTypes,
*,
rsa_padding: padding.PSS | padding.PKCS1v15 | None = None,
) -> PKCS7SignatureBuilder:
if not isinstance(
hash_algorithm,
@@ -109,9 +107,18 @@ class PKCS7SignatureBuilder:
):
raise TypeError("Only RSA & EC keys are supported at this time.")
if rsa_padding is not None:
if not isinstance(rsa_padding, (padding.PSS, padding.PKCS1v15)):
raise TypeError("Padding must be PSS or PKCS1v15")
if not isinstance(private_key, rsa.RSAPrivateKey):
raise TypeError("Padding is only supported for RSA keys")
return PKCS7SignatureBuilder(
self._data,
self._signers + [(certificate, private_key, hash_algorithm)],
[
*self._signers,
(certificate, private_key, hash_algorithm, rsa_padding),
],
)
def add_certificate(
@@ -121,13 +128,13 @@ class PKCS7SignatureBuilder:
raise TypeError("certificate must be a x509.Certificate")
return PKCS7SignatureBuilder(
self._data, self._signers, self._additional_certs + [certificate]
self._data, self._signers, [*self._additional_certs, certificate]
)
def sign(
self,
encoding: serialization.Encoding,
options: typing.Iterable[PKCS7Options],
options: Iterable[PKCS7Options],
backend: typing.Any = None,
) -> bytes:
if len(self._signers) == 0:
@@ -179,7 +186,131 @@ class PKCS7SignatureBuilder:
return rust_pkcs7.sign_and_serialize(self, encoding, options)
def _smime_encode(
class PKCS7EnvelopeBuilder:
def __init__(
self,
*,
_data: bytes | None = None,
_recipients: list[x509.Certificate] | None = None,
_content_encryption_algorithm: ContentEncryptionAlgorithm
| None = None,
):
from cryptography.hazmat.backends.openssl.backend import (
backend as ossl,
)
if not ossl.rsa_encryption_supported(padding=padding.PKCS1v15()):
raise UnsupportedAlgorithm(
"RSA with PKCS1 v1.5 padding is not supported by this version"
" of OpenSSL.",
_Reasons.UNSUPPORTED_PADDING,
)
self._data = _data
self._recipients = _recipients if _recipients is not None else []
self._content_encryption_algorithm = _content_encryption_algorithm
def set_data(self, data: bytes) -> PKCS7EnvelopeBuilder:
_check_byteslike("data", data)
if self._data is not None:
raise ValueError("data may only be set once")
return PKCS7EnvelopeBuilder(
_data=data,
_recipients=self._recipients,
_content_encryption_algorithm=self._content_encryption_algorithm,
)
def add_recipient(
self,
certificate: x509.Certificate,
) -> PKCS7EnvelopeBuilder:
if not isinstance(certificate, x509.Certificate):
raise TypeError("certificate must be a x509.Certificate")
if not isinstance(certificate.public_key(), rsa.RSAPublicKey):
raise TypeError("Only RSA keys are supported at this time.")
return PKCS7EnvelopeBuilder(
_data=self._data,
_recipients=[
*self._recipients,
certificate,
],
_content_encryption_algorithm=self._content_encryption_algorithm,
)
def set_content_encryption_algorithm(
self, content_encryption_algorithm: ContentEncryptionAlgorithm
) -> PKCS7EnvelopeBuilder:
if self._content_encryption_algorithm is not None:
raise ValueError("Content encryption algo may only be set once")
if content_encryption_algorithm not in {
algorithms.AES128,
algorithms.AES256,
}:
raise TypeError("Only AES128 and AES256 are supported")
return PKCS7EnvelopeBuilder(
_data=self._data,
_recipients=self._recipients,
_content_encryption_algorithm=content_encryption_algorithm,
)
def encrypt(
self,
encoding: serialization.Encoding,
options: Iterable[PKCS7Options],
) -> bytes:
if len(self._recipients) == 0:
raise ValueError("Must have at least one recipient")
if self._data is None:
raise ValueError("You must add data to encrypt")
# The default content encryption algorithm is AES-128, which the S/MIME
# v3.2 RFC specifies as MUST support (https://datatracker.ietf.org/doc/html/rfc5751#section-2.7)
content_encryption_algorithm = (
self._content_encryption_algorithm or algorithms.AES128
)
options = list(options)
if not all(isinstance(x, PKCS7Options) for x in options):
raise ValueError("options must be from the PKCS7Options enum")
if encoding not in (
serialization.Encoding.PEM,
serialization.Encoding.DER,
serialization.Encoding.SMIME,
):
raise ValueError(
"Must be PEM, DER, or SMIME from the Encoding enum"
)
# Only allow options that make sense for encryption
if any(
opt not in [PKCS7Options.Text, PKCS7Options.Binary]
for opt in options
):
raise ValueError(
"Only the following options are supported for encryption: "
"Text, Binary"
)
elif PKCS7Options.Text in options and PKCS7Options.Binary in options:
# OpenSSL accepts both options at the same time, but ignores Text.
# We fail defensively to avoid unexpected outputs.
raise ValueError(
"Cannot use Binary and Text options at the same time"
)
return rust_pkcs7.encrypt_and_serialize(
self, content_encryption_algorithm, encoding, options
)
pkcs7_decrypt_der = rust_pkcs7.decrypt_der
pkcs7_decrypt_pem = rust_pkcs7.decrypt_pem
pkcs7_decrypt_smime = rust_pkcs7.decrypt_smime
def _smime_signed_encode(
data: bytes, signature: bytes, micalg: str, text_mode: bool
) -> bytes:
# This function works pretty hard to replicate what OpenSSL does
@@ -227,6 +358,51 @@ def _smime_encode(
return fp.getvalue()
def _smime_enveloped_encode(data: bytes) -> bytes:
m = email.message.Message()
m.add_header("MIME-Version", "1.0")
m.add_header("Content-Disposition", "attachment", filename="smime.p7m")
m.add_header(
"Content-Type",
"application/pkcs7-mime",
smime_type="enveloped-data",
name="smime.p7m",
)
m.add_header("Content-Transfer-Encoding", "base64")
m.set_payload(email.base64mime.body_encode(data, maxlinelen=65))
return m.as_bytes(policy=m.policy.clone(linesep="\n", max_line_length=0))
def _smime_enveloped_decode(data: bytes) -> bytes:
m = email.message_from_bytes(data)
if m.get_content_type() not in {
"application/x-pkcs7-mime",
"application/pkcs7-mime",
}:
raise ValueError("Not an S/MIME enveloped message")
return bytes(m.get_payload(decode=True))
def _smime_remove_text_headers(data: bytes) -> bytes:
m = email.message_from_bytes(data)
# Using get() instead of get_content_type() since it has None as default,
# where the latter has "text/plain". Both methods are case-insensitive.
content_type = m.get("content-type")
if content_type is None:
raise ValueError(
"Decrypted MIME data has no 'Content-Type' header. "
"Please remove the 'Text' option to parse it manually."
)
if "text/plain" not in content_type:
raise ValueError(
f"Decrypted MIME data content type is '{content_type}', not "
"'text/plain'. Remove the 'Text' option to parse it manually."
)
return bytes(m.get_payload(decode=True))
class OpenSSLMimePart(email.message.MIMEPart):
# A MIMEPart subclass that replicates OpenSSL's behavior of not including
# a newline if there are no headers.

View File

@@ -64,6 +64,10 @@ _ECDSA_NISTP384 = b"ecdsa-sha2-nistp384"
_ECDSA_NISTP521 = b"ecdsa-sha2-nistp521"
_CERT_SUFFIX = b"-cert-v01@openssh.com"
# U2F application string suffixed pubkey
_SK_SSH_ED25519 = b"sk-ssh-ed25519@openssh.com"
_SK_SSH_ECDSA_NISTP256 = b"sk-ecdsa-sha2-nistp256@openssh.com"
# These are not key types, only algorithms, so they cannot appear
# as a public key type
_SSH_RSA_SHA256 = b"rsa-sha2-256"
@@ -87,21 +91,17 @@ _PADDING = memoryview(bytearray(range(1, 1 + 16)))
@dataclass
class _SSHCipher:
alg: typing.Type[algorithms.AES]
alg: type[algorithms.AES]
key_len: int
mode: typing.Union[
typing.Type[modes.CTR],
typing.Type[modes.CBC],
typing.Type[modes.GCM],
]
mode: type[modes.CTR] | type[modes.CBC] | type[modes.GCM]
block_len: int
iv_len: int
tag_len: typing.Optional[int]
tag_len: int | None
is_aead: bool
# ciphers that are actually used in key wrapping
_SSH_CIPHERS: typing.Dict[bytes, _SSHCipher] = {
_SSH_CIPHERS: dict[bytes, _SSHCipher] = {
b"aes256-ctr": _SSHCipher(
alg=algorithms.AES,
key_len=32,
@@ -139,9 +139,7 @@ _ECDSA_KEY_TYPE = {
}
def _get_ssh_key_type(
key: typing.Union[SSHPrivateKeyTypes, SSHPublicKeyTypes]
) -> bytes:
def _get_ssh_key_type(key: SSHPrivateKeyTypes | SSHPublicKeyTypes) -> bytes:
if isinstance(key, ec.EllipticCurvePrivateKey):
key_type = _ecdsa_key_type(key.public_key())
elif isinstance(key, ec.EllipticCurvePublicKey):
@@ -171,20 +169,20 @@ def _ecdsa_key_type(public_key: ec.EllipticCurvePublicKey) -> bytes:
def _ssh_pem_encode(
data: bytes,
data: utils.Buffer,
prefix: bytes = _SK_START + b"\n",
suffix: bytes = _SK_END + b"\n",
) -> bytes:
return b"".join([prefix, _base64_encode(data), suffix])
def _check_block_size(data: bytes, block_len: int) -> None:
def _check_block_size(data: utils.Buffer, block_len: int) -> None:
"""Require data to be full blocks"""
if not data or len(data) % block_len != 0:
raise ValueError("Corrupt data: missing padding")
def _check_empty(data: bytes) -> None:
def _check_empty(data: utils.Buffer) -> None:
"""All data should have been parsed."""
if data:
raise ValueError("Corrupt data: unparsed data")
@@ -192,13 +190,15 @@ def _check_empty(data: bytes) -> None:
def _init_cipher(
ciphername: bytes,
password: typing.Optional[bytes],
password: bytes | None,
salt: bytes,
rounds: int,
) -> Cipher[typing.Union[modes.CBC, modes.CTR, modes.GCM]]:
) -> Cipher[modes.CBC | modes.CTR | modes.GCM]:
"""Generate key + iv and return cipher."""
if not password:
raise ValueError("Key is password-protected.")
raise TypeError(
"Key is password-protected, but password was not provided."
)
ciph = _SSH_CIPHERS[ciphername]
seed = _bcrypt_kdf(
@@ -210,21 +210,21 @@ def _init_cipher(
)
def _get_u32(data: memoryview) -> typing.Tuple[int, memoryview]:
def _get_u32(data: memoryview) -> tuple[int, memoryview]:
"""Uint32"""
if len(data) < 4:
raise ValueError("Invalid data")
return int.from_bytes(data[:4], byteorder="big"), data[4:]
def _get_u64(data: memoryview) -> typing.Tuple[int, memoryview]:
def _get_u64(data: memoryview) -> tuple[int, memoryview]:
"""Uint64"""
if len(data) < 8:
raise ValueError("Invalid data")
return int.from_bytes(data[:8], byteorder="big"), data[8:]
def _get_sshstr(data: memoryview) -> typing.Tuple[memoryview, memoryview]:
def _get_sshstr(data: memoryview) -> tuple[memoryview, memoryview]:
"""Bytes with u32 length prefix"""
n, data = _get_u32(data)
if n > len(data):
@@ -232,7 +232,7 @@ def _get_sshstr(data: memoryview) -> typing.Tuple[memoryview, memoryview]:
return data[:n], data[n:]
def _get_mpint(data: memoryview) -> typing.Tuple[int, memoryview]:
def _get_mpint(data: memoryview) -> tuple[int, memoryview]:
"""Big integer."""
val, data = _get_sshstr(data)
if val and val[0] > 0x7F:
@@ -253,16 +253,14 @@ def _to_mpint(val: int) -> bytes:
class _FragList:
"""Build recursive structure without data copy."""
flist: typing.List[bytes]
flist: list[utils.Buffer]
def __init__(
self, init: typing.Optional[typing.List[bytes]] = None
) -> None:
def __init__(self, init: list[utils.Buffer] | None = None) -> None:
self.flist = []
if init:
self.flist.extend(init)
def put_raw(self, val: bytes) -> None:
def put_raw(self, val: utils.Buffer) -> None:
"""Add plain bytes"""
self.flist.append(val)
@@ -274,7 +272,7 @@ class _FragList:
"""Big-endian uint64"""
self.flist.append(val.to_bytes(length=8, byteorder="big"))
def put_sshstr(self, val: typing.Union[bytes, _FragList]) -> None:
def put_sshstr(self, val: bytes | _FragList) -> None:
"""Bytes prefixed with u32 length"""
if isinstance(val, (bytes, memoryview, bytearray)):
self.put_u32(len(val))
@@ -315,7 +313,9 @@ class _SSHFormatRSA:
mpint n, e, d, iqmp, p, q
"""
def get_public(self, data: memoryview):
def get_public(
self, data: memoryview
) -> tuple[tuple[int, int], memoryview]:
"""RSA public fields"""
e, data = _get_mpint(data)
n, data = _get_mpint(data)
@@ -323,7 +323,7 @@ class _SSHFormatRSA:
def load_public(
self, data: memoryview
) -> typing.Tuple[rsa.RSAPublicKey, memoryview]:
) -> tuple[rsa.RSAPublicKey, memoryview]:
"""Make RSA public key from data."""
(e, n), data = self.get_public(data)
public_numbers = rsa.RSAPublicNumbers(e, n)
@@ -331,8 +331,8 @@ class _SSHFormatRSA:
return public_key, data
def load_private(
self, data: memoryview, pubfields
) -> typing.Tuple[rsa.RSAPrivateKey, memoryview]:
self, data: memoryview, pubfields, unsafe_skip_rsa_key_validation: bool
) -> tuple[rsa.RSAPrivateKey, memoryview]:
"""Make RSA private key from data."""
n, data = _get_mpint(data)
e, data = _get_mpint(data)
@@ -349,7 +349,9 @@ class _SSHFormatRSA:
private_numbers = rsa.RSAPrivateNumbers(
p, q, d, dmp1, dmq1, iqmp, public_numbers
)
private_key = private_numbers.private_key()
private_key = private_numbers.private_key(
unsafe_skip_rsa_key_validation=unsafe_skip_rsa_key_validation
)
return private_key, data
def encode_public(
@@ -385,9 +387,7 @@ class _SSHFormatDSA:
mpint p, q, g, y, x
"""
def get_public(
self, data: memoryview
) -> typing.Tuple[typing.Tuple, memoryview]:
def get_public(self, data: memoryview) -> tuple[tuple, memoryview]:
"""DSA public fields"""
p, data = _get_mpint(data)
q, data = _get_mpint(data)
@@ -397,7 +397,7 @@ class _SSHFormatDSA:
def load_public(
self, data: memoryview
) -> typing.Tuple[dsa.DSAPublicKey, memoryview]:
) -> tuple[dsa.DSAPublicKey, memoryview]:
"""Make DSA public key from data."""
(p, q, g, y), data = self.get_public(data)
parameter_numbers = dsa.DSAParameterNumbers(p, q, g)
@@ -407,8 +407,8 @@ class _SSHFormatDSA:
return public_key, data
def load_private(
self, data: memoryview, pubfields
) -> typing.Tuple[dsa.DSAPrivateKey, memoryview]:
self, data: memoryview, pubfields, unsafe_skip_rsa_key_validation: bool
) -> tuple[dsa.DSAPrivateKey, memoryview]:
"""Make DSA private key from data."""
(p, q, g, y), data = self.get_public(data)
x, data = _get_mpint(data)
@@ -466,7 +466,7 @@ class _SSHFormatECDSA:
def get_public(
self, data: memoryview
) -> typing.Tuple[typing.Tuple, memoryview]:
) -> tuple[tuple[memoryview, memoryview], memoryview]:
"""ECDSA public fields"""
curve, data = _get_sshstr(data)
point, data = _get_sshstr(data)
@@ -478,17 +478,17 @@ class _SSHFormatECDSA:
def load_public(
self, data: memoryview
) -> typing.Tuple[ec.EllipticCurvePublicKey, memoryview]:
) -> tuple[ec.EllipticCurvePublicKey, memoryview]:
"""Make ECDSA public key from data."""
(curve_name, point), data = self.get_public(data)
(_, point), data = self.get_public(data)
public_key = ec.EllipticCurvePublicKey.from_encoded_point(
self.curve, point.tobytes()
)
return public_key, data
def load_private(
self, data: memoryview, pubfields
) -> typing.Tuple[ec.EllipticCurvePrivateKey, memoryview]:
self, data: memoryview, pubfields, unsafe_skip_rsa_key_validation: bool
) -> tuple[ec.EllipticCurvePrivateKey, memoryview]:
"""Make ECDSA private key from data."""
(curve_name, point), data = self.get_public(data)
secret, data = _get_mpint(data)
@@ -531,14 +531,14 @@ class _SSHFormatEd25519:
def get_public(
self, data: memoryview
) -> typing.Tuple[typing.Tuple, memoryview]:
) -> tuple[tuple[memoryview], memoryview]:
"""Ed25519 public fields"""
point, data = _get_sshstr(data)
return (point,), data
def load_public(
self, data: memoryview
) -> typing.Tuple[ed25519.Ed25519PublicKey, memoryview]:
) -> tuple[ed25519.Ed25519PublicKey, memoryview]:
"""Make Ed25519 public key from data."""
(point,), data = self.get_public(data)
public_key = ed25519.Ed25519PublicKey.from_public_bytes(
@@ -547,8 +547,8 @@ class _SSHFormatEd25519:
return public_key, data
def load_private(
self, data: memoryview, pubfields
) -> typing.Tuple[ed25519.Ed25519PrivateKey, memoryview]:
self, data: memoryview, pubfields, unsafe_skip_rsa_key_validation: bool
) -> tuple[ed25519.Ed25519PrivateKey, memoryview]:
"""Make Ed25519 private key from data."""
(point,), data = self.get_public(data)
keypair, data = _get_sshstr(data)
@@ -586,6 +586,70 @@ class _SSHFormatEd25519:
f_priv.put_sshstr(f_keypair)
def load_application(data) -> tuple[memoryview, memoryview]:
"""
U2F application strings
"""
application, data = _get_sshstr(data)
if not application.tobytes().startswith(b"ssh:"):
raise ValueError(
"U2F application string does not start with b'ssh:' "
f"({application})"
)
return application, data
class _SSHFormatSKEd25519:
"""
The format of a sk-ssh-ed25519@openssh.com public key is:
string "sk-ssh-ed25519@openssh.com"
string public key
string application (user-specified, but typically "ssh:")
"""
def load_public(
self, data: memoryview
) -> tuple[ed25519.Ed25519PublicKey, memoryview]:
"""Make Ed25519 public key from data."""
public_key, data = _lookup_kformat(_SSH_ED25519).load_public(data)
_, data = load_application(data)
return public_key, data
def get_public(self, data: memoryview) -> typing.NoReturn:
# Confusingly `get_public` is an entry point used by private key
# loading.
raise UnsupportedAlgorithm(
"sk-ssh-ed25519 private keys cannot be loaded"
)
class _SSHFormatSKECDSA:
"""
The format of a sk-ecdsa-sha2-nistp256@openssh.com public key is:
string "sk-ecdsa-sha2-nistp256@openssh.com"
string curve name
ec_point Q
string application (user-specified, but typically "ssh:")
"""
def load_public(
self, data: memoryview
) -> tuple[ec.EllipticCurvePublicKey, memoryview]:
"""Make ECDSA public key from data."""
public_key, data = _lookup_kformat(_ECDSA_NISTP256).load_public(data)
_, data = load_application(data)
return public_key, data
def get_public(self, data: memoryview) -> typing.NoReturn:
# Confusingly `get_public` is an entry point used by private key
# loading.
raise UnsupportedAlgorithm(
"sk-ecdsa-sha2-nistp256 private keys cannot be loaded"
)
_KEY_FORMATS = {
_SSH_RSA: _SSHFormatRSA(),
_SSH_DSA: _SSHFormatDSA(),
@@ -593,10 +657,12 @@ _KEY_FORMATS = {
_ECDSA_NISTP256: _SSHFormatECDSA(b"nistp256", ec.SECP256R1()),
_ECDSA_NISTP384: _SSHFormatECDSA(b"nistp384", ec.SECP384R1()),
_ECDSA_NISTP521: _SSHFormatECDSA(b"nistp521", ec.SECP521R1()),
_SK_SSH_ED25519: _SSHFormatSKEd25519(),
_SK_SSH_ECDSA_NISTP256: _SSHFormatSKECDSA(),
}
def _lookup_kformat(key_type: bytes):
def _lookup_kformat(key_type: utils.Buffer):
"""Return valid format or throw error"""
if not isinstance(key_type, bytes):
key_type = memoryview(key_type).tobytes()
@@ -614,9 +680,11 @@ SSHPrivateKeyTypes = typing.Union[
def load_ssh_private_key(
data: bytes,
password: typing.Optional[bytes],
data: utils.Buffer,
password: bytes | None,
backend: typing.Any = None,
*,
unsafe_skip_rsa_key_validation: bool = False,
) -> SSHPrivateKeyTypes:
"""Load private key from OpenSSH custom encoding."""
utils._check_byteslike("data", data)
@@ -648,7 +716,7 @@ def load_ssh_private_key(
pubfields, pubdata = kformat.get_public(pubdata)
_check_empty(pubdata)
if (ciphername, kdfname) != (_NONE, _NONE):
if ciphername != _NONE or kdfname != _NONE:
ciphername_bytes = ciphername.tobytes()
if ciphername_bytes not in _SSH_CIPHERS:
raise UnsupportedAlgorithm(
@@ -683,6 +751,10 @@ def load_ssh_private_key(
# should be no output from finalize
_check_empty(dec.finalize())
else:
if password:
raise TypeError(
"Password was given but private key is not encrypted."
)
# load secret data
edata, data = _get_sshstr(data)
_check_empty(data)
@@ -697,8 +769,13 @@ def load_ssh_private_key(
key_type, edata = _get_sshstr(edata)
if key_type != pub_key_type:
raise ValueError("Corrupt data: key type mismatch")
private_key, edata = kformat.load_private(edata, pubfields)
comment, edata = _get_sshstr(edata)
private_key, edata = kformat.load_private(
edata,
pubfields,
unsafe_skip_rsa_key_validation=unsafe_skip_rsa_key_validation,
)
# We don't use the comment
_, edata = _get_sshstr(edata)
# yes, SSH does padding check *after* all other parsing is done.
# need to follow as it writes zero-byte padding too.
@@ -820,11 +897,11 @@ class SSHCertificate:
_serial: int,
_cctype: int,
_key_id: memoryview,
_valid_principals: typing.List[bytes],
_valid_principals: list[bytes],
_valid_after: int,
_valid_before: int,
_critical_options: typing.Dict[bytes, bytes],
_extensions: typing.Dict[bytes, bytes],
_critical_options: dict[bytes, bytes],
_extensions: dict[bytes, bytes],
_sig_type: memoryview,
_sig_key: memoryview,
_inner_sig_type: memoryview,
@@ -876,7 +953,7 @@ class SSHCertificate:
return bytes(self._key_id)
@property
def valid_principals(self) -> typing.List[bytes]:
def valid_principals(self) -> list[bytes]:
return self._valid_principals
@property
@@ -888,11 +965,11 @@ class SSHCertificate:
return self._valid_after
@property
def critical_options(self) -> typing.Dict[bytes, bytes]:
def critical_options(self) -> dict[bytes, bytes]:
return self._critical_options
@property
def extensions(self) -> typing.Dict[bytes, bytes]:
def extensions(self) -> dict[bytes, bytes]:
return self._extensions
def signature_key(self) -> SSHCertPublicKeyTypes:
@@ -952,9 +1029,9 @@ def _get_ec_hash_alg(curve: ec.EllipticCurve) -> hashes.HashAlgorithm:
def _load_ssh_public_identity(
data: bytes,
data: utils.Buffer,
_legacy_dsa_allowed=False,
) -> typing.Union[SSHCertificate, SSHPublicKeyTypes]:
) -> SSHCertificate | SSHPublicKeyTypes:
utils._check_byteslike("data", data)
m = _SSH_PUBKEY_RC.match(data)
@@ -1047,13 +1124,13 @@ def _load_ssh_public_identity(
def load_ssh_public_identity(
data: bytes,
) -> typing.Union[SSHCertificate, SSHPublicKeyTypes]:
data: utils.Buffer,
) -> SSHCertificate | SSHPublicKeyTypes:
return _load_ssh_public_identity(data)
def _parse_exts_opts(exts_opts: memoryview) -> typing.Dict[bytes, bytes]:
result: typing.Dict[bytes, bytes] = {}
def _parse_exts_opts(exts_opts: memoryview) -> dict[bytes, bytes]:
result: dict[bytes, bytes] = {}
last_name = None
while exts_opts:
name, exts_opts = _get_sshstr(exts_opts)
@@ -1064,26 +1141,38 @@ def _parse_exts_opts(exts_opts: memoryview) -> typing.Dict[bytes, bytes]:
raise ValueError("Fields not lexically sorted")
value, exts_opts = _get_sshstr(exts_opts)
if len(value) > 0:
try:
value, extra = _get_sshstr(value)
except ValueError:
warnings.warn(
"This certificate has an incorrect encoding for critical "
"options or extensions. This will be an exception in "
"cryptography 42",
utils.DeprecatedIn41,
stacklevel=4,
)
else:
if len(extra) > 0:
raise ValueError("Unexpected extra data after value")
value, extra = _get_sshstr(value)
if len(extra) > 0:
raise ValueError("Unexpected extra data after value")
result[bname] = bytes(value)
last_name = bname
return result
def ssh_key_fingerprint(
key: SSHPublicKeyTypes,
hash_algorithm: hashes.MD5 | hashes.SHA256,
) -> bytes:
if not isinstance(hash_algorithm, (hashes.MD5, hashes.SHA256)):
raise TypeError("hash_algorithm must be either MD5 or SHA256")
key_type = _get_ssh_key_type(key)
kformat = _lookup_kformat(key_type)
f_pub = _FragList()
f_pub.put_sshstr(key_type)
kformat.encode_public(key, f_pub)
ssh_binary_data = f_pub.tobytes()
# Hash the binary data
hash_obj = hashes.Hash(hash_algorithm)
hash_obj.update(ssh_binary_data)
return hash_obj.finalize()
def load_ssh_public_key(
data: bytes, backend: typing.Any = None
data: utils.Buffer, backend: typing.Any = None
) -> SSHPublicKeyTypes:
cert_or_key = _load_ssh_public_identity(data, _legacy_dsa_allowed=True)
public_key: SSHPublicKeyTypes
@@ -1137,16 +1226,16 @@ _SSHKEY_CERT_MAX_PRINCIPALS = 256
class SSHCertificateBuilder:
def __init__(
self,
_public_key: typing.Optional[SSHCertPublicKeyTypes] = None,
_serial: typing.Optional[int] = None,
_type: typing.Optional[SSHCertificateType] = None,
_key_id: typing.Optional[bytes] = None,
_valid_principals: typing.List[bytes] = [],
_public_key: SSHCertPublicKeyTypes | None = None,
_serial: int | None = None,
_type: SSHCertificateType | None = None,
_key_id: bytes | None = None,
_valid_principals: list[bytes] = [],
_valid_for_all_principals: bool = False,
_valid_before: typing.Optional[int] = None,
_valid_after: typing.Optional[int] = None,
_critical_options: typing.List[typing.Tuple[bytes, bytes]] = [],
_extensions: typing.List[typing.Tuple[bytes, bytes]] = [],
_valid_before: int | None = None,
_valid_after: int | None = None,
_critical_options: list[tuple[bytes, bytes]] = [],
_extensions: list[tuple[bytes, bytes]] = [],
):
self._public_key = _public_key
self._serial = _serial
@@ -1247,7 +1336,7 @@ class SSHCertificateBuilder:
)
def valid_principals(
self, valid_principals: typing.List[bytes]
self, valid_principals: list[bytes]
) -> SSHCertificateBuilder:
if self._valid_for_all_principals:
raise ValueError(
@@ -1304,9 +1393,7 @@ class SSHCertificateBuilder:
_extensions=self._extensions,
)
def valid_before(
self, valid_before: typing.Union[int, float]
) -> SSHCertificateBuilder:
def valid_before(self, valid_before: int | float) -> SSHCertificateBuilder:
if not isinstance(valid_before, (int, float)):
raise TypeError("valid_before must be an int or float")
valid_before = int(valid_before)
@@ -1328,9 +1415,7 @@ class SSHCertificateBuilder:
_extensions=self._extensions,
)
def valid_after(
self, valid_after: typing.Union[int, float]
) -> SSHCertificateBuilder:
def valid_after(self, valid_after: int | float) -> SSHCertificateBuilder:
if not isinstance(valid_after, (int, float)):
raise TypeError("valid_after must be an int or float")
valid_after = int(valid_after)
@@ -1370,7 +1455,7 @@ class SSHCertificateBuilder:
_valid_for_all_principals=self._valid_for_all_principals,
_valid_before=self._valid_before,
_valid_after=self._valid_after,
_critical_options=self._critical_options + [(name, value)],
_critical_options=[*self._critical_options, (name, value)],
_extensions=self._extensions,
)
@@ -1393,7 +1478,7 @@ class SSHCertificateBuilder:
_valid_before=self._valid_before,
_valid_after=self._valid_after,
_critical_options=self._critical_options,
_extensions=self._extensions + [(name, value)],
_extensions=[*self._extensions, (name, value)],
)
def sign(self, private_key: SSHCertPrivateKeyTypes) -> SSHCertificate:

View File

@@ -11,6 +11,7 @@ from urllib.parse import quote, urlencode
from cryptography.hazmat.primitives import constant_time, hmac
from cryptography.hazmat.primitives.hashes import SHA1, SHA256, SHA512
from cryptography.hazmat.primitives.twofactor import InvalidToken
from cryptography.utils import Buffer
HOTPHashTypes = typing.Union[SHA1, SHA256, SHA512]
@@ -19,8 +20,8 @@ def _generate_uri(
hotp: HOTP,
type_name: str,
account_name: str,
issuer: typing.Optional[str],
extra_parameters: typing.List[typing.Tuple[str, int]],
issuer: str | None,
extra_parameters: list[tuple[str, int]],
) -> str:
parameters = [
("digits", hotp._length),
@@ -44,7 +45,7 @@ def _generate_uri(
class HOTP:
def __init__(
self,
key: bytes,
key: Buffer,
length: int,
algorithm: HOTPHashTypes,
backend: typing.Any = None,
@@ -67,6 +68,9 @@ class HOTP:
self._algorithm = algorithm
def generate(self, counter: int) -> bytes:
if not isinstance(counter, int):
raise TypeError("Counter parameter must be an integer type.")
truncated_value = self._dynamic_truncate(counter)
hotp = truncated_value % (10**self._length)
return "{0:0{1}}".format(hotp, self._length).encode()
@@ -77,7 +81,12 @@ class HOTP:
def _dynamic_truncate(self, counter: int) -> int:
ctx = hmac.HMAC(self._key, self._algorithm)
ctx.update(counter.to_bytes(length=8, byteorder="big"))
try:
ctx.update(counter.to_bytes(length=8, byteorder="big"))
except OverflowError:
raise ValueError(f"Counter must be between 0 and {2**64 - 1}.")
hmac_value = ctx.finalize()
offset = hmac_value[len(hmac_value) - 1] & 0b1111
@@ -85,7 +94,7 @@ class HOTP:
return int.from_bytes(p, byteorder="big") & 0x7FFFFFFF
def get_provisioning_uri(
self, account_name: str, counter: int, issuer: typing.Optional[str]
self, account_name: str, counter: int, issuer: str | None
) -> str:
return _generate_uri(
self, "hotp", account_name, issuer, [("counter", int(counter))]

View File

@@ -13,12 +13,13 @@ from cryptography.hazmat.primitives.twofactor.hotp import (
HOTPHashTypes,
_generate_uri,
)
from cryptography.utils import Buffer
class TOTP:
def __init__(
self,
key: bytes,
key: Buffer,
length: int,
algorithm: HOTPHashTypes,
time_step: int,
@@ -30,7 +31,12 @@ class TOTP:
key, length, algorithm, enforce_key_length=enforce_key_length
)
def generate(self, time: typing.Union[int, float]) -> bytes:
def generate(self, time: int | float) -> bytes:
if not isinstance(time, (int, float)):
raise TypeError(
"Time parameter must be an integer type or float type."
)
counter = int(time / self._time_step)
return self._hotp.generate(counter)
@@ -39,7 +45,7 @@ class TOTP:
raise InvalidToken("Supplied TOTP value does not match.")
def get_provisioning_uri(
self, account_name: str, issuer: typing.Optional[str]
self, account_name: str, issuer: str | None
) -> str:
return _generate_uri(
self._hotp,