From 2fc874c24476b430a83b48936220f519815819da Mon Sep 17 00:00:00 2001 From: Benjamin Dauvergne Date: Fri, 28 Jan 2022 00:31:04 +0100 Subject: [PATCH 2/2] misc: maintain home url, service and ou (#61199) --- src/authentic2/authenticators.py | 10 +- src/authentic2/constants.py | 1 - src/authentic2/context_processors.py | 13 +++ src/authentic2/idp/saml/saml2_endpoints.py | 9 +- src/authentic2/journal.py | 4 +- src/authentic2/middleware.py | 14 --- src/authentic2/models.py | 3 + src/authentic2/settings.py | 2 +- src/authentic2/templates/authentic2/base.html | 4 +- src/authentic2/utils/misc.py | 33 +++--- src/authentic2/utils/service.py | 106 +++++++++--------- src/authentic2/views.py | 57 +++++----- src/authentic2_auth_fc/views.py | 20 +--- src/authentic2_idp_cas/views.py | 3 + src/authentic2_idp_oidc/views.py | 8 +- tests/auth_fc/conftest.py | 8 +- tests/auth_fc/test_auth_fc.py | 20 ++-- tests/idp_oidc/test_misc.py | 20 ++-- tests/test_idp_saml2.py | 3 +- tests/test_login.py | 34 +++--- tests/test_template.py | 9 +- tests/utils.py | 18 +++ 22 files changed, 216 insertions(+), 183 deletions(-) diff --git a/src/authentic2/authenticators.py b/src/authentic2/authenticators.py index 07173f73..df32d293 100644 --- a/src/authentic2/authenticators.py +++ b/src/authentic2/authenticators.py @@ -29,7 +29,7 @@ from . import app_settings, views from .forms import authentication as authentication_forms from .utils import misc as utils_misc from .utils.evaluate import evaluate_condition -from .utils.service import get_service_from_request +from .utils.service import get_service from .utils.views import csrf_token_check logger = logging.getLogger(__name__) @@ -88,7 +88,8 @@ class LoginPasswordAuthenticator(BaseAuthenticator): return [] return OU.objects.filter(pk__in=service_ou_ids) - def get_preferred_ous(self, request, service): + def get_preferred_ous(self, request): + service = get_service(request) preferred_ous_cookie = utils_misc.get_remember_cookie(request, 'preferred-ous') preferred_ous = [] if preferred_ous_cookie: @@ -102,7 +103,6 @@ class LoginPasswordAuthenticator(BaseAuthenticator): return preferred_ous def login(self, request, *args, **kwargs): - service = get_service_from_request(request) context = kwargs.get('context', {}) is_post = request.method == 'POST' and self.submit_name in request.POST data = request.POST if is_post else None @@ -112,7 +112,7 @@ class LoginPasswordAuthenticator(BaseAuthenticator): # Special handling when the form contains an OU selector if app_settings.A2_LOGIN_FORM_OU_SELECTOR: - preferred_ous = self.get_preferred_ous(request, service) + preferred_ous = self.get_preferred_ous(request) if preferred_ous: initial['ou'] = preferred_ous[0] @@ -135,7 +135,7 @@ class LoginPasswordAuthenticator(BaseAuthenticator): if form.cleaned_data.get('remember_me'): request.session['remember_me'] = True request.session.set_expiry(app_settings.A2_USER_REMEMBER_ME) - response = utils_misc.login(request, form.get_user(), how, service=service) + response = utils_misc.login(request, form.get_user(), how) if 'ou' in form.fields: utils_misc.prepend_remember_cookie( request, response, 'preferred-ous', form.cleaned_data['ou'].pk diff --git a/src/authentic2/constants.py b/src/authentic2/constants.py index f32c226c..c0bd0d34 100644 --- a/src/authentic2/constants.py +++ b/src/authentic2/constants.py @@ -20,5 +20,4 @@ CANCEL_FIELD_NAME = 'cancel' AUTHENTICATION_EVENTS_SESSION_KEY = 'authentication-events' SWITCH_USER_SESSION_KEY = '_switch_user' LAST_LOGIN_SESSION_KEY = '_last_login' -SERVICE_FIELD_NAME = 'service' NEXT_URL_SIGNATURE = 'next-signature' diff --git a/src/authentic2/context_processors.py b/src/authentic2/context_processors.py index d821bf55..3d5e3887 100644 --- a/src/authentic2/context_processors.py +++ b/src/authentic2/context_processors.py @@ -20,6 +20,7 @@ from pkg_resources import get_distribution from . import app_settings, constants from .models import Service from .utils import misc as utils_misc +from .utils.service import get_service class UserFederations: @@ -69,3 +70,15 @@ def a2_processor(request): except Service.DoesNotExist: pass return variables + + +def home(request): + ctx = {} + if 'home_url' in request.session: + ctx['home_url'] = request.session['home_url'] + service = get_service(request) + if service: + ctx['home_service'] = service + if service.ou: + ctx['home_ou'] = service.ou + return ctx diff --git a/src/authentic2/idp/saml/saml2_endpoints.py b/src/authentic2/idp/saml/saml2_endpoints.py index 08dacc95..12b6ed6b 100644 --- a/src/authentic2/idp/saml/saml2_endpoints.py +++ b/src/authentic2/idp/saml/saml2_endpoints.py @@ -113,6 +113,7 @@ from authentic2.utils import misc as utils_misc from authentic2.utils.misc import datetime_to_xs_datetime, find_authentication_event from authentic2.utils.misc import get_backends as get_idp_backends from authentic2.utils.misc import login_require, make_url +from authentic2.utils.service import set_service from authentic2.utils.view_decorators import check_view_restriction, enable_view_restriction from . import app_settings @@ -582,6 +583,7 @@ def sso(request): }, ) else: + set_service(request, provider_loaded) policy = get_sp_options_policy(provider_loaded) if not policy: return error_page(request, _('sso: No SP policy defined'), logger=logger, warning=True) @@ -628,7 +630,7 @@ def sso(request): return sso_after_process_request(request, login, nid_format=nid_format) -def need_login(request, login, nid_format, service): +def need_login(request, login, nid_format): """Redirect to the login page with a nonce parameter to verify later that the login form was submitted """ @@ -640,7 +642,6 @@ def need_login(request, login, nid_format, service): request, next_url=next_url, params={NONCE_FIELD_NAME: nonce}, - service=service, login_hint=get_login_hints_extension(login), ) @@ -789,7 +790,7 @@ def sso_after_process_request( if not passive and (user.is_anonymous or (force_authn and not did_auth)): logger.debug('login required') - return need_login(request, login, nid_format, service) + return need_login(request, login, nid_format) # No user is authenticated and passive is True, deny request if passive and user.is_anonymous: @@ -1296,6 +1297,7 @@ def slo_soap(request): except ObjectDoesNotExist: logger.warning('provider %r unknown', logout.remoteProviderId) return return_logout_error(request, logout, AUTHENTIC_STATUS_CODE_UNAUTHORIZED) + set_service(request, provider) policy = get_sp_options_policy(provider) if not policy: logger.warning('No policy found for %s', logout.remoteProviderId) @@ -1385,6 +1387,7 @@ def slo(request): except ObjectDoesNotExist: logger.debug('provider %r unknown', logout.remoteProviderId) return return_logout_error(request, logout, AUTHENTIC_STATUS_CODE_UNAUTHORIZED) + set_service(request, provider) policy = get_sp_options_policy(provider) if not policy: logger.debug('No policy found for %s', logout.remoteProviderId) diff --git a/src/authentic2/journal.py b/src/authentic2/journal.py index b0172d52..b6eb707d 100644 --- a/src/authentic2/journal.py +++ b/src/authentic2/journal.py @@ -15,7 +15,7 @@ # along with this program. If not, see . from authentic2.apps.journal.journal import Journal as BaseJournal -from authentic2.utils.service import get_service_from_request +from authentic2.utils.service import get_service class Journal(BaseJournal): @@ -25,7 +25,7 @@ class Journal(BaseJournal): @property def service(self): - return self._service or (get_service_from_request(self.request) if self.request else None) + return self._service or get_service(self.request) if self.request else None def massage_kwargs(self, record_parameters, kwargs): if 'service' not in kwargs and 'service' in record_parameters: diff --git a/src/authentic2/middleware.py b/src/authentic2/middleware.py index 6be6eb5b..f0e95520 100644 --- a/src/authentic2/middleware.py +++ b/src/authentic2/middleware.py @@ -28,12 +28,10 @@ from django.conf import settings from django.contrib import messages from django.db.models import Model from django.utils.deprecation import MiddlewareMixin -from django.utils.functional import SimpleLazyObject from django.utils.translation import ugettext as _ from . import app_settings, plugins from .utils import misc as utils_misc -from .utils.service import get_service_from_request, get_service_from_session class CollectIPMiddleware(MiddlewareMixin): @@ -263,18 +261,6 @@ class CookieTestMiddleware(MiddlewareMixin): return response -class SaveServiceInSessionMiddleware: - def __init__(self, get_response): - self.get_response = get_response - - def __call__(self, request): - service = get_service_from_request(request) - if service: - request.session['service_pk'] = service.pk - request.service = SimpleLazyObject(lambda: get_service_from_session(request)) - return self.get_response(request) - - def journal_middleware(get_response): from . import journal diff --git a/src/authentic2/models.py b/src/authentic2/models.py index c7dddb49..3a31a1c6 100644 --- a/src/authentic2/models.py +++ b/src/authentic2/models.py @@ -451,6 +451,9 @@ class Service(models.Model): def get_absolute_url(self): return reverse('a2-manager-service', kwargs={'service_pk': self.pk}) + def get_base_urls(self): + return [] + Service._meta.natural_key = [['slug', 'ou']] diff --git a/src/authentic2/settings.py b/src/authentic2/settings.py index a05511d9..c4b1619e 100644 --- a/src/authentic2/settings.py +++ b/src/authentic2/settings.py @@ -81,6 +81,7 @@ TEMPLATES = [ 'django.contrib.messages.context_processors.messages', 'django.template.context_processors.static', 'authentic2.context_processors.a2_processor', + 'authentic2.context_processors.home', ], }, }, @@ -96,7 +97,6 @@ MIDDLEWARE = ( 'django.middleware.common.CommonMiddleware', 'django.middleware.http.ConditionalGetMiddleware', 'django.contrib.sessions.middleware.SessionMiddleware', - 'authentic2.middleware.SaveServiceInSessionMiddleware', 'django.middleware.csrf.CsrfViewMiddleware', 'django.middleware.locale.LocaleMiddleware', 'django.contrib.auth.middleware.AuthenticationMiddleware', diff --git a/src/authentic2/templates/authentic2/base.html b/src/authentic2/templates/authentic2/base.html index 4c0269fc..bd56e803 100644 --- a/src/authentic2/templates/authentic2/base.html +++ b/src/authentic2/templates/authentic2/base.html @@ -12,7 +12,9 @@ {% endblock %} {% block bodyargs %} - data-service-slug="{{ service.slug }}" data-service-name="{{ service.name }}" + {% if home_url %}data-home-url="{{ home_url }}"{% endif %} + {% if home_service %}data-home-service-slug="{{ home_service.slug }}" data-home-service-name="{{ home_service.name }}"{% endif %} + {% if home_ou %}data-home-ou-slug="{{ home_ou.slug }}" data-home-ou-name="{{ home_ou.name }}"{% endif %} {% endblock %} {% block extrascripts %} diff --git a/src/authentic2/utils/misc.py b/src/authentic2/utils/misc.py index 8cd6e871..0db3b3da 100644 --- a/src/authentic2/utils/misc.py +++ b/src/authentic2/utils/misc.py @@ -50,7 +50,6 @@ from django.utils.translation import ungettext from authentic2.saml.saml2utils import filter_attribute_private_key, filter_element_private_key from .. import app_settings, constants, crypto, plugins -from .service import set_service_ref class CleanLogMessage(logging.Filter): @@ -455,15 +454,13 @@ def last_authentication_event(request=None, session=None): return None -def login(request, user, how, service=None, service_slug=None, nonce=None, record=True, **kwargs): +def login(request, user, how, nonce=None, record=True, **kwargs): """Login a user model, record the authentication event and redirect to next URL or settings.LOGIN_REDIRECT_URL.""" from .. import hooks + from .service import get_service from .views import check_cookie_works - if service: - assert service_slug is None - service_slug = service.slug check_cookie_works(request) last_login = user.last_login auth_login(request, user) @@ -472,23 +469,21 @@ def login(request, user, how, service=None, service_slug=None, nonce=None, recor if constants.LAST_LOGIN_SESSION_KEY not in request.session: request.session[constants.LAST_LOGIN_SESSION_KEY] = localize(to_current_timezone(last_login), True) record_authentication_event(request, how, nonce=nonce) - hooks.call_hooks('event', name='login', user=user, how=how, service=service_slug) + hooks.call_hooks('event', name='login', user=user, how=how, service=get_service(request)) # prevent logint-hint to influence next use of the login page if 'login-hint' in request.session: del request.session['login-hint'] if record: - request.journal.record('user.login', how=how, service=service) + request.journal.record('user.login', how=how) return continue_to_next_url(request, **kwargs) -def login_require(request, next_url=None, login_url='auth_login', service=None, login_hint=(), **kwargs): +def login_require(request, next_url=None, login_url='auth_login', login_hint=(), **kwargs): '''Require a login and come back to current URL''' next_url = next_url or request.get_full_path() params = kwargs.setdefault('params', {}) params[REDIRECT_FIELD_NAME] = next_url - if service: - set_service_ref(params, service) if login_hint: request.session['login-hint'] = list(login_hint) elif 'login-hint' in request.session: @@ -735,14 +730,12 @@ def get_fk_model(model, fieldname): return field.related_model -def get_registration_url(request, service=None): +def get_registration_url(request): next_url = select_next_url(request, settings.LOGIN_REDIRECT_URL) next_url = make_url( next_url, request=request, keep_params=True, include=(constants.NONCE_FIELD_NAME,), resolve=False ) params = {REDIRECT_FIELD_NAME: next_url} - if service: - set_service_ref(params, service) return make_url('registration_register', params=params) @@ -1041,9 +1034,17 @@ def get_next_url(params, field_name=None): return next_url -def select_next_url(request, default, field_name=None, include_post=False, replace=None): +EMPTY = object() + + +def select_next_url(request, default=EMPTY, field_name=None, include_post=False, replace=None): '''Select the first valid next URL''' # pylint: disable=consider-using-ternary + if default is EMPTY: + if request.user.is_authenticated and request.user.ou and request.user.ou.home_url: + default = request.user.ou.home_url + else: + default = settings.LOGIN_REDIRECT_URL next_url = (include_post and get_next_url(request.POST, field_name=field_name)) or get_next_url( request.GET, field_name=field_name ) @@ -1143,7 +1144,7 @@ def same_origin(url1, url2): return True -def simulate_authentication(request, user, method, backend=None, service=None, record=False, **kwargs): +def simulate_authentication(request, user, method, backend=None, record=False, **kwargs): """Simulate a normal login by eventually forcing a backend attribute on the user instance""" if not getattr(user, 'backend', None) and not backend: @@ -1151,7 +1152,7 @@ def simulate_authentication(request, user, method, backend=None, service=None, r if backend: user = copy.deepcopy(user) user.backend = backend - return login(request, user, method, service=service, record=record, **kwargs) + return login(request, user, method, record=record, **kwargs) def get_manager_login_url(): diff --git a/src/authentic2/utils/service.py b/src/authentic2/utils/service.py index 5bfdd805..8925d8b9 100644 --- a/src/authentic2/utils/service.py +++ b/src/authentic2/utils/service.py @@ -14,64 +14,64 @@ # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -from authentic2.constants import SERVICE_FIELD_NAME +from django.apps import apps - -def service_ref(service): - if service.ou: - return '%s %s' % (service.ou.slug, service.slug) - else: - return service.slug +from authentic2.decorators import GlobalCache +from authentic2.utils.misc import same_origin -def get_service_from_ref(ref): +@GlobalCache(timeout=60) +def _base_urls_map(): from authentic2.models import Service - splitted = ref.split(' ') - - try: - ou_slug, service_slug = splitted - except ValueError: - pass - else: - return Service.objects.filter(ou__slug=ou_slug, slug=service_slug).first() + base_urls_map = {} + for service in Service.objects.select_related().select_subclasses(): + for url in service.get_base_urls(): + base_urls_map[url] = (type(service), service.pk) + return base_urls_map - try: - (service_slug,) = splitted - except ValueError: - return None - service = Service.objects.filter(ou__isnull=True, slug=service_slug).first() +def _set_service(session, service): + if service and service.ou and service.ou.home_url: + session['home_url'] = service.ou.home_url + elif 'home_url' in session: + del session['home_url'] if service: - return service - try: - return Service.objects.get(slug=service_slug) - except (Service.DoesNotExist, Service.MultipleObjectsReturned): - return None - - -def get_service_from_request(request): - service_ref = request.GET.get(SERVICE_FIELD_NAME) - if service_ref and '\x00' not in service_ref: - return get_service_from_ref(service_ref) - return None - - -def get_service_from_session(request): - session = getattr(request, 'session', None) - if session and 'service_pk' in session: - from authentic2.models import Service - - return Service.objects.get(pk=session['service_pk']) - return None - - -def get_service_from_token(params): - ref = params.get(SERVICE_FIELD_NAME) - if not ref: - return None - return get_service_from_ref(ref) - - -def set_service_ref(params, service): - params[SERVICE_FIELD_NAME] = service_ref(service) + session['service_type'] = [type(service)._meta.app_label, type(service)._meta.model_name] + session['service_pk'] = service.pk + else: + session.pop('sevice_type', None) + session.pop('sevice_pk', None) + + +def set_service(request, service): + request._service = service + _set_service(request.session, service) + + +def set_home_url(request, url=None): + if not url: + from .misc import select_next_url + + url = select_next_url(request) + if not url: + return + urls_map = _base_urls_map() + for base_url in urls_map: + if same_origin(base_url): + ServiceKlass, pk = urls_map[base_url] + set_service(request, ServiceKlass.object.get(pk=pk)) + request.session['home_url'] = url + break + + +def get_service(request): + if not hasattr(request, '_service'): + if 'service_type' in request.session and 'service_pk' in request.session: + ServiceKlass = apps.get_app_config(request.session['service_type'][0]).get_model( + request.session['service_type'][1] + ) + request._service = ServiceKlass.objects.get(pk=request.session['service_pk']) + else: + request._service = None + return getattr(request, '_service', None) diff --git a/src/authentic2/views.py b/src/authentic2/views.py index 1eab8dad..5dbed140 100644 --- a/src/authentic2/views.py +++ b/src/authentic2/views.py @@ -63,7 +63,7 @@ from .forms import registration as registration_forms from .utils import misc as utils_misc from .utils import switch_user as utils_switch_user from .utils.evaluate import make_condition_context -from .utils.service import get_service_from_request, get_service_from_token, set_service_ref +from .utils.service import get_service, set_home_url from .utils.view_decorators import enable_view_restriction User = get_user_model() @@ -71,7 +71,13 @@ User = get_user_model() logger = logging.getLogger(__name__) -class EditProfile(cbv.HookMixin, cbv.TemplateNamesMixin, UpdateView): +class HomeURLMixin: + def dispatch(self, request, *args, **kwargs): + set_home_url(request) + return super().dispatch(request, *args, **kwargs) + + +class EditProfile(HomeURLMixin, cbv.HookMixin, cbv.TemplateNamesMixin, UpdateView): model = User template_names = ['profiles/edit_profile.html', 'authentic2/accounts_edit.html'] title = _('Edit account data') @@ -182,7 +188,7 @@ class EditRequired(EditProfile): edit_required_profile = login_required(EditRequired.as_view()) -class EmailChangeView(cbv.TemplateNamesMixin, FormView): +class EmailChangeView(HomeURLMixin, cbv.TemplateNamesMixin, FormView): template_names = ['profiles/email_change.html', 'authentic2/change_email.html'] title = _('Email Change') success_url = '..' @@ -295,8 +301,6 @@ def login(request, template_name='authentic2/login.html', redirect_field_name=RE redirect_to = request.GET.get(redirect_field_name) - service = get_service_from_request(request) - if not redirect_to or ' ' in redirect_to: redirect_to = settings.LOGIN_REDIRECT_URL # Heavier security check -- redirects to http://example.com should @@ -311,7 +315,7 @@ def login(request, template_name='authentic2/login.html', redirect_field_name=RE blocks = [] - registration_url = utils_misc.get_registration_url(request, service=service) + registration_url = utils_misc.get_registration_url(request) context = { 'cancel': app_settings.A2_LOGIN_DISPLAY_A_CANCEL_BUTTON and nonce is not None, @@ -346,6 +350,7 @@ def login(request, template_name='authentic2/login.html', redirect_field_name=RE parameters = {'request': request, 'context': context} login_hint = set(request.session.get('login-hint', [])) show_ctx = make_condition_context(request=request, login_hint=login_hint) + service = get_service(request) if service: show_ctx['service_ou_slug'] = service.ou and service.ou.slug show_ctx['service_slug'] = service.slug @@ -421,7 +426,12 @@ class Homepage(cbv.TemplateNamesMixin, TemplateView): template_names = ['idp/homepage.html', 'authentic2/homepage.html'] def dispatch(self, request, *args, **kwargs): - if app_settings.A2_HOMEPAGE_URL: + home_url = request.session.get('home_url') + home_url = home_url or ( + request.user.is_authenticated and request.user and request.user.ou and request.user.ou.home_url + ) + home_url = app_settings.A2_HOMEPAGE_URL + if home_url: return utils_misc.redirect(request, app_settings.A2_HOMEPAGE_URL) return login_required(super().dispatch)(request, *args, **kwargs) @@ -435,7 +445,7 @@ class Homepage(cbv.TemplateNamesMixin, TemplateView): homepage = enable_view_restriction(Homepage.as_view()) -class ProfileView(cbv.TemplateNamesMixin, TemplateView): +class ProfileView(HomeURLMixin, cbv.TemplateNamesMixin, TemplateView): template_names = ['idp/account_management.html', 'authentic2/accounts.html'] title = _('Your account') @@ -889,7 +899,7 @@ class PasswordResetConfirmView(cbv.RedirectToNextURLViewMixin, FormView): password_reset_confirm = PasswordResetConfirmView.as_view() -class BaseRegistrationView(FormView): +class BaseRegistrationView(HomeURLMixin, FormView): form_class = registration_forms.RegistrationForm template_name = 'registration/registration_form.html' title = _('Registration') @@ -912,6 +922,7 @@ class BaseRegistrationView(FormView): if 'ou' in self.token: self.ou = OU.objects.get(pk=self.token['ou']) self.next_url = self.token.pop(REDIRECT_FIELD_NAME, utils_misc.select_next_url(request, None)) + set_home_url(request, self.next_url) return super().dispatch(request, *args, **kwargs) def form_valid(self, form): @@ -978,11 +989,6 @@ class BaseRegistrationView(FormView): for field in form.cleaned_data: self.token[field] = form.cleaned_data[field] - # propagate service to the registration completion view - service = get_service_from_request(self.request) - if service: - set_service_ref(self.token, service) - self.token.pop(REDIRECT_FIELD_NAME, None) self.token.pop('email', None) @@ -1066,8 +1072,7 @@ class RegistrationCompletionView(CreateView): if self.ou: self.email_is_unique |= self.ou.email_is_unique self.init_fields_labels_and_help_texts() - # if registration is done during an SSO add the service to the registration event - self.service = get_service_from_token(self.token) + set_home_url(request, self.get_success_url()) return super().dispatch(request, *args, **kwargs) def init_fields_labels_and_help_texts(self): @@ -1180,9 +1185,7 @@ class RegistrationCompletionView(CreateView): def get(self, request, *args, **kwargs): if len(self.users) == 1 and self.email_is_unique: # Found one user, EMAIL is unique, log her in - utils_misc.simulate_authentication( - request, self.users[0], method=self.authentication_method, service=self.service - ) + utils_misc.simulate_authentication(request, self.users[0], method=self.authentication_method) return utils_misc.redirect(request, self.get_success_url()) confirm_data = self.token.get('confirm_data', False) @@ -1220,9 +1223,7 @@ class RegistrationCompletionView(CreateView): uid = request.POST['uid'] for user in self.users: if str(user.id) == uid: - utils_misc.simulate_authentication( - request, user, method=self.authentication_method, service=self.service - ) + utils_misc.simulate_authentication(request, user, method=self.authentication_method) return utils_misc.redirect(request, self.get_success_url()) return super().post(request, *args, **kwargs) @@ -1284,14 +1285,12 @@ class RegistrationCompletionView(CreateView): view=self, authentication_method=self.authentication_method, token=self.token, - service=self.service and self.service.slug, + service=get_service(request), ) self.send_registration_success_email(user) def registration_success(self, request, user): - utils_misc.simulate_authentication( - request, user, method=self.authentication_method, service=self.service - ) + utils_misc.simulate_authentication(request, user, method=self.authentication_method) message_template = loader.get_template('authentic2/registration_success_message.html') messages.info(self.request, message_template.render(request=request)) return utils_misc.redirect(request, self.get_success_url()) @@ -1319,7 +1318,7 @@ class RegistrationCompletionView(CreateView): registration_completion = RegistrationCompletionView.as_view() -class AccountDeleteView(TemplateView): +class AccountDeleteView(HomeURLMixin, TemplateView): template_name = 'authentic2/accounts_delete_request.html' title = _('Request account deletion') @@ -1407,7 +1406,7 @@ class RegistrationCompleteView(TemplateView): registration_complete = RegistrationCompleteView.as_view() -class PasswordChangeView(DjPasswordChangeView): +class PasswordChangeView(HomeURLMixin, DjPasswordChangeView): title = _('Password Change') do_not_call_in_templates = True @@ -1471,7 +1470,7 @@ class SuView(View): su = SuView.as_view() -class Consents(ListView): +class Consents(HomeURLMixin, ListView): template_name = 'authentic2/consents.html' title = _('Consent Management') model = OIDCAuthorization diff --git a/src/authentic2_auth_fc/views.py b/src/authentic2_auth_fc/views.py index b69f9154..7b9c2413 100644 --- a/src/authentic2_auth_fc/views.py +++ b/src/authentic2_auth_fc/views.py @@ -42,7 +42,6 @@ from authentic2.forms.passwords import SetPasswordForm from authentic2.utils import misc as utils_misc from authentic2.utils import views as utils_views from authentic2.utils.models import safe_get_or_create -from authentic2.utils.service import get_service_from_ref, get_service_from_request, service_ref from . import app_settings, models from .utils import ( @@ -69,7 +68,6 @@ class LoginOrLinkView(View): """ _next_url = None - service = None @property def next_url(self): @@ -114,7 +112,7 @@ class LoginOrLinkView(View): def handle_authorization_response(self, request, code, state): # check state signature and parse it try: - state, self._next_url, self.service = self.decode_state(state) + state, self._next_url = self.decode_state(state) except ValueError: return utils_misc.redirect(request, settings.LOGIN_REDIRECT_URL) @@ -186,10 +184,8 @@ class LoginOrLinkView(View): else: return self.login(request) - def encode_state(self, state, next_url, service): + def encode_state(self, state, next_url): encoded_state = state + ' ' + self.next_url + ' ' - if service: - encoded_state += service_ref(service) encoded_state += ' ' + hmac_url(settings.SECRET_KEY, encoded_state) return encoded_state @@ -197,32 +193,26 @@ class LoginOrLinkView(View): payload, signature = state.rsplit(' ', 1) if not check_hmac_url(settings.SECRET_KEY, payload, signature): raise ValueError - # service_ref can be made of one or two parts try: state, next_url, service_ref = payload.split(' ') except ValueError: state, next_url, ou_slug, service_slug = payload.split(' ') - service_ref = ou_slug + ' ' + service_slug - service = get_service_from_ref(service_ref) - return state, next_url, service + return state, next_url def make_authorization_request(self, request): scope = ' '.join(set(['openid'] + app_settings.scopes)) - service = self.service or get_service_from_request(request) nonce_seed, nonce, state = hash_chain(3) # encode the target service and next_url in the state full_state = state + ' ' + self.next_url + ' ' - if service: - full_state += service_ref(service) full_state += ' ' + hmac_url(settings.SECRET_KEY, full_state) params = { 'client_id': app_settings.client_id, 'scope': scope, 'redirect_uri': self.redirect_uri, 'response_type': 'code', - 'state': self.encode_state(state, self.next_url, service), + 'state': self.encode_state(state, self.next_url), 'nonce': nonce, 'acr_values': 'eidas1', } @@ -340,7 +330,7 @@ class LoginOrLinkView(View): def finish_login(self, request, user, user_info, created): self.update_user_info(user, user_info) utils_views.check_cookie_works(request) - utils_misc.login(request, user, 'france-connect', service=self.service) + utils_misc.login(request, user, 'france-connect') # keep id_token around for logout request.session['fc_id_token'] = self.id_token diff --git a/src/authentic2_idp_cas/views.py b/src/authentic2_idp_cas/views.py index f9535c4b..63a04c6b 100644 --- a/src/authentic2_idp_cas/views.py +++ b/src/authentic2_idp_cas/views.py @@ -36,6 +36,7 @@ from authentic2.utils.misc import ( normalize_attribute_values, redirect, ) +from authentic2.utils.service import set_service from authentic2.utils.view_decorators import enable_view_restriction from authentic2.views import logout as logout_view from authentic2_idp_cas.constants import ( @@ -151,6 +152,7 @@ class LoginView(CasMixin, View): model = Service.objects.for_service(service) if not model: return self.failure(request, service, 'service unknown') + set_service(request, model) if renew and gateway: return self.failure(request, service, 'renew and gateway cannot be requested at the same time') @@ -464,6 +466,7 @@ class LogoutView(View): if referrer: model = Service.objects.for_service(referrer) if model: + set_service(request, model) return logout_view(request, next_url=next_url, check_referer=False, do_local=False) return redirect(request, next_url) diff --git a/src/authentic2_idp_oidc/views.py b/src/authentic2_idp_oidc/views.py index dccf2d16..a28709be 100644 --- a/src/authentic2_idp_oidc/views.py +++ b/src/authentic2_idp_oidc/views.py @@ -49,6 +49,7 @@ from authentic2.a2_rbac.models import OrganizationalUnit from authentic2.decorators import setting_enabled from authentic2.exponential_retry_timeout import ExponentialRetryTimeout from authentic2.utils.misc import last_authentication_event, login_require, make_url, redirect +from authentic2.utils.service import set_service from authentic2.utils.view_decorators import check_view_restriction from authentic2.views import logout as a2_logout @@ -254,6 +255,8 @@ def authorize(request, *args, **kwargs): client = get_client(client_id=client_id) if not client: raise InvalidRequest(_('Unknown client identifier: "%s"') % client_id) + # define the current service + set_service(request, client) try: client.validate_redirect_uri(redirect_uri) except ValueError: @@ -341,7 +344,7 @@ def authorize_for_client(request, client, redirect_uri): params = {} if nonce is not None: params['nonce'] = nonce - return login_require(request, params=params, service=client, login_hint=login_hint) + return login_require(request, params=params, login_hint=login_hint) # view restriction and passive SSO if hasattr(request, 'view_restriction_response'): @@ -360,7 +363,7 @@ def authorize_for_client(request, client, redirect_uri): params = {} if nonce is not None: params['nonce'] = nonce - return login_require(request, params=params, service=client, login_hint=login_hint) + return login_require(request, params=params, login_hint=login_hint) iat = now() # iat = issued at @@ -820,6 +823,7 @@ def logout(request): ) for provider in providers: if post_logout_redirect_uri in provider.post_logout_redirect_uris.split(): + set_service(request, provider) break else: messages.warning(request, _('Invalid post logout URI')) diff --git a/tests/auth_fc/conftest.py b/tests/auth_fc/conftest.py index 131caf31..83a23ad3 100644 --- a/tests/auth_fc/conftest.py +++ b/tests/auth_fc/conftest.py @@ -160,12 +160,16 @@ class FranceConnectMock: @pytest.fixture -def franceconnect(settings, db): +def service(db): + return Service.objects.create(name='portail', slug='portail', ou=get_default_ou()) + + +@pytest.fixture +def franceconnect(settings, service, db): settings.A2_FC_ENABLE = True settings.A2_FC_CLIENT_ID = CLIENT_ID settings.A2_FC_CLIENT_SECRET = CLIENT_SECRET - Service.objects.create(name='portail', slug='portail', ou=get_default_ou()) mock_object = FranceConnectMock() with mock_object(): yield mock_object diff --git a/tests/auth_fc/test_auth_fc.py b/tests/auth_fc/test_auth_fc.py index a8a58d52..1079a95a 100644 --- a/tests/auth_fc/test_auth_fc.py +++ b/tests/auth_fc/test_auth_fc.py @@ -32,12 +32,12 @@ from authentic2.a2_rbac.models import OrganizationalUnit as OU from authentic2.a2_rbac.utils import get_default_ou from authentic2.apps.journal.models import Event from authentic2.custom_user.models import DeletedUser -from authentic2.models import Attribute, Service +from authentic2.models import Attribute from authentic2_auth_fc import models from authentic2_auth_fc.backends import FcBackend from authentic2_auth_fc.utils import requests_retry_session -from ..utils import get_link_from_mail, login +from ..utils import get_link_from_mail, login, set_service User = get_user_model() @@ -54,7 +54,7 @@ def test_fc_url_on_login(app, franceconnect): def test_retry_authorization_if_state_is_lost(settings, app, franceconnect, hooks): - response = app.get('/fc/callback/?next=/idp/&service=default%20portail', status=302) + response = app.get('/fc/callback/?next=/idp/', status=302) # clear fc-state cookie app.cookiejar.clear() response = franceconnect.handle_authorization(app, response.location, status=302) @@ -81,26 +81,26 @@ def test_login_autorun(settings, app, franceconnect): assert response.location == reverse('fc-login-or-link') -def test_create(settings, app, franceconnect, hooks): +def test_create(settings, app, franceconnect, hooks, service): # test direct creation - - response = app.get('/login/?service=portail&next=/idp/') + set_service(app, service) + response = app.get('/login/?next=/idp/') response = response.click(href='callback') assert User.objects.count() == 0 - assert Event.objects.which_references(Service.objects.get()).count() == 0 + assert Event.objects.which_references(service).count() == 0 response = franceconnect.handle_authorization(app, response.location, status=302) assert 'fc-state' not in app.cookies assert User.objects.count() == 1 # check login for service=portail was registered - assert Event.objects.which_references(Service.objects.get()).count() == 1 + assert Event.objects.which_references(service).count() == 1 user = User.objects.get() assert user.verified_attributes.first_name == 'Ÿuñe' assert user.verified_attributes.last_name == 'Frédérique' assert path(response.location) == '/idp/' assert hooks.event[1]['kwargs']['name'] == 'login' - assert hooks.event[1]['kwargs']['service'] == 'portail' + assert hooks.event[1]['kwargs']['service'] == service # we must be connected assert app.session['_auth_user_id'] assert app.session.get_expire_at_browser_close() @@ -130,7 +130,7 @@ def test_create_expired(settings, app, franceconnect, hooks): # test direct creation failure on an expired id_token franceconnect.exp = now() - datetime.timedelta(seconds=30) - response = app.get('/login/?service=portail&next=/idp/') + response = app.get('/login/?next=/idp/') response = response.click(href='callback') assert User.objects.count() == 0 diff --git a/tests/idp_oidc/test_misc.py b/tests/idp_oidc/test_misc.py index 896b8a81..0f18cd22 100644 --- a/tests/idp_oidc/test_misc.py +++ b/tests/idp_oidc/test_misc.py @@ -965,6 +965,9 @@ def test_role_control_access(login_first, oidc_settings, oidc_client, simple_use def test_registration_service_slug(oidc_settings, app, simple_oidc_client, simple_user, hooks, mailoutbox): redirect_uri = simple_oidc_client.redirect_uris.split()[0] + simple_oidc_client.ou.home_url = 'https://portal/' + simple_oidc_client.ou.save() + params = { 'client_id': simple_oidc_client.client_id, 'scope': 'openid profile email', @@ -977,19 +980,18 @@ def test_registration_service_slug(oidc_settings, app, simple_oidc_client, simpl authorize_url = make_url('oidc-authorize', params=params) response = app.get(authorize_url) - location = urllib.parse.urlparse(response['Location']) - query = urllib.parse.parse_qs(location.query) - assert query['service'] == ['default client'] response = response.follow().click('Register') - location = urllib.parse.urlparse(response.request.url) - query = urllib.parse.parse_qs(location.query) - assert query['service'] == ['default client'] - response.form.set('email', 'john.doe@example.com') response = response.form.submit() assert len(mailoutbox) == 1 link = utils.get_link_from_mail(mailoutbox[0]) response = app.get(link) + body = response.pyquery('body')[0] + assert body.attrib['data-home-ou-slug'] == 'default' + assert body.attrib['data-home-ou-name'] == 'Default organizational unit' + assert body.attrib['data-home-service-slug'] == 'client' + assert body.attrib['data-home-service-name'] == 'client' + assert body.attrib['data-home-url'] == 'https://portal/' response.form.set('first_name', 'John') response.form.set('last_name', 'Doe') response.form.set('password1', 'T0==toto') @@ -999,11 +1001,11 @@ def test_registration_service_slug(oidc_settings, app, simple_oidc_client, simpl assert hooks.event[0]['kwargs']['service'].slug == 'client' assert hooks.event[1]['kwargs']['name'] == 'registration' - assert hooks.event[1]['kwargs']['service'] == 'client' + assert hooks.event[1]['kwargs']['service'].slug == 'client' assert hooks.event[2]['kwargs']['name'] == 'login' assert hooks.event[2]['kwargs']['how'] == 'email' - assert hooks.event[2]['kwargs']['service'] == 'client' + assert hooks.event[2]['kwargs']['service'].slug == 'client' def test_claim_default_value(oidc_settings, normal_oidc_client, simple_user, app): diff --git a/tests/test_idp_saml2.py b/tests/test_idp_saml2.py index e9abce31..dfb1227d 100644 --- a/tests/test_idp_saml2.py +++ b/tests/test_idp_saml2.py @@ -33,7 +33,7 @@ from django.utils.encoding import force_bytes, force_str, force_text from django.utils.translation import gettext as _ from authentic2.a2_rbac.models import OrganizationalUnit, Role -from authentic2.constants import NONCE_FIELD_NAME, SERVICE_FIELD_NAME +from authentic2.constants import NONCE_FIELD_NAME from authentic2.custom_user.models import User from authentic2.idp.saml import saml2_endpoints from authentic2.idp.saml.saml2_endpoints import get_extensions, get_login_hints_extension @@ -330,7 +330,6 @@ class Scenario: reverse('auth_login'), **{ 'nonce': '*', - SERVICE_FIELD_NAME: 'default ' + self.sp.slug, REDIRECT_FIELD_NAME: make_url( 'a2-idp-saml-continue', params={NONCE_FIELD_NAME: request_id} ), diff --git a/tests/test_login.py b/tests/test_login.py index 2fcf79be..e20c0879 100644 --- a/tests/test_login.py +++ b/tests/test_login.py @@ -22,7 +22,7 @@ from django.contrib.auth import get_user_model from authentic2 import models from authentic2.utils.misc import get_token_login_url -from .utils import assert_event, login +from .utils import assert_event, login, set_service User = get_user_model() @@ -85,22 +85,22 @@ def test_show_condition(db, app, settings, caplog): assert len(caplog.records) == 1 -def test_show_condition_service(db, app, settings): +def test_show_condition_service(db, rf, app, settings): + portal = models.Service.objects.create(pk=1, name='Service', slug='portal') + service = models.Service.objects.create(pk=2, name='Service', slug='service') settings.AUTH_FRONTENDS_KWARGS = {'password': {'show_condition': 'service_slug == \'portal\''}} - response = app.get('/login/', params={}) - assert 'name="login-password-submit"' not in response - # service doesn't exist - response = app.get('/login/', params={'service': 'portal'}) + response = app.get('/login/') assert 'name="login-password-submit"' not in response - # Create a service - models.Service.objects.create(name='Service', slug='portal') - response = app.get('/login/', params={'service': 'portal'}) + set_service(app, portal) + + response = app.get('/login/') assert 'name="login-password-submit"' in response - models.Service.objects.create(name='Service', slug='service') - response = app.get('/login/', params={'service': 'service'}) + set_service(app, service) + + response = app.get('/login/') assert 'name="login-password-submit"' not in response @@ -251,29 +251,31 @@ def test_ou_selector(app, settings, simple_user, ou1, ou2, user_ou1, role_ou1): response = app.get('/login/') assert response.pyquery.find('select#id_ou option[selected]')[0].text == 'Default organizational unit' + set_service(app, service) # service is specified but not access-control is defined, default for user is selected - response = app.get('/login/?service=service') + response = app.get('/login/') assert response.pyquery.find('select#id_ou option[selected]')[0].text == 'Default organizational unit' # service is specified, access control is defined but role is empty, default for user is selected service.authorized_roles.through.objects.create(service=service, role=role_ou1) - response = app.get('/login/?service=service') + response = app.get('/login/') assert response.pyquery.find('select#id_ou option[selected]')[0].text == 'Default organizational unit' # user is added to role_ou1, default for user is still selected user_ou1.roles.add(role_ou1) - response = app.get('/login/?service=service') + response = app.get('/login/') assert response.pyquery.find('select#id_ou option[selected]')[0].text == 'Default organizational unit' # Clear cookies, OU1 is selected app.cookiejar.clear() - response = app.get('/login/?service=service') + set_service(app, service) + response = app.get('/login/') assert response.pyquery.find('select#id_ou option[selected]')[0].text == 'OU1' # if we change the user's ou, then default selected OU changes user_ou1.ou = ou2 user_ou1.save() - response = app.get('/login/?service=service') + response = app.get('/login/') assert response.pyquery.find('select#id_ou option[selected]')[0].text == 'OU2' diff --git a/tests/test_template.py b/tests/test_template.py index e341bf58..d0612c6e 100644 --- a/tests/test_template.py +++ b/tests/test_template.py @@ -21,6 +21,8 @@ from authentic2.a2_rbac.utils import get_default_ou from authentic2.models import Service from authentic2.utils.template import Template, TemplateError +from . import utils + pytestmark = pytest.mark.django_db @@ -114,7 +116,9 @@ def test_render_template_missing_variable(): def test_service_in_template(app, simple_user, service): - resp = app.get(reverse('auth_login') + '?service=%s' % service.slug) + utils.set_service(app, service) + + resp = app.get(reverse('auth_login')) assert resp.pyquery('body').attr('data-service-slug') == service.slug assert resp.pyquery('body').attr('data-service-name') == service.name @@ -129,6 +133,7 @@ def test_service_in_template(app, simple_user, service): # if user comes back from a different service, the information is updated new_service = Service.objects.create(ou=get_default_ou(), slug='service2', name='Service2') - resp = app.get(reverse('account_management') + '?service=%s' % new_service.slug) + utils.set_service(app, new_service) + resp = app.get(reverse('account_management')) assert resp.pyquery('body').attr('data-service-slug') == new_service.slug assert resp.pyquery('body').attr('data-service-name') == new_service.name diff --git a/tests/utils.py b/tests/utils.py index 4ce94753..3f46269f 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -307,3 +307,21 @@ def assert_event(event_type_name, user=None, session=None, service=None, target_ ) elif data and count > 1: assert qs.filter(**{'data__' + k: v for k, v in data.items()}).count() == 1 + + +def set_service(app, service): + from importlib import import_module + + from django.conf import settings + + from authentic2.utils.service import _set_service + + engine = import_module(settings.SESSION_ENGINE) + if app.session == {}: + session = engine.SessionStore() + else: + session = app.session + _set_service(session, service) + session.save() + if app.session == {}: + app.set_cookie(settings.SESSION_COOKIE_NAME, session.session_key) -- 2.34.1