From 1d0eb14c38f8c6341472685bafd48beab88789be Mon Sep 17 00:00:00 2001
From: Paul Marillonnet <pmarillonnet@entrouvert.com>
Date: Fri, 18 Nov 2022 11:08:18 +0100
Subject: [PATCH] api: filter synchronization queryset by ou permissions
 (#71463)

---
 src/authentic2/api_views.py | 22 ++++++++++--
 tests/test_api_client.py    | 69 +++++++++++++++++++++++++++++++++++++
 2 files changed, 89 insertions(+), 2 deletions(-)

diff --git a/src/authentic2/api_views.py b/src/authentic2/api_views.py
index a2026ea73..1555fc2e9 100644
--- a/src/authentic2/api_views.py
+++ b/src/authentic2/api_views.py
@@ -769,6 +769,22 @@ class UsersAPI(api_mixins.GetOrCreateMixinView, HookMixin, ExceptionHandlerMixin
                 queryset = queryset.none()
         return queryset
 
+    def filter_queryset_by_ou_perm(self, perm):
+        queryset = User.objects
+        allowed_ous = []
+
+        if self.request.user.has_perm(perm):
+            return queryset
+
+        for ou in OrganizationalUnit.objects.all():
+            if self.request.user.has_ou_perm(perm, ou):
+                allowed_ous.append(ou)
+        if not allowed_ous:
+            raise PermissionDenied("You do not have permission to perform this action.")
+
+        queryset = queryset.filter(ou__in=allowed_ous)
+        return queryset
+
     def update(self, request, *args, **kwargs):
         kwargs['partial'] = True
         return super().update(request, *args, **kwargs)
@@ -824,15 +840,17 @@ class UsersAPI(api_mixins.GetOrCreateMixinView, HookMixin, ExceptionHandlerMixin
                     modified_users_uuids.add(users_pks[instance_pk].uuid)
         return modified_users_uuids
 
-    @action(detail=False, methods=['post'], permission_classes=(DjangoPermission('custom_user.search_user'),))
+    @action(detail=False, methods=['post'], permission_classes=(permissions.IsAuthenticated,))
     def synchronization(self, request):
         serializer = self.SynchronizationSerializer(data=request.data)
+        users = self.filter_queryset_by_ou_perm('custom_user.search_user')
+
         if not serializer.is_valid():
             response = {'result': 0, 'errors': serializer.errors}
             return Response(response, status.HTTP_400_BAD_REQUEST)
         hooks.call_hooks('api_modify_serializer_after_validation', self, serializer)
         remote_uuids = serializer.validated_data.get('known_uuids', [])
-        users = User.objects.filter(uuid__in=remote_uuids).only('id', 'uuid')
+        users = users.filter(uuid__in=remote_uuids).only('id', 'uuid')
         unknown_uuids = self.check_unknown_uuids(remote_uuids, users)
         data = {
             'result': 1,
diff --git a/tests/test_api_client.py b/tests/test_api_client.py
index 4450e66b2..11a410bf2 100644
--- a/tests/test_api_client.py
+++ b/tests/test_api_client.py
@@ -7,6 +7,7 @@ from django.contrib.contenttypes.models import ContentType
 from django.urls import reverse
 
 from authentic2.a2_rbac.models import ADD_OP, SEARCH_OP, VIEW_OP, Role
+from authentic2.a2_rbac.utils import get_default_ou
 from authentic2.models import APIClient
 
 User = get_user_model()
@@ -52,6 +53,37 @@ def test_api_users_list(app, api_client):
     assert len(resp.json['results']) == 1
 
 
+def test_api_users_list_ou(app, api_client, ou1):
+    user = User.objects.create(username='user1')
+    api_client.ou = ou1
+    api_client.save()
+
+    app.authorization = ('Basic', ('foo', 'bar'))
+    resp = app.get('/api/users/', status=401)
+
+    app.authorization = ('Basic', (api_client.identifier, api_client.password))
+    resp = app.get('/api/users/')
+    assert len(resp.json['results']) == 0
+
+    # give permissions
+    r = Role.objects.get_admin_role(
+        ContentType.objects.get_for_model(User),
+        name='role',
+        slug='role',
+        ou=ou1,
+        operation=VIEW_OP,
+    )
+    api_client.apiclient_roles.add(r)
+    resp = app.get('/api/users/')
+    assert len(resp.json['results']) == 0
+
+    user.ou = ou1
+    user.save()
+
+    resp = app.get('/api/users/')
+    assert len(resp.json['results']) == 1
+
+
 def test_api_user_synchronization(app, api_client):
     uuids = []
     for _ in range(100):
@@ -80,6 +112,43 @@ def test_api_user_synchronization(app, api_client):
     assert set(response.json['unknown_uuids']) == set(unknown_uuids)
 
 
+def test_api_user_synchronization_ou(app, api_client, ou1):
+    uuids = []
+    authorized_uuids = []
+    for index in range(100):
+        ou = ou1 if index % 2 else get_default_ou()
+        user = User.objects.create(first_name='ben', last_name='dauve', ou=ou)
+        uuids.append(user.uuid)
+        if index % 2:
+            authorized_uuids.append(user.uuid)
+    unknown_uuids = [uuid.uuid4().hex for i in range(100)]
+    url = reverse('a2-api-users-synchronization')
+    content = {
+        'known_uuids': uuids + unknown_uuids,
+    }
+    random.shuffle(content['known_uuids'])
+
+    app.authorization = ('Basic', ('foo', 'bar'))
+    response = app.post_json(url, params=content, status=401)
+
+    app.authorization = ('Basic', (api_client.identifier, api_client.password))
+    response = app.post_json(url, params=content, status=403)
+
+    # give custom_user.search_user permission to user
+    r = Role.objects.get_admin_role(
+        ContentType.objects.get_for_model(User),
+        name='role',
+        slug='role',
+        ou=ou1,
+        operation=SEARCH_OP,
+    )
+    api_client.apiclient_roles.add(r)
+    response = app.post_json(url, params=content)
+    assert response.json['result'] == 1
+    assert set(response.json['unknown_uuids']) != set(unknown_uuids)
+    assert set(unknown_uuids).issubset(set(response.json['unknown_uuids']))
+
+
 def test_api_users_create(app, api_client):
     payload = {
         'username': 'janedoe',
-- 
2.38.1

