From 11ad397384591a968f3a85528eb9c87fd50f2f60 Mon Sep 17 00:00:00 2001 From: Benjamin Dauvergne Date: Thu, 6 Oct 2022 11:11:37 +0200 Subject: [PATCH] misc: move tenant conservation in Thread.start (#69942) --- hobo/multitenant/threads.py | 33 ++++++++++++----------------- tests_multitenant/test_threading.py | 6 +++--- 2 files changed, 17 insertions(+), 22 deletions(-) diff --git a/hobo/multitenant/threads.py b/hobo/multitenant/threads.py index f6981b2..fa94aaf 100644 --- a/hobo/multitenant/threads.py +++ b/hobo/multitenant/threads.py @@ -16,40 +16,35 @@ import threading -_Thread_bootstrap_inner = threading.Thread._bootstrap_inner -_Thread__init__ = threading.Thread.__init__ +_Thread_start = threading.Thread.start +_Thread__bootstrap_inner = threading.Thread._bootstrap_inner -def _new__init__(self, *args, **kwargs): +def _new_start(self): from django.db import connection - try: - if hasattr(connection, 'get_tenant'): - self.tenant = connection.get_tenant() - else: - self.tenant = None - except RuntimeError: - # this happens when ImportError is raised at startup; ignore - # the error to let the real one be displayed. - self.tenant = None - _Thread__init__(self, *args, **kwargs) + tenant = getattr(connection, 'tenant', None) + self.tenant = tenant + return _Thread_start(self) -def _new_bootstrap_inner(self): - if self.tenant is not None: +def _new__bootstrap_inner(self): + tenant = getattr(self, 'tenant', None) + + if tenant is not None: from django.db import connection old_tenant = connection.get_tenant() connection.set_tenant(self.tenant) try: - _Thread_bootstrap_inner(self) + _Thread__bootstrap_inner(self) finally: connection.set_tenant(old_tenant) connection.close() else: - _Thread_bootstrap_inner(self) + _Thread__bootstrap_inner(self) def install_tenant_aware_threads(): - threading.Thread.__init__ = _new__init__ - threading.Thread._bootstrap_inner = _new_bootstrap_inner + threading.Thread.start = _new_start + threading.Thread._bootstrap_inner = _new__bootstrap_inner diff --git a/tests_multitenant/test_threading.py b/tests_multitenant/test_threading.py index f0bcfcd..0ef3cec 100644 --- a/tests_multitenant/test_threading.py +++ b/tests_multitenant/test_threading.py @@ -45,7 +45,7 @@ def test_thread(tenants, settings, client): with tenant_context(tenant): assert hasattr(settings, 'TEMPLATE_VARS') t2 = threading.Thread(target=f, args=(tenant,)) - t2.start() + t2.start() t2.join() assert not hasattr(django.conf.settings, 'TEMPLATE_VARS') @@ -74,7 +74,7 @@ def test_cache(tenants, client): assert cache.get('coin') == tenant.domain_url t1 = threading.Thread(target=f) - t1.start() + t1.start() t1.join() def g(): @@ -107,7 +107,7 @@ def test_timer_thread(tenants, settings, client): with tenant_context(tenant): assert hasattr(settings, 'TEMPLATE_VARS') t2 = threading.Timer(0.0, f, args=(tenant,)) - t2.start() + t2.start() t2.join() assert not hasattr(django.conf.settings, 'TEMPLATE_VARS') -- 2.37.2