Projet

Général

Profil

Support #45845 » __init__.py

Maxime TEILLAUD, 12 août 2020 11:50

 
1
# Copyright (C) 2019  Entr'ouvert
2
#
3
# This program is free software: you can redistribute it and/or modify it
4
# under the terms of the GNU Affero General Public License as published
5
# by the Free Software Foundation, either version 3 of the License, or
6
# (at your option) any later version.
7
#
8
# This program is distributed in the hope that it will be useful,
9
# but WITHOUT ANY WARRANTY; without even the implied warranty of
10
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
11
# GNU Affero General Public License for more details.
12
#
13
# You should have received a copy of the GNU Affero General Public License
14
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
15

    
16
from __future__ import absolute_import
17

    
18
import base64
19
from functools import wraps
20
import hashlib
21
import re
22
from itertools import islice, chain
23
import warnings
24

    
25
from requests import Session as RequestSession, Response as RequestResponse
26
from requests.adapters import HTTPAdapter
27
from requests.structures import CaseInsensitiveDict
28
from urllib3.exceptions import InsecureRequestWarning
29

    
30
from django.conf import settings
31
from django.core.cache import cache
32
from django.core.exceptions import PermissionDenied
33
from django.http import HttpResponse, HttpResponseBadRequest
34
from django.template import Template, Context
35
from django.utils.encoding import force_bytes, force_text
36
from django.utils.functional import lazy
37
from django.utils.html import mark_safe
38
from io import BytesIO
39
from django.views.generic.detail import SingleObjectMixin
40
from django.contrib.contenttypes.models import ContentType
41
from django.db import transaction
42
from django.utils.decorators import available_attrs
43

    
44
from passerelle.base.signature import check_query, check_url
45

    
46

    
47
mark_safe_lazy = lazy(mark_safe, str)
48

    
49

    
50
def response_for_json(request, data):
51
    import json
52

    
53
    response = HttpResponse(content_type='application/json')
54
    json_str = json.dumps(data)
55
    for variable in ('jsonpCallback', 'callback'):
56
        if variable in request.GET:
57
            identifier = request.GET[variable]
58
            if not re.match(r'^[$A-Za-z_][0-9A-Za-z_$]*$', identifier):
59
                return HttpResponseBadRequest('invalid JSONP callback name')
60
            json_str = '%s(%s);' % (identifier, json_str)
61
            response['Content-Type'] = 'application/javascript'
62
            break
63
    response.write(json_str)
64
    return response
65

    
66

    
67
def get_request_users(request):
68
    from passerelle.base.models import ApiUser
69

    
70
    users = []
71

    
72
    users.extend(ApiUser.objects.filter(keytype=''))
73

    
74
    if 'orig' in request.GET and 'signature' in request.GET:
75
        orig = request.GET['orig']
76
        query = request.META['QUERY_STRING']
77
        signature_users = ApiUser.objects.filter(keytype='SIGN', username=orig)
78
        for signature_user in signature_users:
79
            if check_query(query, signature_user.key):
80
                users.append(signature_user)
81

    
82
    elif 'apikey' in request.GET:
83
        users.extend(ApiUser.objects.filter(keytype='API', key=request.GET['apikey']))
84

    
85
    elif 'HTTP_AUTHORIZATION' in request.META:
86
        http_authorization = request.META['HTTP_AUTHORIZATION'].split(' ', 1)
87
        scheme = http_authorization[0].lower()
88
        if scheme == 'basic' and len(http_authorization) > 1:
89
            param = http_authorization[1]
90
            try:
91
                decoded = force_text(base64.b64decode(force_bytes(param.strip())))
92
                username, password = decoded.split(':', 1)
93
            except (TypeError, ValueError):
94
                pass
95
            else:
96
                users.extend(ApiUser.objects.filter(keytype='SIGN', username=username, key=password))
97

    
98
    def ip_match(ip, match):
99
        if not ip:
100
            return True
101
        if ip == match:
102
            return True
103
        return False
