Joshua Lochner commited on
Commit
aacb405
1 Parent(s): f7b6109

Update pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline.py +2 -316
pipeline.py CHANGED
@@ -1,13 +1,3 @@
1
- import youtube_transcript_api2
2
- import json
3
- import re
4
- import requests
5
- from transformers import (
6
- AutoModelForSequenceClassification,
7
- AutoTokenizer,
8
- TextClassificationPipeline,
9
- )
10
- from typing import Any, Dict, List
11
  import json
12
  from typing import Any, Dict, List
13
 
@@ -18,319 +8,16 @@ import io
18
  import os
19
  import numpy as np
20
  from PIL import Image
21
- CATEGORIES = [None, 'SPONSOR', 'SELFPROMO', 'INTERACTION']
22
-
23
- PROFANITY_RAW = '[ __ ]' # How YouTube transcribes profanity
24
- PROFANITY_CONVERTED = '*****' # Safer version for tokenizing
25
-
26
- NUM_DECIMALS = 3
27
-
28
- # https://www.fincher.org/Utilities/CountryLanguageList.shtml
29
- # https://lingohub.com/developers/supported-locales/language-designators-with-regions
30
- LANGUAGE_PREFERENCE_LIST = ['en-GB', 'en-US', 'en-CA', 'en-AU', 'en-NZ', 'en-ZA',
31
- 'en-IE', 'en-IN', 'en-JM', 'en-BZ', 'en-TT', 'en-PH', 'en-ZW',
32
- 'en']
33
-
34
-
35
- def parse_transcript_json(json_data, granularity):
36
- assert json_data['wireMagic'] == 'pb3'
37
-
38
- assert granularity in ('word', 'chunk')
39
-
40
- # TODO remove bracketed words?
41
- # (kiss smacks)
42
- # (upbeat music)
43
- # [text goes here]
44
-
45
- # Some manual transcripts aren't that well formatted... but do have punctuation
46
- # https://www.youtube.com/watch?v=LR9FtWVjk2c
47
-
48
- parsed_transcript = []
49
-
50
- events = json_data['events']
51
-
52
- for event_index, event in enumerate(events):
53
- segments = event.get('segs')
54
- if not segments:
55
- continue
56
-
57
- # This value is known (when phrase appears on screen)
58
- start_ms = event['tStartMs']
59
- total_characters = 0
60
-
61
- new_segments = []
62
- for seg in segments:
63
- # Replace \n, \t, etc. with space
64
- text = ' '.join(seg['utf8'].split())
65
-
66
- # Remove zero-width spaces and strip trailing and leading whitespace
67
- text = text.replace('\u200b', '').replace('\u200c', '').replace(
68
- '\u200d', '').replace('\ufeff', '').strip()
69
-
70
- # Alternatively,
71
- # text = text.encode('ascii', 'ignore').decode()
72
-
73
- # Needed for auto-generated transcripts
74
- text = text.replace(PROFANITY_RAW, PROFANITY_CONVERTED)
75
-
76
- if not text:
77
- continue
78
-
79
- offset_ms = seg.get('tOffsetMs', 0)
80
-
81
- new_segments.append({
82
- 'text': text,
83
- 'start': round((start_ms + offset_ms)/1000, NUM_DECIMALS)
84
- })
85
-
86
- total_characters += len(text)
87
-
88
- if not new_segments:
89
- continue
90
-
91
- if event_index < len(events) - 1:
92
- next_start_ms = events[event_index + 1]['tStartMs']
93
- total_event_duration_ms = min(
94
- event.get('dDurationMs', float('inf')), next_start_ms - start_ms)
95
- else:
96
- total_event_duration_ms = event.get('dDurationMs', 0)
97
-
98
- # Ensure duration is non-negative
99
- total_event_duration_ms = max(total_event_duration_ms, 0)
100
-
101
- avg_seconds_per_character = (
102
- total_event_duration_ms/total_characters)/1000
103
-
104
- num_char_count = 0
105
- for seg_index, seg in enumerate(new_segments):
106
- num_char_count += len(seg['text'])
107
-
108
- # Estimate segment end
109
- seg_end = seg['start'] + \
110
- (num_char_count * avg_seconds_per_character)
111
-
112
- if seg_index < len(new_segments) - 1:
113
- # Do not allow longer than next
114
- seg_end = min(seg_end, new_segments[seg_index+1]['start'])
115
-
116
- seg['end'] = round(seg_end, NUM_DECIMALS)
117
- parsed_transcript.append(seg)
118
-
119
- final_parsed_transcript = []
120
- for i in range(len(parsed_transcript)):
121
-
122
- word_level = granularity == 'word'
123
- if word_level:
124
- split_text = parsed_transcript[i]['text'].split()
125
- elif granularity == 'chunk':
126
- # Split on space after punctuation
127
- split_text = re.split(
128
- r'(?<=[.!?,-;])\s+', parsed_transcript[i]['text'])
129
- if len(split_text) == 1:
130
- split_on_whitespace = parsed_transcript[i]['text'].split()
131
-
132
- if len(split_on_whitespace) >= 8: # Too many words
133
- # Rather split on whitespace instead of punctuation
134
- split_text = split_on_whitespace
135
- else:
136
- word_level = True
137
- else:
138
- raise ValueError('Unknown granularity')
139
-
140
- segment_end = parsed_transcript[i]['end']
141
- if i < len(parsed_transcript) - 1:
142
- segment_end = min(segment_end, parsed_transcript[i+1]['start'])
143
-
144
- segment_duration = segment_end - parsed_transcript[i]['start']
145
-
146
- num_chars_in_text = sum(map(len, split_text))
147
-
148
- num_char_count = 0
149
- current_offset = 0
150
- for s in split_text:
151
- num_char_count += len(s)
152
-
153
- next_offset = (num_char_count/num_chars_in_text) * segment_duration
154
-
155
- word_start = round(
156
- parsed_transcript[i]['start'] + current_offset, NUM_DECIMALS)
157
- word_end = round(
158
- parsed_transcript[i]['start'] + next_offset, NUM_DECIMALS)
159
-
160
- # Make the reasonable assumption that min wps is 1.5
161
- final_parsed_transcript.append({
162
- 'text': s,
163
- 'start': word_start,
164
- 'end': min(word_end, word_start + 1.5) if word_level else word_end
165
- })
166
- current_offset = next_offset
167
-
168
- return final_parsed_transcript
169
-
170
 
