import gradio as gr import requests import re from transformers import AutoTokenizer, pipeline from youtube_transcript_api._transcripts import TranscriptListFetcher tagger = pipeline( "token-classification", "./checkpoint-6000", aggregation_strategy="first", ) tokenizer = AutoTokenizer.from_pretrained("./checkpoint-6000") max_size = 512 classes = [False, True] pattern = re.compile( r"(?:https?:\/\/)?(?:[0-9A-Z-]+\.)?(?:youtube|youtu|youtube-nocookie)\.(?:com|be)\/(?:watch\?v=|watch\?.+&v=|embed\/|v\/|.+\?v=)?([^&=\n%\?]{11})" ) def video_id(url): p = pattern.match(url) return p.group(1) if p else None def process(obj): o = obj["events"] new_l = [] start_dur = None for line in o: if "segs" in line: if len(line["segs"]) == 1 and line["segs"][0]["utf8"] == "\n": if start_dur is not None: new_l.append( { "w": prev["utf8"], "s": start_dur + prev["tOffsetMs"], "e": line["tStartMs"], } ) continue start_dur = line["tStartMs"] prev = line["segs"][0] prev["tOffsetMs"] = 0 for word in line["segs"][1:]: try: new_l.append( { "w": prev["utf8"], "s": start_dur + prev["tOffsetMs"], "e": start_dur + word["tOffsetMs"], } ) prev = word except KeyError: pass return new_l def get_transcript(video_id, session): fetcher = TranscriptListFetcher(session) _json = fetcher._extract_captions_json( fetcher._fetch_video_html(video_id), video_id ) captionTracks = _json["captionTracks"] transcript_track_url = "" for track in captionTracks: if track["languageCode"] == "en": transcript_track_url = track["baseUrl"] + "&fmt=json3" if not transcript_track_url: return None obj = session.get(transcript_track_url) p = process(obj.json()) return p def transcript(url): i = video_id(url) if i: return " ".join(l["w"].strip() for l in get_transcript(i, requests.Session())) else: return "ERROR: Failed to load transcript (it the link a valid youtube url?)..." def inference(transcript): tokens = tokenizer(transcript.split(" "))["input_ids"] current_length = 0 current_word_length = 0 batches = [] for i, w in enumerate(tokens): word = w[:-1] if i == 0 else w[1:] if i == (len(tokens) - 1) else w[1:-1] if (current_length + len(word)) > max_size: batch = " ".join( tokenizer.batch_decode( [ tok[1:-1] for tok in tokens[max(0, i - current_word_length - 1) : i] ] ) ) batches.append(batch) current_word_length = 0 current_length = 0 continue current_length += len(word) current_word_length += 1 if current_length > 0: batches.append( " ".join( tokenizer.batch_decode( [tok[1:-1] for tok in tokens[i - current_word_length :]] ) ) ) results = [] for split in batches: values = tagger(split) results.extend( { "sponsor": v["entity_group"] == "LABEL_1", "phrase": v["word"], } for v in values ) return results def predict(transcript): return [(span["phrase"], "Sponsor" if span["sponsor"] else None) for span in inference(transcript)] with gr.Blocks() as demo: with gr.Row(): with gr.Column(): inp = gr.Textbox(label="Video URL", placeholder="Video URL", lines=1, max_lines=1) btn = gr.Button("Fetch Transcript") gr.Examples(["youtu.be/xsLJZyih3Ac"], [inp]) text = gr.Textbox(label="Transcript", placeholder="") btn.click(fn=transcript, inputs=inp, outputs=text) with gr.Column(): p = gr.Button("Predict Sponsors") highlight = gr.HighlightedText() p.click(fn=predict, inputs=text, outputs=highlight) demo.launch()