|
import json |
|
from typing import Any, Dict, List |
|
|
|
import tensorflow as tf |
|
from tensorflow import keras |
|
import base64 |
|
import io |
|
import os |
|
import numpy as np |
|
from PIL import Image |
|
|
|
import youtube_transcript_api2 |
|
import json |
|
import re |
|
import requests |
|
from transformers import ( |
|
AutoModelForSequenceClassification, |
|
AutoTokenizer, |
|
TextClassificationPipeline, |
|
) |
|
from typing import Any, Dict, List |
|
|
|
CATEGORIES = [None, 'SPONSOR', 'SELFPROMO', 'INTERACTION'] |
|
|
|
PROFANITY_RAW = '[ __ ]' |
|
PROFANITY_CONVERTED = '*****' |
|
|
|
NUM_DECIMALS = 3 |
|
|
|
|
|
|
|
LANGUAGE_PREFERENCE_LIST = ['en-GB', 'en-US', 'en-CA', 'en-AU', 'en-NZ', 'en-ZA', |
|
'en-IE', 'en-IN', 'en-JM', 'en-BZ', 'en-TT', 'en-PH', 'en-ZW', |
|
'en'] |
|
|
|
|
|
def parse_transcript_json(json_data, granularity): |
|
assert json_data['wireMagic'] == 'pb3' |
|
|
|
assert granularity in ('word', 'chunk') |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
parsed_transcript = [] |
|
|
|
events = json_data['events'] |
|
|
|
for event_index, event in enumerate(events): |
|
segments = event.get('segs') |
|
if not segments: |
|
continue |
|
|
|
|
|
start_ms = event['tStartMs'] |
|
total_characters = 0 |
|
|
|
new_segments = [] |
|
for seg in segments: |
|
|
|
text = ' '.join(seg['utf8'].split()) |
|
|
|
|
|
text = text.replace('\u200b', '').replace('\u200c', '').replace( |
|
'\u200d', '').replace('\ufeff', '').strip() |
|
|
|
|
|
|
|
|
|
|
|
text = text.replace(PROFANITY_RAW, PROFANITY_CONVERTED) |
|
|
|
if not text: |
|
continue |
|
|
|
offset_ms = seg.get('tOffsetMs', 0) |
|
|
|
new_segments.append({ |
|
'text': text, |
|
'start': round((start_ms + offset_ms)/1000, NUM_DECIMALS) |
|
}) |
|
|
|
total_characters += len(text) |
|
|
|
if not new_segments: |
|
continue |
|
|
|
if event_index < len(events) - 1: |
|
next_start_ms = events[event_index + 1]['tStartMs'] |
|
total_event_duration_ms = min( |
|
event.get('dDurationMs', float('inf')), next_start_ms - start_ms) |
|
else: |
|
total_event_duration_ms = event.get('dDurationMs', 0) |
|
|
|
|
|
total_event_duration_ms = max(total_event_duration_ms, 0) |
|
|
|
avg_seconds_per_character = ( |
|
total_event_duration_ms/total_characters)/1000 |
|
|
|
num_char_count = 0 |
|
for seg_index, seg in enumerate(new_segments): |
|
num_char_count += len(seg['text']) |
|
|
|
|
|
seg_end = seg['start'] + \ |
|
(num_char_count * avg_seconds_per_character) |
|
|
|
if seg_index < len(new_segments) - 1: |
|
|
|
seg_end = min(seg_end, new_segments[seg_index+1]['start']) |
|
|
|
seg['end'] = round(seg_end, NUM_DECIMALS) |
|
parsed_transcript.append(seg) |
|
|
|
final_parsed_transcript = [] |
|
for i in range(len(parsed_transcript)): |
|
|
|
word_level = granularity == 'word' |
|
if word_level: |
|
split_text = parsed_transcript[i]['text'].split() |
|
elif granularity == 'chunk': |
|
|
|
split_text = re.split( |
|
r'(?<=[.!?,-;])\s+', parsed_transcript[i]['text']) |
|
if len(split_text) == 1: |
|
split_on_whitespace = parsed_transcript[i]['text'].split() |
|
|
|
if len(split_on_whitespace) >= 8: |
|
|
|
split_text = split_on_whitespace |
|
else: |
|
word_level = True |
|
else: |
|
raise ValueError('Unknown granularity') |
|
|
|
segment_end = parsed_transcript[i]['end'] |
|
if i < len(parsed_transcript) - 1: |
|
segment_end = min(segment_end, parsed_transcript[i+1]['start']) |
|
|
|
segment_duration = segment_end - parsed_transcript[i]['start'] |
|
|
|
num_chars_in_text = sum(map(len, split_text)) |
|
|
|
num_char_count = 0 |
|
current_offset = 0 |
|
for s in split_text: |
|
num_char_count += len(s) |
|
|
|
next_offset = (num_char_count/num_chars_in_text) * segment_duration |
|
|
|
word_start = round( |
|
parsed_transcript[i]['start'] + current_offset, NUM_DECIMALS) |
|
word_end = round( |
|
parsed_transcript[i]['start'] + next_offset, NUM_DECIMALS) |
|
|
|
|
|
final_parsed_transcript.append({ |
|
'text': s, |
|
'start': word_start, |
|
'end': min(word_end, word_start + 1.5) if word_level else word_end |
|
}) |
|
current_offset = next_offset |
|
|
|
return final_parsed_transcript |
|
|
|
|
|
def list_transcripts(video_id): |
|
try: |
|
return youtube_transcript_api2.YouTubeTranscriptApi.list_transcripts(video_id) |
|
except json.decoder.JSONDecodeError: |
|
return None |
|
|
|
|
|
WORDS_TO_REMOVE = [ |
|
'[Music]' |
|
'[Applause]' |
|
'[Laughter]' |
|
] |
|
|
|
|
|
def get_words(video_id, transcript_type='auto', fallback='manual', filter_words_to_remove=True, granularity='word'): |
|
"""Get parsed video transcript with caching system |
|
returns None if not processed yet and process is False |
|
""" |
|
|
|
raw_transcript_json = None |
|
try: |
|
transcript_list = list_transcripts(video_id) |
|
|
|
if transcript_list is not None: |
|
if transcript_type == 'manual': |
|
ts = transcript_list.find_manually_created_transcript( |
|
LANGUAGE_PREFERENCE_LIST) |
|
else: |
|
ts = transcript_list.find_generated_transcript( |
|
LANGUAGE_PREFERENCE_LIST) |
|
raw_transcript = ts._http_client.get( |
|
f'{ts._url}&fmt=json3').content |
|
if raw_transcript: |
|
raw_transcript_json = json.loads(raw_transcript) |
|
|
|
except (youtube_transcript_api2.TooManyRequests, youtube_transcript_api2.YouTubeRequestFailed): |
|
raise |
|
|
|
except requests.exceptions.RequestException: |
|
return get_words(video_id, transcript_type, fallback, granularity) |
|
|
|
except youtube_transcript_api2.CouldNotRetrieveTranscript: |
|
pass |
|
|
|
except json.decoder.JSONDecodeError: |
|
return get_words(video_id, transcript_type, fallback, granularity) |
|
|
|
if not raw_transcript_json and fallback is not None: |
|
return get_words(video_id, transcript_type=fallback, fallback=None, granularity=granularity) |
|
|
|
if raw_transcript_json: |
|
processed_transcript = parse_transcript_json( |
|
raw_transcript_json, granularity) |
|
if filter_words_to_remove: |
|
processed_transcript = list( |
|
filter(lambda x: x['text'] not in WORDS_TO_REMOVE, processed_transcript)) |
|
else: |
|
processed_transcript = raw_transcript_json |
|
|
|
return processed_transcript |
|
|
|
|
|
def word_start(word): |
|
return word['start'] |
|
|
|
|
|
def word_end(word): |
|
return word.get('end', word['start']) |
|
|
|
|
|
def extract_segment(words, start, end, map_function=None): |
|
"""Extracts all words with time in [start, end]""" |
|
|
|
a = max(binary_search_below(words, 0, len(words), start), 0) |
|
b = min(binary_search_above(words, -1, len(words) - 1, end) + 1, len(words)) |
|
|
|
to_transform = map_function is not None and callable(map_function) |
|
|
|
return [ |
|
map_function(words[i]) if to_transform else words[i] for i in range(a, b) |
|
] |
|
|
|
|
|
def avg(*items): |
|
return sum(items)/len(items) |
|
|
|
|
|
def binary_search_below(transcript, start_index, end_index, time): |
|
if start_index >= end_index: |
|
return end_index |
|
|
|
middle_index = (start_index + end_index) // 2 |
|
middle = transcript[middle_index] |
|
middle_time = avg(word_start(middle), word_end(middle)) |
|
|
|
if time <= middle_time: |
|
return binary_search_below(transcript, start_index, middle_index, time) |
|
else: |
|
return binary_search_below(transcript, middle_index + 1, end_index, time) |
|
|
|
|
|
def binary_search_above(transcript, start_index, end_index, time): |
|
if start_index >= end_index: |
|
return end_index |
|
|
|
middle_index = (start_index + end_index + 1) // 2 |
|
middle = transcript[middle_index] |
|
middle_time = avg(word_start(middle), word_end(middle)) |
|
|
|
if time >= middle_time: |
|
return binary_search_above(transcript, middle_index, end_index, time) |
|
else: |
|
return binary_search_above(transcript, start_index, middle_index - 1, time) |
|
|
|
|
|
class PreTrainedPipeline(): |
|
def __init__(self, path: str): |
|
|
|
self.model = keras.models.load_model(os.path.join(path, "tf_model.h5")) |
|
|
|
def __call__(self, inputs: "Image.Image")-> List[Dict[str, Any]]: |
|
|
|
|
|
|
|
words = get_words("pqh4LfPeCYs") |
|
segment = extract_segment(words, 835.933, 927.581) |
|
|
|
|
|
|
|
|
|
img = np.array(inputs) |
|
|
|
im = tf.image.resize(img, (128, 128)) |
|
im = tf.cast(im, tf.float32) / 255.0 |
|
pred_mask = self.model.predict(im[tf.newaxis, ...]) |
|
|
|
|
|
|
|
pred_mask_arg = tf.argmax(pred_mask, axis=-1) |
|
|
|
labels = [] |
|
|
|
|
|
binary_masks = {} |
|
mask_codes = {} |
|
|
|
|
|
|
|
|
|
|
|
rows = pred_mask_arg[0][1].get_shape().as_list()[0] |
|
cols = pred_mask_arg[0][2].get_shape().as_list()[0] |
|
|
|
for cls in range(pred_mask.shape[-1]): |
|
|
|
binary_masks[f"mask_{cls}"] = np.zeros(shape = (pred_mask.shape[1], pred_mask.shape[2])) |
|
|
|
for row in range(rows): |
|
|
|
for col in range(cols): |
|
|
|
if pred_mask_arg[0][row][col] == cls: |
|
|
|
binary_masks[f"mask_{cls}"][row][col] = 1 |
|
else: |
|
binary_masks[f"mask_{cls}"][row][col] = 0 |
|
|
|
mask = binary_masks[f"mask_{cls}"] |
|
mask *= 255 |
|
img = Image.fromarray(mask.astype(np.int8), mode="L") |
|
|
|
|
|
with io.BytesIO() as out: |
|
img.save(out, format="PNG") |
|
png_string = out.getvalue() |
|
mask = base64.b64encode(png_string).decode("utf-8") |
|
|
|
mask_codes[f"mask_{cls}"] = mask |
|
|
|
|
|
|
|
labels.append({ |
|
"label": f"LABEL_{cls}", |
|
"mask": mask_codes[f"mask_{cls}"], |
|
"score": 1.0, |
|
"words": segment |
|
}) |
|
return labels |