171
- def list_transcripts(video_id):
172
- try:
173
- return youtube_transcript_api2.YouTubeTranscriptApi.list_transcripts(video_id)
174
- except json.decoder.JSONDecodeError:
175
- return None
176
 
177
 
178
- WORDS_TO_REMOVE = [
179
- '[Music]'
180
- '[Applause]'
181
- '[Laughter]'
182
- ]
183
-
184
-
185
- def get_words(video_id, transcript_type='auto', fallback='manual', filter_words_to_remove=True, granularity='word'):
186
- """Get parsed video transcript with caching system
187
- returns None if not processed yet and process is False
188
- """
189
-
190
- raw_transcript_json = None
191
- try:
192
- transcript_list = list_transcripts(video_id)
193
-
194
- if transcript_list is not None:
195
- if transcript_type == 'manual':
196
- ts = transcript_list.find_manually_created_transcript(
197
- LANGUAGE_PREFERENCE_LIST)
198
- else:
199
- ts = transcript_list.find_generated_transcript(
200
- LANGUAGE_PREFERENCE_LIST)
201
- raw_transcript = ts._http_client.get(
202
- f'{ts._url}&fmt=json3').content
203
- if raw_transcript:
204
- raw_transcript_json = json.loads(raw_transcript)
205
-
206
- except (youtube_transcript_api2.TooManyRequests, youtube_transcript_api2.YouTubeRequestFailed):
207
- raise # Cannot recover from these errors and do not mark as empty transcript
208
-
209
- except requests.exceptions.RequestException: # Can recover
210
- return get_words(video_id, transcript_type, fallback, granularity)
211
-
212
- except youtube_transcript_api2.CouldNotRetrieveTranscript: # Retrying won't solve
213
- pass # Mark as empty transcript
214
-
215
- except json.decoder.JSONDecodeError:
216
- return get_words(video_id, transcript_type, fallback, granularity)
217
-
218
- if not raw_transcript_json and fallback is not None:
219
- return get_words(video_id, transcript_type=fallback, fallback=None, granularity=granularity)
220
-
221
- if raw_transcript_json:
222
- processed_transcript = parse_transcript_json(
223
- raw_transcript_json, granularity)
224
- if filter_words_to_remove:
225
- processed_transcript = list(
226
- filter(lambda x: x['text'] not in WORDS_TO_REMOVE, processed_transcript))
227
- else:
228
- processed_transcript = raw_transcript_json # Either None or []
229
-
230
- return processed_transcript
231
-
232
-
233
- def word_start(word):
234
- return word['start']
235
-
236
-
237
- def word_end(word):
238
- return word.get('end', word['start'])
239
-
240
-
241
- def extract_segment(words, start, end, map_function=None):
242
- """Extracts all words with time in [start, end]"""
243
-
244
- a = max(binary_search_below(words, 0, len(words), start), 0)
245
- b = min(binary_search_above(words, -1, len(words) - 1, end) + 1, len(words))
246
-
247
- to_transform = map_function is not None and callable(map_function)
248
-
249
- return [
250
- map_function(words[i]) if to_transform else words[i] for i in range(a, b)
251
- ]
252
-
253
-
254
- def avg(*items):
255
- return sum(items)/len(items)
256
-
257
-
258
- def binary_search_below(transcript, start_index, end_index, time):
259
- if start_index >= end_index:
260
- return end_index
261
-
262
- middle_index = (start_index + end_index) // 2
263
- middle = transcript[middle_index]
264
- middle_time = avg(word_start(middle), word_end(middle))
265
-
266
- if time <= middle_time:
267
- return binary_search_below(transcript, start_index, middle_index, time)
268
- else:
269
- return binary_search_below(transcript, middle_index + 1, end_index, time)
270
-
271
-
272
- def binary_search_above(transcript, start_index, end_index, time):
273
- if start_index >= end_index:
274
- return end_index
275
-
276
- middle_index = (start_index + end_index + 1) // 2
277
- middle = transcript[middle_index]
278
- middle_time = avg(word_start(middle), word_end(middle))
279
-
280
- if time >= middle_time:
281
- return binary_search_above(transcript, middle_index, end_index, time)
282
- else:
283
- return binary_search_above(transcript, start_index, middle_index - 1, time)
284
-
285
-
286
- class SponsorBlockClassificationPipeline(TextClassificationPipeline):
287
- def __init__(self, model, tokenizer):
288
- super().__init__(model=model, tokenizer=tokenizer, return_all_scores=True)
289
-
290
- def preprocess(self, video, **tokenizer_kwargs):
291
-
292
- words = get_words(video['video_id'])
293
- segment_words = extract_segment(words, video['start'], video['end'])
294
- text = ' '.join(x['text'] for x in segment_words)
295
-
296
- model_inputs = self.tokenizer(
297
- text, return_tensors=self.framework, **tokenizer_kwargs)
298
- return {'video': video, 'model_inputs': model_inputs}
299
-
300
- def _forward(self, data):
301
- model_outputs = self.model(**data['model_inputs'])
302
- return {'video': data['video'], 'model_outputs': model_outputs}
303
-
304
- def postprocess(self, data, function_to_apply=None, return_all_scores=False):
305
- model_outputs = data['model_outputs']
306
-
307
- results = super().postprocess(model_outputs, function_to_apply, return_all_scores)
308
-
309
- for result in results:
310
- result['label_text'] = CATEGORIES[result['label']]
311
-
312
- return results # {**data['video'], 'result': results}
313
-
314
  class PreTrainedPipeline():