104

    
105
    users = [x for x in users if ip_match(x.ipsource, request.META.get('REMOTE_ADDR'))]
106
    return users
107

    
108

    
109
def get_trusted_services():
110
    '''
111
    All services in settings.KNOWN_SERVICES are "trusted"
112
    '''
113
    trusted_services = []
114
    for service_type in getattr(settings, 'KNOWN_SERVICES', {}):
115
        for slug, service in settings.KNOWN_SERVICES[service_type].items():
116
            if service.get('secret') and service.get('verif_orig'):
117
                trusted_service = service.copy()
118
                trusted_service['service_type'] = service_type
119
                trusted_service['slug'] = slug
120
                trusted_services.append(trusted_service)
121
    return trusted_services
122

    
123

    
124
def is_trusted(request):
125
    '''
126
    True if query-string is signed by a trusted service (see get_trusted_services() above)
127
    '''
128
    if not request.GET.get('orig') or not request.GET.get('signature'):
129
        return False
130
    full_path = request.get_full_path()
131
    for service in get_trusted_services():
132
        if (service.get('verif_orig') == request.GET['orig']
133
                and service.get('secret')
134
                and check_url(full_path, service['secret'])):
135
            return True
136
    return False
137

    
138

    
139
def is_authorized(request, obj, perm):
140
    from passerelle.base.models import AccessRight
141

    
142
    if request.user.is_superuser:
143
        return True
144
    if is_trusted(request):
145
        return True
146
    resource_type = ContentType.objects.get_for_model(obj)
147
    rights = AccessRight.objects.filter(resource_type=resource_type, resource_pk=obj.id, codename=perm)
148
    users = [x.apiuser for x in rights]
149
    return set(users).intersection(get_request_users(request))
150

    
151

    
152
def protected_api(perm):
153
    def decorator(view_func):
154
        @wraps(view_func, assigned=available_attrs(view_func))
155
        def _wrapped_view(instance, request, *args, **kwargs):
156
            if not isinstance(instance, SingleObjectMixin):
157
                raise Exception("protected_api must be applied on a method of a class based view")
158
            obj = instance.get_object()
159
            if not is_authorized(request, obj, perm):
160
                raise PermissionDenied()
161
            return view_func(instance, request, *args, **kwargs)
162
        return _wrapped_view
163
    return decorator
164

    
165

    
166
def content_type_match(ctype):
167
    content_types = settings.LOGGED_CONTENT_TYPES_MESSAGES
168
    if not ctype:
169
        return False
170
    for content_type in content_types:
171
        if re.match(content_type, ctype):
172
            return True
173
    return False
174

    
175

    
176
def make_headers_safe(headers):
177
    '''Convert dict of HTTP headers to text safely, as some services returns 8-bits encoding in headers.
178
    '''
179
    return {
180
        force_text(key, errors='replace'): force_text(value, errors='replace')
181
        for key, value in headers.items()
182
    }
183

    
184

    
185
def log_http_request(logger, request, response=None, exception=None, error_log=True, extra=None):
186
    log_function = logger.info
187
    message = ''
188
    extra = extra or {}
189

    
190
    if request is not None:
191
        message = '%s %s' % (request.method, request.url)
192
        extra['request_url'] = request.url
193
    if logger.level == 10 and request:  # DEBUG
194
        extra['request_headers'] = make_headers_safe(request.headers)
195
        if request.body:
196
            if hasattr(logger, 'connector'):
197
                max_size = logger.connector.logging_parameters.requests_max_size
198
            else:
199
                max_size = settings.LOGGED_REQUESTS_MAX_SIZE
200
            extra['request_payload'] = repr(request.body[:max_size])
201
    if response is not None:
202
        message = message + ' (=> %s)' % response.status_code
203
        extra['response_status'] = response.status_code
204
        if logger.level == 10:  # DEBUG
205
            extra['response_headers'] = make_headers_safe(response.headers)
206
            # log body only if content type is allowed
207
            if content_type_match(response.headers.get('Content-Type')):
