import argparse import math import random import os import numpy as np import torch from torch import nn, autograd, optim from torch.nn import functional as F from torch.utils import data import torch.distributed as dist from torchvision import transforms, utils from tqdm import tqdm try: import wandb except ImportError: wandb = None from model import Generator, Discriminator from dataset import MultiResolutionDataset from distributed import ( get_rank, synchronize, reduce_loss_dict, reduce_sum, get_world_size, ) def data_sampler(dataset, shuffle, distributed): if distributed: return data.distributed.DistributedSampler(dataset, shuffle=shuffle) if shuffle: return data.RandomSampler(dataset) else: return data.SequentialSampler(dataset) def requires_grad(model, flag=True): for p in model.parameters(): p.requires_grad = flag def accumulate(model1, model2, decay=0.999): par1 = dict(model1.named_parameters()) par2 = dict(model2.named_parameters()) for k in par1.keys(): par1[k].data.mul_(decay).add_(1 - decay, par2[k].data) def sample_data(loader): while True: for batch in loader: yield batch def d_logistic_loss(real_pred, fake_pred): real_loss = F.softplus(-real_pred) fake_loss = F.softplus(fake_pred) return real_loss.mean() + fake_loss.mean() def d_r1_loss(real_pred, real_img): grad_real, = autograd.grad( outputs=real_pred.sum(), inputs=real_img, create_graph=True ) grad_penalty = grad_real.pow(2).view(grad_real.shape[0], -1).sum(1).mean() return grad_penalty def g_nonsaturating_loss(fake_pred): loss = F.softplus(-fake_pred).mean() return loss def g_path_regularize(fake_img, latents, mean_path_length, decay=0.01): noise = torch.randn_like(fake_img) / math.sqrt( fake_img.shape[2] * fake_img.shape[3] ) grad, = autograd.grad( outputs=(fake_img * noise).sum(), inputs=latents, create_graph=True ) path_lengths = torch.sqrt(grad.pow(2).sum(2).mean(1)) path_mean = mean_path_length + decay * (path_lengths.mean() - mean_path_length) path_penalty = (path_lengths - path_mean).pow(2).mean() return path_penalty, path_mean.detach(), path_lengths def make_noise(batch, latent_dim, n_noise, device): if n_noise == 1: return torch.randn(batch, latent_dim, device=device) noises = torch.randn(n_noise, batch, latent_dim, device=device).unbind(0) return noises def mixing_noise(batch, latent_dim, prob, device): if prob > 0 and random.random() < prob: return make_noise(batch, latent_dim, 2, device) else: return [make_noise(batch, latent_dim, 1, device)] def set_grad_none(model, targets): for n, p in model.named_parameters(): if n in targets: p.grad = None def train(args, loader, generator, discriminator, g_optim, d_optim, g_ema, device): loader = sample_data(loader) pbar = range(args.iter) if get_rank() == 0: pbar = tqdm(pbar, initial=args.start_iter, dynamic_ncols=True, smoothing=0.01) mean_path_length = 0 d_loss_val = 0 r1_loss = torch.tensor(0.0, device=device) g_loss_val = 0 path_loss = torch.tensor(0.0, device=device) path_lengths = torch.tensor(0.0, device=device) mean_path_length_avg = 0 loss_dict = {} if args.distributed: g_module = generator.module d_module = discriminator.module else: g_module = generator d_module = discriminator accum = 0.5 ** (32 / (10 * 1000)) sample_z = torch.randn(args.n_sample, args.latent, device=device) for idx in pbar: i = idx + args.start_iter if i > args.iter: print("Done!") break real_img = next(loader) real_img = real_img.to(device) requires_grad(generator, False) requires_grad(discriminator, True) noise = mixing_noise(args.batch, args.latent, args.mixing, device) fake_img, _ = generator(noise) fake_pred = discriminator(fake_img) real_pred = discriminator(real_img) d_loss = d_logistic_loss(real_pred, fake_pred) loss_dict["d"] = d_loss loss_dict["real_score"] = real_pred.mean() loss_dict["fake_score"] = fake_pred.mean() discriminator.zero_grad() d_loss.backward() d_optim.step() d_regularize = i % args.d_reg_every == 0 if d_regularize: real_img.requires_grad = True real_pred = discriminator(real_img) r1_loss = d_r1_loss(real_pred, real_img) discriminator.zero_grad() (args.r1 / 2 * r1_loss * args.d_reg_every + 0 * real_pred[0]).backward() d_optim.step() loss_dict["r1"] = r1_loss requires_grad(generator, True) requires_grad(discriminator, False) noise = mixing_noise(args.batch, args.latent, args.mixing, device) fake_img, _ = generator(noise) fake_pred = discriminator(fake_img) g_loss = g_nonsaturating_loss(fake_pred) loss_dict["g"] = g_loss generator.zero_grad() g_loss.backward() g_optim.step() g_regularize = i % args.g_reg_every == 0 if g_regularize: path_batch_size = max(1, args.batch // args.path_batch_shrink) noise = mixing_noise(path_batch_size, args.latent, args.mixing, device) fake_img, latents = generator(noise, return_latents=True) path_loss, mean_path_length, path_lengths = g_path_regularize( fake_img, latents, mean_path_length ) generator.zero_grad() weighted_path_loss = args.path_regularize * args.g_reg_every * path_loss if args.path_batch_shrink: weighted_path_loss += 0 * fake_img[0, 0, 0, 0] weighted_path_loss.backward() g_optim.step() mean_path_length_avg = ( reduce_sum(mean_path_length).item() / get_world_size() ) loss_dict["path"] = path_loss loss_dict["path_length"] = path_lengths.mean() accumulate(g_ema, g_module, accum) loss_reduced = reduce_loss_dict(loss_dict) d_loss_val = loss_reduced["d"].mean().item() g_loss_val = loss_reduced["g"].mean().item() r1_val = loss_reduced["r1"].mean().item() path_loss_val = loss_reduced["path"].mean().item() real_score_val = loss_reduced["real_score"].mean().item() fake_score_val = loss_reduced["fake_score"].mean().item() path_length_val = loss_reduced["path_length"].mean().item() if get_rank() == 0: pbar.set_description( ( f"d: {d_loss_val:.4f}; g: {g_loss_val:.4f}; r1: {r1_val:.4f}; " f"path: {path_loss_val:.4f}; mean path: {mean_path_length_avg:.4f}" ) ) if wandb and args.wandb: wandb.log( { "Generator": g_loss_val, "Discriminator": d_loss_val, "R1": r1_val, "Path Length Regularization": path_loss_val, "Mean Path Length": mean_path_length, "Real Score": real_score_val, "Fake Score": fake_score_val, "Path Length": path_length_val, } ) if i % 100 == 0: with torch.no_grad(): g_ema.eval() sample, _ = g_ema([sample_z]) utils.save_image( sample, f"sample/{str(i).zfill(6)}.png", nrow=int(args.n_sample ** 0.5), normalize=True, range=(-1, 1), ) if i % 10000 == 0: torch.save( { "g": g_module.state_dict(), "d": d_module.state_dict(), "g_ema": g_ema.state_dict(), "g_optim": g_optim.state_dict(), "d_optim": d_optim.state_dict(), }, f"checkpoint/{str(i).zfill(6)}.pt", ) if __name__ == "__main__": device = "cuda" parser = argparse.ArgumentParser() parser.add_argument("path", type=str) parser.add_argument("--iter", type=int, default=800000) parser.add_argument("--batch", type=int, default=16) parser.add_argument("--n_sample", type=int, default=64) parser.add_argument("--size", type=int, default=256) parser.add_argument("--r1", type=float, default=10) parser.add_argument("--path_regularize", type=float, default=2) parser.add_argument("--path_batch_shrink", type=int, default=2) parser.add_argument("--d_reg_every", type=int, default=16) parser.add_argument("--g_reg_every", type=int, default=4) parser.add_argument("--mixing", type=float, default=0.9) parser.add_argument("--ckpt", type=str, default=None) parser.add_argument("--lr", type=float, default=0.002) parser.add_argument("--channel_multiplier", type=int, default=2) parser.add_argument("--wandb", action="store_true") parser.add_argument("--local_rank", type=int, default=0) args = parser.parse_args() n_gpu = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1 args.distributed = n_gpu > 1 if args.distributed: torch.cuda.set_device(args.local_rank) torch.distributed.init_process_group(backend="nccl", init_method="env://") synchronize() args.latent = 512 args.n_mlp = 8 args.start_iter = 0 generator = Generator( args.size, args.latent, args.n_mlp, channel_multiplier=args.channel_multiplier ).to(device) discriminator = Discriminator( args.size, channel_multiplier=args.channel_multiplier ).to(device) g_ema = Generator( args.size, args.latent, args.n_mlp, channel_multiplier=args.channel_multiplier ).to(device) g_ema.eval() accumulate(g_ema, generator, 0) g_reg_ratio = args.g_reg_every / (args.g_reg_every + 1) d_reg_ratio = args.d_reg_every / (args.d_reg_every + 1) g_optim = optim.Adam( generator.parameters(), lr=args.lr * g_reg_ratio, betas=(0 ** g_reg_ratio, 0.99 ** g_reg_ratio), ) d_optim = optim.Adam( discriminator.parameters(), lr=args.lr * d_reg_ratio, betas=(0 ** d_reg_ratio, 0.99 ** d_reg_ratio), ) if args.ckpt is not None: print("load model:", args.ckpt) ckpt = torch.load(args.ckpt, map_location=lambda storage, loc: storage) try: ckpt_name = os.path.basename(args.ckpt) args.start_iter = int(os.path.splitext(ckpt_name)[0]) except ValueError: pass generator.load_state_dict(ckpt["g"]) discriminator.load_state_dict(ckpt["d"]) g_ema.load_state_dict(ckpt["g_ema"]) g_optim.load_state_dict(ckpt["g_optim"]) d_optim.load_state_dict(ckpt["d_optim"]) if args.distributed: generator = nn.parallel.DistributedDataParallel( generator, device_ids=[args.local_rank], output_device=args.local_rank, broadcast_buffers=False, ) discriminator = nn.parallel.DistributedDataParallel( discriminator, device_ids=[args.local_rank], output_device=args.local_rank, broadcast_buffers=False, ) transform = transforms.Compose( [ transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True), ] ) dataset = MultiResolutionDataset(args.path, transform, args.size) loader = data.DataLoader( dataset, batch_size=args.batch, sampler=data_sampler(dataset, shuffle=True, distributed=args.distributed), drop_last=True, ) if get_rank() == 0 and wandb is not None and args.wandb: wandb.init(project="stylegan 2") train(args, loader, generator, discriminator, g_optim, d_optim, g_ema, device)