Spaces:
Build error
Build error
# coding: utf-8 | |
""" | |
Convert a TF Hub model for BigGAN in a PT one. | |
""" | |
from __future__ import (absolute_import, division, print_function, unicode_literals) | |
from itertools import chain | |
import os | |
import argparse | |
import logging | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from torch.nn.functional import normalize | |
from .model import BigGAN, WEIGHTS_NAME, CONFIG_NAME | |
from .config import BigGANConfig | |
logger = logging.getLogger(__name__) | |
def extract_batch_norm_stats(tf_model_path, batch_norm_stats_path=None): | |
try: | |
import numpy as np | |
import tensorflow as tf | |
import tensorflow_hub as hub | |
except ImportError: | |
raise ImportError("Loading a TensorFlow models in PyTorch, requires TensorFlow and TF Hub to be installed. " | |
"Please see https://www.tensorflow.org/install/ for installation instructions for TensorFlow. " | |
"And see https://github.com/tensorflow/hub for installing Hub. " | |
"Probably pip install tensorflow tensorflow-hub") | |
tf.reset_default_graph() | |
logger.info('Loading BigGAN module from: {}'.format(tf_model_path)) | |
module = hub.Module(tf_model_path) | |
inputs = {k: tf.placeholder(v.dtype, v.get_shape().as_list(), k) | |
for k, v in module.get_input_info_dict().items()} | |
output = module(inputs) | |
initializer = tf.global_variables_initializer() | |
sess = tf.Session() | |
stacks = sum(((i*10 + 1, i*10 + 3, i*10 + 6, i*10 + 8) for i in range(50)), ()) | |
numpy_stacks = [] | |
for i in stacks: | |
logger.info("Retrieving module_apply_default/stack_{}".format(i)) | |
try: | |
stack_var = tf.get_default_graph().get_tensor_by_name("module_apply_default/stack_%d:0" % i) | |
except KeyError: | |
break # We have all the stats | |
numpy_stacks.append(sess.run(stack_var)) | |
if batch_norm_stats_path is not None: | |
torch.save(numpy_stacks, batch_norm_stats_path) | |
else: | |
return numpy_stacks | |
def build_tf_to_pytorch_map(model, config): | |
""" Build a map from TF variables to PyTorch modules. """ | |
tf_to_pt_map = {} | |
# Embeddings and GenZ | |
tf_to_pt_map.update({'linear/w/ema_0.9999': model.embeddings.weight, | |
'Generator/GenZ/G_linear/b/ema_0.9999': model.generator.gen_z.bias, | |
'Generator/GenZ/G_linear/w/ema_0.9999': model.generator.gen_z.weight_orig, | |
'Generator/GenZ/G_linear/u0': model.generator.gen_z.weight_u}) | |
# GBlock blocks | |
model_layer_idx = 0 | |
for i, (up, in_channels, out_channels) in enumerate(config.layers): | |
if i == config.attention_layer_position: | |
model_layer_idx += 1 | |
layer_str = "Generator/GBlock_%d/" % i if i > 0 else "Generator/GBlock/" | |
layer_pnt = model.generator.layers[model_layer_idx] | |
for i in range(4): # Batchnorms | |
batch_str = layer_str + ("BatchNorm_%d/" % i if i > 0 else "BatchNorm/") | |
batch_pnt = getattr(layer_pnt, 'bn_%d' % i) | |
for name in ('offset', 'scale'): | |
sub_module_str = batch_str + name + "/" | |
sub_module_pnt = getattr(batch_pnt, name) | |
tf_to_pt_map.update({sub_module_str + "w/ema_0.9999": sub_module_pnt.weight_orig, | |
sub_module_str + "u0": sub_module_pnt.weight_u}) | |
for i in range(4): # Convolutions | |
conv_str = layer_str + "conv%d/" % i | |
conv_pnt = getattr(layer_pnt, 'conv_%d' % i) | |
tf_to_pt_map.update({conv_str + "b/ema_0.9999": conv_pnt.bias, | |
conv_str + "w/ema_0.9999": conv_pnt.weight_orig, | |
conv_str + "u0": conv_pnt.weight_u}) | |
model_layer_idx += 1 | |
# Attention block | |
layer_str = "Generator/attention/" | |
layer_pnt = model.generator.layers[config.attention_layer_position] | |
tf_to_pt_map.update({layer_str + "gamma/ema_0.9999": layer_pnt.gamma}) | |
for pt_name, tf_name in zip(['snconv1x1_g', 'snconv1x1_o_conv', 'snconv1x1_phi', 'snconv1x1_theta'], | |
['g/', 'o_conv/', 'phi/', 'theta/']): | |
sub_module_str = layer_str + tf_name | |
sub_module_pnt = getattr(layer_pnt, pt_name) | |
tf_to_pt_map.update({sub_module_str + "w/ema_0.9999": sub_module_pnt.weight_orig, | |
sub_module_str + "u0": sub_module_pnt.weight_u}) | |
# final batch norm and conv to rgb | |
layer_str = "Generator/BatchNorm/" | |
layer_pnt = model.generator.bn | |
tf_to_pt_map.update({layer_str + "offset/ema_0.9999": layer_pnt.bias, | |
layer_str + "scale/ema_0.9999": layer_pnt.weight}) | |
layer_str = "Generator/conv_to_rgb/" | |
layer_pnt = model.generator.conv_to_rgb | |
tf_to_pt_map.update({layer_str + "b/ema_0.9999": layer_pnt.bias, | |
layer_str + "w/ema_0.9999": layer_pnt.weight_orig, | |
layer_str + "u0": layer_pnt.weight_u}) | |
return tf_to_pt_map | |
def load_tf_weights_in_biggan(model, config, tf_model_path, batch_norm_stats_path=None): | |
""" Load tf checkpoints and standing statistics in a pytorch model | |
""" | |
try: | |
import numpy as np | |
import tensorflow as tf | |
except ImportError: | |
raise ImportError("Loading a TensorFlow models in PyTorch, requires TensorFlow to be installed. Please see " | |
"https://www.tensorflow.org/install/ for installation instructions.") | |
# Load weights from TF model | |
checkpoint_path = tf_model_path + "/variables/variables" | |
init_vars = tf.train.list_variables(checkpoint_path) | |
from pprint import pprint | |
pprint(init_vars) | |
# Extract batch norm statistics from model if needed | |
if batch_norm_stats_path: | |
stats = torch.load(batch_norm_stats_path) | |
else: | |
logger.info("Extracting batch norm stats") | |
stats = extract_batch_norm_stats(tf_model_path) | |
# Build TF to PyTorch weights loading map | |
tf_to_pt_map = build_tf_to_pytorch_map(model, config) | |
tf_weights = {} | |
for name in tf_to_pt_map.keys(): | |
array = tf.train.load_variable(checkpoint_path, name) | |
tf_weights[name] = array | |
# logger.info("Loading TF weight {} with shape {}".format(name, array.shape)) | |
# Load parameters | |
with torch.no_grad(): | |
pt_params_pnt = set() | |
for name, pointer in tf_to_pt_map.items(): | |
array = tf_weights[name] | |
if pointer.dim() == 1: | |
if pointer.dim() < array.ndim: | |
array = np.squeeze(array) | |
elif pointer.dim() == 2: # Weights | |
array = np.transpose(array) | |
elif pointer.dim() == 4: # Convolutions | |
array = np.transpose(array, (3, 2, 0, 1)) | |
else: | |
raise "Wrong dimensions to adjust: " + str((pointer.shape, array.shape)) | |
if pointer.shape != array.shape: | |
raise ValueError("Wrong dimensions: " + str((pointer.shape, array.shape))) | |
logger.info("Initialize PyTorch weight {} with shape {}".format(name, pointer.shape)) | |
pointer.data = torch.from_numpy(array) if isinstance(array, np.ndarray) else torch.tensor(array) | |
tf_weights.pop(name, None) | |
pt_params_pnt.add(pointer.data_ptr()) | |
# Prepare SpectralNorm buffers by running one step of Spectral Norm (no need to train the model): | |
for module in model.modules(): | |
for n, buffer in module.named_buffers(): | |
if n == 'weight_v': | |
weight_mat = module.weight_orig | |
weight_mat = weight_mat.reshape(weight_mat.size(0), -1) | |
u = module.weight_u | |
v = normalize(torch.mv(weight_mat.t(), u), dim=0, eps=config.eps) | |
buffer.data = v | |
pt_params_pnt.add(buffer.data_ptr()) | |
u = normalize(torch.mv(weight_mat, v), dim=0, eps=config.eps) | |
module.weight_u.data = u | |
pt_params_pnt.add(module.weight_u.data_ptr()) | |
# Load batch norm statistics | |
index = 0 | |
for layer in model.generator.layers: | |
if not hasattr(layer, 'bn_0'): | |
continue | |
for i in range(4): # Batchnorms | |
bn_pointer = getattr(layer, 'bn_%d' % i) | |
pointer = bn_pointer.running_means | |
if pointer.shape != stats[index].shape: | |
raise "Wrong dimensions: " + str((pointer.shape, stats[index].shape)) | |
pointer.data = torch.from_numpy(stats[index]) | |
pt_params_pnt.add(pointer.data_ptr()) | |
pointer = bn_pointer.running_vars | |
if pointer.shape != stats[index+1].shape: | |
raise "Wrong dimensions: " + str((pointer.shape, stats[index].shape)) | |
pointer.data = torch.from_numpy(stats[index+1]) | |
pt_params_pnt.add(pointer.data_ptr()) | |
index += 2 | |
bn_pointer = model.generator.bn | |
pointer = bn_pointer.running_means | |
if pointer.shape != stats[index].shape: | |
raise "Wrong dimensions: " + str((pointer.shape, stats[index].shape)) | |
pointer.data = torch.from_numpy(stats[index]) | |
pt_params_pnt.add(pointer.data_ptr()) | |
pointer = bn_pointer.running_vars | |
if pointer.shape != stats[index+1].shape: | |
raise "Wrong dimensions: " + str((pointer.shape, stats[index].shape)) | |
pointer.data = torch.from_numpy(stats[index+1]) | |
pt_params_pnt.add(pointer.data_ptr()) | |
remaining_params = list(n for n, t in chain(model.named_parameters(), model.named_buffers()) \ | |
if t.data_ptr() not in pt_params_pnt) | |
logger.info("TF Weights not copied to PyTorch model: {} -".format(', '.join(tf_weights.keys()))) | |
logger.info("Remanining parameters/buffers from PyTorch model: {} -".format(', '.join(remaining_params))) | |
return model | |
BigGAN128 = BigGANConfig(output_dim=128, z_dim=128, class_embed_dim=128, channel_width=128, num_classes=1000, | |
layers=[(False, 16, 16), | |
(True, 16, 16), | |
(False, 16, 16), | |
(True, 16, 8), | |
(False, 8, 8), | |
(True, 8, 4), | |
(False, 4, 4), | |
(True, 4, 2), | |
(False, 2, 2), | |
(True, 2, 1)], | |
attention_layer_position=8, eps=1e-4, n_stats=51) | |
BigGAN256 = BigGANConfig(output_dim=256, z_dim=128, class_embed_dim=128, channel_width=128, num_classes=1000, | |
layers=[(False, 16, 16), | |
(True, 16, 16), | |
(False, 16, 16), | |
(True, 16, 8), | |
(False, 8, 8), | |
(True, 8, 8), | |
(False, 8, 8), | |
(True, 8, 4), | |
(False, 4, 4), | |
(True, 4, 2), | |
(False, 2, 2), | |
(True, 2, 1)], | |
attention_layer_position=8, eps=1e-4, n_stats=51) | |
BigGAN512 = BigGANConfig(output_dim=512, z_dim=128, class_embed_dim=128, channel_width=128, num_classes=1000, | |
layers=[(False, 16, 16), | |
(True, 16, 16), | |
(False, 16, 16), | |
(True, 16, 8), | |
(False, 8, 8), | |
(True, 8, 8), | |
(False, 8, 8), | |
(True, 8, 4), | |
(False, 4, 4), | |
(True, 4, 2), | |
(False, 2, 2), | |
(True, 2, 1), | |
(False, 1, 1), | |
(True, 1, 1)], | |
attention_layer_position=8, eps=1e-4, n_stats=51) | |
def main(): | |
parser = argparse.ArgumentParser(description="Convert a BigGAN TF Hub model in a PyTorch model") | |
parser.add_argument("--model_type", type=str, default="", required=True, | |
help="BigGAN model type (128, 256, 512)") | |
parser.add_argument("--tf_model_path", type=str, default="", required=True, | |
help="Path of the downloaded TF Hub model") | |
parser.add_argument("--pt_save_path", type=str, default="", | |
help="Folder to save the PyTorch model (default: Folder of the TF Hub model)") | |
parser.add_argument("--batch_norm_stats_path", type=str, default="", | |
help="Path of previously extracted batch norm statistics") | |
args = parser.parse_args() | |
logging.basicConfig(level=logging.INFO) | |
if not args.pt_save_path: | |
args.pt_save_path = args.tf_model_path | |
if args.model_type == "128": | |
config = BigGAN128 | |
elif args.model_type == "256": | |
config = BigGAN256 | |
elif args.model_type == "512": | |
config = BigGAN512 | |
else: | |
raise ValueError("model_type should be one of 128, 256 or 512") | |
model = BigGAN(config) | |
model = load_tf_weights_in_biggan(model, config, args.tf_model_path, args.batch_norm_stats_path) | |
model_save_path = os.path.join(args.pt_save_path, WEIGHTS_NAME) | |
config_save_path = os.path.join(args.pt_save_path, CONFIG_NAME) | |
logger.info("Save model dump to {}".format(model_save_path)) | |
torch.save(model.state_dict(), model_save_path) | |
logger.info("Save configuration file to {}".format(config_save_path)) | |
with open(config_save_path, "w", encoding="utf-8") as f: | |
f.write(config.to_json_string()) | |
if __name__ == "__main__": | |
main() | |