from unittest.mock import patch from urllib.parse import parse_qs, urlparse from django.urls import reverse from django.utils.http import urlencode import pytest from allauth.account.models import EmailAddress from allauth.socialaccount.models import SocialAccount from allauth.socialaccount.providers.saml.utils import build_saml_config @pytest.mark.parametrize( "relay_state, expected_url", [ (None, "/accounts/profile/"), ("/foo", "/foo"), ], ) def test_acs( client, db, saml_settings, acs_saml_response, mocked_signature_validation, expected_url, relay_state, ): data = {"SAMLResponse": acs_saml_response} if relay_state is not None: data["RelayState"] = relay_state resp = client.post( reverse("saml_acs", kwargs={"organization_slug": "org"}), data=data ) finish_url = reverse("saml_finish_acs", kwargs={"organization_slug": "org"}) assert resp.status_code == 302 assert resp["location"] == finish_url resp = client.get(finish_url) assert resp["location"] == expected_url account = SocialAccount.objects.get( provider="urn:dev-123.us.auth0.com", uid="dummysamluid" ) assert account.extra_data["Role"] == ["view-profile", "manage-account-links"] email = EmailAddress.objects.get(user=account.user) assert email.email == "john.doe@email.org" def test_acs_error(client, db, saml_settings): data = {"SAMLResponse": "bad-response"} resp = client.post( reverse("saml_acs", kwargs={"organization_slug": "org"}), data=data ) assert resp.status_code == 200 assert "socialaccount/authentication_error.html" in (t.name for t in resp.templates) @pytest.mark.parametrize( "query,expected_relay_state", [ ("", None), ("?process=connect", "/social/connections/"), ("?process=connect&next=/foo", "/foo"), ("?next=/bar", "/bar"), ], ) def test_login(client, db, saml_settings, query, expected_relay_state): resp = client.get( reverse("saml_login", kwargs={"organization_slug": "org"}) + query ) assert resp.status_code == 302 location = resp["location"] assert location.startswith("https://dev-123.us.auth0.com/samlp/456?SAMLRequest=") resp_query = parse_qs(urlparse(location).query) if expected_relay_state is None: assert "RelayState" not in resp_query else: assert resp_query.get("RelayState")[0] == expected_relay_state def test_metadata( client, db, saml_settings, ): resp = client.get(reverse("saml_metadata", kwargs={"organization_slug": "org"})) assert resp.status_code == 200 assert resp.content.startswith( b'\n