Spaces:
Runtime error
Runtime error
import argparse | |
import math | |
import random | |
import os | |
from util import * | |
import numpy as np | |
import torch | |
torch.backends.cudnn.benchmark = True | |
from torch import nn, autograd | |
from torch import optim | |
from torch.nn import functional as F | |
from torch.utils import data | |
import torch.distributed as dist | |
from torchvision import transforms, utils | |
from tqdm import tqdm | |
from torch.optim import lr_scheduler | |
import copy | |
import kornia.augmentation as K | |
import kornia | |
import lpips | |
from model import * | |
from dataset import ImageFolder | |
from distributed import ( | |
get_rank, | |
synchronize, | |
reduce_loss_dict, | |
reduce_sum, | |
get_world_size, | |
) | |
mse_criterion = nn.MSELoss() | |
def test(args, genA2B, genB2A, testA_loader, testB_loader, name, step): | |
testA_loader = iter(testA_loader) | |
testB_loader = iter(testB_loader) | |
with torch.no_grad(): | |
test_sample_num = 16 | |
genA2B.eval(), genB2A.eval() | |
A2B = [] | |
B2A = [] | |
for i in range(test_sample_num): | |
real_A = testA_loader.next() | |
real_B = testB_loader.next() | |
real_A, real_B = real_A.cuda(), real_B.cuda() | |
A2B_content, A2B_style = genA2B.encode(real_A) | |
B2A_content, B2A_style = genB2A.encode(real_B) | |
if i % 2 == 0: | |
A2B_mod1 = torch.randn([1, args.latent_dim]).cuda() | |
B2A_mod1 = torch.randn([1, args.latent_dim]).cuda() | |
A2B_mod2 = torch.randn([1, args.latent_dim]).cuda() | |
B2A_mod2 = torch.randn([1, args.latent_dim]).cuda() | |
fake_B2B, _, _ = genA2B(real_B) | |
fake_A2A, _, _ = genB2A(real_A) | |
colsA = [real_A, fake_A2A] | |
colsB = [real_B, fake_B2B] | |
fake_A2B_1 = genA2B.decode(A2B_content, A2B_mod1) | |
fake_B2A_1 = genB2A.decode(B2A_content, B2A_mod1) | |
fake_A2B_2 = genA2B.decode(A2B_content, A2B_mod2) | |
fake_B2A_2 = genB2A.decode(B2A_content, B2A_mod2) | |
fake_A2B_3 = genA2B.decode(A2B_content, B2A_style) | |
fake_B2A_3 = genB2A.decode(B2A_content, A2B_style) | |
colsA += [fake_A2B_3, fake_A2B_1, fake_A2B_2] | |
colsB += [fake_B2A_3, fake_B2A_1, fake_B2A_2] | |
fake_A2B2A, _, _ = genB2A(fake_A2B_3, A2B_style) | |
fake_B2A2B, _, _ = genA2B(fake_B2A_3, B2A_style) | |
colsA.append(fake_A2B2A) | |
colsB.append(fake_B2A2B) | |
fake_A2B2A, _, _ = genB2A(fake_A2B_1, A2B_style) | |
fake_B2A2B, _, _ = genA2B(fake_B2A_1, B2A_style) | |
colsA.append(fake_A2B2A) | |
colsB.append(fake_B2A2B) | |
fake_A2B2A, _, _ = genB2A(fake_A2B_2, A2B_style) | |
fake_B2A2B, _, _ = genA2B(fake_B2A_2, B2A_style) | |
colsA.append(fake_A2B2A) | |
colsB.append(fake_B2A2B) | |
fake_A2B2A, _, _ = genB2A(fake_A2B_1) | |
fake_B2A2B, _, _ = genA2B(fake_B2A_1) | |
colsA.append(fake_A2B2A) | |
colsB.append(fake_B2A2B) | |
colsA = torch.cat(colsA, 2).detach().cpu() | |
colsB = torch.cat(colsB, 2).detach().cpu() | |
A2B.append(colsA) | |
B2A.append(colsB) | |
A2B = torch.cat(A2B, 0) | |
B2A = torch.cat(B2A, 0) | |
utils.save_image(A2B, f'{im_path}/{name}_A2B_{str(step).zfill(6)}.jpg', normalize=True, range=(-1, 1), nrow=16) | |
utils.save_image(B2A, f'{im_path}/{name}_B2A_{str(step).zfill(6)}.jpg', normalize=True, range=(-1, 1), nrow=16) | |
genA2B.train(), genB2A.train() | |
def train(args, trainA_loader, trainB_loader, testA_loader, testB_loader, G_A2B, G_B2A, D_A, D_B, G_optim, D_optim, device): | |
G_A2B.train(), G_B2A.train(), D_A.train(), D_B.train() | |
trainA_loader = sample_data(trainA_loader) | |
trainB_loader = sample_data(trainB_loader) | |
G_scheduler = lr_scheduler.StepLR(G_optim, step_size=100000, gamma=0.5) | |
D_scheduler = lr_scheduler.StepLR(D_optim, step_size=100000, gamma=0.5) | |
pbar = range(args.iter) | |
if get_rank() == 0: | |
pbar = tqdm(pbar, initial=args.start_iter, dynamic_ncols=True, smoothing=0.1) | |
loss_dict = {} | |
mean_path_length_A2B = 0 | |
mean_path_length_B2A = 0 | |
if args.distributed: | |
G_A2B_module = G_A2B.module | |
G_B2A_module = G_B2A.module | |
D_A_module = D_A.module | |
D_B_module = D_B.module | |
D_L_module = D_L.module | |
else: | |
G_A2B_module = G_A2B | |
G_B2A_module = G_B2A | |
D_A_module = D_A | |
D_B_module = D_B | |
D_L_module = D_L | |
for idx in pbar: | |
i = idx + args.start_iter | |
if i > args.iter: | |
print('Done!') | |
break | |
ori_A = next(trainA_loader) | |
ori_B = next(trainB_loader) | |
if isinstance(ori_A, list): | |
ori_A = ori_A[0] | |
if isinstance(ori_B, list): | |
ori_B = ori_B[0] | |
ori_A = ori_A.to(device) | |
ori_B = ori_B.to(device) | |
aug_A = aug(ori_A) | |
aug_B = aug(ori_B) | |
A = aug(ori_A[[np.random.randint(args.batch)]].expand_as(ori_A)) | |
B = aug(ori_B[[np.random.randint(args.batch)]].expand_as(ori_B)) | |
if i % args.d_reg_every == 0: | |
aug_A.requires_grad = True | |
aug_B.requires_grad = True | |
A2B_content, A2B_style = G_A2B.encode(A) | |
B2A_content, B2A_style = G_B2A.encode(B) | |
# get new style | |
aug_A2B_style = G_B2A.style_encode(aug_B) | |
aug_B2A_style = G_A2B.style_encode(aug_A) | |
rand_A2B_style = torch.randn([args.batch, args.latent_dim]).to(device).requires_grad_() | |
rand_B2A_style = torch.randn([args.batch, args.latent_dim]).to(device).requires_grad_() | |
# styles | |
idx = torch.randperm(2*args.batch) | |
input_A2B_style = torch.cat([rand_A2B_style, aug_A2B_style], 0)[idx][:args.batch] | |
idx = torch.randperm(2*args.batch) | |
input_B2A_style = torch.cat([rand_B2A_style, aug_B2A_style], 0)[idx][:args.batch] | |
fake_A2B = G_A2B.decode(A2B_content, input_A2B_style) | |
fake_B2A = G_B2A.decode(B2A_content, input_B2A_style) | |
# train disc | |
real_A_logit = D_A(aug_A) | |
real_B_logit = D_B(aug_B) | |
real_L_logit1 = D_L(rand_A2B_style) | |
real_L_logit2 = D_L(rand_B2A_style) | |
fake_B_logit = D_B(fake_A2B.detach()) | |
fake_A_logit = D_A(fake_B2A.detach()) | |
fake_L_logit1 = D_L(aug_A2B_style.detach()) | |
fake_L_logit2 = D_L(aug_B2A_style.detach()) | |
# global loss | |
D_loss = d_logistic_loss(real_A_logit, fake_A_logit) +\ | |
d_logistic_loss(real_B_logit, fake_B_logit) +\ | |
d_logistic_loss(real_L_logit1, fake_L_logit1) +\ | |
d_logistic_loss(real_L_logit2, fake_L_logit2) | |
loss_dict['D_adv'] = D_loss | |
if i % args.d_reg_every == 0: | |
r1_A_loss = d_r1_loss(real_A_logit, aug_A) | |
r1_B_loss = d_r1_loss(real_B_logit, aug_B) | |
r1_L_loss = d_r1_loss(real_L_logit1, rand_A2B_style) + d_r1_loss(real_L_logit2, rand_B2A_style) | |
r1_loss = r1_A_loss + r1_B_loss + r1_L_loss | |
D_r1_loss = (args.r1 / 2 * r1_loss * args.d_reg_every) | |
D_loss += D_r1_loss | |
D_optim.zero_grad() | |
D_loss.backward() | |
D_optim.step() | |
#Generator | |
# adv loss | |
fake_B_logit = D_B(fake_A2B) | |
fake_A_logit = D_A(fake_B2A) | |
fake_L_logit1 = D_L(aug_A2B_style) | |
fake_L_logit2 = D_L(aug_B2A_style) | |
lambda_adv = (1, 1, 1) | |
G_adv_loss = 1 * (g_nonsaturating_loss(fake_A_logit, lambda_adv) +\ | |
g_nonsaturating_loss(fake_B_logit, lambda_adv) +\ | |
2*g_nonsaturating_loss(fake_L_logit1, (1,)) +\ | |
2*g_nonsaturating_loss(fake_L_logit2, (1,))) | |
# style consis loss | |
G_con_loss = 50 * (A2B_style.var(0, unbiased=False).sum() + B2A_style.var(0, unbiased=False).sum()) | |
# cycle recon | |
A2B2A_content, A2B2A_style = G_B2A.encode(fake_A2B) | |
B2A2B_content, B2A2B_style = G_A2B.encode(fake_B2A) | |
fake_A2B2A = G_B2A.decode(A2B2A_content, shuffle_batch(A2B_style)) | |
fake_B2A2B = G_A2B.decode(B2A2B_content, shuffle_batch(B2A_style)) | |
G_cycle_loss = 20 * (F.mse_loss(fake_A2B2A, A) + F.mse_loss(fake_B2A2B, B)) | |
lpips_loss = 10 * (lpips_fn(fake_A2B2A, A).mean() + lpips_fn(fake_B2A2B, B).mean()) #10 for anime | |
# style reconstruction | |
G_style_loss = 5 * (mse_criterion(A2B2A_style, input_A2B_style) +\ | |
mse_criterion(B2A2B_style, input_B2A_style)) | |
G_loss = G_adv_loss + G_cycle_loss + G_con_loss + lpips_loss + G_style_loss | |
loss_dict['G_adv'] = G_adv_loss | |
loss_dict['G_con'] = G_con_loss | |
loss_dict['G_cycle'] = G_cycle_loss | |
loss_dict['lpips'] = lpips_loss | |
G_optim.zero_grad() | |
G_loss.backward() | |
G_optim.step() | |
G_scheduler.step() | |
D_scheduler.step() | |
accumulate(G_A2B_ema, G_A2B_module) | |
accumulate(G_B2A_ema, G_B2A_module) | |
loss_reduced = reduce_loss_dict(loss_dict) | |
D_adv_loss_val = loss_reduced['D_adv'].mean().item() | |
G_adv_loss_val = loss_reduced['G_adv'].mean().item() | |
G_cycle_loss_val = loss_reduced['G_cycle'].mean().item() | |
G_con_loss_val = loss_reduced['G_con'].mean().item() | |
lpips_val = loss_reduced['lpips'].mean().item() | |
if get_rank() == 0: | |
pbar.set_description( | |
( | |
f'Dadv: {D_adv_loss_val:.2f}; lpips: {lpips_val:.2f} ' | |
f'Gadv: {G_adv_loss_val:.2f}; Gcycle: {G_cycle_loss_val:.2f}; GMS: {G_con_loss_val:.2f} {G_style_loss.item():.2f}' | |
) | |
) | |
if i % 1000 == 0: | |
with torch.no_grad(): | |
test(args, G_A2B, G_B2A, testA_loader, testB_loader, 'normal', i) | |
test(args, G_A2B_ema, G_B2A_ema, testA_loader, testB_loader, 'ema', i) | |
if (i+1) % 2000 == 0: | |
torch.save( | |
{ | |
'G_A2B': G_A2B_module.state_dict(), | |
'G_B2A': G_B2A_module.state_dict(), | |
'G_A2B_ema': G_A2B_ema.state_dict(), | |
'G_B2A_ema': G_B2A_ema.state_dict(), | |
'D_A': D_A_module.state_dict(), | |
'D_B': D_B_module.state_dict(), | |
'D_L': D_L_module.state_dict(), | |
'G_optim': G_optim.state_dict(), | |
'D_optim': D_optim.state_dict(), | |
'iter': i, | |
}, | |
os.path.join(model_path, 'ck.pt'), | |
) | |
if __name__ == '__main__': | |
device = 'cuda' | |
parser = argparse.ArgumentParser() | |
parser.add_argument('--iter', type=int, default=300000) | |
parser.add_argument('--batch', type=int, default=4) | |
parser.add_argument('--n_sample', type=int, default=64) | |
parser.add_argument('--size', type=int, default=256) | |
parser.add_argument('--r1', type=float, default=10) | |
parser.add_argument('--lambda_cycle', type=int, default=1) | |
parser.add_argument('--path_regularize', type=float, default=2) | |
parser.add_argument('--path_batch_shrink', type=int, default=2) | |
parser.add_argument('--d_reg_every', type=int, default=16) | |
parser.add_argument('--g_reg_every', type=int, default=4) | |
parser.add_argument('--mixing', type=float, default=0.9) | |
parser.add_argument('--ckpt', type=str, default=None) | |
parser.add_argument('--lr', type=float, default=2e-3) | |
parser.add_argument('--local_rank', type=int, default=0) | |
parser.add_argument('--num_down', type=int, default=3) | |
parser.add_argument('--name', type=str, required=True) | |
parser.add_argument('--d_path', type=str, required=True) | |
parser.add_argument('--latent_dim', type=int, default=8) | |
parser.add_argument('--lr_mlp', type=float, default=0.01) | |
parser.add_argument('--n_res', type=int, default=1) | |
args = parser.parse_args() | |
n_gpu = int(os.environ['WORLD_SIZE']) if 'WORLD_SIZE' in os.environ else 1 | |
args.distributed = False | |
if args.distributed: | |
torch.cuda.set_device(args.local_rank) | |
torch.distributed.init_process_group(backend='nccl', init_method='env://') | |
synchronize() | |
save_path = f'./{args.name}' | |
im_path = os.path.join(save_path, 'sample') | |
model_path = os.path.join(save_path, 'checkpoint') | |
os.makedirs(im_path, exist_ok=True) | |
os.makedirs(model_path, exist_ok=True) | |
args.n_mlp = 5 | |
args.start_iter = 0 | |
G_A2B = Generator( args.size, args.num_down, args.latent_dim, args.n_mlp, lr_mlp=args.lr_mlp, n_res=args.n_res).to(device) | |
D_A = Discriminator(args.size).to(device) | |
G_B2A = Generator( args.size, args.num_down, args.latent_dim, args.n_mlp, lr_mlp=args.lr_mlp, n_res=args.n_res).to(device) | |
D_B = Discriminator(args.size).to(device) | |
D_L = LatDiscriminator(args.latent_dim).to(device) | |
lpips_fn = lpips.LPIPS(net='vgg').to(device) | |
G_A2B_ema = copy.deepcopy(G_A2B).to(device).eval() | |
G_B2A_ema = copy.deepcopy(G_B2A).to(device).eval() | |
g_reg_ratio = args.g_reg_every / (args.g_reg_every + 1) | |
d_reg_ratio = args.d_reg_every / (args.d_reg_every + 1) | |
G_optim = optim.Adam( list(G_A2B.parameters()) + list(G_B2A.parameters()), lr=args.lr, betas=(0, 0.99)) | |
D_optim = optim.Adam( | |
list(D_L.parameters()) + list(D_A.parameters()) + list(D_B.parameters()), | |
lr=args.lr, betas=(0**d_reg_ratio, 0.99**d_reg_ratio)) | |
if args.ckpt is not None: | |
ckpt = torch.load(args.ckpt, map_location=lambda storage, loc: storage) | |
try: | |
ckpt_name = os.path.basename(args.ckpt) | |
args.start_iter = int(os.path.splitext(ckpt_name)[0]) | |
except ValueError: | |
pass | |
G_A2B.load_state_dict(ckpt['G_A2B']) | |
G_B2A.load_state_dict(ckpt['G_B2A']) | |
G_A2B_ema.load_state_dict(ckpt['G_A2B_ema']) | |
G_B2A_ema.load_state_dict(ckpt['G_B2A_ema']) | |
D_A.load_state_dict(ckpt['D_A']) | |
D_B.load_state_dict(ckpt['D_B']) | |
D_L.load_state_dict(ckpt['D_L']) | |
G_optim.load_state_dict(ckpt['G_optim']) | |
D_optim.load_state_dict(ckpt['D_optim']) | |
args.start_iter = ckpt['iter'] | |
if args.distributed: | |
G_A2B = nn.parallel.DistributedDataParallel( | |
G_A2B, | |
device_ids=[args.local_rank], | |
output_device=args.local_rank, | |
broadcast_buffers=False, | |
) | |
D_A = nn.parallel.DistributedDataParallel( | |
D_A, | |
device_ids=[args.local_rank], | |
output_device=args.local_rank, | |
broadcast_buffers=False, | |
) | |
G_B2A = nn.parallel.DistributedDataParallel( | |
G_B2A, | |
device_ids=[args.local_rank], | |
output_device=args.local_rank, | |
broadcast_buffers=False, | |
) | |
D_B = nn.parallel.DistributedDataParallel( | |
D_B, | |
device_ids=[args.local_rank], | |
output_device=args.local_rank, | |
broadcast_buffers=False, | |
) | |
D_L = nn.parallel.DistributedDataParallel( | |
D_L, | |
device_ids=[args.local_rank], | |
output_device=args.local_rank, | |
broadcast_buffers=False, | |
) | |
train_transform = transforms.Compose([ | |
transforms.ToTensor(), | |
transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), inplace=True) | |
]) | |
test_transform = transforms.Compose([ | |
transforms.Resize((args.size, args.size)), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), inplace=True) | |
]) | |
aug = nn.Sequential( | |
K.RandomAffine(degrees=(-20,20), scale=(0.8, 1.2), translate=(0.1, 0.1), shear=0.15), | |
kornia.geometry.transform.Resize(256+30), | |
K.RandomCrop((256,256)), | |
K.RandomHorizontalFlip(), | |
) | |
d_path = args.d_path | |
trainA = ImageFolder(os.path.join(d_path, 'trainA'), train_transform) | |
trainB = ImageFolder(os.path.join(d_path, 'trainB'), train_transform) | |
testA = ImageFolder(os.path.join(d_path, 'testA'), test_transform) | |
testB = ImageFolder(os.path.join(d_path, 'testB'), test_transform) | |
trainA_loader = data.DataLoader(trainA, batch_size=args.batch, | |
sampler=data_sampler(trainA, shuffle=True, distributed=args.distributed), drop_last=True, pin_memory=True, num_workers=5) | |
trainB_loader = data.DataLoader(trainB, batch_size=args.batch, | |
sampler=data_sampler(trainB, shuffle=True, distributed=args.distributed), drop_last=True, pin_memory=True, num_workers=5) | |
testA_loader = data.DataLoader(testA, batch_size=1, shuffle=False) | |
testB_loader = data.DataLoader(testB, batch_size=1, shuffle=False) | |
train(args, trainA_loader, trainB_loader, testA_loader, testB_loader, G_A2B, G_B2A, D_A, D_B, G_optim, D_optim, device) | |