|
import gradio as gr |
|
import time |
|
import logging |
|
import torch |
|
from sys import platform |
|
from transformers import pipeline, AutoModelForSpeechSeq2Seq, AutoProcessor |
|
from transformers.utils import is_flash_attn_2_available |
|
from languages import get_language_names |
|
from subtitle_manager import Subtitle |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
last_model = None |
|
pipe = None |
|
|
|
def write_file(output_file,subtitle): |
|
with open(output_file, 'w', encoding='utf-8') as f: |
|
f.write(subtitle) |
|
|
|
def create_pipe(model, flash): |
|
if torch.cuda.is_available(): |
|
device = "cuda:0" |
|
elif platform == "darwin": |
|
device = "mps" |
|
else: |
|
device = "cpu" |
|
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 |
|
model_id = model |
|
|
|
model = AutoModelForSpeechSeq2Seq.from_pretrained( |
|
model_id, |
|
torch_dtype=torch_dtype, |
|
low_cpu_mem_usage=True, |
|
use_safetensors=True, |
|
attn_implementation="flash_attention_2" if flash and is_flash_attn_2_available() else "sdpa", |
|
|
|
|
|
|
|
|
|
) |
|
model.to(device) |
|
|
|
processor = AutoProcessor.from_pretrained(model_id) |
|
|
|
pipe = pipeline( |
|
"automatic-speech-recognition", |
|
model=model, |
|
tokenizer=processor.tokenizer, |
|
feature_extractor=processor.feature_extractor, |
|
|
|
|
|
|
|
torch_dtype=torch_dtype, |
|
device=device, |
|
) |
|
return pipe |
|
|
|
def transcribe_webui_simple_progress(modelName, languageName, urlData, multipleFiles, microphoneData, task, flash, |
|
chunk_length_s, batch_size, progress=gr.Progress()): |
|
global last_model |
|
global pipe |
|
|
|
progress(0, desc="Loading Audio..") |
|
logging.info(f"urlData:{urlData}") |
|
logging.info(f"multipleFiles:{multipleFiles}") |
|
logging.info(f"microphoneData:{microphoneData}") |
|
logging.info(f"task: {task}") |
|
logging.info(f"is_flash_attn_2_available: {is_flash_attn_2_available()}") |
|
logging.info(f"chunk_length_s: {chunk_length_s}") |
|
logging.info(f"batch_size: {batch_size}") |
|
|
|
if last_model == None: |
|
logging.info("first model") |
|
progress(0.1, desc="Loading Model..") |
|
pipe = create_pipe(modelName, flash) |
|
elif modelName != last_model: |
|
logging.info("new model") |
|
torch.cuda.empty_cache() |
|
progress(0.1, desc="Loading Model..") |
|
pipe = create_pipe(modelName, flash) |
|
else: |
|
logging.info("Model not changed") |
|
last_model = modelName |
|
|
|
srt_sub = Subtitle("srt") |
|
vtt_sub = Subtitle("vtt") |
|
txt_sub = Subtitle("txt") |
|
|
|
files = [] |
|
if multipleFiles: |
|
files+=multipleFiles |
|
if urlData: |
|
files.append(urlData) |
|
if microphoneData: |
|
files.append(microphoneData) |
|
logging.info(files) |
|
|
|
generate_kwargs = {} |
|
if languageName != "Automatic Detection" and modelName.endswith(".en") == False: |
|
generate_kwargs["language"] = languageName |
|
if modelName.endswith(".en") == False: |
|
generate_kwargs["task"] = task |
|
|
|
files_out = [] |
|
for file in progress.tqdm(files, desc="Working..."): |
|
start_time = time.time() |
|
logging.info(file) |
|
outputs = pipe( |
|
file, |
|
chunk_length_s=chunk_length_s, |
|
batch_size=batch_size, |
|
generate_kwargs=generate_kwargs, |
|
return_timestamps=True, |
|
) |
|
logging.debug(outputs) |
|
logging.info(print(f"transcribe: {time.time() - start_time} sec.")) |
|
|
|
file_out = file.split('/')[-1] |
|
srt = srt_sub.get_subtitle(outputs["chunks"]) |
|
vtt = vtt_sub.get_subtitle(outputs["chunks"]) |
|
txt = txt_sub.get_subtitle(outputs["chunks"]) |
|
write_file(file_out+".srt",srt) |
|
write_file(file_out+".vtt",vtt) |
|
write_file(file_out+".txt",txt) |
|
files_out += [file_out+".srt", file_out+".vtt", file_out+".txt"] |
|
|
|
progress(1, desc="Completed!") |
|
|
|
return files_out, vtt, txt |
|
|
|
|
|
with gr.Blocks(title="Insanely Fast Whisper") as demo: |
|
description = "An opinionated CLI to transcribe Audio files w/ Whisper on-device! Powered by 🤗 Transformers, Optimum & flash-attn" |
|
article = "Read the [documentation here](https://github.com/Vaibhavs10/insanely-fast-whisper#cli-options)." |
|
whisper_models = [ |
|
"openai/whisper-tiny", "openai/whisper-tiny.en", |
|
"openai/whisper-base", "openai/whisper-base.en", |
|
"openai/whisper-small", "openai/whisper-small.en", "distil-whisper/distil-small.en", |
|
"openai/whisper-medium", "openai/whisper-medium.en", "distil-whisper/distil-medium.en", |
|
"openai/whisper-large", |
|
"openai/whisper-large-v1", |
|
"openai/whisper-large-v2", "distil-whisper/distil-large-v2", |
|
"openai/whisper-large-v3", "distil-whisper/distil-large-v3", "xaviviro/whisper-large-v3-catalan-finetuned-v2", |
|
] |
|
waveform_options=gr.WaveformOptions( |
|
waveform_color="#01C6FF", |
|
waveform_progress_color="#0066B4", |
|
skip_length=2, |
|
show_controls=False, |
|
) |
|
|
|
simple_transcribe = gr.Interface(fn=transcribe_webui_simple_progress, |
|
description=description, |
|
article=article, |
|
inputs=[ |
|
gr.Dropdown(choices=whisper_models, value="distil-whisper/distil-large-v2", label="Model", info="Select whisper model", interactive = True,), |
|
gr.Dropdown(choices=["Automatic Detection"] + sorted(get_language_names()), value="Automatic Detection", label="Language", info="Select audio voice language", interactive = True,), |
|
gr.Text(label="URL", info="(YouTube, etc.)", interactive = True), |
|
gr.File(label="Upload Files", file_count="multiple"), |
|
gr.Audio(sources=["upload", "microphone",], type="filepath", label="Input", waveform_options = waveform_options), |
|
gr.Dropdown(choices=["transcribe", "translate"], label="Task", value="transcribe", interactive = True), |
|
gr.Checkbox(label='Flash',info='Use Flash Attention 2'), |
|
gr.Number(label='chunk_length_s',value=30, interactive = True), |
|
gr.Number(label='batch_size',value=24, interactive = True) |
|
], outputs=[ |
|
gr.File(label="Download"), |
|
gr.Text(label="Transcription"), |
|
gr.Text(label="Segments") |
|
] |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.launch() |
|
|
|
|