Joshua Lochner commited on
Commit
90d1f68
1 Parent(s): 4678c9b

Add functionality to predict self-promo and interaction reminders

Browse files
Files changed (7) hide show
  1. src/evaluate.py +2 -4
  2. src/predict.py +31 -24
  3. src/preprocess.py +111 -55
  4. src/segment.py +3 -3
  5. src/shared.py +5 -6
  6. src/train.py +15 -15
  7. src/utils.py +6 -0
src/evaluate.py CHANGED
@@ -105,13 +105,13 @@ def calculate_metrics(labelled_words, predictions):
105
 
106
  if predicted_sponsor:
107
  # total_positive_time += duration
108
- if word['sponsor']: # Is actual sponsor
109
  metrics['true_positive'] += duration
110
  else:
111
  metrics['false_positive'] += duration
112
  else:
113
  # total_negative_time += duration
114
- if word['sponsor']: # Is actual sponsor
115
  metrics['false_negative'] += duration
116
  else:
117
  metrics['true_negative'] += duration
@@ -176,8 +176,6 @@ def main():
176
  with open(final_path) as fp:
177
  final_data = json.load(fp)
178
 
179
- classifier, vectorizer = get_classifier_vectorizer(classifier_args)
180
-
181
  total_accuracy = 0
182
  total_precision = 0
183
  total_recall = 0
 
105
 
106
  if predicted_sponsor:
107
  # total_positive_time += duration
108
+ if word['category'] is not None: # Is actual sponsor
109
  metrics['true_positive'] += duration
110
  else:
111
  metrics['false_positive'] += duration
112
  else:
113
  # total_negative_time += duration
114
+ if word['category'] is not None: # Is actual sponsor
115
  metrics['false_negative'] += duration
116
  else:
117
  metrics['true_negative'] += duration
 
176
  with open(final_path) as fp:
177
  final_data = json.load(fp)
178
 
 
 
179
  total_accuracy = 0
180
  total_precision = 0
181
  total_recall = 0
src/predict.py CHANGED
@@ -1,4 +1,4 @@
1
- from transformers.trainer_utils import get_last_checkpoint
2
  from shared import OutputArguments
3
  from typing import Optional
4
  from segment import (
@@ -11,21 +11,22 @@ from segment import (
11
  SegmentationArguments
12
  )
13
  import preprocess
14
- import re
15
  from errors import TranscriptError
16
  from model import get_classifier_vectorizer
17
  from transformers import (
18
  AutoModelForSeq2SeqLM,
19
- AutoTokenizer
 
20
  )
 
21
  from dataclasses import dataclass, field
22
- from transformers import HfArgumentParser
23
  from shared import device
24
  import logging
25
 
26
 
27
  def seconds_to_time(seconds):
28
- fractional = str(round(seconds % 1, 3))[1:]
 
29
  h, remainder = divmod(abs(int(seconds)), 3600)
30
  m, s = divmod(remainder, 60)
31
  return f"{'-' if seconds < 0 else ''}{h:02}:{m:02}:{s:02}{fractional}"
@@ -64,7 +65,7 @@ class PredictArguments(TrainingOutputArguments):
64
  )
65
 
66
 
67
- SPONSOR_MATCH_RE = fr'(?<={CustomTokens.START_SPONSOR.value})\s*(.*?)\s*(?={CustomTokens.END_SPONSOR.value}|$)'
68
 
69
  MATCH_WINDOW = 25 # Increase for accuracy, but takes longer: O(n^3)
70
  MERGE_TIME_WITHIN = 8 # Merge predictions if they are within x seconds
@@ -97,11 +98,13 @@ class ClassifierArguments:
97
  default=0.5, metadata={'help': 'Remove all predictions whose classification probability is below this threshold.'})
98
 
99
 
100
- def filter_predictions(predictions, classifier, vectorizer, classifier_args):
101
  """Use classifier to filter predictions"""
102
  if not predictions:
103
  return predictions
104
 
 
 
