Spaces:
Runtime error
Runtime error
import argparse | |
import math | |
import random | |
import os | |
import numpy as np | |
import torch | |
from torch import nn, autograd, optim | |
from torch.nn import functional as F | |
from torch.utils import data | |
import torch.distributed as dist | |
from torchvision import transforms, utils | |
from tqdm import tqdm | |
try: | |
import wandb | |
except ImportError: | |
wandb = None | |
from model import Generator, Discriminator | |
from dataset import MultiResolutionDataset | |
from distributed import ( | |
get_rank, | |
synchronize, | |
reduce_loss_dict, | |
reduce_sum, | |
get_world_size, | |
) | |
def data_sampler(dataset, shuffle, distributed): | |
if distributed: | |
return data.distributed.DistributedSampler(dataset, shuffle=shuffle) | |
if shuffle: | |
return data.RandomSampler(dataset) | |
else: | |
return data.SequentialSampler(dataset) | |
def requires_grad(model, flag=True): | |
for p in model.parameters(): | |
p.requires_grad = flag | |
def accumulate(model1, model2, decay=0.999): | |
par1 = dict(model1.named_parameters()) | |
par2 = dict(model2.named_parameters()) | |
for k in par1.keys(): | |
par1[k].data.mul_(decay).add_(1 - decay, par2[k].data) | |
def sample_data(loader): | |
while True: | |
for batch in loader: | |
yield batch | |
def d_logistic_loss(real_pred, fake_pred): | |
real_loss = F.softplus(-real_pred) | |
fake_loss = F.softplus(fake_pred) | |
return real_loss.mean() + fake_loss.mean() | |
def d_r1_loss(real_pred, real_img): | |
grad_real, = autograd.grad( | |
outputs=real_pred.sum(), inputs=real_img, create_graph=True | |
) | |
grad_penalty = grad_real.pow(2).view(grad_real.shape[0], -1).sum(1).mean() | |
return grad_penalty | |
def g_nonsaturating_loss(fake_pred): | |
loss = F.softplus(-fake_pred).mean() | |
return loss | |
def g_path_regularize(fake_img, latents, mean_path_length, decay=0.01): | |
noise = torch.randn_like(fake_img) / math.sqrt( | |
fake_img.shape[2] * fake_img.shape[3] | |
) | |
grad, = autograd.grad( | |
outputs=(fake_img * noise).sum(), inputs=latents, create_graph=True | |
) | |
path_lengths = torch.sqrt(grad.pow(2).sum(2).mean(1)) | |
path_mean = mean_path_length + decay * (path_lengths.mean() - mean_path_length) | |
path_penalty = (path_lengths - path_mean).pow(2).mean() | |
return path_penalty, path_mean.detach(), path_lengths | |
def make_noise(batch, latent_dim, n_noise, device): | |
if n_noise == 1: | |
return torch.randn(batch, latent_dim, device=device) | |
noises = torch.randn(n_noise, batch, latent_dim, device=device).unbind(0) | |
return noises | |
def mixing_noise(batch, latent_dim, prob, device): | |
if prob > 0 and random.random() < prob: | |
return make_noise(batch, latent_dim, 2, device) | |
else: | |
return [make_noise(batch, latent_dim, 1, device)] | |
def set_grad_none(model, targets): | |
for n, p in model.named_parameters(): | |
if n in targets: | |
p.grad = None | |
def train(args, loader, generator, discriminator, g_optim, d_optim, g_ema, device): | |
loader = sample_data(loader) | |
pbar = range(args.iter) | |
if get_rank() == 0: | |
pbar = tqdm(pbar, initial=args.start_iter, dynamic_ncols=True, smoothing=0.01) | |
mean_path_length = 0 | |
d_loss_val = 0 | |
r1_loss = torch.tensor(0.0, device=device) | |
g_loss_val = 0 | |
path_loss = torch.tensor(0.0, device=device) | |
path_lengths = torch.tensor(0.0, device=device) | |
mean_path_length_avg = 0 | |
loss_dict = {} | |
if args.distributed: | |
g_module = generator.module | |
d_module = discriminator.module | |
else: | |
g_module = generator | |
d_module = discriminator | |
accum = 0.5 ** (32 / (10 * 1000)) | |
sample_z = torch.randn(args.n_sample, args.latent, device=device) | |
for idx in pbar: | |
i = idx + args.start_iter | |
if i > args.iter: | |
print("Done!") | |
break | |
real_img = next(loader) | |
real_img = real_img.to(device) | |
requires_grad(generator, False) | |
requires_grad(discriminator, True) | |
noise = mixing_noise(args.batch, args.latent, args.mixing, device) | |
fake_img, _ = generator(noise) | |
fake_pred = discriminator(fake_img) | |
real_pred = discriminator(real_img) | |
d_loss = d_logistic_loss(real_pred, fake_pred) | |
loss_dict["d"] = d_loss | |
loss_dict["real_score"] = real_pred.mean() | |
loss_dict["fake_score"] = fake_pred.mean() | |
discriminator.zero_grad() | |
d_loss.backward() | |
d_optim.step() | |
d_regularize = i % args.d_reg_every == 0 | |
if d_regularize: | |
real_img.requires_grad = True | |
real_pred = discriminator(real_img) | |
r1_loss = d_r1_loss(real_pred, real_img) | |
discriminator.zero_grad() | |
(args.r1 / 2 * r1_loss * args.d_reg_every + 0 * real_pred[0]).backward() | |
d_optim.step() | |
loss_dict["r1"] = r1_loss | |
requires_grad(generator, True) | |
requires_grad(discriminator, False) | |
noise = mixing_noise(args.batch, args.latent, args.mixing, device) | |
fake_img, _ = generator(noise) | |
fake_pred = discriminator(fake_img) | |
g_loss = g_nonsaturating_loss(fake_pred) | |
loss_dict["g"] = g_loss | |
generator.zero_grad() | |
g_loss.backward() | |
g_optim.step() | |
g_regularize = i % args.g_reg_every == 0 | |
if g_regularize: | |
path_batch_size = max(1, args.batch // args.path_batch_shrink) | |
noise = mixing_noise(path_batch_size, args.latent, args.mixing, device) | |
fake_img, latents = generator(noise, return_latents=True) | |
path_loss, mean_path_length, path_lengths = g_path_regularize( | |
fake_img, latents, mean_path_length | |
) | |
generator.zero_grad() | |
weighted_path_loss = args.path_regularize * args.g_reg_every * path_loss | |
if args.path_batch_shrink: | |
weighted_path_loss += 0 * fake_img[0, 0, 0, 0] | |
weighted_path_loss.backward() | |
g_optim.step() | |
mean_path_length_avg = ( | |
reduce_sum(mean_path_length).item() / get_world_size() | |
) | |
loss_dict["path"] = path_loss | |
loss_dict["path_length"] = path_lengths.mean() | |
accumulate(g_ema, g_module, accum) | |
loss_reduced = reduce_loss_dict(loss_dict) | |
d_loss_val = loss_reduced["d"].mean().item() | |
g_loss_val = loss_reduced["g"].mean().item() | |
r1_val = loss_reduced["r1"].mean().item() | |
path_loss_val = loss_reduced["path"].mean().item() | |
real_score_val = loss_reduced["real_score"].mean().item() | |
fake_score_val = loss_reduced["fake_score"].mean().item() | |
path_length_val = loss_reduced["path_length"].mean().item() | |
if get_rank() == 0: | |
pbar.set_description( | |
( | |
f"d: {d_loss_val:.4f}; g: {g_loss_val:.4f}; r1: {r1_val:.4f}; " | |
f"path: {path_loss_val:.4f}; mean path: {mean_path_length_avg:.4f}" | |
) | |
) | |
if wandb and args.wandb: | |
wandb.log( | |
{ | |
"Generator": g_loss_val, | |
"Discriminator": d_loss_val, | |
"R1": r1_val, | |
"Path Length Regularization": path_loss_val, | |
"Mean Path Length": mean_path_length, | |
"Real Score": real_score_val, | |
"Fake Score": fake_score_val, | |
"Path Length": path_length_val, | |
} | |
) | |
if i % 100 == 0: | |
with torch.no_grad(): | |
g_ema.eval() | |
sample, _ = g_ema([sample_z]) | |
utils.save_image( | |
sample, | |
f"sample/{str(i).zfill(6)}.png", | |
nrow=int(args.n_sample ** 0.5), | |
normalize=True, | |
range=(-1, 1), | |
) | |
if i % 10000 == 0: | |
torch.save( | |
{ | |
"g": g_module.state_dict(), | |
"d": d_module.state_dict(), | |
"g_ema": g_ema.state_dict(), | |
"g_optim": g_optim.state_dict(), | |
"d_optim": d_optim.state_dict(), | |
}, | |
f"checkpoint/{str(i).zfill(6)}.pt", | |
) | |
if __name__ == "__main__": | |
device = "cuda" | |
parser = argparse.ArgumentParser() | |
parser.add_argument("path", type=str) | |
parser.add_argument("--iter", type=int, default=800000) | |
parser.add_argument("--batch", type=int, default=16) | |
parser.add_argument("--n_sample", type=int, default=64) | |
parser.add_argument("--size", type=int, default=256) | |
parser.add_argument("--r1", type=float, default=10) | |
parser.add_argument("--path_regularize", type=float, default=2) | |
parser.add_argument("--path_batch_shrink", type=int, default=2) | |
parser.add_argument("--d_reg_every", type=int, default=16) | |
parser.add_argument("--g_reg_every", type=int, default=4) | |
parser.add_argument("--mixing", type=float, default=0.9) | |
parser.add_argument("--ckpt", type=str, default=None) | |
parser.add_argument("--lr", type=float, default=0.002) | |
parser.add_argument("--channel_multiplier", type=int, default=2) | |
parser.add_argument("--wandb", action="store_true") | |
parser.add_argument("--local_rank", type=int, default=0) | |
args = parser.parse_args() | |
n_gpu = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1 | |
args.distributed = n_gpu > 1 | |
if args.distributed: | |
torch.cuda.set_device(args.local_rank) | |
torch.distributed.init_process_group(backend="nccl", init_method="env://") | |
synchronize() | |
args.latent = 512 | |
args.n_mlp = 8 | |
args.start_iter = 0 | |
generator = Generator( | |
args.size, args.latent, args.n_mlp, channel_multiplier=args.channel_multiplier | |
).to(device) | |
discriminator = Discriminator( | |
args.size, channel_multiplier=args.channel_multiplier | |
).to(device) | |
g_ema = Generator( | |
args.size, args.latent, args.n_mlp, channel_multiplier=args.channel_multiplier | |
).to(device) | |
g_ema.eval() | |
accumulate(g_ema, generator, 0) | |
g_reg_ratio = args.g_reg_every / (args.g_reg_every + 1) | |
d_reg_ratio = args.d_reg_every / (args.d_reg_every + 1) | |
g_optim = optim.Adam( | |
generator.parameters(), | |
lr=args.lr * g_reg_ratio, | |
betas=(0 ** g_reg_ratio, 0.99 ** g_reg_ratio), | |
) | |
d_optim = optim.Adam( | |
discriminator.parameters(), | |
lr=args.lr * d_reg_ratio, | |
betas=(0 ** d_reg_ratio, 0.99 ** d_reg_ratio), | |
) | |
if args.ckpt is not None: | |
print("load model:", args.ckpt) | |
ckpt = torch.load(args.ckpt, map_location=lambda storage, loc: storage) | |
try: | |
ckpt_name = os.path.basename(args.ckpt) | |
args.start_iter = int(os.path.splitext(ckpt_name)[0]) | |
except ValueError: | |
pass | |
generator.load_state_dict(ckpt["g"]) | |
discriminator.load_state_dict(ckpt["d"]) | |
g_ema.load_state_dict(ckpt["g_ema"]) | |
g_optim.load_state_dict(ckpt["g_optim"]) | |
d_optim.load_state_dict(ckpt["d_optim"]) | |
if args.distributed: | |
generator = nn.parallel.DistributedDataParallel( | |
generator, | |
device_ids=[args.local_rank], | |
output_device=args.local_rank, | |
broadcast_buffers=False, | |
) | |
discriminator = nn.parallel.DistributedDataParallel( | |
discriminator, | |
device_ids=[args.local_rank], | |
output_device=args.local_rank, | |
broadcast_buffers=False, | |
) | |
transform = transforms.Compose( | |
[ | |
transforms.RandomHorizontalFlip(), | |
transforms.ToTensor(), | |
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True), | |
] | |
) | |
dataset = MultiResolutionDataset(args.path, transform, args.size) | |
loader = data.DataLoader( | |
dataset, | |
batch_size=args.batch, | |
sampler=data_sampler(dataset, shuffle=True, distributed=args.distributed), | |
drop_last=True, | |
) | |
if get_rank() == 0 and wandb is not None and args.wandb: | |
wandb.init(project="stylegan 2") | |
train(args, loader, generator, discriminator, g_optim, d_optim, g_ema, device) | |