Spaces:
Running
Running
Joshua Lochner
commited on
Commit
•
52340fc
1
Parent(s):
9604abd
Add `output_as_json` argument for inference
Browse files- src/evaluate.py +73 -48
- src/predict.py +4 -0
src/evaluate.py
CHANGED
@@ -15,6 +15,9 @@ import os
|
|
15 |
import random
|
16 |
from shared import seconds_to_time
|
17 |
from urllib.parse import quote
|
|
|
|
|
|
|
18 |
|
19 |
|
20 |
@dataclass
|
@@ -140,8 +143,8 @@ def main():
|
|
140 |
dataset_args.data_dir, dataset_args.processed_file)
|
141 |
|
142 |
if not os.path.exists(final_path):
|
143 |
-
|
144 |
-
|
145 |
return
|
146 |
|
147 |
model, tokenizer = get_model_tokenizer(
|
@@ -180,7 +183,7 @@ def main():
|
|
180 |
|
181 |
sponsor_segments = final_data.get(video_id)
|
182 |
if not sponsor_segments:
|
183 |
-
|
184 |
continue
|
185 |
|
186 |
words = get_words(video_id)
|
@@ -220,56 +223,78 @@ def main():
|
|
220 |
incorrect_segments = [
|
221 |
seg for seg in labelled_predicted_segments if seg['best_prediction'] is None]
|
222 |
|
|
|
|
|
|
|
|
|
|
|
223 |
else:
|
224 |
# Not in database (all segments missed)
|
225 |
missed_segments = predictions
|
226 |
-
incorrect_segments =
|
227 |
|
228 |
if missed_segments or incorrect_segments:
|
229 |
-
|
230 |
-
|
231 |
-
|
232 |
-
|
233 |
-
|
234 |
-
|
235 |
-
|
236 |
-
|
237 |
-
|
238 |
-
|
239 |
-
|
240 |
-
|
241 |
-
|
242 |
-
|
243 |
-
|
244 |
-
|
245 |
-
segments_to_submit.append({
|
246 |
-
'segment': [missed_segment['start'], missed_segment['end']],
|
247 |
-
'category': missed_segment['category'].lower(),
|
248 |
-
'actionType': 'skip'
|
249 |
-
})
|
250 |
-
|
251 |
-
json_data = quote(json.dumps(segments_to_submit))
|
252 |
print(
|
253 |
-
f'
|
254 |
-
|
255 |
-
|
256 |
-
|
257 |
-
|
258 |
-
|
259 |
-
|
260 |
-
|
261 |
-
|
262 |
-
|
263 |
-
|
264 |
-
|
265 |
-
|
266 |
-
|
267 |
-
|
268 |
-
|
269 |
-
|
270 |
-
|
271 |
-
|
272 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
273 |
|
274 |
except KeyboardInterrupt:
|
275 |
pass
|
@@ -277,7 +302,7 @@ def main():
|
|
277 |
df = pd.DataFrame(out_metrics)
|
278 |
|
279 |
df.to_csv(evaluation_args.output_file)
|
280 |
-
|
281 |
|
282 |
|
283 |
if __name__ == '__main__':
|
|
|
15 |
import random
|
16 |
from shared import seconds_to_time
|
17 |
from urllib.parse import quote
|
18 |
+
import logging
|
19 |
+
|
20 |
+
logger = logging.getLogger(__name__)
|
21 |
|
22 |
|
23 |
@dataclass
|
|
|
143 |
dataset_args.data_dir, dataset_args.processed_file)
|
144 |
|
145 |
if not os.path.exists(final_path):
|
146 |
+
logger.error('ERROR: Processed database not found.',
|
147 |
+
f'Run `python src/preprocess.py --update_database --do_process_database` to generate "{final_path}".')
|
148 |
return
|
149 |
|
150 |
model, tokenizer = get_model_tokenizer(
|
|
|
183 |
|
184 |
sponsor_segments = final_data.get(video_id)
|
185 |
if not sponsor_segments:
|
186 |
+
logger.warning('No labels found for', video_id)
|
187 |
continue
|
188 |
|
189 |
words = get_words(video_id)
|
|
|
223 |
incorrect_segments = [
|
224 |
seg for seg in labelled_predicted_segments if seg['best_prediction'] is None]
|
225 |
|
226 |
+
# Add words to incorrect segments
|
227 |
+
for seg in incorrect_segments:
|
228 |
+
seg['words'] = extract_segment(
|
229 |
+
words, seg['start'], seg['end'])
|
230 |
+
|
231 |
else:
|
232 |
# Not in database (all segments missed)
|
233 |
missed_segments = predictions
|
234 |
+
incorrect_segments = []
|
235 |
|
236 |
if missed_segments or incorrect_segments:
|
237 |
+
if evaluation_args.output_as_json:
|
238 |
+
to_print = {'video_id': video_id}
|
239 |
+
|
240 |
+
for z in missed_segments + incorrect_segments:
|
241 |
+
z['text'] = ' '.join(x['text']
|
242 |
+
for x in z.pop('words', []))
|
243 |
+
|
244 |
+
if missed_segments:
|
245 |
+
to_print['missed'] = missed_segments
|
246 |
+
|
247 |
+
if incorrect_segments:
|
248 |
+
to_print['incorrect'] = incorrect_segments
|
249 |
+
|
250 |
+
print(json.dumps(to_print))
|
251 |
+
else:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
252 |
print(
|
253 |
+
f'Issues identified for {video_id} (#{video_index})')
|
254 |
+
# Potentially missed segments (model predicted, but not in database)
|
255 |
+
if missed_segments:
|
256 |
+
print(' - Missed segments:')
|
257 |
+
segments_to_submit = []
|
258 |
+
for i, missed_segment in enumerate(missed_segments, start=1):
|
259 |
+
print(f'\t#{i}:', seconds_to_time(
|
260 |
+
missed_segment['start']), '-->', seconds_to_time(missed_segment['end']))
|
261 |
+
print('\t\tText: "', ' '.join(
|
262 |
+
[w['text'] for w in missed_segment['words']]), '"', sep='')
|
263 |
+
print('\t\tCategory:',
|
264 |
+
missed_segment.get('category'))
|
265 |
+
if 'probability' in missed_segment:
|
266 |
+
print('\t\tProbability:',
|
267 |
+
missed_segment['probability'])
|
268 |
+
|
269 |
+
segments_to_submit.append({
|
270 |
+
'segment': [missed_segment['start'], missed_segment['end']],
|
271 |
+
'category': missed_segment['category'].lower(),
|
272 |
+
'actionType': 'skip'
|
273 |
+
})
|
274 |
+
|
275 |
+
json_data = quote(json.dumps(segments_to_submit))
|
276 |
+
print(
|
277 |
+
f'\tSubmit: https://www.youtube.com/watch?v={video_id}#segments={json_data}')
|
278 |
+
|
279 |
+
# Potentially incorrect segments (model didn't predict, but in database)
|
280 |
+
if incorrect_segments:
|
281 |
+
print(' - Incorrect segments:')
|
282 |
+
for i, incorrect_segment in enumerate(incorrect_segments, start=1):
|
283 |
+
print(f'\t#{i}:', seconds_to_time(
|
284 |
+
incorrect_segment['start']), '-->', seconds_to_time(incorrect_segment['end']))
|
285 |
+
|
286 |
+
seg_words = extract_segment(
|
287 |
+
words, incorrect_segment['start'], incorrect_segment['end'])
|
288 |
+
print('\t\tText: "', ' '.join(
|
289 |
+
[w['text'] for w in seg_words]), '"', sep='')
|
290 |
+
print('\t\tUUID:', incorrect_segment['uuid'])
|
291 |
+
print('\t\tCategory:',
|
292 |
+
incorrect_segment['category'])
|
293 |
+
print('\t\tVotes:', incorrect_segment['votes'])
|
294 |
+
print('\t\tViews:', incorrect_segment['views'])
|
295 |
+
print('\t\tLocked:',
|
296 |
+
incorrect_segment['locked'])
|
297 |
+
print()
|
298 |
|
299 |
except KeyboardInterrupt:
|
300 |
pass
|
|
|
302 |
df = pd.DataFrame(out_metrics)
|
303 |
|
304 |
df.to_csv(evaluation_args.output_file)
|
305 |
+
logger.info(df.mean())
|
306 |
|
307 |
|
308 |
if __name__ == '__main__':
|
src/predict.py
CHANGED
@@ -111,6 +111,9 @@ class InferenceArguments:
|
|
111 |
}
|
112 |
)
|
113 |
|
|
|
|
|
|
|
114 |
def __post_init__(self):
|
115 |
# Try to load model from latest checkpoint
|
116 |
if self.model_path is None:
|
@@ -415,6 +418,7 @@ def main():
|
|
415 |
print('No predictions found for', video_url, end='\n\n')
|
416 |
continue
|
417 |
|
|
|
418 |
print(len(predictions), 'predictions found for', video_url)
|
419 |
for index, prediction in enumerate(predictions, start=1):
|
420 |
print(f'Prediction #{index}:')
|
|
|
111 |
}
|
112 |
)
|
113 |
|
114 |
+
output_as_json: bool = field(default=False, metadata={
|
115 |
+
'help': 'Output evaluations as JSON'})
|
116 |
+
|
117 |
def __post_init__(self):
|
118 |
# Try to load model from latest checkpoint
|
119 |
if self.model_path is None:
|
|
|
418 |
print('No predictions found for', video_url, end='\n\n')
|
419 |
continue
|
420 |
|
421 |
+
# TODO use predict_args.output_as_json
|
422 |
print(len(predictions), 'predictions found for', video_url)
|
423 |
for index, prediction in enumerate(predictions, start=1):
|
424 |
print(f'Prediction #{index}:')
|