Projet

Général

Profil

0002-auth_oidc-add-an-oidc-sync-provider-command-62710.patch

Benjamin Dauvergne, 25 octobre 2022 11:28

Télécharger (15,1 ko)

Voir les différences:

Subject: [PATCH 2/2] auth_oidc: add an oidc-sync-provider command (#62710)

 .../management/commands/oidc-sync-provider.py |  45 ++++++
 .../migrations/0013_synchronization_fields.py |  23 ++++
 src/authentic2_auth_oidc/models.py            |  92 ++++++++++++-
 tests/test_commands.py                        | 130 +++++++++++++++++-
 4 files changed, 287 insertions(+), 3 deletions(-)
 create mode 100644 src/authentic2_auth_oidc/management/commands/oidc-sync-provider.py
 create mode 100644 src/authentic2_auth_oidc/migrations/0013_synchronization_fields.py
src/authentic2_auth_oidc/management/commands/oidc-sync-provider.py
1
# authentic2 - versatile identity manager
2
# Copyright (C) 2010-2022  Entr'ouvert
3
#
4
# This program is free software: you can redistribute it and/or modify it
5
# under the terms of the GNU Affero General Public License as published
6
# by the Free Software Foundation, either version 3 of the License, or
7
# (at your option) any later version.
8
#
9
# This program is distributed in the hope that it will be useful,
10
# but WITHOUT ANY WARRANTY; without even the implied warranty of
11
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12
# GNU Affero General Public License for more details.
13
#
14
# You should have received a copy of the GNU Affero General Public License
15
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
16

  
17
import logging
18

  
19
from authentic2.base_commands import LogToConsoleCommand
20
from authentic2_auth_oidc.models import OIDCProvider
21

  
22

  
23
class Command(LogToConsoleCommand):
24
    loggername = 'authentic2_auth_oidc.models'
25

  
26
    def add_arguments(self, parser):
27
        parser.add_argument('--provider', type=str, default=None)
28

  
29
    def core_command(self, *args, **kwargs):
30
        provider = kwargs['provider']
31

  
32
        logger = logging.getLogger(self.loggername)
33
        providers = OIDCProvider.objects.filter(a2_synchronization_supported=True)
34
        if provider:
35
            providers = providers.filter(slug=provider)
36
        if not providers.count():
37
            logger.error('no provider supporting synchronization found, exiting')
38
            return
39
        logger.info(
40
            'got %s provider(s): %s',
41
            providers.count(),
42
            ' '.join(providers.values_list('slug', flat=True)),
43
        )
44
        for provider in providers:
45
            provider.perform_synchronization()
src/authentic2_auth_oidc/migrations/0013_synchronization_fields.py
1
# Generated by Django 2.2.26 on 2022-08-03 09:30
2

  
3
from django.db import migrations, models
4

  
5

  
6
class Migration(migrations.Migration):
7

  
8
    dependencies = [
9
        ('authentic2_auth_oidc', '0016_auto_20221019_1148'),
10
    ]
11

  
12
    operations = [
13
        migrations.AddField(
14
            model_name='oidcprovider',
15
            name='a2_synchronization_supported',
16
            field=models.BooleanField(default=False, verbose_name='Authentic2 synchronization supported'),
17
        ),
18
        migrations.AddField(
19
            model_name='oidcprovider',
20
            name='last_sync_time',
21
            field=models.DateTimeField(blank=True, null=True, verbose_name='Last synchronization time'),
22
        ),
23
    ]
src/authentic2_auth_oidc/models.py
15 15
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
16 16

  
17 17
import json
18
import logging
18 19
import uuid
20
from datetime import datetime, timedelta
19 21

  
22
import requests
20 23
from django.conf import settings
21 24
from django.contrib.postgres.fields import JSONField
22 25
from django.core.exceptions import ValidationError
23 26
from django.db import models
24 27
from django.shortcuts import render
28
from django.utils.timezone import now
25 29
from django.utils.translation import gettext_lazy as _
26 30
from django.utils.translation import pgettext_lazy
27 31
from jwcrypto.jwk import InvalidJWKValue, JWKSet
......
33 37
    BaseAuthenticator,
34 38
)
35 39
from authentic2.utils.misc import make_url, redirect_to_login
36
from authentic2.utils.template import validate_template
40
from authentic2.utils.template import Template, validate_template
37 41

  
38 42
from . import managers
39 43

  
......
116 120
        verbose_name=_('max authentication age'), blank=True, null=True
117 121
    )
118 122

  
123
    # authentic2 specific synchronization api
124
    a2_synchronization_supported = models.BooleanField(
125
        verbose_name=_('Authentic2 synchronization supported'),
126
        default=False,
127
    )
128
    last_sync_time = models.DateTimeField(
129
        verbose_name=_('Last synchronization time'),
130
        null=True,
131
        blank=True,
132
    )
