Projet

Général

Profil

0002-auth_oidc-add-an-oidc-sync-provider-command-62710.patch

Paul Marillonnet, 05 août 2022 13:56

Télécharger (15,2 ko)

Voir les différences:

Subject: [PATCH 2/2] auth_oidc: add an oidc-sync-provider command (#62710)

 .../management/commands/oidc-sync-provider.py |  45 ++++++
 .../migrations/0013_auto_20220803_1130.py     |  23 +++
 src/authentic2_auth_oidc/models.py            |  92 +++++++++++-
 tests/test_commands.py                        | 131 +++++++++++++++++-
 4 files changed, 288 insertions(+), 3 deletions(-)
 create mode 100644 src/authentic2_auth_oidc/management/commands/oidc-sync-provider.py
 create mode 100644 src/authentic2_auth_oidc/migrations/0013_auto_20220803_1130.py
src/authentic2_auth_oidc/management/commands/oidc-sync-provider.py
1
# authentic2 - versatile identity manager
2
# Copyright (C) 2010-2022  Entr'ouvert
3
#
4
# This program is free software: you can redistribute it and/or modify it
5
# under the terms of the GNU Affero General Public License as published
6
# by the Free Software Foundation, either version 3 of the License, or
7
# (at your option) any later version.
8
#
9
# This program is distributed in the hope that it will be useful,
10
# but WITHOUT ANY WARRANTY; without even the implied warranty of
11
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12
# GNU Affero General Public License for more details.
13
#
14
# You should have received a copy of the GNU Affero General Public License
15
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
16

  
17
import logging
18

  
19
from authentic2.base_commands import LogToConsoleCommand
20
from authentic2_auth_oidc.models import OIDCProvider
21

  
22

  
23
class Command(LogToConsoleCommand):
24
    loggername = 'authentic2_auth_oidc.models'
25

  
26
    def add_arguments(self, parser):
27
        parser.add_argument('--provider', type=str, default=None)
28

  
29
    def core_command(self, *args, **kwargs):
30
        provider = kwargs['provider']
31

  
32
        logger = logging.getLogger(self.loggername)
33
        providers = OIDCProvider.objects.filter(a2_synchronization_supported=True)
34
        if provider:
35
            providers = providers.filter(slug=provider)
36
        if not providers.count():
37
            logger.error('no provider supporting synchronization found, exiting')
38
            return
39
        logger.info(
40
            'got %s provider(s): %s',
41
            providers.count(),
42
            ' '.join(providers.values_list('slug', flat=True)),
43
        )
44
        for provider in providers:
45
            provider.perform_synchronization()
src/authentic2_auth_oidc/migrations/0013_auto_20220803_1130.py
1
# Generated by Django 2.2.26 on 2022-08-03 09:30
2

  
3
from django.db import migrations, models
4

  
5

  
6
class Migration(migrations.Migration):
7

  
8
    dependencies = [
9
        ('authentic2_auth_oidc', '0012_auto_20220524_1147'),
10
    ]
11

  
12
    operations = [
13
        migrations.AddField(
14
            model_name='oidcprovider',
15
            name='a2_synchronization_supported',
16
            field=models.BooleanField(default=False, verbose_name='Authentic2 synchronization supported'),
17
        ),
18
        migrations.AddField(
19
            model_name='oidcprovider',
20
            name='last_sync_time',
21
            field=models.DateTimeField(blank=True, null=True, verbose_name='Last synchronization time'),
22
        ),
23
    ]
src/authentic2_auth_oidc/models.py
15 15
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
16 16

  
17 17
import json
18
import logging
18 19
import uuid
20
from datetime import datetime, timedelta
19 21

  
22
import requests
20 23
from django.conf import settings
21 24
from django.contrib.postgres.fields import JSONField
22 25
from django.core.exceptions import ValidationError
23 26
from django.db import models
24 27
from django.shortcuts import render
28
from django.utils.timezone import now
25 29
from django.utils.translation import gettext_lazy as _
26 30
from jwcrypto.jwk import InvalidJWKValue, JWKSet
27 31

  
28 32
from authentic2.a2_rbac.utils import get_default_ou
29 33
from authentic2.apps.authenticators.models import BaseAuthenticator
30 34
from authentic2.utils.misc import make_url, redirect_to_login
31
from authentic2.utils.template import validate_template
35
from authentic2.utils.template import Template, validate_template
32 36

  
33 37
from . import managers
34 38

  
......
107 111
        verbose_name=_('max authentication age'), blank=True, null=True
108 112
    )
109 113

  
114
    # authentic2 specific synchronization api
115
    a2_synchronization_supported = models.BooleanField(
116
        verbose_name=_('Authentic2 synchronization supported'),
117
        default=False,
118
    )
119
    last_sync_time = models.DateTimeField(
120
        verbose_name=_('Last synchronization time'),
121
        null=True,
122
        blank=True,
123
    )
124

  
110 125
    # metadata
