Joshua Lochner commited on
Commit
583f4cf
1 Parent(s): df35612

Improve how transcripts are stored and how manual transcripts are segmented

Browse files
Files changed (1) hide show
  1. src/preprocess.py +213 -136
src/preprocess.py CHANGED
@@ -1,9 +1,8 @@
1
- from utils import jaccard, Task, InterruptibleTaskPool
2
  from functools import lru_cache
3
  from datetime import datetime
4
  import itertools
5
  from typing import Optional, List
6
- from datasets import load_dataset
7
  from model import ModelArguments
8
  import segment
9
  from tqdm import tqdm
@@ -21,94 +20,141 @@ import time
21
  import requests
22
 
23
 
24
- def find(s, ch):
25
- return [i for i, ltr in enumerate(s) if ltr == ch]
26
 
27
 
28
- def wordify(transcript, maximum_wps=1):
29
- """Try to replicate format for automatically generated transcripts"""
30
 
31
- # Do not allow segments to be on screen for too long using maximum_wps
32
- words = []
33
 
34
- for line_index, line in enumerate(transcript):
35
- text = line['text'].replace('\n', ' ').strip()
36
- if not text:
37
- continue
38
 
39
- start = line['start']
40
- next_start = transcript[line_index + 1]['start'] \
41
- if line_index < len(transcript) - 1 else float('inf')
42
 
43
- # Use maximum wps to calculate latest end (to avoid segments which stay on screen too long)
44
- longest_duration = maximum_wps * text.count(' ')
45
- latest_end = start + longest_duration
46
- end = min(start + line['duration'], next_start, latest_end)
47
 
48
- duration = end - start
 
49
 
50
- indices = find(text, ' ') + [len(text)]
51
- start_index = 0
52
- for i in range(len(indices)):
53
- word = text[start_index:indices[i]].strip()
54
- if not word:
55
- continue # Skip empty words (e.g., \n)
56
- percentage = start_index/indices[-1]
57
 
58
- w_duration = len(word)/indices[-1] * duration
59
 
60
- w_start = start + percentage * duration
 
 
 
61
 
62
- words.append({
63
- 'start': round(w_start, 3),
64
- 'duration': round(w_duration, 3),
65
- 'end': round(w_start + w_duration, 3),
66
- 'text': word,
67
- })
68
 
69
- start_index = indices[i] + 1
 
 
 
 
 
 
70
 
71
- return words
72
 
 
 
 
 
73
 
74
- def get_manual_words(transcript_list):
75
- transcript = transcript_list.find_manually_created_transcript(
76
- ['en-GB', 'en-US', 'en']).fetch()
77
- return wordify(transcript)
78
 
 
 
79
 
80
- PROFANITY_RAW = '[ __ ]' # How YouTube transcribes profanity
81
- PROFANITY_CONVERTED = '*****' # Safer version for tokenizing
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
 
 
 
 
83
 
84
- # TODO add end time for words
85
- def get_auto_words(transcript_list):
86
- words = []
87
- transcript = transcript_list.find_generated_transcript(['en'])
88
- url = transcript._url + '&fmt=json3'
89
- info = transcript._http_client.get(url)
90
 
91
- for event in info.json()['events']:
92
- start_ms = event.get('tStartMs', 0)
93
 
94
- for word in event.get('segs') or []:
95
- offset_ms = word.get('tOffsetMs', 0)
 
 
96
 
97
- texts = word['utf8'].replace(
98
- PROFANITY_RAW, PROFANITY_CONVERTED
99
- ).strip().split()
100
 
101
- for text in texts:
102
- words.append({
103
- 'start': (start_ms + offset_ms)/1000,
104
- 'text': text
105
- })
106
 
107
- return words
 
 
 
 
 
 
 
 
108
 
109
 
110
  def list_transcripts(video_id):
111
- return YouTubeTranscriptApi.list_transcripts(video_id)
 
 
 
