Spaces:
Sleeping
Sleeping
import os | |
import sys | |
import re | |
from typing import Dict, List | |
import pandas as pd | |
import numpy as np | |
import torch | |
from tqdm import tqdm | |
import pathlib | |
import librosa | |
import lightning.pytorch as pl | |
from models.clap_encoder import CLAP_Encoder | |
sys.path.append('../AudioSep/') | |
from utils import ( | |
load_ss_model, | |
calculate_sdr, | |
calculate_sisdr, | |
parse_yaml, | |
get_mean_sdr_from_dict, | |
) | |
meta_csv_file = "evaluation/metadata/class_labels_indices.csv" | |
df = pd.read_csv(meta_csv_file, sep=',') | |
IDS = df['mid'].tolist() | |
LABELS = df['display_name'].tolist() | |
CLASSES_NUM = len(LABELS) | |
IX_TO_LB = {i : label for i, label in enumerate(LABELS)} | |
class AudioSetEvaluator: | |
def __init__( | |
self, | |
audios_dir='evaluation/data/audioset', | |
classes_num=527, | |
sampling_rate=32000, | |
number_per_class=10, | |
) -> None: | |
r"""AudioSet evaluator. | |
Args: | |
audios_dir (str): directory of evaluation segments | |
classes_num (int): the number of sound classes | |
number_per_class (int), the number of samples to evaluate for each sound class | |
Returns: | |
None | |
""" | |
self.audios_dir = audios_dir | |
self.classes_num = classes_num | |
self.number_per_class = number_per_class | |
self.sampling_rate = sampling_rate | |
def __call__( | |
self, | |
pl_model: pl.LightningModule | |
) -> Dict: | |
r"""Evalute.""" | |
pl_model.eval() | |
sisdrs_dict = {class_id: [] for class_id in range(self.classes_num)} | |
sdris_dict = {class_id: [] for class_id in range(self.classes_num)} | |
print('Evaluation on AudioSet with [text label] queries.') | |
for class_id in tqdm(range(self.classes_num)): | |
sub_dir = os.path.join( | |
self.audios_dir, | |
"class_id={}".format(class_id)) | |
audio_names = self._get_audio_names(audios_dir=sub_dir) | |
for audio_index, audio_name in enumerate(audio_names): | |
if audio_index == self.number_per_class: | |
break | |
source_path = os.path.join( | |
sub_dir, "{},source.wav".format(audio_name)) | |
mixture_path = os.path.join( | |
sub_dir, "{},mixture.wav".format(audio_name)) | |
source, fs = librosa.load(source_path, sr=self.sampling_rate, mono=True) | |
mixture, fs = librosa.load(mixture_path, sr=self.sampling_rate, mono=True) | |
sdr_no_sep = calculate_sdr(ref=source, est=mixture) | |
device = pl_model.device | |
text = [IX_TO_LB[class_id]] | |
conditions = pl_model.query_encoder.get_query_embed( | |
modality='text', | |
text=text, | |
device=device | |
) | |
input_dict = { | |
"mixture": torch.Tensor(mixture)[None, None, :].to(device), | |
"condition": conditions, | |
} | |
sep_segment = pl_model.ss_model(input_dict)["waveform"] | |
# sep_segment: (batch_size=1, channels_num=1, segment_samples) | |
sep_segment = sep_segment.squeeze(0).squeeze(0).data.cpu().numpy() | |
# sep_segment: (segment_samples,) | |
sdr = calculate_sdr(ref=source, est=sep_segment) | |
sdri = sdr - sdr_no_sep | |
sisdr = calculate_sisdr(ref=source, est=sep_segment) | |
sisdrs_dict[class_id].append(sisdr) | |
sdris_dict[class_id].append(sdri) | |
stats_dict = { | |
"sisdrs_dict": sisdrs_dict, | |
"sdris_dict": sdris_dict, | |
} | |
return stats_dict | |
def _get_audio_names(self, audios_dir: str) -> List[str]: | |
r"""Get evaluation audio names.""" | |
audio_names = sorted(os.listdir(audios_dir)) | |
audio_names = [audio_name for audio_name in audio_names if '.wav' in audio_name] | |
audio_names = [ | |
re.search( | |
"(.*),(mixture|source).wav", | |
audio_name).group(1) for audio_name in audio_names] | |
audio_names = sorted(list(set(audio_names))) | |
return audio_names | |
def get_median_metrics(stats_dict, metric_type): | |
class_ids = stats_dict[metric_type].keys() | |
median_stats_dict = { | |
class_id: np.nanmedian( | |
stats_dict[metric_type][class_id]) for class_id in class_ids} | |
return median_stats_dict | |