From c7d0383071b99836c3d996b9dd9cace2a0b1a467 Mon Sep 17 00:00:00 2001 From: Benjamin Dauvergne Date: Wed, 3 Oct 2018 00:05:51 +0200 Subject: [PATCH 2/2] authentic: send celery notification in a child process (fixes #26911) Celery use Kombu to communicate with RabbitMQ but kombu has the bad habit of modifying the default socket timeout globally messing with authentic. --- .../management/commands/hobo_provision.py | 4 +- hobo/agent/authentic2/provisionning.py | 69 +++++---- tests_authentic/test_provisionning.py | 132 +++++++++++++----- 3 files changed, 141 insertions(+), 64 deletions(-) diff --git a/hobo/agent/authentic2/management/commands/hobo_provision.py b/hobo/agent/authentic2/management/commands/hobo_provision.py index 4ff456c..49d89c2 100644 --- a/hobo/agent/authentic2/management/commands/hobo_provision.py +++ b/hobo/agent/authentic2/management/commands/hobo_provision.py @@ -34,7 +34,7 @@ class Command(BaseCommand): roles = get_role_model().objects.all() if self.verbosity > 0: self.stdout.write('Provisionning {} roles.'.format(roles.count())) - engine.notify_roles(ous, roles, full=True) + engine.notify_agents(list(engine.notify_roles(ous, roles, full=True))) def provision_users(self, engine, ous, batch_size=512, batch_sleep=30, verbosity=1): qs = get_user_model().objects.all() @@ -48,7 +48,7 @@ class Command(BaseCommand): while users: if verbosity > 0: self.stdout.write(' batch provisionning %d users and sleeping for %d seconds' % (len(users), batch_sleep)) - engine.notify_users(ous, users) + engine.notify_agents(list(engine.notify_users(ous, users))) users = list(qs.filter(id__gt=users[-1].pk)[:batch_size]) if users: time.sleep(batch_sleep) diff --git a/hobo/agent/authentic2/provisionning.py b/hobo/agent/authentic2/provisionning.py index f21dbda..bc25d21 100644 --- a/hobo/agent/authentic2/provisionning.py +++ b/hobo/agent/authentic2/provisionning.py @@ -1,8 +1,9 @@ +import os import json from urlparse import urljoin -import threading import copy import logging +import threading from django.contrib.auth import get_user_model from django.db import connection @@ -17,8 +18,8 @@ from authentic2.models import AttributeValue class Provisionning(object): + # allow provisionning to work with thread workers or gevent local = threading.local() - threads = set() def __init__(self): self.User = get_user_model() @@ -134,7 +135,7 @@ class Provisionning(object): for service, audience in self.get_audience(ou): for user in users: self.logger.info(u'provisionning user %s to %s', user, audience) - notify_agents({ + yield { '@type': 'provision', 'issuer': issuer, 'audience': [audience], @@ -143,7 +144,7 @@ class Provisionning(object): '@type': 'user', 'data': [user_to_json(service, user, user_roles)], } - }) + } else: for ou, users in ous.iteritems(): audience = [a for service, a in self.get_audience(ou)] @@ -151,7 +152,7 @@ class Provisionning(object): continue self.logger.info(u'provisionning users %s to %s', u', '.join(map(unicode, users)), u', '.join(audience)) - notify_agents({ + yield { '@type': 'provision', 'issuer': issuer, 'audience': audience, @@ -160,13 +161,13 @@ class Provisionning(object): '@type': 'user', 'data': [user_to_json(None, user, user_roles) for user in users], } - }) + } elif users: audience = [audience for ou in self.OU.objects.all() for s, audience in self.get_audience(ou)] self.logger.info(u'deprovisionning users %s from %s', u', '.join(map(unicode, users)), u', '.join(audience)) - notify_agents({ + yield { '@type': 'deprovision', 'issuer': issuer, 'audience': audience, @@ -177,7 +178,7 @@ class Provisionning(object): 'uuid': user.uuid, } for user in users] } - }) + } def notify_roles(self, ous, roles, mode='provision', full=False): roles = set([role for role in roles if not role.slug.startswith('_')]) @@ -214,7 +215,7 @@ class Provisionning(object): audience = [entity_id for service, entity_id in self.get_audience(ou)] self.logger.info(u'%sning roles %s to %s', mode, roles, audience) - notify_agents({ + yield { '@type': mode, 'audience': audience, 'full': full, @@ -222,12 +223,13 @@ class Provisionning(object): '@type': 'role', 'data': data, } - }) + } global_roles = set(ous.get(None, [])) for ou, ou_roles in ous.iteritems(): sent_roles = set(ou_roles) | global_roles - helper(ou, sent_roles) + for x in helper(ou, sent_roles): + yield x def provision(self): if (not hasattr(connection, 'tenant') or not connection.tenant or not @@ -239,29 +241,40 @@ class Provisionning(object): if not hasattr(self.local, 'saved') or not hasattr(self.local, 'deleted'): return - t = threading.Thread(target=self.do_provision, kwargs={ - 'saved': getattr(self.local, 'saved', {}), - 'deleted': getattr(self.local, 'deleted', {}), - }) - t.start() - self.threads.add(t) + saved = getattr(self.local, 'saved', {}) + deleted = getattr(self.local, 'deleted', {}) + + msgs = [] + ous = {ou.id: ou for ou in self.OU.objects.all()} + msgs.extend(self.notify_roles(ous, saved.get(self.Role, []))) + msgs.extend(self.notify_roles(ous, deleted.get(self.Role, []), mode='deprovision')) + msgs.extend(self.notify_users(ous, saved.get(self.User, []))) + msgs.extend(self.notify_users(ous, deleted.get(self.User, []), mode='deprovision')) + self.notify_agents(msgs) - def do_provision(self, saved, deleted, thread=None): + def notify_agents(self, msgs): + if not msgs: + return + pid = os.fork() + if pid: + self.local.pid = pid + return try: - ous = {ou.id: ou for ou in self.OU.objects.all()} - self.notify_roles(ous, saved.get(self.Role, [])) - self.notify_roles(ous, deleted.get(self.Role, []), mode='deprovision') - self.notify_users(ous, saved.get(self.User, [])) - self.notify_users(ous, deleted.get(self.User, []), mode='deprovision') + for msg in msgs: + notify_agents(msg) except Exception: # last step, clear everything - self.logger.exception(u'error in provisionning thread') - finally: - self.threads.discard(threading.current_thread()) + self.logger.exception(u'error in provisionning process') + os._exit(-1) + os._exit(0) def wait(self): - for thread in list(self.threads): - thread.join() + if hasattr(self.local, 'pid'): + try: + os.waitpid(self.local.pid, 0) + except OSError: + pass + del self.local.pid def __enter__(self): self.start() diff --git a/tests_authentic/test_provisionning.py b/tests_authentic/test_provisionning.py index 685d65c..4e99e85 100644 --- a/tests_authentic/test_provisionning.py +++ b/tests_authentic/test_provisionning.py @@ -1,8 +1,12 @@ # -*- coding: utf-8 -*- +import os import json +import contextlib import pytest import lasso +import multiprocessing +from Queue import Empty from mock import patch, call, ANY @@ -20,8 +24,69 @@ from hobo.agent.authentic2.provisionning import provisionning pytestmark = pytest.mark.django_db +@contextlib.contextmanager +def mock_notify_agents(): + class CrossProcessSideEffect(object): + def __init__(self): + self.q = multiprocessing.Queue() + self.reset_mock() + + def reset_mock(self): + self._call_count = 0 + self._call_args_list = [] + + def __call__(self, *args, **kwargs): + self.q.put((args, kwargs)) + + @property + def call_count(self): + self.consume() + return self._call_count + + @property + def call_args(self): + self.consume() + return self._call_args_list[-1] + + @property + def call_args_list(self): + self.consume() + return self._call_args_list + + def consume(self): + # consume que queue content + while not self.q.empty(): + try: + o = self.q.get_nowait() + except Empty: + break + self._call_count += 1 + self._call_args_list.append(o) + + old_exit = os._exit + + try: + with patch('hobo.agent.authentic2.provisionning.notify_agents') as notify_agents: + side_effect = CrossProcessSideEffect() + + def new_exit(r): + # multiprocessing.Queue() use a background thread to push + # the queue content to other processes using a pipe, to be + # sure that every data has crossed, your must close the + # queue and join the sending background thread before the + # process exit. + side_effect.q.close() + side_effect.q.join_thread() + old_exit(r) + os._exit = new_exit + notify_agents.side_effect = side_effect + yield side_effect + finally: + os._exit = old_exit + + def test_provision_role(transactional_db, tenant, caplog): - with patch('hobo.agent.authentic2.provisionning.notify_agents') as notify_agents: + with mock_notify_agents() as notify_agents: with tenant_context(tenant): LibertyProvider.objects.create(ou=get_default_ou(), name='provider', entity_id='http://provider.com', @@ -115,7 +180,7 @@ def test_provision_role(transactional_db, tenant, caplog): def test_provision_user(transactional_db, tenant, caplog): - with patch('hobo.agent.authentic2.provisionning.notify_agents') as notify_agents: + with mock_notify_agents() as notify_agents: with tenant_context(tenant): service = LibertyProvider.objects.create(ou=get_default_ou(), name='provider', entity_id='http://provider.com', @@ -277,7 +342,6 @@ def test_provision_user(transactional_db, tenant, caplog): data = objects['data'] assert isinstance(data, list) assert len(data) == 1 - print data for o in data: assert set(o.keys()) >= set(['uuid', 'username', 'first_name', 'is_superuser', 'last_name', 'email', 'roles']) @@ -449,40 +513,40 @@ def test_provision_createsuperuser(transactional_db, tenant, caplog): LibertyProvider.objects.create(ou=None, name='provider', entity_id='http://provider.com', protocol_conformance=lasso.PROTOCOL_SAML_2_0) - with patch('hobo.agent.authentic2.provisionning.notify_agents') as notify_agents: + with mock_notify_agents() as notify_agents: call_command('createsuperuser', domain=tenant.domain_url, uuid='coin', username='coin', email='coin@coin.org', interactive=False) assert notify_agents.call_count == 1 -@patch('hobo.agent.authentic2.provisionning.notify_agents') -def test_command_hobo_provision(notify_agents, transactional_db, tenant, caplog): - User = get_user_model() - with tenant_context(tenant): - ou = get_default_ou() - LibertyProvider.objects.create(ou=ou, name='provider', - entity_id='http://provider.com', - protocol_conformance=lasso.PROTOCOL_SAML_2_0) - for i in range(10): - Role.objects.create(name='role-%s' % i, ou=ou) - for i in range(10): - User.objects.create(username='user-%s' % i, first_name='John', - last_name='Doe %s' % i, ou=ou, - email='jone.doe-%s@example.com') +def test_command_hobo_provision(transactional_db, tenant, caplog): + with mock_notify_agents() as notify_agents: + User = get_user_model() + with tenant_context(tenant): + ou = get_default_ou() + LibertyProvider.objects.create(ou=ou, name='provider', + entity_id='http://provider.com', + protocol_conformance=lasso.PROTOCOL_SAML_2_0) + for i in range(10): + Role.objects.create(name='role-%s' % i, ou=ou) + for i in range(10): + User.objects.create(username='user-%s' % i, first_name='John', + last_name='Doe %s' % i, ou=ou, + email='jone.doe-%s@example.com') - with tenant_context(tenant): - # call_command('tenant_command', 'hobo_provision', ...) doesn't work - # https://github.com/bernardopires/django-tenant-schemas/issues/495 - # so we call the command from the tenant context. - call_command('hobo_provision', roles=True, users=True) - - msg_1 = notify_agents.call_args_list[0][0][0] - msg_2 = notify_agents.call_args_list[1][0][0] - assert msg_1['@type'] == 'provision' - assert msg_1['full'] is True - assert msg_1['objects']['@type'] == 'role' - assert len(msg_1['objects']['data']) == 10 - assert msg_2['@type'] == 'provision' - assert msg_2['full'] is False - assert msg_2['objects']['@type'] == 'user' - assert len(msg_2['objects']['data']) == 10 + with tenant_context(tenant): + # call_command('tenant_command', 'hobo_provision', ...) doesn't work + # https://github.com/bernardopires/django-tenant-schemas/issues/495 + # so we call the command from the tenant context. + call_command('hobo_provision', roles=True, users=True) + + msg_1 = notify_agents.call_args_list[0][0][0] + msg_2 = notify_agents.call_args_list[1][0][0] + assert msg_1['@type'] == 'provision' + assert msg_1['full'] is True + assert msg_1['objects']['@type'] == 'role' + assert len(msg_1['objects']['data']) == 10 + assert msg_2['@type'] == 'provision' + assert msg_2['full'] is False + assert msg_2['objects']['@type'] == 'user' + assert len(msg_2['objects']['data']) == 10 -- 2.18.0