vampnet / scripts /exp /c2f_eval.py
Hugo Flores
save each metric on its own
275afd0
raw
history blame
3.27 kB
from pathlib import Path
import os
from functools import partial
from frechet_audio_distance import FrechetAudioDistance
import pandas
import argbind
from tqdm import tqdm
import audiotools
from audiotools import AudioSignal
@argbind.bind(without_prefix=True)
def eval(
exp_dir: str = None,
baseline_key: str = "reconstructed",
audio_ext: str = ".wav",
):
assert exp_dir is not None
exp_dir = Path(exp_dir)
assert exp_dir.exists(), f"exp_dir {exp_dir} does not exist"
# set up our metrics
sisdr_loss = audiotools.metrics.distance.SISDRLoss()
stft_loss = audiotools.metrics.spectral.MultiScaleSTFTLoss()
mel_loss = audiotools.metrics.spectral.MelSpectrogramLoss()
frechet = FrechetAudioDistance(
use_pca=False,
use_activation=False,
verbose=False
)
visqol = partial(audiotools.metrics.quality.visqol, mode="audio")
# figure out what conditions we have
conditions = [d.name for d in exp_dir.iterdir() if d.is_dir()]
assert baseline_key in conditions, f"baseline_key {baseline_key} not found in {exp_dir}"
conditions.remove(baseline_key)
print(f"Found {len(conditions)} conditions in {exp_dir}")
print(f"conditions: {conditions}")
baseline_dir = exp_dir / baseline_key
baseline_files = list(baseline_dir.glob(f"*{audio_ext}"))
metrics = []
for condition in conditions:
cond_dir = exp_dir / condition
cond_files = list(cond_dir.glob(f"*{audio_ext}"))
print(f"computing fad")
frechet_score = frechet.score(baseline_dir, cond_dir)
# make sure we have the same number of files
assert len(list(baseline_files)) == len(list(cond_files)), f"number of files in {baseline_dir} and {cond_dir} do not match. {len(list(baseline_files))} vs {len(list(cond_files))}"
pbar = tqdm(zip(baseline_files, cond_files), total=len(baseline_files))
for baseline_file, cond_file in pbar:
assert baseline_file.stem == cond_file.stem, f"baseline file {baseline_file} and cond file {cond_file} do not match"
pbar.set_description(baseline_file.stem)
# load the files
baseline_sig = AudioSignal(baseline_file)
cond_sig = AudioSignal(cond_file)
# compute the metrics
try:
vsq = visqol(baseline_sig, cond_sig)
except:
vsq = 0.0
metrics.append({
"sisdr": sisdr_loss(baseline_sig, cond_sig),
"stft": stft_loss(baseline_sig, cond_sig),
"mel": mel_loss(baseline_sig, cond_sig),
"frechet": frechet_score,
"visqol": vsq,
"condition": condition,
"file": baseline_file.stem,
})
metric_keys = [k for k in metrics[0].keys() if k not in ("condition", "file")]
for mk in metric_keys:
stat = pandas.DataFrame(metrics)
stat = stat.groupby(['condition'])[mk].agg(['mean', 'count', 'std'])
stat.to_csv(exp_dir / f"stats-{mk}.csv")
df = pandas.DataFrame(metrics)
df.to_csv(exp_dir / "metrics-all.csv", index=False)
if __name__ == "__main__":
args = argbind.parse_args()
with argbind.scope(args):
eval()