111 126
    created = models.DateTimeField(verbose_name=_('created'), auto_now_add=True)
112 127
    modified = models.DateTimeField(verbose_name=_('modified'), auto_now=True)
......
213 228
        ]
214 229
        return render(request, template_names, context)
215 230

  
231
    def perform_synchronization(self, sync_time=None):
232
        logger = logging.getLogger(__name__)
233

  
234
        if not self.a2_synchronization_supported:
235
            logger.error('OIDC provider %s does not support synchronization', self.slug)
236
            return
237
        if not sync_time:
238
            sync_time = now() - timedelta(minutes=1)
239

  
240
        # check all existing users
241
        def chunks(l, n):
242
            for i in range(0, len(l), n):
243
                yield l[i : i + n]
244

  
245
        url = self.issuer + '/api/users/synchronization/'
246

  
247
        unknown_uuids = []
248
        auth = (self.client_id, self.client_secret)
249
        for accounts in chunks(OIDCAccount.objects.filter(provider=self), 100):
250
            subs = [x.sub for x in accounts]
251
            resp = requests.post(url, json={'known_uuids': subs}, auth=auth)
252
            resp.raise_for_status()
253
            unknown_uuids.extend(resp.json().get('unknown_uuids'))
254
        deletion_ratio = len(unknown_uuids) / OIDCAccount.objects.filter(provider=self).count()
255
        if deletion_ratio > 0.05:  # higher than 5%, something definitely went wrong
256
            logger.error(
257
                'deletion ratio is abnormally high (%s), aborting unkwown users deletion', deletion_ratio
258
            )
259
        else:
260
            OIDCAccount.objects.filter(sub__in=unknown_uuids).delete()
261

  
262
        # update recently modified users
263
        url = self.issuer + '/api/users/?modified__gt=%s&claim_resolution' % (
264
            self.last_sync_time or datetime.utcfromtimestamp(0)
265
        ).strftime('%Y-%m-%dT%H:%M:%S')
266
        while url:
267
            resp = requests.get(url, auth=auth)
268
            resp.raise_for_status()
269
            url = resp.json().get('next')
270
            logger.info('got %s users', len(resp.json()['results']))
271
            for user_dict in resp.json()['results']:
272
                try:
273
                    account = OIDCAccount.objects.get(user__email=user_dict['email'])
274
                except OIDCAccount.DoesNotExist:
275
                    continue
276
                except OIDCAccount.MultipleObjectsReturned:
277
                    continue
278
                had_changes = False
279
                for claim in self.claim_mappings.all():
280
                    if '{{' in claim.claim or '{%' in claim.claim:
281
                        template = Template(claim.claim)
282
                        attribute_value = template.render(context=user_dict)
283
                    else:
284
                        attribute_value = user_dict.get(claim.claim)
285
                    try:
286
                        old_attribute_value = getattr(account.user, claim.attribute)
287
                    except AttributeError:
288
                        try:
289
                            old_attribute_value = getattr(account.user.attributes, claim.attribute)
290
                        except AttributeError:
291
                            old_attribute_value = None
292
                    if old_attribute_value == attribute_value:
293
                        continue
294
                    had_changes = True
295
                    setattr(account.user, claim.attribute, attribute_value)
296
                    try:
297
                        setattr(account.user.attributes, claim.attribute, attribute_value)
298
                    except AttributeError:
299
                        pass
300
                if had_changes:
301
                    logger.debug('had changes, saving %r', account.user)
302
                    account.user.save()
303
        self.last_sync_time = sync_time
304
        self.save(update_fields=['last_sync_time'])
305

  
216 306

  
217 307
class OIDCClaimMapping(models.Model):
218 308
    NOT_VERIFIED = 0
tests/test_commands.py
17 17
import datetime
18 18
import importlib
19 19
import json
20
import random
21
import uuid
20 22
from io import BufferedReader, BufferedWriter, TextIOWrapper
21 23

  
24
import httmock
22 25
import py
23 26
import pytest
24 27
import webtest
25 28
from django.contrib.auth import get_user_model
26 29
from django.contrib.contenttypes.models import ContentType
30
from django.core.management import call_command
27 31
from django.utils.timezone import now
28 32
from jwcrypto.jwk import JWK, JWKSet
29 33

  
......
32 36
from authentic2.apps.journal.models import Event
33 37
from authentic2.custom_user.models import DeletedUser
34 38
from authentic2.models import UserExternalId
35
from authentic2_auth_oidc.models import OIDCAccount, OIDCProvider
39
from authentic2.utils import crypto
40
from authentic2_auth_oidc.models import OIDCAccount, OIDCClaimMapping, OIDCProvider
36 41
from django_rbac.models import ADMIN_OP, Operation
37 42
from django_rbac.utils import get_operation
38 43

  
39
from .utils import call_command, login
44
from .utils import call_command, check_log, login
40 45

  
41 46
User = get_user_model()
42 47

  
......
437 442
    call_command('clean-user-exports')
