From 2ea5197ca6200d1728cc3d2cf95b6bd29833d9a5 Mon Sep 17 00:00:00 2001 From: Benjamin Dauvergne Date: Thu, 16 Sep 2021 17:31:42 +0200 Subject: [PATCH] sql: lock wcs_meta during schema updates (#57017) --- wcs/sql.py | 114 +++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 114 insertions(+) diff --git a/wcs/sql.py b/wcs/sql.py index a535acb7..45598a9f 100644 --- a/wcs/sql.py +++ b/wcs/sql.py @@ -14,8 +14,10 @@ # You should have received a copy of the GNU General Public License # along with this program; if not, see . +import contextvars import copy import datetime +import functools import io import json import re @@ -388,6 +390,12 @@ def get_connection(new=False, isolate=False): return get_publisher().pgconn +def lock_wcs_meta(): + conn, cur = get_connection_and_cursor() + do_meta_table(conn, cur, insert_current_sql_level=False) + cur.execute('LOCK wcs_meta') + + def cleanup_connection(): if hasattr(get_publisher(), 'pgconn') and get_publisher().pgconn is not None: get_publisher().pgconn.close() @@ -510,7 +518,44 @@ def guard_postgres(func): return f +class RedoWithLock(Exception): + pass + + +locked = contextvars.ContextVar('locked', default=False) +guard_lock_depth = contextvars.ContextVar('guard_lock_depth', default=0) + + +def lock(): + if not locked.get(): + raise RedoWithLock + + +def guard_lock(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + depth = guard_lock_depth.get() + depth_token = guard_lock_depth.set(depth + 1) + try: + try: + return func(*args, **kwargs) + except RedoWithLock: + if depth > 0: + raise + token = locked.set(True) + try: + lock_wcs_meta() + return func(*args, **kwargs) + finally: + locked.reset(token) + finally: + guard_lock_depth.reset(depth_token) + + return wrapper + + @guard_postgres +@guard_lock def do_formdef_tables(formdef, conn=None, cur=None, rebuild_views=False, rebuild_global_views=True): if formdef.id is None: return [] @@ -537,6 +582,7 @@ def do_formdef_tables(formdef, conn=None, cur=None, rebuild_views=False, rebuild (table_name,), ) if cur.fetchone()[0] == 0: + lock() cur.execute( '''CREATE TABLE %s (id serial PRIMARY KEY, user_id varchar, @@ -601,53 +647,69 @@ def do_formdef_tables(formdef, conn=None, cur=None, rebuild_views=False, rebuild # migrations if 'fts' not in existing_fields: # full text search + lock() cur.execute('''ALTER TABLE %s ADD COLUMN fts tsvector''' % table_name) cur.execute('''CREATE INDEX %s_fts ON %s USING gin(fts)''' % (table_name, table_name)) if 'workflow_roles' not in existing_fields: + lock() cur.execute('''ALTER TABLE %s ADD COLUMN workflow_roles bytea''' % table_name) cur.execute('''ALTER TABLE %s ADD COLUMN workflow_roles_array text[]''' % table_name) if 'concerned_roles_array' not in existing_fields: + lock() cur.execute('''ALTER TABLE %s ADD COLUMN concerned_roles_array text[]''' % table_name) if 'actions_roles_array' not in existing_fields: + lock() cur.execute('''ALTER TABLE %s ADD COLUMN actions_roles_array text[]''' % table_name) if 'page_no' not in existing_fields: + lock() cur.execute('''ALTER TABLE %s ADD COLUMN page_no varchar''' % table_name) if 'anonymised' not in existing_fields: + lock() cur.execute('''ALTER TABLE %s ADD COLUMN anonymised timestamptz''' % table_name) if 'tracking_code' not in existing_fields: + lock() cur.execute('''ALTER TABLE %s ADD COLUMN tracking_code varchar''' % table_name) if 'backoffice_submission' not in existing_fields: + lock() cur.execute('''ALTER TABLE %s ADD COLUMN backoffice_submission boolean''' % table_name) if 'submission_context' not in existing_fields: + lock() cur.execute('''ALTER TABLE %s ADD COLUMN submission_context bytea''' % table_name) if 'submission_agent_id' not in existing_fields: + lock() cur.execute('''ALTER TABLE %s ADD COLUMN submission_agent_id varchar''' % table_name) if 'submission_channel' not in existing_fields: + lock() cur.execute('''ALTER TABLE %s ADD COLUMN submission_channel varchar''' % table_name) if 'criticality_level' not in existing_fields: + lock() cur.execute( '''ALTER TABLE %s ADD COLUMN criticality_level integer NOT NULL DEFAULT(0)''' % table_name ) if 'last_update_time' not in existing_fields: + lock() cur.execute('''ALTER TABLE %s ADD COLUMN last_update_time timestamp''' % table_name) if 'digests' not in existing_fields: + lock() cur.execute('''ALTER TABLE %s ADD COLUMN digests jsonb''' % table_name) if 'user_label' not in existing_fields: + lock() cur.execute('''ALTER TABLE %s ADD COLUMN user_label varchar''' % table_name) if 'prefilling_data' not in existing_fields: + lock() cur.execute('''ALTER TABLE %s ADD COLUMN prefilling_data bytea''' % table_name) # add new fields @@ -658,10 +720,12 @@ def do_formdef_tables(formdef, conn=None, cur=None, rebuild_views=False, rebuild continue needed_fields.add(get_field_id(field)) if get_field_id(field) not in existing_fields: + lock() cur.execute('''ALTER TABLE %s ADD COLUMN %s %s''' % (table_name, get_field_id(field), sql_type)) if field.store_display_value: needed_fields.add('%s_display' % get_field_id(field)) if '%s_display' % get_field_id(field) not in existing_fields: + lock() cur.execute( '''ALTER TABLE %s ADD COLUMN %s varchar''' % (table_name, '%s_display' % get_field_id(field)) @@ -669,6 +733,7 @@ def do_formdef_tables(formdef, conn=None, cur=None, rebuild_views=False, rebuild if field.store_structured_value: needed_fields.add('%s_structured' % get_field_id(field)) if '%s_structured' % get_field_id(field) not in existing_fields: + lock() cur.execute( '''ALTER TABLE %s ADD COLUMN %s bytea''' % (table_name, '%s_structured' % get_field_id(field)) @@ -678,10 +743,12 @@ def do_formdef_tables(formdef, conn=None, cur=None, rebuild_views=False, rebuild column_name = 'geoloc_%s' % field needed_fields.add(column_name) if column_name not in existing_fields: + lock() cur.execute('ALTER TABLE %s ADD COLUMN %s %s' '' % (table_name, column_name, 'POINT')) # delete obsolete fields for field in existing_fields - needed_fields: + lock() cur.execute('''ALTER TABLE %s DROP COLUMN %s CASCADE''' % (table_name, field)) # migrations on _evolutions table @@ -694,11 +761,13 @@ def do_formdef_tables(formdef, conn=None, cur=None, rebuild_views=False, rebuild ) evo_existing_fields = {x[0] for x in cur.fetchall()} if 'last_jump_datetime' not in evo_existing_fields: + lock() cur.execute('''ALTER TABLE %s_evolutions ADD COLUMN last_jump_datetime timestamp''' % table_name) if rebuild_views or len(existing_fields - needed_fields): # views may have been dropped when dropping columns, so we recreate # them even if not asked to. + lock() redo_views(conn, cur, formdef, rebuild_global_views=rebuild_global_views) if own_conn: @@ -718,6 +787,7 @@ def do_formdef_tables(formdef, conn=None, cur=None, rebuild_views=False, rebuild return actions +@guard_lock def do_formdef_indexes(formdef, created, conn, cur, concurrently=False): table_name = get_formdef_table_name(formdef) evolutions_table_name = table_name + '_evolutions' @@ -737,12 +807,14 @@ def do_formdef_indexes(formdef, created, conn, cur, concurrently=False): create_index = 'CREATE INDEX CONCURRENTLY' if evolutions_table_name + '_fid' not in existing_indexes: + lock() cur.execute( '''%s %s_fid ON %s (formdata_id)''' % (create_index, evolutions_table_name, evolutions_table_name) ) for attr in ('receipt_time', 'anonymised', 'user_id', 'status'): if table_name + '_' + attr + '_idx' not in existing_indexes: + lock() cur.execute( '%(create_index)s %(table_name)s_%(attr)s_idx ON %(table_name)s (%(attr)s)' % {'create_index': create_index, 'table_name': table_name, 'attr': attr} @@ -750,6 +822,7 @@ def do_formdef_indexes(formdef, created, conn, cur, concurrently=False): @guard_postgres +@guard_lock def do_user_table(): conn, cur = get_connection_and_cursor() table_name = 'users' @@ -761,6 +834,7 @@ def do_user_table(): (table_name,), ) if cur.fetchone()[0] == 0: + lock() cur.execute( '''CREATE TABLE %s (id serial PRIMARY KEY, name varchar, @@ -813,10 +887,12 @@ def do_user_table(): continue needed_fields.add(get_field_id(field)) if get_field_id(field) not in existing_fields: + lock() cur.execute('''ALTER TABLE %s ADD COLUMN %s %s''' % (table_name, get_field_id(field), sql_type)) if field.store_display_value: needed_fields.add('%s_display' % get_field_id(field)) if '%s_display' % get_field_id(field) not in existing_fields: + lock() cur.execute( '''ALTER TABLE %s ADD COLUMN %s varchar''' % (table_name, '%s_display' % get_field_id(field)) @@ -824,6 +900,7 @@ def do_user_table(): if field.store_structured_value: needed_fields.add('%s_structured' % get_field_id(field)) if '%s_structured' % get_field_id(field) not in existing_fields: + lock() cur.execute( '''ALTER TABLE %s ADD COLUMN %s bytea''' % (table_name, '%s_structured' % get_field_id(field)) @@ -832,29 +909,36 @@ def do_user_table(): # migrations if 'fts' not in existing_fields: # full text search + lock() cur.execute('''ALTER TABLE %s ADD COLUMN fts tsvector''' % table_name) cur.execute('''CREATE INDEX %s_fts ON %s USING gin(fts)''' % (table_name, table_name)) if 'verified_fields' not in existing_fields: + lock() cur.execute('ALTER TABLE %s ADD COLUMN verified_fields text[]' % table_name) if 'ascii_name' not in existing_fields: + lock() cur.execute('ALTER TABLE %s ADD COLUMN ascii_name varchar' % table_name) if 'deleted_timestamp' not in existing_fields: + lock() cur.execute('ALTER TABLE %s ADD COLUMN deleted_timestamp timestamp' % table_name) if 'is_active' not in existing_fields: + lock() cur.execute('ALTER TABLE %s ADD COLUMN is_active bool DEFAULT TRUE' % table_name) cur.execute('UPDATE %s SET is_active = FALSE WHERE deleted_timestamp IS NOT NULL' % table_name) # delete obsolete fields for field in existing_fields - needed_fields: + lock() cur.execute('''ALTER TABLE %s DROP COLUMN %s''' % (table_name, field)) conn.commit() try: + lock() cur.execute('CREATE INDEX users_name_idx ON users (name)') conn.commit() except psycopg2.ProgrammingError: @@ -863,6 +947,7 @@ def do_user_table(): cur.close() +@guard_lock def do_role_table(concurrently=False): conn, cur = get_connection_and_cursor() table_name = 'roles' @@ -874,6 +959,7 @@ def do_role_table(concurrently=False): (table_name,), ) if cur.fetchone()[0] == 0: + lock() cur.execute( '''CREATE TABLE %s (id VARCHAR PRIMARY KEY, name VARCHAR, @@ -899,6 +985,7 @@ def do_role_table(concurrently=False): # delete obsolete fields for field in existing_fields - needed_fields: + lock() cur.execute('''ALTER TABLE %s DROP COLUMN %s''' % (table_name, field)) conn.commit() @@ -913,6 +1000,7 @@ def migrate_legacy_roles(): role.store() +@guard_lock def do_tracking_code_table(): conn, cur = get_connection_and_cursor() table_name = 'tracking_codes' @@ -924,6 +1012,7 @@ def do_tracking_code_table(): (table_name,), ) if cur.fetchone()[0] == 0: + lock() cur.execute( '''CREATE TABLE %s (id varchar PRIMARY KEY, formdef_id varchar, @@ -942,12 +1031,14 @@ def do_tracking_code_table(): # delete obsolete fields for field in existing_fields - needed_fields: + lock() cur.execute('''ALTER TABLE %s DROP COLUMN %s''' % (table_name, field)) conn.commit() cur.close() +@guard_lock def do_session_table(): conn, cur = get_connection_and_cursor() table_name = 'sessions' @@ -959,6 +1050,7 @@ def do_session_table(): (table_name,), ) if cur.fetchone()[0] == 0: + lock() cur.execute( '''CREATE TABLE %s (id varchar PRIMARY KEY, session_data bytea, @@ -979,17 +1071,20 @@ def do_session_table(): # migrations if 'last_update_time' not in existing_fields: + lock() cur.execute('''ALTER TABLE %s ADD COLUMN last_update_time timestamp DEFAULT NOW()''' % table_name) cur.execute('''CREATE INDEX %s_ts ON %s (last_update_time)''' % (table_name, table_name)) # delete obsolete fields for field in existing_fields - needed_fields: + lock() cur.execute('''ALTER TABLE %s DROP COLUMN %s''' % (table_name, field)) conn.commit() cur.close() +@guard_lock def do_custom_views_table(): conn, cur = get_connection_and_cursor() table_name = 'custom_views' @@ -1001,6 +1096,7 @@ def do_custom_views_table(): (table_name,), ) if cur.fetchone()[0] == 0: + lock() cur.execute( '''CREATE TABLE %s (id varchar PRIMARY KEY, title varchar, @@ -1028,16 +1124,19 @@ def do_custom_views_table(): # migrations if 'is_default' not in existing_fields: + lock() cur.execute('''ALTER TABLE %s ADD COLUMN is_default boolean DEFAULT FALSE''' % table_name) # delete obsolete fields for field in existing_fields - needed_fields: + lock() cur.execute('''ALTER TABLE %s DROP COLUMN %s''' % (table_name, field)) conn.commit() cur.close() +@guard_lock def do_snapshots_table(): conn, cur = get_connection_and_cursor() table_name = 'snapshots' @@ -1049,6 +1148,7 @@ def do_snapshots_table(): (table_name,), ) if cur.fetchone()[0] == 0: + lock() cur.execute( '''CREATE TABLE %s (id SERIAL, object_type VARCHAR, @@ -1073,12 +1173,14 @@ def do_snapshots_table(): # delete obsolete fields for field in existing_fields - needed_fields: + lock() cur.execute('''ALTER TABLE %s DROP COLUMN %s''' % (table_name, field)) conn.commit() cur.close() +@guard_lock def do_loggederrors_table(concurrently=False): conn, cur = get_connection_and_cursor() table_name = 'loggederrors' @@ -1090,6 +1192,7 @@ def do_loggederrors_table(concurrently=False): (table_name,), ) if cur.fetchone()[0] == 0: + lock() cur.execute( '''CREATE TABLE %s (id SERIAL PRIMARY KEY, tech_id VARCHAR UNIQUE, @@ -1123,6 +1226,7 @@ def do_loggederrors_table(concurrently=False): # delete obsolete fields for field in existing_fields - needed_fields: + lock() cur.execute('''ALTER TABLE %s DROP COLUMN %s''' % (table_name, field)) create_index = 'CREATE INDEX' @@ -1142,6 +1246,7 @@ def do_loggederrors_table(concurrently=False): for attr in ('formdef_id', 'workflow_id'): if table_name + '_' + attr + '_idx' not in existing_indexes: + lock() cur.execute( '%(create_index)s %(table_name)s_%(attr)s_idx ON %(table_name)s (%(attr)s)' % {'create_index': create_index, 'table_name': table_name, 'attr': attr} @@ -1198,6 +1303,7 @@ def redo_views(conn, cur, formdef, rebuild_global_views=False): @guard_postgres +@guard_lock def drop_views(formdef, conn, cur): # remove the global views drop_global_views(conn, cur) @@ -1229,6 +1335,7 @@ def drop_views(formdef, conn, cur): view_names.append(row[0]) for view_name in view_names: + lock() cur.execute('''DROP VIEW IF EXISTS %s''' % view_name) @@ -1254,6 +1361,7 @@ def get_view_fields(formdef): @guard_postgres +@guard_lock def do_views(formdef, conn, cur, rebuild_global_views=True): # create new view table_name = get_formdef_table_name(formdef) @@ -1338,13 +1446,16 @@ def do_views(formdef, conn, cur, rebuild_global_views=True): fields_list = ', '.join(['%s AS %s' % (force_text(x), force_text(y)) for (x, y) in view_fields]) + lock() cur.execute('''CREATE VIEW %s AS SELECT %s FROM %s''' % (view_name, fields_list, table_name)) if rebuild_global_views: do_global_views(conn, cur) # recreate global views +@guard_lock def drop_global_views(conn, cur): + lock() cur.execute( '''SELECT table_name FROM information_schema.views WHERE table_schema = 'public' @@ -1364,6 +1475,7 @@ def drop_global_views(conn, cur): cur.execute('''DROP VIEW IF EXISTS wcs_all_forms''') +@guard_lock def do_global_views(conn, cur): # recreate global views from wcs.formdef import FormDef @@ -1387,6 +1499,7 @@ def do_global_views(conn, cur): if not view_names: return + lock() cur.execute('''DROP VIEW IF EXISTS wcs_all_forms CASCADE''') fake_formdef = FormDef() @@ -3510,6 +3623,7 @@ def migrate_views(conn, cur): @guard_postgres +@guard_lock def migrate(): conn, cur = get_connection_and_cursor() sql_level = get_sql_level(conn, cur) -- 2.33.0