Spaces:
Sleeping
Sleeping
File size: 3,814 Bytes
50f034f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 |
import glob
import imp
import os
from pathlib import Path
import argbind
import audiotools
import numpy as np
import pandas as pd
import torch
from flatten_dict import flatten
from rich.progress import track
from torch.utils.tensorboard import SummaryWriter
import wav2wav
train = imp.load_source("train", str(Path(__file__).absolute().parent / "train.py"))
@argbind.bind(without_prefix=True)
def evaluate(
args,
model_tag: str = "ckpt/best",
device: str = "cuda",
exp: str = None,
overwrite: bool = False,
):
assert exp is not None
sisdr_loss = audiotools.metrics.distance.SISDRLoss()
stft_loss = audiotools.metrics.spectral.MultiScaleSTFTLoss()
mel_loss = audiotools.metrics.spectral.MelSpectrogramLoss()
with audiotools.util.chdir(exp):
vampnet = wav2wav.modules.vampnet.transformer.VampNet.load(
f"{model_tag}/vampnet/package.pth"
)
vampnet = vampnet.to(device)
if vampnet.cond_dim > 0:
condnet = wav2wav.modules.condnet.transformer.CondNet.load(
f"{model_tag}/condnet/package.pth"
)
condnet = condnet.to(device)
else:
condnet = None
vqvae = wav2wav.modules.generator.Generator.load(
f"{model_tag}/vqvae/package.pth"
)
_, _, test_data = train.build_datasets(args, vqvae.sample_rate)
with audiotools.util.chdir(exp):
datasets = {
"test": test_data,
}
metrics_path = Path(f"{model_tag}/metrics")
metrics_path.mkdir(parents=True, exist_ok=True)
for key, dataset in datasets.items():
csv_path = metrics_path / f"{key}.csv"
if csv_path.exists() and not overwrite:
break
metrics = []
for i in track(range(len(dataset))):
# TODO: for coarse2fine
# grab the signal
# mask all the codebooks except the conditioning ones
# and infer
# then compute metrics
# for a baseline, just use the coarsest codebook
try:
visqol = audiotools.metrics.quality.visqol(
enhanced, clean, "audio"
).item()
except:
visqol = None
sisdr = sisdr_loss(enhanced, clean)
stft = stft_loss(enhanced, clean)
mel = mel_loss(enhanced, clean)
metrics.append(
{
"visqol": visqol,
"sisdr": sisdr.item(),
"stft": stft.item(),
"mel": mel.item(),
"dataset": key,
"condition": exp,
}
)
print(metrics[-1])
transform_args = flatten(item["transform_args"], "dot")
for k, v in transform_args.items():
if torch.is_tensor(v):
if len(v.shape) == 0:
metrics[-1][k] = v.item()
metrics = pd.DataFrame.from_dict(metrics)
with open(csv_path, "w") as f:
metrics.to_csv(f)
data = summary(model_tag).to_dict()
metrics = {}
for k1, v1 in data.items():
for k2, v2 in v1.items():
metrics[f"metrics/{k2}/{k1}"] = v2
# Number of steps to record
writer = SummaryWriter(log_dir=metrics_path)
num_steps = 10
for k, v in metrics.items():
for i in range(num_steps):
writer.add_scalar(k, v, i)
if __name__ == "__main__":
args = argbind.parse_args()
with argbind.scope(args):
evaluate(args)
|