From 3d3df4e85830e2f3f1ef5624c1f7477d1fff81cb Mon Sep 17 00:00:00 2001 From: Benjamin Dauvergne Date: Tue, 12 Nov 2019 11:05:14 +0100 Subject: [PATCH] models: lock user model when changing multiple attribute values (#37390) --- src/authentic2/custom_user/models.py | 44 ++++++++++++++------------ src/authentic2/models.py | 46 +++++++++++++++------------- tests/test_concurrency.py | 43 ++++++++++++++------------ 3 files changed, 71 insertions(+), 62 deletions(-) diff --git a/src/authentic2/custom_user/models.py b/src/authentic2/custom_user/models.py index 3c867e43..348ad907 100644 --- a/src/authentic2/custom_user/models.py +++ b/src/authentic2/custom_user/models.py @@ -14,7 +14,7 @@ # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -from django.db import models +from django.db import models, transaction from django.utils import timezone from django.core.mail import send_mail from django.utils import six @@ -55,29 +55,33 @@ class Attributes(object): for atv in self.owner.attribute_values.all(): attribute = get_attributes_map()[atv.attribute_id] atv.attribute = attribute - values[attribute.name] = atv + if attribute.multiple: + values.setdefault(attribute.name, []).append(atv) + else: + values[attribute.name] = atv self.__dict__['values'] = owner._a2_attributes_cache def __setattr__(self, name, value): - atv = self.values.get(name) - if atv: - if isinstance(atv, (list, tuple)): - attribute = atv[0].attribute + attribute = get_attributes_map().get(name) + if not attribute: + raise AttributeError(name) + + with transaction.atomic(): + if attribute.multiple: + attribute.set_value(self.owner, value, verified=bool(self.verified)) else: - attribute = atv.attribute - attribute.set_value(self.owner, value, verified=bool(self.verified), attribute_value=atv) - else: - attribute = get_attributes_map().get(name) - if not attribute: - raise AttributeError(name) - self.values[name] = attribute.set_value(self.owner, value, verified=bool(self.verified)) - - update_fields = ['modified'] - if name in ['first_name', 'last_name']: - if getattr(self.owner, name) != value: - setattr(self.owner, name, value) - update_fields.append(name) - self.owner.save(update_fields=update_fields) + atv = self.values.get(name) + self.values[name] = attribute.set_value( + self.owner, value, + verified=bool(self.verified), + attribute_value=atv) + + update_fields = ['modified'] + if name in ['first_name', 'last_name']: + if getattr(self.owner, name) != value: + setattr(self.owner, name, value) + update_fields.append(name) + self.owner.save(update_fields=update_fields) def __getattr__(self, name): if name not in get_attributes_map(): diff --git a/src/authentic2/models.py b/src/authentic2/models.py index 69f933b8..a17e3338 100644 --- a/src/authentic2/models.py +++ b/src/authentic2/models.py @@ -245,13 +245,15 @@ class Attribute(models.Model): AttributeValue.objects.with_owner(owner).filter(attribute=self).delete() return - if self.multiple: - assert isinstance(value, (list, set, tuple)) - values = value - avs = [] - content_list = [] + with transaction.atomic(): + if self.multiple: + assert isinstance(value, (list, set, tuple)) + values = value + avs = [] + content_list = [] + + list(owner.__class__.objects.filter(pk=owner.pk).select_for_update()) - with transaction.atomic(): for value in values: content = serialize(value) av, created = AttributeValue.objects.get_or_create( @@ -273,23 +275,23 @@ class Attribute(models.Model): object_id=owner.pk, multiple=True ).exclude(content__in=content_list).delete() - return avs - else: - content = serialize(value) - if attribute_value: - av, created = attribute_value, False + return avs else: - av, created = AttributeValue.objects.get_or_create( - content_type=ContentType.objects.get_for_model(owner), - object_id=owner.pk, - attribute=self, - multiple=False, - defaults={'content': content, 'verified': verified}) - if not created and (av.content != content or av.verified != verified): - av.content = content - av.verified = verified - av.save() - return av + content = serialize(value) + if attribute_value: + av, created = attribute_value, False + else: + av, created = AttributeValue.objects.get_or_create( + content_type=ContentType.objects.get_for_model(owner), + object_id=owner.pk, + attribute=self, + multiple=False, + defaults={'content': content, 'verified': verified}) + if not created and (av.content != content or av.verified != verified): + av.content = content + av.verified = verified + av.save() + return av def natural_key(self): return (self.name,) diff --git a/tests/test_concurrency.py b/tests/test_concurrency.py index 145d22c3..2107cf37 100644 --- a/tests/test_concurrency.py +++ b/tests/test_concurrency.py @@ -25,11 +25,11 @@ from utils import skipif_sqlite @skipif_sqlite def test_attribute_value_uniqueness(migrations, transactional_db, simple_user, concurrency): - from django.db.transaction import set_autocommit + #from django.db.transaction import set_autocommit # disabled default attributes Attribute.objects.update(disabled=True) - set_autocommit(True) + #set_autocommit(True) acount = Attribute.objects.count() single_at = Attribute.objects.create( @@ -44,23 +44,26 @@ def test_attribute_value_uniqueness(migrations, transactional_db, simple_user, c multiple=True) assert Attribute.objects.count() == acount + 2 - def map_threads(f, l): - threads = [] - for i in l: - threads.append(threading.Thread(target=f, args=(i,))) - threads[-1].start() - for thread in threads: - thread.join() + AttributeValue.objects.all().delete() - def f(i): - simple_user.attributes.multiple = [str(i)] - connection.close() - map_threads(f, range(concurrency)) - map_threads(f, range(concurrency)) - assert AttributeValue.objects.filter(attribute=multiple_at).count() == 1 + for i in range(10): + def map_threads(f, l): + threads = [] + for i in l: + threads.append(threading.Thread(target=f, args=(i,))) + threads[-1].start() + for thread in threads: + thread.join() - def f(i): - simple_user.attributes.single = str(i) - connection.close() - map_threads(f, range(concurrency)) - assert AttributeValue.objects.filter(attribute=single_at).count() == 1 + def f(i): + simple_user.attributes.multiple = [str(i)] + connection.close() + map_threads(f, range(concurrency)) + map_threads(f, range(concurrency)) + assert AttributeValue.objects.filter(attribute=multiple_at).count() == 1 + + def f(i): + simple_user.attributes.single = str(i) + connection.close() + map_threads(f, range(concurrency)) + assert AttributeValue.objects.filter(attribute=single_at).count() == 1 -- 2.23.0