Joshua Lochner commited on
Commit
915339e
1 Parent(s): b309907

Change `extract_segment` to use a binary search

Browse files
Files changed (1) hide show
  1. src/segment.py +21 -15
src/segment.py CHANGED
@@ -127,21 +127,27 @@ def generate_segments(words, tokenizer, segmentation_args):
127
 
128
 
129
  def extract_segment(words, start, end, map_function=None):
130
- """Extract a segment of words that are between (inclusive) the start and end points"""
131
- segment_words = []
132
-
133
- if start > end:
134
- return segment_words
135
 
136
- # TODO change to binary search
137
- for w in words: # Assumes words are sorted
138
- if word_end(w) < start:
139
- continue # Ignore
140
- if word_start(w) > end:
141
- break # Done with range
142
- if map_function is not None and callable(map_function):
143
- w = map_function(w)
 
 
 
 
144
 
145
- segment_words.append(w)
146
 
147
- return segment_words
 
 
 
 
127
 
128
 
129
  def extract_segment(words, start, end, map_function=None):
130
+ """Extracts all words with time in [start, end]"""
131
+
132
+ a = binary_search(words, 0, len(words), start, True)
133
+ b = min(binary_search(words, 0, len(words), end , False) + 1, len(words))
 
134
 
135
+ to_transform = map_function is not None and callable(map_function)
136
+
137
+ return [
138
+ map_function(words[i]) if to_transform else words[i] for i in range(a, b)
139
+ ]
140
+
141
+ # Binary search to get first index of word whose start/end time is greater/less than some value
142
+ def binary_search(words, start_index, end_index, time, below):
143
+ if start_index >= end_index:
144
+ return end_index
145
+
146
+ middle_index = (start_index + end_index ) // 2
147
 
148
+ middle_time = word_start(words[middle_index]) if below else word_end(words[middle_index])
149
 
150
+ if time <= middle_time:
151
+ return binary_search(words, start_index, middle_index, time, below)
152
+ else:
153
+ return binary_search(words, middle_index + 1, end_index, time, below)