from pathlib import Path import os from functools import partial from frechet_audio_distance import FrechetAudioDistance import pandas import argbind from tqdm import tqdm import audiotools from audiotools import AudioSignal @argbind.bind(without_prefix=True) def eval( exp_dir: str = None, baseline_key: str = "baseline", audio_ext: str = ".wav", ): assert exp_dir is not None exp_dir = Path(exp_dir) assert exp_dir.exists(), f"exp_dir {exp_dir} does not exist" # set up our metrics sisdr_loss = audiotools.metrics.distance.SISDRLoss() stft_loss = audiotools.metrics.spectral.MultiScaleSTFTLoss() mel_loss = audiotools.metrics.spectral.MelSpectrogramLoss() frechet = FrechetAudioDistance( use_pca=False, use_activation=False, verbose=True ) visqol = partial(audiotools.metrics.quality.visqol, mode="audio") # figure out what conditions we have conditions = [d.name for d in exp_dir.iterdir() if d.is_dir()] assert baseline_key in conditions, f"baseline_key {baseline_key} not found in {exp_dir}" conditions.remove(baseline_key) print(f"Found {len(conditions)} conditions in {exp_dir}") print(f"conditions: {conditions}") baseline_dir = exp_dir / baseline_key baseline_files = sorted(list(baseline_dir.glob(f"*{audio_ext}")), key=lambda x: int(x.stem)) metrics = [] for condition in conditions: cond_dir = exp_dir / condition cond_files = sorted(list(cond_dir.glob(f"*{audio_ext}")), key=lambda x: int(x.stem)) print(f"computing fad for {baseline_dir} and {cond_dir}") frechet_score = frechet.score(baseline_dir, cond_dir) # make sure we have the same number of files num_files = min(len(baseline_files), len(cond_files)) baseline_files = baseline_files[:num_files] cond_files = cond_files[:num_files] assert len(list(baseline_files)) == len(list(cond_files)), f"number of files in {baseline_dir} and {cond_dir} do not match. {len(list(baseline_files))} vs {len(list(cond_files))}" pbar = tqdm(zip(baseline_files, cond_files), total=len(baseline_files)) for baseline_file, cond_file in pbar: # make sure the files match (same name) assert baseline_file.stem == cond_file.stem, f"baseline file {baseline_file} and cond file {cond_file} do not match" pbar.set_description(baseline_file.stem) # load the files baseline_sig = AudioSignal(str(baseline_file)) cond_sig = AudioSignal(str(cond_file)) # compute the metrics try: vsq = visqol(baseline_sig, cond_sig) except: vsq = 0.0 metrics.append({ "sisdr": -sisdr_loss(baseline_sig, cond_sig).item(), "stft": stft_loss(baseline_sig, cond_sig).item(), "mel": mel_loss(baseline_sig, cond_sig).item(), "frechet": frechet_score, "visqol": vsq, "condition": condition, "file": baseline_file.stem, }) metric_keys = [k for k in metrics[0].keys() if k not in ("condition", "file")] for mk in metric_keys: stat = pandas.DataFrame(metrics) stat = stat.groupby(['condition'])[mk].agg(['mean', 'count', 'std']) stat.to_csv(exp_dir / f"stats-{mk}.csv") df = pandas.DataFrame(metrics) df.to_csv(exp_dir / "metrics-all.csv", index=False) if __name__ == "__main__": args = argbind.parse_args() with argbind.scope(args): eval()