105
  transformed_segments = vectorizer.transform([
106
  preprocess.clean_text(' '.join([x['text'] for x in pred['words']]))
107
  for pred in predictions
@@ -142,9 +145,7 @@ def predict(video_id, model, tokenizer, segmentation_args, words=None, classifie
142
  words, prediction['start'], prediction['end'])
143
 
144
  if classifier_args is not None:
145
- classifier, vectorizer = get_classifier_vectorizer(classifier_args)
146
- predictions = filter_predictions(
147
- predictions, classifier, vectorizer, classifier_args)
148
 
149
  return predictions
150
 
@@ -166,13 +167,10 @@ def greedy_match(list, sublist):
166
  return best_i, best_j, best_k
167
 
168
 
169
- DEFAULT_TOKEN_PREFIX = 'summarize: '
170
-
171
-
172
  def predict_sponsor_text(text, model, tokenizer):
173
  """Given a body of text, predict the words which are part of the sponsor"""
174
  input_ids = tokenizer(
175
- f'{DEFAULT_TOKEN_PREFIX}{text}', return_tensors='pt', truncation=True).input_ids.to(device())
176
 
177
  # Can't be longer than input length + SAFETY_TOKENS or model input dim
178
  max_out_len = min(len(input_ids[0]) + SAFETY_TOKENS, model.model_dim)
@@ -183,10 +181,11 @@ def predict_sponsor_text(text, model, tokenizer):
183
 
184
  def predict_sponsor_matches(text, model, tokenizer):
185
  sponsorship_text = predict_sponsor_text(text, model, tokenizer)
186
- if CustomTokens.NO_SPONSOR.value in sponsorship_text:
 
187
  return []
188
 
189
- return re.findall(SPONSOR_MATCH_RE, sponsorship_text)
190
 
191
 
192
  def segments_to_prediction_times(segments, model, tokenizer):
@@ -202,7 +201,7 @@ def segments_to_prediction_times(segments, model, tokenizer):
202
  matches = predict_sponsor_matches(batch_text, model, tokenizer)
203
 
204
  for match in matches:
205
- matched_text = match.split()
206
  # TODO skip if too short
207
 
208
  i1, j1, k1 = greedy_match(
@@ -217,7 +216,8 @@ def segments_to_prediction_times(segments, model, tokenizer):
217
 
218
  predicted_time_ranges.append({
219
  'start': word_start(extracted_words[0]),
220
- 'end': word_end(extracted_words[-1])
 
221
  })
222
 
223
  # Necessary to sort matches by start time
@@ -225,23 +225,29 @@ def segments_to_prediction_times(segments, model, tokenizer):
225
 
226
  # Merge overlapping predictions and sponsorships that are close together
227
  # Caused by model having max input size
228
- last_end_time = -1
 
 
229
  final_predicted_time_ranges = []
230
  for range in predicted_time_ranges:
231
  start_time = range['start']
232
  end_time = range['end']
233
 
234
- if (start_time <= last_end_time <= end_time) or (last_end_time != -1 and start_time - last_end_time <= MERGE_TIME_WITHIN):
235
- # Ending time of last segment is in this segment, so we extend last prediction range
 
 
 
236
  final_predicted_time_ranges[-1]['end'] = end_time
237
 
238
  else: # No overlap, is a new prediction
239
  final_predicted_time_ranges.append({
240
  'start': start_time,
241
  'end': end_time,
 
242
  })
243
 
244
- last_end_time = end_time
245
 
246
  return final_predicted_time_ranges
247
 
@@ -268,7 +274,7 @@ def main():
268
 
269
  predict_args.video_id = predict_args.video_id.strip()
270
  predictions = predict(predict_args.video_id, model, tokenizer,
271
- segmentation_args, classifier_args=classifier_args)
272
 
273
  video_url = f'https://www.youtube.com/watch?v={predict_args.video_id}'
274
  if not predictions:
@@ -282,7 +288,8 @@ def main():
282
  ' '.join([w['text'] for w in prediction['words']]), '"', sep='')
283
  print('Time:', seconds_to_time(
284
  prediction['start']), '-->', seconds_to_time(prediction['end']))
285
- print('Probability:', prediction['probability'])
 
286
  print()
287
 
288
 
 
1
+ from utils import re_findall
2
  from shared import OutputArguments
3
  from typing import Optional
4
  from segment import (
 
11
  SegmentationArguments
12
  )
13
  import preprocess
 
14
  from errors import TranscriptError
15
  from model import get_classifier_vectorizer
16
  from transformers import (
17
  AutoModelForSeq2SeqLM,
18
+ AutoTokenizer,
19
+ HfArgumentParser
20
  )
21
+ from transformers.trainer_utils import get_last_checkpoint
22
  from dataclasses import dataclass, field
 
23
  from shared import device
24
  import logging
25
 
26
 
27
  def seconds_to_time(seconds):
28
+ fractional = round(seconds % 1, 3)
29
+ fractional = '' if fractional == 0 else str(fractional)[1:]
30
  h, remainder = divmod(abs(int(seconds)), 3600)
31
  m, s = divmod(remainder, 60)
32
  return f"{'-' if seconds < 0 else ''}{h:02}:{m:02}:{s:02}{fractional}"
 
65
  )
66
 
67
 
68
+ SPONSOR_MATCH_RE = fr'(?<={CustomTokens.START_SEGMENT.value})\s*_(?P<category>\S+)\s*(?P<text>.*?)\s*(?={CustomTokens.END_SEGMENT.value}|$)'
69
 
70
  MATCH_WINDOW = 25 # Increase for accuracy, but takes longer: O(n^3)
71
  MERGE_TIME_WITHIN = 8 # Merge predictions if they are within x seconds
 
98
  default=0.5, metadata={'help': 'Remove all predictions whose classification probability is below this threshold.'})
99
 
100
 
101
+ def filter_predictions(predictions, classifier_args): # classifier, vectorizer,
102
  """Use classifier to filter predictions"""
103
  if not predictions:
104
  return predictions
105
 
