0001-auth_oidc-add-an-oidc-sync-provider-command-62710.patch
src/authentic2_auth_oidc/management/commands/oidc-sync-provider.py | ||
---|---|---|
1 |
# authentic2 - versatile identity manager |
|
2 |
# Copyright (C) 2010-2022 Entr'ouvert |
|
3 |
# |
|
4 |
# This program is free software: you can redistribute it and/or modify it |
|
5 |
# under the terms of the GNU Affero General Public License as published |
|
6 |
# by the Free Software Foundation, either version 3 of the License, or |
|
7 |
# (at your option) any later version. |
|
8 |
# |
|
9 |
# This program is distributed in the hope that it will be useful, |
|
10 |
# but WITHOUT ANY WARRANTY; without even the implied warranty of |
|
11 |
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the |
|
12 |
# GNU Affero General Public License for more details. |
|
13 |
# |
|
14 |
# You should have received a copy of the GNU Affero General Public License |
|
15 |
# along with this program. If not, see <http://www.gnu.org/licenses/>. |
|
16 | ||
17 |
from django.core.management.base import BaseCommand |
|
18 | ||
19 |
from authentic2_auth_oidc.models import OIDCProvider |
|
20 | ||
21 | ||
22 |
class Command(BaseCommand): |
|
23 |
def add_arguments(self, parser): |
|
24 |
parser.add_argument('--provider', type=str, default=None) |
|
25 | ||
26 |
def handle(self, *args, **options): |
|
27 |
verbose = int(options['verbosity']) > 0 |
|
28 |
provider = options['provider'] |
|
29 | ||
30 |
providers = OIDCProvider.objects.filter(a2_synchronization_supported=True) |
|
31 |
if provider: |
|
32 |
providers = providers.filter(slug=provider) |
|
33 |
if not providers.count(): |
|
34 |
self.stdout.write('no provider supporting synchronization found, exiting') |
|
35 |
return |
|
36 |
else: |
|
37 |
self.stdout.write( |
|
38 |
'got %s provider(s): %s' |
|
39 |
% (providers.count(), ' '.join(providers.values_list('slug', flat=True))) |
|
40 |
) |
|
41 |
for provider in providers: |
|
42 |
provider.perform_synchronization(verbose=verbose) |
src/authentic2_auth_oidc/migrations/0013_auto_20220803_1130.py | ||
---|---|---|
1 |
# Generated by Django 2.2.26 on 2022-08-03 09:30 |
|
2 | ||
3 |
from django.db import migrations, models |
|
4 | ||
5 | ||
6 |
class Migration(migrations.Migration): |
|
7 | ||
8 |
dependencies = [ |
|
9 |
('authentic2_auth_oidc', '0012_auto_20220524_1147'), |
|
10 |
] |
|
11 | ||
12 |
operations = [ |
|
13 |
migrations.AddField( |
|
14 |
model_name='oidcprovider', |
|
15 |
name='a2_synchronization_supported', |
|
16 |
field=models.BooleanField(default=False, verbose_name='Authentic2 synchronization supported'), |
|
17 |
), |
|
18 |
migrations.AddField( |
|
19 |
model_name='oidcprovider', |
|
20 |
name='last_sync_time', |
|
21 |
field=models.DateTimeField(blank=True, null=True, verbose_name='Last synchronization time'), |
|
22 |
), |
|
23 |
] |
src/authentic2_auth_oidc/models.py | ||
---|---|---|
15 | 15 |
# along with this program. If not, see <http://www.gnu.org/licenses/>. |
16 | 16 | |
17 | 17 |
import json |
18 |
import logging |
|
18 | 19 |
import uuid |
20 |
from datetime import datetime, timedelta |
|
19 | 21 | |
22 |
import requests |
|
20 | 23 |
from django.conf import settings |
21 | 24 |
from django.contrib.postgres.fields import JSONField |
22 | 25 |
from django.core.exceptions import ValidationError |
23 | 26 |
from django.db import models |
24 | 27 |
from django.shortcuts import render |
28 |
from django.utils.timezone import now |
|
25 | 29 |
from django.utils.translation import gettext_lazy as _ |
26 | 30 |
from jwcrypto.jwk import InvalidJWKValue, JWKSet |
27 | 31 | |
28 | 32 |
from authentic2.a2_rbac.utils import get_default_ou |
29 | 33 |
from authentic2.apps.authenticators.models import BaseAuthenticator |
30 | 34 |
from authentic2.utils.misc import make_url, redirect_to_login |
31 |
from authentic2.utils.template import validate_template |
|
35 |
from authentic2.utils.template import Template, validate_template
|
|
32 | 36 | |
33 | 37 |
from . import managers |
34 | 38 | |
... | ... | |
107 | 111 |
verbose_name=_('max authentication age'), blank=True, null=True |
108 | 112 |
) |
109 | 113 | |
114 |
# authentic2 specific synchronization api |
|
115 |
a2_synchronization_supported = models.BooleanField( |
|
116 |
verbose_name=_('Authentic2 synchronization supported'), |
|
117 |
default=False, |
|
118 |
) |
|
119 |
last_sync_time = models.DateTimeField( |
|
120 |
verbose_name=_('Last synchronization time'), |
|
121 |
null=True, |
|
122 |
blank=True, |
|
123 |
) |
|
124 | ||
110 | 125 |
# metadata |
111 | 126 |
created = models.DateTimeField(verbose_name=_('created'), auto_now_add=True) |
112 | 127 |
modified = models.DateTimeField(verbose_name=_('modified'), auto_now=True) |
... | ... | |
213 | 228 |
] |
214 | 229 |
return render(request, template_names, context) |
215 | 230 | |
231 |
def perform_synchronization(self, sync_time=None, verbose=False): |
|
232 |
logger = logging.getLogger(__name__) |
|
233 | ||
234 |
if not self.a2_synchronization_supported: |
|
235 |
logger.error('OIDC provider %s does not support synchronization', self.slug) |
|
236 |
return |
|
237 |
if not sync_time: |
|
238 |
sync_time = now() - timedelta(minutes=1) |
|
239 | ||
240 |
# check all existing users |
|
241 |
def chunks(l, n): |
|
242 |
for i in range(0, len(l), n): |
|
243 |
yield l[i : i + n] |
|
244 | ||
245 |
url = self.issuer + '/api/users/synchronization/' |
|
246 | ||
247 |
unknown_uuids = [] |
|
248 |
auth = (self.client_id, self.client_secret) |
|
249 |
for accounts in chunks(OIDCAccount.objects.filter(provider=self), 100): |
|
250 |
subs = [x.sub for x in accounts] |
|
251 |
resp = requests.post(url, json={'known_uuids': subs}, auth=auth) |
|
252 |
resp.raise_for_status() |
|
253 |
unknown_uuids.extend(resp.json().get('unknown_uuids')) |
|
254 |
deletion_ratio = len(unknown_uuids) / OIDCAccount.objects.filter(provider=self).count() |
|
255 |
if deletion_ratio > 0.05: # higher than 5%, something definitely went wrong |
|
256 |
logger.error( |
|
257 |
'deletion ratio is abnormally high (%s), aborting unkwown users deletion', deletion_ratio |
|
258 |
) |
|
259 |
else: |
|
260 |
OIDCAccount.objects.filter(sub__in=unknown_uuids).delete() |
|
261 | ||
262 |
# update recently modified users |
|
263 |
url = self.issuer + '/api/users/?modified__gt=%s&claim_resolution' % ( |
|
264 |
self.last_sync_time or datetime.utcfromtimestamp(0) |
|
265 |
).strftime('%Y-%m-%dT%H:%M:%S') |
|
266 |
while url: |
|
267 |
resp = requests.get(url, auth=auth) |
|
268 |
resp.raise_for_status() |
|
269 |
url = resp.json().get('next') |
|
270 |
if verbose: |
|
271 |
logger.info('got %s users', len(resp.json()['results'])) |
|
272 |
for user_dict in resp.json()['results']: |
|
273 |
try: |
|
274 |
account = OIDCAccount.objects.get(user__email=user_dict['email']) |
|
275 |
except OIDCAccount.DoesNotExist: |
|
276 |
continue |
|
277 |
except OIDCAccount.MultipleObjectsReturned: |
|
278 |
continue |
|
279 |
had_changes = False |
|
280 |
for claim in self.claim_mappings.all(): |
|
281 |
if '{{' in claim.claim or '{%' in claim.claim: |
|
282 |
template = Template(claim.claim) |
|
283 |
attribute_value = template.render(context=user_dict) |
|
284 |
else: |
|
285 |
attribute_value = user_dict.get(claim.claim) |
|
286 |
try: |
|
287 |
old_attribute_value = getattr(account.user, claim.attribute) |
|
288 |
except AttributeError: |
|
289 |
try: |
|
290 |
old_attribute_value = getattr(account.user.attributes, claim.attribute) |
|
291 |
except AttributeError: |
|
292 |
old_attribute_value = None |
|
293 |
if old_attribute_value == attribute_value: |
|
294 |
continue |
|
295 |
had_changes = True |
|
296 |
setattr(account.user, claim.attribute, attribute_value) |
|
297 |
try: |
|
298 |
setattr(account.user.attributes, claim.attribute, attribute_value) |
|
299 |
except AttributeError: |
|
300 |
pass |
|
301 |
if had_changes: |
|
302 |
if verbose: |
|
303 |
logger.info('had changes, saving %r', account.user) |
|
304 |
account.user.save() |
|
305 |
self.last_sync_time = sync_time |
|
306 |
self.save(update_fields=['last_sync_time']) |
|
307 | ||
216 | 308 | |
217 | 309 |
class OIDCClaimMapping(models.Model): |
218 | 310 |
NOT_VERIFIED = 0 |
tests/test_commands.py | ||
---|---|---|
17 | 17 |
import datetime |
18 | 18 |
import importlib |
19 | 19 |
import json |
20 |
import random |
|
21 |
import uuid |
|
20 | 22 |
from io import BufferedReader, BufferedWriter, TextIOWrapper |
21 | 23 | |
24 |
import httmock |
|
22 | 25 |
import py |
23 | 26 |
import pytest |
24 | 27 |
import webtest |
25 | 28 |
from django.contrib.auth import get_user_model |
26 | 29 |
from django.contrib.contenttypes.models import ContentType |
30 |
from django.core.management import call_command |
|
27 | 31 |
from django.utils.timezone import now |
28 | 32 |
from jwcrypto.jwk import JWK, JWKSet |
29 | 33 | |
... | ... | |
32 | 36 |
from authentic2.apps.journal.models import Event |
33 | 37 |
from authentic2.custom_user.models import DeletedUser |
34 | 38 |
from authentic2.models import UserExternalId |
35 |
from authentic2_auth_oidc.models import OIDCAccount, OIDCProvider |
|
39 |
from authentic2.utils import crypto |
|
40 |
from authentic2_auth_oidc.models import OIDCAccount, OIDCClaimMapping, OIDCProvider |
|
36 | 41 |
from django_rbac.models import ADMIN_OP, Operation |
37 | 42 |
from django_rbac.utils import get_operation |
38 | 43 | |
39 |
from .utils import call_command, login |
|
44 |
from .utils import call_command, check_log, login
|
|
40 | 45 | |
41 | 46 |
User = get_user_model() |
42 | 47 | |
... | ... | |
437 | 442 |
call_command('clean-user-exports') |
438 | 443 |
with pytest.raises(webtest.app.AppError): |
439 | 444 |
resp.click('Download CSV') |
445 | ||
446 | ||
447 |
@pytest.mark.parametrize('deletion_number,deletion_valid', [(2, True), (5, True), (10, False)]) |
|
448 |
def test_oidc_sync_provider(db, app, admin, settings, caplog, capsys, deletion_number, deletion_valid): |
|
449 |
oidc_provider = OIDCProvider.objects.create( |
|
450 |
issuer='https://some.provider', |
|
451 |
name='Some Provider', |
|
452 |
slug='some-provider', |
|
453 |
ou=get_default_ou(), |
|
454 |
) |
|
455 |
OIDCClaimMapping.objects.create( |
|
456 |
provider=oidc_provider, |
|
457 |
attribute='username', |
|
458 |
idtoken_claim=False, |
|
459 |
claim='username', |
|
460 |
) |
|
461 |
OIDCClaimMapping.objects.create( |
|
462 |
provider=oidc_provider, |
|
463 |
attribute='email', |
|
464 |
idtoken_claim=False, |
|
465 |
claim='email', |
|
466 |
) |
|
467 |
# last one, with an idtoken claim |
|
468 |
OIDCClaimMapping.objects.create( |
|
469 |
provider=oidc_provider, |
|
470 |
attribute='last_name', |
|
471 |
idtoken_claim=True, |
|
472 |
claim='family_name', |
|
473 |
) |
|
474 |
# typo in template string |
|
475 |
OIDCClaimMapping.objects.create( |
|
476 |
provider=oidc_provider, |
|
477 |
attribute='first_name', |
|
478 |
idtoken_claim=True, |
|
479 |
claim='given_name', |
|
480 |
) |
|
481 |
User = get_user_model() |
|
482 |
for i in range(100): |
|
483 |
user = User.objects.create( |
|
484 |
first_name='John%s' % i, |
|
485 |
last_name='Doe%s' % i, |
|
486 |
username='john.doe.%s' % i, |
|
487 |
email='john.doe.%s@ad.dre.ss' % i, |
|
488 |
ou=get_default_ou(), |
|
489 |
) |
|
490 |
identifier = uuid.UUID(user.uuid).bytes |
|
491 |
sector_identifier = 'some.provider' |
|
492 |
cipher_args = [ |
|
493 |
settings.SECRET_KEY.encode('utf-8'), |
|
494 |
identifier, |
|
495 |
sector_identifier, |
|
496 |
] |
|
497 |
sub = crypto.aes_base64url_deterministic_encrypt(*cipher_args).decode('utf-8') |
|
498 |
OIDCAccount.objects.create(user=user, provider=oidc_provider, sub=sub) |
|
499 | ||
500 |
def synchronization_post_deletion_response(url, request): |
|
501 |
headers = {'content-type': 'application/json'} |
|
502 |
content = { |
|
503 |
'unknown_uuids': [ |
|
504 |
account.sub for account in random.sample(list(OIDCAccount.objects.all()), deletion_number) |
|
505 |
] |
|
506 |
} |
|
507 |
return httmock.response(status_code=200, headers=headers, content=content, request=request) |
|
508 | ||
509 |
def synchronization_get_modified_response(url, request): |
|
510 |
headers = {'content-type': 'application/json'} |
|
511 |
# randomized batch of modified users |
|
512 |
modified_users = random.sample(list(User.objects.all()), 20) |
|
513 |
results = [] |
|
514 |
for count, user in enumerate(modified_users): |
|
515 |
user_json = user.to_json() |
|
516 |
user_json['username'] = f'modified_{count}' |
|
517 |
user_json['first_name'] = 'Mod' |
|
518 |
user_json['last_name'] = 'Ified' |
|
519 |
# mocking claim resolution by oidc provider |
|
520 |
user_json['given_name'] = 'Mod' |
|
521 |
user_json['family_name'] = 'Ified' |
|
522 |
results.append(user_json) |
|
523 |
content = {'results': results} |
|
524 |
return httmock.response(status_code=200, headers=headers, content=content, request=request) |
|
525 | ||
526 |
with httmock.HTTMock( |
|
527 |
httmock.urlmatch( |
|
528 |
netloc=r'some\.provider', |
|
529 |
path=r'^/api/users/synchronization/$', |
|
530 |
method='POST', |
|
531 |
)(synchronization_post_deletion_response) |
|
532 |
): |
|
533 | ||
534 |
with httmock.HTTMock( |
|
535 |
httmock.urlmatch( |
|
536 |
netloc=r'some\.provider', |
|
537 |
path=r'^/api/users/*', |
|
538 |
method='GET', |
|
539 |
)(synchronization_get_modified_response) |
|
540 |
): |
|
541 |
call_command('oidc-sync-provider', '-v1') |
|
542 |
out, _ = capsys.readouterr() |
|
543 |
assert 'no provider supporting synchronization' in out |
|
544 | ||
545 |
oidc_provider.a2_synchronization_supported = True |
|
546 |
oidc_provider.save() |
|
547 | ||
548 |
call_command('oidc-sync-provider', '--provider', 'whatever', '-v1') |
|
549 |
out, _ = capsys.readouterr() |
|
550 |
assert 'no provider supporting synchronization' in out |
|
551 | ||
552 |
with check_log(caplog, 'got 20 users'): |
|
553 |
call_command('oidc-sync-provider', '-v1') |
|
554 |
if deletion_valid: |
|
555 |
# existing users check |
|
556 |
assert OIDCAccount.objects.count() == 100 - deletion_number |
|
557 |
else: |
|
558 |
assert OIDCAccount.objects.count() == 100 |
|
559 |
assert caplog.records[0].levelname == 'ERROR' |
|
560 |
assert 'deletion ratio is abnormally high' in caplog.records[0].message |
|
561 | ||
562 |
# users update |
|
563 |
assert User.objects.filter(username__startswith='modified').count() in range( |
|
564 |
20 - deletion_number, 21 |
|
565 |
) |
|
566 |
assert User.objects.filter(first_name='Mod', last_name='Ified').count() in range( |
|
567 |
20 - deletion_number, 21 |
|
568 |
) |
|
440 |
- |