From 49ddf9fc27a4e3e1d477bbbbf3f817677f034f57 Mon Sep 17 00:00:00 2001 From: Paul Marillonnet Date: Mon, 30 May 2022 17:59:15 +0200 Subject: [PATCH 2/2] auth_oidc: add an oidc-sync-provider command (#62710) --- .../management/commands/oidc-sync-provider.py | 45 ++++++ .../migrations/0013_synchronization_fields.py | 23 ++++ src/authentic2_auth_oidc/models.py | 92 ++++++++++++- tests/test_commands.py | 130 +++++++++++++++++- 4 files changed, 287 insertions(+), 3 deletions(-) create mode 100644 src/authentic2_auth_oidc/management/commands/oidc-sync-provider.py create mode 100644 src/authentic2_auth_oidc/migrations/0013_synchronization_fields.py diff --git a/src/authentic2_auth_oidc/management/commands/oidc-sync-provider.py b/src/authentic2_auth_oidc/management/commands/oidc-sync-provider.py new file mode 100644 index 00000000..02d80d12 --- /dev/null +++ b/src/authentic2_auth_oidc/management/commands/oidc-sync-provider.py @@ -0,0 +1,45 @@ +# authentic2 - versatile identity manager +# Copyright (C) 2010-2022 Entr'ouvert +# +# This program is free software: you can redistribute it and/or modify it +# under the terms of the GNU Affero General Public License as published +# by the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. +# +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . + +import logging + +from authentic2.base_commands import LogToConsoleCommand +from authentic2_auth_oidc.models import OIDCProvider + + +class Command(LogToConsoleCommand): + loggername = 'authentic2_auth_oidc.models' + + def add_arguments(self, parser): + parser.add_argument('--provider', type=str, default=None) + + def core_command(self, *args, **kwargs): + provider = kwargs['provider'] + + logger = logging.getLogger(self.loggername) + providers = OIDCProvider.objects.filter(a2_synchronization_supported=True) + if provider: + providers = providers.filter(slug=provider) + if not providers.count(): + logger.error('no provider supporting synchronization found, exiting') + return + logger.info( + 'got %s provider(s): %s', + providers.count(), + ' '.join(providers.values_list('slug', flat=True)), + ) + for provider in providers: + provider.perform_synchronization() diff --git a/src/authentic2_auth_oidc/migrations/0013_synchronization_fields.py b/src/authentic2_auth_oidc/migrations/0013_synchronization_fields.py new file mode 100644 index 00000000..25d8d027 --- /dev/null +++ b/src/authentic2_auth_oidc/migrations/0013_synchronization_fields.py @@ -0,0 +1,23 @@ +# Generated by Django 2.2.26 on 2022-08-03 09:30 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('authentic2_auth_oidc', '0016_auto_20221019_1148'), + ] + + operations = [ + migrations.AddField( + model_name='oidcprovider', + name='a2_synchronization_supported', + field=models.BooleanField(default=False, verbose_name='Authentic2 synchronization supported'), + ), + migrations.AddField( + model_name='oidcprovider', + name='last_sync_time', + field=models.DateTimeField(blank=True, null=True, verbose_name='Last synchronization time'), + ), + ] diff --git a/src/authentic2_auth_oidc/models.py b/src/authentic2_auth_oidc/models.py index 37a23987..9709f980 100644 --- a/src/authentic2_auth_oidc/models.py +++ b/src/authentic2_auth_oidc/models.py @@ -15,13 +15,17 @@ # along with this program. If not, see . import json +import logging import uuid +from datetime import datetime, timedelta +import requests from django.conf import settings from django.contrib.postgres.fields import JSONField from django.core.exceptions import ValidationError from django.db import models from django.shortcuts import render +from django.utils.timezone import now from django.utils.translation import gettext_lazy as _ from django.utils.translation import pgettext_lazy from jwcrypto.jwk import InvalidJWKValue, JWKSet @@ -33,7 +37,7 @@ from authentic2.apps.authenticators.models import ( BaseAuthenticator, ) from authentic2.utils.misc import make_url, redirect_to_login -from authentic2.utils.template import validate_template +from authentic2.utils.template import Template, validate_template from . import managers @@ -116,6 +120,17 @@ class OIDCProvider(BaseAuthenticator): verbose_name=_('max authentication age'), blank=True, null=True ) + # authentic2 specific synchronization api + a2_synchronization_supported = models.BooleanField( + verbose_name=_('Authentic2 synchronization supported'), + default=False, + ) + last_sync_time = models.DateTimeField( + verbose_name=_('Last synchronization time'), + null=True, + blank=True, + ) + # metadata created = models.DateTimeField(verbose_name=_('creation date'), auto_now_add=True) modified = models.DateTimeField(verbose_name=_('last modification date'), auto_now=True) @@ -242,6 +257,81 @@ class OIDCProvider(BaseAuthenticator): ] return render(request, template_names, context) + def perform_synchronization(self, sync_time=None, timeout=30): + logger = logging.getLogger(__name__) + + if not self.a2_synchronization_supported: + logger.error('OIDC provider %s does not support synchronization', self.slug) + return + if not sync_time: + sync_time = now() - timedelta(minutes=1) + + # check all existing users + def chunks(l, n): + for i in range(0, len(l), n): + yield l[i : i + n] + + url = self.issuer + '/api/users/synchronization/' + + unknown_uuids = [] + auth = (self.client_id, self.client_secret) + for accounts in chunks(OIDCAccount.objects.filter(provider=self), 100): + subs = [x.sub for x in accounts] + resp = requests.post(url, json={'known_uuids': subs}, auth=auth, timeout=timeout) + resp.raise_for_status() + unknown_uuids.extend(resp.json().get('unknown_uuids')) + deletion_ratio = len(unknown_uuids) / OIDCAccount.objects.filter(provider=self).count() + if deletion_ratio > 0.05: # higher than 5%, something definitely went wrong + logger.error( + 'deletion ratio is abnormally high (%s), aborting unkwown users deletion', deletion_ratio + ) + else: + OIDCAccount.objects.filter(sub__in=unknown_uuids).delete() + + # update recently modified users + url = self.issuer + '/api/users/?modified__gt=%s&claim_resolution' % ( + self.last_sync_time or datetime.utcfromtimestamp(0) + ).strftime('%Y-%m-%dT%H:%M:%S') + while url: + resp = requests.get(url, auth=auth, timeout=timeout) + resp.raise_for_status() + url = resp.json().get('next') + logger.info('got %s users', len(resp.json()['results'])) + for user_dict in resp.json()['results']: + try: + account = OIDCAccount.objects.get(user__email=user_dict['email']) + except OIDCAccount.DoesNotExist: + continue + except OIDCAccount.MultipleObjectsReturned: + continue + had_changes = False + for claim in self.claim_mappings.all(): + if '{{' in claim.claim or '{%' in claim.claim: + template = Template(claim.claim) + attribute_value = template.render(context=user_dict) + else: + attribute_value = user_dict.get(claim.claim) + try: + old_attribute_value = getattr(account.user, claim.attribute) + except AttributeError: + try: + old_attribute_value = getattr(account.user.attributes, claim.attribute) + except AttributeError: + old_attribute_value = None + if old_attribute_value == attribute_value: + continue + had_changes = True + setattr(account.user, claim.attribute, attribute_value) + try: + setattr(account.user.attributes, claim.attribute, attribute_value) + except AttributeError: + pass + if had_changes: + logger.debug('had changes, saving %r', account.user) + account.user.save() + self.last_sync_time = sync_time + self.save(update_fields=['last_sync_time']) + class OIDCClaimMapping(AuthenticatorRelatedObjectBase): NOT_VERIFIED = 0 diff --git a/tests/test_commands.py b/tests/test_commands.py index fba3c68c..1aa0848a 100644 --- a/tests/test_commands.py +++ b/tests/test_commands.py @@ -17,8 +17,11 @@ import datetime import importlib import json +import random +import uuid from io import BufferedReader, BufferedWriter, TextIOWrapper +import httmock import py import pytest import webtest @@ -40,10 +43,11 @@ from authentic2.a2_rbac.utils import get_default_ou from authentic2.apps.journal.models import Event from authentic2.custom_user.models import DeletedUser from authentic2.models import UserExternalId -from authentic2_auth_oidc.models import OIDCAccount, OIDCProvider +from authentic2.utils import crypto +from authentic2_auth_oidc.models import OIDCAccount, OIDCClaimMapping, OIDCProvider from django_rbac.utils import get_operation -from .utils import call_command, login +from .utils import call_command, check_log, login User = get_user_model() @@ -456,3 +460,125 @@ def test_clean_user_exports(settings, app, superuser, freezer): call_command('clean-user-exports') with pytest.raises(webtest.app.AppError): resp.click('Download CSV') + + +@pytest.mark.parametrize('deletion_number,deletion_valid', [(2, True), (5, True), (10, False)]) +def test_oidc_sync_provider(db, app, admin, settings, caplog, deletion_number, deletion_valid): + oidc_provider = OIDCProvider.objects.create( + issuer='https://some.provider', + name='Some Provider', + slug='some-provider', + ou=get_default_ou(), + ) + OIDCClaimMapping.objects.create( + authenticator=oidc_provider, + attribute='username', + idtoken_claim=False, + claim='username', + ) + OIDCClaimMapping.objects.create( + authenticator=oidc_provider, + attribute='email', + idtoken_claim=False, + claim='email', + ) + # last one, with an idtoken claim + OIDCClaimMapping.objects.create( + authenticator=oidc_provider, + attribute='last_name', + idtoken_claim=True, + claim='family_name', + ) + # typo in template string + OIDCClaimMapping.objects.create( + authenticator=oidc_provider, + attribute='first_name', + idtoken_claim=True, + claim='given_name', + ) + User = get_user_model() + for i in range(100): + user = User.objects.create( + first_name='John%s' % i, + last_name='Doe%s' % i, + username='john.doe.%s' % i, + email='john.doe.%s@ad.dre.ss' % i, + ou=get_default_ou(), + ) + identifier = uuid.UUID(user.uuid).bytes + sector_identifier = 'some.provider' + cipher_args = [ + settings.SECRET_KEY.encode('utf-8'), + identifier, + sector_identifier, + ] + sub = crypto.aes_base64url_deterministic_encrypt(*cipher_args).decode('utf-8') + OIDCAccount.objects.create(user=user, provider=oidc_provider, sub=sub) + + def synchronization_post_deletion_response(url, request): + headers = {'content-type': 'application/json'} + content = { + 'unknown_uuids': [ + account.sub for account in random.sample(list(OIDCAccount.objects.all()), deletion_number) + ] + } + return httmock.response(status_code=200, headers=headers, content=content, request=request) + + def synchronization_get_modified_response(url, request): + headers = {'content-type': 'application/json'} + # randomized batch of modified users + modified_users = random.sample(list(User.objects.all()), 20) + results = [] + for count, user in enumerate(modified_users): + user_json = user.to_json() + user_json['username'] = f'modified_{count}' + user_json['first_name'] = 'Mod' + user_json['last_name'] = 'Ified' + # mocking claim resolution by oidc provider + user_json['given_name'] = 'Mod' + user_json['family_name'] = 'Ified' + results.append(user_json) + content = {'results': results} + return httmock.response(status_code=200, headers=headers, content=content, request=request) + + with httmock.HTTMock( + httmock.urlmatch( + netloc=r'some\.provider', + path=r'^/api/users/synchronization/$', + method='POST', + )(synchronization_post_deletion_response) + ): + + with httmock.HTTMock( + httmock.urlmatch( + netloc=r'some\.provider', + path=r'^/api/users/*', + method='GET', + )(synchronization_get_modified_response) + ): + with check_log(caplog, 'no provider supporting synchronization'): + call_command('oidc-sync-provider', '-v1') + + oidc_provider.a2_synchronization_supported = True + oidc_provider.save() + + with check_log(caplog, 'no provider supporting synchronization'): + call_command('oidc-sync-provider', '--provider', 'whatever', '-v1') + + with check_log(caplog, 'got 20 users'): + call_command('oidc-sync-provider', '-v1') + if deletion_valid: + # existing users check + assert OIDCAccount.objects.count() == 100 - deletion_number + else: + assert OIDCAccount.objects.count() == 100 + assert caplog.records[3].levelname == 'ERROR' + assert 'deletion ratio is abnormally high' in caplog.records[3].message + + # users update + assert User.objects.filter(username__startswith='modified').count() in range( + 20 - deletion_number, 21 + ) + assert User.objects.filter(first_name='Mod', last_name='Ified').count() in range( + 20 - deletion_number, 21 + ) -- 2.37.2