PandA / networks /biggan /
james-oldfield's picture
Upload 194 files
history blame contribute delete
No virus
14.8 kB
# coding: utf-8
""" BigGAN PyTorch model.
From "Large Scale GAN Training for High Fidelity Natural Image Synthesis"
By Andrew Brocky, Jeff Donahuey and Karen Simonyan.
PyTorch version implemented from the computational graph of the TF Hub module for BigGAN.
Some part of the code are adapted from
This version only comprises the generator (since the discriminator's weights are not released).
This version only comprises the "deep" version of BigGAN (see publication).
from __future__ import (absolute_import, division, print_function, unicode_literals)
import os
import logging
import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from .config import BigGANConfig
from .file_utils import cached_path
logger = logging.getLogger(__name__)
'biggan-deep-128': "",
'biggan-deep-256': "",
'biggan-deep-512': "",
'biggan-deep-128': "",
'biggan-deep-256': "",
'biggan-deep-512': "",
WEIGHTS_NAME = 'pytorch_model.bin'
CONFIG_NAME = 'config.json'
def snconv2d(eps=1e-12, **kwargs):
return nn.utils.spectral_norm(nn.Conv2d(**kwargs), eps=eps)
def snlinear(eps=1e-12, **kwargs):
return nn.utils.spectral_norm(nn.Linear(**kwargs), eps=eps)
def sn_embedding(eps=1e-12, **kwargs):
return nn.utils.spectral_norm(nn.Embedding(**kwargs), eps=eps)
class SelfAttn(nn.Module):
""" Self attention Layer"""
def __init__(self, in_channels, eps=1e-12):
super(SelfAttn, self).__init__()
self.in_channels = in_channels
self.snconv1x1_theta = snconv2d(in_channels=in_channels, out_channels=in_channels//8,
kernel_size=1, bias=False, eps=eps)
self.snconv1x1_phi = snconv2d(in_channels=in_channels, out_channels=in_channels//8,
kernel_size=1, bias=False, eps=eps)
self.snconv1x1_g = snconv2d(in_channels=in_channels, out_channels=in_channels//2,
kernel_size=1, bias=False, eps=eps)
self.snconv1x1_o_conv = snconv2d(in_channels=in_channels//2, out_channels=in_channels,
kernel_size=1, bias=False, eps=eps)
self.maxpool = nn.MaxPool2d(2, stride=2, padding=0)
self.softmax = nn.Softmax(dim=-1)
self.gamma = nn.Parameter(torch.zeros(1))
def forward(self, x):
_, ch, h, w = x.size()
# Theta path
theta = self.snconv1x1_theta(x)
theta = theta.view(-1, ch//8, h*w)
# Phi path
phi = self.snconv1x1_phi(x)
phi = self.maxpool(phi)
phi = phi.view(-1, ch//8, h*w//4)
# Attn map
attn = torch.bmm(theta.permute(0, 2, 1), phi)
attn = self.softmax(attn)
# g path
g = self.snconv1x1_g(x)
g = self.maxpool(g)
g = g.view(-1, ch//2, h*w//4)
# Attn_g - o_conv
attn_g = torch.bmm(g, attn.permute(0, 2, 1))
attn_g = attn_g.view(-1, ch//2, h, w)
attn_g = self.snconv1x1_o_conv(attn_g)
# Out
out = x + self.gamma*attn_g
return out
class BigGANBatchNorm(nn.Module):
""" This is a batch norm module that can handle conditional input and can be provided with pre-computed
activation means and variances for various truncation parameters.
We cannot just rely on torch.batch_norm since it cannot handle
batched weights (pytorch 1.0.1). We computate batch_norm our-self without updating running means and variances.
If you want to train this model you should add running means and variance computation logic.
def __init__(self, num_features, condition_vector_dim=None, n_stats=51, eps=1e-4, conditional=True):
super(BigGANBatchNorm, self).__init__()
self.num_features = num_features
self.eps = eps
self.conditional = conditional
# We use pre-computed statistics for n_stats values of truncation between 0 and 1
self.register_buffer('running_means', torch.zeros(n_stats, num_features))
self.register_buffer('running_vars', torch.ones(n_stats, num_features))
self.step_size = 1.0 / (n_stats - 1)
if conditional:
assert condition_vector_dim is not None
self.scale = snlinear(in_features=condition_vector_dim, out_features=num_features, bias=False, eps=eps)
self.offset = snlinear(in_features=condition_vector_dim, out_features=num_features, bias=False, eps=eps)
self.weight = torch.nn.Parameter(torch.Tensor(num_features))
self.bias = torch.nn.Parameter(torch.Tensor(num_features))
def forward(self, x, truncation, condition_vector=None):
# Retreive pre-computed statistics associated to this truncation
coef, start_idx = math.modf(truncation / self.step_size)
start_idx = int(start_idx)
if coef != 0.0: # Interpolate
running_mean = self.running_means[start_idx] * coef + self.running_means[start_idx + 1] * (1 - coef)
running_var = self.running_vars[start_idx] * coef + self.running_vars[start_idx + 1] * (1 - coef)
running_mean = self.running_means[start_idx]
running_var = self.running_vars[start_idx]
if self.conditional:
running_mean = running_mean.unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
running_var = running_var.unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
weight = 1 + self.scale(condition_vector).unsqueeze(-1).unsqueeze(-1)
bias = self.offset(condition_vector).unsqueeze(-1).unsqueeze(-1)
out = (x - running_mean) / torch.sqrt(running_var + self.eps) * weight + bias
out = F.batch_norm(x, running_mean, running_var, self.weight, self.bias,
training=False, momentum=0.0, eps=self.eps)
return out
class GenBlock(nn.Module):
def __init__(self, in_size, out_size, condition_vector_dim, reduction_factor=4, up_sample=False,
n_stats=51, eps=1e-12):
super(GenBlock, self).__init__()
self.up_sample = up_sample
self.drop_channels = (in_size != out_size)
middle_size = in_size // reduction_factor
self.bn_0 = BigGANBatchNorm(in_size, condition_vector_dim, n_stats=n_stats, eps=eps, conditional=True)
self.conv_0 = snconv2d(in_channels=in_size, out_channels=middle_size, kernel_size=1, eps=eps)
self.bn_1 = BigGANBatchNorm(middle_size, condition_vector_dim, n_stats=n_stats, eps=eps, conditional=True)
self.conv_1 = snconv2d(in_channels=middle_size, out_channels=middle_size, kernel_size=3, padding=1, eps=eps)
self.bn_2 = BigGANBatchNorm(middle_size, condition_vector_dim, n_stats=n_stats, eps=eps, conditional=True)
self.conv_2 = snconv2d(in_channels=middle_size, out_channels=middle_size, kernel_size=3, padding=1, eps=eps)
self.bn_3 = BigGANBatchNorm(middle_size, condition_vector_dim, n_stats=n_stats, eps=eps, conditional=True)
self.conv_3 = snconv2d(in_channels=middle_size, out_channels=out_size, kernel_size=1, eps=eps)
self.relu = nn.ReLU()
def forward(self, x, cond_vector, truncation):
x0 = x
x = self.bn_0(x, truncation, cond_vector)
x = self.relu(x)
x = self.conv_0(x)
x = self.bn_1(x, truncation, cond_vector)
x = self.relu(x)
if self.up_sample:
x = F.interpolate(x, scale_factor=2, mode='nearest')
x = self.conv_1(x)
x = self.bn_2(x, truncation, cond_vector)
x = self.relu(x)
x = self.conv_2(x)
x = self.bn_3(x, truncation, cond_vector)
x = self.relu(x)
x = self.conv_3(x)
if self.drop_channels:
new_channels = x0.shape[1] // 2
x0 = x0[:, :new_channels, ...]
if self.up_sample:
x0 = F.interpolate(x0, scale_factor=2, mode='nearest')
out = x + x0
return out
class Generator(nn.Module):
def __init__(self, config):
super(Generator, self).__init__()
self.config = config
ch = config.channel_width
condition_vector_dim = config.z_dim * 2
self.gen_z = snlinear(in_features=condition_vector_dim,
out_features=4 * 4 * 16 * ch, eps=config.eps)
layers = []
for i, layer in enumerate(config.layers):
if i == config.attention_layer_position:
layers.append(SelfAttn(ch*layer[1], eps=config.eps))
self.layers = nn.ModuleList(layers) = BigGANBatchNorm(ch, n_stats=config.n_stats, eps=config.eps, conditional=False)
self.relu = nn.ReLU()
self.conv_to_rgb = snconv2d(in_channels=ch, out_channels=ch, kernel_size=3, padding=1, eps=config.eps)
self.tanh = nn.Tanh()
def forward(self, cond_vector, truncation, z=None, start=0, stop=None):
# We use this conversion step to be able to use TF weights:
# TF convention on shape is [batch, height, width, channels]
# PT convention on shape is [batch, channels, height, width]
if start == 0 and z is None:
z = self.gen_z(cond_vector)
z = z.view(-1, 4, 4, 16 * self.config.channel_width)
z = z.permute(0, 3, 1, 2).contiguous()
if stop is None: stop = len(self.layers)
# for i, layer in enumerate(self.layers):
for i in range(start, stop):
if isinstance(self.layers[i], GenBlock):
z = self.layers[i](z, cond_vector, truncation)
z = self.layers[i](z)
if stop == len(self.layers):
z =, truncation)
z = self.relu(z)
z = self.conv_to_rgb(z)
z = z[:, :3, ...]
z = self.tanh(z)
# for i, layer in enumerate(self.layers):
# if isinstance(layer, GenBlock):
# z = layer(z, cond_vector, truncation)
# else:
# z = layer(z)
# z =, truncation)
# z = self.relu(z)
# z = self.conv_to_rgb(z)
# z = z[:, :3, ...]
# z = self.tanh(z)
return z
class BigGAN(nn.Module):
"""BigGAN Generator."""
def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs):
if pretrained_model_name_or_path in PRETRAINED_MODEL_ARCHIVE_MAP:
model_file = PRETRAINED_MODEL_ARCHIVE_MAP[pretrained_model_name_or_path]
config_file = PRETRAINED_CONFIG_ARCHIVE_MAP[pretrained_model_name_or_path]
model_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
config_file = os.path.join(pretrained_model_name_or_path, CONFIG_NAME)
resolved_model_file = cached_path(model_file, cache_dir=cache_dir)
resolved_config_file = cached_path(config_file, cache_dir=cache_dir)
except EnvironmentError:
logger.error("Wrong model name, should be a valid path to a folder containing "
"a {} file and a {} file or a model name in {}".format(
raise"loading model {} from cache at {}".format(pretrained_model_name_or_path, resolved_model_file))
# Load config
config = BigGANConfig.from_json_file(resolved_config_file)"Model config {}".format(config))
# Instantiate model.
model = cls(config, *inputs, **kwargs)
state_dict = torch.load(resolved_model_file, map_location='cpu' if not torch.cuda.is_available() else None)
model.load_state_dict(state_dict, strict=False)
return model
def __init__(self, config):
super(BigGAN, self).__init__()
self.config = config
self.embeddings = nn.Linear(config.num_classes, config.z_dim, bias=False)
self.generator = Generator(config)
# self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
# print(f'device: {self.device}')
def forward(self, z, class_label, truncation, cond_vector=None, start=0, stop=None):
assert 0 < truncation <= 1
results = {}
if start == 0 and cond_vector is None:
embed = self.embeddings(class_label)
cond_vector =, embed), dim=1)
results['cond_vector'] = cond_vector
results['z'] = self.generator(cond_vector, truncation, z=None if start == 0 else z, start=start, stop=stop)
return results
if __name__ == "__main__":
import PIL
from .utils import truncated_noise_sample, save_as_images, one_hot_from_names
from .convert_tf_to_pytorch import load_tf_weights_in_biggan
load_cache = False
cache_path = './'
config = BigGANConfig()
model = BigGAN(config)
if not load_cache:
model = load_tf_weights_in_biggan(model, config, './models/model_128/', './models/model_128/batchnorms_stats.bin'), cache_path)
truncation = 0.4
noise = truncated_noise_sample(batch_size=2, truncation=truncation)
label = one_hot_from_names('diver', batch_size=2)
# Tests
# noise = np.zeros((1, 128))
# label = [983]
noise = torch.tensor(noise, dtype=torch.float)
label = torch.tensor(label, dtype=torch.float)
with torch.no_grad():
outputs = model(noise, label, truncation)