from urllib.parse import urlparse from django.urls import reverse from django.utils.http import urlencode from allauth.socialaccount.providers.base import Provider, ProviderAccount from .utils import ( AXAttribute, OldAXAttribute, SRegField, get_email_from_response, get_value_from_response, ) class OpenIDAccount(ProviderAccount): def get_brand(self): ret = super(OpenIDAccount, self).get_brand() domain = urlparse(self.account.uid).netloc # FIXME: Instead of hardcoding, derive this from the domains # listed in the openid endpoints setting. provider_map = { "yahoo": dict(id="yahoo", name="Yahoo"), "hyves": dict(id="hyves", name="Hyves"), "google": dict(id="google", name="Google"), } for d, p in provider_map.items(): if domain.lower().find(d) >= 0: ret = p break return ret def to_str(self): return self.account.uid class OpenIDProvider(Provider): id = "openid" name = "OpenID" account_class = OpenIDAccount uses_apps = False def get_login_url(self, request, **kwargs): url = reverse("openid_login") if kwargs: url += "?" + urlencode(kwargs) return url def get_brands(self): # These defaults are a bit too arbitrary... default_servers = [ dict(id="yahoo", name="Yahoo", openid_url="http://me.yahoo.com"), dict(id="hyves", name="Hyves", openid_url="http://hyves.nl"), ] return self.get_settings().get("SERVERS", default_servers) def get_server_settings(self, endpoint): servers = self.get_settings().get("SERVERS", []) for server in servers: if endpoint is not None and endpoint.startswith(server.get("openid_url")): return server return {} def extract_extra_data(self, response): extra_data = {} server_settings = self.get_server_settings(response.endpoint.server_url) extra_attributes = server_settings.get("extra_attributes", []) for attribute_id, name, _ in extra_attributes: extra_data[attribute_id] = get_value_from_response( response, ax_names=[name] ) return extra_data def extract_uid(self, response): return response.identity_url def extract_common_fields(self, response): first_name = ( get_value_from_response( response, ax_names=[ AXAttribute.PERSON_FIRST_NAME, OldAXAttribute.PERSON_FIRST_NAME, ], ) or "" ) last_name = ( get_value_from_response( response, ax_names=[ AXAttribute.PERSON_LAST_NAME, OldAXAttribute.PERSON_LAST_NAME, ], ) or "" ) name = ( get_value_from_response( response, sreg_names=[SRegField.NAME], ax_names=[AXAttribute.PERSON_NAME, OldAXAttribute.PERSON_NAME], ) or "" ) return dict( email=get_email_from_response(response), first_name=first_name, last_name=last_name, name=name, ) provider_classes = [OpenIDProvider]