Projet

Général

Profil

0001-arcgis-handler-parameters-in-where-clauses-39612.patch

Benjamin Dauvergne, 06 février 2020 18:37

Télécharger (4,52 ko)

Voir les différences:

Subject: [PATCH] arcgis: handler parameters in where clauses (#39612)

 passerelle/apps/arcgis/models.py | 34 ++++++++++++++++++++++++++++++++
 tests/test_arcgis.py             | 25 +++++++++++++++++++++++
 2 files changed, 59 insertions(+)
passerelle/apps/arcgis/models.py
14 14
# You should have received a copy of the GNU Affero General Public License
15 15
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
16 16

  
17
import re
17 18
from django.db import models
18 19
from django.template import Template, Context
19 20
from django.utils.six.moves.urllib import parse as urlparse
......
24 25
from passerelle.base.models import BaseResource, HTTPResource
25 26

  
26 27

  
28
PRINTF_RE = re.compile('%.')
29

  
30

  
31
def format_where(fmt, args):
32
    '''Format where clause by applying SQL string quoting for %s interpolation
33
       and converting string to integer for %s interpolation. All other conversions
34
       are considered errors.
35
    '''
36
    parcels = [parcel for parcel in PRINTF_RE.findall(fmt) if parcel != '%%']
37

  
38
    if len(parcels) != len(args):
39
        raise APIError(
40
            'number of %%? != number of params in where clause:'
41
            ' %s != %s' % (len(parcels), len(args)),
42
            http_status=400)
43

  
44
    converted_args = []
45
    for parcel, arg in zip(parcels, args):
46
        if parcel == '%s':
47
            converted_args.append('\'%s\'' % str(arg).replace('\'', '\'\''))
48
        elif parcel == '%d':
49
            converted_args.append(int(arg))
50
        else:
51
            raise APIError('invalid where string: %r' % fmt, http_status=400)
52
    return fmt % tuple(converted_args)
53

  
54

  
27 55
class ArcGISError(APIError):
28 56
    pass
29 57

  
......
79 107
            url = urlparse.urljoin(url, folder + '/')
80 108
        url = urlparse.urljoin(url, service + '/MapServer/' + layer + '/query')
81 109

  
110
        # remove where-param from kwargs
111
        kwargs.pop('where-param', None)
112
        where_params = request.GET.getlist('where-param', [])
113

  
82 114
        # build query params
83 115
        # cf https://developers.arcgis.com/rest/services-reference/query-map-service-layer-.htm
84 116
        params = {
......
107 139
            params['text'] = q
108 140
        # consider all remaining parameters as ArcGIS ones
109 141
        params.update(kwargs)
142
        if 'where' in params:
143
            params['where'] = format_where(params['where'], where_params)
110 144
        if 'where' not in params and 'text' not in params:
111 145
            params['where'] = '1=1'
112 146
        if 'distance' in params and 'units' not in params:
tests/test_arcgis.py
214 214
        assert args['distance'] == '5'
215 215
        assert args['units'] == 'esriSRUnit_NauticalMile'
216 216

  
217
        # parametric where
218
        requests_get.reset_mock()
219
        params = {
220
            'folder': 'fold',
221
            'service': 'serv',
222
            'layer': '1',
223
            'full': 'on',
224
            'where': 'adresse LIKE %s AND numero < %d',
225
            'where-param': ['AVENUE D\'ANNAM', '10']
226
        }
227
        resp = app.get(endpoint, params=params, status=200)
228
        assert requests_get.call_count == 1
229
        assert requests_get.call_args[0][0] == 'https://arcgis.example.net/services/fold/serv/MapServer/1/query'
230
        args = requests_get.call_args[1]['params']
231
        assert args['where'] == 'adresse LIKE \'AVENUE D\'\'ANNAM\' AND numero < 10'
232
        del params['where-param']
233
        resp = app.get(endpoint, params=params, status=400)
234
        assert 'number of %' in resp.json['err_desc']
235
        params['where-param'] = ['1', '2', '3']
236
        resp = app.get(endpoint, params=params, status=400)
237
        assert 'number of %' in resp.json['err_desc']
238
        params['where'] = '%f %% %d %s'
239
        resp = app.get(endpoint, params=params, status=400)
240
        assert 'invalid where string' in resp.json['err_desc']
241

  
217 242
    # call errors
218 243
    with mock.patch('passerelle.utils.Request.get') as requests_get:
219 244
        requests_get.return_value = utils.FakedResponse(content=STATES,
220
-