Joshua Lochner commited on
Commit
52340fc
1 Parent(s): 9604abd

Add `output_as_json` argument for inference

Browse files
Files changed (2) hide show
  1. src/evaluate.py +73 -48
  2. src/predict.py +4 -0
src/evaluate.py CHANGED
@@ -15,6 +15,9 @@ import os
15
  import random
16
  from shared import seconds_to_time
17
  from urllib.parse import quote
 
 
 
18
 
19
 
20
  @dataclass
@@ -140,8 +143,8 @@ def main():
140
  dataset_args.data_dir, dataset_args.processed_file)
141
 
142
  if not os.path.exists(final_path):
143
- print('ERROR: Processed database not found.',
144
- f'Run `python src/preprocess.py --update_database --do_process_database` to generate "{final_path}".')
145
  return
146
 
147
  model, tokenizer = get_model_tokenizer(
@@ -180,7 +183,7 @@ def main():
180
 
181
  sponsor_segments = final_data.get(video_id)
182
  if not sponsor_segments:
183
- print('No labels found for', video_id)
184
  continue
185
 
186
  words = get_words(video_id)
@@ -220,56 +223,78 @@ def main():
220
  incorrect_segments = [
221
  seg for seg in labelled_predicted_segments if seg['best_prediction'] is None]
222
 
 
 
 
 
 
223
  else:
224
  # Not in database (all segments missed)
225
  missed_segments = predictions
226
- incorrect_segments = None
227
 
228
  if missed_segments or incorrect_segments:
229
- print(f'Issues identified for {video_id} (#{video_index})')
230
- # Potentially missed segments (model predicted, but not in database)
231
- if missed_segments:
232
- print(' - Missed segments:')
233
- segments_to_submit = []
234
- for i, missed_segment in enumerate(missed_segments, start=1):
235
- print(f'\t#{i}:', seconds_to_time(
236
- missed_segment['start']), '-->', seconds_to_time(missed_segment['end']))
237
- print('\t\tText: "', ' '.join(
238
- [w['text'] for w in missed_segment['words']]), '"', sep='')
239
- print('\t\tCategory:',
240
- missed_segment.get('category'))
241
- if 'probability' in missed_segment:
242
- print('\t\tProbability:',
243
- missed_segment['probability'])
244
-
245
- segments_to_submit.append({
246
- 'segment': [missed_segment['start'], missed_segment['end']],
247
- 'category': missed_segment['category'].lower(),
248
- 'actionType': 'skip'
249
- })
250
-
251
- json_data = quote(json.dumps(segments_to_submit))
252
  print(
253
- f'\tSubmit: https://www.youtube.com/watch?v={video_id}#segments={json_data}')
254
-
255
- # Potentially incorrect segments (model didn't predict, but in database)
256
- if incorrect_segments:
257
- print(' - Incorrect segments:')
258
- for i, incorrect_segment in enumerate(incorrect_segments, start=1):
259
- print(f'\t#{i}:', seconds_to_time(
260
- incorrect_segment['start']), '-->', seconds_to_time(incorrect_segment['end']))
261
-
262
- seg_words = extract_segment(
263
- words, incorrect_segment['start'], incorrect_segment['end'])
264
- print('\t\tText: "', ' '.join(
265
- [w['text'] for w in seg_words]), '"', sep='')
266
- print('\t\tUUID:', incorrect_segment['uuid'])
267
- print('\t\tCategory:',
268
- incorrect_segment['category'])
269
- print('\t\tVotes:', incorrect_segment['votes'])
270
- print('\t\tViews:', incorrect_segment['views'])
271
- print('\t\tLocked:', incorrect_segment['locked'])
272
- print()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
273
 
274
  except KeyboardInterrupt:
275
  pass
@@ -277,7 +302,7 @@ def main():
277
  df = pd.DataFrame(out_metrics)
278
 
279
  df.to_csv(evaluation_args.output_file)
280
- print(df.mean())
281
 
282
 
283
  if __name__ == '__main__':
 
15
  import random
16
  from shared import seconds_to_time
17
  from urllib.parse import quote
18
+ import logging
19
+
20
+ logger = logging.getLogger(__name__)
21
 
22
 
23
  @dataclass
 
143
  dataset_args.data_dir, dataset_args.processed_file)
144
 
145
  if not os.path.exists(final_path):
146
+ logger.error('ERROR: Processed database not found.',
147
+ f'Run `python src/preprocess.py --update_database --do_process_database` to generate "{final_path}".')
148
  return
149
 
