From d03f4fc8d3da9c061f52f01580ecd5f733c977c9 Mon Sep 17 00:00:00 2001 From: Benjamin Dauvergne Date: Fri, 15 Mar 2019 03:17:04 +0100 Subject: [PATCH] api: accept get/update_or_create parameter to user and role creation endpoint (fixes #22376) --- src/authentic2/api_mixins.py | 100 +++++++++++++++++++++++++++++++++++ src/authentic2/api_views.py | 8 +-- tests/test_api.py | 89 ++++++++++++++++++++++++++++--- 3 files changed, 187 insertions(+), 10 deletions(-) create mode 100644 src/authentic2/api_mixins.py diff --git a/src/authentic2/api_mixins.py b/src/authentic2/api_mixins.py new file mode 100644 index 00000000..74c6740c --- /dev/null +++ b/src/authentic2/api_mixins.py @@ -0,0 +1,100 @@ +# authentic2 - versatile identity manager +# Copyright (C) 2010-2018 Entr'ouvert +# +# This program is free software: you can redistribute it and/or modify it +# under the terms of the GNU Affero General Public License as published +# by the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. +# +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . + +from django.db import transaction + +from rest_framework.serializers import raise_errors_on_nested_writes +from rest_framework.utils import model_meta + + +class GetOrCreateModelSerializer(object): + def get_or_create(self, keys, validated_data): + raise_errors_on_nested_writes('get_or_create', self, validated_data) + + ModelClass = self.Meta.model + + # Remove many-to-many relationships from validated_data. + # They are not valid arguments to the default `.create()` method, + # as they require that the instance has already been saved. + info = model_meta.get_field_info(ModelClass) + many_to_many = {} + for field_name, relation_info in info.relations.items(): + if relation_info.to_many and (field_name in validated_data): + many_to_many[field_name] = validated_data.pop(field_name) + + kwargs = {} + defaults = kwargs['defaults'] = {} + missing_keys = set(keys) - set(validated_data) + if missing_keys: + raise TypeError('Keys %s are missing' % missing_keys) + for key, value in validated_data.items(): + if key in keys: + kwargs[key] = value + else: + defaults[key] = value + with transaction.atomic(): + instance, created = self.Meta.model._default_manager.get_or_create(**kwargs) + if many_to_many and created: + self.update(instance, many_to_many) + return instance + + def update_or_create(self, keys, validated_data): + raise_errors_on_nested_writes('update_or_create', self, validated_data) + + ModelClass = self.Meta.model + + # Remove many-to-many relationships from validated_data. + # They are not valid arguments to the default `.create()` method, + # as they require that the instance has already been saved. + info = model_meta.get_field_info(ModelClass) + many_to_many = {} + get_or_create_data = validated_data.copy() + for field_name, relation_info in info.relations.items(): + if relation_info.to_many and (field_name in validated_data): + many_to_many[field_name] = get_or_create_data.pop(field_name) + + kwargs = {} + defaults = kwargs['defaults'] = {} + missing_keys = set(keys) - set(get_or_create_data) + if missing_keys: + raise TypeError('Keys %s are missing' % missing_keys) + for key, value in get_or_create_data.items(): + if key in keys: + kwargs[key] = value + else: + defaults[key] = value + with transaction.atomic(): + instance, created = self.Meta.model._default_manager.get_or_create(**kwargs) + if many_to_many or not created: + self.update(instance, validated_data) + return instance + + def create(self, validated_data): + try: + keys = self.context['view'].request.GET.getlist('get_or_create') + except Exception: + pass + else: + if keys: + return self.get_or_create(keys, validated_data) + try: + keys = self.context['view'].request.GET.getlist('update_or_create') + except Exception: + pass + else: + if keys: + return self.update_or_create(keys, validated_data) + return super(GetOrCreateModelSerializer, self).create(validated_data) diff --git a/src/authentic2/api_views.py b/src/authentic2/api_views.py index 48f95cc4..439062ea 100644 --- a/src/authentic2/api_views.py +++ b/src/authentic2/api_views.py @@ -46,7 +46,8 @@ from django_filters.rest_framework import FilterSet from .passwords import get_password_checker from .custom_user.models import User -from . import utils, decorators, attribute_kinds, app_settings, hooks +from . import (utils, decorators, attribute_kinds, app_settings, hooks, + api_mixins) from .models import Attribute, PasswordReset, Service from .a2_rbac.utils import get_default_ou @@ -321,7 +322,8 @@ def user(request): return request.user.to_json() -class BaseUserSerializer(serializers.ModelSerializer): +class BaseUserSerializer(api_mixins.GetOrCreateModelSerializer, + serializers.ModelSerializer): ou = serializers.SlugRelatedField( queryset=get_ou_model().objects.all(), slug_field='slug', @@ -490,7 +492,7 @@ class BaseUserSerializer(serializers.ModelSerializer): exclude = ('date_joined', 'user_permissions', 'groups', 'last_login') -class RoleSerializer(serializers.ModelSerializer): +class RoleSerializer(api_mixins.GetOrCreateModelSerializer, serializers.ModelSerializer): ou = serializers.SlugRelatedField( many=False, required=False, diff --git a/tests/test_api.py b/tests/test_api.py index 013c1b8d..cf3be3cb 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -22,22 +22,25 @@ import random import uuid -from django.core.urlresolvers import reverse +from django.contrib.auth.hashers import check_password from django.contrib.auth import get_user_model from django.contrib.contenttypes.models import ContentType -from authentic2.a2_rbac.utils import get_default_ou -from django_rbac.utils import get_role_model, get_ou_model -from django_rbac.models import SEARCH_OP -from authentic2.models import Service from django.core import mail -from django.contrib.auth.hashers import check_password +from django.core.urlresolvers import reverse -from authentic2_idp_oidc.models import OIDCClient +from django_rbac.models import SEARCH_OP +from django_rbac.utils import get_role_model, get_ou_model + +from authentic2.a2_rbac.models import Role +from authentic2.a2_rbac.utils import get_default_ou +from authentic2.models import Service from utils import login, basic_authorization_header, get_link_from_mail pytestmark = pytest.mark.django_db +User = get_user_model() + def test_api_user_simple(logged_app): resp = logged_app.get('/api/user/') @@ -1146,3 +1149,75 @@ def test_validate_password_regex(app, settings): assert response.json['checks'][3]['result'] is True assert response.json['checks'][4]['label'] == 'must contain "ok"' assert response.json['checks'][4]['result'] is True + + +def test_api_users_get_or_create(settings, app, admin): + app.authorization = ('Basic', (admin.username, admin.username)) + # test missing first_name + payload = { + 'email': 'john.doe@example.net', + 'first_name': 'John', + 'last_name': 'Doe', + } + resp = app.post_json('/api/users/?get_or_create=email', params=payload, status=201) + id = resp.json['id'] + assert User.objects.get(id=id).first_name == 'John' + assert User.objects.get(id=id).last_name == 'Doe' + + resp = app.post_json('/api/users/?get_or_create=email', params=payload, status=201) + assert id == resp.json['id'] + assert User.objects.get(id=id).first_name == 'John' + assert User.objects.get(id=id).last_name == 'Doe' + + payload['first_name'] = 'Jane' + resp = app.post_json('/api/users/?update_or_create=email', params=payload, status=201) + assert id == resp.json['id'] + assert User.objects.get(id=id).first_name == 'Jane' + assert User.objects.get(id=id).last_name == 'Doe' + + +def test_api_users_get_or_create_multi_key(settings, app, admin): + app.authorization = ('Basic', (admin.username, admin.username)) + # test missing first_name + payload = { + 'email': 'john.doe@example.net', + 'first_name': 'John', + 'last_name': 'Doe', + } + resp = app.post_json('/api/users/?get_or_create=first_name&get_or_create=last_name', params=payload, status=201) + id = resp.json['id'] + assert User.objects.get(id=id).first_name == 'John' + assert User.objects.get(id=id).last_name == 'Doe' + + resp = app.post_json('/api/users/?get_or_create=first_name&get_or_create=last_name', params=payload, status=201) + assert id == resp.json['id'] + assert User.objects.get(id=id).first_name == 'John' + assert User.objects.get(id=id).last_name == 'Doe' + + payload['email'] = 'john.doe@example2.net' + resp = app.post_json('/api/users/?update_or_create=first_name&update_or_create=last_name', params=payload, status=201) + assert id == resp.json['id'] + assert User.objects.get(id=id).email == 'john.doe@example2.net' + + +def test_api_roles_get_or_create(settings, ou1, app, admin): + app.authorization = ('Basic', (admin.username, admin.username)) + # test missing first_name + payload = { + 'ou_slug': 'ou1', + 'name': 'Role 1', + 'slug': 'role-1', + } + resp = app.post_json('/api/roles/?get_or_create=slug', params=payload, status=201) + uuid = resp.json['uuid'] + assert Role.objects.get(uuid=uuid).name == 'Role 1' + assert Role.objects.get(uuid=uuid).slug == 'role-1' + + resp = app.post_json('/api/roles/?get_or_create=slug', params=payload, status=201) + assert uuid == resp.json['uuid'] + + payload['name'] = 'Role 2' + resp = app.post_json('/api/roles/?update_or_create=slug', params=payload, status=201) + assert uuid == resp.json['uuid'] + assert Role.objects.get(uuid=uuid).name == 'Role 2' + assert Role.objects.get(uuid=uuid).slug == 'role-1' -- 2.20.1