From ec1c8b7f322119b5d56db78968f2827b20631fd2 Mon Sep 17 00:00:00 2001 From: Valentin Deniaud Date: Thu, 6 Jan 2022 18:17:16 +0100 Subject: [PATCH 1/2] general: use tree queries for handling page hierarchy (#60018) --- combo/data/models.py | 26 +++------ combo/data/query.py | 125 +++++++++++++++++++++++++++++++++++++++++ combo/manager/forms.py | 4 +- tests/test_manager.py | 4 +- tests/test_pages.py | 1 + tests/test_public.py | 2 +- 6 files changed, 139 insertions(+), 23 deletions(-) create mode 100644 combo/data/query.py diff --git a/combo/data/models.py b/combo/data/models.py index 795f086e..8841e4da 100644 --- a/combo/data/models.py +++ b/combo/data/models.py @@ -60,6 +60,7 @@ from combo.utils import NothingInCacheException from .fields import RichTextField, TemplatableURLField from .library import get_cell_class, get_cell_classes, register_cell_class +from .query import TreeManager class PostException(Exception): @@ -151,7 +152,7 @@ class Placeholder: return self.name -class PageManager(models.Manager): +class PageManager(TreeManager): snapshots = False def __init__(self, *args, **kwargs): @@ -272,12 +273,10 @@ class Page(models.Model): return super().save(*args, **kwargs) def get_parents_and_self(self): - pages = [self] - page = self - while page.parent_id: - page = page._parent if hasattr(page, '_parent') else page.parent - pages.append(page) - return list(reversed(pages)) + if not self.parent_id: + return [self] + + return list(Page.objects.ancestors(self, include_self=False)) + [self] def get_online_url(self, follow_redirection=True): if ( @@ -323,18 +322,7 @@ class Page(models.Model): return Page.objects.filter(parent_id=self.id).exists() def get_descendants(self, include_myself=False): - def get_descendant_pages(page, include_page=True): - if include_page: - descendants = [page] - else: - descendants = [] - for item in page.get_children(): - descendants.extend(get_descendant_pages(item)) - return descendants - - return Page.objects.filter( - id__in=[x.id for x in get_descendant_pages(self, include_page=include_myself)] - ) + return Page.objects.descendants(self, include_self=include_myself) def get_descendants_and_me(self): return self.get_descendants(include_myself=True) diff --git a/combo/data/query.py b/combo/data/query.py new file mode 100644 index 00000000..cc5049b8 --- /dev/null +++ b/combo/data/query.py @@ -0,0 +1,125 @@ +# Copyright (c) 2018, Feinheit AG and individual contributors. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without modification, +# are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# +# 3. Neither the name of Feinheit AG nor the names of its contributors +# may be used to endorse or promote products derived from this software +# without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR +# ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +# (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON +# ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +from django.db import connections, models +from django.db.models.sql.compiler import SQLCompiler +from django.db.models.sql.query import Query +from django.db.models.sql.where import AND, ExtraWhere + + +class TreeQuery(Query): + def get_compiler(self, using=None, connection=None, elide_empty=True): + # Copied from django/db/models/sql/query.py + if using is None and connection is None: + raise ValueError('Need either using or connection') + if connection is None: + connection = connections[using] + return TreeCompiler(self, connection, using) + + +class TreeExtraWhere(ExtraWhere): + def relabeled_clone(self, change_map): + new_sqls = [] + for sql in self.sqls: + for old_table_name, new_table_name in change_map.items(): + sql = sql.replace(old_table_name, new_table_name) + new_sqls.append(sql) + self.sqls = new_sqls + return self + + +class TreeCompiler(SQLCompiler): + CTE = ''' + WITH RECURSIVE __tree ( + "tree_depth", + "tree_path", + "tree_ordering", + "tree_pk" + ) AS ( + SELECT + 0 AS tree_depth, + array[T.id] AS tree_path, + array["order"] AS tree_ordering, + T.id + FROM data_page T + WHERE T.parent_id IS NULL + + UNION ALL + + SELECT + __tree.tree_depth + 1 AS tree_depth, + __tree.tree_path || T.id, + __tree.tree_ordering || "order", + T.id + FROM data_page T + JOIN __tree ON T.parent_id = __tree.tree_pk + ) + ''' + + def as_sql(self, *args, **kwargs): + if '__tree' not in self.query.extra_tables: + self.query.add_extra( + select={ + 'tree_depth': '__tree.tree_depth', + 'tree_path': '__tree.tree_path', + 'tree_ordering': '__tree.tree_ordering', + }, + select_params=None, + where=None, + params=None, + tables=['__tree'], + order_by=['tree_ordering'], + ) + table_name = self.query.table_alias('data_page')[0] + self.query.where.add(TreeExtraWhere(['__tree.tree_pk = %s.id' % table_name], None), AND) + + sql = super().as_sql(*args, **kwargs) + return (''.join([self.CTE, sql[0]]), sql[1]) + + +class TreeQuerySet(models.QuerySet): + def with_tree_fields(self): + self.query.__class__ = TreeQuery + return self + + def ancestors(self, page, include_self=True): + if not hasattr(page, 'tree_path'): + pk = page.pk if not page.snapshot else page.snapshot.page_id + page = self.with_tree_fields().get(pk=pk) + + ids = page.tree_path if include_self else page.tree_path[:-1] + return self.with_tree_fields().filter(id__in=ids).order_by('tree_depth') + + def descendants(self, page, include_self=True): + queryset = self.with_tree_fields().extra(where=['%s = ANY(tree_path)' % page.pk]) + if not include_self: + return queryset.exclude(pk=page.pk) + return queryset + + +TreeManager = models.Manager.from_queryset(TreeQuerySet) diff --git a/combo/manager/forms.py b/combo/manager/forms.py index e6b3eae0..532981b3 100644 --- a/combo/manager/forms.py +++ b/combo/manager/forms.py @@ -245,7 +245,9 @@ class PageEditIncludeInNavigationForm(forms.ModelForm): super().save(*args, **kwargs) if self.cleaned_data.get('apply_to_subpages'): subpages = self.instance.get_descendants(include_myself=True) - subpages.update(exclude_from_navigation=bool(not self.cleaned_data['include_in_navigation'])) + Page.objects.filter(pk__in=subpages.values('pk')).update( + exclude_from_navigation=bool(not self.cleaned_data['include_in_navigation']) + ) else: self.instance.exclude_from_navigation = not self.cleaned_data['include_in_navigation'] self.instance.save() diff --git a/tests/test_manager.py b/tests/test_manager.py index 5d05f35f..93b413f4 100644 --- a/tests/test_manager.py +++ b/tests/test_manager.py @@ -648,7 +648,7 @@ def test_edit_page_num_queries(settings, app, admin_user): app.get('/manage/pages/%s/' % page.pk) # load once to populate caches with CaptureQueriesContext(connection) as ctx: app.get('/manage/pages/%s/' % page.pk) - assert len(ctx.captured_queries) == 33 + assert len(ctx.captured_queries) == 34 def test_delete_page(app, admin_user): @@ -926,7 +926,7 @@ def test_site_export_import_json(app, admin_user): resp.form['site_file'] = Upload('site-export.json', site_export, 'application/json') with CaptureQueriesContext(connection) as ctx: resp = resp.form.submit() - assert len(ctx.captured_queries) in [303, 304] + assert len(ctx.captured_queries) in [308, 309] assert Page.objects.count() == 4 assert PageSnapshot.objects.all().count() == 4 diff --git a/tests/test_pages.py b/tests/test_pages.py index 8bf04deb..ec43e1b3 100644 --- a/tests/test_pages.py +++ b/tests/test_pages.py @@ -35,6 +35,7 @@ def test_page_url(): page2 = Page() page2.slug = 'bar' page2.parent = page + page2.save() assert page2.get_online_url() == '/foo/bar/' # directly give redirect url of linked page diff --git a/tests/test_public.py b/tests/test_public.py index 29250e56..f42daf32 100644 --- a/tests/test_public.py +++ b/tests/test_public.py @@ -186,7 +186,7 @@ def test_page_footer_acquisition(app): assert resp.text.count('BAR2FOO') == 1 queries_count_third = len(ctx.captured_queries) # +2 for validity info of parent page - assert queries_count_third == queries_count_second + 1 + assert queries_count_third == queries_count_second + 2 with CaptureQueriesContext(connection) as ctx: resp = app.get('/second/third/fourth/', status=200) -- 2.30.2