This commit is contained in:
Iliyan Angelov
2025-09-14 23:24:25 +03:00
commit c67067a2a4
71311 changed files with 6800714 additions and 0 deletions

View File

@@ -0,0 +1,291 @@
from __future__ import absolute_import
from django.core.exceptions import (
ImproperlyConfigured,
MultipleObjectsReturned,
ValidationError,
)
from django.db.models import Q
from django.urls import reverse
from django.utils.translation import gettext_lazy as _
from allauth.core import context
from ..account.adapter import get_adapter as get_account_adapter
from ..account.app_settings import EmailVerificationMethod
from ..account.models import EmailAddress
from ..account.utils import user_email, user_field, user_username
from ..utils import (
deserialize_instance,
import_attribute,
serialize_instance,
valid_email_or_none,
)
from . import app_settings
class DefaultSocialAccountAdapter(object):
error_messages = {
"email_taken": _(
"An account already exists with this email address."
" Please sign in to that account first, then connect"
" your %s account."
)
}
def __init__(self, request=None):
# Explicitly passing `request` is deprecated, just use:
# `allauth.core.context.request`.
self.request = context.request
def pre_social_login(self, request, sociallogin):
"""
Invoked just after a user successfully authenticates via a
social provider, but before the login is actually processed
(and before the pre_social_login signal is emitted).
You can use this hook to intervene, e.g. abort the login by
raising an ImmediateHttpResponse
Why both an adapter hook and the signal? Intervening in
e.g. the flow from within a signal handler is bad -- multiple
handlers may be active and are executed in undetermined order.
"""
pass
def authentication_error(
self,
request,
provider_id,
error=None,
exception=None,
extra_context=None,
):
"""
Invoked when there is an error in the authentication cycle. In this
case, pre_social_login will not be reached.
You can use this hook to intervene, e.g. redirect to an
educational flow by raising an ImmediateHttpResponse.
"""
pass
def new_user(self, request, sociallogin):
"""
Instantiates a new User instance.
"""
return get_account_adapter().new_user(request)
def save_user(self, request, sociallogin, form=None):
"""
Saves a newly signed up social login. In case of auto-signup,
the signup form is not available.
"""
u = sociallogin.user
u.set_unusable_password()
if form:
get_account_adapter().save_user(request, u, form)
else:
get_account_adapter().populate_username(request, u)
sociallogin.save(request)
return u
def populate_user(self, request, sociallogin, data):
"""
Hook that can be used to further populate the user instance.
For convenience, we populate several common fields.
Note that the user instance being populated represents a
suggested User instance that represents the social user that is
in the process of being logged in.
The User instance need not be completely valid and conflict
free. For example, verifying whether or not the username
already exists, is not a responsibility.
"""
username = data.get("username")
first_name = data.get("first_name")
last_name = data.get("last_name")
email = data.get("email")
name = data.get("name")
user = sociallogin.user
user_username(user, username or "")
user_email(user, valid_email_or_none(email) or "")
name_parts = (name or "").partition(" ")
user_field(user, "first_name", first_name or name_parts[0])
user_field(user, "last_name", last_name or name_parts[2])
return user
def get_connect_redirect_url(self, request, socialaccount):
"""
Returns the default URL to redirect to after successfully
connecting a social account.
"""
url = reverse("socialaccount_connections")
return url
def validate_disconnect(self, account, accounts):
"""
Validate whether or not the socialaccount account can be
safely disconnected.
"""
if len(accounts) == 1:
# No usable password would render the local account unusable
if not account.user.has_usable_password():
raise ValidationError(_("Your account has no password set up."))
# No email address, no password reset
if app_settings.EMAIL_VERIFICATION == EmailVerificationMethod.MANDATORY:
if not EmailAddress.objects.filter(
user=account.user, verified=True
).exists():
raise ValidationError(
_("Your account has no verified email address.")
)
def is_auto_signup_allowed(self, request, sociallogin):
# If email is specified, check for duplicate and if so, no auto signup.
auto_signup = app_settings.AUTO_SIGNUP
return auto_signup
def is_open_for_signup(self, request, sociallogin):
"""
Checks whether or not the site is open for signups.
Next to simply returning True/False you can also intervene the
regular flow by raising an ImmediateHttpResponse
"""
return get_account_adapter(request).is_open_for_signup(request)
def get_signup_form_initial_data(self, sociallogin):
user = sociallogin.user
initial = {
"email": user_email(user) or "",
"username": user_username(user) or "",
"first_name": user_field(user, "first_name") or "",
"last_name": user_field(user, "last_name") or "",
}
return initial
def deserialize_instance(self, model, data):
return deserialize_instance(model, data)
def serialize_instance(self, instance):
return serialize_instance(instance)
def list_providers(self, request):
from allauth.socialaccount.providers import registry
ret = []
provider_classes = registry.get_class_list()
apps = self.list_apps(request)
apps_map = {}
for app in apps:
apps_map.setdefault(app.provider, []).append(app)
for provider_class in provider_classes:
provider_apps = apps_map.get(provider_class.id, [])
if not provider_apps:
if provider_class.uses_apps:
continue
provider_apps = [None]
for app in provider_apps:
provider = provider_class(request=request, app=app)
ret.append(provider)
return ret
def get_provider(self, request, provider):
"""Looks up a `provider`, supporting subproviders by looking up by
`provider_id`.
"""
from allauth.socialaccount.providers import registry
provider_class = registry.get_class(provider)
if provider_class is None or provider_class.uses_apps:
app = self.get_app(request, provider=provider)
if not provider_class:
# In this case, the `provider` argument passed was a
# `provider_id`.
provider_class = registry.get_class(app.provider)
if not provider_class:
raise ImproperlyConfigured(f"unknown provider: {app.provider}")
return provider_class(request, app=app)
elif provider_class:
assert not provider_class.uses_apps
return provider_class(request, app=None)
else:
raise ImproperlyConfigured(f"unknown provider: {app.provider}")
def list_apps(self, request, provider=None, client_id=None):
"""SocialApp's can be setup in the database, or, via
`settings.SOCIALACCOUNT_PROVIDERS`. This methods returns a uniform list
of all known apps matching the specified criteria, and blends both
(db/settings) sources of data.
"""
# NOTE: Avoid loading models at top due to registry boot...
from allauth.socialaccount.models import SocialApp
# Map provider to the list of apps.
provider_to_apps = {}
# First, populate it with the DB backed apps.
db_apps = SocialApp.objects.on_site(request)
if provider:
db_apps = db_apps.filter(Q(provider=provider) | Q(provider_id=provider))
if client_id:
db_apps = db_apps.filter(client_id=client_id)
for app in db_apps:
apps = provider_to_apps.setdefault(app.provider, [])
apps.append(app)
# Then, extend it with the settings backed apps.
for p, pcfg in app_settings.PROVIDERS.items():
app_configs = pcfg.get("APPS")
if app_configs is None:
app_config = pcfg.get("APP")
if app_config is None:
continue
app_configs = [app_config]
apps = provider_to_apps.setdefault(p, [])
for config in app_configs:
app = SocialApp(provider=p)
for field in [
"name",
"provider_id",
"client_id",
"secret",
"key",
"certificate_key",
"settings",
]:
if field in config:
setattr(app, field, config[field])
if client_id and app.client_id != client_id:
continue
if (
provider
and app.provider_id != provider
and app.provider != provider
):
continue
apps.append(app)
# Flatten the list of apps.
apps = []
for provider_apps in provider_to_apps.values():
apps.extend(provider_apps)
return apps
def get_app(self, request, provider, client_id=None):
from allauth.socialaccount.models import SocialApp
apps = self.list_apps(request, provider=provider, client_id=client_id)
if len(apps) > 1:
raise MultipleObjectsReturned
elif len(apps) == 0:
raise SocialApp.DoesNotExist()
return apps[0]
def get_adapter(request=None):
return import_attribute(app_settings.ADAPTER)(request)

View File

@@ -0,0 +1,61 @@
from django import forms
from django.contrib import admin
from allauth import app_settings
from allauth.account.adapter import get_adapter
from .models import SocialAccount, SocialApp, SocialToken
class SocialAppForm(forms.ModelForm):
class Meta:
model = SocialApp
exclude = []
widgets = {
"client_id": forms.TextInput(attrs={"size": "100"}),
"key": forms.TextInput(attrs={"size": "100"}),
"secret": forms.TextInput(attrs={"size": "100"}),
}
class SocialAppAdmin(admin.ModelAdmin):
form = SocialAppForm
list_display = (
"name",
"provider",
)
filter_horizontal = ("sites",) if app_settings.SITES_ENABLED else ()
class SocialAccountAdmin(admin.ModelAdmin):
search_fields = []
raw_id_fields = ("user",)
list_display = ("user", "uid", "provider")
list_filter = ("provider",)
def get_search_fields(self, request):
base_fields = get_adapter().get_user_search_fields()
return list(map(lambda a: "user__" + a, base_fields))
class SocialTokenAdmin(admin.ModelAdmin):
raw_id_fields = (
"app",
"account",
)
list_display = ("app", "account", "truncated_token", "expires_at")
list_filter = ("app", "app__provider", "expires_at")
def truncated_token(self, token):
max_chars = 40
ret = token.token
if len(ret) > max_chars:
ret = ret[0:max_chars] + "...(truncated)"
return ret
truncated_token.short_description = "Token"
admin.site.register(SocialApp, SocialAppAdmin)
admin.site.register(SocialToken, SocialTokenAdmin)
admin.site.register(SocialAccount, SocialAccountAdmin)

View File

