Projet

Général

Profil

0001-snapshot-store-a-patch-instead-of-serialization-5729.patch

Lauréline Guérin, 05 octobre 2021 14:09

Télécharger (19,7 ko)

Voir les différences:

Subject: [PATCH] snapshot: store a patch instead of serialization (#57299)

 tests/test_snapshots.py     | 146 +++++++++++----------------------
 wcs/backoffice/snapshots.py |   2 +-
 wcs/snapshots.py            | 157 ++++++++++++++++++++++++++++++++----
 wcs/sql.py                  |  24 ++++--
 4 files changed, 207 insertions(+), 122 deletions(-)
tests/test_snapshots.py
9 9
from wcs.blocks import BlockDef
10 10
from wcs.carddef import CardDef
11 11
from wcs.data_sources import NamedDataSource
12
from wcs.fields import ItemField
12
from wcs.fields import ItemField, StringField
13 13
from wcs.formdef import FormDef
14 14
from wcs.mail_templates import MailTemplate
15 15
from wcs.qommon.form import UploadedFile
......
66 66
    formdef.fields = []
67 67
    formdef.store()
68 68

  
69
    carddef = CardDef()
70
    carddef.name = 'testcard'
71
    carddef.fields = []
72
    carddef.store()
69
    # first occurence, complete snapshot stored
70
    assert pub.snapshot_class.count() == 1
71
    snapshot1 = pub.snapshot_class.get_latest('formdef', formdef.id)
72
    assert snapshot1.serialization is not None
73
    assert '>testform<' in snapshot1.serialization
74
    assert snapshot1.patch is None
75
    assert snapshot1.instance  # possible to restore
73 76

  
77
    # no changes
78
    formdef.store()
79
    assert pub.snapshot_class.count() == 1
80

  
81
    # patch only
74 82
    formdef.name = 'testform2'
75 83
    formdef.store()
84
    assert pub.snapshot_class.count() == 2
76 85

  
77
    carddef.name = 'testcard2'
78
    carddef.store()
86
    snapshot2 = pub.snapshot_class.get_latest('formdef', formdef.id)
87
    assert snapshot2.serialization is None
88
    assert '>testform2<' in snapshot2.patch
89
    assert snapshot2.instance  # possible to restore
79 90

  
80
    data_source = NamedDataSource(name='foobar')
81
    data_source.data_source = {'type': 'formula', 'value': repr([('1', 'un'), ('2', 'deux')])}
82
    data_source.store()
91
    # no diff with latest snap but label is given
92
    pub.snapshot_class.snap(instance=formdef, label="foo bar")
93
    assert pub.snapshot_class.count() == 3
94
    snapshot3 = pub.snapshot_class.get_latest('formdef', formdef.id)
95
    assert snapshot3.serialization is None
96
    assert '>testform2<' in snapshot3.patch
97
    assert snapshot2.patch == snapshot3.patch
98
    assert snapshot3.instance  # possible to restore
99

  
100
    # patch is longer as serialization, store serialization
101
    formdef.name = 'testform3'
102
    formdef.fields = [StringField(id=str(i), label='Test %s' % i, type='string') for i in range(0, 10)]
103
    formdef.store()
104
    assert pub.snapshot_class.count() == 4
105
    snapshot4 = pub.snapshot_class.get_latest('formdef', formdef.id)
106
    assert snapshot4.serialization is not None
107
    assert '>testform3<' in snapshot4.serialization
108
    assert snapshot4.patch is None
109
    assert snapshot4.instance  # possible to restore
110

  
111
    # no diff with latest snap but label is given
112
    pub.snapshot_class.snap(instance=formdef, label="foo bar")
83 113
    assert pub.snapshot_class.count() == 5
84

  
85
    # check we got correct data in the serializations
86
    snapshot = pub.snapshot_class.get_latest('formdef', formdef.id)
87
    assert '>testform2<' in snapshot.serialization
88

  
89
    snapshot = pub.snapshot_class.get_latest('carddef', carddef.id)
90
    assert '>testcard2<' in snapshot.serialization
114
    snapshot5 = pub.snapshot_class.get_latest('formdef', formdef.id)
115
    assert snapshot5.serialization is None
116
    assert snapshot5.patch == ''  # no difference with latest snap, which has a serialization
117
    assert snapshot5.instance  # possible to restore
91 118

  
92 119

  
93 120
def test_snapshot_instance(pub):
......
117 144
    snapshots = pub.snapshot_class.select_object_history(formdef)
118 145
    assert len(snapshots) == 10
119 146
    for i in range(10):
120
        assert snapshots[i].serialization is None
147
        assert snapshots[i].serialization is None  # not loaded
148
        assert snapshots[i].patch is None  # not loaded
121 149
        assert pub.snapshot_class.get(snapshots[i].id).instance.name == 'testform %s' % (9 - i)
122 150

  
123 151
    snapshots = pub.snapshot_class.select_object_history(carddef)
......
676 704
    ]
