ic_gan / BigGAN_PyTorch /BigGAN.py
ArantxaCasanova
First model version
a00ee36
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# All contributions by Andy Brock:
# Copyright (c) 2019 Andy Brock
#
# MIT License
import numpy as np
import math
import functools
import os
import torch
import torch.nn as nn
from torch.nn import init
import torch.optim as optim
import torch.nn.functional as F
# from torch.nn import Parameter as P
import sys
sys.path.insert(1, os.path.join(sys.path[0], ".."))
import BigGAN_PyTorch.layers as layers
# from sync_batchnorm import SynchronizedBatchNorm2d as SyncBatchNorm2d
from BigGAN_PyTorch.diffaugment_utils import DiffAugment
# Architectures for G
# Attention is passed in in the format '32_64' to mean applying an attention
# block at both resolution 32x32 and 64x64. Just '64' will apply at 64x64.
def G_arch(ch=64, attention="64", ksize="333333", dilation="111111"):
arch = {}
arch[512] = {
"in_channels": [ch * item for item in [16, 16, 8, 8, 4, 2, 1]],
"out_channels": [ch * item for item in [16, 8, 8, 4, 2, 1, 1]],
"upsample": [True] * 7,
"resolution": [8, 16, 32, 64, 128, 256, 512],
"attention": {
2 ** i: (2 ** i in [int(item) for item in attention.split("_")])
for i in range(3, 10)
},
}
arch[256] = {
"in_channels": [ch * item for item in [16, 16, 8, 8, 4, 2]],
"out_channels": [ch * item for item in [16, 8, 8, 4, 2, 1]],
"upsample": [True] * 6,
"resolution": [8, 16, 32, 64, 128, 256],
"attention": {
2 ** i: (2 ** i in [int(item) for item in attention.split("_")])
for i in range(3, 9)
},
}
arch[128] = {
"in_channels": [ch * item for item in [16, 16, 8, 4, 2]],
"out_channels": [ch * item for item in [16, 8, 4, 2, 1]],
"upsample": [True] * 5,
"resolution": [8, 16, 32, 64, 128],
"attention": {
2 ** i: (2 ** i in [int(item) for item in attention.split("_")])
for i in range(3, 8)
},
}
arch[64] = {
"in_channels": [ch * item for item in [16, 16, 8, 4]],
"out_channels": [ch * item for item in [16, 8, 4, 2]],
"upsample": [True] * 4,
"resolution": [8, 16, 32, 64],
"attention": {
2 ** i: (2 ** i in [int(item) for item in attention.split("_")])
for i in range(3, 7)
},
}
arch[32] = {
"in_channels": [ch * item for item in [4, 4, 4]],
"out_channels": [ch * item for item in [4, 4, 4]],
"upsample": [True] * 3,
"resolution": [8, 16, 32],
"attention": {
2 ** i: (2 ** i in [int(item) for item in attention.split("_")])
for i in range(3, 6)
},
}
return arch
class Generator(nn.Module):
def __init__(
self,
G_ch=64,
dim_z=128,
bottom_width=4,
resolution=128,
G_kernel_size=3,
G_attn="64",
n_classes=1000,
num_G_SVs=1,
num_G_SV_itrs=1,
G_shared=True,
shared_dim=0,
hier=False,
cross_replica=False,
mybn=False,
G_activation=nn.ReLU(inplace=False),
G_lr=5e-5,
G_B1=0.0,
G_B2=0.999,
adam_eps=1e-8,
BN_eps=1e-5,
SN_eps=1e-12,
G_mixed_precision=False,
G_fp16=False,
G_init="ortho",
skip_init=False,
no_optim=False,
G_param="SN",
norm_style="bn",
class_cond=True,
embedded_optimizer=True,
instance_cond=False,
G_shared_feat=True,
shared_dim_feat=2048,
**kwargs
):
super(Generator, self).__init__()
# Channel width mulitplier
self.ch = G_ch
# Dimensionality of the latent space
self.dim_z = dim_z
# The initial spatial dimensions
self.bottom_width = bottom_width
# Resolution of the output
self.resolution = resolution
# Kernel size?
self.kernel_size = G_kernel_size
# Attention?
self.attention = G_attn
# number of classes, for use in categorical conditional generation
self.n_classes = n_classes
# Use shared embeddings?
self.G_shared = G_shared
# Dimensionality of the shared embedding? Unused if not using G_shared
self.shared_dim = shared_dim if shared_dim > 0 else dim_z
# Hierarchical latent space?
self.hier = hier
# Cross replica batchnorm?
self.cross_replica = cross_replica
# Use my batchnorm?
self.mybn = mybn
# nonlinearity for residual blocks
self.activation = G_activation
# Initialization style
self.init = G_init
# Parameterization style
self.G_param = G_param
# Normalization style
self.norm_style = norm_style
# Epsilon for BatchNorm?
self.BN_eps = BN_eps
# Epsilon for Spectral Norm?
self.SN_eps = SN_eps
# fp16?
self.fp16 = G_fp16
# Use embeddings for instance features?
self.G_shared_feat = G_shared_feat
self.shared_dim_feat = shared_dim_feat
# Architecture dict
self.arch = G_arch(self.ch, self.attention)[resolution]
# If using hierarchical latents, adjust z
if self.hier:
# Number of places z slots into
self.num_slots = len(self.arch["in_channels"]) + 1
self.z_chunk_size = self.dim_z // self.num_slots
# Recalculate latent dimensionality for even splitting into chunks
self.dim_z = self.z_chunk_size * self.num_slots
else:
self.num_slots = 1
self.z_chunk_size = 0
# Which convs, batchnorms, and linear layers to use
if self.G_param == "SN":
self.which_conv = functools.partial(
layers.SNConv2d,
kernel_size=3,
padding=1,
num_svs=num_G_SVs,
num_itrs=num_G_SV_itrs,
eps=self.SN_eps,
)
self.which_linear = functools.partial(
layers.SNLinear,
num_svs=num_G_SVs,
num_itrs=num_G_SV_itrs,
eps=self.SN_eps,
)
else:
self.which_conv = functools.partial(nn.Conv2d, kernel_size=3, padding=1)
self.which_linear = nn.Linear
# We use a non-spectral-normed embedding here regardless;
# For some reason applying SN to G's embedding seems to randomly cripple G
self.which_embedding = nn.Embedding
bn_linear = (
functools.partial(self.which_linear, bias=False)
if self.G_shared
else self.which_embedding
)
if not class_cond and not instance_cond:
input_sz_bn = self.n_classes
else:
input_sz_bn = self.z_chunk_size
if class_cond:
input_sz_bn += self.shared_dim
if instance_cond:
input_sz_bn += self.shared_dim_feat
self.which_bn = functools.partial(
layers.ccbn,
which_linear=bn_linear,
cross_replica=self.cross_replica,
mybn=self.mybn,
input_size=input_sz_bn,
norm_style=self.norm_style,
eps=self.BN_eps,
)
# Prepare model
# If not using shared embeddings, self.shared is just a passthrough
self.shared = (
self.which_embedding(n_classes, self.shared_dim)
if G_shared
else layers.identity()
)
self.shared_feat = (
self.which_linear(2048, self.shared_dim_feat)
if G_shared_feat
else layers.identity()
)
# First linear layer
self.linear = self.which_linear(
self.dim_z // self.num_slots,
self.arch["in_channels"][0] * (self.bottom_width ** 2),
)
# self.blocks is a doubly-nested list of modules, the outer loop intended
# to be over blocks at a given resolution (resblocks and/or self-attention)
# while the inner loop is over a given block
self.blocks = []
for index in range(len(self.arch["out_channels"])):
self.blocks += [
[
layers.GBlock(
in_channels=self.arch["in_channels"][index],
out_channels=self.arch["out_channels"][index],
which_conv=self.which_conv,
which_bn=self.which_bn,
activation=self.activation,
upsample=(
functools.partial(F.interpolate, scale_factor=2)
if self.arch["upsample"][index]
else None
),
)
]
]
# If attention on this block, attach it to the end
if self.arch["attention"][self.arch["resolution"][index]]:
print(
"Adding attention layer in G at resolution %d"
% self.arch["resolution"][index]
)
self.blocks[-1] += [
layers.Attention(self.arch["out_channels"][index], self.which_conv)
]
# Turn self.blocks into a ModuleList so that it's all properly registered.
self.blocks = nn.ModuleList([nn.ModuleList(block) for block in self.blocks])
# output layer: batchnorm-relu-conv.
# Consider using a non-spectral conv here
self.output_layer = nn.Sequential(
layers.bn(
self.arch["out_channels"][-1],
cross_replica=self.cross_replica,
mybn=self.mybn,
),
self.activation,
self.which_conv(self.arch["out_channels"][-1], 3),
)
# Initialize weights. Optionally skip init for testing.
if not skip_init:
self.init_weights()
# Set up optimizer
# If this is an EMA copy, no need for an optim, so just return now
if no_optim or not embedded_optimizer:
return
self.lr, self.B1, self.B2, self.adam_eps = G_lr, G_B1, G_B2, adam_eps
if G_mixed_precision:
print("Using fp16 adam in G...")
import utils
self.optim = utils.Adam16(
params=self.parameters(),
lr=self.lr,
betas=(self.B1, self.B2),
weight_decay=0,
eps=self.adam_eps,
)
else:
self.optim = optim.Adam(
params=self.parameters(),
lr=self.lr,
betas=(self.B1, self.B2),
weight_decay=0,
eps=self.adam_eps,
)
# LR scheduling, left here for forward compatibility
# self.lr_sched = {'itr' : 0}# if self.progressive else {}
# self.j = 0
# Initialize
def init_weights(self):
self.param_count = 0
for module in self.modules():
if (
isinstance(module, nn.Conv2d)
or isinstance(module, nn.Linear)
or isinstance(module, nn.Embedding)
):
if self.init == "ortho":
init.orthogonal_(module.weight)
elif self.init == "N02":
init.normal_(module.weight, 0, 0.02)
elif self.init in ["glorot", "xavier"]:
init.xavier_uniform_(module.weight)
else:
print("Init style not recognized...")
self.param_count += sum(
[p.data.nelement() for p in module.parameters()]
)
print("Param count for G" "s initialized parameters: %d" % self.param_count)
# Get conditionings
def get_condition_embeddings(self, cl=None, feat=None):
c_embed = []
if cl is not None:
c_embed.append(self.shared(cl))
if feat is not None:
c_embed.append(self.shared_feat(feat))
if len(c_embed) > 0:
c_embed = torch.cat(c_embed, dim=-1)
return c_embed
# Note on this forward function: we pass in a y vector which has
# already been passed through G.shared to enable easy class-wise
# interpolation later. If we passed in the one-hot and then ran it through
# G.shared in this forward function, it would be harder to handle.
def forward(self, z, label=None, feats=None):
y = self.get_condition_embeddings(label, feats)
# If hierarchical, concatenate zs and ys
if self.hier:
zs = torch.split(z, self.z_chunk_size, 1)
z = zs[0]
ys = [torch.cat([y, item], 1) for item in zs[1:]]
else:
ys = [y] * len(self.blocks)
# First linear layer
h = self.linear(z)
# Reshape
h = h.view(h.size(0), -1, self.bottom_width, self.bottom_width)
# Loop over blocks
for index, blocklist in enumerate(self.blocks):
# Second inner loop in case block has multiple layers
for block in blocklist:
h = block(h, ys[index])
# Apply batchnorm-relu-conv-tanh at output
return torch.tanh(self.output_layer(h))
# Discriminator architecture, same paradigm as G's above
def D_arch(ch=64, attention="64", ksize="333333", dilation="111111"):
arch = {}
arch[256] = {
"in_channels": [3] + [ch * item for item in [1, 2, 4, 8, 8, 16]],
"out_channels": [item * ch for item in [1, 2, 4, 8, 8, 16, 16]],
"downsample": [True] * 6 + [False],
"resolution": [128, 64, 32, 16, 8, 4, 4],
"attention": {
2 ** i: 2 ** i in [int(item) for item in attention.split("_")]
for i in range(2, 8)
},
}
arch[128] = {
"in_channels": [3] + [ch * item for item in [1, 2, 4, 8, 16]],
"out_channels": [item * ch for item in [1, 2, 4, 8, 16, 16]],
"downsample": [True] * 5 + [False],
"resolution": [64, 32, 16, 8, 4, 4],
"attention": {
2 ** i: 2 ** i in [int(item) for item in attention.split("_")]
for i in range(2, 8)
},
}
arch[64] = {
"in_channels": [3] + [ch * item for item in [1, 2, 4, 8]],
"out_channels": [item * ch for item in [1, 2, 4, 8, 16]],
"downsample": [True] * 4 + [False],
"resolution": [32, 16, 8, 4, 4],
"attention": {
2 ** i: 2 ** i in [int(item) for item in attention.split("_")]
for i in range(2, 7)
},
}
arch[32] = {
"in_channels": [3] + [item * ch for item in [4, 4, 4]],
"out_channels": [item * ch for item in [4, 4, 4, 4]],
"downsample": [True, True, False, False],
"resolution": [16, 16, 16, 16],
"attention": {
2 ** i: 2 ** i in [int(item) for item in attention.split("_")]
for i in range(2, 6)
},
}
return arch
class Discriminator(nn.Module):
def __init__(
self,
D_ch=64,
D_wide=True,
resolution=128,
D_kernel_size=3,
D_attn="64",
n_classes=1000,
num_D_SVs=1,
num_D_SV_itrs=1,
D_activation=nn.ReLU(inplace=False),
D_lr=2e-4,
D_B1=0.0,
D_B2=0.999,
adam_eps=1e-8,
SN_eps=1e-12,
output_dim=1,
D_mixed_precision=False,
D_fp16=False,
D_init="ortho",
skip_init=False,
D_param="SN",
class_cond=True,
embedded_optimizer=True,
instance_cond=False,
instance_sz=2048,
**kwargs
):
super(Discriminator, self).__init__()
# Width multiplier
self.ch = D_ch
# Use Wide D as in BigGAN and SA-GAN or skinny D as in SN-GAN?
self.D_wide = D_wide
# Resolution
self.resolution = resolution
# Kernel size
self.kernel_size = D_kernel_size
# Attention?
self.attention = D_attn
# Number of classes
self.n_classes = n_classes
# Activation
self.activation = D_activation
# Initialization style
self.init = D_init
# Parameterization style
self.D_param = D_param
# Epsilon for Spectral Norm?
self.SN_eps = SN_eps
# Fp16?
self.fp16 = D_fp16
# Architecture
self.arch = D_arch(self.ch, self.attention)[resolution]
# Which convs, batchnorms, and linear layers to use
# No option to turn off SN in D right now
if self.D_param == "SN":
self.which_conv = functools.partial(
layers.SNConv2d,
kernel_size=3,
padding=1,
num_svs=num_D_SVs,
num_itrs=num_D_SV_itrs,
eps=self.SN_eps,
)
self.which_linear = functools.partial(
layers.SNLinear,
num_svs=num_D_SVs,
num_itrs=num_D_SV_itrs,
eps=self.SN_eps,
)
self.which_embedding = functools.partial(
layers.SNEmbedding,
num_svs=num_D_SVs,
num_itrs=num_D_SV_itrs,
eps=self.SN_eps,
)
# Prepare model
# self.blocks is a doubly-nested list of modules, the outer loop intended
# to be over blocks at a given resolution (resblocks and/or self-attention)
self.blocks = []
for index in range(len(self.arch["out_channels"])):
self.blocks += [
[
layers.DBlock(
in_channels=self.arch["in_channels"][index],
out_channels=self.arch["out_channels"][index],
which_conv=self.which_conv,
wide=self.D_wide,
activation=self.activation,
preactivation=(index > 0),
downsample=(
nn.AvgPool2d(2) if self.arch["downsample"][index] else None
),
)
]
]
# If attention on this block, attach it to the end
if self.arch["attention"][self.arch["resolution"][index]]:
print(
"Adding attention layer in D at resolution %d"
% self.arch["resolution"][index]
)
self.blocks[-1] += [
layers.Attention(self.arch["out_channels"][index], self.which_conv)
]
# Turn self.blocks into a ModuleList so that it's all properly registered.
self.blocks = nn.ModuleList([nn.ModuleList(block) for block in self.blocks])
# Linear output layer. The output dimension is typically 1, but may be
# larger if we're e.g. turning this into a VAE with an inference output
self.linear = self.which_linear(self.arch["out_channels"][-1], output_dim)
# Embedding for projection discrimination
if class_cond and instance_cond:
self.linear_feat = self.which_linear(
instance_sz, self.arch["out_channels"][-1] // 2
)
self.embed = self.which_embedding(
self.n_classes, self.arch["out_channels"][-1] // 2
)
elif class_cond:
# Embedding for projection discrimination
self.embed = self.which_embedding(
self.n_classes, self.arch["out_channels"][-1]
)
elif instance_cond:
self.linear_feat = self.which_linear(
instance_sz, self.arch["out_channels"][-1]
)
# Initialize weights
if not skip_init:
self.init_weights()
# Set up optimizer
if embedded_optimizer:
self.lr, self.B1, self.B2, self.adam_eps = D_lr, D_B1, D_B2, adam_eps
if D_mixed_precision:
print("Using fp16 adam in D...")
import utils
self.optim = utils.Adam16(
params=self.parameters(),
lr=self.lr,
betas=(self.B1, self.B2),
weight_decay=0,
eps=self.adam_eps,
)
else:
self.optim = optim.Adam(
params=self.parameters(),
lr=self.lr,
betas=(self.B1, self.B2),
weight_decay=0,
eps=self.adam_eps,
)
# LR scheduling, left here for forward compatibility
# self.lr_sched = {'itr' : 0}# if self.progressive else {}
# self.j = 0
# Initialize
def init_weights(self):
self.param_count = 0
for module in self.modules():
if (
isinstance(module, nn.Conv2d)
or isinstance(module, nn.Linear)
or isinstance(module, nn.Embedding)
):
if self.init == "ortho":
init.orthogonal_(module.weight)
elif self.init == "N02":
init.normal_(module.weight, 0, 0.02)
elif self.init in ["glorot", "xavier"]:
init.xavier_uniform_(module.weight)
else:
print("Init style not recognized...")
self.param_count += sum(
[p.data.nelement() for p in module.parameters()]
)
print("Param count for D" "s initialized parameters: %d" % self.param_count)
def forward(self, x, y=None, feat=None):
# Stick x into h for cleaner for loops without flow control
h = x
# Loop over blocks
for index, blocklist in enumerate(self.blocks):
for block in blocklist:
h = block(h)
# Apply global sum pooling as in SN-GAN
h = torch.sum(self.activation(h), [2, 3])
# Get initial class-unconditional output
out = self.linear(h)
# Condition on both class and instance features
if y is not None and feat is not None:
out = out + torch.sum(
torch.cat([self.embed(y), self.linear_feat(feat)], dim=-1) * h,
1,
keepdim=True,
)
# Condition on class only
elif y is not None:
# Get projection of final featureset onto class vectors and add to evidence
out = out + torch.sum(self.embed(y) * h, 1, keepdim=True)
# Condition on instance features only
elif feat is not None:
out = out + torch.sum(self.linear_feat(feat) * h, 1, keepdim=True)
return out
# Parallelized G_D to minimize cross-gpu communication
# Without this, Generator outputs would get all-gathered and then rebroadcast.
class G_D(nn.Module):
def __init__(self, G, D, optimizer_G=None, optimizer_D=None):
super(G_D, self).__init__()
self.G = G
self.D = D
self.optimizer_G = optimizer_G
self.optimizer_D = optimizer_D
def forward(
self,
z,
gy,
feats_g=None,
x=None,
dy=None,
feats=None,
train_G=False,
return_G_z=False,
split_D=False,
policy=False,
DA=False,
):
# If training G, enable grad tape
with torch.set_grad_enabled(train_G):
# Get Generator output given noise
G_z = self.G(z, gy, feats_g)
# Cast as necessary
# if self.G.fp16 and not self.D.fp16:
# G_z = G_z.float()
# if self.D.fp16 and not self.G.fp16:
# G_z = G_z.half()
# Split_D means to run D once with real data and once with fake,
# rather than concatenating along the batch dimension.
if split_D:
D_fake = self.D(G_z, gy, feats_g)
if x is not None:
D_real = self.D(x, dy, feats)
return D_fake, D_real
else:
if return_G_z:
return D_fake, G_z
else:
return D_fake
# If real data is provided, concatenate it with the Generator's output
# along the batch dimension for improved efficiency.
else:
D_input = torch.cat([G_z, x], 0) if x is not None else G_z
D_class = torch.cat([gy, dy], 0) if dy is not None else gy
if feats_g is not None:
D_feats = (
torch.cat([feats_g, feats], 0) if feats is not None else feats_g
)
else:
D_feats = None
if DA:
D_input = DiffAugment(D_input, policy=policy)
# Get Discriminator output
D_out = self.D(D_input, D_class, D_feats)
if x is not None:
return torch.split(D_out, [G_z.shape[0], x.shape[0]]) # D_fake, D_real
else:
if return_G_z:
return D_out, G_z
else:
return D_out