Spaces:
Runtime error
Runtime error
import re | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import torch.nn.utils.spectral_norm as spectral_norm | |
# Returns a function that creates a normalization function | |
# that does not condition on semantic map | |
def get_nonspade_norm_layer(norm_type='instance'): | |
# helper function to get # output channels of the previous layer | |
def get_out_channel(layer): | |
if hasattr(layer, 'out_channels'): | |
return getattr(layer, 'out_channels') | |
return layer.weight.size(0) | |
# this function will be returned | |
def add_norm_layer(layer): | |
nonlocal norm_type | |
if norm_type.startswith('spectral'): | |
layer = spectral_norm(layer) | |
subnorm_type = norm_type[len('spectral'):] | |
if subnorm_type == 'none' or len(subnorm_type) == 0: | |
return layer | |
# remove bias in the previous layer, which is meaningless | |
# since it has no effect after normalization | |
if getattr(layer, 'bias', None) is not None: | |
delattr(layer, 'bias') | |
layer.register_parameter('bias', None) | |
if subnorm_type == 'batch': | |
norm_layer = nn.BatchNorm2d(get_out_channel(layer), affine=True) | |
elif subnorm_type == 'sync_batch': | |
norm_layer = SynchronizedBatchNorm2d(get_out_channel(layer), affine=True) | |
elif subnorm_type == 'instance': | |
norm_layer = nn.InstanceNorm2d(get_out_channel(layer), affine=False) | |
else: | |
raise ValueError('normalization layer %s is not recognized' % subnorm_type) | |
return nn.Sequential(layer, norm_layer) | |
return add_norm_layer | |
# Creates SPADE normalization layer based on the given configuration | |
# SPADE consists of two steps. First, it normalizes the activations using | |
# your favorite normalization method, such as Batch Norm or Instance Norm. | |
# Second, it applies scale and bias to the normalized output, conditioned on | |
# the segmentation map. | |
# The format of |config_text| is spade(norm)(ks), where | |
# (norm) specifies the type of parameter-free normalization. | |
# (e.g. syncbatch, batch, instance) | |
# (ks) specifies the size of kernel in the SPADE module (e.g. 3x3) | |
# Example |config_text| will be spadesyncbatch3x3, or spadeinstance5x5. | |
# Also, the other arguments are | |
# |norm_nc|: the #channels of the normalized activations, hence the output dim of SPADE | |
# |label_nc|: the #channels of the input semantic map, hence the input dim of SPADE | |
class SPADE(nn.Module): | |
def __init__(self, config_text, norm_nc, label_nc): | |
super().__init__() | |
assert config_text.startswith('spade') | |
parsed = re.search('spade(\D+)(\d)x\d', config_text) | |
param_free_norm_type = str(parsed.group(1)) | |
ks = int(parsed.group(2)) | |
if param_free_norm_type == 'instance': | |
self.param_free_norm = nn.InstanceNorm2d(norm_nc, affine=False) | |
elif param_free_norm_type == 'syncbatch': | |
self.param_free_norm = SynchronizedBatchNorm2d(norm_nc, affine=False) | |
elif param_free_norm_type == 'batch': | |
self.param_free_norm = nn.BatchNorm2d(norm_nc, affine=False) | |
else: | |
raise ValueError('%s is not a recognized param-free norm type in SPADE' | |
% param_free_norm_type) | |
# The dimension of the intermediate embedding space. Yes, hardcoded. | |
nhidden = 128 | |
pw = ks // 2 | |
self.mlp_shared = nn.Sequential( | |
nn.Conv2d(label_nc, nhidden, kernel_size=ks, padding=pw), | |
nn.ReLU() | |
) | |
self.mlp_gamma = nn.Conv2d(nhidden, norm_nc, kernel_size=ks, padding=pw) | |
self.mlp_beta = nn.Conv2d(nhidden, norm_nc, kernel_size=ks, padding=pw) | |
def forward(self, x, segmap): | |
# Part 1. generate parameter-free normalized activations | |
normalized = self.param_free_norm(x) | |
# Part 2. produce scaling and bias conditioned on semantic map | |
segmap = F.interpolate(segmap, size=x.size()[2:], mode='nearest') | |
actv = self.mlp_shared(segmap) | |
gamma = self.mlp_gamma(actv) | |
beta = self.mlp_beta(actv) | |
# apply scale and bias | |
out = normalized * (1 + gamma) + beta | |
return out |