This commit is contained in:
Iliyan Angelov
2025-11-19 12:27:01 +02:00
parent 2043ac897c
commit 34b4c969d4
469 changed files with 26870 additions and 8329 deletions

View File

@@ -9,6 +9,7 @@ import re
import sys
import typing
import warnings
from collections.abc import Iterable, Iterator
from cryptography import utils
from cryptography.hazmat.bindings._rust import x509 as rust_x509
@@ -31,7 +32,7 @@ class _ASN1Type(utils.Enum):
_ASN1_TYPE_TO_ENUM = {i.value: i for i in _ASN1Type}
_NAMEOID_DEFAULT_TYPE: typing.Dict[ObjectIdentifier, _ASN1Type] = {
_NAMEOID_DEFAULT_TYPE: dict[ObjectIdentifier, _ASN1Type] = {
NameOID.COUNTRY_NAME: _ASN1Type.PrintableString,
NameOID.JURISDICTION_COUNTRY_NAME: _ASN1Type.PrintableString,
NameOID.SERIAL_NUMBER: _ASN1Type.PrintableString,
@@ -59,8 +60,14 @@ _NAMEOID_TO_NAME: _OidNameMap = {
}
_NAME_TO_NAMEOID = {v: k for k, v in _NAMEOID_TO_NAME.items()}
_NAMEOID_LENGTH_LIMIT = {
NameOID.COUNTRY_NAME: (2, 2),
NameOID.JURISDICTION_COUNTRY_NAME: (2, 2),
NameOID.COMMON_NAME: (1, 64),
}
def _escape_dn_value(val: typing.Union[str, bytes]) -> str:
def _escape_dn_value(val: str | bytes) -> str:
"""Escape special characters in RFC4514 Distinguished Name value."""
if not val:
@@ -108,12 +115,21 @@ def _unescape_dn_value(val: str) -> str:
return _RFC4514NameParser._PAIR_RE.sub(sub, val)
class NameAttribute:
NameAttributeValueType = typing.TypeVar(
"NameAttributeValueType",
typing.Union[str, bytes],
str,
bytes,
covariant=True,
)
class NameAttribute(typing.Generic[NameAttributeValueType]):
def __init__(
self,
oid: ObjectIdentifier,
value: typing.Union[str, bytes],
_type: typing.Optional[_ASN1Type] = None,
value: NameAttributeValueType,
_type: _ASN1Type | None = None,
*,
_validate: bool = True,
) -> None:
@@ -128,26 +144,23 @@ class NameAttribute:
)
if not isinstance(value, bytes):
raise TypeError("value must be bytes for BitString")
else:
if not isinstance(value, str):
raise TypeError("value argument must be a str")
elif not isinstance(value, str):
raise TypeError("value argument must be a str")
if (
oid == NameOID.COUNTRY_NAME
or oid == NameOID.JURISDICTION_COUNTRY_NAME
):
length_limits = _NAMEOID_LENGTH_LIMIT.get(oid)
if length_limits is not None:
min_length, max_length = length_limits
assert isinstance(value, str)
c_len = len(value.encode("utf8"))
if c_len != 2 and _validate is True:
raise ValueError(
"Country name must be a 2 character country code"
)
elif c_len != 2:
warnings.warn(
"Country names should be two characters, but the "
"attribute is {} characters in length.".format(c_len),
stacklevel=2,
if c_len < min_length or c_len > max_length:
msg = (
f"Attribute's length must be >= {min_length} and "
f"<= {max_length}, but it was {c_len}"
)
if _validate is True:
raise ValueError(msg)
else:
warnings.warn(msg, stacklevel=2)
# The appropriate ASN1 string type varies by OID and is defined across
# multiple RFCs including 2459, 3280, and 5280. In general UTF8String
@@ -162,15 +175,15 @@ class NameAttribute:
raise TypeError("_type must be from the _ASN1Type enum")
self._oid = oid
self._value = value
self._type = _type
self._value: NameAttributeValueType = value
self._type: _ASN1Type = _type
@property
def oid(self) -> ObjectIdentifier:
return self._oid
@property
def value(self) -> typing.Union[str, bytes]:
def value(self) -> NameAttributeValueType:
return self._value
@property
@@ -182,7 +195,7 @@ class NameAttribute:
return _NAMEOID_TO_NAME.get(self.oid, self.oid.dotted_string)
def rfc4514_string(
self, attr_name_overrides: typing.Optional[_OidNameMap] = None
self, attr_name_overrides: _OidNameMap | None = None
) -> str:
"""
Format as RFC4514 Distinguished Name string.
@@ -208,11 +221,11 @@ class NameAttribute:
return hash((self.oid, self.value))
def __repr__(self) -> str:
return "<NameAttribute(oid={0.oid}, value={0.value!r})>".format(self)
return f"<NameAttribute(oid={self.oid}, value={self.value!r})>"
class RelativeDistinguishedName:
def __init__(self, attributes: typing.Iterable[NameAttribute]):
def __init__(self, attributes: Iterable[NameAttribute]):
attributes = list(attributes)
if not attributes:
raise ValueError("a relative distinguished name cannot be empty")
@@ -227,12 +240,13 @@ class RelativeDistinguishedName:
raise ValueError("duplicate attributes are not allowed")
def get_attributes_for_oid(
self, oid: ObjectIdentifier
) -> typing.List[NameAttribute]:
self,
oid: ObjectIdentifier,
) -> list[NameAttribute[str | bytes]]:
return [i for i in self if i.oid == oid]
def rfc4514_string(
self, attr_name_overrides: typing.Optional[_OidNameMap] = None
self, attr_name_overrides: _OidNameMap | None = None
) -> str:
"""
Format as RFC4514 Distinguished Name string.
@@ -254,7 +268,7 @@ class RelativeDistinguishedName:
def __hash__(self) -> int:
return hash(self._attribute_set)
def __iter__(self) -> typing.Iterator[NameAttribute]:
def __iter__(self) -> Iterator[NameAttribute]:
return iter(self._attributes)
def __len__(self) -> int:
@@ -266,20 +280,16 @@ class RelativeDistinguishedName:
class Name:
@typing.overload
def __init__(self, attributes: typing.Iterable[NameAttribute]) -> None:
...
def __init__(self, attributes: Iterable[NameAttribute]) -> None: ...
@typing.overload
def __init__(
self, attributes: typing.Iterable[RelativeDistinguishedName]
) -> None:
...
self, attributes: Iterable[RelativeDistinguishedName]
) -> None: ...
def __init__(
self,
attributes: typing.Iterable[
typing.Union[NameAttribute, RelativeDistinguishedName]
],
attributes: Iterable[NameAttribute | RelativeDistinguishedName],
) -> None:
attributes = list(attributes)
if all(isinstance(x, NameAttribute) for x in attributes):
@@ -301,12 +311,12 @@ class Name:
def from_rfc4514_string(
cls,
data: str,
attr_name_overrides: typing.Optional[_NameOidMap] = None,
attr_name_overrides: _NameOidMap | None = None,
) -> Name:
return _RFC4514NameParser(data, attr_name_overrides or {}).parse()
def rfc4514_string(
self, attr_name_overrides: typing.Optional[_OidNameMap] = None
self, attr_name_overrides: _OidNameMap | None = None
) -> str:
"""
Format as RFC4514 Distinguished Name string.
@@ -324,12 +334,13 @@ class Name:
)
def get_attributes_for_oid(
self, oid: ObjectIdentifier
) -> typing.List[NameAttribute]:
self,
oid: ObjectIdentifier,
) -> list[NameAttribute[str | bytes]]:
return [i for i in self if i.oid == oid]
@property
def rdns(self) -> typing.List[RelativeDistinguishedName]:
def rdns(self) -> list[RelativeDistinguishedName]:
return self._attributes
def public_bytes(self, backend: typing.Any = None) -> bytes:
@@ -346,10 +357,9 @@ class Name:
# for you, consider optimizing!
return hash(tuple(self._attributes))
def __iter__(self) -> typing.Iterator[NameAttribute]:
def __iter__(self) -> Iterator[NameAttribute]:
for rdn in self._attributes:
for ava in rdn:
yield ava
yield from rdn
def __len__(self) -> int:
return sum(len(rdn) for rdn in self._attributes)
@@ -395,7 +405,7 @@ class _RFC4514NameParser:
def _has_data(self) -> bool:
return self._idx < len(self._data)
def _peek(self) -> typing.Optional[str]:
def _peek(self) -> str | None:
if self._has_data():
return self._data[self._idx]
return None
@@ -422,6 +432,10 @@ class _RFC4514NameParser:
we parse it, we need to reverse again to get the RDNs on the
correct order.
"""
if not self._has_data():
return Name([])
rdns = [self._parse_rdn()]
while self._has_data():