ic_gan / BigGAN_PyTorch /TFHub /converter.py
ArantxaCasanova
First model version
a00ee36
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# All contributions by Andy Brock:
# Copyright (c) 2019 Andy Brock
#
# MIT License
"""Utilities for converting TFHub BigGAN generator weights to PyTorch.
Recommended usage:
To convert all BigGAN variants and generate test samples, use:
```bash
CUDA_VISIBLE_DEVICES=0 python converter.py --generate_samples
```
See `parse_args` for additional options.
"""
import argparse
import os
import sys
import h5py
import torch
import torch.nn as nn
from torchvision.utils import save_image
import tensorflow as tf
import tensorflow_hub as hub
import parse
# import reference biggan from this folder
import biggan_v1 as biggan_for_conversion
# Import model from main folder
sys.path.append("..")
import BigGAN
DEVICE = "cuda"
HDF5_TMPL = "biggan-{}.h5"
PTH_TMPL = "biggan-{}.pth"
MODULE_PATH_TMPL = "https://tfhub.dev/deepmind/biggan-{}/2"
Z_DIMS = {128: 120, 256: 140, 512: 128}
RESOLUTIONS = list(Z_DIMS)
def dump_tfhub_to_hdf5(module_path, hdf5_path, redownload=False):
"""Loads TFHub weights and saves them to intermediate HDF5 file.
Args:
module_path ([Path-like]): Path to TFHub module.
hdf5_path ([Path-like]): Path to output HDF5 file.
Returns:
[h5py.File]: Loaded hdf5 file containing module weights.
"""
if os.path.exists(hdf5_path) and (not redownload):
print("Loading BigGAN hdf5 file from:", hdf5_path)
return h5py.File(hdf5_path, "r")
print("Loading BigGAN module from:", module_path)
tf.reset_default_graph()
hub.Module(module_path)
print("Loaded BigGAN module from:", module_path)
initializer = tf.global_variables_initializer()
sess = tf.Session()
sess.run(initializer)
print("Saving BigGAN weights to :", hdf5_path)
h5f = h5py.File(hdf5_path, "w")
for var in tf.global_variables():
val = sess.run(var)
h5f.create_dataset(var.name, data=val)
print(f"Saving {var.name} with shape {val.shape}")
h5f.close()
return h5py.File(hdf5_path, "r")
class TFHub2Pytorch(object):
TF_ROOT = "module"
NUM_GBLOCK = {128: 5, 256: 6, 512: 7}
w = "w"
b = "b"
u = "u0"
v = "u1"
gamma = "gamma"
beta = "beta"
def __init__(
self, state_dict, tf_weights, resolution=256, load_ema=True, verbose=False
):
self.state_dict = state_dict
self.tf_weights = tf_weights
self.resolution = resolution
self.verbose = verbose
if load_ema:
for name in ["w", "b", "gamma", "beta"]:
setattr(self, name, getattr(self, name) + "/ema_b999900")
def load(self):
self.load_generator()
return self.state_dict
def load_generator(self):
GENERATOR_ROOT = os.path.join(self.TF_ROOT, "Generator")
for i in range(self.NUM_GBLOCK[self.resolution]):
name_tf = os.path.join(GENERATOR_ROOT, "GBlock")
name_tf += f"_{i}" if i != 0 else ""
self.load_GBlock(f"GBlock.{i}.", name_tf)
self.load_attention("attention.", os.path.join(GENERATOR_ROOT, "attention"))
self.load_linear("linear", os.path.join(self.TF_ROOT, "linear"), bias=False)
self.load_snlinear("G_linear", os.path.join(GENERATOR_ROOT, "G_Z", "G_linear"))
self.load_colorize("colorize", os.path.join(GENERATOR_ROOT, "conv_2d"))
self.load_ScaledCrossReplicaBNs(
"ScaledCrossReplicaBN", os.path.join(GENERATOR_ROOT, "ScaledCrossReplicaBN")
)
def load_linear(self, name_pth, name_tf, bias=True):
self.state_dict[name_pth + ".weight"] = self.load_tf_tensor(
name_tf, self.w
).permute(1, 0)
if bias:
self.state_dict[name_pth + ".bias"] = self.load_tf_tensor(name_tf, self.b)
def load_snlinear(self, name_pth, name_tf, bias=True):
self.state_dict[name_pth + ".module.weight_u"] = self.load_tf_tensor(
name_tf, self.u
).squeeze()
self.state_dict[name_pth + ".module.weight_v"] = self.load_tf_tensor(
name_tf, self.v
).squeeze()
self.state_dict[name_pth + ".module.weight_bar"] = self.load_tf_tensor(
name_tf, self.w
).permute(1, 0)
if bias:
self.state_dict[name_pth + ".module.bias"] = self.load_tf_tensor(
name_tf, self.b
)
def load_colorize(self, name_pth, name_tf):
self.load_snconv(name_pth, name_tf)
def load_GBlock(self, name_pth, name_tf):
self.load_convs(name_pth, name_tf)
self.load_HyperBNs(name_pth, name_tf)
def load_convs(self, name_pth, name_tf):
self.load_snconv(name_pth + "conv0", os.path.join(name_tf, "conv0"))
self.load_snconv(name_pth + "conv1", os.path.join(name_tf, "conv1"))
self.load_snconv(name_pth + "conv_sc", os.path.join(name_tf, "conv_sc"))
def load_snconv(self, name_pth, name_tf, bias=True):
if self.verbose:
print(f"loading: {name_pth} from {name_tf}")
self.state_dict[name_pth + ".module.weight_u"] = self.load_tf_tensor(
name_tf, self.u
).squeeze()
self.state_dict[name_pth + ".module.weight_v"] = self.load_tf_tensor(
name_tf, self.v
).squeeze()
self.state_dict[name_pth + ".module.weight_bar"] = self.load_tf_tensor(
name_tf, self.w
).permute(3, 2, 0, 1)
if bias:
self.state_dict[name_pth + ".module.bias"] = self.load_tf_tensor(
name_tf, self.b
).squeeze()
def load_conv(self, name_pth, name_tf, bias=True):
self.state_dict[name_pth + ".weight_u"] = self.load_tf_tensor(
name_tf, self.u
).squeeze()
self.state_dict[name_pth + ".weight_v"] = self.load_tf_tensor(
name_tf, self.v
).squeeze()
self.state_dict[name_pth + ".weight_bar"] = self.load_tf_tensor(
name_tf, self.w
).permute(3, 2, 0, 1)
if bias:
self.state_dict[name_pth + ".bias"] = self.load_tf_tensor(name_tf, self.b)
def load_HyperBNs(self, name_pth, name_tf):
self.load_HyperBN(name_pth + "HyperBN", os.path.join(name_tf, "HyperBN"))
self.load_HyperBN(name_pth + "HyperBN_1", os.path.join(name_tf, "HyperBN_1"))
def load_ScaledCrossReplicaBNs(self, name_pth, name_tf):
self.state_dict[name_pth + ".bias"] = self.load_tf_tensor(
name_tf, self.beta
).squeeze()
self.state_dict[name_pth + ".weight"] = self.load_tf_tensor(
name_tf, self.gamma
).squeeze()
self.state_dict[name_pth + ".running_mean"] = self.load_tf_tensor(
name_tf + "bn", "accumulated_mean"
)
self.state_dict[name_pth + ".running_var"] = self.load_tf_tensor(
name_tf + "bn", "accumulated_var"
)
self.state_dict[name_pth + ".num_batches_tracked"] = torch.tensor(
self.tf_weights[os.path.join(name_tf + "bn", "accumulation_counter:0")][()],
dtype=torch.float32,
)
def load_HyperBN(self, name_pth, name_tf):
if self.verbose:
print(f"loading: {name_pth} from {name_tf}")
beta = name_pth + ".beta_embed.module"
gamma = name_pth + ".gamma_embed.module"
self.state_dict[beta + ".weight_u"] = self.load_tf_tensor(
os.path.join(name_tf, "beta"), self.u
).squeeze()
self.state_dict[gamma + ".weight_u"] = self.load_tf_tensor(
os.path.join(name_tf, "gamma"), self.u
).squeeze()
self.state_dict[beta + ".weight_v"] = self.load_tf_tensor(
os.path.join(name_tf, "beta"), self.v
).squeeze()
self.state_dict[gamma + ".weight_v"] = self.load_tf_tensor(
os.path.join(name_tf, "gamma"), self.v
).squeeze()
self.state_dict[beta + ".weight_bar"] = self.load_tf_tensor(
os.path.join(name_tf, "beta"), self.w
).permute(1, 0)
self.state_dict[gamma + ".weight_bar"] = self.load_tf_tensor(
os.path.join(name_tf, "gamma"), self.w
).permute(1, 0)
cr_bn_name = name_tf.replace("HyperBN", "CrossReplicaBN")
self.state_dict[name_pth + ".bn.running_mean"] = self.load_tf_tensor(
cr_bn_name, "accumulated_mean"
)
self.state_dict[name_pth + ".bn.running_var"] = self.load_tf_tensor(
cr_bn_name, "accumulated_var"
)
self.state_dict[name_pth + ".bn.num_batches_tracked"] = torch.tensor(
self.tf_weights[os.path.join(cr_bn_name, "accumulation_counter:0")][()],
dtype=torch.float32,
)
def load_attention(self, name_pth, name_tf):
self.load_snconv(name_pth + "theta", os.path.join(name_tf, "theta"), bias=False)
self.load_snconv(name_pth + "phi", os.path.join(name_tf, "phi"), bias=False)
self.load_snconv(name_pth + "g", os.path.join(name_tf, "g"), bias=False)
self.load_snconv(
name_pth + "o_conv", os.path.join(name_tf, "o_conv"), bias=False
)
self.state_dict[name_pth + "gamma"] = self.load_tf_tensor(name_tf, self.gamma)
def load_tf_tensor(self, prefix, var, device="0"):
name = os.path.join(prefix, var) + f":{device}"
return torch.from_numpy(self.tf_weights[name][:])
# Convert from v1: This function maps
def convert_from_v1(hub_dict, resolution=128):
weightname_dict = {"weight_u": "u0", "weight_bar": "weight", "bias": "bias"}
convnum_dict = {"conv0": "conv1", "conv1": "conv2", "conv_sc": "conv_sc"}
attention_blocknum = {128: 3, 256: 4, 512: 3}[resolution]
hub2me = {
"linear.weight": "shared.weight", # This is actually the shared weight
# Linear stuff
"G_linear.module.weight_bar": "linear.weight",
"G_linear.module.bias": "linear.bias",
"G_linear.module.weight_u": "linear.u0",
# output layer stuff
"ScaledCrossReplicaBN.weight": "output_layer.0.gain",
"ScaledCrossReplicaBN.bias": "output_layer.0.bias",
"ScaledCrossReplicaBN.running_mean": "output_layer.0.stored_mean",
"ScaledCrossReplicaBN.running_var": "output_layer.0.stored_var",
"colorize.module.weight_bar": "output_layer.2.weight",
"colorize.module.bias": "output_layer.2.bias",
"colorize.module.weight_u": "output_layer.2.u0",
# Attention stuff
"attention.gamma": "blocks.%d.1.gamma" % attention_blocknum,
"attention.theta.module.weight_u": "blocks.%d.1.theta.u0" % attention_blocknum,
"attention.theta.module.weight_bar": "blocks.%d.1.theta.weight"
% attention_blocknum,
"attention.phi.module.weight_u": "blocks.%d.1.phi.u0" % attention_blocknum,
"attention.phi.module.weight_bar": "blocks.%d.1.phi.weight"
% attention_blocknum,
"attention.g.module.weight_u": "blocks.%d.1.g.u0" % attention_blocknum,
"attention.g.module.weight_bar": "blocks.%d.1.g.weight" % attention_blocknum,
"attention.o_conv.module.weight_u": "blocks.%d.1.o.u0" % attention_blocknum,
"attention.o_conv.module.weight_bar": "blocks.%d.1.o.weight"
% attention_blocknum,
}
# Loop over the hub dict and build the hub2me map
for name in hub_dict.keys():
if "GBlock" in name:
if "HyperBN" not in name: # it's a conv
out = parse.parse("GBlock.{:d}.{}.module.{}", name)
blocknum, convnum, weightname = out
if weightname not in weightname_dict:
continue # else hyperBN in
out_name = "blocks.%d.0.%s.%s" % (
blocknum,
convnum_dict[convnum],
weightname_dict[weightname],
) # Increment conv number by 1
else: # hyperbn not conv
BNnum = 2 if "HyperBN_1" in name else 1
if "embed" in name:
out = parse.parse("GBlock.{:d}.{}.module.{}", name)
blocknum, gamma_or_beta, weightname = out
if weightname not in weightname_dict: # Ignore weight_v
continue
out_name = "blocks.%d.0.bn%d.%s.%s" % (
blocknum,
BNnum,
"gain" if "gamma" in gamma_or_beta else "bias",
weightname_dict[weightname],
)
else:
out = parse.parse("GBlock.{:d}.{}.bn.{}", name)
blocknum, dummy, mean_or_var = out
if "num_batches_tracked" in mean_or_var:
continue
out_name = "blocks.%d.0.bn%d.%s" % (
blocknum,
BNnum,
"stored_mean" if "mean" in mean_or_var else "stored_var",
)
hub2me[name] = out_name
# Invert the hub2me map
me2hub = {hub2me[item]: item for item in hub2me}
new_dict = {}
dimz_dict = {128: 20, 256: 20, 512: 16}
for item in me2hub:
# Swap input dim ordering on batchnorm bois to account for my arbitrary change of ordering when concatenating Ys and Zs
if (
("bn" in item and "weight" in item)
and ("gain" in item or "bias" in item)
and ("output_layer" not in item)
):
new_dict[item] = torch.cat(
[
hub_dict[me2hub[item]][:, -128:],
hub_dict[me2hub[item]][:, : dimz_dict[resolution]],
],
1,
)
# Reshape the first linear weight, bias, and u0
elif item == "linear.weight":
new_dict[item] = (
hub_dict[me2hub[item]]
.contiguous()
.view(4, 4, 96 * 16, -1)
.permute(2, 0, 1, 3)
.contiguous()
.view(-1, dimz_dict[resolution])
)
elif item == "linear.bias":
new_dict[item] = (
hub_dict[me2hub[item]]
.view(4, 4, 96 * 16)
.permute(2, 0, 1)
.contiguous()
.view(-1)
)
elif item == "linear.u0":
new_dict[item] = (
hub_dict[me2hub[item]]
.view(4, 4, 96 * 16)
.permute(2, 0, 1)
.contiguous()
.view(1, -1)
)
elif (
me2hub[item] == "linear.weight"
): # THIS IS THE SHARED WEIGHT NOT THE FIRST LINEAR LAYER
# Transpose shared weight so that it's an embedding
new_dict[item] = hub_dict[me2hub[item]].t()
elif "weight_u" in me2hub[item]: # Unsqueeze u0s
new_dict[item] = hub_dict[me2hub[item]].unsqueeze(0)
else:
new_dict[item] = hub_dict[me2hub[item]]
return new_dict
def get_config(resolution):
attn_dict = {128: "64", 256: "128", 512: "64"}
dim_z_dict = {128: 120, 256: 140, 512: 128}
config = {
"G_param": "SN",
"D_param": "SN",
"G_ch": 96,
"D_ch": 96,
"D_wide": True,
"G_shared": True,
"shared_dim": 128,
"dim_z": dim_z_dict[resolution],
"hier": True,
"cross_replica": False,
"mybn": False,
"G_activation": nn.ReLU(inplace=True),
"G_attn": attn_dict[resolution],
"norm_style": "bn",
"G_init": "ortho",
"skip_init": True,
"no_optim": True,
"G_fp16": False,
"G_mixed_precision": False,
"accumulate_stats": False,
"num_standing_accumulations": 16,
"G_eval_mode": True,
"BN_eps": 1e-04,
"SN_eps": 1e-04,
"num_G_SVs": 1,
"num_G_SV_itrs": 1,
"resolution": resolution,
"n_classes": 1000,
}
return config
def convert_biggan(
resolution, weight_dir, redownload=False, no_ema=False, verbose=False
):
module_path = MODULE_PATH_TMPL.format(resolution)
hdf5_path = os.path.join(weight_dir, HDF5_TMPL.format(resolution))
pth_path = os.path.join(weight_dir, PTH_TMPL.format(resolution))
tf_weights = dump_tfhub_to_hdf5(module_path, hdf5_path, redownload=redownload)
G_temp = getattr(biggan_for_conversion, f"Generator{resolution}")()
state_dict_temp = G_temp.state_dict()
converter = TFHub2Pytorch(
state_dict_temp,
tf_weights,
resolution=resolution,
load_ema=(not no_ema),
verbose=verbose,
)
state_dict_v1 = converter.load()
state_dict = convert_from_v1(state_dict_v1, resolution)
# Get the config, build the model
config = get_config(resolution)
G = BigGAN.Generator(**config)
G.load_state_dict(state_dict, strict=False) # Ignore missing sv0 entries
torch.save(state_dict, pth_path)
# output_location ='pretrained_weights/TFHub-PyTorch-128.pth'
return G
def generate_sample(G, z_dim, batch_size, filename, parallel=False):
G.eval()
G.to(DEVICE)
with torch.no_grad():
z = torch.randn(batch_size, G.dim_z).to(DEVICE)
y = torch.randint(
low=0,
high=1000,
size=(batch_size,),
device=DEVICE,
dtype=torch.int64,
requires_grad=False,
)
if parallel:
images = nn.parallel.data_parallel(G, (z, G.shared(y)))
else:
images = G(z, G.shared(y))
save_image(images, filename, scale_each=True, normalize=True)
def parse_args():
usage = "Parser for conversion script."
parser = argparse.ArgumentParser(description=usage)
parser.add_argument(
"--resolution",
"-r",
type=int,
default=None,
choices=[128, 256, 512],
help="Resolution of TFHub module to convert. Converts all resolutions if None.",
)
parser.add_argument(
"--redownload",
action="store_true",
default=False,
help="Redownload weights and overwrite current hdf5 file, if present.",
)
parser.add_argument("--weights_dir", type=str, default="pretrained_weights")
parser.add_argument("--samples_dir", type=str, default="pretrained_samples")
parser.add_argument(
"--no_ema", action="store_true", default=False, help="Do not load ema weights."
)
parser.add_argument(
"--verbose", action="store_true", default=False, help="Additionally logging."
)
parser.add_argument(
"--generate_samples",
action="store_true",
default=False,
help="Generate test sample with pretrained model.",
)
parser.add_argument(
"--batch_size", type=int, default=64, help="Batch size used for test sample."
)
parser.add_argument(
"--parallel", action="store_true", default=False, help="Parallelize G?"
)
args = parser.parse_args()
return args
if __name__ == "__main__":
args = parse_args()
os.makedirs(args.weights_dir, exist_ok=True)
os.makedirs(args.samples_dir, exist_ok=True)
if args.resolution is not None:
G = convert_biggan(
args.resolution,
args.weights_dir,
redownload=args.redownload,
no_ema=args.no_ema,
verbose=args.verbose,
)
if args.generate_samples:
filename = os.path.join(
args.samples_dir, f"biggan{args.resolution}_samples.jpg"
)
print("Generating samples...")
generate_sample(
G, Z_DIMS[args.resolution], args.batch_size, filename, args.parallel
)
else:
for res in RESOLUTIONS:
G = convert_biggan(
res,
args.weights_dir,
redownload=args.redownload,
no_ema=args.no_ema,
verbose=args.verbose,
)
if args.generate_samples:
filename = os.path.join(args.samples_dir, f"biggan{res}_samples.jpg")
print("Generating samples...")
generate_sample(
G, Z_DIMS[res], args.batch_size, filename, args.parallel
)