Projet

Général

Profil

0001-WIP-auth_oidc-add-an-oidc-sync-provider-command-6271.patch

Paul Marillonnet, 04 août 2022 09:43

Télécharger (14,2 ko)

Voir les différences:

Subject: [PATCH] [WIP] auth_oidc: add an oidc-sync-provider command (#62710)

 .../management/commands/oidc-sync-provider.py |  37 ++++++
 .../migrations/0013_auto_20220803_1130.py     |  23 ++++
 src/authentic2_auth_oidc/models.py            |  90 ++++++++++++-
 tests/test_commands.py                        | 120 +++++++++++++++++-
 4 files changed, 268 insertions(+), 2 deletions(-)
 create mode 100644 src/authentic2_auth_oidc/management/commands/oidc-sync-provider.py
 create mode 100644 src/authentic2_auth_oidc/migrations/0013_auto_20220803_1130.py
src/authentic2_auth_oidc/management/commands/oidc-sync-provider.py
1
# authentic2 - versatile identity manager
2
# Copyright (C) 2010-2022  Entr'ouvert
3
#
4
# This program is free software: you can redistribute it and/or modify it
5
# under the terms of the GNU Affero General Public License as published
6
# by the Free Software Foundation, either version 3 of the License, or
7
# (at your option) any later version.
8
#
9
# This program is distributed in the hope that it will be useful,
10
# but WITHOUT ANY WARRANTY; without even the implied warranty of
11
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12
# GNU Affero General Public License for more details.
13
#
14
# You should have received a copy of the GNU Affero General Public License
15
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
16

  
17
import logging
18

  
19
from django.core.management.base import BaseCommand
20

  
21
from authentic2_auth_oidc.models import OIDCProvider
22

  
23

  
24
class Command(BaseCommand):
25
    def add_arguments(self, parser):
26
        parser.add_argument('--provider', type=str, default=None)
27

  
28
    def handle(self, *args, **options):
29
        logger = logging.getLogger('')  # console
30
        provider = options['provider']
31
        if provider:
32
            providers = OIDCProvider.objects.filter(slug=provider)
33
        else:
34
            providers = OIDCProvider.objects.filter(a2_synchronization_supported=True)
35

  
36
        for provider in providers:
37
            provider.perform_synchronization(logger=logger)
src/authentic2_auth_oidc/migrations/0013_auto_20220803_1130.py
1
# Generated by Django 2.2.26 on 2022-08-03 09:30
2

  
3
from django.db import migrations, models
4

  
5

  
6
class Migration(migrations.Migration):
7

  
8
    dependencies = [
9
        ('authentic2_auth_oidc', '0012_auto_20220524_1147'),
10
    ]
11

  
12
    operations = [
13
        migrations.AddField(
14
            model_name='oidcprovider',
15
            name='a2_synchronization_supported',
16
            field=models.BooleanField(default=False, verbose_name='Authentic2 synchronization supported'),
17
        ),
18
        migrations.AddField(
19
            model_name='oidcprovider',
20
            name='last_sync_time',
21
            field=models.DateTimeField(blank=True, null=True, verbose_name='Last synchronization time'),
22
        ),
23
    ]
src/authentic2_auth_oidc/models.py
15 15
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
16 16

  
17 17
import json
18
import logging
18 19
import uuid
20
from datetime import datetime, timedelta
19 21

  
22
import requests
20 23
from django.conf import settings
21 24
from django.contrib.postgres.fields import JSONField
22 25
from django.core.exceptions import ValidationError
23 26
from django.db import models
24 27
from django.shortcuts import render
28
from django.utils.timezone import now
25 29
from django.utils.translation import gettext_lazy as _
26 30
from jwcrypto.jwk import InvalidJWKValue, JWKSet
27 31

  
28 32
from authentic2.a2_rbac.utils import get_default_ou
29 33
from authentic2.apps.authenticators.models import BaseAuthenticator
30 34
from authentic2.utils.misc import make_url, redirect_to_login
31
from authentic2.utils.template import validate_template
35
from authentic2.utils.template import Template, validate_template
32 36

  
33 37
from . import managers
34 38

  
......
107 111
        verbose_name=_('max authentication age'), blank=True, null=True
108 112
    )
109 113

  
114
    # authentic2 specific synchronization api
115
    a2_synchronization_supported = models.BooleanField(
116
        verbose_name=_('Authentic2 synchronization supported'),
117
        default=False,
118
    )
119
    last_sync_time = models.DateTimeField(
120
        verbose_name=_('Last synchronization time'),
121
        null=True,
122
        blank=True,
123
    )
