speaker / app.py
QLWD's picture
Update app.py
6d3bc8f verified
raw
history blame
4.95 kB
import torch
import spaces
import gradio as gr
import os
from pyannote.audio import Pipeline
from pydub import AudioSegment
# 初始化 pyannote/speaker-diarization 模型
HF_TOKEN = os.environ.get("HUGGINGFACE_READ_TOKEN")
pipeline = None
try:
pipeline = Pipeline.from_pretrained(
"pyannote/speaker-diarization-3.1", use_auth_token=HF_TOKEN
)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
pipeline.to(device)
except Exception as e:
print(f"Error initializing pipeline: {e}")
pipeline = None
# 音频拼接函数:拼接目标音频和混合音频
def combine_audio_with_time(target_audio, mixed_audio):
if pipeline is None:
return "错误: 模型未初始化"
# 加载目标说话人的样本音频
target_audio_segment = AudioSegment.from_wav(target_audio)
# 加载混合音频
mixed_audio_segment = AudioSegment.from_wav(mixed_audio)
# 记录目标说话人音频的时间点(精确到0.01秒)
target_start_time = len(mixed_audio_segment) / 1000 # 秒为单位,精确到 0.01 秒
# 将目标说话人的音频片段添加到混合音频的最后
final_audio = mixed_audio_segment + target_audio_segment
# 保存拼接后的音频并返回时间点
final_audio.export("final_output.wav", format="wav")
return "final_output.wav", target_start_time
# 使用 pyannote/speaker-diarization 对拼接后的音频进行说话人分离
@spaces.GPU(duration=60 * 2) # 使用 GPU 加速,限制执行时间为 120 秒
def diarize_audio(temp_file):
if pipeline is None:
return "错误: 模型未初始化"
try:
diarization = pipeline(temp_file)
except Exception as e:
return f"处理音频时出错: {e}"
# 返回 diarization 输出
return str(diarization)
# 生成标签文件的函数
def generate_labels_from_diarization(diarization_output):
labels_path = 'labels.txt'
successful_lines = 0
try:
with open(labels_path, 'w') as outfile:
lines = diarization_output.strip().split('\n')
for line in lines:
try:
parts = line.strip()[1:-1].split(' --> ')
start_time = parts[0].strip()
end_time = parts[1].split(']')[0].strip()
label = line.split()[-1].strip()
start_seconds = timestamp_to_seconds(start_time)
end_seconds = timestamp_to_seconds(end_time)
outfile.write(f"{start_seconds}\t{end_seconds}\t{label}\n")
successful_lines += 1
except Exception as e:
print(f"处理行时出错: '{line.strip()}'. 错误: {e}")
print(f"成功处理了 {successful_lines} 行。")
return labels_path if successful_lines > 0 else None
except Exception as e:
print(f"写入文件时出错: {e}")
return None
# 将时间戳转换为秒
def timestamp_to_seconds(timestamp):
try:
h, m, s = map(float, timestamp.split(':'))
return 3600 * h + 60 * m + s
except ValueError as e:
print(f"转换时间戳时出错: '{timestamp}'. 错误: {e}")
return None
# 处理音频文件并返回输出
def process_audio(target_audio, mixed_audio):
# 进行音频拼接
final_audio_path, target_start_time = combine_audio_with_time(target_audio, mixed_audio)
# 执行说话人分离
diarization_result = diarize_audio(final_audio_path)
if diarization_result.startswith("错误"):
return diarization_result, None, None # 出错时返回错误信息
else:
# 生成标签文件
label_file = generate_labels_from_diarization(diarization_result)
return diarization_result, label_file, final_audio_path # 返回说话人分离结果、标签文件和剪辑后的音频路径
# Gradio 接口
with gr.Blocks() as demo:
gr.Markdown("""
# 🗣️ 音频拼接与说话人分类 🗣️
上传目标说话人音频和混合音频,拼接并进行说话人分类。结果包括说话人分离输出、标签文件和剪辑后的音频文件。
""")
target_audio_input = gr.Audio(type="filepath", label="上传目标说话人音频")
mixed_audio_input = gr.Audio(type="filepath", label="上传混合音频")
process_button = gr.Button("处理音频")
# 输出结果
diarization_output = gr.Textbox(label="说话人分离结果")
label_file_link = gr.File(label="下载标签文件")
# 修改为 gr.Audio 组件来返回音频
final_audio_link = gr.Audio(label="下载剪辑后的音频", type="file")
# 点击按钮时触发处理音频
process_button.click(
fn=process_audio,
inputs=[target_audio_input, mixed_audio_input],
outputs=[diarization_output, label_file_link, final_audio_link]
)
demo.launch(share=False)