# Copyright (c) Microsoft Corporation. # Licensed under the MIT License. import torch import torch.nn as nn import functools from torch.autograd import Variable import numpy as np from torch.nn.utils import spectral_norm # from util.util import SwitchNorm2d import torch.nn.functional as F ############################################################################### # Functions ############################################################################### def weights_init(m): classname = m.__class__.__name__ if classname.find("Conv") != -1: m.weight.data.normal_(0.0, 0.02) elif classname.find("BatchNorm2d") != -1: m.weight.data.normal_(1.0, 0.02) m.bias.data.fill_(0) def get_norm_layer(norm_type="instance"): if norm_type == "batch": norm_layer = functools.partial(nn.BatchNorm2d, affine=True) elif norm_type == "instance": norm_layer = functools.partial(nn.InstanceNorm2d, affine=False) elif norm_type == "spectral": norm_layer = spectral_norm() elif norm_type == "SwitchNorm": norm_layer = SwitchNorm2d else: raise NotImplementedError("normalization layer [%s] is not found" % norm_type) return norm_layer def print_network(net): if isinstance(net, list): net = net[0] num_params = 0 for param in net.parameters(): num_params += param.numel() print(net) print("Total number of parameters: %d" % num_params) def define_G(input_nc, output_nc, ngf, netG, k_size=3, n_downsample_global=3, n_blocks_global=9, n_local_enhancers=1, n_blocks_local=3, norm='instance', gpu_ids=[], opt=None): norm_layer = get_norm_layer(norm_type=norm) if netG == 'global': # if opt.self_gen: if opt.use_v2: netG = GlobalGenerator_DCDCv2(input_nc, output_nc, ngf, k_size, n_downsample_global, norm_layer, opt=opt) else: netG = GlobalGenerator_v2(input_nc, output_nc, ngf, k_size, n_downsample_global, n_blocks_global, norm_layer, opt=opt) else: raise('generator not implemented!') print(netG) if len(gpu_ids) > 0: assert(torch.cuda.is_available()) netG.cuda(gpu_ids[0]) netG.apply(weights_init) return netG def define_D(input_nc, ndf, n_layers_D, opt, norm='instance', use_sigmoid=False, num_D=1, getIntermFeat=False, gpu_ids=[]): norm_layer = get_norm_layer(norm_type=norm) netD = MultiscaleDiscriminator(input_nc, opt, ndf, n_layers_D, norm_layer, use_sigmoid, num_D, getIntermFeat) print(netD) if len(gpu_ids) > 0: assert(torch.cuda.is_available()) netD.cuda(gpu_ids[0]) netD.apply(weights_init) return netD class GlobalGenerator_DCDCv2(nn.Module): def __init__( self, input_nc, output_nc, ngf=64, k_size=3, n_downsampling=8, norm_layer=nn.BatchNorm2d, padding_type="reflect", opt=None, ): super(GlobalGenerator_DCDCv2, self).__init__() activation = nn.ReLU(True) model = [ nn.ReflectionPad2d(3), nn.Conv2d(input_nc, min(ngf, opt.mc), kernel_size=7, padding=0), norm_layer(ngf), activation, ] ### downsample for i in range(opt.start_r): mult = 2 ** i model += [ nn.Conv2d( min(ngf * mult, opt.mc), min(ngf * mult * 2, opt.mc), kernel_size=k_size, stride=2, padding=1, ), norm_layer(min(ngf * mult * 2, opt.mc)), activation, ] for i in range(opt.start_r, n_downsampling - 1): mult = 2 ** i model += [ nn.Conv2d( min(ngf * mult, opt.mc), min(ngf * mult * 2, opt.mc), kernel_size=k_size, stride=2, padding=1, ), norm_layer(min(ngf * mult * 2, opt.mc)), activation, ] model += [ ResnetBlock( min(ngf * mult * 2, opt.mc), padding_type=padding_type, activation=activation, norm_layer=norm_layer, opt=opt, ) ] model += [ ResnetBlock( min(ngf * mult * 2, opt.mc), padding_type=padding_type, activation=activation, norm_layer=norm_layer, opt=opt, ) ] mult = 2 ** (n_downsampling - 1) if opt.spatio_size == 32: model += [ nn.Conv2d( min(ngf * mult, opt.mc), min(ngf * mult * 2, opt.mc), kernel_size=k_size, stride=2, padding=1, ), norm_layer(min(ngf * mult * 2, opt.mc)), activation, ] if opt.spatio_size == 64: model += [ ResnetBlock( min(ngf * mult * 2, opt.mc), padding_type=padding_type, activation=activation, norm_layer=norm_layer, opt=opt, ) ] model += [ ResnetBlock( min(ngf * mult * 2, opt.mc), padding_type=padding_type, activation=activation, norm_layer=norm_layer, opt=opt, ) ] # model += [nn.Conv2d(min(ngf * mult * 2, opt.mc), min(ngf, opt.mc), 1, 1)] if opt.feat_dim > 0: model += [nn.Conv2d(min(ngf * mult * 2, opt.mc), opt.feat_dim, 1, 1)] self.encoder = nn.Sequential(*model) # decode model = [] if opt.feat_dim > 0: model += [nn.Conv2d(opt.feat_dim, min(ngf * mult * 2, opt.mc), 1, 1)] # model += [nn.Conv2d(min(ngf, opt.mc), min(ngf * mult * 2, opt.mc), 1, 1)] o_pad = 0 if k_size == 4 else 1 mult = 2 ** n_downsampling model += [ ResnetBlock( min(ngf * mult, opt.mc), padding_type=padding_type, activation=activation, norm_layer=norm_layer, opt=opt, ) ] if opt.spatio_size == 32: model += [ nn.ConvTranspose2d( min(ngf * mult, opt.mc), min(int(ngf * mult / 2), opt.mc), kernel_size=k_size, stride=2, padding=1, output_padding=o_pad, ), norm_layer(min(int(ngf * mult / 2), opt.mc)), activation, ] if opt.spatio_size == 64: model += [ ResnetBlock( min(ngf * mult, opt.mc), padding_type=padding_type, activation=activation, norm_layer=norm_layer, opt=opt, ) ] for i in range(1, n_downsampling - opt.start_r): mult = 2 ** (n_downsampling - i) model += [ ResnetBlock( min(ngf * mult, opt.mc), padding_type=padding_type, activation=activation, norm_layer=norm_layer, opt=opt, ) ] model += [ ResnetBlock( min(ngf * mult, opt.mc), padding_type=padding_type, activation=activation, norm_layer=norm_layer, opt=opt, ) ] model += [ nn.ConvTranspose2d( min(ngf * mult, opt.mc), min(int(ngf * mult / 2), opt.mc), kernel_size=k_size, stride=2, padding=1, output_padding=o_pad, ), norm_layer(min(int(ngf * mult / 2), opt.mc)), activation, ] for i in range(n_downsampling - opt.start_r, n_downsampling): mult = 2 ** (n_downsampling - i) model += [ nn.ConvTranspose2d( min(ngf * mult, opt.mc), min(int(ngf * mult / 2), opt.mc), kernel_size=k_size, stride=2, padding=1, output_padding=o_pad, ), norm_layer(min(int(ngf * mult / 2), opt.mc)), activation, ] if opt.use_segmentation_model: model += [nn.ReflectionPad2d(3), nn.Conv2d(min(ngf, opt.mc), output_nc, kernel_size=7, padding=0)] else: model += [ nn.ReflectionPad2d(3), nn.Conv2d(min(ngf, opt.mc), output_nc, kernel_size=7, padding=0), nn.Tanh(), ] self.decoder = nn.Sequential(*model) def forward(self, input, flow="enc_dec"): if flow == "enc": return self.encoder(input) elif flow == "dec": return self.decoder(input) elif flow == "enc_dec": x = self.encoder(input) x = self.decoder(x) return x # Define a resnet block class ResnetBlock(nn.Module): def __init__( self, dim, padding_type, norm_layer, opt, activation=nn.ReLU(True), use_dropout=False, dilation=1 ): super(ResnetBlock, self).__init__() self.opt = opt self.dilation = dilation self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, activation, use_dropout) def build_conv_block(self, dim, padding_type, norm_layer, activation, use_dropout): conv_block = [] p = 0 if padding_type == "reflect": conv_block += [nn.ReflectionPad2d(self.dilation)] elif padding_type == "replicate": conv_block += [nn.ReplicationPad2d(self.dilation)] elif padding_type == "zero": p = self.dilation else: raise NotImplementedError("padding [%s] is not implemented" % padding_type) conv_block += [ nn.Conv2d(dim, dim, kernel_size=3, padding=p, dilation=self.dilation), norm_layer(dim), activation, ] if use_dropout: conv_block += [nn.Dropout(0.5)] p = 0 if padding_type == "reflect": conv_block += [nn.ReflectionPad2d(1)] elif padding_type == "replicate": conv_block += [nn.ReplicationPad2d(1)] elif padding_type == "zero": p = 1 else: raise NotImplementedError("padding [%s] is not implemented" % padding_type) conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, dilation=1), norm_layer(dim)] return nn.Sequential(*conv_block) def forward(self, x): out = x + self.conv_block(x) return out class Encoder(nn.Module): def __init__(self, input_nc, output_nc, ngf=32, n_downsampling=4, norm_layer=nn.BatchNorm2d): super(Encoder, self).__init__() self.output_nc = output_nc model = [ nn.ReflectionPad2d(3), nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0), norm_layer(ngf), nn.ReLU(True), ] ### downsample for i in range(n_downsampling): mult = 2 ** i model += [ nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1), norm_layer(ngf * mult * 2), nn.ReLU(True), ] ### upsample for i in range(n_downsampling): mult = 2 ** (n_downsampling - i) model += [ nn.ConvTranspose2d( ngf * mult, int(ngf * mult / 2), kernel_size=3, stride=2, padding=1, output_padding=1 ), norm_layer(int(ngf * mult / 2)), nn.ReLU(True), ] model += [nn.ReflectionPad2d(3), nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0), nn.Tanh()] self.model = nn.Sequential(*model) def forward(self, input, inst): outputs = self.model(input) # instance-wise average pooling outputs_mean = outputs.clone() inst_list = np.unique(inst.cpu().numpy().astype(int)) for i in inst_list: for b in range(input.size()[0]): indices = (inst[b : b + 1] == int(i)).nonzero() # n x 4 for j in range(self.output_nc): output_ins = outputs[indices[:, 0] + b, indices[:, 1] + j, indices[:, 2], indices[:, 3]] mean_feat = torch.mean(output_ins).expand_as(output_ins) outputs_mean[ indices[:, 0] + b, indices[:, 1] + j, indices[:, 2], indices[:, 3] ] = mean_feat return outputs_mean def SN(module, mode=True): if mode: return torch.nn.utils.spectral_norm(module) return module class NonLocalBlock2D_with_mask_Res(nn.Module): def __init__( self, in_channels, inter_channels, mode="add", re_norm=False, temperature=1.0, use_self=False, cosin=False, ): super(NonLocalBlock2D_with_mask_Res, self).__init__() self.cosin = cosin self.renorm = re_norm self.in_channels = in_channels self.inter_channels = inter_channels self.g = nn.Conv2d( in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1, stride=1, padding=0 ) self.W = nn.Conv2d( in_channels=self.inter_channels, out_channels=self.in_channels, kernel_size=1, stride=1, padding=0 ) # for pytorch 0.3.1 # nn.init.constant(self.W.weight, 0) # nn.init.constant(self.W.bias, 0) # for pytorch 0.4.0 nn.init.constant_(self.W.weight, 0) nn.init.constant_(self.W.bias, 0) self.theta = nn.Conv2d( in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1, stride=1, padding=0 ) self.phi = nn.Conv2d( in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1, stride=1, padding=0 ) self.mode = mode self.temperature = temperature self.use_self = use_self norm_layer = get_norm_layer(norm_type="instance") activation = nn.ReLU(True) model = [] for i in range(3): model += [ ResnetBlock( inter_channels, padding_type="reflect", activation=activation, norm_layer=norm_layer, opt=None, ) ] self.res_block = nn.Sequential(*model) def forward(self, x, mask): ## The shape of mask is Batch*1*H*W batch_size = x.size(0) g_x = self.g(x).view(batch_size, self.inter_channels, -1) g_x = g_x.permute(0, 2, 1) theta_x = self.theta(x).view(batch_size, self.inter_channels, -1) theta_x = theta_x.permute(0, 2, 1) phi_x = self.phi(x).view(batch_size, self.inter_channels, -1) if self.cosin: theta_x = F.normalize(theta_x, dim=2) phi_x = F.normalize(phi_x, dim=1) f = torch.matmul(theta_x, phi_x) f /= self.temperature f_div_C = F.softmax(f, dim=2) tmp = 1 - mask mask = F.interpolate(mask, (x.size(2), x.size(3)), mode="bilinear") mask[mask > 0] = 1.0 mask = 1 - mask tmp = F.interpolate(tmp, (x.size(2), x.size(3))) mask *= tmp mask_expand = mask.view(batch_size, 1, -1) mask_expand = mask_expand.repeat(1, x.size(2) * x.size(3), 1) # mask = 1 - mask # mask=F.interpolate(mask,(x.size(2),x.size(3))) # mask_expand=mask.view(batch_size,1,-1) # mask_expand=mask_expand.repeat(1,x.size(2)*x.size(3),1) if self.use_self: mask_expand[:, range(x.size(2) * x.size(3)), range(x.size(2) * x.size(3))] = 1.0 # print(mask_expand.shape) # print(f_div_C.shape) f_div_C = mask_expand * f_div_C if self.renorm: f_div_C = F.normalize(f_div_C, p=1, dim=2) ########################### y = torch.matmul(f_div_C, g_x) y = y.permute(0, 2, 1).contiguous() y = y.view(batch_size, self.inter_channels, *x.size()[2:]) W_y = self.W(y) W_y = self.res_block(W_y) if self.mode == "combine": full_mask = mask.repeat(1, self.inter_channels, 1, 1) z = full_mask * x + (1 - full_mask) * W_y return z class MultiscaleDiscriminator(nn.Module): def __init__(self, input_nc, opt, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, use_sigmoid=False, num_D=3, getIntermFeat=False): super(MultiscaleDiscriminator, self).__init__() self.num_D = num_D self.n_layers = n_layers self.getIntermFeat = getIntermFeat for i in range(num_D): netD = NLayerDiscriminator(input_nc, opt, ndf, n_layers, norm_layer, use_sigmoid, getIntermFeat) if getIntermFeat: for j in range(n_layers+2): setattr(self, 'scale'+str(i)+'_layer'+str(j), getattr(netD, 'model'+str(j))) else: setattr(self, 'layer'+str(i), netD.model) self.downsample = nn.AvgPool2d(3, stride=2, padding=[1, 1], count_include_pad=False) def singleD_forward(self, model, input): if self.getIntermFeat: result = [input] for i in range(len(model)): result.append(model[i](result[-1])) return result[1:] else: return [model(input)] def forward(self, input): num_D = self.num_D result = [] input_downsampled = input for i in range(num_D): if self.getIntermFeat: model = [getattr(self, 'scale'+str(num_D-1-i)+'_layer'+str(j)) for j in range(self.n_layers+2)] else: model = getattr(self, 'layer'+str(num_D-1-i)) result.append(self.singleD_forward(model, input_downsampled)) if i != (num_D-1): input_downsampled = self.downsample(input_downsampled) return result # Defines the PatchGAN discriminator with the specified arguments. class NLayerDiscriminator(nn.Module): def __init__(self, input_nc, opt, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, use_sigmoid=False, getIntermFeat=False): super(NLayerDiscriminator, self).__init__() self.getIntermFeat = getIntermFeat self.n_layers = n_layers kw = 4 padw = int(np.ceil((kw-1.0)/2)) sequence = [[SN(nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw),opt.use_SN), nn.LeakyReLU(0.2, True)]] nf = ndf for n in range(1, n_layers): nf_prev = nf nf = min(nf * 2, 512) sequence += [[ SN(nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=2, padding=padw),opt.use_SN), norm_layer(nf), nn.LeakyReLU(0.2, True) ]] nf_prev = nf nf = min(nf * 2, 512) sequence += [[ SN(nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=1, padding=padw),opt.use_SN), norm_layer(nf), nn.LeakyReLU(0.2, True) ]] sequence += [[SN(nn.Conv2d(nf, 1, kernel_size=kw, stride=1, padding=padw),opt.use_SN)]] if use_sigmoid: sequence += [[nn.Sigmoid()]] if getIntermFeat: for n in range(len(sequence)): setattr(self, 'model'+str(n), nn.Sequential(*sequence[n])) else: sequence_stream = [] for n in range(len(sequence)): sequence_stream += sequence[n] self.model = nn.Sequential(*sequence_stream) def forward(self, input): if self.getIntermFeat: res = [input] for n in range(self.n_layers+2): model = getattr(self, 'model'+str(n)) res.append(model(res[-1])) return res[1:] else: return self.model(input) class Patch_Attention_4(nn.Module): ## While combine the feature map, use conv and mask def __init__(self, in_channels, inter_channels, patch_size): super(Patch_Attention_4, self).__init__() self.patch_size=patch_size # self.g = nn.Conv2d( # in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1, stride=1, padding=0 # ) # self.W = nn.Conv2d( # in_channels=self.inter_channels, out_channels=self.in_channels, kernel_size=1, stride=1, padding=0 # ) # # for pytorch 0.3.1 # # nn.init.constant(self.W.weight, 0) # # nn.init.constant(self.W.bias, 0) # # for pytorch 0.4.0 # nn.init.constant_(self.W.weight, 0) # nn.init.constant_(self.W.bias, 0) # self.theta = nn.Conv2d( # in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1, stride=1, padding=0 # ) # self.phi = nn.Conv2d( # in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1, stride=1, padding=0 # ) self.F_Combine=nn.Conv2d(in_channels=1025,out_channels=512,kernel_size=3,stride=1,padding=1,bias=True) norm_layer = get_norm_layer(norm_type="instance") activation = nn.ReLU(True) model = [] for i in range(1): model += [ ResnetBlock( inter_channels, padding_type="reflect", activation=activation, norm_layer=norm_layer, opt=None, ) ] self.res_block = nn.Sequential(*model) def Hard_Compose(self, input, dim, index): # batch index select # input: [B,C,HW] # dim: scalar > 0 # index: [B, HW] views = [input.size(0)] + [1 if i!=dim else -1 for i in range(1, len(input.size()))] expanse = list(input.size()) expanse[0] = -1 expanse[dim] = -1 index = index.view(views).expand(expanse) return torch.gather(input, dim, index) def forward(self, z, mask): ## The shape of mask is Batch*1*H*W x=self.res_block(z) b,c,h,w=x.shape ## mask resize + dilation # tmp = 1 - mask mask = F.interpolate(mask, (x.size(2), x.size(3)), mode="bilinear") mask[mask > 0] = 1.0 # mask = 1 - mask # tmp = F.interpolate(tmp, (x.size(2), x.size(3))) # mask *= tmp # mask=1-mask ## 1: mask position 0: non-mask mask_unfold=F.unfold(mask, kernel_size=(self.patch_size,self.patch_size), padding=0, stride=self.patch_size) non_mask_region=(torch.mean(mask_unfold,dim=1,keepdim=True)>0.6).float() all_patch_num=h*w/self.patch_size/self.patch_size non_mask_region=non_mask_region.repeat(1,int(all_patch_num),1) x_unfold=F.unfold(x, kernel_size=(self.patch_size,self.patch_size), padding=0, stride=self.patch_size) y_unfold=x_unfold.permute(0,2,1) x_unfold_normalized=F.normalize(x_unfold,dim=1) y_unfold_normalized=F.normalize(y_unfold,dim=2) correlation_matrix=torch.bmm(y_unfold_normalized,x_unfold_normalized) correlation_matrix=correlation_matrix.masked_fill(non_mask_region==1.,-1e9) correlation_matrix=F.softmax(correlation_matrix,dim=2) # print(correlation_matrix) R, max_arg=torch.max(correlation_matrix,dim=2) composed_unfold=self.Hard_Compose(x_unfold, 2, max_arg) composed_fold=F.fold(composed_unfold,output_size=(h,w),kernel_size=(self.patch_size,self.patch_size),padding=0,stride=self.patch_size) concat_1=torch.cat((z,composed_fold,mask),dim=1) concat_1=self.F_Combine(concat_1) return concat_1 def inference_forward(self,z,mask): ## Reduce the extra memory cost x=self.res_block(z) b,c,h,w=x.shape ## mask resize + dilation # tmp = 1 - mask mask = F.interpolate(mask, (x.size(2), x.size(3)), mode="bilinear") mask[mask > 0] = 1.0 # mask = 1 - mask # tmp = F.interpolate(tmp, (x.size(2), x.size(3))) # mask *= tmp # mask=1-mask ## 1: mask position 0: non-mask mask_unfold=F.unfold(mask, kernel_size=(self.patch_size,self.patch_size), padding=0, stride=self.patch_size) non_mask_region=(torch.mean(mask_unfold,dim=1,keepdim=True)>0.6).float()[0,0,:] # 1*1*all_patch_num all_patch_num=h*w/self.patch_size/self.patch_size mask_index=torch.nonzero(non_mask_region,as_tuple=True)[0] if len(mask_index)==0: ## No mask patch is selected, no attention is needed composed_fold=x else: unmask_index=torch.nonzero(non_mask_region!=1,as_tuple=True)[0] x_unfold=F.unfold(x, kernel_size=(self.patch_size,self.patch_size), padding=0, stride=self.patch_size) Query_Patch=torch.index_select(x_unfold,2,mask_index) Key_Patch=torch.index_select(x_unfold,2,unmask_index) Query_Patch=Query_Patch.permute(0,2,1) Query_Patch_normalized=F.normalize(Query_Patch,dim=2) Key_Patch_normalized=F.normalize(Key_Patch,dim=1) correlation_matrix=torch.bmm(Query_Patch_normalized,Key_Patch_normalized) correlation_matrix=F.softmax(correlation_matrix,dim=2) R, max_arg=torch.max(correlation_matrix,dim=2) composed_unfold=self.Hard_Compose(Key_Patch, 2, max_arg) x_unfold[:,:,mask_index]=composed_unfold composed_fold=F.fold(x_unfold,output_size=(h,w),kernel_size=(self.patch_size,self.patch_size),padding=0,stride=self.patch_size) concat_1=torch.cat((z,composed_fold,mask),dim=1) concat_1=self.F_Combine(concat_1) return concat_1 ############################################################################## # Losses ############################################################################## class GANLoss(nn.Module): def __init__(self, use_lsgan=True, target_real_label=1.0, target_fake_label=0.0, tensor=torch.FloatTensor): super(GANLoss, self).__init__() self.real_label = target_real_label self.fake_label = target_fake_label self.real_label_var = None self.fake_label_var = None self.Tensor = tensor if use_lsgan: self.loss = nn.MSELoss() else: self.loss = nn.BCELoss() def get_target_tensor(self, input, target_is_real): target_tensor = None if target_is_real: create_label = ((self.real_label_var is None) or (self.real_label_var.numel() != input.numel())) if create_label: real_tensor = self.Tensor(input.size()).fill_(self.real_label) self.real_label_var = Variable(real_tensor, requires_grad=False) target_tensor = self.real_label_var else: create_label = ((self.fake_label_var is None) or (self.fake_label_var.numel() != input.numel())) if create_label: fake_tensor = self.Tensor(input.size()).fill_(self.fake_label) self.fake_label_var = Variable(fake_tensor, requires_grad=False) target_tensor = self.fake_label_var return target_tensor def __call__(self, input, target_is_real): if isinstance(input[0], list): loss = 0 for input_i in input: pred = input_i[-1] target_tensor = self.get_target_tensor(pred, target_is_real) loss += self.loss(pred, target_tensor) return loss else: target_tensor = self.get_target_tensor(input[-1], target_is_real) return self.loss(input[-1], target_tensor) ####################################### VGG Loss from torchvision import models class VGG19_torch(torch.nn.Module): def __init__(self, requires_grad=False): super(VGG19_torch, self).__init__() vgg_pretrained_features = models.vgg19(pretrained=True).features self.slice1 = torch.nn.Sequential() self.slice2 = torch.nn.Sequential() self.slice3 = torch.nn.Sequential() self.slice4 = torch.nn.Sequential() self.slice5 = torch.nn.Sequential() for x in range(2): self.slice1.add_module(str(x), vgg_pretrained_features[x]) for x in range(2, 7): self.slice2.add_module(str(x), vgg_pretrained_features[x]) for x in range(7, 12): self.slice3.add_module(str(x), vgg_pretrained_features[x]) for x in range(12, 21): self.slice4.add_module(str(x), vgg_pretrained_features[x]) for x in range(21, 30): self.slice5.add_module(str(x), vgg_pretrained_features[x]) if not requires_grad: for param in self.parameters(): param.requires_grad = False def forward(self, X): h_relu1 = self.slice1(X) h_relu2 = self.slice2(h_relu1) h_relu3 = self.slice3(h_relu2) h_relu4 = self.slice4(h_relu3) h_relu5 = self.slice5(h_relu4) out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5] return out class VGGLoss_torch(nn.Module): def __init__(self, gpu_ids): super(VGGLoss_torch, self).__init__() self.vgg = VGG19_torch().cuda() self.criterion = nn.L1Loss() self.weights = [1.0/32, 1.0/16, 1.0/8, 1.0/4, 1.0] def forward(self, x, y): x_vgg, y_vgg = self.vgg(x), self.vgg(y) loss = 0 for i in range(len(x_vgg)): loss += self.weights[i] * self.criterion(x_vgg[i], y_vgg[i].detach()) return loss