QLWD commited on
Commit
340e149
·
verified ·
1 Parent(s): ba5925c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +74 -46
app.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- import spaces
3
- import gradio as gr
4
  import os
 
5
  from pyannote.audio import Pipeline
6
  from pydub import AudioSegment
 
7
 
8
  # 获取 Hugging Face 认证令牌
9
  HF_TOKEN = os.environ.get("HUGGINGFACE_READ_TOKEN")
@@ -60,7 +60,7 @@ def combine_audio_with_time(target_audio, mixed_audio):
60
  return {"start_time": target_start_time, "end_time": target_end_time}
61
 
62
  # 使用 pyannote/speaker-diarization 对拼接后的音频进行说话人分离
63
- @spaces.GPU(duration=60 * 2) # 使用 GPU 加速,限制执行时间为 120 秒
64
  def diarize_audio(temp_file):
65
  if pipeline is None:
66
  return "错误: 模型未初始化"
@@ -74,6 +74,31 @@ def diarize_audio(temp_file):
74
  except Exception as e:
75
  return f"处理音频时出错: {e}"
76
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
  # 获取目标说话人的时间段(排除目标音频时间段)
78
  def get_speaker_segments(diarization, target_start_time, target_end_time, final_audio_length):
79
  speaker_segments = {}
@@ -83,34 +108,23 @@ def get_speaker_segments(diarization, target_start_time, target_end_time, final_
83
  start = turn.start
84
  end = turn.end
85
 
86
- # 如果是目标说话人
87
- if speaker == 'SPEAKER_00':
88
- # 如果时间段与目标音频有重叠,需要截断
89
- if start < target_end_time and end > target_start_time:
90
- # 记录被截断的时间段
91
- if start < target_start_time:
92
- # 目标音频开始前的时间段
93
- speaker_segments.setdefault(speaker, []).append((start, min(target_start_time, end)))
94
-
95
- if end > target_end_time:
96
- # 目标音频结束后的时间段
97
- speaker_segments.setdefault(speaker, []).append((max(target_end_time, start), min(end, final_audio_length)))
98
- else:
99
- # 完全不与目标音频重叠的时间段
100
- if end <= target_start_time or start >= target_end_time:
101
- speaker_segments.setdefault(speaker, []).append((start, end))
102
 
103
  return speaker_segments
104
 
105
- # 剪辑音频函数:根据时间段剪辑音频
106
- def clip_audio(audio_segment, segments):
107
- clips = []
108
- for start, end in segments:
109
- start_ms = int(start * 1000) # 毫秒
110
- end_ms = int(end * 1000) # 毫秒
111
- clips.append(audio_segment[start_ms:end_ms])
112
- return clips
113
-
114
  # 处理音频文件并返回输出
115
  def process_audio(target_audio, mixed_audio):
116
  print(f"处理音频:目标音频: {target_audio}, 混合音频: {mixed_audio}")
@@ -131,30 +145,44 @@ def process_audio(target_audio, mixed_audio):
131
  # 获取拼接后的音频长度
132
  final_audio_length = len(AudioSegment.from_wav("final_output.wav")) / 1000 # 秒为单位
133
 
134
- # 获取目标说话人的时间段(排除和截断目标音频时间段)
135
- speaker_segments = get_speaker_segments(
136
- diarization_result,
137
  time_dict['start_time'],
138
  time_dict['end_time'],
139
- final_audio_length
140
  )
141
 
142
- if speaker_segments and 'SPEAKER_00' in speaker_segments:
143
- # 剪辑目标说话人的音频片段
144
- final_audio_segment = AudioSegment.from_wav("final_output.wav")
145
- clips = clip_audio(final_audio_segment, speaker_segments['SPEAKER_00'])
146
-
147
- # 将剪辑后的音频片段导出为多个文件
148
- output_files = []
149
- for i, clip in enumerate(clips):
150
- clip_path = f"speaker_00_clip_{i + 1}.wav"
151
- clip.export(clip_path, format="wav")
152
- output_files.append(clip_path)
153
 
154
- # 返回剪辑后的音频文件路径
155
- return output_files
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
156
  else:
157
- return "没有找到SPEAKER_00的时间段。"
158
 
159
  # Gradio 接口
160
  with gr.Blocks() as demo:
 
1
  import torch
 
 
2
  import os
3
+ import gradio as gr
4
  from pyannote.audio import Pipeline
5
  from pydub import AudioSegment
6
+ from spaces import GPU
7
 
8
  # 获取 Hugging Face 认证令牌
9
  HF_TOKEN = os.environ.get("HUGGINGFACE_READ_TOKEN")
 
60
  return {"start_time": target_start_time, "end_time": target_end_time}
61
 
62
  # 使用 pyannote/speaker-diarization 对拼接后的音频进行说话人分离
63
+ @GPU(duration=60 * 2) # 使用 GPU 加速,限制执行时间为 120 秒
64
  def diarize_audio(temp_file):
65
  if pipeline is None:
66
  return "错误: 模型未初始化"
 
74
  except Exception as e:
75
  return f"处理音频时出错: {e}"
76
 
77
+ # 查找最匹配的说话人
78
+ def find_best_matching_speaker(target_start_time, target_end_time, diarization):
79
+ best_match = None
80
+ max_overlap = 0
81
+
82
+ # 遍历所有说话人时间段,计算与目标音频的重叠部分
83
+ for turn, _, speaker in diarization.itertracks(yield_label=True):
84
+ start = turn.start
85
+ end = turn.end
86
+
87
+ # 计算重叠部分的开始和结束时间
88
+ overlap_start = max(start, target_start_time)
89
+ overlap_end = min(end, target_end_time)
90
+
91
+ # 如果有重叠部分,计算重叠的持续时间
92
+ if overlap_end > overlap_start:
93
+ overlap_duration = overlap_end - overlap_start
94
+
95
+ # 如果当前重叠部分更大,则更新最匹配的说话人
96
+ if overlap_duration > max_overlap:
97
+ max_overlap = overlap_duration
98
+ best_match = speaker
99
+
100
+ return best_match, max_overlap
101
+
102
  # 获取目标说话人的时间段(排除目标音频时间段)
103
  def get_speaker_segments(diarization, target_start_time, target_end_time, final_audio_length):
104
  speaker_segments = {}
 
108
  start = turn.start
109
  end = turn.end
110
 
111
+ # 如果时间段与目标音频有重叠,需要截断
112
+ if start < target_end_time and end > target_start_time:
113
+ # 记录被截断的时间段
114
+ if start < target_start_time:
115
+ # 目标音频开始前的时间段
116
+ speaker_segments.setdefault(speaker, []).append((start, min(target_start_time, end)))
117
+
118
+ if end > target_end_time:
119
+ # 目标音频结束后的时间段
120
+ speaker_segments.setdefault(speaker, []).append((max(target_end_time, start), min(end, final_audio_length)))
121
+ else:
122
+ # 完全不与目标音频重叠的时间段
123
+ if end <= target_start_time or start >= target_end_time:
124
+ speaker_segments.setdefault(speaker, []).append((start, end))
 
 
125
 
126
  return speaker_segments
127
 
 
 
 
 
 
 
 
 
 
128
  # 处理音频文件并返回输出
129
  def process_audio(target_audio, mixed_audio):
130
  print(f"处理音频:目标音频: {target_audio}, 混合音频: {mixed_audio}")
 
145
  # 获取拼接后的音频长度
146
  final_audio_length = len(AudioSegment.from_wav("final_output.wav")) / 1000 # 秒为单位
147
 
148
+ # 查找最匹配的说话人
149
+ best_match, overlap_duration = find_best_matching_speaker(
 
150
  time_dict['start_time'],
151
  time_dict['end_time'],
152
+ diarization_result
153
  )
154
 
155
+ if best_match:
156
+ # 获取目标说话人的时间段(排除和截断目标音频时间段)
157
+ speaker_segments = get_speaker_segments(
158
+ diarization_result,
159
+ time_dict['start_time'],
160
+ time_dict['end_time'],
161
+ final_audio_length
162
+ )
 
 
 
163
 
164
+ if best_match in speaker_segments:
165
+ # 拼接所有片段
166
+ final_output = AudioSegment.empty()
167
+ for segment in speaker_segments[best_match]:
168
+ start_time_ms = int(segment[0] * 1000) # 转为毫秒
169
+ end_time_ms = int(segment[1] * 1000)
170
+ segment_audio = AudioSegment.from_wav("final_output.wav")[start_time_ms:end_time_ms]
171
+ final_output += segment_audio
172
+
173
+ # 导出最终拼接音频
174
+ final_output.export("final_combined_output.wav", format="wav")
175
+
176
+ return {
177
+ 'best_matching_speaker': best_match,
178
+ 'overlap_duration': overlap_duration,
179
+ 'segments': speaker_segments[best_match],
180
+ 'final_audio': "final_combined_output.wav"
181
+ }
182
+ else:
183
+ return "没有找到匹配的说话人时间段。"
184
  else:
185
+ return "未找到匹配的说话人。"
186
 
187
  # Gradio 接口
188
  with gr.Blocks() as demo: