|
import numpy as np |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
import os |
|
from torch.autograd import Variable |
|
from .base_model import BaseModel |
|
from . import networks |
|
|
|
class SpecificNorm(nn.Module): |
|
def __init__(self, epsilon=1e-8): |
|
""" |
|
@notice: avoid in-place ops. |
|
https://discuss.pytorch.org/t/encounter-the-runtimeerror-one-of-the-variables-needed-for-gradient-computation-has-been-modified-by-an-inplace-operation/836/3 |
|
""" |
|
super(SpecificNorm, self).__init__() |
|
self.mean = np.array([0.485, 0.456, 0.406]) |
|
self.mean = torch.from_numpy(self.mean).float().cuda() |
|
self.mean = self.mean.view([1, 3, 1, 1]) |
|
|
|
self.std = np.array([0.229, 0.224, 0.225]) |
|
self.std = torch.from_numpy(self.std).float().cuda() |
|
self.std = self.std.view([1, 3, 1, 1]) |
|
|
|
def forward(self, x): |
|
mean = self.mean.expand([1, 3, x.shape[2], x.shape[3]]) |
|
std = self.std.expand([1, 3, x.shape[2], x.shape[3]]) |
|
|
|
x = (x - mean) / std |
|
|
|
return x |
|
|
|
class fsModel(BaseModel): |
|
def name(self): |
|
return 'fsModel' |
|
|
|
def init_loss_filter(self, use_gan_feat_loss, use_vgg_loss): |
|
flags = (True, use_gan_feat_loss, use_vgg_loss, True, True, True, True, True) |
|
|
|
def loss_filter(g_gan, g_gan_feat, g_vgg, g_id, g_rec, g_mask, d_real, d_fake): |
|
return [l for (l, f) in zip((g_gan, g_gan_feat, g_vgg, g_id, g_rec, g_mask, d_real, d_fake), flags) if f] |
|
|
|
return loss_filter |
|
|
|
def initialize(self, opt): |
|
BaseModel.initialize(self, opt) |
|
if opt.resize_or_crop != 'none' or not opt.isTrain: |
|
torch.backends.cudnn.benchmark = True |
|
self.isTrain = opt.isTrain |
|
|
|
device = torch.device("cuda:0") |
|
|
|
if opt.crop_size == 224: |
|
from .fs_networks import Generator_Adain_Upsample, Discriminator |
|
elif opt.crop_size == 512: |
|
from .fs_networks_512 import Generator_Adain_Upsample, Discriminator |
|
|
|
|
|
self.netG = Generator_Adain_Upsample(input_nc=3, output_nc=3, latent_size=512, n_blocks=9, deep=False) |
|
self.netG.to(device) |
|
|
|
|
|
netArc_checkpoint = opt.Arc_path |
|
netArc_checkpoint = torch.load(netArc_checkpoint, map_location=torch.device("cpu")) |
|
self.netArc = netArc_checkpoint |
|
self.netArc = self.netArc.to(device) |
|
self.netArc.eval() |
|
|
|
if not self.isTrain: |
|
pretrained_path = '' if not self.isTrain else opt.load_pretrain |
|
self.load_network(self.netG, 'G', opt.which_epoch, pretrained_path) |
|
return |
|
|
|
|
|
if opt.gan_mode == 'original': |
|
use_sigmoid = True |
|
else: |
|
use_sigmoid = False |
|
self.netD1 = Discriminator(input_nc=3, use_sigmoid=use_sigmoid) |
|
self.netD2 = Discriminator(input_nc=3, use_sigmoid=use_sigmoid) |
|
self.netD1.to(device) |
|
self.netD2.to(device) |
|
|
|
|
|
self.spNorm =SpecificNorm() |
|
self.downsample = nn.AvgPool2d(3, stride=2, padding=[1, 1], count_include_pad=False) |
|
|
|
|
|
if opt.continue_train or opt.load_pretrain: |
|
pretrained_path = '' if not self.isTrain else opt.load_pretrain |
|
|
|
self.load_network(self.netG, 'G', opt.which_epoch, pretrained_path) |
|
self.load_network(self.netD1, 'D1', opt.which_epoch, pretrained_path) |
|
self.load_network(self.netD2, 'D2', opt.which_epoch, pretrained_path) |
|
|
|
|
|
|
|
if self.isTrain: |
|
|
|
self.loss_filter = self.init_loss_filter(not opt.no_ganFeat_loss, not opt.no_vgg_loss) |
|
|
|
self.criterionGAN = networks.GANLoss(opt.gan_mode, tensor=self.Tensor, opt=self.opt) |
|
self.criterionFeat = nn.L1Loss() |
|
self.criterionRec = nn.L1Loss() |
|
|
|
|
|
self.loss_names = self.loss_filter('G_GAN', 'G_GAN_Feat', 'G_VGG', 'G_ID', 'G_Rec', 'D_GP', |
|
'D_real', 'D_fake') |
|
|
|
|
|
|
|
|
|
params = list(self.netG.parameters()) |
|
self.optimizer_G = torch.optim.Adam(params, lr=opt.lr, betas=(opt.beta1, 0.999)) |
|
|
|
|
|
params = list(self.netD1.parameters()) + list(self.netD2.parameters()) |
|
self.optimizer_D = torch.optim.Adam(params, lr=opt.lr, betas=(opt.beta1, 0.999)) |
|
|
|
def _gradinet_penalty_D(self, netD, img_att, img_fake): |
|
|
|
bs = img_fake.shape[0] |
|
alpha = torch.rand(bs, 1, 1, 1).expand_as(img_fake).cuda() |
|
interpolated = Variable(alpha * img_att + (1 - alpha) * img_fake, requires_grad=True) |
|
pred_interpolated = netD.forward(interpolated) |
|
pred_interpolated = pred_interpolated[-1] |
|
|
|
|
|
grad = torch.autograd.grad(outputs=pred_interpolated, |
|
inputs=interpolated, |
|
grad_outputs=torch.ones(pred_interpolated.size()).cuda(), |
|
retain_graph=True, |
|
create_graph=True, |
|
only_inputs=True)[0] |
|
|
|
|
|
grad = grad.view(grad.size(0), -1) |
|
grad_l2norm = torch.sqrt(torch.sum(grad ** 2, dim=1)) |
|
loss_d_gp = torch.mean((grad_l2norm - 1) ** 2) |
|
|
|
return loss_d_gp |
|
|
|
def cosin_metric(self, x1, x2): |
|
|
|
return torch.sum(x1 * x2, dim=1) / (torch.norm(x1, dim=1) * torch.norm(x2, dim=1)) |
|
|
|
def forward(self, img_id, img_att, latent_id, latent_att, for_G=False): |
|
loss_D_fake, loss_D_real, loss_D_GP = 0, 0, 0 |
|
loss_G_GAN, loss_G_GAN_Feat, loss_G_VGG, loss_G_ID, loss_G_Rec = 0,0,0,0,0 |
|
|
|
img_fake = self.netG.forward(img_att, latent_id) |
|
if not self.isTrain: |
|
return img_fake |
|
img_fake_downsample = self.downsample(img_fake) |
|
img_att_downsample = self.downsample(img_att) |
|
|
|
|
|
|
|
|
|
fea1_fake = self.netD1.forward(img_fake.detach()) |
|
fea2_fake = self.netD2.forward(img_fake_downsample.detach()) |
|
pred_fake = [fea1_fake, fea2_fake] |
|
loss_D_fake = self.criterionGAN(pred_fake, False, for_discriminator=True) |
|
|
|
|
|
|
|
fea1_real = self.netD1.forward(img_att) |
|
fea2_real = self.netD2.forward(img_att_downsample) |
|
pred_real = [fea1_real, fea2_real] |
|
fea_real = [fea1_real, fea2_real] |
|
loss_D_real = self.criterionGAN(pred_real, True, for_discriminator=True) |
|
|
|
|
|
|
|
|
|
loss_D_GP = 0 |
|
|
|
|
|
fea1_fake = self.netD1.forward(img_fake) |
|
fea2_fake = self.netD2.forward(img_fake_downsample) |
|
|
|
pred_fake = [fea1_fake, fea2_fake] |
|
fea_fake = [fea1_fake, fea2_fake] |
|
loss_G_GAN = self.criterionGAN(pred_fake, True, for_discriminator=False) |
|
|
|
|
|
n_layers_D = 4 |
|
num_D = 2 |
|
if not self.opt.no_ganFeat_loss: |
|
feat_weights = 4.0 / (n_layers_D + 1) |
|
D_weights = 1.0 / num_D |
|
for i in range(num_D): |
|
for j in range(0, len(fea_fake[i]) - 1): |
|
loss_G_GAN_Feat += D_weights * feat_weights * \ |
|
self.criterionFeat(fea_fake[i][j], |
|
fea_real[i][j].detach()) * self.opt.lambda_feat |
|
|
|
|
|
|
|
img_fake_down = F.interpolate(img_fake, size=(112,112)) |
|
img_fake_down = self.spNorm(img_fake_down) |
|
latent_fake = self.netArc(img_fake_down) |
|
loss_G_ID = (1 - self.cosin_metric(latent_fake, latent_id)) |
|
|
|
|
|
|
|
|
|
loss_G_Rec = self.criterionRec(img_fake, img_att) * self.opt.lambda_rec |
|
|
|
|
|
return [self.loss_filter(loss_G_GAN, loss_G_GAN_Feat, loss_G_VGG, loss_G_ID, loss_G_Rec, loss_D_GP, loss_D_real, loss_D_fake), |
|
img_fake] |
|
|
|
|
|
def save(self, which_epoch): |
|
self.save_network(self.netG, 'G', which_epoch, self.gpu_ids) |
|
self.save_network(self.netD1, 'D1', which_epoch, self.gpu_ids) |
|
self.save_network(self.netD2, 'D2', which_epoch, self.gpu_ids) |
|
'''if self.gen_features: |
|
self.save_network(self.netE, 'E', which_epoch, self.gpu_ids)''' |
|
|
|
def update_fixed_params(self): |
|
|
|
params = list(self.netG.parameters()) |
|
if self.gen_features: |
|
params += list(self.netE.parameters()) |
|
self.optimizer_G = torch.optim.Adam(params, lr=self.opt.lr, betas=(self.opt.beta1, 0.999)) |
|
if self.opt.verbose: |
|
print('------------ Now also finetuning global generator -----------') |
|
|
|
def update_learning_rate(self): |
|
lrd = self.opt.lr / self.opt.niter_decay |
|
lr = self.old_lr - lrd |
|
for param_group in self.optimizer_D.param_groups: |
|
param_group['lr'] = lr |
|
for param_group in self.optimizer_G.param_groups: |
|
param_group['lr'] = lr |
|
if self.opt.verbose: |
|
print('update learning rate: %f -> %f' % (self.old_lr, lr)) |
|
self.old_lr = lr |
|
|
|
|
|
|