updates
This commit is contained in:
@@ -9,6 +9,7 @@ import re
|
||||
import sys
|
||||
import typing
|
||||
import warnings
|
||||
from collections.abc import Iterable, Iterator
|
||||
|
||||
from cryptography import utils
|
||||
from cryptography.hazmat.bindings._rust import x509 as rust_x509
|
||||
@@ -31,7 +32,7 @@ class _ASN1Type(utils.Enum):
|
||||
|
||||
|
||||
_ASN1_TYPE_TO_ENUM = {i.value: i for i in _ASN1Type}
|
||||
_NAMEOID_DEFAULT_TYPE: typing.Dict[ObjectIdentifier, _ASN1Type] = {
|
||||
_NAMEOID_DEFAULT_TYPE: dict[ObjectIdentifier, _ASN1Type] = {
|
||||
NameOID.COUNTRY_NAME: _ASN1Type.PrintableString,
|
||||
NameOID.JURISDICTION_COUNTRY_NAME: _ASN1Type.PrintableString,
|
||||
NameOID.SERIAL_NUMBER: _ASN1Type.PrintableString,
|
||||
@@ -59,8 +60,14 @@ _NAMEOID_TO_NAME: _OidNameMap = {
|
||||
}
|
||||
_NAME_TO_NAMEOID = {v: k for k, v in _NAMEOID_TO_NAME.items()}
|
||||
|
||||
_NAMEOID_LENGTH_LIMIT = {
|
||||
NameOID.COUNTRY_NAME: (2, 2),
|
||||
NameOID.JURISDICTION_COUNTRY_NAME: (2, 2),
|
||||
NameOID.COMMON_NAME: (1, 64),
|
||||
}
|
||||
|
||||
def _escape_dn_value(val: typing.Union[str, bytes]) -> str:
|
||||
|
||||
def _escape_dn_value(val: str | bytes) -> str:
|
||||
"""Escape special characters in RFC4514 Distinguished Name value."""
|
||||
|
||||
if not val:
|
||||
@@ -108,12 +115,21 @@ def _unescape_dn_value(val: str) -> str:
|
||||
return _RFC4514NameParser._PAIR_RE.sub(sub, val)
|
||||
|
||||
|
||||
class NameAttribute:
|
||||
NameAttributeValueType = typing.TypeVar(
|
||||
"NameAttributeValueType",
|
||||
typing.Union[str, bytes],
|
||||
str,
|
||||
bytes,
|
||||
covariant=True,
|
||||
)
|
||||
|
||||
|
||||
class NameAttribute(typing.Generic[NameAttributeValueType]):
|
||||
def __init__(
|
||||
self,
|
||||
oid: ObjectIdentifier,
|
||||
value: typing.Union[str, bytes],
|
||||
_type: typing.Optional[_ASN1Type] = None,
|
||||
value: NameAttributeValueType,
|
||||
_type: _ASN1Type | None = None,
|
||||
*,
|
||||
_validate: bool = True,
|
||||
) -> None:
|
||||
@@ -128,26 +144,23 @@ class NameAttribute:
|
||||
)
|
||||
if not isinstance(value, bytes):
|
||||
raise TypeError("value must be bytes for BitString")
|
||||
else:
|
||||
if not isinstance(value, str):
|
||||
raise TypeError("value argument must be a str")
|
||||
elif not isinstance(value, str):
|
||||
raise TypeError("value argument must be a str")
|
||||
|
||||
if (
|
||||
oid == NameOID.COUNTRY_NAME
|
||||
or oid == NameOID.JURISDICTION_COUNTRY_NAME
|
||||
):
|
||||
length_limits = _NAMEOID_LENGTH_LIMIT.get(oid)
|
||||
if length_limits is not None:
|
||||
min_length, max_length = length_limits
|
||||
assert isinstance(value, str)
|
||||
c_len = len(value.encode("utf8"))
|
||||
if c_len != 2 and _validate is True:
|
||||
raise ValueError(
|
||||
"Country name must be a 2 character country code"
|
||||
)
|
||||
elif c_len != 2:
|
||||
warnings.warn(
|
||||
"Country names should be two characters, but the "
|
||||
"attribute is {} characters in length.".format(c_len),
|
||||
stacklevel=2,
|
||||
if c_len < min_length or c_len > max_length:
|
||||
msg = (
|
||||
f"Attribute's length must be >= {min_length} and "
|
||||
f"<= {max_length}, but it was {c_len}"
|
||||
)
|
||||
if _validate is True:
|
||||
raise ValueError(msg)
|
||||
else:
|
||||
warnings.warn(msg, stacklevel=2)
|
||||
|
||||
# The appropriate ASN1 string type varies by OID and is defined across
|
||||
# multiple RFCs including 2459, 3280, and 5280. In general UTF8String
|
||||
@@ -162,15 +175,15 @@ class NameAttribute:
|
||||
raise TypeError("_type must be from the _ASN1Type enum")
|
||||
|
||||
self._oid = oid
|
||||
self._value = value
|
||||
self._type = _type
|
||||
self._value: NameAttributeValueType = value
|
||||
self._type: _ASN1Type = _type
|
||||
|
||||
@property
|
||||
def oid(self) -> ObjectIdentifier:
|
||||
return self._oid
|
||||
|
||||
@property
|
||||
def value(self) -> typing.Union[str, bytes]:
|
||||
def value(self) -> NameAttributeValueType:
|
||||
return self._value
|
||||
|
||||
@property
|
||||
@@ -182,7 +195,7 @@ class NameAttribute:
|
||||
return _NAMEOID_TO_NAME.get(self.oid, self.oid.dotted_string)
|
||||
|
||||
def rfc4514_string(
|
||||
self, attr_name_overrides: typing.Optional[_OidNameMap] = None
|
||||
self, attr_name_overrides: _OidNameMap | None = None
|
||||
) -> str:
|
||||
"""
|
||||
Format as RFC4514 Distinguished Name string.
|
||||
@@ -208,11 +221,11 @@ class NameAttribute:
|
||||
return hash((self.oid, self.value))
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return "<NameAttribute(oid={0.oid}, value={0.value!r})>".format(self)
|
||||
return f"<NameAttribute(oid={self.oid}, value={self.value!r})>"
|
||||
|
||||
|
||||
class RelativeDistinguishedName:
|
||||
def __init__(self, attributes: typing.Iterable[NameAttribute]):
|
||||
def __init__(self, attributes: Iterable[NameAttribute]):
|
||||
attributes = list(attributes)
|
||||
if not attributes:
|
||||
raise ValueError("a relative distinguished name cannot be empty")
|
||||
@@ -227,12 +240,13 @@ class RelativeDistinguishedName:
|
||||
raise ValueError("duplicate attributes are not allowed")
|
||||
|
||||
def get_attributes_for_oid(
|
||||
self, oid: ObjectIdentifier
|
||||
) -> typing.List[NameAttribute]:
|
||||
self,
|
||||
oid: ObjectIdentifier,
|
||||
) -> list[NameAttribute[str | bytes]]:
|
||||
return [i for i in self if i.oid == oid]
|
||||
|
||||
def rfc4514_string(
|
||||
self, attr_name_overrides: typing.Optional[_OidNameMap] = None
|
||||
self, attr_name_overrides: _OidNameMap | None = None
|
||||
) -> str:
|
||||
"""
|
||||
Format as RFC4514 Distinguished Name string.
|
||||
@@ -254,7 +268,7 @@ class RelativeDistinguishedName:
|
||||
def __hash__(self) -> int:
|
||||
return hash(self._attribute_set)
|
||||
|
||||
def __iter__(self) -> typing.Iterator[NameAttribute]:
|
||||
def __iter__(self) -> Iterator[NameAttribute]:
|
||||
return iter(self._attributes)
|
||||
|
||||
def __len__(self) -> int:
|
||||
@@ -266,20 +280,16 @@ class RelativeDistinguishedName:
|
||||
|
||||
class Name:
|
||||
@typing.overload
|
||||
def __init__(self, attributes: typing.Iterable[NameAttribute]) -> None:
|
||||
...
|
||||
def __init__(self, attributes: Iterable[NameAttribute]) -> None: ...
|
||||
|
||||
@typing.overload
|
||||
def __init__(
|
||||
self, attributes: typing.Iterable[RelativeDistinguishedName]
|
||||
) -> None:
|
||||
...
|
||||
self, attributes: Iterable[RelativeDistinguishedName]
|
||||
) -> None: ...
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
attributes: typing.Iterable[
|
||||
typing.Union[NameAttribute, RelativeDistinguishedName]
|
||||
],
|
||||
attributes: Iterable[NameAttribute | RelativeDistinguishedName],
|
||||
) -> None:
|
||||
attributes = list(attributes)
|
||||
if all(isinstance(x, NameAttribute) for x in attributes):
|
||||
@@ -301,12 +311,12 @@ class Name:
|
||||
def from_rfc4514_string(
|
||||
cls,
|
||||
data: str,
|
||||
attr_name_overrides: typing.Optional[_NameOidMap] = None,
|
||||
attr_name_overrides: _NameOidMap | None = None,
|
||||
) -> Name:
|
||||
return _RFC4514NameParser(data, attr_name_overrides or {}).parse()
|
||||
|
||||
def rfc4514_string(
|
||||
self, attr_name_overrides: typing.Optional[_OidNameMap] = None
|
||||
self, attr_name_overrides: _OidNameMap | None = None
|
||||
) -> str:
|
||||
"""
|
||||
Format as RFC4514 Distinguished Name string.
|
||||
@@ -324,12 +334,13 @@ class Name:
|
||||
)
|
||||
|
||||
def get_attributes_for_oid(
|
||||
self, oid: ObjectIdentifier
|
||||
) -> typing.List[NameAttribute]:
|
||||
self,
|
||||
oid: ObjectIdentifier,
|
||||
) -> list[NameAttribute[str | bytes]]:
|
||||
return [i for i in self if i.oid == oid]
|
||||
|
||||
@property
|
||||
def rdns(self) -> typing.List[RelativeDistinguishedName]:
|
||||
def rdns(self) -> list[RelativeDistinguishedName]:
|
||||
return self._attributes
|
||||
|
||||
def public_bytes(self, backend: typing.Any = None) -> bytes:
|
||||
@@ -346,10 +357,9 @@ class Name:
|
||||
# for you, consider optimizing!
|
||||
return hash(tuple(self._attributes))
|
||||
|
||||
def __iter__(self) -> typing.Iterator[NameAttribute]:
|
||||
def __iter__(self) -> Iterator[NameAttribute]:
|
||||
for rdn in self._attributes:
|
||||
for ava in rdn:
|
||||
yield ava
|
||||
yield from rdn
|
||||
|
||||
def __len__(self) -> int:
|
||||
return sum(len(rdn) for rdn in self._attributes)
|
||||
@@ -395,7 +405,7 @@ class _RFC4514NameParser:
|
||||
def _has_data(self) -> bool:
|
||||
return self._idx < len(self._data)
|
||||
|
||||
def _peek(self) -> typing.Optional[str]:
|
||||
def _peek(self) -> str | None:
|
||||
if self._has_data():
|
||||
return self._data[self._idx]
|
||||
return None
|
||||
@@ -422,6 +432,10 @@ class _RFC4514NameParser:
|
||||
we parse it, we need to reverse again to get the RDNs on the
|
||||
correct order.
|
||||
"""
|
||||
|
||||
if not self._has_data():
|
||||
return Name([])
|
||||
|
||||
rdns = [self._parse_rdn()]
|
||||
|
||||
while self._has_data():
|
||||
|
||||
Reference in New Issue
Block a user