import re
import torch
import torch.nn as nn
from torch.nn import init
import torch.nn.functional as F
import torch.nn.utils.spectral_norm as spectral_norm
from models.utils.batchnorm import SynchronizedBatchNorm2d
class SPADE(nn.Module):
def __init__(self, config_text, norm_nc, label_nc):
assert config_text.startswith('spade')
parsed = re.search('spade(\D+)(\d)x\d', config_text)
param_free_norm_type = str(parsed.group(1))
ks = int(parsed.group(2))
if param_free_norm_type == 'instance':
self.param_free_norm = nn.InstanceNorm2d(norm_nc, affine=False)
elif param_free_norm_type == 'syncbatch':
self.param_free_norm = SynchronizedBatchNorm2d(norm_nc, affine=False)
elif param_free_norm_type == 'batch':
self.param_free_norm = nn.BatchNorm2d(norm_nc, affine=False)
raise ValueError('%s is not a recognized param-free norm type in SPADE'
% param_free_norm_type)
# The dimension of the intermediate embedding space. Yes, hardcoded.
nhidden = 128
pw = ks // 2
self.mlp_shared = nn.Sequential(
nn.Conv2d(label_nc, nhidden, kernel_size=ks, padding=pw),
self.mlp_gamma = nn.Conv2d(nhidden, norm_nc, kernel_size=ks, padding=pw)
self.mlp_beta = nn.Conv2d(nhidden, norm_nc, kernel_size=ks, padding=pw)
def forward(self, x, segmap):
# Part 1. generate parameter-free normalized activations
normalized = self.param_free_norm(x)
# Part 2. produce scaling and bias conditioned on semantic map
segmap = F.interpolate(segmap, size=x.size()[2:], mode='nearest')
actv = self.mlp_shared(segmap)
gamma = self.mlp_gamma(actv)
beta = self.mlp_beta(actv)
# apply scale and bias
out = normalized * (1 + gamma) + beta
return out
class SPADEResnetBlock(nn.Module):
def __init__(self, fin, fout, norm_G, semantic_nc):
# Attributes
self.learned_shortcut = (fin != fout)
fmiddle = min(fin, fout)
# create conv layers
self.conv_0 = nn.Conv2d(fin, fmiddle, kernel_size=3, padding=1)
self.conv_1 = nn.Conv2d(fmiddle, fout, kernel_size=3, padding=1)
if self.learned_shortcut:
self.conv_s = nn.Conv2d(fin, fout, kernel_size=1, bias=False)
# apply spectral norm if specified
if 'spectral' in norm_G:
self.conv_0 = spectral_norm(self.conv_0)
self.conv_1 = spectral_norm(self.conv_1)
if self.learned_shortcut:
self.conv_s = spectral_norm(self.conv_s)
# define normalization layers
spade_config_str = norm_G.replace('spectral', '')
self.norm_0 = SPADE(spade_config_str, fin, semantic_nc)
self.norm_1 = SPADE(spade_config_str, fmiddle, semantic_nc)
if self.learned_shortcut:
self.norm_s = SPADE(spade_config_str, fin, semantic_nc)
# note the resnet block with SPADE also takes in |seg|,
# the semantic segmentation map as input
def forward(self, x, seg):
x_s = self.shortcut(x, seg)
dx = self.conv_0(self.actvn(self.norm_0(x, seg)))
dx = self.conv_1(self.actvn(self.norm_1(dx, seg)))
out = x_s + dx
return out
def shortcut(self, x, seg):
if self.learned_shortcut:
x_s = self.conv_s(self.norm_s(x, seg))
x_s = x
return x_s
def actvn(self, x):
return F.leaky_relu(x, 2e-1)
class BaseNetwork(nn.Module):
def __init__(self):
super(BaseNetwork, self).__init__()
def modify_commandline_options(parser, is_train):
return parser
def print_network(self):
if isinstance(self, list):
self = self[0]
num_params = 0
for param in self.parameters():
num_params += param.numel()
print('Network [%s] was created. Total number of parameters: %.1f million. '
'To see the architecture, do print(network).'
% (type(self).__name__, num_params / 1000000))
def init_weights(self, init_type='normal', gain=0.02):
def init_func(m):
classname = m.__class__.__name__
if classname.find('BatchNorm2d') != -1:
if hasattr(m, 'weight') and m.weight is not None:
init.normal_(m.weight.data, 1.0, gain)
if hasattr(m, 'bias') and m.bias is not None:
init.constant_(m.bias.data, 0.0)
elif hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
if init_type == 'normal':
init.normal_(m.weight.data, 0.0, gain)
elif init_type == 'xavier':
init.xavier_normal_(m.weight.data, gain=gain)
elif init_type == 'xavier_uniform':
init.xavier_uniform_(m.weight.data, gain=1.0)
elif init_type == 'kaiming':
init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
elif init_type == 'orthogonal':
init.orthogonal_(m.weight.data, gain=gain)
elif init_type == 'none': # uses pytorch's default init method
raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
if hasattr(m, 'bias') and m.bias is not None:
init.constant_(m.bias.data, 0.0)
# propagate to children
for m in self.children():
if hasattr(m, 'init_weights'):
m.init_weights(init_type, gain)
class SPADEGenerator(BaseNetwork):
def __init__(self, z_dim, semantic_nc, ngf, dim_seq, bev_grid_size, aspect_ratio,
num_upsampling_layers, not_use_vae, norm_G):
nf = ngf
self.not_use_vae = not_use_vae
self.z_dim = z_dim
self.ngf = ngf
self.dim_seq = list(map(int, dim_seq.split(',')))
self.num_upsampling_layers = num_upsampling_layers
self.sw, self.sh = self.compute_latent_vector_size(num_upsampling_layers, bev_grid_size, aspect_ratio)
if not not_use_vae:
# In case of VAE, we will sample from random z vector
self.fc = nn.Linear(z_dim, self.dim_seq[0] * nf * self.sw * self.sh)
# Otherwise, we make the network deterministic by starting with
# downsampled segmentation map instead of random z
self.fc = nn.Conv2d(semantic_nc, self.dim_seq[0] * nf, 3, padding=1)
self.head_0 = SPADEResnetBlock(self.dim_seq[0] * nf, self.dim_seq[0] * nf, norm_G, semantic_nc)
self.G_middle_0 = SPADEResnetBlock(self.dim_seq[0] * nf, self.dim_seq[0] * nf, norm_G, semantic_nc)
self.G_middle_1 = SPADEResnetBlock(self.dim_seq[0] * nf, self.dim_seq[0] * nf, norm_G, semantic_nc)
self.up_0 = SPADEResnetBlock(self.dim_seq[0] * nf, self.dim_seq[1] * nf, norm_G, semantic_nc)
self.up_1 = SPADEResnetBlock(self.dim_seq[1] * nf, self.dim_seq[2] * nf, norm_G, semantic_nc)
self.up_2 = SPADEResnetBlock(self.dim_seq[2] * nf, self.dim_seq[3] * nf, norm_G, semantic_nc)
self.up_3 = SPADEResnetBlock(self.dim_seq[3] * nf, self.dim_seq[4] * nf, norm_G, semantic_nc)
final_nc = nf * self.dim_seq[4]
if num_upsampling_layers == 'most':
self.up_4 = SPADEResnetBlock(self.dim_seq[4] * nf, nf // 2, norm_G, semantic_nc)
final_nc = nf // 2
self.conv_img = nn.Conv2d(final_nc, 32, 3, padding=1)
# self.conv_img = nn.Conv2d(final_nc, 3, 3, padding=1)
self.up = nn.Upsample(scale_factor=2)
def compute_latent_vector_size(self, num_upsampling_layers, bev_grid_size, aspect_ratio):
if num_upsampling_layers == 'normal':
num_up_layers = 5
elif num_upsampling_layers == 'more':
num_up_layers = 6
elif num_upsampling_layers == 'most':
num_up_layers = 7
raise ValueError('num_upsampling_layers [%s] not recognized' %
sw = bev_grid_size // (2**num_up_layers)
sh = round(sw / aspect_ratio)
return sw, sh
def forward(self, input, z=None):
seg = input
if not self.not_use_vae:
# we sample z from unit normal and reshape the tensor
if z is None:
z = torch.randn(input.size(0), self.z_dim,
dtype=torch.float32, device=input.get_device())
x = self.fc(z)
x = x.view(-1, self.dim_seq[0] * self.ngf, self.sh, self.sw)
# we downsample segmap and run convolution
x = F.interpolate(seg, size=(self.sh, self.sw))
x = self.fc(x)
x = self.head_0(x, seg)
x = self.up(x)
x = self.G_middle_0(x, seg)
if self.num_upsampling_layers == 'more' or \
self.num_upsampling_layers == 'most':
x = self.up(x)
x = self.G_middle_1(x, seg)
x = self.up(x)
x = self.up_0(x, seg)
x = self.up(x)
x = self.up_1(x, seg)
x = self.up(x)
x = self.up_2(x, seg)
x = self.up(x)
x = self.up_3(x, seg)
if self.num_upsampling_layers == 'most':
x = self.up(x)
x = self.up_4(x, seg)
# TODO: Wtf is this leaky relu
x = self.conv_img(F.leaky_relu(x, 2e-1))
# x = torch.tanh(x)
return x
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--z_dim', type=int, default=10)
parser.add_argument('--semantic_nc', type=int, default=10)
parser.add_argument('--ngf', type=int, default=64)
parser.add_argument('--bev_grid_size', type=int, default=512)
parser.add_argument('--aspect_ratio', type=float, default=1.0)
parser.add_argument('--num_upsampling_layers', type=str, default='more')
parser.add_argument('--not_use_vae', action="store_true")
parser.add_argument('--norm_G', type=str, default='spectralspadesyncbatch3x3', help='instance normalization or batch normalization')
args = parser.parse_args()
sg = SPADEGenerator(args).cuda()
seg = torch.zeros([2, 10, 5, 5]).cuda()
while 1:
import pdb;pdb.set_trace()
out = sg(seg)