Projet

Général

Profil

0001-idp_oidc-perform-claim-resolution-when-api-is-called.patch

Paul Marillonnet, 02 juin 2022 14:08

Télécharger (9,75 ko)

Voir les différences:

Subject: [PATCH] idp_oidc: perform claim resolution when api is called by a
 client (#65877)

 src/authentic2/api_views.py     | 13 +++++++
 src/authentic2_idp_oidc/apps.py | 65 +++++++++++++++++++++++++-------
 tests/idp_oidc/test_api.py      | 66 ++++++++++++++++++++++++++++++++-
 3 files changed, 129 insertions(+), 15 deletions(-)
src/authentic2/api_views.py
826 826
        hooks.call_hooks('api_modify_response', self, 'synchronization', data)
827 827
        return Response(data)
828 828

  
829
    def list(self, request, *args, **kwargs):
830
        queryset = self.filter_queryset(self.get_queryset())
831

  
832
        page = self.paginate_queryset(queryset)
833
        if page is not None:
834
            serializer = self.get_serializer(page, many=True)
835
            hooks.call_hooks('api_modify_serializer_after_validation', self, serializer)
836
            return self.get_paginated_response(serializer.data)
837

  
838
        serializer = self.get_serializer(queryset, many=True)
839
        hooks.call_hooks('api_modify_serializer_after_validation', self, serializer)
840
        return Response(serializer.data)
841

  
829 842
    @action(
830 843
        detail=True,
831 844
        methods=['post'],
src/authentic2_idp_oidc/apps.py
14 14
# You should have received a copy of the GNU Affero General Public License
15 15
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
16 16

  
17
import copy
18

  
17 19
import django.apps
18 20
from django.template.loader import render_to_string
19 21
from django.utils.encoding import smart_bytes
......
49 51
    # implement translation of encrypted pairwise identifiers when and OIDC Client is using the
50 52
    # A2 API
51 53
    def a2_hook_api_modify_serializer(self, view, serializer):
54
        from django.utils.timezone import utc
52 55
        from rest_framework import serializers
53 56

  
57
        from authentic2.utils.template import Template
58

  
54 59
        from . import utils
60
        from .models import OIDCClaim
55 61

  
56 62
        if hasattr(view.request.user, 'oidc_client'):
57 63
            client = view.request.user.oidc_client
......
63 69
                serializer.get_oidc_uuid = get_oidc_uuuid
64 70
                serializer.fields['uuid'] = serializers.SerializerMethodField(method_name='get_oidc_uuid')
65 71

  
72
            # /api/users/ returned content needs oidc claim resolution
73
            if view.__class__.__name__ == 'UsersAPI' and 'claim_resolution' in view.request.GET:
74
                # use deepcopy to prevent overwrite of field.field_name
75
                # see: https://github.com/encode/django-rest-framework/blob/bce9df9b5e0f54a6076519835393fea59accb40c/rest_framework/utils/serializer_helpers.py#L169
76
                serializer.fields['sub'] = copy.deepcopy(serializer.fields['uuid'])
77
                del serializer.fields['is_superuser']
78
                del serializer.fields['is_staff']
79
                del serializer.fields['password']
80
                # forbid modification of email trough POST/PATCH/PUT
81
                if serializer.instance:
82
                    serializer.fields['email'].read_only = True
83
                for claim in OIDCClaim.objects.filter(client=client):
84
                    if claim.name in serializer.fields:
85
                        continue
86
                    serializer.fields[claim.name] = serializers.CharField(
87
                        read_only=True,
88
                    )
89

  
66 90
    @classmethod
67 91
    def get_oidc_client(cls, view):
68 92
        request = view.request
......
93 117
    def a2_hook_api_modify_serializer_after_validation(self, view, serializer):
94 118
        import uuid
95 119

  
120
        from authentic2.utils.template import Template
121

  
96 122
        from . import utils
123
        from .models import OIDCClaim
97 124

  
98 125
        if view.__class__.__name__ != 'UsersAPI':
99 126
            return
100
        if serializer.__class__.__name__ != 'SynchronizationSerializer':
127
        if serializer.__class__.__name__ not in ('SynchronizationSerializer', 'ListSerializer'):
101 128
            return
102 129
        request = view.request
103 130
        if not hasattr(request.user, 'oidc_client'):
......
105 132
        client = request.user.oidc_client
106 133
        if client.identifier_policy != client.POLICY_PAIRWISE_REVERSIBLE:
107 134
            return
108
        new_known_uuids = []
109
        uuid_map = request.uuid_map = {}
110
        request.unknown_uuids = []
111
        for u in serializer.validated_data['known_uuids']:
112
            decrypted = utils.reverse_pairwise_sub(client, smart_bytes(u))
113
            if decrypted:
114
                new_known_uuid = uuid.UUID(bytes=decrypted).hex
115
                new_known_uuids.append(new_known_uuid)
116
                uuid_map[new_known_uuid] = u
117
            else:
118
                request.unknown_uuids.append(u)
119
            # undecipherable sub are just not checked at all
120
        serializer.validated_data['known_uuids'] = new_known_uuids
135
        if serializer.__class__.__name__ == 'SynchronizationSerializer':
136
            new_known_uuids = []
137
            request.unknown_uuids = []
138
            uuid_map = request.uuid_map = {}
139
            for u in serializer.validated_data['known_uuids']:
140
                decrypted = utils.reverse_pairwise_sub(client, smart_bytes(u))
141
                if decrypted:
142
                    new_known_uuid = uuid.UUID(bytes=decrypted).hex
143
                    new_known_uuids.append(new_known_uuid)
144
                    uuid_map[new_known_uuid] = u
145
                else:
146
                    # undecipherable sub are just not checked at all
147
                    request.unknown_uuids.append(u)
148
            serializer.validated_data['known_uuids'] = new_known_uuids
149
        elif serializer.__class__.__name__ == 'ListSerializer' and 'claim_resolution' in view.request.GET:
150
            for user_dict in serializer.data:
151
                context = user_dict.copy()
152
                for claim in OIDCClaim.objects.filter(client=client):
153
                    value = claim.value
154
                    if claim.value and ('{{' in claim.value or '{%' in claim.value):
155
                        template = Template(claim.value)
156
                        value = template.render(context=context)
157
                    user_dict[claim.name] = value
121 158

  
122 159
    def a2_hook_api_modify_response(self, view, method_name, data):
123 160
        """Reverse mapping applied in a2_hook_api_modify_serializer_after_validation using the
tests/idp_oidc/test_api.py
14 14
# You should have received a copy of the GNU Affero General Public License
15 15
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
16 16

  
17
import random
18

  
19
from django.utils.timezone import now
20

  
17 21
from authentic2.custom_user.models import User
18
from authentic2_idp_oidc.models import OIDCClient
22
from authentic2_idp_oidc.models import OIDCClaim, OIDCClient
19 23
from authentic2_idp_oidc.utils import make_sub
20 24

  
21 25

  
......
39 43
    if status == 200:
40 44
        assert response.json['result'] == 1
41 45
        assert set(response.json['unknown_uuids']) == deleted_subs
46

  
47

  
48
def test_api_users_list_claim_resolution(app, oidc_client):
49
    oidc_client.has_api_access = True
50
    oidc_client.identifier_policy = OIDCClient.POLICY_PAIRWISE
51
    oidc_client.save()
52

  
53
    # default claims are boring, create custom templated ones
54
    OIDCClaim.objects.all().delete()
55
    OIDCClaim.objects.create(
56
        name='given_name',
57
        value='Templated {{ first_name }}',
58
        scopes='profile',
59
        client=oidc_client,
60
    )
61
    OIDCClaim.objects.create(
62
        name='family_name',
63
        value='Templated {{ last_name }}',
64
        scopes='profile',
65
        client=oidc_client,
66
    )
67
    OIDCClaim.objects.create(
68
        name='email',
69
        value='{{ last_name }}@templated.nowhere.null',
70
        scopes='email',
71
        client=oidc_client,
72
    )
73

  
74
    users = [User.objects.create(username=f'user-{i}', last_name=f'Name-{i}') for i in range(10)]
75
    pre_modification = now().strftime('%Y-%m-%dT%H:%M:%S')
76
    for count, user in enumerate(users):
77
        user.first_name = f'User {count}'
78
        user.save()
79

  
80
    app.authorization = ('Basic', (oidc_client.client_id, oidc_client.client_secret))
81
    app.get(
82
        f'/api/users/?modified__gt={pre_modification}&claim_resolution',
83
        status=401,
84
    )
85

  
86
    oidc_client.identifier_policy = OIDCClient.POLICY_PAIRWISE_REVERSIBLE
87
    oidc_client.save()
88

  
89
    response = app.get(
90
        f'/api/users/?modified__gt={pre_modification}&claim_resolution',
91
        status=200,
92
    )
93

  
94
    for user_dict in random.choices(response.json['results'], k=3):
95
        assert user_dict['last_name']
96
        assert user_dict['family_name'].startswith('Templated')
97
        assert user_dict['family_name'].endswith(user_dict['last_name'])
98

  
99
        assert user_dict['first_name']
100
        assert user_dict['given_name'].startswith('Templated')
101
        assert user_dict['given_name'].endswith(user_dict['first_name'])
102

  
103
        assert user_dict['email']
104
        assert user_dict['email'].startswith(user_dict['last_name'])
105
        assert user_dict['email'].endswith('@templated.nowhere.null')
42
-