sjc / ncsn /normalization.py
amankishore's picture
Updated app.py
7a11626
raw
history blame
8.01 kB
import torch
import torch.nn as nn
def get_normalization(config, conditional=True):
norm = config.model.normalization
if conditional:
if norm == 'NoneNorm':
return ConditionalNoneNorm2d
elif norm == 'InstanceNorm++':
return ConditionalInstanceNorm2dPlus
elif norm == 'InstanceNorm':
return ConditionalInstanceNorm2d
elif norm == 'BatchNorm':
return ConditionalBatchNorm2d
elif norm == 'VarianceNorm':
return ConditionalVarianceNorm2d
else:
raise NotImplementedError("{} does not exist!".format(norm))
else:
if norm == 'BatchNorm':
return nn.BatchNorm2d
elif norm == 'InstanceNorm':
return nn.InstanceNorm2d
elif norm == 'InstanceNorm++':
return InstanceNorm2dPlus
elif norm == 'VarianceNorm':
return VarianceNorm2d
elif norm == 'NoneNorm':
return NoneNorm2d
elif norm is None:
return None
else:
raise NotImplementedError("{} does not exist!".format(norm))
class ConditionalBatchNorm2d(nn.Module):
def __init__(self, num_features, num_classes, bias=True):
super().__init__()
self.num_features = num_features
self.bias = bias
self.bn = nn.BatchNorm2d(num_features, affine=False)
if self.bias:
self.embed = nn.Embedding(num_classes, num_features * 2)
self.embed.weight.data[:, :num_features].uniform_() # Initialise scale at N(1, 0.02)
self.embed.weight.data[:, num_features:].zero_() # Initialise bias at 0
else:
self.embed = nn.Embedding(num_classes, num_features)
self.embed.weight.data.uniform_()
def forward(self, x, y):
out = self.bn(x)
if self.bias:
gamma, beta = self.embed(y).chunk(2, dim=1)
out = gamma.view(-1, self.num_features, 1, 1) * out + beta.view(-1, self.num_features, 1, 1)
else:
gamma = self.embed(y)
out = gamma.view(-1, self.num_features, 1, 1) * out
return out
class ConditionalInstanceNorm2d(nn.Module):
def __init__(self, num_features, num_classes, bias=True):
super().__init__()
self.num_features = num_features
self.bias = bias
self.instance_norm = nn.InstanceNorm2d(num_features, affine=False, track_running_stats=False)
if bias:
self.embed = nn.Embedding(num_classes, num_features * 2)
self.embed.weight.data[:, :num_features].uniform_() # Initialise scale at N(1, 0.02)
self.embed.weight.data[:, num_features:].zero_() # Initialise bias at 0
else:
self.embed = nn.Embedding(num_classes, num_features)
self.embed.weight.data.uniform_()
def forward(self, x, y):
h = self.instance_norm(x)
if self.bias:
gamma, beta = self.embed(y).chunk(2, dim=-1)
out = gamma.view(-1, self.num_features, 1, 1) * h + beta.view(-1, self.num_features, 1, 1)
else:
gamma = self.embed(y)
out = gamma.view(-1, self.num_features, 1, 1) * h
return out
class ConditionalVarianceNorm2d(nn.Module):
def __init__(self, num_features, num_classes, bias=False):
super().__init__()
self.num_features = num_features
self.bias = bias
self.embed = nn.Embedding(num_classes, num_features)
self.embed.weight.data.normal_(1, 0.02)
def forward(self, x, y):
vars = torch.var(x, dim=(2, 3), keepdim=True)
h = x / torch.sqrt(vars + 1e-5)
gamma = self.embed(y)
out = gamma.view(-1, self.num_features, 1, 1) * h
return out
class VarianceNorm2d(nn.Module):
def __init__(self, num_features, bias=False):
super().__init__()
self.num_features = num_features
self.bias = bias
self.alpha = nn.Parameter(torch.zeros(num_features))
self.alpha.data.normal_(1, 0.02)
def forward(self, x):
vars = torch.var(x, dim=(2, 3), keepdim=True)
h = x / torch.sqrt(vars + 1e-5)
out = self.alpha.view(-1, self.num_features, 1, 1) * h
return out
class ConditionalNoneNorm2d(nn.Module):
def __init__(self, num_features, num_classes, bias=True):
super().__init__()
self.num_features = num_features
self.bias = bias
if bias:
self.embed = nn.Embedding(num_classes, num_features * 2)
self.embed.weight.data[:, :num_features].uniform_() # Initialise scale at N(1, 0.02)
self.embed.weight.data[:, num_features:].zero_() # Initialise bias at 0
else:
self.embed = nn.Embedding(num_classes, num_features)
self.embed.weight.data.uniform_()
def forward(self, x, y):
if self.bias:
gamma, beta = self.embed(y).chunk(2, dim=-1)
out = gamma.view(-1, self.num_features, 1, 1) * x + beta.view(-1, self.num_features, 1, 1)
else:
gamma = self.embed(y)
out = gamma.view(-1, self.num_features, 1, 1) * x
return out
class NoneNorm2d(nn.Module):
def __init__(self, num_features, bias=True):
super().__init__()
def forward(self, x):
return x
class InstanceNorm2dPlus(nn.Module):
def __init__(self, num_features, bias=True):
super().__init__()
self.num_features = num_features
self.bias = bias
self.instance_norm = nn.InstanceNorm2d(num_features, affine=False, track_running_stats=False)
self.alpha = nn.Parameter(torch.zeros(num_features))
self.gamma = nn.Parameter(torch.zeros(num_features))
self.alpha.data.normal_(1, 0.02)
self.gamma.data.normal_(1, 0.02)
if bias:
self.beta = nn.Parameter(torch.zeros(num_features))
def forward(self, x):
means = torch.mean(x, dim=(2, 3))
m = torch.mean(means, dim=-1, keepdim=True)
v = torch.var(means, dim=-1, keepdim=True)
means = (means - m) / (torch.sqrt(v + 1e-5))
h = self.instance_norm(x)
if self.bias:
h = h + means[..., None, None] * self.alpha[..., None, None]
out = self.gamma.view(-1, self.num_features, 1, 1) * h + self.beta.view(-1, self.num_features, 1, 1)
else:
h = h + means[..., None, None] * self.alpha[..., None, None]
out = self.gamma.view(-1, self.num_features, 1, 1) * h
return out
class ConditionalInstanceNorm2dPlus(nn.Module):
def __init__(self, num_features, num_classes, bias=True):
super().__init__()
self.num_features = num_features
self.bias = bias
self.instance_norm = nn.InstanceNorm2d(num_features, affine=False, track_running_stats=False)
if bias:
self.embed = nn.Embedding(num_classes, num_features * 3)
self.embed.weight.data[:, :2 * num_features].normal_(1, 0.02) # Initialise scale at N(1, 0.02)
self.embed.weight.data[:, 2 * num_features:].zero_() # Initialise bias at 0
else:
self.embed = nn.Embedding(num_classes, 2 * num_features)
self.embed.weight.data.normal_(1, 0.02)
def forward(self, x, y):
means = torch.mean(x, dim=(2, 3))
m = torch.mean(means, dim=-1, keepdim=True)
v = torch.var(means, dim=-1, keepdim=True)
means = (means - m) / (torch.sqrt(v + 1e-5))
h = self.instance_norm(x)
if self.bias:
gamma, alpha, beta = self.embed(y).chunk(3, dim=-1)
h = h + means[..., None, None] * alpha[..., None, None]
out = gamma.view(-1, self.num_features, 1, 1) * h + beta.view(-1, self.num_features, 1, 1)
else:
gamma, alpha = self.embed(y).chunk(2, dim=-1)
h = h + means[..., None, None] * alpha[..., None, None]
out = gamma.view(-1, self.num_features, 1, 1) * h
return out