This commit is contained in:
Iliyan Angelov
2025-09-14 23:24:25 +03:00
commit c67067a2a4
71311 changed files with 6800714 additions and 0 deletions

View File

@@ -0,0 +1,49 @@
#-*- coding: utf-8 -*-
"""
This package is an implementation of the OpenID specification in
Python. It contains code for both server and consumer
implementations. For information on implementing an OpenID consumer,
see the C{L{openid.consumer.consumer}} module. For information on
implementing an OpenID server, see the C{L{openid.server.server}}
module.
@contact: U{http://github.com/necaris/python3-openid/}
@copyright: (C) 2005-2008 JanRain, Inc., 2012-2017 Rami Chowdhury
@license: Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
U{http://www.apache.org/licenses/LICENSE-2.0}
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions
and limitations under the License.
"""
version_info = (3, 2, 0)
__version__ = ".".join(str(x) for x in version_info)
__all__ = [
'association',
'consumer',
'cryptutil',
'dh',
'extension',
'extensions',
'fetchers',
'kvform',
'message',
'oidutil',
'server',
'sreg',
'store',
'urinorm',
'yadis',
]

View File

@@ -0,0 +1,562 @@
#-*-test-case-name: openid.test.test_association-*-
#-*- coding: utf-8 -*-
"""
This module contains code for dealing with associations between
consumers and servers. Associations contain a shared secret that is
used to sign C{openid.mode=id_res} messages.
Users of the library should not usually need to interact directly with
associations. The L{store<openid.store>}, L{server<openid.server.server>}
and L{consumer<openid.consumer.consumer>} objects will create and manage
the associations. The consumer and server code will make use of a
C{L{SessionNegotiator}} when managing associations, which enables
users to express a preference for what kind of associations should be
allowed, and what kind of exchange should be done to establish the
association.
@var default_negotiator: A C{L{SessionNegotiator}} that allows all
association types that are specified by the OpenID
specification. It prefers to use HMAC-SHA1/DH-SHA1, if it's
available. If HMAC-SHA256 is not supported by your Python runtime,
HMAC-SHA256 and DH-SHA256 will not be available.
@var encrypted_negotiator: A C{L{SessionNegotiator}} that
does not support C{'no-encryption'} associations. It prefers
HMAC-SHA1/DH-SHA1 association types if available.
"""
import time
import functools
from openid import cryptutil
from openid import kvform
from openid import oidutil
from openid.message import OPENID_NS
__all__ = [
'default_negotiator',
'encrypted_negotiator',
'SessionNegotiator',
'Association',
]
all_association_types = [
'HMAC-SHA1',
'HMAC-SHA256',
]
if hasattr(cryptutil, 'hmacSha256'):
supported_association_types = list(all_association_types)
default_association_order = [
('HMAC-SHA1', 'DH-SHA1'),
('HMAC-SHA1', 'no-encryption'),
('HMAC-SHA256', 'DH-SHA256'),
('HMAC-SHA256', 'no-encryption'),
]
only_encrypted_association_order = [
('HMAC-SHA1', 'DH-SHA1'),
('HMAC-SHA256', 'DH-SHA256'),
]
else:
supported_association_types = ['HMAC-SHA1']
default_association_order = [
('HMAC-SHA1', 'DH-SHA1'),
('HMAC-SHA1', 'no-encryption'),
]
only_encrypted_association_order = [
('HMAC-SHA1', 'DH-SHA1'),
]
def getSessionTypes(assoc_type):
"""Return the allowed session types for a given association type"""
assoc_to_session = {
'HMAC-SHA1': ['DH-SHA1', 'no-encryption'],
'HMAC-SHA256': ['DH-SHA256', 'no-encryption'],
}
return assoc_to_session.get(assoc_type, [])
def checkSessionType(assoc_type, session_type):
"""Check to make sure that this pair of assoc type and session
type are allowed"""
if session_type not in getSessionTypes(assoc_type):
raise ValueError('Session type %r not valid for assocation type %r' %
(session_type, assoc_type))
class SessionNegotiator(object):
"""A session negotiator controls the allowed and preferred
association types and association session types. Both the
C{L{Consumer<openid.consumer.consumer.Consumer>}} and
C{L{Server<openid.server.server.Server>}} use negotiators when
creating associations.
You can create and use negotiators if you:
- Do not want to do Diffie-Hellman key exchange because you use
transport-layer encryption (e.g. SSL)
- Want to use only SHA-256 associations
- Do not want to support plain-text associations over a non-secure
channel
It is up to you to set a policy for what kinds of associations to
accept. By default, the library will make any kind of association
that is allowed in the OpenID 2.0 specification.
Use of negotiators in the library
=================================
When a consumer makes an association request, it calls
C{L{getAllowedType}} to get the preferred association type and
association session type.
The server gets a request for a particular association/session
type and calls C{L{isAllowed}} to determine if it should
create an association. If it is supported, negotiation is
complete. If it is not, the server calls C{L{getAllowedType}} to
get an allowed association type to return to the consumer.
If the consumer gets an error response indicating that the
requested association/session type is not supported by the server
that contains an assocation/session type to try, it calls
C{L{isAllowed}} to determine if it should try again with the
given combination of association/session type.
@ivar allowed_types: A list of association/session types that are
allowed by the server. The order of the pairs in this list
determines preference. If an association/session type comes
earlier in the list, the library is more likely to use that
type.
@type allowed_types: [(str, str)]
"""
def __init__(self, allowed_types):
self.setAllowedTypes(allowed_types)
def copy(self):
return self.__class__(list(self.allowed_types))
def setAllowedTypes(self, allowed_types):
"""Set the allowed association types, checking to make sure
each combination is valid."""
for (assoc_type, session_type) in allowed_types:
checkSessionType(assoc_type, session_type)
self.allowed_types = allowed_types
def addAllowedType(self, assoc_type, session_type=None):
"""Add an association type and session type to the allowed
types list. The assocation/session pairs are tried in the
order that they are added."""
if self.allowed_types is None:
self.allowed_types = []
if session_type is None:
available = getSessionTypes(assoc_type)
if not available:
raise ValueError('No session available for association type %r'
% (assoc_type, ))
for session_type in getSessionTypes(assoc_type):
self.addAllowedType(assoc_type, session_type)
else:
checkSessionType(assoc_type, session_type)
self.allowed_types.append((assoc_type, session_type))
def isAllowed(self, assoc_type, session_type):
"""Is this combination of association type and session type allowed?"""
assoc_good = (assoc_type, session_type) in self.allowed_types
matches = session_type in getSessionTypes(assoc_type)
return assoc_good and matches
def getAllowedType(self):
"""Get a pair of assocation type and session type that are
supported"""
try:
return self.allowed_types[0]
except IndexError:
return (None, None)
default_negotiator = SessionNegotiator(default_association_order)
encrypted_negotiator = SessionNegotiator(only_encrypted_association_order)
def getSecretSize(assoc_type):
if assoc_type == 'HMAC-SHA1':
return 20
elif assoc_type == 'HMAC-SHA256':
return 32
else:
raise ValueError('Unsupported association type: %r' % (assoc_type, ))
@functools.total_ordering
class Association(object):
"""
This class represents an association between a server and a
consumer. In general, users of this library will never see
instances of this object. The only exception is if you implement
a custom C{L{OpenIDStore<openid.store.interface.OpenIDStore>}}.
If you do implement such a store, it will need to store the values
of the C{L{handle}}, C{L{secret}}, C{L{issued}}, C{L{lifetime}}, and
C{L{assoc_type}} instance variables.
@ivar handle: This is the handle the server gave this association.
@type handle: C{str}
@ivar secret: This is the shared secret the server generated for
this association.
@type secret: C{str}
@ivar issued: This is the time this association was issued, in
seconds since 00:00 GMT, January 1, 1970. (ie, a unix
timestamp)
@type issued: C{int}
@ivar lifetime: This is the amount of time this association is
good for, measured in seconds since the association was
issued.
@type lifetime: C{int}
@ivar assoc_type: This is the type of association this instance
represents. The only valid value of this field at this time
is C{'HMAC-SHA1'}, but new types may be defined in the future.
@type assoc_type: C{str}
@sort: __init__, fromExpiresIn, expiresIn, __eq__, __ne__,
handle, secret, issued, lifetime, assoc_type
"""
# The ordering and name of keys as stored by serialize
assoc_keys = [
'version',
'handle',
'secret',
'issued',
'lifetime',
'assoc_type',
]
_macs = {
'HMAC-SHA1': cryptutil.hmacSha1,
'HMAC-SHA256': cryptutil.hmacSha256,
}
@classmethod
def fromExpiresIn(cls, expires_in, handle, secret, assoc_type):
"""
This is an alternate constructor used by the OpenID consumer
library to create associations. C{L{OpenIDStore
<openid.store.interface.OpenIDStore>}} implementations
shouldn't use this constructor.
@param expires_in: This is the amount of time this association
is good for, measured in seconds since the association was
issued.
@type expires_in: C{int}
@param handle: This is the handle the server gave this
association.
@type handle: C{str}
@param secret: This is the shared secret the server generated
for this association.
@type secret: C{str}
@param assoc_type: This is the type of association this
instance represents. The only valid value of this field
at this time is C{'HMAC-SHA1'}, but new types may be
defined in the future.
@type assoc_type: C{str}
"""
issued = int(time.time())
lifetime = expires_in
return cls(handle, secret, issued, lifetime, assoc_type)
def __init__(self, handle, secret, issued, lifetime, assoc_type):
"""
This is the standard constructor for creating an association.
@param handle: This is the handle the server gave this
association.
@type handle: C{str}
@param secret: This is the shared secret the server generated
for this association.
@type secret: C{str}
@param issued: This is the time this association was issued,
in seconds since 00:00 GMT, January 1, 1970. (ie, a unix
timestamp)
@type issued: C{int}
@param lifetime: This is the amount of time this association
is good for, measured in seconds since the association was
issued.
@type lifetime: C{int}
@param assoc_type: This is the type of association this
instance represents. The only valid value of this field
at this time is C{'HMAC-SHA1'}, but new types may be
defined in the future.
@type assoc_type: C{str}
"""
if assoc_type not in all_association_types:
fmt = '%r is not a supported association type'
raise ValueError(fmt % (assoc_type, ))
# secret_size = getSecretSize(assoc_type)
# if len(secret) != secret_size:
# fmt = 'Wrong size secret (%s bytes) for association type %s'
# raise ValueError(fmt % (len(secret), assoc_type))
self.handle = handle
if isinstance(secret, str):
secret = secret.encode("utf-8") # should be bytes
self.secret = secret
self.issued = issued
self.lifetime = lifetime
self.assoc_type = assoc_type
@property
def expiresIn(self, now=None):
"""
This returns the number of seconds this association is still
valid for, or C{0} if the association is no longer valid.
@return: The number of seconds this association is still valid
for, or C{0} if the association is no longer valid.
@rtype: C{int}
"""
if now is None:
now = int(time.time())
return max(0, self.issued + self.lifetime - now)
def __lt__(self, other):
"""
Compare two C{L{Association}} instances to determine relative
ordering.
Currently compares object lifetimes -- C{L{Association}} A < B
if A.lifetime < B.lifetime.
"""
return self.lifetime < other.lifetime
def __eq__(self, other):
"""
This checks to see if two C{L{Association}} instances
represent the same association.
@return: C{True} if the two instances represent the same
association, C{False} otherwise.
@rtype: C{bool}
"""
return type(self) is type(other) and self.__dict__ == other.__dict__
def __ne__(self, other):
"""
This checks to see if two C{L{Association}} instances
represent different associations.
@return: C{True} if the two instances represent different
associations, C{False} otherwise.
@rtype: C{bool}
"""
return not (self == other)
def serialize(self):
"""
Convert an association to KV form.
@return: String in KV form suitable for deserialization by
deserialize.
@rtype: str
"""
data = {
'version': '2',
'handle': self.handle,
'secret': oidutil.toBase64(self.secret),
'issued': str(int(self.issued)),
'lifetime': str(int(self.lifetime)),
'assoc_type': self.assoc_type
}
assert len(data) == len(self.assoc_keys)
pairs = []
for field_name in self.assoc_keys:
pairs.append((field_name, data[field_name]))
return kvform.seqToKV(pairs, strict=True)
@classmethod
def deserialize(cls, assoc_s):
"""
Parse an association as stored by serialize().
inverse of serialize
@param assoc_s: Association as serialized by serialize()
@type assoc_s: bytes
@return: instance of this class
"""
pairs = kvform.kvToSeq(assoc_s, strict=True)
keys = []
values = []
for k, v in pairs:
keys.append(k)
values.append(v)
if keys != cls.assoc_keys:
raise ValueError('Unexpected key values: %r', keys)
version, handle, secret, issued, lifetime, assoc_type = values
if version != '2':
raise ValueError('Unknown version: %r' % version)
issued = int(issued)
lifetime = int(lifetime)
secret = oidutil.fromBase64(secret)
return cls(handle, secret, issued, lifetime, assoc_type)
def sign(self, pairs):
"""
Generate a signature for a sequence of (key, value) pairs
@param pairs: The pairs to sign, in order
@type pairs: sequence of (str, str)
@return: The binary signature of this sequence of pairs
@rtype: bytes
"""
kv = kvform.seqToKV(pairs)
try:
mac = self._macs[self.assoc_type]
except KeyError:
raise ValueError('Unknown association type: %r' %
(self.assoc_type, ))
return mac(self.secret, kv)
def getMessageSignature(self, message):
"""Return the signature of a message.
If I am not a sign-all association, the message must have a
signed list.
@return: the signature, base64 encoded
@rtype: bytes
@raises ValueError: If there is no signed list and I am not a sign-all
type of association.
"""
pairs = self._makePairs(message)
return oidutil.toBase64(self.sign(pairs))
def signMessage(self, message):
"""Add a signature (and a signed list) to a message.
@return: a new Message object with a signature
@rtype: L{openid.message.Message}
"""
if (message.hasKey(OPENID_NS, 'sig') or
message.hasKey(OPENID_NS, 'signed')):
raise ValueError('Message already has signed list or signature')
extant_handle = message.getArg(OPENID_NS, 'assoc_handle')
if extant_handle and extant_handle != self.handle:
raise ValueError("Message has a different association handle")
signed_message = message.copy()
signed_message.setArg(OPENID_NS, 'assoc_handle', self.handle)
message_keys = list(signed_message.toPostArgs().keys())
signed_list = [k[7:] for k in message_keys if k.startswith('openid.')]
signed_list.append('signed')
signed_list.sort()
signed_message.setArg(OPENID_NS, 'signed', ','.join(signed_list))
sig = self.getMessageSignature(signed_message)
signed_message.setArg(OPENID_NS, 'sig', sig)
return signed_message
def checkMessageSignature(self, message):
"""Given a message with a signature, calculate a new signature
and return whether it matches the signature in the message.
@raises ValueError: if the message has no signature or no signature
can be calculated for it.
"""
message_sig = message.getArg(OPENID_NS, 'sig')
if not message_sig:
raise ValueError("%s has no sig." % (message, ))
calculated_sig = self.getMessageSignature(message)
# remember, getMessageSignature returns bytes
calculated_sig = calculated_sig.decode('utf-8')
return cryptutil.const_eq(calculated_sig, message_sig)
def _makePairs(self, message):
signed = message.getArg(OPENID_NS, 'signed')
if not signed:
raise ValueError('Message has no signed list: %s' % (message, ))
signed_list = signed.split(',')
pairs = []
data = message.toPostArgs()
for field in signed_list:
pairs.append((field, data.get('openid.' + field, '')))
return pairs
def __repr__(self):
return "<%s.%s %s %s>" % (self.__class__.__module__,
self.__class__.__name__, self.assoc_type,
self.handle)

View File

@@ -0,0 +1,91 @@
import codecs
try:
chr(0x10000)
except ValueError:
# narrow python build
UCSCHAR = [
(0xA0, 0xD7FF),
(0xF900, 0xFDCF),
(0xFDF0, 0xFFEF),
]
IPRIVATE = [
(0xE000, 0xF8FF),
]
else:
UCSCHAR = [
(0xA0, 0xD7FF),
(0xF900, 0xFDCF),
(0xFDF0, 0xFFEF),
(0x10000, 0x1FFFD),
(0x20000, 0x2FFFD),
(0x30000, 0x3FFFD),
(0x40000, 0x4FFFD),
(0x50000, 0x5FFFD),
(0x60000, 0x6FFFD),
(0x70000, 0x7FFFD),
(0x80000, 0x8FFFD),
(0x90000, 0x9FFFD),
(0xA0000, 0xAFFFD),
(0xB0000, 0xBFFFD),
(0xC0000, 0xCFFFD),
(0xD0000, 0xDFFFD),
(0xE1000, 0xEFFFD),
]
IPRIVATE = [
(0xE000, 0xF8FF),
(0xF0000, 0xFFFFD),
(0x100000, 0x10FFFD),
]
_ESCAPE_RANGES = UCSCHAR + IPRIVATE
def _in_escape_range(octet):
for start, end in _ESCAPE_RANGES:
if start <= octet <= end:
return True
return False
def _starts_surrogate_pair(character):
char_value = ord(character)
return 0xD800 <= char_value <= 0xDBFF
def _ends_surrogate_pair(character):
char_value = ord(character)
return 0xDC00 <= char_value <= 0xDFFF
def _pct_encoded_replacements(chunk):
replacements = []
chunk_iter = iter(chunk)
for character in chunk_iter:
codepoint = ord(character)
if _in_escape_range(codepoint):
for char in chr(codepoint).encode("utf-8"):
replacements.append("%%%X" % char)
elif _starts_surrogate_pair(character):
next_character = next(chunk_iter)
for char in (character + next_character).encode("utf-8"):
replacements.append("%%%X" % char)
else:
replacements.append(chr(codepoint))
return replacements
def _pct_escape_handler(err):
'''
Encoding error handler that does percent-escaping of Unicode, to be used
with codecs.register_error
TODO: replace use of this with urllib.parse.quote as appropriate
'''
chunk = err.object[err.start:err.end]
replacements = _pct_encoded_replacements(chunk)
return ("".join(replacements), err.end)
codecs.register_error("oid_percent_escape", _pct_escape_handler)

View File

@@ -0,0 +1,6 @@
"""
This package contains the portions of the library used only when
implementing an OpenID consumer.
"""
__all__ = ['consumer', 'discover']

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,467 @@
# -*- test-case-name: openid.test.test_discover -*-
"""Functions to discover OpenID endpoints from identifiers.
"""
__all__ = [
'DiscoveryFailure',
'OPENID_1_0_NS',
'OPENID_1_0_TYPE',
'OPENID_1_1_TYPE',
'OPENID_2_0_TYPE',
'OPENID_IDP_2_0_TYPE',
'OpenIDServiceEndpoint',
'discover',
]
import urllib.parse
import logging
from openid import fetchers, urinorm
from openid import yadis
from openid.yadis.etxrd import nsTag, XRDSError, XRD_NS_2_0
from openid.yadis.services import applyFilter as extractServices
from openid.yadis.discover import discover as yadisDiscover
from openid.yadis.discover import DiscoveryFailure
from openid.yadis import xrires, filters
from openid.yadis import xri
from openid.consumer import html_parse
OPENID_1_0_NS = 'http://openid.net/xmlns/1.0'
OPENID_IDP_2_0_TYPE = 'http://specs.openid.net/auth/2.0/server'
OPENID_2_0_TYPE = 'http://specs.openid.net/auth/2.0/signon'
OPENID_1_1_TYPE = 'http://openid.net/signon/1.1'
OPENID_1_0_TYPE = 'http://openid.net/signon/1.0'
from openid.message import OPENID1_NS as OPENID_1_0_MESSAGE_NS
from openid.message import OPENID2_NS as OPENID_2_0_MESSAGE_NS
logger = logging.getLogger(__name__)
class OpenIDServiceEndpoint(object):
"""Object representing an OpenID service endpoint.
@ivar identity_url: the verified identifier.
@ivar canonicalID: For XRI, the persistent identifier.
"""
# OpenID service type URIs, listed in order of preference. The
# ordering of this list affects yadis and XRI service discovery.
openid_type_uris = [
OPENID_IDP_2_0_TYPE,
OPENID_2_0_TYPE,
OPENID_1_1_TYPE,
OPENID_1_0_TYPE,
]
def __init__(self):
self.claimed_id = None
self.server_url = None
self.type_uris = []
self.local_id = None
self.canonicalID = None
self.used_yadis = False # whether this came from an XRDS
self.display_identifier = None
def usesExtension(self, extension_uri):
return extension_uri in self.type_uris
def preferredNamespace(self):
if (OPENID_IDP_2_0_TYPE in self.type_uris or
OPENID_2_0_TYPE in self.type_uris):
return OPENID_2_0_MESSAGE_NS
else:
return OPENID_1_0_MESSAGE_NS
def supportsType(self, type_uri):
"""Does this endpoint support this type?
I consider C{/server} endpoints to implicitly support C{/signon}.
"""
return ((type_uri in self.type_uris) or
(type_uri == OPENID_2_0_TYPE and self.isOPIdentifier()))
def getDisplayIdentifier(self):
"""Return the display_identifier if set, else return the claimed_id.
"""
if self.display_identifier is not None:
return self.display_identifier
if self.claimed_id is None:
return None
else:
return urllib.parse.urldefrag(self.claimed_id)[0]
def compatibilityMode(self):
return self.preferredNamespace() != OPENID_2_0_MESSAGE_NS
def isOPIdentifier(self):
return OPENID_IDP_2_0_TYPE in self.type_uris
def parseService(self, yadis_url, uri, type_uris, service_element):
"""Set the state of this object based on the contents of the
service element."""
self.type_uris = type_uris
self.server_url = uri
self.used_yadis = True
if not self.isOPIdentifier():
# XXX: This has crappy implications for Service elements
# that contain both 'server' and 'signon' Types. But
# that's a pathological configuration anyway, so I don't
# think I care.
self.local_id = findOPLocalIdentifier(service_element,
self.type_uris)
self.claimed_id = yadis_url
def getLocalID(self):
"""Return the identifier that should be sent as the
openid.identity parameter to the server."""
# I looked at this conditional and thought "ah-hah! there's the bug!"
# but Python actually makes that one big expression somehow, i.e.
# "x is x is x" is not the same thing as "(x is x) is x".
# That's pretty weird, dude. -- kmt, 1/07
if (self.local_id is self.canonicalID is None):
return self.claimed_id
else:
return self.local_id or self.canonicalID
def fromBasicServiceEndpoint(cls, endpoint):
"""Create a new instance of this class from the endpoint
object passed in.
@return: None or OpenIDServiceEndpoint for this endpoint object"""
type_uris = endpoint.matchTypes(cls.openid_type_uris)
# If any Type URIs match and there is an endpoint URI
# specified, then this is an OpenID endpoint
if type_uris and endpoint.uri is not None:
openid_endpoint = cls()
openid_endpoint.parseService(endpoint.yadis_url, endpoint.uri,
endpoint.type_uris,
endpoint.service_element)
else:
openid_endpoint = None
return openid_endpoint
fromBasicServiceEndpoint = classmethod(fromBasicServiceEndpoint)
def fromHTML(cls, uri, html):
"""Parse the given document as HTML looking for an OpenID <link
rel=...>
@rtype: [OpenIDServiceEndpoint]
"""
discovery_types = [
(OPENID_2_0_TYPE, 'openid2.provider', 'openid2.local_id'),
(OPENID_1_1_TYPE, 'openid.server', 'openid.delegate'),
]
link_attrs = html_parse.parseLinkAttrs(html)
services = []
for type_uri, op_endpoint_rel, local_id_rel in discovery_types:
op_endpoint_url = html_parse.findFirstHref(link_attrs,
op_endpoint_rel)
if op_endpoint_url is None:
continue
service = cls()
service.claimed_id = uri
service.local_id = html_parse.findFirstHref(link_attrs,
local_id_rel)
service.server_url = op_endpoint_url
service.type_uris = [type_uri]
services.append(service)
return services
fromHTML = classmethod(fromHTML)
def fromXRDS(cls, uri, xrds):
"""Parse the given document as XRDS looking for OpenID services.
@rtype: [OpenIDServiceEndpoint]
@raises XRDSError: When the XRDS does not parse.
@since: 2.1.0
"""
return extractServices(uri, xrds, cls)
fromXRDS = classmethod(fromXRDS)
def fromDiscoveryResult(cls, discoveryResult):
"""Create endpoints from a DiscoveryResult.
@type discoveryResult: L{DiscoveryResult}
@rtype: list of L{OpenIDServiceEndpoint}
@raises XRDSError: When the XRDS does not parse.
@since: 2.1.0
"""
if discoveryResult.isXRDS():
method = cls.fromXRDS
else:
method = cls.fromHTML
return method(discoveryResult.normalized_uri,
discoveryResult.response_text)
fromDiscoveryResult = classmethod(fromDiscoveryResult)
def fromOPEndpointURL(cls, op_endpoint_url):
"""Construct an OP-Identifier OpenIDServiceEndpoint object for
a given OP Endpoint URL
@param op_endpoint_url: The URL of the endpoint
@rtype: OpenIDServiceEndpoint
"""
service = cls()
service.server_url = op_endpoint_url
service.type_uris = [OPENID_IDP_2_0_TYPE]
return service
fromOPEndpointURL = classmethod(fromOPEndpointURL)
def __str__(self):
return ("<%s.%s "
"server_url=%r "
"claimed_id=%r "
"local_id=%r "
"canonicalID=%r "
"used_yadis=%s "
">" % (self.__class__.__module__, self.__class__.__name__,
self.server_url, self.claimed_id, self.local_id,
self.canonicalID, self.used_yadis))
def findOPLocalIdentifier(service_element, type_uris):
"""Find the OP-Local Identifier for this xrd:Service element.
This considers openid:Delegate to be a synonym for xrd:LocalID if
both OpenID 1.X and OpenID 2.0 types are present. If only OpenID
1.X is present, it returns the value of openid:Delegate. If only
OpenID 2.0 is present, it returns the value of xrd:LocalID. If
there is more than one LocalID tag and the values are different,
it raises a DiscoveryFailure. This is also triggered when the
xrd:LocalID and openid:Delegate tags are different.
@param service_element: The xrd:Service element
@type service_element: ElementTree.Node
@param type_uris: The xrd:Type values present in this service
element. This function could extract them, but higher level
code needs to do that anyway.
@type type_uris: [str]
@raises DiscoveryFailure: when discovery fails.
@returns: The OP-Local Identifier for this service element, if one
is present, or None otherwise.
@rtype: str or unicode or NoneType
"""
# XXX: Test this function on its own!
# Build the list of tags that could contain the OP-Local Identifier
local_id_tags = []
if (OPENID_1_1_TYPE in type_uris or OPENID_1_0_TYPE in type_uris):
local_id_tags.append(nsTag(OPENID_1_0_NS, 'Delegate'))
if OPENID_2_0_TYPE in type_uris:
local_id_tags.append(nsTag(XRD_NS_2_0, 'LocalID'))
# Walk through all the matching tags and make sure that they all
# have the same value
local_id = None
for local_id_tag in local_id_tags:
for local_id_element in service_element.findall(local_id_tag):
if local_id is None:
local_id = local_id_element.text
elif local_id != local_id_element.text:
format = 'More than one %r tag found in one service element'
message = format % (local_id_tag, )
raise DiscoveryFailure(message, None)
return local_id
def normalizeURL(url):
"""Normalize a URL, converting normalization failures to
DiscoveryFailure"""
try:
normalized = urinorm.urinorm(url)
except ValueError as why:
raise DiscoveryFailure('Normalizing identifier: %s' % (why, ), None)
else:
return urllib.parse.urldefrag(normalized)[0]
def normalizeXRI(xri):
"""Normalize an XRI, stripping its scheme if present"""
if xri.startswith("xri://"):
xri = xri[6:]
return xri
def arrangeByType(service_list, preferred_types):
"""Rearrange service_list in a new list so services are ordered by
types listed in preferred_types. Return the new list."""
def enumerate(elts):
"""Return an iterable that pairs the index of an element with
that element.
For Python 2.2 compatibility"""
return list(zip(list(range(len(elts))), elts))
def bestMatchingService(service):
"""Return the index of the first matching type, or something
higher if no type matches.
This provides an ordering in which service elements that
contain a type that comes earlier in the preferred types list
come before service elements that come later. If a service
element has more than one type, the most preferred one wins.
"""
for i, t in enumerate(preferred_types):
if preferred_types[i] in service.type_uris:
return i
return len(preferred_types)
# Build a list with the service elements in tuples whose
# comparison will prefer the one with the best matching service
prio_services = [(bestMatchingService(s), orig_index, s)
for (orig_index, s) in enumerate(service_list)]
prio_services.sort()
# Now that the services are sorted by priority, remove the sort
# keys from the list.
for i in range(len(prio_services)):
prio_services[i] = prio_services[i][2]
return prio_services
def getOPOrUserServices(openid_services):
"""Extract OP Identifier services. If none found, return the
rest, sorted with most preferred first according to
OpenIDServiceEndpoint.openid_type_uris.
openid_services is a list of OpenIDServiceEndpoint objects.
Returns a list of OpenIDServiceEndpoint objects."""
op_services = arrangeByType(openid_services, [OPENID_IDP_2_0_TYPE])
openid_services = arrangeByType(openid_services,
OpenIDServiceEndpoint.openid_type_uris)
return op_services or openid_services
def discoverYadis(uri):
"""Discover OpenID services for a URI. Tries Yadis and falls back
on old-style <link rel='...'> discovery if Yadis fails.
@param uri: normalized identity URL
@type uri: str
@return: (claimed_id, services)
@rtype: (str, list(OpenIDServiceEndpoint))
@raises DiscoveryFailure: when discovery fails.
"""
# Might raise a yadis.discover.DiscoveryFailure if no document
# came back for that URI at all. I don't think falling back
# to OpenID 1.0 discovery on the same URL will help, so don't
# bother to catch it.
response = yadisDiscover(uri)
yadis_url = response.normalized_uri
body = response.response_text
try:
openid_services = OpenIDServiceEndpoint.fromXRDS(yadis_url, body)
except XRDSError:
# Does not parse as a Yadis XRDS file
openid_services = []
if not openid_services:
# Either not an XRDS or there are no OpenID services.
if response.isXRDS():
# if we got the Yadis content-type or followed the Yadis
# header, re-fetch the document without following the Yadis
# header, with no Accept header.
return discoverNoYadis(uri)
# Try to parse the response as HTML.
# <link rel="...">
openid_services = OpenIDServiceEndpoint.fromHTML(yadis_url, body)
return (yadis_url, getOPOrUserServices(openid_services))
def discoverXRI(iname):
endpoints = []
iname = normalizeXRI(iname)
try:
canonicalID, services = xrires.ProxyResolver().query(
iname, OpenIDServiceEndpoint.openid_type_uris)
if canonicalID is None:
raise XRDSError('No CanonicalID found for XRI %r' % (iname, ))
flt = filters.mkFilter(OpenIDServiceEndpoint)
for service_element in services:
endpoints.extend(flt.getServiceEndpoints(iname, service_element))
except XRDSError:
logger.exception('xrds error on ' + iname)
for endpoint in endpoints:
# Is there a way to pass this through the filter to the endpoint
# constructor instead of tacking it on after?
endpoint.canonicalID = canonicalID
endpoint.claimed_id = canonicalID
endpoint.display_identifier = iname
# FIXME: returned xri should probably be in some normal form
return iname, getOPOrUserServices(endpoints)
def discoverNoYadis(uri):
http_resp = fetchers.fetch(uri)
if http_resp.status not in (200, 206):
raise DiscoveryFailure(
'HTTP Response status from identity URL host is not 200. '
'Got status %r' % (http_resp.status, ), http_resp)
claimed_id = http_resp.final_url
openid_services = OpenIDServiceEndpoint.fromHTML(claimed_id,
http_resp.body)
return claimed_id, openid_services
def discoverURI(uri):
parsed = urllib.parse.urlparse(uri)
if parsed[0] and parsed[1]:
if parsed[0] not in ['http', 'https']:
raise DiscoveryFailure('URI scheme is not HTTP or HTTPS', None)
else:
uri = 'http://' + uri
uri = normalizeURL(uri)
claimed_id, openid_services = discoverYadis(uri)
claimed_id = normalizeURL(claimed_id)
return claimed_id, openid_services
def discover(identifier):
if xri.identifierScheme(identifier) == "XRI":
return discoverXRI(identifier)
else:
return discoverURI(identifier)

