Spaces:
Runtime error
Runtime error
import gradio as gr | |
from pytube import YouTube | |
from transformers import pipeline | |
from transformers import WhisperProcessor, WhisperForConditionalGeneration | |
import soundfile | |
import os | |
import subprocess | |
class GradioInference(): | |
def __init__(self): | |
self.processor = WhisperProcessor.from_pretrained("humeur/whisper-small-sv-en") | |
self.model = WhisperForConditionalGeneration.from_pretrained("humeur/whisper-small-sv-en") | |
self.yt = None | |
def __call__(self, link): | |
if self.yt is None: | |
self.yt = YouTube(link) | |
path = self.yt.streams.filter(only_audio=True)[0].download(filename="tmp.mp4") | |
subprocess.run([ | |
'ffmpeg', '-i', 'tmp.mp4','-vn', '-acodec', 'pcm_s16le', '-ac', '1', '-ar', '44100', '-f', 'wav','tmp.wav' | |
]) | |
sound_data = soundfile.read('tmp.wav') | |
input_features = self.processor(sound_data, return_tensors="pt").input_features | |
forced_decoder_ids = self.processor.get_decoder_prompt_ids(language = "sv", task = "translate") | |
predicted_ids = self.model.generate(input_features, forced_decoder_ids = forced_decoder_ids) | |
results = self.processor.batch_decode(predicted_ids, skip_special_tokens = True) | |
# results = self.model(path) | |
# return results["text"] | |
return results | |
def populate_metadata(self, link): | |
self.yt = YouTube(link) | |
return self.yt.thumbnail_url, self.yt.title | |
gio = GradioInference() | |
title="SWED->EN Youtube Transcriber (Whisper)" | |
description="Speech to text transcription of Youtube videos using OpenAI's Whisper finetunned for Swedish to English translation" | |
block = gr.Blocks() | |
with block: | |
gr.HTML( | |
f""" | |
<div style="text-align: center; max-width: 500px; margin: 0 auto;"> | |
<div> | |
<h1>{title}</h1> | |
</div> | |
<p style="margin-bottom: 10px; font-size: 94%"> | |
{description} | |
</p> | |
</div> | |
""" | |
) | |
with gr.Group(): | |
with gr.Box(): | |
link = gr.Textbox(label="YouTube Link") | |
title = gr.Label(label="Video Title") | |
with gr.Row().style(equal_height=True): | |
img = gr.Image(label="Thumbnail") | |
text = gr.Textbox(label="Transcription", placeholder="Transcription Output", lines=10) | |
with gr.Row().style(equal_height=True): | |
btn = gr.Button("Transcribe") | |
btn.click(gio, inputs=[link], outputs=[text]) | |
link.change(gio.populate_metadata, inputs=[link], outputs=[img, title]) | |
block.launch() | |