From 707a47bb4528c9afcc384f9bb6836372ba98e459 Mon Sep 17 00:00:00 2001 From: Benjamin Dauvergne Date: Wed, 20 Oct 2021 10:49:44 +0200 Subject: [PATCH] sql: use batch iteration on ids instead of named cursors (#58013) Named cursors imposed the use of isolated connections and were misused resulting in reading using one SQL query by row (because of the use of .fetchone() with cursors). This commit revert to the behaviour of one connection per request and reading full SQL statement results at a time without using cursors. --- wcs/qommon/storage.py | 1 + wcs/sql.py | 103 ++++++++++++++++++++++++++++-------------- wcs/wf/jump.py | 2 +- 3 files changed, 70 insertions(+), 36 deletions(-) diff --git a/wcs/qommon/storage.py b/wcs/qommon/storage.py index 55202cef..ea0888a7 100644 --- a/wcs/qommon/storage.py +++ b/wcs/qommon/storage.py @@ -431,6 +431,7 @@ class StorableObject: limit=None, offset=None, iterator=False, + itersize=None, **kwargs, ): # iterator: only for compatibility with sql select() diff --git a/wcs/sql.py b/wcs/sql.py index 90441061..cf0963be 100644 --- a/wcs/sql.py +++ b/wcs/sql.py @@ -20,7 +20,6 @@ import io import json import re import time -import uuid import psycopg2 import psycopg2.extensions @@ -365,11 +364,11 @@ def site_unicode(value): return force_text(value, get_publisher().site_charset) -def get_connection(new=False, isolate=False): - if new and not isolate: +def get_connection(new=False): + if new: cleanup_connection() - if isolate or not getattr(get_publisher(), 'pgconn', None): + if not getattr(get_publisher(), 'pgconn', None): postgresql_cfg = {} for param in ('database', 'user', 'password', 'host', 'port'): value = get_cfg('postgresql', {}).get(param) @@ -378,11 +377,9 @@ def get_connection(new=False, isolate=False): try: pgconn = psycopg2.connect(**postgresql_cfg) except psycopg2.Error: - if new or isolate: + if new: raise pgconn = None - if isolate: - return pgconn get_publisher().pgconn = pgconn @@ -1431,7 +1428,7 @@ class SqlMixin: _table_name = None _numerical_id = True _table_select_skipped_fields = [] - _iterate_on_server = True + _has_id = True @classmethod @guard_postgres @@ -1651,16 +1648,38 @@ class SqlMixin: @classmethod @guard_postgres - def select_iterator(cls, clause=None, order_by=None, ignore_errors=False, limit=None, offset=None): + def select_iterator( + cls, + clause=None, + order_by=None, + ignore_errors=False, + limit=None, + offset=None, + itersize=None, + ): table_static_fields = [ x[0] if x[0] not in cls._table_select_skipped_fields else 'NULL AS %s' % x[0] for x in cls._table_static_fields ] - sql_statement = '''SELECT %s - FROM %s''' % ( - ', '.join(table_static_fields + cls.get_data_fields()), - cls._table_name, - ) + + def retrieve(): + for object in cls.get_objects(cur, iterator=True): + if object is None: + continue + if func_clause and not func_clause(object): + continue + yield object + + if itersize and cls._has_id: + # this case concerns almost all data tables: formdata, card, users, roles + sql_statement = '''SELECT id FROM %s''' % cls._table_name + else: + # this case concerns aggregated views like wcs_all_forms (class + # AnyFormData) which does not have a surrogate key id column + sql_statement = '''SELECT %s FROM %s''' % ( + ', '.join(table_static_fields + cls.get_data_fields()), + cls._table_name, + ) where_clauses, parameters, func_clause = parse_clause(clause) if where_clauses: sql_statement += ' WHERE ' + ' AND '.join(where_clauses) @@ -1675,31 +1694,45 @@ class SqlMixin: sql_statement += ' OFFSET %(offset)s' parameters['offset'] = offset - if cls._iterate_on_server: - conn = get_connection(isolate=True) - cur = conn.cursor(name='select_iterator_%s' % uuid.uuid4()) - else: - conn, cur = get_connection_and_cursor() - cur.execute(sql_statement, parameters) - try: - for object in cls.get_objects(cur, iterator=True): - if object is None: - continue - if func_clause and not func_clause(object): - continue - yield object - finally: - cur.close() + conn, cur = get_connection_and_cursor() + with cur: + cur.execute(sql_statement, parameters) conn.commit() - if cls._iterate_on_server: - # close isolated connection - conn.close() + if itersize and cls._has_id: + sql_id_statement = '''SELECT %s FROM %s WHERE id IN %%s''' % ( + ', '.join(table_static_fields + cls.get_data_fields()), + cls._table_name, + ) + sql_id_statement += cls.get_order_by_clause(order_by) + ids = [row[0] for row in cur] + while ids: + cur.execute(sql_id_statement, [tuple(ids[:itersize])]) + conn.commit() + yield from retrieve() + ids = ids[itersize:] + else: + yield from retrieve() @classmethod @guard_postgres - def select(cls, clause=None, order_by=None, ignore_errors=False, limit=None, offset=None, iterator=False): + def select( + cls, + clause=None, + order_by=None, + ignore_errors=False, + limit=None, + offset=None, + iterator=False, + itersize=None, + ): + if iterator and not itersize: + itersize = 200 objects = cls.select_iterator( - clause=clause, order_by=order_by, ignore_errors=ignore_errors, limit=limit, offset=offset + clause=clause, + order_by=order_by, + ignore_errors=ignore_errors, + limit=limit, + offset=offset, ) func_clause = parse_clause(clause)[2] if func_clause and (limit or offset): @@ -3147,7 +3180,7 @@ class classproperty: class AnyFormData(SqlMixin): _table_name = 'wcs_all_forms' _formdef_cache = {} - _iterate_on_server = False + _has_id = False @classproperty def _table_static_fields(self): diff --git a/wcs/wf/jump.py b/wcs/wf/jump.py index 90bb634d..6308884e 100644 --- a/wcs/wf/jump.py +++ b/wcs/wf/jump.py @@ -350,7 +350,7 @@ def _apply_timeouts(publisher, **kwargs): (datetime.datetime.now() - datetime.timedelta(seconds=delay)).timetuple(), ), ] - formdatas = formdata_class.select_iterator(criterias, ignore_errors=True) + formdatas = formdata_class.select_iterator(criterias, ignore_errors=True, itersize=200) else: formdatas = formdata_class.get_with_indexed_value('status', status_id, ignore_errors=True) -- 2.33.0