View File

@@ -0,0 +1,278 @@
"""
This module implements a VERY limited parser that finds <link> tags in
the head of HTML or XHTML documents and parses out their attributes
according to the OpenID spec. It is a liberal parser, but it requires
these things from the data in order to work:
- There must be an open <html> tag
- There must be an open <head> tag inside of the <html> tag
- Only <link>s that are found inside of the <head> tag are parsed
(this is by design)
- The parser follows the OpenID specification in resolving the
attributes of the link tags. This means that the attributes DO NOT
get resolved as they would by an XML or HTML parser. In particular,
only certain entities get replaced, and href attributes do not get
resolved relative to a base URL.
From http://openid.net/specs.bml#linkrel:
- The openid.server URL MUST be an absolute URL. OpenID consumers
MUST NOT attempt to resolve relative URLs.
- The openid.server URL MUST NOT include entities other than &amp;,
&lt;, &gt;, and &quot;.
The parser ignores SGML comments and <![CDATA[blocks]]>. Both kinds of
quoting are allowed for attributes.
The parser deals with invalid markup in these ways:
- Tag names are not case-sensitive
- The <html> tag is accepted even when it is not at the top level
- The <head> tag is accepted even when it is not a direct child of
the <html> tag, but a <html> tag must be an ancestor of the <head>
tag
- <link> tags are accepted even when they are not direct children of
the <head> tag, but a <head> tag must be an ancestor of the <link>
tag
- If there is no closing tag for an open <html> or <head> tag, the
remainder of the document is viewed as being inside of the tag. If
there is no closing tag for a <link> tag, the link tag is treated
as a short tag. Exceptions to this rule are that <html> closes
<html> and <body> or <head> closes <head>
- Attributes of the <link> tag are not required to be quoted.
- In the case of duplicated attribute names, the attribute coming
last in the tag will be the value returned.
- Any text that does not parse as an attribute within a link tag will
be ignored. (e.g. <link pumpkin rel='openid.server' /> will ignore
pumpkin)
- If there are more than one <html> or <head> tag, the parser only
looks inside of the first one.
- The contents of <script> tags are ignored entirely, except unclosed
<script> tags. Unclosed <script> tags are ignored.
- Any other invalid markup is ignored, including unclosed SGML
comments and unclosed <![CDATA[blocks.
"""
__all__ = ['parseLinkAttrs']
import re
flags = (
re.DOTALL # Match newlines with '.'
| re.IGNORECASE | re.VERBOSE # Allow comments and whitespace in patterns
| re.UNICODE # Make \b respect Unicode word boundaries
)
# Stuff to remove before we start looking for tags
removed_re = re.compile(r'''
# Comments
<!--.*?-->
# CDATA blocks
| <!\[CDATA\[.*?\]\]>
# script blocks
| <script\b
# make sure script is not an XML namespace
(?!:)
[^>]*>.*?</script>
''', flags)
tag_expr = r'''
# Starts with the tag name at a word boundary, where the tag name is
# not a namespace
<%(tag_name)s\b(?!:)
# All of the stuff up to a ">", hopefully attributes.
(?P<attrs>[^>]*?)
(?: # Match a short tag
/>
| # Match a full tag
>
(?P<contents>.*?)
# Closed by
(?: # One of the specified close tags
</?%(closers)s\s*>
# End of the string
| \Z
)
)
'''
def tagMatcher(tag_name, *close_tags):
if close_tags:
options = '|'.join((tag_name, ) + close_tags)
closers = '(?:%s)' % (options, )
else:
closers = tag_name
expr = tag_expr % locals()
return re.compile(expr, flags)
# Must contain at least an open html and an open head tag
html_find = tagMatcher('html')
head_find = tagMatcher('head', 'body')
link_find = re.compile(r'<link\b(?!:)', flags)
attr_find = re.compile(r'''
# Must start with a sequence of word-characters, followed by an equals sign
(?P<attr_name>\w+)=
# Then either a quoted or unquoted attribute
(?:
# Match everything that\'s between matching quote marks
(?P<qopen>["\'])(?P<q_val>.*?)(?P=qopen)
|
# If the value is not quoted, match up to whitespace
(?P<unq_val>(?:[^\s<>/]|/(?!>))+)
)
|
(?P<end_link>[<>])
''', flags)
# Entity replacement:
replacements = {
'amp': '&',
'lt': '<',
'gt': '>',
'quot': '"',
}
ent_replace = re.compile(r'&(%s);' % '|'.join(list(replacements.keys())))
def replaceEnt(mo):
"Replace the entities that are specified by OpenID"
return replacements.get(mo.group(1), mo.group())
def parseLinkAttrs(html, ignore_errors=False):
"""Find all link tags in a string representing a HTML document and
return a list of their attributes.
@param html: the text to parse
@type html: str or unicode
@param ignore_errors: whether to return despite e.g. parsing errors
@type ignore_errors: bool
@return: A list of dictionaries of attributes, one for each link tag
@rtype: [[(type(html), type(html))]]
"""
if isinstance(html, bytes):
# Attempt to decode as UTF-8, since that's the most modern -- also
# try Latin-1, since that's suggested by HTTP/1.1. If neither of
# those works, fall over.
try:
html = html.decode("utf-8")
except UnicodeDecodeError:
try:
html = html.decode("latin1")
except UnicodeDecodeError:
if ignore_errors:
# Optionally ignore the errors and act as if no link attrs
# were found here
return []
else:
raise AssertionError("Unreadable HTML!")
stripped = removed_re.sub('', html)
html_mo = html_find.search(stripped)
if html_mo is None or html_mo.start('contents') == -1:
return []
start, end = html_mo.span('contents')
head_mo = head_find.search(stripped, start, end)
if head_mo is None or head_mo.start('contents') == -1:
return []
start, end = head_mo.span('contents')
link_mos = link_find.finditer(stripped, head_mo.start(), head_mo.end())
matches = []
for link_mo in link_mos:
start = link_mo.start() + 5
link_attrs = {}
for attr_mo in attr_find.finditer(stripped, start):
if attr_mo.lastgroup == 'end_link':
break
# Either q_val or unq_val must be present, but not both
# unq_val is a True (non-empty) value if it is present
attr_name, q_val, unq_val = attr_mo.group('attr_name', 'q_val',
'unq_val')
attr_val = ent_replace.sub(replaceEnt, unq_val or q_val)
link_attrs[attr_name] = attr_val
matches.append(link_attrs)
return matches
def relMatches(rel_attr, target_rel):
"""Does this target_rel appear in the rel_str?"""
# XXX: TESTME
rels = rel_attr.strip().split()
for rel in rels:
rel = rel.lower()
if rel == target_rel:
return 1
return 0
def linkHasRel(link_attrs, target_rel):
"""Does this link have target_rel as a relationship?"""
# XXX: TESTME
rel_attr = link_attrs.get('rel')
return rel_attr and relMatches(rel_attr, target_rel)
def findLinksRel(link_attrs_list, target_rel):
"""Filter the list of link attributes on whether it has target_rel
as a relationship."""
# XXX: TESTME
matchesTarget = lambda attrs: linkHasRel(attrs, target_rel)
return list(filter(matchesTarget, link_attrs_list))
def findFirstHref(link_attrs_list, target_rel):
"""Return the value of the href attribute for the first link tag
in the list that has target_rel as a relationship."""
# XXX: TESTME
matches = findLinksRel(link_attrs_list, target_rel)
if not matches:
return None
first = matches[0]
return first.get('href')

View File

@@ -0,0 +1,154 @@
"""Module containing a cryptographic-quality source of randomness and
other cryptographically useful functionality
Python 2.4 needs no external support for this module, nor does Python
2.3 on a system with /dev/urandom.
Other configurations will need a quality source of random bytes and
access to a function that will convert binary strings to long
integers. This module will work with the Python Cryptography Toolkit
(pycrypto) if it is present. pycrypto can be found with a search
engine, but is currently found at:
http://www.amk.ca/python/code/crypto
"""
__all__ = [
'base64ToLong',
'binaryToLong',
'hmacSha1',
'hmacSha256',
'longToBase64',
'longToBinary',
'randomString',
'randrange',
'sha1',
'sha256',
]
import hmac
import os
import random
from openid.oidutil import toBase64, fromBase64
import hashlib
class HashContainer(object):
def __init__(self, hash_constructor):
self.new = hash_constructor
self.digest_size = hash_constructor().digest_size
sha1_module = HashContainer(hashlib.sha1)
sha256_module = HashContainer(hashlib.sha256)
def hmacSha1(key, text):
if isinstance(key, str):
key = bytes(key, encoding="utf-8")
if isinstance(text, str):
text = bytes(text, encoding="utf-8")
return hmac.new(key, text, sha1_module).digest()
def sha1(s):
if isinstance(s, str):
s = bytes(s, encoding="utf-8")
return sha1_module.new(s).digest()
def hmacSha256(key, text):
if isinstance(key, str):
key = bytes(key, encoding="utf-8")
if isinstance(text, str):
text = bytes(text, encoding="utf-8")
return hmac.new(key, text, sha256_module).digest()
def sha256(s):
if isinstance(s, str):
s = bytes(s, encoding="utf-8")
return sha256_module.new(s).digest()
SHA256_AVAILABLE = True
try:
from Crypto.Util.number import long_to_bytes, bytes_to_long
except ImportError:
# In the case where we don't have pycrypto installed, define substitute
# functionality.
import pickle
def longToBinary(l):
if l == 0:
return b'\x00'
b = bytearray(pickle.encode_long(l))
b.reverse()
return bytes(b)
def binaryToLong(s):
if isinstance(s, str):
s = s.encode("utf-8")
b = bytearray(s)
b.reverse()
return pickle.decode_long(bytes(b))
else:
# We have pycrypto, so wrap its functions instead.
def longToBinary(l):
if l < 0:
raise ValueError('This function only supports positive integers')
bytestring = long_to_bytes(l)
if bytestring[0] > 127:
return b'\x00' + bytestring
else:
return bytestring
def binaryToLong(bytestring):
if not bytestring:
raise ValueError('Empty string passed to strToLong')
if bytestring[0] > 127:
raise ValueError('This function only supports positive integers')
return bytes_to_long(bytestring)
# A cryptographically safe source of random bytes
getBytes = os.urandom
# A randrange function that works for longs
randrange = random.randrange
def longToBase64(l):
return toBase64(longToBinary(l))
def base64ToLong(s):
return binaryToLong(fromBase64(s))
def randomString(length, chrs=None):
"""Produce a string of length random bytes, chosen from chrs."""
if chrs is None:
return getBytes(length)
else:
n = len(chrs)
return ''.join([chrs[randrange(n)] for _ in range(length)])
def const_eq(s1, s2):
if len(s1) != len(s2):
return False
result = True
for i in range(len(s1)):
result = result and (s1[i] == s2[i])
return result

View File

@@ -0,0 +1,47 @@
from openid import cryptutil
def strxor(x, y):
if len(x) != len(y):
raise ValueError('Inputs to strxor must have the same length')
if isinstance(x, str):
x = x.encode("utf-8")
if isinstance(y, str):
y = y.encode("utf-8")
return bytes([a ^ b for a, b in zip(x, y)])
class DiffieHellman(object):
DEFAULT_MOD = 155172898181473697471232257763715539915724801966915404479707795314057629378541917580651227423698188993727816152646631438561595825688188889951272158842675419950341258706556549803580104870537681476726513255747040765857479291291572334510643245094715007229621094194349783925984760375594985848253359305585439638443
DEFAULT_GEN = 2
def fromDefaults(cls):
return cls(cls.DEFAULT_MOD, cls.DEFAULT_GEN)
fromDefaults = classmethod(fromDefaults)
def __init__(self, modulus, generator):
self.modulus = int(modulus)
self.generator = int(generator)
self._setPrivate(cryptutil.randrange(1, modulus - 1))
def _setPrivate(self, private):
"""This is here to make testing easier"""
self.private = private
self.public = pow(self.generator, self.private, self.modulus)
def usingDefaultValues(self):
return (self.modulus == self.DEFAULT_MOD and
self.generator == self.DEFAULT_GEN)
def getSharedSecret(self, composite):
return pow(composite, self.private, self.modulus)
def xorSecret(self, composite, secret, hash_func):
dh_shared = self.getSharedSecret(composite)
hashed_dh_shared = hash_func(cryptutil.longToBinary(dh_shared))
return strxor(secret, hashed_dh_shared)

View File

@@ -0,0 +1,51 @@
import warnings
from openid import message as message_module
class Extension(object):
"""An interface for OpenID extensions.
@ivar ns_uri: The namespace to which to add the arguments for this
extension
"""
ns_uri = None
ns_alias = None
def getExtensionArgs(self):
"""Get the string arguments that should be added to an OpenID
message for this extension.
@returns: A dictionary of completely non-namespaced arguments
to be added. For example, if the extension's alias is
'uncle', and this method returns {'meat':'Hot Rats'}, the
final message will contain {'openid.uncle.meat':'Hot Rats'}
"""
raise NotImplementedError()
def toMessage(self, message=None):
"""Add the arguments from this extension to the provided
message, or create a new message containing only those
arguments.
@returns: The message with the extension arguments added
"""
if message is None:
warnings.warn(
'Passing None to Extension.toMessage is deprecated. '
'Creating a message assuming you want OpenID 2.',
DeprecationWarning,
stacklevel=2)
message = message_module.Message(message_module.OPENID2_NS)
implicit = message.isOpenID1()
try:
message.namespaces.addAlias(
self.ns_uri, self.ns_alias, implicit=implicit)
except KeyError:
if message.namespaces.getAlias(self.ns_uri) != self.ns_alias:
raise
message.updateArgs(self.ns_uri, self.getExtensionArgs())
return message

View File

@@ -0,0 +1,5 @@
"""OpenID Extension modules."""
__all__ = ['ax', 'pape', 'sreg']
from openid.extensions.draft import pape5 as pape

View File

@@ -0,0 +1,781 @@
# -*- test-case-name: openid.test.test_ax -*-
"""Implements the OpenID Attribute Exchange specification, version 1.0.
@since: 2.1.0
"""
__all__ = [
'AttributeRequest',
'FetchRequest',
'FetchResponse',
'StoreRequest',
'StoreResponse',
]
from openid import extension
from openid.server.trustroot import TrustRoot
from openid.message import NamespaceMap, OPENID_NS
# Use this as the 'count' value for an attribute in a FetchRequest to
# ask for as many values as the OP can provide.
UNLIMITED_VALUES = "unlimited"
# Minimum supported alias length in characters. Here for
# completeness.
MINIMUM_SUPPORTED_ALIAS_LENGTH = 32
def checkAlias(alias):
"""
Check an alias for invalid characters; raise AXError if any are
found. Return None if the alias is valid.
"""
if ',' in alias:
raise AXError("Alias %r must not contain comma" % (alias, ))
if '.' in alias:
raise AXError("Alias %r must not contain period" % (alias, ))
class AXError(ValueError):
"""Results from data that does not meet the attribute exchange 1.0
specification"""
class NotAXMessage(AXError):
"""Raised when there is no Attribute Exchange mode in the message."""
def __repr__(self):
return self.__class__.__name__
def __str__(self):
return self.__class__.__name__
class AXMessage(extension.Extension):
"""Abstract class containing common code for attribute exchange messages
@cvar ns_alias: The preferred namespace alias for attribute
exchange messages
@cvar mode: The type of this attribute exchange message. This must
be overridden in subclasses.
"""
# This class is abstract, so it's OK that it doesn't override the
# abstract method in Extension:
#
#pylint:disable-msg=W0223
ns_alias = 'ax'
ns_uri = 'http://openid.net/srv/ax/1.0'
mode = None # NOTE mode is only ever set to a str value, see below
def _checkMode(self, ax_args):
"""Raise an exception if the mode in the attribute exchange
arguments does not match what is expected for this class.
@raises NotAXMessage: When there is no mode value in ax_args at all.
@raises AXError: When mode does not match.
"""
mode = ax_args.get('mode')
if isinstance(mode, bytes):
mode = str(mode, encoding="utf-8")
if mode != self.mode:
if not mode:
raise NotAXMessage()
else:
raise AXError('Expected mode %r; got %r' % (self.mode, mode))
def _newArgs(self):
"""Return a set of attribute exchange arguments containing the
basic information that must be in every attribute exchange
message.
"""
return {'mode': self.mode}
class AttrInfo(object):
"""Represents a single attribute in an attribute exchange
request. This should be added to an AXRequest object in order to
request the attribute.
@ivar required: Whether the attribute will be marked as required
when presented to the subject of the attribute exchange
request.
@type required: bool
@ivar count: How many values of this type to request from the
subject. Defaults to one.
@type count: int
@ivar type_uri: The identifier that determines what the attribute
represents and how it is serialized. For example, one type URI
representing dates could represent a Unix timestamp in base 10
and another could represent a human-readable string.
@type type_uri: str
@ivar alias: The name that should be given to this alias in the
request. If it is not supplied, a generic name will be
assigned. For example, if you want to call a Unix timestamp
value 'tstamp', set its alias to that value. If two attributes
in the same message request to use the same alias, the request
will fail to be generated.
@type alias: str or NoneType
"""
# It's OK that this class doesn't have public methods (it's just a
# holder for a bunch of attributes):
#
#pylint:disable-msg=R0903
def __init__(self, type_uri, count=1, required=False, alias=None):
self.required = required
self.count = count
self.type_uri = type_uri
self.alias = alias
if self.alias is not None:
checkAlias(self.alias)
def wantsUnlimitedValues(self):
"""
When processing a request for this attribute, the OP should
call this method to determine whether all available attribute
values were requested. If self.count == UNLIMITED_VALUES,
this returns True. Otherwise this returns False, in which
case self.count is an integer.
"""
return self.count == UNLIMITED_VALUES
def toTypeURIs(namespace_map, alias_list_s):
"""Given a namespace mapping and a string containing a
comma-separated list of namespace aliases, return a list of type
URIs that correspond to those aliases.
@param namespace_map: The mapping from namespace URI to alias
@type namespace_map: openid.message.NamespaceMap
@param alias_list_s: The string containing the comma-separated
list of aliases. May also be None for convenience.
@type alias_list_s: str or NoneType
@returns: The list of namespace URIs that corresponds to the
supplied list of aliases. If the string was zero-length or
None, an empty list will be returned.
@raise KeyError: If an alias is present in the list of aliases but
is not present in the namespace map.
"""
uris = []
if alias_list_s:
for alias in alias_list_s.split(','):
type_uri = namespace_map.getNamespaceURI(alias)
if type_uri is None:
raise KeyError('No type is defined for attribute name %r' %
(alias, ))
else:
uris.append(type_uri)
return uris
class FetchRequest(AXMessage):
"""An attribute exchange 'fetch_request' message. This message is
sent by a relying party when it wishes to obtain attributes about
the subject of an OpenID authentication request.
@ivar requested_attributes: The attributes that have been
requested thus far, indexed by the type URI.
@type requested_attributes: {str:AttrInfo}
@ivar update_url: A URL that will accept responses for this
attribute exchange request, even in the absence of the user
who made this request.
"""
mode = 'fetch_request'
def __init__(self, update_url=None):
AXMessage.__init__(self)
self.requested_attributes = {}
self.update_url = update_url
def add(self, attribute):
"""Add an attribute to this attribute exchange request.
@param attribute: The attribute that is being requested
@type attribute: C{L{AttrInfo}}
@returns: None
@raise KeyError: when the requested attribute is already
present in this fetch request.
"""
if attribute.type_uri in self.requested_attributes:
raise KeyError('The attribute %r has already been requested' %
(attribute.type_uri, ))
self.requested_attributes[attribute.type_uri] = attribute
def getExtensionArgs(self):
"""Get the serialized form of this attribute fetch request.
@returns: The fetch request message parameters
@rtype: {unicode:unicode}
"""
aliases = NamespaceMap()
required = []
if_available = []
ax_args = self._newArgs()
for type_uri, attribute in self.requested_attributes.items():
if attribute.alias is None:
alias = aliases.add(type_uri)
else:
# This will raise an exception when the second
# attribute with the same alias is added. I think it
# would be better to complain at the time that the
# attribute is added to this object so that the code
# that is adding it is identified in the stack trace,
# but it's more work to do so, and it won't be 100%
# accurate anyway, since the attributes are
# mutable. So for now, just live with the fact that
# we'll learn about the error later.
#
# The other possible approach is to hide the error and
# generate a new alias on the fly. I think that would
# probably be bad.
alias = aliases.addAlias(type_uri, attribute.alias)
if attribute.required:
required.append(alias)
else:
if_available.append(alias)
if attribute.count != 1:
ax_args['count.' + alias] = str(attribute.count)
ax_args['type.' + alias] = type_uri
if required:
ax_args['required'] = ','.join(required)
if if_available:
ax_args['if_available'] = ','.join(if_available)
return ax_args
def getRequiredAttrs(self):
"""Get the type URIs for all attributes that have been marked
as required.
@returns: A list of the type URIs for attributes that have
been marked as required.
@rtype: [str]
"""
required = []
for type_uri, attribute in self.requested_attributes.items():
if attribute.required:
required.append(type_uri)
return required
def fromOpenIDRequest(cls, openid_request):
"""Extract a FetchRequest from an OpenID message
@param openid_request: The OpenID authentication request
containing the attribute fetch request
@type openid_request: C{L{openid.server.server.CheckIDRequest}}
@rtype: C{L{FetchRequest}} or C{None}
@returns: The FetchRequest extracted from the message or None, if
the message contained no AX extension.
@raises KeyError: if the AuthRequest is not consistent in its use
of namespace aliases.
@raises AXError: When parseExtensionArgs would raise same.
@see: L{parseExtensionArgs}
"""
message = openid_request.message
ax_args = message.getArgs(cls.ns_uri)
self = cls()
try:
self.parseExtensionArgs(ax_args)
except NotAXMessage as err:
return None
if self.update_url:
# Update URL must match the openid.realm of the underlying
# OpenID 2 message.
realm = message.getArg(OPENID_NS, 'realm',
message.getArg(OPENID_NS, 'return_to'))
if not realm:
raise AXError(
("Cannot validate update_url %r " + "against absent realm")
% (self.update_url, ))
tr = TrustRoot.parse(realm)
if not tr.validateURL(self.update_url):
raise AXError(
"Update URL %r failed validation against realm %r" %
(self.update_url, realm, ))
return self
fromOpenIDRequest = classmethod(fromOpenIDRequest)
def parseExtensionArgs(self, ax_args):
"""Given attribute exchange arguments, populate this FetchRequest.
@param ax_args: Attribute Exchange arguments from the request.
As returned from L{Message.getArgs<openid.message.Message.getArgs>}.
@type ax_args: dict
@raises KeyError: if the message is not consistent in its use
of namespace aliases.
@raises NotAXMessage: If ax_args does not include an Attribute Exchange
mode.
@raises AXError: If the data to be parsed does not follow the
attribute exchange specification. At least when
'if_available' or 'required' is not specified for a
particular attribute type.
"""
# Raises an exception if the mode is not the expected value
self._checkMode(ax_args)
aliases = NamespaceMap()
for key, value in ax_args.items():
if key.startswith('type.'):
alias = key[5:]
type_uri = value
aliases.addAlias(type_uri, alias)
count_key = 'count.' + alias
count_s = ax_args.get(count_key)
if count_s:
try:
count = int(count_s)
if count <= 0:
raise AXError(
"Count %r must be greater than zero, got %r" %
(count_key, count_s, ))
except ValueError:
if count_s != UNLIMITED_VALUES:
raise AXError("Invalid count value for %r: %r" %
(count_key, count_s, ))
count = count_s
else:
count = 1
self.add(AttrInfo(type_uri, alias=alias, count=count))
required = toTypeURIs(aliases, ax_args.get('required'))
for type_uri in required:
self.requested_attributes[type_uri].required = True
if_available = toTypeURIs(aliases, ax_args.get('if_available'))
all_type_uris = required + if_available
for type_uri in aliases.iterNamespaceURIs():
if type_uri not in all_type_uris:
raise AXError('Type URI %r was in the request but not '
'present in "required" or "if_available"' %
(type_uri, ))
self.update_url = ax_args.get('update_url')
def iterAttrs(self):
"""Iterate over the AttrInfo objects that are
contained in this fetch_request.
"""
return iter(self.requested_attributes.values())
def __iter__(self):
"""Iterate over the attribute type URIs in this fetch_request
"""
return iter(self.requested_attributes)
def has_key(self, type_uri):
"""Is the given type URI present in this fetch_request?
"""
return type_uri in self.requested_attributes
__contains__ = has_key
class AXKeyValueMessage(AXMessage):
"""An abstract class that implements a message that has attribute
keys and values. It contains the common code between
fetch_response and store_request.
"""
# This class is abstract, so it's OK that it doesn't override the
# abstract method in Extension:
#
#pylint:disable-msg=W0223
def __init__(self):
AXMessage.__init__(self)
self.data = {}
def addValue(self, type_uri, value):
"""Add a single value for the given attribute type to the
message. If there are already values specified for this type,
this value will be sent in addition to the values already
specified.
@param type_uri: The URI for the attribute
@param value: The value to add to the response to the relying
party for this attribute
@type value: unicode
@returns: None
"""
try:
values = self.data[type_uri]
except KeyError:
values = self.data[type_uri] = []
values.append(value)
def setValues(self, type_uri, values):
"""Set the values for the given attribute type. This replaces
any values that have already been set for this attribute.
@param type_uri: The URI for the attribute
@param values: A list of values to send for this attribute.
@type values: [unicode]
"""
self.data[type_uri] = values
def _getExtensionKVArgs(self, aliases=None):
"""Get the extension arguments for the key/value pairs
contained in this message.
@param aliases: An alias mapping. Set to None if you don't
care about the aliases for this request.
"""
if aliases is None:
aliases = NamespaceMap()
ax_args = {}
for type_uri, values in self.data.items():
alias = aliases.add(type_uri)
ax_args['type.' + alias] = type_uri
ax_args['count.' + alias] = str(len(values))
for i, value in enumerate(values):
key = 'value.%s.%d' % (alias, i + 1)
ax_args[key] = value
return ax_args
def parseExtensionArgs(self, ax_args):
"""Parse attribute exchange key/value arguments into this
object.
@param ax_args: The attribute exchange fetch_response
arguments, with namespacing removed.
@type ax_args: {unicode:unicode}
@returns: None
@raises ValueError: If the message has bad values for
particular fields
@raises KeyError: If the namespace mapping is bad or required
arguments are missing
"""
self._checkMode(ax_args)
aliases = NamespaceMap()
for key, value in ax_args.items():
if key.startswith('type.'):
type_uri = value
alias = key[5:]
checkAlias(alias)
aliases.addAlias(type_uri, alias)
for type_uri, alias in aliases.items():
try:
count_s = ax_args['count.' + alias]
except KeyError:
value = ax_args['value.' + alias]
if value == '':
values = []
else:
values = [value]
else:
count = int(count_s)
values = []
for i in range(1, count + 1):
value_key = 'value.%s.%d' % (alias, i)
value = ax_args[value_key]
values.append(value)
self.data[type_uri] = values
def getSingle(self, type_uri, default=None):
"""Get a single value for an attribute. If no value was sent
for this attribute, use the supplied default. If there is more
than one value for this attribute, this method will fail.
@type type_uri: str
@param type_uri: The URI for the attribute
@param default: The value to return if the attribute was not
sent in the fetch_response.
@returns: The value of the attribute in the fetch_response
message, or the default supplied
@rtype: unicode or NoneType
@raises ValueError: If there is more than one value for this
parameter in the fetch_response message.
@raises KeyError: If the attribute was not sent in this response
"""
values = self.data.get(type_uri)
if not values:
return default
elif len(values) == 1:
return values[0]
else:
raise AXError('More than one value present for %r' % (type_uri, ))
def get(self, type_uri):
"""Get the list of values for this attribute in the
fetch_response.
XXX: what to do if the values are not present? default
parameter? this is funny because it's always supposed to
return a list, so the default may break that, though it's
provided by the user's code, so it might be okay. If no
default is supplied, should the return be None or []?
@param type_uri: The URI of the attribute
@returns: The list of values for this attribute in the
response. May be an empty list.
@rtype: [unicode]
@raises KeyError: If the attribute was not sent in the response
"""
return self.data[type_uri]
def count(self, type_uri):
"""Get the number of responses for a particular attribute in
this fetch_response message.
@param type_uri: The URI of the attribute
@returns: The number of values sent for this attribute
@raises KeyError: If the attribute was not sent in the
response. KeyError will not be raised if the number of
values was zero.
"""
return len(self.get(type_uri))
class FetchResponse(AXKeyValueMessage):
"""A fetch_response attribute exchange message
"""
mode = 'fetch_response'
def __init__(self, request=None, update_url=None):
"""
@param request: When supplied, I will use namespace aliases
that match those in this request. I will also check to
make sure I do not respond with attributes that were not
requested.
@type request: L{FetchRequest}
@param update_url: By default, C{update_url} is taken from the
request. But if you do not supply the request, you may set
the C{update_url} here.
@type update_url: str
"""
AXKeyValueMessage.__init__(self)
self.update_url = update_url
self.request = request
def getExtensionArgs(self):
"""Serialize this object into arguments in the attribute
exchange namespace
@returns: The dictionary of unqualified attribute exchange
arguments that represent this fetch_response.
@rtype: {unicode;unicode}
"""
aliases = NamespaceMap()
zero_value_types = []
if self.request is not None:
# Validate the data in the context of the request (the
# same attributes should be present in each, and the
# counts in the response must be no more than the counts
# in the request)
for type_uri in self.data:
if type_uri not in self.request:
raise KeyError(
'Response attribute not present in request: %r' %
(type_uri, ))
for attr_info in self.request.iterAttrs():
# Copy the aliases from the request so that reading
# the response in light of the request is easier
if attr_info.alias is None:
aliases.add(attr_info.type_uri)
else:
aliases.addAlias(attr_info.type_uri, attr_info.alias)
try:
values = self.data[attr_info.type_uri]
except KeyError:
values = []
zero_value_types.append(attr_info)
if (attr_info.count != UNLIMITED_VALUES) and \
(attr_info.count < len(values)):
raise AXError(
'More than the number of requested values were '
'specified for %r' % (attr_info.type_uri, ))
kv_args = self._getExtensionKVArgs(aliases)
# Add the KV args into the response with the args that are
# unique to the fetch_response
ax_args = self._newArgs()
# For each requested attribute, put its type/alias and count
# into the response even if no data were returned.
for attr_info in zero_value_types:
alias = aliases.getAlias(attr_info.type_uri)
kv_args['type.' + alias] = attr_info.type_uri
kv_args['count.' + alias] = '0'
update_url = ((self.request and self.request.update_url) or
self.update_url)
if update_url:
ax_args['update_url'] = update_url
ax_args.update(kv_args)
return ax_args
def parseExtensionArgs(self, ax_args):
"""@see: {Extension.parseExtensionArgs<openid.extension.Extension.parseExtensionArgs>}"""
super(FetchResponse, self).parseExtensionArgs(ax_args)
self.update_url = ax_args.get('update_url')
def fromSuccessResponse(cls, success_response, signed=True):
"""Construct a FetchResponse object from an OpenID library
SuccessResponse object.
@param success_response: A successful id_res response object
@type success_response: openid.consumer.consumer.SuccessResponse
@param signed: Whether non-signed args should be
processsed. If True (the default), only signed arguments
will be processsed.
@type signed: bool
@returns: A FetchResponse containing the data from the OpenID
message, or None if the SuccessResponse did not contain AX
extension data.
@raises AXError: when the AX data cannot be parsed.
"""
self = cls()
ax_args = success_response.extensionResponse(self.ns_uri, signed)
try:
self.parseExtensionArgs(ax_args)
except NotAXMessage as err:
return None
else:
return self
fromSuccessResponse = classmethod(fromSuccessResponse)
class StoreRequest(AXKeyValueMessage):
"""A store request attribute exchange message representation
"""
mode = 'store_request'
def __init__(self, aliases=None):
"""
@param aliases: The namespace aliases to use when making this
store request. Leave as None to use defaults.
"""
super(StoreRequest, self).__init__()
self.aliases = aliases
def getExtensionArgs(self):
"""
@see: L{Extension.getExtensionArgs<openid.extension.Extension.getExtensionArgs>}
"""
ax_args = self._newArgs()
kv_args = self._getExtensionKVArgs(self.aliases)
ax_args.update(kv_args)
return ax_args
class StoreResponse(AXMessage):
"""An indication that the store request was processed along with
this OpenID transaction.
"""
SUCCESS_MODE = 'store_response_success'
FAILURE_MODE = 'store_response_failure'
def __init__(self, succeeded=True, error_message=None):
AXMessage.__init__(self)
if succeeded and error_message is not None:
raise AXError('An error message may only be included in a '
'failing fetch response')
if succeeded:
self.mode = self.SUCCESS_MODE
else:
self.mode = self.FAILURE_MODE
self.error_message = error_message
def succeeded(self):
"""Was this response a success response?"""
return self.mode == self.SUCCESS_MODE
def getExtensionArgs(self):
"""@see: {Extension.getExtensionArgs<openid.extension.Extension.getExtensionArgs>}"""
ax_args = self._newArgs()
if not self.succeeded() and self.error_message:
ax_args['error'] = self.error_message
return ax_args

