speaker / app.py
QLWD's picture
Update app.py
6d92921 verified
raw
history blame
8.7 kB
import torch
import os
import gradio as gr
from pyannote.audio import Pipeline
from pydub import AudioSegment
from spaces import GPU
# 获取 Hugging Face 认证令牌
HF_TOKEN = os.environ.get("HUGGINGFACE_READ_TOKEN")
pipeline = None
# 尝试加载 pyannote 模型
try:
pipeline = Pipeline.from_pretrained(
"pyannote/speaker-diarization-3.1", use_auth_token=HF_TOKEN
)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
pipeline.to(device)
except Exception as e:
print(f"Error initializing pipeline: {e}")
pipeline = None
# 时间戳转换为秒
def timestamp_to_seconds(timestamp):
h, m, s = map(float, timestamp.split(':'))
return 3600 * h + 60 * m + s
def convert_to_wav(audio_file):
try:
# 使用 pydub 打开上传的音频文件
audio = AudioSegment.from_file(audio_file)
# 创建一个 BytesIO 对象以存储转换后的音频
wav_output = BytesIO()
# 将音频转换为 wav 格式并存储在 BytesIO 对象中
audio.export(wav_output, format="wav")
# 将 BytesIO 对象的位置重置为开始位置,以便之后可以读取
wav_output.seek(0)
return wav_output # 返回转换后的 wav 音频
except Exception as e:
return f"音频转换失败: {e}"
# 音频拼接函数:拼接目标音频和混合音频,返回目标音频的起始时间和结束时间作为字典
def combine_audio_with_time(target_audio, mixed_audio):
if pipeline is None:
return "错误: 模型未初始化"
# 打印文件路径,确保文件正确传递
print(f"目标音频文件路径: {target_audio}")
print(f"混合音频文件路径: {mixed_audio}")
# 加载目标说话人的样本音频
try:
target_audio_segment = convert_to_wav(AudioSegment.from_wav(target_audio))
except Exception as e:
return f"加载目标音频时出错: {e}"
# 加载混合音频
try:
mixed_audio_segment = convert_to_wav(AudioSegment.from_wav(mixed_audio))
except Exception as e:
return f"加载混合音频时出错: {e}"
# 记录目标说话人音频的时间点(精确到0.01秒)
target_start_time = len(mixed_audio_segment) / 1000 # 秒为单位,精确到 0.01 秒
# 目标音频的结束时间(拼接后的音频长度)
target_end_time = target_start_time + len(target_audio_segment) / 1000 # 秒为单位
# 将目标说话人的音频片段添加到混合音频的最后
final_audio = mixed_audio_segment + target_audio_segment
final_audio.export("final_output.wav", format="wav")
# 返回目标音频的起始时间和结束时间
return {"start_time": target_start_time, "end_time": target_end_time}
# 使用 pyannote/speaker-diarization 对拼接后的音频进行说话人分离
@GPU(duration=60 * 2) # 使用 GPU 加速,限制执行时间为 120 秒
def diarize_audio(temp_file):
if pipeline is None:
return "错误: 模型未初始化"
try:
diarization = pipeline(temp_file)
print("说话人分离结果:")
for turn, _, speaker in diarization.itertracks(yield_label=True):
print(f"[{turn.start:.3f} --> {turn.end:.3f}] {speaker}")
return diarization
except Exception as e:
return f"处理音频时出错: {e}"
# 查找最匹配的说话人
def find_best_matching_speaker(target_start_time, target_end_time, diarization):
best_match = None
max_overlap = 0
# 遍历所有说话人时间段,计算与目标音频的重叠部分
for turn, _, speaker in diarization.itertracks(yield_label=True):
start = turn.start
end = turn.end
# 计算重叠部分的开始和结束时间
overlap_start = max(start, target_start_time)
overlap_end = min(end, target_end_time)
# 如果有重叠部分,计算重叠的持续时间
if overlap_end > overlap_start:
overlap_duration = overlap_end - overlap_start
# 如果当前重叠部分更大,则更新最匹配的说话人
if overlap_duration > max_overlap:
max_overlap = overlap_duration
best_match = speaker
return best_match, max_overlap
# 获取目标说话人的时间段(排除目标音频时间段)
def get_speaker_segments(diarization, target_start_time, target_end_time, final_audio_length):
speaker_segments = {}
# 遍历所有说话人时间段
for turn, _, speaker in diarization.itertracks(yield_label=True):
start = turn.start
end = turn.end
# 如果时间段与目标音频有重叠,需要截断
if start < target_end_time and end > target_start_time:
# 记录被截断的时间段
if start < target_start_time:
# 目标音频开始前的时间段
speaker_segments.setdefault(speaker, []).append((start, min(target_start_time, end)))
if end > target_end_time:
# 目标音频结束后的时间段
speaker_segments.setdefault(speaker, []).append((max(target_end_time, start), min(end, final_audio_length)))
else:
# 完全不与目标音频重叠的时间段
if end <= target_start_time or start >= target_end_time:
speaker_segments.setdefault(speaker, []).append((start, end))
return speaker_segments
# 处理音频文件并返回输出
def process_audio(target_audio, mixed_audio):
print(f"处理音频:目标音频: {target_audio}, 混合音频: {mixed_audio}")
# 进行音频拼接并返回目标音频的起始和结束时间(作为字典)
time_dict = combine_audio_with_time(target_audio, mixed_audio)
# 如果音频拼接出错,返回错误信息
if isinstance(time_dict, str):
return time_dict
# 执行说话人分离
diarization_result = diarize_audio("final_output.wav")
if isinstance(diarization_result, str) and diarization_result.startswith("错误"):
return diarization_result # 出错时返回错误信息
else:
# 获取拼接后的音频长度
final_audio_length = len(AudioSegment.from_wav("final_output.wav")) / 1000 # 秒为单位
# 查找最匹配的说话人
best_match, overlap_duration = find_best_matching_speaker(
time_dict['start_time'],
time_dict['end_time'],
diarization_result
)
if best_match:
# 获取目标说话人的时间段(排除和截断目标音频时间段)
speaker_segments = get_speaker_segments(
diarization_result,
time_dict['start_time'],
time_dict['end_time'],
final_audio_length
)
if best_match in speaker_segments:
# 拼接所有片段
final_output = AudioSegment.empty()
for segment in speaker_segments[best_match]:
start_time_ms = int(segment[0] * 1000) # 转为毫秒
end_time_ms = int(segment[1] * 1000)
segment_audio = AudioSegment.from_wav("final_output.wav")[start_time_ms:end_time_ms]
final_output += segment_audio
# 导出最终拼接音频
final_output.export("final_combined_output.wav", format="wav")
return "final_combined_output.wav"
else:
return "没有找到匹配的说话人时间段。"
else:
return "未找到匹配的说话人。"
# Gradio 接口
with gr.Blocks() as demo:
gr.Markdown("""
# 🗣️ 音频拼接与说话人分类 🗣️
上传目标音频和混合音频,拼接并进行说话人分类。
结果包括目标说话人(SPEAKER_00)的时间段,已排除和截断目标录音时间段,并自动剪辑目标音频。
""")
mixed_audio_input = gr.Audio(type="filepath", label="上传混合音频")
target_audio_input = gr.Audio(type="filepath", label="上传目标说话人音频")
process_button = gr.Button("处理音频")
# 输出结果
output_audio = gr.Audio(label="剪辑后的音频")
# 点击按钮时触发处理音频
process_button.click(
fn=process_audio,
inputs=[target_audio_input, mixed_audio_input],
outputs=[output_audio]
)
demo.launch(share=True)