|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import numpy as np |
|
import math |
|
import functools |
|
import os |
|
|
|
import torch |
|
import torch.nn as nn |
|
from torch.nn import init |
|
import torch.optim as optim |
|
import torch.nn.functional as F |
|
|
|
|
|
import sys |
|
|
|
sys.path.insert(1, os.path.join(sys.path[0], "..")) |
|
import BigGAN_PyTorch.layers as layers |
|
|
|
|
|
from BigGAN_PyTorch.diffaugment_utils import DiffAugment |
|
|
|
|
|
|
|
|
|
def G_arch(ch=64, attention="64", ksize="333333", dilation="111111"): |
|
arch = {} |
|
arch[512] = { |
|
"in_channels": [ch * item for item in [16, 16, 8, 8, 4, 2, 1]], |
|
"out_channels": [ch * item for item in [16, 8, 8, 4, 2, 1, 1]], |
|
"upsample": [True] * 7, |
|
"resolution": [8, 16, 32, 64, 128, 256, 512], |
|
"attention": { |
|
2 ** i: (2 ** i in [int(item) for item in attention.split("_")]) |
|
for i in range(3, 10) |
|
}, |
|
} |
|
arch[256] = { |
|
"in_channels": [ch * item for item in [16, 16, 8, 8, 4, 2]], |
|
"out_channels": [ch * item for item in [16, 8, 8, 4, 2, 1]], |
|
"upsample": [True] * 6, |
|
"resolution": [8, 16, 32, 64, 128, 256], |
|
"attention": { |
|
2 ** i: (2 ** i in [int(item) for item in attention.split("_")]) |
|
for i in range(3, 9) |
|
}, |
|
} |
|
arch[128] = { |
|
"in_channels": [ch * item for item in [16, 16, 8, 4, 2]], |
|
"out_channels": [ch * item for item in [16, 8, 4, 2, 1]], |
|
"upsample": [True] * 5, |
|
"resolution": [8, 16, 32, 64, 128], |
|
"attention": { |
|
2 ** i: (2 ** i in [int(item) for item in attention.split("_")]) |
|
for i in range(3, 8) |
|
}, |
|
} |
|
arch[64] = { |
|
"in_channels": [ch * item for item in [16, 16, 8, 4]], |
|
"out_channels": [ch * item for item in [16, 8, 4, 2]], |
|
"upsample": [True] * 4, |
|
"resolution": [8, 16, 32, 64], |
|
"attention": { |
|
2 ** i: (2 ** i in [int(item) for item in attention.split("_")]) |
|
for i in range(3, 7) |
|
}, |
|
} |
|
arch[32] = { |
|
"in_channels": [ch * item for item in [4, 4, 4]], |
|
"out_channels": [ch * item for item in [4, 4, 4]], |
|
"upsample": [True] * 3, |
|
"resolution": [8, 16, 32], |
|
"attention": { |
|
2 ** i: (2 ** i in [int(item) for item in attention.split("_")]) |
|
for i in range(3, 6) |
|
}, |
|
} |
|
|
|
return arch |
|
|
|
|
|
class Generator(nn.Module): |
|
def __init__( |
|
self, |
|
G_ch=64, |
|
dim_z=128, |
|
bottom_width=4, |
|
resolution=128, |
|
G_kernel_size=3, |
|
G_attn="64", |
|
n_classes=1000, |
|
num_G_SVs=1, |
|
num_G_SV_itrs=1, |
|
G_shared=True, |
|
shared_dim=0, |
|
hier=False, |
|
cross_replica=False, |
|
mybn=False, |
|
G_activation=nn.ReLU(inplace=False), |
|
G_lr=5e-5, |
|
G_B1=0.0, |
|
G_B2=0.999, |
|
adam_eps=1e-8, |
|
BN_eps=1e-5, |
|
SN_eps=1e-12, |
|
G_mixed_precision=False, |
|
G_fp16=False, |
|
G_init="ortho", |
|
skip_init=False, |
|
no_optim=False, |
|
G_param="SN", |
|
norm_style="bn", |
|
class_cond=True, |
|
embedded_optimizer=True, |
|
instance_cond=False, |
|
G_shared_feat=True, |
|
shared_dim_feat=2048, |
|
**kwargs |
|
): |
|
super(Generator, self).__init__() |
|
|
|
self.ch = G_ch |
|
|
|
self.dim_z = dim_z |
|
|
|
self.bottom_width = bottom_width |
|
|
|
self.resolution = resolution |
|
|
|
self.kernel_size = G_kernel_size |
|
|
|
self.attention = G_attn |
|
|
|
self.n_classes = n_classes |
|
|
|
self.G_shared = G_shared |
|
|
|
self.shared_dim = shared_dim if shared_dim > 0 else dim_z |
|
|
|
self.hier = hier |
|
|
|
self.cross_replica = cross_replica |
|
|
|
self.mybn = mybn |
|
|
|
self.activation = G_activation |
|
|
|
self.init = G_init |
|
|
|
self.G_param = G_param |
|
|
|
self.norm_style = norm_style |
|
|
|
self.BN_eps = BN_eps |
|
|
|
self.SN_eps = SN_eps |
|
|
|
self.fp16 = G_fp16 |
|
|
|
self.G_shared_feat = G_shared_feat |
|
self.shared_dim_feat = shared_dim_feat |
|
|
|
self.arch = G_arch(self.ch, self.attention)[resolution] |
|
|
|
|
|
if self.hier: |
|
|
|
self.num_slots = len(self.arch["in_channels"]) + 1 |
|
self.z_chunk_size = self.dim_z // self.num_slots |
|
|
|
self.dim_z = self.z_chunk_size * self.num_slots |
|
else: |
|
self.num_slots = 1 |
|
self.z_chunk_size = 0 |
|
|
|
|
|
if self.G_param == "SN": |
|
self.which_conv = functools.partial( |
|
layers.SNConv2d, |
|
kernel_size=3, |
|
padding=1, |
|
num_svs=num_G_SVs, |
|
num_itrs=num_G_SV_itrs, |
|
eps=self.SN_eps, |
|
) |
|
self.which_linear = functools.partial( |
|
layers.SNLinear, |
|
num_svs=num_G_SVs, |
|
num_itrs=num_G_SV_itrs, |
|
eps=self.SN_eps, |
|
) |
|
else: |
|
self.which_conv = functools.partial(nn.Conv2d, kernel_size=3, padding=1) |
|
self.which_linear = nn.Linear |
|
|
|
|
|
|
|
self.which_embedding = nn.Embedding |
|
bn_linear = ( |
|
functools.partial(self.which_linear, bias=False) |
|
if self.G_shared |
|
else self.which_embedding |
|
) |
|
if not class_cond and not instance_cond: |
|
input_sz_bn = self.n_classes |
|
else: |
|
input_sz_bn = self.z_chunk_size |
|
if class_cond: |
|
input_sz_bn += self.shared_dim |
|
if instance_cond: |
|
input_sz_bn += self.shared_dim_feat |
|
self.which_bn = functools.partial( |
|
layers.ccbn, |
|
which_linear=bn_linear, |
|
cross_replica=self.cross_replica, |
|
mybn=self.mybn, |
|
input_size=input_sz_bn, |
|
norm_style=self.norm_style, |
|
eps=self.BN_eps, |
|
) |
|
|
|
|
|
|
|
self.shared = ( |
|
self.which_embedding(n_classes, self.shared_dim) |
|
if G_shared |
|
else layers.identity() |
|
) |
|
self.shared_feat = ( |
|
self.which_linear(2048, self.shared_dim_feat) |
|
if G_shared_feat |
|
else layers.identity() |
|
) |
|
|
|
self.linear = self.which_linear( |
|
self.dim_z // self.num_slots, |
|
self.arch["in_channels"][0] * (self.bottom_width ** 2), |
|
) |
|
|
|
|
|
|
|
|
|
self.blocks = [] |
|
for index in range(len(self.arch["out_channels"])): |
|
self.blocks += [ |
|
[ |
|
layers.GBlock( |
|
in_channels=self.arch["in_channels"][index], |
|
out_channels=self.arch["out_channels"][index], |
|
which_conv=self.which_conv, |
|
which_bn=self.which_bn, |
|
activation=self.activation, |
|
upsample=( |
|
functools.partial(F.interpolate, scale_factor=2) |
|
if self.arch["upsample"][index] |
|
else None |
|
), |
|
) |
|
] |
|
] |
|
|
|
|
|
if self.arch["attention"][self.arch["resolution"][index]]: |
|
print( |
|
"Adding attention layer in G at resolution %d" |
|
% self.arch["resolution"][index] |
|
) |
|
self.blocks[-1] += [ |
|
layers.Attention(self.arch["out_channels"][index], self.which_conv) |
|
] |
|
|
|
|
|
self.blocks = nn.ModuleList([nn.ModuleList(block) for block in self.blocks]) |
|
|
|
|
|
|
|
self.output_layer = nn.Sequential( |
|
layers.bn( |
|
self.arch["out_channels"][-1], |
|
cross_replica=self.cross_replica, |
|
mybn=self.mybn, |
|
), |
|
self.activation, |
|
self.which_conv(self.arch["out_channels"][-1], 3), |
|
) |
|
|
|
|
|
if not skip_init: |
|
self.init_weights() |
|
|
|
|
|
|
|
if no_optim or not embedded_optimizer: |
|
return |
|
self.lr, self.B1, self.B2, self.adam_eps = G_lr, G_B1, G_B2, adam_eps |
|
if G_mixed_precision: |
|
print("Using fp16 adam in G...") |
|
import utils |
|
|
|
self.optim = utils.Adam16( |
|
params=self.parameters(), |
|
lr=self.lr, |
|
betas=(self.B1, self.B2), |
|
weight_decay=0, |
|
eps=self.adam_eps, |
|
) |
|
else: |
|
self.optim = optim.Adam( |
|
params=self.parameters(), |
|
lr=self.lr, |
|
betas=(self.B1, self.B2), |
|
weight_decay=0, |
|
eps=self.adam_eps, |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
def init_weights(self): |
|
self.param_count = 0 |
|
for module in self.modules(): |
|
if ( |
|
isinstance(module, nn.Conv2d) |
|
or isinstance(module, nn.Linear) |
|
or isinstance(module, nn.Embedding) |
|
): |
|
if self.init == "ortho": |
|
init.orthogonal_(module.weight) |
|
elif self.init == "N02": |
|
init.normal_(module.weight, 0, 0.02) |
|
elif self.init in ["glorot", "xavier"]: |
|
init.xavier_uniform_(module.weight) |
|
else: |
|
print("Init style not recognized...") |
|
self.param_count += sum( |
|
[p.data.nelement() for p in module.parameters()] |
|
) |
|
print("Param count for G" "s initialized parameters: %d" % self.param_count) |
|
|
|
|
|
|
|
def get_condition_embeddings(self, cl=None, feat=None): |
|
c_embed = [] |
|
if cl is not None: |
|
c_embed.append(self.shared(cl)) |
|
if feat is not None: |
|
c_embed.append(self.shared_feat(feat)) |
|
if len(c_embed) > 0: |
|
c_embed = torch.cat(c_embed, dim=-1) |
|
return c_embed |
|
|
|
|
|
|
|
|
|
|
|
def forward(self, z, label=None, feats=None): |
|
y = self.get_condition_embeddings(label, feats) |
|
|
|
if self.hier: |
|
zs = torch.split(z, self.z_chunk_size, 1) |
|
z = zs[0] |
|
ys = [torch.cat([y, item], 1) for item in zs[1:]] |
|
else: |
|
ys = [y] * len(self.blocks) |
|
|
|
|
|
h = self.linear(z) |
|
|
|
h = h.view(h.size(0), -1, self.bottom_width, self.bottom_width) |
|
|
|
|
|
for index, blocklist in enumerate(self.blocks): |
|
|
|
for block in blocklist: |
|
h = block(h, ys[index]) |
|
|
|
|
|
return torch.tanh(self.output_layer(h)) |
|
|
|
|
|
|
|
def D_arch(ch=64, attention="64", ksize="333333", dilation="111111"): |
|
arch = {} |
|
arch[256] = { |
|
"in_channels": [3] + [ch * item for item in [1, 2, 4, 8, 8, 16]], |
|
"out_channels": [item * ch for item in [1, 2, 4, 8, 8, 16, 16]], |
|
"downsample": [True] * 6 + [False], |
|
"resolution": [128, 64, 32, 16, 8, 4, 4], |
|
"attention": { |
|
2 ** i: 2 ** i in [int(item) for item in attention.split("_")] |
|
for i in range(2, 8) |
|
}, |
|
} |
|
arch[128] = { |
|
"in_channels": [3] + [ch * item for item in [1, 2, 4, 8, 16]], |
|
"out_channels": [item * ch for item in [1, 2, 4, 8, 16, 16]], |
|
"downsample": [True] * 5 + [False], |
|
"resolution": [64, 32, 16, 8, 4, 4], |
|
"attention": { |
|
2 ** i: 2 ** i in [int(item) for item in attention.split("_")] |
|
for i in range(2, 8) |
|
}, |
|
} |
|
arch[64] = { |
|
"in_channels": [3] + [ch * item for item in [1, 2, 4, 8]], |
|
"out_channels": [item * ch for item in [1, 2, 4, 8, 16]], |
|
"downsample": [True] * 4 + [False], |
|
"resolution": [32, 16, 8, 4, 4], |
|
"attention": { |
|
2 ** i: 2 ** i in [int(item) for item in attention.split("_")] |
|
for i in range(2, 7) |
|
}, |
|
} |
|
arch[32] = { |
|
"in_channels": [3] + [item * ch for item in [4, 4, 4]], |
|
"out_channels": [item * ch for item in [4, 4, 4, 4]], |
|
"downsample": [True, True, False, False], |
|
"resolution": [16, 16, 16, 16], |
|
"attention": { |
|
2 ** i: 2 ** i in [int(item) for item in attention.split("_")] |
|
for i in range(2, 6) |
|
}, |
|
} |
|
return arch |
|
|
|
|
|
class Discriminator(nn.Module): |
|
def __init__( |
|
self, |
|
D_ch=64, |
|
D_wide=True, |
|
resolution=128, |
|
D_kernel_size=3, |
|
D_attn="64", |
|
n_classes=1000, |
|
num_D_SVs=1, |
|
num_D_SV_itrs=1, |
|
D_activation=nn.ReLU(inplace=False), |
|
D_lr=2e-4, |
|
D_B1=0.0, |
|
D_B2=0.999, |
|
adam_eps=1e-8, |
|
SN_eps=1e-12, |
|
output_dim=1, |
|
D_mixed_precision=False, |
|
D_fp16=False, |
|
D_init="ortho", |
|
skip_init=False, |
|
D_param="SN", |
|
class_cond=True, |
|
embedded_optimizer=True, |
|
instance_cond=False, |
|
instance_sz=2048, |
|
**kwargs |
|
): |
|
super(Discriminator, self).__init__() |
|
|
|
self.ch = D_ch |
|
|
|
self.D_wide = D_wide |
|
|
|
self.resolution = resolution |
|
|
|
self.kernel_size = D_kernel_size |
|
|
|
self.attention = D_attn |
|
|
|
self.n_classes = n_classes |
|
|
|
self.activation = D_activation |
|
|
|
self.init = D_init |
|
|
|
self.D_param = D_param |
|
|
|
self.SN_eps = SN_eps |
|
|
|
self.fp16 = D_fp16 |
|
|
|
self.arch = D_arch(self.ch, self.attention)[resolution] |
|
|
|
|
|
|
|
if self.D_param == "SN": |
|
self.which_conv = functools.partial( |
|
layers.SNConv2d, |
|
kernel_size=3, |
|
padding=1, |
|
num_svs=num_D_SVs, |
|
num_itrs=num_D_SV_itrs, |
|
eps=self.SN_eps, |
|
) |
|
self.which_linear = functools.partial( |
|
layers.SNLinear, |
|
num_svs=num_D_SVs, |
|
num_itrs=num_D_SV_itrs, |
|
eps=self.SN_eps, |
|
) |
|
self.which_embedding = functools.partial( |
|
layers.SNEmbedding, |
|
num_svs=num_D_SVs, |
|
num_itrs=num_D_SV_itrs, |
|
eps=self.SN_eps, |
|
) |
|
|
|
|
|
|
|
self.blocks = [] |
|
for index in range(len(self.arch["out_channels"])): |
|
self.blocks += [ |
|
[ |
|
layers.DBlock( |
|
in_channels=self.arch["in_channels"][index], |
|
out_channels=self.arch["out_channels"][index], |
|
which_conv=self.which_conv, |
|
wide=self.D_wide, |
|
activation=self.activation, |
|
preactivation=(index > 0), |
|
downsample=( |
|
nn.AvgPool2d(2) if self.arch["downsample"][index] else None |
|
), |
|
) |
|
] |
|
] |
|
|
|
if self.arch["attention"][self.arch["resolution"][index]]: |
|
print( |
|
"Adding attention layer in D at resolution %d" |
|
% self.arch["resolution"][index] |
|
) |
|
self.blocks[-1] += [ |
|
layers.Attention(self.arch["out_channels"][index], self.which_conv) |
|
] |
|
|
|
self.blocks = nn.ModuleList([nn.ModuleList(block) for block in self.blocks]) |
|
|
|
|
|
self.linear = self.which_linear(self.arch["out_channels"][-1], output_dim) |
|
|
|
if class_cond and instance_cond: |
|
self.linear_feat = self.which_linear( |
|
instance_sz, self.arch["out_channels"][-1] // 2 |
|
) |
|
self.embed = self.which_embedding( |
|
self.n_classes, self.arch["out_channels"][-1] // 2 |
|
) |
|
elif class_cond: |
|
|
|
self.embed = self.which_embedding( |
|
self.n_classes, self.arch["out_channels"][-1] |
|
) |
|
elif instance_cond: |
|
self.linear_feat = self.which_linear( |
|
instance_sz, self.arch["out_channels"][-1] |
|
) |
|
|
|
|
|
if not skip_init: |
|
self.init_weights() |
|
|
|
|
|
if embedded_optimizer: |
|
self.lr, self.B1, self.B2, self.adam_eps = D_lr, D_B1, D_B2, adam_eps |
|
if D_mixed_precision: |
|
print("Using fp16 adam in D...") |
|
import utils |
|
|
|
self.optim = utils.Adam16( |
|
params=self.parameters(), |
|
lr=self.lr, |
|
betas=(self.B1, self.B2), |
|
weight_decay=0, |
|
eps=self.adam_eps, |
|
) |
|
else: |
|
self.optim = optim.Adam( |
|
params=self.parameters(), |
|
lr=self.lr, |
|
betas=(self.B1, self.B2), |
|
weight_decay=0, |
|
eps=self.adam_eps, |
|
) |
|
|
|
|
|
|
|
|
|
|
|
def init_weights(self): |
|
self.param_count = 0 |
|
for module in self.modules(): |
|
if ( |
|
isinstance(module, nn.Conv2d) |
|
or isinstance(module, nn.Linear) |
|
or isinstance(module, nn.Embedding) |
|
): |
|
if self.init == "ortho": |
|
init.orthogonal_(module.weight) |
|
elif self.init == "N02": |
|
init.normal_(module.weight, 0, 0.02) |
|
elif self.init in ["glorot", "xavier"]: |
|
init.xavier_uniform_(module.weight) |
|
else: |
|
print("Init style not recognized...") |
|
self.param_count += sum( |
|
[p.data.nelement() for p in module.parameters()] |
|
) |
|
print("Param count for D" "s initialized parameters: %d" % self.param_count) |
|
|
|
def forward(self, x, y=None, feat=None): |
|
|
|
h = x |
|
|
|
for index, blocklist in enumerate(self.blocks): |
|
for block in blocklist: |
|
h = block(h) |
|
|
|
h = torch.sum(self.activation(h), [2, 3]) |
|
|
|
out = self.linear(h) |
|
|
|
if y is not None and feat is not None: |
|
out = out + torch.sum( |
|
torch.cat([self.embed(y), self.linear_feat(feat)], dim=-1) * h, |
|
1, |
|
keepdim=True, |
|
) |
|
|
|
elif y is not None: |
|
|
|
out = out + torch.sum(self.embed(y) * h, 1, keepdim=True) |
|
|
|
elif feat is not None: |
|
out = out + torch.sum(self.linear_feat(feat) * h, 1, keepdim=True) |
|
return out |
|
|
|
|
|
|
|
|
|
class G_D(nn.Module): |
|
def __init__(self, G, D, optimizer_G=None, optimizer_D=None): |
|
super(G_D, self).__init__() |
|
self.G = G |
|
self.D = D |
|
self.optimizer_G = optimizer_G |
|
self.optimizer_D = optimizer_D |
|
|
|
def forward( |
|
self, |
|
z, |
|
gy, |
|
feats_g=None, |
|
x=None, |
|
dy=None, |
|
feats=None, |
|
train_G=False, |
|
return_G_z=False, |
|
split_D=False, |
|
policy=False, |
|
DA=False, |
|
): |
|
|
|
with torch.set_grad_enabled(train_G): |
|
|
|
G_z = self.G(z, gy, feats_g) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if split_D: |
|
D_fake = self.D(G_z, gy, feats_g) |
|
if x is not None: |
|
D_real = self.D(x, dy, feats) |
|
return D_fake, D_real |
|
else: |
|
if return_G_z: |
|
return D_fake, G_z |
|
else: |
|
return D_fake |
|
|
|
|
|
else: |
|
D_input = torch.cat([G_z, x], 0) if x is not None else G_z |
|
D_class = torch.cat([gy, dy], 0) if dy is not None else gy |
|
if feats_g is not None: |
|
D_feats = ( |
|
torch.cat([feats_g, feats], 0) if feats is not None else feats_g |
|
) |
|
else: |
|
D_feats = None |
|
if DA: |
|
D_input = DiffAugment(D_input, policy=policy) |
|
|
|
D_out = self.D(D_input, D_class, D_feats) |
|
if x is not None: |
|
return torch.split(D_out, [G_z.shape[0], x.shape[0]]) |
|
else: |
|
if return_G_z: |
|
return D_out, G_z |
|
else: |
|
return D_out |
|
|