Projet

Général

Profil

0001-models-lock-user-model-when-changing-multiple-attrib.patch

Benjamin Dauvergne, 12 novembre 2019 11:33

Télécharger (8,32 ko)

Voir les différences:

Subject: [PATCH] models: lock user model when changing multiple attribute
 values (#37390)

 src/authentic2/custom_user/models.py | 44 ++++++++++++++------------
 src/authentic2/models.py             | 46 +++++++++++++++-------------
 tests/test_concurrency.py            | 43 ++++++++++++++------------
 3 files changed, 71 insertions(+), 62 deletions(-)
src/authentic2/custom_user/models.py
14 14
# You should have received a copy of the GNU Affero General Public License
15 15
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
16 16

  
17
from django.db import models
17
from django.db import models, transaction
18 18
from django.utils import timezone
19 19
from django.core.mail import send_mail
20 20
from django.utils import six
......
55 55
            for atv in self.owner.attribute_values.all():
56 56
                attribute = get_attributes_map()[atv.attribute_id]
57 57
                atv.attribute = attribute
58
                values[attribute.name] = atv
58
                if attribute.multiple:
59
                    values.setdefault(attribute.name, []).append(atv)
60
                else:
61
                    values[attribute.name] = atv
59 62
        self.__dict__['values'] = owner._a2_attributes_cache
60 63

  
61 64
    def __setattr__(self, name, value):
62
        atv = self.values.get(name)
63
        if atv:
64
            if isinstance(atv, (list, tuple)):
65
                attribute = atv[0].attribute
65
        attribute = get_attributes_map().get(name)
66
        if not attribute:
67
            raise AttributeError(name)
68

  
69
        with transaction.atomic():
70
            if attribute.multiple:
71
                attribute.set_value(self.owner, value, verified=bool(self.verified))
66 72
            else:
67
                attribute = atv.attribute
68
            attribute.set_value(self.owner, value, verified=bool(self.verified), attribute_value=atv)
69
        else:
70
            attribute = get_attributes_map().get(name)
71
            if not attribute:
72
                raise AttributeError(name)
73
            self.values[name] = attribute.set_value(self.owner, value, verified=bool(self.verified))
74

  
75
        update_fields = ['modified']
76
        if name in ['first_name', 'last_name']:
77
            if getattr(self.owner, name) != value:
78
                setattr(self.owner, name, value)
79
                update_fields.append(name)
80
        self.owner.save(update_fields=update_fields)
73
                atv = self.values.get(name)
74
                self.values[name] = attribute.set_value(
75
                    self.owner, value,
76
                    verified=bool(self.verified),
77
                    attribute_value=atv)
78

  
79
            update_fields = ['modified']
80
            if name in ['first_name', 'last_name']:
81
                if getattr(self.owner, name) != value:
82
                    setattr(self.owner, name, value)
83
                    update_fields.append(name)
84
            self.owner.save(update_fields=update_fields)
81 85

  
82 86
    def __getattr__(self, name):
83 87
        if name not in get_attributes_map():
src/authentic2/models.py
245 245
            AttributeValue.objects.with_owner(owner).filter(attribute=self).delete()
246 246
            return
247 247

  
248
        if self.multiple:
249
            assert isinstance(value, (list, set, tuple))
250
            values = value
251
            avs = []
252
            content_list = []
248
        with transaction.atomic():
249
            if self.multiple:
250
                assert isinstance(value, (list, set, tuple))
251
                values = value
252
                avs = []
253
                content_list = []
254

  
255
                list(owner.__class__.objects.filter(pk=owner.pk).select_for_update())
253 256

  
254
            with transaction.atomic():
255 257
                for value in values:
256 258
                    content = serialize(value)
257 259
                    av, created = AttributeValue.objects.get_or_create(
......
273 275
                    object_id=owner.pk,
274 276
                    multiple=True
275 277
                ).exclude(content__in=content_list).delete()
276
            return avs
277
        else:
278
            content = serialize(value)
279
            if attribute_value:
280
                av, created = attribute_value, False
278
                return avs
281 279
            else:
282
                av, created = AttributeValue.objects.get_or_create(
283
                    content_type=ContentType.objects.get_for_model(owner),
284
                    object_id=owner.pk,
285
                    attribute=self,
286
                    multiple=False,
287
                    defaults={'content': content, 'verified': verified})
288
            if not created and (av.content != content or av.verified != verified):
289
                av.content = content
290
                av.verified = verified
291
                av.save()
292
            return av
280
                content = serialize(value)
281
                if attribute_value:
282
                    av, created = attribute_value, False
283
                else:
284
                    av, created = AttributeValue.objects.get_or_create(
285
                        content_type=ContentType.objects.get_for_model(owner),
286
                        object_id=owner.pk,
287
                        attribute=self,
288
                        multiple=False,
289
                        defaults={'content': content, 'verified': verified})
290
                if not created and (av.content != content or av.verified != verified):
291
                    av.content = content
292
                    av.verified = verified
293
                    av.save()
294
                return av
293 295

  
294 296
    def natural_key(self):
295 297
        return (self.name,)
tests/test_concurrency.py
25 25

  
26 26
@skipif_sqlite
27 27
def test_attribute_value_uniqueness(migrations, transactional_db, simple_user, concurrency):
28
    from django.db.transaction import set_autocommit
28
    #from django.db.transaction import set_autocommit
29 29
    # disabled default attributes
30 30
    Attribute.objects.update(disabled=True)
31 31

  
32
    set_autocommit(True)
32
    #set_autocommit(True)
33 33
    acount = Attribute.objects.count()
34 34

  
35 35
    single_at = Attribute.objects.create(
......
44 44
        multiple=True)
45 45
    assert Attribute.objects.count() == acount + 2
46 46

  
47
    def map_threads(f, l):
48
        threads = []
49
        for i in l:
50
            threads.append(threading.Thread(target=f, args=(i,)))
51
            threads[-1].start()
52
        for thread in threads:
53
            thread.join()
47
    AttributeValue.objects.all().delete()
54 48

  
55
    def f(i):
56
        simple_user.attributes.multiple = [str(i)]
57
        connection.close()
58
    map_threads(f, range(concurrency))
59
    map_threads(f, range(concurrency))
60
    assert AttributeValue.objects.filter(attribute=multiple_at).count() == 1
49
    for i in range(10):
50
        def map_threads(f, l):
51
            threads = []
52
            for i in l:
53
                threads.append(threading.Thread(target=f, args=(i,)))
54
                threads[-1].start()
55
            for thread in threads:
56
                thread.join()
61 57

  
62
    def f(i):
63
        simple_user.attributes.single = str(i)
64
        connection.close()
65
    map_threads(f, range(concurrency))
66
    assert AttributeValue.objects.filter(attribute=single_at).count() == 1
58
        def f(i):
59
            simple_user.attributes.multiple = [str(i)]
60
            connection.close()
61
        map_threads(f, range(concurrency))
62
        map_threads(f, range(concurrency))
63
        assert AttributeValue.objects.filter(attribute=multiple_at).count() == 1
64

  
65
        def f(i):
66
            simple_user.attributes.single = str(i)
67
            connection.close()
68
        map_threads(f, range(concurrency))
69
        assert AttributeValue.objects.filter(attribute=single_at).count() == 1
67
-