Spaces:
Sleeping
Sleeping
from pathlib import Path | |
import os | |
from functools import partial | |
from frechet_audio_distance import FrechetAudioDistance | |
import pandas | |
import argbind | |
from tqdm import tqdm | |
import audiotools | |
from audiotools import AudioSignal | |
def eval( | |
exp_dir: str = None, | |
baseline_key: str = "reconstructed", | |
audio_ext: str = ".wav", | |
): | |
assert exp_dir is not None | |
exp_dir = Path(exp_dir) | |
assert exp_dir.exists(), f"exp_dir {exp_dir} does not exist" | |
# set up our metrics | |
sisdr_loss = audiotools.metrics.distance.SISDRLoss() | |
stft_loss = audiotools.metrics.spectral.MultiScaleSTFTLoss() | |
mel_loss = audiotools.metrics.spectral.MelSpectrogramLoss() | |
frechet = FrechetAudioDistance( | |
use_pca=False, | |
use_activation=False, | |
verbose=False | |
) | |
visqol = partial(audiotools.metrics.quality.visqol, mode="audio") | |
# figure out what conditions we have | |
conditions = [d.name for d in exp_dir.iterdir() if d.is_dir()] | |
assert baseline_key in conditions, f"baseline_key {baseline_key} not found in {exp_dir}" | |
conditions.remove(baseline_key) | |
print(f"Found {len(conditions)} conditions in {exp_dir}") | |
print(f"conditions: {conditions}") | |
baseline_dir = exp_dir / baseline_key | |
baseline_files = list(baseline_dir.glob(f"*{audio_ext}")) | |
metrics = [] | |
for condition in conditions: | |
cond_dir = exp_dir / condition | |
cond_files = list(cond_dir.glob(f"*{audio_ext}")) | |
print(f"computing fad") | |
frechet_score = frechet.score(baseline_dir, cond_dir) | |
# make sure we have the same number of files | |
assert len(list(baseline_files)) == len(list(cond_files)), f"number of files in {baseline_dir} and {cond_dir} do not match. {len(list(baseline_files))} vs {len(list(cond_files))}" | |
pbar = tqdm(zip(baseline_files, cond_files), total=len(baseline_files)) | |
for baseline_file, cond_file in pbar: | |
assert baseline_file.stem == cond_file.stem, f"baseline file {baseline_file} and cond file {cond_file} do not match" | |
pbar.set_description(baseline_file.stem) | |
# load the files | |
baseline_sig = AudioSignal(baseline_file) | |
cond_sig = AudioSignal(cond_file) | |
# compute the metrics | |
try: | |
vsq = visqol(baseline_sig, cond_sig) | |
except: | |
vsq = 0.0 | |
metrics.append({ | |
"sisdr": sisdr_loss(baseline_sig, cond_sig), | |
"stft": stft_loss(baseline_sig, cond_sig), | |
"mel": mel_loss(baseline_sig, cond_sig), | |
"frechet": frechet_score, | |
"visqol": vsq, | |
"condition": condition, | |
"file": baseline_file.stem, | |
}) | |
metric_keys = [k for k in metrics[0].keys() if k not in ("condition", "file")] | |
for mk in metric_keys: | |
stat = pandas.DataFrame(metrics) | |
stat = stat.groupby(['condition'])[mk].agg(['mean', 'count', 'std']) | |
stat.to_csv(exp_dir / f"stats-{mk}.csv") | |
df = pandas.DataFrame(metrics) | |
df.to_csv(exp_dir / "metrics-all.csv", index=False) | |
if __name__ == "__main__": | |
args = argbind.parse_args() | |
with argbind.scope(args): | |
eval() |