Projet

Général

Profil

0002-auth_oidc-add-an-oidc-sync-provider-command-62710.patch

Paul Marillonnet, 14 décembre 2022 11:30

Télécharger (15,3 ko)

Voir les différences:

Subject: [PATCH 2/2] auth_oidc: add an oidc-sync-provider command (#62710)

 .../management/commands/oidc-sync-provider.py |  45 ++++++
 .../migrations/0013_synchronization_fields.py |  23 +++
 src/authentic2_auth_oidc/models.py            |  94 +++++++++++-
 tests/test_commands.py                        | 139 +++++++++++++++++-
 4 files changed, 298 insertions(+), 3 deletions(-)
 create mode 100644 src/authentic2_auth_oidc/management/commands/oidc-sync-provider.py
 create mode 100644 src/authentic2_auth_oidc/migrations/0013_synchronization_fields.py
src/authentic2_auth_oidc/management/commands/oidc-sync-provider.py
1
# authentic2 - versatile identity manager
2
# Copyright (C) 2010-2022  Entr'ouvert
3
#
4
# This program is free software: you can redistribute it and/or modify it
5
# under the terms of the GNU Affero General Public License as published
6
# by the Free Software Foundation, either version 3 of the License, or
7
# (at your option) any later version.
8
#
9
# This program is distributed in the hope that it will be useful,
10
# but WITHOUT ANY WARRANTY; without even the implied warranty of
11
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12
# GNU Affero General Public License for more details.
13
#
14
# You should have received a copy of the GNU Affero General Public License
15
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
16

  
17
import logging
18

  
19
from authentic2.base_commands import LogToConsoleCommand
20
from authentic2_auth_oidc.models import OIDCProvider
21

  
22

  
23
class Command(LogToConsoleCommand):
24
    loggername = 'authentic2_auth_oidc.models'
25

  
26
    def add_arguments(self, parser):
27
        parser.add_argument('--provider', type=str, default=None)
28

  
29
    def core_command(self, *args, **kwargs):
30
        provider = kwargs['provider']
31

  
32
        logger = logging.getLogger(self.loggername)
33
        providers = OIDCProvider.objects.filter(a2_synchronization_supported=True)
34
        if provider:
35
            providers = providers.filter(slug=provider)
36
        if not providers.count():
37
            logger.error('no provider supporting synchronization found, exiting')
38
            return
39
        logger.info(
40
            'got %s provider(s): %s',
41
            providers.count(),
42
            ' '.join(providers.values_list('slug', flat=True)),
43
        )
44
        for provider in providers:
45
            provider.perform_synchronization()
src/authentic2_auth_oidc/migrations/0013_synchronization_fields.py
1
# Generated by Django 2.2.26 on 2022-08-03 09:30
2

  
3
from django.db import migrations, models
4

  
5

  
6
class Migration(migrations.Migration):
7

  
8
    dependencies = [
9
        ('authentic2_auth_oidc', '0016_auto_20221019_1148'),
10
    ]
11

  
12
    operations = [
13
        migrations.AddField(
14
            model_name='oidcprovider',
15
            name='a2_synchronization_supported',
16
            field=models.BooleanField(default=False, verbose_name='Authentic2 synchronization supported'),
17
        ),
18
        migrations.AddField(
19
            model_name='oidcprovider',
20
            name='last_sync_time',
21
            field=models.DateTimeField(blank=True, null=True, verbose_name='Last synchronization time'),
22
        ),
23
    ]
src/authentic2_auth_oidc/models.py
15 15
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
16 16

  
17 17
import json
18
import logging
19
from datetime import datetime, timedelta
18 20

  
19 21
import django
22
import requests
20 23
from django.conf import settings
21 24
from django.core.exceptions import ValidationError
22 25
from django.db import models
23 26
from django.shortcuts import render
27
from django.utils.timezone import now
24 28
from django.utils.translation import gettext_lazy as _
25 29
from django.utils.translation import pgettext_lazy
26 30
from jwcrypto.jwk import InvalidJWKValue, JWKSet
......
32 36
    BaseAuthenticator,
33 37
)
34 38
from authentic2.utils.misc import make_url, redirect_to_login
35
from authentic2.utils.template import validate_template
39
from authentic2.utils.template import Template, validate_template
36 40

  
37 41
from . import managers
38 42

  
......
120 124
        verbose_name=_('max authentication age'), blank=True, null=True
121 125
    )
122 126

  
127
    # authentic2 specific synchronization api
128
    a2_synchronization_supported = models.BooleanField(
129
        verbose_name=_('Authentic2 synchronization supported'),
130
        default=False,
131
    )
132
    last_sync_time = models.DateTimeField(
133
        verbose_name=_('Last synchronization time'),
134
        null=True,
135
        blank=True,
136
    )
137

  
123 138
    # metadata
124 139
    created = models.DateTimeField(verbose_name=_('creation date'), auto_now_add=True)
