QLWD commited on
Commit
1cdea95
·
verified ·
1 Parent(s): 21019ce

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +150 -127
app.py CHANGED
@@ -3,162 +3,185 @@ import spaces
3
  import gradio as gr
4
  import os
5
  from pyannote.audio import Pipeline
 
6
  from pydub import AudioSegment
7
 
8
  # 获取 Hugging Face 认证令牌
9
  HF_TOKEN = os.environ.get("HUGGINGFACE_READ_TOKEN")
10
- pipeline = None
11
 
12
- # 尝试加载 pyannote 模型
13
- try:
14
- pipeline = Pipeline.from_pretrained(
15
- "pyannote/speaker-diarization-3.1", use_auth_token=HF_TOKEN
16
- )
17
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
18
- pipeline.to(device)
19
- except Exception as e:
20
- print(f"Error initializing pipeline: {e}")
21
- pipeline = None
22
-
23
- # 音频拼接函数:拼接目标音频和混合音频,返回目标音频的起始时间和结束时间作为字典
24
- def combine_audio_with_time(target_audio, mixed_audio):
25
- if pipeline is None:
26
- return "错误: 模型未初始化"
27
-
28
- # 打印文件路径,确保文件正确传递
29
- print(f"目标音频文件路径: {target_audio}")
30
- print(f"混合音频文件路径: {mixed_audio}")
31
-
32
- # 加载目标说话人的样本音频
33
- try:
34
- target_audio_segment = AudioSegment.from_wav(target_audio)
35
- except Exception as e:
36
- return f"加载目标音频时出错: {e}"
37
-
38
- # 加载混合音频
39
- try:
40
- mixed_audio_segment = AudioSegment.from_wav(mixed_audio)
41
- except Exception as e:
42
- return f"加载混合音频时出错: {e}"
43
-
44
- # 记录目标说话人音频的时间点(精确到0.01秒)
45
- target_start_time = len(mixed_audio_segment) / 1000 # 秒为单位,精确到 0.01 秒
46
-
47
- # 目标音频的结束时间(拼接后的音频长度)
48
- target_end_time = target_start_time + len(target_audio_segment) / 1000 # 秒为单位
49
-
50
- # 将目标说话人的音频片段添加到混合音频的最后
51
- final_audio = mixed_audio_segment + target_audio_segment
52
- final_audio.export("final_output.wav", format="wav")
53
-
54
- # 返回目标音频的起始时间和结束时间
55
- return {"start_time": target_start_time, "end_time": target_end_time}
56
-
57
- # 使用 pyannote/speaker-diarization 对拼接后的音频进行说话人分离
58
- @spaces.GPU(duration=60 * 2) # 使用 GPU 加速,限制执行时间为 120 秒
59
- def diarize_audio(temp_file):
60
- if pipeline is None:
61
- return "错误: 模型未初始化"
62
-
63
- try:
64
- diarization = pipeline(temp_file)
65
- except Exception as e:
66
- return f"处理音频时出错: {e}"
67
- print(diarization)
68
- print(type(diarization))
69
- # 返回 diarization 输出
70
- return str(diarization)
71
-
72
- # 将时间戳转换为秒
73
- def timestamp_to_seconds(timestamp):
74
- try:
75
- h, m, s = map(float, timestamp.split(':'))
76
- return 3600 * h + 60 * m + s
77
- except ValueError as e:
78
- print(f"转换时间戳时出错: '{timestamp}'. 错误: {e}")
79
- return None
80
-
81
- # 计算时间段的重叠部分(单位:秒)
82
- def calculate_overlap(start1, end1, start2, end2):
83
- overlap_start = max(start1, start2)
84
- overlap_end = min(end1, end2)
85
- overlap_duration = max(0, overlap_end - overlap_start)
86
- return overlap_duration
87
-
88
- # 获取目标时间段和说话人时间段的重叠比例
89
- def get_best_match(target_time, diarization_output):
90
- target_start_time = target_time['start_time']
91
- target_end_time = target_time['end_time']
92
-
93
- # 假设 diarization_output 是一个列表,包含说话人时间段和标签
94
- speaker_segments = []
95
- for line in diarization_output.strip().split('\n'):
96
  try:
