Projet

Général

Profil

0001-agent-a2-prevent-useless-thread-launching-34484.patch

Benjamin Dauvergne, 04 juillet 2019 23:05

Télécharger (24,2 ko)

Voir les différences:

Subject: [PATCH] agent/a2: prevent useless thread launching (#34484)

 hobo/agent/authentic2/middleware.py    |   8 +-
 hobo/agent/authentic2/provisionning.py | 178 +++++++++++++++----------
 tests_authentic/conftest.py            | 129 ++++++++++++------
 tests_authentic/settings.py            |   4 +
 tests_authentic/test_provisionning.py  |  38 +++++-
 5 files changed, 234 insertions(+), 123 deletions(-)
hobo/agent/authentic2/middleware.py
1
from django.conf import settings
2

  
1 3
from .provisionning import provisionning
2 4

  
3 5

  
......
6 8
        provisionning.start()
7 9

  
8 10
    def process_exception(self, request, exception):
9
        provisionning.clean()
11
        provisionning.clear()
10 12

  
11 13
    def process_response(self, request, response):
12
        provisionning.provision()
14
        provisionning.stop(provision=True, wait=getattr(settings, 'HOBO_PROVISIONNING_SYNCHRONOUS', False))
15
        provisionning.clear()
13 16
        return response
14

  
hobo/agent/authentic2/provisionning.py
15 15
from authentic2.a2_rbac.models import RoleAttribute
16 16
from authentic2.models import AttributeValue
17 17

  
18
User = get_user_model()
19
Role = get_role_model()
20
OU = get_ou_model()
21
RoleParenting = get_role_parenting_model()
18 22

  
19
class Provisionning(object):
20
    local = threading.local()
23
logger = logging.getLogger(__name__)
24

  
25

  
26
class Provisionning(threading.local):
27
    __slots__ = ['threads']
21 28
    threads = set()
22 29

  
23 30
    def __init__(self):
24
        self.User = get_user_model()
25
        self.Role = get_role_model()
26
        self.OU = get_ou_model()
27
        self.RoleParenting = get_role_parenting_model()
28
        self.logger = logging.getLogger(__name__)
31
        self.stack = []
29 32

  
30 33
    def start(self):
31
        self.local.saved = {}
32
        self.local.deleted = {}
34
        self.stack.append({
35
            'saved': {},
36
            'deleted': {},
37
        })
38

  
39
    def clear(self):
40
        self.stack = []
41

  
42
    def stop(self, provision=True, wait=True):
43
        if not self.stack:
44
            return
45

  
46
        context = self.stack.pop()
47

  
48
        if provision:
49
            self.provision(**context)
50
            if wait:
51
                self.wait()
33 52

  
34
    def clean(self):
35
        if hasattr(self.local, 'saved'):
36
            del self.local.saved
37
        if hasattr(self.local, 'deleted'):
38
            del self.local.deleted
53
    @property
54
    def saved(self):
55
        if self.stack:
56
            return self.stack[-1]['saved']
57
        return None
39 58

  
40
    def saved(self, *args):
41
        if not hasattr(self.local, 'saved'):
59
    @property
60
    def deleted(self):
61
        if self.stack:
62
            return self.stack[-1]['deleted']
63
        return None
64

  
65
    def add_saved(self, *args):
66
        if not self.stack:
42 67
            return
43 68

  
44 69
        for instance in args:
45
            klass = self.User if isinstance(instance, self.User) else self.Role
46
            self.local.saved.setdefault(klass, set()).add(instance)
70
            klass = User if isinstance(instance, User) else Role
71
            self.saved.setdefault(klass, set()).add(instance)
47 72

  
48
    def deleted(self, *args):
49
        if not hasattr(self.local, 'saved'):
73
    def add_deleted(self, *args):
74
        if not self.stack:
50 75
            return
51 76

  
52 77
        for instance in args:
53
            klass = self.User if isinstance(instance, self.User) else self.Role
54
            self.local.deleted.setdefault(klass, set()).add(instance)
55
            self.local.saved.get(klass, set()).discard(instance)
78
            klass = User if isinstance(instance, User) else Role
79
            self.deleted.setdefault(klass, set()).add(instance)
80
            self.saved.get(klass, set()).discard(instance)
56 81

  
57 82
    def resolve_ou(self, instances, ous):
58 83
        for instance in instances:
......
61 86

  
62 87
    def notify_users(self, ous, users, mode='provision'):
63 88
        if mode == 'provision':
64
            users = (self.User.objects.filter(id__in=[u.id for u in users])
89
            users = (User.objects.filter(id__in=[u.id for u in users])
65 90
                     .select_related('ou').prefetch_related('attribute_values__attribute'))
66 91
        else:
67 92
            self.resolve_ou(users, ous)
......
105 130
            # Find roles giving a superuser attribute
106 131
            # If there is any role of this kind, we do one provisionning message for each user and
107 132
            # each service.
108
            roles_with_attributes = (self.Role.objects.filter(members__in=users)
133
            roles_with_attributes = (Role.objects.filter(members__in=users)
109 134
                                     .parents(include_self=True)
110 135
                                     .filter(attributes__name='is_superuser')
111 136
                                     .exists())
112 137

  
113
            all_roles = (self.Role.objects.filter(members__in=users).parents()
138
            all_roles = (Role.objects.filter(members__in=users).parents()
114 139
                         .prefetch_related('attributes').distinct())
115 140
            roles = dict((r.id, r) for r in all_roles)
116 141
            user_roles = {}
117 142
            parents = {}
118
            for rp in self.RoleParenting.objects.filter(child__in=all_roles):
143
            for rp in RoleParenting.objects.filter(child__in=all_roles):
119 144
                parents.setdefault(rp.child.id, []).append(rp.parent.id)
120
            Through = self.Role.members.through
145
            Through = Role.members.through
121 146
            for u_id, r_id in Through.objects.filter(role__members__in=users).values_list('user_id',
122 147
                                                                                      'role_id'):
123 148
                user_roles.setdefault(u_id, set()).add(roles[r_id])
......
133 158
                for ou, users in ous.iteritems():
134 159
                    for service, audience in self.get_audience(ou):
135 160
                        for user in users:
136
                            self.logger.info(u'provisionning user %s to %s', user, audience)
161
                            logger.info(u'provisionning user %s to %s', user, audience)
137 162
                            notify_agents({
138 163
                                '@type': 'provision',
139 164
                                'issuer': issuer,
......
149 174
                    audience = [a for service, a in self.get_audience(ou)]
150 175
                    if not audience:
151 176
                        continue
152
                    self.logger.info(u'provisionning users %s to %s',
177
                    logger.info(u'provisionning users %s to %s',
153 178
                                     u', '.join(map(unicode, users)), u', '.join(audience))
154 179
                    notify_agents({
155 180
                        '@type': 'provision',
......
162 187
                        }
163 188
                    })
164 189
        elif users:
165
            audience = [audience for ou in self.OU.objects.all()
190
            audience = [audience for ou in OU.objects.all()
166 191
                        for s, audience in self.get_audience(ou)]
167
            self.logger.info(u'deprovisionning users %s from %s', u', '.join(map(unicode, users)),
192
            logger.info(u'deprovisionning users %s from %s', u', '.join(map(unicode, users)),
168 193
                             u', '.join(audience))
169 194
            notify_agents({
170 195
                '@type': 'deprovision',
......
213 238
                ]
214 239

  
215 240
            audience = [entity_id for service, entity_id in self.get_audience(ou)]
216
            self.logger.info(u'%sning roles %s to %s', mode, roles, audience)
241
            logger.info(u'%sning roles %s to %s', mode, roles, audience)
217 242
            notify_agents({
218 243
                '@type': mode,
219 244
                'audience': audience,
......
229 254
            sent_roles = set(ou_roles) | global_roles
230 255
            helper(ou, sent_roles)
231 256

  
232
    def provision(self):
257
    def provision(self, saved, deleted):
258
        # Returns if:
259
        # - we are not in a tenant
260
        # - provsionning is disabled
261
        # - there is nothing to do
233 262
        if (not hasattr(connection, 'tenant') or not connection.tenant or not
234 263
                hasattr(connection.tenant, 'domain_url')):
235 264
            return
236 265
        if not getattr(settings, 'HOBO_ROLE_EXPORT', True):
237 266
            return
238
        # exit early if not started
239
        if not hasattr(self.local, 'saved') or not hasattr(self.local, 'deleted'):
267
        if not (saved or deleted):
240 268
            return
241 269

  
242
        t = threading.Thread(target=self.do_provision, kwargs={
243
            'saved': getattr(self.local, 'saved', {}),
244
            'deleted': getattr(self.local, 'deleted', {}),
245
        })
270
        t = threading.Thread(
271
            target=self.do_provision,
272
            kwargs={'saved': saved, 'deleted': deleted})
246 273
        t.start()
247 274
        self.threads.add(t)
248 275

  
249
    def do_provision(self, saved, deleted, thread=None):
276
    def do_provision(self, saved, deleted):
250 277
        try:
251
            ous = {ou.id: ou for ou in self.OU.objects.all()}
252
            self.notify_roles(ous, saved.get(self.Role, []))
253
            self.notify_roles(ous, deleted.get(self.Role, []), mode='deprovision')
254
            self.notify_users(ous, saved.get(self.User, []))
255
            self.notify_users(ous, deleted.get(self.User, []), mode='deprovision')
278
            ous = {ou.id: ou for ou in OU.objects.all()}
279
            self.notify_roles(ous, saved.get(Role, []))
280
            self.notify_roles(ous, deleted.get(Role, []), mode='deprovision')
281
            self.notify_users(ous, saved.get(User, []))
282
            self.notify_users(ous, deleted.get(User, []), mode='deprovision')
256 283
        except Exception:
257 284
            # last step, clear everything
258
            self.logger.exception(u'error in provisionning thread')
285
            logger.exception(u'error in provisionning thread')
259 286
        finally:
260 287
            self.threads.discard(threading.current_thread())
261 288

  
......
267 294
        self.start()
268 295

  
269 296
    def __exit__(self, exc_type, exc_value, exc_tb):
270
        if exc_type is None:
271
            self.provision()
272
            self.clean()
273
            self.wait()
274
        else:
275
            self.clean()
297
        if not self.stack:
298
            return
299
        self.stop(provision=exc_type is None)
276 300

  
277 301
    def get_audience(self, ou):
278 302
        if ou:
......
298 322
        return urljoin(base_url, reverse('a2-idp-saml-metadata'))
299 323

  
300 324
    def pre_save(self, sender, instance, raw, using, update_fields, **kwargs):
325
        if not self.stack:
326
            return
301 327
        # we skip new instances
302 328
        if not instance.pk:
303 329
            return
304
        if not isinstance(instance, (self.User, self.Role, RoleAttribute, AttributeValue)):
330
        if not isinstance(instance, (User, Role, RoleAttribute, AttributeValue)):
305 331
            return
306 332
        # ignore last_login update on login
307
        if isinstance(instance, self.User) and update_fields == ['last_login']:
333
        if isinstance(instance, User) and (update_fields and set(update_fields) == set(['last_login'])):
308 334
            return
309 335
        if isinstance(instance, RoleAttribute):
310 336
            instance = instance.role
311 337
        elif isinstance(instance, AttributeValue):
312
            if not isinstance(instance.owner, self.User):
338
            if not isinstance(instance.owner, User):
313 339
                return
314 340
            instance = instance.owner
315
        self.saved(instance)
341
        self.add_saved(instance)
316 342

  
317 343
    def post_save(self, sender, instance, created, raw, using, update_fields, **kwargs):
344
        if not self.stack:
345
            return
318 346
        # during post_save we only handle new instances
319
        if isinstance(instance, self.RoleParenting):
320
            self.saved(*list(instance.child.all_members()))
347
        if isinstance(instance, RoleParenting):
348
            self.add_saved(*list(instance.child.all_members()))
321 349
            return
322 350
        if not created:
323 351
            return
324
        if not isinstance(instance, (self.User, self.Role, RoleAttribute, AttributeValue)):
352
        if not isinstance(instance, (User, Role, RoleAttribute, AttributeValue)):
325 353
            return
326 354
        if isinstance(instance, RoleAttribute):
327 355
            instance = instance.role
328 356
        elif isinstance(instance, AttributeValue):
329
            if not isinstance(instance.owner, self.User):
357
            if not isinstance(instance.owner, User):
330 358
                return
331 359
            instance = instance.owner
332
        self.saved(instance)
360
        self.add_saved(instance)
333 361

  
334 362
    def pre_delete(self, sender, instance, using, **kwargs):
335
        if isinstance(instance, (self.User, self.Role)):
336
            self.deleted(copy.copy(instance))
363
        if not self.stack:
364
            return
365
        if isinstance(instance, (User, Role)):
366
            self.add_deleted(copy.copy(instance))
337 367
        elif isinstance(instance, RoleAttribute):
338 368
            instance = instance.role
339
            self.saved(instance)
369
            self.add_saved(instance)
340 370
        elif isinstance(instance, AttributeValue):
341
            if not isinstance(instance.owner, self.User):
371
            if not isinstance(instance.owner, User):
342 372
                return
343 373
            instance = instance.owner
344
            self.saved(instance)
345
        elif isinstance(instance, self.RoleParenting):
346
            self.saved(*list(instance.child.all_members()))
374
            self.add_saved(instance)
375
        elif isinstance(instance, RoleParenting):
376
            self.add_saved(*list(instance.child.all_members()))
347 377

  
348 378
    def m2m_changed(self, sender, instance, action, reverse, model, pk_set, using, **kwargs):
379
        if not self.stack:
380
            return
349 381
        if action != 'pre_clear' and action.startswith('pre_'):
350 382
            return
351
        if sender is self.Role.members.through:
352
            self.saved(instance)
383
        if sender is Role.members.through:
384
            self.add_saved(instance)
353 385
            # on a clear, pk_set is None
354 386
            for other_instance in model.objects.filter(pk__in=pk_set or []):
355
                self.saved(other_instance)
387
                self.add_saved(other_instance)
356 388
            if action == 'pre_clear':
357 389
                # when the action is pre_clear we need to lookup the current value of the members
358 390
                # relation, to re-provision all previously enroled users.
359 391
                if not reverse:
360 392
                    for other_instance in instance.members.all():
361
                        self.saved(other_instance)
393
                        self.add_saved(other_instance)
tests_authentic/conftest.py
1 1
import os
2
import tempfile
3
import shutil
4 2
import json
5 3

  
4
from django_webtest import WebTestMixin, DjangoTestApp
6 5
import pytest
7 6

  
7
from django.db import connection, transaction
8
from tenant_schemas.postgresql_backend.base import FakeTenant
9
from tenant_schemas.utils import tenant_context
10

  
11
from hobo.multitenant.models import Tenant
12

  
8 13

  
9 14
@pytest.fixture
10
def tenant_base(request, settings):
11
    base = tempfile.mkdtemp('authentic-tenant-base')
15
def tenant_base(tmpdir, settings):
16
    base = str(tmpdir.mkdir('authentic-tenant-base'))
12 17
    settings.TENANT_BASE = base
13

  
14
    def fin():
15
        shutil.rmtree(base)
16
    request.addfinalizer(fin)
17 18
    return base
18 19

  
19 20

  
20
@pytest.fixture(scope='function')
21
def tenant(transactional_db, request, settings, tenant_base):
22
    from hobo.multitenant.models import Tenant
23
    base = tenant_base
21
@pytest.fixture
22
def tenant_factory(transactional_db, tenant_base):
23
    tenants = []
24 24

  
25
    @pytest.mark.django_db
26
    def make_tenant(name):
27
        tenant_dir = os.path.join(base, name)
25
    def factory(name):
26
        tenant_dir = os.path.join(tenant_base, name)
28 27
        os.mkdir(tenant_dir)
29 28
        with open(os.path.join(tenant_dir, 'unsecure'), 'w') as fd:
30 29
            fd.write('1')
......
37 36
                    'other_variable': 'foo',
38 37
                },
39 38
                'services': [
40
                    {'slug': 'test',
41
                     'service-id': 'authentic',
42
                     'title': 'Test',
43
                     'this': True,
44
                     'secret_key': '12345',
45
                     'base_url': 'http://%s' % name,
46
                     'variables': {
47
                         'other_variable': 'bar',
48
                     }
39
                    {
40
                        'slug': 'test',
41
                        'service-id': 'authentic',
42
                        'title': 'Test',
43
                        'this': True,
44
                        'secret_key': '12345',
45
                        'base_url': 'http://%s' % name,
46
                        'variables': {
47
                            'other_variable': 'bar',
48
                        }
49
                    },
50
                    {
51
                        'slug': 'other',
52
                        'title': 'Other',
53
                        'service-id': 'welco',
54
                        'secret_key': 'abcdef',
55
                        'base_url': 'http://other.example.net'
49 56
                    },
50
                    {'slug': 'other',
51
                     'title': 'Other',
52
                     'service-id': 'welco',
53
                     'secret_key': 'abcdef',
54
                     'base_url': 'http://other.example.net'},
55
                    ]}, fd)
56
        t = Tenant(domain_url=name,
57
                      schema_name=name.replace('-', '_').replace('.', '_'))
58
        t.create_schema()
57
                ]
58
            }, fd)
59
        schema_name = name.replace('-', '_').replace('.', '_')
60
        t = Tenant(domain_url=name, schema_name=schema_name)
61
        with transaction.atomic():
62
            t.create_schema()
63
        tenants.append(t)
59 64
        return t
60
    tenants = [make_tenant('authentic.example.net')]
61

  
62
    def fin():
63
        from django.db import connection
65
    try:
66
        yield factory
67
    finally:
68
        # cleanup all created tenants
64 69
        connection.set_schema_to_public()
65
        for t in tenants:
66
            t.delete(True)
67
    request.addfinalizer(fin)
68
    return tenants[0]
70
        with tenant_context(FakeTenant('public')):
71
            for tenant in tenants:
72
                tenant.delete(force_drop=True)
73

  
74

  
75
@pytest.fixture
76
def tenant(tenant_factory):
77
    return tenant_factory('authentic.example.net')
78

  
79

  
80
@pytest.fixture
81
def another_tenant(tenant_factory):
82
    return tenant_factory('another-authentic.example.net')
83

  
84

  
85
# remove this class when this issue is fixed
86
# https://github.com/django-webtest/django-webtest/issues/102
87
class FixEnvironDjangoTestApp(DjangoTestApp):
88
    def _update_environ(self, environ, user=None):
89
        fixup = ((not environ or 'HTTP_HOST' not in environ)
90
                 and (self.extra_environ and 'HTTP_HOST' in self.extra_environ))
91
        environ = super(FixEnvironDjangoTestApp, self)._update_environ(environ, user=user)
92
        if fixup:
93
            environ.pop('HTTP_HOST', None)
94
        return environ
95

  
96

  
97
@pytest.fixture
98
def app_factory(request):
99
    wtm = WebTestMixin()
100
    wtm._patch_settings()
101

  
102
    def factory(hostname='testserver'):
103
        if hasattr(hostname, 'domain_url'):
104
            hostname = hostname.domain_url
105
        return FixEnvironDjangoTestApp(extra_environ={'HTTP_HOST': hostname})
106

  
107
    try:
108
        yield factory
109
    finally:
110
        wtm._unpatch_settings()
111

  
112

  
113
@pytest.fixture
114
def notify_agents(mocker):
115
    yield mocker.patch('hobo.agent.authentic2.provisionning.notify_agents')
tests_authentic/settings.py
39 39
}
40 40

  
41 41
HOBO_ROLE_EXPORT = True
42

  
43
ALLOWED_HOSTS = ['*']
44
SESSION_COOKIE_SECURE = False
45
CSRF_COOKIE_SECURE = False
tests_authentic/test_provisionning.py
17 17
from authentic2.models import Attribute, AttributeValue
18 18
from hobo.agent.authentic2.provisionning import provisionning
19 19

  
20
User = get_user_model()
21

  
20 22
pytestmark = pytest.mark.django_db
21 23

  
22 24

  
......
126 128
            role2.attributes.create(kind='json', name='emails', value='["zob@example.net"]')
127 129
            child_role = Role.objects.create(name='child', ou=get_default_ou())
128 130
            notify_agents.reset_mock()
129
            User = get_user_model()
130 131
            attribute = Attribute.objects.create(label='Code postal', name='code_postal',
131 132
                                                 kind='string')
132 133
            with provisionning:
......
277 278
                data = objects['data']
278 279
                assert isinstance(data, list)
279 280
                assert len(data) == 1
280
                print data
281 281
                for o in data:
282 282
                    assert set(o.keys()) >= set(['uuid', 'username', 'first_name',
283 283
                                                 'is_superuser', 'last_name', 'email', 'roles'])
......
447 447
    with tenant_context(tenant):
448 448
        # create a provider so notification messages have an audience.
449 449
        LibertyProvider.objects.create(ou=None, name='provider',
450
                                                 entity_id='http://provider.com',
451
                                                 protocol_conformance=lasso.PROTOCOL_SAML_2_0)
450
                                       entity_id='http://provider.com',
451
                                       protocol_conformance=lasso.PROTOCOL_SAML_2_0)
452 452
    with patch('hobo.agent.authentic2.provisionning.notify_agents') as notify_agents:
453 453
        call_command('createsuperuser', domain=tenant.domain_url, uuid='coin',
454 454
                     username='coin', email='coin@coin.org', interactive=False)
455 455
        assert notify_agents.call_count == 1
456 456

  
457 457

  
458
@patch('hobo.agent.authentic2.provisionning.notify_agents')
459 458
def test_command_hobo_provision(notify_agents, transactional_db, tenant, caplog):
460
    User = get_user_model()
461 459
    with tenant_context(tenant):
462 460
        ou = get_default_ou()
463 461
        LibertyProvider.objects.create(ou=ou, name='provider',
......
486 484
    assert msg_2['full'] is False
487 485
    assert msg_2['objects']['@type'] == 'user'
488 486
    assert len(msg_2['objects']['data']) == 10
487

  
488

  
489
def test_middleware(notify_agents, app_factory, tenant, settings):
490
    settings.HOBO_PROVISIONNING_SYNCHRONOUS = True
491

  
492
    with tenant_context(tenant):
493
        user = User.objects.create(username='john', ou=get_default_ou())
494
        user.set_password('password')
495
        user.save()
496
        LibertyProvider.objects.create(ou=get_default_ou(),
497
                                       name='provider',
498
                                       entity_id='http://provider.com',
499
                                       protocol_conformance=lasso.PROTOCOL_SAML_2_0)
500
    assert notify_agents.call_count == 0
501

  
502
    app = app_factory(tenant)
503
    resp = app.get('/login/')
504
    form = resp.form
505
    form.set('username', 'john')
506
    form.set('password', 'password')
507
    resp = form.submit(name='login-password-submit').follow()
508
    resp = resp.click('Your account')
509
    resp = resp.click('Edit')
510
    resp.form.set('edit-profile-first_name', 'John')
511
    resp.form.set('edit-profile-last_name', 'Doe')
512
    assert notify_agents.call_count == 0
513
    resp = resp.form.submit().follow()
514
    assert notify_agents.call_count == 1
489
-