Projet

Général

Profil

0001-add-a-custom-implementation-of-interval-sets.patch

Benjamin Dauvergne, 16 décembre 2017 05:40

Télécharger (11,3 ko)

Voir les différences:

Subject: [PATCH 1/2] add a custom implementation of interval sets

 chrono/interval.py     |  208 ++++++++++++++++++++++++++++++++++++++++++++++++
 tests/test_interval.py |  118 +++++++++++++++++++++++++++
 tox.ini                |    2 +
 3 files changed, 328 insertions(+)
 create mode 100644 chrono/interval.py
 create mode 100644 tests/test_interval.py
chrono/interval.py
1
# chrono - agendas system
2
# Copyright (C) 2016  Entr'ouvert
3
#
4
# This program is free software: you can redistribute it and/or modify it
5
# under the terms of the GNU Affero General Public License as published
6
# by the Free Software Foundation, either version 3 of the License, or
7
# (at your option) any later version.
8
#
9
# This program is distributed in the hope that it will be useful,
10
# but WITHOUT ANY WARRANTY; without even the implied warranty of
11
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12
# GNU Affero General Public License for more details.
13
#
14
# You should have received a copy of the GNU Affero General Public License
15
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
16

  
17
import bisect
18

  
19

  
20
class Interval(object):
21
    __slots__ = ['begin', 'end', 'data']
22

  
23
    def __init__(self, begin, end, data=None):
24
        assert begin < end
25
        self.begin = begin
26
        self.end = end
27
        self.data = data
28

  
29
    def overlap(self, begin, end):
30
        if end <= self.begin:
31
            return False
32
        if begin >= self.end:
33
            return False
34
        return True
35

  
36
    def intersect(self, interval):
37
        return self.overlap(interval.begin, interval.end)
38

  
39
    def __lt__(self, interval):
40
        return self.begin < interval.begin or (self.begin == interval.begin and
41
            self.end < interval.end)
42

  
43
    def __repr__(self):
44
        return '<Interval [%s, %s] %s>' % (self.begin, self.end, self.data or '')
45

  
46

  
47
class Intervals(object):
48
    "Maintain a list of mostly non overlapping intervals, allow removing overlap"
49
    def __init__(self, sequence=None):
50
        self.points = []
51
        self.container = []
52

  
53
        if sequence:
54
            d = {}
55
            for interval in sequence:
56
                l = d.setdefault(interval.begin, [])
57
                l.append(interval)
58
                l = d.setdefault(interval.end, [])
59
                l.append(interval)
60
            for point, l in sorted(d.iteritems()):
61
                self.points.append(point)
62
                self.container.append(l)
63

  
64
    def __insert_point(self, point, interval):
65
        i = bisect.bisect_left(self.points, point)
66
        if i >= len(self.container) or self.points[i] != point:
67
            self.points.insert(i, point)
68
            self.container.insert(i, [])
69
        self.container[i].append(interval)
70

  
71
    def add(self, begin, end, data=None):
72
        'Add an interval'
73
        self.add_interval(Interval(begin, end, data))
74

  
75
    def add_interval(self, interval):
76
        'Add an interval object'
77
        self.__insert_point(interval.begin, interval)
78
        self.__insert_point(interval.end, interval)
79

  
80
    def __iter_interval(self, begin, end, modify=False):
81
         i = bisect.bisect_left(self.points, begin)
82
         while i < len(self.points) and self.points[i] <= end:
83
             container = self.container[i]
84
             if modify:
85
                 container = list(container)
86
             for interval in container:
87
                 yield self.points[i], interval
88
             i += 1
89

  
90
    def remove_overlap(self, begin, end):
91
        'Remove all overlapping intervals'
92
        for point, interval in self.__iter_interval(begin, end, modify=True):
93
            if interval.overlap(begin, end):
94
                self.__remove_interval(interval)
95

  
96
    def overlap(self, begin, end):
97
        'Test if some intervals overlap'