97
- parts = line.strip()[1:-1].split(' --> ')
98
- start_time = parts[0].strip()
99
- end_time = parts[1].split(']')[0].strip()
100
- label = line.split()[-1].strip()
 
 
 
 
 
 
 
 
 
 
 
101
 
102
- start_seconds = timestamp_to_seconds(start_time)
103
- end_seconds = timestamp_to_seconds(end_time)
 
104
 
105
- # 计算目标音频时间段和说话人时间段的重叠时间
106
- overlap = calculate_overlap(target_start_time, target_end_time, start_seconds, end_seconds)
107
- overlap_ratio = overlap / (end_seconds - start_seconds)
108
 
109
- # 记录说话人标签和重叠比例
110
- speaker_segments.append((label, overlap_ratio, start_seconds, end_seconds))
 
 
 
111
 
 
 
112
  except Exception as e:
113
- print(f"处理行时出错: '{line.strip()}'. 错误: {e}")
114
-
115
- # 按照重叠比例排序,返回重叠比例最大的一段
116
- best_match = max(speaker_segments, key=lambda x: x[1], default=None)
117
-
118
- return best_match
119
-
120
- # 处理音频文件并返回输出
121
- def process_audio(target_audio, mixed_audio):
122
- # 打印文件路径,确保传入的文件有效
123
- print(f"处理音频:目标音频: {target_audio}, 混合音频: {mixed_audio}")
124
-
125
- # 进行音频拼接并返回目标音频的起始和结束时间(作为字典)
126
- time_dict = combine_audio_with_time(target_audio, mixed_audio)
127
-
128
- # 执行说话人分离
129
- diarization_result = diarize_audio("final_output.wav")
130
-
131
- if diarization_result.startswith("错误"):
132
- return diarization_result, None # 出错时返回错误信息
133
- else:
134
- # 获取最佳匹配的说话人标签和时间段
135
- best_match = get_best_match(time_dict, diarization_result)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
136
 
137
  if best_match:
138
- # 返回最佳匹配说话人的标签和时间段
139
- return best_match[0], best_match[2], best_match[3]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
140
 
141
  # Gradio 接口
142
  with gr.Blocks() as demo:
143
  gr.Markdown("""
144
  # 🗣️ 音频拼接与说话人分类 🗣️
145
- 上传目标音频和混合音频,拼接并进行说话人分类。结果包括最佳匹配说话人的时间段。
146
  """)
147
-
148
  mixed_audio_input = gr.Audio(type="filepath", label="上传混合音频")
149
  target_audio_input = gr.Audio(type="filepath", label="上传目标说话人音频")
150
-
151
  process_button = gr.Button("处理音频")
152
-
153
  # 输出结果
154
  diarization_output = gr.Textbox(label="最佳匹配说话人")
155
  time_range_output = gr.Textbox(label="最佳匹配时间段")
156
 
157
  # 点击按钮时触发处理音频
158
  process_button.click(
159
- fn=process_audio,
160
  inputs=[target_audio_input, mixed_audio_input],
161
  outputs=[diarization_output, time_range_output]
162
  )
163
 
164
- demo.launch(share=True)
 
3
  import gradio as gr
4
  import os
5
  from pyannote.audio import Pipeline
6
+ from pyannote.core import Annotation, Segment
7
  from pydub import AudioSegment
8
 
9
  # 获取 Hugging Face 认证令牌
10
  HF_TOKEN = os.environ.get("HUGGINGFACE_READ_TOKEN")
 
11
 