View File

@@ -0,0 +1,285 @@
"""An implementation of the OpenID Provider Authentication Policy
Extension 1.0
@see: http://openid.net/developers/specs/
@since: 2.1.0
"""
__all__ = [
'Request',
'Response',
'ns_uri',
'AUTH_PHISHING_RESISTANT',
'AUTH_MULTI_FACTOR',
'AUTH_MULTI_FACTOR_PHYSICAL',
]
from openid.extension import Extension
import re
ns_uri = "http://specs.openid.net/extensions/pape/1.0"
AUTH_MULTI_FACTOR_PHYSICAL = \
'http://schemas.openid.net/pape/policies/2007/06/multi-factor-physical'
AUTH_MULTI_FACTOR = \
'http://schemas.openid.net/pape/policies/2007/06/multi-factor'
AUTH_PHISHING_RESISTANT = \
'http://schemas.openid.net/pape/policies/2007/06/phishing-resistant'
TIME_VALIDATOR = re.compile('^\d\d\d\d-\d\d-\d\dT\d\d:\d\d:\d\dZ$')
class Request(Extension):
"""A Provider Authentication Policy request, sent from a relying
party to a provider
@ivar preferred_auth_policies: The authentication policies that
the relying party prefers
@type preferred_auth_policies: [str]
@ivar max_auth_age: The maximum time, in seconds, that the relying
party wants to allow to have elapsed before the user must
re-authenticate
@type max_auth_age: int or NoneType
"""
ns_alias = 'pape'
def __init__(self, preferred_auth_policies=None, max_auth_age=None):
super(Request, self).__init__()
if not preferred_auth_policies:
preferred_auth_policies = []
self.preferred_auth_policies = preferred_auth_policies
self.max_auth_age = max_auth_age
def __bool__(self):
return bool(self.preferred_auth_policies or
self.max_auth_age is not None)
def addPolicyURI(self, policy_uri):
"""Add an acceptable authentication policy URI to this request
This method is intended to be used by the relying party to add
acceptable authentication types to the request.
@param policy_uri: The identifier for the preferred type of
authentication.
@see: http://openid.net/specs/openid-provider-authentication-policy-extension-1_0-01.html#auth_policies
"""
if policy_uri not in self.preferred_auth_policies:
self.preferred_auth_policies.append(policy_uri)
def getExtensionArgs(self):
"""@see: C{L{Extension.getExtensionArgs}}
"""
ns_args = {
'preferred_auth_policies': ' '.join(self.preferred_auth_policies)
}
if self.max_auth_age is not None:
ns_args['max_auth_age'] = str(self.max_auth_age)
return ns_args
def fromOpenIDRequest(cls, request):
"""Instantiate a Request object from the arguments in a
C{checkid_*} OpenID message
"""
self = cls()
args = request.message.getArgs(self.ns_uri)
if args == {}:
return None
self.parseExtensionArgs(args)
return self
fromOpenIDRequest = classmethod(fromOpenIDRequest)
def parseExtensionArgs(self, args):
"""Set the state of this request to be that expressed in these
PAPE arguments
@param args: The PAPE arguments without a namespace
@rtype: None
@raises ValueError: When the max_auth_age is not parseable as
an integer
"""
# preferred_auth_policies is a space-separated list of policy URIs
self.preferred_auth_policies = []
policies_str = args.get('preferred_auth_policies')
if policies_str:
if isinstance(policies_str, bytes):
policies_str = str(policies_str, encoding="utf-8")
for uri in policies_str.split(' '):
if uri not in self.preferred_auth_policies:
self.preferred_auth_policies.append(uri)
# max_auth_age is base-10 integer number of seconds
max_auth_age_str = args.get('max_auth_age')
self.max_auth_age = None
if max_auth_age_str:
try:
self.max_auth_age = int(max_auth_age_str)
except ValueError:
pass
def preferredTypes(self, supported_types):
"""Given a list of authentication policy URIs that a provider
supports, this method returns the subsequence of those types
that are preferred by the relying party.
@param supported_types: A sequence of authentication policy
type URIs that are supported by a provider
@returns: The sub-sequence of the supported types that are
preferred by the relying party. This list will be ordered
in the order that the types appear in the supported_types
sequence, and may be empty if the provider does not prefer
any of the supported authentication types.
@returntype: [str]
"""
return list(
filter(self.preferred_auth_policies.__contains__, supported_types))
Request.ns_uri = ns_uri
class Response(Extension):
"""A Provider Authentication Policy response, sent from a provider
to a relying party
"""
ns_alias = 'pape'
def __init__(self,
auth_policies=None,
auth_time=None,
nist_auth_level=None):
super(Response, self).__init__()
if auth_policies:
self.auth_policies = auth_policies
else:
self.auth_policies = []
self.auth_time = auth_time
self.nist_auth_level = nist_auth_level
def addPolicyURI(self, policy_uri):
"""Add a authentication policy to this response
This method is intended to be used by the provider to add a
policy that the provider conformed to when authenticating the user.
@param policy_uri: The identifier for the preferred type of
authentication.
@see: http://openid.net/specs/openid-provider-authentication-policy-extension-1_0-01.html#auth_policies
"""
if policy_uri not in self.auth_policies:
self.auth_policies.append(policy_uri)
def fromSuccessResponse(cls, success_response):
"""Create a C{L{Response}} object from a successful OpenID
library response
(C{L{openid.consumer.consumer.SuccessResponse}}) response
message
@param success_response: A SuccessResponse from consumer.complete()
@type success_response: C{L{openid.consumer.consumer.SuccessResponse}}
@rtype: Response or None
@returns: A provider authentication policy response from the
data that was supplied with the C{id_res} response or None
if the provider sent no signed PAPE response arguments.
"""
self = cls()
# PAPE requires that the args be signed.
args = success_response.getSignedNS(self.ns_uri)
# Only try to construct a PAPE response if the arguments were
# signed in the OpenID response. If not, return None.
if args is not None:
self.parseExtensionArgs(args)
return self
else:
return None
def parseExtensionArgs(self, args, strict=False):
"""Parse the provider authentication policy arguments into the
internal state of this object
@param args: unqualified provider authentication policy
arguments
@param strict: Whether to raise an exception when bad data is
encountered
@returns: None. The data is parsed into the internal fields of
this object.
"""
policies_str = args.get('auth_policies')
if policies_str and policies_str != 'none':
self.auth_policies = policies_str.split(' ')
nist_level_str = args.get('nist_auth_level')
if nist_level_str:
try:
nist_level = int(nist_level_str)
except ValueError:
if strict:
raise ValueError(
'nist_auth_level must be an integer between '
'zero and four, inclusive')
else:
self.nist_auth_level = None
else:
if 0 <= nist_level < 5:
self.nist_auth_level = nist_level
auth_time = args.get('auth_time')
if auth_time:
if TIME_VALIDATOR.match(auth_time):
self.auth_time = auth_time
elif strict:
raise ValueError("auth_time must be in RFC3339 format")
fromSuccessResponse = classmethod(fromSuccessResponse)
def getExtensionArgs(self):
"""@see: C{L{Extension.getExtensionArgs}}
"""
if len(self.auth_policies) == 0:
ns_args = {
'auth_policies': 'none',
}
else:
ns_args = {
'auth_policies': ' '.join(self.auth_policies),
}
if self.nist_auth_level is not None:
if self.nist_auth_level not in list(range(0, 5)):
raise ValueError('nist_auth_level must be an integer between '
'zero and four, inclusive')
ns_args['nist_auth_level'] = str(self.nist_auth_level)
if self.auth_time is not None:
if not TIME_VALIDATOR.match(self.auth_time):
raise ValueError('auth_time must be in RFC3339 format')
ns_args['auth_time'] = self.auth_time
return ns_args
Response.ns_uri = ns_uri

View File

@@ -0,0 +1,481 @@
"""An implementation of the OpenID Provider Authentication Policy
Extension 1.0, Draft 5
@see: http://openid.net/developers/specs/
@since: 2.1.0
"""
__all__ = [
'Request',
'Response',
'ns_uri',
'AUTH_PHISHING_RESISTANT',
'AUTH_MULTI_FACTOR',
'AUTH_MULTI_FACTOR_PHYSICAL',
'LEVELS_NIST',
'LEVELS_JISA',
]
from openid.extension import Extension
import warnings
import re
ns_uri = "http://specs.openid.net/extensions/pape/1.0"
AUTH_MULTI_FACTOR_PHYSICAL = \
'http://schemas.openid.net/pape/policies/2007/06/multi-factor-physical'
AUTH_MULTI_FACTOR = \
'http://schemas.openid.net/pape/policies/2007/06/multi-factor'
AUTH_PHISHING_RESISTANT = \
'http://schemas.openid.net/pape/policies/2007/06/phishing-resistant'
AUTH_NONE = \
'http://schemas.openid.net/pape/policies/2007/06/none'
TIME_VALIDATOR = re.compile(r'^\d\d\d\d-\d\d-\d\dT\d\d:\d\d:\d\dZ$')
LEVELS_NIST = 'http://csrc.nist.gov/publications/nistpubs/800-63/SP800-63V1_0_2.pdf'
LEVELS_JISA = 'http://www.jisa.or.jp/spec/auth_level.html'
class PAPEExtension(Extension):
_default_auth_level_aliases = {
'nist': LEVELS_NIST,
'jisa': LEVELS_JISA,
}
def __init__(self):
self.auth_level_aliases = self._default_auth_level_aliases.copy()
def _addAuthLevelAlias(self, auth_level_uri, alias=None):
"""Add an auth level URI alias to this request.
@param auth_level_uri: The auth level URI to send in the
request.
@param alias: The namespace alias to use for this auth level
in this message. May be None if the alias is not
important.
"""
if alias is None:
try:
alias = self._getAlias(auth_level_uri)
except KeyError:
alias = self._generateAlias()
else:
existing_uri = self.auth_level_aliases.get(alias)
if existing_uri is not None and existing_uri != auth_level_uri:
raise KeyError('Attempting to redefine alias %r from %r to %r',
alias, existing_uri, auth_level_uri)
self.auth_level_aliases[alias] = auth_level_uri
def _generateAlias(self):
"""Return an unused auth level alias"""
for i in range(1000):
alias = 'cust%d' % (i, )
if alias not in self.auth_level_aliases:
return alias
raise RuntimeError('Could not find an unused alias (tried 1000!)')
def _getAlias(self, auth_level_uri):
"""Return the alias for the specified auth level URI.
@raises KeyError: if no alias is defined
"""
for (alias, existing_uri) in self.auth_level_aliases.items():
if auth_level_uri == existing_uri:
return alias
raise KeyError(auth_level_uri)
class Request(PAPEExtension):
"""A Provider Authentication Policy request, sent from a relying
party to a provider
@ivar preferred_auth_policies: The authentication policies that
the relying party prefers
@type preferred_auth_policies: [str]
@ivar max_auth_age: The maximum time, in seconds, that the relying
party wants to allow to have elapsed before the user must
re-authenticate
@type max_auth_age: int or NoneType
@ivar preferred_auth_level_types: Ordered list of authentication
level namespace URIs
@type preferred_auth_level_types: [str]
"""
ns_alias = 'pape'
def __init__(self,
preferred_auth_policies=None,
max_auth_age=None,
preferred_auth_level_types=None):
super(Request, self).__init__()
if preferred_auth_policies is None:
preferred_auth_policies = []
self.preferred_auth_policies = preferred_auth_policies
self.max_auth_age = max_auth_age
self.preferred_auth_level_types = []
if preferred_auth_level_types is not None:
for auth_level in preferred_auth_level_types:
self.addAuthLevel(auth_level)
def __bool__(self):
return bool(self.preferred_auth_policies or
self.max_auth_age is not None or
self.preferred_auth_level_types)
def addPolicyURI(self, policy_uri):
"""Add an acceptable authentication policy URI to this request
This method is intended to be used by the relying party to add
acceptable authentication types to the request.
@param policy_uri: The identifier for the preferred type of
authentication.
@see: http://openid.net/specs/openid-provider-authentication-policy-extension-1_0-05.html#auth_policies
"""
if policy_uri not in self.preferred_auth_policies:
self.preferred_auth_policies.append(policy_uri)
def addAuthLevel(self, auth_level_uri, alias=None):
self._addAuthLevelAlias(auth_level_uri, alias)
if auth_level_uri not in self.preferred_auth_level_types:
self.preferred_auth_level_types.append(auth_level_uri)
def getExtensionArgs(self):
"""@see: C{L{Extension.getExtensionArgs}}
"""
ns_args = {
'preferred_auth_policies': ' '.join(self.preferred_auth_policies),
}
if self.max_auth_age is not None:
ns_args['max_auth_age'] = str(self.max_auth_age)
if self.preferred_auth_level_types:
preferred_types = []
for auth_level_uri in self.preferred_auth_level_types:
alias = self._getAlias(auth_level_uri)
ns_args['auth_level.ns.%s' % (alias, )] = auth_level_uri
preferred_types.append(alias)
ns_args['preferred_auth_level_types'] = ' '.join(preferred_types)
return ns_args
def fromOpenIDRequest(cls, request):
"""Instantiate a Request object from the arguments in a
C{checkid_*} OpenID message
"""
self = cls()
args = request.message.getArgs(self.ns_uri)
is_openid1 = request.message.isOpenID1()
if args == {}:
return None
self.parseExtensionArgs(args, is_openid1)
return self
fromOpenIDRequest = classmethod(fromOpenIDRequest)
def parseExtensionArgs(self, args, is_openid1, strict=False):
"""Set the state of this request to be that expressed in these
PAPE arguments
@param args: The PAPE arguments without a namespace
@param strict: Whether to raise an exception if the input is
out of spec or otherwise malformed. If strict is false,
malformed input will be ignored.
@param is_openid1: Whether the input should be treated as part
of an OpenID1 request
@rtype: None
@raises ValueError: When the max_auth_age is not parseable as
an integer
"""
# preferred_auth_policies is a space-separated list of policy URIs
self.preferred_auth_policies = []
policies_str = args.get('preferred_auth_policies')
if policies_str:
if isinstance(policies_str, bytes):
policies_str = str(policies_str, encoding="utf-8")
for uri in policies_str.split(' '):
if uri not in self.preferred_auth_policies:
self.preferred_auth_policies.append(uri)
# max_auth_age is base-10 integer number of seconds
max_auth_age_str = args.get('max_auth_age')
self.max_auth_age = None
if max_auth_age_str:
try:
self.max_auth_age = int(max_auth_age_str)
except ValueError:
if strict:
raise
# Parse auth level information
preferred_auth_level_types = args.get('preferred_auth_level_types')
if preferred_auth_level_types:
aliases = preferred_auth_level_types.strip().split()
for alias in aliases:
key = 'auth_level.ns.%s' % (alias, )
try:
uri = args[key]
except KeyError:
if is_openid1:
uri = self._default_auth_level_aliases.get(alias)
else:
uri = None
if uri is None:
if strict:
raise ValueError('preferred auth level %r is not '
'defined in this message' % (alias, ))
else:
self.addAuthLevel(uri, alias)
def preferredTypes(self, supported_types):
"""Given a list of authentication policy URIs that a provider
supports, this method returns the subsequence of those types
that are preferred by the relying party.
@param supported_types: A sequence of authentication policy
type URIs that are supported by a provider
@returns: The sub-sequence of the supported types that are
preferred by the relying party. This list will be ordered
in the order that the types appear in the supported_types
sequence, and may be empty if the provider does not prefer
any of the supported authentication types.
@returntype: [str]
"""
return list(
filter(self.preferred_auth_policies.__contains__, supported_types))
Request.ns_uri = ns_uri
class Response(PAPEExtension):
"""A Provider Authentication Policy response, sent from a provider
to a relying party
@ivar auth_policies: List of authentication policies conformed to
by this OpenID assertion, represented as policy URIs
"""
ns_alias = 'pape'
def __init__(self, auth_policies=None, auth_time=None, auth_levels=None):
super(Response, self).__init__()
if auth_policies:
self.auth_policies = auth_policies
else:
self.auth_policies = []
self.auth_time = auth_time
self.auth_levels = {}
if auth_levels is None:
auth_levels = {}
for uri, level in auth_levels.items():
self.setAuthLevel(uri, level)
def setAuthLevel(self, level_uri, level, alias=None):
"""Set the value for the given auth level type.
@param level: string representation of an authentication level
valid for level_uri
@param alias: An optional namespace alias for the given auth
level URI. May be omitted if the alias is not
significant. The library will use a reasonable default for
widely-used auth level types.
"""
self._addAuthLevelAlias(level_uri, alias)
self.auth_levels[level_uri] = level
def getAuthLevel(self, level_uri):
"""Return the auth level for the specified auth level
identifier
@returns: A string that should map to the auth levels defined
for the auth level type
@raises KeyError: If the auth level type is not present in
this message
"""
return self.auth_levels[level_uri]
def _getNISTAuthLevel(self):
try:
return int(self.getAuthLevel(LEVELS_NIST))
except KeyError:
return None
nist_auth_level = property(
_getNISTAuthLevel,
doc="Backward-compatibility accessor for the NIST auth level")
def addPolicyURI(self, policy_uri):
"""Add a authentication policy to this response
This method is intended to be used by the provider to add a
policy that the provider conformed to when authenticating the user.
@param policy_uri: The identifier for the preferred type of
authentication.
@see: http://openid.net/specs/openid-provider-authentication-policy-extension-1_0-01.html#auth_policies
"""
if policy_uri == AUTH_NONE:
raise RuntimeError(
'To send no policies, do not set any on the response.')
if policy_uri not in self.auth_policies:
self.auth_policies.append(policy_uri)
def fromSuccessResponse(cls, success_response):
"""Create a C{L{Response}} object from a successful OpenID
library response
(C{L{openid.consumer.consumer.SuccessResponse}}) response
message
@param success_response: A SuccessResponse from consumer.complete()
@type success_response: C{L{openid.consumer.consumer.SuccessResponse}}
@rtype: Response or None
@returns: A provider authentication policy response from the
data that was supplied with the C{id_res} response or None
if the provider sent no signed PAPE response arguments.
"""
self = cls()
# PAPE requires that the args be signed.
args = success_response.getSignedNS(self.ns_uri)
is_openid1 = success_response.isOpenID1()
# Only try to construct a PAPE response if the arguments were
# signed in the OpenID response. If not, return None.
if args is not None:
self.parseExtensionArgs(args, is_openid1)
return self
else:
return None
def parseExtensionArgs(self, args, is_openid1, strict=False):
"""Parse the provider authentication policy arguments into the
internal state of this object
@param args: unqualified provider authentication policy
arguments
@param strict: Whether to raise an exception when bad data is
encountered
@returns: None. The data is parsed into the internal fields of
this object.
"""
policies_str = args.get('auth_policies')
if policies_str:
auth_policies = policies_str.split(' ')
elif strict:
raise ValueError('Missing auth_policies')
else:
auth_policies = []
if (len(auth_policies) > 1 and strict and AUTH_NONE in auth_policies):
raise ValueError('Got some auth policies, as well as the special '
'"none" URI: %r' % (auth_policies, ))
if 'none' in auth_policies:
msg = '"none" used as a policy URI (see PAPE draft < 5)'
if strict:
raise ValueError(msg)
else:
warnings.warn(msg, stacklevel=2)
auth_policies = [
u for u in auth_policies if u not in ['none', AUTH_NONE]
]
self.auth_policies = auth_policies
for (key, val) in args.items():
if key.startswith('auth_level.'):
alias = key[11:]
# skip the already-processed namespace declarations
if alias.startswith('ns.'):
continue
try:
uri = args['auth_level.ns.%s' % (alias, )]
except KeyError:
if is_openid1:
uri = self._default_auth_level_aliases.get(alias)
else:
uri = None
if uri is None:
if strict:
raise ValueError('Undefined auth level alias: %r' %
(alias, ))
else:
self.setAuthLevel(uri, val, alias)
auth_time = args.get('auth_time')
if auth_time:
if TIME_VALIDATOR.match(auth_time):
self.auth_time = auth_time
elif strict:
raise ValueError("auth_time must be in RFC3339 format")
fromSuccessResponse = classmethod(fromSuccessResponse)
def getExtensionArgs(self):
"""@see: C{L{Extension.getExtensionArgs}}
"""
if len(self.auth_policies) == 0:
ns_args = {
'auth_policies': AUTH_NONE,
}
else:
ns_args = {
'auth_policies': ' '.join(self.auth_policies),
}
for level_type, level in self.auth_levels.items():
alias = self._getAlias(level_type)
ns_args['auth_level.ns.%s' % (alias, )] = level_type
ns_args['auth_level.%s' % (alias, )] = str(level)
if self.auth_time is not None:
if not TIME_VALIDATOR.match(self.auth_time):
raise ValueError('auth_time must be in RFC3339 format')
ns_args['auth_time'] = self.auth_time
return ns_args
Response.ns_uri = ns_uri

View File

