sjc / ncsn /layers.py
amankishore's picture
Updated app.py
7a11626
import torch.nn as nn
import torch
from torch.nn.parameter import Parameter
import torch.nn.functional as F
from .normalization import *
from functools import partial
import math
import torch.nn.init as init
def get_act(config):
if config.model.nonlinearity.lower() == 'elu':
return nn.ELU()
elif config.model.nonlinearity.lower() == 'relu':
return nn.ReLU()
elif config.model.nonlinearity.lower() == 'lrelu':
return nn.LeakyReLU(negative_slope=0.2)
elif config.model.nonlinearity.lower() == 'swish':
def swish(x):
return x * torch.sigmoid(x)
return swish
else:
raise NotImplementedError('activation function does not exist!')
def spectral_norm(layer, n_iters=1):
return torch.nn.utils.spectral_norm(layer, n_power_iterations=n_iters)
def conv1x1(in_planes, out_planes, stride=1, bias=True, spec_norm=False):
"1x1 convolution"
conv = nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride,
padding=0, bias=bias)
if spec_norm:
conv = spectral_norm(conv)
return conv
def conv3x3(in_planes, out_planes, stride=1, bias=True, spec_norm=False):
"3x3 convolution with padding"
conv = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
padding=1, bias=bias)
if spec_norm:
conv = spectral_norm(conv)
return conv
def stride_conv3x3(in_planes, out_planes, kernel_size, bias=True, spec_norm=False):
conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=2,
padding=kernel_size // 2, bias=bias)
if spec_norm:
conv = spectral_norm(conv)
return conv
def dilated_conv3x3(in_planes, out_planes, dilation, bias=True, spec_norm=False):
conv = nn.Conv2d(in_planes, out_planes, kernel_size=3, padding=dilation, dilation=dilation, bias=bias)
if spec_norm:
conv = spectral_norm(conv)
return conv
class CRPBlock(nn.Module):
def __init__(self, features, n_stages, act=nn.ReLU(), maxpool=True, spec_norm=False):
super().__init__()
self.convs = nn.ModuleList()
for i in range(n_stages):
self.convs.append(conv3x3(features, features, stride=1, bias=False, spec_norm=spec_norm))
self.n_stages = n_stages
if maxpool:
self.maxpool = nn.MaxPool2d(kernel_size=5, stride=1, padding=2)
else:
self.maxpool = nn.AvgPool2d(kernel_size=5, stride=1, padding=2)
self.act = act
def forward(self, x):
x = self.act(x)
path = x
for i in range(self.n_stages):
path = self.maxpool(path)
path = self.convs[i](path)
x = path + x
return x
class CondCRPBlock(nn.Module):
def __init__(self, features, n_stages, num_classes, normalizer, act=nn.ReLU(), spec_norm=False):
super().__init__()
self.convs = nn.ModuleList()
self.norms = nn.ModuleList()
self.normalizer = normalizer
for i in range(n_stages):
self.norms.append(normalizer(features, num_classes, bias=True))
self.convs.append(conv3x3(features, features, stride=1, bias=False, spec_norm=spec_norm))
self.n_stages = n_stages
self.maxpool = nn.AvgPool2d(kernel_size=5, stride=1, padding=2)
self.act = act
def forward(self, x, y):
x = self.act(x)
path = x
for i in range(self.n_stages):
path = self.norms[i](path, y)
path = self.maxpool(path)
path = self.convs[i](path)
x = path + x
return x
class RCUBlock(nn.Module):
def __init__(self, features, n_blocks, n_stages, act=nn.ReLU(), spec_norm=False):
super().__init__()
for i in range(n_blocks):
for j in range(n_stages):
setattr(self, '{}_{}_conv'.format(i + 1, j + 1), conv3x3(features, features, stride=1, bias=False,
spec_norm=spec_norm))
self.stride = 1
self.n_blocks = n_blocks
self.n_stages = n_stages
self.act = act
def forward(self, x):
for i in range(self.n_blocks):
residual = x
for j in range(self.n_stages):
x = self.act(x)
x = getattr(self, '{}_{}_conv'.format(i + 1, j + 1))(x)
x += residual
return x
class CondRCUBlock(nn.Module):
def __init__(self, features, n_blocks, n_stages, num_classes, normalizer, act=nn.ReLU(), spec_norm=False):
super().__init__()
for i in range(n_blocks):
for j in range(n_stages):
setattr(self, '{}_{}_norm'.format(i + 1, j + 1), normalizer(features, num_classes, bias=True))
setattr(self, '{}_{}_conv'.format(i + 1, j + 1),
conv3x3(features, features, stride=1, bias=False, spec_norm=spec_norm))
self.stride = 1
self.n_blocks = n_blocks
self.n_stages = n_stages
self.act = act
self.normalizer = normalizer
def forward(self, x, y):
for i in range(self.n_blocks):
residual = x
for j in range(self.n_stages):
x = getattr(self, '{}_{}_norm'.format(i + 1, j + 1))(x, y)
x = self.act(x)
x = getattr(self, '{}_{}_conv'.format(i + 1, j + 1))(x)
x += residual
return x
class MSFBlock(nn.Module):
def __init__(self, in_planes, features, spec_norm=False):
"""
:param in_planes: tuples of input planes
"""
super().__init__()
assert isinstance(in_planes, list) or isinstance(in_planes, tuple)
self.convs = nn.ModuleList()
self.features = features
for i in range(len(in_planes)):
self.convs.append(conv3x3(in_planes[i], features, stride=1, bias=True, spec_norm=spec_norm))
def forward(self, xs, shape):
sums = torch.zeros(xs[0].shape[0], self.features, *shape, device=xs[0].device)
for i in range(len(self.convs)):
h = self.convs[i](xs[i])
h = F.interpolate(h, size=shape, mode='bilinear', align_corners=True)
sums += h
return sums
class CondMSFBlock(nn.Module):
def __init__(self, in_planes, features, num_classes, normalizer, spec_norm=False):
"""
:param in_planes: tuples of input planes
"""
super().__init__()
assert isinstance(in_planes, list) or isinstance(in_planes, tuple)
self.convs = nn.ModuleList()
self.norms = nn.ModuleList()
self.features = features
self.normalizer = normalizer
for i in range(len(in_planes)):
self.convs.append(conv3x3(in_planes[i], features, stride=1, bias=True, spec_norm=spec_norm))
self.norms.append(normalizer(in_planes[i], num_classes, bias=True))
def forward(self, xs, y, shape):
sums = torch.zeros(xs[0].shape[0], self.features, *shape, device=xs[0].device)
for i in range(len(self.convs)):
h = self.norms[i](xs[i], y)
h = self.convs[i](h)
h = F.interpolate(h, size=shape, mode='bilinear', align_corners=True)
sums += h
return sums
class RefineBlock(nn.Module):
def __init__(self, in_planes, features, act=nn.ReLU(), start=False, end=False, maxpool=True, spec_norm=False):
super().__init__()
assert isinstance(in_planes, tuple) or isinstance(in_planes, list)
self.n_blocks = n_blocks = len(in_planes)
self.adapt_convs = nn.ModuleList()
for i in range(n_blocks):
self.adapt_convs.append(
RCUBlock(in_planes[i], 2, 2, act, spec_norm=spec_norm)
)
self.output_convs = RCUBlock(features, 3 if end else 1, 2, act, spec_norm=spec_norm)
if not start:
self.msf = MSFBlock(in_planes, features, spec_norm=spec_norm)
self.crp = CRPBlock(features, 2, act, maxpool=maxpool, spec_norm=spec_norm)
def forward(self, xs, output_shape):
assert isinstance(xs, tuple) or isinstance(xs, list)
hs = []
for i in range(len(xs)):
h = self.adapt_convs[i](xs[i])
hs.append(h)
if self.n_blocks > 1:
h = self.msf(hs, output_shape)
else:
h = hs[0]
h = self.crp(h)
h = self.output_convs(h)
return h
class CondRefineBlock(nn.Module):
def __init__(self, in_planes, features, num_classes, normalizer, act=nn.ReLU(), start=False, end=False, spec_norm=False):
super().__init__()
assert isinstance(in_planes, tuple) or isinstance(in_planes, list)
self.n_blocks = n_blocks = len(in_planes)
self.adapt_convs = nn.ModuleList()
for i in range(n_blocks):
self.adapt_convs.append(
CondRCUBlock(in_planes[i], 2, 2, num_classes, normalizer, act, spec_norm=spec_norm)
)
self.output_convs = CondRCUBlock(features, 3 if end else 1, 2, num_classes, normalizer, act, spec_norm=spec_norm)
if not start:
self.msf = CondMSFBlock(in_planes, features, num_classes, normalizer, spec_norm=spec_norm)
self.crp = CondCRPBlock(features, 2, num_classes, normalizer, act, spec_norm=spec_norm)
def forward(self, xs, y, output_shape):
assert isinstance(xs, tuple) or isinstance(xs, list)
hs = []
for i in range(len(xs)):
h = self.adapt_convs[i](xs[i], y)
hs.append(h)
if self.n_blocks > 1:
h = self.msf(hs, y, output_shape)
else:
h = hs[0]
h = self.crp(h, y)
h = self.output_convs(h, y)
return h
class ConvMeanPool(nn.Module):
def __init__(self, input_dim, output_dim, kernel_size=3, biases=True, adjust_padding=False, spec_norm=False):
super().__init__()
if not adjust_padding:
conv = nn.Conv2d(input_dim, output_dim, kernel_size, stride=1, padding=kernel_size // 2, bias=biases)
if spec_norm:
conv = spectral_norm(conv)
self.conv = conv
else:
conv = nn.Conv2d(input_dim, output_dim, kernel_size, stride=1, padding=kernel_size // 2, bias=biases)
if spec_norm:
conv = spectral_norm(conv)
self.conv = nn.Sequential(
nn.ZeroPad2d((1, 0, 1, 0)),
conv
)
def forward(self, inputs):
output = self.conv(inputs)
output = sum([output[:, :, ::2, ::2], output[:, :, 1::2, ::2],
output[:, :, ::2, 1::2], output[:, :, 1::2, 1::2]]) / 4.
return output
class MeanPoolConv(nn.Module):
def __init__(self, input_dim, output_dim, kernel_size=3, biases=True, spec_norm=False):
super().__init__()
self.conv = nn.Conv2d(input_dim, output_dim, kernel_size, stride=1, padding=kernel_size // 2, bias=biases)
if spec_norm:
self.conv = spectral_norm(self.conv)
def forward(self, inputs):
output = inputs
output = sum([output[:, :, ::2, ::2], output[:, :, 1::2, ::2],
output[:, :, ::2, 1::2], output[:, :, 1::2, 1::2]]) / 4.
return self.conv(output)
class UpsampleConv(nn.Module):
def __init__(self, input_dim, output_dim, kernel_size=3, biases=True, spec_norm=False):
super().__init__()
self.conv = nn.Conv2d(input_dim, output_dim, kernel_size, stride=1, padding=kernel_size // 2, bias=biases)
if spec_norm:
self.conv = spectral_norm(self.conv)
self.pixelshuffle = nn.PixelShuffle(upscale_factor=2)
def forward(self, inputs):
output = inputs
output = torch.cat([output, output, output, output], dim=1)
output = self.pixelshuffle(output)
return self.conv(output)
class ConditionalResidualBlock(nn.Module):
def __init__(self, input_dim, output_dim, num_classes, resample=None, act=nn.ELU(),
normalization=ConditionalBatchNorm2d, adjust_padding=False, dilation=None, spec_norm=False):
super().__init__()
self.non_linearity = act
self.input_dim = input_dim
self.output_dim = output_dim
self.resample = resample
self.normalization = normalization
if resample == 'down':
if dilation is not None:
self.conv1 = dilated_conv3x3(input_dim, input_dim, dilation=dilation, spec_norm=spec_norm)
self.normalize2 = normalization(input_dim, num_classes)
self.conv2 = dilated_conv3x3(input_dim, output_dim, dilation=dilation, spec_norm=spec_norm)
conv_shortcut = partial(dilated_conv3x3, dilation=dilation, spec_norm=spec_norm)
else:
self.conv1 = conv3x3(input_dim, input_dim, spec_norm=spec_norm)
self.normalize2 = normalization(input_dim, num_classes)
self.conv2 = ConvMeanPool(input_dim, output_dim, 3, adjust_padding=adjust_padding, spec_norm=spec_norm)
conv_shortcut = partial(ConvMeanPool, kernel_size=1, adjust_padding=adjust_padding, spec_norm=spec_norm)
elif resample is None:
if dilation is not None:
conv_shortcut = partial(dilated_conv3x3, dilation=dilation, spec_norm=spec_norm)
self.conv1 = dilated_conv3x3(input_dim, output_dim, dilation=dilation, spec_norm=spec_norm)
self.normalize2 = normalization(output_dim, num_classes)
self.conv2 = dilated_conv3x3(output_dim, output_dim, dilation=dilation, spec_norm=spec_norm)
else:
conv_shortcut = nn.Conv2d
self.conv1 = conv3x3(input_dim, output_dim, spec_norm=spec_norm)
self.normalize2 = normalization(output_dim, num_classes)
self.conv2 = conv3x3(output_dim, output_dim, spec_norm=spec_norm)
else:
raise Exception('invalid resample value')
if output_dim != input_dim or resample is not None:
self.shortcut = conv_shortcut(input_dim, output_dim)
self.normalize1 = normalization(input_dim, num_classes)
def forward(self, x, y):
output = self.normalize1(x, y)
output = self.non_linearity(output)
output = self.conv1(output)
output = self.normalize2(output, y)
output = self.non_linearity(output)
output = self.conv2(output)
if self.output_dim == self.input_dim and self.resample is None:
shortcut = x
else:
shortcut = self.shortcut(x)
return shortcut + output
class ResidualBlock(nn.Module):
def __init__(self, input_dim, output_dim, resample=None, act=nn.ELU(),
normalization=nn.BatchNorm2d, adjust_padding=False, dilation=None, spec_norm=False):
super().__init__()
self.non_linearity = act
self.input_dim = input_dim
self.output_dim = output_dim
self.resample = resample
self.normalization = normalization
if resample == 'down':
if dilation is not None:
self.conv1 = dilated_conv3x3(input_dim, input_dim, dilation=dilation, spec_norm=spec_norm)
self.normalize2 = normalization(input_dim)
self.conv2 = dilated_conv3x3(input_dim, output_dim, dilation=dilation, spec_norm=spec_norm)
conv_shortcut = partial(dilated_conv3x3, dilation=dilation, spec_norm=spec_norm)
else:
self.conv1 = conv3x3(input_dim, input_dim, spec_norm=spec_norm)
self.normalize2 = normalization(input_dim)
self.conv2 = ConvMeanPool(input_dim, output_dim, 3, adjust_padding=adjust_padding, spec_norm=spec_norm)
conv_shortcut = partial(ConvMeanPool, kernel_size=1, adjust_padding=adjust_padding, spec_norm=spec_norm)
elif resample is None:
if dilation is not None:
conv_shortcut = partial(dilated_conv3x3, dilation=dilation, spec_norm=spec_norm)
self.conv1 = dilated_conv3x3(input_dim, output_dim, dilation=dilation, spec_norm=spec_norm)
self.normalize2 = normalization(output_dim)
self.conv2 = dilated_conv3x3(output_dim, output_dim, dilation=dilation, spec_norm=spec_norm)
else:
# conv_shortcut = nn.Conv2d ### Something wierd here.
conv_shortcut = partial(conv1x1, spec_norm=spec_norm)
self.conv1 = conv3x3(input_dim, output_dim, spec_norm=spec_norm)
self.normalize2 = normalization(output_dim)
self.conv2 = conv3x3(output_dim, output_dim, spec_norm=spec_norm)
else:
raise Exception('invalid resample value')
if output_dim != input_dim or resample is not None:
self.shortcut = conv_shortcut(input_dim, output_dim)
self.normalize1 = normalization(input_dim)
def forward(self, x):
output = self.normalize1(x)
output = self.non_linearity(output)
output = self.conv1(output)
output = self.normalize2(output)
output = self.non_linearity(output)
output = self.conv2(output)
if self.output_dim == self.input_dim and self.resample is None:
shortcut = x
else:
shortcut = self.shortcut(x)
return shortcut + output