208
                if hasattr(logger, 'connector'):
209
                    max_size = logger.connector.logging_parameters.responses_max_size
210
                else:
211
                    max_size = settings.LOGGED_RESPONSES_MAX_SIZE
212
                content = response.content[:max_size]
213
                extra['response_content'] = repr(content)
214
        if response.status_code // 100 == 3:
215
            log_function = logger.warning
216
        elif response.status_code // 100 >= 4:
217
            log_function = logger.error
218
    elif exception:
219
        if message:
220
            message = message + ' (=> %s)' % repr(exception)
221
        else:
222
            message = repr(exception)
223
        extra['response_exception'] = repr(exception)
224
        log_function = logger.error
225

    
226
    # allow resources to disable any error log at requests level
227
    if not error_log:
228
        log_function = logger.info
229
    log_function(message, extra=extra)
230

    
231

    
232
# Wrapper around requests.Session
233
# - log input and output data
234
# - use HTTP Basic auth if resource.basic_auth_username and resource.basic_auth_password exist
235
# - use client side certificate if resource.client_certificate (FileField) exists
236
# - verify server certificate CA if resource.trusted_certificate_authorities (FileField) exists
237
# - disable CA verification if resource.verify_cert (BooleanField) exists and is set
238
# - use a proxy for HTTP and HTTPS if resource.http_proxy exists
239

    
240
class Request(RequestSession):
241
    ADAPTER_REGISTRY = {}  # connection pooling
242

    
243
    def __init__(self, *args, **kwargs):
244
        self.logger = kwargs.pop('logger')
245
        self.resource = kwargs.pop('resource', None)
246
        super(Request, self).__init__(*args, **kwargs)
247
        if self.resource:
248
            adapter = Request.ADAPTER_REGISTRY.setdefault(type(self.resource), HTTPAdapter())
249
            self.mount('https://', adapter)
250
            self.mount('http://', adapter)
251

    
252
    def request(self, method, url, **kwargs):
253
        cache_duration = kwargs.pop('cache_duration', None)
254
        invalidate_cache = kwargs.pop('invalidate_cache', False)
255

    
256
        if self.resource:
257
            if 'auth' not in kwargs:
258
                username = getattr(self.resource, 'basic_auth_username', None)
259
                if username and hasattr(self.resource, 'basic_auth_password'):
260
                    kwargs['auth'] = (username, self.resource.basic_auth_password)
261
            if 'cert' not in kwargs:
262
                keystore = getattr(self.resource, 'client_certificate', None)
263
                if keystore:
264
                    kwargs['cert'] = keystore.path
265
            if 'verify' not in kwargs:
266
                trusted_certificate_authorities = getattr(self.resource,
267
                                                          'trusted_certificate_authorities',
268
                                                          None)
269
                if trusted_certificate_authorities:
270
                    kwargs['verify'] = trusted_certificate_authorities.path
271
                elif hasattr(self.resource, 'verify_cert'):
272
                    kwargs['verify'] = self.resource.verify_cert
273
            if 'proxies' not in kwargs:
274
                proxy = getattr(self.resource, 'http_proxy', None)
275
                if proxy:
276
                    kwargs['proxies'] = {'http': proxy, 'https': proxy}
277

    
278
        if method == 'GET' and cache_duration:
279
            cache_key = hashlib.md5(force_bytes('%r;%r' % (url, kwargs))).hexdigest()
280
            cache_content = cache.get(cache_key)
281
            if cache_content and not invalidate_cache:
282
                response = RequestResponse()
283
                response.raw = BytesIO(cache_content.get('content'))
284
                response.headers = CaseInsensitiveDict(cache_content.get('headers', {}))
285
                response.status_code = cache_content.get('status_code')
286
                return response
287

    
288
        if settings.REQUESTS_PROXIES and 'proxies' not in kwargs:
289
            kwargs['proxies'] = settings.REQUESTS_PROXIES
290

    
291
        if 'timeout' not in kwargs:
