|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
import click |
|
import pickle |
|
import re |
|
import copy |
|
import numpy as np |
|
import torch |
|
import sys |
|
|
|
sys.path.insert(1, os.path.join(sys.path[0], "../stylegan2_ada_pytorch")) |
|
import dnnlib |
|
from torch_utils import misc |
|
|
|
|
|
|
|
|
|
def load_network_pkl(f, force_fp16=False): |
|
data = _LegacyUnpickler(f).load() |
|
|
|
|
|
if ( |
|
isinstance(data, tuple) |
|
and len(data) == 3 |
|
and all(isinstance(net, _TFNetworkStub) for net in data) |
|
): |
|
tf_G, tf_D, tf_Gs = data |
|
G = convert_tf_generator(tf_G) |
|
D = convert_tf_discriminator(tf_D) |
|
G_ema = convert_tf_generator(tf_Gs) |
|
data = dict(G=G, D=D, G_ema=G_ema) |
|
|
|
|
|
if "training_set_kwargs" not in data: |
|
data["training_set_kwargs"] = None |
|
if "augment_pipe" not in data: |
|
data["augment_pipe"] = None |
|
|
|
|
|
|
|
|
|
assert isinstance(data["G_ema"], torch.nn.Module) |
|
|
|
|
|
|
|
|
|
if force_fp16: |
|
for key in ["G", "D", "G_ema"]: |
|
old = data[key] |
|
kwargs = copy.deepcopy(old.init_kwargs) |
|
if key.startswith("G"): |
|
kwargs.synthesis_kwargs = dnnlib.EasyDict( |
|
kwargs.get("synthesis_kwargs", {}) |
|
) |
|
kwargs.synthesis_kwargs.num_fp16_res = 4 |
|
kwargs.synthesis_kwargs.conv_clamp = 256 |
|
if key.startswith("D"): |
|
kwargs.num_fp16_res = 4 |
|
kwargs.conv_clamp = 256 |
|
if kwargs != old.init_kwargs: |
|
new = type(old)(**kwargs).eval().requires_grad_(False) |
|
misc.copy_params_and_buffers(old, new, require_all=True) |
|
data[key] = new |
|
return data |
|
|
|
|
|
|
|
|
|
|
|
class _TFNetworkStub(dnnlib.EasyDict): |
|
pass |
|
|
|
|
|
class _LegacyUnpickler(pickle.Unpickler): |
|
def find_class(self, module, name): |
|
if module == "dnnlib.tflib.network" and name == "Network": |
|
return _TFNetworkStub |
|
return super().find_class(module, name) |
|
|
|
|
|
|
|
|
|
|
|
def _collect_tf_params(tf_net): |
|
|
|
tf_params = dict() |
|
|
|
def recurse(prefix, tf_net): |
|
for name, value in tf_net.variables: |
|
tf_params[prefix + name] = value |
|
for name, comp in tf_net.components.items(): |
|
recurse(prefix + name + "/", comp) |
|
|
|
recurse("", tf_net) |
|
return tf_params |
|
|
|
|
|
|
|
|
|
|
|
def _populate_module_params(module, *patterns): |
|
for name, tensor in misc.named_params_and_buffers(module): |
|
found = False |
|
value = None |
|
for pattern, value_fn in zip(patterns[0::2], patterns[1::2]): |
|
match = re.fullmatch(pattern, name) |
|
if match: |
|
found = True |
|
if value_fn is not None: |
|
value = value_fn(*match.groups()) |
|
break |
|
try: |
|
assert found |
|
if value is not None: |
|
tensor.copy_(torch.from_numpy(np.array(value))) |
|
except: |
|
print(name, list(tensor.shape)) |
|
raise |
|
|
|
|
|
|
|
|
|
|
|
def convert_tf_generator(tf_G): |
|
if tf_G.version < 4: |
|
raise ValueError("TensorFlow pickle version too low") |
|
|
|
|
|
tf_kwargs = tf_G.static_kwargs |
|
known_kwargs = set() |
|
|
|
def kwarg(tf_name, default=None, none=None): |
|
known_kwargs.add(tf_name) |
|
val = tf_kwargs.get(tf_name, default) |
|
return val if val is not None else none |
|
|
|
|
|
kwargs = dnnlib.EasyDict( |
|
z_dim=kwarg("latent_size", 512), |
|
c_dim=kwarg("label_size", 0), |
|
w_dim=kwarg("dlatent_size", 512), |
|
img_resolution=kwarg("resolution", 1024), |
|
img_channels=kwarg("num_channels", 3), |
|
mapping_kwargs=dnnlib.EasyDict( |
|
num_layers=kwarg("mapping_layers", 8), |
|
embed_features=kwarg("label_fmaps", None), |
|
layer_features=kwarg("mapping_fmaps", None), |
|
activation=kwarg("mapping_nonlinearity", "lrelu"), |
|
lr_multiplier=kwarg("mapping_lrmul", 0.01), |
|
w_avg_beta=kwarg("w_avg_beta", 0.995, none=1), |
|
), |
|
synthesis_kwargs=dnnlib.EasyDict( |
|
channel_base=kwarg("fmap_base", 16384) * 2, |
|
channel_max=kwarg("fmap_max", 512), |
|
num_fp16_res=kwarg("num_fp16_res", 0), |
|
conv_clamp=kwarg("conv_clamp", None), |
|
architecture=kwarg("architecture", "skip"), |
|
resample_filter=kwarg("resample_kernel", [1, 3, 3, 1]), |
|
use_noise=kwarg("use_noise", True), |
|
activation=kwarg("nonlinearity", "lrelu"), |
|
), |
|
) |
|
|
|
|
|
kwarg("truncation_psi") |
|
kwarg("truncation_cutoff") |
|
kwarg("style_mixing_prob") |
|
kwarg("structure") |
|
unknown_kwargs = list(set(tf_kwargs.keys()) - known_kwargs) |
|
if len(unknown_kwargs) > 0: |
|
raise ValueError("Unknown TensorFlow kwarg", unknown_kwargs[0]) |
|
|
|
|
|
tf_params = _collect_tf_params(tf_G) |
|
for name, value in list(tf_params.items()): |
|
match = re.fullmatch(r"ToRGB_lod(\d+)/(.*)", name) |
|
if match: |
|
r = kwargs.img_resolution // (2 ** int(match.group(1))) |
|
tf_params[f"{r}x{r}/ToRGB/{match.group(2)}"] = value |
|
kwargs.synthesis.kwargs.architecture = "orig" |
|
|
|
|
|
|
|
from training import networks |
|
|
|
G = networks.Generator(**kwargs).eval().requires_grad_(False) |
|
|
|
_populate_module_params( |
|
G, |
|
r"mapping\.w_avg", |
|
lambda: tf_params[f"dlatent_avg"], |
|
r"mapping\.embed\.weight", |
|
lambda: tf_params[f"mapping/LabelEmbed/weight"].transpose(), |
|
r"mapping\.embed\.bias", |
|
lambda: tf_params[f"mapping/LabelEmbed/bias"], |
|
r"mapping\.fc(\d+)\.weight", |
|
lambda i: tf_params[f"mapping/Dense{i}/weight"].transpose(), |
|
r"mapping\.fc(\d+)\.bias", |
|
lambda i: tf_params[f"mapping/Dense{i}/bias"], |
|
r"synthesis\.b4\.const", |
|
lambda: tf_params[f"synthesis/4x4/Const/const"][0], |
|
r"synthesis\.b4\.conv1\.weight", |
|
lambda: tf_params[f"synthesis/4x4/Conv/weight"].transpose(3, 2, 0, 1), |
|
r"synthesis\.b4\.conv1\.bias", |
|
lambda: tf_params[f"synthesis/4x4/Conv/bias"], |
|
r"synthesis\.b4\.conv1\.noise_const", |
|
lambda: tf_params[f"synthesis/noise0"][0, 0], |
|
r"synthesis\.b4\.conv1\.noise_strength", |
|
lambda: tf_params[f"synthesis/4x4/Conv/noise_strength"], |
|
r"synthesis\.b4\.conv1\.affine\.weight", |
|
lambda: tf_params[f"synthesis/4x4/Conv/mod_weight"].transpose(), |
|
r"synthesis\.b4\.conv1\.affine\.bias", |
|
lambda: tf_params[f"synthesis/4x4/Conv/mod_bias"] + 1, |
|
r"synthesis\.b(\d+)\.conv0\.weight", |
|
lambda r: tf_params[f"synthesis/{r}x{r}/Conv0_up/weight"][::-1, ::-1].transpose( |
|
3, 2, 0, 1 |
|
), |
|
r"synthesis\.b(\d+)\.conv0\.bias", |
|
lambda r: tf_params[f"synthesis/{r}x{r}/Conv0_up/bias"], |
|
r"synthesis\.b(\d+)\.conv0\.noise_const", |
|
lambda r: tf_params[f"synthesis/noise{int(np.log2(int(r)))*2-5}"][0, 0], |
|
r"synthesis\.b(\d+)\.conv0\.noise_strength", |
|
lambda r: tf_params[f"synthesis/{r}x{r}/Conv0_up/noise_strength"], |
|
r"synthesis\.b(\d+)\.conv0\.affine\.weight", |
|
lambda r: tf_params[f"synthesis/{r}x{r}/Conv0_up/mod_weight"].transpose(), |
|
r"synthesis\.b(\d+)\.conv0\.affine\.bias", |
|
lambda r: tf_params[f"synthesis/{r}x{r}/Conv0_up/mod_bias"] + 1, |
|
r"synthesis\.b(\d+)\.conv1\.weight", |
|
lambda r: tf_params[f"synthesis/{r}x{r}/Conv1/weight"].transpose(3, 2, 0, 1), |
|
r"synthesis\.b(\d+)\.conv1\.bias", |
|
lambda r: tf_params[f"synthesis/{r}x{r}/Conv1/bias"], |
|
r"synthesis\.b(\d+)\.conv1\.noise_const", |
|
lambda r: tf_params[f"synthesis/noise{int(np.log2(int(r)))*2-4}"][0, 0], |
|
r"synthesis\.b(\d+)\.conv1\.noise_strength", |
|
lambda r: tf_params[f"synthesis/{r}x{r}/Conv1/noise_strength"], |
|
r"synthesis\.b(\d+)\.conv1\.affine\.weight", |
|
lambda r: tf_params[f"synthesis/{r}x{r}/Conv1/mod_weight"].transpose(), |
|
r"synthesis\.b(\d+)\.conv1\.affine\.bias", |
|
lambda r: tf_params[f"synthesis/{r}x{r}/Conv1/mod_bias"] + 1, |
|
r"synthesis\.b(\d+)\.torgb\.weight", |
|
lambda r: tf_params[f"synthesis/{r}x{r}/ToRGB/weight"].transpose(3, 2, 0, 1), |
|
r"synthesis\.b(\d+)\.torgb\.bias", |
|
lambda r: tf_params[f"synthesis/{r}x{r}/ToRGB/bias"], |
|
r"synthesis\.b(\d+)\.torgb\.affine\.weight", |
|
lambda r: tf_params[f"synthesis/{r}x{r}/ToRGB/mod_weight"].transpose(), |
|
r"synthesis\.b(\d+)\.torgb\.affine\.bias", |
|
lambda r: tf_params[f"synthesis/{r}x{r}/ToRGB/mod_bias"] + 1, |
|
r"synthesis\.b(\d+)\.skip\.weight", |
|
lambda r: tf_params[f"synthesis/{r}x{r}/Skip/weight"][::-1, ::-1].transpose( |
|
3, 2, 0, 1 |
|
), |
|
r".*\.resample_filter", |
|
None, |
|
) |
|
return G |
|
|
|
|
|
|
|
|
|
|
|
def convert_tf_discriminator(tf_D): |
|
if tf_D.version < 4: |
|
raise ValueError("TensorFlow pickle version too low") |
|
|
|
|
|
tf_kwargs = tf_D.static_kwargs |
|
known_kwargs = set() |
|
|
|
def kwarg(tf_name, default=None): |
|
known_kwargs.add(tf_name) |
|
return tf_kwargs.get(tf_name, default) |
|
|
|
|
|
kwargs = dnnlib.EasyDict( |
|
c_dim=kwarg("label_size", 0), |
|
img_resolution=kwarg("resolution", 1024), |
|
img_channels=kwarg("num_channels", 3), |
|
architecture=kwarg("architecture", "resnet"), |
|
channel_base=kwarg("fmap_base", 16384) * 2, |
|
channel_max=kwarg("fmap_max", 512), |
|
num_fp16_res=kwarg("num_fp16_res", 0), |
|
conv_clamp=kwarg("conv_clamp", None), |
|
cmap_dim=kwarg("mapping_fmaps", None), |
|
block_kwargs=dnnlib.EasyDict( |
|
activation=kwarg("nonlinearity", "lrelu"), |
|
resample_filter=kwarg("resample_kernel", [1, 3, 3, 1]), |
|
freeze_layers=kwarg("freeze_layers", 0), |
|
), |
|
mapping_kwargs=dnnlib.EasyDict( |
|
num_layers=kwarg("mapping_layers", 0), |
|
embed_features=kwarg("mapping_fmaps", None), |
|
layer_features=kwarg("mapping_fmaps", None), |
|
activation=kwarg("nonlinearity", "lrelu"), |
|
lr_multiplier=kwarg("mapping_lrmul", 0.1), |
|
), |
|
epilogue_kwargs=dnnlib.EasyDict( |
|
mbstd_group_size=kwarg("mbstd_group_size", None), |
|
mbstd_num_channels=kwarg("mbstd_num_features", 1), |
|
activation=kwarg("nonlinearity", "lrelu"), |
|
), |
|
) |
|
|
|
|
|
kwarg("structure") |
|
unknown_kwargs = list(set(tf_kwargs.keys()) - known_kwargs) |
|
if len(unknown_kwargs) > 0: |
|
raise ValueError("Unknown TensorFlow kwarg", unknown_kwargs[0]) |
|
|
|
|
|
tf_params = _collect_tf_params(tf_D) |
|
for name, value in list(tf_params.items()): |
|
match = re.fullmatch(r"FromRGB_lod(\d+)/(.*)", name) |
|
if match: |
|
r = kwargs.img_resolution // (2 ** int(match.group(1))) |
|
tf_params[f"{r}x{r}/FromRGB/{match.group(2)}"] = value |
|
kwargs.architecture = "orig" |
|
|
|
|
|
|
|
from training import networks |
|
|
|
D = networks.Discriminator(**kwargs).eval().requires_grad_(False) |
|
|
|
_populate_module_params( |
|
D, |
|
r"b(\d+)\.fromrgb\.weight", |
|
lambda r: tf_params[f"{r}x{r}/FromRGB/weight"].transpose(3, 2, 0, 1), |
|
r"b(\d+)\.fromrgb\.bias", |
|
lambda r: tf_params[f"{r}x{r}/FromRGB/bias"], |
|
r"b(\d+)\.conv(\d+)\.weight", |
|
lambda r, i: tf_params[ |
|
f'{r}x{r}/Conv{i}{["","_down"][int(i)]}/weight' |
|
].transpose(3, 2, 0, 1), |
|
r"b(\d+)\.conv(\d+)\.bias", |
|
lambda r, i: tf_params[f'{r}x{r}/Conv{i}{["","_down"][int(i)]}/bias'], |
|
r"b(\d+)\.skip\.weight", |
|
lambda r: tf_params[f"{r}x{r}/Skip/weight"].transpose(3, 2, 0, 1), |
|
r"mapping\.embed\.weight", |
|
lambda: tf_params[f"LabelEmbed/weight"].transpose(), |
|
r"mapping\.embed\.bias", |
|
lambda: tf_params[f"LabelEmbed/bias"], |
|
r"mapping\.fc(\d+)\.weight", |
|
lambda i: tf_params[f"Mapping{i}/weight"].transpose(), |
|
r"mapping\.fc(\d+)\.bias", |
|
lambda i: tf_params[f"Mapping{i}/bias"], |
|
r"b4\.conv\.weight", |
|
lambda: tf_params[f"4x4/Conv/weight"].transpose(3, 2, 0, 1), |
|
r"b4\.conv\.bias", |
|
lambda: tf_params[f"4x4/Conv/bias"], |
|
r"b4\.fc\.weight", |
|
lambda: tf_params[f"4x4/Dense0/weight"].transpose(), |
|
r"b4\.fc\.bias", |
|
lambda: tf_params[f"4x4/Dense0/bias"], |
|
r"b4\.out\.weight", |
|
lambda: tf_params[f"Output/weight"].transpose(), |
|
r"b4\.out\.bias", |
|
lambda: tf_params[f"Output/bias"], |
|
r".*\.resample_filter", |
|
None, |
|
) |
|
return D |
|
|
|
|
|
|
|
|
|
|
|
@click.command() |
|
@click.option("--source", help="Input pickle", required=True, metavar="PATH") |
|
@click.option("--dest", help="Output pickle", required=True, metavar="PATH") |
|
@click.option( |
|
"--force-fp16", |
|
help="Force the networks to use FP16", |
|
type=bool, |
|
default=False, |
|
metavar="BOOL", |
|
show_default=True, |
|
) |
|
def convert_network_pickle(source, dest, force_fp16): |
|
"""Convert legacy network pickle into the native PyTorch format. |
|
|
|
The tool is able to load the main network configurations exported using the TensorFlow version of StyleGAN2 or StyleGAN2-ADA. |
|
It does not support e.g. StyleGAN2-ADA comparison methods, StyleGAN2 configs A-D, or StyleGAN1 networks. |
|
|
|
Example: |
|
|
|
\b |
|
python legacy.py \\ |
|
--source=https://nvlabs-fi-cdn.nvidia.com/stylegan2/networks/stylegan2-cat-config-f.pkl \\ |
|
--dest=stylegan2-cat-config-f.pkl |
|
""" |
|
print(f'Loading "{source}"...') |
|
with dnnlib.util.open_url(source) as f: |
|
data = load_network_pkl(f, force_fp16=force_fp16) |
|
print(f'Saving "{dest}"...') |
|
with open(dest, "wb") as f: |
|
pickle.dump(data, f) |
|
print("Done.") |
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
convert_network_pickle() |
|
|
|
|
|
|