fifa-tryon-demo / models /networks.py
hasibzunair's picture
added files
4a285f6
raw history blame
No virus
70.3 kB
import torch
import os
import torch.nn as nn
import functools
from torch.autograd import Variable
import numpy as np
import torch.nn.functional as F
import math
import torch
import itertools
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from grid_sample import grid_sample
from torch.autograd import Variable
from tps_grid_gen import TPSGridGen
###############################################################################
# Functions
###############################################################################
def weights_init(m):
classname = m.__class__.__name__
if classname.find('Conv2d') != -1:
m.weight.data.normal_(0.0, 0.02)
elif classname.find('BatchNorm2d') != -1:
m.weight.data.normal_(1.0, 0.02)
m.bias.data.fill_(0)
def get_norm_layer(norm_type='instance'):
if norm_type == 'batch':
norm_layer = functools.partial(nn.BatchNorm2d, affine=True)
elif norm_type == 'instance':
norm_layer = functools.partial(nn.InstanceNorm2d, affine=False)
else:
raise NotImplementedError('normalization layer [%s] is not found' % norm_type)
return norm_layer
def define_G(input_nc, output_nc, ngf, netG, L=1, S=1, n_downsample_global=3, n_blocks_global=9, n_local_enhancers=1,
n_blocks_local=3, norm='instance', gpu_ids=[]):
norm_layer = get_norm_layer(norm_type=norm)
if netG == 'global':
netG = GlobalGenerator(input_nc, output_nc, L, S, ngf, n_downsample_global, n_blocks_global, norm_layer)
elif netG == 'local':
netG = LocalEnhancer(input_nc, output_nc, ngf, n_downsample_global, n_blocks_global,
n_local_enhancers, n_blocks_local, norm_layer)
else:
raise ('generator not implemented!')
print(netG)
if len(gpu_ids) > 0:
assert (torch.cuda.is_available())
netG.cuda(gpu_ids[0])
netG.apply(weights_init)
return netG
def define_Unet(input_nc, gpu_ids=[]):
netG = Unet(input_nc)
netG.cuda(gpu_ids[0])
netG.apply(weights_init)
return netG
def define_UnetMask(input_nc, gpu_ids=[]):
netG = UnetMask(input_nc,output_nc=4)
netG.cuda(gpu_ids[0])
netG.apply(weights_init)
return netG
def define_Refine(input_nc, output_nc, gpu_ids=[]):
netG = Refine(input_nc, output_nc)
netG.cuda(gpu_ids[0])
netG.apply(weights_init)
return netG
####################################################
def define_Refine_ResUnet(input_nc, output_nc, gpu_ids=[]):
#ipdb.set_trace()
netG = Refine_ResUnet_New(input_nc, output_nc) #norm_layer=nn.InstanceNorm2d
#ipdb.set_trace()
netG.cuda(gpu_ids[0])
netG.apply(weights_init)
return netG
####################################################
def define_D(input_nc, ndf, n_layers_D, norm='instance', use_sigmoid=False, num_D=1, getIntermFeat=False, gpu_ids=[]):
norm_layer = get_norm_layer(norm_type=norm)
netD = MultiscaleDiscriminator(input_nc, ndf, n_layers_D, norm_layer, use_sigmoid, num_D, getIntermFeat)
print(netD)
if len(gpu_ids) > 0:
assert (torch.cuda.is_available())
netD.cuda(gpu_ids[0])
netD.apply(weights_init)
return netD
def define_VAE(input_nc, gpu_ids=[]):
netVAE = VAE(19, 32, 32, 1024)
print(netVAE)
if len(gpu_ids) > 0:
assert (torch.cuda.is_available())
netVAE.cuda(gpu_ids[0])
return netVAE
def define_B(input_nc, output_nc, ngf, n_downsample_global=3, n_blocks_global=3, norm='instance', gpu_ids=[]):
norm_layer = get_norm_layer(norm_type=norm)
netB = BlendGenerator(input_nc, output_nc, ngf, n_downsample_global, n_blocks_global, norm_layer)
print(netB)
if len(gpu_ids) > 0:
assert (torch.cuda.is_available())
netB.cuda(gpu_ids[0])
netB.apply(weights_init)
return netB
def define_partial_enc(input_nc, gpu_ids=[]):
net = PartialConvEncoder(input_nc)
print(net)
if len(gpu_ids) > 0:
assert (torch.cuda.is_available())
net.cuda(gpu_ids[0])
net.apply(weights_init)
return net
def define_conv_enc(input_nc, gpu_ids=[]):
net = ConvEncoder(input_nc)
print(net)
if len(gpu_ids) > 0:
assert (torch.cuda.is_available())
net.cuda(gpu_ids[0])
net.apply(weights_init)
return net
def define_AttG(output_nc, gpu_ids=[]):
net = AttGenerator(output_nc)
print(net)
if len(gpu_ids) > 0:
assert (torch.cuda.is_available())
net.cuda(gpu_ids[0])
net.apply(weights_init)
return net
def print_network(net):
if isinstance(net, list):
net = net[0]
num_params = 0
for param in net.parameters():
num_params += param.numel()
print(net)
print('Total number of parameters: %d' % num_params)
##############################################################################
# Losses
##############################################################################
class GANLoss(nn.Module):
def __init__(self, use_lsgan=True, target_real_label=1.0, target_fake_label=0.0,
tensor=torch.FloatTensor):
super(GANLoss, self).__init__()
self.real_label = target_real_label
self.fake_label = target_fake_label
self.real_label_var = None
self.fake_label_var = None
self.Tensor = tensor
if use_lsgan:
self.loss = nn.MSELoss()
else:
self.loss = nn.BCELoss()
def get_target_tensor(self, input, target_is_real):
target_tensor = None
if target_is_real:
create_label = ((self.real_label_var is None) or
(self.real_label_var.numel() != input.numel()))
if create_label:
real_tensor = self.Tensor(input.size()).fill_(self.real_label)
self.real_label_var = Variable(real_tensor, requires_grad=False)
target_tensor = self.real_label_var
else:
create_label = ((self.fake_label_var is None) or
(self.fake_label_var.numel() != input.numel()))
if create_label:
fake_tensor = self.Tensor(input.size()).fill_(self.fake_label)
self.fake_label_var = Variable(fake_tensor, requires_grad=False)
target_tensor = self.fake_label_var
return target_tensor
def __call__(self, input, target_is_real):
if isinstance(input[0], list):
loss = 0
for input_i in input:
pred = input_i[-1]
target_tensor = self.get_target_tensor(pred, target_is_real)
loss += self.loss(pred, target_tensor)
return loss
else:
target_tensor = self.get_target_tensor(input[-1], target_is_real)
return self.loss(input[-1], target_tensor)
class VGGLossWarp(nn.Module):
def __init__(self, gpu_ids):
super(VGGLossWarp, self).__init__()
self.vgg = Vgg19().cuda()
self.criterion = nn.L1Loss()
self.weights = [1.0 / 32, 1.0 / 16, 1.0 / 8, 1.0 / 4, 1.0]
def forward(self, x, y):
x_vgg, y_vgg = self.vgg(x), self.vgg(y)
loss = 0
loss += self.weights[4] * self.criterion(x_vgg[4], y_vgg[4].detach())
return loss
class VGGLoss(nn.Module):
def __init__(self, gpu_ids):
super(VGGLoss, self).__init__()
self.vgg = Vgg19().cuda()
self.criterion = nn.L1Loss()
self.weights = [1.0 / 32, 1.0 / 16, 1.0 / 8, 1.0 / 4, 1.0]
def forward(self, x, y):
x_vgg, y_vgg = self.vgg(x), self.vgg(y)
loss = 0
for i in range(len(x_vgg)):
loss += self.weights[i] * self.criterion(x_vgg[i], y_vgg[i].detach())
return loss
def warp(self, x, y):
x_vgg, y_vgg = self.vgg(x), self.vgg(y)
loss = 0
loss += self.weights[4] * self.criterion(x_vgg[4], y_vgg[4].detach())
return loss
class StyleLoss(nn.Module):
def __init__(self, gpu_ids):
super(StyleLoss, self).__init__()
self.vgg = Vgg19().cuda()
self.weights = [1.0 / 32, 1.0 / 16, 1.0 / 8, 1.0 / 4, 1.0]
def forward(self, x, y):
x_vgg, y_vgg = self.vgg(x), self.vgg(y)
loss = 0
for i in range(len(x_vgg)):
N, C, H, W = x_vgg[i].shape
for n in range(N):
phi_x = x_vgg[i][n]
phi_y = y_vgg[i][n]
phi_x = phi_x.reshape(C, H * W)
phi_y = phi_y.reshape(C, H * W)
G_x = torch.matmul(phi_x, phi_x.t()) / (C * H * W)
G_y = torch.matmul(phi_y, phi_y.t()) / (C * H * W)
loss += torch.sqrt(torch.mean((G_x - G_y) ** 2)) * self.weights[i]
return loss
##############################################################################
# Generator
##############################################################################
class PartialConvEncoder(nn.Module):
def __init__(self, input_nc, ngf=32, norm_layer=nn.BatchNorm2d):
super(PartialConvEncoder, self).__init__()
activation = nn.ReLU(True)
self.pad1 = nn.ReflectionPad2d(3)
self.partial_conv1 = PartialConv(input_nc, ngf, kernel_size=7)
self.norm_layer1 = norm_layer(ngf)
self.activation = activation
##down sample
mult = 2 ** 0
self.down1 = PartialConv(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1)
self.norm_layer2 = norm_layer(ngf * mult * 2)
mult = 2 ** 1
self.down2 = PartialConv(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1)
self.norm_layer3 = norm_layer(ngf * mult * 2)
mult = 2 ** 2
self.down3 = PartialConv(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1)
self.norm_layer4 = norm_layer(ngf * mult * 2)
mult = 2 ** 3
self.down4 = PartialConv(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1)
self.norm_layer5 = norm_layer(ngf * mult * 2)
def forward(self, input, mask):
input = self.pad1(input)
mask = self.pad1(mask)
input, mask = self.partial_conv1(input, mask)
input = self.norm_layer1(input)
input = self.activation(input)
input, mask = self.down1(input, mask)
input = self.norm_layer2(input)
input = self.activation(input)
input, mask = self.down2(input, mask)
input = self.norm_layer3(input)
input = self.activation(input)
input, mask = self.down3(input, mask)
input = self.norm_layer4(input)
input = self.activation(input)
input, mask = self.down4(input, mask)
input = self.norm_layer5(input)
input = self.activation(input)
return input
class ConvEncoder(nn.Module):
def __init__(self, input_nc, ngf=32, n_downsampling=4, n_blocks=4, norm_layer=nn.BatchNorm2d,
padding_type='reflect'):
super(ConvEncoder, self).__init__()
activation = nn.ReLU(True)
# print("input_nc",input_nc)
model = [nn.ReflectionPad2d(3), nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0), norm_layer(ngf), activation]
### downsample
for i in range(n_downsampling):
stride = 2
mult = 2 ** i
model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=stride, padding=1),
norm_layer(ngf * mult * 2), activation]
self.model = nn.Sequential(*model)
def forward(self, input):
return self.model(input)
class AttGenerator(nn.Module):
def __init__(self, output_nc, ngf=32, n_blocks=4, n_downsampling=4, padding_type='reflect'):
super(AttGenerator, self).__init__()
mult = 2 ** n_downsampling
model = []
for i in range(n_blocks):
model += [ResnetBlock(ngf * mult * 2, norm_type='in', padding_type=padding_type)]
self.model = nn.Sequential(*model)
self.upsampling = []
self.out_channels = []
self.AttNorm = []
##upsampling
norm_layer = nn.BatchNorm2d
activation = nn.ReLU(True)
for i in range(n_downsampling):
mult = 2 ** (n_downsampling - i)
up_module = [nn.ConvTranspose2d(ngf * mult * 2, int(ngf * mult / 2) * 2, kernel_size=3, stride=2, padding=1,
output_padding=1),
norm_layer(int(ngf * mult / 2) * 2), activation
]
up_module = nn.Sequential(*up_module)
self.upsampling += [up_module]
self.out_channels += [int(ngf * mult / 2) * 2]
self.upsampling = nn.Sequential(*self.upsampling)
#
self.AttNorm += [AttentionNorm(5, self.out_channels[0], 2, 4)]
self.AttNorm += [AttentionNorm(5, self.out_channels[1], 2, 2)]
self.AttNorm += [AttentionNorm(5, self.out_channels[2], 1, 2)]
self.AttNorm += [AttentionNorm(5, self.out_channels[3], 1, 1)]
self.AttNorm = nn.Sequential(*self.AttNorm)
self.last_conv = [nn.ReflectionPad2d(3), nn.Conv2d(ngf * 2, output_nc, kernel_size=7, padding=0), nn.Tanh()]
self.last_conv = nn.Sequential(*self.last_conv)
def forward(self, input, unattended):
up = self.model(unattended)
for i in range(4):
# print(i)
up = self.upsampling[i](up)
if i == 3:
break;
up = self.AttNorm[i](input, up)
return self.last_conv(up)
class PartialConv(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
padding=0, dilation=1, groups=1, bias=True):
super(PartialConv, self).__init__()
self.input_conv = nn.Conv2d(in_channels, out_channels, kernel_size,
stride, padding, dilation, groups, bias)
self.mask_conv = nn.Conv2d(in_channels, out_channels, kernel_size,
stride, padding, dilation, groups, False)
self.input_conv.apply(weights_init)
torch.nn.init.constant_(self.mask_conv.weight, 1.0)
# mask is not updated
for param in self.mask_conv.parameters():
param.requires_grad = False
def forward(self, input, mask):
# http://masc.cs.gmu.edu/wiki/partialconv
# C(X) = W^T * X + b, C(0) = b, D(M) = 1 * M + 0 = sum(M)
# W^T* (M .* X) / sum(M) + b = [C(M .* X) – C(0)] / D(M) + C(0)
output = self.input_conv(input * mask)
if self.input_conv.bias is not None:
output_bias = self.input_conv.bias.view(1, -1, 1, 1).expand_as(
output)
else:
output_bias = torch.zeros_like(output)
with torch.no_grad():
output_mask = self.mask_conv(mask)
no_update_holes = output_mask == 0
mask_sum = output_mask.masked_fill_(no_update_holes, 1.0)
output_pre = (output - output_bias) / mask_sum + output_bias
output = output_pre.masked_fill_(no_update_holes, 0.0)
new_mask = torch.ones_like(output)
new_mask = new_mask.masked_fill_(no_update_holes, 0.0)
return output, new_mask
class AttentionNorm(nn.Module):
def __init__(self, ref_channels, out_channels, first_rate, second_rate):
super(AttentionNorm, self).__init__()
self.first = first_rate
self.second = second_rate
mid_channels = int(out_channels / 2)
self.conv_1time_f = nn.Conv2d(ref_channels, mid_channels, kernel_size=3, stride=1, padding=1)
self.conv_2times_f = nn.Conv2d(ref_channels, mid_channels, kernel_size=3, stride=2, padding=1)
self.conv_4times_f = nn.Conv2d(ref_channels, mid_channels, kernel_size=3, stride=4, padding=1)
self.conv_1time_s = nn.Conv2d(mid_channels, out_channels, kernel_size=3, stride=1, padding=1)
self.conv_2times_s = nn.Conv2d(mid_channels, out_channels, kernel_size=3, stride=2, padding=1)
self.conv_4times_s = nn.Conv2d(mid_channels, out_channels, kernel_size=3, stride=4, padding=1)
self.conv_1time_m = nn.Conv2d(mid_channels, out_channels, kernel_size=3, stride=1, padding=1)
self.conv_2times_m = nn.Conv2d(mid_channels, out_channels, kernel_size=3, stride=2, padding=1)
self.conv_4times_m = nn.Conv2d(mid_channels, out_channels, kernel_size=3, stride=4, padding=1)
self.norm = nn.BatchNorm2d(out_channels)
self.conv = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
def forward(self, input, unattended):
# attention weights
# print(input.shape,unattended.shape)
if self.first == 1:
input = self.conv_1time_f(input)
elif self.first == 2:
input = self.conv_2times_f(input)
elif self.first == 4:
input = self.conv_4times_f(input)
mask = None
if self.second == 1:
bias = self.conv_1time_s(input)
mask = self.conv_1time_m(input)
elif self.second == 2:
bias = self.conv_2times_s(input)
mask = self.conv_2times_m(input)
elif self.second == 4:
bias = self.conv_4times_s(input)
mask = self.conv_4times_m(input)
mask = torch.sigmoid(mask)
attended = self.norm(unattended)
# print(attended.shape,mask.shape,bias.shape)
attended = attended * mask + bias
attended = torch.relu(attended)
attended = self.conv(attended)
output = attended + unattended
return output
class UnetMask(nn.Module):
def __init__(self, input_nc, output_nc=3):
super(UnetMask, self).__init__()
self.stn = STNNet()
nl = nn.InstanceNorm2d
self.conv1 = nn.Sequential(*[nn.Conv2d(input_nc, 64, kernel_size=3, stride=1, padding=1), nl(64), nn.ReLU(),
nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1), nl(64), nn.ReLU()])
self.pool1 = nn.MaxPool2d(kernel_size=(2, 2))
self.conv2 = nn.Sequential(*[nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1), nl(128), nn.ReLU(),
nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1), nl(128), nn.ReLU()])
self.pool2 = nn.MaxPool2d(kernel_size=(2, 2))
self.conv3 = nn.Sequential(*[nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1), nl(256), nn.ReLU(),
nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1), nl(256), nn.ReLU()])
self.pool3 = nn.MaxPool2d(kernel_size=(2, 2))
self.conv4 = nn.Sequential(*[nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1), nl(512), nn.ReLU(),
nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1), nl(512), nn.ReLU()])
self.drop4 = nn.Dropout(0.5)
self.pool4 = nn.MaxPool2d(kernel_size=(2, 2))
self.conv5 = nn.Sequential(*[nn.Conv2d(512, 1024, kernel_size=3, stride=1, padding=1), nl(1024), nn.ReLU(),
nn.Conv2d(1024, 1024, kernel_size=3, stride=1, padding=1), nl(1024), nn.ReLU()])
self.drop5 = nn.Dropout(0.5)
self.up6 = nn.Sequential(
*[nn.UpsamplingNearest2d(scale_factor=2), nn.Conv2d(1024, 512, kernel_size=3, stride=1, padding=1), nl(512),
nn.ReLU()])
self.conv6 = nn.Sequential(*[nn.Conv2d(1024, 512, kernel_size=3, stride=1, padding=1), nl(512), nn.ReLU(),
nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1), nl(512), nn.ReLU()])
self.up7 = nn.Sequential(
*[nn.UpsamplingNearest2d(scale_factor=2), nn.Conv2d(512, 256, kernel_size=3, stride=1, padding=1), nl(256),
nn.ReLU()])
self.conv7 = nn.Sequential(*[nn.Conv2d(512, 256, kernel_size=3, stride=1, padding=1), nl(256), nn.ReLU(),
nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1), nl(256), nn.ReLU()])
self.up8 = nn.Sequential(
*[nn.UpsamplingNearest2d(scale_factor=2), nn.Conv2d(256, 128, kernel_size=3, stride=1, padding=1), nl(128),
nn.ReLU()])
self.conv8 = nn.Sequential(*[nn.Conv2d(256, 128, kernel_size=3, stride=1, padding=1), nl(128), nn.ReLU(),
nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1), nl(128), nn.ReLU()])
self.up9 = nn.Sequential(
*[nn.UpsamplingNearest2d(scale_factor=2), nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1), nl(64),
nn.ReLU()])
self.conv9 = nn.Sequential(*[nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1), nl(64), nn.ReLU(),
nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1), nl(64), nn.ReLU(),
nn.Conv2d(64, output_nc, kernel_size=3, stride=1, padding=1)
])
def forward(self, input, refer, mask,grid):
input, warped_mask,rx,ry,cx,cy,grid = self.stn(input, torch.cat([mask, refer, input], 1), mask,grid)
# print(input.shape)
conv1 = self.conv1(torch.cat([refer.detach(), input.detach()], 1))
pool1 = self.pool1(conv1)
conv2 = self.conv2(pool1)
pool2 = self.pool2(conv2)
conv3 = self.conv3(pool2)
pool3 = self.pool3(conv3)
conv4 = self.conv4(pool3)
drop4 = self.drop4(conv4)
pool4 = self.pool4(drop4)
conv5 = self.conv5(pool4)
drop5 = self.drop5(conv5)
up6 = self.up6(drop5)
conv6 = self.conv6(torch.cat([drop4, up6], 1))
up7 = self.up7(conv6)
conv7 = self.conv7(torch.cat([conv3, up7], 1))
up8 = self.up8(conv7)
conv8 = self.conv8(torch.cat([conv2, up8], 1))
up9 = self.up9(conv8)
conv9 = self.conv9(torch.cat([conv1, up9], 1))
return conv9, input, warped_mask,grid
class Unet(nn.Module):
def __init__(self, input_nc, output_nc=3):
super(Unet, self).__init__()
self.stn = STNNet()
nl = nn.InstanceNorm2d
self.conv1 = nn.Sequential(*[nn.Conv2d(input_nc, 64, kernel_size=3, stride=1, padding=1), nl(64), nn.ReLU(),
nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1), nl(64), nn.ReLU()])
self.pool1 = nn.MaxPool2d(kernel_size=(2, 2))
self.conv2 = nn.Sequential(*[nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1), nl(128), nn.ReLU(),
nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1), nl(128), nn.ReLU()])
self.pool2 = nn.MaxPool2d(kernel_size=(2, 2))
self.conv3 = nn.Sequential(*[nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1), nl(256), nn.ReLU(),
nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1), nl(256), nn.ReLU()])
self.pool3 = nn.MaxPool2d(kernel_size=(2, 2))
self.conv4 = nn.Sequential(*[nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1), nl(512), nn.ReLU(),
nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1), nl(512), nn.ReLU()])
self.drop4 = nn.Dropout(0.5)
self.pool4 = nn.MaxPool2d(kernel_size=(2, 2))
self.conv5 = nn.Sequential(*[nn.Conv2d(512, 1024, kernel_size=3, stride=1, padding=1), nl(1024), nn.ReLU(),
nn.Conv2d(1024, 1024, kernel_size=3, stride=1, padding=1), nl(1024), nn.ReLU()])
self.drop5 = nn.Dropout(0.5)
self.up6 = nn.Sequential(
*[nn.UpsamplingNearest2d(scale_factor=2), nn.Conv2d(1024, 512, kernel_size=3, stride=1, padding=1), nl(512),
nn.ReLU()])
self.conv6 = nn.Sequential(*[nn.Conv2d(1024, 512, kernel_size=3, stride=1, padding=1), nl(512), nn.ReLU(),
nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1), nl(512), nn.ReLU()])
self.up7 = nn.Sequential(
*[nn.UpsamplingNearest2d(scale_factor=2), nn.Conv2d(512, 256, kernel_size=3, stride=1, padding=1), nl(256),
nn.ReLU()])
self.conv7 = nn.Sequential(*[nn.Conv2d(512, 256, kernel_size=3, stride=1, padding=1), nl(256), nn.ReLU(),
nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1), nl(256), nn.ReLU()])
self.up8 = nn.Sequential(
*[nn.UpsamplingNearest2d(scale_factor=2), nn.Conv2d(256, 128, kernel_size=3, stride=1, padding=1), nl(128),
nn.ReLU()])
self.conv8 = nn.Sequential(*[nn.Conv2d(256, 128, kernel_size=3, stride=1, padding=1), nl(128), nn.ReLU(),
nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1), nl(128), nn.ReLU()])
self.up9 = nn.Sequential(
*[nn.UpsamplingNearest2d(scale_factor=2), nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1), nl(64),
nn.ReLU()])
self.conv9 = nn.Sequential(*[nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1), nl(64), nn.ReLU(),
nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1), nl(64), nn.ReLU(),
nn.Conv2d(64, output_nc, kernel_size=3, stride=1, padding=1)
])
def forward(self, input, refer, mask):
input, warped_mask,rx,ry,cx,cy = self.stn(input, torch.cat([mask, refer, input], 1), mask)
# print(input.shape)
conv1 = self.conv1(torch.cat([refer.detach(), input.detach()], 1))
pool1 = self.pool1(conv1)
conv2 = self.conv2(pool1)
pool2 = self.pool2(conv2)
conv3 = self.conv3(pool2)
pool3 = self.pool3(conv3)
conv4 = self.conv4(pool3)
drop4 = self.drop4(conv4)
pool4 = self.pool4(drop4)
conv5 = self.conv5(pool4)
drop5 = self.drop5(conv5)
up6 = self.up6(drop5)
conv6 = self.conv6(torch.cat([drop4, up6], 1))
up7 = self.up7(conv6)
conv7 = self.conv7(torch.cat([conv3, up7], 1))
up8 = self.up8(conv7)
conv8 = self.conv8(torch.cat([conv2, up8], 1))
up9 = self.up9(conv8)
conv9 = self.conv9(torch.cat([conv1, up9], 1))
return conv9, input, warped_mask,rx,ry,cx,cy
def refine(self, input):
conv1 = self.conv1(input)
pool1 = self.pool1(conv1)
conv2 = self.conv2(pool1)
pool2 = self.pool2(conv2)
conv3 = self.conv3(pool2)
pool3 = self.pool3(conv3)
conv4 = self.conv4(pool3)
drop4 = self.drop4(conv4)
pool4 = self.pool4(drop4)
conv5 = self.conv5(pool4)
drop5 = self.drop5(conv5)
up6 = self.up6(drop5)
conv6 = self.conv6(torch.cat([drop4, up6], 1))
up7 = self.up7(conv6)
conv7 = self.conv7(torch.cat([conv3, up7], 1))
up8 = self.up8(conv7)
conv8 = self.conv8(torch.cat([conv2, up8], 1))
up9 = self.up9(conv8)
conv9 = self.conv9(torch.cat([conv1, up9], 1))
return conv9
class Refine(nn.Module):
def __init__(self, input_nc, output_nc=3):
super(Refine, self).__init__()
nl = nn.InstanceNorm2d
self.conv1 = nn.Sequential(*[nn.Conv2d(input_nc, 64, kernel_size=3, stride=1, padding=1), nl(64), nn.ReLU(),
nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1), nl(64), nn.ReLU()])
self.pool1 = nn.MaxPool2d(kernel_size=(2, 2))
self.conv2 = nn.Sequential(*[nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1), nl(128), nn.ReLU(),
nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1), nl(128), nn.ReLU()])
self.pool2 = nn.MaxPool2d(kernel_size=(2, 2))
self.conv3 = nn.Sequential(*[nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1), nl(256), nn.ReLU(),
nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1), nl(256), nn.ReLU()])
self.pool3 = nn.MaxPool2d(kernel_size=(2, 2))
self.conv4 = nn.Sequential(*[nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1), nl(512), nn.ReLU(),
nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1), nl(512), nn.ReLU()])
self.drop4 = nn.Dropout(0.5)
self.pool4 = nn.MaxPool2d(kernel_size=(2, 2))
self.conv5 = nn.Sequential(*[nn.Conv2d(512, 1024, kernel_size=3, stride=1, padding=1), nl(1024), nn.ReLU(),
nn.Conv2d(1024, 1024, kernel_size=3, stride=1, padding=1), nl(1024), nn.ReLU()])
self.drop5 = nn.Dropout(0.5)
self.up6 = nn.Sequential(
*[nn.UpsamplingNearest2d(scale_factor=2), nn.Conv2d(1024, 512, kernel_size=3, stride=1, padding=1), nl(512),
nn.ReLU()])
self.conv6 = nn.Sequential(*[nn.Conv2d(1024, 512, kernel_size=3, stride=1, padding=1), nl(512), nn.ReLU(),
nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1), nl(512), nn.ReLU()])
self.up7 = nn.Sequential(
*[nn.UpsamplingNearest2d(scale_factor=2), nn.Conv2d(512, 256, kernel_size=3, stride=1, padding=1), nl(256),
nn.ReLU()])
self.conv7 = nn.Sequential(*[nn.Conv2d(512, 256, kernel_size=3, stride=1, padding=1), nl(256), nn.ReLU(),
nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1), nl(256), nn.ReLU()])
self.up8 = nn.Sequential(
*[nn.UpsamplingNearest2d(scale_factor=2), nn.Conv2d(256, 128, kernel_size=3, stride=1, padding=1), nl(128),
nn.ReLU()])
self.conv8 = nn.Sequential(*[nn.Conv2d(256, 128, kernel_size=3, stride=1, padding=1), nl(128), nn.ReLU(),
nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1), nl(128), nn.ReLU()])
self.up9 = nn.Sequential(
*[nn.UpsamplingNearest2d(scale_factor=2), nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1), nl(64),
nn.ReLU()])
self.conv9 = nn.Sequential(*[nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1), nl(64), nn.ReLU(),
nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1), nl(64), nn.ReLU(),
nn.Conv2d(64, output_nc, kernel_size=3, stride=1, padding=1)
])
def refine(self, input):
conv1 = self.conv1(input)
pool1 = self.pool1(conv1)
conv2 = self.conv2(pool1)
pool2 = self.pool2(conv2)
conv3 = self.conv3(pool2)
pool3 = self.pool3(conv3)
conv4 = self.conv4(pool3)
drop4 = self.drop4(conv4)
pool4 = self.pool4(drop4)
conv5 = self.conv5(pool4)
drop5 = self.drop5(conv5)
up6 = self.up6(drop5)
conv6 = self.conv6(torch.cat([drop4, up6], 1))
up7 = self.up7(conv6)
conv7 = self.conv7(torch.cat([conv3, up7], 1))
up8 = self.up8(conv7)
conv8 = self.conv8(torch.cat([conv2, up8], 1))
up9 = self.up9(conv8)
conv9 = self.conv9(torch.cat([conv1, up9], 1))
return conv9
###### ResUnet new
class ResidualBlock(nn.Module):
def __init__(self, in_features=64, norm_layer=nn.BatchNorm2d):
super(ResidualBlock, self).__init__()
self.relu = nn.ReLU(True)
if norm_layer == None:
self.block = nn.Sequential(
nn.Conv2d(in_features, in_features, 3, 1, 1, bias=False),
nn.ReLU(inplace=True),
nn.Conv2d(in_features, in_features, 3, 1, 1, bias=False),
)
else:
self.block = nn.Sequential(
nn.Conv2d(in_features, in_features, 3, 1, 1, bias=False),
norm_layer(in_features),
nn.ReLU(inplace=True),
nn.Conv2d(in_features, in_features, 3, 1, 1, bias=False),
norm_layer(in_features)
)
def forward(self, x):
residual = x
out = self.block(x)
out += residual
out = self.relu(out)
return out
class Refine_ResUnet_New(nn.Module):
def __init__(self, input_nc, output_nc, num_downs=5, ngf=32,
norm_layer=nn.BatchNorm2d, use_dropout=False):
super(Refine_ResUnet_New, self).__init__()
# construct unet structure
unet_block = ResUnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True)
for i in range(num_downs - 5):
unet_block = ResUnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_dropout=use_dropout)
unet_block = ResUnetSkipConnectionBlock(ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
unet_block = ResUnetSkipConnectionBlock(ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
unet_block = ResUnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
unet_block = ResUnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer)
self.model = unet_block
def refine(self, input):
return self.model(input)
# Defines the submodule with skip connection.
# X -------------------identity---------------------- X
# |-- downsampling -- |submodule| -- upsampling --|
class ResUnetSkipConnectionBlock(nn.Module):
def __init__(self, outer_nc, inner_nc, input_nc=None,
submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False):
super(ResUnetSkipConnectionBlock, self).__init__()
self.outermost = outermost
use_bias = norm_layer == nn.InstanceNorm2d
if input_nc is None:
input_nc = outer_nc
downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=3,
stride=2, padding=1, bias=use_bias)
# add two resblock
res_downconv = [ResidualBlock(inner_nc, norm_layer), ResidualBlock(inner_nc, norm_layer)]
res_upconv = [ResidualBlock(outer_nc, norm_layer), ResidualBlock(outer_nc, norm_layer)]
downrelu = nn.ReLU(True)
uprelu = nn.ReLU(True)
if norm_layer != None:
downnorm = norm_layer(inner_nc)
upnorm = norm_layer(outer_nc)
if outermost:
upsample = nn.Upsample(scale_factor=2, mode='nearest')
upconv = nn.Conv2d(inner_nc * 2, outer_nc, kernel_size=3, stride=1, padding=1, bias=use_bias)
down = [downconv, downrelu] + res_downconv
up = [upsample, upconv]
model = down + [submodule] + up
elif innermost:
upsample = nn.Upsample(scale_factor=2, mode='nearest')
upconv = nn.Conv2d(inner_nc, outer_nc, kernel_size=3, stride=1, padding=1, bias=use_bias)
down = [downconv, downrelu] + res_downconv
if norm_layer == None:
up = [upsample, upconv, uprelu] + res_upconv
else:
up = [upsample, upconv, upnorm, uprelu] + res_upconv
model = down + up
else:
upsample = nn.Upsample(scale_factor=2, mode='nearest')
upconv = nn.Conv2d(inner_nc*2, outer_nc, kernel_size=3, stride=1, padding=1, bias=use_bias)
if norm_layer == None:
down = [downconv, downrelu] + res_downconv
up = [upsample, upconv, uprelu] + res_upconv
else:
down = [downconv, downnorm, downrelu] + res_downconv
up = [upsample, upconv, upnorm, uprelu] + res_upconv
if use_dropout:
model = down + [submodule] + up + [nn.Dropout(0.5)]
else:
model = down + [submodule] + up
self.model = nn.Sequential(*model)
def forward(self, x):
if self.outermost:
return self.model(x)
else:
return torch.cat([x, self.model(x)], 1)
##################
class GlobalGenerator(nn.Module):
def __init__(self, input_nc, output_nc, L, S, ngf=64, n_downsampling=3, n_blocks=9, norm_layer=nn.BatchNorm2d,
padding_type='reflect'):
assert (n_blocks >= 0)
super(GlobalGenerator, self).__init__()
activation = nn.ReLU(True)
model = [nn.ReflectionPad2d(3), nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0), norm_layer(ngf), activation]
### downsample
for i in range(n_downsampling):
mult = 2 ** i
model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1),
norm_layer(ngf * mult * 2), activation]
### resnet blocks
mult = 2 ** n_downsampling
for i in range(n_blocks):
model += [ResnetBlock(ngf * mult, norm_type='adain', padding_type=padding_type)]
### upsample
for i in range(n_downsampling):
mult = 2 ** (n_downsampling - i)
model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2), kernel_size=3, stride=2, padding=1,
output_padding=1),
norm_layer(int(ngf * mult / 2)), activation]
model += [nn.ReflectionPad2d(3), nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
self.model = nn.Sequential(*model)
# style encoder
self.enc_style = StyleEncoder(5, S, 16, self.get_num_adain_params(self.model), norm='none', activ='relu',
pad_type='reflect')
# label encoder
self.enc_label = LabelEncoder(5, L, 16, 64, norm='none', activ='relu', pad_type='reflect')
def assign_adain_params(self, adain_params, model):
# assign the adain_params to the AdaIN layers in model
for m in model.modules():
if m.__class__.__name__ == "AdaptiveInstanceNorm2d":
mean = adain_params[:, :m.num_features]
std = adain_params[:, m.num_features:2 * m.num_features]
m.bias = mean.contiguous().view(-1)
m.weight = std.contiguous().view(-1)
if adain_params.size(1) > 2 * m.num_features:
adain_params = adain_params[:, 2 * m.num_features:]
def get_num_adain_params(self, model):
# return the number of AdaIN parameters needed by the model
num_adain_params = 0
for m in model.modules():
if m.__class__.__name__ == "AdaptiveInstanceNorm2d":
num_adain_params += 2 * m.num_features
return num_adain_params
def forward(self, input, input_ref, image_ref):
fea1, fea2 = self.enc_label(input_ref)
adain_params = self.enc_style((image_ref, fea1, fea2))
self.assign_adain_params(adain_params, self.model)
return self.model(input)
class BlendGenerator(nn.Module):
def __init__(self, input_nc, output_nc, ngf=64, n_downsampling=3, n_blocks=3, norm_layer=nn.BatchNorm2d,
padding_type='reflect'):
assert (n_blocks >= 0)
super(BlendGenerator, self).__init__()
activation = nn.ReLU(True)
model = [nn.ReflectionPad2d(3), nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0), norm_layer(ngf), activation]
### downsample
for i in range(n_downsampling):
mult = 2 ** i
model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1),
norm_layer(ngf * mult * 2), activation]
### resnet blocks
mult = 2 ** n_downsampling
for i in range(n_blocks):
model += [ResnetBlock(ngf * mult, norm_type='in', padding_type=padding_type)]
### upsample
for i in range(n_downsampling):
mult = 2 ** (n_downsampling - i)
model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2), kernel_size=3, stride=2, padding=1,
output_padding=1),
norm_layer(int(ngf * mult / 2)), activation]
model += [nn.ReflectionPad2d(3), nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0), nn.Sigmoid()]
self.model = nn.Sequential(*model)
def forward(self, input1, input2):
m = self.model(torch.cat([input1, input2], 1))
return input1 * m + input2 * (1 - m), m
# Define the Multiscale Discriminator.
class MultiscaleDiscriminator(nn.Module):
def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d,
use_sigmoid=False, num_D=3, getIntermFeat=False):
super(MultiscaleDiscriminator, self).__init__()
self.num_D = num_D
self.n_layers = n_layers
self.getIntermFeat = getIntermFeat
for i in range(num_D):
netD = NLayerDiscriminator(input_nc, ndf, n_layers, norm_layer, use_sigmoid, getIntermFeat)
if getIntermFeat:
for j in range(n_layers + 2):
setattr(self, 'scale' + str(i) + '_layer' + str(j), getattr(netD, 'model' + str(j)))
else:
setattr(self, 'layer' + str(i), netD.model)
self.downsample = nn.AvgPool2d(3, stride=2, padding=[1, 1], count_include_pad=False)
def singleD_forward(self, model, input):
if self.getIntermFeat:
result = [input]
for i in range(len(model)):
result.append(model[i](result[-1]))
return result[1:]
else:
return [model(input)]
def forward(self, input):
num_D = self.num_D
result = []
input_downsampled = input
for i in range(num_D):
if self.getIntermFeat:
model = [getattr(self, 'scale' + str(num_D - 1 - i) + '_layer' + str(j)) for j in
range(self.n_layers + 2)]
else:
model = getattr(self, 'layer' + str(num_D - 1 - i))
result.append(self.singleD_forward(model, input_downsampled))
if i != (num_D - 1):
input_downsampled = self.downsample(input_downsampled)
return result
# Define the PatchGAN discriminator with the specified arguments.
class NLayerDiscriminator(nn.Module):
def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, use_sigmoid=False, getIntermFeat=False):
super(NLayerDiscriminator, self).__init__()
self.getIntermFeat = getIntermFeat
self.n_layers = n_layers
kw = 4
padw = int(np.ceil((kw - 1.0) / 2))
sequence = [[nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]]
nf = ndf
for n in range(1, n_layers):
nf_prev = nf
nf = min(nf * 2, 512)
sequence += [[
nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=2, padding=padw),
norm_layer(nf), nn.LeakyReLU(0.2, True)
]]
nf_prev = nf
nf = min(nf * 2, 512)
sequence += [[
nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=1, padding=padw),
norm_layer(nf),
nn.LeakyReLU(0.2, True)
]]
sequence += [[nn.Conv2d(nf, 1, kernel_size=kw, stride=1, padding=padw)]]
if use_sigmoid:
sequence += [[nn.Sigmoid()]]
if getIntermFeat:
for n in range(len(sequence)):
setattr(self, 'model' + str(n), nn.Sequential(*sequence[n]))
else:
sequence_stream = []
for n in range(len(sequence)):
sequence_stream += sequence[n]
self.model = nn.Sequential(*sequence_stream)
def forward(self, input):
if self.getIntermFeat:
res = [input]
for n in range(self.n_layers + 2):
model = getattr(self, 'model' + str(n))
res.append(model(res[-1]))
return res[1:]
else:
return self.model(input)
from torchvision import models
class Vgg19(torch.nn.Module):
def __init__(self, requires_grad=False):
super(Vgg19, self).__init__()
vgg = models.vgg19(pretrained=False)
vgg_pretrained_features = vgg.features
self.vgg = vgg
self.slice1 = torch.nn.Sequential()
self.slice2 = torch.nn.Sequential()
self.slice3 = torch.nn.Sequential()
self.slice4 = torch.nn.Sequential()
self.slice5 = torch.nn.Sequential()
for x in range(2):
self.slice1.add_module(str(x), vgg_pretrained_features[x])
for x in range(2, 7):
self.slice2.add_module(str(x), vgg_pretrained_features[x])
for x in range(7, 12):
self.slice3.add_module(str(x), vgg_pretrained_features[x])
for x in range(12, 21):
self.slice4.add_module(str(x), vgg_pretrained_features[x])
for x in range(21, 30):
self.slice5.add_module(str(x), vgg_pretrained_features[x])
if not requires_grad:
for param in self.parameters():
param.requires_grad = False
def forward(self, X):
h_relu1 = self.slice1(X)
h_relu2 = self.slice2(h_relu1)
h_relu3 = self.slice3(h_relu2)
h_relu4 = self.slice4(h_relu3)
h_relu5 = self.slice5(h_relu4)
out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5]
return out
def extract(self, x):
x = self.vgg.features(x)
x = self.vgg.avgpool(x)
return x
# Define the MaskVAE
class VAE(nn.Module):
def __init__(self, nc, ngf, ndf, latent_variable_size):
super(VAE, self).__init__()
# self.cuda = True
self.nc = nc
self.ngf = ngf
self.ndf = ndf
self.latent_variable_size = latent_variable_size
# encoder
self.e1 = nn.Conv2d(nc, ndf, 4, 2, 1)
self.bn1 = nn.BatchNorm2d(ndf)
self.e2 = nn.Conv2d(ndf, ndf * 2, 4, 2, 1)
self.bn2 = nn.BatchNorm2d(ndf * 2)
self.e3 = nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1)
self.bn3 = nn.BatchNorm2d(ndf * 4)
self.e4 = nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1)
self.bn4 = nn.BatchNorm2d(ndf * 8)
self.e5 = nn.Conv2d(ndf * 8, ndf * 16, 4, 2, 1)
self.bn5 = nn.BatchNorm2d(ndf * 16)
self.e6 = nn.Conv2d(ndf * 16, ndf * 32, 4, 2, 1)
self.bn6 = nn.BatchNorm2d(ndf * 32)
self.e7 = nn.Conv2d(ndf * 32, ndf * 64, 4, 2, 1)
self.bn7 = nn.BatchNorm2d(ndf * 64)
self.fc1 = nn.Linear(ndf * 64 * 4 * 4, latent_variable_size)
self.fc2 = nn.Linear(ndf * 64 * 4 * 4, latent_variable_size)
# decoder
self.d1 = nn.Linear(latent_variable_size, ngf * 64 * 4 * 4)
self.up1 = nn.UpsamplingNearest2d(scale_factor=2)
self.pd1 = nn.ReplicationPad2d(1)
self.d2 = nn.Conv2d(ngf * 64, ngf * 32, 3, 1)
self.bn8 = nn.BatchNorm2d(ngf * 32, 1.e-3)
self.up2 = nn.UpsamplingNearest2d(scale_factor=2)
self.pd2 = nn.ReplicationPad2d(1)
self.d3 = nn.Conv2d(ngf * 32, ngf * 16, 3, 1)
self.bn9 = nn.BatchNorm2d(ngf * 16, 1.e-3)
self.up3 = nn.UpsamplingNearest2d(scale_factor=2)
self.pd3 = nn.ReplicationPad2d(1)
self.d4 = nn.Conv2d(ngf * 16, ngf * 8, 3, 1)
self.bn10 = nn.BatchNorm2d(ngf * 8, 1.e-3)
self.up4 = nn.UpsamplingNearest2d(scale_factor=2)
self.pd4 = nn.ReplicationPad2d(1)
self.d5 = nn.Conv2d(ngf * 8, ngf * 4, 3, 1)
self.bn11 = nn.BatchNorm2d(ngf * 4, 1.e-3)
self.up5 = nn.UpsamplingNearest2d(scale_factor=2)
self.pd5 = nn.ReplicationPad2d(1)
self.d6 = nn.Conv2d(ngf * 4, ngf * 2, 3, 1)
self.bn12 = nn.BatchNorm2d(ngf * 2, 1.e-3)
self.up6 = nn.UpsamplingNearest2d(scale_factor=2)
self.pd6 = nn.ReplicationPad2d(1)
self.d7 = nn.Conv2d(ngf * 2, ngf, 3, 1)
self.bn13 = nn.BatchNorm2d(ngf, 1.e-3)
self.up7 = nn.UpsamplingNearest2d(scale_factor=2)
self.pd7 = nn.ReplicationPad2d(1)
self.d8 = nn.Conv2d(ngf, nc, 3, 1)
self.leakyrelu = nn.LeakyReLU(0.2)
self.relu = nn.ReLU()
# self.sigmoid = nn.Sigmoid()
self.maxpool = nn.MaxPool2d((2, 2), (2, 2))
def encode(self, x):
h1 = self.leakyrelu(self.bn1(self.e1(x)))
h2 = self.leakyrelu(self.bn2(self.e2(h1)))
h3 = self.leakyrelu(self.bn3(self.e3(h2)))
h4 = self.leakyrelu(self.bn4(self.e4(h3)))
h5 = self.leakyrelu(self.bn5(self.e5(h4)))
h6 = self.leakyrelu(self.bn6(self.e6(h5)))
h7 = self.leakyrelu(self.bn7(self.e7(h6)))
h7 = h7.view(-1, self.ndf * 64 * 4 * 4)
return self.fc1(h7), self.fc2(h7)
def reparametrize(self, mu, logvar):
std = logvar.mul(0.5).exp_()
# if self.cuda:
eps = torch.cuda.FloatTensor(std.size()).normal_()
# else:
# eps = torch.FloatTensor(std.size()).normal_()
eps = Variable(eps)
return eps.mul(std).add_(mu)
def decode(self, z):
h1 = self.relu(self.d1(z))
h1 = h1.view(-1, self.ngf * 64, 4, 4)
h2 = self.leakyrelu(self.bn8(self.d2(self.pd1(self.up1(h1)))))
h3 = self.leakyrelu(self.bn9(self.d3(self.pd2(self.up2(h2)))))
h4 = self.leakyrelu(self.bn10(self.d4(self.pd3(self.up3(h3)))))
h5 = self.leakyrelu(self.bn11(self.d5(self.pd4(self.up4(h4)))))
h6 = self.leakyrelu(self.bn12(self.d6(self.pd5(self.up5(h5)))))
h7 = self.leakyrelu(self.bn13(self.d7(self.pd6(self.up6(h6)))))
return self.d8(self.pd7(self.up7(h7)))
def get_latent_var(self, x):
mu, logvar = self.encode(x)
z = self.reparametrize(mu, logvar)
return z, mu, logvar.mul(0.5).exp_()
def forward(self, x):
mu, logvar = self.encode(x)
z = self.reparametrize(mu, logvar)
res = self.decode(z)
return res, x, mu, logvar
# style encode part
class StyleEncoder(nn.Module):
def __init__(self, n_downsample, input_dim, dim, style_dim, norm, activ, pad_type):
super(StyleEncoder, self).__init__()
self.model = []
self.model_middle = []
self.model_last = []
self.model += [ConvBlock(input_dim, dim, 7, 1, 3, norm=norm, activation=activ, pad_type=pad_type)]
for i in range(2):
self.model += [ConvBlock(dim, 2 * dim, 4, 2, 1, norm=norm, activation=activ, pad_type=pad_type)]
dim *= 2
for i in range(n_downsample - 2):
self.model_middle += [ConvBlock(dim, dim, 4, 2, 1, norm=norm, activation=activ, pad_type=pad_type)]
self.model_last += [nn.AdaptiveAvgPool2d(1)] # global average pooling
self.model_last += [nn.Conv2d(dim, style_dim, 1, 1, 0)]
self.model = nn.Sequential(*self.model)
self.model_middle = nn.Sequential(*self.model_middle)
self.model_last = nn.Sequential(*self.model_last)
self.output_dim = dim
self.sft1 = SFTLayer()
self.sft2 = SFTLayer()
def forward(self, x):
fea = self.model(x[0])
fea = self.sft1((fea, x[1]))
fea = self.model_middle(fea)
fea = self.sft2((fea, x[2]))
return self.model_last(fea)
# label encode part
class LabelEncoder(nn.Module):
def __init__(self, n_downsample, input_dim, dim, style_dim, norm, activ, pad_type):
super(LabelEncoder, self).__init__()
self.model = []
self.model_last = [nn.ReLU()]
self.model += [ConvBlock(input_dim, dim, 7, 1, 3, norm=norm, activation=activ, pad_type=pad_type)]
self.model += [ConvBlock(dim, 2 * dim, 4, 2, 1, norm=norm, activation=activ, pad_type=pad_type)]
dim *= 2
self.model += [ConvBlock(dim, 2 * dim, 4, 2, 1, norm=norm, activation='none', pad_type=pad_type)]
dim *= 2
for i in range(n_downsample - 3):
self.model_last += [ConvBlock(dim, dim, 4, 2, 1, norm=norm, activation=activ, pad_type=pad_type)]
self.model_last += [ConvBlock(dim, dim, 4, 2, 1, norm=norm, activation='none', pad_type=pad_type)]
self.model = nn.Sequential(*self.model)
self.model_last = nn.Sequential(*self.model_last)
self.output_dim = dim
def forward(self, x):
fea = self.model(x)
return fea, self.model_last(fea)
# Define the basic block
class ConvBlock(nn.Module):
def __init__(self, input_dim, output_dim, kernel_size, stride,
padding=0, norm='none', activation='relu', pad_type='zero'):
super(ConvBlock, self).__init__()
self.use_bias = True
# initialize padding
if pad_type == 'reflect':
self.pad = nn.ReflectionPad2d(padding)
elif pad_type == 'replicate':
self.pad = nn.ReplicationPad2d(padding)
elif pad_type == 'zero':
self.pad = nn.ZeroPad2d(padding)
else:
assert 0, "Unsupported padding type: {}".format(pad_type)
# initialize normalization
norm_dim = output_dim
if norm == 'bn':
self.norm = nn.BatchNorm2d(norm_dim)
elif norm == 'in':
# self.norm = nn.InstanceNorm2d(norm_dim, track_running_stats=True)
self.norm = nn.InstanceNorm2d(norm_dim)
elif norm == 'ln':
self.norm = LayerNorm(norm_dim)
elif norm == 'adain':
self.norm = AdaptiveInstanceNorm2d(norm_dim)
elif norm == 'none' or norm == 'sn':
self.norm = None
else:
assert 0, "Unsupported normalization: {}".format(norm)
# initialize activation
if activation == 'relu':
self.activation = nn.ReLU(inplace=True)
elif activation == 'lrelu':
self.activation = nn.LeakyReLU(0.2, inplace=True)
elif activation == 'prelu':
self.activation = nn.PReLU()
elif activation == 'selu':
self.activation = nn.SELU(inplace=True)
elif activation == 'tanh':
self.activation = nn.Tanh()
elif activation == 'none':
self.activation = None
else:
assert 0, "Unsupported activation: {}".format(activation)
# initialize convolution
if norm == 'sn':
self.conv = SpectralNorm(nn.Conv2d(input_dim, output_dim, kernel_size, stride, bias=self.use_bias))
else:
self.conv = nn.Conv2d(input_dim, output_dim, kernel_size, stride, bias=self.use_bias)
def forward(self, x):
x = self.conv(self.pad(x))
if self.norm:
x = self.norm(x)
if self.activation:
x = self.activation(x)
return x
class LinearBlock(nn.Module):
def __init__(self, input_dim, output_dim, norm='none', activation='relu'):
super(LinearBlock, self).__init__()
use_bias = True
# initialize fully connected layer
if norm == 'sn':
self.fc = SpectralNorm(nn.Linear(input_dim, output_dim, bias=use_bias))
else:
self.fc = nn.Linear(input_dim, output_dim, bias=use_bias)
# initialize normalization
norm_dim = output_dim
if norm == 'bn':
self.norm = nn.BatchNorm1d(norm_dim)
elif norm == 'in':
self.norm = nn.InstanceNorm1d(norm_dim)
elif norm == 'ln':
self.norm = LayerNorm(norm_dim)
elif norm == 'none' or norm == 'sn':
self.norm = None
else:
assert 0, "Unsupported normalization: {}".format(norm)
# initialize activation
if activation == 'relu':
self.activation = nn.ReLU(inplace=True)
elif activation == 'lrelu':
self.activation = nn.LeakyReLU(0.2, inplace=True)
elif activation == 'prelu':
self.activation = nn.PReLU()
elif activation == 'selu':
self.activation = nn.SELU(inplace=True)
elif activation == 'tanh':
self.activation = nn.Tanh()
elif activation == 'none':
self.activation = None
else:
assert 0, "Unsupported activation: {}".format(activation)
def forward(self, x):
out = self.fc(x)
if self.norm:
out = self.norm(out)
if self.activation:
out = self.activation(out)
return out
# Define a resnet block
class ResnetBlock(nn.Module):
def __init__(self, dim, norm_type, padding_type, use_dropout=False):
super(ResnetBlock, self).__init__()
self.conv_block = self.build_conv_block(dim, norm_type, padding_type, use_dropout)
def build_conv_block(self, dim, norm_type, padding_type, use_dropout):
conv_block = []
conv_block += [ConvBlock(dim, dim, 3, 1, 1, norm=norm_type, activation='relu', pad_type=padding_type)]
conv_block += [ConvBlock(dim, dim, 3, 1, 1, norm=norm_type, activation='none', pad_type=padding_type)]
return nn.Sequential(*conv_block)
def forward(self, x):
out = x + self.conv_block(x)
return out
class SFTLayer(nn.Module):
def __init__(self):
super(SFTLayer, self).__init__()
self.SFT_scale_conv1 = nn.Conv2d(64, 64, 1)
self.SFT_scale_conv2 = nn.Conv2d(64, 64, 1)
self.SFT_shift_conv1 = nn.Conv2d(64, 64, 1)
self.SFT_shift_conv2 = nn.Conv2d(64, 64, 1)
def forward(self, x):
scale = self.SFT_scale_conv2(F.leaky_relu(self.SFT_scale_conv1(x[1]), 0.1, inplace=True))
shift = self.SFT_shift_conv2(F.leaky_relu(self.SFT_shift_conv1(x[1]), 0.1, inplace=True))
return x[0] * scale + shift
class ConvBlock_SFT(nn.Module):
def __init__(self, dim, norm_type, padding_type, use_dropout=False):
super(ResnetBlock_SFT, self).__init__()
self.sft1 = SFTLayer()
self.conv1 = ConvBlock(dim, dim, 4, 2, 1, norm=norm_type, activation='none', pad_type=padding_type)
def forward(self, x):
fea = self.sft1((x[0], x[1]))
fea = F.relu(self.conv1(fea), inplace=True)
return (x[0] + fea, x[1])
class ConvBlock_SFT_last(nn.Module):
def __init__(self, dim, norm_type, padding_type, use_dropout=False):
super(ResnetBlock_SFT_last, self).__init__()
self.sft1 = SFTLayer()
self.conv1 = ConvBlock(dim, dim, 4, 2, 1, norm=norm_type, activation='none', pad_type=padding_type)
def forward(self, x):
fea = self.sft1((x[0], x[1]))
fea = F.relu(self.conv1(fea), inplace=True)
return x[0] + fea
# Definition of normalization layer
class AdaptiveInstanceNorm2d(nn.Module):
def __init__(self, num_features, eps=1e-5, momentum=0.1):
super(AdaptiveInstanceNorm2d, self).__init__()
self.num_features = num_features
self.eps = eps
self.momentum = momentum
# weight and bias are dynamically assigned
self.weight = None
self.bias = None
# just dummy buffers, not used
self.register_buffer('running_mean', torch.zeros(num_features))
self.register_buffer('running_var', torch.ones(num_features))
def forward(self, x):
assert self.weight is not None and self.bias is not None, "Please assign weight and bias before calling AdaIN!"
b, c = x.size(0), x.size(1)
running_mean = self.running_mean.repeat(b)
running_var = self.running_var.repeat(b)
# Apply instance norm
x_reshaped = x.contiguous().view(1, b * c, *x.size()[2:])
out = F.batch_norm(
x_reshaped, running_mean, running_var, self.weight, self.bias,
True, self.momentum, self.eps)
return out.view(b, c, *x.size()[2:])
def __repr__(self):
return self.__class__.__name__ + '(' + str(self.num_features) + ')'
class LayerNorm(nn.Module):
def __init__(self, num_features, eps=1e-5, affine=True):
super(LayerNorm, self).__init__()
self.num_features = num_features
self.affine = affine
self.eps = eps
if self.affine:
self.gamma = nn.Parameter(torch.Tensor(num_features).uniform_())
self.beta = nn.Parameter(torch.zeros(num_features))
def forward(self, x):
shape = [-1] + [1] * (x.dim() - 1)
# print(x.size())
if x.size(0) == 1:
# These two lines run much faster in pytorch 0.4 than the two lines listed below.
mean = x.view(-1).mean().view(*shape)
std = x.view(-1).std().view(*shape)
else:
mean = x.view(x.size(0), -1).mean(1).view(*shape)
std = x.view(x.size(0), -1).std(1).view(*shape)
x = (x - mean) / (std + self.eps)
if self.affine:
shape = [1, -1] + [1] * (x.dim() - 2)
x = x * self.gamma.view(*shape) + self.beta.view(*shape)
return x
def l2normalize(v, eps=1e-12):
return v / (v.norm() + eps)
class SpectralNorm(nn.Module):
"""
Based on the paper "Spectral Normalization for Generative Adversarial Networks" by Takeru Miyato, Toshiki Kataoka, Masanori Koyama, Yuichi Yoshida
and the Pytorch implementation https://github.com/christiancosgrove/pytorch-spectral-normalization-gan
"""
def __init__(self, module, name='weight', power_iterations=1):
super(SpectralNorm, self).__init__()
self.module = module
self.name = name
self.power_iterations = power_iterations
if not self._made_params():
self._make_params()
def _update_u_v(self):
u = getattr(self.module, self.name + "_u")
v = getattr(self.module, self.name + "_v")
w = getattr(self.module, self.name + "_bar")
height = w.data.shape[0]
for _ in range(self.power_iterations):
v.data = l2normalize(torch.mv(torch.t(w.view(height, -1).data), u.data))
u.data = l2normalize(torch.mv(w.view(height, -1).data, v.data))
# sigma = torch.dot(u.data, torch.mv(w.view(height,-1).data, v.data))
sigma = u.dot(w.view(height, -1).mv(v))
setattr(self.module, self.name, w / sigma.expand_as(w))
def _made_params(self):
try:
u = getattr(self.module, self.name + "_u")
v = getattr(self.module, self.name + "_v")
w = getattr(self.module, self.name + "_bar")
return True
except AttributeError:
return False
def _make_params(self):
w = getattr(self.module, self.name)
height = w.data.shape[0]
width = w.view(height, -1).data.shape[1]
u = nn.Parameter(w.data.new(height).normal_(0, 1), requires_grad=False)
v = nn.Parameter(w.data.new(width).normal_(0, 1), requires_grad=False)
u.data = l2normalize(u.data)
v.data = l2normalize(v.data)
w_bar = nn.Parameter(w.data)
del self.module._parameters[self.name]
self.module.register_parameter(self.name + "_u", u)
self.module.register_parameter(self.name + "_v", v)
self.module.register_parameter(self.name + "_bar", w_bar)
def forward(self, *args):
self._update_u_v()
return self.module.forward(*args)
### STN TPS
class CNN(nn.Module):
def __init__(self, num_output, input_nc=5, ngf=8, n_layers=5, norm_layer=nn.InstanceNorm2d, use_dropout=False):
super(CNN, self).__init__()
downconv = nn.Conv2d(5, ngf, kernel_size=4, stride=2, padding=1)
model = [downconv, nn.ReLU(True), norm_layer(ngf)]
for i in range(n_layers):
in_ngf = 2 ** i * ngf if 2 ** i * ngf < 1024 else 1024
out_ngf = 2 ** (i + 1) * ngf if 2 ** i * ngf < 1024 else 1024
downconv = nn.Conv2d(in_ngf, out_ngf, kernel_size=4, stride=2, padding=1)
model += [downconv, norm_layer(out_ngf), nn.ReLU(True)]
model += [nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1), norm_layer(64), nn.ReLU(True)]
model += [nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1), norm_layer(64), nn.ReLU(True)]
self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
self.model = nn.Sequential(*model)
self.fc1 = nn.Linear(512, 128)
self.fc2 = nn.Linear(128, num_output)
def forward(self, x):
x = self.model(x)
x = self.maxpool(x)
x = x.view(x.shape[0], -1)
x = F.relu(self.fc1(x))
x = F.dropout(x, training=self.training)
x = self.fc2(x)
return x
class ClsNet(nn.Module):
def __init__(self):
super(ClsNet, self).__init__()
self.cnn = CNN(10)
def forward(self, x):
return F.log_softmax(self.cnn(x))
class BoundedGridLocNet(nn.Module):
def __init__(self, grid_height, grid_width, target_control_points):
super(BoundedGridLocNet, self).__init__()
self.cnn = CNN(grid_height * grid_width * 2)
bias = torch.from_numpy(np.arctanh(target_control_points.numpy()))
bias = bias.view(-1)
self.cnn.fc2.bias.data.copy_(bias)
self.cnn.fc2.weight.data.zero_()
def forward(self, x):
batch_size = x.size(0)
points = F.tanh(self.cnn(x))
coor=points.view(batch_size, -1, 2)
# coor+=torch.randn(coor.shape).cuda()/10
row=self.get_row(coor,5)
col=self.get_col(coor,5)
rx,ry,cx,cy=torch.tensor(0.08).cuda(),torch.tensor(0.08).cuda()\
,torch.tensor(0.08).cuda(),torch.tensor(0.08).cuda()
row_x,row_y=row[:,:,0],row[:,:,1]
col_x,col_y=col[:,:,0],col[:,:,1]
rx_loss=torch.max(rx,row_x).mean()
ry_loss=torch.max(ry,row_y).mean()
cx_loss=torch.max(cx,col_x).mean()
cy_loss=torch.max(cy,col_y).mean()
return coor,rx_loss,ry_loss,cx_loss,cy_loss
def get_row(self,coor,num):
sec_dic=[]
for j in range(num):
sum=0
buffer=0
flag=False
max=-1
for i in range(num-1):
differ=(coor[:,j*num+i+1,:]-coor[:,j*num+i,:])**2
if not flag:
second_dif=0
flag=True
else:
second_dif=torch.abs(differ-buffer)
sec_dic.append(second_dif)
buffer=differ
sum+=second_dif
return torch.stack(sec_dic,dim=1)
def get_col(self,coor,num):
sec_dic=[]
for i in range(num):
sum = 0
buffer = 0
flag = False
max = -1
for j in range(num - 1):
differ = (coor[:, (j+1) * num + i , :] - coor[:, j * num + i, :]) ** 2
if not flag:
second_dif = 0
flag = True
else:
second_dif = torch.abs(differ-buffer)
sec_dic.append(second_dif)
buffer = differ
sum += second_dif
return torch.stack(sec_dic,dim=1)
class UnBoundedGridLocNet(nn.Module):
def __init__(self, grid_height, grid_width, target_control_points):
super(UnBoundedGridLocNet, self).__init__()
self.cnn = CNN(grid_height * grid_width * 2)
bias = target_control_points.view(-1)
self.cnn.fc2.bias.data.copy_(bias)
self.cnn.fc2.weight.data.zero_()
def forward(self, x):
batch_size = x.size(0)
points = self.cnn(x)
return points.view(batch_size, -1, 2)
class STNNet(nn.Module):
def __init__(self):
super(STNNet, self).__init__()
range = 0.9
r1 = range
r2 = range
grid_size_h = 5
grid_size_w = 5
assert r1 < 1 and r2 < 1 # if >= 1, arctanh will cause error in BoundedGridLocNet
target_control_points = torch.Tensor(list(itertools.product(
np.arange(-r1, r1 + 0.00001, 2.0 * r1 / (grid_size_h - 1)),
np.arange(-r2, r2 + 0.00001, 2.0 * r2 / (grid_size_w - 1)),
)))
Y, X = target_control_points.split(1, dim=1)
target_control_points = torch.cat([X, Y], dim=1)
self.target_control_points=target_control_points
# self.get_row(target_control_points,5)
GridLocNet = {
'unbounded_stn': UnBoundedGridLocNet,
'bounded_stn': BoundedGridLocNet,
}['bounded_stn']
self.loc_net = GridLocNet(grid_size_h, grid_size_w, target_control_points)
self.tps = TPSGridGen(256, 192, target_control_points)
def get_row(self, coor, num):
for j in range(num):
sum = 0
buffer = 0
flag = False
max = -1
for i in range(num - 1):
differ = (coor[j * num + i + 1, :] - coor[j * num + i, :]) ** 2
if not flag:
second_dif = 0
flag = True
else:
second_dif = torch.abs(differ - buffer)
buffer = differ
sum += second_dif
print(sum / num)
def get_col(self,coor,num):
for i in range(num):
sum = 0
buffer = 0
flag = False
max = -1
for j in range(num - 1):
differ = (coor[ (j + 1) * num + i, :] - coor[j * num + i, :]) ** 2
if not flag:
second_dif = 0
flag = True
else:
second_dif = torch.abs(differ-buffer)
buffer = differ
sum += second_dif
print(sum)
def forward(self, x, reference, mask,grid_pic):
batch_size = x.size(0)
source_control_points,rx,ry,cx,cy = self.loc_net(reference)
source_control_points=(source_control_points)
# print('control points',source_control_points.shape)
source_coordinate = self.tps(source_control_points)
grid = source_coordinate.view(batch_size, 256, 192, 2)
# print('grid size',grid.shape)
transformed_x = grid_sample(x, grid, canvas=0)
warped_mask = grid_sample(mask, grid, canvas=0)
warped_gpic= grid_sample(grid_pic, grid, canvas=0)
return transformed_x, warped_mask,rx,ry,cx,cy,warped_gpic