677 705

  
678 706

  
679
@pytest.fixture
680
def size_limit(pub):
681
    pub.snapshot_class.WCS_MAX_LEN = 100
682
    yield
683
    pub.snapshot_class.WCS_MAX_LEN = 1000000
684

  
685

  
686
def test_workflow_snapshot_max_len(pub, size_limit):
687
    formdef = FormDef()
688
    formdef.name = 'testform'
689
    formdef.fields = []
690
    formdef.store()
691

  
692
    Workflow.wipe()
693
    workflow = Workflow(name='test')
694
    workflow.store()
695

  
696
    another_workflow = Workflow(name='other test')
697
    another_workflow.store()  # same object_type - check that other instances snapshots are not deleted
698

  
699
    assert formdef.id == workflow.id  # same id - check other object_type snapshots are not deleted
700

  
701
    # first one: saved
702
    assert pub.snapshot_class.count() == 3
703
    first_id = pub.snapshot_class.select(order_by='id')[0].id
704
    assert pub.snapshot_class.get(first_id).object_type == 'formdef'
705
    assert pub.snapshot_class.get(first_id + 1).object_type == 'workflow'
706
    assert pub.snapshot_class.get(first_id + 1).object_id == '1'
707
    old_timestamp = pub.snapshot_class.get(first_id + 1).timestamp
708
    assert pub.snapshot_class.get(first_id + 2).object_type == 'workflow'
709
    assert pub.snapshot_class.get(first_id + 2).object_id == '2'
710

  
711
    # save snapshot
712
    pub.snapshot_class.snap(instance=workflow, label="snapshot !")
713
    assert pub.snapshot_class.count() == 4
714
    assert pub.snapshot_class.get(first_id).object_type == 'formdef'
715
    assert pub.snapshot_class.get(first_id + 1).object_type == 'workflow'
716
    assert pub.snapshot_class.get(first_id + 1).object_id == '1'
717
    assert pub.snapshot_class.get(first_id + 1).label is None
718
    assert pub.snapshot_class.get(first_id + 1).timestamp == old_timestamp
719
    assert pub.snapshot_class.get(first_id + 1).instance.name == 'test'
720
    assert pub.snapshot_class.get(first_id + 2).object_type == 'workflow'
721
    assert pub.snapshot_class.get(first_id + 2).object_id == '2'
722
    assert pub.snapshot_class.get(first_id + 3).object_type == 'workflow'
723
    assert pub.snapshot_class.get(first_id + 3).object_id == '1'
724
    assert pub.snapshot_class.get(first_id + 3).label == "snapshot !"
725
    assert pub.snapshot_class.get(first_id + 3).instance.name == 'test'
726

  
727
    # no changes
728
    workflow.store()
729
    assert pub.snapshot_class.count() == 4
730
    assert pub.snapshot_class.get(first_id).object_type == 'formdef'
731
    assert pub.snapshot_class.get(first_id + 1).object_type == 'workflow'
732
    assert pub.snapshot_class.get(first_id + 1).object_id == '1'
733
    assert pub.snapshot_class.get(first_id + 1).label is None
734
    assert pub.snapshot_class.get(first_id + 1).timestamp == old_timestamp
735
    assert pub.snapshot_class.get(first_id + 1).instance.name == 'test'
736
    assert pub.snapshot_class.get(first_id + 2).object_type == 'workflow'
737
    assert pub.snapshot_class.get(first_id + 2).object_id == '2'
738
    assert pub.snapshot_class.get(first_id + 3).object_type == 'workflow'
739
    assert pub.snapshot_class.get(first_id + 3).object_id == '1'
740
    assert pub.snapshot_class.get(first_id + 3).label == "snapshot !"