106
+ classifier, vectorizer = get_classifier_vectorizer(classifier_args)
107
+
108
  transformed_segments = vectorizer.transform([
109
  preprocess.clean_text(' '.join([x['text'] for x in pred['words']]))
110
  for pred in predictions
 
145
  words, prediction['start'], prediction['end'])
146
 
147
  if classifier_args is not None:
148
+ predictions = filter_predictions(predictions, classifier_args)
 
 
149
 
150
  return predictions
151
 
 
167
  return best_i, best_j, best_k
168
 
169
 
 
 
 
170
  def predict_sponsor_text(text, model, tokenizer):
171
  """Given a body of text, predict the words which are part of the sponsor"""
172
  input_ids = tokenizer(
173
+ f'{CustomTokens.EXTRACT_SEGMENTS_PREFIX.value} {text}', return_tensors='pt', truncation=True).input_ids.to(device())
174
 
175
  # Can't be longer than input length + SAFETY_TOKENS or model input dim
176
  max_out_len = min(len(input_ids[0]) + SAFETY_TOKENS, model.model_dim)
 
181
 
182
  def predict_sponsor_matches(text, model, tokenizer):
183
  sponsorship_text = predict_sponsor_text(text, model, tokenizer)
184
+
185
+ if CustomTokens.NO_SEGMENT.value in sponsorship_text:
186
  return []
187
 
188
+ return re_findall(SPONSOR_MATCH_RE, sponsorship_text)
189
 
190
 
191
  def segments_to_prediction_times(segments, model, tokenizer):
 
201
  matches = predict_sponsor_matches(batch_text, model, tokenizer)
202
 
203
  for match in matches:
204
+ matched_text = match['text'].split()
205
  # TODO skip if too short
206
 
207
  i1, j1, k1 = greedy_match(
 
216
 
217
  predicted_time_ranges.append({
218
  'start': word_start(extracted_words[0]),
219
+ 'end': word_end(extracted_words[-1]),
220
+ 'category': match['category']
221
  })
222
 
223
  # Necessary to sort matches by start time
 
225
 
226
  # Merge overlapping predictions and sponsorships that are close together
227
  # Caused by model having max input size
228
+
229
+ prev_prediction = None
230
+
231
  final_predicted_time_ranges = []
232
  for range in predicted_time_ranges:
233
  start_time = range['start']
234
  end_time = range['end']
235
 
236
+ if prev_prediction is not None and range['category'] == prev_prediction['category'] and (
237
+ start_time <= prev_prediction['end'] <= end_time or start_time -
238
+ prev_prediction['end'] <= MERGE_TIME_WITHIN
239
+ ):
240
+ # Ending time of last segment is in this segment or c, so we extend last prediction range
241
  final_predicted_time_ranges[-1]['end'] = end_time
242
 
243
  else: # No overlap, is a new prediction
244
  final_predicted_time_ranges.append({
245
  'start': start_time,
246
  'end': end_time,
247
+ 'category': range['category']
248
  })
249
 
250
+ prev_prediction = range
251
 
252
  return final_predicted_time_ranges
253
 
 
274
 
275
  predict_args.video_id = predict_args.video_id.strip()
276
  predictions = predict(predict_args.video_id, model, tokenizer,
277
+ segmentation_args) # TODO add back , classifier_args=classifier_args
278
 
279
  video_url = f'https://www.youtube.com/watch?v={predict_args.video_id}'
280
  if not predictions:
 
288
  ' '.join([w['text'] for w in prediction['words']]), '"', sep='')
289
  print('Time:', seconds_to_time(
290
  prediction['start']), '-->', seconds_to_time(prediction['end']))
291
+ print('Probability:', prediction.get('probability'))
292
+ print('Category:', prediction.get('category'))
293
  print()
294
 
295
 
src/preprocess.py CHANGED
@@ -1,5 +1,6 @@
 
1
  import itertools
2
- from typing import Optional
3
  from datasets import load_dataset
4
  from model import ModelArguments
5
  import segment
@@ -24,8 +25,10 @@ def find(s, ch):
24
  return [i for i, ltr in enumerate(s) if ltr == ch]
25
 
26
 
27
- def wordify(transcript):
28
  """Try to replicate format for automatically generated transcripts"""
 
 
29
  words = []
30
 
31
  for line_index, line in enumerate(transcript):
@@ -34,9 +37,14 @@ def wordify(transcript):
34
  continue
35
 
36
  start = line['start']
37
- next_start = transcript[line_index +
38
- 1]['start'] if line_index < len(transcript) - 1 else float('inf')
39
- end = min(start + line['duration'], next_start)
 
 
 
 
 
40
  duration = end - start
41
 
42
  indices = find(text, ' ') + [len(text)]
@@ -52,9 +60,9 @@ def wordify(transcript):
52
  w_start = start + percentage * duration
53
 
54
  words.append({
55
- 'start': round(w_start, 5),
56
- 'duration': round(w_duration, 5),
57
- 'end': round(w_start + w_duration, 5),
58
  'text': word,
59
  })
60
 
@@ -69,6 +77,10 @@ def get_manual_words(transcript_list):
69
  return wordify(transcript)
70
 
71
 
 
 
 
 
72
  def get_auto_words(transcript_list):
73
  words = []
74
  transcript = transcript_list.find_generated_transcript(['en'])
@@ -82,7 +94,7 @@ def get_auto_words(transcript_list):
82
  offset_ms = word.get('tOffsetMs', 0)
83
 
84
  texts = word['utf8'].replace(
85
- CustomTokens.PROFANITY_RAW.value, CustomTokens.PROFANITY_CONVERTED.value
86
  ).strip().split()
87
 
88
  for text in texts:
@@ -94,7 +106,7 @@ def get_auto_words(transcript_list):
94
  return words
95
 
96
 
97
- def get_words(video_id, process=True, fallback=False, transcript_type='auto'):
98
  """Get parsed video transcript with caching system
99
  returns None if not processed yet and process is False
100
  """
@@ -148,21 +160,31 @@ def extract_sponsors(words, min_sponsor_segment_length=5):
148
 
149
  paragraphs = []
150
  current = []
 
151
  for word in words:
152
- if not word.get('sponsor') and not current:
153
- continue
154
 
155
- if word['sponsor']:
156
  current.append(word['text'])
157
  else:
158
- paragraphs.append(current)
 
 
 
159
  current = []
160
- if current:
161
- paragraphs.append(current)
 
 
 
 
 
 
162
 
163
  # Remove all too short:
164
  paragraphs = list(filter(lambda x: len(
165
- x) >= min_sponsor_segment_length, paragraphs))
166
 
167
  return paragraphs
168
 
@@ -203,10 +225,8 @@ def clean_text(text):
203
  text = re.sub(NUM_REGEX, CustomTokens.NUMBER.value, text)
204
 
205
  # Replace profanity with special token
206
- text = text.replace(CustomTokens.PROFANITY_RAW.value,
207
- CustomTokens.PROFANITY.value)
208
- text = text.replace(CustomTokens.PROFANITY_CONVERTED.value,
209
- CustomTokens.PROFANITY.value)
210
 
211
  return text.strip()
212
 
@@ -254,11 +274,25 @@ class PreprocessArguments:
254
  do_create: bool = field(
255
  default=False, metadata={'help': 'Merge sponsor segments into single file'}
256
  )
 
257
  min_votes: int = field(
258
  default=0, metadata={'help': 'Minimum number of votes'})
259
  # Downvotes will make this negative.
260
  # 1 = At least one positive vote
261
 
 
 
 
 
 
 
 
 
 
 
 
 
 
262
  do_transcribe: bool = field(
263
  default=False, metadata={'help': 'Get transcripts for videos'}
264
  )
@@ -266,7 +300,7 @@ class PreprocessArguments:
266
  default=4, metadata={'help': 'Number of transcripts to download in parallel'})