@@ -0,0 +1,148 @@
class AppSettings(object):
def __init__(self, prefix):
self.prefix = prefix
def _setting(self, name, dflt):
from allauth.utils import get_setting
return get_setting(self.prefix + name, dflt)
@property
def QUERY_EMAIL(self):
"""
Request email address from 3rd party account provider?
E.g. using OpenID AX
"""
from allauth.account import app_settings as account_settings
return self._setting("QUERY_EMAIL", account_settings.EMAIL_REQUIRED)
@property
def AUTO_SIGNUP(self):
"""
Attempt to bypass the signup form by using fields (e.g. username,
email) retrieved from the social account provider. If a conflict
arises due to a duplicate email signup form will still kick in.
"""
return self._setting("AUTO_SIGNUP", True)
@property
def PROVIDERS(self):
"""
Provider specific settings
"""
ret = self._setting("PROVIDERS", {})
oidc = ret.get("openid_connect")
if oidc:
ret["openid_connect"] = self._migrate_oidc(oidc)
return ret
def _migrate_oidc(self, oidc):
servers = oidc.get("SERVERS")
if servers is None:
return oidc
ret = {}
apps = []
for server in servers:
app = dict(**server["APP"])
app_settings = {}
if "token_auth_method" in server:
app_settings["token_auth_method"] = server["token_auth_method"]
app_settings["server_url"] = server["server_url"]
app.update(
{
"name": server.get("name", ""),
"provider_id": server["id"],
"settings": app_settings,
}
)
assert app["provider_id"]
apps.append(app)
ret["APPS"] = apps
return ret
@property
def EMAIL_REQUIRED(self):
"""
The user is required to hand over an email address when signing up
"""
from allauth.account import app_settings as account_settings
return self._setting("EMAIL_REQUIRED", account_settings.EMAIL_REQUIRED)
@property
def EMAIL_VERIFICATION(self):
"""
See email verification method
"""
from allauth.account import app_settings as account_settings
return self._setting("EMAIL_VERIFICATION", account_settings.EMAIL_VERIFICATION)
@property
def EMAIL_AUTHENTICATION(self):
"""Consider a scenario where a social login occurs, and the social
account comes with a verified email address (verified by the account
provider), but that email address is already taken by a local user
account. Additionally, assume that the local user account does not have
any social account connected. Now, if the provider can be fully trusted,
you can argue that we should treat this scenario as a login to the
existing local user account even if the local account does not already
have the social account connected, because -- according to the provider
-- the user logging in has ownership of the email address. This is how
this scenario is handled when `EMAIL_AUTHENTICATION` is set to
`True`. As this implies that an untrustworthy provider can login to any
local account by fabricating social account data, this setting defaults
to `False`. Only set it to `True` if you are using providers that can be
fully trusted.
"""
return self._setting("EMAIL_AUTHENTICATION", False)
@property
def EMAIL_AUTHENTICATION_AUTO_CONNECT(self):
"""In case email authentication is applied, this setting controls
whether or not the social account is automatically connected to the
local account. In case of ``False`` (the default) the local account
remains unchanged during the login. In case of ``True``, the social
account for which the email matched, is automatically added to the list
of social accounts connected to the local account. As a result, even if
the user were to change the email address afterwards, social login
would still be possible when using ``True``, but not in case of
``False``.
"""
return self._setting("EMAIL_AUTHENTICATION_AUTO_CONNECT", False)
@property
def ADAPTER(self):
return self._setting(
"ADAPTER",
"allauth.socialaccount.adapter.DefaultSocialAccountAdapter",
)
@property
def FORMS(self):
return self._setting("FORMS", {})
@property
def LOGIN_ON_GET(self):
return self._setting("LOGIN_ON_GET", False)
@property
def STORE_TOKENS(self):
return self._setting("STORE_TOKENS", False)
@property
def UID_MAX_LENGTH(self):
return 191
@property
def SOCIALACCOUNT_STR(self):
return self._setting("SOCIALACCOUNT_STR", None)
_app_settings = AppSettings("SOCIALACCOUNT_")
def __getattr__(name):
# See https://peps.python.org/pep-0562/
return getattr(_app_settings, name)

View File

@@ -0,0 +1,8 @@
from django.apps import AppConfig
from django.utils.translation import gettext_lazy as _
class SocialAccountConfig(AppConfig):
name = "allauth.socialaccount"
verbose_name = _("Social Accounts")
default_auto_field = "django.db.models.AutoField"

View File

@@ -0,0 +1,25 @@
import pytest
from allauth.account.models import EmailAddress
from allauth.socialaccount.models import SocialAccount, SocialLogin
@pytest.fixture
def sociallogin_factory(user_factory):
def factory(
email=None,
with_email=True,
provider="unittest-server",
uid="123",
email_verified=True,
):
user = user_factory(email=email, commit=False, with_email=with_email)
account = SocialAccount(provider=provider, uid=uid)
sociallogin = SocialLogin(user=user, account=account)
if with_email:
sociallogin.email_addresses = [
EmailAddress(email=user.email, verified=email_verified, primary=True)
]
return sociallogin
return factory

View File

@@ -0,0 +1,51 @@
# Courtesy of django-social-auth
import json
from django.core.exceptions import ValidationError
from django.db import models
class JSONField(models.TextField):
"""Simple JSON field that stores python structures as JSON strings
on database.
"""
def from_db_value(self, value, *args, **kwargs):
return self.to_python(value)
def to_python(self, value):
"""
Convert the input JSON value into python structures, raises
django.core.exceptions.ValidationError if the data can't be converted.
"""
if self.blank and not value:
return None
if isinstance(value, str):
try:
return json.loads(value)
except Exception as e:
raise ValidationError(str(e))
else:
return value
def validate(self, value, model_instance):
"""Check value is a valid JSON string, raise ValidationError on
error."""
if isinstance(value, str):
super(JSONField, self).validate(value, model_instance)
try:
json.loads(value)
except Exception as e:
raise ValidationError(str(e))
def get_prep_value(self, value):
"""Convert value to JSON string before save"""
try:
return json.dumps(value)
except Exception as e:
raise ValidationError(str(e))
def value_from_object(self, obj):
"""Return value dumped to string."""
val = super(JSONField, self).value_from_object(obj)
return self.get_prep_value(val)

View File

@@ -0,0 +1,67 @@
from __future__ import absolute_import
from django import forms
from allauth.account.forms import BaseSignupForm
from . import app_settings, signals
from .adapter import get_adapter
from .models import SocialAccount
class SignupForm(BaseSignupForm):
def __init__(self, *args, **kwargs):
self.sociallogin = kwargs.pop("sociallogin")
initial = get_adapter().get_signup_form_initial_data(self.sociallogin)
kwargs.update(
{
"initial": initial,
"email_required": kwargs.get(
"email_required", app_settings.EMAIL_REQUIRED
),
}
)
super(SignupForm, self).__init__(*args, **kwargs)
def save(self, request):
adapter = get_adapter()
user = adapter.save_user(request, self.sociallogin, form=self)
self.custom_signup(request, user)
return user
def validate_unique_email(self, value):
try:
return super(SignupForm, self).validate_unique_email(value)
except forms.ValidationError:
raise forms.ValidationError(
get_adapter().error_messages["email_taken"]
% self.sociallogin.account.get_provider().name
)
class DisconnectForm(forms.Form):
account = forms.ModelChoiceField(
queryset=SocialAccount.objects.none(),
widget=forms.RadioSelect,
required=True,
)
def __init__(self, *args, **kwargs):
self.request = kwargs.pop("request")
self.accounts = SocialAccount.objects.filter(user=self.request.user)
super(DisconnectForm, self).__init__(*args, **kwargs)
self.fields["account"].queryset = self.accounts
def clean(self):
cleaned_data = super(DisconnectForm, self).clean()
account = cleaned_data.get("account")
if account:
get_adapter(self.request).validate_disconnect(account, self.accounts)
return cleaned_data
def save(self):
account = self.cleaned_data["account"]
account.delete()
signals.social_account_removed.send(
sender=SocialAccount, request=self.request, socialaccount=account
)

View File

