From 5115297d93f6cc1f8d7fb06e33c71ab02674d7f8 Mon Sep 17 00:00:00 2001 From: Benjamin Dauvergne Date: Mon, 31 Jan 2022 22:26:19 +0100 Subject: [PATCH 5/6] django_rbac: new update_transitive_closure algorithm (#57500) --- src/django_rbac/managers.py | 78 ++++++++++++++++++++----------------- src/django_rbac/models.py | 4 +- 2 files changed, 44 insertions(+), 38 deletions(-) diff --git a/src/django_rbac/managers.py b/src/django_rbac/managers.py index 0faa1444..6dfe88fc 100644 --- a/src/django_rbac/managers.py +++ b/src/django_rbac/managers.py @@ -1,12 +1,12 @@ import contextlib -import functools import threading from django.contrib.auth import get_user_model from django.contrib.contenttypes.models import ContentType -from django.db import models +from django.db import connection, models from django.db.models import query from django.db.models.query import Prefetch, Q +from django.db.transaction import atomic from . import utils @@ -177,40 +177,46 @@ class RoleParentingManager(models.Manager): self.tls.CLOSURE_UPDATED = True return - # existing indirect paths - old = set(self.filter(direct=False).values_list('parent_id', 'child_id')) - # existing direct paths - ris = set(self.filter(direct=True).values_list('parent_id', 'child_id')) - add = set() - new = set() - old_new = ris - - # Start computing new indirect paths - while True: - for (i, j) in ris: - for (k, l) in old_new: - if j == k and (i, l) not in ris: - new.add((i, l)) - if old_new != ris: - for (i, j) in old_new: - for (k, l) in ris: - if j == k and (i, l) not in ris: - new.add((i, l)) - if not new: - break - add.update(new) - ris.update(new) - old_new = new - new = set() - # Create new relations - self.model.objects.bulk_create( - self.model(parent_id=a, child_id=b, direct=False) for a, b in add - old - ) - # Delete old ones - obsolete = old - add - if obsolete: - queries = (query.Q(parent_id=a, child_id=b, direct=False) for a, b in obsolete) - self.model.objects.filter(functools.reduce(query.Q.__or__, queries)).delete() + with atomic(savepoint=False): + # existing direct paths + direct = set(self.filter(direct=True).values_list('parent_id', 'child_id')) + old_indirects = set(self.filter(direct=False).values_list('parent_id', 'child_id')) + indirects = set(direct) + + while True: + changed = False + for (i, j) in list(indirects): + for (k, l) in direct: + if j == k and i != l and (i, l) not in indirects: + indirects.add((i, l)) + changed = True + if not changed: + break + + with connection.cursor() as cur: + # Delete old ones + obsolete = old_indirects - indirects - direct + if obsolete: + obsolete_values = ', '.join('(%s, %s)' % (a, b) for a, b in obsolete) + sql = '''DELETE FROM "%s" AS relation \ +USING (VALUES %s) AS dead(parent_id, child_id) \ +WHERE relation.direct = 'false' AND relation.parent_id = dead.parent_id \ +AND relation.child_id = dead.child_id''' % ( + self.model._meta.db_table, + obsolete_values, + ) + cur.execute(sql) + # Create new indirect relations + new = indirects - old_indirects - direct + if new: + new_values = ', '.join( + ("(%s, %s, 'false')" % (parent_id, child_id) for parent_id, child_id in new) + ) + sql = '''INSERT INTO "%s" (parent_id, child_id, direct) VALUES %s''' % ( + self.model._meta.db_table, + new_values, + ) + cur.execute(sql) @contextlib.contextmanager diff --git a/src/django_rbac/models.py b/src/django_rbac/models.py index 2080a224..e3228c24 100644 --- a/src/django_rbac/models.py +++ b/src/django_rbac/models.py @@ -186,7 +186,7 @@ class RoleAbstractBase(AbstractOrganizationalUnitScopedBase, AbstractBase): def add_child(self, child): RoleParenting = utils.get_role_parenting_model() - RoleParenting.objects.get_or_create(parent=self, child=child) + RoleParenting.objects.get_or_create(parent=self, child=child, direct=True) def remove_child(self, child): RoleParenting = utils.get_role_parenting_model() @@ -194,7 +194,7 @@ class RoleAbstractBase(AbstractOrganizationalUnitScopedBase, AbstractBase): def add_parent(self, parent): RoleParenting = utils.get_role_parenting_model() - RoleParenting.objects.get_or_create(parent=parent, child=self) + RoleParenting.objects.get_or_create(parent=parent, child=self, direct=True) def remove_parent(self, parent): RoleParenting = utils.get_role_parenting_model() -- 2.34.1