0002-misc-move-tokens-to-an-SQL-table-60665.patch
tests/test_token.py | ||
---|---|---|
1 |
import os |
|
2 |
import time |
|
3 | ||
4 |
import pytest |
|
5 |
from django.utils.timezone import now |
|
6 |
from quixote import get_publisher |
|
7 | ||
8 |
import wcs.sql |
|
9 |
from wcs.qommon.tokens import Token |
|
10 | ||
11 |
from .utilities import create_temporary_pub |
|
12 | ||
13 | ||
14 |
def pytest_generate_tests(metafunc): |
|
15 |
if 'two_pubs' in metafunc.fixturenames: |
|
16 |
metafunc.parametrize('two_pubs', ['pickle', 'sql'], indirect=True) |
|
17 | ||
18 | ||
19 |
@pytest.fixture |
|
20 |
def pub(request): |
|
21 |
return create_temporary_pub(sql_mode=True) |
|
22 | ||
23 | ||
24 |
@pytest.fixture |
|
25 |
def two_pubs(request): |
|
26 |
return create_temporary_pub(sql_mode=(request.param == 'sql')) |
|
27 | ||
28 | ||
29 |
def test_migrate_to_sql(pub): |
|
30 |
get_publisher().token_class.wipe() |
|
31 |
assert get_publisher().token_class.count() == 0 |
|
32 |
token = Token() |
|
33 |
token.expiration = time.time() + 86400 # expiration stored as timestamp |
|
34 |
token.context = {'a': 'b'} |
|
35 |
token.store() |
|
36 |
assert os.path.exists(token.get_object_filename()) |
|
37 | ||
38 |
token2 = Token() |
|
39 |
token2.expiration = time.time() - 86400 # already expired |
|
40 |
token2.context = {'a': 'b'} |
|
41 |
token2.store() |
|
42 |
assert os.path.exists(token2.get_object_filename()) |
|
43 | ||
44 |
wcs.sql.migrate_legacy_tokens() |
|
45 |
assert os.path.exists(token.get_object_filename()) |
|
46 |
assert not os.path.exists(token2.get_object_filename()) |
|
47 |
os.unlink(token.get_object_filename()) |
|
48 |
assert get_publisher().token_class.count() == 1 |
|
49 |
sql_token = get_publisher().token_class.get(token.id) |
|
50 |
assert sql_token.id == token.id |
|
51 |
assert sql_token.context == token.context |
|
52 |
assert sql_token.expiration.year == now().year |
|
53 | ||
54 | ||
55 |
def test_expiration(two_pubs): |
|
56 |
get_publisher().token_class.wipe() |
|
57 |
token = get_publisher().token_class() |
|
58 |
token.store() |
|
59 |
assert get_publisher().token_class().get(token.id) |
|
60 | ||
61 |
token = get_publisher().token_class(expiration_delay=-3600) # already expired |
|
62 |
token.store() |
|
63 |
with pytest.raises(KeyError): |
|
64 |
assert get_publisher().token_class().get(token.id) |
tests/utilities.py | ||
---|---|---|
21 | 21 |
from wcs import compat, custom_views, sessions, sql |
22 | 22 |
from wcs.qommon import force_str |
23 | 23 |
from wcs.qommon.errors import ConnectionError |
24 |
from wcs.qommon.tokens import Token |
|
24 | 25 |
from wcs.roles import Role |
25 | 26 |
from wcs.tracking_code import TrackingCode |
26 | 27 |
from wcs.users import User |
... | ... | |
78 | 79 |
if sql_mode: |
79 | 80 |
pub.user_class = sql.SqlUser |
80 | 81 |
pub.role_class = sql.Role |
82 |
pub.token_class = sql.Token |
|
81 | 83 |
pub.tracking_code_class = sql.TrackingCode |
82 | 84 |
pub.session_class = sql.Session |
83 | 85 |
pub.custom_view_class = sql.CustomView |
... | ... | |
87 | 89 |
else: |
88 | 90 |
pub.user_class = User |
89 | 91 |
pub.role_class = Role |
92 |
pub.token_class = Token |
|
90 | 93 |
pub.tracking_code_class = TrackingCode |
91 | 94 |
pub.session_class = sessions.BasicSession |
92 | 95 |
pub.custom_view_class = custom_views.CustomView |
... | ... | |
169 | 172 | |
170 | 173 |
sql.do_user_table() |
171 | 174 |
sql.do_role_table() |
175 |
sql.do_tokens_table() |
|
172 | 176 |
sql.do_tracking_code_table() |
173 | 177 |
sql.do_session_table() |
174 | 178 |
sql.do_custom_views_table() |
wcs/publisher.py | ||
---|---|---|
34 | 34 |
from .Defaults import * # noqa pylint: disable=wildcard-import |
35 | 35 |
from .qommon.cron import CronJob |
36 | 36 |
from .qommon.publisher import QommonPublisher, get_request, set_publisher_class |
37 |
from .qommon.tokens import Token |
|
37 | 38 |
from .roles import Role |
38 | 39 |
from .root import RootDirectory |
39 | 40 |
from .tracking_code import TrackingCode |
... | ... | |
144 | 145 | |
145 | 146 |
self.user_class = sql.SqlUser |
146 | 147 |
self.role_class = sql.Role |
148 |
self.token_class = sql.Token |
|
147 | 149 |
self.tracking_code_class = sql.TrackingCode |
148 | 150 |
self.session_class = sql.Session |
149 | 151 |
self.custom_view_class = sql.CustomView |
... | ... | |
153 | 155 |
else: |
154 | 156 |
self.user_class = User |
155 | 157 |
self.role_class = Role |
158 |
self.token_class = Token |
|
156 | 159 |
self.tracking_code_class = TrackingCode |
157 | 160 |
self.session_class = sessions.BasicSession |
158 | 161 |
self.custom_view_class = custom_views.CustomView |
wcs/qommon/ident/password.py | ||
---|---|---|
369 | 369 |
'change_url': get_request().get_frontoffice_url() + '?t=%s&a=cfmpw' % token.id, |
370 | 370 |
'cancel_url': get_request().get_frontoffice_url() + '?t=%s&a=cxlpw' % token.id, |
371 | 371 |
'token': token.id, |
372 |
'time': misc.localstrftime(time.localtime(token.expiration)),
|
|
372 |
'time': misc.localstrftime(token.expiration),
|
|
373 | 373 |
} |
374 | 374 | |
375 | 375 |
try: |
wcs/qommon/tokens.py | ||
---|---|---|
14 | 14 |
# You should have received a copy of the GNU General Public License |
15 | 15 |
# along with this program; if not, see <http://www.gnu.org/licenses/>. |
16 | 16 | |
17 |
import datetime |
|
17 | 18 |
import random |
18 | 19 |
import string |
19 |
import time |
|
20 | ||
21 |
from django.utils.timezone import make_aware, now |
|
20 | 22 | |
21 | 23 |
from .storage import StorableObject |
22 | 24 | |
... | ... | |
29 | 31 |
context = None |
30 | 32 | |
31 | 33 |
def __init__(self, expiration_delay=86400, size=16, chars=None): |
32 |
chars = chars or list(string.digits + string.ascii_letters) |
|
33 |
StorableObject.__init__(self, self.get_new_id(size, chars)) |
|
34 |
super().__init__(self.get_new_id(size, chars)) |
|
34 | 35 |
if expiration_delay: |
35 |
self.expiration = time.time() + expiration_delay |
|
36 |
self.set_expiration_delay(expiration_delay) |
|
37 | ||
38 |
def set_expiration_delay(self, expiration_delay): |
|
39 |
self.expiration = now() + datetime.timedelta(seconds=expiration_delay) |
|
36 | 40 | |
37 | 41 |
@classmethod |
38 |
def get_new_id(cls, size, chars): |
|
42 |
def get_new_id(cls, size=16, chars=None): |
|
43 |
chars = chars or list(string.digits + string.ascii_letters) |
|
39 | 44 |
r = random.SystemRandom() |
40 | 45 |
while True: |
41 | 46 |
id = ''.join([r.choice(chars) for x in range(size)]) |
... | ... | |
43 | 48 |
return id |
44 | 49 | |
45 | 50 |
def migrate(self): |
46 |
if self.expiration and self.expiration < time.time(): |
|
51 |
if isinstance(self.expiration, (float, int)): |
|
52 |
self.expiration = make_aware(datetime.datetime.fromtimestamp(self.expiration)) |
|
53 |
self.expiration_check() |
|
54 | ||
55 |
def expiration_check(self): |
|
56 |
if self.expiration and self.expiration < now(): |
|
47 | 57 |
self.remove_self() |
48 | 58 |
raise KeyError() |
wcs/sql.py | ||
---|---|---|
40 | 40 |
import wcs.custom_views |
41 | 41 |
import wcs.formdata |
42 | 42 |
import wcs.logged_errors |
43 |
import wcs.qommon.tokens |
|
43 | 44 |
import wcs.roles |
44 | 45 |
import wcs.snapshots |
45 | 46 |
import wcs.tracking_code |
... | ... | |
1175 | 1176 |
cur.close() |
1176 | 1177 | |
1177 | 1178 | |
1179 |
def do_tokens_table(concurrently=False): |
|
1180 |
conn, cur = get_connection_and_cursor() |
|
1181 |
table_name = Token._table_name |
|
1182 | ||
1183 |
cur.execute( |
|
1184 |
'''SELECT COUNT(*) FROM information_schema.tables |
|
1185 |
WHERE table_schema = 'public' |
|
1186 |
AND table_name = %s''', |
|
1187 |
(table_name,), |
|
1188 |
) |
|
1189 |
if cur.fetchone()[0] == 0: |
|
1190 |
cur.execute( |
|
1191 |
'''CREATE TABLE %s (id VARCHAR PRIMARY KEY, |
|
1192 |
type VARCHAR, |
|
1193 |
expiration TIMESTAMPTZ, |
|
1194 |
context JSONB |
|
1195 |
)''' |
|
1196 |
% table_name |
|
1197 |
) |
|
1198 |
cur.execute( |
|
1199 |
'''SELECT column_name FROM information_schema.columns |
|
1200 |
WHERE table_schema = 'public' |
|
1201 |
AND table_name = %s''', |
|
1202 |
(table_name,), |
|
1203 |
) |
|
1204 |
existing_fields = {x[0] for x in cur.fetchall()} |
|
1205 | ||
1206 |
needed_fields = {x[0] for x in Token._table_static_fields} |
|
1207 | ||
1208 |
# delete obsolete fields |
|
1209 |
for field in existing_fields - needed_fields: |
|
1210 |
cur.execute('''ALTER TABLE %s DROP COLUMN %s''' % (table_name, field)) |
|
1211 | ||
1212 |
conn.commit() |
|
1213 |
cur.close() |
|
1214 | ||
1215 | ||
1216 |
def migrate_legacy_tokens(): |
|
1217 |
# store old pickle tokens in SQL |
|
1218 |
for token_id in wcs.qommon.tokens.Token.keys(): |
|
1219 |
try: |
|
1220 |
token = wcs.qommon.tokens.Token.get(token_id) |
|
1221 |
except KeyError: |
|
1222 |
continue |
|
1223 |
except AttributeError: |
|
1224 |
# old python2 tokens: |
|
1225 |
# AttributeError: module 'builtins' has no attribute 'unicode' |
|
1226 |
wcs.qommon.tokens.Token.remove_object(token_id) |
|
1227 |
continue |
|
1228 |
token.__class__ = Token |
|
1229 |
token.store() |
|
1230 | ||
1231 | ||
1178 | 1232 |
@guard_postgres |
1179 | 1233 |
def do_meta_table(conn=None, cur=None, insert_current_sql_level=True): |
1180 | 1234 |
own_conn = False |
... | ... | |
1662 | 1716 | |
1663 | 1717 |
return ' ORDER BY %s %s' % (order_by, direction) |
1664 | 1718 | |
1719 |
@classmethod |
|
1720 |
@guard_postgres |
|
1721 |
def has_key(cls, id): |
|
1722 |
conn, cur = get_connection_and_cursor() |
|
1723 |
sql_statement = 'SELECT EXISTS(SELECT 1 FROM %s WHERE id = %%s)' % cls._table_name |
|
1724 |
with cur: |
|
1725 |
cur.execute(sql_statement, (id,)) |
|
1726 |
result = cur.fetchall()[0][0] |
|
1727 |
conn.commit() |
|
1728 |
return result |
|
1729 | ||
1665 | 1730 |
@classmethod |
1666 | 1731 |
@guard_postgres |
1667 | 1732 |
def select_iterator( |
... | ... | |
3193 | 3258 |
cur.close() |
3194 | 3259 | |
3195 | 3260 | |
3261 |
class Token(SqlMixin, wcs.qommon.tokens.Token): |
|
3262 |
_table_name = 'tokens' |
|
3263 |
_table_static_fields = [ |
|
3264 |
('id', 'varchar'), |
|
3265 |
('type', 'varchar'), |
|
3266 |
('expiration', 'timestamptz'), |
|
3267 |
('context', 'jsonb'), |
|
3268 |
] |
|
3269 | ||
3270 |
_numerical_id = False |
|
3271 | ||
3272 |
@guard_postgres |
|
3273 |
def store(self): |
|
3274 |
sql_dict = { |
|
3275 |
'id': self.id, |
|
3276 |
'type': self.type, |
|
3277 |
'expiration': self.expiration, |
|
3278 |
'context': self.context, |
|
3279 |
} |
|
3280 | ||
3281 |
conn, cur = get_connection_and_cursor() |
|
3282 |
column_names = sql_dict.keys() |
|
3283 | ||
3284 |
if not self.id: |
|
3285 |
sql_dict['id'] = self.get_new_id() |
|
3286 |
sql_statement = '''INSERT INTO %s (%s) |
|
3287 |
VALUES (%s) |
|
3288 |
RETURNING id''' % ( |
|
3289 |
self._table_name, |
|
3290 |
', '.join(column_names), |
|
3291 |
', '.join(['%%(%s)s' % x for x in column_names]), |
|
3292 |
) |
|
3293 |
while True: |
|
3294 |
try: |
|
3295 |
cur.execute(sql_statement, sql_dict) |
|
3296 |
except psycopg2.IntegrityError: |
|
3297 |
conn.rollback() |
|
3298 |
sql_dict['id'] = self.get_new_id() |
|
3299 |
else: |
|
3300 |
break |
|
3301 |
self.id = cur.fetchone()[0] |
|
3302 |
else: |
|
3303 |
sql_statement = '''UPDATE %s SET %s WHERE id = %%(id)s RETURNING id''' % ( |
|
3304 |
self._table_name, |
|
3305 |
', '.join(['%s = %%(%s)s' % (x, x) for x in column_names]), |
|
3306 |
) |
|
3307 |
cur.execute(sql_statement, sql_dict) |
|
3308 |
if cur.fetchone() is None: |
|
3309 |
sql_statement = '''INSERT INTO %s (%s) VALUES (%s)''' % ( |
|
3310 |
self._table_name, |
|
3311 |
', '.join(column_names), |
|
3312 |
', '.join(['%%(%s)s' % x for x in column_names]), |
|
3313 |
) |
|
3314 |
cur.execute(sql_statement, sql_dict) |
|
3315 | ||
3316 |
conn.commit() |
|
3317 |
cur.close() |
|
3318 | ||
3319 |
@classmethod |
|
3320 |
def get_data_fields(cls): |
|
3321 |
return [] |
|
3322 | ||
3323 |
@classmethod |
|
3324 |
def _row2ob(cls, row, **kwargs): |
|
3325 |
o = cls() |
|
3326 |
for field, value in zip(cls._table_static_fields, tuple(row)): |
|
3327 |
setattr(o, field[0], value) |
|
3328 |
o.expiration_check() |
|
3329 |
return o |
|
3330 | ||
3331 | ||
3196 | 3332 |
class classproperty: |
3197 | 3333 |
def __init__(self, f): |
3198 | 3334 |
self.f = f |
... | ... | |
3509 | 3645 |
# latest migration, number + description (description is not used |
3510 | 3646 |
# programmaticaly but will make sure git conflicts if two migrations are |
3511 | 3647 |
# separately added with the same number) |
3512 |
SQL_LEVEL = (56, 'add gin indexes to concerned_roles_array and actions_roles_array')
|
|
3648 |
SQL_LEVEL = (57, 'store tokens in SQL')
|
|
3513 | 3649 | |
3514 | 3650 | |
3515 | 3651 |
def migrate_global_views(conn, cur): |
... | ... | |
3701 | 3837 |
# 50: switch role uuid column to varchar |
3702 | 3838 |
do_role_table() |
3703 | 3839 |
migrate_legacy_roles() |
3840 |
if sql_level < 57: |
|
3841 |
# 57: store tokens in SQL |
|
3842 |
do_tokens_table() |
|
3843 |
migrate_legacy_tokens() |
|
3704 | 3844 | |
3705 | 3845 |
cur.execute('''UPDATE wcs_meta SET value = %s WHERE key = %s''', (str(SQL_LEVEL[0]), 'sql_level')) |
3706 | 3846 | |
3707 |
- |