@@ -0,0 +1,232 @@
from django.contrib import messages
from django.forms import ValidationError
from django.http import HttpResponseRedirect
from django.shortcuts import render
from django.urls import reverse
from allauth.account import app_settings as account_settings
from allauth.account.adapter import get_adapter as get_account_adapter
from allauth.account.utils import (
assess_unique_email,
complete_signup,
perform_login,
user_display,
user_email,
user_username,
)
from allauth.core.exceptions import ImmediateHttpResponse
from . import app_settings, signals
from .adapter import get_adapter
from .models import SocialLogin
from .providers.base import AuthError, AuthProcess
def _process_auto_signup(request, sociallogin):
auto_signup = get_adapter().is_auto_signup_allowed(request, sociallogin)
if not auto_signup:
return False, None
email = user_email(sociallogin.user)
# Let's check if auto_signup is really possible...
if email:
assessment = assess_unique_email(email)
if assessment is True:
# Auto signup is fine.
pass
elif assessment is False:
# Oops, another user already has this address. We cannot simply
# connect this social account to the existing user. Reason is
# that the email address may not be verified, meaning, the user
# may be a hacker that has added your email address to their
# account in the hope that you fall in their trap. We cannot
# check on 'email_address.verified' either, because
# 'email_address' is not guaranteed to be verified.
auto_signup = False
# TODO: We redirect to signup form -- user will see email
# address conflict only after posting whereas we detected it
# here already.
else:
assert assessment is None
# Prevent enumeration is properly turned on, meaning, we cannot
# show the signup form to allow the user to input another email
# address. Instead, we're going to send the user an email that
# the account already exists, and on the outside make it appear
# as if an email verification mail was sent.
account_adapter = get_account_adapter(request)
account_adapter.send_account_already_exists_mail(email)
resp = account_adapter.respond_email_verification_sent(request, None)
return False, resp
elif app_settings.EMAIL_REQUIRED:
# Nope, email is required and we don't have it yet...
auto_signup = False
return auto_signup, None
def _process_signup(request, sociallogin):
auto_signup, resp = _process_auto_signup(request, sociallogin)
if resp:
return resp
if not auto_signup:
request.session["socialaccount_sociallogin"] = sociallogin.serialize()
url = reverse("socialaccount_signup")
resp = HttpResponseRedirect(url)
else:
# Ok, auto signup it is, at least the email address is ok.
# We still need to check the username though...
if account_settings.USER_MODEL_USERNAME_FIELD:
username = user_username(sociallogin.user)
try:
get_account_adapter(request).clean_username(username)
except ValidationError:
# This username is no good ...
user_username(sociallogin.user, "")
# TODO: This part contains a lot of duplication of logic
# ("closed" rendering, create user, send email, in active
# etc..)
if not get_adapter().is_open_for_signup(request, sociallogin):
return render(
request,
"account/signup_closed." + account_settings.TEMPLATE_EXTENSION,
)
get_adapter().save_user(request, sociallogin, form=None)
resp = complete_social_signup(request, sociallogin)
return resp
def _login_social_account(request, sociallogin):
return perform_login(
request,
sociallogin.user,
email_verification=app_settings.EMAIL_VERIFICATION,
redirect_url=sociallogin.get_redirect_url(request),
signal_kwargs={"sociallogin": sociallogin},
)
def render_authentication_error(
request,
provider_id,
error=AuthError.UNKNOWN,
exception=None,
extra_context=None,
):
try:
if extra_context is None:
extra_context = {}
get_adapter().authentication_error(
request,
provider_id,
error=error,
exception=exception,
extra_context=extra_context,
)
except ImmediateHttpResponse as e:
return e.response
if error == AuthError.CANCELLED:
return HttpResponseRedirect(reverse("socialaccount_login_cancelled"))
context = {
"auth_error": {
"provider": provider_id,
"code": error,
"exception": exception,
}
}
context.update(extra_context)
return render(
request,
"socialaccount/authentication_error." + account_settings.TEMPLATE_EXTENSION,
context,
)
def _add_social_account(request, sociallogin):
if request.user.is_anonymous:
# This should not happen. Simply redirect to the connections
# view (which has a login required)
connect_redirect_url = get_adapter().get_connect_redirect_url(
request, sociallogin.account
)
return HttpResponseRedirect(connect_redirect_url)
level = messages.INFO
message = "socialaccount/messages/account_connected.txt"
action = None
if sociallogin.is_existing:
if sociallogin.user != request.user:
# Social account of other user. For now, this scenario
# is not supported. Issue is that one cannot simply
# remove the social account from the other user, as
# that may render the account unusable.
level = messages.ERROR
message = "socialaccount/messages/account_connected_other.txt"
else:
# This account is already connected -- we give the opportunity
# for customized behaviour through use of a signal.
action = "updated"
message = "socialaccount/messages/account_connected_updated.txt"
else:
# New account, let's connect
action = "added"
sociallogin.connect(request, request.user)
assert request.user.is_authenticated
default_next = get_adapter().get_connect_redirect_url(request, sociallogin.account)
next_url = sociallogin.get_redirect_url(request) or default_next
get_account_adapter(request).add_message(
request,
level,
message,
message_context={"sociallogin": sociallogin, "action": action},
)
return HttpResponseRedirect(next_url)
def complete_social_login(request, sociallogin):
assert not sociallogin.is_existing
sociallogin.lookup()
try:
get_adapter().pre_social_login(request, sociallogin)
signals.pre_social_login.send(
sender=SocialLogin, request=request, sociallogin=sociallogin
)
process = sociallogin.state.get("process")
if process == AuthProcess.REDIRECT:
return _social_login_redirect(request, sociallogin)
elif process == AuthProcess.CONNECT:
return _add_social_account(request, sociallogin)
else:
return _complete_social_login(request, sociallogin)
except ImmediateHttpResponse as e:
return e.response
def _social_login_redirect(request, sociallogin):
next_url = sociallogin.get_redirect_url(request) or "/"
return HttpResponseRedirect(next_url)
def _complete_social_login(request, sociallogin):
if request.user.is_authenticated:
get_account_adapter(request).logout(request)
if sociallogin.is_existing:
# Login existing user
ret = _login_social_account(request, sociallogin)
else:
# New social user
ret = _process_signup(request, sociallogin)
return ret
def complete_social_signup(request, sociallogin):
return complete_signup(
request,
sociallogin.user,
app_settings.EMAIL_VERIFICATION,
sociallogin.get_redirect_url(request),
signal_kwargs={"sociallogin": sociallogin},
)
def socialaccount_user_display(socialaccount):
func = app_settings.SOCIALACCOUNT_STR
if not func:
return user_display(socialaccount.user)
return func(socialaccount)

View File

@@ -0,0 +1,202 @@
# -*- coding: utf-8 -*-
from __future__ import unicode_literals
from django.conf import settings
from django.db import migrations, models
import allauth.socialaccount.fields
from allauth import app_settings
from allauth.socialaccount.providers import registry
class Migration(migrations.Migration):
dependencies = (
[
("sites", "0001_initial"),
]
if app_settings.SITES_ENABLED
else []
+ [
migrations.swappable_dependency(settings.AUTH_USER_MODEL),
]
)
operations = [
migrations.CreateModel(
name="SocialAccount",
fields=[
(
"id",
models.AutoField(
verbose_name="ID",
serialize=False,
auto_created=True,
primary_key=True,
),
),
(
"provider",
models.CharField(
max_length=30,
verbose_name="provider",
choices=registry.as_choices(),
),
),
(
"uid",
models.CharField(
max_length=getattr(
settings, "SOCIALACCOUNT_UID_MAX_LENGTH", 191
),
verbose_name="uid",
),
),
(
"last_login",
models.DateTimeField(auto_now=True, verbose_name="last login"),
),
(
"date_joined",
models.DateTimeField(auto_now_add=True, verbose_name="date joined"),
),
(
"extra_data",
allauth.socialaccount.fields.JSONField(
default="{}", verbose_name="extra data"
),
),
(
"user",
models.ForeignKey(
to=settings.AUTH_USER_MODEL, on_delete=models.CASCADE
),
),
],
options={
"verbose_name": "social account",
"verbose_name_plural": "social accounts",
},
bases=(models.Model,),
),
migrations.CreateModel(
name="SocialApp",
fields=[
(
"id",
models.AutoField(
verbose_name="ID",
serialize=False,
auto_created=True,
primary_key=True,
),
),
(
"provider",
models.CharField(
max_length=30,
verbose_name="provider",
choices=registry.as_choices(),
),
),
("name", models.CharField(max_length=40, verbose_name="name")),
(
"client_id",
models.CharField(
help_text="App ID, or consumer key",
max_length=100,
verbose_name="client id",
),
),
(
"secret",
models.CharField(
help_text="API secret, client secret, or consumer secret",
max_length=100,
verbose_name="secret key",
),
),
(
"key",
models.CharField(
help_text="Key",
max_length=100,
verbose_name="key",
blank=True,
),
),
]
+ (
[
("sites", models.ManyToManyField(to="sites.Site", blank=True)),
]
if app_settings.SITES_ENABLED
else []
),
options={
"verbose_name": "social application",
"verbose_name_plural": "social applications",
},
bases=(models.Model,),
),
migrations.CreateModel(
name="SocialToken",
fields=[
(
"id",
models.AutoField(
verbose_name="ID",
serialize=False,
auto_created=True,
primary_key=True,
),
),
(
"token",
models.TextField(
help_text='"oauth_token" (OAuth1) or access token (OAuth2)',
verbose_name="token",
),
),
(
"token_secret",
models.TextField(
help_text='"oauth_token_secret" (OAuth1) or refresh token (OAuth2)',
verbose_name="token secret",
blank=True,
),
),
(
"expires_at",
models.DateTimeField(
null=True, verbose_name="expires at", blank=True
),
),
(
"account",
models.ForeignKey(
to="socialaccount.SocialAccount",
on_delete=models.CASCADE,
),
),
(
"app",
models.ForeignKey(
to="socialaccount.SocialApp", on_delete=models.CASCADE
),
),
],
options={
"verbose_name": "social application token",
"verbose_name_plural": "social application tokens",
},
bases=(models.Model,),
),
migrations.AlterUniqueTogether(
name="socialtoken",
unique_together=set([("app", "account")]),
),
migrations.AlterUniqueTogether(
name="socialaccount",
unique_together=set([("provider", "uid")]),
),
]

View File

@@ -0,0 +1,48 @@
# -*- coding: utf-8 -*-
from __future__ import unicode_literals
from django.conf import settings
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("socialaccount", "0001_initial"),
]
operations = [
migrations.AlterField(
model_name="socialaccount",
name="uid",
field=models.CharField(
max_length=getattr(settings, "SOCIALACCOUNT_UID_MAX_LENGTH", 191),
verbose_name="uid",
),
),
migrations.AlterField(
model_name="socialapp",
name="client_id",
field=models.CharField(
help_text="App ID, or consumer key",
max_length=191,
verbose_name="client id",
),
),
migrations.AlterField(
model_name="socialapp",
name="key",
field=models.CharField(
help_text="Key", max_length=191, verbose_name="key", blank=True
),
),
migrations.AlterField(
model_name="socialapp",
name="secret",
field=models.CharField(
help_text="API secret, client secret, or consumer secret",
max_length=191,
verbose_name="secret key",
blank=True,
),
),
]

View File

@@ -0,0 +1,23 @@
# -*- coding: utf-8 -*-
from __future__ import unicode_literals
from django.db import migrations
import allauth.socialaccount.fields
class Migration(migrations.Migration):
dependencies = [
("socialaccount", "0002_token_max_lengths"),
]
operations = [
migrations.AlterField(
model_name="socialaccount",
name="extra_data",
field=allauth.socialaccount.fields.JSONField(
default=dict, verbose_name="extra data"
),
preserve_default=True,
),
]

View File

