Spaces:
Sleeping
Sleeping
import spaces | |
import torch | |
import gradio as gr | |
import whisperx | |
from transformers.pipelines.audio_utils import ffmpeg_read | |
import tempfile | |
import gc | |
import os | |
# Constants | |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
BATCH_SIZE = 4 # reduce if low on GPU mem | |
COMPUTE_TYPE = "float32" # change to "int8" if low on GPU mem | |
FILE_LIMIT_MB = 1000 | |
def transcribe_audio(inputs, task): | |
if inputs is None: | |
raise gr.Error("No audio file submitted! Please upload or record an audio file before submitting your request.") | |
try: | |
# Load audio | |
if isinstance(inputs, str): | |
# For file path input | |
audio = whisperx.load_audio(inputs) | |
else: | |
# For microphone input (needs conversion) | |
audio = whisperx.load_audio(inputs) | |
# 1. Transcribe with base Whisper model | |
model = whisperx.load_model("large-v3", device=DEVICE, compute_type=COMPUTE_TYPE) | |
result = model.transcribe(audio, batch_size=BATCH_SIZE) | |
# Clear GPU memory | |
del model | |
gc.collect() | |
torch.cuda.empty_cache() | |
# 2. Align whisper output | |
model_a, metadata = whisperx.load_align_model(language_code=result["language"], device=DEVICE) | |
result = whisperx.align(result["segments"], model_a, metadata, audio, DEVICE, return_char_alignments=False) | |
# Clear GPU memory again | |
del model_a | |
gc.collect() | |
torch.cuda.empty_cache() | |
# 3. Diarize audio | |
diarize_model = whisperx.DiarizationPipeline(use_auth_token=os.environ["HF_TOKEN"], device=DEVICE) | |
diarize_segments = diarize_model(audio) | |
# 4. Assign speaker labels | |
result = whisperx.assign_word_speakers(diarize_segments, result) | |
# Format output | |
output_text = "" | |
for segment in result['segments']: | |
speaker = segment.get('speaker', 'Unknown Speaker') | |
text = segment['text'] | |
output_text += f"{speaker}: {text}\n" | |
return output_text | |
except Exception as e: | |
raise gr.Error(f"Error processing audio: {str(e)}") | |
finally: | |
# Final cleanup | |
gc.collect() | |
torch.cuda.empty_cache() | |
# Create Gradio interface | |
demo = gr.Blocks(theme=gr.themes.Ocean()) | |
with demo: | |
gr.Markdown("# WhisperX: Advanced Speech Recognition with Speaker Diarization") | |
with gr.Row(): | |
with gr.Column(): | |
audio_input = gr.Audio( | |
sources=["microphone", "upload"], | |
type="filepath", | |
label="Audio Input (Microphone or File Upload)" | |
) | |
task = gr.Radio( | |
["transcribe", "translate"], | |
label="Task", | |
value="transcribe" | |
) | |
submit_button = gr.Button("Process Audio") | |
with gr.Column(): | |
output_text = gr.Textbox( | |
label="Transcription with Speaker Diarization", | |
lines=10, | |
placeholder="Transcribed text will appear here..." | |
) | |
gr.Markdown(""" | |
### Features: | |
- High-accuracy transcription using WhisperX | |
- Automatic speaker diarization | |
- Support for both microphone recording and file upload | |
- GPU-accelerated processing | |
### Note: | |
Processing may take a few moments depending on the audio length and system resources. | |
""") | |
submit_button.click( | |
fn=transcribe_audio, | |
inputs=[audio_input, task], | |
outputs=output_text | |
) | |
demo.queue().launch(ssr_mode=False) |