125 140
    modified = models.DateTimeField(verbose_name=_('last modification date'), auto_now=True)
......
247 262
        ]
248 263
        return render(request, template_names, context)
249 264

  
265
    def perform_synchronization(self, sync_time=None, timeout=30):
266
        logger = logging.getLogger(__name__)
267

  
268
        if not self.a2_synchronization_supported:
269
            logger.error('OIDC provider %s does not support synchronization', self.slug)
270
            return
271
        if not sync_time:
272
            sync_time = now() - timedelta(minutes=1)
273

  
274
        # check all existing users
275
        def chunks(l, n):
276
            for i in range(0, len(l), n):
277
                yield l[i : i + n]
278

  
279
        url = self.issuer + '/api/users/synchronization/'
280

  
281
        unknown_uuids = []
282
        auth = (self.client_id, self.client_secret)
283
        for accounts in chunks(OIDCAccount.objects.filter(provider=self), 100):
284
            subs = [x.sub for x in accounts]
285
            resp = requests.post(url, json={'known_uuids': subs}, auth=auth, timeout=timeout)
286
            resp.raise_for_status()
287
            unknown_uuids.extend(resp.json().get('unknown_uuids'))
288
        deletion_ratio = len(unknown_uuids) / OIDCAccount.objects.filter(provider=self).count()
289
        if deletion_ratio > 0.05:  # higher than 5%, something definitely went wrong
290
            logger.error(
291
                'deletion ratio is abnormally high (%s), aborting unkwown users deletion', deletion_ratio
292
            )
293
        else:
294
            OIDCAccount.objects.filter(sub__in=unknown_uuids).delete()
295

  
296
        # update recently modified users
297
        url = self.issuer + '/api/users/?modified__gt=%s&claim_resolution' % (
298
            self.last_sync_time or datetime.utcfromtimestamp(0)
299
        ).strftime('%Y-%m-%dT%H:%M:%S')
300
        while url:
301
            resp = requests.get(url, auth=auth, timeout=timeout)
302
            resp.raise_for_status()
303
            url = resp.json().get('next')
304
            logger.info('got %s users', len(resp.json()['results']))
305
            for user_dict in resp.json()['results']:
306
                if not user_dict.get('sub', None):
307
                    continue
308
                try:
309
                    account = OIDCAccount.objects.get(sub=user_dict['sub'])
310
                except OIDCAccount.DoesNotExist:
311
                    continue
312
                except OIDCAccount.MultipleObjectsReturned:
313
                    continue
314
                had_changes = False
315
                for claim in self.claim_mappings.all():
316
                    if '{{' in claim.claim or '{%' in claim.claim:
317
                        template = Template(claim.claim)
318
                        attribute_value = template.render(context=user_dict)
319
                    else:
320
                        attribute_value = user_dict.get(claim.claim)
321
                    try:
322
                        old_attribute_value = getattr(account.user, claim.attribute)
323
                    except AttributeError:
324
                        try:
325
                            old_attribute_value = getattr(account.user.attributes, claim.attribute)
326
                        except AttributeError:
327
                            old_attribute_value = None
328
                    if old_attribute_value == attribute_value:
329
                        continue
330
                    had_changes = True
331
                    setattr(account.user, claim.attribute, attribute_value)
332
                    try:
333
                        setattr(account.user.attributes, claim.attribute, attribute_value)
334
                    except AttributeError:
335
                        pass
336
                if had_changes:
337
                    logger.debug('had changes, saving %r', account.user)
338
                    account.user.save()
339
        self.last_sync_time = sync_time
340
        self.save(update_fields=['last_sync_time'])
341

  
250 342

  
251 343
class OIDCClaimMapping(AuthenticatorRelatedObjectBase):
252 344
    NOT_VERIFIED = 0
tests/test_commands.py
17 17
import datetime
18 18
import importlib
19 19
import json
20
import random
21
import uuid
20 22
from io import BufferedReader, BufferedWriter, TextIOWrapper
21 23

  
24
import httmock
22 25
import py
23 26
import pytest
24 27
import webtest
......
40 43
from authentic2.apps.journal.models import Event
41 44
from authentic2.custom_user.models import DeletedUser
42 45
from authentic2.models import UserExternalId
43
from authentic2_auth_oidc.models import OIDCAccount, OIDCProvider
46
from authentic2.utils import crypto
47
from authentic2_auth_oidc.models import OIDCAccount, OIDCClaimMapping, OIDCProvider
44 48

  
45
from .utils import call_command, login
49
from .utils import call_command, check_log, login
46 50

  
47 51
User = get_user_model()
48 52

  
......
520 524
    call_command('clean-user-exports')
521 525
    with pytest.raises(webtest.app.AppError):
522 526
        resp.click('Download CSV')
