# coding: utf-8 """ BigGAN PyTorch model. From "Large Scale GAN Training for High Fidelity Natural Image Synthesis" By Andrew Brocky, Jeff Donahuey and Karen Simonyan. https://openreview.net/forum?id=B1xsqj09Fm PyTorch version implemented from the computational graph of the TF Hub module for BigGAN. Some part of the code are adapted from https://github.com/brain-research/self-attention-gan This version only comprises the generator (since the discriminator's weights are not released). This version only comprises the "deep" version of BigGAN (see publication). """ from __future__ import (absolute_import, division, print_function, unicode_literals) import os import logging import math import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from .config import BigGANConfig from .file_utils import cached_path logger = logging.getLogger(__name__) PRETRAINED_MODEL_ARCHIVE_MAP = { 'biggan-deep-128': "https://s3.amazonaws.com/models.huggingface.co/biggan/biggan-deep-128-pytorch_model.bin", 'biggan-deep-256': "https://s3.amazonaws.com/models.huggingface.co/biggan/biggan-deep-256-pytorch_model.bin", 'biggan-deep-512': "https://s3.amazonaws.com/models.huggingface.co/biggan/biggan-deep-512-pytorch_model.bin", } PRETRAINED_CONFIG_ARCHIVE_MAP = { 'biggan-deep-128': "https://s3.amazonaws.com/models.huggingface.co/biggan/biggan-deep-128-config.json", 'biggan-deep-256': "https://s3.amazonaws.com/models.huggingface.co/biggan/biggan-deep-256-config.json", 'biggan-deep-512': "https://s3.amazonaws.com/models.huggingface.co/biggan/biggan-deep-512-config.json", } WEIGHTS_NAME = 'pytorch_model.bin' CONFIG_NAME = 'config.json' def snconv2d(eps=1e-12, **kwargs): return nn.utils.spectral_norm(nn.Conv2d(**kwargs), eps=eps) def snlinear(eps=1e-12, **kwargs): return nn.utils.spectral_norm(nn.Linear(**kwargs), eps=eps) def sn_embedding(eps=1e-12, **kwargs): return nn.utils.spectral_norm(nn.Embedding(**kwargs), eps=eps) class SelfAttn(nn.Module): """ Self attention Layer""" def __init__(self, in_channels, eps=1e-12): super(SelfAttn, self).__init__() self.in_channels = in_channels self.snconv1x1_theta = snconv2d(in_channels=in_channels, out_channels=in_channels//8, kernel_size=1, bias=False, eps=eps) self.snconv1x1_phi = snconv2d(in_channels=in_channels, out_channels=in_channels//8, kernel_size=1, bias=False, eps=eps) self.snconv1x1_g = snconv2d(in_channels=in_channels, out_channels=in_channels//2, kernel_size=1, bias=False, eps=eps) self.snconv1x1_o_conv = snconv2d(in_channels=in_channels//2, out_channels=in_channels, kernel_size=1, bias=False, eps=eps) self.maxpool = nn.MaxPool2d(2, stride=2, padding=0) self.softmax = nn.Softmax(dim=-1) self.gamma = nn.Parameter(torch.zeros(1)) def forward(self, x): _, ch, h, w = x.size() # Theta path theta = self.snconv1x1_theta(x) theta = theta.view(-1, ch//8, h*w) # Phi path phi = self.snconv1x1_phi(x) phi = self.maxpool(phi) phi = phi.view(-1, ch//8, h*w//4) # Attn map attn = torch.bmm(theta.permute(0, 2, 1), phi) attn = self.softmax(attn) # g path g = self.snconv1x1_g(x) g = self.maxpool(g) g = g.view(-1, ch//2, h*w//4) # Attn_g - o_conv attn_g = torch.bmm(g, attn.permute(0, 2, 1)) attn_g = attn_g.view(-1, ch//2, h, w) attn_g = self.snconv1x1_o_conv(attn_g) # Out out = x + self.gamma*attn_g return out class BigGANBatchNorm(nn.Module): """ This is a batch norm module that can handle conditional input and can be provided with pre-computed activation means and variances for various truncation parameters. We cannot just rely on torch.batch_norm since it cannot handle batched weights (pytorch 1.0.1). We computate batch_norm our-self without updating running means and variances. If you want to train this model you should add running means and variance computation logic. """ def __init__(self, num_features, condition_vector_dim=None, n_stats=51, eps=1e-4, conditional=True): super(BigGANBatchNorm, self).__init__() self.num_features = num_features self.eps = eps self.conditional = conditional # We use pre-computed statistics for n_stats values of truncation between 0 and 1 self.register_buffer('running_means', torch.zeros(n_stats, num_features)) self.register_buffer('running_vars', torch.ones(n_stats, num_features)) self.step_size = 1.0 / (n_stats - 1) if conditional: assert condition_vector_dim is not None self.scale = snlinear(in_features=condition_vector_dim, out_features=num_features, bias=False, eps=eps) self.offset = snlinear(in_features=condition_vector_dim, out_features=num_features, bias=False, eps=eps) else: self.weight = torch.nn.Parameter(torch.Tensor(num_features)) self.bias = torch.nn.Parameter(torch.Tensor(num_features)) def forward(self, x, truncation, condition_vector=None): # Retreive pre-computed statistics associated to this truncation coef, start_idx = math.modf(truncation / self.step_size) start_idx = int(start_idx) if coef != 0.0: # Interpolate running_mean = self.running_means[start_idx] * coef + self.running_means[start_idx + 1] * (1 - coef) running_var = self.running_vars[start_idx] * coef + self.running_vars[start_idx + 1] * (1 - coef) else: running_mean = self.running_means[start_idx] running_var = self.running_vars[start_idx] if self.conditional: running_mean = running_mean.unsqueeze(0).unsqueeze(-1).unsqueeze(-1) running_var = running_var.unsqueeze(0).unsqueeze(-1).unsqueeze(-1) weight = 1 + self.scale(condition_vector).unsqueeze(-1).unsqueeze(-1) bias = self.offset(condition_vector).unsqueeze(-1).unsqueeze(-1) out = (x - running_mean) / torch.sqrt(running_var + self.eps) * weight + bias else: out = F.batch_norm(x, running_mean, running_var, self.weight, self.bias, training=False, momentum=0.0, eps=self.eps) return out class GenBlock(nn.Module): def __init__(self, in_size, out_size, condition_vector_dim, reduction_factor=4, up_sample=False, n_stats=51, eps=1e-12): super(GenBlock, self).__init__() self.up_sample = up_sample self.drop_channels = (in_size != out_size) middle_size = in_size // reduction_factor self.bn_0 = BigGANBatchNorm(in_size, condition_vector_dim, n_stats=n_stats, eps=eps, conditional=True) self.conv_0 = snconv2d(in_channels=in_size, out_channels=middle_size, kernel_size=1, eps=eps) self.bn_1 = BigGANBatchNorm(middle_size, condition_vector_dim, n_stats=n_stats, eps=eps, conditional=True) self.conv_1 = snconv2d(in_channels=middle_size, out_channels=middle_size, kernel_size=3, padding=1, eps=eps) self.bn_2 = BigGANBatchNorm(middle_size, condition_vector_dim, n_stats=n_stats, eps=eps, conditional=True) self.conv_2 = snconv2d(in_channels=middle_size, out_channels=middle_size, kernel_size=3, padding=1, eps=eps) self.bn_3 = BigGANBatchNorm(middle_size, condition_vector_dim, n_stats=n_stats, eps=eps, conditional=True) self.conv_3 = snconv2d(in_channels=middle_size, out_channels=out_size, kernel_size=1, eps=eps) self.relu = nn.ReLU() def forward(self, x, cond_vector, truncation): x0 = x x = self.bn_0(x, truncation, cond_vector) x = self.relu(x) x = self.conv_0(x) x = self.bn_1(x, truncation, cond_vector) x = self.relu(x) if self.up_sample: x = F.interpolate(x, scale_factor=2, mode='nearest') x = self.conv_1(x) x = self.bn_2(x, truncation, cond_vector) x = self.relu(x) x = self.conv_2(x) x = self.bn_3(x, truncation, cond_vector) x = self.relu(x) x = self.conv_3(x) if self.drop_channels: new_channels = x0.shape[1] // 2 x0 = x0[:, :new_channels, ...] if self.up_sample: x0 = F.interpolate(x0, scale_factor=2, mode='nearest') out = x + x0 return out class Generator(nn.Module): def __init__(self, config): super(Generator, self).__init__() self.config = config ch = config.channel_width condition_vector_dim = config.z_dim * 2 self.gen_z = snlinear(in_features=condition_vector_dim, out_features=4 * 4 * 16 * ch, eps=config.eps) layers = [] for i, layer in enumerate(config.layers): if i == config.attention_layer_position: layers.append(SelfAttn(ch*layer[1], eps=config.eps)) layers.append(GenBlock(ch*layer[1], ch*layer[2], condition_vector_dim, up_sample=layer[0], n_stats=config.n_stats, eps=config.eps)) self.layers = nn.ModuleList(layers) self.bn = BigGANBatchNorm(ch, n_stats=config.n_stats, eps=config.eps, conditional=False) self.relu = nn.ReLU() self.conv_to_rgb = snconv2d(in_channels=ch, out_channels=ch, kernel_size=3, padding=1, eps=config.eps) self.tanh = nn.Tanh() def forward(self, cond_vector, truncation, z=None, start=0, stop=None): # We use this conversion step to be able to use TF weights: # TF convention on shape is [batch, height, width, channels] # PT convention on shape is [batch, channels, height, width] if start == 0 and z is None: z = self.gen_z(cond_vector) z = z.view(-1, 4, 4, 16 * self.config.channel_width) z = z.permute(0, 3, 1, 2).contiguous() if stop is None: stop = len(self.layers) # for i, layer in enumerate(self.layers): for i in range(start, stop): if isinstance(self.layers[i], GenBlock): z = self.layers[i](z, cond_vector, truncation) else: z = self.layers[i](z) if stop == len(self.layers): z = self.bn(z, truncation) z = self.relu(z) z = self.conv_to_rgb(z) z = z[:, :3, ...] z = self.tanh(z) # for i, layer in enumerate(self.layers): # if isinstance(layer, GenBlock): # z = layer(z, cond_vector, truncation) # else: # z = layer(z) # z = self.bn(z, truncation) # z = self.relu(z) # z = self.conv_to_rgb(z) # z = z[:, :3, ...] # z = self.tanh(z) return z class BigGAN(nn.Module): """BigGAN Generator.""" @classmethod def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs): if pretrained_model_name_or_path in PRETRAINED_MODEL_ARCHIVE_MAP: model_file = PRETRAINED_MODEL_ARCHIVE_MAP[pretrained_model_name_or_path] config_file = PRETRAINED_CONFIG_ARCHIVE_MAP[pretrained_model_name_or_path] else: model_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME) config_file = os.path.join(pretrained_model_name_or_path, CONFIG_NAME) try: resolved_model_file = cached_path(model_file, cache_dir=cache_dir) resolved_config_file = cached_path(config_file, cache_dir=cache_dir) except EnvironmentError: logger.error("Wrong model name, should be a valid path to a folder containing " "a {} file and a {} file or a model name in {}".format( WEIGHTS_NAME, CONFIG_NAME, PRETRAINED_MODEL_ARCHIVE_MAP.keys())) raise logger.info("loading model {} from cache at {}".format(pretrained_model_name_or_path, resolved_model_file)) # Load config config = BigGANConfig.from_json_file(resolved_config_file) logger.info("Model config {}".format(config)) # Instantiate model. model = cls(config, *inputs, **kwargs) state_dict = torch.load(resolved_model_file, map_location='cpu' if not torch.cuda.is_available() else None) model.load_state_dict(state_dict, strict=False) return model def __init__(self, config): super(BigGAN, self).__init__() self.config = config self.embeddings = nn.Linear(config.num_classes, config.z_dim, bias=False) self.generator = Generator(config) # self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') # print(f'device: {self.device}') # self.generator.to(self.device) def forward(self, z, class_label, truncation, cond_vector=None, start=0, stop=None): assert 0 < truncation <= 1 results = {} if start == 0 and cond_vector is None: embed = self.embeddings(class_label) cond_vector = torch.cat((z, embed), dim=1) results['cond_vector'] = cond_vector results['z'] = self.generator(cond_vector, truncation, z=None if start == 0 else z, start=start, stop=stop) return results if __name__ == "__main__": import PIL from .utils import truncated_noise_sample, save_as_images, one_hot_from_names from .convert_tf_to_pytorch import load_tf_weights_in_biggan load_cache = False cache_path = './saved_model.pt' config = BigGANConfig() model = BigGAN(config) if not load_cache: model = load_tf_weights_in_biggan(model, config, './models/model_128/', './models/model_128/batchnorms_stats.bin') torch.save(model.state_dict(), cache_path) else: model.load_state_dict(torch.load(cache_path)) model.eval() truncation = 0.4 noise = truncated_noise_sample(batch_size=2, truncation=truncation) label = one_hot_from_names('diver', batch_size=2) # Tests # noise = np.zeros((1, 128)) # label = [983] noise = torch.tensor(noise, dtype=torch.float) label = torch.tensor(label, dtype=torch.float) with torch.no_grad(): outputs = model(noise, label, truncation) print(outputs.shape) save_as_images(outputs)