741
    assert pub.snapshot_class.get(first_id + 3).instance.name == 'test'
742

  
743
    # with changes
744
    workflow.name = 'foo bar'
745
    workflow.store()
746
    assert pub.snapshot_class.count() == 4
747
    assert pub.snapshot_class.get(first_id).object_type == 'formdef'
748
    assert pub.snapshot_class.get(first_id + 2).object_type == 'workflow'
749
    assert pub.snapshot_class.get(first_id + 2).object_id == '2'
750
    assert pub.snapshot_class.get(first_id + 3).object_type == 'workflow'
751
    assert pub.snapshot_class.get(first_id + 3).object_id == '1'
752
    assert pub.snapshot_class.get(first_id + 3).label == "snapshot !"
753
    assert pub.snapshot_class.get(first_id + 3).instance.name == 'test'
754
    assert pub.snapshot_class.get(first_id + 4).object_type == 'workflow'
755
    assert pub.snapshot_class.get(first_id + 4).object_id == '1'
756
    assert pub.snapshot_class.get(first_id + 4).label is None
757
    assert pub.snapshot_class.get(first_id + 4).timestamp > old_timestamp
758
    assert pub.snapshot_class.get(first_id + 4).instance.name == 'foo bar'
759

  
760

  
761 707
def test_pickle_erroneous_snapshot_object(pub):
762 708
    # check snapshot object attribute is not restored
763 709
    formdef = FormDef()
wcs/backoffice/snapshots.py
117 117
                self.snapshot.timestamp.strftime('%Y%m%d-%H%M'),
118 118
            ),
119 119
        )
120
        return '<?xml version="1.0"?>\n' + self.snapshot.serialization
120
        return '<?xml version="1.0"?>\n' + self.snapshot.get_serialization()
121 121

  
122 122
    def restore(self):
123 123
        form = Form(enctype='multipart/form-data')
wcs/snapshots.py
14 14
# You should have received a copy of the GNU General Public License
15 15
# along with this program; if not, see <http://www.gnu.org/licenses/>.
16 16

  
17
import difflib
18
import re
17 19
import xml.etree.ElementTree as ET
18 20

  
19 21
from django.utils.timezone import now
20 22
from quixote import get_publisher, get_session
21 23

  
22 24
from wcs.qommon import _, misc
23
from wcs.qommon.storage import Null
24 25

  
25 26

  
26 27
class UnknownUser:
......
28 29
        return str(_('unknown user'))
29 30

  
30 31

  
32
def indent(tree, space="  ", level=0):
33
    # backport from Lib/xml/etree/ElementTree.py python 3.9
34
    if isinstance(tree, ET.ElementTree):
35
        tree = tree.getroot()
36
    if level < 0:
37
        raise ValueError(f"Initial indentation level must be >= 0, got {level}")
38
    if len(tree) == 0:
39
        return
40

  
41
    # Reduce the memory consumption by reusing indentation strings.
42
    indentations = ["\n" + level * space]
43

  
44
    def _indent_children(elem, level):
45
        # Start a new indentation level for the first child.
46
        child_level = level + 1
47
        try:
48
            child_indentation = indentations[child_level]
49
        except IndexError:
50
            child_indentation = indentations[level] + space
51
            indentations.append(child_indentation)
52

  
53
        if not elem.text or not elem.text.strip():
54
            elem.text = child_indentation
55

  
56
        for child in elem:
57
            if len(child):
58
                _indent_children(child, child_level)
59
            if not child.tail or not child.tail.strip():
60
                child.tail = child_indentation
61

  
62
        # Dedent after the last child by overwriting the previous indentation.
63
        if not child.tail.strip():
64
            child.tail = indentations[level]
65

  
66
    _indent_children(tree, 0)
67

  
68

  
69
_no_eol = "\\ No newline at end of file"
70
_hdr_pat = re.compile(r"^@@ -(\d+),?(\d+)? \+(\d+),?(\d+)? @@$")
71

  
72

  
73
def make_patch(a, b):
74
    """
75
    Get unified string diff between two strings. Trims top two lines.
76
    Returns empty string if strings are identical.
77
    """
78
    diffs = difflib.unified_diff(a.splitlines(True), b.splitlines(True), n=0)
79
    try:
80
        _, _ = next(diffs), next(diffs)
81
    except StopIteration:
82
        pass
