From 64be2f52a83890d9811c8cb8aa2999376bbf3936 Mon Sep 17 00:00:00 2001 From: Benjamin Dauvergne Date: Fri, 15 Mar 2019 03:17:04 +0100 Subject: [PATCH] api: accept get_or_create parameter to user creation endpoint (fixes #22376) Done using, hopefully thread-safe, monkeypatching of model's manager create() method. --- src/authentic2/api_mixins.py | 64 ++++++++++++++++++++++++++++++++++++ src/authentic2/api_views.py | 6 ++-- tests/test_api.py | 19 +++++++++-- 3 files changed, 85 insertions(+), 4 deletions(-) create mode 100644 src/authentic2/api_mixins.py diff --git a/src/authentic2/api_mixins.py b/src/authentic2/api_mixins.py new file mode 100644 index 00000000..28152b5b --- /dev/null +++ b/src/authentic2/api_mixins.py @@ -0,0 +1,64 @@ +# authentic2 - versatile identity manager +# Copyright (C) 2010-2018 Entr'ouvert +# +# This program is free software: you can redistribute it and/or modify it +# under the terms of the GNU Affero General Public License as published +# by the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. +# +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . + +import threading +import contextlib + + +@contextlib.contextmanager +def monkeypatch_method(instance, attribute, new_value): + current_thread_id = threading.current_thread().ident + old_value = getattr(instance, attribute) + try: + def f(*args, **kwargs): + if threading.current_thread().ident == current_thread_id: + return new_value(*args, **kwargs) + else: + return old_value(*args, **kwargs) + setattr(instance, attribute, f) + yield + finally: + try: + delattr(instance, attribute) + except AttributeError: + pass + + +class GetOrCreateModelSerializer(object): + def get_or_create(self, keys, validated_data): + def get_or_create(**validated_data): + kwargs = {} + defaults = kwargs['defaults'] = {} + missing_keys = set(keys) - set(validated_data) + if missing_keys: + raise TypeError('Keys %s are not a writable field' % keys) + for key, value in validated_data.items(): + if key in keys: + kwargs[key] = value + else: + defaults[key] = value + return self.Meta.model.objects.get_or_create(**kwargs)[0] + with monkeypatch_method(self.Meta.model.objects, 'create', get_or_create): + return super(GetOrCreateModelSerializer, self).create(validated_data) + + def create(self, validated_data): + try: + keys = self.context['view'].request.GET.getlist('get_or_create') + except Exception: + pass + else: + return self.get_or_create(keys, validated_data) + return super(GetOrCreateModelSerializer, self).create(validated_data) diff --git a/src/authentic2/api_views.py b/src/authentic2/api_views.py index 54e3a626..05c8d1a8 100644 --- a/src/authentic2/api_views.py +++ b/src/authentic2/api_views.py @@ -46,7 +46,8 @@ from django_filters.rest_framework import FilterSet from .passwords import get_password_checker from .custom_user.models import User -from . import utils, decorators, attribute_kinds, app_settings, hooks +from . import (utils, decorators, attribute_kinds, app_settings, hooks, + api_mixins) from .models import Attribute, PasswordReset, Service from .a2_rbac.utils import get_default_ou @@ -321,7 +322,8 @@ def user(request): return request.user.to_json() -class BaseUserSerializer(serializers.ModelSerializer): +class BaseUserSerializer(api_mixins.GetOrCreateModelSerializer, + serializers.ModelSerializer): ou = serializers.SlugRelatedField( queryset=get_ou_model().objects.all(), slug_field='slug', diff --git a/tests/test_api.py b/tests/test_api.py index cfd9fe6a..6b70f4b9 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -16,12 +16,12 @@ from authentic2.models import Service from django.core import mail from django.contrib.auth.hashers import check_password -from authentic2_idp_oidc.models import OIDCClient - from utils import login, basic_authorization_header, get_link_from_mail pytestmark = pytest.mark.django_db +User = get_user_model() + def test_api_user_simple(logged_app): resp = logged_app.get('/api/user/') @@ -991,3 +991,18 @@ def test_validate_password_regex(app, settings): assert response.json['checks'][3]['result'] is True assert response.json['checks'][4]['label'] == 'must contain "ok"' assert response.json['checks'][4]['result'] is True + + +def test_api_users_get_or_create(settings, app, admin): + app.authorization = ('Basic', (admin.username, admin.username)) + # test missing first_name + payload = { + 'email': 'john.doe@example.net', + 'first_name': 'John', + 'last_name': 'Doe', + } + resp = app.post_json('/api/users/?get_or_create=email', params=payload, status=201) + id = resp.json['id'] + + resp = app.post_json('/api/users/?get_or_create=email', params=payload, status=201) + assert id == resp.json['id'] -- 2.20.1