Project

General

Profile

0002-authentic-send-celery-notification-in-a-child-proces.patch

Benjamin Dauvergne, 02 Mar 2019 11:30 AM

Download (15.6 KB)

View differences:

Subject: [PATCH 2/2] authentic: send celery notification in a child process
 (fixes #26911)

Celery use Kombu to communicate with RabbitMQ but kombu has the bad
habit of modifying the default socket timeout globally messing with
authentic.
 .../management/commands/hobo_provision.py     |   4 +-
 hobo/agent/authentic2/provisionning.py        |  69 +++++----
 tests_authentic/test_provisionning.py         | 132 +++++++++++++-----
 3 files changed, 141 insertions(+), 64 deletions(-)
hobo/agent/authentic2/management/commands/hobo_provision.py
34 34
        roles = get_role_model().objects.all()
35 35
        if self.verbosity > 0:
36 36
            self.stdout.write('Provisionning {} roles.'.format(roles.count()))
37
        engine.notify_roles(ous, roles, full=True)
37
        engine.notify_agents(list(engine.notify_roles(ous, roles, full=True)))
38 38

  
39 39
    def provision_users(self, engine, ous, batch_size=512, batch_sleep=30, verbosity=1):
40 40
        qs = get_user_model().objects.all()
......
48 48
            while users:
49 49
                if verbosity > 0:
50 50
                    self.stdout.write('  batch provisionning %d users and sleeping for %d seconds' % (len(users), batch_sleep))
51
                engine.notify_users(ous, users)
51
                engine.notify_agents(list(engine.notify_users(ous, users)))
52 52
                users = list(qs.filter(id__gt=users[-1].pk)[:batch_size])
53 53
                if users:
54 54
                    time.sleep(batch_sleep)
hobo/agent/authentic2/provisionning.py
1
import os
1 2
import json
2 3
from urlparse import urljoin
3
import threading
4 4
import copy
5 5
import logging
6
import threading
6 7

  
7 8
from django.contrib.auth import get_user_model
8 9
from django.db import connection
......
17 18

  
18 19

  
19 20
class Provisionning(object):
21
    # allow provisionning to work with thread workers or gevent
20 22
    local = threading.local()
21
    threads = set()
22 23

  
23 24
    def __init__(self):
24 25
        self.User = get_user_model()
......
134 135
                    for service, audience in self.get_audience(ou):
135 136
                        for user in users:
136 137
                            self.logger.info(u'provisionning user %s to %s', user, audience)
137
                            notify_agents({
138
                            yield {
138 139
                                '@type': 'provision',
139 140
                                'issuer': issuer,
140 141
                                'audience': [audience],
......
143 144
                                    '@type': 'user',
144 145
                                    'data': [user_to_json(service, user, user_roles)],
145 146
                                }
146
                            })
147
                            }
147 148
            else:
148 149
                for ou, users in ous.iteritems():
149 150
                    audience = [a for service, a in self.get_audience(ou)]
......
151 152
                        continue
152 153
                    self.logger.info(u'provisionning users %s to %s',
153 154
                                     u', '.join(map(unicode, users)), u', '.join(audience))
154
                    notify_agents({
155
                    yield {
155 156
                        '@type': 'provision',
156 157
                        'issuer': issuer,
157 158
                        'audience': audience,
......
160 161
                            '@type': 'user',
161 162
                            'data': [user_to_json(None, user, user_roles) for user in users],
162 163
                        }
163
                    })
164
                    }
164 165
        elif users:
165 166
            audience = [audience for ou in self.OU.objects.all()
166 167
                        for s, audience in self.get_audience(ou)]
167 168
            self.logger.info(u'deprovisionning users %s from %s', u', '.join(map(unicode, users)),
168 169
                             u', '.join(audience))
169
            notify_agents({
170
            yield {
170 171
                '@type': 'deprovision',
171 172
                'issuer': issuer,
172 173
                'audience': audience,
......
177 178
                        'uuid': user.uuid,
178 179
                    } for user in users]
179 180
                }
180
            })
181
            }
181 182

  
182 183
    def notify_roles(self, ous, roles, mode='provision', full=False):
183 184
        roles = set([role for role in roles if not role.slug.startswith('_')])
......
214 215

  
215 216
            audience = [entity_id for service, entity_id in self.get_audience(ou)]
216 217
            self.logger.info(u'%sning roles %s to %s', mode, roles, audience)
217
            notify_agents({
218
            yield {
218 219
                '@type': mode,
219 220
                'audience': audience,
220 221
                'full': full,
......
222 223
                    '@type': 'role',
223 224
                    'data': data,
224 225
                }
225
            })
226
            }
226 227

  
227 228
        global_roles = set(ous.get(None, []))
228 229
        for ou, ou_roles in ous.iteritems():
229 230
            sent_roles = set(ou_roles) | global_roles
230
            helper(ou, sent_roles)
231
            for x in helper(ou, sent_roles):
