WhisperX-V2 / app.py
StevenChen16's picture
Update app.py
aa547ad verified
raw
history blame
3.67 kB
import spaces
import torch
import gradio as gr
import whisperx
from transformers.pipelines.audio_utils import ffmpeg_read
import tempfile
import gc
import os
# Constants
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
BATCH_SIZE = 4 # reduce if low on GPU mem
COMPUTE_TYPE = "float32" # change to "int8" if low on GPU mem
FILE_LIMIT_MB = 1000
@spaces.GPU
def transcribe_audio(inputs, task):
if inputs is None:
raise gr.Error("No audio file submitted! Please upload or record an audio file before submitting your request.")
try:
# Load audio
if isinstance(inputs, str):
# For file path input
audio = whisperx.load_audio(inputs)
else:
# For microphone input (needs conversion)
audio = whisperx.load_audio(inputs)
# 1. Transcribe with base Whisper model
model = whisperx.load_model("large-v3", device=DEVICE, compute_type=COMPUTE_TYPE)
result = model.transcribe(audio, batch_size=BATCH_SIZE)
# Clear GPU memory
del model
gc.collect()
torch.cuda.empty_cache()
# 2. Align whisper output
model_a, metadata = whisperx.load_align_model(language_code=result["language"], device=DEVICE)
result = whisperx.align(result["segments"], model_a, metadata, audio, DEVICE, return_char_alignments=False)
# Clear GPU memory again
del model_a
gc.collect()
torch.cuda.empty_cache()
# 3. Diarize audio
diarize_model = whisperx.DiarizationPipeline(use_auth_token=os.environ["HF_TOKEN"], device=DEVICE)
diarize_segments = diarize_model(audio)
# 4. Assign speaker labels
result = whisperx.assign_word_speakers(diarize_segments, result)
# Format output
output_text = ""
for segment in result['segments']:
speaker = segment.get('speaker', 'Unknown Speaker')
text = segment['text']
output_text += f"{speaker}: {text}\n"
return output_text
except Exception as e:
raise gr.Error(f"Error processing audio: {str(e)}")
finally:
# Final cleanup
gc.collect()
torch.cuda.empty_cache()
# Create Gradio interface
demo = gr.Blocks(theme=gr.themes.Ocean())
with demo:
gr.Markdown("# WhisperX: Advanced Speech Recognition with Speaker Diarization")
with gr.Row():
with gr.Column():
audio_input = gr.Audio(
sources=["microphone", "upload"],
type="filepath",
label="Audio Input (Microphone or File Upload)"
)
task = gr.Radio(
["transcribe", "translate"],
label="Task",
value="transcribe"
)
submit_button = gr.Button("Process Audio")
with gr.Column():
output_text = gr.Textbox(
label="Transcription with Speaker Diarization",
lines=10,
placeholder="Transcribed text will appear here..."
)
gr.Markdown("""
### Features:
- High-accuracy transcription using WhisperX
- Automatic speaker diarization
- Support for both microphone recording and file upload
- GPU-accelerated processing
### Note:
Processing may take a few moments depending on the audio length and system resources.
""")
submit_button.click(
fn=transcribe_audio,
inputs=[audio_input, task],
outputs=output_text
)
demo.queue().launch(ssr_mode=False)