From 3a4ebae3c47ed623b78efb03d42325668b479626 Mon Sep 17 00:00:00 2001 From: Benjamin Dauvergne Date: Tue, 2 Jul 2019 11:38:34 +0200 Subject: [PATCH] agent/a2: prevent useless thread launching (#34484) --- hobo/agent/authentic2/middleware.py | 5 +- hobo/agent/authentic2/provisionning.py | 170 ++++++++++++++----------- 2 files changed, 100 insertions(+), 75 deletions(-) diff --git a/hobo/agent/authentic2/middleware.py b/hobo/agent/authentic2/middleware.py index c8701f3..f4af687 100644 --- a/hobo/agent/authentic2/middleware.py +++ b/hobo/agent/authentic2/middleware.py @@ -6,9 +6,8 @@ class ProvisionningMiddleware(object): provisionning.start() def process_exception(self, request, exception): - provisionning.clean() + provisionning.stop(provision=False) def process_response(self, request, response): - provisionning.provision() + provisionning.provision(provision=True, wait=False) return response - diff --git a/hobo/agent/authentic2/provisionning.py b/hobo/agent/authentic2/provisionning.py index f21dbda..e23f146 100644 --- a/hobo/agent/authentic2/provisionning.py +++ b/hobo/agent/authentic2/provisionning.py @@ -15,44 +15,63 @@ from authentic2.saml.models import LibertyProvider from authentic2.a2_rbac.models import RoleAttribute from authentic2.models import AttributeValue +User = get_user_model() +Role = get_role_model() +OU = get_ou_model() +RoleParenting = get_role_parenting_model() -class Provisionning(object): - local = threading.local() +logger = logging.getLogger(__name__) + + +class Provisionning(threading.local): + __slots__ = ['threads'] threads = set() def __init__(self): - self.User = get_user_model() - self.Role = get_role_model() - self.OU = get_ou_model() - self.RoleParenting = get_role_parenting_model() - self.logger = logging.getLogger(__name__) + self.stack = [] def start(self): - self.local.saved = {} - self.local.deleted = {} + self.stack.append({ + 'saved': {}, + 'deleted': {}, + }) + + def stop(self, provision=True, wait=True): + context = self.stack.pop() - def clean(self): - if hasattr(self.local, 'saved'): - del self.local.saved - if hasattr(self.local, 'deleted'): - del self.local.deleted + if provision: + self.provision(**context) + if wait: + self.wait() - def saved(self, *args): - if not hasattr(self.local, 'saved'): + @property + def saved(self): + if self.stack: + return self.stack[-1]['saved'] + return None + + @property + def deleted(self): + if self.stack: + return self.stack[-1]['deleted'] + return None + + def add_saved(self, *args): + if not self.stack: return for instance in args: - klass = self.User if isinstance(instance, self.User) else self.Role - self.local.saved.setdefault(klass, set()).add(instance) + klass = User if isinstance(instance, User) else Role + self.saved.setdefault(klass, set()).add(instance) - def deleted(self, *args): - if not hasattr(self.local, 'saved'): + def add_deleted(self, *args): + if not self.stack: return for instance in args: - klass = self.User if isinstance(instance, self.User) else self.Role - self.local.deleted.setdefault(klass, set()).add(instance) - self.local.saved.get(klass, set()).discard(instance) + klass = User if isinstance(instance, User) else Role + self.deleted.setdefault(klass, set()).add(instance) + self.saved.get(klass, set()).discard(instance) def resolve_ou(self, instances, ous): for instance in instances: @@ -61,7 +80,7 @@ class Provisionning(object): def notify_users(self, ous, users, mode='provision'): if mode == 'provision': - users = (self.User.objects.filter(id__in=[u.id for u in users]) + users = (User.objects.filter(id__in=[u.id for u in users]) .select_related('ou').prefetch_related('attribute_values__attribute')) else: self.resolve_ou(users, ous) @@ -105,19 +124,19 @@ class Provisionning(object): # Find roles giving a superuser attribute # If there is any role of this kind, we do one provisionning message for each user and # each service. - roles_with_attributes = (self.Role.objects.filter(members__in=users) + roles_with_attributes = (Role.objects.filter(members__in=users) .parents(include_self=True) .filter(attributes__name='is_superuser') .exists()) - all_roles = (self.Role.objects.filter(members__in=users).parents() + all_roles = (Role.objects.filter(members__in=users).parents() .prefetch_related('attributes').distinct()) roles = dict((r.id, r) for r in all_roles) user_roles = {} parents = {} - for rp in self.RoleParenting.objects.filter(child__in=all_roles): + for rp in RoleParenting.objects.filter(child__in=all_roles): parents.setdefault(rp.child.id, []).append(rp.parent.id) - Through = self.Role.members.through + Through = Role.members.through for u_id, r_id in Through.objects.filter(role__members__in=users).values_list('user_id', 'role_id'): user_roles.setdefault(u_id, set()).add(roles[r_id]) @@ -133,7 +152,7 @@ class Provisionning(object): for ou, users in ous.iteritems(): for service, audience in self.get_audience(ou): for user in users: - self.logger.info(u'provisionning user %s to %s', user, audience) + logger.info(u'provisionning user %s to %s', user, audience) notify_agents({ '@type': 'provision', 'issuer': issuer, @@ -149,7 +168,7 @@ class Provisionning(object): audience = [a for service, a in self.get_audience(ou)] if not audience: continue - self.logger.info(u'provisionning users %s to %s', + logger.info(u'provisionning users %s to %s', u', '.join(map(unicode, users)), u', '.join(audience)) notify_agents({ '@type': 'provision', @@ -162,9 +181,9 @@ class Provisionning(object): } }) elif users: - audience = [audience for ou in self.OU.objects.all() + audience = [audience for ou in OU.objects.all() for s, audience in self.get_audience(ou)] - self.logger.info(u'deprovisionning users %s from %s', u', '.join(map(unicode, users)), + logger.info(u'deprovisionning users %s from %s', u', '.join(map(unicode, users)), u', '.join(audience)) notify_agents({ '@type': 'deprovision', @@ -213,7 +232,7 @@ class Provisionning(object): ] audience = [entity_id for service, entity_id in self.get_audience(ou)] - self.logger.info(u'%sning roles %s to %s', mode, roles, audience) + logger.info(u'%sning roles %s to %s', mode, roles, audience) notify_agents({ '@type': mode, 'audience': audience, @@ -229,33 +248,35 @@ class Provisionning(object): sent_roles = set(ou_roles) | global_roles helper(ou, sent_roles) - def provision(self): + def provision(self, saved, deleted): + # Returns if: + # - we are not in a tenant + # - provsionning is disabled + # - there is nothing to do if (not hasattr(connection, 'tenant') or not connection.tenant or not hasattr(connection.tenant, 'domain_url')): return if not getattr(settings, 'HOBO_ROLE_EXPORT', True): return - # exit early if not started - if not hasattr(self.local, 'saved') or not hasattr(self.local, 'deleted'): + if not (saved or deleted): return - t = threading.Thread(target=self.do_provision, kwargs={ - 'saved': getattr(self.local, 'saved', {}), - 'deleted': getattr(self.local, 'deleted', {}), - }) + t = threading.Thread( + target=self.do_provision, + kwargs={'saved': saved, 'deleted': deleted}) t.start() self.threads.add(t) def do_provision(self, saved, deleted, thread=None): try: - ous = {ou.id: ou for ou in self.OU.objects.all()} - self.notify_roles(ous, saved.get(self.Role, [])) - self.notify_roles(ous, deleted.get(self.Role, []), mode='deprovision') - self.notify_users(ous, saved.get(self.User, [])) - self.notify_users(ous, deleted.get(self.User, []), mode='deprovision') + ous = {ou.id: ou for ou in OU.objects.all()} + self.notify_roles(ous, saved.get(Role, [])) + self.notify_roles(ous, deleted.get(Role, []), mode='deprovision') + self.notify_users(ous, saved.get(User, [])) + self.notify_users(ous, deleted.get(User, []), mode='deprovision') except Exception: # last step, clear everything - self.logger.exception(u'error in provisionning thread') + logger.exception(u'error in provisionning thread') finally: self.threads.discard(threading.current_thread()) @@ -267,12 +288,9 @@ class Provisionning(object): self.start() def __exit__(self, exc_type, exc_value, exc_tb): - if exc_type is None: - self.provision() - self.clean() - self.wait() - else: - self.clean() + if not self.stack: + return + self.stop(provision=exc_type is None) def get_audience(self, ou): if ou: @@ -298,64 +316,72 @@ class Provisionning(object): return urljoin(base_url, reverse('a2-idp-saml-metadata')) def pre_save(self, sender, instance, raw, using, update_fields, **kwargs): + if not self.stack: + return # we skip new instances if not instance.pk: return - if not isinstance(instance, (self.User, self.Role, RoleAttribute, AttributeValue)): + if not isinstance(instance, (User, Role, RoleAttribute, AttributeValue)): return # ignore last_login update on login - if isinstance(instance, self.User) and update_fields == ['last_login']: + if isinstance(instance, User) and update_fields == ['last_login']: return if isinstance(instance, RoleAttribute): instance = instance.role elif isinstance(instance, AttributeValue): - if not isinstance(instance.owner, self.User): + if not isinstance(instance.owner, User): return instance = instance.owner - self.saved(instance) + self.add_saved(instance) def post_save(self, sender, instance, created, raw, using, update_fields, **kwargs): + if not self.stack: + return # during post_save we only handle new instances - if isinstance(instance, self.RoleParenting): - self.saved(*list(instance.child.all_members())) + if isinstance(instance, RoleParenting): + self.add_saved(*list(instance.child.all_members())) return if not created: return - if not isinstance(instance, (self.User, self.Role, RoleAttribute, AttributeValue)): + if not isinstance(instance, (User, Role, RoleAttribute, AttributeValue)): return if isinstance(instance, RoleAttribute): instance = instance.role elif isinstance(instance, AttributeValue): - if not isinstance(instance.owner, self.User): + if not isinstance(instance.owner, User): return instance = instance.owner - self.saved(instance) + self.add_saved(instance) def pre_delete(self, sender, instance, using, **kwargs): - if isinstance(instance, (self.User, self.Role)): - self.deleted(copy.copy(instance)) + if not self.stack: + return + if isinstance(instance, (User, Role)): + self.add_deleted(copy.copy(instance)) elif isinstance(instance, RoleAttribute): instance = instance.role - self.saved(instance) + self.add_saved(instance) elif isinstance(instance, AttributeValue): - if not isinstance(instance.owner, self.User): + if not isinstance(instance.owner, User): return instance = instance.owner - self.saved(instance) - elif isinstance(instance, self.RoleParenting): - self.saved(*list(instance.child.all_members())) + self.add_saved(instance) + elif isinstance(instance, RoleParenting): + self.add_saved(*list(instance.child.all_members())) def m2m_changed(self, sender, instance, action, reverse, model, pk_set, using, **kwargs): + if not self.stack: + return if action != 'pre_clear' and action.startswith('pre_'): return - if sender is self.Role.members.through: - self.saved(instance) + if sender is Role.members.through: + self.add_saved(instance) # on a clear, pk_set is None for other_instance in model.objects.filter(pk__in=pk_set or []): - self.saved(other_instance) + self.add_saved(other_instance) if action == 'pre_clear': # when the action is pre_clear we need to lookup the current value of the members # relation, to re-provision all previously enroled users. if not reverse: for other_instance in instance.members.all(): - self.saved(other_instance) + self.add_saved(other_instance) -- 2.20.1