Spaces:
Running
Running
File size: 3,498 Bytes
89040ed |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 |
import os
import sys
import re
from typing import Dict, List
import csv
import pandas as pd
import numpy as np
import torch
from tqdm import tqdm
import pathlib
import librosa
import lightning.pytorch as pl
from models.clap_encoder import CLAP_Encoder
sys.path.append('../AudioSep/')
from utils import (
load_ss_model,
calculate_sdr,
calculate_sisdr,
parse_yaml,
get_mean_sdr_from_dict,
)
class MUSICEvaluator:
def __init__(
self,
sampling_rate=32000
) -> None:
self.sampling_rate = sampling_rate
with open('evaluation/metadata/music_eval.csv') as csv_file:
csv_reader = csv.reader(csv_file, delimiter=',')
eval_list = [row for row in csv_reader][1:]
self.eval_list = eval_list
self.audio_dir = 'evaluation/data/music'
self.source_types = [
"acoustic guitar",
"violin",
"accordion",
"xylophone",
"erhu",
"trumpet",
"tuba",
"cello",
"flute",
"saxophone"]
def __call__(
self,
pl_model: pl.LightningModule
) -> Dict:
r"""Evalute."""
print(f'Evaluation on MUSIC Test with [text label] queries.')
pl_model.eval()
device = pl_model.device
sisdrs_list = {source_type: [] for source_type in self.source_types}
sdris_list = {source_type: [] for source_type in self.source_types}
with torch.no_grad():
for eval_data in tqdm(self.eval_list):
idx, caption, _, _, = eval_data
source_path = os.path.join(self.audio_dir, f'segment-{idx}.wav')
mixture_path = os.path.join(self.audio_dir, f'mixture-{idx}.wav')
source, fs = librosa.load(source_path, sr=self.sampling_rate, mono=True)
mixture, fs = librosa.load(mixture_path, sr=self.sampling_rate, mono=True)
sdr_no_sep = calculate_sdr(ref=source, est=mixture)
text = [caption]
conditions = pl_model.query_encoder.get_query_embed(
modality='text',
text=text,
device=device
)
input_dict = {
"mixture": torch.Tensor(mixture)[None, None, :].to(device),
"condition": conditions,
}
sep_segment = pl_model.ss_model(input_dict)["waveform"]
# sep_segment: (batch_size=1, channels_num=1, segment_samples)
sep_segment = sep_segment.squeeze(0).squeeze(0).data.cpu().numpy()
# sep_segment: (segment_samples,)
sdr = calculate_sdr(ref=source, est=sep_segment)
sdri = sdr - sdr_no_sep
sisdr = calculate_sisdr(ref=source, est=sep_segment)
sisdrs_list[caption].append(sisdr)
sdris_list[caption].append(sdri)
mean_sisdr_list = []
mean_sdri_list = []
for source_class in self.source_types:
sisdr = np.mean(sisdrs_list[source_class])
sdri = np.mean(sdris_list[source_class])
mean_sisdr_list.append(sisdr)
mean_sdri_list.append(sdri)
mean_sdri = np.mean(mean_sdri_list)
mean_sisdr = np.mean(mean_sisdr_list)
return mean_sisdr, mean_sdri |