124

  
110 125
    # metadata
111 126
    created = models.DateTimeField(verbose_name=_('created'), auto_now_add=True)
112 127
    modified = models.DateTimeField(verbose_name=_('modified'), auto_now=True)
......
213 228
        ]
214 229
        return render(request, template_names, context)
215 230

  
231
    def perform_synchronization(self, sync_time=None, logger=None):
232
        # todo logging
233
        if not logger:
234
            logger = logging.getLogger(__name__)
235
        if not self.a2_synchronization_supported:
236
            return
237
        if not sync_time:
238
            sync_time = now() - timedelta(minutes=1)
239

  
240
        # check all existing users
241
        def chunks(l, n):
242
            for i in range(0, len(l), n):
243
                yield l[i : i + n]
244

  
245
        url = self.issuer + '/api/users/synchronization/'
246

  
247
        unknown_uuids = []
248
        auth = (self.client_id, self.client_secret)
249
        for accounts in chunks(OIDCAccount.objects.filter(provider=self), 100):
250
            subs = [x.sub for x in accounts]
251
            resp = requests.post(url, json={'known_uuids': subs}, auth=auth)
252
            resp.raise_for_status()
253
            unknown_uuids.extend(resp.json().get('unknown_uuids'))
254
        deletion_ratio = len(unknown_uuids) / OIDCAccount.objects.filter(provider=self).count()
255
        if deletion_ratio > 0.05:  # higher than 5%, something definitely went wrong
256
            raise Exception(
257
                f'deletion ratio is abnormally high ({deletion_ratio}), aborting unkwown users deletion'
258
            )  # todo proper exception class
259
        else:
260
            OIDCAccount.objects.filter(sub__in=unknown_uuids).delete()
261

  
262
        # update recently modified users
263
        url = self.issuer + '/api/users/?modified__gt=%s&claim_resolution' % (
264
            self.last_sync_time or datetime.utcfromtimestamp(0)
265
        ).strftime('%Y-%m-%dT%H:%M:%S')
266
        while url:
267
            resp = requests.get(url, auth=auth)
268
            resp.raise_for_status()
269
            url = resp.json().get('next')
270
            for user_dict in resp.json()['results']:
271
                try:
272
                    account = OIDCAccount.objects.get(user__email=user_dict['email'])
273
                except OIDCAccount.DoesNotExist:
274
                    continue
275
                except OIDCAccount.MultipleObjectsReturned:
276
                    continue
277
                had_changes = False
278
                for claim in self.claim_mappings.all():
279
                    if '{{' in claim.claim or '{%' in claim.claim:
280
                        template = Template(claim.claim)
281
                        attribute_value = template.render(context=user_dict)
282
                    else:
283
                        attribute_value = user_dict.get(claim.claim)
284
                    try:
285
                        old_attribute_value = getattr(account.user, claim.attribute)
286
                    except AttributeError:
287
                        try:
288
                            old_attribute_value = getattr(account.user.attributes, claim.attribute)
289
                        except AttributeError:
290
                            old_attribute_value = None
291
                    if old_attribute_value == attribute_value:
292
                        continue
293
                    had_changes = True
294
                    setattr(account.user, claim.attribute, attribute_value)
295
                    try:
296
                        setattr(account.user.attributes, claim.attribute, attribute_value)
297
                    except AttributeError:
298
                        pass
299
                if had_changes:
300
                    account.user.save()
301
        self.last_sync_time = sync_time
302
        self.save(update_fields=['last_sync_time'])
303

  
216 304

  
217 305
class OIDCClaimMapping(models.Model):
218 306
    NOT_VERIFIED = 0
tests/test_commands.py
17 17
import datetime
18 18
import importlib
19 19
import json
20
import random
21
import uuid
20 22
from io import BufferedReader, BufferedWriter, TextIOWrapper
21 23

  
24
import httmock
22 25
import py
23 26
import pytest
24 27
import webtest
25 28
from django.contrib.auth import get_user_model
26 29
from django.contrib.contenttypes.models import ContentType
30
from django.core.management import call_command
27 31
from django.utils.timezone import now
28 32
from jwcrypto.jwk import JWK, JWKSet
29 33

  
......
32 36
from authentic2.apps.journal.models import Event
33 37
from authentic2.custom_user.models import DeletedUser
34 38
from authentic2.models import UserExternalId
35
from authentic2_auth_oidc.models import OIDCAccount, OIDCProvider
39
from authentic2.utils import crypto
40
from authentic2_auth_oidc.models import OIDCAccount, OIDCClaimMapping, OIDCProvider
36 41
from django_rbac.models import ADMIN_OP, Operation
37 42
from django_rbac.utils import get_operation
38 43

  
......
437 442
    call_command('clean-user-exports')