@@ -0,0 +1,29 @@
# Generated by Django 3.2.19 on 2023-06-30 13:16
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("socialaccount", "0003_extra_data_default_dict"),
]
operations = [
migrations.AddField(
model_name="socialapp",
name="provider_id",
field=models.CharField(
blank=True, max_length=200, verbose_name="provider ID"
),
),
migrations.AddField(
model_name="socialapp",
name="settings",
field=models.JSONField(blank=True, default=dict),
),
migrations.AlterField(
model_name="socialaccount",
name="provider",
field=models.CharField(max_length=200, verbose_name="provider"),
),
]

View File

@@ -0,0 +1,23 @@
# Generated by Django 3.2.20 on 2023-09-03 19:46
import django.db.models.deletion
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("socialaccount", "0004_app_provider_id_settings"),
]
operations = [
migrations.AlterField(
model_name="socialtoken",
name="app",
field=models.ForeignKey(
blank=True,
null=True,
on_delete=django.db.models.deletion.SET_NULL,
to="socialaccount.socialapp",
),
),
]

View File

@@ -0,0 +1,378 @@
from __future__ import absolute_import
from django.contrib.auth import authenticate, get_user_model
from django.contrib.sites.shortcuts import get_current_site
from django.core.exceptions import PermissionDenied
from django.db import models
from django.utils.crypto import get_random_string
from django.utils.translation import gettext_lazy as _
import allauth.app_settings
from allauth.account.models import EmailAddress
from allauth.account.utils import get_next_redirect_url, setup_user_email
from allauth.core import context
from allauth.socialaccount import signals
from ..utils import get_request_param
from . import app_settings, providers
from .adapter import get_adapter
from .fields import JSONField
class SocialAppManager(models.Manager):
def on_site(self, request):
if allauth.app_settings.SITES_ENABLED:
site = get_current_site(request)
return self.filter(sites__id=site.id)
return self.all()
class SocialApp(models.Model):
objects = SocialAppManager()
# The provider type, e.g. "google", "telegram", "saml".
provider = models.CharField(
verbose_name=_("provider"),
max_length=30,
choices=providers.registry.as_choices(),
)
# For providers that support subproviders, such as OpenID Connect and SAML,
# this ID identifies that instance. SocialAccount's originating from app
# will have their `provider` field set to the `provider_id` if available,
# else `provider`.
provider_id = models.CharField(
verbose_name=_("provider ID"),
max_length=200,
blank=True,
)
name = models.CharField(verbose_name=_("name"), max_length=40)
client_id = models.CharField(
verbose_name=_("client id"),
max_length=191,
help_text=_("App ID, or consumer key"),
)
secret = models.CharField(
verbose_name=_("secret key"),
max_length=191,
blank=True,
help_text=_("API secret, client secret, or consumer secret"),
)
key = models.CharField(
verbose_name=_("key"), max_length=191, blank=True, help_text=_("Key")
)
settings = models.JSONField(default=dict, blank=True)
if allauth.app_settings.SITES_ENABLED:
# Most apps can be used across multiple domains, therefore we use
# a ManyToManyField. Note that Facebook requires an app per domain
# (unless the domains share a common base name).
# blank=True allows for disabling apps without removing them
sites = models.ManyToManyField("sites.Site", blank=True)
# We want to move away from storing secrets in the database. So, we're
# putting a halt towards adding more fields for additional secrets, such as
# the certificate some providers need. Therefore, the certificate is not a
# DB backed field and can only be set using the ``APP`` configuration key
# in the provider settings.
certificate_key = None
class Meta:
verbose_name = _("social application")
verbose_name_plural = _("social applications")
def __str__(self):
return self.name
def get_provider(self, request):
provider_class = providers.registry.get_class(self.provider)
return provider_class(request=request, app=self)
class SocialAccount(models.Model):
user = models.ForeignKey(allauth.app_settings.USER_MODEL, on_delete=models.CASCADE)
# Given a `SocialApp` from which this account originates, this field equals
# the app's `app.provider_id` if available, `app.provider` otherwise.
provider = models.CharField(
verbose_name=_("provider"),
max_length=200,
)
# Just in case you're wondering if an OpenID identity URL is going
# to fit in a 'uid':
#
# Ideally, URLField(max_length=1024, unique=True) would be used
# for identity. However, MySQL has a max_length limitation of 191
# for URLField (in case of utf8mb4). How about
# models.TextField(unique=True) then? Well, that won't work
# either for MySQL due to another bug[1]. So the only way out
# would be to drop the unique constraint, or switch to shorter
# identity URLs. Opted for the latter, as [2] suggests that
# identity URLs are supposed to be short anyway, at least for the
# old spec.
#
# [1] http://code.djangoproject.com/ticket/2495.
# [2] http://openid.net/specs/openid-authentication-1_1.html#limits
uid = models.CharField(
verbose_name=_("uid"), max_length=app_settings.UID_MAX_LENGTH
)
last_login = models.DateTimeField(verbose_name=_("last login"), auto_now=True)
date_joined = models.DateTimeField(verbose_name=_("date joined"), auto_now_add=True)
extra_data = JSONField(verbose_name=_("extra data"), default=dict)
class Meta:
unique_together = ("provider", "uid")
verbose_name = _("social account")
verbose_name_plural = _("social accounts")
def authenticate(self):
return authenticate(account=self)
def __str__(self):
from .helpers import socialaccount_user_display
return socialaccount_user_display(self)
def get_profile_url(self):
return self.get_provider_account().get_profile_url()
def get_avatar_url(self):
return self.get_provider_account().get_avatar_url()
def get_provider(self, request=None):
provider = getattr(self, "_provider", None)
if provider:
return provider
adapter = get_adapter()
provider = self._provider = adapter.get_provider(
request, provider=self.provider
)
return provider
def get_provider_account(self):
return self.get_provider().wrap_account(self)
class SocialToken(models.Model):
app = models.ForeignKey(SocialApp, on_delete=models.SET_NULL, blank=True, null=True)
account = models.ForeignKey(SocialAccount, on_delete=models.CASCADE)
token = models.TextField(
verbose_name=_("token"),
help_text=_('"oauth_token" (OAuth1) or access token (OAuth2)'),
)
token_secret = models.TextField(
blank=True,
verbose_name=_("token secret"),
help_text=_('"oauth_token_secret" (OAuth1) or refresh token (OAuth2)'),
)
expires_at = models.DateTimeField(
blank=True, null=True, verbose_name=_("expires at")
)
class Meta:
unique_together = ("app", "account")
verbose_name = _("social application token")
verbose_name_plural = _("social application tokens")
def __str__(self):
return self.token
class SocialLogin(object):
"""
Represents a social user that is in the process of being logged
in. This consists of the following information:
`account` (`SocialAccount` instance): The social account being
logged in. Providers are not responsible for checking whether or
not an account already exists or not. Therefore, a provider
typically creates a new (unsaved) `SocialAccount` instance. The
`User` instance pointed to by the account (`account.user`) may be
prefilled by the provider for use as a starting point later on
during the signup process.
`token` (`SocialToken` instance): An optional access token token
that results from performing a successful authentication
handshake.
`state` (`dict`): The state to be preserved during the
authentication handshake. Note that this state may end up in the
url -- do not put any secrets in here. It currently only contains
the url to redirect to after login.
`email_addresses` (list of `EmailAddress`): Optional list of
email addresses retrieved from the provider.
"""
def __init__(self, user=None, account=None, token=None, email_addresses=[]):
if token:
assert token.account is None or token.account == account
self.token = token
self.user = user
self.account = account
self.email_addresses = email_addresses
self.state = {}
def connect(self, request, user):
self.user = user
self.save(request, connect=True)
signals.social_account_added.send(
sender=SocialLogin, request=request, sociallogin=self
)
def serialize(self):
serialize_instance = get_adapter().serialize_instance
ret = dict(
account=serialize_instance(self.account),
user=serialize_instance(self.user),
state=self.state,
email_addresses=[serialize_instance(ea) for ea in self.email_addresses],
)
if self.token:
ret["token"] = serialize_instance(self.token)
return ret
@classmethod
def deserialize(cls, data):
deserialize_instance = get_adapter().deserialize_instance
account = deserialize_instance(SocialAccount, data["account"])
user = deserialize_instance(get_user_model(), data["user"])
if "token" in data:
token = deserialize_instance(SocialToken, data["token"])
else:
token = None
email_addresses = []
for ea in data["email_addresses"]:
email_address = deserialize_instance(EmailAddress, ea)
email_addresses.append(email_address)
ret = cls()
ret.token = token
ret.account = account
ret.user = user
ret.email_addresses = email_addresses
ret.state = data["state"]
return ret
def save(self, request, connect=False):
"""
Saves a new account. Note that while the account is new,
the user may be an existing one (when connecting accounts)
"""
user = self.user
user.save()
self.account.user = user
self.account.save()
if app_settings.STORE_TOKENS and self.token:
self.token.account = self.account
self.token.save()
if connect:
# TODO: Add any new email addresses automatically?
pass
else:
setup_user_email(request, user, self.email_addresses)
@property
def is_existing(self):
"""When `False`, this social login represents a temporary account, not
yet backed by a database record.
"""
if self.user.pk is None:
return False
return get_user_model().objects.filter(pk=self.user.pk).exists()
def lookup(self):
"""Look up the existing local user account to which this social login
points, if any.
"""
if not self._lookup_by_socialaccount():
provider_id = self.account.get_provider().id
if app_settings.EMAIL_AUTHENTICATION or app_settings.PROVIDERS.get(
provider_id, {}
).get("EMAIL_AUTHENTICATION", False):
self._lookup_by_email()
def _lookup_by_socialaccount(self):
assert not self.is_existing
try:
a = SocialAccount.objects.get(
provider=self.account.provider, uid=self.account.uid
)
# Update account
a.extra_data = self.account.extra_data
self.account = a
self.user = self.account.user
a.save()
signals.social_account_updated.send(
sender=SocialLogin, request=context.request, sociallogin=self
)
# Update token
if app_settings.STORE_TOKENS and self.token:
assert not self.token.pk
try:
t = SocialToken.objects.get(
account=self.account, app=self.token.app
)
t.token = self.token.token
if self.token.token_secret:
# only update the refresh token if we got one
# many oauth2 providers do not resend the refresh token
t.token_secret = self.token.token_secret
t.expires_at = self.token.expires_at
t.save()
self.token = t
except SocialToken.DoesNotExist:
self.token.account = a
self.token.save()
return True
except SocialAccount.DoesNotExist:
pass
def _lookup_by_email(self):
emails = [e.email for e in self.email_addresses if e.verified]
if not emails:
return
address = (
EmailAddress.objects.lookup(emails).order_by("-verified", "user_id").first()
)
if address:
if app_settings.EMAIL_AUTHENTICATION_AUTO_CONNECT:
self.connect(context.request, address.user)
else:
self.user = address.user
def get_redirect_url(self, request):
url = self.state.get("next")
return url
@classmethod
def state_from_request(cls, request):
state = {}
next_url = get_next_redirect_url(request)
if next_url:
state["next"] = next_url
state["process"] = get_request_param(request, "process", "login")
state["scope"] = get_request_param(request, "scope", "")
state["auth_params"] = get_request_param(request, "auth_params", "")
return state
@classmethod
def stash_state(cls, request):
state = cls.state_from_request(request)
verifier = get_random_string(16)
request.session["socialaccount_state"] = (state, verifier)
return verifier
@classmethod
def unstash_state(cls, request):
if "socialaccount_state" not in request.session:
raise PermissionDenied()
state, verifier = request.session.pop("socialaccount_state")
return state
@classmethod
def verify_and_unstash_state(cls, request, verifier):
if "socialaccount_state" not in request.session:
raise PermissionDenied()
state, verifier2 = request.session.pop("socialaccount_state")
if verifier != verifier2:
raise PermissionDenied()
return state

