Joshua Lochner commited on
Commit
df35612
1 Parent(s): 8fc746d

Add boilerplate code to detect whether segment was split due to length

Browse files
Files changed (1) hide show
  1. src/segment.py +9 -6
src/segment.py CHANGED
@@ -50,7 +50,6 @@ def word_end(word):
50
 
51
 
52
  def generate_segments(words, tokenizer, segmentation_args):
53
- first_pass_segments = []
54
 
55
  cleaned_words_list = []
56
  for w in words:
@@ -61,6 +60,7 @@ def generate_segments(words, tokenizer, segmentation_args):
61
  num_tokens_list = tokenizer(cleaned_words_list, add_special_tokens=False,
62
  truncation=True, return_attention_mask=False, return_length=True).length
63
 
 
64
  for index, (word, num_tokens) in enumerate(zip(words, num_tokens_list)):
65
  word['num_tokens'] = num_tokens
66
 
@@ -81,14 +81,14 @@ def generate_segments(words, tokenizer, segmentation_args):
81
  for segment in first_pass_segments:
82
  current_segment_num_tokens = 0
83
  current_segment = []
84
-
85
  for word in segment:
86
  new_seg = current_segment_num_tokens + \
87
  word['num_tokens'] >= max_q_size
88
  if new_seg:
89
  # Adding this token would make it have too many tokens
90
  # We save this batch and create new
91
- second_pass_segments.append(current_segment)
92
 
93
  # Add tokens to current segment
94
  current_segment.append(word)
@@ -106,10 +106,13 @@ def generate_segments(words, tokenizer, segmentation_args):
106
  current_segment = current_segment[last_index:]
107
 
108
  if current_segment: # Add remaining segment
109
- second_pass_segments.append(current_segment)
 
 
 
 
110
 
111
  # Cleaning up, delete 'num_tokens' from each word
112
- # for segment in second_pass_segments:
113
  for word in words:
114
  word.pop('num_tokens', None)
115
 
@@ -120,7 +123,7 @@ def extract_segment(words, start, end, map_function=None):
120
  """Extracts all words with time in [start, end]"""
121
 
122
  a = max(binary_search_below(words, 0, len(words), start), 0)
123
- b = min(binary_search_above(words, -1, len(words) -1, end) + 1, len(words))
124
 
125
  to_transform = map_function is not None and callable(map_function)
126
 
 
50
 
51
 
52
  def generate_segments(words, tokenizer, segmentation_args):
 
53
 
54
  cleaned_words_list = []
55
  for w in words:
 
60
  num_tokens_list = tokenizer(cleaned_words_list, add_special_tokens=False,
61
  truncation=True, return_attention_mask=False, return_length=True).length
62
 
63
+ first_pass_segments = []
64
  for index, (word, num_tokens) in enumerate(zip(words, num_tokens_list)):
65
  word['num_tokens'] = num_tokens
66
 
 
81
  for segment in first_pass_segments:
82
  current_segment_num_tokens = 0
83
  current_segment = []
84
+ after_split_segments = []
85
  for word in segment:
86
  new_seg = current_segment_num_tokens + \
87
  word['num_tokens'] >= max_q_size
88
  if new_seg:
89
  # Adding this token would make it have too many tokens
90
  # We save this batch and create new
91
+ after_split_segments.append(current_segment)
92
 
93
  # Add tokens to current segment
94
  current_segment.append(word)
 
106
  current_segment = current_segment[last_index:]
107
 
108
  if current_segment: # Add remaining segment
109
+ after_split_segments.append(current_segment)
110
+
111
+ # TODO if len(after_split_segments) > 1, a split occurred
112
+
113
+ second_pass_segments.extend(after_split_segments)
114
 
115
  # Cleaning up, delete 'num_tokens' from each word
 
116
  for word in words:
117
  word.pop('num_tokens', None)
118
 
 
123
  """Extracts all words with time in [start, end]"""
124
 
125
  a = max(binary_search_below(words, 0, len(words), start), 0)
126
+ b = min(binary_search_above(words, -1, len(words) - 1, end) + 1, len(words))
127
 
128
  to_transform = map_function is not None and callable(map_function)
129