From 2edc178948e43cf44361b1c7a41e0305f0a9c104 Mon Sep 17 00:00:00 2001 From: Serghei Mihai Date: Tue, 13 Nov 2018 14:05:58 +0100 Subject: [PATCH] csvdatasource: remove advanced lookup filters (#13748) --- passerelle/apps/csvdatasource/lookups.py | 70 ------------------------ passerelle/apps/csvdatasource/models.py | 39 +------------ passerelle/apps/csvdatasource/views.py | 49 ++++------------- tests/test_csv_datasource.py | 42 +------------- 4 files changed, 13 insertions(+), 187 deletions(-) delete mode 100644 passerelle/apps/csvdatasource/lookups.py diff --git a/passerelle/apps/csvdatasource/lookups.py b/passerelle/apps/csvdatasource/lookups.py deleted file mode 100644 index 667cdc07..00000000 --- a/passerelle/apps/csvdatasource/lookups.py +++ /dev/null @@ -1,70 +0,0 @@ -DELIMITER = '__' - -class InvalidOperatorError(Exception): - pass - -compare_str = cmp - - -def is_int(value): - try: - int(value) - return True - except (ValueError, TypeError): - return False - - -class Lookup(object): - - def contains(self, key, value): - return lambda x: value in x[key] - - def icontains(self, key, value): - return lambda x: value.lower() in x[key].lower() - - def gt(self, key, value): - return lambda x: int(x[key]) > int(value) - - def igt(self, key, value): - return lambda x: compare_str(x[key].lower(), value.lower()) > 0 - - def ge(self, key, value): - return lambda x: int(x[key]) >= int(value) - - def ige(self, key, value): - return lambda x: compare_str(x[key].lower(), value.lower()) >= 0 - - def lt(self, key, value): - return lambda x: int(x[key]) < int(value) - - def ilt(self, key, value): - return lambda x: compare_str(x[key].lower(), value.lower()) < 0 - - def le(self, key, value): - return lambda x: int(x[key]) <= int(value) - - def ile(self, key, value): - return lambda x: compare_str(x[key].lower(), value.lower()) <= 0 - - def eq(self, key, value): - if is_int(value): - return lambda x: int(value) == int(x[key]) - return lambda x: value == x[key] - - def ieq(self, key, value): - return lambda x: value.lower() == x[key].lower() - - def ne(self, key, value): - if is_int(value): - return lambda x: int(value) != int(x[key]) - return lambda x: value != x[key] - - def ine(self, key, value): - return lambda x: value.lower() != x[key].lower() - - -def get_lookup(operator, key, value): - try: - return getattr(Lookup(), operator)(key, value) - except (AttributeError,): - raise InvalidOperatorError('%s is not a valid operator' % operator) diff --git a/passerelle/apps/csvdatasource/models.py b/passerelle/apps/csvdatasource/models.py index 7b0d5eef..947aa76a 100644 --- a/passerelle/apps/csvdatasource/models.py +++ b/passerelle/apps/csvdatasource/models.py @@ -35,8 +35,6 @@ from passerelle.base.models import BaseResource from passerelle.utils.jsonresponse import APIError from passerelle.utils.api import endpoint -import lookups - identifier_re = re.compile(r"^[^\d\W]\w*\Z", re.UNICODE) @@ -242,41 +240,6 @@ class CsvDataSource(BaseResource): for data in self.get_cached_rows(initial=False): yield data - def get_data(self, filters=None): - titles = [t.strip() for t in self.columns_keynames.split(',')] - - # validate filters (appropriate columns must exist) - if filters: - for filter_key in filters.keys(): - if not filter_key.split(lookups.DELIMITER)[0] in titles: - del filters[filter_key] - - rows = self.get_cached_rows() - data = [] - - # build a generator of all filters - def filters_generator(filters, titles): - if not filters: - return - for key, value in filters.items(): - try: - key, op = key.split(lookups.DELIMITER) - except (ValueError,): - op = 'eq' - yield lookups.get_lookup(op, key, value) - - # apply filters to data - def super_filter(filters, data): - for f in filters: - data = itertools.ifilter(f, data) - return data - - data = list(super_filter( - filters_generator(filters, titles), rows - )) - - return data - @property def titles(self): return [smart_text(t.strip()) for t in self.columns_keynames.split(',')] @@ -288,7 +251,9 @@ class CsvDataSource(BaseResource): query = Query.objects.get(resource=self.id, slug=query_name) except Query.DoesNotExist: raise APIError(u'no such query') + return self.execute_query(request, query, **kwargs) + def execute_query(self, request, query, **kwargs): titles = self.titles data = self.get_cached_rows() diff --git a/passerelle/apps/csvdatasource/views.py b/passerelle/apps/csvdatasource/views.py index 1754c2d2..14c1584a 100644 --- a/passerelle/apps/csvdatasource/views.py +++ b/passerelle/apps/csvdatasource/views.py @@ -31,48 +31,19 @@ from .models import CsvDataSource, Query class CsvDataView(View, SingleObjectMixin): model = CsvDataSource - def _filters_builder(self, request): - filters = {} - obj = self.get_object() - - params = request.GET - - case_insensitive = 'case-insensitive' in params - query = params.get('q', None) - - if query: - if case_insensitive: - filters['text__icontains'] = query.lower() - else: - filters['text__contains'] = query - - # builds filters according to csv file header - for column_title in [t.strip() for t in obj.columns_keynames.split(',') if t]: - match = filter( - (lambda ct: lambda x: x.startswith(ct))(column_title), params.keys() - ) - for key in match: - if case_insensitive: - filters[key + '__ieq'] = params[key].lower() - else: - filters[key] = params[key] - - if 'text' in filters: - if case_insensitive: - filters['text__ieq'] = filters['text'].lower() - else: - filters['text__eq'] = filters['text'] - filters.pop('text') - - return filters - - @utils.protected_api('can_access') @utils.to_json() def get(self, request, *args, **kwargs): - obj = self.get_object() - filters = self._filters_builder(request) - return {'data': obj.get_data(filters)} + params = request.GET + filters = [] + for column_title in [t.strip() for t in self.get_object().columns_keynames.split(',') if t]: + if column_title in params.keys(): + if 'case-insensitive' in params: + filters.append("%s.lower() == query.get('%s', '').lower()" % (column_title, column_title)) + else: + filters.append("%s == query.get('%s')" % (column_title, column_title)) + query = Query(filters='\n'.join(filters)) + return self.get_object().execute_query(request, query, **params.dict()) class NewQueryView(CreateView): diff --git a/tests/test_csv_datasource.py b/tests/test_csv_datasource.py index 0d707df6..57e38672 100644 --- a/tests/test_csv_datasource.py +++ b/tests/test_csv_datasource.py @@ -155,13 +155,6 @@ def test_columns_keynames_with_spaces(client, setup, filetype): result = parse_response(resp) assert len(result) == 1 -def test_skipped_header_data(): - csv = CsvDataSource.objects.create(csv_file=File(StringIO(get_file_content('data.csv')), 'data.csv'), - columns_keynames=',id,,text,', - skip_header=True) - result = csv.get_data({'text': 'Eliot'}) - assert len(result) == 0 - def test_data(client, setup, filetype): csvdata, url = setup('fam,id,, text,sexe ', filename=filetype, data=get_file_content(filetype)) filters = {'text': 'Sacha'} @@ -234,39 +227,6 @@ def test_query_insensitive_and_filter(client, setup, filetype): assert result[0]['text'] == 'Eliot' assert len(result) == 1 -def test_advanced_filters(client, setup, filetype): - csvdata, url = setup(filename=filetype, data=get_file_content(filetype)) - filters = {'id__gt':20, 'id__lt': 40} - resp = client.get(url, filters) - result = parse_response(resp) - assert len(result) == 3 - for stuff in result: - assert stuff['id'] in ('22', '36', '38') - -def test_advanced_filters_combo(client, setup, filetype): - csvdata, url = setup(filename=filetype, data=get_file_content(filetype)) - filters = { - 'id__ge': '20', - 'id__lt': '40', - 'fam__gt': '234', - 'fam__le': '235', - 'fname__icontains': 'Sandra' - } - resp = client.get(url, filters) - result = parse_response(resp) - assert len(result) == 1 - assert result[0]['id'] == '22' - assert result[0]['lname'] == 'MARTIN' - -def test_unknown_operator(client, setup, filetype): - csvdata, url = setup(filename=filetype, data=get_file_content(filetype)) - filters = {'id__whatever': '25', 'fname__icontains':'Eliot'} - resp = client.get(url, filters) - result = json.loads(resp.content) - assert result['err'] == 1 - assert result['err_class'] == 'passerelle.apps.csvdatasource.lookups.InvalidOperatorError' - assert result['err_desc'] == 'whatever is not a valid operator' - def test_dialect(client, setup): csvdata, url = setup(data=data) @@ -280,7 +240,7 @@ def test_dialect(client, setup): } assert expected == csvdata.dialect_options - filters = {'id__gt': '20', 'id__lt': '40', 'fname__icontains': 'Sandra'} + filters = {'id': '22', 'fname': 'Sandra'} resp = client.get(url, filters) result = parse_response(resp) assert len(result) == 1 -- 2.19.1