|
import re |
|
import torch |
|
import torch.nn as nn |
|
from torch.nn import init |
|
import torch.nn.functional as F |
|
import torch.nn.utils.spectral_norm as spectral_norm |
|
|
|
from models.utils.batchnorm import SynchronizedBatchNorm2d |
|
|
|
class SPADE(nn.Module): |
|
def __init__(self, config_text, norm_nc, label_nc): |
|
super().__init__() |
|
|
|
assert config_text.startswith('spade') |
|
parsed = re.search('spade(\D+)(\d)x\d', config_text) |
|
param_free_norm_type = str(parsed.group(1)) |
|
ks = int(parsed.group(2)) |
|
|
|
if param_free_norm_type == 'instance': |
|
self.param_free_norm = nn.InstanceNorm2d(norm_nc, affine=False) |
|
elif param_free_norm_type == 'syncbatch': |
|
self.param_free_norm = SynchronizedBatchNorm2d(norm_nc, affine=False) |
|
elif param_free_norm_type == 'batch': |
|
self.param_free_norm = nn.BatchNorm2d(norm_nc, affine=False) |
|
else: |
|
raise ValueError('%s is not a recognized param-free norm type in SPADE' |
|
% param_free_norm_type) |
|
|
|
|
|
nhidden = 128 |
|
|
|
pw = ks // 2 |
|
self.mlp_shared = nn.Sequential( |
|
nn.Conv2d(label_nc, nhidden, kernel_size=ks, padding=pw), |
|
nn.ReLU() |
|
) |
|
self.mlp_gamma = nn.Conv2d(nhidden, norm_nc, kernel_size=ks, padding=pw) |
|
self.mlp_beta = nn.Conv2d(nhidden, norm_nc, kernel_size=ks, padding=pw) |
|
|
|
def forward(self, x, segmap): |
|
|
|
|
|
normalized = self.param_free_norm(x) |
|
|
|
|
|
segmap = F.interpolate(segmap, size=x.size()[2:], mode='nearest') |
|
actv = self.mlp_shared(segmap) |
|
gamma = self.mlp_gamma(actv) |
|
beta = self.mlp_beta(actv) |
|
|
|
|
|
out = normalized * (1 + gamma) + beta |
|
|
|
return out |
|
|
|
class SPADEResnetBlock(nn.Module): |
|
def __init__(self, fin, fout, norm_G, semantic_nc): |
|
super().__init__() |
|
|
|
self.learned_shortcut = (fin != fout) |
|
fmiddle = min(fin, fout) |
|
|
|
|
|
self.conv_0 = nn.Conv2d(fin, fmiddle, kernel_size=3, padding=1) |
|
self.conv_1 = nn.Conv2d(fmiddle, fout, kernel_size=3, padding=1) |
|
if self.learned_shortcut: |
|
self.conv_s = nn.Conv2d(fin, fout, kernel_size=1, bias=False) |
|
|
|
|
|
if 'spectral' in norm_G: |
|
self.conv_0 = spectral_norm(self.conv_0) |
|
self.conv_1 = spectral_norm(self.conv_1) |
|
if self.learned_shortcut: |
|
self.conv_s = spectral_norm(self.conv_s) |
|
|
|
|
|
spade_config_str = norm_G.replace('spectral', '') |
|
self.norm_0 = SPADE(spade_config_str, fin, semantic_nc) |
|
self.norm_1 = SPADE(spade_config_str, fmiddle, semantic_nc) |
|
if self.learned_shortcut: |
|
self.norm_s = SPADE(spade_config_str, fin, semantic_nc) |
|
|
|
|
|
|
|
def forward(self, x, seg): |
|
x_s = self.shortcut(x, seg) |
|
|
|
dx = self.conv_0(self.actvn(self.norm_0(x, seg))) |
|
dx = self.conv_1(self.actvn(self.norm_1(dx, seg))) |
|
|
|
out = x_s + dx |
|
|
|
return out |
|
|
|
def shortcut(self, x, seg): |
|
if self.learned_shortcut: |
|
x_s = self.conv_s(self.norm_s(x, seg)) |
|
else: |
|
x_s = x |
|
return x_s |
|
|
|
def actvn(self, x): |
|
return F.leaky_relu(x, 2e-1) |
|
|
|
class BaseNetwork(nn.Module): |
|
def __init__(self): |
|
super(BaseNetwork, self).__init__() |
|
|
|
@staticmethod |
|
def modify_commandline_options(parser, is_train): |
|
return parser |
|
|
|
def print_network(self): |
|
if isinstance(self, list): |
|
self = self[0] |
|
num_params = 0 |
|
for param in self.parameters(): |
|
num_params += param.numel() |
|
print('Network [%s] was created. Total number of parameters: %.1f million. ' |
|
'To see the architecture, do print(network).' |
|
% (type(self).__name__, num_params / 1000000)) |
|
|
|
def init_weights(self, init_type='normal', gain=0.02): |
|
def init_func(m): |
|
classname = m.__class__.__name__ |
|
if classname.find('BatchNorm2d') != -1: |
|
if hasattr(m, 'weight') and m.weight is not None: |
|
init.normal_(m.weight.data, 1.0, gain) |
|
if hasattr(m, 'bias') and m.bias is not None: |
|
init.constant_(m.bias.data, 0.0) |
|
elif hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): |
|
if init_type == 'normal': |
|
init.normal_(m.weight.data, 0.0, gain) |
|
elif init_type == 'xavier': |
|
init.xavier_normal_(m.weight.data, gain=gain) |
|
elif init_type == 'xavier_uniform': |
|
init.xavier_uniform_(m.weight.data, gain=1.0) |
|
elif init_type == 'kaiming': |
|
init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') |
|
elif init_type == 'orthogonal': |
|
init.orthogonal_(m.weight.data, gain=gain) |
|
elif init_type == 'none': |
|
m.reset_parameters() |
|
else: |
|
raise NotImplementedError('initialization method [%s] is not implemented' % init_type) |
|
if hasattr(m, 'bias') and m.bias is not None: |
|
init.constant_(m.bias.data, 0.0) |
|
|
|
self.apply(init_func) |
|
|
|
|
|
for m in self.children(): |
|
if hasattr(m, 'init_weights'): |
|
m.init_weights(init_type, gain) |
|
|
|
class SPADEGenerator(BaseNetwork): |
|
def __init__(self, z_dim, semantic_nc, ngf, dim_seq, bev_grid_size, aspect_ratio, |
|
num_upsampling_layers, not_use_vae, norm_G): |
|
super().__init__() |
|
nf = ngf |
|
self.not_use_vae = not_use_vae |
|
self.z_dim = z_dim |
|
self.ngf = ngf |
|
self.dim_seq = list(map(int, dim_seq.split(','))) |
|
self.num_upsampling_layers = num_upsampling_layers |
|
|
|
self.sw, self.sh = self.compute_latent_vector_size(num_upsampling_layers, bev_grid_size, aspect_ratio) |
|
|
|
if not not_use_vae: |
|
|
|
self.fc = nn.Linear(z_dim, self.dim_seq[0] * nf * self.sw * self.sh) |
|
else: |
|
|
|
|
|
self.fc = nn.Conv2d(semantic_nc, self.dim_seq[0] * nf, 3, padding=1) |
|
|
|
self.head_0 = SPADEResnetBlock(self.dim_seq[0] * nf, self.dim_seq[0] * nf, norm_G, semantic_nc) |
|
|
|
self.G_middle_0 = SPADEResnetBlock(self.dim_seq[0] * nf, self.dim_seq[0] * nf, norm_G, semantic_nc) |
|
self.G_middle_1 = SPADEResnetBlock(self.dim_seq[0] * nf, self.dim_seq[0] * nf, norm_G, semantic_nc) |
|
|
|
self.up_0 = SPADEResnetBlock(self.dim_seq[0] * nf, self.dim_seq[1] * nf, norm_G, semantic_nc) |
|
self.up_1 = SPADEResnetBlock(self.dim_seq[1] * nf, self.dim_seq[2] * nf, norm_G, semantic_nc) |
|
self.up_2 = SPADEResnetBlock(self.dim_seq[2] * nf, self.dim_seq[3] * nf, norm_G, semantic_nc) |
|
self.up_3 = SPADEResnetBlock(self.dim_seq[3] * nf, self.dim_seq[4] * nf, norm_G, semantic_nc) |
|
|
|
final_nc = nf * self.dim_seq[4] |
|
|
|
if num_upsampling_layers == 'most': |
|
self.up_4 = SPADEResnetBlock(self.dim_seq[4] * nf, nf // 2, norm_G, semantic_nc) |
|
final_nc = nf // 2 |
|
|
|
self.conv_img = nn.Conv2d(final_nc, 32, 3, padding=1) |
|
|
|
|
|
self.up = nn.Upsample(scale_factor=2) |
|
|
|
def compute_latent_vector_size(self, num_upsampling_layers, bev_grid_size, aspect_ratio): |
|
if num_upsampling_layers == 'normal': |
|
num_up_layers = 5 |
|
elif num_upsampling_layers == 'more': |
|
num_up_layers = 6 |
|
elif num_upsampling_layers == 'most': |
|
num_up_layers = 7 |
|
else: |
|
raise ValueError('num_upsampling_layers [%s] not recognized' % |
|
num_upsampling_layers) |
|
|
|
sw = bev_grid_size // (2**num_up_layers) |
|
sh = round(sw / aspect_ratio) |
|
|
|
return sw, sh |
|
|
|
def forward(self, input, z=None): |
|
seg = input |
|
|
|
if not self.not_use_vae: |
|
|
|
if z is None: |
|
z = torch.randn(input.size(0), self.z_dim, |
|
dtype=torch.float32, device=input.get_device()) |
|
x = self.fc(z) |
|
x = x.view(-1, self.dim_seq[0] * self.ngf, self.sh, self.sw) |
|
else: |
|
|
|
x = F.interpolate(seg, size=(self.sh, self.sw)) |
|
x = self.fc(x) |
|
|
|
x = self.head_0(x, seg) |
|
|
|
x = self.up(x) |
|
x = self.G_middle_0(x, seg) |
|
|
|
if self.num_upsampling_layers == 'more' or \ |
|
self.num_upsampling_layers == 'most': |
|
x = self.up(x) |
|
|
|
x = self.G_middle_1(x, seg) |
|
|
|
x = self.up(x) |
|
x = self.up_0(x, seg) |
|
x = self.up(x) |
|
x = self.up_1(x, seg) |
|
x = self.up(x) |
|
x = self.up_2(x, seg) |
|
x = self.up(x) |
|
x = self.up_3(x, seg) |
|
|
|
if self.num_upsampling_layers == 'most': |
|
x = self.up(x) |
|
x = self.up_4(x, seg) |
|
|
|
|
|
x = self.conv_img(F.leaky_relu(x, 2e-1)) |
|
|
|
|
|
return x |
|
|
|
if __name__ == '__main__': |
|
import argparse |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument('--z_dim', type=int, default=10) |
|
parser.add_argument('--semantic_nc', type=int, default=10) |
|
parser.add_argument('--ngf', type=int, default=64) |
|
parser.add_argument('--bev_grid_size', type=int, default=512) |
|
parser.add_argument('--aspect_ratio', type=float, default=1.0) |
|
parser.add_argument('--num_upsampling_layers', type=str, default='more') |
|
parser.add_argument('--not_use_vae', action="store_true") |
|
parser.add_argument('--norm_G', type=str, default='spectralspadesyncbatch3x3', help='instance normalization or batch normalization') |
|
|
|
args = parser.parse_args() |
|
sg = SPADEGenerator(args).cuda() |
|
seg = torch.zeros([2, 10, 5, 5]).cuda() |
|
while 1: |
|
import pdb;pdb.set_trace() |
|
out = sg(seg) |
|
|