Projet

Général

Profil

0002-misc-move-tokens-to-an-SQL-table-60665.patch

Frédéric Péters, 18 janvier 2022 17:51

Télécharger (13,6 ko)

Voir les différences:

Subject: [PATCH 2/5] misc: move tokens to an SQL table (#60665)

 tests/test_token.py          |  64 ++++++++++++++++
 tests/utilities.py           |   4 +
 wcs/publisher.py             |   3 +
 wcs/qommon/ident/password.py |   2 +-
 wcs/qommon/tokens.py         |  22 ++++--
 wcs/sql.py                   | 142 ++++++++++++++++++++++++++++++++++-
 6 files changed, 229 insertions(+), 8 deletions(-)
 create mode 100644 tests/test_token.py
tests/test_token.py
1
import os
2
import time
3

  
4
import pytest
5
from django.utils.timezone import now
6
from quixote import get_publisher
7

  
8
import wcs.sql
9
from wcs.qommon.tokens import Token
10

  
11
from .utilities import create_temporary_pub
12

  
13

  
14
def pytest_generate_tests(metafunc):
15
    if 'two_pubs' in metafunc.fixturenames:
16
        metafunc.parametrize('two_pubs', ['pickle', 'sql'], indirect=True)
17

  
18

  
19
@pytest.fixture
20
def pub(request):
21
    return create_temporary_pub(sql_mode=True)
22

  
23

  
24
@pytest.fixture
25
def two_pubs(request):
26
    return create_temporary_pub(sql_mode=(request.param == 'sql'))
27

  
28

  
29
def test_migrate_to_sql(pub):
30
    get_publisher().token_class.wipe()
31
    assert get_publisher().token_class.count() == 0
32
    token = Token()
33
    token.expiration = time.time() + 86400  # expiration stored as timestamp
34
    token.context = {'a': 'b'}
35
    token.store()
36
    assert os.path.exists(token.get_object_filename())
37

  
38
    token2 = Token()
39
    token2.expiration = time.time() - 86400  # already expired
40
    token2.context = {'a': 'b'}
41
    token2.store()
42
    assert os.path.exists(token2.get_object_filename())
43

  
44
    wcs.sql.migrate_legacy_tokens()
45
    assert os.path.exists(token.get_object_filename())
46
    assert not os.path.exists(token2.get_object_filename())
47
    os.unlink(token.get_object_filename())
48
    assert get_publisher().token_class.count() == 1
49
    sql_token = get_publisher().token_class.get(token.id)
50
    assert sql_token.id == token.id
51
    assert sql_token.context == token.context
52
    assert sql_token.expiration.year == now().year
53

  
54

  
55
def test_expiration(two_pubs):
56
    get_publisher().token_class.wipe()
57
    token = get_publisher().token_class()
58
    token.store()
59
    assert get_publisher().token_class().get(token.id)
60

  
61
    token = get_publisher().token_class(expiration_delay=-3600)  # already expired
62
    token.store()
63
    with pytest.raises(KeyError):
64
        assert get_publisher().token_class().get(token.id)
tests/utilities.py
21 21
from wcs import compat, custom_views, sessions, sql
22 22
from wcs.qommon import force_str
23 23
from wcs.qommon.errors import ConnectionError
24
from wcs.qommon.tokens import Token
24 25
from wcs.roles import Role
25 26
from wcs.tracking_code import TrackingCode
26 27
from wcs.users import User
......
78 79
    if sql_mode:
79 80
        pub.user_class = sql.SqlUser
80 81
        pub.role_class = sql.Role
82
        pub.token_class = sql.Token
81 83
        pub.tracking_code_class = sql.TrackingCode
82 84
        pub.session_class = sql.Session
83 85
        pub.custom_view_class = sql.CustomView
......
87 89
    else:
88 90
        pub.user_class = User
89 91
        pub.role_class = Role
92
        pub.token_class = Token
90 93
        pub.tracking_code_class = TrackingCode
91 94
        pub.session_class = sessions.BasicSession
92 95
        pub.custom_view_class = custom_views.CustomView
......
169 172

  
170 173
        sql.do_user_table()
171 174
        sql.do_role_table()
175
        sql.do_tokens_table()
172 176
        sql.do_tracking_code_table()
173 177
        sql.do_session_table()
174 178
        sql.do_custom_views_table()
wcs/publisher.py
34 34
from .Defaults import *  # noqa pylint: disable=wildcard-import
35 35
from .qommon.cron import CronJob
36 36
from .qommon.publisher import QommonPublisher, get_request, set_publisher_class
37
from .qommon.tokens import Token
37 38
from .roles import Role
38 39
from .root import RootDirectory
39 40
from .tracking_code import TrackingCode
......
144 145

  
145 146
            self.user_class = sql.SqlUser
146 147
            self.role_class = sql.Role
148
            self.token_class = sql.Token
147 149
            self.tracking_code_class = sql.TrackingCode
148 150
            self.session_class = sql.Session
149 151
            self.custom_view_class = sql.CustomView
......
153 155
        else:
154 156
            self.user_class = User
155 157
            self.role_class = Role
158
            self.token_class = Token
156 159
            self.tracking_code_class = TrackingCode
157 160
            self.session_class = sessions.BasicSession
158 161
            self.custom_view_class = custom_views.CustomView
wcs/qommon/ident/password.py
369 369
            'change_url': get_request().get_frontoffice_url() + '?t=%s&a=cfmpw' % token.id,
370 370
            'cancel_url': get_request().get_frontoffice_url() + '?t=%s&a=cxlpw' % token.id,
371 371
            'token': token.id,
372
            'time': misc.localstrftime(time.localtime(token.expiration)),
372
            'time': misc.localstrftime(token.expiration),
373 373
        }
