From 406654d96d61c47495402a7e528e3073650c511f Mon Sep 17 00:00:00 2001 From: Valentin Deniaud Date: Tue, 25 Aug 2020 10:07:49 +0200 Subject: [PATCH] misc: provide origin service in template context (#20699) --- src/authentic2/context_processors.py | 6 +++++ src/authentic2/middleware.py | 15 ++++++++++++ src/authentic2/settings.py | 1 + src/authentic2/templates/authentic2/base.html | 4 ++++ tests/test_template.py | 24 +++++++++++++++++++ 5 files changed, 50 insertions(+) diff --git a/src/authentic2/context_processors.py b/src/authentic2/context_processors.py index cb78679b..885ad91f 100644 --- a/src/authentic2/context_processors.py +++ b/src/authentic2/context_processors.py @@ -18,6 +18,7 @@ from pkg_resources import get_distribution from django.conf import settings from . import utils, app_settings, constants +from .models import Service class UserFederations(object): @@ -59,4 +60,9 @@ def a2_processor(request): if hasattr(request, 'session'): variables['LAST_LOGIN'] = request.session.get(constants.LAST_LOGIN_SESSION_KEY) variables['USER_SWITCHED'] = constants.SWITCH_USER_SESSION_KEY in request.session + if 'service_pk' in request.session: + try: + variables['service'] = Service.objects.get(pk=request.session['service_pk']) + except Service.DoesNotExist: + pass return variables diff --git a/src/authentic2/middleware.py b/src/authentic2/middleware.py index 38a6761b..4755a2af 100644 --- a/src/authentic2/middleware.py +++ b/src/authentic2/middleware.py @@ -32,6 +32,7 @@ from django.utils.six.moves.urllib import parse as urlparse from django.shortcuts import render from . import app_settings, utils, plugins +from .utils.service import get_service_from_request class CollectIPMiddleware(MiddlewareMixin): @@ -205,3 +206,17 @@ class CookieTestMiddleware(MiddlewareMixin): # set test cookie for 1 year response.set_cookie(self.COOKIE_NAME, '1', max_age=365 * 24 * 3600) return response + + +class SaveServiceInSessionMiddleware: + def __init__(self, get_response): + self.get_response = get_response + + def __call__(self, request): + service = None + + service = get_service_from_request(request) + if service: + request.session['service_pk'] = service.pk + + return self.get_response(request) diff --git a/src/authentic2/settings.py b/src/authentic2/settings.py index f1a51130..ad469c4e 100644 --- a/src/authentic2/settings.py +++ b/src/authentic2/settings.py @@ -95,6 +95,7 @@ MIDDLEWARE = ( 'django.middleware.common.CommonMiddleware', 'django.middleware.http.ConditionalGetMiddleware', 'django.contrib.sessions.middleware.SessionMiddleware', + 'authentic2.middleware.SaveServiceInSessionMiddleware', 'django.middleware.csrf.CsrfViewMiddleware', 'django.middleware.locale.LocaleMiddleware', 'django.contrib.auth.middleware.AuthenticationMiddleware', diff --git a/src/authentic2/templates/authentic2/base.html b/src/authentic2/templates/authentic2/base.html index c346df2b..91e724ab 100644 --- a/src/authentic2/templates/authentic2/base.html +++ b/src/authentic2/templates/authentic2/base.html @@ -11,6 +11,10 @@ {{ form.media.css }} {% endblock %} +{% block bodyargs %} + data-service-slug="{{ service.slug }}" data-service-name="{{ service.name }}" +{% endblock %} + {% block extrascripts %} {{ block.super }} {{ form.media.js }} diff --git a/tests/test_template.py b/tests/test_template.py index c91a27c4..0511ce25 100644 --- a/tests/test_template.py +++ b/tests/test_template.py @@ -16,6 +16,8 @@ import pytest +from authentic2.a2_rbac.utils import get_default_ou +from authentic2.models import Service from authentic2.utils.template import Template, TemplateError pytestmark = pytest.mark.django_db @@ -111,3 +113,25 @@ def test_render_template_missing_variable(): with pytest.raises(TemplateError) as raised: template.render(context=context) assert 'missing template variable' in raised + + +def test_service_in_template(app, simple_user, service): + resp = app.get(reverse('auth_login') + '?service=%s' % service.slug) + + assert resp.pyquery('body').attr('data-service-slug') == service.slug + assert resp.pyquery('body').attr('data-service-name') == service.name + + resp.form.set('username', simple_user.username) + resp.form.set('password', simple_user.username) + response = resp.form.submit(name='login-password-submit') + + resp = app.get(reverse('account_management')) + assert resp.pyquery('body').attr('data-service-slug') == service.slug + assert resp.pyquery('body').attr('data-service-name') == service.name + + # if user comes back from a different service, the information is updated + new_service = Service.objects.create(ou=get_default_ou(), slug='service2', + name='Service2') + resp = app.get(reverse('account_management') + '?service=%s' % new_service.slug) + assert resp.pyquery('body').attr('data-service-slug') == new_service.slug + assert resp.pyquery('body').attr('data-service-name') == new_service.name -- 2.20.1