Projet

Général

Profil

0001-auth_oidc-add-an-oidc-sync-provider-command-62710.patch

Paul Marillonnet, 01 juin 2022 09:51

Télécharger (10,6 ko)

Voir les différences:

Subject: [PATCH] auth_oidc: add an oidc-sync-provider command (#62710)

 .../management/commands/oidc-sync-provider.py | 112 ++++++++++++++++++
 tests/test_commands.py                        |  92 ++++++++++++++
 2 files changed, 204 insertions(+)
 create mode 100644 src/authentic2_auth_oidc/management/commands/oidc-sync-provider.py
src/authentic2_auth_oidc/management/commands/oidc-sync-provider.py
1
# authentic2 - versatile identity manager
2
# Copyright (C) 2010-2022  Entr'ouvert
3
#
4
# This program is free software: you can redistribute it and/or modify it
5
# under the terms of the GNU Affero General Public License as published
6
# by the Free Software Foundation, either version 3 of the License, or
7
# (at your option) any later version.
8
#
9
# This program is distributed in the hope that it will be useful,
10
# but WITHOUT ANY WARRANTY; without even the implied warranty of
11
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12
# GNU Affero General Public License for more details.
13
#
14
# You should have received a copy of the GNU Affero General Public License
15
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
16

  
17
import datetime
18

  
19
import requests
20
from django.core.exceptions import MultipleObjectsReturned
21
from django.core.management.base import BaseCommand
22

  
23
from authentic2.utils.template import Template
24
from authentic2_auth_oidc.models import OIDCAccount, OIDCProvider
25

  
26

  
27
class Command(BaseCommand):
28
    def add_arguments(self, parser):
29
        parser.add_argument('--delta', metavar='DELTA', type=int, default=300)
30
        parser.add_argument('--provider', type=str, default=None)
31

  
32
    def handle(self, *args, **options):
33
        verbose = int(options['verbosity']) > 0
34
        delta = options['delta']
35
        provider = options['provider']
36

  
37
        if not provider:
38
            self.stdout.write(self.style.ERROR('no declared provider, exiting...'))
39
            return
40
        try:
41
            provider = OIDCProvider.objects.get(slug=provider)
42
        except OIDCProvider.DoesNotExist:
43
            self.stdout.write(self.style.ERROR(f'provider {provider} not found, exiting...'))
44
            return
45

  
46
        # check all existing users
47
        def chunks(l, n):
48
            for i in range(0, len(l), n):
49
                yield l[i : i + n]
50

  
51
        url = provider.issuer + '/api/users/synchronization/'
52

  
53
        unknown_uuids = []
54
        auth = (provider.client_id, provider.client_secret)
55
        for accounts in chunks(OIDCAccount.objects.filter(provider=provider), 100):
56
            subs = [x.sub for x in accounts]
57
            resp = requests.post(url, json={'known_uuids': subs}, auth=auth)
58
            resp.raise_for_status()
59
            unknown_uuids.extend(resp.json().get('unknown_uuids'))
60
        deletion_ratio = len(unknown_uuids) / OIDCAccount.objects.filter(provider=provider).count()
61
        if deletion_ratio > 0.05:  # higher than 5%, something definitely went wrong
62
            self.stdout.write(
63
                self.style.ERROR(
64
                    f'deletion ratio is abnormally high ({deletion_ratio}), aborting unkwown users deletion'
65
                )
66
            )
67
        else:
68
            OIDCAccount.objects.filter(sub__in=unknown_uuids).delete()
69

  
70
        # update recently modified users
71
        url = provider.issuer + '/api/users/?modified__gt=%s' % (
72
            datetime.datetime.now() - datetime.timedelta(seconds=delta)
73
        ).strftime('%Y-%m-%dT%H:%M:%S')
74
        while url:
75
            resp = requests.get(url, auth=auth)
76
            resp.raise_for_status()
77
            url = resp.json().get('next')
78
            if verbose:
79
                self.stdout.write('got %s users' % len(resp.json()['results']))
80
            for user_dict in resp.json()['results']:
81
                try:
82
                    account = OIDCAccount.objects.get(user__email=user_dict['email'])
83
                except OIDCAccount.DoesNotExist:
84
                    continue
85
                except MultipleObjectsReturned:
86
                    continue
87
                had_changes = False
88
                for claim in provider.claim_mappings.all():
89
                    if '{{' in claim.claim or '{%' in claim.claim:
90
                        template = Template(claim.claim)
91
                        attribute_value = template.render(context=user_dict)
92
                    else:
93
                        attribute_value = user_dict.get(claim.claim)
94
                    try:
95
                        old_attribute_value = getattr(account.user, claim.attribute)
96
                    except AttributeError:
97
                        try:
98
                            old_attribute_value = getattr(account.user.attributes, claim.attribute)
99
                        except AttributeError:
100
                            old_attribute_value = None
101
                    if old_attribute_value == attribute_value:
102
                        continue
103
                    had_changes = True
104
                    setattr(account.user, claim.attribute, attribute_value)
105
                    try:
106
                        setattr(account.user.attributes, claim.attribute, attribute_value)
107
                    except AttributeError:
108
                        pass
109
                if had_changes:
110
                    if verbose:
111
                        self.stdout.write('had changes, saving %r' % account.user)
112
                    account.user.save()
tests/test_commands.py
17 17
import datetime
18 18
import importlib
19 19
import json
20
import random
21
import uuid
20 22
from io import BufferedReader, BufferedWriter, TextIOWrapper
21 23

  
24
import httmock
22 25
import py
23 26
import pytest
24 27
import webtest
25 28
from django.contrib.auth import get_user_model
26 29
from django.contrib.contenttypes.models import ContentType
30
from django.core.management import call_command
27 31
from django.utils.timezone import now
28 32
from jwcrypto.jwk import JWK, JWKSet
29 33

  
......
32 36
from authentic2.apps.journal.models import Event
33 37
from authentic2.custom_user.models import DeletedUser
34 38
from authentic2.models import UserExternalId
39
from authentic2.utils import crypto
35 40
from authentic2_auth_oidc.models import OIDCAccount, OIDCProvider
36 41
from django_rbac.models import ADMIN_OP, Operation
37 42
from django_rbac.utils import get_operation
......
437 442
    call_command('clean-user-exports')
438 443
    with pytest.raises(webtest.app.AppError):
439 444
        resp.click('Download CSV')
445

  
446

  
447
@pytest.mark.parametrize('deletion_number,deletion_valid', [(2, True), (5, True), (10, False)])
448
def test_oidc_sync_provider(db, app, admin, settings, capsys, deletion_number, deletion_valid):
449
    oidc_provider = OIDCProvider.objects.create(
450
        issuer='https://some.provider',
451
        name='Some Provider',
452
        slug='some-provider',
453
        ou=get_default_ou(),
454
    )
455
    User = get_user_model()
456
    for i in range(100):
457
        user = User.objects.create(
458
            first_name='John%s' % i,
459
            last_name='Doe%s' % i,
460
            username='john.doe.%s' % i,
461
            email='john.doe.%s@ad.dre.ss',
462
            ou=get_default_ou(),
463
        )
464
        identifier = uuid.UUID(user.uuid).bytes
465
        sector_identifier = 'some.provider'
466
        cipher_args = [
467
            settings.SECRET_KEY.encode('utf-8'),
468
            identifier,
469
            sector_identifier,
470
        ]
471
        sub = crypto.aes_base64url_deterministic_encrypt(*cipher_args).decode('utf-8')
472
        OIDCAccount.objects.create(user=user, provider=oidc_provider, sub=sub)
473

  
474
    def synchronization_post_deletion_response(url, request):
475
        headers = {'content-type': 'application/json'}
476
        content = {
477
            'unknown_uuids': [
478
                account.sub for account in random.sample(list(OIDCAccount.objects.all()), deletion_number)
479
            ]
480
        }
481
        return httmock.response(status_code=200, headers=headers, content=content, request=request)
482

  
483
    def synchronization_get_modified_response(url, request):
484
        headers = {'content-type': 'application/json'}
485
        # randomized batch of modified users
486
        modified_users = random.sample(list(User.objects.all()), 20)
487
        for count, user in enumerate(modified_users):
488
            user.username = f'modified_{count}'
489
            user.first_name = 'Mod'
490
            user.last_name = 'Ified'
491
            user.save()
492
        content = {'results': [user.to_json() for user in modified_users]}
493
        return httmock.response(status_code=200, headers=headers, content=content, request=request)
494

  
495
    with httmock.HTTMock(
496
        httmock.urlmatch(
497
            netloc=r'some\.provider',
498
            path=r'^/api/users/synchronization/$',
499
            method='POST',
500
        )(synchronization_post_deletion_response)
501
    ):
502

  
503
        with httmock.HTTMock(
504
            httmock.urlmatch(
505
                netloc=r'some\.provider',
506
                path=r'^/api/users/*',
507
                method='GET',
508
            )(synchronization_get_modified_response)
509
        ):
510
            call_command('oidc-sync-provider', '--delta', '300', '-v1')
511
            out, err = capsys.readouterr()
512
            assert 'no declared provider' in out
513

  
514
            call_command('oidc-sync-provider', '--delta', '300', '--provider', 'unknown-provider', '-v1')
515
            out, err = capsys.readouterr()
516
            assert 'provider unknown-provider not found' in out
517

  
518
            call_command('oidc-sync-provider', '--delta', '300', '--provider', 'some-provider', '-v1')
519
            out, err = capsys.readouterr()
520
            assert not err
521
            if deletion_valid:
522
                # existing users check
523
                assert OIDCAccount.objects.count() == 100 - deletion_number
524
            else:
525
                assert 'deletion ratio is abnormally high' in out
526
                assert OIDCAccount.objects.count() == 100
527

  
528
            # users update
529
            assert 'got 20 users' in out
530
            assert User.objects.filter(username__startswith='modified').count() == 20
531
            assert User.objects.filter(first_name='Mod', last_name='Ified').count() == 20
440
-