Joshua Lochner commited on
Commit
ee58f38
1 Parent(s): 34790a9

Update pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline.py +6 -325
pipeline.py CHANGED
@@ -1,333 +1,14 @@
1
- import youtube_transcript_api2
2
- import json
3
- import re
4
- import requests
5
- from transformers import (
6
- AutoModelForSequenceClassification,
7
- AutoTokenizer,
8
- TextClassificationPipeline,
9
- )
10
  from typing import Any, Dict, List
11
 
12
- CATEGORIES = [None, 'SPONSOR', 'SELFPROMO', 'INTERACTION']
13
-
14
- PROFANITY_RAW = '[ __ ]' # How YouTube transcribes profanity
15
- PROFANITY_CONVERTED = '*****' # Safer version for tokenizing
16
-
17
- NUM_DECIMALS = 3
18
-
19
- # https://www.fincher.org/Utilities/CountryLanguageList.shtml
20
- # https://lingohub.com/developers/supported-locales/language-designators-with-regions
21
- LANGUAGE_PREFERENCE_LIST = ['en-GB', 'en-US', 'en-CA', 'en-AU', 'en-NZ', 'en-ZA',
22
- 'en-IE', 'en-IN', 'en-JM', 'en-BZ', 'en-TT', 'en-PH', 'en-ZW',
23
- 'en']
24
-
25
-
26
- def parse_transcript_json(json_data, granularity):
27
- assert json_data['wireMagic'] == 'pb3'
28
-
29
- assert granularity in ('word', 'chunk')
30
-
31
- # TODO remove bracketed words?
32
- # (kiss smacks)
33
- # (upbeat music)
34
- # [text goes here]
35
-
36
- # Some manual transcripts aren't that well formatted... but do have punctuation
37
- # https://www.youtube.com/watch?v=LR9FtWVjk2c
38
-
39
- parsed_transcript = []
40
-
41
- events = json_data['events']
42
-
43
- for event_index, event in enumerate(events):
44
- segments = event.get('segs')
45
- if not segments:
46
- continue
47
-
48
- # This value is known (when phrase appears on screen)
49
- start_ms = event['tStartMs']
50
- total_characters = 0
51
-
52
- new_segments = []
53
- for seg in segments:
54
- # Replace \n, \t, etc. with space
55
- text = ' '.join(seg['utf8'].split())
56
-
57
- # Remove zero-width spaces and strip trailing and leading whitespace
58
- text = text.replace('\u200b', '').replace('\u200c', '').replace(
59
- '\u200d', '').replace('\ufeff', '').strip()
60
-
61
- # Alternatively,
62
- # text = text.encode('ascii', 'ignore').decode()
63
-
64
- # Needed for auto-generated transcripts
65
- text = text.replace(PROFANITY_RAW, PROFANITY_CONVERTED)
66
-
67
- if not text:
68
- continue
69
-
70
- offset_ms = seg.get('tOffsetMs', 0)
71
-
72
- new_segments.append({
73
- 'text': text,
74
- 'start': round((start_ms + offset_ms)/1000, NUM_DECIMALS)
75
- })
76
-
77
- total_characters += len(text)
78
-
79
- if not new_segments:
80
- continue
81
-
82
- if event_index < len(events) - 1:
83
- next_start_ms = events[event_index + 1]['tStartMs']
84
- total_event_duration_ms = min(
85
- event.get('dDurationMs', float('inf')), next_start_ms - start_ms)
86
- else:
87
- total_event_duration_ms = event.get('dDurationMs', 0)
88
-
89
- # Ensure duration is non-negative
90
- total_event_duration_ms = max(total_event_duration_ms, 0)
91
-
92
- avg_seconds_per_character = (
93
- total_event_duration_ms/total_characters)/1000
94
-
95
- num_char_count = 0
96
- for seg_index, seg in enumerate(new_segments):
97
- num_char_count += len(seg['text'])
98
-
99
- # Estimate segment end
100
- seg_end = seg['start'] + \
101
- (num_char_count * avg_seconds_per_character)
102
-
103
- if seg_index < len(new_segments) - 1:
104
- # Do not allow longer than next
105
- seg_end = min(seg_end, new_segments[seg_index+1]['start'])
106
-
107
- seg['end'] = round(seg_end, NUM_DECIMALS)
108
- parsed_transcript.append(seg)
109
-
110
- final_parsed_transcript = []
111
- for i in range(len(parsed_transcript)):
112
-
113
- word_level = granularity == 'word'
114
- if word_level:
115
- split_text = parsed_transcript[i]['text'].split()
116
- elif granularity == 'chunk':
117
- # Split on space after punctuation
118
- split_text = re.split(
119
- r'(?<=[.!?,-;])\s+', parsed_transcript[i]['text'])
120
- if len(split_text) == 1:
121
- split_on_whitespace = parsed_transcript[i]['text'].split()
122
-
123
- if len(split_on_whitespace) >= 8: # Too many words
124
- # Rather split on whitespace instead of punctuation
125
- split_text = split_on_whitespace
126
- else:
127
- word_level = True
128
- else:
129
- raise ValueError('Unknown granularity')
130
-
131
- segment_end = parsed_transcript[i]['end']
132
- if i < len(parsed_transcript) - 1:
133
- segment_end = min(segment_end, parsed_transcript[i+1]['start'])
134
-
135
- segment_duration = segment_end - parsed_transcript[i]['start']
136
-
137
- num_chars_in_text = sum(map(len, split_text))
138
-
139
- num_char_count = 0
140
- current_offset = 0
141
- for s in split_text:
142
- num_char_count += len(s)
143
-
144
- next_offset = (num_char_count/num_chars_in_text) * segment_duration
145
-
146
- word_start = round(
147
- parsed_transcript[i]['start'] + current_offset, NUM_DECIMALS)
148
- word_end = round(
149
- parsed_transcript[i]['start'] + next_offset, NUM_DECIMALS)
150
-
151
- # Make the reasonable assumption that min wps is 1.5
152
- final_parsed_transcript.append({
153
- 'text': s,
154
- 'start': word_start,
155
- 'end': min(word_end, word_start + 1.5) if word_level else word_end
156
- })
157
- current_offset = next_offset
158
-
159
- return final_parsed_transcript
160
-
161
-
162
- def list_transcripts(video_id):
163
- try:
164
- return youtube_transcript_api2.YouTubeTranscriptApi.list_transcripts(video_id)
165
- except json.decoder.JSONDecodeError:
166
- return None
167
-
168
-
169
- WORDS_TO_REMOVE = [
170
- '[Music]'
171
- '[Applause]'
172
- '[Laughter]'
173
- ]
174
-
175
-
176
- def get_words(video_id, transcript_type='auto', fallback='manual', filter_words_to_remove=True, granularity='word'):
177
- """Get parsed video transcript with caching system
178
- returns None if not processed yet and process is False
179
- """
180
-
181
- raw_transcript_json = None
182
- try:
183
- transcript_list = list_transcripts(video_id)
184
-
185
- if transcript_list is not None:
186
- if transcript_type == 'manual':
187
- ts = transcript_list.find_manually_created_transcript(
188
- LANGUAGE_PREFERENCE_LIST)
189
- else:
190
- ts = transcript_list.find_generated_transcript(
191
- LANGUAGE_PREFERENCE_LIST)
192
- raw_transcript = ts._http_client.get(
193
- f'{ts._url}&fmt=json3').content
194
- if raw_transcript:
195
- raw_transcript_json = json.loads(raw_transcript)
196
-
197
- except (youtube_transcript_api2.TooManyRequests, youtube_transcript_api2.YouTubeRequestFailed):
198
- raise # Cannot recover from these errors and do not mark as empty transcript
199
-
200
- except requests.exceptions.RequestException: # Can recover
201
- return get_words(video_id, transcript_type, fallback, granularity)
202
-
203
- except youtube_transcript_api2.CouldNotRetrieveTranscript: # Retrying won't solve
204
- pass # Mark as empty transcript
205
-
206
- except json.decoder.JSONDecodeError:
207
- return get_words(video_id, transcript_type, fallback, granularity)
208
-
209
- if not raw_transcript_json and fallback is not None:
210
- return get_words(video_id, transcript_type=fallback, fallback=None, granularity=granularity)
211
-
212
- if raw_transcript_json:
213
- processed_transcript = parse_transcript_json(
214
- raw_transcript_json, granularity)
215
- if filter_words_to_remove:
216
- processed_transcript = list(
217
- filter(lambda x: x['text'] not in WORDS_TO_REMOVE, processed_transcript))
218
- else:
219
- processed_transcript = raw_transcript_json # Either None or []
220
-
221
- return processed_transcript
222
-
223
-
224
- def word_start(word):
225
- return word['start']
226
-
227
-
228
- def word_end(word):
229
- return word.get('end', word['start'])
230
-
231
-
232
- def extract_segment(words, start, end, map_function=None):
233
- """Extracts all words with time in [start, end]"""
234
-
235
- a = max(binary_search_below(words, 0, len(words), start), 0)
236
- b = min(binary_search_above(words, -1, len(words) - 1, end) + 1, len(words))
237
-
238
- to_transform = map_function is not None and callable(map_function)
239
-
240
- return [
241
- map_function(words[i]) if to_transform else words[i] for i in range(a, b)
242
- ]
243
-
244
-
245
- def avg(*items):
246
- return sum(items)/len(items)
247
-
248
-
249
- def binary_search_below(transcript, start_index, end_index, time):
250
- if start_index >= end_index:
251
- return end_index
252
-
253
- middle_index = (start_index + end_index) // 2
254
- middle = transcript[middle_index]
255
- middle_time = avg(word_start(middle), word_end(middle))
256
-
257
- if time <= middle_time:
258
- return binary_search_below(transcript, start_index, middle_index, time)
259
- else:
260
- return binary_search_below(transcript, middle_index + 1, end_index, time)
261
-
262
-
263
- def binary_search_above(transcript, start_index, end_index, time):
264
- if start_index >= end_index:
265
- return end_index
266
-
267
- middle_index = (start_index + end_index + 1) // 2
268
- middle = transcript[middle_index]
269
- middle_time = avg(word_start(middle), word_end(middle))
270
-
271
- if time >= middle_time:
272
- return binary_search_above(transcript, middle_index, end_index, time)
273
- else:
274
- return binary_search_above(transcript, start_index, middle_index - 1, time)
275
-
276
-
277
- class SponsorBlockClassificationPipeline(TextClassificationPipeline):
278
- def __init__(self, model, tokenizer):
279
- super().__init__(model=model, tokenizer=tokenizer, return_all_scores=True)
280
-
281
- def preprocess(self, video, **tokenizer_kwargs):
282
-
283
- words = get_words(video['video_id'])
284
- segment_words = extract_segment(words, video['start'], video['end'])
285
- text = ' '.join(x['text'] for x in segment_words)
286
-
287
- model_inputs = self.tokenizer(
288
- text, return_tensors=self.framework, **tokenizer_kwargs)
289
- return {'video': video, 'model_inputs': model_inputs}
290
-
291
- def _forward(self, data):
292
- model_outputs = self.model(**data['model_inputs'])
293
- return {'video': data['video'], 'model_outputs': model_outputs}
294
-
295
- def postprocess(self, data, function_to_apply=None, return_all_scores=False):
296
- model_outputs = data['model_outputs']
297
-
298
- results = super().postprocess(model_outputs, function_to_apply, return_all_scores)
299
-
300
- for result in results:
301
- result['label_text'] = CATEGORIES[result['label']]
302
-
303
- return results # {**data['video'], 'result': results}
304
-
305
-
306
- # model_id = "Xenova/sponsorblock-classifier-v2"
307
- # model = AutoModelForSequenceClassification.from_pretrained(model_id)
308
- # tokenizer = AutoTokenizer.from_pretrained(model_id)
309
-
310
- # pl = SponsorBlockClassificationPipeline(model=model, tokenizer=tokenizer)
311
- data = [{
312
- 'video_id': 'pqh4LfPeCYs',
313
- 'start': 835.933,
314
- 'end': 927.581,
315
- 'category': 'sponsor'
316
- }]
317
- # print(pl(data))
318
-
319
- # MODEL_ID = "Xenova/sponsorblock-classifier-v2"
320
  class PreTrainedPipeline():
321
  def __init__(self, path: str):
322
  # load the model
323
- self.model = AutoModelForSequenceClassification.from_pretrained(path)
324
- self.tokenizer = AutoTokenizer.from_pretrained(path)
325
- self.pipeline = SponsorBlockClassificationPipeline(
326
- model=self.model, tokenizer=self.tokenizer)
327
 
328
  def __call__(self, inputs: str) -> List[Dict[str, Any]]:
329
- json_data = json.loads(inputs)
330
- return self.pipeline(json_data)
331
 
332
- # a = PreTrainedPipeline('Xenova/sponsorblock-classifier-v2')(json.dumps(data))
333
- # print(a)
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from typing import Any, Dict, List
2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  class PreTrainedPipeline():
4
  def __init__(self, path: str):
5
  # load the model
6
+ self.model = None
 
 
 
7
 
8
  def __call__(self, inputs: str) -> List[Dict[str, Any]]:
 
 
9
 
10
+ return [{
11
+ "label": 0,
12
+ "mask": 1,
13
+ "score": 1.0,
14
+ }]