mfrashad's picture
Init code
97069e1
import argparse
import math
import random
import os
import numpy as np
import torch
from torch import nn, autograd, optim
from torch.nn import functional as F
from torch.utils import data
import torch.distributed as dist
from torchvision import transforms, utils
from tqdm import tqdm
try:
import wandb
except ImportError:
wandb = None
from model import Generator, Discriminator
from dataset import MultiResolutionDataset
from distributed import (
get_rank,
synchronize,
reduce_loss_dict,
reduce_sum,
get_world_size,
)
def data_sampler(dataset, shuffle, distributed):
if distributed:
return data.distributed.DistributedSampler(dataset, shuffle=shuffle)
if shuffle:
return data.RandomSampler(dataset)
else:
return data.SequentialSampler(dataset)
def requires_grad(model, flag=True):
for p in model.parameters():
p.requires_grad = flag
def accumulate(model1, model2, decay=0.999):
par1 = dict(model1.named_parameters())
par2 = dict(model2.named_parameters())
for k in par1.keys():
par1[k].data.mul_(decay).add_(1 - decay, par2[k].data)
def sample_data(loader):
while True:
for batch in loader:
yield batch
def d_logistic_loss(real_pred, fake_pred):
real_loss = F.softplus(-real_pred)
fake_loss = F.softplus(fake_pred)
return real_loss.mean() + fake_loss.mean()
def d_r1_loss(real_pred, real_img):
grad_real, = autograd.grad(
outputs=real_pred.sum(), inputs=real_img, create_graph=True
)
grad_penalty = grad_real.pow(2).view(grad_real.shape[0], -1).sum(1).mean()
return grad_penalty
def g_nonsaturating_loss(fake_pred):
loss = F.softplus(-fake_pred).mean()
return loss
def g_path_regularize(fake_img, latents, mean_path_length, decay=0.01):
noise = torch.randn_like(fake_img) / math.sqrt(
fake_img.shape[2] * fake_img.shape[3]
)
grad, = autograd.grad(
outputs=(fake_img * noise).sum(), inputs=latents, create_graph=True
)
path_lengths = torch.sqrt(grad.pow(2).sum(2).mean(1))
path_mean = mean_path_length + decay * (path_lengths.mean() - mean_path_length)
path_penalty = (path_lengths - path_mean).pow(2).mean()
return path_penalty, path_mean.detach(), path_lengths
def make_noise(batch, latent_dim, n_noise, device):
if n_noise == 1:
return torch.randn(batch, latent_dim, device=device)
noises = torch.randn(n_noise, batch, latent_dim, device=device).unbind(0)
return noises
def mixing_noise(batch, latent_dim, prob, device):
if prob > 0 and random.random() < prob:
return make_noise(batch, latent_dim, 2, device)
else:
return [make_noise(batch, latent_dim, 1, device)]
def set_grad_none(model, targets):
for n, p in model.named_parameters():
if n in targets:
p.grad = None
def train(args, loader, generator, discriminator, g_optim, d_optim, g_ema, device):
loader = sample_data(loader)
pbar = range(args.iter)
if get_rank() == 0:
pbar = tqdm(pbar, initial=args.start_iter, dynamic_ncols=True, smoothing=0.01)
mean_path_length = 0
d_loss_val = 0
r1_loss = torch.tensor(0.0, device=device)
g_loss_val = 0
path_loss = torch.tensor(0.0, device=device)
path_lengths = torch.tensor(0.0, device=device)
mean_path_length_avg = 0
loss_dict = {}
if args.distributed:
g_module = generator.module
d_module = discriminator.module
else:
g_module = generator
d_module = discriminator
accum = 0.5 ** (32 / (10 * 1000))
sample_z = torch.randn(args.n_sample, args.latent, device=device)
for idx in pbar:
i = idx + args.start_iter
if i > args.iter:
print("Done!")
break
real_img = next(loader)
real_img = real_img.to(device)
requires_grad(generator, False)
requires_grad(discriminator, True)
noise = mixing_noise(args.batch, args.latent, args.mixing, device)
fake_img, _ = generator(noise)
fake_pred = discriminator(fake_img)
real_pred = discriminator(real_img)
d_loss = d_logistic_loss(real_pred, fake_pred)
loss_dict["d"] = d_loss
loss_dict["real_score"] = real_pred.mean()
loss_dict["fake_score"] = fake_pred.mean()
discriminator.zero_grad()
d_loss.backward()
d_optim.step()
d_regularize = i % args.d_reg_every == 0
if d_regularize:
real_img.requires_grad = True
real_pred = discriminator(real_img)
r1_loss = d_r1_loss(real_pred, real_img)
discriminator.zero_grad()
(args.r1 / 2 * r1_loss * args.d_reg_every + 0 * real_pred[0]).backward()
d_optim.step()
loss_dict["r1"] = r1_loss
requires_grad(generator, True)
requires_grad(discriminator, False)
noise = mixing_noise(args.batch, args.latent, args.mixing, device)
fake_img, _ = generator(noise)
fake_pred = discriminator(fake_img)
g_loss = g_nonsaturating_loss(fake_pred)
loss_dict["g"] = g_loss
generator.zero_grad()
g_loss.backward()
g_optim.step()
g_regularize = i % args.g_reg_every == 0
if g_regularize:
path_batch_size = max(1, args.batch // args.path_batch_shrink)
noise = mixing_noise(path_batch_size, args.latent, args.mixing, device)
fake_img, latents = generator(noise, return_latents=True)
path_loss, mean_path_length, path_lengths = g_path_regularize(
fake_img, latents, mean_path_length
)
generator.zero_grad()
weighted_path_loss = args.path_regularize * args.g_reg_every * path_loss
if args.path_batch_shrink:
weighted_path_loss += 0 * fake_img[0, 0, 0, 0]
weighted_path_loss.backward()
g_optim.step()
mean_path_length_avg = (
reduce_sum(mean_path_length).item() / get_world_size()
)
loss_dict["path"] = path_loss
loss_dict["path_length"] = path_lengths.mean()
accumulate(g_ema, g_module, accum)
loss_reduced = reduce_loss_dict(loss_dict)
d_loss_val = loss_reduced["d"].mean().item()
g_loss_val = loss_reduced["g"].mean().item()
r1_val = loss_reduced["r1"].mean().item()
path_loss_val = loss_reduced["path"].mean().item()
real_score_val = loss_reduced["real_score"].mean().item()
fake_score_val = loss_reduced["fake_score"].mean().item()
path_length_val = loss_reduced["path_length"].mean().item()
if get_rank() == 0:
pbar.set_description(
(
f"d: {d_loss_val:.4f}; g: {g_loss_val:.4f}; r1: {r1_val:.4f}; "
f"path: {path_loss_val:.4f}; mean path: {mean_path_length_avg:.4f}"
)
)
if wandb and args.wandb:
wandb.log(
{
"Generator": g_loss_val,
"Discriminator": d_loss_val,
"R1": r1_val,
"Path Length Regularization": path_loss_val,
"Mean Path Length": mean_path_length,
"Real Score": real_score_val,
"Fake Score": fake_score_val,
"Path Length": path_length_val,
}
)
if i % 100 == 0:
with torch.no_grad():
g_ema.eval()
sample, _ = g_ema([sample_z])
utils.save_image(
sample,
f"sample/{str(i).zfill(6)}.png",
nrow=int(args.n_sample ** 0.5),
normalize=True,
range=(-1, 1),
)
if i % 10000 == 0:
torch.save(
{
"g": g_module.state_dict(),
"d": d_module.state_dict(),
"g_ema": g_ema.state_dict(),
"g_optim": g_optim.state_dict(),
"d_optim": d_optim.state_dict(),
},
f"checkpoint/{str(i).zfill(6)}.pt",
)
if __name__ == "__main__":
device = "cuda"
parser = argparse.ArgumentParser()
parser.add_argument("path", type=str)
parser.add_argument("--iter", type=int, default=800000)
parser.add_argument("--batch", type=int, default=16)
parser.add_argument("--n_sample", type=int, default=64)
parser.add_argument("--size", type=int, default=256)
parser.add_argument("--r1", type=float, default=10)
parser.add_argument("--path_regularize", type=float, default=2)
parser.add_argument("--path_batch_shrink", type=int, default=2)
parser.add_argument("--d_reg_every", type=int, default=16)
parser.add_argument("--g_reg_every", type=int, default=4)
parser.add_argument("--mixing", type=float, default=0.9)
parser.add_argument("--ckpt", type=str, default=None)
parser.add_argument("--lr", type=float, default=0.002)
parser.add_argument("--channel_multiplier", type=int, default=2)
parser.add_argument("--wandb", action="store_true")
parser.add_argument("--local_rank", type=int, default=0)
args = parser.parse_args()
n_gpu = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1
args.distributed = n_gpu > 1
if args.distributed:
torch.cuda.set_device(args.local_rank)
torch.distributed.init_process_group(backend="nccl", init_method="env://")
synchronize()
args.latent = 512
args.n_mlp = 8
args.start_iter = 0
generator = Generator(
args.size, args.latent, args.n_mlp, channel_multiplier=args.channel_multiplier
).to(device)
discriminator = Discriminator(
args.size, channel_multiplier=args.channel_multiplier
).to(device)
g_ema = Generator(
args.size, args.latent, args.n_mlp, channel_multiplier=args.channel_multiplier
).to(device)
g_ema.eval()
accumulate(g_ema, generator, 0)
g_reg_ratio = args.g_reg_every / (args.g_reg_every + 1)
d_reg_ratio = args.d_reg_every / (args.d_reg_every + 1)
g_optim = optim.Adam(
generator.parameters(),
lr=args.lr * g_reg_ratio,
betas=(0 ** g_reg_ratio, 0.99 ** g_reg_ratio),
)
d_optim = optim.Adam(
discriminator.parameters(),
lr=args.lr * d_reg_ratio,
betas=(0 ** d_reg_ratio, 0.99 ** d_reg_ratio),
)
if args.ckpt is not None:
print("load model:", args.ckpt)
ckpt = torch.load(args.ckpt, map_location=lambda storage, loc: storage)
try:
ckpt_name = os.path.basename(args.ckpt)
args.start_iter = int(os.path.splitext(ckpt_name)[0])
except ValueError:
pass
generator.load_state_dict(ckpt["g"])
discriminator.load_state_dict(ckpt["d"])
g_ema.load_state_dict(ckpt["g_ema"])
g_optim.load_state_dict(ckpt["g_optim"])
d_optim.load_state_dict(ckpt["d_optim"])
if args.distributed:
generator = nn.parallel.DistributedDataParallel(
generator,
device_ids=[args.local_rank],
output_device=args.local_rank,
broadcast_buffers=False,
)
discriminator = nn.parallel.DistributedDataParallel(
discriminator,
device_ids=[args.local_rank],
output_device=args.local_rank,
broadcast_buffers=False,
)
transform = transforms.Compose(
[
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True),
]
)
dataset = MultiResolutionDataset(args.path, transform, args.size)
loader = data.DataLoader(
dataset,
batch_size=args.batch,
sampler=data_sampler(dataset, shuffle=True, distributed=args.distributed),
drop_last=True,
)
if get_rank() == 0 and wandb is not None and args.wandb:
wandb.init(project="stylegan 2")
train(args, loader, generator, discriminator, g_optim, d_optim, g_ema, device)