vampnet / scripts /exp /eval.py
Hugo Flores Garcia
more tweaks
93b48cb
raw
history blame
No virus
4.18 kB
from pathlib import Path
import os
from functools import partial
from frechet_audio_distance import FrechetAudioDistance
import pandas
import argbind
import torch
from tqdm import tqdm
import audiotools
from audiotools import AudioSignal
@argbind.bind(without_prefix=True)
def eval(
exp_dir: str = None,
baseline_key: str = "baseline",
audio_ext: str = ".wav",
):
assert exp_dir is not None
exp_dir = Path(exp_dir)
assert exp_dir.exists(), f"exp_dir {exp_dir} does not exist"
# set up our metrics
# sisdr_loss = audiotools.metrics.distance.SISDRLoss()
# stft_loss = audiotools.metrics.spectral.MultiScaleSTFTLoss()
mel_loss = audiotools.metrics.spectral.MelSpectrogramLoss()
frechet = FrechetAudioDistance(
use_pca=False,
use_activation=False,
verbose=True,
audio_load_worker=4,
)
frechet.model.to("cuda" if torch.cuda.is_available() else "cpu")
# figure out what conditions we have
conditions = [d.name for d in exp_dir.iterdir() if d.is_dir()]
assert baseline_key in conditions, f"baseline_key {baseline_key} not found in {exp_dir}"
conditions.remove(baseline_key)
print(f"Found {len(conditions)} conditions in {exp_dir}")
print(f"conditions: {conditions}")
baseline_dir = exp_dir / baseline_key
baseline_files = sorted(list(baseline_dir.glob(f"*{audio_ext}")), key=lambda x: int(x.stem))
metrics = []
for condition in tqdm(conditions):
cond_dir = exp_dir / condition
cond_files = sorted(list(cond_dir.glob(f"*{audio_ext}")), key=lambda x: int(x.stem))
print(f"computing fad for {baseline_dir} and {cond_dir}")
frechet_score = frechet.score(baseline_dir, cond_dir)
# make sure we have the same number of files
num_files = min(len(baseline_files), len(cond_files))
baseline_files = baseline_files[:num_files]
cond_files = cond_files[:num_files]
assert len(list(baseline_files)) == len(list(cond_files)), f"number of files in {baseline_dir} and {cond_dir} do not match. {len(list(baseline_files))} vs {len(list(cond_files))}"
def process(baseline_file, cond_file):
# make sure the files match (same name)
assert baseline_file.stem == cond_file.stem, f"baseline file {baseline_file} and cond file {cond_file} do not match"
# load the files
baseline_sig = AudioSignal(str(baseline_file))
cond_sig = AudioSignal(str(cond_file))
cond_sig.resample(baseline_sig.sample_rate)
cond_sig.truncate_samples(baseline_sig.length)
# if our condition is inpainting, we need to trim the conditioning off
if "inpaint" in condition:
ctx_amt = float(condition.split("_")[-1])
ctx_samples = int(ctx_amt * baseline_sig.sample_rate)
print(f"found inpainting condition. trimming off {ctx_samples} samples from {cond_file} and {baseline_file}")
cond_sig.trim(ctx_samples, ctx_samples)
baseline_sig.trim(ctx_samples, ctx_samples)
return {
# "sisdr": -sisdr_loss(baseline_sig, cond_sig).item(),
# "stft": stft_loss(baseline_sig, cond_sig).item(),
"mel": mel_loss(baseline_sig, cond_sig).item(),
"frechet": frechet_score,
# "visqol": vsq,
"condition": condition,
"file": baseline_file.stem,
}
print(f"processing {len(baseline_files)} files in {baseline_dir} and {cond_dir}")
metrics.extend(tqdm(map(process, baseline_files, cond_files), total=len(baseline_files)))
metric_keys = [k for k in metrics[0].keys() if k not in ("condition", "file")]
for mk in metric_keys:
stat = pandas.DataFrame(metrics)
stat = stat.groupby(['condition'])[mk].agg(['mean', 'count', 'std'])
stat.to_csv(exp_dir / f"stats-{mk}.csv")
df = pandas.DataFrame(metrics)
df.to_csv(exp_dir / "metrics-all.csv", index=False)
if __name__ == "__main__":
args = argbind.parse_args()
with argbind.scope(args):
eval()