133

  
119 134
    # metadata
120 135
    created = models.DateTimeField(verbose_name=_('creation date'), auto_now_add=True)
121 136
    modified = models.DateTimeField(verbose_name=_('last modification date'), auto_now=True)
......
242 257
        ]
243 258
        return render(request, template_names, context)
244 259

  
260
    def perform_synchronization(self, sync_time=None, timeout=30):
261
        logger = logging.getLogger(__name__)
262

  
263
        if not self.a2_synchronization_supported:
264
            logger.error('OIDC provider %s does not support synchronization', self.slug)
265
            return
266
        if not sync_time:
267
            sync_time = now() - timedelta(minutes=1)
268

  
269
        # check all existing users
270
        def chunks(l, n):
271
            for i in range(0, len(l), n):
272
                yield l[i : i + n]
273

  
274
        url = self.issuer + '/api/users/synchronization/'
275

  
276
        unknown_uuids = []
277
        auth = (self.client_id, self.client_secret)
278
        for accounts in chunks(OIDCAccount.objects.filter(provider=self), 100):
279
            subs = [x.sub for x in accounts]
280
            resp = requests.post(url, json={'known_uuids': subs}, auth=auth, timeout=timeout)
281
            resp.raise_for_status()
282
            unknown_uuids.extend(resp.json().get('unknown_uuids'))
283
        deletion_ratio = len(unknown_uuids) / OIDCAccount.objects.filter(provider=self).count()
284
        if deletion_ratio > 0.05:  # higher than 5%, something definitely went wrong
285
            logger.error(
286
                'deletion ratio is abnormally high (%s), aborting unkwown users deletion', deletion_ratio
287
            )
288
        else:
289
            OIDCAccount.objects.filter(sub__in=unknown_uuids).delete()
290

  
291
        # update recently modified users
292
        url = self.issuer + '/api/users/?modified__gt=%s&claim_resolution' % (
293
            self.last_sync_time or datetime.utcfromtimestamp(0)
294
        ).strftime('%Y-%m-%dT%H:%M:%S')
295
        while url:
296
            resp = requests.get(url, auth=auth, timeout=timeout)
297
            resp.raise_for_status()
298
            url = resp.json().get('next')
299
            logger.info('got %s users', len(resp.json()['results']))
300
            for user_dict in resp.json()['results']:
301
                try:
302
                    account = OIDCAccount.objects.get(user__email=user_dict['email'])
303
                except OIDCAccount.DoesNotExist:
304
                    continue
305
                except OIDCAccount.MultipleObjectsReturned:
306
                    continue
307
                had_changes = False
308
                for claim in self.claim_mappings.all():
309
                    if '{{' in claim.claim or '{%' in claim.claim:
310
                        template = Template(claim.claim)
311
                        attribute_value = template.render(context=user_dict)
312
                    else:
313
                        attribute_value = user_dict.get(claim.claim)
314
                    try:
315
                        old_attribute_value = getattr(account.user, claim.attribute)
316
                    except AttributeError:
317
                        try:
318
                            old_attribute_value = getattr(account.user.attributes, claim.attribute)
319
                        except AttributeError:
320
                            old_attribute_value = None
321
                    if old_attribute_value == attribute_value:
322
                        continue
323
                    had_changes = True
324
                    setattr(account.user, claim.attribute, attribute_value)
325
                    try:
326
                        setattr(account.user.attributes, claim.attribute, attribute_value)
327
                    except AttributeError:
328
                        pass
329
                if had_changes:
330
                    logger.debug('had changes, saving %r', account.user)
331
                    account.user.save()
332
        self.last_sync_time = sync_time
333
        self.save(update_fields=['last_sync_time'])
334

  
245 335

  
246 336
class OIDCClaimMapping(AuthenticatorRelatedObjectBase):
247 337
    NOT_VERIFIED = 0
tests/test_commands.py
17 17
import datetime
18 18
import importlib
19 19
import json
20
import random
21
import uuid
20 22
from io import BufferedReader, BufferedWriter, TextIOWrapper
21 23

  
24
import httmock
22 25
import py
23 26
import pytest
24 27
import webtest
......
40 43
from authentic2.apps.journal.models import Event
41 44
from authentic2.custom_user.models import DeletedUser
42 45
from authentic2.models import UserExternalId
43
from authentic2_auth_oidc.models import OIDCAccount, OIDCProvider
46
from authentic2.utils import crypto
47
from authentic2_auth_oidc.models import OIDCAccount, OIDCClaimMapping, OIDCProvider
44 48
from django_rbac.utils import get_operation
45 49

  
46
from .utils import call_command, login
50
from .utils import call_command, check_log, login
47 51

  
48 52
User = get_user_model()
49 53

  
......
456 460
    call_command('clean-user-exports')
457 461
    with pytest.raises(webtest.app.AppError):
458 462
        resp.click('Download CSV')
