3D_Photo_Inpainting / networks.py
Saini
init
0b9f920
raw
history blame
22.1 kB
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
import torch.nn.functional as F
class BaseNetwork(nn.Module):
def __init__(self):
super(BaseNetwork, self).__init__()
def init_weights(self, init_type='normal', gain=0.02):
'''
initialize network's weights
init_type: normal | xavier | kaiming | orthogonal
https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/9451e70673400885567d08a9e97ade2524c700d0/models/networks.py#L39
'''
def init_func(m):
classname = m.__class__.__name__
if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
if init_type == 'normal':
nn.init.normal_(m.weight.data, 0.0, gain)
elif init_type == 'xavier':
nn.init.xavier_normal_(m.weight.data, gain=gain)
elif init_type == 'kaiming':
nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
elif init_type == 'orthogonal':
nn.init.orthogonal_(m.weight.data, gain=gain)
if hasattr(m, 'bias') and m.bias is not None:
nn.init.constant_(m.bias.data, 0.0)
elif classname.find('BatchNorm2d') != -1:
nn.init.normal_(m.weight.data, 1.0, gain)
nn.init.constant_(m.bias.data, 0.0)
self.apply(init_func)
def weights_init(init_type='gaussian'):
def init_fun(m):
classname = m.__class__.__name__
if (classname.find('Conv') == 0 or classname.find(
'Linear') == 0) and hasattr(m, 'weight'):
if init_type == 'gaussian':
nn.init.normal_(m.weight, 0.0, 0.02)
elif init_type == 'xavier':
nn.init.xavier_normal_(m.weight, gain=math.sqrt(2))
elif init_type == 'kaiming':
nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in')
elif init_type == 'orthogonal':
nn.init.orthogonal_(m.weight, gain=math.sqrt(2))
elif init_type == 'default':
pass
else:
assert 0, "Unsupported initialization: {}".format(init_type)
if hasattr(m, 'bias') and m.bias is not None:
nn.init.constant_(m.bias, 0.0)
return init_fun
class PartialConv(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
padding=0, dilation=1, groups=1, bias=True):
super().__init__()
self.input_conv = nn.Conv2d(in_channels, out_channels, kernel_size,
stride, padding, dilation, groups, bias)
self.mask_conv = nn.Conv2d(in_channels, out_channels, kernel_size,
stride, padding, dilation, groups, False)
self.input_conv.apply(weights_init('kaiming'))
self.slide_winsize = in_channels * kernel_size * kernel_size
torch.nn.init.constant_(self.mask_conv.weight, 1.0)
# mask is not updated
for param in self.mask_conv.parameters():
param.requires_grad = False
def forward(self, input, mask):
# http://masc.cs.gmu.edu/wiki/partialconv
# C(X) = W^T * X + b, C(0) = b, D(M) = 1 * M + 0 = sum(M)
# W^T* (M .* X) / sum(M) + b = [C(M .* X) – C(0)] / D(M) + C(0)
output = self.input_conv(input * mask)
if self.input_conv.bias is not None:
output_bias = self.input_conv.bias.view(1, -1, 1, 1).expand_as(
output)
else:
output_bias = torch.zeros_like(output)
with torch.no_grad():
output_mask = self.mask_conv(mask)
no_update_holes = output_mask == 0
mask_sum = output_mask.masked_fill_(no_update_holes, 1.0)
output_pre = ((output - output_bias) * self.slide_winsize) / mask_sum + output_bias
output = output_pre.masked_fill_(no_update_holes, 0.0)
new_mask = torch.ones_like(output)
new_mask = new_mask.masked_fill_(no_update_holes, 0.0)
return output, new_mask
class PCBActiv(nn.Module):
def __init__(self, in_ch, out_ch, bn=True, sample='none-3', activ='relu',
conv_bias=False):
super().__init__()
if sample == 'down-5':
self.conv = PartialConv(in_ch, out_ch, 5, 2, 2, bias=conv_bias)
elif sample == 'down-7':
self.conv = PartialConv(in_ch, out_ch, 7, 2, 3, bias=conv_bias)
elif sample == 'down-3':
self.conv = PartialConv(in_ch, out_ch, 3, 2, 1, bias=conv_bias)
else:
self.conv = PartialConv(in_ch, out_ch, 3, 1, 1, bias=conv_bias)
if bn:
self.bn = nn.BatchNorm2d(out_ch)
if activ == 'relu':
self.activation = nn.ReLU()
elif activ == 'leaky':
self.activation = nn.LeakyReLU(negative_slope=0.2)
def forward(self, input, input_mask):
h, h_mask = self.conv(input, input_mask)
if hasattr(self, 'bn'):
h = self.bn(h)
if hasattr(self, 'activation'):
h = self.activation(h)
return h, h_mask
class Inpaint_Depth_Net(nn.Module):
def __init__(self, layer_size=7, upsampling_mode='nearest'):
super().__init__()
in_channels = 4
out_channels = 1
self.freeze_enc_bn = False
self.upsampling_mode = upsampling_mode
self.layer_size = layer_size
self.enc_1 = PCBActiv(in_channels, 64, bn=False, sample='down-7', conv_bias=True)
self.enc_2 = PCBActiv(64, 128, sample='down-5', conv_bias=True)
self.enc_3 = PCBActiv(128, 256, sample='down-5')
self.enc_4 = PCBActiv(256, 512, sample='down-3')
for i in range(4, self.layer_size):
name = 'enc_{:d}'.format(i + 1)
setattr(self, name, PCBActiv(512, 512, sample='down-3'))
for i in range(4, self.layer_size):
name = 'dec_{:d}'.format(i + 1)
setattr(self, name, PCBActiv(512 + 512, 512, activ='leaky'))
self.dec_4 = PCBActiv(512 + 256, 256, activ='leaky')
self.dec_3 = PCBActiv(256 + 128, 128, activ='leaky')
self.dec_2 = PCBActiv(128 + 64, 64, activ='leaky')
self.dec_1 = PCBActiv(64 + in_channels, out_channels,
bn=False, activ=None, conv_bias=True)
def add_border(self, input, mask_flag, PCONV=True):
with torch.no_grad():
h = input.shape[-2]
w = input.shape[-1]
require_len_unit = 2 ** self.layer_size
residual_h = int(np.ceil(h / float(require_len_unit)) * require_len_unit - h) # + 2*require_len_unit
residual_w = int(np.ceil(w / float(require_len_unit)) * require_len_unit - w) # + 2*require_len_unit
enlarge_input = torch.zeros((input.shape[0], input.shape[1], h + residual_h, w + residual_w)).to(input.device)
if mask_flag:
if PCONV is False:
enlarge_input += 1.0
enlarge_input = enlarge_input.clamp(0.0, 1.0)
else:
enlarge_input[:, 2, ...] = 0.0
anchor_h = residual_h//2
anchor_w = residual_w//2
enlarge_input[..., anchor_h:anchor_h+h, anchor_w:anchor_w+w] = input
return enlarge_input, [anchor_h, anchor_h+h, anchor_w, anchor_w+w]
def forward_3P(self, mask, context, depth, edge, unit_length=128, cuda=None):
with torch.no_grad():
input = torch.cat((depth, edge, context, mask), dim=1)
n, c, h, w = input.shape
residual_h = int(np.ceil(h / float(unit_length)) * unit_length - h)
residual_w = int(np.ceil(w / float(unit_length)) * unit_length - w)
anchor_h = residual_h//2
anchor_w = residual_w//2
enlarge_input = torch.zeros((n, c, h + residual_h, w + residual_w)).to(cuda)
enlarge_input[..., anchor_h:anchor_h+h, anchor_w:anchor_w+w] = input
# enlarge_input[:, 3] = 1. - enlarge_input[:, 3]
depth_output = self.forward(enlarge_input)
depth_output = depth_output[..., anchor_h:anchor_h+h, anchor_w:anchor_w+w]
# import pdb; pdb.set_trace()
return depth_output
def forward(self, input_feat, refine_border=False, sample=False, PCONV=True):
input = input_feat
input_mask = (input_feat[:, -2:-1] + input_feat[:, -1:]).clamp(0, 1).repeat(1, input.shape[1], 1, 1)
vis_input = input.cpu().data.numpy()
vis_input_mask = input_mask.cpu().data.numpy()
H, W = input.shape[-2:]
if refine_border is True:
input, anchor = self.add_border(input, mask_flag=False)
input_mask, anchor = self.add_border(input_mask, mask_flag=True, PCONV=PCONV)
h_dict = {} # for the output of enc_N
h_mask_dict = {} # for the output of enc_N
h_dict['h_0'], h_mask_dict['h_0'] = input, input_mask
h_key_prev = 'h_0'
for i in range(1, self.layer_size + 1):
l_key = 'enc_{:d}'.format(i)
h_key = 'h_{:d}'.format(i)
h_dict[h_key], h_mask_dict[h_key] = getattr(self, l_key)(
h_dict[h_key_prev], h_mask_dict[h_key_prev])
h_key_prev = h_key
h_key = 'h_{:d}'.format(self.layer_size)
h, h_mask = h_dict[h_key], h_mask_dict[h_key]
for i in range(self.layer_size, 0, -1):
enc_h_key = 'h_{:d}'.format(i - 1)
dec_l_key = 'dec_{:d}'.format(i)
h = F.interpolate(h, scale_factor=2, mode=self.upsampling_mode)
h_mask = F.interpolate(h_mask, scale_factor=2, mode='nearest')
h = torch.cat([h, h_dict[enc_h_key]], dim=1)
h_mask = torch.cat([h_mask, h_mask_dict[enc_h_key]], dim=1)
h, h_mask = getattr(self, dec_l_key)(h, h_mask)
output = h
if refine_border is True:
h_mask = h_mask[..., anchor[0]:anchor[1], anchor[2]:anchor[3]]
output = output[..., anchor[0]:anchor[1], anchor[2]:anchor[3]]
return output
class Inpaint_Edge_Net(BaseNetwork):
def __init__(self, residual_blocks=8, init_weights=True):
super(Inpaint_Edge_Net, self).__init__()
in_channels = 7
out_channels = 1
self.encoder = []
# 0
self.encoder_0 = nn.Sequential(
nn.ReflectionPad2d(3),
spectral_norm(nn.Conv2d(in_channels=in_channels, out_channels=64, kernel_size=7, padding=0), True),
nn.InstanceNorm2d(64, track_running_stats=False),
nn.ReLU(True))
# 1
self.encoder_1 = nn.Sequential(
spectral_norm(nn.Conv2d(in_channels=64, out_channels=128, kernel_size=4, stride=2, padding=1), True),
nn.InstanceNorm2d(128, track_running_stats=False),
nn.ReLU(True))
# 2
self.encoder_2 = nn.Sequential(
spectral_norm(nn.Conv2d(in_channels=128, out_channels=256, kernel_size=4, stride=2, padding=1), True),
nn.InstanceNorm2d(256, track_running_stats=False),
nn.ReLU(True))
# 3
blocks = []
for _ in range(residual_blocks):
block = ResnetBlock(256, 2)
blocks.append(block)
self.middle = nn.Sequential(*blocks)
# + 3
self.decoder_0 = nn.Sequential(
spectral_norm(nn.ConvTranspose2d(in_channels=256+256, out_channels=128, kernel_size=4, stride=2, padding=1), True),
nn.InstanceNorm2d(128, track_running_stats=False),
nn.ReLU(True))
# + 2
self.decoder_1 = nn.Sequential(
spectral_norm(nn.ConvTranspose2d(in_channels=128+128, out_channels=64, kernel_size=4, stride=2, padding=1), True),
nn.InstanceNorm2d(64, track_running_stats=False),
nn.ReLU(True))
# + 1
self.decoder_2 = nn.Sequential(
nn.ReflectionPad2d(3),
nn.Conv2d(in_channels=64+64, out_channels=out_channels, kernel_size=7, padding=0),
)
if init_weights:
self.init_weights()
def add_border(self, input, channel_pad_1=None):
h = input.shape[-2]
w = input.shape[-1]
require_len_unit = 16
residual_h = int(np.ceil(h / float(require_len_unit)) * require_len_unit - h) # + 2*require_len_unit
residual_w = int(np.ceil(w / float(require_len_unit)) * require_len_unit - w) # + 2*require_len_unit
enlarge_input = torch.zeros((input.shape[0], input.shape[1], h + residual_h, w + residual_w)).to(input.device)
if channel_pad_1 is not None:
for channel in channel_pad_1:
enlarge_input[:, channel] = 1
anchor_h = residual_h//2
anchor_w = residual_w//2
enlarge_input[..., anchor_h:anchor_h+h, anchor_w:anchor_w+w] = input
return enlarge_input, [anchor_h, anchor_h+h, anchor_w, anchor_w+w]
def forward_3P(self, mask, context, rgb, disp, edge, unit_length=128, cuda=None):
with torch.no_grad():
input = torch.cat((rgb, disp/disp.max(), edge, context, mask), dim=1)
n, c, h, w = input.shape
residual_h = int(np.ceil(h / float(unit_length)) * unit_length - h)
residual_w = int(np.ceil(w / float(unit_length)) * unit_length - w)
anchor_h = residual_h//2
anchor_w = residual_w//2
enlarge_input = torch.zeros((n, c, h + residual_h, w + residual_w)).to(cuda)
enlarge_input[..., anchor_h:anchor_h+h, anchor_w:anchor_w+w] = input
edge_output = self.forward(enlarge_input)
edge_output = edge_output[..., anchor_h:anchor_h+h, anchor_w:anchor_w+w]
return edge_output
def forward(self, x, refine_border=False):
if refine_border:
x, anchor = self.add_border(x, [5])
x1 = self.encoder_0(x)
x2 = self.encoder_1(x1)
x3 = self.encoder_2(x2)
x4 = self.middle(x3)
x5 = self.decoder_0(torch.cat((x4, x3), dim=1))
x6 = self.decoder_1(torch.cat((x5, x2), dim=1))
x7 = self.decoder_2(torch.cat((x6, x1), dim=1))
x = torch.sigmoid(x7)
if refine_border:
x = x[..., anchor[0]:anchor[1], anchor[2]:anchor[3]]
return x
class Inpaint_Color_Net(nn.Module):
def __init__(self, layer_size=7, upsampling_mode='nearest', add_hole_mask=False, add_two_layer=False, add_border=False):
super().__init__()
self.freeze_enc_bn = False
self.upsampling_mode = upsampling_mode
self.layer_size = layer_size
in_channels = 6
self.enc_1 = PCBActiv(in_channels, 64, bn=False, sample='down-7')
self.enc_2 = PCBActiv(64, 128, sample='down-5')
self.enc_3 = PCBActiv(128, 256, sample='down-5')
self.enc_4 = PCBActiv(256, 512, sample='down-3')
self.enc_5 = PCBActiv(512, 512, sample='down-3')
self.enc_6 = PCBActiv(512, 512, sample='down-3')
self.enc_7 = PCBActiv(512, 512, sample='down-3')
self.dec_7 = PCBActiv(512+512, 512, activ='leaky')
self.dec_6 = PCBActiv(512+512, 512, activ='leaky')
self.dec_5A = PCBActiv(512 + 512, 512, activ='leaky')
self.dec_4A = PCBActiv(512 + 256, 256, activ='leaky')
self.dec_3A = PCBActiv(256 + 128, 128, activ='leaky')
self.dec_2A = PCBActiv(128 + 64, 64, activ='leaky')
self.dec_1A = PCBActiv(64 + in_channels, 3, bn=False, activ=None, conv_bias=True)
'''
self.dec_5B = PCBActiv(512 + 512, 512, activ='leaky')
self.dec_4B = PCBActiv(512 + 256, 256, activ='leaky')
self.dec_3B = PCBActiv(256 + 128, 128, activ='leaky')
self.dec_2B = PCBActiv(128 + 64, 64, activ='leaky')
self.dec_1B = PCBActiv(64 + 4, 1, bn=False, activ=None, conv_bias=True)
'''
def cat(self, A, B):
return torch.cat((A, B), dim=1)
def upsample(self, feat, mask):
feat = F.interpolate(feat, scale_factor=2, mode=self.upsampling_mode)
mask = F.interpolate(mask, scale_factor=2, mode='nearest')
return feat, mask
def forward_3P(self, mask, context, rgb, edge, unit_length=128, cuda=None):
with torch.no_grad():
input = torch.cat((rgb, edge, context, mask), dim=1)
n, c, h, w = input.shape
residual_h = int(np.ceil(h / float(unit_length)) * unit_length - h) # + 128
residual_w = int(np.ceil(w / float(unit_length)) * unit_length - w) # + 256
anchor_h = residual_h//2
anchor_w = residual_w//2
enlarge_input = torch.zeros((n, c, h + residual_h, w + residual_w)).to(cuda)
enlarge_input[..., anchor_h:anchor_h+h, anchor_w:anchor_w+w] = input
# enlarge_input[:, 3] = 1. - enlarge_input[:, 3]
enlarge_input = enlarge_input.to(cuda)
rgb_output = self.forward(enlarge_input)
rgb_output = rgb_output[..., anchor_h:anchor_h+h, anchor_w:anchor_w+w]
return rgb_output
def forward(self, input, add_border=False):
input_mask = (input[:, -2:-1] + input[:, -1:]).clamp(0, 1)
H, W = input.shape[-2:]
f_0, h_0 = input, input_mask.repeat((1,input.shape[1],1,1))
f_1, h_1 = self.enc_1(f_0, h_0)
f_2, h_2 = self.enc_2(f_1, h_1)
f_3, h_3 = self.enc_3(f_2, h_2)
f_4, h_4 = self.enc_4(f_3, h_3)
f_5, h_5 = self.enc_5(f_4, h_4)
f_6, h_6 = self.enc_6(f_5, h_5)
f_7, h_7 = self.enc_7(f_6, h_6)
o_7, k_7 = self.upsample(f_7, h_7)
o_6, k_6 = self.dec_7(self.cat(o_7, f_6), self.cat(k_7, h_6))
o_6, k_6 = self.upsample(o_6, k_6)
o_5, k_5 = self.dec_6(self.cat(o_6, f_5), self.cat(k_6, h_5))
o_5, k_5 = self.upsample(o_5, k_5)
o_5A, k_5A = o_5, k_5
o_5B, k_5B = o_5, k_5
###############
o_4A, k_4A = self.dec_5A(self.cat(o_5A, f_4), self.cat(k_5A, h_4))
o_4A, k_4A = self.upsample(o_4A, k_4A)
o_3A, k_3A = self.dec_4A(self.cat(o_4A, f_3), self.cat(k_4A, h_3))
o_3A, k_3A = self.upsample(o_3A, k_3A)
o_2A, k_2A = self.dec_3A(self.cat(o_3A, f_2), self.cat(k_3A, h_2))
o_2A, k_2A = self.upsample(o_2A, k_2A)
o_1A, k_1A = self.dec_2A(self.cat(o_2A, f_1), self.cat(k_2A, h_1))
o_1A, k_1A = self.upsample(o_1A, k_1A)
o_0A, k_0A = self.dec_1A(self.cat(o_1A, f_0), self.cat(k_1A, h_0))
return torch.sigmoid(o_0A)
def train(self, mode=True):
"""
Override the default train() to freeze the BN parameters
"""
super().train(mode)
if self.freeze_enc_bn:
for name, module in self.named_modules():
if isinstance(module, nn.BatchNorm2d) and 'enc' in name:
module.eval()
class Discriminator(BaseNetwork):
def __init__(self, use_sigmoid=True, use_spectral_norm=True, init_weights=True, in_channels=None):
super(Discriminator, self).__init__()
self.use_sigmoid = use_sigmoid
self.conv1 = self.features = nn.Sequential(
spectral_norm(nn.Conv2d(in_channels=in_channels, out_channels=64, kernel_size=4, stride=2, padding=1, bias=not use_spectral_norm), use_spectral_norm),
nn.LeakyReLU(0.2, inplace=True),
)
self.conv2 = nn.Sequential(
spectral_norm(nn.Conv2d(in_channels=64, out_channels=128, kernel_size=4, stride=2, padding=1, bias=not use_spectral_norm), use_spectral_norm),
nn.LeakyReLU(0.2, inplace=True),
)
self.conv3 = nn.Sequential(
spectral_norm(nn.Conv2d(in_channels=128, out_channels=256, kernel_size=4, stride=2, padding=1, bias=not use_spectral_norm), use_spectral_norm),
nn.LeakyReLU(0.2, inplace=True),
)
self.conv4 = nn.Sequential(
spectral_norm(nn.Conv2d(in_channels=256, out_channels=512, kernel_size=4, stride=1, padding=1, bias=not use_spectral_norm), use_spectral_norm),
nn.LeakyReLU(0.2, inplace=True),
)
self.conv5 = nn.Sequential(
spectral_norm(nn.Conv2d(in_channels=512, out_channels=1, kernel_size=4, stride=1, padding=1, bias=not use_spectral_norm), use_spectral_norm),
)
if init_weights:
self.init_weights()
def forward(self, x):
conv1 = self.conv1(x)
conv2 = self.conv2(conv1)
conv3 = self.conv3(conv2)
conv4 = self.conv4(conv3)
conv5 = self.conv5(conv4)
outputs = conv5
if self.use_sigmoid:
outputs = torch.sigmoid(conv5)
return outputs, [conv1, conv2, conv3, conv4, conv5]
class ResnetBlock(nn.Module):
def __init__(self, dim, dilation=1):
super(ResnetBlock, self).__init__()
self.conv_block = nn.Sequential(
nn.ReflectionPad2d(dilation),
spectral_norm(nn.Conv2d(in_channels=dim, out_channels=dim, kernel_size=3, padding=0, dilation=dilation, bias=not True), True),
nn.InstanceNorm2d(dim, track_running_stats=False),
nn.LeakyReLU(negative_slope=0.2),
nn.ReflectionPad2d(1),
spectral_norm(nn.Conv2d(in_channels=dim, out_channels=dim, kernel_size=3, padding=0, dilation=1, bias=not True), True),
nn.InstanceNorm2d(dim, track_running_stats=False),
)
def forward(self, x):
out = x + self.conv_block(x)
# Remove ReLU at the end of the residual block
# http://torch.ch/blog/2016/02/04/resnets.html
return out
def spectral_norm(module, mode=True):
if mode:
return nn.utils.spectral_norm(module)
return module