View File

@@ -0,0 +1,57 @@
import importlib
from collections import OrderedDict
from django.apps import apps
from django.conf import settings
from allauth.utils import import_attribute
class ProviderRegistry(object):
def __init__(self):
self.provider_map = OrderedDict()
self.loaded = False
def get_class_list(self):
self.load()
return list(self.provider_map.values())
def register(self, cls):
self.provider_map[cls.id] = cls
def get_class(self, id):
return self.provider_map.get(id)
def as_choices(self):
self.load()
for provider_cls in self.provider_map.values():
yield (provider_cls.id, provider_cls.name)
def load(self):
# TODO: Providers register with the provider registry when
# loaded. Here, we build the URLs for all registered providers. So, we
# really need to be sure all providers did register, which is why we're
# forcefully importing the `provider` modules here. The overall
# mechanism is way to magical and depends on the import order et al, so
# all of this really needs to be revisited.
if not self.loaded:
for app_config in apps.get_app_configs():
try:
provider_module = importlib.import_module(
app_config.name + ".provider"
)
except ImportError:
pass
else:
provider_settings = getattr(settings, "SOCIALACCOUNT_PROVIDERS", {})
for cls in getattr(provider_module, "provider_classes", []):
provider_class = provider_settings.get(cls.id, {}).get(
"provider_class"
)
if provider_class:
cls = import_attribute(provider_class)
self.register(cls)
self.loaded = True
registry = ProviderRegistry()

View File

@@ -0,0 +1,39 @@
from allauth.socialaccount.providers.base import ProviderAccount
from allauth.socialaccount.providers.oauth2.provider import OAuth2Provider
class AgaveAccount(ProviderAccount):
def get_profile_url(self):
return self.account.extra_data.get("web_url", "dflt")
def get_avatar_url(self):
return self.account.extra_data.get("avatar_url", "dflt")
def to_str(self):
dflt = super(AgaveAccount, self).to_str()
return self.account.extra_data.get("name", dflt)
class AgaveProvider(OAuth2Provider):
id = "agave"
name = "Agave"
account_class = AgaveAccount
def extract_uid(self, data):
return str(data.get("create_time"))
def extract_common_fields(self, data):
return dict(
email=data.get("email"),
username=data.get("username", ""),
name=(
(data.get("first_name", "") + " " + data.get("last_name", "")).strip()
),
)
def get_default_scope(self):
scope = ["PRODUCTION"]
return scope
provider_classes = [AgaveProvider]

View File

@@ -0,0 +1,31 @@
from allauth.socialaccount.tests import OAuth2TestsMixin
from allauth.tests import MockedResponse, TestCase
from .provider import AgaveProvider
class AgaveTests(OAuth2TestsMixin, TestCase):
provider_id = AgaveProvider.id
def get_mocked_response(self):
return MockedResponse(
200,
"""
{
"status": "success",
"message": "User details retrieved successfully.",
"version": "2.0.0-SNAPSHOT-rc3fad",
"result": {
"first_name": "John",
"last_name": "Doe",
"full_name": "John Doe",
"email": "jon@doe.edu",
"phone": "",
"mobile_phone": "",
"status": "Active",
"create_time": "20180322043812Z",
"username": "jdoe"
}
}
""",
)

View File

@@ -0,0 +1,5 @@
from allauth.socialaccount.providers.agave.provider import AgaveProvider
from allauth.socialaccount.providers.oauth2.urls import default_urlpatterns
urlpatterns = default_urlpatterns(AgaveProvider)

View File

@@ -0,0 +1,39 @@
import requests
from allauth.socialaccount import app_settings
from allauth.socialaccount.providers.agave.provider import AgaveProvider
from allauth.socialaccount.providers.oauth2.views import (
OAuth2Adapter,
OAuth2CallbackView,
OAuth2LoginView,
)
class AgaveAdapter(OAuth2Adapter):
provider_id = AgaveProvider.id
settings = app_settings.PROVIDERS.get(provider_id, {})
provider_base_url = settings.get("API_URL", "https://public.agaveapi.co")
access_token_url = "{0}/token".format(provider_base_url)
authorize_url = "{0}/authorize".format(provider_base_url)
profile_url = "{0}/profiles/v2/me".format(provider_base_url)
def complete_login(self, request, app, token, response):
extra_data = requests.get(
self.profile_url,
params={"access_token": token.token},
headers={
"Authorization": "Bearer " + token.token,
},
)
user_profile = (
extra_data.json()["result"] if "result" in extra_data.json() else {}
)
return self.get_provider().sociallogin_from_response(request, user_profile)
oauth2_login = OAuth2LoginView.adapter_view(AgaveAdapter)
oauth2_callback = OAuth2CallbackView.adapter_view(AgaveAdapter)

View File

@@ -0,0 +1,33 @@
from allauth.socialaccount.providers.base import ProviderAccount
from allauth.socialaccount.providers.oauth2.provider import OAuth2Provider
class AmazonAccount(ProviderAccount):
def to_str(self):
return self.account.extra_data.get("name", super(AmazonAccount, self).to_str())
class AmazonProvider(OAuth2Provider):
id = "amazon"
name = "Amazon"
account_class = AmazonAccount
def get_default_scope(self):
return ["profile"]
def extract_uid(self, data):
return str(data["user_id"])
def extract_common_fields(self, data):
# Hackish way of splitting the fullname.
# Assumes no middlenames.
name = data.get("name", "")
first_name, last_name = name, ""
if name and " " in name:
first_name, last_name = name.split(" ", 1)
return dict(
email=data.get("email", ""), last_name=last_name, first_name=first_name
)
provider_classes = [AmazonProvider]

View File

@@ -0,0 +1,21 @@
from allauth.socialaccount.tests import OAuth2TestsMixin
from allauth.tests import MockedResponse, TestCase
from .provider import AmazonProvider
class AmazonTests(OAuth2TestsMixin, TestCase):
provider_id = AmazonProvider.id
def get_mocked_response(self):
return MockedResponse(
200,
"""
{
"Profile":{
"CustomerId":"amzn1.account.K2LI23KL2LK2",
"Name":"John Doe",
"PrimaryEmail":"johndoe@example.com"
}
}""",
)

View File

@@ -0,0 +1,6 @@
from allauth.socialaccount.providers.oauth2.urls import default_urlpatterns
from .provider import AmazonProvider
urlpatterns = default_urlpatterns(AmazonProvider)

View File

@@ -0,0 +1,32 @@
import requests
from allauth.socialaccount.providers.oauth2.views import (
OAuth2Adapter,
OAuth2CallbackView,
OAuth2LoginView,
)
from .provider import AmazonProvider
class AmazonOAuth2Adapter(OAuth2Adapter):
provider_id = AmazonProvider.id
access_token_url = "https://api.amazon.com/auth/o2/token"
authorize_url = "http://www.amazon.com/ap/oa"
profile_url = "https://api.amazon.com/user/profile"
supports_state = False
def complete_login(self, request, app, token, **kwargs):
response = requests.get(self.profile_url, params={"access_token": token})
extra_data = response.json()
if "Profile" in extra_data:
extra_data = {
"user_id": extra_data["Profile"]["CustomerId"],
"name": extra_data["Profile"]["Name"],
"email": extra_data["Profile"]["PrimaryEmail"],
}
return self.get_provider().sociallogin_from_response(request, extra_data)
oauth2_login = OAuth2LoginView.adapter_view(AmazonOAuth2Adapter)
oauth2_callback = OAuth2CallbackView.adapter_view(AmazonOAuth2Adapter)

View File

