From 033c6d53b7b86bca4dc28f413ec6cf788eaee122 Mon Sep 17 00:00:00 2001 From: Benjamin Dauvergne Date: Tue, 9 Mar 2021 22:25:10 +0100 Subject: [PATCH 1/3] journal: permit custom prefetching (#51808) --- src/authentic2/apps/journal/forms.py | 5 +++- src/authentic2/apps/journal/models.py | 22 +++++++++++++++- tests/test_journal.py | 38 ++++++++++++++++++++++++++- 3 files changed, 62 insertions(+), 3 deletions(-) diff --git a/src/authentic2/apps/journal/forms.py b/src/authentic2/apps/journal/forms.py index edc96a33..d6150f31 100644 --- a/src/authentic2/apps/journal/forms.py +++ b/src/authentic2/apps/journal/forms.py @@ -330,13 +330,16 @@ class JournalForm(forms.Form): first = len(page) <= limit last = True page = page[-limit:] - models.prefetch_events_references(page) + models.prefetch_events_references(page, prefetcher=self.prefetcher) if page: self.data = self.data.copy() self.cleaned_data['after_cursor'] = self.data['after_cursor'] = page[0].cursor.minus_one() self.cleaned_data['before_cursor'] = '' return Page(self, page, first, last) + def prefetcher(self, model, pks): + return [] + @cached_property def date_hierarchy(self): self.is_valid() diff --git a/src/authentic2/apps/journal/models.py b/src/authentic2/apps/journal/models.py index bb6d71b1..d5302f4f 100644 --- a/src/authentic2/apps/journal/models.py +++ b/src/authentic2/apps/journal/models.py @@ -20,6 +20,7 @@ from collections import defaultdict from contextlib import contextmanager from datetime import datetime, timedelta +import django from django.conf import settings from django.contrib.auth import get_user_model from django.contrib.contenttypes.models import ContentType @@ -458,7 +459,7 @@ class EventCursor(str): return EventCursor('%s %s' % (self.timestamp.timestamp(), self.event_id - 1)) -def prefetch_events_references(events): +def prefetch_events_references(events, prefetcher=None): '''Prefetch references on an iterable of events, prevent N+1 queries problem.''' grouped_references = defaultdict(set) references = {} @@ -473,6 +474,25 @@ def prefetch_events_references(events): content_type = ContentType.objects.get_for_id(content_type_id) for instance in content_type.get_all_objects_for_this_type(pk__in=instance_pks): references[(content_type_id, instance.pk)] = instance + if prefetcher: + deleted_pks = [pk for pk in instance_pks if (content_type_id, pk) not in references] + if deleted_pks: + for found_pk, instance in prefetcher(content_type.model_class(), deleted_pks): + references[(content_type_id, found_pk)] = instance + + # prefetch the user column if absent + if prefetcher: + user_to_events = {} + for event in events: + if event.user is None and event.user_id: + user_to_events.setdefault(event.user_id, []).append(event) + for found_pk, instance in prefetcher(User, user_to_events.keys()): + for event in user_to_events[found_pk]: + # prevent TypeError in user's field descriptor __set__ method + if django.VERSION < (2,): + event._user_cache = instance + else: + event._state.fields_cache['user'] = instance # assign references to events for event in events: diff --git a/tests/test_journal.py b/tests/test_journal.py index 169c13f6..f87a59e3 100644 --- a/tests/test_journal.py +++ b/tests/test_journal.py @@ -28,7 +28,13 @@ from authentic2.a2_rbac.models import OrganizationalUnit as OU from authentic2.a2_rbac.utils import get_default_ou from authentic2.apps.journal.forms import JournalForm from authentic2.apps.journal.journal import Journal -from authentic2.apps.journal.models import Event, EventType, EventTypeDefinition, clean_registry +from authentic2.apps.journal.models import ( + Event, + EventType, + EventTypeDefinition, + clean_registry, + prefetch_events_references, +) from authentic2.models import Service User = get_user_model() @@ -146,6 +152,7 @@ def test_references(db): assert list(event.get_typed_references(User, None)) == [None, None] event = Event.objects.get() assert list(event.get_typed_references(Service, User)) == [None, None] + assert event.user is None def test_event_types(clean_event_types_definition_registry): @@ -669,3 +676,32 @@ def test_statistics_ou_with_no_service(db, freezer): ou_with_no_service = OU.objects.create(name='Second OU') stats = event_type_definition.get_method_statistics('month', services_ou=ou_with_no_service) assert stats == {'x_labels': [], 'series': []} + + +def test_prefetcher(db): + event_type = EventType.objects.get_for_name('user.login') + for i in range(10): + user = User.objects.create() + Event.objects.create(type=event_type, user=user, references=[user]) + Event.objects.create(type=event_type, user=user, references=[user]) + + User.objects.all().delete() + + events = list(Event.objects.all()) + prefetch_events_references(events) + for event in events: + assert event.user is None + assert list(event.get_typed_references(User)) == [None] + + def prefetcher(model, pks): + if not issubclass(model, User): + return + for pk in pks: + yield pk, 'deleted %s' % pk + + events = list(Event.objects.all()) + prefetch_events_references(events, prefetcher=prefetcher) + for event in events: + s = 'deleted %s' % event.user_id + assert event.user == s + assert list(event.get_typed_references((str, User))) == [s] -- 2.32.0.rc0