267
 
268
  overwrite: bool = field(
269
- default=False, metadata={'help': 'Overwrite training, testing and validation data, if present.'}
270
  )
271
 
272
  do_generate: bool = field(
@@ -447,14 +481,26 @@ def main():
447
  preprocess_args.raw_data_dir, preprocess_args.raw_data_file)
448
 
449
  def get_rows():
 
 
 
450
  with open(raw_dataset_path, newline='') as csvfile:
451
  reader = csv.DictReader(csvfile)
 
452
  for line in reader:
 
 
 
 
 
 
453
  if line['service'] != 'YouTube':
454
  continue
 
 
455
 
456
  # TODO add support for other categories and action types?
457
- if line['category'] != 'sponsor':
458
  continue
459
  if line['actionType'] != 'skip':
460
  continue
@@ -463,9 +509,6 @@ def main():
463
  if line['hidden'] == '1' or line['shadowHidden'] == '1':
464
  continue
465
 
466
- if len(line['videoID']) != 11:
467
- continue # Invalid youtube video ID
468
-
469
  # Skip those that aren't highly voted
470
  line['votes'] = int(line['votes'])
471
  # incorrect_votes = int(line['incorrectVotes'])
@@ -494,6 +537,8 @@ def main():
494
  for row in data_rows:
495
  video_ids.add(row['videoID'])
496
 
 
 
497
  print('Start transcribing')
498
  with tqdm(total=len(video_ids)) as progress:
499
  def on_job_complete(job):
@@ -517,21 +562,18 @@ def main():
517
  final_path = os.path.join(
518
  processed_args.processed_dir, processed_args.processed_file)
519
 
520
- if os.path.exists(final_path) and not preprocess_args.do_create:
521
- logging.info(f'{final_path} exists, opening file')
522
- with open(final_path) as fp:
523
- final_data = json.load(fp)
524
- else:
525
  print('Create final data')
526
 
527
  final_data = {}
528
 
529
  if data_rows is None:
530
  data_rows = get_rows()
 
531
 
532
  # TODO add progress bar
533
  # TODO parallelise?
534
- for line in data_rows:
535
  video_id = line['videoID']
536
 
537
  if video_id not in final_data:
@@ -540,7 +582,10 @@ def main():
540
  segment_start = float(line['startTime'])
541
  segment_end = float(line['endTime'])
542
 
543
- video_words = get_words(video_id, process=True)
 
 
 
544
  segment_words = segment.extract_segment(
545
  video_words, segment_start, segment_end)
546
 
@@ -552,7 +597,8 @@ def main():
552
  wps = len(segment_words)/duration if duration > 0 else 0
553
 
554
  if wps < preprocess_args.min_wps:
555
- print('bad segment in', video_id, '| wps =', wps)
 
556
  continue
557
 
