Spaces:
Sleeping
Sleeping
StevenChen16
commited on
Commit
•
550cf61
1
Parent(s):
d6c72bf
Update app.py
Browse files
app.py
CHANGED
@@ -1,66 +1,114 @@
|
|
1 |
-
import
|
2 |
-
import whisperx
|
3 |
import torch
|
4 |
import gradio as gr
|
|
|
|
|
5 |
import tempfile
|
6 |
-
import
|
7 |
|
8 |
-
|
9 |
-
|
10 |
-
|
|
|
|
|
11 |
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
model = whisperx.load_model("large-v3", device=device, compute_type=compute_type)
|
16 |
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
46 |
|
47 |
-
|
|
|
|
|
|
|
48 |
|
49 |
-
# Gradio
|
50 |
demo = gr.Blocks(theme=gr.themes.Ocean())
|
51 |
|
52 |
-
transcribe_interface = gr.Interface(
|
53 |
-
fn=transcribe_whisperx,
|
54 |
-
inputs=[
|
55 |
-
gr.Audio(sources=["microphone", "upload"], type="filepath"),
|
56 |
-
gr.Radio(["transcribe", "translate"], label="Task", value="transcribe"),
|
57 |
-
],
|
58 |
-
outputs="text",
|
59 |
-
title="WhisperX: Transcribe and Diarize Audio",
|
60 |
-
description="使用WhisperX对音频文件或麦克风输入进行转录和说话人分离。"
|
61 |
-
)
|
62 |
-
|
63 |
with demo:
|
64 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
65 |
|
66 |
-
demo.queue().launch(
|
|
|
1 |
+
import os
|
|
|
2 |
import torch
|
3 |
import gradio as gr
|
4 |
+
import whisperx
|
5 |
+
from transformers.pipelines.audio_utils import ffmpeg_read
|
6 |
import tempfile
|
7 |
+
import gc
|
8 |
|
9 |
+
# Constants
|
10 |
+
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
11 |
+
BATCH_SIZE = 4
|
12 |
+
COMPUTE_TYPE = "float32"
|
13 |
+
FILE_LIMIT_MB = 1000
|
14 |
|
15 |
+
def transcribe_audio(inputs, task):
|
16 |
+
if inputs is None:
|
17 |
+
raise gr.Error("No audio file submitted! Please upload or record an audio file before submitting your request.")
|
|
|
18 |
|
19 |
+
try:
|
20 |
+
# Load audio
|
21 |
+
if isinstance(inputs, str):
|
22 |
+
# For file path input
|
23 |
+
audio = whisperx.load_audio(inputs)
|
24 |
+
else:
|
25 |
+
# For microphone input (needs conversion)
|
26 |
+
audio = whisperx.load_audio(inputs)
|
27 |
+
|
28 |
+
# 1. Transcribe with base Whisper model
|
29 |
+
model = whisperx.load_model("large-v3", device=DEVICE, compute_type=COMPUTE_TYPE)
|
30 |
+
result = model.transcribe(audio, batch_size=BATCH_SIZE)
|
31 |
+
|
32 |
+
# Clear GPU memory
|
33 |
+
del model
|
34 |
+
gc.collect()
|
35 |
+
torch.cuda.empty_cache()
|
36 |
+
|
37 |
+
# 2. Align whisper output
|
38 |
+
model_a, metadata = whisperx.load_align_model(language_code=result["language"], device=DEVICE)
|
39 |
+
result = whisperx.align(result["segments"], model_a, metadata, audio, DEVICE, return_char_alignments=False)
|
40 |
+
|
41 |
+
# Clear GPU memory again
|
42 |
+
del model_a
|
43 |
+
gc.collect()
|
44 |
+
torch.cuda.empty_cache()
|
45 |
+
|
46 |
+
# 3. Diarize audio
|
47 |
+
diarize_model = whisperx.DiarizationPipeline(use_auth_token="YOUR_HF_TOKEN", device=DEVICE)
|
48 |
+
diarize_segments = diarize_model(audio)
|
49 |
+
|
50 |
+
# 4. Assign speaker labels
|
51 |
+
result = whisperx.assign_word_speakers(diarize_segments, result)
|
52 |
+
|
53 |
+
# Format output
|
54 |
+
output_text = ""
|
55 |
+
for segment in result['segments']:
|
56 |
+
speaker = segment.get('speaker', 'Unknown Speaker')
|
57 |
+
text = segment['text']
|
58 |
+
output_text += f"{speaker}: {text}\n"
|
59 |
+
|
60 |
+
return output_text
|
61 |
+
|
62 |
+
except Exception as e:
|
63 |
+
raise gr.Error(f"Error processing audio: {str(e)}")
|
64 |
|
65 |
+
finally:
|
66 |
+
# Final cleanup
|
67 |
+
gc.collect()
|
68 |
+
torch.cuda.empty_cache()
|
69 |
|
70 |
+
# Create Gradio interface
|
71 |
demo = gr.Blocks(theme=gr.themes.Ocean())
|
72 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
73 |
with demo:
|
74 |
+
gr.Markdown("# WhisperX: Advanced Speech Recognition with Speaker Diarization")
|
75 |
+
|
76 |
+
with gr.Row():
|
77 |
+
with gr.Column():
|
78 |
+
audio_input = gr.Audio(
|
79 |
+
sources=["microphone", "upload"],
|
80 |
+
type="filepath",
|
81 |
+
label="Audio Input (Microphone or File Upload)"
|
82 |
+
)
|
83 |
+
task = gr.Radio(
|
84 |
+
["transcribe", "translate"],
|
85 |
+
label="Task",
|
86 |
+
value="transcribe"
|
87 |
+
)
|
88 |
+
submit_button = gr.Button("Process Audio")
|
89 |
+
|
90 |
+
with gr.Column():
|
91 |
+
output_text = gr.Textbox(
|
92 |
+
label="Transcription with Speaker Diarization",
|
93 |
+
lines=10,
|
94 |
+
placeholder="Transcribed text will appear here..."
|
95 |
+
)
|
96 |
+
|
97 |
+
gr.Markdown("""
|
98 |
+
### Features:
|
99 |
+
- High-accuracy transcription using WhisperX
|
100 |
+
- Automatic speaker diarization
|
101 |
+
- Support for both microphone recording and file upload
|
102 |
+
- GPU-accelerated processing
|
103 |
+
|
104 |
+
### Note:
|
105 |
+
Processing may take a few moments depending on the audio length and system resources.
|
106 |
+
""")
|
107 |
+
|
108 |
+
submit_button.click(
|
109 |
+
fn=transcribe_audio,
|
110 |
+
inputs=[audio_input, task],
|
111 |
+
outputs=output_text
|
112 |
+
)
|
113 |
|
114 |
+
demo.queue().launch(share=True)
|