@@ -0,0 +1,78 @@
from allauth.account.models import EmailAddress
from allauth.socialaccount.providers.amazon_cognito.utils import (
convert_to_python_bool_if_value_is_json_string_bool,
)
from allauth.socialaccount.providers.base import ProviderAccount
from allauth.socialaccount.providers.oauth2.provider import OAuth2Provider
class AmazonCognitoAccount(ProviderAccount):
def to_str(self):
dflt = super(AmazonCognitoAccount, self).to_str()
return self.account.extra_data.get("username", dflt)
def get_avatar_url(self):
return self.account.extra_data.get("picture")
def get_profile_url(self):
return self.account.extra_data.get("profile")
class AmazonCognitoProvider(OAuth2Provider):
id = "amazon_cognito"
name = "Amazon Cognito"
account_class = AmazonCognitoAccount
def extract_uid(self, data):
return str(data["sub"])
def extract_common_fields(self, data):
return {
"email": data.get("email"),
"first_name": data.get("given_name"),
"last_name": data.get("family_name"),
}
def get_default_scope(self):
return ["openid", "profile", "email"]
def extract_email_addresses(self, data):
email = data.get("email")
verified = convert_to_python_bool_if_value_is_json_string_bool(
data.get("email_verified", False)
)
return (
[EmailAddress(email=email, verified=verified, primary=True)]
if email
else []
)
def extract_extra_data(self, data):
return {
"address": data.get("address"),
"birthdate": data.get("birthdate"),
"gender": data.get("gender"),
"locale": data.get("locale"),
"middlename": data.get("middlename"),
"nickname": data.get("nickname"),
"phone_number": data.get("phone_number"),
"phone_number_verified": convert_to_python_bool_if_value_is_json_string_bool(
data.get("phone_number_verified")
),
"picture": data.get("picture"),
"preferred_username": data.get("preferred_username"),
"profile": data.get("profile"),
"website": data.get("website"),
"zoneinfo": data.get("zoneinfo"),
}
@classmethod
def get_slug(cls):
# IMPORTANT: Amazon Cognito does not support `_` characters
# as part of their redirect URI.
return super(AmazonCognitoProvider, cls).get_slug().replace("_", "-")
provider_classes = [AmazonCognitoProvider]

View File

@@ -0,0 +1,69 @@
import json
from django.test import override_settings
from allauth.account.models import EmailAddress
from allauth.socialaccount.models import SocialAccount
from allauth.socialaccount.providers.amazon_cognito.provider import (
AmazonCognitoProvider,
)
from allauth.socialaccount.providers.amazon_cognito.views import (
AmazonCognitoOAuth2Adapter,
)
from allauth.socialaccount.tests import OAuth2TestsMixin
from allauth.tests import MockedResponse, TestCase
def _get_mocked_claims():
return {
"sub": "4993b410-8a1b-4c36-b843-a9c1a697e6b7",
"given_name": "John",
"family_name": "Doe",
"email": "jdoe@example.com",
"username": "johndoe",
}
@override_settings(
SOCIALACCOUNT_PROVIDERS={
"amazon_cognito": {"DOMAIN": "https://domain.auth.us-east-1.amazoncognito.com"}
}
)
class AmazonCognitoTestCase(OAuth2TestsMixin, TestCase):
provider_id = AmazonCognitoProvider.id
def get_mocked_response(self):
mocked_payload = json.dumps(_get_mocked_claims())
return MockedResponse(status_code=200, content=mocked_payload)
@override_settings(SOCIALACCOUNT_PROVIDERS={"amazon_cognito": {}})
def test_oauth2_adapter_raises_if_domain_settings_is_missing(
self,
):
mocked_response = self.get_mocked_response()
with self.assertRaises(
ValueError,
msg=AmazonCognitoOAuth2Adapter.DOMAIN_KEY_MISSING_ERROR,
):
self.login(mocked_response)
def test_saves_email_as_verified_if_email_is_verified_in_cognito(
self,
):
mocked_claims = _get_mocked_claims()
mocked_claims["email_verified"] = True
mocked_payload = json.dumps(mocked_claims)
mocked_response = MockedResponse(status_code=200, content=mocked_payload)
self.login(mocked_response)
user_id = SocialAccount.objects.get(uid=mocked_claims["sub"]).user_id
email_address = EmailAddress.objects.get(user_id=user_id)
self.assertEqual(email_address.email, mocked_claims["email"])
self.assertTrue(email_address.verified)
def test_provider_slug_replaces_underscores_with_hyphens(self):
self.assertTrue("_" not in self.provider.get_slug())

View File

@@ -0,0 +1,7 @@
from allauth.socialaccount.providers.amazon_cognito.provider import (
AmazonCognitoProvider,
)
from allauth.socialaccount.providers.oauth2.urls import default_urlpatterns
urlpatterns = default_urlpatterns(AmazonCognitoProvider)

View File

@@ -0,0 +1,7 @@
def convert_to_python_bool_if_value_is_json_string_bool(s):
if s == "true":
return True
elif s == "false":
return False
return s

View File

@@ -0,0 +1,57 @@
import requests
from allauth.socialaccount import app_settings
from allauth.socialaccount.providers.amazon_cognito.provider import (
AmazonCognitoProvider,
)
from allauth.socialaccount.providers.oauth2.views import (
OAuth2Adapter,
OAuth2CallbackView,
OAuth2LoginView,
)
class AmazonCognitoOAuth2Adapter(OAuth2Adapter):
provider_id = AmazonCognitoProvider.id
DOMAIN_KEY_MISSING_ERROR = (
'"DOMAIN" key is missing in Amazon Cognito configuration.'
)
@property
def settings(self):
return app_settings.PROVIDERS.get(self.provider_id, {})
@property
def domain(self):
domain = self.settings.get("DOMAIN")
if domain is None:
raise ValueError(self.DOMAIN_KEY_MISSING_ERROR)
return domain
@property
def access_token_url(self):
return "{}/oauth2/token".format(self.domain)
@property
def authorize_url(self):
return "{}/oauth2/authorize".format(self.domain)
@property
def profile_url(self):
return "{}/oauth2/userInfo".format(self.domain)
def complete_login(self, request, app, access_token, **kwargs):
headers = {
"Authorization": "Bearer {}".format(access_token),
}
extra_data = requests.get(self.profile_url, headers=headers)
extra_data.raise_for_status()
return self.get_provider().sociallogin_from_response(request, extra_data.json())
oauth2_login = OAuth2LoginView.adapter_view(AmazonCognitoOAuth2Adapter)
oauth2_callback = OAuth2CallbackView.adapter_view(AmazonCognitoOAuth2Adapter)

View File

@@ -0,0 +1,33 @@
from allauth.socialaccount.providers.base import ProviderAccount
from allauth.socialaccount.providers.oauth2.provider import OAuth2Provider
class AngelListAccount(ProviderAccount):
def get_profile_url(self):
return self.account.extra_data.get("angellist_url")
def get_avatar_url(self):
return self.account.extra_data.get("image")
def to_str(self):
dflt = super(AngelListAccount, self).to_str()
return self.account.extra_data.get("name", dflt)
class AngelListProvider(OAuth2Provider):
id = "angellist"
name = "AngelList"
account_class = AngelListAccount
def extract_uid(self, data):
return str(data["id"])
def extract_common_fields(self, data):
return dict(
email=data.get("email"),
username=data.get("angellist_url").split("/")[-1],
name=data.get("name"),
)
provider_classes = [AngelListProvider]

View File

@@ -0,0 +1,25 @@
from allauth.socialaccount.tests import OAuth2TestsMixin
from allauth.tests import MockedResponse, TestCase
from .provider import AngelListProvider
class AngelListTests(OAuth2TestsMixin, TestCase):
provider_id = AngelListProvider.id
def get_mocked_response(self):
return MockedResponse(
200,
"""
{"name":"pennersr","id":424732,"bio":"","follower_count":0,
"angellist_url":"https://angel.co/dsxtst",
"image":"https://angel.co/images/shared/nopic.png",
"email":"raymond.penners@example.com","blog_url":null,
"online_bio_url":null,"twitter_url":"https://twitter.com/dsxtst",
"facebook_url":null,"linkedin_url":null,"aboutme_url":null,
"github_url":null,"dribbble_url":null,"behance_url":null,
"what_ive_built":null,"locations":[],"roles":[],"skills":[],
"investor":false,"scopes":["message","talent","dealflow","comment",
"email"]}
""",
)

View File

@@ -0,0 +1,6 @@
from allauth.socialaccount.providers.oauth2.urls import default_urlpatterns
from .provider import AngelListProvider
urlpatterns = default_urlpatterns(AngelListProvider)

View File

@@ -0,0 +1,26 @@
import requests
from allauth.socialaccount.providers.oauth2.views import (
OAuth2Adapter,
OAuth2CallbackView,
OAuth2LoginView,
)
from .provider import AngelListProvider
class AngelListOAuth2Adapter(OAuth2Adapter):
provider_id = AngelListProvider.id
access_token_url = "https://angel.co/api/oauth/token/"
authorize_url = "https://angel.co/api/oauth/authorize/"
profile_url = "https://api.angel.co/1/me/"
supports_state = False
def complete_login(self, request, app, token, **kwargs):
resp = requests.get(self.profile_url, params={"access_token": token.token})
extra_data = resp.json()
return self.get_provider().sociallogin_from_response(request, extra_data)
oauth2_login = OAuth2LoginView.adapter_view(AngelListOAuth2Adapter)
oauth2_callback = OAuth2CallbackView.adapter_view(AngelListOAuth2Adapter)

View File

@@ -0,0 +1,8 @@
from allauth.socialaccount.sessions import LoginSession
APPLE_SESSION_COOKIE_NAME = "apple-login-session"
def get_apple_session(request):
return LoginSession(request, "apple_login_session", APPLE_SESSION_COOKIE_NAME)

View File