150
  model, tokenizer = get_model_tokenizer(
 
183
 
184
  sponsor_segments = final_data.get(video_id)
185
  if not sponsor_segments:
186
+ logger.warning('No labels found for', video_id)
187
  continue
188
 
189
  words = get_words(video_id)
 
223
  incorrect_segments = [
224
  seg for seg in labelled_predicted_segments if seg['best_prediction'] is None]
225
 
226
+ # Add words to incorrect segments
227
+ for seg in incorrect_segments:
228
+ seg['words'] = extract_segment(
229
+ words, seg['start'], seg['end'])
230
+
231
  else:
232
  # Not in database (all segments missed)
233
  missed_segments = predictions
234
+ incorrect_segments = []
235
 
236
  if missed_segments or incorrect_segments:
237
+ if evaluation_args.output_as_json:
238
+ to_print = {'video_id': video_id}
239
+
240
+ for z in missed_segments + incorrect_segments:
241
+ z['text'] = ' '.join(x['text']
242
+ for x in z.pop('words', []))
243
+
244
+ if missed_segments:
245
+ to_print['missed'] = missed_segments
246
+
247
+ if incorrect_segments:
248
+ to_print['incorrect'] = incorrect_segments
249
+
250
+ print(json.dumps(to_print))
251
+ else:
 
 
 
 
 
 
 
 
252
  print(
253
+ f'Issues identified for {video_id} (#{video_index})')
254
+ # Potentially missed segments (model predicted, but not in database)
255
+ if missed_segments:
256
+ print(' - Missed segments:')
257
+ segments_to_submit = []
258
+ for i, missed_segment in enumerate(missed_segments, start=1):
259
+ print(f'\t#{i}:', seconds_to_time(
260
+ missed_segment['start']), '-->', seconds_to_time(missed_segment['end']))
261
+ print('\t\tText: "', ' '.join(
262
+ [w['text'] for w in missed_segment['words']]), '"', sep='')
263
+ print('\t\tCategory:',
264
+ missed_segment.get('category'))
265
+ if 'probability' in missed_segment:
266
+ print('\t\tProbability:',
267
+ missed_segment['probability'])
268
+
269
+ segments_to_submit.append({
270
+ 'segment': [missed_segment['start'], missed_segment['end']],
271
+ 'category': missed_segment['category'].lower(),
272
+ 'actionType': 'skip'
273
+ })
274
+
275
+ json_data = quote(json.dumps(segments_to_submit))
276
+ print(
277
+ f'\tSubmit: https://www.youtube.com/watch?v={video_id}#segments={json_data}')
278
+
279
+ # Potentially incorrect segments (model didn't predict, but in database)
280
+ if incorrect_segments:
281
+ print(' - Incorrect segments:')
282
+ for i, incorrect_segment in enumerate(incorrect_segments, start=1):
283
+ print(f'\t#{i}:', seconds_to_time(
284
+ incorrect_segment['start']), '-->', seconds_to_time(incorrect_segment['end']))
285
+
286
+ seg_words = extract_segment(
287
+ words, incorrect_segment['start'], incorrect_segment['end'])
288
+ print('\t\tText: "', ' '.join(
289
+ [w['text'] for w in seg_words]), '"', sep='')
290
+ print('\t\tUUID:', incorrect_segment['uuid'])
291
+ print('\t\tCategory:',
292
+ incorrect_segment['category'])
293
+ print('\t\tVotes:', incorrect_segment['votes'])
294
+ print('\t\tViews:', incorrect_segment['views'])
295
+ print('\t\tLocked:',
296
+ incorrect_segment['locked'])
297
+ print()
298
 
299
  except KeyboardInterrupt:
300
  pass
 
302
  df = pd.DataFrame(out_metrics)
303
 
304
  df.to_csv(evaluation_args.output_file)
305
+ logger.info(df.mean())
306
 
307
 
308
  if __name__ == '__main__':
src/predict.py CHANGED
@@ -111,6 +111,9 @@ class InferenceArguments:
111
  }
112
  )
113
 
 
 
 
114
  def __post_init__(self):
115
  # Try to load model from latest checkpoint
116
  if self.model_path is None:
@@ -415,6 +418,7 @@ def main():
415
  print('No predictions found for', video_url, end='\n\n')
416
  continue
417
 
 
418
  print(len(predictions), 'predictions found for', video_url)
419
  for index, prediction in enumerate(predictions, start=1):
420
  print(f'Prediction #{index}:')
 
111
  }
112
  )
113
 
114
+ output_as_json: bool = field(default=False, metadata={
115
+ 'help': 'Output evaluations as JSON'})
116
+
117
  def __post_init__(self):
118
  # Try to load model from latest checkpoint
119
  if self.model_path is None:
 
418
  print('No predictions found for', video_url, end='\n\n')
419
  continue
420
 
421
+ # TODO use predict_args.output_as_json
422
  print(len(predictions), 'predictions found for', video_url)
423
  for index, prediction in enumerate(predictions, start=1):
424
  print(f'Prediction #{index}:')