Spaces:
Build error
Build error
import torch | |
import torch.nn as nn | |
def get_normalization(config, conditional=True): | |
norm = config.model.normalization | |
if conditional: | |
if norm == 'NoneNorm': | |
return ConditionalNoneNorm2d | |
elif norm == 'InstanceNorm++': | |
return ConditionalInstanceNorm2dPlus | |
elif norm == 'InstanceNorm': | |
return ConditionalInstanceNorm2d | |
elif norm == 'BatchNorm': | |
return ConditionalBatchNorm2d | |
elif norm == 'VarianceNorm': | |
return ConditionalVarianceNorm2d | |
else: | |
raise NotImplementedError("{} does not exist!".format(norm)) | |
else: | |
if norm == 'BatchNorm': | |
return nn.BatchNorm2d | |
elif norm == 'InstanceNorm': | |
return nn.InstanceNorm2d | |
elif norm == 'InstanceNorm++': | |
return InstanceNorm2dPlus | |
elif norm == 'VarianceNorm': | |
return VarianceNorm2d | |
elif norm == 'NoneNorm': | |
return NoneNorm2d | |
elif norm is None: | |
return None | |
else: | |
raise NotImplementedError("{} does not exist!".format(norm)) | |
class ConditionalBatchNorm2d(nn.Module): | |
def __init__(self, num_features, num_classes, bias=True): | |
super().__init__() | |
self.num_features = num_features | |
self.bias = bias | |
self.bn = nn.BatchNorm2d(num_features, affine=False) | |
if self.bias: | |
self.embed = nn.Embedding(num_classes, num_features * 2) | |
self.embed.weight.data[:, :num_features].uniform_() # Initialise scale at N(1, 0.02) | |
self.embed.weight.data[:, num_features:].zero_() # Initialise bias at 0 | |
else: | |
self.embed = nn.Embedding(num_classes, num_features) | |
self.embed.weight.data.uniform_() | |
def forward(self, x, y): | |
out = self.bn(x) | |
if self.bias: | |
gamma, beta = self.embed(y).chunk(2, dim=1) | |
out = gamma.view(-1, self.num_features, 1, 1) * out + beta.view(-1, self.num_features, 1, 1) | |
else: | |
gamma = self.embed(y) | |
out = gamma.view(-1, self.num_features, 1, 1) * out | |
return out | |
class ConditionalInstanceNorm2d(nn.Module): | |
def __init__(self, num_features, num_classes, bias=True): | |
super().__init__() | |
self.num_features = num_features | |
self.bias = bias | |
self.instance_norm = nn.InstanceNorm2d(num_features, affine=False, track_running_stats=False) | |
if bias: | |
self.embed = nn.Embedding(num_classes, num_features * 2) | |
self.embed.weight.data[:, :num_features].uniform_() # Initialise scale at N(1, 0.02) | |
self.embed.weight.data[:, num_features:].zero_() # Initialise bias at 0 | |
else: | |
self.embed = nn.Embedding(num_classes, num_features) | |
self.embed.weight.data.uniform_() | |
def forward(self, x, y): | |
h = self.instance_norm(x) | |
if self.bias: | |
gamma, beta = self.embed(y).chunk(2, dim=-1) | |
out = gamma.view(-1, self.num_features, 1, 1) * h + beta.view(-1, self.num_features, 1, 1) | |
else: | |
gamma = self.embed(y) | |
out = gamma.view(-1, self.num_features, 1, 1) * h | |
return out | |
class ConditionalVarianceNorm2d(nn.Module): | |
def __init__(self, num_features, num_classes, bias=False): | |
super().__init__() | |
self.num_features = num_features | |
self.bias = bias | |
self.embed = nn.Embedding(num_classes, num_features) | |
self.embed.weight.data.normal_(1, 0.02) | |
def forward(self, x, y): | |
vars = torch.var(x, dim=(2, 3), keepdim=True) | |
h = x / torch.sqrt(vars + 1e-5) | |
gamma = self.embed(y) | |
out = gamma.view(-1, self.num_features, 1, 1) * h | |
return out | |
class VarianceNorm2d(nn.Module): | |
def __init__(self, num_features, bias=False): | |
super().__init__() | |
self.num_features = num_features | |
self.bias = bias | |
self.alpha = nn.Parameter(torch.zeros(num_features)) | |
self.alpha.data.normal_(1, 0.02) | |
def forward(self, x): | |
vars = torch.var(x, dim=(2, 3), keepdim=True) | |
h = x / torch.sqrt(vars + 1e-5) | |
out = self.alpha.view(-1, self.num_features, 1, 1) * h | |
return out | |
class ConditionalNoneNorm2d(nn.Module): | |
def __init__(self, num_features, num_classes, bias=True): | |
super().__init__() | |
self.num_features = num_features | |
self.bias = bias | |
if bias: | |
self.embed = nn.Embedding(num_classes, num_features * 2) | |
self.embed.weight.data[:, :num_features].uniform_() # Initialise scale at N(1, 0.02) | |
self.embed.weight.data[:, num_features:].zero_() # Initialise bias at 0 | |
else: | |
self.embed = nn.Embedding(num_classes, num_features) | |
self.embed.weight.data.uniform_() | |
def forward(self, x, y): | |
if self.bias: | |
gamma, beta = self.embed(y).chunk(2, dim=-1) | |
out = gamma.view(-1, self.num_features, 1, 1) * x + beta.view(-1, self.num_features, 1, 1) | |
else: | |
gamma = self.embed(y) | |
out = gamma.view(-1, self.num_features, 1, 1) * x | |
return out | |
class NoneNorm2d(nn.Module): | |
def __init__(self, num_features, bias=True): | |
super().__init__() | |
def forward(self, x): | |
return x | |
class InstanceNorm2dPlus(nn.Module): | |
def __init__(self, num_features, bias=True): | |
super().__init__() | |
self.num_features = num_features | |
self.bias = bias | |
self.instance_norm = nn.InstanceNorm2d(num_features, affine=False, track_running_stats=False) | |
self.alpha = nn.Parameter(torch.zeros(num_features)) | |
self.gamma = nn.Parameter(torch.zeros(num_features)) | |
self.alpha.data.normal_(1, 0.02) | |
self.gamma.data.normal_(1, 0.02) | |
if bias: | |
self.beta = nn.Parameter(torch.zeros(num_features)) | |
def forward(self, x): | |
means = torch.mean(x, dim=(2, 3)) | |
m = torch.mean(means, dim=-1, keepdim=True) | |
v = torch.var(means, dim=-1, keepdim=True) | |
means = (means - m) / (torch.sqrt(v + 1e-5)) | |
h = self.instance_norm(x) | |
if self.bias: | |
h = h + means[..., None, None] * self.alpha[..., None, None] | |
out = self.gamma.view(-1, self.num_features, 1, 1) * h + self.beta.view(-1, self.num_features, 1, 1) | |
else: | |
h = h + means[..., None, None] * self.alpha[..., None, None] | |
out = self.gamma.view(-1, self.num_features, 1, 1) * h | |
return out | |
class ConditionalInstanceNorm2dPlus(nn.Module): | |
def __init__(self, num_features, num_classes, bias=True): | |
super().__init__() | |
self.num_features = num_features | |
self.bias = bias | |
self.instance_norm = nn.InstanceNorm2d(num_features, affine=False, track_running_stats=False) | |
if bias: | |
self.embed = nn.Embedding(num_classes, num_features * 3) | |
self.embed.weight.data[:, :2 * num_features].normal_(1, 0.02) # Initialise scale at N(1, 0.02) | |
self.embed.weight.data[:, 2 * num_features:].zero_() # Initialise bias at 0 | |
else: | |
self.embed = nn.Embedding(num_classes, 2 * num_features) | |
self.embed.weight.data.normal_(1, 0.02) | |
def forward(self, x, y): | |
means = torch.mean(x, dim=(2, 3)) | |
m = torch.mean(means, dim=-1, keepdim=True) | |
v = torch.var(means, dim=-1, keepdim=True) | |
means = (means - m) / (torch.sqrt(v + 1e-5)) | |
h = self.instance_norm(x) | |
if self.bias: | |
gamma, alpha, beta = self.embed(y).chunk(3, dim=-1) | |
h = h + means[..., None, None] * alpha[..., None, None] | |
out = gamma.view(-1, self.num_features, 1, 1) * h + beta.view(-1, self.num_features, 1, 1) | |
else: | |
gamma, alpha = self.embed(y).chunk(2, dim=-1) | |
h = h + means[..., None, None] * alpha[..., None, None] | |
out = gamma.view(-1, self.num_features, 1, 1) * h | |
return out | |