Projet

Général

Profil

0001-sql-store-transient-data-in-a-specific-table-66620.patch

Frédéric Péters, 28 juin 2022 18:47

Télécharger (12,3 ko)

Voir les différences:

Subject: [PATCH] sql: store transient data in a specific table (#66620)

 tests/form_pages/test_all.py |   9 ++-
 tests/test_sessions.py       |  77 +++++++++++++++++-
 tests/utilities.py           |   1 +
 wcs/publisher.py             |   1 +
 wcs/sql.py                   | 147 ++++++++++++++++++++++++++++++++++-
 5 files changed, 229 insertions(+), 6 deletions(-)
tests/form_pages/test_all.py
1521 1521
    user.store()
1522 1522
    resp = resp.form.submit('submit')
1523 1523
    resp = resp.follow()
1524
    assert 'The form has been recorded' in resp
1525
    assert formdef.data_class().count() == 1
1526
    assert formdef.data_class().select()[0].user_id is None
1524
    if pub.is_using_postgresql():
1525
        assert 'Sorry, your session have been lost.' in resp
1526
    else:
1527
        assert 'The form has been recorded' in resp
1528
        assert formdef.data_class().count() == 1
1529
        assert formdef.data_class().select()[0].user_id is None
1527 1530

  
1528 1531

  
1529 1532
def test_form_titles(pub):
tests/test_sessions.py
1 1
import datetime
2 2
import os
3
import pickle
3 4
import shutil
4 5
import time
5 6

  
6 7
import pytest
7 8

  
8
from wcs import fields
9
from wcs import fields, sql
9 10
from wcs.formdef import FormDef
10 11
from wcs.qommon.http_request import HTTPRequest
11 12
from wcs.qommon.ident.password_accounts import PasswordAccount
......
225 226
    user.is_active = False
226 227
    user.store()
227 228
    assert 'Logout' not in app.get('/')
229

  
230

  
231
def test_transient_data_removal(pub, app):
232
    if not pub.is_using_postgresql():
233
        pytest.skip('this requires SQL')
234
        return
235

  
236
    pub.session_manager.session_class.wipe()
237
    sql.TransientData.wipe()
238
    resp = app.get('/')
239

  
240
    formdef = FormDef()
241
    formdef.name = 'foobar'
242
    formdef.fields = [fields.StringField(id='1', label='string', type='string')]
243
    formdef.store()
244
    formdef.data_class().wipe()
245

  
246
    resp = app.get('/foobar/')
247
    resp.form['f1'] = 'test'
248
    resp = resp.form.submit('submit')
249

  
250
    assert sql.Session.count() == 1
251
    assert sql.TransientData.count() == 1
252

  
253
    app.get('/logout')
254
    assert sql.Session.count() == 0
255
    assert sql.TransientData.count() == 0
256

  
257

  
258
def test_magictoken_migration(pub, app):
259
    if not pub.is_using_postgresql():
260
        pytest.skip('this requires SQL')
261
        return
262

  
263
    pub.session_manager.session_class.wipe()
264
    sql.TransientData.wipe()
265
    resp = app.get('/')
266

  
267
    formdef = FormDef()
268
    formdef.name = 'foobar'
269
    formdef.fields = [
270
        fields.PageField(id='0', label='1st PAGE', type='page'),
271
        fields.StringField(id='1', label='string', type='string'),
272
        fields.PageField(id='2', label='2nd PAGE', type='page'),
273
        fields.PageField(id='3', label='3rd PAGE', type='page'),
274
    ]
275
    formdef.store()
276
    formdef.data_class().wipe()
277

  
278
    resp = app.get('/foobar/')
279
    resp.form['f1'] = 'test'
280
    resp = resp.form.submit('submit')
281

  
282
    # migrate back session to look like before transient data table
283
    assert pub.session_manager.session_class.count() == 1
284
    session = pub.session_manager.session_class.select()[0]
285
    session.magictokens = {}
286
    for transient_data in sql.TransientData.select():
287
        session.magictokens[transient_data.id] = transient_data.data
288
    sql.TransientData.wipe()
289

  
290
    conn, cur = sql.get_connection_and_cursor()
291
    sql_statement = 'UPDATE sessions SET session_data = %s WHERE id = %s'
292
    cur.execute(sql_statement, (bytearray(pickle.dumps(session.__dict__, protocol=2)), session.id))
293
    conn.commit()
294
    cur.close()
295

  
296
    # and get back to submitting form, it should run migration
297
    resp = resp.form.submit('submit')  # -> 3rd page
298
    resp = resp.form.submit('submit')  # -> validation page
299
    resp = resp.form.submit('submit')  # -> submit
300
    assert formdef.data_class().count() == 1
301
    formdata = formdef.data_class().select()[0]
302
    assert formdata.data['1'] == 'test'
tests/utilities.py
175 175
        sql.do_tokens_table()
176 176
        sql.do_tracking_code_table()
177 177
        sql.do_session_table()
178
        sql.do_transient_data_table()
178 179
        sql.do_custom_views_table()
179 180
        sql.do_snapshots_table()
180 181
        sql.do_loggederrors_table()
wcs/publisher.py
358 358
        sql.do_role_table()
359 359
        sql.do_tracking_code_table()
360 360
        sql.do_custom_views_table()
361
        sql.do_transient_data_table()
361 362
        sql.do_snapshots_table()
362 363
        sql.do_loggederrors_table()
363 364
        sql.do_tokens_table()
wcs/sql.py
33 33
    import pickle
34 34

  
35 35
from django.utils.encoding import force_bytes, force_text
36
from django.utils.timezone import now
36 37
from quixote import get_publisher
37 38

  
38 39
import wcs.carddata
......
1250 1251
    cur.close()
1251 1252

  
1252 1253

  
1254
def do_transient_data_table():
1255
    conn, cur = get_connection_and_cursor()
1256
    table_name = TransientData._table_name
1257

  
1258
    cur.execute(
1259
        '''SELECT COUNT(*) FROM information_schema.tables
1260
                    WHERE table_schema = 'public'
1261
                      AND table_name = %s''',
1262
        (table_name,),
1263
    )
1264
    if cur.fetchone()[0] == 0:
1265
        cur.execute(
1266
            '''CREATE TABLE %s (id varchar PRIMARY KEY,
1267
                                session_id VARCHAR REFERENCES sessions(id) ON DELETE CASCADE,
1268
                                data bytea,
1269
                                last_update_time timestamptz
1270
                                )'''
1271
            % table_name
1272
        )
1273
    cur.execute(
1274
        '''SELECT column_name FROM information_schema.columns
1275
                    WHERE table_schema = 'public'
1276
                      AND table_name = %s''',
1277
        (table_name,),
1278
    )
1279
    existing_fields = {x[0] for x in cur.fetchall()}
1280
    needed_fields = {x[0] for x in TransientData._table_static_fields}
1281

  
1282
    # delete obsolete fields
1283
    for field in existing_fields - needed_fields:
1284
        cur.execute('''ALTER TABLE %s DROP COLUMN %s''' % (table_name, field))
1285

  
1286
    conn.commit()
1287
    cur.close()
1288

  
1289

  
1253 1290
def do_custom_views_table():
1254 1291
    conn, cur = get_connection_and_cursor()
1255 1292
    table_name = 'custom_views'
......
3275 3312
        return o
3276 3313

  
3277 3314

  
3315
class TransientData(SqlMixin):
3316
    # table to keep some transient submission data out of global session dictionary
3317
    _table_name = 'transient_data'
3318
    _table_static_fields = [
3319
        ('id', 'varchar'),
3320
        ('session_id', 'varchar'),
3321
        ('data', 'bytea'),
3322
        ('last_update_time', 'timestamptz'),
3323
    ]
3324
    _numerical_id = False
3325

  
3326
    def __init__(self, id, session_id, data):
3327
        self.id = id
3328
        self.session_id = session_id
3329
        self.data = data
3330

  
3331
    @guard_postgres
3332
    def store(self):
3333
        sql_dict = {
3334
            'id': self.id,
3335
            'session_id': self.session_id,
3336
            'data': bytearray(pickle.dumps(self.data, protocol=2)),
3337
            'last_update_time': now(),
3338
        }
3339

  
3340
        conn, cur = get_connection_and_cursor()
3341
        column_names = sql_dict.keys()
3342
        sql_statement = '''UPDATE %s SET %s WHERE id = %%(id)s RETURNING id''' % (
3343
            self._table_name,
3344
            ', '.join(['%s = %%(%s)s' % (x, x) for x in column_names]),
3345
        )
3346
        cur.execute(sql_statement, sql_dict)
3347
        if cur.fetchone() is None:
3348
            sql_statement = '''INSERT INTO %s (%s) VALUES (%s)''' % (
3349
                self._table_name,
3350
                ', '.join(column_names),
3351
                ', '.join(['%%(%s)s' % x for x in column_names]),
3352
            )
3353
            cur.execute(sql_statement, sql_dict)
3354

  
3355
        conn.commit()
3356
        cur.close()
3357

  
3358
    @classmethod
3359
    def _row2ob(cls, row, **kwargs):
3360
        o = cls.__new__(cls)
3361
        o.id = str_encode(row[0])
3362
        o.session_id = row[1]
3363
        o.data = pickle_loads(row[2])
3364
        return o
3365

  
3366
    @classmethod
3367
    def get_data_fields(cls):
3368
        return []
3369

  
3370
    @classmethod
3371
    @guard_postgres
3372
    def remove_for_session(cls, session_id):
3373
        conn, cur = get_connection_and_cursor()
3374
        sql_statement = 'DELETE FROM %s WHERE ' % cls._table_name
3375
        sql_statement += 'session_id = %s'
3376
        cur.execute(sql_statement, (session_id,))
3377
        conn.commit()
3378
        cur.close()
3379

  
3380

  
3278 3381
class Session(SqlMixin, wcs.sessions.BasicSession):
3279 3382
    _table_name = 'sessions'
3280 3383
    _table_static_fields = [
......
3296 3399
        if self.message:
3297 3400
            # escape lazy gettext
3298 3401
            self.message = (self.message[0], str(self.message[1]))
3402

  
3403
        # store transient data
3404
        for v in (self.magictokens or {}).values():
3405
            v.store()
3406

  
3407
        # force to be empty, to make sure there's no leftover direct usage
3408
        session_data = copy.copy(self.__dict__)
3409
        session_data['magictokens'] = None
3410

  
3299 3411
        sql_dict = {
3300 3412
            'id': self.id,
3301
            'session_data': bytearray(pickle.dumps(self.__dict__, protocol=2)),
3413
            'session_data': bytearray(pickle.dumps(session_data, protocol=2)),
3302 3414
            # the other fields are stored to run optimized SELECT() against the
3303 3415
            # table, they are ignored when loading the data.
3304 3416
            'name_identifier': self.name_identifier,
......
3333 3445
        session_data = pickle_loads(row[1])
3334 3446
        for k, v in session_data.items():
3335 3447
            setattr(o, k, v)
3448
        if o.magictokens:
3449
            # migration, obsolete storage of magictokens in session
3450
            for k, v in o.magictokens.items():
3451
                o.add_magictoken(k, v)
3452
            o.magictokens = None
3453
            o.store()
3336 3454
        return o
3337 3455

  
3338 3456
    @classmethod
......
3375 3493
    def get_data_fields(cls):
3376 3494
        return []
3377 3495

  
3496
    def add_magictoken(self, token, data):
3497
        assert self.id
3498
        super().add_magictoken(token, data)
3499
        self.magictokens[token] = TransientData(id=token, session_id=self.id, data=data)
3500
        self.magictokens[token].store()
3501

  
3502
    def get_by_magictoken(self, token, default=None):
3503
        if not self.magictokens:
3504
            self.magictokens = {}
3505
        try:
3506
            if token not in self.magictokens:
3507
                self.magictokens[token] = TransientData.select(
3508
                    [Equal('session_id', self.id), Equal('id', token)]
3509
                )[0]
3510
            return self.magictokens[token].data
3511
        except IndexError:
3512
            return default
3513

  
3514
    def remove_magictoken(self, token):
3515
        super().remove_magictoken(token)
3516
        TransientData.remove_object(token)
3517

  
3378 3518

  
3379 3519
class TrackingCode(SqlMixin, wcs.tracking_code.TrackingCode):
3380 3520
    _table_name = 'tracking_codes'
......
4141 4281
# latest migration, number + description (description is not used
4142 4282
# programmaticaly but will make sure git conflicts if two migrations are
4143 4283
# separately added with the same number)
4144
SQL_LEVEL = (63, 'add index on snapshot table')
4284
SQL_LEVEL = (64, 'add transient data table')
4145 4285

  
4146 4286

  
4147 4287
def migrate_global_views(conn, cur):
......
4306 4446
        # 37: create custom_views table
4307 4447
        # 44: add is_default column to custom_views table
4308 4448
        do_custom_views_table()
4449
    if sql_level < 64:
4450
        # 64: add transient data table
4451
        do_transient_data_table()
4309 4452
    if sql_level < 30:
4310 4453
        # 30: actually remove evo.who on anonymised formdatas
4311 4454
        from wcs.formdef import FormDef
4312
-