Tianhao Wang
first commit
dbbd709
import os
import csv
import glob
from tqdm import tqdm
import torch
import torchaudio
from torchmetrics.audio import ScaleInvariantSignalDistortionRatio, SignalDistortionRatio
def calculate_sdr_and_sisdr(original_audio_path, separated_audio_paths):
"""
计算叠加的音频与原始音频之间的 SDR 和 SI-SDR。
参数:
- original_audio_path: str, 原始音频文件路径。
- separated_audio_paths: List[str], 分割后的音频片段文件路径列表。
返回:
- sdr: float, SDR 值。
- sisdr: float, SI-SDR 值。
"""
# 加载原始音频
original_waveform, sample_rate = torchaudio.load(original_audio_path)
# 初始化叠加的音频波形
combined_waveform = None
# 加载并叠加分割的音频片段
for path in separated_audio_paths:
separated_waveform, _ = torchaudio.load(path)
# 对齐片段长度
min_length = min(original_waveform.size(1), separated_waveform.size(1))
separated_waveform = separated_waveform[:, :min_length]
# 初始化或叠加音频
if combined_waveform is None:
combined_waveform = separated_waveform
else:
combined_waveform = combined_waveform[:, :min_length] + separated_waveform
# 确保合并后的音频和原始音频的长度一致
min_length = min(original_waveform.size(1), combined_waveform.size(1))
original_waveform = original_waveform[:, :min_length]
combined_waveform = combined_waveform[:, :min_length]
# 计算 SI-SDR
sisdr_metric = ScaleInvariantSignalDistortionRatio()
sisdr = sisdr_metric(combined_waveform, original_waveform).item()
# 计算 SDR
sdr_metric = SignalDistortionRatio()
sdr = sdr_metric(combined_waveform, original_waveform).item()
# print(f"SI-SDR between original and combined audio: {sisdr} dB")
# print(f"SDR between original and combined audio: {sdr} dB")
return sdr, sisdr
if __name__ == "__main__":
# 示例: 指定原始音频和分割后的音频片段路径
# original_audio_path = "path_to_original_audio.wav"
# separated_audio_paths = [
# "path_to_segment_1.wav",
# "path_to_segment_2.wav",
# "path_to_segment_3.wav",
# ]
# # 计算 SDR 和 SI-SDR
# sdr, sisdr = calculate_sdr_and_sisdr(original_audio_path, separated_audio_paths)
dset = 'balanced_train_segments'
# dset = 'eval_segments'
src_data_root = r'/data/sound/audioset/audios_32k'
sep_data_root = r'data_engine_infer/audioset_separation_child_label'
writer = csv.writer(open(os.path.join(sep_data_root, dset + '.csv'), 'w'))
writer.writerow(['video', 'sdr', 'sisdr'])
for video_path in tqdm(glob.glob(os.path.join(sep_data_root, dset, '*'))):
video = video_path.split('/')[-1]
original_audio_path = os.path.join(src_data_root, dset, video + '.wav')
separated_audio_paths = glob.glob(video_path + '/*')
sdr, sisdr = calculate_sdr_and_sisdr(original_audio_path, separated_audio_paths)
writer.writerow([video, f'{sdr:.3f}', f'{sisdr:.3f}'])
# dset = 'unbalanced_train_segments'
# src_data_root = r'/data/sound/audioset/audios_32k'
# sep_data_root = r'data_engine_infer/audioset_separation_child_label'
# writer = csv.writer(open(os.path.join(sep_data_root, dset + '.csv'), 'w'))
# writer.writerow(['video', 'sdr', 'sisdr'])
# for video_path in tqdm(glob.glob(os.path.join(sep_data_root, dset, '*', '*'))):
# part = video_path.split('/')[-2]
# video = video_path.split('/')[-1]
# original_audio_path = os.path.join(src_data_root, dset, part, video + '.wav')
# separated_audio_paths = glob.glob(video_path + '/*')
# sdr, sisdr = calculate_sdr_and_sisdr(original_audio_path, separated_audio_paths)
# writer.writerow([video, f'{sdr:.3f}', f'{sisdr:.3f}'])