374 374

  
375 375
        try:
wcs/qommon/tokens.py
14 14
# You should have received a copy of the GNU General Public License
15 15
# along with this program; if not, see <http://www.gnu.org/licenses/>.
16 16

  
17
import datetime
17 18
import random
18 19
import string
19
import time
20

  
21
from django.utils.timezone import make_aware, now
20 22

  
21 23
from .storage import StorableObject
22 24

  
......
29 31
    context = None
30 32

  
31 33
    def __init__(self, expiration_delay=86400, size=16, chars=None):
32
        chars = chars or list(string.digits + string.ascii_letters)
33
        StorableObject.__init__(self, self.get_new_id(size, chars))
34
        super().__init__(self.get_new_id(size, chars))
34 35
        if expiration_delay:
35
            self.expiration = time.time() + expiration_delay
36
            self.set_expiration_delay(expiration_delay)
37

  
38
    def set_expiration_delay(self, expiration_delay):
39
        self.expiration = now() + datetime.timedelta(seconds=expiration_delay)
36 40

  
37 41
    @classmethod
38
    def get_new_id(cls, size, chars):
42
    def get_new_id(cls, size=16, chars=None):
43
        chars = chars or list(string.digits + string.ascii_letters)
39 44
        r = random.SystemRandom()
40 45
        while True:
41 46
            id = ''.join([r.choice(chars) for x in range(size)])
......
43 48
                return id
44 49

  
45 50
    def migrate(self):
46
        if self.expiration and self.expiration < time.time():
51
        if isinstance(self.expiration, (float, int)):
52
            self.expiration = make_aware(datetime.datetime.fromtimestamp(self.expiration))
53
        self.expiration_check()
54

  
55
    def expiration_check(self):
56
        if self.expiration and self.expiration < now():
47 57
            self.remove_self()
48 58
            raise KeyError()
wcs/sql.py
40 40
import wcs.custom_views
41 41
import wcs.formdata
42 42
import wcs.logged_errors
43
import wcs.qommon.tokens
43 44
import wcs.roles
44 45
import wcs.snapshots
45 46
import wcs.tracking_code
......
1175 1176
    cur.close()
1176 1177

  
1177 1178

  
1179
def do_tokens_table(concurrently=False):
1180
    conn, cur = get_connection_and_cursor()
1181
    table_name = Token._table_name
1182

  
1183
    cur.execute(
1184
        '''SELECT COUNT(*) FROM information_schema.tables
1185
                    WHERE table_schema = 'public'
1186
                      AND table_name = %s''',
1187
        (table_name,),
1188
    )
1189
    if cur.fetchone()[0] == 0:
1190
        cur.execute(
1191
            '''CREATE TABLE %s (id VARCHAR PRIMARY KEY,
1192
                                type VARCHAR,
1193
                                expiration TIMESTAMPTZ,
1194
                                context JSONB
1195
                               )'''
1196
            % table_name
1197
        )
1198
    cur.execute(
1199
        '''SELECT column_name FROM information_schema.columns
1200
                    WHERE table_schema = 'public'
1201
                      AND table_name = %s''',
1202
        (table_name,),
1203
    )
1204
    existing_fields = {x[0] for x in cur.fetchall()}
1205

  
1206
    needed_fields = {x[0] for x in Token._table_static_fields}
1207

  
1208
    # delete obsolete fields
1209
    for field in existing_fields - needed_fields:
1210
        cur.execute('''ALTER TABLE %s DROP COLUMN %s''' % (table_name, field))
1211

  
1212
    conn.commit()
1213
    cur.close()
1214

  
1215

  
1216
def migrate_legacy_tokens():
1217
    # store old pickle tokens in SQL
1218
    for token_id in wcs.qommon.tokens.Token.keys():
1219
        try:
1220
            token = wcs.qommon.tokens.Token.get(token_id)
1221
        except KeyError:
1222
            continue
1223
        except AttributeError:
1224
            # old python2 tokens:
