gobeldan's picture
Update app.py
17c872a verified
import gradio as gr
import time
import logging
import torch
from sys import platform
from transformers import pipeline, AutoModelForSpeechSeq2Seq, AutoProcessor
from transformers.utils import is_flash_attn_2_available
from languages import get_language_names
from subtitle_manager import Subtitle
logging.basicConfig(level=logging.INFO)
last_model = None
pipe = None
def write_file(output_file,subtitle):
with open(output_file, 'w', encoding='utf-8') as f:
f.write(subtitle)
def create_pipe(model, flash):
if torch.cuda.is_available():
device = "cuda:0"
elif platform == "darwin":
device = "mps"
else:
device = "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
model_id = model
model = AutoModelForSpeechSeq2Seq.from_pretrained(
model_id,
torch_dtype=torch_dtype,
low_cpu_mem_usage=True,
use_safetensors=True,
attn_implementation="flash_attention_2" if flash and is_flash_attn_2_available() else "sdpa",
# eager (manual attention implementation)
# flash_attention_2 (implementation using flash attention 2)
# sdpa (implementation using torch.nn.functional.scaled_dot_product_attention)
# PyTorch SDPA requirements in Transformers are not met. Please install torch>=2.1.1.
)
model.to(device)
processor = AutoProcessor.from_pretrained(model_id)
pipe = pipeline(
"automatic-speech-recognition",
model=model,
tokenizer=processor.tokenizer,
feature_extractor=processor.feature_extractor,
# max_new_tokens=128,
# chunk_length_s=15,
# batch_size=16,
torch_dtype=torch_dtype,
device=device,
)
return pipe
def transcribe_webui_simple_progress(modelName, languageName, urlData, multipleFiles, microphoneData, task, flash,
chunk_length_s, batch_size, progress=gr.Progress()):
global last_model
global pipe
progress(0, desc="Loading Audio..")
logging.info(f"urlData:{urlData}")
logging.info(f"multipleFiles:{multipleFiles}")
logging.info(f"microphoneData:{microphoneData}")
logging.info(f"task: {task}")
logging.info(f"is_flash_attn_2_available: {is_flash_attn_2_available()}")
logging.info(f"chunk_length_s: {chunk_length_s}")
logging.info(f"batch_size: {batch_size}")
if last_model == None:
logging.info("first model")
progress(0.1, desc="Loading Model..")
pipe = create_pipe(modelName, flash)
elif modelName != last_model:
logging.info("new model")
torch.cuda.empty_cache()
progress(0.1, desc="Loading Model..")
pipe = create_pipe(modelName, flash)
else:
logging.info("Model not changed")
last_model = modelName
srt_sub = Subtitle("srt")
vtt_sub = Subtitle("vtt")
txt_sub = Subtitle("txt")
files = []
if multipleFiles:
files+=multipleFiles
if urlData:
files.append(urlData)
if microphoneData:
files.append(microphoneData)
logging.info(files)
generate_kwargs = {}
if languageName != "Automatic Detection" and modelName.endswith(".en") == False:
generate_kwargs["language"] = languageName
if modelName.endswith(".en") == False:
generate_kwargs["task"] = task
files_out = []
for file in progress.tqdm(files, desc="Working..."):
start_time = time.time()
logging.info(file)
outputs = pipe(
file,
chunk_length_s=chunk_length_s,#30
batch_size=batch_size,#24
generate_kwargs=generate_kwargs,
return_timestamps=True,
)
logging.debug(outputs)
logging.info(print(f"transcribe: {time.time() - start_time} sec."))
file_out = file.split('/')[-1]
srt = srt_sub.get_subtitle(outputs["chunks"])
vtt = vtt_sub.get_subtitle(outputs["chunks"])
txt = txt_sub.get_subtitle(outputs["chunks"])
write_file(file_out+".srt",srt)
write_file(file_out+".vtt",vtt)
write_file(file_out+".txt",txt)
files_out += [file_out+".srt", file_out+".vtt", file_out+".txt"]
progress(1, desc="Completed!")
return files_out, vtt, txt
with gr.Blocks(title="Insanely Fast Whisper") as demo:
description = "An opinionated CLI to transcribe Audio files w/ Whisper on-device! Powered by 🤗 Transformers, Optimum & flash-attn"
article = "Read the [documentation here](https://github.com/Vaibhavs10/insanely-fast-whisper#cli-options)."
whisper_models = [
"openai/whisper-tiny", "openai/whisper-tiny.en",
"openai/whisper-base", "openai/whisper-base.en",
"openai/whisper-small", "openai/whisper-small.en", "distil-whisper/distil-small.en",
"openai/whisper-medium", "openai/whisper-medium.en", "distil-whisper/distil-medium.en",
"openai/whisper-large",
"openai/whisper-large-v1",
"openai/whisper-large-v2", "distil-whisper/distil-large-v2",
"openai/whisper-large-v3", "xaviviro/whisper-large-v3-catalan-finetuned-v2",
]
waveform_options=gr.WaveformOptions(
waveform_color="#01C6FF",
waveform_progress_color="#0066B4",
skip_length=2,
show_controls=False,
)
simple_transcribe = gr.Interface(fn=transcribe_webui_simple_progress,
description=description,
article=article,
inputs=[
gr.Dropdown(choices=whisper_models, value="distil-whisper/distil-large-v2", label="Model", info="Select whisper model", interactive = True,),
gr.Dropdown(choices=["Automatic Detection"] + sorted(get_language_names()), value="Automatic Detection", label="Language", info="Select audio voice language", interactive = True,),
gr.Text(label="URL", info="(YouTube, etc.)", interactive = True),
gr.File(label="Upload Files", file_count="multiple"),
gr.Audio(sources=["upload", "microphone",], type="filepath", label="Input", waveform_options = waveform_options),
gr.Dropdown(choices=["transcribe", "translate"], label="Task", value="transcribe", interactive = True),
gr.Checkbox(label='Flash',info='Use Flash Attention 2'),
gr.Number(label='chunk_length_s',value=30, interactive = True),
gr.Number(label='batch_size',value=24, interactive = True)
], outputs=[
gr.File(label="Download"),
gr.Text(label="Transcription"),
gr.Text(label="Segments")
]
)
if __name__ == "__main__":
demo.launch()