558
  final_data[video_id].append({
@@ -580,10 +626,16 @@ def main():
580
  # raw_dataset_path, final_path, preprocess_args.min_votes)
581
  # # TODO save metadata in final.json?
582
 
583
- logging.info(f'Found {len(final_data)} videos')
 
 
 
 
 
 
 
584
 
585
  # TODO shuffle final_data
586
-
587
  # if not os.path.exists(excess_path) or preprocess_args.overwrite
588
  # TODO use overwrite param
589
 
@@ -610,10 +662,8 @@ def main():
610
  write_mode = 'w' if preprocess_args.overwrite else 'a'
611
 
612
  get_all = preprocess_args.max_videos is None
613
- if get_all:
614
- total = len(final_data)
615
- else:
616
- total = preprocess_args.max_videos
617
 
618
  index = 0
619
  data = final_data.items()
@@ -641,7 +691,7 @@ def main():
641
  elif count_videos >= preprocess_args.max_videos:
642
  break
643
 
644
- words = get_words(video_id, False)
645
  if not words:
646
  continue
647
 
@@ -662,34 +712,40 @@ def main():
662
  progress.update()
663
 
664
  for seg in segments:
665
-
666
- segment_text = ' '.join((x['text'] for x in seg))
667
-
668
- extracted_text = ''
669
- for p in extract_sponsors(seg):
670
- p_text = ' '.join(p)
671
- extracted_text += f'{CustomTokens.START_SPONSOR.value} {p_text} {CustomTokens.END_SPONSOR.value}. '
672
-
673
  duration = segment.word_end(
674
  seg[-1]) - segment.word_start(seg[0])
675
  wps = len(seg)/duration if duration > 0 else 0
 
676
  # Ignore segments with "not enough words" in the transcript
677
  if wps < preprocess_args.min_wps:
678
  continue
679
 
 
 
680
  d = {
681
  'video_index': index,
682
  'video_id': video_id,
683
  'text': clean_text(segment_text),
684
- 'words_per_second': wps,
685
  }
686
 
687
- d['sponsor'] = bool(extracted_text)
688
- d['extracted'] = clean_text(
689
- extracted_text) if d['sponsor'] else CustomTokens.NO_SPONSOR.value
 
 
 
 
 
 
 
 
 
 
690
 
691
- print(json.dumps(d), file=(
692
- positive if d['sponsor'] else negative))
 
693
 
694
  if preprocess_args.do_split:
695
  print('Splitting')
 
1
+ from datetime import datetime
2
  import itertools
3
+ from typing import Optional, List
4
  from datasets import load_dataset
5
  from model import ModelArguments
6
  import segment
 
25
  return [i for i, ltr in enumerate(s) if ltr == ch]
26
 
27
 
28
+ def wordify(transcript, maximum_wps=1):
29
  """Try to replicate format for automatically generated transcripts"""
30
+
31
+ # Do not allow segments to be on screen for too long using maximum_wps
32
  words = []
33
 
34
  for line_index, line in enumerate(transcript):
 
37
  continue
38
 
39
  start = line['start']
40
+ next_start = transcript[line_index + 1]['start'] \
41
+ if line_index < len(transcript) - 1 else float('inf')
42
+
43
+ # Use maximum wps to calculate latest end (to avoid segments which stay on screen too long)
44
+ longest_duration = maximum_wps * text.count(' ')
45
+ latest_end = start + longest_duration
46
+ end = min(start + line['duration'], next_start, latest_end)
47
+
48
  duration = end - start
49
 
50
  indices = find(text, ' ') + [len(text)]
 
60
  w_start = start + percentage * duration
61
 
62
  words.append({
63
+ 'start': round(w_start, 3),
64
+ 'duration': round(w_duration, 3),
65
+ 'end': round(w_start + w_duration, 3),
66
  'text': word,
67
  })
68
 
 
77
  return wordify(transcript)
78
 
79
 
80
+ PROFANITY_RAW = '[ __ ]' # How YouTube transcribes profanity
81
+ PROFANITY_CONVERTED = '*****' # Safer version for tokenizing
82
+
83
+
84
  def get_auto_words(transcript_list):
85
  words = []
86
  transcript = transcript_list.find_generated_transcript(['en'])
 
94
  offset_ms = word.get('tOffsetMs', 0)
95
 
96
  texts = word['utf8'].replace(
97
+ PROFANITY_RAW, PROFANITY_CONVERTED
98
  ).strip().split()
99
 
100
  for text in texts:
 
106
  return words
107
 
108
 
109
+ def get_words(video_id, process=True, fallback=True, transcript_type='auto'):
110
  """Get parsed video transcript with caching system
111
  returns None if not processed yet and process is False
112
  """
 
160
 
161
  paragraphs = []
162
  current = []
163
+ prev_category = None
164
  for word in words:
165
+ if word['category'] is None: # and not current:
166
+ continue # Skip unimportant
167
 
168
+ if word['category'] == prev_category:
169
  current.append(word['text'])
170
  else:
171
+ paragraphs.append({
172
+ 'words': current,
173
+ 'category': prev_category,
174
+ })
175
  current = []
176
+
177
+ prev_category = word['category']
178
+
179
+ if current and prev_category is not None:
180
+ paragraphs.append({
181
+ 'words': current,
182
+ 'category': prev_category,
183
+ })
184
 
185
  # Remove all too short:
186
  paragraphs = list(filter(lambda x: len(
187
+ x['words']) >= min_sponsor_segment_length, paragraphs))
188
 
189
  return paragraphs
190
 
 
225
  text = re.sub(NUM_REGEX, CustomTokens.NUMBER.value, text)
226
 
227
  # Replace profanity with special token
228
+ text = text.replace(PROFANITY_RAW, CustomTokens.PROFANITY.value)
229
+ text = text.replace(PROFANITY_CONVERTED, CustomTokens.PROFANITY.value)
 
 
230
 
231
  return text.strip()
232
 
 
274
  do_create: bool = field(
275
  default=False, metadata={'help': 'Merge sponsor segments into single file'}
276
  )
277
+
278
  min_votes: int = field(
279
  default=0, metadata={'help': 'Minimum number of votes'})
280
  # Downvotes will make this negative.
281
  # 1 = At least one positive vote
282
 
283
+ min_date: str = field(
284
+ default='20/08/2021', metadata={'help': 'Only use submissions from after this date, defaults to the release of v3.0 (https://github.com/ajayyy/SponsorBlock/releases/tag/3.0)'})
285
+
286
+ categories: str = field(
287
+ default_factory=lambda: ['sponsor', 'selfpromo', 'interaction'],
288
+ metadata={
289
+ 'nargs': '+',
290
+ 'choices': ['intro', 'sponsor', 'interaction',
291
+ 'outro', 'selfpromo', 'preview',
292
+ 'poi_highlight', 'filler', 'music_offtopic'] # moreCategories
293
+ }
294
+ )
295
+
296
  do_transcribe: bool = field(
297
  default=False, metadata={'help': 'Get transcripts for videos'}
298
  )
 
300
  default=4, metadata={'help': 'Number of transcripts to download in parallel'})
301
 
302
  overwrite: bool = field(
303
+ default=True, metadata={'help': 'Overwrite training, testing and validation data, if present.'}
304
  )
305
 
306
  do_generate: bool = field(
 
481
  preprocess_args.raw_data_dir, preprocess_args.raw_data_file)
482
 
483
  def get_rows():
484
+
485
+ latest_time = datetime.strptime(preprocess_args.min_date, '%d/%m/%Y')
486
+
487
  with open(raw_dataset_path, newline='') as csvfile:
488
  reader = csv.DictReader(csvfile)
489
+
490
  for line in reader:
491
+ submitted_time = datetime.fromtimestamp(
492
+ float(line['timeSubmitted'])/1e3)
493
+
494
+ if submitted_time < latest_time:
495
+ continue
496
+
497
  if line['service'] != 'YouTube':
498
  continue
499
+ if len(line['videoID']) != 11:
500
+ continue # Invalid youtube video ID
501
 
502
  # TODO add support for other categories and action types?
503
+ if line['category'] not in preprocess_args.categories:
504
  continue
505
  if line['actionType'] != 'skip':
506
  continue
 
509
  if line['hidden'] == '1' or line['shadowHidden'] == '1':
510
  continue
511
 
 
 
 
512
  # Skip those that aren't highly voted
513
  line['votes'] = int(line['votes'])
514
  # incorrect_votes = int(line['incorrectVotes'])
 
537
  for row in data_rows:
538
  video_ids.add(row['videoID'])
539
 
540
+ # TODO first set - os.listdir and do rest
541
+
542
  print('Start transcribing')
543
  with tqdm(total=len(video_ids)) as progress:
544
  def on_job_complete(job):
 
562
  final_path = os.path.join(
563
  processed_args.processed_dir, processed_args.processed_file)
564
 
565
+ if preprocess_args.do_create:
 
 
 
 
566
  print('Create final data')
567
 
568
  final_data = {}
569
 
570
  if data_rows is None:
571
  data_rows = get_rows()
572
+ # data_rows = itertools.islice(data_rows, 1000) # TODO temp
573
 
574
  # TODO add progress bar
575
  # TODO parallelise?
576
+ for index, line in enumerate(data_rows):
577
  video_id = line['videoID']
578
 
579
  if video_id not in final_data:
 
582
  segment_start = float(line['startTime'])
583
  segment_end = float(line['endTime'])
584
 
585
+ video_words = get_words(video_id, process=False)
586
+ if not video_words:
587
+ continue
588
+
589
  segment_words = segment.extract_segment(
590
  video_words, segment_start, segment_end)
591
 
 
597
  wps = len(segment_words)/duration if duration > 0 else 0
598
 
599
  if wps < preprocess_args.min_wps:
600
+ print(index, 'Skipping bad segment in',
601
+ video_id, '| wps =', wps)
602
  continue
603
 
604
  final_data[video_id].append({
 
626
  # raw_dataset_path, final_path, preprocess_args.min_votes)
627
  # # TODO save metadata in final.json?
628
 
629
+ elif os.path.exists(final_path):
630
+ # Already exists
631
+ logging.info(f'{final_path} exists, opening file')
632
+ with open(final_path) as fp:
633
+ final_data = json.load(fp)
634
+ logging.info(f'Found {len(final_data)} videos')
635
+ else:
636
+ return # Do not continue
637
 
638
  # TODO shuffle final_data
 
639
  # if not os.path.exists(excess_path) or preprocess_args.overwrite
640
  # TODO use overwrite param
641
 
 
662
  write_mode = 'w' if preprocess_args.overwrite else 'a'
663
 
664
  get_all = preprocess_args.max_videos is None
665
+
666
+ total = len(final_data) if get_all else preprocess_args.max_videos
 
 
667
 
668
  index = 0
669
  data = final_data.items()
 
691
  elif count_videos >= preprocess_args.max_videos:
692
  break
693
 
694
+ words = get_words(video_id, process=False)
695
  if not words:
696
  continue
697
 
 
712
  progress.update()
713
 
714
  for seg in segments:
 
 
 
 
 
 
 
 
715
  duration = segment.word_end(
716
  seg[-1]) - segment.word_start(seg[0])
717
  wps = len(seg)/duration if duration > 0 else 0
718
+
719
  # Ignore segments with "not enough words" in the transcript
720
  if wps < preprocess_args.min_wps:
721
  continue
722
 
723
+ segment_text = ' '.join((x['text'] for x in seg))
724
+ extracted_segments = extract_sponsors(seg)
725
  d = {
726
  'video_index': index,
727
  'video_id': video_id,
728
  'text': clean_text(segment_text),
729
+ 'words_per_second': round(wps, 3),
730
  }
731
 
732
+ if extracted_segments:
733
+ extracted_texts = []
734
+ for s in extracted_segments:
735
+ w = ' '.join(s['words'])
736
+ category = s['category'].upper()
737
+
738
+ t = f"{CustomTokens.START_SEGMENT.value}_{category} {w} {CustomTokens.END_SEGMENT.value}_{category}"
739
+ extracted_texts.append(t)
740
+
741
+ extracted_text = '\n'.join(extracted_texts)
742
+
743
+ d['extracted'] = clean_text(extracted_text)
744
+ print(json.dumps(d), file=positive)
745
 
746
+ else:
747
+ d['extracted'] = CustomTokens.NO_SEGMENT.value
748
+ print(json.dumps(d), file=negative)
749
 
750
  if preprocess_args.do_split:
751
  print('Splitting')
src/segment.py CHANGED
@@ -25,7 +25,7 @@ def get_overlapping_chunks_of_tokens(tokens, size, overlap):
25
 
26
 
27
  # Generate up to max_tokens - SAFETY_TOKENS
28
- SAFETY_TOKENS = 8
29
 
30
 
31
  # TODO play around with this?
@@ -36,10 +36,10 @@ def add_labels_to_words(words, sponsor_segments):
36
 
37
  # TODO binary search
38
  for word in words:
39
- word['sponsor'] = False
40
  for sponsor_segment in sponsor_segments:
41
  if sponsor_segment['start'] <= word['start'] <= sponsor_segment['end']:
42
- word['sponsor'] = True
43
 
44
  # TODO use extract_segment with mapping function?
45
  # TODO remove sponsor segments that contain mostly empty space?
 
25
 
26
 
27
  # Generate up to max_tokens - SAFETY_TOKENS
28
+ SAFETY_TOKENS = 12
29
 
30
 
31
  # TODO play around with this?
 
36
 
37
  # TODO binary search
38
  for word in words:
39
+ word['category'] = None
40
  for sponsor_segment in sponsor_segments:
41
  if sponsor_segment['start'] <= word['start'] <= sponsor_segment['end']:
42
+ word['category'] = sponsor_segment['category']
43
 
44
  # TODO use extract_segment with mapping function?
45
  # TODO remove sponsor segments that contain mostly empty space?
src/shared.py CHANGED
@@ -7,16 +7,17 @@ from typing import Optional
7
  from dataclasses import dataclass, field
8
  from enum import Enum
9
 
10
-
11
  class CustomTokens(Enum):
 
 
12
  URL = 'URL_TOKEN'
13
  HYPHENATED_URL = 'HYPHENATED_URL_TOKEN'
14
  NUMBER_PERCENTAGE = 'NUMBER_PERCENTAGE_TOKEN'
15
  NUMBER = 'NUMBER_TOKEN'
16
 
17
- START_SPONSOR = 'START_SPONSOR'
18
- END_SPONSOR = 'END_SPONSOR'
19
- NO_SPONSOR = 'NO_SPONSOR_FOUND'
20
 
21
  SHORT_HYPHENATED = 'SHORT_HYPHENATED_TOKEN'
22
  LONG_WORD = 'LONG_WORD_TOKEN'
@@ -26,8 +27,6 @@ class CustomTokens(Enum):
26
  APPLAUSE = '[Applause]'
27
  LAUGHTER = '[Laughter]'
28
 
29
- PROFANITY_RAW = '[ __ ]' # How YouTube transcribes profanity
30
- PROFANITY_CONVERTED = '*****' # Safer version for tokenizing
31
  PROFANITY = 'PROFANITY_TOKEN'
32
 
33
  @classmethod
 
7
  from dataclasses import dataclass, field
8
  from enum import Enum
9
 
 
10
  class CustomTokens(Enum):
11
+ EXTRACT_SEGMENTS_PREFIX = 'EXTRACT_SEGMENTS: '
12
+
13
  URL = 'URL_TOKEN'
14
  HYPHENATED_URL = 'HYPHENATED_URL_TOKEN'
15
  NUMBER_PERCENTAGE = 'NUMBER_PERCENTAGE_TOKEN'
16
  NUMBER = 'NUMBER_TOKEN'
17
 
18
+ START_SEGMENT = 'START_SEGMENT_TOKEN'
19
+ END_SEGMENT = 'END_SEGMENT_TOKEN'
20
+ NO_SEGMENT = 'NO_SEGMENT_FOUND'
21
 
22
  SHORT_HYPHENATED = 'SHORT_HYPHENATED_TOKEN'
23
  LONG_WORD = 'LONG_WORD_TOKEN'
 
27
  APPLAUSE = '[Applause]'
28
  LAUGHTER = '[Laughter]'
29
 
 
 
30
  PROFANITY = 'PROFANITY_TOKEN'
31
 
32
  @classmethod
src/train.py CHANGED
@@ -1,9 +1,8 @@
1
  from preprocess import load_datasets, DatasetArguments
2
- from predict import ClassifierArguments, SPONSOR_MATCH_RE, DEFAULT_TOKEN_PREFIX
3
- from shared import device, GeneralArguments, OutputArguments
4
- from model import ModelArguments
5
  import transformers
6
- from model import get_model, get_tokenizer
7
  import logging
8
  import os
9
  import sys
@@ -22,7 +21,7 @@ from transformers.utils import check_min_version
22
  from transformers.utils.versions import require_version
23
  from sklearn.linear_model import LogisticRegression
24
  from sklearn.feature_extraction.text import TfidfVectorizer
25
-
26
  import re
27
 
28
  # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
@@ -117,7 +116,7 @@ class DataTrainingArguments:
117
  },
118
  )
119
  source_prefix: Optional[str] = field(
120
- default=DEFAULT_TOKEN_PREFIX, metadata={
121
  'help': 'A prefix to add before every source text (useful for T5 models).'}
122
  )
123
 
@@ -135,11 +134,11 @@ class SequenceTrainingArguments(OutputArguments, Seq2SeqTrainingArguments):
135
  num_train_epochs: float = field(
136
  default=1, metadata={'help': 'Total number of training epochs to perform.'})
137
 
138
- save_steps: int = field(default=2500, metadata={
139
  'help': 'Save checkpoint every X updates steps.'})
140
- eval_steps: int = field(default=2500, metadata={
141
  'help': 'Run an evaluation every X steps.'})
142
- logging_steps: int = field(default=2500, metadata={
143
  'help': 'Log every X updates steps.'})
144
 
145
  skip_train_transformer: bool = field(default=False, metadata={
@@ -257,8 +256,8 @@ def main():
257
 
258
  ngram_range=(1, 2), # best so far
259
  # max_features=8000 # remove for higher accuracy?
260
- max_features=50000
261
- # max_features=10000
262
  )
263
 
264
  train_test_data = {
@@ -277,11 +276,12 @@ def main():
277
  dataset = raw_datasets[ds_type]
278
 
279
  for row in dataset:
280
-
281
  # Get matches:
282
- if row['sponsor']:
283
- matches = re.findall(SPONSOR_MATCH_RE, row['extracted'])
284
- else:
 
 
285
  matches = [row['text']]
286
 
287
  for match in matches:
 
1
  from preprocess import load_datasets, DatasetArguments
2
+ from predict import ClassifierArguments, SPONSOR_MATCH_RE
3
+ from shared import CustomTokens, device, GeneralArguments, OutputArguments
4
+ from model import ModelArguments, get_model, get_tokenizer
5
  import transformers
 
6
  import logging
7
  import os
8
  import sys
 
21
  from transformers.utils.versions import require_version
22
  from sklearn.linear_model import LogisticRegression
23
  from sklearn.feature_extraction.text import TfidfVectorizer
24
+ from utils import re_findall
25
  import re
26
 
27
  # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
 
116
  },
117
  )
118
  source_prefix: Optional[str] = field(
119
+ default=CustomTokens.EXTRACT_SEGMENTS_PREFIX.value, metadata={
120
  'help': 'A prefix to add before every source text (useful for T5 models).'}
121
  )
122
 
 
134
  num_train_epochs: float = field(
135
  default=1, metadata={'help': 'Total number of training epochs to perform.'})
136
 
137
+ save_steps: int = field(default=5000, metadata={
138
  'help': 'Save checkpoint every X updates steps.'})
139
+ eval_steps: int = field(default=5000, metadata={
140
  'help': 'Run an evaluation every X steps.'})
141
+ logging_steps: int = field(default=5000, metadata={
142
  'help': 'Log every X updates steps.'})
143
 
144
  skip_train_transformer: bool = field(default=False, metadata={
 
256
 
257
  ngram_range=(1, 2), # best so far
258
  # max_features=8000 # remove for higher accuracy?
259
+ # max_features=50000
260
+ max_features=10000
261
  )
262
 
263
  train_test_data = {
 
276
  dataset = raw_datasets[ds_type]
277
 
278
  for row in dataset:
 
279
  # Get matches:
280
+ matches = re_findall(SPONSOR_MATCH_RE, row['extracted'])
281
+
282
+ return # TODO fix
283
+
284
+ if not matches:
285
  matches = [row['text']]
286
 
287
  for match in matches:
src/utils.py CHANGED
@@ -1,6 +1,8 @@
 
1
  import asyncio
2
  import os
3
 
 
4
  class Job:
5
  def __init__(self, function, *args, **kwargs) -> None:
6
  self.function = function
@@ -84,3 +86,7 @@ class InterruptibleThreadPool:
84
  self.loop.close()
85
 
86
  return self.jobs
 
 
 
 
 
1
+ import re
2
  import asyncio
3
  import os
4
 
5
+
6
  class Job:
7
  def __init__(self, function, *args, **kwargs) -> None:
8
  self.function = function
 
86
  self.loop.close()
87
 
88
  return self.jobs
89
+
90
+
91
+ def re_findall(pattern, string):
92
+ return [m.groupdict() for m in re.finditer(pattern, string)]