EColi commited on
Commit
dc4dce6
1 Parent(s): 1d7009e

Fix structure

Browse files
README.md ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ tags:
3
+ - text-classification
4
+ - generic
5
+ library_name: generic
6
+ widget:
7
+ - text: 'This video is sponsored by squarespace'
8
+ example_title: Sponsor
9
+ - text: 'Check out the merch at linustechtips.com'
10
+ example_title: Unpaid/self promotion
11
+ - text: "Don't forget to like, comment and subscribe"
12
+ example_title: Interaction reminder
13
+ - text: 'pqh4LfPeCYs,824.695,826.267,826.133,829.876,835.933,927.581'
14
+ example_title: Extract text from video
15
+ ---
checkpoint-325000/added_tokens.json → added_tokens.json RENAMED
File without changes
checkpoint-325000/optimizer.pt DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:96765e5aa06e0e6bb3828a8da9c276e30fefada85f8a18852f84b00ff074a1ff
3
- size 876116189
 
 
 
checkpoint-325000/config.json → config.json RENAMED
File without changes
pipeline.py ADDED
@@ -0,0 +1,325 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from functools import lru_cache
3
+ from youtube_transcript_api import (
4
+ YouTubeTranscriptApi,
5
+ TooManyRequests,
6
+ YouTubeRequestFailed,
7
+ CouldNotRetrieveTranscript
8
+ )
9
+ import json
10
+ import re
11
+ import requests
12
+ from transformers import (
13
+ AutoModelForSequenceClassification,
14
+ AutoTokenizer,
15
+ TextClassificationPipeline,
16
+ )
17
+ from typing import Any, Dict, List
18
+ import os
19
+ import numpy as np
20
+
21
+ CATEGORIES = [None, 'SPONSOR', 'SELFPROMO', 'INTERACTION']
22
+
23
+ PROFANITY_RAW = '[ __ ]' # How YouTube transcribes profanity
24
+ PROFANITY_CONVERTED = '*****' # Safer version for tokenizing
25
+
26
+ NUM_DECIMALS = 3
27
+
28
+ # https://www.fincher.org/Utilities/CountryLanguageList.shtml
29
+ # https://lingohub.com/developers/supported-locales/language-designators-with-regions
30
+ LANGUAGE_PREFERENCE_LIST = ['en-GB', 'en-US', 'en-CA', 'en-AU', 'en-NZ', 'en-ZA',
31
+ 'en-IE', 'en-IN', 'en-JM', 'en-BZ', 'en-TT', 'en-PH', 'en-ZW',
32
+ 'en']
33
+
34
+
35
+ def parse_transcript_json(json_data, granularity):
36
+ assert json_data['wireMagic'] == 'pb3'
37
+
38
+ assert granularity in ('word', 'chunk')
39
+
40
+ # TODO remove bracketed words?
41
+ # (kiss smacks)
42
+ # (upbeat music)
43
+ # [text goes here]
44
+
45
+ # Some manual transcripts aren't that well formatted... but do have punctuation
46
+ # https://www.youtube.com/watch?v=LR9FtWVjk2c
47
+
48
+ parsed_transcript = []
49
+
50
+ events = json_data['events']
51
+
52
+ for event_index, event in enumerate(events):
53
+ segments = event.get('segs')
54
+ if not segments:
55
+ continue
56
+
57
+ # This value is known (when phrase appears on screen)
58
+ start_ms = event['tStartMs']
59
+ total_characters = 0
60
+
61
+ new_segments = []
62
+ for seg in segments:
63
+ # Replace \n, \t, etc. with space
64
+ text = ' '.join(seg['utf8'].split())
65
+
66
+ # Remove zero-width spaces and strip trailing and leading whitespace
67
+ text = text.replace('\u200b', '').replace('\u200c', '').replace(
68
+ '\u200d', '').replace('\ufeff', '').strip()
69
+
70
+ # Alternatively,
71
+ # text = text.encode('ascii', 'ignore').decode()
72
+
73
+ # Needed for auto-generated transcripts
74
+ text = text.replace(PROFANITY_RAW, PROFANITY_CONVERTED)
75
+
76
+ if not text:
77
+ continue
78
+
79
+ offset_ms = seg.get('tOffsetMs', 0)
80
+
81
+ new_segments.append({
82
+ 'text': text,
83
+ 'start': round((start_ms + offset_ms)/1000, NUM_DECIMALS)
84
+ })
85
+
86
+ total_characters += len(text)
87
+
88
+ if not new_segments:
89
+ continue
90
+
91
+ if event_index < len(events) - 1:
92
+ next_start_ms = events[event_index + 1]['tStartMs']
93
+ total_event_duration_ms = min(
94
+ event.get('dDurationMs', float('inf')), next_start_ms - start_ms)
95
+ else:
96
+ total_event_duration_ms = event.get('dDurationMs', 0)
97
+
98
+ # Ensure duration is non-negative
99
+ total_event_duration_ms = max(total_event_duration_ms, 0)
100
+
101
+ avg_seconds_per_character = (
102
+ total_event_duration_ms/total_characters)/1000
103
+
104
+ num_char_count = 0
105
+ for seg_index, seg in enumerate(new_segments):
106
+ num_char_count += len(seg['text'])
107
+
108
+ # Estimate segment end
109
+ seg_end = seg['start'] + \
110
+ (num_char_count * avg_seconds_per_character)
111
+
112
+ if seg_index < len(new_segments) - 1:
113
+ # Do not allow longer than next
114
+ seg_end = min(seg_end, new_segments[seg_index+1]['start'])
115
+
116
+ seg['end'] = round(seg_end, NUM_DECIMALS)
117
+ parsed_transcript.append(seg)
118
+
119
+ final_parsed_transcript = []
120
+ for i in range(len(parsed_transcript)):
121
+
122
+ word_level = granularity == 'word'
123
+ if word_level:
124
+ split_text = parsed_transcript[i]['text'].split()
125
+ elif granularity == 'chunk':
126
+ # Split on space after punctuation
127
+ split_text = re.split(
128
+ r'(?<=[.!?,-;])\s+', parsed_transcript[i]['text'])
129
+ if len(split_text) == 1:
130
+ split_on_whitespace = parsed_transcript[i]['text'].split()
131
+
132
+ if len(split_on_whitespace) >= 8: # Too many words
133
+ # Rather split on whitespace instead of punctuation
134
+ split_text = split_on_whitespace
135
+ else:
136
+ word_level = True
137
+ else:
138
+ raise ValueError('Unknown granularity')
139
+
140
+ segment_end = parsed_transcript[i]['end']
141
+ if i < len(parsed_transcript) - 1:
142
+ segment_end = min(segment_end, parsed_transcript[i+1]['start'])
143
+
144
+ segment_duration = segment_end - parsed_transcript[i]['start']
145
+
146
+ num_chars_in_text = sum(map(len, split_text))
147
+
148
+ num_char_count = 0
149
+ current_offset = 0
150
+ for s in split_text:
151
+ num_char_count += len(s)
152
+
153
+ next_offset = (num_char_count/num_chars_in_text) * segment_duration
154
+
155
+ word_start = round(
156
+ parsed_transcript[i]['start'] + current_offset, NUM_DECIMALS)
157
+ word_end = round(
158
+ parsed_transcript[i]['start'] + next_offset, NUM_DECIMALS)
159
+
160
+ # Make the reasonable assumption that min wps is 1.5
161
+ final_parsed_transcript.append({
162
+ 'text': s,
163
+ 'start': word_start,
164
+ 'end': min(word_end, word_start + 1.5) if word_level else word_end
165
+ })
166
+ current_offset = next_offset
167
+
168
+ return final_parsed_transcript
169
+
170
+
171
+ def list_transcripts(video_id):
172
+ try:
173
+ return YouTubeTranscriptApi.list_transcripts(video_id)
174
+ except json.decoder.JSONDecodeError:
175
+ return None
176
+
177
+
178
+ WORDS_TO_REMOVE = [
179
+ '[Music]'
180
+ '[Applause]'
181
+ '[Laughter]'
182
+ ]
183
+
184
+
185
+ @lru_cache(maxsize=16)
186
+ def get_words(video_id, transcript_type='auto', fallback='manual', filter_words_to_remove=True, granularity='word'):
187
+ """Get parsed video transcript with caching system
188
+ returns None if not processed yet and process is False
189
+ """
190
+
191
+ raw_transcript_json = None
192
+ try:
193
+ transcript_list = list_transcripts(video_id)
194
+
195
+ if transcript_list is not None:
196
+ if transcript_type == 'manual':
197
+ ts = transcript_list.find_manually_created_transcript(
198
+ LANGUAGE_PREFERENCE_LIST)
199
+ else:
200
+ ts = transcript_list.find_generated_transcript(
201
+ LANGUAGE_PREFERENCE_LIST)
202
+ raw_transcript = ts._http_client.get(
203
+ f'{ts._url}&fmt=json3').content
204
+ if raw_transcript:
205
+ raw_transcript_json = json.loads(raw_transcript)
206
+ except (TooManyRequests, YouTubeRequestFailed):
207
+ raise # Cannot recover from these errors and do not mark as empty transcript
208
+
209
+ except requests.exceptions.RequestException: # Can recover
210
+ return get_words(video_id, transcript_type, fallback, granularity)
211
+
212
+ except CouldNotRetrieveTranscript: # Retrying won't solve
213
+ pass # Mark as empty transcript
214
+
215
+ except json.decoder.JSONDecodeError:
216
+ return get_words(video_id, transcript_type, fallback, granularity)
217
+
218
+ if not raw_transcript_json and fallback is not None:
219
+ return get_words(video_id, transcript_type=fallback, fallback=None, granularity=granularity)
220
+
221
+ if raw_transcript_json:
222
+ processed_transcript = parse_transcript_json(
223
+ raw_transcript_json, granularity)
224
+ if filter_words_to_remove:
225
+ processed_transcript = list(
226
+ filter(lambda x: x['text'] not in WORDS_TO_REMOVE, processed_transcript))
227
+ else:
228
+ processed_transcript = raw_transcript_json # Either None or []
229
+
230
+ return processed_transcript
231
+
232
+
233
+ def word_start(word):
234
+ return word['start']
235
+
236
+
237
+ def word_end(word):
238
+ return word.get('end', word['start'])
239
+
240
+
241
+ def extract_segment(words, start, end, map_function=None):
242
+ """Extracts all words with time in [start, end]"""
243
+
244
+ a = max(binary_search_below(words, 0, len(words), start), 0)
245
+ b = min(binary_search_above(words, -1, len(words) - 1, end) + 1, len(words))
246
+
247
+ to_transform = map_function is not None and callable(map_function)
248
+
249
+ return [
250
+ map_function(words[i]) if to_transform else words[i] for i in range(a, b)
251
+ ]
252
+
253
+
254
+ def avg(*items):
255
+ return sum(items)/len(items)
256
+
257
+
258
+ def binary_search_below(transcript, start_index, end_index, time):
259
+ if start_index >= end_index:
260
+ return end_index
261
+
262
+ middle_index = (start_index + end_index) // 2
263
+ middle = transcript[middle_index]
264
+ middle_time = avg(word_start(middle), word_end(middle))
265
+
266
+ if time <= middle_time:
267
+ return binary_search_below(transcript, start_index, middle_index, time)
268
+ else:
269
+ return binary_search_below(transcript, middle_index + 1, end_index, time)
270
+
271
+
272
+ def binary_search_above(transcript, start_index, end_index, time):
273
+ if start_index >= end_index:
274
+ return end_index
275
+
276
+ middle_index = (start_index + end_index + 1) // 2
277
+ middle = transcript[middle_index]
278
+ middle_time = avg(word_start(middle), word_end(middle))
279
+
280
+ if time >= middle_time:
281
+ return binary_search_above(transcript, middle_index, end_index, time)
282
+ else:
283
+ return binary_search_above(transcript, start_index, middle_index - 1, time)
284
+
285
+
286
+ class PreTrainedPipeline():
287
+ def __init__(self, path: str):
288
+ self.model2 = AutoModelForSequenceClassification.from_pretrained(path)
289
+ self.tokenizer2 = AutoTokenizer.from_pretrained(path)
290
+ self.pipeline2 = SponsorBlockClassificationPipeline(
291
+ model=self.model2, tokenizer=self.tokenizer2)
292
+
293
+ def __call__(self, inputs: str) -> List[Dict[str, Any]]:
294
+
295
+ # Automated call (compressed string)
296
+ if ' ' not in inputs and inputs.count(',') >= 2:
297
+ split_info = inputs.split(',', 1)
298
+ times = np.reshape(np.array(split_info[1].split(',')), (-1, 2))
299
+ data = []
300
+ for start, end in times:
301
+ data.append({
302
+ 'video_id': split_info[0],
303
+ 'start': float(start),
304
+ 'end': float(end)
305
+ })
306
+ else:
307
+ data = inputs
308
+
309
+ return self.pipeline2(data)
310
+
311
+
312
+ class SponsorBlockClassificationPipeline(TextClassificationPipeline):
313
+ def __init__(self, model, tokenizer):
314
+ super().__init__(model=model, tokenizer=tokenizer, return_all_scores=True)
315
+
316
+ def preprocess(self, data, **tokenizer_kwargs):
317
+ if isinstance(data, str): # If string, assume this is what user wants to classify
318
+ text = data
319
+ else: # Otherwise, get data from transcript
320
+ words = get_words(data['video_id'])
321
+ segment_words = extract_segment(words, data['start'], data['end'])
322
+ text = ' '.join(x['text'] for x in segment_words)
323
+
324
+ return self.tokenizer(
325
+ text, return_tensors=self.framework, **tokenizer_kwargs)
checkpoint-325000/pytorch_model.bin → pytorch_model.bin RENAMED
File without changes
requirements.txt ADDED
@@ -0,0 +1 @@
 
1
+ youtube_transcript_api
checkpoint-325000/rng_state.pth → rng_state.pth RENAMED
File without changes
checkpoint-325000/scheduler.pt → scheduler.pt RENAMED
File without changes
checkpoint-325000/special_tokens_map.json → special_tokens_map.json RENAMED
File without changes
checkpoint-325000/tokenizer.json → tokenizer.json RENAMED
File without changes
checkpoint-325000/tokenizer_config.json → tokenizer_config.json RENAMED
File without changes
checkpoint-325000/trainer_state.json → trainer_state.json RENAMED
File without changes
checkpoint-325000/training_args.bin → training_args.bin RENAMED
File without changes
checkpoint-325000/vocab.txt → vocab.txt RENAMED
File without changes