Projet

Général

Profil

0001-sql-store-transient-data-in-a-specific-table-66620.patch

Frédéric Péters, 26 juin 2022 09:24

Télécharger (11,7 ko)

Voir les différences:

Subject: [PATCH] sql: store transient data in a specific table (#66620)

 tests/form_pages/test_all.py |   9 ++-
 tests/test_sessions.py       |  50 +++++++++++-
 tests/utilities.py           |   1 +
 wcs/publisher.py             |   1 +
 wcs/sql.py                   | 153 ++++++++++++++++++++++++++++++++++-
 5 files changed, 208 insertions(+), 6 deletions(-)
tests/form_pages/test_all.py
1521 1521
    user.store()
1522 1522
    resp = resp.form.submit('submit')
1523 1523
    resp = resp.follow()
1524
    assert 'The form has been recorded' in resp
1525
    assert formdef.data_class().count() == 1
1526
    assert formdef.data_class().select()[0].user_id is None
1524
    if pub.is_using_postgresql():
1525
        assert 'Sorry, your session have been lost.' in resp
1526
    else:
1527
        assert 'The form has been recorded' in resp
1528
        assert formdef.data_class().count() == 1
1529
        assert formdef.data_class().select()[0].user_id is None
1527 1530

  
1528 1531

  
1529 1532
def test_form_titles(pub):
tests/test_sessions.py
1 1
import datetime
2 2
import os
3
import pickle
3 4
import shutil
4 5
import time
5 6

  
6 7
import pytest
7 8

  
8
from wcs import fields
9
from wcs import fields, sql
9 10
from wcs.formdef import FormDef
10 11
from wcs.qommon.http_request import HTTPRequest
11 12
from wcs.qommon.ident.password_accounts import PasswordAccount
......
225 226
    user.is_active = False
226 227
    user.store()
227 228
    assert 'Logout' not in app.get('/')
229

  
230

  
231
def test_magictoken_migration(pub, app):
232
    if not pub.is_using_postgresql():
233
        pytest.skip('this requires SQL')
234
        return
235

  
236
    pub.session_manager.session_class.wipe()
237
    sql.TransientData.wipe()
238
    resp = app.get('/')
239

  
240
    formdef = FormDef()
241
    formdef.name = 'foobar'
242
    formdef.fields = [
243
        fields.PageField(id='0', label='1st PAGE', type='page'),
244
        fields.StringField(id='1', label='string', type='string'),
245
        fields.PageField(id='2', label='2nd PAGE', type='page'),
246
        fields.PageField(id='3', label='3rd PAGE', type='page'),
247
    ]
248
    formdef.store()
249
    formdef.data_class().wipe()
250

  
251
    resp = app.get('/foobar/')
252
    resp.form['f1'] = 'test'
253
    resp = resp.form.submit('submit')
254

  
255
    # migrate back session to look like before transient data table
256
    assert pub.session_manager.session_class.count() == 1
257
    session = pub.session_manager.session_class.select()[0]
258
    session.magictokens = {}
259
    for transient_data in sql.TransientData.select():
260
        session.magictokens[transient_data.id] = transient_data.data
261
    sql.TransientData.wipe()
262

  
263
    conn, cur = sql.get_connection_and_cursor()
264
    sql_statement = 'UPDATE sessions SET session_data = %s WHERE id = %s'
265
    cur.execute(sql_statement, (bytearray(pickle.dumps(session.__dict__, protocol=2)), session.id))
266
    conn.commit()
267
    cur.close()
268

  
269
    # and get back to submitting form, it should run migration
270
    resp = resp.form.submit('submit')  # -> 3rd page
271
    resp = resp.form.submit('submit')  # -> validation page
272
    resp = resp.form.submit('submit')  # -> submit
273
    assert formdef.data_class().count() == 1
274
    formdata = formdef.data_class().select()[0]
275
    assert formdata.data['1'] == 'test'
tests/utilities.py
175 175
        sql.do_tokens_table()
176 176
        sql.do_tracking_code_table()
177 177
        sql.do_session_table()
178
        sql.do_transient_data_table()
178 179
        sql.do_custom_views_table()
179 180
        sql.do_snapshots_table()
180 181
        sql.do_loggederrors_table()
wcs/publisher.py
358 358
        sql.do_role_table()
359 359
        sql.do_tracking_code_table()
360 360
        sql.do_custom_views_table()
361
        sql.do_transient_data_table()
361 362
        sql.do_snapshots_table()
362 363
        sql.do_loggederrors_table()
363 364
        sql.do_tokens_table()
wcs/sql.py
33 33
    import pickle
34 34

  
35 35
from django.utils.encoding import force_bytes, force_text
36
from django.utils.timezone import now
36 37
from quixote import get_publisher
37 38

  
38 39
import wcs.carddata
......
1250 1251
    cur.close()
1251 1252

  
1252 1253

  
1254
def do_transient_data_table():
1255
    conn, cur = get_connection_and_cursor()
1256
    table_name = TransientData._table_name
1257

  
1258
    cur.execute(
1259
        '''SELECT COUNT(*) FROM information_schema.tables
1260
                    WHERE table_schema = 'public'
1261
                      AND table_name = %s''',
1262
        (table_name,),
1263
    )
1264
    if cur.fetchone()[0] == 0:
1265
        cur.execute(
1266
            '''CREATE TABLE %s (id varchar PRIMARY KEY,
1267
                                session_id varchar,
1268
                                data bytea,
1269
                                last_update_time timestamptz
1270
                                )'''
1271
            % table_name
1272
        )
1273
    cur.execute(
1274
        '''SELECT column_name FROM information_schema.columns
1275
                    WHERE table_schema = 'public'
1276
                      AND table_name = %s''',
1277
        (table_name,),
1278
    )
1279
    existing_fields = {x[0] for x in cur.fetchall()}
1280
    needed_fields = {x[0] for x in TransientData._table_static_fields}
1281

  
1282
    # delete obsolete fields
1283
    for field in existing_fields - needed_fields:
1284
        cur.execute('''ALTER TABLE %s DROP COLUMN %s''' % (table_name, field))
1285

  
1286
    conn.commit()
1287
    cur.close()
1288

  
1289

  
1253 1290
def do_custom_views_table():
1254 1291
    conn, cur = get_connection_and_cursor()
1255 1292
    table_name = 'custom_views'
......
3275 3312
        return o
3276 3313

  
3277 3314

  
3315
class TransientData(SqlMixin):
3316
    # table to keep some transient submission data out of global session dictionary
3317
    _table_name = 'transient_data'
3318
    _table_static_fields = [
3319
        ('id', 'varchar'),
3320
        ('session_id', 'varchar'),
3321
        ('data', 'bytea'),
3322
        ('last_update_time', 'timestamptz'),
3323
    ]
3324
    _numerical_id = False
3325

  
3326
    def __init__(self, id, session_id, data):
3327
        self.id = id
3328
        self.session_id = session_id
3329
        self.data = data
3330

  
3331
    @guard_postgres
3332
    def store(self):
3333
        sql_dict = {
3334
            'id': self.id,
3335
            'session_id': self.session_id,
3336
            'data': bytearray(pickle.dumps(self.data, protocol=2)),
3337
            'last_update_time': now(),
3338
        }
3339

  
3340
        conn, cur = get_connection_and_cursor()
3341
        column_names = sql_dict.keys()
3342
        sql_statement = '''UPDATE %s SET %s WHERE id = %%(id)s RETURNING id''' % (
3343
            self._table_name,
3344
            ', '.join(['%s = %%(%s)s' % (x, x) for x in column_names]),
3345
        )
3346
        cur.execute(sql_statement, sql_dict)
3347
        if cur.fetchone() is None:
3348
            sql_statement = '''INSERT INTO %s (%s) VALUES (%s)''' % (
3349
                self._table_name,
3350
                ', '.join(column_names),
3351
                ', '.join(['%%(%s)s' % x for x in column_names]),
3352
            )
3353
            cur.execute(sql_statement, sql_dict)
3354

  
3355
        conn.commit()
3356
        cur.close()
3357

  
3358
    @classmethod
3359
    def _row2ob(cls, row, **kwargs):
3360
        o = cls.__new__(cls)
3361
        o.id = str_encode(row[0])
3362
        o.session_id = row[1]
3363
        o.data = pickle_loads(row[2])
3364
        return o
3365

  
3366
    @classmethod
3367
    def get_data_fields(cls):
3368
        return []
3369

  
3370
    @classmethod
3371
    @guard_postgres
3372
    def remove_for_session(cls, session_id):
3373
        conn, cur = get_connection_and_cursor()
3374
        sql_statement = 'DELETE FROM %s WHERE ' % cls._table_name
3375
        sql_statement += 'session_id = %s'
3376
        cur.execute(sql_statement, (session_id,))
3377
        conn.commit()
3378
        cur.close()
3379

  
3380

  
3278 3381
class Session(SqlMixin, wcs.sessions.BasicSession):
3279 3382
    _table_name = 'sessions'
3280 3383
    _table_static_fields = [
......
3296 3399
        if self.message:
3297 3400
            # escape lazy gettext
3298 3401
            self.message = (self.message[0], str(self.message[1]))
3402

  
3403
        # store transient data
3404
        for v in (self.magictokens or {}).values():
3405
            v.store()
3406

  
3407
        # force to be empty, to make sure there's no leftover direct usage
3408
        session_data = copy.copy(self.__dict__)
3409
        session_data['magictokens'] = None
3410

  
3299 3411
        sql_dict = {
3300 3412
            'id': self.id,
3301
            'session_data': bytearray(pickle.dumps(self.__dict__, protocol=2)),
3413
            'session_data': bytearray(pickle.dumps(session_data, protocol=2)),
3302 3414
            # the other fields are stored to run optimized SELECT() against the
3303 3415
            # table, they are ignored when loading the data.
3304 3416
            'name_identifier': self.name_identifier,
......
3333 3445
        session_data = pickle_loads(row[1])
3334 3446
        for k, v in session_data.items():
3335 3447
            setattr(o, k, v)
3448
        if o.magictokens:
3449
            # migration, obsolete storage of magictokens in session
3450
            for k, v in o.magictokens.items():
3451
                o.add_magictoken(k, v)
3452
            o.magictokens = None
3453
            o.store()
3336 3454
        return o
3337 3455

  
3338 3456
    @classmethod
......
3375 3493
    def get_data_fields(cls):
3376 3494
        return []
3377 3495

  
3496
    def add_magictoken(self, token, data):
3497
        super().add_magictoken(token, data)
3498
        if not self.id:
3499
            self.store()
3500
        self.magictokens[token] = TransientData(id=token, session_id=self.id, data=data)
3501
        self.magictokens[token].store()
3502

  
3503
    def get_by_magictoken(self, token, default=None):
3504
        if not self.magictokens:
3505
            self.magictokens = {}
3506
        try:
3507
            if token not in self.magictokens:
3508
                self.magictokens[token] = TransientData.select(
3509
                    [Equal('session_id', self.id), Equal('id', token)]
3510
                )[0]
3511
            return self.magictokens[token].data
3512
        except IndexError:
3513
            return default
3514

  
3515
    def remove_magictoken(self, token):
3516
        super().remove_magictoken(token)
3517
        TransientData.remove_object(token)
3518

  
3519
    @classmethod
3520
    def remove_object(cls, id):
3521
        TransientData.remove_for_session(id)
3522
        super().remove_object(id)
3523

  
3378 3524

  
3379 3525
class TrackingCode(SqlMixin, wcs.tracking_code.TrackingCode):
3380 3526
    _table_name = 'tracking_codes'
......
4141 4287
# latest migration, number + description (description is not used
4142 4288
# programmaticaly but will make sure git conflicts if two migrations are
4143 4289
# separately added with the same number)
4144
SQL_LEVEL = (63, 'add index on snapshot table')
4290
SQL_LEVEL = (64, 'add transient data table')
4145 4291

  
4146 4292

  
4147 4293
def migrate_global_views(conn, cur):
......
4306 4452
        # 37: create custom_views table
4307 4453
        # 44: add is_default column to custom_views table
4308 4454
        do_custom_views_table()
4455
    if sql_level < 64:
4456
        # 64: add transient data table
4457
        do_transient_data_table()
4309 4458
    if sql_level < 30:
4310 4459
        # 30: actually remove evo.who on anonymised formdatas
4311 4460
        from wcs.formdef import FormDef
4312
-