Joshua Lochner commited on
Commit
25f1183
1 Parent(s): 320a2ba

Use multiclass classifier to filter predictions

Browse files
out/runs/Jan18_13-34-23_DESKTOP-I39NJG7/1642505668.7632372/events.out.tfevents.1642505668.DESKTOP-I39NJG7.27016.1 ADDED
Binary file (5.12 kB). View file
 
out/runs/Jan18_13-34-23_DESKTOP-I39NJG7/events.out.tfevents.1642505668.DESKTOP-I39NJG7.27016.0 ADDED
Binary file (3.51 kB). View file
 
src/predict.py CHANGED
@@ -1,3 +1,4 @@
 
1
  from utils import re_findall
2
  from shared import OutputArguments
3
  from typing import Optional
@@ -25,6 +26,7 @@ import logging
25
 
26
  import re
27
 
 
28
  def seconds_to_time(seconds, remove_leading_zeroes=False):
29
  fractional = round(seconds % 1, 3)
30
  fractional = '' if fractional == 0 else str(fractional)[1:]
@@ -35,6 +37,7 @@ def seconds_to_time(seconds, remove_leading_zeroes=False):
35
  hms = re.sub(r'^0(?:0:0?)?', '', hms)
36
  return f"{'-' if seconds < 0 else ''}{hms}{fractional}"
37
 
 
38
  @dataclass
39
  class TrainingOutputArguments:
40
 
@@ -68,13 +71,15 @@ class PredictArguments(TrainingOutputArguments):
68
  )
69
 
70
 
71
- SPONSOR_MATCH_RE = fr'(?<={CustomTokens.START_SEGMENT.value})\s*_(?P<category>\S+)\s*(?P<text>.*?)\s*(?={CustomTokens.END_SEGMENT.value}|$)'
 
 
72
 
73
  MATCH_WINDOW = 25 # Increase for accuracy, but takes longer: O(n^3)
74
  MERGE_TIME_WITHIN = 8 # Merge predictions if they are within x seconds
75
 
76
 
77
- @dataclass
78
  class ClassifierArguments:
79
  classifier_dir: Optional[str] = field(
80
  default='classifiers',
@@ -101,7 +106,7 @@ class ClassifierArguments:
101
  default=0.5, metadata={'help': 'Remove all predictions whose classification probability is below this threshold.'})
102
 
103
 
104
- def filter_predictions(predictions, classifier_args): # classifier, vectorizer,
105
  """Use classifier to filter predictions"""
106
  if not predictions:
107
  return predictions
@@ -114,14 +119,34 @@ def filter_predictions(predictions, classifier_args): # classifier, vectorizer,
114
  ])
115
  probabilities = classifier.predict_proba(transformed_segments)
116
 
 
 
117
  filtered_predictions = []
118
- for prediction, probability in zip(predictions, probabilities):
119
- prediction['probability'] = probability[1]
 
 
 
 
 
 
 
 
 
 
 
 
 
120
 
121
- if prediction['probability'] >= classifier_args.min_probability:
122
- filtered_predictions.append(prediction)
123
- # else:
124
- # print('removing segment', prediction)
 
 
 
 
 
125
 
126
  return filtered_predictions
127
 
@@ -140,7 +165,6 @@ def predict(video_id, model, tokenizer, segmentation_args, words=None, classifie
140
  )
141
 
142
  predictions = segments_to_predictions(segments, model, tokenizer)
143
-
144
  # Add words back to time_ranges
145
  for prediction in predictions:
146
  # Stores words in the range
@@ -148,8 +172,8 @@ def predict(video_id, model, tokenizer, segmentation_args, words=None, classifie
148
  words, prediction['start'], prediction['end'])
149
 
150
  # TODO add back
151
- # if classifier_args is not None:
152
- # predictions = filter_predictions(predictions, classifier_args)
153
 
154
  return predictions
155
 
@@ -171,6 +195,9 @@ def greedy_match(list, sublist):
171
  return best_i, best_j, best_k
172
 
173
 
 
 
 
174
  def predict_sponsor_text(text, model, tokenizer):
175
  """Given a body of text, predict the words which are part of the sponsor"""
176
  input_ids = tokenizer(
@@ -189,7 +216,7 @@ def predict_sponsor_matches(text, model, tokenizer):
189
  if CustomTokens.NO_SEGMENT.value in sponsorship_text:
190
  return []
191
 
192
- return re_findall(SPONSOR_MATCH_RE, sponsorship_text)
193
 
194
 
195
  def segments_to_predictions(segments, model, tokenizer):
@@ -237,12 +264,11 @@ def segments_to_predictions(segments, model, tokenizer):
237
  start_time = range['start']
238
  end_time = range['end']
239
 
240
- if prev_prediction is not None and range['category'] == prev_prediction['category'] and (
241
- start_time <= prev_prediction['end'] <= end_time or \
242
- start_time - prev_prediction['end'] <= MERGE_TIME_WITHIN
243
- ):
244
- # Ending time of last segment is in this segment or within the merge threshold,
245
- # so we extend last prediction range
246
  final_predicted_time_ranges[-1]['end'] = end_time
247
 
248
  else: # No overlap, is a new prediction
@@ -279,7 +305,7 @@ def main():
279
 
280
  predict_args.video_id = predict_args.video_id.strip()
281
  predictions = predict(predict_args.video_id, model, tokenizer,
282
- segmentation_args) # TODO add back , classifier_args=classifier_args
283
 
284
  video_url = f'https://www.youtube.com/watch?v={predict_args.video_id}'
285
  if not predictions:
@@ -292,7 +318,7 @@ def main():
292
  print('Text: "',
293
  ' '.join([w['text'] for w in prediction['words']]), '"', sep='')
294
  print('Time:', seconds_to_time(
295
- prediction['start']), '-->', seconds_to_time(prediction['end']))
296
  print('Probability:', prediction.get('probability'))
297
  print('Category:', prediction.get('category'))
298
  print()
 
1
+ from shared import START_SEGMENT_TEMPLATE, END_SEGMENT_TEMPLATE
2
  from utils import re_findall
3
  from shared import OutputArguments
4
  from typing import Optional
 
26
 
27
  import re
28
 
29
+
30
  def seconds_to_time(seconds, remove_leading_zeroes=False):
31
  fractional = round(seconds % 1, 3)
32
  fractional = '' if fractional == 0 else str(fractional)[1:]
 
37
  hms = re.sub(r'^0(?:0:0?)?', '', hms)
38
  return f"{'-' if seconds < 0 else ''}{hms}{fractional}"
39
 
40
+
41
  @dataclass
42
  class TrainingOutputArguments:
43
 
 
71
  )
