import torch import os import torch.nn as nn import functools from torch.autograd import Variable import numpy as np import torch.nn.functional as F import math import torch import itertools import numpy as np import torch.nn as nn import torch.nn.functional as F from grid_sample import grid_sample from torch.autograd import Variable from tps_grid_gen import TPSGridGen ############################################################################### # Functions ############################################################################### def weights_init(m): classname = m.__class__.__name__ if classname.find('Conv2d') != -1: m.weight.data.normal_(0.0, 0.02) elif classname.find('BatchNorm2d') != -1: m.weight.data.normal_(1.0, 0.02) m.bias.data.fill_(0) def get_norm_layer(norm_type='instance'): if norm_type == 'batch': norm_layer = functools.partial(nn.BatchNorm2d, affine=True) elif norm_type == 'instance': norm_layer = functools.partial(nn.InstanceNorm2d, affine=False) else: raise NotImplementedError('normalization layer [%s] is not found' % norm_type) return norm_layer def define_G(input_nc, output_nc, ngf, netG, L=1, S=1, n_downsample_global=3, n_blocks_global=9, n_local_enhancers=1, n_blocks_local=3, norm='instance', gpu_ids=[]): norm_layer = get_norm_layer(norm_type=norm) if netG == 'global': netG = GlobalGenerator(input_nc, output_nc, L, S, ngf, n_downsample_global, n_blocks_global, norm_layer) elif netG == 'local': netG = LocalEnhancer(input_nc, output_nc, ngf, n_downsample_global, n_blocks_global, n_local_enhancers, n_blocks_local, norm_layer) else: raise ('generator not implemented!') print(netG) if len(gpu_ids) > 0: assert (torch.cuda.is_available()) netG.cuda(gpu_ids[0]) netG.apply(weights_init) return netG def define_Unet(input_nc, gpu_ids=[]): netG = Unet(input_nc) netG.cuda(gpu_ids[0]) netG.apply(weights_init) return netG def define_UnetMask(input_nc, gpu_ids=[]): netG = UnetMask(input_nc,output_nc=4) netG.cuda(gpu_ids[0]) netG.apply(weights_init) return netG def define_Refine(input_nc, output_nc, gpu_ids=[]): netG = Refine(input_nc, output_nc) netG.cuda(gpu_ids[0]) netG.apply(weights_init) return netG #################################################### def define_Refine_ResUnet(input_nc, output_nc, gpu_ids=[]): #ipdb.set_trace() netG = Refine_ResUnet_New(input_nc, output_nc) #norm_layer=nn.InstanceNorm2d #ipdb.set_trace() netG.cuda(gpu_ids[0]) netG.apply(weights_init) return netG #################################################### def define_D(input_nc, ndf, n_layers_D, norm='instance', use_sigmoid=False, num_D=1, getIntermFeat=False, gpu_ids=[]): norm_layer = get_norm_layer(norm_type=norm) netD = MultiscaleDiscriminator(input_nc, ndf, n_layers_D, norm_layer, use_sigmoid, num_D, getIntermFeat) print(netD) if len(gpu_ids) > 0: assert (torch.cuda.is_available()) netD.cuda(gpu_ids[0]) netD.apply(weights_init) return netD def define_VAE(input_nc, gpu_ids=[]): netVAE = VAE(19, 32, 32, 1024) print(netVAE) if len(gpu_ids) > 0: assert (torch.cuda.is_available()) netVAE.cuda(gpu_ids[0]) return netVAE def define_B(input_nc, output_nc, ngf, n_downsample_global=3, n_blocks_global=3, norm='instance', gpu_ids=[]): norm_layer = get_norm_layer(norm_type=norm) netB = BlendGenerator(input_nc, output_nc, ngf, n_downsample_global, n_blocks_global, norm_layer) print(netB) if len(gpu_ids) > 0: assert (torch.cuda.is_available()) netB.cuda(gpu_ids[0]) netB.apply(weights_init) return netB def define_partial_enc(input_nc, gpu_ids=[]): net = PartialConvEncoder(input_nc) print(net) if len(gpu_ids) > 0: assert (torch.cuda.is_available()) net.cuda(gpu_ids[0]) net.apply(weights_init) return net def define_conv_enc(input_nc, gpu_ids=[]): net = ConvEncoder(input_nc) print(net) if len(gpu_ids) > 0: assert (torch.cuda.is_available()) net.cuda(gpu_ids[0]) net.apply(weights_init) return net def define_AttG(output_nc, gpu_ids=[]): net = AttGenerator(output_nc) print(net) if len(gpu_ids) > 0: assert (torch.cuda.is_available()) net.cuda(gpu_ids[0]) net.apply(weights_init) return net def print_network(net): if isinstance(net, list): net = net[0] num_params = 0 for param in net.parameters(): num_params += param.numel() print(net) print('Total number of parameters: %d' % num_params) ############################################################################## # Losses ############################################################################## class GANLoss(nn.Module): def __init__(self, use_lsgan=True, target_real_label=1.0, target_fake_label=0.0, tensor=torch.FloatTensor): super(GANLoss, self).__init__() self.real_label = target_real_label self.fake_label = target_fake_label self.real_label_var = None self.fake_label_var = None self.Tensor = tensor if use_lsgan: self.loss = nn.MSELoss() else: self.loss = nn.BCELoss() def get_target_tensor(self, input, target_is_real): target_tensor = None if target_is_real: create_label = ((self.real_label_var is None) or (self.real_label_var.numel() != input.numel())) if create_label: real_tensor = self.Tensor(input.size()).fill_(self.real_label) self.real_label_var = Variable(real_tensor, requires_grad=False) target_tensor = self.real_label_var else: create_label = ((self.fake_label_var is None) or (self.fake_label_var.numel() != input.numel())) if create_label: fake_tensor = self.Tensor(input.size()).fill_(self.fake_label) self.fake_label_var = Variable(fake_tensor, requires_grad=False) target_tensor = self.fake_label_var return target_tensor def __call__(self, input, target_is_real): if isinstance(input[0], list): loss = 0 for input_i in input: pred = input_i[-1] target_tensor = self.get_target_tensor(pred, target_is_real) loss += self.loss(pred, target_tensor) return loss else: target_tensor = self.get_target_tensor(input[-1], target_is_real) return self.loss(input[-1], target_tensor) class VGGLossWarp(nn.Module): def __init__(self, gpu_ids): super(VGGLossWarp, self).__init__() self.vgg = Vgg19().cuda() self.criterion = nn.L1Loss() self.weights = [1.0 / 32, 1.0 / 16, 1.0 / 8, 1.0 / 4, 1.0] def forward(self, x, y): x_vgg, y_vgg = self.vgg(x), self.vgg(y) loss = 0 loss += self.weights[4] * self.criterion(x_vgg[4], y_vgg[4].detach()) return loss class VGGLoss(nn.Module): def __init__(self, gpu_ids): super(VGGLoss, self).__init__() self.vgg = Vgg19().cuda() self.criterion = nn.L1Loss() self.weights = [1.0 / 32, 1.0 / 16, 1.0 / 8, 1.0 / 4, 1.0] def forward(self, x, y): x_vgg, y_vgg = self.vgg(x), self.vgg(y) loss = 0 for i in range(len(x_vgg)): loss += self.weights[i] * self.criterion(x_vgg[i], y_vgg[i].detach()) return loss def warp(self, x, y): x_vgg, y_vgg = self.vgg(x), self.vgg(y) loss = 0 loss += self.weights[4] * self.criterion(x_vgg[4], y_vgg[4].detach()) return loss class StyleLoss(nn.Module): def __init__(self, gpu_ids): super(StyleLoss, self).__init__() self.vgg = Vgg19().cuda() self.weights = [1.0 / 32, 1.0 / 16, 1.0 / 8, 1.0 / 4, 1.0] def forward(self, x, y): x_vgg, y_vgg = self.vgg(x), self.vgg(y) loss = 0 for i in range(len(x_vgg)): N, C, H, W = x_vgg[i].shape for n in range(N): phi_x = x_vgg[i][n] phi_y = y_vgg[i][n] phi_x = phi_x.reshape(C, H * W) phi_y = phi_y.reshape(C, H * W) G_x = torch.matmul(phi_x, phi_x.t()) / (C * H * W) G_y = torch.matmul(phi_y, phi_y.t()) / (C * H * W) loss += torch.sqrt(torch.mean((G_x - G_y) ** 2)) * self.weights[i] return loss ############################################################################## # Generator ############################################################################## class PartialConvEncoder(nn.Module): def __init__(self, input_nc, ngf=32, norm_layer=nn.BatchNorm2d): super(PartialConvEncoder, self).__init__() activation = nn.ReLU(True) self.pad1 = nn.ReflectionPad2d(3) self.partial_conv1 = PartialConv(input_nc, ngf, kernel_size=7) self.norm_layer1 = norm_layer(ngf) self.activation = activation ##down sample mult = 2 ** 0 self.down1 = PartialConv(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1) self.norm_layer2 = norm_layer(ngf * mult * 2) mult = 2 ** 1 self.down2 = PartialConv(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1) self.norm_layer3 = norm_layer(ngf * mult * 2) mult = 2 ** 2 self.down3 = PartialConv(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1) self.norm_layer4 = norm_layer(ngf * mult * 2) mult = 2 ** 3 self.down4 = PartialConv(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1) self.norm_layer5 = norm_layer(ngf * mult * 2) def forward(self, input, mask): input = self.pad1(input) mask = self.pad1(mask) input, mask = self.partial_conv1(input, mask) input = self.norm_layer1(input) input = self.activation(input) input, mask = self.down1(input, mask) input = self.norm_layer2(input) input = self.activation(input) input, mask = self.down2(input, mask) input = self.norm_layer3(input) input = self.activation(input) input, mask = self.down3(input, mask) input = self.norm_layer4(input) input = self.activation(input) input, mask = self.down4(input, mask) input = self.norm_layer5(input) input = self.activation(input) return input class ConvEncoder(nn.Module): def __init__(self, input_nc, ngf=32, n_downsampling=4, n_blocks=4, norm_layer=nn.BatchNorm2d, padding_type='reflect'): super(ConvEncoder, self).__init__() activation = nn.ReLU(True) # print("input_nc",input_nc) model = [nn.ReflectionPad2d(3), nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0), norm_layer(ngf), activation] ### downsample for i in range(n_downsampling): stride = 2 mult = 2 ** i model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=stride, padding=1), norm_layer(ngf * mult * 2), activation] self.model = nn.Sequential(*model) def forward(self, input): return self.model(input) class AttGenerator(nn.Module): def __init__(self, output_nc, ngf=32, n_blocks=4, n_downsampling=4, padding_type='reflect'): super(AttGenerator, self).__init__() mult = 2 ** n_downsampling model = [] for i in range(n_blocks): model += [ResnetBlock(ngf * mult * 2, norm_type='in', padding_type=padding_type)] self.model = nn.Sequential(*model) self.upsampling = [] self.out_channels = [] self.AttNorm = [] ##upsampling norm_layer = nn.BatchNorm2d activation = nn.ReLU(True) for i in range(n_downsampling): mult = 2 ** (n_downsampling - i) up_module = [nn.ConvTranspose2d(ngf * mult * 2, int(ngf * mult / 2) * 2, kernel_size=3, stride=2, padding=1, output_padding=1), norm_layer(int(ngf * mult / 2) * 2), activation ] up_module = nn.Sequential(*up_module) self.upsampling += [up_module] self.out_channels += [int(ngf * mult / 2) * 2] self.upsampling = nn.Sequential(*self.upsampling) # self.AttNorm += [AttentionNorm(5, self.out_channels[0], 2, 4)] self.AttNorm += [AttentionNorm(5, self.out_channels[1], 2, 2)] self.AttNorm += [AttentionNorm(5, self.out_channels[2], 1, 2)] self.AttNorm += [AttentionNorm(5, self.out_channels[3], 1, 1)] self.AttNorm = nn.Sequential(*self.AttNorm) self.last_conv = [nn.ReflectionPad2d(3), nn.Conv2d(ngf * 2, output_nc, kernel_size=7, padding=0), nn.Tanh()] self.last_conv = nn.Sequential(*self.last_conv) def forward(self, input, unattended): up = self.model(unattended) for i in range(4): # print(i) up = self.upsampling[i](up) if i == 3: break; up = self.AttNorm[i](input, up) return self.last_conv(up) class PartialConv(nn.Module): def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True): super(PartialConv, self).__init__() self.input_conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias) self.mask_conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, False) self.input_conv.apply(weights_init) torch.nn.init.constant_(self.mask_conv.weight, 1.0) # mask is not updated for param in self.mask_conv.parameters(): param.requires_grad = False def forward(self, input, mask): # http://masc.cs.gmu.edu/wiki/partialconv # C(X) = W^T * X + b, C(0) = b, D(M) = 1 * M + 0 = sum(M) # W^T* (M .* X) / sum(M) + b = [C(M .* X) – C(0)] / D(M) + C(0) output = self.input_conv(input * mask) if self.input_conv.bias is not None: output_bias = self.input_conv.bias.view(1, -1, 1, 1).expand_as( output) else: output_bias = torch.zeros_like(output) with torch.no_grad(): output_mask = self.mask_conv(mask) no_update_holes = output_mask == 0 mask_sum = output_mask.masked_fill_(no_update_holes, 1.0) output_pre = (output - output_bias) / mask_sum + output_bias output = output_pre.masked_fill_(no_update_holes, 0.0) new_mask = torch.ones_like(output) new_mask = new_mask.masked_fill_(no_update_holes, 0.0) return output, new_mask class AttentionNorm(nn.Module): def __init__(self, ref_channels, out_channels, first_rate, second_rate): super(AttentionNorm, self).__init__() self.first = first_rate self.second = second_rate mid_channels = int(out_channels / 2) self.conv_1time_f = nn.Conv2d(ref_channels, mid_channels, kernel_size=3, stride=1, padding=1) self.conv_2times_f = nn.Conv2d(ref_channels, mid_channels, kernel_size=3, stride=2, padding=1) self.conv_4times_f = nn.Conv2d(ref_channels, mid_channels, kernel_size=3, stride=4, padding=1) self.conv_1time_s = nn.Conv2d(mid_channels, out_channels, kernel_size=3, stride=1, padding=1) self.conv_2times_s = nn.Conv2d(mid_channels, out_channels, kernel_size=3, stride=2, padding=1) self.conv_4times_s = nn.Conv2d(mid_channels, out_channels, kernel_size=3, stride=4, padding=1) self.conv_1time_m = nn.Conv2d(mid_channels, out_channels, kernel_size=3, stride=1, padding=1) self.conv_2times_m = nn.Conv2d(mid_channels, out_channels, kernel_size=3, stride=2, padding=1) self.conv_4times_m = nn.Conv2d(mid_channels, out_channels, kernel_size=3, stride=4, padding=1) self.norm = nn.BatchNorm2d(out_channels) self.conv = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) def forward(self, input, unattended): # attention weights # print(input.shape,unattended.shape) if self.first == 1: input = self.conv_1time_f(input) elif self.first == 2: input = self.conv_2times_f(input) elif self.first == 4: input = self.conv_4times_f(input) mask = None if self.second == 1: bias = self.conv_1time_s(input) mask = self.conv_1time_m(input) elif self.second == 2: bias = self.conv_2times_s(input) mask = self.conv_2times_m(input) elif self.second == 4: bias = self.conv_4times_s(input) mask = self.conv_4times_m(input) mask = torch.sigmoid(mask) attended = self.norm(unattended) # print(attended.shape,mask.shape,bias.shape) attended = attended * mask + bias attended = torch.relu(attended) attended = self.conv(attended) output = attended + unattended return output class UnetMask(nn.Module): def __init__(self, input_nc, output_nc=3): super(UnetMask, self).__init__() self.stn = STNNet() nl = nn.InstanceNorm2d self.conv1 = nn.Sequential(*[nn.Conv2d(input_nc, 64, kernel_size=3, stride=1, padding=1), nl(64), nn.ReLU(), nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1), nl(64), nn.ReLU()]) self.pool1 = nn.MaxPool2d(kernel_size=(2, 2)) self.conv2 = nn.Sequential(*[nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1), nl(128), nn.ReLU(), nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1), nl(128), nn.ReLU()]) self.pool2 = nn.MaxPool2d(kernel_size=(2, 2)) self.conv3 = nn.Sequential(*[nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1), nl(256), nn.ReLU(), nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1), nl(256), nn.ReLU()]) self.pool3 = nn.MaxPool2d(kernel_size=(2, 2)) self.conv4 = nn.Sequential(*[nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1), nl(512), nn.ReLU(), nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1), nl(512), nn.ReLU()]) self.drop4 = nn.Dropout(0.5) self.pool4 = nn.MaxPool2d(kernel_size=(2, 2)) self.conv5 = nn.Sequential(*[nn.Conv2d(512, 1024, kernel_size=3, stride=1, padding=1), nl(1024), nn.ReLU(), nn.Conv2d(1024, 1024, kernel_size=3, stride=1, padding=1), nl(1024), nn.ReLU()]) self.drop5 = nn.Dropout(0.5) self.up6 = nn.Sequential( *[nn.UpsamplingNearest2d(scale_factor=2), nn.Conv2d(1024, 512, kernel_size=3, stride=1, padding=1), nl(512), nn.ReLU()]) self.conv6 = nn.Sequential(*[nn.Conv2d(1024, 512, kernel_size=3, stride=1, padding=1), nl(512), nn.ReLU(), nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1), nl(512), nn.ReLU()]) self.up7 = nn.Sequential( *[nn.UpsamplingNearest2d(scale_factor=2), nn.Conv2d(512, 256, kernel_size=3, stride=1, padding=1), nl(256), nn.ReLU()]) self.conv7 = nn.Sequential(*[nn.Conv2d(512, 256, kernel_size=3, stride=1, padding=1), nl(256), nn.ReLU(), nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1), nl(256), nn.ReLU()]) self.up8 = nn.Sequential( *[nn.UpsamplingNearest2d(scale_factor=2), nn.Conv2d(256, 128, kernel_size=3, stride=1, padding=1), nl(128), nn.ReLU()]) self.conv8 = nn.Sequential(*[nn.Conv2d(256, 128, kernel_size=3, stride=1, padding=1), nl(128), nn.ReLU(), nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1), nl(128), nn.ReLU()]) self.up9 = nn.Sequential( *[nn.UpsamplingNearest2d(scale_factor=2), nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1), nl(64), nn.ReLU()]) self.conv9 = nn.Sequential(*[nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1), nl(64), nn.ReLU(), nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1), nl(64), nn.ReLU(), nn.Conv2d(64, output_nc, kernel_size=3, stride=1, padding=1) ]) def forward(self, input, refer, mask,grid): input, warped_mask,rx,ry,cx,cy,grid = self.stn(input, torch.cat([mask, refer, input], 1), mask,grid) # print(input.shape) conv1 = self.conv1(torch.cat([refer.detach(), input.detach()], 1)) pool1 = self.pool1(conv1) conv2 = self.conv2(pool1) pool2 = self.pool2(conv2) conv3 = self.conv3(pool2) pool3 = self.pool3(conv3) conv4 = self.conv4(pool3) drop4 = self.drop4(conv4) pool4 = self.pool4(drop4) conv5 = self.conv5(pool4) drop5 = self.drop5(conv5) up6 = self.up6(drop5) conv6 = self.conv6(torch.cat([drop4, up6], 1)) up7 = self.up7(conv6) conv7 = self.conv7(torch.cat([conv3, up7], 1)) up8 = self.up8(conv7) conv8 = self.conv8(torch.cat([conv2, up8], 1)) up9 = self.up9(conv8) conv9 = self.conv9(torch.cat([conv1, up9], 1)) return conv9, input, warped_mask,grid class Unet(nn.Module): def __init__(self, input_nc, output_nc=3): super(Unet, self).__init__() self.stn = STNNet() nl = nn.InstanceNorm2d self.conv1 = nn.Sequential(*[nn.Conv2d(input_nc, 64, kernel_size=3, stride=1, padding=1), nl(64), nn.ReLU(), nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1), nl(64), nn.ReLU()]) self.pool1 = nn.MaxPool2d(kernel_size=(2, 2)) self.conv2 = nn.Sequential(*[nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1), nl(128), nn.ReLU(), nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1), nl(128), nn.ReLU()]) self.pool2 = nn.MaxPool2d(kernel_size=(2, 2)) self.conv3 = nn.Sequential(*[nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1), nl(256), nn.ReLU(), nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1), nl(256), nn.ReLU()]) self.pool3 = nn.MaxPool2d(kernel_size=(2, 2)) self.conv4 = nn.Sequential(*[nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1), nl(512), nn.ReLU(), nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1), nl(512), nn.ReLU()]) self.drop4 = nn.Dropout(0.5) self.pool4 = nn.MaxPool2d(kernel_size=(2, 2)) self.conv5 = nn.Sequential(*[nn.Conv2d(512, 1024, kernel_size=3, stride=1, padding=1), nl(1024), nn.ReLU(), nn.Conv2d(1024, 1024, kernel_size=3, stride=1, padding=1), nl(1024), nn.ReLU()]) self.drop5 = nn.Dropout(0.5) self.up6 = nn.Sequential( *[nn.UpsamplingNearest2d(scale_factor=2), nn.Conv2d(1024, 512, kernel_size=3, stride=1, padding=1), nl(512), nn.ReLU()]) self.conv6 = nn.Sequential(*[nn.Conv2d(1024, 512, kernel_size=3, stride=1, padding=1), nl(512), nn.ReLU(), nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1), nl(512), nn.ReLU()]) self.up7 = nn.Sequential( *[nn.UpsamplingNearest2d(scale_factor=2), nn.Conv2d(512, 256, kernel_size=3, stride=1, padding=1), nl(256), nn.ReLU()]) self.conv7 = nn.Sequential(*[nn.Conv2d(512, 256, kernel_size=3, stride=1, padding=1), nl(256), nn.ReLU(), nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1), nl(256), nn.ReLU()]) self.up8 = nn.Sequential( *[nn.UpsamplingNearest2d(scale_factor=2), nn.Conv2d(256, 128, kernel_size=3, stride=1, padding=1), nl(128), nn.ReLU()]) self.conv8 = nn.Sequential(*[nn.Conv2d(256, 128, kernel_size=3, stride=1, padding=1), nl(128), nn.ReLU(), nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1), nl(128), nn.ReLU()]) self.up9 = nn.Sequential( *[nn.UpsamplingNearest2d(scale_factor=2), nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1), nl(64), nn.ReLU()]) self.conv9 = nn.Sequential(*[nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1), nl(64), nn.ReLU(), nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1), nl(64), nn.ReLU(), nn.Conv2d(64, output_nc, kernel_size=3, stride=1, padding=1) ]) def forward(self, input, refer, mask): input, warped_mask,rx,ry,cx,cy = self.stn(input, torch.cat([mask, refer, input], 1), mask) # print(input.shape) conv1 = self.conv1(torch.cat([refer.detach(), input.detach()], 1)) pool1 = self.pool1(conv1) conv2 = self.conv2(pool1) pool2 = self.pool2(conv2) conv3 = self.conv3(pool2) pool3 = self.pool3(conv3) conv4 = self.conv4(pool3) drop4 = self.drop4(conv4) pool4 = self.pool4(drop4) conv5 = self.conv5(pool4) drop5 = self.drop5(conv5) up6 = self.up6(drop5) conv6 = self.conv6(torch.cat([drop4, up6], 1)) up7 = self.up7(conv6) conv7 = self.conv7(torch.cat([conv3, up7], 1)) up8 = self.up8(conv7) conv8 = self.conv8(torch.cat([conv2, up8], 1)) up9 = self.up9(conv8) conv9 = self.conv9(torch.cat([conv1, up9], 1)) return conv9, input, warped_mask,rx,ry,cx,cy def refine(self, input): conv1 = self.conv1(input) pool1 = self.pool1(conv1) conv2 = self.conv2(pool1) pool2 = self.pool2(conv2) conv3 = self.conv3(pool2) pool3 = self.pool3(conv3) conv4 = self.conv4(pool3) drop4 = self.drop4(conv4) pool4 = self.pool4(drop4) conv5 = self.conv5(pool4) drop5 = self.drop5(conv5) up6 = self.up6(drop5) conv6 = self.conv6(torch.cat([drop4, up6], 1)) up7 = self.up7(conv6) conv7 = self.conv7(torch.cat([conv3, up7], 1)) up8 = self.up8(conv7) conv8 = self.conv8(torch.cat([conv2, up8], 1)) up9 = self.up9(conv8) conv9 = self.conv9(torch.cat([conv1, up9], 1)) return conv9 class Refine(nn.Module): def __init__(self, input_nc, output_nc=3): super(Refine, self).__init__() nl = nn.InstanceNorm2d self.conv1 = nn.Sequential(*[nn.Conv2d(input_nc, 64, kernel_size=3, stride=1, padding=1), nl(64), nn.ReLU(), nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1), nl(64), nn.ReLU()]) self.pool1 = nn.MaxPool2d(kernel_size=(2, 2)) self.conv2 = nn.Sequential(*[nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1), nl(128), nn.ReLU(), nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1), nl(128), nn.ReLU()]) self.pool2 = nn.MaxPool2d(kernel_size=(2, 2)) self.conv3 = nn.Sequential(*[nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1), nl(256), nn.ReLU(), nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1), nl(256), nn.ReLU()]) self.pool3 = nn.MaxPool2d(kernel_size=(2, 2)) self.conv4 = nn.Sequential(*[nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1), nl(512), nn.ReLU(), nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1), nl(512), nn.ReLU()]) self.drop4 = nn.Dropout(0.5) self.pool4 = nn.MaxPool2d(kernel_size=(2, 2)) self.conv5 = nn.Sequential(*[nn.Conv2d(512, 1024, kernel_size=3, stride=1, padding=1), nl(1024), nn.ReLU(), nn.Conv2d(1024, 1024, kernel_size=3, stride=1, padding=1), nl(1024), nn.ReLU()]) self.drop5 = nn.Dropout(0.5) self.up6 = nn.Sequential( *[nn.UpsamplingNearest2d(scale_factor=2), nn.Conv2d(1024, 512, kernel_size=3, stride=1, padding=1), nl(512), nn.ReLU()]) self.conv6 = nn.Sequential(*[nn.Conv2d(1024, 512, kernel_size=3, stride=1, padding=1), nl(512), nn.ReLU(), nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1), nl(512), nn.ReLU()]) self.up7 = nn.Sequential( *[nn.UpsamplingNearest2d(scale_factor=2), nn.Conv2d(512, 256, kernel_size=3, stride=1, padding=1), nl(256), nn.ReLU()]) self.conv7 = nn.Sequential(*[nn.Conv2d(512, 256, kernel_size=3, stride=1, padding=1), nl(256), nn.ReLU(), nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1), nl(256), nn.ReLU()]) self.up8 = nn.Sequential( *[nn.UpsamplingNearest2d(scale_factor=2), nn.Conv2d(256, 128, kernel_size=3, stride=1, padding=1), nl(128), nn.ReLU()]) self.conv8 = nn.Sequential(*[nn.Conv2d(256, 128, kernel_size=3, stride=1, padding=1), nl(128), nn.ReLU(), nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1), nl(128), nn.ReLU()]) self.up9 = nn.Sequential( *[nn.UpsamplingNearest2d(scale_factor=2), nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1), nl(64), nn.ReLU()]) self.conv9 = nn.Sequential(*[nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1), nl(64), nn.ReLU(), nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1), nl(64), nn.ReLU(), nn.Conv2d(64, output_nc, kernel_size=3, stride=1, padding=1) ]) def refine(self, input): conv1 = self.conv1(input) pool1 = self.pool1(conv1) conv2 = self.conv2(pool1) pool2 = self.pool2(conv2) conv3 = self.conv3(pool2) pool3 = self.pool3(conv3) conv4 = self.conv4(pool3) drop4 = self.drop4(conv4) pool4 = self.pool4(drop4) conv5 = self.conv5(pool4) drop5 = self.drop5(conv5) up6 = self.up6(drop5) conv6 = self.conv6(torch.cat([drop4, up6], 1)) up7 = self.up7(conv6) conv7 = self.conv7(torch.cat([conv3, up7], 1)) up8 = self.up8(conv7) conv8 = self.conv8(torch.cat([conv2, up8], 1)) up9 = self.up9(conv8) conv9 = self.conv9(torch.cat([conv1, up9], 1)) return conv9 ###### ResUnet new class ResidualBlock(nn.Module): def __init__(self, in_features=64, norm_layer=nn.BatchNorm2d): super(ResidualBlock, self).__init__() self.relu = nn.ReLU(True) if norm_layer == None: self.block = nn.Sequential( nn.Conv2d(in_features, in_features, 3, 1, 1, bias=False), nn.ReLU(inplace=True), nn.Conv2d(in_features, in_features, 3, 1, 1, bias=False), ) else: self.block = nn.Sequential( nn.Conv2d(in_features, in_features, 3, 1, 1, bias=False), norm_layer(in_features), nn.ReLU(inplace=True), nn.Conv2d(in_features, in_features, 3, 1, 1, bias=False), norm_layer(in_features) ) def forward(self, x): residual = x out = self.block(x) out += residual out = self.relu(out) return out class Refine_ResUnet_New(nn.Module): def __init__(self, input_nc, output_nc, num_downs=5, ngf=32, norm_layer=nn.BatchNorm2d, use_dropout=False): super(Refine_ResUnet_New, self).__init__() # construct unet structure unet_block = ResUnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True) for i in range(num_downs - 5): unet_block = ResUnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_dropout=use_dropout) unet_block = ResUnetSkipConnectionBlock(ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer) unet_block = ResUnetSkipConnectionBlock(ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer) unet_block = ResUnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer) unet_block = ResUnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer) self.model = unet_block def refine(self, input): return self.model(input) # Defines the submodule with skip connection. # X -------------------identity---------------------- X # |-- downsampling -- |submodule| -- upsampling --| class ResUnetSkipConnectionBlock(nn.Module): def __init__(self, outer_nc, inner_nc, input_nc=None, submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False): super(ResUnetSkipConnectionBlock, self).__init__() self.outermost = outermost use_bias = norm_layer == nn.InstanceNorm2d if input_nc is None: input_nc = outer_nc downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=3, stride=2, padding=1, bias=use_bias) # add two resblock res_downconv = [ResidualBlock(inner_nc, norm_layer), ResidualBlock(inner_nc, norm_layer)] res_upconv = [ResidualBlock(outer_nc, norm_layer), ResidualBlock(outer_nc, norm_layer)] downrelu = nn.ReLU(True) uprelu = nn.ReLU(True) if norm_layer != None: downnorm = norm_layer(inner_nc) upnorm = norm_layer(outer_nc) if outermost: upsample = nn.Upsample(scale_factor=2, mode='nearest') upconv = nn.Conv2d(inner_nc * 2, outer_nc, kernel_size=3, stride=1, padding=1, bias=use_bias) down = [downconv, downrelu] + res_downconv up = [upsample, upconv] model = down + [submodule] + up elif innermost: upsample = nn.Upsample(scale_factor=2, mode='nearest') upconv = nn.Conv2d(inner_nc, outer_nc, kernel_size=3, stride=1, padding=1, bias=use_bias) down = [downconv, downrelu] + res_downconv if norm_layer == None: up = [upsample, upconv, uprelu] + res_upconv else: up = [upsample, upconv, upnorm, uprelu] + res_upconv model = down + up else: upsample = nn.Upsample(scale_factor=2, mode='nearest') upconv = nn.Conv2d(inner_nc*2, outer_nc, kernel_size=3, stride=1, padding=1, bias=use_bias) if norm_layer == None: down = [downconv, downrelu] + res_downconv up = [upsample, upconv, uprelu] + res_upconv else: down = [downconv, downnorm, downrelu] + res_downconv up = [upsample, upconv, upnorm, uprelu] + res_upconv if use_dropout: model = down + [submodule] + up + [nn.Dropout(0.5)] else: model = down + [submodule] + up self.model = nn.Sequential(*model) def forward(self, x): if self.outermost: return self.model(x) else: return torch.cat([x, self.model(x)], 1) ################## class GlobalGenerator(nn.Module): def __init__(self, input_nc, output_nc, L, S, ngf=64, n_downsampling=3, n_blocks=9, norm_layer=nn.BatchNorm2d, padding_type='reflect'): assert (n_blocks >= 0) super(GlobalGenerator, self).__init__() activation = nn.ReLU(True) model = [nn.ReflectionPad2d(3), nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0), norm_layer(ngf), activation] ### downsample for i in range(n_downsampling): mult = 2 ** i model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1), norm_layer(ngf * mult * 2), activation] ### resnet blocks mult = 2 ** n_downsampling for i in range(n_blocks): model += [ResnetBlock(ngf * mult, norm_type='adain', padding_type=padding_type)] ### upsample for i in range(n_downsampling): mult = 2 ** (n_downsampling - i) model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2), kernel_size=3, stride=2, padding=1, output_padding=1), norm_layer(int(ngf * mult / 2)), activation] model += [nn.ReflectionPad2d(3), nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)] self.model = nn.Sequential(*model) # style encoder self.enc_style = StyleEncoder(5, S, 16, self.get_num_adain_params(self.model), norm='none', activ='relu', pad_type='reflect') # label encoder self.enc_label = LabelEncoder(5, L, 16, 64, norm='none', activ='relu', pad_type='reflect') def assign_adain_params(self, adain_params, model): # assign the adain_params to the AdaIN layers in model for m in model.modules(): if m.__class__.__name__ == "AdaptiveInstanceNorm2d": mean = adain_params[:, :m.num_features] std = adain_params[:, m.num_features:2 * m.num_features] m.bias = mean.contiguous().view(-1) m.weight = std.contiguous().view(-1) if adain_params.size(1) > 2 * m.num_features: adain_params = adain_params[:, 2 * m.num_features:] def get_num_adain_params(self, model): # return the number of AdaIN parameters needed by the model num_adain_params = 0 for m in model.modules(): if m.__class__.__name__ == "AdaptiveInstanceNorm2d": num_adain_params += 2 * m.num_features return num_adain_params def forward(self, input, input_ref, image_ref): fea1, fea2 = self.enc_label(input_ref) adain_params = self.enc_style((image_ref, fea1, fea2)) self.assign_adain_params(adain_params, self.model) return self.model(input) class BlendGenerator(nn.Module): def __init__(self, input_nc, output_nc, ngf=64, n_downsampling=3, n_blocks=3, norm_layer=nn.BatchNorm2d, padding_type='reflect'): assert (n_blocks >= 0) super(BlendGenerator, self).__init__() activation = nn.ReLU(True) model = [nn.ReflectionPad2d(3), nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0), norm_layer(ngf), activation] ### downsample for i in range(n_downsampling): mult = 2 ** i model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1), norm_layer(ngf * mult * 2), activation] ### resnet blocks mult = 2 ** n_downsampling for i in range(n_blocks): model += [ResnetBlock(ngf * mult, norm_type='in', padding_type=padding_type)] ### upsample for i in range(n_downsampling): mult = 2 ** (n_downsampling - i) model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2), kernel_size=3, stride=2, padding=1, output_padding=1), norm_layer(int(ngf * mult / 2)), activation] model += [nn.ReflectionPad2d(3), nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0), nn.Sigmoid()] self.model = nn.Sequential(*model) def forward(self, input1, input2): m = self.model(torch.cat([input1, input2], 1)) return input1 * m + input2 * (1 - m), m # Define the Multiscale Discriminator. class MultiscaleDiscriminator(nn.Module): def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, use_sigmoid=False, num_D=3, getIntermFeat=False): super(MultiscaleDiscriminator, self).__init__() self.num_D = num_D self.n_layers = n_layers self.getIntermFeat = getIntermFeat for i in range(num_D): netD = NLayerDiscriminator(input_nc, ndf, n_layers, norm_layer, use_sigmoid, getIntermFeat) if getIntermFeat: for j in range(n_layers + 2): setattr(self, 'scale' + str(i) + '_layer' + str(j), getattr(netD, 'model' + str(j))) else: setattr(self, 'layer' + str(i), netD.model) self.downsample = nn.AvgPool2d(3, stride=2, padding=[1, 1], count_include_pad=False) def singleD_forward(self, model, input): if self.getIntermFeat: result = [input] for i in range(len(model)): result.append(model[i](result[-1])) return result[1:] else: return [model(input)] def forward(self, input): num_D = self.num_D result = [] input_downsampled = input for i in range(num_D): if self.getIntermFeat: model = [getattr(self, 'scale' + str(num_D - 1 - i) + '_layer' + str(j)) for j in range(self.n_layers + 2)] else: model = getattr(self, 'layer' + str(num_D - 1 - i)) result.append(self.singleD_forward(model, input_downsampled)) if i != (num_D - 1): input_downsampled = self.downsample(input_downsampled) return result # Define the PatchGAN discriminator with the specified arguments. class NLayerDiscriminator(nn.Module): def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, use_sigmoid=False, getIntermFeat=False): super(NLayerDiscriminator, self).__init__() self.getIntermFeat = getIntermFeat self.n_layers = n_layers kw = 4 padw = int(np.ceil((kw - 1.0) / 2)) sequence = [[nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]] nf = ndf for n in range(1, n_layers): nf_prev = nf nf = min(nf * 2, 512) sequence += [[ nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=2, padding=padw), norm_layer(nf), nn.LeakyReLU(0.2, True) ]] nf_prev = nf nf = min(nf * 2, 512) sequence += [[ nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=1, padding=padw), norm_layer(nf), nn.LeakyReLU(0.2, True) ]] sequence += [[nn.Conv2d(nf, 1, kernel_size=kw, stride=1, padding=padw)]] if use_sigmoid: sequence += [[nn.Sigmoid()]] if getIntermFeat: for n in range(len(sequence)): setattr(self, 'model' + str(n), nn.Sequential(*sequence[n])) else: sequence_stream = [] for n in range(len(sequence)): sequence_stream += sequence[n] self.model = nn.Sequential(*sequence_stream) def forward(self, input): if self.getIntermFeat: res = [input] for n in range(self.n_layers + 2): model = getattr(self, 'model' + str(n)) res.append(model(res[-1])) return res[1:] else: return self.model(input) from torchvision import models class Vgg19(torch.nn.Module): def __init__(self, requires_grad=False): super(Vgg19, self).__init__() vgg = models.vgg19(pretrained=False) vgg_pretrained_features = vgg.features self.vgg = vgg self.slice1 = torch.nn.Sequential() self.slice2 = torch.nn.Sequential() self.slice3 = torch.nn.Sequential() self.slice4 = torch.nn.Sequential() self.slice5 = torch.nn.Sequential() for x in range(2): self.slice1.add_module(str(x), vgg_pretrained_features[x]) for x in range(2, 7): self.slice2.add_module(str(x), vgg_pretrained_features[x]) for x in range(7, 12): self.slice3.add_module(str(x), vgg_pretrained_features[x]) for x in range(12, 21): self.slice4.add_module(str(x), vgg_pretrained_features[x]) for x in range(21, 30): self.slice5.add_module(str(x), vgg_pretrained_features[x]) if not requires_grad: for param in self.parameters(): param.requires_grad = False def forward(self, X): h_relu1 = self.slice1(X) h_relu2 = self.slice2(h_relu1) h_relu3 = self.slice3(h_relu2) h_relu4 = self.slice4(h_relu3) h_relu5 = self.slice5(h_relu4) out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5] return out def extract(self, x): x = self.vgg.features(x) x = self.vgg.avgpool(x) return x # Define the MaskVAE class VAE(nn.Module): def __init__(self, nc, ngf, ndf, latent_variable_size): super(VAE, self).__init__() # self.cuda = True self.nc = nc self.ngf = ngf self.ndf = ndf self.latent_variable_size = latent_variable_size # encoder self.e1 = nn.Conv2d(nc, ndf, 4, 2, 1) self.bn1 = nn.BatchNorm2d(ndf) self.e2 = nn.Conv2d(ndf, ndf * 2, 4, 2, 1) self.bn2 = nn.BatchNorm2d(ndf * 2) self.e3 = nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1) self.bn3 = nn.BatchNorm2d(ndf * 4) self.e4 = nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1) self.bn4 = nn.BatchNorm2d(ndf * 8) self.e5 = nn.Conv2d(ndf * 8, ndf * 16, 4, 2, 1) self.bn5 = nn.BatchNorm2d(ndf * 16) self.e6 = nn.Conv2d(ndf * 16, ndf * 32, 4, 2, 1) self.bn6 = nn.BatchNorm2d(ndf * 32) self.e7 = nn.Conv2d(ndf * 32, ndf * 64, 4, 2, 1) self.bn7 = nn.BatchNorm2d(ndf * 64) self.fc1 = nn.Linear(ndf * 64 * 4 * 4, latent_variable_size) self.fc2 = nn.Linear(ndf * 64 * 4 * 4, latent_variable_size) # decoder self.d1 = nn.Linear(latent_variable_size, ngf * 64 * 4 * 4) self.up1 = nn.UpsamplingNearest2d(scale_factor=2) self.pd1 = nn.ReplicationPad2d(1) self.d2 = nn.Conv2d(ngf * 64, ngf * 32, 3, 1) self.bn8 = nn.BatchNorm2d(ngf * 32, 1.e-3) self.up2 = nn.UpsamplingNearest2d(scale_factor=2) self.pd2 = nn.ReplicationPad2d(1) self.d3 = nn.Conv2d(ngf * 32, ngf * 16, 3, 1) self.bn9 = nn.BatchNorm2d(ngf * 16, 1.e-3) self.up3 = nn.UpsamplingNearest2d(scale_factor=2) self.pd3 = nn.ReplicationPad2d(1) self.d4 = nn.Conv2d(ngf * 16, ngf * 8, 3, 1) self.bn10 = nn.BatchNorm2d(ngf * 8, 1.e-3) self.up4 = nn.UpsamplingNearest2d(scale_factor=2) self.pd4 = nn.ReplicationPad2d(1) self.d5 = nn.Conv2d(ngf * 8, ngf * 4, 3, 1) self.bn11 = nn.BatchNorm2d(ngf * 4, 1.e-3) self.up5 = nn.UpsamplingNearest2d(scale_factor=2) self.pd5 = nn.ReplicationPad2d(1) self.d6 = nn.Conv2d(ngf * 4, ngf * 2, 3, 1) self.bn12 = nn.BatchNorm2d(ngf * 2, 1.e-3) self.up6 = nn.UpsamplingNearest2d(scale_factor=2) self.pd6 = nn.ReplicationPad2d(1) self.d7 = nn.Conv2d(ngf * 2, ngf, 3, 1) self.bn13 = nn.BatchNorm2d(ngf, 1.e-3) self.up7 = nn.UpsamplingNearest2d(scale_factor=2) self.pd7 = nn.ReplicationPad2d(1) self.d8 = nn.Conv2d(ngf, nc, 3, 1) self.leakyrelu = nn.LeakyReLU(0.2) self.relu = nn.ReLU() # self.sigmoid = nn.Sigmoid() self.maxpool = nn.MaxPool2d((2, 2), (2, 2)) def encode(self, x): h1 = self.leakyrelu(self.bn1(self.e1(x))) h2 = self.leakyrelu(self.bn2(self.e2(h1))) h3 = self.leakyrelu(self.bn3(self.e3(h2))) h4 = self.leakyrelu(self.bn4(self.e4(h3))) h5 = self.leakyrelu(self.bn5(self.e5(h4))) h6 = self.leakyrelu(self.bn6(self.e6(h5))) h7 = self.leakyrelu(self.bn7(self.e7(h6))) h7 = h7.view(-1, self.ndf * 64 * 4 * 4) return self.fc1(h7), self.fc2(h7) def reparametrize(self, mu, logvar): std = logvar.mul(0.5).exp_() # if self.cuda: eps = torch.cuda.FloatTensor(std.size()).normal_() # else: # eps = torch.FloatTensor(std.size()).normal_() eps = Variable(eps) return eps.mul(std).add_(mu) def decode(self, z): h1 = self.relu(self.d1(z)) h1 = h1.view(-1, self.ngf * 64, 4, 4) h2 = self.leakyrelu(self.bn8(self.d2(self.pd1(self.up1(h1))))) h3 = self.leakyrelu(self.bn9(self.d3(self.pd2(self.up2(h2))))) h4 = self.leakyrelu(self.bn10(self.d4(self.pd3(self.up3(h3))))) h5 = self.leakyrelu(self.bn11(self.d5(self.pd4(self.up4(h4))))) h6 = self.leakyrelu(self.bn12(self.d6(self.pd5(self.up5(h5))))) h7 = self.leakyrelu(self.bn13(self.d7(self.pd6(self.up6(h6))))) return self.d8(self.pd7(self.up7(h7))) def get_latent_var(self, x): mu, logvar = self.encode(x) z = self.reparametrize(mu, logvar) return z, mu, logvar.mul(0.5).exp_() def forward(self, x): mu, logvar = self.encode(x) z = self.reparametrize(mu, logvar) res = self.decode(z) return res, x, mu, logvar # style encode part class StyleEncoder(nn.Module): def __init__(self, n_downsample, input_dim, dim, style_dim, norm, activ, pad_type): super(StyleEncoder, self).__init__() self.model = [] self.model_middle = [] self.model_last = [] self.model += [ConvBlock(input_dim, dim, 7, 1, 3, norm=norm, activation=activ, pad_type=pad_type)] for i in range(2): self.model += [ConvBlock(dim, 2 * dim, 4, 2, 1, norm=norm, activation=activ, pad_type=pad_type)] dim *= 2 for i in range(n_downsample - 2): self.model_middle += [ConvBlock(dim, dim, 4, 2, 1, norm=norm, activation=activ, pad_type=pad_type)] self.model_last += [nn.AdaptiveAvgPool2d(1)] # global average pooling self.model_last += [nn.Conv2d(dim, style_dim, 1, 1, 0)] self.model = nn.Sequential(*self.model) self.model_middle = nn.Sequential(*self.model_middle) self.model_last = nn.Sequential(*self.model_last) self.output_dim = dim self.sft1 = SFTLayer() self.sft2 = SFTLayer() def forward(self, x): fea = self.model(x[0]) fea = self.sft1((fea, x[1])) fea = self.model_middle(fea) fea = self.sft2((fea, x[2])) return self.model_last(fea) # label encode part class LabelEncoder(nn.Module): def __init__(self, n_downsample, input_dim, dim, style_dim, norm, activ, pad_type): super(LabelEncoder, self).__init__() self.model = [] self.model_last = [nn.ReLU()] self.model += [ConvBlock(input_dim, dim, 7, 1, 3, norm=norm, activation=activ, pad_type=pad_type)] self.model += [ConvBlock(dim, 2 * dim, 4, 2, 1, norm=norm, activation=activ, pad_type=pad_type)] dim *= 2 self.model += [ConvBlock(dim, 2 * dim, 4, 2, 1, norm=norm, activation='none', pad_type=pad_type)] dim *= 2 for i in range(n_downsample - 3): self.model_last += [ConvBlock(dim, dim, 4, 2, 1, norm=norm, activation=activ, pad_type=pad_type)] self.model_last += [ConvBlock(dim, dim, 4, 2, 1, norm=norm, activation='none', pad_type=pad_type)] self.model = nn.Sequential(*self.model) self.model_last = nn.Sequential(*self.model_last) self.output_dim = dim def forward(self, x): fea = self.model(x) return fea, self.model_last(fea) # Define the basic block class ConvBlock(nn.Module): def __init__(self, input_dim, output_dim, kernel_size, stride, padding=0, norm='none', activation='relu', pad_type='zero'): super(ConvBlock, self).__init__() self.use_bias = True # initialize padding if pad_type == 'reflect': self.pad = nn.ReflectionPad2d(padding) elif pad_type == 'replicate': self.pad = nn.ReplicationPad2d(padding) elif pad_type == 'zero': self.pad = nn.ZeroPad2d(padding) else: assert 0, "Unsupported padding type: {}".format(pad_type) # initialize normalization norm_dim = output_dim if norm == 'bn': self.norm = nn.BatchNorm2d(norm_dim) elif norm == 'in': # self.norm = nn.InstanceNorm2d(norm_dim, track_running_stats=True) self.norm = nn.InstanceNorm2d(norm_dim) elif norm == 'ln': self.norm = LayerNorm(norm_dim) elif norm == 'adain': self.norm = AdaptiveInstanceNorm2d(norm_dim) elif norm == 'none' or norm == 'sn': self.norm = None else: assert 0, "Unsupported normalization: {}".format(norm) # initialize activation if activation == 'relu': self.activation = nn.ReLU(inplace=True) elif activation == 'lrelu': self.activation = nn.LeakyReLU(0.2, inplace=True) elif activation == 'prelu': self.activation = nn.PReLU() elif activation == 'selu': self.activation = nn.SELU(inplace=True) elif activation == 'tanh': self.activation = nn.Tanh() elif activation == 'none': self.activation = None else: assert 0, "Unsupported activation: {}".format(activation) # initialize convolution if norm == 'sn': self.conv = SpectralNorm(nn.Conv2d(input_dim, output_dim, kernel_size, stride, bias=self.use_bias)) else: self.conv = nn.Conv2d(input_dim, output_dim, kernel_size, stride, bias=self.use_bias) def forward(self, x): x = self.conv(self.pad(x)) if self.norm: x = self.norm(x) if self.activation: x = self.activation(x) return x class LinearBlock(nn.Module): def __init__(self, input_dim, output_dim, norm='none', activation='relu'): super(LinearBlock, self).__init__() use_bias = True # initialize fully connected layer if norm == 'sn': self.fc = SpectralNorm(nn.Linear(input_dim, output_dim, bias=use_bias)) else: self.fc = nn.Linear(input_dim, output_dim, bias=use_bias) # initialize normalization norm_dim = output_dim if norm == 'bn': self.norm = nn.BatchNorm1d(norm_dim) elif norm == 'in': self.norm = nn.InstanceNorm1d(norm_dim) elif norm == 'ln': self.norm = LayerNorm(norm_dim) elif norm == 'none' or norm == 'sn': self.norm = None else: assert 0, "Unsupported normalization: {}".format(norm) # initialize activation if activation == 'relu': self.activation = nn.ReLU(inplace=True) elif activation == 'lrelu': self.activation = nn.LeakyReLU(0.2, inplace=True) elif activation == 'prelu': self.activation = nn.PReLU() elif activation == 'selu': self.activation = nn.SELU(inplace=True) elif activation == 'tanh': self.activation = nn.Tanh() elif activation == 'none': self.activation = None else: assert 0, "Unsupported activation: {}".format(activation) def forward(self, x): out = self.fc(x) if self.norm: out = self.norm(out) if self.activation: out = self.activation(out) return out # Define a resnet block class ResnetBlock(nn.Module): def __init__(self, dim, norm_type, padding_type, use_dropout=False): super(ResnetBlock, self).__init__() self.conv_block = self.build_conv_block(dim, norm_type, padding_type, use_dropout) def build_conv_block(self, dim, norm_type, padding_type, use_dropout): conv_block = [] conv_block += [ConvBlock(dim, dim, 3, 1, 1, norm=norm_type, activation='relu', pad_type=padding_type)] conv_block += [ConvBlock(dim, dim, 3, 1, 1, norm=norm_type, activation='none', pad_type=padding_type)] return nn.Sequential(*conv_block) def forward(self, x): out = x + self.conv_block(x) return out class SFTLayer(nn.Module): def __init__(self): super(SFTLayer, self).__init__() self.SFT_scale_conv1 = nn.Conv2d(64, 64, 1) self.SFT_scale_conv2 = nn.Conv2d(64, 64, 1) self.SFT_shift_conv1 = nn.Conv2d(64, 64, 1) self.SFT_shift_conv2 = nn.Conv2d(64, 64, 1) def forward(self, x): scale = self.SFT_scale_conv2(F.leaky_relu(self.SFT_scale_conv1(x[1]), 0.1, inplace=True)) shift = self.SFT_shift_conv2(F.leaky_relu(self.SFT_shift_conv1(x[1]), 0.1, inplace=True)) return x[0] * scale + shift class ConvBlock_SFT(nn.Module): def __init__(self, dim, norm_type, padding_type, use_dropout=False): super(ResnetBlock_SFT, self).__init__() self.sft1 = SFTLayer() self.conv1 = ConvBlock(dim, dim, 4, 2, 1, norm=norm_type, activation='none', pad_type=padding_type) def forward(self, x): fea = self.sft1((x[0], x[1])) fea = F.relu(self.conv1(fea), inplace=True) return (x[0] + fea, x[1]) class ConvBlock_SFT_last(nn.Module): def __init__(self, dim, norm_type, padding_type, use_dropout=False): super(ResnetBlock_SFT_last, self).__init__() self.sft1 = SFTLayer() self.conv1 = ConvBlock(dim, dim, 4, 2, 1, norm=norm_type, activation='none', pad_type=padding_type) def forward(self, x): fea = self.sft1((x[0], x[1])) fea = F.relu(self.conv1(fea), inplace=True) return x[0] + fea # Definition of normalization layer class AdaptiveInstanceNorm2d(nn.Module): def __init__(self, num_features, eps=1e-5, momentum=0.1): super(AdaptiveInstanceNorm2d, self).__init__() self.num_features = num_features self.eps = eps self.momentum = momentum # weight and bias are dynamically assigned self.weight = None self.bias = None # just dummy buffers, not used self.register_buffer('running_mean', torch.zeros(num_features)) self.register_buffer('running_var', torch.ones(num_features)) def forward(self, x): assert self.weight is not None and self.bias is not None, "Please assign weight and bias before calling AdaIN!" b, c = x.size(0), x.size(1) running_mean = self.running_mean.repeat(b) running_var = self.running_var.repeat(b) # Apply instance norm x_reshaped = x.contiguous().view(1, b * c, *x.size()[2:]) out = F.batch_norm( x_reshaped, running_mean, running_var, self.weight, self.bias, True, self.momentum, self.eps) return out.view(b, c, *x.size()[2:]) def __repr__(self): return self.__class__.__name__ + '(' + str(self.num_features) + ')' class LayerNorm(nn.Module): def __init__(self, num_features, eps=1e-5, affine=True): super(LayerNorm, self).__init__() self.num_features = num_features self.affine = affine self.eps = eps if self.affine: self.gamma = nn.Parameter(torch.Tensor(num_features).uniform_()) self.beta = nn.Parameter(torch.zeros(num_features)) def forward(self, x): shape = [-1] + [1] * (x.dim() - 1) # print(x.size()) if x.size(0) == 1: # These two lines run much faster in pytorch 0.4 than the two lines listed below. mean = x.view(-1).mean().view(*shape) std = x.view(-1).std().view(*shape) else: mean = x.view(x.size(0), -1).mean(1).view(*shape) std = x.view(x.size(0), -1).std(1).view(*shape) x = (x - mean) / (std + self.eps) if self.affine: shape = [1, -1] + [1] * (x.dim() - 2) x = x * self.gamma.view(*shape) + self.beta.view(*shape) return x def l2normalize(v, eps=1e-12): return v / (v.norm() + eps) class SpectralNorm(nn.Module): """ Based on the paper "Spectral Normalization for Generative Adversarial Networks" by Takeru Miyato, Toshiki Kataoka, Masanori Koyama, Yuichi Yoshida and the Pytorch implementation https://github.com/christiancosgrove/pytorch-spectral-normalization-gan """ def __init__(self, module, name='weight', power_iterations=1): super(SpectralNorm, self).__init__() self.module = module self.name = name self.power_iterations = power_iterations if not self._made_params(): self._make_params() def _update_u_v(self): u = getattr(self.module, self.name + "_u") v = getattr(self.module, self.name + "_v") w = getattr(self.module, self.name + "_bar") height = w.data.shape[0] for _ in range(self.power_iterations): v.data = l2normalize(torch.mv(torch.t(w.view(height, -1).data), u.data)) u.data = l2normalize(torch.mv(w.view(height, -1).data, v.data)) # sigma = torch.dot(u.data, torch.mv(w.view(height,-1).data, v.data)) sigma = u.dot(w.view(height, -1).mv(v)) setattr(self.module, self.name, w / sigma.expand_as(w)) def _made_params(self): try: u = getattr(self.module, self.name + "_u") v = getattr(self.module, self.name + "_v") w = getattr(self.module, self.name + "_bar") return True except AttributeError: return False def _make_params(self): w = getattr(self.module, self.name) height = w.data.shape[0] width = w.view(height, -1).data.shape[1] u = nn.Parameter(w.data.new(height).normal_(0, 1), requires_grad=False) v = nn.Parameter(w.data.new(width).normal_(0, 1), requires_grad=False) u.data = l2normalize(u.data) v.data = l2normalize(v.data) w_bar = nn.Parameter(w.data) del self.module._parameters[self.name] self.module.register_parameter(self.name + "_u", u) self.module.register_parameter(self.name + "_v", v) self.module.register_parameter(self.name + "_bar", w_bar) def forward(self, *args): self._update_u_v() return self.module.forward(*args) ### STN TPS class CNN(nn.Module): def __init__(self, num_output, input_nc=5, ngf=8, n_layers=5, norm_layer=nn.InstanceNorm2d, use_dropout=False): super(CNN, self).__init__() downconv = nn.Conv2d(5, ngf, kernel_size=4, stride=2, padding=1) model = [downconv, nn.ReLU(True), norm_layer(ngf)] for i in range(n_layers): in_ngf = 2 ** i * ngf if 2 ** i * ngf < 1024 else 1024 out_ngf = 2 ** (i + 1) * ngf if 2 ** i * ngf < 1024 else 1024 downconv = nn.Conv2d(in_ngf, out_ngf, kernel_size=4, stride=2, padding=1) model += [downconv, norm_layer(out_ngf), nn.ReLU(True)] model += [nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1), norm_layer(64), nn.ReLU(True)] model += [nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1), norm_layer(64), nn.ReLU(True)] self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2) self.model = nn.Sequential(*model) self.fc1 = nn.Linear(512, 128) self.fc2 = nn.Linear(128, num_output) def forward(self, x): x = self.model(x) x = self.maxpool(x) x = x.view(x.shape[0], -1) x = F.relu(self.fc1(x)) x = F.dropout(x, training=self.training) x = self.fc2(x) return x class ClsNet(nn.Module): def __init__(self): super(ClsNet, self).__init__() self.cnn = CNN(10) def forward(self, x): return F.log_softmax(self.cnn(x)) class BoundedGridLocNet(nn.Module): def __init__(self, grid_height, grid_width, target_control_points): super(BoundedGridLocNet, self).__init__() self.cnn = CNN(grid_height * grid_width * 2) bias = torch.from_numpy(np.arctanh(target_control_points.numpy())) bias = bias.view(-1) self.cnn.fc2.bias.data.copy_(bias) self.cnn.fc2.weight.data.zero_() def forward(self, x): batch_size = x.size(0) points = F.tanh(self.cnn(x)) coor=points.view(batch_size, -1, 2) # coor+=torch.randn(coor.shape).cuda()/10 row=self.get_row(coor,5) col=self.get_col(coor,5) rx,ry,cx,cy=torch.tensor(0.08).cuda(),torch.tensor(0.08).cuda()\ ,torch.tensor(0.08).cuda(),torch.tensor(0.08).cuda() row_x,row_y=row[:,:,0],row[:,:,1] col_x,col_y=col[:,:,0],col[:,:,1] rx_loss=torch.max(rx,row_x).mean() ry_loss=torch.max(ry,row_y).mean() cx_loss=torch.max(cx,col_x).mean() cy_loss=torch.max(cy,col_y).mean() return coor,rx_loss,ry_loss,cx_loss,cy_loss def get_row(self,coor,num): sec_dic=[] for j in range(num): sum=0 buffer=0 flag=False max=-1 for i in range(num-1): differ=(coor[:,j*num+i+1,:]-coor[:,j*num+i,:])**2 if not flag: second_dif=0 flag=True else: second_dif=torch.abs(differ-buffer) sec_dic.append(second_dif) buffer=differ sum+=second_dif return torch.stack(sec_dic,dim=1) def get_col(self,coor,num): sec_dic=[] for i in range(num): sum = 0 buffer = 0 flag = False max = -1 for j in range(num - 1): differ = (coor[:, (j+1) * num + i , :] - coor[:, j * num + i, :]) ** 2 if not flag: second_dif = 0 flag = True else: second_dif = torch.abs(differ-buffer) sec_dic.append(second_dif) buffer = differ sum += second_dif return torch.stack(sec_dic,dim=1) class UnBoundedGridLocNet(nn.Module): def __init__(self, grid_height, grid_width, target_control_points): super(UnBoundedGridLocNet, self).__init__() self.cnn = CNN(grid_height * grid_width * 2) bias = target_control_points.view(-1) self.cnn.fc2.bias.data.copy_(bias) self.cnn.fc2.weight.data.zero_() def forward(self, x): batch_size = x.size(0) points = self.cnn(x) return points.view(batch_size, -1, 2) class STNNet(nn.Module): def __init__(self): super(STNNet, self).__init__() range = 0.9 r1 = range r2 = range grid_size_h = 5 grid_size_w = 5 assert r1 < 1 and r2 < 1 # if >= 1, arctanh will cause error in BoundedGridLocNet target_control_points = torch.Tensor(list(itertools.product( np.arange(-r1, r1 + 0.00001, 2.0 * r1 / (grid_size_h - 1)), np.arange(-r2, r2 + 0.00001, 2.0 * r2 / (grid_size_w - 1)), ))) Y, X = target_control_points.split(1, dim=1) target_control_points = torch.cat([X, Y], dim=1) self.target_control_points=target_control_points # self.get_row(target_control_points,5) GridLocNet = { 'unbounded_stn': UnBoundedGridLocNet, 'bounded_stn': BoundedGridLocNet, }['bounded_stn'] self.loc_net = GridLocNet(grid_size_h, grid_size_w, target_control_points) self.tps = TPSGridGen(256, 192, target_control_points) def get_row(self, coor, num): for j in range(num): sum = 0 buffer = 0 flag = False max = -1 for i in range(num - 1): differ = (coor[j * num + i + 1, :] - coor[j * num + i, :]) ** 2 if not flag: second_dif = 0 flag = True else: second_dif = torch.abs(differ - buffer) buffer = differ sum += second_dif print(sum / num) def get_col(self,coor,num): for i in range(num): sum = 0 buffer = 0 flag = False max = -1 for j in range(num - 1): differ = (coor[ (j + 1) * num + i, :] - coor[j * num + i, :]) ** 2 if not flag: second_dif = 0 flag = True else: second_dif = torch.abs(differ-buffer) buffer = differ sum += second_dif print(sum) def forward(self, x, reference, mask,grid_pic): batch_size = x.size(0) source_control_points,rx,ry,cx,cy = self.loc_net(reference) source_control_points=(source_control_points) # print('control points',source_control_points.shape) source_coordinate = self.tps(source_control_points) grid = source_coordinate.view(batch_size, 256, 192, 2) # print('grid size',grid.shape) transformed_x = grid_sample(x, grid, canvas=0) warped_mask = grid_sample(mask, grid, canvas=0) warped_gpic= grid_sample(grid_pic, grid, canvas=0) return transformed_x, warped_mask,rx,ry,cx,cy,warped_gpic