From 5199e6ad31a717267be70eabf044079399336fe4 Mon Sep 17 00:00:00 2001 From: Benjamin Dauvergne Date: Wed, 20 Oct 2021 10:49:44 +0200 Subject: [PATCH] sql: use batch iteration on ids instead of named cursors (#58013) Named cursors imposed the use of isolated connections and were misused resulting in reading using one SQL query by row (because of the use of .fetchone() with cursors). This commit revert to the behaviour of one connection per request and reading full SQL statement results at a time without using cursors. --- wcs/qommon/storage.py | 1 + wcs/sql.py | 131 ++++++++++++++++++++++++++++-------------- wcs/wf/jump.py | 2 +- 3 files changed, 91 insertions(+), 43 deletions(-) diff --git a/wcs/qommon/storage.py b/wcs/qommon/storage.py index 55202cef..ea0888a7 100644 --- a/wcs/qommon/storage.py +++ b/wcs/qommon/storage.py @@ -431,6 +431,7 @@ class StorableObject: limit=None, offset=None, iterator=False, + itersize=None, **kwargs, ): # iterator: only for compatibility with sql select() diff --git a/wcs/sql.py b/wcs/sql.py index 751c0426..973dbbc0 100644 --- a/wcs/sql.py +++ b/wcs/sql.py @@ -364,11 +364,11 @@ def site_unicode(value): return force_text(value, get_publisher().site_charset) -def get_connection(new=False, isolate=False): - if new and not isolate: +def get_connection(new=False): + if new: cleanup_connection() - if isolate or not getattr(get_publisher(), 'pgconn', None): + if not getattr(get_publisher(), 'pgconn', None): postgresql_cfg = {} for param in ('database', 'user', 'password', 'host', 'port'): value = get_cfg('postgresql', {}).get(param) @@ -377,11 +377,9 @@ def get_connection(new=False, isolate=False): try: pgconn = psycopg2.connect(**postgresql_cfg) except psycopg2.Error: - if new or isolate: + if new: raise pgconn = None - if isolate: - return pgconn get_publisher().pgconn = pgconn @@ -1430,7 +1428,7 @@ class SqlMixin: _table_name = None _numerical_id = True _table_select_skipped_fields = [] - _iterate_on_server = True + _has_id = True @classmethod @guard_postgres @@ -1650,55 +1648,104 @@ class SqlMixin: @classmethod @guard_postgres - def select_iterator(cls, clause=None, order_by=None, ignore_errors=False, limit=None, offset=None): + def select_iterator( + cls, + clause=None, + order_by=None, + ignore_errors=False, + limit=None, + offset=None, + itersize=None, + ): table_static_fields = [ x[0] if x[0] not in cls._table_select_skipped_fields else 'NULL AS %s' % x[0] for x in cls._table_static_fields ] - sql_statement = '''SELECT %s - FROM %s''' % ( - ', '.join(table_static_fields + cls.get_data_fields()), - cls._table_name, - ) - where_clauses, parameters, func_clause = parse_clause(clause) - if where_clauses: - sql_statement += ' WHERE ' + ' AND '.join(where_clauses) - sql_statement += cls.get_order_by_clause(order_by) - - if not func_clause: - if limit: - sql_statement += ' LIMIT %(limit)s' - parameters['limit'] = limit - if offset: - sql_statement += ' OFFSET %(offset)s' - parameters['offset'] = offset - - if cls._iterate_on_server: - conn = get_connection(isolate=True) - cur = conn.cursor(name='select_iterator_%s' % uuid.uuid4()) - else: - conn, cur = get_connection_and_cursor() - cur.execute(sql_statement, parameters) - try: + def retrieve(): for object in cls.get_objects(cur, iterator=True): if object is None: continue if func_clause and not func_clause(object): continue + print(object.id) yield object - finally: - cur.close() - conn.commit() - if cls._iterate_on_server: - # close isolated connection - conn.close() + + if itersize and cls._has_id: + sql_statement = '''SELECT id FROM %s''' % cls._table_name + where_clauses, parameters, func_clause = parse_clause(clause) + if where_clauses: + sql_statement += ' WHERE ' + ' AND '.join(where_clauses) + + sql_statement += cls.get_order_by_clause(order_by) + + if not func_clause: + if limit: + sql_statement += ' LIMIT %(limit)s' + parameters['limit'] = limit + if offset: + sql_statement += ' OFFSET %(offset)s' + parameters['offset'] = offset + + sql_id_statement = '''SELECT %s + FROM %s WHERE id IN %%s''' % ( + ', '.join(table_static_fields + cls.get_data_fields()), + cls._table_name, + ) + sql_id_statement += cls.get_order_by_clause(order_by) + + conn, cur = get_connection_and_cursor() + with cur: + cur.execute(sql_statement, parameters) + ids = [row[0] for row in cur] + while ids: + cur.execute(sql_id_statement, [tuple(ids[:itersize])]) + yield from retrieve() + ids = ids[itersize:] + else: # for AnyFormData + sql_statement = '''SELECT %s FROM %s''' % ( + ', '.join(table_static_fields + cls.get_data_fields()), + cls._table_name, + ) + where_clauses, parameters, func_clause = parse_clause(clause) + if where_clauses: + sql_statement += ' WHERE ' + ' AND '.join(where_clauses) + + sql_statement += cls.get_order_by_clause(order_by) + + if not func_clause: + if limit: + sql_statement += ' LIMIT %(limit)s' + parameters['limit'] = limit + if offset: + sql_statement += ' OFFSET %(offset)s' + parameters['offset'] = offset + + conn, cur = get_connection_and_cursor() + with cur: + cur.execute(sql_statement, parameters) + yield from retrieve() @classmethod @guard_postgres - def select(cls, clause=None, order_by=None, ignore_errors=False, limit=None, offset=None, iterator=False): + def select( + cls, + clause=None, + order_by=None, + ignore_errors=False, + limit=None, + offset=None, + iterator=False, + itersize=None, + ): + if iterator and not itersize: + itersize = 200 objects = cls.select_iterator( - clause=clause, order_by=order_by, ignore_errors=ignore_errors, limit=limit, offset=offset + clause=clause, + order_by=order_by, + ignore_errors=ignore_errors, + limit=limit, + offset=offset, ) func_clause = parse_clause(clause)[2] if func_clause and (limit or offset): @@ -3145,7 +3192,7 @@ class classproperty: class AnyFormData(SqlMixin): _table_name = 'wcs_all_forms' _formdef_cache = {} - _iterate_on_server = False + _has_id = False @classproperty def _table_static_fields(self): diff --git a/wcs/wf/jump.py b/wcs/wf/jump.py index dfb5caed..d9b84b86 100644 --- a/wcs/wf/jump.py +++ b/wcs/wf/jump.py @@ -345,7 +345,7 @@ def _apply_timeouts(publisher): (datetime.datetime.now() - datetime.timedelta(seconds=delay)).timetuple(), ), ] - formdatas = formdata_class.select_iterator(criterias, ignore_errors=True) + formdatas = formdata_class.select_iterator(criterias, ignore_errors=True, itersize=200) else: formdatas = formdata_class.get_with_indexed_value('status', status_id, ignore_errors=True) -- 2.33.0