soumickmj's picture
Upload DiffAE
c2ced9d verified
import copy
import numpy as np
import torch
from pytorch_lightning.callbacks import *
from torch.optim.optimizer import Optimizer
from transformers import PreTrainedModel
from .DiffAEConfig import DiffAEConfig
from .DiffAE_support import *
class DiffAE(PreTrainedModel):
config_class = DiffAEConfig
def __init__(self, config):
super().__init__(config)
conf = ukbb_autoenc(n_latents=config.latent_dim)
conf.__dict__.update(**vars(config)) #update the supplied DiffAE params
if config.test_with_TEval:
conf.T_inv = conf.T_eval
conf.T_step = conf.T_eval
conf.fp16 = config.ampmode not in ["32", "32-true"]
conf.refresh_values()
conf.make_model_conf()
self.config = config
self.conf = conf
self.net = conf.make_model_conf().make_model()
self.ema_net = copy.deepcopy(self.net)
self.ema_net.requires_grad_(False)
self.ema_net.eval()
model_size = sum(param.data.nelement() for param in self.net.parameters())
print('Model params: %.2f M' % (model_size / 1024 / 1024))
self.sampler = conf.make_diffusion_conf().make_sampler()
self.eval_sampler = conf.make_eval_diffusion_conf().make_sampler()
# this is shared for both model and latent
self.T_sampler = conf.make_T_sampler()
if conf.train_mode.use_latent_net():
self.latent_sampler = conf.make_latent_diffusion_conf(
).make_sampler()
self.eval_latent_sampler = conf.make_latent_eval_diffusion_conf(
).make_sampler()
else:
self.latent_sampler = None
self.eval_latent_sampler = None
# initial variables for consistent sampling
self.register_buffer('x_T', torch.randn(conf.sample_size, conf.in_channels, *conf.input_shape))
if conf.pretrain is not None:
print(f'loading pretrain ... {conf.pretrain.name}')
state = torch.load(conf.pretrain.path, map_location='cpu')
print('step:', state['global_step'])
self.load_state_dict(state['state_dict'], strict=False)
if conf.latent_infer_path is not None:
print('loading latent stats ...')
state = torch.load(conf.latent_infer_path)
self.conds = state['conds']
self.register_buffer('conds_mean', state['conds_mean'][None, :])
self.register_buffer('conds_std', state['conds_std'][None, :])
else:
self.conds_mean = None
self.conds_std = None
def normalise(self, cond):
cond = (cond - self.conds_mean.to(self.device)) / self.conds_std.to(
self.device)
return cond
def denormalise(self, cond):
cond = (cond * self.conds_std.to(self.device)) + self.conds_mean.to(
self.device)
return cond
def sample(self, N, device, T=None, T_latent=None):
if T is None:
sampler = self.eval_sampler
latent_sampler = self.latent_sampler
else:
sampler = self.conf._make_diffusion_conf(T).make_sampler()
latent_sampler = self.conf._make_latent_diffusion_conf(T_latent).make_sampler()
noise = torch.randn(N,
self.conf.in_channels,
*self.conf.input_shape,
device=device)
pred_img = render_uncondition(
self.conf,
self.ema_net,
noise,
sampler=sampler,
latent_sampler=latent_sampler,
conds_mean=self.conds_mean,
conds_std=self.conds_std,
)
pred_img = (pred_img + 1) / 2
return pred_img
def render(self, noise, cond=None, T=None, use_ema=True):
if T is None:
sampler = self.eval_sampler
else:
sampler = self.conf._make_diffusion_conf(T).make_sampler()
if cond is not None:
pred_img = render_condition(self.conf,
self.ema_net if use_ema else self.net,
noise,
sampler=sampler,
cond=cond)
else:
pred_img = render_uncondition(self.conf,
self.ema_net if use_ema else self.net,
noise,
sampler=sampler,
latent_sampler=None)
pred_img = (pred_img + 1) / 2
return pred_img
def encode(self, x, use_ema=True):
assert self.conf.model_type.has_autoenc()
return self.ema_net.encoder.forward(x) if use_ema else self.net.encoder.forward(x)
def encode_stochastic(self, x, cond, T=None, use_ema=True):
if T is None:
sampler = self.eval_sampler
else:
sampler = self.conf._make_diffusion_conf(T).make_sampler()
out = sampler.ddim_reverse_sample_loop(self.ema_net if use_ema else self.net,
x,
model_kwargs={'cond': cond})
return out['sample']
def forward(self, x_start=None, noise=None, ema_model: bool = False):
with amp.autocast(False):
model = self.ema_net if ema_model else self.net
return self.eval_sampler.sample(
model=model,
noise=noise,
x_start=x_start,
shape=noise.shape if noise is not None else x_start.shape,
)
def is_last_accum(self, batch_idx):
"""
is it the last gradient accumulation loop?
used with gradient_accum > 1 and to see if the optimizer will perform "step" in this iteration or not
"""
return (batch_idx + 1) % self.conf.accum_batches == 0
def training_step(self, batch, batch_idx):
"""
given an input, calculate the loss function
no optimization at this stage.
"""
with amp.autocast(False):
# forward
if self.conf.train_mode.require_dataset_infer():
# this mode as pre-calculated cond
cond = batch[0]
if self.conf.latent_znormalize:
cond = (cond - self.conds_mean.to(
self.device)) / self.conds_std.to(self.device)
else:
imgs, idxs = batch['inp']['data'], batch_idx
# print(f'(rank {self.global_rank}) batch size:', len(imgs))
x_start = imgs
if self.conf.train_mode == TrainMode.diffusion:
"""
main training mode!!!
"""
# with numpy seed we have the problem that the sample t's are related!
t, weight = self.T_sampler.sample(len(x_start), x_start.device)
losses = self.sampler.training_losses(model=self.net,
x_start=x_start,
t=t)
elif self.conf.train_mode.is_latent_diffusion():
"""
training the latent variables!
"""
# diffusion on the latent
t, weight = self.T_sampler.sample(len(cond), cond.device)
latent_losses = self.latent_sampler.training_losses(
model=self.net.latent_net, x_start=cond, t=t)
# train only do the latent diffusion
losses = {
'latent': latent_losses['loss'],
'loss': latent_losses['loss']
}
else:
raise NotImplementedError()
loss = losses['loss'].mean()
loss_dict = {"train_loss": loss}
for key in ['vae', 'latent', 'mmd', 'chamfer', 'arg_cnt']:
if key in losses:
loss_dict[f'train_{key}'] = losses[key].mean()
self.log_dict(loss_dict, on_step=True, on_epoch=True, reduce_fx="mean", sync_dist=True, batch_size=batch['inp']['data'].shape[0])
return loss
def on_train_batch_end(self, outputs, batch, batch_idx: int) -> None:
"""
after each training step ...
"""
if self.is_last_accum(batch_idx):
# only apply ema on the last gradient accumulation step,
# if it is the iteration that has optimizer.step()
if self.conf.train_mode == TrainMode.latent_diffusion:
# it trains only the latent hence change only the latent
ema(self.net.latent_net, self.ema_net.latent_net,
self.conf.ema_decay)
else:
ema(self.net, self.ema_net, self.conf.ema_decay)
def on_before_optimizer_step(self, optimizer: Optimizer) -> None:
# fix the fp16 + clip grad norm problem with pytorch lightinng
# this is the currently correct way to do it
if self.conf.grad_clip > 0:
# from trainer.params_grads import grads_norm, iter_opt_params
params = [
p for group in optimizer.param_groups for p in group['params']
]
# print('before:', grads_norm(iter_opt_params(optimizer)))
torch.nn.utils.clip_grad_norm_(params,
max_norm=self.conf.grad_clip)
# print('after:', grads_norm(iter_opt_params(optimizer)))
#Validation
def validation_step(self, batch, batch_idx):
_, prediction_ema = self.inference_pass(batch['inp']['data'], T_inv=self.conf.T_eval, T_step=self.conf.T_eval, use_ema=True)
_, prediction_base = self.inference_pass(batch['inp']['data'], T_inv=self.conf.T_eval, T_step=self.conf.T_eval, use_ema=False)
inp = batch['inp']['data'].cpu()
inp = (inp + 1) / 2
_, val_ssim_ema = self._eval_prediction(inp, prediction_ema)
_, val_ssim_base = self._eval_prediction(inp, prediction_base)
self.log_dict({"val_ssim_ema": val_ssim_ema, "val_ssim_base": val_ssim_base, "val_loss": -val_ssim_ema}, on_step=True, on_epoch=True, reduce_fx="mean", sync_dist=True, batch_size=batch['inp']['data'].shape[0])
self.img_logger("val_ema", batch_idx, inp, prediction_ema)
self.img_logger("val_base", batch_idx, inp, prediction_base)
def _eval_prediction(self, inp, prediction):
prediction = prediction.detach().cpu()
prediction = prediction.numpy() if prediction.dtype not in {torch.bfloat16, torch.float16} else prediction.to(dtype=torch.float32).numpy()
if self.config.grey2RGB in [0, 2]:
inp = inp[:, 1, ...].unsqueeze(1)
prediction = np.expand_dims(prediction[:, 1, ...], axis=1)
val_ssim = getSSIM(inp.numpy(), prediction, data_range=1)
return prediction, val_ssim
def inference_pass(self, inp, T_inv, T_step, use_ema=True):
semantic_latent = self.encode(inp, use_ema=use_ema)
if self.config.test_emb_only:
return semantic_latent, None
stochastic_latent = self.encode_stochastic(inp, semantic_latent, T=T_inv)
prediction = self.render(stochastic_latent, semantic_latent, T=T_step, use_ema=use_ema)
return semantic_latent, prediction
# Testing
def test_step(self, batch, batch_idx):
emb, recon = self.inference_pass(batch['inp']['data'], T_inv=self.conf.T_inv, T_step=self.conf.T_step, use_ema=self.config.test_ema)
emb = emb.detach().cpu()
emb = emb.numpy() if emb.dtype not in {torch.bfloat16, torch.float16} else emb.to(dtype=torch.float32).numpy()
return emb, recon
#Prediction
def predict_step(self, batch, batch_idx):
emb = self.encode(batch['inp']['data']).detach().cpu()
return emb.numpy() if emb.dtype not in {torch.bfloat16, torch.float16} else emb.to(dtype=torch.float32).numpy()
def configure_optimizers(self):
if self.conf.optimizer == OptimizerType.adam:
optim = torch.optim.Adam(self.net.parameters(),
lr=self.conf.lr,
weight_decay=self.conf.weight_decay)
elif self.conf.optimizer == OptimizerType.adamw:
optim = torch.optim.AdamW(self.net.parameters(),
lr=self.conf.lr,
weight_decay=self.conf.weight_decay)
else:
raise NotImplementedError()
out = {'optimizer': optim}
if self.conf.warmup > 0:
sched = torch.optim.lr_scheduler.LambdaLR(optim,
lr_lambda=WarmupLR(
self.conf.warmup))
out['lr_scheduler'] = {
'scheduler': sched,
'interval': 'step',
}
return out
def split_tensor(self, x):
"""
extract the tensor for a corresponding "worker" in the batch dimension
Args:
x: (n, c)
Returns: x: (n_local, c)
"""
n = len(x)
rank = self.global_rank
world_size = get_world_size()
# print(f'rank: {rank}/{world_size}')
per_rank = n // world_size
return x[rank * per_rank:(rank + 1) * per_rank]