vampnet / scripts /exp /eval.py
Hugo Flores
refactor bugfixes
534a89c
raw
history blame
3.81 kB
import glob
import imp
import os
from pathlib import Path
import argbind
import audiotools
import numpy as np
import pandas as pd
import torch
from flatten_dict import flatten
from rich.progress import track
from torch.utils.tensorboard import SummaryWriter
import wav2wav
train = imp.load_source("train", str(Path(__file__).absolute().parent / "train.py"))
@argbind.bind(without_prefix=True)
def evaluate(
args,
model_tag: str = "ckpt/best",
device: str = "cuda",
exp: str = None,
overwrite: bool = False,
):
assert exp is not None
sisdr_loss = audiotools.metrics.distance.SISDRLoss()
stft_loss = audiotools.metrics.spectral.MultiScaleSTFTLoss()
mel_loss = audiotools.metrics.spectral.MelSpectrogramLoss()
with audiotools.util.chdir(exp):
vampnet = wav2wav.modules.vampnet.transformer.VampNet.load(
f"{model_tag}/vampnet/package.pth"
)
vampnet = vampnet.to(device)
if vampnet.cond_dim > 0:
condnet = wav2wav.modules.condnet.transformer.CondNet.load(
f"{model_tag}/condnet/package.pth"
)
condnet = condnet.to(device)
else:
condnet = None
vqvae = wav2wav.modules.generator.Generator.load(
f"{model_tag}/vqvae/package.pth"
)
_, _, test_data = train.build_datasets(args, vqvae.sample_rate)
with audiotools.util.chdir(exp):
datasets = {
"test": test_data,
}
metrics_path = Path(f"{model_tag}/metrics")
metrics_path.mkdir(parents=True, exist_ok=True)
for key, dataset in datasets.items():
csv_path = metrics_path / f"{key}.csv"
if csv_path.exists() and not overwrite:
break
metrics = []
for i in track(range(len(dataset))):
# TODO: for coarse2fine
# grab the signal
# mask all the codebooks except the conditioning ones
# and infer
# then compute metrics
# for a baseline, just use the coarsest codebook
try:
visqol = audiotools.metrics.quality.visqol(
enhanced, clean, "audio"
).item()
except:
visqol = None
sisdr = sisdr_loss(enhanced, clean)
stft = stft_loss(enhanced, clean)
mel = mel_loss(enhanced, clean)
metrics.append(
{
"visqol": visqol,
"sisdr": sisdr.item(),
"stft": stft.item(),
"mel": mel.item(),
"dataset": key,
"condition": exp,
}
)
print(metrics[-1])
transform_args = flatten(item["transform_args"], "dot")
for k, v in transform_args.items():
if torch.is_tensor(v):
if len(v.shape) == 0:
metrics[-1][k] = v.item()
metrics = pd.DataFrame.from_dict(metrics)
with open(csv_path, "w") as f:
metrics.to_csv(f)
data = summary(model_tag).to_dict()
metrics = {}
for k1, v1 in data.items():
for k2, v2 in v1.items():
metrics[f"metrics/{k2}/{k1}"] = v2
# Number of steps to record
writer = SummaryWriter(log_dir=metrics_path)
num_steps = 10
for k, v in metrics.items():
for i in range(num_steps):
writer.add_scalar(k, v, i)
if __name__ == "__main__":
args = argbind.parse_args()
with argbind.scope(args):
evaluate(args)