Projet

Général

Profil

0001-auth_oidc-add-an-oidc-sync-provider-command-62710.patch

Paul Marillonnet, 01 juin 2022 15:41

Télécharger (11,7 ko)

Voir les différences:

Subject: [PATCH] auth_oidc: add an oidc-sync-provider command (#62710)

 .../management/commands/oidc-sync-provider.py | 112 +++++++++++++++
 tests/test_commands.py                        | 129 +++++++++++++++++-
 2 files changed, 240 insertions(+), 1 deletion(-)
 create mode 100644 src/authentic2_auth_oidc/management/commands/oidc-sync-provider.py
src/authentic2_auth_oidc/management/commands/oidc-sync-provider.py
1
# authentic2 - versatile identity manager
2
# Copyright (C) 2010-2022  Entr'ouvert
3
#
4
# This program is free software: you can redistribute it and/or modify it
5
# under the terms of the GNU Affero General Public License as published
6
# by the Free Software Foundation, either version 3 of the License, or
7
# (at your option) any later version.
8
#
9
# This program is distributed in the hope that it will be useful,
10
# but WITHOUT ANY WARRANTY; without even the implied warranty of
11
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12
# GNU Affero General Public License for more details.
13
#
14
# You should have received a copy of the GNU Affero General Public License
15
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
16

  
17
import datetime
18

  
19
import requests
20
from django.core.exceptions import MultipleObjectsReturned
21
from django.core.management.base import BaseCommand
22

  
23
from authentic2.utils.template import Template
24
from authentic2_auth_oidc.models import OIDCAccount, OIDCProvider
25

  
26

  
27
class Command(BaseCommand):
28
    def add_arguments(self, parser):
29
        parser.add_argument('--delta', metavar='DELTA', type=int, default=300)
30
        parser.add_argument('--provider', type=str, default=None)
31

  
32
    def handle(self, *args, **options):
33
        verbose = int(options['verbosity']) > 0
34
        delta = options['delta']
35
        provider = options['provider']
36

  
37
        if not provider:
38
            self.stdout.write(self.style.ERROR('no declared provider, exiting...'))
39
            return
40
        try:
41
            provider = OIDCProvider.objects.get(slug=provider)
42
        except OIDCProvider.DoesNotExist:
43
            self.stdout.write(self.style.ERROR(f'provider {provider} not found, exiting...'))
44
            return
45

  
46
        # check all existing users
47
        def chunks(l, n):
48
            for i in range(0, len(l), n):
49
                yield l[i : i + n]
50

  
51
        url = provider.issuer + '/api/users/synchronization/'
52

  
53
        unknown_uuids = []
54
        auth = (provider.client_id, provider.client_secret)
55
        for accounts in chunks(OIDCAccount.objects.filter(provider=provider), 100):
56
            subs = [x.sub for x in accounts]
57
            resp = requests.post(url, json={'known_uuids': subs}, auth=auth)
58
            resp.raise_for_status()
59
            unknown_uuids.extend(resp.json().get('unknown_uuids'))
60
        deletion_ratio = len(unknown_uuids) / OIDCAccount.objects.filter(provider=provider).count()
61
        if deletion_ratio > 0.05:  # higher than 5%, something definitely went wrong
62
            self.stdout.write(
63
                self.style.ERROR(
64
                    f'deletion ratio is abnormally high ({deletion_ratio}), aborting unkwown users deletion'
65
                )
66
            )
67
        else:
68
            OIDCAccount.objects.filter(sub__in=unknown_uuids).delete()
69

  
70
        # update recently modified users
71
        url = provider.issuer + '/api/users/?modified__gt=%s&claim_resolution' % (
72
            datetime.datetime.now() - datetime.timedelta(seconds=delta)
73
        ).strftime('%Y-%m-%dT%H:%M:%S')
74
        while url:
75
            resp = requests.get(url, auth=auth)
76
            resp.raise_for_status()
77
            url = resp.json().get('next')
78
            if verbose:
79
                self.stdout.write('got %s users' % len(resp.json()['results']))
80
            for user_dict in resp.json()['results']:
81
                try:
82
                    account = OIDCAccount.objects.get(user__email=user_dict['email'])
83
                except OIDCAccount.DoesNotExist:
84
                    continue
85
                except MultipleObjectsReturned:
86
                    continue
87
                had_changes = False
88
                for claim in provider.claim_mappings.all():
89
                    if '{{' in claim.claim or '{%' in claim.claim:
90
                        template = Template(claim.claim)
91
                        attribute_value = template.render(context=user_dict)
92
                    else:
93
                        attribute_value = user_dict.get(claim.claim)
94
                    try:
95
                        old_attribute_value = getattr(account.user, claim.attribute)
96
                    except AttributeError:
97
                        try:
98
                            old_attribute_value = getattr(account.user.attributes, claim.attribute)
99
                        except AttributeError:
100
                            old_attribute_value = None
101
                    if old_attribute_value == attribute_value:
102
                        continue
103
                    had_changes = True
104
                    setattr(account.user, claim.attribute, attribute_value)
105
                    try:
106
                        setattr(account.user.attributes, claim.attribute, attribute_value)
107
                    except AttributeError:
108
                        pass
109
                if had_changes:
110
                    if verbose:
111
                        self.stdout.write('had changes, saving %r' % account.user)
112
                    account.user.save()
tests/test_commands.py
17 17
import datetime
18 18
import importlib
19 19
import json
20
import random
21
import uuid
20 22
from io import BufferedReader, BufferedWriter, TextIOWrapper
21 23

  
24
import httmock
22 25
import py
23 26
import pytest
24 27
import webtest
25 28
from django.contrib.auth import get_user_model
26 29
from django.contrib.contenttypes.models import ContentType
30
from django.core.management import call_command
27 31
from django.utils.timezone import now
28 32
from jwcrypto.jwk import JWK, JWKSet
29 33

  
......
32 36
from authentic2.apps.journal.models import Event
33 37
from authentic2.custom_user.models import DeletedUser
34 38
from authentic2.models import UserExternalId
35
from authentic2_auth_oidc.models import OIDCAccount, OIDCProvider
39
from authentic2.utils import crypto
40
from authentic2_auth_oidc.models import OIDCAccount, OIDCClaimMapping, OIDCProvider
36 41
from django_rbac.models import ADMIN_OP, Operation
37 42
from django_rbac.utils import get_operation
38 43

  
......
437 442
    call_command('clean-user-exports')
438 443
    with pytest.raises(webtest.app.AppError):
439 444
        resp.click('Download CSV')
445

  
446

  
447
@pytest.mark.parametrize('deletion_number,deletion_valid', [(2, True), (5, True), (10, False)])
448
def test_oidc_sync_provider(db, app, admin, settings, capsys, deletion_number, deletion_valid):
449
    oidc_provider = OIDCProvider.objects.create(
450
        issuer='https://some.provider',
451
        name='Some Provider',
452
        slug='some-provider',
453
        ou=get_default_ou(),
454
    )
455
    OIDCClaimMapping.objects.create(
456
        provider=oidc_provider,
457
        attribute='username',
458
        idtoken_claim=False,
459
        claim='username',
460
    )
461
    OIDCClaimMapping.objects.create(
462
        provider=oidc_provider,
463
        attribute='email',
464
        idtoken_claim=False,
465
        claim='email',
466
    )
467
    # last one, with an idtoken claim
468
    OIDCClaimMapping.objects.create(
469
        provider=oidc_provider,
470
        attribute='last_name',
471
        idtoken_claim=True,
472
        claim='family_name',
473
    )
474
    # typo in template string
475
    OIDCClaimMapping.objects.create(
476
        provider=oidc_provider,
477
        attribute='first_name',
478
        idtoken_claim=True,
479
        claim='given_name',
480
    )
481
    User = get_user_model()
482
    for i in range(100):
483
        user = User.objects.create(
484
            first_name='John%s' % i,
485
            last_name='Doe%s' % i,
486
            username='john.doe.%s' % i,
487
            email='john.doe.%s@ad.dre.ss' % i,
488
            ou=get_default_ou(),
489
        )
490
        identifier = uuid.UUID(user.uuid).bytes
491
        sector_identifier = 'some.provider'
492
        cipher_args = [
493
            settings.SECRET_KEY.encode('utf-8'),
494
            identifier,
495
            sector_identifier,
496
        ]
497
        sub = crypto.aes_base64url_deterministic_encrypt(*cipher_args).decode('utf-8')
498
        OIDCAccount.objects.create(user=user, provider=oidc_provider, sub=sub)
499

  
500
    def synchronization_post_deletion_response(url, request):
501
        headers = {'content-type': 'application/json'}
502
        content = {
503
            'unknown_uuids': [
504
                account.sub for account in random.sample(list(OIDCAccount.objects.all()), deletion_number)
505
            ]
506
        }
507
        return httmock.response(status_code=200, headers=headers, content=content, request=request)
508

  
509
    def synchronization_get_modified_response(url, request):
510
        headers = {'content-type': 'application/json'}
511
        # randomized batch of modified users
512
        modified_users = random.sample(list(User.objects.all()), 20)
513
        results = []
514
        for count, user in enumerate(modified_users):
515
            user_json = user.to_json()
516
            user_json['username'] = f'modified_{count}'
517
            user_json['first_name'] = 'Mod'
518
            user_json['last_name'] = 'Ified'
519
            # mocking claim resolution by oidc provider
520
            user_json['given_name'] = 'Mod'
521
            user_json['family_name'] = 'Ified'
522
            results.append(user_json)
523
        content = {'results': results}
524
        return httmock.response(status_code=200, headers=headers, content=content, request=request)
525

  
526
    with httmock.HTTMock(
527
        httmock.urlmatch(
528
            netloc=r'some\.provider',
529
            path=r'^/api/users/synchronization/$',
530
            method='POST',
531
        )(synchronization_post_deletion_response)
532
    ):
533

  
534
        with httmock.HTTMock(
535
            httmock.urlmatch(
536
                netloc=r'some\.provider',
537
                path=r'^/api/users/*',
538
                method='GET',
539
            )(synchronization_get_modified_response)
540
        ):
541
            call_command('oidc-sync-provider', '--delta', '300', '-v1')
542
            out, err = capsys.readouterr()
543
            assert 'no declared provider' in out
544

  
545
            call_command('oidc-sync-provider', '--delta', '300', '--provider', 'unknown-provider', '-v1')
546
            out, err = capsys.readouterr()
547
            assert 'provider unknown-provider not found' in out
548

  
549
            call_command('oidc-sync-provider', '--delta', '300', '--provider', 'some-provider', '-v1')
550
            out, err = capsys.readouterr()
551
            assert not err
552
            if deletion_valid:
553
                # existing users check
554
                assert OIDCAccount.objects.count() == 100 - deletion_number
555
            else:
556
                assert 'deletion ratio is abnormally high' in out
557
                assert OIDCAccount.objects.count() == 100
558

  
559
            # users update
560
            assert 'got 20 users' in out
561
            assert User.objects.filter(username__startswith='modified').count() in range(
562
                20 - deletion_number, 21
563
            )
564
            assert User.objects.filter(first_name='Mod', last_name='Ified').count() in range(
565
                20 - deletion_number, 21
566
            )
440
-