463

  
464

  
465
@pytest.mark.parametrize('deletion_number,deletion_valid', [(2, True), (5, True), (10, False)])
466
def test_oidc_sync_provider(db, app, admin, settings, caplog, deletion_number, deletion_valid):
467
    oidc_provider = OIDCProvider.objects.create(
468
        issuer='https://some.provider',
469
        name='Some Provider',
470
        slug='some-provider',
471
        ou=get_default_ou(),
472
    )
473
    OIDCClaimMapping.objects.create(
474
        authenticator=oidc_provider,
475
        attribute='username',
476
        idtoken_claim=False,
477
        claim='username',
478
    )
479
    OIDCClaimMapping.objects.create(
480
        authenticator=oidc_provider,
481
        attribute='email',
482
        idtoken_claim=False,
483
        claim='email',
484
    )
485
    # last one, with an idtoken claim
486
    OIDCClaimMapping.objects.create(
487
        authenticator=oidc_provider,
488
        attribute='last_name',
489
        idtoken_claim=True,
490
        claim='family_name',
491
    )
492
    # typo in template string
493
    OIDCClaimMapping.objects.create(
494
        authenticator=oidc_provider,
495
        attribute='first_name',
496
        idtoken_claim=True,
497
        claim='given_name',
498
    )
499
    User = get_user_model()
500
    for i in range(100):
501
        user = User.objects.create(
502
            first_name='John%s' % i,
503
            last_name='Doe%s' % i,
504
            username='john.doe.%s' % i,
505
            email='john.doe.%s@ad.dre.ss' % i,
506
            ou=get_default_ou(),
507
        )
508
        identifier = uuid.UUID(user.uuid).bytes
509
        sector_identifier = 'some.provider'
510
        cipher_args = [
511
            settings.SECRET_KEY.encode('utf-8'),
512
            identifier,
513
            sector_identifier,
514
        ]
515
        sub = crypto.aes_base64url_deterministic_encrypt(*cipher_args).decode('utf-8')
516
        OIDCAccount.objects.create(user=user, provider=oidc_provider, sub=sub)
517

  
518
    def synchronization_post_deletion_response(url, request):
519
        headers = {'content-type': 'application/json'}
520
        content = {
521
            'unknown_uuids': [
522
                account.sub for account in random.sample(list(OIDCAccount.objects.all()), deletion_number)
523
            ]
524
        }
525
        return httmock.response(status_code=200, headers=headers, content=content, request=request)
526

  
527
    def synchronization_get_modified_response(url, request):
528
        headers = {'content-type': 'application/json'}
529
        # randomized batch of modified users
530
        modified_users = random.sample(list(User.objects.all()), 20)
531
        results = []
532
        for count, user in enumerate(modified_users):
533
            user_json = user.to_json()
534
            user_json['username'] = f'modified_{count}'
535
            user_json['first_name'] = 'Mod'
536
            user_json['last_name'] = 'Ified'
537
            # mocking claim resolution by oidc provider
538
            user_json['given_name'] = 'Mod'
539
            user_json['family_name'] = 'Ified'
540
            results.append(user_json)
541
        content = {'results': results}
542
        return httmock.response(status_code=200, headers=headers, content=content, request=request)
543

  
544
    with httmock.HTTMock(
545
        httmock.urlmatch(
546
            netloc=r'some\.provider',
547
            path=r'^/api/users/synchronization/$',
548
            method='POST',
549
        )(synchronization_post_deletion_response)
550
    ):
551

  
552
        with httmock.HTTMock(
553
            httmock.urlmatch(
554
                netloc=r'some\.provider',
555
                path=r'^/api/users/*',
556
                method='GET',
557
            )(synchronization_get_modified_response)
558
        ):
559
            with check_log(caplog, 'no provider supporting synchronization'):
560
                call_command('oidc-sync-provider', '-v1')
561

  
562
            oidc_provider.a2_synchronization_supported = True
563
            oidc_provider.save()
564

  
565
            with check_log(caplog, 'no provider supporting synchronization'):
566
                call_command('oidc-sync-provider', '--provider', 'whatever', '-v1')
567

  
568
            with check_log(caplog, 'got 20 users'):
569
                call_command('oidc-sync-provider', '-v1')
570
            if deletion_valid:
571
                # existing users check
572
                assert OIDCAccount.objects.count() == 100 - deletion_number
573
            else:
574
                assert OIDCAccount.objects.count() == 100
575
                assert caplog.records[3].levelname == 'ERROR'
576
                assert 'deletion ratio is abnormally high' in caplog.records[3].message
577

  
578
            # users update
579
            assert User.objects.filter(username__startswith='modified').count() in range(
580
                20 - deletion_number, 21
581
            )
582
            assert User.objects.filter(first_name='Mod', last_name='Ified').count() in range(
583
                20 - deletion_number, 21
584
            )
459
-