@@ -0,0 +1,529 @@
"""Simple registration request and response parsing and object representation
This module contains objects representing simple registration requests
and responses that can be used with both OpenID relying parties and
OpenID providers.
1. The relying party creates a request object and adds it to the
C{L{AuthRequest<openid.consumer.consumer.AuthRequest>}} object
before making the C{checkid_} request to the OpenID provider::
auth_request.addExtension(SRegRequest(required=['email']))
2. The OpenID provider extracts the simple registration request from
the OpenID request using C{L{SRegRequest.fromOpenIDRequest}},
gets the user's approval and data, creates a C{L{SRegResponse}}
object and adds it to the C{id_res} response::
sreg_req = SRegRequest.fromOpenIDRequest(checkid_request)
# [ get the user's approval and data, informing the user that
# the fields in sreg_response were requested ]
sreg_resp = SRegResponse.extractResponse(sreg_req, user_data)
sreg_resp.toMessage(openid_response.fields)
3. The relying party uses C{L{SRegResponse.fromSuccessResponse}} to
extract the data from the OpenID response::
sreg_resp = SRegResponse.fromSuccessResponse(success_response)
@since: 2.0
@var sreg_data_fields: The names of the data fields that are listed in
the sreg spec, and a description of them in English
@var sreg_uri: The preferred URI to use for the simple registration
namespace and XRD Type value
"""
from openid.message import registerNamespaceAlias, \
NamespaceAliasRegistrationError
from openid.extension import Extension
import logging
logger = logging.getLogger(__name__)
try:
str #pylint:disable-msg=W0104
except NameError:
# For Python 2.2
str = (str, str) #pylint:disable-msg=W0622
__all__ = [
'SRegRequest',
'SRegResponse',
'data_fields',
'ns_uri',
'ns_uri_1_0',
'ns_uri_1_1',
'supportsSReg',
]
# The data fields that are listed in the sreg spec
data_fields = {
'fullname': 'Full Name',
'nickname': 'Nickname',
'dob': 'Date of Birth',
'email': 'E-mail Address',
'gender': 'Gender',
'postcode': 'Postal Code',
'country': 'Country',
'language': 'Language',
'timezone': 'Time Zone',
}
def checkFieldName(field_name):
"""Check to see that the given value is a valid simple
registration data field name.
@raise ValueError: if the field name is not a valid simple
registration data field name
"""
if field_name not in data_fields:
raise ValueError('%r is not a defined simple registration field' %
(field_name, ))
# URI used in the wild for Yadis documents advertising simple
# registration support
ns_uri_1_0 = 'http://openid.net/sreg/1.0'
# URI in the draft specification for simple registration 1.1
# <http://openid.net/specs/openid-simple-registration-extension-1_1-01.html>
ns_uri_1_1 = 'http://openid.net/extensions/sreg/1.1'
# This attribute will always hold the preferred URI to use when adding
# sreg support to an XRDS file or in an OpenID namespace declaration.
ns_uri = ns_uri_1_1
try:
registerNamespaceAlias(ns_uri_1_1, 'sreg')
except NamespaceAliasRegistrationError as e:
logger.exception('registerNamespaceAlias(%r, %r) failed: %s' %
(ns_uri_1_1, 'sreg', str(e), ))
def supportsSReg(endpoint):
"""Does the given endpoint advertise support for simple
registration?
@param endpoint: The endpoint object as returned by OpenID discovery
@type endpoint: openid.consumer.discover.OpenIDEndpoint
@returns: Whether an sreg type was advertised by the endpoint
@rtype: bool
"""
return (endpoint.usesExtension(ns_uri_1_1) or
endpoint.usesExtension(ns_uri_1_0))
class SRegNamespaceError(ValueError):
"""The simple registration namespace was not found and could not
be created using the expected name (there's another extension
using the name 'sreg')
This is not I{illegal}, for OpenID 2, although it probably
indicates a problem, since it's not expected that other extensions
will re-use the alias that is in use for OpenID 1.
If this is an OpenID 1 request, then there is no recourse. This
should not happen unless some code has modified the namespaces for
the message that is being processed.
"""
def getSRegNS(message):
"""Extract the simple registration namespace URI from the given
OpenID message. Handles OpenID 1 and 2, as well as both sreg
namespace URIs found in the wild, as well as missing namespace
definitions (for OpenID 1)
@param message: The OpenID message from which to parse simple
registration fields. This may be a request or response message.
@type message: C{L{openid.message.Message}}
@returns: the sreg namespace URI for the supplied message. The
message may be modified to define a simple registration
namespace.
@rtype: C{str}
@raise ValueError: when using OpenID 1 if the message defines
the 'sreg' alias to be something other than a simple
registration type.
"""
# See if there exists an alias for one of the two defined simple
# registration types.
for sreg_ns_uri in [ns_uri_1_1, ns_uri_1_0]:
alias = message.namespaces.getAlias(sreg_ns_uri)
if alias is not None:
break
else:
# There is no alias for either of the types, so try to add
# one. We default to using the modern value (1.1)
sreg_ns_uri = ns_uri_1_1
try:
message.namespaces.addAlias(ns_uri_1_1, 'sreg')
except KeyError as why:
# An alias for the string 'sreg' already exists, but it's
# defined for something other than simple registration
raise SRegNamespaceError(why)
# we know that sreg_ns_uri defined, because it's defined in the
# else clause of the loop as well, so disable the warning
return sreg_ns_uri #pylint:disable-msg=W0631
class SRegRequest(Extension):
"""An object to hold the state of a simple registration request.
@ivar required: A list of the required fields in this simple
registration request
@type required: [str]
@ivar optional: A list of the optional fields in this simple
registration request
@type optional: [str]
@ivar policy_url: The policy URL that was provided with the request
@type policy_url: str or NoneType
@group Consumer: requestField, requestFields, getExtensionArgs, addToOpenIDRequest
@group Server: fromOpenIDRequest, parseExtensionArgs
"""
ns_alias = 'sreg'
def __init__(self,
required=None,
optional=None,
policy_url=None,
sreg_ns_uri=ns_uri):
"""Initialize an empty simple registration request"""
Extension.__init__(self)
self.required = []
self.optional = []
self.policy_url = policy_url
self.ns_uri = sreg_ns_uri
if required:
self.requestFields(required, required=True, strict=True)
if optional:
self.requestFields(optional, required=False, strict=True)
# Assign getSRegNS to a static method so that it can be
# overridden for testing.
_getSRegNS = staticmethod(getSRegNS)
def fromOpenIDRequest(cls, request):
"""Create a simple registration request that contains the
fields that were requested in the OpenID request with the
given arguments
@param request: The OpenID request
@type request: openid.server.CheckIDRequest
@returns: The newly created simple registration request
@rtype: C{L{SRegRequest}}
"""
self = cls()
# Since we're going to mess with namespace URI mapping, don't
# mutate the object that was passed in.
message = request.message.copy()
self.ns_uri = self._getSRegNS(message)
args = message.getArgs(self.ns_uri)
self.parseExtensionArgs(args)
return self
fromOpenIDRequest = classmethod(fromOpenIDRequest)
def parseExtensionArgs(self, args, strict=False):
"""Parse the unqualified simple registration request
parameters and add them to this object.
This method is essentially the inverse of
C{L{getExtensionArgs}}. This method restores the serialized simple
registration request fields.
If you are extracting arguments from a standard OpenID
checkid_* request, you probably want to use C{L{fromOpenIDRequest}},
which will extract the sreg namespace and arguments from the
OpenID request. This method is intended for cases where the
OpenID server needs more control over how the arguments are
parsed than that method provides.
>>> args = message.getArgs(ns_uri)
>>> request.parseExtensionArgs(args)
@param args: The unqualified simple registration arguments
@type args: {str:str}
@param strict: Whether requests with fields that are not
defined in the simple registration specification should be
tolerated (and ignored)
@type strict: bool
@returns: None; updates this object
"""
for list_name in ['required', 'optional']:
required = (list_name == 'required')
items = args.get(list_name)
if items:
for field_name in items.split(','):
try:
self.requestField(field_name, required, strict)
except ValueError:
if strict:
raise
self.policy_url = args.get('policy_url')
def allRequestedFields(self):
"""A list of all of the simple registration fields that were
requested, whether they were required or optional.
@rtype: [str]
"""
return self.required + self.optional
def wereFieldsRequested(self):
"""Have any simple registration fields been requested?
@rtype: bool
"""
return bool(self.allRequestedFields())
def __contains__(self, field_name):
"""Was this field in the request?"""
return (field_name in self.required or field_name in self.optional)
def requestField(self, field_name, required=False, strict=False):
"""Request the specified field from the OpenID user
@param field_name: the unqualified simple registration field name
@type field_name: str
@param required: whether the given field should be presented
to the user as being a required to successfully complete
the request
@param strict: whether to raise an exception when a field is
added to a request more than once
@raise ValueError: when the field requested is not a simple
registration field or strict is set and the field was
requested more than once
"""
checkFieldName(field_name)
if strict:
if field_name in self.required or field_name in self.optional:
raise ValueError('That field has already been requested')
else:
if field_name in self.required:
return
if field_name in self.optional:
if required:
self.optional.remove(field_name)
else:
return
if required:
self.required.append(field_name)
else:
self.optional.append(field_name)
def requestFields(self, field_names, required=False, strict=False):
"""Add the given list of fields to the request
@param field_names: The simple registration data fields to request
@type field_names: [str]
@param required: Whether these values should be presented to
the user as required
@param strict: whether to raise an exception when a field is
added to a request more than once
@raise ValueError: when a field requested is not a simple
registration field or strict is set and a field was
requested more than once
"""
if isinstance(field_names, str):
raise TypeError('Fields should be passed as a list of '
'strings (not %r)' % (type(field_names), ))
for field_name in field_names:
self.requestField(field_name, required, strict=strict)
def getExtensionArgs(self):
"""Get a dictionary of unqualified simple registration
arguments representing this request.
This method is essentially the inverse of
C{L{parseExtensionArgs}}. This method serializes the simple
registration request fields.
@rtype: {str:str}
"""
args = {}
if self.required:
args['required'] = ','.join(self.required)
if self.optional:
args['optional'] = ','.join(self.optional)
if self.policy_url:
args['policy_url'] = self.policy_url
return args
class SRegResponse(Extension):
"""Represents the data returned in a simple registration response
inside of an OpenID C{id_res} response. This object will be
created by the OpenID server, added to the C{id_res} response
object, and then extracted from the C{id_res} message by the
Consumer.
@ivar data: The simple registration data, keyed by the unqualified
simple registration name of the field (i.e. nickname is keyed
by C{'nickname'})
@ivar ns_uri: The URI under which the simple registration data was
stored in the response message.
@group Server: extractResponse
@group Consumer: fromSuccessResponse
@group Read-only dictionary interface: keys, iterkeys, items, iteritems,
__iter__, get, __getitem__, keys, has_key
"""
ns_alias = 'sreg'
def __init__(self, data=None, sreg_ns_uri=ns_uri):
Extension.__init__(self)
if data is None:
self.data = {}
else:
self.data = data
self.ns_uri = sreg_ns_uri
def extractResponse(cls, request, data):
"""Take a C{L{SRegRequest}} and a dictionary of simple
registration values and create a C{L{SRegResponse}}
object containing that data.
@param request: The simple registration request object
@type request: SRegRequest
@param data: The simple registration data for this
response, as a dictionary from unqualified simple
registration field name to string (unicode) value. For
instance, the nickname should be stored under the key
'nickname'.
@type data: {str:str}
@returns: a simple registration response object
@rtype: SRegResponse
"""
self = cls()
self.ns_uri = request.ns_uri
for field in request.allRequestedFields():
value = data.get(field)
if value is not None:
self.data[field] = value
return self
extractResponse = classmethod(extractResponse)
# Assign getSRegArgs to a static method so that it can be
# overridden for testing
_getSRegNS = staticmethod(getSRegNS)
def fromSuccessResponse(cls, success_response, signed_only=True):
"""Create a C{L{SRegResponse}} object from a successful OpenID
library response
(C{L{openid.consumer.consumer.SuccessResponse}}) response
message
@param success_response: A SuccessResponse from consumer.complete()
@type success_response: C{L{openid.consumer.consumer.SuccessResponse}}
@param signed_only: Whether to process only data that was
signed in the id_res message from the server.
@type signed_only: bool
@rtype: SRegResponse
@returns: A simple registration response containing the data
that was supplied with the C{id_res} response.
"""
self = cls()
self.ns_uri = self._getSRegNS(success_response.message)
if signed_only:
args = success_response.getSignedNS(self.ns_uri)
else:
args = success_response.message.getArgs(self.ns_uri)
if not args:
return None
for field_name in data_fields:
if field_name in args:
self.data[field_name] = args[field_name]
return self
fromSuccessResponse = classmethod(fromSuccessResponse)
def getExtensionArgs(self):
"""Get the fields to put in the simple registration namespace
when adding them to an id_res message.
@see: openid.extension
"""
return self.data
# Read-only dictionary interface
def get(self, field_name, default=None):
"""Like dict.get, except that it checks that the field name is
defined by the simple registration specification"""
checkFieldName(field_name)
return self.data.get(field_name, default)
def items(self):
"""All of the data values in this simple registration response
"""
return list(self.data.items())
def iteritems(self):
return iter(self.data.items())
def keys(self):
return list(self.data.keys())
def iterkeys(self):
return iter(self.data.keys())
def has_key(self, key):
return key in self
def __contains__(self, field_name):
checkFieldName(field_name)
return field_name in self.data
def __iter__(self):
return iter(self.data)
def __getitem__(self, field_name):
checkFieldName(field_name)
return self.data[field_name]
def __bool__(self):
return bool(self.data)

View File

@@ -0,0 +1,493 @@
# -*- test-case-name: openid.test.test_fetchers -*-
"""
This module contains the HTTP fetcher interface and several implementations.
"""
__all__ = [
'fetch', 'getDefaultFetcher', 'setDefaultFetcher', 'HTTPResponse',
'HTTPFetcher', 'createHTTPFetcher', 'HTTPFetchingError', 'HTTPError'
]
import urllib.request
import urllib.error
import urllib.parse
import http.client
import time
import io
import sys
import contextlib
import openid
import openid.urinorm
# Try to import httplib2 for caching support
# http://bitworking.org/projects/httplib2/
try:
import httplib2
except ImportError:
# httplib2 not available
httplib2 = None
# try to import pycurl, which will let us use CurlHTTPFetcher
try:
import pycurl
except ImportError:
pycurl = None
USER_AGENT = "python-openid/%s (%s)" % (openid.__version__, sys.platform)
MAX_RESPONSE_KB = 1024
def fetch(url, body=None, headers=None):
"""Invoke the fetch method on the default fetcher. Most users
should need only this method.
@raises Exception: any exceptions that may be raised by the default fetcher
"""
fetcher = getDefaultFetcher()
return fetcher.fetch(url, body, headers)
def createHTTPFetcher():
"""Create a default HTTP fetcher instance
prefers Curl to urllib2."""
if pycurl is None:
fetcher = Urllib2Fetcher()
else:
fetcher = CurlHTTPFetcher()
return fetcher
# Contains the currently set HTTP fetcher. If it is set to None, the
# library will call createHTTPFetcher() to set it. Do not access this
# variable outside of this module.
_default_fetcher = None
def getDefaultFetcher():
"""Return the default fetcher instance
if no fetcher has been set, it will create a default fetcher.
@return: the default fetcher
@rtype: HTTPFetcher
"""
global _default_fetcher
if _default_fetcher is None:
setDefaultFetcher(createHTTPFetcher())
return _default_fetcher
def setDefaultFetcher(fetcher, wrap_exceptions=True):
"""Set the default fetcher
@param fetcher: The fetcher to use as the default HTTP fetcher
@type fetcher: HTTPFetcher
@param wrap_exceptions: Whether to wrap exceptions thrown by the
fetcher wil HTTPFetchingError so that they may be caught
easier. By default, exceptions will be wrapped. In general,
unwrapped fetchers are useful for debugging of fetching errors
or if your fetcher raises well-known exceptions that you would
like to catch.
@type wrap_exceptions: bool
"""
global _default_fetcher
if fetcher is None or not wrap_exceptions:
_default_fetcher = fetcher
else:
_default_fetcher = ExceptionWrappingFetcher(fetcher)
def usingCurl():
"""Whether the currently set HTTP fetcher is a Curl HTTP fetcher."""
fetcher = getDefaultFetcher()
if isinstance(fetcher, ExceptionWrappingFetcher):
fetcher = fetcher.fetcher
return isinstance(fetcher, CurlHTTPFetcher)
class HTTPResponse(object):
"""XXX document attributes"""
headers = None
status = None
body = None
final_url = None
def __init__(self, final_url=None, status=None, headers=None, body=None):
self.final_url = final_url
self.status = status
self.headers = headers
self.body = body
def __repr__(self):
return "<%s status %s for %s>" % (self.__class__.__name__, self.status,
self.final_url)
class HTTPFetcher(object):
"""
This class is the interface for openid HTTP fetchers. This
interface is only important if you need to write a new fetcher for
some reason.
"""
def fetch(self, url, body=None, headers=None):
"""
This performs an HTTP POST or GET, following redirects along
the way. If a body is specified, then the request will be a
POST. Otherwise, it will be a GET.
@param headers: HTTP headers to include with the request
@type headers: {str:str}
@return: An object representing the server's HTTP response. If
there are network or protocol errors, an exception will be
raised. HTTP error responses, like 404 or 500, do not
cause exceptions.
@rtype: L{HTTPResponse}
@raise Exception: Different implementations will raise
different errors based on the underlying HTTP library.
"""
raise NotImplementedError
def _allowedURL(url):
parsed = urllib.parse.urlparse(url)
# scheme is the first item in the tuple
return parsed[0] in ('http', 'https')
class HTTPFetchingError(Exception):
"""Exception that is wrapped around all exceptions that are raised
by the underlying fetcher when using the ExceptionWrappingFetcher
@ivar why: The exception that caused this exception
"""
def __init__(self, why=None):
Exception.__init__(self, why)
self.why = why
class ExceptionWrappingFetcher(HTTPFetcher):
"""Fetcher that wraps another fetcher, causing all exceptions
@cvar uncaught_exceptions: Exceptions that should be exposed to the
user if they are raised by the fetch call
"""
uncaught_exceptions = (SystemExit, KeyboardInterrupt, MemoryError)
def __init__(self, fetcher):
self.fetcher = fetcher
def fetch(self, *args, **kwargs):
try:
return self.fetcher.fetch(*args, **kwargs)
except self.uncaught_exceptions:
raise
except:
exc_cls, exc_inst = sys.exc_info()[:2]
if exc_inst is None:
# string exceptions
exc_inst = exc_cls
raise HTTPFetchingError(why=exc_inst)
class Urllib2Fetcher(HTTPFetcher):
"""An C{L{HTTPFetcher}} that uses urllib2.
"""
# Parameterized for the benefit of testing frameworks, see
# http://trac.openidenabled.com/trac/ticket/85
urlopen = staticmethod(urllib.request.urlopen)
def fetch(self, url, body=None, headers=None):
if not _allowedURL(url):
raise ValueError('Bad URL scheme: %r' % (url, ))
if headers is None:
headers = {}
headers.setdefault('User-Agent', "%s Python-urllib/%s" %
(USER_AGENT, urllib.request.__version__))
if isinstance(body, str):
body = bytes(body, encoding="utf-8")
req = urllib.request.Request(url, data=body, headers=headers)
url_resource = None
try:
url_resource = self.urlopen(req)
with contextlib.closing(url_resource):
return self._makeResponse(url_resource)
except urllib.error.HTTPError as why:
with contextlib.closing(why):
resp = self._makeResponse(why)
return resp
except (urllib.error.URLError, http.client.BadStatusLine) as why:
raise
except Exception as why:
raise AssertionError(why)
def _makeResponse(self, urllib2_response):
'''
Construct an HTTPResponse from the the urllib response. Attempt to
decode the response body from bytes to str if the necessary information
is available.
'''
resp = HTTPResponse()
resp.body = urllib2_response.read(MAX_RESPONSE_KB * 1024)
resp.final_url = urllib2_response.geturl()
resp.headers = self._lowerCaseKeys(
dict(list(urllib2_response.info().items())))
if hasattr(urllib2_response, 'code'):
resp.status = urllib2_response.code
else:
resp.status = 200
_, extra_dict = self._parseHeaderValue(
resp.headers.get("content-type", ""))
# Try to decode the response body to a string, if there's a
# charset known; fall back to ISO-8859-1 otherwise, since that's
# what's suggested in HTTP/1.1
charset = extra_dict.get('charset', 'latin1')
try:
resp.body = resp.body.decode(charset)
except Exception:
pass
return resp
def _lowerCaseKeys(self, headers_dict):
new_dict = {}
for k, v in headers_dict.items():
new_dict[k.lower()] = v
return new_dict
def _parseHeaderValue(self, header_value):
"""
Parse out a complex header value (such as Content-Type, with a value
like "text/html; charset=utf-8") into a main value and a dictionary of
extra information (in this case, 'text/html' and {'charset': 'utf8'}).
"""
values = header_value.split(';', 1)
if len(values) == 1:
# There's no extra info -- return the main value and an empty dict
return values[0], {}
main_value, extra_values = values[0], values[1].split(';')
extra_dict = {}
for value_string in extra_values:
try:
key, value = value_string.split('=', 1)
extra_dict[key.strip()] = value.strip()
except ValueError:
# Can't unpack it -- must be malformed. Ignore
pass
return main_value, extra_dict
class HTTPError(HTTPFetchingError):
"""
This exception is raised by the C{L{CurlHTTPFetcher}} when it
encounters an exceptional situation fetching a URL.
"""
pass
# XXX: define what we mean by paranoid, and make sure it is.
class CurlHTTPFetcher(HTTPFetcher):
"""
An C{L{HTTPFetcher}} that uses pycurl for fetching.
See U{http://pycurl.sourceforge.net/}.
"""
ALLOWED_TIME = 20 # seconds
def __init__(self):
HTTPFetcher.__init__(self)
if pycurl is None:
raise RuntimeError('Cannot find pycurl library')
def _parseHeaders(self, header_file):
header_file.seek(0)
# Remove all non "name: value" header lines from the input
lines = [line.decode().strip() for line in header_file if b':' in line]
headers = {}
for line in lines:
try:
name, value = line.split(':', 1)
except ValueError:
raise HTTPError("Malformed HTTP header line in response: %r" %
(line, ))
value = value.strip()
# HTTP headers are case-insensitive
name = name.lower()
headers[name] = value
return headers
def _checkURL(self, url):
# XXX: document that this can be overridden to match desired policy
# XXX: make sure url is well-formed and routeable
return _allowedURL(url)
def fetch(self, url, body=None, headers=None):
stop = int(time.time()) + self.ALLOWED_TIME
off = self.ALLOWED_TIME
if headers is None:
headers = {}
headers.setdefault('User-Agent',
"%s %s" % (USER_AGENT, pycurl.version, ))
header_list = []
if headers is not None:
for header_name, header_value in headers.items():
header = '%s: %s' % (header_name, header_value)
header_list.append(header.encode())
c = pycurl.Curl()
try:
c.setopt(pycurl.NOSIGNAL, 1)
if header_list:
c.setopt(pycurl.HTTPHEADER, header_list)
# Presence of a body indicates that we should do a POST
if body is not None:
c.setopt(pycurl.POST, 1)
c.setopt(pycurl.POSTFIELDS, body)
while off > 0:
if not self._checkURL(url):
raise HTTPError("Fetching URL not allowed: %r" % (url, ))
data = io.BytesIO()
def write_data(chunk):
if data.tell() > (1024 * MAX_RESPONSE_KB):
return 0
else:
return data.write(chunk)
response_header_data = io.BytesIO()
c.setopt(pycurl.WRITEFUNCTION, write_data)
c.setopt(pycurl.HEADERFUNCTION, response_header_data.write)
c.setopt(pycurl.TIMEOUT, off)
c.setopt(pycurl.URL, openid.urinorm.urinorm(url))
c.perform()
response_headers = self._parseHeaders(response_header_data)
code = c.getinfo(pycurl.RESPONSE_CODE)
if code in [301, 302, 303, 307]:
url = response_headers.get('location')
if url is None:
raise HTTPError(
'Redirect (%s) returned without a location' % code)
# Redirects are always GETs
c.setopt(pycurl.POST, 0)
# There is no way to reset POSTFIELDS to empty and
# reuse the connection, but we only use it once.
else:
resp = HTTPResponse()
resp.headers = response_headers
resp.status = code
resp.final_url = url
resp.body = data.getvalue().decode()
return resp
off = stop - int(time.time())
raise HTTPError("Timed out fetching: %r" % (url, ))
finally:
c.close()
class HTTPLib2Fetcher(HTTPFetcher):
"""A fetcher that uses C{httplib2} for performing HTTP
requests. This implementation supports HTTP caching.
@see: http://bitworking.org/projects/httplib2/
"""
def __init__(self, cache=None):
"""@param cache: An object suitable for use as an C{httplib2}
cache. If a string is passed, it is assumed to be a
directory name.
"""
if httplib2 is None:
raise RuntimeError('Cannot find httplib2 library. '
'See http://bitworking.org/projects/httplib2/')
super(HTTPLib2Fetcher, self).__init__()
# An instance of the httplib2 object that performs HTTP requests
self.httplib2 = httplib2.Http(cache)
# We want httplib2 to raise exceptions for errors, just like
# the other fetchers.
self.httplib2.force_exception_to_status_code = False
def fetch(self, url, body=None, headers=None):
"""Perform an HTTP request
@raises Exception: Any exception that can be raised by httplib2
@see: C{L{HTTPFetcher.fetch}}
"""
if body:
method = 'POST'
else:
method = 'GET'
if headers is None:
headers = {}
# httplib2 doesn't check to make sure that the URL's scheme is
# 'http' so we do it here.
if not (url.startswith('http://') or url.startswith('https://')):
raise ValueError('URL is not a HTTP URL: %r' % (url, ))
httplib2_response, content = self.httplib2.request(
url, method, body=body, headers=headers)
# Translate the httplib2 response to our HTTP response abstraction
# When a 400 is returned, there is no "content-location"
# header set. This seems like a bug to me. I can't think of a
# case where we really care about the final URL when it is an
# error response, but being careful about it can't hurt.
try:
final_url = httplib2_response['content-location']
except KeyError:
# We're assuming that no redirects occurred
assert not httplib2_response.previous
# And this should never happen for a successful response
assert httplib2_response.status != 200
final_url = url
return HTTPResponse(
body=content.decode(), # TODO Don't assume ASCII
final_url=final_url,
headers=dict(list(httplib2_response.items())),
status=httplib2_response.status, )

View File

@@ -0,0 +1,134 @@
import logging
logger = logging.getLogger(__name__)
__all__ = ['seqToKV', 'kvToSeq', 'dictToKV', 'kvToDict']
class KVFormError(ValueError):
pass
def seqToKV(seq, strict=False):
"""Represent a sequence of pairs of strings as newline-terminated
key:value pairs. The pairs are generated in the order given.
@param seq: The pairs
@type seq: [(str, (unicode|str))]
@return: A string representation of the sequence
@rtype: bytes
"""
def err(msg):
formatted = 'seqToKV warning: %s: %r' % (msg, seq)
if strict:
raise KVFormError(formatted)
else:
logger.warning(formatted)
lines = []
for k, v in seq:
if isinstance(k, bytes):
k = k.decode('utf-8')
elif not isinstance(k, str):
err('Converting key to string: %r' % k)
k = str(k)
if '\n' in k:
raise KVFormError(
'Invalid input for seqToKV: key contains newline: %r' % (k, ))
if ':' in k:
raise KVFormError(
'Invalid input for seqToKV: key contains colon: %r' % (k, ))
if k.strip() != k:
err('Key has whitespace at beginning or end: %r' % (k, ))
if isinstance(v, bytes):
v = v.decode('utf-8')
elif not isinstance(v, str):
err('Converting value to string: %r' % (v, ))
v = str(v)
if '\n' in v:
raise KVFormError(
'Invalid input for seqToKV: value contains newline: %r' %
(v, ))
if v.strip() != v:
err('Value has whitespace at beginning or end: %r' % (v, ))
lines.append(k + ':' + v + '\n')
return ''.join(lines).encode('utf-8')
def kvToSeq(data, strict=False):
"""
After one parse, seqToKV and kvToSeq are inverses, with no warnings::
seq = kvToSeq(s)
seqToKV(kvToSeq(seq)) == seq
@return str
"""
def err(msg):
formatted = 'kvToSeq warning: %s: %r' % (msg, data)
if strict:
raise KVFormError(formatted)
else:
logger.warning(formatted)
if isinstance(data, bytes):
data = data.decode("utf-8")
lines = data.split('\n')
if lines[-1]:
err('Does not end in a newline')
else:
del lines[-1]
pairs = []
line_num = 0
for line in lines:
line_num += 1
# Ignore blank lines
if not line.strip():
continue
pair = line.split(':', 1)
if len(pair) == 2:
k, v = pair
k_s = k.strip()
if k_s != k:
fmt = ('In line %d, ignoring leading or trailing '
'whitespace in key %r')
err(fmt % (line_num, k))
if not k_s:
err('In line %d, got empty key' % (line_num, ))
v_s = v.strip()
if v_s != v:
fmt = ('In line %d, ignoring leading or trailing '
'whitespace in value %r')
err(fmt % (line_num, v))
pairs.append((k_s, v_s))
else:
err('Line %d does not contain a colon' % line_num)
return pairs
def dictToKV(d):
return seqToKV(sorted(d.items()))
def kvToDict(s):
return dict(kvToSeq(s))

