From efba1643959a7513cbb24214d1d94bfe200e13a3 Mon Sep 17 00:00:00 2001 From: Benjamin Dauvergne Date: Tue, 21 Jan 2020 13:44:08 +0100 Subject: [PATCH 3/3] tests: test invalid kid in id_token (#39136) --- tests/test_auth_oidc.py | 136 ++++++++++++++++++++++++++-------------- tests/utils.py | 7 ++- 2 files changed, 92 insertions(+), 51 deletions(-) diff --git a/tests/test_auth_oidc.py b/tests/test_auth_oidc.py index c68d86d4..5f62b642 100644 --- a/tests/test_auth_oidc.py +++ b/tests/test_auth_oidc.py @@ -55,7 +55,8 @@ def test_base64url_decode(): base64url_decode('x') base64url_decode('aa') -header_rsa_decoded = {'alg': 'RS256', 'kid': '1e9gdk7'} +KID = '1e9gdk7' +header_rsa_decoded = {'alg': 'RS256', 'kid': KID} header_hmac_decoded = {'alg': 'HS256'} payload_decoded = { 'sub': '248289761001', @@ -72,8 +73,9 @@ payload = ('eyJhdWQiOiJzNkJoZFJrcXQzIiwiZXhwIjoyMjAxMDk0Mjc4LCJpYXQiOjEzMTEyOD' 'MFM2X1d6QTJNaiIsInN1YiI6IjI0ODI4OTc2MTAwMSJ9') -def test_parse_id_token(code, oidc_provider, oidc_provider_jwkset, header, - signature): +def test_parse_id_token(code, oidc_provider, oidc_provider_jwkset): + header = _header(oidc_provider) + signature = _signature(oidc_provider) with oidc_provider_mock(oidc_provider, oidc_provider_jwkset, code): with pytest.raises(InvalidJWSObject): parse_id_token('x%s.%s.%s' % (header, payload, signature), oidc_provider) @@ -88,7 +90,9 @@ def test_parse_id_token(code, oidc_provider, oidc_provider_jwkset, header, assert parse_id_token('%s.%s.%s' % (header, payload, signature), oidc_provider) -def test_idtoken(oidc_provider, header, signature): +def test_idtoken(oidc_provider): + signature = _signature(oidc_provider) + header = _header(oidc_provider) token = IDToken('%s.%s.%s' % (header, payload, signature)) token.deserialize(oidc_provider) assert token.sub == payload_decoded['sub'] @@ -101,7 +105,7 @@ def test_idtoken(oidc_provider, header, signature): @pytest.fixture def oidc_provider_jwkset(): - key = JWK.generate(kty='RSA', size=512, kid='1e9gdk7') + key = JWK.generate(kty='RSA', size=512, kid=KID) jwkset = JWKSet() jwkset.add(key) return jwkset @@ -109,7 +113,7 @@ def oidc_provider_jwkset(): OIDC_PROVIDER_PARAMS = [ {}, { - 'idtoken_algo': OIDCProvider.ALGO_HMAC + 'idtoken_algo': OIDCProvider.ALGO_HMAC, }, { 'claims_parameter_supported': True, @@ -119,15 +123,33 @@ OIDC_PROVIDER_PARAMS = [ @pytest.fixture(params=OIDC_PROVIDER_PARAMS) def oidc_provider(request, db, oidc_provider_jwkset): - idtoken_algo = request.param.get('idtoken_algo', OIDCProvider.ALGO_RSA) claims_parameter_supported = request.param.get('claims_parameter_supported', False) + idtoken_algo = request.param.get('idtoken_algo', OIDCProvider.ALGO_RSA) + + return make_oidc_provider( + idtoken_algo=idtoken_algo, + jwkset=oidc_provider_jwkset, + claims_parameter_supported=claims_parameter_supported) + + +@pytest.fixture +def oidc_provider_rsa(request, db, oidc_provider_jwkset): + return make_oidc_provider( + idtoken_algo=OIDCProvider.ALGO_RSA, + jwkset=oidc_provider_jwkset) + + +def make_oidc_provider( + idtoken_algo=OIDCProvider._meta.get_field('idtoken_algo').default, + jwkset=None, + claims_parameter_supported=False): from authentic2_auth_oidc.utils import get_provider, get_provider_by_issuer get_provider.cache.clear() get_provider_by_issuer.cache.clear() - if idtoken_algo == OIDCProvider.ALGO_RSA: - jwkset = json.loads(oidc_provider_jwkset.export()) - else: - jwkset = None + + if jwkset is not None: + jwkset = json.loads(jwkset.export()) + provider = OIDCProvider.objects.create( ou=get_default_ou(), name='OIDIDP', @@ -182,18 +204,16 @@ def code(): return 'xxxx' -@pytest.fixture -def header(oidc_provider): +def _header(oidc_provider): if oidc_provider.idtoken_algo == OIDCProvider.ALGO_RSA: return header_rsa elif oidc_provider.idtoken_algo == OIDCProvider.ALGO_HMAC: return header_hmac -@pytest.fixture -def signature(oidc_provider): +def _signature(oidc_provider): if oidc_provider.idtoken_algo == OIDCProvider.ALGO_RSA: - key = oidc_provider.jwkset.get_key(kid='1e9gdk7') + key = oidc_provider.jwkset.get_key(kid=KID) header_decoded = header_rsa_decoded elif oidc_provider.idtoken_algo == OIDCProvider.ALGO_HMAC: key = JWK(kty='oct', @@ -206,7 +226,7 @@ def signature(oidc_provider): def oidc_provider_mock(oidc_provider, oidc_provider_jwkset, code, extra_id_token=None, - extra_user_info=None, sub='john.doe', nonce=None): + extra_user_info=None, sub='john.doe', nonce=None, kid=KID): token_endpoint = urlparse.urlparse(oidc_provider.token_endpoint) userinfo_endpoint = urlparse.urlparse(oidc_provider.userinfo_endpoint) token_revocation_endpoint = urlparse.urlparse(oidc_provider.token_revocation_endpoint) @@ -227,7 +247,7 @@ def oidc_provider_mock(oidc_provider, oidc_provider_jwkset, code, extra_id_token id_token.update(extra_id_token) if oidc_provider.idtoken_algo == OIDCProvider.ALGO_RSA: - jwt = JWT(header={'alg': 'RS256', 'kid': '1e9gdk7'}, + jwt = JWT(header={'alg': 'RS256', 'kid': kid}, claims=id_token) jwt.make_signed_token(list(oidc_provider_jwkset['keys'])[0]) else: @@ -289,12 +309,6 @@ def oidc_provider_mock(oidc_provider, oidc_provider_jwkset, code, extra_id_token return HTTMock(token_endpoint_mock, user_info_endpoint_mock, token_revocation_endpoint_mock) -@pytest.fixture -def login_url(oidc_provider): - return reverse('oidc-login', kwargs={'pk': oidc_provider.pk}) - - -@pytest.fixture def login_callback_url(oidc_provider): return reverse('oidc-login-callback') @@ -333,9 +347,7 @@ def test_providers_on_login_page(oidc_provider, app): assert response.pyquery('p#oidc-p-oidcidp-2') - - -def test_sso(app, caplog, code, oidc_provider, oidc_provider_jwkset, login_url, login_callback_url, hooks): +def test_sso(app, caplog, code, oidc_provider, oidc_provider_jwkset, hooks): OU = get_ou_model() cassis = OU.objects.create(name='Cassis', slug='cassis') OU.cached.cache.clear() @@ -370,34 +382,34 @@ def test_sso(app, caplog, code, oidc_provider, oidc_provider_jwkset, login_url, with utils.check_log(caplog, 'failed to contact the token_endpoint'): with oidc_provider_mock(oidc_provider, oidc_provider_jwkset, code): - response = app.get(login_callback_url, params={'code': 'yyyy', 'state': query['state']}) - with utils.check_log(caplog, 'invalid id_token %r'): + response = app.get(login_callback_url(oidc_provider), params={'code': 'yyyy', 'state': query['state']}) + with utils.check_log(caplog, 'invalid id_token'): with oidc_provider_mock(oidc_provider, oidc_provider_jwkset, code, extra_id_token={'iss': None}): - response = app.get(login_callback_url, params={'code': code, 'state': query['state']}) - with utils.check_log(caplog, 'invalid id_token %r'): + response = app.get(login_callback_url(oidc_provider), params={'code': code, 'state': query['state']}) + with utils.check_log(caplog, 'invalid id_token'): with oidc_provider_mock(oidc_provider, oidc_provider_jwkset, code, extra_id_token={'sub': None}): - response = app.get(login_callback_url, params={'code': code, 'state': query['state']}) + response = app.get(login_callback_url(oidc_provider), params={'code': code, 'state': query['state']}) with utils.check_log(caplog, 'authentication is too old'): with oidc_provider_mock(oidc_provider, oidc_provider_jwkset, code, extra_id_token={'iat': 1}): - response = app.get(login_callback_url, params={'code': code, 'state': query['state']}) - with utils.check_log(caplog, 'invalid id_token %r'): + response = app.get(login_callback_url(oidc_provider), params={'code': code, 'state': query['state']}) + with utils.check_log(caplog, 'invalid id_token'): with oidc_provider_mock(oidc_provider, oidc_provider_jwkset, code, extra_id_token={'exp': 1}): - response = app.get(login_callback_url, params={'code': code, 'state': query['state']}) + response = app.get(login_callback_url(oidc_provider), params={'code': code, 'state': query['state']}) with utils.check_log(caplog, 'invalid id_token audience'): with oidc_provider_mock(oidc_provider, oidc_provider_jwkset, code, extra_id_token={'aud': 'zz'}): - response = app.get(login_callback_url, params={'code': code, 'state': query['state']}) + response = app.get(login_callback_url(oidc_provider), params={'code': code, 'state': query['state']}) with utils.check_log(caplog, 'expected nonce'): with oidc_provider_mock(oidc_provider, oidc_provider_jwkset, code): - response = app.get(login_callback_url, params={'code': code, 'state': query['state']}) + response = app.get(login_callback_url(oidc_provider), params={'code': code, 'state': query['state']}) assert not hooks.auth_oidc_backend_modify_user with utils.check_log(caplog, 'created user'): with oidc_provider_mock(oidc_provider, oidc_provider_jwkset, code, nonce=nonce): - response = app.get(login_callback_url, params={'code': code, 'state': query['state']}) + response = app.get(login_callback_url(oidc_provider), params={'code': code, 'state': query['state']}) assert len(hooks.auth_oidc_backend_modify_user) == 1 assert set(hooks.auth_oidc_backend_modify_user[0]['kwargs']) >= set( ['user', 'provider', 'user_info', 'id_token', 'access_token']) @@ -417,19 +429,19 @@ def test_sso(app, caplog, code, oidc_provider, oidc_provider_jwkset, login_url, with oidc_provider_mock(oidc_provider, oidc_provider_jwkset, code, extra_user_info={'family_name_verified': True}, nonce=nonce): - response = app.get(login_callback_url, params={'code': code, 'state': query['state']}) + response = app.get(login_callback_url(oidc_provider), params={'code': code, 'state': query['state']}) assert AttributeValue.objects.filter(content='Doe', verified=False).count() == 0 assert AttributeValue.objects.filter(content='Doe', verified=True).count() == 1 with oidc_provider_mock(oidc_provider, oidc_provider_jwkset, code, extra_user_info={'ou': 'cassis'}, nonce=nonce): - response = app.get(login_callback_url, params={'code': code, 'state': query['state']}) + response = app.get(login_callback_url(oidc_provider), params={'code': code, 'state': query['state']}) assert User.objects.count() == 1 user = User.objects.get() assert user.ou == cassis with oidc_provider_mock(oidc_provider, oidc_provider_jwkset, code, nonce=nonce): - response = app.get(login_callback_url, params={'code': code, 'state': query['state']}) + response = app.get(login_callback_url(oidc_provider), params={'code': code, 'state': query['state']}) assert User.objects.count() == 1 user = User.objects.get() assert user.ou == get_default_ou() @@ -438,7 +450,7 @@ def test_sso(app, caplog, code, oidc_provider, oidc_provider_jwkset, login_url, time.sleep(0.1) with oidc_provider_mock(oidc_provider, oidc_provider_jwkset, code): - response = app.get(login_callback_url, params={'code': code, 'state': query['state']}) + response = app.get(login_callback_url(oidc_provider), params={'code': code, 'state': query['state']}) assert User.objects.count() == 1 user = User.objects.get() assert user.ou == get_default_ou() @@ -469,8 +481,7 @@ def test_show_on_login_page(app, oidc_provider): assert 'oidc-a-oididp' not in response.text -def test_strategy_find_uuid(app, caplog, code, oidc_provider, oidc_provider_jwkset, login_url, - login_callback_url, simple_user): +def test_strategy_find_uuid(app, caplog, code, oidc_provider, oidc_provider_jwkset, simple_user): get_providers.cache.clear() has_providers.cache.clear() @@ -492,12 +503,12 @@ def test_strategy_find_uuid(app, caplog, code, oidc_provider, oidc_provider_jwks # sub=john.doe, MUST not work with utils.check_log(caplog, 'cannot create user'): with oidc_provider_mock(oidc_provider, oidc_provider_jwkset, code, nonce=nonce): - response = app.get(login_callback_url, params={'code': code, 'state': query['state']}) + response = app.get(login_callback_url(oidc_provider), params={'code': code, 'state': query['state']}) # sub=simple_user.uuid MUST work with utils.check_log(caplog, 'found user using UUID'): with oidc_provider_mock(oidc_provider, oidc_provider_jwkset, code, sub=simple_user.uuid, nonce=nonce): - response = app.get(login_callback_url, params={'code': code, 'state': query['state']}) + response = app.get(login_callback_url(oidc_provider), params={'code': code, 'state': query['state']}) assert urlparse.urlparse(response['Location']).path == '/' assert User.objects.count() == 1 @@ -542,7 +553,7 @@ def test_register_issuer(db, app, caplog, oidc_provider_jwkset): openid_configuration=oidc_conf) -def test_required_keys(db, oidc_provider, header, signature, caplog): +def test_required_keys(db, oidc_provider, caplog): erroneous_payload = base64url_encode(json.dumps({ 'sub': '248289761001', 'iss': 'http://server.example.com', @@ -553,5 +564,34 @@ def test_required_keys(db, oidc_provider, header, signature, caplog): with pytest.raises(IDTokenError): with utils.check_log(caplog, 'missing field'): - token = IDToken('{}.{}.{}'.format(header, erroneous_payload, signature)) + token = IDToken('{}.{}.{}'.format(_header(oidc_provider), erroneous_payload, _signature(oidc_provider))) token.deserialize(oidc_provider) + + +def test_invalid_kid(app, caplog, code, oidc_provider_rsa, + oidc_provider_jwkset, simple_user): + + get_providers.cache.clear() + has_providers.cache.clear() + # no mapping please + OIDCClaimMapping.objects.all().delete() + + User = get_user_model() + assert User.objects.count() == 1 + + response = app.get('/').maybe_follow() + assert oidc_provider_rsa.name in response.text + response = response.click(oidc_provider_rsa.name) + location = urlparse.urlparse(response.location) + query = check_simple_qs(urlparse.parse_qs(location.query)) + nonce = app.session['auth_oidc'][query['state']]['request']['nonce'] + + # test invalid kid + with utils.check_log(caplog, message='not in key set', levelname='WARNING'): + with oidc_provider_mock(oidc_provider_rsa, oidc_provider_jwkset, code, nonce=nonce, kid='coin'): + response = app.get(login_callback_url(oidc_provider_rsa), params={'code': code, 'state': query['state']}) + + # test missing kid + with utils.check_log(caplog, message='Missing Key ID', levelname='WARNING'): + with oidc_provider_mock(oidc_provider_rsa, oidc_provider_jwkset, code, nonce=nonce, kid=None): + response = app.get(login_callback_url(oidc_provider_rsa), params={'code': code, 'state': query['state']}) diff --git a/tests/utils.py b/tests/utils.py index b5ebab4d..f19972c6 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -154,11 +154,12 @@ class Authentic2TestCase(TestCase): @contextmanager -def check_log(caplog, msg): +def check_log(caplog, message, levelname=None): idx = len(caplog.records) yield - assert any(msg in record.msg for record in caplog.records[idx:]), \ - '%r not found in log records' % msg + assert any(message in record.message for record in caplog.records[idx:] + if not levelname or record.levelname == levelname), \ + '%r not found in log records' % message def can_resolve_dns(): -- 2.24.0