162 lines
5.5 KiB
Python
162 lines
5.5 KiB
Python
from django.http import HttpResponseRedirect
|
|
from django.shortcuts import render
|
|
from django.urls import reverse
|
|
from django.views import View
|
|
from django.views.decorators.csrf import csrf_exempt
|
|
|
|
from openid.consumer import consumer
|
|
from openid.consumer.discover import DiscoveryFailure
|
|
from openid.extensions.ax import AttrInfo, FetchRequest
|
|
from openid.extensions.sreg import SRegRequest
|
|
|
|
from allauth.socialaccount.app_settings import QUERY_EMAIL
|
|
from allauth.socialaccount.helpers import (
|
|
complete_social_login,
|
|
render_authentication_error,
|
|
)
|
|
from allauth.socialaccount.models import SocialLogin
|
|
|
|
from ..base import AuthError
|
|
from .forms import LoginForm
|
|
from .provider import OpenIDProvider
|
|
from .utils import AXAttributes, DBOpenIDStore, JSONSafeSession, SRegFields
|
|
|
|
|
|
def _openid_consumer(request, provider, endpoint):
|
|
server_settings = provider.get_server_settings(endpoint)
|
|
stateless = server_settings.get("stateless", False)
|
|
store = None if stateless else DBOpenIDStore()
|
|
client = consumer.Consumer(JSONSafeSession(request.session), store)
|
|
return client
|
|
|
|
|
|
class OpenIDLoginView(View):
|
|
template_name = "openid/login.html"
|
|
form_class = LoginForm
|
|
provider = OpenIDProvider
|
|
|
|
def get(self, request):
|
|
form = self.get_form()
|
|
if not form.is_valid():
|
|
return render(request, self.template_name, {"form": form})
|
|
|
|
try:
|
|
return self.perform_openid_auth(form)
|
|
except (UnicodeDecodeError, DiscoveryFailure) as e:
|
|
# UnicodeDecodeError: necaris/python3-openid#1
|
|
return render_authentication_error(request, self.provider.id, exception=e)
|
|
|
|
def post(self, request):
|
|
form = self.get_form()
|
|
if form.is_valid():
|
|
try:
|
|
return self.perform_openid_auth(form)
|
|
except (UnicodeDecodeError, DiscoveryFailure) as e:
|
|
form._errors["openid"] = form.error_class([e])
|
|
|
|
return render(request, self.template_name, {"form": form})
|
|
|
|
def get_form(self):
|
|
if self.request.method == "GET" and "openid" not in self.request.GET:
|
|
return self.form_class(
|
|
initial={
|
|
"next": self.request.GET.get("next"),
|
|
"process": self.request.GET.get("process"),
|
|
}
|
|
)
|
|
|
|
return self.form_class(
|
|
dict(list(self.request.GET.items()) + list(self.request.POST.items()))
|
|
)
|
|
|
|
def get_client(self, provider, endpoint):
|
|
return _openid_consumer(self.request, provider, endpoint)
|
|
|
|
def get_realm(self, provider):
|
|
return provider.get_settings().get(
|
|
"REALM", self.request.build_absolute_uri("/")
|
|
)
|
|
|
|
def get_callback_url(self):
|
|
return reverse(callback)
|
|
|
|
def perform_openid_auth(self, form):
|
|
if not form.is_valid():
|
|
return form
|
|
|
|
request = self.request
|
|
provider = self.provider(request)
|
|
endpoint = form.cleaned_data["openid"]
|
|
client = self.get_client(provider, endpoint)
|
|
realm = self.get_realm(provider)
|
|
|
|
auth_request = client.begin(endpoint)
|
|
if QUERY_EMAIL:
|
|
sreg = SRegRequest()
|
|
for name in SRegFields:
|
|
sreg.requestField(field_name=name, required=True)
|
|
auth_request.addExtension(sreg)
|
|
ax = FetchRequest()
|
|
for name in AXAttributes:
|
|
ax.add(AttrInfo(name, required=True))
|
|
provider = OpenIDProvider(request)
|
|
server_settings = provider.get_server_settings(request.GET.get("openid"))
|
|
extra_attributes = server_settings.get("extra_attributes", [])
|
|
for _, name, required in extra_attributes:
|
|
ax.add(AttrInfo(name, required=required))
|
|
auth_request.addExtension(ax)
|
|
|
|
SocialLogin.stash_state(request)
|
|
|
|
# Fix for issues 1523 and 2072 (github django-allauth)
|
|
if "next" in form.cleaned_data and form.cleaned_data["next"]:
|
|
auth_request.return_to_args["next"] = form.cleaned_data["next"]
|
|
redirect_url = auth_request.redirectURL(
|
|
realm, request.build_absolute_uri(self.get_callback_url())
|
|
)
|
|
return HttpResponseRedirect(redirect_url)
|
|
|
|
|
|
login = OpenIDLoginView.as_view()
|
|
|
|
|
|
class OpenIDCallbackView(View):
|
|
provider = OpenIDProvider
|
|
|
|
def get(self, request):
|
|
provider = self.provider(request)
|
|
endpoint = request.GET.get("openid.op_endpoint", "")
|
|
client = self.get_client(provider, endpoint)
|
|
response = self.get_openid_response(client)
|
|
|
|
if response.status == consumer.SUCCESS:
|
|
login = provider.sociallogin_from_response(request, response)
|
|
login.state = SocialLogin.unstash_state(request)
|
|
return self.complete_login(login)
|
|
else:
|
|
if response.status == consumer.CANCEL:
|
|
error = AuthError.CANCELLED
|
|
else:
|
|
error = AuthError.UNKNOWN
|
|
return self.render_error(error)
|
|
|
|
post = get
|
|
|
|
def complete_login(self, login):
|
|
return complete_social_login(self.request, login)
|
|
|
|
def render_error(self, error):
|
|
return render_authentication_error(self.request, self.provider.id, error=error)
|
|
|
|
def get_client(self, provider, endpoint):
|
|
return _openid_consumer(self.request, provider, endpoint)
|
|
|
|
def get_openid_response(self, client):
|
|
return client.complete(
|
|
dict(list(self.request.GET.items()) + list(self.request.POST.items())),
|
|
self.request.build_absolute_uri(self.request.path),
|
|
)
|
|
|
|
|
|
callback = csrf_exempt(OpenIDCallbackView.as_view())
|