From e650ec2fd0217110b446547e596a0b20f90fd4c9 Mon Sep 17 00:00:00 2001 From: Benjamin Dauvergne Date: Wed, 7 Dec 2022 10:58:26 +0100 Subject: [PATCH 1/3] misc: add next_url parameter to Authenticator.autorun() (#27135) --- src/authentic2/views.py | 2 +- src/authentic2_auth_fc/models.py | 4 ++-- src/authentic2_auth_fc/views.py | 5 ++++- src/authentic2_auth_oidc/models.py | 6 ++++-- src/authentic2_auth_saml/models.py | 4 ++-- tests/test_auth_oidc.py | 2 +- 6 files changed, 14 insertions(+), 9 deletions(-) diff --git a/src/authentic2/views.py b/src/authentic2/views.py index 3e7bd0d1..2e044592 100644 --- a/src/authentic2/views.py +++ b/src/authentic2/views.py @@ -415,7 +415,7 @@ def login(request, template_name='authentic2/login.html', redirect_field_name=RE if hasattr(authenticator, 'autorun'): if 'message' in token: messages.info(request, token['message']) - return authenticator.autorun(request, block.get('id')) + return authenticator.autorun(request, block_id=block.get('id'), next_url=redirect_to) context.update( { diff --git a/src/authentic2_auth_fc/models.py b/src/authentic2_auth_fc/models.py index f767c720..958e8a41 100644 --- a/src/authentic2_auth_fc/models.py +++ b/src/authentic2_auth_fc/models.py @@ -118,8 +118,8 @@ class FcAuthenticator(BaseAuthenticator): else: return 'https://app.franceconnect.gouv.fr/api/v1/logout' - def autorun(self, request, block_id): - return views.LoginOrLinkView.as_view(display_message_on_redirect=True)(request) + def autorun(self, request, block_id, next_url): + return views.LoginOrLinkView.as_view(display_message_on_redirect=True)(request, next_url=next_url) def login(self, request, *args, **kwargs): return views.login(request, *args, **kwargs) diff --git a/src/authentic2_auth_fc/views.py b/src/authentic2_auth_fc/views.py index daca3b82..f23a6e23 100644 --- a/src/authentic2_auth_fc/views.py +++ b/src/authentic2_auth_fc/views.py @@ -147,7 +147,10 @@ class LoginOrLinkView(View): display_name += family_name return display_name - def get(self, request, *args, **kwargs): + def get(self, request, *args, next_url=None, **kwargs): + if next_url: + self._next_url = next_url + self.authenticator = get_object_or_404(models.FcAuthenticator, enabled=True) code = request.GET.get('code') diff --git a/src/authentic2_auth_oidc/models.py b/src/authentic2_auth_oidc/models.py index dd23e02d..f39af972 100644 --- a/src/authentic2_auth_oidc/models.py +++ b/src/authentic2_auth_oidc/models.py @@ -232,8 +232,10 @@ class OIDCProvider(BaseAuthenticator): def __repr__(self): return '' % self.issuer - def autorun(self, request, *args): - return redirect_to_login(request, login_url='oidc-login', kwargs={'pk': self.pk}) + def autorun(self, request, block_id, next_url): + return redirect_to_login( + request, login_url='oidc-login', kwargs={'pk': self.pk}, params={'next': next_url} + ) def login(self, request, *args, **kwargs): context = kwargs.get('context', {}).copy() diff --git a/src/authentic2_auth_saml/models.py b/src/authentic2_auth_saml/models.py index bdfa239b..0c951ab0 100644 --- a/src/authentic2_auth_saml/models.py +++ b/src/authentic2_auth_saml/models.py @@ -230,13 +230,13 @@ class SAMLAuthenticator(BaseAuthenticator): if not (self.metadata or self.metadata_url): raise ValidationError(_('One of the metadata fields must be filled.')) - def autorun(self, request, block_id): + def autorun(self, request, block_id, next_url): from .adapters import AuthenticAdapter settings = self.settings AuthenticAdapter().load_idp(settings, self.order) return redirect_to_login( - request, login_url='mellon_login', params={'entityID': settings['ENTITY_ID']} + request, login_url='mellon_login', params={'entityID': settings['ENTITY_ID'], 'next': next_url} ) def has_signing_key(self): diff --git a/tests/test_auth_oidc.py b/tests/test_auth_oidc.py index ee20552c..58c57a27 100644 --- a/tests/test_auth_oidc.py +++ b/tests/test_auth_oidc.py @@ -492,7 +492,7 @@ def test_login_autorun(oidc_provider, app, settings): slug='password-authenticator', defaults={'enabled': False} ) response = app.get('/login/', status=302) - assert response['Location'] == '/accounts/oidc/login/%s/' % oidc_provider.pk + assert response['Location'] == '/accounts/oidc/login/%s/?next=/' % oidc_provider.pk def test_sso(app, caplog, code, oidc_provider, oidc_provider_jwkset, hooks): -- 2.37.2