@@ -0,0 +1,98 @@
import requests
from datetime import datetime, timedelta
from urllib.parse import parse_qsl, quote, urlencode
from django.core.exceptions import ImproperlyConfigured
import jwt
from allauth.socialaccount.adapter import get_adapter
from allauth.socialaccount.providers.oauth2.client import (
OAuth2Client,
OAuth2Error,
)
def jwt_encode(*args, **kwargs):
resp = jwt.encode(*args, **kwargs)
if isinstance(resp, bytes):
# For PyJWT <2
resp = resp.decode("utf-8")
return resp
class Scope(object):
EMAIL = "email"
NAME = "name"
class AppleOAuth2Client(OAuth2Client):
"""
Custom client because `Sign In With Apple`:
* requires `response_mode` field in redirect_url
* requires special `client_secret` as JWT
"""
def generate_client_secret(self):
"""Create a JWT signed with an apple provided private key"""
now = datetime.utcnow()
app = get_adapter(self.request).get_app(self.request, "apple")
if not app.key:
raise ImproperlyConfigured("Apple 'key' missing")
if not app.certificate_key:
raise ImproperlyConfigured("Apple 'certificate_key' missing")
claims = {
"iss": app.key,
"aud": "https://appleid.apple.com",
"sub": self.get_client_id(),
"iat": now,
"exp": now + timedelta(hours=1),
}
headers = {"kid": self.consumer_secret, "alg": "ES256"}
client_secret = jwt_encode(
payload=claims, key=app.certificate_key, algorithm="ES256", headers=headers
)
return client_secret
def get_client_id(self):
"""We support multiple client_ids, but use the first one for api calls"""
return self.consumer_key.split(",")[0]
def get_access_token(self, code, pkce_code_verifier=None):
url = self.access_token_url
client_secret = self.generate_client_secret()
data = {
"client_id": self.get_client_id(),
"code": code,
"grant_type": "authorization_code",
"redirect_uri": self.callback_url,
"client_secret": client_secret,
}
if pkce_code_verifier:
data["code_verifier"] = pkce_code_verifier
self._strip_empty_keys(data)
resp = requests.request(
self.access_token_method, url, data=data, headers=self.headers
)
access_token = None
if resp.status_code in [200, 201]:
try:
access_token = resp.json()
except ValueError:
access_token = dict(parse_qsl(resp.text))
if not access_token or "access_token" not in access_token:
raise OAuth2Error("Error retrieving access token: %s" % resp.content)
return access_token
def get_redirect_url(self, authorization_url, extra_params):
params = {
"client_id": self.get_client_id(),
"redirect_uri": self.callback_url,
"response_mode": "form_post",
"scope": self.scope,
"response_type": "code id_token",
}
if self.state:
params["state"] = self.state
params.update(extra_params)
return "%s?%s" % (authorization_url, urlencode(params, quote_via=quote))

View File

@@ -0,0 +1,49 @@
from allauth.account.models import EmailAddress
from allauth.socialaccount.app_settings import QUERY_EMAIL
from allauth.socialaccount.providers.base import ProviderAccount
from allauth.socialaccount.providers.oauth2.provider import OAuth2Provider
class AppleProvider(OAuth2Provider):
id = "apple"
name = "Apple"
account_class = ProviderAccount
def extract_uid(self, data):
return str(data["sub"])
def extract_common_fields(self, data):
fields = {"email": data.get("email")}
# If the name was provided
name = data.get("name")
if name:
fields["first_name"] = name.get("firstName", "")
fields["last_name"] = name.get("lastName", "")
return fields
def extract_email_addresses(self, data):
ret = []
email = data.get("email")
verified = data.get("email_verified")
if isinstance(verified, str):
verified = verified.lower() == "true"
if email:
ret.append(
EmailAddress(
email=email,
verified=verified,
primary=True,
)
)
return ret
def get_default_scope(self):
scopes = ["name"]
if QUERY_EMAIL:
scopes.append("email")
return scopes
provider_classes = [AppleProvider]

View File

@@ -0,0 +1,253 @@
import json
from datetime import datetime, timedelta
from importlib import import_module
from urllib.parse import parse_qs, urlparse
from django.conf import settings
from django.test.utils import override_settings
from django.urls import reverse
from django.utils.http import urlencode
import jwt
from allauth.socialaccount.tests import OAuth2TestsMixin
from allauth.tests import MockedResponse, TestCase, mocked_response
from .apple_session import APPLE_SESSION_COOKIE_NAME
from .client import jwt_encode
from .provider import AppleProvider
# Generated on https://mkjwk.org/, used to sign and verify the apple id_token
TESTING_JWT_KEYSET = {
"p": (
"4ADzS5jKx_kdQihyOocVS0Qwwo7m0f7Ow56EadySJ-cmnwoHHF3AxgRaq-h-KwybSphv"
"dc-X7NbS79-b9dumHKyt1MeVLAsDZD1a-uQCEneY1g9LsQkscNr7OggcpvMg5UUFwv6A"
"kavu8cB0iyhNdha5_AWX27K5lNebvpaXEJ8"
),
"kty": "RSA",
"q": (
"yy5UvMjrvZyO1Os_nxXIugCa3NyWOkC8oMppPvr1Bl5AnF_xwXN2n9ozPd9Nb3Q3n-om"
"NgLayyUxhwIjWDlI67Vbx-ESuff8ZEBKuTK0Gdmr4C_QU_j0gvvNMNJweSPxDdRmIUgO"
"njTVNWmdqFTZs43jXAT4J519rgveNLAkGNE"
),
"d": (
"riPuGIDde88WS03CVbo_mZ9toFWPyTxvuz8VInJ9S1ZxULo-hQWDBohWGYwvg8cgfXck"
"cqWt5OBqNvPYdLgwb84uVi2JeEHmhcQSc_x0zfRTau5HVE2KdR-gWxQjPWoaBHeDVqwo"
"PKaU2XYxa-gYDXcuSJWHz3BX13oInDEFCXr6VwiLiwLBFsb63EEHwyWXJbTpoar7AARW"
"kz76qtngDkk4t9gk_Q0L1y1qf1GeWiAL7xWb-bdptma4-1ui-R2219-1ONEZ41v_jsIS"
"_z8ooXmVCbUsHV4Z1UDpRvpORVE3u57WK3qXUdAtZsXjaIwkdItbDmL1jFUgefwfO91Y"
"YQ"
),
"e": "AQAB",
"use": "sig",
"kid": "testkey",
"qi": (
"R0Hu4YmpHzw3SKWGYuAcAo6B97-JlN2fXiTjZ2g8eHGQX7LSoKEu0Hmu5hcBZYSgOuor"
"IPsPUu3mNtx3pjLMOaJRk34VwcYu7h23ogEKGcPUt1c4tTotFDdw8WFptDOw4ow31Tml"
"BPExLqzzGjJeQSNULB1bExuuhYMWx6wBXo8"
),
"dp": (
"WBaHlnbjZ3hDVTzqjrGIYizSr-_aPUJitPKlR6wBncd8nJYo7bLAmB4mOewXkX5HozIG"
"wuF78RsZoFLi1fAmhqgxQ7eopcU-9DBcksUPO4vkgmlJbrkYzNiQauW9vrllekOGXIQQ"
"szhVoqP4MLEMpR-Sy9S3PyItcKbJDE3T4ik"
),
"alg": "RS256",
"dq": (
"Ar5kbIw2CsBzeVKX8FkF9eUOMk9URAMdyPoSw8P1zRk2vCXbiOY7Qttad8ptLEUgfytV"
"SsNtGvMsoQsZWRak8nHnhGJ4s0QzB1OK7sdNgU_cL1HV-VxSSPaHhdJBrJEcrzggDPEB"
"KYfDHU6Iz34d1nvjBxoWE8rfqJsGbCW4xxE"
),
"n": (
"sclLPioUv4VOcOZWAKoRhcvwIH2jOhoHhSI_Cj5c5zSp7qaK8jCU6T7-GObsgrhpty-k"
"26ZuqRdgu9d-62WO8OBGt1e0wxbTh128-nTTrOESHUlV_K1wpJmXOxNpJiybcgzZNbAm"
"ACmsHfxZvN9bt7gKPXxf3-_zFAf12PbYMrOionAJ1N_4HxL7fz3xkr5C87Av06QNilIC"
"-mA-4n9Eqw_R2DYNpE3RYMdWtwKqBwJC8qs3677RpG9vcc-yZ_97pEiytd2FBJ8uoTwH"
"d3DHJB8UVgBSh1kMUpSdoM7HxVzKx732nx6Kusln79LrsfOzrXF4enkfKJYI40-uwT95"
"zw"
),
}
# Mocked version of the test data from https://appleid.apple.com/auth/keys
KEY_SERVER_RESP_JSON = json.dumps(
{
"keys": [
{
"kty": TESTING_JWT_KEYSET["kty"],
"kid": TESTING_JWT_KEYSET["kid"],
"use": TESTING_JWT_KEYSET["use"],
"alg": TESTING_JWT_KEYSET["alg"],
"n": TESTING_JWT_KEYSET["n"],
"e": TESTING_JWT_KEYSET["e"],
}
]
}
)
def sign_id_token(payload):
"""
Sign a payload as apple normally would for the id_token.
"""
signing_key = jwt.algorithms.RSAAlgorithm.from_jwk(json.dumps(TESTING_JWT_KEYSET))
return jwt_encode(
payload,
signing_key,
algorithm="RS256",
headers={"kid": TESTING_JWT_KEYSET["kid"]},
)
@override_settings(
SOCIALACCOUNT_STORE_TOKENS=False,
SOCIALACCOUNT_PROVIDERS={
"apple": {
"APP": {
"client_id": "app123id",
"key": "apple",
"secret": "dummy",
"certificate_key": """-----BEGIN PRIVATE KEY-----
MIGHAgEAMBMGByqGSM49AgEGCCqGSM49AwEHBG0wawIBAQQg2+Eybl8ojH4wB30C
3/iDkpsrxuPfs3DZ+3nHNghBOpmhRANCAAQSpo1eQ+EpNgQQyQVs/F27dkq3gvAI
28m95JEk26v64YAea5NTH56mru30RDqTKPgRVi5qRu3XGyqy3mdb8gMy
-----END PRIVATE KEY-----
""",
}
}
},
)
class AppleTests(OAuth2TestsMixin, TestCase):
provider_id = AppleProvider.id
def get_apple_id_token_payload(self):
now = datetime.utcnow()
return {
"iss": "https://appleid.apple.com",
"aud": "app123id", # Matches `setup_app`
"exp": now + timedelta(hours=1),
"iat": now,
"sub": "000313.c9720f41e9434e18987a.1218",
"at_hash": "CkaUPjk4MJinaAq6Z0tGUA",
"email": "test@privaterelay.appleid.com",
"email_verified": "true",
"is_private_email": "true",
"auth_time": 1234345345, # not converted automatically by pyjwt
}
def get_login_response_json(self, with_refresh_token=True):
"""
`with_refresh_token` is not optional for apple, so it's ignored.
"""
id_token = sign_id_token(self.get_apple_id_token_payload())
return json.dumps(
{
"access_token": "testac", # Matches OAuth2TestsMixin value
"expires_in": 3600,
"id_token": id_token,
"refresh_token": "testrt", # Matches OAuth2TestsMixin value
"token_type": "Bearer",
}
)
def get_mocked_response(self):
"""
Apple is unusual in that the `id_token` contains all the user info
so no profile info request is made. However, it does need the
public key verification, so this mocked response is the public
key request in order to verify the authenticity of the id_token.
"""
return MockedResponse(
200, KEY_SERVER_RESP_JSON, {"content-type": "application/json"}
)
def get_complete_parameters(self, auth_request_params):
"""
Add apple specific response parameters which they include in the
form_post response.
https://developer.apple.com/documentation/sign_in_with_apple/sign_in_with_apple_js/incorporating_sign_in_with_apple_into_other_platforms
"""
params = super().get_complete_parameters(auth_request_params)
params.update(
{
"id_token": sign_id_token(self.get_apple_id_token_payload()),
"user": json.dumps(
{
"email": "private@appleid.apple.com",
"name": {
"firstName": "A",
"lastName": "B",
},
}
),
}
)
return params
def login(self, resp_mock, process="login", with_refresh_token=True):
resp = self.client.post(
reverse(self.provider.id + "_login")
+ "?"
+ urlencode(dict(process=process))
)
p = urlparse(resp["location"])
q = parse_qs(p.query)
complete_url = reverse(self.provider.id + "_callback")
self.assertGreater(q["redirect_uri"][0].find(complete_url), 0)
response_json = self.get_login_response_json(
with_refresh_token=with_refresh_token
)
with mocked_response(
MockedResponse(200, response_json, {"content-type": "application/json"}),
resp_mock,
):
resp = self.client.post(
complete_url,
data=self.get_complete_parameters(q),
)
assert reverse("apple_finish_callback") in resp.url
# Follow the redirect
resp = self.client.get(resp.url)
return resp
def test_authentication_error(self):
"""Override base test because apple posts errors"""
resp = self.client.post(
reverse(self.provider.id + "_callback"),
data={"error": "misc", "state": "testingstate123"},
)
assert reverse("apple_finish_callback") in resp.url
# Follow the redirect
resp = self.client.get(resp.url)
self.assertTemplateUsed(
resp,
"socialaccount/authentication_error.%s"
% getattr(settings, "ACCOUNT_TEMPLATE_EXTENSION", "html"),
)
def test_apple_finish(self):
resp = self.login(self.get_mocked_response())
# Check request generating the response
finish_url = reverse("apple_finish_callback")
self.assertEqual(resp.request["PATH_INFO"], finish_url)
self.assertTrue("state" in resp.request["QUERY_STRING"])
self.assertTrue("code" in resp.request["QUERY_STRING"])
# Check have cookie containing apple session
self.assertTrue(APPLE_SESSION_COOKIE_NAME in self.client.cookies)
# Session should have been cleared
apple_session_cookie = self.client.cookies.get(APPLE_SESSION_COOKIE_NAME)
engine = import_module(settings.SESSION_ENGINE)
SessionStore = engine.SessionStore
apple_login_session = SessionStore(apple_session_cookie.value)
self.assertEqual(len(apple_login_session.keys()), 0)
# Check cookie path was correctly set
self.assertEqual(apple_session_cookie.get("path"), finish_url)