83
    return ''.join([d if d[-1] == '\n' else d + '\n' + _no_eol + '\n' for d in diffs])
84

  
85

  
86
def apply_patch(s, patch, revert=False):
87
    """
88
    Apply patch to string s to recover newer string.
89
    If revert is True, treat s as the newer string, recover older string.
90
    """
91
    s = s.splitlines(True)
92
    p = patch.splitlines(True)
93
    t = ''
94
    i = sl = 0
95
    (midx, sign) = (1, '+') if not revert else (3, '-')
96
    while i < len(p) and p[i].startswith(("---", "+++")):
97
        i += 1  # skip header lines
98
    while i < len(p):
99
        m = _hdr_pat.match(p[i])
100
        if not m:
101
            raise Exception("Bad patch -- regex mismatch [line " + str(i) + "]")
102
        _l = int(m.group(midx)) - 1 + (m.group(midx + 1) == '0')
103
        if sl > _l or _l > len(s):
104
            raise Exception("Bad patch -- bad line num [line " + str(i) + "]")
105
        t += ''.join(s[sl:_l])
106
        sl = _l
107
        i += 1
108
        while i < len(p) and p[i][0] != '@':
109
            if i + 1 < len(p) and p[i + 1][0] == '\\':
110
                line = p[i][:-1]
111
                i += 2
112
            else:
113
                line = p[i]
114
                i += 1
115
            if len(line) > 0:
116
                if line[0] == sign or line[0] == ' ':
117
                    t += line[1:]
118
                sl += line[0] != sign
119
    t += ''.join(s[sl:])
120
    return t
121

  
122

  
31 123
class Snapshot:
32 124
    id = None
33 125
    object_type = None  # (formdef, carddef, blockdef, workflow, data_source, etc.)
......
36 128
    user_id = None
37 129
    comment = None
38 130
    serialization = None
131
    patch = None
39 132
    label = None  # (named snapshot)
40 133

  
41 134
    # cache
42 135
    _instance = None
43 136
    _user = None
44 137

  
45
    WCS_MAX_LEN = 1000000
46

  
47 138
    @classmethod
48 139
    def snap(cls, instance, comment=None, label=None):
49 140
        obj = cls()
......
52 143
        obj.timestamp = now()
53 144
        if get_session():
54 145
            obj.user_id = get_session().user
55
        obj.serialization = ET.tostring(instance.export_to_xml(include_id=True)).decode('utf-8')
146
        tree = instance.export_to_xml(include_id=True)
147
        obj.serialization = ET.tostring(tree).decode('utf-8')
56 148
        obj.comment = str(comment) if comment else None
57 149
        obj.label = label
58
        latest = cls.get_latest(obj.object_type, obj.object_id)
59
        if label is not None or latest is None or obj.serialization != latest.serialization:
60
            # save snapshot if there are changes or an explicit label was
61
            # given.
62
            if label is None and len(obj.serialization) > cls.WCS_MAX_LEN:
63
                # keep only latest snapshot for big objects
64
                # (typically workflows with embedded documents)
65
                for old_snapshot in cls.select_object_history(instance, clause=[Null('label')]):
66
                    cls.remove_object(old_snapshot.id)
150

  
151
        latest_complete = cls.get_latest(obj.object_type, obj.object_id, complete=True)
152
        if latest_complete is None:
153
            # no complete snapshot, store it, with serialization and no patch
154
            obj.store()
155
            return
156

  
157
        # get patch beetween latest serialization and current instance
158
        # indent xml to minimize patch
159
        latest_tree = ET.fromstring(latest_complete.serialization)
160
        indent(tree)
161
        indent(latest_tree)
162
        patch = make_patch(ET.tostring(latest_tree).decode('utf-8'), ET.tostring(tree).decode('utf-8'))
163
        # should we store a snapshot ?
164
        store_snapshot = False
165
        if label is not None:
166
            # always store if label is set
167
            store_snapshot = True
168
        else:
169
            # compare with patch of latest snapshot
170
            latest = cls.get_latest(obj.object_type, obj.object_id)
171
            if latest.patch and patch != latest.patch:
172
                # the patch has changed
173
                store_snapshot = True
174
            elif latest.serialization and patch:
175
                # there is a patch
176
                store_snapshot = True
177

  
178
        if store_snapshot:
179
            if len(obj.serialization) > len(patch):
