hilamanor's picture
initial commit
e73da9c
raw
history blame
No virus
10.6 kB
import json
import logging
import math
import os
import time
from contextlib import suppress
import numpy as np
import torch
import torch.nn.functional as F
try:
import wandb
except ImportError:
wandb = None
from open_clip import LPLoss, LPMetrics, lp_gather_features
from open_clip.utils import do_mixup, get_mix_lambda
from .distributed import is_master
from .zero_shot import zero_shot_eval
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def unwrap_model(model):
if hasattr(model, "module"):
return model.module
else:
return model
def train_one_epoch(
model,
data,
epoch,
optimizer,
scaler,
scheduler,
args,
tb_writer=None,
extra_suffix="",
):
device = torch.device(args.device)
autocast = torch.cuda.amp.autocast if args.precision == "amp" else suppress
model.train()
loss = LPLoss(args.lp_loss)
dataloader, sampler = data["train"].dataloader, data["train"].sampler
if args.distributed and sampler is not None:
sampler.set_epoch(epoch)
num_batches_per_epoch = dataloader.num_batches
sample_digits = math.ceil(math.log(dataloader.num_samples + 1, 10))
# for toy dataset
if args.dataset_type == "toy":
dataloader.dataset.generate_queue()
loss_m = AverageMeter()
batch_time_m = AverageMeter()
data_time_m = AverageMeter()
end = time.time()
for i, batch in enumerate(dataloader):
step = num_batches_per_epoch * epoch + i
if isinstance(scheduler, dict):
for s in scheduler.values():
s(step)
else:
scheduler(step)
audio = batch # contains mel_spec, wavform, and longer list
class_label = batch["class_label"]
# audio = audio.to(device=device, non_blocking=True)
class_label = class_label.to(device=device, non_blocking=True)
if args.mixup:
# https://github.com/RetroCirce/HTS-Audio-Transformer/blob/main/utils.py#L146
mix_lambda = torch.from_numpy(
get_mix_lambda(0.5, len(audio["waveform"]))
).to(device)
class_label = do_mixup(class_label, mix_lambda)
else:
mix_lambda = None
data_time_m.update(time.time() - end)
if isinstance(optimizer, dict):
for o_ in optimizer.values():
o_.zero_grad()
else:
optimizer.zero_grad()
with autocast():
pred = model(audio, mix_lambda=mix_lambda, device=device)
total_loss = loss(pred, class_label)
if isinstance(optimizer, dict):
if scaler is not None:
scaler.scale(total_loss).backward()
for o_ in optimizer.values():
if args.horovod:
o_.synchronize()
scaler.unscale_(o_)
with o_.skip_synchronize():
scaler.step(o_)
else:
scaler.step(o_)
scaler.update()
else:
total_loss.backward()
for o_ in optimizer.values():
o_.step()
else:
if scaler is not None:
scaler.scale(total_loss).backward()
if args.horovod:
optimizer.synchronize()
scaler.unscale_(optimizer)
with optimizer.skip_synchronize():
scaler.step(optimizer)
else:
scaler.step(optimizer)
scaler.update()
else:
total_loss.backward()
optimizer.step()
# Note: we clamp to 4.6052 = ln(100), as in the original paper.
with torch.no_grad():
unwrap_model(model).clap_model.logit_scale_a.clamp_(0, math.log(100))
unwrap_model(model).clap_model.logit_scale_t.clamp_(0, math.log(100))
batch_time_m.update(time.time() - end)
end = time.time()
batch_count = i + 1
if is_master(args) and (i % 100 == 0 or batch_count == num_batches_per_epoch):
if isinstance(audio, dict):
batch_size = len(audio["waveform"])
else:
batch_size = len(audio)
num_samples = batch_count * batch_size * args.world_size
samples_per_epoch = dataloader.num_samples
percent_complete = 100.0 * batch_count / num_batches_per_epoch
# NOTE loss is coarsely sampled, just master node and per log update
loss_m.update(total_loss.item(), batch_size)
if isinstance(optimizer, dict):
logging.info(
f"Train Epoch: {epoch} [{num_samples:>{sample_digits}}/{samples_per_epoch} ({percent_complete:.0f}%)] "
f"Loss: {loss_m.val:#.5g} ({loss_m.avg:#.4g}) "
f"Data (t): {data_time_m.avg:.3f} "
f"Batch (t): {batch_time_m.avg:.3f} "
f"LR: {[o_.param_groups[0]['lr'] for o_ in optimizer.values()]}"
)
log_data = {
"loss": loss_m.val,
"data_time": data_time_m.val,
"batch_time": batch_time_m.val,
"lr": [o_.param_groups[0]["lr"] for o_ in optimizer.values()],
}
else:
logging.info(
f"Train Epoch: {epoch} [{num_samples:>{sample_digits}}/{samples_per_epoch} ({percent_complete:.0f}%)] "
f"Loss: {loss_m.val:#.5g} ({loss_m.avg:#.4g}) "
f"Data (t): {data_time_m.avg:.3f} "
f"Batch (t): {batch_time_m.avg:.3f} "
f"LR: {optimizer.param_groups[0]['lr']:5f} "
)
# Save train loss / etc. Using non avg meter values as loggers have their own smoothing
log_data = {
"loss": loss_m.val,
"data_time": data_time_m.val,
"batch_time": batch_time_m.val,
"lr": optimizer.param_groups[0]["lr"],
}
for name, val in log_data.items():
name = f"train{extra_suffix}/{name}"
if tb_writer is not None:
tb_writer.add_scalar(name, val, step)
if args.wandb:
assert wandb is not None, "Please install wandb."
wandb.log({name: val, "step": step})
# resetting batch / data time meters per log window
batch_time_m.reset()
data_time_m.reset()
# end for
def evaluate(model, data, epoch, args, tb_writer=None, extra_suffix=""):
metrics = {}
if not args.parallel_eval:
if not is_master(args):
return metrics
device = torch.device(args.device)
model.eval()
# CHANGE
# zero_shot_metrics = zero_shot_eval(model, data, epoch, args)
# metrics.update(zero_shot_metrics)
if is_master(args):
print("Evaluating...")
metric_names = args.lp_metrics.split(",")
eval_tool = LPMetrics(metric_names=metric_names)
autocast = torch.cuda.amp.autocast if args.precision == "amp" else suppress
if "val" in data and (
args.val_frequency
and ((epoch % args.val_frequency) == 0 or epoch == args.epochs)
):
if args.parallel_eval:
dataloader, sampler = data["val"].dataloader, data["val"].sampler
if args.distributed and sampler is not None:
sampler.set_epoch(epoch)
samples_per_val = dataloader.num_samples
else:
dataloader = data["val"].dataloader
num_samples = 0
samples_per_val = dataloader.num_samples
eval_info = {"pred": [], "target": []}
with torch.no_grad():
for i, batch in enumerate(dataloader):
audio = batch # contains mel_spec, wavform, and longer list
class_label = batch["class_label"]
# audio = audio.to(device=device, non_blocking=True)
class_label = class_label.to(device=device, non_blocking=True)
with autocast():
pred = model(audio, device=device)
if args.parallel_eval:
pred, class_label = lp_gather_features(
pred, class_label, args.world_size, args.horovod
)
eval_info["pred"].append(pred)
eval_info["target"].append(class_label)
num_samples += class_label.shape[0]
if (i % 100) == 0: # and i != 0:
logging.info(
f"Eval Epoch: {epoch} [{num_samples} / {samples_per_val}]"
)
if is_master(args):
eval_info["pred"] = torch.cat(eval_info["pred"], 0).cpu()
eval_info["target"] = torch.cat(eval_info["target"], 0).cpu()
metric_dict = eval_tool.evaluate_mertics(
eval_info["pred"], eval_info["target"]
)
metrics.update(metric_dict)
if "epoch" not in metrics.keys():
metrics.update({"epoch": epoch})
if is_master(args):
if not metrics:
return metrics
logging.info(
f"Eval Epoch: {epoch} "
+ "\n".join(
["\t".join([f"{m}: {round(metrics[m], 4):.4f}"]) for m in metrics]
)
)
if args.save_logs:
for name, val in metrics.items():
if tb_writer is not None:
tb_writer.add_scalar(f"val{extra_suffix}/{name}", val, epoch)
with open(os.path.join(args.checkpoint_path, "results.jsonl"), "a+") as f:
f.write(json.dumps(metrics))
f.write("\n")
if args.wandb:
assert wandb is not None, "Please install wandb."
for name, val in metrics.items():
wandb.log({f"val{extra_suffix}/{name}": val, "epoch": epoch})
return metrics
else:
return metrics