438 443
    with pytest.raises(webtest.app.AppError):
439 444
        resp.click('Download CSV')
445

  
446

  
447
@pytest.mark.parametrize('deletion_number,deletion_valid', [(2, True), (5, True), (10, False)])
448
def test_oidc_sync_provider(db, app, admin, settings, capsys, deletion_number, deletion_valid):
449
    oidc_provider = OIDCProvider.objects.create(
450
        issuer='https://some.provider',
451
        name='Some Provider',
452
        slug='some-provider',
453
        ou=get_default_ou(),
454
        a2_synchronization_supported=True,
455
    )
456
    OIDCClaimMapping.objects.create(
457
        provider=oidc_provider,
458
        attribute='username',
459
        idtoken_claim=False,
460
        claim='username',
461
    )
462
    OIDCClaimMapping.objects.create(
463
        provider=oidc_provider,
464
        attribute='email',
465
        idtoken_claim=False,
466
        claim='email',
467
    )
468
    # last one, with an idtoken claim
469
    OIDCClaimMapping.objects.create(
470
        provider=oidc_provider,
471
        attribute='last_name',
472
        idtoken_claim=True,
473
        claim='family_name',
474
    )
475
    # typo in template string
476
    OIDCClaimMapping.objects.create(
477
        provider=oidc_provider,
478
        attribute='first_name',
479
        idtoken_claim=True,
480
        claim='given_name',
481
    )
482
    User = get_user_model()
483
    for i in range(100):
484
        user = User.objects.create(
485
            first_name='John%s' % i,
486
            last_name='Doe%s' % i,
487
            username='john.doe.%s' % i,
488
            email='john.doe.%s@ad.dre.ss' % i,
489
            ou=get_default_ou(),
490
        )
491
        identifier = uuid.UUID(user.uuid).bytes
492
        sector_identifier = 'some.provider'
493
        cipher_args = [
494
            settings.SECRET_KEY.encode('utf-8'),
495
            identifier,
496
            sector_identifier,
497
        ]
498
        sub = crypto.aes_base64url_deterministic_encrypt(*cipher_args).decode('utf-8')
499
        OIDCAccount.objects.create(user=user, provider=oidc_provider, sub=sub)
500

  
501
    def synchronization_post_deletion_response(url, request):
502
        headers = {'content-type': 'application/json'}
503
        content = {
504
            'unknown_uuids': [
505
                account.sub for account in random.sample(list(OIDCAccount.objects.all()), deletion_number)
506
            ]
507
        }
508
        return httmock.response(status_code=200, headers=headers, content=content, request=request)
509

  
510
    def synchronization_get_modified_response(url, request):
511
        headers = {'content-type': 'application/json'}
512
        # randomized batch of modified users
513
        modified_users = random.sample(list(User.objects.all()), 20)
514
        results = []
515
        for count, user in enumerate(modified_users):
516
            user_json = user.to_json()
517
            user_json['username'] = f'modified_{count}'
518
            user_json['first_name'] = 'Mod'
519
            user_json['last_name'] = 'Ified'
520
            # mocking claim resolution by oidc provider
521
            user_json['given_name'] = 'Mod'
522
            user_json['family_name'] = 'Ified'
523
            results.append(user_json)
524
        content = {'results': results}
525
        return httmock.response(status_code=200, headers=headers, content=content, request=request)
526

  
527
    with httmock.HTTMock(
528
        httmock.urlmatch(
529
            netloc=r'some\.provider',
530
            path=r'^/api/users/synchronization/$',
531
            method='POST',
532
        )(synchronization_post_deletion_response)
533
    ):
534

  
535
        with httmock.HTTMock(
536
            httmock.urlmatch(
537
                netloc=r'some\.provider',
538
                path=r'^/api/users/*',
539
                method='GET',
540
            )(synchronization_get_modified_response)
541
        ):
542
            call_command('oidc-sync-provider', '-v1')
543
            out, err = capsys.readouterr()
544
            assert not err
545
            if deletion_valid:
546
                # existing users check
547
                assert OIDCAccount.objects.count() == 100 - deletion_number
548
            else:
549
                assert OIDCAccount.objects.count() == 100
550

  
551
            # users update
552
            assert User.objects.filter(username__startswith='modified').count() in range(
553
                20 - deletion_number, 21
554
            )
555
            assert User.objects.filter(first_name='Mod', last_name='Ified').count() in range(
556
                20 - deletion_number, 21
557
            )
440
-