315
  def __init__(self, path: str):
316
  # load the model
317
- self.model = AutoModelForSequenceClassification.from_pretrained(path)
318
- self.tokenizer = AutoTokenizer.from_pretrained(path)
319
- self.pipeline = SponsorBlockClassificationPipeline(
320
- model=self.model, tokenizer=self.tokenizer)
321
 
322
- # def __call__(self, inputs: str) -> List[Dict[str, Any]]:
323
- # json_data = json.loads(inputs)
324
- # return self.pipeline(json_data)
325
  def __call__(self, inputs: "Image.Image")-> List[Dict[str, Any]]:
326
- data = [{
327
- 'video_id': 'pqh4LfPeCYs',
328
- 'start': 835.933,
329
- 'end': 927.581,
330
- 'category': 'sponsor'
331
- }]
332
- results = self.pipeline(data)
333
-
334
  # convert img to numpy array, resize and normalize to make the prediction
335
  img = np.array(inputs)
336
 
@@ -387,6 +74,5 @@ class PreTrainedPipeline():
387
  "label": f"LABEL_{cls}",
388
  "mask": mask_codes[f"mask_{cls}"],
389
  "score": 1.0,
390
- # "q": results
391
  })
392
  return labels
 
 
 
 
 
 
 
 
 
 
 
1
  import json
2
  from typing import Any, Dict, List
3
 
 
8
  import os
9
  import numpy as np
10
  from PIL import Image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
 
 
 
 
 
12
 
13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  class PreTrainedPipeline():
15
  def __init__(self, path: str):
16
  # load the model
17
+ self.model = keras.models.load_model(os.path.join(path, "tf_model.h5"))
 
 
 
18
 
 
 
 
19
  def __call__(self, inputs: "Image.Image")-> List[Dict[str, Any]]:
20
+
 
 
 
 
 
 
 
21
  # convert img to numpy array, resize and normalize to make the prediction
22
  img = np.array(inputs)
23
 
 
74
  "label": f"LABEL_{cls}",
75
  "mask": mask_codes[f"mask_{cls}"],
76
  "score": 1.0,
 
77
  })
78
  return labels