# Copyright (C) 2019 Entr'ouvert # # This program is free software: you can redistribute it and/or modify it # under the terms of the GNU Affero General Public License as published # by the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. # # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU Affero General Public License for more details. # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . from __future__ import absolute_import import base64 from functools import wraps import hashlib import re from itertools import islice, chain import warnings from requests import Session as RequestSession, Response as RequestResponse from requests.adapters import HTTPAdapter from requests.structures import CaseInsensitiveDict from urllib3.exceptions import InsecureRequestWarning from django.conf import settings from django.core.cache import cache from django.core.exceptions import PermissionDenied from django.http import HttpResponse, HttpResponseBadRequest from django.template import Template, Context from django.utils.encoding import force_bytes, force_text from django.utils.functional import lazy from django.utils.html import mark_safe from io import BytesIO from django.views.generic.detail import SingleObjectMixin from django.contrib.contenttypes.models import ContentType from django.db import transaction from django.utils.decorators import available_attrs from passerelle.base.signature import check_query, check_url mark_safe_lazy = lazy(mark_safe, str) def response_for_json(request, data): import json response = HttpResponse(content_type='application/json') json_str = json.dumps(data) for variable in ('jsonpCallback', 'callback'): if variable in request.GET: identifier = request.GET[variable] if not re.match(r'^[$A-Za-z_][0-9A-Za-z_$]*$', identifier): return HttpResponseBadRequest('invalid JSONP callback name') json_str = '%s(%s);' % (identifier, json_str) response['Content-Type'] = 'application/javascript' break response.write(json_str) return response def get_request_users(request): from passerelle.base.models import ApiUser users = [] users.extend(ApiUser.objects.filter(keytype='')) if 'orig' in request.GET and 'signature' in request.GET: orig = request.GET['orig'] query = request.META['QUERY_STRING'] signature_users = ApiUser.objects.filter(keytype='SIGN', username=orig) for signature_user in signature_users: if check_query(query, signature_user.key): users.append(signature_user) elif 'apikey' in request.GET: users.extend(ApiUser.objects.filter(keytype='API', key=request.GET['apikey'])) elif 'HTTP_AUTHORIZATION' in request.META: http_authorization = request.META['HTTP_AUTHORIZATION'].split(' ', 1) scheme = http_authorization[0].lower() if scheme == 'basic' and len(http_authorization) > 1: param = http_authorization[1] try: decoded = force_text(base64.b64decode(force_bytes(param.strip()))) username, password = decoded.split(':', 1) except (TypeError, ValueError): pass else: users.extend(ApiUser.objects.filter(keytype='SIGN', username=username, key=password)) def ip_match(ip, match): if not ip: return True if ip == match: return True return False users = [x for x in users if ip_match(x.ipsource, request.META.get('REMOTE_ADDR'))] return users def get_trusted_services(): ''' All services in settings.KNOWN_SERVICES are "trusted" ''' trusted_services = [] for service_type in getattr(settings, 'KNOWN_SERVICES', {}): for slug, service in settings.KNOWN_SERVICES[service_type].items(): if service.get('secret') and service.get('verif_orig'): trusted_service = service.copy() trusted_service['service_type'] = service_type trusted_service['slug'] = slug trusted_services.append(trusted_service) return trusted_services def is_trusted(request): ''' True if query-string is signed by a trusted service (see get_trusted_services() above) ''' if not request.GET.get('orig') or not request.GET.get('signature'): return False full_path = request.get_full_path() for service in get_trusted_services(): if (service.get('verif_orig') == request.GET['orig'] and service.get('secret') and check_url(full_path, service['secret'])): return True return False def is_authorized(request, obj, perm): from passerelle.base.models import AccessRight if request.user.is_superuser: return True if is_trusted(request): return True resource_type = ContentType.objects.get_for_model(obj) rights = AccessRight.objects.filter(resource_type=resource_type, resource_pk=obj.id, codename=perm) users = [x.apiuser for x in rights] return set(users).intersection(get_request_users(request)) def protected_api(perm): def decorator(view_func): @wraps(view_func, assigned=available_attrs(view_func)) def _wrapped_view(instance, request, *args, **kwargs): if not isinstance(instance, SingleObjectMixin): raise Exception("protected_api must be applied on a method of a class based view") obj = instance.get_object() if not is_authorized(request, obj, perm): raise PermissionDenied() return view_func(instance, request, *args, **kwargs) return _wrapped_view return decorator def content_type_match(ctype): content_types = settings.LOGGED_CONTENT_TYPES_MESSAGES if not ctype: return False for content_type in content_types: if re.match(content_type, ctype): return True return False def make_headers_safe(headers): '''Convert dict of HTTP headers to text safely, as some services returns 8-bits encoding in headers. ''' return { force_text(key, errors='replace'): force_text(value, errors='replace') for key, value in headers.items() } def log_http_request(logger, request, response=None, exception=None, error_log=True, extra=None): log_function = logger.info message = '' extra = extra or {} if request is not None: message = '%s %s' % (request.method, request.url) extra['request_url'] = request.url if logger.level == 10 and request: # DEBUG extra['request_headers'] = make_headers_safe(request.headers) if request.body: if hasattr(logger, 'connector'): max_size = logger.connector.logging_parameters.requests_max_size else: max_size = settings.LOGGED_REQUESTS_MAX_SIZE extra['request_payload'] = repr(request.body[:max_size]) if response is not None: message = message + ' (=> %s)' % response.status_code extra['response_status'] = response.status_code if logger.level == 10: # DEBUG extra['response_headers'] = make_headers_safe(response.headers) # log body only if content type is allowed if content_type_match(response.headers.get('Content-Type')): if hasattr(logger, 'connector'): max_size = logger.connector.logging_parameters.responses_max_size else: max_size = settings.LOGGED_RESPONSES_MAX_SIZE content = response.content[:max_size] extra['response_content'] = repr(content) if response.status_code // 100 == 3: log_function = logger.warning elif response.status_code // 100 >= 4: log_function = logger.error elif exception: if message: message = message + ' (=> %s)' % repr(exception) else: message = repr(exception) extra['response_exception'] = repr(exception) log_function = logger.error # allow resources to disable any error log at requests level if not error_log: log_function = logger.info log_function(message, extra=extra) # Wrapper around requests.Session # - log input and output data # - use HTTP Basic auth if resource.basic_auth_username and resource.basic_auth_password exist # - use client side certificate if resource.client_certificate (FileField) exists # - verify server certificate CA if resource.trusted_certificate_authorities (FileField) exists # - disable CA verification if resource.verify_cert (BooleanField) exists and is set # - use a proxy for HTTP and HTTPS if resource.http_proxy exists class Request(RequestSession): ADAPTER_REGISTRY = {} # connection pooling def __init__(self, *args, **kwargs): self.logger = kwargs.pop('logger') self.resource = kwargs.pop('resource', None) super(Request, self).__init__(*args, **kwargs) if self.resource: adapter = Request.ADAPTER_REGISTRY.setdefault(type(self.resource), HTTPAdapter()) self.mount('https://', adapter) self.mount('http://', adapter) def request(self, method, url, **kwargs): cache_duration = kwargs.pop('cache_duration', None) invalidate_cache = kwargs.pop('invalidate_cache', False) if self.resource: if 'auth' not in kwargs: username = getattr(self.resource, 'basic_auth_username', None) if username and hasattr(self.resource, 'basic_auth_password'): kwargs['auth'] = (username, self.resource.basic_auth_password) if 'cert' not in kwargs: keystore = getattr(self.resource, 'client_certificate', None) if keystore: kwargs['cert'] = keystore.path if 'verify' not in kwargs: trusted_certificate_authorities = getattr(self.resource, 'trusted_certificate_authorities', None) if trusted_certificate_authorities: kwargs['verify'] = trusted_certificate_authorities.path elif hasattr(self.resource, 'verify_cert'): kwargs['verify'] = self.resource.verify_cert if 'proxies' not in kwargs: proxy = getattr(self.resource, 'http_proxy', None) if proxy: kwargs['proxies'] = {'http': proxy, 'https': proxy} if method == 'GET' and cache_duration: cache_key = hashlib.md5(force_bytes('%r;%r' % (url, kwargs))).hexdigest() cache_content = cache.get(cache_key) if cache_content and not invalidate_cache: response = RequestResponse() response.raw = BytesIO(cache_content.get('content')) response.headers = CaseInsensitiveDict(cache_content.get('headers', {})) response.status_code = cache_content.get('status_code') return response if settings.REQUESTS_PROXIES and 'proxies' not in kwargs: kwargs['proxies'] = settings.REQUESTS_PROXIES if 'timeout' not in kwargs: kwargs['timeout'] = settings.REQUESTS_TIMEOUT with warnings.catch_warnings(): if kwargs.get('verify') is False: # disable urllib3 warnings warnings.simplefilter(action='ignore', category=InsecureRequestWarning) response = super(Request, self).request(method, url, **kwargs) if method == 'GET' and cache_duration and (response.status_code // 100 == 2): cache.set(cache_key, { 'content': response.content, 'headers': response.headers, 'status_code': response.status_code, }, cache_duration) return response def send(self, request, **kwargs): try: response = super(Request, self).send(request, **kwargs) except Exception as exc: self.log_http_request(request, exception=exc) raise self.log_http_request(request, response=response) return response def log_http_request(self, request, response=None, exception=None): error_log = getattr(self.resource, 'log_requests_errors', True) log_http_request(self.logger, request=request, response=response, exception=exception, error_log=error_log) def export_site(slugs=None): '''Dump passerelle configuration (users, resources and ACLs) to JSON dumpable dictionnary''' from passerelle.base.models import ApiUser from passerelle.base.models import BaseResource d = {} d['apiusers'] = [apiuser.export_json() for apiuser in ApiUser.objects.all()] d['resources'] = resources = [] for subclass in BaseResource.__subclasses__(): if subclass._meta.abstract: continue for resource in subclass.objects.all(): if slugs and resource.slug not in slugs: continue try: resources.append(resource.export_json()) except NotImplementedError: break return d def import_site(d, if_empty=False, clean=False, overwrite=False, import_users=False): '''Load passerelle configuration (users, resources and ACLs) from a dictionnary loaded from JSON ''' from passerelle.base.models import ApiUser from passerelle.base.models import BaseResource d = d.copy() def is_empty(): if import_users: if ApiUser.objects.count(): return False for subclass in BaseResource.__subclasses__(): if subclass._meta.abstract: continue if subclass.objects.count(): return False return True if if_empty and not is_empty(): return if clean: for subclass in BaseResource.__subclasses__(): if subclass._meta.abstract: continue subclass.objects.all().delete() if import_users: ApiUser.objects.all().delete() with transaction.atomic(): if import_users: for apiuser in d.get('apiusers', []): ApiUser.import_json(apiuser, overwrite=overwrite) for resource in d.get('resources', []): BaseResource.import_json(resource, overwrite=overwrite, import_users=import_users) def batch(iterable, size): '''Batch an iterable as an iterable of iterables of at most size element long. ''' sourceiter = iter(iterable) while True: batchiter = islice(sourceiter, size) # call next() at least one time to advance, if the caller does not # consume the returned iterators, sourceiter will never be exhausted. try: yield chain([next(batchiter)], batchiter) except StopIteration: return # legacy import, other modules keep importing to_json from passerelle.utils from .jsonresponse import to_json from .soap import SOAPClient, SOAPTransport from .sftp import SFTPField, SFTP