View File

@@ -0,0 +1,16 @@
from django.urls import path
from allauth.socialaccount.providers.oauth2.urls import default_urlpatterns
from .provider import AppleProvider
from .views import oauth2_finish_login
urlpatterns = default_urlpatterns(AppleProvider)
urlpatterns += [
path(
AppleProvider.get_slug() + "/login/callback/finish/",
oauth2_finish_login,
name="apple_finish_callback",
),
]

View File

@@ -0,0 +1,183 @@
import json
import requests
from datetime import timedelta
from django.http import HttpResponseNotAllowed, HttpResponseRedirect
from django.urls import reverse
from django.utils import timezone
from django.utils.http import urlencode
from django.views.decorators.csrf import csrf_exempt
import jwt
from allauth.socialaccount.adapter import get_adapter
from allauth.socialaccount.models import SocialToken
from allauth.socialaccount.providers.oauth2.client import OAuth2Error
from allauth.socialaccount.providers.oauth2.views import (
OAuth2Adapter,
OAuth2CallbackView,
OAuth2LoginView,
)
from allauth.utils import build_absolute_uri, get_request_param
from .apple_session import get_apple_session
from .client import AppleOAuth2Client
from .provider import AppleProvider
class AppleOAuth2Adapter(OAuth2Adapter):
client_class = AppleOAuth2Client
provider_id = AppleProvider.id
access_token_url = "https://appleid.apple.com/auth/token"
authorize_url = "https://appleid.apple.com/auth/authorize"
public_key_url = "https://appleid.apple.com/auth/keys"
def _get_apple_public_key(self, kid):
response = requests.get(self.public_key_url)
response.raise_for_status()
try:
data = response.json()
except json.JSONDecodeError as e:
raise OAuth2Error("Error retrieving apple public key.") from e
for d in data["keys"]:
if d["kid"] == kid:
return d
def get_public_key(self, id_token):
"""
Get the public key which matches the `kid` in the id_token header.
"""
kid = jwt.get_unverified_header(id_token)["kid"]
apple_public_key = self._get_apple_public_key(kid=kid)
public_key = jwt.algorithms.RSAAlgorithm.from_jwk(json.dumps(apple_public_key))
return public_key
def get_client_id(self, provider):
app = get_adapter().get_app(request=None, provider=self.provider_id)
return [aud.strip() for aud in app.client_id.split(",")]
def get_verified_identity_data(self, id_token):
provider = self.get_provider()
allowed_auds = self.get_client_id(provider)
try:
public_key = self.get_public_key(id_token)
identity_data = jwt.decode(
id_token,
public_key,
algorithms=["RS256"],
audience=allowed_auds,
issuer="https://appleid.apple.com",
)
return identity_data
except jwt.PyJWTError as e:
raise OAuth2Error("Invalid id_token") from e
def parse_token(self, data):
token = SocialToken(
token=data["access_token"],
)
token.token_secret = data.get("refresh_token", "")
expires_in = data.get(self.expires_in_key)
if expires_in:
token.expires_at = timezone.now() + timedelta(seconds=int(expires_in))
# `user_data` is a big flat dictionary with the parsed JWT claims
# access_tokens, and user info from the apple post.
identity_data = self.get_verified_identity_data(data["id_token"])
token.user_data = {**data, **identity_data}
return token
def complete_login(self, request, app, token, **kwargs):
extra_data = token.user_data
login = self.get_provider().sociallogin_from_response(
request=request, response=extra_data
)
login.state["id_token"] = token.user_data
# We can safely remove the apple login session now
# Note: The cookie will remain, but it's set to delete on browser close
get_apple_session(request).delete()
return login
def get_user_scope_data(self, request):
user_scope_data = request.apple_login_session.get("user", "")
try:
return json.loads(user_scope_data)
except json.JSONDecodeError:
# We do not care much about user scope data as it maybe blank
# so return blank dictionary instead
return {}
def get_access_token_data(self, request, app, client):
"""We need to gather the info from the apple specific login"""
apple_session = get_apple_session(request)
# Exchange `code`
code = get_request_param(request, "code")
pkce_code_verifier = request.session.pop("pkce_code_verifier", None)
access_token_data = client.get_access_token(
code, pkce_code_verifier=pkce_code_verifier
)
id_token = access_token_data.get("id_token", None)
# In case of missing id_token in access_token_data
if id_token is None:
id_token = apple_session.store.get("id_token")
return {
**access_token_data,
**self.get_user_scope_data(request),
"id_token": id_token,
}
@csrf_exempt
def apple_post_callback(request, finish_endpoint_name="apple_finish_callback"):
"""
Apple uses a `form_post` response type, which due to
CORS/Samesite-cookie rules means this request cannot access
the request since the session cookie is unavailable.
We work around this by storing the apple response in a
separate, temporary session and redirecting to a more normal
oauth flow.
args:
finish_endpoint_name (str): The name of a defined URL, which can be
overridden in your url configuration if you have more than one
callback endpoint.
"""
if request.method != "POST":
return HttpResponseNotAllowed(["POST"])
apple_session = get_apple_session(request)
# Add regular OAuth2 params to the URL - reduces the overrides required
keys_to_put_in_url = ["code", "state", "error"]
url_params = {}
for key in keys_to_put_in_url:
value = get_request_param(request, key, "")
if value:
url_params[key] = value
# Add other params to the apple_login_session
keys_to_save_to_session = ["user", "id_token"]
for key in keys_to_save_to_session:
apple_session.store[key] = get_request_param(request, key, "")
url = build_absolute_uri(request, reverse(finish_endpoint_name))
response = HttpResponseRedirect(
"{url}?{query}".format(url=url, query=urlencode(url_params))
)
apple_session.save(response)
return response
oauth2_login = OAuth2LoginView.adapter_view(AppleOAuth2Adapter)
oauth2_callback = apple_post_callback
oauth2_finish_login = OAuth2CallbackView.adapter_view(AppleOAuth2Adapter)

Some files were not shown because too many files have changed in this diff Show More