diff --git a/README.md b/README.md index 65d9a87e6b0362358d3fd43b64aa80e1ff4463ef..c4b70abea48f0b4457025136df12c0facc92fd5c 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ --- title: PTI -emoji: 🦀 +emoji: 🦀 colorFrom: gray colorTo: pink sdk: gradio @@ -34,4 +34,4 @@ Path to your main application file (which contains either `gradio` or `streamlit Path is relative to the root of the repository. `pinned`: _boolean_ -Whether the Space stays on top of your list. +Whether the Space stays on top of your list. \ No newline at end of file diff --git a/configs/__init__.py b/configs/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/configs/__pycache__/__init__.cpython-36.pyc b/configs/__pycache__/__init__.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ec015a30a95656948c6de2bea661f6842de30021 Binary files /dev/null and b/configs/__pycache__/__init__.cpython-36.pyc differ diff --git a/configs/__pycache__/__init__.cpython-39.pyc b/configs/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..35c8adfc39e66c83980c1948ce266dc4f50c990c Binary files /dev/null and b/configs/__pycache__/__init__.cpython-39.pyc differ diff --git a/configs/__pycache__/global_config.cpython-36.pyc b/configs/__pycache__/global_config.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2055726f7b344c2477fb5b50040e97868cf627ac Binary files /dev/null and b/configs/__pycache__/global_config.cpython-36.pyc differ diff --git a/configs/__pycache__/global_config.cpython-39.pyc b/configs/__pycache__/global_config.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3900bf4b4950316b439221e213310926f8c77254 Binary files /dev/null and b/configs/__pycache__/global_config.cpython-39.pyc differ diff --git a/configs/__pycache__/hyperparameters.cpython-36.pyc b/configs/__pycache__/hyperparameters.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..894d516944dda2241a12371058b024d6171826b1 Binary files /dev/null and b/configs/__pycache__/hyperparameters.cpython-36.pyc differ diff --git a/configs/__pycache__/hyperparameters.cpython-39.pyc b/configs/__pycache__/hyperparameters.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..141d861e309d73ffe56cd3ab4d90c641a834ac54 Binary files /dev/null and b/configs/__pycache__/hyperparameters.cpython-39.pyc differ diff --git a/configs/__pycache__/paths_config.cpython-36.pyc b/configs/__pycache__/paths_config.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a0afd84f05e1ab89a9cd91befcc4307ad711b2ce Binary files /dev/null and b/configs/__pycache__/paths_config.cpython-36.pyc differ diff --git a/configs/__pycache__/paths_config.cpython-39.pyc b/configs/__pycache__/paths_config.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2cb1b6684474a52e43c6a3c57e95ab51a3c45c68 Binary files /dev/null and b/configs/__pycache__/paths_config.cpython-39.pyc differ diff --git a/configs/evaluation_config.py b/configs/evaluation_config.py new file mode 100644 index 0000000000000000000000000000000000000000..16b621d4a47df9e25828c4235cf1692899d14d50 --- /dev/null +++ b/configs/evaluation_config.py @@ -0,0 +1 @@ +evaluated_methods = ['e4e', 'SG2', 'SG2Plus'] \ No newline at end of file diff --git a/configs/global_config.py b/configs/global_config.py new file mode 100644 index 0000000000000000000000000000000000000000..f95c51c56d8179f45095fd6e47554bbfaaf2ee55 --- /dev/null +++ b/configs/global_config.py @@ -0,0 +1,12 @@ +## Device +cuda_visible_devices = '0' +device = 'cuda:0' + +## Logs +training_step = 1 +image_rec_result_log_snapshot = 100 +pivotal_training_steps = 0 +model_snapshot_interval = 400 + +## Run name to be updated during PTI +run_name = '' diff --git a/configs/hyperparameters.py b/configs/hyperparameters.py new file mode 100644 index 0000000000000000000000000000000000000000..ab50db62b29ef29eeb128663e4f7ff737df81ca3 --- /dev/null +++ b/configs/hyperparameters.py @@ -0,0 +1,28 @@ +## Architechture +lpips_type = 'alex' +first_inv_type = 'w' +optim_type = 'adam' + +## Locality regularization +latent_ball_num_of_samples = 1 +locality_regularization_interval = 1 +use_locality_regularization = False +regulizer_l2_lambda = 0.1 +regulizer_lpips_lambda = 0.1 +regulizer_alpha = 30 + +## Loss +pt_l2_lambda = 1 +pt_lpips_lambda = 1 + +## Steps +LPIPS_value_threshold = 0.06 +max_pti_steps = 350 +first_inv_steps = 450 +max_images_to_invert = 300 + +## Optimization +pti_learning_rate = 3e-4 +first_inv_lr = 5e-3 +train_batch_size = 1 +use_last_w_pivots = False diff --git a/configs/paths_config.py b/configs/paths_config.py new file mode 100644 index 0000000000000000000000000000000000000000..508a0cb459c50b8d74b2399eea25a543d99e8eb4 --- /dev/null +++ b/configs/paths_config.py @@ -0,0 +1,31 @@ +## Pretrained models paths +e4e = './pretrained_models/e4e_ffhq_encode.pt' +stylegan2_ada_ffhq = '/home/sayantan/PTI/pretrained_models/ffhq.pkl' +style_clip_pretrained_mappers = '' +ir_se50 = './pretrained_models/model_ir_se50.pth' +dlib = './pretrained_models/align.dat' + +## Dirs for output files +checkpoints_dir = './checkpoints' +embedding_base_dir = './embeddings' +styleclip_output_dir = './StyleCLIP_results' +experiments_output_dir = './output' + +## Input info +### Input dir, where the images reside +input_data_path = '' +### Inversion identifier, used to keeping track of the inversion results. Both the latent code and the generator +input_data_id = 'rocky' + +## Keywords +pti_results_keyword = 'PTI' +e4e_results_keyword = 'e4e' +sg2_results_keyword = 'SG2' +sg2_plus_results_keyword = 'SG2_plus' +multi_id_model_type = 'multi_id' + +## Edit directions +interfacegan_age = 'editings/interfacegan_directions/age.pt' +interfacegan_smile = 'editings/interfacegan_directions/smile.pt' +interfacegan_rotation = 'editings/interfacegan_directions/rotation.pt' +ffhq_pca = 'editings/ganspace_pca/ffhq_pca.pt' diff --git a/criteria/__init__.py b/criteria/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/criteria/__pycache__/__init__.cpython-36.pyc b/criteria/__pycache__/__init__.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..aa37579b15ea2b20a2f3bc499dc49b0a1bff2115 Binary files /dev/null and b/criteria/__pycache__/__init__.cpython-36.pyc differ diff --git a/criteria/__pycache__/__init__.cpython-39.pyc b/criteria/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..db20edd4f1709d48fc8b6802798623b1c7554f6e Binary files /dev/null and b/criteria/__pycache__/__init__.cpython-39.pyc differ diff --git a/criteria/__pycache__/l2_loss.cpython-36.pyc b/criteria/__pycache__/l2_loss.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3405ac62e5c18a9007f976fc2dbe12932ea3c57a Binary files /dev/null and b/criteria/__pycache__/l2_loss.cpython-36.pyc differ diff --git a/criteria/__pycache__/l2_loss.cpython-39.pyc b/criteria/__pycache__/l2_loss.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9756932ac28382dee8fddf21c2d0ebd6647db099 Binary files /dev/null and b/criteria/__pycache__/l2_loss.cpython-39.pyc differ diff --git a/criteria/__pycache__/localitly_regulizer.cpython-36.pyc b/criteria/__pycache__/localitly_regulizer.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..38c6aeb48a4ddf2160afeecade53f9dea353063d Binary files /dev/null and b/criteria/__pycache__/localitly_regulizer.cpython-36.pyc differ diff --git a/criteria/__pycache__/localitly_regulizer.cpython-39.pyc b/criteria/__pycache__/localitly_regulizer.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..92aa18adfe0f70d7f009a03bdc0c7ce5385d4c12 Binary files /dev/null and b/criteria/__pycache__/localitly_regulizer.cpython-39.pyc differ diff --git a/criteria/l2_loss.py b/criteria/l2_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..c7ac2753b02dfa9d21ccf03fa3b87b9d6fc3f01d --- /dev/null +++ b/criteria/l2_loss.py @@ -0,0 +1,8 @@ +import torch + +l2_criterion = torch.nn.MSELoss(reduction='mean') + + +def l2_loss(real_images, generated_images): + loss = l2_criterion(real_images, generated_images) + return loss diff --git a/criteria/localitly_regulizer.py b/criteria/localitly_regulizer.py new file mode 100644 index 0000000000000000000000000000000000000000..9925ab97cf80119e8099e0d89de42e33d58faf0e --- /dev/null +++ b/criteria/localitly_regulizer.py @@ -0,0 +1,59 @@ +import torch +import numpy as np +import wandb +from criteria import l2_loss +from configs import hyperparameters +from configs import global_config + + +class Space_Regulizer: + def __init__(self, original_G, lpips_net): + self.original_G = original_G + self.morphing_regulizer_alpha = hyperparameters.regulizer_alpha + self.lpips_loss = lpips_net + + def get_morphed_w_code(self, new_w_code, fixed_w): + interpolation_direction = new_w_code - fixed_w + interpolation_direction_norm = torch.norm(interpolation_direction, p=2) + direction_to_move = hyperparameters.regulizer_alpha * interpolation_direction / interpolation_direction_norm + result_w = fixed_w + direction_to_move + self.morphing_regulizer_alpha * fixed_w + (1 - self.morphing_regulizer_alpha) * new_w_code + + return result_w + + def get_image_from_ws(self, w_codes, G): + return torch.cat([G.synthesis(w_code, noise_mode='none', force_fp32=True) for w_code in w_codes]) + + def ball_holder_loss_lazy(self, new_G, num_of_sampled_latents, w_batch, use_wandb=False): + loss = 0.0 + + z_samples = np.random.randn(num_of_sampled_latents, self.original_G.z_dim) + w_samples = self.original_G.mapping(torch.from_numpy(z_samples).to(global_config.device), None, + truncation_psi=0.5) + territory_indicator_ws = [self.get_morphed_w_code(w_code.unsqueeze(0), w_batch) for w_code in w_samples] + + for w_code in territory_indicator_ws: + new_img = new_G.synthesis(w_code, noise_mode='none', force_fp32=True) + with torch.no_grad(): + old_img = self.original_G.synthesis(w_code, noise_mode='none', force_fp32=True) + + if hyperparameters.regulizer_l2_lambda > 0: + l2_loss_val = l2_loss.l2_loss(old_img, new_img) + if use_wandb: + wandb.log({f'space_regulizer_l2_loss_val': l2_loss_val.detach().cpu()}, + step=global_config.training_step) + loss += l2_loss_val * hyperparameters.regulizer_l2_lambda + + if hyperparameters.regulizer_lpips_lambda > 0: + loss_lpips = self.lpips_loss(old_img, new_img) + loss_lpips = torch.mean(torch.squeeze(loss_lpips)) + if use_wandb: + wandb.log({f'space_regulizer_lpips_loss_val': loss_lpips.detach().cpu()}, + step=global_config.training_step) + loss += loss_lpips * hyperparameters.regulizer_lpips_lambda + + return loss / len(territory_indicator_ws) + + def space_regulizer_loss(self, new_G, w_batch, use_wandb): + ret_val = self.ball_holder_loss_lazy(new_G, hyperparameters.latent_ball_num_of_samples, w_batch, use_wandb) + return ret_val diff --git a/dnnlib/__init__.py b/dnnlib/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2f08cf36f11f9b0fd94c1b7caeadf69b98375b04 --- /dev/null +++ b/dnnlib/__init__.py @@ -0,0 +1,9 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +from .util import EasyDict, make_cache_dir_path diff --git a/dnnlib/__pycache__/__init__.cpython-36.pyc b/dnnlib/__pycache__/__init__.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6a63c5b87101c658fc55fb3705245a4c22f57c7b Binary files /dev/null and b/dnnlib/__pycache__/__init__.cpython-36.pyc differ diff --git a/dnnlib/__pycache__/__init__.cpython-39.pyc b/dnnlib/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1ed3fe40efb8da63d3ff5301cb434f878a8e9b2c Binary files /dev/null and b/dnnlib/__pycache__/__init__.cpython-39.pyc differ diff --git a/dnnlib/__pycache__/util.cpython-36.pyc b/dnnlib/__pycache__/util.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..540b5dae3b6669507417156b7e8afcbf288be394 Binary files /dev/null and b/dnnlib/__pycache__/util.cpython-36.pyc differ diff --git a/dnnlib/__pycache__/util.cpython-39.pyc b/dnnlib/__pycache__/util.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..22ec89e4e225d718f6fa68250500fe0b7cc7c5b8 Binary files /dev/null and b/dnnlib/__pycache__/util.cpython-39.pyc differ diff --git a/dnnlib/util.py b/dnnlib/util.py new file mode 100644 index 0000000000000000000000000000000000000000..76725336d01e75e1c68daa88be47f4fde0bbc63b --- /dev/null +++ b/dnnlib/util.py @@ -0,0 +1,477 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +"""Miscellaneous utility classes and functions.""" + +import ctypes +import fnmatch +import importlib +import inspect +import numpy as np +import os +import shutil +import sys +import types +import io +import pickle +import re +import requests +import html +import hashlib +import glob +import tempfile +import urllib +import urllib.request +import uuid + +from distutils.util import strtobool +from typing import Any, List, Tuple, Union + + +# Util classes +# ------------------------------------------------------------------------------------------ + + +class EasyDict(dict): + """Convenience class that behaves like a dict but allows access with the attribute syntax.""" + + def __getattr__(self, name: str) -> Any: + try: + return self[name] + except KeyError: + raise AttributeError(name) + + def __setattr__(self, name: str, value: Any) -> None: + self[name] = value + + def __delattr__(self, name: str) -> None: + del self[name] + + +class Logger(object): + """Redirect stderr to stdout, optionally print stdout to a file, and optionally force flushing on both stdout and the file.""" + + def __init__(self, file_name: str = None, file_mode: str = "w", should_flush: bool = True): + self.file = None + + if file_name is not None: + self.file = open(file_name, file_mode) + + self.should_flush = should_flush + self.stdout = sys.stdout + self.stderr = sys.stderr + + sys.stdout = self + sys.stderr = self + + def __enter__(self) -> "Logger": + return self + + def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: + self.close() + + def write(self, text: Union[str, bytes]) -> None: + """Write text to stdout (and a file) and optionally flush.""" + if isinstance(text, bytes): + text = text.decode() + if len(text) == 0: # workaround for a bug in VSCode debugger: sys.stdout.write(''); sys.stdout.flush() => crash + return + + if self.file is not None: + self.file.write(text) + + self.stdout.write(text) + + if self.should_flush: + self.flush() + + def flush(self) -> None: + """Flush written text to both stdout and a file, if open.""" + if self.file is not None: + self.file.flush() + + self.stdout.flush() + + def close(self) -> None: + """Flush, close possible files, and remove stdout/stderr mirroring.""" + self.flush() + + # if using multiple loggers, prevent closing in wrong order + if sys.stdout is self: + sys.stdout = self.stdout + if sys.stderr is self: + sys.stderr = self.stderr + + if self.file is not None: + self.file.close() + self.file = None + + +# Cache directories +# ------------------------------------------------------------------------------------------ + +_dnnlib_cache_dir = None + +def set_cache_dir(path: str) -> None: + global _dnnlib_cache_dir + _dnnlib_cache_dir = path + +def make_cache_dir_path(*paths: str) -> str: + if _dnnlib_cache_dir is not None: + return os.path.join(_dnnlib_cache_dir, *paths) + if 'DNNLIB_CACHE_DIR' in os.environ: + return os.path.join(os.environ['DNNLIB_CACHE_DIR'], *paths) + if 'HOME' in os.environ: + return os.path.join(os.environ['HOME'], '.cache', 'dnnlib', *paths) + if 'USERPROFILE' in os.environ: + return os.path.join(os.environ['USERPROFILE'], '.cache', 'dnnlib', *paths) + return os.path.join(tempfile.gettempdir(), '.cache', 'dnnlib', *paths) + +# Small util functions +# ------------------------------------------------------------------------------------------ + + +def format_time(seconds: Union[int, float]) -> str: + """Convert the seconds to human readable string with days, hours, minutes and seconds.""" + s = int(np.rint(seconds)) + + if s < 60: + return "{0}s".format(s) + elif s < 60 * 60: + return "{0}m {1:02}s".format(s // 60, s % 60) + elif s < 24 * 60 * 60: + return "{0}h {1:02}m {2:02}s".format(s // (60 * 60), (s // 60) % 60, s % 60) + else: + return "{0}d {1:02}h {2:02}m".format(s // (24 * 60 * 60), (s // (60 * 60)) % 24, (s // 60) % 60) + + +def ask_yes_no(question: str) -> bool: + """Ask the user the question until the user inputs a valid answer.""" + while True: + try: + print("{0} [y/n]".format(question)) + return strtobool(input().lower()) + except ValueError: + pass + + +def tuple_product(t: Tuple) -> Any: + """Calculate the product of the tuple elements.""" + result = 1 + + for v in t: + result *= v + + return result + + +_str_to_ctype = { + "uint8": ctypes.c_ubyte, + "uint16": ctypes.c_uint16, + "uint32": ctypes.c_uint32, + "uint64": ctypes.c_uint64, + "int8": ctypes.c_byte, + "int16": ctypes.c_int16, + "int32": ctypes.c_int32, + "int64": ctypes.c_int64, + "float32": ctypes.c_float, + "float64": ctypes.c_double +} + + +def get_dtype_and_ctype(type_obj: Any) -> Tuple[np.dtype, Any]: + """Given a type name string (or an object having a __name__ attribute), return matching Numpy and ctypes types that have the same size in bytes.""" + type_str = None + + if isinstance(type_obj, str): + type_str = type_obj + elif hasattr(type_obj, "__name__"): + type_str = type_obj.__name__ + elif hasattr(type_obj, "name"): + type_str = type_obj.name + else: + raise RuntimeError("Cannot infer type name from input") + + assert type_str in _str_to_ctype.keys() + + my_dtype = np.dtype(type_str) + my_ctype = _str_to_ctype[type_str] + + assert my_dtype.itemsize == ctypes.sizeof(my_ctype) + + return my_dtype, my_ctype + + +def is_pickleable(obj: Any) -> bool: + try: + with io.BytesIO() as stream: + pickle.dump(obj, stream) + return True + except: + return False + + +# Functionality to import modules/objects by name, and call functions by name +# ------------------------------------------------------------------------------------------ + +def get_module_from_obj_name(obj_name: str) -> Tuple[types.ModuleType, str]: + """Searches for the underlying module behind the name to some python object. + Returns the module and the object name (original name with module part removed).""" + + # allow convenience shorthands, substitute them by full names + obj_name = re.sub("^np.", "numpy.", obj_name) + obj_name = re.sub("^tf.", "tensorflow.", obj_name) + + # list alternatives for (module_name, local_obj_name) + parts = obj_name.split(".") + name_pairs = [(".".join(parts[:i]), ".".join(parts[i:])) for i in range(len(parts), 0, -1)] + + # try each alternative in turn + for module_name, local_obj_name in name_pairs: + try: + module = importlib.import_module(module_name) # may raise ImportError + get_obj_from_module(module, local_obj_name) # may raise AttributeError + return module, local_obj_name + except: + pass + + # maybe some of the modules themselves contain errors? + for module_name, _local_obj_name in name_pairs: + try: + importlib.import_module(module_name) # may raise ImportError + except ImportError: + if not str(sys.exc_info()[1]).startswith("No module named '" + module_name + "'"): + raise + + # maybe the requested attribute is missing? + for module_name, local_obj_name in name_pairs: + try: + module = importlib.import_module(module_name) # may raise ImportError + get_obj_from_module(module, local_obj_name) # may raise AttributeError + except ImportError: + pass + + # we are out of luck, but we have no idea why + raise ImportError(obj_name) + + +def get_obj_from_module(module: types.ModuleType, obj_name: str) -> Any: + """Traverses the object name and returns the last (rightmost) python object.""" + if obj_name == '': + return module + obj = module + for part in obj_name.split("."): + obj = getattr(obj, part) + return obj + + +def get_obj_by_name(name: str) -> Any: + """Finds the python object with the given name.""" + module, obj_name = get_module_from_obj_name(name) + return get_obj_from_module(module, obj_name) + + +def call_func_by_name(*args, func_name: str = None, **kwargs) -> Any: + """Finds the python object with the given name and calls it as a function.""" + assert func_name is not None + func_obj = get_obj_by_name(func_name) + assert callable(func_obj) + return func_obj(*args, **kwargs) + + +def construct_class_by_name(*args, class_name: str = None, **kwargs) -> Any: + """Finds the python class with the given name and constructs it with the given arguments.""" + return call_func_by_name(*args, func_name=class_name, **kwargs) + + +def get_module_dir_by_obj_name(obj_name: str) -> str: + """Get the directory path of the module containing the given object name.""" + module, _ = get_module_from_obj_name(obj_name) + return os.path.dirname(inspect.getfile(module)) + + +def is_top_level_function(obj: Any) -> bool: + """Determine whether the given object is a top-level function, i.e., defined at module scope using 'def'.""" + return callable(obj) and obj.__name__ in sys.modules[obj.__module__].__dict__ + + +def get_top_level_function_name(obj: Any) -> str: + """Return the fully-qualified name of a top-level function.""" + assert is_top_level_function(obj) + module = obj.__module__ + if module == '__main__': + module = os.path.splitext(os.path.basename(sys.modules[module].__file__))[0] + return module + "." + obj.__name__ + + +# File system helpers +# ------------------------------------------------------------------------------------------ + +def list_dir_recursively_with_ignore(dir_path: str, ignores: List[str] = None, add_base_to_relative: bool = False) -> List[Tuple[str, str]]: + """List all files recursively in a given directory while ignoring given file and directory names. + Returns list of tuples containing both absolute and relative paths.""" + assert os.path.isdir(dir_path) + base_name = os.path.basename(os.path.normpath(dir_path)) + + if ignores is None: + ignores = [] + + result = [] + + for root, dirs, files in os.walk(dir_path, topdown=True): + for ignore_ in ignores: + dirs_to_remove = [d for d in dirs if fnmatch.fnmatch(d, ignore_)] + + # dirs need to be edited in-place + for d in dirs_to_remove: + dirs.remove(d) + + files = [f for f in files if not fnmatch.fnmatch(f, ignore_)] + + absolute_paths = [os.path.join(root, f) for f in files] + relative_paths = [os.path.relpath(p, dir_path) for p in absolute_paths] + + if add_base_to_relative: + relative_paths = [os.path.join(base_name, p) for p in relative_paths] + + assert len(absolute_paths) == len(relative_paths) + result += zip(absolute_paths, relative_paths) + + return result + + +def copy_files_and_create_dirs(files: List[Tuple[str, str]]) -> None: + """Takes in a list of tuples of (src, dst) paths and copies files. + Will create all necessary directories.""" + for file in files: + target_dir_name = os.path.dirname(file[1]) + + # will create all intermediate-level directories + if not os.path.exists(target_dir_name): + os.makedirs(target_dir_name) + + shutil.copyfile(file[0], file[1]) + + +# URL helpers +# ------------------------------------------------------------------------------------------ + +def is_url(obj: Any, allow_file_urls: bool = False) -> bool: + """Determine whether the given object is a valid URL string.""" + if not isinstance(obj, str) or not "://" in obj: + return False + if allow_file_urls and obj.startswith('file://'): + return True + try: + res = requests.compat.urlparse(obj) + if not res.scheme or not res.netloc or not "." in res.netloc: + return False + res = requests.compat.urlparse(requests.compat.urljoin(obj, "/")) + if not res.scheme or not res.netloc or not "." in res.netloc: + return False + except: + return False + return True + + +def open_url(url: str, cache_dir: str = None, num_attempts: int = 10, verbose: bool = True, return_filename: bool = False, cache: bool = True) -> Any: + """Download the given URL and return a binary-mode file object to access the data.""" + assert num_attempts >= 1 + assert not (return_filename and (not cache)) + + # Doesn't look like an URL scheme so interpret it as a local filename. + if not re.match('^[a-z]+://', url): + return url if return_filename else open(url, "rb") + + # Handle file URLs. This code handles unusual file:// patterns that + # arise on Windows: + # + # file:///c:/foo.txt + # + # which would translate to a local '/c:/foo.txt' filename that's + # invalid. Drop the forward slash for such pathnames. + # + # If you touch this code path, you should test it on both Linux and + # Windows. + # + # Some internet resources suggest using urllib.request.url2pathname() but + # but that converts forward slashes to backslashes and this causes + # its own set of problems. + if url.startswith('file://'): + filename = urllib.parse.urlparse(url).path + if re.match(r'^/[a-zA-Z]:', filename): + filename = filename[1:] + return filename if return_filename else open(filename, "rb") + + assert is_url(url) + + # Lookup from cache. + if cache_dir is None: + cache_dir = make_cache_dir_path('downloads') + + url_md5 = hashlib.md5(url.encode("utf-8")).hexdigest() + if cache: + cache_files = glob.glob(os.path.join(cache_dir, url_md5 + "_*")) + if len(cache_files) == 1: + filename = cache_files[0] + return filename if return_filename else open(filename, "rb") + + # Download. + url_name = None + url_data = None + with requests.Session() as session: + if verbose: + print("Downloading %s ..." % url, end="", flush=True) + for attempts_left in reversed(range(num_attempts)): + try: + with session.get(url) as res: + res.raise_for_status() + if len(res.content) == 0: + raise IOError("No data received") + + if len(res.content) < 8192: + content_str = res.content.decode("utf-8") + if "download_warning" in res.headers.get("Set-Cookie", ""): + links = [html.unescape(link) for link in content_str.split('"') if "export=download" in link] + if len(links) == 1: + url = requests.compat.urljoin(url, links[0]) + raise IOError("Google Drive virus checker nag") + if "Google Drive - Quota exceeded" in content_str: + raise IOError("Google Drive download quota exceeded -- please try again later") + + match = re.search(r'filename="([^"]*)"', res.headers.get("Content-Disposition", "")) + url_name = match[1] if match else url + url_data = res.content + if verbose: + print(" done") + break + except KeyboardInterrupt: + raise + except: + if not attempts_left: + if verbose: + print(" failed") + raise + if verbose: + print(".", end="", flush=True) + + # Save to cache. + if cache: + safe_name = re.sub(r"[^0-9a-zA-Z-._]", "_", url_name) + cache_file = os.path.join(cache_dir, url_md5 + "_" + safe_name) + temp_file = os.path.join(cache_dir, "tmp_" + uuid.uuid4().hex + "_" + url_md5 + "_" + safe_name) + os.makedirs(cache_dir, exist_ok=True) + with open(temp_file, "wb") as f: + f.write(url_data) + os.replace(temp_file, cache_file) # atomic + if return_filename: + return cache_file + + # Return data as file object. + assert not return_filename + return io.BytesIO(url_data) diff --git a/edit.py b/edit.py new file mode 100644 index 0000000000000000000000000000000000000000..5c7326b4859fa4c53869976e0288daed5fe67a4a --- /dev/null +++ b/edit.py @@ -0,0 +1,84 @@ +import wandb +import click +import os +import sys +import pickle +import numpy as np +from PIL import Image +import glob +import torch +from configs import paths_config, hyperparameters, global_config +from IPython.display import display +import matplotlib.pyplot as plt +from scripts.latent_editor_wrapper import LatentEditorWrapper + + +image_dir_name = '/home/sayantan/processed_images' +use_multi_id_training = False +global_config.device = 'cuda' +paths_config.e4e = '/home/sayantan/PTI/pretrained_models/e4e_ffhq_encode.pt' +paths_config.input_data_id = image_dir_name +paths_config.input_data_path = f'{image_dir_name}' +paths_config.stylegan2_ada_ffhq = '/home/sayantan/PTI/pretrained_models/ffhq.pkl' +paths_config.checkpoints_dir = '/home/sayantan/PTI/' +paths_config.style_clip_pretrained_mappers = '/home/sayantan/PTI/pretrained_models' +hyperparameters.use_locality_regularization = False +hyperparameters.lpips_type = 'squeeze' + +model_id = "MYJJDFVGATAT" + + + +def display_alongside_source_image(images): + res = np.concatenate([np.array(image) for image in images], axis=1) + return Image.fromarray(res) + +def load_generators(model_id, image_name): + with open(paths_config.stylegan2_ada_ffhq, 'rb') as f: + old_G = pickle.load(f)['G_ema'].cuda() + + with open(f'{paths_config.checkpoints_dir}/model_{model_id}_{image_name}.pt', 'rb') as f_new: + new_G = torch.load(f_new).cuda() + + return old_G, new_G + +def plot_syn_images(syn_images,text): + for img in syn_images: + img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8).detach().cpu().numpy()[0] + plt.axis('off') + resized_image = Image.fromarray(img,mode='RGB').resize((256,256)) + display(resized_image) + #wandb.log({text: [wandb.Image(resized_image, caption="Label")]}) + del img + del resized_image + torch.cuda.empty_cache() + +def syn_images_wandb(img): + img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8).detach().cpu().numpy()[0] + plt.axis('off') + resized_image = Image.fromarray(img,mode='RGB').resize((256,256)) + return resized_image + + +def edit(image_name): + generator_type = paths_config.multi_id_model_type if use_multi_id_training else image_name + old_G, new_G = load_generators(model_id, generator_type) + w_path_dir = f'{paths_config.embedding_base_dir}/{paths_config.input_data_id}' + + embedding_dir = f'{w_path_dir}/{paths_config.pti_results_keyword}/{image_name}' + w_pivot = torch.load(f'{embedding_dir}/0.pt') + + old_image = old_G.synthesis(w_pivot, noise_mode='const', force_fp32 = True) + new_image = new_G.synthesis(w_pivot, noise_mode='const', force_fp32 = True) + + latent_editor = LatentEditorWrapper() + latents_after_edit = latent_editor.get_single_interface_gan_edits(w_pivot, [i for i in range(-5,5)]) + + for direction, factor_and_edit in latents_after_edit.items(): + for editkey in factor_and_edit.keys(): + new_image = new_G.synthesis(factor_and_edit[editkey], noise_mode='const', force_fp32 = True) + image_pil = syn_images_wandb(new_image).save(f"/home/sayantan/PTI/{direction}/{editkey}/{image_name}.jpg") + +if __name__ == '__main__': + for image_name in [f.split(".")[0].split("_")[2] for f in sorted(glob.glob("*.pt"))]: + edit(image_name) diff --git a/editings/__pycache__/ganspace.cpython-39.pyc b/editings/__pycache__/ganspace.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cd0afb556095b4a93e9b3ad442840502cdd43027 Binary files /dev/null and b/editings/__pycache__/ganspace.cpython-39.pyc differ diff --git a/editings/__pycache__/latent_editor.cpython-39.pyc b/editings/__pycache__/latent_editor.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6d19dedb29da77546f668678ffb4ac4aa9cd513b Binary files /dev/null and b/editings/__pycache__/latent_editor.cpython-39.pyc differ diff --git a/editings/ganspace.py b/editings/ganspace.py new file mode 100644 index 0000000000000000000000000000000000000000..ee1e28c76de89f690e563902def42e3738dc677f --- /dev/null +++ b/editings/ganspace.py @@ -0,0 +1,21 @@ +import torch + + +def edit(latents, pca, edit_directions): + edit_latents = [] + for latent in latents: + for pca_idx, start, end, strength in edit_directions: + delta = get_delta(pca, latent, pca_idx, strength) + delta_padded = torch.zeros(latent.shape).to('cuda') + delta_padded[start:end] += delta.repeat(end - start, 1) + edit_latents.append(latent + delta_padded) + return torch.stack(edit_latents) + + +def get_delta(pca, latent, idx, strength): + w_centered = latent - pca['mean'].to('cuda') + lat_comp = pca['comp'].to('cuda') + lat_std = pca['std'].to('cuda') + w_coord = torch.sum(w_centered[0].reshape(-1)*lat_comp[idx].reshape(-1)) / lat_std[idx] + delta = (strength - w_coord)*lat_comp[idx]*lat_std[idx] + return delta diff --git a/editings/ganspace_pca/ffhq_pca.pt b/editings/ganspace_pca/ffhq_pca.pt new file mode 100644 index 0000000000000000000000000000000000000000..c07f44fc070557d62ad2c8b105486ebf78a1f82d Binary files /dev/null and b/editings/ganspace_pca/ffhq_pca.pt differ diff --git a/editings/interfacegan_directions/age.pt b/editings/interfacegan_directions/age.pt new file mode 100644 index 0000000000000000000000000000000000000000..73b4e6c9848e68d4d033146c20921cd0594f5943 Binary files /dev/null and b/editings/interfacegan_directions/age.pt differ diff --git a/editings/interfacegan_directions/rotation.pt b/editings/interfacegan_directions/rotation.pt new file mode 100644 index 0000000000000000000000000000000000000000..919dfda31918ecc39ab44bf2131cca5712c6c37c Binary files /dev/null and b/editings/interfacegan_directions/rotation.pt differ diff --git a/editings/interfacegan_directions/smile.pt b/editings/interfacegan_directions/smile.pt new file mode 100644 index 0000000000000000000000000000000000000000..3c44456cefdeecf940ab21e0c2d3024a3a6d6432 Binary files /dev/null and b/editings/interfacegan_directions/smile.pt differ diff --git a/editings/latent_editor.py b/editings/latent_editor.py new file mode 100644 index 0000000000000000000000000000000000000000..32554e8010c4da27aaded1b0ce938bd37d5e242b --- /dev/null +++ b/editings/latent_editor.py @@ -0,0 +1,23 @@ +import torch + +from configs import paths_config +from editings import ganspace +from utils.data_utils import tensor2im + + +class LatentEditor(object): + + def apply_ganspace(self, latent, ganspace_pca, edit_directions): + edit_latents = ganspace.edit(latent, ganspace_pca, edit_directions) + return edit_latents + + def apply_interfacegan(self, latent, direction, factor=1, factor_range=None): + edit_latents = [] + if factor_range is not None: # Apply a range of editing factors. for example, (-5, 5) + for f in range(*factor_range): + edit_latent = latent + f * direction + edit_latents.append(edit_latent) + edit_latents = torch.cat(edit_latents) + else: + edit_latents = latent + factor * direction + return edit_latents diff --git a/evaluation/experiment_setting_creator.py b/evaluation/experiment_setting_creator.py new file mode 100644 index 0000000000000000000000000000000000000000..c8ad234ba845d84ddd435424a7fe9ed238af3ff6 --- /dev/null +++ b/evaluation/experiment_setting_creator.py @@ -0,0 +1,43 @@ +import glob +import os +from configs import global_config, paths_config, hyperparameters +from scripts.latent_creators.sg2_plus_latent_creator import SG2PlusLatentCreator +from scripts.latent_creators.e4e_latent_creator import E4ELatentCreator +from scripts.run_pti import run_PTI +import pickle +import torch +from utils.models_utils import toogle_grad, load_old_G + + +class ExperimentRunner: + + def __init__(self, run_id=''): + self.images_paths = glob.glob(f'{paths_config.input_data_path}/*') + self.target_paths = glob.glob(f'{paths_config.input_data_path}/*') + self.run_id = run_id + self.sampled_ws = None + + self.old_G = load_old_G() + + toogle_grad(self.old_G, False) + + def run_experiment(self, run_pt, create_other_latents, use_multi_id_training, use_wandb=False): + if run_pt: + self.run_id = run_PTI(self.run_id, use_wandb=use_wandb, use_multi_id_training=use_multi_id_training) + if create_other_latents: + sg2_plus_latent_creator = SG2PlusLatentCreator(use_wandb=use_wandb) + sg2_plus_latent_creator.create_latents() + e4e_latent_creator = E4ELatentCreator(use_wandb=use_wandb) + e4e_latent_creator.create_latents() + + torch.cuda.empty_cache() + + return self.run_id + + +if __name__ == '__main__': + os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID' + os.environ['CUDA_VISIBLE_DEVICES'] = global_config.cuda_visible_devices + + runner = ExperimentRunner() + runner.run_experiment(True, False, False) diff --git a/evaluation/qualitative_edit_comparison.py b/evaluation/qualitative_edit_comparison.py new file mode 100644 index 0000000000000000000000000000000000000000..39ed13264a9df5a257746f02b070c54934eb3117 --- /dev/null +++ b/evaluation/qualitative_edit_comparison.py @@ -0,0 +1,156 @@ +import os +from random import choice +from string import ascii_uppercase +from PIL import Image +from tqdm import tqdm +from scripts.latent_editor_wrapper import LatentEditorWrapper +from evaluation.experiment_setting_creator import ExperimentRunner +import torch +from configs import paths_config, hyperparameters, evaluation_config +from utils.log_utils import save_concat_image, save_single_image +from utils.models_utils import load_tuned_G + + +class EditComparison: + + def __init__(self, save_single_images, save_concatenated_images, run_id): + + self.run_id = run_id + self.experiment_creator = ExperimentRunner(run_id) + self.save_single_images = save_single_images + self.save_concatenated_images = save_concatenated_images + self.latent_editor = LatentEditorWrapper() + + def save_reconstruction_images(self, image_latents, new_inv_image_latent, new_G, target_image): + if self.save_concatenated_images: + save_concat_image(self.concat_base_dir, image_latents, new_inv_image_latent, new_G, + self.experiment_creator.old_G, + 'rec', + target_image) + + if self.save_single_images: + save_single_image(self.single_base_dir, new_inv_image_latent, new_G, 'rec') + target_image.save(f'{self.single_base_dir}/Original.jpg') + + def create_output_dirs(self, full_image_name): + output_base_dir_path = f'{paths_config.experiments_output_dir}/{paths_config.input_data_id}/{self.run_id}/{full_image_name}' + os.makedirs(output_base_dir_path, exist_ok=True) + + self.concat_base_dir = f'{output_base_dir_path}/concat_images' + self.single_base_dir = f'{output_base_dir_path}/single_images' + + os.makedirs(self.concat_base_dir, exist_ok=True) + os.makedirs(self.single_base_dir, exist_ok=True) + + def get_image_latent_codes(self, image_name): + image_latents = [] + for method in evaluation_config.evaluated_methods: + if method == 'SG2': + image_latents.append(torch.load( + f'{paths_config.embedding_base_dir}/{paths_config.input_data_id}/' + f'{paths_config.pti_results_keyword}/{image_name}/0.pt')) + else: + image_latents.append(torch.load( + f'{paths_config.embedding_base_dir}/{paths_config.input_data_id}/{method}/{image_name}/0.pt')) + new_inv_image_latent = torch.load( + f'{paths_config.embedding_base_dir}/{paths_config.input_data_id}/{paths_config.pti_results_keyword}/{image_name}/0.pt') + + return image_latents, new_inv_image_latent + + def save_interfacegan_edits(self, image_latents, new_inv_image_latent, interfacegan_factors, new_G, target_image): + new_w_inv_edits = self.latent_editor.get_single_interface_gan_edits(new_inv_image_latent, + interfacegan_factors) + + inv_edits = [] + for latent in image_latents: + inv_edits.append(self.latent_editor.get_single_interface_gan_edits(latent, interfacegan_factors)) + + for direction, edits in new_w_inv_edits.items(): + for factor, edit_tensor in edits.items(): + if self.save_concatenated_images: + save_concat_image(self.concat_base_dir, [edits[direction][factor] for edits in inv_edits], + new_w_inv_edits[direction][factor], + new_G, + self.experiment_creator.old_G, + f'{direction}_{factor}', target_image) + if self.save_single_images: + save_single_image(self.single_base_dir, new_w_inv_edits[direction][factor], new_G, + f'{direction}_{factor}') + + def save_ganspace_edits(self, image_latents, new_inv_image_latent, factors, new_G, target_image): + new_w_inv_edits = self.latent_editor.get_single_ganspace_edits(new_inv_image_latent, factors) + inv_edits = [] + for latent in image_latents: + inv_edits.append(self.latent_editor.get_single_ganspace_edits(latent, factors)) + + for idx in range(len(new_w_inv_edits)): + if self.save_concatenated_images: + save_concat_image(self.concat_base_dir, [edit[idx] for edit in inv_edits], new_w_inv_edits[idx], + new_G, + self.experiment_creator.old_G, + f'ganspace_{idx}', target_image) + if self.save_single_images: + save_single_image(self.single_base_dir, new_w_inv_edits[idx], new_G, + f'ganspace_{idx}') + + def run_experiment(self, run_pt, create_other_latents, use_multi_id_training, use_wandb=False): + images_counter = 0 + new_G = None + interfacegan_factors = [val / 2 for val in range(-6, 7) if val != 0] + ganspace_factors = range(-20, 25, 5) + self.experiment_creator.run_experiment(run_pt, create_other_latents, use_multi_id_training, use_wandb) + + if use_multi_id_training: + new_G = load_tuned_G(self.run_id, paths_config.multi_id_model_type) + + for idx, image_path in tqdm(enumerate(self.experiment_creator.images_paths), + total=len(self.experiment_creator.images_paths)): + + if images_counter >= hyperparameters.max_images_to_invert: + break + + image_name = image_path.split('.')[0].split('/')[-1] + target_image = Image.open(self.experiment_creator.target_paths[idx]) + + if not use_multi_id_training: + new_G = load_tuned_G(self.run_id, image_name) + + image_latents, new_inv_image_latent = self.get_image_latent_codes(image_name) + + self.create_output_dirs(image_name) + + self.save_reconstruction_images(image_latents, new_inv_image_latent, new_G, target_image) + + self.save_interfacegan_edits(image_latents, new_inv_image_latent, interfacegan_factors, new_G, target_image) + + self.save_ganspace_edits(image_latents, new_inv_image_latent, ganspace_factors, new_G, target_image) + + target_image.close() + torch.cuda.empty_cache() + images_counter += 1 + + +def run_pti_and_full_edit(iid): + evaluation_config.evaluated_methods = ['SG2Plus', 'e4e', 'SG2'] + edit_figure_creator = EditComparison(save_single_images=True, save_concatenated_images=True, + run_id=f'{paths_config.input_data_id}_pti_full_edit_{iid}') + edit_figure_creator.run_experiment(True, True, use_multi_id_training=False, use_wandb=False) + + +def pti_no_comparison(iid): + evaluation_config.evaluated_methods = [] + edit_figure_creator = EditComparison(save_single_images=True, save_concatenated_images=True, + run_id=f'{paths_config.input_data_id}_pti_no_comparison_{iid}') + edit_figure_creator.run_experiment(True, False, use_multi_id_training=False, use_wandb=False) + + +def edits_for_existed_experiment(run_id): + evaluation_config.evaluated_methods = ['SG2Plus', 'e4e', 'SG2'] + edit_figure_creator = EditComparison(save_single_images=True, save_concatenated_images=True, + run_id=run_id) + edit_figure_creator.run_experiment(False, True, use_multi_id_training=False, use_wandb=False) + + +if __name__ == '__main__': + iid = ''.join(choice(ascii_uppercase) for i in range(7)) + pti_no_comparison(iid) diff --git a/makedirs.py b/makedirs.py new file mode 100644 index 0000000000000000000000000000000000000000..ef633d28b8eb07cfdf5c1cf6759a3d532cff0424 --- /dev/null +++ b/makedirs.py @@ -0,0 +1,84 @@ +import wandb +import click +import os +import sys +import pickle +import numpy as np +from PIL import Image +import torch +from configs import paths_config, hyperparameters, global_config +from IPython.display import display +import matplotlib.pyplot as plt +from scripts.latent_editor_wrapper import LatentEditorWrapper + + +image_dir_name = '/home/sayantan/processed_images' +use_multi_id_training = False +global_config.device = 'cuda' +paths_config.e4e = '/home/sayantan/PTI/pretrained_models/e4e_ffhq_encode.pt' +paths_config.input_data_id = image_dir_name +paths_config.input_data_path = f'{image_dir_name}' +paths_config.stylegan2_ada_ffhq = '/home/sayantan/PTI/pretrained_models/ffhq.pkl' +paths_config.checkpoints_dir = '/home/sayantan/PTI/' +paths_config.style_clip_pretrained_mappers = '/home/sayantan/PTI/pretrained_models' +hyperparameters.use_locality_regularization = False +hyperparameters.lpips_type = 'squeeze' + +model_id = "MYJJDFVGATAT" + + + +def display_alongside_source_image(images): + res = np.concatenate([np.array(image) for image in images], axis=1) + return Image.fromarray(res) + +def load_generators(model_id, image_name): + with open(paths_config.stylegan2_ada_ffhq, 'rb') as f: + old_G = pickle.load(f)['G_ema'].cuda() + + with open(f'{paths_config.checkpoints_dir}/model_{model_id}_{image_name}.pt', 'rb') as f_new: + new_G = torch.load(f_new).cuda() + + return old_G, new_G + +def plot_syn_images(syn_images,text): + for img in syn_images: + img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8).detach().cpu().numpy()[0] + plt.axis('off') + resized_image = Image.fromarray(img,mode='RGB').resize((256,256)) + display(resized_image) + #wandb.log({text: [wandb.Image(resized_image, caption="Label")]}) + del img + del resized_image + torch.cuda.empty_cache() + +def syn_images_wandb(img): + img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8).detach().cpu().numpy()[0] + plt.axis('off') + resized_image = Image.fromarray(img,mode='RGB').resize((256,256)) + return resized_image + +@click.command() +@click.pass_context +@click.option('--image_name', prompt='image name', help='The name for image') + +def makedir(ctx: click.Context,image_name): + generator_type = paths_config.multi_id_model_type if use_multi_id_training else image_name + old_G, new_G = load_generators(model_id, generator_type) + w_path_dir = f'{paths_config.embedding_base_dir}/{paths_config.input_data_id}' + + embedding_dir = f'{w_path_dir}/{paths_config.pti_results_keyword}/{image_name}' + w_pivot = torch.load(f'{embedding_dir}/0.pt') + + old_image = old_G.synthesis(w_pivot, noise_mode='const', force_fp32 = True) + new_image = new_G.synthesis(w_pivot, noise_mode='const', force_fp32 = True) + + latent_editor = LatentEditorWrapper() + latents_after_edit = latent_editor.get_single_interface_gan_edits(w_pivot, [i for i in range(-5,5)]) + + for direction, factor_and_edit in latents_after_edit.items(): + for editkey in factor_and_edit.keys(): + os.makedirs(f"/home/sayantan/PTI/{direction}/{editkey}") + +if __name__ == '__main__': + makedir() diff --git a/models/StyleCLIP/__init__.py b/models/StyleCLIP/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/models/StyleCLIP/criteria/__init__.py b/models/StyleCLIP/criteria/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/models/StyleCLIP/criteria/clip_loss.py b/models/StyleCLIP/criteria/clip_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..18176ee8eb0d992d69d5b951d7f36e2efa92a37b --- /dev/null +++ b/models/StyleCLIP/criteria/clip_loss.py @@ -0,0 +1,17 @@ + +import torch +import clip + + +class CLIPLoss(torch.nn.Module): + + def __init__(self, opts): + super(CLIPLoss, self).__init__() + self.model, self.preprocess = clip.load("ViT-B/32", device="cuda") + self.upsample = torch.nn.Upsample(scale_factor=7) + self.avg_pool = torch.nn.AvgPool2d(kernel_size=opts.stylegan_size // 32) + + def forward(self, image, text): + image = self.avg_pool(self.upsample(image)) + similarity = 1 - self.model(image, text)[0] / 100 + return similarity \ No newline at end of file diff --git a/models/StyleCLIP/criteria/id_loss.py b/models/StyleCLIP/criteria/id_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..a828023e115243e48918538d31b91d662cd12d0f --- /dev/null +++ b/models/StyleCLIP/criteria/id_loss.py @@ -0,0 +1,39 @@ +import torch +from torch import nn + +from models.facial_recognition.model_irse import Backbone + + +class IDLoss(nn.Module): + def __init__(self, opts): + super(IDLoss, self).__init__() + print('Loading ResNet ArcFace') + self.facenet = Backbone(input_size=112, num_layers=50, drop_ratio=0.6, mode='ir_se') + self.facenet.load_state_dict(torch.load(opts.ir_se50_weights)) + self.pool = torch.nn.AdaptiveAvgPool2d((256, 256)) + self.face_pool = torch.nn.AdaptiveAvgPool2d((112, 112)) + self.facenet.eval() + self.opts = opts + + def extract_feats(self, x): + if x.shape[2] != 256: + x = self.pool(x) + x = x[:, :, 35:223, 32:220] # Crop interesting region + x = self.face_pool(x) + x_feats = self.facenet(x) + return x_feats + + def forward(self, y_hat, y): + n_samples = y.shape[0] + y_feats = self.extract_feats(y) # Otherwise use the feature from there + y_hat_feats = self.extract_feats(y_hat) + y_feats = y_feats.detach() + loss = 0 + sim_improvement = 0 + count = 0 + for i in range(n_samples): + diff_target = y_hat_feats[i].dot(y_feats[i]) + loss += 1 - diff_target + count += 1 + + return loss / count, sim_improvement / count diff --git a/models/StyleCLIP/global_directions/GUI.py b/models/StyleCLIP/global_directions/GUI.py new file mode 100644 index 0000000000000000000000000000000000000000..19f7f8cce9305819b22664642799200d9e1cfff0 --- /dev/null +++ b/models/StyleCLIP/global_directions/GUI.py @@ -0,0 +1,103 @@ + + +from tkinter import Tk,Frame ,Label,Button,messagebox,Canvas,Text,Scale +from tkinter import HORIZONTAL + +class View(): + def __init__(self,master): + + self.width=600 + self.height=600 + + + self.root=master + self.root.geometry("600x600") + + self.left_frame=Frame(self.root,width=600) + self.left_frame.pack_propagate(0) + self.left_frame.pack(fill='both', side='left', expand='True') + + self.retrieval_frame=Frame(self.root,bg='snow3') + self.retrieval_frame.pack_propagate(0) + self.retrieval_frame.pack(fill='both', side='right', expand='True') + + self.bg_frame=Frame(self.left_frame,bg='snow3',height=600,width=600) + self.bg_frame.pack_propagate(0) + self.bg_frame.pack(fill='both', side='top', expand='True') + + self.command_frame=Frame(self.left_frame,bg='snow3') + self.command_frame.pack_propagate(0) + self.command_frame.pack(fill='both', side='bottom', expand='True') +# self.command_frame.grid(row=1, column=0,padx=0, pady=0) + + self.bg=Canvas(self.bg_frame,width=self.width,height=self.height, bg='gray') + self.bg.place(relx=0.5, rely=0.5, anchor='center') + + self.mani=Canvas(self.retrieval_frame,width=1024,height=1024, bg='gray') + self.mani.grid(row=0, column=0,padx=0, pady=42) + + self.SetCommand() + + + + + def run(self): + self.root.mainloop() + + def helloCallBack(self): + category=self.set_category.get() + messagebox.showinfo( "Hello Python",category) + + def SetCommand(self): + + tmp = Label(self.command_frame, text="neutral", width=10 ,bg='snow3') + tmp.grid(row=1, column=0,padx=10, pady=10) + + tmp = Label(self.command_frame, text="a photo of a", width=10 ,bg='snow3') + tmp.grid(row=1, column=1,padx=10, pady=10) + + self.neutral = Text ( self.command_frame, height=2, width=30) + self.neutral.grid(row=1, column=2,padx=10, pady=10) + + + tmp = Label(self.command_frame, text="target", width=10 ,bg='snow3') + tmp.grid(row=2, column=0,padx=10, pady=10) + + tmp = Label(self.command_frame, text="a photo of a", width=10 ,bg='snow3') + tmp.grid(row=2, column=1,padx=10, pady=10) + + self.target = Text ( self.command_frame, height=2, width=30) + self.target.grid(row=2, column=2,padx=10, pady=10) + + tmp = Label(self.command_frame, text="strength", width=10 ,bg='snow3') + tmp.grid(row=3, column=0,padx=10, pady=10) + + self.alpha = Scale(self.command_frame, from_=-15, to=25, orient=HORIZONTAL,bg='snow3', length=250,resolution=0.01) + self.alpha.grid(row=3, column=2,padx=10, pady=10) + + + tmp = Label(self.command_frame, text="disentangle", width=10 ,bg='snow3') + tmp.grid(row=4, column=0,padx=10, pady=10) + + self.beta = Scale(self.command_frame, from_=0.08, to=0.4, orient=HORIZONTAL,bg='snow3', length=250,resolution=0.001) + self.beta.grid(row=4, column=2,padx=10, pady=10) + + self.reset = Button(self.command_frame, text='Reset') + self.reset.grid(row=5, column=1,padx=10, pady=10) + + + self.set_init = Button(self.command_frame, text='Accept') + self.set_init.grid(row=5, column=2,padx=10, pady=10) + +#%% +if __name__ == "__main__": + master=Tk() + self=View(master) + self.run() + + + + + + + \ No newline at end of file diff --git a/models/StyleCLIP/global_directions/GenerateImg.py b/models/StyleCLIP/global_directions/GenerateImg.py new file mode 100644 index 0000000000000000000000000000000000000000..0c6dee48f2d6d9ac37c00ee77c7a46c2cc6b25e1 --- /dev/null +++ b/models/StyleCLIP/global_directions/GenerateImg.py @@ -0,0 +1,50 @@ + +import os +import numpy as np +import argparse +from manipulate import Manipulator + +from PIL import Image +#%% + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description='Process some integers.') + + parser.add_argument('--dataset_name',type=str,default='ffhq', + help='name of dataset, for example, ffhq') + + args = parser.parse_args() + dataset_name=args.dataset_name + + if not os.path.isdir('./data/'+dataset_name): + os.system('mkdir ./data/'+dataset_name) + #%% + M=Manipulator(dataset_name=dataset_name) + np.set_printoptions(suppress=True) + print(M.dataset_name) + #%% + + M.img_index=0 + M.num_images=50 + M.alpha=[0] + M.step=1 + lindex,bname=0,0 + + M.manipulate_layers=[lindex] + codes,out=M.EditOneC(bname) + #%% + + for i in range(len(out)): + img=out[i,0] + img=Image.fromarray(img) + img.save('./data/'+dataset_name+'/'+str(i)+'.jpg') + #%% + w=np.load('./npy/'+dataset_name+'/W.npy') + + tmp=w[:M.num_images] + tmp=tmp[:,None,:] + tmp=np.tile(tmp,(1,M.Gs.components.synthesis.input_shape[1],1)) + + np.save('./data/'+dataset_name+'/w_plus.npy',tmp) + + \ No newline at end of file diff --git a/models/StyleCLIP/global_directions/GetCode.py b/models/StyleCLIP/global_directions/GetCode.py new file mode 100644 index 0000000000000000000000000000000000000000..62e64dc8cbc5ad2bb16aef5da8f6d41c26b24170 --- /dev/null +++ b/models/StyleCLIP/global_directions/GetCode.py @@ -0,0 +1,232 @@ + + + +import os +import pickle +import numpy as np +from dnnlib import tflib +import tensorflow as tf + +import argparse + +def LoadModel(dataset_name): + # Initialize TensorFlow. + tflib.init_tf() + model_path='./model/' + model_name=dataset_name+'.pkl' + + tmp=os.path.join(model_path,model_name) + with open(tmp, 'rb') as f: + _, _, Gs = pickle.load(f) + return Gs + +def lerp(a,b,t): + return a + (b - a) * t + +#stylegan-ada +def SelectName(layer_name,suffix): + if suffix==None: + tmp1='add:0' in layer_name + tmp2='shape=(?,' in layer_name + tmp4='G_synthesis_1' in layer_name + tmp= tmp1 and tmp2 and tmp4 + else: + tmp1=('/Conv0_up'+suffix) in layer_name + tmp2=('/Conv1'+suffix) in layer_name + tmp3=('4x4/Conv'+suffix) in layer_name + tmp4='G_synthesis_1' in layer_name + tmp5=('/ToRGB'+suffix) in layer_name + tmp= (tmp1 or tmp2 or tmp3 or tmp5) and tmp4 + return tmp + + +def GetSNames(suffix): + #get style tensor name + with tf.Session() as sess: + op = sess.graph.get_operations() + layers=[m.values() for m in op] + + + select_layers=[] + for layer in layers: + layer_name=str(layer) + if SelectName(layer_name,suffix): + select_layers.append(layer[0]) + return select_layers + +def SelectName2(layer_name): + tmp1='mod_bias' in layer_name + tmp2='mod_weight' in layer_name + tmp3='ToRGB' in layer_name + + tmp= (tmp1 or tmp2) and (not tmp3) + return tmp + +def GetKName(Gs): + + layers=[var for name, var in Gs.components.synthesis.vars.items()] + + select_layers=[] + for layer in layers: + layer_name=str(layer) + if SelectName2(layer_name): + select_layers.append(layer) + return select_layers + +def GetCode(Gs,random_state,num_img,num_once,dataset_name): + rnd = np.random.RandomState(random_state) #5 + + truncation_psi=0.7 + truncation_cutoff=8 + + dlatent_avg=Gs.get_var('dlatent_avg') + + dlatents=np.zeros((num_img,512),dtype='float32') + for i in range(int(num_img/num_once)): + src_latents = rnd.randn(num_once, Gs.input_shape[1]) + src_dlatents = Gs.components.mapping.run(src_latents, None) # [seed, layer, component] + + # Apply truncation trick. + if truncation_psi is not None and truncation_cutoff is not None: + layer_idx = np.arange(src_dlatents.shape[1])[np.newaxis, :, np.newaxis] + ones = np.ones(layer_idx.shape, dtype=np.float32) + coefs = np.where(layer_idx < truncation_cutoff, truncation_psi * ones, ones) + src_dlatents_np=lerp(dlatent_avg, src_dlatents, coefs) + src_dlatents=src_dlatents_np[:,0,:].astype('float32') + dlatents[(i*num_once):((i+1)*num_once),:]=src_dlatents + print('get all z and w') + + tmp='./npy/'+dataset_name+'/W' + np.save(tmp,dlatents) + + +def GetImg(Gs,num_img,num_once,dataset_name,save_name='images'): + print('Generate Image') + tmp='./npy/'+dataset_name+'/W.npy' + dlatents=np.load(tmp) + fmt = dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True) + + all_images=[] + for i in range(int(num_img/num_once)): + print(i) + images=[] + for k in range(num_once): + tmp=dlatents[i*num_once+k] + tmp=tmp[None,None,:] + tmp=np.tile(tmp,(1,Gs.components.synthesis.input_shape[1],1)) + image2= Gs.components.synthesis.run(tmp, randomize_noise=False, output_transform=fmt) + images.append(image2) + + images=np.concatenate(images) + + all_images.append(images) + + all_images=np.concatenate(all_images) + + tmp='./npy/'+dataset_name+'/'+save_name + np.save(tmp,all_images) + +def GetS(dataset_name,num_img): + print('Generate S') + tmp='./npy/'+dataset_name+'/W.npy' + dlatents=np.load(tmp)[:num_img] + + with tf.Session() as sess: + init = tf.global_variables_initializer() + sess.run(init) + + Gs=LoadModel(dataset_name) + Gs.print_layers() #for ada + select_layers1=GetSNames(suffix=None) #None,'/mul_1:0','/mod_weight/read:0','/MatMul:0' + dlatents=dlatents[:,None,:] + dlatents=np.tile(dlatents,(1,Gs.components.synthesis.input_shape[1],1)) + + all_s = sess.run( + select_layers1, + feed_dict={'G_synthesis_1/dlatents_in:0': dlatents}) + + layer_names=[layer.name for layer in select_layers1] + save_tmp=[layer_names,all_s] + return save_tmp + + + + +def convert_images_to_uint8(images, drange=[-1,1], nchw_to_nhwc=False): + """Convert a minibatch of images from float32 to uint8 with configurable dynamic range. + Can be used as an output transformation for Network.run(). + """ + if nchw_to_nhwc: + images = np.transpose(images, [0, 2, 3, 1]) + + scale = 255 / (drange[1] - drange[0]) + images = images * scale + (0.5 - drange[0] * scale) + + np.clip(images, 0, 255, out=images) + images=images.astype('uint8') + return images + + +def GetCodeMS(dlatents): + m=[] + std=[] + for i in range(len(dlatents)): + tmp= dlatents[i] + tmp_mean=tmp.mean(axis=0) + tmp_std=tmp.std(axis=0) + m.append(tmp_mean) + std.append(tmp_std) + return m,std + + + +#%% +if __name__ == "__main__": + + + parser = argparse.ArgumentParser(description='Process some integers.') + + parser.add_argument('--dataset_name',type=str,default='ffhq', + help='name of dataset, for example, ffhq') + parser.add_argument('--code_type',choices=['w','s','s_mean_std'],default='w') + + args = parser.parse_args() + random_state=5 + num_img=100_000 + num_once=1_000 + dataset_name=args.dataset_name + + if not os.path.isfile('./model/'+dataset_name+'.pkl'): + url='https://nvlabs-fi-cdn.nvidia.com/stylegan2/networks/' + name='stylegan2-'+dataset_name+'-config-f.pkl' + os.system('wget ' +url+name + ' -P ./model/') + os.system('mv ./model/'+name+' ./model/'+dataset_name+'.pkl') + + if not os.path.isdir('./npy/'+dataset_name): + os.system('mkdir ./npy/'+dataset_name) + + if args.code_type=='w': + Gs=LoadModel(dataset_name=dataset_name) + GetCode(Gs,random_state,num_img,num_once,dataset_name) +# GetImg(Gs,num_img=num_img,num_once=num_once,dataset_name=dataset_name,save_name='images_100K') #no need + elif args.code_type=='s': + save_name='S' + save_tmp=GetS(dataset_name,num_img=2_000) + tmp='./npy/'+dataset_name+'/'+save_name + with open(tmp, "wb") as fp: + pickle.dump(save_tmp, fp) + + elif args.code_type=='s_mean_std': + save_tmp=GetS(dataset_name,num_img=num_img) + dlatents=save_tmp[1] + m,std=GetCodeMS(dlatents) + save_tmp=[m,std] + save_name='S_mean_std' + tmp='./npy/'+dataset_name+'/'+save_name + with open(tmp, "wb") as fp: + pickle.dump(save_tmp, fp) + + + + + diff --git a/models/StyleCLIP/global_directions/GetGUIData.py b/models/StyleCLIP/global_directions/GetGUIData.py new file mode 100644 index 0000000000000000000000000000000000000000..52f77213ab88edf8b33eff166b89b9e56ac4ff01 --- /dev/null +++ b/models/StyleCLIP/global_directions/GetGUIData.py @@ -0,0 +1,67 @@ + +import os +import numpy as np +import argparse +from manipulate import Manipulator +import torch +from PIL import Image +#%% + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description='Process some integers.') + + parser.add_argument('--dataset_name',type=str,default='ffhq', + help='name of dataset, for example, ffhq') + + parser.add_argument('--real', action='store_true') + + args = parser.parse_args() + dataset_name=args.dataset_name + + if not os.path.isdir('./data/'+dataset_name): + os.system('mkdir ./data/'+dataset_name) + #%% + M=Manipulator(dataset_name=dataset_name) + np.set_printoptions(suppress=True) + print(M.dataset_name) + #%% + #remove all .jpg + names=os.listdir('./data/'+dataset_name+'/') + for name in names: + if '.jpg' in name: + os.system('rm ./data/'+dataset_name+'/'+name) + + + #%% + if args.real: + latents=torch.load('./data/'+dataset_name+'/latents.pt') + w_plus=latents.cpu().detach().numpy() + else: + w=np.load('./npy/'+dataset_name+'/W.npy') + tmp=w[:50] #only use 50 images + tmp=tmp[:,None,:] + w_plus=np.tile(tmp,(1,M.Gs.components.synthesis.input_shape[1],1)) + np.save('./data/'+dataset_name+'/w_plus.npy',w_plus) + + #%% + tmp=M.W2S(w_plus) + M.dlatents=tmp + + M.img_index=0 + M.num_images=len(w_plus) + M.alpha=[0] + M.step=1 + lindex,bname=0,0 + + M.manipulate_layers=[lindex] + codes,out=M.EditOneC(bname) + #%% + + for i in range(len(out)): + img=out[i,0] + img=Image.fromarray(img) + img.save('./data/'+dataset_name+'/'+str(i)+'.jpg') + #%% + + + \ No newline at end of file diff --git a/models/StyleCLIP/global_directions/Inference.py b/models/StyleCLIP/global_directions/Inference.py new file mode 100644 index 0000000000000000000000000000000000000000..a292787c88a370b15b4f0d633ac27bb5bed2b510 --- /dev/null +++ b/models/StyleCLIP/global_directions/Inference.py @@ -0,0 +1,106 @@ + + +from manipulate import Manipulator +import tensorflow as tf +import numpy as np +import torch +import clip +from MapTS import GetBoundary,GetDt + +class StyleCLIP(): + + def __init__(self,dataset_name='ffhq'): + print('load clip') + device = "cuda" if torch.cuda.is_available() else "cpu" + self.model, preprocess = clip.load("ViT-B/32", device=device) + self.LoadData(dataset_name) + + def LoadData(self, dataset_name): + tf.keras.backend.clear_session() + M=Manipulator(dataset_name=dataset_name) + np.set_printoptions(suppress=True) + fs3=np.load('./npy/'+dataset_name+'/fs3.npy') + + self.M=M + self.fs3=fs3 + + w_plus=np.load('./data/'+dataset_name+'/w_plus.npy') + self.M.dlatents=M.W2S(w_plus) + + if dataset_name=='ffhq': + self.c_threshold=20 + else: + self.c_threshold=100 + self.SetInitP() + + def SetInitP(self): + self.M.alpha=[3] + self.M.num_images=1 + + self.target='' + self.neutral='' + self.GetDt2() + img_index=0 + self.M.dlatent_tmp=[tmp[img_index:(img_index+1)] for tmp in self.M.dlatents] + + + def GetDt2(self): + classnames=[self.target,self.neutral] + dt=GetDt(classnames,self.model) + + self.dt=dt + num_cs=[] + betas=np.arange(0.1,0.3,0.01) + for i in range(len(betas)): + boundary_tmp2,num_c=GetBoundary(self.fs3,self.dt,self.M,threshold=betas[i]) + print(betas[i]) + num_cs.append(num_c) + + num_cs=np.array(num_cs) + select=num_cs>self.c_threshold + + if sum(select)==0: + self.beta=0.1 + else: + self.beta=betas[select][-1] + + + def GetCode(self): + boundary_tmp2,num_c=GetBoundary(self.fs3,self.dt,self.M,threshold=self.beta) + codes=self.M.MSCode(self.M.dlatent_tmp,boundary_tmp2) + return codes + + def GetImg(self): + + codes=self.GetCode() + out=self.M.GenerateImg(codes) + img=out[0,0] + return img + + + + +#%% +if __name__ == "__main__": + style_clip=StyleCLIP() + self=style_clip + + + + + + + + + + + + + + + + + + + + diff --git a/models/StyleCLIP/global_directions/MapTS.py b/models/StyleCLIP/global_directions/MapTS.py new file mode 100644 index 0000000000000000000000000000000000000000..2160a62cdbb0278d213076637f79b1e6f66db906 --- /dev/null +++ b/models/StyleCLIP/global_directions/MapTS.py @@ -0,0 +1,394 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +Created on Thu Feb 4 17:36:31 2021 + +@author: wuzongze +""" + +import os +#os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" +#os.environ["CUDA_VISIBLE_DEVICES"] = "1" #(or "1" or "2") + +import sys + +#sys.path=['', '/usr/local/tensorflow/avx-avx2-gpu/1.14.0/python3.7/site-packages', '/usr/local/matlab/2018b/lib/python3.7/site-packages', '/cs/labs/danix/wuzongze/pythonV/venv3.7/lib/python37.zip', '/cs/labs/danix/wuzongze/pythonV/venv3.7/lib/python3.7', '/cs/labs/danix/wuzongze/pythonV/venv3.7/lib/python3.7/lib-dynload', '/usr/lib/python3.7', '/cs/labs/danix/wuzongze/pythonV/venv3.7/lib/python3.7/site-packages', '/cs/labs/danix/wuzongze/pythonV/venv3.7/lib/python3.7/site-packages/copkmeans-1.5-py3.7.egg', '/cs/labs/danix/wuzongze/pythonV/venv3.7/lib/python3.7/site-packages/spherecluster-0.1.7-py3.7.egg', '/usr/lib/python3/dist-packages', '/usr/local/lib/python3.7/dist-packages', '/usr/lib/python3/dist-packages/IPython/extensions'] + +import tensorflow as tf + +import numpy as np +import torch +import clip +from PIL import Image +import pickle +import copy +import matplotlib.pyplot as plt + +def GetAlign(out,dt,model,preprocess): + imgs=out + imgs1=imgs.reshape([-1]+list(imgs.shape[2:])) + + tmp=[] + for i in range(len(imgs1)): + + img=Image.fromarray(imgs1[i]) + image = preprocess(img).unsqueeze(0).to(device) + tmp.append(image) + + image=torch.cat(tmp) + + with torch.no_grad(): + image_features = model.encode_image(image) + image_features = image_features / image_features.norm(dim=-1, keepdim=True) + + image_features1=image_features.cpu().numpy() + + image_features1=image_features1.reshape(list(imgs.shape[:2])+[512]) + + fd=image_features1[:,1:,:]-image_features1[:,:-1,:] + + fd1=fd.reshape([-1,512]) + fd2=fd1/np.linalg.norm(fd1,axis=1)[:,None] + + tmp=np.dot(fd2,dt) + m=tmp.mean() + acc=np.sum(tmp>0)/len(tmp) + print(m,acc) + return m,acc + + +def SplitS(ds_p,M,if_std): + all_ds=[] + start=0 + for i in M.mindexs: + tmp=M.dlatents[i].shape[1] + end=start+tmp + tmp=ds_p[start:end] +# tmp=tmp*M.code_std[i] + + all_ds.append(tmp) + start=end + + all_ds2=[] + tmp_index=0 + for i in range(len(M.s_names)): + if (not 'RGB' in M.s_names[i]) and (not len(all_ds[tmp_index])==0): + +# tmp=np.abs(all_ds[tmp_index]/M.code_std[i]) +# print(i,tmp.mean()) +# tmp=np.dot(M.latent_codes[i],all_ds[tmp_index]) +# print(tmp) + if if_std: + tmp=all_ds[tmp_index]*M.code_std[i] + else: + tmp=all_ds[tmp_index] + + all_ds2.append(tmp) + tmp_index+=1 + else: + tmp=np.zeros(len(M.dlatents[i][0])) + all_ds2.append(tmp) + return all_ds2 + + +imagenet_templates = [ + 'a bad photo of a {}.', +# 'a photo of many {}.', + 'a sculpture of a {}.', + 'a photo of the hard to see {}.', + 'a low resolution photo of the {}.', + 'a rendering of a {}.', + 'graffiti of a {}.', + 'a bad photo of the {}.', + 'a cropped photo of the {}.', + 'a tattoo of a {}.', + 'the embroidered {}.', + 'a photo of a hard to see {}.', + 'a bright photo of a {}.', + 'a photo of a clean {}.', + 'a photo of a dirty {}.', + 'a dark photo of the {}.', + 'a drawing of a {}.', + 'a photo of my {}.', + 'the plastic {}.', + 'a photo of the cool {}.', + 'a close-up photo of a {}.', + 'a black and white photo of the {}.', + 'a painting of the {}.', + 'a painting of a {}.', + 'a pixelated photo of the {}.', + 'a sculpture of the {}.', + 'a bright photo of the {}.', + 'a cropped photo of a {}.', + 'a plastic {}.', + 'a photo of the dirty {}.', + 'a jpeg corrupted photo of a {}.', + 'a blurry photo of the {}.', + 'a photo of the {}.', + 'a good photo of the {}.', + 'a rendering of the {}.', + 'a {} in a video game.', + 'a photo of one {}.', + 'a doodle of a {}.', + 'a close-up photo of the {}.', + 'a photo of a {}.', + 'the origami {}.', + 'the {} in a video game.', + 'a sketch of a {}.', + 'a doodle of the {}.', + 'a origami {}.', + 'a low resolution photo of a {}.', + 'the toy {}.', + 'a rendition of the {}.', + 'a photo of the clean {}.', + 'a photo of a large {}.', + 'a rendition of a {}.', + 'a photo of a nice {}.', + 'a photo of a weird {}.', + 'a blurry photo of a {}.', + 'a cartoon {}.', + 'art of a {}.', + 'a sketch of the {}.', + 'a embroidered {}.', + 'a pixelated photo of a {}.', + 'itap of the {}.', + 'a jpeg corrupted photo of the {}.', + 'a good photo of a {}.', + 'a plushie {}.', + 'a photo of the nice {}.', + 'a photo of the small {}.', + 'a photo of the weird {}.', + 'the cartoon {}.', + 'art of the {}.', + 'a drawing of the {}.', + 'a photo of the large {}.', + 'a black and white photo of a {}.', + 'the plushie {}.', + 'a dark photo of a {}.', + 'itap of a {}.', + 'graffiti of the {}.', + 'a toy {}.', + 'itap of my {}.', + 'a photo of a cool {}.', + 'a photo of a small {}.', + 'a tattoo of the {}.', +] + + +def zeroshot_classifier(classnames, templates,model): + with torch.no_grad(): + zeroshot_weights = [] + for classname in classnames: + texts = [template.format(classname) for template in templates] #format with class + texts = clip.tokenize(texts).cuda() #tokenize + class_embeddings = model.encode_text(texts) #embed with text encoder + class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True) + class_embedding = class_embeddings.mean(dim=0) + class_embedding /= class_embedding.norm() + zeroshot_weights.append(class_embedding) + zeroshot_weights = torch.stack(zeroshot_weights, dim=1).cuda() + return zeroshot_weights + + +def GetDt(classnames,model): + text_features=zeroshot_classifier(classnames, imagenet_templates,model).t() + + dt=text_features[0]-text_features[1] + dt=dt.cpu().numpy() + +# t_m1=t_m/np.linalg.norm(t_m) +# dt=text_features.cpu().numpy()[0]-t_m1 + print(np.linalg.norm(dt)) + dt=dt/np.linalg.norm(dt) + return dt + + +def GetBoundary(fs3,dt,M,threshold): + tmp=np.dot(fs3,dt) + + ds_imp=copy.copy(tmp) + select=np.abs(tmp)", self.text_n) + self.view.target.bind("", self.text_t) + self.view.alpha.bind('', self.ChangeAlpha) + self.view.beta.bind('', self.ChangeBeta) + self.view.set_init.bind('', self.SetInit) + self.view.reset.bind('', self.Reset) + self.view.bg.bind('', self.open_img) + + + self.drawn = None + + self.view.target.delete(1.0, "end") + self.view.target.insert("end", self.style_clip.target) +# + self.view.neutral.delete(1.0, "end") + self.view.neutral.insert("end", self.style_clip.neutral) + + + def Reset(self,event): + self.style_clip.GetDt2() + self.style_clip.M.alpha=[0] + + self.view.beta.set(self.style_clip.beta) + self.view.alpha.set(0) + + img=self.style_clip.GetImg() + img=Image.fromarray(img) + img = ImageTk.PhotoImage(img) + self.addImage_m(img) + + + def SetInit(self,event): + codes=self.style_clip.GetCode() + self.style_clip.M.dlatent_tmp=[tmp[:,0] for tmp in codes] + print('set init') + + def ChangeAlpha(self,event): + tmp=self.view.alpha.get() + self.style_clip.M.alpha=[float(tmp)] + + img=self.style_clip.GetImg() + print('manipulate one') + img=Image.fromarray(img) + img = ImageTk.PhotoImage(img) + self.addImage_m(img) + + def ChangeBeta(self,event): + tmp=self.view.beta.get() + self.style_clip.beta=float(tmp) + + img=self.style_clip.GetImg() + print('manipulate one') + img=Image.fromarray(img) + img = ImageTk.PhotoImage(img) + self.addImage_m(img) + + def ChangeDataset(self,event): + + dataset_name=self.view.set_category.get() + + self.style_clip.LoadData(dataset_name) + + self.view.target.delete(1.0, "end") + self.view.target.insert("end", self.style_clip.target) + + self.view.neutral.delete(1.0, "end") + self.view.neutral.insert("end", self.style_clip.neutral) + + def text_t(self,event): + tmp=self.view.target.get("1.0",'end') + tmp=tmp.replace('\n','') + + self.view.target.delete(1.0, "end") + self.view.target.insert("end", tmp) + + print('target',tmp,'###') + self.style_clip.target=tmp + self.style_clip.GetDt2() + self.view.beta.set(self.style_clip.beta) + self.view.alpha.set(3) + self.style_clip.M.alpha=[3] + + img=self.style_clip.GetImg() + print('manipulate one') + img=Image.fromarray(img) + img = ImageTk.PhotoImage(img) + self.addImage_m(img) + + + def text_n(self,event): + tmp=self.view.neutral.get("1.0",'end') + tmp=tmp.replace('\n','') + + self.view.neutral.delete(1.0, "end") + self.view.neutral.insert("end", tmp) + + print('neutral',tmp,'###') + self.style_clip.neutral=tmp + self.view.target.delete(1.0, "end") + self.view.target.insert("end", tmp) + + + def run(self): + self.root.mainloop() + + def addImage(self,img): + self.view.bg.create_image(self.view.width/2, self.view.height/2, image=img, anchor='center') + self.image=img #save a copy of image. if not the image will disappear + + def addImage_m(self,img): + self.view.mani.create_image(512, 512, image=img, anchor='center') + self.image2=img + + + def openfn(self): + filename = askopenfilename(title='open',initialdir='./data/'+self.style_clip.M.dataset_name+'/',filetypes=[("all image format", ".jpg"),("all image format", ".png")]) + return filename + + def open_img(self,event): + x = self.openfn() + print(x) + + + img = Image.open(x) + img2 = img.resize(( 512,512), Image.ANTIALIAS) + img2 = ImageTk.PhotoImage(img2) + self.addImage(img2) + + img = ImageTk.PhotoImage(img) + self.addImage_m(img) + + img_index=x.split('/')[-1].split('.')[0] + img_index=int(img_index) + print(img_index) + self.style_clip.M.img_index=img_index + self.style_clip.M.dlatent_tmp=[tmp[img_index:(img_index+1)] for tmp in self.style_clip.M.dlatents] + + + self.style_clip.GetDt2() + self.view.beta.set(self.style_clip.beta) + self.view.alpha.set(3) + + #%% +if __name__ == "__main__": + parser = argparse.ArgumentParser(description='Process some integers.') + + parser.add_argument('--dataset_name',type=str,default='ffhq', + help='name of dataset, for example, ffhq') + + args = parser.parse_args() + dataset_name=args.dataset_name + + self=PlayInteractively(dataset_name) + self.run() + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/models/StyleCLIP/global_directions/SingleChannel.py b/models/StyleCLIP/global_directions/SingleChannel.py new file mode 100644 index 0000000000000000000000000000000000000000..ecaa7ec7898d37f8f5db171f9141a5253af3fa73 --- /dev/null +++ b/models/StyleCLIP/global_directions/SingleChannel.py @@ -0,0 +1,109 @@ + + + +import numpy as np +import torch +import clip +from PIL import Image +import copy +from manipulate import Manipulator +import argparse + +def GetImgF(out,model,preprocess): + imgs=out + imgs1=imgs.reshape([-1]+list(imgs.shape[2:])) + + tmp=[] + for i in range(len(imgs1)): + + img=Image.fromarray(imgs1[i]) + image = preprocess(img).unsqueeze(0).to(device) + tmp.append(image) + + image=torch.cat(tmp) + with torch.no_grad(): + image_features = model.encode_image(image) + + image_features1=image_features.cpu().numpy() + image_features1=image_features1.reshape(list(imgs.shape[:2])+[512]) + + return image_features1 + +def GetFs(fs): + tmp=np.linalg.norm(fs,axis=-1) + fs1=fs/tmp[:,:,:,None] + fs2=fs1[:,:,1,:]-fs1[:,:,0,:] # 5*sigma - (-5)* sigma + fs3=fs2/np.linalg.norm(fs2,axis=-1)[:,:,None] + fs3=fs3.mean(axis=1) + fs3=fs3/np.linalg.norm(fs3,axis=-1)[:,None] + return fs3 + +#%% +if __name__ == "__main__": + parser = argparse.ArgumentParser(description='Process some integers.') + + parser.add_argument('--dataset_name',type=str,default='cat', + help='name of dataset, for example, ffhq') + args = parser.parse_args() + dataset_name=args.dataset_name + + #%% + device = "cuda" if torch.cuda.is_available() else "cpu" + model, preprocess = clip.load("ViT-B/32", device=device) + #%% + M=Manipulator(dataset_name=dataset_name) + np.set_printoptions(suppress=True) + print(M.dataset_name) + #%% + img_sindex=0 + num_images=100 + dlatents_o=[] + tmp=img_sindex*num_images + for i in range(len(M.dlatents)): + tmp1=M.dlatents[i][tmp:(tmp+num_images)] + dlatents_o.append(tmp1) + #%% + + all_f=[] + M.alpha=[-5,5] #ffhq 5 + M.step=2 + M.num_images=num_images + select=np.array(M.mindexs)<=16 #below or equal to 128 resolution + mindexs2=np.array(M.mindexs)[select] + for lindex in mindexs2: #ignore ToRGB layers + print(lindex) + num_c=M.dlatents[lindex].shape[1] + for cindex in range(num_c): + + M.dlatents=copy.copy(dlatents_o) + M.dlatents[lindex][:,cindex]=M.code_mean[lindex][cindex] + + M.manipulate_layers=[lindex] + codes,out=M.EditOneC(cindex) + image_features1=GetImgF(out,model,preprocess) + all_f.append(image_features1) + + all_f=np.array(all_f) + + fs3=GetFs(all_f) + + #%% + file_path='./npy/'+M.dataset_name+'/' + np.save(file_path+'fs3',fs3) + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/models/StyleCLIP/global_directions/__init__.py b/models/StyleCLIP/global_directions/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/models/StyleCLIP/global_directions/data/ffhq/w_plus.npy b/models/StyleCLIP/global_directions/data/ffhq/w_plus.npy new file mode 100644 index 0000000000000000000000000000000000000000..2039c3c5817022d644936586ca807fafe69b0cee Binary files /dev/null and b/models/StyleCLIP/global_directions/data/ffhq/w_plus.npy differ diff --git a/models/StyleCLIP/global_directions/dnnlib/__init__.py b/models/StyleCLIP/global_directions/dnnlib/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c73940d81233142ae3dcd9a37b7ec2185c5d5fc5 --- /dev/null +++ b/models/StyleCLIP/global_directions/dnnlib/__init__.py @@ -0,0 +1,9 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +from .util import EasyDict, make_cache_dir_path diff --git a/models/StyleCLIP/global_directions/dnnlib/tflib/__init__.py b/models/StyleCLIP/global_directions/dnnlib/tflib/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ca852844ec488c0134bffa647e25a40646ff4718 --- /dev/null +++ b/models/StyleCLIP/global_directions/dnnlib/tflib/__init__.py @@ -0,0 +1,20 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +from . import autosummary +from . import network +from . import optimizer +from . import tfutil +from . import custom_ops + +from .tfutil import * +from .network import Network + +from .optimizer import Optimizer + +from .custom_ops import get_plugin diff --git a/models/StyleCLIP/global_directions/dnnlib/tflib/autosummary.py b/models/StyleCLIP/global_directions/dnnlib/tflib/autosummary.py new file mode 100644 index 0000000000000000000000000000000000000000..56dfb96093bb5b1129a99585b4ce655b98d80009 --- /dev/null +++ b/models/StyleCLIP/global_directions/dnnlib/tflib/autosummary.py @@ -0,0 +1,193 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +"""Helper for adding automatically tracked values to Tensorboard. + +Autosummary creates an identity op that internally keeps track of the input +values and automatically shows up in TensorBoard. The reported value +represents an average over input components. The average is accumulated +constantly over time and flushed when save_summaries() is called. + +Notes: +- The output tensor must be used as an input for something else in the + graph. Otherwise, the autosummary op will not get executed, and the average + value will not get accumulated. +- It is perfectly fine to include autosummaries with the same name in + several places throughout the graph, even if they are executed concurrently. +- It is ok to also pass in a python scalar or numpy array. In this case, it + is added to the average immediately. +""" + +from collections import OrderedDict +import numpy as np +import tensorflow as tf +from tensorboard import summary as summary_lib +from tensorboard.plugins.custom_scalar import layout_pb2 + +from . import tfutil +from .tfutil import TfExpression +from .tfutil import TfExpressionEx + +# Enable "Custom scalars" tab in TensorBoard for advanced formatting. +# Disabled by default to reduce tfevents file size. +enable_custom_scalars = False + +_dtype = tf.float64 +_vars = OrderedDict() # name => [var, ...] +_immediate = OrderedDict() # name => update_op, update_value +_finalized = False +_merge_op = None + + +def _create_var(name: str, value_expr: TfExpression) -> TfExpression: + """Internal helper for creating autosummary accumulators.""" + assert not _finalized + name_id = name.replace("/", "_") + v = tf.cast(value_expr, _dtype) + + if v.shape.is_fully_defined(): + size = np.prod(v.shape.as_list()) + size_expr = tf.constant(size, dtype=_dtype) + else: + size = None + size_expr = tf.reduce_prod(tf.cast(tf.shape(v), _dtype)) + + if size == 1: + if v.shape.ndims != 0: + v = tf.reshape(v, []) + v = [size_expr, v, tf.square(v)] + else: + v = [size_expr, tf.reduce_sum(v), tf.reduce_sum(tf.square(v))] + v = tf.cond(tf.is_finite(v[1]), lambda: tf.stack(v), lambda: tf.zeros(3, dtype=_dtype)) + + with tfutil.absolute_name_scope("Autosummary/" + name_id), tf.control_dependencies(None): + var = tf.Variable(tf.zeros(3, dtype=_dtype), trainable=False) # [sum(1), sum(x), sum(x**2)] + update_op = tf.cond(tf.is_variable_initialized(var), lambda: tf.assign_add(var, v), lambda: tf.assign(var, v)) + + if name in _vars: + _vars[name].append(var) + else: + _vars[name] = [var] + return update_op + + +def autosummary(name: str, value: TfExpressionEx, passthru: TfExpressionEx = None, condition: TfExpressionEx = True) -> TfExpressionEx: + """Create a new autosummary. + + Args: + name: Name to use in TensorBoard + value: TensorFlow expression or python value to track + passthru: Optionally return this TF node without modifications but tack an autosummary update side-effect to this node. + + Example use of the passthru mechanism: + + n = autosummary('l2loss', loss, passthru=n) + + This is a shorthand for the following code: + + with tf.control_dependencies([autosummary('l2loss', loss)]): + n = tf.identity(n) + """ + tfutil.assert_tf_initialized() + name_id = name.replace("/", "_") + + if tfutil.is_tf_expression(value): + with tf.name_scope("summary_" + name_id), tf.device(value.device): + condition = tf.convert_to_tensor(condition, name='condition') + update_op = tf.cond(condition, lambda: tf.group(_create_var(name, value)), tf.no_op) + with tf.control_dependencies([update_op]): + return tf.identity(value if passthru is None else passthru) + + else: # python scalar or numpy array + assert not tfutil.is_tf_expression(passthru) + assert not tfutil.is_tf_expression(condition) + if condition: + if name not in _immediate: + with tfutil.absolute_name_scope("Autosummary/" + name_id), tf.device(None), tf.control_dependencies(None): + update_value = tf.placeholder(_dtype) + update_op = _create_var(name, update_value) + _immediate[name] = update_op, update_value + update_op, update_value = _immediate[name] + tfutil.run(update_op, {update_value: value}) + return value if passthru is None else passthru + + +def finalize_autosummaries() -> None: + """Create the necessary ops to include autosummaries in TensorBoard report. + Note: This should be done only once per graph. + """ + global _finalized + tfutil.assert_tf_initialized() + + if _finalized: + return None + + _finalized = True + tfutil.init_uninitialized_vars([var for vars_list in _vars.values() for var in vars_list]) + + # Create summary ops. + with tf.device(None), tf.control_dependencies(None): + for name, vars_list in _vars.items(): + name_id = name.replace("/", "_") + with tfutil.absolute_name_scope("Autosummary/" + name_id): + moments = tf.add_n(vars_list) + moments /= moments[0] + with tf.control_dependencies([moments]): # read before resetting + reset_ops = [tf.assign(var, tf.zeros(3, dtype=_dtype)) for var in vars_list] + with tf.name_scope(None), tf.control_dependencies(reset_ops): # reset before reporting + mean = moments[1] + std = tf.sqrt(moments[2] - tf.square(moments[1])) + tf.summary.scalar(name, mean) + if enable_custom_scalars: + tf.summary.scalar("xCustomScalars/" + name + "/margin_lo", mean - std) + tf.summary.scalar("xCustomScalars/" + name + "/margin_hi", mean + std) + + # Setup layout for custom scalars. + layout = None + if enable_custom_scalars: + cat_dict = OrderedDict() + for series_name in sorted(_vars.keys()): + p = series_name.split("/") + cat = p[0] if len(p) >= 2 else "" + chart = "/".join(p[1:-1]) if len(p) >= 3 else p[-1] + if cat not in cat_dict: + cat_dict[cat] = OrderedDict() + if chart not in cat_dict[cat]: + cat_dict[cat][chart] = [] + cat_dict[cat][chart].append(series_name) + categories = [] + for cat_name, chart_dict in cat_dict.items(): + charts = [] + for chart_name, series_names in chart_dict.items(): + series = [] + for series_name in series_names: + series.append(layout_pb2.MarginChartContent.Series( + value=series_name, + lower="xCustomScalars/" + series_name + "/margin_lo", + upper="xCustomScalars/" + series_name + "/margin_hi")) + margin = layout_pb2.MarginChartContent(series=series) + charts.append(layout_pb2.Chart(title=chart_name, margin=margin)) + categories.append(layout_pb2.Category(title=cat_name, chart=charts)) + layout = summary_lib.custom_scalar_pb(layout_pb2.Layout(category=categories)) + return layout + +def save_summaries(file_writer, global_step=None): + """Call FileWriter.add_summary() with all summaries in the default graph, + automatically finalizing and merging them on the first call. + """ + global _merge_op + tfutil.assert_tf_initialized() + + if _merge_op is None: + layout = finalize_autosummaries() + if layout is not None: + file_writer.add_summary(layout) + with tf.device(None), tf.control_dependencies(None): + _merge_op = tf.summary.merge_all() + + file_writer.add_summary(_merge_op.eval(), global_step) diff --git a/models/StyleCLIP/global_directions/dnnlib/tflib/custom_ops.py b/models/StyleCLIP/global_directions/dnnlib/tflib/custom_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..702471e2006af6858345c1225c1e55b0acd17d32 --- /dev/null +++ b/models/StyleCLIP/global_directions/dnnlib/tflib/custom_ops.py @@ -0,0 +1,181 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +"""TensorFlow custom ops builder. +""" + +import glob +import os +import re +import uuid +import hashlib +import tempfile +import shutil +import tensorflow as tf +from tensorflow.python.client import device_lib # pylint: disable=no-name-in-module + +from .. import util + +#---------------------------------------------------------------------------- +# Global configs. + +cuda_cache_path = None +cuda_cache_version_tag = 'v1' +do_not_hash_included_headers = True # Speed up compilation by assuming that headers included by the CUDA code never change. +verbose = True # Print status messages to stdout. + +#---------------------------------------------------------------------------- +# Internal helper funcs. + +def _find_compiler_bindir(): + hostx64_paths = sorted(glob.glob('C:/Program Files (x86)/Microsoft Visual Studio/*/Professional/VC/Tools/MSVC/*/bin/Hostx64/x64'), reverse=True) + if hostx64_paths != []: + return hostx64_paths[0] + hostx64_paths = sorted(glob.glob('C:/Program Files (x86)/Microsoft Visual Studio/*/BuildTools/VC/Tools/MSVC/*/bin/Hostx64/x64'), reverse=True) + if hostx64_paths != []: + return hostx64_paths[0] + hostx64_paths = sorted(glob.glob('C:/Program Files (x86)/Microsoft Visual Studio/*/Community/VC/Tools/MSVC/*/bin/Hostx64/x64'), reverse=True) + if hostx64_paths != []: + return hostx64_paths[0] + vc_bin_dir = 'C:/Program Files (x86)/Microsoft Visual Studio 14.0/vc/bin' + if os.path.isdir(vc_bin_dir): + return vc_bin_dir + return None + +def _get_compute_cap(device): + caps_str = device.physical_device_desc + m = re.search('compute capability: (\\d+).(\\d+)', caps_str) + major = m.group(1) + minor = m.group(2) + return (major, minor) + +def _get_cuda_gpu_arch_string(): + gpus = [x for x in device_lib.list_local_devices() if x.device_type == 'GPU'] + if len(gpus) == 0: + raise RuntimeError('No GPU devices found') + (major, minor) = _get_compute_cap(gpus[0]) + return 'sm_%s%s' % (major, minor) + +def _run_cmd(cmd): + with os.popen(cmd) as pipe: + output = pipe.read() + status = pipe.close() + if status is not None: + raise RuntimeError('NVCC returned an error. See below for full command line and output log:\n\n%s\n\n%s' % (cmd, output)) + +def _prepare_nvcc_cli(opts): + cmd = 'nvcc ' + opts.strip() + cmd += ' --disable-warnings' + cmd += ' --include-path "%s"' % tf.sysconfig.get_include() + cmd += ' --include-path "%s"' % os.path.join(tf.sysconfig.get_include(), 'external', 'protobuf_archive', 'src') + cmd += ' --include-path "%s"' % os.path.join(tf.sysconfig.get_include(), 'external', 'com_google_absl') + cmd += ' --include-path "%s"' % os.path.join(tf.sysconfig.get_include(), 'external', 'eigen_archive') + + compiler_bindir = _find_compiler_bindir() + if compiler_bindir is None: + # Require that _find_compiler_bindir succeeds on Windows. Allow + # nvcc to use whatever is the default on Linux. + if os.name == 'nt': + raise RuntimeError('Could not find MSVC/GCC/CLANG installation on this computer. Check compiler_bindir_search_path list in "%s".' % __file__) + else: + cmd += ' --compiler-bindir "%s"' % compiler_bindir + cmd += ' 2>&1' + return cmd + +#---------------------------------------------------------------------------- +# Main entry point. + +_plugin_cache = dict() + +def get_plugin(cuda_file, extra_nvcc_options=[]): + cuda_file_base = os.path.basename(cuda_file) + cuda_file_name, cuda_file_ext = os.path.splitext(cuda_file_base) + + # Already in cache? + if cuda_file in _plugin_cache: + return _plugin_cache[cuda_file] + + # Setup plugin. + if verbose: + print('Setting up TensorFlow plugin "%s": ' % cuda_file_base, end='', flush=True) + try: + # Hash CUDA source. + md5 = hashlib.md5() + with open(cuda_file, 'rb') as f: + md5.update(f.read()) + md5.update(b'\n') + + # Hash headers included by the CUDA code by running it through the preprocessor. + if not do_not_hash_included_headers: + if verbose: + print('Preprocessing... ', end='', flush=True) + with tempfile.TemporaryDirectory() as tmp_dir: + tmp_file = os.path.join(tmp_dir, cuda_file_name + '_tmp' + cuda_file_ext) + _run_cmd(_prepare_nvcc_cli('"%s" --preprocess -o "%s" --keep --keep-dir "%s"' % (cuda_file, tmp_file, tmp_dir))) + with open(tmp_file, 'rb') as f: + bad_file_str = ('"' + cuda_file.replace('\\', '/') + '"').encode('utf-8') # __FILE__ in error check macros + good_file_str = ('"' + cuda_file_base + '"').encode('utf-8') + for ln in f: + if not ln.startswith(b'# ') and not ln.startswith(b'#line '): # ignore line number pragmas + ln = ln.replace(bad_file_str, good_file_str) + md5.update(ln) + md5.update(b'\n') + + # Select compiler configs. + compile_opts = '' + if os.name == 'nt': + compile_opts += '"%s"' % os.path.join(tf.sysconfig.get_lib(), 'python', '_pywrap_tensorflow_internal.lib') + elif os.name == 'posix': + compile_opts += f' --compiler-options \'-fPIC\'' + compile_opts += f' --compiler-options \'{" ".join(tf.sysconfig.get_compile_flags())}\'' + compile_opts += f' --linker-options \'{" ".join(tf.sysconfig.get_link_flags())}\'' + else: + assert False # not Windows or Linux, w00t? + compile_opts += f' --gpu-architecture={_get_cuda_gpu_arch_string()}' + compile_opts += ' --use_fast_math' + for opt in extra_nvcc_options: + compile_opts += ' ' + opt + nvcc_cmd = _prepare_nvcc_cli(compile_opts) + + # Hash build configuration. + md5.update(('nvcc_cmd: ' + nvcc_cmd).encode('utf-8') + b'\n') + md5.update(('tf.VERSION: ' + tf.VERSION).encode('utf-8') + b'\n') + md5.update(('cuda_cache_version_tag: ' + cuda_cache_version_tag).encode('utf-8') + b'\n') + + # Compile if not already compiled. + cache_dir = util.make_cache_dir_path('tflib-cudacache') if cuda_cache_path is None else cuda_cache_path + bin_file_ext = '.dll' if os.name == 'nt' else '.so' + bin_file = os.path.join(cache_dir, cuda_file_name + '_' + md5.hexdigest() + bin_file_ext) + if not os.path.isfile(bin_file): + if verbose: + print('Compiling... ', end='', flush=True) + with tempfile.TemporaryDirectory() as tmp_dir: + tmp_file = os.path.join(tmp_dir, cuda_file_name + '_tmp' + bin_file_ext) + _run_cmd(nvcc_cmd + ' "%s" --shared -o "%s" --keep --keep-dir "%s"' % (cuda_file, tmp_file, tmp_dir)) + os.makedirs(cache_dir, exist_ok=True) + intermediate_file = os.path.join(cache_dir, cuda_file_name + '_' + uuid.uuid4().hex + '_tmp' + bin_file_ext) + shutil.copyfile(tmp_file, intermediate_file) + os.rename(intermediate_file, bin_file) # atomic + + # Load. + if verbose: + print('Loading... ', end='', flush=True) + plugin = tf.load_op_library(bin_file) + + # Add to cache. + _plugin_cache[cuda_file] = plugin + if verbose: + print('Done.', flush=True) + return plugin + + except: + if verbose: + print('Failed!', flush=True) + raise + +#---------------------------------------------------------------------------- diff --git a/models/StyleCLIP/global_directions/dnnlib/tflib/network.py b/models/StyleCLIP/global_directions/dnnlib/tflib/network.py new file mode 100644 index 0000000000000000000000000000000000000000..ff0c169eabdc579041dac0650fbc6da956646594 --- /dev/null +++ b/models/StyleCLIP/global_directions/dnnlib/tflib/network.py @@ -0,0 +1,781 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +"""Helper for managing networks.""" + +import types +import inspect +import re +import uuid +import sys +import copy +import numpy as np +import tensorflow as tf + +from collections import OrderedDict +from typing import Any, List, Tuple, Union, Callable + +from . import tfutil +from .. import util + +from .tfutil import TfExpression, TfExpressionEx + +# pylint: disable=protected-access +# pylint: disable=attribute-defined-outside-init +# pylint: disable=too-many-public-methods + +_import_handlers = [] # Custom import handlers for dealing with legacy data in pickle import. +_import_module_src = dict() # Source code for temporary modules created during pickle import. + + +def import_handler(handler_func): + """Function decorator for declaring custom import handlers.""" + _import_handlers.append(handler_func) + return handler_func + + +class Network: + """Generic network abstraction. + + Acts as a convenience wrapper for a parameterized network construction + function, providing several utility methods and convenient access to + the inputs/outputs/weights. + + Network objects can be safely pickled and unpickled for long-term + archival purposes. The pickling works reliably as long as the underlying + network construction function is defined in a standalone Python module + that has no side effects or application-specific imports. + + Args: + name: Network name. Used to select TensorFlow name and variable scopes. Defaults to build func name if None. + func_name: Fully qualified name of the underlying network construction function, or a top-level function object. + static_kwargs: Keyword arguments to be passed in to the network construction function. + """ + + def __init__(self, name: str = None, func_name: Any = None, **static_kwargs): + # Locate the user-specified build function. + assert isinstance(func_name, str) or util.is_top_level_function(func_name) + if util.is_top_level_function(func_name): + func_name = util.get_top_level_function_name(func_name) + module, func_name = util.get_module_from_obj_name(func_name) + func = util.get_obj_from_module(module, func_name) + + # Dig up source code for the module containing the build function. + module_src = _import_module_src.get(module, None) + if module_src is None: + module_src = inspect.getsource(module) + + # Initialize fields. + self._init_fields(name=(name or func_name), static_kwargs=static_kwargs, build_func=func, build_func_name=func_name, build_module_src=module_src) + + def _init_fields(self, name: str, static_kwargs: dict, build_func: Callable, build_func_name: str, build_module_src: str) -> None: + tfutil.assert_tf_initialized() + assert isinstance(name, str) + assert len(name) >= 1 + assert re.fullmatch(r"[A-Za-z0-9_.\\-]*", name) + assert isinstance(static_kwargs, dict) + assert util.is_pickleable(static_kwargs) + assert callable(build_func) + assert isinstance(build_func_name, str) + assert isinstance(build_module_src, str) + + # Choose TensorFlow name scope. + with tf.name_scope(None): + scope = tf.get_default_graph().unique_name(name, mark_as_used=True) + + # Query current TensorFlow device. + with tfutil.absolute_name_scope(scope), tf.control_dependencies(None): + device = tf.no_op(name="_QueryDevice").device + + # Immutable state. + self._name = name + self._scope = scope + self._device = device + self._static_kwargs = util.EasyDict(copy.deepcopy(static_kwargs)) + self._build_func = build_func + self._build_func_name = build_func_name + self._build_module_src = build_module_src + + # State before _init_graph(). + self._var_inits = dict() # var_name => initial_value, set to None by _init_graph() + self._all_inits_known = False # Do we know for sure that _var_inits covers all the variables? + self._components = None # subnet_name => Network, None if the components are not known yet + + # Initialized by _init_graph(). + self._input_templates = None + self._output_templates = None + self._own_vars = None + + # Cached values initialized the respective methods. + self._input_shapes = None + self._output_shapes = None + self._input_names = None + self._output_names = None + self._vars = None + self._trainables = None + self._var_global_to_local = None + self._run_cache = dict() + + def _init_graph(self) -> None: + assert self._var_inits is not None + assert self._input_templates is None + assert self._output_templates is None + assert self._own_vars is None + + # Initialize components. + if self._components is None: + self._components = util.EasyDict() + + # Choose build func kwargs. + build_kwargs = dict(self.static_kwargs) + build_kwargs["is_template_graph"] = True + build_kwargs["components"] = self._components + + # Override scope and device, and ignore surrounding control dependencies. + with tfutil.absolute_variable_scope(self.scope, reuse=False), tfutil.absolute_name_scope(self.scope), tf.device(self.device), tf.control_dependencies(None): + assert tf.get_variable_scope().name == self.scope + assert tf.get_default_graph().get_name_scope() == self.scope + + # Create input templates. + self._input_templates = [] + for param in inspect.signature(self._build_func).parameters.values(): + if param.kind == param.POSITIONAL_OR_KEYWORD and param.default is param.empty: + self._input_templates.append(tf.placeholder(tf.float32, name=param.name)) + + # Call build func. + out_expr = self._build_func(*self._input_templates, **build_kwargs) + + # Collect output templates and variables. + assert tfutil.is_tf_expression(out_expr) or isinstance(out_expr, tuple) + self._output_templates = [out_expr] if tfutil.is_tf_expression(out_expr) else list(out_expr) + self._own_vars = OrderedDict((var.name[len(self.scope) + 1:].split(":")[0], var) for var in tf.global_variables(self.scope + "/")) + + # Check for errors. + if len(self._input_templates) == 0: + raise ValueError("Network build func did not list any inputs.") + if len(self._output_templates) == 0: + raise ValueError("Network build func did not return any outputs.") + if any(not tfutil.is_tf_expression(t) for t in self._output_templates): + raise ValueError("Network outputs must be TensorFlow expressions.") + if any(t.shape.ndims is None for t in self._input_templates): + raise ValueError("Network input shapes not defined. Please call x.set_shape() for each input.") + if any(t.shape.ndims is None for t in self._output_templates): + raise ValueError("Network output shapes not defined. Please call x.set_shape() where applicable.") + if any(not isinstance(comp, Network) for comp in self._components.values()): + raise ValueError("Components of a Network must be Networks themselves.") + if len(self._components) != len(set(comp.name for comp in self._components.values())): + raise ValueError("Components of a Network must have unique names.") + + # Initialize variables. + if len(self._var_inits): + tfutil.set_vars({self._get_vars()[name]: value for name, value in self._var_inits.items() if name in self._get_vars()}) + remaining_inits = [var.initializer for name, var in self._own_vars.items() if name not in self._var_inits] + if self._all_inits_known: + assert len(remaining_inits) == 0 + else: + tfutil.run(remaining_inits) + self._var_inits = None + + @property + def name(self): + """User-specified name string.""" + return self._name + + @property + def scope(self): + """Unique TensorFlow scope containing template graph and variables, derived from the user-specified name.""" + return self._scope + + @property + def device(self): + """Name of the TensorFlow device that the weights of this network reside on. Determined by the current device at construction time.""" + return self._device + + @property + def static_kwargs(self): + """EasyDict of arguments passed to the user-supplied build func.""" + return copy.deepcopy(self._static_kwargs) + + @property + def components(self): + """EasyDict of sub-networks created by the build func.""" + return copy.copy(self._get_components()) + + def _get_components(self): + if self._components is None: + self._init_graph() + assert self._components is not None + return self._components + + @property + def input_shapes(self): + """List of input tensor shapes, including minibatch dimension.""" + if self._input_shapes is None: + self._input_shapes = [t.shape.as_list() for t in self.input_templates] + return copy.deepcopy(self._input_shapes) + + @property + def output_shapes(self): + """List of output tensor shapes, including minibatch dimension.""" + if self._output_shapes is None: + self._output_shapes = [t.shape.as_list() for t in self.output_templates] + return copy.deepcopy(self._output_shapes) + + @property + def input_shape(self): + """Short-hand for input_shapes[0].""" + return self.input_shapes[0] + + @property + def output_shape(self): + """Short-hand for output_shapes[0].""" + return self.output_shapes[0] + + @property + def num_inputs(self): + """Number of input tensors.""" + return len(self.input_shapes) + + @property + def num_outputs(self): + """Number of output tensors.""" + return len(self.output_shapes) + + @property + def input_names(self): + """Name string for each input.""" + if self._input_names is None: + self._input_names = [t.name.split("/")[-1].split(":")[0] for t in self.input_templates] + return copy.copy(self._input_names) + + @property + def output_names(self): + """Name string for each output.""" + if self._output_names is None: + self._output_names = [t.name.split("/")[-1].split(":")[0] for t in self.output_templates] + return copy.copy(self._output_names) + + @property + def input_templates(self): + """Input placeholders in the template graph.""" + if self._input_templates is None: + self._init_graph() + assert self._input_templates is not None + return copy.copy(self._input_templates) + + @property + def output_templates(self): + """Output tensors in the template graph.""" + if self._output_templates is None: + self._init_graph() + assert self._output_templates is not None + return copy.copy(self._output_templates) + + @property + def own_vars(self): + """Variables defined by this network (local_name => var), excluding sub-networks.""" + return copy.copy(self._get_own_vars()) + + def _get_own_vars(self): + if self._own_vars is None: + self._init_graph() + assert self._own_vars is not None + return self._own_vars + + @property + def vars(self): + """All variables (local_name => var).""" + return copy.copy(self._get_vars()) + + def _get_vars(self): + if self._vars is None: + self._vars = OrderedDict(self._get_own_vars()) + for comp in self._get_components().values(): + self._vars.update((comp.name + "/" + name, var) for name, var in comp._get_vars().items()) + return self._vars + + @property + def trainables(self): + """All trainable variables (local_name => var).""" + return copy.copy(self._get_trainables()) + + def _get_trainables(self): + if self._trainables is None: + self._trainables = OrderedDict((name, var) for name, var in self.vars.items() if var.trainable) + return self._trainables + + @property + def var_global_to_local(self): + """Mapping from variable global names to local names.""" + return copy.copy(self._get_var_global_to_local()) + + def _get_var_global_to_local(self): + if self._var_global_to_local is None: + self._var_global_to_local = OrderedDict((var.name.split(":")[0], name) for name, var in self.vars.items()) + return self._var_global_to_local + + def reset_own_vars(self) -> None: + """Re-initialize all variables of this network, excluding sub-networks.""" + if self._var_inits is None or self._components is None: + tfutil.run([var.initializer for var in self._get_own_vars().values()]) + else: + self._var_inits.clear() + self._all_inits_known = False + + def reset_vars(self) -> None: + """Re-initialize all variables of this network, including sub-networks.""" + if self._var_inits is None: + tfutil.run([var.initializer for var in self._get_vars().values()]) + else: + self._var_inits.clear() + self._all_inits_known = False + if self._components is not None: + for comp in self._components.values(): + comp.reset_vars() + + def reset_trainables(self) -> None: + """Re-initialize all trainable variables of this network, including sub-networks.""" + tfutil.run([var.initializer for var in self._get_trainables().values()]) + + def get_output_for(self, *in_expr: TfExpression, return_as_list: bool = False, **dynamic_kwargs) -> Union[TfExpression, List[TfExpression]]: + """Construct TensorFlow expression(s) for the output(s) of this network, given the input expression(s). + The graph is placed on the current TensorFlow device.""" + assert len(in_expr) == self.num_inputs + assert not all(expr is None for expr in in_expr) + self._get_vars() # ensure that all variables have been created + + # Choose build func kwargs. + build_kwargs = dict(self.static_kwargs) + build_kwargs.update(dynamic_kwargs) + build_kwargs["is_template_graph"] = False + build_kwargs["components"] = self._components + + # Build TensorFlow graph to evaluate the network. + with tfutil.absolute_variable_scope(self.scope, reuse=True), tf.name_scope(self.name): + assert tf.get_variable_scope().name == self.scope + valid_inputs = [expr for expr in in_expr if expr is not None] + final_inputs = [] + for expr, name, shape in zip(in_expr, self.input_names, self.input_shapes): + if expr is not None: + expr = tf.identity(expr, name=name) + else: + expr = tf.zeros([tf.shape(valid_inputs[0])[0]] + shape[1:], name=name) + final_inputs.append(expr) + out_expr = self._build_func(*final_inputs, **build_kwargs) + + # Propagate input shapes back to the user-specified expressions. + for expr, final in zip(in_expr, final_inputs): + if isinstance(expr, tf.Tensor): + expr.set_shape(final.shape) + + # Express outputs in the desired format. + assert tfutil.is_tf_expression(out_expr) or isinstance(out_expr, tuple) + if return_as_list: + out_expr = [out_expr] if tfutil.is_tf_expression(out_expr) else list(out_expr) + return out_expr + + def get_var_local_name(self, var_or_global_name: Union[TfExpression, str]) -> str: + """Get the local name of a given variable, without any surrounding name scopes.""" + assert tfutil.is_tf_expression(var_or_global_name) or isinstance(var_or_global_name, str) + global_name = var_or_global_name if isinstance(var_or_global_name, str) else var_or_global_name.name + return self._get_var_global_to_local()[global_name] + + def find_var(self, var_or_local_name: Union[TfExpression, str]) -> TfExpression: + """Find variable by local or global name.""" + assert tfutil.is_tf_expression(var_or_local_name) or isinstance(var_or_local_name, str) + return self._get_vars()[var_or_local_name] if isinstance(var_or_local_name, str) else var_or_local_name + + def get_var(self, var_or_local_name: Union[TfExpression, str]) -> np.ndarray: + """Get the value of a given variable as NumPy array. + Note: This method is very inefficient -- prefer to use tflib.run(list_of_vars) whenever possible.""" + return self.find_var(var_or_local_name).eval() + + def set_var(self, var_or_local_name: Union[TfExpression, str], new_value: Union[int, float, np.ndarray]) -> None: + """Set the value of a given variable based on the given NumPy array. + Note: This method is very inefficient -- prefer to use tflib.set_vars() whenever possible.""" + tfutil.set_vars({self.find_var(var_or_local_name): new_value}) + + def __getstate__(self) -> dict: + """Pickle export.""" + state = dict() + state["version"] = 5 + state["name"] = self.name + state["static_kwargs"] = dict(self.static_kwargs) + state["components"] = dict(self.components) + state["build_module_src"] = self._build_module_src + state["build_func_name"] = self._build_func_name + state["variables"] = list(zip(self._get_own_vars().keys(), tfutil.run(list(self._get_own_vars().values())))) + state["input_shapes"] = self.input_shapes + state["output_shapes"] = self.output_shapes + state["input_names"] = self.input_names + state["output_names"] = self.output_names + return state + + def __setstate__(self, state: dict) -> None: + """Pickle import.""" + + # Execute custom import handlers. + for handler in _import_handlers: + state = handler(state) + + # Get basic fields. + assert state["version"] in [2, 3, 4, 5] + name = state["name"] + static_kwargs = state["static_kwargs"] + build_module_src = state["build_module_src"] + build_func_name = state["build_func_name"] + + # Create temporary module from the imported source code. + module_name = "_tflib_network_import_" + uuid.uuid4().hex + module = types.ModuleType(module_name) + sys.modules[module_name] = module + _import_module_src[module] = build_module_src + exec(build_module_src, module.__dict__) # pylint: disable=exec-used + build_func = util.get_obj_from_module(module, build_func_name) + + # Initialize fields. + self._init_fields(name=name, static_kwargs=static_kwargs, build_func=build_func, build_func_name=build_func_name, build_module_src=build_module_src) + self._var_inits.update(copy.deepcopy(state["variables"])) + self._all_inits_known = True + self._components = util.EasyDict(state.get("components", {})) + self._input_shapes = copy.deepcopy(state.get("input_shapes", None)) + self._output_shapes = copy.deepcopy(state.get("output_shapes", None)) + self._input_names = copy.deepcopy(state.get("input_names", None)) + self._output_names = copy.deepcopy(state.get("output_names", None)) + + def clone(self, name: str = None, **new_static_kwargs) -> "Network": + """Create a clone of this network with its own copy of the variables.""" + static_kwargs = dict(self.static_kwargs) + static_kwargs.update(new_static_kwargs) + net = object.__new__(Network) + net._init_fields(name=(name or self.name), static_kwargs=static_kwargs, build_func=self._build_func, build_func_name=self._build_func_name, build_module_src=self._build_module_src) + net.copy_vars_from(self) + return net + + def copy_own_vars_from(self, src_net: "Network") -> None: + """Copy the values of all variables from the given network, excluding sub-networks.""" + + # Source has unknown variables or unknown components => init now. + if (src_net._var_inits is not None and not src_net._all_inits_known) or src_net._components is None: + src_net._get_vars() + + # Both networks are inited => copy directly. + if src_net._var_inits is None and self._var_inits is None: + names = [name for name in self._get_own_vars().keys() if name in src_net._get_own_vars()] + tfutil.set_vars(tfutil.run({self._get_vars()[name]: src_net._get_vars()[name] for name in names})) + return + + # Read from source. + if src_net._var_inits is None: + value_dict = tfutil.run(src_net._get_own_vars()) + else: + value_dict = src_net._var_inits + + # Write to destination. + if self._var_inits is None: + tfutil.set_vars({self._get_vars()[name]: value for name, value in value_dict.items() if name in self._get_vars()}) + else: + self._var_inits.update(value_dict) + + def copy_vars_from(self, src_net: "Network") -> None: + """Copy the values of all variables from the given network, including sub-networks.""" + + # Source has unknown variables or unknown components => init now. + if (src_net._var_inits is not None and not src_net._all_inits_known) or src_net._components is None: + src_net._get_vars() + + # Source is inited, but destination components have not been created yet => set as initial values. + if src_net._var_inits is None and self._components is None: + self._var_inits.update(tfutil.run(src_net._get_vars())) + return + + # Destination has unknown components => init now. + if self._components is None: + self._get_vars() + + # Both networks are inited => copy directly. + if src_net._var_inits is None and self._var_inits is None: + names = [name for name in self._get_vars().keys() if name in src_net._get_vars()] + tfutil.set_vars(tfutil.run({self._get_vars()[name]: src_net._get_vars()[name] for name in names})) + return + + # Copy recursively, component by component. + self.copy_own_vars_from(src_net) + for name, src_comp in src_net._components.items(): + if name in self._components: + self._components[name].copy_vars_from(src_comp) + + def copy_trainables_from(self, src_net: "Network") -> None: + """Copy the values of all trainable variables from the given network, including sub-networks.""" + names = [name for name in self._get_trainables().keys() if name in src_net._get_trainables()] + tfutil.set_vars(tfutil.run({self._get_vars()[name]: src_net._get_vars()[name] for name in names})) + + def convert(self, new_func_name: str, new_name: str = None, **new_static_kwargs) -> "Network": + """Create new network with the given parameters, and copy all variables from this network.""" + if new_name is None: + new_name = self.name + static_kwargs = dict(self.static_kwargs) + static_kwargs.update(new_static_kwargs) + net = Network(name=new_name, func_name=new_func_name, **static_kwargs) + net.copy_vars_from(self) + return net + + def setup_as_moving_average_of(self, src_net: "Network", beta: TfExpressionEx = 0.99, beta_nontrainable: TfExpressionEx = 0.0) -> tf.Operation: + """Construct a TensorFlow op that updates the variables of this network + to be slightly closer to those of the given network.""" + with tfutil.absolute_name_scope(self.scope + "/_MovingAvg"): + ops = [] + for name, var in self._get_vars().items(): + if name in src_net._get_vars(): + cur_beta = beta if var.trainable else beta_nontrainable + new_value = tfutil.lerp(src_net._get_vars()[name], var, cur_beta) + ops.append(var.assign(new_value)) + return tf.group(*ops) + + def run(self, + *in_arrays: Tuple[Union[np.ndarray, None], ...], + input_transform: dict = None, + output_transform: dict = None, + return_as_list: bool = False, + print_progress: bool = False, + minibatch_size: int = None, + num_gpus: int = 1, + assume_frozen: bool = False, + **dynamic_kwargs) -> Union[np.ndarray, Tuple[np.ndarray, ...], List[np.ndarray]]: + """Run this network for the given NumPy array(s), and return the output(s) as NumPy array(s). + + Args: + input_transform: A dict specifying a custom transformation to be applied to the input tensor(s) before evaluating the network. + The dict must contain a 'func' field that points to a top-level function. The function is called with the input + TensorFlow expression(s) as positional arguments. Any remaining fields of the dict will be passed in as kwargs. + output_transform: A dict specifying a custom transformation to be applied to the output tensor(s) after evaluating the network. + The dict must contain a 'func' field that points to a top-level function. The function is called with the output + TensorFlow expression(s) as positional arguments. Any remaining fields of the dict will be passed in as kwargs. + return_as_list: True = return a list of NumPy arrays, False = return a single NumPy array, or a tuple if there are multiple outputs. + print_progress: Print progress to the console? Useful for very large input arrays. + minibatch_size: Maximum minibatch size to use, None = disable batching. + num_gpus: Number of GPUs to use. + assume_frozen: Improve multi-GPU performance by assuming that the trainable parameters will remain changed between calls. + dynamic_kwargs: Additional keyword arguments to be passed into the network build function. + """ + assert len(in_arrays) == self.num_inputs + assert not all(arr is None for arr in in_arrays) + assert input_transform is None or util.is_top_level_function(input_transform["func"]) + assert output_transform is None or util.is_top_level_function(output_transform["func"]) + output_transform, dynamic_kwargs = _handle_legacy_output_transforms(output_transform, dynamic_kwargs) + num_items = in_arrays[0].shape[0] + if minibatch_size is None: + minibatch_size = num_items + + # Construct unique hash key from all arguments that affect the TensorFlow graph. + key = dict(input_transform=input_transform, output_transform=output_transform, num_gpus=num_gpus, assume_frozen=assume_frozen, dynamic_kwargs=dynamic_kwargs) + def unwind_key(obj): + if isinstance(obj, dict): + return [(key, unwind_key(value)) for key, value in sorted(obj.items())] + if callable(obj): + return util.get_top_level_function_name(obj) + return obj + key = repr(unwind_key(key)) + + # Build graph. + if key not in self._run_cache: + with tfutil.absolute_name_scope(self.scope + "/_Run"), tf.control_dependencies(None): + with tf.device("/cpu:0"): + in_expr = [tf.placeholder(tf.float32, name=name) for name in self.input_names] + in_split = list(zip(*[tf.split(x, num_gpus) for x in in_expr])) + + out_split = [] + for gpu in range(num_gpus): + with tf.device(self.device if num_gpus == 1 else "/gpu:%d" % gpu): + net_gpu = self.clone() if assume_frozen else self + in_gpu = in_split[gpu] + + if input_transform is not None: + in_kwargs = dict(input_transform) + in_gpu = in_kwargs.pop("func")(*in_gpu, **in_kwargs) + in_gpu = [in_gpu] if tfutil.is_tf_expression(in_gpu) else list(in_gpu) + + assert len(in_gpu) == self.num_inputs + out_gpu = net_gpu.get_output_for(*in_gpu, return_as_list=True, **dynamic_kwargs) + + if output_transform is not None: + out_kwargs = dict(output_transform) + out_gpu = out_kwargs.pop("func")(*out_gpu, **out_kwargs) + out_gpu = [out_gpu] if tfutil.is_tf_expression(out_gpu) else list(out_gpu) + + assert len(out_gpu) == self.num_outputs + out_split.append(out_gpu) + + with tf.device("/cpu:0"): + out_expr = [tf.concat(outputs, axis=0) for outputs in zip(*out_split)] + self._run_cache[key] = in_expr, out_expr + + # Run minibatches. + in_expr, out_expr = self._run_cache[key] + out_arrays = [np.empty([num_items] + expr.shape.as_list()[1:], expr.dtype.name) for expr in out_expr] + + for mb_begin in range(0, num_items, minibatch_size): + if print_progress: + print("\r%d / %d" % (mb_begin, num_items), end="") + + mb_end = min(mb_begin + minibatch_size, num_items) + mb_num = mb_end - mb_begin + mb_in = [src[mb_begin : mb_end] if src is not None else np.zeros([mb_num] + shape[1:]) for src, shape in zip(in_arrays, self.input_shapes)] + mb_out = tf.get_default_session().run(out_expr, dict(zip(in_expr, mb_in))) + + for dst, src in zip(out_arrays, mb_out): + dst[mb_begin: mb_end] = src + + # Done. + if print_progress: + print("\r%d / %d" % (num_items, num_items)) + + if not return_as_list: + out_arrays = out_arrays[0] if len(out_arrays) == 1 else tuple(out_arrays) + return out_arrays + + def list_ops(self) -> List[TfExpression]: + _ = self.output_templates # ensure that the template graph has been created + include_prefix = self.scope + "/" + exclude_prefix = include_prefix + "_" + ops = tf.get_default_graph().get_operations() + ops = [op for op in ops if op.name.startswith(include_prefix)] + ops = [op for op in ops if not op.name.startswith(exclude_prefix)] + return ops + + def list_layers(self) -> List[Tuple[str, TfExpression, List[TfExpression]]]: + """Returns a list of (layer_name, output_expr, trainable_vars) tuples corresponding to + individual layers of the network. Mainly intended to be used for reporting.""" + layers = [] + + def recurse(scope, parent_ops, parent_vars, level): + if len(parent_ops) == 0 and len(parent_vars) == 0: + return + + # Ignore specific patterns. + if any(p in scope for p in ["/Shape", "/strided_slice", "/Cast", "/concat", "/Assign"]): + return + + # Filter ops and vars by scope. + global_prefix = scope + "/" + local_prefix = global_prefix[len(self.scope) + 1:] + cur_ops = [op for op in parent_ops if op.name.startswith(global_prefix) or op.name == global_prefix[:-1]] + cur_vars = [(name, var) for name, var in parent_vars if name.startswith(local_prefix) or name == local_prefix[:-1]] + if not cur_ops and not cur_vars: + return + + # Filter out all ops related to variables. + for var in [op for op in cur_ops if op.type.startswith("Variable")]: + var_prefix = var.name + "/" + cur_ops = [op for op in cur_ops if not op.name.startswith(var_prefix)] + + # Scope does not contain ops as immediate children => recurse deeper. + contains_direct_ops = any("/" not in op.name[len(global_prefix):] and op.type not in ["Identity", "Cast", "Transpose"] for op in cur_ops) + if (level == 0 or not contains_direct_ops) and (len(cur_ops) != 0 or len(cur_vars) != 0): + visited = set() + for rel_name in [op.name[len(global_prefix):] for op in cur_ops] + [name[len(local_prefix):] for name, _var in cur_vars]: + token = rel_name.split("/")[0] + if token not in visited: + recurse(global_prefix + token, cur_ops, cur_vars, level + 1) + visited.add(token) + return + + # Report layer. + layer_name = scope[len(self.scope) + 1:] + layer_output = cur_ops[-1].outputs[0] if cur_ops else cur_vars[-1][1] + layer_trainables = [var for _name, var in cur_vars if var.trainable] + layers.append((layer_name, layer_output, layer_trainables)) + + recurse(self.scope, self.list_ops(), list(self._get_vars().items()), 0) + return layers + + def print_layers(self, title: str = None, hide_layers_with_no_params: bool = False) -> None: + """Print a summary table of the network structure.""" + rows = [[title if title is not None else self.name, "Params", "OutputShape", "WeightShape"]] + rows += [["---"] * 4] + total_params = 0 + + for layer_name, layer_output, layer_trainables in self.list_layers(): + num_params = sum(int(np.prod(var.shape.as_list())) for var in layer_trainables) + weights = [var for var in layer_trainables if var.name.endswith("/weight:0")] + weights.sort(key=lambda x: len(x.name)) + if len(weights) == 0 and len(layer_trainables) == 1: + weights = layer_trainables + total_params += num_params + + if not hide_layers_with_no_params or num_params != 0: + num_params_str = str(num_params) if num_params > 0 else "-" + output_shape_str = str(layer_output.shape) + weight_shape_str = str(weights[0].shape) if len(weights) >= 1 else "-" + rows += [[layer_name, num_params_str, output_shape_str, weight_shape_str]] + + rows += [["---"] * 4] + rows += [["Total", str(total_params), "", ""]] + + widths = [max(len(cell) for cell in column) for column in zip(*rows)] + print() + for row in rows: + print(" ".join(cell + " " * (width - len(cell)) for cell, width in zip(row, widths))) + print() + + def setup_weight_histograms(self, title: str = None) -> None: + """Construct summary ops to include histograms of all trainable parameters in TensorBoard.""" + if title is None: + title = self.name + + with tf.name_scope(None), tf.device(None), tf.control_dependencies(None): + for local_name, var in self._get_trainables().items(): + if "/" in local_name: + p = local_name.split("/") + name = title + "_" + p[-1] + "/" + "_".join(p[:-1]) + else: + name = title + "_toplevel/" + local_name + + tf.summary.histogram(name, var) + +#---------------------------------------------------------------------------- +# Backwards-compatible emulation of legacy output transformation in Network.run(). + +_print_legacy_warning = True + +def _handle_legacy_output_transforms(output_transform, dynamic_kwargs): + global _print_legacy_warning + legacy_kwargs = ["out_mul", "out_add", "out_shrink", "out_dtype"] + if not any(kwarg in dynamic_kwargs for kwarg in legacy_kwargs): + return output_transform, dynamic_kwargs + + if _print_legacy_warning: + _print_legacy_warning = False + print() + print("WARNING: Old-style output transformations in Network.run() are deprecated.") + print("Consider using 'output_transform=dict(func=tflib.convert_images_to_uint8)'") + print("instead of 'out_mul=127.5, out_add=127.5, out_dtype=np.uint8'.") + print() + assert output_transform is None + + new_kwargs = dict(dynamic_kwargs) + new_transform = {kwarg: new_kwargs.pop(kwarg) for kwarg in legacy_kwargs if kwarg in dynamic_kwargs} + new_transform["func"] = _legacy_output_transform_func + return new_transform, new_kwargs + +def _legacy_output_transform_func(*expr, out_mul=1.0, out_add=0.0, out_shrink=1, out_dtype=None): + if out_mul != 1.0: + expr = [x * out_mul for x in expr] + + if out_add != 0.0: + expr = [x + out_add for x in expr] + + if out_shrink > 1: + ksize = [1, 1, out_shrink, out_shrink] + expr = [tf.nn.avg_pool(x, ksize=ksize, strides=ksize, padding="VALID", data_format="NCHW") for x in expr] + + if out_dtype is not None: + if tf.as_dtype(out_dtype).is_integer: + expr = [tf.round(x) for x in expr] + expr = [tf.saturate_cast(x, out_dtype) for x in expr] + return expr diff --git a/models/StyleCLIP/global_directions/dnnlib/tflib/ops/__init__.py b/models/StyleCLIP/global_directions/dnnlib/tflib/ops/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..43cce37364064146fd30e18612b1d9e3a84f513a --- /dev/null +++ b/models/StyleCLIP/global_directions/dnnlib/tflib/ops/__init__.py @@ -0,0 +1,9 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +# empty diff --git a/models/StyleCLIP/global_directions/dnnlib/tflib/ops/fused_bias_act.cu b/models/StyleCLIP/global_directions/dnnlib/tflib/ops/fused_bias_act.cu new file mode 100644 index 0000000000000000000000000000000000000000..0268f14395319003240b4a5a59141d703e9a4257 --- /dev/null +++ b/models/StyleCLIP/global_directions/dnnlib/tflib/ops/fused_bias_act.cu @@ -0,0 +1,220 @@ +// Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +// +// NVIDIA CORPORATION and its licensors retain all intellectual property +// and proprietary rights in and to this software, related documentation +// and any modifications thereto. Any use, reproduction, disclosure or +// distribution of this software and related documentation without an express +// license agreement from NVIDIA CORPORATION is strictly prohibited. + +#define EIGEN_USE_GPU +#define __CUDA_INCLUDE_COMPILER_INTERNAL_HEADERS__ +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/shape_inference.h" +#include + +using namespace tensorflow; +using namespace tensorflow::shape_inference; + +#define OP_CHECK_CUDA_ERROR(CTX, CUDA_CALL) do { cudaError_t err = CUDA_CALL; OP_REQUIRES(CTX, err == cudaSuccess, errors::Internal(cudaGetErrorName(err))); } while (false) + +//------------------------------------------------------------------------ +// CUDA kernel. + +template +struct FusedBiasActKernelParams +{ + const T* x; // [sizeX] + const T* b; // [sizeB] or NULL + const T* xref; // [sizeX] or NULL + const T* yref; // [sizeX] or NULL + T* y; // [sizeX] + + int grad; + int axis; + int act; + float alpha; + float gain; + float clamp; + + int sizeX; + int sizeB; + int stepB; + int loopX; +}; + +template +static __global__ void FusedBiasActKernel(const FusedBiasActKernelParams p) +{ + const float expRange = 80.0f; + const float halfExpRange = 40.0f; + const float seluScale = 1.0507009873554804934193349852946f; + const float seluAlpha = 1.6732632423543772848170429916717f; + + // Loop over elements. + int xi = blockIdx.x * p.loopX * blockDim.x + threadIdx.x; + for (int loopIdx = 0; loopIdx < p.loopX && xi < p.sizeX; loopIdx++, xi += blockDim.x) + { + // Load and apply bias. + float x = (float)p.x[xi]; + if (p.b) + x += (float)p.b[(xi / p.stepB) % p.sizeB]; + float xref = (p.xref) ? (float)p.xref[xi] : 0.0f; + float yref = (p.yref) ? (float)p.yref[xi] : 0.0f; + float yy = (p.gain != 0.0f) ? yref / p.gain : 0.0f; + + // Evaluate activation func. + float y; + switch (p.act * 10 + p.grad) + { + // linear + default: + case 10: y = x; break; + case 11: y = x; break; + case 12: y = 0.0f; break; + + // relu + case 20: y = (x > 0.0f) ? x : 0.0f; break; + case 21: y = (yy > 0.0f) ? x : 0.0f; break; + case 22: y = 0.0f; break; + + // lrelu + case 30: y = (x > 0.0f) ? x : x * p.alpha; break; + case 31: y = (yy > 0.0f) ? x : x * p.alpha; break; + case 32: y = 0.0f; break; + + // tanh + case 40: { float c = expf(x); float d = 1.0f / c; y = (x < -expRange) ? -1.0f : (x > expRange) ? 1.0f : (c - d) / (c + d); } break; + case 41: y = x * (1.0f - yy * yy); break; + case 42: y = x * (1.0f - yy * yy) * (-2.0f * yy); break; + + // sigmoid + case 50: y = (x < -expRange) ? 0.0f : 1.0f / (expf(-x) + 1.0f); break; + case 51: y = x * yy * (1.0f - yy); break; + case 52: y = x * yy * (1.0f - yy) * (1.0f - 2.0f * yy); break; + + // elu + case 60: y = (x >= 0.0f) ? x : expf(x) - 1.0f; break; + case 61: y = (yy >= 0.0f) ? x : x * (yy + 1.0f); break; + case 62: y = (yy >= 0.0f) ? 0.0f : x * (yy + 1.0f); break; + + // selu + case 70: y = (x >= 0.0f) ? seluScale * x : (seluScale * seluAlpha) * (expf(x) - 1.0f); break; + case 71: y = (yy >= 0.0f) ? x * seluScale : x * (yy + seluScale * seluAlpha); break; + case 72: y = (yy >= 0.0f) ? 0.0f : x * (yy + seluScale * seluAlpha); break; + + // softplus + case 80: y = (x > expRange) ? x : logf(expf(x) + 1.0f); break; + case 81: y = x * (1.0f - expf(-yy)); break; + case 82: { float c = expf(-yy); y = x * c * (1.0f - c); } break; + + // swish + case 90: y = (x < -expRange) ? 0.0f : x / (expf(-x) + 1.0f); break; + case 91: + case 92: + { + float c = expf(xref); + float d = c + 1.0f; + if (p.grad == 1) + y = (xref > halfExpRange) ? x : x * c * (xref + d) / (d * d); + else + y = (xref > halfExpRange) ? 0.0f : x * c * (xref * (2.0f - d) + 2.0f * d) / (d * d * d); + yref = (xref < -expRange) ? 0.0f : xref / (expf(-xref) + 1.0f) * p.gain; + } + break; + } + + // Apply gain. + y *= p.gain; + + // Clamp. + if (p.clamp >= 0.0f) + { + if (p.grad == 0) + y = (fabsf(y) < p.clamp) ? y : (y >= 0.0f) ? p.clamp : -p.clamp; + else + y = (fabsf(yref) < p.clamp) ? y : 0.0f; + } + + // Store. + p.y[xi] = (T)y; + } +} + +//------------------------------------------------------------------------ +// TensorFlow op. + +template +struct FusedBiasActOp : public OpKernel +{ + FusedBiasActKernelParams m_attribs; + + FusedBiasActOp(OpKernelConstruction* ctx) : OpKernel(ctx) + { + memset(&m_attribs, 0, sizeof(m_attribs)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("grad", &m_attribs.grad)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("axis", &m_attribs.axis)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("act", &m_attribs.act)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("alpha", &m_attribs.alpha)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("gain", &m_attribs.gain)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("clamp", &m_attribs.clamp)); + OP_REQUIRES(ctx, m_attribs.grad >= 0, errors::InvalidArgument("grad must be non-negative")); + OP_REQUIRES(ctx, m_attribs.axis >= 0, errors::InvalidArgument("axis must be non-negative")); + OP_REQUIRES(ctx, m_attribs.act >= 0, errors::InvalidArgument("act must be non-negative")); + } + + void Compute(OpKernelContext* ctx) + { + FusedBiasActKernelParams p = m_attribs; + cudaStream_t stream = ctx->eigen_device().stream(); + + const Tensor& x = ctx->input(0); // [...] + const Tensor& b = ctx->input(1); // [sizeB] or [0] + const Tensor& xref = ctx->input(2); // x.shape or [0] + const Tensor& yref = ctx->input(3); // x.shape or [0] + p.x = x.flat().data(); + p.b = (b.NumElements()) ? b.flat().data() : NULL; + p.xref = (xref.NumElements()) ? xref.flat().data() : NULL; + p.yref = (yref.NumElements()) ? yref.flat().data() : NULL; + OP_REQUIRES(ctx, b.NumElements() == 0 || m_attribs.axis < x.dims(), errors::InvalidArgument("axis out of bounds")); + OP_REQUIRES(ctx, b.dims() == 1, errors::InvalidArgument("b must have rank 1")); + OP_REQUIRES(ctx, b.NumElements() == 0 || b.NumElements() == x.dim_size(m_attribs.axis), errors::InvalidArgument("b has wrong number of elements")); + OP_REQUIRES(ctx, xref.NumElements() == 0 || xref.NumElements() == x.NumElements(), errors::InvalidArgument("xref has wrong number of elements")); + OP_REQUIRES(ctx, yref.NumElements() == 0 || yref.NumElements() == x.NumElements(), errors::InvalidArgument("yref has wrong number of elements")); + OP_REQUIRES(ctx, x.NumElements() <= kint32max, errors::InvalidArgument("x is too large")); + + p.sizeX = (int)x.NumElements(); + p.sizeB = (int)b.NumElements(); + p.stepB = 1; + for (int i = m_attribs.axis + 1; i < x.dims(); i++) + p.stepB *= (int)x.dim_size(i); + + Tensor* y = NULL; // x.shape + OP_REQUIRES_OK(ctx, ctx->allocate_output(0, x.shape(), &y)); + p.y = y->flat().data(); + + p.loopX = 4; + int blockSize = 4 * 32; + int gridSize = (p.sizeX - 1) / (p.loopX * blockSize) + 1; + void* args[] = {&p}; + OP_CHECK_CUDA_ERROR(ctx, cudaLaunchKernel((void*)FusedBiasActKernel, gridSize, blockSize, args, 0, stream)); + } +}; + +REGISTER_OP("FusedBiasAct") + .Input ("x: T") + .Input ("b: T") + .Input ("xref: T") + .Input ("yref: T") + .Output ("y: T") + .Attr ("T: {float, half}") + .Attr ("grad: int = 0") + .Attr ("axis: int = 1") + .Attr ("act: int = 0") + .Attr ("alpha: float = 0.0") + .Attr ("gain: float = 1.0") + .Attr ("clamp: float = -1.0"); +REGISTER_KERNEL_BUILDER(Name("FusedBiasAct").Device(DEVICE_GPU).TypeConstraint("T"), FusedBiasActOp); +REGISTER_KERNEL_BUILDER(Name("FusedBiasAct").Device(DEVICE_GPU).TypeConstraint("T"), FusedBiasActOp); + +//------------------------------------------------------------------------ diff --git a/models/StyleCLIP/global_directions/dnnlib/tflib/ops/fused_bias_act.py b/models/StyleCLIP/global_directions/dnnlib/tflib/ops/fused_bias_act.py new file mode 100644 index 0000000000000000000000000000000000000000..79991b0497d3d92f25194a31668b9568048163f8 --- /dev/null +++ b/models/StyleCLIP/global_directions/dnnlib/tflib/ops/fused_bias_act.py @@ -0,0 +1,211 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +"""Custom TensorFlow ops for efficient bias and activation.""" + +import os +import numpy as np +import tensorflow as tf +from .. import custom_ops +from ...util import EasyDict + +def _get_plugin(): + return custom_ops.get_plugin(os.path.splitext(__file__)[0] + '.cu') + +#---------------------------------------------------------------------------- + +activation_funcs = { + 'linear': EasyDict(func=lambda x, **_: x, def_alpha=None, def_gain=1.0, cuda_idx=1, ref='y', zero_2nd_grad=True), + 'relu': EasyDict(func=lambda x, **_: tf.nn.relu(x), def_alpha=None, def_gain=np.sqrt(2), cuda_idx=2, ref='y', zero_2nd_grad=True), + 'lrelu': EasyDict(func=lambda x, alpha, **_: tf.nn.leaky_relu(x, alpha), def_alpha=0.2, def_gain=np.sqrt(2), cuda_idx=3, ref='y', zero_2nd_grad=True), + 'tanh': EasyDict(func=lambda x, **_: tf.nn.tanh(x), def_alpha=None, def_gain=1.0, cuda_idx=4, ref='y', zero_2nd_grad=False), + 'sigmoid': EasyDict(func=lambda x, **_: tf.nn.sigmoid(x), def_alpha=None, def_gain=1.0, cuda_idx=5, ref='y', zero_2nd_grad=False), + 'elu': EasyDict(func=lambda x, **_: tf.nn.elu(x), def_alpha=None, def_gain=1.0, cuda_idx=6, ref='y', zero_2nd_grad=False), + 'selu': EasyDict(func=lambda x, **_: tf.nn.selu(x), def_alpha=None, def_gain=1.0, cuda_idx=7, ref='y', zero_2nd_grad=False), + 'softplus': EasyDict(func=lambda x, **_: tf.nn.softplus(x), def_alpha=None, def_gain=1.0, cuda_idx=8, ref='y', zero_2nd_grad=False), + 'swish': EasyDict(func=lambda x, **_: tf.nn.sigmoid(x) * x, def_alpha=None, def_gain=np.sqrt(2), cuda_idx=9, ref='x', zero_2nd_grad=False), +} + +#---------------------------------------------------------------------------- + +def fused_bias_act(x, b=None, axis=1, act='linear', alpha=None, gain=None, clamp=None, impl='cuda'): + r"""Fused bias and activation function. + + Adds bias `b` to activation tensor `x`, evaluates activation function `act`, + and scales the result by `gain`. Each of the steps is optional. In most cases, + the fused op is considerably more efficient than performing the same calculation + using standard TensorFlow ops. It supports first and second order gradients, + but not third order gradients. + + Args: + x: Input activation tensor. Can have any shape, but if `b` is defined, the + dimension corresponding to `axis`, as well as the rank, must be known. + b: Bias vector, or `None` to disable. Must be a 1D tensor of the same type + as `x`. The shape must be known, and it must match the dimension of `x` + corresponding to `axis`. + axis: The dimension in `x` corresponding to the elements of `b`. + The value of `axis` is ignored if `b` is not specified. + act: Name of the activation function to evaluate, or `"linear"` to disable. + Can be e.g. `"relu"`, `"lrelu"`, `"tanh"`, `"sigmoid"`, `"swish"`, etc. + See `activation_funcs` for a full list. `None` is not allowed. + alpha: Shape parameter for the activation function, or `None` to use the default. + gain: Scaling factor for the output tensor, or `None` to use default. + See `activation_funcs` for the default scaling of each activation function. + If unsure, consider specifying `1.0`. + clamp: Clamp the output values to `[-clamp, +clamp]`, or `None` to disable + the clamping (default). + impl: Name of the implementation to use. Can be `"ref"` or `"cuda"` (default). + + Returns: + Tensor of the same shape and datatype as `x`. + """ + + impl_dict = { + 'ref': _fused_bias_act_ref, + 'cuda': _fused_bias_act_cuda, + } + return impl_dict[impl](x=x, b=b, axis=axis, act=act, alpha=alpha, gain=gain, clamp=clamp) + +#---------------------------------------------------------------------------- + +def _fused_bias_act_ref(x, b, axis, act, alpha, gain, clamp): + """Slow reference implementation of `fused_bias_act()` using standard TensorFlow ops.""" + + # Validate arguments. + x = tf.convert_to_tensor(x) + b = tf.convert_to_tensor(b) if b is not None else tf.constant([], dtype=x.dtype) + act_spec = activation_funcs[act] + assert b.shape.rank == 1 and (b.shape[0] == 0 or b.shape[0] == x.shape[axis]) + assert b.shape[0] == 0 or 0 <= axis < x.shape.rank + if alpha is None: + alpha = act_spec.def_alpha + if gain is None: + gain = act_spec.def_gain + + # Add bias. + if b.shape[0] != 0: + x += tf.reshape(b, [-1 if i == axis else 1 for i in range(x.shape.rank)]) + + # Evaluate activation function. + x = act_spec.func(x, alpha=alpha) + + # Scale by gain. + if gain != 1: + x *= gain + + # Clamp. + if clamp is not None: + clamp = np.asarray(clamp, dtype=x.dtype.name) + assert clamp.shape == () and clamp >= 0 + x = tf.clip_by_value(x, -clamp, clamp) + return x + +#---------------------------------------------------------------------------- + +def _fused_bias_act_cuda(x, b, axis, act, alpha, gain, clamp): + """Fast CUDA implementation of `fused_bias_act()` using custom ops.""" + + # Validate arguments. + x = tf.convert_to_tensor(x) + empty_tensor = tf.constant([], dtype=x.dtype) + b = tf.convert_to_tensor(b) if b is not None else empty_tensor + act_spec = activation_funcs[act] + assert b.shape.rank == 1 and (b.shape[0] == 0 or b.shape[0] == x.shape[axis]) + assert b.shape[0] == 0 or 0 <= axis < x.shape.rank + if alpha is None: + alpha = act_spec.def_alpha + if gain is None: + gain = act_spec.def_gain + + # Special cases. + if act == 'linear' and b is None and gain == 1.0: + return x + if act_spec.cuda_idx is None: + return _fused_bias_act_ref(x=x, b=b, axis=axis, act=act, alpha=alpha, gain=gain, clamp=clamp) + + # CUDA op. + cuda_op = _get_plugin().fused_bias_act + cuda_kwargs = dict(axis=int(axis), act=int(act_spec.cuda_idx), gain=float(gain)) + if alpha is not None: + cuda_kwargs['alpha'] = float(alpha) + if clamp is not None: + clamp = np.asarray(clamp, dtype=x.dtype.name) + assert clamp.shape == () and clamp >= 0 + cuda_kwargs['clamp'] = float(clamp.astype(np.float32)) + def ref(tensor, name): + return tensor if act_spec.ref == name else empty_tensor + + # Forward pass: y = func(x, b). + def func_y(x, b): + y = cuda_op(x=x, b=b, xref=empty_tensor, yref=empty_tensor, grad=0, **cuda_kwargs) + y.set_shape(x.shape) + return y + + # Backward pass: dx, db = grad(dy, x, y) + def grad_dx(dy, x, y): + dx = cuda_op(x=dy, b=empty_tensor, xref=ref(x,'x'), yref=ref(y,'y'), grad=1, **cuda_kwargs) + dx.set_shape(x.shape) + return dx + def grad_db(dx): + if b.shape[0] == 0: + return empty_tensor + db = dx + if axis < x.shape.rank - 1: + db = tf.reduce_sum(db, list(range(axis + 1, x.shape.rank))) + if axis > 0: + db = tf.reduce_sum(db, list(range(axis))) + db.set_shape(b.shape) + return db + + # Second order gradients: d_dy, d_x = grad2(d_dx, d_db, x, y) + def grad2_d_dy(d_dx, d_db, x, y): + d_dy = cuda_op(x=d_dx, b=d_db, xref=ref(x,'x'), yref=ref(y,'y'), grad=1, **cuda_kwargs) + d_dy.set_shape(x.shape) + return d_dy + def grad2_d_x(d_dx, d_db, x, y): + d_x = cuda_op(x=d_dx, b=d_db, xref=ref(x,'x'), yref=ref(y,'y'), grad=2, **cuda_kwargs) + d_x.set_shape(x.shape) + return d_x + + # Fast version for piecewise-linear activation funcs. + @tf.custom_gradient + def func_zero_2nd_grad(x, b): + y = func_y(x, b) + @tf.custom_gradient + def grad(dy): + dx = grad_dx(dy, x, y) + db = grad_db(dx) + def grad2(d_dx, d_db): + d_dy = grad2_d_dy(d_dx, d_db, x, y) + return d_dy + return (dx, db), grad2 + return y, grad + + # Slow version for general activation funcs. + @tf.custom_gradient + def func_nonzero_2nd_grad(x, b): + y = func_y(x, b) + def grad_wrap(dy): + @tf.custom_gradient + def grad_impl(dy, x): + dx = grad_dx(dy, x, y) + db = grad_db(dx) + def grad2(d_dx, d_db): + d_dy = grad2_d_dy(d_dx, d_db, x, y) + d_x = grad2_d_x(d_dx, d_db, x, y) + return d_dy, d_x + return (dx, db), grad2 + return grad_impl(dy, x) + return y, grad_wrap + + # Which version to use? + if act_spec.zero_2nd_grad: + return func_zero_2nd_grad(x, b) + return func_nonzero_2nd_grad(x, b) + +#---------------------------------------------------------------------------- diff --git a/models/StyleCLIP/global_directions/dnnlib/tflib/ops/upfirdn_2d.cu b/models/StyleCLIP/global_directions/dnnlib/tflib/ops/upfirdn_2d.cu new file mode 100644 index 0000000000000000000000000000000000000000..7aad60d53e57d4f3e60f36a24df80a6278f1bb63 --- /dev/null +++ b/models/StyleCLIP/global_directions/dnnlib/tflib/ops/upfirdn_2d.cu @@ -0,0 +1,359 @@ +// Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +// +// NVIDIA CORPORATION and its licensors retain all intellectual property +// and proprietary rights in and to this software, related documentation +// and any modifications thereto. Any use, reproduction, disclosure or +// distribution of this software and related documentation without an express +// license agreement from NVIDIA CORPORATION is strictly prohibited. + +#define EIGEN_USE_GPU +#define __CUDA_INCLUDE_COMPILER_INTERNAL_HEADERS__ +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/shape_inference.h" +#include + +using namespace tensorflow; +using namespace tensorflow::shape_inference; + +//------------------------------------------------------------------------ +// Helpers. + +#define OP_CHECK_CUDA_ERROR(CTX, CUDA_CALL) do { cudaError_t err = CUDA_CALL; OP_REQUIRES(CTX, err == cudaSuccess, errors::Internal(cudaGetErrorName(err))); } while (false) + +static __host__ __device__ __forceinline__ int floorDiv(int a, int b) +{ + int t = 1 - a / b; + return (a + t * b) / b - t; +} + +//------------------------------------------------------------------------ +// CUDA kernel params. + +template +struct UpFirDn2DKernelParams +{ + const T* x; // [majorDim, inH, inW, minorDim] + const T* k; // [kernelH, kernelW] + T* y; // [majorDim, outH, outW, minorDim] + + int upx; + int upy; + int downx; + int downy; + int padx0; + int padx1; + int pady0; + int pady1; + + int majorDim; + int inH; + int inW; + int minorDim; + int kernelH; + int kernelW; + int outH; + int outW; + int loopMajor; + int loopX; +}; + +//------------------------------------------------------------------------ +// General CUDA implementation for large filter kernels. + +template +static __global__ void UpFirDn2DKernel_large(const UpFirDn2DKernelParams p) +{ + // Calculate thread index. + int minorIdx = blockIdx.x * blockDim.x + threadIdx.x; + int outY = minorIdx / p.minorDim; + minorIdx -= outY * p.minorDim; + int outXBase = blockIdx.y * p.loopX * blockDim.y + threadIdx.y; + int majorIdxBase = blockIdx.z * p.loopMajor; + if (outXBase >= p.outW || outY >= p.outH || majorIdxBase >= p.majorDim) + return; + + // Setup Y receptive field. + int midY = outY * p.downy + p.upy - 1 - p.pady0; + int inY = min(max(floorDiv(midY, p.upy), 0), p.inH); + int h = min(max(floorDiv(midY + p.kernelH, p.upy), 0), p.inH) - inY; + int kernelY = midY + p.kernelH - (inY + 1) * p.upy; + + // Loop over majorDim and outX. + for (int loopMajor = 0, majorIdx = majorIdxBase; loopMajor < p.loopMajor && majorIdx < p.majorDim; loopMajor++, majorIdx++) + for (int loopX = 0, outX = outXBase; loopX < p.loopX && outX < p.outW; loopX++, outX += blockDim.y) + { + // Setup X receptive field. + int midX = outX * p.downx + p.upx - 1 - p.padx0; + int inX = min(max(floorDiv(midX, p.upx), 0), p.inW); + int w = min(max(floorDiv(midX + p.kernelW, p.upx), 0), p.inW) - inX; + int kernelX = midX + p.kernelW - (inX + 1) * p.upx; + + // Initialize pointers. + const T* xp = &p.x[((majorIdx * p.inH + inY) * p.inW + inX) * p.minorDim + minorIdx]; + const T* kp = &p.k[kernelY * p.kernelW + kernelX]; + int xpx = p.minorDim; + int kpx = -p.upx; + int xpy = p.inW * p.minorDim; + int kpy = -p.upy * p.kernelW; + + // Inner loop. + float v = 0.0f; + for (int y = 0; y < h; y++) + { + for (int x = 0; x < w; x++) + { + v += (float)(*xp) * (float)(*kp); + xp += xpx; + kp += kpx; + } + xp += xpy - w * xpx; + kp += kpy - w * kpx; + } + + // Store result. + p.y[((majorIdx * p.outH + outY) * p.outW + outX) * p.minorDim + minorIdx] = (T)v; + } +} + +//------------------------------------------------------------------------ +// Specialized CUDA implementation for small filter kernels. + +template +static __global__ void UpFirDn2DKernel_small(const UpFirDn2DKernelParams p) +{ + //assert(kernelW % upx == 0); + //assert(kernelH % upy == 0); + const int tileInW = ((tileOutW - 1) * downx + kernelW - 1) / upx + 1; + const int tileInH = ((tileOutH - 1) * downy + kernelH - 1) / upy + 1; + __shared__ volatile float sk[kernelH][kernelW]; + __shared__ volatile float sx[tileInH][tileInW]; + + // Calculate tile index. + int minorIdx = blockIdx.x; + int tileOutY = minorIdx / p.minorDim; + minorIdx -= tileOutY * p.minorDim; + tileOutY *= tileOutH; + int tileOutXBase = blockIdx.y * p.loopX * tileOutW; + int majorIdxBase = blockIdx.z * p.loopMajor; + if (tileOutXBase >= p.outW | tileOutY >= p.outH | majorIdxBase >= p.majorDim) + return; + + // Load filter kernel (flipped). + for (int tapIdx = threadIdx.x; tapIdx < kernelH * kernelW; tapIdx += blockDim.x) + { + int ky = tapIdx / kernelW; + int kx = tapIdx - ky * kernelW; + float v = 0.0f; + if (kx < p.kernelW & ky < p.kernelH) + v = (float)p.k[(p.kernelH - 1 - ky) * p.kernelW + (p.kernelW - 1 - kx)]; + sk[ky][kx] = v; + } + + // Loop over majorDim and outX. + for (int loopMajor = 0, majorIdx = majorIdxBase; loopMajor < p.loopMajor & majorIdx < p.majorDim; loopMajor++, majorIdx++) + for (int loopX = 0, tileOutX = tileOutXBase; loopX < p.loopX & tileOutX < p.outW; loopX++, tileOutX += tileOutW) + { + // Load input pixels. + int tileMidX = tileOutX * downx + upx - 1 - p.padx0; + int tileMidY = tileOutY * downy + upy - 1 - p.pady0; + int tileInX = floorDiv(tileMidX, upx); + int tileInY = floorDiv(tileMidY, upy); + __syncthreads(); + for (int inIdx = threadIdx.x; inIdx < tileInH * tileInW; inIdx += blockDim.x) + { + int relInY = inIdx / tileInW; + int relInX = inIdx - relInY * tileInW; + int inX = relInX + tileInX; + int inY = relInY + tileInY; + float v = 0.0f; + if (inX >= 0 & inY >= 0 & inX < p.inW & inY < p.inH) + v = (float)p.x[((majorIdx * p.inH + inY) * p.inW + inX) * p.minorDim + minorIdx]; + sx[relInY][relInX] = v; + } + + // Loop over output pixels. + __syncthreads(); + for (int outIdx = threadIdx.x; outIdx < tileOutH * tileOutW; outIdx += blockDim.x) + { + int relOutY = outIdx / tileOutW; + int relOutX = outIdx - relOutY * tileOutW; + int outX = relOutX + tileOutX; + int outY = relOutY + tileOutY; + + // Setup receptive field. + int midX = tileMidX + relOutX * downx; + int midY = tileMidY + relOutY * downy; + int inX = floorDiv(midX, upx); + int inY = floorDiv(midY, upy); + int relInX = inX - tileInX; + int relInY = inY - tileInY; + int kernelX = (inX + 1) * upx - midX - 1; // flipped + int kernelY = (inY + 1) * upy - midY - 1; // flipped + + // Inner loop. + float v = 0.0f; + #pragma unroll + for (int y = 0; y < kernelH / upy; y++) + #pragma unroll + for (int x = 0; x < kernelW / upx; x++) + v += sx[relInY + y][relInX + x] * sk[kernelY + y * upy][kernelX + x * upx]; + + // Store result. + if (outX < p.outW & outY < p.outH) + p.y[((majorIdx * p.outH + outY) * p.outW + outX) * p.minorDim + minorIdx] = (T)v; + } + } +} + +//------------------------------------------------------------------------ +// TensorFlow op. + +template +struct UpFirDn2DOp : public OpKernel +{ + UpFirDn2DKernelParams m_attribs; + + UpFirDn2DOp(OpKernelConstruction* ctx) : OpKernel(ctx) + { + memset(&m_attribs, 0, sizeof(m_attribs)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("upx", &m_attribs.upx)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("upy", &m_attribs.upy)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("downx", &m_attribs.downx)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("downy", &m_attribs.downy)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("padx0", &m_attribs.padx0)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("padx1", &m_attribs.padx1)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("pady0", &m_attribs.pady0)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("pady1", &m_attribs.pady1)); + OP_REQUIRES(ctx, m_attribs.upx >= 1 && m_attribs.upy >= 1, errors::InvalidArgument("upx and upy must be at least 1x1")); + OP_REQUIRES(ctx, m_attribs.downx >= 1 && m_attribs.downy >= 1, errors::InvalidArgument("downx and downy must be at least 1x1")); + } + + void Compute(OpKernelContext* ctx) + { + UpFirDn2DKernelParams p = m_attribs; + cudaStream_t stream = ctx->eigen_device().stream(); + + const Tensor& x = ctx->input(0); // [majorDim, inH, inW, minorDim] + const Tensor& k = ctx->input(1); // [kernelH, kernelW] + p.x = x.flat().data(); + p.k = k.flat().data(); + OP_REQUIRES(ctx, x.dims() == 4, errors::InvalidArgument("input must have rank 4")); + OP_REQUIRES(ctx, k.dims() == 2, errors::InvalidArgument("kernel must have rank 2")); + OP_REQUIRES(ctx, x.NumElements() <= kint32max, errors::InvalidArgument("input too large")); + OP_REQUIRES(ctx, k.NumElements() <= kint32max, errors::InvalidArgument("kernel too large")); + + p.majorDim = (int)x.dim_size(0); + p.inH = (int)x.dim_size(1); + p.inW = (int)x.dim_size(2); + p.minorDim = (int)x.dim_size(3); + p.kernelH = (int)k.dim_size(0); + p.kernelW = (int)k.dim_size(1); + OP_REQUIRES(ctx, p.kernelW >= 1 && p.kernelH >= 1, errors::InvalidArgument("kernel must be at least 1x1")); + + p.outW = (p.inW * p.upx + p.padx0 + p.padx1 - p.kernelW + p.downx) / p.downx; + p.outH = (p.inH * p.upy + p.pady0 + p.pady1 - p.kernelH + p.downy) / p.downy; + OP_REQUIRES(ctx, p.outW >= 1 && p.outH >= 1, errors::InvalidArgument("output must be at least 1x1")); + + Tensor* y = NULL; // [majorDim, outH, outW, minorDim] + TensorShape ys; + ys.AddDim(p.majorDim); + ys.AddDim(p.outH); + ys.AddDim(p.outW); + ys.AddDim(p.minorDim); + OP_REQUIRES_OK(ctx, ctx->allocate_output(0, ys, &y)); + p.y = y->flat().data(); + OP_REQUIRES(ctx, y->NumElements() <= kint32max, errors::InvalidArgument("output too large")); + + // Choose CUDA kernel to use. + void* cudaKernel = (void*)UpFirDn2DKernel_large; + int tileOutW = -1; + int tileOutH = -1; + + if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 7 && p.kernelH <= 7 ) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 64; tileOutH = 16; } + if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 6 && p.kernelH <= 6 ) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 64; tileOutH = 16; } + if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 5 && p.kernelH <= 5 ) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 64; tileOutH = 16; } + if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 4 && p.kernelH <= 4 ) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 64; tileOutH = 16; } + if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 3 && p.kernelH <= 3 ) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 64; tileOutH = 16; } + if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 24 && p.kernelH <= 1 ) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 128; tileOutH = 8; } + if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 20 && p.kernelH <= 1 ) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 128; tileOutH = 8; } + if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 16 && p.kernelH <= 1 ) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 128; tileOutH = 8; } + if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 12 && p.kernelH <= 1 ) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 128; tileOutH = 8; } + if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 8 && p.kernelH <= 1 ) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 128; tileOutH = 8; } + if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 1 && p.kernelH <= 24) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 32; tileOutH = 32; } + if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 1 && p.kernelH <= 20) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 32; tileOutH = 32; } + if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 1 && p.kernelH <= 16) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 32; tileOutH = 32; } + if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 1 && p.kernelH <= 12) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 32; tileOutH = 32; } + if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 1 && p.kernelH <= 8 ) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 32; tileOutH = 32; } + + if (p.upx == 2 && p.upy == 2 && p.downx == 1 && p.downy == 1 && p.kernelW <= 8 && p.kernelH <= 8 ) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 64; tileOutH = 16; } + if (p.upx == 2 && p.upy == 2 && p.downx == 1 && p.downy == 1 && p.kernelW <= 6 && p.kernelH <= 6 ) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 64; tileOutH = 16; } + if (p.upx == 2 && p.upy == 2 && p.downx == 1 && p.downy == 1 && p.kernelW <= 4 && p.kernelH <= 4 ) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 64; tileOutH = 16; } + if (p.upx == 2 && p.upy == 2 && p.downx == 1 && p.downy == 1 && p.kernelW <= 2 && p.kernelH <= 2 ) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 64; tileOutH = 16; } + if (p.upx == 2 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 24 && p.kernelH <= 1 ) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 128; tileOutH = 8; } + if (p.upx == 2 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 20 && p.kernelH <= 1 ) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 128; tileOutH = 8; } + if (p.upx == 2 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 16 && p.kernelH <= 1 ) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 128; tileOutH = 8; } + if (p.upx == 2 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 12 && p.kernelH <= 1 ) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 128; tileOutH = 8; } + if (p.upx == 2 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 8 && p.kernelH <= 1 ) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 128; tileOutH = 8; } + if (p.upx == 1 && p.upy == 2 && p.downx == 1 && p.downy == 1 && p.kernelW <= 1 && p.kernelH <= 24) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 32; tileOutH = 32; } + if (p.upx == 1 && p.upy == 2 && p.downx == 1 && p.downy == 1 && p.kernelW <= 1 && p.kernelH <= 20) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 32; tileOutH = 32; } + if (p.upx == 1 && p.upy == 2 && p.downx == 1 && p.downy == 1 && p.kernelW <= 1 && p.kernelH <= 16) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 32; tileOutH = 32; } + if (p.upx == 1 && p.upy == 2 && p.downx == 1 && p.downy == 1 && p.kernelW <= 1 && p.kernelH <= 12) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 32; tileOutH = 32; } + if (p.upx == 1 && p.upy == 2 && p.downx == 1 && p.downy == 1 && p.kernelW <= 1 && p.kernelH <= 8 ) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 32; tileOutH = 32; } + + if (p.upx == 1 && p.upy == 1 && p.downx == 2 && p.downy == 2 && p.kernelW <= 8 && p.kernelH <= 8 ) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 32; tileOutH = 8; } + if (p.upx == 1 && p.upy == 1 && p.downx == 2 && p.downy == 2 && p.kernelW <= 6 && p.kernelH <= 6 ) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 32; tileOutH = 8; } + if (p.upx == 1 && p.upy == 1 && p.downx == 2 && p.downy == 2 && p.kernelW <= 4 && p.kernelH <= 4 ) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 32; tileOutH = 8; } + if (p.upx == 1 && p.upy == 1 && p.downx == 2 && p.downy == 2 && p.kernelW <= 2 && p.kernelH <= 2 ) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 32; tileOutH = 8; } + if (p.upx == 1 && p.upy == 1 && p.downx == 2 && p.downy == 1 && p.kernelW <= 24 && p.kernelH <= 1 ) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 64; tileOutH = 8; } + if (p.upx == 1 && p.upy == 1 && p.downx == 2 && p.downy == 1 && p.kernelW <= 20 && p.kernelH <= 1 ) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 64; tileOutH = 8; } + if (p.upx == 1 && p.upy == 1 && p.downx == 2 && p.downy == 1 && p.kernelW <= 16 && p.kernelH <= 1 ) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 64; tileOutH = 8; } + if (p.upx == 1 && p.upy == 1 && p.downx == 2 && p.downy == 1 && p.kernelW <= 12 && p.kernelH <= 1 ) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 64; tileOutH = 8; } + if (p.upx == 1 && p.upy == 1 && p.downx == 2 && p.downy == 1 && p.kernelW <= 8 && p.kernelH <= 1 ) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 64; tileOutH = 8; } + if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 2 && p.kernelW <= 1 && p.kernelH <= 24) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 32; tileOutH = 16; } + if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 2 && p.kernelW <= 1 && p.kernelH <= 20) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 32; tileOutH = 16; } + if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 2 && p.kernelW <= 1 && p.kernelH <= 16) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 32; tileOutH = 16; } + if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 2 && p.kernelW <= 1 && p.kernelH <= 12) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 32; tileOutH = 16; } + if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 2 && p.kernelW <= 1 && p.kernelH <= 8 ) { cudaKernel = (void*)UpFirDn2DKernel_small; tileOutW = 32; tileOutH = 16; } + + // Choose launch params. + dim3 blockSize; + dim3 gridSize; + if (tileOutW > 0 && tileOutH > 0) // small + { + p.loopMajor = (p.majorDim - 1) / 16384 + 1; + p.loopX = 1; + blockSize = dim3(32 * 8, 1, 1); + gridSize = dim3(((p.outH - 1) / tileOutH + 1) * p.minorDim, (p.outW - 1) / (p.loopX * tileOutW) + 1, (p.majorDim - 1) / p.loopMajor + 1); + } + else // large + { + p.loopMajor = (p.majorDim - 1) / 16384 + 1; + p.loopX = 4; + blockSize = dim3(4, 32, 1); + gridSize = dim3((p.outH * p.minorDim - 1) / blockSize.x + 1, (p.outW - 1) / (p.loopX * blockSize.y) + 1, (p.majorDim - 1) / p.loopMajor + 1); + } + + // Launch CUDA kernel. + void* args[] = {&p}; + OP_CHECK_CUDA_ERROR(ctx, cudaLaunchKernel(cudaKernel, gridSize, blockSize, args, 0, stream)); + } +}; + +REGISTER_OP("UpFirDn2D") + .Input ("x: T") + .Input ("k: T") + .Output ("y: T") + .Attr ("T: {float, half}") + .Attr ("upx: int = 1") + .Attr ("upy: int = 1") + .Attr ("downx: int = 1") + .Attr ("downy: int = 1") + .Attr ("padx0: int = 0") + .Attr ("padx1: int = 0") + .Attr ("pady0: int = 0") + .Attr ("pady1: int = 0"); +REGISTER_KERNEL_BUILDER(Name("UpFirDn2D").Device(DEVICE_GPU).TypeConstraint("T"), UpFirDn2DOp); +REGISTER_KERNEL_BUILDER(Name("UpFirDn2D").Device(DEVICE_GPU).TypeConstraint("T"), UpFirDn2DOp); + +//------------------------------------------------------------------------ diff --git a/models/StyleCLIP/global_directions/dnnlib/tflib/ops/upfirdn_2d.py b/models/StyleCLIP/global_directions/dnnlib/tflib/ops/upfirdn_2d.py new file mode 100644 index 0000000000000000000000000000000000000000..55a31af7e146da7afeb964db018f14aca3134920 --- /dev/null +++ b/models/StyleCLIP/global_directions/dnnlib/tflib/ops/upfirdn_2d.py @@ -0,0 +1,418 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +"""Custom TensorFlow ops for efficient resampling of 2D images.""" + +import os +import numpy as np +import tensorflow as tf +from .. import custom_ops + +def _get_plugin(): + return custom_ops.get_plugin(os.path.splitext(__file__)[0] + '.cu') + +#---------------------------------------------------------------------------- + +def upfirdn_2d(x, k, upx=1, upy=1, downx=1, downy=1, padx0=0, padx1=0, pady0=0, pady1=0, impl='cuda'): + r"""Pad, upsample, FIR filter, and downsample a batch of 2D images. + + Accepts a batch of 2D images of the shape `[majorDim, inH, inW, minorDim]` + and performs the following operations for each image, batched across + `majorDim` and `minorDim`: + + 1. Upsample the image by inserting the zeros after each pixel (`upx`, `upy`). + + 2. Pad the image with zeros by the specified number of pixels on each side + (`padx0`, `padx1`, `pady0`, `pady1`). Specifying a negative value + corresponds to cropping the image. + + 3. Convolve the image with the specified 2D FIR filter (`k`), shrinking the + image so that the footprint of all output pixels lies within the input image. + + 4. Downsample the image by throwing away pixels (`downx`, `downy`). + + This sequence of operations bears close resemblance to scipy.signal.upfirdn(). + The fused op is considerably more efficient than performing the same calculation + using standard TensorFlow ops. It supports gradients of arbitrary order. + + Args: + x: Input tensor of the shape `[majorDim, inH, inW, minorDim]`. + k: 2D FIR filter of the shape `[firH, firW]`. + upx: Integer upsampling factor along the X-axis (default: 1). + upy: Integer upsampling factor along the Y-axis (default: 1). + downx: Integer downsampling factor along the X-axis (default: 1). + downy: Integer downsampling factor along the Y-axis (default: 1). + padx0: Number of pixels to pad on the left side (default: 0). + padx1: Number of pixels to pad on the right side (default: 0). + pady0: Number of pixels to pad on the top side (default: 0). + pady1: Number of pixels to pad on the bottom side (default: 0). + impl: Name of the implementation to use. Can be `"ref"` or `"cuda"` (default). + + Returns: + Tensor of the shape `[majorDim, outH, outW, minorDim]`, and same datatype as `x`. + """ + + impl_dict = { + 'ref': _upfirdn_2d_ref, + 'cuda': _upfirdn_2d_cuda, + } + return impl_dict[impl](x=x, k=k, upx=upx, upy=upy, downx=downx, downy=downy, padx0=padx0, padx1=padx1, pady0=pady0, pady1=pady1) + +#---------------------------------------------------------------------------- + +def _upfirdn_2d_ref(x, k, upx, upy, downx, downy, padx0, padx1, pady0, pady1): + """Slow reference implementation of `upfirdn_2d()` using standard TensorFlow ops.""" + + x = tf.convert_to_tensor(x) + k = np.asarray(k, dtype=np.float32) + assert x.shape.rank == 4 + inH = x.shape[1].value + inW = x.shape[2].value + minorDim = _shape(x, 3) + kernelH, kernelW = k.shape + assert inW >= 1 and inH >= 1 + assert kernelW >= 1 and kernelH >= 1 + assert isinstance(upx, int) and isinstance(upy, int) + assert isinstance(downx, int) and isinstance(downy, int) + assert isinstance(padx0, int) and isinstance(padx1, int) + assert isinstance(pady0, int) and isinstance(pady1, int) + + # Upsample (insert zeros). + x = tf.reshape(x, [-1, inH, 1, inW, 1, minorDim]) + x = tf.pad(x, [[0, 0], [0, 0], [0, upy - 1], [0, 0], [0, upx - 1], [0, 0]]) + x = tf.reshape(x, [-1, inH * upy, inW * upx, minorDim]) + + # Pad (crop if negative). + x = tf.pad(x, [[0, 0], [max(pady0, 0), max(pady1, 0)], [max(padx0, 0), max(padx1, 0)], [0, 0]]) + x = x[:, max(-pady0, 0) : x.shape[1].value - max(-pady1, 0), max(-padx0, 0) : x.shape[2].value - max(-padx1, 0), :] + + # Convolve with filter. + x = tf.transpose(x, [0, 3, 1, 2]) + x = tf.reshape(x, [-1, 1, inH * upy + pady0 + pady1, inW * upx + padx0 + padx1]) + w = tf.constant(k[::-1, ::-1, np.newaxis, np.newaxis], dtype=x.dtype) + x = tf.nn.conv2d(x, w, strides=[1,1,1,1], padding='VALID', data_format='NCHW') + x = tf.reshape(x, [-1, minorDim, inH * upy + pady0 + pady1 - kernelH + 1, inW * upx + padx0 + padx1 - kernelW + 1]) + x = tf.transpose(x, [0, 2, 3, 1]) + + # Downsample (throw away pixels). + return x[:, ::downy, ::downx, :] + +#---------------------------------------------------------------------------- + +def _upfirdn_2d_cuda(x, k, upx, upy, downx, downy, padx0, padx1, pady0, pady1): + """Fast CUDA implementation of `upfirdn_2d()` using custom ops.""" + + x = tf.convert_to_tensor(x) + k = np.asarray(k, dtype=np.float32) + majorDim, inH, inW, minorDim = x.shape.as_list() + kernelH, kernelW = k.shape + assert inW >= 1 and inH >= 1 + assert kernelW >= 1 and kernelH >= 1 + assert isinstance(upx, int) and isinstance(upy, int) + assert isinstance(downx, int) and isinstance(downy, int) + assert isinstance(padx0, int) and isinstance(padx1, int) + assert isinstance(pady0, int) and isinstance(pady1, int) + + outW = (inW * upx + padx0 + padx1 - kernelW) // downx + 1 + outH = (inH * upy + pady0 + pady1 - kernelH) // downy + 1 + assert outW >= 1 and outH >= 1 + + cuda_op = _get_plugin().up_fir_dn2d + kc = tf.constant(k, dtype=x.dtype) + gkc = tf.constant(k[::-1, ::-1], dtype=x.dtype) + gpadx0 = kernelW - padx0 - 1 + gpady0 = kernelH - pady0 - 1 + gpadx1 = inW * upx - outW * downx + padx0 - upx + 1 + gpady1 = inH * upy - outH * downy + pady0 - upy + 1 + + @tf.custom_gradient + def func(x): + y = cuda_op(x=x, k=kc, upx=int(upx), upy=int(upy), downx=int(downx), downy=int(downy), padx0=int(padx0), padx1=int(padx1), pady0=int(pady0), pady1=int(pady1)) + y.set_shape([majorDim, outH, outW, minorDim]) + @tf.custom_gradient + def grad(dy): + dx = cuda_op(x=dy, k=gkc, upx=int(downx), upy=int(downy), downx=int(upx), downy=int(upy), padx0=int(gpadx0), padx1=int(gpadx1), pady0=int(gpady0), pady1=int(gpady1)) + dx.set_shape([majorDim, inH, inW, minorDim]) + return dx, func + return y, grad + return func(x) + +#---------------------------------------------------------------------------- + +def filter_2d(x, k, gain=1, padding=0, data_format='NCHW', impl='cuda'): + r"""Filter a batch of 2D images with the given FIR filter. + + Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` + and filters each image with the given filter. The filter is normalized so that + if the input pixels are constant, they will be scaled by the specified `gain`. + Pixels outside the image are assumed to be zero. + + Args: + x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`. + k: FIR filter of the shape `[firH, firW]` or `[firN]` (separable). + gain: Scaling factor for signal magnitude (default: 1.0). + padding: Number of pixels to pad or crop the output on each side (default: 0). + data_format: `'NCHW'` or `'NHWC'` (default: `'NCHW'`). + impl: Name of the implementation to use. Can be `"ref"` or `"cuda"` (default). + + Returns: + Tensor of the same shape and datatype as `x`. + """ + + assert isinstance(padding, int) + k = _FilterKernel(k=k, gain=gain) + assert k.w == k.h + pad0 = k.w // 2 + padding + pad1 = (k.w - 1) // 2 + padding + return _simple_upfirdn_2d(x, k, pad0=pad0, pad1=pad1, data_format=data_format, impl=impl) + +#---------------------------------------------------------------------------- + +def upsample_2d(x, k=None, factor=2, gain=1, padding=0, data_format='NCHW', impl='cuda'): + r"""Upsample a batch of 2D images with the given filter. + + Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` + and upsamples each image with the given filter. The filter is normalized so that + if the input pixels are constant, they will be scaled by the specified `gain`. + Pixels outside the image are assumed to be zero, and the filter is padded with + zeros so that its shape is a multiple of the upsampling factor. + + Args: + x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`. + k: FIR filter of the shape `[firH, firW]` or `[firN]` (separable). + The default is `[1] * factor`, which corresponds to nearest-neighbor + upsampling. + factor: Integer upsampling factor (default: 2). + gain: Scaling factor for signal magnitude (default: 1.0). + padding: Number of pixels to pad or crop the output on each side (default: 0). + data_format: `'NCHW'` or `'NHWC'` (default: `'NCHW'`). + impl: Name of the implementation to use. Can be `"ref"` or `"cuda"` (default). + + Returns: + Tensor of the shape `[N, C, H * factor, W * factor]` or + `[N, H * factor, W * factor, C]`, and same datatype as `x`. + """ + + assert isinstance(factor, int) and factor >= 1 + assert isinstance(padding, int) + k = _FilterKernel(k if k is not None else [1] * factor, gain * (factor ** 2)) + assert k.w == k.h + pad0 = (k.w + factor - 1) // 2 + padding + pad1 = (k.w - factor) // 2 + padding + return _simple_upfirdn_2d(x, k, up=factor, pad0=pad0, pad1=pad1, data_format=data_format, impl=impl) + +#---------------------------------------------------------------------------- + +def downsample_2d(x, k=None, factor=2, gain=1, padding=0, data_format='NCHW', impl='cuda'): + r"""Downsample a batch of 2D images with the given filter. + + Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` + and downsamples each image with the given filter. The filter is normalized so that + if the input pixels are constant, they will be scaled by the specified `gain`. + Pixels outside the image are assumed to be zero, and the filter is padded with + zeros so that its shape is a multiple of the downsampling factor. + + Args: + x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`. + k: FIR filter of the shape `[firH, firW]` or `[firN]` (separable). + The default is `[1] * factor`, which corresponds to average pooling. + factor: Integer downsampling factor (default: 2). + gain: Scaling factor for signal magnitude (default: 1.0). + padding: Number of pixels to pad or crop the output on each side (default: 0). + data_format: `'NCHW'` or `'NHWC'` (default: `'NCHW'`). + impl: Name of the implementation to use. Can be `"ref"` or `"cuda"` (default). + + Returns: + Tensor of the shape `[N, C, H // factor, W // factor]` or + `[N, H // factor, W // factor, C]`, and same datatype as `x`. + """ + + assert isinstance(factor, int) and factor >= 1 + assert isinstance(padding, int) + k = _FilterKernel(k if k is not None else [1] * factor, gain) + assert k.w == k.h + pad0 = (k.w - factor + 1) // 2 + padding * factor + pad1 = (k.w - factor) // 2 + padding * factor + return _simple_upfirdn_2d(x, k, down=factor, pad0=pad0, pad1=pad1, data_format=data_format, impl=impl) + +#---------------------------------------------------------------------------- + +def upsample_conv_2d(x, w, k=None, factor=2, gain=1, padding=0, data_format='NCHW', impl='cuda'): + r"""Fused `upsample_2d()` followed by `tf.nn.conv2d()`. + + Padding is performed only once at the beginning, not between the operations. + The fused op is considerably more efficient than performing the same calculation + using standard TensorFlow ops. It supports gradients of arbitrary order. + + Args: + x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`. + w: Weight tensor of the shape `[filterH, filterW, inChannels, outChannels]`. + Grouped convolution can be performed by `inChannels = x.shape[0] // numGroups`. + k: FIR filter of the shape `[firH, firW]` or `[firN]` (separable). + The default is `[1] * factor`, which corresponds to nearest-neighbor + upsampling. + factor: Integer upsampling factor (default: 2). + gain: Scaling factor for signal magnitude (default: 1.0). + padding: Number of pixels to pad or crop the output on each side (default: 0). + data_format: `'NCHW'` or `'NHWC'` (default: `'NCHW'`). + impl: Name of the implementation to use. Can be `"ref"` or `"cuda"` (default). + + Returns: + Tensor of the shape `[N, C, H * factor, W * factor]` or + `[N, H * factor, W * factor, C]`, and same datatype as `x`. + """ + + assert isinstance(factor, int) and factor >= 1 + assert isinstance(padding, int) + + # Check weight shape. + w = tf.convert_to_tensor(w) + ch, cw, _inC, _outC = w.shape.as_list() + inC = _shape(w, 2) + outC = _shape(w, 3) + assert cw == ch + + # Fast path for 1x1 convolution. + if cw == 1 and ch == 1: + x = tf.nn.conv2d(x, w, data_format=data_format, strides=[1,1,1,1], padding='VALID') + x = upsample_2d(x, k, factor=factor, gain=gain, padding=padding, data_format=data_format, impl=impl) + return x + + # Setup filter kernel. + k = _FilterKernel(k if k is not None else [1] * factor, gain * (factor ** 2)) + assert k.w == k.h + + # Determine data dimensions. + if data_format == 'NCHW': + stride = [1, 1, factor, factor] + output_shape = [_shape(x, 0), outC, (_shape(x, 2) - 1) * factor + ch, (_shape(x, 3) - 1) * factor + cw] + num_groups = _shape(x, 1) // inC + else: + stride = [1, factor, factor, 1] + output_shape = [_shape(x, 0), (_shape(x, 1) - 1) * factor + ch, (_shape(x, 2) - 1) * factor + cw, outC] + num_groups = _shape(x, 3) // inC + + # Transpose weights. + w = tf.reshape(w, [ch, cw, inC, num_groups, -1]) + w = tf.transpose(w[::-1, ::-1], [0, 1, 4, 3, 2]) + w = tf.reshape(w, [ch, cw, -1, num_groups * inC]) + + # Execute. + x = tf.nn.conv2d_transpose(x, w, output_shape=output_shape, strides=stride, padding='VALID', data_format=data_format) + pad0 = (k.w + factor - cw) // 2 + padding + pad1 = (k.w - factor - cw + 3) // 2 + padding + return _simple_upfirdn_2d(x, k, pad0=pad0, pad1=pad1, data_format=data_format, impl=impl) + +#---------------------------------------------------------------------------- + +def conv_downsample_2d(x, w, k=None, factor=2, gain=1, padding=0, data_format='NCHW', impl='cuda'): + r"""Fused `tf.nn.conv2d()` followed by `downsample_2d()`. + + Padding is performed only once at the beginning, not between the operations. + The fused op is considerably more efficient than performing the same calculation + using standard TensorFlow ops. It supports gradients of arbitrary order. + + Args: + x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`. + w: Weight tensor of the shape `[filterH, filterW, inChannels, outChannels]`. + Grouped convolution can be performed by `inChannels = x.shape[0] // numGroups`. + k: FIR filter of the shape `[firH, firW]` or `[firN]` (separable). + The default is `[1] * factor`, which corresponds to average pooling. + factor: Integer downsampling factor (default: 2). + gain: Scaling factor for signal magnitude (default: 1.0). + padding: Number of pixels to pad or crop the output on each side (default: 0). + data_format: `'NCHW'` or `'NHWC'` (default: `'NCHW'`). + impl: Name of the implementation to use. Can be `"ref"` or `"cuda"` (default). + + Returns: + Tensor of the shape `[N, C, H // factor, W // factor]` or + `[N, H // factor, W // factor, C]`, and same datatype as `x`. + """ + + assert isinstance(factor, int) and factor >= 1 + assert isinstance(padding, int) + + # Check weight shape. + w = tf.convert_to_tensor(w) + ch, cw, _inC, _outC = w.shape.as_list() + assert cw == ch + + # Fast path for 1x1 convolution. + if cw == 1 and ch == 1: + x = downsample_2d(x, k, factor=factor, gain=gain, padding=padding, data_format=data_format, impl=impl) + x = tf.nn.conv2d(x, w, data_format=data_format, strides=[1,1,1,1], padding='VALID') + return x + + # Setup filter kernel. + k = _FilterKernel(k if k is not None else [1] * factor, gain) + assert k.w == k.h + + # Determine stride. + if data_format == 'NCHW': + s = [1, 1, factor, factor] + else: + s = [1, factor, factor, 1] + + # Execute. + pad0 = (k.w - factor + cw) // 2 + padding * factor + pad1 = (k.w - factor + cw - 1) // 2 + padding * factor + x = _simple_upfirdn_2d(x, k, pad0=pad0, pad1=pad1, data_format=data_format, impl=impl) + return tf.nn.conv2d(x, w, strides=s, padding='VALID', data_format=data_format) + +#---------------------------------------------------------------------------- +# Internal helpers. + +class _FilterKernel: + def __init__(self, k, gain=1): + k = np.asarray(k, dtype=np.float32) + k /= np.sum(k) + + # Separable. + if k.ndim == 1 and k.size >= 8: + self.w = k.size + self.h = k.size + self.kx = k[np.newaxis, :] + self.ky = k[:, np.newaxis] * gain + self.kxy = None + + # Non-separable. + else: + if k.ndim == 1: + k = np.outer(k, k) + assert k.ndim == 2 + self.w = k.shape[1] + self.h = k.shape[0] + self.kx = None + self.ky = None + self.kxy = k * gain + +def _simple_upfirdn_2d(x, k, up=1, down=1, pad0=0, pad1=0, data_format='NCHW', impl='cuda'): + assert isinstance(k, _FilterKernel) + assert data_format in ['NCHW', 'NHWC'] + assert x.shape.rank == 4 + y = x + if data_format == 'NCHW': + y = tf.reshape(y, [-1, _shape(y, 2), _shape(y, 3), 1]) + if k.kx is not None: + y = upfirdn_2d(y, k.kx, upx=up, downx=down, padx0=pad0, padx1=pad1, impl=impl) + if k.ky is not None: + y = upfirdn_2d(y, k.ky, upy=up, downy=down, pady0=pad0, pady1=pad1, impl=impl) + if k.kxy is not None: + y = upfirdn_2d(y, k.kxy, upx=up, upy=up, downx=down, downy=down, padx0=pad0, padx1=pad1, pady0=pad0, pady1=pad1, impl=impl) + if data_format == 'NCHW': + y = tf.reshape(y, [-1, _shape(x, 1), _shape(y, 1), _shape(y, 2)]) + return y + +def _shape(tf_expr, dim_idx): + if tf_expr.shape.rank is not None: + dim = tf_expr.shape[dim_idx].value + if dim is not None: + return dim + return tf.shape(tf_expr)[dim_idx] + +#---------------------------------------------------------------------------- diff --git a/models/StyleCLIP/global_directions/dnnlib/tflib/optimizer.py b/models/StyleCLIP/global_directions/dnnlib/tflib/optimizer.py new file mode 100644 index 0000000000000000000000000000000000000000..cae5ffff3d11aaccd705d6936e080175ab97dd0e --- /dev/null +++ b/models/StyleCLIP/global_directions/dnnlib/tflib/optimizer.py @@ -0,0 +1,372 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +"""Helper wrapper for a Tensorflow optimizer.""" + +import platform +import numpy as np +import tensorflow as tf + +from collections import OrderedDict +from typing import List, Union + +from . import autosummary +from . import tfutil +from .. import util + +from .tfutil import TfExpression, TfExpressionEx + +_collective_ops_warning_printed = False +_collective_ops_group_key = 831766147 +_collective_ops_instance_key = 436340067 + +class Optimizer: + """A Wrapper for tf.train.Optimizer. + + Automatically takes care of: + - Gradient averaging for multi-GPU training. + - Gradient accumulation for arbitrarily large minibatches. + - Dynamic loss scaling and typecasts for FP16 training. + - Ignoring corrupted gradients that contain NaNs/Infs. + - Reporting statistics. + - Well-chosen default settings. + """ + + def __init__(self, + name: str = "Train", # Name string that will appear in TensorFlow graph. + tf_optimizer: str = "tf.train.AdamOptimizer", # Underlying optimizer class. + learning_rate: TfExpressionEx = 0.001, # Learning rate. Can vary over time. + minibatch_multiplier: TfExpressionEx = None, # Treat N consecutive minibatches as one by accumulating gradients. + share: "Optimizer" = None, # Share internal state with a previously created optimizer? + use_loss_scaling: bool = False, # Enable dynamic loss scaling for robust mixed-precision training? + loss_scaling_init: float = 64.0, # Log2 of initial loss scaling factor. + loss_scaling_inc: float = 0.0005, # Log2 of per-minibatch loss scaling increment when there is no overflow. + loss_scaling_dec: float = 1.0, # Log2 of per-minibatch loss scaling decrement when there is an overflow. + report_mem_usage: bool = False, # Report fine-grained memory usage statistics in TensorBoard? + **kwargs): + + # Public fields. + self.name = name + self.learning_rate = learning_rate + self.minibatch_multiplier = minibatch_multiplier + self.id = self.name.replace("/", ".") + self.scope = tf.get_default_graph().unique_name(self.id) + self.optimizer_class = util.get_obj_by_name(tf_optimizer) + self.optimizer_kwargs = dict(kwargs) + self.use_loss_scaling = use_loss_scaling + self.loss_scaling_init = loss_scaling_init + self.loss_scaling_inc = loss_scaling_inc + self.loss_scaling_dec = loss_scaling_dec + + # Private fields. + self._updates_applied = False + self._devices = OrderedDict() # device_name => EasyDict() + self._shared_optimizers = OrderedDict() # device_name => optimizer_class + self._gradient_shapes = None # [shape, ...] + self._report_mem_usage = report_mem_usage + + # Validate arguments. + assert callable(self.optimizer_class) + + # Share internal state if requested. + if share is not None: + assert isinstance(share, Optimizer) + assert self.optimizer_class is share.optimizer_class + assert self.learning_rate is share.learning_rate + assert self.optimizer_kwargs == share.optimizer_kwargs + self._shared_optimizers = share._shared_optimizers # pylint: disable=protected-access + + def _get_device(self, device_name: str): + """Get internal state for the given TensorFlow device.""" + tfutil.assert_tf_initialized() + if device_name in self._devices: + return self._devices[device_name] + + # Initialize fields. + device = util.EasyDict() + device.name = device_name + device.optimizer = None # Underlying optimizer: optimizer_class + device.loss_scaling_var = None # Log2 of loss scaling: tf.Variable + device.grad_raw = OrderedDict() # Raw gradients: var => [grad, ...] + device.grad_clean = OrderedDict() # Clean gradients: var => grad + device.grad_acc_vars = OrderedDict() # Accumulation sums: var => tf.Variable + device.grad_acc_count = None # Accumulation counter: tf.Variable + device.grad_acc = OrderedDict() # Accumulated gradients: var => grad + + # Setup TensorFlow objects. + with tfutil.absolute_name_scope(self.scope + "/Devices"), tf.device(device_name), tf.control_dependencies(None): + if device_name not in self._shared_optimizers: + optimizer_name = self.scope.replace("/", "_") + "_opt%d" % len(self._shared_optimizers) + self._shared_optimizers[device_name] = self.optimizer_class(name=optimizer_name, learning_rate=self.learning_rate, **self.optimizer_kwargs) + device.optimizer = self._shared_optimizers[device_name] + if self.use_loss_scaling: + device.loss_scaling_var = tf.Variable(np.float32(self.loss_scaling_init), trainable=False, name="loss_scaling_var") + + # Register device. + self._devices[device_name] = device + return device + + def register_gradients(self, loss: TfExpression, trainable_vars: Union[List, dict]) -> None: + """Register the gradients of the given loss function with respect to the given variables. + Intended to be called once per GPU.""" + tfutil.assert_tf_initialized() + assert not self._updates_applied + device = self._get_device(loss.device) + + # Validate trainables. + if isinstance(trainable_vars, dict): + trainable_vars = list(trainable_vars.values()) # allow passing in Network.trainables as vars + assert isinstance(trainable_vars, list) and len(trainable_vars) >= 1 + assert all(tfutil.is_tf_expression(expr) for expr in trainable_vars + [loss]) + assert all(var.device == device.name for var in trainable_vars) + + # Validate shapes. + if self._gradient_shapes is None: + self._gradient_shapes = [var.shape.as_list() for var in trainable_vars] + assert len(trainable_vars) == len(self._gradient_shapes) + assert all(var.shape.as_list() == var_shape for var, var_shape in zip(trainable_vars, self._gradient_shapes)) + + # Report memory usage if requested. + deps = [loss] + if self._report_mem_usage: + self._report_mem_usage = False + try: + with tf.name_scope(self.id + '_mem'), tf.device(device.name), tf.control_dependencies([loss]): + deps.append(autosummary.autosummary(self.id + "/mem_usage_gb", tf.contrib.memory_stats.BytesInUse() / 2**30)) + except tf.errors.NotFoundError: + pass + + # Compute gradients. + with tf.name_scope(self.id + "_grad"), tf.device(device.name), tf.control_dependencies(deps): + loss = self.apply_loss_scaling(tf.cast(loss, tf.float32)) + gate = tf.train.Optimizer.GATE_NONE # disable gating to reduce memory usage + grad_list = device.optimizer.compute_gradients(loss=loss, var_list=trainable_vars, gate_gradients=gate) + + # Register gradients. + for grad, var in grad_list: + if var not in device.grad_raw: + device.grad_raw[var] = [] + device.grad_raw[var].append(grad) + + def apply_updates(self, allow_no_op: bool = False) -> tf.Operation: + """Construct training op to update the registered variables based on their gradients.""" + tfutil.assert_tf_initialized() + assert not self._updates_applied + self._updates_applied = True + all_ops = [] + + # Check for no-op. + if allow_no_op and len(self._devices) == 0: + with tfutil.absolute_name_scope(self.scope): + return tf.no_op(name='TrainingOp') + + # Clean up gradients. + for device_idx, device in enumerate(self._devices.values()): + with tfutil.absolute_name_scope(self.scope + "/Clean%d" % device_idx), tf.device(device.name): + for var, grad in device.grad_raw.items(): + + # Filter out disconnected gradients and convert to float32. + grad = [g for g in grad if g is not None] + grad = [tf.cast(g, tf.float32) for g in grad] + + # Sum within the device. + if len(grad) == 0: + grad = tf.zeros(var.shape) # No gradients => zero. + elif len(grad) == 1: + grad = grad[0] # Single gradient => use as is. + else: + grad = tf.add_n(grad) # Multiple gradients => sum. + + # Scale as needed. + scale = 1.0 / len(device.grad_raw[var]) / len(self._devices) + scale = tf.constant(scale, dtype=tf.float32, name="scale") + if self.minibatch_multiplier is not None: + scale /= tf.cast(self.minibatch_multiplier, tf.float32) + scale = self.undo_loss_scaling(scale) + device.grad_clean[var] = grad * scale + + # Sum gradients across devices. + if len(self._devices) > 1: + with tfutil.absolute_name_scope(self.scope + "/Broadcast"), tf.device(None): + if platform.system() == "Windows": # Windows => NCCL ops are not available. + self._broadcast_fallback() + elif tf.VERSION.startswith("1.15."): # TF 1.15 => NCCL ops are broken: https://github.com/tensorflow/tensorflow/issues/41539 + self._broadcast_fallback() + else: # Otherwise => NCCL ops are safe to use. + self._broadcast_nccl() + + # Apply updates separately on each device. + for device_idx, device in enumerate(self._devices.values()): + with tfutil.absolute_name_scope(self.scope + "/Apply%d" % device_idx), tf.device(device.name): + # pylint: disable=cell-var-from-loop + + # Accumulate gradients over time. + if self.minibatch_multiplier is None: + acc_ok = tf.constant(True, name='acc_ok') + device.grad_acc = OrderedDict(device.grad_clean) + else: + # Create variables. + with tf.control_dependencies(None): + for var in device.grad_clean.keys(): + device.grad_acc_vars[var] = tf.Variable(tf.zeros(var.shape), trainable=False, name="grad_acc_var") + device.grad_acc_count = tf.Variable(tf.zeros([]), trainable=False, name="grad_acc_count") + + # Track counter. + count_cur = device.grad_acc_count + 1.0 + count_inc_op = lambda: tf.assign(device.grad_acc_count, count_cur) + count_reset_op = lambda: tf.assign(device.grad_acc_count, tf.zeros([])) + acc_ok = (count_cur >= tf.cast(self.minibatch_multiplier, tf.float32)) + all_ops.append(tf.cond(acc_ok, count_reset_op, count_inc_op)) + + # Track gradients. + for var, grad in device.grad_clean.items(): + acc_var = device.grad_acc_vars[var] + acc_cur = acc_var + grad + device.grad_acc[var] = acc_cur + with tf.control_dependencies([acc_cur]): + acc_inc_op = lambda: tf.assign(acc_var, acc_cur) + acc_reset_op = lambda: tf.assign(acc_var, tf.zeros(var.shape)) + all_ops.append(tf.cond(acc_ok, acc_reset_op, acc_inc_op)) + + # No overflow => apply gradients. + all_ok = tf.reduce_all(tf.stack([acc_ok] + [tf.reduce_all(tf.is_finite(g)) for g in device.grad_acc.values()])) + apply_op = lambda: device.optimizer.apply_gradients([(tf.cast(grad, var.dtype), var) for var, grad in device.grad_acc.items()]) + all_ops.append(tf.cond(all_ok, apply_op, tf.no_op)) + + # Adjust loss scaling. + if self.use_loss_scaling: + ls_inc_op = lambda: tf.assign_add(device.loss_scaling_var, self.loss_scaling_inc) + ls_dec_op = lambda: tf.assign_sub(device.loss_scaling_var, self.loss_scaling_dec) + ls_update_op = lambda: tf.group(tf.cond(all_ok, ls_inc_op, ls_dec_op)) + all_ops.append(tf.cond(acc_ok, ls_update_op, tf.no_op)) + + # Last device => report statistics. + if device_idx == len(self._devices) - 1: + all_ops.append(autosummary.autosummary(self.id + "/learning_rate", tf.convert_to_tensor(self.learning_rate))) + all_ops.append(autosummary.autosummary(self.id + "/overflow_frequency", tf.where(all_ok, 0, 1), condition=acc_ok)) + if self.use_loss_scaling: + all_ops.append(autosummary.autosummary(self.id + "/loss_scaling_log2", device.loss_scaling_var)) + + # Initialize variables. + self.reset_optimizer_state() + if self.use_loss_scaling: + tfutil.init_uninitialized_vars([device.loss_scaling_var for device in self._devices.values()]) + if self.minibatch_multiplier is not None: + tfutil.run([var.initializer for device in self._devices.values() for var in list(device.grad_acc_vars.values()) + [device.grad_acc_count]]) + + # Group everything into a single op. + with tfutil.absolute_name_scope(self.scope): + return tf.group(*all_ops, name="TrainingOp") + + def reset_optimizer_state(self) -> None: + """Reset internal state of the underlying optimizer.""" + tfutil.assert_tf_initialized() + tfutil.run([var.initializer for device in self._devices.values() for var in device.optimizer.variables()]) + + def get_loss_scaling_var(self, device: str) -> Union[tf.Variable, None]: + """Get or create variable representing log2 of the current dynamic loss scaling factor.""" + return self._get_device(device).loss_scaling_var + + def apply_loss_scaling(self, value: TfExpression) -> TfExpression: + """Apply dynamic loss scaling for the given expression.""" + assert tfutil.is_tf_expression(value) + if not self.use_loss_scaling: + return value + return value * tfutil.exp2(self.get_loss_scaling_var(value.device)) + + def undo_loss_scaling(self, value: TfExpression) -> TfExpression: + """Undo the effect of dynamic loss scaling for the given expression.""" + assert tfutil.is_tf_expression(value) + if not self.use_loss_scaling: + return value + return value * tfutil.exp2(-self.get_loss_scaling_var(value.device)) # pylint: disable=invalid-unary-operand-type + + def _broadcast_nccl(self): + """Sum gradients across devices using NCCL ops (fast path).""" + from tensorflow.python.ops import nccl_ops # pylint: disable=no-name-in-module + for all_vars in zip(*[device.grad_clean.keys() for device in self._devices.values()]): + if any(x.shape.num_elements() > 0 for x in all_vars): + all_grads = [device.grad_clean[var] for device, var in zip(self._devices.values(), all_vars)] + all_grads = nccl_ops.all_sum(all_grads) + for device, var, grad in zip(self._devices.values(), all_vars, all_grads): + device.grad_clean[var] = grad + + def _broadcast_fallback(self): + """Sum gradients across devices using TensorFlow collective ops (slow fallback path).""" + from tensorflow.python.ops import collective_ops # pylint: disable=no-name-in-module + global _collective_ops_warning_printed, _collective_ops_group_key, _collective_ops_instance_key + if all(x.shape.num_elements() == 0 for device in self._devices.values() for x in device.grad_clean.values()): + return + if not _collective_ops_warning_printed: + print("------------------------------------------------------------------------") + print("WARNING: Using slow fallback implementation for inter-GPU communication.") + print("Please use TensorFlow 1.14 on Linux for optimal training performance.") + print("------------------------------------------------------------------------") + _collective_ops_warning_printed = True + for device in self._devices.values(): + with tf.device(device.name): + combo = [tf.reshape(x, [x.shape.num_elements()]) for x in device.grad_clean.values()] + combo = tf.concat(combo, axis=0) + combo = collective_ops.all_reduce(combo, merge_op='Add', final_op='Id', + group_size=len(self._devices), group_key=_collective_ops_group_key, + instance_key=_collective_ops_instance_key) + cur_ofs = 0 + for var, grad_old in device.grad_clean.items(): + grad_new = tf.reshape(combo[cur_ofs : cur_ofs + grad_old.shape.num_elements()], grad_old.shape) + cur_ofs += grad_old.shape.num_elements() + device.grad_clean[var] = grad_new + _collective_ops_instance_key += 1 + + +class SimpleAdam: + """Simplified version of tf.train.AdamOptimizer that behaves identically when used with dnnlib.tflib.Optimizer.""" + + def __init__(self, name="Adam", learning_rate=0.001, beta1=0.9, beta2=0.999, epsilon=1e-8): + self.name = name + self.learning_rate = learning_rate + self.beta1 = beta1 + self.beta2 = beta2 + self.epsilon = epsilon + self.all_state_vars = [] + + def variables(self): + return self.all_state_vars + + def compute_gradients(self, loss, var_list, gate_gradients=tf.train.Optimizer.GATE_NONE): + assert gate_gradients == tf.train.Optimizer.GATE_NONE + return list(zip(tf.gradients(loss, var_list), var_list)) + + def apply_gradients(self, grads_and_vars): + with tf.name_scope(self.name): + state_vars = [] + update_ops = [] + + # Adjust learning rate to deal with startup bias. + with tf.control_dependencies(None): + b1pow_var = tf.Variable(dtype=tf.float32, initial_value=1, trainable=False) + b2pow_var = tf.Variable(dtype=tf.float32, initial_value=1, trainable=False) + state_vars += [b1pow_var, b2pow_var] + b1pow_new = b1pow_var * self.beta1 + b2pow_new = b2pow_var * self.beta2 + update_ops += [tf.assign(b1pow_var, b1pow_new), tf.assign(b2pow_var, b2pow_new)] + lr_new = self.learning_rate * tf.sqrt(1 - b2pow_new) / (1 - b1pow_new) + + # Construct ops to update each variable. + for grad, var in grads_and_vars: + with tf.control_dependencies(None): + m_var = tf.Variable(dtype=tf.float32, initial_value=tf.zeros_like(var), trainable=False) + v_var = tf.Variable(dtype=tf.float32, initial_value=tf.zeros_like(var), trainable=False) + state_vars += [m_var, v_var] + m_new = self.beta1 * m_var + (1 - self.beta1) * grad + v_new = self.beta2 * v_var + (1 - self.beta2) * tf.square(grad) + var_delta = lr_new * m_new / (tf.sqrt(v_new) + self.epsilon) + update_ops += [tf.assign(m_var, m_new), tf.assign(v_var, v_new), tf.assign_sub(var, var_delta)] + + # Group everything together. + self.all_state_vars += state_vars + return tf.group(*update_ops) diff --git a/models/StyleCLIP/global_directions/dnnlib/tflib/tfutil.py b/models/StyleCLIP/global_directions/dnnlib/tflib/tfutil.py new file mode 100644 index 0000000000000000000000000000000000000000..fe21100299251492ee6d49a7fab566ffb8283702 --- /dev/null +++ b/models/StyleCLIP/global_directions/dnnlib/tflib/tfutil.py @@ -0,0 +1,262 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +"""Miscellaneous helper utils for Tensorflow.""" + +import os +import numpy as np +import tensorflow as tf + +# Silence deprecation warnings from TensorFlow 1.13 onwards +import logging +logging.getLogger('tensorflow').setLevel(logging.ERROR) +import tensorflow.contrib # requires TensorFlow 1.x! +tf.contrib = tensorflow.contrib + +from typing import Any, Iterable, List, Union + +TfExpression = Union[tf.Tensor, tf.Variable, tf.Operation] +"""A type that represents a valid Tensorflow expression.""" + +TfExpressionEx = Union[TfExpression, int, float, np.ndarray] +"""A type that can be converted to a valid Tensorflow expression.""" + + +def run(*args, **kwargs) -> Any: + """Run the specified ops in the default session.""" + assert_tf_initialized() + return tf.get_default_session().run(*args, **kwargs) + + +def is_tf_expression(x: Any) -> bool: + """Check whether the input is a valid Tensorflow expression, i.e., Tensorflow Tensor, Variable, or Operation.""" + return isinstance(x, (tf.Tensor, tf.Variable, tf.Operation)) + + +def shape_to_list(shape: Iterable[tf.Dimension]) -> List[Union[int, None]]: + """Convert a Tensorflow shape to a list of ints. Retained for backwards compatibility -- use TensorShape.as_list() in new code.""" + return [dim.value for dim in shape] + + +def flatten(x: TfExpressionEx) -> TfExpression: + """Shortcut function for flattening a tensor.""" + with tf.name_scope("Flatten"): + return tf.reshape(x, [-1]) + + +def log2(x: TfExpressionEx) -> TfExpression: + """Logarithm in base 2.""" + with tf.name_scope("Log2"): + return tf.log(x) * np.float32(1.0 / np.log(2.0)) + + +def exp2(x: TfExpressionEx) -> TfExpression: + """Exponent in base 2.""" + with tf.name_scope("Exp2"): + return tf.exp(x * np.float32(np.log(2.0))) + + +def erfinv(y: TfExpressionEx) -> TfExpression: + """Inverse of the error function.""" + # pylint: disable=no-name-in-module + from tensorflow.python.ops.distributions import special_math + return special_math.erfinv(y) + + +def lerp(a: TfExpressionEx, b: TfExpressionEx, t: TfExpressionEx) -> TfExpressionEx: + """Linear interpolation.""" + with tf.name_scope("Lerp"): + return a + (b - a) * t + + +def lerp_clip(a: TfExpressionEx, b: TfExpressionEx, t: TfExpressionEx) -> TfExpression: + """Linear interpolation with clip.""" + with tf.name_scope("LerpClip"): + return a + (b - a) * tf.clip_by_value(t, 0.0, 1.0) + + +def absolute_name_scope(scope: str) -> tf.name_scope: + """Forcefully enter the specified name scope, ignoring any surrounding scopes.""" + return tf.name_scope(scope + "/") + + +def absolute_variable_scope(scope: str, **kwargs) -> tf.variable_scope: + """Forcefully enter the specified variable scope, ignoring any surrounding scopes.""" + return tf.variable_scope(tf.VariableScope(name=scope, **kwargs), auxiliary_name_scope=False) + + +def _sanitize_tf_config(config_dict: dict = None) -> dict: + # Defaults. + cfg = dict() + cfg["rnd.np_random_seed"] = None # Random seed for NumPy. None = keep as is. + cfg["rnd.tf_random_seed"] = "auto" # Random seed for TensorFlow. 'auto' = derive from NumPy random state. None = keep as is. + cfg["env.TF_CPP_MIN_LOG_LEVEL"] = "1" # 0 = Print all available debug info from TensorFlow. 1 = Print warnings and errors, but disable debug info. + cfg["env.HDF5_USE_FILE_LOCKING"] = "FALSE" # Disable HDF5 file locking to avoid concurrency issues with network shares. + cfg["graph_options.place_pruned_graph"] = True # False = Check that all ops are available on the designated device. True = Skip the check for ops that are not used. + cfg["gpu_options.allow_growth"] = True # False = Allocate all GPU memory at the beginning. True = Allocate only as much GPU memory as needed. + + # Remove defaults for environment variables that are already set. + for key in list(cfg): + fields = key.split(".") + if fields[0] == "env": + assert len(fields) == 2 + if fields[1] in os.environ: + del cfg[key] + + # User overrides. + if config_dict is not None: + cfg.update(config_dict) + return cfg + + +def init_tf(config_dict: dict = None) -> None: + """Initialize TensorFlow session using good default settings.""" + # Skip if already initialized. + if tf.get_default_session() is not None: + return + + # Setup config dict and random seeds. + cfg = _sanitize_tf_config(config_dict) + np_random_seed = cfg["rnd.np_random_seed"] + if np_random_seed is not None: + np.random.seed(np_random_seed) + tf_random_seed = cfg["rnd.tf_random_seed"] + if tf_random_seed == "auto": + tf_random_seed = np.random.randint(1 << 31) + if tf_random_seed is not None: + tf.set_random_seed(tf_random_seed) + + # Setup environment variables. + for key, value in cfg.items(): + fields = key.split(".") + if fields[0] == "env": + assert len(fields) == 2 + os.environ[fields[1]] = str(value) + + # Create default TensorFlow session. + create_session(cfg, force_as_default=True) + + +def assert_tf_initialized(): + """Check that TensorFlow session has been initialized.""" + if tf.get_default_session() is None: + raise RuntimeError("No default TensorFlow session found. Please call dnnlib.tflib.init_tf().") + + +def create_session(config_dict: dict = None, force_as_default: bool = False) -> tf.Session: + """Create tf.Session based on config dict.""" + # Setup TensorFlow config proto. + cfg = _sanitize_tf_config(config_dict) + config_proto = tf.ConfigProto() + for key, value in cfg.items(): + fields = key.split(".") + if fields[0] not in ["rnd", "env"]: + obj = config_proto + for field in fields[:-1]: + obj = getattr(obj, field) + setattr(obj, fields[-1], value) + + # Create session. + session = tf.Session(config=config_proto) + if force_as_default: + # pylint: disable=protected-access + session._default_session = session.as_default() + session._default_session.enforce_nesting = False + session._default_session.__enter__() + return session + + +def init_uninitialized_vars(target_vars: List[tf.Variable] = None) -> None: + """Initialize all tf.Variables that have not already been initialized. + + Equivalent to the following, but more efficient and does not bloat the tf graph: + tf.variables_initializer(tf.report_uninitialized_variables()).run() + """ + assert_tf_initialized() + if target_vars is None: + target_vars = tf.global_variables() + + test_vars = [] + test_ops = [] + + with tf.control_dependencies(None): # ignore surrounding control_dependencies + for var in target_vars: + assert is_tf_expression(var) + + try: + tf.get_default_graph().get_tensor_by_name(var.name.replace(":0", "/IsVariableInitialized:0")) + except KeyError: + # Op does not exist => variable may be uninitialized. + test_vars.append(var) + + with absolute_name_scope(var.name.split(":")[0]): + test_ops.append(tf.is_variable_initialized(var)) + + init_vars = [var for var, inited in zip(test_vars, run(test_ops)) if not inited] + run([var.initializer for var in init_vars]) + + +def set_vars(var_to_value_dict: dict) -> None: + """Set the values of given tf.Variables. + + Equivalent to the following, but more efficient and does not bloat the tf graph: + tflib.run([tf.assign(var, value) for var, value in var_to_value_dict.items()] + """ + assert_tf_initialized() + ops = [] + feed_dict = {} + + for var, value in var_to_value_dict.items(): + assert is_tf_expression(var) + + try: + setter = tf.get_default_graph().get_tensor_by_name(var.name.replace(":0", "/setter:0")) # look for existing op + except KeyError: + with absolute_name_scope(var.name.split(":")[0]): + with tf.control_dependencies(None): # ignore surrounding control_dependencies + setter = tf.assign(var, tf.placeholder(var.dtype, var.shape, "new_value"), name="setter") # create new setter + + ops.append(setter) + feed_dict[setter.op.inputs[1]] = value + + run(ops, feed_dict) + + +def create_var_with_large_initial_value(initial_value: np.ndarray, *args, **kwargs): + """Create tf.Variable with large initial value without bloating the tf graph.""" + assert_tf_initialized() + assert isinstance(initial_value, np.ndarray) + zeros = tf.zeros(initial_value.shape, initial_value.dtype) + var = tf.Variable(zeros, *args, **kwargs) + set_vars({var: initial_value}) + return var + + +def convert_images_from_uint8(images, drange=[-1,1], nhwc_to_nchw=False): + """Convert a minibatch of images from uint8 to float32 with configurable dynamic range. + Can be used as an input transformation for Network.run(). + """ + images = tf.cast(images, tf.float32) + if nhwc_to_nchw: + images = tf.transpose(images, [0, 3, 1, 2]) + return images * ((drange[1] - drange[0]) / 255) + drange[0] + + +def convert_images_to_uint8(images, drange=[-1,1], nchw_to_nhwc=False, shrink=1): + """Convert a minibatch of images from float32 to uint8 with configurable dynamic range. + Can be used as an output transformation for Network.run(). + """ + images = tf.cast(images, tf.float32) + if shrink > 1: + ksize = [1, 1, shrink, shrink] + images = tf.nn.avg_pool(images, ksize=ksize, strides=ksize, padding="VALID", data_format="NCHW") + if nchw_to_nhwc: + images = tf.transpose(images, [0, 2, 3, 1]) + scale = 255 / (drange[1] - drange[0]) + images = images * scale + (0.5 - drange[0] * scale) + return tf.saturate_cast(images, tf.uint8) diff --git a/models/StyleCLIP/global_directions/dnnlib/util.py b/models/StyleCLIP/global_directions/dnnlib/util.py new file mode 100644 index 0000000000000000000000000000000000000000..0c35b8923bb27bcd91fd0c14234480067138a3fc --- /dev/null +++ b/models/StyleCLIP/global_directions/dnnlib/util.py @@ -0,0 +1,472 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +"""Miscellaneous utility classes and functions.""" + +import ctypes +import fnmatch +import importlib +import inspect +import numpy as np +import os +import shutil +import sys +import types +import io +import pickle +import re +import requests +import html +import hashlib +import glob +import tempfile +import urllib +import urllib.request +import uuid + +from distutils.util import strtobool +from typing import Any, List, Tuple, Union + + +# Util classes +# ------------------------------------------------------------------------------------------ + + +class EasyDict(dict): + """Convenience class that behaves like a dict but allows access with the attribute syntax.""" + + def __getattr__(self, name: str) -> Any: + try: + return self[name] + except KeyError: + raise AttributeError(name) + + def __setattr__(self, name: str, value: Any) -> None: + self[name] = value + + def __delattr__(self, name: str) -> None: + del self[name] + + +class Logger(object): + """Redirect stderr to stdout, optionally print stdout to a file, and optionally force flushing on both stdout and the file.""" + + def __init__(self, file_name: str = None, file_mode: str = "w", should_flush: bool = True): + self.file = None + + if file_name is not None: + self.file = open(file_name, file_mode) + + self.should_flush = should_flush + self.stdout = sys.stdout + self.stderr = sys.stderr + + sys.stdout = self + sys.stderr = self + + def __enter__(self) -> "Logger": + return self + + def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: + self.close() + + def write(self, text: str) -> None: + """Write text to stdout (and a file) and optionally flush.""" + if len(text) == 0: # workaround for a bug in VSCode debugger: sys.stdout.write(''); sys.stdout.flush() => crash + return + + if self.file is not None: + self.file.write(text) + + self.stdout.write(text) + + if self.should_flush: + self.flush() + + def flush(self) -> None: + """Flush written text to both stdout and a file, if open.""" + if self.file is not None: + self.file.flush() + + self.stdout.flush() + + def close(self) -> None: + """Flush, close possible files, and remove stdout/stderr mirroring.""" + self.flush() + + # if using multiple loggers, prevent closing in wrong order + if sys.stdout is self: + sys.stdout = self.stdout + if sys.stderr is self: + sys.stderr = self.stderr + + if self.file is not None: + self.file.close() + + +# Cache directories +# ------------------------------------------------------------------------------------------ + +_dnnlib_cache_dir = None + +def set_cache_dir(path: str) -> None: + global _dnnlib_cache_dir + _dnnlib_cache_dir = path + +def make_cache_dir_path(*paths: str) -> str: + if _dnnlib_cache_dir is not None: + return os.path.join(_dnnlib_cache_dir, *paths) + if 'DNNLIB_CACHE_DIR' in os.environ: + return os.path.join(os.environ['DNNLIB_CACHE_DIR'], *paths) + if 'HOME' in os.environ: + return os.path.join(os.environ['HOME'], '.cache', 'dnnlib', *paths) + if 'USERPROFILE' in os.environ: + return os.path.join(os.environ['USERPROFILE'], '.cache', 'dnnlib', *paths) + return os.path.join(tempfile.gettempdir(), '.cache', 'dnnlib', *paths) + +# Small util functions +# ------------------------------------------------------------------------------------------ + + +def format_time(seconds: Union[int, float]) -> str: + """Convert the seconds to human readable string with days, hours, minutes and seconds.""" + s = int(np.rint(seconds)) + + if s < 60: + return "{0}s".format(s) + elif s < 60 * 60: + return "{0}m {1:02}s".format(s // 60, s % 60) + elif s < 24 * 60 * 60: + return "{0}h {1:02}m {2:02}s".format(s // (60 * 60), (s // 60) % 60, s % 60) + else: + return "{0}d {1:02}h {2:02}m".format(s // (24 * 60 * 60), (s // (60 * 60)) % 24, (s // 60) % 60) + + +def ask_yes_no(question: str) -> bool: + """Ask the user the question until the user inputs a valid answer.""" + while True: + try: + print("{0} [y/n]".format(question)) + return strtobool(input().lower()) + except ValueError: + pass + + +def tuple_product(t: Tuple) -> Any: + """Calculate the product of the tuple elements.""" + result = 1 + + for v in t: + result *= v + + return result + + +_str_to_ctype = { + "uint8": ctypes.c_ubyte, + "uint16": ctypes.c_uint16, + "uint32": ctypes.c_uint32, + "uint64": ctypes.c_uint64, + "int8": ctypes.c_byte, + "int16": ctypes.c_int16, + "int32": ctypes.c_int32, + "int64": ctypes.c_int64, + "float32": ctypes.c_float, + "float64": ctypes.c_double +} + + +def get_dtype_and_ctype(type_obj: Any) -> Tuple[np.dtype, Any]: + """Given a type name string (or an object having a __name__ attribute), return matching Numpy and ctypes types that have the same size in bytes.""" + type_str = None + + if isinstance(type_obj, str): + type_str = type_obj + elif hasattr(type_obj, "__name__"): + type_str = type_obj.__name__ + elif hasattr(type_obj, "name"): + type_str = type_obj.name + else: + raise RuntimeError("Cannot infer type name from input") + + assert type_str in _str_to_ctype.keys() + + my_dtype = np.dtype(type_str) + my_ctype = _str_to_ctype[type_str] + + assert my_dtype.itemsize == ctypes.sizeof(my_ctype) + + return my_dtype, my_ctype + + +def is_pickleable(obj: Any) -> bool: + try: + with io.BytesIO() as stream: + pickle.dump(obj, stream) + return True + except: + return False + + +# Functionality to import modules/objects by name, and call functions by name +# ------------------------------------------------------------------------------------------ + +def get_module_from_obj_name(obj_name: str) -> Tuple[types.ModuleType, str]: + """Searches for the underlying module behind the name to some python object. + Returns the module and the object name (original name with module part removed).""" + + # allow convenience shorthands, substitute them by full names + obj_name = re.sub("^np.", "numpy.", obj_name) + obj_name = re.sub("^tf.", "tensorflow.", obj_name) + + # list alternatives for (module_name, local_obj_name) + parts = obj_name.split(".") + name_pairs = [(".".join(parts[:i]), ".".join(parts[i:])) for i in range(len(parts), 0, -1)] + + # try each alternative in turn + for module_name, local_obj_name in name_pairs: + try: + module = importlib.import_module(module_name) # may raise ImportError + get_obj_from_module(module, local_obj_name) # may raise AttributeError + return module, local_obj_name + except: + pass + + # maybe some of the modules themselves contain errors? + for module_name, _local_obj_name in name_pairs: + try: + importlib.import_module(module_name) # may raise ImportError + except ImportError: + if not str(sys.exc_info()[1]).startswith("No module named '" + module_name + "'"): + raise + + # maybe the requested attribute is missing? + for module_name, local_obj_name in name_pairs: + try: + module = importlib.import_module(module_name) # may raise ImportError + get_obj_from_module(module, local_obj_name) # may raise AttributeError + except ImportError: + pass + + # we are out of luck, but we have no idea why + raise ImportError(obj_name) + + +def get_obj_from_module(module: types.ModuleType, obj_name: str) -> Any: + """Traverses the object name and returns the last (rightmost) python object.""" + if obj_name == '': + return module + obj = module + for part in obj_name.split("."): + obj = getattr(obj, part) + return obj + + +def get_obj_by_name(name: str) -> Any: + """Finds the python object with the given name.""" + module, obj_name = get_module_from_obj_name(name) + return get_obj_from_module(module, obj_name) + + +def call_func_by_name(*args, func_name: str = None, **kwargs) -> Any: + """Finds the python object with the given name and calls it as a function.""" + assert func_name is not None + func_obj = get_obj_by_name(func_name) + assert callable(func_obj) + return func_obj(*args, **kwargs) + + +def construct_class_by_name(*args, class_name: str = None, **kwargs) -> Any: + """Finds the python class with the given name and constructs it with the given arguments.""" + return call_func_by_name(*args, func_name=class_name, **kwargs) + + +def get_module_dir_by_obj_name(obj_name: str) -> str: + """Get the directory path of the module containing the given object name.""" + module, _ = get_module_from_obj_name(obj_name) + return os.path.dirname(inspect.getfile(module)) + + +def is_top_level_function(obj: Any) -> bool: + """Determine whether the given object is a top-level function, i.e., defined at module scope using 'def'.""" + return callable(obj) and obj.__name__ in sys.modules[obj.__module__].__dict__ + + +def get_top_level_function_name(obj: Any) -> str: + """Return the fully-qualified name of a top-level function.""" + assert is_top_level_function(obj) + module = obj.__module__ + if module == '__main__': + module = os.path.splitext(os.path.basename(sys.modules[module].__file__))[0] + return module + "." + obj.__name__ + + +# File system helpers +# ------------------------------------------------------------------------------------------ + +def list_dir_recursively_with_ignore(dir_path: str, ignores: List[str] = None, add_base_to_relative: bool = False) -> List[Tuple[str, str]]: + """List all files recursively in a given directory while ignoring given file and directory names. + Returns list of tuples containing both absolute and relative paths.""" + assert os.path.isdir(dir_path) + base_name = os.path.basename(os.path.normpath(dir_path)) + + if ignores is None: + ignores = [] + + result = [] + + for root, dirs, files in os.walk(dir_path, topdown=True): + for ignore_ in ignores: + dirs_to_remove = [d for d in dirs if fnmatch.fnmatch(d, ignore_)] + + # dirs need to be edited in-place + for d in dirs_to_remove: + dirs.remove(d) + + files = [f for f in files if not fnmatch.fnmatch(f, ignore_)] + + absolute_paths = [os.path.join(root, f) for f in files] + relative_paths = [os.path.relpath(p, dir_path) for p in absolute_paths] + + if add_base_to_relative: + relative_paths = [os.path.join(base_name, p) for p in relative_paths] + + assert len(absolute_paths) == len(relative_paths) + result += zip(absolute_paths, relative_paths) + + return result + + +def copy_files_and_create_dirs(files: List[Tuple[str, str]]) -> None: + """Takes in a list of tuples of (src, dst) paths and copies files. + Will create all necessary directories.""" + for file in files: + target_dir_name = os.path.dirname(file[1]) + + # will create all intermediate-level directories + if not os.path.exists(target_dir_name): + os.makedirs(target_dir_name) + + shutil.copyfile(file[0], file[1]) + + +# URL helpers +# ------------------------------------------------------------------------------------------ + +def is_url(obj: Any, allow_file_urls: bool = False) -> bool: + """Determine whether the given object is a valid URL string.""" + if not isinstance(obj, str) or not "://" in obj: + return False + if allow_file_urls and obj.startswith('file://'): + return True + try: + res = requests.compat.urlparse(obj) + if not res.scheme or not res.netloc or not "." in res.netloc: + return False + res = requests.compat.urlparse(requests.compat.urljoin(obj, "/")) + if not res.scheme or not res.netloc or not "." in res.netloc: + return False + except: + return False + return True + + +def open_url(url: str, cache_dir: str = None, num_attempts: int = 10, verbose: bool = True, return_filename: bool = False, cache: bool = True) -> Any: + """Download the given URL and return a binary-mode file object to access the data.""" + assert num_attempts >= 1 + assert not (return_filename and (not cache)) + + # Doesn't look like an URL scheme so interpret it as a local filename. + if not re.match('^[a-z]+://', url): + return url if return_filename else open(url, "rb") + + # Handle file URLs. This code handles unusual file:// patterns that + # arise on Windows: + # + # file:///c:/foo.txt + # + # which would translate to a local '/c:/foo.txt' filename that's + # invalid. Drop the forward slash for such pathnames. + # + # If you touch this code path, you should test it on both Linux and + # Windows. + # + # Some internet resources suggest using urllib.request.url2pathname() but + # but that converts forward slashes to backslashes and this causes + # its own set of problems. + if url.startswith('file://'): + filename = urllib.parse.urlparse(url).path + if re.match(r'^/[a-zA-Z]:', filename): + filename = filename[1:] + return filename if return_filename else open(filename, "rb") + + assert is_url(url) + + # Lookup from cache. + if cache_dir is None: + cache_dir = make_cache_dir_path('downloads') + + url_md5 = hashlib.md5(url.encode("utf-8")).hexdigest() + if cache: + cache_files = glob.glob(os.path.join(cache_dir, url_md5 + "_*")) + if len(cache_files) == 1: + filename = cache_files[0] + return filename if return_filename else open(filename, "rb") + + # Download. + url_name = None + url_data = None + with requests.Session() as session: + if verbose: + print("Downloading %s ..." % url, end="", flush=True) + for attempts_left in reversed(range(num_attempts)): + try: + with session.get(url) as res: + res.raise_for_status() + if len(res.content) == 0: + raise IOError("No data received") + + if len(res.content) < 8192: + content_str = res.content.decode("utf-8") + if "download_warning" in res.headers.get("Set-Cookie", ""): + links = [html.unescape(link) for link in content_str.split('"') if "export=download" in link] + if len(links) == 1: + url = requests.compat.urljoin(url, links[0]) + raise IOError("Google Drive virus checker nag") + if "Google Drive - Quota exceeded" in content_str: + raise IOError("Google Drive download quota exceeded -- please try again later") + + match = re.search(r'filename="([^"]*)"', res.headers.get("Content-Disposition", "")) + url_name = match[1] if match else url + url_data = res.content + if verbose: + print(" done") + break + except: + if not attempts_left: + if verbose: + print(" failed") + raise + if verbose: + print(".", end="", flush=True) + + # Save to cache. + if cache: + safe_name = re.sub(r"[^0-9a-zA-Z-._]", "_", url_name) + cache_file = os.path.join(cache_dir, url_md5 + "_" + safe_name) + temp_file = os.path.join(cache_dir, "tmp_" + uuid.uuid4().hex + "_" + url_md5 + "_" + safe_name) + os.makedirs(cache_dir, exist_ok=True) + with open(temp_file, "wb") as f: + f.write(url_data) + os.replace(temp_file, cache_file) # atomic + if return_filename: + return cache_file + + # Return data as file object. + assert not return_filename + return io.BytesIO(url_data) diff --git a/models/StyleCLIP/global_directions/manipulate.py b/models/StyleCLIP/global_directions/manipulate.py new file mode 100644 index 0000000000000000000000000000000000000000..e1a2480caad8016fea0c06f0bfe521b25f084436 --- /dev/null +++ b/models/StyleCLIP/global_directions/manipulate.py @@ -0,0 +1,278 @@ + + +import os +import os.path +import pickle +import numpy as np +import tensorflow as tf +from dnnlib import tflib +from global_directions.utils.visualizer import HtmlPageVisualizer + + +def Vis(bname,suffix,out,rownames=None,colnames=None): + num_images=out.shape[0] + step=out.shape[1] + + if colnames is None: + colnames=[f'Step {i:02d}' for i in range(1, step + 1)] + if rownames is None: + rownames=[str(i) for i in range(num_images)] + + + visualizer = HtmlPageVisualizer( + num_rows=num_images, num_cols=step + 1, viz_size=256) + visualizer.set_headers( + ['Name'] +colnames) + + for i in range(num_images): + visualizer.set_cell(i, 0, text=rownames[i]) + + for i in range(num_images): + for k in range(step): + image=out[i,k,:,:,:] + visualizer.set_cell(i, 1+k, image=image) + + # Save results. + visualizer.save(f'./html/'+bname+'_'+suffix+'.html') + + + + +def LoadData(img_path): + tmp=img_path+'S' + with open(tmp, "rb") as fp: #Pickling + s_names,all_s=pickle.load( fp) + dlatents=all_s + + pindexs=[] + mindexs=[] + for i in range(len(s_names)): + name=s_names[i] + if not('ToRGB' in name): + mindexs.append(i) + else: + pindexs.append(i) + + tmp=img_path+'S_mean_std' + with open(tmp, "rb") as fp: #Pickling + m,std=pickle.load( fp) + + return dlatents,s_names,mindexs,pindexs,m,std + + +def LoadModel(model_path,model_name): + # Initialize TensorFlow. + tflib.init_tf() + tmp=os.path.join(model_path,model_name) + with open(tmp, 'rb') as f: + _, _, Gs = pickle.load(f) + Gs.print_layers() + return Gs + +def convert_images_to_uint8(images, drange=[-1,1], nchw_to_nhwc=False): + """Convert a minibatch of images from float32 to uint8 with configurable dynamic range. + Can be used as an output transformation for Network.run(). + """ + if nchw_to_nhwc: + images = np.transpose(images, [0, 2, 3, 1]) + + scale = 255 / (drange[1] - drange[0]) + images = images * scale + (0.5 - drange[0] * scale) + + np.clip(images, 0, 255, out=images) + images=images.astype('uint8') + return images + + +def convert_images_from_uint8(images, drange=[-1,1], nhwc_to_nchw=False): + """Convert a minibatch of images from uint8 to float32 with configurable dynamic range. + Can be used as an input transformation for Network.run(). + """ + if nhwc_to_nchw: + images=np.rollaxis(images, 3, 1) + return images/ 255 *(drange[1] - drange[0])+ drange[0] + + +class Manipulator(): + def __init__(self,dataset_name='ffhq'): + self.file_path='./' + self.img_path=self.file_path+'npy/'+dataset_name+'/' + self.model_path=self.file_path+'model/' + self.dataset_name=dataset_name + self.model_name=dataset_name+'.pkl' + + self.alpha=[0] #manipulation strength + self.num_images=10 + self.img_index=0 #which image to start + self.viz_size=256 + self.manipulate_layers=None #which layer to manipulate, list + + self.dlatents,self.s_names,self.mindexs,self.pindexs,self.code_mean,self.code_std=LoadData(self.img_path) + + self.sess=tf.InteractiveSession() + init = tf.global_variables_initializer() + self.sess.run(init) + self.Gs=LoadModel(self.model_path,self.model_name) + self.num_layers=len(self.dlatents) + + self.Vis=Vis + self.noise_constant={} + + for i in range(len(self.s_names)): + tmp1=self.s_names[i].split('/') + if not 'ToRGB' in tmp1: + tmp1[-1]='random_normal:0' + size=int(tmp1[1].split('x')[0]) + tmp1='/'.join(tmp1) + tmp=(1,1,size,size) + self.noise_constant[tmp1]=np.random.random(tmp) + + tmp=self.Gs.components.synthesis.input_shape[1] + d={} + d['G_synthesis_1/dlatents_in:0']=np.zeros([1,tmp,512]) + names=list(self.noise_constant.keys()) + tmp=tflib.run(names,d) + for i in range(len(names)): + self.noise_constant[names[i]]=tmp[i] + + self.fmt = dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True) + self.img_size=self.Gs.output_shape[-1] + + def GenerateImg(self,codes): + + + num_images,step=codes[0].shape[:2] + + + out=np.zeros((num_images,step,self.img_size,self.img_size,3),dtype='uint8') + for i in range(num_images): + for k in range(step): + d={} + for m in range(len(self.s_names)): + d[self.s_names[m]]=codes[m][i,k][None,:] #need to change + d['G_synthesis_1/4x4/Const/Shape:0']=np.array([1,18, 512], dtype=np.int32) + d.update(self.noise_constant) + img=tflib.run('G_synthesis_1/images_out:0', d) + image=convert_images_to_uint8(img, nchw_to_nhwc=True) + out[i,k,:,:,:]=image[0] + return out + + + + def MSCode(self,dlatent_tmp,boundary_tmp): + + step=len(self.alpha) + dlatent_tmp1=[tmp.reshape((self.num_images,-1)) for tmp in dlatent_tmp] + dlatent_tmp2=[np.tile(tmp[:,None],(1,step,1)) for tmp in dlatent_tmp1] # (10, 7, 512) + + l=np.array(self.alpha) + l=l.reshape( + [step if axis == 1 else 1 for axis in range(dlatent_tmp2[0].ndim)]) + + if type(self.manipulate_layers)==int: + tmp=[self.manipulate_layers] + elif type(self.manipulate_layers)==list: + tmp=self.manipulate_layers + elif self.manipulate_layers is None: + tmp=np.arange(len(boundary_tmp)) + else: + raise ValueError('manipulate_layers is wrong') + + for i in tmp: + dlatent_tmp2[i]+=l*boundary_tmp[i] + + codes=[] + for i in range(len(dlatent_tmp2)): + tmp=list(dlatent_tmp[i].shape) + tmp.insert(1,step) + codes.append(dlatent_tmp2[i].reshape(tmp)) + return codes + + + def EditOne(self,bname,dlatent_tmp=None): + if dlatent_tmp==None: + dlatent_tmp=[tmp[self.img_index:(self.img_index+self.num_images)] for tmp in self.dlatents] + + boundary_tmp=[] + for i in range(len(self.boundary)): + tmp=self.boundary[i] + if len(tmp)<=bname: + boundary_tmp.append([]) + else: + boundary_tmp.append(tmp[bname]) + + codes=self.MSCode(dlatent_tmp,boundary_tmp) + + out=self.GenerateImg(codes) + return codes,out + + def EditOneC(self,cindex,dlatent_tmp=None): + if dlatent_tmp==None: + dlatent_tmp=[tmp[self.img_index:(self.img_index+self.num_images)] for tmp in self.dlatents] + + boundary_tmp=[[] for i in range(len(self.dlatents))] + + #'only manipulate 1 layer and one channel' + assert len(self.manipulate_layers)==1 + + ml=self.manipulate_layers[0] + tmp=dlatent_tmp[ml].shape[1] #ada + tmp1=np.zeros(tmp) + tmp1[cindex]=self.code_std[ml][cindex] #1 + boundary_tmp[ml]=tmp1 + + codes=self.MSCode(dlatent_tmp,boundary_tmp) + out=self.GenerateImg(codes) + return codes,out + + + def W2S(self,dlatent_tmp): + + all_s = self.sess.run( + self.s_names, + feed_dict={'G_synthesis_1/dlatents_in:0': dlatent_tmp}) + return all_s + + + + + + + + +#%% +if __name__ == "__main__": + + + M=Manipulator(dataset_name='ffhq') + + + #%% + M.alpha=[-5,0,5] + M.num_images=20 + lindex,cindex=6,501 + + M.manipulate_layers=[lindex] + codes,out=M.EditOneC(cindex) #dlatent_tmp + tmp=str(M.manipulate_layers)+'_'+str(cindex) + M.Vis(tmp,'c',out) + + + + + + + + + + + + + + + + + + + + diff --git a/models/StyleCLIP/global_directions/utils/__init__.py b/models/StyleCLIP/global_directions/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/models/StyleCLIP/global_directions/utils/editor.py b/models/StyleCLIP/global_directions/utils/editor.py new file mode 100644 index 0000000000000000000000000000000000000000..b1c2ac56fd7b4b127f948c6b8cf15874a8fe9d93 --- /dev/null +++ b/models/StyleCLIP/global_directions/utils/editor.py @@ -0,0 +1,507 @@ +# python 3.7 +"""Utility functions for image editing from latent space.""" + +import os.path +import numpy as np + +__all__ = [ + 'parse_indices', 'interpolate', 'mix_style', + 'get_layerwise_manipulation_strength', 'manipulate', 'parse_boundary_list' +] + + +def parse_indices(obj, min_val=None, max_val=None): + """Parses indices. + + If the input is a list or tuple, this function has no effect. + + The input can also be a string, which is either a comma separated list of + numbers 'a, b, c', or a dash separated range 'a - c'. Space in the string will + be ignored. + + Args: + obj: The input object to parse indices from. + min_val: If not `None`, this function will check that all indices are equal + to or larger than this value. (default: None) + max_val: If not `None`, this function will check that all indices are equal + to or smaller than this field. (default: None) + + Returns: + A list of integers. + + Raises: + If the input is invalid, i.e., neither a list or tuple, nor a string. + """ + if obj is None or obj == '': + indices = [] + elif isinstance(obj, int): + indices = [obj] + elif isinstance(obj, (list, tuple, np.ndarray)): + indices = list(obj) + elif isinstance(obj, str): + indices = [] + splits = obj.replace(' ', '').split(',') + for split in splits: + numbers = list(map(int, split.split('-'))) + if len(numbers) == 1: + indices.append(numbers[0]) + elif len(numbers) == 2: + indices.extend(list(range(numbers[0], numbers[1] + 1))) + else: + raise ValueError(f'Invalid type of input: {type(obj)}!') + + assert isinstance(indices, list) + indices = sorted(list(set(indices))) + for idx in indices: + assert isinstance(idx, int) + if min_val is not None: + assert idx >= min_val, f'{idx} is smaller than min val `{min_val}`!' + if max_val is not None: + assert idx <= max_val, f'{idx} is larger than max val `{max_val}`!' + + return indices + + +def interpolate(src_codes, dst_codes, step=5): + """Interpolates two sets of latent codes linearly. + + Args: + src_codes: Source codes, with shape [num, *code_shape]. + dst_codes: Target codes, with shape [num, *code_shape]. + step: Number of interplolation steps, with source and target included. For + example, if `step = 5`, three more samples will be inserted. (default: 5) + + Returns: + Interpolated codes, with shape [num, step, *code_shape]. + + Raises: + ValueError: If the input two sets of latent codes are with different shapes. + """ + if not (src_codes.ndim >= 2 and src_codes.shape == dst_codes.shape): + raise ValueError(f'Shapes of source codes and target codes should both be ' + f'[num, *code_shape], but {src_codes.shape} and ' + f'{dst_codes.shape} are received!') + num = src_codes.shape[0] + code_shape = src_codes.shape[1:] + + a = src_codes[:, np.newaxis] + b = dst_codes[:, np.newaxis] + l = np.linspace(0.0, 1.0, step).reshape( + [step if axis == 1 else 1 for axis in range(a.ndim)]) + results = a + l * (b - a) + assert results.shape == (num, step, *code_shape) + + return results + + +def mix_style(style_codes, + content_codes, + num_layers=1, + mix_layers=None, + is_style_layerwise=True, + is_content_layerwise=True): + """Mixes styles from style codes to those of content codes. + + Each style code or content code consists of `num_layers` codes, each of which + is typically fed into a particular layer of the generator. This function mixes + styles by partially replacing the codes of `content_codes` from some certain + layers with those of `style_codes`. + + For example, if both style code and content code are with shape [10, 512], + meaning to have 10 layers and each employs a 512-dimensional latent code. And + the 1st, 2nd, and 3rd layers are the target layers to perform style mixing. + Then the top half of the content code (with shape [3, 512]) will be replaced + by the top half of the style code (also with shape [3, 512]). + + NOTE: This function also supports taking single-layer latent codes as inputs, + i.e., setting `is_style_layerwise` or `is_content_layerwise` as False. In this + case, the corresponding code will be first repeated for `num_layers` before + performing style mixing. + + Args: + style_codes: Style codes, with shape [num_styles, *code_shape] or + [num_styles, num_layers, *code_shape]. + content_codes: Content codes, with shape [num_contents, *code_shape] or + [num_contents, num_layers, *code_shape]. + num_layers: Total number of layers in the generative model. (default: 1) + mix_layers: Indices of the layers to perform style mixing. `None` means to + replace all layers, in which case the content code will be completely + replaced by style code. (default: None) + is_style_layerwise: Indicating whether the input `style_codes` are + layer-wise codes. (default: True) + is_content_layerwise: Indicating whether the input `content_codes` are + layer-wise codes. (default: True) + num_layers + + Returns: + Codes after style mixing, with shape [num_styles, num_contents, num_layers, + *code_shape]. + + Raises: + ValueError: If input `content_codes` or `style_codes` is with invalid shape. + """ + if not is_style_layerwise: + style_codes = style_codes[:, np.newaxis] + style_codes = np.tile( + style_codes, + [num_layers if axis == 1 else 1 for axis in range(style_codes.ndim)]) + if not is_content_layerwise: + content_codes = content_codes[:, np.newaxis] + content_codes = np.tile( + content_codes, + [num_layers if axis == 1 else 1 for axis in range(content_codes.ndim)]) + + if not (style_codes.ndim >= 3 and style_codes.shape[1] == num_layers and + style_codes.shape[1:] == content_codes.shape[1:]): + raise ValueError(f'Shapes of style codes and content codes should be ' + f'[num_styles, num_layers, *code_shape] and ' + f'[num_contents, num_layers, *code_shape] respectively, ' + f'but {style_codes.shape} and {content_codes.shape} are ' + f'received!') + + layer_indices = parse_indices(mix_layers, min_val=0, max_val=num_layers - 1) + if not layer_indices: + layer_indices = list(range(num_layers)) + + num_styles = style_codes.shape[0] + num_contents = content_codes.shape[0] + code_shape = content_codes.shape[2:] + + s = style_codes[:, np.newaxis] + s = np.tile(s, [num_contents if axis == 1 else 1 for axis in range(s.ndim)]) + c = content_codes[np.newaxis] + c = np.tile(c, [num_styles if axis == 0 else 1 for axis in range(c.ndim)]) + + from_style = np.zeros(s.shape, dtype=bool) + from_style[:, :, layer_indices] = True + results = np.where(from_style, s, c) + assert results.shape == (num_styles, num_contents, num_layers, *code_shape) + + return results + + +def get_layerwise_manipulation_strength(num_layers, + truncation_psi, + truncation_layers): + """Gets layer-wise strength for manipulation. + + Recall the truncation trick played on layer [0, truncation_layers): + + w = truncation_psi * w + (1 - truncation_psi) * w_avg + + So, when using the same boundary to manipulate different layers, layer + [0, truncation_layers) and layer [truncation_layers, num_layers) should use + different strength to eliminate the effect from the truncation trick. More + concretely, the strength for layer [0, truncation_layers) is set as + `truncation_psi`, while that for other layers are set as 1. + """ + strength = [1.0 for _ in range(num_layers)] + if truncation_layers > 0: + for layer_idx in range(0, truncation_layers): + strength[layer_idx] = truncation_psi + return strength + + +def manipulate(latent_codes, + boundary, + start_distance=-5.0, + end_distance=5.0, + step=21, + layerwise_manipulation=False, + num_layers=1, + manipulate_layers=None, + is_code_layerwise=False, + is_boundary_layerwise=False, + layerwise_manipulation_strength=1.0): + """Manipulates the given latent codes with respect to a particular boundary. + + Basically, this function takes a set of latent codes and a boundary as inputs, + and outputs a collection of manipulated latent codes. + + For example, let `step` to be 10, `latent_codes` to be with shape [num, + *code_shape], and `boundary` to be with shape [1, *code_shape] and unit norm. + Then the output will be with shape [num, 10, *code_shape]. For each 10-element + manipulated codes, the first code is `start_distance` away from the original + code (i.e., the input) along the `boundary` direction, while the last code is + `end_distance` away. Remaining codes are linearly interpolated. Here, + `distance` is sign sensitive. + + NOTE: This function also supports layer-wise manipulation, in which case the + generator should be able to take layer-wise latent codes as inputs. For + example, if the generator has 18 convolutional layers in total, and each of + which takes an independent latent code as input. It is possible, sometimes + with even better performance, to only partially manipulate these latent codes + corresponding to some certain layers yet keeping others untouched. + + NOTE: Boundary is assumed to be normalized to unit norm already. + + Args: + latent_codes: The input latent codes for manipulation, with shape + [num, *code_shape] or [num, num_layers, *code_shape]. + boundary: The semantic boundary as reference, with shape [1, *code_shape] or + [1, num_layers, *code_shape]. + start_distance: Start point for manipulation. (default: -5.0) + end_distance: End point for manipulation. (default: 5.0) + step: Number of manipulation steps. (default: 21) + layerwise_manipulation: Whether to perform layer-wise manipulation. + (default: False) + num_layers: Number of layers. Only active when `layerwise_manipulation` is + set as `True`. Should be a positive integer. (default: 1) + manipulate_layers: Indices of the layers to perform manipulation. `None` + means to manipulate latent codes from all layers. (default: None) + is_code_layerwise: Whether the input latent codes are layer-wise. If set as + `False`, the function will first repeat the input codes for `num_layers` + times before perform manipulation. (default: False) + is_boundary_layerwise: Whether the input boundary is layer-wise. If set as + `False`, the function will first repeat boundary for `num_layers` times + before perform manipulation. (default: False) + layerwise_manipulation_strength: Manipulation strength for each layer. Only + active when `layerwise_manipulation` is set as `True`. This field can be + used to resolve the strength discrepancy across layers when truncation + trick is on. See function `get_layerwise_manipulation_strength()` for + details. A tuple, list, or `numpy.ndarray` is expected. If set as a single + number, this strength will be used for all layers. (default: 1.0) + + Returns: + Manipulated codes, with shape [num, step, *code_shape] if + `layerwise_manipulation` is set as `False`, or shape [num, step, + num_layers, *code_shape] if `layerwise_manipulation` is set as `True`. + + Raises: + ValueError: If the input latent codes, boundary, or strength are with + invalid shape. + """ + if not (boundary.ndim >= 2 and boundary.shape[0] == 1): + raise ValueError(f'Boundary should be with shape [1, *code_shape] or ' + f'[1, num_layers, *code_shape], but ' + f'{boundary.shape} is received!') + + if not layerwise_manipulation: + assert not is_code_layerwise + assert not is_boundary_layerwise + num_layers = 1 + manipulate_layers = None + layerwise_manipulation_strength = 1.0 + + # Preprocessing for layer-wise manipulation. + # Parse indices of manipulation layers. + layer_indices = parse_indices( + manipulate_layers, min_val=0, max_val=num_layers - 1) + if not layer_indices: + layer_indices = list(range(num_layers)) + # Make latent codes layer-wise if needed. + assert num_layers > 0 + if not is_code_layerwise: + x = latent_codes[:, np.newaxis] + x = np.tile(x, [num_layers if axis == 1 else 1 for axis in range(x.ndim)]) + else: + x = latent_codes + if x.shape[1] != num_layers: + raise ValueError(f'Latent codes should be with shape [num, num_layers, ' + f'*code_shape], where `num_layers` equals to ' + f'{num_layers}, but {x.shape} is received!') + # Make boundary layer-wise if needed. + if not is_boundary_layerwise: + b = boundary + b = np.tile(b, [num_layers if axis == 0 else 1 for axis in range(b.ndim)]) + else: + b = boundary[0] + if b.shape[0] != num_layers: + raise ValueError(f'Boundary should be with shape [num_layers, ' + f'*code_shape], where `num_layers` equals to ' + f'{num_layers}, but {b.shape} is received!') + # Get layer-wise manipulation strength. + if isinstance(layerwise_manipulation_strength, (int, float)): + s = [float(layerwise_manipulation_strength) for _ in range(num_layers)] + elif isinstance(layerwise_manipulation_strength, (list, tuple)): + s = layerwise_manipulation_strength + if len(s) != num_layers: + raise ValueError(f'Shape of layer-wise manipulation strength `{len(s)}` ' + f'mismatches number of layers `{num_layers}`!') + elif isinstance(layerwise_manipulation_strength, np.ndarray): + s = layerwise_manipulation_strength + if s.size != num_layers: + raise ValueError(f'Shape of layer-wise manipulation strength `{s.size}` ' + f'mismatches number of layers `{num_layers}`!') + else: + raise ValueError(f'Unsupported type of `layerwise_manipulation_strength`!') + s = np.array(s).reshape( + [num_layers if axis == 0 else 1 for axis in range(b.ndim)]) + b = b * s + + if x.shape[1:] != b.shape: + raise ValueError(f'Latent code shape {x.shape} and boundary shape ' + f'{b.shape} mismatch!') + num = x.shape[0] + code_shape = x.shape[2:] + + x = x[:, np.newaxis] + b = b[np.newaxis, np.newaxis, :] + l = np.linspace(start_distance, end_distance, step).reshape( + [step if axis == 1 else 1 for axis in range(x.ndim)]) + results = np.tile(x, [step if axis == 1 else 1 for axis in range(x.ndim)]) + is_manipulatable = np.zeros(results.shape, dtype=bool) + is_manipulatable[:, :, layer_indices] = True + results = np.where(is_manipulatable, x + l * b, results) + assert results.shape == (num, step, num_layers, *code_shape) + + return results if layerwise_manipulation else results[:, :, 0] + + +def manipulate2(latent_codes, + proj, + mindex, + start_distance=-5.0, + end_distance=5.0, + step=21, + layerwise_manipulation=False, + num_layers=1, + manipulate_layers=None, + is_code_layerwise=False, + layerwise_manipulation_strength=1.0): + + + if not layerwise_manipulation: + assert not is_code_layerwise +# assert not is_boundary_layerwise + num_layers = 1 + manipulate_layers = None + layerwise_manipulation_strength = 1.0 + + # Preprocessing for layer-wise manipulation. + # Parse indices of manipulation layers. + layer_indices = parse_indices( + manipulate_layers, min_val=0, max_val=num_layers - 1) + if not layer_indices: + layer_indices = list(range(num_layers)) + # Make latent codes layer-wise if needed. + assert num_layers > 0 + if not is_code_layerwise: + x = latent_codes[:, np.newaxis] + x = np.tile(x, [num_layers if axis == 1 else 1 for axis in range(x.ndim)]) + else: + x = latent_codes + if x.shape[1] != num_layers: + raise ValueError(f'Latent codes should be with shape [num, num_layers, ' + f'*code_shape], where `num_layers` equals to ' + f'{num_layers}, but {x.shape} is received!') + # Make boundary layer-wise if needed. +# if not is_boundary_layerwise: +# b = boundary +# b = np.tile(b, [num_layers if axis == 0 else 1 for axis in range(b.ndim)]) +# else: +# b = boundary[0] +# if b.shape[0] != num_layers: +# raise ValueError(f'Boundary should be with shape [num_layers, ' +# f'*code_shape], where `num_layers` equals to ' +# f'{num_layers}, but {b.shape} is received!') + # Get layer-wise manipulation strength. + if isinstance(layerwise_manipulation_strength, (int, float)): + s = [float(layerwise_manipulation_strength) for _ in range(num_layers)] + elif isinstance(layerwise_manipulation_strength, (list, tuple)): + s = layerwise_manipulation_strength + if len(s) != num_layers: + raise ValueError(f'Shape of layer-wise manipulation strength `{len(s)}` ' + f'mismatches number of layers `{num_layers}`!') + elif isinstance(layerwise_manipulation_strength, np.ndarray): + s = layerwise_manipulation_strength + if s.size != num_layers: + raise ValueError(f'Shape of layer-wise manipulation strength `{s.size}` ' + f'mismatches number of layers `{num_layers}`!') + else: + raise ValueError(f'Unsupported type of `layerwise_manipulation_strength`!') +# s = np.array(s).reshape( +# [num_layers if axis == 0 else 1 for axis in range(b.ndim)]) +# b = b * s + +# if x.shape[1:] != b.shape: +# raise ValueError(f'Latent code shape {x.shape} and boundary shape ' +# f'{b.shape} mismatch!') + num = x.shape[0] + code_shape = x.shape[2:] + + x = x[:, np.newaxis] +# b = b[np.newaxis, np.newaxis, :] +# l = np.linspace(start_distance, end_distance, step).reshape( +# [step if axis == 1 else 1 for axis in range(x.ndim)]) + results = np.tile(x, [step if axis == 1 else 1 for axis in range(x.ndim)]) + is_manipulatable = np.zeros(results.shape, dtype=bool) + is_manipulatable[:, :, layer_indices] = True + + tmp=MPC(proj,x,mindex,start_distance,end_distance,step) + tmp = tmp[:, :,np.newaxis] + tmp1 = np.tile(tmp, [num_layers if axis == 2 else 1 for axis in range(tmp.ndim)]) + + + results = np.where(is_manipulatable, tmp1, results) +# print(results.shape) + assert results.shape == (num, step, num_layers, *code_shape) + return results if layerwise_manipulation else results[:, :, 0] + +def MPC(proj,x,mindex,start_distance,end_distance,step): + # x shape (batch_size,1,num_layers,feature) +# print(x.shape) + x1=proj.transform(x[:,0,0,:]) #/np.sqrt(proj.explained_variance_) # (batch_size,num_pc) + + x1 = x1[:, np.newaxis] + x1 = np.tile(x1, [step if axis == 1 else 1 for axis in range(x1.ndim)]) + + + l = np.linspace(start_distance, end_distance, step)[None,:] + x1[:,:,mindex]+=l + + tmp=x1.reshape((-1,x1.shape[-1])) #*np.sqrt(proj.explained_variance_) +# print('xxx') + x2=proj.inverse_transform(tmp) + x2=x2.reshape((x1.shape[0],x1.shape[1],-1)) + +# x1 = x1[:, np.newaxis] +# x1 = np.tile(x1, [step if axis == 1 else 1 for axis in range(x1.ndim)]) + + return x2 + + + + +def parse_boundary_list(boundary_list_path): + """Parses boundary list. + + Sometimes, a text file containing a list of boundaries will significantly + simplify image manipulation with a large amount of boundaries. This function + is used to parse boundary information from such list file. + + Basically, each item in the list should be with format + `($NAME, $SPACE_TYPE): $PATH`. `DISABLE` at the beginning of the line can + disable a particular boundary. + + Sample: + + (age, z): $AGE_BOUNDARY_PATH + (gender, w): $GENDER_BOUNDARY_PATH + DISABLE(pose, wp): $POSE_BOUNDARY_PATH + + Args: + boundary_list_path: Path to the boundary list. + + Returns: + A dictionary, whose key is a two-element tuple (boundary_name, space_type) + and value is the corresponding boundary path. + + Raise: + ValueError: If the given boundary list does not exist. + """ + if not os.path.isfile(boundary_list_path): + raise ValueError(f'Boundary list `boundary_list_path` does not exist!') + + boundaries = {} + with open(boundary_list_path, 'r') as f: + for line in f: + if line[:len('DISABLE')] == 'DISABLE': + continue + boundary_info, boundary_path = line.strip().split(':') + boundary_name, space_type = boundary_info.strip()[1:-1].split(',') + boundary_name = boundary_name.strip() + space_type = space_type.strip().lower() + boundary_path = boundary_path.strip() + boundaries[(boundary_name, space_type)] = boundary_path + return boundaries diff --git a/models/StyleCLIP/global_directions/utils/train_boundary.py b/models/StyleCLIP/global_directions/utils/train_boundary.py new file mode 100644 index 0000000000000000000000000000000000000000..710d062bc4b42913fcc5b12bd545e47af00c7123 --- /dev/null +++ b/models/StyleCLIP/global_directions/utils/train_boundary.py @@ -0,0 +1,158 @@ + +import numpy as np +from sklearn import svm + + + + + +def train_boundary(latent_codes, + scores, + chosen_num_or_ratio=0.02, + split_ratio=0.7, + invalid_value=None, + logger=None, + logger_name='train_boundary'): + """Trains boundary in latent space with offline predicted attribute scores. + + Given a collection of latent codes and the attribute scores predicted from the + corresponding images, this function will train a linear SVM by treating it as + a bi-classification problem. Basically, the samples with highest attribute + scores are treated as positive samples, while those with lowest scores as + negative. For now, the latent code can ONLY be with 1 dimension. + + NOTE: The returned boundary is with shape (1, latent_space_dim), and also + normalized with unit norm. + + Args: + latent_codes: Input latent codes as training data. + scores: Input attribute scores used to generate training labels. + chosen_num_or_ratio: How many samples will be chosen as positive (negative) + samples. If this field lies in range (0, 0.5], `chosen_num_or_ratio * + latent_codes_num` will be used. Otherwise, `min(chosen_num_or_ratio, + 0.5 * latent_codes_num)` will be used. (default: 0.02) + split_ratio: Ratio to split training and validation sets. (default: 0.7) + invalid_value: This field is used to filter out data. (default: None) + logger: Logger for recording log messages. If set as `None`, a default + logger, which prints messages from all levels to screen, will be created. + (default: None) + + Returns: + A decision boundary with type `numpy.ndarray`. + + Raises: + ValueError: If the input `latent_codes` or `scores` are with invalid format. + """ +# if not logger: +# logger = setup_logger(work_dir='', logger_name=logger_name) + + if (not isinstance(latent_codes, np.ndarray) or + not len(latent_codes.shape) == 2): + raise ValueError(f'Input `latent_codes` should be with type' + f'`numpy.ndarray`, and shape [num_samples, ' + f'latent_space_dim]!') + num_samples = latent_codes.shape[0] + latent_space_dim = latent_codes.shape[1] + if (not isinstance(scores, np.ndarray) or not len(scores.shape) == 2 or + not scores.shape[0] == num_samples or not scores.shape[1] == 1): + raise ValueError(f'Input `scores` should be with type `numpy.ndarray`, and ' + f'shape [num_samples, 1], where `num_samples` should be ' + f'exactly same as that of input `latent_codes`!') + if chosen_num_or_ratio <= 0: + raise ValueError(f'Input `chosen_num_or_ratio` should be positive, ' + f'but {chosen_num_or_ratio} received!') + +# logger.info(f'Filtering training data.') + print('Filtering training data.') + if invalid_value is not None: + latent_codes = latent_codes[scores[:, 0] != invalid_value] + scores = scores[scores[:, 0] != invalid_value] + +# logger.info(f'Sorting scores to get positive and negative samples.') + print('Sorting scores to get positive and negative samples.') + + sorted_idx = np.argsort(scores, axis=0)[::-1, 0] + latent_codes = latent_codes[sorted_idx] + scores = scores[sorted_idx] + num_samples = latent_codes.shape[0] + if 0 < chosen_num_or_ratio <= 1: + chosen_num = int(num_samples * chosen_num_or_ratio) + else: + chosen_num = int(chosen_num_or_ratio) + chosen_num = min(chosen_num, num_samples // 2) + +# logger.info(f'Spliting training and validation sets:') + print('Filtering training data.') + + train_num = int(chosen_num * split_ratio) + val_num = chosen_num - train_num + # Positive samples. + positive_idx = np.arange(chosen_num) + np.random.shuffle(positive_idx) + positive_train = latent_codes[:chosen_num][positive_idx[:train_num]] + positive_val = latent_codes[:chosen_num][positive_idx[train_num:]] + # Negative samples. + negative_idx = np.arange(chosen_num) + np.random.shuffle(negative_idx) + negative_train = latent_codes[-chosen_num:][negative_idx[:train_num]] + negative_val = latent_codes[-chosen_num:][negative_idx[train_num:]] + # Training set. + train_data = np.concatenate([positive_train, negative_train], axis=0) + train_label = np.concatenate([np.ones(train_num, dtype=np.int), + np.zeros(train_num, dtype=np.int)], axis=0) +# logger.info(f' Training: {train_num} positive, {train_num} negative.') + print(f' Training: {train_num} positive, {train_num} negative.') + # Validation set. + val_data = np.concatenate([positive_val, negative_val], axis=0) + val_label = np.concatenate([np.ones(val_num, dtype=np.int), + np.zeros(val_num, dtype=np.int)], axis=0) +# logger.info(f' Validation: {val_num} positive, {val_num} negative.') + print(f' Validation: {val_num} positive, {val_num} negative.') + + # Remaining set. + remaining_num = num_samples - chosen_num * 2 + remaining_data = latent_codes[chosen_num:-chosen_num] + remaining_scores = scores[chosen_num:-chosen_num] + decision_value = (scores[0] + scores[-1]) / 2 + remaining_label = np.ones(remaining_num, dtype=np.int) + remaining_label[remaining_scores.ravel() < decision_value] = 0 + remaining_positive_num = np.sum(remaining_label == 1) + remaining_negative_num = np.sum(remaining_label == 0) +# logger.info(f' Remaining: {remaining_positive_num} positive, ' +# f'{remaining_negative_num} negative.') + print(f' Remaining: {remaining_positive_num} positive, ' + f'{remaining_negative_num} negative.') +# logger.info(f'Training boundary.') + print(f'Training boundary.') + + clf = svm.SVC(kernel='linear') + classifier = clf.fit(train_data, train_label) +# logger.info(f'Finish training.') + print(f'Finish training.') + + + if val_num: + val_prediction = classifier.predict(val_data) + correct_num = np.sum(val_label == val_prediction) +# logger.info(f'Accuracy for validation set: ' +# f'{correct_num} / {val_num * 2} = ' +# f'{correct_num / (val_num * 2):.6f}') + print(f'Accuracy for validation set: ' + f'{correct_num} / {val_num * 2} = ' + f'{correct_num / (val_num * 2):.6f}') + vacc=correct_num/len(val_label) + ''' + if remaining_num: + remaining_prediction = classifier.predict(remaining_data) + correct_num = np.sum(remaining_label == remaining_prediction) + logger.info(f'Accuracy for remaining set: ' + f'{correct_num} / {remaining_num} = ' + f'{correct_num / remaining_num:.6f}') + ''' + a = classifier.coef_.reshape(1, latent_space_dim).astype(np.float32) + return a / np.linalg.norm(a),vacc + + + + + diff --git a/models/StyleCLIP/global_directions/utils/visualizer.py b/models/StyleCLIP/global_directions/utils/visualizer.py new file mode 100644 index 0000000000000000000000000000000000000000..8c4a1fba06bf6bc680aa59bf645f796283f6f1c6 --- /dev/null +++ b/models/StyleCLIP/global_directions/utils/visualizer.py @@ -0,0 +1,605 @@ +# python 3.7 +"""Utility functions for visualizing results on html page.""" + +import base64 +import os.path +import cv2 +import numpy as np + +__all__ = [ + 'get_grid_shape', 'get_blank_image', 'load_image', 'save_image', + 'resize_image', 'add_text_to_image', 'fuse_images', 'HtmlPageVisualizer', + 'VideoReader', 'VideoWriter', 'adjust_pixel_range' +] + + +def adjust_pixel_range(images, min_val=-1.0, max_val=1.0, channel_order='NCHW'): + """Adjusts the pixel range of the input images. + + This function assumes the input array (image batch) is with shape [batch_size, + channel, height, width] if `channel_order = NCHW`, or with shape [batch_size, + height, width] if `channel_order = NHWC`. The returned images are with shape + [batch_size, height, width, channel] and pixel range [0, 255]. + + NOTE: The channel order of output images will remain the same as the input. + + Args: + images: Input images to adjust pixel range. + min_val: Min value of the input images. (default: -1.0) + max_val: Max value of the input images. (default: 1.0) + channel_order: Channel order of the input array. (default: NCHW) + + Returns: + The postprocessed images with dtype `numpy.uint8` and range [0, 255]. + + Raises: + ValueError: If the input `images` are not with type `numpy.ndarray` or the + shape is invalid according to `channel_order`. + """ + if not isinstance(images, np.ndarray): + raise ValueError(f'Images should be with type `numpy.ndarray`!') + + channel_order = channel_order.upper() + if channel_order not in ['NCHW', 'NHWC']: + raise ValueError(f'Invalid channel order `{channel_order}`!') + + if images.ndim != 4: + raise ValueError(f'Input images are expected to be with shape `NCHW` or ' + f'`NHWC`, but `{images.shape}` is received!') + if channel_order == 'NCHW' and images.shape[1] not in [1, 3]: + raise ValueError(f'Input images should have 1 or 3 channels under `NCHW` ' + f'channel order!') + if channel_order == 'NHWC' and images.shape[3] not in [1, 3]: + raise ValueError(f'Input images should have 1 or 3 channels under `NHWC` ' + f'channel order!') + + images = images.astype(np.float32) + images = (images - min_val) * 255 / (max_val - min_val) + images = np.clip(images + 0.5, 0, 255).astype(np.uint8) + if channel_order == 'NCHW': + images = images.transpose(0, 2, 3, 1) + + return images + + +def get_grid_shape(size, row=0, col=0, is_portrait=False): + """Gets the shape of a grid based on the size. + + This function makes greatest effort on making the output grid square if + neither `row` nor `col` is set. If `is_portrait` is set as `False`, the height + will always be equal to or smaller than the width. For example, if input + `size = 16`, output shape will be `(4, 4)`; if input `size = 15`, output shape + will be (3, 5). Otherwise, the height will always be equal to or larger than + the width. + + Args: + size: Size (height * width) of the target grid. + is_portrait: Whether to return a portrait size of a landscape size. + (default: False) + + Returns: + A two-element tuple, representing height and width respectively. + """ + assert isinstance(size, int) + assert isinstance(row, int) + assert isinstance(col, int) + if size == 0: + return (0, 0) + + if row > 0 and col > 0 and row * col != size: + row = 0 + col = 0 + + if row > 0 and size % row == 0: + return (row, size // row) + if col > 0 and size % col == 0: + return (size // col, col) + + row = int(np.sqrt(size)) + while row > 0: + if size % row == 0: + col = size // row + break + row = row - 1 + + return (col, row) if is_portrait else (row, col) + + +def get_blank_image(height, width, channels=3, is_black=True): + """Gets a blank image, either white of black. + + NOTE: This function will always return an image with `RGB` channel order for + color image and pixel range [0, 255]. + + Args: + height: Height of the returned image. + width: Width of the returned image. + channels: Number of channels. (default: 3) + is_black: Whether to return a black image or white image. (default: True) + """ + shape = (height, width, channels) + if is_black: + return np.zeros(shape, dtype=np.uint8) + return np.ones(shape, dtype=np.uint8) * 255 + + +def load_image(path): + """Loads an image from disk. + + NOTE: This function will always return an image with `RGB` channel order for + color image and pixel range [0, 255]. + + Args: + path: Path to load the image from. + + Returns: + An image with dtype `np.ndarray` or `None` if input `path` does not exist. + """ + if not os.path.isfile(path): + return None + + image = cv2.imread(path) + return image[:, :, ::-1] + + +def save_image(path, image): + """Saves an image to disk. + + NOTE: The input image (if colorful) is assumed to be with `RGB` channel order + and pixel range [0, 255]. + + Args: + path: Path to save the image to. + image: Image to save. + """ + if image is None: + return + + assert len(image.shape) == 3 and image.shape[2] in [1, 3] + cv2.imwrite(path, image[:, :, ::-1]) + + +def resize_image(image, *args, **kwargs): + """Resizes image. + + This is a wrap of `cv2.resize()`. + + NOTE: THe channel order of the input image will not be changed. + + Args: + image: Image to resize. + """ + if image is None: + return None + + assert image.ndim == 3 and image.shape[2] in [1, 3] + image = cv2.resize(image, *args, **kwargs) + if image.ndim == 2: + return image[:, :, np.newaxis] + return image + + +def add_text_to_image(image, + text='', + position=None, + font=cv2.FONT_HERSHEY_TRIPLEX, + font_size=1.0, + line_type=cv2.LINE_8, + line_width=1, + color=(255, 255, 255)): + """Overlays text on given image. + + NOTE: The input image is assumed to be with `RGB` channel order. + + Args: + image: The image to overlay text on. + text: Text content to overlay on the image. (default: '') + position: Target position (bottom-left corner) to add text. If not set, + center of the image will be used by default. (default: None) + font: Font of the text added. (default: cv2.FONT_HERSHEY_TRIPLEX) + font_size: Font size of the text added. (default: 1.0) + line_type: Line type used to depict the text. (default: cv2.LINE_8) + line_width: Line width used to depict the text. (default: 1) + color: Color of the text added in `RGB` channel order. (default: + (255, 255, 255)) + + Returns: + An image with target text overlayed on. + """ + if image is None or not text: + return image + + cv2.putText(img=image, + text=text, + org=position, + fontFace=font, + fontScale=font_size, + color=color, + thickness=line_width, + lineType=line_type, + bottomLeftOrigin=False) + + return image + + +def fuse_images(images, + image_size=None, + row=0, + col=0, + is_row_major=True, + is_portrait=False, + row_spacing=0, + col_spacing=0, + border_left=0, + border_right=0, + border_top=0, + border_bottom=0, + black_background=True): + """Fuses a collection of images into an entire image. + + Args: + images: A collection of images to fuse. Should be with shape [num, height, + width, channels]. + image_size: Int or two-element tuple. This field is used to resize the image + before fusing. `None` disables resizing. (default: None) + row: Number of rows used for image fusion. If not set, this field will be + automatically assigned based on `col` and total number of images. + (default: None) + col: Number of columns used for image fusion. If not set, this field will be + automatically assigned based on `row` and total number of images. + (default: None) + is_row_major: Whether the input images should be arranged row-major or + column-major. (default: True) + is_portrait: Only active when both `row` and `col` should be assigned + automatically. (default: False) + row_spacing: Space between rows. (default: 0) + col_spacing: Space between columns. (default: 0) + border_left: Width of left border. (default: 0) + border_right: Width of right border. (default: 0) + border_top: Width of top border. (default: 0) + border_bottom: Width of bottom border. (default: 0) + + Returns: + The fused image. + + Raises: + ValueError: If the input `images` is not with shape [num, height, width, + width]. + """ + if images is None: + return images + + if not images.ndim == 4: + raise ValueError(f'Input `images` should be with shape [num, height, ' + f'width, channels], but {images.shape} is received!') + + num, image_height, image_width, channels = images.shape + if image_size is not None: + if isinstance(image_size, int): + image_size = (image_size, image_size) + assert isinstance(image_size, (list, tuple)) and len(image_size) == 2 + width, height = image_size + else: + height, width = image_height, image_width + row, col = get_grid_shape(num, row=row, col=col, is_portrait=is_portrait) + fused_height = ( + height * row + row_spacing * (row - 1) + border_top + border_bottom) + fused_width = ( + width * col + col_spacing * (col - 1) + border_left + border_right) + fused_image = get_blank_image( + fused_height, fused_width, channels=channels, is_black=black_background) + images = images.reshape(row, col, image_height, image_width, channels) + if not is_row_major: + images = images.transpose(1, 0, 2, 3, 4) + + for i in range(row): + y = border_top + i * (height + row_spacing) + for j in range(col): + x = border_left + j * (width + col_spacing) + if image_size is not None: + image = cv2.resize(images[i, j], image_size) + else: + image = images[i, j] + fused_image[y:y + height, x:x + width] = image + + return fused_image + + +def get_sortable_html_header(column_name_list, sort_by_ascending=False): + """Gets header for sortable html page. + + Basically, the html page contains a sortable table, where user can sort the + rows by a particular column by clicking the column head. + + Example: + + column_name_list = [name_1, name_2, name_3] + header = get_sortable_html_header(column_name_list) + footer = get_sortable_html_footer() + sortable_table = ... + html_page = header + sortable_table + footer + + Args: + column_name_list: List of column header names. + sort_by_ascending: Default sorting order. If set as `True`, the html page + will be sorted by ascending order when the header is clicked for the first + time. + + Returns: + A string, which represents for the header for a sortable html page. + """ + header = '\n'.join([ + '', + '', + '', + '', + '', + '', + '', + '', + '', + '', + '', + '', + '', + '']) + for idx, column_name in enumerate(column_name_list): + header += f' \n' + header += '\n' + header += '\n' + header += '\n' + + return header + + +def get_sortable_html_footer(): + """Gets footer for sortable html page. + + Check function `get_sortable_html_header()` for more details. + """ + return '\n
{column_name}
\n\n\n\n' + + +def encode_image_to_html_str(image, image_size=None): + """Encodes an image to html language. + + Args: + image: The input image to encode. Should be with `RGB` channel order. + image_size: Int or two-element tuple. This field is used to resize the image + before encoding. `None` disables resizing. (default: None) + + Returns: + A string which represents the encoded image. + """ + if image is None: + return '' + + assert len(image.shape) == 3 and image.shape[2] in [1, 3] + + # Change channel order to `BGR`, which is opencv-friendly. + image = image[:, :, ::-1] + + # Resize the image if needed. + if image_size is not None: + if isinstance(image_size, int): + image_size = (image_size, image_size) + assert isinstance(image_size, (list, tuple)) and len(image_size) == 2 + image = cv2.resize(image, image_size) + + # Encode the image to html-format string. + encoded_image = cv2.imencode(".jpg", image)[1].tostring() + encoded_image_base64 = base64.b64encode(encoded_image).decode('utf-8') + html_str = f'' + + return html_str + + +class HtmlPageVisualizer(object): + """Defines the html page visualizer. + + This class can be used to visualize image results as html page. Basically, it + is based on an html-format sorted table with helper functions + `get_sortable_html_header()`, `get_sortable_html_footer()`, and + `encode_image_to_html_str()`. To simplify the usage, specifying the following + fields is enough to create a visualization page: + + (1) num_rows: Number of rows of the table (header-row exclusive). + (2) num_cols: Number of columns of the table. + (3) header contents (optional): Title of each column. + + NOTE: `grid_size` can be used to assign `num_rows` and `num_cols` + automatically. + + Example: + + html = HtmlPageVisualizer(num_rows, num_cols) + html.set_headers([...]) + for i in range(num_rows): + for j in range(num_cols): + html.set_cell(i, j, text=..., image=...) + html.save('visualize.html') + """ + + def __init__(self, + num_rows=0, + num_cols=0, + grid_size=0, + is_portrait=False, + viz_size=None): + if grid_size > 0: + num_rows, num_cols = get_grid_shape( + grid_size, row=num_rows, col=num_cols, is_portrait=is_portrait) + assert num_rows > 0 and num_cols > 0 + + self.num_rows = num_rows + self.num_cols = num_cols + self.viz_size = viz_size + self.headers = ['' for _ in range(self.num_cols)] + self.cells = [[{ + 'text': '', + 'image': '', + } for _ in range(self.num_cols)] for _ in range(self.num_rows)] + + def set_header(self, column_idx, content): + """Sets the content of a particular header by column index.""" + self.headers[column_idx] = content + + def set_headers(self, contents): + """Sets the contents of all headers.""" + if isinstance(contents, str): + contents = [contents] + assert isinstance(contents, (list, tuple)) + assert len(contents) == self.num_cols + for column_idx, content in enumerate(contents): + self.set_header(column_idx, content) + + def set_cell(self, row_idx, column_idx, text='', image=None): + """Sets the content of a particular cell. + + Basically, a cell contains some text as well as an image. Both text and + image can be empty. + + Args: + row_idx: Row index of the cell to edit. + column_idx: Column index of the cell to edit. + text: Text to add into the target cell. + image: Image to show in the target cell. Should be with `RGB` channel + order. + """ + self.cells[row_idx][column_idx]['text'] = text + self.cells[row_idx][column_idx]['image'] = encode_image_to_html_str( + image, self.viz_size) + + def save(self, save_path): + """Saves the html page.""" + html = '' + for i in range(self.num_rows): + html += f'\n' + for j in range(self.num_cols): + text = self.cells[i][j]['text'] + image = self.cells[i][j]['image'] + if text: + html += f' {text}

{image}\n' + else: + html += f' {image}\n' + html += f'\n' + + header = get_sortable_html_header(self.headers) + footer = get_sortable_html_footer() + + with open(save_path, 'w') as f: + f.write(header + html + footer) + + +class VideoReader(object): + """Defines the video reader. + + This class can be used to read frames from a given video. + """ + + def __init__(self, path): + """Initializes the video reader by loading the video from disk.""" + if not os.path.isfile(path): + raise ValueError(f'Video `{path}` does not exist!') + + self.path = path + self.video = cv2.VideoCapture(path) + assert self.video.isOpened() + self.position = 0 + + self.length = int(self.video.get(cv2.CAP_PROP_FRAME_COUNT)) + self.frame_height = int(self.video.get(cv2.CAP_PROP_FRAME_HEIGHT)) + self.frame_width = int(self.video.get(cv2.CAP_PROP_FRAME_WIDTH)) + self.fps = self.video.get(cv2.CAP_PROP_FPS) + + def __del__(self): + """Releases the opened video.""" + self.video.release() + + def read(self, position=None): + """Reads a certain frame. + + NOTE: The returned frame is assumed to be with `RGB` channel order. + + Args: + position: Optional. If set, the reader will read frames from the exact + position. Otherwise, the reader will read next frames. (default: None) + """ + if position is not None and position < self.length: + self.video.set(cv2.CAP_PROP_POS_FRAMES, position) + self.position = position + + success, frame = self.video.read() + self.position = self.position + 1 + + return frame[:, :, ::-1] if success else None + + +class VideoWriter(object): + """Defines the video writer. + + This class can be used to create a video. + + NOTE: `.avi` and `DIVX` is the most recommended codec format since it does not + rely on other dependencies. + """ + + def __init__(self, path, frame_height, frame_width, fps=24, codec='DIVX'): + """Creates the video writer.""" + self.path = path + self.frame_height = frame_height + self.frame_width = frame_width + self.fps = fps + self.codec = codec + + self.video = cv2.VideoWriter(filename=path, + fourcc=cv2.VideoWriter_fourcc(*codec), + fps=fps, + frameSize=(frame_width, frame_height)) + + def __del__(self): + """Releases the opened video.""" + self.video.release() + + def write(self, frame): + """Writes a target frame. + + NOTE: The input frame is assumed to be with `RGB` channel order. + """ + self.video.write(frame[:, :, ::-1]) diff --git a/models/StyleCLIP/mapper/__init__.py b/models/StyleCLIP/mapper/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/models/StyleCLIP/mapper/datasets/__init__.py b/models/StyleCLIP/mapper/datasets/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/models/StyleCLIP/mapper/datasets/latents_dataset.py b/models/StyleCLIP/mapper/datasets/latents_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..dde6ef52b7488e864ccd2fa2930b5100c1025c17 --- /dev/null +++ b/models/StyleCLIP/mapper/datasets/latents_dataset.py @@ -0,0 +1,15 @@ +from torch.utils.data import Dataset + + +class LatentsDataset(Dataset): + + def __init__(self, latents, opts): + self.latents = latents + self.opts = opts + + def __len__(self): + return self.latents.shape[0] + + def __getitem__(self, index): + + return self.latents[index] diff --git a/models/StyleCLIP/mapper/latent_mappers.py b/models/StyleCLIP/mapper/latent_mappers.py new file mode 100644 index 0000000000000000000000000000000000000000..63637adc9646986a3546edd19f4555a2f75a379f --- /dev/null +++ b/models/StyleCLIP/mapper/latent_mappers.py @@ -0,0 +1,81 @@ +import torch +from torch import nn +from torch.nn import Module + +from models.StyleCLIP.models.stylegan2.model import EqualLinear, PixelNorm + + +class Mapper(Module): + + def __init__(self, opts): + super(Mapper, self).__init__() + + self.opts = opts + layers = [PixelNorm()] + + for i in range(4): + layers.append( + EqualLinear( + 512, 512, lr_mul=0.01, activation='fused_lrelu' + ) + ) + + self.mapping = nn.Sequential(*layers) + + + def forward(self, x): + x = self.mapping(x) + return x + + +class SingleMapper(Module): + + def __init__(self, opts): + super(SingleMapper, self).__init__() + + self.opts = opts + + self.mapping = Mapper(opts) + + def forward(self, x): + out = self.mapping(x) + return out + + +class LevelsMapper(Module): + + def __init__(self, opts): + super(LevelsMapper, self).__init__() + + self.opts = opts + + if not opts.no_coarse_mapper: + self.course_mapping = Mapper(opts) + if not opts.no_medium_mapper: + self.medium_mapping = Mapper(opts) + if not opts.no_fine_mapper: + self.fine_mapping = Mapper(opts) + + def forward(self, x): + x_coarse = x[:, :4, :] + x_medium = x[:, 4:8, :] + x_fine = x[:, 8:, :] + + if not self.opts.no_coarse_mapper: + x_coarse = self.course_mapping(x_coarse) + else: + x_coarse = torch.zeros_like(x_coarse) + if not self.opts.no_medium_mapper: + x_medium = self.medium_mapping(x_medium) + else: + x_medium = torch.zeros_like(x_medium) + if not self.opts.no_fine_mapper: + x_fine = self.fine_mapping(x_fine) + else: + x_fine = torch.zeros_like(x_fine) + + + out = torch.cat([x_coarse, x_medium, x_fine], dim=1) + + return out + diff --git a/models/StyleCLIP/mapper/options/__init__.py b/models/StyleCLIP/mapper/options/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/models/StyleCLIP/mapper/options/test_options.py b/models/StyleCLIP/mapper/options/test_options.py new file mode 100644 index 0000000000000000000000000000000000000000..aab2e5a5bba1038b97110fa6c8e8bce14de7390c --- /dev/null +++ b/models/StyleCLIP/mapper/options/test_options.py @@ -0,0 +1,42 @@ +from argparse import ArgumentParser + + +class TestOptions: + + def __init__(self): + self.parser = ArgumentParser() + self.initialize() + + def initialize(self): + # arguments for inference script + self.parser.add_argument('--exp_dir', type=str, help='Path to experiment output directory') + self.parser.add_argument('--checkpoint_path', default=None, type=str, help='Path to model checkpoint') + self.parser.add_argument('--couple_outputs', action='store_true', + help='Whether to also save inputs + outputs side-by-side') + + self.parser.add_argument('--mapper_type', default='LevelsMapper', type=str, help='Which mapper to use') + self.parser.add_argument('--no_coarse_mapper', default=False, action="store_true") + self.parser.add_argument('--no_medium_mapper', default=False, action="store_true") + self.parser.add_argument('--no_fine_mapper', default=False, action="store_true") + self.parser.add_argument('--stylegan_size', default=1024, type=int) + + self.parser.add_argument('--test_batch_size', default=2, type=int, help='Batch size for testing and inference') + self.parser.add_argument('--latents_test_path', default=None, type=str, help="The latents for the validation") + self.parser.add_argument('--test_workers', default=2, type=int, + help='Number of test/inference dataloader workers') + + self.parser.add_argument('--n_images', type=int, default=None, + help='Number of images to output. If None, run on all data') + + self.parser.add_argument('--run_id', type=str, default='PKNWUQRQRKXQ', + help='The generator id to use') + + self.parser.add_argument('--image_name', type=str, default='', + help='image to run on') + + self.parser.add_argument('--edit_name', type=str, default='', + help='edit type') + + def parse(self): + opts = self.parser.parse_args() + return opts diff --git a/models/StyleCLIP/mapper/options/train_options.py b/models/StyleCLIP/mapper/options/train_options.py new file mode 100644 index 0000000000000000000000000000000000000000..a365217f8b76d38aaef4a42b90152ec7a8e7bf1f --- /dev/null +++ b/models/StyleCLIP/mapper/options/train_options.py @@ -0,0 +1,49 @@ +from argparse import ArgumentParser + + +class TrainOptions: + + def __init__(self): + self.parser = ArgumentParser() + self.initialize() + + def initialize(self): + self.parser.add_argument('--exp_dir', type=str, help='Path to experiment output directory') + self.parser.add_argument('--mapper_type', default='LevelsMapper', type=str, help='Which mapper to use') + self.parser.add_argument('--no_coarse_mapper', default=False, action="store_true") + self.parser.add_argument('--no_medium_mapper', default=False, action="store_true") + self.parser.add_argument('--no_fine_mapper', default=False, action="store_true") + self.parser.add_argument('--latents_train_path', default="train_faces.pt", type=str, help="The latents for the training") + self.parser.add_argument('--latents_test_path', default="test_faces.pt", type=str, help="The latents for the validation") + self.parser.add_argument('--train_dataset_size', default=5000, type=int, help="Will be used only if no latents are given") + self.parser.add_argument('--test_dataset_size', default=1000, type=int, help="Will be used only if no latents are given") + + self.parser.add_argument('--batch_size', default=2, type=int, help='Batch size for training') + self.parser.add_argument('--test_batch_size', default=1, type=int, help='Batch size for testing and inference') + self.parser.add_argument('--workers', default=4, type=int, help='Number of train dataloader workers') + self.parser.add_argument('--test_workers', default=2, type=int, help='Number of test/inference dataloader workers') + + self.parser.add_argument('--learning_rate', default=0.5, type=float, help='Optimizer learning rate') + self.parser.add_argument('--optim_name', default='ranger', type=str, help='Which optimizer to use') + + self.parser.add_argument('--id_lambda', default=0.1, type=float, help='ID loss multiplier factor') + self.parser.add_argument('--clip_lambda', default=1.0, type=float, help='CLIP loss multiplier factor') + self.parser.add_argument('--latent_l2_lambda', default=0.8, type=float, help='Latent L2 loss multiplier factor') + + self.parser.add_argument('--stylegan_weights', default='../pretrained_models/stylegan2-ffhq-config-f.pt', type=str, help='Path to StyleGAN model weights') + self.parser.add_argument('--stylegan_size', default=1024, type=int) + self.parser.add_argument('--ir_se50_weights', default='../pretrained_models/model_ir_se50.pth', type=str, help="Path to facial recognition network used in ID loss") + self.parser.add_argument('--checkpoint_path', default=None, type=str, help='Path to StyleCLIPModel model checkpoint') + + self.parser.add_argument('--max_steps', default=50000, type=int, help='Maximum number of training steps') + self.parser.add_argument('--image_interval', default=100, type=int, help='Interval for logging train images during training') + self.parser.add_argument('--board_interval', default=50, type=int, help='Interval for logging metrics to tensorboard') + self.parser.add_argument('--val_interval', default=2000, type=int, help='Validation interval') + self.parser.add_argument('--save_interval', default=2000, type=int, help='Model checkpoint interval') + + self.parser.add_argument('--description', required=True, type=str, help='Driving text prompt') + + + def parse(self): + opts = self.parser.parse_args() + return opts \ No newline at end of file diff --git a/models/StyleCLIP/mapper/scripts/inference.py b/models/StyleCLIP/mapper/scripts/inference.py new file mode 100644 index 0000000000000000000000000000000000000000..98d765b3607bc6ecf4d137ac3a876b400269c82a --- /dev/null +++ b/models/StyleCLIP/mapper/scripts/inference.py @@ -0,0 +1,80 @@ +import os +import pickle +from argparse import Namespace +import torchvision +import torch +import sys +import time + +from configs import paths_config, global_config +from models.StyleCLIP.mapper.styleclip_mapper import StyleCLIPMapper +from utils.models_utils import load_tuned_G, load_old_G + +sys.path.append(".") +sys.path.append("..") + + +def run(test_opts, model_id, image_name, use_multi_id_G): + out_path_results = os.path.join(test_opts.exp_dir, test_opts.data_dir_name) + os.makedirs(out_path_results, exist_ok=True) + out_path_results = os.path.join(out_path_results, test_opts.image_name) + os.makedirs(out_path_results, exist_ok=True) + + # update test configs with configs used during training + ckpt = torch.load(test_opts.checkpoint_path, map_location='cpu') + opts = ckpt['opts'] + opts.update(vars(test_opts)) + opts = Namespace(**opts) + + net = StyleCLIPMapper(opts, test_opts.run_id) + net.eval() + net.to(global_config.device) + + generator_type = paths_config.multi_id_model_type if use_multi_id_G else image_name + + new_G = load_tuned_G(model_id, generator_type) + old_G = load_old_G() + + run_styleclip(net, new_G, opts, paths_config.pti_results_keyword, out_path_results, test_opts) + run_styleclip(net, old_G, opts, paths_config.e4e_results_keyword, out_path_results, test_opts) + + +def run_styleclip(net, G, opts, method, out_path_results, test_opts): + net.set_G(G) + + out_path_results = os.path.join(out_path_results, method) + os.makedirs(out_path_results, exist_ok=True) + + latent = torch.load(opts.latents_test_path) + + global_i = 0 + global_time = [] + with torch.no_grad(): + input_cuda = latent.cuda().float() + tic = time.time() + result_batch = run_on_batch(input_cuda, net, test_opts.couple_outputs) + toc = time.time() + global_time.append(toc - tic) + + for i in range(opts.test_batch_size): + im_path = f'{test_opts.image_name}_{test_opts.edit_name}' + if test_opts.couple_outputs: + couple_output = torch.cat([result_batch[2][i].unsqueeze(0), result_batch[0][i].unsqueeze(0)]) + torchvision.utils.save_image(couple_output, os.path.join(out_path_results, f"{im_path}.jpg"), + normalize=True, range=(-1, 1)) + else: + torchvision.utils.save_image(result_batch[0][i], os.path.join(out_path_results, f"{im_path}.jpg"), + normalize=True, range=(-1, 1)) + torch.save(result_batch[1][i].detach().cpu(), os.path.join(out_path_results, f"latent_{im_path}.pt")) + + +def run_on_batch(inputs, net, couple_outputs=False): + w = inputs + with torch.no_grad(): + w_hat = w + 0.06 * net.mapper(w) + x_hat = net.decoder.synthesis(w_hat, noise_mode='const', force_fp32=True) + result_batch = (x_hat, w_hat) + if couple_outputs: + x = net.decoder.synthesis(w, noise_mode='const', force_fp32=True) + result_batch = (x_hat, w_hat, x) + return result_batch diff --git a/models/StyleCLIP/mapper/scripts/train.py b/models/StyleCLIP/mapper/scripts/train.py new file mode 100644 index 0000000000000000000000000000000000000000..4141436fb3edee8ab5f7576fde0c0e53b529ef66 --- /dev/null +++ b/models/StyleCLIP/mapper/scripts/train.py @@ -0,0 +1,32 @@ +""" +This file runs the main training/val loop +""" +import os +import json +import sys +import pprint + +sys.path.append(".") +sys.path.append("..") + +from mapper.options.train_options import TrainOptions +from mapper.training.coach import Coach + + +def main(opts): + if os.path.exists(opts.exp_dir): + raise Exception('Oops... {} already exists'.format(opts.exp_dir)) + os.makedirs(opts.exp_dir, exist_ok=True) + + opts_dict = vars(opts) + pprint.pprint(opts_dict) + with open(os.path.join(opts.exp_dir, 'opt.json'), 'w') as f: + json.dump(opts_dict, f, indent=4, sort_keys=True) + + coach = Coach(opts) + coach.train() + + +if __name__ == '__main__': + opts = TrainOptions().parse() + main(opts) diff --git a/models/StyleCLIP/mapper/styleclip_mapper.py b/models/StyleCLIP/mapper/styleclip_mapper.py new file mode 100644 index 0000000000000000000000000000000000000000..86c04bee5744a551f4c0d31359e0de1f5492ff7e --- /dev/null +++ b/models/StyleCLIP/mapper/styleclip_mapper.py @@ -0,0 +1,76 @@ +import torch +from torch import nn +from models.StyleCLIP.mapper import latent_mappers +from models.StyleCLIP.models.stylegan2.model import Generator + + +def get_keys(d, name): + if 'state_dict' in d: + d = d['state_dict'] + d_filt = {k[len(name) + 1:]: v for k, v in d.items() if k[:len(name)] == name} + return d_filt + + +class StyleCLIPMapper(nn.Module): + + def __init__(self, opts, run_id): + super(StyleCLIPMapper, self).__init__() + self.opts = opts + # Define architecture + self.mapper = self.set_mapper() + self.run_id = run_id + + self.face_pool = torch.nn.AdaptiveAvgPool2d((256, 256)) + # Load weights if needed + self.load_weights() + + def set_mapper(self): + if self.opts.mapper_type == 'SingleMapper': + mapper = latent_mappers.SingleMapper(self.opts) + elif self.opts.mapper_type == 'LevelsMapper': + mapper = latent_mappers.LevelsMapper(self.opts) + else: + raise Exception('{} is not a valid mapper'.format(self.opts.mapper_type)) + return mapper + + def load_weights(self): + if self.opts.checkpoint_path is not None: + print('Loading from checkpoint: {}'.format(self.opts.checkpoint_path)) + ckpt = torch.load(self.opts.checkpoint_path, map_location='cpu') + self.mapper.load_state_dict(get_keys(ckpt, 'mapper'), strict=True) + + def set_G(self, new_G): + self.decoder = new_G + + def forward(self, x, resize=True, latent_mask=None, input_code=False, randomize_noise=True, + inject_latent=None, return_latents=False, alpha=None): + if input_code: + codes = x + else: + codes = self.mapper(x) + + if latent_mask is not None: + for i in latent_mask: + if inject_latent is not None: + if alpha is not None: + codes[:, i] = alpha * inject_latent[:, i] + (1 - alpha) * codes[:, i] + else: + codes[:, i] = inject_latent[:, i] + else: + codes[:, i] = 0 + + input_is_latent = not input_code + images = self.decoder.synthesis(codes, noise_mode='const') + result_latent = None + # images, result_latent = self.decoder([codes], + # input_is_latent=input_is_latent, + # randomize_noise=randomize_noise, + # return_latents=return_latents) + + if resize: + images = self.face_pool(images) + + if return_latents: + return images, result_latent + else: + return images diff --git a/models/StyleCLIP/mapper/training/__init__.py b/models/StyleCLIP/mapper/training/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/models/StyleCLIP/mapper/training/coach.py b/models/StyleCLIP/mapper/training/coach.py new file mode 100644 index 0000000000000000000000000000000000000000..fd38eb226106a21e19beb306cd9b0de6a1e7db04 --- /dev/null +++ b/models/StyleCLIP/mapper/training/coach.py @@ -0,0 +1,242 @@ +import os + +import clip +import torch +import torchvision +from torch import nn +from torch.utils.data import DataLoader +from torch.utils.tensorboard import SummaryWriter + +import criteria.clip_loss as clip_loss +from criteria import id_loss +from mapper.datasets.latents_dataset import LatentsDataset +from mapper.styleclip_mapper import StyleCLIPMapper +from mapper.training.ranger import Ranger +from mapper.training import train_utils + + +class Coach: + def __init__(self, opts): + self.opts = opts + + self.global_step = 0 + + self.device = 'cuda:0' + self.opts.device = self.device + + # Initialize network + self.net = StyleCLIPMapper(self.opts).to(self.device) + + # Initialize loss + if self.opts.id_lambda > 0: + self.id_loss = id_loss.IDLoss(self.opts).to(self.device).eval() + if self.opts.clip_lambda > 0: + self.clip_loss = clip_loss.CLIPLoss(opts) + if self.opts.latent_l2_lambda > 0: + self.latent_l2_loss = nn.MSELoss().to(self.device).eval() + + # Initialize optimizer + self.optimizer = self.configure_optimizers() + + # Initialize dataset + self.train_dataset, self.test_dataset = self.configure_datasets() + self.train_dataloader = DataLoader(self.train_dataset, + batch_size=self.opts.batch_size, + shuffle=True, + num_workers=int(self.opts.workers), + drop_last=True) + self.test_dataloader = DataLoader(self.test_dataset, + batch_size=self.opts.test_batch_size, + shuffle=False, + num_workers=int(self.opts.test_workers), + drop_last=True) + + self.text_inputs = torch.cat([clip.tokenize(self.opts.description)]).cuda() + + # Initialize logger + log_dir = os.path.join(opts.exp_dir, 'logs') + os.makedirs(log_dir, exist_ok=True) + self.log_dir = log_dir + self.logger = SummaryWriter(log_dir=log_dir) + + # Initialize checkpoint dir + self.checkpoint_dir = os.path.join(opts.exp_dir, 'checkpoints') + os.makedirs(self.checkpoint_dir, exist_ok=True) + self.best_val_loss = None + if self.opts.save_interval is None: + self.opts.save_interval = self.opts.max_steps + + def train(self): + self.net.train() + while self.global_step < self.opts.max_steps: + for batch_idx, batch in enumerate(self.train_dataloader): + self.optimizer.zero_grad() + w = batch + w = w.to(self.device) + with torch.no_grad(): + x, _ = self.net.decoder([w], input_is_latent=True, randomize_noise=False, truncation=1) + w_hat = w + 0.1 * self.net.mapper(w) + x_hat, w_hat = self.net.decoder([w_hat], input_is_latent=True, return_latents=True, randomize_noise=False, truncation=1) + loss, loss_dict = self.calc_loss(w, x, w_hat, x_hat) + loss.backward() + self.optimizer.step() + + # Logging related + if self.global_step % self.opts.image_interval == 0 or ( + self.global_step < 1000 and self.global_step % 1000 == 0): + self.parse_and_log_images(x, x_hat, title='images_train') + if self.global_step % self.opts.board_interval == 0: + self.print_metrics(loss_dict, prefix='train') + self.log_metrics(loss_dict, prefix='train') + + # Validation related + val_loss_dict = None + if self.global_step % self.opts.val_interval == 0 or self.global_step == self.opts.max_steps: + val_loss_dict = self.validate() + if val_loss_dict and (self.best_val_loss is None or val_loss_dict['loss'] < self.best_val_loss): + self.best_val_loss = val_loss_dict['loss'] + self.checkpoint_me(val_loss_dict, is_best=True) + + if self.global_step % self.opts.save_interval == 0 or self.global_step == self.opts.max_steps: + if val_loss_dict is not None: + self.checkpoint_me(val_loss_dict, is_best=False) + else: + self.checkpoint_me(loss_dict, is_best=False) + + if self.global_step == self.opts.max_steps: + print('OMG, finished training!') + break + + self.global_step += 1 + + def validate(self): + self.net.eval() + agg_loss_dict = [] + for batch_idx, batch in enumerate(self.test_dataloader): + if batch_idx > 200: + break + + w = batch + + with torch.no_grad(): + w = w.to(self.device).float() + x, _ = self.net.decoder([w], input_is_latent=True, randomize_noise=True, truncation=1) + w_hat = w + 0.1 * self.net.mapper(w) + x_hat, _ = self.net.decoder([w_hat], input_is_latent=True, randomize_noise=True, truncation=1) + loss, cur_loss_dict = self.calc_loss(w, x, w_hat, x_hat) + agg_loss_dict.append(cur_loss_dict) + + # Logging related + self.parse_and_log_images(x, x_hat, title='images_val', index=batch_idx) + + # For first step just do sanity test on small amount of data + if self.global_step == 0 and batch_idx >= 4: + self.net.train() + return None # Do not log, inaccurate in first batch + + loss_dict = train_utils.aggregate_loss_dict(agg_loss_dict) + self.log_metrics(loss_dict, prefix='test') + self.print_metrics(loss_dict, prefix='test') + + self.net.train() + return loss_dict + + def checkpoint_me(self, loss_dict, is_best): + save_name = 'best_model.pt' if is_best else 'iteration_{}.pt'.format(self.global_step) + save_dict = self.__get_save_dict() + checkpoint_path = os.path.join(self.checkpoint_dir, save_name) + torch.save(save_dict, checkpoint_path) + with open(os.path.join(self.checkpoint_dir, 'timestamp.txt'), 'a') as f: + if is_best: + f.write('**Best**: Step - {}, Loss - {:.3f} \n{}\n'.format(self.global_step, self.best_val_loss, loss_dict)) + else: + f.write('Step - {}, \n{}\n'.format(self.global_step, loss_dict)) + + def configure_optimizers(self): + params = list(self.net.mapper.parameters()) + if self.opts.optim_name == 'adam': + optimizer = torch.optim.Adam(params, lr=self.opts.learning_rate) + else: + optimizer = Ranger(params, lr=self.opts.learning_rate) + return optimizer + + def configure_datasets(self): + if self.opts.latents_train_path: + train_latents = torch.load(self.opts.latents_train_path) + else: + train_latents_z = torch.randn(self.opts.train_dataset_size, 512).cuda() + train_latents = [] + for b in range(self.opts.train_dataset_size // self.opts.batch_size): + with torch.no_grad(): + _, train_latents_b = self.net.decoder([train_latents_z[b: b + self.opts.batch_size]], + truncation=0.7, truncation_latent=self.net.latent_avg, return_latents=True) + train_latents.append(train_latents_b) + train_latents = torch.cat(train_latents) + + if self.opts.latents_test_path: + test_latents = torch.load(self.opts.latents_test_path) + else: + test_latents_z = torch.randn(self.opts.train_dataset_size, 512).cuda() + test_latents = [] + for b in range(self.opts.test_dataset_size // self.opts.test_batch_size): + with torch.no_grad(): + _, test_latents_b = self.net.decoder([test_latents_z[b: b + self.opts.test_batch_size]], + truncation=0.7, truncation_latent=self.net.latent_avg, return_latents=True) + test_latents.append(test_latents_b) + test_latents = torch.cat(test_latents) + + train_dataset_celeba = LatentsDataset(latents=train_latents.cpu(), + opts=self.opts) + test_dataset_celeba = LatentsDataset(latents=test_latents.cpu(), + opts=self.opts) + train_dataset = train_dataset_celeba + test_dataset = test_dataset_celeba + print("Number of training samples: {}".format(len(train_dataset))) + print("Number of test samples: {}".format(len(test_dataset))) + return train_dataset, test_dataset + + def calc_loss(self, w, x, w_hat, x_hat): + loss_dict = {} + loss = 0.0 + if self.opts.id_lambda > 0: + loss_id, sim_improvement = self.id_loss(x_hat, x) + loss_dict['loss_id'] = float(loss_id) + loss_dict['id_improve'] = float(sim_improvement) + loss = loss_id * self.opts.id_lambda + if self.opts.clip_lambda > 0: + loss_clip = self.clip_loss(x_hat, self.text_inputs).mean() + loss_dict['loss_clip'] = float(loss_clip) + loss += loss_clip * self.opts.clip_lambda + if self.opts.latent_l2_lambda > 0: + loss_l2_latent = self.latent_l2_loss(w_hat, w) + loss_dict['loss_l2_latent'] = float(loss_l2_latent) + loss += loss_l2_latent * self.opts.latent_l2_lambda + loss_dict['loss'] = float(loss) + return loss, loss_dict + + def log_metrics(self, metrics_dict, prefix): + for key, value in metrics_dict.items(): + #pass + print(f"step: {self.global_step} \t metric: {prefix}/{key} \t value: {value}") + self.logger.add_scalar('{}/{}'.format(prefix, key), value, self.global_step) + + def print_metrics(self, metrics_dict, prefix): + print('Metrics for {}, step {}'.format(prefix, self.global_step)) + for key, value in metrics_dict.items(): + print('\t{} = '.format(key), value) + + def parse_and_log_images(self, x, x_hat, title, index=None): + if index is None: + path = os.path.join(self.log_dir, title, f'{str(self.global_step).zfill(5)}.jpg') + else: + path = os.path.join(self.log_dir, title, f'{str(self.global_step).zfill(5)}_{str(index).zfill(5)}.jpg') + os.makedirs(os.path.dirname(path), exist_ok=True) + torchvision.utils.save_image(torch.cat([x.detach().cpu(), x_hat.detach().cpu()]), path, + normalize=True, scale_each=True, range=(-1, 1), nrow=self.opts.batch_size) + + def __get_save_dict(self): + save_dict = { + 'state_dict': self.net.state_dict(), + 'opts': vars(self.opts) + } + return save_dict \ No newline at end of file diff --git a/models/StyleCLIP/mapper/training/ranger.py b/models/StyleCLIP/mapper/training/ranger.py new file mode 100644 index 0000000000000000000000000000000000000000..9442fd10d42fcc19f4e0dd798d1573b31ed2c0a0 --- /dev/null +++ b/models/StyleCLIP/mapper/training/ranger.py @@ -0,0 +1,164 @@ +# Ranger deep learning optimizer - RAdam + Lookahead + Gradient Centralization, combined into one optimizer. + +# https://github.com/lessw2020/Ranger-Deep-Learning-Optimizer +# and/or +# https://github.com/lessw2020/Best-Deep-Learning-Optimizers + +# Ranger has now been used to capture 12 records on the FastAI leaderboard. + +# This version = 20.4.11 + +# Credits: +# Gradient Centralization --> https://arxiv.org/abs/2004.01461v2 (a new optimization technique for DNNs), github: https://github.com/Yonghongwei/Gradient-Centralization +# RAdam --> https://github.com/LiyuanLucasLiu/RAdam +# Lookahead --> rewritten by lessw2020, but big thanks to Github @LonePatient and @RWightman for ideas from their code. +# Lookahead paper --> MZhang,G Hinton https://arxiv.org/abs/1907.08610 + +# summary of changes: +# 4/11/20 - add gradient centralization option. Set new testing benchmark for accuracy with it, toggle with use_gc flag at init. +# full code integration with all updates at param level instead of group, moves slow weights into state dict (from generic weights), +# supports group learning rates (thanks @SHolderbach), fixes sporadic load from saved model issues. +# changes 8/31/19 - fix references to *self*.N_sma_threshold; +# changed eps to 1e-5 as better default than 1e-8. + +import math +import torch +from torch.optim.optimizer import Optimizer + + +class Ranger(Optimizer): + + def __init__(self, params, lr=1e-3, # lr + alpha=0.5, k=6, N_sma_threshhold=5, # Ranger configs + betas=(.95, 0.999), eps=1e-5, weight_decay=0, # Adam configs + use_gc=True, gc_conv_only=False + # Gradient centralization on or off, applied to conv layers only or conv + fc layers + ): + + # parameter checks + if not 0.0 <= alpha <= 1.0: + raise ValueError(f'Invalid slow update rate: {alpha}') + if not 1 <= k: + raise ValueError(f'Invalid lookahead steps: {k}') + if not lr > 0: + raise ValueError(f'Invalid Learning Rate: {lr}') + if not eps > 0: + raise ValueError(f'Invalid eps: {eps}') + + # parameter comments: + # beta1 (momentum) of .95 seems to work better than .90... + # N_sma_threshold of 5 seems better in testing than 4. + # In both cases, worth testing on your dataset (.90 vs .95, 4 vs 5) to make sure which works best for you. + + # prep defaults and init torch.optim base + defaults = dict(lr=lr, alpha=alpha, k=k, step_counter=0, betas=betas, N_sma_threshhold=N_sma_threshhold, + eps=eps, weight_decay=weight_decay) + super().__init__(params, defaults) + + # adjustable threshold + self.N_sma_threshhold = N_sma_threshhold + + # look ahead params + + self.alpha = alpha + self.k = k + + # radam buffer for state + self.radam_buffer = [[None, None, None] for ind in range(10)] + + # gc on or off + self.use_gc = use_gc + + # level of gradient centralization + self.gc_gradient_threshold = 3 if gc_conv_only else 1 + + def __setstate__(self, state): + super(Ranger, self).__setstate__(state) + + def step(self, closure=None): + loss = None + + # Evaluate averages and grad, update param tensors + for group in self.param_groups: + + for p in group['params']: + if p.grad is None: + continue + grad = p.grad.data.float() + + if grad.is_sparse: + raise RuntimeError('Ranger optimizer does not support sparse gradients') + + p_data_fp32 = p.data.float() + + state = self.state[p] # get state dict for this param + + if len(state) == 0: # if first time to run...init dictionary with our desired entries + # if self.first_run_check==0: + # self.first_run_check=1 + # print("Initializing slow buffer...should not see this at load from saved model!") + state['step'] = 0 + state['exp_avg'] = torch.zeros_like(p_data_fp32) + state['exp_avg_sq'] = torch.zeros_like(p_data_fp32) + + # look ahead weight storage now in state dict + state['slow_buffer'] = torch.empty_like(p.data) + state['slow_buffer'].copy_(p.data) + + else: + state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32) + state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32) + + # begin computations + exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] + beta1, beta2 = group['betas'] + + # GC operation for Conv layers and FC layers + if grad.dim() > self.gc_gradient_threshold: + grad.add_(-grad.mean(dim=tuple(range(1, grad.dim())), keepdim=True)) + + state['step'] += 1 + + # compute variance mov avg + exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) + # compute mean moving avg + exp_avg.mul_(beta1).add_(1 - beta1, grad) + + buffered = self.radam_buffer[int(state['step'] % 10)] + + if state['step'] == buffered[0]: + N_sma, step_size = buffered[1], buffered[2] + else: + buffered[0] = state['step'] + beta2_t = beta2 ** state['step'] + N_sma_max = 2 / (1 - beta2) - 1 + N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t) + buffered[1] = N_sma + if N_sma > self.N_sma_threshhold: + step_size = math.sqrt( + (1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / ( + N_sma_max - 2)) / (1 - beta1 ** state['step']) + else: + step_size = 1.0 / (1 - beta1 ** state['step']) + buffered[2] = step_size + + if group['weight_decay'] != 0: + p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32) + + # apply lr + if N_sma > self.N_sma_threshhold: + denom = exp_avg_sq.sqrt().add_(group['eps']) + p_data_fp32.addcdiv_(-step_size * group['lr'], exp_avg, denom) + else: + p_data_fp32.add_(-step_size * group['lr'], exp_avg) + + p.data.copy_(p_data_fp32) + + # integrated look ahead... + # we do it at the param level instead of group level + if state['step'] % group['k'] == 0: + slow_p = state['slow_buffer'] # get access to slow param tensor + slow_p.add_(self.alpha, p.data - slow_p) # (fast weights - slow weights) * alpha + p.data.copy_(slow_p) # copy interpolated weights to RAdam param tensor + + return loss \ No newline at end of file diff --git a/models/StyleCLIP/mapper/training/train_utils.py b/models/StyleCLIP/mapper/training/train_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..0c55177f7442010bc1fcc64de3d142585c22adc0 --- /dev/null +++ b/models/StyleCLIP/mapper/training/train_utils.py @@ -0,0 +1,13 @@ + +def aggregate_loss_dict(agg_loss_dict): + mean_vals = {} + for output in agg_loss_dict: + for key in output: + mean_vals[key] = mean_vals.setdefault(key, []) + [output[key]] + for key in mean_vals: + if len(mean_vals[key]) > 0: + mean_vals[key] = sum(mean_vals[key]) / len(mean_vals[key]) + else: + print('{} has no value'.format(key)) + mean_vals[key] = 0 + return mean_vals diff --git a/models/StyleCLIP/models/__init__.py b/models/StyleCLIP/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/models/StyleCLIP/models/facial_recognition/__init__.py b/models/StyleCLIP/models/facial_recognition/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/models/StyleCLIP/models/facial_recognition/helpers.py b/models/StyleCLIP/models/facial_recognition/helpers.py new file mode 100644 index 0000000000000000000000000000000000000000..feb7b1cec9efc54f71c90a91fd010fe12e76e1a9 --- /dev/null +++ b/models/StyleCLIP/models/facial_recognition/helpers.py @@ -0,0 +1,119 @@ +from collections import namedtuple +import torch +from torch.nn import Conv2d, BatchNorm2d, PReLU, ReLU, Sigmoid, MaxPool2d, AdaptiveAvgPool2d, Sequential, Module + +""" +ArcFace implementation from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch) +""" + + +class Flatten(Module): + def forward(self, input): + return input.view(input.size(0), -1) + + +def l2_norm(input, axis=1): + norm = torch.norm(input, 2, axis, True) + output = torch.div(input, norm) + return output + + +class Bottleneck(namedtuple('Block', ['in_channel', 'depth', 'stride'])): + """ A named tuple describing a ResNet block. """ + + +def get_block(in_channel, depth, num_units, stride=2): + return [Bottleneck(in_channel, depth, stride)] + [Bottleneck(depth, depth, 1) for i in range(num_units - 1)] + + +def get_blocks(num_layers): + if num_layers == 50: + blocks = [ + get_block(in_channel=64, depth=64, num_units=3), + get_block(in_channel=64, depth=128, num_units=4), + get_block(in_channel=128, depth=256, num_units=14), + get_block(in_channel=256, depth=512, num_units=3) + ] + elif num_layers == 100: + blocks = [ + get_block(in_channel=64, depth=64, num_units=3), + get_block(in_channel=64, depth=128, num_units=13), + get_block(in_channel=128, depth=256, num_units=30), + get_block(in_channel=256, depth=512, num_units=3) + ] + elif num_layers == 152: + blocks = [ + get_block(in_channel=64, depth=64, num_units=3), + get_block(in_channel=64, depth=128, num_units=8), + get_block(in_channel=128, depth=256, num_units=36), + get_block(in_channel=256, depth=512, num_units=3) + ] + else: + raise ValueError("Invalid number of layers: {}. Must be one of [50, 100, 152]".format(num_layers)) + return blocks + + +class SEModule(Module): + def __init__(self, channels, reduction): + super(SEModule, self).__init__() + self.avg_pool = AdaptiveAvgPool2d(1) + self.fc1 = Conv2d(channels, channels // reduction, kernel_size=1, padding=0, bias=False) + self.relu = ReLU(inplace=True) + self.fc2 = Conv2d(channels // reduction, channels, kernel_size=1, padding=0, bias=False) + seltorch.sigmoid = Sigmoid() + + def forward(self, x): + module_input = x + x = self.avg_pool(x) + x = self.fc1(x) + x = self.relu(x) + x = self.fc2(x) + x = seltorch.sigmoid(x) + return module_input * x + + +class bottleneck_IR(Module): + def __init__(self, in_channel, depth, stride): + super(bottleneck_IR, self).__init__() + if in_channel == depth: + self.shortcut_layer = MaxPool2d(1, stride) + else: + self.shortcut_layer = Sequential( + Conv2d(in_channel, depth, (1, 1), stride, bias=False), + BatchNorm2d(depth) + ) + self.res_layer = Sequential( + BatchNorm2d(in_channel), + Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False), PReLU(depth), + Conv2d(depth, depth, (3, 3), stride, 1, bias=False), BatchNorm2d(depth) + ) + + def forward(self, x): + shortcut = self.shortcut_layer(x) + res = self.res_layer(x) + return res + shortcut + + +class bottleneck_IR_SE(Module): + def __init__(self, in_channel, depth, stride): + super(bottleneck_IR_SE, self).__init__() + if in_channel == depth: + self.shortcut_layer = MaxPool2d(1, stride) + else: + self.shortcut_layer = Sequential( + Conv2d(in_channel, depth, (1, 1), stride, bias=False), + BatchNorm2d(depth) + ) + self.res_layer = Sequential( + BatchNorm2d(in_channel), + Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False), + PReLU(depth), + Conv2d(depth, depth, (3, 3), stride, 1, bias=False), + BatchNorm2d(depth), + SEModule(depth, 16) + ) + + def forward(self, x): + shortcut = self.shortcut_layer(x) + res = self.res_layer(x) + return res + shortcut diff --git a/models/StyleCLIP/models/facial_recognition/model_irse.py b/models/StyleCLIP/models/facial_recognition/model_irse.py new file mode 100644 index 0000000000000000000000000000000000000000..b1c79e0366e4a6fd92011e86df80f8b31ec671ae --- /dev/null +++ b/models/StyleCLIP/models/facial_recognition/model_irse.py @@ -0,0 +1,84 @@ +from torch.nn import Linear, Conv2d, BatchNorm1d, BatchNorm2d, PReLU, Dropout, Sequential, Module +from models.facial_recognition.helpers import get_blocks, Flatten, bottleneck_IR, bottleneck_IR_SE, l2_norm + +""" +Modified Backbone implementation from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch) +""" + + +class Backbone(Module): + def __init__(self, input_size, num_layers, mode='ir', drop_ratio=0.4, affine=True): + super(Backbone, self).__init__() + assert input_size in [112, 224], "input_size should be 112 or 224" + assert num_layers in [50, 100, 152], "num_layers should be 50, 100 or 152" + assert mode in ['ir', 'ir_se'], "mode should be ir or ir_se" + blocks = get_blocks(num_layers) + if mode == 'ir': + unit_module = bottleneck_IR + elif mode == 'ir_se': + unit_module = bottleneck_IR_SE + self.input_layer = Sequential(Conv2d(3, 64, (3, 3), 1, 1, bias=False), + BatchNorm2d(64), + PReLU(64)) + if input_size == 112: + self.output_layer = Sequential(BatchNorm2d(512), + Dropout(drop_ratio), + Flatten(), + Linear(512 * 7 * 7, 512), + BatchNorm1d(512, affine=affine)) + else: + self.output_layer = Sequential(BatchNorm2d(512), + Dropout(drop_ratio), + Flatten(), + Linear(512 * 14 * 14, 512), + BatchNorm1d(512, affine=affine)) + + modules = [] + for block in blocks: + for bottleneck in block: + modules.append(unit_module(bottleneck.in_channel, + bottleneck.depth, + bottleneck.stride)) + self.body = Sequential(*modules) + + def forward(self, x): + x = self.input_layer(x) + x = self.body(x) + x = self.output_layer(x) + return l2_norm(x) + + +def IR_50(input_size): + """Constructs a ir-50 model.""" + model = Backbone(input_size, num_layers=50, mode='ir', drop_ratio=0.4, affine=False) + return model + + +def IR_101(input_size): + """Constructs a ir-101 model.""" + model = Backbone(input_size, num_layers=100, mode='ir', drop_ratio=0.4, affine=False) + return model + + +def IR_152(input_size): + """Constructs a ir-152 model.""" + model = Backbone(input_size, num_layers=152, mode='ir', drop_ratio=0.4, affine=False) + return model + + +def IR_SE_50(input_size): + """Constructs a ir_se-50 model.""" + model = Backbone(input_size, num_layers=50, mode='ir_se', drop_ratio=0.4, affine=False) + return model + + +def IR_SE_101(input_size): + """Constructs a ir_se-101 model.""" + model = Backbone(input_size, num_layers=100, mode='ir_se', drop_ratio=0.4, affine=False) + return model + + +def IR_SE_152(input_size): + """Constructs a ir_se-152 model.""" + model = Backbone(input_size, num_layers=152, mode='ir_se', drop_ratio=0.4, affine=False) + return model diff --git a/models/StyleCLIP/models/stylegan2/__init__.py b/models/StyleCLIP/models/stylegan2/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/models/StyleCLIP/models/stylegan2/model.py b/models/StyleCLIP/models/stylegan2/model.py new file mode 100644 index 0000000000000000000000000000000000000000..9d5559203f4f3843fc814b090780ffa129a6fdf0 --- /dev/null +++ b/models/StyleCLIP/models/stylegan2/model.py @@ -0,0 +1,674 @@ +import math +import random + +import torch +from torch import nn +from torch.nn import functional as F + +from models.StyleCLIP.models.stylegan2.op import FusedLeakyReLU, fused_leaky_relu, upfirdn2d + + +class PixelNorm(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, input): + return input * torch.rsqrt(torch.mean(input ** 2, dim=1, keepdim=True) + 1e-8) + + +def make_kernel(k): + k = torch.tensor(k, dtype=torch.float32) + + if k.ndim == 1: + k = k[None, :] * k[:, None] + + k /= k.sum() + + return k + + +class Upsample(nn.Module): + def __init__(self, kernel, factor=2): + super().__init__() + + self.factor = factor + kernel = make_kernel(kernel) * (factor ** 2) + self.register_buffer('kernel', kernel) + + p = kernel.shape[0] - factor + + pad0 = (p + 1) // 2 + factor - 1 + pad1 = p // 2 + + self.pad = (pad0, pad1) + + def forward(self, input): + out = upfirdn2d(input, self.kernel, up=self.factor, down=1, pad=self.pad) + + return out + + +class Downsample(nn.Module): + def __init__(self, kernel, factor=2): + super().__init__() + + self.factor = factor + kernel = make_kernel(kernel) + self.register_buffer('kernel', kernel) + + p = kernel.shape[0] - factor + + pad0 = (p + 1) // 2 + pad1 = p // 2 + + self.pad = (pad0, pad1) + + def forward(self, input): + out = upfirdn2d(input, self.kernel, up=1, down=self.factor, pad=self.pad) + + return out + + +class Blur(nn.Module): + def __init__(self, kernel, pad, upsample_factor=1): + super().__init__() + + kernel = make_kernel(kernel) + + if upsample_factor > 1: + kernel = kernel * (upsample_factor ** 2) + + self.register_buffer('kernel', kernel) + + self.pad = pad + + def forward(self, input): + out = upfirdn2d(input, self.kernel, pad=self.pad) + + return out + + +class EqualConv2d(nn.Module): + def __init__( + self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True + ): + super().__init__() + + self.weight = nn.Parameter( + torch.randn(out_channel, in_channel, kernel_size, kernel_size) + ) + self.scale = 1 / math.sqrt(in_channel * kernel_size ** 2) + + self.stride = stride + self.padding = padding + + if bias: + self.bias = nn.Parameter(torch.zeros(out_channel)) + + else: + self.bias = None + + def forward(self, input): + out = F.conv2d( + input, + self.weight * self.scale, + bias=self.bias, + stride=self.stride, + padding=self.padding, + ) + + return out + + def __repr__(self): + return ( + f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]},' + f' {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})' + ) + + +class EqualLinear(nn.Module): + def __init__( + self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None + ): + super().__init__() + + self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul)) + + if bias: + self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init)) + + else: + self.bias = None + + self.activation = activation + + self.scale = (1 / math.sqrt(in_dim)) * lr_mul + self.lr_mul = lr_mul + + def forward(self, input): + if self.activation: + out = F.linear(input, self.weight * self.scale) + out = fused_leaky_relu(out, self.bias * self.lr_mul) + + else: + out = F.linear( + input, self.weight * self.scale, bias=self.bias * self.lr_mul + ) + + return out + + def __repr__(self): + return ( + f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})' + ) + + +class ScaledLeakyReLU(nn.Module): + def __init__(self, negative_slope=0.2): + super().__init__() + + self.negative_slope = negative_slope + + def forward(self, input): + out = F.leaky_relu(input, negative_slope=self.negative_slope) + + return out * math.sqrt(2) + + +class ModulatedConv2d(nn.Module): + def __init__( + self, + in_channel, + out_channel, + kernel_size, + style_dim, + demodulate=True, + upsample=False, + downsample=False, + blur_kernel=[1, 3, 3, 1], + ): + super().__init__() + + self.eps = 1e-8 + self.kernel_size = kernel_size + self.in_channel = in_channel + self.out_channel = out_channel + self.upsample = upsample + self.downsample = downsample + + if upsample: + factor = 2 + p = (len(blur_kernel) - factor) - (kernel_size - 1) + pad0 = (p + 1) // 2 + factor - 1 + pad1 = p // 2 + 1 + + self.blur = Blur(blur_kernel, pad=(pad0, pad1), upsample_factor=factor) + + if downsample: + factor = 2 + p = (len(blur_kernel) - factor) + (kernel_size - 1) + pad0 = (p + 1) // 2 + pad1 = p // 2 + + self.blur = Blur(blur_kernel, pad=(pad0, pad1)) + + fan_in = in_channel * kernel_size ** 2 + self.scale = 1 / math.sqrt(fan_in) + self.padding = kernel_size // 2 + + self.weight = nn.Parameter( + torch.randn(1, out_channel, in_channel, kernel_size, kernel_size) + ) + + self.modulation = EqualLinear(style_dim, in_channel, bias_init=1) + + self.demodulate = demodulate + + def __repr__(self): + return ( + f'{self.__class__.__name__}({self.in_channel}, {self.out_channel}, {self.kernel_size}, ' + f'upsample={self.upsample}, downsample={self.downsample})' + ) + + def forward(self, input, style): + batch, in_channel, height, width = input.shape + + style = self.modulation(style).view(batch, 1, in_channel, 1, 1) + weight = self.scale * self.weight * style + + if self.demodulate: + demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-8) + weight = weight * demod.view(batch, self.out_channel, 1, 1, 1) + + weight = weight.view( + batch * self.out_channel, in_channel, self.kernel_size, self.kernel_size + ) + + if self.upsample: + input = input.view(1, batch * in_channel, height, width) + weight = weight.view( + batch, self.out_channel, in_channel, self.kernel_size, self.kernel_size + ) + weight = weight.transpose(1, 2).reshape( + batch * in_channel, self.out_channel, self.kernel_size, self.kernel_size + ) + out = F.conv_transpose2d(input, weight, padding=0, stride=2, groups=batch) + _, _, height, width = out.shape + out = out.view(batch, self.out_channel, height, width) + out = self.blur(out) + + elif self.downsample: + input = self.blur(input) + _, _, height, width = input.shape + input = input.view(1, batch * in_channel, height, width) + out = F.conv2d(input, weight, padding=0, stride=2, groups=batch) + _, _, height, width = out.shape + out = out.view(batch, self.out_channel, height, width) + + else: + input = input.view(1, batch * in_channel, height, width) + out = F.conv2d(input, weight, padding=self.padding, groups=batch) + _, _, height, width = out.shape + out = out.view(batch, self.out_channel, height, width) + + return out + + +class NoiseInjection(nn.Module): + def __init__(self): + super().__init__() + + self.weight = nn.Parameter(torch.zeros(1)) + + def forward(self, image, noise=None): + if noise is None: + batch, _, height, width = image.shape + noise = image.new_empty(batch, 1, height, width).normal_() + + return image + self.weight * noise + + +class ConstantInput(nn.Module): + def __init__(self, channel, size=4): + super().__init__() + + self.input = nn.Parameter(torch.randn(1, channel, size, size)) + + def forward(self, input): + batch = input.shape[0] + out = self.input.repeat(batch, 1, 1, 1) + + return out + + +class StyledConv(nn.Module): + def __init__( + self, + in_channel, + out_channel, + kernel_size, + style_dim, + upsample=False, + blur_kernel=[1, 3, 3, 1], + demodulate=True, + ): + super().__init__() + + self.conv = ModulatedConv2d( + in_channel, + out_channel, + kernel_size, + style_dim, + upsample=upsample, + blur_kernel=blur_kernel, + demodulate=demodulate, + ) + + self.noise = NoiseInjection() + # self.bias = nn.Parameter(torch.zeros(1, out_channel, 1, 1)) + # self.activate = ScaledLeakyReLU(0.2) + self.activate = FusedLeakyReLU(out_channel) + + def forward(self, input, style, noise=None): + out = self.conv(input, style) + out = self.noise(out, noise=noise) + # out = out + self.bias + out = self.activate(out) + + return out + + +class ToRGB(nn.Module): + def __init__(self, in_channel, style_dim, upsample=True, blur_kernel=[1, 3, 3, 1]): + super().__init__() + + if upsample: + self.upsample = Upsample(blur_kernel) + + self.conv = ModulatedConv2d(in_channel, 3, 1, style_dim, demodulate=False) + self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1)) + + def forward(self, input, style, skip=None): + out = self.conv(input, style) + out = out + self.bias + + if skip is not None: + skip = self.upsample(skip) + + out = out + skip + + return out + + +class Generator(nn.Module): + def __init__( + self, + size, + style_dim, + n_mlp, + channel_multiplier=2, + blur_kernel=[1, 3, 3, 1], + lr_mlp=0.01, + ): + super().__init__() + + self.size = size + + self.style_dim = style_dim + + layers = [PixelNorm()] + + for i in range(n_mlp): + layers.append( + EqualLinear( + style_dim, style_dim, lr_mul=lr_mlp, activation='fused_lrelu' + ) + ) + + self.style = nn.Sequential(*layers) + + self.channels = { + 4: 512, + 8: 512, + 16: 512, + 32: 512, + 64: 256 * channel_multiplier, + 128: 128 * channel_multiplier, + 256: 64 * channel_multiplier, + 512: 32 * channel_multiplier, + 1024: 16 * channel_multiplier, + } + + self.input = ConstantInput(self.channels[4]) + self.conv1 = StyledConv( + self.channels[4], self.channels[4], 3, style_dim, blur_kernel=blur_kernel + ) + self.to_rgb1 = ToRGB(self.channels[4], style_dim, upsample=False) + + self.log_size = int(math.log(size, 2)) + self.num_layers = (self.log_size - 2) * 2 + 1 + + self.convs = nn.ModuleList() + self.upsamples = nn.ModuleList() + self.to_rgbs = nn.ModuleList() + self.noises = nn.Module() + + in_channel = self.channels[4] + + for layer_idx in range(self.num_layers): + res = (layer_idx + 5) // 2 + shape = [1, 1, 2 ** res, 2 ** res] + self.noises.register_buffer(f'noise_{layer_idx}', torch.randn(*shape)) + + for i in range(3, self.log_size + 1): + out_channel = self.channels[2 ** i] + + self.convs.append( + StyledConv( + in_channel, + out_channel, + 3, + style_dim, + upsample=True, + blur_kernel=blur_kernel, + ) + ) + + self.convs.append( + StyledConv( + out_channel, out_channel, 3, style_dim, blur_kernel=blur_kernel + ) + ) + + self.to_rgbs.append(ToRGB(out_channel, style_dim)) + + in_channel = out_channel + + self.n_latent = self.log_size * 2 - 2 + + def make_noise(self): + device = self.input.input.device + + noises = [torch.randn(1, 1, 2 ** 2, 2 ** 2, device=device)] + + for i in range(3, self.log_size + 1): + for _ in range(2): + noises.append(torch.randn(1, 1, 2 ** i, 2 ** i, device=device)) + + return noises + + def mean_latent(self, n_latent): + latent_in = torch.randn( + n_latent, self.style_dim, device=self.input.input.device + ) + latent = self.style(latent_in).mean(0, keepdim=True) + + return latent + + def get_latent(self, input): + return self.style(input) + + def forward( + self, + styles, + return_latents=False, + inject_index=None, + truncation=1, + truncation_latent=None, + input_is_latent=False, + noise=None, + randomize_noise=True, + ): + if not input_is_latent: + styles = [self.style(s) for s in styles] + + if noise is None: + if randomize_noise: + noise = [None] * self.num_layers + else: + noise = [ + getattr(self.noises, f'noise_{i}') for i in range(self.num_layers) + ] + + if truncation < 1: + style_t = [] + + for style in styles: + style_t.append( + truncation_latent + truncation * (style - truncation_latent) + ) + + styles = style_t + + if len(styles) < 2: + inject_index = self.n_latent + + if styles[0].ndim < 3: + latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1) + + else: + latent = styles[0] + + else: + if inject_index is None: + inject_index = random.randint(1, self.n_latent - 1) + + latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1) + latent2 = styles[1].unsqueeze(1).repeat(1, self.n_latent - inject_index, 1) + + latent = torch.cat([latent, latent2], 1) + + out = self.input(latent) + out = self.conv1(out, latent[:, 0], noise=noise[0]) + + skip = self.to_rgb1(out, latent[:, 1]) + + i = 1 + for conv1, conv2, noise1, noise2, to_rgb in zip( + self.convs[::2], self.convs[1::2], noise[1::2], noise[2::2], self.to_rgbs + ): + out = conv1(out, latent[:, i], noise=noise1) + out = conv2(out, latent[:, i + 1], noise=noise2) + skip = to_rgb(out, latent[:, i + 2], skip) + + i += 2 + + image = skip + + if return_latents: + return image, latent + + else: + return image, None + + +class ConvLayer(nn.Sequential): + def __init__( + self, + in_channel, + out_channel, + kernel_size, + downsample=False, + blur_kernel=[1, 3, 3, 1], + bias=True, + activate=True, + ): + layers = [] + + if downsample: + factor = 2 + p = (len(blur_kernel) - factor) + (kernel_size - 1) + pad0 = (p + 1) // 2 + pad1 = p // 2 + + layers.append(Blur(blur_kernel, pad=(pad0, pad1))) + + stride = 2 + self.padding = 0 + + else: + stride = 1 + self.padding = kernel_size // 2 + + layers.append( + EqualConv2d( + in_channel, + out_channel, + kernel_size, + padding=self.padding, + stride=stride, + bias=bias and not activate, + ) + ) + + if activate: + if bias: + layers.append(FusedLeakyReLU(out_channel)) + + else: + layers.append(ScaledLeakyReLU(0.2)) + + super().__init__(*layers) + + +class ResBlock(nn.Module): + def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1]): + super().__init__() + + self.conv1 = ConvLayer(in_channel, in_channel, 3) + self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=True) + + self.skip = ConvLayer( + in_channel, out_channel, 1, downsample=True, activate=False, bias=False + ) + + def forward(self, input): + out = self.conv1(input) + out = self.conv2(out) + + skip = self.skip(input) + out = (out + skip) / math.sqrt(2) + + return out + + +class Discriminator(nn.Module): + def __init__(self, size, channel_multiplier=2, blur_kernel=[1, 3, 3, 1]): + super().__init__() + + channels = { + 4: 512, + 8: 512, + 16: 512, + 32: 512, + 64: 256 * channel_multiplier, + 128: 128 * channel_multiplier, + 256: 64 * channel_multiplier, + 512: 32 * channel_multiplier, + 1024: 16 * channel_multiplier, + } + + convs = [ConvLayer(3, channels[size], 1)] + + log_size = int(math.log(size, 2)) + + in_channel = channels[size] + + for i in range(log_size, 2, -1): + out_channel = channels[2 ** (i - 1)] + + convs.append(ResBlock(in_channel, out_channel, blur_kernel)) + + in_channel = out_channel + + self.convs = nn.Sequential(*convs) + + self.stddev_group = 4 + self.stddev_feat = 1 + + self.final_conv = ConvLayer(in_channel + 1, channels[4], 3) + self.final_linear = nn.Sequential( + EqualLinear(channels[4] * 4 * 4, channels[4], activation='fused_lrelu'), + EqualLinear(channels[4], 1), + ) + + def forward(self, input): + out = self.convs(input) + + batch, channel, height, width = out.shape + group = min(batch, self.stddev_group) + stddev = out.view( + group, -1, self.stddev_feat, channel // self.stddev_feat, height, width + ) + stddev = torch.sqrt(stddev.var(0, unbiased=False) + 1e-8) + stddev = stddev.mean([2, 3, 4], keepdims=True).squeeze(2) + stddev = stddev.repeat(group, 1, height, width) + out = torch.cat([out, stddev], 1) + + out = self.final_conv(out) + + out = out.view(batch, -1) + out = self.final_linear(out) + + return out + diff --git a/models/StyleCLIP/models/stylegan2/op/__init__.py b/models/StyleCLIP/models/stylegan2/op/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d0918d92285955855be89f00096b888ee5597ce3 --- /dev/null +++ b/models/StyleCLIP/models/stylegan2/op/__init__.py @@ -0,0 +1,2 @@ +from .fused_act import FusedLeakyReLU, fused_leaky_relu +from .upfirdn2d import upfirdn2d diff --git a/models/StyleCLIP/models/stylegan2/op/fused_act.py b/models/StyleCLIP/models/stylegan2/op/fused_act.py new file mode 100644 index 0000000000000000000000000000000000000000..2d575bc9198e6d46eee040eb374c6d8f64c3363c --- /dev/null +++ b/models/StyleCLIP/models/stylegan2/op/fused_act.py @@ -0,0 +1,40 @@ +import os + +import torch +from torch import nn +from torch.nn import functional as F + +module_path = os.path.dirname(__file__) + + + +class FusedLeakyReLU(nn.Module): + def __init__(self, channel, negative_slope=0.2, scale=2 ** 0.5): + super().__init__() + + self.bias = nn.Parameter(torch.zeros(channel)) + self.negative_slope = negative_slope + self.scale = scale + + def forward(self, input): + return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale) + + +def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2 ** 0.5): + rest_dim = [1] * (input.ndim - bias.ndim - 1) + input = input.cuda() + if input.ndim == 3: + return ( + F.leaky_relu( + input + bias.view(1, *rest_dim, bias.shape[0]), negative_slope=negative_slope + ) + * scale + ) + else: + return ( + F.leaky_relu( + input + bias.view(1, bias.shape[0], *rest_dim), negative_slope=negative_slope + ) + * scale + ) + diff --git a/models/StyleCLIP/models/stylegan2/op/upfirdn2d.py b/models/StyleCLIP/models/stylegan2/op/upfirdn2d.py new file mode 100644 index 0000000000000000000000000000000000000000..02fc25af780868d9b883631eb6b03a25c225d745 --- /dev/null +++ b/models/StyleCLIP/models/stylegan2/op/upfirdn2d.py @@ -0,0 +1,60 @@ +import os + +import torch +from torch.nn import functional as F + + +module_path = os.path.dirname(__file__) + + + +def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)): + out = upfirdn2d_native( + input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1] + ) + + return out + + +def upfirdn2d_native( + input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1 +): + _, channel, in_h, in_w = input.shape + input = input.reshape(-1, in_h, in_w, 1) + + _, in_h, in_w, minor = input.shape + kernel_h, kernel_w = kernel.shape + + out = input.view(-1, in_h, 1, in_w, 1, minor) + out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1]) + out = out.view(-1, in_h * up_y, in_w * up_x, minor) + + out = F.pad( + out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)] + ) + out = out[ + :, + max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0), + max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0), + :, + ] + + out = out.permute(0, 3, 1, 2) + out = out.reshape( + [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1] + ) + w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w) + out = F.conv2d(out, w) + out = out.reshape( + -1, + minor, + in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1, + in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1, + ) + out = out.permute(0, 2, 3, 1) + out = out[:, ::down_y, ::down_x, :] + + out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1 + out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1 + + return out.view(-1, channel, out_h, out_w) \ No newline at end of file diff --git a/models/StyleCLIP/optimization/run_optimization.py b/models/StyleCLIP/optimization/run_optimization.py new file mode 100644 index 0000000000000000000000000000000000000000..766d0c81400951202bed51e3f1812e1260ccf071 --- /dev/null +++ b/models/StyleCLIP/optimization/run_optimization.py @@ -0,0 +1,128 @@ +import argparse +import math +import os +import pickle + +import torch +import torchvision +from torch import optim +from tqdm import tqdm + +from StyleCLIP.criteria.clip_loss import CLIPLoss +from StyleCLIP.models.stylegan2.model import Generator +import clip +from StyleCLIP.utils import ensure_checkpoint_exists + + +def get_lr(t, initial_lr, rampdown=0.25, rampup=0.05): + lr_ramp = min(1, (1 - t) / rampdown) + lr_ramp = 0.5 - 0.5 * math.cos(lr_ramp * math.pi) + lr_ramp = lr_ramp * min(1, t / rampup) + + return initial_lr * lr_ramp + + +def main(args, use_old_G): + ensure_checkpoint_exists(args.ckpt) + text_inputs = torch.cat([clip.tokenize(args.description)]).cuda() + os.makedirs(args.results_dir, exist_ok=True) + new_generator_path = f'/disk2/danielroich/Sandbox/stylegan2_ada_pytorch/checkpoints/model_{args.run_id}_{args.image_name}.pt' + old_generator_path = '/disk2/danielroich/Sandbox/pretrained_models/ffhq.pkl' + + if not use_old_G: + with open(new_generator_path, 'rb') as f: + G = torch.load(f).cuda().eval() + else: + with open(old_generator_path, 'rb') as f: + G = pickle.load(f)['G_ema'].cuda().eval() + + if args.latent_path: + latent_code_init = torch.load(args.latent_path).cuda() + elif args.mode == "edit": + latent_code_init_not_trunc = torch.randn(1, 512).cuda() + with torch.no_grad(): + latent_code_init = G.mapping(latent_code_init_not_trunc, None) + + latent = latent_code_init.detach().clone() + latent.requires_grad = True + + clip_loss = CLIPLoss(args) + + optimizer = optim.Adam([latent], lr=args.lr) + + pbar = tqdm(range(args.step)) + + for i in pbar: + t = i / args.step + lr = get_lr(t, args.lr) + optimizer.param_groups[0]["lr"] = lr + + img_gen = G.synthesis(latent, noise_mode='const') + + c_loss = clip_loss(img_gen, text_inputs) + + if args.mode == "edit": + l2_loss = ((latent_code_init - latent) ** 2).sum() + loss = c_loss + args.l2_lambda * l2_loss + else: + loss = c_loss + + optimizer.zero_grad() + loss.backward() + optimizer.step() + + pbar.set_description( + ( + f"loss: {loss.item():.4f};" + ) + ) + if args.save_intermediate_image_every > 0 and i % args.save_intermediate_image_every == 0: + with torch.no_grad(): + img_gen = G.synthesis(latent, noise_mode='const') + + torchvision.utils.save_image(img_gen, + f"/disk2/danielroich/Sandbox/StyleCLIP/results/inference_results/{str(i).zfill(5)}.png", + normalize=True, range=(-1, 1)) + + if args.mode == "edit": + with torch.no_grad(): + img_orig = G.synthesis(latent_code_init, noise_mode='const') + + final_result = torch.cat([img_orig, img_gen]) + else: + final_result = img_gen + + return final_result + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--description", type=str, default="a person with purple hair", + help="the text that guides the editing/generation") + parser.add_argument("--ckpt", type=str, default="../pretrained_models/stylegan2-ffhq-config-f.pt", + help="pretrained StyleGAN2 weights") + parser.add_argument("--stylegan_size", type=int, default=1024, help="StyleGAN resolution") + parser.add_argument("--lr_rampup", type=float, default=0.05) + parser.add_argument("--lr", type=float, default=0.1) + parser.add_argument("--step", type=int, default=300, help="number of optimization steps") + parser.add_argument("--mode", type=str, default="edit", choices=["edit", "free_generation"], + help="choose between edit an image an generate a free one") + parser.add_argument("--l2_lambda", type=float, default=0.008, + help="weight of the latent distance (used for editing only)") + parser.add_argument("--latent_path", type=str, default=None, + help="starts the optimization from the given latent code if provided. Otherwose, starts from" + "the mean latent in a free generation, and from a random one in editing. " + "Expects a .pt format") + parser.add_argument("--truncation", type=float, default=0.7, + help="used only for the initial latent vector, and only when a latent code path is" + "not provided") + parser.add_argument("--save_intermediate_image_every", type=int, default=20, + help="if > 0 then saves intermidate results during the optimization") + parser.add_argument("--results_dir", type=str, default="results") + + args = parser.parse_args() + + result_image = main(args) + + torchvision.utils.save_image(result_image.detach().cpu(), os.path.join(args.results_dir, "final_result.jpg"), + normalize=True, scale_each=True, range=(-1, 1)) diff --git a/models/__init__.py b/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/models/__pycache__/__init__.cpython-36.pyc b/models/__pycache__/__init__.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c5552d0c1c5a1466f571cf01a4cd23d296e98633 Binary files /dev/null and b/models/__pycache__/__init__.cpython-36.pyc differ diff --git a/models/__pycache__/__init__.cpython-39.pyc b/models/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bb15d6aa77f804e148bd6a3a084d31a021efbda1 Binary files /dev/null and b/models/__pycache__/__init__.cpython-39.pyc differ diff --git a/models/e4e/__init__.py b/models/e4e/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/models/e4e/__pycache__/__init__.cpython-36.pyc b/models/e4e/__pycache__/__init__.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e254b21efd74265c6f5da25a7b066314b6af50de Binary files /dev/null and b/models/e4e/__pycache__/__init__.cpython-36.pyc differ diff --git a/models/e4e/__pycache__/__init__.cpython-39.pyc b/models/e4e/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d66a397405b688424c78f42938abcd1651e68387 Binary files /dev/null and b/models/e4e/__pycache__/__init__.cpython-39.pyc differ diff --git a/models/e4e/__pycache__/psp.cpython-36.pyc b/models/e4e/__pycache__/psp.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..339288a87e21039ffde208413b5c593badd54060 Binary files /dev/null and b/models/e4e/__pycache__/psp.cpython-36.pyc differ diff --git a/models/e4e/__pycache__/psp.cpython-39.pyc b/models/e4e/__pycache__/psp.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..222521bf5b87d1d77d6dcf2b2f069440106fc430 Binary files /dev/null and b/models/e4e/__pycache__/psp.cpython-39.pyc differ diff --git a/models/e4e/discriminator.py b/models/e4e/discriminator.py new file mode 100644 index 0000000000000000000000000000000000000000..16bf3722c7f2e35cdc9bd177a33ed0975e67200d --- /dev/null +++ b/models/e4e/discriminator.py @@ -0,0 +1,20 @@ +from torch import nn + + +class LatentCodesDiscriminator(nn.Module): + def __init__(self, style_dim, n_mlp): + super().__init__() + + self.style_dim = style_dim + + layers = [] + for i in range(n_mlp-1): + layers.append( + nn.Linear(style_dim, style_dim) + ) + layers.append(nn.LeakyReLU(0.2)) + layers.append(nn.Linear(512, 1)) + self.mlp = nn.Sequential(*layers) + + def forward(self, w): + return self.mlp(w) diff --git a/models/e4e/encoders/__init__.py b/models/e4e/encoders/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/models/e4e/encoders/__pycache__/__init__.cpython-36.pyc b/models/e4e/encoders/__pycache__/__init__.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2b4e2fb03cc70eef5ca7f29877162a9a5f45b0de Binary files /dev/null and b/models/e4e/encoders/__pycache__/__init__.cpython-36.pyc differ diff --git a/models/e4e/encoders/__pycache__/__init__.cpython-39.pyc b/models/e4e/encoders/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0c1c6e49609f89568a3301798a70e1fb2e1f2f07 Binary files /dev/null and b/models/e4e/encoders/__pycache__/__init__.cpython-39.pyc differ diff --git a/models/e4e/encoders/__pycache__/helpers.cpython-36.pyc b/models/e4e/encoders/__pycache__/helpers.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bae725e327c1c74de623303675b8ded7f66aeeaa Binary files /dev/null and b/models/e4e/encoders/__pycache__/helpers.cpython-36.pyc differ diff --git a/models/e4e/encoders/__pycache__/helpers.cpython-39.pyc b/models/e4e/encoders/__pycache__/helpers.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5e8862f662d0fac28bb589d9bc2e9b22c4636a88 Binary files /dev/null and b/models/e4e/encoders/__pycache__/helpers.cpython-39.pyc differ diff --git a/models/e4e/encoders/__pycache__/psp_encoders.cpython-36.pyc b/models/e4e/encoders/__pycache__/psp_encoders.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3e7a646a4ddcc5f37de092dceadd46169db0d153 Binary files /dev/null and b/models/e4e/encoders/__pycache__/psp_encoders.cpython-36.pyc differ diff --git a/models/e4e/encoders/__pycache__/psp_encoders.cpython-39.pyc b/models/e4e/encoders/__pycache__/psp_encoders.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..af724868f389e7fe4d0d4905de8a58838110139c Binary files /dev/null and b/models/e4e/encoders/__pycache__/psp_encoders.cpython-39.pyc differ diff --git a/models/e4e/encoders/helpers.py b/models/e4e/encoders/helpers.py new file mode 100644 index 0000000000000000000000000000000000000000..cf31d3c16b1d2df4c34390d5aa1141398a4aa5cd --- /dev/null +++ b/models/e4e/encoders/helpers.py @@ -0,0 +1,140 @@ +from collections import namedtuple +import torch +import torch.nn.functional as F +from torch.nn import Conv2d, BatchNorm2d, PReLU, ReLU, Sigmoid, MaxPool2d, AdaptiveAvgPool2d, Sequential, Module + +""" +ArcFace implementation from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch) +""" + + +class Flatten(Module): + def forward(self, input): + return input.view(input.size(0), -1) + + +def l2_norm(input, axis=1): + norm = torch.norm(input, 2, axis, True) + output = torch.div(input, norm) + return output + + +class Bottleneck(namedtuple('Block', ['in_channel', 'depth', 'stride'])): + """ A named tuple describing a ResNet block. """ + + +def get_block(in_channel, depth, num_units, stride=2): + return [Bottleneck(in_channel, depth, stride)] + [Bottleneck(depth, depth, 1) for i in range(num_units - 1)] + + +def get_blocks(num_layers): + if num_layers == 50: + blocks = [ + get_block(in_channel=64, depth=64, num_units=3), + get_block(in_channel=64, depth=128, num_units=4), + get_block(in_channel=128, depth=256, num_units=14), + get_block(in_channel=256, depth=512, num_units=3) + ] + elif num_layers == 100: + blocks = [ + get_block(in_channel=64, depth=64, num_units=3), + get_block(in_channel=64, depth=128, num_units=13), + get_block(in_channel=128, depth=256, num_units=30), + get_block(in_channel=256, depth=512, num_units=3) + ] + elif num_layers == 152: + blocks = [ + get_block(in_channel=64, depth=64, num_units=3), + get_block(in_channel=64, depth=128, num_units=8), + get_block(in_channel=128, depth=256, num_units=36), + get_block(in_channel=256, depth=512, num_units=3) + ] + else: + raise ValueError("Invalid number of layers: {}. Must be one of [50, 100, 152]".format(num_layers)) + return blocks + + +class SEModule(Module): + def __init__(self, channels, reduction): + super(SEModule, self).__init__() + self.avg_pool = AdaptiveAvgPool2d(1) + self.fc1 = Conv2d(channels, channels // reduction, kernel_size=1, padding=0, bias=False) + self.relu = ReLU(inplace=True) + self.fc2 = Conv2d(channels // reduction, channels, kernel_size=1, padding=0, bias=False) + seltorch.sigmoid = Sigmoid() + + def forward(self, x): + module_input = x + x = self.avg_pool(x) + x = self.fc1(x) + x = self.relu(x) + x = self.fc2(x) + x = seltorch.sigmoid(x) + return module_input * x + + +class bottleneck_IR(Module): + def __init__(self, in_channel, depth, stride): + super(bottleneck_IR, self).__init__() + if in_channel == depth: + self.shortcut_layer = MaxPool2d(1, stride) + else: + self.shortcut_layer = Sequential( + Conv2d(in_channel, depth, (1, 1), stride, bias=False), + BatchNorm2d(depth) + ) + self.res_layer = Sequential( + BatchNorm2d(in_channel), + Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False), PReLU(depth), + Conv2d(depth, depth, (3, 3), stride, 1, bias=False), BatchNorm2d(depth) + ) + + def forward(self, x): + shortcut = self.shortcut_layer(x) + res = self.res_layer(x) + return res + shortcut + + +class bottleneck_IR_SE(Module): + def __init__(self, in_channel, depth, stride): + super(bottleneck_IR_SE, self).__init__() + if in_channel == depth: + self.shortcut_layer = MaxPool2d(1, stride) + else: + self.shortcut_layer = Sequential( + Conv2d(in_channel, depth, (1, 1), stride, bias=False), + BatchNorm2d(depth) + ) + self.res_layer = Sequential( + BatchNorm2d(in_channel), + Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False), + PReLU(depth), + Conv2d(depth, depth, (3, 3), stride, 1, bias=False), + BatchNorm2d(depth), + SEModule(depth, 16) + ) + + def forward(self, x): + shortcut = self.shortcut_layer(x) + res = self.res_layer(x) + return res + shortcut + + +def _upsample_add(x, y): + """Upsample and add two feature maps. + Args: + x: (Variable) top feature map to be upsampled. + y: (Variable) lateral feature map. + Returns: + (Variable) added feature map. + Note in PyTorch, when input size is odd, the upsampled feature map + with `F.upsample(..., scale_factor=2, mode='nearest')` + maybe not equal to the lateral feature map size. + e.g. + original input size: [N,_,15,15] -> + conv2d feature map size: [N,_,8,8] -> + upsampled feature map size: [N,_,16,16] + So we choose bilinear upsample which supports arbitrary output sizes. + """ + _, _, H, W = y.size() + return F.interpolate(x, size=(H, W), mode='bilinear', align_corners=True) + y diff --git a/models/e4e/encoders/model_irse.py b/models/e4e/encoders/model_irse.py new file mode 100644 index 0000000000000000000000000000000000000000..976ce2c61104efdc6b0015d895830346dd01bc10 --- /dev/null +++ b/models/e4e/encoders/model_irse.py @@ -0,0 +1,84 @@ +from torch.nn import Linear, Conv2d, BatchNorm1d, BatchNorm2d, PReLU, Dropout, Sequential, Module +from encoder4editing.models.encoders.helpers import get_blocks, Flatten, bottleneck_IR, bottleneck_IR_SE, l2_norm + +""" +Modified Backbone implementation from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch) +""" + + +class Backbone(Module): + def __init__(self, input_size, num_layers, mode='ir', drop_ratio=0.4, affine=True): + super(Backbone, self).__init__() + assert input_size in [112, 224], "input_size should be 112 or 224" + assert num_layers in [50, 100, 152], "num_layers should be 50, 100 or 152" + assert mode in ['ir', 'ir_se'], "mode should be ir or ir_se" + blocks = get_blocks(num_layers) + if mode == 'ir': + unit_module = bottleneck_IR + elif mode == 'ir_se': + unit_module = bottleneck_IR_SE + self.input_layer = Sequential(Conv2d(3, 64, (3, 3), 1, 1, bias=False), + BatchNorm2d(64), + PReLU(64)) + if input_size == 112: + self.output_layer = Sequential(BatchNorm2d(512), + Dropout(drop_ratio), + Flatten(), + Linear(512 * 7 * 7, 512), + BatchNorm1d(512, affine=affine)) + else: + self.output_layer = Sequential(BatchNorm2d(512), + Dropout(drop_ratio), + Flatten(), + Linear(512 * 14 * 14, 512), + BatchNorm1d(512, affine=affine)) + + modules = [] + for block in blocks: + for bottleneck in block: + modules.append(unit_module(bottleneck.in_channel, + bottleneck.depth, + bottleneck.stride)) + self.body = Sequential(*modules) + + def forward(self, x): + x = self.input_layer(x) + x = self.body(x) + x = self.output_layer(x) + return l2_norm(x) + + +def IR_50(input_size): + """Constructs a ir-50 model.""" + model = Backbone(input_size, num_layers=50, mode='ir', drop_ratio=0.4, affine=False) + return model + + +def IR_101(input_size): + """Constructs a ir-101 model.""" + model = Backbone(input_size, num_layers=100, mode='ir', drop_ratio=0.4, affine=False) + return model + + +def IR_152(input_size): + """Constructs a ir-152 model.""" + model = Backbone(input_size, num_layers=152, mode='ir', drop_ratio=0.4, affine=False) + return model + + +def IR_SE_50(input_size): + """Constructs a ir_se-50 model.""" + model = Backbone(input_size, num_layers=50, mode='ir_se', drop_ratio=0.4, affine=False) + return model + + +def IR_SE_101(input_size): + """Constructs a ir_se-101 model.""" + model = Backbone(input_size, num_layers=100, mode='ir_se', drop_ratio=0.4, affine=False) + return model + + +def IR_SE_152(input_size): + """Constructs a ir_se-152 model.""" + model = Backbone(input_size, num_layers=152, mode='ir_se', drop_ratio=0.4, affine=False) + return model diff --git a/models/e4e/encoders/psp_encoders.py b/models/e4e/encoders/psp_encoders.py new file mode 100644 index 0000000000000000000000000000000000000000..9c7c70e5e2586bd6a0de825e45a80e9116156166 --- /dev/null +++ b/models/e4e/encoders/psp_encoders.py @@ -0,0 +1,200 @@ +from enum import Enum +import math +import numpy as np +import torch +from torch import nn +from torch.nn import Conv2d, BatchNorm2d, PReLU, Sequential, Module + +from models.e4e.encoders.helpers import get_blocks, bottleneck_IR, bottleneck_IR_SE, _upsample_add +from models.e4e.stylegan2.model import EqualLinear + + +class ProgressiveStage(Enum): + WTraining = 0 + Delta1Training = 1 + Delta2Training = 2 + Delta3Training = 3 + Delta4Training = 4 + Delta5Training = 5 + Delta6Training = 6 + Delta7Training = 7 + Delta8Training = 8 + Delta9Training = 9 + Delta10Training = 10 + Delta11Training = 11 + Delta12Training = 12 + Delta13Training = 13 + Delta14Training = 14 + Delta15Training = 15 + Delta16Training = 16 + Delta17Training = 17 + Inference = 18 + + +class GradualStyleBlock(Module): + def __init__(self, in_c, out_c, spatial): + super(GradualStyleBlock, self).__init__() + self.out_c = out_c + self.spatial = spatial + num_pools = int(np.log2(spatial)) + modules = [] + modules += [Conv2d(in_c, out_c, kernel_size=3, stride=2, padding=1), + nn.LeakyReLU()] + for i in range(num_pools - 1): + modules += [ + Conv2d(out_c, out_c, kernel_size=3, stride=2, padding=1), + nn.LeakyReLU() + ] + self.convs = nn.Sequential(*modules) + self.linear = EqualLinear(out_c, out_c, lr_mul=1) + + def forward(self, x): + x = self.convs(x) + x = x.view(-1, self.out_c) + x = self.linear(x) + return x + + +class GradualStyleEncoder(Module): + def __init__(self, num_layers, mode='ir', opts=None): + super(GradualStyleEncoder, self).__init__() + assert num_layers in [50, 100, 152], 'num_layers should be 50,100, or 152' + assert mode in ['ir', 'ir_se'], 'mode should be ir or ir_se' + blocks = get_blocks(num_layers) + if mode == 'ir': + unit_module = bottleneck_IR + elif mode == 'ir_se': + unit_module = bottleneck_IR_SE + self.input_layer = Sequential(Conv2d(3, 64, (3, 3), 1, 1, bias=False), + BatchNorm2d(64), + PReLU(64)) + modules = [] + for block in blocks: + for bottleneck in block: + modules.append(unit_module(bottleneck.in_channel, + bottleneck.depth, + bottleneck.stride)) + self.body = Sequential(*modules) + + self.styles = nn.ModuleList() + log_size = int(math.log(opts.stylegan_size, 2)) + self.style_count = 2 * log_size - 2 + self.coarse_ind = 3 + self.middle_ind = 7 + for i in range(self.style_count): + if i < self.coarse_ind: + style = GradualStyleBlock(512, 512, 16) + elif i < self.middle_ind: + style = GradualStyleBlock(512, 512, 32) + else: + style = GradualStyleBlock(512, 512, 64) + self.styles.append(style) + self.latlayer1 = nn.Conv2d(256, 512, kernel_size=1, stride=1, padding=0) + self.latlayer2 = nn.Conv2d(128, 512, kernel_size=1, stride=1, padding=0) + + def forward(self, x): + x = self.input_layer(x) + + latents = [] + modulelist = list(self.body._modules.values()) + for i, l in enumerate(modulelist): + x = l(x) + if i == 6: + c1 = x + elif i == 20: + c2 = x + elif i == 23: + c3 = x + + for j in range(self.coarse_ind): + latents.append(self.styles[j](c3)) + + p2 = _upsample_add(c3, self.latlayer1(c2)) + for j in range(self.coarse_ind, self.middle_ind): + latents.append(self.styles[j](p2)) + + p1 = _upsample_add(p2, self.latlayer2(c1)) + for j in range(self.middle_ind, self.style_count): + latents.append(self.styles[j](p1)) + + out = torch.stack(latents, dim=1) + return out + + +class Encoder4Editing(Module): + def __init__(self, num_layers, mode='ir', opts=None): + super(Encoder4Editing, self).__init__() + assert num_layers in [50, 100, 152], 'num_layers should be 50,100, or 152' + assert mode in ['ir', 'ir_se'], 'mode should be ir or ir_se' + blocks = get_blocks(num_layers) + if mode == 'ir': + unit_module = bottleneck_IR + elif mode == 'ir_se': + unit_module = bottleneck_IR_SE + self.input_layer = Sequential(Conv2d(3, 64, (3, 3), 1, 1, bias=False), + BatchNorm2d(64), + PReLU(64)) + modules = [] + for block in blocks: + for bottleneck in block: + modules.append(unit_module(bottleneck.in_channel, + bottleneck.depth, + bottleneck.stride)) + self.body = Sequential(*modules) + + self.styles = nn.ModuleList() + log_size = int(math.log(opts.stylegan_size, 2)) + self.style_count = 2 * log_size - 2 + self.coarse_ind = 3 + self.middle_ind = 7 + + for i in range(self.style_count): + if i < self.coarse_ind: + style = GradualStyleBlock(512, 512, 16) + elif i < self.middle_ind: + style = GradualStyleBlock(512, 512, 32) + else: + style = GradualStyleBlock(512, 512, 64) + self.styles.append(style) + + self.latlayer1 = nn.Conv2d(256, 512, kernel_size=1, stride=1, padding=0) + self.latlayer2 = nn.Conv2d(128, 512, kernel_size=1, stride=1, padding=0) + + self.progressive_stage = ProgressiveStage.Inference + + def get_deltas_starting_dimensions(self): + ''' Get a list of the initial dimension of every delta from which it is applied ''' + return list(range(self.style_count)) # Each dimension has a delta applied to it + + def set_progressive_stage(self, new_stage: ProgressiveStage): + self.progressive_stage = new_stage + print('Changed progressive stage to: ', new_stage) + + def forward(self, x): + x = self.input_layer(x) + + modulelist = list(self.body._modules.values()) + for i, l in enumerate(modulelist): + x = l(x) + if i == 6: + c1 = x + elif i == 20: + c2 = x + elif i == 23: + c3 = x + + # Infer main W and duplicate it + w0 = self.styles[0](c3) + w = w0.repeat(self.style_count, 1, 1).permute(1, 0, 2) + stage = self.progressive_stage.value + features = c3 + for i in range(1, min(stage + 1, self.style_count)): # Infer additional deltas + if i == self.coarse_ind: + p2 = _upsample_add(c3, self.latlayer1(c2)) # FPN's middle features + features = p2 + elif i == self.middle_ind: + p1 = _upsample_add(p2, self.latlayer2(c1)) # FPN's fine features + features = p1 + delta_i = self.styles[i](features) + w[:, i] += delta_i + return w diff --git a/models/e4e/latent_codes_pool.py b/models/e4e/latent_codes_pool.py new file mode 100644 index 0000000000000000000000000000000000000000..0281d4b5e80f8eb26e824fa35b4f908dcb6634e6 --- /dev/null +++ b/models/e4e/latent_codes_pool.py @@ -0,0 +1,55 @@ +import random +import torch + + +class LatentCodesPool: + """This class implements latent codes buffer that stores previously generated w latent codes. + This buffer enables us to update discriminators using a history of generated w's + rather than the ones produced by the latest encoder. + """ + + def __init__(self, pool_size): + """Initialize the ImagePool class + Parameters: + pool_size (int) -- the size of image buffer, if pool_size=0, no buffer will be created + """ + self.pool_size = pool_size + if self.pool_size > 0: # create an empty pool + self.num_ws = 0 + self.ws = [] + + def query(self, ws): + """Return w's from the pool. + Parameters: + ws: the latest generated w's from the generator + Returns w's from the buffer. + By 50/100, the buffer will return input w's. + By 50/100, the buffer will return w's previously stored in the buffer, + and insert the current w's to the buffer. + """ + if self.pool_size == 0: # if the buffer size is 0, do nothing + return ws + return_ws = [] + for w in ws: # ws.shape: (batch, 512) or (batch, n_latent, 512) + # w = torch.unsqueeze(image.data, 0) + if w.ndim == 2: + i = random.randint(0, len(w) - 1) # apply a random latent index as a candidate + w = w[i] + self.handle_w(w, return_ws) + return_ws = torch.stack(return_ws, 0) # collect all the images and return + return return_ws + + def handle_w(self, w, return_ws): + if self.num_ws < self.pool_size: # if the buffer is not full; keep inserting current codes to the buffer + self.num_ws = self.num_ws + 1 + self.ws.append(w) + return_ws.append(w) + else: + p = random.uniform(0, 1) + if p > 0.5: # by 50% chance, the buffer will return a previously stored latent code, and insert the current code into the buffer + random_id = random.randint(0, self.pool_size - 1) # randint is inclusive + tmp = self.ws[random_id].clone() + self.ws[random_id] = w + return_ws.append(tmp) + else: # by another 50% chance, the buffer will return the current image + return_ws.append(w) diff --git a/models/e4e/psp.py b/models/e4e/psp.py new file mode 100644 index 0000000000000000000000000000000000000000..bf9f75dbaa66997abfc1e3e0e4f19ddfec7fedac --- /dev/null +++ b/models/e4e/psp.py @@ -0,0 +1,97 @@ +import matplotlib +from configs import paths_config +matplotlib.use('Agg') +import torch +from torch import nn +from models.e4e.encoders import psp_encoders +from models.e4e.stylegan2.model import Generator + + +def get_keys(d, name): + if 'state_dict' in d: + d = d['state_dict'] + d_filt = {k[len(name) + 1:]: v for k, v in d.items() if k[:len(name)] == name} + return d_filt + + +class pSp(nn.Module): + + def __init__(self, opts): + super(pSp, self).__init__() + self.opts = opts + # Define architecture + self.encoder = self.set_encoder() + self.decoder = Generator(opts.stylegan_size, 512, 8, channel_multiplier=2) + self.face_pool = torch.nn.AdaptiveAvgPool2d((256, 256)) + # Load weights if needed + self.load_weights() + + def set_encoder(self): + if self.opts.encoder_type == 'GradualStyleEncoder': + encoder = psp_encoders.GradualStyleEncoder(50, 'ir_se', self.opts) + elif self.opts.encoder_type == 'Encoder4Editing': + encoder = psp_encoders.Encoder4Editing(50, 'ir_se', self.opts) + else: + raise Exception('{} is not a valid encoders'.format(self.opts.encoder_type)) + return encoder + + def load_weights(self): + if self.opts.checkpoint_path is not None: + print('Loading e4e over the pSp framework from checkpoint: {}'.format(self.opts.checkpoint_path)) + ckpt = torch.load(self.opts.checkpoint_path, map_location='cpu') + self.encoder.load_state_dict(get_keys(ckpt, 'encoder'), strict=True) + self.decoder.load_state_dict(get_keys(ckpt, 'decoder'), strict=True) + self.__load_latent_avg(ckpt) + else: + print('Loading encoders weights from irse50!') + encoder_ckpt = torch.load(paths_config.ir_se50) + self.encoder.load_state_dict(encoder_ckpt, strict=False) + print('Loading decoder weights from pretrained!') + ckpt = torch.load(self.opts.stylegan_weights) + self.decoder.load_state_dict(ckpt['g_ema'], strict=False) + self.__load_latent_avg(ckpt, repeat=self.encoder.style_count) + + def forward(self, x, resize=True, latent_mask=None, input_code=False, randomize_noise=True, + inject_latent=None, return_latents=False, alpha=None): + if input_code: + codes = x + else: + codes = self.encoder(x) + # normalize with respect to the center of an average face + if self.opts.start_from_latent_avg: + if codes.ndim == 2: + codes = codes + self.latent_avg.repeat(codes.shape[0], 1, 1)[:, 0, :] + else: + codes = codes + self.latent_avg.repeat(codes.shape[0], 1, 1) + + if latent_mask is not None: + for i in latent_mask: + if inject_latent is not None: + if alpha is not None: + codes[:, i] = alpha * inject_latent[:, i] + (1 - alpha) * codes[:, i] + else: + codes[:, i] = inject_latent[:, i] + else: + codes[:, i] = 0 + + input_is_latent = not input_code + images, result_latent = self.decoder([codes], + input_is_latent=input_is_latent, + randomize_noise=randomize_noise, + return_latents=return_latents) + + if resize: + images = self.face_pool(images) + + if return_latents: + return images, result_latent + else: + return images + + def __load_latent_avg(self, ckpt, repeat=None): + if 'latent_avg' in ckpt: + self.latent_avg = ckpt['latent_avg'].to(self.opts.device) + if repeat is not None: + self.latent_avg = self.latent_avg.repeat(repeat, 1) + else: + self.latent_avg = None diff --git a/models/e4e/stylegan2/__init__.py b/models/e4e/stylegan2/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/models/e4e/stylegan2/__pycache__/__init__.cpython-36.pyc b/models/e4e/stylegan2/__pycache__/__init__.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..68d3433763d37454e99721869433a5187adb571e Binary files /dev/null and b/models/e4e/stylegan2/__pycache__/__init__.cpython-36.pyc differ diff --git a/models/e4e/stylegan2/__pycache__/__init__.cpython-39.pyc b/models/e4e/stylegan2/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2f575df3e18085b8f650360106f861630fddbec7 Binary files /dev/null and b/models/e4e/stylegan2/__pycache__/__init__.cpython-39.pyc differ diff --git a/models/e4e/stylegan2/__pycache__/model.cpython-36.pyc b/models/e4e/stylegan2/__pycache__/model.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..345213714551ed617ef44360745c2af608ac31bc Binary files /dev/null and b/models/e4e/stylegan2/__pycache__/model.cpython-36.pyc differ diff --git a/models/e4e/stylegan2/__pycache__/model.cpython-39.pyc b/models/e4e/stylegan2/__pycache__/model.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..36920e45f0c9cc4c7ce00ac3a364cba8c6ae3201 Binary files /dev/null and b/models/e4e/stylegan2/__pycache__/model.cpython-39.pyc differ diff --git a/models/e4e/stylegan2/model.py b/models/e4e/stylegan2/model.py new file mode 100644 index 0000000000000000000000000000000000000000..ede4360148e260363887662bae7fe68c987ee60e --- /dev/null +++ b/models/e4e/stylegan2/model.py @@ -0,0 +1,674 @@ +import math +import random +import torch +from torch import nn +from torch.nn import functional as F + +from .op.fused_act import FusedLeakyReLU, fused_leaky_relu +from .op.upfirdn2d import upfirdn2d + + +class PixelNorm(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, input): + return input * torch.rsqrt(torch.mean(input ** 2, dim=1, keepdim=True) + 1e-8) + + +def make_kernel(k): + k = torch.tensor(k, dtype=torch.float32) + + if k.ndim == 1: + k = k[None, :] * k[:, None] + + k /= k.sum() + + return k + + +class Upsample(nn.Module): + def __init__(self, kernel, factor=2): + super().__init__() + + self.factor = factor + kernel = make_kernel(kernel) * (factor ** 2) + self.register_buffer('kernel', kernel) + + p = kernel.shape[0] - factor + + pad0 = (p + 1) // 2 + factor - 1 + pad1 = p // 2 + + self.pad = (pad0, pad1) + + def forward(self, input): + out = upfirdn2d(input, self.kernel, up=self.factor, down=1, pad=self.pad) + + return out + + +class Downsample(nn.Module): + def __init__(self, kernel, factor=2): + super().__init__() + + self.factor = factor + kernel = make_kernel(kernel) + self.register_buffer('kernel', kernel) + + p = kernel.shape[0] - factor + + pad0 = (p + 1) // 2 + pad1 = p // 2 + + self.pad = (pad0, pad1) + + def forward(self, input): + out = upfirdn2d(input, self.kernel, up=1, down=self.factor, pad=self.pad) + + return out + + +class Blur(nn.Module): + def __init__(self, kernel, pad, upsample_factor=1): + super().__init__() + + kernel = make_kernel(kernel) + + if upsample_factor > 1: + kernel = kernel * (upsample_factor ** 2) + + self.register_buffer('kernel', kernel) + + self.pad = pad + + def forward(self, input): + out = upfirdn2d(input, self.kernel, pad=self.pad) + + return out + + +class EqualConv2d(nn.Module): + def __init__( + self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True + ): + super().__init__() + + self.weight = nn.Parameter( + torch.randn(out_channel, in_channel, kernel_size, kernel_size) + ) + self.scale = 1 / math.sqrt(in_channel * kernel_size ** 2) + + self.stride = stride + self.padding = padding + + if bias: + self.bias = nn.Parameter(torch.zeros(out_channel)) + + else: + self.bias = None + + def forward(self, input): + out = F.conv2d( + input, + self.weight * self.scale, + bias=self.bias, + stride=self.stride, + padding=self.padding, + ) + + return out + + def __repr__(self): + return ( + f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]},' + f' {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})' + ) + + +class EqualLinear(nn.Module): + def __init__( + self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None + ): + super().__init__() + + self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul)) + + if bias: + self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init)) + + else: + self.bias = None + + self.activation = activation + + self.scale = (1 / math.sqrt(in_dim)) * lr_mul + self.lr_mul = lr_mul + + def forward(self, input): + if self.activation: + out = F.linear(input, self.weight * self.scale) + out = fused_leaky_relu(out, self.bias * self.lr_mul) + + else: + out = F.linear( + input, self.weight * self.scale, bias=self.bias * self.lr_mul + ) + + return out + + def __repr__(self): + return ( + f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})' + ) + + +class ScaledLeakyReLU(nn.Module): + def __init__(self, negative_slope=0.2): + super().__init__() + + self.negative_slope = negative_slope + + def forward(self, input): + out = F.leaky_relu(input, negative_slope=self.negative_slope) + + return out * math.sqrt(2) + + +class ModulatedConv2d(nn.Module): + def __init__( + self, + in_channel, + out_channel, + kernel_size, + style_dim, + demodulate=True, + upsample=False, + downsample=False, + blur_kernel=[1, 3, 3, 1], + ): + super().__init__() + + self.eps = 1e-8 + self.kernel_size = kernel_size + self.in_channel = in_channel + self.out_channel = out_channel + self.upsample = upsample + self.downsample = downsample + + if upsample: + factor = 2 + p = (len(blur_kernel) - factor) - (kernel_size - 1) + pad0 = (p + 1) // 2 + factor - 1 + pad1 = p // 2 + 1 + + self.blur = Blur(blur_kernel, pad=(pad0, pad1), upsample_factor=factor) + + if downsample: + factor = 2 + p = (len(blur_kernel) - factor) + (kernel_size - 1) + pad0 = (p + 1) // 2 + pad1 = p // 2 + + self.blur = Blur(blur_kernel, pad=(pad0, pad1)) + + fan_in = in_channel * kernel_size ** 2 + self.scale = 1 / math.sqrt(fan_in) + self.padding = kernel_size // 2 + + self.weight = nn.Parameter( + torch.randn(1, out_channel, in_channel, kernel_size, kernel_size) + ) + + self.modulation = EqualLinear(style_dim, in_channel, bias_init=1) + + self.demodulate = demodulate + + def __repr__(self): + return ( + f'{self.__class__.__name__}({self.in_channel}, {self.out_channel}, {self.kernel_size}, ' + f'upsample={self.upsample}, downsample={self.downsample})' + ) + + def forward(self, input, style): + batch, in_channel, height, width = input.shape + + style = self.modulation(style).view(batch, 1, in_channel, 1, 1) + weight = self.scale * self.weight * style + + if self.demodulate: + demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-8) + weight = weight * demod.view(batch, self.out_channel, 1, 1, 1) + + weight = weight.view( + batch * self.out_channel, in_channel, self.kernel_size, self.kernel_size + ) + + if self.upsample: + input = input.view(1, batch * in_channel, height, width) + weight = weight.view( + batch, self.out_channel, in_channel, self.kernel_size, self.kernel_size + ) + weight = weight.transpose(1, 2).reshape( + batch * in_channel, self.out_channel, self.kernel_size, self.kernel_size + ) + out = F.conv_transpose2d(input, weight, padding=0, stride=2, groups=batch) + _, _, height, width = out.shape + out = out.view(batch, self.out_channel, height, width) + out = self.blur(out) + + elif self.downsample: + input = self.blur(input) + _, _, height, width = input.shape + input = input.view(1, batch * in_channel, height, width) + out = F.conv2d(input, weight, padding=0, stride=2, groups=batch) + _, _, height, width = out.shape + out = out.view(batch, self.out_channel, height, width) + + else: + input = input.view(1, batch * in_channel, height, width) + out = F.conv2d(input, weight, padding=self.padding, groups=batch) + _, _, height, width = out.shape + out = out.view(batch, self.out_channel, height, width) + + return out + + +class NoiseInjection(nn.Module): + def __init__(self): + super().__init__() + + self.weight = nn.Parameter(torch.zeros(1)) + + def forward(self, image, noise=None): + if noise is None: + batch, _, height, width = image.shape + noise = image.new_empty(batch, 1, height, width).normal_() + + return image + self.weight * noise + + +class ConstantInput(nn.Module): + def __init__(self, channel, size=4): + super().__init__() + + self.input = nn.Parameter(torch.randn(1, channel, size, size)) + + def forward(self, input): + batch = input.shape[0] + out = self.input.repeat(batch, 1, 1, 1) + + return out + + +class StyledConv(nn.Module): + def __init__( + self, + in_channel, + out_channel, + kernel_size, + style_dim, + upsample=False, + blur_kernel=[1, 3, 3, 1], + demodulate=True, + ): + super().__init__() + + self.conv = ModulatedConv2d( + in_channel, + out_channel, + kernel_size, + style_dim, + upsample=upsample, + blur_kernel=blur_kernel, + demodulate=demodulate, + ) + + self.noise = NoiseInjection() + # self.bias = nn.Parameter(torch.zeros(1, out_channel, 1, 1)) + # self.activate = ScaledLeakyReLU(0.2) + self.activate = FusedLeakyReLU(out_channel) + + def forward(self, input, style, noise=None): + out = self.conv(input, style) + out = self.noise(out, noise=noise) + # out = out + self.bias + out = self.activate(out) + + return out + + +class ToRGB(nn.Module): + def __init__(self, in_channel, style_dim, upsample=True, blur_kernel=[1, 3, 3, 1]): + super().__init__() + + if upsample: + self.upsample = Upsample(blur_kernel) + + self.conv = ModulatedConv2d(in_channel, 3, 1, style_dim, demodulate=False) + self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1)) + + def forward(self, input, style, skip=None): + out = self.conv(input, style) + out = out + self.bias + + if skip is not None: + skip = self.upsample(skip) + + out = out + skip + + return out + + +class Generator(nn.Module): + def __init__( + self, + size, + style_dim, + n_mlp, + channel_multiplier=2, + blur_kernel=[1, 3, 3, 1], + lr_mlp=0.01, + ): + super().__init__() + + self.size = size + + self.style_dim = style_dim + + layers = [PixelNorm()] + + for i in range(n_mlp): + layers.append( + EqualLinear( + style_dim, style_dim, lr_mul=lr_mlp, activation='fused_lrelu' + ) + ) + + self.style = nn.Sequential(*layers) + + self.channels = { + 4: 512, + 8: 512, + 16: 512, + 32: 512, + 64: 256 * channel_multiplier, + 128: 128 * channel_multiplier, + 256: 64 * channel_multiplier, + 512: 32 * channel_multiplier, + 1024: 16 * channel_multiplier, + } + + self.input = ConstantInput(self.channels[4]) + self.conv1 = StyledConv( + self.channels[4], self.channels[4], 3, style_dim, blur_kernel=blur_kernel + ) + self.to_rgb1 = ToRGB(self.channels[4], style_dim, upsample=False) + + self.log_size = int(math.log(size, 2)) + self.num_layers = (self.log_size - 2) * 2 + 1 + + self.convs = nn.ModuleList() + self.upsamples = nn.ModuleList() + self.to_rgbs = nn.ModuleList() + self.noises = nn.Module() + + in_channel = self.channels[4] + + for layer_idx in range(self.num_layers): + res = (layer_idx + 5) // 2 + shape = [1, 1, 2 ** res, 2 ** res] + self.noises.register_buffer(f'noise_{layer_idx}', torch.randn(*shape)) + + for i in range(3, self.log_size + 1): + out_channel = self.channels[2 ** i] + + self.convs.append( + StyledConv( + in_channel, + out_channel, + 3, + style_dim, + upsample=True, + blur_kernel=blur_kernel, + ) + ) + + self.convs.append( + StyledConv( + out_channel, out_channel, 3, style_dim, blur_kernel=blur_kernel + ) + ) + + self.to_rgbs.append(ToRGB(out_channel, style_dim)) + + in_channel = out_channel + + self.n_latent = self.log_size * 2 - 2 + + def make_noise(self): + device = self.input.input.device + + noises = [torch.randn(1, 1, 2 ** 2, 2 ** 2, device=device)] + + for i in range(3, self.log_size + 1): + for _ in range(2): + noises.append(torch.randn(1, 1, 2 ** i, 2 ** i, device=device)) + + return noises + + def mean_latent(self, n_latent): + latent_in = torch.randn( + n_latent, self.style_dim, device=self.input.input.device + ) + latent = self.style(latent_in).mean(0, keepdim=True) + + return latent + + def get_latent(self, input): + return self.style(input) + + def forward( + self, + styles, + return_latents=False, + return_features=False, + inject_index=None, + truncation=1, + truncation_latent=None, + input_is_latent=False, + noise=None, + randomize_noise=True, + ): + if not input_is_latent: + styles = [self.style(s) for s in styles] + + if noise is None: + if randomize_noise: + noise = [None] * self.num_layers + else: + noise = [ + getattr(self.noises, f'noise_{i}') for i in range(self.num_layers) + ] + + if truncation < 1: + style_t = [] + + for style in styles: + style_t.append( + truncation_latent + truncation * (style - truncation_latent) + ) + + styles = style_t + + if len(styles) < 2: + inject_index = self.n_latent + + if styles[0].ndim < 3: + latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1) + else: + latent = styles[0] + + else: + if inject_index is None: + inject_index = random.randint(1, self.n_latent - 1) + + latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1) + latent2 = styles[1].unsqueeze(1).repeat(1, self.n_latent - inject_index, 1) + + latent = torch.cat([latent, latent2], 1) + + out = self.input(latent) + out = self.conv1(out, latent[:, 0], noise=noise[0]) + + skip = self.to_rgb1(out, latent[:, 1]) + + i = 1 + for conv1, conv2, noise1, noise2, to_rgb in zip( + self.convs[::2], self.convs[1::2], noise[1::2], noise[2::2], self.to_rgbs + ): + out = conv1(out, latent[:, i], noise=noise1) + out = conv2(out, latent[:, i + 1], noise=noise2) + skip = to_rgb(out, latent[:, i + 2], skip) + + i += 2 + + image = skip + + if return_latents: + return image, latent + elif return_features: + return image, out + else: + return image, None + + +class ConvLayer(nn.Sequential): + def __init__( + self, + in_channel, + out_channel, + kernel_size, + downsample=False, + blur_kernel=[1, 3, 3, 1], + bias=True, + activate=True, + ): + layers = [] + + if downsample: + factor = 2 + p = (len(blur_kernel) - factor) + (kernel_size - 1) + pad0 = (p + 1) // 2 + pad1 = p // 2 + + layers.append(Blur(blur_kernel, pad=(pad0, pad1))) + + stride = 2 + self.padding = 0 + + else: + stride = 1 + self.padding = kernel_size // 2 + + layers.append( + EqualConv2d( + in_channel, + out_channel, + kernel_size, + padding=self.padding, + stride=stride, + bias=bias and not activate, + ) + ) + + if activate: + if bias: + layers.append(FusedLeakyReLU(out_channel)) + + else: + layers.append(ScaledLeakyReLU(0.2)) + + super().__init__(*layers) + + +class ResBlock(nn.Module): + def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1]): + super().__init__() + + self.conv1 = ConvLayer(in_channel, in_channel, 3) + self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=True) + + self.skip = ConvLayer( + in_channel, out_channel, 1, downsample=True, activate=False, bias=False + ) + + def forward(self, input): + out = self.conv1(input) + out = self.conv2(out) + + skip = self.skip(input) + out = (out + skip) / math.sqrt(2) + + return out + + +class Discriminator(nn.Module): + def __init__(self, size, channel_multiplier=2, blur_kernel=[1, 3, 3, 1]): + super().__init__() + + channels = { + 4: 512, + 8: 512, + 16: 512, + 32: 512, + 64: 256 * channel_multiplier, + 128: 128 * channel_multiplier, + 256: 64 * channel_multiplier, + 512: 32 * channel_multiplier, + 1024: 16 * channel_multiplier, + } + + convs = [ConvLayer(3, channels[size], 1)] + + log_size = int(math.log(size, 2)) + + in_channel = channels[size] + + for i in range(log_size, 2, -1): + out_channel = channels[2 ** (i - 1)] + + convs.append(ResBlock(in_channel, out_channel, blur_kernel)) + + in_channel = out_channel + + self.convs = nn.Sequential(*convs) + + self.stddev_group = 4 + self.stddev_feat = 1 + + self.final_conv = ConvLayer(in_channel + 1, channels[4], 3) + self.final_linear = nn.Sequential( + EqualLinear(channels[4] * 4 * 4, channels[4], activation='fused_lrelu'), + EqualLinear(channels[4], 1), + ) + + def forward(self, input): + out = self.convs(input) + + batch, channel, height, width = out.shape + group = min(batch, self.stddev_group) + stddev = out.view( + group, -1, self.stddev_feat, channel // self.stddev_feat, height, width + ) + stddev = torch.sqrt(stddev.var(0, unbiased=False) + 1e-8) + stddev = stddev.mean([2, 3, 4], keepdims=True).squeeze(2) + stddev = stddev.repeat(group, 1, height, width) + out = torch.cat([out, stddev], 1) + + out = self.final_conv(out) + + out = out.view(batch, -1) + out = self.final_linear(out) + + return out diff --git a/models/e4e/stylegan2/op/__init__.py b/models/e4e/stylegan2/op/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d0918d92285955855be89f00096b888ee5597ce3 --- /dev/null +++ b/models/e4e/stylegan2/op/__init__.py @@ -0,0 +1,2 @@ +from .fused_act import FusedLeakyReLU, fused_leaky_relu +from .upfirdn2d import upfirdn2d diff --git a/models/e4e/stylegan2/op/__pycache__/__init__.cpython-36.pyc b/models/e4e/stylegan2/op/__pycache__/__init__.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d562dab67e9de18d463f537d5f04e07fe1c7a0eb Binary files /dev/null and b/models/e4e/stylegan2/op/__pycache__/__init__.cpython-36.pyc differ diff --git a/models/e4e/stylegan2/op/__pycache__/__init__.cpython-39.pyc b/models/e4e/stylegan2/op/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..618ee34d3a56161c5731b7c265464949c01352e0 Binary files /dev/null and b/models/e4e/stylegan2/op/__pycache__/__init__.cpython-39.pyc differ diff --git a/models/e4e/stylegan2/op/__pycache__/fused_act.cpython-36.pyc b/models/e4e/stylegan2/op/__pycache__/fused_act.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..da31273be531830a4019ed5556ac270408d811fd Binary files /dev/null and b/models/e4e/stylegan2/op/__pycache__/fused_act.cpython-36.pyc differ diff --git a/models/e4e/stylegan2/op/__pycache__/fused_act.cpython-39.pyc b/models/e4e/stylegan2/op/__pycache__/fused_act.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..480f419f20318a2f2b19903e477d029c9036e462 Binary files /dev/null and b/models/e4e/stylegan2/op/__pycache__/fused_act.cpython-39.pyc differ diff --git a/models/e4e/stylegan2/op/__pycache__/upfirdn2d.cpython-36.pyc b/models/e4e/stylegan2/op/__pycache__/upfirdn2d.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c4fcc2bb8722b9ce46298d3d64c74b61669525ee Binary files /dev/null and b/models/e4e/stylegan2/op/__pycache__/upfirdn2d.cpython-36.pyc differ diff --git a/models/e4e/stylegan2/op/__pycache__/upfirdn2d.cpython-39.pyc b/models/e4e/stylegan2/op/__pycache__/upfirdn2d.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..389ba4ec5ffc37db254315501c75fa5f5073876f Binary files /dev/null and b/models/e4e/stylegan2/op/__pycache__/upfirdn2d.cpython-39.pyc differ diff --git a/models/e4e/stylegan2/op/fused_act.py b/models/e4e/stylegan2/op/fused_act.py new file mode 100644 index 0000000000000000000000000000000000000000..90949545ba955dabf2e17d8cf5e524d5cb190a63 --- /dev/null +++ b/models/e4e/stylegan2/op/fused_act.py @@ -0,0 +1,34 @@ +import os + +import torch +from torch import nn +from torch.nn import functional as F +from torch.autograd import Function + + +module_path = os.path.dirname(__file__) + + + +class FusedLeakyReLU(nn.Module): + def __init__(self, channel, negative_slope=0.2, scale=2 ** 0.5): + super().__init__() + + self.bias = nn.Parameter(torch.zeros(channel)) + self.negative_slope = negative_slope + self.scale = scale + + def forward(self, input): + return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale) + + +def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2 ** 0.5): + rest_dim = [1] * (input.ndim - bias.ndim - 1) + input = input.cuda() + return ( + F.leaky_relu( + input + bias.view(1, bias.shape[0], *rest_dim), negative_slope=negative_slope + ) + * scale + ) + diff --git a/models/e4e/stylegan2/op/fused_bias_act.cpp b/models/e4e/stylegan2/op/fused_bias_act.cpp new file mode 100644 index 0000000000000000000000000000000000000000..02be898f970bcc8ea297867fcaa4e71b24b3d949 --- /dev/null +++ b/models/e4e/stylegan2/op/fused_bias_act.cpp @@ -0,0 +1,21 @@ +#include + + +torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, + int act, int grad, float alpha, float scale); + +#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") +#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) + +torch::Tensor fused_bias_act(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, + int act, int grad, float alpha, float scale) { + CHECK_CUDA(input); + CHECK_CUDA(bias); + + return fused_bias_act_op(input, bias, refer, act, grad, alpha, scale); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("fused_bias_act", &fused_bias_act, "fused bias act (CUDA)"); +} \ No newline at end of file diff --git a/models/e4e/stylegan2/op/fused_bias_act_kernel.cu b/models/e4e/stylegan2/op/fused_bias_act_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..c9fa56fea7ede7072dc8925cfb0148f136eb85b8 --- /dev/null +++ b/models/e4e/stylegan2/op/fused_bias_act_kernel.cu @@ -0,0 +1,99 @@ +// Copyright (c) 2019, NVIDIA Corporation. All rights reserved. +// +// This work is made available under the Nvidia Source Code License-NC. +// To view a copy of this license, visit +// https://nvlabs.github.io/stylegan2/license.html + +#include + +#include +#include +#include +#include + +#include +#include + + +template +static __global__ void fused_bias_act_kernel(scalar_t* out, const scalar_t* p_x, const scalar_t* p_b, const scalar_t* p_ref, + int act, int grad, scalar_t alpha, scalar_t scale, int loop_x, int size_x, int step_b, int size_b, int use_bias, int use_ref) { + int xi = blockIdx.x * loop_x * blockDim.x + threadIdx.x; + + scalar_t zero = 0.0; + + for (int loop_idx = 0; loop_idx < loop_x && xi < size_x; loop_idx++, xi += blockDim.x) { + scalar_t x = p_x[xi]; + + if (use_bias) { + x += p_b[(xi / step_b) % size_b]; + } + + scalar_t ref = use_ref ? p_ref[xi] : zero; + + scalar_t y; + + switch (act * 10 + grad) { + default: + case 10: y = x; break; + case 11: y = x; break; + case 12: y = 0.0; break; + + case 30: y = (x > 0.0) ? x : x * alpha; break; + case 31: y = (ref > 0.0) ? x : x * alpha; break; + case 32: y = 0.0; break; + } + + out[xi] = y * scale; + } +} + + +torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, + int act, int grad, float alpha, float scale) { + int curDevice = -1; + cudaGetDevice(&curDevice); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice); + + auto x = input.contiguous(); + auto b = bias.contiguous(); + auto ref = refer.contiguous(); + + int use_bias = b.numel() ? 1 : 0; + int use_ref = ref.numel() ? 1 : 0; + + int size_x = x.numel(); + int size_b = b.numel(); + int step_b = 1; + + for (int i = 1 + 1; i < x.dim(); i++) { + step_b *= x.size(i); + } + + int loop_x = 4; + int block_size = 4 * 32; + int grid_size = (size_x - 1) / (loop_x * block_size) + 1; + + auto y = torch::empty_like(x); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "fused_bias_act_kernel", [&] { + fused_bias_act_kernel<<>>( + y.data_ptr(), + x.data_ptr(), + b.data_ptr(), + ref.data_ptr(), + act, + grad, + alpha, + scale, + loop_x, + size_x, + step_b, + size_b, + use_bias, + use_ref + ); + }); + + return y; +} \ No newline at end of file diff --git a/models/e4e/stylegan2/op/upfirdn2d.cpp b/models/e4e/stylegan2/op/upfirdn2d.cpp new file mode 100644 index 0000000000000000000000000000000000000000..d2e633dc896433c205e18bc3e455539192ff968e --- /dev/null +++ b/models/e4e/stylegan2/op/upfirdn2d.cpp @@ -0,0 +1,23 @@ +#include + + +torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel, + int up_x, int up_y, int down_x, int down_y, + int pad_x0, int pad_x1, int pad_y0, int pad_y1); + +#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") +#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) + +torch::Tensor upfirdn2d(const torch::Tensor& input, const torch::Tensor& kernel, + int up_x, int up_y, int down_x, int down_y, + int pad_x0, int pad_x1, int pad_y0, int pad_y1) { + CHECK_CUDA(input); + CHECK_CUDA(kernel); + + return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)"); +} \ No newline at end of file diff --git a/models/e4e/stylegan2/op/upfirdn2d.py b/models/e4e/stylegan2/op/upfirdn2d.py new file mode 100644 index 0000000000000000000000000000000000000000..02fc25af780868d9b883631eb6b03a25c225d745 --- /dev/null +++ b/models/e4e/stylegan2/op/upfirdn2d.py @@ -0,0 +1,60 @@ +import os + +import torch +from torch.nn import functional as F + + +module_path = os.path.dirname(__file__) + + + +def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)): + out = upfirdn2d_native( + input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1] + ) + + return out + + +def upfirdn2d_native( + input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1 +): + _, channel, in_h, in_w = input.shape + input = input.reshape(-1, in_h, in_w, 1) + + _, in_h, in_w, minor = input.shape + kernel_h, kernel_w = kernel.shape + + out = input.view(-1, in_h, 1, in_w, 1, minor) + out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1]) + out = out.view(-1, in_h * up_y, in_w * up_x, minor) + + out = F.pad( + out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)] + ) + out = out[ + :, + max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0), + max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0), + :, + ] + + out = out.permute(0, 3, 1, 2) + out = out.reshape( + [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1] + ) + w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w) + out = F.conv2d(out, w) + out = out.reshape( + -1, + minor, + in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1, + in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1, + ) + out = out.permute(0, 2, 3, 1) + out = out[:, ::down_y, ::down_x, :] + + out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1 + out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1 + + return out.view(-1, channel, out_h, out_w) \ No newline at end of file diff --git a/models/e4e/stylegan2/op/upfirdn2d_kernel.cu b/models/e4e/stylegan2/op/upfirdn2d_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..2a710aa6adc3d43ac93136a1814e3c39970e1c7e --- /dev/null +++ b/models/e4e/stylegan2/op/upfirdn2d_kernel.cu @@ -0,0 +1,272 @@ +// Copyright (c) 2019, NVIDIA Corporation. All rights reserved. +// +// This work is made available under the Nvidia Source Code License-NC. +// To view a copy of this license, visit +// https://nvlabs.github.io/stylegan2/license.html + +#include + +#include +#include +#include +#include + +#include +#include + + +static __host__ __device__ __forceinline__ int floor_div(int a, int b) { + int c = a / b; + + if (c * b > a) { + c--; + } + + return c; +} + + +struct UpFirDn2DKernelParams { + int up_x; + int up_y; + int down_x; + int down_y; + int pad_x0; + int pad_x1; + int pad_y0; + int pad_y1; + + int major_dim; + int in_h; + int in_w; + int minor_dim; + int kernel_h; + int kernel_w; + int out_h; + int out_w; + int loop_major; + int loop_x; +}; + + +template +__global__ void upfirdn2d_kernel(scalar_t* out, const scalar_t* input, const scalar_t* kernel, const UpFirDn2DKernelParams p) { + const int tile_in_h = ((tile_out_h - 1) * down_y + kernel_h - 1) / up_y + 1; + const int tile_in_w = ((tile_out_w - 1) * down_x + kernel_w - 1) / up_x + 1; + + __shared__ volatile float sk[kernel_h][kernel_w]; + __shared__ volatile float sx[tile_in_h][tile_in_w]; + + int minor_idx = blockIdx.x; + int tile_out_y = minor_idx / p.minor_dim; + minor_idx -= tile_out_y * p.minor_dim; + tile_out_y *= tile_out_h; + int tile_out_x_base = blockIdx.y * p.loop_x * tile_out_w; + int major_idx_base = blockIdx.z * p.loop_major; + + if (tile_out_x_base >= p.out_w | tile_out_y >= p.out_h | major_idx_base >= p.major_dim) { + return; + } + + for (int tap_idx = threadIdx.x; tap_idx < kernel_h * kernel_w; tap_idx += blockDim.x) { + int ky = tap_idx / kernel_w; + int kx = tap_idx - ky * kernel_w; + scalar_t v = 0.0; + + if (kx < p.kernel_w & ky < p.kernel_h) { + v = kernel[(p.kernel_h - 1 - ky) * p.kernel_w + (p.kernel_w - 1 - kx)]; + } + + sk[ky][kx] = v; + } + + for (int loop_major = 0, major_idx = major_idx_base; loop_major < p.loop_major & major_idx < p.major_dim; loop_major++, major_idx++) { + for (int loop_x = 0, tile_out_x = tile_out_x_base; loop_x < p.loop_x & tile_out_x < p.out_w; loop_x++, tile_out_x += tile_out_w) { + int tile_mid_x = tile_out_x * down_x + up_x - 1 - p.pad_x0; + int tile_mid_y = tile_out_y * down_y + up_y - 1 - p.pad_y0; + int tile_in_x = floor_div(tile_mid_x, up_x); + int tile_in_y = floor_div(tile_mid_y, up_y); + + __syncthreads(); + + for (int in_idx = threadIdx.x; in_idx < tile_in_h * tile_in_w; in_idx += blockDim.x) { + int rel_in_y = in_idx / tile_in_w; + int rel_in_x = in_idx - rel_in_y * tile_in_w; + int in_x = rel_in_x + tile_in_x; + int in_y = rel_in_y + tile_in_y; + + scalar_t v = 0.0; + + if (in_x >= 0 & in_y >= 0 & in_x < p.in_w & in_y < p.in_h) { + v = input[((major_idx * p.in_h + in_y) * p.in_w + in_x) * p.minor_dim + minor_idx]; + } + + sx[rel_in_y][rel_in_x] = v; + } + + __syncthreads(); + for (int out_idx = threadIdx.x; out_idx < tile_out_h * tile_out_w; out_idx += blockDim.x) { + int rel_out_y = out_idx / tile_out_w; + int rel_out_x = out_idx - rel_out_y * tile_out_w; + int out_x = rel_out_x + tile_out_x; + int out_y = rel_out_y + tile_out_y; + + int mid_x = tile_mid_x + rel_out_x * down_x; + int mid_y = tile_mid_y + rel_out_y * down_y; + int in_x = floor_div(mid_x, up_x); + int in_y = floor_div(mid_y, up_y); + int rel_in_x = in_x - tile_in_x; + int rel_in_y = in_y - tile_in_y; + int kernel_x = (in_x + 1) * up_x - mid_x - 1; + int kernel_y = (in_y + 1) * up_y - mid_y - 1; + + scalar_t v = 0.0; + + #pragma unroll + for (int y = 0; y < kernel_h / up_y; y++) + #pragma unroll + for (int x = 0; x < kernel_w / up_x; x++) + v += sx[rel_in_y + y][rel_in_x + x] * sk[kernel_y + y * up_y][kernel_x + x * up_x]; + + if (out_x < p.out_w & out_y < p.out_h) { + out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim + minor_idx] = v; + } + } + } + } +} + + +torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel, + int up_x, int up_y, int down_x, int down_y, + int pad_x0, int pad_x1, int pad_y0, int pad_y1) { + int curDevice = -1; + cudaGetDevice(&curDevice); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice); + + UpFirDn2DKernelParams p; + + auto x = input.contiguous(); + auto k = kernel.contiguous(); + + p.major_dim = x.size(0); + p.in_h = x.size(1); + p.in_w = x.size(2); + p.minor_dim = x.size(3); + p.kernel_h = k.size(0); + p.kernel_w = k.size(1); + p.up_x = up_x; + p.up_y = up_y; + p.down_x = down_x; + p.down_y = down_y; + p.pad_x0 = pad_x0; + p.pad_x1 = pad_x1; + p.pad_y0 = pad_y0; + p.pad_y1 = pad_y1; + + p.out_h = (p.in_h * p.up_y + p.pad_y0 + p.pad_y1 - p.kernel_h + p.down_y) / p.down_y; + p.out_w = (p.in_w * p.up_x + p.pad_x0 + p.pad_x1 - p.kernel_w + p.down_x) / p.down_x; + + auto out = at::empty({p.major_dim, p.out_h, p.out_w, p.minor_dim}, x.options()); + + int mode = -1; + + int tile_out_h; + int tile_out_w; + + if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 && p.kernel_h <= 4 && p.kernel_w <= 4) { + mode = 1; + tile_out_h = 16; + tile_out_w = 64; + } + + if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 && p.kernel_h <= 3 && p.kernel_w <= 3) { + mode = 2; + tile_out_h = 16; + tile_out_w = 64; + } + + if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 && p.kernel_h <= 4 && p.kernel_w <= 4) { + mode = 3; + tile_out_h = 16; + tile_out_w = 64; + } + + if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 && p.kernel_h <= 2 && p.kernel_w <= 2) { + mode = 4; + tile_out_h = 16; + tile_out_w = 64; + } + + if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 && p.kernel_h <= 4 && p.kernel_w <= 4) { + mode = 5; + tile_out_h = 8; + tile_out_w = 32; + } + + if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 && p.kernel_h <= 2 && p.kernel_w <= 2) { + mode = 6; + tile_out_h = 8; + tile_out_w = 32; + } + + dim3 block_size; + dim3 grid_size; + + if (tile_out_h > 0 && tile_out_w) { + p.loop_major = (p.major_dim - 1) / 16384 + 1; + p.loop_x = 1; + block_size = dim3(32 * 8, 1, 1); + grid_size = dim3(((p.out_h - 1) / tile_out_h + 1) * p.minor_dim, + (p.out_w - 1) / (p.loop_x * tile_out_w) + 1, + (p.major_dim - 1) / p.loop_major + 1); + } + + AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] { + switch (mode) { + case 1: + upfirdn2d_kernel<<>>( + out.data_ptr(), x.data_ptr(), k.data_ptr(), p + ); + + break; + + case 2: + upfirdn2d_kernel<<>>( + out.data_ptr(), x.data_ptr(), k.data_ptr(), p + ); + + break; + + case 3: + upfirdn2d_kernel<<>>( + out.data_ptr(), x.data_ptr(), k.data_ptr(), p + ); + + break; + + case 4: + upfirdn2d_kernel<<>>( + out.data_ptr(), x.data_ptr(), k.data_ptr(), p + ); + + break; + + case 5: + upfirdn2d_kernel<<>>( + out.data_ptr(), x.data_ptr(), k.data_ptr(), p + ); + + break; + + case 6: + upfirdn2d_kernel<<>>( + out.data_ptr(), x.data_ptr(), k.data_ptr(), p + ); + + break; + } + }); + + return out; +} \ No newline at end of file diff --git a/scripts/__init__.py b/scripts/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scripts/__pycache__/__init__.cpython-36.pyc b/scripts/__pycache__/__init__.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e1e0698746f8f73936b3ff4da15e898c725803cf Binary files /dev/null and b/scripts/__pycache__/__init__.cpython-36.pyc differ diff --git a/scripts/__pycache__/__init__.cpython-39.pyc b/scripts/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3f0e63aadc2c17779d5435e653abb0e8a7edcfcc Binary files /dev/null and b/scripts/__pycache__/__init__.cpython-39.pyc differ diff --git a/scripts/__pycache__/latent_editor_wrapper.cpython-39.pyc b/scripts/__pycache__/latent_editor_wrapper.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..92ba3f4f2136f2ae96204eac9a3fbbba24b59dae Binary files /dev/null and b/scripts/__pycache__/latent_editor_wrapper.cpython-39.pyc differ diff --git a/scripts/__pycache__/run_pti.cpython-36.pyc b/scripts/__pycache__/run_pti.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9c98fd8a4a3f386660b6f5c0cb62bba00a30ba8e Binary files /dev/null and b/scripts/__pycache__/run_pti.cpython-36.pyc differ diff --git a/scripts/__pycache__/run_pti.cpython-39.pyc b/scripts/__pycache__/run_pti.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d90bf53985c0e872bcbfaa5c167b73dab42d6073 Binary files /dev/null and b/scripts/__pycache__/run_pti.cpython-39.pyc differ diff --git a/scripts/latent_creators/__init__.py b/scripts/latent_creators/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scripts/latent_creators/base_latent_creator.py b/scripts/latent_creators/base_latent_creator.py new file mode 100644 index 0000000000000000000000000000000000000000..e1fcdbdb95541f9aa0485039f9e46d04b3007fcf --- /dev/null +++ b/scripts/latent_creators/base_latent_creator.py @@ -0,0 +1,62 @@ +import abc +import logging + +import pickle + +import os +from random import choice +from string import ascii_uppercase +import torch +from torch.utils.data import DataLoader +import wandb +from configs import global_config, paths_config +from tqdm import tqdm + +from torchvision import transforms + +from utils.ImagesDataset import ImagesDataset + + +class BaseLatentCreator: + + def __init__(self, method_name, dara_preprocess=None, use_wandb=False): + global_config.run_name = ''.join(choice(ascii_uppercase) for i in range(12)) + self.use_wandb = use_wandb + if use_wandb: + run = wandb.init(project="personalized_stylegan", reinit=True, name=global_config.run_name) + + os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID' + os.environ['CUDA_VISIBLE_DEVICES'] = global_config.cuda_visible_devices + + if dara_preprocess is None: + self.projection_preprocess = transforms.Compose([ + transforms.Resize(1024), + transforms.ToTensor(), + transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]) + else: + self.projection_preprocess = dara_preprocess + + image_dataset = ImagesDataset(f'{paths_config.input_data_path}', self.projection_preprocess) + self.image_dataloader = DataLoader(image_dataset, batch_size=1, shuffle=False) + + base_latent_folder_path = f'{paths_config.embedding_base_dir}/{paths_config.input_data_id}' + os.makedirs(base_latent_folder_path, exist_ok=True) + self.latent_folder_path = f'{base_latent_folder_path}/{method_name}' + os.makedirs(self.latent_folder_path, exist_ok=True) + + with open(paths_config.stylegan2_ada_ffhq, 'rb') as f: + self.old_G = pickle.load(f)['G_ema'].cuda() + + @abc.abstractmethod + def run_projection(self, fname, image): + return None + + def create_latents(self): + for fname, image in tqdm(self.image_dataloader): + fname = fname[0] + cur_latent_folder_path = f'{self.latent_folder_path}/{fname}' + image = image.cuda() + w = self.run_projection(fname, image) + + os.makedirs(cur_latent_folder_path, exist_ok=True) + torch.save(w, f'{cur_latent_folder_path}/0.pt') diff --git a/scripts/latent_creators/e4e_latent_creator.py b/scripts/latent_creators/e4e_latent_creator.py new file mode 100644 index 0000000000000000000000000000000000000000..5726a3e286374020609a1d58708fa2659ba73b22 --- /dev/null +++ b/scripts/latent_creators/e4e_latent_creator.py @@ -0,0 +1,44 @@ +import torch +from argparse import Namespace +from torchvision.transforms import transforms + +from configs import paths_config +from models.e4e.psp import pSp +from scripts.latent_creators.base_latent_creator import BaseLatentCreator +from utils.log_utils import log_image_from_w + + +class E4ELatentCreator(BaseLatentCreator): + + def __init__(self, use_wandb=False): + self.e4e_inversion_pre_process = transforms.Compose([ + transforms.Resize((256, 256)), + transforms.ToTensor(), + transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]) + + super().__init__('e4e', self.e4e_inversion_pre_process, use_wandb=use_wandb) + + e4e_model_path = paths_config.e4e + ckpt = torch.load(e4e_model_path, map_location='cpu') + opts = ckpt['opts'] + opts['batch_size'] = 1 + opts['checkpoint_path'] = e4e_model_path + opts = Namespace(**opts) + self.e4e_inversion_net = pSp(opts) + self.e4e_inversion_net.eval() + self.e4e_inversion_net = self.e4e_inversion_net.cuda() + + def run_projection(self, fname, image): + _, e4e_image_latent = self.e4e_inversion_net(image, randomize_noise=False, return_latents=True, + resize=False, + input_code=False) + + if self.use_wandb: + log_image_from_w(e4e_image_latent, self.old_G, 'First e4e inversion') + + return e4e_image_latent + + +if __name__ == '__main__': + e4e_latent_creator = E4ELatentCreator() + e4e_latent_creator.create_latents() diff --git a/scripts/latent_creators/sg2_latent_creator.py b/scripts/latent_creators/sg2_latent_creator.py new file mode 100644 index 0000000000000000000000000000000000000000..5c33704038a5a1bbc410be8caf9aa317346f400a --- /dev/null +++ b/scripts/latent_creators/sg2_latent_creator.py @@ -0,0 +1,24 @@ +import torch +from configs import global_config, paths_config +from scripts.latent_creators.base_latent_creator import BaseLatentCreator +from training.projectors import w_projector + + +class SG2LatentCreator(BaseLatentCreator): + + def __init__(self, use_wandb=False, projection_steps=600): + super().__init__(paths_config.sg2_results_keyword, use_wandb=use_wandb) + + self.projection_steps = projection_steps + + def run_projection(self, fname, image): + image = torch.squeeze((image.to(global_config.device) + 1) / 2) * 255 + w = w_projector.project(self.old_G, image, device=torch.device(global_config.device), + num_steps=self.projection_steps, w_name=fname, use_wandb=self.use_wandb) + + return w + + +if __name__ == '__main__': + id_change_report = SG2LatentCreator() + id_change_report.create_latents() diff --git a/scripts/latent_creators/sg2_plus_latent_creator.py b/scripts/latent_creators/sg2_plus_latent_creator.py new file mode 100644 index 0000000000000000000000000000000000000000..3f11ce883699a6801510712c8ee2eaa7a63fac1d --- /dev/null +++ b/scripts/latent_creators/sg2_plus_latent_creator.py @@ -0,0 +1,24 @@ +import torch +from configs import global_config, paths_config +from scripts.latent_creators.base_latent_creator import BaseLatentCreator +from training.projectors import w_plus_projector + + +class SG2PlusLatentCreator(BaseLatentCreator): + + def __init__(self, use_wandb=False, projection_steps=2000): + super().__init__(paths_config.sg2_plus_results_keyword, use_wandb=use_wandb) + + self.projection_steps = projection_steps + + def run_projection(self, fname, image): + image = torch.squeeze((image.to(global_config.device) + 1) / 2) * 255 + w = w_plus_projector.project(self.old_G, image, device=torch.device(global_config.device), + num_steps=self.projection_steps, w_name=fname, use_wandb=self.use_wandb) + + return w + + +if __name__ == '__main__': + id_change_report = SG2PlusLatentCreator() + id_change_report.create_latents() diff --git a/scripts/latent_editor_wrapper.py b/scripts/latent_editor_wrapper.py new file mode 100644 index 0000000000000000000000000000000000000000..8f3297b5574e495376a2c1686466780dee953a7d --- /dev/null +++ b/scripts/latent_editor_wrapper.py @@ -0,0 +1,46 @@ +import torch +from configs import paths_config +from editings.latent_editor import LatentEditor + + +class LatentEditorWrapper: + + def __init__(self): + + self.interfacegan_directions = {'age': f'{paths_config.interfacegan_age}', + 'smile': f'{paths_config.interfacegan_smile}', + 'rotation': f'{paths_config.interfacegan_rotation}'} + self.interfacegan_directions_tensors = {name: torch.load(path).cuda() for name, path in + self.interfacegan_directions.items()} + self.ganspace_pca = torch.load(f'{paths_config.ffhq_pca}') + + ## For more edit directions please visit .. + self.ganspace_directions = { + 'eye_openness': (54, 7, 8, 5), + 'smile': (46, 4, 5, -6), + 'trimmed_beard': (58, 7, 9, 7), + } + + self.latent_editor = LatentEditor() + + def get_single_ganspace_edits(self, start_w, factors): + latents_to_display = [] + for ganspace_direction in self.ganspace_directions.values(): + for factor in factors: + edit_direction = list(ganspace_direction) + edit_direction[-1] = factor + edit_direction = tuple(edit_direction) + new_w = self.latent_editor.apply_ganspace(start_w, self.ganspace_pca, [edit_direction]) + latents_to_display.append(new_w) + return latents_to_display + + def get_single_interface_gan_edits(self, start_w, factors): + latents_to_display = {} + for direction in ['rotation', 'smile', 'age']: + for factor in factors: + if direction not in latents_to_display: + latents_to_display[direction] = {} + latents_to_display[direction][factor] = self.latent_editor.apply_interfacegan( + start_w, self.interfacegan_directions_tensors[direction], factor / 2) + + return latents_to_display diff --git a/scripts/pti_styleclip.py b/scripts/pti_styleclip.py new file mode 100644 index 0000000000000000000000000000000000000000..c886ec41acc274732c0924a4c1cb03ccd469fce8 --- /dev/null +++ b/scripts/pti_styleclip.py @@ -0,0 +1,57 @@ +import glob +from argparse import Namespace +from configs import paths_config +from models.StyleCLIP.mapper.scripts.inference import run +from scripts.run_pti import run_PTI + +meta_data = { + 'afro': ['afro', False, False, True], + 'angry': ['angry', False, False, True], + 'Beyonce': ['beyonce', False, False, False], + 'bobcut': ['bobcut', False, False, True], + 'bowlcut': ['bowlcut', False, False, True], + 'curly hair': ['curly_hair', False, False, True], + 'Hilary Clinton': ['hilary_clinton', False, False, False], + 'Jhonny Depp': ['depp', False, False, False], + 'mohawk': ['mohawk', False, False, True], + 'purple hair': ['purple_hair', False, False, False], + 'surprised': ['surprised', False, False, True], + 'Taylor Swift': ['taylor_swift', False, False, False], + 'trump': ['trump', False, False, False], + 'Mark Zuckerberg': ['zuckerberg', False, False, False] +} + + +def styleclip_edit(use_multi_id_G, run_id, use_wandb, edit_types): + images_dir = paths_config.input_data_path + pretrained_mappers = paths_config.style_clip_pretrained_mappers + data_dir_name = paths_config.input_data_id + if run_id == '': + run_id = run_PTI(run_name='', use_wandb=use_wandb, use_multi_id_training=False) + images = glob.glob(f"{images_dir}/*.jpeg") + w_path_dir = f'{paths_config.embedding_base_dir}/{paths_config.input_data_id}' + for image_name in images: + image_name = image_name.split(".")[0].split("/")[-1] + embedding_dir = f'{w_path_dir}/{paths_config.pti_results_keyword}/{image_name}' + latent_path = f'{embedding_dir}/0.pt' + for edit_type in set(meta_data.keys()).intersection(edit_types): + edit_id = meta_data[edit_type][0] + args = { + "exp_dir": f'{paths_config.styleclip_output_dir}', + "checkpoint_path": f"{pretrained_mappers}/{edit_id}.pt", + "couple_outputs": False, + "mapper_type": "LevelsMapper", + "no_coarse_mapper": meta_data[edit_type][1], + "no_medium_mapper": meta_data[edit_type][2], + "no_fine_mapper": meta_data[edit_type][3], + "stylegan_size": 1024, + "test_batch_size": 1, + "latents_test_path": latent_path, + "test_workers": 1, + "run_id": run_id, + "image_name": image_name, + 'edit_name': edit_type, + "data_dir_name": data_dir_name + } + + run(Namespace(**args), run_id, image_name, use_multi_id_G) diff --git a/scripts/run_pti.py b/scripts/run_pti.py new file mode 100644 index 0000000000000000000000000000000000000000..04e11b265cdadcc1d110e8009785d6d4511a3991 --- /dev/null +++ b/scripts/run_pti.py @@ -0,0 +1,48 @@ +from random import choice +from string import ascii_uppercase +from torch.utils.data import DataLoader +from torchvision.transforms import transforms +import os +from configs import global_config, paths_config +import wandb + +from training.coaches.multi_id_coach import MultiIDCoach +from training.coaches.single_id_coach import SingleIDCoach +from utils.ImagesDataset import ImagesDataset + + +def run_PTI(run_name='', use_wandb=False, use_multi_id_training=False): + os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID' + os.environ['CUDA_VISIBLE_DEVICES'] = global_config.cuda_visible_devices + + if run_name == '': + global_config.run_name = ''.join(choice(ascii_uppercase) for i in range(12)) + else: + global_config.run_name = run_name + + if use_wandb: + run = wandb.init(project=paths_config.pti_results_keyword, reinit=True, name=global_config.run_name) + global_config.pivotal_training_steps = 1 + global_config.training_step = 1 + + embedding_dir_path = f'{paths_config.embedding_base_dir}/{paths_config.input_data_id}/{paths_config.pti_results_keyword}' + os.makedirs(embedding_dir_path, exist_ok=True) + + dataset = ImagesDataset(paths_config.input_data_path, transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])) + + dataloader = DataLoader(dataset, batch_size=1, shuffle=False) + + if use_multi_id_training: + coach = MultiIDCoach(dataloader, use_wandb) + else: + coach = SingleIDCoach(dataloader, use_wandb) + + coach.train() + + return global_config.run_name + + +if __name__ == '__main__': + run_PTI(run_name='', use_wandb=False, use_multi_id_training=False) diff --git a/torch_utils/__init__.py b/torch_utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ece0ea08fe2e939cc260a1dafc0ab5b391b773d9 --- /dev/null +++ b/torch_utils/__init__.py @@ -0,0 +1,9 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +# empty diff --git a/torch_utils/custom_ops.py b/torch_utils/custom_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..4cc4e43fc6f6ce79f2bd68a44ba87990b9b8564e --- /dev/null +++ b/torch_utils/custom_ops.py @@ -0,0 +1,126 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +import os +import glob +import torch +import torch.utils.cpp_extension +import importlib +import hashlib +import shutil +from pathlib import Path + +from torch.utils.file_baton import FileBaton + +#---------------------------------------------------------------------------- +# Global options. + +verbosity = 'brief' # Verbosity level: 'none', 'brief', 'full' + +#---------------------------------------------------------------------------- +# Internal helper funcs. + +def _find_compiler_bindir(): + patterns = [ + 'C:/Program Files (x86)/Microsoft Visual Studio/*/Professional/VC/Tools/MSVC/*/bin/Hostx64/x64', + 'C:/Program Files (x86)/Microsoft Visual Studio/*/BuildTools/VC/Tools/MSVC/*/bin/Hostx64/x64', + 'C:/Program Files (x86)/Microsoft Visual Studio/*/Community/VC/Tools/MSVC/*/bin/Hostx64/x64', + 'C:/Program Files (x86)/Microsoft Visual Studio */vc/bin', + ] + for pattern in patterns: + matches = sorted(glob.glob(pattern)) + if len(matches): + return matches[-1] + return None + +#---------------------------------------------------------------------------- +# Main entry point for compiling and loading C++/CUDA plugins. + +_cached_plugins = dict() + +def get_plugin(module_name, sources, **build_kwargs): + assert verbosity in ['none', 'brief', 'full'] + + # Already cached? + if module_name in _cached_plugins: + return _cached_plugins[module_name] + + # Print status. + if verbosity == 'full': + print(f'Setting up PyTorch plugin "{module_name}"...') + elif verbosity == 'brief': + print(f'Setting up PyTorch plugin "{module_name}"... ', end='', flush=True) + + try: # pylint: disable=too-many-nested-blocks + # Make sure we can find the necessary compiler binaries. + if os.name == 'nt' and os.system("where cl.exe >nul 2>nul") != 0: + compiler_bindir = _find_compiler_bindir() + if compiler_bindir is None: + raise RuntimeError(f'Could not find MSVC/GCC/CLANG installation on this computer. Check _find_compiler_bindir() in "{__file__}".') + os.environ['PATH'] += ';' + compiler_bindir + + # Compile and load. + verbose_build = (verbosity == 'full') + + # Incremental build md5sum trickery. Copies all the input source files + # into a cached build directory under a combined md5 digest of the input + # source files. Copying is done only if the combined digest has changed. + # This keeps input file timestamps and filenames the same as in previous + # extension builds, allowing for fast incremental rebuilds. + # + # This optimization is done only in case all the source files reside in + # a single directory (just for simplicity) and if the TORCH_EXTENSIONS_DIR + # environment variable is set (we take this as a signal that the user + # actually cares about this.) + source_dirs_set = set(os.path.dirname(source) for source in sources) + if len(source_dirs_set) == 1 and ('TORCH_EXTENSIONS_DIR' in os.environ): + all_source_files = sorted(list(x for x in Path(list(source_dirs_set)[0]).iterdir() if x.is_file())) + + # Compute a combined hash digest for all source files in the same + # custom op directory (usually .cu, .cpp, .py and .h files). + hash_md5 = hashlib.md5() + for src in all_source_files: + with open(src, 'rb') as f: + hash_md5.update(f.read()) + build_dir = torch.utils.cpp_extension._get_build_directory(module_name, verbose=verbose_build) # pylint: disable=protected-access + digest_build_dir = os.path.join(build_dir, hash_md5.hexdigest()) + + if not os.path.isdir(digest_build_dir): + os.makedirs(digest_build_dir, exist_ok=True) + baton = FileBaton(os.path.join(digest_build_dir, 'lock')) + if baton.try_acquire(): + try: + for src in all_source_files: + shutil.copyfile(src, os.path.join(digest_build_dir, os.path.basename(src))) + finally: + baton.release() + else: + # Someone else is copying source files under the digest dir, + # wait until done and continue. + baton.wait() + digest_sources = [os.path.join(digest_build_dir, os.path.basename(x)) for x in sources] + torch.utils.cpp_extension.load(name=module_name, build_directory=build_dir, + verbose=verbose_build, sources=digest_sources, **build_kwargs) + else: + torch.utils.cpp_extension.load(name=module_name, verbose=verbose_build, sources=sources, **build_kwargs) + module = importlib.import_module(module_name) + + except: + if verbosity == 'brief': + print('Failed!') + raise + + # Print status and add to cache. + if verbosity == 'full': + print(f'Done setting up PyTorch plugin "{module_name}".') + elif verbosity == 'brief': + print('Done.') + _cached_plugins[module_name] = module + return module + +#---------------------------------------------------------------------------- diff --git a/torch_utils/misc.py b/torch_utils/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..7829f4d9f168557ce8a9a6dec289aa964234cb8c --- /dev/null +++ b/torch_utils/misc.py @@ -0,0 +1,262 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +import re +import contextlib +import numpy as np +import torch +import warnings +import dnnlib + +#---------------------------------------------------------------------------- +# Cached construction of constant tensors. Avoids CPU=>GPU copy when the +# same constant is used multiple times. + +_constant_cache = dict() + +def constant(value, shape=None, dtype=None, device=None, memory_format=None): + value = np.asarray(value) + if shape is not None: + shape = tuple(shape) + if dtype is None: + dtype = torch.get_default_dtype() + if device is None: + device = torch.device('cpu') + if memory_format is None: + memory_format = torch.contiguous_format + + key = (value.shape, value.dtype, value.tobytes(), shape, dtype, device, memory_format) + tensor = _constant_cache.get(key, None) + if tensor is None: + tensor = torch.as_tensor(value.copy(), dtype=dtype, device=device) + if shape is not None: + tensor, _ = torch.broadcast_tensors(tensor, torch.empty(shape)) + tensor = tensor.contiguous(memory_format=memory_format) + _constant_cache[key] = tensor + return tensor + +#---------------------------------------------------------------------------- +# Replace NaN/Inf with specified numerical values. + +try: + nan_to_num = torch.nan_to_num # 1.8.0a0 +except AttributeError: + def nan_to_num(input, nan=0.0, posinf=None, neginf=None, *, out=None): # pylint: disable=redefined-builtin + assert isinstance(input, torch.Tensor) + if posinf is None: + posinf = torch.finfo(input.dtype).max + if neginf is None: + neginf = torch.finfo(input.dtype).min + assert nan == 0 + return torch.clamp(input.unsqueeze(0).nansum(0), min=neginf, max=posinf, out=out) + +#---------------------------------------------------------------------------- +# Symbolic assert. + +try: + symbolic_assert = torch._assert # 1.8.0a0 # pylint: disable=protected-access +except AttributeError: + symbolic_assert = torch.Assert # 1.7.0 + +#---------------------------------------------------------------------------- +# Context manager to suppress known warnings in torch.jit.trace(). + +class suppress_tracer_warnings(warnings.catch_warnings): + def __enter__(self): + super().__enter__() + warnings.simplefilter('ignore', category=torch.jit.TracerWarning) + return self + +#---------------------------------------------------------------------------- +# Assert that the shape of a tensor matches the given list of integers. +# None indicates that the size of a dimension is allowed to vary. +# Performs symbolic assertion when used in torch.jit.trace(). + +def assert_shape(tensor, ref_shape): + if tensor.ndim != len(ref_shape): + raise AssertionError(f'Wrong number of dimensions: got {tensor.ndim}, expected {len(ref_shape)}') + for idx, (size, ref_size) in enumerate(zip(tensor.shape, ref_shape)): + if ref_size is None: + pass + elif isinstance(ref_size, torch.Tensor): + with suppress_tracer_warnings(): # as_tensor results are registered as constants + symbolic_assert(torch.equal(torch.as_tensor(size), ref_size), f'Wrong size for dimension {idx}') + elif isinstance(size, torch.Tensor): + with suppress_tracer_warnings(): # as_tensor results are registered as constants + symbolic_assert(torch.equal(size, torch.as_tensor(ref_size)), f'Wrong size for dimension {idx}: expected {ref_size}') + elif size != ref_size: + raise AssertionError(f'Wrong size for dimension {idx}: got {size}, expected {ref_size}') + +#---------------------------------------------------------------------------- +# Function decorator that calls torch.autograd.profiler.record_function(). + +def profiled_function(fn): + def decorator(*args, **kwargs): + with torch.autograd.profiler.record_function(fn.__name__): + return fn(*args, **kwargs) + decorator.__name__ = fn.__name__ + return decorator + +#---------------------------------------------------------------------------- +# Sampler for torch.utils.data.DataLoader that loops over the dataset +# indefinitely, shuffling items as it goes. + +class InfiniteSampler(torch.utils.data.Sampler): + def __init__(self, dataset, rank=0, num_replicas=1, shuffle=True, seed=0, window_size=0.5): + assert len(dataset) > 0 + assert num_replicas > 0 + assert 0 <= rank < num_replicas + assert 0 <= window_size <= 1 + super().__init__(dataset) + self.dataset = dataset + self.rank = rank + self.num_replicas = num_replicas + self.shuffle = shuffle + self.seed = seed + self.window_size = window_size + + def __iter__(self): + order = np.arange(len(self.dataset)) + rnd = None + window = 0 + if self.shuffle: + rnd = np.random.RandomState(self.seed) + rnd.shuffle(order) + window = int(np.rint(order.size * self.window_size)) + + idx = 0 + while True: + i = idx % order.size + if idx % self.num_replicas == self.rank: + yield order[i] + if window >= 2: + j = (i - rnd.randint(window)) % order.size + order[i], order[j] = order[j], order[i] + idx += 1 + +#---------------------------------------------------------------------------- +# Utilities for operating with torch.nn.Module parameters and buffers. + +def params_and_buffers(module): + assert isinstance(module, torch.nn.Module) + return list(module.parameters()) + list(module.buffers()) + +def named_params_and_buffers(module): + assert isinstance(module, torch.nn.Module) + return list(module.named_parameters()) + list(module.named_buffers()) + +def copy_params_and_buffers(src_module, dst_module, require_all=False): + assert isinstance(src_module, torch.nn.Module) + assert isinstance(dst_module, torch.nn.Module) + src_tensors = {name: tensor for name, tensor in named_params_and_buffers(src_module)} + for name, tensor in named_params_and_buffers(dst_module): + assert (name in src_tensors) or (not require_all) + if name in src_tensors: + tensor.copy_(src_tensors[name].detach()).requires_grad_(tensor.requires_grad) + +#---------------------------------------------------------------------------- +# Context manager for easily enabling/disabling DistributedDataParallel +# synchronization. + +@contextlib.contextmanager +def ddp_sync(module, sync): + assert isinstance(module, torch.nn.Module) + if sync or not isinstance(module, torch.nn.parallel.DistributedDataParallel): + yield + else: + with module.no_sync(): + yield + +#---------------------------------------------------------------------------- +# Check DistributedDataParallel consistency across processes. + +def check_ddp_consistency(module, ignore_regex=None): + assert isinstance(module, torch.nn.Module) + for name, tensor in named_params_and_buffers(module): + fullname = type(module).__name__ + '.' + name + if ignore_regex is not None and re.fullmatch(ignore_regex, fullname): + continue + tensor = tensor.detach() + other = tensor.clone() + torch.distributed.broadcast(tensor=other, src=0) + assert (nan_to_num(tensor) == nan_to_num(other)).all(), fullname + +#---------------------------------------------------------------------------- +# Print summary table of module hierarchy. + +def print_module_summary(module, inputs, max_nesting=3, skip_redundant=True): + assert isinstance(module, torch.nn.Module) + assert not isinstance(module, torch.jit.ScriptModule) + assert isinstance(inputs, (tuple, list)) + + # Register hooks. + entries = [] + nesting = [0] + def pre_hook(_mod, _inputs): + nesting[0] += 1 + def post_hook(mod, _inputs, outputs): + nesting[0] -= 1 + if nesting[0] <= max_nesting: + outputs = list(outputs) if isinstance(outputs, (tuple, list)) else [outputs] + outputs = [t for t in outputs if isinstance(t, torch.Tensor)] + entries.append(dnnlib.EasyDict(mod=mod, outputs=outputs)) + hooks = [mod.register_forward_pre_hook(pre_hook) for mod in module.modules()] + hooks += [mod.register_forward_hook(post_hook) for mod in module.modules()] + + # Run module. + outputs = module(*inputs) + for hook in hooks: + hook.remove() + + # Identify unique outputs, parameters, and buffers. + tensors_seen = set() + for e in entries: + e.unique_params = [t for t in e.mod.parameters() if id(t) not in tensors_seen] + e.unique_buffers = [t for t in e.mod.buffers() if id(t) not in tensors_seen] + e.unique_outputs = [t for t in e.outputs if id(t) not in tensors_seen] + tensors_seen |= {id(t) for t in e.unique_params + e.unique_buffers + e.unique_outputs} + + # Filter out redundant entries. + if skip_redundant: + entries = [e for e in entries if len(e.unique_params) or len(e.unique_buffers) or len(e.unique_outputs)] + + # Construct table. + rows = [[type(module).__name__, 'Parameters', 'Buffers', 'Output shape', 'Datatype']] + rows += [['---'] * len(rows[0])] + param_total = 0 + buffer_total = 0 + submodule_names = {mod: name for name, mod in module.named_modules()} + for e in entries: + name = '' if e.mod is module else submodule_names[e.mod] + param_size = sum(t.numel() for t in e.unique_params) + buffer_size = sum(t.numel() for t in e.unique_buffers) + output_shapes = [str(list(e.outputs[0].shape)) for t in e.outputs] + output_dtypes = [str(t.dtype).split('.')[-1] for t in e.outputs] + rows += [[ + name + (':0' if len(e.outputs) >= 2 else ''), + str(param_size) if param_size else '-', + str(buffer_size) if buffer_size else '-', + (output_shapes + ['-'])[0], + (output_dtypes + ['-'])[0], + ]] + for idx in range(1, len(e.outputs)): + rows += [[name + f':{idx}', '-', '-', output_shapes[idx], output_dtypes[idx]]] + param_total += param_size + buffer_total += buffer_size + rows += [['---'] * len(rows[0])] + rows += [['Total', str(param_total), str(buffer_total), '-', '-']] + + # Print table. + widths = [max(len(cell) for cell in column) for column in zip(*rows)] + print() + for row in rows: + print(' '.join(cell + ' ' * (width - len(cell)) for cell, width in zip(row, widths))) + print() + return outputs + +#---------------------------------------------------------------------------- diff --git a/torch_utils/ops/__init__.py b/torch_utils/ops/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ece0ea08fe2e939cc260a1dafc0ab5b391b773d9 --- /dev/null +++ b/torch_utils/ops/__init__.py @@ -0,0 +1,9 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +# empty diff --git a/torch_utils/ops/__pycache__/__init__.cpython-36.pyc b/torch_utils/ops/__pycache__/__init__.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8e19dbbd4ea4e37f41c36a356d3c06cd74790948 Binary files /dev/null and b/torch_utils/ops/__pycache__/__init__.cpython-36.pyc differ diff --git a/torch_utils/ops/__pycache__/__init__.cpython-39.pyc b/torch_utils/ops/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4d3eb8a2f49c65902b814b93cc0e4923008a54a5 Binary files /dev/null and b/torch_utils/ops/__pycache__/__init__.cpython-39.pyc differ diff --git a/torch_utils/ops/__pycache__/bias_act.cpython-36.pyc b/torch_utils/ops/__pycache__/bias_act.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7ce9ca546927ccef66b4740b66eb9e5079033526 Binary files /dev/null and b/torch_utils/ops/__pycache__/bias_act.cpython-36.pyc differ diff --git a/torch_utils/ops/__pycache__/bias_act.cpython-39.pyc b/torch_utils/ops/__pycache__/bias_act.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..eb869727475b88ebc0c14254f1e8fca517eb53da Binary files /dev/null and b/torch_utils/ops/__pycache__/bias_act.cpython-39.pyc differ diff --git a/torch_utils/ops/__pycache__/conv2d_gradfix.cpython-36.pyc b/torch_utils/ops/__pycache__/conv2d_gradfix.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4dbd8afac6eda2926c063b2532294ff85f81e9d2 Binary files /dev/null and b/torch_utils/ops/__pycache__/conv2d_gradfix.cpython-36.pyc differ diff --git a/torch_utils/ops/__pycache__/conv2d_gradfix.cpython-39.pyc b/torch_utils/ops/__pycache__/conv2d_gradfix.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e795f4d5c1b73a85a2e801b37db9ca432c4f450b Binary files /dev/null and b/torch_utils/ops/__pycache__/conv2d_gradfix.cpython-39.pyc differ diff --git a/torch_utils/ops/__pycache__/conv2d_resample.cpython-36.pyc b/torch_utils/ops/__pycache__/conv2d_resample.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e38f39e632d6074a5cb839abf7cf4d5e95bd0199 Binary files /dev/null and b/torch_utils/ops/__pycache__/conv2d_resample.cpython-36.pyc differ diff --git a/torch_utils/ops/__pycache__/conv2d_resample.cpython-39.pyc b/torch_utils/ops/__pycache__/conv2d_resample.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fe20935b8281fb21703536440ef23dc4bfe0de9f Binary files /dev/null and b/torch_utils/ops/__pycache__/conv2d_resample.cpython-39.pyc differ diff --git a/torch_utils/ops/__pycache__/fma.cpython-36.pyc b/torch_utils/ops/__pycache__/fma.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2c84c7a552a09c63b80ab620a4b1e94c7850240b Binary files /dev/null and b/torch_utils/ops/__pycache__/fma.cpython-36.pyc differ diff --git a/torch_utils/ops/__pycache__/fma.cpython-39.pyc b/torch_utils/ops/__pycache__/fma.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5f93221c842885ce8b2c41a35e3a55c35902e98b Binary files /dev/null and b/torch_utils/ops/__pycache__/fma.cpython-39.pyc differ diff --git a/torch_utils/ops/__pycache__/upfirdn2d.cpython-36.pyc b/torch_utils/ops/__pycache__/upfirdn2d.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..31e31d5ccae91919e780d0a9338bff3c27ebfdf8 Binary files /dev/null and b/torch_utils/ops/__pycache__/upfirdn2d.cpython-36.pyc differ diff --git a/torch_utils/ops/__pycache__/upfirdn2d.cpython-39.pyc b/torch_utils/ops/__pycache__/upfirdn2d.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ab558ff15a3c3a3ae40b913ea814e60b82e7aa82 Binary files /dev/null and b/torch_utils/ops/__pycache__/upfirdn2d.cpython-39.pyc differ diff --git a/torch_utils/ops/bias_act.cpp b/torch_utils/ops/bias_act.cpp new file mode 100644 index 0000000000000000000000000000000000000000..5d2425d8054991a8e8b6f7a940fd0ff7fa0bb330 --- /dev/null +++ b/torch_utils/ops/bias_act.cpp @@ -0,0 +1,99 @@ +// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +// +// NVIDIA CORPORATION and its licensors retain all intellectual property +// and proprietary rights in and to this software, related documentation +// and any modifications thereto. Any use, reproduction, disclosure or +// distribution of this software and related documentation without an express +// license agreement from NVIDIA CORPORATION is strictly prohibited. + +#include +#include +#include +#include "bias_act.h" + +//------------------------------------------------------------------------ + +static bool has_same_layout(torch::Tensor x, torch::Tensor y) +{ + if (x.dim() != y.dim()) + return false; + for (int64_t i = 0; i < x.dim(); i++) + { + if (x.size(i) != y.size(i)) + return false; + if (x.size(i) >= 2 && x.stride(i) != y.stride(i)) + return false; + } + return true; +} + +//------------------------------------------------------------------------ + +static torch::Tensor bias_act(torch::Tensor x, torch::Tensor b, torch::Tensor xref, torch::Tensor yref, torch::Tensor dy, int grad, int dim, int act, float alpha, float gain, float clamp) +{ + // Validate arguments. + TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device"); + TORCH_CHECK(b.numel() == 0 || (b.dtype() == x.dtype() && b.device() == x.device()), "b must have the same dtype and device as x"); + TORCH_CHECK(xref.numel() == 0 || (xref.sizes() == x.sizes() && xref.dtype() == x.dtype() && xref.device() == x.device()), "xref must have the same shape, dtype, and device as x"); + TORCH_CHECK(yref.numel() == 0 || (yref.sizes() == x.sizes() && yref.dtype() == x.dtype() && yref.device() == x.device()), "yref must have the same shape, dtype, and device as x"); + TORCH_CHECK(dy.numel() == 0 || (dy.sizes() == x.sizes() && dy.dtype() == x.dtype() && dy.device() == x.device()), "dy must have the same dtype and device as x"); + TORCH_CHECK(x.numel() <= INT_MAX, "x is too large"); + TORCH_CHECK(b.dim() == 1, "b must have rank 1"); + TORCH_CHECK(b.numel() == 0 || (dim >= 0 && dim < x.dim()), "dim is out of bounds"); + TORCH_CHECK(b.numel() == 0 || b.numel() == x.size(dim), "b has wrong number of elements"); + TORCH_CHECK(grad >= 0, "grad must be non-negative"); + + // Validate layout. + TORCH_CHECK(x.is_non_overlapping_and_dense(), "x must be non-overlapping and dense"); + TORCH_CHECK(b.is_contiguous(), "b must be contiguous"); + TORCH_CHECK(xref.numel() == 0 || has_same_layout(xref, x), "xref must have the same layout as x"); + TORCH_CHECK(yref.numel() == 0 || has_same_layout(yref, x), "yref must have the same layout as x"); + TORCH_CHECK(dy.numel() == 0 || has_same_layout(dy, x), "dy must have the same layout as x"); + + // Create output tensor. + const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); + torch::Tensor y = torch::empty_like(x); + TORCH_CHECK(has_same_layout(y, x), "y must have the same layout as x"); + + // Initialize CUDA kernel parameters. + bias_act_kernel_params p; + p.x = x.data_ptr(); + p.b = (b.numel()) ? b.data_ptr() : NULL; + p.xref = (xref.numel()) ? xref.data_ptr() : NULL; + p.yref = (yref.numel()) ? yref.data_ptr() : NULL; + p.dy = (dy.numel()) ? dy.data_ptr() : NULL; + p.y = y.data_ptr(); + p.grad = grad; + p.act = act; + p.alpha = alpha; + p.gain = gain; + p.clamp = clamp; + p.sizeX = (int)x.numel(); + p.sizeB = (int)b.numel(); + p.stepB = (b.numel()) ? (int)x.stride(dim) : 1; + + // Choose CUDA kernel. + void* kernel; + AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] + { + kernel = choose_bias_act_kernel(p); + }); + TORCH_CHECK(kernel, "no CUDA kernel found for the specified activation func"); + + // Launch CUDA kernel. + p.loopX = 4; + int blockSize = 4 * 32; + int gridSize = (p.sizeX - 1) / (p.loopX * blockSize) + 1; + void* args[] = {&p}; + AT_CUDA_CHECK(cudaLaunchKernel(kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream())); + return y; +} + +//------------------------------------------------------------------------ + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) +{ + m.def("bias_act", &bias_act); +} + +//------------------------------------------------------------------------ diff --git a/torch_utils/ops/bias_act.cu b/torch_utils/ops/bias_act.cu new file mode 100644 index 0000000000000000000000000000000000000000..dd8fc4756d7d94727f94af738665b68d9c518880 --- /dev/null +++ b/torch_utils/ops/bias_act.cu @@ -0,0 +1,173 @@ +// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +// +// NVIDIA CORPORATION and its licensors retain all intellectual property +// and proprietary rights in and to this software, related documentation +// and any modifications thereto. Any use, reproduction, disclosure or +// distribution of this software and related documentation without an express +// license agreement from NVIDIA CORPORATION is strictly prohibited. + +#include +#include "bias_act.h" + +//------------------------------------------------------------------------ +// Helpers. + +template struct InternalType; +template <> struct InternalType { typedef double scalar_t; }; +template <> struct InternalType { typedef float scalar_t; }; +template <> struct InternalType { typedef float scalar_t; }; + +//------------------------------------------------------------------------ +// CUDA kernel. + +template +__global__ void bias_act_kernel(bias_act_kernel_params p) +{ + typedef typename InternalType::scalar_t scalar_t; + int G = p.grad; + scalar_t alpha = (scalar_t)p.alpha; + scalar_t gain = (scalar_t)p.gain; + scalar_t clamp = (scalar_t)p.clamp; + scalar_t one = (scalar_t)1; + scalar_t two = (scalar_t)2; + scalar_t expRange = (scalar_t)80; + scalar_t halfExpRange = (scalar_t)40; + scalar_t seluScale = (scalar_t)1.0507009873554804934193349852946; + scalar_t seluAlpha = (scalar_t)1.6732632423543772848170429916717; + + // Loop over elements. + int xi = blockIdx.x * p.loopX * blockDim.x + threadIdx.x; + for (int loopIdx = 0; loopIdx < p.loopX && xi < p.sizeX; loopIdx++, xi += blockDim.x) + { + // Load. + scalar_t x = (scalar_t)((const T*)p.x)[xi]; + scalar_t b = (p.b) ? (scalar_t)((const T*)p.b)[(xi / p.stepB) % p.sizeB] : 0; + scalar_t xref = (p.xref) ? (scalar_t)((const T*)p.xref)[xi] : 0; + scalar_t yref = (p.yref) ? (scalar_t)((const T*)p.yref)[xi] : 0; + scalar_t dy = (p.dy) ? (scalar_t)((const T*)p.dy)[xi] : one; + scalar_t yy = (gain != 0) ? yref / gain : 0; + scalar_t y = 0; + + // Apply bias. + ((G == 0) ? x : xref) += b; + + // linear + if (A == 1) + { + if (G == 0) y = x; + if (G == 1) y = x; + } + + // relu + if (A == 2) + { + if (G == 0) y = (x > 0) ? x : 0; + if (G == 1) y = (yy > 0) ? x : 0; + } + + // lrelu + if (A == 3) + { + if (G == 0) y = (x > 0) ? x : x * alpha; + if (G == 1) y = (yy > 0) ? x : x * alpha; + } + + // tanh + if (A == 4) + { + if (G == 0) { scalar_t c = exp(x); scalar_t d = one / c; y = (x < -expRange) ? -one : (x > expRange) ? one : (c - d) / (c + d); } + if (G == 1) y = x * (one - yy * yy); + if (G == 2) y = x * (one - yy * yy) * (-two * yy); + } + + // sigmoid + if (A == 5) + { + if (G == 0) y = (x < -expRange) ? 0 : one / (exp(-x) + one); + if (G == 1) y = x * yy * (one - yy); + if (G == 2) y = x * yy * (one - yy) * (one - two * yy); + } + + // elu + if (A == 6) + { + if (G == 0) y = (x >= 0) ? x : exp(x) - one; + if (G == 1) y = (yy >= 0) ? x : x * (yy + one); + if (G == 2) y = (yy >= 0) ? 0 : x * (yy + one); + } + + // selu + if (A == 7) + { + if (G == 0) y = (x >= 0) ? seluScale * x : (seluScale * seluAlpha) * (exp(x) - one); + if (G == 1) y = (yy >= 0) ? x * seluScale : x * (yy + seluScale * seluAlpha); + if (G == 2) y = (yy >= 0) ? 0 : x * (yy + seluScale * seluAlpha); + } + + // softplus + if (A == 8) + { + if (G == 0) y = (x > expRange) ? x : log(exp(x) + one); + if (G == 1) y = x * (one - exp(-yy)); + if (G == 2) { scalar_t c = exp(-yy); y = x * c * (one - c); } + } + + // swish + if (A == 9) + { + if (G == 0) + y = (x < -expRange) ? 0 : x / (exp(-x) + one); + else + { + scalar_t c = exp(xref); + scalar_t d = c + one; + if (G == 1) + y = (xref > halfExpRange) ? x : x * c * (xref + d) / (d * d); + else + y = (xref > halfExpRange) ? 0 : x * c * (xref * (two - d) + two * d) / (d * d * d); + yref = (xref < -expRange) ? 0 : xref / (exp(-xref) + one) * gain; + } + } + + // Apply gain. + y *= gain * dy; + + // Clamp. + if (clamp >= 0) + { + if (G == 0) + y = (y > -clamp & y < clamp) ? y : (y >= 0) ? clamp : -clamp; + else + y = (yref > -clamp & yref < clamp) ? y : 0; + } + + // Store. + ((T*)p.y)[xi] = (T)y; + } +} + +//------------------------------------------------------------------------ +// CUDA kernel selection. + +template void* choose_bias_act_kernel(const bias_act_kernel_params& p) +{ + if (p.act == 1) return (void*)bias_act_kernel; + if (p.act == 2) return (void*)bias_act_kernel; + if (p.act == 3) return (void*)bias_act_kernel; + if (p.act == 4) return (void*)bias_act_kernel; + if (p.act == 5) return (void*)bias_act_kernel; + if (p.act == 6) return (void*)bias_act_kernel; + if (p.act == 7) return (void*)bias_act_kernel; + if (p.act == 8) return (void*)bias_act_kernel; + if (p.act == 9) return (void*)bias_act_kernel; + return NULL; +} + +//------------------------------------------------------------------------ +// Template specializations. + +template void* choose_bias_act_kernel (const bias_act_kernel_params& p); +template void* choose_bias_act_kernel (const bias_act_kernel_params& p); +template void* choose_bias_act_kernel (const bias_act_kernel_params& p); + +//------------------------------------------------------------------------ diff --git a/torch_utils/ops/bias_act.h b/torch_utils/ops/bias_act.h new file mode 100644 index 0000000000000000000000000000000000000000..a32187e1fb7e3bae509d4eceaf900866866875a4 --- /dev/null +++ b/torch_utils/ops/bias_act.h @@ -0,0 +1,38 @@ +// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +// +// NVIDIA CORPORATION and its licensors retain all intellectual property +// and proprietary rights in and to this software, related documentation +// and any modifications thereto. Any use, reproduction, disclosure or +// distribution of this software and related documentation without an express +// license agreement from NVIDIA CORPORATION is strictly prohibited. + +//------------------------------------------------------------------------ +// CUDA kernel parameters. + +struct bias_act_kernel_params +{ + const void* x; // [sizeX] + const void* b; // [sizeB] or NULL + const void* xref; // [sizeX] or NULL + const void* yref; // [sizeX] or NULL + const void* dy; // [sizeX] or NULL + void* y; // [sizeX] + + int grad; + int act; + float alpha; + float gain; + float clamp; + + int sizeX; + int sizeB; + int stepB; + int loopX; +}; + +//------------------------------------------------------------------------ +// CUDA kernel selection. + +template void* choose_bias_act_kernel(const bias_act_kernel_params& p); + +//------------------------------------------------------------------------ diff --git a/torch_utils/ops/bias_act.py b/torch_utils/ops/bias_act.py new file mode 100644 index 0000000000000000000000000000000000000000..4bcb409a89ccf6c6f6ecfca5962683df2d280b1f --- /dev/null +++ b/torch_utils/ops/bias_act.py @@ -0,0 +1,212 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +"""Custom PyTorch ops for efficient bias and activation.""" + +import os +import warnings +import numpy as np +import torch +import dnnlib +import traceback + +from .. import custom_ops +from .. import misc + +#---------------------------------------------------------------------------- + +activation_funcs = { + 'linear': dnnlib.EasyDict(func=lambda x, **_: x, def_alpha=0, def_gain=1, cuda_idx=1, ref='', has_2nd_grad=False), + 'relu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.relu(x), def_alpha=0, def_gain=np.sqrt(2), cuda_idx=2, ref='y', has_2nd_grad=False), + 'lrelu': dnnlib.EasyDict(func=lambda x, alpha, **_: torch.nn.functional.leaky_relu(x, alpha), def_alpha=0.2, def_gain=np.sqrt(2), cuda_idx=3, ref='y', has_2nd_grad=False), + 'tanh': dnnlib.EasyDict(func=lambda x, **_: torch.tanh(x), def_alpha=0, def_gain=1, cuda_idx=4, ref='y', has_2nd_grad=True), + 'sigmoid': dnnlib.EasyDict(func=lambda x, **_: torch.sigmoid(x), def_alpha=0, def_gain=1, cuda_idx=5, ref='y', has_2nd_grad=True), + 'elu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.elu(x), def_alpha=0, def_gain=1, cuda_idx=6, ref='y', has_2nd_grad=True), + 'selu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.selu(x), def_alpha=0, def_gain=1, cuda_idx=7, ref='y', has_2nd_grad=True), + 'softplus': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.softplus(x), def_alpha=0, def_gain=1, cuda_idx=8, ref='y', has_2nd_grad=True), + 'swish': dnnlib.EasyDict(func=lambda x, **_: torch.sigmoid(x) * x, def_alpha=0, def_gain=np.sqrt(2), cuda_idx=9, ref='x', has_2nd_grad=True), +} + +#---------------------------------------------------------------------------- + +_inited = False +_plugin = None +_null_tensor = torch.empty([0]) + +def _init(): + global _inited, _plugin + if not _inited: + _inited = True + sources = ['bias_act.cpp', 'bias_act.cu'] + sources = [os.path.join(os.path.dirname(__file__), s) for s in sources] + try: + _plugin = custom_ops.get_plugin('bias_act_plugin', sources=sources, extra_cuda_cflags=['--use_fast_math']) + except: + warnings.warn('Failed to build CUDA kernels for bias_act. Falling back to slow reference implementation. Details:\n\n' + traceback.format_exc()) + return _plugin is not None + +#---------------------------------------------------------------------------- + +def bias_act(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None, impl='cuda'): + r"""Fused bias and activation function. + + Adds bias `b` to activation tensor `x`, evaluates activation function `act`, + and scales the result by `gain`. Each of the steps is optional. In most cases, + the fused op is considerably more efficient than performing the same calculation + using standard PyTorch ops. It supports first and second order gradients, + but not third order gradients. + + Args: + x: Input activation tensor. Can be of any shape. + b: Bias vector, or `None` to disable. Must be a 1D tensor of the same type + as `x`. The shape must be known, and it must match the dimension of `x` + corresponding to `dim`. + dim: The dimension in `x` corresponding to the elements of `b`. + The value of `dim` is ignored if `b` is not specified. + act: Name of the activation function to evaluate, or `"linear"` to disable. + Can be e.g. `"relu"`, `"lrelu"`, `"tanh"`, `"sigmoid"`, `"swish"`, etc. + See `activation_funcs` for a full list. `None` is not allowed. + alpha: Shape parameter for the activation function, or `None` to use the default. + gain: Scaling factor for the output tensor, or `None` to use default. + See `activation_funcs` for the default scaling of each activation function. + If unsure, consider specifying 1. + clamp: Clamp the output values to `[-clamp, +clamp]`, or `None` to disable + the clamping (default). + impl: Name of the implementation to use. Can be `"ref"` or `"cuda"` (default). + + Returns: + Tensor of the same shape and datatype as `x`. + """ + assert isinstance(x, torch.Tensor) + assert impl in ['ref', 'cuda'] + if impl == 'cuda' and x.device.type == 'cuda' and _init(): + return _bias_act_cuda(dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp).apply(x, b) + return _bias_act_ref(x=x, b=b, dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp) + +#---------------------------------------------------------------------------- + +@misc.profiled_function +def _bias_act_ref(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None): + """Slow reference implementation of `bias_act()` using standard TensorFlow ops. + """ + assert isinstance(x, torch.Tensor) + assert clamp is None or clamp >= 0 + spec = activation_funcs[act] + alpha = float(alpha if alpha is not None else spec.def_alpha) + gain = float(gain if gain is not None else spec.def_gain) + clamp = float(clamp if clamp is not None else -1) + + # Add bias. + if b is not None: + assert isinstance(b, torch.Tensor) and b.ndim == 1 + assert 0 <= dim < x.ndim + assert b.shape[0] == x.shape[dim] + x = x + b.reshape([-1 if i == dim else 1 for i in range(x.ndim)]) + + # Evaluate activation function. + alpha = float(alpha) + x = spec.func(x, alpha=alpha) + + # Scale by gain. + gain = float(gain) + if gain != 1: + x = x * gain + + # Clamp. + if clamp >= 0: + x = x.clamp(-clamp, clamp) # pylint: disable=invalid-unary-operand-type + return x + +#---------------------------------------------------------------------------- + +_bias_act_cuda_cache = dict() + +def _bias_act_cuda(dim=1, act='linear', alpha=None, gain=None, clamp=None): + """Fast CUDA implementation of `bias_act()` using custom ops. + """ + # Parse arguments. + assert clamp is None or clamp >= 0 + spec = activation_funcs[act] + alpha = float(alpha if alpha is not None else spec.def_alpha) + gain = float(gain if gain is not None else spec.def_gain) + clamp = float(clamp if clamp is not None else -1) + + # Lookup from cache. + key = (dim, act, alpha, gain, clamp) + if key in _bias_act_cuda_cache: + return _bias_act_cuda_cache[key] + + # Forward op. + class BiasActCuda(torch.autograd.Function): + @staticmethod + def forward(ctx, x, b): # pylint: disable=arguments-differ + ctx.memory_format = torch.channels_last if x.ndim > 2 and x.stride()[1] == 1 else torch.contiguous_format + x = x.contiguous(memory_format=ctx.memory_format) + b = b.contiguous() if b is not None else _null_tensor + y = x + if act != 'linear' or gain != 1 or clamp >= 0 or b is not _null_tensor: + y = _plugin.bias_act(x, b, _null_tensor, _null_tensor, _null_tensor, 0, dim, spec.cuda_idx, alpha, gain, clamp) + ctx.save_for_backward( + x if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor, + b if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor, + y if 'y' in spec.ref else _null_tensor) + return y + + @staticmethod + def backward(ctx, dy): # pylint: disable=arguments-differ + dy = dy.contiguous(memory_format=ctx.memory_format) + x, b, y = ctx.saved_tensors + dx = None + db = None + + if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]: + dx = dy + if act != 'linear' or gain != 1 or clamp >= 0: + dx = BiasActCudaGrad.apply(dy, x, b, y) + + if ctx.needs_input_grad[1]: + db = dx.sum([i for i in range(dx.ndim) if i != dim]) + + return dx, db + + # Backward op. + class BiasActCudaGrad(torch.autograd.Function): + @staticmethod + def forward(ctx, dy, x, b, y): # pylint: disable=arguments-differ + ctx.memory_format = torch.channels_last if dy.ndim > 2 and dy.stride()[1] == 1 else torch.contiguous_format + dx = _plugin.bias_act(dy, b, x, y, _null_tensor, 1, dim, spec.cuda_idx, alpha, gain, clamp) + ctx.save_for_backward( + dy if spec.has_2nd_grad else _null_tensor, + x, b, y) + return dx + + @staticmethod + def backward(ctx, d_dx): # pylint: disable=arguments-differ + d_dx = d_dx.contiguous(memory_format=ctx.memory_format) + dy, x, b, y = ctx.saved_tensors + d_dy = None + d_x = None + d_b = None + d_y = None + + if ctx.needs_input_grad[0]: + d_dy = BiasActCudaGrad.apply(d_dx, x, b, y) + + if spec.has_2nd_grad and (ctx.needs_input_grad[1] or ctx.needs_input_grad[2]): + d_x = _plugin.bias_act(d_dx, b, x, y, dy, 2, dim, spec.cuda_idx, alpha, gain, clamp) + + if spec.has_2nd_grad and ctx.needs_input_grad[2]: + d_b = d_x.sum([i for i in range(d_x.ndim) if i != dim]) + + return d_dy, d_x, d_b, d_y + + # Add to cache. + _bias_act_cuda_cache[key] = BiasActCuda + return BiasActCuda + +#---------------------------------------------------------------------------- diff --git a/torch_utils/ops/conv2d_gradfix.py b/torch_utils/ops/conv2d_gradfix.py new file mode 100644 index 0000000000000000000000000000000000000000..e95e10d0b1d0315a63a76446fd4c5c293c8bbc6d --- /dev/null +++ b/torch_utils/ops/conv2d_gradfix.py @@ -0,0 +1,170 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +"""Custom replacement for `torch.nn.functional.conv2d` that supports +arbitrarily high order gradients with zero performance penalty.""" + +import warnings +import contextlib +import torch + +# pylint: disable=redefined-builtin +# pylint: disable=arguments-differ +# pylint: disable=protected-access + +#---------------------------------------------------------------------------- + +enabled = False # Enable the custom op by setting this to true. +weight_gradients_disabled = False # Forcefully disable computation of gradients with respect to the weights. + +@contextlib.contextmanager +def no_weight_gradients(): + global weight_gradients_disabled + old = weight_gradients_disabled + weight_gradients_disabled = True + yield + weight_gradients_disabled = old + +#---------------------------------------------------------------------------- + +def conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1): + if _should_use_custom_op(input): + return _conv2d_gradfix(transpose=False, weight_shape=weight.shape, stride=stride, padding=padding, output_padding=0, dilation=dilation, groups=groups).apply(input, weight, bias) + return torch.nn.functional.conv2d(input=input, weight=weight, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups) + +def conv_transpose2d(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1): + if _should_use_custom_op(input): + return _conv2d_gradfix(transpose=True, weight_shape=weight.shape, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation).apply(input, weight, bias) + return torch.nn.functional.conv_transpose2d(input=input, weight=weight, bias=bias, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation) + +#---------------------------------------------------------------------------- + +def _should_use_custom_op(input): + assert isinstance(input, torch.Tensor) + if (not enabled) or (not torch.backends.cudnn.enabled): + return False + if input.device.type != 'cuda': + return False + if any(torch.__version__.startswith(x) for x in ['1.7.', '1.8.', '1.9']): + return True + warnings.warn(f'conv2d_gradfix not supported on PyTorch {torch.__version__}. Falling back to torch.nn.functional.conv2d().') + return False + +def _tuple_of_ints(xs, ndim): + xs = tuple(xs) if isinstance(xs, (tuple, list)) else (xs,) * ndim + assert len(xs) == ndim + assert all(isinstance(x, int) for x in xs) + return xs + +#---------------------------------------------------------------------------- + +_conv2d_gradfix_cache = dict() + +def _conv2d_gradfix(transpose, weight_shape, stride, padding, output_padding, dilation, groups): + # Parse arguments. + ndim = 2 + weight_shape = tuple(weight_shape) + stride = _tuple_of_ints(stride, ndim) + padding = _tuple_of_ints(padding, ndim) + output_padding = _tuple_of_ints(output_padding, ndim) + dilation = _tuple_of_ints(dilation, ndim) + + # Lookup from cache. + key = (transpose, weight_shape, stride, padding, output_padding, dilation, groups) + if key in _conv2d_gradfix_cache: + return _conv2d_gradfix_cache[key] + + # Validate arguments. + assert groups >= 1 + assert len(weight_shape) == ndim + 2 + assert all(stride[i] >= 1 for i in range(ndim)) + assert all(padding[i] >= 0 for i in range(ndim)) + assert all(dilation[i] >= 0 for i in range(ndim)) + if not transpose: + assert all(output_padding[i] == 0 for i in range(ndim)) + else: # transpose + assert all(0 <= output_padding[i] < max(stride[i], dilation[i]) for i in range(ndim)) + + # Helpers. + common_kwargs = dict(stride=stride, padding=padding, dilation=dilation, groups=groups) + def calc_output_padding(input_shape, output_shape): + if transpose: + return [0, 0] + return [ + input_shape[i + 2] + - (output_shape[i + 2] - 1) * stride[i] + - (1 - 2 * padding[i]) + - dilation[i] * (weight_shape[i + 2] - 1) + for i in range(ndim) + ] + + # Forward & backward. + class Conv2d(torch.autograd.Function): + @staticmethod + def forward(ctx, input, weight, bias): + assert weight.shape == weight_shape + if not transpose: + output = torch.nn.functional.conv2d(input=input, weight=weight, bias=bias, **common_kwargs) + else: # transpose + output = torch.nn.functional.conv_transpose2d(input=input, weight=weight, bias=bias, output_padding=output_padding, **common_kwargs) + ctx.save_for_backward(input, weight) + return output + + @staticmethod + def backward(ctx, grad_output): + input, weight = ctx.saved_tensors + grad_input = None + grad_weight = None + grad_bias = None + + if ctx.needs_input_grad[0]: + p = calc_output_padding(input_shape=input.shape, output_shape=grad_output.shape) + grad_input = _conv2d_gradfix(transpose=(not transpose), weight_shape=weight_shape, output_padding=p, **common_kwargs).apply(grad_output, weight, None) + assert grad_input.shape == input.shape + + if ctx.needs_input_grad[1] and not weight_gradients_disabled: + grad_weight = Conv2dGradWeight.apply(grad_output, input) + assert grad_weight.shape == weight_shape + + if ctx.needs_input_grad[2]: + grad_bias = grad_output.sum([0, 2, 3]) + + return grad_input, grad_weight, grad_bias + + # Gradient with respect to the weights. + class Conv2dGradWeight(torch.autograd.Function): + @staticmethod + def forward(ctx, grad_output, input): + op = torch._C._jit_get_operation('aten::cudnn_convolution_backward_weight' if not transpose else 'aten::cudnn_convolution_transpose_backward_weight') + flags = [torch.backends.cudnn.benchmark, torch.backends.cudnn.deterministic, torch.backends.cudnn.allow_tf32] + grad_weight = op(weight_shape, grad_output, input, padding, stride, dilation, groups, *flags) + assert grad_weight.shape == weight_shape + ctx.save_for_backward(grad_output, input) + return grad_weight + + @staticmethod + def backward(ctx, grad2_grad_weight): + grad_output, input = ctx.saved_tensors + grad2_grad_output = None + grad2_input = None + + if ctx.needs_input_grad[0]: + grad2_grad_output = Conv2d.apply(input, grad2_grad_weight, None) + assert grad2_grad_output.shape == grad_output.shape + + if ctx.needs_input_grad[1]: + p = calc_output_padding(input_shape=input.shape, output_shape=grad_output.shape) + grad2_input = _conv2d_gradfix(transpose=(not transpose), weight_shape=weight_shape, output_padding=p, **common_kwargs).apply(grad_output, grad2_grad_weight, None) + assert grad2_input.shape == input.shape + + return grad2_grad_output, grad2_input + + _conv2d_gradfix_cache[key] = Conv2d + return Conv2d + +#---------------------------------------------------------------------------- diff --git a/torch_utils/ops/conv2d_resample.py b/torch_utils/ops/conv2d_resample.py new file mode 100644 index 0000000000000000000000000000000000000000..cd4750744c83354bab78704d4ef51ad1070fcc4a --- /dev/null +++ b/torch_utils/ops/conv2d_resample.py @@ -0,0 +1,156 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +"""2D convolution with optional up/downsampling.""" + +import torch + +from .. import misc +from . import conv2d_gradfix +from . import upfirdn2d +from .upfirdn2d import _parse_padding +from .upfirdn2d import _get_filter_size + +#---------------------------------------------------------------------------- + +def _get_weight_shape(w): + with misc.suppress_tracer_warnings(): # this value will be treated as a constant + shape = [int(sz) for sz in w.shape] + misc.assert_shape(w, shape) + return shape + +#---------------------------------------------------------------------------- + +def _conv2d_wrapper(x, w, stride=1, padding=0, groups=1, transpose=False, flip_weight=True): + """Wrapper for the underlying `conv2d()` and `conv_transpose2d()` implementations. + """ + out_channels, in_channels_per_group, kh, kw = _get_weight_shape(w) + + # Flip weight if requested. + if not flip_weight: # conv2d() actually performs correlation (flip_weight=True) not convolution (flip_weight=False). + w = w.flip([2, 3]) + + # Workaround performance pitfall in cuDNN 8.0.5, triggered when using + # 1x1 kernel + memory_format=channels_last + less than 64 channels. + if kw == 1 and kh == 1 and stride == 1 and padding in [0, [0, 0], (0, 0)] and not transpose: + if x.stride()[1] == 1 and min(out_channels, in_channels_per_group) < 64: + if out_channels <= 4 and groups == 1: + in_shape = x.shape + x = w.squeeze(3).squeeze(2) @ x.reshape([in_shape[0], in_channels_per_group, -1]) + x = x.reshape([in_shape[0], out_channels, in_shape[2], in_shape[3]]) + else: + x = x.to(memory_format=torch.contiguous_format) + w = w.to(memory_format=torch.contiguous_format) + x = conv2d_gradfix.conv2d(x, w, groups=groups) + return x.to(memory_format=torch.channels_last) + + # Otherwise => execute using conv2d_gradfix. + op = conv2d_gradfix.conv_transpose2d if transpose else conv2d_gradfix.conv2d + return op(x, w, stride=stride, padding=padding, groups=groups) + +#---------------------------------------------------------------------------- + +@misc.profiled_function +def conv2d_resample(x, w, f=None, up=1, down=1, padding=0, groups=1, flip_weight=True, flip_filter=False): + r"""2D convolution with optional up/downsampling. + + Padding is performed only once at the beginning, not between the operations. + + Args: + x: Input tensor of shape + `[batch_size, in_channels, in_height, in_width]`. + w: Weight tensor of shape + `[out_channels, in_channels//groups, kernel_height, kernel_width]`. + f: Low-pass filter for up/downsampling. Must be prepared beforehand by + calling upfirdn2d.setup_filter(). None = identity (default). + up: Integer upsampling factor (default: 1). + down: Integer downsampling factor (default: 1). + padding: Padding with respect to the upsampled image. Can be a single number + or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]` + (default: 0). + groups: Split input channels into N groups (default: 1). + flip_weight: False = convolution, True = correlation (default: True). + flip_filter: False = convolution, True = correlation (default: False). + + Returns: + Tensor of the shape `[batch_size, num_channels, out_height, out_width]`. + """ + # Validate arguments. + assert isinstance(x, torch.Tensor) and (x.ndim == 4) + assert isinstance(w, torch.Tensor) and (w.ndim == 4) and (w.dtype == x.dtype) + assert f is None or (isinstance(f, torch.Tensor) and f.ndim in [1, 2] and f.dtype == torch.float32) + assert isinstance(up, int) and (up >= 1) + assert isinstance(down, int) and (down >= 1) + assert isinstance(groups, int) and (groups >= 1) + out_channels, in_channels_per_group, kh, kw = _get_weight_shape(w) + fw, fh = _get_filter_size(f) + px0, px1, py0, py1 = _parse_padding(padding) + + # Adjust padding to account for up/downsampling. + if up > 1: + px0 += (fw + up - 1) // 2 + px1 += (fw - up) // 2 + py0 += (fh + up - 1) // 2 + py1 += (fh - up) // 2 + if down > 1: + px0 += (fw - down + 1) // 2 + px1 += (fw - down) // 2 + py0 += (fh - down + 1) // 2 + py1 += (fh - down) // 2 + + # Fast path: 1x1 convolution with downsampling only => downsample first, then convolve. + if kw == 1 and kh == 1 and (down > 1 and up == 1): + x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, padding=[px0,px1,py0,py1], flip_filter=flip_filter) + x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight) + return x + + # Fast path: 1x1 convolution with upsampling only => convolve first, then upsample. + if kw == 1 and kh == 1 and (up > 1 and down == 1): + x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight) + x = upfirdn2d.upfirdn2d(x=x, f=f, up=up, padding=[px0,px1,py0,py1], gain=up**2, flip_filter=flip_filter) + return x + + # Fast path: downsampling only => use strided convolution. + if down > 1 and up == 1: + x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0,px1,py0,py1], flip_filter=flip_filter) + x = _conv2d_wrapper(x=x, w=w, stride=down, groups=groups, flip_weight=flip_weight) + return x + + # Fast path: upsampling with optional downsampling => use transpose strided convolution. + if up > 1: + if groups == 1: + w = w.transpose(0, 1) + else: + w = w.reshape(groups, out_channels // groups, in_channels_per_group, kh, kw) + w = w.transpose(1, 2) + w = w.reshape(groups * in_channels_per_group, out_channels // groups, kh, kw) + px0 -= kw - 1 + px1 -= kw - up + py0 -= kh - 1 + py1 -= kh - up + pxt = max(min(-px0, -px1), 0) + pyt = max(min(-py0, -py1), 0) + x = _conv2d_wrapper(x=x, w=w, stride=up, padding=[pyt,pxt], groups=groups, transpose=True, flip_weight=(not flip_weight)) + x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0+pxt,px1+pxt,py0+pyt,py1+pyt], gain=up**2, flip_filter=flip_filter) + if down > 1: + x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter) + return x + + # Fast path: no up/downsampling, padding supported by the underlying implementation => use plain conv2d. + if up == 1 and down == 1: + if px0 == px1 and py0 == py1 and px0 >= 0 and py0 >= 0: + return _conv2d_wrapper(x=x, w=w, padding=[py0,px0], groups=groups, flip_weight=flip_weight) + + # Fallback: Generic reference implementation. + x = upfirdn2d.upfirdn2d(x=x, f=(f if up > 1 else None), up=up, padding=[px0,px1,py0,py1], gain=up**2, flip_filter=flip_filter) + x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight) + if down > 1: + x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter) + return x + +#---------------------------------------------------------------------------- diff --git a/torch_utils/ops/fma.py b/torch_utils/ops/fma.py new file mode 100644 index 0000000000000000000000000000000000000000..2eeac58a626c49231e04122b93e321ada954c5d3 --- /dev/null +++ b/torch_utils/ops/fma.py @@ -0,0 +1,60 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +"""Fused multiply-add, with slightly faster gradients than `torch.addcmul()`.""" + +import torch + +#---------------------------------------------------------------------------- + +def fma(a, b, c): # => a * b + c + return _FusedMultiplyAdd.apply(a, b, c) + +#---------------------------------------------------------------------------- + +class _FusedMultiplyAdd(torch.autograd.Function): # a * b + c + @staticmethod + def forward(ctx, a, b, c): # pylint: disable=arguments-differ + out = torch.addcmul(c, a, b) + ctx.save_for_backward(a, b) + ctx.c_shape = c.shape + return out + + @staticmethod + def backward(ctx, dout): # pylint: disable=arguments-differ + a, b = ctx.saved_tensors + c_shape = ctx.c_shape + da = None + db = None + dc = None + + if ctx.needs_input_grad[0]: + da = _unbroadcast(dout * b, a.shape) + + if ctx.needs_input_grad[1]: + db = _unbroadcast(dout * a, b.shape) + + if ctx.needs_input_grad[2]: + dc = _unbroadcast(dout, c_shape) + + return da, db, dc + +#---------------------------------------------------------------------------- + +def _unbroadcast(x, shape): + extra_dims = x.ndim - len(shape) + assert extra_dims >= 0 + dim = [i for i in range(x.ndim) if x.shape[i] > 1 and (i < extra_dims or shape[i - extra_dims] == 1)] + if len(dim): + x = x.sum(dim=dim, keepdim=True) + if extra_dims: + x = x.reshape(-1, *x.shape[extra_dims+1:]) + assert x.shape == shape + return x + +#---------------------------------------------------------------------------- diff --git a/torch_utils/ops/grid_sample_gradfix.py b/torch_utils/ops/grid_sample_gradfix.py new file mode 100644 index 0000000000000000000000000000000000000000..ca6b3413ea72a734703c34382c023b84523601fd --- /dev/null +++ b/torch_utils/ops/grid_sample_gradfix.py @@ -0,0 +1,83 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +"""Custom replacement for `torch.nn.functional.grid_sample` that +supports arbitrarily high order gradients between the input and output. +Only works on 2D images and assumes +`mode='bilinear'`, `padding_mode='zeros'`, `align_corners=False`.""" + +import warnings +import torch + +# pylint: disable=redefined-builtin +# pylint: disable=arguments-differ +# pylint: disable=protected-access + +#---------------------------------------------------------------------------- + +enabled = False # Enable the custom op by setting this to true. + +#---------------------------------------------------------------------------- + +def grid_sample(input, grid): + if _should_use_custom_op(): + return _GridSample2dForward.apply(input, grid) + return torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False) + +#---------------------------------------------------------------------------- + +def _should_use_custom_op(): + if not enabled: + return False + if any(torch.__version__.startswith(x) for x in ['1.7.', '1.8.', '1.9']): + return True + warnings.warn(f'grid_sample_gradfix not supported on PyTorch {torch.__version__}. Falling back to torch.nn.functional.grid_sample().') + return False + +#---------------------------------------------------------------------------- + +class _GridSample2dForward(torch.autograd.Function): + @staticmethod + def forward(ctx, input, grid): + assert input.ndim == 4 + assert grid.ndim == 4 + output = torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False) + ctx.save_for_backward(input, grid) + return output + + @staticmethod + def backward(ctx, grad_output): + input, grid = ctx.saved_tensors + grad_input, grad_grid = _GridSample2dBackward.apply(grad_output, input, grid) + return grad_input, grad_grid + +#---------------------------------------------------------------------------- + +class _GridSample2dBackward(torch.autograd.Function): + @staticmethod + def forward(ctx, grad_output, input, grid): + op = torch._C._jit_get_operation('aten::grid_sampler_2d_backward') + grad_input, grad_grid = op(grad_output, input, grid, 0, 0, False) + ctx.save_for_backward(grid) + return grad_input, grad_grid + + @staticmethod + def backward(ctx, grad2_grad_input, grad2_grad_grid): + _ = grad2_grad_grid # unused + grid, = ctx.saved_tensors + grad2_grad_output = None + grad2_input = None + grad2_grid = None + + if ctx.needs_input_grad[0]: + grad2_grad_output = _GridSample2dForward.apply(grad2_grad_input, grid) + + assert not ctx.needs_input_grad[2] + return grad2_grad_output, grad2_input, grad2_grid + +#---------------------------------------------------------------------------- diff --git a/torch_utils/ops/upfirdn2d.cpp b/torch_utils/ops/upfirdn2d.cpp new file mode 100644 index 0000000000000000000000000000000000000000..2d7177fc60040751d20e9a8da0301fa3ab64968a --- /dev/null +++ b/torch_utils/ops/upfirdn2d.cpp @@ -0,0 +1,103 @@ +// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +// +// NVIDIA CORPORATION and its licensors retain all intellectual property +// and proprietary rights in and to this software, related documentation +// and any modifications thereto. Any use, reproduction, disclosure or +// distribution of this software and related documentation without an express +// license agreement from NVIDIA CORPORATION is strictly prohibited. + +#include +#include +#include +#include "upfirdn2d.h" + +//------------------------------------------------------------------------ + +static torch::Tensor upfirdn2d(torch::Tensor x, torch::Tensor f, int upx, int upy, int downx, int downy, int padx0, int padx1, int pady0, int pady1, bool flip, float gain) +{ + // Validate arguments. + TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device"); + TORCH_CHECK(f.device() == x.device(), "f must reside on the same device as x"); + TORCH_CHECK(f.dtype() == torch::kFloat, "f must be float32"); + TORCH_CHECK(x.numel() <= INT_MAX, "x is too large"); + TORCH_CHECK(f.numel() <= INT_MAX, "f is too large"); + TORCH_CHECK(x.dim() == 4, "x must be rank 4"); + TORCH_CHECK(f.dim() == 2, "f must be rank 2"); + TORCH_CHECK(f.size(0) >= 1 && f.size(1) >= 1, "f must be at least 1x1"); + TORCH_CHECK(upx >= 1 && upy >= 1, "upsampling factor must be at least 1"); + TORCH_CHECK(downx >= 1 && downy >= 1, "downsampling factor must be at least 1"); + + // Create output tensor. + const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); + int outW = ((int)x.size(3) * upx + padx0 + padx1 - (int)f.size(1) + downx) / downx; + int outH = ((int)x.size(2) * upy + pady0 + pady1 - (int)f.size(0) + downy) / downy; + TORCH_CHECK(outW >= 1 && outH >= 1, "output must be at least 1x1"); + torch::Tensor y = torch::empty({x.size(0), x.size(1), outH, outW}, x.options(), x.suggest_memory_format()); + TORCH_CHECK(y.numel() <= INT_MAX, "output is too large"); + + // Initialize CUDA kernel parameters. + upfirdn2d_kernel_params p; + p.x = x.data_ptr(); + p.f = f.data_ptr(); + p.y = y.data_ptr(); + p.up = make_int2(upx, upy); + p.down = make_int2(downx, downy); + p.pad0 = make_int2(padx0, pady0); + p.flip = (flip) ? 1 : 0; + p.gain = gain; + p.inSize = make_int4((int)x.size(3), (int)x.size(2), (int)x.size(1), (int)x.size(0)); + p.inStride = make_int4((int)x.stride(3), (int)x.stride(2), (int)x.stride(1), (int)x.stride(0)); + p.filterSize = make_int2((int)f.size(1), (int)f.size(0)); + p.filterStride = make_int2((int)f.stride(1), (int)f.stride(0)); + p.outSize = make_int4((int)y.size(3), (int)y.size(2), (int)y.size(1), (int)y.size(0)); + p.outStride = make_int4((int)y.stride(3), (int)y.stride(2), (int)y.stride(1), (int)y.stride(0)); + p.sizeMajor = (p.inStride.z == 1) ? p.inSize.w : p.inSize.w * p.inSize.z; + p.sizeMinor = (p.inStride.z == 1) ? p.inSize.z : 1; + + // Choose CUDA kernel. + upfirdn2d_kernel_spec spec; + AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] + { + spec = choose_upfirdn2d_kernel(p); + }); + + // Set looping options. + p.loopMajor = (p.sizeMajor - 1) / 16384 + 1; + p.loopMinor = spec.loopMinor; + p.loopX = spec.loopX; + p.launchMinor = (p.sizeMinor - 1) / p.loopMinor + 1; + p.launchMajor = (p.sizeMajor - 1) / p.loopMajor + 1; + + // Compute grid size. + dim3 blockSize, gridSize; + if (spec.tileOutW < 0) // large + { + blockSize = dim3(4, 32, 1); + gridSize = dim3( + ((p.outSize.y - 1) / blockSize.x + 1) * p.launchMinor, + (p.outSize.x - 1) / (blockSize.y * p.loopX) + 1, + p.launchMajor); + } + else // small + { + blockSize = dim3(256, 1, 1); + gridSize = dim3( + ((p.outSize.y - 1) / spec.tileOutH + 1) * p.launchMinor, + (p.outSize.x - 1) / (spec.tileOutW * p.loopX) + 1, + p.launchMajor); + } + + // Launch CUDA kernel. + void* args[] = {&p}; + AT_CUDA_CHECK(cudaLaunchKernel(spec.kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream())); + return y; +} + +//------------------------------------------------------------------------ + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) +{ + m.def("upfirdn2d", &upfirdn2d); +} + +//------------------------------------------------------------------------ diff --git a/torch_utils/ops/upfirdn2d.cu b/torch_utils/ops/upfirdn2d.cu new file mode 100644 index 0000000000000000000000000000000000000000..ebdd9879f4bb16fc57a23cbc81f9de8ef54e4916 --- /dev/null +++ b/torch_utils/ops/upfirdn2d.cu @@ -0,0 +1,350 @@ +// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +// +// NVIDIA CORPORATION and its licensors retain all intellectual property +// and proprietary rights in and to this software, related documentation +// and any modifications thereto. Any use, reproduction, disclosure or +// distribution of this software and related documentation without an express +// license agreement from NVIDIA CORPORATION is strictly prohibited. + +#include +#include "upfirdn2d.h" + +//------------------------------------------------------------------------ +// Helpers. + +template struct InternalType; +template <> struct InternalType { typedef double scalar_t; }; +template <> struct InternalType { typedef float scalar_t; }; +template <> struct InternalType { typedef float scalar_t; }; + +static __device__ __forceinline__ int floor_div(int a, int b) +{ + int t = 1 - a / b; + return (a + t * b) / b - t; +} + +//------------------------------------------------------------------------ +// Generic CUDA implementation for large filters. + +template static __global__ void upfirdn2d_kernel_large(upfirdn2d_kernel_params p) +{ + typedef typename InternalType::scalar_t scalar_t; + + // Calculate thread index. + int minorBase = blockIdx.x * blockDim.x + threadIdx.x; + int outY = minorBase / p.launchMinor; + minorBase -= outY * p.launchMinor; + int outXBase = blockIdx.y * p.loopX * blockDim.y + threadIdx.y; + int majorBase = blockIdx.z * p.loopMajor; + if (outXBase >= p.outSize.x | outY >= p.outSize.y | majorBase >= p.sizeMajor) + return; + + // Setup Y receptive field. + int midY = outY * p.down.y + p.up.y - 1 - p.pad0.y; + int inY = min(max(floor_div(midY, p.up.y), 0), p.inSize.y); + int h = min(max(floor_div(midY + p.filterSize.y, p.up.y), 0), p.inSize.y) - inY; + int filterY = midY + p.filterSize.y - (inY + 1) * p.up.y; + if (p.flip) + filterY = p.filterSize.y - 1 - filterY; + + // Loop over major, minor, and X. + for (int majorIdx = 0, major = majorBase; majorIdx < p.loopMajor & major < p.sizeMajor; majorIdx++, major++) + for (int minorIdx = 0, minor = minorBase; minorIdx < p.loopMinor & minor < p.sizeMinor; minorIdx++, minor += p.launchMinor) + { + int nc = major * p.sizeMinor + minor; + int n = nc / p.inSize.z; + int c = nc - n * p.inSize.z; + for (int loopX = 0, outX = outXBase; loopX < p.loopX & outX < p.outSize.x; loopX++, outX += blockDim.y) + { + // Setup X receptive field. + int midX = outX * p.down.x + p.up.x - 1 - p.pad0.x; + int inX = min(max(floor_div(midX, p.up.x), 0), p.inSize.x); + int w = min(max(floor_div(midX + p.filterSize.x, p.up.x), 0), p.inSize.x) - inX; + int filterX = midX + p.filterSize.x - (inX + 1) * p.up.x; + if (p.flip) + filterX = p.filterSize.x - 1 - filterX; + + // Initialize pointers. + const T* xp = &((const T*)p.x)[inX * p.inStride.x + inY * p.inStride.y + c * p.inStride.z + n * p.inStride.w]; + const float* fp = &p.f[filterX * p.filterStride.x + filterY * p.filterStride.y]; + int filterStepX = ((p.flip) ? p.up.x : -p.up.x) * p.filterStride.x; + int filterStepY = ((p.flip) ? p.up.y : -p.up.y) * p.filterStride.y; + + // Inner loop. + scalar_t v = 0; + for (int y = 0; y < h; y++) + { + for (int x = 0; x < w; x++) + { + v += (scalar_t)(*xp) * (scalar_t)(*fp); + xp += p.inStride.x; + fp += filterStepX; + } + xp += p.inStride.y - w * p.inStride.x; + fp += filterStepY - w * filterStepX; + } + + // Store result. + v *= p.gain; + ((T*)p.y)[outX * p.outStride.x + outY * p.outStride.y + c * p.outStride.z + n * p.outStride.w] = (T)v; + } + } +} + +//------------------------------------------------------------------------ +// Specialized CUDA implementation for small filters. + +template +static __global__ void upfirdn2d_kernel_small(upfirdn2d_kernel_params p) +{ + typedef typename InternalType::scalar_t scalar_t; + const int tileInW = ((tileOutW - 1) * downx + filterW - 1) / upx + 1; + const int tileInH = ((tileOutH - 1) * downy + filterH - 1) / upy + 1; + __shared__ volatile scalar_t sf[filterH][filterW]; + __shared__ volatile scalar_t sx[tileInH][tileInW][loopMinor]; + + // Calculate tile index. + int minorBase = blockIdx.x; + int tileOutY = minorBase / p.launchMinor; + minorBase -= tileOutY * p.launchMinor; + minorBase *= loopMinor; + tileOutY *= tileOutH; + int tileOutXBase = blockIdx.y * p.loopX * tileOutW; + int majorBase = blockIdx.z * p.loopMajor; + if (tileOutXBase >= p.outSize.x | tileOutY >= p.outSize.y | majorBase >= p.sizeMajor) + return; + + // Load filter (flipped). + for (int tapIdx = threadIdx.x; tapIdx < filterH * filterW; tapIdx += blockDim.x) + { + int fy = tapIdx / filterW; + int fx = tapIdx - fy * filterW; + scalar_t v = 0; + if (fx < p.filterSize.x & fy < p.filterSize.y) + { + int ffx = (p.flip) ? fx : p.filterSize.x - 1 - fx; + int ffy = (p.flip) ? fy : p.filterSize.y - 1 - fy; + v = (scalar_t)p.f[ffx * p.filterStride.x + ffy * p.filterStride.y]; + } + sf[fy][fx] = v; + } + + // Loop over major and X. + for (int majorIdx = 0, major = majorBase; majorIdx < p.loopMajor & major < p.sizeMajor; majorIdx++, major++) + { + int baseNC = major * p.sizeMinor + minorBase; + int n = baseNC / p.inSize.z; + int baseC = baseNC - n * p.inSize.z; + for (int loopX = 0, tileOutX = tileOutXBase; loopX < p.loopX & tileOutX < p.outSize.x; loopX++, tileOutX += tileOutW) + { + // Load input pixels. + int tileMidX = tileOutX * downx + upx - 1 - p.pad0.x; + int tileMidY = tileOutY * downy + upy - 1 - p.pad0.y; + int tileInX = floor_div(tileMidX, upx); + int tileInY = floor_div(tileMidY, upy); + __syncthreads(); + for (int inIdx = threadIdx.x; inIdx < tileInH * tileInW * loopMinor; inIdx += blockDim.x) + { + int relC = inIdx; + int relInX = relC / loopMinor; + int relInY = relInX / tileInW; + relC -= relInX * loopMinor; + relInX -= relInY * tileInW; + int c = baseC + relC; + int inX = tileInX + relInX; + int inY = tileInY + relInY; + scalar_t v = 0; + if (inX >= 0 & inY >= 0 & inX < p.inSize.x & inY < p.inSize.y & c < p.inSize.z) + v = (scalar_t)((const T*)p.x)[inX * p.inStride.x + inY * p.inStride.y + c * p.inStride.z + n * p.inStride.w]; + sx[relInY][relInX][relC] = v; + } + + // Loop over output pixels. + __syncthreads(); + for (int outIdx = threadIdx.x; outIdx < tileOutH * tileOutW * loopMinor; outIdx += blockDim.x) + { + int relC = outIdx; + int relOutX = relC / loopMinor; + int relOutY = relOutX / tileOutW; + relC -= relOutX * loopMinor; + relOutX -= relOutY * tileOutW; + int c = baseC + relC; + int outX = tileOutX + relOutX; + int outY = tileOutY + relOutY; + + // Setup receptive field. + int midX = tileMidX + relOutX * downx; + int midY = tileMidY + relOutY * downy; + int inX = floor_div(midX, upx); + int inY = floor_div(midY, upy); + int relInX = inX - tileInX; + int relInY = inY - tileInY; + int filterX = (inX + 1) * upx - midX - 1; // flipped + int filterY = (inY + 1) * upy - midY - 1; // flipped + + // Inner loop. + if (outX < p.outSize.x & outY < p.outSize.y & c < p.outSize.z) + { + scalar_t v = 0; + #pragma unroll + for (int y = 0; y < filterH / upy; y++) + #pragma unroll + for (int x = 0; x < filterW / upx; x++) + v += sx[relInY + y][relInX + x][relC] * sf[filterY + y * upy][filterX + x * upx]; + v *= p.gain; + ((T*)p.y)[outX * p.outStride.x + outY * p.outStride.y + c * p.outStride.z + n * p.outStride.w] = (T)v; + } + } + } + } +} + +//------------------------------------------------------------------------ +// CUDA kernel selection. + +template upfirdn2d_kernel_spec choose_upfirdn2d_kernel(const upfirdn2d_kernel_params& p) +{ + int s = p.inStride.z, fx = p.filterSize.x, fy = p.filterSize.y; + + upfirdn2d_kernel_spec spec = {(void*)upfirdn2d_kernel_large, -1,-1,1, 4}; // contiguous + if (s == 1) spec = {(void*)upfirdn2d_kernel_large, -1,-1,4, 1}; // channels_last + + if (s != 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1) // contiguous + { + if (fx <= 7 && fy <= 7 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1}; + if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1}; + if (fx <= 5 && fy <= 5 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1}; + if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1}; + if (fx <= 3 && fy <= 3 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1}; + if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1}; + if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1}; + if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1}; + if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1}; + if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1}; + if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; + if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; + if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; + if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; + if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; + } + if (s == 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1) // channels_last + { + if (fx <= 7 && fy <= 7 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1}; + if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1}; + if (fx <= 5 && fy <= 5 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1}; + if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1}; + if (fx <= 3 && fy <= 3 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1}; + if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1}; + if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1}; + if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1}; + if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1}; + if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1}; + if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1}; + if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1}; + if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1}; + if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1}; + if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1}; + } + if (s != 1 && p.up.x == 2 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1) // contiguous + { + if (fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1}; + if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1}; + if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1}; + if (fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1}; + } + if (s == 1 && p.up.x == 2 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1) // channels_last + { + if (fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1}; + if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1}; + if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1}; + if (fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1}; + } + if (s != 1 && p.up.x == 2 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1) // contiguous + { + if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1}; + if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1}; + if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1}; + if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1}; + if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1}; + } + if (s == 1 && p.up.x == 2 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1) // channels_last + { + if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1}; + if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1}; + if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1}; + if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1}; + if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1}; + } + if (s != 1 && p.up.x == 1 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1) // contiguous + { + if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; + if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; + if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; + if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; + if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; + } + if (s == 1 && p.up.x == 1 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1) // channels_last + { + if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1}; + if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1}; + if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1}; + if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1}; + if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1}; + } + if (s != 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 2) // contiguous + { + if (fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 32,8,1, 1}; + if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small, 32,8,1, 1}; + if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small, 32,8,1, 1}; + if (fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small, 32,8,1, 1}; + } + if (s == 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 2) // channels_last + { + if (fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 8,8,8, 1}; + if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small, 8,8,8, 1}; + if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small, 8,8,8, 1}; + if (fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small, 8,8,8, 1}; + } + if (s != 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 1) // contiguous + { + if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 64,8,1, 1}; + if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 64,8,1, 1}; + if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 64,8,1, 1}; + if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 64,8,1, 1}; + if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 64,8,1, 1}; + } + if (s == 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 1) // channels_last + { + if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 64,1,8, 1}; + if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 64,1,8, 1}; + if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 64,1,8, 1}; + if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 64,1,8, 1}; + if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 64,1,8, 1}; + } + if (s != 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 2) // contiguous + { + if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 32,16,1, 1}; + if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small, 32,16,1, 1}; + if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 32,16,1, 1}; + if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small, 32,16,1, 1}; + if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 32,16,1, 1}; + } + if (s == 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 2) // channels_last + { + if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 1,64,8, 1}; + if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small, 1,64,8, 1}; + if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 1,64,8, 1}; + if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small, 1,64,8, 1}; + if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 1,64,8, 1}; + } + return spec; +} + +//------------------------------------------------------------------------ +// Template specializations. + +template upfirdn2d_kernel_spec choose_upfirdn2d_kernel (const upfirdn2d_kernel_params& p); +template upfirdn2d_kernel_spec choose_upfirdn2d_kernel (const upfirdn2d_kernel_params& p); +template upfirdn2d_kernel_spec choose_upfirdn2d_kernel(const upfirdn2d_kernel_params& p); + +//------------------------------------------------------------------------ diff --git a/torch_utils/ops/upfirdn2d.h b/torch_utils/ops/upfirdn2d.h new file mode 100644 index 0000000000000000000000000000000000000000..c9e2032bcac9d2abde7a75eea4d812da348afadd --- /dev/null +++ b/torch_utils/ops/upfirdn2d.h @@ -0,0 +1,59 @@ +// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +// +// NVIDIA CORPORATION and its licensors retain all intellectual property +// and proprietary rights in and to this software, related documentation +// and any modifications thereto. Any use, reproduction, disclosure or +// distribution of this software and related documentation without an express +// license agreement from NVIDIA CORPORATION is strictly prohibited. + +#include + +//------------------------------------------------------------------------ +// CUDA kernel parameters. + +struct upfirdn2d_kernel_params +{ + const void* x; + const float* f; + void* y; + + int2 up; + int2 down; + int2 pad0; + int flip; + float gain; + + int4 inSize; // [width, height, channel, batch] + int4 inStride; + int2 filterSize; // [width, height] + int2 filterStride; + int4 outSize; // [width, height, channel, batch] + int4 outStride; + int sizeMinor; + int sizeMajor; + + int loopMinor; + int loopMajor; + int loopX; + int launchMinor; + int launchMajor; +}; + +//------------------------------------------------------------------------ +// CUDA kernel specialization. + +struct upfirdn2d_kernel_spec +{ + void* kernel; + int tileOutW; + int tileOutH; + int loopMinor; + int loopX; +}; + +//------------------------------------------------------------------------ +// CUDA kernel selection. + +template upfirdn2d_kernel_spec choose_upfirdn2d_kernel(const upfirdn2d_kernel_params& p); + +//------------------------------------------------------------------------ diff --git a/torch_utils/ops/upfirdn2d.py b/torch_utils/ops/upfirdn2d.py new file mode 100644 index 0000000000000000000000000000000000000000..ceeac2b9834e33b7c601c28bf27f32aa91c69256 --- /dev/null +++ b/torch_utils/ops/upfirdn2d.py @@ -0,0 +1,384 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +"""Custom PyTorch ops for efficient resampling of 2D images.""" + +import os +import warnings +import numpy as np +import torch +import traceback + +from .. import custom_ops +from .. import misc +from . import conv2d_gradfix + +#---------------------------------------------------------------------------- + +_inited = False +_plugin = None + +def _init(): + global _inited, _plugin + if not _inited: + sources = ['upfirdn2d.cpp', 'upfirdn2d.cu'] + sources = [os.path.join(os.path.dirname(__file__), s) for s in sources] + try: + _plugin = custom_ops.get_plugin('upfirdn2d_plugin', sources=sources, extra_cuda_cflags=['--use_fast_math']) + except: + warnings.warn('Failed to build CUDA kernels for upfirdn2d. Falling back to slow reference implementation. Details:\n\n' + traceback.format_exc()) + return _plugin is not None + +def _parse_scaling(scaling): + if isinstance(scaling, int): + scaling = [scaling, scaling] + assert isinstance(scaling, (list, tuple)) + assert all(isinstance(x, int) for x in scaling) + sx, sy = scaling + assert sx >= 1 and sy >= 1 + return sx, sy + +def _parse_padding(padding): + if isinstance(padding, int): + padding = [padding, padding] + assert isinstance(padding, (list, tuple)) + assert all(isinstance(x, int) for x in padding) + if len(padding) == 2: + padx, pady = padding + padding = [padx, padx, pady, pady] + padx0, padx1, pady0, pady1 = padding + return padx0, padx1, pady0, pady1 + +def _get_filter_size(f): + if f is None: + return 1, 1 + assert isinstance(f, torch.Tensor) and f.ndim in [1, 2] + fw = f.shape[-1] + fh = f.shape[0] + with misc.suppress_tracer_warnings(): + fw = int(fw) + fh = int(fh) + misc.assert_shape(f, [fh, fw][:f.ndim]) + assert fw >= 1 and fh >= 1 + return fw, fh + +#---------------------------------------------------------------------------- + +def setup_filter(f, device=torch.device('cpu'), normalize=True, flip_filter=False, gain=1, separable=None): + r"""Convenience function to setup 2D FIR filter for `upfirdn2d()`. + + Args: + f: Torch tensor, numpy array, or python list of the shape + `[filter_height, filter_width]` (non-separable), + `[filter_taps]` (separable), + `[]` (impulse), or + `None` (identity). + device: Result device (default: cpu). + normalize: Normalize the filter so that it retains the magnitude + for constant input signal (DC)? (default: True). + flip_filter: Flip the filter? (default: False). + gain: Overall scaling factor for signal magnitude (default: 1). + separable: Return a separable filter? (default: select automatically). + + Returns: + Float32 tensor of the shape + `[filter_height, filter_width]` (non-separable) or + `[filter_taps]` (separable). + """ + # Validate. + if f is None: + f = 1 + f = torch.as_tensor(f, dtype=torch.float32) + assert f.ndim in [0, 1, 2] + assert f.numel() > 0 + if f.ndim == 0: + f = f[np.newaxis] + + # Separable? + if separable is None: + separable = (f.ndim == 1 and f.numel() >= 8) + if f.ndim == 1 and not separable: + f = f.ger(f) + assert f.ndim == (1 if separable else 2) + + # Apply normalize, flip, gain, and device. + if normalize: + f /= f.sum() + if flip_filter: + f = f.flip(list(range(f.ndim))) + f = f * (gain ** (f.ndim / 2)) + f = f.to(device=device) + return f + +#---------------------------------------------------------------------------- + +def upfirdn2d(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1, impl='cuda'): + r"""Pad, upsample, filter, and downsample a batch of 2D images. + + Performs the following sequence of operations for each channel: + + 1. Upsample the image by inserting N-1 zeros after each pixel (`up`). + + 2. Pad the image with the specified number of zeros on each side (`padding`). + Negative padding corresponds to cropping the image. + + 3. Convolve the image with the specified 2D FIR filter (`f`), shrinking it + so that the footprint of all output pixels lies within the input image. + + 4. Downsample the image by keeping every Nth pixel (`down`). + + This sequence of operations bears close resemblance to scipy.signal.upfirdn(). + The fused op is considerably more efficient than performing the same calculation + using standard PyTorch ops. It supports gradients of arbitrary order. + + Args: + x: Float32/float64/float16 input tensor of the shape + `[batch_size, num_channels, in_height, in_width]`. + f: Float32 FIR filter of the shape + `[filter_height, filter_width]` (non-separable), + `[filter_taps]` (separable), or + `None` (identity). + up: Integer upsampling factor. Can be a single int or a list/tuple + `[x, y]` (default: 1). + down: Integer downsampling factor. Can be a single int or a list/tuple + `[x, y]` (default: 1). + padding: Padding with respect to the upsampled image. Can be a single number + or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]` + (default: 0). + flip_filter: False = convolution, True = correlation (default: False). + gain: Overall scaling factor for signal magnitude (default: 1). + impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`). + + Returns: + Tensor of the shape `[batch_size, num_channels, out_height, out_width]`. + """ + assert isinstance(x, torch.Tensor) + assert impl in ['ref', 'cuda'] + if impl == 'cuda' and x.device.type == 'cuda' and _init(): + return _upfirdn2d_cuda(up=up, down=down, padding=padding, flip_filter=flip_filter, gain=gain).apply(x, f) + return _upfirdn2d_ref(x, f, up=up, down=down, padding=padding, flip_filter=flip_filter, gain=gain) + +#---------------------------------------------------------------------------- + +@misc.profiled_function +def _upfirdn2d_ref(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1): + """Slow reference implementation of `upfirdn2d()` using standard PyTorch ops. + """ + # Validate arguments. + assert isinstance(x, torch.Tensor) and x.ndim == 4 + if f is None: + f = torch.ones([1, 1], dtype=torch.float32, device=x.device) + assert isinstance(f, torch.Tensor) and f.ndim in [1, 2] + assert f.dtype == torch.float32 and not f.requires_grad + batch_size, num_channels, in_height, in_width = x.shape + upx, upy = _parse_scaling(up) + downx, downy = _parse_scaling(down) + padx0, padx1, pady0, pady1 = _parse_padding(padding) + + # Upsample by inserting zeros. + x = x.reshape([batch_size, num_channels, in_height, 1, in_width, 1]) + x = torch.nn.functional.pad(x, [0, upx - 1, 0, 0, 0, upy - 1]) + x = x.reshape([batch_size, num_channels, in_height * upy, in_width * upx]) + + # Pad or crop. + x = torch.nn.functional.pad(x, [max(padx0, 0), max(padx1, 0), max(pady0, 0), max(pady1, 0)]) + x = x[:, :, max(-pady0, 0) : x.shape[2] - max(-pady1, 0), max(-padx0, 0) : x.shape[3] - max(-padx1, 0)] + + # Setup filter. + f = f * (gain ** (f.ndim / 2)) + f = f.to(x.dtype) + if not flip_filter: + f = f.flip(list(range(f.ndim))) + + # Convolve with the filter. + f = f[np.newaxis, np.newaxis].repeat([num_channels, 1] + [1] * f.ndim) + if f.ndim == 4: + x = conv2d_gradfix.conv2d(input=x, weight=f, groups=num_channels) + else: + x = conv2d_gradfix.conv2d(input=x, weight=f.unsqueeze(2), groups=num_channels) + x = conv2d_gradfix.conv2d(input=x, weight=f.unsqueeze(3), groups=num_channels) + + # Downsample by throwing away pixels. + x = x[:, :, ::downy, ::downx] + return x + +#---------------------------------------------------------------------------- + +_upfirdn2d_cuda_cache = dict() + +def _upfirdn2d_cuda(up=1, down=1, padding=0, flip_filter=False, gain=1): + """Fast CUDA implementation of `upfirdn2d()` using custom ops. + """ + # Parse arguments. + upx, upy = _parse_scaling(up) + downx, downy = _parse_scaling(down) + padx0, padx1, pady0, pady1 = _parse_padding(padding) + + # Lookup from cache. + key = (upx, upy, downx, downy, padx0, padx1, pady0, pady1, flip_filter, gain) + if key in _upfirdn2d_cuda_cache: + return _upfirdn2d_cuda_cache[key] + + # Forward op. + class Upfirdn2dCuda(torch.autograd.Function): + @staticmethod + def forward(ctx, x, f): # pylint: disable=arguments-differ + assert isinstance(x, torch.Tensor) and x.ndim == 4 + if f is None: + f = torch.ones([1, 1], dtype=torch.float32, device=x.device) + assert isinstance(f, torch.Tensor) and f.ndim in [1, 2] + y = x + if f.ndim == 2: + y = _plugin.upfirdn2d(y, f, upx, upy, downx, downy, padx0, padx1, pady0, pady1, flip_filter, gain) + else: + y = _plugin.upfirdn2d(y, f.unsqueeze(0), upx, 1, downx, 1, padx0, padx1, 0, 0, flip_filter, np.sqrt(gain)) + y = _plugin.upfirdn2d(y, f.unsqueeze(1), 1, upy, 1, downy, 0, 0, pady0, pady1, flip_filter, np.sqrt(gain)) + ctx.save_for_backward(f) + ctx.x_shape = x.shape + return y + + @staticmethod + def backward(ctx, dy): # pylint: disable=arguments-differ + f, = ctx.saved_tensors + _, _, ih, iw = ctx.x_shape + _, _, oh, ow = dy.shape + fw, fh = _get_filter_size(f) + p = [ + fw - padx0 - 1, + iw * upx - ow * downx + padx0 - upx + 1, + fh - pady0 - 1, + ih * upy - oh * downy + pady0 - upy + 1, + ] + dx = None + df = None + + if ctx.needs_input_grad[0]: + dx = _upfirdn2d_cuda(up=down, down=up, padding=p, flip_filter=(not flip_filter), gain=gain).apply(dy, f) + + assert not ctx.needs_input_grad[1] + return dx, df + + # Add to cache. + _upfirdn2d_cuda_cache[key] = Upfirdn2dCuda + return Upfirdn2dCuda + +#---------------------------------------------------------------------------- + +def filter2d(x, f, padding=0, flip_filter=False, gain=1, impl='cuda'): + r"""Filter a batch of 2D images using the given 2D FIR filter. + + By default, the result is padded so that its shape matches the input. + User-specified padding is applied on top of that, with negative values + indicating cropping. Pixels outside the image are assumed to be zero. + + Args: + x: Float32/float64/float16 input tensor of the shape + `[batch_size, num_channels, in_height, in_width]`. + f: Float32 FIR filter of the shape + `[filter_height, filter_width]` (non-separable), + `[filter_taps]` (separable), or + `None` (identity). + padding: Padding with respect to the output. Can be a single number or a + list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]` + (default: 0). + flip_filter: False = convolution, True = correlation (default: False). + gain: Overall scaling factor for signal magnitude (default: 1). + impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`). + + Returns: + Tensor of the shape `[batch_size, num_channels, out_height, out_width]`. + """ + padx0, padx1, pady0, pady1 = _parse_padding(padding) + fw, fh = _get_filter_size(f) + p = [ + padx0 + fw // 2, + padx1 + (fw - 1) // 2, + pady0 + fh // 2, + pady1 + (fh - 1) // 2, + ] + return upfirdn2d(x, f, padding=p, flip_filter=flip_filter, gain=gain, impl=impl) + +#---------------------------------------------------------------------------- + +def upsample2d(x, f, up=2, padding=0, flip_filter=False, gain=1, impl='cuda'): + r"""Upsample a batch of 2D images using the given 2D FIR filter. + + By default, the result is padded so that its shape is a multiple of the input. + User-specified padding is applied on top of that, with negative values + indicating cropping. Pixels outside the image are assumed to be zero. + + Args: + x: Float32/float64/float16 input tensor of the shape + `[batch_size, num_channels, in_height, in_width]`. + f: Float32 FIR filter of the shape + `[filter_height, filter_width]` (non-separable), + `[filter_taps]` (separable), or + `None` (identity). + up: Integer upsampling factor. Can be a single int or a list/tuple + `[x, y]` (default: 1). + padding: Padding with respect to the output. Can be a single number or a + list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]` + (default: 0). + flip_filter: False = convolution, True = correlation (default: False). + gain: Overall scaling factor for signal magnitude (default: 1). + impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`). + + Returns: + Tensor of the shape `[batch_size, num_channels, out_height, out_width]`. + """ + upx, upy = _parse_scaling(up) + padx0, padx1, pady0, pady1 = _parse_padding(padding) + fw, fh = _get_filter_size(f) + p = [ + padx0 + (fw + upx - 1) // 2, + padx1 + (fw - upx) // 2, + pady0 + (fh + upy - 1) // 2, + pady1 + (fh - upy) // 2, + ] + return upfirdn2d(x, f, up=up, padding=p, flip_filter=flip_filter, gain=gain*upx*upy, impl=impl) + +#---------------------------------------------------------------------------- + +def downsample2d(x, f, down=2, padding=0, flip_filter=False, gain=1, impl='cuda'): + r"""Downsample a batch of 2D images using the given 2D FIR filter. + + By default, the result is padded so that its shape is a fraction of the input. + User-specified padding is applied on top of that, with negative values + indicating cropping. Pixels outside the image are assumed to be zero. + + Args: + x: Float32/float64/float16 input tensor of the shape + `[batch_size, num_channels, in_height, in_width]`. + f: Float32 FIR filter of the shape + `[filter_height, filter_width]` (non-separable), + `[filter_taps]` (separable), or + `None` (identity). + down: Integer downsampling factor. Can be a single int or a list/tuple + `[x, y]` (default: 1). + padding: Padding with respect to the input. Can be a single number or a + list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]` + (default: 0). + flip_filter: False = convolution, True = correlation (default: False). + gain: Overall scaling factor for signal magnitude (default: 1). + impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`). + + Returns: + Tensor of the shape `[batch_size, num_channels, out_height, out_width]`. + """ + downx, downy = _parse_scaling(down) + padx0, padx1, pady0, pady1 = _parse_padding(padding) + fw, fh = _get_filter_size(f) + p = [ + padx0 + (fw - downx + 1) // 2, + padx1 + (fw - downx) // 2, + pady0 + (fh - downy + 1) // 2, + pady1 + (fh - downy) // 2, + ] + return upfirdn2d(x, f, down=down, padding=p, flip_filter=flip_filter, gain=gain, impl=impl) + +#---------------------------------------------------------------------------- diff --git a/torch_utils/persistence.py b/torch_utils/persistence.py new file mode 100644 index 0000000000000000000000000000000000000000..0186cfd97bca0fcb397a7b73643520c1d1105a02 --- /dev/null +++ b/torch_utils/persistence.py @@ -0,0 +1,251 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +"""Facilities for pickling Python code alongside other data. + +The pickled code is automatically imported into a separate Python module +during unpickling. This way, any previously exported pickles will remain +usable even if the original code is no longer available, or if the current +version of the code is not consistent with what was originally pickled.""" + +import sys +import pickle +import io +import inspect +import copy +import uuid +import types +import dnnlib + +#---------------------------------------------------------------------------- + +_version = 6 # internal version number +_decorators = set() # {decorator_class, ...} +_import_hooks = [] # [hook_function, ...] +_module_to_src_dict = dict() # {module: src, ...} +_src_to_module_dict = dict() # {src: module, ...} + +#---------------------------------------------------------------------------- + +def persistent_class(orig_class): + r"""Class decorator that extends a given class to save its source code + when pickled. + + Example: + + from torch_utils import persistence + + @persistence.persistent_class + class MyNetwork(torch.nn.Module): + def __init__(self, num_inputs, num_outputs): + super().__init__() + self.fc = MyLayer(num_inputs, num_outputs) + ... + + @persistence.persistent_class + class MyLayer(torch.nn.Module): + ... + + When pickled, any instance of `MyNetwork` and `MyLayer` will save its + source code alongside other internal state (e.g., parameters, buffers, + and submodules). This way, any previously exported pickle will remain + usable even if the class definitions have been modified or are no + longer available. + + The decorator saves the source code of the entire Python module + containing the decorated class. It does *not* save the source code of + any imported modules. Thus, the imported modules must be available + during unpickling, also including `torch_utils.persistence` itself. + + It is ok to call functions defined in the same module from the + decorated class. However, if the decorated class depends on other + classes defined in the same module, they must be decorated as well. + This is illustrated in the above example in the case of `MyLayer`. + + It is also possible to employ the decorator just-in-time before + calling the constructor. For example: + + cls = MyLayer + if want_to_make_it_persistent: + cls = persistence.persistent_class(cls) + layer = cls(num_inputs, num_outputs) + + As an additional feature, the decorator also keeps track of the + arguments that were used to construct each instance of the decorated + class. The arguments can be queried via `obj.init_args` and + `obj.init_kwargs`, and they are automatically pickled alongside other + object state. A typical use case is to first unpickle a previous + instance of a persistent class, and then upgrade it to use the latest + version of the source code: + + with open('old_pickle.pkl', 'rb') as f: + old_net = pickle.load(f) + new_net = MyNetwork(*old_obj.init_args, **old_obj.init_kwargs) + misc.copy_params_and_buffers(old_net, new_net, require_all=True) + """ + assert isinstance(orig_class, type) + if is_persistent(orig_class): + return orig_class + + assert orig_class.__module__ in sys.modules + orig_module = sys.modules[orig_class.__module__] + orig_module_src = _module_to_src(orig_module) + + class Decorator(orig_class): + _orig_module_src = orig_module_src + _orig_class_name = orig_class.__name__ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._init_args = copy.deepcopy(args) + self._init_kwargs = copy.deepcopy(kwargs) + assert orig_class.__name__ in orig_module.__dict__ + _check_pickleable(self.__reduce__()) + + @property + def init_args(self): + return copy.deepcopy(self._init_args) + + @property + def init_kwargs(self): + return dnnlib.EasyDict(copy.deepcopy(self._init_kwargs)) + + def __reduce__(self): + fields = list(super().__reduce__()) + fields += [None] * max(3 - len(fields), 0) + if fields[0] is not _reconstruct_persistent_obj: + meta = dict(type='class', version=_version, module_src=self._orig_module_src, class_name=self._orig_class_name, state=fields[2]) + fields[0] = _reconstruct_persistent_obj # reconstruct func + fields[1] = (meta,) # reconstruct args + fields[2] = None # state dict + return tuple(fields) + + Decorator.__name__ = orig_class.__name__ + _decorators.add(Decorator) + return Decorator + +#---------------------------------------------------------------------------- + +def is_persistent(obj): + r"""Test whether the given object or class is persistent, i.e., + whether it will save its source code when pickled. + """ + try: + if obj in _decorators: + return True + except TypeError: + pass + return type(obj) in _decorators # pylint: disable=unidiomatic-typecheck + +#---------------------------------------------------------------------------- + +def import_hook(hook): + r"""Register an import hook that is called whenever a persistent object + is being unpickled. A typical use case is to patch the pickled source + code to avoid errors and inconsistencies when the API of some imported + module has changed. + + The hook should have the following signature: + + hook(meta) -> modified meta + + `meta` is an instance of `dnnlib.EasyDict` with the following fields: + + type: Type of the persistent object, e.g. `'class'`. + version: Internal version number of `torch_utils.persistence`. + module_src Original source code of the Python module. + class_name: Class name in the original Python module. + state: Internal state of the object. + + Example: + + @persistence.import_hook + def wreck_my_network(meta): + if meta.class_name == 'MyNetwork': + print('MyNetwork is being imported. I will wreck it!') + meta.module_src = meta.module_src.replace("True", "False") + return meta + """ + assert callable(hook) + _import_hooks.append(hook) + +#---------------------------------------------------------------------------- + +def _reconstruct_persistent_obj(meta): + r"""Hook that is called internally by the `pickle` module to unpickle + a persistent object. + """ + meta = dnnlib.EasyDict(meta) + meta.state = dnnlib.EasyDict(meta.state) + for hook in _import_hooks: + meta = hook(meta) + assert meta is not None + + assert meta.version == _version + module = _src_to_module(meta.module_src) + + assert meta.type == 'class' + orig_class = module.__dict__[meta.class_name] + decorator_class = persistent_class(orig_class) + obj = decorator_class.__new__(decorator_class) + + setstate = getattr(obj, '__setstate__', None) + if callable(setstate): + setstate(meta.state) # pylint: disable=not-callable + else: + obj.__dict__.update(meta.state) + return obj + +#---------------------------------------------------------------------------- + +def _module_to_src(module): + r"""Query the source code of a given Python module. + """ + src = _module_to_src_dict.get(module, None) + if src is None: + src = inspect.getsource(module) + _module_to_src_dict[module] = src + _src_to_module_dict[src] = module + return src + +def _src_to_module(src): + r"""Get or create a Python module for the given source code. + """ + module = _src_to_module_dict.get(src, None) + if module is None: + module_name = "_imported_module_" + uuid.uuid4().hex + module = types.ModuleType(module_name) + sys.modules[module_name] = module + _module_to_src_dict[module] = src + _src_to_module_dict[src] = module + exec(src, module.__dict__) # pylint: disable=exec-used + return module + +#---------------------------------------------------------------------------- + +def _check_pickleable(obj): + r"""Check that the given object is pickleable, raising an exception if + it is not. This function is expected to be considerably more efficient + than actually pickling the object. + """ + def recurse(obj): + if isinstance(obj, (list, tuple, set)): + return [recurse(x) for x in obj] + if isinstance(obj, dict): + return [[recurse(x), recurse(y)] for x, y in obj.items()] + if isinstance(obj, (str, int, float, bool, bytes, bytearray)): + return None # Python primitive types are pickleable. + if f'{type(obj).__module__}.{type(obj).__name__}' in ['numpy.ndarray', 'torch.Tensor']: + return None # NumPy arrays and PyTorch tensors are pickleable. + if is_persistent(obj): + return None # Persistent objects are pickleable, by virtue of the constructor check. + return obj + with io.BytesIO() as f: + pickle.dump(recurse(obj), f) + +#---------------------------------------------------------------------------- diff --git a/torch_utils/training_stats.py b/torch_utils/training_stats.py new file mode 100644 index 0000000000000000000000000000000000000000..26f467f9eaa074ee13de1cf2625cd7da44880847 --- /dev/null +++ b/torch_utils/training_stats.py @@ -0,0 +1,268 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +"""Facilities for reporting and collecting training statistics across +multiple processes and devices. The interface is designed to minimize +synchronization overhead as well as the amount of boilerplate in user +code.""" + +import re +import numpy as np +import torch +import dnnlib + +from . import misc + +#---------------------------------------------------------------------------- + +_num_moments = 3 # [num_scalars, sum_of_scalars, sum_of_squares] +_reduce_dtype = torch.float32 # Data type to use for initial per-tensor reduction. +_counter_dtype = torch.float64 # Data type to use for the internal counters. +_rank = 0 # Rank of the current process. +_sync_device = None # Device to use for multiprocess communication. None = single-process. +_sync_called = False # Has _sync() been called yet? +_counters = dict() # Running counters on each device, updated by report(): name => device => torch.Tensor +_cumulative = dict() # Cumulative counters on the CPU, updated by _sync(): name => torch.Tensor + +#---------------------------------------------------------------------------- + +def init_multiprocessing(rank, sync_device): + r"""Initializes `torch_utils.training_stats` for collecting statistics + across multiple processes. + + This function must be called after + `torch.distributed.init_process_group()` and before `Collector.update()`. + The call is not necessary if multi-process collection is not needed. + + Args: + rank: Rank of the current process. + sync_device: PyTorch device to use for inter-process + communication, or None to disable multi-process + collection. Typically `torch.device('cuda', rank)`. + """ + global _rank, _sync_device + assert not _sync_called + _rank = rank + _sync_device = sync_device + +#---------------------------------------------------------------------------- + +@misc.profiled_function +def report(name, value): + r"""Broadcasts the given set of scalars to all interested instances of + `Collector`, across device and process boundaries. + + This function is expected to be extremely cheap and can be safely + called from anywhere in the training loop, loss function, or inside a + `torch.nn.Module`. + + Warning: The current implementation expects the set of unique names to + be consistent across processes. Please make sure that `report()` is + called at least once for each unique name by each process, and in the + same order. If a given process has no scalars to broadcast, it can do + `report(name, [])` (empty list). + + Args: + name: Arbitrary string specifying the name of the statistic. + Averages are accumulated separately for each unique name. + value: Arbitrary set of scalars. Can be a list, tuple, + NumPy array, PyTorch tensor, or Python scalar. + + Returns: + The same `value` that was passed in. + """ + if name not in _counters: + _counters[name] = dict() + + elems = torch.as_tensor(value) + if elems.numel() == 0: + return value + + elems = elems.detach().flatten().to(_reduce_dtype) + moments = torch.stack([ + torch.ones_like(elems).sum(), + elems.sum(), + elems.square().sum(), + ]) + assert moments.ndim == 1 and moments.shape[0] == _num_moments + moments = moments.to(_counter_dtype) + + device = moments.device + if device not in _counters[name]: + _counters[name][device] = torch.zeros_like(moments) + _counters[name][device].add_(moments) + return value + +#---------------------------------------------------------------------------- + +def report0(name, value): + r"""Broadcasts the given set of scalars by the first process (`rank = 0`), + but ignores any scalars provided by the other processes. + See `report()` for further details. + """ + report(name, value if _rank == 0 else []) + return value + +#---------------------------------------------------------------------------- + +class Collector: + r"""Collects the scalars broadcasted by `report()` and `report0()` and + computes their long-term averages (mean and standard deviation) over + user-defined periods of time. + + The averages are first collected into internal counters that are not + directly visible to the user. They are then copied to the user-visible + state as a result of calling `update()` and can then be queried using + `mean()`, `std()`, `as_dict()`, etc. Calling `update()` also resets the + internal counters for the next round, so that the user-visible state + effectively reflects averages collected between the last two calls to + `update()`. + + Args: + regex: Regular expression defining which statistics to + collect. The default is to collect everything. + keep_previous: Whether to retain the previous averages if no + scalars were collected on a given round + (default: True). + """ + def __init__(self, regex='.*', keep_previous=True): + self._regex = re.compile(regex) + self._keep_previous = keep_previous + self._cumulative = dict() + self._moments = dict() + self.update() + self._moments.clear() + + def names(self): + r"""Returns the names of all statistics broadcasted so far that + match the regular expression specified at construction time. + """ + return [name for name in _counters if self._regex.fullmatch(name)] + + def update(self): + r"""Copies current values of the internal counters to the + user-visible state and resets them for the next round. + + If `keep_previous=True` was specified at construction time, the + operation is skipped for statistics that have received no scalars + since the last update, retaining their previous averages. + + This method performs a number of GPU-to-CPU transfers and one + `torch.distributed.all_reduce()`. It is intended to be called + periodically in the main training loop, typically once every + N training steps. + """ + if not self._keep_previous: + self._moments.clear() + for name, cumulative in _sync(self.names()): + if name not in self._cumulative: + self._cumulative[name] = torch.zeros([_num_moments], dtype=_counter_dtype) + delta = cumulative - self._cumulative[name] + self._cumulative[name].copy_(cumulative) + if float(delta[0]) != 0: + self._moments[name] = delta + + def _get_delta(self, name): + r"""Returns the raw moments that were accumulated for the given + statistic between the last two calls to `update()`, or zero if + no scalars were collected. + """ + assert self._regex.fullmatch(name) + if name not in self._moments: + self._moments[name] = torch.zeros([_num_moments], dtype=_counter_dtype) + return self._moments[name] + + def num(self, name): + r"""Returns the number of scalars that were accumulated for the given + statistic between the last two calls to `update()`, or zero if + no scalars were collected. + """ + delta = self._get_delta(name) + return int(delta[0]) + + def mean(self, name): + r"""Returns the mean of the scalars that were accumulated for the + given statistic between the last two calls to `update()`, or NaN if + no scalars were collected. + """ + delta = self._get_delta(name) + if int(delta[0]) == 0: + return float('nan') + return float(delta[1] / delta[0]) + + def std(self, name): + r"""Returns the standard deviation of the scalars that were + accumulated for the given statistic between the last two calls to + `update()`, or NaN if no scalars were collected. + """ + delta = self._get_delta(name) + if int(delta[0]) == 0 or not np.isfinite(float(delta[1])): + return float('nan') + if int(delta[0]) == 1: + return float(0) + mean = float(delta[1] / delta[0]) + raw_var = float(delta[2] / delta[0]) + return np.sqrt(max(raw_var - np.square(mean), 0)) + + def as_dict(self): + r"""Returns the averages accumulated between the last two calls to + `update()` as an `dnnlib.EasyDict`. The contents are as follows: + + dnnlib.EasyDict( + NAME = dnnlib.EasyDict(num=FLOAT, mean=FLOAT, std=FLOAT), + ... + ) + """ + stats = dnnlib.EasyDict() + for name in self.names(): + stats[name] = dnnlib.EasyDict(num=self.num(name), mean=self.mean(name), std=self.std(name)) + return stats + + def __getitem__(self, name): + r"""Convenience getter. + `collector[name]` is a synonym for `collector.mean(name)`. + """ + return self.mean(name) + +#---------------------------------------------------------------------------- + +def _sync(names): + r"""Synchronize the global cumulative counters across devices and + processes. Called internally by `Collector.update()`. + """ + if len(names) == 0: + return [] + global _sync_called + _sync_called = True + + # Collect deltas within current rank. + deltas = [] + device = _sync_device if _sync_device is not None else torch.device('cpu') + for name in names: + delta = torch.zeros([_num_moments], dtype=_counter_dtype, device=device) + for counter in _counters[name].values(): + delta.add_(counter.to(device)) + counter.copy_(torch.zeros_like(counter)) + deltas.append(delta) + deltas = torch.stack(deltas) + + # Sum deltas across ranks. + if _sync_device is not None: + torch.distributed.all_reduce(deltas) + + # Update cumulative values. + deltas = deltas.cpu() + for idx, name in enumerate(names): + if name not in _cumulative: + _cumulative[name] = torch.zeros([_num_moments], dtype=_counter_dtype) + _cumulative[name].add_(deltas[idx]) + + # Return name-value pairs. + return [(name, _cumulative[name]) for name in names] + +#---------------------------------------------------------------------------- diff --git a/training/__init__.py b/training/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/training/__pycache__/__init__.cpython-36.pyc b/training/__pycache__/__init__.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b68a9792bf4d8f075a5d99994996adc46f77da0a Binary files /dev/null and b/training/__pycache__/__init__.cpython-36.pyc differ diff --git a/training/__pycache__/__init__.cpython-39.pyc b/training/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..abfb4c39d316c74b1891467a78df5af8e1ff5e00 Binary files /dev/null and b/training/__pycache__/__init__.cpython-39.pyc differ diff --git a/training/coaches/__init__.py b/training/coaches/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/training/coaches/__pycache__/__init__.cpython-36.pyc b/training/coaches/__pycache__/__init__.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3607f582d31c7082f98276bfb400b1c060e3bfdb Binary files /dev/null and b/training/coaches/__pycache__/__init__.cpython-36.pyc differ diff --git a/training/coaches/__pycache__/__init__.cpython-39.pyc b/training/coaches/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4fe0c17e0d6a7b171fc8b3c842ad336d4a3fea41 Binary files /dev/null and b/training/coaches/__pycache__/__init__.cpython-39.pyc differ diff --git a/training/coaches/__pycache__/base_coach.cpython-36.pyc b/training/coaches/__pycache__/base_coach.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b53723f2411aa97ae744c071f60f261482eff7ab Binary files /dev/null and b/training/coaches/__pycache__/base_coach.cpython-36.pyc differ diff --git a/training/coaches/__pycache__/base_coach.cpython-39.pyc b/training/coaches/__pycache__/base_coach.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1946575ea2f0a593e73bf56a973fd3415b9df86d Binary files /dev/null and b/training/coaches/__pycache__/base_coach.cpython-39.pyc differ diff --git a/training/coaches/__pycache__/multi_id_coach.cpython-36.pyc b/training/coaches/__pycache__/multi_id_coach.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6922dc23f6b188c5c958057cc7ff6b9db113d955 Binary files /dev/null and b/training/coaches/__pycache__/multi_id_coach.cpython-36.pyc differ diff --git a/training/coaches/__pycache__/multi_id_coach.cpython-39.pyc b/training/coaches/__pycache__/multi_id_coach.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d4d5fdfce7576bdbc2bac7552d3aef0b5f85e3d7 Binary files /dev/null and b/training/coaches/__pycache__/multi_id_coach.cpython-39.pyc differ diff --git a/training/coaches/__pycache__/single_id_coach.cpython-36.pyc b/training/coaches/__pycache__/single_id_coach.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..37000f357fd3acad0fd3f95a97d337fc32750302 Binary files /dev/null and b/training/coaches/__pycache__/single_id_coach.cpython-36.pyc differ diff --git a/training/coaches/__pycache__/single_id_coach.cpython-39.pyc b/training/coaches/__pycache__/single_id_coach.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8be32a83f9339635b254bc31543c6a6f1177a930 Binary files /dev/null and b/training/coaches/__pycache__/single_id_coach.cpython-39.pyc differ diff --git a/training/coaches/base_coach.py b/training/coaches/base_coach.py new file mode 100644 index 0000000000000000000000000000000000000000..892a4825b68c6b801a7ba2c06ddc32c1127fe4c3 --- /dev/null +++ b/training/coaches/base_coach.py @@ -0,0 +1,152 @@ +import abc +import os +import pickle +from argparse import Namespace +import wandb +import os.path +from criteria.localitly_regulizer import Space_Regulizer +import torch +from torchvision import transforms +from lpips import LPIPS +from training.projectors import w_projector +from configs import global_config, paths_config, hyperparameters +from criteria import l2_loss +from models.e4e.psp import pSp +from utils.log_utils import log_image_from_w +from utils.models_utils import toogle_grad, load_old_G + + +class BaseCoach: + def __init__(self, data_loader, use_wandb): + + self.use_wandb = use_wandb + self.data_loader = data_loader + self.w_pivots = {} + self.image_counter = 0 + + if hyperparameters.first_inv_type == 'w+': + self.initilize_e4e() + + self.e4e_image_transform = transforms.Compose([ + transforms.ToPILImage(), + transforms.Resize((256, 256)), + transforms.ToTensor(), + transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]) + + # Initialize loss + self.lpips_loss = LPIPS(net=hyperparameters.lpips_type).to(global_config.device).eval() + + self.restart_training() + + # Initialize checkpoint dir + self.checkpoint_dir = paths_config.checkpoints_dir + os.makedirs(self.checkpoint_dir, exist_ok=True) + + def restart_training(self): + + # Initialize networks + self.G = load_old_G() + toogle_grad(self.G, True) + + self.original_G = load_old_G() + + self.space_regulizer = Space_Regulizer(self.original_G, self.lpips_loss) + self.optimizer = self.configure_optimizers() + + def get_inversion(self, w_path_dir, image_name, image): + embedding_dir = f'{w_path_dir}/{paths_config.pti_results_keyword}/{image_name}' + os.makedirs(embedding_dir, exist_ok=True) + + w_pivot = None + + if hyperparameters.use_last_w_pivots: + w_pivot = self.load_inversions(w_path_dir, image_name) + + if not hyperparameters.use_last_w_pivots or w_pivot is None: + w_pivot = self.calc_inversions(image, image_name) + torch.save(w_pivot, f'{embedding_dir}/0.pt') + + w_pivot = w_pivot.to(global_config.device) + return w_pivot + + def load_inversions(self, w_path_dir, image_name): + if image_name in self.w_pivots: + return self.w_pivots[image_name] + + if hyperparameters.first_inv_type == 'w+': + w_potential_path = f'{w_path_dir}/{paths_config.e4e_results_keyword}/{image_name}/0.pt' + else: + w_potential_path = f'{w_path_dir}/{paths_config.pti_results_keyword}/{image_name}/0.pt' + if not os.path.isfile(w_potential_path): + return None + w = torch.load(w_potential_path).to(global_config.device) + self.w_pivots[image_name] = w + return w + + def calc_inversions(self, image, image_name): + + if hyperparameters.first_inv_type == 'w+': + w = self.get_e4e_inversion(image) + + else: + id_image = torch.squeeze((image.to(global_config.device) + 1) / 2) * 255 + w = w_projector.project(self.G, id_image, device=torch.device(global_config.device), w_avg_samples=600, + num_steps=hyperparameters.first_inv_steps, w_name=image_name, + use_wandb=self.use_wandb) + + return w + + @abc.abstractmethod + def train(self): + pass + + def configure_optimizers(self): + optimizer = torch.optim.Adam(self.G.parameters(), lr=hyperparameters.pti_learning_rate) + + return optimizer + + def calc_loss(self, generated_images, real_images, log_name, new_G, use_ball_holder, w_batch): + loss = 0.0 + + if hyperparameters.pt_l2_lambda > 0: + l2_loss_val = l2_loss.l2_loss(generated_images, real_images) + if self.use_wandb: + wandb.log({f'MSE_loss_val_{log_name}': l2_loss_val.detach().cpu()}, step=global_config.training_step) + loss += l2_loss_val * hyperparameters.pt_l2_lambda + if hyperparameters.pt_lpips_lambda > 0: + loss_lpips = self.lpips_loss(generated_images, real_images) + loss_lpips = torch.squeeze(loss_lpips) + if self.use_wandb: + wandb.log({f'LPIPS_loss_val_{log_name}': loss_lpips.detach().cpu()}, step=global_config.training_step) + loss += loss_lpips * hyperparameters.pt_lpips_lambda + + if use_ball_holder and hyperparameters.use_locality_regularization: + ball_holder_loss_val = self.space_regulizer.space_regulizer_loss(new_G, w_batch, use_wandb=self.use_wandb) + loss += ball_holder_loss_val + + return loss, l2_loss_val, loss_lpips + + def forward(self, w): + generated_images = self.G.synthesis(w, noise_mode='const', force_fp32=True) + + return generated_images + + def initilize_e4e(self): + ckpt = torch.load(paths_config.e4e, map_location='cpu') + opts = ckpt['opts'] + opts['batch_size'] = hyperparameters.train_batch_size + opts['checkpoint_path'] = paths_config.e4e + opts = Namespace(**opts) + self.e4e_inversion_net = pSp(opts) + self.e4e_inversion_net.eval() + self.e4e_inversion_net = self.e4e_inversion_net.to(global_config.device) + toogle_grad(self.e4e_inversion_net, False) + + def get_e4e_inversion(self, image): + image = (image + 1) / 2 + new_image = self.e4e_image_transform(image[0]).to(global_config.device) + _, w = self.e4e_inversion_net(new_image.unsqueeze(0), randomize_noise=False, return_latents=True, resize=False, + input_code=False) + if self.use_wandb: + log_image_from_w(w, self.G, 'First e4e inversion') + return w diff --git a/training/coaches/multi_id_coach.py b/training/coaches/multi_id_coach.py new file mode 100644 index 0000000000000000000000000000000000000000..1bc600a22fb9b63201a4787a8e53b5dc9f462bc7 --- /dev/null +++ b/training/coaches/multi_id_coach.py @@ -0,0 +1,72 @@ +import os + +import torch +from tqdm import tqdm + +from configs import paths_config, hyperparameters, global_config +from training.coaches.base_coach import BaseCoach +from utils.log_utils import log_images_from_w + + +class MultiIDCoach(BaseCoach): + + def __init__(self, data_loader, use_wandb): + super().__init__(data_loader, use_wandb) + + def train(self): + self.G.synthesis.train() + self.G.mapping.train() + + w_path_dir = f'{paths_config.embedding_base_dir}/{paths_config.input_data_id}' + os.makedirs(w_path_dir, exist_ok=True) + os.makedirs(f'{w_path_dir}/{paths_config.pti_results_keyword}', exist_ok=True) + + use_ball_holder = True + w_pivots = [] + images = [] + + for fname, image in self.data_loader: + if self.image_counter >= hyperparameters.max_images_to_invert: + break + + image_name = fname[0] + if hyperparameters.first_inv_type == 'w+': + embedding_dir = f'{w_path_dir}/{paths_config.e4e_results_keyword}/{image_name}' + else: + embedding_dir = f'{w_path_dir}/{paths_config.pti_results_keyword}/{image_name}' + os.makedirs(embedding_dir, exist_ok=True) + + w_pivot = self.get_inversion(w_path_dir, image_name, image) + w_pivots.append(w_pivot) + images.append((image_name, image)) + self.image_counter += 1 + + for i in tqdm(range(hyperparameters.max_pti_steps)): + self.image_counter = 0 + + for data, w_pivot in zip(images, w_pivots): + image_name, image = data + + if self.image_counter >= hyperparameters.max_images_to_invert: + break + + real_images_batch = image.to(global_config.device) + + generated_images = self.forward(w_pivot) + loss, l2_loss_val, loss_lpips = self.calc_loss(generated_images, real_images_batch, image_name, + self.G, use_ball_holder, w_pivot) + + self.optimizer.zero_grad() + loss.backward() + self.optimizer.step() + + use_ball_holder = global_config.training_step % hyperparameters.locality_regularization_interval == 0 + + global_config.training_step += 1 + self.image_counter += 1 + + if self.use_wandb: + log_images_from_w(w_pivots, self.G, [image[0] for image in images]) + + torch.save(self.G, + f'{paths_config.checkpoints_dir}/model_{global_config.run_name}_multi_id.pt') diff --git a/training/coaches/single_id_coach.py b/training/coaches/single_id_coach.py new file mode 100644 index 0000000000000000000000000000000000000000..33bf98dd4124568fb401473bee390090dc09283c --- /dev/null +++ b/training/coaches/single_id_coach.py @@ -0,0 +1,73 @@ +import os +import torch +from tqdm import tqdm +from configs import paths_config, hyperparameters, global_config +from training.coaches.base_coach import BaseCoach +from utils.log_utils import log_images_from_w + + +class SingleIDCoach(BaseCoach): + + def __init__(self, data_loader, use_wandb): + super().__init__(data_loader, use_wandb) + + def train(self): + + w_path_dir = f'{paths_config.embedding_base_dir}/{paths_config.input_data_id}' + os.makedirs(w_path_dir, exist_ok=True) + os.makedirs(f'{w_path_dir}/{paths_config.pti_results_keyword}', exist_ok=True) + + use_ball_holder = True + + for fname, image in tqdm(self.data_loader): + image_name = fname[0] + + self.restart_training() + + if self.image_counter >= hyperparameters.max_images_to_invert: + break + + embedding_dir = f'{w_path_dir}/{paths_config.pti_results_keyword}/{image_name}' + os.makedirs(embedding_dir, exist_ok=True) + + w_pivot = None + + if hyperparameters.use_last_w_pivots: + w_pivot = self.load_inversions(w_path_dir, image_name) + + elif not hyperparameters.use_last_w_pivots or w_pivot is None: + w_pivot = self.calc_inversions(image, image_name) + + # w_pivot = w_pivot.detach().clone().to(global_config.device) + w_pivot = w_pivot.to(global_config.device) + + torch.save(w_pivot, f'{embedding_dir}/0.pt') + log_images_counter = 0 + real_images_batch = image.to(global_config.device) + + for i in tqdm(range(hyperparameters.max_pti_steps)): + + generated_images = self.forward(w_pivot) + loss, l2_loss_val, loss_lpips = self.calc_loss(generated_images, real_images_batch, image_name, + self.G, use_ball_holder, w_pivot) + + self.optimizer.zero_grad() + + if loss_lpips <= hyperparameters.LPIPS_value_threshold: + break + + loss.backward() + self.optimizer.step() + + use_ball_holder = global_config.training_step % hyperparameters.locality_regularization_interval == 0 + + if self.use_wandb and log_images_counter % global_config.image_rec_result_log_snapshot == 0: + log_images_from_w([w_pivot], self.G, [image_name]) + + global_config.training_step += 1 + log_images_counter += 1 + + self.image_counter += 1 + + torch.save(self.G, + f'{paths_config.checkpoints_dir}/model_{global_config.run_name}_{image_name}.pt') diff --git a/training/projectors/__init__.py b/training/projectors/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/training/projectors/__pycache__/__init__.cpython-36.pyc b/training/projectors/__pycache__/__init__.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f15e8754bece91fb94f173755861fdb852ef8dcc Binary files /dev/null and b/training/projectors/__pycache__/__init__.cpython-36.pyc differ diff --git a/training/projectors/__pycache__/__init__.cpython-39.pyc b/training/projectors/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f917b2bd8631d82d383308557e427482daa31fd9 Binary files /dev/null and b/training/projectors/__pycache__/__init__.cpython-39.pyc differ diff --git a/training/projectors/__pycache__/w_projector.cpython-36.pyc b/training/projectors/__pycache__/w_projector.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..673de6ab254efc6755d82f718a81b5722939cb97 Binary files /dev/null and b/training/projectors/__pycache__/w_projector.cpython-36.pyc differ diff --git a/training/projectors/__pycache__/w_projector.cpython-39.pyc b/training/projectors/__pycache__/w_projector.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4a145183d82c8ff34b3ba199d8974fb80fc694a1 Binary files /dev/null and b/training/projectors/__pycache__/w_projector.cpython-39.pyc differ diff --git a/training/projectors/w_plus_projector.py b/training/projectors/w_plus_projector.py new file mode 100644 index 0000000000000000000000000000000000000000..b9cce427e5374c5ddce90199e1184f84a13d30c5 --- /dev/null +++ b/training/projectors/w_plus_projector.py @@ -0,0 +1,145 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +"""Project given image to the latent space of pretrained network pickle.""" + +import copy +import wandb +import numpy as np +import torch +import torch.nn.functional as F +from tqdm import tqdm +from configs import global_config, hyperparameters +import dnnlib +from utils.log_utils import log_image_from_w + + +def project( + G, + target: torch.Tensor, # [C,H,W] and dynamic range [0,255], W & H must match G output resolution + *, + num_steps=1000, + w_avg_samples=10000, + initial_learning_rate=0.01, + initial_noise_factor=0.05, + lr_rampdown_length=0.25, + lr_rampup_length=0.05, + noise_ramp_length=0.75, + regularize_noise_weight=1e5, + verbose=False, + device: torch.device, + use_wandb=False, + initial_w=None, + image_log_step=global_config.image_rec_result_log_snapshot, + w_name: str +): + assert target.shape == (G.img_channels, G.img_resolution, G.img_resolution) + + def logprint(*args): + if verbose: + print(*args) + + G = copy.deepcopy(G).eval().requires_grad_(False).to(device).float() # type: ignore + + # Compute w stats. + logprint(f'Computing W midpoint and stddev using {w_avg_samples} samples...') + z_samples = np.random.RandomState(123).randn(w_avg_samples, G.z_dim) + w_samples = G.mapping(torch.from_numpy(z_samples).to(device), None) # [N, L, C] + w_samples = w_samples[:, :1, :].cpu().numpy().astype(np.float32) # [N, 1, C] + w_avg = np.mean(w_samples, axis=0, keepdims=True) # [1, 1, C] + w_avg_tensor = torch.from_numpy(w_avg).to(global_config.device) + w_std = (np.sum((w_samples - w_avg) ** 2) / w_avg_samples) ** 0.5 + + start_w = initial_w if initial_w is not None else w_avg + + # Setup noise inputs. + noise_bufs = {name: buf for (name, buf) in G.synthesis.named_buffers() if 'noise_const' in name} + + # Load VGG16 feature detector. + url = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/vgg16.pt' + with dnnlib.util.open_url(url) as f: + vgg16 = torch.jit.load(f).eval().to(device) + + # Features for target image. + target_images = target.unsqueeze(0).to(device).to(torch.float32) + if target_images.shape[2] > 256: + target_images = F.interpolate(target_images, size=(256, 256), mode='area') + target_features = vgg16(target_images, resize_images=False, return_lpips=True) + + start_w = np.repeat(start_w, G.mapping.num_ws, axis=1) + w_opt = torch.tensor(start_w, dtype=torch.float32, device=device, + requires_grad=True) # pylint: disable=not-callable + + optimizer = torch.optim.Adam([w_opt] + list(noise_bufs.values()), betas=(0.9, 0.999), + lr=hyperparameters.first_inv_lr) + + # Init noise. + for buf in noise_bufs.values(): + buf[:] = torch.randn_like(buf) + buf.requires_grad = True + + for step in tqdm(range(num_steps)): + + # Learning rate schedule. + t = step / num_steps + w_noise_scale = w_std * initial_noise_factor * max(0.0, 1.0 - t / noise_ramp_length) ** 2 + lr_ramp = min(1.0, (1.0 - t) / lr_rampdown_length) + lr_ramp = 0.5 - 0.5 * np.cos(lr_ramp * np.pi) + lr_ramp = lr_ramp * min(1.0, t / lr_rampup_length) + lr = initial_learning_rate * lr_ramp + for param_group in optimizer.param_groups: + param_group['lr'] = lr + + # Synth images from opt_w. + w_noise = torch.randn_like(w_opt) * w_noise_scale + ws = (w_opt + w_noise) + + synth_images = G.synthesis(ws, noise_mode='const', force_fp32=True) + + # Downsample image to 256x256 if it's larger than that. VGG was built for 224x224 images. + synth_images = (synth_images + 1) * (255 / 2) + if synth_images.shape[2] > 256: + synth_images = F.interpolate(synth_images, size=(256, 256), mode='area') + + # Features for synth images. + synth_features = vgg16(synth_images, resize_images=False, return_lpips=True) + dist = (target_features - synth_features).square().sum() + + # Noise regularization. + reg_loss = 0.0 + for v in noise_bufs.values(): + noise = v[None, None, :, :] # must be [1,1,H,W] for F.avg_pool2d() + while True: + reg_loss += (noise * torch.roll(noise, shifts=1, dims=3)).mean() ** 2 + reg_loss += (noise * torch.roll(noise, shifts=1, dims=2)).mean() ** 2 + if noise.shape[2] <= 8: + break + noise = F.avg_pool2d(noise, kernel_size=2) + loss = dist + reg_loss * regularize_noise_weight + + if step % image_log_step == 0: + with torch.no_grad(): + if use_wandb: + global_config.training_step += 1 + wandb.log({f'first projection _{w_name}': loss.detach().cpu()}, step=global_config.training_step) + log_image_from_w(w_opt, G, w_name) + + # Step + optimizer.zero_grad(set_to_none=True) + loss.backward() + optimizer.step() + logprint(f'step {step + 1:>4d}/{num_steps}: dist {dist:<4.2f} loss {float(loss):<5.2f}') + + # Normalize noise. + with torch.no_grad(): + for buf in noise_bufs.values(): + buf -= buf.mean() + buf *= buf.square().mean().rsqrt() + + del G + return w_opt diff --git a/training/projectors/w_projector.py b/training/projectors/w_projector.py new file mode 100644 index 0000000000000000000000000000000000000000..bec89870027dcb86961a1898e89f987a65f6f9a8 --- /dev/null +++ b/training/projectors/w_projector.py @@ -0,0 +1,142 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +"""Project given image to the latent space of pretrained network pickle.""" + +import copy +import wandb +import numpy as np +import torch +import torch.nn.functional as F +from tqdm import tqdm +from configs import global_config, hyperparameters +from utils import log_utils +import dnnlib + + +def project( + G, + target: torch.Tensor, # [C,H,W] and dynamic range [0,255], W & H must match G output resolution + *, + num_steps=1000, + w_avg_samples=10000, + initial_learning_rate=0.01, + initial_noise_factor=0.05, + lr_rampdown_length=0.25, + lr_rampup_length=0.05, + noise_ramp_length=0.75, + regularize_noise_weight=1e5, + verbose=False, + device: torch.device, + use_wandb=False, + initial_w=None, + image_log_step=global_config.image_rec_result_log_snapshot, + w_name: str +): + assert target.shape == (G.img_channels, G.img_resolution, G.img_resolution) + + def logprint(*args): + if verbose: + print(*args) + + G = copy.deepcopy(G).eval().requires_grad_(False).to(device).float() # type: ignore + + # Compute w stats. + logprint(f'Computing W midpoint and stddev using {w_avg_samples} samples...') + z_samples = np.random.RandomState(123).randn(w_avg_samples, G.z_dim) + w_samples = G.mapping(torch.from_numpy(z_samples).to(device), None) # [N, L, C] + w_samples = w_samples[:, :1, :].cpu().numpy().astype(np.float32) # [N, 1, C] + w_avg = np.mean(w_samples, axis=0, keepdims=True) # [1, 1, C] + w_avg_tensor = torch.from_numpy(w_avg).to(global_config.device) + w_std = (np.sum((w_samples - w_avg) ** 2) / w_avg_samples) ** 0.5 + + start_w = initial_w if initial_w is not None else w_avg + + # Setup noise inputs. + noise_bufs = {name: buf for (name, buf) in G.synthesis.named_buffers() if 'noise_const' in name} + + # Load VGG16 feature detector. + url = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/vgg16.pt' + with dnnlib.util.open_url(url) as f: + vgg16 = torch.jit.load(f).eval().to(device) + + # Features for target image. + target_images = target.unsqueeze(0).to(device).to(torch.float32) + if target_images.shape[2] > 256: + target_images = F.interpolate(target_images, size=(256, 256), mode='area') + target_features = vgg16(target_images, resize_images=False, return_lpips=True) + + w_opt = torch.tensor(start_w, dtype=torch.float32, device=device, + requires_grad=True) # pylint: disable=not-callable + optimizer = torch.optim.Adam([w_opt] + list(noise_bufs.values()), betas=(0.9, 0.999), + lr=hyperparameters.first_inv_lr) + + # Init noise. + for buf in noise_bufs.values(): + buf[:] = torch.randn_like(buf) + buf.requires_grad = True + + for step in tqdm(range(num_steps)): + + # Learning rate schedule. + t = step / num_steps + w_noise_scale = w_std * initial_noise_factor * max(0.0, 1.0 - t / noise_ramp_length) ** 2 + lr_ramp = min(1.0, (1.0 - t) / lr_rampdown_length) + lr_ramp = 0.5 - 0.5 * np.cos(lr_ramp * np.pi) + lr_ramp = lr_ramp * min(1.0, t / lr_rampup_length) + lr = initial_learning_rate * lr_ramp + for param_group in optimizer.param_groups: + param_group['lr'] = lr + + # Synth images from opt_w. + w_noise = torch.randn_like(w_opt) * w_noise_scale + ws = (w_opt + w_noise).repeat([1, G.mapping.num_ws, 1]) + synth_images = G.synthesis(ws, noise_mode='const', force_fp32=True) + + # Downsample image to 256x256 if it's larger than that. VGG was built for 224x224 images. + synth_images = (synth_images + 1) * (255 / 2) + if synth_images.shape[2] > 256: + synth_images = F.interpolate(synth_images, size=(256, 256), mode='area') + + # Features for synth images. + synth_features = vgg16(synth_images, resize_images=False, return_lpips=True) + dist = (target_features - synth_features).square().sum() + + # Noise regularization. + reg_loss = 0.0 + for v in noise_bufs.values(): + noise = v[None, None, :, :] # must be [1,1,H,W] for F.avg_pool2d() + while True: + reg_loss += (noise * torch.roll(noise, shifts=1, dims=3)).mean() ** 2 + reg_loss += (noise * torch.roll(noise, shifts=1, dims=2)).mean() ** 2 + if noise.shape[2] <= 8: + break + noise = F.avg_pool2d(noise, kernel_size=2) + loss = dist + reg_loss * regularize_noise_weight + + if step % image_log_step == 0: + with torch.no_grad(): + if use_wandb: + global_config.training_step += 1 + wandb.log({f'first projection _{w_name}': loss.detach().cpu()}, step=global_config.training_step) + log_utils.log_image_from_w(w_opt.repeat([1, G.mapping.num_ws, 1]), G, w_name) + + # Step + optimizer.zero_grad(set_to_none=True) + loss.backward() + optimizer.step() + logprint(f'step {step + 1:>4d}/{num_steps}: dist {dist:<4.2f} loss {float(loss):<5.2f}') + + # Normalize noise. + with torch.no_grad(): + for buf in noise_bufs.values(): + buf -= buf.mean() + buf *= buf.square().mean().rsqrt() + + del G + return w_opt.repeat([1, 18, 1]) diff --git a/tune.py b/tune.py new file mode 100644 index 0000000000000000000000000000000000000000..6f8a1921786f5cc85a39729f9070baea7607e5e5 --- /dev/null +++ b/tune.py @@ -0,0 +1,40 @@ +import wandb +import click +import os +import sys +import pickle +import numpy as np +from PIL import Image +import torch +from configs import paths_config, hyperparameters, global_config +from IPython.display import display +import matplotlib.pyplot as plt +from scripts.latent_editor_wrapper import LatentEditorWrapper + +image_dir_name = '/home/sayantan/processed_images' +use_multi_id_training = False +global_config.device = 'cuda' +paths_config.e4e = '/home/sayantan/PTI/pretrained_models/e4e_ffhq_encode.pt' +paths_config.input_data_id = image_dir_name +paths_config.input_data_path = f'{image_dir_name}' +paths_config.stylegan2_ada_ffhq = '/home/sayantan/PTI/pretrained_models/ffhq.pkl' +paths_config.checkpoints_dir = '/home/sayantan/PTI/' +paths_config.style_clip_pretrained_mappers = '/home/sayantan/PTI/pretrained_models' +hyperparameters.use_locality_regularization = False +hyperparameters.lpips_type = 'squeeze' + +from scripts.run_pti import run_PTI + +@click.command() +@click.pass_context +@click.option('--rname', prompt='wandb RUN NAME', help='The name to give for the wandb run') + +def tune(ctx: click.Context,rname): + runn = wandb.init(project='PTI', entity='masc', name = rname) + model_id = run_PTI(run_name='',use_wandb=True, use_multi_id_training=False) + +#---------------------------------------------------------------------------- +if __name__ == '__main__': + tune() + +#---------------------------------------------------------------------------- diff --git a/upload_wandb.py b/upload_wandb.py new file mode 100644 index 0000000000000000000000000000000000000000..dec295baaa7823e98a9e8dd342bef477f52265c1 --- /dev/null +++ b/upload_wandb.py @@ -0,0 +1,9 @@ +import wandb +api = wandb.Api() +run = api.run("masc/PTIseg/rhh4r09q") +import os +fils = os.listdir("/home/sayantan/processed_images") + +for i in fils: + run.upload_file("/home/sayantan/processed_images/"+i,root="/home/sayantan/") +print("uploaded all") diff --git a/utils/ImagesDataset.py b/utils/ImagesDataset.py new file mode 100644 index 0000000000000000000000000000000000000000..20aa43e9b9ad242a5e1fd4309f83ab182021cbc6 --- /dev/null +++ b/utils/ImagesDataset.py @@ -0,0 +1,25 @@ +import os + +from torch.utils.data import Dataset +from PIL import Image + +from utils.data_utils import make_dataset + + +class ImagesDataset(Dataset): + + def __init__(self, source_root, source_transform=None): + self.source_paths = sorted(make_dataset(source_root)) + self.source_transform = source_transform + + def __len__(self): + return len(self.source_paths) + + def __getitem__(self, index): + fname, from_path = self.source_paths[index] + from_im = Image.open(from_path).convert('RGB') + + if self.source_transform: + from_im = self.source_transform(from_im) + + return fname, from_im diff --git a/utils/__init__.py b/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/utils/align_data.py b/utils/align_data.py new file mode 100644 index 0000000000000000000000000000000000000000..d292b9cd3dc5ca816c6c567a3271a0ef732093dc --- /dev/null +++ b/utils/align_data.py @@ -0,0 +1,35 @@ +from configs import paths_config +import dlib +import glob +import os +from tqdm import tqdm +from utils.alignment import align_face + + +def pre_process_images(raw_images_path): + current_directory = os.getcwd() + + IMAGE_SIZE = 256 + predictor = dlib.shape_predictor(paths_config.dlib) + os.chdir(raw_images_path) + images_names = glob.glob(f'*') + + aligned_images = [] + for image_name in tqdm(images_names): + try: + aligned_image = align_face(filepath=f'{raw_images_path}/{image_name}', + predictor=predictor, output_size=IMAGE_SIZE) + aligned_images.append(aligned_image) + except Exception as e: + print(e) + + os.makedirs(paths_config.input_data_path, exist_ok=True) + for image, name in zip(aligned_images, images_names): + real_name = name.split('.')[0] + image.save(f'{paths_config.input_data_path}/{real_name}.jpg') + + os.chdir(current_directory) + + +if __name__ == "__main__": + pre_process_images('') \ No newline at end of file diff --git a/utils/alignment.py b/utils/alignment.py new file mode 100644 index 0000000000000000000000000000000000000000..033e78e5c3f82721ad369ed1814b600b155f45e2 --- /dev/null +++ b/utils/alignment.py @@ -0,0 +1,114 @@ +import numpy as np +import PIL +import PIL.Image +import scipy +import scipy.ndimage +import dlib + + +def get_landmark(filepath, predictor): + """get landmark with dlib + :return: np.array shape=(68, 2) + """ + detector = dlib.get_frontal_face_detector() + + img = dlib.load_rgb_image(filepath) + dets = detector(img, 1) + + for k, d in enumerate(dets): + shpe = predictor(img, d) + + t = list(shpe.parts()) + a = [] + for tt in t: + a.append([tt.x, tt.y]) + lm = np.array(a) + return lm + + +def align_face(filepath, predictor, output_size): + """ + :param filepath: str + :return: PIL Image + """ + + lm = get_landmark(filepath, predictor) + + lm_chin = lm[0: 17] # left-right + lm_eyebrow_left = lm[17: 22] # left-right + lm_eyebrow_right = lm[22: 27] # left-right + lm_nose = lm[27: 31] # top-down + lm_nostrils = lm[31: 36] # top-down + lm_eye_left = lm[36: 42] # left-clockwise + lm_eye_right = lm[42: 48] # left-clockwise + lm_mouth_outer = lm[48: 60] # left-clockwise + lm_mouth_inner = lm[60: 68] # left-clockwise + + # Calculate auxiliary vectors. + eye_left = np.mean(lm_eye_left, axis=0) + eye_right = np.mean(lm_eye_right, axis=0) + eye_avg = (eye_left + eye_right) * 0.5 + eye_to_eye = eye_right - eye_left + mouth_left = lm_mouth_outer[0] + mouth_right = lm_mouth_outer[6] + mouth_avg = (mouth_left + mouth_right) * 0.5 + eye_to_mouth = mouth_avg - eye_avg + + # Choose oriented crop rectangle. + x = eye_to_eye - np.flipud(eye_to_mouth) * [-1, 1] + x /= np.hypot(*x) + x *= max(np.hypot(*eye_to_eye) * 2.0, np.hypot(*eye_to_mouth) * 1.8) + y = np.flipud(x) * [-1, 1] + c = eye_avg + eye_to_mouth * 0.1 + quad = np.stack([c - x - y, c - x + y, c + x + y, c + x - y]) + qsize = np.hypot(*x) * 2 + + # read image + img = PIL.Image.open(filepath).convert('RGB') + + transform_size = output_size + enable_padding = True + + # Shrink. + shrink = int(np.floor(qsize / output_size * 0.5)) + if shrink > 1: + rsize = (int(np.rint(float(img.size[0]) / shrink)), int(np.rint(float(img.size[1]) / shrink))) + img = img.resize(rsize, PIL.Image.ANTIALIAS) + quad /= shrink + qsize /= shrink + + # Crop. + border = max(int(np.rint(qsize * 0.1)), 3) + crop = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))), int(np.ceil(max(quad[:, 0]))), + int(np.ceil(max(quad[:, 1])))) + crop = (max(crop[0] - border, 0), max(crop[1] - border, 0), min(crop[2] + border, img.size[0]), + min(crop[3] + border, img.size[1])) + if crop[2] - crop[0] < img.size[0] or crop[3] - crop[1] < img.size[1]: + img = img.crop(crop) + quad -= crop[0:2] + + # Pad. + pad = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))), int(np.ceil(max(quad[:, 0]))), + int(np.ceil(max(quad[:, 1])))) + pad = (max(-pad[0] + border, 0), max(-pad[1] + border, 0), max(pad[2] - img.size[0] + border, 0), + max(pad[3] - img.size[1] + border, 0)) + if enable_padding and max(pad) > border - 4: + pad = np.maximum(pad, int(np.rint(qsize * 0.3))) + img = np.pad(np.float32(img), ((pad[1], pad[3]), (pad[0], pad[2]), (0, 0)), 'reflect') + h, w, _ = img.shape + y, x, _ = np.ogrid[:h, :w, :1] + mask = np.maximum(1.0 - np.minimum(np.float32(x) / pad[0], np.float32(w - 1 - x) / pad[2]), + 1.0 - np.minimum(np.float32(y) / pad[1], np.float32(h - 1 - y) / pad[3])) + blur = qsize * 0.02 + img += (scipy.ndimage.gaussian_filter(img, [blur, blur, 0]) - img) * np.clip(mask * 3.0 + 1.0, 0.0, 1.0) + img += (np.median(img, axis=(0, 1)) - img) * np.clip(mask, 0.0, 1.0) + img = PIL.Image.fromarray(np.uint8(np.clip(np.rint(img), 0, 255)), 'RGB') + quad += pad[:2] + + # Transform. + img = img.transform((transform_size, transform_size), PIL.Image.QUAD, (quad + 0.5).flatten(), PIL.Image.BILINEAR) + if output_size < transform_size: + img = img.resize((output_size, output_size), PIL.Image.ANTIALIAS) + + # Return aligned image. + return img diff --git a/utils/data_utils.py b/utils/data_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..a477bb62396989bf1000a9a46c695687b5c15f59 --- /dev/null +++ b/utils/data_utils.py @@ -0,0 +1,34 @@ +import os + +from PIL import Image + +IMG_EXTENSIONS = [ + '.jpg', '.JPG', '.jpeg', '.JPEG', + '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', '.tiff' +] + + +def is_image_file(filename): + return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) + + +def tensor2im(var): + # var shape: (3, H, W) + var = var.cpu().detach().transpose(0, 2).transpose(0, 1).numpy() + var = ((var + 1) / 2) + var[var < 0] = 0 + var[var > 1] = 1 + var = var * 255 + return Image.fromarray(var.astype('uint8')) + + +def make_dataset(dir): + images = [] + assert os.path.isdir(dir), '%s is not a valid directory' % dir + for root, _, fnames in sorted(os.walk(dir)): + for fname in fnames: + if is_image_file(fname): + path = os.path.join(root, fname) + fname = fname.split('.')[0] + images.append((fname, path)) + return images diff --git a/utils/log_utils.py b/utils/log_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..7149cf8877be2759ed885901946db683d1295768 --- /dev/null +++ b/utils/log_utils.py @@ -0,0 +1,79 @@ +import numpy as np +from PIL import Image +import wandb +from configs import global_config +import torch +import matplotlib.pyplot as plt + + +def log_image_from_w(w, G, name): + img = get_image_from_w(w, G) + pillow_image = Image.fromarray(img) + wandb.log( + {f"{name}": [ + wandb.Image(pillow_image, caption=f"current inversion {name}")]}, + step=global_config.training_step) + + +def log_images_from_w(ws, G, names): + for name, w in zip(names, ws): + w = w.to(global_config.device) + log_image_from_w(w, G, name) + + +def plot_image_from_w(w, G): + img = get_image_from_w(w, G) + pillow_image = Image.fromarray(img) + plt.imshow(pillow_image) + plt.show() + + +def plot_image(img): + img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8).detach().cpu().numpy() + pillow_image = Image.fromarray(img[0]) + plt.imshow(pillow_image) + plt.show() + + +def save_image(name, method_type, results_dir, image, run_id): + image.save(f'{results_dir}/{method_type}_{name}_{run_id}.jpg') + + +def save_w(w, G, name, method_type, results_dir): + im = get_image_from_w(w, G) + im = Image.fromarray(im, mode='RGB') + save_image(name, method_type, results_dir, im) + + +def save_concat_image(base_dir, image_latents, new_inv_image_latent, new_G, + old_G, + file_name, + extra_image=None): + images_to_save = [] + if extra_image is not None: + images_to_save.append(extra_image) + for latent in image_latents: + images_to_save.append(get_image_from_w(latent, old_G)) + images_to_save.append(get_image_from_w(new_inv_image_latent, new_G)) + result_image = create_alongside_images(images_to_save) + result_image.save(f'{base_dir}/{file_name}.jpg') + + +def save_single_image(base_dir, image_latent, G, file_name): + image_to_save = get_image_from_w(image_latent, G) + image_to_save = Image.fromarray(image_to_save, mode='RGB') + image_to_save.save(f'{base_dir}/{file_name}.jpg') + + +def create_alongside_images(images): + res = np.concatenate([np.array(image) for image in images], axis=1) + return Image.fromarray(res, mode='RGB') + + +def get_image_from_w(w, G): + if len(w.size()) <= 2: + w = w.unsqueeze(0) + with torch.no_grad(): + img = G.synthesis(w, noise_mode='const') + img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8).detach().cpu().numpy() + return img[0] diff --git a/utils/models_utils.py b/utils/models_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..0465f38b8f140d77c40a2323f5a8cce4571f6194 --- /dev/null +++ b/utils/models_utils.py @@ -0,0 +1,25 @@ +import pickle +import functools +import torch +from configs import paths_config, global_config + + +def toogle_grad(model, flag=True): + for p in model.parameters(): + p.requires_grad = flag + + +def load_tuned_G(run_id, type): + new_G_path = f'{paths_config.checkpoints_dir}/model_{run_id}_{type}.pt' + with open(new_G_path, 'rb') as f: + new_G = torch.load(f).to(global_config.device).eval() + new_G = new_G.float() + toogle_grad(new_G, False) + return new_G + + +def load_old_G(): + with open(paths_config.stylegan2_ada_ffhq, 'rb') as f: + old_G = pickle.load(f)['G_ema'].to(global_config.device).eval() + old_G = old_G.float() + return old_G