Spaces:
Sleeping
Sleeping
File size: 3,685 Bytes
aa547ad 86a1f13 550cf61 b86a6f7 550cf61 aa547ad 9a9ac31 550cf61 1e923d6 4036c8e 1e923d6 550cf61 ee53092 550cf61 1e923d6 550cf61 1e923d6 550cf61 1e923d6 550cf61 1e923d6 550cf61 1e923d6 b86a6f7 550cf61 86a1f13 b86a6f7 550cf61 1e923d6 550cf61 1e923d6 550cf61 1e923d6 550cf61 1e923d6 550cf61 b86a6f7 1e923d6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 |
import spaces
import torch
import gradio as gr
import whisperx
from transformers.pipelines.audio_utils import ffmpeg_read
import tempfile
import gc
import os
# Constants
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
BATCH_SIZE = 4 # reduce if low on GPU mem
COMPUTE_TYPE = "float32" # change to "int8" if low on GPU mem
FILE_LIMIT_MB = 1000
@spaces.GPU(duration=200)
def transcribe_audio(inputs, task):
if inputs is None:
raise gr.Error("No audio file submitted! Please upload or record an audio file before submitting your request.")
try:
# Load audio
if isinstance(inputs, str):
# For file path input
audio = whisperx.load_audio(inputs)
else:
# For microphone input (needs conversion)
audio = whisperx.load_audio(inputs)
# 1. Transcribe with base Whisper model
model = whisperx.load_model("large-v3", device=DEVICE, compute_type=COMPUTE_TYPE)
result = model.transcribe(audio, batch_size=BATCH_SIZE)
# Clear GPU memory
del model
gc.collect()
torch.cuda.empty_cache()
# 2. Align whisper output
model_a, metadata = whisperx.load_align_model(language_code=result["language"], device=DEVICE)
result = whisperx.align(result["segments"], model_a, metadata, audio, DEVICE, return_char_alignments=False)
# Clear GPU memory again
del model_a
gc.collect()
torch.cuda.empty_cache()
# 3. Diarize audio
diarize_model = whisperx.DiarizationPipeline(use_auth_token=os.environ["HF_TOKEN"], device=DEVICE)
diarize_segments = diarize_model(audio)
# 4. Assign speaker labels
result = whisperx.assign_word_speakers(diarize_segments, result)
# Format output
output_text = ""
for segment in result['segments']:
speaker = segment.get('speaker', 'Unknown Speaker')
text = segment['text']
output_text += f"{speaker}: {text}\n"
return output_text
except Exception as e:
raise gr.Error(f"Error processing audio: {str(e)}")
finally:
# Final cleanup
gc.collect()
torch.cuda.empty_cache()
# Create Gradio interface
demo = gr.Blocks(theme=gr.themes.Ocean())
with demo:
gr.Markdown("# WhisperX: Advanced Speech Recognition with Speaker Diarization")
with gr.Row():
with gr.Column():
audio_input = gr.Audio(
sources=["microphone", "upload"],
type="filepath",
label="Audio Input (Microphone or File Upload)"
)
task = gr.Radio(
["transcribe", "translate"],
label="Task",
value="transcribe"
)
submit_button = gr.Button("Process Audio")
with gr.Column():
output_text = gr.Textbox(
label="Transcription with Speaker Diarization",
lines=10,
placeholder="Transcribed text will appear here..."
)
gr.Markdown("""
### Features:
- High-accuracy transcription using WhisperX
- Automatic speaker diarization
- Support for both microphone recording and file upload
- GPU-accelerated processing
### Note:
Processing may take a few moments depending on the audio length and system resources.
""")
submit_button.click(
fn=transcribe_audio,
inputs=[audio_input, task],
outputs=output_text
)
demo.queue().launch(ssr_mode=False) |