From bc3a052b2d996b1bf1b2c79c37d2577fbe7468a1 Mon Sep 17 00:00:00 2001 From: Benjamin Dauvergne Date: Sat, 16 Dec 2017 04:17:54 +0100 Subject: [PATCH 1/3] general: add custom implementation of interval sets (#20732) --- chrono/interval.py | 116 +++++++++++++++++++++++++++++++++++++++++++++++++ tests/test_interval.py | 62 ++++++++++++++++++++++++++ tox.ini | 1 + 3 files changed, 179 insertions(+) create mode 100644 chrono/interval.py create mode 100644 tests/test_interval.py diff --git a/chrono/interval.py b/chrono/interval.py new file mode 100644 index 0000000..92aa498 --- /dev/null +++ b/chrono/interval.py @@ -0,0 +1,116 @@ +# chrono - agendas system +# Copyright (C) 2017 Entr'ouvert +# +# This program is free software: you can redistribute it and/or modify it +# under the terms of the GNU Affero General Public License as published +# by the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. +# +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . + +import bisect + + +class Interval(object): + __slots__ = ['begin', 'end', 'data'] + + def __init__(self, begin, end, data=None): + assert begin < end + self.begin = begin + self.end = end + self.data = data + + def overlap(self, begin, end): + if end <= self.begin: + return False + if begin >= self.end: + return False + return True + + def __repr__(self): + return '' % (self.begin, self.end, self.data or '') + + +class Intervals(object): + "Maintain a list of mostly non overlapping intervals, allow removing overlap" + def __init__(self): + self.points = [] + self.container = [] + + def __insert_point(self, point, interval): + i = bisect.bisect_left(self.points, point) + if i >= len(self.container) or self.points[i] != point: + self.points.insert(i, point) + self.container.insert(i, []) + self.container[i].append(interval) + + def add(self, begin, end, data=None): + 'Add an interval' + self.add_interval(Interval(begin, end, data)) + + def add_interval(self, interval): + 'Add an interval object' + self.__insert_point(interval.begin, interval) + self.__insert_point(interval.end, interval) + + def __iter_interval(self, begin, end, modify=False): + i = bisect.bisect_left(self.points, begin) + while i < len(self.points) and self.points[i] <= end: + container = self.container[i] + if modify: + container = list(container) + for interval in container: + yield self.points[i], interval + i += 1 + + def remove_overlap(self, begin, end): + 'Remove all overlapping intervals' + for point, interval in self.__iter_interval(begin, end, modify=True): + if interval.overlap(begin, end): + self.__remove_interval(interval) + + def overlap(self, begin, end): + 'Test if some intervals overlap' + for point, interval in self.__iter_interval(begin, end): + if interval.overlap(begin, end): + return True + return False + + def search(self, begin, end): + 'Search overlapping intervals' + for point, interval in self.__iter_interval(begin, end): + if interval.overlap(begin, end): + # prevent returning the same interval twice + if interval.begin < begin or interval.begin == point: + yield interval + + def search_data(self, begin, end): + 'Search data elements of overlapping intervals' + for interval in self.search(begin, end): + yield interval.data + + def iter(self): + 'Iterate intervals' + if not self.points: + return [] + return self.search(self.points[0], self.points[-1]) + + def iter_data(self): + 'Iterate data element attached to intervals' + for interval in self.iter(): + yield interval.data + + def __remove_interval(self, interval): + self.__remove_point_interval(interval.begin, interval) + self.__remove_point_interval(interval.end, interval) + + def __remove_point_interval(self, point, interval): + i = bisect.bisect_left(self.points, point) + assert self.points[i] == point + self.container[i].remove(interval) diff --git a/tests/test_interval.py b/tests/test_interval.py new file mode 100644 index 0000000..5eb3793 --- /dev/null +++ b/tests/test_interval.py @@ -0,0 +1,62 @@ +import pytest + +try: + from intervaltree import IntervalTree +except ImportError: + IntervalTree = None + +from chrono.interval import Interval, Intervals + + +def test_interval_repr(): + a = Interval(1, 4) + repr(a) + +def test_interval_overlap(): + a = Interval(1, 4) + + assert not a.overlap(0, 1) + assert a.overlap(0, 2) + assert a.overlap(1, 4) + assert a.overlap(2, 3) + assert a.overlap(3, 5) + assert not a.overlap(5, 6) + +def test_intervals(): + intervals = Intervals() + + assert len(list(intervals.search(0, 5))) == 0 + + for i in range(10): + intervals.add(i, i + 1, 1) + + for i in range(10, 20): + intervals.add(i, i + 1, 2) + + for i in range(5, 15): + intervals.add(i, i + 1, 3) + + assert len(list(intervals.search(0, 5))) == 5 + assert len(list(intervals.search(0, 10))) == 15 + assert len(list(intervals.search(5, 15))) == 20 + assert len(list(intervals.search(10, 20))) == 15 + assert len(list(intervals.search(15, 20))) == 5 + + assert set(intervals.search_data(0, 5)) == {1} + assert set(intervals.search_data(0, 10)) == {1, 3} + assert set(intervals.search_data(5, 15)) == {1, 2, 3} + assert set(intervals.search_data(10, 20)) == {2, 3} + assert set(intervals.search_data(15, 20)) == {2} + + for i in range(20): + assert intervals.overlap(i, i + 1) + + intervals.remove_overlap(5, 15) + assert set(intervals.search_data(0, 20)) == {1, 2} + + for i in range(5): + assert intervals.overlap(i, i + 1) + for i in range(5, 15): + assert not intervals.overlap(i, i + 1) + for i in range(15, 20): + assert intervals.overlap(i, i + 1) diff --git a/tox.ini b/tox.ini index c4f4509..80f80ab 100644 --- a/tox.ini +++ b/tox.ini @@ -15,6 +15,7 @@ deps = django111: django>=1.11,<1.12 pytest-cov pytest-django + intervaltree pytest>=3.3.0 WebTest mock -- 2.15.1