72
 
73
 
74
+ _SEGMENT_START = START_SEGMENT_TEMPLATE.format(r'(?P<category>\w+)')
75
+ _SEGMENT_END = END_SEGMENT_TEMPLATE.format(r'\w+')
76
+ SEGMENT_MATCH_RE = fr'{_SEGMENT_START}\s*(?P<text>.*?)\s*(?:{_SEGMENT_END}|$)'
77
 
78
  MATCH_WINDOW = 25 # Increase for accuracy, but takes longer: O(n^3)
79
  MERGE_TIME_WITHIN = 8 # Merge predictions if they are within x seconds
80
 
81
 
82
+ @dataclass(frozen=True, eq=True)
83
  class ClassifierArguments:
84
  classifier_dir: Optional[str] = field(
85
  default='classifiers',
 
106
  default=0.5, metadata={'help': 'Remove all predictions whose classification probability is below this threshold.'})
107
 
108
 
109
+ def add_predictions(predictions, classifier_args): # classifier, vectorizer,
110
  """Use classifier to filter predictions"""
111
  if not predictions:
112
  return predictions
 
119
  ])
120
  probabilities = classifier.predict_proba(transformed_segments)
121
 
122
+ # Transformer sometimes says segment is of another category, so we
123
+ # update category and probabilities if classifier is confident it is another category
124
  filtered_predictions = []
125
+ for prediction, probabilities in zip(predictions, probabilities):
126
+ predicted_probabilities = {k: v for k,
127
+ v in zip(CATEGORIES, probabilities)}
128
+
129
+ # Get best category + probability
130
+ classifier_category = max(
131
+ predicted_probabilities, key=predicted_probabilities.get)
132
+ classifier_probability = predicted_probabilities[classifier_category]
133
+
134
+ if classifier_category is None and classifier_probability > classifier_args.min_probability:
135
+ continue # Ignore
136
+
137
+ if classifier_category is not None and classifier_probability > 0.5: # TODO make param
138
+ # Confident enough to overrule, so we update category
139
+ prediction['category'] = classifier_category
140
 
141
+ prediction['probability'] = predicted_probabilities[prediction['category']]
142
+
143
+ # TODO add probabilities, but remove None and normalise rest
144
+ prediction['probabilities'] = predicted_probabilities
145
+
146
+ # if prediction['probability'] < classifier_args.min_probability:
147
+ # continue
148
+
149
+ filtered_predictions.append(prediction)
150
 
151
  return filtered_predictions
152
 
 
165
  )
166
 
167
  predictions = segments_to_predictions(segments, model, tokenizer)
 
168
  # Add words back to time_ranges
169
  for prediction in predictions:
170
  # Stores words in the range
 
172
  words, prediction['start'], prediction['end'])
173
 
174
  # TODO add back
175
+ if classifier_args is not None:
176
+ predictions = add_predictions(predictions, classifier_args)
177
 
178
  return predictions
179
 
 
195
  return best_i, best_j, best_k
196
 
197
 
198
+ CATEGORIES = [None, 'SPONSOR', 'SELFPROMO', 'INTERACTION']
199
+
200
+
201
  def predict_sponsor_text(text, model, tokenizer):
202
  """Given a body of text, predict the words which are part of the sponsor"""
203
  input_ids = tokenizer(
 
216
  if CustomTokens.NO_SEGMENT.value in sponsorship_text:
217
  return []
218
 
219
+ return re_findall(SEGMENT_MATCH_RE, sponsorship_text)
220
 
221
 
222
  def segments_to_predictions(segments, model, tokenizer):
 
264
  start_time = range['start']
265
  end_time = range['end']
266
 
267
+ if prev_prediction is not None and \
268
+ (start_time <= prev_prediction['end'] <= end_time or # Merge overlapping segments
269
+ (range['category'] == prev_prediction['category'] # Merge disconnected segments if same category and within threshold
270
+ and start_time - prev_prediction['end'] <= MERGE_TIME_WITHIN)):
271
+ # Extend last prediction range
 
272
  final_predicted_time_ranges[-1]['end'] = end_time
273
 
274
  else: # No overlap, is a new prediction
 
305
 
306
  predict_args.video_id = predict_args.video_id.strip()
307
  predictions = predict(predict_args.video_id, model, tokenizer,
308
+ segmentation_args, classifier_args=classifier_args)
309
 
310
  video_url = f'https://www.youtube.com/watch?v={predict_args.video_id}'
311
  if not predictions:
 
318
  print('Text: "',
319
  ' '.join([w['text'] for w in prediction['words']]), '"', sep='')
320
  print('Time:', seconds_to_time(
321
+ prediction['start']), '\u2192', seconds_to_time(prediction['end']))
322
  print('Probability:', prediction.get('probability'))
323
  print('Category:', prediction.get('category'))
324
  print()