98
        for point, interval in self.__iter_interval(begin, end):
99
            if interval.overlap(begin, end):
100
                return True
101
        return False
102

  
103
    def search(self, begin, end):
104
        'Search overlapping intervals'
105
        for point, interval in self.__iter_interval(begin, end):
106
            if interval.overlap(begin, end):
107
                 # prevent returning the same interval two times
108
                 if interval.begin < begin or interval.begin == point:
109
                     yield interval
110

  
111
    def search_data(self, begin, end):
112
        'Search data elements of overlapping intervals'
113
        for interval in self.search(begin, end):
114
            yield interval.data
115

  
116
    def iter(self):
117
        'Iterate intervals'
118
        if not self.points:
119
            return []
120
        return self.search(self.points[0], self.points[-1])
121

  
122
    def iter_data(self):
123
        'Iterate data element attached to intervals'
124
        for interval in self.iter():
125
            yield interval.data
126

  
127
    def __remove_interval(self, interval):
128
        self.__remove_point_interval(interval.begin, interval)
129
        self.__remove_point_interval(interval.end, interval)
130

  
131
    def __remove_point_interval(self, point, interval):
132
        i = bisect.bisect_left(self.points, point)
133
        assert self.points[i] == point
134
        self.container[i].remove(interval)
135

  
136
    def iter_merge(self):
137
        'Returns the union of all contained intervals'
138
        if not self.points:
139
            return
140

  
141
        begin = None
142
        end = None
143
        for i in range(len(self.points)):
144
            if begin is None and self.container[i]:
145
                begin = self.points[i]
146
            for interval in self.container[i]:
147
                if end is None:
148
                    end = interval.end
149
                else:
150
                    end = max(end, interval.end)
151
            assert end is None or (not end < self.points[i])
152
            if end == self.points[i]:
153
                yield Interval(begin, end)
154
                begin = None
155
                end = None
156

  
157
    def remove_overlap_non_overlapping_ordered_intervals(self, intervals):
158
        '''Remove overlapping intervals with a strictly
159
           ordered sequence of intervals such as returned
160
           by iter_merge().
161
        '''
162
        i = 0
163
        l = len(self.points)
164
        interval = None
165
        points = []
166
        containers = []
167
        for interval in intervals:
168
            while i < l and self.points[i] <= interval.end:
169
                container = self.container[i]
170
                for b in list(container):
171
                    if b.intersect(interval):
172
                        container.remove(b)
173
                i += 1
174
            if i >= l:
175
                break
176
        while i < l and interval is not None:
177
            container = self.container[i]
178
            for b in list(container):
179
                if b.intersect(interval):
180
                    container.remove(b)
181
            i += 1
182

  
183
    @classmethod
184
    def from_ordered_intervals(self, sequence):
185
        'Build a new Intervals from a strictly ordered sequence of intervals'
186
        d = {}
187
        points = []
188
        container = []
189
        def add_point(point, interval):
190
            if point in d:
191
                i = d[point]
192
            else:
193
                i = d[point] = len(points)
194
                points.append(point)
195
                container.append([])
196
            container[i].append(interval)
197
        last = None
198
        for interval in sequence:
199
            assert last is None or last < interval, (last, interval)
200
            add_point(interval.begin, interval)
201
            add_point(interval.end, interval)
202
            last = interval
203

  
204
        intervals = Intervals()
205
        intervals.points = points
206
        intervals.container = container
207
        return intervals
208

  
tests/test_interval.py
1
from chrono.interval import Interval, Intervals
2

  
3

  
4
def test_interval_overlap():
5
    a = Interval(1, 4)
6

  
7
    assert not a.overlap(0, 1)
8
    assert a.overlap(0, 2)
9
    assert a.overlap(1, 4)
10
    assert a.overlap(2, 3)
11
    assert a.overlap(3, 5)
12
    assert not a.overlap(5, 6)
13

  
14
def test_intervals():
15
    intervals = Intervals()
16

  
17
    for i in range(10):
