Projet

Général

Profil

0002-auth_oidc-add-an-oidc-sync-provider-command-62710.patch

Paul Marillonnet, 28 octobre 2022 09:40

Télécharger (15,5 ko)

Voir les différences:

Subject: [PATCH 2/2] auth_oidc: add an oidc-sync-provider command (#62710)

 .../management/commands/oidc-sync-provider.py |  45 ++++++
 .../migrations/0013_synchronization_fields.py |  23 +++
 src/authentic2_auth_oidc/models.py            |  96 +++++++++++-
 tests/test_commands.py                        | 139 +++++++++++++++++-
 4 files changed, 300 insertions(+), 3 deletions(-)
 create mode 100644 src/authentic2_auth_oidc/management/commands/oidc-sync-provider.py
 create mode 100644 src/authentic2_auth_oidc/migrations/0013_synchronization_fields.py
src/authentic2_auth_oidc/management/commands/oidc-sync-provider.py
1
# authentic2 - versatile identity manager
2
# Copyright (C) 2010-2022  Entr'ouvert
3
#
4
# This program is free software: you can redistribute it and/or modify it
5
# under the terms of the GNU Affero General Public License as published
6
# by the Free Software Foundation, either version 3 of the License, or
7
# (at your option) any later version.
8
#
9
# This program is distributed in the hope that it will be useful,
10
# but WITHOUT ANY WARRANTY; without even the implied warranty of
11
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12
# GNU Affero General Public License for more details.
13
#
14
# You should have received a copy of the GNU Affero General Public License
15
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
16

  
17
import logging
18

  
19
from authentic2.base_commands import LogToConsoleCommand
20
from authentic2_auth_oidc.models import OIDCProvider
21

  
22

  
23
class Command(LogToConsoleCommand):
24
    loggername = 'authentic2_auth_oidc.models'
25

  
26
    def add_arguments(self, parser):
27
        parser.add_argument('--provider', type=str, default=None)
28

  
29
    def core_command(self, *args, **kwargs):
30
        provider = kwargs['provider']
31

  
32
        logger = logging.getLogger(self.loggername)
33
        providers = OIDCProvider.objects.filter(a2_synchronization_supported=True)
34
        if provider:
35
            providers = providers.filter(slug=provider)
36
        if not providers.count():
37
            logger.error('no provider supporting synchronization found, exiting')
38
            return
39
        logger.info(
40
            'got %s provider(s): %s',
41
            providers.count(),
42
            ' '.join(providers.values_list('slug', flat=True)),
43
        )
44
        for provider in providers:
45
            provider.perform_synchronization()
src/authentic2_auth_oidc/migrations/0013_synchronization_fields.py
1
# Generated by Django 2.2.26 on 2022-08-03 09:30
2

  
3
from django.db import migrations, models
4

  
5

  
6
class Migration(migrations.Migration):
7

  
8
    dependencies = [
9
        ('authentic2_auth_oidc', '0016_auto_20221019_1148'),
10
    ]
11

  
12
    operations = [
13
        migrations.AddField(
14
            model_name='oidcprovider',
15
            name='a2_synchronization_supported',
16
            field=models.BooleanField(default=False, verbose_name='Authentic2 synchronization supported'),
17
        ),
18
        migrations.AddField(
19
            model_name='oidcprovider',
20
            name='last_sync_time',
21
            field=models.DateTimeField(blank=True, null=True, verbose_name='Last synchronization time'),
22
        ),
23
    ]
src/authentic2_auth_oidc/models.py
15 15
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
16 16

  
17 17
import json
18
import logging
19
from datetime import datetime, timedelta
18 20

  
21
import requests
19 22
from django.conf import settings
20 23
from django.contrib.postgres.fields import JSONField
21 24
from django.core.exceptions import ValidationError
22 25
from django.db import models
23 26
from django.shortcuts import render
27
from django.utils.timezone import now
24 28
from django.utils.translation import gettext_lazy as _
25 29
from django.utils.translation import pgettext_lazy
26 30
from jwcrypto.jwk import InvalidJWKValue, JWKSet
......
32 36
    BaseAuthenticator,
33 37
)
34 38
from authentic2.utils.misc import make_url, redirect_to_login
35
from authentic2.utils.template import validate_template
39
from authentic2.utils.template import Template, validate_template
36 40

  
37 41
from . import managers
38 42

  
......
115 119
        verbose_name=_('max authentication age'), blank=True, null=True
116 120
    )
117 121

  
122
    # authentic2 specific synchronization api
123
    a2_synchronization_supported = models.BooleanField(
124
        verbose_name=_('Authentic2 synchronization supported'),
125
        default=False,
126
    )
127
    last_sync_time = models.DateTimeField(
128
        verbose_name=_('Last synchronization time'),
129
        null=True,
130
        blank=True,
131
    )
132

  
118 133
    # metadata
119 134
    created = models.DateTimeField(verbose_name=_('creation date'), auto_now_add=True)