438 443
    with pytest.raises(webtest.app.AppError):
439 444
        resp.click('Download CSV')
445

  
446

  
447
@pytest.mark.parametrize('deletion_number,deletion_valid', [(2, True), (5, True), (10, False)])
448
def test_oidc_sync_provider(db, app, admin, settings, caplog, deletion_number, deletion_valid):
449
    oidc_provider = OIDCProvider.objects.create(
450
        issuer='https://some.provider',
451
        name='Some Provider',
452
        slug='some-provider',
453
        ou=get_default_ou(),
454
    )
455
    OIDCClaimMapping.objects.create(
456
        provider=oidc_provider,
457
        attribute='username',
458
        idtoken_claim=False,
459
        claim='username',
460
    )
461
    OIDCClaimMapping.objects.create(
462
        provider=oidc_provider,
463
        attribute='email',
464
        idtoken_claim=False,
465
        claim='email',
466
    )
467
    # last one, with an idtoken claim
468
    OIDCClaimMapping.objects.create(
469
        provider=oidc_provider,
470
        attribute='last_name',
471
        idtoken_claim=True,
472
        claim='family_name',
473
    )
474
    # typo in template string
475
    OIDCClaimMapping.objects.create(
476
        provider=oidc_provider,
477
        attribute='first_name',
478
        idtoken_claim=True,
479
        claim='given_name',
480
    )
481
    User = get_user_model()
482
    for i in range(100):
483
        user = User.objects.create(
484
            first_name='John%s' % i,
485
            last_name='Doe%s' % i,
486
            username='john.doe.%s' % i,
487
            email='john.doe.%s@ad.dre.ss' % i,
488
            ou=get_default_ou(),
489
        )
490
        identifier = uuid.UUID(user.uuid).bytes
491
        sector_identifier = 'some.provider'
492
        cipher_args = [
493
            settings.SECRET_KEY.encode('utf-8'),
494
            identifier,
495
            sector_identifier,
496
        ]
497
        sub = crypto.aes_base64url_deterministic_encrypt(*cipher_args).decode('utf-8')
498
        OIDCAccount.objects.create(user=user, provider=oidc_provider, sub=sub)
499

  
500
    def synchronization_post_deletion_response(url, request):
501
        headers = {'content-type': 'application/json'}
502
        content = {
503
            'unknown_uuids': [
504
                account.sub for account in random.sample(list(OIDCAccount.objects.all()), deletion_number)
505
            ]
506
        }
507
        return httmock.response(status_code=200, headers=headers, content=content, request=request)
508

  
509
    def synchronization_get_modified_response(url, request):
510
        headers = {'content-type': 'application/json'}
511
        # randomized batch of modified users
512
        modified_users = random.sample(list(User.objects.all()), 20)
513
        results = []
514
        for count, user in enumerate(modified_users):
515
            user_json = user.to_json()
516
            user_json['username'] = f'modified_{count}'
517
            user_json['first_name'] = 'Mod'
518
            user_json['last_name'] = 'Ified'
519
            # mocking claim resolution by oidc provider
520
            user_json['given_name'] = 'Mod'
521
            user_json['family_name'] = 'Ified'
522
            results.append(user_json)
523
        content = {'results': results}
524
        return httmock.response(status_code=200, headers=headers, content=content, request=request)
525

  
526
    with httmock.HTTMock(
527
        httmock.urlmatch(
528
            netloc=r'some\.provider',
529
            path=r'^/api/users/synchronization/$',
530
            method='POST',
531
        )(synchronization_post_deletion_response)
532
    ):
533

  
534
        with httmock.HTTMock(
535
            httmock.urlmatch(
536
                netloc=r'some\.provider',
537
                path=r'^/api/users/*',
538
                method='GET',
539
            )(synchronization_get_modified_response)
540
        ):
541
            with check_log(caplog, 'no provider supporting synchronization'):
542
                call_command('oidc-sync-provider', '-v1')
543

  
544
            oidc_provider.a2_synchronization_supported = True
545
            oidc_provider.save()
546

  
547
            with check_log(caplog, 'no provider supporting synchronization'):
548
                call_command('oidc-sync-provider', '--provider', 'whatever', '-v1')
549

  
550
            with check_log(caplog, 'got 20 users'):
551
                call_command('oidc-sync-provider', '-v1')
552
            if deletion_valid:
553
                # existing users check
554
                assert OIDCAccount.objects.count() == 100 - deletion_number
555
            else:
556
                assert OIDCAccount.objects.count() == 100
557
                assert caplog.records[3].levelname == 'ERROR'
558
                assert 'deletion ratio is abnormally high' in caplog.records[3].message
559

  
560
            # users update
561
            assert User.objects.filter(username__startswith='modified').count() in range(
562
                20 - deletion_number, 21
563
            )
564
            assert User.objects.filter(first_name='Mod', last_name='Ified').count() in range(
565
                20 - deletion_number, 21
566
            )
440
-