Spaces:
Runtime error
Runtime error
import torch | |
import os | |
import torch.nn as nn | |
import functools | |
from torch.autograd import Variable | |
import numpy as np | |
import torch.nn.functional as F | |
import math | |
import torch | |
import itertools | |
import numpy as np | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from grid_sample import grid_sample | |
from torch.autograd import Variable | |
from tps_grid_gen import TPSGridGen | |
############################################################################### | |
# Functions | |
############################################################################### | |
def weights_init(m): | |
classname = m.__class__.__name__ | |
if classname.find('Conv2d') != -1: | |
m.weight.data.normal_(0.0, 0.02) | |
elif classname.find('BatchNorm2d') != -1: | |
m.weight.data.normal_(1.0, 0.02) | |
m.bias.data.fill_(0) | |
def get_norm_layer(norm_type='instance'): | |
if norm_type == 'batch': | |
norm_layer = functools.partial(nn.BatchNorm2d, affine=True) | |
elif norm_type == 'instance': | |
norm_layer = functools.partial(nn.InstanceNorm2d, affine=False) | |
else: | |
raise NotImplementedError('normalization layer [%s] is not found' % norm_type) | |
return norm_layer | |
def define_G(input_nc, output_nc, ngf, netG, L=1, S=1, n_downsample_global=3, n_blocks_global=9, n_local_enhancers=1, | |
n_blocks_local=3, norm='instance', gpu_ids=[]): | |
norm_layer = get_norm_layer(norm_type=norm) | |
if netG == 'global': | |
netG = GlobalGenerator(input_nc, output_nc, L, S, ngf, n_downsample_global, n_blocks_global, norm_layer) | |
elif netG == 'local': | |
netG = LocalEnhancer(input_nc, output_nc, ngf, n_downsample_global, n_blocks_global, | |
n_local_enhancers, n_blocks_local, norm_layer) | |
else: | |
raise ('generator not implemented!') | |
print(netG) | |
if len(gpu_ids) > 0: | |
assert (torch.cuda.is_available()) | |
netG.cuda(gpu_ids[0]) | |
netG.apply(weights_init) | |
return netG | |
def define_Unet(input_nc, gpu_ids=[]): | |
netG = Unet(input_nc) | |
netG.cuda(gpu_ids[0]) | |
netG.apply(weights_init) | |
return netG | |
def define_UnetMask(input_nc, gpu_ids=[]): | |
netG = UnetMask(input_nc,output_nc=4) | |
netG.cuda(gpu_ids[0]) | |
netG.apply(weights_init) | |
return netG | |
def define_Refine(input_nc, output_nc, gpu_ids=[]): | |
netG = Refine(input_nc, output_nc) | |
netG.cuda(gpu_ids[0]) | |
netG.apply(weights_init) | |
return netG | |
#################################################### | |
def define_Refine_ResUnet(input_nc, output_nc, gpu_ids=[]): | |
#ipdb.set_trace() | |
netG = Refine_ResUnet_New(input_nc, output_nc) #norm_layer=nn.InstanceNorm2d | |
#ipdb.set_trace() | |
netG.cuda(gpu_ids[0]) | |
netG.apply(weights_init) | |
return netG | |
#################################################### | |
def define_D(input_nc, ndf, n_layers_D, norm='instance', use_sigmoid=False, num_D=1, getIntermFeat=False, gpu_ids=[]): | |
norm_layer = get_norm_layer(norm_type=norm) | |
netD = MultiscaleDiscriminator(input_nc, ndf, n_layers_D, norm_layer, use_sigmoid, num_D, getIntermFeat) | |
print(netD) | |
if len(gpu_ids) > 0: | |
assert (torch.cuda.is_available()) | |
netD.cuda(gpu_ids[0]) | |
netD.apply(weights_init) | |
return netD | |
def define_VAE(input_nc, gpu_ids=[]): | |
netVAE = VAE(19, 32, 32, 1024) | |
print(netVAE) | |
if len(gpu_ids) > 0: | |
assert (torch.cuda.is_available()) | |
netVAE.cuda(gpu_ids[0]) | |
return netVAE | |
def define_B(input_nc, output_nc, ngf, n_downsample_global=3, n_blocks_global=3, norm='instance', gpu_ids=[]): | |
norm_layer = get_norm_layer(norm_type=norm) | |
netB = BlendGenerator(input_nc, output_nc, ngf, n_downsample_global, n_blocks_global, norm_layer) | |
print(netB) | |
if len(gpu_ids) > 0: | |
assert (torch.cuda.is_available()) | |
netB.cuda(gpu_ids[0]) | |
netB.apply(weights_init) | |
return netB | |
def define_partial_enc(input_nc, gpu_ids=[]): | |
net = PartialConvEncoder(input_nc) | |
print(net) | |
if len(gpu_ids) > 0: | |
assert (torch.cuda.is_available()) | |
net.cuda(gpu_ids[0]) | |
net.apply(weights_init) | |
return net | |
def define_conv_enc(input_nc, gpu_ids=[]): | |
net = ConvEncoder(input_nc) | |
print(net) | |
if len(gpu_ids) > 0: | |
assert (torch.cuda.is_available()) | |
net.cuda(gpu_ids[0]) | |
net.apply(weights_init) | |
return net | |
def define_AttG(output_nc, gpu_ids=[]): | |
net = AttGenerator(output_nc) | |
print(net) | |
if len(gpu_ids) > 0: | |
assert (torch.cuda.is_available()) | |
net.cuda(gpu_ids[0]) | |
net.apply(weights_init) | |
return net | |
def print_network(net): | |
if isinstance(net, list): | |
net = net[0] | |
num_params = 0 | |
for param in net.parameters(): | |
num_params += param.numel() | |
print(net) | |
print('Total number of parameters: %d' % num_params) | |
############################################################################## | |
# Losses | |
############################################################################## | |
class GANLoss(nn.Module): | |
def __init__(self, use_lsgan=True, target_real_label=1.0, target_fake_label=0.0, | |
tensor=torch.FloatTensor): | |
super(GANLoss, self).__init__() | |
self.real_label = target_real_label | |
self.fake_label = target_fake_label | |
self.real_label_var = None | |
self.fake_label_var = None | |
self.Tensor = tensor | |
if use_lsgan: | |
self.loss = nn.MSELoss() | |
else: | |
self.loss = nn.BCELoss() | |
def get_target_tensor(self, input, target_is_real): | |
target_tensor = None | |
if target_is_real: | |
create_label = ((self.real_label_var is None) or | |
(self.real_label_var.numel() != input.numel())) | |
if create_label: | |
real_tensor = self.Tensor(input.size()).fill_(self.real_label) | |
self.real_label_var = Variable(real_tensor, requires_grad=False) | |
target_tensor = self.real_label_var | |
else: | |
create_label = ((self.fake_label_var is None) or | |
(self.fake_label_var.numel() != input.numel())) | |
if create_label: | |
fake_tensor = self.Tensor(input.size()).fill_(self.fake_label) | |
self.fake_label_var = Variable(fake_tensor, requires_grad=False) | |
target_tensor = self.fake_label_var | |
return target_tensor | |
def __call__(self, input, target_is_real): | |
if isinstance(input[0], list): | |
loss = 0 | |
for input_i in input: | |
pred = input_i[-1] | |
target_tensor = self.get_target_tensor(pred, target_is_real) | |
loss += self.loss(pred, target_tensor) | |
return loss | |
else: | |
target_tensor = self.get_target_tensor(input[-1], target_is_real) | |
return self.loss(input[-1], target_tensor) | |
class VGGLossWarp(nn.Module): | |
def __init__(self, gpu_ids): | |
super(VGGLossWarp, self).__init__() | |
self.vgg = Vgg19().cuda() | |
self.criterion = nn.L1Loss() | |
self.weights = [1.0 / 32, 1.0 / 16, 1.0 / 8, 1.0 / 4, 1.0] | |
def forward(self, x, y): | |
x_vgg, y_vgg = self.vgg(x), self.vgg(y) | |
loss = 0 | |
loss += self.weights[4] * self.criterion(x_vgg[4], y_vgg[4].detach()) | |
return loss | |
class VGGLoss(nn.Module): | |
def __init__(self, gpu_ids): | |
super(VGGLoss, self).__init__() | |
self.vgg = Vgg19().cuda() | |
self.criterion = nn.L1Loss() | |
self.weights = [1.0 / 32, 1.0 / 16, 1.0 / 8, 1.0 / 4, 1.0] | |
def forward(self, x, y): | |
x_vgg, y_vgg = self.vgg(x), self.vgg(y) | |
loss = 0 | |
for i in range(len(x_vgg)): | |
loss += self.weights[i] * self.criterion(x_vgg[i], y_vgg[i].detach()) | |
return loss | |
def warp(self, x, y): | |
x_vgg, y_vgg = self.vgg(x), self.vgg(y) | |
loss = 0 | |
loss += self.weights[4] * self.criterion(x_vgg[4], y_vgg[4].detach()) | |
return loss | |
class StyleLoss(nn.Module): | |
def __init__(self, gpu_ids): | |
super(StyleLoss, self).__init__() | |
self.vgg = Vgg19().cuda() | |
self.weights = [1.0 / 32, 1.0 / 16, 1.0 / 8, 1.0 / 4, 1.0] | |
def forward(self, x, y): | |
x_vgg, y_vgg = self.vgg(x), self.vgg(y) | |
loss = 0 | |
for i in range(len(x_vgg)): | |
N, C, H, W = x_vgg[i].shape | |
for n in range(N): | |
phi_x = x_vgg[i][n] | |
phi_y = y_vgg[i][n] | |
phi_x = phi_x.reshape(C, H * W) | |
phi_y = phi_y.reshape(C, H * W) | |
G_x = torch.matmul(phi_x, phi_x.t()) / (C * H * W) | |
G_y = torch.matmul(phi_y, phi_y.t()) / (C * H * W) | |
loss += torch.sqrt(torch.mean((G_x - G_y) ** 2)) * self.weights[i] | |
return loss | |
############################################################################## | |
# Generator | |
############################################################################## | |
class PartialConvEncoder(nn.Module): | |
def __init__(self, input_nc, ngf=32, norm_layer=nn.BatchNorm2d): | |
super(PartialConvEncoder, self).__init__() | |
activation = nn.ReLU(True) | |
self.pad1 = nn.ReflectionPad2d(3) | |
self.partial_conv1 = PartialConv(input_nc, ngf, kernel_size=7) | |
self.norm_layer1 = norm_layer(ngf) | |
self.activation = activation | |
##down sample | |
mult = 2 ** 0 | |
self.down1 = PartialConv(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1) | |
self.norm_layer2 = norm_layer(ngf * mult * 2) | |
mult = 2 ** 1 | |
self.down2 = PartialConv(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1) | |
self.norm_layer3 = norm_layer(ngf * mult * 2) | |
mult = 2 ** 2 | |
self.down3 = PartialConv(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1) | |
self.norm_layer4 = norm_layer(ngf * mult * 2) | |
mult = 2 ** 3 | |
self.down4 = PartialConv(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1) | |
self.norm_layer5 = norm_layer(ngf * mult * 2) | |
def forward(self, input, mask): | |
input = self.pad1(input) | |
mask = self.pad1(mask) | |
input, mask = self.partial_conv1(input, mask) | |
input = self.norm_layer1(input) | |
input = self.activation(input) | |
input, mask = self.down1(input, mask) | |
input = self.norm_layer2(input) | |
input = self.activation(input) | |
input, mask = self.down2(input, mask) | |
input = self.norm_layer3(input) | |
input = self.activation(input) | |
input, mask = self.down3(input, mask) | |
input = self.norm_layer4(input) | |
input = self.activation(input) | |
input, mask = self.down4(input, mask) | |
input = self.norm_layer5(input) | |
input = self.activation(input) | |
return input | |
class ConvEncoder(nn.Module): | |
def __init__(self, input_nc, ngf=32, n_downsampling=4, n_blocks=4, norm_layer=nn.BatchNorm2d, | |
padding_type='reflect'): | |
super(ConvEncoder, self).__init__() | |
activation = nn.ReLU(True) | |
# print("input_nc",input_nc) | |
model = [nn.ReflectionPad2d(3), nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0), norm_layer(ngf), activation] | |
### downsample | |
for i in range(n_downsampling): | |
stride = 2 | |
mult = 2 ** i | |
model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=stride, padding=1), | |
norm_layer(ngf * mult * 2), activation] | |
self.model = nn.Sequential(*model) | |
def forward(self, input): | |
return self.model(input) | |
class AttGenerator(nn.Module): | |
def __init__(self, output_nc, ngf=32, n_blocks=4, n_downsampling=4, padding_type='reflect'): | |
super(AttGenerator, self).__init__() | |
mult = 2 ** n_downsampling | |
model = [] | |
for i in range(n_blocks): | |
model += [ResnetBlock(ngf * mult * 2, norm_type='in', padding_type=padding_type)] | |
self.model = nn.Sequential(*model) | |
self.upsampling = [] | |
self.out_channels = [] | |
self.AttNorm = [] | |
##upsampling | |
norm_layer = nn.BatchNorm2d | |
activation = nn.ReLU(True) | |
for i in range(n_downsampling): | |
mult = 2 ** (n_downsampling - i) | |
up_module = [nn.ConvTranspose2d(ngf * mult * 2, int(ngf * mult / 2) * 2, kernel_size=3, stride=2, padding=1, | |
output_padding=1), | |
norm_layer(int(ngf * mult / 2) * 2), activation | |
] | |
up_module = nn.Sequential(*up_module) | |
self.upsampling += [up_module] | |
self.out_channels += [int(ngf * mult / 2) * 2] | |
self.upsampling = nn.Sequential(*self.upsampling) | |
# | |
self.AttNorm += [AttentionNorm(5, self.out_channels[0], 2, 4)] | |
self.AttNorm += [AttentionNorm(5, self.out_channels[1], 2, 2)] | |
self.AttNorm += [AttentionNorm(5, self.out_channels[2], 1, 2)] | |
self.AttNorm += [AttentionNorm(5, self.out_channels[3], 1, 1)] | |
self.AttNorm = nn.Sequential(*self.AttNorm) | |
self.last_conv = [nn.ReflectionPad2d(3), nn.Conv2d(ngf * 2, output_nc, kernel_size=7, padding=0), nn.Tanh()] | |
self.last_conv = nn.Sequential(*self.last_conv) | |
def forward(self, input, unattended): | |
up = self.model(unattended) | |
for i in range(4): | |
# print(i) | |
up = self.upsampling[i](up) | |
if i == 3: | |
break; | |
up = self.AttNorm[i](input, up) | |
return self.last_conv(up) | |
class PartialConv(nn.Module): | |
def __init__(self, in_channels, out_channels, kernel_size, stride=1, | |
padding=0, dilation=1, groups=1, bias=True): | |
super(PartialConv, self).__init__() | |
self.input_conv = nn.Conv2d(in_channels, out_channels, kernel_size, | |
stride, padding, dilation, groups, bias) | |
self.mask_conv = nn.Conv2d(in_channels, out_channels, kernel_size, | |
stride, padding, dilation, groups, False) | |
self.input_conv.apply(weights_init) | |
torch.nn.init.constant_(self.mask_conv.weight, 1.0) | |
# mask is not updated | |
for param in self.mask_conv.parameters(): | |
param.requires_grad = False | |
def forward(self, input, mask): | |
# http://masc.cs.gmu.edu/wiki/partialconv | |
# C(X) = W^T * X + b, C(0) = b, D(M) = 1 * M + 0 = sum(M) | |
# W^T* (M .* X) / sum(M) + b = [C(M .* X) β C(0)] / D(M) + C(0) | |
output = self.input_conv(input * mask) | |
if self.input_conv.bias is not None: | |
output_bias = self.input_conv.bias.view(1, -1, 1, 1).expand_as( | |
output) | |
else: | |
output_bias = torch.zeros_like(output) | |
with torch.no_grad(): | |
output_mask = self.mask_conv(mask) | |
no_update_holes = output_mask == 0 | |
mask_sum = output_mask.masked_fill_(no_update_holes, 1.0) | |
output_pre = (output - output_bias) / mask_sum + output_bias | |
output = output_pre.masked_fill_(no_update_holes, 0.0) | |
new_mask = torch.ones_like(output) | |
new_mask = new_mask.masked_fill_(no_update_holes, 0.0) | |
return output, new_mask | |
class AttentionNorm(nn.Module): | |
def __init__(self, ref_channels, out_channels, first_rate, second_rate): | |
super(AttentionNorm, self).__init__() | |
self.first = first_rate | |
self.second = second_rate | |
mid_channels = int(out_channels / 2) | |
self.conv_1time_f = nn.Conv2d(ref_channels, mid_channels, kernel_size=3, stride=1, padding=1) | |
self.conv_2times_f = nn.Conv2d(ref_channels, mid_channels, kernel_size=3, stride=2, padding=1) | |
self.conv_4times_f = nn.Conv2d(ref_channels, mid_channels, kernel_size=3, stride=4, padding=1) | |
self.conv_1time_s = nn.Conv2d(mid_channels, out_channels, kernel_size=3, stride=1, padding=1) | |
self.conv_2times_s = nn.Conv2d(mid_channels, out_channels, kernel_size=3, stride=2, padding=1) | |
self.conv_4times_s = nn.Conv2d(mid_channels, out_channels, kernel_size=3, stride=4, padding=1) | |
self.conv_1time_m = nn.Conv2d(mid_channels, out_channels, kernel_size=3, stride=1, padding=1) | |
self.conv_2times_m = nn.Conv2d(mid_channels, out_channels, kernel_size=3, stride=2, padding=1) | |
self.conv_4times_m = nn.Conv2d(mid_channels, out_channels, kernel_size=3, stride=4, padding=1) | |
self.norm = nn.BatchNorm2d(out_channels) | |
self.conv = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) | |
def forward(self, input, unattended): | |
# attention weights | |
# print(input.shape,unattended.shape) | |
if self.first == 1: | |
input = self.conv_1time_f(input) | |
elif self.first == 2: | |
input = self.conv_2times_f(input) | |
elif self.first == 4: | |
input = self.conv_4times_f(input) | |
mask = None | |
if self.second == 1: | |
bias = self.conv_1time_s(input) | |
mask = self.conv_1time_m(input) | |
elif self.second == 2: | |
bias = self.conv_2times_s(input) | |
mask = self.conv_2times_m(input) | |
elif self.second == 4: | |
bias = self.conv_4times_s(input) | |
mask = self.conv_4times_m(input) | |
mask = torch.sigmoid(mask) | |
attended = self.norm(unattended) | |
# print(attended.shape,mask.shape,bias.shape) | |
attended = attended * mask + bias | |
attended = torch.relu(attended) | |
attended = self.conv(attended) | |
output = attended + unattended | |
return output | |
class UnetMask(nn.Module): | |
def __init__(self, input_nc, output_nc=3): | |
super(UnetMask, self).__init__() | |
self.stn = STNNet() | |
nl = nn.InstanceNorm2d | |
self.conv1 = nn.Sequential(*[nn.Conv2d(input_nc, 64, kernel_size=3, stride=1, padding=1), nl(64), nn.ReLU(), | |
nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1), nl(64), nn.ReLU()]) | |
self.pool1 = nn.MaxPool2d(kernel_size=(2, 2)) | |
self.conv2 = nn.Sequential(*[nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1), nl(128), nn.ReLU(), | |
nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1), nl(128), nn.ReLU()]) | |
self.pool2 = nn.MaxPool2d(kernel_size=(2, 2)) | |
self.conv3 = nn.Sequential(*[nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1), nl(256), nn.ReLU(), | |
nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1), nl(256), nn.ReLU()]) | |
self.pool3 = nn.MaxPool2d(kernel_size=(2, 2)) | |
self.conv4 = nn.Sequential(*[nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1), nl(512), nn.ReLU(), | |
nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1), nl(512), nn.ReLU()]) | |
self.drop4 = nn.Dropout(0.5) | |
self.pool4 = nn.MaxPool2d(kernel_size=(2, 2)) | |
self.conv5 = nn.Sequential(*[nn.Conv2d(512, 1024, kernel_size=3, stride=1, padding=1), nl(1024), nn.ReLU(), | |
nn.Conv2d(1024, 1024, kernel_size=3, stride=1, padding=1), nl(1024), nn.ReLU()]) | |
self.drop5 = nn.Dropout(0.5) | |
self.up6 = nn.Sequential( | |
*[nn.UpsamplingNearest2d(scale_factor=2), nn.Conv2d(1024, 512, kernel_size=3, stride=1, padding=1), nl(512), | |
nn.ReLU()]) | |
self.conv6 = nn.Sequential(*[nn.Conv2d(1024, 512, kernel_size=3, stride=1, padding=1), nl(512), nn.ReLU(), | |
nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1), nl(512), nn.ReLU()]) | |
self.up7 = nn.Sequential( | |
*[nn.UpsamplingNearest2d(scale_factor=2), nn.Conv2d(512, 256, kernel_size=3, stride=1, padding=1), nl(256), | |
nn.ReLU()]) | |
self.conv7 = nn.Sequential(*[nn.Conv2d(512, 256, kernel_size=3, stride=1, padding=1), nl(256), nn.ReLU(), | |
nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1), nl(256), nn.ReLU()]) | |
self.up8 = nn.Sequential( | |
*[nn.UpsamplingNearest2d(scale_factor=2), nn.Conv2d(256, 128, kernel_size=3, stride=1, padding=1), nl(128), | |
nn.ReLU()]) | |
self.conv8 = nn.Sequential(*[nn.Conv2d(256, 128, kernel_size=3, stride=1, padding=1), nl(128), nn.ReLU(), | |
nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1), nl(128), nn.ReLU()]) | |
self.up9 = nn.Sequential( | |
*[nn.UpsamplingNearest2d(scale_factor=2), nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1), nl(64), | |
nn.ReLU()]) | |
self.conv9 = nn.Sequential(*[nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1), nl(64), nn.ReLU(), | |
nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1), nl(64), nn.ReLU(), | |
nn.Conv2d(64, output_nc, kernel_size=3, stride=1, padding=1) | |
]) | |
def forward(self, input, refer, mask,grid): | |
input, warped_mask,rx,ry,cx,cy,grid = self.stn(input, torch.cat([mask, refer, input], 1), mask,grid) | |
# print(input.shape) | |
conv1 = self.conv1(torch.cat([refer.detach(), input.detach()], 1)) | |
pool1 = self.pool1(conv1) | |
conv2 = self.conv2(pool1) | |
pool2 = self.pool2(conv2) | |
conv3 = self.conv3(pool2) | |
pool3 = self.pool3(conv3) | |
conv4 = self.conv4(pool3) | |
drop4 = self.drop4(conv4) | |
pool4 = self.pool4(drop4) | |
conv5 = self.conv5(pool4) | |
drop5 = self.drop5(conv5) | |
up6 = self.up6(drop5) | |
conv6 = self.conv6(torch.cat([drop4, up6], 1)) | |
up7 = self.up7(conv6) | |
conv7 = self.conv7(torch.cat([conv3, up7], 1)) | |
up8 = self.up8(conv7) | |
conv8 = self.conv8(torch.cat([conv2, up8], 1)) | |
up9 = self.up9(conv8) | |
conv9 = self.conv9(torch.cat([conv1, up9], 1)) | |
return conv9, input, warped_mask,grid | |
class Unet(nn.Module): | |
def __init__(self, input_nc, output_nc=3): | |
super(Unet, self).__init__() | |
self.stn = STNNet() | |
nl = nn.InstanceNorm2d | |
self.conv1 = nn.Sequential(*[nn.Conv2d(input_nc, 64, kernel_size=3, stride=1, padding=1), nl(64), nn.ReLU(), | |
nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1), nl(64), nn.ReLU()]) | |
self.pool1 = nn.MaxPool2d(kernel_size=(2, 2)) | |
self.conv2 = nn.Sequential(*[nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1), nl(128), nn.ReLU(), | |
nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1), nl(128), nn.ReLU()]) | |
self.pool2 = nn.MaxPool2d(kernel_size=(2, 2)) | |
self.conv3 = nn.Sequential(*[nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1), nl(256), nn.ReLU(), | |
nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1), nl(256), nn.ReLU()]) | |
self.pool3 = nn.MaxPool2d(kernel_size=(2, 2)) | |
self.conv4 = nn.Sequential(*[nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1), nl(512), nn.ReLU(), | |
nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1), nl(512), nn.ReLU()]) | |
self.drop4 = nn.Dropout(0.5) | |
self.pool4 = nn.MaxPool2d(kernel_size=(2, 2)) | |
self.conv5 = nn.Sequential(*[nn.Conv2d(512, 1024, kernel_size=3, stride=1, padding=1), nl(1024), nn.ReLU(), | |
nn.Conv2d(1024, 1024, kernel_size=3, stride=1, padding=1), nl(1024), nn.ReLU()]) | |
self.drop5 = nn.Dropout(0.5) | |
self.up6 = nn.Sequential( | |
*[nn.UpsamplingNearest2d(scale_factor=2), nn.Conv2d(1024, 512, kernel_size=3, stride=1, padding=1), nl(512), | |
nn.ReLU()]) | |
self.conv6 = nn.Sequential(*[nn.Conv2d(1024, 512, kernel_size=3, stride=1, padding=1), nl(512), nn.ReLU(), | |
nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1), nl(512), nn.ReLU()]) | |
self.up7 = nn.Sequential( | |
*[nn.UpsamplingNearest2d(scale_factor=2), nn.Conv2d(512, 256, kernel_size=3, stride=1, padding=1), nl(256), | |
nn.ReLU()]) | |
self.conv7 = nn.Sequential(*[nn.Conv2d(512, 256, kernel_size=3, stride=1, padding=1), nl(256), nn.ReLU(), | |
nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1), nl(256), nn.ReLU()]) | |
self.up8 = nn.Sequential( | |
*[nn.UpsamplingNearest2d(scale_factor=2), nn.Conv2d(256, 128, kernel_size=3, stride=1, padding=1), nl(128), | |
nn.ReLU()]) | |
self.conv8 = nn.Sequential(*[nn.Conv2d(256, 128, kernel_size=3, stride=1, padding=1), nl(128), nn.ReLU(), | |
nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1), nl(128), nn.ReLU()]) | |
self.up9 = nn.Sequential( | |
*[nn.UpsamplingNearest2d(scale_factor=2), nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1), nl(64), | |
nn.ReLU()]) | |
self.conv9 = nn.Sequential(*[nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1), nl(64), nn.ReLU(), | |
nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1), nl(64), nn.ReLU(), | |
nn.Conv2d(64, output_nc, kernel_size=3, stride=1, padding=1) | |
]) | |
def forward(self, input, refer, mask): | |
input, warped_mask,rx,ry,cx,cy = self.stn(input, torch.cat([mask, refer, input], 1), mask) | |
# print(input.shape) | |
conv1 = self.conv1(torch.cat([refer.detach(), input.detach()], 1)) | |
pool1 = self.pool1(conv1) | |
conv2 = self.conv2(pool1) | |
pool2 = self.pool2(conv2) | |
conv3 = self.conv3(pool2) | |
pool3 = self.pool3(conv3) | |
conv4 = self.conv4(pool3) | |
drop4 = self.drop4(conv4) | |
pool4 = self.pool4(drop4) | |
conv5 = self.conv5(pool4) | |
drop5 = self.drop5(conv5) | |
up6 = self.up6(drop5) | |
conv6 = self.conv6(torch.cat([drop4, up6], 1)) | |
up7 = self.up7(conv6) | |
conv7 = self.conv7(torch.cat([conv3, up7], 1)) | |
up8 = self.up8(conv7) | |
conv8 = self.conv8(torch.cat([conv2, up8], 1)) | |
up9 = self.up9(conv8) | |
conv9 = self.conv9(torch.cat([conv1, up9], 1)) | |
return conv9, input, warped_mask,rx,ry,cx,cy | |
def refine(self, input): | |
conv1 = self.conv1(input) | |
pool1 = self.pool1(conv1) | |
conv2 = self.conv2(pool1) | |
pool2 = self.pool2(conv2) | |
conv3 = self.conv3(pool2) | |
pool3 = self.pool3(conv3) | |
conv4 = self.conv4(pool3) | |
drop4 = self.drop4(conv4) | |
pool4 = self.pool4(drop4) | |
conv5 = self.conv5(pool4) | |
drop5 = self.drop5(conv5) | |
up6 = self.up6(drop5) | |
conv6 = self.conv6(torch.cat([drop4, up6], 1)) | |
up7 = self.up7(conv6) | |
conv7 = self.conv7(torch.cat([conv3, up7], 1)) | |
up8 = self.up8(conv7) | |
conv8 = self.conv8(torch.cat([conv2, up8], 1)) | |
up9 = self.up9(conv8) | |
conv9 = self.conv9(torch.cat([conv1, up9], 1)) | |
return conv9 | |
class Refine(nn.Module): | |
def __init__(self, input_nc, output_nc=3): | |
super(Refine, self).__init__() | |
nl = nn.InstanceNorm2d | |
self.conv1 = nn.Sequential(*[nn.Conv2d(input_nc, 64, kernel_size=3, stride=1, padding=1), nl(64), nn.ReLU(), | |
nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1), nl(64), nn.ReLU()]) | |
self.pool1 = nn.MaxPool2d(kernel_size=(2, 2)) | |
self.conv2 = nn.Sequential(*[nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1), nl(128), nn.ReLU(), | |
nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1), nl(128), nn.ReLU()]) | |
self.pool2 = nn.MaxPool2d(kernel_size=(2, 2)) | |
self.conv3 = nn.Sequential(*[nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1), nl(256), nn.ReLU(), | |
nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1), nl(256), nn.ReLU()]) | |
self.pool3 = nn.MaxPool2d(kernel_size=(2, 2)) | |
self.conv4 = nn.Sequential(*[nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1), nl(512), nn.ReLU(), | |
nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1), nl(512), nn.ReLU()]) | |
self.drop4 = nn.Dropout(0.5) | |
self.pool4 = nn.MaxPool2d(kernel_size=(2, 2)) | |
self.conv5 = nn.Sequential(*[nn.Conv2d(512, 1024, kernel_size=3, stride=1, padding=1), nl(1024), nn.ReLU(), | |
nn.Conv2d(1024, 1024, kernel_size=3, stride=1, padding=1), nl(1024), nn.ReLU()]) | |
self.drop5 = nn.Dropout(0.5) | |
self.up6 = nn.Sequential( | |
*[nn.UpsamplingNearest2d(scale_factor=2), nn.Conv2d(1024, 512, kernel_size=3, stride=1, padding=1), nl(512), | |
nn.ReLU()]) | |
self.conv6 = nn.Sequential(*[nn.Conv2d(1024, 512, kernel_size=3, stride=1, padding=1), nl(512), nn.ReLU(), | |
nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1), nl(512), nn.ReLU()]) | |
self.up7 = nn.Sequential( | |
*[nn.UpsamplingNearest2d(scale_factor=2), nn.Conv2d(512, 256, kernel_size=3, stride=1, padding=1), nl(256), | |
nn.ReLU()]) | |
self.conv7 = nn.Sequential(*[nn.Conv2d(512, 256, kernel_size=3, stride=1, padding=1), nl(256), nn.ReLU(), | |
nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1), nl(256), nn.ReLU()]) | |
self.up8 = nn.Sequential( | |
*[nn.UpsamplingNearest2d(scale_factor=2), nn.Conv2d(256, 128, kernel_size=3, stride=1, padding=1), nl(128), | |
nn.ReLU()]) | |
self.conv8 = nn.Sequential(*[nn.Conv2d(256, 128, kernel_size=3, stride=1, padding=1), nl(128), nn.ReLU(), | |
nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1), nl(128), nn.ReLU()]) | |
self.up9 = nn.Sequential( | |
*[nn.UpsamplingNearest2d(scale_factor=2), nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1), nl(64), | |
nn.ReLU()]) | |
self.conv9 = nn.Sequential(*[nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1), nl(64), nn.ReLU(), | |
nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1), nl(64), nn.ReLU(), | |
nn.Conv2d(64, output_nc, kernel_size=3, stride=1, padding=1) | |
]) | |
def refine(self, input): | |
conv1 = self.conv1(input) | |
pool1 = self.pool1(conv1) | |
conv2 = self.conv2(pool1) | |
pool2 = self.pool2(conv2) | |
conv3 = self.conv3(pool2) | |
pool3 = self.pool3(conv3) | |
conv4 = self.conv4(pool3) | |
drop4 = self.drop4(conv4) | |
pool4 = self.pool4(drop4) | |
conv5 = self.conv5(pool4) | |
drop5 = self.drop5(conv5) | |
up6 = self.up6(drop5) | |
conv6 = self.conv6(torch.cat([drop4, up6], 1)) | |
up7 = self.up7(conv6) | |
conv7 = self.conv7(torch.cat([conv3, up7], 1)) | |
up8 = self.up8(conv7) | |
conv8 = self.conv8(torch.cat([conv2, up8], 1)) | |
up9 = self.up9(conv8) | |
conv9 = self.conv9(torch.cat([conv1, up9], 1)) | |
return conv9 | |
###### ResUnet new | |
class ResidualBlock(nn.Module): | |
def __init__(self, in_features=64, norm_layer=nn.BatchNorm2d): | |
super(ResidualBlock, self).__init__() | |
self.relu = nn.ReLU(True) | |
if norm_layer == None: | |
self.block = nn.Sequential( | |
nn.Conv2d(in_features, in_features, 3, 1, 1, bias=False), | |
nn.ReLU(inplace=True), | |
nn.Conv2d(in_features, in_features, 3, 1, 1, bias=False), | |
) | |
else: | |
self.block = nn.Sequential( | |
nn.Conv2d(in_features, in_features, 3, 1, 1, bias=False), | |
norm_layer(in_features), | |
nn.ReLU(inplace=True), | |
nn.Conv2d(in_features, in_features, 3, 1, 1, bias=False), | |
norm_layer(in_features) | |
) | |
def forward(self, x): | |
residual = x | |
out = self.block(x) | |
out += residual | |
out = self.relu(out) | |
return out | |
class Refine_ResUnet_New(nn.Module): | |
def __init__(self, input_nc, output_nc, num_downs=5, ngf=32, | |
norm_layer=nn.BatchNorm2d, use_dropout=False): | |
super(Refine_ResUnet_New, self).__init__() | |
# construct unet structure | |
unet_block = ResUnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True) | |
for i in range(num_downs - 5): | |
unet_block = ResUnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_dropout=use_dropout) | |
unet_block = ResUnetSkipConnectionBlock(ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer) | |
unet_block = ResUnetSkipConnectionBlock(ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer) | |
unet_block = ResUnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer) | |
unet_block = ResUnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer) | |
self.model = unet_block | |
def refine(self, input): | |
return self.model(input) | |
# Defines the submodule with skip connection. | |
# X -------------------identity---------------------- X | |
# |-- downsampling -- |submodule| -- upsampling --| | |
class ResUnetSkipConnectionBlock(nn.Module): | |
def __init__(self, outer_nc, inner_nc, input_nc=None, | |
submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False): | |
super(ResUnetSkipConnectionBlock, self).__init__() | |
self.outermost = outermost | |
use_bias = norm_layer == nn.InstanceNorm2d | |
if input_nc is None: | |
input_nc = outer_nc | |
downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=3, | |
stride=2, padding=1, bias=use_bias) | |
# add two resblock | |
res_downconv = [ResidualBlock(inner_nc, norm_layer), ResidualBlock(inner_nc, norm_layer)] | |
res_upconv = [ResidualBlock(outer_nc, norm_layer), ResidualBlock(outer_nc, norm_layer)] | |
downrelu = nn.ReLU(True) | |
uprelu = nn.ReLU(True) | |
if norm_layer != None: | |
downnorm = norm_layer(inner_nc) | |
upnorm = norm_layer(outer_nc) | |
if outermost: | |
upsample = nn.Upsample(scale_factor=2, mode='nearest') | |
upconv = nn.Conv2d(inner_nc * 2, outer_nc, kernel_size=3, stride=1, padding=1, bias=use_bias) | |
down = [downconv, downrelu] + res_downconv | |
up = [upsample, upconv] | |
model = down + [submodule] + up | |
elif innermost: | |
upsample = nn.Upsample(scale_factor=2, mode='nearest') | |
upconv = nn.Conv2d(inner_nc, outer_nc, kernel_size=3, stride=1, padding=1, bias=use_bias) | |
down = [downconv, downrelu] + res_downconv | |
if norm_layer == None: | |
up = [upsample, upconv, uprelu] + res_upconv | |
else: | |
up = [upsample, upconv, upnorm, uprelu] + res_upconv | |
model = down + up | |
else: | |
upsample = nn.Upsample(scale_factor=2, mode='nearest') | |
upconv = nn.Conv2d(inner_nc*2, outer_nc, kernel_size=3, stride=1, padding=1, bias=use_bias) | |
if norm_layer == None: | |
down = [downconv, downrelu] + res_downconv | |
up = [upsample, upconv, uprelu] + res_upconv | |
else: | |
down = [downconv, downnorm, downrelu] + res_downconv | |
up = [upsample, upconv, upnorm, uprelu] + res_upconv | |
if use_dropout: | |
model = down + [submodule] + up + [nn.Dropout(0.5)] | |
else: | |
model = down + [submodule] + up | |
self.model = nn.Sequential(*model) | |
def forward(self, x): | |
if self.outermost: | |
return self.model(x) | |
else: | |
return torch.cat([x, self.model(x)], 1) | |
################## | |
class GlobalGenerator(nn.Module): | |
def __init__(self, input_nc, output_nc, L, S, ngf=64, n_downsampling=3, n_blocks=9, norm_layer=nn.BatchNorm2d, | |
padding_type='reflect'): | |
assert (n_blocks >= 0) | |
super(GlobalGenerator, self).__init__() | |
activation = nn.ReLU(True) | |
model = [nn.ReflectionPad2d(3), nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0), norm_layer(ngf), activation] | |
### downsample | |
for i in range(n_downsampling): | |
mult = 2 ** i | |
model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1), | |
norm_layer(ngf * mult * 2), activation] | |
### resnet blocks | |
mult = 2 ** n_downsampling | |
for i in range(n_blocks): | |
model += [ResnetBlock(ngf * mult, norm_type='adain', padding_type=padding_type)] | |
### upsample | |
for i in range(n_downsampling): | |
mult = 2 ** (n_downsampling - i) | |
model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2), kernel_size=3, stride=2, padding=1, | |
output_padding=1), | |
norm_layer(int(ngf * mult / 2)), activation] | |
model += [nn.ReflectionPad2d(3), nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)] | |
self.model = nn.Sequential(*model) | |
# style encoder | |
self.enc_style = StyleEncoder(5, S, 16, self.get_num_adain_params(self.model), norm='none', activ='relu', | |
pad_type='reflect') | |
# label encoder | |
self.enc_label = LabelEncoder(5, L, 16, 64, norm='none', activ='relu', pad_type='reflect') | |
def assign_adain_params(self, adain_params, model): | |
# assign the adain_params to the AdaIN layers in model | |
for m in model.modules(): | |
if m.__class__.__name__ == "AdaptiveInstanceNorm2d": | |
mean = adain_params[:, :m.num_features] | |
std = adain_params[:, m.num_features:2 * m.num_features] | |
m.bias = mean.contiguous().view(-1) | |
m.weight = std.contiguous().view(-1) | |
if adain_params.size(1) > 2 * m.num_features: | |
adain_params = adain_params[:, 2 * m.num_features:] | |
def get_num_adain_params(self, model): | |
# return the number of AdaIN parameters needed by the model | |
num_adain_params = 0 | |
for m in model.modules(): | |
if m.__class__.__name__ == "AdaptiveInstanceNorm2d": | |
num_adain_params += 2 * m.num_features | |
return num_adain_params | |
def forward(self, input, input_ref, image_ref): | |
fea1, fea2 = self.enc_label(input_ref) | |
adain_params = self.enc_style((image_ref, fea1, fea2)) | |
self.assign_adain_params(adain_params, self.model) | |
return self.model(input) | |
class BlendGenerator(nn.Module): | |
def __init__(self, input_nc, output_nc, ngf=64, n_downsampling=3, n_blocks=3, norm_layer=nn.BatchNorm2d, | |
padding_type='reflect'): | |
assert (n_blocks >= 0) | |
super(BlendGenerator, self).__init__() | |
activation = nn.ReLU(True) | |
model = [nn.ReflectionPad2d(3), nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0), norm_layer(ngf), activation] | |
### downsample | |
for i in range(n_downsampling): | |
mult = 2 ** i | |
model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1), | |
norm_layer(ngf * mult * 2), activation] | |
### resnet blocks | |
mult = 2 ** n_downsampling | |
for i in range(n_blocks): | |
model += [ResnetBlock(ngf * mult, norm_type='in', padding_type=padding_type)] | |
### upsample | |
for i in range(n_downsampling): | |
mult = 2 ** (n_downsampling - i) | |
model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2), kernel_size=3, stride=2, padding=1, | |
output_padding=1), | |
norm_layer(int(ngf * mult / 2)), activation] | |
model += [nn.ReflectionPad2d(3), nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0), nn.Sigmoid()] | |
self.model = nn.Sequential(*model) | |
def forward(self, input1, input2): | |
m = self.model(torch.cat([input1, input2], 1)) | |
return input1 * m + input2 * (1 - m), m | |
# Define the Multiscale Discriminator. | |
class MultiscaleDiscriminator(nn.Module): | |
def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, | |
use_sigmoid=False, num_D=3, getIntermFeat=False): | |
super(MultiscaleDiscriminator, self).__init__() | |
self.num_D = num_D | |
self.n_layers = n_layers | |
self.getIntermFeat = getIntermFeat | |
for i in range(num_D): | |
netD = NLayerDiscriminator(input_nc, ndf, n_layers, norm_layer, use_sigmoid, getIntermFeat) | |
if getIntermFeat: | |
for j in range(n_layers + 2): | |
setattr(self, 'scale' + str(i) + '_layer' + str(j), getattr(netD, 'model' + str(j))) | |
else: | |
setattr(self, 'layer' + str(i), netD.model) | |
self.downsample = nn.AvgPool2d(3, stride=2, padding=[1, 1], count_include_pad=False) | |
def singleD_forward(self, model, input): | |
if self.getIntermFeat: | |
result = [input] | |
for i in range(len(model)): | |
result.append(model[i](result[-1])) | |
return result[1:] | |
else: | |
return [model(input)] | |
def forward(self, input): | |
num_D = self.num_D | |
result = [] | |
input_downsampled = input | |
for i in range(num_D): | |
if self.getIntermFeat: | |
model = [getattr(self, 'scale' + str(num_D - 1 - i) + '_layer' + str(j)) for j in | |
range(self.n_layers + 2)] | |
else: | |
model = getattr(self, 'layer' + str(num_D - 1 - i)) | |
result.append(self.singleD_forward(model, input_downsampled)) | |
if i != (num_D - 1): | |
input_downsampled = self.downsample(input_downsampled) | |
return result | |
# Define the PatchGAN discriminator with the specified arguments. | |
class NLayerDiscriminator(nn.Module): | |
def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, use_sigmoid=False, getIntermFeat=False): | |
super(NLayerDiscriminator, self).__init__() | |
self.getIntermFeat = getIntermFeat | |
self.n_layers = n_layers | |
kw = 4 | |
padw = int(np.ceil((kw - 1.0) / 2)) | |
sequence = [[nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]] | |
nf = ndf | |
for n in range(1, n_layers): | |
nf_prev = nf | |
nf = min(nf * 2, 512) | |
sequence += [[ | |
nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=2, padding=padw), | |
norm_layer(nf), nn.LeakyReLU(0.2, True) | |
]] | |
nf_prev = nf | |
nf = min(nf * 2, 512) | |
sequence += [[ | |
nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=1, padding=padw), | |
norm_layer(nf), | |
nn.LeakyReLU(0.2, True) | |
]] | |
sequence += [[nn.Conv2d(nf, 1, kernel_size=kw, stride=1, padding=padw)]] | |
if use_sigmoid: | |
sequence += [[nn.Sigmoid()]] | |
if getIntermFeat: | |
for n in range(len(sequence)): | |
setattr(self, 'model' + str(n), nn.Sequential(*sequence[n])) | |
else: | |
sequence_stream = [] | |
for n in range(len(sequence)): | |
sequence_stream += sequence[n] | |
self.model = nn.Sequential(*sequence_stream) | |
def forward(self, input): | |
if self.getIntermFeat: | |
res = [input] | |
for n in range(self.n_layers + 2): | |
model = getattr(self, 'model' + str(n)) | |
res.append(model(res[-1])) | |
return res[1:] | |
else: | |
return self.model(input) | |
from torchvision import models | |
class Vgg19(torch.nn.Module): | |
def __init__(self, requires_grad=False): | |
super(Vgg19, self).__init__() | |
vgg = models.vgg19(pretrained=False) | |
vgg_pretrained_features = vgg.features | |
self.vgg = vgg | |
self.slice1 = torch.nn.Sequential() | |
self.slice2 = torch.nn.Sequential() | |
self.slice3 = torch.nn.Sequential() | |
self.slice4 = torch.nn.Sequential() | |
self.slice5 = torch.nn.Sequential() | |
for x in range(2): | |
self.slice1.add_module(str(x), vgg_pretrained_features[x]) | |
for x in range(2, 7): | |
self.slice2.add_module(str(x), vgg_pretrained_features[x]) | |
for x in range(7, 12): | |
self.slice3.add_module(str(x), vgg_pretrained_features[x]) | |
for x in range(12, 21): | |
self.slice4.add_module(str(x), vgg_pretrained_features[x]) | |
for x in range(21, 30): | |
self.slice5.add_module(str(x), vgg_pretrained_features[x]) | |
if not requires_grad: | |
for param in self.parameters(): | |
param.requires_grad = False | |
def forward(self, X): | |
h_relu1 = self.slice1(X) | |
h_relu2 = self.slice2(h_relu1) | |
h_relu3 = self.slice3(h_relu2) | |
h_relu4 = self.slice4(h_relu3) | |
h_relu5 = self.slice5(h_relu4) | |
out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5] | |
return out | |
def extract(self, x): | |
x = self.vgg.features(x) | |
x = self.vgg.avgpool(x) | |
return x | |
# Define the MaskVAE | |
class VAE(nn.Module): | |
def __init__(self, nc, ngf, ndf, latent_variable_size): | |
super(VAE, self).__init__() | |
# self.cuda = True | |
self.nc = nc | |
self.ngf = ngf | |
self.ndf = ndf | |
self.latent_variable_size = latent_variable_size | |
# encoder | |
self.e1 = nn.Conv2d(nc, ndf, 4, 2, 1) | |
self.bn1 = nn.BatchNorm2d(ndf) | |
self.e2 = nn.Conv2d(ndf, ndf * 2, 4, 2, 1) | |
self.bn2 = nn.BatchNorm2d(ndf * 2) | |
self.e3 = nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1) | |
self.bn3 = nn.BatchNorm2d(ndf * 4) | |
self.e4 = nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1) | |
self.bn4 = nn.BatchNorm2d(ndf * 8) | |
self.e5 = nn.Conv2d(ndf * 8, ndf * 16, 4, 2, 1) | |
self.bn5 = nn.BatchNorm2d(ndf * 16) | |
self.e6 = nn.Conv2d(ndf * 16, ndf * 32, 4, 2, 1) | |
self.bn6 = nn.BatchNorm2d(ndf * 32) | |
self.e7 = nn.Conv2d(ndf * 32, ndf * 64, 4, 2, 1) | |
self.bn7 = nn.BatchNorm2d(ndf * 64) | |
self.fc1 = nn.Linear(ndf * 64 * 4 * 4, latent_variable_size) | |
self.fc2 = nn.Linear(ndf * 64 * 4 * 4, latent_variable_size) | |
# decoder | |
self.d1 = nn.Linear(latent_variable_size, ngf * 64 * 4 * 4) | |
self.up1 = nn.UpsamplingNearest2d(scale_factor=2) | |
self.pd1 = nn.ReplicationPad2d(1) | |
self.d2 = nn.Conv2d(ngf * 64, ngf * 32, 3, 1) | |
self.bn8 = nn.BatchNorm2d(ngf * 32, 1.e-3) | |
self.up2 = nn.UpsamplingNearest2d(scale_factor=2) | |
self.pd2 = nn.ReplicationPad2d(1) | |
self.d3 = nn.Conv2d(ngf * 32, ngf * 16, 3, 1) | |
self.bn9 = nn.BatchNorm2d(ngf * 16, 1.e-3) | |
self.up3 = nn.UpsamplingNearest2d(scale_factor=2) | |
self.pd3 = nn.ReplicationPad2d(1) | |
self.d4 = nn.Conv2d(ngf * 16, ngf * 8, 3, 1) | |
self.bn10 = nn.BatchNorm2d(ngf * 8, 1.e-3) | |
self.up4 = nn.UpsamplingNearest2d(scale_factor=2) | |
self.pd4 = nn.ReplicationPad2d(1) | |
self.d5 = nn.Conv2d(ngf * 8, ngf * 4, 3, 1) | |
self.bn11 = nn.BatchNorm2d(ngf * 4, 1.e-3) | |
self.up5 = nn.UpsamplingNearest2d(scale_factor=2) | |
self.pd5 = nn.ReplicationPad2d(1) | |
self.d6 = nn.Conv2d(ngf * 4, ngf * 2, 3, 1) | |
self.bn12 = nn.BatchNorm2d(ngf * 2, 1.e-3) | |
self.up6 = nn.UpsamplingNearest2d(scale_factor=2) | |
self.pd6 = nn.ReplicationPad2d(1) | |
self.d7 = nn.Conv2d(ngf * 2, ngf, 3, 1) | |
self.bn13 = nn.BatchNorm2d(ngf, 1.e-3) | |
self.up7 = nn.UpsamplingNearest2d(scale_factor=2) | |
self.pd7 = nn.ReplicationPad2d(1) | |
self.d8 = nn.Conv2d(ngf, nc, 3, 1) | |
self.leakyrelu = nn.LeakyReLU(0.2) | |
self.relu = nn.ReLU() | |
# self.sigmoid = nn.Sigmoid() | |
self.maxpool = nn.MaxPool2d((2, 2), (2, 2)) | |
def encode(self, x): | |
h1 = self.leakyrelu(self.bn1(self.e1(x))) | |
h2 = self.leakyrelu(self.bn2(self.e2(h1))) | |
h3 = self.leakyrelu(self.bn3(self.e3(h2))) | |
h4 = self.leakyrelu(self.bn4(self.e4(h3))) | |
h5 = self.leakyrelu(self.bn5(self.e5(h4))) | |
h6 = self.leakyrelu(self.bn6(self.e6(h5))) | |
h7 = self.leakyrelu(self.bn7(self.e7(h6))) | |
h7 = h7.view(-1, self.ndf * 64 * 4 * 4) | |
return self.fc1(h7), self.fc2(h7) | |
def reparametrize(self, mu, logvar): | |
std = logvar.mul(0.5).exp_() | |
# if self.cuda: | |
eps = torch.cuda.FloatTensor(std.size()).normal_() | |
# else: | |
# eps = torch.FloatTensor(std.size()).normal_() | |
eps = Variable(eps) | |
return eps.mul(std).add_(mu) | |
def decode(self, z): | |
h1 = self.relu(self.d1(z)) | |
h1 = h1.view(-1, self.ngf * 64, 4, 4) | |
h2 = self.leakyrelu(self.bn8(self.d2(self.pd1(self.up1(h1))))) | |
h3 = self.leakyrelu(self.bn9(self.d3(self.pd2(self.up2(h2))))) | |
h4 = self.leakyrelu(self.bn10(self.d4(self.pd3(self.up3(h3))))) | |
h5 = self.leakyrelu(self.bn11(self.d5(self.pd4(self.up4(h4))))) | |
h6 = self.leakyrelu(self.bn12(self.d6(self.pd5(self.up5(h5))))) | |
h7 = self.leakyrelu(self.bn13(self.d7(self.pd6(self.up6(h6))))) | |
return self.d8(self.pd7(self.up7(h7))) | |
def get_latent_var(self, x): | |
mu, logvar = self.encode(x) | |
z = self.reparametrize(mu, logvar) | |
return z, mu, logvar.mul(0.5).exp_() | |
def forward(self, x): | |
mu, logvar = self.encode(x) | |
z = self.reparametrize(mu, logvar) | |
res = self.decode(z) | |
return res, x, mu, logvar | |
# style encode part | |
class StyleEncoder(nn.Module): | |
def __init__(self, n_downsample, input_dim, dim, style_dim, norm, activ, pad_type): | |
super(StyleEncoder, self).__init__() | |
self.model = [] | |
self.model_middle = [] | |
self.model_last = [] | |
self.model += [ConvBlock(input_dim, dim, 7, 1, 3, norm=norm, activation=activ, pad_type=pad_type)] | |
for i in range(2): | |
self.model += [ConvBlock(dim, 2 * dim, 4, 2, 1, norm=norm, activation=activ, pad_type=pad_type)] | |
dim *= 2 | |
for i in range(n_downsample - 2): | |
self.model_middle += [ConvBlock(dim, dim, 4, 2, 1, norm=norm, activation=activ, pad_type=pad_type)] | |
self.model_last += [nn.AdaptiveAvgPool2d(1)] # global average pooling | |
self.model_last += [nn.Conv2d(dim, style_dim, 1, 1, 0)] | |
self.model = nn.Sequential(*self.model) | |
self.model_middle = nn.Sequential(*self.model_middle) | |
self.model_last = nn.Sequential(*self.model_last) | |
self.output_dim = dim | |
self.sft1 = SFTLayer() | |
self.sft2 = SFTLayer() | |
def forward(self, x): | |
fea = self.model(x[0]) | |
fea = self.sft1((fea, x[1])) | |
fea = self.model_middle(fea) | |
fea = self.sft2((fea, x[2])) | |
return self.model_last(fea) | |
# label encode part | |
class LabelEncoder(nn.Module): | |
def __init__(self, n_downsample, input_dim, dim, style_dim, norm, activ, pad_type): | |
super(LabelEncoder, self).__init__() | |
self.model = [] | |
self.model_last = [nn.ReLU()] | |
self.model += [ConvBlock(input_dim, dim, 7, 1, 3, norm=norm, activation=activ, pad_type=pad_type)] | |
self.model += [ConvBlock(dim, 2 * dim, 4, 2, 1, norm=norm, activation=activ, pad_type=pad_type)] | |
dim *= 2 | |
self.model += [ConvBlock(dim, 2 * dim, 4, 2, 1, norm=norm, activation='none', pad_type=pad_type)] | |
dim *= 2 | |
for i in range(n_downsample - 3): | |
self.model_last += [ConvBlock(dim, dim, 4, 2, 1, norm=norm, activation=activ, pad_type=pad_type)] | |
self.model_last += [ConvBlock(dim, dim, 4, 2, 1, norm=norm, activation='none', pad_type=pad_type)] | |
self.model = nn.Sequential(*self.model) | |
self.model_last = nn.Sequential(*self.model_last) | |
self.output_dim = dim | |
def forward(self, x): | |
fea = self.model(x) | |
return fea, self.model_last(fea) | |
# Define the basic block | |
class ConvBlock(nn.Module): | |
def __init__(self, input_dim, output_dim, kernel_size, stride, | |
padding=0, norm='none', activation='relu', pad_type='zero'): | |
super(ConvBlock, self).__init__() | |
self.use_bias = True | |
# initialize padding | |
if pad_type == 'reflect': | |
self.pad = nn.ReflectionPad2d(padding) | |
elif pad_type == 'replicate': | |
self.pad = nn.ReplicationPad2d(padding) | |
elif pad_type == 'zero': | |
self.pad = nn.ZeroPad2d(padding) | |
else: | |
assert 0, "Unsupported padding type: {}".format(pad_type) | |
# initialize normalization | |
norm_dim = output_dim | |
if norm == 'bn': | |
self.norm = nn.BatchNorm2d(norm_dim) | |
elif norm == 'in': | |
# self.norm = nn.InstanceNorm2d(norm_dim, track_running_stats=True) | |
self.norm = nn.InstanceNorm2d(norm_dim) | |
elif norm == 'ln': | |
self.norm = LayerNorm(norm_dim) | |
elif norm == 'adain': | |
self.norm = AdaptiveInstanceNorm2d(norm_dim) | |
elif norm == 'none' or norm == 'sn': | |
self.norm = None | |
else: | |
assert 0, "Unsupported normalization: {}".format(norm) | |
# initialize activation | |
if activation == 'relu': | |
self.activation = nn.ReLU(inplace=True) | |
elif activation == 'lrelu': | |
self.activation = nn.LeakyReLU(0.2, inplace=True) | |
elif activation == 'prelu': | |
self.activation = nn.PReLU() | |
elif activation == 'selu': | |
self.activation = nn.SELU(inplace=True) | |
elif activation == 'tanh': | |
self.activation = nn.Tanh() | |
elif activation == 'none': | |
self.activation = None | |
else: | |
assert 0, "Unsupported activation: {}".format(activation) | |
# initialize convolution | |
if norm == 'sn': | |
self.conv = SpectralNorm(nn.Conv2d(input_dim, output_dim, kernel_size, stride, bias=self.use_bias)) | |
else: | |
self.conv = nn.Conv2d(input_dim, output_dim, kernel_size, stride, bias=self.use_bias) | |
def forward(self, x): | |
x = self.conv(self.pad(x)) | |
if self.norm: | |
x = self.norm(x) | |
if self.activation: | |
x = self.activation(x) | |
return x | |
class LinearBlock(nn.Module): | |
def __init__(self, input_dim, output_dim, norm='none', activation='relu'): | |
super(LinearBlock, self).__init__() | |
use_bias = True | |
# initialize fully connected layer | |
if norm == 'sn': | |
self.fc = SpectralNorm(nn.Linear(input_dim, output_dim, bias=use_bias)) | |
else: | |
self.fc = nn.Linear(input_dim, output_dim, bias=use_bias) | |
# initialize normalization | |
norm_dim = output_dim | |
if norm == 'bn': | |
self.norm = nn.BatchNorm1d(norm_dim) | |
elif norm == 'in': | |
self.norm = nn.InstanceNorm1d(norm_dim) | |
elif norm == 'ln': | |
self.norm = LayerNorm(norm_dim) | |
elif norm == 'none' or norm == 'sn': | |
self.norm = None | |
else: | |
assert 0, "Unsupported normalization: {}".format(norm) | |
# initialize activation | |
if activation == 'relu': | |
self.activation = nn.ReLU(inplace=True) | |
elif activation == 'lrelu': | |
self.activation = nn.LeakyReLU(0.2, inplace=True) | |
elif activation == 'prelu': | |
self.activation = nn.PReLU() | |
elif activation == 'selu': | |
self.activation = nn.SELU(inplace=True) | |
elif activation == 'tanh': | |
self.activation = nn.Tanh() | |
elif activation == 'none': | |
self.activation = None | |
else: | |
assert 0, "Unsupported activation: {}".format(activation) | |
def forward(self, x): | |
out = self.fc(x) | |
if self.norm: | |
out = self.norm(out) | |
if self.activation: | |
out = self.activation(out) | |
return out | |
# Define a resnet block | |
class ResnetBlock(nn.Module): | |
def __init__(self, dim, norm_type, padding_type, use_dropout=False): | |
super(ResnetBlock, self).__init__() | |
self.conv_block = self.build_conv_block(dim, norm_type, padding_type, use_dropout) | |
def build_conv_block(self, dim, norm_type, padding_type, use_dropout): | |
conv_block = [] | |
conv_block += [ConvBlock(dim, dim, 3, 1, 1, norm=norm_type, activation='relu', pad_type=padding_type)] | |
conv_block += [ConvBlock(dim, dim, 3, 1, 1, norm=norm_type, activation='none', pad_type=padding_type)] | |
return nn.Sequential(*conv_block) | |
def forward(self, x): | |
out = x + self.conv_block(x) | |
return out | |
class SFTLayer(nn.Module): | |
def __init__(self): | |
super(SFTLayer, self).__init__() | |
self.SFT_scale_conv1 = nn.Conv2d(64, 64, 1) | |
self.SFT_scale_conv2 = nn.Conv2d(64, 64, 1) | |
self.SFT_shift_conv1 = nn.Conv2d(64, 64, 1) | |
self.SFT_shift_conv2 = nn.Conv2d(64, 64, 1) | |
def forward(self, x): | |
scale = self.SFT_scale_conv2(F.leaky_relu(self.SFT_scale_conv1(x[1]), 0.1, inplace=True)) | |
shift = self.SFT_shift_conv2(F.leaky_relu(self.SFT_shift_conv1(x[1]), 0.1, inplace=True)) | |
return x[0] * scale + shift | |
class ConvBlock_SFT(nn.Module): | |
def __init__(self, dim, norm_type, padding_type, use_dropout=False): | |
super(ResnetBlock_SFT, self).__init__() | |
self.sft1 = SFTLayer() | |
self.conv1 = ConvBlock(dim, dim, 4, 2, 1, norm=norm_type, activation='none', pad_type=padding_type) | |
def forward(self, x): | |
fea = self.sft1((x[0], x[1])) | |
fea = F.relu(self.conv1(fea), inplace=True) | |
return (x[0] + fea, x[1]) | |
class ConvBlock_SFT_last(nn.Module): | |
def __init__(self, dim, norm_type, padding_type, use_dropout=False): | |
super(ResnetBlock_SFT_last, self).__init__() | |
self.sft1 = SFTLayer() | |
self.conv1 = ConvBlock(dim, dim, 4, 2, 1, norm=norm_type, activation='none', pad_type=padding_type) | |
def forward(self, x): | |
fea = self.sft1((x[0], x[1])) | |
fea = F.relu(self.conv1(fea), inplace=True) | |
return x[0] + fea | |
# Definition of normalization layer | |
class AdaptiveInstanceNorm2d(nn.Module): | |
def __init__(self, num_features, eps=1e-5, momentum=0.1): | |
super(AdaptiveInstanceNorm2d, self).__init__() | |
self.num_features = num_features | |
self.eps = eps | |
self.momentum = momentum | |
# weight and bias are dynamically assigned | |
self.weight = None | |
self.bias = None | |
# just dummy buffers, not used | |
self.register_buffer('running_mean', torch.zeros(num_features)) | |
self.register_buffer('running_var', torch.ones(num_features)) | |
def forward(self, x): | |
assert self.weight is not None and self.bias is not None, "Please assign weight and bias before calling AdaIN!" | |
b, c = x.size(0), x.size(1) | |
running_mean = self.running_mean.repeat(b) | |
running_var = self.running_var.repeat(b) | |
# Apply instance norm | |
x_reshaped = x.contiguous().view(1, b * c, *x.size()[2:]) | |
out = F.batch_norm( | |
x_reshaped, running_mean, running_var, self.weight, self.bias, | |
True, self.momentum, self.eps) | |
return out.view(b, c, *x.size()[2:]) | |
def __repr__(self): | |
return self.__class__.__name__ + '(' + str(self.num_features) + ')' | |
class LayerNorm(nn.Module): | |
def __init__(self, num_features, eps=1e-5, affine=True): | |
super(LayerNorm, self).__init__() | |
self.num_features = num_features | |
self.affine = affine | |
self.eps = eps | |
if self.affine: | |
self.gamma = nn.Parameter(torch.Tensor(num_features).uniform_()) | |
self.beta = nn.Parameter(torch.zeros(num_features)) | |
def forward(self, x): | |
shape = [-1] + [1] * (x.dim() - 1) | |
# print(x.size()) | |
if x.size(0) == 1: | |
# These two lines run much faster in pytorch 0.4 than the two lines listed below. | |
mean = x.view(-1).mean().view(*shape) | |
std = x.view(-1).std().view(*shape) | |
else: | |
mean = x.view(x.size(0), -1).mean(1).view(*shape) | |
std = x.view(x.size(0), -1).std(1).view(*shape) | |
x = (x - mean) / (std + self.eps) | |
if self.affine: | |
shape = [1, -1] + [1] * (x.dim() - 2) | |
x = x * self.gamma.view(*shape) + self.beta.view(*shape) | |
return x | |
def l2normalize(v, eps=1e-12): | |
return v / (v.norm() + eps) | |
class SpectralNorm(nn.Module): | |
""" | |
Based on the paper "Spectral Normalization for Generative Adversarial Networks" by Takeru Miyato, Toshiki Kataoka, Masanori Koyama, Yuichi Yoshida | |
and the Pytorch implementation https://github.com/christiancosgrove/pytorch-spectral-normalization-gan | |
""" | |
def __init__(self, module, name='weight', power_iterations=1): | |
super(SpectralNorm, self).__init__() | |
self.module = module | |
self.name = name | |
self.power_iterations = power_iterations | |
if not self._made_params(): | |
self._make_params() | |
def _update_u_v(self): | |
u = getattr(self.module, self.name + "_u") | |
v = getattr(self.module, self.name + "_v") | |
w = getattr(self.module, self.name + "_bar") | |
height = w.data.shape[0] | |
for _ in range(self.power_iterations): | |
v.data = l2normalize(torch.mv(torch.t(w.view(height, -1).data), u.data)) | |
u.data = l2normalize(torch.mv(w.view(height, -1).data, v.data)) | |
# sigma = torch.dot(u.data, torch.mv(w.view(height,-1).data, v.data)) | |
sigma = u.dot(w.view(height, -1).mv(v)) | |
setattr(self.module, self.name, w / sigma.expand_as(w)) | |
def _made_params(self): | |
try: | |
u = getattr(self.module, self.name + "_u") | |
v = getattr(self.module, self.name + "_v") | |
w = getattr(self.module, self.name + "_bar") | |
return True | |
except AttributeError: | |
return False | |
def _make_params(self): | |
w = getattr(self.module, self.name) | |
height = w.data.shape[0] | |
width = w.view(height, -1).data.shape[1] | |
u = nn.Parameter(w.data.new(height).normal_(0, 1), requires_grad=False) | |
v = nn.Parameter(w.data.new(width).normal_(0, 1), requires_grad=False) | |
u.data = l2normalize(u.data) | |
v.data = l2normalize(v.data) | |
w_bar = nn.Parameter(w.data) | |
del self.module._parameters[self.name] | |
self.module.register_parameter(self.name + "_u", u) | |
self.module.register_parameter(self.name + "_v", v) | |
self.module.register_parameter(self.name + "_bar", w_bar) | |
def forward(self, *args): | |
self._update_u_v() | |
return self.module.forward(*args) | |
### STN TPS | |
class CNN(nn.Module): | |
def __init__(self, num_output, input_nc=5, ngf=8, n_layers=5, norm_layer=nn.InstanceNorm2d, use_dropout=False): | |
super(CNN, self).__init__() | |
downconv = nn.Conv2d(5, ngf, kernel_size=4, stride=2, padding=1) | |
model = [downconv, nn.ReLU(True), norm_layer(ngf)] | |
for i in range(n_layers): | |
in_ngf = 2 ** i * ngf if 2 ** i * ngf < 1024 else 1024 | |
out_ngf = 2 ** (i + 1) * ngf if 2 ** i * ngf < 1024 else 1024 | |
downconv = nn.Conv2d(in_ngf, out_ngf, kernel_size=4, stride=2, padding=1) | |
model += [downconv, norm_layer(out_ngf), nn.ReLU(True)] | |
model += [nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1), norm_layer(64), nn.ReLU(True)] | |
model += [nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1), norm_layer(64), nn.ReLU(True)] | |
self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2) | |
self.model = nn.Sequential(*model) | |
self.fc1 = nn.Linear(512, 128) | |
self.fc2 = nn.Linear(128, num_output) | |
def forward(self, x): | |
x = self.model(x) | |
x = self.maxpool(x) | |
x = x.view(x.shape[0], -1) | |
x = F.relu(self.fc1(x)) | |
x = F.dropout(x, training=self.training) | |
x = self.fc2(x) | |
return x | |
class ClsNet(nn.Module): | |
def __init__(self): | |
super(ClsNet, self).__init__() | |
self.cnn = CNN(10) | |
def forward(self, x): | |
return F.log_softmax(self.cnn(x)) | |
class BoundedGridLocNet(nn.Module): | |
def __init__(self, grid_height, grid_width, target_control_points): | |
super(BoundedGridLocNet, self).__init__() | |
self.cnn = CNN(grid_height * grid_width * 2) | |
bias = torch.from_numpy(np.arctanh(target_control_points.numpy())) | |
bias = bias.view(-1) | |
self.cnn.fc2.bias.data.copy_(bias) | |
self.cnn.fc2.weight.data.zero_() | |
def forward(self, x): | |
batch_size = x.size(0) | |
points = F.tanh(self.cnn(x)) | |
coor=points.view(batch_size, -1, 2) | |
# coor+=torch.randn(coor.shape).cuda()/10 | |
row=self.get_row(coor,5) | |
col=self.get_col(coor,5) | |
rx,ry,cx,cy=torch.tensor(0.08).cuda(),torch.tensor(0.08).cuda()\ | |
,torch.tensor(0.08).cuda(),torch.tensor(0.08).cuda() | |
row_x,row_y=row[:,:,0],row[:,:,1] | |
col_x,col_y=col[:,:,0],col[:,:,1] | |
rx_loss=torch.max(rx,row_x).mean() | |
ry_loss=torch.max(ry,row_y).mean() | |
cx_loss=torch.max(cx,col_x).mean() | |
cy_loss=torch.max(cy,col_y).mean() | |
return coor,rx_loss,ry_loss,cx_loss,cy_loss | |
def get_row(self,coor,num): | |
sec_dic=[] | |
for j in range(num): | |
sum=0 | |
buffer=0 | |
flag=False | |
max=-1 | |
for i in range(num-1): | |
differ=(coor[:,j*num+i+1,:]-coor[:,j*num+i,:])**2 | |
if not flag: | |
second_dif=0 | |
flag=True | |
else: | |
second_dif=torch.abs(differ-buffer) | |
sec_dic.append(second_dif) | |
buffer=differ | |
sum+=second_dif | |
return torch.stack(sec_dic,dim=1) | |
def get_col(self,coor,num): | |
sec_dic=[] | |
for i in range(num): | |
sum = 0 | |
buffer = 0 | |
flag = False | |
max = -1 | |
for j in range(num - 1): | |
differ = (coor[:, (j+1) * num + i , :] - coor[:, j * num + i, :]) ** 2 | |
if not flag: | |
second_dif = 0 | |
flag = True | |
else: | |
second_dif = torch.abs(differ-buffer) | |
sec_dic.append(second_dif) | |
buffer = differ | |
sum += second_dif | |
return torch.stack(sec_dic,dim=1) | |
class UnBoundedGridLocNet(nn.Module): | |
def __init__(self, grid_height, grid_width, target_control_points): | |
super(UnBoundedGridLocNet, self).__init__() | |
self.cnn = CNN(grid_height * grid_width * 2) | |
bias = target_control_points.view(-1) | |
self.cnn.fc2.bias.data.copy_(bias) | |
self.cnn.fc2.weight.data.zero_() | |
def forward(self, x): | |
batch_size = x.size(0) | |
points = self.cnn(x) | |
return points.view(batch_size, -1, 2) | |
class STNNet(nn.Module): | |
def __init__(self): | |
super(STNNet, self).__init__() | |
range = 0.9 | |
r1 = range | |
r2 = range | |
grid_size_h = 5 | |
grid_size_w = 5 | |
assert r1 < 1 and r2 < 1 # if >= 1, arctanh will cause error in BoundedGridLocNet | |
target_control_points = torch.Tensor(list(itertools.product( | |
np.arange(-r1, r1 + 0.00001, 2.0 * r1 / (grid_size_h - 1)), | |
np.arange(-r2, r2 + 0.00001, 2.0 * r2 / (grid_size_w - 1)), | |
))) | |
Y, X = target_control_points.split(1, dim=1) | |
target_control_points = torch.cat([X, Y], dim=1) | |
self.target_control_points=target_control_points | |
# self.get_row(target_control_points,5) | |
GridLocNet = { | |
'unbounded_stn': UnBoundedGridLocNet, | |
'bounded_stn': BoundedGridLocNet, | |
}['bounded_stn'] | |
self.loc_net = GridLocNet(grid_size_h, grid_size_w, target_control_points) | |
self.tps = TPSGridGen(256, 192, target_control_points) | |
def get_row(self, coor, num): | |
for j in range(num): | |
sum = 0 | |
buffer = 0 | |
flag = False | |
max = -1 | |
for i in range(num - 1): | |
differ = (coor[j * num + i + 1, :] - coor[j * num + i, :]) ** 2 | |
if not flag: | |
second_dif = 0 | |
flag = True | |
else: | |
second_dif = torch.abs(differ - buffer) | |
buffer = differ | |
sum += second_dif | |
print(sum / num) | |
def get_col(self,coor,num): | |
for i in range(num): | |
sum = 0 | |
buffer = 0 | |
flag = False | |
max = -1 | |
for j in range(num - 1): | |
differ = (coor[ (j + 1) * num + i, :] - coor[j * num + i, :]) ** 2 | |
if not flag: | |
second_dif = 0 | |
flag = True | |
else: | |
second_dif = torch.abs(differ-buffer) | |
buffer = differ | |
sum += second_dif | |
print(sum) | |
def forward(self, x, reference, mask,grid_pic): | |
batch_size = x.size(0) | |
source_control_points,rx,ry,cx,cy = self.loc_net(reference) | |
source_control_points=(source_control_points) | |
# print('control points',source_control_points.shape) | |
source_coordinate = self.tps(source_control_points) | |
grid = source_coordinate.view(batch_size, 256, 192, 2) | |
# print('grid size',grid.shape) | |
transformed_x = grid_sample(x, grid, canvas=0) | |
warped_mask = grid_sample(mask, grid, canvas=0) | |
warped_gpic= grid_sample(grid_pic, grid, canvas=0) | |
return transformed_x, warped_mask,rx,ry,cx,cy,warped_gpic |