oguzakif's picture
init repo
d4b77ac
raw history blame
No virus
14.2 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
networks = ['BaseNetwork', 'Discriminator', 'ASPP']
# Base model borrows from PEN-NET
# https://github.com/researchmm/PEN-Net-for-Inpainting
class BaseNetwork(nn.Module):
def __init__(self):
super(BaseNetwork, self).__init__()
def print_network(self):
if isinstance(self, list):
self = self[0]
num_params = 0
for param in self.parameters():
num_params += param.numel()
print('Network [%s] was created. Total number of parameters: %.1f million. '
'To see the architecture, do print(network).' % (type(self).__name__, num_params / 1000000))
def init_weights(self, init_type='normal', gain=0.02):
'''
initialize network's weights
init_type: normal | xavier | kaiming | orthogonal
https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/9451e70673400885567d08a9e97ade2524c700d0/models/networks.py#L39
'''
def init_func(m):
classname = m.__class__.__name__
if classname.find('InstanceNorm2d') != -1:
if hasattr(m, 'weight') and m.weight is not None:
nn.init.constant_(m.weight.data, 1.0)
if hasattr(m, 'bias') and m.bias is not None:
nn.init.constant_(m.bias.data, 0.0)
elif hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
if init_type == 'normal':
nn.init.normal_(m.weight.data, 0.0, gain)
elif init_type == 'xavier':
nn.init.xavier_normal_(m.weight.data, gain=gain)
elif init_type == 'xavier_uniform':
nn.init.xavier_uniform_(m.weight.data, gain=1.0)
elif init_type == 'kaiming':
nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
elif init_type == 'orthogonal':
nn.init.orthogonal_(m.weight.data, gain=gain)
elif init_type == 'none': # uses pytorch's default init method
m.reset_parameters()
else:
raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
if hasattr(m, 'bias') and m.bias is not None:
nn.init.constant_(m.bias.data, 0.0)
self.apply(init_func)
# propagate to children
for m in self.children():
if hasattr(m, 'init_weights'):
m.init_weights(init_type, gain)
# temporal patch gan: from Free-form Video Inpainting with 3D Gated Convolution and Temporal PatchGAN in 2019 ICCV
# todo: debug this model
class Discriminator(BaseNetwork):
def __init__(self, in_channels=3, use_sigmoid=False, use_spectral_norm=True, init_weights=True):
super(Discriminator, self).__init__()
self.use_sigmoid = use_sigmoid
nf = 64
self.conv = nn.Sequential(
DisBuildingBlock(in_channel=in_channels, out_channel=nf * 1, kernel_size=(3, 5, 5), stride=(1, 2, 2),
padding=1, use_spectral_norm=use_spectral_norm),
# nn.InstanceNorm2d(64, track_running_stats=False),
nn.LeakyReLU(0.2, inplace=True),
DisBuildingBlock(in_channel=nf * 1, out_channel=nf * 2, kernel_size=(3, 5, 5), stride=(1, 2, 2),
padding=(1, 2, 2), use_spectral_norm=use_spectral_norm),
# nn.InstanceNorm2d(128, track_running_stats=False),
nn.LeakyReLU(0.2, inplace=True),
DisBuildingBlock(in_channel=nf * 2, out_channel=nf * 4, kernel_size=(3, 5, 5), stride=(1, 2, 2),
padding=(1, 2, 2), use_spectral_norm=use_spectral_norm),
# nn.InstanceNorm2d(256, track_running_stats=False),
nn.LeakyReLU(0.2, inplace=True),
DisBuildingBlock(in_channel=nf * 4, out_channel=nf * 4, kernel_size=(3, 5, 5), stride=(1, 2, 2),
padding=(1, 2, 2), use_spectral_norm=use_spectral_norm),
# nn.InstanceNorm2d(256, track_running_stats=False),
nn.LeakyReLU(0.2, inplace=True),
DisBuildingBlock(in_channel=nf * 4, out_channel=nf * 4, kernel_size=(3, 5, 5), stride=(1, 2, 2),
padding=(1, 2, 2), use_spectral_norm=use_spectral_norm),
# nn.InstanceNorm2d(256, track_running_stats=False),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv3d(nf * 4, nf * 4, kernel_size=(3, 5, 5),
stride=(1, 2, 2), padding=(1, 2, 2))
)
if init_weights:
self.init_weights()
def forward(self, xs):
# B, C, T, H, W = xs.shape
feat = self.conv(xs)
if self.use_sigmoid:
feat = torch.sigmoid(feat)
return feat
class DisBuildingBlock(nn.Module):
def __init__(self, in_channel, out_channel, kernel_size, stride, padding, use_spectral_norm):
super(DisBuildingBlock, self).__init__()
self.block = self._getBlock(in_channel, out_channel, kernel_size, stride, padding, use_spectral_norm)
def _getBlock(self, in_channel, out_channel, kernel_size, stride, padding, use_spectral_norm):
feature_conv = nn.Conv3d(in_channels=in_channel, out_channels=out_channel, kernel_size=kernel_size,
stride=stride, padding=padding, bias=not use_spectral_norm)
if use_spectral_norm:
feature_conv = nn.utils.spectral_norm(feature_conv)
return feature_conv
def forward(self, inputs):
out = self.block(inputs)
return out
class ASPP(nn.Module):
def __init__(self, input_channels, output_channels, rate=[1, 2, 4, 8]):
super(ASPP, self).__init__()
self.input_channels = input_channels
self.output_channels = output_channels
self.rate = rate
for i in range(len(rate)):
self.__setattr__('conv{}'.format(str(i).zfill(2)), nn.Sequential(
nn.Conv2d(input_channels, output_channels // len(rate), kernel_size=3, dilation=rate[i],
padding=rate[i]),
nn.LeakyReLU(negative_slope=0.2, inplace=True)
))
def forward(self, inputs):
inputs = inputs / len(self.rate)
tmp = []
for i in range(len(self.rate)):
tmp.append(self.__getattr__('conv{}'.format(str(i).zfill(2)))(inputs))
output = torch.cat(tmp, dim=1)
return output
class GatedConv2dWithActivation(torch.nn.Module):
"""
Gated Convlution layer with activation (default activation:LeakyReLU)
Params: same as conv2d
Input: The feature from last layer "I"
Output:\phi(f(I))*\sigmoid(g(I))
"""
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True,
batch_norm=False, activation=torch.nn.LeakyReLU(0.2, inplace=True)):
super(GatedConv2dWithActivation, self).__init__()
self.batch_norm = batch_norm
self.activation = activation
self.conv2d = torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias)
self.mask_conv2d = torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups,
bias)
self.batch_norm2d = torch.nn.BatchNorm2d(out_channels)
self.sigmoid = torch.nn.Sigmoid()
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight)
def gated(self, mask):
return self.sigmoid(mask)
def forward(self, inputs):
x = self.conv2d(inputs)
mask = self.mask_conv2d(inputs)
if self.activation is not None:
x = self.activation(x) * self.gated(mask)
else:
x = x * self.gated(mask)
if self.batch_norm:
return self.batch_norm2d(x)
else:
return x
class GatedDeConv2dWithActivation(torch.nn.Module):
"""
Gated DeConvlution layer with activation (default activation:LeakyReLU)
resize + conv
Params: same as conv2d
Input: The feature from last layer "I"
Output:\phi(f(I))*\sigmoid(g(I))
"""
def __init__(self, scale_factor, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1,
bias=True, batch_norm=False, activation=torch.nn.LeakyReLU(0.2, inplace=True)):
super(GatedDeConv2dWithActivation, self).__init__()
self.conv2d = GatedConv2dWithActivation(in_channels, out_channels, kernel_size, stride, padding, dilation,
groups, bias, batch_norm, activation)
self.scale_factor = scale_factor
def forward(self, inputs):
# print(input.size())
x = F.interpolate(inputs, scale_factor=self.scale_factor)
return self.conv2d(x)
class SNGatedConv2dWithActivation(torch.nn.Module):
"""
Gated Convolution with spetral normalization
"""
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True,
batch_norm=False, activation=torch.nn.LeakyReLU(0.2, inplace=True)):
super(SNGatedConv2dWithActivation, self).__init__()
self.conv2d = torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias)
self.mask_conv2d = torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups,
bias)
self.activation = activation
self.batch_norm = batch_norm
self.batch_norm2d = torch.nn.BatchNorm2d(out_channels)
self.sigmoid = torch.nn.Sigmoid()
self.conv2d = torch.nn.utils.spectral_norm(self.conv2d)
self.mask_conv2d = torch.nn.utils.spectral_norm(self.mask_conv2d)
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight)
def gated(self, mask):
return self.sigmoid(mask)
def forward(self, inputs):
x = self.conv2d(inputs)
mask = self.mask_conv2d(inputs)
if self.activation is not None:
x = self.activation(x) * self.gated(mask)
else:
x = x * self.gated(mask)
if self.batch_norm:
return self.batch_norm2d(x)
else:
return x
class SNGatedDeConv2dWithActivation(torch.nn.Module):
"""
Gated DeConvlution layer with activation (default activation:LeakyReLU)
resize + conv
Params: same as conv2d
Input: The feature from last layer "I"
Output:\phi(f(I))*\sigmoid(g(I))
"""
def __init__(self, scale_factor, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1,
bias=True, batch_norm=False, activation=torch.nn.LeakyReLU(0.2, inplace=True)):
super(SNGatedDeConv2dWithActivation, self).__init__()
self.conv2d = SNGatedConv2dWithActivation(in_channels, out_channels, kernel_size, stride, padding, dilation,
groups, bias, batch_norm, activation)
self.scale_factor = scale_factor
def forward(self, inputs):
x = F.interpolate(inputs, scale_factor=2)
return self.conv2d(x)
class GatedConv3d(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True,
activation=nn.LeakyReLU(0.2, inplace=True)):
super(GatedConv3d, self).__init__()
self.input_conv = nn.Conv3d(in_channels, out_channels, kernel_size,
stride, padding, dilation, groups, bias)
self.gating_conv = nn.Conv3d(in_channels, out_channels, kernel_size,
stride, padding, dilation, groups, bias)
self.activation = activation
def forward(self, inputs):
feature = self.input_conv(inputs)
if self.activation:
feature = self.activation(feature)
gating = torch.sigmoid(self.gating_conv(inputs))
return feature * gating
class GatedDeconv3d(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride, padding, scale_factor, dilation=1, groups=1,
bias=True, activation=nn.LeakyReLU(0.2, inplace=True)):
super().__init__()
self.scale_factor = scale_factor
self.deconv = GatedConv3d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias,
activation)
def forward(self, inputs):
inputs = F.interpolate(inputs, scale_factor=(1, self.scale_factor, self.scale_factor))
return self.deconv(inputs)
def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
return _no_grad_trunc_normal_(tensor, mean, std, a, b)
def _no_grad_trunc_normal_(tensor, mean, std, a, b):
# Cut & paste from PyTorch official master until it's in a few official releases - RW
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
def norm_cdf(x):
# Computes standard normal cumulative distribution function
return (1. + math.erf(x / math.sqrt(2.))) / 2.
with torch.no_grad():
# Values are generated by using a truncated uniform distribution and
# then using the inverse CDF for the normal distribution.
# Get upper and lower cdf values
l = norm_cdf((a - mean) / std)
u = norm_cdf((b - mean) / std)
# Uniformly fill tensor with values from [l, u], then translate to
# [2l-1, 2u-1].
tensor.uniform_(2 * l - 1, 2 * u - 1)
# Use inverse cdf transform for normal distribution to get truncated
# standard normal
tensor.erfinv_()
# Transform to proper mean, std
tensor.mul_(std * math.sqrt(2.))
tensor.add_(mean)
# Clamp to ensure it's in the proper range
tensor.clamp_(min=a, max=b)
return tensor