rule-guided-music / scripts /classifier_train_aug.py
yjhuangcd
First commit
9965bf6
"""
Train a noised image classifier on ImageNet.
"""
import argparse
import os
import os.path as osp
import blobfile as bf
import torch as th
import torch.distributed as dist
import torch.nn.functional as F
from torch.nn.parallel.distributed import DistributedDataParallel as DDP
from torch.optim import AdamW
from guided_diffusion import dist_util, logger
from guided_diffusion.fp16_util import MixedPrecisionTrainer
from guided_diffusion.pr_datasets_all import load_data
from guided_diffusion.dit import DiT_models
from guided_diffusion.resample import create_named_schedule_sampler
from guided_diffusion.script_util import (
add_dict_to_argparser,
create_diffusion,
args_to_dict,
classifier_and_diffusion_defaults,
create_classifier_and_diffusion,
)
from load_utils import load_model
from guided_diffusion.train_util import parse_resume_step_from_filename, log_loss_dict, get_kl_input
def main():
args = create_argparser().parse_args()
comm = dist_util.setup_dist(port=args.port)
logger.configure(args=args, comm=comm)
logger.log("creating model and diffusion...")
model = DiT_models[args.model](
input_size=args.image_size,
in_channels=args.in_channels,
num_classes=args.num_classes,
chord=True if args.rule == 'chord_progression' else False,
)
diffusion = create_diffusion(
learn_sigma=args.learn_sigma,
diffusion_steps=args.diffusion_steps,
noise_schedule=args.noise_schedule,
timestep_respacing=args.timestep_respacing,
use_kl=args.use_kl,
predict_xstart=args.predict_xstart,
rescale_timesteps=args.rescale_timesteps,
rescale_learned_sigmas=args.rescale_learned_sigmas,
)
# create embed model
embed_model = load_model(args.embed_model_name, args.embed_model_ckpt)
del embed_model.loss
embed_model.to(dist_util.dev())
embed_model.eval()
model.to(dist_util.dev())
if args.noised:
schedule_sampler = create_named_schedule_sampler(
args.schedule_sampler, diffusion
)
resume_step = 0
if args.resume_checkpoint:
resume_step = parse_resume_step_from_filename(args.resume_checkpoint)
if dist.get_rank() == 0:
logger.log(
f"loading model from checkpoint: {args.resume_checkpoint}... at {resume_step} step"
)
model.load_state_dict(
dist_util.load_state_dict(
args.resume_checkpoint, map_location=dist_util.dev()
)
)
# Needed for creating correct EMAs and fp16 parameters.
dist_util.sync_params(model.parameters())
mp_trainer = MixedPrecisionTrainer(
model=model, use_fp16=args.classifier_use_fp16, initial_lg_loss_scale=16.0
)
model = DDP(
model,
device_ids=[dist_util.dev()],
output_device=dist_util.dev(),
broadcast_buffers=False,
bucket_cap_mb=128,
find_unused_parameters=False,
)
logger.log("creating data loader...")
data = load_data(
data_dir=args.data_dir + "_train.csv",
batch_size=args.batch_size // args.encode_rep,
class_cond=True if args.rule is None else False,
image_size=args.pr_image_size,
rule=args.rule,
)
if args.val_data_dir:
val_data = load_data(
data_dir=args.data_dir + "_test.csv",
batch_size=args.batch_size // args.encode_rep,
class_cond=True if args.rule is None else False,
image_size=args.pr_image_size,
rule=args.rule,
)
else:
val_data = None
logger.log(f"creating optimizer...")
opt = AdamW(mp_trainer.master_params, lr=args.lr, weight_decay=args.weight_decay)
if args.resume_checkpoint:
opt_checkpoint = bf.join(
bf.dirname(args.resume_checkpoint), f"opt{resume_step:06}.pt"
)
logger.log(f"loading optimizer state from checkpoint: {opt_checkpoint}")
opt.load_state_dict(
dist_util.load_state_dict(opt_checkpoint, map_location=dist_util.dev())
)
logger.log("training classifier model...")
def forward_backward_log(data_loader, prefix="train", rule=None):
batch, extra = next(data_loader)
if rule is not None:
if rule == 'chord_progression':
labels_key = extra["key"].to(dist_util.dev()) # B x 1
labels_chord = extra["chord"].to(dist_util.dev()) # B x 8
labels = th.concat((labels_key, labels_chord), dim=-1) # B x (1+8)
else:
labels = extra[rule].to(dist_util.dev())
else:
labels = extra["y"].to(dist_util.dev())
if args.get_KL:
# need more sample diversity in a batch for classification
batch = get_kl_input(batch, microbatch=args.microbatch_encode, model=embed_model,
scale_factor=args.scale_factor, recombine=False)
if batch.shape[0] != labels.shape[0]:
labels = labels.repeat_interleave(args.encode_rep, dim=0)
batch = batch.to(dist_util.dev())
# Noisy images
if args.noised:
t, _ = schedule_sampler.sample(batch.shape[0], dist_util.dev())
# decoder cannot decode samples with t < 750
if args.no_high_noise:
t[t > 750] = 1000 - t[t > 750]
batch = diffusion.q_sample(batch, t)
else:
t = th.zeros(batch.shape[0], dtype=th.long, device=dist_util.dev())
for i, (sub_batch, sub_labels, sub_t) in enumerate(
split_microbatches(args.microbatch, batch, labels, t)
):
if rule == 'chord_progression':
key, chord = model(sub_batch, sub_t)
else:
logits = model(sub_batch, sub_t)
if rule is not None:
if rule == 'chord_progression':
sub_labels_key = sub_labels[:, :1].squeeze()
sub_labels_chord = sub_labels[:, 1:].reshape(-1)
chord = chord.reshape(-1, chord.shape[-1])
loss_key = F.cross_entropy(key, sub_labels_key, reduction="none")
loss_chord = F.cross_entropy(chord, sub_labels_chord, reduction="none")
# reshape to B x n_chord (8), and average along n_chord
loss_chord = loss_chord.reshape(sub_batch.shape[0], -1).mean(dim=-1)
loss = (loss_key + loss_chord) / 2
else:
loss = F.mse_loss(logits, sub_labels, reduction="none").mean(dim=-1)
else: # train for cfg condition
loss = F.cross_entropy(logits, sub_labels, reduction="none")
losses = {}
losses[f"{prefix}_loss"] = loss.detach()
if rule is None:
losses[f"{prefix}_acc@1"] = compute_top_k(
logits, sub_labels, k=1, reduction="none"
)
# losses[f"{prefix}_acc@5"] = compute_top_k(
# logits, sub_labels, k=5, reduction="none"
# )
elif rule == 'chord_progression':
losses[f"{prefix}_acc@1"] = compute_top_k(
chord, sub_labels_chord, k=1, reduction="none"
)
log_loss_dict(diffusion, sub_t, losses)
del losses
loss = loss.mean()
if loss.requires_grad:
if i == 0:
mp_trainer.zero_grad()
mp_trainer.backward(loss * len(sub_batch) / len(batch))
for step in range(args.iterations - resume_step):
logger.logkv("step", step + resume_step)
logger.logkv(
"samples",
(step + resume_step + 1) * args.batch_size * dist.get_world_size(),
)
if args.anneal_lr:
set_annealed_lr(opt, args.lr, (step + resume_step) / args.iterations)
forward_backward_log(data, rule=args.rule)
mp_trainer.optimize(opt)
if val_data is not None and not step % args.eval_interval:
with th.no_grad():
with model.no_sync():
model.eval()
forward_backward_log(val_data, prefix="val", rule=args.rule)
model.train()
if not step % args.log_interval:
logger.dumpkvs()
if (
step
and dist.get_rank() == 0
and not (step + resume_step) % args.save_interval
):
logger.log("saving model...")
save_model(mp_trainer, opt, step + resume_step)
if dist.get_rank() == 0:
logger.log("saving model...")
save_model(mp_trainer, opt, step + resume_step)
dist.barrier()
def set_annealed_lr(opt, base_lr, frac_done):
lr = base_lr * (1 - frac_done)
for param_group in opt.param_groups:
param_group["lr"] = lr
def save_model(mp_trainer, opt, step):
if dist.get_rank() == 0:
th.save(
mp_trainer.master_params_to_state_dict(mp_trainer.master_params),
os.path.join(logger.get_dir(), f"model{step:06d}.pt"),
)
th.save(opt.state_dict(), os.path.join(logger.get_dir(), f"opt{step:06d}.pt"))
def compute_top_k(logits, labels, k, reduction="mean"):
_, top_ks = th.topk(logits, k, dim=-1)
if reduction == "mean":
return (top_ks == labels[:, None]).float().sum(dim=-1).mean().item()
elif reduction == "none":
return (top_ks == labels[:, None]).float().sum(dim=-1)
def split_microbatches(microbatch, *args):
bs = len(args[0])
if microbatch == -1 or microbatch >= bs:
yield tuple(args)
else:
for i in range(0, bs, microbatch):
yield tuple(x[i : i + microbatch] if x is not None else None for x in args)
def create_argparser():
defaults = dict(
project="music-guided-classifier",
dir="",
data_dir="",
val_data_dir="",
model="DiT-B/8", # DiT model names
in_channels=4,
noised=True,
no_high_noise=False,
iterations=150000,
lr=3e-4,
weight_decay=0.0,
anneal_lr=False,
batch_size=4,
encode_rep=1, # whether to use recombination of encoded excerpts
microbatch=-1,
schedule_sampler="uniform",
resume_checkpoint="",
log_interval=10,
eval_interval=5,
save_interval=10000,
get_KL=True,
scale_factor=1.,
embed_model_name="kl/f8-all-onset",
embed_model_ckpt="taming-transformers/checkpoints/all_onset/epoch_14.ckpt",
microbatch_encode=-1,
pr_image_size=1024,
rule=None,
num_classes=9, # number of outputs from classifier
training=False, # not training diffusion
port=None, # whether to use fixed port for ngc
)
defaults.update(classifier_and_diffusion_defaults())
parser = argparse.ArgumentParser()
add_dict_to_argparser(parser, defaults)
return parser
if __name__ == "__main__":
main()