From fa54d5b3ad8b4c634fb76be49bc23e12d7cbb950 Mon Sep 17 00:00:00 2001 From: Benjamin Dauvergne Date: Thu, 7 Mar 2019 19:41:54 +0100 Subject: [PATCH 1/4] utils: add defer module to run things later (#31204) --- passerelle/utils/defer.py | 88 ++++++++++++++++++++++++++ tests/test_defer.py | 128 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 216 insertions(+) create mode 100644 passerelle/utils/defer.py create mode 100644 tests/test_defer.py diff --git a/passerelle/utils/defer.py b/passerelle/utils/defer.py new file mode 100644 index 00000000..b72ad369 --- /dev/null +++ b/passerelle/utils/defer.py @@ -0,0 +1,88 @@ +# Copyright (C) 2019 Entr'ouvert +# +# This program is free software: you can redistribute it and/or modify it +# under the terms of the GNU Affero General Public License as published +# by the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. +# +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . + + +import functools +import logging +import threading + +import django.db + + +class DeferrableBarrier(threading.local): + @property + def stack(self): + if not hasattr(self, '_stack'): + self._stack = [] + return self._stack + + def __push(self): + self.stack.append([]) + + def defer(self, func, *args, **kwargs): + if self.should_defer: + self.stack[-1].append((func, args, kwargs)) + else: + return func(*args, **kwargs) + + def __pop(self): + return self.stack.pop() + + def __enter__(self): + self.__push() + return self + + def __exit__(self, exc_type, exc_value, exc_tb): + for func, args, kwargs in self.__pop(): + try: + func(*args, **kwargs) + except Exception: + logging.exception('failed to run deferrable function %s', func) + + @property + def should_defer(self): + return bool(self.stack) + + def deferrable(self, func=None, predicate=None): + '''Automatically defer a function if dynamic scope is inside a deferrable_barrier.''' + if not func: + return functools.partial(self.deferrable, predicate=predicate) + + @functools.wraps(func) + def f(*args, **kwargs): + if not predicate or predicate(): + return self.defer(func, *args, **kwargs) + else: + return func(*args, **kwargs) + return f + + def __call__(self, func): + '''Wraps func in a deferrable_barrier scope.''' + + @functools.wraps(func) + def f(*args, **kwargs): + with self: + return func(*args, **kwargs) + return f + + +deferrable_barrier = DeferrableBarrier() + + +def is_in_transaction(): + return getattr(django.db.connection, 'in_atomic_block', False) + +deferrable = deferrable_barrier.deferrable +deferrable_if_in_transaction = deferrable_barrier.deferrable(predicate=is_in_transaction) diff --git a/tests/test_defer.py b/tests/test_defer.py new file mode 100644 index 00000000..fd2aea91 --- /dev/null +++ b/tests/test_defer.py @@ -0,0 +1,128 @@ +# Copyright (C) 2019 Entr'ouvert +# +# This program is free software: you can redistribute it and/or modify it +# under the terms of the GNU Affero General Public License as published +# by the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. +# +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . + +import threading + +from django.db import transaction + +from passerelle.utils import defer + +import pytest + + +def test_deferrable_barrier(): + x = [] + + def f(): + x.append(1) + + assert not defer.deferrable_barrier.should_defer + with defer.deferrable_barrier: + assert defer.deferrable_barrier.should_defer + defer.deferrable_barrier.defer(f) + assert x == [] + assert not defer.deferrable_barrier.should_defer + assert x == [1] + + +def test_deferrable_barrier_with_threading(): + x = [] + + def f(): + x.append(1) + + assert not defer.deferrable_barrier.should_defer + with defer.deferrable_barrier: + defer.deferrable_barrier.defer(f) + assert x == [] + t = threading.Thread(target=defer.deferrable_barrier.defer, args=(f,)) + assert x == [] + t.start() + t.join() + assert x == [1] + assert not defer.deferrable_barrier.should_defer + assert x == [1, 1] + + +def test_deferrable(): + x = [] + + @defer.deferrable + def f(): + x.append(1) + + f() + assert x == [1] + + assert not defer.deferrable_barrier.should_defer + with defer.deferrable_barrier: + f() + assert x == [1] + assert not defer.deferrable_barrier.should_defer + assert x == [1, 1] + + +def test_deferrable_with_threading(): + x = [] + + @defer.deferrable + def f(): + x.append(1) + + f() + assert x == [1] + + assert not defer.deferrable_barrier.should_defer + with defer.deferrable_barrier: + f() + assert x == [1] + t = threading.Thread(target=f) + t.start() + t.join() + assert x == [1, 1] + assert x == [1, 1, 1] + assert not defer.deferrable_barrier.should_defer + + +def test_deferrable_if_in_transaction(transactional_db): + assert not defer.is_in_transaction() + + x = [] + + @defer.deferrable_if_in_transaction + def f(): + x.append(1) + + f() + assert x == [1] + + with transaction.atomic(): + assert defer.is_in_transaction() + f() + assert x == [1, 1] + + with pytest.raises(Exception): + with defer.deferrable_barrier: + f() + assert x == [1, 1, 1] + try: + with transaction.atomic(): + f() + assert x == [1, 1, 1] + raise Exception + finally: + assert x == [1, 1, 1] + assert x == [1, 1, 1, 1] + -- 2.20.1