PandA / networks /biggan /model.py
james-oldfield's picture
Upload 194 files
2a76164
raw history blame
No virus
14.8 kB
# coding: utf-8
""" BigGAN PyTorch model.
From "Large Scale GAN Training for High Fidelity Natural Image Synthesis"
By Andrew Brocky, Jeff Donahuey and Karen Simonyan.
https://openreview.net/forum?id=B1xsqj09Fm
PyTorch version implemented from the computational graph of the TF Hub module for BigGAN.
Some part of the code are adapted from https://github.com/brain-research/self-attention-gan
This version only comprises the generator (since the discriminator's weights are not released).
This version only comprises the "deep" version of BigGAN (see publication).
"""
from __future__ import (absolute_import, division, print_function, unicode_literals)
import os
import logging
import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from .config import BigGANConfig
from .file_utils import cached_path
logger = logging.getLogger(__name__)
PRETRAINED_MODEL_ARCHIVE_MAP = {
'biggan-deep-128': "https://s3.amazonaws.com/models.huggingface.co/biggan/biggan-deep-128-pytorch_model.bin",
'biggan-deep-256': "https://s3.amazonaws.com/models.huggingface.co/biggan/biggan-deep-256-pytorch_model.bin",
'biggan-deep-512': "https://s3.amazonaws.com/models.huggingface.co/biggan/biggan-deep-512-pytorch_model.bin",
}
PRETRAINED_CONFIG_ARCHIVE_MAP = {
'biggan-deep-128': "https://s3.amazonaws.com/models.huggingface.co/biggan/biggan-deep-128-config.json",
'biggan-deep-256': "https://s3.amazonaws.com/models.huggingface.co/biggan/biggan-deep-256-config.json",
'biggan-deep-512': "https://s3.amazonaws.com/models.huggingface.co/biggan/biggan-deep-512-config.json",
}
WEIGHTS_NAME = 'pytorch_model.bin'
CONFIG_NAME = 'config.json'
def snconv2d(eps=1e-12, **kwargs):
return nn.utils.spectral_norm(nn.Conv2d(**kwargs), eps=eps)
def snlinear(eps=1e-12, **kwargs):
return nn.utils.spectral_norm(nn.Linear(**kwargs), eps=eps)
def sn_embedding(eps=1e-12, **kwargs):
return nn.utils.spectral_norm(nn.Embedding(**kwargs), eps=eps)
class SelfAttn(nn.Module):
""" Self attention Layer"""
def __init__(self, in_channels, eps=1e-12):
super(SelfAttn, self).__init__()
self.in_channels = in_channels
self.snconv1x1_theta = snconv2d(in_channels=in_channels, out_channels=in_channels//8,
kernel_size=1, bias=False, eps=eps)
self.snconv1x1_phi = snconv2d(in_channels=in_channels, out_channels=in_channels//8,
kernel_size=1, bias=False, eps=eps)
self.snconv1x1_g = snconv2d(in_channels=in_channels, out_channels=in_channels//2,
kernel_size=1, bias=False, eps=eps)
self.snconv1x1_o_conv = snconv2d(in_channels=in_channels//2, out_channels=in_channels,
kernel_size=1, bias=False, eps=eps)
self.maxpool = nn.MaxPool2d(2, stride=2, padding=0)
self.softmax = nn.Softmax(dim=-1)
self.gamma = nn.Parameter(torch.zeros(1))
def forward(self, x):
_, ch, h, w = x.size()
# Theta path
theta = self.snconv1x1_theta(x)
theta = theta.view(-1, ch//8, h*w)
# Phi path
phi = self.snconv1x1_phi(x)
phi = self.maxpool(phi)
phi = phi.view(-1, ch//8, h*w//4)
# Attn map
attn = torch.bmm(theta.permute(0, 2, 1), phi)
attn = self.softmax(attn)
# g path
g = self.snconv1x1_g(x)
g = self.maxpool(g)
g = g.view(-1, ch//2, h*w//4)
# Attn_g - o_conv
attn_g = torch.bmm(g, attn.permute(0, 2, 1))
attn_g = attn_g.view(-1, ch//2, h, w)
attn_g = self.snconv1x1_o_conv(attn_g)
# Out
out = x + self.gamma*attn_g
return out
class BigGANBatchNorm(nn.Module):
""" This is a batch norm module that can handle conditional input and can be provided with pre-computed
activation means and variances for various truncation parameters.
We cannot just rely on torch.batch_norm since it cannot handle
batched weights (pytorch 1.0.1). We computate batch_norm our-self without updating running means and variances.
If you want to train this model you should add running means and variance computation logic.
"""
def __init__(self, num_features, condition_vector_dim=None, n_stats=51, eps=1e-4, conditional=True):
super(BigGANBatchNorm, self).__init__()
self.num_features = num_features
self.eps = eps
self.conditional = conditional
# We use pre-computed statistics for n_stats values of truncation between 0 and 1
self.register_buffer('running_means', torch.zeros(n_stats, num_features))
self.register_buffer('running_vars', torch.ones(n_stats, num_features))
self.step_size = 1.0 / (n_stats - 1)
if conditional:
assert condition_vector_dim is not None
self.scale = snlinear(in_features=condition_vector_dim, out_features=num_features, bias=False, eps=eps)
self.offset = snlinear(in_features=condition_vector_dim, out_features=num_features, bias=False, eps=eps)
else:
self.weight = torch.nn.Parameter(torch.Tensor(num_features))
self.bias = torch.nn.Parameter(torch.Tensor(num_features))
def forward(self, x, truncation, condition_vector=None):
# Retreive pre-computed statistics associated to this truncation
coef, start_idx = math.modf(truncation / self.step_size)
start_idx = int(start_idx)
if coef != 0.0: # Interpolate
running_mean = self.running_means[start_idx] * coef + self.running_means[start_idx + 1] * (1 - coef)
running_var = self.running_vars[start_idx] * coef + self.running_vars[start_idx + 1] * (1 - coef)
else:
running_mean = self.running_means[start_idx]
running_var = self.running_vars[start_idx]
if self.conditional:
running_mean = running_mean.unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
running_var = running_var.unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
weight = 1 + self.scale(condition_vector).unsqueeze(-1).unsqueeze(-1)
bias = self.offset(condition_vector).unsqueeze(-1).unsqueeze(-1)
out = (x - running_mean) / torch.sqrt(running_var + self.eps) * weight + bias
else:
out = F.batch_norm(x, running_mean, running_var, self.weight, self.bias,
training=False, momentum=0.0, eps=self.eps)
return out
class GenBlock(nn.Module):
def __init__(self, in_size, out_size, condition_vector_dim, reduction_factor=4, up_sample=False,
n_stats=51, eps=1e-12):
super(GenBlock, self).__init__()
self.up_sample = up_sample
self.drop_channels = (in_size != out_size)
middle_size = in_size // reduction_factor
self.bn_0 = BigGANBatchNorm(in_size, condition_vector_dim, n_stats=n_stats, eps=eps, conditional=True)
self.conv_0 = snconv2d(in_channels=in_size, out_channels=middle_size, kernel_size=1, eps=eps)
self.bn_1 = BigGANBatchNorm(middle_size, condition_vector_dim, n_stats=n_stats, eps=eps, conditional=True)
self.conv_1 = snconv2d(in_channels=middle_size, out_channels=middle_size, kernel_size=3, padding=1, eps=eps)
self.bn_2 = BigGANBatchNorm(middle_size, condition_vector_dim, n_stats=n_stats, eps=eps, conditional=True)
self.conv_2 = snconv2d(in_channels=middle_size, out_channels=middle_size, kernel_size=3, padding=1, eps=eps)
self.bn_3 = BigGANBatchNorm(middle_size, condition_vector_dim, n_stats=n_stats, eps=eps, conditional=True)
self.conv_3 = snconv2d(in_channels=middle_size, out_channels=out_size, kernel_size=1, eps=eps)
self.relu = nn.ReLU()
def forward(self, x, cond_vector, truncation):
x0 = x
x = self.bn_0(x, truncation, cond_vector)
x = self.relu(x)
x = self.conv_0(x)
x = self.bn_1(x, truncation, cond_vector)
x = self.relu(x)
if self.up_sample:
x = F.interpolate(x, scale_factor=2, mode='nearest')
x = self.conv_1(x)
x = self.bn_2(x, truncation, cond_vector)
x = self.relu(x)
x = self.conv_2(x)
x = self.bn_3(x, truncation, cond_vector)
x = self.relu(x)
x = self.conv_3(x)
if self.drop_channels:
new_channels = x0.shape[1] // 2
x0 = x0[:, :new_channels, ...]
if self.up_sample:
x0 = F.interpolate(x0, scale_factor=2, mode='nearest')
out = x + x0
return out
class Generator(nn.Module):
def __init__(self, config):
super(Generator, self).__init__()
self.config = config
ch = config.channel_width
condition_vector_dim = config.z_dim * 2
self.gen_z = snlinear(in_features=condition_vector_dim,
out_features=4 * 4 * 16 * ch, eps=config.eps)
layers = []
for i, layer in enumerate(config.layers):
if i == config.attention_layer_position:
layers.append(SelfAttn(ch*layer[1], eps=config.eps))
layers.append(GenBlock(ch*layer[1],
ch*layer[2],
condition_vector_dim,
up_sample=layer[0],
n_stats=config.n_stats,
eps=config.eps))
self.layers = nn.ModuleList(layers)
self.bn = BigGANBatchNorm(ch, n_stats=config.n_stats, eps=config.eps, conditional=False)
self.relu = nn.ReLU()
self.conv_to_rgb = snconv2d(in_channels=ch, out_channels=ch, kernel_size=3, padding=1, eps=config.eps)
self.tanh = nn.Tanh()
def forward(self, cond_vector, truncation, z=None, start=0, stop=None):
# We use this conversion step to be able to use TF weights:
# TF convention on shape is [batch, height, width, channels]
# PT convention on shape is [batch, channels, height, width]
if start == 0 and z is None:
z = self.gen_z(cond_vector)
z = z.view(-1, 4, 4, 16 * self.config.channel_width)
z = z.permute(0, 3, 1, 2).contiguous()
if stop is None: stop = len(self.layers)
# for i, layer in enumerate(self.layers):
for i in range(start, stop):
if isinstance(self.layers[i], GenBlock):
z = self.layers[i](z, cond_vector, truncation)
else:
z = self.layers[i](z)
if stop == len(self.layers):
z = self.bn(z, truncation)
z = self.relu(z)
z = self.conv_to_rgb(z)
z = z[:, :3, ...]
z = self.tanh(z)
# for i, layer in enumerate(self.layers):
# if isinstance(layer, GenBlock):
# z = layer(z, cond_vector, truncation)
# else:
# z = layer(z)
# z = self.bn(z, truncation)
# z = self.relu(z)
# z = self.conv_to_rgb(z)
# z = z[:, :3, ...]
# z = self.tanh(z)
return z
class BigGAN(nn.Module):
"""BigGAN Generator."""
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs):
if pretrained_model_name_or_path in PRETRAINED_MODEL_ARCHIVE_MAP:
model_file = PRETRAINED_MODEL_ARCHIVE_MAP[pretrained_model_name_or_path]
config_file = PRETRAINED_CONFIG_ARCHIVE_MAP[pretrained_model_name_or_path]
else:
model_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
config_file = os.path.join(pretrained_model_name_or_path, CONFIG_NAME)
try:
resolved_model_file = cached_path(model_file, cache_dir=cache_dir)
resolved_config_file = cached_path(config_file, cache_dir=cache_dir)
except EnvironmentError:
logger.error("Wrong model name, should be a valid path to a folder containing "
"a {} file and a {} file or a model name in {}".format(
WEIGHTS_NAME, CONFIG_NAME, PRETRAINED_MODEL_ARCHIVE_MAP.keys()))
raise
logger.info("loading model {} from cache at {}".format(pretrained_model_name_or_path, resolved_model_file))
# Load config
config = BigGANConfig.from_json_file(resolved_config_file)
logger.info("Model config {}".format(config))
# Instantiate model.
model = cls(config, *inputs, **kwargs)
state_dict = torch.load(resolved_model_file, map_location='cpu' if not torch.cuda.is_available() else None)
model.load_state_dict(state_dict, strict=False)
return model
def __init__(self, config):
super(BigGAN, self).__init__()
self.config = config
self.embeddings = nn.Linear(config.num_classes, config.z_dim, bias=False)
self.generator = Generator(config)
# self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
# print(f'device: {self.device}')
# self.generator.to(self.device)
def forward(self, z, class_label, truncation, cond_vector=None, start=0, stop=None):
assert 0 < truncation <= 1
results = {}
if start == 0 and cond_vector is None:
embed = self.embeddings(class_label)
cond_vector = torch.cat((z, embed), dim=1)
results['cond_vector'] = cond_vector
results['z'] = self.generator(cond_vector, truncation, z=None if start == 0 else z, start=start, stop=stop)
return results
if __name__ == "__main__":
import PIL
from .utils import truncated_noise_sample, save_as_images, one_hot_from_names
from .convert_tf_to_pytorch import load_tf_weights_in_biggan
load_cache = False
cache_path = './saved_model.pt'
config = BigGANConfig()
model = BigGAN(config)
if not load_cache:
model = load_tf_weights_in_biggan(model, config, './models/model_128/', './models/model_128/batchnorms_stats.bin')
torch.save(model.state_dict(), cache_path)
else:
model.load_state_dict(torch.load(cache_path))
model.eval()
truncation = 0.4
noise = truncated_noise_sample(batch_size=2, truncation=truncation)
label = one_hot_from_names('diver', batch_size=2)
# Tests
# noise = np.zeros((1, 128))
# label = [983]
noise = torch.tensor(noise, dtype=torch.float)
label = torch.tensor(label, dtype=torch.float)
with torch.no_grad():
outputs = model(noise, label, truncation)
print(outputs.shape)
save_as_images(outputs)