import whisperx import torch import gradio as gr import tempfile import os import spaces device = "cuda" if torch.cuda.is_available() else "cpu" batch_size = 4 # 如果GPU内存不足,可适当减少 compute_type = "float32" # 如果GPU内存不足,可改为 "int8"(可能影响准确度) @spaces.GPU def transcribe_whisperx(audio_file, task): # WhisperX模型加载 model = whisperx.load_model("large-v3", device=device, compute_type=compute_type) if audio_file is None: raise gr.Error("请上传或录制音频文件再提交请求!") # 加载音频文件 audio = whisperx.load_audio(audio_file) # 执行初步转录 result = model.transcribe(audio, batch_size=batch_size) # 释放模型资源,防止GPU内存不足 torch.cuda.empty_cache() # 加载对齐模型并对齐转录结果 model_a, metadata = whisperx.load_align_model(language_code=result["language"], device=device) result = whisperx.align(result["segments"], model_a, metadata, audio, device, return_char_alignments=False) # 执行说话人分离 hf_token = os.getenv("HF_TOKEN") diarize_model = whisperx.DiarizationPipeline(use_auth_token=hf_token, device=device) diarize_segments = diarize_model(audio_file) result = whisperx.assign_word_speakers(diarize_segments, result) # 格式化输出文本 output_text = "" for segment in result["segments"]: speaker = segment.get("speaker", "未知") text = segment["text"] output_text += f"{speaker}: {text}\n" return output_text # Gradio界面 demo = gr.Blocks(theme=gr.themes.Ocean()) transcribe_interface = gr.Interface( fn=transcribe_whisperx, inputs=[ gr.Audio(sources=["microphone", "upload"], type="filepath"), gr.Radio(["transcribe", "translate"], label="Task", value="transcribe"), ], outputs="text", title="WhisperX: Transcribe and Diarize Audio", description="使用WhisperX对音频文件或麦克风输入进行转录和说话人分离。" ) with demo: transcribe_interface demo.queue().launch(ssr_mode=False)