This commit is contained in:
Iliyan Angelov
2025-11-19 12:27:01 +02:00
parent 2043ac897c
commit 34b4c969d4
469 changed files with 26870 additions and 8329 deletions

View File

@@ -0,0 +1,13 @@
# This file is dual licensed under the terms of the Apache License, Version
# 2.0, and the BSD License. See the LICENSE file in the root of this repository
# for complete details.
from __future__ import annotations
from cryptography.hazmat.bindings._rust import openssl as rust_openssl
from cryptography.hazmat.primitives.kdf import KeyDerivationFunction
Argon2id = rust_openssl.kdf.Argon2id
KeyDerivationFunction.register(Argon2id)
__all__ = ["Argon2id"]

View File

@@ -5,6 +5,7 @@
from __future__ import annotations
import typing
from collections.abc import Callable
from cryptography import utils
from cryptography.exceptions import AlreadyFinalized, InvalidKey
@@ -19,7 +20,7 @@ def _int_to_u32be(n: int) -> bytes:
def _common_args_checks(
algorithm: hashes.HashAlgorithm,
length: int,
otherinfo: typing.Optional[bytes],
otherinfo: bytes | None,
) -> None:
max_length = algorithm.digest_size * (2**32 - 1)
if length > max_length:
@@ -29,9 +30,9 @@ def _common_args_checks(
def _concatkdf_derive(
key_material: bytes,
key_material: utils.Buffer,
length: int,
auxfn: typing.Callable[[], hashes.HashContext],
auxfn: Callable[[], hashes.HashContext],
otherinfo: bytes,
) -> bytes:
utils._check_byteslike("key_material", key_material)
@@ -56,7 +57,7 @@ class ConcatKDFHash(KeyDerivationFunction):
self,
algorithm: hashes.HashAlgorithm,
length: int,
otherinfo: typing.Optional[bytes],
otherinfo: bytes | None,
backend: typing.Any = None,
):
_common_args_checks(algorithm, length, otherinfo)
@@ -69,7 +70,7 @@ class ConcatKDFHash(KeyDerivationFunction):
def _hash(self) -> hashes.Hash:
return hashes.Hash(self._algorithm)
def derive(self, key_material: bytes) -> bytes:
def derive(self, key_material: utils.Buffer) -> bytes:
if self._used:
raise AlreadyFinalized
self._used = True
@@ -87,8 +88,8 @@ class ConcatKDFHMAC(KeyDerivationFunction):
self,
algorithm: hashes.HashAlgorithm,
length: int,
salt: typing.Optional[bytes],
otherinfo: typing.Optional[bytes],
salt: bytes | None,
otherinfo: bytes | None,
backend: typing.Any = None,
):
_common_args_checks(algorithm, length, otherinfo)
@@ -111,7 +112,7 @@ class ConcatKDFHMAC(KeyDerivationFunction):
def _hmac(self) -> hmac.HMAC:
return hmac.HMAC(self._salt, self._algorithm)
def derive(self, key_material: bytes) -> bytes:
def derive(self, key_material: utils.Buffer) -> bytes:
if self._used:
raise AlreadyFinalized
self._used = True

View File

@@ -4,98 +4,13 @@
from __future__ import annotations
import typing
from cryptography import utils
from cryptography.exceptions import AlreadyFinalized, InvalidKey
from cryptography.hazmat.primitives import constant_time, hashes, hmac
from cryptography.hazmat.bindings._rust import openssl as rust_openssl
from cryptography.hazmat.primitives.kdf import KeyDerivationFunction
HKDF = rust_openssl.kdf.HKDF
HKDFExpand = rust_openssl.kdf.HKDFExpand
class HKDF(KeyDerivationFunction):
def __init__(
self,
algorithm: hashes.HashAlgorithm,
length: int,
salt: typing.Optional[bytes],
info: typing.Optional[bytes],
backend: typing.Any = None,
):
self._algorithm = algorithm
KeyDerivationFunction.register(HKDF)
KeyDerivationFunction.register(HKDFExpand)
if salt is None:
salt = b"\x00" * self._algorithm.digest_size
else:
utils._check_bytes("salt", salt)
self._salt = salt
self._hkdf_expand = HKDFExpand(self._algorithm, length, info)
def _extract(self, key_material: bytes) -> bytes:
h = hmac.HMAC(self._salt, self._algorithm)
h.update(key_material)
return h.finalize()
def derive(self, key_material: bytes) -> bytes:
utils._check_byteslike("key_material", key_material)
return self._hkdf_expand.derive(self._extract(key_material))
def verify(self, key_material: bytes, expected_key: bytes) -> None:
if not constant_time.bytes_eq(self.derive(key_material), expected_key):
raise InvalidKey
class HKDFExpand(KeyDerivationFunction):
def __init__(
self,
algorithm: hashes.HashAlgorithm,
length: int,
info: typing.Optional[bytes],
backend: typing.Any = None,
):
self._algorithm = algorithm
max_length = 255 * algorithm.digest_size
if length > max_length:
raise ValueError(
f"Cannot derive keys larger than {max_length} octets."
)
self._length = length
if info is None:
info = b""
else:
utils._check_bytes("info", info)
self._info = info
self._used = False
def _expand(self, key_material: bytes) -> bytes:
output = [b""]
counter = 1
while self._algorithm.digest_size * (len(output) - 1) < self._length:
h = hmac.HMAC(key_material, self._algorithm)
h.update(output[-1])
h.update(self._info)
h.update(bytes([counter]))
output.append(h.finalize())
counter += 1
return b"".join(output)[: self._length]
def derive(self, key_material: bytes) -> bytes:
utils._check_byteslike("key_material", key_material)
if self._used:
raise AlreadyFinalized
self._used = True
return self._expand(key_material)
def verify(self, key_material: bytes, expected_key: bytes) -> None:
if not constant_time.bytes_eq(self.derive(key_material), expected_key):
raise InvalidKey
__all__ = ["HKDF", "HKDFExpand"]

