COP-GEN-Beta / src /COP-GEN-Beta /train_triffuser_discrete.py
mikonvergence
github code incorporated
82f1234
raw
history blame
17 kB
import ml_collections
import torch
from torch import multiprocessing as mp
from datasets import get_dataset
from torchvision.utils import make_grid, save_image
import utils
import einops
from torch.utils._pytree import tree_map
import accelerate
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
from dpm_solver_pp import NoiseScheduleVP, DPM_Solver
import tempfile
from tools.fid_score import calculate_fid_given_paths
from absl import logging
import builtins
import os
import wandb
import libs.autoencoder
import numpy as np
def stable_diffusion_beta_schedule(linear_start=0.00085, linear_end=0.0120, n_timestep=1000):
_betas = (
torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2
)
return _betas.numpy()
def get_skip(alphas, betas):
N = len(betas) - 1
skip_alphas = np.ones([N + 1, N + 1], dtype=betas.dtype)
for s in range(N + 1):
skip_alphas[s, s + 1:] = alphas[s + 1:].cumprod()
skip_betas = np.zeros([N + 1, N + 1], dtype=betas.dtype)
for t in range(N + 1):
prod = betas[1: t + 1] * skip_alphas[1: t + 1, t]
skip_betas[:t, t] = (prod[::-1].cumsum())[::-1]
return skip_alphas, skip_betas
def stp(s, ts: torch.Tensor): # scalar tensor product
if isinstance(s, np.ndarray):
s = torch.from_numpy(s).type_as(ts)
extra_dims = (1,) * (ts.dim() - 1)
return s.view(-1, *extra_dims) * ts
def mos(a, start_dim=1): # mean of square
return a.pow(2).flatten(start_dim=start_dim).mean(dim=-1)
class Schedule(object): # discrete time
def __init__(self, _betas):
r""" _betas[0...999] = betas[1...1000]
for n>=1, betas[n] is the variance of q(xn|xn-1)
for n=0, betas[0]=0
"""
self._betas = _betas
self.betas = np.append(0., _betas)
self.alphas = 1. - self.betas
self.N = len(_betas)
assert isinstance(self.betas, np.ndarray) and self.betas[0] == 0
assert isinstance(self.alphas, np.ndarray) and self.alphas[0] == 1
assert len(self.betas) == len(self.alphas)
# skip_alphas[s, t] = alphas[s + 1: t + 1].prod()
self.skip_alphas, self.skip_betas = get_skip(self.alphas, self.betas)
self.cum_alphas = self.skip_alphas[0] # cum_alphas = alphas.cumprod()
self.cum_betas = self.skip_betas[0]
self.snr = self.cum_alphas / self.cum_betas
def tilde_beta(self, s, t):
return self.skip_betas[s, t] * self.cum_betas[s] / self.cum_betas[t]
def sample(self, x0, multi_modal=False): # sample from q(xn|x0), where n is uniform
if multi_modal:
n_list = []
eps_list = []
xn_list = []
for x0_i in x0:
n = np.random.choice(list(range(1, self.N + 1)), (len(x0_i),))
eps = torch.randn_like(x0_i)
xn = stp(self.cum_alphas[n] ** 0.5, x0_i) + stp(self.cum_betas[n] ** 0.5, eps)
n_list.append(torch.tensor(n, device=x0_i.device))
eps_list.append(eps)
xn_list.append(xn)
return n_list, eps_list, xn_list
else:
n = np.random.choice(list(range(1, self.N + 1)), (len(x0),))
eps = torch.randn_like(x0)
xn = stp(self.cum_alphas[n] ** 0.5, x0) + stp(self.cum_betas[n] ** 0.5, eps)
return torch.tensor(n, device=x0.device), eps, xn
def __repr__(self):
return f'Schedule({self.betas[:10]}..., {self.N})'
def LSimple(x0, nnet, schedule, multi_modal=False, **kwargs):
if multi_modal:
n_list, eps_list, xn_list = schedule.sample(x0, multi_modal=multi_modal) # n in {1, ..., 1000}
eps_pred = nnet(xn_list, n_list, **kwargs)
return sum(mos(n - np_) for n, np_ in zip(eps_list, eps_pred))
else:
n, eps, xn = schedule.sample(x0) # n in {1, ..., 1000}
eps_pred = nnet(xn, n, **kwargs)
return mos(eps - eps_pred)
def train(config):
if config.get('benchmark', False):
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.deterministic = False
mp.set_start_method('spawn')
accelerator = accelerate.Accelerator()
device = accelerator.device
accelerate.utils.set_seed(config.seed, device_specific=True)
logging.info(f'Process {accelerator.process_index} using device: {device}')
config.mixed_precision = accelerator.mixed_precision
config = ml_collections.FrozenConfigDict(config)
assert config.train.batch_size % accelerator.num_processes == 0
mini_batch_size = config.train.batch_size // accelerator.num_processes
if accelerator.is_main_process:
os.makedirs(config.ckpt_root, exist_ok=True)
os.makedirs(config.sample_dir, exist_ok=True)
accelerator.wait_for_everyone()
if accelerator.is_main_process:
wandb.init(dir=os.path.abspath(config.workdir), project=f'uvit_{config.dataset.name}', config=config.to_dict(),
name=config.hparams, job_type='train', mode='offline')
utils.set_logger(log_level='info', fname=os.path.join(config.workdir, 'output.log'))
logging.info(config)
else:
utils.set_logger(log_level='error')
builtins.print = lambda *args: None
logging.info(f'Run on {accelerator.num_processes} devices')
dataset = get_dataset(**config.dataset)
assert os.path.exists(dataset.fid_stat)
train_dataset = dataset.get_split(split='train', labeled=config.train.mode == 'cond')
train_dataset_loader = DataLoader(train_dataset, batch_size=mini_batch_size, shuffle=True, drop_last=True,
num_workers=8, pin_memory=True, persistent_workers=True)
train_state = utils.initialize_train_state(config, device)
nnet, nnet_ema, optimizer, train_dataset_loader = accelerator.prepare(
train_state.nnet, train_state.nnet_ema, train_state.optimizer, train_dataset_loader)
lr_scheduler = train_state.lr_scheduler
train_state.resume(config.ckpt_root)
autoencoder = libs.autoencoder.get_model(config.autoencoder.pretrained_path)
autoencoder.to(device)
@ torch.cuda.amp.autocast()
def encode(_batch):
return autoencoder.encode(_batch)
@ torch.cuda.amp.autocast()
def decode(_batch):
return autoencoder.decode(_batch)
def get_data_generator():
while True:
for data in tqdm(train_dataset_loader, disable=not accelerator.is_main_process, desc='epoch'):
yield data
data_generator = get_data_generator()
_betas = stable_diffusion_beta_schedule()
_schedule = Schedule(_betas)
logging.info(f'use {_schedule}')
def train_step(_batch):
_metrics = dict()
optimizer.zero_grad()
if config.train.mode == 'uncond': # Multi-modal data. Sample each modality independently
if config.train.multi_modal:
_zs = [autoencoder.sample(modality) if 'feature' in config.dataset.name else encode(modality) for modality in _batch]
loss = LSimple(_zs, nnet, _schedule, multi_modal=config.train.multi_modal)
else:
_z = autoencoder.sample(_batch) if 'feature' in config.dataset.name else encode(_batch)
loss = LSimple(_z, nnet, _schedule)
elif config.train.mode == 'cond':
_z = autoencoder.sample(_batch[0]) if 'feature' in config.dataset.name else encode(_batch[0])
loss = LSimple(_z, nnet, _schedule, y=_batch[1])
else:
raise NotImplementedError(config.train.mode)
_metrics['loss'] = accelerator.gather(loss.detach()).mean()
accelerator.backward(loss.mean())
optimizer.step()
lr_scheduler.step()
train_state.ema_update(config.get('ema_rate', 0.9999))
train_state.step += 1
return dict(lr=train_state.optimizer.param_groups[0]['lr'], **_metrics)
def dpm_solver_sample(_n_samples, _sample_steps, **kwargs):
_z_init = torch.randn(_n_samples, *config.z_shape, device=device)
noise_schedule = NoiseScheduleVP(schedule='discrete', betas=torch.tensor(_betas, device=device).float())
def model_fn(x, t_continuous):
t = t_continuous * _schedule.N
eps_pre = nnet_ema(x, t, **kwargs)
return eps_pre
dpm_solver = DPM_Solver(model_fn, noise_schedule, predict_x0=True, thresholding=False)
_z = dpm_solver.sample(_z_init, steps=_sample_steps, eps=1. / _schedule.N, T=1.)
return decode(_z)
def combine_joint(z):
z = torch.concat([einops.rearrange(z_i, 'B C H W -> B (C H W)') for z_i in z], dim=-1)
return z
def split_joint(x, n_modalities):
C, H, W = config.z_shape
z_dim = C * H * W
z = x.split([z_dim] * n_modalities, dim=1)
z = [einops.rearrange(z_i, 'B (C H W) -> B C H W', C=C, H=H, W=W) for z_i in z]
return z
def dpm_solver_sample_multi_modal(_n_modalities, _n_samples, _sample_steps, **kwargs):
"""here"""
_z_init = torch.randn(_n_modalities, _n_samples, *config.z_shape, device=device)
_z_init = combine_joint(_z_init)
noise_schedule = NoiseScheduleVP(schedule='discrete', betas=torch.tensor(_betas, device=device).float())
def model_fn(x, t_continuous):
t = t_continuous * _schedule.N
timesteps = [t] * _n_modalities
z = split_joint(x, _n_modalities)
z_out = nnet_ema(z, t_imgs=timesteps)
x_out = combine_joint(z_out)
# eps_pre = nnet_ema(x, t, **kwargs)
return x_out
dpm_solver = DPM_Solver(model_fn, noise_schedule, predict_x0=True, thresholding=False)
_zs = dpm_solver.sample(_z_init, steps=_sample_steps, eps=1. / _schedule.N, T=1.)
_zs = split_joint(_zs, _n_modalities)
samples_unstacked = [decode(_z) for _z in _zs]
return samples_unstacked
def eval_step(n_samples, sample_steps):
logging.info(f'eval_step: n_samples={n_samples}, sample_steps={sample_steps}'
f'mini_batch_size={config.sample.mini_batch_size}')
def sample_fn(_n_samples):
if config.train.mode == 'uncond':
kwargs = dict()
elif config.train.mode == 'cond':
kwargs = dict(y=dataset.sample_label(_n_samples, device=device))
else:
raise NotImplementedError
return dpm_solver_sample(_n_samples, sample_steps, **kwargs)
with tempfile.TemporaryDirectory() as temp_path:
path = config.sample.path or temp_path
if accelerator.is_main_process:
os.makedirs(path, exist_ok=True)
utils.sample2dir(accelerator, path, n_samples, config.sample.mini_batch_size, sample_fn, dataset.unpreprocess)
_fid = 0
if accelerator.is_main_process:
_fid = calculate_fid_given_paths((dataset.fid_stat, path))
logging.info(f'step={train_state.step} fid{n_samples}={_fid}')
with open(os.path.join(config.workdir, 'eval.log'), 'a') as f:
print(f'step={train_state.step} fid{n_samples}={_fid}', file=f)
wandb.log({f'fid{n_samples}': _fid}, step=train_state.step)
_fid = torch.tensor(_fid, device=device)
_fid = accelerator.reduce(_fid, reduction='sum')
return _fid.item()
logging.info(f'Start fitting, step={train_state.step}, mixed_precision={config.mixed_precision}')
step_fid = []
while train_state.step < config.train.n_steps:
nnet.train()
batch = tree_map(lambda x: x.to(device), next(data_generator))
metrics = train_step(batch)
nnet.eval()
if accelerator.is_main_process and train_state.step % config.train.log_interval == 0:
logging.info(utils.dct2str(dict(step=train_state.step, **metrics)))
logging.info(config.workdir)
wandb.log(metrics, step=train_state.step)
if accelerator.is_main_process and train_state.step % config.train.eval_interval == 0:
torch.cuda.empty_cache()
logging.info('Save a grid of images...')
if config.train.mode == 'uncond':
if config.train.multi_modal:
samples = dpm_solver_sample_multi_modal(_n_modalities=config.nnet.num_modalities, _n_samples=5 * 10, _sample_steps=50)
else:
samples = dpm_solver_sample(_n_samples=5 * 10, _sample_steps=50)
elif config.train.mode == 'cond':
y = einops.repeat(torch.arange(5, device=device) % dataset.K, 'nrow -> (nrow ncol)', ncol=10)
samples = dpm_solver_sample(_n_samples=5 * 10, _sample_steps=50, y=y)
else:
raise NotImplementedError
if config.train.multi_modal:
samples = torch.stack([dataset.unpreprocess(sample) for sample in samples], dim=0) # stack instead of cat
b = samples.shape[1] # batch size
# Properly interleave samples from all modalities
# For each sample index, get all modalities before moving to next sample
samples = torch.stack([samples[j, i] for i in range(b) for j in range(config.nnet.num_modalities)]).view(-1, *samples.shape[2:])
# If the number of modalities is 3 then we plot in 9 columns
n_cols = 9 if config.nnet.num_modalities == 3 else 10
samples = make_grid(samples, n_cols)
else:
samples = make_grid(dataset.unpreprocess(samples), 10)
save_image(samples, os.path.join(config.sample_dir, f'{train_state.step}.png'))
wandb.log({'samples': wandb.Image(samples)}, step=train_state.step)
torch.cuda.empty_cache()
accelerator.wait_for_everyone()
if train_state.step % config.train.save_interval == 0 or train_state.step == config.train.n_steps:
torch.cuda.empty_cache()
logging.info(f'Save and eval checkpoint {train_state.step}...')
if accelerator.is_main_process:
try:
train_state.save(os.path.join(config.ckpt_root, f'{train_state.step}.ckpt'))
except Exception as e:
logging.error(f" ==> Failed to save checkpoint: {e}!!!")
accelerator.wait_for_everyone()
# TODO: Skip FID for now
# fid = eval_step(n_samples=10000, sample_steps=50) # calculate fid of the saved checkpoint
# step_fid.append((train_state.step, fid))
torch.cuda.empty_cache()
accelerator.wait_for_everyone()
logging.info(f'Finish fitting, step={train_state.step}')
logging.info(f'step_fid: {step_fid}')
step_best = sorted(step_fid, key=lambda x: x[1])[0][0]
logging.info(f'step_best: {step_best}')
train_state.load(os.path.join(config.ckpt_root, f'{step_best}.ckpt'))
del metrics
accelerator.wait_for_everyone()
eval_step(n_samples=config.sample.n_samples, sample_steps=config.sample.sample_steps)
from absl import flags
from absl import app
from ml_collections import config_flags
import sys
from pathlib import Path
FLAGS = flags.FLAGS
config_flags.DEFINE_config_file(
"config", None, "Training configuration.", lock_config=False)
flags.mark_flags_as_required(["config"])
flags.DEFINE_string("workdir", None, "Work unit directory.")
def get_config_name():
argv = sys.argv
for i in range(1, len(argv)):
if argv[i].startswith('--config='):
return Path(argv[i].split('=')[-1]).stem
def get_config_path():
argv = sys.argv
for i in range(1, len(argv)):
if argv[i].startswith('--config='):
path = argv[i].split('=')[-1]
if path.startswith('configs/'):
path = path[len('configs/'):]
return path
def get_hparams():
argv = sys.argv
lst = []
for i in range(1, len(argv)):
assert '=' in argv[i]
if argv[i].startswith('--config.') and not argv[i].startswith('--config.dataset.path'):
hparam, val = argv[i].split('=')
hparam = hparam.split('.')[-1]
if hparam.endswith('path'):
val = Path(val).stem
lst.append(f'{hparam}={val}')
hparams = '-'.join(lst)
if hparams == '':
hparams = 'default'
return hparams
def main(argv):
config = FLAGS.config
# config.config_name = get_config_name()
config.config_name = get_config_path().strip('.py')
config.hparams = get_hparams()
config.workdir = FLAGS.workdir or os.path.join('workdir', config.config_name, config.hparams)
config.ckpt_root = os.path.join(config.workdir, 'ckpts')
config.sample_dir = os.path.join(config.workdir, 'samples')
train(config)
if __name__ == "__main__":
app.run(main)