From 2f3745792a0a44a321c6515d6053daa55d0b26c0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fr=C3=A9d=C3=A9ric=20P=C3=A9ters?= Date: Tue, 23 May 2017 18:17:47 +0200 Subject: [PATCH] sql: get varchar/text values as unicode (#15802) This matches what's being done in Django and will help integrating w.c.s. with Django applications. --- tests/test_sql.py | 3 +++ wcs/sql.py | 30 ++++++++++++++++++++++-------- 2 files changed, 25 insertions(+), 8 deletions(-) diff --git a/tests/test_sql.py b/tests/test_sql.py index 932c0895..8d3e21e9 100644 --- a/tests/test_sql.py +++ b/tests/test_sql.py @@ -156,6 +156,7 @@ def check_sql_field(no, value): @postgresql def test_sql_field_string(): check_sql_field('0', 'hello world') + check_sql_field('0', 'élo world') @postgresql def test_sql_field_email(): @@ -164,6 +165,7 @@ def test_sql_field_email(): @postgresql def test_sql_field_text(): check_sql_field('2', 'long text') + check_sql_field('2', 'long tèxt') @postgresql def test_sql_field_bool(): @@ -182,6 +184,7 @@ def test_sql_field_date(): def test_sql_field_items(): check_sql_field('6', ['apricot']) check_sql_field('6', ['apricot', 'pear']) + check_sql_field('6', ['pomme', 'poire', 'pêche']) @postgresql def test_sql_geoloc(): diff --git a/wcs/sql.py b/wcs/sql.py index 019e0859..b444d735 100644 --- a/wcs/sql.py +++ b/wcs/sql.py @@ -15,6 +15,7 @@ # along with this program; if not, see . import psycopg2 +import psycopg2.extensions import datetime import time import re @@ -30,6 +31,11 @@ import wcs.formdata import wcs.tracking_code import wcs.users +# enable psycogp2 unicode mode, this will fetch postgresql varchar/text columns +# as unicode objects +psycopg2.extensions.register_type(psycopg2.extensions.UNICODE) +psycopg2.extensions.register_type(psycopg2.extensions.UNICODEARRAY) + SQL_TYPE_MAPPING = { 'title': None, 'subtitle': None, @@ -234,6 +240,13 @@ def parse_clause(clause): return (where_clauses, parameters, None) +def str_encode(value): + if isinstance(value, list): + return [str_encode(x) for x in value] + if not isinstance(value, unicode): + return value + return value.encode(get_publisher().site_charset) + def get_connection(new=False): if new: @@ -999,7 +1012,8 @@ class SqlMixin(object): if sql_type is None: continue value = row[i] - if value: + if value is not None: + value = str_encode(value) if field.key == 'ranked-items': d = {} for data, rank in value: @@ -1020,7 +1034,7 @@ class SqlMixin(object): obdata[field.id] = value i += 1 if field.store_display_value: - value = row[i] + value = str_encode(row[i]) obdata['%s_display' % field.id] = value i += 1 if field.store_structured_value: @@ -1124,7 +1138,7 @@ class SqlFormData(SqlMixin, wcs.formdata.FormData): @classmethod def _row2evo(cls, row): o = wcs.formdata.Evolution() - o._sql_id, o.who, o.status, o.time, o.comment = tuple(row[:5]) + o._sql_id, o.who, o.status, o.time, o.comment = [str_encode(x) for x in tuple(row[:5])] if o.time: o.time = o.time.timetuple() if row[5]: @@ -1319,7 +1333,7 @@ class SqlFormData(SqlMixin, wcs.formdata.FormData): o = cls() for static_field, value in zip(cls._table_static_fields, tuple(row[:len(cls._table_static_fields)])): - setattr(o, static_field[0], value) + setattr(o, static_field[0], str_encode(value)) if o.receipt_time: o.receipt_time = o.receipt_time.timetuple() if o.workflow_data: @@ -1559,7 +1573,7 @@ class SqlUser(SqlMixin, wcs.users.User): o = cls() (o.id, o.name, o.email, o.roles, o.is_admin, o.anonymous, o.name_identifiers, o.verified_fields, o.lasso_dump, - o.last_seen, ascii_name) = tuple(row[:11]) + o.last_seen, ascii_name) = [str_encode(x) for x in tuple(row[:11])] if o.last_seen: o.last_seen = time.mktime(o.last_seen.timetuple()) if o.roles: @@ -1688,7 +1702,7 @@ class TrackingCode(SqlMixin, wcs.tracking_code.TrackingCode): sql_dict['id'] = self.get_new_id() else: break - self.id = cur.fetchone()[0] + self.id = str_encode(cur.fetchone()[0]) else: column_names = sql_dict.keys() sql_dict['id'] = self.id @@ -1705,7 +1719,7 @@ class TrackingCode(SqlMixin, wcs.tracking_code.TrackingCode): @classmethod def _row2ob(cls, row): o = cls() - (o.id, o.formdef_id, o.formdata_id) = tuple(row[:3]) + (o.id, o.formdef_id, o.formdata_id) = [str_encode(x) for x in tuple(row[:3])] return o @classmethod @@ -1756,7 +1770,7 @@ class AnyFormData(SqlMixin): o = formdef.data_class()() for static_field, value in zip(cls._table_static_fields, tuple(row[:len(cls._table_static_fields)])): - setattr(o, static_field[0], value) + setattr(o, static_field[0], str_encode(value)) # [CRITICALITY_2] transform criticality_level back to the expected # range (see [CRITICALITY_1]) levels = len(formdef.workflow.criticality_levels or [0]) -- 2.13.1