From 186f823529a77392af726bc25eacf553b12f1b4f Mon Sep 17 00:00:00 2001 From: Benjamin Dauvergne Date: Thu, 21 Oct 2021 18:40:27 +0200 Subject: [PATCH 3/3] auth: inject dnsbl function in condition evaluation context (#58055) --- src/authentic2/utils/evaluate.py | 17 +++++++++++++++-- src/authentic2/views.py | 5 ++--- tests/test_auth_saml.py | 21 +++++++++++++++++++++ 3 files changed, 38 insertions(+), 5 deletions(-) diff --git a/src/authentic2/utils/evaluate.py b/src/authentic2/utils/evaluate.py index 3f448da0..a78a0953 100644 --- a/src/authentic2/utils/evaluate.py +++ b/src/authentic2/utils/evaluate.py @@ -93,6 +93,7 @@ def is_valid_hostname(hostname): def check_dnsbl(dnsbl, remote_addr): domain = '.'.join(reversed(remote_addr.split('.'))) + '.' + dnsbl exception = None + log = logger.debug try: answers = dns.resolver.resolve(domain, 'A', lifetime=1) result = any(answer.address for answer in answers) @@ -104,9 +105,9 @@ def check_dnsbl(dnsbl, remote_addr): result = False except dns.exception.DNSException as e: exception = e - logger.warning(f'utils: could not check dnsbl {dnsbl} for domain "%s": %s', domain, e) + log = logger.warning result = False - logger.debug('utils: dnsbl lookup of "%s", result=%s exception=%s', domain, result, exception) + log('utils: dnsbl lookup of "%s", result=%s exception=%s', domain, result, exception) return result @@ -299,3 +300,15 @@ def evaluate_condition(expression, ctx=None, validator=None, on_raise=None): if on_raise is not None: return on_raise raise e + + +def make_condition_context(*, request=None, **kwargs): + '''Helper to make a condition context''' + ctx = { + 'dnsbl': dnsbl, + } + if request: + ctx['headers'] = HTTPHeaders(request) + ctx['remote_addr'] = request.META.get('REMOTE_ADDR') + ctx.update(kwargs) + return ctx diff --git a/src/authentic2/views.py b/src/authentic2/views.py index e9f51b1b..c691563e 100644 --- a/src/authentic2/views.py +++ b/src/authentic2/views.py @@ -62,7 +62,7 @@ from .forms import profile as profile_forms from .forms import registration as registration_forms from .utils import misc as utils_misc from .utils import switch_user as utils_switch_user -from .utils.evaluate import HTTPHeaders +from .utils.evaluate import make_condition_context from .utils.service import get_service_from_request, get_service_from_token, set_service_ref from .utils.view_decorators import enable_view_restriction @@ -344,9 +344,8 @@ def login(request, template_name='authentic2/login.html', redirect_field_name=RE else: # New frontends API auth_blocks = [] parameters = {'request': request, 'context': context} - remote_addr = request.META.get('REMOTE_ADDR') login_hint = set(request.session.get('login-hint', [])) - show_ctx = dict(remote_addr=remote_addr, login_hint=login_hint, headers=HTTPHeaders(request)) + show_ctx = make_condition_context(request=request, login_hint=login_hint) if service: show_ctx['service_ou_slug'] = service.ou and service.ou.slug show_ctx['service_slug'] = service.slug diff --git a/tests/test_auth_saml.py b/tests/test_auth_saml.py index e1c67d96..686b4ff5 100644 --- a/tests/test_auth_saml.py +++ b/tests/test_auth_saml.py @@ -16,6 +16,7 @@ import os import re +from unittest import mock import lasso import pytest @@ -239,6 +240,26 @@ def test_login_with_conditionnal_authenticators(db, app, settings, caplog): assert 'login-saml-1' not in response +def test_login_condition_dnsbl(db, app, settings, caplog): + settings.A2_AUTH_SAML_ENABLE = True + settings.MELLON_IDENTITY_PROVIDERS = [ + {"METADATA": os.path.join(os.path.dirname(__file__), 'metadata.xml')}, + {"METADATA": os.path.join(os.path.dirname(__file__), 'metadata.xml')}, + ] + settings.AUTH_FRONTENDS_KWARGS = { + 'saml': { + 'show_condition': { + '0': 'remote_addr in dnsbl(\'dnswl.example.com\')', + '1': 'remote_addr not in dnsbl(\'dnswl.example.com\')', + } + } + } + with mock.patch('authentic2.utils.evaluate.check_dnsbl', return_value=True): + response = app.get('/login/') + assert 'login-saml-0' in response + assert 'login-saml-1' not in response + + def test_login_autorun(db, app, settings): response = app.get('/login/') -- 2.33.0