18
        intervals.add(i, i + 1, 1)
19

  
20
    for i in range(10, 20):
21
        intervals.add(i, i + 1, 2)
22

  
23
    for i in range(5, 15):
24
        intervals.add(i, i + 1, 3)
25

  
26
    assert len(list(intervals.search(0, 5))) == 5
27
    assert len(list(intervals.search(0, 10))) == 15
28
    assert len(list(intervals.search(5, 15))) == 20
29
    assert len(list(intervals.search(10, 20))) == 15
30
    assert len(list(intervals.search(15, 20))) == 5
31

  
32
    assert set(intervals.search_data(0, 5)) == {1}
33
    assert set(intervals.search_data(0, 10)) == {1, 3}
34
    assert set(intervals.search_data(5, 15)) == {1, 2, 3}
35
    assert set(intervals.search_data(10, 20)) == {2, 3}
36
    assert set(intervals.search_data(15, 20)) == {2}
37

  
38
    for i in range(20):
39
        assert intervals.overlap(i, i + 1)
40

  
41
    assert len(list(intervals.iter_merge())) == 1
42
    l = list(intervals.iter_merge())
43
    assert l[0].begin == 0
44
    assert l[0].end == 20
45

  
46
    intervals.remove_overlap(5, 15)
47

  
48
    assert set(intervals.search_data(0, 20)) == {1, 2}
49

  
50
    for i in range(5):
51
        assert intervals.overlap(i, i + 1)
52
    for i in range(5, 15):
53
        assert not intervals.overlap(i, i + 1)
54
    for i in range(15, 20):
55
        assert intervals.overlap(i, i + 1)
56

  
57

  
58
    assert len(list(intervals.iter_merge())) == 2
59
    a, b = list(intervals.iter_merge())
60
    assert a.begin == 0
61
    assert a.end == 5
62
    assert b.begin == 15
63
    assert b.end == 20
64

  
65

  
66
COUNT = 10000
67

  
68

  
69
def g():
70
    for i in range(COUNT):
71
        yield Interval(3 * i, 3 * i + 1, 2)
72
        yield Interval(3 * i, 3 * i + 3, 1)
73

  
74

  
75
def amortized():
76
    def g():
77
        for i in range(COUNT):
78
            yield Interval(3 * i, 3 * i + 1, 2)
79
            yield Interval(3 * i, 3 * i + 3, 1)
80

  
81
    intervals = Intervals.from_ordered_intervals(g())
82

  
83
    def h():
84
        for i in range(COUNT):
85
            yield Interval(3 * i + 2, 3 * i + 3)
86

  
87
    intervals.remove_overlap_non_overlapping_ordered_intervals(
88
        Interval(3 * i + 2, 3 * i + 3) for i in range(COUNT))
89

  
90

  
91
def amortized2():
92

  
93
    intervals = Intervals.from_ordered_intervals(g())
94

  
95
    for i in range(COUNT):
96
        intervals.remove_overlap(3 * i + 2, 3 * i + 3)
97

  
98

  
99
def amortized3():
100
    from intervaltree import IntervalTree
101

  
102
    intervals = IntervalTree()
103
    for i in range(COUNT):
104
        intervals.addi(3 * i, 3 * i + 1, 2)
105
        intervals.addi(3 * i, 3 * i + 3, 1)
106

  
107
    for i in range(COUNT):
108
        intervals.remove_overlap(3 * i + 2, 3 * i + 3)
109

  
110

  
111
def test_amortized(benchmark):
112
    benchmark(amortized)
113

  
114
def test_amortized2(benchmark):
115
    benchmark(amortized2)
116

  
117
def test_amortized3(benchmark):
118
    benchmark(amortized3)
tox.ini
15 15
  django111: django>=1.11,<1.12
16 16
  pytest-cov
17 17
  pytest-django
18
  pytest-benchmark
19
  intervaltree
18 20
  pytest>=3.3.0
19 21
  WebTest
20 22
  mock
21
-