From e10eaf08abb9fc6c36ad4e5755998e0bed70df6b Mon Sep 17 00:00:00 2001 From: Benjamin Dauvergne Date: Tue, 5 Oct 2021 22:38:13 +0200 Subject: [PATCH 2/3] sql: use iterate on server only for iterators Iteration on server is useless when loading all objects in memory and increase the number of needed connections to the SQL server, at least by a factor 2 (you need the current one and the one needed to populate the returned list). --- tests/conftest.py | 3 ++- tests/workflow/test_all.py | 2 +- wcs/sql.py | 15 +++++++++++---- 3 files changed, 14 insertions(+), 6 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 0050687d..93b40544 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -96,7 +96,7 @@ def sql_queries(monkeypatch): def cursor(*args, **kwargs): cur = old_cursor(*args, **kwargs) - mocked_cur = mock.Mock(wraps=cur) + mocked_cur = mock.MagicMock(wraps=cur, spec=cur) old_execute = cur.execute def execute(*args, **kwargs): @@ -104,6 +104,7 @@ def sql_queries(monkeypatch): return old_execute(*args, **kwargs) mocked_cur.execute = execute + mocked_cur.__iter__ = lambda *args, **kwargs: cur.__iter__() return mocked_cur mocked_conn.cursor = cursor diff --git a/tests/workflow/test_all.py b/tests/workflow/test_all.py index d80775b1..06b68721 100644 --- a/tests/workflow/test_all.py +++ b/tests/workflow/test_all.py @@ -3183,7 +3183,7 @@ def test_geolocate_address(two_pubs): assert formdata.geolocations == {} if two_pubs.is_using_postgresql(): assert two_pubs.loggederror_class.count() == 2 - logged_error = two_pubs.loggederror_class.select()[1] + logged_error = two_pubs.loggederror_class.select(order_by='id')[1] assert logged_error.summary == 'error calling geocoding service [some error]' assert logged_error.formdata_id == str(formdata.id) assert logged_error.exception_class == 'ConnectionError' diff --git a/wcs/sql.py b/wcs/sql.py index 57f9b16e..520f1a86 100644 --- a/wcs/sql.py +++ b/wcs/sql.py @@ -1630,7 +1630,9 @@ class SqlMixin: @classmethod @guard_postgres - def select_iterator(cls, clause=None, order_by=None, ignore_errors=False, limit=None, offset=None): + def select_iterator( + cls, clause=None, order_by=None, ignore_errors=False, limit=None, offset=None, iterator=True + ): table_static_fields = [ x[0] if x[0] not in cls._table_select_skipped_fields else 'NULL AS %s' % x[0] for x in cls._table_static_fields @@ -1654,7 +1656,7 @@ class SqlMixin: sql_statement += ' OFFSET %(offset)s' parameters['offset'] = offset - if cls._iterate_on_server: + if cls._iterate_on_server and iterator: conn = get_connection(isolate=True) cur = conn.cursor(name='select_iterator_%s' % uuid.uuid4()) cur.itersize = 1 @@ -1671,7 +1673,7 @@ class SqlMixin: finally: cur.close() conn.commit() - if cls._iterate_on_server: + if cls._iterate_on_server and iterator: # close isolated connection conn.close() @@ -1679,7 +1681,12 @@ class SqlMixin: @guard_postgres def select(cls, clause=None, order_by=None, ignore_errors=False, limit=None, offset=None, iterator=False): objects = cls.select_iterator( - clause=clause, order_by=order_by, ignore_errors=ignore_errors, limit=limit, offset=offset + clause=clause, + order_by=order_by, + ignore_errors=ignore_errors, + limit=limit, + offset=offset, + iterator=iterator, ) func_clause = parse_clause(clause)[2] if func_clause and (limit or offset): -- 2.33.0