File size: 12,919 Bytes
814c8cf |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 |
"""
Ability to train vq-vae and prior
First try for random inputs
Then from maestros
"""
import sys
import fire
import warnings
import numpy as np
import torch as t
import jukebox.utils.dist_adapter as dist
from torch.nn.parallel import DistributedDataParallel
from jukebox.hparams import setup_hparams
from jukebox.make_models import make_vqvae, make_prior, restore_opt, save_checkpoint
from jukebox.utils.logger import init_logging
from jukebox.utils.audio_utils import audio_preprocess, audio_postprocess
from jukebox.utils.torch_utils import zero_grad, count_parameters
from jukebox.utils.dist_utils import print_once, allreduce, allgather
from jukebox.utils.ema import CPUEMA, FusedEMA, EMA
from jukebox.utils.fp16 import FP16FusedAdam, FusedAdam, LossScalar, clipped_grad_scale, backward
from jukebox.data.data_processor import DataProcessor
def prepare_aud(x, hps):
x = audio_postprocess(x.detach().contiguous(), hps)
return allgather(x)
def log_aud(logger, tag, x, hps):
logger.add_audios(tag, prepare_aud(x, hps), hps.sr, max_len=hps.max_len, max_log=hps.max_log)
logger.flush()
def log_labels(logger, labeller, tag, y, hps):
y = y.cpu().numpy()
txt = ''
for item in range(y.shape[0]):
description = labeller.describe_label(y[item])
artist, genre, lyrics = description['artist'], description['genre'], description['lyrics']
txt += f'{item} artist:{artist}, genre:{genre}, lyrics:{lyrics}\n'
logger.add_text(tag, txt)
logger.flush()
def get_ddp(model, hps):
rank = dist.get_rank()
local_rank = rank % 8
ddp = DistributedDataParallel(model, device_ids=[local_rank], output_device=local_rank, broadcast_buffers=False, bucket_cap_mb=hps.bucket)
return ddp
def get_ema(model, hps):
mu = hps.mu or (1. - (hps.bs * hps.ngpus/8.)/1000)
ema = None
if hps.ema and hps.train:
if hps.cpu_ema:
if dist.get_rank() == 0:
print("Using CPU EMA")
ema = CPUEMA(model.parameters(), mu=mu, freq=hps.cpu_ema_freq)
elif hps.ema_fused:
ema = FusedEMA(model.parameters(), mu=mu)
else:
ema = EMA(model.parameters(), mu=mu)
return ema
def get_lr_scheduler(opt, hps):
def lr_lambda(step):
if hps.lr_use_linear_decay:
lr_scale = hps.lr_scale * min(1.0, step / hps.lr_warmup)
decay = max(0.0, 1.0 - max(0.0, step - hps.lr_start_linear_decay) / hps.lr_decay)
if decay == 0.0:
if dist.get_rank() == 0:
print("Reached end of training")
return lr_scale * decay
else:
return hps.lr_scale * (hps.lr_gamma ** (step // hps.lr_decay)) * min(1.0, step / hps.lr_warmup)
shd = t.optim.lr_scheduler.LambdaLR(opt, lr_lambda)
return shd
def get_optimizer(model, hps):
# Optimizer
betas = (hps.beta1, hps.beta2)
if hps.fp16_opt:
opt = FP16FusedAdam(model.parameters(), lr=hps.lr, weight_decay=hps.weight_decay, betas=betas, eps=hps.eps)
else:
opt = FusedAdam(model.parameters(), lr=hps.lr, weight_decay=hps.weight_decay, betas=betas, eps=hps.eps)
# lr scheduler
shd = get_lr_scheduler(opt, hps)
restore_path = hps.restore_prior if hps.prior else hps.restore_vqvae
restore_opt(opt, shd, restore_path)
# fp16 dynamic loss scaler
scalar = None
if hps.fp16:
rank = dist.get_rank()
local_rank = rank % 8
scalar = LossScalar(hps.fp16_loss_scale, scale_factor=2 ** (1./hps.fp16_scale_window))
if local_rank == 0: print(scalar.__dict__)
zero_grad(model)
return opt, shd, scalar
def log_inputs(orig_model, logger, x_in, y, x_out, hps, tag="train"):
print(f"Logging {tag} inputs/ouputs")
log_aud(logger, f'{tag}_x_in', x_in, hps)
log_aud(logger, f'{tag}_x_out', x_out, hps)
bs = x_in.shape[0]
if hps.prior:
if hps.labels:
log_labels(logger, orig_model.labeller, f'{tag}_y_in', allgather(y.cuda()), hps)
else:
zs_in = orig_model.encode(x_in, start_level=0, bs_chunks=bs)
x_ds = [orig_model.decode(zs_in[level:], start_level=level, bs_chunks=bs) for level in range(0, hps.levels)]
for i in range(len(x_ds)):
log_aud(logger, f'{tag}_x_ds_start_{i}', x_ds[i], hps)
logger.flush()
def sample_prior(orig_model, ema, logger, x_in, y, hps):
if ema is not None: ema.swap()
orig_model.eval()
x_in = x_in[:hps.bs_sample]
bs = x_in.shape[0]
zs_in = orig_model.encode(x_in, start_level=0, bs_chunks=bs)
assert len(zs_in) == hps.levels
x_ds = [orig_model.decode(zs_in[level:], start_level=level, bs_chunks=bs) for level in range(0, hps.levels)]
if not hps.labels:
y = None
elif hps.level == (hps.levels - 1):
# Topmost level labels in order
y = y[:hps.bs_sample] # t.ones((hps.bs_sample, 1), device=y.device, dtype=t.long) * dist.get_rank()
else:
# Other levels keep labels to match x_cond
y = y[:hps.bs_sample]
# Temp 1.0
_, *z_conds = orig_model.encode(x_in, bs_chunks=bs)
z = orig_model.sample(hps.bs_sample, z_conds=z_conds, y=y, fp16=False, temp=1.0)
x_sample = orig_model.decode([z, *z_conds], bs_chunks=bs)
log_aud(logger, 'sample_x_T1', x_sample, hps)
if hps.prior and hps.labels:
log_labels(logger, orig_model.labeller, f'sample_x_T1', allgather(y.cuda()), hps)
# Recons
for i in range(len(x_ds)):
log_aud(logger, f'x_ds_start_{i}', x_ds[i], hps)
orig_model.train()
if ema is not None: ema.swap()
logger.flush()
def evaluate(model, orig_model, logger, metrics, data_processor, hps):
model.eval()
orig_model.eval()
if hps.prior:
_print_keys = dict(l="loss", bpd="bpd")
else:
_print_keys = dict(l="loss", rl="recons_loss", sl="spectral_loss")
with t.no_grad():
for i, x in logger.get_range(data_processor.test_loader):
if isinstance(x, (tuple, list)):
x, y = x
else:
y = None
x = x.to('cuda', non_blocking=True)
if y is not None:
y = y.to('cuda', non_blocking=True)
x_in = x = audio_preprocess(x, hps)
log_input_output = (i==0)
if hps.prior:
forw_kwargs = dict(y=y, fp16=hps.fp16, decode=log_input_output)
else:
forw_kwargs = dict(loss_fn=hps.loss_fn, hps=hps)
x_out, loss, _metrics = model(x, **forw_kwargs)
# Logging
for key, val in _metrics.items():
_metrics[key] = val.item()
_metrics["loss"] = loss = loss.item() # Make sure to call to free graph
# Average and log
for key, val in _metrics.items():
_metrics[key] = metrics.update(f"test_{key}", val, x.shape[0])
with t.no_grad():
if log_input_output:
log_inputs(orig_model, logger, x_in, y, x_out, hps)
logger.set_postfix(**{print_key:_metrics[key] for print_key, key in _print_keys.items()})
for key, val in _metrics.items():
logger.add_scalar(f"test_{key}", metrics.avg(f"test_{key}"))
logger.close_range()
return {key: metrics.avg(f"test_{key}") for key in _metrics.keys()}
def train(model, orig_model, opt, shd, scalar, ema, logger, metrics, data_processor, hps):
model.train()
orig_model.train()
if hps.prior:
_print_keys = dict(l="loss", bpd="bpd", gn="gn", g_l="gen_loss", p_l="prime_loss")
else:
_print_keys = dict(l="loss", sl="spectral_loss", rl="recons_loss", e="entropy", u="usage", uc="used_curr", gn="gn", pn="pn", dk="dk")
for i, x in logger.get_range(data_processor.train_loader):
if isinstance(x, (tuple, list)):
x, y = x
else:
y = None
x = x.to('cuda', non_blocking=True)
if y is not None:
y = y.to('cuda', non_blocking=True)
x_in = x = audio_preprocess(x, hps)
log_input_output = (logger.iters % hps.save_iters == 0)
if hps.prior:
forw_kwargs = dict(y=y, fp16=hps.fp16, decode=log_input_output)
else:
forw_kwargs = dict(loss_fn=hps.loss_fn, hps=hps)
# Forward
x_out, loss, _metrics = model(x, **forw_kwargs)
# Backward
loss, scale, grad_norm, overflow_loss, overflow_grad = backward(loss=loss, params=list(model.parameters()),
scalar=scalar, fp16=hps.fp16, logger=logger)
# Skip step if overflow
grad_norm = allreduce(grad_norm, op=dist.ReduceOp.MAX)
if overflow_loss or overflow_grad or grad_norm > hps.ignore_grad_norm > 0:
zero_grad(orig_model)
continue
# Step opt. Divide by scale to include clipping and fp16 scaling
logger.step()
opt.step(scale=clipped_grad_scale(grad_norm, hps.clip, scale))
zero_grad(orig_model)
lr = hps.lr if shd is None else shd.get_lr()[0]
if shd is not None: shd.step()
if ema is not None: ema.step()
next_lr = hps.lr if shd is None else shd.get_lr()[0]
finished_training = (next_lr == 0.0)
# Logging
for key, val in _metrics.items():
_metrics[key] = val.item()
_metrics["loss"] = loss = loss.item() * hps.iters_before_update # Make sure to call to free graph
_metrics["gn"] = grad_norm
_metrics["lr"] = lr
_metrics["lg_loss_scale"] = np.log2(scale)
# Average and log
for key, val in _metrics.items():
_metrics[key] = metrics.update(key, val, x.shape[0])
if logger.iters % hps.log_steps == 0:
logger.add_scalar(key, _metrics[key])
# Save checkpoint
with t.no_grad():
if hps.save and (logger.iters % hps.save_iters == 1 or finished_training):
if ema is not None: ema.swap()
orig_model.eval()
name = 'latest' if hps.prior else f'step_{logger.iters}'
if dist.get_rank() % 8 == 0:
save_checkpoint(logger, name, orig_model, opt, dict(step=logger.iters), hps)
orig_model.train()
if ema is not None: ema.swap()
# Sample
with t.no_grad():
if (logger.iters % 12000) in list(range(1, 1 + hps.iters_before_update)) or finished_training:
if hps.prior:
sample_prior(orig_model, ema, logger, x_in, y, hps)
# Input/Output
with t.no_grad():
if log_input_output:
log_inputs(orig_model, logger, x_in, y, x_out, hps)
logger.set_postfix(**{print_key:_metrics[key] for print_key, key in _print_keys.items()})
if finished_training:
dist.barrier()
exit()
logger.close_range()
return {key: metrics.avg(key) for key in _metrics.keys()}
def run(hps="teeny", port=29500, **kwargs):
from jukebox.utils.dist_utils import setup_dist_from_mpi
rank, local_rank, device = setup_dist_from_mpi(port=port)
hps = setup_hparams(hps, kwargs)
hps.ngpus = dist.get_world_size()
hps.argv = " ".join(sys.argv)
hps.bs_sample = hps.nworkers = hps.bs
# Setup dataset
data_processor = DataProcessor(hps)
# Setup models
vqvae = make_vqvae(hps, device)
print_once(f"Parameters VQVAE:{count_parameters(vqvae)}")
if hps.prior:
prior = make_prior(hps, vqvae, device)
print_once(f"Parameters Prior:{count_parameters(prior)}")
model = prior
else:
model = vqvae
# Setup opt, ema and distributed_model.
opt, shd, scalar = get_optimizer(model, hps)
ema = get_ema(model, hps)
distributed_model = get_ddp(model, hps)
logger, metrics = init_logging(hps, local_rank, rank)
logger.iters = model.step
# Run training, eval, sample
for epoch in range(hps.curr_epoch, hps.epochs):
metrics.reset()
data_processor.set_epoch(epoch)
if hps.train:
train_metrics = train(distributed_model, model, opt, shd, scalar, ema, logger, metrics, data_processor, hps)
train_metrics['epoch'] = epoch
if rank == 0:
print('Train',' '.join([f'{key}: {val:0.4f}' for key,val in train_metrics.items()]))
dist.barrier()
if hps.test:
if ema: ema.swap()
test_metrics = evaluate(distributed_model, model, logger, metrics, data_processor, hps)
test_metrics['epoch'] = epoch
if rank == 0:
print('Ema',' '.join([f'{key}: {val:0.4f}' for key,val in test_metrics.items()]))
dist.barrier()
if ema: ema.swap()
dist.barrier()
if __name__ == '__main__':
fire.Fire(run)
|