120 135
    modified = models.DateTimeField(verbose_name=_('last modification date'), auto_now=True)
......
241 256
        ]
242 257
        return render(request, template_names, context)
243 258

  
259
    def perform_synchronization(self, sync_time=None, timeout=30):
260
        logger = logging.getLogger(__name__)
261

  
262
        if not self.a2_synchronization_supported:
263
            logger.error('OIDC provider %s does not support synchronization', self.slug)
264
            return
265
        if not sync_time:
266
            sync_time = now() - timedelta(minutes=1)
267

  
268
        # check all existing users
269
        def chunks(l, n):
270
            for i in range(0, len(l), n):
271
                yield l[i : i + n]
272

  
273
        url = self.issuer + '/api/users/synchronization/'
274

  
275
        unknown_uuids = []
276
        auth = (self.client_id, self.client_secret)
277
        for accounts in chunks(OIDCAccount.objects.filter(provider=self), 100):
278
            subs = [x.sub for x in accounts]
279
            resp = requests.post(url, json={'known_uuids': subs}, auth=auth, timeout=timeout)
280
            resp.raise_for_status()
281
            unknown_uuids.extend(resp.json().get('unknown_uuids'))
282
        deletion_ratio = len(unknown_uuids) / OIDCAccount.objects.filter(provider=self).count()
283
        if deletion_ratio > 0.05:  # higher than 5%, something definitely went wrong
284
            logger.error(
285
                'deletion ratio is abnormally high (%s), aborting unkwown users deletion', deletion_ratio
286
            )
287
        else:
288
            OIDCAccount.objects.filter(sub__in=unknown_uuids).delete()
289

  
290
        # update recently modified users
291
        url = self.issuer + '/api/users/?modified__gt=%s&claim_resolution' % (
292
            self.last_sync_time or datetime.utcfromtimestamp(0)
293
        ).strftime('%Y-%m-%dT%H:%M:%S')
294
        while url:
295
            resp = requests.get(url, auth=auth, timeout=timeout)
296
            resp.raise_for_status()
297
            url = resp.json().get('next')
298
            logger.info('got %s users', len(resp.json()['results']))
299
            for user_dict in resp.json()['results']:
300
                if user_dict.get('sub', None):
301
                    qs_search = {'sub': user_dict['sub']}
302
                else:
303
                    qs_search = {'user__email': user_dict['email']}
304
                try:
305
                    account = OIDCAccount.objects.get(**qs_search)
306
                except OIDCAccount.DoesNotExist:
307
                    continue
308
                except OIDCAccount.MultipleObjectsReturned:
309
                    continue
310
                had_changes = False
311
                for claim in self.claim_mappings.all():
312
                    if '{{' in claim.claim or '{%' in claim.claim:
313
                        template = Template(claim.claim)
314
                        attribute_value = template.render(context=user_dict)
315
                    else:
316
                        attribute_value = user_dict.get(claim.claim)
317
                    try:
318
                        old_attribute_value = getattr(account.user, claim.attribute)
319
                    except AttributeError:
320
                        try:
321
                            old_attribute_value = getattr(account.user.attributes, claim.attribute)
322
                        except AttributeError:
323
                            old_attribute_value = None
324
                    if old_attribute_value == attribute_value:
325
                        continue
326
                    had_changes = True
327
                    setattr(account.user, claim.attribute, attribute_value)
328
                    try:
329
                        setattr(account.user.attributes, claim.attribute, attribute_value)
330
                    except AttributeError:
331
                        pass
332
                if had_changes:
333
                    logger.debug('had changes, saving %r', account.user)
334
                    account.user.save()
335
        self.last_sync_time = sync_time
336
        self.save(update_fields=['last_sync_time'])
337

  
244 338

  
245 339
class OIDCClaimMapping(AuthenticatorRelatedObjectBase):
246 340
    NOT_VERIFIED = 0
tests/test_commands.py
17 17
import datetime
18 18
import importlib
19 19
import json
20
import random
21
import uuid
20 22
from io import BufferedReader, BufferedWriter, TextIOWrapper
21 23

  
24
import httmock
22 25
import py
23 26
import pytest
24 27
import webtest
......
40 43
from authentic2.apps.journal.models import Event
41 44
from authentic2.custom_user.models import DeletedUser
42 45
from authentic2.models import UserExternalId
43
from authentic2_auth_oidc.models import OIDCAccount, OIDCProvider
46
from authentic2.utils import crypto
47
from authentic2_auth_oidc.models import OIDCAccount, OIDCClaimMapping, OIDCProvider
44 48
from django_rbac.utils import get_operation
45 49

  
46
from .utils import call_command, login
50
from .utils import call_command, check_log, login
47 51

  
48 52
User = get_user_model()
49 53

  
......
476 480
    call_command('clean-user-exports')
477 481
    with pytest.raises(webtest.app.AppError):
478 482
        resp.click('Download CSV')