12
+ class AudioProcessor:
13
+ def __init__(self):
14
+ self.pipeline = None
15
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
+
17
+ # 尝试加载 pyannote 模型
18
+ try:
19
+ self.pipeline = Pipeline.from_pretrained(
20
+ "pyannote/speaker-diarization-3.1", use_auth_token=HF_TOKEN
21
+ )
22
+ self.pipeline.to(self.device)
23
+ print("pyannote model loaded successfully.")
24
+ except Exception as e:
25
+ print(f"Error initializing pipeline: {e}")
26
+ self.pipeline = None
27
+
28
+ # 音频拼接函数:拼接目标音频和混合音频,返回目标音频的起始时间和结束时间作为字典
29
+ def combine_audio_with_time(self, target_audio, mixed_audio):
30
+ if self.pipeline is None:
31
+ return "错误: 模型未初始化"
32
+
33
+ # 加载目标说话人的样本音频
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  try:
35
+ target_audio_segment = AudioSegment.from_wav(target_audio)
36
+ except Exception as e:
37
+ return f"加载目标音频时出错: {e}"
38
+
39
+ # 加载混合音频
40
+ try:
41
+ mixed_audio_segment = AudioSegment.from_wav(mixed_audio)
42
+ except Exception as e:
43
+ return f"加载混合音频时出错: {e}"
44
+
45
+ # 记录目标说话人音频的时��点(精确到0.01秒)
46
+ target_start_time = len(mixed_audio_segment) / 1000 # 秒为单位,精确到 0.01 秒
47
+
48
+ # 目标音频的结束时间(拼接后的音频长度)
49
+ target_end_time = target_start_time + len(target_audio_segment) / 1000 # 秒为单位
50
 
51
+ # 将目标说话人的音频片段添加到混合音频的最后
52
+ final_audio = mixed_audio_segment + target_audio_segment
53
+ final_audio.export("final_output.wav", format="wav")
54
 
55
+ # 返回目标音频的起始时间和结束时间
56
+ return {"start_time": target_start_time, "end_time": target_end_time}
 
57
 
58
+ # 使用 pyannote/speaker-diarization 对拼接后的音频进行说话人分离
59
+ @spaces.GPU(duration=60 * 2) # 使用 GPU 加速,限制执行时间为 120 秒
60
+ def diarize_audio(self, temp_file):
61
+ if self.pipeline is None:
62
+ return "错误: 模型未初始化"
63
 
64
+ try:
65
+ diarization = self.pipeline(temp_file) # 返回 Annotation 对象
66
  except Exception as e:
