|
|
import os |
|
|
import csv |
|
|
import glob |
|
|
from tqdm import tqdm |
|
|
import torch |
|
|
import torchaudio |
|
|
from torchmetrics.audio import ScaleInvariantSignalDistortionRatio, SignalDistortionRatio |
|
|
|
|
|
|
|
|
def calculate_sdr_and_sisdr(original_audio_path, separated_audio_paths): |
|
|
""" |
|
|
计算叠加的音频与原始音频之间的 SDR 和 SI-SDR。 |
|
|
|
|
|
参数: |
|
|
- original_audio_path: str, 原始音频文件路径。 |
|
|
- separated_audio_paths: List[str], 分割后的音频片段文件路径列表。 |
|
|
|
|
|
返回: |
|
|
- sdr: float, SDR 值。 |
|
|
- sisdr: float, SI-SDR 值。 |
|
|
""" |
|
|
|
|
|
original_waveform, sample_rate = torchaudio.load(original_audio_path) |
|
|
|
|
|
|
|
|
combined_waveform = None |
|
|
|
|
|
|
|
|
for path in separated_audio_paths: |
|
|
separated_waveform, _ = torchaudio.load(path) |
|
|
|
|
|
|
|
|
min_length = min(original_waveform.size(1), separated_waveform.size(1)) |
|
|
separated_waveform = separated_waveform[:, :min_length] |
|
|
|
|
|
|
|
|
if combined_waveform is None: |
|
|
combined_waveform = separated_waveform |
|
|
else: |
|
|
combined_waveform = combined_waveform[:, :min_length] + separated_waveform |
|
|
|
|
|
|
|
|
min_length = min(original_waveform.size(1), combined_waveform.size(1)) |
|
|
original_waveform = original_waveform[:, :min_length] |
|
|
combined_waveform = combined_waveform[:, :min_length] |
|
|
|
|
|
|
|
|
sisdr_metric = ScaleInvariantSignalDistortionRatio() |
|
|
sisdr = sisdr_metric(combined_waveform, original_waveform).item() |
|
|
|
|
|
|
|
|
sdr_metric = SignalDistortionRatio() |
|
|
sdr = sdr_metric(combined_waveform, original_waveform).item() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return sdr, sisdr |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dset = 'balanced_train_segments' |
|
|
|
|
|
|
|
|
src_data_root = r'/data/sound/audioset/audios_32k' |
|
|
sep_data_root = r'data_engine_infer/audioset_separation_child_label' |
|
|
|
|
|
writer = csv.writer(open(os.path.join(sep_data_root, dset + '.csv'), 'w')) |
|
|
writer.writerow(['video', 'sdr', 'sisdr']) |
|
|
for video_path in tqdm(glob.glob(os.path.join(sep_data_root, dset, '*'))): |
|
|
video = video_path.split('/')[-1] |
|
|
original_audio_path = os.path.join(src_data_root, dset, video + '.wav') |
|
|
separated_audio_paths = glob.glob(video_path + '/*') |
|
|
sdr, sisdr = calculate_sdr_and_sisdr(original_audio_path, separated_audio_paths) |
|
|
writer.writerow([video, f'{sdr:.3f}', f'{sisdr:.3f}']) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|