View File

@@ -0,0 +1,678 @@
"""Extension argument processing code
"""
__all__ = [
'Message', 'NamespaceMap', 'no_default', 'registerNamespaceAlias',
'OPENID_NS', 'BARE_NS', 'OPENID1_NS', 'OPENID2_NS', 'SREG_URI',
'IDENTIFIER_SELECT'
]
import copy
import warnings
import urllib.request
import urllib.error
from openid import oidutil
from openid import kvform
try:
ElementTree = oidutil.importElementTree()
except ImportError:
# No elementtree found, so give up, but don't fail to import,
# since we have fallbacks.
ElementTree = None
# This doesn't REALLY belong here, but where is better?
IDENTIFIER_SELECT = 'http://specs.openid.net/auth/2.0/identifier_select'
# URI for Simple Registration extension, the only commonly deployed
# OpenID 1.x extension, and so a special case
SREG_URI = 'http://openid.net/sreg/1.0'
# The OpenID 1.X namespace URI
OPENID1_NS = 'http://openid.net/signon/1.0'
THE_OTHER_OPENID1_NS = 'http://openid.net/signon/1.1'
OPENID1_NAMESPACES = OPENID1_NS, THE_OTHER_OPENID1_NS
# The OpenID 2.0 namespace URI
OPENID2_NS = 'http://specs.openid.net/auth/2.0'
# The namespace consisting of pairs with keys that are prefixed with
# "openid." but not in another namespace.
NULL_NAMESPACE = oidutil.Symbol('Null namespace')
# The null namespace, when it is an allowed OpenID namespace
OPENID_NS = oidutil.Symbol('OpenID namespace')
# The top-level namespace, excluding all pairs with keys that start
# with "openid."
BARE_NS = oidutil.Symbol('Bare namespace')
# Limit, in bytes, of identity provider and return_to URLs, including
# response payload. See OpenID 1.1 specification, Appendix D.
OPENID1_URL_LIMIT = 2047
# All OpenID protocol fields. Used to check namespace aliases.
OPENID_PROTOCOL_FIELDS = [
'ns',
'mode',
'error',
'return_to',
'contact',
'reference',
'signed',
'assoc_type',
'session_type',
'dh_modulus',
'dh_gen',
'dh_consumer_public',
'claimed_id',
'identity',
'realm',
'invalidate_handle',
'op_endpoint',
'response_nonce',
'sig',
'assoc_handle',
'trust_root',
'openid',
]
class UndefinedOpenIDNamespace(ValueError):
"""Raised if the generic OpenID namespace is accessed when there
is no OpenID namespace set for this message."""
class InvalidOpenIDNamespace(ValueError):
"""Raised if openid.ns is not a recognized value.
For recognized values, see L{Message.allowed_openid_namespaces}
"""
def __str__(self):
s = "Invalid OpenID Namespace"
if self.args:
s += " %r" % (self.args[0], )
return s
# Sentinel used for Message implementation to indicate that getArg
# should raise an exception instead of returning a default.
no_default = object()
# Global namespace / alias registration map. See
# registerNamespaceAlias.
registered_aliases = {}
class NamespaceAliasRegistrationError(Exception):
"""
Raised when an alias or namespace URI has already been registered.
"""
pass
def registerNamespaceAlias(namespace_uri, alias):
"""
Registers a (namespace URI, alias) mapping in a global namespace
alias map. Raises NamespaceAliasRegistrationError if either the
namespace URI or alias has already been registered with a
different value. This function is required if you want to use a
namespace with an OpenID 1 message.
"""
global registered_aliases
if registered_aliases.get(alias) == namespace_uri:
return
if namespace_uri in list(registered_aliases.values()):
raise NamespaceAliasRegistrationError(
'Namespace uri %r already registered' % (namespace_uri, ))
if alias in registered_aliases:
raise NamespaceAliasRegistrationError('Alias %r already registered' %
(alias, ))
registered_aliases[alias] = namespace_uri
class Message(object):
"""
In the implementation of this object, None represents the global
namespace as well as a namespace with no key.
@cvar namespaces: A dictionary specifying specific
namespace-URI to alias mappings that should be used when
generating namespace aliases.
@ivar ns_args: two-level dictionary of the values in this message,
grouped by namespace URI. The first level is the namespace
URI.
"""
allowed_openid_namespaces = [OPENID1_NS, THE_OTHER_OPENID1_NS, OPENID2_NS]
def __init__(self, openid_namespace=None):
"""Create an empty Message.
@raises InvalidOpenIDNamespace: if openid_namespace is not in
L{Message.allowed_openid_namespaces}
"""
self.args = {}
self.namespaces = NamespaceMap()
if openid_namespace is None:
self._openid_ns_uri = None
else:
implicit = openid_namespace in OPENID1_NAMESPACES
self.setOpenIDNamespace(openid_namespace, implicit)
@classmethod
def fromPostArgs(cls, args):
"""Construct a Message containing a set of POST arguments.
"""
self = cls()
# Partition into "openid." args and bare args
openid_args = {}
for key, value in args.items():
if isinstance(value, list):
raise TypeError("query dict must have one value for each key, "
"not lists of values. Query is %r" % (args, ))
try:
prefix, rest = key.split('.', 1)
except ValueError:
prefix = None
if prefix != 'openid':
self.args[(BARE_NS, key)] = value
else:
openid_args[rest] = value
self._fromOpenIDArgs(openid_args)
return self
@classmethod
def fromOpenIDArgs(cls, openid_args):
"""Construct a Message from a parsed KVForm message.
@raises InvalidOpenIDNamespace: if openid.ns is not in
L{Message.allowed_openid_namespaces}
"""
self = cls()
self._fromOpenIDArgs(openid_args)
return self
def _fromOpenIDArgs(self, openid_args):
ns_args = []
# Resolve namespaces
for rest, value in openid_args.items():
try:
ns_alias, ns_key = rest.split('.', 1)
except ValueError:
ns_alias = NULL_NAMESPACE
ns_key = rest
if ns_alias == 'ns':
self.namespaces.addAlias(value, ns_key)
elif ns_alias == NULL_NAMESPACE and ns_key == 'ns':
# null namespace
self.setOpenIDNamespace(value, False)
else:
ns_args.append((ns_alias, ns_key, value))
# Implicitly set an OpenID namespace definition (OpenID 1)
if not self.getOpenIDNamespace():
self.setOpenIDNamespace(OPENID1_NS, True)
# Actually put the pairs into the appropriate namespaces
for (ns_alias, ns_key, value) in ns_args:
ns_uri = self.namespaces.getNamespaceURI(ns_alias)
if ns_uri is None:
# we found a namespaced arg without a namespace URI defined
ns_uri = self._getDefaultNamespace(ns_alias)
if ns_uri is None:
ns_uri = self.getOpenIDNamespace()
ns_key = '%s.%s' % (ns_alias, ns_key)
else:
self.namespaces.addAlias(ns_uri, ns_alias, implicit=True)
self.setArg(ns_uri, ns_key, value)
def _getDefaultNamespace(self, mystery_alias):
"""OpenID 1 compatibility: look for a default namespace URI to
use for this alias."""
global registered_aliases
# Only try to map an alias to a default if it's an
# OpenID 1.x message.
if self.isOpenID1():
return registered_aliases.get(mystery_alias)
else:
return None
def setOpenIDNamespace(self, openid_ns_uri, implicit):
"""Set the OpenID namespace URI used in this message.
@raises InvalidOpenIDNamespace: if the namespace is not in
L{Message.allowed_openid_namespaces}
"""
if isinstance(openid_ns_uri, bytes):
openid_ns_uri = str(openid_ns_uri, encoding="utf-8")
if openid_ns_uri not in self.allowed_openid_namespaces:
raise InvalidOpenIDNamespace(openid_ns_uri)
self.namespaces.addAlias(openid_ns_uri, NULL_NAMESPACE, implicit)
self._openid_ns_uri = openid_ns_uri
def getOpenIDNamespace(self):
return self._openid_ns_uri
def isOpenID1(self):
return self.getOpenIDNamespace() in OPENID1_NAMESPACES
def isOpenID2(self):
return self.getOpenIDNamespace() == OPENID2_NS
def fromKVForm(cls, kvform_string):
"""Create a Message from a KVForm string"""
return cls.fromOpenIDArgs(kvform.kvToDict(kvform_string))
fromKVForm = classmethod(fromKVForm)
def copy(self):
return copy.deepcopy(self)
def toPostArgs(self):
"""
Return all arguments with openid. in front of namespaced arguments.
@return bytes
"""
args = {}
# Add namespace definitions to the output
for ns_uri, alias in self.namespaces.items():
if self.namespaces.isImplicit(ns_uri):
continue
if alias == NULL_NAMESPACE:
ns_key = 'openid.ns'
else:
ns_key = 'openid.ns.' + alias
args[ns_key] = oidutil.toUnicode(ns_uri)
for (ns_uri, ns_key), value in self.args.items():
key = self.getKey(ns_uri, ns_key)
# Ensure the resulting value is an UTF-8 encoded *bytestring*.
args[key] = oidutil.toUnicode(value)
return args
def toArgs(self):
"""Return all namespaced arguments, failing if any
non-namespaced arguments exist."""
# FIXME - undocumented exception
post_args = self.toPostArgs()
kvargs = {}
for k, v in post_args.items():
if not k.startswith('openid.'):
raise ValueError(
'This message can only be encoded as a POST, because it '
'contains arguments that are not prefixed with "openid."')
else:
kvargs[k[7:]] = v
return kvargs
def toFormMarkup(self,
action_url,
form_tag_attrs=None,
submit_text="Continue"):
"""Generate HTML form markup that contains the values in this
message, to be HTTP POSTed as x-www-form-urlencoded UTF-8.
@param action_url: The URL to which the form will be POSTed
@type action_url: str
@param form_tag_attrs: Dictionary of attributes to be added to
the form tag. 'accept-charset' and 'enctype' have defaults
that can be overridden. If a value is supplied for
'action' or 'method', it will be replaced.
@type form_tag_attrs: {unicode: unicode}
@param submit_text: The text that will appear on the submit
button for this form.
@type submit_text: unicode
@returns: A string containing (X)HTML markup for a form that
encodes the values in this Message object.
@rtype: str
"""
if ElementTree is None:
raise RuntimeError('This function requires ElementTree.')
assert action_url is not None
form = ElementTree.Element('form')
if form_tag_attrs:
for name, attr in form_tag_attrs.items():
form.attrib[name] = attr
form.attrib['action'] = oidutil.toUnicode(action_url)
form.attrib['method'] = 'post'
form.attrib['accept-charset'] = 'UTF-8'
form.attrib['enctype'] = 'application/x-www-form-urlencoded'
for name, value in self.toPostArgs().items():
attrs = {
'type': 'hidden',
'name': oidutil.toUnicode(name),
'value': oidutil.toUnicode(value)
}
form.append(ElementTree.Element('input', attrs))
submit = ElementTree.Element(
'input',
{'type': 'submit',
'value': oidutil.toUnicode(submit_text)})
form.append(submit)
return str(ElementTree.tostring(form, encoding='utf-8'),
encoding="utf-8")
def toURL(self, base_url):
"""Generate a GET URL with the parameters in this message
attached as query parameters."""
return oidutil.appendArgs(base_url, self.toPostArgs())
def toKVForm(self):
"""Generate a KVForm string that contains the parameters in
this message. This will fail if the message contains arguments
outside of the 'openid.' prefix.
"""
return kvform.dictToKV(self.toArgs())
def toURLEncoded(self):
"""Generate an x-www-urlencoded string"""
args = sorted(self.toPostArgs().items())
return urllib.parse.urlencode(args)
def _fixNS(self, namespace):
"""Convert an input value into the internally used values of
this object
@param namespace: The string or constant to convert
@type namespace: str or unicode or BARE_NS or OPENID_NS
"""
if isinstance(namespace, bytes):
namespace = str(namespace, encoding="utf-8")
if namespace == OPENID_NS:
if self._openid_ns_uri is None:
raise UndefinedOpenIDNamespace('OpenID namespace not set')
else:
namespace = self._openid_ns_uri
if namespace != BARE_NS and not isinstance(namespace, str):
raise TypeError(
"Namespace must be BARE_NS, OPENID_NS or a string. got %r" %
(namespace, ))
if namespace != BARE_NS and ':' not in namespace:
fmt = 'OpenID 2.0 namespace identifiers SHOULD be URIs. Got %r'
warnings.warn(fmt % (namespace, ), DeprecationWarning)
if namespace == 'sreg':
fmt = 'Using %r instead of "sreg" as namespace'
warnings.warn(
fmt % (SREG_URI, ),
DeprecationWarning, )
return SREG_URI
return namespace
def hasKey(self, namespace, ns_key):
namespace = self._fixNS(namespace)
return (namespace, ns_key) in self.args
def getKey(self, namespace, ns_key):
"""Get the key for a particular namespaced argument"""
namespace = self._fixNS(namespace)
if namespace == BARE_NS:
return ns_key
ns_alias = self.namespaces.getAlias(namespace)
# No alias is defined, so no key can exist
if ns_alias is None:
return None
if ns_alias == NULL_NAMESPACE:
tail = ns_key
else:
tail = '%s.%s' % (ns_alias, ns_key)
return 'openid.' + tail
def getArg(self, namespace, key, default=None):
"""Get a value for a namespaced key.
@param namespace: The namespace in the message for this key
@type namespace: str
@param key: The key to get within this namespace
@type key: str
@param default: The value to use if this key is absent from
this message. Using the special value
openid.message.no_default will result in this method
raising a KeyError instead of returning the default.
@rtype: str or the type of default
@raises KeyError: if default is no_default
@raises UndefinedOpenIDNamespace: if the message has not yet
had an OpenID namespace set
"""
namespace = self._fixNS(namespace)
args_key = (namespace, key)
try:
return self.args[args_key]
except KeyError:
if default is no_default:
raise KeyError((namespace, key))
else:
return default
def getArgs(self, namespace):
"""Get the arguments that are defined for this namespace URI
@returns: mapping from namespaced keys to values
@returntype: dict of {str:bytes}
"""
namespace = self._fixNS(namespace)
args = []
for ((pair_ns, ns_key), value) in self.args.items():
if pair_ns == namespace:
if isinstance(ns_key, bytes):
k = str(ns_key, encoding="utf-8")
else:
k = ns_key
if isinstance(value, bytes):
v = str(value, encoding="utf-8")
else:
v = value
args.append((k, v))
return dict(args)
def updateArgs(self, namespace, updates):
"""Set multiple key/value pairs in one call
@param updates: The values to set
@type updates: {unicode:unicode}
"""
namespace = self._fixNS(namespace)
for k, v in updates.items():
self.setArg(namespace, k, v)
def setArg(self, namespace, key, value):
"""Set a single argument in this namespace"""
assert key is not None
assert value is not None
namespace = self._fixNS(namespace)
# try to ensure that internally it's consistent, at least: str -> str
if isinstance(value, bytes):
value = str(value, encoding="utf-8")
self.args[(namespace, key)] = value
if not (namespace is BARE_NS):
self.namespaces.add(namespace)
def delArg(self, namespace, key):
namespace = self._fixNS(namespace)
del self.args[(namespace, key)]
def __repr__(self):
return "<%s.%s %r>" % (self.__class__.__module__,
self.__class__.__name__, self.args)
def __eq__(self, other):
return self.args == other.args
def __ne__(self, other):
return not (self == other)
def getAliasedArg(self, aliased_key, default=None):
if aliased_key == 'ns':
return self.getOpenIDNamespace()
if aliased_key.startswith('ns.'):
uri = self.namespaces.getNamespaceURI(aliased_key[3:])
if uri is None:
if default == no_default:
raise KeyError
else:
return default
else:
return uri
try:
alias, key = aliased_key.split('.', 1)
except ValueError:
# need more than x values to unpack
ns = None
else:
ns = self.namespaces.getNamespaceURI(alias)
if ns is None:
key = aliased_key
ns = self.getOpenIDNamespace()
return self.getArg(ns, key, default)
class NamespaceMap(object):
"""Maintains a bijective map between namespace uris and aliases.
"""
def __init__(self):
self.alias_to_namespace = {}
self.namespace_to_alias = {}
self.implicit_namespaces = []
def getAlias(self, namespace_uri):
return self.namespace_to_alias.get(namespace_uri)
def getNamespaceURI(self, alias):
return self.alias_to_namespace.get(alias)
def iterNamespaceURIs(self):
"""Return an iterator over the namespace URIs"""
return iter(self.namespace_to_alias)
def iterAliases(self):
"""Return an iterator over the aliases"""
return iter(self.alias_to_namespace)
def items(self):
"""Iterate over the mapping
@returns: iterator of (namespace_uri, alias)
"""
return self.namespace_to_alias.items()
def addAlias(self, namespace_uri, desired_alias, implicit=False):
"""Add an alias from this namespace URI to the desired alias
"""
if isinstance(namespace_uri, bytes):
namespace_uri = str(namespace_uri, encoding="utf-8")
# Check that desired_alias is not an openid protocol field as
# per the spec.
assert desired_alias not in OPENID_PROTOCOL_FIELDS, \
"%r is not an allowed namespace alias" % (desired_alias,)
# Check that desired_alias does not contain a period as per
# the spec.
if isinstance(desired_alias, str):
assert '.' not in desired_alias, \
"%r must not contain a dot" % (desired_alias,)
# Check that there is not a namespace already defined for
# the desired alias
current_namespace_uri = self.alias_to_namespace.get(desired_alias)
if (current_namespace_uri is not None and
current_namespace_uri != namespace_uri):
fmt = ('Cannot map %r to alias %r. '
'%r is already mapped to alias %r')
msg = fmt % (namespace_uri, desired_alias, current_namespace_uri,
desired_alias)
raise KeyError(msg)
# Check that there is not already a (different) alias for
# this namespace URI
alias = self.namespace_to_alias.get(namespace_uri)
if alias is not None and alias != desired_alias:
fmt = ('Cannot map %r to alias %r. '
'It is already mapped to alias %r')
raise KeyError(fmt % (namespace_uri, desired_alias, alias))
assert (desired_alias == NULL_NAMESPACE or
type(desired_alias) in [str, str]), repr(desired_alias)
assert namespace_uri not in self.implicit_namespaces
self.alias_to_namespace[desired_alias] = namespace_uri
self.namespace_to_alias[namespace_uri] = desired_alias
if implicit:
self.implicit_namespaces.append(namespace_uri)
return desired_alias
def add(self, namespace_uri):
"""Add this namespace URI to the mapping, without caring what
alias it ends up with"""
# See if this namespace is already mapped to an alias
alias = self.namespace_to_alias.get(namespace_uri)
if alias is not None:
return alias
# Fall back to generating a numerical alias
i = 0
while True:
alias = 'ext' + str(i)
try:
self.addAlias(namespace_uri, alias)
except KeyError:
i += 1
else:
return alias
assert False, "Not reached"
def isDefined(self, namespace_uri):
return namespace_uri in self.namespace_to_alias
def __contains__(self, namespace_uri):
return self.isDefined(namespace_uri)
def isImplicit(self, namespace_uri):
return namespace_uri in self.implicit_namespaces

View File

@@ -0,0 +1,235 @@
"""This module contains general utility code that is used throughout
the library.
"""
__all__ = [
'log', 'appendArgs', 'toBase64', 'fromBase64', 'autoSubmitHTML',
'toUnicode'
]
import binascii
import logging
# import urllib.parse as urlparse
from urllib.parse import urlencode
logger = logging.getLogger(__name__)
xxe_safe_elementtree_modules = [
'defusedxml.cElementTree',
'defusedxml.ElementTree',
]
elementtree_modules = [
'xml.etree.cElementTree',
'xml.etree.ElementTree',
'cElementTree',
'elementtree.ElementTree',
]
def toUnicode(value):
"""Returns the given argument as a unicode object.
@param value: A UTF-8 encoded string or a unicode (coercable) object
@type message: str or unicode
@returns: Unicode object representing the input value.
"""
if isinstance(value, bytes):
return value.decode('utf-8')
return str(value)
def autoSubmitHTML(form, title='OpenID transaction in progress'):
if isinstance(form, bytes):
form = str(form, encoding="utf-8")
if isinstance(title, bytes):
title = str(title, encoding="utf-8")
html = """
<html>
<head>
<title>%s</title>
</head>
<body onload="document.forms[0].submit();">
%s
<script>
var elements = document.forms[0].elements;
for (var i = 0; i < elements.length; i++) {
elements[i].style.display = "none";
}
</script>
</body>
</html>
""" % (title, form)
return html
def importSafeElementTree(module_names=None):
"""Find a working ElementTree implementation that is not vulnerable
to XXE, using `defusedxml`.
>>> XXESafeElementTree = importSafeElementTree()
@param module_names: The names of modules to try to use as
a safe ElementTree. Defaults to C{L{xxe_safe_elementtree_modules}}
@returns: An ElementTree module that is not vulnerable to XXE.
"""
if module_names is None:
module_names = xxe_safe_elementtree_modules
try:
return importElementTree(module_names)
except ImportError:
raise ImportError('Unable to find a ElementTree module '
'that is not vulnerable to XXE. '
'Tried importing %r' % (module_names, ))
def importElementTree(module_names=None):
"""Find a working ElementTree implementation, trying the standard
places that such a thing might show up.
>>> ElementTree = importElementTree()
@param module_names: The names of modules to try to use as
ElementTree. Defaults to C{L{elementtree_modules}}
@returns: An ElementTree module
"""
if module_names is None:
module_names = elementtree_modules
for mod_name in module_names:
try:
ElementTree = __import__(mod_name, None, None, ['unused'])
except ImportError:
pass
else:
# Make sure it can actually parse XML
try:
ElementTree.XML('<unused/>')
except (SystemExit, MemoryError, AssertionError):
raise
except:
logger.exception(
'Not using ElementTree library %r because it failed to '
'parse a trivial document: %s' % mod_name)
else:
return ElementTree
else:
raise ImportError('No ElementTree library found. '
'You may need to install one. '
'Tried importing %r' % (module_names, ))
def log(message, level=0):
"""Handle a log message from the OpenID library.
This is a legacy function which redirects to logger.error.
The logging module should be used instead of this
@param message: A string containing a debugging message from the
OpenID library
@type message: str
@param level: The severity of the log message. This parameter is
currently unused, but in the future, the library may indicate
more important information with a higher level value.
@type level: int or None
@returns: Nothing.
"""
logger.error("This is a legacy log message, please use the "
"logging module. Message: %s", message)
def appendArgs(url, args):
"""Append query arguments to a HTTP(s) URL. If the URL already has
query arguemtns, these arguments will be added, and the existing
arguments will be preserved. Duplicate arguments will not be
detected or collapsed (both will appear in the output).
@param url: The url to which the arguments will be appended
@type url: str
@param args: The query arguments to add to the URL. If a
dictionary is passed, the items will be sorted before
appending them to the URL. If a sequence of pairs is passed,
the order of the sequence will be preserved.
@type args: A dictionary from string to string, or a sequence of
pairs of strings.
@returns: The URL with the parameters added
@rtype: str
"""
if hasattr(args, 'items'):
args = sorted(args.items())
else:
args = list(args)
if not isinstance(url, str):
url = str(url, encoding="utf-8")
if not args:
return url
if '?' in url:
sep = '&'
else:
sep = '?'
# Map unicode to UTF-8 if present. Do not make any assumptions
# about the encodings of plain bytes (str).
i = 0
for k, v in args:
if not isinstance(k, bytes):
k = k.encode('utf-8')
if not isinstance(v, bytes):
v = v.encode('utf-8')
args[i] = (k, v)
i += 1
return '%s%s%s' % (url, sep, urlencode(args))
def toBase64(s):
"""Represent string / bytes s as base64, omitting newlines"""
if isinstance(s, str):
s = s.encode("utf-8")
return binascii.b2a_base64(s)[:-1]
def fromBase64(s):
if isinstance(s, str):
s = s.encode("utf-8")
try:
return binascii.a2b_base64(s)
except binascii.Error as why:
# Convert to a common exception type
raise ValueError(str(why))
class Symbol(object):
"""This class implements an object that compares equal to others
of the same type that have the same name. These are distict from
str or unicode objects.
"""
def __init__(self, name):
self.name = name
def __eq__(self, other):
return type(self) is type(other) and self.name == other.name
def __ne__(self, other):
return not (self == other)
def __hash__(self):
return hash((self.__class__, self.name))
def __repr__(self):
return '<Symbol %s>' % (self.name, )

View File

