File size: 2,176 Bytes
ee53092
86a1f13
 
b86a6f7
 
86a1f13
9a9ac31
86a1f13
 
 
b86a6f7
307c54e
86a1f13
 
 
ee53092
86a1f13
 
 
 
 
 
 
ee53092
 
86a1f13
ee53092
86a1f13
 
ee53092
 
b86a6f7
86a1f13
 
 
 
 
 
8392829
86a1f13
 
 
 
 
 
ee53092
86a1f13
b86a6f7
86a1f13
 
b86a6f7
86a1f13
 
 
 
 
 
 
 
 
b86a6f7
 
 
86a1f13
b86a6f7
86a1f13
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import whisperx
import torch
import gradio as gr
import tempfile
import os
import spaces

device = "cuda" if torch.cuda.is_available() else "cpu"
batch_size = 4  # 如果GPU内存不足,可适当减少
compute_type = "float32"  # 如果GPU内存不足,可改为 "int8"(可能影响准确度)

@spaces.GPU
def transcribe_whisperx(audio_file, task):
    # WhisperX模型加载
    model = whisperx.load_model("large-v3", device=device, compute_type=compute_type)
    
    if audio_file is None:
        raise gr.Error("请上传或录制音频文件再提交请求!")

    # 加载音频文件
    audio = whisperx.load_audio(audio_file)

    # 执行初步转录
    result = model.transcribe(audio, batch_size=batch_size)

    # 释放模型资源,防止GPU内存不足
    torch.cuda.empty_cache()

    # 加载对齐模型并对齐转录结果
    model_a, metadata = whisperx.load_align_model(language_code=result["language"], device=device)
    result = whisperx.align(result["segments"], model_a, metadata, audio, device, return_char_alignments=False)

    # 执行说话人分离
    hf_token = os.getenv("HF_TOKEN")
    diarize_model = whisperx.DiarizationPipeline(use_auth_token=hf_token,
                                                 device=device)
    diarize_segments = diarize_model(audio_file)
    result = whisperx.assign_word_speakers(diarize_segments, result)

    # 格式化输出文本
    output_text = ""
    for segment in result["segments"]:
        speaker = segment.get("speaker", "未知")
        text = segment["text"]
        output_text += f"{speaker}: {text}\n"
    
    return output_text

# Gradio界面
demo = gr.Blocks(theme=gr.themes.Ocean())

transcribe_interface = gr.Interface(
    fn=transcribe_whisperx,
    inputs=[
        gr.Audio(sources=["microphone", "upload"], type="filepath"),
        gr.Radio(["transcribe", "translate"], label="Task", value="transcribe"),
    ],
    outputs="text",
    title="WhisperX: Transcribe and Diarize Audio",
    description="使用WhisperX对音频文件或麦克风输入进行转录和说话人分离。"
)

with demo:
    transcribe_interface

demo.queue().launch(ssr_mode=False)