292
            kwargs['timeout'] = settings.REQUESTS_TIMEOUT
293

    
294
        with warnings.catch_warnings():
295
            if kwargs.get('verify') is False:
296
                # disable urllib3 warnings
297
                warnings.simplefilter(action='ignore', category=InsecureRequestWarning)
298
            response = super(Request, self).request(method, url, **kwargs)
299

    
300
        if method == 'GET' and cache_duration and (response.status_code // 100 == 2):
301
            cache.set(cache_key, {
302
                'content': response.content,
303
                'headers': response.headers,
304
                'status_code': response.status_code,
305
            }, cache_duration)
306

    
307
        return response
308

    
309
    def send(self, request, **kwargs):
310
        try:
311
            response = super(Request, self).send(request, **kwargs)
312
        except Exception as exc:
313
            self.log_http_request(request, exception=exc)
314
            raise
315
        self.log_http_request(request, response=response)
316
        return response
317

    
318
    def log_http_request(self, request, response=None, exception=None):
319
        error_log = getattr(self.resource, 'log_requests_errors', True)
320
        log_http_request(self.logger, request=request, response=response, exception=exception, error_log=error_log)
321

    
322

    
323

    
324
def export_site(slugs=None):
325
    '''Dump passerelle configuration (users, resources and ACLs) to JSON dumpable dictionnary'''
326
    from passerelle.base.models import ApiUser
327
    from passerelle.base.models import BaseResource
328

    
329
    d = {}
330
    d['apiusers'] = [apiuser.export_json() for apiuser in ApiUser.objects.all()]
331
    d['resources'] = resources = []
332
    for subclass in BaseResource.__subclasses__():
333
        if subclass._meta.abstract:
334
            continue
335
        for resource in subclass.objects.all():
336
            if slugs and resource.slug not in slugs:
337
                continue
338
            try:
339
                resources.append(resource.export_json())
340
            except NotImplementedError:
341
                break
342
    return d
343

    
344

    
345
def import_site(d, if_empty=False, clean=False, overwrite=False, import_users=False):
346
    '''Load passerelle configuration (users, resources and ACLs) from a dictionnary loaded from
347
       JSON
348
    '''
349
    from passerelle.base.models import ApiUser
350
    from passerelle.base.models import BaseResource
351

    
352
    d = d.copy()
353

    
354
    def is_empty():
355
        if import_users:
356
            if ApiUser.objects.count():
357
                return False
358

    
359
        for subclass in BaseResource.__subclasses__():
360
            if subclass._meta.abstract:
361
                continue
362
            if subclass.objects.count():
363
                return False
364
        return True
365

    
366
    if if_empty and not is_empty():
367
        return
368

    
369
    if clean:
370
        for subclass in BaseResource.__subclasses__():
371
            if subclass._meta.abstract:
372
                continue
373
            subclass.objects.all().delete()
374
        if import_users:
375
            ApiUser.objects.all().delete()
376

    
377
    with transaction.atomic():
378
        if import_users:
379
            for apiuser in d.get('apiusers', []):
380
                ApiUser.import_json(apiuser, overwrite=overwrite)
381

    
382
        for resource in d.get('resources', []):
383
            BaseResource.import_json(resource, overwrite=overwrite, import_users=import_users)
384

    
385

    
386
def batch(iterable, size):
387
    '''Batch an iterable as an iterable of iterables of at most size element
388
       long.
389
    '''
390
    sourceiter = iter(iterable)
391
    while True:
392
        batchiter = islice(sourceiter, size)
393
        # call next() at least one time to advance, if the caller does not
394
        # consume the returned iterators, sourceiter will never be exhausted.
395
        try:
396
            yield chain([next(batchiter)], batchiter)
397
        except StopIteration:
398
            return
399

    
400
# legacy import, other modules keep importing to_json from passerelle.utils
401
from .jsonresponse import to_json
402
from .soap import SOAPClient, SOAPTransport
403
from .sftp import SFTPField, SFTP