@@ -0,0 +1,6 @@
"""
This package contains the portions of the library used only when
implementing an OpenID server. See L{openid.server.server}.
"""
__all__ = ['server', 'trustroot']

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,456 @@
# -*- test-case-name: openid.test.test_rpverify -*-
"""
This module contains the C{L{TrustRoot}} class, which helps handle
trust root checking. This module is used by the
C{L{openid.server.server}} module, but it is also available to server
implementers who wish to use it for additional trust root checking.
It also implements relying party return_to URL verification, based on
the realm.
"""
__all__ = [
'TrustRoot',
'RP_RETURN_TO_URL_TYPE',
'extractReturnToURLs',
'returnToMatches',
'verifyReturnTo',
]
from openid import urinorm
from openid.yadis import services
from urllib.parse import urlparse, urlunparse
import re
import logging
logger = logging.getLogger(__name__)
############################################
_protocols = ['http', 'https']
_top_level_domains = [
'ac', 'ad', 'ae', 'aero', 'af', 'ag', 'ai', 'al', 'am', 'an', 'ao', 'aq',
'ar', 'arpa', 'as', 'asia', 'at', 'au', 'aw', 'ax', 'az', 'ba', 'bb', 'bd',
'be', 'bf', 'bg', 'bh', 'bi', 'biz', 'bj', 'bm', 'bn', 'bo', 'br', 'bs',
'bt', 'bv', 'bw', 'by', 'bz', 'ca', 'cat', 'cc', 'cd', 'cf', 'cg', 'ch',
'ci', 'ck', 'cl', 'cm', 'cn', 'co', 'com', 'coop', 'cr', 'cu', 'cv', 'cx',
'cy', 'cz', 'de', 'dj', 'dk', 'dm', 'do', 'dz', 'ec', 'edu', 'ee', 'eg',
'er', 'es', 'et', 'eu', 'fi', 'fj', 'fk', 'fm', 'fo', 'fr', 'ga', 'gb',
'gd', 'ge', 'gf', 'gg', 'gh', 'gi', 'gl', 'gm', 'gn', 'gov', 'gp', 'gq',
'gr', 'gs', 'gt', 'gu', 'gw', 'gy', 'hk', 'hm', 'hn', 'hr', 'ht', 'hu',
'id', 'ie', 'il', 'im', 'in', 'info', 'int', 'io', 'iq', 'ir', 'is', 'it',
'je', 'jm', 'jo', 'jobs', 'jp', 'ke', 'kg', 'kh', 'ki', 'km', 'kn', 'kp',
'kr', 'kw', 'ky', 'kz', 'la', 'lb', 'lc', 'li', 'lk', 'lr', 'ls', 'lt',
'lu', 'lv', 'ly', 'ma', 'mc', 'md', 'me', 'mg', 'mh', 'mil', 'mk', 'ml',
'mm', 'mn', 'mo', 'mobi', 'mp', 'mq', 'mr', 'ms', 'mt', 'mu', 'museum',
'mv', 'mw', 'mx', 'my', 'mz', 'na', 'name', 'nc', 'ne', 'net', 'nf', 'ng',
'ni', 'nl', 'no', 'np', 'nr', 'nu', 'nz', 'om', 'org', 'pa', 'pe', 'pf',
'pg', 'ph', 'pk', 'pl', 'pm', 'pn', 'pr', 'pro', 'ps', 'pt', 'pw', 'py',
'qa', 're', 'ro', 'rs', 'ru', 'rw', 'sa', 'sb', 'sc', 'sd', 'se', 'sg',
'sh', 'si', 'sj', 'sk', 'sl', 'sm', 'sn', 'so', 'sr', 'st', 'su', 'sv',
'sy', 'sz', 'tc', 'td', 'tel', 'tf', 'tg', 'th', 'tj', 'tk', 'tl', 'tm',
'tn', 'to', 'tp', 'tr', 'travel', 'tt', 'tv', 'tw', 'tz', 'ua', 'ug', 'uk',
'us', 'uy', 'uz', 'va', 'vc', 've', 'vg', 'vi', 'vn', 'vu', 'wf', 'ws',
'xn--0zwm56d', 'xn--11b5bs3a9aj6g', 'xn--80akhbyknj4f', 'xn--9t4b11yi5a',
'xn--deba0ad', 'xn--g6w251d', 'xn--hgbk6aj7f53bba', 'xn--hlcj6aya9esc7a',
'xn--jxalpdlp', 'xn--kgbechtv', 'xn--zckzah', 'ye', 'yt', 'yu', 'za', 'zm',
'zw'
]
# Build from RFC3986, section 3.2.2. Used to reject hosts with invalid
# characters.
host_segment_re = re.compile(
r"(?:[-a-zA-Z0-9!$&'\(\)\*+,;=._~]|%[a-zA-Z0-9]{2})+$")
class RealmVerificationRedirected(Exception):
"""Attempting to verify this realm resulted in a redirect.
@since: 2.1.0
"""
def __init__(self, relying_party_url, rp_url_after_redirects):
self.relying_party_url = relying_party_url
self.rp_url_after_redirects = rp_url_after_redirects
def __str__(self):
return ("Attempting to verify %r resulted in "
"redirect to %r" % (self.relying_party_url,
self.rp_url_after_redirects))
def _parseURL(url):
try:
url = urinorm.urinorm(url)
except ValueError:
return None
proto, netloc, path, params, query, frag = urlparse(url)
if not path:
# Python <2.4 does not parse URLs with no path properly
if not query and '?' in netloc:
netloc, query = netloc.split('?', 1)
path = '/'
path = urlunparse(('', '', path, params, query, frag))
if ':' in netloc:
try:
host, port = netloc.split(':')
except ValueError:
return None
if not re.match(r'\d+$', port):
return None
else:
host = netloc
port = ''
host = host.lower()
if not host_segment_re.match(host):
return None
return proto, host, port, path
class TrustRoot(object):
"""
This class represents an OpenID trust root. The C{L{parse}}
classmethod accepts a trust root string, producing a
C{L{TrustRoot}} object. The method OpenID server implementers
would be most likely to use is the C{L{isSane}} method, which
checks the trust root for given patterns that indicate that the
trust root is too broad or points to a local network resource.
@sort: parse, isSane
"""
def __init__(self, unparsed, proto, wildcard, host, port, path):
self.unparsed = unparsed
self.proto = proto
self.wildcard = wildcard
self.host = host
self.port = port
self.path = path
def isSane(self):
"""
This method checks the to see if a trust root represents a
reasonable (sane) set of URLs. 'http://*.com/', for example
is not a reasonable pattern, as it cannot meaningfully specify
the site claiming it. This function attempts to find many
related examples, but it can only work via heuristics.
Negative responses from this method should be treated as
advisory, used only to alert the user to examine the trust
root carefully.
@return: Whether the trust root is sane
@rtype: C{bool}
"""
if self.host == 'localhost':
return True
host_parts = self.host.split('.')
if self.wildcard:
assert host_parts[0] == '', host_parts
del host_parts[0]
# If it's an absolute domain name, remove the empty string
# from the end.
if host_parts and not host_parts[-1]:
del host_parts[-1]
if not host_parts:
return False
# Do not allow adjacent dots
if '' in host_parts:
return False
tld = host_parts[-1]
if tld not in _top_level_domains:
return False
if len(host_parts) == 1:
return False
if self.wildcard:
if len(tld) == 2 and len(host_parts[-2]) <= 3:
# It's a 2-letter tld with a short second to last segment
# so there needs to be more than two segments specified
# (e.g. *.co.uk is insane)
return len(host_parts) > 2
# Passed all tests for insanity.
return True
def validateURL(self, url):
"""
Validates a URL against this trust root.
@param url: The URL to check
@type url: C{str}
@return: Whether the given URL is within this trust root.
@rtype: C{bool}
"""
url_parts = _parseURL(url)
if url_parts is None:
return False
proto, host, port, path = url_parts
if proto != self.proto:
return False
if port != self.port:
return False
if '*' in host:
return False
if not self.wildcard:
if host != self.host:
return False
elif ((not host.endswith(self.host)) and ('.' + host) != self.host):
return False
if path != self.path:
path_len = len(self.path)
trust_prefix = self.path[:path_len]
url_prefix = path[:path_len]
# must be equal up to the length of the path, at least
if trust_prefix != url_prefix:
return False
# These characters must be on the boundary between the end
# of the trust root's path and the start of the URL's
# path.
if '?' in self.path:
allowed = '&'
else:
allowed = '?/'
return (self.path[-1] in allowed or path[path_len] in allowed)
return True
def parse(cls, trust_root):
"""
This method creates a C{L{TrustRoot}} instance from the given
input, if possible.
@param trust_root: This is the trust root to parse into a
C{L{TrustRoot}} object.
@type trust_root: C{str}
@return: A C{L{TrustRoot}} instance if trust_root parses as a
trust root, C{None} otherwise.
@rtype: C{NoneType} or C{L{TrustRoot}}
"""
url_parts = _parseURL(trust_root)
if url_parts is None:
return None
proto, host, port, path = url_parts
# check for valid prototype
if proto not in _protocols:
return None
# check for URI fragment
if path.find('#') != -1:
return None
# extract wildcard if it is there
if host.find('*', 1) != -1:
# wildcard must be at start of domain: *.foo.com, not foo.*.com
return None
if host.startswith('*'):
# Starts with star, so must have a dot after it (if a
# domain is specified)
if len(host) > 1 and host[1] != '.':
return None
host = host[1:]
wilcard = True
else:
wilcard = False
# we have a valid trust root
tr = cls(trust_root, proto, wilcard, host, port, path)
return tr
parse = classmethod(parse)
def checkSanity(cls, trust_root_string):
"""str -> bool
is this a sane trust root?
"""
trust_root = cls.parse(trust_root_string)
if trust_root is None:
return False
else:
return trust_root.isSane()
checkSanity = classmethod(checkSanity)
def checkURL(cls, trust_root, url):
"""quick func for validating a url against a trust root. See the
TrustRoot class if you need more control."""
tr = cls.parse(trust_root)
return tr is not None and tr.validateURL(url)
checkURL = classmethod(checkURL)
def buildDiscoveryURL(self):
"""Return a discovery URL for this realm.
This function does not check to make sure that the realm is
valid. Its behaviour on invalid inputs is undefined.
@rtype: str
@returns: The URL upon which relying party discovery should be run
in order to verify the return_to URL
@since: 2.1.0
"""
if self.wildcard:
# Use "www." in place of the star
assert self.host.startswith('.'), self.host
www_domain = 'www' + self.host
return '%s://%s%s' % (self.proto, www_domain, self.path)
else:
return self.unparsed
def __repr__(self):
return "TrustRoot(%r, %r, %r, %r, %r, %r)" % (
self.unparsed, self.proto, self.wildcard, self.host, self.port,
self.path)
def __str__(self):
return repr(self)
# The URI for relying party discovery, used in realm verification.
#
# XXX: This should probably live somewhere else (like in
# openid.consumer or openid.yadis somewhere)
RP_RETURN_TO_URL_TYPE = 'http://specs.openid.net/auth/2.0/return_to'
def _extractReturnURL(endpoint):
"""If the endpoint is a relying party OpenID return_to endpoint,
return the endpoint URL. Otherwise, return None.
This function is intended to be used as a filter for the Yadis
filtering interface.
@see: C{L{openid.yadis.services}}
@see: C{L{openid.yadis.filters}}
@param endpoint: An XRDS BasicServiceEndpoint, as returned by
performing Yadis dicovery.
@returns: The endpoint URL or None if the endpoint is not a
relying party endpoint.
@rtype: str or NoneType
"""
if endpoint.matchTypes([RP_RETURN_TO_URL_TYPE]):
return endpoint.uri
else:
return None
def returnToMatches(allowed_return_to_urls, return_to):
"""Is the return_to URL under one of the supplied allowed
return_to URLs?
@since: 2.1.0
"""
for allowed_return_to in allowed_return_to_urls:
# A return_to pattern works the same as a realm, except that
# it's not allowed to use a wildcard. We'll model this by
# parsing it as a realm, and not trying to match it if it has
# a wildcard.
return_realm = TrustRoot.parse(allowed_return_to)
if ( # Parses as a trust root
return_realm is not None and
# Does not have a wildcard
not return_realm.wildcard and
# Matches the return_to that we passed in with it
return_realm.validateURL(return_to)):
return True
# No URL in the list matched
return False
def getAllowedReturnURLs(relying_party_url):
"""Given a relying party discovery URL return a list of return_to URLs.
@since: 2.1.0
"""
(rp_url_after_redirects, return_to_urls) = services.getServiceEndpoints(
relying_party_url, _extractReturnURL)
if rp_url_after_redirects != relying_party_url:
# Verification caused a redirect
raise RealmVerificationRedirected(relying_party_url,
rp_url_after_redirects)
return return_to_urls
# _vrfy parameter is there to make testing easier
def verifyReturnTo(realm_str, return_to, _vrfy=getAllowedReturnURLs):
"""Verify that a return_to URL is valid for the given realm.
This function builds a discovery URL, performs Yadis discovery on
it, makes sure that the URL does not redirect, parses out the
return_to URLs, and finally checks to see if the current return_to
URL matches the return_to.
@raises DiscoveryFailure: When Yadis discovery fails
@returns: True if the return_to URL is valid for the realm
@since: 2.1.0
"""
realm = TrustRoot.parse(realm_str)
if realm is None:
# The realm does not parse as a URL pattern
return False
try:
allowable_urls = _vrfy(realm.buildDiscoveryURL())
except RealmVerificationRedirected as err:
logger.exception(str(err))
return False
if returnToMatches(allowable_urls, return_to):
return True
else:
logger.error("Failed to validate return_to %r for realm %r, was not "
"in %s" % (return_to, realm_str, allowable_urls))
return False

View File

@@ -0,0 +1,7 @@
"""moved to L{openid.extensions.sreg}"""
import warnings
warnings.warn("openid.sreg has moved to openid.extensions.sreg",
DeprecationWarning)
from openid.extensions.sreg import *

View File

@@ -0,0 +1,8 @@
"""
This package contains the modules related to this library's use of
persistent storage.
@sort: interface, filestore, sqlstore, memstore
"""
__all__ = ['interface', 'filestore', 'sqlstore', 'memstore', 'nonce']

View File

@@ -0,0 +1,399 @@
"""
This module contains an C{L{OpenIDStore}} implementation backed by
flat files.
"""
import string
import os
import os.path
import time
import logging
from errno import EEXIST, ENOENT
from tempfile import mkstemp
from openid.association import Association
from openid.store.interface import OpenIDStore
from openid.store import nonce
from openid import cryptutil, oidutil
logger = logging.getLogger(__name__)
_filename_allowed = string.ascii_letters + string.digits + '.'
_isFilenameSafe = set(_filename_allowed).__contains__
def _safe64(s):
h64 = oidutil.toBase64(cryptutil.sha1(s))
# to be able to manipulate it, make it a bytearray
h64 = bytearray(h64)
h64 = h64.replace(b'+', b'_')
h64 = h64.replace(b'/', b'.')
h64 = h64.replace(b'=', b'')
return bytes(h64)
def _filenameEscape(s):
filename_chunks = []
for c in s:
if _isFilenameSafe(c):
filename_chunks.append(c)
else:
filename_chunks.append('_%02X' % ord(c))
return ''.join(filename_chunks)
def _removeIfPresent(filename):
"""Attempt to remove a file, returning whether the file existed at
the time of the call.
str -> bool
"""
try:
os.unlink(filename)
except OSError as why:
if why.errno == ENOENT:
# Someone beat us to it, but it's gone, so that's OK
return 0
else:
raise
else:
# File was present
return 1
def _ensureDir(dir_name):
"""Create dir_name as a directory if it does not exist. If it
exists, make sure that it is, in fact, a directory.
Can raise OSError
str -> NoneType
"""
try:
os.makedirs(dir_name)
except OSError as why:
if why.errno != EEXIST or not os.path.isdir(dir_name):
raise
class FileOpenIDStore(OpenIDStore):
"""
This is a filesystem-based store for OpenID associations and
nonces. This store should be safe for use in concurrent systems
on both windows and unix (excluding NFS filesystems). There are a
couple race conditions in the system, but those failure cases have
been set up in such a way that the worst-case behavior is someone
having to try to log in a second time.
Most of the methods of this class are implementation details.
People wishing to just use this store need only pay attention to
the C{L{__init__}} method.
Methods of this object can raise OSError if unexpected filesystem
conditions, such as bad permissions or missing directories, occur.
"""
def __init__(self, directory):
"""
Initializes a new FileOpenIDStore. This initializes the
nonce and association directories, which are subdirectories of
the directory passed in.
@param directory: This is the directory to put the store
directories in.
@type directory: C{str}
"""
# Make absolute
directory = os.path.normpath(os.path.abspath(directory))
self.nonce_dir = os.path.join(directory, 'nonces')
self.association_dir = os.path.join(directory, 'associations')
# Temp dir must be on the same filesystem as the assciations
# directory
self.temp_dir = os.path.join(directory, 'temp')
self.max_nonce_age = 6 * 60 * 60 # Six hours, in seconds
self._setup()
def _setup(self):
"""Make sure that the directories in which we store our data
exist.
() -> NoneType
"""
_ensureDir(self.nonce_dir)
_ensureDir(self.association_dir)
_ensureDir(self.temp_dir)
def _mktemp(self):
"""Create a temporary file on the same filesystem as
self.association_dir.
The temporary directory should not be cleaned if there are any
processes using the store. If there is no active process using
the store, it is safe to remove all of the files in the
temporary directory.
() -> (file, str)
"""
fd, name = mkstemp(dir=self.temp_dir)
try:
file_obj = os.fdopen(fd, 'wb')
return file_obj, name
except:
_removeIfPresent(name)
raise
def getAssociationFilename(self, server_url, handle):
"""Create a unique filename for a given server url and
handle. This implementation does not assume anything about the
format of the handle. The filename that is returned will
contain the domain name from the server URL for ease of human
inspection of the data directory.
(str, str) -> str
"""
if server_url.find('://') == -1:
raise ValueError('Bad server URL: %r' % server_url)
proto, rest = server_url.split('://', 1)
domain = _filenameEscape(rest.split('/', 1)[0])
url_hash = _safe64(server_url)
if handle:
handle_hash = _safe64(handle)
else:
handle_hash = ''
filename = '%s-%s-%s-%s' % (proto, domain, url_hash, handle_hash)
return os.path.join(self.association_dir, filename)
def storeAssociation(self, server_url, association):
"""Store an association in the association directory.
(str, Association) -> NoneType
"""
association_s = association.serialize() # NOTE: UTF-8 encoded bytes
filename = self.getAssociationFilename(server_url, association.handle)
tmp_file, tmp = self._mktemp()
try:
try:
tmp_file.write(association_s)
os.fsync(tmp_file.fileno())
finally:
tmp_file.close()
try:
os.rename(tmp, filename)
except OSError as why:
if why.errno != EEXIST:
raise
# We only expect EEXIST to happen only on Windows. It's
# possible that we will succeed in unlinking the existing
# file, but not in putting the temporary file in place.
try:
os.unlink(filename)
except OSError as why:
if why.errno == ENOENT:
pass
else:
raise
# Now the target should not exist. Try renaming again,
# giving up if it fails.
os.rename(tmp, filename)
except:
# If there was an error, don't leave the temporary file
# around.
_removeIfPresent(tmp)
raise
def getAssociation(self, server_url, handle=None):
"""Retrieve an association. If no handle is specified, return
the association with the latest expiration.
(str, str or NoneType) -> Association or NoneType
"""
if handle is None:
handle = ''
# The filename with the empty handle is a prefix of all other
# associations for the given server URL.
filename = self.getAssociationFilename(server_url, handle)
if handle:
return self._getAssociation(filename)
else:
association_files = os.listdir(self.association_dir)
matching_files = []
# strip off the path to do the comparison
name = os.path.basename(filename)
for association_file in association_files:
if association_file.startswith(name):
matching_files.append(association_file)
matching_associations = []
# read the matching files and sort by time issued
for name in matching_files:
full_name = os.path.join(self.association_dir, name)
association = self._getAssociation(full_name)
if association is not None:
matching_associations.append(
(association.issued, association))
matching_associations.sort()
# return the most recently issued one.
if matching_associations:
(_, assoc) = matching_associations[-1]
return assoc
else:
return None
def _getAssociation(self, filename):
try:
assoc_file = open(filename, 'rb')
except IOError as why:
if why.errno == ENOENT:
# No association exists for that URL and handle
return None
else:
raise
try:
assoc_s = assoc_file.read()
finally:
assoc_file.close()
try:
association = Association.deserialize(assoc_s)
except ValueError:
_removeIfPresent(filename)
return None
# Clean up expired associations
if association.expiresIn == 0:
_removeIfPresent(filename)
return None
else:
return association
def removeAssociation(self, server_url, handle):
"""Remove an association if it exists. Do nothing if it does not.
(str, str) -> bool
"""
assoc = self.getAssociation(server_url, handle)
if assoc is None:
return 0
else:
filename = self.getAssociationFilename(server_url, handle)
return _removeIfPresent(filename)
def useNonce(self, server_url, timestamp, salt):
"""Return whether this nonce is valid.
str -> bool
"""
if abs(timestamp - time.time()) > nonce.SKEW:
return False
if server_url:
proto, rest = server_url.split('://', 1)
else:
# Create empty proto / rest values for empty server_url,
# which is part of a consumer-generated nonce.
proto, rest = '', ''
domain = _filenameEscape(rest.split('/', 1)[0])
url_hash = _safe64(server_url)
salt_hash = _safe64(salt)
filename = '%08x-%s-%s-%s-%s' % (timestamp, proto, domain, url_hash,
salt_hash)
filename = os.path.join(self.nonce_dir, filename)
try:
fd = os.open(filename, os.O_CREAT | os.O_EXCL | os.O_WRONLY, 0o200)
except OSError as why:
if why.errno == EEXIST:
return False
else:
raise
else:
os.close(fd)
return True
def _allAssocs(self):
all_associations = []
association_filenames = [
os.path.join(self.association_dir, filename)
for filename in os.listdir(self.association_dir)
]
for association_filename in association_filenames:
try:
association_file = open(association_filename, 'rb')
except IOError as why:
if why.errno == ENOENT:
logger.exception("%s disappeared during %s._allAssocs" % (
association_filename, self.__class__.__name__))
else:
raise
else:
try:
assoc_s = association_file.read()
finally:
association_file.close()
# Remove expired or corrupted associations
try:
association = Association.deserialize(assoc_s)
except ValueError:
_removeIfPresent(association_filename)
else:
all_associations.append(
(association_filename, association))
return all_associations
def cleanup(self):
"""Remove expired entries from the database. This is
potentially expensive, so only run when it is acceptable to
take time.
() -> NoneType
"""
self.cleanupAssociations()
self.cleanupNonces()
def cleanupAssociations(self):
removed = 0
for assoc_filename, assoc in self._allAssocs():
if assoc.expiresIn == 0:
_removeIfPresent(assoc_filename)
removed += 1
return removed
def cleanupNonces(self):
nonces = os.listdir(self.nonce_dir)
now = time.time()
removed = 0
# Check all nonces for expiry
for nonce_fname in nonces:
timestamp = nonce_fname.split('-', 1)[0]
timestamp = int(timestamp, 16)
if abs(timestamp - now) > nonce.SKEW:
filename = os.path.join(self.nonce_dir, nonce_fname)
_removeIfPresent(filename)
removed += 1
return removed

View File

@@ -0,0 +1,198 @@
"""
This module contains the definition of the C{L{OpenIDStore}}
interface.
"""
class OpenIDStore(object):
"""
This is the interface for the store objects the OpenID library
uses. It is a single class that provides all of the persistence
mechanisms that the OpenID library needs, for both servers and
consumers.
@change: Version 2.0 removed the C{storeNonce}, C{getAuthKey}, and C{isDumb}
methods, and changed the behavior of the C{L{useNonce}} method
to support one-way nonces. It added C{L{cleanupNonces}},
C{L{cleanupAssociations}}, and C{L{cleanup}}.
@sort: storeAssociation, getAssociation, removeAssociation,
useNonce
"""
def storeAssociation(self, server_url, association):
"""
This method puts a C{L{Association
<openid.association.Association>}} object into storage,
retrievable by server URL and handle.
@param server_url: The URL of the identity server that this
association is with. Because of the way the server
portion of the library uses this interface, don't assume
there are any limitations on the character set of the
input string. In particular, expect to see unescaped
non-url-safe characters in the server_url field.
@type server_url: C{str}
@param association: The C{L{Association
<openid.association.Association>}} to store.
@type association: C{L{Association
<openid.association.Association>}}
@return: C{None}
@rtype: C{NoneType}
"""
raise NotImplementedError
def getAssociation(self, server_url, handle=None):
"""
This method returns an C{L{Association
<openid.association.Association>}} object from storage that
matches the server URL and, if specified, handle. It returns
C{None} if no such association is found or if the matching
association is expired.
If no handle is specified, the store may return any
association which matches the server URL. If multiple
associations are valid, the recommended return value for this
method is the one most recently issued.
This method is allowed (and encouraged) to garbage collect
expired associations when found. This method must not return
expired associations.
@param server_url: The URL of the identity server to get the
association for. Because of the way the server portion of
the library uses this interface, don't assume there are
any limitations on the character set of the input string.
In particular, expect to see unescaped non-url-safe
characters in the server_url field.
@type server_url: C{str}
@param handle: This optional parameter is the handle of the
specific association to get. If no specific handle is
provided, any valid association matching the server URL is
returned.
@type handle: C{str} or C{NoneType}
@return: The C{L{Association
<openid.association.Association>}} for the given identity
server.
@rtype: C{L{Association <openid.association.Association>}} or
C{NoneType}
"""
raise NotImplementedError
def removeAssociation(self, server_url, handle):
"""
This method removes the matching association if it's found,
and returns whether the association was removed or not.
@param server_url: The URL of the identity server the
association to remove belongs to. Because of the way the
server portion of the library uses this interface, don't
assume there are any limitations on the character set of
the input string. In particular, expect to see unescaped
non-url-safe characters in the server_url field.
@type server_url: C{str}
@param handle: This is the handle of the association to
remove. If there isn't an association found that matches
both the given URL and handle, then there was no matching
handle found.
@type handle: C{str}
@return: Returns whether or not the given association existed.
@rtype: C{bool} or C{int}
"""
raise NotImplementedError
def useNonce(self, server_url, timestamp, salt):
"""Called when using a nonce.
This method should return C{True} if the nonce has not been
used before, and store it for a while to make sure nobody
tries to use the same value again. If the nonce has already
been used or the timestamp is not current, return C{False}.
You may use L{openid.store.nonce.SKEW} for your timestamp window.
@change: In earlier versions, round-trip nonces were used and
a nonce was only valid if it had been previously stored
with C{storeNonce}. Version 2.0 uses one-way nonces,
requiring a different implementation here that does not
depend on a C{storeNonce} call. (C{storeNonce} is no
longer part of the interface.)
@param server_url: The URL of the server from which the nonce
originated.
@type server_url: C{str}
@param timestamp: The time that the nonce was created (to the
nearest second), in seconds since January 1 1970 UTC.
@type timestamp: C{int}
@param salt: A random string that makes two nonces from the
same server issued during the same second unique.
@type salt: str
@return: Whether or not the nonce was valid.
@rtype: C{bool}
"""
raise NotImplementedError
def cleanupNonces(self):
"""Remove expired nonces from the store.
Discards any nonce from storage that is old enough that its
timestamp would not pass L{useNonce}.
This method is not called in the normal operation of the
library. It provides a way for store admins to keep
their storage from filling up with expired data.
@return: the number of nonces expired.
@returntype: int
"""
raise NotImplementedError
def cleanupAssociations(self):
"""Remove expired associations from the store.
This method is not called in the normal operation of the
library. It provides a way for store admins to keep
their storage from filling up with expired data.
@return: the number of associations expired.
@returntype: int
"""
raise NotImplementedError
def cleanup(self):
"""Shortcut for C{L{cleanupNonces}()}, C{L{cleanupAssociations}()}.
This method is not called in the normal operation of the
library. It provides a way for store admins to keep
their storage from filling up with expired data.
"""
return self.cleanupNonces(), self.cleanupAssociations()

View File

@@ -0,0 +1,126 @@
"""A simple store using only in-process memory."""
from openid.store import nonce
import copy
import time
class ServerAssocs(object):
def __init__(self):
self.assocs = {}
def set(self, assoc):
self.assocs[assoc.handle] = assoc
def get(self, handle):
return self.assocs.get(handle)
def remove(self, handle):
try:
del self.assocs[handle]
except KeyError:
return False
else:
return True
def best(self):
"""Returns association with the oldest issued date.
or None if there are no associations.
"""
best = None
for assoc in list(self.assocs.values()):
if best is None or best.issued < assoc.issued:
best = assoc
return best
def cleanup(self):
"""Remove expired associations.
@return: tuple of (removed associations, remaining associations)
"""
remove = []
for handle, assoc in self.assocs.items():
if assoc.expiresIn == 0:
remove.append(handle)
for handle in remove:
del self.assocs[handle]
return len(remove), len(self.assocs)
class MemoryStore(object):
"""In-process memory store.
Use for single long-running processes. No persistence supplied.
"""
def __init__(self):
self.server_assocs = {}
self.nonces = {}
def _getServerAssocs(self, server_url):
try:
return self.server_assocs[server_url]
except KeyError:
assocs = self.server_assocs[server_url] = ServerAssocs()
return assocs
def storeAssociation(self, server_url, assoc):
assocs = self._getServerAssocs(server_url)
assocs.set(copy.deepcopy(assoc))
def getAssociation(self, server_url, handle=None):
assocs = self._getServerAssocs(server_url)
if handle is None:
return assocs.best()
else:
return assocs.get(handle)
def removeAssociation(self, server_url, handle):
assocs = self._getServerAssocs(server_url)
return assocs.remove(handle)
def useNonce(self, server_url, timestamp, salt):
if abs(timestamp - time.time()) > nonce.SKEW:
return False
anonce = (str(server_url), int(timestamp), str(salt))
if anonce in self.nonces:
return False
else:
self.nonces[anonce] = None
return True
def cleanupNonces(self):
now = time.time()
expired = []
for anonce in self.nonces.keys():
if abs(anonce[1] - now) > nonce.SKEW:
# removing items while iterating over the set could be bad.
expired.append(anonce)
for anonce in expired:
del self.nonces[anonce]
return len(expired)
def cleanupAssociations(self):
remove_urls = []
removed_assocs = 0
for server_url, assocs in self.server_assocs.items():
removed, remaining = assocs.cleanup()
removed_assocs += removed
if not remaining:
remove_urls.append(server_url)
# Remove entries from server_assocs that had none remaining.
for server_url in remove_urls:
del self.server_assocs[server_url]
return removed_assocs
def __eq__(self, other):
return ((self.server_assocs == other.server_assocs) and
(self.nonces == other.nonces))
def __ne__(self, other):
return not (self == other)

View File

@@ -0,0 +1,101 @@
__all__ = [
'split',
'mkNonce',
'checkTimestamp',
]
from openid import cryptutil
from time import strptime, strftime, gmtime, time
from calendar import timegm
import string
NONCE_CHARS = string.ascii_letters + string.digits
# Keep nonces for five hours (allow five hours for the combination of
# request time and clock skew). This is probably way more than is
# necessary, but there is not much overhead in storing nonces.
SKEW = 60 * 60 * 5
time_fmt = '%Y-%m-%dT%H:%M:%SZ'
time_str_len = len('0000-00-00T00:00:00Z')
def split(nonce_string):
"""Extract a timestamp from the given nonce string
@param nonce_string: the nonce from which to extract the timestamp
@type nonce_string: str
@returns: A pair of a Unix timestamp and the salt characters
@returntype: (int, str)
@raises ValueError: if the nonce does not start with a correctly
formatted time string
"""
timestamp_str = nonce_string[:time_str_len]
try:
timestamp = timegm(strptime(timestamp_str, time_fmt))
except AssertionError: # Python 2.2
timestamp = -1
if timestamp < 0:
raise ValueError('time out of range')
return timestamp, nonce_string[time_str_len:]
def checkTimestamp(nonce_string, allowed_skew=SKEW, now=None):
"""Is the timestamp that is part of the specified nonce string
within the allowed clock-skew of the current time?
@param nonce_string: The nonce that is being checked
@type nonce_string: str
@param allowed_skew: How many seconds should be allowed for
completing the request, allowing for clock skew.
@type allowed_skew: int
@param now: The current time, as a Unix timestamp
@type now: int
@returntype: bool
@returns: Whether the timestamp is correctly formatted and within
the allowed skew of the current time.
"""
try:
stamp, _ = split(nonce_string)
except ValueError:
return False
else:
if now is None:
now = time()
# Time after which we should not use the nonce
past = now - allowed_skew
# Time that is too far in the future for us to allow
future = now + allowed_skew
# the stamp is not too far in the future and is not too far in
# the past
return past <= stamp <= future
def mkNonce(when=None):
"""Generate a nonce with the current timestamp
@param when: Unix timestamp representing the issue time of the
nonce. Defaults to the current time.
@type when: int
@returntype: str
@returns: A string that should be usable as a one-way nonce
@see: time
"""
salt = cryptutil.randomString(6, NONCE_CHARS)
if when is None:
t = gmtime()
else:
t = gmtime(when)
time_str = strftime(time_fmt, t)
return time_str + salt

View File

