from pathlib import Path import os from functools import partial from frechet_audio_distance import FrechetAudioDistance import pandas import argbind from tqdm import tqdm import audiotools from audiotools import AudioSignal @argbind.bind(without_prefix=True) def eval( exp_dir: str = None, baseline_key: str = "reconstructed", audio_ext: str = ".wav", ): assert exp_dir is not None exp_dir = Path(exp_dir) assert exp_dir.exists(), f"exp_dir {exp_dir} does not exist" # set up our metrics sisdr_loss = audiotools.metrics.distance.SISDRLoss() stft_loss = audiotools.metrics.spectral.MultiScaleSTFTLoss() mel_loss = audiotools.metrics.spectral.MelSpectrogramLoss() frechet = FrechetAudioDistance( use_pca=False, use_activation=False, verbose=False ) visqol = partial(audiotools.metrics.quality.visqol, mode="audio") # figure out what conditions we have conditions = [d.name for d in exp_dir.iterdir() if d.is_dir()] assert baseline_key in conditions, f"baseline_key {baseline_key} not found in {exp_dir}" conditions.remove(baseline_key) print(f"Found {len(conditions)} conditions in {exp_dir}") print(f"conditions: {conditions}") baseline_dir = exp_dir / baseline_key baseline_files = list(baseline_dir.glob(f"*{audio_ext}")) metrics = [] for condition in conditions: cond_dir = exp_dir / condition cond_files = list(cond_dir.glob(f"*{audio_ext}")) print(f"computing fad") frechet_score = frechet.score(baseline_dir, cond_dir) # make sure we have the same number of files assert len(list(baseline_files)) == len(list(cond_files)), f"number of files in {baseline_dir} and {cond_dir} do not match. {len(list(baseline_files))} vs {len(list(cond_files))}" pbar = tqdm(zip(baseline_files, cond_files), total=len(baseline_files)) for baseline_file, cond_file in pbar: assert baseline_file.stem == cond_file.stem, f"baseline file {baseline_file} and cond file {cond_file} do not match" pbar.set_description(baseline_file.stem) # load the files baseline_sig = AudioSignal(baseline_file) cond_sig = AudioSignal(cond_file) # compute the metrics try: vsq = visqol(baseline_sig, cond_sig) except: vsq = 0.0 metrics.append({ "sisdr": sisdr_loss(baseline_sig, cond_sig), "stft": stft_loss(baseline_sig, cond_sig), "mel": mel_loss(baseline_sig, cond_sig), "frechet": frechet_score, "visqol": vsq, "condition": condition, "file": baseline_file.stem, }) metric_keys = [k for k in metrics[0].keys() if k not in ("condition", "file")] for mk in metric_keys: stat = pandas.DataFrame(metrics) stat = stat.groupby(['condition'])[mk].agg(['mean', 'count', 'std']) stat.to_csv(exp_dir / f"stats-{mk}.csv") df = pandas.DataFrame(metrics) df.to_csv(exp_dir / "metrics-all.csv", index=False) if __name__ == "__main__": args = argbind.parse_args() with argbind.scope(args): eval()