0002-auth_oidc-add-an-oidc-sync-provider-command-62710.patch
src/authentic2_auth_oidc/management/commands/oidc-sync-provider.py | ||
---|---|---|
1 |
# authentic2 - versatile identity manager |
|
2 |
# Copyright (C) 2010-2022 Entr'ouvert |
|
3 |
# |
|
4 |
# This program is free software: you can redistribute it and/or modify it |
|
5 |
# under the terms of the GNU Affero General Public License as published |
|
6 |
# by the Free Software Foundation, either version 3 of the License, or |
|
7 |
# (at your option) any later version. |
|
8 |
# |
|
9 |
# This program is distributed in the hope that it will be useful, |
|
10 |
# but WITHOUT ANY WARRANTY; without even the implied warranty of |
|
11 |
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the |
|
12 |
# GNU Affero General Public License for more details. |
|
13 |
# |
|
14 |
# You should have received a copy of the GNU Affero General Public License |
|
15 |
# along with this program. If not, see <http://www.gnu.org/licenses/>. |
|
16 | ||
17 |
import logging |
|
18 | ||
19 |
from authentic2.base_commands import LogToConsoleCommand |
|
20 |
from authentic2_auth_oidc.models import OIDCProvider |
|
21 | ||
22 | ||
23 |
class Command(LogToConsoleCommand): |
|
24 |
loggername = 'authentic2_auth_oidc.models' |
|
25 | ||
26 |
def add_arguments(self, parser): |
|
27 |
parser.add_argument('--provider', type=str, default=None) |
|
28 | ||
29 |
def core_command(self, *args, **kwargs): |
|
30 |
provider = kwargs['provider'] |
|
31 | ||
32 |
logger = logging.getLogger(self.loggername) |
|
33 |
providers = OIDCProvider.objects.filter(a2_synchronization_supported=True) |
|
34 |
if provider: |
|
35 |
providers = providers.filter(slug=provider) |
|
36 |
if not providers.count(): |
|
37 |
logger.error('no provider supporting synchronization found, exiting') |
|
38 |
return |
|
39 |
logger.info( |
|
40 |
'got %s provider(s): %s', |
|
41 |
providers.count(), |
|
42 |
' '.join(providers.values_list('slug', flat=True)), |
|
43 |
) |
|
44 |
for provider in providers: |
|
45 |
provider.perform_synchronization() |
src/authentic2_auth_oidc/migrations/0013_synchronization_fields.py | ||
---|---|---|
1 |
# Generated by Django 2.2.26 on 2022-08-03 09:30 |
|
2 | ||
3 |
from django.db import migrations, models |
|
4 | ||
5 | ||
6 |
class Migration(migrations.Migration): |
|
7 | ||
8 |
dependencies = [ |
|
9 |
('authentic2_auth_oidc', '0016_auto_20221019_1148'), |
|
10 |
] |
|
11 | ||
12 |
operations = [ |
|
13 |
migrations.AddField( |
|
14 |
model_name='oidcprovider', |
|
15 |
name='a2_synchronization_supported', |
|
16 |
field=models.BooleanField(default=False, verbose_name='Authentic2 synchronization supported'), |
|
17 |
), |
|
18 |
migrations.AddField( |
|
19 |
model_name='oidcprovider', |
|
20 |
name='last_sync_time', |
|
21 |
field=models.DateTimeField(blank=True, null=True, verbose_name='Last synchronization time'), |
|
22 |
), |
|
23 |
] |
src/authentic2_auth_oidc/models.py | ||
---|---|---|
15 | 15 |
# along with this program. If not, see <http://www.gnu.org/licenses/>. |
16 | 16 | |
17 | 17 |
import json |
18 |
import logging |
|
18 | 19 |
import uuid |
20 |
from datetime import datetime, timedelta |
|
19 | 21 | |
22 |
import requests |
|
20 | 23 |
from django.conf import settings |
21 | 24 |
from django.contrib.postgres.fields import JSONField |
22 | 25 |
from django.core.exceptions import ValidationError |
23 | 26 |
from django.db import models |
24 | 27 |
from django.shortcuts import render |
28 |
from django.utils.timezone import now |
|
25 | 29 |
from django.utils.translation import gettext_lazy as _ |
26 | 30 |
from django.utils.translation import pgettext_lazy |
27 | 31 |
from jwcrypto.jwk import InvalidJWKValue, JWKSet |
... | ... | |
33 | 37 |
BaseAuthenticator, |
34 | 38 |
) |
35 | 39 |
from authentic2.utils.misc import make_url, redirect_to_login |
36 |
from authentic2.utils.template import validate_template |
|
40 |
from authentic2.utils.template import Template, validate_template
|
|
37 | 41 | |
38 | 42 |
from . import managers |
39 | 43 | |
... | ... | |
116 | 120 |
verbose_name=_('max authentication age'), blank=True, null=True |
117 | 121 |
) |
118 | 122 | |
123 |
# authentic2 specific synchronization api |
|
124 |
a2_synchronization_supported = models.BooleanField( |
|
125 |
verbose_name=_('Authentic2 synchronization supported'), |
|
126 |
default=False, |
|
127 |
) |
|
128 |
last_sync_time = models.DateTimeField( |
|
129 |
verbose_name=_('Last synchronization time'), |
|
130 |
null=True, |
|
131 |
blank=True, |
|
132 |
) |
|
133 | ||
119 | 134 |
# metadata |
120 | 135 |
created = models.DateTimeField(verbose_name=_('creation date'), auto_now_add=True) |
121 | 136 |
modified = models.DateTimeField(verbose_name=_('last modification date'), auto_now=True) |
... | ... | |
242 | 257 |
] |
243 | 258 |
return render(request, template_names, context) |
244 | 259 | |
260 |
def perform_synchronization(self, sync_time=None, timeout=30): |
|
261 |
logger = logging.getLogger(__name__) |
|
262 | ||
263 |
if not self.a2_synchronization_supported: |
|
264 |
logger.error('OIDC provider %s does not support synchronization', self.slug) |
|
265 |
return |
|
266 |
if not sync_time: |
|
267 |
sync_time = now() - timedelta(minutes=1) |
|
268 | ||
269 |
# check all existing users |
|
270 |
def chunks(l, n): |
|
271 |
for i in range(0, len(l), n): |
|
272 |
yield l[i : i + n] |
|
273 | ||
274 |
url = self.issuer + '/api/users/synchronization/' |
|
275 | ||
276 |
unknown_uuids = [] |
|
277 |
auth = (self.client_id, self.client_secret) |
|
278 |
for accounts in chunks(OIDCAccount.objects.filter(provider=self), 100): |
|
279 |
subs = [x.sub for x in accounts] |
|
280 |
resp = requests.post(url, json={'known_uuids': subs}, auth=auth, timeout=timeout) |
|
281 |
resp.raise_for_status() |
|
282 |
unknown_uuids.extend(resp.json().get('unknown_uuids')) |
|
283 |
deletion_ratio = len(unknown_uuids) / OIDCAccount.objects.filter(provider=self).count() |
|
284 |
if deletion_ratio > 0.05: # higher than 5%, something definitely went wrong |
|
285 |
logger.error( |
|
286 |
'deletion ratio is abnormally high (%s), aborting unkwown users deletion', deletion_ratio |
|
287 |
) |
|
288 |
else: |
|
289 |
OIDCAccount.objects.filter(sub__in=unknown_uuids).delete() |
|
290 | ||
291 |
# update recently modified users |
|
292 |
url = self.issuer + '/api/users/?modified__gt=%s&claim_resolution' % ( |
|
293 |
self.last_sync_time or datetime.utcfromtimestamp(0) |
|
294 |
).strftime('%Y-%m-%dT%H:%M:%S') |
|
295 |
while url: |
|
296 |
resp = requests.get(url, auth=auth, timeout=timeout) |
|
297 |
resp.raise_for_status() |
|
298 |
url = resp.json().get('next') |
|
299 |
logger.info('got %s users', len(resp.json()['results'])) |
|
300 |
for user_dict in resp.json()['results']: |
|
301 |
try: |
|
302 |
account = OIDCAccount.objects.get(user__email=user_dict['email']) |
|
303 |
except OIDCAccount.DoesNotExist: |
|
304 |
continue |
|
305 |
except OIDCAccount.MultipleObjectsReturned: |
|
306 |
continue |
|
307 |
had_changes = False |
|
308 |
for claim in self.claim_mappings.all(): |
|
309 |
if '{{' in claim.claim or '{%' in claim.claim: |
|
310 |
template = Template(claim.claim) |
|
311 |
attribute_value = template.render(context=user_dict) |
|
312 |
else: |
|
313 |
attribute_value = user_dict.get(claim.claim) |
|
314 |
try: |
|
315 |
old_attribute_value = getattr(account.user, claim.attribute) |
|
316 |
except AttributeError: |
|
317 |
try: |
|
318 |
old_attribute_value = getattr(account.user.attributes, claim.attribute) |
|
319 |
except AttributeError: |
|
320 |
old_attribute_value = None |
|
321 |
if old_attribute_value == attribute_value: |
|
322 |
continue |
|
323 |
had_changes = True |
|
324 |
setattr(account.user, claim.attribute, attribute_value) |
|
325 |
try: |
|
326 |
setattr(account.user.attributes, claim.attribute, attribute_value) |
|
327 |
except AttributeError: |
|
328 |
pass |
|
329 |
if had_changes: |
|
330 |
logger.debug('had changes, saving %r', account.user) |
|
331 |
account.user.save() |
|
332 |
self.last_sync_time = sync_time |
|
333 |
self.save(update_fields=['last_sync_time']) |
|
334 | ||
245 | 335 | |
246 | 336 |
class OIDCClaimMapping(AuthenticatorRelatedObjectBase): |
247 | 337 |
NOT_VERIFIED = 0 |
tests/test_commands.py | ||
---|---|---|
17 | 17 |
import datetime |
18 | 18 |
import importlib |
19 | 19 |
import json |
20 |
import random |
|
21 |
import uuid |
|
20 | 22 |
from io import BufferedReader, BufferedWriter, TextIOWrapper |
21 | 23 | |
24 |
import httmock |
|
22 | 25 |
import py |
23 | 26 |
import pytest |
24 | 27 |
import webtest |
... | ... | |
40 | 43 |
from authentic2.apps.journal.models import Event |
41 | 44 |
from authentic2.custom_user.models import DeletedUser |
42 | 45 |
from authentic2.models import UserExternalId |
43 |
from authentic2_auth_oidc.models import OIDCAccount, OIDCProvider |
|
46 |
from authentic2.utils import crypto |
|
47 |
from authentic2_auth_oidc.models import OIDCAccount, OIDCClaimMapping, OIDCProvider |
|
44 | 48 |
from django_rbac.utils import get_operation |
45 | 49 | |
46 |
from .utils import call_command, login |
|
50 |
from .utils import call_command, check_log, login
|
|
47 | 51 | |
48 | 52 |
User = get_user_model() |
49 | 53 | |
... | ... | |
456 | 460 |
call_command('clean-user-exports') |
457 | 461 |
with pytest.raises(webtest.app.AppError): |
458 | 462 |
resp.click('Download CSV') |
463 | ||
464 | ||
465 |
@pytest.mark.parametrize('deletion_number,deletion_valid', [(2, True), (5, True), (10, False)]) |
|
466 |
def test_oidc_sync_provider(db, app, admin, settings, caplog, deletion_number, deletion_valid): |
|
467 |
oidc_provider = OIDCProvider.objects.create( |
|
468 |
issuer='https://some.provider', |
|
469 |
name='Some Provider', |
|
470 |
slug='some-provider', |
|
471 |
ou=get_default_ou(), |
|
472 |
) |
|
473 |
OIDCClaimMapping.objects.create( |
|
474 |
authenticator=oidc_provider, |
|
475 |
attribute='username', |
|
476 |
idtoken_claim=False, |
|
477 |
claim='username', |
|
478 |
) |
|
479 |
OIDCClaimMapping.objects.create( |
|
480 |
authenticator=oidc_provider, |
|
481 |
attribute='email', |
|
482 |
idtoken_claim=False, |
|
483 |
claim='email', |
|
484 |
) |
|
485 |
# last one, with an idtoken claim |
|
486 |
OIDCClaimMapping.objects.create( |
|
487 |
authenticator=oidc_provider, |
|
488 |
attribute='last_name', |
|
489 |
idtoken_claim=True, |
|
490 |
claim='family_name', |
|
491 |
) |
|
492 |
# typo in template string |
|
493 |
OIDCClaimMapping.objects.create( |
|
494 |
authenticator=oidc_provider, |
|
495 |
attribute='first_name', |
|
496 |
idtoken_claim=True, |
|
497 |
claim='given_name', |
|
498 |
) |
|
499 |
User = get_user_model() |
|
500 |
for i in range(100): |
|
501 |
user = User.objects.create( |
|
502 |
first_name='John%s' % i, |
|
503 |
last_name='Doe%s' % i, |
|
504 |
username='john.doe.%s' % i, |
|
505 |
email='john.doe.%s@ad.dre.ss' % i, |
|
506 |
ou=get_default_ou(), |
|
507 |
) |
|
508 |
identifier = uuid.UUID(user.uuid).bytes |
|
509 |
sector_identifier = 'some.provider' |
|
510 |
cipher_args = [ |
|
511 |
settings.SECRET_KEY.encode('utf-8'), |
|
512 |
identifier, |
|
513 |
sector_identifier, |
|
514 |
] |
|
515 |
sub = crypto.aes_base64url_deterministic_encrypt(*cipher_args).decode('utf-8') |
|
516 |
OIDCAccount.objects.create(user=user, provider=oidc_provider, sub=sub) |
|
517 | ||
518 |
def synchronization_post_deletion_response(url, request): |
|
519 |
headers = {'content-type': 'application/json'} |
|
520 |
content = { |
|
521 |
'unknown_uuids': [ |
|
522 |
account.sub for account in random.sample(list(OIDCAccount.objects.all()), deletion_number) |
|
523 |
] |
|
524 |
} |
|
525 |
return httmock.response(status_code=200, headers=headers, content=content, request=request) |
|
526 | ||
527 |
def synchronization_get_modified_response(url, request): |
|
528 |
headers = {'content-type': 'application/json'} |
|
529 |
# randomized batch of modified users |
|
530 |
modified_users = random.sample(list(User.objects.all()), 20) |
|
531 |
results = [] |
|
532 |
for count, user in enumerate(modified_users): |
|
533 |
user_json = user.to_json() |
|
534 |
user_json['username'] = f'modified_{count}' |
|
535 |
user_json['first_name'] = 'Mod' |
|
536 |
user_json['last_name'] = 'Ified' |
|
537 |
# mocking claim resolution by oidc provider |
|
538 |
user_json['given_name'] = 'Mod' |
|
539 |
user_json['family_name'] = 'Ified' |
|
540 |
results.append(user_json) |
|
541 |
content = {'results': results} |
|
542 |
return httmock.response(status_code=200, headers=headers, content=content, request=request) |
|
543 | ||
544 |
with httmock.HTTMock( |
|
545 |
httmock.urlmatch( |
|
546 |
netloc=r'some\.provider', |
|
547 |
path=r'^/api/users/synchronization/$', |
|
548 |
method='POST', |
|
549 |
)(synchronization_post_deletion_response) |
|
550 |
): |
|
551 | ||
552 |
with httmock.HTTMock( |
|
553 |
httmock.urlmatch( |
|
554 |
netloc=r'some\.provider', |
|
555 |
path=r'^/api/users/*', |
|
556 |
method='GET', |
|
557 |
)(synchronization_get_modified_response) |
|
558 |
): |
|
559 |
with check_log(caplog, 'no provider supporting synchronization'): |
|
560 |
call_command('oidc-sync-provider', '-v1') |
|
561 | ||
562 |
oidc_provider.a2_synchronization_supported = True |
|
563 |
oidc_provider.save() |
|
564 | ||
565 |
with check_log(caplog, 'no provider supporting synchronization'): |
|
566 |
call_command('oidc-sync-provider', '--provider', 'whatever', '-v1') |
|
567 | ||
568 |
with check_log(caplog, 'got 20 users'): |
|
569 |
call_command('oidc-sync-provider', '-v1') |
|
570 |
if deletion_valid: |
|
571 |
# existing users check |
|
572 |
assert OIDCAccount.objects.count() == 100 - deletion_number |
|
573 |
else: |
|
574 |
assert OIDCAccount.objects.count() == 100 |
|
575 |
assert caplog.records[3].levelname == 'ERROR' |
|
576 |
assert 'deletion ratio is abnormally high' in caplog.records[3].message |
|
577 | ||
578 |
# users update |
|
579 |
assert User.objects.filter(username__startswith='modified').count() in range( |
|
580 |
20 - deletion_number, 21 |
|
581 |
) |
|
582 |
assert User.objects.filter(first_name='Mod', last_name='Ified').count() in range( |
|
583 |
20 - deletion_number, 21 |
|
584 |
) |
|
459 |
- |