1225
            # AttributeError: module 'builtins' has no attribute 'unicode'
1226
            wcs.qommon.tokens.Token.remove_object(token_id)
1227
            continue
1228
        token.__class__ = Token
1229
        token.store()
1230

  
1231

  
1178 1232
@guard_postgres
1179 1233
def do_meta_table(conn=None, cur=None, insert_current_sql_level=True):
1180 1234
    own_conn = False
......
1662 1716

  
1663 1717
        return ' ORDER BY %s %s' % (order_by, direction)
1664 1718

  
1719
    @classmethod
1720
    @guard_postgres
1721
    def has_key(cls, id):
1722
        conn, cur = get_connection_and_cursor()
1723
        sql_statement = 'SELECT EXISTS(SELECT 1 FROM %s WHERE id = %%s)' % cls._table_name
1724
        with cur:
1725
            cur.execute(sql_statement, (id,))
1726
            result = cur.fetchall()[0][0]
1727
            conn.commit()
1728
        return result
1729

  
1665 1730
    @classmethod
1666 1731
    @guard_postgres
1667 1732
    def select_iterator(
......
3193 3258
        cur.close()
3194 3259

  
3195 3260

  
3261
class Token(SqlMixin, wcs.qommon.tokens.Token):
3262
    _table_name = 'tokens'
3263
    _table_static_fields = [
3264
        ('id', 'varchar'),
3265
        ('type', 'varchar'),
3266
        ('expiration', 'timestamptz'),
3267
        ('context', 'jsonb'),
3268
    ]
3269

  
3270
    _numerical_id = False
3271

  
3272
    @guard_postgres
3273
    def store(self):
3274
        sql_dict = {
3275
            'id': self.id,
3276
            'type': self.type,
3277
            'expiration': self.expiration,
3278
            'context': self.context,
3279
        }
3280

  
3281
        conn, cur = get_connection_and_cursor()
3282
        column_names = sql_dict.keys()
3283

  
3284
        if not self.id:
3285
            sql_dict['id'] = self.get_new_id()
3286
            sql_statement = '''INSERT INTO %s (%s)
3287
                               VALUES (%s)
3288
                               RETURNING id''' % (
3289
                self._table_name,
3290
                ', '.join(column_names),
3291
                ', '.join(['%%(%s)s' % x for x in column_names]),
3292
            )
3293
            while True:
3294
                try:
3295
                    cur.execute(sql_statement, sql_dict)
3296
                except psycopg2.IntegrityError:
3297
                    conn.rollback()
3298
                    sql_dict['id'] = self.get_new_id()
3299
                else:
3300
                    break
3301
            self.id = cur.fetchone()[0]
3302
        else:
3303
            sql_statement = '''UPDATE %s SET %s WHERE id = %%(id)s RETURNING id''' % (
3304
                self._table_name,
3305
                ', '.join(['%s = %%(%s)s' % (x, x) for x in column_names]),
3306
            )
3307
            cur.execute(sql_statement, sql_dict)
3308
            if cur.fetchone() is None:
3309
                sql_statement = '''INSERT INTO %s (%s) VALUES (%s)''' % (
3310
                    self._table_name,
3311
                    ', '.join(column_names),
3312
                    ', '.join(['%%(%s)s' % x for x in column_names]),
3313
                )
3314
                cur.execute(sql_statement, sql_dict)
3315

  
3316
        conn.commit()
3317
        cur.close()
3318

  
3319
    @classmethod
3320
    def get_data_fields(cls):
3321
        return []
3322

  
3323
    @classmethod
3324
    def _row2ob(cls, row, **kwargs):
3325
        o = cls()
3326
        for field, value in zip(cls._table_static_fields, tuple(row)):
3327
            setattr(o, field[0], value)
3328
        o.expiration_check()
3329
        return o
3330

  
3331

  
3196 3332
class classproperty:
3197 3333
    def __init__(self, f):
3198 3334
        self.f = f
......
3509 3645
# latest migration, number + description (description is not used
3510 3646
# programmaticaly but will make sure git conflicts if two migrations are
3511 3647
# separately added with the same number)
3512
SQL_LEVEL = (56, 'add gin indexes to concerned_roles_array and actions_roles_array')
3648
SQL_LEVEL = (57, 'store tokens in SQL')
3513 3649

  
3514 3650

  
3515 3651
def migrate_global_views(conn, cur):
......
3701 3837
        # 50: switch role uuid column to varchar
3702 3838
        do_role_table()
3703 3839
        migrate_legacy_roles()
3840
    if sql_level < 57:
3841
        # 57: store tokens in SQL
3842
        do_tokens_table()
3843
        migrate_legacy_tokens()
3704 3844

  
3705 3845
    cur.execute('''UPDATE wcs_meta SET value = %s WHERE key = %s''', (str(SQL_LEVEL[0]), 'sql_level'))
3706 3846

  
3707
-