From 9c48014869822fff313e69e6551d074f4313dba8 Mon Sep 17 00:00:00 2001 From: Benjamin Dauvergne Date: Thu, 6 Feb 2020 18:36:55 +0100 Subject: [PATCH] arcgis: handler parameters in where clauses (#39612) --- passerelle/apps/arcgis/models.py | 34 ++++++++++++++++++++++++++++++++ tests/test_arcgis.py | 25 +++++++++++++++++++++++ 2 files changed, 59 insertions(+) diff --git a/passerelle/apps/arcgis/models.py b/passerelle/apps/arcgis/models.py index 6c1e72da..c24da9b4 100644 --- a/passerelle/apps/arcgis/models.py +++ b/passerelle/apps/arcgis/models.py @@ -14,6 +14,7 @@ # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . +import re from django.db import models from django.template import Template, Context from django.utils.six.moves.urllib import parse as urlparse @@ -24,6 +25,33 @@ from passerelle.utils.api import endpoint from passerelle.base.models import BaseResource, HTTPResource +PRINTF_RE = re.compile('%.') + + +def format_where(fmt, args): + '''Format where clause by applying SQL string quoting for %s interpolation + and converting string to integer for %s interpolation. All other conversions + are considered errors. + ''' + parcels = [parcel for parcel in PRINTF_RE.findall(fmt) if parcel != '%%'] + + if len(parcels) != len(args): + raise APIError( + 'number of %%? != number of params in where clause:' + ' %s != %s' % (len(parcels), len(args)), + http_status=400) + + converted_args = [] + for parcel, arg in zip(parcels, args): + if parcel == '%s': + converted_args.append('\'%s\'' % str(arg).replace('\'', '\'\'')) + elif parcel == '%d': + converted_args.append(int(arg)) + else: + raise APIError('invalid where string: %r' % fmt, http_status=400) + return fmt % tuple(converted_args) + + class ArcGISError(APIError): pass @@ -79,6 +107,10 @@ class ArcGIS(BaseResource, HTTPResource): url = urlparse.urljoin(url, folder + '/') url = urlparse.urljoin(url, service + '/MapServer/' + layer + '/query') + # remove where-param from kwargs + kwargs.pop('where-param', None) + where_params = request.GET.getlist('where-param', []) + # build query params # cf https://developers.arcgis.com/rest/services-reference/query-map-service-layer-.htm params = { @@ -107,6 +139,8 @@ class ArcGIS(BaseResource, HTTPResource): params['text'] = q # consider all remaining parameters as ArcGIS ones params.update(kwargs) + if 'where' in params: + params['where'] = format_where(params['where'], where_params) if 'where' not in params and 'text' not in params: params['where'] = '1=1' if 'distance' in params and 'units' not in params: diff --git a/tests/test_arcgis.py b/tests/test_arcgis.py index 3c197057..280ae691 100644 --- a/tests/test_arcgis.py +++ b/tests/test_arcgis.py @@ -214,6 +214,31 @@ def test_arcgis_mapservice_query(app, arcgis): assert args['distance'] == '5' assert args['units'] == 'esriSRUnit_NauticalMile' + # parametric where + requests_get.reset_mock() + params = { + 'folder': 'fold', + 'service': 'serv', + 'layer': '1', + 'full': 'on', + 'where': 'adresse LIKE %s AND numero < %d', + 'where-param': ['AVENUE D\'ANNAM', '10'] + } + resp = app.get(endpoint, params=params, status=200) + assert requests_get.call_count == 1 + assert requests_get.call_args[0][0] == 'https://arcgis.example.net/services/fold/serv/MapServer/1/query' + args = requests_get.call_args[1]['params'] + assert args['where'] == 'adresse LIKE \'AVENUE D\'\'ANNAM\' AND numero < 10' + del params['where-param'] + resp = app.get(endpoint, params=params, status=400) + assert 'number of %' in resp.json['err_desc'] + params['where-param'] = ['1', '2', '3'] + resp = app.get(endpoint, params=params, status=400) + assert 'number of %' in resp.json['err_desc'] + params['where'] = '%f %% %d %s' + resp = app.get(endpoint, params=params, status=400) + assert 'invalid where string' in resp.json['err_desc'] + # call errors with mock.patch('passerelle.utils.Request.get') as requests_get: requests_get.return_value = utils.FakedResponse(content=STATES, -- 2.24.0