@@ -0,0 +1,510 @@
"""
This module contains C{L{OpenIDStore}} implementations that use
various SQL databases to back them.
Example of how to initialize a store database::
python -c 'from openid.store import sqlstore; import pysqlite2.dbapi2;'
'sqlstore.SQLiteStore(pysqlite2.dbapi2.connect("cstore.db")).createTables()'
"""
import re
import time
from openid.association import Association
from openid.store.interface import OpenIDStore
from openid.store import nonce
def _inTxn(func):
def wrapped(self, *args, **kwargs):
return self._callInTransaction(func, self, *args, **kwargs)
if hasattr(func, '__name__'):
try:
wrapped.__name__ = func.__name__[4:]
except TypeError:
pass
if hasattr(func, '__doc__'):
wrapped.__doc__ = func.__doc__
return wrapped
class SQLStore(OpenIDStore):
"""
This is the parent class for the SQL stores, which contains the
logic common to all of the SQL stores.
The table names used are determined by the class variables
C{L{associations_table}} and
C{L{nonces_table}}. To change the name of the tables used, pass
new table names into the constructor.
To create the tables with the proper schema, see the
C{L{createTables}} method.
This class shouldn't be used directly. Use one of its subclasses
instead, as those contain the code necessary to use a specific
database.
All methods other than C{L{__init__}} and C{L{createTables}}
should be considered implementation details.
@cvar associations_table: This is the default name of the table to
keep associations in
@cvar nonces_table: This is the default name of the table to keep
nonces in.
@sort: __init__, createTables
"""
associations_table = 'oid_associations'
nonces_table = 'oid_nonces'
def __init__(self, conn, associations_table=None, nonces_table=None):
"""
This creates a new SQLStore instance. It requires an
established database connection be given to it, and it allows
overriding the default table names.
@param conn: This must be an established connection to a
database of the correct type for the SQLStore subclass
you're using.
@type conn: A python database API compatible connection
object.
@param associations_table: This is an optional parameter to
specify the name of the table used for storing
associations. The default value is specified in
C{L{SQLStore.associations_table}}.
@type associations_table: C{str}
@param nonces_table: This is an optional parameter to specify
the name of the table used for storing nonces. The
default value is specified in C{L{SQLStore.nonces_table}}.
@type nonces_table: C{str}
"""
self.conn = conn
self.cur = None
self._statement_cache = {}
self._table_names = {
'associations': associations_table or self.associations_table,
'nonces': nonces_table or self.nonces_table,
}
self.max_nonce_age = 6 * 60 * 60 # Six hours, in seconds
# DB API extension: search for "Connection Attributes .Error,
# .ProgrammingError, etc." in
# http://www.python.org/dev/peps/pep-0249/
if (hasattr(self.conn, 'IntegrityError') and
hasattr(self.conn, 'OperationalError')):
self.exceptions = self.conn
if not (hasattr(self.exceptions, 'IntegrityError') and
hasattr(self.exceptions, 'OperationalError')):
raise RuntimeError("Error using database connection module "
"(Maybe it can't be imported?)")
def blobDecode(self, blob):
"""Convert a blob as returned by the SQL engine into a str object.
str -> str"""
return blob
def blobEncode(self, s):
"""Convert a str object into the necessary object for storing
in the database as a blob."""
return s
def _getSQL(self, sql_name):
try:
return self._statement_cache[sql_name]
except KeyError:
sql = getattr(self, sql_name)
sql %= self._table_names
self._statement_cache[sql_name] = sql
return sql
def _execSQL(self, sql_name, *args):
sql = self._getSQL(sql_name)
# Kludge because we have reports of postgresql not quoting
# arguments if they are passed in as unicode instead of str.
# Currently the strings in our tables just have ascii in them,
# so this ought to be safe.
def unicode_to_str(arg):
if isinstance(arg, str):
return str(arg)
else:
return arg
str_args = list(map(unicode_to_str, args))
self.cur.execute(sql, str_args)
def __getattr__(self, attr):
# if the attribute starts with db_, use a default
# implementation that looks up the appropriate SQL statement
# as an attribute of this object and executes it.
if attr[:3] == 'db_':
sql_name = attr[3:] + '_sql'
def func(*args):
return self._execSQL(sql_name, *args)
setattr(self, attr, func)
return func
else:
raise AttributeError('Attribute %r not found' % (attr, ))
def _callInTransaction(self, func, *args, **kwargs):
"""Execute the given function inside of a transaction, with an
open cursor. If no exception is raised, the transaction is
comitted, otherwise it is rolled back."""
# No nesting of transactions
self.conn.rollback()
try:
self.cur = self.conn.cursor()
try:
ret = func(*args, **kwargs)
finally:
self.cur.close()
self.cur = None
except:
self.conn.rollback()
raise
else:
self.conn.commit()
return ret
def txn_createTables(self):
"""
This method creates the database tables necessary for this
store to work. It should not be called if the tables already
exist.
"""
self.db_create_nonce()
self.db_create_assoc()
createTables = _inTxn(txn_createTables)
def txn_storeAssociation(self, server_url, association):
"""Set the association for the server URL.
Association -> NoneType
"""
a = association
self.db_set_assoc(server_url, a.handle,
self.blobEncode(a.secret), a.issued, a.lifetime,
a.assoc_type)
storeAssociation = _inTxn(txn_storeAssociation)
def txn_getAssociation(self, server_url, handle=None):
"""Get the most recent association that has been set for this
server URL and handle.
str -> NoneType or Association
"""
if handle is not None:
self.db_get_assoc(server_url, handle)
else:
self.db_get_assocs(server_url)
rows = self.cur.fetchall()
if len(rows) == 0:
return None
else:
associations = []
for values in rows:
values = list(values)
values[1] = self.blobDecode(values[1])
assoc = Association(*values)
if assoc.expiresIn == 0:
self.txn_removeAssociation(server_url, assoc.handle)
else:
associations.append((assoc.issued, assoc))
if associations:
associations.sort()
return associations[-1][1]
else:
return None
getAssociation = _inTxn(txn_getAssociation)
def txn_removeAssociation(self, server_url, handle):
"""Remove the association for the given server URL and handle,
returning whether the association existed at all.
(str, str) -> bool
"""
self.db_remove_assoc(server_url, handle)
return self.cur.rowcount > 0 # -1 is undefined
removeAssociation = _inTxn(txn_removeAssociation)
def txn_useNonce(self, server_url, timestamp, salt):
"""Return whether this nonce is present, and if it is, then
remove it from the set.
str -> bool"""
if abs(timestamp - time.time()) > nonce.SKEW:
return False
try:
self.db_add_nonce(server_url, timestamp, salt)
except self.exceptions.IntegrityError:
# The key uniqueness check failed
return False
else:
# The nonce was successfully added
return True
useNonce = _inTxn(txn_useNonce)
def txn_cleanupNonces(self):
self.db_clean_nonce(int(time.time()) - nonce.SKEW)
return self.cur.rowcount
cleanupNonces = _inTxn(txn_cleanupNonces)
def txn_cleanupAssociations(self):
self.db_clean_assoc(int(time.time()))
return self.cur.rowcount
cleanupAssociations = _inTxn(txn_cleanupAssociations)
class SQLiteStore(SQLStore):
"""
This is an SQLite-based specialization of C{L{SQLStore}}.
To create an instance, see C{L{SQLStore.__init__}}. To create the
tables it will use, see C{L{SQLStore.createTables}}.
All other methods are implementation details.
"""
create_nonce_sql = """
CREATE TABLE %(nonces)s (
server_url VARCHAR,
timestamp INTEGER,
salt CHAR(40),
UNIQUE(server_url, timestamp, salt)
);
"""
create_assoc_sql = """
CREATE TABLE %(associations)s
(
server_url VARCHAR(2047),
handle VARCHAR(255),
secret BLOB(128),
issued INTEGER,
lifetime INTEGER,
assoc_type VARCHAR(64),
PRIMARY KEY (server_url, handle)
);
"""
set_assoc_sql = ('INSERT OR REPLACE INTO %(associations)s '
'(server_url, handle, secret, issued, '
'lifetime, assoc_type) '
'VALUES (?, ?, ?, ?, ?, ?);')
get_assocs_sql = ('SELECT handle, secret, issued, lifetime, assoc_type '
'FROM %(associations)s WHERE server_url = ?;')
get_assoc_sql = (
'SELECT handle, secret, issued, lifetime, assoc_type '
'FROM %(associations)s WHERE server_url = ? AND handle = ?;')
get_expired_sql = ('SELECT server_url '
'FROM %(associations)s WHERE issued + lifetime < ?;')
remove_assoc_sql = ('DELETE FROM %(associations)s '
'WHERE server_url = ? AND handle = ?;')
clean_assoc_sql = 'DELETE FROM %(associations)s WHERE issued + lifetime < ?;'
add_nonce_sql = 'INSERT INTO %(nonces)s VALUES (?, ?, ?);'
clean_nonce_sql = 'DELETE FROM %(nonces)s WHERE timestamp < ?;'
def blobEncode(self, s):
return memoryview(s)
def useNonce(self, *args, **kwargs):
# Older versions of the sqlite wrapper do not raise
# IntegrityError as they should, so we have to detect the
# message from the OperationalError.
try:
return super(SQLiteStore, self).useNonce(*args, **kwargs)
except self.exceptions.OperationalError as why:
if re.match('^columns .* are not unique$', str(why)):
return False
else:
raise
class MySQLStore(SQLStore):
"""
This is a MySQL-based specialization of C{L{SQLStore}}.
Uses InnoDB tables for transaction support.
To create an instance, see C{L{SQLStore.__init__}}. To create the
tables it will use, see C{L{SQLStore.createTables}}.
All other methods are implementation details.
"""
try:
import MySQLdb as exceptions
except ImportError:
exceptions = None
create_nonce_sql = """
CREATE TABLE %(nonces)s (
server_url BLOB NOT NULL,
timestamp INTEGER NOT NULL,
salt CHAR(40) NOT NULL,
PRIMARY KEY (server_url(255), timestamp, salt)
)
ENGINE=InnoDB;
"""
create_assoc_sql = """
CREATE TABLE %(associations)s
(
server_url BLOB NOT NULL,
handle VARCHAR(255) NOT NULL,
secret BLOB NOT NULL,
issued INTEGER NOT NULL,
lifetime INTEGER NOT NULL,
assoc_type VARCHAR(64) NOT NULL,
PRIMARY KEY (server_url(255), handle)
)
ENGINE=InnoDB;
"""
set_assoc_sql = ('REPLACE INTO %(associations)s '
'VALUES (%%s, %%s, %%s, %%s, %%s, %%s);')
get_assocs_sql = ('SELECT handle, secret, issued, lifetime, assoc_type'
' FROM %(associations)s WHERE server_url = %%s;')
get_expired_sql = ('SELECT server_url '
'FROM %(associations)s WHERE issued + lifetime < %%s;')
get_assoc_sql = (
'SELECT handle, secret, issued, lifetime, assoc_type'
' FROM %(associations)s WHERE server_url = %%s AND handle = %%s;')
remove_assoc_sql = ('DELETE FROM %(associations)s '
'WHERE server_url = %%s AND handle = %%s;')
clean_assoc_sql = 'DELETE FROM %(associations)s WHERE issued + lifetime < %%s;'
add_nonce_sql = 'INSERT INTO %(nonces)s VALUES (%%s, %%s, %%s);'
clean_nonce_sql = 'DELETE FROM %(nonces)s WHERE timestamp < %%s;'
class PostgreSQLStore(SQLStore):
"""
This is a PostgreSQL-based specialization of C{L{SQLStore}}.
To create an instance, see C{L{SQLStore.__init__}}. To create the
tables it will use, see C{L{SQLStore.createTables}}.
All other methods are implementation details.
"""
try:
import psycopg2
except ImportError:
from psycopg2cffi import compat
compat.register()
exceptions = None
create_nonce_sql = """
CREATE TABLE %(nonces)s (
server_url VARCHAR(2047) NOT NULL,
timestamp INTEGER NOT NULL,
salt CHAR(40) NOT NULL,
PRIMARY KEY (server_url, timestamp, salt)
);
"""
create_assoc_sql = """
CREATE TABLE %(associations)s
(
server_url VARCHAR(2047) NOT NULL,
handle VARCHAR(255) NOT NULL,
secret BYTEA NOT NULL,
issued INTEGER NOT NULL,
lifetime INTEGER NOT NULL,
assoc_type VARCHAR(64) NOT NULL,
PRIMARY KEY (server_url, handle),
CONSTRAINT secret_length_constraint CHECK (LENGTH(secret) <= 128)
);
"""
def db_set_assoc(self, server_url, handle, secret, issued, lifetime,
assoc_type):
"""
Set an association. This is implemented as a method because
REPLACE INTO is not supported by PostgreSQL (and is not
standard SQL).
"""
result = self.db_get_assoc(server_url, handle)
rows = self.cur.fetchall()
if len(rows):
# Update the table since this associations already exists.
return self.db_update_assoc(secret, issued, lifetime, assoc_type,
server_url, handle)
else:
# Insert a new record because this association wasn't
# found.
return self.db_new_assoc(server_url, handle, secret, issued,
lifetime, assoc_type)
new_assoc_sql = ('INSERT INTO %(associations)s '
'VALUES (%%s, %%s, %%s, %%s, %%s, %%s);')
update_assoc_sql = ('UPDATE %(associations)s SET '
'secret = %%s, issued = %%s, '
'lifetime = %%s, assoc_type = %%s '
'WHERE server_url = %%s AND handle = %%s;')
get_assocs_sql = ('SELECT handle, secret, issued, lifetime, assoc_type'
' FROM %(associations)s WHERE server_url = %%s;')
get_expired_sql = ('SELECT server_url '
'FROM %(associations)s WHERE issued + lifetime < %%s;')
get_assoc_sql = (
'SELECT handle, secret, issued, lifetime, assoc_type'
' FROM %(associations)s WHERE server_url = %%s AND handle = %%s;')
remove_assoc_sql = ('DELETE FROM %(associations)s '
'WHERE server_url = %%s AND handle = %%s;')
clean_assoc_sql = 'DELETE FROM %(associations)s WHERE issued + lifetime < %%s;'
add_nonce_sql = 'INSERT INTO %(nonces)s VALUES (%%s, %%s, %%s);'
clean_nonce_sql = 'DELETE FROM %(nonces)s WHERE timestamp < %%s;'
def blobEncode(self, blob):
from psycopg2 import Binary
return Binary(blob)
def blobDecode(self, blob):
return blob.tobytes()

View File

@@ -0,0 +1,161 @@
import re
from openid import codecutil # registers 'oid_percent_escape' encoding handler
# from appendix B of rfc 3986 (http://www.ietf.org/rfc/rfc3986.txt)
uri_pattern = r'^(([^:/?#]+):)?(//([^/?#]*))?([^?#]*)(\?([^#]*))?(#(.*))?'
uri_re = re.compile(uri_pattern)
# gen-delims = ":" / "/" / "?" / "#" / "[" / "]" / "@"
#
# sub-delims = "!" / "$" / "&" / "'" / "(" / ")"
# / "*" / "+" / "," / ";" / "="
#
# unreserved = ALPHA / DIGIT / "-" / "." / "_" / "~"
uri_illegal_char_re = re.compile(r"[^-A-Za-z0-9:/?#[\]@!$&'()*+,;=._~%]",
re.UNICODE)
authority_pattern = r'^([^@]*@)?([^:]*)(:.*)?'
authority_re = re.compile(authority_pattern)
pct_encoded_pattern = r'%([0-9A-Fa-f]{2})'
pct_encoded_re = re.compile(pct_encoded_pattern)
_unreserved = [False] * 256
for _ in range(ord('A'), ord('Z') + 1):
_unreserved[_] = True
for _ in range(ord('0'), ord('9') + 1):
_unreserved[_] = True
for _ in range(ord('a'), ord('z') + 1):
_unreserved[_] = True
_unreserved[ord('-')] = True
_unreserved[ord('.')] = True
_unreserved[ord('_')] = True
_unreserved[ord('~')] = True
def _pct_encoded_replace_unreserved(mo):
try:
i = int(mo.group(1), 16)
if _unreserved[i]:
return chr(i)
else:
return mo.group().upper()
except ValueError:
return mo.group()
def _pct_encoded_replace(mo):
try:
return chr(int(mo.group(1), 16))
except ValueError:
return mo.group()
def remove_dot_segments(path):
result_segments = []
while path:
if path.startswith('../'):
path = path[3:]
elif path.startswith('./'):
path = path[2:]
elif path.startswith('/./'):
path = path[2:]
elif path == '/.':
path = '/'
elif path.startswith('/../'):
path = path[3:]
if result_segments:
result_segments.pop()
elif path == '/..':
path = '/'
if result_segments:
result_segments.pop()
elif path == '..' or path == '.':
path = ''
else:
i = 0
if path[0] == '/':
i = 1
i = path.find('/', i)
if i == -1:
i = len(path)
result_segments.append(path[:i])
path = path[i:]
return ''.join(result_segments)
def urinorm(uri):
'''
Normalize a URI
'''
# TODO: use urllib.parse instead of these complex regular expressions
if isinstance(uri, bytes):
uri = str(uri, encoding='utf-8')
uri = uri.encode('ascii', errors='oid_percent_escape').decode('utf-8')
# _escapeme_re.sub(_pct_escape_unicode, uri).encode('ascii').decode()
illegal_mo = uri_illegal_char_re.search(uri)
if illegal_mo:
raise ValueError('Illegal characters in URI: %r at position %s' %
(illegal_mo.group(), illegal_mo.start()))
uri_mo = uri_re.match(uri)
scheme = uri_mo.group(2)
if scheme is None:
raise ValueError('No scheme specified')
scheme = scheme.lower()
if scheme not in ('http', 'https'):
raise ValueError('Not an absolute HTTP or HTTPS URI: %r' % (uri, ))
authority = uri_mo.group(4)
if authority is None:
raise ValueError('Not an absolute URI: %r' % (uri, ))
authority_mo = authority_re.match(authority)
if authority_mo is None:
raise ValueError('URI does not have a valid authority: %r' % (uri, ))
userinfo, host, port = authority_mo.groups()
if userinfo is None:
userinfo = ''
if '%' in host:
host = host.lower()
host = pct_encoded_re.sub(_pct_encoded_replace, host)
host = host.encode('idna').decode()
else:
host = host.lower()
if port:
if (port == ':' or (scheme == 'http' and port == ':80') or
(scheme == 'https' and port == ':443')):
port = ''
else:
port = ''
authority = userinfo + host + port
path = uri_mo.group(5)
path = pct_encoded_re.sub(_pct_encoded_replace_unreserved, path)
path = remove_dot_segments(path)
if not path:
path = '/'
query = uri_mo.group(6)
if query is None:
query = ''
fragment = uri_mo.group(8)
if fragment is None:
fragment = ''
return scheme + '://' + authority + path + query + fragment

View File

@@ -0,0 +1,17 @@
#-*- coding: utf-8 -*-
__all__ = [
'constants',
'discover',
'etxrd',
'filters',
'manager',
'parsehtml',
'services',
'xri',
'xrires',
]
version_info = (2, 0, 0)
__version__ = ".".join(str(x) for x in version_info)

View File

@@ -0,0 +1,137 @@
"""Functions for generating and parsing HTTP Accept: headers for
supporting server-directed content negotiation.
"""
def generateAcceptHeader(*elements):
"""Generate an accept header value
[str or (str, float)] -> str
"""
parts = []
for element in elements:
if type(element) is str:
qs = "1.0"
mtype = element
else:
mtype, q = element
q = float(q)
if q > 1 or q <= 0:
raise ValueError('Invalid preference factor: %r' % q)
qs = '%0.1f' % (q, )
parts.append((qs, mtype))
parts.sort()
chunks = []
for q, mtype in parts:
if q == '1.0':
chunks.append(mtype)
else:
chunks.append('%s; q=%s' % (mtype, q))
return ', '.join(chunks)
def parseAcceptHeader(value):
"""Parse an accept header, ignoring any accept-extensions
returns a list of tuples containing main MIME type, MIME subtype,
and quality markdown.
str -> [(str, str, float)]
"""
chunks = [chunk.strip() for chunk in value.split(',')]
accept = []
for chunk in chunks:
parts = [s.strip() for s in chunk.split(';')]
mtype = parts.pop(0)
if '/' not in mtype:
# This is not a MIME type, so ignore the bad data
continue
main, sub = mtype.split('/', 1)
for ext in parts:
if '=' in ext:
k, v = ext.split('=', 1)
if k == 'q':
try:
q = float(v)
break
except ValueError:
# Ignore poorly formed q-values
pass
else:
q = 1.0
accept.append((q, main, sub))
accept.sort()
accept.reverse()
return [(main, sub, q) for (q, main, sub) in accept]
def matchTypes(accept_types, have_types):
"""Given the result of parsing an Accept: header, and the
available MIME types, return the acceptable types with their
quality markdowns.
For example:
>>> acceptable = parseAcceptHeader('text/html, text/plain; q=0.5')
>>> matchTypes(acceptable, ['text/plain', 'text/html', 'image/jpeg'])
[('text/html', 1.0), ('text/plain', 0.5)]
Type signature: ([(str, str, float)], [str]) -> [(str, float)]
"""
if not accept_types:
# Accept all of them
default = 1
else:
default = 0
match_main = {}
match_sub = {}
for (main, sub, q) in accept_types:
if main == '*':
default = max(default, q)
continue
elif sub == '*':
match_main[main] = max(match_main.get(main, 0), q)
else:
match_sub[(main, sub)] = max(match_sub.get((main, sub), 0), q)
accepted_list = []
order_maintainer = 0
for mtype in have_types:
main, sub = mtype.split('/')
if (main, sub) in match_sub:
q = match_sub[(main, sub)]
else:
q = match_main.get(main, default)
if q:
accepted_list.append((1 - q, order_maintainer, q, mtype))
order_maintainer += 1
accepted_list.sort()
return [(mtype, q) for (_, _, q, mtype) in accepted_list]
def getAcceptable(accept_header, have_types):
"""Parse the accept header and return a list of available types in
preferred order. If a type is unacceptable, it will not be in the
resulting list.
This is a convenience wrapper around matchTypes and
parseAcceptHeader.
(str, [str]) -> [str]
"""
accepted = parseAcceptHeader(accept_header)
preferred = matchTypes(accepted, have_types)
return [mtype for (mtype, _) in preferred]

View File

@@ -0,0 +1,12 @@
__all__ = ['YADIS_HEADER_NAME', 'YADIS_CONTENT_TYPE', 'YADIS_ACCEPT_HEADER']
from openid.yadis.accept import generateAcceptHeader
YADIS_HEADER_NAME = 'X-XRDS-Location'
YADIS_CONTENT_TYPE = 'application/xrds+xml'
# A value suitable for using as an accept header when performing YADIS
# discovery, unless the application has special requirements
YADIS_ACCEPT_HEADER = generateAcceptHeader(
('text/html', 0.3),
('application/xhtml+xml', 0.5),
(YADIS_CONTENT_TYPE, 1.0), )

View File

@@ -0,0 +1,169 @@
# -*- test-case-name: openid.test.test_yadis_discover -*-
__all__ = ['discover', 'DiscoveryResult', 'DiscoveryFailure']
from io import StringIO
from openid import fetchers
from openid.yadis.constants import \
YADIS_HEADER_NAME, YADIS_CONTENT_TYPE, YADIS_ACCEPT_HEADER
from openid.yadis.parsehtml import MetaNotFound, findHTMLMeta
class DiscoveryFailure(Exception):
"""Raised when a YADIS protocol error occurs in the discovery process"""
identity_url = None
def __init__(self, message, http_response):
Exception.__init__(self, message)
self.http_response = http_response
class DiscoveryResult(object):
"""Contains the result of performing Yadis discovery on a URI"""
# The URI that was passed to the fetcher
request_uri = None
# The result of following redirects from the request_uri
normalized_uri = None
# The URI from which the response text was returned (set to
# None if there was no XRDS document found)
xrds_uri = None
# The content-type returned with the response_text
content_type = None
# The document returned from the xrds_uri
response_text = None
def __init__(self, request_uri):
"""Initialize the state of the object
sets all attributes to None except the request_uri
"""
self.request_uri = request_uri
def usedYadisLocation(self):
"""Was the Yadis protocol's indirection used?"""
if self.xrds_uri is None:
return False
return self.normalized_uri != self.xrds_uri
def isXRDS(self):
"""Is the response text supposed to be an XRDS document?"""
return (self.usedYadisLocation() or
self.content_type == YADIS_CONTENT_TYPE)
def discover(uri):
"""Discover services for a given URI.
@param uri: The identity URI as a well-formed http or https
URI. The well-formedness and the protocol are not checked, but
the results of this function are undefined if those properties
do not hold.
@return: DiscoveryResult object
@raises Exception: Any exception that can be raised by fetching a URL with
the given fetcher.
@raises DiscoveryFailure: When the HTTP response does not have a 200 code.
"""
result = DiscoveryResult(uri)
resp = fetchers.fetch(uri, headers={'Accept': YADIS_ACCEPT_HEADER})
if resp.status not in (200, 206):
raise DiscoveryFailure(
'HTTP Response status from identity URL host is not 200. '
'Got status %r' % (resp.status, ), resp)
# Note the URL after following redirects
result.normalized_uri = resp.final_url
# Attempt to find out where to go to discover the document
# or if we already have it
result.content_type = resp.headers.get('content-type')
result.xrds_uri = whereIsYadis(resp)
if result.xrds_uri and result.usedYadisLocation():
resp = fetchers.fetch(result.xrds_uri)
if resp.status not in (200, 206):
exc = DiscoveryFailure(
'HTTP Response status from Yadis host is not 200. '
'Got status %r' % (resp.status, ), resp)
exc.identity_url = result.normalized_uri
raise exc
result.content_type = resp.headers.get('content-type')
result.response_text = resp.body
return result
def whereIsYadis(resp):
"""Given a HTTPResponse, return the location of the Yadis document.
May be the URL just retrieved, another URL, or None if no suitable URL can
be found.
[non-blocking]
@returns: str or None
"""
# Attempt to find out where to go to discover the document
# or if we already have it
content_type = resp.headers.get('content-type')
# According to the spec, the content-type header must be an exact
# match, or else we have to look for an indirection.
if (content_type and
content_type.split(';', 1)[0].lower() == YADIS_CONTENT_TYPE):
return resp.final_url
else:
# Try the header
yadis_loc = resp.headers.get(YADIS_HEADER_NAME.lower())
if not yadis_loc:
# Parse as HTML if the header is missing.
#
# XXX: do we want to do something with content-type, like
# have a whitelist or a blacklist (for detecting that it's
# HTML)?
# Decode body by encoding of file
content_type = content_type or ''
encoding = content_type.rsplit(';', 1)
if (len(encoding) == 2 and
encoding[1].strip().startswith('charset=')):
encoding = encoding[1].split('=', 1)[1].strip()
else:
encoding = 'utf-8'
if isinstance(resp.body, bytes):
try:
content = resp.body.decode(encoding)
except UnicodeError:
# All right, the detected encoding has failed. Try with
# UTF-8 (even if there was no detected encoding and we've
# defaulted to UTF-8, it's not that expensive an operation)
try:
content = resp.body.decode('utf-8')
except UnicodeError:
# At this point the content cannot be decoded to a str
# using the detected encoding or falling back to utf-8,
# so we have to resort to replacing undecodable chars.
# This *will* result in broken content but there isn't
# anything else that can be done.
content = resp.body.decode(encoding, 'replace')
else:
content = resp.body
try:
yadis_loc = findHTMLMeta(StringIO(content))
except (MetaNotFound, UnicodeError):
# UnicodeError: Response body could not be encoded and xrds
# location could not be found before troubles occur.
pass
return yadis_loc

View File