View File

@@ -5,6 +5,7 @@
from __future__ import annotations
import typing
from collections.abc import Callable
from cryptography import utils
from cryptography.exceptions import (
@@ -36,16 +37,16 @@ class CounterLocation(utils.Enum):
class _KBKDFDeriver:
def __init__(
self,
prf: typing.Callable,
prf: Callable,
mode: Mode,
length: int,
rlen: int,
llen: typing.Optional[int],
llen: int | None,
location: CounterLocation,
break_location: typing.Optional[int],
label: typing.Optional[bytes],
context: typing.Optional[bytes],
fixed: typing.Optional[bytes],
break_location: int | None,
label: bytes | None,
context: bytes | None,
fixed: bytes | None,
):
assert callable(prf)
@@ -75,7 +76,7 @@ class _KBKDFDeriver:
if (label or context) and fixed:
raise ValueError(
"When supplying fixed data, " "label and context are ignored."
"When supplying fixed data, label and context are ignored."
)
if rlen is None or not self._valid_byte_length(rlen):
@@ -87,6 +88,9 @@ class _KBKDFDeriver:
if llen is not None and not isinstance(llen, int):
raise TypeError("llen must be an integer")
if llen == 0:
raise ValueError("llen must be non-zero")
if label is None:
label = b""
@@ -113,11 +117,11 @@ class _KBKDFDeriver:
raise TypeError("value must be of type int")
value_bin = utils.int_to_bytes(1, value)
if not 1 <= len(value_bin) <= 4:
return False
return True
return 1 <= len(value_bin) <= 4
def derive(self, key_material: bytes, prf_output_size: int) -> bytes:
def derive(
self, key_material: utils.Buffer, prf_output_size: int
) -> bytes:
if self._used:
raise AlreadyFinalized
@@ -181,14 +185,14 @@ class KBKDFHMAC(KeyDerivationFunction):
mode: Mode,
length: int,
rlen: int,
llen: typing.Optional[int],
llen: int | None,
location: CounterLocation,
label: typing.Optional[bytes],
context: typing.Optional[bytes],
fixed: typing.Optional[bytes],
label: bytes | None,
context: bytes | None,
fixed: bytes | None,
backend: typing.Any = None,
*,
break_location: typing.Optional[int] = None,
break_location: int | None = None,
):
if not isinstance(algorithm, hashes.HashAlgorithm):
raise UnsupportedAlgorithm(
@@ -224,7 +228,7 @@ class KBKDFHMAC(KeyDerivationFunction):
def _prf(self, key_material: bytes) -> hmac.HMAC:
return hmac.HMAC(key_material, self._algorithm)
def derive(self, key_material: bytes) -> bytes:
def derive(self, key_material: utils.Buffer) -> bytes:
return self._deriver.derive(key_material, self._algorithm.digest_size)
def verify(self, key_material: bytes, expected_key: bytes) -> None:
@@ -239,14 +243,14 @@ class KBKDFCMAC(KeyDerivationFunction):
mode: Mode,
length: int,
rlen: int,
llen: typing.Optional[int],
llen: int | None,
location: CounterLocation,
label: typing.Optional[bytes],
context: typing.Optional[bytes],
fixed: typing.Optional[bytes],
label: bytes | None,
context: bytes | None,
fixed: bytes | None,
backend: typing.Any = None,
*,
break_location: typing.Optional[int] = None,
break_location: int | None = None,
):
if not issubclass(
algorithm, ciphers.BlockCipherAlgorithm
@@ -257,7 +261,7 @@ class KBKDFCMAC(KeyDerivationFunction):
)
self._algorithm = algorithm
self._cipher: typing.Optional[ciphers.BlockCipherAlgorithm] = None
self._cipher: ciphers.BlockCipherAlgorithm | None = None
self._deriver = _KBKDFDeriver(
self._prf,
@@ -277,7 +281,7 @@ class KBKDFCMAC(KeyDerivationFunction):
return cmac.CMAC(self._cipher)
def derive(self, key_material: bytes) -> bytes:
def derive(self, key_material: utils.Buffer) -> bytes:
self._cipher = self._algorithm(key_material)
assert self._cipher is not None

View File

@@ -33,9 +33,7 @@ class PBKDF2HMAC(KeyDerivationFunction):
if not ossl.pbkdf2_hmac_supported(algorithm):
raise UnsupportedAlgorithm(
"{} is not supported for PBKDF2 by this backend.".format(
algorithm.name
),
f"{algorithm.name} is not supported for PBKDF2.",
_Reasons.UNSUPPORTED_HASH,
)
self._used = False
@@ -45,7 +43,7 @@ class PBKDF2HMAC(KeyDerivationFunction):
self._salt = salt
self._iterations = iterations
def derive(self, key_material: bytes) -> bytes:
def derive(self, key_material: utils.Buffer) -> bytes:
if self._used:
raise AlreadyFinalized("PBKDF2 instances can only be used once.")
self._used = True

View File

@@ -5,76 +5,15 @@
from __future__ import annotations
import sys
import typing
from cryptography import utils
from cryptography.exceptions import (
AlreadyFinalized,
InvalidKey,
UnsupportedAlgorithm,
)
from cryptography.hazmat.bindings._rust import openssl as rust_openssl
from cryptography.hazmat.primitives import constant_time
from cryptography.hazmat.primitives.kdf import KeyDerivationFunction
# This is used by the scrypt tests to skip tests that require more memory
# than the MEM_LIMIT
_MEM_LIMIT = sys.maxsize // 2
Scrypt = rust_openssl.kdf.Scrypt
KeyDerivationFunction.register(Scrypt)
class Scrypt(KeyDerivationFunction):
def __init__(
self,
salt: bytes,
length: int,
n: int,
r: int,
p: int,
backend: typing.Any = None,
):
from cryptography.hazmat.backends.openssl.backend import (
backend as ossl,
)
if not ossl.scrypt_supported():
raise UnsupportedAlgorithm(
"This version of OpenSSL does not support scrypt"
)
self._length = length
utils._check_bytes("salt", salt)
if n < 2 or (n & (n - 1)) != 0:
raise ValueError("n must be greater than 1 and be a power of 2.")
if r < 1:
raise ValueError("r must be greater than or equal to 1.")
if p < 1:
raise ValueError("p must be greater than or equal to 1.")
self._used = False
self._salt = salt
self._n = n
self._r = r
self._p = p
def derive(self, key_material: bytes) -> bytes:
if self._used:
raise AlreadyFinalized("Scrypt instances can only be used once.")
self._used = True
utils._check_byteslike("key_material", key_material)
return rust_openssl.kdf.derive_scrypt(
key_material,
self._salt,
self._n,
self._r,
self._p,
_MEM_LIMIT,
self._length,
)
def verify(self, key_material: bytes, expected_key: bytes) -> None:
derived_key = self.derive(key_material)
if not constant_time.bytes_eq(derived_key, expected_key):
raise InvalidKey("Keys do not match.")
__all__ = ["Scrypt"]

View File

@@ -21,7 +21,7 @@ class X963KDF(KeyDerivationFunction):
self,
algorithm: hashes.HashAlgorithm,
length: int,
sharedinfo: typing.Optional[bytes],
sharedinfo: bytes | None,
backend: typing.Any = None,
):
max_len = algorithm.digest_size * (2**32 - 1)
@@ -35,7 +35,7 @@ class X963KDF(KeyDerivationFunction):
self._sharedinfo = sharedinfo
self._used = False
def derive(self, key_material: bytes) -> bytes:
def derive(self, key_material: utils.Buffer) -> bytes:
if self._used:
raise AlreadyFinalized
self._used = True