""" Ability to train vq-vae and prior First try for random inputs Then from maestros """ import sys import fire import warnings import numpy as np import torch as t import jukebox.utils.dist_adapter as dist from torch.nn.parallel import DistributedDataParallel from jukebox.hparams import setup_hparams from jukebox.make_models import make_vqvae, make_prior, restore_opt, save_checkpoint from jukebox.utils.logger import init_logging from jukebox.utils.audio_utils import audio_preprocess, audio_postprocess from jukebox.utils.torch_utils import zero_grad, count_parameters from jukebox.utils.dist_utils import print_once, allreduce, allgather from jukebox.utils.ema import CPUEMA, FusedEMA, EMA from jukebox.utils.fp16 import FP16FusedAdam, FusedAdam, LossScalar, clipped_grad_scale, backward from jukebox.data.data_processor import DataProcessor def prepare_aud(x, hps): x = audio_postprocess(x.detach().contiguous(), hps) return allgather(x) def log_aud(logger, tag, x, hps): logger.add_audios(tag, prepare_aud(x, hps), hps.sr, max_len=hps.max_len, max_log=hps.max_log) logger.flush() def log_labels(logger, labeller, tag, y, hps): y = y.cpu().numpy() txt = '' for item in range(y.shape[0]): description = labeller.describe_label(y[item]) artist, genre, lyrics = description['artist'], description['genre'], description['lyrics'] txt += f'{item} artist:{artist}, genre:{genre}, lyrics:{lyrics}\n' logger.add_text(tag, txt) logger.flush() def get_ddp(model, hps): rank = dist.get_rank() local_rank = rank % 8 ddp = DistributedDataParallel(model, device_ids=[local_rank], output_device=local_rank, broadcast_buffers=False, bucket_cap_mb=hps.bucket) return ddp def get_ema(model, hps): mu = hps.mu or (1. - (hps.bs * hps.ngpus/8.)/1000) ema = None if hps.ema and hps.train: if hps.cpu_ema: if dist.get_rank() == 0: print("Using CPU EMA") ema = CPUEMA(model.parameters(), mu=mu, freq=hps.cpu_ema_freq) elif hps.ema_fused: ema = FusedEMA(model.parameters(), mu=mu) else: ema = EMA(model.parameters(), mu=mu) return ema def get_lr_scheduler(opt, hps): def lr_lambda(step): if hps.lr_use_linear_decay: lr_scale = hps.lr_scale * min(1.0, step / hps.lr_warmup) decay = max(0.0, 1.0 - max(0.0, step - hps.lr_start_linear_decay) / hps.lr_decay) if decay == 0.0: if dist.get_rank() == 0: print("Reached end of training") return lr_scale * decay else: return hps.lr_scale * (hps.lr_gamma ** (step // hps.lr_decay)) * min(1.0, step / hps.lr_warmup) shd = t.optim.lr_scheduler.LambdaLR(opt, lr_lambda) return shd def get_optimizer(model, hps): # Optimizer betas = (hps.beta1, hps.beta2) if hps.fp16_opt: opt = FP16FusedAdam(model.parameters(), lr=hps.lr, weight_decay=hps.weight_decay, betas=betas, eps=hps.eps) else: opt = FusedAdam(model.parameters(), lr=hps.lr, weight_decay=hps.weight_decay, betas=betas, eps=hps.eps) # lr scheduler shd = get_lr_scheduler(opt, hps) restore_path = hps.restore_prior if hps.prior else hps.restore_vqvae restore_opt(opt, shd, restore_path) # fp16 dynamic loss scaler scalar = None if hps.fp16: rank = dist.get_rank() local_rank = rank % 8 scalar = LossScalar(hps.fp16_loss_scale, scale_factor=2 ** (1./hps.fp16_scale_window)) if local_rank == 0: print(scalar.__dict__) zero_grad(model) return opt, shd, scalar def log_inputs(orig_model, logger, x_in, y, x_out, hps, tag="train"): print(f"Logging {tag} inputs/ouputs") log_aud(logger, f'{tag}_x_in', x_in, hps) log_aud(logger, f'{tag}_x_out', x_out, hps) bs = x_in.shape[0] if hps.prior: if hps.labels: log_labels(logger, orig_model.labeller, f'{tag}_y_in', allgather(y.cuda()), hps) else: zs_in = orig_model.encode(x_in, start_level=0, bs_chunks=bs) x_ds = [orig_model.decode(zs_in[level:], start_level=level, bs_chunks=bs) for level in range(0, hps.levels)] for i in range(len(x_ds)): log_aud(logger, f'{tag}_x_ds_start_{i}', x_ds[i], hps) logger.flush() def sample_prior(orig_model, ema, logger, x_in, y, hps): if ema is not None: ema.swap() orig_model.eval() x_in = x_in[:hps.bs_sample] bs = x_in.shape[0] zs_in = orig_model.encode(x_in, start_level=0, bs_chunks=bs) assert len(zs_in) == hps.levels x_ds = [orig_model.decode(zs_in[level:], start_level=level, bs_chunks=bs) for level in range(0, hps.levels)] if not hps.labels: y = None elif hps.level == (hps.levels - 1): # Topmost level labels in order y = y[:hps.bs_sample] # t.ones((hps.bs_sample, 1), device=y.device, dtype=t.long) * dist.get_rank() else: # Other levels keep labels to match x_cond y = y[:hps.bs_sample] # Temp 1.0 _, *z_conds = orig_model.encode(x_in, bs_chunks=bs) z = orig_model.sample(hps.bs_sample, z_conds=z_conds, y=y, fp16=False, temp=1.0) x_sample = orig_model.decode([z, *z_conds], bs_chunks=bs) log_aud(logger, 'sample_x_T1', x_sample, hps) if hps.prior and hps.labels: log_labels(logger, orig_model.labeller, f'sample_x_T1', allgather(y.cuda()), hps) # Recons for i in range(len(x_ds)): log_aud(logger, f'x_ds_start_{i}', x_ds[i], hps) orig_model.train() if ema is not None: ema.swap() logger.flush() def evaluate(model, orig_model, logger, metrics, data_processor, hps): model.eval() orig_model.eval() if hps.prior: _print_keys = dict(l="loss", bpd="bpd") else: _print_keys = dict(l="loss", rl="recons_loss", sl="spectral_loss") with t.no_grad(): for i, x in logger.get_range(data_processor.test_loader): if isinstance(x, (tuple, list)): x, y = x else: y = None x = x.to('cuda', non_blocking=True) if y is not None: y = y.to('cuda', non_blocking=True) x_in = x = audio_preprocess(x, hps) log_input_output = (i==0) if hps.prior: forw_kwargs = dict(y=y, fp16=hps.fp16, decode=log_input_output) else: forw_kwargs = dict(loss_fn=hps.loss_fn, hps=hps) x_out, loss, _metrics = model(x, **forw_kwargs) # Logging for key, val in _metrics.items(): _metrics[key] = val.item() _metrics["loss"] = loss = loss.item() # Make sure to call to free graph # Average and log for key, val in _metrics.items(): _metrics[key] = metrics.update(f"test_{key}", val, x.shape[0]) with t.no_grad(): if log_input_output: log_inputs(orig_model, logger, x_in, y, x_out, hps) logger.set_postfix(**{print_key:_metrics[key] for print_key, key in _print_keys.items()}) for key, val in _metrics.items(): logger.add_scalar(f"test_{key}", metrics.avg(f"test_{key}")) logger.close_range() return {key: metrics.avg(f"test_{key}") for key in _metrics.keys()} def train(model, orig_model, opt, shd, scalar, ema, logger, metrics, data_processor, hps): model.train() orig_model.train() if hps.prior: _print_keys = dict(l="loss", bpd="bpd", gn="gn", g_l="gen_loss", p_l="prime_loss") else: _print_keys = dict(l="loss", sl="spectral_loss", rl="recons_loss", e="entropy", u="usage", uc="used_curr", gn="gn", pn="pn", dk="dk") for i, x in logger.get_range(data_processor.train_loader): if isinstance(x, (tuple, list)): x, y = x else: y = None x = x.to('cuda', non_blocking=True) if y is not None: y = y.to('cuda', non_blocking=True) x_in = x = audio_preprocess(x, hps) log_input_output = (logger.iters % hps.save_iters == 0) if hps.prior: forw_kwargs = dict(y=y, fp16=hps.fp16, decode=log_input_output) else: forw_kwargs = dict(loss_fn=hps.loss_fn, hps=hps) # Forward x_out, loss, _metrics = model(x, **forw_kwargs) # Backward loss, scale, grad_norm, overflow_loss, overflow_grad = backward(loss=loss, params=list(model.parameters()), scalar=scalar, fp16=hps.fp16, logger=logger) # Skip step if overflow grad_norm = allreduce(grad_norm, op=dist.ReduceOp.MAX) if overflow_loss or overflow_grad or grad_norm > hps.ignore_grad_norm > 0: zero_grad(orig_model) continue # Step opt. Divide by scale to include clipping and fp16 scaling logger.step() opt.step(scale=clipped_grad_scale(grad_norm, hps.clip, scale)) zero_grad(orig_model) lr = hps.lr if shd is None else shd.get_lr()[0] if shd is not None: shd.step() if ema is not None: ema.step() next_lr = hps.lr if shd is None else shd.get_lr()[0] finished_training = (next_lr == 0.0) # Logging for key, val in _metrics.items(): _metrics[key] = val.item() _metrics["loss"] = loss = loss.item() * hps.iters_before_update # Make sure to call to free graph _metrics["gn"] = grad_norm _metrics["lr"] = lr _metrics["lg_loss_scale"] = np.log2(scale) # Average and log for key, val in _metrics.items(): _metrics[key] = metrics.update(key, val, x.shape[0]) if logger.iters % hps.log_steps == 0: logger.add_scalar(key, _metrics[key]) # Save checkpoint with t.no_grad(): if hps.save and (logger.iters % hps.save_iters == 1 or finished_training): if ema is not None: ema.swap() orig_model.eval() name = 'latest' if hps.prior else f'step_{logger.iters}' if dist.get_rank() % 8 == 0: save_checkpoint(logger, name, orig_model, opt, dict(step=logger.iters), hps) orig_model.train() if ema is not None: ema.swap() # Sample with t.no_grad(): if (logger.iters % 12000) in list(range(1, 1 + hps.iters_before_update)) or finished_training: if hps.prior: sample_prior(orig_model, ema, logger, x_in, y, hps) # Input/Output with t.no_grad(): if log_input_output: log_inputs(orig_model, logger, x_in, y, x_out, hps) logger.set_postfix(**{print_key:_metrics[key] for print_key, key in _print_keys.items()}) if finished_training: dist.barrier() exit() logger.close_range() return {key: metrics.avg(key) for key in _metrics.keys()} def run(hps="teeny", port=29500, **kwargs): from jukebox.utils.dist_utils import setup_dist_from_mpi rank, local_rank, device = setup_dist_from_mpi(port=port) hps = setup_hparams(hps, kwargs) hps.ngpus = dist.get_world_size() hps.argv = " ".join(sys.argv) hps.bs_sample = hps.nworkers = hps.bs # Setup dataset data_processor = DataProcessor(hps) # Setup models vqvae = make_vqvae(hps, device) print_once(f"Parameters VQVAE:{count_parameters(vqvae)}") if hps.prior: prior = make_prior(hps, vqvae, device) print_once(f"Parameters Prior:{count_parameters(prior)}") model = prior else: model = vqvae # Setup opt, ema and distributed_model. opt, shd, scalar = get_optimizer(model, hps) ema = get_ema(model, hps) distributed_model = get_ddp(model, hps) logger, metrics = init_logging(hps, local_rank, rank) logger.iters = model.step # Run training, eval, sample for epoch in range(hps.curr_epoch, hps.epochs): metrics.reset() data_processor.set_epoch(epoch) if hps.train: train_metrics = train(distributed_model, model, opt, shd, scalar, ema, logger, metrics, data_processor, hps) train_metrics['epoch'] = epoch if rank == 0: print('Train',' '.join([f'{key}: {val:0.4f}' for key,val in train_metrics.items()])) dist.barrier() if hps.test: if ema: ema.swap() test_metrics = evaluate(distributed_model, model, logger, metrics, data_processor, hps) test_metrics['epoch'] = epoch if rank == 0: print('Ema',' '.join([f'{key}: {val:0.4f}' for key,val in test_metrics.items()])) dist.barrier() if ema: ema.swap() dist.barrier() if __name__ == '__main__': fire.Fire(run)