180
                # serialization is bigger than patch, store patch
181
                obj.serialization = None
182
                obj.patch = patch
183
            # else: keep serialization and ignore patch
67 184
            obj.store()
68 185

  
69 186
    def get_object_class(self):
......
80 197
                return klass
81 198
        raise KeyError('no class for object type: %s' % self.object_type)
82 199

  
200
    def get_serialization(self):
201
        # there is a complete serialization
202
        if self.serialization:
203
            return self.serialization
204

  
205
        # get latest version with serialization
206
        latest_complete = self.__class__.get_latest(self.object_type, self.object_id, complete=True)
207
        latest_tree = ET.fromstring(latest_complete.serialization)
208
        indent(latest_tree)
209
        serialization = apply_patch(ET.tostring(latest_tree).decode('utf-8'), self.patch or '')
210
        return serialization
211

  
83 212
    @property
84 213
    def instance(self):
85 214
        if self._instance is None:
86
            tree = ET.fromstring(self.serialization)
215
            tree = ET.fromstring(self.get_serialization())
87 216
            self._instance = self.get_object_class().import_from_xml_tree(
88 217
                tree,
89 218
                include_id=True,
wcs/sql.py
1057 1057
                                        user_id VARCHAR,
1058 1058
                                        comment TEXT,
1059 1059
                                        serialization TEXT,
1060
                                        patch TEXT,
1060 1061
                                        label VARCHAR
1061 1062
                                        )'''
1062 1063
            % table_name
......
1069 1070
    )
1070 1071
    existing_fields = {x[0] for x in cur.fetchall()}
1071 1072

  
1073
    # migrations
1074
    if 'patch' not in existing_fields:
1075
        cur.execute('''ALTER TABLE %s ADD COLUMN patch TEXT''' % table_name)
1076

  
1072 1077
    needed_fields = {x[0] for x in Snapshot._table_static_fields}
1073 1078

  
1074 1079
    # delete obsolete fields
......
2959 2964
        ('user_id', 'varchar'),
2960 2965
        ('comment', 'text'),
2961 2966
        ('serialization', 'text'),
2967
        ('patch', 'text'),
2962 2968
        ('label', 'varchar'),
2963 2969
    ]
2964
    _table_select_skipped_fields = ['serialization']
2970
    _table_select_skipped_fields = ['serialization', 'patch']
2965 2971

  
2966 2972
    @guard_postgres
2967 2973
    @invalidate_substitution_cache
......
3018 3024
        return []
3019 3025

  
3020 3026
    @classmethod
3021
    def get_latest(cls, object_type, object_id):
3027
    def get_latest(cls, object_type, object_id, complete=False):
3022 3028
        conn, cur = get_connection_and_cursor()
3023 3029
        sql_statement = '''SELECT id FROM snapshots
3024
                            WHERE object_type = %(object_type)s
3025
                              AND object_id = %(object_id)s
3030
                            WHERE object_type = %%(object_type)s
3031
                              AND object_id = %%(object_id)s
3032
                              %s
3026 3033
                         ORDER BY timestamp DESC
3027
                            LIMIT 1'''
3034
                            LIMIT 1''' % (
3035
            'AND serialization IS NOT NULL' if complete else ''
3036
        )
3028 3037
        cur.execute(sql_statement, {'object_type': object_type, 'object_id': object_id})
3029 3038
        row = cur.fetchone()
3030 3039
        conn.commit()
......
3441 3450
# latest migration, number + description (description is not used
3442 3451
# programmaticaly but will make sure git conflicts if two migrations are
3443 3452
# separately added with the same number)
3444
SQL_LEVEL = (53, 'add kind column on logged_errors table')
3453
SQL_LEVEL = (54, 'add patch column on snapshot table')
3445 3454

  
3446 3455

  
3447 3456
def migrate_global_views(conn, cur):
......
3617 3626
                continue
3618 3627
            for formdata in formdef.data_class().select_iterator():
3619 3628
                formdata._set_auto_fields(cur)  # build digests
3620
    if sql_level < 42:
3629
    if sql_level < 54:
3621 3630
        # 42: create snapshots table
3631
        # 54: add patch column
3622 3632
        do_snapshots_table()
3623 3633
    if sql_level < 53:
3624 3634
        # 47: store LoggedErrors in SQL
3625
-