Projet

Général

Profil

0001-agent-a2-prevent-useless-thread-launching-34484.patch

Benjamin Dauvergne, 04 juillet 2019 11:14

Télécharger (24 ko)

Voir les différences:

Subject: [PATCH] agent/a2: prevent useless thread launching (#34484)

 hobo/agent/authentic2/middleware.py    |   7 +-
 hobo/agent/authentic2/provisionning.py | 170 ++++++++++++++-----------
 tests_authentic/conftest.py            | 126 ++++++++++++------
 tests_authentic/settings.py            |   4 +
 tests_authentic/test_provisionning.py  |  38 +++++-
 5 files changed, 224 insertions(+), 121 deletions(-)
hobo/agent/authentic2/middleware.py
1
from django.conf import settings
2

  
1 3
from .provisionning import provisionning
2 4

  
3 5

  
......
6 8
        provisionning.start()
7 9

  
8 10
    def process_exception(self, request, exception):
9
        provisionning.clean()
11
        provisionning.stop(provision=False)
10 12

  
11 13
    def process_response(self, request, response):
12
        provisionning.provision()
14
        provisionning.stop(provision=True, wait=getattr(settings, 'HOBO_PROVISIONNING_SYNCHRONOUS', False))
13 15
        return response
14

  
hobo/agent/authentic2/provisionning.py
15 15
from authentic2.a2_rbac.models import RoleAttribute
16 16
from authentic2.models import AttributeValue
17 17

  
18
User = get_user_model()
19
Role = get_role_model()
20
OU = get_ou_model()
21
RoleParenting = get_role_parenting_model()
18 22

  
19
class Provisionning(object):
20
    local = threading.local()
23
logger = logging.getLogger(__name__)
24

  
25

  
26
class Provisionning(threading.local):
27
    __slots__ = ['threads']
21 28
    threads = set()
22 29

  
23 30
    def __init__(self):
24
        self.User = get_user_model()
25
        self.Role = get_role_model()
26
        self.OU = get_ou_model()
27
        self.RoleParenting = get_role_parenting_model()
28
        self.logger = logging.getLogger(__name__)
31
        self.stack = []
29 32

  
30 33
    def start(self):
31
        self.local.saved = {}
32
        self.local.deleted = {}
34
        self.stack.append({
35
            'saved': {},
36
            'deleted': {},
37
        })
38

  
39
    def stop(self, provision=True, wait=True):
40
        context = self.stack.pop()
33 41

  
34
    def clean(self):
35
        if hasattr(self.local, 'saved'):
36
            del self.local.saved
37
        if hasattr(self.local, 'deleted'):
38
            del self.local.deleted
42
        if provision:
43
            self.provision(**context)
44
            if wait:
45
                self.wait()
39 46

  
40
    def saved(self, *args):
41
        if not hasattr(self.local, 'saved'):
47
    @property
48
    def saved(self):
49
        if self.stack:
50
            return self.stack[-1]['saved']
51
        return None
52

  
53
    @property
54
    def deleted(self):
55
        if self.stack:
56
            return self.stack[-1]['deleted']
57
        return None
58

  
59
    def add_saved(self, *args):
60
        if not self.stack:
42 61
            return
43 62

  
44 63
        for instance in args:
45
            klass = self.User if isinstance(instance, self.User) else self.Role
46
            self.local.saved.setdefault(klass, set()).add(instance)
64
            klass = User if isinstance(instance, User) else Role
65
            self.saved.setdefault(klass, set()).add(instance)
47 66

  
48
    def deleted(self, *args):
49
        if not hasattr(self.local, 'saved'):
67
    def add_deleted(self, *args):
68
        if not self.stack:
50 69
            return
51 70

  
52 71
        for instance in args:
53
            klass = self.User if isinstance(instance, self.User) else self.Role
54
            self.local.deleted.setdefault(klass, set()).add(instance)
55
            self.local.saved.get(klass, set()).discard(instance)
72
            klass = User if isinstance(instance, User) else Role
73
            self.deleted.setdefault(klass, set()).add(instance)
74
            self.saved.get(klass, set()).discard(instance)
56 75

  
57 76
    def resolve_ou(self, instances, ous):
58 77
        for instance in instances:
......
61 80

  
62 81
    def notify_users(self, ous, users, mode='provision'):
63 82
        if mode == 'provision':
64
            users = (self.User.objects.filter(id__in=[u.id for u in users])
83
            users = (User.objects.filter(id__in=[u.id for u in users])
65 84
                     .select_related('ou').prefetch_related('attribute_values__attribute'))
66 85
        else:
67 86
            self.resolve_ou(users, ous)
......
105 124
            # Find roles giving a superuser attribute
106 125
            # If there is any role of this kind, we do one provisionning message for each user and
107 126
            # each service.
108
            roles_with_attributes = (self.Role.objects.filter(members__in=users)
127
            roles_with_attributes = (Role.objects.filter(members__in=users)
109 128
                                     .parents(include_self=True)
110 129
                                     .filter(attributes__name='is_superuser')
111 130
                                     .exists())
112 131

  
113
            all_roles = (self.Role.objects.filter(members__in=users).parents()
132
            all_roles = (Role.objects.filter(members__in=users).parents()
114 133
                         .prefetch_related('attributes').distinct())
115 134
            roles = dict((r.id, r) for r in all_roles)
116 135
            user_roles = {}
117 136
            parents = {}
118
            for rp in self.RoleParenting.objects.filter(child__in=all_roles):
137
            for rp in RoleParenting.objects.filter(child__in=all_roles):
119 138
                parents.setdefault(rp.child.id, []).append(rp.parent.id)
120
            Through = self.Role.members.through
139
            Through = Role.members.through
121 140
            for u_id, r_id in Through.objects.filter(role__members__in=users).values_list('user_id',
122 141
                                                                                      'role_id'):
123 142
                user_roles.setdefault(u_id, set()).add(roles[r_id])
......
133 152
                for ou, users in ous.iteritems():
134 153
                    for service, audience in self.get_audience(ou):
135 154
                        for user in users:
136
                            self.logger.info(u'provisionning user %s to %s', user, audience)
155
                            logger.info(u'provisionning user %s to %s', user, audience)
137 156
                            notify_agents({
138 157
                                '@type': 'provision',
139 158
                                'issuer': issuer,
......
149 168
                    audience = [a for service, a in self.get_audience(ou)]
150 169
                    if not audience:
151 170
                        continue
152
                    self.logger.info(u'provisionning users %s to %s',
171
                    logger.info(u'provisionning users %s to %s',
153 172
                                     u', '.join(map(unicode, users)), u', '.join(audience))
154 173
                    notify_agents({
155 174
                        '@type': 'provision',
......
162 181
                        }
163 182
                    })
164 183
        elif users:
165
            audience = [audience for ou in self.OU.objects.all()
184
            audience = [audience for ou in OU.objects.all()
166 185
                        for s, audience in self.get_audience(ou)]
167
            self.logger.info(u'deprovisionning users %s from %s', u', '.join(map(unicode, users)),
186
            logger.info(u'deprovisionning users %s from %s', u', '.join(map(unicode, users)),
168 187
                             u', '.join(audience))
169 188
            notify_agents({
170 189
                '@type': 'deprovision',
......
213 232
                ]
214 233

  
215 234
            audience = [entity_id for service, entity_id in self.get_audience(ou)]
216
            self.logger.info(u'%sning roles %s to %s', mode, roles, audience)
235
            logger.info(u'%sning roles %s to %s', mode, roles, audience)
217 236
            notify_agents({
218 237
                '@type': mode,
219 238
                'audience': audience,
......
229 248
            sent_roles = set(ou_roles) | global_roles
230 249
            helper(ou, sent_roles)
231 250

  
232
    def provision(self):
251
    def provision(self, saved, deleted):
252
        # Returns if:
253
        # - we are not in a tenant
254
        # - provsionning is disabled
255
        # - there is nothing to do
233 256
        if (not hasattr(connection, 'tenant') or not connection.tenant or not
234 257
                hasattr(connection.tenant, 'domain_url')):
235 258
            return
236 259
        if not getattr(settings, 'HOBO_ROLE_EXPORT', True):
237 260
            return
238
        # exit early if not started
239
        if not hasattr(self.local, 'saved') or not hasattr(self.local, 'deleted'):
261
        if not (saved or deleted):
240 262
            return
241 263

  
242
        t = threading.Thread(target=self.do_provision, kwargs={
243
            'saved': getattr(self.local, 'saved', {}),
244
            'deleted': getattr(self.local, 'deleted', {}),
245
        })
264
        t = threading.Thread(
265
            target=self.do_provision,
266
            kwargs={'saved': saved, 'deleted': deleted})
246 267
        t.start()
247 268
        self.threads.add(t)
248 269

  
249 270
    def do_provision(self, saved, deleted, thread=None):
250 271
        try:
251
            ous = {ou.id: ou for ou in self.OU.objects.all()}
252
            self.notify_roles(ous, saved.get(self.Role, []))
253
            self.notify_roles(ous, deleted.get(self.Role, []), mode='deprovision')
254
            self.notify_users(ous, saved.get(self.User, []))
255
            self.notify_users(ous, deleted.get(self.User, []), mode='deprovision')
272
            ous = {ou.id: ou for ou in OU.objects.all()}
273
            self.notify_roles(ous, saved.get(Role, []))
274
            self.notify_roles(ous, deleted.get(Role, []), mode='deprovision')
275
            self.notify_users(ous, saved.get(User, []))
276
            self.notify_users(ous, deleted.get(User, []), mode='deprovision')
256 277
        except Exception:
257 278
            # last step, clear everything
258
            self.logger.exception(u'error in provisionning thread')
279
            logger.exception(u'error in provisionning thread')
259 280
        finally:
260 281
            self.threads.discard(threading.current_thread())
261 282

  
......
267 288
        self.start()
268 289

  
269 290
    def __exit__(self, exc_type, exc_value, exc_tb):
270
        if exc_type is None:
271
            self.provision()
272
            self.clean()
273
            self.wait()
274
        else:
275
            self.clean()
291
        if not self.stack:
292
            return
293
        self.stop(provision=exc_type is None)
276 294

  
277 295
    def get_audience(self, ou):
278 296
        if ou:
......
298 316
        return urljoin(base_url, reverse('a2-idp-saml-metadata'))
299 317

  
300 318
    def pre_save(self, sender, instance, raw, using, update_fields, **kwargs):
319
        if not self.stack:
320
            return
301 321
        # we skip new instances
302 322
        if not instance.pk:
303 323
            return
304
        if not isinstance(instance, (self.User, self.Role, RoleAttribute, AttributeValue)):
324
        if not isinstance(instance, (User, Role, RoleAttribute, AttributeValue)):
305 325
            return
306 326
        # ignore last_login update on login
307
        if isinstance(instance, self.User) and update_fields == ['last_login']:
327
        if isinstance(instance, User) and (update_fields and set(update_fields) == set(['last_login'])):
308 328
            return
309 329
        if isinstance(instance, RoleAttribute):
310 330
            instance = instance.role
311 331
        elif isinstance(instance, AttributeValue):
312
            if not isinstance(instance.owner, self.User):
332
            if not isinstance(instance.owner, User):
313 333
                return
314 334
            instance = instance.owner
315
        self.saved(instance)
335
        self.add_saved(instance)
316 336

  
317 337
    def post_save(self, sender, instance, created, raw, using, update_fields, **kwargs):
338
        if not self.stack:
339
            return
318 340
        # during post_save we only handle new instances
319
        if isinstance(instance, self.RoleParenting):
320
            self.saved(*list(instance.child.all_members()))
341
        if isinstance(instance, RoleParenting):
342
            self.add_saved(*list(instance.child.all_members()))
321 343
            return
322 344
        if not created:
323 345
            return
324
        if not isinstance(instance, (self.User, self.Role, RoleAttribute, AttributeValue)):
346
        if not isinstance(instance, (User, Role, RoleAttribute, AttributeValue)):
325 347
            return
326 348
        if isinstance(instance, RoleAttribute):
327 349
            instance = instance.role
328 350
        elif isinstance(instance, AttributeValue):
329
            if not isinstance(instance.owner, self.User):
351
            if not isinstance(instance.owner, User):
330 352
                return
331 353
            instance = instance.owner
332
        self.saved(instance)
354
        self.add_saved(instance)
333 355

  
334 356
    def pre_delete(self, sender, instance, using, **kwargs):
335
        if isinstance(instance, (self.User, self.Role)):
336
            self.deleted(copy.copy(instance))
357
        if not self.stack:
358
            return
359
        if isinstance(instance, (User, Role)):
360
            self.add_deleted(copy.copy(instance))
337 361
        elif isinstance(instance, RoleAttribute):
338 362
            instance = instance.role
339
            self.saved(instance)
363
            self.add_saved(instance)
340 364
        elif isinstance(instance, AttributeValue):
341
            if not isinstance(instance.owner, self.User):
365
            if not isinstance(instance.owner, User):
342 366
                return
343 367
            instance = instance.owner
344
            self.saved(instance)
345
        elif isinstance(instance, self.RoleParenting):
346
            self.saved(*list(instance.child.all_members()))
368
            self.add_saved(instance)
369
        elif isinstance(instance, RoleParenting):
370
            self.add_saved(*list(instance.child.all_members()))
347 371

  
348 372
    def m2m_changed(self, sender, instance, action, reverse, model, pk_set, using, **kwargs):
373
        if not self.stack:
374
            return
349 375
        if action != 'pre_clear' and action.startswith('pre_'):
350 376
            return
351
        if sender is self.Role.members.through:
352
            self.saved(instance)
377
        if sender is Role.members.through:
378
            self.add_saved(instance)
353 379
            # on a clear, pk_set is None
354 380
            for other_instance in model.objects.filter(pk__in=pk_set or []):
355
                self.saved(other_instance)
381
                self.add_saved(other_instance)
356 382
            if action == 'pre_clear':
357 383
                # when the action is pre_clear we need to lookup the current value of the members
358 384
                # relation, to re-provision all previously enroled users.
359 385
                if not reverse:
360 386
                    for other_instance in instance.members.all():
361
                        self.saved(other_instance)
387
                        self.add_saved(other_instance)
tests_authentic/conftest.py
1 1
import os
2
import tempfile
3
import shutil
4 2
import json
5 3

  
4
from django_webtest import WebTestMixin, DjangoTestApp
6 5
import pytest
7 6

  
7
from django.db import connection
8
from tenant_schemas.postgresql_backend.base import FakeTenant
9
from tenant_schemas.utils import tenant_context
10

  
11
from hobo.multitenant.models import Tenant
12

  
8 13

  
9 14
@pytest.fixture
10
def tenant_base(request, settings):
11
    base = tempfile.mkdtemp('authentic-tenant-base')
15
def tenant_base(tmpdir, settings):
16
    base = str(tmpdir.mkdir('authentic-tenant-base'))
12 17
    settings.TENANT_BASE = base
13

  
14
    def fin():
15
        shutil.rmtree(base)
16
    request.addfinalizer(fin)
17 18
    return base
18 19

  
19 20

  
20
@pytest.fixture(scope='function')
21
def tenant(transactional_db, request, settings, tenant_base):
22
    from hobo.multitenant.models import Tenant
23
    base = tenant_base
21
@pytest.fixture
22
def tenant_factory(transactional_db, tenant_base):
23
    tenants = []
24 24

  
25
    @pytest.mark.django_db
26
    def make_tenant(name):
27
        tenant_dir = os.path.join(base, name)
25
    def factory(name):
26
        tenant_dir = os.path.join(tenant_base, name)
28 27
        os.mkdir(tenant_dir)
29 28
        with open(os.path.join(tenant_dir, 'unsecure'), 'w') as fd:
30 29
            fd.write('1')
......
37 36
                    'other_variable': 'foo',
38 37
                },
39 38
                'services': [
40
                    {'slug': 'test',
41
                     'service-id': 'authentic',
42
                     'title': 'Test',
43
                     'this': True,
44
                     'secret_key': '12345',
45
                     'base_url': 'http://%s' % name,
46
                     'variables': {
47
                         'other_variable': 'bar',
48
                     }
39
                    {
40
                        'slug': 'test',
41
                        'service-id': 'authentic',
42
                        'title': 'Test',
43
                        'this': True,
44
                        'secret_key': '12345',
45
                        'base_url': 'http://%s' % name,
46
                        'variables': {
47
                            'other_variable': 'bar',
48
                        }
49
                    },
50
                    {
51
                        'slug': 'other',
52
                        'title': 'Other',
53
                        'service-id': 'welco',
54
                        'secret_key': 'abcdef',
55
                        'base_url': 'http://other.example.net'
49 56
                    },
50
                    {'slug': 'other',
51
                     'title': 'Other',
52
                     'service-id': 'welco',
53
                     'secret_key': 'abcdef',
54
                     'base_url': 'http://other.example.net'},
55
                    ]}, fd)
56
        t = Tenant(domain_url=name,
57
                      schema_name=name.replace('-', '_').replace('.', '_'))
57
                ]
58
            }, fd)
59
        schema_name = name.replace('-', '_').replace('.', '_')
60
        t = Tenant(domain_url=name, schema_name=schema_name)
58 61
        t.create_schema()
62
        tenants.append(t)
59 63
        return t
60
    tenants = [make_tenant('authentic.example.net')]
61

  
62
    def fin():
63
        from django.db import connection
64
    try:
65
        yield factory
66
    finally:
67
        # cleanup all created tenants
64 68
        connection.set_schema_to_public()
65
        for t in tenants:
66
            t.delete(True)
67
    request.addfinalizer(fin)
68
    return tenants[0]
69
        with tenant_context(FakeTenant('public')):
70
            for tenant in tenants:
71
                tenant.delete(force_drop=True)
72

  
73

  
74
@pytest.fixture
75
def tenant(tenant_factory):
76
    return tenant_factory('authentic.example.net')
77

  
78

  
79
@pytest.fixture
80
def another_tenant(tenant_factory):
81
    return tenant_factory('another-authentic.example.net')
82

  
83

  
84
# remove this class when this issue is fixed
85
# https://github.com/django-webtest/django-webtest/issues/102
86
class FixEnvironDjangoTestApp(DjangoTestApp):
87
    def _update_environ(self, environ, user=None):
88
        fixup = ((not environ or 'HTTP_HOST' not in environ)
89
                 and (self.extra_environ and 'HTTP_HOST' in self.extra_environ))
90
        environ = super(FixEnvironDjangoTestApp, self)._update_environ(environ, user=user)
91
        if fixup:
92
            environ.pop('HTTP_HOST', None)
93
        return environ
94

  
95

  
96
@pytest.fixture
97
def app_factory(request):
98
    wtm = WebTestMixin()
99
    wtm._patch_settings()
100

  
101
    def factory(hostname='testserver'):
102
        if hasattr(hostname, 'domain_url'):
103
            hostname = hostname.domain_url
104
        return FixEnvironDjangoTestApp(extra_environ={'HTTP_HOST': hostname})
105

  
106
    try:
107
        yield factory
108
    finally:
109
        wtm._unpatch_settings()
110

  
111

  
112
@pytest.fixture
113
def notify_agents(mocker):
114
    yield mocker.patch('hobo.agent.authentic2.provisionning.notify_agents')
tests_authentic/settings.py
39 39
}
40 40

  
41 41
HOBO_ROLE_EXPORT = True
42

  
43
ALLOWED_HOSTS = ['*']
44
SESSION_COOKIE_SECURE = False
45
CSRF_COOKIE_SECURE = False
tests_authentic/test_provisionning.py
17 17
from authentic2.models import Attribute, AttributeValue
18 18
from hobo.agent.authentic2.provisionning import provisionning
19 19

  
20
User = get_user_model()
21

  
20 22
pytestmark = pytest.mark.django_db
21 23

  
22 24

  
......
126 128
            role2.attributes.create(kind='json', name='emails', value='["zob@example.net"]')
127 129
            child_role = Role.objects.create(name='child', ou=get_default_ou())
128 130
            notify_agents.reset_mock()
129
            User = get_user_model()
130 131
            attribute = Attribute.objects.create(label='Code postal', name='code_postal',
131 132
                                                 kind='string')
132 133
            with provisionning:
......
277 278
                data = objects['data']
278 279
                assert isinstance(data, list)
279 280
                assert len(data) == 1
280
                print data
281 281
                for o in data:
282 282
                    assert set(o.keys()) >= set(['uuid', 'username', 'first_name',
283 283
                                                 'is_superuser', 'last_name', 'email', 'roles'])
......
447 447
    with tenant_context(tenant):
448 448
        # create a provider so notification messages have an audience.
449 449
        LibertyProvider.objects.create(ou=None, name='provider',
450
                                                 entity_id='http://provider.com',
451
                                                 protocol_conformance=lasso.PROTOCOL_SAML_2_0)
450
                                       entity_id='http://provider.com',
451
                                       protocol_conformance=lasso.PROTOCOL_SAML_2_0)
452 452
    with patch('hobo.agent.authentic2.provisionning.notify_agents') as notify_agents:
453 453
        call_command('createsuperuser', domain=tenant.domain_url, uuid='coin',
454 454
                     username='coin', email='coin@coin.org', interactive=False)
455 455
        assert notify_agents.call_count == 1
456 456

  
457 457

  
458
@patch('hobo.agent.authentic2.provisionning.notify_agents')
459 458
def test_command_hobo_provision(notify_agents, transactional_db, tenant, caplog):
460
    User = get_user_model()
461 459
    with tenant_context(tenant):
462 460
        ou = get_default_ou()
463 461
        LibertyProvider.objects.create(ou=ou, name='provider',
......
486 484
    assert msg_2['full'] is False
487 485
    assert msg_2['objects']['@type'] == 'user'
488 486
    assert len(msg_2['objects']['data']) == 10
487

  
488

  
489
def test_middleware(notify_agents, app_factory, tenant, settings):
490
    settings.HOBO_PROVISIONNING_SYNCHRONOUS = True
491

  
492
    with tenant_context(tenant):
493
        user = User.objects.create(username='john', ou=get_default_ou())
494
        user.set_password('password')
495
        user.save()
496
        LibertyProvider.objects.create(ou=get_default_ou(),
497
                                       name='provider',
498
                                       entity_id='http://provider.com',
499
                                       protocol_conformance=lasso.PROTOCOL_SAML_2_0)
500
    assert notify_agents.call_count == 0
501

  
502
    app = app_factory(tenant)
503
    resp = app.get('/login/')
504
    form = resp.form
505
    form.set('username', 'john')
506
    form.set('password', 'password')
507
    resp = form.submit(name='login-password-submit').follow()
508
    resp = resp.click('Your account')
509
    resp = resp.click('Edit')
510
    resp.form.set('edit-profile-first_name', 'John')
511
    resp.form.set('edit-profile-last_name', 'Doe')
512
    assert notify_agents.call_count == 0
513
    resp = resp.form.submit().follow()
514
    assert notify_agents.call_count == 1
489
-