Projet

Général

Profil

0003-a2_rbac-move-managers-from-django_rbac-70894.patch

Valentin Deniaud, 03 novembre 2022 14:40

Télécharger (27,8 ko)

Voir les différences:

Subject: [PATCH 3/4] a2_rbac: move managers from django_rbac (#70894)

 src/authentic2/a2_rbac/managers.py        | 304 +++++++++++++++++++++-
 src/authentic2/a2_rbac/models.py          |  11 +-
 src/authentic2/a2_rbac/signal_handlers.py |   3 +-
 src/django_rbac/managers.py               | 292 ---------------------
 4 files changed, 303 insertions(+), 307 deletions(-)
 delete mode 100644 src/django_rbac/managers.py
src/authentic2/a2_rbac/managers.py
14 14
# You should have received a copy of the GNU Affero General Public License
15 15
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
16 16

  
17
import contextlib
18
import datetime
19
import threading
20

  
21
from django.contrib.auth import get_user_model
17 22
from django.contrib.contenttypes.models import ContentType
23
from django.db import connection, models
24
from django.db.models import query
25
from django.db.models.query import Prefetch, Q
26
from django.db.transaction import atomic
18 27

  
19
from authentic2.a2_rbac import models
20
from django_rbac.managers import AbstractBaseManager
21
from django_rbac.managers import RoleManager as BaseRoleManager
28
from django_rbac import utils
22 29
from django_rbac.utils import get_operation
23 30

  
31
from . import models as a2_models
32
from . import signals
33

  
34

  
35
class AbstractBaseManager(models.Manager):
36
    def get_by_natural_key(self, uuid):
37
        return self.get(uuid=uuid)
38

  
39

  
40
class OperationManager(models.Manager):
41
    def get_by_natural_key(self, slug):
42
        return self.get(slug=slug)
43

  
44
    def has_perm(self, user, operation_slug, object_or_model, ou=None):
45
        """Test if an user can do the operation given by operation_slug
46
        on the given object_or_model eventually scoped by an organizational
47
        unit given by ou.
48

  
49
        Returns True or False.
50
        """
51
        ou_query = query.Q(ou__isnull=True)
52
        if ou:
53
            ou_query |= query.Q(ou=ou.as_scope())
54
        ct = ContentType.objects.get_for_model(object_or_model)
55
        target_query = query.Q(target_ct=ContentType.objects.get_for_model(ContentType), target_id=ct.pk)
56
        if isinstance(object_or_model, models.Model):
57
            target_query |= query.Q(target_ct=ct, target_id=object.pk)
58
        Permission = utils.get_permission_model()
59
        qs = Permission.objects.for_user(user)
60
        qs = qs.filter(operation__slug=operation_slug)
61
        qs = qs.filter(ou_query & target_query)
62
        return qs.exists()
63

  
64

  
65
class PermissionManagerBase(models.Manager):
66
    def get_by_natural_key(self, operation_slug, ou_nk, target_ct, target_nk):
67
        qs = self.filter(operation__slug=operation_slug)
68
        if ou_nk:
69
            OrganizationalUnit = utils.get_ou_model()
70
            try:
71
                ou = OrganizationalUnit.objects.get_by_natural_key(*ou_nk)
72
            except OrganizationalUnit.DoesNotExist:
73
                raise self.model.DoesNotExist
74
            qs = qs.filter(ou=ou)
75
        else:
76
            qs = qs.filter(ou__isnull=True)
77
        try:
78
            target_ct = ContentType.objects.get_by_natural_key(*target_ct)
79
        except ContentType.DoesNotExist:
80
            raise self.model.DoesNotExist
81
        target_model = target_ct.model_class()
82
        try:
83
            target = target_model.objects.get_by_natural_key(*target_nk)
84
        except target_model.DoesNotExist:
85
            raise self.model.DoesNotExist
86
        return qs.get(target_ct=ContentType.objects.get_for_model(target), target_id=target.pk)
87

  
88

  
89
class PermissionQueryset(query.QuerySet):
90
    def by_target_ct(self, target):
91
        """Filter permission whose target content-type matches the content
92
        type of the target argument
93
        """
94
        target_ct = ContentType.objects.get_for_model(target)
95
        return self.filter(target_ct=target_ct)
96

  
97
    def by_target(self, target):
98
        '''Filter permission whose target matches target'''
99
        return self.by_target_ct(target).filter(target_id=target.pk)
100

  
101
    def for_user(self, user):
102
        """Retrieve all permissions hold by an user through its role and
103
        inherited roles.
104
        """
105
        Role = utils.get_role_model()
106
        roles = Role.objects.for_user(user=user)
107
        return self.filter(roles__in=roles)
108

  
109
    def cleanup(self):
110
        count = 0
111
        for p in self:
112
            if not p.target and (p.target_ct_id or p.target_id):
113
                p.delete()
114
                count += 1
115
        return count
116

  
117

  
118
PermissionManager = PermissionManagerBase.from_queryset(PermissionQueryset)
119

  
120

  
121
class IntCast(models.Func):
122
    function = 'int'
123
    template = 'CAST((%(expressions)s) AS %(function)s)'
124

  
125

  
126
class RoleQuerySet(query.QuerySet):
127
    def for_user(self, user):
128
        if hasattr(user, 'apiclient_roles'):
129
            queryset = self.filter(apiclients=user)
130
        else:
131
            queryset = self.filter(members=user)
132
        return queryset.parents().distinct()
133

  
134
    def parents(self, include_self=True, annotate=False, direct=None):
135
        assert annotate is False or direct is not True, 'annotate=True cannot be used with direct=True'
136
        if direct is None:
137
            qs = self.model.objects.filter(
138
                child_relation__deleted__isnull=True,
139
                child_relation__child__in=self,
140
            )
141
        else:
142
            qs = self.model.objects.filter(
143
                child_relation__deleted__isnull=True,
144
                child_relation__child__in=self,
145
                child_relation__direct=direct,
146
            )
147
        if include_self:
148
            qs = self | qs
149
        qs = qs.distinct()
150
        if annotate:
151
            qs = qs.annotate(direct=models.Max(IntCast('child_relation__direct')))
152
        return qs
153

  
154
    def children(self, include_self=True, annotate=False, direct=None):
155
        assert annotate is False or direct is not True, 'annotate=True cannot be used with direct=True'
156
        if direct is None:
157
            qs = self.model.objects.filter(
158
                parent_relation__deleted__isnull=True,
159
                parent_relation__parent__in=self,
160
            )
161
        else:
162
            qs = self.model.objects.filter(
163
                parent_relation__deleted__isnull=True,
164
                parent_relation__parent__in=self,
165
                parent_relation__direct=direct,
166
            )
167
        if include_self:
168
            qs = self | qs
169
        qs = qs.distinct()
170
        if annotate:
171
            qs = qs.annotate(direct=models.Max(IntCast('parent_relation__direct')))
172
        return qs
173

  
174
    def all_members(self):
175
        User = get_user_model()
176
        prefetch = Prefetch('roles', queryset=self, to_attr='direct')
177
        return (
178
            User.objects.filter(
179
                Q(roles__in=self)
180
                | Q(roles__parent_relation__parent__in=self, roles__parent_relation__deleted__isnull=True)
181
            )
182
            .distinct()
183
            .prefetch_related(prefetch)
184
        )
185

  
186
    def by_admin_scope_ct(self, admin_scope):
187
        admin_scope_ct = ContentType.objects.get_for_model(admin_scope)
188
        return self.filter(admin_scope_ct=admin_scope_ct)
189

  
190
    def cleanup(self):
191
        count = 0
192
        for r in self.filter(Q(admin_scope_ct_id__isnull=False) | Q(admin_scope_id__isnull=False)):
193
            if not r.admin_scope:
194
                r.delete()
195
                count += 1
196
        return count
197

  
198

  
199
BaseRoleManager = AbstractBaseManager.from_queryset(RoleQuerySet)
200

  
201

  
202
class RoleParentingManager(models.Manager):
203
    class Local(threading.local):
204
        DO_UPDATE_CLOSURE = True
205
        CLOSURE_UPDATED = False
206

  
207
    tls = Local()
208

  
209
    def get_by_natural_key(self, parent_nk, child_nk, direct):
210
        Role = utils.get_role_model()
211
        try:
212
            parent = Role.objects.get_by_natural_key(*parent_nk)
213
        except Role.DoesNotExist:
214
            raise self.model.DoesNotExist
215
        try:
216
            child = Role.objects.get_by_natural_key(*child_nk)
217
        except Role.DoesNotExist:
218
            raise self.model.DoesNotExist
219
        return self.get(parent=parent, child=child, direct=direct)
220

  
221
    def soft_create(self, parent, child):
222
        with atomic(savepoint=False):
223
            rp, created = self.get_or_create(parent=parent, child=child, direct=True)
224
            new = created or rp.deleted
225
            if not created and rp.deleted:
226
                rp.created = datetime.datetime.now()
227
                rp.deleted = None
228
                rp.save(update_fields=['created', 'deleted'])
229
            if new:
230
                signals.post_soft_create.send(sender=self.model, instance=rp)
231

  
232
    def soft_delete(self, parent, child):
233
        qs = self.filter(parent=parent, child=child, deleted__isnull=True, direct=True)
234
        with atomic(savepoint=False):
235
            rp = qs.first()
236
            if rp:
237
                count = qs.update(deleted=datetime.datetime.now())
238
                # read-commited, view of tables can change during transaction
239
                if count:
240
                    signals.post_soft_delete.send(sender=self.model, instance=rp)
241

  
242
    def update_transitive_closure(self):
243
        """Recompute the transitive closure of the inheritance relation
244
        from scratch. Add missing indirect relations and delete
245
        obsolete indirect relations.
246
        """
247
        if not self.tls.DO_UPDATE_CLOSURE:
248
            self.tls.CLOSURE_UPDATED = True
249
            return
250

  
251
        with atomic(savepoint=False):
252
            # existing direct paths
253
            direct = set(self.filter(direct=True, deleted__isnull=True).values_list('parent_id', 'child_id'))
254
            old_indirects = set(
255
                self.filter(direct=False, deleted__isnull=True).values_list('parent_id', 'child_id')
256
            )
257
            indirects = set(direct)
258

  
259
            while True:
260
                changed = False
261
                for (i, j) in list(indirects):
262
                    for (k, l) in direct:
263
                        if j == k and i != l and (i, l) not in indirects:
264
                            indirects.add((i, l))
265
                            changed = True
266
                if not changed:
267
                    break
268

  
269
            with connection.cursor() as cur:
270
                # Delete old ones
271
                obsolete = old_indirects - indirects - direct
272
                if obsolete:
273
                    sql = '''UPDATE "%s" AS relation \
274
SET deleted = now()\
275
FROM (VALUES %s) AS dead(parent_id, child_id) \
276
WHERE relation.direct = 'false' AND relation.parent_id = dead.parent_id \
277
AND relation.child_id = dead.child_id AND deleted IS NULL''' % (
278
                        self.model._meta.db_table,
279
                        ', '.join('(%s, %s)' % (a, b) for a, b in obsolete),
280
                    )
281
                    cur.execute(sql)
282
                # Create new indirect relations
283
                new = indirects - old_indirects - direct
284
                if new:
285
                    new_values = ', '.join(
286
                        (
287
                            "(%s, %s, 'false', now(), NULL)" % (parent_id, child_id)
288
                            for parent_id, child_id in new
289
                        )
290
                    )
291
                    sql = '''INSERT INTO "%s" (parent_id, child_id, direct, created, deleted) VALUES %s \
292
ON CONFLICT (parent_id, child_id, direct) DO UPDATE SET created = EXCLUDED.created, deleted = NULL''' % (
293
                        self.model._meta.db_table,
294
                        new_values,
295
                    )
296
                    cur.execute(sql)
297

  
298

  
299
@contextlib.contextmanager
300
def defer_update_transitive_closure():
301
    from . import utils
302

  
303
    RoleParentingManager.tls.DO_UPDATE_CLOSURE = False
304
    try:
305
        yield
306
        if RoleParentingManager.tls.CLOSURE_UPDATED:
307
            utils.get_role_parenting_model().objects.update_transitive_closure()
308
    finally:
309
        RoleParentingManager.tls.DO_UPDATE_CLOSURE = True
310
        RoleParentingManager.tls.CLOSURE_UPDATED = False
311

  
24 312

  
25 313
class OrganizationalUnitManager(AbstractBaseManager):
26 314
    def get_by_natural_key(self, uuid):
......
70 358
        # find an operation matching the template
71 359
        op = get_operation(operation)
72 360
        if create:
73
            perm, _ = models.Permission.objects.update_or_create(
361
            perm, _ = a2_models.Permission.objects.update_or_create(
74 362
                operation=op,
75 363
                target_ct=ContentType.objects.get_for_model(instance),
76 364
                target_id=instance.pk,
......
79 367
            )
80 368
        else:
81 369
            try:
82
                perm = models.Permission.objects.get(
370
                perm = a2_models.Permission.objects.get(
83 371
                    operation=op,
84 372
                    target_ct=ContentType.objects.get_for_model(instance),
85 373
                    target_id=instance.pk,
86 374
                    **kwargs,
87 375
                )
88
            except models.Permission.DoesNotExist:
376
            except a2_models.Permission.DoesNotExist:
89 377
                return None
90 378

  
91 379
        # in which ou do we put the role ?
......
157 445
            kwargs['ou__isnull'] = True
158 446
        else:
159 447
            try:
160
                ou = models.OrganizationalUnit.objects.get_by_natural_key(*ou_natural_key)
161
            except models.OrganizationalUnit.DoesNotExist:
448
                ou = a2_models.OrganizationalUnit.objects.get_by_natural_key(*ou_natural_key)
449
            except a2_models.OrganizationalUnit.DoesNotExist:
162 450
                raise self.model.DoesNotExist
163 451
            kwargs['ou'] = ou
164 452
        if service_natural_key is None:
src/authentic2/a2_rbac/models.py
37 37
from authentic2.decorators import errorcollector
38 38
from authentic2.utils.cache import GlobalCache
39 39
from authentic2.validators import HexaColourValidator
40
from django_rbac import managers as rbac_managers
41 40
from django_rbac import utils as rbac_utils
42 41

  
43 42
from . import app_settings, fields, managers
......
55 54
    slug = models.SlugField(max_length=256, verbose_name=_('slug'))
56 55
    description = models.TextField(verbose_name=_('description'), blank=True)
57 56

  
58
    objects = rbac_managers.AbstractBaseManager()
57
    objects = managers.AbstractBaseManager()
59 58

  
60 59
    def __str__(self):
61 60
        return str(self.name)
......
280 279
    target_id = models.PositiveIntegerField()
281 280
    target = GenericForeignKey('target_ct', 'target_id')
282 281

  
283
    objects = rbac_managers.PermissionManager()
282
    objects = managers.PermissionManager()
284 283

  
285 284
    class Meta:
286 285
        verbose_name = _('permission')
......
412 411
        default=True, verbose_name=_('Allow adding or deleting role members')
413 412
    )
414 413

  
415
    objects = rbac_managers.RoleQuerySet.as_manager()
414
    objects = managers.RoleQuerySet.as_manager()
416 415

  
417 416
    def add_child(self, child):
418 417
        RoleParenting = rbac_utils.get_role_parenting_model()
......
720 719
    created = models.DateTimeField(verbose_name=_('Creation date'), auto_now_add=True)
721 720
    deleted = models.DateTimeField(verbose_name=_('Deletion date'), null=True)
722 721

  
723
    objects = rbac_managers.RoleParentingManager()
722
    objects = managers.RoleParentingManager()
724 723
    alive = QueryManager(deleted__isnull=True)
725 724

  
726 725
    def natural_key(self):
......
778 777

  
779 778
    _registry = {}
780 779

  
781
    objects = rbac_managers.OperationManager()
780
    objects = managers.OperationManager()
782 781

  
783 782

  
784 783
Operation._meta.natural_key = ['slug']
src/authentic2/a2_rbac/signal_handlers.py
22 22

  
23 23
from authentic2.a2_rbac.models import OrganizationalUnit, Role
24 24
from authentic2.utils.misc import get_fk_model
25
from django_rbac.managers import defer_update_transitive_closure
26 25
from django_rbac.utils import get_operation, get_role_parenting_model
27 26

  
27
from .managers import defer_update_transitive_closure
28

  
28 29

  
29 30
def create_default_ou(app_config, verbosity=2, interactive=True, using=DEFAULT_DB_ALIAS, **kwargs):
30 31
    if not router.allow_migrate(using, OrganizationalUnit):
src/django_rbac/managers.py
1
import contextlib
2
import datetime
3
import threading
4

  
5
from django.contrib.auth import get_user_model
6
from django.contrib.contenttypes.models import ContentType
7
from django.db import connection, models
8
from django.db.models import query
9
from django.db.models.query import Prefetch, Q
10
from django.db.transaction import atomic
11

  
12
from authentic2.a2_rbac import signals
13

  
14
from . import utils
15

  
16

  
17
class AbstractBaseManager(models.Manager):
18
    def get_by_natural_key(self, uuid):
19
        return self.get(uuid=uuid)
20

  
21

  
22
class OperationManager(models.Manager):
23
    def get_by_natural_key(self, slug):
24
        return self.get(slug=slug)
25

  
26
    def has_perm(self, user, operation_slug, object_or_model, ou=None):
27
        """Test if an user can do the operation given by operation_slug
28
        on the given object_or_model eventually scoped by an organizational
29
        unit given by ou.
30

  
31
        Returns True or False.
32
        """
33
        ou_query = query.Q(ou__isnull=True)
34
        if ou:
35
            ou_query |= query.Q(ou=ou.as_scope())
36
        ct = ContentType.objects.get_for_model(object_or_model)
37
        target_query = query.Q(target_ct=ContentType.objects.get_for_model(ContentType), target_id=ct.pk)
38
        if isinstance(object_or_model, models.Model):
39
            target_query |= query.Q(target_ct=ct, target_id=object.pk)
40
        Permission = utils.get_permission_model()
41
        qs = Permission.objects.for_user(user)
42
        qs = qs.filter(operation__slug=operation_slug)
43
        qs = qs.filter(ou_query & target_query)
44
        return qs.exists()
45

  
46

  
47
class PermissionManagerBase(models.Manager):
48
    def get_by_natural_key(self, operation_slug, ou_nk, target_ct, target_nk):
49
        qs = self.filter(operation__slug=operation_slug)
50
        if ou_nk:
51
            OrganizationalUnit = utils.get_ou_model()
52
            try:
53
                ou = OrganizationalUnit.objects.get_by_natural_key(*ou_nk)
54
            except OrganizationalUnit.DoesNotExist:
55
                raise self.model.DoesNotExist
56
            qs = qs.filter(ou=ou)
57
        else:
58
            qs = qs.filter(ou__isnull=True)
59
        try:
60
            target_ct = ContentType.objects.get_by_natural_key(*target_ct)
61
        except ContentType.DoesNotExist:
62
            raise self.model.DoesNotExist
63
        target_model = target_ct.model_class()
64
        try:
65
            target = target_model.objects.get_by_natural_key(*target_nk)
66
        except target_model.DoesNotExist:
67
            raise self.model.DoesNotExist
68
        return qs.get(target_ct=ContentType.objects.get_for_model(target), target_id=target.pk)
69

  
70

  
71
class PermissionQueryset(query.QuerySet):
72
    def by_target_ct(self, target):
73
        """Filter permission whose target content-type matches the content
74
        type of the target argument
75
        """
76
        target_ct = ContentType.objects.get_for_model(target)
77
        return self.filter(target_ct=target_ct)
78

  
79
    def by_target(self, target):
80
        '''Filter permission whose target matches target'''
81
        return self.by_target_ct(target).filter(target_id=target.pk)
82

  
83
    def for_user(self, user):
84
        """Retrieve all permissions hold by an user through its role and
85
        inherited roles.
86
        """
87
        Role = utils.get_role_model()
88
        roles = Role.objects.for_user(user=user)
89
        return self.filter(roles__in=roles)
90

  
91
    def cleanup(self):
92
        count = 0
93
        for p in self:
94
            if not p.target and (p.target_ct_id or p.target_id):
95
                p.delete()
96
                count += 1
97
        return count
98

  
99

  
100
PermissionManager = PermissionManagerBase.from_queryset(PermissionQueryset)
101

  
102

  
103
class IntCast(models.Func):
104
    function = 'int'
105
    template = 'CAST((%(expressions)s) AS %(function)s)'
106

  
107

  
108
class RoleQuerySet(query.QuerySet):
109
    def for_user(self, user):
110
        if hasattr(user, 'apiclient_roles'):
111
            queryset = self.filter(apiclients=user)
112
        else:
113
            queryset = self.filter(members=user)
114
        return queryset.parents().distinct()
115

  
116
    def parents(self, include_self=True, annotate=False, direct=None):
117
        assert annotate is False or direct is not True, 'annotate=True cannot be used with direct=True'
118
        if direct is None:
119
            qs = self.model.objects.filter(
120
                child_relation__deleted__isnull=True,
121
                child_relation__child__in=self,
122
            )
123
        else:
124
            qs = self.model.objects.filter(
125
                child_relation__deleted__isnull=True,
126
                child_relation__child__in=self,
127
                child_relation__direct=direct,
128
            )
129
        if include_self:
130
            qs = self | qs
131
        qs = qs.distinct()
132
        if annotate:
133
            qs = qs.annotate(direct=models.Max(IntCast('child_relation__direct')))
134
        return qs
135

  
136
    def children(self, include_self=True, annotate=False, direct=None):
137
        assert annotate is False or direct is not True, 'annotate=True cannot be used with direct=True'
138
        if direct is None:
139
            qs = self.model.objects.filter(
140
                parent_relation__deleted__isnull=True,
141
                parent_relation__parent__in=self,
142
            )
143
        else:
144
            qs = self.model.objects.filter(
145
                parent_relation__deleted__isnull=True,
146
                parent_relation__parent__in=self,
147
                parent_relation__direct=direct,
148
            )
149
        if include_self:
150
            qs = self | qs
151
        qs = qs.distinct()
152
        if annotate:
153
            qs = qs.annotate(direct=models.Max(IntCast('parent_relation__direct')))
154
        return qs
155

  
156
    def all_members(self):
157
        User = get_user_model()
158
        prefetch = Prefetch('roles', queryset=self, to_attr='direct')
159
        return (
160
            User.objects.filter(
161
                Q(roles__in=self)
162
                | Q(roles__parent_relation__parent__in=self, roles__parent_relation__deleted__isnull=True)
163
            )
164
            .distinct()
165
            .prefetch_related(prefetch)
166
        )
167

  
168
    def by_admin_scope_ct(self, admin_scope):
169
        admin_scope_ct = ContentType.objects.get_for_model(admin_scope)
170
        return self.filter(admin_scope_ct=admin_scope_ct)
171

  
172
    def cleanup(self):
173
        count = 0
174
        for r in self.filter(Q(admin_scope_ct_id__isnull=False) | Q(admin_scope_id__isnull=False)):
175
            if not r.admin_scope:
176
                r.delete()
177
                count += 1
178
        return count
179

  
180

  
181
RoleManager = AbstractBaseManager.from_queryset(RoleQuerySet)
182

  
183

  
184
class RoleParentingManager(models.Manager):
185
    class Local(threading.local):
186
        DO_UPDATE_CLOSURE = True
187
        CLOSURE_UPDATED = False
188

  
189
    tls = Local()
190

  
191
    def get_by_natural_key(self, parent_nk, child_nk, direct):
192
        Role = utils.get_role_model()
193
        try:
194
            parent = Role.objects.get_by_natural_key(*parent_nk)
195
        except Role.DoesNotExist:
196
            raise self.model.DoesNotExist
197
        try:
198
            child = Role.objects.get_by_natural_key(*child_nk)
199
        except Role.DoesNotExist:
200
            raise self.model.DoesNotExist
201
        return self.get(parent=parent, child=child, direct=direct)
202

  
203
    def soft_create(self, parent, child):
204
        with atomic(savepoint=False):
205
            rp, created = self.get_or_create(parent=parent, child=child, direct=True)
206
            new = created or rp.deleted
207
            if not created and rp.deleted:
208
                rp.created = datetime.datetime.now()
209
                rp.deleted = None
210
                rp.save(update_fields=['created', 'deleted'])
211
            if new:
212
                signals.post_soft_create.send(sender=self.model, instance=rp)
213

  
214
    def soft_delete(self, parent, child):
215
        qs = self.filter(parent=parent, child=child, deleted__isnull=True, direct=True)
216
        with atomic(savepoint=False):
217
            rp = qs.first()
218
            if rp:
219
                count = qs.update(deleted=datetime.datetime.now())
220
                # read-commited, view of tables can change during transaction
221
                if count:
222
                    signals.post_soft_delete.send(sender=self.model, instance=rp)
223

  
224
    def update_transitive_closure(self):
225
        """Recompute the transitive closure of the inheritance relation
226
        from scratch. Add missing indirect relations and delete
227
        obsolete indirect relations.
228
        """
229
        if not self.tls.DO_UPDATE_CLOSURE:
230
            self.tls.CLOSURE_UPDATED = True
231
            return
232

  
233
        with atomic(savepoint=False):
234
            # existing direct paths
235
            direct = set(self.filter(direct=True, deleted__isnull=True).values_list('parent_id', 'child_id'))
236
            old_indirects = set(
237
                self.filter(direct=False, deleted__isnull=True).values_list('parent_id', 'child_id')
238
            )
239
            indirects = set(direct)
240

  
241
            while True:
242
                changed = False
243
                for (i, j) in list(indirects):
244
                    for (k, l) in direct:
245
                        if j == k and i != l and (i, l) not in indirects:
246
                            indirects.add((i, l))
247
                            changed = True
248
                if not changed:
249
                    break
250

  
251
            with connection.cursor() as cur:
252
                # Delete old ones
253
                obsolete = old_indirects - indirects - direct
254
                if obsolete:
255
                    sql = '''UPDATE "%s" AS relation \
256
SET deleted = now()\
257
FROM (VALUES %s) AS dead(parent_id, child_id) \
258
WHERE relation.direct = 'false' AND relation.parent_id = dead.parent_id \
259
AND relation.child_id = dead.child_id AND deleted IS NULL''' % (
260
                        self.model._meta.db_table,
261
                        ', '.join('(%s, %s)' % (a, b) for a, b in obsolete),
262
                    )
263
                    cur.execute(sql)
264
                # Create new indirect relations
265
                new = indirects - old_indirects - direct
266
                if new:
267
                    new_values = ', '.join(
268
                        (
269
                            "(%s, %s, 'false', now(), NULL)" % (parent_id, child_id)
270
                            for parent_id, child_id in new
271
                        )
272
                    )
273
                    sql = '''INSERT INTO "%s" (parent_id, child_id, direct, created, deleted) VALUES %s \
274
ON CONFLICT (parent_id, child_id, direct) DO UPDATE SET created = EXCLUDED.created, deleted = NULL''' % (
275
                        self.model._meta.db_table,
276
                        new_values,
277
                    )
278
                    cur.execute(sql)
279

  
280

  
281
@contextlib.contextmanager
282
def defer_update_transitive_closure():
283
    from . import utils
284

  
285
    RoleParentingManager.tls.DO_UPDATE_CLOSURE = False
286
    try:
287
        yield
288
        if RoleParentingManager.tls.CLOSURE_UPDATED:
289
            utils.get_role_parenting_model().objects.update_transitive_closure()
290
    finally:
291
        RoleParentingManager.tls.DO_UPDATE_CLOSURE = True
292
        RoleParentingManager.tls.CLOSURE_UPDATED = False
293
-