67
+ return f"处理音频时出错: {e}"
68
+
69
+ return diarization # 直接返回 Annotation 对象
70
+
71
+ # 将时间戳转换为秒
72
+ def timestamp_to_seconds(self, timestamp):
73
+ try:
74
+ h, m, s = map(float, timestamp.split(':'))
75
+ return 3600 * h + 60 * m + s
76
+ except ValueError as e:
77
+ print(f"转换时间戳时出错: '{timestamp}'. 错误: {e}")
78
+ return None
79
+
80
+ # 计算时间段的重叠部分(单位:秒)
81
+ def calculate_overlap(self, start1, end1, start2, end2):
82
+ overlap_start = max(start1, start2)
83
+ overlap_end = min(end1, end2)
84
+ overlap_duration = max(0, overlap_end - overlap_start)
85
+ return overlap_duration
86
+
87
+ # 获取该说话人除了目标语音时间段外的所有时间段
88
+ def get_speaker_time_segments(self, diarization_output, target_time, speaker_label):
89
+ remaining_segments = []
90
+
91
+ # 遍历 diarization 输出,查找该说话人的所有时间段
92
+ for segment, label in diarization_output.itertracks(yield_label=True):
93
+ if label == speaker_label:
94
+ start_seconds = segment.start
95
+ end_seconds = segment.end
96
+
97
+ # 计算与目标音频的重叠部分
98
+ overlap_start = max(start_seconds, target_time['start_time'])
99
+ overlap_end = min(end_seconds, target_time['end_time'])
100
+
101
+ # 如果有重叠部分,排除重叠部分
102
+ if overlap_start < overlap_end:
103
+ if start_seconds < overlap_start:
104
+ remaining_segments.append((start_seconds, overlap_start))
105
+ if overlap_end < end_seconds:
106
+ remaining_segments.append((overlap_end, end_seconds))
107
+ else:
108
+ remaining_segments.append((start_seconds, end_seconds))
109
+
110
+ return remaining_segments
111
+
112
+ # 获取最佳匹配的说话人,并返回该说话人非重叠的时间段
113
+ def get_best_match(self, target_time, diarization_output):
114
+ target_start_time = target_time['start_time']
115
+ target_end_time = target_time['end_time']
116
+
117
+ # 用于存储每个说话人时间段的重叠比例
118
+ speaker_segments = []
119
+ for segment, label in diarization_output.itertracks(yield_label=True):
120
+ try:
121
+ start_seconds = segment.start
122
+ end_seconds = segment.end
123
+
124
+ # 计算目标音频时间段和说话人时间段的重叠时间
125
+ overlap = self.calculate_overlap(target_start_time, target_end_time, start_seconds, end_seconds)
126
+ overlap_ratio = overlap / (end_seconds - start_seconds)
127
+
128
+ # 记录说话人标签和重叠比例
129
+ speaker_segments.append((label, overlap_ratio, start_seconds, end_seconds))
130
+
131
+ except Exception as e:
132
+ print(f"处理行时出错: '{segment}'. 错误: {e}")
133
+
134
+ # 按照重叠比例排序,返回重叠比例最大的一段
135
+ best_match = max(speaker_segments, key=lambda x: x[1], default=None)
136
 
137
  if best_match:
138
+ speaker_label = best_match[0] # 取出最佳匹配说话人的标签
139
+ remaining_segments = self.get_speaker_time_segments(diarization_output, target_time, speaker_label)
140
+ return speaker_label, remaining_segments
141
+ else:
142
+ return None, None
143
+
144
+ # 处理音频文件并返回输出
145
+ def process_audio(self, target_audio, mixed_audio):
146
+ # 进行音频拼接并返回目标音频的起始和结束时间(作为字典)
147
+ time_dict = self.combine_audio_with_time(target_audio, mixed_audio)
148
+
149
+ # 执行说话人分离
150
+ diarization_result = self.diarize_audio("final_output.wav")
151
+
152
+ print("123",type(diarization_result))
153
+
154
+ if isinstance(diarization_result, str) and diarization_result.startswith("错误"):
155
+ return diarization_result, None # 出错时返回错误信息
156
+ else:
157
+ # 获取最佳匹配的说话人标签和时间段
158
+ speaker_label, remaining_segments = self.get_best_match(time_dict, diarization_result)
159
+
160
+ if speaker_label and remaining_segments:
161
+ # 返回最佳匹配说话人的标签和时间段
162
+ return speaker_label, remaining_segments
163
 
164
  # Gradio 接口
165
  with gr.Blocks() as demo:
166
  gr.Markdown("""
167
  # 🗣️ 音频拼接与说话人分类 🗣️
168
+ 上传目标音频和混合音频,拼接并进行说话人分类。结果包括最佳匹配说话人的时间段(排除目标音频时间段)。
169
  """)
170
+
171
  mixed_audio_input = gr.Audio(type="filepath", label="上传混合音频")
172
  target_audio_input = gr.Audio(type="filepath", label="上传目标说话人音频")
173
+
174
  process_button = gr.Button("处理音频")
175
+
176
  # 输出结果
177
  diarization_output = gr.Textbox(label="最佳匹配说话人")
178
  time_range_output = gr.Textbox(label="最佳匹配时间段")
179
 
180
  # 点击按钮时触发处理音频
181
  process_button.click(
182
+ fn=AudioProcessor().process_audio,
183
  inputs=[target_audio_input, mixed_audio_input],
184
  outputs=[diarization_output, time_range_output]
185
  )
186
 
187
+ demo.launch(share=True)