|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" Layers |
|
This file contains various layers for the BigGAN models. |
|
""" |
|
import os |
|
import numpy as np |
|
import torch |
|
import torch.nn as nn |
|
from torch.nn import init |
|
import torch.optim as optim |
|
import torch.nn.functional as F |
|
from torch.nn import Parameter as P |
|
|
|
import sys |
|
|
|
sys.path.insert(1, os.path.join(sys.path[0], "..")) |
|
from BigGAN_PyTorch.sync_batchnorm import SynchronizedBatchNorm2d as SyncBN2d |
|
|
|
|
|
|
|
def proj(x, y): |
|
return torch.mm(y, x.t()) * y / torch.mm(y, y.t()) |
|
|
|
|
|
|
|
def gram_schmidt(x, ys): |
|
for y in ys: |
|
x = x - proj(x, y) |
|
return x |
|
|
|
|
|
|
|
def power_iteration(W, u_, update=True, eps=1e-12): |
|
|
|
us, vs, svs = [], [], [] |
|
for i, u in enumerate(u_): |
|
|
|
with torch.no_grad(): |
|
v = torch.matmul(u, W) |
|
|
|
v = F.normalize(gram_schmidt(v, vs), eps=eps) |
|
|
|
vs += [v] |
|
|
|
u = torch.matmul(v, W.t()) |
|
|
|
u = F.normalize(gram_schmidt(u, us), eps=eps) |
|
|
|
us += [u] |
|
if update: |
|
u_[i][:] = u |
|
|
|
svs += [torch.squeeze(torch.matmul(torch.matmul(v, W.t()), u.t()))] |
|
|
|
return svs, us, vs |
|
|
|
|
|
|
|
class identity(nn.Module): |
|
def forward(self, input): |
|
return input |
|
|
|
|
|
|
|
class SN(object): |
|
def __init__(self, num_svs, num_itrs, num_outputs, transpose=False, eps=1e-12): |
|
|
|
self.num_itrs = num_itrs |
|
|
|
self.num_svs = num_svs |
|
|
|
self.transpose = transpose |
|
|
|
self.eps = eps |
|
|
|
for i in range(self.num_svs): |
|
self.register_buffer("u%d" % i, torch.randn(1, num_outputs)) |
|
self.register_buffer("sv%d" % i, torch.ones(1)) |
|
|
|
|
|
@property |
|
def u(self): |
|
return [getattr(self, "u%d" % i) for i in range(self.num_svs)] |
|
|
|
|
|
|
|
@property |
|
def sv(self): |
|
return [getattr(self, "sv%d" % i) for i in range(self.num_svs)] |
|
|
|
|
|
def W_(self): |
|
W_mat = self.weight.view(self.weight.size(0), -1) |
|
if self.transpose: |
|
W_mat = W_mat.t() |
|
|
|
for _ in range(self.num_itrs): |
|
svs, us, vs = power_iteration( |
|
W_mat, self.u, update=self.training, eps=self.eps |
|
) |
|
|
|
if self.training: |
|
with torch.no_grad(): |
|
for i, sv in enumerate(svs): |
|
self.sv[i][:] = sv |
|
return self.weight / svs[0] |
|
|
|
|
|
|
|
class SNConv2d(nn.Conv2d, SN): |
|
def __init__( |
|
self, |
|
in_channels, |
|
out_channels, |
|
kernel_size, |
|
stride=1, |
|
padding=0, |
|
dilation=1, |
|
groups=1, |
|
bias=True, |
|
num_svs=1, |
|
num_itrs=1, |
|
eps=1e-12, |
|
): |
|
nn.Conv2d.__init__( |
|
self, |
|
in_channels, |
|
out_channels, |
|
kernel_size, |
|
stride, |
|
padding, |
|
dilation, |
|
groups, |
|
bias, |
|
) |
|
SN.__init__(self, num_svs, num_itrs, out_channels, eps=eps) |
|
|
|
def forward(self, x): |
|
return F.conv2d( |
|
x, |
|
self.W_(), |
|
self.bias, |
|
self.stride, |
|
self.padding, |
|
self.dilation, |
|
self.groups, |
|
) |
|
|
|
|
|
|
|
class SNLinear(nn.Linear, SN): |
|
def __init__( |
|
self, in_features, out_features, bias=True, num_svs=1, num_itrs=1, eps=1e-12 |
|
): |
|
nn.Linear.__init__(self, in_features, out_features, bias) |
|
SN.__init__(self, num_svs, num_itrs, out_features, eps=eps) |
|
|
|
def forward(self, x): |
|
return F.linear(x, self.W_(), self.bias) |
|
|
|
|
|
|
|
|
|
|
|
class SNEmbedding(nn.Embedding, SN): |
|
def __init__( |
|
self, |
|
num_embeddings, |
|
embedding_dim, |
|
padding_idx=None, |
|
max_norm=None, |
|
norm_type=2, |
|
scale_grad_by_freq=False, |
|
sparse=False, |
|
_weight=None, |
|
num_svs=1, |
|
num_itrs=1, |
|
eps=1e-12, |
|
): |
|
nn.Embedding.__init__( |
|
self, |
|
num_embeddings, |
|
embedding_dim, |
|
padding_idx, |
|
max_norm, |
|
norm_type, |
|
scale_grad_by_freq, |
|
sparse, |
|
_weight, |
|
) |
|
SN.__init__(self, num_svs, num_itrs, num_embeddings, eps=eps) |
|
|
|
def forward(self, x): |
|
return F.embedding(x, self.W_()) |
|
|
|
|
|
|
|
|
|
|
|
class Attention(nn.Module): |
|
def __init__(self, ch, which_conv=SNConv2d, name="attention"): |
|
super(Attention, self).__init__() |
|
|
|
self.ch = ch |
|
self.which_conv = which_conv |
|
self.theta = self.which_conv( |
|
self.ch, self.ch // 8, kernel_size=1, padding=0, bias=False |
|
) |
|
self.phi = self.which_conv( |
|
self.ch, self.ch // 8, kernel_size=1, padding=0, bias=False |
|
) |
|
self.g = self.which_conv( |
|
self.ch, self.ch // 2, kernel_size=1, padding=0, bias=False |
|
) |
|
self.o = self.which_conv( |
|
self.ch // 2, self.ch, kernel_size=1, padding=0, bias=False |
|
) |
|
|
|
self.gamma = P(torch.tensor(0.0), requires_grad=True) |
|
|
|
def forward(self, x, y=None): |
|
|
|
theta = self.theta(x) |
|
phi = F.max_pool2d(self.phi(x), [2, 2]) |
|
g = F.max_pool2d(self.g(x), [2, 2]) |
|
|
|
theta = theta.view(-1, self.ch // 8, x.shape[2] * x.shape[3]) |
|
phi = phi.view(-1, self.ch // 8, x.shape[2] * x.shape[3] // 4) |
|
g = g.view(-1, self.ch // 2, x.shape[2] * x.shape[3] // 4) |
|
|
|
beta = F.softmax(torch.bmm(theta.transpose(1, 2), phi), -1) |
|
|
|
o = self.o( |
|
torch.bmm(g, beta.transpose(1, 2)).view( |
|
-1, self.ch // 2, x.shape[2], x.shape[3] |
|
) |
|
) |
|
return self.gamma * o + x |
|
|
|
|
|
|
|
def fused_bn(x, mean, var, gain=None, bias=None, eps=1e-5): |
|
|
|
|
|
scale = torch.rsqrt(var + eps) |
|
|
|
if gain is not None: |
|
scale = scale * gain |
|
|
|
shift = mean * scale |
|
|
|
if bias is not None: |
|
shift = shift - bias |
|
return x * scale - shift |
|
|
|
|
|
|
|
|
|
|
|
def manual_bn(x, gain=None, bias=None, return_mean_var=False, eps=1e-5): |
|
|
|
float_x = x.float() |
|
|
|
|
|
m = torch.mean(float_x, [0, 2, 3], keepdim=True) |
|
|
|
m2 = torch.mean(float_x ** 2, [0, 2, 3], keepdim=True) |
|
|
|
var = m2 - m ** 2 |
|
|
|
var = var.type(x.type()) |
|
m = m.type(x.type()) |
|
|
|
if return_mean_var: |
|
return fused_bn(x, m, var, gain, bias, eps), m.squeeze(), var.squeeze() |
|
else: |
|
return fused_bn(x, m, var, gain, bias, eps) |
|
|
|
|
|
|
|
class myBN(nn.Module): |
|
def __init__(self, num_channels, eps=1e-5, momentum=0.1): |
|
super(myBN, self).__init__() |
|
|
|
self.momentum = momentum |
|
|
|
self.eps = eps |
|
|
|
self.momentum = momentum |
|
|
|
self.register_buffer("stored_mean", torch.zeros(num_channels)) |
|
self.register_buffer("stored_var", torch.ones(num_channels)) |
|
self.register_buffer("accumulation_counter", torch.zeros(1)) |
|
|
|
self.accumulate_standing = False |
|
|
|
|
|
def reset_stats(self): |
|
self.stored_mean[:] = 0 |
|
self.stored_var[:] = 0 |
|
self.accumulation_counter[:] = 0 |
|
|
|
def forward(self, x, gain, bias): |
|
if self.training: |
|
out, mean, var = manual_bn( |
|
x, gain, bias, return_mean_var=True, eps=self.eps |
|
) |
|
|
|
if self.accumulate_standing: |
|
self.stored_mean[:] = self.stored_mean + mean.data |
|
self.stored_var[:] = self.stored_var + var.data |
|
self.accumulation_counter += 1.0 |
|
|
|
else: |
|
self.stored_mean[:] = ( |
|
self.stored_mean * (1 - self.momentum) + mean * self.momentum |
|
) |
|
self.stored_var[:] = ( |
|
self.stored_var * (1 - self.momentum) + var * self.momentum |
|
) |
|
return out |
|
|
|
else: |
|
mean = self.stored_mean.view(1, -1, 1, 1) |
|
var = self.stored_var.view(1, -1, 1, 1) |
|
|
|
if self.accumulate_standing: |
|
mean = mean / self.accumulation_counter |
|
var = var / self.accumulation_counter |
|
return fused_bn(x, mean, var, gain, bias, self.eps) |
|
|
|
|
|
|
|
def groupnorm(x, norm_style): |
|
|
|
if "ch" in norm_style: |
|
ch = int(norm_style.split("_")[-1]) |
|
groups = max(int(x.shape[1]) // ch, 1) |
|
|
|
elif "grp" in norm_style: |
|
groups = int(norm_style.split("_")[-1]) |
|
|
|
else: |
|
groups = 16 |
|
return F.group_norm(x, groups) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ccbn(nn.Module): |
|
def __init__( |
|
self, |
|
output_size, |
|
input_size, |
|
which_linear, |
|
eps=1e-5, |
|
momentum=0.1, |
|
cross_replica=False, |
|
mybn=False, |
|
norm_style="bn", |
|
): |
|
super(ccbn, self).__init__() |
|
self.output_size, self.input_size = output_size, input_size |
|
|
|
self.gain = which_linear(input_size, output_size) |
|
self.bias = which_linear(input_size, output_size) |
|
|
|
self.eps = eps |
|
|
|
self.momentum = momentum |
|
|
|
self.cross_replica = cross_replica |
|
|
|
self.mybn = mybn |
|
|
|
self.norm_style = norm_style |
|
|
|
if self.cross_replica: |
|
|
|
self.bn = nn.BatchNorm2d( |
|
output_size, eps=self.eps, momentum=self.momentum, affine=False |
|
) |
|
elif self.mybn: |
|
self.bn = myBN(output_size, self.eps, self.momentum) |
|
elif self.norm_style in ["bn", "in"]: |
|
self.register_buffer("stored_mean", torch.zeros(output_size)) |
|
self.register_buffer("stored_var", torch.ones(output_size)) |
|
|
|
def forward(self, x, y): |
|
|
|
gain = (1 + self.gain(y)).view(y.size(0), -1, 1, 1) |
|
bias = self.bias(y).view(y.size(0), -1, 1, 1) |
|
|
|
|
|
if self.cross_replica: |
|
out = self.bn(x) |
|
out = out * gain + bias |
|
return out |
|
elif self.mybn: |
|
return self.bn(x, gain=gain, bias=bias) |
|
else: |
|
if self.norm_style == "bn": |
|
out = F.batch_norm( |
|
x, |
|
self.stored_mean, |
|
self.stored_var, |
|
None, |
|
None, |
|
self.training, |
|
0.1, |
|
self.eps, |
|
) |
|
elif self.norm_style == "in": |
|
out = F.instance_norm( |
|
x, |
|
self.stored_mean, |
|
self.stored_var, |
|
None, |
|
None, |
|
self.training, |
|
0.1, |
|
self.eps, |
|
) |
|
elif self.norm_style == "gn": |
|
out = groupnorm(x, self.normstyle) |
|
elif self.norm_style == "nonorm": |
|
out = x |
|
return out * gain + bias |
|
|
|
def extra_repr(self): |
|
s = "out: {output_size}, in: {input_size}," |
|
s += " cross_replica={cross_replica}" |
|
return s.format(**self.__dict__) |
|
|
|
|
|
|
|
class bn(nn.Module): |
|
def __init__( |
|
self, |
|
output_size, |
|
eps=1e-5, |
|
momentum=0.1, |
|
cross_replica=False, |
|
mybn=False, |
|
**kwargs |
|
): |
|
super(bn, self).__init__() |
|
self.output_size = output_size |
|
|
|
|
|
self.eps = eps |
|
|
|
self.momentum = momentum |
|
|
|
self.cross_replica = cross_replica |
|
|
|
self.mybn = mybn |
|
|
|
if self.cross_replica: |
|
|
|
self.bn = nn.BatchNorm2d( |
|
output_size, eps=self.eps, momentum=self.momentum, affine=True |
|
) |
|
elif mybn: |
|
|
|
self.bn = myBN(output_size, self.eps, self.momentum) |
|
|
|
else: |
|
self.register_buffer("stored_mean", torch.zeros(output_size)) |
|
self.register_buffer("stored_var", torch.ones(output_size)) |
|
|
|
if not self.cross_replica: |
|
self.gain = P(torch.ones(output_size), requires_grad=True) |
|
self.bias = P(torch.zeros(output_size), requires_grad=True) |
|
|
|
def forward(self, x, y=None): |
|
if self.cross_replica: |
|
out = self.bn(x) |
|
return out |
|
elif self.mybn: |
|
gain = self.gain.view(1, -1, 1, 1) |
|
bias = self.bias.view(1, -1, 1, 1) |
|
return self.bn(x, gain=gain, bias=bias) |
|
else: |
|
return F.batch_norm( |
|
x, |
|
self.stored_mean, |
|
self.stored_var, |
|
self.gain, |
|
self.bias, |
|
self.training, |
|
self.momentum, |
|
self.eps, |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class GBlock(nn.Module): |
|
def __init__( |
|
self, |
|
in_channels, |
|
out_channels, |
|
which_conv=nn.Conv2d, |
|
which_bn=bn, |
|
activation=None, |
|
upsample=None, |
|
): |
|
super(GBlock, self).__init__() |
|
|
|
self.in_channels, self.out_channels = in_channels, out_channels |
|
self.which_conv, self.which_bn = which_conv, which_bn |
|
self.activation = activation |
|
self.upsample = upsample |
|
|
|
self.conv1 = self.which_conv(self.in_channels, self.out_channels) |
|
self.conv2 = self.which_conv(self.out_channels, self.out_channels) |
|
self.learnable_sc = in_channels != out_channels or upsample |
|
if self.learnable_sc: |
|
self.conv_sc = self.which_conv( |
|
in_channels, out_channels, kernel_size=1, padding=0 |
|
) |
|
|
|
self.bn1 = self.which_bn(in_channels) |
|
self.bn2 = self.which_bn(out_channels) |
|
|
|
self.upsample = upsample |
|
|
|
def forward(self, x, y): |
|
h = self.activation(self.bn1(x, y)) |
|
if self.upsample: |
|
h = self.upsample(h) |
|
x = self.upsample(x) |
|
h = self.conv1(h) |
|
h = self.activation(self.bn2(h, y)) |
|
h = self.conv2(h) |
|
if self.learnable_sc: |
|
x = self.conv_sc(x) |
|
return h + x |
|
|
|
|
|
|
|
class DBlock(nn.Module): |
|
def __init__( |
|
self, |
|
in_channels, |
|
out_channels, |
|
which_conv=SNConv2d, |
|
wide=True, |
|
preactivation=False, |
|
activation=None, |
|
downsample=None, |
|
): |
|
super(DBlock, self).__init__() |
|
self.in_channels, self.out_channels = in_channels, out_channels |
|
|
|
self.hidden_channels = self.out_channels if wide else self.in_channels |
|
self.which_conv = which_conv |
|
self.preactivation = preactivation |
|
self.activation = activation |
|
self.downsample = downsample |
|
|
|
|
|
self.conv1 = self.which_conv(self.in_channels, self.hidden_channels) |
|
self.conv2 = self.which_conv(self.hidden_channels, self.out_channels) |
|
self.learnable_sc = ( |
|
True if (in_channels != out_channels) or downsample else False |
|
) |
|
if self.learnable_sc: |
|
self.conv_sc = self.which_conv( |
|
in_channels, out_channels, kernel_size=1, padding=0 |
|
) |
|
|
|
def shortcut(self, x): |
|
if self.preactivation: |
|
if self.learnable_sc: |
|
x = self.conv_sc(x) |
|
if self.downsample: |
|
x = self.downsample(x) |
|
else: |
|
if self.downsample: |
|
x = self.downsample(x) |
|
if self.learnable_sc: |
|
x = self.conv_sc(x) |
|
return x |
|
|
|
def forward(self, x): |
|
if self.preactivation: |
|
|
|
|
|
|
|
h = F.relu(x) |
|
else: |
|
h = x |
|
h = self.conv1(h) |
|
h = self.conv2(self.activation(h)) |
|
if self.downsample: |
|
h = self.downsample(h) |
|
|
|
return h + self.shortcut(x) |
|
|
|
|
|
|
|
|