import torch import torch.nn as nn import numpy as np import matplotlib.pyplot as plt import torch.nn.functional as F class BaseNetwork(nn.Module): def __init__(self): super(BaseNetwork, self).__init__() def init_weights(self, init_type='normal', gain=0.02): ''' initialize network's weights init_type: normal | xavier | kaiming | orthogonal https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/9451e70673400885567d08a9e97ade2524c700d0/models/networks.py#L39 ''' def init_func(m): classname = m.__class__.__name__ if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): if init_type == 'normal': nn.init.normal_(m.weight.data, 0.0, gain) elif init_type == 'xavier': nn.init.xavier_normal_(m.weight.data, gain=gain) elif init_type == 'kaiming': nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') elif init_type == 'orthogonal': nn.init.orthogonal_(m.weight.data, gain=gain) if hasattr(m, 'bias') and m.bias is not None: nn.init.constant_(m.bias.data, 0.0) elif classname.find('BatchNorm2d') != -1: nn.init.normal_(m.weight.data, 1.0, gain) nn.init.constant_(m.bias.data, 0.0) self.apply(init_func) def weights_init(init_type='gaussian'): def init_fun(m): classname = m.__class__.__name__ if (classname.find('Conv') == 0 or classname.find( 'Linear') == 0) and hasattr(m, 'weight'): if init_type == 'gaussian': nn.init.normal_(m.weight, 0.0, 0.02) elif init_type == 'xavier': nn.init.xavier_normal_(m.weight, gain=math.sqrt(2)) elif init_type == 'kaiming': nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in') elif init_type == 'orthogonal': nn.init.orthogonal_(m.weight, gain=math.sqrt(2)) elif init_type == 'default': pass else: assert 0, "Unsupported initialization: {}".format(init_type) if hasattr(m, 'bias') and m.bias is not None: nn.init.constant_(m.bias, 0.0) return init_fun class PartialConv(nn.Module): def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True): super().__init__() self.input_conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias) self.mask_conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, False) self.input_conv.apply(weights_init('kaiming')) self.slide_winsize = in_channels * kernel_size * kernel_size torch.nn.init.constant_(self.mask_conv.weight, 1.0) # mask is not updated for param in self.mask_conv.parameters(): param.requires_grad = False def forward(self, input, mask): # http://masc.cs.gmu.edu/wiki/partialconv # C(X) = W^T * X + b, C(0) = b, D(M) = 1 * M + 0 = sum(M) # W^T* (M .* X) / sum(M) + b = [C(M .* X) – C(0)] / D(M) + C(0) output = self.input_conv(input * mask) if self.input_conv.bias is not None: output_bias = self.input_conv.bias.view(1, -1, 1, 1).expand_as( output) else: output_bias = torch.zeros_like(output) with torch.no_grad(): output_mask = self.mask_conv(mask) no_update_holes = output_mask == 0 mask_sum = output_mask.masked_fill_(no_update_holes, 1.0) output_pre = ((output - output_bias) * self.slide_winsize) / mask_sum + output_bias output = output_pre.masked_fill_(no_update_holes, 0.0) new_mask = torch.ones_like(output) new_mask = new_mask.masked_fill_(no_update_holes, 0.0) return output, new_mask class PCBActiv(nn.Module): def __init__(self, in_ch, out_ch, bn=True, sample='none-3', activ='relu', conv_bias=False): super().__init__() if sample == 'down-5': self.conv = PartialConv(in_ch, out_ch, 5, 2, 2, bias=conv_bias) elif sample == 'down-7': self.conv = PartialConv(in_ch, out_ch, 7, 2, 3, bias=conv_bias) elif sample == 'down-3': self.conv = PartialConv(in_ch, out_ch, 3, 2, 1, bias=conv_bias) else: self.conv = PartialConv(in_ch, out_ch, 3, 1, 1, bias=conv_bias) if bn: self.bn = nn.BatchNorm2d(out_ch) if activ == 'relu': self.activation = nn.ReLU() elif activ == 'leaky': self.activation = nn.LeakyReLU(negative_slope=0.2) def forward(self, input, input_mask): h, h_mask = self.conv(input, input_mask) if hasattr(self, 'bn'): h = self.bn(h) if hasattr(self, 'activation'): h = self.activation(h) return h, h_mask class Inpaint_Depth_Net(nn.Module): def __init__(self, layer_size=7, upsampling_mode='nearest'): super().__init__() in_channels = 4 out_channels = 1 self.freeze_enc_bn = False self.upsampling_mode = upsampling_mode self.layer_size = layer_size self.enc_1 = PCBActiv(in_channels, 64, bn=False, sample='down-7', conv_bias=True) self.enc_2 = PCBActiv(64, 128, sample='down-5', conv_bias=True) self.enc_3 = PCBActiv(128, 256, sample='down-5') self.enc_4 = PCBActiv(256, 512, sample='down-3') for i in range(4, self.layer_size): name = 'enc_{:d}'.format(i + 1) setattr(self, name, PCBActiv(512, 512, sample='down-3')) for i in range(4, self.layer_size): name = 'dec_{:d}'.format(i + 1) setattr(self, name, PCBActiv(512 + 512, 512, activ='leaky')) self.dec_4 = PCBActiv(512 + 256, 256, activ='leaky') self.dec_3 = PCBActiv(256 + 128, 128, activ='leaky') self.dec_2 = PCBActiv(128 + 64, 64, activ='leaky') self.dec_1 = PCBActiv(64 + in_channels, out_channels, bn=False, activ=None, conv_bias=True) def add_border(self, input, mask_flag, PCONV=True): with torch.no_grad(): h = input.shape[-2] w = input.shape[-1] require_len_unit = 2 ** self.layer_size residual_h = int(np.ceil(h / float(require_len_unit)) * require_len_unit - h) # + 2*require_len_unit residual_w = int(np.ceil(w / float(require_len_unit)) * require_len_unit - w) # + 2*require_len_unit enlarge_input = torch.zeros((input.shape[0], input.shape[1], h + residual_h, w + residual_w)).to(input.device) if mask_flag: if PCONV is False: enlarge_input += 1.0 enlarge_input = enlarge_input.clamp(0.0, 1.0) else: enlarge_input[:, 2, ...] = 0.0 anchor_h = residual_h//2 anchor_w = residual_w//2 enlarge_input[..., anchor_h:anchor_h+h, anchor_w:anchor_w+w] = input return enlarge_input, [anchor_h, anchor_h+h, anchor_w, anchor_w+w] def forward_3P(self, mask, context, depth, edge, unit_length=128, cuda=None): with torch.no_grad(): input = torch.cat((depth, edge, context, mask), dim=1) n, c, h, w = input.shape residual_h = int(np.ceil(h / float(unit_length)) * unit_length - h) residual_w = int(np.ceil(w / float(unit_length)) * unit_length - w) anchor_h = residual_h//2 anchor_w = residual_w//2 enlarge_input = torch.zeros((n, c, h + residual_h, w + residual_w)).to(cuda) enlarge_input[..., anchor_h:anchor_h+h, anchor_w:anchor_w+w] = input # enlarge_input[:, 3] = 1. - enlarge_input[:, 3] depth_output = self.forward(enlarge_input) depth_output = depth_output[..., anchor_h:anchor_h+h, anchor_w:anchor_w+w] # import pdb; pdb.set_trace() return depth_output def forward(self, input_feat, refine_border=False, sample=False, PCONV=True): input = input_feat input_mask = (input_feat[:, -2:-1] + input_feat[:, -1:]).clamp(0, 1).repeat(1, input.shape[1], 1, 1) vis_input = input.cpu().data.numpy() vis_input_mask = input_mask.cpu().data.numpy() H, W = input.shape[-2:] if refine_border is True: input, anchor = self.add_border(input, mask_flag=False) input_mask, anchor = self.add_border(input_mask, mask_flag=True, PCONV=PCONV) h_dict = {} # for the output of enc_N h_mask_dict = {} # for the output of enc_N h_dict['h_0'], h_mask_dict['h_0'] = input, input_mask h_key_prev = 'h_0' for i in range(1, self.layer_size + 1): l_key = 'enc_{:d}'.format(i) h_key = 'h_{:d}'.format(i) h_dict[h_key], h_mask_dict[h_key] = getattr(self, l_key)( h_dict[h_key_prev], h_mask_dict[h_key_prev]) h_key_prev = h_key h_key = 'h_{:d}'.format(self.layer_size) h, h_mask = h_dict[h_key], h_mask_dict[h_key] for i in range(self.layer_size, 0, -1): enc_h_key = 'h_{:d}'.format(i - 1) dec_l_key = 'dec_{:d}'.format(i) h = F.interpolate(h, scale_factor=2, mode=self.upsampling_mode) h_mask = F.interpolate(h_mask, scale_factor=2, mode='nearest') h = torch.cat([h, h_dict[enc_h_key]], dim=1) h_mask = torch.cat([h_mask, h_mask_dict[enc_h_key]], dim=1) h, h_mask = getattr(self, dec_l_key)(h, h_mask) output = h if refine_border is True: h_mask = h_mask[..., anchor[0]:anchor[1], anchor[2]:anchor[3]] output = output[..., anchor[0]:anchor[1], anchor[2]:anchor[3]] return output class Inpaint_Edge_Net(BaseNetwork): def __init__(self, residual_blocks=8, init_weights=True): super(Inpaint_Edge_Net, self).__init__() in_channels = 7 out_channels = 1 self.encoder = [] # 0 self.encoder_0 = nn.Sequential( nn.ReflectionPad2d(3), spectral_norm(nn.Conv2d(in_channels=in_channels, out_channels=64, kernel_size=7, padding=0), True), nn.InstanceNorm2d(64, track_running_stats=False), nn.ReLU(True)) # 1 self.encoder_1 = nn.Sequential( spectral_norm(nn.Conv2d(in_channels=64, out_channels=128, kernel_size=4, stride=2, padding=1), True), nn.InstanceNorm2d(128, track_running_stats=False), nn.ReLU(True)) # 2 self.encoder_2 = nn.Sequential( spectral_norm(nn.Conv2d(in_channels=128, out_channels=256, kernel_size=4, stride=2, padding=1), True), nn.InstanceNorm2d(256, track_running_stats=False), nn.ReLU(True)) # 3 blocks = [] for _ in range(residual_blocks): block = ResnetBlock(256, 2) blocks.append(block) self.middle = nn.Sequential(*blocks) # + 3 self.decoder_0 = nn.Sequential( spectral_norm(nn.ConvTranspose2d(in_channels=256+256, out_channels=128, kernel_size=4, stride=2, padding=1), True), nn.InstanceNorm2d(128, track_running_stats=False), nn.ReLU(True)) # + 2 self.decoder_1 = nn.Sequential( spectral_norm(nn.ConvTranspose2d(in_channels=128+128, out_channels=64, kernel_size=4, stride=2, padding=1), True), nn.InstanceNorm2d(64, track_running_stats=False), nn.ReLU(True)) # + 1 self.decoder_2 = nn.Sequential( nn.ReflectionPad2d(3), nn.Conv2d(in_channels=64+64, out_channels=out_channels, kernel_size=7, padding=0), ) if init_weights: self.init_weights() def add_border(self, input, channel_pad_1=None): h = input.shape[-2] w = input.shape[-1] require_len_unit = 16 residual_h = int(np.ceil(h / float(require_len_unit)) * require_len_unit - h) # + 2*require_len_unit residual_w = int(np.ceil(w / float(require_len_unit)) * require_len_unit - w) # + 2*require_len_unit enlarge_input = torch.zeros((input.shape[0], input.shape[1], h + residual_h, w + residual_w)).to(input.device) if channel_pad_1 is not None: for channel in channel_pad_1: enlarge_input[:, channel] = 1 anchor_h = residual_h//2 anchor_w = residual_w//2 enlarge_input[..., anchor_h:anchor_h+h, anchor_w:anchor_w+w] = input return enlarge_input, [anchor_h, anchor_h+h, anchor_w, anchor_w+w] def forward_3P(self, mask, context, rgb, disp, edge, unit_length=128, cuda=None): with torch.no_grad(): input = torch.cat((rgb, disp/disp.max(), edge, context, mask), dim=1) n, c, h, w = input.shape residual_h = int(np.ceil(h / float(unit_length)) * unit_length - h) residual_w = int(np.ceil(w / float(unit_length)) * unit_length - w) anchor_h = residual_h//2 anchor_w = residual_w//2 enlarge_input = torch.zeros((n, c, h + residual_h, w + residual_w)).to(cuda) enlarge_input[..., anchor_h:anchor_h+h, anchor_w:anchor_w+w] = input edge_output = self.forward(enlarge_input) edge_output = edge_output[..., anchor_h:anchor_h+h, anchor_w:anchor_w+w] return edge_output def forward(self, x, refine_border=False): if refine_border: x, anchor = self.add_border(x, [5]) x1 = self.encoder_0(x) x2 = self.encoder_1(x1) x3 = self.encoder_2(x2) x4 = self.middle(x3) x5 = self.decoder_0(torch.cat((x4, x3), dim=1)) x6 = self.decoder_1(torch.cat((x5, x2), dim=1)) x7 = self.decoder_2(torch.cat((x6, x1), dim=1)) x = torch.sigmoid(x7) if refine_border: x = x[..., anchor[0]:anchor[1], anchor[2]:anchor[3]] return x class Inpaint_Color_Net(nn.Module): def __init__(self, layer_size=7, upsampling_mode='nearest', add_hole_mask=False, add_two_layer=False, add_border=False): super().__init__() self.freeze_enc_bn = False self.upsampling_mode = upsampling_mode self.layer_size = layer_size in_channels = 6 self.enc_1 = PCBActiv(in_channels, 64, bn=False, sample='down-7') self.enc_2 = PCBActiv(64, 128, sample='down-5') self.enc_3 = PCBActiv(128, 256, sample='down-5') self.enc_4 = PCBActiv(256, 512, sample='down-3') self.enc_5 = PCBActiv(512, 512, sample='down-3') self.enc_6 = PCBActiv(512, 512, sample='down-3') self.enc_7 = PCBActiv(512, 512, sample='down-3') self.dec_7 = PCBActiv(512+512, 512, activ='leaky') self.dec_6 = PCBActiv(512+512, 512, activ='leaky') self.dec_5A = PCBActiv(512 + 512, 512, activ='leaky') self.dec_4A = PCBActiv(512 + 256, 256, activ='leaky') self.dec_3A = PCBActiv(256 + 128, 128, activ='leaky') self.dec_2A = PCBActiv(128 + 64, 64, activ='leaky') self.dec_1A = PCBActiv(64 + in_channels, 3, bn=False, activ=None, conv_bias=True) ''' self.dec_5B = PCBActiv(512 + 512, 512, activ='leaky') self.dec_4B = PCBActiv(512 + 256, 256, activ='leaky') self.dec_3B = PCBActiv(256 + 128, 128, activ='leaky') self.dec_2B = PCBActiv(128 + 64, 64, activ='leaky') self.dec_1B = PCBActiv(64 + 4, 1, bn=False, activ=None, conv_bias=True) ''' def cat(self, A, B): return torch.cat((A, B), dim=1) def upsample(self, feat, mask): feat = F.interpolate(feat, scale_factor=2, mode=self.upsampling_mode) mask = F.interpolate(mask, scale_factor=2, mode='nearest') return feat, mask def forward_3P(self, mask, context, rgb, edge, unit_length=128, cuda=None): with torch.no_grad(): input = torch.cat((rgb, edge, context, mask), dim=1) n, c, h, w = input.shape residual_h = int(np.ceil(h / float(unit_length)) * unit_length - h) # + 128 residual_w = int(np.ceil(w / float(unit_length)) * unit_length - w) # + 256 anchor_h = residual_h//2 anchor_w = residual_w//2 enlarge_input = torch.zeros((n, c, h + residual_h, w + residual_w)).to(cuda) enlarge_input[..., anchor_h:anchor_h+h, anchor_w:anchor_w+w] = input # enlarge_input[:, 3] = 1. - enlarge_input[:, 3] enlarge_input = enlarge_input.to(cuda) rgb_output = self.forward(enlarge_input) rgb_output = rgb_output[..., anchor_h:anchor_h+h, anchor_w:anchor_w+w] return rgb_output def forward(self, input, add_border=False): input_mask = (input[:, -2:-1] + input[:, -1:]).clamp(0, 1) H, W = input.shape[-2:] f_0, h_0 = input, input_mask.repeat((1,input.shape[1],1,1)) f_1, h_1 = self.enc_1(f_0, h_0) f_2, h_2 = self.enc_2(f_1, h_1) f_3, h_3 = self.enc_3(f_2, h_2) f_4, h_4 = self.enc_4(f_3, h_3) f_5, h_5 = self.enc_5(f_4, h_4) f_6, h_6 = self.enc_6(f_5, h_5) f_7, h_7 = self.enc_7(f_6, h_6) o_7, k_7 = self.upsample(f_7, h_7) o_6, k_6 = self.dec_7(self.cat(o_7, f_6), self.cat(k_7, h_6)) o_6, k_6 = self.upsample(o_6, k_6) o_5, k_5 = self.dec_6(self.cat(o_6, f_5), self.cat(k_6, h_5)) o_5, k_5 = self.upsample(o_5, k_5) o_5A, k_5A = o_5, k_5 o_5B, k_5B = o_5, k_5 ############### o_4A, k_4A = self.dec_5A(self.cat(o_5A, f_4), self.cat(k_5A, h_4)) o_4A, k_4A = self.upsample(o_4A, k_4A) o_3A, k_3A = self.dec_4A(self.cat(o_4A, f_3), self.cat(k_4A, h_3)) o_3A, k_3A = self.upsample(o_3A, k_3A) o_2A, k_2A = self.dec_3A(self.cat(o_3A, f_2), self.cat(k_3A, h_2)) o_2A, k_2A = self.upsample(o_2A, k_2A) o_1A, k_1A = self.dec_2A(self.cat(o_2A, f_1), self.cat(k_2A, h_1)) o_1A, k_1A = self.upsample(o_1A, k_1A) o_0A, k_0A = self.dec_1A(self.cat(o_1A, f_0), self.cat(k_1A, h_0)) return torch.sigmoid(o_0A) def train(self, mode=True): """ Override the default train() to freeze the BN parameters """ super().train(mode) if self.freeze_enc_bn: for name, module in self.named_modules(): if isinstance(module, nn.BatchNorm2d) and 'enc' in name: module.eval() class Discriminator(BaseNetwork): def __init__(self, use_sigmoid=True, use_spectral_norm=True, init_weights=True, in_channels=None): super(Discriminator, self).__init__() self.use_sigmoid = use_sigmoid self.conv1 = self.features = nn.Sequential( spectral_norm(nn.Conv2d(in_channels=in_channels, out_channels=64, kernel_size=4, stride=2, padding=1, bias=not use_spectral_norm), use_spectral_norm), nn.LeakyReLU(0.2, inplace=True), ) self.conv2 = nn.Sequential( spectral_norm(nn.Conv2d(in_channels=64, out_channels=128, kernel_size=4, stride=2, padding=1, bias=not use_spectral_norm), use_spectral_norm), nn.LeakyReLU(0.2, inplace=True), ) self.conv3 = nn.Sequential( spectral_norm(nn.Conv2d(in_channels=128, out_channels=256, kernel_size=4, stride=2, padding=1, bias=not use_spectral_norm), use_spectral_norm), nn.LeakyReLU(0.2, inplace=True), ) self.conv4 = nn.Sequential( spectral_norm(nn.Conv2d(in_channels=256, out_channels=512, kernel_size=4, stride=1, padding=1, bias=not use_spectral_norm), use_spectral_norm), nn.LeakyReLU(0.2, inplace=True), ) self.conv5 = nn.Sequential( spectral_norm(nn.Conv2d(in_channels=512, out_channels=1, kernel_size=4, stride=1, padding=1, bias=not use_spectral_norm), use_spectral_norm), ) if init_weights: self.init_weights() def forward(self, x): conv1 = self.conv1(x) conv2 = self.conv2(conv1) conv3 = self.conv3(conv2) conv4 = self.conv4(conv3) conv5 = self.conv5(conv4) outputs = conv5 if self.use_sigmoid: outputs = torch.sigmoid(conv5) return outputs, [conv1, conv2, conv3, conv4, conv5] class ResnetBlock(nn.Module): def __init__(self, dim, dilation=1): super(ResnetBlock, self).__init__() self.conv_block = nn.Sequential( nn.ReflectionPad2d(dilation), spectral_norm(nn.Conv2d(in_channels=dim, out_channels=dim, kernel_size=3, padding=0, dilation=dilation, bias=not True), True), nn.InstanceNorm2d(dim, track_running_stats=False), nn.LeakyReLU(negative_slope=0.2), nn.ReflectionPad2d(1), spectral_norm(nn.Conv2d(in_channels=dim, out_channels=dim, kernel_size=3, padding=0, dilation=1, bias=not True), True), nn.InstanceNorm2d(dim, track_running_stats=False), ) def forward(self, x): out = x + self.conv_block(x) # Remove ReLU at the end of the residual block # http://torch.ch/blog/2016/02/04/resnets.html return out def spectral_norm(module, mode=True): if mode: return nn.utils.spectral_norm(module) return module