HTS-Audio-Transformer / sed_model.py
artelabsuper
copy work from repo
daf1ccd
raw
history blame
No virus
16 kB
# Ke Chen
# knutchen@ucsd.edu
# HTS-AT: A HIERARCHICAL TOKEN-SEMANTIC AUDIO TRANSFORMER FOR SOUND CLASSIFICATION AND DETECTION
# The Model Training Wrapper
import numpy as np
import librosa
import os
import bisect
from numpy.lib.function_base import average
from sklearn.metrics import average_precision_score, roc_auc_score, accuracy_score
from utils import get_loss_func, get_mix_lambda, d_prime
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as cp
import torch.optim as optim
from torch.nn.parameter import Parameter
import torch.distributed as dist
import pytorch_lightning as pl
from utils import do_mixup, get_mix_lambda, do_mixup_label
class SEDWrapper(pl.LightningModule):
def __init__(self, sed_model, config, dataset):
super().__init__()
self.sed_model = sed_model
self.config = config
self.dataset = dataset
self.loss_func = get_loss_func(config.loss_type)
def evaluate_metric(self, pred, ans):
ap = []
if self.config.dataset_type == "audioset":
mAP = np.mean(average_precision_score(ans, pred, average = None))
mAUC = np.mean(roc_auc_score(ans, pred, average = None))
dprime = d_prime(mAUC)
return {"mAP": mAP, "mAUC": mAUC, "dprime": dprime}
else:
acc = accuracy_score(ans, np.argmax(pred, 1))
return {"acc": acc}
def forward(self, x, mix_lambda = None):
output_dict = self.sed_model(x, mix_lambda)
return output_dict["clipwise_output"], output_dict["framewise_output"]
def inference(self, x):
self.device_type = next(self.parameters()).device
self.eval()
x = torch.from_numpy(x).float().to(self.device_type)
print(x.shape)
output_dict = self.sed_model(x, None, True)
for key in output_dict.keys():
output_dict[key] = output_dict[key].detach().cpu().numpy()
return output_dict
def training_step(self, batch, batch_idx):
self.device_type = next(self.parameters()).device
mix_lambda = torch.from_numpy(get_mix_lambda(0.5, len(batch["waveform"]))).to(self.device_type)
# Another Choice: also mixup the target, but AudioSet is not a perfect data
# so "adding noise" might be better than purly "mix"
# batch["target"] = do_mixup_label(batch["target"])
# batch["target"] = do_mixup(batch["target"], mix_lambda)
pred, _ = self(batch["waveform"], mix_lambda)
loss = self.loss_func(pred, batch["target"])
self.log("loss", loss, on_epoch= True, prog_bar=True)
return loss
def training_epoch_end(self, outputs):
# Change: SWA, deprecated
# for opt in self.trainer.optimizers:
# if not type(opt) is SWA:
# continue
# opt.swap_swa_sgd()
self.dataset.generate_queue()
def validation_step(self, batch, batch_idx):
pred, _ = self(batch["waveform"])
return [pred.detach(), batch["target"].detach()]
def validation_epoch_end(self, validation_step_outputs):
self.device_type = next(self.parameters()).device
pred = torch.cat([d[0] for d in validation_step_outputs], dim = 0)
target = torch.cat([d[1] for d in validation_step_outputs], dim = 0)
gather_pred = [torch.zeros_like(pred) for _ in range(dist.get_world_size())]
gather_target = [torch.zeros_like(target) for _ in range(dist.get_world_size())]
dist.barrier()
if self.config.dataset_type == "audioset":
metric_dict = {
"mAP": 0.,
"mAUC": 0.,
"dprime": 0.
}
else:
metric_dict = {
"acc":0.
}
dist.all_gather(gather_pred, pred)
dist.all_gather(gather_target, target)
if dist.get_rank() == 0:
gather_pred = torch.cat(gather_pred, dim = 0).cpu().numpy()
gather_target = torch.cat(gather_target, dim = 0).cpu().numpy()
if self.config.dataset_type == "scv2":
gather_target = np.argmax(gather_target, 1)
metric_dict = self.evaluate_metric(gather_pred, gather_target)
print(self.device_type, dist.get_world_size(), metric_dict, flush = True)
if self.config.dataset_type == "audioset":
self.log("mAP", metric_dict["mAP"] * float(dist.get_world_size()), on_epoch = True, prog_bar=True, sync_dist=True)
self.log("mAUC", metric_dict["mAUC"] * float(dist.get_world_size()), on_epoch = True, prog_bar=True, sync_dist=True)
self.log("dprime", metric_dict["dprime"] * float(dist.get_world_size()), on_epoch = True, prog_bar=True, sync_dist=True)
else:
self.log("acc", metric_dict["acc"] * float(dist.get_world_size()), on_epoch = True, prog_bar=True, sync_dist=True)
dist.barrier()
def time_shifting(self, x, shift_len):
shift_len = int(shift_len)
new_sample = torch.cat([x[:, shift_len:], x[:, :shift_len]], axis = 1)
return new_sample
def test_step(self, batch, batch_idx):
print(batch['waveform'].shape)
exit()
self.device_type = next(self.parameters()).device
preds = []
# time shifting optimization
if self.config.fl_local or self.config.dataset_type != "audioset":
shift_num = 1 # framewise localization cannot allow the time shifting
else:
shift_num = 10
for i in range(shift_num):
pred, pred_map = self(batch["waveform"])
preds.append(pred.unsqueeze(0))
batch["waveform"] = self.time_shifting(batch["waveform"], shift_len = 100 * (i + 1))
preds = torch.cat(preds, dim=0)
pred = preds.mean(dim = 0)
if self.config.fl_local:
return [
pred.detach().cpu().numpy(),
pred_map.detach().cpu().numpy(),
batch["audio_name"],
batch["real_len"].cpu().numpy()
]
else:
return [pred.detach(), batch["target"].detach()]
def test_epoch_end(self, test_step_outputs):
self.device_type = next(self.parameters()).device
if self.config.fl_local:
pred = np.concatenate([d[0] for d in test_step_outputs], axis = 0)
pred_map = np.concatenate([d[1] for d in test_step_outputs], axis = 0)
audio_name = np.concatenate([d[2] for d in test_step_outputs], axis = 0)
real_len = np.concatenate([d[3] for d in test_step_outputs], axis = 0)
heatmap_file = os.path.join(self.config.heatmap_dir, self.config.test_file + "_" + str(self.device_type) + ".npy")
save_npy = [
{
"audio_name": audio_name[i],
"heatmap": pred_map[i],
"pred": pred[i],
"real_len":real_len[i]
}
for i in range(len(pred))
]
np.save(heatmap_file, save_npy)
else:
self.device_type = next(self.parameters()).device
pred = torch.cat([d[0] for d in test_step_outputs], dim = 0)
target = torch.cat([d[1] for d in test_step_outputs], dim = 0)
gather_pred = [torch.zeros_like(pred) for _ in range(dist.get_world_size())]
gather_target = [torch.zeros_like(target) for _ in range(dist.get_world_size())]
dist.barrier()
if self.config.dataset_type == "audioset":
metric_dict = {
"mAP": 0.,
"mAUC": 0.,
"dprime": 0.
}
else:
metric_dict = {
"acc":0.
}
dist.all_gather(gather_pred, pred)
dist.all_gather(gather_target, target)
if dist.get_rank() == 0:
gather_pred = torch.cat(gather_pred, dim = 0).cpu().numpy()
gather_target = torch.cat(gather_target, dim = 0).cpu().numpy()
if self.config.dataset_type == "scv2":
gather_target = np.argmax(gather_target, 1)
metric_dict = self.evaluate_metric(gather_pred, gather_target)
print(self.device_type, dist.get_world_size(), metric_dict, flush = True)
if self.config.dataset_type == "audioset":
self.log("mAP", metric_dict["mAP"] * float(dist.get_world_size()), on_epoch = True, prog_bar=True, sync_dist=True)
self.log("mAUC", metric_dict["mAUC"] * float(dist.get_world_size()), on_epoch = True, prog_bar=True, sync_dist=True)
self.log("dprime", metric_dict["dprime"] * float(dist.get_world_size()), on_epoch = True, prog_bar=True, sync_dist=True)
else:
self.log("acc", metric_dict["acc"] * float(dist.get_world_size()), on_epoch = True, prog_bar=True, sync_dist=True)
dist.barrier()
def configure_optimizers(self):
optimizer = optim.AdamW(
filter(lambda p: p.requires_grad, self.parameters()),
lr = self.config.learning_rate,
betas = (0.9, 0.999), eps = 1e-08, weight_decay = 0.05,
)
# Change: SWA, deprecated
# optimizer = SWA(optimizer, swa_start=10, swa_freq=5)
def lr_foo(epoch):
if epoch < 3:
# warm up lr
lr_scale = self.config.lr_rate[epoch]
else:
# warmup schedule
lr_pos = int(-1 - bisect.bisect_left(self.config.lr_scheduler_epoch, epoch))
if lr_pos < -3:
lr_scale = max(self.config.lr_rate[0] * (0.98 ** epoch), 0.03 )
else:
lr_scale = self.config.lr_rate[lr_pos]
return lr_scale
scheduler = optim.lr_scheduler.LambdaLR(
optimizer,
lr_lambda=lr_foo
)
return [optimizer], [scheduler]
class Ensemble_SEDWrapper(pl.LightningModule):
def __init__(self, sed_models, config, dataset):
super().__init__()
self.sed_models = nn.ModuleList(sed_models)
self.config = config
self.dataset = dataset
def evaluate_metric(self, pred, ans):
if self.config.dataset_type == "audioset":
mAP = np.mean(average_precision_score(ans, pred, average = None))
mAUC = np.mean(roc_auc_score(ans, pred, average = None))
dprime = d_prime(mAUC)
return {"mAP": mAP, "mAUC": mAUC, "dprime": dprime}
else:
acc = accuracy_score(ans, np.argmax(pred, 1))
return {"acc": acc}
def forward(self, x, sed_index, mix_lambda = None):
self.sed_models[sed_index].eval()
preds = []
pred_maps = []
# time shifting optimization
if self.config.fl_local or self.config.dataset_type != "audioset":
shift_num = 1 # framewise localization cannot allow the time shifting
else:
shift_num = 10
for i in range(shift_num):
pred, pred_map = self.sed_models[sed_index](x)
pred_maps.append(pred_map.unsqueeze(0))
preds.append(pred.unsqueeze(0))
x = self.time_shifting(x, shift_len = 100 * (i + 1))
preds = torch.cat(preds, dim=0)
pred_maps = torch.cat(pred_maps, dim = 0)
pred = preds.mean(dim = 0)
pred_map = pred_maps.mean(dim = 0)
return pred, pred_map
def time_shifting(self, x, shift_len):
shift_len = int(shift_len)
new_sample = torch.cat([x[:, shift_len:], x[:, :shift_len]], axis = 1)
return new_sample
def test_step(self, batch, batch_idx):
self.device_type = next(self.parameters()).device
if self.config.fl_local:
pred = torch.zeros(len(batch["waveform"]), self.config.classes_num).float().to(self.device_type)
pred_map = torch.zeros(len(batch["waveform"]), 1024, self.config.classes_num).float().to(self.device_type)
for j in range(len(self.sed_models)):
temp_pred, temp_pred_map = self(batch["waveform"], j)
pred = pred + temp_pred
pred_map = pred_map + temp_pred_map
pred = pred / len(self.sed_models)
pred_map = pred_map / len(self.sed_models)
return [
pred.detach().cpu().numpy(),
pred_map.detach().cpu().numpy(),
batch["audio_name"],
batch["real_len"].cpu().numpy()
]
else:
pred = torch.zeros(len(batch["waveform"]), self.config.classes_num).float().to(self.device_type)
for j in range(len(self.sed_models)):
temp_pred, _ = self(batch["waveform"], j)
pred = pred + temp_pred
pred = pred / len(self.sed_models)
return [
pred.detach(),
batch["target"].detach(),
]
def test_epoch_end(self, test_step_outputs):
self.device_type = next(self.parameters()).device
if self.config.fl_local:
pred = np.concatenate([d[0] for d in test_step_outputs], axis = 0)
pred_map = np.concatenate([d[1] for d in test_step_outputs], axis = 0)
audio_name = np.concatenate([d[2] for d in test_step_outputs], axis = 0)
real_len = np.concatenate([d[3] for d in test_step_outputs], axis = 0)
heatmap_file = os.path.join(self.config.heatmap_dir, self.config.test_file + "_" + str(self.device_type) + ".npy")
print(pred.shape)
print(pred_map.shape)
print(real_len.shape)
save_npy = [
{
"audio_name": audio_name[i],
"heatmap": pred_map[i],
"pred": pred[i],
"real_len":real_len[i]
}
for i in range(len(pred))
]
np.save(heatmap_file, save_npy)
else:
pred = torch.cat([d[0] for d in test_step_outputs], dim = 0)
target = torch.cat([d[1] for d in test_step_outputs], dim = 0)
gather_pred = [torch.zeros_like(pred) for _ in range(dist.get_world_size())]
gather_target = [torch.zeros_like(target) for _ in range(dist.get_world_size())]
dist.barrier()
if self.config.dataset_type == "audioset":
metric_dict = {
"mAP": 0.,
"mAUC": 0.,
"dprime": 0.
}
else:
metric_dict = {
"acc":0.
}
dist.all_gather(gather_pred, pred)
dist.all_gather(gather_target, target)
if dist.get_rank() == 0:
gather_pred = torch.cat(gather_pred, dim = 0).cpu().numpy()
gather_target = torch.cat(gather_target, dim = 0).cpu().numpy()
if self.config.dataset_type == "scv2":
gather_target = np.argmax(gather_target, 1)
metric_dict = self.evaluate_metric(gather_pred, gather_target)
print(self.device_type, dist.get_world_size(), metric_dict, flush = True)
if self.config.dataset_type == "audioset":
self.log("mAP", metric_dict["mAP"] * float(dist.get_world_size()), on_epoch = True, prog_bar=True, sync_dist=True)
self.log("mAUC", metric_dict["mAUC"] * float(dist.get_world_size()), on_epoch = True, prog_bar=True, sync_dist=True)
self.log("dprime", metric_dict["dprime"] * float(dist.get_world_size()), on_epoch = True, prog_bar=True, sync_dist=True)
else:
self.log("acc", metric_dict["acc"] * float(dist.get_world_size()), on_epoch = True, prog_bar=True, sync_dist=True)
dist.barrier()