0001-api-accept-get_or_create-parameter-to-user-creation-.patch
src/authentic2/api_mixins.py | ||
---|---|---|
1 |
# authentic2 - versatile identity manager |
|
2 |
# Copyright (C) 2010-2018 Entr'ouvert |
|
3 |
# |
|
4 |
# This program is free software: you can redistribute it and/or modify it |
|
5 |
# under the terms of the GNU Affero General Public License as published |
|
6 |
# by the Free Software Foundation, either version 3 of the License, or |
|
7 |
# (at your option) any later version. |
|
8 |
# |
|
9 |
# This program is distributed in the hope that it will be useful, |
|
10 |
# but WITHOUT ANY WARRANTY; without even the implied warranty of |
|
11 |
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the |
|
12 |
# GNU Affero General Public License for more details. |
|
13 |
# |
|
14 |
# You should have received a copy of the GNU Affero General Public License |
|
15 |
# along with this program. If not, see <http://www.gnu.org/licenses/>. |
|
16 | ||
17 |
import threading |
|
18 |
import contextlib |
|
19 | ||
20 | ||
21 |
@contextlib.contextmanager |
|
22 |
def monkeypatch_method(instance, attribute, new_value): |
|
23 |
current_thread_id = threading.current_thread().ident |
|
24 |
old_value = getattr(instance, attribute) |
|
25 |
try: |
|
26 |
def f(*args, **kwargs): |
|
27 |
if threading.current_thread().ident == current_thread_id: |
|
28 |
return new_value(*args, **kwargs) |
|
29 |
else: |
|
30 |
return old_value(*args, **kwargs) |
|
31 |
setattr(instance, attribute, f) |
|
32 |
yield |
|
33 |
finally: |
|
34 |
try: |
|
35 |
delattr(instance, attribute) |
|
36 |
except AttributeError: |
|
37 |
pass |
|
38 | ||
39 | ||
40 |
class GetOrCreateModelSerializer(object): |
|
41 |
def get_or_create(self, keys, validated_data): |
|
42 |
def get_or_create(**validated_data): |
|
43 |
kwargs = {} |
|
44 |
defaults = kwargs['defaults'] = {} |
|
45 |
missing_keys = set(keys) - set(validated_data) |
|
46 |
if missing_keys: |
|
47 |
raise TypeError('Keys %s are not a writable field' % keys) |
|
48 |
for key, value in validated_data.items(): |
|
49 |
if key in keys: |
|
50 |
kwargs[key] = value |
|
51 |
else: |
|
52 |
defaults[key] = value |
|
53 |
return self.Meta.model.objects.get_or_create(**kwargs)[0] |
|
54 |
with monkeypatch_method(self.Meta.model.objects, 'create', get_or_create): |
|
55 |
return super(GetOrCreateModelSerializer, self).create(validated_data) |
|
56 | ||
57 |
def create(self, validated_data): |
|
58 |
try: |
|
59 |
keys = self.context['view'].request.GET.getlist('get_or_create') |
|
60 |
except Exception: |
|
61 |
pass |
|
62 |
else: |
|
63 |
return self.get_or_create(keys, validated_data) |
|
64 |
return super(GetOrCreateModelSerializer, self).create(validated_data) |
src/authentic2/api_views.py | ||
---|---|---|
46 | 46 | |
47 | 47 |
from .passwords import get_password_checker |
48 | 48 |
from .custom_user.models import User |
49 |
from . import utils, decorators, attribute_kinds, app_settings, hooks |
|
49 |
from . import (utils, decorators, attribute_kinds, app_settings, hooks, |
|
50 |
api_mixins) |
|
50 | 51 |
from .models import Attribute, PasswordReset, Service |
51 | 52 |
from .a2_rbac.utils import get_default_ou |
52 | 53 | |
... | ... | |
321 | 322 |
return request.user.to_json() |
322 | 323 | |
323 | 324 | |
324 |
class BaseUserSerializer(serializers.ModelSerializer): |
|
325 |
class BaseUserSerializer(api_mixins.GetOrCreateModelSerializer, |
|
326 |
serializers.ModelSerializer): |
|
325 | 327 |
ou = serializers.SlugRelatedField( |
326 | 328 |
queryset=get_ou_model().objects.all(), |
327 | 329 |
slug_field='slug', |
tests/test_api.py | ||
---|---|---|
16 | 16 |
from django.core import mail |
17 | 17 |
from django.contrib.auth.hashers import check_password |
18 | 18 | |
19 |
from authentic2_idp_oidc.models import OIDCClient |
|
20 | ||
21 | 19 |
from utils import login, basic_authorization_header, get_link_from_mail |
22 | 20 | |
23 | 21 |
pytestmark = pytest.mark.django_db |
24 | 22 | |
23 |
User = get_user_model() |
|
24 | ||
25 | 25 | |
26 | 26 |
def test_api_user_simple(logged_app): |
27 | 27 |
resp = logged_app.get('/api/user/') |
... | ... | |
991 | 991 |
assert response.json['checks'][3]['result'] is True |
992 | 992 |
assert response.json['checks'][4]['label'] == 'must contain "ok"' |
993 | 993 |
assert response.json['checks'][4]['result'] is True |
994 | ||
995 | ||
996 |
def test_api_users_get_or_create(settings, app, admin): |
|
997 |
app.authorization = ('Basic', (admin.username, admin.username)) |
|
998 |
# test missing first_name |
|
999 |
payload = { |
|
1000 |
'email': 'john.doe@example.net', |
|
1001 |
'first_name': 'John', |
|
1002 |
'last_name': 'Doe', |
|
1003 |
} |
|
1004 |
resp = app.post_json('/api/users/?get_or_create=email', params=payload, status=201) |
|
1005 |
id = resp.json['id'] |
|
1006 | ||
1007 |
resp = app.post_json('/api/users/?get_or_create=email', params=payload, status=201) |
|
1008 |
assert id == resp.json['id'] |
|
994 |
- |