|
import copy |
|
import json |
|
import os |
|
import re |
|
|
|
import numpy as np |
|
import pandas as pd |
|
import pytorch_lightning as pl |
|
import torch |
|
from numpy.lib.function_base import flip |
|
from pytorch_lightning import loggers as pl_loggers |
|
from pytorch_lightning.callbacks import * |
|
from torch import nn |
|
from torch.cuda import amp |
|
from torch.distributions import Categorical |
|
from torch.optim.optimizer import Optimizer |
|
from torch.utils.data.dataset import ConcatDataset, TensorDataset |
|
from torchvision.utils import make_grid, save_image |
|
|
|
from config import * |
|
from dataset import * |
|
from dist_utils import * |
|
from lmdb_writer import * |
|
from metrics import * |
|
from renderer import * |
|
|
|
|
|
class LitModel(pl.LightningModule): |
|
def __init__(self, conf: TrainConfig): |
|
super().__init__() |
|
assert conf.train_mode != TrainMode.manipulate |
|
if conf.seed is not None: |
|
pl.seed_everything(conf.seed) |
|
|
|
self.save_hyperparameters(conf.as_dict_jsonable()) |
|
|
|
self.conf = conf |
|
|
|
self.model = conf.make_model_conf().make_model() |
|
self.ema_model = copy.deepcopy(self.model) |
|
self.ema_model.requires_grad_(False) |
|
self.ema_model.eval() |
|
|
|
model_size = 0 |
|
for param in self.model.parameters(): |
|
model_size += param.data.nelement() |
|
print('Model params: %.2f M' % (model_size / 1024 / 1024)) |
|
|
|
self.sampler = conf.make_diffusion_conf().make_sampler() |
|
self.eval_sampler = conf.make_eval_diffusion_conf().make_sampler() |
|
|
|
|
|
self.T_sampler = conf.make_T_sampler() |
|
|
|
if conf.train_mode.use_latent_net(): |
|
self.latent_sampler = conf.make_latent_diffusion_conf( |
|
).make_sampler() |
|
self.eval_latent_sampler = conf.make_latent_eval_diffusion_conf( |
|
).make_sampler() |
|
else: |
|
self.latent_sampler = None |
|
self.eval_latent_sampler = None |
|
|
|
|
|
self.register_buffer( |
|
'x_T', |
|
torch.randn(conf.sample_size, 3, conf.img_size, conf.img_size)) |
|
|
|
if conf.pretrain is not None: |
|
print(f'loading pretrain ... {conf.pretrain.name}') |
|
state = torch.load(conf.pretrain.path, map_location='cpu') |
|
print('step:', state['global_step']) |
|
self.load_state_dict(state['state_dict'], strict=False) |
|
|
|
if conf.latent_infer_path is not None: |
|
print('loading latent stats ...') |
|
state = torch.load(conf.latent_infer_path) |
|
self.conds = state['conds'] |
|
self.register_buffer('conds_mean', state['conds_mean'][None, :]) |
|
self.register_buffer('conds_std', state['conds_std'][None, :]) |
|
else: |
|
self.conds_mean = None |
|
self.conds_std = None |
|
|
|
def normalize(self, cond): |
|
cond = (cond - self.conds_mean.to(self.device)) / self.conds_std.to( |
|
self.device) |
|
return cond |
|
|
|
def denormalize(self, cond): |
|
cond = (cond * self.conds_std.to(self.device)) + self.conds_mean.to( |
|
self.device) |
|
return cond |
|
|
|
def sample(self, N, device, T=None, T_latent=None): |
|
if T is None: |
|
sampler = self.eval_sampler |
|
latent_sampler = self.latent_sampler |
|
else: |
|
sampler = self.conf._make_diffusion_conf(T).make_sampler() |
|
latent_sampler = self.conf._make_latent_diffusion_conf(T_latent).make_sampler() |
|
|
|
noise = torch.randn(N, |
|
3, |
|
self.conf.img_size, |
|
self.conf.img_size, |
|
device=device) |
|
pred_img = render_uncondition( |
|
self.conf, |
|
self.ema_model, |
|
noise, |
|
sampler=sampler, |
|
latent_sampler=latent_sampler, |
|
conds_mean=self.conds_mean, |
|
conds_std=self.conds_std, |
|
) |
|
pred_img = (pred_img + 1) / 2 |
|
return pred_img |
|
|
|
def render(self, noise, cond=None, T=None): |
|
if T is None: |
|
sampler = self.eval_sampler |
|
else: |
|
sampler = self.conf._make_diffusion_conf(T).make_sampler() |
|
|
|
if cond is not None: |
|
pred_img = render_condition(self.conf, |
|
self.ema_model, |
|
noise, |
|
sampler=sampler, |
|
cond=cond) |
|
else: |
|
pred_img = render_uncondition(self.conf, |
|
self.ema_model, |
|
noise, |
|
sampler=sampler, |
|
latent_sampler=None) |
|
pred_img = (pred_img + 1) / 2 |
|
return pred_img |
|
|
|
def encode(self, x): |
|
|
|
assert self.conf.model_type.has_autoenc() |
|
cond = self.ema_model.encoder.forward(x) |
|
return cond |
|
|
|
def encode_stochastic(self, x, cond, T=None): |
|
if T is None: |
|
sampler = self.eval_sampler |
|
else: |
|
sampler = self.conf._make_diffusion_conf(T).make_sampler() |
|
out = sampler.ddim_reverse_sample_loop(self.ema_model, |
|
x, |
|
model_kwargs={'cond': cond}) |
|
return out['sample'] |
|
|
|
def forward(self, noise=None, x_start=None, ema_model: bool = False): |
|
with amp.autocast(False): |
|
if ema_model: |
|
model = self.ema_model |
|
else: |
|
model = self.model |
|
gen = self.eval_sampler.sample(model=model, |
|
noise=noise, |
|
x_start=x_start) |
|
return gen |
|
|
|
def setup(self, stage=None) -> None: |
|
""" |
|
make datasets & seeding each worker separately |
|
""" |
|
|
|
|
|
if self.conf.seed is not None: |
|
seed = self.conf.seed * get_world_size() + self.global_rank |
|
np.random.seed(seed) |
|
torch.manual_seed(seed) |
|
torch.cuda.manual_seed(seed) |
|
print('local seed:', seed) |
|
|
|
|
|
self.train_data = self.conf.make_dataset() |
|
print('train data:', len(self.train_data)) |
|
self.val_data = self.train_data |
|
print('val data:', len(self.val_data)) |
|
|
|
def _train_dataloader(self, drop_last=True): |
|
""" |
|
really make the dataloader |
|
""" |
|
|
|
|
|
conf = self.conf.clone() |
|
conf.batch_size = self.batch_size |
|
|
|
dataloader = conf.make_loader(self.train_data, |
|
shuffle=True, |
|
drop_last=drop_last) |
|
return dataloader |
|
|
|
def train_dataloader(self): |
|
""" |
|
return the dataloader, if diffusion mode => return image dataset |
|
if latent mode => return the inferred latent dataset |
|
""" |
|
print('on train dataloader start ...') |
|
if self.conf.train_mode.require_dataset_infer(): |
|
if self.conds is None: |
|
|
|
|
|
self.conds = self.infer_whole_dataset() |
|
|
|
|
|
self.conds_mean.data = self.conds.float().mean(dim=0, |
|
keepdim=True) |
|
self.conds_std.data = self.conds.float().std(dim=0, |
|
keepdim=True) |
|
print('mean:', self.conds_mean.mean(), 'std:', |
|
self.conds_std.mean()) |
|
|
|
|
|
conf = self.conf.clone() |
|
conf.batch_size = self.batch_size |
|
data = TensorDataset(self.conds) |
|
return conf.make_loader(data, shuffle=True) |
|
else: |
|
return self._train_dataloader() |
|
|
|
@property |
|
def batch_size(self): |
|
""" |
|
local batch size for each worker |
|
""" |
|
ws = get_world_size() |
|
assert self.conf.batch_size % ws == 0 |
|
return self.conf.batch_size // ws |
|
|
|
@property |
|
def num_samples(self): |
|
""" |
|
(global) batch size * iterations |
|
""" |
|
|
|
|
|
return self.global_step * self.conf.batch_size_effective |
|
|
|
def is_last_accum(self, batch_idx): |
|
""" |
|
is it the last gradient accumulation loop? |
|
used with gradient_accum > 1 and to see if the optimizer will perform "step" in this iteration or not |
|
""" |
|
return (batch_idx + 1) % self.conf.accum_batches == 0 |
|
|
|
def infer_whole_dataset(self, |
|
with_render=False, |
|
T_render=None, |
|
render_save_path=None): |
|
""" |
|
predicting the latents given images using the encoder |
|
|
|
Args: |
|
both_flips: include both original and flipped images; no need, it's not an improvement |
|
with_render: whether to also render the images corresponding to that latent |
|
render_save_path: lmdb output for the rendered images |
|
""" |
|
data = self.conf.make_dataset() |
|
if isinstance(data, CelebAlmdb) and data.crop_d2c: |
|
|
|
data.transform = make_transform(self.conf.img_size, |
|
flip_prob=0, |
|
crop_d2c=True) |
|
else: |
|
data.transform = make_transform(self.conf.img_size, flip_prob=0) |
|
|
|
|
|
|
|
loader = self.conf.make_loader( |
|
data, |
|
shuffle=False, |
|
drop_last=False, |
|
batch_size=self.conf.batch_size_eval, |
|
parallel=True, |
|
) |
|
model = self.ema_model |
|
model.eval() |
|
conds = [] |
|
|
|
if with_render: |
|
sampler = self.conf._make_diffusion_conf( |
|
T=T_render or self.conf.T_eval).make_sampler() |
|
|
|
if self.global_rank == 0: |
|
writer = LMDBImageWriter(render_save_path, |
|
format='webp', |
|
quality=100) |
|
else: |
|
writer = nullcontext() |
|
else: |
|
writer = nullcontext() |
|
|
|
with writer: |
|
for batch in tqdm(loader, total=len(loader), desc='infer'): |
|
with torch.no_grad(): |
|
|
|
|
|
cond = model.encoder(batch['img'].to(self.device)) |
|
|
|
|
|
idx = batch['index'] |
|
idx = self.all_gather(idx) |
|
if idx.dim() == 2: |
|
idx = idx.flatten(0, 1) |
|
argsort = idx.argsort() |
|
|
|
if with_render: |
|
noise = torch.randn(len(cond), |
|
3, |
|
self.conf.img_size, |
|
self.conf.img_size, |
|
device=self.device) |
|
render = sampler.sample(model, noise=noise, cond=cond) |
|
render = (render + 1) / 2 |
|
|
|
|
|
render = self.all_gather(render) |
|
if render.dim() == 5: |
|
|
|
render = render.flatten(0, 1) |
|
|
|
|
|
|
|
if self.global_rank == 0: |
|
writer.put_images(render[argsort]) |
|
|
|
|
|
cond = self.all_gather(cond) |
|
|
|
if cond.dim() == 3: |
|
|
|
cond = cond.flatten(0, 1) |
|
|
|
conds.append(cond[argsort].cpu()) |
|
|
|
model.train() |
|
|
|
|
|
conds = torch.cat(conds).float() |
|
return conds |
|
|
|
def training_step(self, batch, batch_idx): |
|
""" |
|
given an input, calculate the loss function |
|
no optimization at this stage. |
|
""" |
|
with amp.autocast(False): |
|
|
|
|
|
if self.conf.train_mode.require_dataset_infer(): |
|
|
|
cond = batch[0] |
|
if self.conf.latent_znormalize: |
|
cond = (cond - self.conds_mean.to( |
|
self.device)) / self.conds_std.to(self.device) |
|
else: |
|
imgs, idxs = batch['img'], batch['index'] |
|
|
|
x_start = imgs |
|
|
|
if self.conf.train_mode == TrainMode.diffusion: |
|
""" |
|
main training mode!!! |
|
""" |
|
|
|
t, weight = self.T_sampler.sample(len(x_start), x_start.device) |
|
losses = self.sampler.training_losses(model=self.model, |
|
x_start=x_start, |
|
t=t) |
|
elif self.conf.train_mode.is_latent_diffusion(): |
|
""" |
|
training the latent variables! |
|
""" |
|
|
|
t, weight = self.T_sampler.sample(len(cond), cond.device) |
|
latent_losses = self.latent_sampler.training_losses( |
|
model=self.model.latent_net, x_start=cond, t=t) |
|
|
|
losses = { |
|
'latent': latent_losses['loss'], |
|
'loss': latent_losses['loss'] |
|
} |
|
else: |
|
raise NotImplementedError() |
|
|
|
loss = losses['loss'].mean() |
|
|
|
for key in ['loss', 'vae', 'latent', 'mmd', 'chamfer', 'arg_cnt']: |
|
if key in losses: |
|
losses[key] = self.all_gather(losses[key]).mean() |
|
|
|
if self.global_rank == 0: |
|
self.logger.experiment.add_scalar('loss', losses['loss'], |
|
self.num_samples) |
|
for key in ['vae', 'latent', 'mmd', 'chamfer', 'arg_cnt']: |
|
if key in losses: |
|
self.logger.experiment.add_scalar( |
|
f'loss/{key}', losses[key], self.num_samples) |
|
|
|
return {'loss': loss} |
|
|
|
def on_train_batch_end(self, outputs, batch, batch_idx: int, |
|
dataloader_idx: int) -> None: |
|
""" |
|
after each training step ... |
|
""" |
|
if self.is_last_accum(batch_idx): |
|
|
|
|
|
if self.conf.train_mode == TrainMode.latent_diffusion: |
|
|
|
ema(self.model.latent_net, self.ema_model.latent_net, |
|
self.conf.ema_decay) |
|
else: |
|
ema(self.model, self.ema_model, self.conf.ema_decay) |
|
|
|
|
|
if self.conf.train_mode.require_dataset_infer(): |
|
imgs = None |
|
else: |
|
imgs = batch['img'] |
|
self.log_sample(x_start=imgs) |
|
self.evaluate_scores() |
|
|
|
def on_before_optimizer_step(self, optimizer: Optimizer, |
|
optimizer_idx: int) -> None: |
|
|
|
|
|
if self.conf.grad_clip > 0: |
|
|
|
params = [ |
|
p for group in optimizer.param_groups for p in group['params'] |
|
] |
|
|
|
torch.nn.utils.clip_grad_norm_(params, |
|
max_norm=self.conf.grad_clip) |
|
|
|
|
|
def log_sample(self, x_start): |
|
""" |
|
put images to the tensorboard |
|
""" |
|
def do(model, |
|
postfix, |
|
use_xstart, |
|
save_real=False, |
|
no_latent_diff=False, |
|
interpolate=False): |
|
model.eval() |
|
with torch.no_grad(): |
|
all_x_T = self.split_tensor(self.x_T) |
|
batch_size = min(len(all_x_T), self.conf.batch_size_eval) |
|
|
|
loader = DataLoader(all_x_T, batch_size=batch_size) |
|
|
|
Gen = [] |
|
for x_T in loader: |
|
if use_xstart: |
|
_xstart = x_start[:len(x_T)] |
|
else: |
|
_xstart = None |
|
|
|
if self.conf.train_mode.is_latent_diffusion( |
|
) and not use_xstart: |
|
|
|
gen = render_uncondition( |
|
conf=self.conf, |
|
model=model, |
|
x_T=x_T, |
|
sampler=self.eval_sampler, |
|
latent_sampler=self.eval_latent_sampler, |
|
conds_mean=self.conds_mean, |
|
conds_std=self.conds_std) |
|
else: |
|
if not use_xstart and self.conf.model_type.has_noise_to_cond( |
|
): |
|
model: BeatGANsAutoencModel |
|
|
|
cond = torch.randn(len(x_T), |
|
self.conf.style_ch, |
|
device=self.device) |
|
cond = model.noise_to_cond(cond) |
|
else: |
|
if interpolate: |
|
with amp.autocast(self.conf.fp16): |
|
cond = model.encoder(_xstart) |
|
i = torch.randperm(len(cond)) |
|
cond = (cond + cond[i]) / 2 |
|
else: |
|
cond = None |
|
gen = self.eval_sampler.sample(model=model, |
|
noise=x_T, |
|
cond=cond, |
|
x_start=_xstart) |
|
Gen.append(gen) |
|
|
|
gen = torch.cat(Gen) |
|
gen = self.all_gather(gen) |
|
if gen.dim() == 5: |
|
|
|
gen = gen.flatten(0, 1) |
|
|
|
if save_real and use_xstart: |
|
|
|
real = self.all_gather(_xstart) |
|
if real.dim() == 5: |
|
real = real.flatten(0, 1) |
|
|
|
if self.global_rank == 0: |
|
grid_real = (make_grid(real) + 1) / 2 |
|
self.logger.experiment.add_image( |
|
f'sample{postfix}/real', grid_real, |
|
self.num_samples) |
|
|
|
if self.global_rank == 0: |
|
|
|
grid = (make_grid(gen) + 1) / 2 |
|
sample_dir = os.path.join(self.conf.logdir, |
|
f'sample{postfix}') |
|
if not os.path.exists(sample_dir): |
|
os.makedirs(sample_dir) |
|
path = os.path.join(sample_dir, |
|
'%d.png' % self.num_samples) |
|
save_image(grid, path) |
|
self.logger.experiment.add_image(f'sample{postfix}', grid, |
|
self.num_samples) |
|
model.train() |
|
|
|
if self.conf.sample_every_samples > 0 and is_time( |
|
self.num_samples, self.conf.sample_every_samples, |
|
self.conf.batch_size_effective): |
|
|
|
if self.conf.train_mode.require_dataset_infer(): |
|
do(self.model, '', use_xstart=False) |
|
do(self.ema_model, '_ema', use_xstart=False) |
|
else: |
|
if self.conf.model_type.has_autoenc( |
|
) and self.conf.model_type.can_sample(): |
|
do(self.model, '', use_xstart=False) |
|
do(self.ema_model, '_ema', use_xstart=False) |
|
|
|
do(self.model, '_enc', use_xstart=True, save_real=True) |
|
do(self.ema_model, |
|
'_enc_ema', |
|
use_xstart=True, |
|
save_real=True) |
|
elif self.conf.train_mode.use_latent_net(): |
|
do(self.model, '', use_xstart=False) |
|
do(self.ema_model, '_ema', use_xstart=False) |
|
|
|
do(self.model, '_enc', use_xstart=True, save_real=True) |
|
do(self.model, |
|
'_enc_nodiff', |
|
use_xstart=True, |
|
save_real=True, |
|
no_latent_diff=True) |
|
do(self.ema_model, |
|
'_enc_ema', |
|
use_xstart=True, |
|
save_real=True) |
|
else: |
|
do(self.model, '', use_xstart=True, save_real=True) |
|
do(self.ema_model, '_ema', use_xstart=True, save_real=True) |
|
|
|
def evaluate_scores(self): |
|
""" |
|
evaluate FID and other scores during training (put to the tensorboard) |
|
For, FID. It is a fast version with 5k images (gold standard is 50k). |
|
Don't use its results in the paper! |
|
""" |
|
def fid(model, postfix): |
|
score = evaluate_fid(self.eval_sampler, |
|
model, |
|
self.conf, |
|
device=self.device, |
|
train_data=self.train_data, |
|
val_data=self.val_data, |
|
latent_sampler=self.eval_latent_sampler, |
|
conds_mean=self.conds_mean, |
|
conds_std=self.conds_std) |
|
if self.global_rank == 0: |
|
self.logger.experiment.add_scalar(f'FID{postfix}', score, |
|
self.num_samples) |
|
if not os.path.exists(self.conf.logdir): |
|
os.makedirs(self.conf.logdir) |
|
with open(os.path.join(self.conf.logdir, 'eval.txt'), |
|
'a') as f: |
|
metrics = { |
|
f'FID{postfix}': score, |
|
'num_samples': self.num_samples, |
|
} |
|
f.write(json.dumps(metrics) + "\n") |
|
|
|
def lpips(model, postfix): |
|
if self.conf.model_type.has_autoenc( |
|
) and self.conf.train_mode.is_autoenc(): |
|
|
|
score = evaluate_lpips(self.eval_sampler, |
|
model, |
|
self.conf, |
|
device=self.device, |
|
val_data=self.val_data, |
|
latent_sampler=self.eval_latent_sampler) |
|
|
|
if self.global_rank == 0: |
|
for key, val in score.items(): |
|
self.logger.experiment.add_scalar( |
|
f'{key}{postfix}', val, self.num_samples) |
|
|
|
if self.conf.eval_every_samples > 0 and self.num_samples > 0 and is_time( |
|
self.num_samples, self.conf.eval_every_samples, |
|
self.conf.batch_size_effective): |
|
print(f'eval fid @ {self.num_samples}') |
|
lpips(self.model, '') |
|
fid(self.model, '') |
|
|
|
if self.conf.eval_ema_every_samples > 0 and self.num_samples > 0 and is_time( |
|
self.num_samples, self.conf.eval_ema_every_samples, |
|
self.conf.batch_size_effective): |
|
print(f'eval fid ema @ {self.num_samples}') |
|
fid(self.ema_model, '_ema') |
|
|
|
|
|
|
|
def configure_optimizers(self): |
|
out = {} |
|
if self.conf.optimizer == OptimizerType.adam: |
|
optim = torch.optim.Adam(self.model.parameters(), |
|
lr=self.conf.lr, |
|
weight_decay=self.conf.weight_decay) |
|
elif self.conf.optimizer == OptimizerType.adamw: |
|
optim = torch.optim.AdamW(self.model.parameters(), |
|
lr=self.conf.lr, |
|
weight_decay=self.conf.weight_decay) |
|
else: |
|
raise NotImplementedError() |
|
out['optimizer'] = optim |
|
if self.conf.warmup > 0: |
|
sched = torch.optim.lr_scheduler.LambdaLR(optim, |
|
lr_lambda=WarmupLR( |
|
self.conf.warmup)) |
|
out['lr_scheduler'] = { |
|
'scheduler': sched, |
|
'interval': 'step', |
|
} |
|
return out |
|
|
|
def split_tensor(self, x): |
|
""" |
|
extract the tensor for a corresponding "worker" in the batch dimension |
|
|
|
Args: |
|
x: (n, c) |
|
|
|
Returns: x: (n_local, c) |
|
""" |
|
n = len(x) |
|
rank = self.global_rank |
|
world_size = get_world_size() |
|
|
|
per_rank = n // world_size |
|
return x[rank * per_rank:(rank + 1) * per_rank] |
|
|
|
def test_step(self, batch, *args, **kwargs): |
|
""" |
|
for the "eval" mode. |
|
We first select what to do according to the "conf.eval_programs". |
|
test_step will only run for "one iteration" (it's a hack!). |
|
|
|
We just want the multi-gpu support. |
|
""" |
|
|
|
self.setup() |
|
|
|
|
|
print('global step:', self.global_step) |
|
""" |
|
"infer" = predict the latent variables using the encoder on the whole dataset |
|
""" |
|
if 'infer' in self.conf.eval_programs: |
|
if 'infer' in self.conf.eval_programs: |
|
print('infer ...') |
|
conds = self.infer_whole_dataset().float() |
|
|
|
save_path = f'checkpoints/{self.conf.name}/latent.pkl' |
|
else: |
|
raise NotImplementedError() |
|
|
|
if self.global_rank == 0: |
|
conds_mean = conds.mean(dim=0) |
|
conds_std = conds.std(dim=0) |
|
if not os.path.exists(os.path.dirname(save_path)): |
|
os.makedirs(os.path.dirname(save_path)) |
|
torch.save( |
|
{ |
|
'conds': conds, |
|
'conds_mean': conds_mean, |
|
'conds_std': conds_std, |
|
}, save_path) |
|
""" |
|
"infer+render" = predict the latent variables using the encoder on the whole dataset |
|
THIS ALSO GENERATE CORRESPONDING IMAGES |
|
""" |
|
|
|
for each in self.conf.eval_programs: |
|
if each.startswith('infer+render'): |
|
m = re.match(r'infer\+render([0-9]+)', each) |
|
if m is not None: |
|
T = int(m[1]) |
|
self.setup() |
|
print(f'infer + reconstruction T{T} ...') |
|
conds = self.infer_whole_dataset( |
|
with_render=True, |
|
T_render=T, |
|
render_save_path= |
|
f'latent_infer_render{T}/{self.conf.name}.lmdb', |
|
) |
|
save_path = f'latent_infer_render{T}/{self.conf.name}.pkl' |
|
conds_mean = conds.mean(dim=0) |
|
conds_std = conds.std(dim=0) |
|
if not os.path.exists(os.path.dirname(save_path)): |
|
os.makedirs(os.path.dirname(save_path)) |
|
torch.save( |
|
{ |
|
'conds': conds, |
|
'conds_mean': conds_mean, |
|
'conds_std': conds_std, |
|
}, save_path) |
|
|
|
|
|
""" |
|
"fid<T>" = unconditional generation (conf.train_mode = diffusion). |
|
Note: Diff. autoenc will still receive real images in this mode. |
|
"fid<T>,<T_latent>" = unconditional generation for latent models (conf.train_mode = latent_diffusion). |
|
Note: Diff. autoenc will still NOT receive real images in this made. |
|
but you need to make sure that the train_mode is latent_diffusion. |
|
""" |
|
for each in self.conf.eval_programs: |
|
if each.startswith('fid'): |
|
m = re.match(r'fid\(([0-9]+),([0-9]+)\)', each) |
|
clip_latent_noise = False |
|
if m is not None: |
|
|
|
T = int(m[1]) |
|
T_latent = int(m[2]) |
|
print(f'evaluating FID T = {T}... latent T = {T_latent}') |
|
else: |
|
m = re.match(r'fidclip\(([0-9]+),([0-9]+)\)', each) |
|
if m is not None: |
|
|
|
T = int(m[1]) |
|
T_latent = int(m[2]) |
|
clip_latent_noise = True |
|
print( |
|
f'evaluating FID (clip latent noise) T = {T}... latent T = {T_latent}' |
|
) |
|
else: |
|
|
|
_, T = each.split('fid') |
|
T = int(T) |
|
T_latent = None |
|
print(f'evaluating FID T = {T}...') |
|
|
|
self.train_dataloader() |
|
sampler = self.conf._make_diffusion_conf(T=T).make_sampler() |
|
if T_latent is not None: |
|
latent_sampler = self.conf._make_latent_diffusion_conf( |
|
T=T_latent).make_sampler() |
|
else: |
|
latent_sampler = None |
|
|
|
conf = self.conf.clone() |
|
conf.eval_num_images = 50_000 |
|
score = evaluate_fid( |
|
sampler, |
|
self.ema_model, |
|
conf, |
|
device=self.device, |
|
train_data=self.train_data, |
|
val_data=self.val_data, |
|
latent_sampler=latent_sampler, |
|
conds_mean=self.conds_mean, |
|
conds_std=self.conds_std, |
|
remove_cache=False, |
|
clip_latent_noise=clip_latent_noise, |
|
) |
|
if T_latent is None: |
|
self.log(f'fid_ema_T{T}', score) |
|
else: |
|
name = 'fid' |
|
if clip_latent_noise: |
|
name += '_clip' |
|
name += f'_ema_T{T}_Tlatent{T_latent}' |
|
self.log(name, score) |
|
""" |
|
"recon<T>" = reconstruction & autoencoding (without noise inversion) |
|
""" |
|
for each in self.conf.eval_programs: |
|
if each.startswith('recon'): |
|
self.model: BeatGANsAutoencModel |
|
_, T = each.split('recon') |
|
T = int(T) |
|
print(f'evaluating reconstruction T = {T}...') |
|
|
|
sampler = self.conf._make_diffusion_conf(T=T).make_sampler() |
|
|
|
conf = self.conf.clone() |
|
|
|
conf.eval_num_images = len(self.val_data) |
|
|
|
score = evaluate_lpips(sampler, |
|
self.ema_model, |
|
conf, |
|
device=self.device, |
|
val_data=self.val_data, |
|
latent_sampler=None) |
|
for k, v in score.items(): |
|
self.log(f'{k}_ema_T{T}', v) |
|
""" |
|
"inv<T>" = reconstruction with noise inversion |
|
""" |
|
for each in self.conf.eval_programs: |
|
if each.startswith('inv'): |
|
self.model: BeatGANsAutoencModel |
|
_, T = each.split('inv') |
|
T = int(T) |
|
print( |
|
f'evaluating reconstruction with noise inversion T = {T}...' |
|
) |
|
|
|
sampler = self.conf._make_diffusion_conf(T=T).make_sampler() |
|
|
|
conf = self.conf.clone() |
|
|
|
conf.eval_num_images = len(self.val_data) |
|
|
|
score = evaluate_lpips(sampler, |
|
self.ema_model, |
|
conf, |
|
device=self.device, |
|
val_data=self.val_data, |
|
latent_sampler=None, |
|
use_inverted_noise=True) |
|
for k, v in score.items(): |
|
self.log(f'{k}_inv_ema_T{T}', v) |
|
|
|
|
|
def ema(source, target, decay): |
|
source_dict = source.state_dict() |
|
target_dict = target.state_dict() |
|
for key in source_dict.keys(): |
|
target_dict[key].data.copy_(target_dict[key].data * decay + |
|
source_dict[key].data * (1 - decay)) |
|
|
|
|
|
class WarmupLR: |
|
def __init__(self, warmup) -> None: |
|
self.warmup = warmup |
|
|
|
def __call__(self, step): |
|
return min(step, self.warmup) / self.warmup |
|
|
|
|
|
def is_time(num_samples, every, step_size): |
|
closest = (num_samples // every) * every |
|
return num_samples - closest < step_size |
|
|
|
|
|
def train(conf: TrainConfig, gpus, nodes=1, mode: str = 'train'): |
|
print('conf:', conf.name) |
|
|
|
|
|
model = LitModel(conf) |
|
|
|
if not os.path.exists(conf.logdir): |
|
os.makedirs(conf.logdir) |
|
checkpoint = ModelCheckpoint(dirpath=f'{conf.logdir}', |
|
save_last=True, |
|
save_top_k=1, |
|
every_n_train_steps=conf.save_every_samples // |
|
conf.batch_size_effective) |
|
checkpoint_path = f'{conf.logdir}/last.ckpt' |
|
print('ckpt path:', checkpoint_path) |
|
if os.path.exists(checkpoint_path): |
|
resume = checkpoint_path |
|
print('resume!') |
|
else: |
|
if conf.continue_from is not None: |
|
|
|
resume = conf.continue_from.path |
|
else: |
|
resume = None |
|
|
|
tb_logger = pl_loggers.TensorBoardLogger(save_dir=conf.logdir, |
|
name=None, |
|
version='') |
|
|
|
|
|
|
|
plugins = [] |
|
if len(gpus) == 1 and nodes == 1: |
|
accelerator = None |
|
else: |
|
accelerator = 'ddp' |
|
from pytorch_lightning.plugins import DDPPlugin |
|
|
|
|
|
plugins.append(DDPPlugin(find_unused_parameters=False)) |
|
|
|
trainer = pl.Trainer( |
|
max_steps=conf.total_samples // conf.batch_size_effective, |
|
resume_from_checkpoint=resume, |
|
gpus=gpus, |
|
num_nodes=nodes, |
|
accelerator=accelerator, |
|
precision=16 if conf.fp16 else 32, |
|
callbacks=[ |
|
checkpoint, |
|
LearningRateMonitor(), |
|
], |
|
|
|
|
|
replace_sampler_ddp=True, |
|
logger=tb_logger, |
|
accumulate_grad_batches=conf.accum_batches, |
|
plugins=plugins, |
|
) |
|
|
|
if mode == 'train': |
|
trainer.fit(model) |
|
elif mode == 'eval': |
|
|
|
|
|
|
|
dummy = DataLoader(TensorDataset(torch.tensor([0.] * conf.batch_size)), |
|
batch_size=conf.batch_size) |
|
eval_path = conf.eval_path or checkpoint_path |
|
|
|
print('loading from:', eval_path) |
|
state = torch.load(eval_path, map_location='cpu') |
|
print('step:', state['global_step']) |
|
model.load_state_dict(state['state_dict']) |
|
|
|
out = trainer.test(model, dataloaders=dummy) |
|
|
|
out = out[0] |
|
print(out) |
|
|
|
if get_rank() == 0: |
|
|
|
for k, v in out.items(): |
|
tb_logger.experiment.add_scalar( |
|
k, v, state['global_step'] * conf.batch_size_effective) |
|
|
|
|
|
|
|
|
|
|
|
tgt = f'evals/{conf.name}.txt' |
|
dirname = os.path.dirname(tgt) |
|
if not os.path.exists(dirname): |
|
os.makedirs(dirname) |
|
with open(tgt, 'a') as f: |
|
f.write(json.dumps(out) + "\n") |
|
|
|
else: |
|
raise NotImplementedError() |
|
|