This commit is contained in:
Iliyan Angelov
2025-11-19 12:27:01 +02:00
parent 2043ac897c
commit 34b4c969d4
469 changed files with 26870 additions and 8329 deletions

View File

@@ -33,9 +33,25 @@ from cryptography.hazmat.primitives.serialization.ssh import (
load_ssh_private_key,
load_ssh_public_identity,
load_ssh_public_key,
ssh_key_fingerprint,
)
__all__ = [
"BestAvailableEncryption",
"Encoding",
"KeySerializationEncryption",
"NoEncryption",
"ParameterFormat",
"PrivateFormat",
"PublicFormat",
"SSHCertPrivateKeyTypes",
"SSHCertPublicKeyTypes",
"SSHCertificate",
"SSHCertificateBuilder",
"SSHCertificateType",
"SSHPrivateKeyTypes",
"SSHPublicKeyTypes",
"_KeySerializationEncryption",
"load_der_parameters",
"load_der_private_key",
"load_der_public_key",
@@ -45,19 +61,5 @@ __all__ = [
"load_ssh_private_key",
"load_ssh_public_identity",
"load_ssh_public_key",
"Encoding",
"PrivateFormat",
"PublicFormat",
"ParameterFormat",
"KeySerializationEncryption",
"BestAvailableEncryption",
"NoEncryption",
"_KeySerializationEncryption",
"SSHCertificateBuilder",
"SSHCertificate",
"SSHCertificateType",
"SSHCertPublicKeyTypes",
"SSHCertPrivateKeyTypes",
"SSHPrivateKeyTypes",
"SSHPublicKeyTypes",
"ssh_key_fingerprint",
]

View File

@@ -2,72 +2,13 @@
# 2.0, and the BSD License. See the LICENSE file in the root of this repository
# for complete details.
from __future__ import annotations
from cryptography.hazmat.bindings._rust import openssl as rust_openssl
import typing
load_pem_private_key = rust_openssl.keys.load_pem_private_key
load_der_private_key = rust_openssl.keys.load_der_private_key
from cryptography.hazmat.primitives.asymmetric import dh
from cryptography.hazmat.primitives.asymmetric.types import (
PrivateKeyTypes,
PublicKeyTypes,
)
load_pem_public_key = rust_openssl.keys.load_pem_public_key
load_der_public_key = rust_openssl.keys.load_der_public_key
def load_pem_private_key(
data: bytes,
password: typing.Optional[bytes],
backend: typing.Any = None,
*,
unsafe_skip_rsa_key_validation: bool = False,
) -> PrivateKeyTypes:
from cryptography.hazmat.backends.openssl.backend import backend as ossl
return ossl.load_pem_private_key(
data, password, unsafe_skip_rsa_key_validation
)
def load_pem_public_key(
data: bytes, backend: typing.Any = None
) -> PublicKeyTypes:
from cryptography.hazmat.backends.openssl.backend import backend as ossl
return ossl.load_pem_public_key(data)
def load_pem_parameters(
data: bytes, backend: typing.Any = None
) -> dh.DHParameters:
from cryptography.hazmat.backends.openssl.backend import backend as ossl
return ossl.load_pem_parameters(data)
def load_der_private_key(
data: bytes,
password: typing.Optional[bytes],
backend: typing.Any = None,
*,
unsafe_skip_rsa_key_validation: bool = False,
) -> PrivateKeyTypes:
from cryptography.hazmat.backends.openssl.backend import backend as ossl
return ossl.load_der_private_key(
data, password, unsafe_skip_rsa_key_validation
)
def load_der_public_key(
data: bytes, backend: typing.Any = None
) -> PublicKeyTypes:
from cryptography.hazmat.backends.openssl.backend import backend as ossl
return ossl.load_der_public_key(data)
def load_der_parameters(
data: bytes, backend: typing.Any = None
) -> dh.DHParameters:
from cryptography.hazmat.backends.openssl.backend import backend as ossl
return ossl.load_der_parameters(data)
load_pem_parameters = rust_openssl.dh.from_pem_parameters
load_der_parameters = rust_openssl.dh.from_der_parameters

View File

@@ -5,8 +5,10 @@
from __future__ import annotations
import typing
from collections.abc import Iterable
from cryptography import x509
from cryptography.hazmat.bindings._rust import pkcs12 as rust_pkcs12
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives._serialization import PBES as PBES
from cryptography.hazmat.primitives.asymmetric import (
@@ -20,11 +22,12 @@ from cryptography.hazmat.primitives.asymmetric.types import PrivateKeyTypes
__all__ = [
"PBES",
"PKCS12PrivateKeyTypes",
"PKCS12Certificate",
"PKCS12KeyAndCertificates",
"PKCS12PrivateKeyTypes",
"load_key_and_certificates",
"load_pkcs12",
"serialize_java_truststore",
"serialize_key_and_certificates",
]
@@ -37,51 +40,15 @@ PKCS12PrivateKeyTypes = typing.Union[
]
class PKCS12Certificate:
def __init__(
self,
cert: x509.Certificate,
friendly_name: typing.Optional[bytes],
):
if not isinstance(cert, x509.Certificate):
raise TypeError("Expecting x509.Certificate object")
if friendly_name is not None and not isinstance(friendly_name, bytes):
raise TypeError("friendly_name must be bytes or None")
self._cert = cert
self._friendly_name = friendly_name
@property
def friendly_name(self) -> typing.Optional[bytes]:
return self._friendly_name
@property
def certificate(self) -> x509.Certificate:
return self._cert
def __eq__(self, other: object) -> bool:
if not isinstance(other, PKCS12Certificate):
return NotImplemented
return (
self.certificate == other.certificate
and self.friendly_name == other.friendly_name
)
def __hash__(self) -> int:
return hash((self.certificate, self.friendly_name))
def __repr__(self) -> str:
return "<PKCS12Certificate({}, friendly_name={!r})>".format(
self.certificate, self.friendly_name
)
PKCS12Certificate = rust_pkcs12.PKCS12Certificate
class PKCS12KeyAndCertificates:
def __init__(
self,
key: typing.Optional[PrivateKeyTypes],
cert: typing.Optional[PKCS12Certificate],
additional_certs: typing.List[PKCS12Certificate],
key: PrivateKeyTypes | None,
cert: PKCS12Certificate | None,
additional_certs: list[PKCS12Certificate],
):
if key is not None and not isinstance(
key,
@@ -112,15 +79,15 @@ class PKCS12KeyAndCertificates:
self._additional_certs = additional_certs
@property
def key(self) -> typing.Optional[PrivateKeyTypes]:
def key(self) -> PrivateKeyTypes | None:
return self._key
@property
def cert(self) -> typing.Optional[PKCS12Certificate]:
def cert(self) -> PKCS12Certificate | None:
return self._cert
@property
def additional_certs(self) -> typing.List[PKCS12Certificate]:
def additional_certs(self) -> list[PKCS12Certificate]:
return self._additional_certs
def __eq__(self, other: object) -> bool:
@@ -143,28 +110,8 @@ class PKCS12KeyAndCertificates:
return fmt.format(self.key, self.cert, self.additional_certs)
def load_key_and_certificates(
data: bytes,
password: typing.Optional[bytes],
backend: typing.Any = None,
) -> typing.Tuple[
typing.Optional[PrivateKeyTypes],
typing.Optional[x509.Certificate],
typing.List[x509.Certificate],
]:
from cryptography.hazmat.backends.openssl.backend import backend as ossl
return ossl.load_key_and_certificates_from_pkcs12(data, password)
def load_pkcs12(
data: bytes,
password: typing.Optional[bytes],
backend: typing.Any = None,
) -> PKCS12KeyAndCertificates:
from cryptography.hazmat.backends.openssl.backend import backend as ossl
return ossl.load_pkcs12(data, password)
load_key_and_certificates = rust_pkcs12.load_key_and_certificates
load_pkcs12 = rust_pkcs12.load_pkcs12
_PKCS12CATypes = typing.Union[
@@ -173,11 +120,29 @@ _PKCS12CATypes = typing.Union[
]
def serialize_java_truststore(
certs: Iterable[PKCS12Certificate],
encryption_algorithm: serialization.KeySerializationEncryption,
) -> bytes:
if not certs:
raise ValueError("You must supply at least one cert")
if not isinstance(
encryption_algorithm, serialization.KeySerializationEncryption
):
raise TypeError(
"Key encryption algorithm must be a "
"KeySerializationEncryption instance"
)
return rust_pkcs12.serialize_java_truststore(certs, encryption_algorithm)
def serialize_key_and_certificates(
name: typing.Optional[bytes],
key: typing.Optional[PKCS12PrivateKeyTypes],
cert: typing.Optional[x509.Certificate],
cas: typing.Optional[typing.Iterable[_PKCS12CATypes]],
name: bytes | None,
key: PKCS12PrivateKeyTypes | None,
cert: x509.Certificate | None,
cas: Iterable[_PKCS12CATypes] | None,
encryption_algorithm: serialization.KeySerializationEncryption,
) -> bytes:
if key is not None and not isinstance(
@@ -194,22 +159,6 @@ def serialize_key_and_certificates(
"Key must be RSA, DSA, EllipticCurve, ED25519, or ED448"
" private key, or None."
)
if cert is not None and not isinstance(cert, x509.Certificate):
raise TypeError("cert must be a certificate or None")
if cas is not None:
cas = list(cas)
if not all(
isinstance(
val,
(
x509.Certificate,
PKCS12Certificate,
),
)
for val in cas
):
raise TypeError("all values in cas must be certificates")
if not isinstance(
encryption_algorithm, serialization.KeySerializationEncryption
@@ -222,8 +171,6 @@ def serialize_key_and_certificates(
if key is None and cert is None and not cas:
raise ValueError("You must supply at least one of key, cert, or cas")
from cryptography.hazmat.backends.openssl.backend import backend
return backend.serialize_key_and_certificates_to_pkcs12(
return rust_pkcs12.serialize_key_and_certificates(
name, key, cert, cas, encryption_algorithm
)

View File

@@ -10,32 +10,23 @@ import email.message
import email.policy
import io
import typing
from collections.abc import Iterable
from cryptography import utils, x509
from cryptography.exceptions import UnsupportedAlgorithm, _Reasons
from cryptography.hazmat.bindings._rust import pkcs7 as rust_pkcs7
from cryptography.hazmat.primitives import hashes, serialization
from cryptography.hazmat.primitives.asymmetric import ec, rsa
from cryptography.hazmat.primitives.asymmetric import ec, padding, rsa
from cryptography.hazmat.primitives.ciphers import (
algorithms,
)
from cryptography.utils import _check_byteslike
load_pem_pkcs7_certificates = rust_pkcs7.load_pem_pkcs7_certificates
def load_pem_pkcs7_certificates(data: bytes) -> typing.List[x509.Certificate]:
from cryptography.hazmat.backends.openssl.backend import backend
return backend.load_pem_pkcs7_certificates(data)
def load_der_pkcs7_certificates(data: bytes) -> typing.List[x509.Certificate]:
from cryptography.hazmat.backends.openssl.backend import backend
return backend.load_der_pkcs7_certificates(data)
def serialize_certificates(
certs: typing.List[x509.Certificate],
encoding: serialization.Encoding,
) -> bytes:
return rust_pkcs7.serialize_certificates(certs, encoding)
load_der_pkcs7_certificates = rust_pkcs7.load_der_pkcs7_certificates
serialize_certificates = rust_pkcs7.serialize_certificates
PKCS7HashTypes = typing.Union[
hashes.SHA224,
@@ -48,6 +39,10 @@ PKCS7PrivateKeyTypes = typing.Union[
rsa.RSAPrivateKey, ec.EllipticCurvePrivateKey
]
ContentEncryptionAlgorithm = typing.Union[
typing.Type[algorithms.AES128], typing.Type[algorithms.AES256]
]
class PKCS7Options(utils.Enum):
Text = "Add text/plain MIME type"
@@ -61,21 +56,22 @@ class PKCS7Options(utils.Enum):
class PKCS7SignatureBuilder:
def __init__(
self,
data: typing.Optional[bytes] = None,
signers: typing.List[
typing.Tuple[
data: utils.Buffer | None = None,
signers: list[
tuple[
x509.Certificate,
PKCS7PrivateKeyTypes,
PKCS7HashTypes,
padding.PSS | padding.PKCS1v15 | None,
]
] = [],
additional_certs: typing.List[x509.Certificate] = [],
additional_certs: list[x509.Certificate] = [],
):
self._data = data
self._signers = signers
self._additional_certs = additional_certs
def set_data(self, data: bytes) -> PKCS7SignatureBuilder:
def set_data(self, data: utils.Buffer) -> PKCS7SignatureBuilder:
_check_byteslike("data", data)
if self._data is not None:
raise ValueError("data may only be set once")
@@ -87,6 +83,8 @@ class PKCS7SignatureBuilder:
certificate: x509.Certificate,
private_key: PKCS7PrivateKeyTypes,
hash_algorithm: PKCS7HashTypes,
*,
rsa_padding: padding.PSS | padding.PKCS1v15 | None = None,
) -> PKCS7SignatureBuilder:
if not isinstance(
hash_algorithm,
@@ -109,9 +107,18 @@ class PKCS7SignatureBuilder:
):
raise TypeError("Only RSA & EC keys are supported at this time.")
if rsa_padding is not None:
if not isinstance(rsa_padding, (padding.PSS, padding.PKCS1v15)):
raise TypeError("Padding must be PSS or PKCS1v15")
if not isinstance(private_key, rsa.RSAPrivateKey):
raise TypeError("Padding is only supported for RSA keys")
return PKCS7SignatureBuilder(
self._data,
self._signers + [(certificate, private_key, hash_algorithm)],
[
*self._signers,
(certificate, private_key, hash_algorithm, rsa_padding),
],
)
def add_certificate(
@@ -121,13 +128,13 @@ class PKCS7SignatureBuilder:
raise TypeError("certificate must be a x509.Certificate")
return PKCS7SignatureBuilder(
self._data, self._signers, self._additional_certs + [certificate]
self._data, self._signers, [*self._additional_certs, certificate]
)
def sign(
self,
encoding: serialization.Encoding,
options: typing.Iterable[PKCS7Options],
options: Iterable[PKCS7Options],
backend: typing.Any = None,
) -> bytes:
if len(self._signers) == 0:
@@ -179,7 +186,131 @@ class PKCS7SignatureBuilder:
return rust_pkcs7.sign_and_serialize(self, encoding, options)
def _smime_encode(
class PKCS7EnvelopeBuilder:
def __init__(
self,
*,
_data: bytes | None = None,
_recipients: list[x509.Certificate] | None = None,
_content_encryption_algorithm: ContentEncryptionAlgorithm
| None = None,
):
from cryptography.hazmat.backends.openssl.backend import (
backend as ossl,
)
if not ossl.rsa_encryption_supported(padding=padding.PKCS1v15()):
raise UnsupportedAlgorithm(
"RSA with PKCS1 v1.5 padding is not supported by this version"
" of OpenSSL.",
_Reasons.UNSUPPORTED_PADDING,
)
self._data = _data
self._recipients = _recipients if _recipients is not None else []
self._content_encryption_algorithm = _content_encryption_algorithm
def set_data(self, data: bytes) -> PKCS7EnvelopeBuilder:
_check_byteslike("data", data)
if self._data is not None:
raise ValueError("data may only be set once")
return PKCS7EnvelopeBuilder(
_data=data,
_recipients=self._recipients,
_content_encryption_algorithm=self._content_encryption_algorithm,
)
def add_recipient(
self,
certificate: x509.Certificate,
) -> PKCS7EnvelopeBuilder:
if not isinstance(certificate, x509.Certificate):
raise TypeError("certificate must be a x509.Certificate")
if not isinstance(certificate.public_key(), rsa.RSAPublicKey):
raise TypeError("Only RSA keys are supported at this time.")
return PKCS7EnvelopeBuilder(
_data=self._data,
_recipients=[
*self._recipients,
certificate,
],
_content_encryption_algorithm=self._content_encryption_algorithm,
)
def set_content_encryption_algorithm(
self, content_encryption_algorithm: ContentEncryptionAlgorithm
) -> PKCS7EnvelopeBuilder:
if self._content_encryption_algorithm is not None:
raise ValueError("Content encryption algo may only be set once")
if content_encryption_algorithm not in {
algorithms.AES128,
algorithms.AES256,
}:
raise TypeError("Only AES128 and AES256 are supported")
return PKCS7EnvelopeBuilder(
_data=self._data,
_recipients=self._recipients,
_content_encryption_algorithm=content_encryption_algorithm,
)
def encrypt(
self,
encoding: serialization.Encoding,
options: Iterable[PKCS7Options],
) -> bytes:
if len(self._recipients) == 0:
raise ValueError("Must have at least one recipient")
if self._data is None:
raise ValueError("You must add data to encrypt")
# The default content encryption algorithm is AES-128, which the S/MIME
# v3.2 RFC specifies as MUST support (https://datatracker.ietf.org/doc/html/rfc5751#section-2.7)
content_encryption_algorithm = (
self._content_encryption_algorithm or algorithms.AES128
)
options = list(options)
if not all(isinstance(x, PKCS7Options) for x in options):
raise ValueError("options must be from the PKCS7Options enum")
if encoding not in (
serialization.Encoding.PEM,
serialization.Encoding.DER,
serialization.Encoding.SMIME,
):
raise ValueError(
"Must be PEM, DER, or SMIME from the Encoding enum"
)
# Only allow options that make sense for encryption
if any(
opt not in [PKCS7Options.Text, PKCS7Options.Binary]
for opt in options
):
raise ValueError(
"Only the following options are supported for encryption: "
"Text, Binary"
)
elif PKCS7Options.Text in options and PKCS7Options.Binary in options:
# OpenSSL accepts both options at the same time, but ignores Text.
# We fail defensively to avoid unexpected outputs.
raise ValueError(
"Cannot use Binary and Text options at the same time"
)
return rust_pkcs7.encrypt_and_serialize(
self, content_encryption_algorithm, encoding, options
)
pkcs7_decrypt_der = rust_pkcs7.decrypt_der
pkcs7_decrypt_pem = rust_pkcs7.decrypt_pem
pkcs7_decrypt_smime = rust_pkcs7.decrypt_smime
def _smime_signed_encode(
data: bytes, signature: bytes, micalg: str, text_mode: bool
) -> bytes:
# This function works pretty hard to replicate what OpenSSL does
@@ -227,6 +358,51 @@ def _smime_encode(
return fp.getvalue()
def _smime_enveloped_encode(data: bytes) -> bytes:
m = email.message.Message()
m.add_header("MIME-Version", "1.0")
m.add_header("Content-Disposition", "attachment", filename="smime.p7m")
m.add_header(
"Content-Type",
"application/pkcs7-mime",
smime_type="enveloped-data",
name="smime.p7m",
)
m.add_header("Content-Transfer-Encoding", "base64")
m.set_payload(email.base64mime.body_encode(data, maxlinelen=65))
return m.as_bytes(policy=m.policy.clone(linesep="\n", max_line_length=0))
def _smime_enveloped_decode(data: bytes) -> bytes:
m = email.message_from_bytes(data)
if m.get_content_type() not in {
"application/x-pkcs7-mime",
"application/pkcs7-mime",
}:
raise ValueError("Not an S/MIME enveloped message")
return bytes(m.get_payload(decode=True))
def _smime_remove_text_headers(data: bytes) -> bytes:
m = email.message_from_bytes(data)
# Using get() instead of get_content_type() since it has None as default,
# where the latter has "text/plain". Both methods are case-insensitive.
content_type = m.get("content-type")
if content_type is None:
raise ValueError(
"Decrypted MIME data has no 'Content-Type' header. "
"Please remove the 'Text' option to parse it manually."
)
if "text/plain" not in content_type:
raise ValueError(
f"Decrypted MIME data content type is '{content_type}', not "
"'text/plain'. Remove the 'Text' option to parse it manually."
)
return bytes(m.get_payload(decode=True))
class OpenSSLMimePart(email.message.MIMEPart):
# A MIMEPart subclass that replicates OpenSSL's behavior of not including
# a newline if there are no headers.

View File

@@ -64,6 +64,10 @@ _ECDSA_NISTP384 = b"ecdsa-sha2-nistp384"
_ECDSA_NISTP521 = b"ecdsa-sha2-nistp521"
_CERT_SUFFIX = b"-cert-v01@openssh.com"
# U2F application string suffixed pubkey
_SK_SSH_ED25519 = b"sk-ssh-ed25519@openssh.com"
_SK_SSH_ECDSA_NISTP256 = b"sk-ecdsa-sha2-nistp256@openssh.com"
# These are not key types, only algorithms, so they cannot appear
# as a public key type
_SSH_RSA_SHA256 = b"rsa-sha2-256"
@@ -87,21 +91,17 @@ _PADDING = memoryview(bytearray(range(1, 1 + 16)))
@dataclass
class _SSHCipher:
alg: typing.Type[algorithms.AES]
alg: type[algorithms.AES]
key_len: int
mode: typing.Union[
typing.Type[modes.CTR],
typing.Type[modes.CBC],
typing.Type[modes.GCM],
]
mode: type[modes.CTR] | type[modes.CBC] | type[modes.GCM]
block_len: int
iv_len: int
tag_len: typing.Optional[int]
tag_len: int | None
is_aead: bool
# ciphers that are actually used in key wrapping
_SSH_CIPHERS: typing.Dict[bytes, _SSHCipher] = {
_SSH_CIPHERS: dict[bytes, _SSHCipher] = {
b"aes256-ctr": _SSHCipher(
alg=algorithms.AES,
key_len=32,
@@ -139,9 +139,7 @@ _ECDSA_KEY_TYPE = {
}
def _get_ssh_key_type(
key: typing.Union[SSHPrivateKeyTypes, SSHPublicKeyTypes]
) -> bytes:
def _get_ssh_key_type(key: SSHPrivateKeyTypes | SSHPublicKeyTypes) -> bytes:
if isinstance(key, ec.EllipticCurvePrivateKey):
key_type = _ecdsa_key_type(key.public_key())
elif isinstance(key, ec.EllipticCurvePublicKey):
@@ -171,20 +169,20 @@ def _ecdsa_key_type(public_key: ec.EllipticCurvePublicKey) -> bytes:
def _ssh_pem_encode(
data: bytes,
data: utils.Buffer,
prefix: bytes = _SK_START + b"\n",
suffix: bytes = _SK_END + b"\n",
) -> bytes:
return b"".join([prefix, _base64_encode(data), suffix])
def _check_block_size(data: bytes, block_len: int) -> None:
def _check_block_size(data: utils.Buffer, block_len: int) -> None:
"""Require data to be full blocks"""
if not data or len(data) % block_len != 0:
raise ValueError("Corrupt data: missing padding")
def _check_empty(data: bytes) -> None:
def _check_empty(data: utils.Buffer) -> None:
"""All data should have been parsed."""
if data:
raise ValueError("Corrupt data: unparsed data")
@@ -192,13 +190,15 @@ def _check_empty(data: bytes) -> None:
def _init_cipher(
ciphername: bytes,
password: typing.Optional[bytes],
password: bytes | None,
salt: bytes,
rounds: int,
) -> Cipher[typing.Union[modes.CBC, modes.CTR, modes.GCM]]:
) -> Cipher[modes.CBC | modes.CTR | modes.GCM]:
"""Generate key + iv and return cipher."""
if not password:
raise ValueError("Key is password-protected.")
raise TypeError(
"Key is password-protected, but password was not provided."
)
ciph = _SSH_CIPHERS[ciphername]
seed = _bcrypt_kdf(
@@ -210,21 +210,21 @@ def _init_cipher(
)
def _get_u32(data: memoryview) -> typing.Tuple[int, memoryview]:
def _get_u32(data: memoryview) -> tuple[int, memoryview]:
"""Uint32"""
if len(data) < 4:
raise ValueError("Invalid data")
return int.from_bytes(data[:4], byteorder="big"), data[4:]
def _get_u64(data: memoryview) -> typing.Tuple[int, memoryview]:
def _get_u64(data: memoryview) -> tuple[int, memoryview]:
"""Uint64"""
if len(data) < 8:
raise ValueError("Invalid data")
return int.from_bytes(data[:8], byteorder="big"), data[8:]
def _get_sshstr(data: memoryview) -> typing.Tuple[memoryview, memoryview]:
def _get_sshstr(data: memoryview) -> tuple[memoryview, memoryview]:
"""Bytes with u32 length prefix"""
n, data = _get_u32(data)
if n > len(data):
@@ -232,7 +232,7 @@ def _get_sshstr(data: memoryview) -> typing.Tuple[memoryview, memoryview]:
return data[:n], data[n:]
def _get_mpint(data: memoryview) -> typing.Tuple[int, memoryview]:
def _get_mpint(data: memoryview) -> tuple[int, memoryview]:
"""Big integer."""
val, data = _get_sshstr(data)
if val and val[0] > 0x7F:
@@ -253,16 +253,14 @@ def _to_mpint(val: int) -> bytes:
class _FragList:
"""Build recursive structure without data copy."""
flist: typing.List[bytes]
flist: list[utils.Buffer]
def __init__(
self, init: typing.Optional[typing.List[bytes]] = None
) -> None:
def __init__(self, init: list[utils.Buffer] | None = None) -> None:
self.flist = []
if init:
self.flist.extend(init)
def put_raw(self, val: bytes) -> None:
def put_raw(self, val: utils.Buffer) -> None:
"""Add plain bytes"""
self.flist.append(val)
@@ -274,7 +272,7 @@ class _FragList:
"""Big-endian uint64"""
self.flist.append(val.to_bytes(length=8, byteorder="big"))
def put_sshstr(self, val: typing.Union[bytes, _FragList]) -> None:
def put_sshstr(self, val: bytes | _FragList) -> None:
"""Bytes prefixed with u32 length"""
if isinstance(val, (bytes, memoryview, bytearray)):
self.put_u32(len(val))
@@ -315,7 +313,9 @@ class _SSHFormatRSA:
mpint n, e, d, iqmp, p, q
"""
def get_public(self, data: memoryview):
def get_public(
self, data: memoryview
) -> tuple[tuple[int, int], memoryview]:
"""RSA public fields"""
e, data = _get_mpint(data)
n, data = _get_mpint(data)
@@ -323,7 +323,7 @@ class _SSHFormatRSA:
def load_public(
self, data: memoryview
) -> typing.Tuple[rsa.RSAPublicKey, memoryview]:
) -> tuple[rsa.RSAPublicKey, memoryview]:
"""Make RSA public key from data."""
(e, n), data = self.get_public(data)
public_numbers = rsa.RSAPublicNumbers(e, n)
@@ -331,8 +331,8 @@ class _SSHFormatRSA:
return public_key, data
def load_private(
self, data: memoryview, pubfields
) -> typing.Tuple[rsa.RSAPrivateKey, memoryview]:
self, data: memoryview, pubfields, unsafe_skip_rsa_key_validation: bool
) -> tuple[rsa.RSAPrivateKey, memoryview]:
"""Make RSA private key from data."""
n, data = _get_mpint(data)
e, data = _get_mpint(data)
@@ -349,7 +349,9 @@ class _SSHFormatRSA:
private_numbers = rsa.RSAPrivateNumbers(
p, q, d, dmp1, dmq1, iqmp, public_numbers
)
private_key = private_numbers.private_key()
private_key = private_numbers.private_key(
unsafe_skip_rsa_key_validation=unsafe_skip_rsa_key_validation
)
return private_key, data
def encode_public(
@@ -385,9 +387,7 @@ class _SSHFormatDSA:
mpint p, q, g, y, x
"""
def get_public(
self, data: memoryview
) -> typing.Tuple[typing.Tuple, memoryview]:
def get_public(self, data: memoryview) -> tuple[tuple, memoryview]:
"""DSA public fields"""
p, data = _get_mpint(data)
q, data = _get_mpint(data)
@@ -397,7 +397,7 @@ class _SSHFormatDSA:
def load_public(
self, data: memoryview
) -> typing.Tuple[dsa.DSAPublicKey, memoryview]:
) -> tuple[dsa.DSAPublicKey, memoryview]:
"""Make DSA public key from data."""
(p, q, g, y), data = self.get_public(data)
parameter_numbers = dsa.DSAParameterNumbers(p, q, g)
@@ -407,8 +407,8 @@ class _SSHFormatDSA:
return public_key, data
def load_private(
self, data: memoryview, pubfields
) -> typing.Tuple[dsa.DSAPrivateKey, memoryview]:
self, data: memoryview, pubfields, unsafe_skip_rsa_key_validation: bool
) -> tuple[dsa.DSAPrivateKey, memoryview]:
"""Make DSA private key from data."""
(p, q, g, y), data = self.get_public(data)
x, data = _get_mpint(data)
@@ -466,7 +466,7 @@ class _SSHFormatECDSA:
def get_public(
self, data: memoryview
) -> typing.Tuple[typing.Tuple, memoryview]:
) -> tuple[tuple[memoryview, memoryview], memoryview]:
"""ECDSA public fields"""
curve, data = _get_sshstr(data)
point, data = _get_sshstr(data)
@@ -478,17 +478,17 @@ class _SSHFormatECDSA:
def load_public(
self, data: memoryview
) -> typing.Tuple[ec.EllipticCurvePublicKey, memoryview]:
) -> tuple[ec.EllipticCurvePublicKey, memoryview]:
"""Make ECDSA public key from data."""
(curve_name, point), data = self.get_public(data)
(_, point), data = self.get_public(data)
public_key = ec.EllipticCurvePublicKey.from_encoded_point(
self.curve, point.tobytes()
)
return public_key, data
def load_private(
self, data: memoryview, pubfields
) -> typing.Tuple[ec.EllipticCurvePrivateKey, memoryview]:
self, data: memoryview, pubfields, unsafe_skip_rsa_key_validation: bool
) -> tuple[ec.EllipticCurvePrivateKey, memoryview]:
"""Make ECDSA private key from data."""
(curve_name, point), data = self.get_public(data)
secret, data = _get_mpint(data)
@@ -531,14 +531,14 @@ class _SSHFormatEd25519:
def get_public(
self, data: memoryview
) -> typing.Tuple[typing.Tuple, memoryview]:
) -> tuple[tuple[memoryview], memoryview]:
"""Ed25519 public fields"""
point, data = _get_sshstr(data)
return (point,), data
def load_public(
self, data: memoryview
) -> typing.Tuple[ed25519.Ed25519PublicKey, memoryview]:
) -> tuple[ed25519.Ed25519PublicKey, memoryview]:
"""Make Ed25519 public key from data."""
(point,), data = self.get_public(data)
public_key = ed25519.Ed25519PublicKey.from_public_bytes(
@@ -547,8 +547,8 @@ class _SSHFormatEd25519:
return public_key, data
def load_private(
self, data: memoryview, pubfields
) -> typing.Tuple[ed25519.Ed25519PrivateKey, memoryview]:
self, data: memoryview, pubfields, unsafe_skip_rsa_key_validation: bool
) -> tuple[ed25519.Ed25519PrivateKey, memoryview]:
"""Make Ed25519 private key from data."""
(point,), data = self.get_public(data)
keypair, data = _get_sshstr(data)
@@ -586,6 +586,70 @@ class _SSHFormatEd25519:
f_priv.put_sshstr(f_keypair)
def load_application(data) -> tuple[memoryview, memoryview]:
"""
U2F application strings
"""
application, data = _get_sshstr(data)
if not application.tobytes().startswith(b"ssh:"):
raise ValueError(
"U2F application string does not start with b'ssh:' "
f"({application})"
)
return application, data
class _SSHFormatSKEd25519:
"""
The format of a sk-ssh-ed25519@openssh.com public key is:
string "sk-ssh-ed25519@openssh.com"
string public key
string application (user-specified, but typically "ssh:")
"""
def load_public(
self, data: memoryview
) -> tuple[ed25519.Ed25519PublicKey, memoryview]:
"""Make Ed25519 public key from data."""
public_key, data = _lookup_kformat(_SSH_ED25519).load_public(data)
_, data = load_application(data)
return public_key, data
def get_public(self, data: memoryview) -> typing.NoReturn:
# Confusingly `get_public` is an entry point used by private key
# loading.
raise UnsupportedAlgorithm(
"sk-ssh-ed25519 private keys cannot be loaded"
)
class _SSHFormatSKECDSA:
"""
The format of a sk-ecdsa-sha2-nistp256@openssh.com public key is:
string "sk-ecdsa-sha2-nistp256@openssh.com"
string curve name
ec_point Q
string application (user-specified, but typically "ssh:")
"""
def load_public(
self, data: memoryview
) -> tuple[ec.EllipticCurvePublicKey, memoryview]:
"""Make ECDSA public key from data."""
public_key, data = _lookup_kformat(_ECDSA_NISTP256).load_public(data)
_, data = load_application(data)
return public_key, data
def get_public(self, data: memoryview) -> typing.NoReturn:
# Confusingly `get_public` is an entry point used by private key
# loading.
raise UnsupportedAlgorithm(
"sk-ecdsa-sha2-nistp256 private keys cannot be loaded"
)
_KEY_FORMATS = {
_SSH_RSA: _SSHFormatRSA(),
_SSH_DSA: _SSHFormatDSA(),
@@ -593,10 +657,12 @@ _KEY_FORMATS = {
_ECDSA_NISTP256: _SSHFormatECDSA(b"nistp256", ec.SECP256R1()),
_ECDSA_NISTP384: _SSHFormatECDSA(b"nistp384", ec.SECP384R1()),
_ECDSA_NISTP521: _SSHFormatECDSA(b"nistp521", ec.SECP521R1()),
_SK_SSH_ED25519: _SSHFormatSKEd25519(),
_SK_SSH_ECDSA_NISTP256: _SSHFormatSKECDSA(),
}
def _lookup_kformat(key_type: bytes):
def _lookup_kformat(key_type: utils.Buffer):
"""Return valid format or throw error"""
if not isinstance(key_type, bytes):
key_type = memoryview(key_type).tobytes()
@@ -614,9 +680,11 @@ SSHPrivateKeyTypes = typing.Union[
def load_ssh_private_key(
data: bytes,
password: typing.Optional[bytes],
data: utils.Buffer,
password: bytes | None,
backend: typing.Any = None,
*,
unsafe_skip_rsa_key_validation: bool = False,
) -> SSHPrivateKeyTypes:
"""Load private key from OpenSSH custom encoding."""
utils._check_byteslike("data", data)
@@ -648,7 +716,7 @@ def load_ssh_private_key(
pubfields, pubdata = kformat.get_public(pubdata)
_check_empty(pubdata)
if (ciphername, kdfname) != (_NONE, _NONE):
if ciphername != _NONE or kdfname != _NONE:
ciphername_bytes = ciphername.tobytes()
if ciphername_bytes not in _SSH_CIPHERS:
raise UnsupportedAlgorithm(
@@ -683,6 +751,10 @@ def load_ssh_private_key(
# should be no output from finalize
_check_empty(dec.finalize())
else:
if password:
raise TypeError(
"Password was given but private key is not encrypted."
)
# load secret data
edata, data = _get_sshstr(data)
_check_empty(data)
@@ -697,8 +769,13 @@ def load_ssh_private_key(
key_type, edata = _get_sshstr(edata)
if key_type != pub_key_type:
raise ValueError("Corrupt data: key type mismatch")
private_key, edata = kformat.load_private(edata, pubfields)
comment, edata = _get_sshstr(edata)
private_key, edata = kformat.load_private(
edata,
pubfields,
unsafe_skip_rsa_key_validation=unsafe_skip_rsa_key_validation,
)
# We don't use the comment
_, edata = _get_sshstr(edata)
# yes, SSH does padding check *after* all other parsing is done.
# need to follow as it writes zero-byte padding too.
@@ -820,11 +897,11 @@ class SSHCertificate:
_serial: int,
_cctype: int,
_key_id: memoryview,
_valid_principals: typing.List[bytes],
_valid_principals: list[bytes],
_valid_after: int,
_valid_before: int,
_critical_options: typing.Dict[bytes, bytes],
_extensions: typing.Dict[bytes, bytes],
_critical_options: dict[bytes, bytes],
_extensions: dict[bytes, bytes],
_sig_type: memoryview,
_sig_key: memoryview,
_inner_sig_type: memoryview,
@@ -876,7 +953,7 @@ class SSHCertificate:
return bytes(self._key_id)
@property
def valid_principals(self) -> typing.List[bytes]:
def valid_principals(self) -> list[bytes]:
return self._valid_principals
@property
@@ -888,11 +965,11 @@ class SSHCertificate:
return self._valid_after
@property
def critical_options(self) -> typing.Dict[bytes, bytes]:
def critical_options(self) -> dict[bytes, bytes]:
return self._critical_options
@property
def extensions(self) -> typing.Dict[bytes, bytes]:
def extensions(self) -> dict[bytes, bytes]:
return self._extensions
def signature_key(self) -> SSHCertPublicKeyTypes:
@@ -952,9 +1029,9 @@ def _get_ec_hash_alg(curve: ec.EllipticCurve) -> hashes.HashAlgorithm:
def _load_ssh_public_identity(
data: bytes,
data: utils.Buffer,
_legacy_dsa_allowed=False,
) -> typing.Union[SSHCertificate, SSHPublicKeyTypes]:
) -> SSHCertificate | SSHPublicKeyTypes:
utils._check_byteslike("data", data)
m = _SSH_PUBKEY_RC.match(data)
@@ -1047,13 +1124,13 @@ def _load_ssh_public_identity(
def load_ssh_public_identity(
data: bytes,
) -> typing.Union[SSHCertificate, SSHPublicKeyTypes]:
data: utils.Buffer,
) -> SSHCertificate | SSHPublicKeyTypes:
return _load_ssh_public_identity(data)
def _parse_exts_opts(exts_opts: memoryview) -> typing.Dict[bytes, bytes]:
result: typing.Dict[bytes, bytes] = {}
def _parse_exts_opts(exts_opts: memoryview) -> dict[bytes, bytes]:
result: dict[bytes, bytes] = {}
last_name = None
while exts_opts:
name, exts_opts = _get_sshstr(exts_opts)
@@ -1064,26 +1141,38 @@ def _parse_exts_opts(exts_opts: memoryview) -> typing.Dict[bytes, bytes]:
raise ValueError("Fields not lexically sorted")
value, exts_opts = _get_sshstr(exts_opts)
if len(value) > 0:
try:
value, extra = _get_sshstr(value)
except ValueError:
warnings.warn(
"This certificate has an incorrect encoding for critical "
"options or extensions. This will be an exception in "
"cryptography 42",
utils.DeprecatedIn41,
stacklevel=4,
)
else:
if len(extra) > 0:
raise ValueError("Unexpected extra data after value")
value, extra = _get_sshstr(value)
if len(extra) > 0:
raise ValueError("Unexpected extra data after value")
result[bname] = bytes(value)
last_name = bname
return result
def ssh_key_fingerprint(
key: SSHPublicKeyTypes,
hash_algorithm: hashes.MD5 | hashes.SHA256,
) -> bytes:
if not isinstance(hash_algorithm, (hashes.MD5, hashes.SHA256)):
raise TypeError("hash_algorithm must be either MD5 or SHA256")
key_type = _get_ssh_key_type(key)
kformat = _lookup_kformat(key_type)
f_pub = _FragList()
f_pub.put_sshstr(key_type)
kformat.encode_public(key, f_pub)
ssh_binary_data = f_pub.tobytes()
# Hash the binary data
hash_obj = hashes.Hash(hash_algorithm)
hash_obj.update(ssh_binary_data)
return hash_obj.finalize()
def load_ssh_public_key(
data: bytes, backend: typing.Any = None
data: utils.Buffer, backend: typing.Any = None
) -> SSHPublicKeyTypes:
cert_or_key = _load_ssh_public_identity(data, _legacy_dsa_allowed=True)
public_key: SSHPublicKeyTypes
@@ -1137,16 +1226,16 @@ _SSHKEY_CERT_MAX_PRINCIPALS = 256
class SSHCertificateBuilder:
def __init__(
self,
_public_key: typing.Optional[SSHCertPublicKeyTypes] = None,
_serial: typing.Optional[int] = None,
_type: typing.Optional[SSHCertificateType] = None,
_key_id: typing.Optional[bytes] = None,
_valid_principals: typing.List[bytes] = [],
_public_key: SSHCertPublicKeyTypes | None = None,
_serial: int | None = None,
_type: SSHCertificateType | None = None,
_key_id: bytes | None = None,
_valid_principals: list[bytes] = [],
_valid_for_all_principals: bool = False,
_valid_before: typing.Optional[int] = None,
_valid_after: typing.Optional[int] = None,
_critical_options: typing.List[typing.Tuple[bytes, bytes]] = [],
_extensions: typing.List[typing.Tuple[bytes, bytes]] = [],
_valid_before: int | None = None,
_valid_after: int | None = None,
_critical_options: list[tuple[bytes, bytes]] = [],
_extensions: list[tuple[bytes, bytes]] = [],
):
self._public_key = _public_key
self._serial = _serial
@@ -1247,7 +1336,7 @@ class SSHCertificateBuilder:
)
def valid_principals(
self, valid_principals: typing.List[bytes]
self, valid_principals: list[bytes]
) -> SSHCertificateBuilder:
if self._valid_for_all_principals:
raise ValueError(
@@ -1304,9 +1393,7 @@ class SSHCertificateBuilder:
_extensions=self._extensions,
)
def valid_before(
self, valid_before: typing.Union[int, float]
) -> SSHCertificateBuilder:
def valid_before(self, valid_before: int | float) -> SSHCertificateBuilder:
if not isinstance(valid_before, (int, float)):
raise TypeError("valid_before must be an int or float")
valid_before = int(valid_before)
@@ -1328,9 +1415,7 @@ class SSHCertificateBuilder:
_extensions=self._extensions,
)
def valid_after(
self, valid_after: typing.Union[int, float]
) -> SSHCertificateBuilder:
def valid_after(self, valid_after: int | float) -> SSHCertificateBuilder:
if not isinstance(valid_after, (int, float)):
raise TypeError("valid_after must be an int or float")
valid_after = int(valid_after)
@@ -1370,7 +1455,7 @@ class SSHCertificateBuilder:
_valid_for_all_principals=self._valid_for_all_principals,
_valid_before=self._valid_before,
_valid_after=self._valid_after,
_critical_options=self._critical_options + [(name, value)],
_critical_options=[*self._critical_options, (name, value)],
_extensions=self._extensions,
)
@@ -1393,7 +1478,7 @@ class SSHCertificateBuilder:
_valid_before=self._valid_before,
_valid_after=self._valid_after,
_critical_options=self._critical_options,
_extensions=self._extensions + [(name, value)],
_extensions=[*self._extensions, (name, value)],
)
def sign(self, private_key: SSHCertPrivateKeyTypes) -> SSHCertificate: