Projet

Général

Profil

0003-tests-test-invalid-kid-in-id_token-39136.patch

Benjamin Dauvergne, 19 février 2020 02:27

Télécharger (16,2 ko)

Voir les différences:

Subject: [PATCH 3/3] tests: test invalid kid in id_token (#39136)

 tests/test_auth_oidc.py | 136 ++++++++++++++++++++++++++--------------
 tests/utils.py          |   7 ++-
 2 files changed, 92 insertions(+), 51 deletions(-)
tests/test_auth_oidc.py
55 55
        base64url_decode('x')
56 56
    base64url_decode('aa')
57 57

  
58
header_rsa_decoded = {'alg': 'RS256', 'kid': '1e9gdk7'}
58
KID = '1e9gdk7'
59
header_rsa_decoded = {'alg': 'RS256', 'kid': KID}
59 60
header_hmac_decoded = {'alg': 'HS256'}
60 61
payload_decoded = {
61 62
    'sub': '248289761001',
......
72 73
           'MFM2X1d6QTJNaiIsInN1YiI6IjI0ODI4OTc2MTAwMSJ9')
73 74

  
74 75

  
75
def test_parse_id_token(code, oidc_provider, oidc_provider_jwkset, header,
76
                        signature):
76
def test_parse_id_token(code, oidc_provider, oidc_provider_jwkset):
77
    header = _header(oidc_provider)
78
    signature = _signature(oidc_provider)
77 79
    with oidc_provider_mock(oidc_provider, oidc_provider_jwkset, code):
78 80
        with pytest.raises(InvalidJWSObject):
79 81
            parse_id_token('x%s.%s.%s' % (header, payload, signature), oidc_provider)
......
88 90
        assert parse_id_token('%s.%s.%s' % (header, payload, signature), oidc_provider)
89 91

  
90 92

  
91
def test_idtoken(oidc_provider, header, signature):
93
def test_idtoken(oidc_provider):
94
    signature = _signature(oidc_provider)
95
    header = _header(oidc_provider)
92 96
    token = IDToken('%s.%s.%s' % (header, payload, signature))
93 97
    token.deserialize(oidc_provider)
94 98
    assert token.sub == payload_decoded['sub']
......
101 105

  
102 106
@pytest.fixture
103 107
def oidc_provider_jwkset():
104
    key = JWK.generate(kty='RSA', size=512, kid='1e9gdk7')
108
    key = JWK.generate(kty='RSA', size=512, kid=KID)
105 109
    jwkset = JWKSet()
106 110
    jwkset.add(key)
107 111
    return jwkset
......
109 113
OIDC_PROVIDER_PARAMS = [
110 114
    {},
111 115
    {
112
        'idtoken_algo': OIDCProvider.ALGO_HMAC
116
        'idtoken_algo': OIDCProvider.ALGO_HMAC,
113 117
    },
114 118
    {
115 119
        'claims_parameter_supported': True,
......
119 123

  
120 124
@pytest.fixture(params=OIDC_PROVIDER_PARAMS)
121 125
def oidc_provider(request, db, oidc_provider_jwkset):
122
    idtoken_algo = request.param.get('idtoken_algo', OIDCProvider.ALGO_RSA)
123 126
    claims_parameter_supported = request.param.get('claims_parameter_supported', False)
127
    idtoken_algo = request.param.get('idtoken_algo', OIDCProvider.ALGO_RSA)
128

  
129
    return make_oidc_provider(
130
        idtoken_algo=idtoken_algo,
131
        jwkset=oidc_provider_jwkset,
132
        claims_parameter_supported=claims_parameter_supported)
133

  
134

  
135
@pytest.fixture
136
def oidc_provider_rsa(request, db, oidc_provider_jwkset):
137
    return make_oidc_provider(
138
        idtoken_algo=OIDCProvider.ALGO_RSA,
139
        jwkset=oidc_provider_jwkset)
140

  
141

  
142
def make_oidc_provider(
143
        idtoken_algo=OIDCProvider._meta.get_field('idtoken_algo').default,
144
        jwkset=None,
145
        claims_parameter_supported=False):
124 146
    from authentic2_auth_oidc.utils import get_provider, get_provider_by_issuer
125 147
    get_provider.cache.clear()
126 148
    get_provider_by_issuer.cache.clear()
127
    if idtoken_algo == OIDCProvider.ALGO_RSA:
128
        jwkset = json.loads(oidc_provider_jwkset.export())
129
    else:
130
        jwkset = None
149

  
150
    if jwkset is not None:
151
        jwkset = json.loads(jwkset.export())
152

  
131 153
    provider = OIDCProvider.objects.create(
132 154
        ou=get_default_ou(),
133 155
        name='OIDIDP',
......
182 204
    return 'xxxx'
183 205

  
184 206

  
185
@pytest.fixture
186
def header(oidc_provider):
207
def _header(oidc_provider):
187 208
    if oidc_provider.idtoken_algo == OIDCProvider.ALGO_RSA:
188 209
        return header_rsa
189 210
    elif oidc_provider.idtoken_algo == OIDCProvider.ALGO_HMAC:
190 211
        return header_hmac
191 212

  
192 213

  
193
@pytest.fixture
194
def signature(oidc_provider):
214
def _signature(oidc_provider):
195 215
    if oidc_provider.idtoken_algo == OIDCProvider.ALGO_RSA:
196
        key = oidc_provider.jwkset.get_key(kid='1e9gdk7')
216
        key = oidc_provider.jwkset.get_key(kid=KID)
197 217
        header_decoded = header_rsa_decoded
198 218
    elif oidc_provider.idtoken_algo == OIDCProvider.ALGO_HMAC:
199 219
        key = JWK(kty='oct',
......
206 226

  
207 227

  
208 228
def oidc_provider_mock(oidc_provider, oidc_provider_jwkset, code, extra_id_token=None,
209
                       extra_user_info=None, sub='john.doe', nonce=None):
229
                       extra_user_info=None, sub='john.doe', nonce=None, kid=KID):
210 230
    token_endpoint = urlparse.urlparse(oidc_provider.token_endpoint)
211 231
    userinfo_endpoint = urlparse.urlparse(oidc_provider.userinfo_endpoint)
212 232
    token_revocation_endpoint = urlparse.urlparse(oidc_provider.token_revocation_endpoint)
......
227 247
                id_token.update(extra_id_token)
228 248

  
229 249
            if oidc_provider.idtoken_algo == OIDCProvider.ALGO_RSA:
230
                jwt = JWT(header={'alg': 'RS256', 'kid': '1e9gdk7'},
250
                jwt = JWT(header={'alg': 'RS256', 'kid': kid},
231 251
                          claims=id_token)
232 252
                jwt.make_signed_token(list(oidc_provider_jwkset['keys'])[0])
233 253
            else:
......
289 309
    return HTTMock(token_endpoint_mock, user_info_endpoint_mock, token_revocation_endpoint_mock)
290 310

  
291 311

  
292
@pytest.fixture
293
def login_url(oidc_provider):
294
    return reverse('oidc-login', kwargs={'pk': oidc_provider.pk})
295

  
296

  
297
@pytest.fixture
298 312
def login_callback_url(oidc_provider):
299 313
    return reverse('oidc-login-callback')
300 314

  
......
333 347
    assert response.pyquery('p#oidc-p-oidcidp-2')
334 348

  
335 349

  
336

  
337

  
338
def test_sso(app, caplog, code, oidc_provider, oidc_provider_jwkset, login_url, login_callback_url, hooks):
350
def test_sso(app, caplog, code, oidc_provider, oidc_provider_jwkset, hooks):
339 351
    OU = get_ou_model()
340 352
    cassis = OU.objects.create(name='Cassis', slug='cassis')
341 353
    OU.cached.cache.clear()
......
370 382

  
371 383
    with utils.check_log(caplog, 'failed to contact the token_endpoint'):
372 384
        with oidc_provider_mock(oidc_provider, oidc_provider_jwkset, code):
373
            response = app.get(login_callback_url, params={'code': 'yyyy', 'state': query['state']})
374
    with utils.check_log(caplog, 'invalid id_token %r'):
385
            response = app.get(login_callback_url(oidc_provider), params={'code': 'yyyy', 'state': query['state']})
386
    with utils.check_log(caplog, 'invalid id_token'):
375 387
        with oidc_provider_mock(oidc_provider, oidc_provider_jwkset, code,
376 388
                                extra_id_token={'iss': None}):
377
            response = app.get(login_callback_url, params={'code': code, 'state': query['state']})
378
    with utils.check_log(caplog, 'invalid id_token %r'):
389
            response = app.get(login_callback_url(oidc_provider), params={'code': code, 'state': query['state']})
390
    with utils.check_log(caplog, 'invalid id_token'):
379 391
        with oidc_provider_mock(oidc_provider, oidc_provider_jwkset, code,
380 392
                                extra_id_token={'sub': None}):
381
            response = app.get(login_callback_url, params={'code': code, 'state': query['state']})
393
            response = app.get(login_callback_url(oidc_provider), params={'code': code, 'state': query['state']})
382 394
    with utils.check_log(caplog, 'authentication is too old'):
383 395
        with oidc_provider_mock(oidc_provider, oidc_provider_jwkset, code,
384 396
                                extra_id_token={'iat': 1}):
385
            response = app.get(login_callback_url, params={'code': code, 'state': query['state']})
386
    with utils.check_log(caplog, 'invalid id_token %r'):
397
            response = app.get(login_callback_url(oidc_provider), params={'code': code, 'state': query['state']})
398
    with utils.check_log(caplog, 'invalid id_token'):
387 399
        with oidc_provider_mock(oidc_provider, oidc_provider_jwkset, code,
388 400
                                extra_id_token={'exp': 1}):
389
            response = app.get(login_callback_url, params={'code': code, 'state': query['state']})
401
            response = app.get(login_callback_url(oidc_provider), params={'code': code, 'state': query['state']})
390 402
    with utils.check_log(caplog, 'invalid id_token audience'):
391 403
        with oidc_provider_mock(oidc_provider, oidc_provider_jwkset, code,
392 404
                                extra_id_token={'aud': 'zz'}):
393
            response = app.get(login_callback_url, params={'code': code, 'state': query['state']})
405
            response = app.get(login_callback_url(oidc_provider), params={'code': code, 'state': query['state']})
394 406
    with utils.check_log(caplog, 'expected nonce'):
395 407
        with oidc_provider_mock(oidc_provider, oidc_provider_jwkset, code):
396
            response = app.get(login_callback_url, params={'code': code, 'state': query['state']})
408
            response = app.get(login_callback_url(oidc_provider), params={'code': code, 'state': query['state']})
397 409
    assert not hooks.auth_oidc_backend_modify_user
398 410
    with utils.check_log(caplog, 'created user'):
399 411
        with oidc_provider_mock(oidc_provider, oidc_provider_jwkset, code, nonce=nonce):
400
            response = app.get(login_callback_url, params={'code': code, 'state': query['state']})
412
            response = app.get(login_callback_url(oidc_provider), params={'code': code, 'state': query['state']})
401 413
    assert len(hooks.auth_oidc_backend_modify_user) == 1
402 414
    assert set(hooks.auth_oidc_backend_modify_user[0]['kwargs']) >= set(
403 415
        ['user', 'provider', 'user_info', 'id_token', 'access_token'])
......
417 429

  
418 430
    with oidc_provider_mock(oidc_provider, oidc_provider_jwkset, code,
419 431
                            extra_user_info={'family_name_verified': True}, nonce=nonce):
420
        response = app.get(login_callback_url, params={'code': code, 'state': query['state']})
432
        response = app.get(login_callback_url(oidc_provider), params={'code': code, 'state': query['state']})
421 433
    assert AttributeValue.objects.filter(content='Doe', verified=False).count() == 0
422 434
    assert AttributeValue.objects.filter(content='Doe', verified=True).count() == 1
423 435

  
424 436
    with oidc_provider_mock(oidc_provider, oidc_provider_jwkset, code,
425 437
                            extra_user_info={'ou': 'cassis'}, nonce=nonce):
426
        response = app.get(login_callback_url, params={'code': code, 'state': query['state']})
438
        response = app.get(login_callback_url(oidc_provider), params={'code': code, 'state': query['state']})
427 439
    assert User.objects.count() == 1
428 440
    user = User.objects.get()
429 441
    assert user.ou == cassis
430 442

  
431 443
    with oidc_provider_mock(oidc_provider, oidc_provider_jwkset, code, nonce=nonce):
432
        response = app.get(login_callback_url, params={'code': code, 'state': query['state']})
444
        response = app.get(login_callback_url(oidc_provider), params={'code': code, 'state': query['state']})
433 445
    assert User.objects.count() == 1
434 446
    user = User.objects.get()
435 447
    assert user.ou == get_default_ou()
......
438 450
    time.sleep(0.1)
439 451

  
440 452
    with oidc_provider_mock(oidc_provider, oidc_provider_jwkset, code):
441
        response = app.get(login_callback_url, params={'code': code, 'state': query['state']})
453
        response = app.get(login_callback_url(oidc_provider), params={'code': code, 'state': query['state']})
442 454
    assert User.objects.count() == 1
443 455
    user = User.objects.get()
444 456
    assert user.ou == get_default_ou()
......
469 481
    assert 'oidc-a-oididp' not in response.text
470 482

  
471 483

  
472
def test_strategy_find_uuid(app, caplog, code, oidc_provider, oidc_provider_jwkset, login_url,
473
                            login_callback_url, simple_user):
484
def test_strategy_find_uuid(app, caplog, code, oidc_provider, oidc_provider_jwkset, simple_user):
474 485

  
475 486
    get_providers.cache.clear()
476 487
    has_providers.cache.clear()
......
492 503
    # sub=john.doe, MUST not work
493 504
    with utils.check_log(caplog, 'cannot create user'):
494 505
        with oidc_provider_mock(oidc_provider, oidc_provider_jwkset, code, nonce=nonce):
495
            response = app.get(login_callback_url, params={'code': code, 'state': query['state']})
506
            response = app.get(login_callback_url(oidc_provider), params={'code': code, 'state': query['state']})
496 507

  
497 508
    # sub=simple_user.uuid MUST work
498 509
    with utils.check_log(caplog, 'found user using UUID'):
499 510
        with oidc_provider_mock(oidc_provider, oidc_provider_jwkset, code, sub=simple_user.uuid, nonce=nonce):
500
            response = app.get(login_callback_url, params={'code': code, 'state': query['state']})
511
            response = app.get(login_callback_url(oidc_provider), params={'code': code, 'state': query['state']})
501 512

  
502 513
    assert urlparse.urlparse(response['Location']).path == '/'
503 514
    assert User.objects.count() == 1
......
542 553
            openid_configuration=oidc_conf)
543 554

  
544 555

  
545
def test_required_keys(db, oidc_provider, header, signature, caplog):
556
def test_required_keys(db, oidc_provider, caplog):
546 557
    erroneous_payload = base64url_encode(json.dumps({
547 558
        'sub': '248289761001',
548 559
        'iss': 'http://server.example.com',
......
553 564

  
554 565
    with pytest.raises(IDTokenError):
555 566
        with utils.check_log(caplog, 'missing field'):
556
            token = IDToken('{}.{}.{}'.format(header, erroneous_payload, signature))
567
            token = IDToken('{}.{}.{}'.format(_header(oidc_provider), erroneous_payload, _signature(oidc_provider)))
557 568
            token.deserialize(oidc_provider)
569

  
570

  
571
def test_invalid_kid(app, caplog, code, oidc_provider_rsa,
572
                     oidc_provider_jwkset, simple_user):
573

  
574
    get_providers.cache.clear()
575
    has_providers.cache.clear()
576
    # no mapping please
577
    OIDCClaimMapping.objects.all().delete()
578

  
579
    User = get_user_model()
580
    assert User.objects.count() == 1
581

  
582
    response = app.get('/').maybe_follow()
583
    assert oidc_provider_rsa.name in response.text
584
    response = response.click(oidc_provider_rsa.name)
585
    location = urlparse.urlparse(response.location)
586
    query = check_simple_qs(urlparse.parse_qs(location.query))
587
    nonce = app.session['auth_oidc'][query['state']]['request']['nonce']
588

  
589
    # test invalid kid
590
    with utils.check_log(caplog, message='not in key set', levelname='WARNING'):
591
        with oidc_provider_mock(oidc_provider_rsa, oidc_provider_jwkset, code, nonce=nonce, kid='coin'):
592
            response = app.get(login_callback_url(oidc_provider_rsa), params={'code': code, 'state': query['state']})
593

  
594
    # test missing kid
595
    with utils.check_log(caplog, message='Missing Key ID', levelname='WARNING'):
596
        with oidc_provider_mock(oidc_provider_rsa, oidc_provider_jwkset, code, nonce=nonce, kid=None):
597
            response = app.get(login_callback_url(oidc_provider_rsa), params={'code': code, 'state': query['state']})
tests/utils.py
154 154

  
155 155

  
156 156
@contextmanager
157
def check_log(caplog, msg):
157
def check_log(caplog, message, levelname=None):
158 158
    idx = len(caplog.records)
159 159
    yield
160
    assert any(msg in record.msg for record in caplog.records[idx:]), \
161
        '%r not found in log records' % msg
160
    assert any(message in record.message for record in caplog.records[idx:]
161
               if not levelname or record.levelname == levelname), \
162
        '%r not found in log records' % message
162 163

  
163 164

  
164 165
def can_resolve_dns():
165
-