humeur's picture
FIx app
edc4957
raw
history blame
2.51 kB
import gradio as gr
from pytube import YouTube
from transformers import pipeline
from transformers import WhisperProcessor, WhisperForConditionalGeneration
import soundfile
import os
import subprocess
class GradioInference():
def __init__(self):
self.processor = WhisperProcessor.from_pretrained("humeur/whisper-small-sv-en")
self.model = WhisperForConditionalGeneration.from_pretrained("humeur/whisper-small-sv-en")
self.yt = None
def __call__(self, link):
if self.yt is None:
self.yt = YouTube(link)
path = self.yt.streams.filter(only_audio=True)[0].download(filename="tmp.mp4")
subprocess.run([
'ffmpeg', '-i', 'tmp.mp4','-vn', '-acodec', 'pcm_s16le', '-ac', '1', '-ar', '44100', '-f', 'wav','tmp.wav'
])
sound_data = soundfile.read('tmp.wav')
input_features = self.processor(sound_data, return_tensors="pt").input_features
forced_decoder_ids = self.processor.get_decoder_prompt_ids(language = "sv", task = "translate")
predicted_ids = self.model.generate(input_features, forced_decoder_ids = forced_decoder_ids)
results = self.processor.batch_decode(predicted_ids, skip_special_tokens = True)
# results = self.model(path)
# return results["text"]
return results
def populate_metadata(self, link):
self.yt = YouTube(link)
return self.yt.thumbnail_url, self.yt.title
gio = GradioInference()
title="SWED->EN Youtube Transcriber (Whisper)"
description="Speech to text transcription of Youtube videos using OpenAI's Whisper finetunned for Swedish to English translation"
block = gr.Blocks()
with block:
gr.HTML(
f"""
<div style="text-align: center; max-width: 500px; margin: 0 auto;">
<div>
<h1>{title}</h1>
</div>
<p style="margin-bottom: 10px; font-size: 94%">
{description}
</p>
</div>
"""
)
with gr.Group():
with gr.Box():
link = gr.Textbox(label="YouTube Link")
title = gr.Label(label="Video Title")
with gr.Row().style(equal_height=True):
img = gr.Image(label="Thumbnail")
text = gr.Textbox(label="Transcription", placeholder="Transcription Output", lines=10)
with gr.Row().style(equal_height=True):
btn = gr.Button("Transcribe")
btn.click(gio, inputs=[link], outputs=[text])
link.change(gio.populate_metadata, inputs=[link], outputs=[img, title])
block.launch()