Spaces:
Sleeping
Sleeping
import glob | |
import imp | |
import os | |
from pathlib import Path | |
import argbind | |
import audiotools | |
import numpy as np | |
import pandas as pd | |
import torch | |
from flatten_dict import flatten | |
from rich.progress import track | |
from torch.utils.tensorboard import SummaryWriter | |
import wav2wav | |
train = imp.load_source("train", str(Path(__file__).absolute().parent / "train.py")) | |
def evaluate( | |
args, | |
model_tag: str = "ckpt/best", | |
device: str = "cuda", | |
exp: str = None, | |
overwrite: bool = False, | |
): | |
assert exp is not None | |
sisdr_loss = audiotools.metrics.distance.SISDRLoss() | |
stft_loss = audiotools.metrics.spectral.MultiScaleSTFTLoss() | |
mel_loss = audiotools.metrics.spectral.MelSpectrogramLoss() | |
with audiotools.util.chdir(exp): | |
vampnet = wav2wav.modules.vampnet.transformer.VampNet.load( | |
f"{model_tag}/vampnet/package.pth" | |
) | |
vampnet = vampnet.to(device) | |
if vampnet.cond_dim > 0: | |
condnet = wav2wav.modules.condnet.transformer.CondNet.load( | |
f"{model_tag}/condnet/package.pth" | |
) | |
condnet = condnet.to(device) | |
else: | |
condnet = None | |
vqvae = wav2wav.modules.generator.Generator.load( | |
f"{model_tag}/vqvae/package.pth" | |
) | |
_, _, test_data = train.build_datasets(args, vqvae.sample_rate) | |
with audiotools.util.chdir(exp): | |
datasets = { | |
"test": test_data, | |
} | |
metrics_path = Path(f"{model_tag}/metrics") | |
metrics_path.mkdir(parents=True, exist_ok=True) | |
for key, dataset in datasets.items(): | |
csv_path = metrics_path / f"{key}.csv" | |
if csv_path.exists() and not overwrite: | |
break | |
metrics = [] | |
for i in track(range(len(dataset))): | |
# TODO: for coarse2fine | |
# grab the signal | |
# mask all the codebooks except the conditioning ones | |
# and infer | |
# then compute metrics | |
# for a baseline, just use the coarsest codebook | |
try: | |
visqol = audiotools.metrics.quality.visqol( | |
enhanced, clean, "audio" | |
).item() | |
except: | |
visqol = None | |
sisdr = sisdr_loss(enhanced, clean) | |
stft = stft_loss(enhanced, clean) | |
mel = mel_loss(enhanced, clean) | |
metrics.append( | |
{ | |
"visqol": visqol, | |
"sisdr": sisdr.item(), | |
"stft": stft.item(), | |
"mel": mel.item(), | |
"dataset": key, | |
"condition": exp, | |
} | |
) | |
print(metrics[-1]) | |
transform_args = flatten(item["transform_args"], "dot") | |
for k, v in transform_args.items(): | |
if torch.is_tensor(v): | |
if len(v.shape) == 0: | |
metrics[-1][k] = v.item() | |
metrics = pd.DataFrame.from_dict(metrics) | |
with open(csv_path, "w") as f: | |
metrics.to_csv(f) | |
data = summary(model_tag).to_dict() | |
metrics = {} | |
for k1, v1 in data.items(): | |
for k2, v2 in v1.items(): | |
metrics[f"metrics/{k2}/{k1}"] = v2 | |
# Number of steps to record | |
writer = SummaryWriter(log_dir=metrics_path) | |
num_steps = 10 | |
for k, v in metrics.items(): | |
for i in range(num_steps): | |
writer.add_scalar(k, v, i) | |
if __name__ == "__main__": | |
args = argbind.parse_args() | |
with argbind.scope(args): | |
evaluate(args) | |