From 4746e8f10f26aba0652342e328d455c70da24ff5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fr=C3=A9d=C3=A9ric=20P=C3=A9ters?= Date: Sat, 23 Dec 2017 13:08:19 +0100 Subject: [PATCH] utils: add cache support to requests wrapper (#17192) --- passerelle/utils/__init__.py | 26 ++++++++++++++++++- tests/test_requests.py | 62 +++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 86 insertions(+), 2 deletions(-) diff --git a/passerelle/utils/__init__.py b/passerelle/utils/__init__.py index 220a642..b836d4b 100644 --- a/passerelle/utils/__init__.py +++ b/passerelle/utils/__init__.py @@ -1,11 +1,15 @@ +from cStringIO import StringIO from functools import wraps +import hashlib import json import re import logging -from requests import Session as RequestSession +from requests import Session as RequestSession, Response as RequestResponse +from requests.structures import CaseInsensitiveDict from django.conf import settings +from django.core.cache import cache from django.core.exceptions import PermissionDenied, ObjectDoesNotExist from django.core.serializers.json import DjangoJSONEncoder from django.http import HttpRequest, HttpResponse, HttpResponseBadRequest @@ -165,6 +169,8 @@ class Request(RequestSession): super(Request, self).__init__(*args, **kwargs) def request(self, method, url, **kwargs): + cache_duration = kwargs.pop('cache_duration', None) + invalidate_cache = kwargs.pop('invalidate_cache', False) params = kwargs.get('params', '') self.logger.info('%s %s %s' % (method, url, params), extra={'requests_url': url} @@ -192,11 +198,29 @@ class Request(RequestSession): if proxy: kwargs['proxies'] = {'http': proxy, 'https': proxy} + if method == 'GET' and cache_duration: + cache_key = hashlib.md5('%r;%r' % (url, kwargs)).hexdigest() + cache_content = cache.get(cache_key) + if cache_content and not invalidate_cache: + response = RequestResponse() + response.raw = StringIO(cache_content.get('content')) + response.headers = CaseInsensitiveDict(cache_content.get('headers', {})) + response.status_code = cache_content.get('status_code') + print 'cached response' + return response + if settings.REQUESTS_PROXIES and 'proxies' not in kwargs: kwargs['proxies'] = settings.REQUESTS_PROXIES response = super(Request, self).request(method, url, **kwargs) + if method == 'GET' and cache_duration and (response.status_code // 100 == 2): + cache.set(cache_key, { + 'content': response.content, + 'headers': response.headers, + 'statux_code': response.status_code, + }, cache_duration) + self.logger.debug('Request Headers: {}'.format(''.join([ '%s: %s | ' % (k,v) for k,v in response.request.headers.items() ]))) diff --git a/tests/test_requests.py b/tests/test_requests.py index 2009fbf..82780b1 100644 --- a/tests/test_requests.py +++ b/tests/test_requests.py @@ -6,7 +6,7 @@ from httmock import urlmatch, HTTMock, response from django.test import override_settings -from passerelle.utils import Request +from passerelle.utils import Request, CaseInsensitiveDict import utils from utils import FakedResponse @@ -217,3 +217,63 @@ def test_resource_certificates(mocked_get, caplog, endpoint_response): request.get('http://example.net/whatever', cert='/local.pem', verify=False) assert mocked_get.call_args[1].get('cert') == '/local.pem' assert mocked_get.call_args[1].get('verify') is False + +@mock.patch('passerelle.utils.RequestSession.request') +def test_requests_cache(mocked_get, caplog): + resource = MockResource() + logger = logging.getLogger('requests') + request = Request(resource=resource, logger=logger) + + response_request = mock.Mock(headers={'Accept': '*/*'}, body=None) + mocked_get.return_value = FakedResponse( + headers={'Content-Type': 'text/plain; charset=charset=utf-8'}, + request=response_request, + content='hello world', status_code=200) + + # by default there is no cache + assert request.get('http://cache.example.org/').content == 'hello world' + assert request.get('http://cache.example.org/').content == 'hello world' + assert mocked_get.call_count == 2 + + # add some cache + mocked_get.reset_mock() + assert request.get('http://cache.example.org/', cache_duration=15).content == 'hello world' + assert mocked_get.call_count == 1 + assert request.get('http://cache.example.org/', cache_duration=15).content == 'hello world' + assert mocked_get.call_count == 1 # got a cached response + + # value changed + mocked_get.return_value = FakedResponse( + headers={'Content-Type': 'text/plain; charset=charset=utf-8'}, + request=response_request, + content='hello second world', status_code=200) + assert request.get('http://cache.example.org/', cache_duration=15).content == 'hello world' + assert mocked_get.call_count == 1 + + # force cache invalidation + assert request.get('http://cache.example.org/', invalidate_cache=True).content == 'hello second world' + assert mocked_get.call_count == 2 + + # do not cache errors + mocked_get.return_value = FakedResponse( + headers={'Content-Type': 'text/plain; charset=charset=utf-8'}, + request=response_request, + content='no such world', status_code=404) + mocked_get.reset_mock() + response = request.get('http://cache.example.org/404', cache_duration=15) + assert response.content == 'no such world' + assert response.status_code == 404 + assert mocked_get.call_count == 1 + response = request.get('http://cache.example.org/404', cache_duration=15) + assert mocked_get.call_count == 2 + + # check response headers + mocked_get.reset_mock() + mocked_get.return_value = FakedResponse( + headers=CaseInsensitiveDict({'Content-Type': 'image/png'}), + request=response_request, + content='hello world', status_code=200) + assert request.get('http://cache.example.org/img', cache_duration=15).headers.get('content-type') == 'image/png' + assert mocked_get.call_count == 1 + assert request.get('http://cache.example.org/img', cache_duration=15).headers.get('content-type') == 'image/png' + assert mocked_get.call_count == 1 # got a cached response -- 2.15.1