Spaces:
Sleeping
Sleeping
from pathlib import Path | |
import os | |
from functools import partial | |
from frechet_audio_distance import FrechetAudioDistance | |
import pandas | |
import argbind | |
import torch | |
from tqdm import tqdm | |
import audiotools | |
from audiotools import AudioSignal | |
def eval( | |
exp_dir: str = None, | |
baseline_key: str = "baseline", | |
audio_ext: str = ".wav", | |
): | |
assert exp_dir is not None | |
exp_dir = Path(exp_dir) | |
assert exp_dir.exists(), f"exp_dir {exp_dir} does not exist" | |
# set up our metrics | |
# sisdr_loss = audiotools.metrics.distance.SISDRLoss() | |
# stft_loss = audiotools.metrics.spectral.MultiScaleSTFTLoss() | |
mel_loss = audiotools.metrics.spectral.MelSpectrogramLoss() | |
frechet = FrechetAudioDistance( | |
use_pca=False, | |
use_activation=False, | |
verbose=True, | |
audio_load_worker=4, | |
) | |
frechet.model.to("cuda" if torch.cuda.is_available() else "cpu") | |
# figure out what conditions we have | |
conditions = [d.name for d in exp_dir.iterdir() if d.is_dir()] | |
assert baseline_key in conditions, f"baseline_key {baseline_key} not found in {exp_dir}" | |
conditions.remove(baseline_key) | |
print(f"Found {len(conditions)} conditions in {exp_dir}") | |
print(f"conditions: {conditions}") | |
baseline_dir = exp_dir / baseline_key | |
baseline_files = sorted(list(baseline_dir.glob(f"*{audio_ext}")), key=lambda x: int(x.stem)) | |
metrics = [] | |
for condition in tqdm(conditions): | |
cond_dir = exp_dir / condition | |
cond_files = sorted(list(cond_dir.glob(f"*{audio_ext}")), key=lambda x: int(x.stem)) | |
print(f"computing fad for {baseline_dir} and {cond_dir}") | |
frechet_score = frechet.score(baseline_dir, cond_dir) | |
# make sure we have the same number of files | |
num_files = min(len(baseline_files), len(cond_files)) | |
baseline_files = baseline_files[:num_files] | |
cond_files = cond_files[:num_files] | |
assert len(list(baseline_files)) == len(list(cond_files)), f"number of files in {baseline_dir} and {cond_dir} do not match. {len(list(baseline_files))} vs {len(list(cond_files))}" | |
def process(baseline_file, cond_file): | |
# make sure the files match (same name) | |
assert baseline_file.stem == cond_file.stem, f"baseline file {baseline_file} and cond file {cond_file} do not match" | |
# load the files | |
baseline_sig = AudioSignal(str(baseline_file)) | |
cond_sig = AudioSignal(str(cond_file)) | |
cond_sig.resample(baseline_sig.sample_rate) | |
cond_sig.truncate_samples(baseline_sig.length) | |
# if our condition is inpainting, we need to trim the conditioning off | |
if "inpaint" in condition: | |
ctx_amt = float(condition.split("_")[-1]) | |
ctx_samples = int(ctx_amt * baseline_sig.sample_rate) | |
print(f"found inpainting condition. trimming off {ctx_samples} samples from {cond_file} and {baseline_file}") | |
cond_sig.trim(ctx_samples, ctx_samples) | |
baseline_sig.trim(ctx_samples, ctx_samples) | |
return { | |
# "sisdr": -sisdr_loss(baseline_sig, cond_sig).item(), | |
# "stft": stft_loss(baseline_sig, cond_sig).item(), | |
"mel": mel_loss(baseline_sig, cond_sig).item(), | |
"frechet": frechet_score, | |
# "visqol": vsq, | |
"condition": condition, | |
"file": baseline_file.stem, | |
} | |
print(f"processing {len(baseline_files)} files in {baseline_dir} and {cond_dir}") | |
metrics.extend(tqdm(map(process, baseline_files, cond_files), total=len(baseline_files))) | |
metric_keys = [k for k in metrics[0].keys() if k not in ("condition", "file")] | |
for mk in metric_keys: | |
stat = pandas.DataFrame(metrics) | |
stat = stat.groupby(['condition'])[mk].agg(['mean', 'count', 'std']) | |
stat.to_csv(exp_dir / f"stats-{mk}.csv") | |
df = pandas.DataFrame(metrics) | |
df.to_csv(exp_dir / "metrics-all.csv", index=False) | |
if __name__ == "__main__": | |
args = argbind.parse_args() | |
with argbind.scope(args): | |
eval() |