Whisper-medium / app.py
sanjitaa's picture
Update app.py
97bf953
import torch
import gradio as gr
import yt_dlp as youtube_dl
from transformers import pipeline
from transformers.pipelines.audio_utils import ffmpeg_read
from transformers import MBartForConditionalGeneration, MBart50TokenizerFast
from faster_whisper import WhisperModel
import tempfile
import os
MODEL_NAME = "medium"
BATCH_SIZE = 8
FILE_LIMIT_MB = 1000
device = 0 if torch.cuda.is_available() else "cpu"
# pipe = pipeline(
# task="automatic-speech-recognition",
# model=MODEL_NAME,
# chunk_length_s=30,
# device=device,
# )
model = MBartForConditionalGeneration.from_pretrained("sanjitaa/mbart-many-to-many")
tokenizer = MBart50TokenizerFast.from_pretrained("sanjitaa/mbart-many-to-many")
def translate(inputs, task):
if inputs is None:
raise gr.Error("No audio file submitted! Please upload or record an audio file before submitting your request.")
#text = pipe(inputs, batch_size=BATCH_SIZE, generate_kwargs={"task": task}, return_timestamps=True)["text"]
ts_model = WhisperModel(MODEL_NAME, device = device, compute_type = "int8")
segments, _ = ts_model.transcribe(inputs, task = "translate")
lst = ''
for segment in segments:
lst = lst + segment.text
encoded_text = tokenizer(lst, return_tensors="pt")
tokenizer.src_lang = "en_XX"
generated_tokens = model.generate(
**encoded_text,
forced_bos_token_id=tokenizer.lang_code_to_id["fr_XX"]
)
result = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
return result
demo = gr.Blocks()
# mf_transcribe = gr.Interface(
# fn=translate,
# inputs=[
# gr.inputs.Audio(source="microphone", type="filepath", optional=True),
# gr.inputs.Radio(["translate"], label="Task", default="translate"),
# ],
# outputs="text",
# layout="horizontal",
# theme="huggingface",
# title="Whisper Medium: Transcribe Audio",
# description=(
# "Transcribe long-form microphone or audio inputs with the click of a button! Demo uses the"
# f" checkpoint [{MODEL_NAME}](https://huggingface.co/{MODEL_NAME}) and πŸ€— Transformers to transcribe audio files"
# " of arbitrary length."
# ),
# allow_flagging="never",
# )
file_transcribe = gr.Interface(
fn=translate,
inputs=[
gr.inputs.Audio(source="upload", type="filepath", optional=True, label="Audio file"),
gr.inputs.Radio(["translate"], label="Task", default="transcribe"),
],
outputs="text",
layout="horizontal",
theme="huggingface",
title="Whisper Medium: Transcribe Audio",
description=(
"Transcribe long-form microphone or audio inputs with the click of a button! Demo uses the"
f" checkpoint [{MODEL_NAME}](https://huggingface.co/{MODEL_NAME}) and πŸ€— Transformers to transcribe audio files"
" of arbitrary length."
),
allow_flagging="never",
)
with demo:
gr.TabbedInterface([file_transcribe], ["Audio file"])
demo.launch(enable_queue=True)