Spaces:
Running
Running
Joshua Lochner
commited on
Commit
•
583f4cf
1
Parent(s):
df35612
Improve how transcripts are stored and how manual transcripts are segmented
Browse files- src/preprocess.py +213 -136
src/preprocess.py
CHANGED
@@ -1,9 +1,8 @@
|
|
1 |
-
from utils import jaccard
|
2 |
from functools import lru_cache
|
3 |
from datetime import datetime
|
4 |
import itertools
|
5 |
from typing import Optional, List
|
6 |
-
from datasets import load_dataset
|
7 |
from model import ModelArguments
|
8 |
import segment
|
9 |
from tqdm import tqdm
|
@@ -21,94 +20,141 @@ import time
|
|
21 |
import requests
|
22 |
|
23 |
|
24 |
-
|
25 |
-
|
26 |
|
27 |
|
28 |
-
|
29 |
-
"""Try to replicate format for automatically generated transcripts"""
|
30 |
|
31 |
-
# Do not allow segments to be on screen for too long using maximum_wps
|
32 |
-
words = []
|
33 |
|
34 |
-
|
35 |
-
|
36 |
-
if not text:
|
37 |
-
continue
|
38 |
|
39 |
-
|
40 |
-
next_start = transcript[line_index + 1]['start'] \
|
41 |
-
if line_index < len(transcript) - 1 else float('inf')
|
42 |
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
|
48 |
-
|
|
|
49 |
|
50 |
-
|
51 |
-
start_index = 0
|
52 |
-
for i in range(len(indices)):
|
53 |
-
word = text[start_index:indices[i]].strip()
|
54 |
-
if not word:
|
55 |
-
continue # Skip empty words (e.g., \n)
|
56 |
-
percentage = start_index/indices[-1]
|
57 |
|
58 |
-
|
59 |
|
60 |
-
|
|
|
|
|
|
|
61 |
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
'end': round(w_start + w_duration, 3),
|
66 |
-
'text': word,
|
67 |
-
})
|
68 |
|
69 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
70 |
|
71 |
-
|
72 |
|
|
|
|
|
|
|
|
|
73 |
|
74 |
-
|
75 |
-
transcript = transcript_list.find_manually_created_transcript(
|
76 |
-
['en-GB', 'en-US', 'en']).fetch()
|
77 |
-
return wordify(transcript)
|
78 |
|
|
|
|
|
79 |
|
80 |
-
|
81 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
82 |
|
|
|
|
|
|
|
83 |
|
84 |
-
|
85 |
-
def get_auto_words(transcript_list):
|
86 |
-
words = []
|
87 |
-
transcript = transcript_list.find_generated_transcript(['en'])
|
88 |
-
url = transcript._url + '&fmt=json3'
|
89 |
-
info = transcript._http_client.get(url)
|
90 |
|
91 |
-
|
92 |
-
start_ms = event.get('tStartMs', 0)
|
93 |
|
94 |
-
|
95 |
-
|
|
|
|
|
96 |
|
97 |
-
|
98 |
-
PROFANITY_RAW, PROFANITY_CONVERTED
|
99 |
-
).strip().split()
|
100 |
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
})
|
106 |
|
107 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
108 |
|
109 |
|
110 |
def list_transcripts(video_id):
|
111 |
-
|
|
|
|
|
|
|
112 |
|
113 |
|
114 |
WORDS_TO_REMOVE = [
|
@@ -119,60 +165,74 @@ WORDS_TO_REMOVE = [
|
|
119 |
|
120 |
|
121 |
@lru_cache(maxsize=16)
|
122 |
-
def get_words(video_id, process=True, transcript_type='auto', fallback='manual', filter_words_to_remove=True):
|
123 |
"""Get parsed video transcript with caching system
|
124 |
returns None if not processed yet and process is False
|
125 |
"""
|
|
|
|
|
126 |
transcript_path = os.path.join( # TODO use relative path to this
|
127 |
'transcripts', transcript_type, f'{video_id}.json')
|
128 |
|
129 |
-
|
130 |
try:
|
131 |
-
if os.path.exists(transcript_path): # Load from file
|
132 |
with open(transcript_path) as fp:
|
133 |
-
|
134 |
|
135 |
elif process:
|
136 |
transcript_list = list_transcripts(video_id)
|
137 |
|
138 |
-
if
|
139 |
-
|
140 |
-
|
141 |
-
|
|
|
|
|
|
|
|
|
|
|
142 |
|
143 |
except (TooManyRequests, YouTubeRequestFailed):
|
144 |
raise # Cannot recover from these errors and do not mark as empty transcript
|
145 |
|
146 |
-
except requests.exceptions.
|
147 |
time.sleep(10) # Timeout
|
148 |
-
return get_words(video_id, process, transcript_type, fallback)
|
149 |
|
150 |
except CouldNotRetrieveTranscript: # Retrying won't solve
|
151 |
pass # Mark as empty transcript
|
152 |
|
153 |
except json.decoder.JSONDecodeError:
|
154 |
print('JSONDecodeError for', video_id)
|
155 |
-
os.
|
156 |
-
|
|
|
157 |
|
158 |
# Tried to process it, but it was empty...
|
159 |
-
if process and not os.path.exists(transcript_path):
|
160 |
with open(transcript_path, 'w') as fp:
|
161 |
-
json.dump(
|
162 |
|
163 |
-
if not
|
164 |
-
return get_words(video_id, process, transcript_type=fallback, fallback=None)
|
165 |
|
166 |
-
if
|
167 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
168 |
|
169 |
-
return
|
170 |
|
171 |
|
172 |
# TODO make min_sponsor_segment_length param
|
173 |
# TODO rename to extract_segments
|
174 |
def extract_sponsors(words, min_sponsor_segment_length=3):
|
175 |
-
if not words
|
176 |
return []
|
177 |
|
178 |
paragraphs = []
|
@@ -302,9 +362,13 @@ class PreprocessArguments:
|
|
302 |
|
303 |
max_date: str = field(
|
304 |
# default='01/01/9999', # Include all
|
305 |
-
default='
|
306 |
metadata={'help': 'Only use videos that have some segment from before this date (exclusive). This allows for videos to have segments be corrected, but ignores new videos (posted after this date) to enter the pool.'})
|
307 |
|
|
|
|
|
|
|
|
|
308 |
do_process_database: bool = field(
|
309 |
default=False, metadata={'help': 'Process the raw database'}
|
310 |
)
|
@@ -393,23 +457,6 @@ def download_file(url, filename):
|
|
393 |
return total_bytes == os.path.getsize(filename)
|
394 |
|
395 |
|
396 |
-
def load_datasets(dataset_args):
|
397 |
-
print('Reading datasets')
|
398 |
-
data_files = {}
|
399 |
-
|
400 |
-
if dataset_args.train_file is not None:
|
401 |
-
data_files['train'] = os.path.join(
|
402 |
-
dataset_args.data_dir, dataset_args.train_file)
|
403 |
-
if dataset_args.validation_file is not None:
|
404 |
-
data_files['validation'] = os.path.join(
|
405 |
-
dataset_args.data_dir, dataset_args.validation_file)
|
406 |
-
if dataset_args.test_file is not None:
|
407 |
-
data_files['test'] = os.path.join(
|
408 |
-
dataset_args.data_dir, dataset_args.test_file)
|
409 |
-
|
410 |
-
return load_dataset('json', data_files=data_files, cache_dir=dataset_args.dataset_cache_dir)
|
411 |
-
|
412 |
-
|
413 |
@dataclass
|
414 |
class DatasetArguments:
|
415 |
data_dir: Optional[str] = field(
|
@@ -503,9 +550,11 @@ def main():
|
|
503 |
break
|
504 |
print('Failed, trying next')
|
505 |
|
|
|
506 |
processed_db_path = os.path.join(
|
507 |
dataset_args.data_dir, dataset_args.processed_database)
|
508 |
|
|
|
509 |
@lru_cache(maxsize=1)
|
510 |
def read_db():
|
511 |
if not preprocess_args.overwrite and os.path.exists(processed_db_path):
|
@@ -520,6 +569,7 @@ def main():
|
|
520 |
|
521 |
for line in reader:
|
522 |
|
|
|
523 |
if line['service'] != 'YouTube':
|
524 |
continue
|
525 |
if len(line['videoID']) != 11:
|
@@ -565,9 +615,10 @@ def main():
|
|
565 |
})
|
566 |
|
567 |
# Remove duplicate sponsor segments by choosing best (most votes)
|
568 |
-
|
569 |
-
|
570 |
-
|
|
|
571 |
|
572 |
# We now remove whole videos from the list
|
573 |
# Helps with obtaining "fully-labelled" videos
|
@@ -616,20 +667,44 @@ def main():
|
|
616 |
|
617 |
video_ids = list(parsed_database.keys() - finished)
|
618 |
|
619 |
-
#
|
620 |
-
|
621 |
-
|
622 |
-
|
623 |
-
|
624 |
-
|
625 |
-
|
626 |
-
|
627 |
-
|
628 |
-
|
629 |
-
|
630 |
-
|
631 |
-
|
632 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
633 |
|
634 |
final_path = os.path.join(
|
635 |
dataset_args.data_dir, dataset_args.processed_file)
|
@@ -641,9 +716,14 @@ def main():
|
|
641 |
|
642 |
parsed_database = read_db()
|
643 |
|
644 |
-
|
645 |
-
|
646 |
-
|
|
|
|
|
|
|
|
|
|
|
647 |
if preprocess_args.max_videos is not None and index >= preprocess_args.max_videos:
|
648 |
break
|
649 |
progress.set_description(f'Processing {video_id}')
|
@@ -654,7 +734,8 @@ def main():
|
|
654 |
continue
|
655 |
|
656 |
final_vid_segs = []
|
657 |
-
|
|
|
658 |
segment_words = segment.extract_segment(
|
659 |
video_words, seg['start'], seg['end'])
|
660 |
|
@@ -697,8 +778,6 @@ def main():
|
|
697 |
# if not os.path.exists(excess_path) or preprocess_args.overwrite
|
698 |
# TODO use overwrite param
|
699 |
|
700 |
-
os.makedirs(dataset_args.data_dir, exist_ok=True)
|
701 |
-
|
702 |
positive_file = os.path.join(
|
703 |
dataset_args.data_dir, dataset_args.positive_file)
|
704 |
negative_file = os.path.join(
|
@@ -724,7 +803,7 @@ def main():
|
|
724 |
|
725 |
data = list(itertools.islice(data, start_index, end_index))
|
726 |
|
727 |
-
write_mode = 'w' if preprocess_args.overwrite else 'a'
|
728 |
with open(positive_file, write_mode, encoding='utf-8') as positive, \
|
729 |
open(negative_file, write_mode, encoding='utf-8') as negative, \
|
730 |
tqdm(data) as progress:
|
@@ -734,16 +813,14 @@ def main():
|
|
734 |
progress.set_description(f'Processing {video_id}')
|
735 |
progress.update()
|
736 |
|
737 |
-
|
|
|
738 |
if not words:
|
739 |
continue
|
740 |
|
741 |
-
|
742 |
-
if num_words <= 1:
|
743 |
continue
|
744 |
|
745 |
-
# TODO only count words that aren't [Music], [Applause], etc.
|
746 |
-
|
747 |
segments = segment.generate_labelled_segments(
|
748 |
words, tokenizer, segmentation_args, sponsor_segments)
|
749 |
|
@@ -753,13 +830,13 @@ def main():
|
|
753 |
for seg in segments:
|
754 |
seg_start = segment.word_start(seg[0])
|
755 |
seg_end = segment.word_end(seg[-1])
|
756 |
-
|
757 |
-
|
758 |
|
759 |
-
#
|
760 |
-
#
|
761 |
-
|
762 |
-
|
763 |
|
764 |
d = {
|
765 |
'video_index': offset + start_index,
|
|
|
1 |
+
from utils import jaccard
|
2 |
from functools import lru_cache
|
3 |
from datetime import datetime
|
4 |
import itertools
|
5 |
from typing import Optional, List
|
|
|
6 |
from model import ModelArguments
|
7 |
import segment
|
8 |
from tqdm import tqdm
|
|
|
20 |
import requests
|
21 |
|
22 |
|
23 |
+
PROFANITY_RAW = '[ __ ]' # How YouTube transcribes profanity
|
24 |
+
PROFANITY_CONVERTED = '*****' # Safer version for tokenizing
|
25 |
|
26 |
|
27 |
+
NUM_DECIMALS = 3
|
|
|
28 |
|
|
|
|
|
29 |
|
30 |
+
def parse_transcript_json(json_data, granularity):
|
31 |
+
assert json_data['wireMagic'] == 'pb3'
|
|
|
|
|
32 |
|
33 |
+
assert granularity in ('word', 'chunk')
|
|
|
|
|
34 |
|
35 |
+
# TODO remove bracketed words?
|
36 |
+
# (kiss smacks)
|
37 |
+
# (upbeat music)
|
38 |
+
# [text goes here]
|
39 |
|
40 |
+
# Some manual transcripts aren't that well formatted... but do have punctuation
|
41 |
+
# https://www.youtube.com/watch?v=LR9FtWVjk2c
|
42 |
|
43 |
+
parsed_transcript = []
|
|
|
|
|
|
|
|
|
|
|
|
|
44 |
|
45 |
+
events = json_data['events']
|
46 |
|
47 |
+
for event_index, event in enumerate(events):
|
48 |
+
segments = event.get('segs')
|
49 |
+
if not segments:
|
50 |
+
continue
|
51 |
|
52 |
+
# This value is known (when phrase appears on screen)
|
53 |
+
start_ms = event['tStartMs']
|
54 |
+
total_characters = 0
|
|
|
|
|
|
|
55 |
|
56 |
+
new_segments = []
|
57 |
+
for seg in segments:
|
58 |
+
text = seg['utf8'].replace('\n', ' ').replace(
|
59 |
+
PROFANITY_RAW, PROFANITY_CONVERTED, # Needed for auto-generated transcripts
|
60 |
+
).strip()
|
61 |
+
if not text:
|
62 |
+
continue
|
63 |
|
64 |
+
offset_ms = seg.get('tOffsetMs', 0)
|
65 |
|
66 |
+
new_segments.append({
|
67 |
+
'text': text,
|
68 |
+
'start': round((start_ms + offset_ms)/1000, NUM_DECIMALS)
|
69 |
+
})
|
70 |
|
71 |
+
total_characters += len(text)
|
|
|
|
|
|
|
72 |
|
73 |
+
if not new_segments:
|
74 |
+
continue
|
75 |
|
76 |
+
if event_index < len(events) - 1:
|
77 |
+
next_start_ms = events[event_index + 1]['tStartMs']
|
78 |
+
total_event_duration_ms = min(
|
79 |
+
event.get('dDurationMs', float('inf')), next_start_ms - start_ms)
|
80 |
+
else:
|
81 |
+
total_event_duration_ms = event.get('dDurationMs', 0)
|
82 |
+
|
83 |
+
avg_seconds_per_character = (
|
84 |
+
total_event_duration_ms/total_characters)/1000
|
85 |
+
|
86 |
+
num_char_count = 0
|
87 |
+
for seg_index, seg in enumerate(new_segments):
|
88 |
+
num_char_count += len(seg['text'])
|
89 |
+
|
90 |
+
# Estimate segment end
|
91 |
+
seg_end = seg['start'] + \
|
92 |
+
(num_char_count * avg_seconds_per_character)
|
93 |
+
|
94 |
+
if seg_index < len(new_segments) - 1:
|
95 |
+
# Do not allow longer than next
|
96 |
+
seg_end = min(seg_end, new_segments[seg_index+1]['start'])
|
97 |
+
|
98 |
+
seg['end'] = round(seg_end, NUM_DECIMALS)
|
99 |
+
parsed_transcript.append(seg)
|
100 |
+
|
101 |
+
final_parsed_transcript = []
|
102 |
+
for i in range(len(parsed_transcript)):
|
103 |
+
|
104 |
+
word_level = granularity == 'word'
|
105 |
+
if word_level:
|
106 |
+
split_text = parsed_transcript[i]['text'].split()
|
107 |
+
elif granularity == 'chunk':
|
108 |
+
# Split on space after punctuation
|
109 |
+
split_text = re.split(
|
110 |
+
r'(?<=[.!?,-;])\s+', parsed_transcript[i]['text'])
|
111 |
+
if len(split_text) == 1:
|
112 |
+
split_on_whitespace = parsed_transcript[i]['text'].split()
|
113 |
+
|
114 |
+
if len(split_on_whitespace) >= 8: # Too many words
|
115 |
+
# Rather split on whitespace instead of punctuation
|
116 |
+
split_text = split_on_whitespace
|
117 |
+
else:
|
118 |
+
word_level = True
|
119 |
+
else:
|
120 |
+
raise ValueError('Unknown granularity')
|
121 |
|
122 |
+
segment_end = parsed_transcript[i]['end']
|
123 |
+
if i < len(parsed_transcript) - 1:
|
124 |
+
segment_end = min(segment_end, parsed_transcript[i+1]['start'])
|
125 |
|
126 |
+
segment_duration = segment_end - parsed_transcript[i]['start']
|
|
|
|
|
|
|
|
|
|
|
127 |
|
128 |
+
num_chars_in_text = sum(map(len, split_text))
|
|
|
129 |
|
130 |
+
num_char_count = 0
|
131 |
+
current_offset = 0
|
132 |
+
for s in split_text:
|
133 |
+
num_char_count += len(s)
|
134 |
|
135 |
+
next_offset = (num_char_count/num_chars_in_text) * segment_duration
|
|
|
|
|
136 |
|
137 |
+
word_start = round(
|
138 |
+
parsed_transcript[i]['start'] + current_offset, NUM_DECIMALS)
|
139 |
+
word_end = round(
|
140 |
+
parsed_transcript[i]['start'] + next_offset, NUM_DECIMALS)
|
|
|
141 |
|
142 |
+
# Make the reasonable assumption that min wps is 1.5
|
143 |
+
final_parsed_transcript.append({
|
144 |
+
'text': s,
|
145 |
+
'start': word_start,
|
146 |
+
'end': min(word_end, word_start + 1.5) if word_level else word_end
|
147 |
+
})
|
148 |
+
current_offset = next_offset
|
149 |
+
|
150 |
+
return final_parsed_transcript
|
151 |
|
152 |
|
153 |
def list_transcripts(video_id):
|
154 |
+
try:
|
155 |
+
return YouTubeTranscriptApi.list_transcripts(video_id)
|
156 |
+
except json.decoder.JSONDecodeError:
|
157 |
+
return None
|
158 |
|
159 |
|
160 |
WORDS_TO_REMOVE = [
|
|
|
165 |
|
166 |
|
167 |
@lru_cache(maxsize=16)
|
168 |
+
def get_words(video_id, process=True, transcript_type='auto', fallback='manual', filter_words_to_remove=True, download=False, granularity='word'):
|
169 |
"""Get parsed video transcript with caching system
|
170 |
returns None if not processed yet and process is False
|
171 |
"""
|
172 |
+
# NOTE: granularity='chunk' should only be used for generating training data... nowhere else
|
173 |
+
|
174 |
transcript_path = os.path.join( # TODO use relative path to this
|
175 |
'transcripts', transcript_type, f'{video_id}.json')
|
176 |
|
177 |
+
raw_transcript_json = None
|
178 |
try:
|
179 |
+
if not download and os.path.exists(transcript_path): # Load from file
|
180 |
with open(transcript_path) as fp:
|
181 |
+
raw_transcript_json = json.load(fp) # May be empty
|
182 |
|
183 |
elif process:
|
184 |
transcript_list = list_transcripts(video_id)
|
185 |
|
186 |
+
if transcript_list is not None:
|
187 |
+
if transcript_type == 'manual':
|
188 |
+
ts = transcript_list.find_manually_created_transcript(
|
189 |
+
['en-GB', 'en-US', 'en'])
|
190 |
+
else:
|
191 |
+
ts = transcript_list.find_generated_transcript(['en'])
|
192 |
+
|
193 |
+
raw_transcript_json = ts._http_client.get(
|
194 |
+
f'{ts._url}&fmt=json3').json()
|
195 |
|
196 |
except (TooManyRequests, YouTubeRequestFailed):
|
197 |
raise # Cannot recover from these errors and do not mark as empty transcript
|
198 |
|
199 |
+
except requests.exceptions.RequestException: # Can recover
|
200 |
time.sleep(10) # Timeout
|
201 |
+
return get_words(video_id, process, transcript_type, fallback, granularity)
|
202 |
|
203 |
except CouldNotRetrieveTranscript: # Retrying won't solve
|
204 |
pass # Mark as empty transcript
|
205 |
|
206 |
except json.decoder.JSONDecodeError:
|
207 |
print('JSONDecodeError for', video_id)
|
208 |
+
if os.path.exists(transcript_path):
|
209 |
+
os.remove(transcript_path) # Remove file and try again
|
210 |
+
return get_words(video_id, process, transcript_type, fallback, granularity)
|
211 |
|
212 |
# Tried to process it, but it was empty...
|
213 |
+
if download or (process and not os.path.exists(transcript_path)):
|
214 |
with open(transcript_path, 'w') as fp:
|
215 |
+
json.dump(raw_transcript_json, fp)
|
216 |
|
217 |
+
if not raw_transcript_json and fallback is not None:
|
218 |
+
return get_words(video_id, process, transcript_type=fallback, fallback=None, granularity=granularity)
|
219 |
|
220 |
+
if raw_transcript_json:
|
221 |
+
processed_transcript = parse_transcript_json(
|
222 |
+
raw_transcript_json, granularity)
|
223 |
+
if filter_words_to_remove:
|
224 |
+
processed_transcript = list(
|
225 |
+
filter(lambda x: x['text'] not in WORDS_TO_REMOVE, processed_transcript))
|
226 |
+
else:
|
227 |
+
processed_transcript = raw_transcript_json # Either None or []
|
228 |
|
229 |
+
return processed_transcript
|
230 |
|
231 |
|
232 |
# TODO make min_sponsor_segment_length param
|
233 |
# TODO rename to extract_segments
|
234 |
def extract_sponsors(words, min_sponsor_segment_length=3):
|
235 |
+
if not words:
|
236 |
return []
|
237 |
|
238 |
paragraphs = []
|
|
|
362 |
|
363 |
max_date: str = field(
|
364 |
# default='01/01/9999', # Include all
|
365 |
+
default='02/02/2022',
|
366 |
metadata={'help': 'Only use videos that have some segment from before this date (exclusive). This allows for videos to have segments be corrected, but ignores new videos (posted after this date) to enter the pool.'})
|
367 |
|
368 |
+
keep_duplicate_segments: bool = field(
|
369 |
+
default=False, metadata={'help': 'Keep duplicate segments'}
|
370 |
+
)
|
371 |
+
|
372 |
do_process_database: bool = field(
|
373 |
default=False, metadata={'help': 'Process the raw database'}
|
374 |
)
|
|
|
457 |
return total_bytes == os.path.getsize(filename)
|
458 |
|
459 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
460 |
@dataclass
|
461 |
class DatasetArguments:
|
462 |
data_dir: Optional[str] = field(
|
|
|
550 |
break
|
551 |
print('Failed, trying next')
|
552 |
|
553 |
+
os.makedirs(dataset_args.data_dir, exist_ok=True)
|
554 |
processed_db_path = os.path.join(
|
555 |
dataset_args.data_dir, dataset_args.processed_database)
|
556 |
|
557 |
+
# TODO process all valid possible items and then do filtering only later
|
558 |
@lru_cache(maxsize=1)
|
559 |
def read_db():
|
560 |
if not preprocess_args.overwrite and os.path.exists(processed_db_path):
|
|
|
569 |
|
570 |
for line in reader:
|
571 |
|
572 |
+
# Never show:
|
573 |
if line['service'] != 'YouTube':
|
574 |
continue
|
575 |
if len(line['videoID']) != 11:
|
|
|
615 |
})
|
616 |
|
617 |
# Remove duplicate sponsor segments by choosing best (most votes)
|
618 |
+
if not preprocess_args.keep_duplicate_segments:
|
619 |
+
print('Remove duplicate segments')
|
620 |
+
for key in db:
|
621 |
+
db[key] = remove_duplicate_segments(db[key])
|
622 |
|
623 |
# We now remove whole videos from the list
|
624 |
# Helps with obtaining "fully-labelled" videos
|
|
|
667 |
|
668 |
video_ids = list(parsed_database.keys() - finished)
|
669 |
|
670 |
+
# https://stackoverflow.com/a/63495323
|
671 |
+
import concurrent
|
672 |
+
POLL_INTERVAL = 0.1
|
673 |
+
|
674 |
+
# Wrap get words function to return video_id after completion
|
675 |
+
def get_words_wrapper(video_id):
|
676 |
+
get_words(video_id)
|
677 |
+
return video_id
|
678 |
+
|
679 |
+
print('Setting up ThreadPoolExecutor')
|
680 |
+
with concurrent.futures.ThreadPoolExecutor(max_workers=preprocess_args.num_jobs) as pool, \
|
681 |
+
tqdm(total=len(video_ids)) as progress:
|
682 |
+
|
683 |
+
all_futures = (pool.submit(get_words_wrapper, video_id)
|
684 |
+
for video_id in video_ids)
|
685 |
+
to_process = set(itertools.islice(
|
686 |
+
all_futures, preprocess_args.num_jobs))
|
687 |
+
try:
|
688 |
+
while to_process:
|
689 |
+
just_finished, to_process = concurrent.futures.wait(
|
690 |
+
to_process, timeout=POLL_INTERVAL)
|
691 |
+
to_process |= set(itertools.islice(
|
692 |
+
all_futures, len(just_finished)))
|
693 |
+
|
694 |
+
for d in just_finished:
|
695 |
+
progress.set_description(f'Processed {d.result()}')
|
696 |
+
progress.update()
|
697 |
+
|
698 |
+
except KeyboardInterrupt:
|
699 |
+
print('Gracefully shutting down: Cancelling unscheduled tasks')
|
700 |
+
|
701 |
+
# only futures that are not done will prevent exiting
|
702 |
+
for future in to_process:
|
703 |
+
future.cancel()
|
704 |
+
|
705 |
+
print('Waiting for in-progress tasks to complete')
|
706 |
+
concurrent.futures.wait(to_process, timeout=None)
|
707 |
+
print('Cancellation successful')
|
708 |
|
709 |
final_path = os.path.join(
|
710 |
dataset_args.data_dir, dataset_args.processed_file)
|
|
|
716 |
|
717 |
parsed_database = read_db()
|
718 |
|
719 |
+
transcribed = set(x.split('.')[0] for x in os.listdir(
|
720 |
+
'transcripts/auto/') + os.listdir('transcripts/manual/'))
|
721 |
+
|
722 |
+
# Only consider videos that have been transcribed already
|
723 |
+
video_ids = parsed_database.keys() & transcribed
|
724 |
+
|
725 |
+
with tqdm(total=len(video_ids)) as progress:
|
726 |
+
for index, video_id in enumerate(video_ids):
|
727 |
if preprocess_args.max_videos is not None and index >= preprocess_args.max_videos:
|
728 |
break
|
729 |
progress.set_description(f'Processing {video_id}')
|
|
|
734 |
continue
|
735 |
|
736 |
final_vid_segs = []
|
737 |
+
# Only add segments with high enough wps
|
738 |
+
for seg in parsed_database[video_id]:
|
739 |
segment_words = segment.extract_segment(
|
740 |
video_words, seg['start'], seg['end'])
|
741 |
|
|
|
778 |
# if not os.path.exists(excess_path) or preprocess_args.overwrite
|
779 |
# TODO use overwrite param
|
780 |
|
|
|
|
|
781 |
positive_file = os.path.join(
|
782 |
dataset_args.data_dir, dataset_args.positive_file)
|
783 |
negative_file = os.path.join(
|
|
|
803 |
|
804 |
data = list(itertools.islice(data, start_index, end_index))
|
805 |
|
806 |
+
write_mode = 'w' # if preprocess_args.overwrite else 'a'
|
807 |
with open(positive_file, write_mode, encoding='utf-8') as positive, \
|
808 |
open(negative_file, write_mode, encoding='utf-8') as negative, \
|
809 |
tqdm(data) as progress:
|
|
|
813 |
progress.set_description(f'Processing {video_id}')
|
814 |
progress.update()
|
815 |
|
816 |
+
# Use chunk granularity to improve manual transcripts
|
817 |
+
words = get_words(video_id, process=False, granularity='chunk')
|
818 |
if not words:
|
819 |
continue
|
820 |
|
821 |
+
if len(words) <= 1:
|
|
|
822 |
continue
|
823 |
|
|
|
|
|
824 |
segments = segment.generate_labelled_segments(
|
825 |
words, tokenizer, segmentation_args, sponsor_segments)
|
826 |
|
|
|
830 |
for seg in segments:
|
831 |
seg_start = segment.word_start(seg[0])
|
832 |
seg_end = segment.word_end(seg[-1])
|
833 |
+
duration = seg_end - seg_start
|
834 |
+
wps = len(seg)/duration if duration > 0 else 0
|
835 |
|
836 |
+
# Ignore segments with "not enough words" in the transcript
|
837 |
+
# Must do here since this includes non-sponsor segments
|
838 |
+
if wps < preprocess_args.min_wps:
|
839 |
+
continue
|
840 |
|
841 |
d = {
|
842 |
'video_index': offset + start_index,
|