527

  
528

  
529
@pytest.mark.parametrize('deletion_number,deletion_valid', [(2, True), (5, True), (10, False)])
530
def test_oidc_sync_provider(db, app, admin, settings, caplog, deletion_number, deletion_valid):
531
    oidc_provider = OIDCProvider.objects.create(
532
        issuer='https://some.provider',
533
        name='Some Provider',
534
        slug='some-provider',
535
        ou=get_default_ou(),
536
    )
537
    OIDCClaimMapping.objects.create(
538
        authenticator=oidc_provider,
539
        attribute='username',
540
        idtoken_claim=False,
541
        claim='username',
542
    )
543
    OIDCClaimMapping.objects.create(
544
        authenticator=oidc_provider,
545
        attribute='email',
546
        idtoken_claim=False,
547
        claim='email',
548
    )
549
    # last one, with an idtoken claim
550
    OIDCClaimMapping.objects.create(
551
        authenticator=oidc_provider,
552
        attribute='last_name',
553
        idtoken_claim=True,
554
        claim='family_name',
555
    )
556
    # typo in template string
557
    OIDCClaimMapping.objects.create(
558
        authenticator=oidc_provider,
559
        attribute='first_name',
560
        idtoken_claim=True,
561
        claim='given_name',
562
    )
563
    User = get_user_model()
564
    for i in range(100):
565
        user = User.objects.create(
566
            first_name='John%s' % i,
567
            last_name='Doe%s' % i,
568
            username='john.doe.%s' % i,
569
            email='john.doe.%s@ad.dre.ss' % i,
570
            ou=get_default_ou(),
571
        )
572
        identifier = uuid.UUID(user.uuid).bytes
573
        sector_identifier = 'some.provider'
574
        cipher_args = [
575
            settings.SECRET_KEY.encode('utf-8'),
576
            identifier,
577
            sector_identifier,
578
        ]
579
        sub = crypto.aes_base64url_deterministic_encrypt(*cipher_args).decode('utf-8')
580
        OIDCAccount.objects.create(user=user, provider=oidc_provider, sub=sub)
581

  
582
    def synchronization_post_deletion_response(url, request):
583
        headers = {'content-type': 'application/json'}
584
        content = {
585
            'unknown_uuids': [
586
                account.sub for account in random.sample(list(OIDCAccount.objects.all()), deletion_number)
587
            ]
588
        }
589
        return httmock.response(status_code=200, headers=headers, content=content, request=request)
590

  
591
    def synchronization_get_modified_response(url, request):
592
        headers = {'content-type': 'application/json'}
593
        # randomized batch of modified users
594
        modified_users = random.sample(list(User.objects.all()), 20)
595
        results = []
596
        for count, user in enumerate(modified_users):
597
            user_json = user.to_json()
598
            user_json['username'] = f'modified_{count}'
599
            user_json['first_name'] = 'Mod'
600
            user_json['last_name'] = 'Ified'
601
            # mocking claim resolution by oidc provider
602
            user_json['given_name'] = 'Mod'
603
            user_json['family_name'] = 'Ified'
604

  
605
            # add user sub to response
606
            try:
607
                account = OIDCAccount.objects.get(user=user)
608
            except OIDCAccount.DoesNotExist:
609
                pass
610
            else:
611
                user_json['sub'] = account.sub
612

  
613
            results.append(user_json)
614
        content = {'results': results}
615
        return httmock.response(status_code=200, headers=headers, content=content, request=request)
616

  
617
    with httmock.HTTMock(
618
        httmock.urlmatch(
619
            netloc=r'some\.provider',
620
            path=r'^/api/users/synchronization/$',
621
            method='POST',
622
        )(synchronization_post_deletion_response)
623
    ):
624

  
625
        with httmock.HTTMock(
626
            httmock.urlmatch(
627
                netloc=r'some\.provider',
628
                path=r'^/api/users/*',
629
                method='GET',
630
            )(synchronization_get_modified_response)
631
        ):
632
            with check_log(caplog, 'no provider supporting synchronization'):
633
                call_command('oidc-sync-provider', '-v1')
634

  
635
            oidc_provider.a2_synchronization_supported = True
636
            oidc_provider.save()
637

  
638
            with check_log(caplog, 'no provider supporting synchronization'):
639
                call_command('oidc-sync-provider', '--provider', 'whatever', '-v1')
640

  
641
            with check_log(caplog, 'got 20 users'):
642
                call_command('oidc-sync-provider', '-v1')
643
            if deletion_valid:
644
                # existing users check
645
                assert OIDCAccount.objects.count() == 100 - deletion_number
646
            else:
647
                assert OIDCAccount.objects.count() == 100
648
                assert caplog.records[3].levelname == 'ERROR'
649
                assert 'deletion ratio is abnormally high' in caplog.records[3].message
650

  
651
            # users update
652
            assert User.objects.filter(username__startswith='modified').count() in range(
653
                20 - deletion_number, 21
654
            )
655
            assert User.objects.filter(first_name='Mod', last_name='Ified').count() in range(
656
                20 - deletion_number, 21
657
            )
523
-