232
                yield x
231 233

  
232 234
    def provision(self):
233 235
        if (not hasattr(connection, 'tenant') or not connection.tenant or not
......
239 241
        if not hasattr(self.local, 'saved') or not hasattr(self.local, 'deleted'):
240 242
            return
241 243

  
242
        t = threading.Thread(target=self.do_provision, kwargs={
243
            'saved': getattr(self.local, 'saved', {}),
244
            'deleted': getattr(self.local, 'deleted', {}),
245
        })
246
        t.start()
247
        self.threads.add(t)
244
        saved = getattr(self.local, 'saved', {})
245
        deleted = getattr(self.local, 'deleted', {})
246

  
247
        msgs = []
248
        ous = {ou.id: ou for ou in self.OU.objects.all()}
249
        msgs.extend(self.notify_roles(ous, saved.get(self.Role, [])))
250
        msgs.extend(self.notify_roles(ous, deleted.get(self.Role, []), mode='deprovision'))
251
        msgs.extend(self.notify_users(ous, saved.get(self.User, [])))
252
        msgs.extend(self.notify_users(ous, deleted.get(self.User, []), mode='deprovision'))
253
        self.notify_agents(msgs)
248 254

  
249
    def do_provision(self, saved, deleted, thread=None):
255
    def notify_agents(self, msgs):
256
        if not msgs:
257
            return
258
        pid = os.fork()
259
        if pid:
260
            self.local.pid = pid
261
            return
250 262
        try:
251
            ous = {ou.id: ou for ou in self.OU.objects.all()}
252
            self.notify_roles(ous, saved.get(self.Role, []))
253
            self.notify_roles(ous, deleted.get(self.Role, []), mode='deprovision')
254
            self.notify_users(ous, saved.get(self.User, []))
255
            self.notify_users(ous, deleted.get(self.User, []), mode='deprovision')
263
            for msg in msgs:
264
                notify_agents(msg)
256 265
        except Exception:
257 266
            # last step, clear everything
258
            self.logger.exception(u'error in provisionning thread')
259
        finally:
260
            self.threads.discard(threading.current_thread())
267
            self.logger.exception(u'error in provisionning process')
268
            os._exit(-1)
269
        os._exit(0)
261 270

  
262 271
    def wait(self):
263
        for thread in list(self.threads):
264
            thread.join()
272
        if hasattr(self.local, 'pid'):
273
            try:
274
                os.waitpid(self.local.pid, 0)
275
            except OSError:
276
                pass
277
            del self.local.pid
265 278

  
266 279
    def __enter__(self):
267 280
        self.start()
tests_authentic/test_provisionning.py
1 1
# -*- coding: utf-8 -*-
2
import os
2 3
import json
4
import contextlib
3 5

  
4 6
import pytest
5 7
import lasso
8
import multiprocessing
9
from Queue import Empty
6 10

  
7 11
from mock import patch, call, ANY
8 12

  
......
20 24
pytestmark = pytest.mark.django_db
21 25

  
22 26

  
27
@contextlib.contextmanager
28
def mock_notify_agents():
29
    class CrossProcessSideEffect(object):
30
        def __init__(self):
31
            self.q = multiprocessing.Queue()
32
            self.reset_mock()
33

  
34
        def reset_mock(self):
35
            self._call_count = 0
36
            self._call_args_list = []
37

  
38
        def __call__(self, *args, **kwargs):
39
            self.q.put((args, kwargs))
40

  
41
        @property
42
        def call_count(self):
43
            self.consume()
44
            return self._call_count
45

  
46
        @property
47
        def call_args(self):
48
            self.consume()
49
            return self._call_args_list[-1]
50

  
51
        @property
52
        def call_args_list(self):
53
            self.consume()
54
            return self._call_args_list
55

  
56
        def consume(self):
57
            # consume que queue content
58
            while not self.q.empty():
59
                try:
60
                    o = self.q.get_nowait()
61
                except Empty:
62
                    break
63
                self._call_count += 1
64
                self._call_args_list.append(o)
65

  
66
    old_exit = os._exit
67

  
68
    try:
69
        with patch('hobo.agent.authentic2.provisionning.notify_agents') as notify_agents:
70
                side_effect = CrossProcessSideEffect()
71

  
72
                def new_exit(r):
73
                    # multiprocessing.Queue() use a background thread to push
74
                    # the queue content to other processes using a pipe, to be
75
                    # sure that every data has crossed, your must close the
76
                    # queue and join the sending background thread before the
77
                    # process exit.
78
                    side_effect.q.close()
79
                    side_effect.q.join_thread()
80
                    old_exit(r)
81
                os._exit = new_exit
82
                notify_agents.side_effect = side_effect
83
                yield side_effect
84
    finally:
85
        os._exit = old_exit
86

  
87

  
23 88
def test_provision_role(transactional_db, tenant, caplog):
24
    with patch('hobo.agent.authentic2.provisionning.notify_agents') as notify_agents:
89
    with mock_notify_agents() as notify_agents:
25 90
        with tenant_context(tenant):
26 91
            LibertyProvider.objects.create(ou=get_default_ou(), name='provider',
27 92
                                           entity_id='http://provider.com',
......
115 180

  
116 181

  
117 182
def test_provision_user(transactional_db, tenant, caplog):
118
    with patch('hobo.agent.authentic2.provisionning.notify_agents') as notify_agents:
183
    with mock_notify_agents() as notify_agents:
119 184
        with tenant_context(tenant):
120 185
            service = LibertyProvider.objects.create(ou=get_default_ou(), name='provider',
121 186
                                                     entity_id='http://provider.com',
......
277 342
                data = objects['data']
278 343
                assert isinstance(data, list)
279 344
                assert len(data) == 1
280
                print data
281 345
                for o in data:
282 346
                    assert set(o.keys()) >= set(['uuid', 'username', 'first_name',
283 347
                                                 'is_superuser', 'last_name', 'email', 'roles'])
......
449 513
        LibertyProvider.objects.create(ou=None, name='provider',
450 514
                                                 entity_id='http://provider.com',
451 515
                                                 protocol_conformance=lasso.PROTOCOL_SAML_2_0)
452
    with patch('hobo.agent.authentic2.provisionning.notify_agents') as notify_agents:
516
    with mock_notify_agents() as notify_agents:
453 517
        call_command('createsuperuser', domain=tenant.domain_url, uuid='coin',
454 518
                     username='coin', email='coin@coin.org', interactive=False)
455 519
        assert notify_agents.call_count == 1
456 520

  
457 521

  
458
@patch('hobo.agent.authentic2.provisionning.notify_agents')
459
def test_command_hobo_provision(notify_agents, transactional_db, tenant, caplog):
460
    User = get_user_model()
461
    with tenant_context(tenant):
462
        ou = get_default_ou()
463
        LibertyProvider.objects.create(ou=ou, name='provider',
464
                                       entity_id='http://provider.com',
465
                                       protocol_conformance=lasso.PROTOCOL_SAML_2_0)
466
        for i in range(10):
467
            Role.objects.create(name='role-%s' % i, ou=ou)
468
        for i in range(10):
469
            User.objects.create(username='user-%s' % i, first_name='John',
470
                                last_name='Doe %s' % i, ou=ou,
471
                                email='jone.doe-%s@example.com')
522
def test_command_hobo_provision(transactional_db, tenant, caplog):
523
    with mock_notify_agents() as notify_agents:
524
        User = get_user_model()
525
        with tenant_context(tenant):
526
            ou = get_default_ou()
527
            LibertyProvider.objects.create(ou=ou, name='provider',
528
                                           entity_id='http://provider.com',
529
                                           protocol_conformance=lasso.PROTOCOL_SAML_2_0)
530
            for i in range(10):
531
                Role.objects.create(name='role-%s' % i, ou=ou)
532
            for i in range(10):
533
                User.objects.create(username='user-%s' % i, first_name='John',
534
                                    last_name='Doe %s' % i, ou=ou,
535
                                    email='jone.doe-%s@example.com')
472 536

  
473
    with tenant_context(tenant):
474
        # call_command('tenant_command', 'hobo_provision', ...) doesn't work
475
        # https://github.com/bernardopires/django-tenant-schemas/issues/495
476
        # so we call the command from the tenant context.
477
        call_command('hobo_provision', roles=True, users=True)
478

  
479
    msg_1 = notify_agents.call_args_list[0][0][0]
480
    msg_2 = notify_agents.call_args_list[1][0][0]
481
    assert msg_1['@type'] == 'provision'
482
    assert msg_1['full'] is True
483
    assert msg_1['objects']['@type'] == 'role'
484
    assert len(msg_1['objects']['data']) == 10
485
    assert msg_2['@type'] == 'provision'
486
    assert msg_2['full'] is False
487
    assert msg_2['objects']['@type'] == 'user'
488
    assert len(msg_2['objects']['data']) == 10
537
        with tenant_context(tenant):
538
            # call_command('tenant_command', 'hobo_provision', ...) doesn't work
539
            # https://github.com/bernardopires/django-tenant-schemas/issues/495
540
            # so we call the command from the tenant context.
541
            call_command('hobo_provision', roles=True, users=True)
542

  
543
        msg_1 = notify_agents.call_args_list[0][0][0]
544
        msg_2 = notify_agents.call_args_list[1][0][0]
545
        assert msg_1['@type'] == 'provision'
546
        assert msg_1['full'] is True
547
        assert msg_1['objects']['@type'] == 'role'
548
        assert len(msg_1['objects']['data']) == 10
549
        assert msg_2['@type'] == 'provision'
550
        assert msg_2['full'] is False
551
        assert msg_2['objects']['@type'] == 'user'
552
        assert len(msg_2['objects']['data']) == 10
489
-