Projet

Général

Profil

0001-api-accept-get_or_create-parameter-to-user-creation-.patch

Benjamin Dauvergne, 15 mars 2019 03:19

Télécharger (5,52 ko)

Voir les différences:

Subject: [PATCH] api: accept get_or_create parameter to user creation endpoint
 (fixes #22376)

Done using, hopefully thread-safe, monkeypatching of model's manager
create() method.
 src/authentic2/api_mixins.py | 64 ++++++++++++++++++++++++++++++++++++
 src/authentic2/api_views.py  |  6 ++--
 tests/test_api.py            | 19 +++++++++--
 3 files changed, 85 insertions(+), 4 deletions(-)
 create mode 100644 src/authentic2/api_mixins.py
src/authentic2/api_mixins.py
1
# authentic2 - versatile identity manager
2
# Copyright (C) 2010-2018 Entr'ouvert
3
#
4
# This program is free software: you can redistribute it and/or modify it
5
# under the terms of the GNU Affero General Public License as published
6
# by the Free Software Foundation, either version 3 of the License, or
7
# (at your option) any later version.
8
#
9
# This program is distributed in the hope that it will be useful,
10
# but WITHOUT ANY WARRANTY; without even the implied warranty of
11
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12
# GNU Affero General Public License for more details.
13
#
14
# You should have received a copy of the GNU Affero General Public License
15
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
16

  
17
import threading
18
import contextlib
19

  
20

  
21
@contextlib.contextmanager
22
def monkeypatch_method(instance, attribute, new_value):
23
    current_thread_id = threading.current_thread().ident
24
    old_value = getattr(instance, attribute)
25
    try:
26
        def f(*args, **kwargs):
27
            if threading.current_thread().ident == current_thread_id:
28
                return new_value(*args, **kwargs)
29
            else:
30
                return old_value(*args, **kwargs)
31
        setattr(instance, attribute, f)
32
        yield
33
    finally:
34
        try:
35
            delattr(instance, attribute)
36
        except AttributeError:
37
            pass
38

  
39

  
40
class GetOrCreateModelSerializer(object):
41
    def get_or_create(self, keys, validated_data):
42
        def get_or_create(**validated_data):
43
            kwargs = {}
44
            defaults = kwargs['defaults'] = {}
45
            missing_keys = set(keys) - set(validated_data)
46
            if missing_keys:
47
                raise TypeError('Keys %s are not a writable field' % keys)
48
            for key, value in validated_data.items():
49
                if key in keys:
50
                    kwargs[key] = value
51
                else:
52
                    defaults[key] = value
53
            return self.Meta.model.objects.get_or_create(**kwargs)[0]
54
        with monkeypatch_method(self.Meta.model.objects, 'create', get_or_create):
55
            return super(GetOrCreateModelSerializer, self).create(validated_data)
56

  
57
    def create(self, validated_data):
58
        try:
59
            keys = self.context['view'].request.GET.getlist('get_or_create')
60
        except Exception:
61
            pass
62
        else:
63
            return self.get_or_create(keys, validated_data)
64
        return super(GetOrCreateModelSerializer, self).create(validated_data)
src/authentic2/api_views.py
46 46

  
47 47
from .passwords import get_password_checker
48 48
from .custom_user.models import User
49
from . import utils, decorators, attribute_kinds, app_settings, hooks
49
from . import (utils, decorators, attribute_kinds, app_settings, hooks,
50
               api_mixins)
50 51
from .models import Attribute, PasswordReset, Service
51 52
from .a2_rbac.utils import get_default_ou
52 53

  
......
321 322
    return request.user.to_json()
322 323

  
323 324

  
324
class BaseUserSerializer(serializers.ModelSerializer):
325
class BaseUserSerializer(api_mixins.GetOrCreateModelSerializer,
326
                         serializers.ModelSerializer):
325 327
    ou = serializers.SlugRelatedField(
326 328
        queryset=get_ou_model().objects.all(),
327 329
        slug_field='slug',
tests/test_api.py
16 16
from django.core import mail
17 17
from django.contrib.auth.hashers import check_password
18 18

  
19
from authentic2_idp_oidc.models import OIDCClient
20

  
21 19
from utils import login, basic_authorization_header, get_link_from_mail
22 20

  
23 21
pytestmark = pytest.mark.django_db
24 22

  
23
User = get_user_model()
24

  
25 25

  
26 26
def test_api_user_simple(logged_app):
27 27
    resp = logged_app.get('/api/user/')
......
991 991
    assert response.json['checks'][3]['result'] is True
992 992
    assert response.json['checks'][4]['label'] == 'must contain "ok"'
993 993
    assert response.json['checks'][4]['result'] is True
994

  
995

  
996
def test_api_users_get_or_create(settings, app, admin):
997
    app.authorization = ('Basic', (admin.username, admin.username))
998
    # test missing first_name
999
    payload = {
1000
        'email': 'john.doe@example.net',
1001
        'first_name': 'John',
1002
        'last_name': 'Doe',
1003
    }
1004
    resp = app.post_json('/api/users/?get_or_create=email', params=payload, status=201)
1005
    id = resp.json['id']
1006

  
1007
    resp = app.post_json('/api/users/?get_or_create=email', params=payload, status=201)
1008
    assert id == resp.json['id']
994
-