Spaces:
Sleeping
Sleeping
import os | |
import sys | |
import re | |
from typing import Dict, List | |
import csv | |
import pandas as pd | |
import numpy as np | |
import torch | |
from tqdm import tqdm | |
import pathlib | |
import librosa | |
import lightning.pytorch as pl | |
from models.clap_encoder import CLAP_Encoder | |
sys.path.append('../AudioSep/') | |
from utils import ( | |
load_ss_model, | |
calculate_sdr, | |
calculate_sisdr, | |
parse_yaml, | |
get_mean_sdr_from_dict, | |
) | |
class MUSICEvaluator: | |
def __init__( | |
self, | |
sampling_rate=32000 | |
) -> None: | |
self.sampling_rate = sampling_rate | |
with open('evaluation/metadata/music_eval.csv') as csv_file: | |
csv_reader = csv.reader(csv_file, delimiter=',') | |
eval_list = [row for row in csv_reader][1:] | |
self.eval_list = eval_list | |
self.audio_dir = 'evaluation/data/music' | |
self.source_types = [ | |
"acoustic guitar", | |
"violin", | |
"accordion", | |
"xylophone", | |
"erhu", | |
"trumpet", | |
"tuba", | |
"cello", | |
"flute", | |
"saxophone"] | |
def __call__( | |
self, | |
pl_model: pl.LightningModule | |
) -> Dict: | |
r"""Evalute.""" | |
print(f'Evaluation on MUSIC Test with [text label] queries.') | |
pl_model.eval() | |
device = pl_model.device | |
sisdrs_list = {source_type: [] for source_type in self.source_types} | |
sdris_list = {source_type: [] for source_type in self.source_types} | |
with torch.no_grad(): | |
for eval_data in tqdm(self.eval_list): | |
idx, caption, _, _, = eval_data | |
source_path = os.path.join(self.audio_dir, f'segment-{idx}.wav') | |
mixture_path = os.path.join(self.audio_dir, f'mixture-{idx}.wav') | |
source, fs = librosa.load(source_path, sr=self.sampling_rate, mono=True) | |
mixture, fs = librosa.load(mixture_path, sr=self.sampling_rate, mono=True) | |
sdr_no_sep = calculate_sdr(ref=source, est=mixture) | |
text = [caption] | |
conditions = pl_model.query_encoder.get_query_embed( | |
modality='text', | |
text=text, | |
device=device | |
) | |
input_dict = { | |
"mixture": torch.Tensor(mixture)[None, None, :].to(device), | |
"condition": conditions, | |
} | |
sep_segment = pl_model.ss_model(input_dict)["waveform"] | |
# sep_segment: (batch_size=1, channels_num=1, segment_samples) | |
sep_segment = sep_segment.squeeze(0).squeeze(0).data.cpu().numpy() | |
# sep_segment: (segment_samples,) | |
sdr = calculate_sdr(ref=source, est=sep_segment) | |
sdri = sdr - sdr_no_sep | |
sisdr = calculate_sisdr(ref=source, est=sep_segment) | |
sisdrs_list[caption].append(sisdr) | |
sdris_list[caption].append(sdri) | |
mean_sisdr_list = [] | |
mean_sdri_list = [] | |
for source_class in self.source_types: | |
sisdr = np.mean(sisdrs_list[source_class]) | |
sdri = np.mean(sdris_list[source_class]) | |
mean_sisdr_list.append(sisdr) | |
mean_sdri_list.append(sdri) | |
mean_sdri = np.mean(mean_sdri_list) | |
mean_sisdr = np.mean(mean_sisdr_list) | |
return mean_sisdr, mean_sdri |