Spaces:
Runtime error
Runtime error
#!/usr/bin/env python3 | |
"""Trains Karras et al. (2022) diffusion models.""" | |
import argparse | |
from copy import deepcopy | |
from functools import partial | |
import importlib.util | |
import math | |
import json | |
from pathlib import Path | |
import time | |
import accelerate | |
import safetensors.torch as safetorch | |
import torch | |
import torch._dynamo | |
from torch import distributed as dist | |
from torch import multiprocessing as mp | |
from torch import optim | |
from torch.utils import data, flop_counter | |
from torchvision import datasets, transforms, utils | |
from tqdm.auto import tqdm | |
import k_diffusion as K | |
def ensure_distributed(): | |
if not dist.is_initialized(): | |
dist.init_process_group(world_size=1, rank=0, store=dist.HashStore()) | |
def main(): | |
p = argparse.ArgumentParser(description=__doc__, | |
formatter_class=argparse.ArgumentDefaultsHelpFormatter) | |
p.add_argument('--batch-size', type=int, default=64, | |
help='the batch size') | |
p.add_argument('--checkpointing', action='store_true', | |
help='enable gradient checkpointing') | |
p.add_argument('--clip-model', type=str, default='ViT-B/16', | |
choices=K.evaluation.CLIPFeatureExtractor.available_models(), | |
help='the CLIP model to use to evaluate') | |
p.add_argument('--compile', action='store_true', | |
help='compile the model') | |
p.add_argument('--config', type=str, required=True, | |
help='the configuration file') | |
p.add_argument('--demo-every', type=int, default=500, | |
help='save a demo grid every this many steps') | |
p.add_argument('--dinov2-model', type=str, default='vitl14', | |
choices=K.evaluation.DINOv2FeatureExtractor.available_models(), | |
help='the DINOv2 model to use to evaluate') | |
p.add_argument('--end-step', type=int, default=None, | |
help='the step to end training at') | |
p.add_argument('--evaluate-every', type=int, default=10000, | |
help='evaluate every this many steps') | |
p.add_argument('--evaluate-n', type=int, default=2000, | |
help='the number of samples to draw to evaluate') | |
p.add_argument('--evaluate-only', action='store_true', | |
help='evaluate instead of training') | |
p.add_argument('--evaluate-with', type=str, default='inception', | |
choices=['inception', 'clip', 'dinov2'], | |
help='the feature extractor to use for evaluation') | |
p.add_argument('--gns', action='store_true', | |
help='measure the gradient noise scale (DDP only, disables stratified sampling)') | |
p.add_argument('--grad-accum-steps', type=int, default=1, | |
help='the number of gradient accumulation steps') | |
p.add_argument('--lr', type=float, | |
help='the learning rate') | |
p.add_argument('--mixed-precision', type=str, | |
help='the mixed precision type') | |
p.add_argument('--name', type=str, default='model', | |
help='the name of the run') | |
p.add_argument('--num-workers', type=int, default=8, | |
help='the number of data loader workers') | |
p.add_argument('--reset-ema', action='store_true', | |
help='reset the EMA') | |
p.add_argument('--resume', type=str, | |
help='the checkpoint to resume from') | |
p.add_argument('--resume-inference', type=str, | |
help='the inference checkpoint to resume from') | |
p.add_argument('--sample-n', type=int, default=64, | |
help='the number of images to sample for demo grids') | |
p.add_argument('--save-every', type=int, default=10000, | |
help='save every this many steps') | |
p.add_argument('--seed', type=int, | |
help='the random seed') | |
p.add_argument('--start-method', type=str, default='spawn', | |
choices=['fork', 'forkserver', 'spawn'], | |
help='the multiprocessing start method') | |
p.add_argument('--wandb-entity', type=str, | |
help='the wandb entity name') | |
p.add_argument('--wandb-group', type=str, | |
help='the wandb group name') | |
p.add_argument('--wandb-project', type=str, | |
help='the wandb project name (specify this to enable wandb)') | |
p.add_argument('--wandb-save-model', action='store_true', | |
help='save model to wandb') | |
args = p.parse_args() | |
mp.set_start_method(args.start_method) | |
torch.backends.cuda.matmul.allow_tf32 = True | |
try: | |
torch._dynamo.config.automatic_dynamic_shapes = False | |
except AttributeError: | |
pass | |
config = K.config.load_config(args.config) | |
model_config = config['model'] | |
dataset_config = config['dataset'] | |
opt_config = config['optimizer'] | |
sched_config = config['lr_sched'] | |
ema_sched_config = config['ema_sched'] | |
# TODO: allow non-square input sizes | |
assert len(model_config['input_size']) == 2 and model_config['input_size'][0] == model_config['input_size'][1] | |
size = model_config['input_size'] | |
accelerator = accelerate.Accelerator(gradient_accumulation_steps=args.grad_accum_steps, mixed_precision=args.mixed_precision) | |
ensure_distributed() | |
device = accelerator.device | |
unwrap = accelerator.unwrap_model | |
print(f'Process {accelerator.process_index} using device: {device}', flush=True) | |
accelerator.wait_for_everyone() | |
if accelerator.is_main_process: | |
print(f'World size: {accelerator.num_processes}', flush=True) | |
print(f'Batch size: {args.batch_size * accelerator.num_processes}', flush=True) | |
if args.seed is not None: | |
seeds = torch.randint(-2 ** 63, 2 ** 63 - 1, [accelerator.num_processes], generator=torch.Generator().manual_seed(args.seed)) | |
torch.manual_seed(seeds[accelerator.process_index]) | |
demo_gen = torch.Generator().manual_seed(torch.randint(-2 ** 63, 2 ** 63 - 1, ()).item()) | |
elapsed = 0.0 | |
inner_model = K.config.make_model(config) | |
inner_model_ema = deepcopy(inner_model) | |
if args.compile: | |
inner_model.compile() | |
# inner_model_ema.compile() | |
if accelerator.is_main_process: | |
print(f'Parameters: {K.utils.n_params(inner_model):,}') | |
# If logging to wandb, initialize the run | |
use_wandb = accelerator.is_main_process and args.wandb_project | |
if use_wandb: | |
import wandb | |
log_config = vars(args) | |
log_config['config'] = config | |
log_config['parameters'] = K.utils.n_params(inner_model) | |
wandb.init(project=args.wandb_project, entity=args.wandb_entity, group=args.wandb_group, config=log_config, save_code=True) | |
lr = opt_config['lr'] if args.lr is None else args.lr | |
groups = inner_model.param_groups(lr) | |
if opt_config['type'] == 'adamw': | |
opt = optim.AdamW(groups, | |
lr=lr, | |
betas=tuple(opt_config['betas']), | |
eps=opt_config['eps'], | |
weight_decay=opt_config['weight_decay']) | |
elif opt_config['type'] == 'adam8bit': | |
import bitsandbytes as bnb | |
opt = bnb.optim.Adam8bit(groups, | |
lr=lr, | |
betas=tuple(opt_config['betas']), | |
eps=opt_config['eps'], | |
weight_decay=opt_config['weight_decay']) | |
elif opt_config['type'] == 'sgd': | |
opt = optim.SGD(groups, | |
lr=lr, | |
momentum=opt_config.get('momentum', 0.), | |
nesterov=opt_config.get('nesterov', False), | |
weight_decay=opt_config.get('weight_decay', 0.)) | |
else: | |
raise ValueError('Invalid optimizer type') | |
if sched_config['type'] == 'inverse': | |
sched = K.utils.InverseLR(opt, | |
inv_gamma=sched_config['inv_gamma'], | |
power=sched_config['power'], | |
warmup=sched_config['warmup']) | |
elif sched_config['type'] == 'exponential': | |
sched = K.utils.ExponentialLR(opt, | |
num_steps=sched_config['num_steps'], | |
decay=sched_config['decay'], | |
warmup=sched_config['warmup']) | |
elif sched_config['type'] == 'constant': | |
sched = K.utils.ConstantLRWithWarmup(opt, warmup=sched_config['warmup']) | |
else: | |
raise ValueError('Invalid schedule type') | |
assert ema_sched_config['type'] == 'inverse' | |
ema_sched = K.utils.EMAWarmup(power=ema_sched_config['power'], | |
max_value=ema_sched_config['max_value']) | |
ema_stats = {} | |
tf = transforms.Compose([ | |
transforms.Resize(size[0], interpolation=transforms.InterpolationMode.BICUBIC), | |
transforms.CenterCrop(size[0]), | |
K.augmentation.KarrasAugmentationPipeline(model_config['augment_prob'], disable_all=model_config['augment_prob'] == 0), | |
]) | |
if dataset_config['type'] == 'imagefolder': | |
train_set = K.utils.FolderOfImages(dataset_config['location'], transform=tf) | |
elif dataset_config['type'] == 'imagefolder-class': | |
train_set = datasets.ImageFolder(dataset_config['location'], transform=tf) | |
elif dataset_config['type'] == 'cifar10': | |
train_set = datasets.CIFAR10(dataset_config['location'], train=True, download=True, transform=tf) | |
elif dataset_config['type'] == 'mnist': | |
train_set = datasets.MNIST(dataset_config['location'], train=True, download=True, transform=tf) | |
elif dataset_config['type'] == 'huggingface': | |
from datasets import load_dataset | |
train_set = load_dataset(dataset_config['location']) | |
train_set.set_transform(partial(K.utils.hf_datasets_augs_helper, transform=tf, image_key=dataset_config['image_key'])) | |
train_set = train_set['train'] | |
elif dataset_config['type'] == 'custom': | |
location = (Path(args.config).parent / dataset_config['location']).resolve() | |
spec = importlib.util.spec_from_file_location('custom_dataset', location) | |
module = importlib.util.module_from_spec(spec) | |
spec.loader.exec_module(module) | |
get_dataset = getattr(module, dataset_config.get('get_dataset', 'get_dataset')) | |
custom_dataset_config = dataset_config.get('config', {}) | |
train_set = get_dataset(custom_dataset_config, transform=tf) | |
else: | |
raise ValueError('Invalid dataset type') | |
if accelerator.is_main_process: | |
try: | |
print(f'Number of items in dataset: {len(train_set):,}') | |
except TypeError: | |
pass | |
image_key = dataset_config.get('image_key', 0) | |
num_classes = dataset_config.get('num_classes', 0) | |
cond_dropout_rate = dataset_config.get('cond_dropout_rate', 0.1) | |
class_key = dataset_config.get('class_key', 1) | |
train_dl = data.DataLoader(train_set, args.batch_size, shuffle=True, drop_last=True, | |
num_workers=args.num_workers, persistent_workers=True, pin_memory=True) | |
inner_model, inner_model_ema, opt, train_dl = accelerator.prepare(inner_model, inner_model_ema, opt, train_dl) | |
with torch.no_grad(), K.models.flops.flop_counter() as fc: | |
x = torch.zeros([1, model_config['input_channels'], size[0], size[1]], device=device) | |
sigma = torch.ones([1], device=device) | |
extra_args = {} | |
if getattr(unwrap(inner_model), "num_classes", 0): | |
extra_args['class_cond'] = torch.zeros([1], dtype=torch.long, device=device) | |
inner_model(x, sigma, **extra_args) | |
if accelerator.is_main_process: | |
print(f"Forward pass GFLOPs: {fc.flops / 1_000_000_000:,.3f}", flush=True) | |
if use_wandb: | |
wandb.watch(inner_model) | |
if accelerator.num_processes == 1: | |
args.gns = False | |
if args.gns: | |
gns_stats_hook = K.gns.DDPGradientStatsHook(inner_model) | |
gns_stats = K.gns.GradientNoiseScale() | |
else: | |
gns_stats = None | |
sigma_min = model_config['sigma_min'] | |
sigma_max = model_config['sigma_max'] | |
sample_density = K.config.make_sample_density(model_config) | |
model = K.config.make_denoiser_wrapper(config)(inner_model) | |
model_ema = K.config.make_denoiser_wrapper(config)(inner_model_ema) | |
state_path = Path(f'{args.name}_state.json') | |
if state_path.exists() or args.resume: | |
if args.resume: | |
ckpt_path = args.resume | |
if not args.resume: | |
state = json.load(open(state_path)) | |
ckpt_path = state['latest_checkpoint'] | |
if accelerator.is_main_process: | |
print(f'Resuming from {ckpt_path}...') | |
ckpt = torch.load(ckpt_path, map_location='cpu') | |
unwrap(model.inner_model).load_state_dict(ckpt['model']) | |
unwrap(model_ema.inner_model).load_state_dict(ckpt['model_ema']) | |
opt.load_state_dict(ckpt['opt']) | |
sched.load_state_dict(ckpt['sched']) | |
ema_sched.load_state_dict(ckpt['ema_sched']) | |
ema_stats = ckpt.get('ema_stats', ema_stats) | |
epoch = ckpt['epoch'] + 1 | |
step = ckpt['step'] + 1 | |
if args.gns and ckpt.get('gns_stats', None) is not None: | |
gns_stats.load_state_dict(ckpt['gns_stats']) | |
demo_gen.set_state(ckpt['demo_gen']) | |
elapsed = ckpt.get('elapsed', 0.0) | |
del ckpt | |
else: | |
epoch = 0 | |
step = 0 | |
if args.reset_ema: | |
unwrap(model.inner_model).load_state_dict(unwrap(model_ema.inner_model).state_dict()) | |
ema_sched = K.utils.EMAWarmup(power=ema_sched_config['power'], | |
max_value=ema_sched_config['max_value']) | |
ema_stats = {} | |
if args.resume_inference: | |
if accelerator.is_main_process: | |
print(f'Loading {args.resume_inference}...') | |
ckpt = safetorch.load_file(args.resume_inference) | |
unwrap(model.inner_model).load_state_dict(ckpt) | |
unwrap(model_ema.inner_model).load_state_dict(ckpt) | |
del ckpt | |
evaluate_enabled = args.evaluate_every > 0 and args.evaluate_n > 0 | |
metrics_log = None | |
if evaluate_enabled: | |
if args.evaluate_with == 'inception': | |
extractor = K.evaluation.InceptionV3FeatureExtractor(device=device) | |
elif args.evaluate_with == 'clip': | |
extractor = K.evaluation.CLIPFeatureExtractor(args.clip_model, device=device) | |
elif args.evaluate_with == 'dinov2': | |
extractor = K.evaluation.DINOv2FeatureExtractor(args.dinov2_model, device=device) | |
else: | |
raise ValueError('Invalid evaluation feature extractor') | |
train_iter = iter(train_dl) | |
if accelerator.is_main_process: | |
print('Computing features for reals...') | |
reals_features = K.evaluation.compute_features(accelerator, lambda x: next(train_iter)[image_key][1], extractor, args.evaluate_n, args.batch_size) | |
if accelerator.is_main_process and not args.evaluate_only: | |
metrics_log = K.utils.CSVLogger(f'{args.name}_metrics.csv', ['step', 'time', 'loss', 'fid', 'kid']) | |
del train_iter | |
cfg_scale = 1. | |
def make_cfg_model_fn(model): | |
def cfg_model_fn(x, sigma, class_cond): | |
x_in = torch.cat([x, x]) | |
sigma_in = torch.cat([sigma, sigma]) | |
class_uncond = torch.full_like(class_cond, num_classes) | |
class_cond_in = torch.cat([class_uncond, class_cond]) | |
out = model(x_in, sigma_in, class_cond=class_cond_in) | |
out_uncond, out_cond = out.chunk(2) | |
return out_uncond + (out_cond - out_uncond) * cfg_scale | |
if cfg_scale != 1: | |
return cfg_model_fn | |
return model | |
def demo(): | |
if accelerator.is_main_process: | |
tqdm.write('Sampling...') | |
filename = f'{args.name}_demo_{step:08}.png' | |
n_per_proc = math.ceil(args.sample_n / accelerator.num_processes) | |
x = torch.randn([accelerator.num_processes, n_per_proc, model_config['input_channels'], size[0], size[1]], generator=demo_gen).to(device) | |
dist.broadcast(x, 0) | |
x = x[accelerator.process_index] * sigma_max | |
model_fn, extra_args = model_ema, {} | |
if num_classes: | |
class_cond = torch.randint(0, num_classes, [accelerator.num_processes, n_per_proc], generator=demo_gen).to(device) | |
dist.broadcast(class_cond, 0) | |
extra_args['class_cond'] = class_cond[accelerator.process_index] | |
model_fn = make_cfg_model_fn(model_ema) | |
sigmas = K.sampling.get_sigmas_karras(50, sigma_min, sigma_max, rho=7., device=device) | |
x_0 = K.sampling.sample_dpmpp_2m_sde(model_fn, x, sigmas, extra_args=extra_args, eta=0.0, solver_type='heun', disable=not accelerator.is_main_process) | |
x_0 = accelerator.gather(x_0)[:args.sample_n] | |
if accelerator.is_main_process: | |
grid = utils.make_grid(x_0, nrow=math.ceil(args.sample_n ** 0.5), padding=0) | |
K.utils.to_pil_image(grid).save(filename) | |
if use_wandb: | |
wandb.log({'demo_grid': wandb.Image(filename)}, step=step) | |
def evaluate(): | |
if not evaluate_enabled: | |
return | |
if accelerator.is_main_process: | |
tqdm.write('Evaluating...') | |
sigmas = K.sampling.get_sigmas_karras(50, sigma_min, sigma_max, rho=7., device=device) | |
def sample_fn(n): | |
x = torch.randn([n, model_config['input_channels'], size[0], size[1]], device=device) * sigma_max | |
model_fn, extra_args = model_ema, {} | |
if num_classes: | |
extra_args['class_cond'] = torch.randint(0, num_classes, [n], device=device) | |
model_fn = make_cfg_model_fn(model_ema) | |
x_0 = K.sampling.sample_dpmpp_2m_sde(model_fn, x, sigmas, extra_args=extra_args, eta=0.0, solver_type='heun', disable=True) | |
return x_0 | |
fakes_features = K.evaluation.compute_features(accelerator, sample_fn, extractor, args.evaluate_n, args.batch_size) | |
if accelerator.is_main_process: | |
fid = K.evaluation.fid(fakes_features, reals_features) | |
kid = K.evaluation.kid(fakes_features, reals_features) | |
print(f'FID: {fid.item():g}, KID: {kid.item():g}') | |
if accelerator.is_main_process and metrics_log is not None: | |
metrics_log.write(step, elapsed, ema_stats['loss'], fid.item(), kid.item()) | |
if use_wandb: | |
wandb.log({'FID': fid.item(), 'KID': kid.item()}, step=step) | |
def save(): | |
accelerator.wait_for_everyone() | |
filename = f'{args.name}_{step:08}.pth' | |
if accelerator.is_main_process: | |
tqdm.write(f'Saving to {filename}...') | |
inner_model = unwrap(model.inner_model) | |
inner_model_ema = unwrap(model_ema.inner_model) | |
obj = { | |
'config': config, | |
'model': inner_model.state_dict(), | |
'model_ema': inner_model_ema.state_dict(), | |
'opt': opt.state_dict(), | |
'sched': sched.state_dict(), | |
'ema_sched': ema_sched.state_dict(), | |
'epoch': epoch, | |
'step': step, | |
'gns_stats': gns_stats.state_dict() if gns_stats is not None else None, | |
'ema_stats': ema_stats, | |
'demo_gen': demo_gen.get_state(), | |
'elapsed': elapsed, | |
} | |
accelerator.save(obj, filename) | |
if accelerator.is_main_process: | |
state_obj = {'latest_checkpoint': filename} | |
json.dump(state_obj, open(state_path, 'w')) | |
if args.wandb_save_model and use_wandb: | |
wandb.save(filename) | |
if args.evaluate_only: | |
if not evaluate_enabled: | |
raise ValueError('--evaluate-only requested but evaluation is disabled') | |
evaluate() | |
return | |
losses_since_last_print = [] | |
try: | |
while True: | |
for batch in tqdm(train_dl, smoothing=0.1, disable=not accelerator.is_main_process): | |
if device.type == 'cuda': | |
start_timer = torch.cuda.Event(enable_timing=True) | |
end_timer = torch.cuda.Event(enable_timing=True) | |
torch.cuda.synchronize() | |
start_timer.record() | |
else: | |
start_timer = time.time() | |
with accelerator.accumulate(model): | |
reals, _, aug_cond = batch[image_key] | |
class_cond, extra_args = None, {} | |
if num_classes: | |
class_cond = batch[class_key] | |
drop = torch.rand(class_cond.shape, device=class_cond.device) | |
class_cond.masked_fill_(drop < cond_dropout_rate, num_classes) | |
extra_args['class_cond'] = class_cond | |
noise = torch.randn_like(reals) | |
with K.utils.enable_stratified_accelerate(accelerator, disable=args.gns): | |
sigma = sample_density([reals.shape[0]], device=device) | |
with K.models.checkpointing(args.checkpointing): | |
losses = model.loss(reals, noise, sigma, aug_cond=aug_cond, **extra_args) | |
loss = accelerator.gather(losses).mean().item() | |
losses_since_last_print.append(loss) | |
accelerator.backward(losses.mean()) | |
if args.gns: | |
sq_norm_small_batch, sq_norm_large_batch = gns_stats_hook.get_stats() | |
gns_stats.update(sq_norm_small_batch, sq_norm_large_batch, reals.shape[0], reals.shape[0] * accelerator.num_processes) | |
if accelerator.sync_gradients: | |
accelerator.clip_grad_norm_(model.parameters(), 1.) | |
opt.step() | |
sched.step() | |
opt.zero_grad() | |
ema_decay = ema_sched.get_value() | |
K.utils.ema_update_dict(ema_stats, {'loss': loss}, ema_decay ** (1 / args.grad_accum_steps)) | |
if accelerator.sync_gradients: | |
K.utils.ema_update(model, model_ema, ema_decay) | |
ema_sched.step() | |
if device.type == 'cuda': | |
end_timer.record() | |
torch.cuda.synchronize() | |
elapsed += start_timer.elapsed_time(end_timer) / 1000 | |
else: | |
elapsed += time.time() - start_timer | |
if step % 25 == 0: | |
loss_disp = sum(losses_since_last_print) / len(losses_since_last_print) | |
losses_since_last_print.clear() | |
avg_loss = ema_stats['loss'] | |
if accelerator.is_main_process: | |
if args.gns: | |
tqdm.write(f'Epoch: {epoch}, step: {step}, loss: {loss_disp:g}, avg loss: {avg_loss:g}, gns: {gns_stats.get_gns():g}') | |
else: | |
tqdm.write(f'Epoch: {epoch}, step: {step}, loss: {loss_disp:g}, avg loss: {avg_loss:g}') | |
if use_wandb: | |
log_dict = { | |
'epoch': epoch, | |
'loss': loss, | |
'lr': sched.get_last_lr()[0], | |
'ema_decay': ema_decay, | |
} | |
if args.gns: | |
log_dict['gradient_noise_scale'] = gns_stats.get_gns() | |
wandb.log(log_dict, step=step) | |
step += 1 | |
if step % args.demo_every == 0: | |
demo() | |
if evaluate_enabled and step > 0 and step % args.evaluate_every == 0: | |
evaluate() | |
if step == args.end_step or (step > 0 and step % args.save_every == 0): | |
save() | |
if step == args.end_step: | |
if accelerator.is_main_process: | |
tqdm.write('Done!') | |
return | |
epoch += 1 | |
except KeyboardInterrupt: | |
pass | |
if __name__ == '__main__': | |
main() | |