112
 
113
 
114
  WORDS_TO_REMOVE = [
@@ -119,60 +165,74 @@ WORDS_TO_REMOVE = [
119
 
120
 
121
  @lru_cache(maxsize=16)
122
- def get_words(video_id, process=True, transcript_type='auto', fallback='manual', filter_words_to_remove=True):
123
  """Get parsed video transcript with caching system
124
  returns None if not processed yet and process is False
125
  """
 
 
126
  transcript_path = os.path.join( # TODO use relative path to this
127
  'transcripts', transcript_type, f'{video_id}.json')
128
 
129
- words = None
130
  try:
131
- if os.path.exists(transcript_path): # Load from file
132
  with open(transcript_path) as fp:
133
- words = json.load(fp) # May be empty
134
 
135
  elif process:
136
  transcript_list = list_transcripts(video_id)
137
 
138
- if transcript_type == 'manual':
139
- words = get_manual_words(transcript_list)
140
- else:
141
- words = get_auto_words(transcript_list)
 
 
 
 
 
142
 
143
  except (TooManyRequests, YouTubeRequestFailed):
144
  raise # Cannot recover from these errors and do not mark as empty transcript
145
 
146
- except requests.exceptions.ConnectionError: # Can recover
147
  time.sleep(10) # Timeout
148
- return get_words(video_id, process, transcript_type, fallback)
149
 
150
  except CouldNotRetrieveTranscript: # Retrying won't solve
151
  pass # Mark as empty transcript
152
 
153
  except json.decoder.JSONDecodeError:
154
  print('JSONDecodeError for', video_id)
155
- os.remove(transcript_path) # Remove file and try again
156
- return get_words(video_id, process, transcript_type, fallback)
 
157
 
158
  # Tried to process it, but it was empty...
159
- if process and not os.path.exists(transcript_path):
160
  with open(transcript_path, 'w') as fp:
161
- json.dump(words, fp)
162
 
163
- if not words and fallback is not None:
164
- return get_words(video_id, process, transcript_type=fallback, fallback=None)
165
 
166
- if words and filter_words_to_remove:
167
- words = list(filter(lambda x: x['text'] not in WORDS_TO_REMOVE, words))
 
 
 
 
 
 
168
 
169
- return words
170
 
171
 
172
  # TODO make min_sponsor_segment_length param
173
  # TODO rename to extract_segments
174
  def extract_sponsors(words, min_sponsor_segment_length=3):
175
- if not words or len(words) < min_sponsor_segment_length:
176
  return []
177
 
178
  paragraphs = []
@@ -302,9 +362,13 @@ class PreprocessArguments:
302
 
303
  max_date: str = field(
304
  # default='01/01/9999', # Include all
305
- default='27/01/2022',
306
  metadata={'help': 'Only use videos that have some segment from before this date (exclusive). This allows for videos to have segments be corrected, but ignores new videos (posted after this date) to enter the pool.'})
307
 
 
 
 
 
308
  do_process_database: bool = field(
309
  default=False, metadata={'help': 'Process the raw database'}
310
  )
@@ -393,23 +457,6 @@ def download_file(url, filename):
393
  return total_bytes == os.path.getsize(filename)
394
 
395
 
396
- def load_datasets(dataset_args):
397
- print('Reading datasets')
398
- data_files = {}
399
-
400
- if dataset_args.train_file is not None:
401
- data_files['train'] = os.path.join(
402
- dataset_args.data_dir, dataset_args.train_file)
403
- if dataset_args.validation_file is not None:
404
- data_files['validation'] = os.path.join(
405
- dataset_args.data_dir, dataset_args.validation_file)
406
- if dataset_args.test_file is not None:
407
- data_files['test'] = os.path.join(
408
- dataset_args.data_dir, dataset_args.test_file)
409
-
410
- return load_dataset('json', data_files=data_files, cache_dir=dataset_args.dataset_cache_dir)
411
-
412
-
413
  @dataclass
414
  class DatasetArguments:
415
  data_dir: Optional[str] = field(
@@ -503,9 +550,11 @@ def main():
503
  break
504
  print('Failed, trying next')
505
 
 
506
  processed_db_path = os.path.join(
507
  dataset_args.data_dir, dataset_args.processed_database)
508
 
 
509
  @lru_cache(maxsize=1)
510
  def read_db():
511
  if not preprocess_args.overwrite and os.path.exists(processed_db_path):
@@ -520,6 +569,7 @@ def main():
520
 
521
  for line in reader:
522
 
 
523
  if line['service'] != 'YouTube':
524
  continue
525
  if len(line['videoID']) != 11:
@@ -565,9 +615,10 @@ def main():
565
  })
566
 
567
  # Remove duplicate sponsor segments by choosing best (most votes)
568
- print('Remove duplicate segments')
569
- for key in db:
570
- db[key] = remove_duplicate_segments(db[key])
 
571
 
572
  # We now remove whole videos from the list
573
  # Helps with obtaining "fully-labelled" videos
@@ -616,20 +667,44 @@ def main():
616
 
617
  video_ids = list(parsed_database.keys() - finished)
618
 
619
- # Create tasks generator
620
- tasks = (
621
- Task(get_words, video_id)
622
- for video_id in video_ids
623
- )
624
-
625
- print('Downloading transcripts')
626
- with tqdm(total=len(video_ids)) as progress:
627
- def callback(task):
628
- progress.set_description(f'Processing {task.args[0]}')
629
- progress.update()
630
-
631
- InterruptibleTaskPool(
632
- tasks, preprocess_args.num_jobs, callback).start()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
633
 
634
  final_path = os.path.join(
635
  dataset_args.data_dir, dataset_args.processed_file)
@@ -641,9 +716,14 @@ def main():
641
 
642
  parsed_database = read_db()
643
 
644
- # TODO parallelise?
645
- with tqdm(total=len(parsed_database)) as progress:
646
- for index, (video_id, segments) in enumerate(parsed_database.items()):
 
 
 
 
 
647
  if preprocess_args.max_videos is not None and index >= preprocess_args.max_videos:
648
  break
649
  progress.set_description(f'Processing {video_id}')
@@ -654,7 +734,8 @@ def main():
654
  continue
655
 
656
  final_vid_segs = []
657
- for seg in segments: # Only add segments with high enough wps
 
658
  segment_words = segment.extract_segment(
659
  video_words, seg['start'], seg['end'])
660
 
@@ -697,8 +778,6 @@ def main():
697
  # if not os.path.exists(excess_path) or preprocess_args.overwrite
698
  # TODO use overwrite param
699
 
700
- os.makedirs(dataset_args.data_dir, exist_ok=True)
701
-
702
  positive_file = os.path.join(
703
  dataset_args.data_dir, dataset_args.positive_file)
704
  negative_file = os.path.join(
@@ -724,7 +803,7 @@ def main():
724
 
725
  data = list(itertools.islice(data, start_index, end_index))
726
 
727
- write_mode = 'w' if preprocess_args.overwrite else 'a'
728
  with open(positive_file, write_mode, encoding='utf-8') as positive, \
729
  open(negative_file, write_mode, encoding='utf-8') as negative, \
730
  tqdm(data) as progress:
@@ -734,16 +813,14 @@ def main():
734
  progress.set_description(f'Processing {video_id}')
735
  progress.update()
736
 
737
- words = get_words(video_id, process=False)
 
738
  if not words:
739
  continue
740
 
741
- num_words = len(words)
742
- if num_words <= 1:
743
  continue
744
 
745
- # TODO only count words that aren't [Music], [Applause], etc.
746
-
747
  segments = segment.generate_labelled_segments(
748
  words, tokenizer, segmentation_args, sponsor_segments)
749
 
@@ -753,13 +830,13 @@ def main():
753
  for seg in segments:
754
  seg_start = segment.word_start(seg[0])
755
  seg_end = segment.word_end(seg[-1])
756
- # duration = seg_end - seg_start
757
- # wps = len(seg)/duration if duration > 0 else 0
758
 
759
- # # Ignore segments with "not enough words" in the transcript
760
- # # Must do here since this includes non-sponsor segments
761
- # if wps < preprocess_args.min_wps:
762
- # continue
763
 
764
  d = {
765
  'video_index': offset + start_index,
 
1
+ from utils import jaccard
2
  from functools import lru_cache
3
  from datetime import datetime
4
  import itertools
5
  from typing import Optional, List
 
6
  from model import ModelArguments
7
  import segment
8
  from tqdm import tqdm
 
20
  import requests
21
 
22
 
23
+ PROFANITY_RAW = '[ __ ]' # How YouTube transcribes profanity
24
+ PROFANITY_CONVERTED = '*****' # Safer version for tokenizing
25
 
26
 
27
+ NUM_DECIMALS = 3
 
28
 
 
 
29
 
30
+ def parse_transcript_json(json_data, granularity):
31
+ assert json_data['wireMagic'] == 'pb3'
 
 
32
 
33
+ assert granularity in ('word', 'chunk')
 
 
34
 
35
+ # TODO remove bracketed words?
36
+ # (kiss smacks)
37
+ # (upbeat music)
38
+ # [text goes here]
39
 
40
+ # Some manual transcripts aren't that well formatted... but do have punctuation
41
+ # https://www.youtube.com/watch?v=LR9FtWVjk2c
42
 
43
+ parsed_transcript = []
 
 
 
 
 
 
44
 
45
+ events = json_data['events']
46
 
47
+ for event_index, event in enumerate(events):
48
+ segments = event.get('segs')
49
+ if not segments:
50
+ continue
51
 
52
+ # This value is known (when phrase appears on screen)
53
+ start_ms = event['tStartMs']
54
+ total_characters = 0
 
 
 
55
 
56
+ new_segments = []
57
+ for seg in segments:
58
+ text = seg['utf8'].replace('\n', ' ').replace(
59
+ PROFANITY_RAW, PROFANITY_CONVERTED, # Needed for auto-generated transcripts
60
+ ).strip()
61
+ if not text:
62
+ continue
63
 
64
+ offset_ms = seg.get('tOffsetMs', 0)
65
 
66
+ new_segments.append({
67
+ 'text': text,
68
+ 'start': round((start_ms + offset_ms)/1000, NUM_DECIMALS)
69
+ })
70
 
71
+ total_characters += len(text)
 
 
 
72
 
73
+ if not new_segments:
74
+ continue
75
 
76
+ if event_index < len(events) - 1:
77
+ next_start_ms = events[event_index + 1]['tStartMs']
78
+ total_event_duration_ms = min(
79
+ event.get('dDurationMs', float('inf')), next_start_ms - start_ms)
80
+ else:
81
+ total_event_duration_ms = event.get('dDurationMs', 0)
82
+
83
+ avg_seconds_per_character = (
84
+ total_event_duration_ms/total_characters)/1000
85
+
86
+ num_char_count = 0
87
+ for seg_index, seg in enumerate(new_segments):
88
+ num_char_count += len(seg['text'])
89
+
90
+ # Estimate segment end
91
+ seg_end = seg['start'] + \
92
+ (num_char_count * avg_seconds_per_character)
93
+
94
+ if seg_index < len(new_segments) - 1:
95
+ # Do not allow longer than next
96
+ seg_end = min(seg_end, new_segments[seg_index+1]['start'])
97
+
98
+ seg['end'] = round(seg_end, NUM_DECIMALS)
99
+ parsed_transcript.append(seg)
100
+
101
+ final_parsed_transcript = []
102
+ for i in range(len(parsed_transcript)):
103
+
104
+ word_level = granularity == 'word'
105
+ if word_level:
106
+ split_text = parsed_transcript[i]['text'].split()
107
+ elif granularity == 'chunk':
108
+ # Split on space after punctuation
109
+ split_text = re.split(
110
+ r'(?<=[.!?,-;])\s+', parsed_transcript[i]['text'])
111
+ if len(split_text) == 1:
112
+ split_on_whitespace = parsed_transcript[i]['text'].split()
113
+
114
+ if len(split_on_whitespace) >= 8: # Too many words
115
+ # Rather split on whitespace instead of punctuation
116
+ split_text = split_on_whitespace
117
+ else:
118
+ word_level = True
119
+ else:
120
+ raise ValueError('Unknown granularity')
121
 
122
+ segment_end = parsed_transcript[i]['end']
123
+ if i < len(parsed_transcript) - 1:
124
+ segment_end = min(segment_end, parsed_transcript[i+1]['start'])
125
 
126
+ segment_duration = segment_end - parsed_transcript[i]['start']
 
 
 
 
 
127
 
128
+ num_chars_in_text = sum(map(len, split_text))
 
129
 
130
+ num_char_count = 0
131
+ current_offset = 0
132
+ for s in split_text:
133
+ num_char_count += len(s)
134
 
135
+ next_offset = (num_char_count/num_chars_in_text) * segment_duration
 
 
136
 
137
+ word_start = round(
138
+ parsed_transcript[i]['start'] + current_offset, NUM_DECIMALS)
139
+ word_end = round(
140
+ parsed_transcript[i]['start'] + next_offset, NUM_DECIMALS)
 
141
 
142
+ # Make the reasonable assumption that min wps is 1.5
143
+ final_parsed_transcript.append({
144
+ 'text': s,
145
+ 'start': word_start,
146
+ 'end': min(word_end, word_start + 1.5) if word_level else word_end
147
+ })
148
+ current_offset = next_offset
149
+
150
+ return final_parsed_transcript
151
 
152
 
153
  def list_transcripts(video_id):
154
+ try:
155
+ return YouTubeTranscriptApi.list_transcripts(video_id)
156
+ except json.decoder.JSONDecodeError:
157
+ return None
158
 
159
 
160
  WORDS_TO_REMOVE = [
 
165
 
166
 
167
  @lru_cache(maxsize=16)
168
+ def get_words(video_id, process=True, transcript_type='auto', fallback='manual', filter_words_to_remove=True, download=False, granularity='word'):
169
  """Get parsed video transcript with caching system
170
  returns None if not processed yet and process is False
171
  """
172
+ # NOTE: granularity='chunk' should only be used for generating training data... nowhere else
173
+
174
  transcript_path = os.path.join( # TODO use relative path to this
175
  'transcripts', transcript_type, f'{video_id}.json')
176
 
177
+ raw_transcript_json = None
178
  try:
179
+ if not download and os.path.exists(transcript_path): # Load from file
180
  with open(transcript_path) as fp:
181
+ raw_transcript_json = json.load(fp) # May be empty
182
 
183
  elif process:
184
  transcript_list = list_transcripts(video_id)
185
 
186
+ if transcript_list is not None:
187
+ if transcript_type == 'manual':
188
+ ts = transcript_list.find_manually_created_transcript(
189
+ ['en-GB', 'en-US', 'en'])
190
+ else:
191
+ ts = transcript_list.find_generated_transcript(['en'])
192
+
193
+ raw_transcript_json = ts._http_client.get(
194
+ f'{ts._url}&fmt=json3').json()
195
 
196
  except (TooManyRequests, YouTubeRequestFailed):
197
  raise # Cannot recover from these errors and do not mark as empty transcript
198
 
199
+ except requests.exceptions.RequestException: # Can recover
200
  time.sleep(10) # Timeout
201
+ return get_words(video_id, process, transcript_type, fallback, granularity)
202
 
203
  except CouldNotRetrieveTranscript: # Retrying won't solve
204
  pass # Mark as empty transcript
205
 
206
  except json.decoder.JSONDecodeError:
207
  print('JSONDecodeError for', video_id)
208
+ if os.path.exists(transcript_path):
209
+ os.remove(transcript_path) # Remove file and try again
210
+ return get_words(video_id, process, transcript_type, fallback, granularity)
211
 
212
  # Tried to process it, but it was empty...
213
+ if download or (process and not os.path.exists(transcript_path)):
214
  with open(transcript_path, 'w') as fp:
215
+ json.dump(raw_transcript_json, fp)
216
 
217
+ if not raw_transcript_json and fallback is not None:
218
+ return get_words(video_id, process, transcript_type=fallback, fallback=None, granularity=granularity)
219
 
220
+ if raw_transcript_json:
221
+ processed_transcript = parse_transcript_json(
222
+ raw_transcript_json, granularity)
223
+ if filter_words_to_remove:
224
+ processed_transcript = list(
225
+ filter(lambda x: x['text'] not in WORDS_TO_REMOVE, processed_transcript))
226
+ else:
227
+ processed_transcript = raw_transcript_json # Either None or []
228
 
229
+ return processed_transcript
230
 
231
 
232
  # TODO make min_sponsor_segment_length param
233
  # TODO rename to extract_segments
234
  def extract_sponsors(words, min_sponsor_segment_length=3):
235
+ if not words:
236
  return []
237
 
238
  paragraphs = []
 
362
 
363
  max_date: str = field(
364
  # default='01/01/9999', # Include all
365
+ default='02/02/2022',
366
  metadata={'help': 'Only use videos that have some segment from before this date (exclusive). This allows for videos to have segments be corrected, but ignores new videos (posted after this date) to enter the pool.'})
367
 
368
+ keep_duplicate_segments: bool = field(
369
+ default=False, metadata={'help': 'Keep duplicate segments'}
370
+ )
371
+
372
  do_process_database: bool = field(
373
  default=False, metadata={'help': 'Process the raw database'}
374
  )
 
457
  return total_bytes == os.path.getsize(filename)
458
 
459
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
460
  @dataclass
461
  class DatasetArguments:
462
  data_dir: Optional[str] = field(
 
550
  break
551
  print('Failed, trying next')
552
 
553
+ os.makedirs(dataset_args.data_dir, exist_ok=True)
554
  processed_db_path = os.path.join(
555
  dataset_args.data_dir, dataset_args.processed_database)
556
 
557
+ # TODO process all valid possible items and then do filtering only later
558
  @lru_cache(maxsize=1)
559
  def read_db():
560
  if not preprocess_args.overwrite and os.path.exists(processed_db_path):
 
569
 
570
  for line in reader:
571
 
572
+ # Never show:
573
  if line['service'] != 'YouTube':
574
  continue
575
  if len(line['videoID']) != 11:
 
615
  })
616
 
617
  # Remove duplicate sponsor segments by choosing best (most votes)
618
+ if not preprocess_args.keep_duplicate_segments:
619
+ print('Remove duplicate segments')
620
+ for key in db:
621
+ db[key] = remove_duplicate_segments(db[key])
622
 
623
  # We now remove whole videos from the list
624
  # Helps with obtaining "fully-labelled" videos
 
667
 
668
  video_ids = list(parsed_database.keys() - finished)
669
 
670
+ # https://stackoverflow.com/a/63495323
671
+ import concurrent
672
+ POLL_INTERVAL = 0.1
673
+
674
+ # Wrap get words function to return video_id after completion
675
+ def get_words_wrapper(video_id):
676
+ get_words(video_id)
677
+ return video_id
678
+
679
+ print('Setting up ThreadPoolExecutor')
680
+ with concurrent.futures.ThreadPoolExecutor(max_workers=preprocess_args.num_jobs) as pool, \
681
+ tqdm(total=len(video_ids)) as progress:
682
+
683
+ all_futures = (pool.submit(get_words_wrapper, video_id)
684
+ for video_id in video_ids)
685
+ to_process = set(itertools.islice(
686
+ all_futures, preprocess_args.num_jobs))
687
+ try:
688
+ while to_process:
689
+ just_finished, to_process = concurrent.futures.wait(
690
+ to_process, timeout=POLL_INTERVAL)
691
+ to_process |= set(itertools.islice(
692
+ all_futures, len(just_finished)))
693
+
694
+ for d in just_finished:
695
+ progress.set_description(f'Processed {d.result()}')
696
+ progress.update()
697
+
698
+ except KeyboardInterrupt:
699
+ print('Gracefully shutting down: Cancelling unscheduled tasks')
700
+
701
+ # only futures that are not done will prevent exiting
702
+ for future in to_process:
703
+ future.cancel()
704
+
705
+ print('Waiting for in-progress tasks to complete')
706
+ concurrent.futures.wait(to_process, timeout=None)
707
+ print('Cancellation successful')
708
 
709
  final_path = os.path.join(
710
  dataset_args.data_dir, dataset_args.processed_file)
 
716
 
717
  parsed_database = read_db()
718
 
719
+ transcribed = set(x.split('.')[0] for x in os.listdir(
720
+ 'transcripts/auto/') + os.listdir('transcripts/manual/'))
721
+
722
+ # Only consider videos that have been transcribed already
723
+ video_ids = parsed_database.keys() & transcribed
724
+
725
+ with tqdm(total=len(video_ids)) as progress:
726
+ for index, video_id in enumerate(video_ids):
727
  if preprocess_args.max_videos is not None and index >= preprocess_args.max_videos:
728
  break
729
  progress.set_description(f'Processing {video_id}')
 
734
  continue
735
 
736
  final_vid_segs = []
737
+ # Only add segments with high enough wps
738
+ for seg in parsed_database[video_id]:
739
  segment_words = segment.extract_segment(
740
  video_words, seg['start'], seg['end'])
741
 
 
778
  # if not os.path.exists(excess_path) or preprocess_args.overwrite
779
  # TODO use overwrite param
780
 
 
 
781
  positive_file = os.path.join(
782
  dataset_args.data_dir, dataset_args.positive_file)
783
  negative_file = os.path.join(
 
803
 
804
  data = list(itertools.islice(data, start_index, end_index))
805
 
806
+ write_mode = 'w' # if preprocess_args.overwrite else 'a'
807
  with open(positive_file, write_mode, encoding='utf-8') as positive, \
808
  open(negative_file, write_mode, encoding='utf-8') as negative, \
809
  tqdm(data) as progress:
 
813
  progress.set_description(f'Processing {video_id}')
814
  progress.update()
815
 
816
+ # Use chunk granularity to improve manual transcripts
817
+ words = get_words(video_id, process=False, granularity='chunk')
818
  if not words:
819
  continue
820
 
821
+ if len(words) <= 1:
 
822
  continue
823
 
 
 
824
  segments = segment.generate_labelled_segments(
825
  words, tokenizer, segmentation_args, sponsor_segments)
826
 
 
830
  for seg in segments:
831
  seg_start = segment.word_start(seg[0])
832
  seg_end = segment.word_end(seg[-1])
833
+ duration = seg_end - seg_start
834
+ wps = len(seg)/duration if duration > 0 else 0
835
 
836
+ # Ignore segments with "not enough words" in the transcript
837
+ # Must do here since this includes non-sponsor segments
838
+ if wps < preprocess_args.min_wps:
839
+ continue
840
 
841
  d = {
842
  'video_index': offset + start_index,