483

  
484

  
485
@pytest.mark.parametrize('deletion_number,deletion_valid', [(2, True), (5, True), (10, False)])
486
def test_oidc_sync_provider(db, app, admin, settings, caplog, deletion_number, deletion_valid):
487
    oidc_provider = OIDCProvider.objects.create(
488
        issuer='https://some.provider',
489
        name='Some Provider',
490
        slug='some-provider',
491
        ou=get_default_ou(),
492
    )
493
    OIDCClaimMapping.objects.create(
494
        authenticator=oidc_provider,
495
        attribute='username',
496
        idtoken_claim=False,
497
        claim='username',
498
    )
499
    OIDCClaimMapping.objects.create(
500
        authenticator=oidc_provider,
501
        attribute='email',
502
        idtoken_claim=False,
503
        claim='email',
504
    )
505
    # last one, with an idtoken claim
506
    OIDCClaimMapping.objects.create(
507
        authenticator=oidc_provider,
508
        attribute='last_name',
509
        idtoken_claim=True,
510
        claim='family_name',
511
    )
512
    # typo in template string
513
    OIDCClaimMapping.objects.create(
514
        authenticator=oidc_provider,
515
        attribute='first_name',
516
        idtoken_claim=True,
517
        claim='given_name',
518
    )
519
    User = get_user_model()
520
    for i in range(100):
521
        user = User.objects.create(
522
            first_name='John%s' % i,
523
            last_name='Doe%s' % i,
524
            username='john.doe.%s' % i,
525
            email='john.doe.%s@ad.dre.ss' % i,
526
            ou=get_default_ou(),
527
        )
528
        identifier = uuid.UUID(user.uuid).bytes
529
        sector_identifier = 'some.provider'
530
        cipher_args = [
531
            settings.SECRET_KEY.encode('utf-8'),
532
            identifier,
533
            sector_identifier,
534
        ]
535
        sub = crypto.aes_base64url_deterministic_encrypt(*cipher_args).decode('utf-8')
536
        OIDCAccount.objects.create(user=user, provider=oidc_provider, sub=sub)
537

  
538
    def synchronization_post_deletion_response(url, request):
539
        headers = {'content-type': 'application/json'}
540
        content = {
541
            'unknown_uuids': [
542
                account.sub for account in random.sample(list(OIDCAccount.objects.all()), deletion_number)
543
            ]
544
        }
545
        return httmock.response(status_code=200, headers=headers, content=content, request=request)
546

  
547
    def synchronization_get_modified_response(url, request):
548
        headers = {'content-type': 'application/json'}
549
        # randomized batch of modified users
550
        modified_users = random.sample(list(User.objects.all()), 20)
551
        results = []
552
        for count, user in enumerate(modified_users):
553
            user_json = user.to_json()
554
            user_json['username'] = f'modified_{count}'
555
            user_json['first_name'] = 'Mod'
556
            user_json['last_name'] = 'Ified'
557
            # mocking claim resolution by oidc provider
558
            user_json['given_name'] = 'Mod'
559
            user_json['family_name'] = 'Ified'
560

  
561
            # add user sub to response
562
            try:
563
                account = OIDCAccount.objects.get(user=user)
564
            except OIDCAccount.DoesNotExist:
565
                pass
566
            else:
567
                user_json['sub'] = account.sub
568

  
569
            results.append(user_json)
570
        content = {'results': results}
571
        return httmock.response(status_code=200, headers=headers, content=content, request=request)
572

  
573
    with httmock.HTTMock(
574
        httmock.urlmatch(
575
            netloc=r'some\.provider',
576
            path=r'^/api/users/synchronization/$',
577
            method='POST',
578
        )(synchronization_post_deletion_response)
579
    ):
580

  
581
        with httmock.HTTMock(
582
            httmock.urlmatch(
583
                netloc=r'some\.provider',
584
                path=r'^/api/users/*',
585
                method='GET',
586
            )(synchronization_get_modified_response)
587
        ):
588
            with check_log(caplog, 'no provider supporting synchronization'):
589
                call_command('oidc-sync-provider', '-v1')
590

  
591
            oidc_provider.a2_synchronization_supported = True
592
            oidc_provider.save()
593

  
594
            with check_log(caplog, 'no provider supporting synchronization'):
595
                call_command('oidc-sync-provider', '--provider', 'whatever', '-v1')
596

  
597
            with check_log(caplog, 'got 20 users'):
598
                call_command('oidc-sync-provider', '-v1')
599
            if deletion_valid:
600
                # existing users check
601
                assert OIDCAccount.objects.count() == 100 - deletion_number
602
            else:
603
                assert OIDCAccount.objects.count() == 100
604
                assert caplog.records[3].levelname == 'ERROR'
605
                assert 'deletion ratio is abnormally high' in caplog.records[3].message
606

  
607
            # users update
608
            assert User.objects.filter(username__startswith='modified').count() in range(
609
                20 - deletion_number, 21
610
            )
611
            assert User.objects.filter(first_name='Mod', last_name='Ified').count() in range(
612
                20 - deletion_number, 21
613
            )
479
-