@@ -0,0 +1,318 @@
# -*- test-case-name: yadis.test.test_etxrd -*-
"""
ElementTree interface to an XRD document.
"""
__all__ = [
'nsTag',
'mkXRDTag',
'isXRDS',
'parseXRDS',
'getCanonicalID',
'getYadisXRD',
'getPriorityStrict',
'getPriority',
'prioSort',
'iterServices',
'expandService',
'expandServices',
]
import sys
import random
import functools
from datetime import datetime
from time import strptime
from openid.oidutil import importElementTree, importSafeElementTree
ElementTree = importElementTree()
SafeElementTree = importSafeElementTree()
from openid.yadis import xri
class XRDSError(Exception):
"""An error with the XRDS document."""
# The exception that triggered this exception
reason = None
class XRDSFraud(XRDSError):
"""Raised when there's an assertion in the XRDS that it does not have
the authority to make.
"""
def parseXRDS(text):
"""Parse the given text as an XRDS document.
@return: ElementTree containing an XRDS document
@raises XRDSError: When there is a parse error or the document does
not contain an XRDS.
"""
try:
# lxml prefers to parse bytestrings, and occasionally chokes on a
# combination of text strings and declared XML encodings -- see
# https://github.com/necaris/python3-openid/issues/19
# To avoid this, we ensure that the 'text' we're parsing is actually
# a bytestring
bytestring = text.encode('utf8') if isinstance(text, str) else text
element = SafeElementTree.XML(bytestring)
except (SystemExit, MemoryError, AssertionError, ImportError):
raise
except Exception as why:
exc = XRDSError('Error parsing document as XML')
exc.reason = why
raise exc
else:
tree = ElementTree.ElementTree(element)
if not isXRDS(tree):
raise XRDSError('Not an XRDS document')
return tree
XRD_NS_2_0 = 'xri://$xrd*($v*2.0)'
XRDS_NS = 'xri://$xrds'
def nsTag(ns, t):
return '{%s}%s' % (ns, t)
def mkXRDTag(t):
"""basestring -> basestring
Create a tag name in the XRD 2.0 XML namespace suitable for using
with ElementTree
"""
return nsTag(XRD_NS_2_0, t)
def mkXRDSTag(t):
"""basestring -> basestring
Create a tag name in the XRDS XML namespace suitable for using
with ElementTree
"""
return nsTag(XRDS_NS, t)
# Tags that are used in Yadis documents
root_tag = mkXRDSTag('XRDS')
service_tag = mkXRDTag('Service')
xrd_tag = mkXRDTag('XRD')
type_tag = mkXRDTag('Type')
uri_tag = mkXRDTag('URI')
expires_tag = mkXRDTag('Expires')
# Other XRD tags
canonicalID_tag = mkXRDTag('CanonicalID')
def isXRDS(xrd_tree):
"""Is this document an XRDS document?"""
root = xrd_tree.getroot()
return root.tag == root_tag
def getYadisXRD(xrd_tree):
"""Return the XRD element that should contain the Yadis services"""
xrd = None
# for the side-effect of assigning the last one in the list to the
# xrd variable
for xrd in xrd_tree.findall(xrd_tag):
pass
# There were no elements found, or else xrd would be set to the
# last one
if xrd is None:
raise XRDSError('No XRD present in tree')
return xrd
def getXRDExpiration(xrd_element, default=None):
"""Return the expiration date of this XRD element, or None if no
expiration was specified.
@type xrd_element: ElementTree node
@param default: The value to use as the expiration if no
expiration was specified in the XRD.
@rtype: datetime.datetime
@raises ValueError: If the xrd:Expires element is present, but its
contents are not formatted according to the specification.
"""
expires_element = xrd_element.find(expires_tag)
if expires_element is None:
return default
else:
expires_string = expires_element.text
# Will raise ValueError if the string is not the expected format
expires_time = strptime(expires_string, "%Y-%m-%dT%H:%M:%SZ")
return datetime(*expires_time[0:6])
def getCanonicalID(iname, xrd_tree):
"""Return the CanonicalID from this XRDS document.
@param iname: the XRI being resolved.
@type iname: unicode
@param xrd_tree: The XRDS output from the resolver.
@type xrd_tree: ElementTree
@returns: The XRI CanonicalID or None.
@returntype: unicode or None
"""
xrd_list = xrd_tree.findall(xrd_tag)
xrd_list.reverse()
try:
canonicalID = xri.XRI(xrd_list[0].findall(canonicalID_tag)[0].text)
except IndexError:
return None
childID = canonicalID.lower()
for xrd in xrd_list[1:]:
parent_sought = childID.rsplit("!", 1)[0]
parent = xri.XRI(xrd.findtext(canonicalID_tag))
if parent_sought != parent.lower():
raise XRDSFraud("%r can not come from %s" % (childID, parent))
childID = parent_sought
root = xri.rootAuthority(iname)
if not xri.providerIsAuthoritative(root, childID):
raise XRDSFraud("%r can not come from root %r" % (childID, root))
return canonicalID
@functools.total_ordering
class _Max(object):
"""
Value that compares greater than any other value.
Should only be used as a singleton. Implemented for use as a
priority value for when a priority is not specified.
"""
def __lt__(self, other):
return isinstance(other, self.__class__)
def __eq__(self, other):
return isinstance(other, self.__class__)
Max = _Max()
def getPriorityStrict(element):
"""Get the priority of this element.
Raises ValueError if the value of the priority is invalid. If no
priority is specified, it returns a value that compares greater
than any other value.
"""
prio_str = element.get('priority')
if prio_str is not None:
prio_val = int(prio_str)
if prio_val >= 0:
return prio_val
else:
raise ValueError('Priority values must be non-negative integers')
# Any errors in parsing the priority fall through to here
return Max
def getPriority(element):
"""Get the priority of this element
Returns Max if no priority is specified or the priority value is invalid.
"""
try:
return getPriorityStrict(element)
except ValueError:
return Max
def prioSort(elements):
"""Sort a list of elements that have priority attributes"""
# Randomize the services before sorting so that equal priority
# elements are load-balanced.
random.shuffle(elements)
sorted_elems = sorted(elements, key=getPriority)
return sorted_elems
def iterServices(xrd_tree):
"""Return an iterable over the Service elements in the Yadis XRD
sorted by priority"""
xrd = getYadisXRD(xrd_tree)
return prioSort(xrd.findall(service_tag))
def sortedURIs(service_element):
"""Given a Service element, return a list of the contents of all
URI tags in priority order."""
return [
uri_element.text
for uri_element in prioSort(service_element.findall(uri_tag))
]
def getTypeURIs(service_element):
"""Given a Service element, return a list of the contents of all
Type tags"""
return [
type_element.text for type_element in service_element.findall(type_tag)
]
def expandService(service_element):
"""Take a service element and expand it into an iterator of:
([type_uri], uri, service_element)
"""
uris = sortedURIs(service_element)
if not uris:
uris = [None]
expanded = []
for uri in uris:
type_uris = getTypeURIs(service_element)
expanded.append((type_uris, uri, service_element))
return expanded
def expandServices(service_elements):
"""Take a sorted iterator of service elements and expand it into a
sorted iterator of:
([type_uri], uri, service_element)
There may be more than one item in the resulting list for each
service element if there is more than one URI or type for a
service, but each triple will be unique.
If there is no URI or Type for a Service element, it will not
appear in the result.
"""
expanded = []
for service_element in service_elements:
expanded.extend(expandService(service_element))
return expanded

View File

@@ -0,0 +1,213 @@
"""This module contains functions and classes used for extracting
endpoint information out of a Yadis XRD file using the ElementTree XML
parser.
"""
__all__ = [
'BasicServiceEndpoint',
'mkFilter',
'IFilter',
'TransformFilterMaker',
'CompoundFilter',
]
from openid.yadis.etxrd import expandService
try:
from collections.abc import Callable
except ImportError:
from collections import Callable
class BasicServiceEndpoint(object):
"""Generic endpoint object that contains parsed service
information, as well as a reference to the service element from
which it was generated. If there is more than one xrd:Type or
xrd:URI in the xrd:Service, this object represents just one of
those pairs.
This object can be used as a filter, because it implements
fromBasicServiceEndpoint.
The simplest kind of filter you can write implements
fromBasicServiceEndpoint, which takes one of these objects.
"""
def __init__(self, yadis_url, type_uris, uri, service_element):
self.type_uris = type_uris
self.yadis_url = yadis_url
self.uri = uri
self.service_element = service_element
def matchTypes(self, type_uris):
"""Query this endpoint to see if it has any of the given type
URIs. This is useful for implementing other endpoint classes
that e.g. need to check for the presence of multiple versions
of a single protocol.
@param type_uris: The URIs that you wish to check
@type type_uris: iterable of str
@return: all types that are in both in type_uris and
self.type_uris
"""
return [uri for uri in type_uris if uri in self.type_uris]
def fromBasicServiceEndpoint(endpoint):
"""Trivial transform from a basic endpoint to itself. This
method exists to allow BasicServiceEndpoint to be used as a
filter.
If you are subclassing this object, re-implement this function.
@param endpoint: An instance of BasicServiceEndpoint
@return: The object that was passed in, with no processing.
"""
return endpoint
fromBasicServiceEndpoint = staticmethod(fromBasicServiceEndpoint)
class IFilter(object):
"""Interface for Yadis filter objects. Other filter-like things
are convertable to this class."""
def getServiceEndpoints(self, yadis_url, service_element):
"""Returns an iterator of endpoint objects"""
raise NotImplementedError
class TransformFilterMaker(object):
"""Take a list of basic filters and makes a filter that transforms
the basic filter into a top-level filter. This is mostly useful
for the implementation of mkFilter, which should only be needed
for special cases or internal use by this library.
This object is useful for creating simple filters for services
that use one URI and are specified by one Type (we expect most
Types will fit this paradigm).
Creates a BasicServiceEndpoint object and apply the filter
functions to it until one of them returns a value.
"""
def __init__(self, filter_functions):
"""Initialize the filter maker's state
@param filter_functions: The endpoint transformer functions to
apply to the basic endpoint. These are called in turn
until one of them does not return None, and the result of
that transformer is returned.
"""
self.filter_functions = filter_functions
def getServiceEndpoints(self, yadis_url, service_element):
"""Returns an iterator of endpoint objects produced by the
filter functions."""
endpoints = []
# Do an expansion of the service element by xrd:Type and xrd:URI
for type_uris, uri, _ in expandService(service_element):
# Create a basic endpoint object to represent this
# yadis_url, Service, Type, URI combination
endpoint = BasicServiceEndpoint(yadis_url, type_uris, uri,
service_element)
e = self.applyFilters(endpoint)
if e is not None:
endpoints.append(e)
return endpoints
def applyFilters(self, endpoint):
"""Apply filter functions to an endpoint until one of them
returns non-None."""
for filter_function in self.filter_functions:
e = filter_function(endpoint)
if e is not None:
# Once one of the filters has returned an
# endpoint, do not apply any more.
return e
return None
class CompoundFilter(object):
"""Create a new filter that applies a set of filters to an endpoint
and collects their results.
"""
def __init__(self, subfilters):
self.subfilters = subfilters
def getServiceEndpoints(self, yadis_url, service_element):
"""Generate all endpoint objects for all of the subfilters of
this filter and return their concatenation."""
endpoints = []
for subfilter in self.subfilters:
endpoints.extend(
subfilter.getServiceEndpoints(yadis_url, service_element))
return endpoints
# Exception raised when something is not able to be turned into a filter
filter_type_error = TypeError(
'Expected a filter, an endpoint, a callable or a list of any of these.')
def mkFilter(parts):
"""Convert a filter-convertable thing into a filter
@param parts: a filter, an endpoint, a callable, or a list of any of these.
"""
# Convert the parts into a list, and pass to mkCompoundFilter
if parts is None:
parts = [BasicServiceEndpoint]
try:
parts = list(parts)
except TypeError:
return mkCompoundFilter([parts])
else:
return mkCompoundFilter(parts)
def mkCompoundFilter(parts):
"""Create a filter out of a list of filter-like things
Used by mkFilter
@param parts: list of filter, endpoint, callable or list of any of these
"""
# Separate into a list of callables and a list of filter objects
transformers = []
filters = []
for subfilter in parts:
try:
subfilter = list(subfilter)
except TypeError:
# If it's not an iterable
if hasattr(subfilter, 'getServiceEndpoints'):
# It's a full filter
filters.append(subfilter)
elif hasattr(subfilter, 'fromBasicServiceEndpoint'):
# It's an endpoint object, so put its endpoint
# conversion attribute into the list of endpoint
# transformers
transformers.append(subfilter.fromBasicServiceEndpoint)
elif isinstance(subfilter, Callable):
# It's a simple callable, so add it to the list of
# endpoint transformers
transformers.append(subfilter)
else:
raise filter_type_error
else:
filters.append(mkCompoundFilter(subfilter))
if transformers:
filters.append(TransformFilterMaker(transformers))
if len(filters) == 1:
return filters[0]
else:
return CompoundFilter(filters)

View File

@@ -0,0 +1,195 @@
class YadisServiceManager(object):
"""Holds the state of a list of selected Yadis services, managing
storing it in a session and iterating over the services in order."""
def __init__(self, starting_url, yadis_url, services, session_key):
# The URL that was used to initiate the Yadis protocol
self.starting_url = starting_url
# The URL after following redirects (the identifier)
self.yadis_url = yadis_url
# List of service elements
self.services = list(services)
self.session_key = session_key
# Reference to the current service object
self._current = None
def __len__(self):
"""How many untried services remain?"""
return len(self.services)
def __iter__(self):
return self
def __next__(self):
"""Return the next service
self.current() will continue to return that service until the
next call to this method."""
try:
self._current = self.services.pop(0)
except IndexError:
raise StopIteration
else:
return self._current
def current(self):
"""Return the current service.
Returns None if there are no services left.
"""
return self._current
def forURL(self, url):
return url in [self.starting_url, self.yadis_url]
def started(self):
"""Has the first service been returned?"""
return self._current is not None
def store(self, session):
"""Store this object in the session, by its session key."""
session[self.session_key] = self
class Discovery(object):
"""State management for discovery.
High-level usage pattern is to call .getNextService(discover) in
order to find the next available service for this user for this
session. Once a request completes, call .finish() to clean up the
session state.
@ivar session: a dict-like object that stores state unique to the
requesting user-agent. This object must be able to store
serializable objects.
@ivar url: the URL that is used to make the discovery request
@ivar session_key_suffix: The suffix that will be used to identify
this object in the session object.
"""
DEFAULT_SUFFIX = 'auth'
PREFIX = '_yadis_services_'
def __init__(self, session, url, session_key_suffix=None):
"""Initialize a discovery object"""
self.session = session
self.url = url
if session_key_suffix is None:
session_key_suffix = self.DEFAULT_SUFFIX
self.session_key_suffix = session_key_suffix
def getNextService(self, discover):
"""Return the next authentication service for the pair of
user_input and session. This function handles fallback.
@param discover: a callable that takes a URL and returns a
list of services
@type discover: str -> [service]
@return: the next available service
"""
manager = self.getManager()
if manager is not None and not manager:
self.destroyManager()
if not manager:
yadis_url, services = discover(self.url)
manager = self.createManager(services, yadis_url)
if manager:
service = next(manager)
manager.store(self.session)
else:
service = None
return service
def cleanup(self, force=False):
"""Clean up Yadis-related services in the session and return
the most-recently-attempted service from the manager, if one
exists.
@param force: True if the manager should be deleted regardless
of whether it's a manager for self.url.
@return: current service endpoint object or None if there is
no current service
"""
manager = self.getManager(force=force)
if manager is not None:
service = manager.current()
self.destroyManager(force=force)
else:
service = None
return service
### Lower-level methods
def getSessionKey(self):
"""Get the session key for this starting URL and suffix
@return: The session key
@rtype: str
"""
return self.PREFIX + self.session_key_suffix
def getManager(self, force=False):
"""Extract the YadisServiceManager for this object's URL and
suffix from the session.
@param force: True if the manager should be returned
regardless of whether it's a manager for self.url.
@return: The current YadisServiceManager, if it's for this
URL, or else None
"""
manager = self.session.get(self.getSessionKey())
if (manager is not None and (manager.forURL(self.url) or force)):
return manager
else:
return None
def createManager(self, services, yadis_url=None):
"""Create a new YadisService Manager for this starting URL and
suffix, and store it in the session.
@raises KeyError: When I already have a manager.
@return: A new YadisServiceManager or None
"""
key = self.getSessionKey()
if self.getManager():
raise KeyError('There is already a %r manager for %r' %
(key, self.url))
if not services:
return None
manager = YadisServiceManager(self.url, yadis_url, services, key)
manager.store(self.session)
return manager
def destroyManager(self, force=False):
"""Delete any YadisServiceManager with this starting URL and
suffix from the session.
If there is no service manager or the service manager is for a
different URL, it silently does nothing.
@param force: True if the manager should be deleted regardless
of whether it's a manager for self.url.
"""
if self.getManager(force=force) is not None:
key = self.getSessionKey()
del self.session[key]

View File

@@ -0,0 +1,207 @@
__all__ = ['findHTMLMeta', 'MetaNotFound']
from html.parser import HTMLParser
import html.entities
import re
import sys
from openid.yadis.constants import YADIS_HEADER_NAME
# Size of the chunks to search at a time (also the amount that gets
# read at a time)
CHUNK_SIZE = 1024 * 16 # 16 KB
class ParseDone(Exception):
"""Exception to hold the URI that was located when the parse is
finished. If the parse finishes without finding the URI, set it to
None."""
class MetaNotFound(Exception):
"""Exception to hold the content of the page if we did not find
the appropriate <meta> tag"""
re_flags = re.IGNORECASE | re.UNICODE | re.VERBOSE
ent_pat = r'''
&
(?: \#x (?P<hex> [a-f0-9]+ )
| \# (?P<dec> \d+ )
| (?P<word> \w+ )
)
;'''
ent_re = re.compile(ent_pat, re_flags)
def substituteMO(mo):
if mo.lastgroup == 'hex':
codepoint = int(mo.group('hex'), 16)
elif mo.lastgroup == 'dec':
codepoint = int(mo.group('dec'))
else:
assert mo.lastgroup == 'word'
codepoint = html.entities.name2codepoint.get(mo.group('word'))
if codepoint is None:
return mo.group()
else:
return chr(codepoint)
def substituteEntities(s):
return ent_re.sub(substituteMO, s)
class YadisHTMLParser(HTMLParser):
"""Parser that finds a meta http-equiv tag in the head of a html
document.
When feeding in data, if the tag is matched or it will never be
found, the parser will raise ParseDone with the uri as the first
attribute.
Parsing state diagram
=====================
Any unlisted input does not affect the state::
1, 2, 5 8
+--------------------------+ +-+
| | | |
4 | 3 1, 2, 5, 7 v | v
TOP -> HTML -> HEAD ----------> TERMINATED
| | ^ | ^ ^
| | 3 | | | |
| +------------+ +-> FOUND ------+ |
| 6 8 |
| 1, 2 |
+------------------------------------+
1. any of </body>, </html>, </head> -> TERMINATE
2. <body> -> TERMINATE
3. <head> -> HEAD
4. <html> -> HTML
5. <html> -> TERMINATE
6. <meta http-equiv='X-XRDS-Location'> -> FOUND
7. <head> -> TERMINATE
8. Any input -> TERMINATE
"""
TOP = 0
HTML = 1
HEAD = 2
FOUND = 3
TERMINATED = 4
def __init__(self):
if (sys.version_info.minor <= 2):
# Python 3.2 and below actually require the `strict` argument
# to `html.parser.HTMLParser` -- otherwise it's deprecated and
# we don't want to pass it
super(YadisHTMLParser, self).__init__(strict=False)
else:
super(YadisHTMLParser, self).__init__()
self.phase = self.TOP
def _terminate(self):
self.phase = self.TERMINATED
raise ParseDone(None)
def handle_endtag(self, tag):
# If we ever see an end of head, body, or html, bail out right away.
# [1]
if tag in ['head', 'body', 'html']:
self._terminate()
def handle_starttag(self, tag, attrs):
# if we ever see a start body tag, bail out right away, since
# we want to prevent the meta tag from appearing in the body
# [2]
if tag == 'body':
self._terminate()
if self.phase == self.TOP:
# At the top level, allow a html tag or a head tag to move
# to the head or html phase
if tag == 'head':
# [3]
self.phase = self.HEAD
elif tag == 'html':
# [4]
self.phase = self.HTML
elif self.phase == self.HTML:
# if we are in the html tag, allow a head tag to move to
# the HEAD phase. If we get another html tag, then bail
# out
if tag == 'head':
# [3]
self.phase = self.HEAD
elif tag == 'html':
# [5]
self._terminate()
elif self.phase == self.HEAD:
# If we are in the head phase, look for the appropriate
# meta tag. If we get a head or body tag, bail out.
if tag == 'meta':
attrs_d = dict(attrs)
http_equiv = attrs_d.get('http-equiv', '').lower()
if http_equiv == YADIS_HEADER_NAME.lower():
raw_attr = attrs_d.get('content')
yadis_loc = substituteEntities(raw_attr)
# [6]
self.phase = self.FOUND
raise ParseDone(yadis_loc)
elif tag in ('head', 'html'):
# [5], [7]
self._terminate()
def feed(self, chars):
# [8]
if self.phase in (self.TERMINATED, self.FOUND):
self._terminate()
return super(YadisHTMLParser, self).feed(chars)
def findHTMLMeta(stream):
"""Look for a meta http-equiv tag with the YADIS header name.
@param stream: Source of the html text
@type stream: Object that implements a read() method that works
like file.read
@return: The URI from which to fetch the XRDS document
@rtype: str
@raises MetaNotFound: raised with the content that was
searched as the first parameter.
"""
parser = YadisHTMLParser()
chunks = []
while 1:
chunk = stream.read(CHUNK_SIZE)
if not chunk:
# End of file
break
chunks.append(chunk)
try:
parser.feed(chunk)
except ParseDone as why:
uri = why.args[0]
if uri is None:
# Parse finished, but we may need the rest of the file
chunks.append(stream.read())
break
else:
return uri
content = ''.join(chunks)
raise MetaNotFound(content)

View File

@@ -0,0 +1,56 @@
# -*- test-case-name: openid.test.test_services -*-
from openid.yadis.filters import mkFilter
from openid.yadis.discover import discover, DiscoveryFailure
from openid.yadis.etxrd import parseXRDS, iterServices, XRDSError
def getServiceEndpoints(input_url, flt=None):
"""Perform the Yadis protocol on the input URL and return an
iterable of resulting endpoint objects.
@param flt: A filter object or something that is convertable to
a filter object (using mkFilter) that will be used to generate
endpoint objects. This defaults to generating BasicEndpoint
objects.
@param input_url: The URL on which to perform the Yadis protocol
@return: The normalized identity URL and an iterable of endpoint
objects generated by the filter function.
@rtype: (str, [endpoint])
@raises DiscoveryFailure: when Yadis fails to obtain an XRDS document.
"""
result = discover(input_url)
try:
endpoints = applyFilter(result.normalized_uri, result.response_text,
flt)
except XRDSError as err:
raise DiscoveryFailure(str(err), None)
return (result.normalized_uri, endpoints)
def applyFilter(normalized_uri, xrd_data, flt=None):
"""Generate an iterable of endpoint objects given this input data,
presumably from the result of performing the Yadis protocol.
@param normalized_uri: The input URL, after following redirects,
as in the Yadis protocol.
@param xrd_data: The XML text the XRDS file fetched from the
normalized URI.
@type xrd_data: str
"""
flt = mkFilter(flt)
et = parseXRDS(xrd_data)
endpoints = []
for service_element in iterServices(et):
endpoints.extend(
flt.getServiceEndpoints(normalized_uri, service_element))
return endpoints

View File

@@ -0,0 +1,122 @@
# -*- test-case-name: openid.test.test_xri -*-
"""Utility functions for handling XRIs.
@see: XRI Syntax v2.0 at the U{OASIS XRI Technical Committee<http://www.oasis-open.org/committees/tc_home.php?wg_abbrev=xri>}
"""
import re
from functools import reduce
from openid import codecutil # registers 'oid_percent_escape' encoding handler
XRI_AUTHORITIES = ['!', '=', '@', '+', '$', '(']
def identifierScheme(identifier):
"""Determine if this identifier is an XRI or URI.
@returns: C{"XRI"} or C{"URI"}
"""
if identifier.startswith('xri://') or (identifier and
identifier[0] in XRI_AUTHORITIES):
return "XRI"
else:
return "URI"
def toIRINormal(xri):
"""Transform an XRI to IRI-normal form."""
if not xri.startswith('xri://'):
xri = 'xri://' + xri
return escapeForIRI(xri)
_xref_re = re.compile(r'\((.*?)\)')
def _escape_xref(xref_match):
"""Escape things that need to be escaped if they're in a cross-reference.
"""
xref = xref_match.group()
xref = xref.replace('/', '%2F')
xref = xref.replace('?', '%3F')
xref = xref.replace('#', '%23')
return xref
def escapeForIRI(xri):
"""Escape things that need to be escaped when transforming to an IRI."""
xri = xri.replace('%', '%25')
xri = _xref_re.sub(_escape_xref, xri)
return xri
def toURINormal(xri):
"""Transform an XRI to URI normal form."""
return iriToURI(toIRINormal(xri))
def iriToURI(iri):
"""Transform an IRI to a URI by escaping unicode."""
# According to RFC 3987, section 3.1, "Mapping of IRIs to URIs"
if isinstance(iri, bytes):
iri = str(iri, encoding="utf-8")
return iri.encode('ascii', errors='oid_percent_escape').decode()
def providerIsAuthoritative(providerID, canonicalID):
"""Is this provider ID authoritative for this XRI?
@returntype: bool
"""
# XXX: can't use rsplit until we require python >= 2.4.
lastbang = canonicalID.rindex('!')
parent = canonicalID[:lastbang]
return parent == providerID
def rootAuthority(xri):
"""Return the root authority for an XRI.
Example::
rootAuthority("xri://@example") == "xri://@"
@type xri: unicode
@returntype: unicode
"""
if xri.startswith('xri://'):
xri = xri[6:]
authority = xri.split('/', 1)[0]
if authority[0] == '(':
# Cross-reference.
# XXX: This is incorrect if someone nests cross-references so there
# is another close-paren in there. Hopefully nobody does that
# before we have a real xriparse function. Hopefully nobody does
# that *ever*.
root = authority[:authority.index(')') + 1]
elif authority[0] in XRI_AUTHORITIES:
# Other XRI reference.
root = authority[0]
else:
# IRI reference. XXX: Can IRI authorities have segments?
segments = authority.split('!')
segments = reduce(list.__add__, [s.split('*') for s in segments])
root = segments[0]
return XRI(root)
def XRI(xri):
"""An XRI object allowing comparison of XRI.
Ideally, this would do full normalization and provide comparsion
operators as per XRI Syntax. Right now, it just does a bit of
canonicalization by ensuring the xri scheme is present.
@param xri: an xri string
@type xri: unicode
"""
if not xri.startswith('xri://'):
xri = 'xri://' + xri
return xri

View File

@@ -0,0 +1,123 @@
# -*- test-case-name: openid.test.test_xrires -*-
"""XRI resolution.
"""
from urllib.parse import urlencode
from openid import fetchers
from openid.yadis import etxrd
from openid.yadis.xri import toURINormal
from openid.yadis.services import iterServices
DEFAULT_PROXY = 'http://proxy.xri.net/'
class ProxyResolver(object):
"""Python interface to a remote XRI proxy resolver.
"""
def __init__(self, proxy_url=DEFAULT_PROXY):
self.proxy_url = proxy_url
def queryURL(self, xri, service_type=None):
"""Build a URL to query the proxy resolver.
@param xri: An XRI to resolve.
@type xri: unicode
@param service_type: The service type to resolve, if you desire
service endpoint selection. A service type is a URI.
@type service_type: str
@returns: a URL
@returntype: str
"""
# Trim off the xri:// prefix. The proxy resolver didn't accept it
# when this code was written, but that may (or may not) change for
# XRI Resolution 2.0 Working Draft 11.
qxri = toURINormal(xri)[6:]
hxri = self.proxy_url + qxri
args = {
# XXX: If the proxy resolver will ensure that it doesn't return
# bogus CanonicalIDs (as per Steve's message of 15 Aug 2006
# 11:13:42), then we could ask for application/xrd+xml instead,
# which would give us a bit less to process.
'_xrd_r': 'application/xrds+xml',
}
if service_type:
args['_xrd_t'] = service_type
else:
# Don't perform service endpoint selection.
args['_xrd_r'] += ';sep=false'
query = _appendArgs(hxri, args)
return query
def query(self, xri, service_types):
"""Resolve some services for an XRI.
Note: I don't implement any service endpoint selection beyond what
the resolver I'm querying does, so the Services I return may well
include Services that were not of the types you asked for.
May raise fetchers.HTTPFetchingError or L{etxrd.XRDSError} if
the fetching or parsing don't go so well.
@param xri: An XRI to resolve.
@type xri: unicode
@param service_types: A list of services types to query for. Service
types are URIs.
@type service_types: list of str
@returns: tuple of (CanonicalID, Service elements)
@returntype: (unicode, list of C{ElementTree.Element}s)
"""
# FIXME: No test coverage!
services = []
# Make a seperate request to the proxy resolver for each service
# type, as, if it is following Refs, it could return a different
# XRDS for each.
canonicalID = None
for service_type in service_types:
url = self.queryURL(xri, service_type)
response = fetchers.fetch(url)
if response.status not in (200, 206):
# XXX: sucks to fail silently.
# print "response not OK:", response
continue
et = etxrd.parseXRDS(response.body)
canonicalID = etxrd.getCanonicalID(xri, et)
some_services = list(iterServices(et))
services.extend(some_services)
# TODO:
# * If we do get hits for multiple service_types, we're almost
# certainly going to have duplicated service entries and
# broken priority ordering.
return canonicalID, services
def _appendArgs(url, args):
"""Append some arguments to an HTTP query.
"""
# to be merged with oidutil.appendArgs when we combine the projects.
if hasattr(args, 'items'):
args = list(args.items())
args.sort()
if len(args) == 0:
return url
# According to XRI Resolution section "QXRI query parameters":
#
# """If the original QXRI had a null query component (only a leading
# question mark), or a query component consisting of only question
# marks, one additional leading question mark MUST be added when
# adding any XRI resolution parameters."""
if '?' in url.rstrip('?'):
sep = '&'
else:
sep = '?'
return '%s%s%s' % (url, sep, urlencode(args))