From 75ee2290db7c0ba7ca1ee965fc3b329848204675 Mon Sep 17 00:00:00 2001 From: Benjamin Dauvergne Date: Thu, 7 Mar 2019 19:41:54 +0100 Subject: [PATCH 1/3] utils: add defer module to run things later (#31204) --- passerelle/utils/defer.py | 91 +++++++++++++++++++++++++++ tests/test_utils_defer.py | 126 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 217 insertions(+) create mode 100644 passerelle/utils/defer.py create mode 100644 tests/test_utils_defer.py diff --git a/passerelle/utils/defer.py b/passerelle/utils/defer.py new file mode 100644 index 00000000..d168035c --- /dev/null +++ b/passerelle/utils/defer.py @@ -0,0 +1,91 @@ +# Copyright (C) 2012 Entr'ouvert +# +# This program is free software: you can redistribute it and/or modify it +# under the terms of the GNU Affero General Public License as published +# by the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. +# +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . + + +import functools +import logging +import threading + +import django.db + + +class DeferrableScope(threading.local): + @property + def stack(self): + if not hasattr(self, '_stack'): + self._stack = [] + return self._stack + + def __push(self): + self.stack.append([]) + + def defer(self, func, *args, **kwargs): + if self.should_defer: + self.stack[-1].append((func, args, kwargs)) + else: + return func(*args, **kwargs) + + def __pop(self): + return self.stack.pop() + + def __enter__(self): + self.__push() + return self + + def __exit__(self, exc_type, exc_value, exc_tb): + for func, args, kwargs in self.__pop(): + try: + func(*args, **kwargs) + except Exception: + logging.exception('failed to run deferrable function %s', func) + + @property + def should_defer(self): + return bool(self.stack) + + def deferrable(self, func=None, predicate=None): + '''Automatically defer a function if dynamic scope is inside a deferrable_scope.''' + if not func: + return functools.partial(self.deferrable, predicate=predicate) + + @functools.wraps(func) + def f(*args, **kwargs): + if not predicate or predicate(): + return self.defer(func, *args, **kwargs) + else: + return func(*args, **kwargs) + + return f + + def __call__(self, func): + '''Wraps func in a deferrable_scope scope.''' + + @functools.wraps(func) + def f(*args, **kwargs): + with self: + return func(*args, **kwargs) + + return f + + +deferrable_scope = DeferrableScope() + + +def is_in_transaction(): + return getattr(django.db.connection, 'in_atomic_block', False) + + +deferrable = deferrable_scope.deferrable +deferrable_if_in_transaction = deferrable_scope.deferrable(predicate=is_in_transaction) diff --git a/tests/test_utils_defer.py b/tests/test_utils_defer.py new file mode 100644 index 00000000..8943e6bd --- /dev/null +++ b/tests/test_utils_defer.py @@ -0,0 +1,126 @@ +# Copyright (C) 2022 Entr'ouvert +# +# This program is free software: you can redistribute it and/or modify it +# under the terms of the GNU Affero General Public License as published +# by the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. +# +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . + +import threading + +import pytest +from django.db import transaction + +from passerelle.utils import defer + + +def test_deferrable_scope(): + x = [] + + def f(): + x.append(1) + + assert not defer.deferrable_scope.should_defer + with defer.deferrable_scope: + assert defer.deferrable_scope.should_defer + defer.deferrable_scope.defer(f) + assert x == [] + assert not defer.deferrable_scope.should_defer + assert x == [1] + + +def test_deferrable_scope_with_threading(): + x = [] + + def f(): + x.append(1) + + assert not defer.deferrable_scope.should_defer + with defer.deferrable_scope: + defer.deferrable_scope.defer(f) + assert x == [] + t = threading.Thread(target=defer.deferrable_scope.defer, args=(f,)) + assert x == [] + t.start() + t.join() + assert x == [1] + assert not defer.deferrable_scope.should_defer + assert x == [1, 1] + + +def test_deferrable(): + x = [] + + @defer.deferrable + def f(): + x.append(1) + + f() + assert x == [1] + + assert not defer.deferrable_scope.should_defer + with defer.deferrable_scope: + f() + assert x == [1] + assert not defer.deferrable_scope.should_defer + assert x == [1, 1] + + +def test_deferrable_with_threading(): + x = [] + + @defer.deferrable + def f(): + x.append(1) + + f() + assert x == [1] + + assert not defer.deferrable_scope.should_defer + with defer.deferrable_scope: + f() + assert x == [1] + t = threading.Thread(target=f) + t.start() + t.join() + assert x == [1, 1] + assert x == [1, 1, 1] + assert not defer.deferrable_scope.should_defer + + +def test_deferrable_if_in_transaction(transactional_db): + assert not defer.is_in_transaction() + + x = [] + + @defer.deferrable_if_in_transaction + def f(): + x.append(1) + + f() + assert x == [1] + + with transaction.atomic(): + assert defer.is_in_transaction() + f() + assert x == [1, 1] + + with pytest.raises(Exception): + with defer.deferrable_scope: + f() + assert x == [1, 1, 1] + try: + with transaction.atomic(): + f() + assert x == [1, 1, 1] + raise Exception + finally: + assert x == [1, 1, 1] + assert x == [1, 1, 1, 1] -- 2.36.1