0001-auth_oidc-add-an-oidc-sync-provider-command-62710.patch
src/authentic2_auth_oidc/management/commands/oidc-sync-provider.py | ||
---|---|---|
1 |
# authentic2 - versatile identity manager |
|
2 |
# Copyright (C) 2010-2022 Entr'ouvert |
|
3 |
# |
|
4 |
# This program is free software: you can redistribute it and/or modify it |
|
5 |
# under the terms of the GNU Affero General Public License as published |
|
6 |
# by the Free Software Foundation, either version 3 of the License, or |
|
7 |
# (at your option) any later version. |
|
8 |
# |
|
9 |
# This program is distributed in the hope that it will be useful, |
|
10 |
# but WITHOUT ANY WARRANTY; without even the implied warranty of |
|
11 |
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the |
|
12 |
# GNU Affero General Public License for more details. |
|
13 |
# |
|
14 |
# You should have received a copy of the GNU Affero General Public License |
|
15 |
# along with this program. If not, see <http://www.gnu.org/licenses/>. |
|
16 | ||
17 |
import datetime |
|
18 | ||
19 |
import requests |
|
20 |
from django.core.exceptions import MultipleObjectsReturned |
|
21 |
from django.core.management.base import BaseCommand |
|
22 | ||
23 |
from authentic2.utils.template import Template |
|
24 |
from authentic2_auth_oidc.models import OIDCAccount, OIDCProvider |
|
25 | ||
26 | ||
27 |
class Command(BaseCommand): |
|
28 |
def add_arguments(self, parser): |
|
29 |
parser.add_argument('--delta', metavar='DELTA', type=int, default=300) |
|
30 |
parser.add_argument('--provider', type=str, default=None) |
|
31 | ||
32 |
def handle(self, *args, **options): |
|
33 |
verbose = int(options['verbosity']) > 0 |
|
34 |
delta = options['delta'] |
|
35 |
provider = options['provider'] |
|
36 | ||
37 |
if not provider: |
|
38 |
self.stdout.write(self.style.ERROR('no declared provider, exiting...')) |
|
39 |
return |
|
40 |
try: |
|
41 |
provider = OIDCProvider.objects.get(slug=provider) |
|
42 |
except OIDCProvider.DoesNotExist: |
|
43 |
self.stdout.write(self.style.ERROR(f'provider {provider} not found, exiting...')) |
|
44 |
return |
|
45 | ||
46 |
# check all existing users |
|
47 |
def chunks(l, n): |
|
48 |
for i in range(0, len(l), n): |
|
49 |
yield l[i : i + n] |
|
50 | ||
51 |
url = provider.issuer + '/api/users/synchronization/' |
|
52 | ||
53 |
unknown_uuids = [] |
|
54 |
auth = (provider.client_id, provider.client_secret) |
|
55 |
for accounts in chunks(OIDCAccount.objects.filter(provider=provider), 100): |
|
56 |
subs = [x.sub for x in accounts] |
|
57 |
resp = requests.post(url, json={'known_uuids': subs}, auth=auth) |
|
58 |
resp.raise_for_status() |
|
59 |
unknown_uuids.extend(resp.json().get('unknown_uuids')) |
|
60 |
deletion_ratio = len(unknown_uuids) / OIDCAccount.objects.filter(provider=provider).count() |
|
61 |
if deletion_ratio > 0.05: # higher than 5%, something definitely went wrong |
|
62 |
self.stdout.write( |
|
63 |
self.style.ERROR( |
|
64 |
f'deletion ratio is abnormally high ({deletion_ratio}), aborting unkwown users deletion' |
|
65 |
) |
|
66 |
) |
|
67 |
else: |
|
68 |
OIDCAccount.objects.filter(sub__in=unknown_uuids).delete() |
|
69 | ||
70 |
# update recently modified users |
|
71 |
url = provider.issuer + '/api/users/?modified__gt=%s&claim_resolution' % ( |
|
72 |
datetime.datetime.now() - datetime.timedelta(seconds=delta) |
|
73 |
).strftime('%Y-%m-%dT%H:%M:%S') |
|
74 |
while url: |
|
75 |
resp = requests.get(url, auth=auth) |
|
76 |
resp.raise_for_status() |
|
77 |
url = resp.json().get('next') |
|
78 |
if verbose: |
|
79 |
self.stdout.write('got %s users' % len(resp.json()['results'])) |
|
80 |
for user_dict in resp.json()['results']: |
|
81 |
try: |
|
82 |
account = OIDCAccount.objects.get(user__email=user_dict['email']) |
|
83 |
except OIDCAccount.DoesNotExist: |
|
84 |
continue |
|
85 |
except MultipleObjectsReturned: |
|
86 |
continue |
|
87 |
had_changes = False |
|
88 |
for claim in provider.claim_mappings.all(): |
|
89 |
if '{{' in claim.claim or '{%' in claim.claim: |
|
90 |
template = Template(claim.claim) |
|
91 |
attribute_value = template.render(context=user_dict) |
|
92 |
else: |
|
93 |
attribute_value = user_dict.get(claim.claim) |
|
94 |
try: |
|
95 |
old_attribute_value = getattr(account.user, claim.attribute) |
|
96 |
except AttributeError: |
|
97 |
try: |
|
98 |
old_attribute_value = getattr(account.user.attributes, claim.attribute) |
|
99 |
except AttributeError: |
|
100 |
old_attribute_value = None |
|
101 |
if old_attribute_value == attribute_value: |
|
102 |
continue |
|
103 |
had_changes = True |
|
104 |
setattr(account.user, claim.attribute, attribute_value) |
|
105 |
try: |
|
106 |
setattr(account.user.attributes, claim.attribute, attribute_value) |
|
107 |
except AttributeError: |
|
108 |
pass |
|
109 |
if had_changes: |
|
110 |
if verbose: |
|
111 |
self.stdout.write('had changes, saving %r' % account.user) |
|
112 |
account.user.save() |
tests/test_commands.py | ||
---|---|---|
17 | 17 |
import datetime |
18 | 18 |
import importlib |
19 | 19 |
import json |
20 |
import random |
|
21 |
import uuid |
|
20 | 22 |
from io import BufferedReader, BufferedWriter, TextIOWrapper |
21 | 23 | |
24 |
import httmock |
|
22 | 25 |
import py |
23 | 26 |
import pytest |
24 | 27 |
import webtest |
25 | 28 |
from django.contrib.auth import get_user_model |
26 | 29 |
from django.contrib.contenttypes.models import ContentType |
30 |
from django.core.management import call_command |
|
27 | 31 |
from django.utils.timezone import now |
28 | 32 |
from jwcrypto.jwk import JWK, JWKSet |
29 | 33 | |
... | ... | |
32 | 36 |
from authentic2.apps.journal.models import Event |
33 | 37 |
from authentic2.custom_user.models import DeletedUser |
34 | 38 |
from authentic2.models import UserExternalId |
35 |
from authentic2_auth_oidc.models import OIDCAccount, OIDCProvider |
|
39 |
from authentic2.utils import crypto |
|
40 |
from authentic2_auth_oidc.models import OIDCAccount, OIDCClaimMapping, OIDCProvider |
|
36 | 41 |
from django_rbac.models import ADMIN_OP, Operation |
37 | 42 |
from django_rbac.utils import get_operation |
38 | 43 | |
... | ... | |
437 | 442 |
call_command('clean-user-exports') |
438 | 443 |
with pytest.raises(webtest.app.AppError): |
439 | 444 |
resp.click('Download CSV') |
445 | ||
446 | ||
447 |
@pytest.mark.parametrize('deletion_number,deletion_valid', [(2, True), (5, True), (10, False)]) |
|
448 |
def test_oidc_sync_provider(db, app, admin, settings, capsys, deletion_number, deletion_valid): |
|
449 |
oidc_provider = OIDCProvider.objects.create( |
|
450 |
issuer='https://some.provider', |
|
451 |
name='Some Provider', |
|
452 |
slug='some-provider', |
|
453 |
ou=get_default_ou(), |
|
454 |
) |
|
455 |
OIDCClaimMapping.objects.create( |
|
456 |
provider=oidc_provider, |
|
457 |
attribute='username', |
|
458 |
idtoken_claim=False, |
|
459 |
claim='username', |
|
460 |
) |
|
461 |
OIDCClaimMapping.objects.create( |
|
462 |
provider=oidc_provider, |
|
463 |
attribute='email', |
|
464 |
idtoken_claim=False, |
|
465 |
claim='email', |
|
466 |
) |
|
467 |
# last one, with an idtoken claim |
|
468 |
OIDCClaimMapping.objects.create( |
|
469 |
provider=oidc_provider, |
|
470 |
attribute='last_name', |
|
471 |
idtoken_claim=True, |
|
472 |
claim='family_name', |
|
473 |
) |
|
474 |
# typo in template string |
|
475 |
OIDCClaimMapping.objects.create( |
|
476 |
provider=oidc_provider, |
|
477 |
attribute='first_name', |
|
478 |
idtoken_claim=True, |
|
479 |
claim='given_name', |
|
480 |
) |
|
481 |
User = get_user_model() |
|
482 |
for i in range(100): |
|
483 |
user = User.objects.create( |
|
484 |
first_name='John%s' % i, |
|
485 |
last_name='Doe%s' % i, |
|
486 |
username='john.doe.%s' % i, |
|
487 |
email='john.doe.%s@ad.dre.ss' % i, |
|
488 |
ou=get_default_ou(), |
|
489 |
) |
|
490 |
identifier = uuid.UUID(user.uuid).bytes |
|
491 |
sector_identifier = 'some.provider' |
|
492 |
cipher_args = [ |
|
493 |
settings.SECRET_KEY.encode('utf-8'), |
|
494 |
identifier, |
|
495 |
sector_identifier, |
|
496 |
] |
|
497 |
sub = crypto.aes_base64url_deterministic_encrypt(*cipher_args).decode('utf-8') |
|
498 |
OIDCAccount.objects.create(user=user, provider=oidc_provider, sub=sub) |
|
499 | ||
500 |
def synchronization_post_deletion_response(url, request): |
|
501 |
headers = {'content-type': 'application/json'} |
|
502 |
content = { |
|
503 |
'unknown_uuids': [ |
|
504 |
account.sub for account in random.sample(list(OIDCAccount.objects.all()), deletion_number) |
|
505 |
] |
|
506 |
} |
|
507 |
return httmock.response(status_code=200, headers=headers, content=content, request=request) |
|
508 | ||
509 |
def synchronization_get_modified_response(url, request): |
|
510 |
headers = {'content-type': 'application/json'} |
|
511 |
# randomized batch of modified users |
|
512 |
modified_users = random.sample(list(User.objects.all()), 20) |
|
513 |
results = [] |
|
514 |
for count, user in enumerate(modified_users): |
|
515 |
user_json = user.to_json() |
|
516 |
user_json['username'] = f'modified_{count}' |
|
517 |
user_json['first_name'] = 'Mod' |
|
518 |
user_json['last_name'] = 'Ified' |
|
519 |
# mocking claim resolution by oidc provider |
|
520 |
user_json['given_name'] = 'Mod' |
|
521 |
user_json['family_name'] = 'Ified' |
|
522 |
results.append(user_json) |
|
523 |
content = {'results': results} |
|
524 |
return httmock.response(status_code=200, headers=headers, content=content, request=request) |
|
525 | ||
526 |
with httmock.HTTMock( |
|
527 |
httmock.urlmatch( |
|
528 |
netloc=r'some\.provider', |
|
529 |
path=r'^/api/users/synchronization/$', |
|
530 |
method='POST', |
|
531 |
)(synchronization_post_deletion_response) |
|
532 |
): |
|
533 | ||
534 |
with httmock.HTTMock( |
|
535 |
httmock.urlmatch( |
|
536 |
netloc=r'some\.provider', |
|
537 |
path=r'^/api/users/*', |
|
538 |
method='GET', |
|
539 |
)(synchronization_get_modified_response) |
|
540 |
): |
|
541 |
call_command('oidc-sync-provider', '--delta', '300', '-v1') |
|
542 |
out, err = capsys.readouterr() |
|
543 |
assert 'no declared provider' in out |
|
544 | ||
545 |
call_command('oidc-sync-provider', '--delta', '300', '--provider', 'unknown-provider', '-v1') |
|
546 |
out, err = capsys.readouterr() |
|
547 |
assert 'provider unknown-provider not found' in out |
|
548 | ||
549 |
call_command('oidc-sync-provider', '--delta', '300', '--provider', 'some-provider', '-v1') |
|
550 |
out, err = capsys.readouterr() |
|
551 |
assert not err |
|
552 |
if deletion_valid: |
|
553 |
# existing users check |
|
554 |
assert OIDCAccount.objects.count() == 100 - deletion_number |
|
555 |
else: |
|
556 |
assert 'deletion ratio is abnormally high' in out |
|
557 |
assert OIDCAccount.objects.count() == 100 |
|
558 | ||
559 |
# users update |
|
560 |
assert 'got 20 users' in out |
|
561 |
assert User.objects.filter(username__startswith='modified').count() in range( |
|
562 |
20 - deletion_number, 21 |
|
563 |
) |
|
564 |
assert User.objects.filter(first_name='Mod', last_name='Ified').count() in range( |
|
565 |
20 - deletion_number, 21 |
|
566 |
) |
|
440 |
- |