File size: 9,728 Bytes
f884940 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 |
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import os
from torch.autograd import Variable
from .base_model import BaseModel
from . import networks
class SpecificNorm(nn.Module):
def __init__(self, epsilon=1e-8):
"""
@notice: avoid in-place ops.
https://discuss.pytorch.org/t/encounter-the-runtimeerror-one-of-the-variables-needed-for-gradient-computation-has-been-modified-by-an-inplace-operation/836/3
"""
super(SpecificNorm, self).__init__()
self.mean = np.array([0.485, 0.456, 0.406])
self.mean = torch.from_numpy(self.mean).float().cuda()
self.mean = self.mean.view([1, 3, 1, 1])
self.std = np.array([0.229, 0.224, 0.225])
self.std = torch.from_numpy(self.std).float().cuda()
self.std = self.std.view([1, 3, 1, 1])
def forward(self, x):
mean = self.mean.expand([1, 3, x.shape[2], x.shape[3]])
std = self.std.expand([1, 3, x.shape[2], x.shape[3]])
x = (x - mean) / std
return x
class fsModel(BaseModel):
def name(self):
return 'fsModel'
def init_loss_filter(self, use_gan_feat_loss, use_vgg_loss):
flags = (True, use_gan_feat_loss, use_vgg_loss, True, True, True, True, True)
def loss_filter(g_gan, g_gan_feat, g_vgg, g_id, g_rec, g_mask, d_real, d_fake):
return [l for (l, f) in zip((g_gan, g_gan_feat, g_vgg, g_id, g_rec, g_mask, d_real, d_fake), flags) if f]
return loss_filter
def initialize(self, opt):
BaseModel.initialize(self, opt)
if opt.resize_or_crop != 'none' or not opt.isTrain: # when training at full res this causes OOM
torch.backends.cudnn.benchmark = True
self.isTrain = opt.isTrain
device = torch.device("cuda:0")
if opt.crop_size == 224:
from .fs_networks import Generator_Adain_Upsample, Discriminator
elif opt.crop_size == 512:
from .fs_networks_512 import Generator_Adain_Upsample, Discriminator
# Generator network
self.netG = Generator_Adain_Upsample(input_nc=3, output_nc=3, latent_size=512, n_blocks=9, deep=False)
self.netG.to(device)
# Id network
netArc_checkpoint = opt.Arc_path
netArc_checkpoint = torch.load(netArc_checkpoint, map_location=torch.device("cpu"))
self.netArc = netArc_checkpoint
self.netArc = self.netArc.to(device)
self.netArc.eval()
if not self.isTrain:
pretrained_path = '' if not self.isTrain else opt.load_pretrain
self.load_network(self.netG, 'G', opt.which_epoch, pretrained_path)
return
# Discriminator network
if opt.gan_mode == 'original':
use_sigmoid = True
else:
use_sigmoid = False
self.netD1 = Discriminator(input_nc=3, use_sigmoid=use_sigmoid)
self.netD2 = Discriminator(input_nc=3, use_sigmoid=use_sigmoid)
self.netD1.to(device)
self.netD2.to(device)
#
self.spNorm =SpecificNorm()
self.downsample = nn.AvgPool2d(3, stride=2, padding=[1, 1], count_include_pad=False)
# load networks
if opt.continue_train or opt.load_pretrain:
pretrained_path = '' if not self.isTrain else opt.load_pretrain
# print (pretrained_path)
self.load_network(self.netG, 'G', opt.which_epoch, pretrained_path)
self.load_network(self.netD1, 'D1', opt.which_epoch, pretrained_path)
self.load_network(self.netD2, 'D2', opt.which_epoch, pretrained_path)
if self.isTrain:
# define loss functions
self.loss_filter = self.init_loss_filter(not opt.no_ganFeat_loss, not opt.no_vgg_loss)
self.criterionGAN = networks.GANLoss(opt.gan_mode, tensor=self.Tensor, opt=self.opt)
self.criterionFeat = nn.L1Loss()
self.criterionRec = nn.L1Loss()
# Names so we can breakout loss
self.loss_names = self.loss_filter('G_GAN', 'G_GAN_Feat', 'G_VGG', 'G_ID', 'G_Rec', 'D_GP',
'D_real', 'D_fake')
# initialize optimizers
# optimizer G
params = list(self.netG.parameters())
self.optimizer_G = torch.optim.Adam(params, lr=opt.lr, betas=(opt.beta1, 0.999))
# optimizer D
params = list(self.netD1.parameters()) + list(self.netD2.parameters())
self.optimizer_D = torch.optim.Adam(params, lr=opt.lr, betas=(opt.beta1, 0.999))
def _gradinet_penalty_D(self, netD, img_att, img_fake):
# interpolate sample
bs = img_fake.shape[0]
alpha = torch.rand(bs, 1, 1, 1).expand_as(img_fake).cuda()
interpolated = Variable(alpha * img_att + (1 - alpha) * img_fake, requires_grad=True)
pred_interpolated = netD.forward(interpolated)
pred_interpolated = pred_interpolated[-1]
# compute gradients
grad = torch.autograd.grad(outputs=pred_interpolated,
inputs=interpolated,
grad_outputs=torch.ones(pred_interpolated.size()).cuda(),
retain_graph=True,
create_graph=True,
only_inputs=True)[0]
# penalize gradients
grad = grad.view(grad.size(0), -1)
grad_l2norm = torch.sqrt(torch.sum(grad ** 2, dim=1))
loss_d_gp = torch.mean((grad_l2norm - 1) ** 2)
return loss_d_gp
def cosin_metric(self, x1, x2):
#return np.dot(x1, x2) / (np.linalg.norm(x1) * np.linalg.norm(x2))
return torch.sum(x1 * x2, dim=1) / (torch.norm(x1, dim=1) * torch.norm(x2, dim=1))
def forward(self, img_id, img_att, latent_id, latent_att, for_G=False):
loss_D_fake, loss_D_real, loss_D_GP = 0, 0, 0
loss_G_GAN, loss_G_GAN_Feat, loss_G_VGG, loss_G_ID, loss_G_Rec = 0,0,0,0,0
img_fake = self.netG.forward(img_att, latent_id)
if not self.isTrain:
return img_fake
img_fake_downsample = self.downsample(img_fake)
img_att_downsample = self.downsample(img_att)
# D_Fake
fea1_fake = self.netD1.forward(img_fake.detach())
fea2_fake = self.netD2.forward(img_fake_downsample.detach())
pred_fake = [fea1_fake, fea2_fake]
loss_D_fake = self.criterionGAN(pred_fake, False, for_discriminator=True)
# D_Feal
fea1_real = self.netD1.forward(img_att)
fea2_real = self.netD2.forward(img_att_downsample)
pred_real = [fea1_real, fea2_real]
fea_real = [fea1_real, fea2_real]
loss_D_real = self.criterionGAN(pred_real, True, for_discriminator=True)
#print('=====================D_Real========================')
# D_GP
loss_D_GP = 0
# G_GAN
fea1_fake = self.netD1.forward(img_fake)
fea2_fake = self.netD2.forward(img_fake_downsample)
#pred_fake = [fea1_fake[-1], fea2_fake[-1]]
pred_fake = [fea1_fake, fea2_fake]
fea_fake = [fea1_fake, fea2_fake]
loss_G_GAN = self.criterionGAN(pred_fake, True, for_discriminator=False)
# GAN feature matching loss
n_layers_D = 4
num_D = 2
if not self.opt.no_ganFeat_loss:
feat_weights = 4.0 / (n_layers_D + 1)
D_weights = 1.0 / num_D
for i in range(num_D):
for j in range(0, len(fea_fake[i]) - 1):
loss_G_GAN_Feat += D_weights * feat_weights * \
self.criterionFeat(fea_fake[i][j],
fea_real[i][j].detach()) * self.opt.lambda_feat
#G_ID
img_fake_down = F.interpolate(img_fake, size=(112,112))
img_fake_down = self.spNorm(img_fake_down)
latent_fake = self.netArc(img_fake_down)
loss_G_ID = (1 - self.cosin_metric(latent_fake, latent_id))
#print('=====================G_ID========================')
#print(loss_G_ID)
#G_Rec
loss_G_Rec = self.criterionRec(img_fake, img_att) * self.opt.lambda_rec
# Only return the fake_B image if necessary to save BW
return [self.loss_filter(loss_G_GAN, loss_G_GAN_Feat, loss_G_VGG, loss_G_ID, loss_G_Rec, loss_D_GP, loss_D_real, loss_D_fake),
img_fake]
def save(self, which_epoch):
self.save_network(self.netG, 'G', which_epoch, self.gpu_ids)
self.save_network(self.netD1, 'D1', which_epoch, self.gpu_ids)
self.save_network(self.netD2, 'D2', which_epoch, self.gpu_ids)
'''if self.gen_features:
self.save_network(self.netE, 'E', which_epoch, self.gpu_ids)'''
def update_fixed_params(self):
# after fixing the global generator for a number of iterations, also start finetuning it
params = list(self.netG.parameters())
if self.gen_features:
params += list(self.netE.parameters())
self.optimizer_G = torch.optim.Adam(params, lr=self.opt.lr, betas=(self.opt.beta1, 0.999))
if self.opt.verbose:
print('------------ Now also finetuning global generator -----------')
def update_learning_rate(self):
lrd = self.opt.lr / self.opt.niter_decay
lr = self.old_lr - lrd
for param_group in self.optimizer_D.param_groups:
param_group['lr'] = lr
for param_group in self.optimizer_G.param_groups:
param_group['lr'] = lr
if self.opt.verbose:
print('update learning rate: %f -> %f' % (self.old_lr, lr))
self.old_lr = lr
|