diff --git a/.DS_Store b/.DS_Store index a7ead18d57b5a2ad26d6a3e435add8d497a0d8d4..9a686eb54f9a4bcfce2dede881dac1fa7ca72e0a 100644 Binary files a/.DS_Store and b/.DS_Store differ diff --git a/.gitattributes b/.gitattributes index c7d9f3332a950355d5a77d85000f05e6f45435ea..4173da7f9117717d1534a334f96161aad34c7c4f 100644 --- a/.gitattributes +++ b/.gitattributes @@ -32,3 +32,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *.zip filter=lfs diff=lfs merge=lfs -text *.zst filter=lfs diff=lfs merge=lfs -text *tfevents* filter=lfs diff=lfs merge=lfs -text +*.pth* filter=lfs diff=lfs merge=lfs -text + filter=lfs diff=lfs merge=lfs -text diff --git a/LICENSE.txt b/LICENSE.txt new file mode 100644 index 0000000000000000000000000000000000000000..48e03de7fd9ce9dda1a8f2c5b0f8cbca8f6b24e6 --- /dev/null +++ b/LICENSE.txt @@ -0,0 +1,97 @@ +Copyright (c) 2021, NVIDIA Corporation. All rights reserved. + + +NVIDIA Source Code License for StyleGAN2 with Adaptive Discriminator Augmentation (ADA) + + +======================================================================= + +1. Definitions + +"Licensor" means any person or entity that distributes its Work. + +"Software" means the original work of authorship made available under +this License. + +"Work" means the Software and any additions to or derivative works of +the Software that are made available under this License. + +The terms "reproduce," "reproduction," "derivative works," and +"distribution" have the meaning as provided under U.S. copyright law; +provided, however, that for the purposes of this License, derivative +works shall not include works that remain separable from, or merely +link (or bind by name) to the interfaces of, the Work. + +Works, including the Software, are "made available" under this License +by including in or with the Work either (a) a copyright notice +referencing the applicability of this License to the Work, or (b) a +copy of this License. + +2. License Grants + + 2.1 Copyright Grant. Subject to the terms and conditions of this + License, each Licensor grants to you a perpetual, worldwide, + non-exclusive, royalty-free, copyright license to reproduce, + prepare derivative works of, publicly display, publicly perform, + sublicense and distribute its Work and any resulting derivative + works in any form. + +3. Limitations + + 3.1 Redistribution. You may reproduce or distribute the Work only + if (a) you do so under this License, (b) you include a complete + copy of this License with your distribution, and (c) you retain + without modification any copyright, patent, trademark, or + attribution notices that are present in the Work. + + 3.2 Derivative Works. You may specify that additional or different + terms apply to the use, reproduction, and distribution of your + derivative works of the Work ("Your Terms") only if (a) Your Terms + provide that the use limitation in Section 3.3 applies to your + derivative works, and (b) you identify the specific derivative + works that are subject to Your Terms. Notwithstanding Your Terms, + this License (including the redistribution requirements in Section + 3.1) will continue to apply to the Work itself. + + 3.3 Use Limitation. The Work and any derivative works thereof only + may be used or intended for use non-commercially. Notwithstanding + the foregoing, NVIDIA and its affiliates may use the Work and any + derivative works commercially. As used herein, "non-commercially" + means for research or evaluation purposes only. + + 3.4 Patent Claims. If you bring or threaten to bring a patent claim + against any Licensor (including any claim, cross-claim or + counterclaim in a lawsuit) to enforce any patents that you allege + are infringed by any Work, then your rights under this License from + such Licensor (including the grant in Section 2.1) will terminate + immediately. + + 3.5 Trademarks. This License does not grant any rights to use any + Licensor’s or its affiliates’ names, logos, or trademarks, except + as necessary to reproduce the notices described in this License. + + 3.6 Termination. If you violate any term of this License, then your + rights under this License (including the grant in Section 2.1) will + terminate immediately. + +4. Disclaimer of Warranty. + +THE WORK IS PROVIDED "AS IS" WITHOUT WARRANTIES OR CONDITIONS OF ANY +KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WARRANTIES OR CONDITIONS OF +MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR +NON-INFRINGEMENT. YOU BEAR THE RISK OF UNDERTAKING ANY ACTIVITIES UNDER +THIS LICENSE. + +5. Limitation of Liability. + +EXCEPT AS PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL +THEORY, WHETHER IN TORT (INCLUDING NEGLIGENCE), CONTRACT, OR OTHERWISE +SHALL ANY LICENSOR BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT, +INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF +OR RELATED TO THIS LICENSE, THE USE OR INABILITY TO USE THE WORK +(INCLUDING BUT NOT LIMITED TO LOSS OF GOODWILL, BUSINESS INTERRUPTION, +LOST PROFITS OR DATA, COMPUTER FAILURE OR MALFUNCTION, OR ANY OTHER +COMMERCIAL DAMAGES OR LOSSES), EVEN IF THE LICENSOR HAS BEEN ADVISED OF +THE POSSIBILITY OF SUCH DAMAGES. + +======================================================================= diff --git a/dnnlib/__init__.py b/dnnlib/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2f08cf36f11f9b0fd94c1b7caeadf69b98375b04 --- /dev/null +++ b/dnnlib/__init__.py @@ -0,0 +1,9 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +from .util import EasyDict, make_cache_dir_path diff --git a/dnnlib/__pycache__/__init__.cpython-38.pyc b/dnnlib/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ec292a2ea1bce630121429b8f7f963290261a5b0 Binary files /dev/null and b/dnnlib/__pycache__/__init__.cpython-38.pyc differ diff --git a/dnnlib/__pycache__/util.cpython-38.pyc b/dnnlib/__pycache__/util.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ad2b30269b88391a530fe09ed180fc18a46a7427 Binary files /dev/null and b/dnnlib/__pycache__/util.cpython-38.pyc differ diff --git a/dnnlib/util.py b/dnnlib/util.py new file mode 100644 index 0000000000000000000000000000000000000000..147c2be2e382e43898f23189690c4a4b6fb7732a --- /dev/null +++ b/dnnlib/util.py @@ -0,0 +1,473 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +"""Miscellaneous utility classes and functions.""" + +import ctypes +import fnmatch +import importlib +import inspect +import numpy as np +import os +import shutil +import sys +import types +import io +import pickle +import re +import requests +import html +import hashlib +import glob +import tempfile +import urllib +import urllib.request +import uuid + +from distutils.util import strtobool +from typing import Any, List, Tuple, Union + + +class EasyDict(dict): + """Convenience class that behaves like a dict but allows access with the attribute syntax.""" + + def __getattr__(self, name: str) -> Any: + try: + return self[name] + except KeyError: + raise AttributeError(name) + + def __setattr__(self, name: str, value: Any) -> None: + self[name] = value + + def __delattr__(self, name: str) -> None: + del self[name] + + +class Logger(object): + """Redirect stderr to stdout, optionally print stdout to a file, and optionally force flushing on both stdout and the file.""" + + def __init__(self, file_name: str = None, file_mode: str = "w", should_flush: bool = True): + self.file = None + + if file_name is not None: + self.file = open(file_name, file_mode) + + self.should_flush = should_flush + self.stdout = sys.stdout + self.stderr = sys.stderr + + sys.stdout = self + sys.stderr = self + + def __enter__(self) -> "Logger": + return self + + def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: + self.close() + + def write(self, text: Union[str, bytes]) -> None: + """Write text to stdout (and a file) and optionally flush.""" + if isinstance(text, bytes): + text = text.decode() + if len(text) == 0: # workaround for a bug in VSCode debugger: sys.stdout.write(''); sys.stdout.flush() => crash + return + + if self.file is not None: + self.file.write(text) + + self.stdout.write(text) + + if self.should_flush: + self.flush() + + def flush(self) -> None: + """Flush written text to both stdout and a file, if open.""" + if self.file is not None: + self.file.flush() + + self.stdout.flush() + + def close(self) -> None: + """Flush, close possible files, and remove stdout/stderr mirroring.""" + self.flush() + + # if using multiple loggers, prevent closing in wrong order + if sys.stdout is self: + sys.stdout = self.stdout + if sys.stderr is self: + sys.stderr = self.stderr + + if self.file is not None: + self.file.close() + self.file = None + + +# Cache directories +# ------------------------------------------------------------------------------------------ + +_dnnlib_cache_dir = None + +def set_cache_dir(path: str) -> None: + global _dnnlib_cache_dir + _dnnlib_cache_dir = path + +def make_cache_dir_path(*paths: str) -> str: + if _dnnlib_cache_dir is not None: + return os.path.join(_dnnlib_cache_dir, *paths) + if 'DNNLIB_CACHE_DIR' in os.environ: + return os.path.join(os.environ['DNNLIB_CACHE_DIR'], *paths) + if 'HOME' in os.environ: + return os.path.join(os.environ['HOME'], '.cache', 'dnnlib', *paths) + if 'USERPROFILE' in os.environ: + return os.path.join(os.environ['USERPROFILE'], '.cache', 'dnnlib', *paths) + return os.path.join(tempfile.gettempdir(), '.cache', 'dnnlib', *paths) + +# Small util functions +# ------------------------------------------------------------------------------------------ + + +def format_time(seconds: Union[int, float]) -> str: + """Convert the seconds to human readable string with days, hours, minutes and seconds.""" + s = int(np.rint(seconds)) + + if s < 60: + return "{0}s".format(s) + elif s < 60 * 60: + return "{0}m {1:02}s".format(s // 60, s % 60) + elif s < 24 * 60 * 60: + return "{0}h {1:02}m {2:02}s".format(s // (60 * 60), (s // 60) % 60, s % 60) + else: + return "{0}d {1:02}h {2:02}m".format(s // (24 * 60 * 60), (s // (60 * 60)) % 24, (s // 60) % 60) + + +def ask_yes_no(question: str) -> bool: + """Ask the user the question until the user inputs a valid answer.""" + while True: + try: + print("{0} [y/n]".format(question)) + return strtobool(input().lower()) + except ValueError: + pass + + +def tuple_product(t: Tuple) -> Any: + """Calculate the product of the tuple elements.""" + result = 1 + + for v in t: + result *= v + + return result + + +_str_to_ctype = { + "uint8": ctypes.c_ubyte, + "uint16": ctypes.c_uint16, + "uint32": ctypes.c_uint32, + "uint64": ctypes.c_uint64, + "int8": ctypes.c_byte, + "int16": ctypes.c_int16, + "int32": ctypes.c_int32, + "int64": ctypes.c_int64, + "float32": ctypes.c_float, + "float64": ctypes.c_double +} + + +def get_dtype_and_ctype(type_obj: Any) -> Tuple[np.dtype, Any]: + """Given a type name string (or an object having a __name__ attribute), return matching Numpy and ctypes types that have the same size in bytes.""" + type_str = None + + if isinstance(type_obj, str): + type_str = type_obj + elif hasattr(type_obj, "__name__"): + type_str = type_obj.__name__ + elif hasattr(type_obj, "name"): + type_str = type_obj.name + else: + raise RuntimeError("Cannot infer type name from input") + + assert type_str in _str_to_ctype.keys() + + my_dtype = np.dtype(type_str) + my_ctype = _str_to_ctype[type_str] + + assert my_dtype.itemsize == ctypes.sizeof(my_ctype) + + return my_dtype, my_ctype + + +def is_pickleable(obj: Any) -> bool: + try: + with io.BytesIO() as stream: + pickle.dump(obj, stream) + return True + except: + return False + + +# Functionality to import modules/objects by name, and call functions by name +# ------------------------------------------------------------------------------------------ + +def get_module_from_obj_name(obj_name: str) -> Tuple[types.ModuleType, str]: + """Searches for the underlying module behind the name to some python object. + Returns the module and the object name (original name with module part removed).""" + + # allow convenience shorthands, substitute them by full names + obj_name = re.sub("^np.", "numpy.", obj_name) + obj_name = re.sub("^tf.", "tensorflow.", obj_name) + + # list alternatives for (module_name, local_obj_name) + parts = obj_name.split(".") + name_pairs = [(".".join(parts[:i]), ".".join(parts[i:])) for i in range(len(parts), 0, -1)] + + # try each alternative in turn + for module_name, local_obj_name in name_pairs: + try: + module = importlib.import_module(module_name) # may raise ImportError + get_obj_from_module(module, local_obj_name) # may raise AttributeError + return module, local_obj_name + except: + pass + + # maybe some of the modules themselves contain errors? + for module_name, _local_obj_name in name_pairs: + try: + importlib.import_module(module_name) # may raise ImportError + except ImportError: + if not str(sys.exc_info()[1]).startswith("No module named '" + module_name + "'"): + raise + + # maybe the requested attribute is missing? + for module_name, local_obj_name in name_pairs: + try: + module = importlib.import_module(module_name) # may raise ImportError + get_obj_from_module(module, local_obj_name) # may raise AttributeError + except ImportError: + pass + + # we are out of luck, but we have no idea why + raise ImportError(obj_name) + + +def get_obj_from_module(module: types.ModuleType, obj_name: str) -> Any: + """Traverses the object name and returns the last (rightmost) python object.""" + if obj_name == '': + return module + obj = module + for part in obj_name.split("."): + obj = getattr(obj, part) + return obj + + +def get_obj_by_name(name: str) -> Any: + """Finds the python object with the given name.""" + module, obj_name = get_module_from_obj_name(name) + return get_obj_from_module(module, obj_name) + + +def call_func_by_name(*args, func_name: str = None, **kwargs) -> Any: + """Finds the python object with the given name and calls it as a function.""" + assert func_name is not None + func_obj = get_obj_by_name(func_name) + assert callable(func_obj) + return func_obj(*args, **kwargs) + + +def construct_class_by_name(*args, class_name: str = None, **kwargs) -> Any: + """Finds the python class with the given name and constructs it with the given arguments.""" + return call_func_by_name(*args, func_name=class_name, **kwargs) + + +def get_module_dir_by_obj_name(obj_name: str) -> str: + """Get the directory path of the module containing the given object name.""" + module, _ = get_module_from_obj_name(obj_name) + return os.path.dirname(inspect.getfile(module)) + + +def is_top_level_function(obj: Any) -> bool: + """Determine whether the given object is a top-level function, i.e., defined at module scope using 'def'.""" + return callable(obj) and obj.__name__ in sys.modules[obj.__module__].__dict__ + + +def get_top_level_function_name(obj: Any) -> str: + """Return the fully-qualified name of a top-level function.""" + assert is_top_level_function(obj) + module = obj.__module__ + if module == '__main__': + module = os.path.splitext(os.path.basename(sys.modules[module].__file__))[0] + return module + "." + obj.__name__ + + +# File system helpers +# ------------------------------------------------------------------------------------------ + +def list_dir_recursively_with_ignore(dir_path: str, ignores: List[str] = None, add_base_to_relative: bool = False) -> List[Tuple[str, str]]: + """List all files recursively in a given directory while ignoring given file and directory names. + Returns list of tuples containing both absolute and relative paths.""" + assert os.path.isdir(dir_path) + base_name = os.path.basename(os.path.normpath(dir_path)) + + if ignores is None: + ignores = [] + + result = [] + + for root, dirs, files in os.walk(dir_path, topdown=True): + for ignore_ in ignores: + dirs_to_remove = [d for d in dirs if fnmatch.fnmatch(d, ignore_)] + + # dirs need to be edited in-place + for d in dirs_to_remove: + dirs.remove(d) + + files = [f for f in files if not fnmatch.fnmatch(f, ignore_)] + + absolute_paths = [os.path.join(root, f) for f in files] + relative_paths = [os.path.relpath(p, dir_path) for p in absolute_paths] + + if add_base_to_relative: + relative_paths = [os.path.join(base_name, p) for p in relative_paths] + + assert len(absolute_paths) == len(relative_paths) + result += zip(absolute_paths, relative_paths) + + return result + + +def copy_files_and_create_dirs(files: List[Tuple[str, str]]) -> None: + """Takes in a list of tuples of (src, dst) paths and copies files. + Will create all necessary directories.""" + for file in files: + target_dir_name = os.path.dirname(file[1]) + + # will create all intermediate-level directories + if not os.path.exists(target_dir_name): + os.makedirs(target_dir_name) + + shutil.copyfile(file[0], file[1]) + + +# URL helpers +# ------------------------------------------------------------------------------------------ + +def is_url(obj: Any, allow_file_urls: bool = False) -> bool: + """Determine whether the given object is a valid URL string.""" + if not isinstance(obj, str) or not "://" in obj: + return False + if allow_file_urls and obj.startswith('file://'): + return True + try: + res = requests.compat.urlparse(obj) + if not res.scheme or not res.netloc or not "." in res.netloc: + return False + res = requests.compat.urlparse(requests.compat.urljoin(obj, "/")) + if not res.scheme or not res.netloc or not "." in res.netloc: + return False + except: + return False + return True + + +def open_url(url: str, cache_dir: str = None, num_attempts: int = 10, verbose: bool = True, return_filename: bool = False, cache: bool = True) -> Any: + """Download the given URL and return a binary-mode file object to access the data.""" + assert num_attempts >= 1 + assert not (return_filename and (not cache)) + + # Doesn't look like an URL scheme so interpret it as a local filename. + if not re.match('^[a-z]+://', url): + return url if return_filename else open(url, "rb") + + # Handle file URLs. This code handles unusual file:// patterns that + # arise on Windows: + # + # file:///c:/foo.txt + # + # which would translate to a local '/c:/foo.txt' filename that's + # invalid. Drop the forward slash for such pathnames. + # + # If you touch this code path, you should test it on both Linux and + # Windows. + # + # Some internet resources suggest using urllib.request.url2pathname() but + # but that converts forward slashes to backslashes and this causes + # its own set of problems. + if url.startswith('file://'): + filename = urllib.parse.urlparse(url).path + if re.match(r'^/[a-zA-Z]:', filename): + filename = filename[1:] + return filename if return_filename else open(filename, "rb") + + assert is_url(url) + + # Lookup from cache. + if cache_dir is None: + cache_dir = make_cache_dir_path('downloads') + + url_md5 = hashlib.md5(url.encode("utf-8")).hexdigest() + if cache: + cache_files = glob.glob(os.path.join(cache_dir, url_md5 + "_*")) + if len(cache_files) == 1: + filename = cache_files[0] + return filename if return_filename else open(filename, "rb") + + # Download. + url_name = None + url_data = None + with requests.Session() as session: + if verbose: + print("Downloading %s ..." % url, end="", flush=True) + for attempts_left in reversed(range(num_attempts)): + try: + with session.get(url) as res: + res.raise_for_status() + if len(res.content) == 0: + raise IOError("No data received") + + if len(res.content) < 8192: + content_str = res.content.decode("utf-8") + if "download_warning" in res.headers.get("Set-Cookie", ""): + links = [html.unescape(link) for link in content_str.split('"') if "export=download" in link] + if len(links) == 1: + url = requests.compat.urljoin(url, links[0]) + raise IOError("Google Drive virus checker nag") + if "Google Drive - Quota exceeded" in content_str: + raise IOError("Google Drive download quota exceeded -- please try again later") + + match = re.search(r'filename="([^"]*)"', res.headers.get("Content-Disposition", "")) + url_name = match[1] if match else url + url_data = res.content + if verbose: + print(" done") + break + except KeyboardInterrupt: + raise + except: + if not attempts_left: + if verbose: + print(" failed") + raise + if verbose: + print(".", end="", flush=True) + + # Save to cache. + if cache: + safe_name = re.sub(r"[^0-9a-zA-Z-._]", "_", url_name) + cache_file = os.path.join(cache_dir, url_md5 + "_" + safe_name) + temp_file = os.path.join(cache_dir, "tmp_" + uuid.uuid4().hex + "_" + url_md5 + "_" + safe_name) + os.makedirs(cache_dir, exist_ok=True) + with open(temp_file, "wb") as f: + f.write(url_data) + os.replace(temp_file, cache_file) # atomic + if return_filename: + return cache_file + + # Return data as file object. + assert not return_filename + return io.BytesIO(url_data) diff --git a/encoder4editing/LICENSE b/encoder4editing/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..88ba9d421ea8acea9b4e3937535e72c282b3d4e6 --- /dev/null +++ b/encoder4editing/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2021 omertov + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/encoder4editing/configs/__init__.py b/encoder4editing/configs/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/encoder4editing/configs/data_configs.py b/encoder4editing/configs/data_configs.py new file mode 100644 index 0000000000000000000000000000000000000000..deccb0b1c266ad4b6abaef53d67ec1ed0ddbd462 --- /dev/null +++ b/encoder4editing/configs/data_configs.py @@ -0,0 +1,41 @@ +from configs import transforms_config +from configs.paths_config import dataset_paths + + +DATASETS = { + 'ffhq_encode': { + 'transforms': transforms_config.EncodeTransforms, + 'train_source_root': dataset_paths['ffhq'], + 'train_target_root': dataset_paths['ffhq'], + 'test_source_root': dataset_paths['celeba_test'], + 'test_target_root': dataset_paths['celeba_test'], + }, + 'cars_encode': { + 'transforms': transforms_config.CarsEncodeTransforms, + 'train_source_root': dataset_paths['cars_train'], + 'train_target_root': dataset_paths['cars_train'], + 'test_source_root': dataset_paths['cars_test'], + 'test_target_root': dataset_paths['cars_test'], + }, + 'horse_encode': { + 'transforms': transforms_config.EncodeTransforms, + 'train_source_root': dataset_paths['horse_train'], + 'train_target_root': dataset_paths['horse_train'], + 'test_source_root': dataset_paths['horse_test'], + 'test_target_root': dataset_paths['horse_test'], + }, + 'church_encode': { + 'transforms': transforms_config.EncodeTransforms, + 'train_source_root': dataset_paths['church_train'], + 'train_target_root': dataset_paths['church_train'], + 'test_source_root': dataset_paths['church_test'], + 'test_target_root': dataset_paths['church_test'], + }, + 'cats_encode': { + 'transforms': transforms_config.EncodeTransforms, + 'train_source_root': dataset_paths['cats_train'], + 'train_target_root': dataset_paths['cats_train'], + 'test_source_root': dataset_paths['cats_test'], + 'test_target_root': dataset_paths['cats_test'], + } +} diff --git a/encoder4editing/configs/paths_config.py b/encoder4editing/configs/paths_config.py new file mode 100644 index 0000000000000000000000000000000000000000..4604f6063b8125364a52a492de52fcc54004f373 --- /dev/null +++ b/encoder4editing/configs/paths_config.py @@ -0,0 +1,28 @@ +dataset_paths = { + # Face Datasets (In the paper: FFHQ - train, CelebAHQ - test) + 'ffhq': '', + 'celeba_test': '', + + # Cars Dataset (In the paper: Stanford cars) + 'cars_train': '', + 'cars_test': '', + + # Horse Dataset (In the paper: LSUN Horse) + 'horse_train': '', + 'horse_test': '', + + # Church Dataset (In the paper: LSUN Church) + 'church_train': '', + 'church_test': '', + + # Cats Dataset (In the paper: LSUN Cat) + 'cats_train': '', + 'cats_test': '' +} + +model_paths = { + 'stylegan_ffhq': 'pretrained_models/stylegan2-ffhq-config-f.pt', + 'ir_se50': 'pretrained_models/model_ir_se50.pth', + 'shape_predictor': 'pretrained_models/shape_predictor_68_face_landmarks.dat', + 'moco': 'pretrained_models/moco_v2_800ep_pretrain.pth' +} diff --git a/encoder4editing/configs/transforms_config.py b/encoder4editing/configs/transforms_config.py new file mode 100644 index 0000000000000000000000000000000000000000..ac12b5d5ba0571f21715e0f6b24b9c1ebe84bf72 --- /dev/null +++ b/encoder4editing/configs/transforms_config.py @@ -0,0 +1,62 @@ +from abc import abstractmethod +import torchvision.transforms as transforms + + +class TransformsConfig(object): + + def __init__(self, opts): + self.opts = opts + + @abstractmethod + def get_transforms(self): + pass + + +class EncodeTransforms(TransformsConfig): + + def __init__(self, opts): + super(EncodeTransforms, self).__init__(opts) + + def get_transforms(self): + transforms_dict = { + 'transform_gt_train': transforms.Compose([ + transforms.Resize((256, 256)), + transforms.RandomHorizontalFlip(0.5), + transforms.ToTensor(), + transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]), + 'transform_source': None, + 'transform_test': transforms.Compose([ + transforms.Resize((256, 256)), + transforms.ToTensor(), + transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]), + 'transform_inference': transforms.Compose([ + transforms.Resize((256, 256)), + transforms.ToTensor(), + transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]) + } + return transforms_dict + + +class CarsEncodeTransforms(TransformsConfig): + + def __init__(self, opts): + super(CarsEncodeTransforms, self).__init__(opts) + + def get_transforms(self): + transforms_dict = { + 'transform_gt_train': transforms.Compose([ + transforms.Resize((192, 256)), + transforms.RandomHorizontalFlip(0.5), + transforms.ToTensor(), + transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]), + 'transform_source': None, + 'transform_test': transforms.Compose([ + transforms.Resize((192, 256)), + transforms.ToTensor(), + transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]), + 'transform_inference': transforms.Compose([ + transforms.Resize((192, 256)), + transforms.ToTensor(), + transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]) + } + return transforms_dict diff --git a/encoder4editing/criteria/__init__.py b/encoder4editing/criteria/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/encoder4editing/criteria/id_loss.py b/encoder4editing/criteria/id_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..bab806172eff18c0630536ae96817508c3197b8b --- /dev/null +++ b/encoder4editing/criteria/id_loss.py @@ -0,0 +1,47 @@ +import torch +from torch import nn +from configs.paths_config import model_paths +from models.encoders.model_irse import Backbone + + +class IDLoss(nn.Module): + def __init__(self): + super(IDLoss, self).__init__() + print('Loading ResNet ArcFace') + self.facenet = Backbone(input_size=112, num_layers=50, drop_ratio=0.6, mode='ir_se') + self.facenet.load_state_dict(torch.load(model_paths['ir_se50'])) + self.face_pool = torch.nn.AdaptiveAvgPool2d((112, 112)) + self.facenet.eval() + for module in [self.facenet, self.face_pool]: + for param in module.parameters(): + param.requires_grad = False + + def extract_feats(self, x): + x = x[:, :, 35:223, 32:220] # Crop interesting region + x = self.face_pool(x) + x_feats = self.facenet(x) + return x_feats + + def forward(self, y_hat, y, x): + n_samples = x.shape[0] + x_feats = self.extract_feats(x) + y_feats = self.extract_feats(y) # Otherwise use the feature from there + y_hat_feats = self.extract_feats(y_hat) + y_feats = y_feats.detach() + loss = 0 + sim_improvement = 0 + id_logs = [] + count = 0 + for i in range(n_samples): + diff_target = y_hat_feats[i].dot(y_feats[i]) + diff_input = y_hat_feats[i].dot(x_feats[i]) + diff_views = y_feats[i].dot(x_feats[i]) + id_logs.append({'diff_target': float(diff_target), + 'diff_input': float(diff_input), + 'diff_views': float(diff_views)}) + loss += 1 - diff_target + id_diff = float(diff_target) - float(diff_views) + sim_improvement += id_diff + count += 1 + + return loss / count, sim_improvement / count, id_logs diff --git a/encoder4editing/criteria/lpips/__init__.py b/encoder4editing/criteria/lpips/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/encoder4editing/criteria/lpips/lpips.py b/encoder4editing/criteria/lpips/lpips.py new file mode 100644 index 0000000000000000000000000000000000000000..1add6acc84c1c04cfcb536cf31ec5acdf24b716b --- /dev/null +++ b/encoder4editing/criteria/lpips/lpips.py @@ -0,0 +1,35 @@ +import torch +import torch.nn as nn + +from criteria.lpips.networks import get_network, LinLayers +from criteria.lpips.utils import get_state_dict + + +class LPIPS(nn.Module): + r"""Creates a criterion that measures + Learned Perceptual Image Patch Similarity (LPIPS). + Arguments: + net_type (str): the network type to compare the features: + 'alex' | 'squeeze' | 'vgg'. Default: 'alex'. + version (str): the version of LPIPS. Default: 0.1. + """ + def __init__(self, net_type: str = 'alex', version: str = '0.1'): + + assert version in ['0.1'], 'v0.1 is only supported now' + + super(LPIPS, self).__init__() + + # pretrained network + self.net = get_network(net_type).to("cuda") + + # linear layers + self.lin = LinLayers(self.net.n_channels_list).to("cuda") + self.lin.load_state_dict(get_state_dict(net_type, version)) + + def forward(self, x: torch.Tensor, y: torch.Tensor): + feat_x, feat_y = self.net(x), self.net(y) + + diff = [(fx - fy) ** 2 for fx, fy in zip(feat_x, feat_y)] + res = [l(d).mean((2, 3), True) for d, l in zip(diff, self.lin)] + + return torch.sum(torch.cat(res, 0)) / x.shape[0] diff --git a/encoder4editing/criteria/lpips/networks.py b/encoder4editing/criteria/lpips/networks.py new file mode 100644 index 0000000000000000000000000000000000000000..3a0d13ad2d560278f16586da68d3a5eadb26e746 --- /dev/null +++ b/encoder4editing/criteria/lpips/networks.py @@ -0,0 +1,96 @@ +from typing import Sequence + +from itertools import chain + +import torch +import torch.nn as nn +from torchvision import models + +from criteria.lpips.utils import normalize_activation + + +def get_network(net_type: str): + if net_type == 'alex': + return AlexNet() + elif net_type == 'squeeze': + return SqueezeNet() + elif net_type == 'vgg': + return VGG16() + else: + raise NotImplementedError('choose net_type from [alex, squeeze, vgg].') + + +class LinLayers(nn.ModuleList): + def __init__(self, n_channels_list: Sequence[int]): + super(LinLayers, self).__init__([ + nn.Sequential( + nn.Identity(), + nn.Conv2d(nc, 1, 1, 1, 0, bias=False) + ) for nc in n_channels_list + ]) + + for param in self.parameters(): + param.requires_grad = False + + +class BaseNet(nn.Module): + def __init__(self): + super(BaseNet, self).__init__() + + # register buffer + self.register_buffer( + 'mean', torch.Tensor([-.030, -.088, -.188])[None, :, None, None]) + self.register_buffer( + 'std', torch.Tensor([.458, .448, .450])[None, :, None, None]) + + def set_requires_grad(self, state: bool): + for param in chain(self.parameters(), self.buffers()): + param.requires_grad = state + + def z_score(self, x: torch.Tensor): + return (x - self.mean) / self.std + + def forward(self, x: torch.Tensor): + x = self.z_score(x) + + output = [] + for i, (_, layer) in enumerate(self.layers._modules.items(), 1): + x = layer(x) + if i in self.target_layers: + output.append(normalize_activation(x)) + if len(output) == len(self.target_layers): + break + return output + + +class SqueezeNet(BaseNet): + def __init__(self): + super(SqueezeNet, self).__init__() + + self.layers = models.squeezenet1_1(True).features + self.target_layers = [2, 5, 8, 10, 11, 12, 13] + self.n_channels_list = [64, 128, 256, 384, 384, 512, 512] + + self.set_requires_grad(False) + + +class AlexNet(BaseNet): + def __init__(self): + super(AlexNet, self).__init__() + + self.layers = models.alexnet(True).features + self.target_layers = [2, 5, 8, 10, 12] + self.n_channels_list = [64, 192, 384, 256, 256] + + self.set_requires_grad(False) + + +class VGG16(BaseNet): + def __init__(self): + super(VGG16, self).__init__() + + self.layers = models.vgg16(True).features + self.target_layers = [4, 9, 16, 23, 30] + self.n_channels_list = [64, 128, 256, 512, 512] + + self.set_requires_grad(False) \ No newline at end of file diff --git a/encoder4editing/criteria/lpips/utils.py b/encoder4editing/criteria/lpips/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..3d15a0983775810ef6239c561c67939b2b9ee3b5 --- /dev/null +++ b/encoder4editing/criteria/lpips/utils.py @@ -0,0 +1,30 @@ +from collections import OrderedDict + +import torch + + +def normalize_activation(x, eps=1e-10): + norm_factor = torch.sqrt(torch.sum(x ** 2, dim=1, keepdim=True)) + return x / (norm_factor + eps) + + +def get_state_dict(net_type: str = 'alex', version: str = '0.1'): + # build url + url = 'https://raw.githubusercontent.com/richzhang/PerceptualSimilarity/' \ + + f'master/lpips/weights/v{version}/{net_type}.pth' + + # download + old_state_dict = torch.hub.load_state_dict_from_url( + url, progress=True, + map_location=None if torch.cuda.is_available() else torch.device('cpu') + ) + + # rename keys + new_state_dict = OrderedDict() + for key, val in old_state_dict.items(): + new_key = key + new_key = new_key.replace('lin', '') + new_key = new_key.replace('model.', '') + new_state_dict[new_key] = val + + return new_state_dict diff --git a/encoder4editing/criteria/moco_loss.py b/encoder4editing/criteria/moco_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..8fb13fbd426202cff9014c876c85b0d5c4ec6a9d --- /dev/null +++ b/encoder4editing/criteria/moco_loss.py @@ -0,0 +1,71 @@ +import torch +from torch import nn +import torch.nn.functional as F + +from configs.paths_config import model_paths + + +class MocoLoss(nn.Module): + + def __init__(self, opts): + super(MocoLoss, self).__init__() + print("Loading MOCO model from path: {}".format(model_paths["moco"])) + self.model = self.__load_model() + self.model.eval() + for param in self.model.parameters(): + param.requires_grad = False + + @staticmethod + def __load_model(): + import torchvision.models as models + model = models.__dict__["resnet50"]() + # freeze all layers but the last fc + for name, param in model.named_parameters(): + if name not in ['fc.weight', 'fc.bias']: + param.requires_grad = False + checkpoint = torch.load(model_paths['moco'], map_location="cpu") + state_dict = checkpoint['state_dict'] + # rename moco pre-trained keys + for k in list(state_dict.keys()): + # retain only encoder_q up to before the embedding layer + if k.startswith('module.encoder_q') and not k.startswith('module.encoder_q.fc'): + # remove prefix + state_dict[k[len("module.encoder_q."):]] = state_dict[k] + # delete renamed or unused k + del state_dict[k] + msg = model.load_state_dict(state_dict, strict=False) + assert set(msg.missing_keys) == {"fc.weight", "fc.bias"} + # remove output layer + model = nn.Sequential(*list(model.children())[:-1]).cuda() + return model + + def extract_feats(self, x): + x = F.interpolate(x, size=224) + x_feats = self.model(x) + x_feats = nn.functional.normalize(x_feats, dim=1) + x_feats = x_feats.squeeze() + return x_feats + + def forward(self, y_hat, y, x): + n_samples = x.shape[0] + x_feats = self.extract_feats(x) + y_feats = self.extract_feats(y) + y_hat_feats = self.extract_feats(y_hat) + y_feats = y_feats.detach() + loss = 0 + sim_improvement = 0 + sim_logs = [] + count = 0 + for i in range(n_samples): + diff_target = y_hat_feats[i].dot(y_feats[i]) + diff_input = y_hat_feats[i].dot(x_feats[i]) + diff_views = y_feats[i].dot(x_feats[i]) + sim_logs.append({'diff_target': float(diff_target), + 'diff_input': float(diff_input), + 'diff_views': float(diff_views)}) + loss += 1 - diff_target + sim_diff = float(diff_target) - float(diff_views) + sim_improvement += sim_diff + count += 1 + + return loss / count, sim_improvement / count, sim_logs diff --git a/encoder4editing/criteria/w_norm.py b/encoder4editing/criteria/w_norm.py new file mode 100644 index 0000000000000000000000000000000000000000..a45ab6f67d8a3f7051be4b7236fa2f38446fd2c1 --- /dev/null +++ b/encoder4editing/criteria/w_norm.py @@ -0,0 +1,14 @@ +import torch +from torch import nn + + +class WNormLoss(nn.Module): + + def __init__(self, start_from_latent_avg=True): + super(WNormLoss, self).__init__() + self.start_from_latent_avg = start_from_latent_avg + + def forward(self, latent, latent_avg=None): + if self.start_from_latent_avg: + latent = latent - latent_avg + return torch.sum(latent.norm(2, dim=(1, 2))) / latent.shape[0] diff --git a/encoder4editing/datasets/__init__.py b/encoder4editing/datasets/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/encoder4editing/datasets/gt_res_dataset.py b/encoder4editing/datasets/gt_res_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..c0beacfee5335aa10aa7e8b7cabe206d7f9a56f7 --- /dev/null +++ b/encoder4editing/datasets/gt_res_dataset.py @@ -0,0 +1,32 @@ +#!/usr/bin/python +# encoding: utf-8 +import os +from torch.utils.data import Dataset +from PIL import Image +import torch + +class GTResDataset(Dataset): + + def __init__(self, root_path, gt_dir=None, transform=None, transform_train=None): + self.pairs = [] + for f in os.listdir(root_path): + image_path = os.path.join(root_path, f) + gt_path = os.path.join(gt_dir, f) + if f.endswith(".jpg") or f.endswith(".png"): + self.pairs.append([image_path, gt_path.replace('.png', '.jpg'), None]) + self.transform = transform + self.transform_train = transform_train + + def __len__(self): + return len(self.pairs) + + def __getitem__(self, index): + from_path, to_path, _ = self.pairs[index] + from_im = Image.open(from_path).convert('RGB') + to_im = Image.open(to_path).convert('RGB') + + if self.transform: + to_im = self.transform(to_im) + from_im = self.transform(from_im) + + return from_im, to_im diff --git a/encoder4editing/datasets/images_dataset.py b/encoder4editing/datasets/images_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..00c54c7db944569a749af4c6f0c4d99fcc37f9cc --- /dev/null +++ b/encoder4editing/datasets/images_dataset.py @@ -0,0 +1,33 @@ +from torch.utils.data import Dataset +from PIL import Image +from utils import data_utils + + +class ImagesDataset(Dataset): + + def __init__(self, source_root, target_root, opts, target_transform=None, source_transform=None): + self.source_paths = sorted(data_utils.make_dataset(source_root)) + self.target_paths = sorted(data_utils.make_dataset(target_root)) + self.source_transform = source_transform + self.target_transform = target_transform + self.opts = opts + + def __len__(self): + return len(self.source_paths) + + def __getitem__(self, index): + from_path = self.source_paths[index] + from_im = Image.open(from_path) + from_im = from_im.convert('RGB') + + to_path = self.target_paths[index] + to_im = Image.open(to_path).convert('RGB') + if self.target_transform: + to_im = self.target_transform(to_im) + + if self.source_transform: + from_im = self.source_transform(from_im) + else: + from_im = to_im + + return from_im, to_im diff --git a/encoder4editing/datasets/inference_dataset.py b/encoder4editing/datasets/inference_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..fb577d7b538d634f27013c2784d2ea32143154cb --- /dev/null +++ b/encoder4editing/datasets/inference_dataset.py @@ -0,0 +1,25 @@ +from torch.utils.data import Dataset +from PIL import Image +from utils import data_utils + + +class InferenceDataset(Dataset): + + def __init__(self, root, opts, transform=None, preprocess=None): + self.paths = sorted(data_utils.make_dataset(root)) + self.transform = transform + self.preprocess = preprocess + self.opts = opts + + def __len__(self): + return len(self.paths) + + def __getitem__(self, index): + from_path = self.paths[index] + if self.preprocess is not None: + from_im = self.preprocess(from_path) + else: + from_im = Image.open(from_path).convert('RGB') + if self.transform: + from_im = self.transform(from_im) + return from_im diff --git a/encoder4editing/editings/ganspace.py b/encoder4editing/editings/ganspace.py new file mode 100644 index 0000000000000000000000000000000000000000..0c286a421280c542e9776a75e64bb65409da8fc7 --- /dev/null +++ b/encoder4editing/editings/ganspace.py @@ -0,0 +1,22 @@ +import torch + + +def edit(latents, pca, edit_directions): + edit_latents = [] + for latent in latents: + for pca_idx, start, end, strength in edit_directions: + delta = get_delta(pca, latent, pca_idx, strength) + delta_padded = torch.zeros(latent.shape).to('cuda') + delta_padded[start:end] += delta.repeat(end - start, 1) + edit_latents.append(latent + delta_padded) + return torch.stack(edit_latents) + + +def get_delta(pca, latent, idx, strength): + # pca: ganspace checkpoint. latent: (16, 512) w+ + w_centered = latent - pca['mean'].to('cuda') + lat_comp = pca['comp'].to('cuda') + lat_std = pca['std'].to('cuda') + w_coord = torch.sum(w_centered[0].reshape(-1)*lat_comp[idx].reshape(-1)) / lat_std[idx] + delta = (strength - w_coord)*lat_comp[idx]*lat_std[idx] + return delta diff --git a/encoder4editing/editings/ganspace_pca/cars_pca.pt b/encoder4editing/editings/ganspace_pca/cars_pca.pt new file mode 100644 index 0000000000000000000000000000000000000000..41c2618317f92be5089f99e1f566e9a45650b1bb --- /dev/null +++ b/encoder4editing/editings/ganspace_pca/cars_pca.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a5c3bae61ecd85de077fbbf103f5f30cf4b7676fe23a8508166eaf2ce73c8392 +size 167562 diff --git a/encoder4editing/editings/ganspace_pca/ffhq_pca.pt b/encoder4editing/editings/ganspace_pca/ffhq_pca.pt new file mode 100644 index 0000000000000000000000000000000000000000..8c8be273036803a6845ad067c8f659867343932d --- /dev/null +++ b/encoder4editing/editings/ganspace_pca/ffhq_pca.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4d7f9df1c96180d9026b9cb8d04753579fbf385f321a9d0e263641601c5e5d36 +size 167562 diff --git a/encoder4editing/editings/interfacegan_directions/age.pt b/encoder4editing/editings/interfacegan_directions/age.pt new file mode 100644 index 0000000000000000000000000000000000000000..64cdd22d071c643c59ce94d58334f09f647e8a83 --- /dev/null +++ b/encoder4editing/editings/interfacegan_directions/age.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:50074516b1629707d89b5e43d6b8abd1792212fa3b961a87a11323d6a5222ae0 +size 2808 diff --git a/encoder4editing/editings/interfacegan_directions/pose.pt b/encoder4editing/editings/interfacegan_directions/pose.pt new file mode 100644 index 0000000000000000000000000000000000000000..2b6ceffe285303e7b2b09287167dba965283570b --- /dev/null +++ b/encoder4editing/editings/interfacegan_directions/pose.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:736e0eacc8488fa0b020a2e7bd256b957284c364191dfea693705e5d06d43e7d +size 37624 diff --git a/encoder4editing/editings/interfacegan_directions/smile.pt b/encoder4editing/editings/interfacegan_directions/smile.pt new file mode 100644 index 0000000000000000000000000000000000000000..eeedc44689954510ce2c3bb585f9f9968ee06825 --- /dev/null +++ b/encoder4editing/editings/interfacegan_directions/smile.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:817a7e732b59dee9eba862bec8bd7e8373568443bc9f9731a21cf9b0356f0653 +size 2808 diff --git a/encoder4editing/editings/latent_editor.py b/encoder4editing/editings/latent_editor.py new file mode 100644 index 0000000000000000000000000000000000000000..4bebca2f5c86f71b58fa1f30d24bfcb0da06d88f --- /dev/null +++ b/encoder4editing/editings/latent_editor.py @@ -0,0 +1,45 @@ +import torch +import sys +sys.path.append(".") +sys.path.append("..") +from editings import ganspace, sefa +from utils.common import tensor2im + + +class LatentEditor(object): + def __init__(self, stylegan_generator, is_cars=False): + self.generator = stylegan_generator + self.is_cars = is_cars # Since the cars StyleGAN output is 384x512, there is a need to crop the 512x512 output. + + def apply_ganspace(self, latent, ganspace_pca, edit_directions): + edit_latents = ganspace.edit(latent, ganspace_pca, edit_directions) + return self._latents_to_image(edit_latents) + + def apply_interfacegan(self, latent, direction, factor=1, factor_range=None): + edit_latents = [] + if factor_range is not None: # Apply a range of editing factors. for example, (-5, 5) + for f in range(*factor_range): + edit_latent = latent + f * direction + edit_latents.append(edit_latent) + edit_latents = torch.cat(edit_latents) + else: + edit_latents = latent + factor * direction + return self._latents_to_image(edit_latents) + + def apply_sefa(self, latent, indices=[2, 3, 4, 5], **kwargs): + edit_latents = sefa.edit(self.generator, latent, indices, **kwargs) + return self._latents_to_image(edit_latents) + + # Currently, in order to apply StyleFlow editings, one should run inference, + # save the latent codes and load them form the official StyleFlow repository. + # def apply_styleflow(self): + # pass + + def _latents_to_image(self, latents): + with torch.no_grad(): + images, _ = self.generator([latents], randomize_noise=False, input_is_latent=True) + if self.is_cars: + images = images[:, :, 64:448, :] # 512x512 -> 384x512 + horizontal_concat_image = torch.cat(list(images), 2) + final_image = tensor2im(horizontal_concat_image) + return final_image diff --git a/encoder4editing/editings/sefa.py b/encoder4editing/editings/sefa.py new file mode 100644 index 0000000000000000000000000000000000000000..db7083ce463b765a7cf452807883a3b85fb63fa5 --- /dev/null +++ b/encoder4editing/editings/sefa.py @@ -0,0 +1,46 @@ +import torch +import numpy as np +from tqdm import tqdm + + +def edit(generator, latents, indices, semantics=1, start_distance=-15.0, end_distance=15.0, num_samples=1, step=11): + + layers, boundaries, values = factorize_weight(generator, indices) + codes = latents.detach().cpu().numpy() # (1,18,512) + + # Generate visualization pages. + distances = np.linspace(start_distance, end_distance, step) + num_sam = num_samples + num_sem = semantics + + edited_latents = [] + for sem_id in tqdm(range(num_sem), desc='Semantic ', leave=False): + boundary = boundaries[sem_id:sem_id + 1] + for sam_id in tqdm(range(num_sam), desc='Sample ', leave=False): + code = codes[sam_id:sam_id + 1] + for col_id, d in enumerate(distances, start=1): + temp_code = code.copy() + temp_code[:, layers, :] += boundary * d + edited_latents.append(torch.from_numpy(temp_code).float().cuda()) + return torch.cat(edited_latents) + + +def factorize_weight(g_ema, layers='all'): + + weights = [] + if layers == 'all' or 0 in layers: + weight = g_ema.conv1.conv.modulation.weight.T + weights.append(weight.cpu().detach().numpy()) + + if layers == 'all': + layers = list(range(g_ema.num_layers - 1)) + else: + layers = [l - 1 for l in layers if l != 0] + + for idx in layers: + weight = g_ema.convs[idx].conv.modulation.weight.T + weights.append(weight.cpu().detach().numpy()) + weight = np.concatenate(weights, axis=1).astype(np.float32) + weight = weight / np.linalg.norm(weight, axis=0, keepdims=True) + eigen_values, eigen_vectors = np.linalg.eig(weight.dot(weight.T)) + return layers, eigen_vectors.T, eigen_values diff --git a/encoder4editing/environment/e4e_env.yaml b/encoder4editing/environment/e4e_env.yaml new file mode 100644 index 0000000000000000000000000000000000000000..4f537615ebb47afd74b5a9856fb9cbea2e0c4bf4 --- /dev/null +++ b/encoder4editing/environment/e4e_env.yaml @@ -0,0 +1,73 @@ +name: e4e_env +channels: + - conda-forge + - defaults +dependencies: + - _libgcc_mutex=0.1=main + - ca-certificates=2020.4.5.1=hecc5488_0 + - certifi=2020.4.5.1=py36h9f0ad1d_0 + - libedit=3.1.20181209=hc058e9b_0 + - libffi=3.2.1=hd88cf55_4 + - libgcc-ng=9.1.0=hdf63c60_0 + - libstdcxx-ng=9.1.0=hdf63c60_0 + - ncurses=6.2=he6710b0_1 + - ninja=1.10.0=hc9558a2_0 + - openssl=1.1.1g=h516909a_0 + - pip=20.0.2=py36_3 + - python=3.6.7=h0371630_0 + - python_abi=3.6=1_cp36m + - readline=7.0=h7b6447c_5 + - setuptools=46.4.0=py36_0 + - sqlite=3.31.1=h62c20be_1 + - tk=8.6.8=hbc83047_0 + - wheel=0.34.2=py36_0 + - xz=5.2.5=h7b6447c_0 + - zlib=1.2.11=h7b6447c_3 + - pip: + - absl-py==0.9.0 + - cachetools==4.1.0 + - chardet==3.0.4 + - cycler==0.10.0 + - decorator==4.4.2 + - future==0.18.2 + - google-auth==1.15.0 + - google-auth-oauthlib==0.4.1 + - grpcio==1.29.0 + - idna==2.9 + - imageio==2.8.0 + - importlib-metadata==1.6.0 + - kiwisolver==1.2.0 + - markdown==3.2.2 + - matplotlib==3.2.1 + - mxnet==1.6.0 + - networkx==2.4 + - numpy==1.18.4 + - oauthlib==3.1.0 + - opencv-python==4.2.0.34 + - pillow==7.1.2 + - protobuf==3.12.1 + - pyasn1==0.4.8 + - pyasn1-modules==0.2.8 + - pyparsing==2.4.7 + - python-dateutil==2.8.1 + - pytorch-lightning==0.7.1 + - pywavelets==1.1.1 + - requests==2.23.0 + - requests-oauthlib==1.3.0 + - rsa==4.0 + - scikit-image==0.17.2 + - scipy==1.4.1 + - six==1.15.0 + - tensorboard==2.2.1 + - tensorboard-plugin-wit==1.6.0.post3 + - tensorboardx==1.9 + - tifffile==2020.5.25 + - torch==1.6.0 + - torchvision==0.7.1 + - tqdm==4.46.0 + - urllib3==1.25.9 + - werkzeug==1.0.1 + - zipp==3.1.0 + - pyaml +prefix: ~/anaconda3/envs/e4e_env + diff --git a/encoder4editing/infer.py b/encoder4editing/infer.py new file mode 100644 index 0000000000000000000000000000000000000000..925e89fc4f5725bcb9e6a3c545f4a803c0639fb7 --- /dev/null +++ b/encoder4editing/infer.py @@ -0,0 +1,134 @@ +import os +import argparse +from argparse import Namespace +import time +import os +import sys +import numpy as np +from PIL import Image +import torch +import torchvision.transforms as transforms + +sys.path.append(".") +sys.path.append("..") + +from utils.common import tensor2im +from models.psp import pSp # we use the pSp framework to load the e4e encoder. +experiment_type = 'ffhq_encode' + +parser = argparse.ArgumentParser() +parser.add_argument('--input_image', type=str, default="", help='input image path') +args = parser.parse_args() +opts = vars(args) +print(opts) +image_path = opts["input_image"] + +def get_download_model_command(file_id, file_name): + """ Get wget download command for downloading the desired model and save to directory pretrained_models. """ + current_directory = os.getcwd() + save_path = "encoder4editing/saves" + if not os.path.exists(save_path): + os.makedirs(save_path) + url = r"""wget --load-cookies /tmp/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id={FILE_ID}' -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id={FILE_ID}" -O {SAVE_PATH}/{FILE_NAME} && rm -rf /tmp/cookies.txt""".format(FILE_ID=file_id, FILE_NAME=file_name, SAVE_PATH=save_path) + return url + +MODEL_PATHS = { + "ffhq_encode": {"id": "1cUv_reLE6k3604or78EranS7XzuVMWeO", "name": "e4e_ffhq_encode.pt"}, + "cars_encode": {"id": "17faPqBce2m1AQeLCLHUVXaDfxMRU2QcV", "name": "e4e_cars_encode.pt"}, + "horse_encode": {"id": "1TkLLnuX86B_BMo2ocYD0kX9kWh53rUVX", "name": "e4e_horse_encode.pt"}, + "church_encode": {"id": "1-L0ZdnQLwtdy6-A_Ccgq5uNJGTqE7qBa", "name": "e4e_church_encode.pt"} +} + +path = MODEL_PATHS[experiment_type] +download_command = get_download_model_command(file_id=path["id"], file_name=path["name"]) + +EXPERIMENT_DATA_ARGS = { + "ffhq_encode": { + "model_path": "encoder4editing/e4e_ffhq_encode.pt", + "image_path": "notebooks/images/input_img.jpg" + }, + "cars_encode": { + "model_path": "pretrained_models/e4e_cars_encode.pt", + "image_path": "notebooks/images/car_img.jpg" + }, + "horse_encode": { + "model_path": "pretrained_models/e4e_horse_encode.pt", + "image_path": "notebooks/images/horse_img.jpg" + }, + "church_encode": { + "model_path": "pretrained_models/e4e_church_encode.pt", + "image_path": "notebooks/images/church_img.jpg" + } + +} +# Setup required image transformations +EXPERIMENT_ARGS = EXPERIMENT_DATA_ARGS[experiment_type] +if experiment_type == 'cars_encode': + EXPERIMENT_ARGS['transform'] = transforms.Compose([ + transforms.Resize((192, 256)), + transforms.ToTensor(), + transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]) + resize_dims = (256, 192) +else: + EXPERIMENT_ARGS['transform'] = transforms.Compose([ + transforms.Resize((256, 256)), + transforms.ToTensor(), + transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]) + resize_dims = (256, 256) + + +model_path = EXPERIMENT_ARGS['model_path'] +ckpt = torch.load(model_path, map_location='cpu') +opts = ckpt['opts'] + +# update the training options +opts['checkpoint_path'] = model_path +opts= Namespace(**opts) +net = pSp(opts) +net.eval() +net.cuda() +print('Model successfully loaded!') + + +original_image = Image.open(image_path) +original_image = original_image.convert("RGB") + +def run_alignment(image_path): + import dlib + from utils.alignment import align_face + predictor = dlib.shape_predictor("encoder4editing/shape_predictor_68_face_landmarks.dat") + aligned_image = align_face(filepath=image_path, predictor=predictor) + print("Aligned image has shape: {}".format(aligned_image.size)) + return aligned_image + +if experiment_type == "ffhq_encode": + input_image = run_alignment(image_path) +else: + input_image = original_image + +input_image.resize(resize_dims) + +img_transforms = EXPERIMENT_ARGS['transform'] +transformed_image = img_transforms(input_image) + +def display_alongside_source_image(result_image, source_image): + res = np.concatenate([np.array(source_image.resize(resize_dims)), + np.array(result_image.resize(resize_dims))], axis=1) + return Image.fromarray(res) + +def run_on_batch(inputs, net): + images, latents = net(inputs.to("cuda").float(), randomize_noise=False, return_latents=True) + if experiment_type == 'cars_encode': + images = images[:, :, 32:224, :] + return images, latents + +with torch.no_grad(): + tic = time.time() + images, latents = run_on_batch(transformed_image.unsqueeze(0), net) + result_image, latent = images[0], latents[0] + toc = time.time() + print('Inference took {:.4f} seconds.'.format(toc - tic)) + +# Display inversion: +display_alongside_source_image(tensor2im(result_image), input_image) +np.savez(f'encoder4editing/projected_w.npz', w=latents.cpu().numpy()) diff --git a/encoder4editing/metrics/LEC.py b/encoder4editing/metrics/LEC.py new file mode 100644 index 0000000000000000000000000000000000000000..3eef2d2f00a4d757a56b6e845a8fde16aab306ab --- /dev/null +++ b/encoder4editing/metrics/LEC.py @@ -0,0 +1,134 @@ +import sys +import argparse +import torch +import numpy as np +from torch.utils.data import DataLoader + +sys.path.append(".") +sys.path.append("..") + +from configs import data_configs +from datasets.images_dataset import ImagesDataset +from utils.model_utils import setup_model + + +class LEC: + def __init__(self, net, is_cars=False): + """ + Latent Editing Consistency metric as proposed in the main paper. + :param net: e4e model loaded over the pSp framework. + :param is_cars: An indication as to whether or not to crop the middle of the StyleGAN's output images. + """ + self.net = net + self.is_cars = is_cars + + def _encode(self, images): + """ + Encodes the given images into StyleGAN's latent space. + :param images: Tensor of shape NxCxHxW representing the images to be encoded. + :return: Tensor of shape NxKx512 representing the latent space embeddings of the given image (in W(K, *) space). + """ + codes = self.net.encoder(images) + assert codes.ndim == 3, f"Invalid latent codes shape, should be NxKx512 but is {codes.shape}" + # normalize with respect to the center of an average face + if self.net.opts.start_from_latent_avg: + codes = codes + self.net.latent_avg.repeat(codes.shape[0], 1, 1) + return codes + + def _generate(self, codes): + """ + Generate the StyleGAN2 images of the given codes + :param codes: Tensor of shape NxKx512 representing the StyleGAN's latent codes (in W(K, *) space). + :return: Tensor of shape NxCxHxW representing the generated images. + """ + images, _ = self.net.decoder([codes], input_is_latent=True, randomize_noise=False, return_latents=True) + images = self.net.face_pool(images) + if self.is_cars: + images = images[:, :, 32:224, :] + return images + + @staticmethod + def _filter_outliers(arr): + arr = np.array(arr) + + lo = np.percentile(arr, 1, interpolation="lower") + hi = np.percentile(arr, 99, interpolation="higher") + return np.extract( + np.logical_and(lo <= arr, arr <= hi), arr + ) + + def calculate_metric(self, data_loader, edit_function, inverse_edit_function): + """ + Calculate the LEC metric score. + :param data_loader: An iterable that returns a tuple of (images, _), similar to the training data loader. + :param edit_function: A function that receives latent codes and performs a semantically meaningful edit in the + latent space. + :param inverse_edit_function: A function that receives latent codes and performs the inverse edit of the + `edit_function` parameter. + :return: The LEC metric score. + """ + distances = [] + with torch.no_grad(): + for batch in data_loader: + x, _ = batch + inputs = x.to(device).float() + + codes = self._encode(inputs) + edited_codes = edit_function(codes) + edited_image = self._generate(edited_codes) + edited_image_inversion_codes = self._encode(edited_image) + inverse_edit_codes = inverse_edit_function(edited_image_inversion_codes) + + dist = (codes - inverse_edit_codes).norm(2, dim=(1, 2)).mean() + distances.append(dist.to("cpu").numpy()) + + distances = self._filter_outliers(distances) + return distances.mean() + + +if __name__ == "__main__": + device = "cuda" + + parser = argparse.ArgumentParser(description="LEC metric calculator") + + parser.add_argument("--batch", type=int, default=8, help="batch size for the models") + parser.add_argument("--images_dir", type=str, default=None, + help="Path to the images directory on which we calculate the LEC score") + parser.add_argument("ckpt", metavar="CHECKPOINT", help="path to the model checkpoints") + + args = parser.parse_args() + print(args) + + net, opts = setup_model(args.ckpt, device) + dataset_args = data_configs.DATASETS[opts.dataset_type] + transforms_dict = dataset_args['transforms'](opts).get_transforms() + + images_directory = dataset_args['test_source_root'] if args.images_dir is None else args.images_dir + test_dataset = ImagesDataset(source_root=images_directory, + target_root=images_directory, + source_transform=transforms_dict['transform_source'], + target_transform=transforms_dict['transform_test'], + opts=opts) + + data_loader = DataLoader(test_dataset, + batch_size=args.batch, + shuffle=False, + num_workers=2, + drop_last=True) + + print(f'dataset length: {len(test_dataset)}') + + # In the following example, we are using an InterfaceGAN based editing to calculate the LEC metric. + # Change the provided example according to your domain and needs. + direction = torch.load('../editings/interfacegan_directions/age.pt').to(device) + + def edit_func_example(codes): + return codes + 3 * direction + + + def inverse_edit_func_example(codes): + return codes - 3 * direction + + lec = LEC(net, is_cars='car' in opts.dataset_type) + result = lec.calculate_metric(data_loader, edit_func_example, inverse_edit_func_example) + print(f"LEC: {result}") diff --git a/encoder4editing/models/__init__.py b/encoder4editing/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/encoder4editing/models/discriminator.py b/encoder4editing/models/discriminator.py new file mode 100644 index 0000000000000000000000000000000000000000..16bf3722c7f2e35cdc9bd177a33ed0975e67200d --- /dev/null +++ b/encoder4editing/models/discriminator.py @@ -0,0 +1,20 @@ +from torch import nn + + +class LatentCodesDiscriminator(nn.Module): + def __init__(self, style_dim, n_mlp): + super().__init__() + + self.style_dim = style_dim + + layers = [] + for i in range(n_mlp-1): + layers.append( + nn.Linear(style_dim, style_dim) + ) + layers.append(nn.LeakyReLU(0.2)) + layers.append(nn.Linear(512, 1)) + self.mlp = nn.Sequential(*layers) + + def forward(self, w): + return self.mlp(w) diff --git a/encoder4editing/models/encoders/__init__.py b/encoder4editing/models/encoders/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/encoder4editing/models/encoders/helpers.py b/encoder4editing/models/encoders/helpers.py new file mode 100644 index 0000000000000000000000000000000000000000..c4a58b34ea5ca6912fe53c63dede0a8696f5c024 --- /dev/null +++ b/encoder4editing/models/encoders/helpers.py @@ -0,0 +1,140 @@ +from collections import namedtuple +import torch +import torch.nn.functional as F +from torch.nn import Conv2d, BatchNorm2d, PReLU, ReLU, Sigmoid, MaxPool2d, AdaptiveAvgPool2d, Sequential, Module + +""" +ArcFace implementation from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch) +""" + + +class Flatten(Module): + def forward(self, input): + return input.view(input.size(0), -1) + + +def l2_norm(input, axis=1): + norm = torch.norm(input, 2, axis, True) + output = torch.div(input, norm) + return output + + +class Bottleneck(namedtuple('Block', ['in_channel', 'depth', 'stride'])): + """ A named tuple describing a ResNet block. """ + + +def get_block(in_channel, depth, num_units, stride=2): + return [Bottleneck(in_channel, depth, stride)] + [Bottleneck(depth, depth, 1) for i in range(num_units - 1)] + + +def get_blocks(num_layers): + if num_layers == 50: + blocks = [ + get_block(in_channel=64, depth=64, num_units=3), + get_block(in_channel=64, depth=128, num_units=4), + get_block(in_channel=128, depth=256, num_units=14), + get_block(in_channel=256, depth=512, num_units=3) + ] + elif num_layers == 100: + blocks = [ + get_block(in_channel=64, depth=64, num_units=3), + get_block(in_channel=64, depth=128, num_units=13), + get_block(in_channel=128, depth=256, num_units=30), + get_block(in_channel=256, depth=512, num_units=3) + ] + elif num_layers == 152: + blocks = [ + get_block(in_channel=64, depth=64, num_units=3), + get_block(in_channel=64, depth=128, num_units=8), + get_block(in_channel=128, depth=256, num_units=36), + get_block(in_channel=256, depth=512, num_units=3) + ] + else: + raise ValueError("Invalid number of layers: {}. Must be one of [50, 100, 152]".format(num_layers)) + return blocks + + +class SEModule(Module): + def __init__(self, channels, reduction): + super(SEModule, self).__init__() + self.avg_pool = AdaptiveAvgPool2d(1) + self.fc1 = Conv2d(channels, channels // reduction, kernel_size=1, padding=0, bias=False) + self.relu = ReLU(inplace=True) + self.fc2 = Conv2d(channels // reduction, channels, kernel_size=1, padding=0, bias=False) + self.sigmoid = Sigmoid() + + def forward(self, x): + module_input = x + x = self.avg_pool(x) + x = self.fc1(x) + x = self.relu(x) + x = self.fc2(x) + x = self.sigmoid(x) + return module_input * x + + +class bottleneck_IR(Module): + def __init__(self, in_channel, depth, stride): + super(bottleneck_IR, self).__init__() + if in_channel == depth: + self.shortcut_layer = MaxPool2d(1, stride) + else: + self.shortcut_layer = Sequential( + Conv2d(in_channel, depth, (1, 1), stride, bias=False), + BatchNorm2d(depth) + ) + self.res_layer = Sequential( + BatchNorm2d(in_channel), + Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False), PReLU(depth), + Conv2d(depth, depth, (3, 3), stride, 1, bias=False), BatchNorm2d(depth) + ) + + def forward(self, x): + shortcut = self.shortcut_layer(x) + res = self.res_layer(x) + return res + shortcut + + +class bottleneck_IR_SE(Module): + def __init__(self, in_channel, depth, stride): + super(bottleneck_IR_SE, self).__init__() + if in_channel == depth: + self.shortcut_layer = MaxPool2d(1, stride) + else: + self.shortcut_layer = Sequential( + Conv2d(in_channel, depth, (1, 1), stride, bias=False), + BatchNorm2d(depth) + ) + self.res_layer = Sequential( + BatchNorm2d(in_channel), + Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False), + PReLU(depth), + Conv2d(depth, depth, (3, 3), stride, 1, bias=False), + BatchNorm2d(depth), + SEModule(depth, 16) + ) + + def forward(self, x): + shortcut = self.shortcut_layer(x) + res = self.res_layer(x) + return res + shortcut + + +def _upsample_add(x, y): + """Upsample and add two feature maps. + Args: + x: (Variable) top feature map to be upsampled. + y: (Variable) lateral feature map. + Returns: + (Variable) added feature map. + Note in PyTorch, when input size is odd, the upsampled feature map + with `F.upsample(..., scale_factor=2, mode='nearest')` + maybe not equal to the lateral feature map size. + e.g. + original input size: [N,_,15,15] -> + conv2d feature map size: [N,_,8,8] -> + upsampled feature map size: [N,_,16,16] + So we choose bilinear upsample which supports arbitrary output sizes. + """ + _, _, H, W = y.size() + return F.interpolate(x, size=(H, W), mode='bilinear', align_corners=True) + y diff --git a/encoder4editing/models/encoders/model_irse.py b/encoder4editing/models/encoders/model_irse.py new file mode 100644 index 0000000000000000000000000000000000000000..ea6c6091c1e71279ff0bc7e013b0cea287cb01b3 --- /dev/null +++ b/encoder4editing/models/encoders/model_irse.py @@ -0,0 +1,84 @@ +from torch.nn import Linear, Conv2d, BatchNorm1d, BatchNorm2d, PReLU, Dropout, Sequential, Module +from models.encoders.helpers import get_blocks, Flatten, bottleneck_IR, bottleneck_IR_SE, l2_norm + +""" +Modified Backbone implementation from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch) +""" + + +class Backbone(Module): + def __init__(self, input_size, num_layers, mode='ir', drop_ratio=0.4, affine=True): + super(Backbone, self).__init__() + assert input_size in [112, 224], "input_size should be 112 or 224" + assert num_layers in [50, 100, 152], "num_layers should be 50, 100 or 152" + assert mode in ['ir', 'ir_se'], "mode should be ir or ir_se" + blocks = get_blocks(num_layers) + if mode == 'ir': + unit_module = bottleneck_IR + elif mode == 'ir_se': + unit_module = bottleneck_IR_SE + self.input_layer = Sequential(Conv2d(3, 64, (3, 3), 1, 1, bias=False), + BatchNorm2d(64), + PReLU(64)) + if input_size == 112: + self.output_layer = Sequential(BatchNorm2d(512), + Dropout(drop_ratio), + Flatten(), + Linear(512 * 7 * 7, 512), + BatchNorm1d(512, affine=affine)) + else: + self.output_layer = Sequential(BatchNorm2d(512), + Dropout(drop_ratio), + Flatten(), + Linear(512 * 14 * 14, 512), + BatchNorm1d(512, affine=affine)) + + modules = [] + for block in blocks: + for bottleneck in block: + modules.append(unit_module(bottleneck.in_channel, + bottleneck.depth, + bottleneck.stride)) + self.body = Sequential(*modules) + + def forward(self, x): + x = self.input_layer(x) + x = self.body(x) + x = self.output_layer(x) + return l2_norm(x) + + +def IR_50(input_size): + """Constructs a ir-50 model.""" + model = Backbone(input_size, num_layers=50, mode='ir', drop_ratio=0.4, affine=False) + return model + + +def IR_101(input_size): + """Constructs a ir-101 model.""" + model = Backbone(input_size, num_layers=100, mode='ir', drop_ratio=0.4, affine=False) + return model + + +def IR_152(input_size): + """Constructs a ir-152 model.""" + model = Backbone(input_size, num_layers=152, mode='ir', drop_ratio=0.4, affine=False) + return model + + +def IR_SE_50(input_size): + """Constructs a ir_se-50 model.""" + model = Backbone(input_size, num_layers=50, mode='ir_se', drop_ratio=0.4, affine=False) + return model + + +def IR_SE_101(input_size): + """Constructs a ir_se-101 model.""" + model = Backbone(input_size, num_layers=100, mode='ir_se', drop_ratio=0.4, affine=False) + return model + + +def IR_SE_152(input_size): + """Constructs a ir_se-152 model.""" + model = Backbone(input_size, num_layers=152, mode='ir_se', drop_ratio=0.4, affine=False) + return model diff --git a/encoder4editing/models/encoders/psp_encoders.py b/encoder4editing/models/encoders/psp_encoders.py new file mode 100644 index 0000000000000000000000000000000000000000..ab52d04dbd8eac5adf673a1587b0b4ea9d6e68dd --- /dev/null +++ b/encoder4editing/models/encoders/psp_encoders.py @@ -0,0 +1,235 @@ +from enum import Enum +import math +import numpy as np +import torch +from torch import nn +from torch.nn import Conv2d, BatchNorm2d, PReLU, Sequential, Module + +from models.encoders.helpers import get_blocks, bottleneck_IR, bottleneck_IR_SE, _upsample_add +from models.stylegan2.model import EqualLinear + + +class ProgressiveStage(Enum): + WTraining = 0 + Delta1Training = 1 + Delta2Training = 2 + Delta3Training = 3 + Delta4Training = 4 + Delta5Training = 5 + Delta6Training = 6 + Delta7Training = 7 + Delta8Training = 8 + Delta9Training = 9 + Delta10Training = 10 + Delta11Training = 11 + Delta12Training = 12 + Delta13Training = 13 + Delta14Training = 14 + Delta15Training = 15 + Delta16Training = 16 + Delta17Training = 17 + Inference = 18 + + +class GradualStyleBlock(Module): + def __init__(self, in_c, out_c, spatial): + super(GradualStyleBlock, self).__init__() + self.out_c = out_c + self.spatial = spatial + num_pools = int(np.log2(spatial)) + modules = [] + modules += [Conv2d(in_c, out_c, kernel_size=3, stride=2, padding=1), + nn.LeakyReLU()] + for i in range(num_pools - 1): + modules += [ + Conv2d(out_c, out_c, kernel_size=3, stride=2, padding=1), + nn.LeakyReLU() + ] + self.convs = nn.Sequential(*modules) + self.linear = EqualLinear(out_c, out_c, lr_mul=1) + + def forward(self, x): + x = self.convs(x) + x = x.view(-1, self.out_c) + x = self.linear(x) + return x + + +class GradualStyleEncoder(Module): + def __init__(self, num_layers, mode='ir', opts=None): + super(GradualStyleEncoder, self).__init__() + assert num_layers in [50, 100, 152], 'num_layers should be 50,100, or 152' + assert mode in ['ir', 'ir_se'], 'mode should be ir or ir_se' + blocks = get_blocks(num_layers) + if mode == 'ir': + unit_module = bottleneck_IR + elif mode == 'ir_se': + unit_module = bottleneck_IR_SE + self.input_layer = Sequential(Conv2d(3, 64, (3, 3), 1, 1, bias=False), + BatchNorm2d(64), + PReLU(64)) + modules = [] + for block in blocks: + for bottleneck in block: + modules.append(unit_module(bottleneck.in_channel, + bottleneck.depth, + bottleneck.stride)) + self.body = Sequential(*modules) + + self.styles = nn.ModuleList() + log_size = int(math.log(opts.stylegan_size, 2)) + self.style_count = 2 * log_size - 2 + self.coarse_ind = 3 + self.middle_ind = 7 + for i in range(self.style_count): + if i < self.coarse_ind: + style = GradualStyleBlock(512, 512, 16) + elif i < self.middle_ind: + style = GradualStyleBlock(512, 512, 32) + else: + style = GradualStyleBlock(512, 512, 64) + self.styles.append(style) + self.latlayer1 = nn.Conv2d(256, 512, kernel_size=1, stride=1, padding=0) + self.latlayer2 = nn.Conv2d(128, 512, kernel_size=1, stride=1, padding=0) + + def forward(self, x): + x = self.input_layer(x) + + latents = [] + modulelist = list(self.body._modules.values()) + for i, l in enumerate(modulelist): + x = l(x) + if i == 6: + c1 = x + elif i == 20: + c2 = x + elif i == 23: + c3 = x + + for j in range(self.coarse_ind): + latents.append(self.styles[j](c3)) + + p2 = _upsample_add(c3, self.latlayer1(c2)) + for j in range(self.coarse_ind, self.middle_ind): + latents.append(self.styles[j](p2)) + + p1 = _upsample_add(p2, self.latlayer2(c1)) + for j in range(self.middle_ind, self.style_count): + latents.append(self.styles[j](p1)) + + out = torch.stack(latents, dim=1) + return out + + +class Encoder4Editing(Module): + def __init__(self, num_layers, mode='ir', opts=None): + super(Encoder4Editing, self).__init__() + assert num_layers in [50, 100, 152], 'num_layers should be 50,100, or 152' + assert mode in ['ir', 'ir_se'], 'mode should be ir or ir_se' + blocks = get_blocks(num_layers) + if mode == 'ir': + unit_module = bottleneck_IR + elif mode == 'ir_se': + unit_module = bottleneck_IR_SE + self.input_layer = Sequential(Conv2d(3, 64, (3, 3), 1, 1, bias=False), + BatchNorm2d(64), + PReLU(64)) + modules = [] + for block in blocks: + for bottleneck in block: + modules.append(unit_module(bottleneck.in_channel, + bottleneck.depth, + bottleneck.stride)) + self.body = Sequential(*modules) + + self.styles = nn.ModuleList() + log_size = int(math.log(opts.stylegan_size, 2)) + self.style_count = 2 * log_size - 2 + self.coarse_ind = 3 + self.middle_ind = 7 + + for i in range(self.style_count): + if i < self.coarse_ind: + style = GradualStyleBlock(512, 512, 16) + elif i < self.middle_ind: + style = GradualStyleBlock(512, 512, 32) + else: + style = GradualStyleBlock(512, 512, 64) + self.styles.append(style) + + self.latlayer1 = nn.Conv2d(256, 512, kernel_size=1, stride=1, padding=0) + self.latlayer2 = nn.Conv2d(128, 512, kernel_size=1, stride=1, padding=0) + + self.progressive_stage = ProgressiveStage.Inference + + def get_deltas_starting_dimensions(self): + ''' Get a list of the initial dimension of every delta from which it is applied ''' + return list(range(self.style_count)) # Each dimension has a delta applied to it + + def set_progressive_stage(self, new_stage: ProgressiveStage): + self.progressive_stage = new_stage + print('Changed progressive stage to: ', new_stage) + + def forward(self, x): + x = self.input_layer(x) + + modulelist = list(self.body._modules.values()) + for i, l in enumerate(modulelist): + x = l(x) + if i == 6: + c1 = x + elif i == 20: + c2 = x + elif i == 23: + c3 = x + + # Infer main W and duplicate it + w0 = self.styles[0](c3) + w = w0.repeat(self.style_count, 1, 1).permute(1, 0, 2) + stage = self.progressive_stage.value + features = c3 + for i in range(1, min(stage + 1, self.style_count)): # Infer additional deltas + if i == self.coarse_ind: + p2 = _upsample_add(c3, self.latlayer1(c2)) # FPN's middle features + features = p2 + elif i == self.middle_ind: + p1 = _upsample_add(p2, self.latlayer2(c1)) # FPN's fine features + features = p1 + delta_i = self.styles[i](features) + w[:, i] += delta_i + return w + + +class BackboneEncoderUsingLastLayerIntoW(Module): + def __init__(self, num_layers, mode='ir', opts=None): + super(BackboneEncoderUsingLastLayerIntoW, self).__init__() + print('Using BackboneEncoderUsingLastLayerIntoW') + assert num_layers in [50, 100, 152], 'num_layers should be 50,100, or 152' + assert mode in ['ir', 'ir_se'], 'mode should be ir or ir_se' + blocks = get_blocks(num_layers) + if mode == 'ir': + unit_module = bottleneck_IR + elif mode == 'ir_se': + unit_module = bottleneck_IR_SE + self.input_layer = Sequential(Conv2d(3, 64, (3, 3), 1, 1, bias=False), + BatchNorm2d(64), + PReLU(64)) + self.output_pool = torch.nn.AdaptiveAvgPool2d((1, 1)) + self.linear = EqualLinear(512, 512, lr_mul=1) + modules = [] + for block in blocks: + for bottleneck in block: + modules.append(unit_module(bottleneck.in_channel, + bottleneck.depth, + bottleneck.stride)) + self.body = Sequential(*modules) + log_size = int(math.log(opts.stylegan_size, 2)) + self.style_count = 2 * log_size - 2 + + def forward(self, x): + x = self.input_layer(x) + x = self.body(x) + x = self.output_pool(x) + x = x.view(-1, 512) + x = self.linear(x) + return x.repeat(self.style_count, 1, 1).permute(1, 0, 2) diff --git a/encoder4editing/models/latent_codes_pool.py b/encoder4editing/models/latent_codes_pool.py new file mode 100644 index 0000000000000000000000000000000000000000..0281d4b5e80f8eb26e824fa35b4f908dcb6634e6 --- /dev/null +++ b/encoder4editing/models/latent_codes_pool.py @@ -0,0 +1,55 @@ +import random +import torch + + +class LatentCodesPool: + """This class implements latent codes buffer that stores previously generated w latent codes. + This buffer enables us to update discriminators using a history of generated w's + rather than the ones produced by the latest encoder. + """ + + def __init__(self, pool_size): + """Initialize the ImagePool class + Parameters: + pool_size (int) -- the size of image buffer, if pool_size=0, no buffer will be created + """ + self.pool_size = pool_size + if self.pool_size > 0: # create an empty pool + self.num_ws = 0 + self.ws = [] + + def query(self, ws): + """Return w's from the pool. + Parameters: + ws: the latest generated w's from the generator + Returns w's from the buffer. + By 50/100, the buffer will return input w's. + By 50/100, the buffer will return w's previously stored in the buffer, + and insert the current w's to the buffer. + """ + if self.pool_size == 0: # if the buffer size is 0, do nothing + return ws + return_ws = [] + for w in ws: # ws.shape: (batch, 512) or (batch, n_latent, 512) + # w = torch.unsqueeze(image.data, 0) + if w.ndim == 2: + i = random.randint(0, len(w) - 1) # apply a random latent index as a candidate + w = w[i] + self.handle_w(w, return_ws) + return_ws = torch.stack(return_ws, 0) # collect all the images and return + return return_ws + + def handle_w(self, w, return_ws): + if self.num_ws < self.pool_size: # if the buffer is not full; keep inserting current codes to the buffer + self.num_ws = self.num_ws + 1 + self.ws.append(w) + return_ws.append(w) + else: + p = random.uniform(0, 1) + if p > 0.5: # by 50% chance, the buffer will return a previously stored latent code, and insert the current code into the buffer + random_id = random.randint(0, self.pool_size - 1) # randint is inclusive + tmp = self.ws[random_id].clone() + self.ws[random_id] = w + return_ws.append(tmp) + else: # by another 50% chance, the buffer will return the current image + return_ws.append(w) diff --git a/encoder4editing/models/psp.py b/encoder4editing/models/psp.py new file mode 100644 index 0000000000000000000000000000000000000000..6cae392654d6c3d678950a477921453781c81c7c --- /dev/null +++ b/encoder4editing/models/psp.py @@ -0,0 +1,100 @@ +import matplotlib + +matplotlib.use('Agg') +import torch +from torch import nn +from models.encoders import psp_encoders +from models.stylegan2.model import Generator +from configs.paths_config import model_paths + + +def get_keys(d, name): + if 'state_dict' in d: + d = d['state_dict'] + d_filt = {k[len(name) + 1:]: v for k, v in d.items() if k[:len(name)] == name} + return d_filt + + +class pSp(nn.Module): + + def __init__(self, opts): + super(pSp, self).__init__() + self.opts = opts + # Define architecture + self.encoder = self.set_encoder() + self.decoder = Generator(opts.stylegan_size, 512, 8, channel_multiplier=2) + self.face_pool = torch.nn.AdaptiveAvgPool2d((256, 256)) + # Load weights if needed + self.load_weights() + + def set_encoder(self): + if self.opts.encoder_type == 'GradualStyleEncoder': + encoder = psp_encoders.GradualStyleEncoder(50, 'ir_se', self.opts) + elif self.opts.encoder_type == 'Encoder4Editing': + encoder = psp_encoders.Encoder4Editing(50, 'ir_se', self.opts) + elif self.opts.encoder_type == 'SingleStyleCodeEncoder': + encoder = psp_encoders.BackboneEncoderUsingLastLayerIntoW(50, 'ir_se', self.opts) + else: + raise Exception('{} is not a valid encoders'.format(self.opts.encoder_type)) + return encoder + + def load_weights(self): + if self.opts.checkpoint_path is not None: + print('Loading e4e over the pSp framework from checkpoint: {}'.format(self.opts.checkpoint_path)) + ckpt = torch.load(self.opts.checkpoint_path, map_location='cpu') + self.encoder.load_state_dict(get_keys(ckpt, 'encoder'), strict=True) + self.decoder.load_state_dict(get_keys(ckpt, 'decoder'), strict=True) + self.__load_latent_avg(ckpt) + else: + print('Loading encoders weights from irse50!') + encoder_ckpt = torch.load(model_paths['ir_se50']) + self.encoder.load_state_dict(encoder_ckpt, strict=False) + print('Loading decoder weights from pretrained!') + ckpt = torch.load(self.opts.stylegan_weights) + self.decoder.load_state_dict(ckpt['g_ema'], strict=False) + self.__load_latent_avg(ckpt, repeat=self.encoder.style_count) + + def forward(self, x, resize=True, latent_mask=None, input_code=False, randomize_noise=True, + inject_latent=None, return_latents=False, alpha=None): + if input_code: + codes = x + else: + codes = self.encoder(x) + # normalize with respect to the center of an average face + if self.opts.start_from_latent_avg: + if codes.ndim == 2: + codes = codes + self.latent_avg.repeat(codes.shape[0], 1, 1)[:, 0, :] + else: + codes = codes + self.latent_avg.repeat(codes.shape[0], 1, 1) + + if latent_mask is not None: + for i in latent_mask: + if inject_latent is not None: + if alpha is not None: + codes[:, i] = alpha * inject_latent[:, i] + (1 - alpha) * codes[:, i] + else: + codes[:, i] = inject_latent[:, i] + else: + codes[:, i] = 0 + + input_is_latent = not input_code + images, result_latent = self.decoder([codes], + input_is_latent=input_is_latent, + randomize_noise=randomize_noise, + return_latents=return_latents) + + if resize: + images = self.face_pool(images) + + if return_latents: + return images, result_latent + else: + return images + + def __load_latent_avg(self, ckpt, repeat=None): + if 'latent_avg' in ckpt: + self.latent_avg = ckpt['latent_avg'].to(self.opts.device) + if repeat is not None: + self.latent_avg = self.latent_avg.repeat(repeat, 1) + else: + self.latent_avg = None diff --git a/encoder4editing/models/stylegan2/__init__.py b/encoder4editing/models/stylegan2/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/encoder4editing/models/stylegan2/model.py b/encoder4editing/models/stylegan2/model.py new file mode 100644 index 0000000000000000000000000000000000000000..54870486c6ef5a0d34e8e63b94ba5e3ac6e68944 --- /dev/null +++ b/encoder4editing/models/stylegan2/model.py @@ -0,0 +1,673 @@ +import math +import random +import torch +from torch import nn +from torch.nn import functional as F + +from models.stylegan2.op import FusedLeakyReLU, fused_leaky_relu, upfirdn2d + + +class PixelNorm(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, input): + return input * torch.rsqrt(torch.mean(input ** 2, dim=1, keepdim=True) + 1e-8) + + +def make_kernel(k): + k = torch.tensor(k, dtype=torch.float32) + + if k.ndim == 1: + k = k[None, :] * k[:, None] + + k /= k.sum() + + return k + + +class Upsample(nn.Module): + def __init__(self, kernel, factor=2): + super().__init__() + + self.factor = factor + kernel = make_kernel(kernel) * (factor ** 2) + self.register_buffer('kernel', kernel) + + p = kernel.shape[0] - factor + + pad0 = (p + 1) // 2 + factor - 1 + pad1 = p // 2 + + self.pad = (pad0, pad1) + + def forward(self, input): + out = upfirdn2d(input, self.kernel, up=self.factor, down=1, pad=self.pad) + + return out + + +class Downsample(nn.Module): + def __init__(self, kernel, factor=2): + super().__init__() + + self.factor = factor + kernel = make_kernel(kernel) + self.register_buffer('kernel', kernel) + + p = kernel.shape[0] - factor + + pad0 = (p + 1) // 2 + pad1 = p // 2 + + self.pad = (pad0, pad1) + + def forward(self, input): + out = upfirdn2d(input, self.kernel, up=1, down=self.factor, pad=self.pad) + + return out + + +class Blur(nn.Module): + def __init__(self, kernel, pad, upsample_factor=1): + super().__init__() + + kernel = make_kernel(kernel) + + if upsample_factor > 1: + kernel = kernel * (upsample_factor ** 2) + + self.register_buffer('kernel', kernel) + + self.pad = pad + + def forward(self, input): + out = upfirdn2d(input, self.kernel, pad=self.pad) + + return out + + +class EqualConv2d(nn.Module): + def __init__( + self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True + ): + super().__init__() + + self.weight = nn.Parameter( + torch.randn(out_channel, in_channel, kernel_size, kernel_size) + ) + self.scale = 1 / math.sqrt(in_channel * kernel_size ** 2) + + self.stride = stride + self.padding = padding + + if bias: + self.bias = nn.Parameter(torch.zeros(out_channel)) + + else: + self.bias = None + + def forward(self, input): + out = F.conv2d( + input, + self.weight * self.scale, + bias=self.bias, + stride=self.stride, + padding=self.padding, + ) + + return out + + def __repr__(self): + return ( + f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]},' + f' {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})' + ) + + +class EqualLinear(nn.Module): + def __init__( + self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None + ): + super().__init__() + + self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul)) + + if bias: + self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init)) + + else: + self.bias = None + + self.activation = activation + + self.scale = (1 / math.sqrt(in_dim)) * lr_mul + self.lr_mul = lr_mul + + def forward(self, input): + if self.activation: + out = F.linear(input, self.weight * self.scale) + out = fused_leaky_relu(out, self.bias * self.lr_mul) + + else: + out = F.linear( + input, self.weight * self.scale, bias=self.bias * self.lr_mul + ) + + return out + + def __repr__(self): + return ( + f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})' + ) + + +class ScaledLeakyReLU(nn.Module): + def __init__(self, negative_slope=0.2): + super().__init__() + + self.negative_slope = negative_slope + + def forward(self, input): + out = F.leaky_relu(input, negative_slope=self.negative_slope) + + return out * math.sqrt(2) + + +class ModulatedConv2d(nn.Module): + def __init__( + self, + in_channel, + out_channel, + kernel_size, + style_dim, + demodulate=True, + upsample=False, + downsample=False, + blur_kernel=[1, 3, 3, 1], + ): + super().__init__() + + self.eps = 1e-8 + self.kernel_size = kernel_size + self.in_channel = in_channel + self.out_channel = out_channel + self.upsample = upsample + self.downsample = downsample + + if upsample: + factor = 2 + p = (len(blur_kernel) - factor) - (kernel_size - 1) + pad0 = (p + 1) // 2 + factor - 1 + pad1 = p // 2 + 1 + + self.blur = Blur(blur_kernel, pad=(pad0, pad1), upsample_factor=factor) + + if downsample: + factor = 2 + p = (len(blur_kernel) - factor) + (kernel_size - 1) + pad0 = (p + 1) // 2 + pad1 = p // 2 + + self.blur = Blur(blur_kernel, pad=(pad0, pad1)) + + fan_in = in_channel * kernel_size ** 2 + self.scale = 1 / math.sqrt(fan_in) + self.padding = kernel_size // 2 + + self.weight = nn.Parameter( + torch.randn(1, out_channel, in_channel, kernel_size, kernel_size) + ) + + self.modulation = EqualLinear(style_dim, in_channel, bias_init=1) + + self.demodulate = demodulate + + def __repr__(self): + return ( + f'{self.__class__.__name__}({self.in_channel}, {self.out_channel}, {self.kernel_size}, ' + f'upsample={self.upsample}, downsample={self.downsample})' + ) + + def forward(self, input, style): + batch, in_channel, height, width = input.shape + + style = self.modulation(style).view(batch, 1, in_channel, 1, 1) + weight = self.scale * self.weight * style + + if self.demodulate: + demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-8) + weight = weight * demod.view(batch, self.out_channel, 1, 1, 1) + + weight = weight.view( + batch * self.out_channel, in_channel, self.kernel_size, self.kernel_size + ) + + if self.upsample: + input = input.view(1, batch * in_channel, height, width) + weight = weight.view( + batch, self.out_channel, in_channel, self.kernel_size, self.kernel_size + ) + weight = weight.transpose(1, 2).reshape( + batch * in_channel, self.out_channel, self.kernel_size, self.kernel_size + ) + out = F.conv_transpose2d(input, weight, padding=0, stride=2, groups=batch) + _, _, height, width = out.shape + out = out.view(batch, self.out_channel, height, width) + out = self.blur(out) + + elif self.downsample: + input = self.blur(input) + _, _, height, width = input.shape + input = input.view(1, batch * in_channel, height, width) + out = F.conv2d(input, weight, padding=0, stride=2, groups=batch) + _, _, height, width = out.shape + out = out.view(batch, self.out_channel, height, width) + + else: + input = input.view(1, batch * in_channel, height, width) + out = F.conv2d(input, weight, padding=self.padding, groups=batch) + _, _, height, width = out.shape + out = out.view(batch, self.out_channel, height, width) + + return out + + +class NoiseInjection(nn.Module): + def __init__(self): + super().__init__() + + self.weight = nn.Parameter(torch.zeros(1)) + + def forward(self, image, noise=None): + if noise is None: + batch, _, height, width = image.shape + noise = image.new_empty(batch, 1, height, width).normal_() + + return image + self.weight * noise + + +class ConstantInput(nn.Module): + def __init__(self, channel, size=4): + super().__init__() + + self.input = nn.Parameter(torch.randn(1, channel, size, size)) + + def forward(self, input): + batch = input.shape[0] + out = self.input.repeat(batch, 1, 1, 1) + + return out + + +class StyledConv(nn.Module): + def __init__( + self, + in_channel, + out_channel, + kernel_size, + style_dim, + upsample=False, + blur_kernel=[1, 3, 3, 1], + demodulate=True, + ): + super().__init__() + + self.conv = ModulatedConv2d( + in_channel, + out_channel, + kernel_size, + style_dim, + upsample=upsample, + blur_kernel=blur_kernel, + demodulate=demodulate, + ) + + self.noise = NoiseInjection() + # self.bias = nn.Parameter(torch.zeros(1, out_channel, 1, 1)) + # self.activate = ScaledLeakyReLU(0.2) + self.activate = FusedLeakyReLU(out_channel) + + def forward(self, input, style, noise=None): + out = self.conv(input, style) + out = self.noise(out, noise=noise) + # out = out + self.bias + out = self.activate(out) + + return out + + +class ToRGB(nn.Module): + def __init__(self, in_channel, style_dim, upsample=True, blur_kernel=[1, 3, 3, 1]): + super().__init__() + + if upsample: + self.upsample = Upsample(blur_kernel) + + self.conv = ModulatedConv2d(in_channel, 3, 1, style_dim, demodulate=False) + self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1)) + + def forward(self, input, style, skip=None): + out = self.conv(input, style) + out = out + self.bias + + if skip is not None: + skip = self.upsample(skip) + + out = out + skip + + return out + + +class Generator(nn.Module): + def __init__( + self, + size, + style_dim, + n_mlp, + channel_multiplier=2, + blur_kernel=[1, 3, 3, 1], + lr_mlp=0.01, + ): + super().__init__() + + self.size = size + + self.style_dim = style_dim + + layers = [PixelNorm()] + + for i in range(n_mlp): + layers.append( + EqualLinear( + style_dim, style_dim, lr_mul=lr_mlp, activation='fused_lrelu' + ) + ) + + self.style = nn.Sequential(*layers) + + self.channels = { + 4: 512, + 8: 512, + 16: 512, + 32: 512, + 64: 256 * channel_multiplier, + 128: 128 * channel_multiplier, + 256: 64 * channel_multiplier, + 512: 32 * channel_multiplier, + 1024: 16 * channel_multiplier, + } + + self.input = ConstantInput(self.channels[4]) + self.conv1 = StyledConv( + self.channels[4], self.channels[4], 3, style_dim, blur_kernel=blur_kernel + ) + self.to_rgb1 = ToRGB(self.channels[4], style_dim, upsample=False) + + self.log_size = int(math.log(size, 2)) + self.num_layers = (self.log_size - 2) * 2 + 1 + + self.convs = nn.ModuleList() + self.upsamples = nn.ModuleList() + self.to_rgbs = nn.ModuleList() + self.noises = nn.Module() + + in_channel = self.channels[4] + + for layer_idx in range(self.num_layers): + res = (layer_idx + 5) // 2 + shape = [1, 1, 2 ** res, 2 ** res] + self.noises.register_buffer(f'noise_{layer_idx}', torch.randn(*shape)) + + for i in range(3, self.log_size + 1): + out_channel = self.channels[2 ** i] + + self.convs.append( + StyledConv( + in_channel, + out_channel, + 3, + style_dim, + upsample=True, + blur_kernel=blur_kernel, + ) + ) + + self.convs.append( + StyledConv( + out_channel, out_channel, 3, style_dim, blur_kernel=blur_kernel + ) + ) + + self.to_rgbs.append(ToRGB(out_channel, style_dim)) + + in_channel = out_channel + + self.n_latent = self.log_size * 2 - 2 + + def make_noise(self): + device = self.input.input.device + + noises = [torch.randn(1, 1, 2 ** 2, 2 ** 2, device=device)] + + for i in range(3, self.log_size + 1): + for _ in range(2): + noises.append(torch.randn(1, 1, 2 ** i, 2 ** i, device=device)) + + return noises + + def mean_latent(self, n_latent): + latent_in = torch.randn( + n_latent, self.style_dim, device=self.input.input.device + ) + latent = self.style(latent_in).mean(0, keepdim=True) + + return latent + + def get_latent(self, input): + return self.style(input) + + def forward( + self, + styles, + return_latents=False, + return_features=False, + inject_index=None, + truncation=1, + truncation_latent=None, + input_is_latent=False, + noise=None, + randomize_noise=True, + ): + if not input_is_latent: + styles = [self.style(s) for s in styles] + + if noise is None: + if randomize_noise: + noise = [None] * self.num_layers + else: + noise = [ + getattr(self.noises, f'noise_{i}') for i in range(self.num_layers) + ] + + if truncation < 1: + style_t = [] + + for style in styles: + style_t.append( + truncation_latent + truncation * (style - truncation_latent) + ) + + styles = style_t + + if len(styles) < 2: + inject_index = self.n_latent + + if styles[0].ndim < 3: + latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1) + else: + latent = styles[0] + + else: + if inject_index is None: + inject_index = random.randint(1, self.n_latent - 1) + + latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1) + latent2 = styles[1].unsqueeze(1).repeat(1, self.n_latent - inject_index, 1) + + latent = torch.cat([latent, latent2], 1) + + out = self.input(latent) + out = self.conv1(out, latent[:, 0], noise=noise[0]) + + skip = self.to_rgb1(out, latent[:, 1]) + + i = 1 + for conv1, conv2, noise1, noise2, to_rgb in zip( + self.convs[::2], self.convs[1::2], noise[1::2], noise[2::2], self.to_rgbs + ): + out = conv1(out, latent[:, i], noise=noise1) + out = conv2(out, latent[:, i + 1], noise=noise2) + skip = to_rgb(out, latent[:, i + 2], skip) + + i += 2 + + image = skip + + if return_latents: + return image, latent + elif return_features: + return image, out + else: + return image, None + + +class ConvLayer(nn.Sequential): + def __init__( + self, + in_channel, + out_channel, + kernel_size, + downsample=False, + blur_kernel=[1, 3, 3, 1], + bias=True, + activate=True, + ): + layers = [] + + if downsample: + factor = 2 + p = (len(blur_kernel) - factor) + (kernel_size - 1) + pad0 = (p + 1) // 2 + pad1 = p // 2 + + layers.append(Blur(blur_kernel, pad=(pad0, pad1))) + + stride = 2 + self.padding = 0 + + else: + stride = 1 + self.padding = kernel_size // 2 + + layers.append( + EqualConv2d( + in_channel, + out_channel, + kernel_size, + padding=self.padding, + stride=stride, + bias=bias and not activate, + ) + ) + + if activate: + if bias: + layers.append(FusedLeakyReLU(out_channel)) + + else: + layers.append(ScaledLeakyReLU(0.2)) + + super().__init__(*layers) + + +class ResBlock(nn.Module): + def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1]): + super().__init__() + + self.conv1 = ConvLayer(in_channel, in_channel, 3) + self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=True) + + self.skip = ConvLayer( + in_channel, out_channel, 1, downsample=True, activate=False, bias=False + ) + + def forward(self, input): + out = self.conv1(input) + out = self.conv2(out) + + skip = self.skip(input) + out = (out + skip) / math.sqrt(2) + + return out + + +class Discriminator(nn.Module): + def __init__(self, size, channel_multiplier=2, blur_kernel=[1, 3, 3, 1]): + super().__init__() + + channels = { + 4: 512, + 8: 512, + 16: 512, + 32: 512, + 64: 256 * channel_multiplier, + 128: 128 * channel_multiplier, + 256: 64 * channel_multiplier, + 512: 32 * channel_multiplier, + 1024: 16 * channel_multiplier, + } + + convs = [ConvLayer(3, channels[size], 1)] + + log_size = int(math.log(size, 2)) + + in_channel = channels[size] + + for i in range(log_size, 2, -1): + out_channel = channels[2 ** (i - 1)] + + convs.append(ResBlock(in_channel, out_channel, blur_kernel)) + + in_channel = out_channel + + self.convs = nn.Sequential(*convs) + + self.stddev_group = 4 + self.stddev_feat = 1 + + self.final_conv = ConvLayer(in_channel + 1, channels[4], 3) + self.final_linear = nn.Sequential( + EqualLinear(channels[4] * 4 * 4, channels[4], activation='fused_lrelu'), + EqualLinear(channels[4], 1), + ) + + def forward(self, input): + out = self.convs(input) + + batch, channel, height, width = out.shape + group = min(batch, self.stddev_group) + stddev = out.view( + group, -1, self.stddev_feat, channel // self.stddev_feat, height, width + ) + stddev = torch.sqrt(stddev.var(0, unbiased=False) + 1e-8) + stddev = stddev.mean([2, 3, 4], keepdims=True).squeeze(2) + stddev = stddev.repeat(group, 1, height, width) + out = torch.cat([out, stddev], 1) + + out = self.final_conv(out) + + out = out.view(batch, -1) + out = self.final_linear(out) + + return out diff --git a/encoder4editing/models/stylegan2/op/__init__.py b/encoder4editing/models/stylegan2/op/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d0918d92285955855be89f00096b888ee5597ce3 --- /dev/null +++ b/encoder4editing/models/stylegan2/op/__init__.py @@ -0,0 +1,2 @@ +from .fused_act import FusedLeakyReLU, fused_leaky_relu +from .upfirdn2d import upfirdn2d diff --git a/encoder4editing/models/stylegan2/op/fused_act.py b/encoder4editing/models/stylegan2/op/fused_act.py new file mode 100644 index 0000000000000000000000000000000000000000..973a84fffde53668d31397da5fb993bbc95f7be0 --- /dev/null +++ b/encoder4editing/models/stylegan2/op/fused_act.py @@ -0,0 +1,85 @@ +import os + +import torch +from torch import nn +from torch.autograd import Function +from torch.utils.cpp_extension import load + +module_path = os.path.dirname(__file__) +fused = load( + 'fused', + sources=[ + os.path.join(module_path, 'fused_bias_act.cpp'), + os.path.join(module_path, 'fused_bias_act_kernel.cu'), + ], +) + + +class FusedLeakyReLUFunctionBackward(Function): + @staticmethod + def forward(ctx, grad_output, out, negative_slope, scale): + ctx.save_for_backward(out) + ctx.negative_slope = negative_slope + ctx.scale = scale + + empty = grad_output.new_empty(0) + + grad_input = fused.fused_bias_act( + grad_output, empty, out, 3, 1, negative_slope, scale + ) + + dim = [0] + + if grad_input.ndim > 2: + dim += list(range(2, grad_input.ndim)) + + grad_bias = grad_input.sum(dim).detach() + + return grad_input, grad_bias + + @staticmethod + def backward(ctx, gradgrad_input, gradgrad_bias): + out, = ctx.saved_tensors + gradgrad_out = fused.fused_bias_act( + gradgrad_input, gradgrad_bias, out, 3, 1, ctx.negative_slope, ctx.scale + ) + + return gradgrad_out, None, None, None + + +class FusedLeakyReLUFunction(Function): + @staticmethod + def forward(ctx, input, bias, negative_slope, scale): + empty = input.new_empty(0) + out = fused.fused_bias_act(input, bias, empty, 3, 0, negative_slope, scale) + ctx.save_for_backward(out) + ctx.negative_slope = negative_slope + ctx.scale = scale + + return out + + @staticmethod + def backward(ctx, grad_output): + out, = ctx.saved_tensors + + grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply( + grad_output, out, ctx.negative_slope, ctx.scale + ) + + return grad_input, grad_bias, None, None + + +class FusedLeakyReLU(nn.Module): + def __init__(self, channel, negative_slope=0.2, scale=2 ** 0.5): + super().__init__() + + self.bias = nn.Parameter(torch.zeros(channel)) + self.negative_slope = negative_slope + self.scale = scale + + def forward(self, input): + return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale) + + +def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2 ** 0.5): + return FusedLeakyReLUFunction.apply(input, bias, negative_slope, scale) diff --git a/encoder4editing/models/stylegan2/op/fused_bias_act.cpp b/encoder4editing/models/stylegan2/op/fused_bias_act.cpp new file mode 100644 index 0000000000000000000000000000000000000000..02be898f970bcc8ea297867fcaa4e71b24b3d949 --- /dev/null +++ b/encoder4editing/models/stylegan2/op/fused_bias_act.cpp @@ -0,0 +1,21 @@ +#include + + +torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, + int act, int grad, float alpha, float scale); + +#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") +#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) + +torch::Tensor fused_bias_act(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, + int act, int grad, float alpha, float scale) { + CHECK_CUDA(input); + CHECK_CUDA(bias); + + return fused_bias_act_op(input, bias, refer, act, grad, alpha, scale); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("fused_bias_act", &fused_bias_act, "fused bias act (CUDA)"); +} \ No newline at end of file diff --git a/encoder4editing/models/stylegan2/op/fused_bias_act_kernel.cu b/encoder4editing/models/stylegan2/op/fused_bias_act_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..c9fa56fea7ede7072dc8925cfb0148f136eb85b8 --- /dev/null +++ b/encoder4editing/models/stylegan2/op/fused_bias_act_kernel.cu @@ -0,0 +1,99 @@ +// Copyright (c) 2019, NVIDIA Corporation. All rights reserved. +// +// This work is made available under the Nvidia Source Code License-NC. +// To view a copy of this license, visit +// https://nvlabs.github.io/stylegan2/license.html + +#include + +#include +#include +#include +#include + +#include +#include + + +template +static __global__ void fused_bias_act_kernel(scalar_t* out, const scalar_t* p_x, const scalar_t* p_b, const scalar_t* p_ref, + int act, int grad, scalar_t alpha, scalar_t scale, int loop_x, int size_x, int step_b, int size_b, int use_bias, int use_ref) { + int xi = blockIdx.x * loop_x * blockDim.x + threadIdx.x; + + scalar_t zero = 0.0; + + for (int loop_idx = 0; loop_idx < loop_x && xi < size_x; loop_idx++, xi += blockDim.x) { + scalar_t x = p_x[xi]; + + if (use_bias) { + x += p_b[(xi / step_b) % size_b]; + } + + scalar_t ref = use_ref ? p_ref[xi] : zero; + + scalar_t y; + + switch (act * 10 + grad) { + default: + case 10: y = x; break; + case 11: y = x; break; + case 12: y = 0.0; break; + + case 30: y = (x > 0.0) ? x : x * alpha; break; + case 31: y = (ref > 0.0) ? x : x * alpha; break; + case 32: y = 0.0; break; + } + + out[xi] = y * scale; + } +} + + +torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, + int act, int grad, float alpha, float scale) { + int curDevice = -1; + cudaGetDevice(&curDevice); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice); + + auto x = input.contiguous(); + auto b = bias.contiguous(); + auto ref = refer.contiguous(); + + int use_bias = b.numel() ? 1 : 0; + int use_ref = ref.numel() ? 1 : 0; + + int size_x = x.numel(); + int size_b = b.numel(); + int step_b = 1; + + for (int i = 1 + 1; i < x.dim(); i++) { + step_b *= x.size(i); + } + + int loop_x = 4; + int block_size = 4 * 32; + int grid_size = (size_x - 1) / (loop_x * block_size) + 1; + + auto y = torch::empty_like(x); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "fused_bias_act_kernel", [&] { + fused_bias_act_kernel<<>>( + y.data_ptr(), + x.data_ptr(), + b.data_ptr(), + ref.data_ptr(), + act, + grad, + alpha, + scale, + loop_x, + size_x, + step_b, + size_b, + use_bias, + use_ref + ); + }); + + return y; +} \ No newline at end of file diff --git a/encoder4editing/models/stylegan2/op/upfirdn2d.cpp b/encoder4editing/models/stylegan2/op/upfirdn2d.cpp new file mode 100644 index 0000000000000000000000000000000000000000..d2e633dc896433c205e18bc3e455539192ff968e --- /dev/null +++ b/encoder4editing/models/stylegan2/op/upfirdn2d.cpp @@ -0,0 +1,23 @@ +#include + + +torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel, + int up_x, int up_y, int down_x, int down_y, + int pad_x0, int pad_x1, int pad_y0, int pad_y1); + +#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") +#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) + +torch::Tensor upfirdn2d(const torch::Tensor& input, const torch::Tensor& kernel, + int up_x, int up_y, int down_x, int down_y, + int pad_x0, int pad_x1, int pad_y0, int pad_y1) { + CHECK_CUDA(input); + CHECK_CUDA(kernel); + + return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)"); +} \ No newline at end of file diff --git a/encoder4editing/models/stylegan2/op/upfirdn2d.py b/encoder4editing/models/stylegan2/op/upfirdn2d.py new file mode 100644 index 0000000000000000000000000000000000000000..7bc5a1e331c2bbb1893ac748cfd0f144ff0651b4 --- /dev/null +++ b/encoder4editing/models/stylegan2/op/upfirdn2d.py @@ -0,0 +1,184 @@ +import os + +import torch +from torch.autograd import Function +from torch.utils.cpp_extension import load + +module_path = os.path.dirname(__file__) +upfirdn2d_op = load( + 'upfirdn2d', + sources=[ + os.path.join(module_path, 'upfirdn2d.cpp'), + os.path.join(module_path, 'upfirdn2d_kernel.cu'), + ], +) + + +class UpFirDn2dBackward(Function): + @staticmethod + def forward( + ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad, in_size, out_size + ): + up_x, up_y = up + down_x, down_y = down + g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad + + grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1) + + grad_input = upfirdn2d_op.upfirdn2d( + grad_output, + grad_kernel, + down_x, + down_y, + up_x, + up_y, + g_pad_x0, + g_pad_x1, + g_pad_y0, + g_pad_y1, + ) + grad_input = grad_input.view(in_size[0], in_size[1], in_size[2], in_size[3]) + + ctx.save_for_backward(kernel) + + pad_x0, pad_x1, pad_y0, pad_y1 = pad + + ctx.up_x = up_x + ctx.up_y = up_y + ctx.down_x = down_x + ctx.down_y = down_y + ctx.pad_x0 = pad_x0 + ctx.pad_x1 = pad_x1 + ctx.pad_y0 = pad_y0 + ctx.pad_y1 = pad_y1 + ctx.in_size = in_size + ctx.out_size = out_size + + return grad_input + + @staticmethod + def backward(ctx, gradgrad_input): + kernel, = ctx.saved_tensors + + gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2], ctx.in_size[3], 1) + + gradgrad_out = upfirdn2d_op.upfirdn2d( + gradgrad_input, + kernel, + ctx.up_x, + ctx.up_y, + ctx.down_x, + ctx.down_y, + ctx.pad_x0, + ctx.pad_x1, + ctx.pad_y0, + ctx.pad_y1, + ) + # gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0], ctx.out_size[1], ctx.in_size[3]) + gradgrad_out = gradgrad_out.view( + ctx.in_size[0], ctx.in_size[1], ctx.out_size[0], ctx.out_size[1] + ) + + return gradgrad_out, None, None, None, None, None, None, None, None + + +class UpFirDn2d(Function): + @staticmethod + def forward(ctx, input, kernel, up, down, pad): + up_x, up_y = up + down_x, down_y = down + pad_x0, pad_x1, pad_y0, pad_y1 = pad + + kernel_h, kernel_w = kernel.shape + batch, channel, in_h, in_w = input.shape + ctx.in_size = input.shape + + input = input.reshape(-1, in_h, in_w, 1) + + ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1])) + + out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1 + out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1 + ctx.out_size = (out_h, out_w) + + ctx.up = (up_x, up_y) + ctx.down = (down_x, down_y) + ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1) + + g_pad_x0 = kernel_w - pad_x0 - 1 + g_pad_y0 = kernel_h - pad_y0 - 1 + g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1 + g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1 + + ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1) + + out = upfirdn2d_op.upfirdn2d( + input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1 + ) + # out = out.view(major, out_h, out_w, minor) + out = out.view(-1, channel, out_h, out_w) + + return out + + @staticmethod + def backward(ctx, grad_output): + kernel, grad_kernel = ctx.saved_tensors + + grad_input = UpFirDn2dBackward.apply( + grad_output, + kernel, + grad_kernel, + ctx.up, + ctx.down, + ctx.pad, + ctx.g_pad, + ctx.in_size, + ctx.out_size, + ) + + return grad_input, None, None, None, None + + +def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)): + out = UpFirDn2d.apply( + input, kernel, (up, up), (down, down), (pad[0], pad[1], pad[0], pad[1]) + ) + + return out + + +def upfirdn2d_native( + input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1 +): + _, in_h, in_w, minor = input.shape + kernel_h, kernel_w = kernel.shape + + out = input.view(-1, in_h, 1, in_w, 1, minor) + out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1]) + out = out.view(-1, in_h * up_y, in_w * up_x, minor) + + out = F.pad( + out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)] + ) + out = out[ + :, + max(-pad_y0, 0): out.shape[1] - max(-pad_y1, 0), + max(-pad_x0, 0): out.shape[2] - max(-pad_x1, 0), + :, + ] + + out = out.permute(0, 3, 1, 2) + out = out.reshape( + [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1] + ) + w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w) + out = F.conv2d(out, w) + out = out.reshape( + -1, + minor, + in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1, + in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1, + ) + out = out.permute(0, 2, 3, 1) + + return out[:, ::down_y, ::down_x, :] diff --git a/encoder4editing/models/stylegan2/op/upfirdn2d_kernel.cu b/encoder4editing/models/stylegan2/op/upfirdn2d_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..2a710aa6adc3d43ac93136a1814e3c39970e1c7e --- /dev/null +++ b/encoder4editing/models/stylegan2/op/upfirdn2d_kernel.cu @@ -0,0 +1,272 @@ +// Copyright (c) 2019, NVIDIA Corporation. All rights reserved. +// +// This work is made available under the Nvidia Source Code License-NC. +// To view a copy of this license, visit +// https://nvlabs.github.io/stylegan2/license.html + +#include + +#include +#include +#include +#include + +#include +#include + + +static __host__ __device__ __forceinline__ int floor_div(int a, int b) { + int c = a / b; + + if (c * b > a) { + c--; + } + + return c; +} + + +struct UpFirDn2DKernelParams { + int up_x; + int up_y; + int down_x; + int down_y; + int pad_x0; + int pad_x1; + int pad_y0; + int pad_y1; + + int major_dim; + int in_h; + int in_w; + int minor_dim; + int kernel_h; + int kernel_w; + int out_h; + int out_w; + int loop_major; + int loop_x; +}; + + +template +__global__ void upfirdn2d_kernel(scalar_t* out, const scalar_t* input, const scalar_t* kernel, const UpFirDn2DKernelParams p) { + const int tile_in_h = ((tile_out_h - 1) * down_y + kernel_h - 1) / up_y + 1; + const int tile_in_w = ((tile_out_w - 1) * down_x + kernel_w - 1) / up_x + 1; + + __shared__ volatile float sk[kernel_h][kernel_w]; + __shared__ volatile float sx[tile_in_h][tile_in_w]; + + int minor_idx = blockIdx.x; + int tile_out_y = minor_idx / p.minor_dim; + minor_idx -= tile_out_y * p.minor_dim; + tile_out_y *= tile_out_h; + int tile_out_x_base = blockIdx.y * p.loop_x * tile_out_w; + int major_idx_base = blockIdx.z * p.loop_major; + + if (tile_out_x_base >= p.out_w | tile_out_y >= p.out_h | major_idx_base >= p.major_dim) { + return; + } + + for (int tap_idx = threadIdx.x; tap_idx < kernel_h * kernel_w; tap_idx += blockDim.x) { + int ky = tap_idx / kernel_w; + int kx = tap_idx - ky * kernel_w; + scalar_t v = 0.0; + + if (kx < p.kernel_w & ky < p.kernel_h) { + v = kernel[(p.kernel_h - 1 - ky) * p.kernel_w + (p.kernel_w - 1 - kx)]; + } + + sk[ky][kx] = v; + } + + for (int loop_major = 0, major_idx = major_idx_base; loop_major < p.loop_major & major_idx < p.major_dim; loop_major++, major_idx++) { + for (int loop_x = 0, tile_out_x = tile_out_x_base; loop_x < p.loop_x & tile_out_x < p.out_w; loop_x++, tile_out_x += tile_out_w) { + int tile_mid_x = tile_out_x * down_x + up_x - 1 - p.pad_x0; + int tile_mid_y = tile_out_y * down_y + up_y - 1 - p.pad_y0; + int tile_in_x = floor_div(tile_mid_x, up_x); + int tile_in_y = floor_div(tile_mid_y, up_y); + + __syncthreads(); + + for (int in_idx = threadIdx.x; in_idx < tile_in_h * tile_in_w; in_idx += blockDim.x) { + int rel_in_y = in_idx / tile_in_w; + int rel_in_x = in_idx - rel_in_y * tile_in_w; + int in_x = rel_in_x + tile_in_x; + int in_y = rel_in_y + tile_in_y; + + scalar_t v = 0.0; + + if (in_x >= 0 & in_y >= 0 & in_x < p.in_w & in_y < p.in_h) { + v = input[((major_idx * p.in_h + in_y) * p.in_w + in_x) * p.minor_dim + minor_idx]; + } + + sx[rel_in_y][rel_in_x] = v; + } + + __syncthreads(); + for (int out_idx = threadIdx.x; out_idx < tile_out_h * tile_out_w; out_idx += blockDim.x) { + int rel_out_y = out_idx / tile_out_w; + int rel_out_x = out_idx - rel_out_y * tile_out_w; + int out_x = rel_out_x + tile_out_x; + int out_y = rel_out_y + tile_out_y; + + int mid_x = tile_mid_x + rel_out_x * down_x; + int mid_y = tile_mid_y + rel_out_y * down_y; + int in_x = floor_div(mid_x, up_x); + int in_y = floor_div(mid_y, up_y); + int rel_in_x = in_x - tile_in_x; + int rel_in_y = in_y - tile_in_y; + int kernel_x = (in_x + 1) * up_x - mid_x - 1; + int kernel_y = (in_y + 1) * up_y - mid_y - 1; + + scalar_t v = 0.0; + + #pragma unroll + for (int y = 0; y < kernel_h / up_y; y++) + #pragma unroll + for (int x = 0; x < kernel_w / up_x; x++) + v += sx[rel_in_y + y][rel_in_x + x] * sk[kernel_y + y * up_y][kernel_x + x * up_x]; + + if (out_x < p.out_w & out_y < p.out_h) { + out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim + minor_idx] = v; + } + } + } + } +} + + +torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel, + int up_x, int up_y, int down_x, int down_y, + int pad_x0, int pad_x1, int pad_y0, int pad_y1) { + int curDevice = -1; + cudaGetDevice(&curDevice); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice); + + UpFirDn2DKernelParams p; + + auto x = input.contiguous(); + auto k = kernel.contiguous(); + + p.major_dim = x.size(0); + p.in_h = x.size(1); + p.in_w = x.size(2); + p.minor_dim = x.size(3); + p.kernel_h = k.size(0); + p.kernel_w = k.size(1); + p.up_x = up_x; + p.up_y = up_y; + p.down_x = down_x; + p.down_y = down_y; + p.pad_x0 = pad_x0; + p.pad_x1 = pad_x1; + p.pad_y0 = pad_y0; + p.pad_y1 = pad_y1; + + p.out_h = (p.in_h * p.up_y + p.pad_y0 + p.pad_y1 - p.kernel_h + p.down_y) / p.down_y; + p.out_w = (p.in_w * p.up_x + p.pad_x0 + p.pad_x1 - p.kernel_w + p.down_x) / p.down_x; + + auto out = at::empty({p.major_dim, p.out_h, p.out_w, p.minor_dim}, x.options()); + + int mode = -1; + + int tile_out_h; + int tile_out_w; + + if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 && p.kernel_h <= 4 && p.kernel_w <= 4) { + mode = 1; + tile_out_h = 16; + tile_out_w = 64; + } + + if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 && p.kernel_h <= 3 && p.kernel_w <= 3) { + mode = 2; + tile_out_h = 16; + tile_out_w = 64; + } + + if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 && p.kernel_h <= 4 && p.kernel_w <= 4) { + mode = 3; + tile_out_h = 16; + tile_out_w = 64; + } + + if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 && p.kernel_h <= 2 && p.kernel_w <= 2) { + mode = 4; + tile_out_h = 16; + tile_out_w = 64; + } + + if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 && p.kernel_h <= 4 && p.kernel_w <= 4) { + mode = 5; + tile_out_h = 8; + tile_out_w = 32; + } + + if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 && p.kernel_h <= 2 && p.kernel_w <= 2) { + mode = 6; + tile_out_h = 8; + tile_out_w = 32; + } + + dim3 block_size; + dim3 grid_size; + + if (tile_out_h > 0 && tile_out_w) { + p.loop_major = (p.major_dim - 1) / 16384 + 1; + p.loop_x = 1; + block_size = dim3(32 * 8, 1, 1); + grid_size = dim3(((p.out_h - 1) / tile_out_h + 1) * p.minor_dim, + (p.out_w - 1) / (p.loop_x * tile_out_w) + 1, + (p.major_dim - 1) / p.loop_major + 1); + } + + AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] { + switch (mode) { + case 1: + upfirdn2d_kernel<<>>( + out.data_ptr(), x.data_ptr(), k.data_ptr(), p + ); + + break; + + case 2: + upfirdn2d_kernel<<>>( + out.data_ptr(), x.data_ptr(), k.data_ptr(), p + ); + + break; + + case 3: + upfirdn2d_kernel<<>>( + out.data_ptr(), x.data_ptr(), k.data_ptr(), p + ); + + break; + + case 4: + upfirdn2d_kernel<<>>( + out.data_ptr(), x.data_ptr(), k.data_ptr(), p + ); + + break; + + case 5: + upfirdn2d_kernel<<>>( + out.data_ptr(), x.data_ptr(), k.data_ptr(), p + ); + + break; + + case 6: + upfirdn2d_kernel<<>>( + out.data_ptr(), x.data_ptr(), k.data_ptr(), p + ); + + break; + } + }); + + return out; +} \ No newline at end of file diff --git a/encoder4editing/options/__init__.py b/encoder4editing/options/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/encoder4editing/options/train_options.py b/encoder4editing/options/train_options.py new file mode 100644 index 0000000000000000000000000000000000000000..583ea1423fdc9a649cd7044d74d554bf0ac2bf51 --- /dev/null +++ b/encoder4editing/options/train_options.py @@ -0,0 +1,84 @@ +from argparse import ArgumentParser +from configs.paths_config import model_paths + + +class TrainOptions: + + def __init__(self): + self.parser = ArgumentParser() + self.initialize() + + def initialize(self): + self.parser.add_argument('--exp_dir', type=str, help='Path to experiment output directory') + self.parser.add_argument('--dataset_type', default='ffhq_encode', type=str, + help='Type of dataset/experiment to run') + self.parser.add_argument('--encoder_type', default='Encoder4Editing', type=str, help='Which encoder to use') + + self.parser.add_argument('--batch_size', default=4, type=int, help='Batch size for training') + self.parser.add_argument('--test_batch_size', default=2, type=int, help='Batch size for testing and inference') + self.parser.add_argument('--workers', default=4, type=int, help='Number of train dataloader workers') + self.parser.add_argument('--test_workers', default=2, type=int, + help='Number of test/inference dataloader workers') + + self.parser.add_argument('--learning_rate', default=0.0001, type=float, help='Optimizer learning rate') + self.parser.add_argument('--optim_name', default='ranger', type=str, help='Which optimizer to use') + self.parser.add_argument('--train_decoder', default=False, type=bool, help='Whether to train the decoder model') + self.parser.add_argument('--start_from_latent_avg', action='store_true', + help='Whether to add average latent vector to generate codes from encoder.') + self.parser.add_argument('--lpips_type', default='alex', type=str, help='LPIPS backbone') + + self.parser.add_argument('--lpips_lambda', default=0.8, type=float, help='LPIPS loss multiplier factor') + self.parser.add_argument('--id_lambda', default=0.1, type=float, help='ID loss multiplier factor') + self.parser.add_argument('--l2_lambda', default=1.0, type=float, help='L2 loss multiplier factor') + + self.parser.add_argument('--stylegan_weights', default=model_paths['stylegan_ffhq'], type=str, + help='Path to StyleGAN model weights') + self.parser.add_argument('--stylegan_size', default=1024, type=int, + help='size of pretrained StyleGAN Generator') + self.parser.add_argument('--checkpoint_path', default=None, type=str, help='Path to pSp model checkpoint') + + self.parser.add_argument('--max_steps', default=500000, type=int, help='Maximum number of training steps') + self.parser.add_argument('--image_interval', default=100, type=int, + help='Interval for logging train images during training') + self.parser.add_argument('--board_interval', default=50, type=int, + help='Interval for logging metrics to tensorboard') + self.parser.add_argument('--val_interval', default=1000, type=int, help='Validation interval') + self.parser.add_argument('--save_interval', default=None, type=int, help='Model checkpoint interval') + + # Discriminator flags + self.parser.add_argument('--w_discriminator_lambda', default=0, type=float, help='Dw loss multiplier') + self.parser.add_argument('--w_discriminator_lr', default=2e-5, type=float, help='Dw learning rate') + self.parser.add_argument("--r1", type=float, default=10, help="weight of the r1 regularization") + self.parser.add_argument("--d_reg_every", type=int, default=16, + help="interval for applying r1 regularization") + self.parser.add_argument('--use_w_pool', action='store_true', + help='Whether to store a latnet codes pool for the discriminator\'s training') + self.parser.add_argument("--w_pool_size", type=int, default=50, + help="W\'s pool size, depends on --use_w_pool") + + # e4e specific + self.parser.add_argument('--delta_norm', type=int, default=2, help="norm type of the deltas") + self.parser.add_argument('--delta_norm_lambda', type=float, default=2e-4, help="lambda for delta norm loss") + + # Progressive training + self.parser.add_argument('--progressive_steps', nargs='+', type=int, default=None, + help="The training steps of training new deltas. steps[i] starts the delta_i training") + self.parser.add_argument('--progressive_start', type=int, default=None, + help="The training step to start training the deltas, overrides progressive_steps") + self.parser.add_argument('--progressive_step_every', type=int, default=2_000, + help="Amount of training steps for each progressive step") + + # Save additional training info to enable future training continuation from produced checkpoints + self.parser.add_argument('--save_training_data', action='store_true', + help='Save intermediate training data to resume training from the checkpoint') + self.parser.add_argument('--sub_exp_dir', default=None, type=str, help='Name of sub experiment directory') + self.parser.add_argument('--keep_optimizer', action='store_true', + help='Whether to continue from the checkpoint\'s optimizer') + self.parser.add_argument('--resume_training_from_ckpt', default=None, type=str, + help='Path to training checkpoint, works when --save_training_data was set to True') + self.parser.add_argument('--update_param_list', nargs='+', type=str, default=None, + help="Name of training parameters to update the loaded training checkpoint") + + def parse(self): + opts = self.parser.parse_args() + return opts diff --git a/encoder4editing/scripts/calc_losses_on_images.py b/encoder4editing/scripts/calc_losses_on_images.py new file mode 100644 index 0000000000000000000000000000000000000000..32b6bcee854da7ae357daf82bd986f30db9fb72c --- /dev/null +++ b/encoder4editing/scripts/calc_losses_on_images.py @@ -0,0 +1,87 @@ +from argparse import ArgumentParser +import os +import json +import sys +from tqdm import tqdm +import numpy as np +import torch +from torch.utils.data import DataLoader +import torchvision.transforms as transforms + +sys.path.append(".") +sys.path.append("..") + +from criteria.lpips.lpips import LPIPS +from datasets.gt_res_dataset import GTResDataset + + +def parse_args(): + parser = ArgumentParser(add_help=False) + parser.add_argument('--mode', type=str, default='lpips', choices=['lpips', 'l2']) + parser.add_argument('--data_path', type=str, default='results') + parser.add_argument('--gt_path', type=str, default='gt_images') + parser.add_argument('--workers', type=int, default=4) + parser.add_argument('--batch_size', type=int, default=4) + parser.add_argument('--is_cars', action='store_true') + args = parser.parse_args() + return args + + +def run(args): + resize_dims = (256, 256) + if args.is_cars: + resize_dims = (192, 256) + transform = transforms.Compose([transforms.Resize(resize_dims), + transforms.ToTensor(), + transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]) + + print('Loading dataset') + dataset = GTResDataset(root_path=args.data_path, + gt_dir=args.gt_path, + transform=transform) + + dataloader = DataLoader(dataset, + batch_size=args.batch_size, + shuffle=False, + num_workers=int(args.workers), + drop_last=True) + + if args.mode == 'lpips': + loss_func = LPIPS(net_type='alex') + elif args.mode == 'l2': + loss_func = torch.nn.MSELoss() + else: + raise Exception('Not a valid mode!') + loss_func.cuda() + + global_i = 0 + scores_dict = {} + all_scores = [] + for result_batch, gt_batch in tqdm(dataloader): + for i in range(args.batch_size): + loss = float(loss_func(result_batch[i:i + 1].cuda(), gt_batch[i:i + 1].cuda())) + all_scores.append(loss) + im_path = dataset.pairs[global_i][0] + scores_dict[os.path.basename(im_path)] = loss + global_i += 1 + + all_scores = list(scores_dict.values()) + mean = np.mean(all_scores) + std = np.std(all_scores) + result_str = 'Average loss is {:.2f}+-{:.2f}'.format(mean, std) + print('Finished with ', args.data_path) + print(result_str) + + out_path = os.path.join(os.path.dirname(args.data_path), 'inference_metrics') + if not os.path.exists(out_path): + os.makedirs(out_path) + + with open(os.path.join(out_path, 'stat_{}.txt'.format(args.mode)), 'w') as f: + f.write(result_str) + with open(os.path.join(out_path, 'scores_{}.json'.format(args.mode)), 'w') as f: + json.dump(scores_dict, f) + + +if __name__ == '__main__': + args = parse_args() + run(args) diff --git a/encoder4editing/scripts/inference.py b/encoder4editing/scripts/inference.py new file mode 100644 index 0000000000000000000000000000000000000000..185b9b34db85dcd97b9793bd5dbfc9d1ca046549 --- /dev/null +++ b/encoder4editing/scripts/inference.py @@ -0,0 +1,133 @@ +import argparse + +import torch +import numpy as np +import sys +import os +import dlib + +sys.path.append(".") +sys.path.append("..") + +from configs import data_configs, paths_config +from datasets.inference_dataset import InferenceDataset +from torch.utils.data import DataLoader +from utils.model_utils import setup_model +from utils.common import tensor2im +from utils.alignment import align_face +from PIL import Image + + +def main(args): + net, opts = setup_model(args.ckpt, device) + is_cars = 'cars_' in opts.dataset_type + generator = net.decoder + generator.eval() + args, data_loader = setup_data_loader(args, opts) + + # Check if latents exist + latents_file_path = os.path.join(args.save_dir, 'latents.pt') + if os.path.exists(latents_file_path): + latent_codes = torch.load(latents_file_path).to(device) + else: + latent_codes = get_all_latents(net, data_loader, args.n_sample, is_cars=is_cars) + torch.save(latent_codes, latents_file_path) + + if not args.latents_only: + generate_inversions(args, generator, latent_codes, is_cars=is_cars) + + +def setup_data_loader(args, opts): + dataset_args = data_configs.DATASETS[opts.dataset_type] + transforms_dict = dataset_args['transforms'](opts).get_transforms() + images_path = args.images_dir if args.images_dir is not None else dataset_args['test_source_root'] + print(f"images path: {images_path}") + align_function = None + if args.align: + align_function = run_alignment + test_dataset = InferenceDataset(root=images_path, + transform=transforms_dict['transform_test'], + preprocess=align_function, + opts=opts) + + data_loader = DataLoader(test_dataset, + batch_size=args.batch, + shuffle=False, + num_workers=2, + drop_last=True) + + print(f'dataset length: {len(test_dataset)}') + + if args.n_sample is None: + args.n_sample = len(test_dataset) + return args, data_loader + + +def get_latents(net, x, is_cars=False): + codes = net.encoder(x) + if net.opts.start_from_latent_avg: + if codes.ndim == 2: + codes = codes + net.latent_avg.repeat(codes.shape[0], 1, 1)[:, 0, :] + else: + codes = codes + net.latent_avg.repeat(codes.shape[0], 1, 1) + if codes.shape[1] == 18 and is_cars: + codes = codes[:, :16, :] + return codes + + +def get_all_latents(net, data_loader, n_images=None, is_cars=False): + all_latents = [] + i = 0 + with torch.no_grad(): + for batch in data_loader: + if n_images is not None and i > n_images: + break + x = batch + inputs = x.to(device).float() + latents = get_latents(net, inputs, is_cars) + all_latents.append(latents) + i += len(latents) + return torch.cat(all_latents) + + +def save_image(img, save_dir, idx): + result = tensor2im(img) + im_save_path = os.path.join(save_dir, f"{idx:05d}.jpg") + Image.fromarray(np.array(result)).save(im_save_path) + + +@torch.no_grad() +def generate_inversions(args, g, latent_codes, is_cars): + print('Saving inversion images') + inversions_directory_path = os.path.join(args.save_dir, 'inversions') + os.makedirs(inversions_directory_path, exist_ok=True) + for i in range(args.n_sample): + imgs, _ = g([latent_codes[i].unsqueeze(0)], input_is_latent=True, randomize_noise=False, return_latents=True) + if is_cars: + imgs = imgs[:, :, 64:448, :] + save_image(imgs[0], inversions_directory_path, i + 1) + + +def run_alignment(image_path): + predictor = dlib.shape_predictor(paths_config.model_paths['shape_predictor']) + aligned_image = align_face(filepath=image_path, predictor=predictor) + print("Aligned image has shape: {}".format(aligned_image.size)) + return aligned_image + + +if __name__ == "__main__": + device = "cuda" + + parser = argparse.ArgumentParser(description="Inference") + parser.add_argument("--images_dir", type=str, default=None, + help="The directory of the images to be inverted") + parser.add_argument("--save_dir", type=str, default=None, + help="The directory to save the latent codes and inversion images. (default: images_dir") + parser.add_argument("--batch", type=int, default=1, help="batch size for the generator") + parser.add_argument("--n_sample", type=int, default=None, help="number of the samples to infer.") + parser.add_argument("--latents_only", action="store_true", help="infer only the latent codes of the directory") + parser.add_argument("--align", action="store_true", help="align face images before inference") + parser.add_argument("ckpt", metavar="CHECKPOINT", help="path to generator checkpoint") + + args = parser.parse_args() + main(args) diff --git a/encoder4editing/scripts/train.py b/encoder4editing/scripts/train.py new file mode 100644 index 0000000000000000000000000000000000000000..d885cfde49a0b21140e663e475918698d5e51ee3 --- /dev/null +++ b/encoder4editing/scripts/train.py @@ -0,0 +1,88 @@ +""" +This file runs the main training/val loop +""" +import os +import json +import math +import sys +import pprint +import torch +from argparse import Namespace + +sys.path.append(".") +sys.path.append("..") + +from options.train_options import TrainOptions +from training.coach import Coach + + +def main(): + opts = TrainOptions().parse() + previous_train_ckpt = None + if opts.resume_training_from_ckpt: + opts, previous_train_ckpt = load_train_checkpoint(opts) + else: + setup_progressive_steps(opts) + create_initial_experiment_dir(opts) + + coach = Coach(opts, previous_train_ckpt) + coach.train() + + +def load_train_checkpoint(opts): + train_ckpt_path = opts.resume_training_from_ckpt + previous_train_ckpt = torch.load(opts.resume_training_from_ckpt, map_location='cpu') + new_opts_dict = vars(opts) + opts = previous_train_ckpt['opts'] + opts['resume_training_from_ckpt'] = train_ckpt_path + update_new_configs(opts, new_opts_dict) + pprint.pprint(opts) + opts = Namespace(**opts) + if opts.sub_exp_dir is not None: + sub_exp_dir = opts.sub_exp_dir + opts.exp_dir = os.path.join(opts.exp_dir, sub_exp_dir) + create_initial_experiment_dir(opts) + return opts, previous_train_ckpt + + +def setup_progressive_steps(opts): + log_size = int(math.log(opts.stylegan_size, 2)) + num_style_layers = 2*log_size - 2 + num_deltas = num_style_layers - 1 + if opts.progressive_start is not None: # If progressive delta training + opts.progressive_steps = [0] + next_progressive_step = opts.progressive_start + for i in range(num_deltas): + opts.progressive_steps.append(next_progressive_step) + next_progressive_step += opts.progressive_step_every + + assert opts.progressive_steps is None or is_valid_progressive_steps(opts, num_style_layers), \ + "Invalid progressive training input" + + +def is_valid_progressive_steps(opts, num_style_layers): + return len(opts.progressive_steps) == num_style_layers and opts.progressive_steps[0] == 0 + + +def create_initial_experiment_dir(opts): + if os.path.exists(opts.exp_dir): + raise Exception('Oops... {} already exists'.format(opts.exp_dir)) + os.makedirs(opts.exp_dir) + + opts_dict = vars(opts) + pprint.pprint(opts_dict) + with open(os.path.join(opts.exp_dir, 'opt.json'), 'w') as f: + json.dump(opts_dict, f, indent=4, sort_keys=True) + + +def update_new_configs(ckpt_opts, new_opts): + for k, v in new_opts.items(): + if k not in ckpt_opts: + ckpt_opts[k] = v + if new_opts['update_param_list']: + for param in new_opts['update_param_list']: + ckpt_opts[param] = new_opts[param] + + +if __name__ == '__main__': + main() diff --git a/encoder4editing/training/__init__.py b/encoder4editing/training/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/encoder4editing/training/coach.py b/encoder4editing/training/coach.py new file mode 100644 index 0000000000000000000000000000000000000000..10b22a6830673752dcf922cee7914c39069a4333 --- /dev/null +++ b/encoder4editing/training/coach.py @@ -0,0 +1,439 @@ +import os +import random +import matplotlib +import matplotlib.pyplot as plt + +matplotlib.use('Agg') + +import torch +from torch import nn, autograd +from torch.utils.data import DataLoader +from torch.utils.tensorboard import SummaryWriter +import torch.nn.functional as F + +from utils import common, train_utils +from criteria import id_loss, moco_loss +from configs import data_configs +from datasets.images_dataset import ImagesDataset +from criteria.lpips.lpips import LPIPS +from models.psp import pSp +from models.latent_codes_pool import LatentCodesPool +from models.discriminator import LatentCodesDiscriminator +from models.encoders.psp_encoders import ProgressiveStage +from training.ranger import Ranger + +random.seed(0) +torch.manual_seed(0) + + +class Coach: + def __init__(self, opts, prev_train_checkpoint=None): + self.opts = opts + + self.global_step = 0 + + self.device = 'cuda:0' + self.opts.device = self.device + # Initialize network + self.net = pSp(self.opts).to(self.device) + + # Initialize loss + if self.opts.lpips_lambda > 0: + self.lpips_loss = LPIPS(net_type=self.opts.lpips_type).to(self.device).eval() + if self.opts.id_lambda > 0: + if 'ffhq' in self.opts.dataset_type or 'celeb' in self.opts.dataset_type: + self.id_loss = id_loss.IDLoss().to(self.device).eval() + else: + self.id_loss = moco_loss.MocoLoss(opts).to(self.device).eval() + self.mse_loss = nn.MSELoss().to(self.device).eval() + + # Initialize optimizer + self.optimizer = self.configure_optimizers() + + # Initialize discriminator + if self.opts.w_discriminator_lambda > 0: + self.discriminator = LatentCodesDiscriminator(512, 4).to(self.device) + self.discriminator_optimizer = torch.optim.Adam(list(self.discriminator.parameters()), + lr=opts.w_discriminator_lr) + self.real_w_pool = LatentCodesPool(self.opts.w_pool_size) + self.fake_w_pool = LatentCodesPool(self.opts.w_pool_size) + + # Initialize dataset + self.train_dataset, self.test_dataset = self.configure_datasets() + self.train_dataloader = DataLoader(self.train_dataset, + batch_size=self.opts.batch_size, + shuffle=True, + num_workers=int(self.opts.workers), + drop_last=True) + self.test_dataloader = DataLoader(self.test_dataset, + batch_size=self.opts.test_batch_size, + shuffle=False, + num_workers=int(self.opts.test_workers), + drop_last=True) + + # Initialize logger + log_dir = os.path.join(opts.exp_dir, 'logs') + os.makedirs(log_dir, exist_ok=True) + self.logger = SummaryWriter(log_dir=log_dir) + + # Initialize checkpoint dir + self.checkpoint_dir = os.path.join(opts.exp_dir, 'checkpoints') + os.makedirs(self.checkpoint_dir, exist_ok=True) + self.best_val_loss = None + if self.opts.save_interval is None: + self.opts.save_interval = self.opts.max_steps + + if prev_train_checkpoint is not None: + self.load_from_train_checkpoint(prev_train_checkpoint) + prev_train_checkpoint = None + + def load_from_train_checkpoint(self, ckpt): + print('Loading previous training data...') + self.global_step = ckpt['global_step'] + 1 + self.best_val_loss = ckpt['best_val_loss'] + self.net.load_state_dict(ckpt['state_dict']) + + if self.opts.keep_optimizer: + self.optimizer.load_state_dict(ckpt['optimizer']) + if self.opts.w_discriminator_lambda > 0: + self.discriminator.load_state_dict(ckpt['discriminator_state_dict']) + self.discriminator_optimizer.load_state_dict(ckpt['discriminator_optimizer_state_dict']) + if self.opts.progressive_steps: + self.check_for_progressive_training_update(is_resume_from_ckpt=True) + print(f'Resuming training from step {self.global_step}') + + def train(self): + self.net.train() + if self.opts.progressive_steps: + self.check_for_progressive_training_update() + while self.global_step < self.opts.max_steps: + for batch_idx, batch in enumerate(self.train_dataloader): + loss_dict = {} + if self.is_training_discriminator(): + loss_dict = self.train_discriminator(batch) + x, y, y_hat, latent = self.forward(batch) + loss, encoder_loss_dict, id_logs = self.calc_loss(x, y, y_hat, latent) + loss_dict = {**loss_dict, **encoder_loss_dict} + self.optimizer.zero_grad() + loss.backward() + self.optimizer.step() + + # Logging related + if self.global_step % self.opts.image_interval == 0 or ( + self.global_step < 1000 and self.global_step % 25 == 0): + self.parse_and_log_images(id_logs, x, y, y_hat, title='images/train/faces') + if self.global_step % self.opts.board_interval == 0: + self.print_metrics(loss_dict, prefix='train') + self.log_metrics(loss_dict, prefix='train') + + # Validation related + val_loss_dict = None + if self.global_step % self.opts.val_interval == 0 or self.global_step == self.opts.max_steps: + val_loss_dict = self.validate() + if val_loss_dict and (self.best_val_loss is None or val_loss_dict['loss'] < self.best_val_loss): + self.best_val_loss = val_loss_dict['loss'] + self.checkpoint_me(val_loss_dict, is_best=True) + + if self.global_step % self.opts.save_interval == 0 or self.global_step == self.opts.max_steps: + if val_loss_dict is not None: + self.checkpoint_me(val_loss_dict, is_best=False) + else: + self.checkpoint_me(loss_dict, is_best=False) + + if self.global_step == self.opts.max_steps: + print('OMG, finished training!') + break + + self.global_step += 1 + if self.opts.progressive_steps: + self.check_for_progressive_training_update() + + def check_for_progressive_training_update(self, is_resume_from_ckpt=False): + for i in range(len(self.opts.progressive_steps)): + if is_resume_from_ckpt and self.global_step >= self.opts.progressive_steps[i]: # Case checkpoint + self.net.encoder.set_progressive_stage(ProgressiveStage(i)) + if self.global_step == self.opts.progressive_steps[i]: # Case training reached progressive step + self.net.encoder.set_progressive_stage(ProgressiveStage(i)) + + def validate(self): + self.net.eval() + agg_loss_dict = [] + for batch_idx, batch in enumerate(self.test_dataloader): + cur_loss_dict = {} + if self.is_training_discriminator(): + cur_loss_dict = self.validate_discriminator(batch) + with torch.no_grad(): + x, y, y_hat, latent = self.forward(batch) + loss, cur_encoder_loss_dict, id_logs = self.calc_loss(x, y, y_hat, latent) + cur_loss_dict = {**cur_loss_dict, **cur_encoder_loss_dict} + agg_loss_dict.append(cur_loss_dict) + + # Logging related + self.parse_and_log_images(id_logs, x, y, y_hat, + title='images/test/faces', + subscript='{:04d}'.format(batch_idx)) + + # For first step just do sanity test on small amount of data + if self.global_step == 0 and batch_idx >= 4: + self.net.train() + return None # Do not log, inaccurate in first batch + + loss_dict = train_utils.aggregate_loss_dict(agg_loss_dict) + self.log_metrics(loss_dict, prefix='test') + self.print_metrics(loss_dict, prefix='test') + + self.net.train() + return loss_dict + + def checkpoint_me(self, loss_dict, is_best): + save_name = 'best_model.pt' if is_best else 'iteration_{}.pt'.format(self.global_step) + save_dict = self.__get_save_dict() + checkpoint_path = os.path.join(self.checkpoint_dir, save_name) + torch.save(save_dict, checkpoint_path) + with open(os.path.join(self.checkpoint_dir, 'timestamp.txt'), 'a') as f: + if is_best: + f.write( + '**Best**: Step - {}, Loss - {:.3f} \n{}\n'.format(self.global_step, self.best_val_loss, loss_dict)) + else: + f.write('Step - {}, \n{}\n'.format(self.global_step, loss_dict)) + + def configure_optimizers(self): + params = list(self.net.encoder.parameters()) + if self.opts.train_decoder: + params += list(self.net.decoder.parameters()) + else: + self.requires_grad(self.net.decoder, False) + if self.opts.optim_name == 'adam': + optimizer = torch.optim.Adam(params, lr=self.opts.learning_rate) + else: + optimizer = Ranger(params, lr=self.opts.learning_rate) + return optimizer + + def configure_datasets(self): + if self.opts.dataset_type not in data_configs.DATASETS.keys(): + Exception('{} is not a valid dataset_type'.format(self.opts.dataset_type)) + print('Loading dataset for {}'.format(self.opts.dataset_type)) + dataset_args = data_configs.DATASETS[self.opts.dataset_type] + transforms_dict = dataset_args['transforms'](self.opts).get_transforms() + train_dataset = ImagesDataset(source_root=dataset_args['train_source_root'], + target_root=dataset_args['train_target_root'], + source_transform=transforms_dict['transform_source'], + target_transform=transforms_dict['transform_gt_train'], + opts=self.opts) + test_dataset = ImagesDataset(source_root=dataset_args['test_source_root'], + target_root=dataset_args['test_target_root'], + source_transform=transforms_dict['transform_source'], + target_transform=transforms_dict['transform_test'], + opts=self.opts) + print("Number of training samples: {}".format(len(train_dataset))) + print("Number of test samples: {}".format(len(test_dataset))) + return train_dataset, test_dataset + + def calc_loss(self, x, y, y_hat, latent): + loss_dict = {} + loss = 0.0 + id_logs = None + if self.is_training_discriminator(): # Adversarial loss + loss_disc = 0. + dims_to_discriminate = self.get_dims_to_discriminate() if self.is_progressive_training() else \ + list(range(self.net.decoder.n_latent)) + + for i in dims_to_discriminate: + w = latent[:, i, :] + fake_pred = self.discriminator(w) + loss_disc += F.softplus(-fake_pred).mean() + loss_disc /= len(dims_to_discriminate) + loss_dict['encoder_discriminator_loss'] = float(loss_disc) + loss += self.opts.w_discriminator_lambda * loss_disc + + if self.opts.progressive_steps and self.net.encoder.progressive_stage.value != 18: # delta regularization loss + total_delta_loss = 0 + deltas_latent_dims = self.net.encoder.get_deltas_starting_dimensions() + + first_w = latent[:, 0, :] + for i in range(1, self.net.encoder.progressive_stage.value + 1): + curr_dim = deltas_latent_dims[i] + delta = latent[:, curr_dim, :] - first_w + delta_loss = torch.norm(delta, self.opts.delta_norm, dim=1).mean() + loss_dict[f"delta{i}_loss"] = float(delta_loss) + total_delta_loss += delta_loss + loss_dict['total_delta_loss'] = float(total_delta_loss) + loss += self.opts.delta_norm_lambda * total_delta_loss + + if self.opts.id_lambda > 0: # Similarity loss + loss_id, sim_improvement, id_logs = self.id_loss(y_hat, y, x) + loss_dict['loss_id'] = float(loss_id) + loss_dict['id_improve'] = float(sim_improvement) + loss += loss_id * self.opts.id_lambda + if self.opts.l2_lambda > 0: + loss_l2 = F.mse_loss(y_hat, y) + loss_dict['loss_l2'] = float(loss_l2) + loss += loss_l2 * self.opts.l2_lambda + if self.opts.lpips_lambda > 0: + loss_lpips = self.lpips_loss(y_hat, y) + loss_dict['loss_lpips'] = float(loss_lpips) + loss += loss_lpips * self.opts.lpips_lambda + loss_dict['loss'] = float(loss) + return loss, loss_dict, id_logs + + def forward(self, batch): + x, y = batch + x, y = x.to(self.device).float(), y.to(self.device).float() + y_hat, latent = self.net.forward(x, return_latents=True) + if self.opts.dataset_type == "cars_encode": + y_hat = y_hat[:, :, 32:224, :] + return x, y, y_hat, latent + + def log_metrics(self, metrics_dict, prefix): + for key, value in metrics_dict.items(): + self.logger.add_scalar('{}/{}'.format(prefix, key), value, self.global_step) + + def print_metrics(self, metrics_dict, prefix): + print('Metrics for {}, step {}'.format(prefix, self.global_step)) + for key, value in metrics_dict.items(): + print('\t{} = '.format(key), value) + + def parse_and_log_images(self, id_logs, x, y, y_hat, title, subscript=None, display_count=2): + im_data = [] + for i in range(display_count): + cur_im_data = { + 'input_face': common.log_input_image(x[i], self.opts), + 'target_face': common.tensor2im(y[i]), + 'output_face': common.tensor2im(y_hat[i]), + } + if id_logs is not None: + for key in id_logs[i]: + cur_im_data[key] = id_logs[i][key] + im_data.append(cur_im_data) + self.log_images(title, im_data=im_data, subscript=subscript) + + def log_images(self, name, im_data, subscript=None, log_latest=False): + fig = common.vis_faces(im_data) + step = self.global_step + if log_latest: + step = 0 + if subscript: + path = os.path.join(self.logger.log_dir, name, '{}_{:04d}.jpg'.format(subscript, step)) + else: + path = os.path.join(self.logger.log_dir, name, '{:04d}.jpg'.format(step)) + os.makedirs(os.path.dirname(path), exist_ok=True) + fig.savefig(path) + plt.close(fig) + + def __get_save_dict(self): + save_dict = { + 'state_dict': self.net.state_dict(), + 'opts': vars(self.opts) + } + # save the latent avg in state_dict for inference if truncation of w was used during training + if self.opts.start_from_latent_avg: + save_dict['latent_avg'] = self.net.latent_avg + + if self.opts.save_training_data: # Save necessary information to enable training continuation from checkpoint + save_dict['global_step'] = self.global_step + save_dict['optimizer'] = self.optimizer.state_dict() + save_dict['best_val_loss'] = self.best_val_loss + if self.opts.w_discriminator_lambda > 0: + save_dict['discriminator_state_dict'] = self.discriminator.state_dict() + save_dict['discriminator_optimizer_state_dict'] = self.discriminator_optimizer.state_dict() + return save_dict + + def get_dims_to_discriminate(self): + deltas_starting_dimensions = self.net.encoder.get_deltas_starting_dimensions() + return deltas_starting_dimensions[:self.net.encoder.progressive_stage.value + 1] + + def is_progressive_training(self): + return self.opts.progressive_steps is not None + +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Discriminator ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # + + def is_training_discriminator(self): + return self.opts.w_discriminator_lambda > 0 + + @staticmethod + def discriminator_loss(real_pred, fake_pred, loss_dict): + real_loss = F.softplus(-real_pred).mean() + fake_loss = F.softplus(fake_pred).mean() + + loss_dict['d_real_loss'] = float(real_loss) + loss_dict['d_fake_loss'] = float(fake_loss) + + return real_loss + fake_loss + + @staticmethod + def discriminator_r1_loss(real_pred, real_w): + grad_real, = autograd.grad( + outputs=real_pred.sum(), inputs=real_w, create_graph=True + ) + grad_penalty = grad_real.pow(2).reshape(grad_real.shape[0], -1).sum(1).mean() + + return grad_penalty + + @staticmethod + def requires_grad(model, flag=True): + for p in model.parameters(): + p.requires_grad = flag + + def train_discriminator(self, batch): + loss_dict = {} + x, _ = batch + x = x.to(self.device).float() + self.requires_grad(self.discriminator, True) + + with torch.no_grad(): + real_w, fake_w = self.sample_real_and_fake_latents(x) + real_pred = self.discriminator(real_w) + fake_pred = self.discriminator(fake_w) + loss = self.discriminator_loss(real_pred, fake_pred, loss_dict) + loss_dict['discriminator_loss'] = float(loss) + + self.discriminator_optimizer.zero_grad() + loss.backward() + self.discriminator_optimizer.step() + + # r1 regularization + d_regularize = self.global_step % self.opts.d_reg_every == 0 + if d_regularize: + real_w = real_w.detach() + real_w.requires_grad = True + real_pred = self.discriminator(real_w) + r1_loss = self.discriminator_r1_loss(real_pred, real_w) + + self.discriminator.zero_grad() + r1_final_loss = self.opts.r1 / 2 * r1_loss * self.opts.d_reg_every + 0 * real_pred[0] + r1_final_loss.backward() + self.discriminator_optimizer.step() + loss_dict['discriminator_r1_loss'] = float(r1_final_loss) + + # Reset to previous state + self.requires_grad(self.discriminator, False) + + return loss_dict + + def validate_discriminator(self, test_batch): + with torch.no_grad(): + loss_dict = {} + x, _ = test_batch + x = x.to(self.device).float() + real_w, fake_w = self.sample_real_and_fake_latents(x) + real_pred = self.discriminator(real_w) + fake_pred = self.discriminator(fake_w) + loss = self.discriminator_loss(real_pred, fake_pred, loss_dict) + loss_dict['discriminator_loss'] = float(loss) + return loss_dict + + def sample_real_and_fake_latents(self, x): + sample_z = torch.randn(self.opts.batch_size, 512, device=self.device) + real_w = self.net.decoder.get_latent(sample_z) + fake_w = self.net.encoder(x) + if self.opts.start_from_latent_avg: + fake_w = fake_w + self.net.latent_avg.repeat(fake_w.shape[0], 1, 1) + if self.is_progressive_training(): # When progressive training, feed only unique w's + dims_to_discriminate = self.get_dims_to_discriminate() + fake_w = fake_w[:, dims_to_discriminate, :] + if self.opts.use_w_pool: + real_w = self.real_w_pool.query(real_w) + fake_w = self.fake_w_pool.query(fake_w) + if fake_w.ndim == 3: + fake_w = fake_w[:, 0, :] + return real_w, fake_w diff --git a/encoder4editing/training/ranger.py b/encoder4editing/training/ranger.py new file mode 100644 index 0000000000000000000000000000000000000000..3d63264dda6df0ee40cac143440f0b5f8977a9ad --- /dev/null +++ b/encoder4editing/training/ranger.py @@ -0,0 +1,164 @@ +# Ranger deep learning optimizer - RAdam + Lookahead + Gradient Centralization, combined into one optimizer. + +# https://github.com/lessw2020/Ranger-Deep-Learning-Optimizer +# and/or +# https://github.com/lessw2020/Best-Deep-Learning-Optimizers + +# Ranger has now been used to capture 12 records on the FastAI leaderboard. + +# This version = 20.4.11 + +# Credits: +# Gradient Centralization --> https://arxiv.org/abs/2004.01461v2 (a new optimization technique for DNNs), github: https://github.com/Yonghongwei/Gradient-Centralization +# RAdam --> https://github.com/LiyuanLucasLiu/RAdam +# Lookahead --> rewritten by lessw2020, but big thanks to Github @LonePatient and @RWightman for ideas from their code. +# Lookahead paper --> MZhang,G Hinton https://arxiv.org/abs/1907.08610 + +# summary of changes: +# 4/11/20 - add gradient centralization option. Set new testing benchmark for accuracy with it, toggle with use_gc flag at init. +# full code integration with all updates at param level instead of group, moves slow weights into state dict (from generic weights), +# supports group learning rates (thanks @SHolderbach), fixes sporadic load from saved model issues. +# changes 8/31/19 - fix references to *self*.N_sma_threshold; +# changed eps to 1e-5 as better default than 1e-8. + +import math +import torch +from torch.optim.optimizer import Optimizer + + +class Ranger(Optimizer): + + def __init__(self, params, lr=1e-3, # lr + alpha=0.5, k=6, N_sma_threshhold=5, # Ranger options + betas=(.95, 0.999), eps=1e-5, weight_decay=0, # Adam options + use_gc=True, gc_conv_only=False + # Gradient centralization on or off, applied to conv layers only or conv + fc layers + ): + + # parameter checks + if not 0.0 <= alpha <= 1.0: + raise ValueError(f'Invalid slow update rate: {alpha}') + if not 1 <= k: + raise ValueError(f'Invalid lookahead steps: {k}') + if not lr > 0: + raise ValueError(f'Invalid Learning Rate: {lr}') + if not eps > 0: + raise ValueError(f'Invalid eps: {eps}') + + # parameter comments: + # beta1 (momentum) of .95 seems to work better than .90... + # N_sma_threshold of 5 seems better in testing than 4. + # In both cases, worth testing on your dataset (.90 vs .95, 4 vs 5) to make sure which works best for you. + + # prep defaults and init torch.optim base + defaults = dict(lr=lr, alpha=alpha, k=k, step_counter=0, betas=betas, N_sma_threshhold=N_sma_threshhold, + eps=eps, weight_decay=weight_decay) + super().__init__(params, defaults) + + # adjustable threshold + self.N_sma_threshhold = N_sma_threshhold + + # look ahead params + + self.alpha = alpha + self.k = k + + # radam buffer for state + self.radam_buffer = [[None, None, None] for ind in range(10)] + + # gc on or off + self.use_gc = use_gc + + # level of gradient centralization + self.gc_gradient_threshold = 3 if gc_conv_only else 1 + + def __setstate__(self, state): + super(Ranger, self).__setstate__(state) + + def step(self, closure=None): + loss = None + + # Evaluate averages and grad, update param tensors + for group in self.param_groups: + + for p in group['params']: + if p.grad is None: + continue + grad = p.grad.data.float() + + if grad.is_sparse: + raise RuntimeError('Ranger optimizer does not support sparse gradients') + + p_data_fp32 = p.data.float() + + state = self.state[p] # get state dict for this param + + if len(state) == 0: # if first time to run...init dictionary with our desired entries + # if self.first_run_check==0: + # self.first_run_check=1 + # print("Initializing slow buffer...should not see this at load from saved model!") + state['step'] = 0 + state['exp_avg'] = torch.zeros_like(p_data_fp32) + state['exp_avg_sq'] = torch.zeros_like(p_data_fp32) + + # look ahead weight storage now in state dict + state['slow_buffer'] = torch.empty_like(p.data) + state['slow_buffer'].copy_(p.data) + + else: + state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32) + state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32) + + # begin computations + exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] + beta1, beta2 = group['betas'] + + # GC operation for Conv layers and FC layers + if grad.dim() > self.gc_gradient_threshold: + grad.add_(-grad.mean(dim=tuple(range(1, grad.dim())), keepdim=True)) + + state['step'] += 1 + + # compute variance mov avg + exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) + # compute mean moving avg + exp_avg.mul_(beta1).add_(1 - beta1, grad) + + buffered = self.radam_buffer[int(state['step'] % 10)] + + if state['step'] == buffered[0]: + N_sma, step_size = buffered[1], buffered[2] + else: + buffered[0] = state['step'] + beta2_t = beta2 ** state['step'] + N_sma_max = 2 / (1 - beta2) - 1 + N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t) + buffered[1] = N_sma + if N_sma > self.N_sma_threshhold: + step_size = math.sqrt( + (1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / ( + N_sma_max - 2)) / (1 - beta1 ** state['step']) + else: + step_size = 1.0 / (1 - beta1 ** state['step']) + buffered[2] = step_size + + if group['weight_decay'] != 0: + p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32) + + # apply lr + if N_sma > self.N_sma_threshhold: + denom = exp_avg_sq.sqrt().add_(group['eps']) + p_data_fp32.addcdiv_(-step_size * group['lr'], exp_avg, denom) + else: + p_data_fp32.add_(-step_size * group['lr'], exp_avg) + + p.data.copy_(p_data_fp32) + + # integrated look ahead... + # we do it at the param level instead of group level + if state['step'] % group['k'] == 0: + slow_p = state['slow_buffer'] # get access to slow param tensor + slow_p.add_(self.alpha, p.data - slow_p) # (fast weights - slow weights) * alpha + p.data.copy_(slow_p) # copy interpolated weights to RAdam param tensor + + return loss \ No newline at end of file diff --git a/encoder4editing/utils/__init__.py b/encoder4editing/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/encoder4editing/utils/alignment.py b/encoder4editing/utils/alignment.py new file mode 100644 index 0000000000000000000000000000000000000000..a02798f0f7c9fdcc319f7884a491b9e6580cc8aa --- /dev/null +++ b/encoder4editing/utils/alignment.py @@ -0,0 +1,115 @@ +import numpy as np +import PIL +import PIL.Image +import scipy +import scipy.ndimage +import dlib + + +def get_landmark(filepath, predictor): + """get landmark with dlib + :return: np.array shape=(68, 2) + """ + detector = dlib.get_frontal_face_detector() + + img = dlib.load_rgb_image(filepath) + dets = detector(img, 1) + + for k, d in enumerate(dets): + shape = predictor(img, d) + + t = list(shape.parts()) + a = [] + for tt in t: + a.append([tt.x, tt.y]) + lm = np.array(a) + return lm + + +def align_face(filepath, predictor): + """ + :param filepath: str + :return: PIL Image + """ + + lm = get_landmark(filepath, predictor) + + lm_chin = lm[0: 17] # left-right + lm_eyebrow_left = lm[17: 22] # left-right + lm_eyebrow_right = lm[22: 27] # left-right + lm_nose = lm[27: 31] # top-down + lm_nostrils = lm[31: 36] # top-down + lm_eye_left = lm[36: 42] # left-clockwise + lm_eye_right = lm[42: 48] # left-clockwise + lm_mouth_outer = lm[48: 60] # left-clockwise + lm_mouth_inner = lm[60: 68] # left-clockwise + + # Calculate auxiliary vectors. + eye_left = np.mean(lm_eye_left, axis=0) + eye_right = np.mean(lm_eye_right, axis=0) + eye_avg = (eye_left + eye_right) * 0.5 + eye_to_eye = eye_right - eye_left + mouth_left = lm_mouth_outer[0] + mouth_right = lm_mouth_outer[6] + mouth_avg = (mouth_left + mouth_right) * 0.5 + eye_to_mouth = mouth_avg - eye_avg + + # Choose oriented crop rectangle. + x = eye_to_eye - np.flipud(eye_to_mouth) * [-1, 1] + x /= np.hypot(*x) + x *= max(np.hypot(*eye_to_eye) * 2.0, np.hypot(*eye_to_mouth) * 1.8) + y = np.flipud(x) * [-1, 1] + c = eye_avg + eye_to_mouth * 0.1 + quad = np.stack([c - x - y, c - x + y, c + x + y, c + x - y]) + qsize = np.hypot(*x) * 2 + + # read image + img = PIL.Image.open(filepath) + + output_size = 256 + transform_size = 256 + enable_padding = True + + # Shrink. + shrink = int(np.floor(qsize / output_size * 0.5)) + if shrink > 1: + rsize = (int(np.rint(float(img.size[0]) / shrink)), int(np.rint(float(img.size[1]) / shrink))) + img = img.resize(rsize, PIL.Image.ANTIALIAS) + quad /= shrink + qsize /= shrink + + # Crop. + border = max(int(np.rint(qsize * 0.1)), 3) + crop = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))), int(np.ceil(max(quad[:, 0]))), + int(np.ceil(max(quad[:, 1])))) + crop = (max(crop[0] - border, 0), max(crop[1] - border, 0), min(crop[2] + border, img.size[0]), + min(crop[3] + border, img.size[1])) + if crop[2] - crop[0] < img.size[0] or crop[3] - crop[1] < img.size[1]: + img = img.crop(crop) + quad -= crop[0:2] + + # Pad. + pad = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))), int(np.ceil(max(quad[:, 0]))), + int(np.ceil(max(quad[:, 1])))) + pad = (max(-pad[0] + border, 0), max(-pad[1] + border, 0), max(pad[2] - img.size[0] + border, 0), + max(pad[3] - img.size[1] + border, 0)) + if enable_padding and max(pad) > border - 4: + pad = np.maximum(pad, int(np.rint(qsize * 0.3))) + img = np.pad(np.float32(img), ((pad[1], pad[3]), (pad[0], pad[2]), (0, 0)), 'reflect') + h, w, _ = img.shape + y, x, _ = np.ogrid[:h, :w, :1] + mask = np.maximum(1.0 - np.minimum(np.float32(x) / pad[0], np.float32(w - 1 - x) / pad[2]), + 1.0 - np.minimum(np.float32(y) / pad[1], np.float32(h - 1 - y) / pad[3])) + blur = qsize * 0.02 + img += (scipy.ndimage.gaussian_filter(img, [blur, blur, 0]) - img) * np.clip(mask * 3.0 + 1.0, 0.0, 1.0) + img += (np.median(img, axis=(0, 1)) - img) * np.clip(mask, 0.0, 1.0) + img = PIL.Image.fromarray(np.uint8(np.clip(np.rint(img), 0, 255)), 'RGB') + quad += pad[:2] + + # Transform. + img = img.transform((transform_size, transform_size), PIL.Image.QUAD, (quad + 0.5).flatten(), PIL.Image.BILINEAR) + if output_size < transform_size: + img = img.resize((output_size, output_size), PIL.Image.ANTIALIAS) + + # Return aligned image. + return img diff --git a/encoder4editing/utils/common.py b/encoder4editing/utils/common.py new file mode 100644 index 0000000000000000000000000000000000000000..b19e18ddcb78b06678fa18e4a76da44fc511b789 --- /dev/null +++ b/encoder4editing/utils/common.py @@ -0,0 +1,55 @@ +from PIL import Image +import matplotlib.pyplot as plt + + +# Log images +def log_input_image(x, opts): + return tensor2im(x) + + +def tensor2im(var): + # var shape: (3, H, W) + var = var.cpu().detach().transpose(0, 2).transpose(0, 1).numpy() + var = ((var + 1) / 2) + var[var < 0] = 0 + var[var > 1] = 1 + var = var * 255 + return Image.fromarray(var.astype('uint8')) + + +def vis_faces(log_hooks): + display_count = len(log_hooks) + fig = plt.figure(figsize=(8, 4 * display_count)) + gs = fig.add_gridspec(display_count, 3) + for i in range(display_count): + hooks_dict = log_hooks[i] + fig.add_subplot(gs[i, 0]) + if 'diff_input' in hooks_dict: + vis_faces_with_id(hooks_dict, fig, gs, i) + else: + vis_faces_no_id(hooks_dict, fig, gs, i) + plt.tight_layout() + return fig + + +def vis_faces_with_id(hooks_dict, fig, gs, i): + plt.imshow(hooks_dict['input_face']) + plt.title('Input\nOut Sim={:.2f}'.format(float(hooks_dict['diff_input']))) + fig.add_subplot(gs[i, 1]) + plt.imshow(hooks_dict['target_face']) + plt.title('Target\nIn={:.2f}, Out={:.2f}'.format(float(hooks_dict['diff_views']), + float(hooks_dict['diff_target']))) + fig.add_subplot(gs[i, 2]) + plt.imshow(hooks_dict['output_face']) + plt.title('Output\n Target Sim={:.2f}'.format(float(hooks_dict['diff_target']))) + + +def vis_faces_no_id(hooks_dict, fig, gs, i): + plt.imshow(hooks_dict['input_face'], cmap="gray") + plt.title('Input') + fig.add_subplot(gs[i, 1]) + plt.imshow(hooks_dict['target_face']) + plt.title('Target') + fig.add_subplot(gs[i, 2]) + plt.imshow(hooks_dict['output_face']) + plt.title('Output') diff --git a/encoder4editing/utils/data_utils.py b/encoder4editing/utils/data_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..f1ba79f4a2d5cc2b97dce76d87bf6e7cdebbc257 --- /dev/null +++ b/encoder4editing/utils/data_utils.py @@ -0,0 +1,25 @@ +""" +Code adopted from pix2pixHD: +https://github.com/NVIDIA/pix2pixHD/blob/master/data/image_folder.py +""" +import os + +IMG_EXTENSIONS = [ + '.jpg', '.JPG', '.jpeg', '.JPEG', + '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', '.tiff' +] + + +def is_image_file(filename): + return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) + + +def make_dataset(dir): + images = [] + assert os.path.isdir(dir), '%s is not a valid directory' % dir + for root, _, fnames in sorted(os.walk(dir)): + for fname in fnames: + if is_image_file(fname): + path = os.path.join(root, fname) + images.append(path) + return images diff --git a/encoder4editing/utils/model_utils.py b/encoder4editing/utils/model_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e51e95578f72b3218d6d832e3b604193cb68c1d7 --- /dev/null +++ b/encoder4editing/utils/model_utils.py @@ -0,0 +1,35 @@ +import torch +import argparse +from models.psp import pSp +from models.encoders.psp_encoders import Encoder4Editing + + +def setup_model(checkpoint_path, device='cuda'): + ckpt = torch.load(checkpoint_path, map_location='cpu') + opts = ckpt['opts'] + + opts['checkpoint_path'] = checkpoint_path + opts['device'] = device + opts = argparse.Namespace(**opts) + + net = pSp(opts) + net.eval() + net = net.to(device) + return net, opts + + +def load_e4e_standalone(checkpoint_path, device='cuda'): + ckpt = torch.load(checkpoint_path, map_location='cpu') + opts = argparse.Namespace(**ckpt['opts']) + e4e = Encoder4Editing(50, 'ir_se', opts) + e4e_dict = {k.replace('encoder.', ''): v for k, v in ckpt['state_dict'].items() if k.startswith('encoder.')} + e4e.load_state_dict(e4e_dict) + e4e.eval() + e4e = e4e.to(device) + latent_avg = ckpt['latent_avg'].to(device) + + def add_latent_avg(model, inputs, outputs): + return outputs + latent_avg.repeat(outputs.shape[0], 1, 1) + + e4e.register_forward_hook(add_latent_avg) + return e4e diff --git a/encoder4editing/utils/train_utils.py b/encoder4editing/utils/train_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..0c55177f7442010bc1fcc64de3d142585c22adc0 --- /dev/null +++ b/encoder4editing/utils/train_utils.py @@ -0,0 +1,13 @@ + +def aggregate_loss_dict(agg_loss_dict): + mean_vals = {} + for output in agg_loss_dict: + for key in output: + mean_vals[key] = mean_vals.setdefault(key, []) + [output[key]] + for key in mean_vals: + if len(mean_vals[key]) > 0: + mean_vals[key] = sum(mean_vals[key]) / len(mean_vals[key]) + else: + print('{} has no value'.format(key)) + mean_vals[key] = 0 + return mean_vals diff --git a/find_direction.py b/find_direction.py new file mode 100644 index 0000000000000000000000000000000000000000..c4d9f6b95869e822b64bd676e2bd27f5512f10a6 --- /dev/null +++ b/find_direction.py @@ -0,0 +1,372 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +"""Generate images using pretrained network pickle.""" + +import os +import re +import random +import math +import time +import click +import legacy +from typing import List, Optional + +import cv2 +import clip +import dnnlib +import numpy as np +import torch +from torch import linalg as LA +import torch.nn.functional as F +import torchvision +from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize +import PIL.Image +from PIL import Image +import matplotlib.pyplot as plt + +from torch_utils import misc +from torch_utils import persistence +from torch_utils.ops import conv2d_resample +from torch_utils.ops import upfirdn2d +from torch_utils.ops import bias_act +from torch_utils.ops import fma +import id_loss + + +def block_forward(self, x, img, ws, shapes, force_fp32=False, fused_modconv=None, **layer_kwargs): + misc.assert_shape(ws, [None, self.num_conv + self.num_torgb, self.w_dim]) + w_iter = iter(ws.unbind(dim=1)) + dtype = torch.float16 if self.use_fp16 and not force_fp32 else torch.float32 + memory_format = torch.channels_last if self.channels_last and not force_fp32 else torch.contiguous_format + if fused_modconv is None: + with misc.suppress_tracer_warnings(): # this value will be treated as a constant + fused_modconv = (not self.training) and (dtype == torch.float32 or int(x.shape[0]) == 1) + + # Input. + if self.in_channels == 0: + x = self.const.to(dtype=dtype, memory_format=memory_format) + x = x.unsqueeze(0).repeat([ws.shape[0], 1, 1, 1]) + else: + misc.assert_shape(x, [None, self.in_channels, self.resolution // 2, self.resolution // 2]) + x = x.to(dtype=dtype, memory_format=memory_format) + + # Main layers. + if self.in_channels == 0: + x = self.conv1(x, next(w_iter)[...,:shapes[0]], fused_modconv=fused_modconv, **layer_kwargs) + elif self.architecture == 'resnet': + y = self.skip(x, gain=np.sqrt(0.5)) + x = self.conv0(x, next(w_iter), fused_modconv=fused_modconv, **layer_kwargs) + x = self.conv1(x, next(w_iter), fused_modconv=fused_modconv, gain=np.sqrt(0.5), **layer_kwargs) + x = y.add_(x) + else: + x = self.conv0(x, next(w_iter)[...,:shapes[0]], fused_modconv=fused_modconv, **layer_kwargs) + x = self.conv1(x, next(w_iter)[...,:shapes[1]], fused_modconv=fused_modconv, **layer_kwargs) + + # ToRGB. + if img is not None: + misc.assert_shape(img, [None, self.img_channels, self.resolution // 2, self.resolution // 2]) + img = upfirdn2d.upsample2d(img, self.resample_filter) + if self.is_last or self.architecture == 'skip': + y = self.torgb(x, next(w_iter)[...,:shapes[2]], fused_modconv=fused_modconv) + y = y.to(dtype=torch.float32, memory_format=torch.contiguous_format) + img = img.add_(y) if img is not None else y + + assert x.dtype == dtype + assert img is None or img.dtype == torch.float32 + return x, img + + + +def unravel_index(index, shape): + out = [] + for dim in reversed(shape): + out.append(index % dim) + index = index // dim + return tuple(reversed(out)) + + +def num_range(s: str) -> List[int]: + """ + Accept either a comma separated list of numbers 'a,b,c' or a range 'a-c' and return as a list of ints. + """ + + range_re = re.compile(r'^(\d+)-(\d+)$') + m = range_re.match(s) + if m: + return list(range(int(m.group(1)), int(m.group(2)) + 1)) + vals = s.split(',') + return [int(x) for x in vals] + + +@click.command() +@click.pass_context +@click.option('--network', 'network_pkl', help='Network pickle filename', required=True) +@click.option('--seeds', type=num_range, help='List of random seeds') +@click.option('--trunc', 'truncation_psi', type=float, help='Truncation psi', default=1, show_default=True) +@click.option('--class', 'class_idx', type=int, help='Class label (unconditional if not specified)') +@click.option('--noise-mode', help='Noise mode', type=click.Choice(['const', 'random', 'none']), default='const', show_default=True) +@click.option('--projected-w', help='Projection result file', type=str, metavar='FILE') +@click.option('--projected_s', help='Projection result file', type=str, metavar='FILE') +@click.option('--outdir', help='Where to save the output images', type=str, required=True, metavar='DIR') +@click.option('--text_prompt', help='Text', type=str, required=True) +@click.option('--resolution', help='Resolution of output images', type=int, required=True) +@click.option('--batch_size', help='Batch Size', type=int, required=True) +@click.option('--identity_power', help='How much change occurs on the face', type=str, required=True) +def generate_images( + ctx: click.Context, + network_pkl: str, + seeds: Optional[List[int]], + truncation_psi: float, + noise_mode: str, + outdir: str, + class_idx: Optional[int], + projected_w: Optional[str], + projected_s: Optional[str], + text_prompt: str, + resolution: int, + batch_size: int, + identity_power: str, +): + """ + Generate images using pretrained network pickle. + + Examples: + # Generate curated MetFaces images without truncation (Fig.10 left) + python generate.py --outdir=out --trunc=1 --seeds=85,265,297,849 \\ + --network=https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metfaces.pkl + + # Generate uncurated MetFaces images with truncation (Fig.12 upper left) + python generate.py --outdir=out --trunc=0.7 --seeds=600-605 \\ + --network=https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metfaces.pkl + + # Generate class conditional CIFAR-10 images (Fig.17 left, Car) + python generate.py --outdir=out --seeds=0-35 --class=1 \\ + --network=https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/cifar10.pkl + + # Render an image from projected W + python generate.py --outdir=out --projected_w=projected_w.npz \\ + --network=https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metfaces.pkl + """ + + print('Loading networks from "%s"...' % network_pkl) + # Use GPU if available + if torch.cuda.is_available(): + device = torch.device("cuda") + else: + device = torch.device("cpu") + + with dnnlib.util.open_url(network_pkl) as f: + G = legacy.load_network_pkl(f)['G_ema'].to(device) # type: ignore + + os.makedirs(outdir, exist_ok=True) + + # Synthesize the result of a W projection + if projected_w is not None: + if seeds is not None: + print('warn: --seeds is ignored when using --projected-w') + print(f'Generating images from projected W "{projected_w}"') + ws = np.load(projected_w)['w'] + ws = torch.tensor(ws, device=device) # pylint: disable=not-callable + assert ws.shape[1:] == (G.num_ws, G.w_dim) + for idx, w in enumerate(ws): + img = G.synthesis(w.unsqueeze(0), noise_mode=noise_mode) + img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8) + img = PIL.Image.fromarray(img[0].cpu().numpy(), 'RGB').save(f'{outdir}/proj{idx:02d}.png') + return + + if seeds is None: + ctx.fail('--seeds option is required when not using --projected-w') + + # Labels + label = torch.zeros([1, G.c_dim], device=device).requires_grad_() + if G.c_dim != 0: + if class_idx is None: + ctx.fail('Must specify class label with --class when using a conditional network') + label[:, class_idx] = 1 + else: + if class_idx is not None: + print('warn: --class=lbl ignored when running on an unconditional network') + + model, preprocess = clip.load("ViT-B/32", device=device) + text = clip.tokenize([text_prompt]).to(device) + text_features = model.encode_text(text) + + # Generate images + for i in G.parameters(): + i.requires_grad = True + + mean = torch.as_tensor((0.48145466, 0.4578275, 0.40821073), dtype=torch.float, device=device) + std = torch.as_tensor((0.26862954, 0.26130258, 0.27577711), dtype=torch.float, device=device) + if mean.ndim == 1: + mean = mean.view(-1, 1, 1) + if std.ndim == 1: + std = std.view(-1, 1, 1) + + transf = Compose([Resize(224, interpolation=Image.BICUBIC), CenterCrop(224)]) + + styles_array = [] + print("seeds:", seeds) + t1 = time.time() + for seed_idx, seed in enumerate(seeds): + if seed == seeds[-1]: + print('Generating image for seed %d (%d/%d) ...' % (seed, seed_idx, len(seeds))) + z = torch.from_numpy(np.random.RandomState(seed).randn(1, G.z_dim)).to(device) + ws = G.mapping(z, label, truncation_psi=truncation_psi) + + block_ws = [] + with torch.autograd.profiler.record_function('split_ws'): + misc.assert_shape(ws, [None, G.synthesis.num_ws, G.synthesis.w_dim]) + ws = ws.to(torch.float32) + + w_idx = 0 + for res in G.synthesis.block_resolutions: + block = getattr(G.synthesis, f'b{res}') + block_ws.append(ws.narrow(1, w_idx, block.num_conv + block.num_torgb)) + w_idx += block.num_conv + + styles = torch.zeros(1, 26, 512, device=device) + styles_idx = 0 + temp_shapes = [] + for res, cur_ws in zip(G.synthesis.block_resolutions, block_ws): + block = getattr(G.synthesis, f'b{res}') + + if res == 4: + temp_shape = (block.conv1.affine.weight.shape[0], block.conv1.affine.weight.shape[0], block.torgb.affine.weight.shape[0]) + styles[0, :1, :] = block.conv1.affine(cur_ws[0, :1, :]) + styles[0, 1:2, :] = block.torgb.affine(cur_ws[0, 1:2, :]) + if seed_idx == (len(seeds) - 1): + block.conv1.affine = torch.nn.Identity() + block.torgb.affine = torch.nn.Identity() + styles_idx += 2 + else: + temp_shape = (block.conv0.affine.weight.shape[0], block.conv1.affine.weight.shape[0], block.torgb.affine.weight.shape[0]) + styles[0,styles_idx:styles_idx+1,:temp_shape[0]] = block.conv0.affine(cur_ws[0,:1,:]) + styles[0,styles_idx+1:styles_idx+2,:temp_shape[1]] = block.conv1.affine(cur_ws[0,1:2,:]) + styles[0,styles_idx+2:styles_idx+3,:temp_shape[2]] = block.torgb.affine(cur_ws[0,2:3,:]) + if seed_idx == (len(seeds) - 1): + block.conv0.affine = torch.nn.Identity() + block.conv1.affine = torch.nn.Identity() + block.torgb.affine = torch.nn.Identity() + styles_idx += 3 + temp_shapes.append(temp_shape) + + styles = styles.detach() + styles_array.append(styles) + + resolution_dict = {256: 6, 512: 7, 1024: 8} + id_coeff_dict = {"high": 2, "medium": 0.5, "low": 0.1, "none": 0} + id_coeff = id_coeff_dict[identity_power] + styles_direction = torch.zeros(1, 26, 512, device=device) + styles_direction_grad_el2 = torch.zeros(1, 26, 512, device=device) + styles_direction.requires_grad_() + + global id_loss + id_loss = id_loss.IDLoss("a").to(device).eval() + + temp_photos = [] + grads = [] + for i in range(math.ceil(len(seeds) / batch_size)): + # print(i*batch_size, "processed", time.time()-t1) + + styles = torch.vstack(styles_array[i*batch_size:(i+1)*batch_size]).to(device) + seed = seeds[i] + + styles_idx = 0 + x2 = img2 = None + + for k, (res, cur_ws) in enumerate(zip(G.synthesis.block_resolutions, block_ws)): + block = getattr(G.synthesis, f'b{res}') + if k > resolution_dict[resolution]: + continue + + if res == 4: + x2, img2 = block_forward(block, x2, img2, styles[:, styles_idx:styles_idx+2, :], temp_shapes[k], noise_mode=noise_mode, force_fp32=True) + styles_idx += 2 + else: + x2, img2 = block_forward(block, x2, img2, styles[:, styles_idx:styles_idx+3, :], temp_shapes[k], noise_mode=noise_mode, force_fp32=True) + styles_idx += 3 + + img2_cpu = img2.detach().cpu().numpy() + temp_photos.append(img2_cpu) + if i > 3: + continue + + styles2 = styles + styles_direction + + styles_idx = 0 + x = img = None + for k, (res, cur_ws) in enumerate(zip(G.synthesis.block_resolutions, block_ws)): + block = getattr(G.synthesis, f'b{res}') + if k > resolution_dict[resolution]: + continue + if res == 4: + x, img = block_forward(block, x, img, styles2[:, styles_idx:styles_idx+2, :], temp_shapes[k], noise_mode=noise_mode, force_fp32=True) + styles_idx += 2 + else: + x, img = block_forward(block, x, img, styles2[:, styles_idx:styles_idx+3, :], temp_shapes[k], noise_mode=noise_mode, force_fp32=True) + styles_idx += 3 + + identity_loss, _ = id_loss(img, img2) + identity_loss *= id_coeff + img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255) + img = (transf(img.permute(0, 3, 1, 2)) / 255).sub_(mean).div_(std) + image_features = model.encode_image(img) + cos_sim = -1*F.cosine_similarity(image_features, (text_features[0]).unsqueeze(0)) + (identity_loss + cos_sim.sum()).backward(retain_graph=True) + + styles_direction.grad[:, list(range(26)), :] = 0 + with torch.no_grad(): + styles_direction *= 0 + + for i in range(math.ceil(len(seeds) / batch_size)): + seed = seeds[i] + styles = torch.vstack(styles_array[i*batch_size:(i+1)*batch_size]).to(device) + img2 = torch.tensor(temp_photos[i]).to(device) + styles2 = styles + styles_direction + + styles_idx = 0 + x = img = None + for k, (res, cur_ws) in enumerate(zip(G.synthesis.block_resolutions, block_ws)): + block = getattr(G.synthesis, f'b{res}') + if k > resolution_dict[resolution]: + continue + + if res == 4: + x, img = block_forward(block, x, img, styles2[:, styles_idx:styles_idx+2, :], temp_shapes[k], noise_mode=noise_mode, force_fp32=True) + styles_idx += 2 + else: + x, img = block_forward(block, x, img, styles2[:, styles_idx:styles_idx+3, :], temp_shapes[k], noise_mode=noise_mode, force_fp32=True) + styles_idx += 3 + + identity_loss, _ = id_loss(img, img2) + identity_loss *= id_coeff + img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255) + img = (transf(img.permute(0, 3, 1, 2)) / 255).sub_(mean).div_(std) + image_features = model.encode_image(img) + cos_sim = -1*F.cosine_similarity(image_features, (text_features[0]).unsqueeze(0)) + (identity_loss + cos_sim.sum()).backward(retain_graph=True) + + styles_direction.grad[:, [0, 1, 4, 7, 10, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25], :] = 0 + + if i % 2 == 1: + styles_direction.data = (styles_direction - styles_direction.grad * 5) + grads.append(styles_direction.grad.clone()) + styles_direction.grad.data.zero_() + if i > 3: + styles_direction_grad_el2[grads[-2] * grads[-1] < 0] += 1 + + styles_direction = styles_direction.detach() + styles_direction[styles_direction_grad_el2 > (len(seeds) / batch_size) / 4] = 0 + + output_filepath = f'{outdir}/direction_' + text_prompt.replace(" ", "_") + '.npz' + np.savez(output_filepath, s=styles_direction.cpu().numpy()) + + +if __name__ == "__main__": + generate_images() diff --git a/generate_fromS.py b/generate_fromS.py new file mode 100644 index 0000000000000000000000000000000000000000..e1f17d07338b0fa7ee26adeaa95442ae43de9fef --- /dev/null +++ b/generate_fromS.py @@ -0,0 +1,277 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +"""Generate images using pretrained network pickle.""" + +import os +import re +import random +import math +import time +import click +import legacy +from typing import List, Optional + +import cv2 +import clip +import dnnlib +import numpy as np +import torchvision +from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize +import PIL.Image +import matplotlib.pyplot as plt +import torch +from torch import linalg as LA +import torch.nn.functional as F +from torch_utils import misc +from torch_utils import persistence +from torch_utils.ops import conv2d_resample +from torch_utils.ops import upfirdn2d +from torch_utils.ops import bias_act +from torch_utils.ops import fma + + +def block_forward(self, x, img, ws, shapes, force_fp32=False, fused_modconv=None, **layer_kwargs): + misc.assert_shape(ws, [None, self.num_conv + self.num_torgb, self.w_dim]) + w_iter = iter(ws.unbind(dim=1)) + dtype = torch.float16 if self.use_fp16 and not force_fp32 else torch.float32 + memory_format = torch.channels_last if self.channels_last and not force_fp32 else torch.contiguous_format + if fused_modconv is None: + with misc.suppress_tracer_warnings(): # this value will be treated as a constant + fused_modconv = (not self.training) and (dtype == torch.float32 or int(x.shape[0]) == 1) + + # Input. + if self.in_channels == 0: + x = self.const.to(dtype=dtype, memory_format=memory_format) + x = x.unsqueeze(0).repeat([ws.shape[0], 1, 1, 1]) + else: + misc.assert_shape(x, [None, self.in_channels, self.resolution // 2, self.resolution // 2]) + x = x.to(dtype=dtype, memory_format=memory_format) + + # Main layers. + if self.in_channels == 0: + x = self.conv1(x, next(w_iter)[...,:shapes[0]], fused_modconv=fused_modconv, **layer_kwargs) + elif self.architecture == 'resnet': + y = self.skip(x, gain=np.sqrt(0.5)) + x = self.conv0(x, next(w_iter), fused_modconv=fused_modconv, **layer_kwargs) + x = self.conv1(x, next(w_iter), fused_modconv=fused_modconv, gain=np.sqrt(0.5), **layer_kwargs) + x = y.add_(x) + else: + x = self.conv0(x, next(w_iter)[...,:shapes[0]], fused_modconv=fused_modconv, **layer_kwargs) + x = self.conv1(x, next(w_iter)[...,:shapes[1]], fused_modconv=fused_modconv, **layer_kwargs) + + # ToRGB. + if img is not None: + misc.assert_shape(img, [None, self.img_channels, self.resolution // 2, self.resolution // 2]) + img = upfirdn2d.upsample2d(img, self.resample_filter) + if self.is_last or self.architecture == 'skip': + y = self.torgb(x, next(w_iter)[...,:shapes[2]], fused_modconv=fused_modconv) + y = y.to(dtype=torch.float32, memory_format=torch.contiguous_format) + img = img.add_(y) if img is not None else y + + assert x.dtype == dtype + assert img is None or img.dtype == torch.float32 + return x, img + + +def unravel_index(index, shape): + out = [] + for dim in reversed(shape): + out.append(index % dim) + index = index // dim + return tuple(reversed(out)) + + +def num_range(s: str) -> List[int]: + """ + Accept either a comma separated list of numbers 'a,b,c' or a range 'a-c' and return as a list of ints. + """ + + range_re = re.compile(r'^(\d+)-(\d+)$') + m = range_re.match(s) + if m: + return list(range(int(m.group(1)), int(m.group(2))+1)) + vals = s.split(',') + return [int(x) for x in vals] + + +@click.command() +@click.pass_context +@click.option('--network', 'network_pkl', help='Network pickle filename', required=True) +@click.option('--seeds', type=num_range, help='List of random seeds') +@click.option('--trunc', 'truncation_psi', type=float, help='Truncation psi', default=0.7, show_default=True) +@click.option('--class', 'class_idx', type=int, help='Class label (unconditional if not specified)') +@click.option('--noise-mode', help='Noise mode', type=click.Choice(['const', 'random', 'none']), default='const', show_default=True) +@click.option('--projected-w', help='Projection result file', type=str, metavar='FILE') +@click.option('--s_input', help='Projection result file', type=str, metavar='FILE') +@click.option('--outdir', help='Where to save the output images', type=str, required=True, metavar='DIR') +@click.option('--text_prompt', help='Text', type=str, required=True) +@click.option('--change_power', help='Change power', type=int, required=True) +@click.option('--from_video', 'from_video', is_flag=True, help="generate from video") + +def generate_images( + ctx: click.Context, + network_pkl: str, + seeds: Optional[List[int]], + truncation_psi: float, + noise_mode: str, + outdir: str, + class_idx: Optional[int], + projected_w: Optional[str], + s_input: Optional[str], + text_prompt: str, + change_power: int, + from_video: bool, +): + """ + Generate images using pretrained network pickle. + + Examples: + # Generate curated MetFaces images without truncation (Fig.10 left) + python generate.py --outdir=out --trunc=1 --seeds=85,265,297,849 \\ + --network=https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metfaces.pkl + + # Generate uncurated MetFaces images with truncation (Fig.12 upper left) + python generate.py --outdir=out --trunc=0.7 --seeds=600-605 \\ + --network=https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metfaces.pkl + + # Generate class conditional CIFAR-10 images (Fig.17 left, Car) + python generate.py --outdir=out --seeds=0-35 --class=1 \\ + --network=https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/cifar10.pkl + + # Render an image from projected W + python generate.py --outdir=out --projected_w=projected_w.npz \\ + --network=https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metfaces.pkl + """ + + print('Loading networks from "%s"...' % network_pkl) + # Use GPU if available + if torch.cuda.is_available(): + device = torch.device("cuda") + else: + device = torch.device("cpu") + + with dnnlib.util.open_url(network_pkl) as f: + G = legacy.load_network_pkl(f)['G_ema'].to(device) # type: ignore + + os.makedirs(outdir, exist_ok=True) + + # Synthesize the result of a W projection. + if projected_w is not None: + if seeds is not None: + print ('warn: --seeds is ignored when using --projected-w') + print(f'Generating images from projected W "{projected_w}"') + ws = np.load(projected_w)['w'] + ws = torch.tensor(ws, device=device) # pylint: disable=not-callable + assert ws.shape[1:] == (G.num_ws, G.w_dim) + for idx, w in enumerate(ws): + img = G.synthesis(w.unsqueeze(0), noise_mode=noise_mode) + img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8) + img = PIL.Image.fromarray(img[0].cpu().numpy(), 'RGB') + img.save(f'{outdir}/proj{idx:02d}.png') + return + + # Labels + label = torch.zeros([1, G.c_dim], device=device).requires_grad_() + if G.c_dim != 0: + if class_idx is None: + ctx.fail('Must specify class label with --class when using a conditional network') + label[:, class_idx] = 1 + else: + if class_idx is not None: + print ('warn: --class=lbl ignored when running on an unconditional network') + + # Generate images + for i in G.parameters(): + i.requires_grad = False + + + temp_shapes = [] + for res in G.synthesis.block_resolutions: + block = getattr(G.synthesis, f'b{res}') + if res == 4: + temp_shape = (block.conv1.affine.weight.shape[0], block.conv1.affine.weight.shape[0], block.torgb.affine.weight.shape[0]) + block.conv1.affine = torch.nn.Identity() + block.torgb.affine = torch.nn.Identity() + + else: + temp_shape = (block.conv0.affine.weight.shape[0], block.conv1.affine.weight.shape[0], block.torgb.affine.weight.shape[0]) + block.conv0.affine = torch.nn.Identity() + block.conv1.affine = torch.nn.Identity() + block.torgb.affine = torch.nn.Identity() + + temp_shapes.append(temp_shape) + + + if s_input is not None: + styles = np.load(s_input)['s'] + styles_direction = np.load(f'{outdir}/direction_'+text_prompt.replace(" ", "_")+'.npz')['s'] + + styles_direction = torch.tensor(styles_direction, device=device) + styles = torch.tensor(styles, device=device) + + if from_video and not os.path.isdir(f'{outdir}_video'): + os.makedirs(f'{outdir}_video') + + with torch.no_grad(): + if from_video: + name_i = 1000 + for grad_change in np.arange(0, 1, 0.02)*change_power: + imgs = [] + name_i += 1 + + styles += styles_direction*grad_change + styles_idx = 0 + x = img = None + for k , res in enumerate(G.synthesis.block_resolutions): + block = getattr(G.synthesis, f'b{res}') + + if res == 4: + x, img = block_forward(block, x, img, styles[:, styles_idx:styles_idx+2, :], temp_shapes[k], noise_mode=noise_mode, force_fp32=True) + styles_idx += 2 + else: + x, img = block_forward(block, x, img, styles[:, styles_idx:styles_idx+3, :], temp_shapes[k], noise_mode=noise_mode, force_fp32=True) + styles_idx += 3 + + img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255) + imgs.append(img[0].to(torch.uint8).cpu().numpy()) + + styles -= styles_direction*grad_change + img_filepath = '{}_video/{}_{}_{}.jpeg'.format(outdir, text_prompt.replace(" ", "_"), change_power, name_i) + PIL.Image.fromarray(np.concatenate(imgs, axis=1), 'RGB').save(img_filepath, quality=95) + else: + imgs = [] + grad_changes = [0, 0.25*change_power, 0.5*change_power, 0.75*change_power, change_power] + + for grad_change in grad_changes: + styles += styles_direction*grad_change + + styles_idx = 0 + x = img = None + for k , res in enumerate(G.synthesis.block_resolutions): + block = getattr(G.synthesis, f'b{res}') + + if res == 4: + x, img = block_forward(block, x, img, styles[:, styles_idx:styles_idx+2, :], temp_shapes[k], noise_mode=noise_mode, force_fp32=True) + styles_idx += 2 + else: + x, img = block_forward(block, x, img, styles[:, styles_idx:styles_idx+3, :], temp_shapes[k], noise_mode=noise_mode, force_fp32=True) + styles_idx += 3 + + img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255) + imgs.append(img[0].to(torch.uint8).cpu().numpy()) + + styles -= styles_direction*grad_change + + img_filepath = f'{outdir}/'+text_prompt.replace(" ", "_")+'_'+str(change_power)+'.jpeg' + PIL.Image.fromarray(np.concatenate(imgs, axis=1), 'RGB').save(img_filepath, quality=95) + + + +if __name__ == "__main__": + generate_images() diff --git a/generate_multi.py b/generate_multi.py new file mode 100644 index 0000000000000000000000000000000000000000..8ec4932c96891c7be5321c7822fd032a96974510 --- /dev/null +++ b/generate_multi.py @@ -0,0 +1,403 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +"""Generate images using pretrained network pickle.""" + +import os +import re +from typing import List, Optional +import torchvision +from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize +import click +import dnnlib +import numpy as np +import PIL.Image +import torch +from torch import linalg as LA +import clip +from PIL import Image +import legacy +import torch.nn.functional as F +import cv2 +import matplotlib.pyplot as plt +from torch_utils import misc +from torch_utils import persistence +from torch_utils.ops import conv2d_resample +from torch_utils.ops import upfirdn2d +from torch_utils.ops import bias_act +from torch_utils.ops import fma +import random +import math +import time +import id_loss + + +def block_forward(self, x, img, ws, shapes, force_fp32=False, fused_modconv=None, **layer_kwargs): + misc.assert_shape(ws, [None, self.num_conv + self.num_torgb, self.w_dim]) + w_iter = iter(ws.unbind(dim=1)) + dtype = torch.float16 if self.use_fp16 and not force_fp32 else torch.float32 + memory_format = torch.channels_last if self.channels_last and not force_fp32 else torch.contiguous_format + if fused_modconv is None: + with misc.suppress_tracer_warnings(): # this value will be treated as a constant + fused_modconv = (not self.training) and (dtype == torch.float32 or int(x.shape[0]) == 1) + + # Input. + if self.in_channels == 0: + x = self.const.to(dtype=dtype, memory_format=memory_format) + x = x.unsqueeze(0).repeat([ws.shape[0], 1, 1, 1]) + else: + misc.assert_shape(x, [None, self.in_channels, self.resolution // 2, self.resolution // 2]) + x = x.to(dtype=dtype, memory_format=memory_format) + + # Main layers. + if self.in_channels == 0: + x = self.conv1(x, next(w_iter)[...,:shapes[0]], fused_modconv=fused_modconv, **layer_kwargs) + elif self.architecture == 'resnet': + y = self.skip(x, gain=np.sqrt(0.5)) + x = self.conv0(x, next(w_iter), fused_modconv=fused_modconv, **layer_kwargs) + x = self.conv1(x, next(w_iter), fused_modconv=fused_modconv, gain=np.sqrt(0.5), **layer_kwargs) + x = y.add_(x) + else: + x = self.conv0(x, next(w_iter)[...,:shapes[0]], fused_modconv=fused_modconv, **layer_kwargs) + x = self.conv1(x, next(w_iter)[...,:shapes[1]], fused_modconv=fused_modconv, **layer_kwargs) + + # ToRGB. + if img is not None: + misc.assert_shape(img, [None, self.img_channels, self.resolution // 2, self.resolution // 2]) + img = upfirdn2d.upsample2d(img, self.resample_filter) + if self.is_last or self.architecture == 'skip': + y = self.torgb(x, next(w_iter)[...,:shapes[2]], fused_modconv=fused_modconv) + y = y.to(dtype=torch.float32, memory_format=torch.contiguous_format) + img = img.add_(y) if img is not None else y + + assert x.dtype == dtype + assert img is None or img.dtype == torch.float32 + return x, img + +def unravel_index(index, shape): + out = [] + for dim in reversed(shape): + out.append(index % dim) + index = index // dim + return tuple(reversed(out)) + + +#---------------------------------------------------------------------------- + +def num_range(s: str) -> List[int]: + '''Accept either a comma separated list of numbers 'a,b,c' or a range 'a-c' and return as a list of ints.''' + + range_re = re.compile(r'^(\d+)-(\d+)$') + m = range_re.match(s) + if m: + return list(range(int(m.group(1)), int(m.group(2))+1)) + vals = s.split(',') + return [int(x) for x in vals] + +#---------------------------------------------------------------------------- + +@click.command() +@click.pass_context +@click.option('--network', 'network_pkl', help='Network pickle filename', required=True) +@click.option('--seeds', type=num_range, help='List of random seeds') +@click.option('--trunc', 'truncation_psi', type=float, help='Truncation psi', default=1, show_default=True) +@click.option('--class', 'class_idx', type=int, help='Class label (unconditional if not specified)') +@click.option('--noise-mode', help='Noise mode', type=click.Choice(['const', 'random', 'none']), default='const', show_default=True) +@click.option('--projected-w', help='Projection result file', type=str, metavar='FILE') +@click.option('--projected_s', help='Projection result file', type=str, metavar='FILE') +@click.option('--outdir', help='Where to save the output images', type=str, required=True, metavar='DIR') +@click.option('--resolution', help='Resolution of output images', type=int, required=True) +@click.option('--batch_size', help='Batch Size', type=int, required=True) +@click.option('--identity_power', help='How much change occurs on the face', type=str, required=True) +def generate_images( + ctx: click.Context, + network_pkl: str, + seeds: Optional[List[int]], + truncation_psi: float, + noise_mode: str, + outdir: str, + class_idx: Optional[int], + projected_w: Optional[str], + projected_s: Optional[str], + resolution: int, + batch_size: int, + identity_power: str +): + """Generate images using pretrained network pickle. + + Examples: + + \b + # Generate curated MetFaces images without truncation (Fig.10 left) + python generate.py --outdir=out --trunc=1 --seeds=85,265,297,849 \\ + --network=https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metfaces.pkl + + \b + # Generate uncurated MetFaces images with truncation (Fig.12 upper left) + python generate.py --outdir=out --trunc=0.7 --seeds=600-605 \\ + --network=https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metfaces.pkl + + \b + # Generate class conditional CIFAR-10 images (Fig.17 left, Car) + python generate.py --outdir=out --seeds=0-35 --class=1 \\ + --network=https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/cifar10.pkl + + \b + # Render an image from projected W + python generate.py --outdir=out --projected_w=projected_w.npz \\ + --network=https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metfaces.pkl + """ + + print('Loading networks from "%s"...' % network_pkl) + device = torch.device('cuda') + with dnnlib.util.open_url(network_pkl) as f: + G = legacy.load_network_pkl(f)['G_ema'].to(device) # type: ignore + + os.makedirs(outdir, exist_ok=True) + + # Synthesize the result of a W projection. + if projected_w is not None: + if seeds is not None: + print ('warn: --seeds is ignored when using --projected-w') + print(f'Generating images from projected W "{projected_w}"') + ws = np.load(projected_w)['w'] + ws = torch.tensor(ws, device=device) # pylint: disable=not-callable + assert ws.shape[1:] == (G.num_ws, G.w_dim) + for idx, w in enumerate(ws): + img = G.synthesis(w.unsqueeze(0), noise_mode=noise_mode) + img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8) + img = PIL.Image.fromarray(img[0].cpu().numpy(), 'RGB').save(f'{outdir}/proj{idx:02d}.png') + return + + if seeds is None: + ctx.fail('--seeds option is required when not using --projected-w') + + # Labels. + label = torch.zeros([1, G.c_dim], device=device).requires_grad_() + if G.c_dim != 0: + if class_idx is None: + ctx.fail('Must specify class label with --class when using a conditional network') + label[:, class_idx] = 1 + else: + if class_idx is not None: + print ('warn: --class=lbl ignored when running on an unconditional network') + + model, preprocess = clip.load("ViT-B/32", device=device) + + text_prompts_file = open("text_prompts.txt") + text_prompts = text_prompts_file.read().split("\n") + text_prompts_file.close() + + text = clip.tokenize(text_prompts).to(device) + text_features = model.encode_text(text) + + # Generate images. + for i in G.parameters(): + i.requires_grad = True + + mean = torch.as_tensor((0.48145466, 0.4578275, 0.40821073), dtype=torch.float, device=device) + std = torch.as_tensor((0.26862954, 0.26130258, 0.27577711), dtype=torch.float, device=device) + if mean.ndim == 1: + mean = mean.view(-1, 1, 1) + if std.ndim == 1: + std = std.view(-1, 1, 1) + + transf = Compose([ + Resize(224, interpolation=Image.BICUBIC), + CenterCrop(224), + ]) + + styles_array = [] + print("seeds:", seeds) + t1 = time.time() + for seed_idx, seed in enumerate(seeds): + if seed==seeds[-1]: + print('Generating image for seed %d (%d/%d) ...' % (seed, seed_idx, len(seeds))) + z = torch.from_numpy(np.random.RandomState(seed).randn(1, G.z_dim)).to(device) + ws = G.mapping(z, label, truncation_psi=truncation_psi) + + block_ws = [] + with torch.autograd.profiler.record_function('split_ws'): + misc.assert_shape(ws, [None, G.synthesis.num_ws, G.synthesis.w_dim]) + ws = ws.to(torch.float32) + + + w_idx = 0 + for res in G.synthesis.block_resolutions: + block = getattr(G.synthesis, f'b{res}') + block_ws.append(ws.narrow(1, w_idx, block.num_conv + block.num_torgb)) + w_idx += block.num_conv + + + styles = torch.zeros(1,26,512, device=device) + styles_idx = 0 + temp_shapes = [] + for res, cur_ws in zip(G.synthesis.block_resolutions, block_ws): + block = getattr(G.synthesis, f'b{res}') + + if res == 4: + temp_shape = (block.conv1.affine.weight.shape[0], block.conv1.affine.weight.shape[0], block.torgb.affine.weight.shape[0]) + styles[0,:1,:] = block.conv1.affine(cur_ws[0,:1,:]) + styles[0,1:2,:] = block.torgb.affine(cur_ws[0,1:2,:]) + if seed_idx==(len(seeds)-1): + block.conv1.affine = torch.nn.Identity() + block.torgb.affine = torch.nn.Identity() + styles_idx += 2 + else: + temp_shape = (block.conv0.affine.weight.shape[0], block.conv1.affine.weight.shape[0], block.torgb.affine.weight.shape[0]) + styles[0,styles_idx:styles_idx+1,:temp_shape[0]] = block.conv0.affine(cur_ws[0,:1,:]) + styles[0,styles_idx+1:styles_idx+2,:temp_shape[1]] = block.conv1.affine(cur_ws[0,1:2,:]) + styles[0,styles_idx+2:styles_idx+3,:temp_shape[2]] = block.torgb.affine(cur_ws[0,2:3,:]) + if seed_idx==(len(seeds)-1): + block.conv0.affine = torch.nn.Identity() + block.conv1.affine = torch.nn.Identity() + block.torgb.affine = torch.nn.Identity() + styles_idx += 3 + temp_shapes.append(temp_shape) + + + styles = styles.detach() + styles_array.append(styles) + + resolution_dict = {256: 6, 512: 7, 1024: 8} + identity_coefficient_dict = {"high": 2,"medium": 0.5, "low": 0.1, "none": 0} + identity_coefficient = identity_coefficient_dict[identity_power] + styles_wanted_direction = torch.zeros(1,26,512, device=device) + styles_wanted_direction_grad_el2 = torch.zeros(1,26,512, device=device) + styles_wanted_direction.requires_grad_() + + global id_loss + id_loss = id_loss.IDLoss("a").to(device).eval() + + temp_photos = [] + grads = [] + for i in range(math.ceil(len(seeds)/batch_size)): + #print(i*batch_size, "processed", time.time()-t1) + + + styles = torch.vstack(styles_array[i*batch_size:(i+1)*batch_size]).to(device) + + + seed = seeds[i] + + styles_idx = 0 + x2 = img2 = None + + for k , (res, cur_ws) in enumerate(zip(G.synthesis.block_resolutions, block_ws)): + block = getattr(G.synthesis, f'b{res}') + if k>resolution_dict[resolution]: + continue + + if res == 4: + x2, img2 = block_forward(block, x2, img2, styles[:, styles_idx:styles_idx+2, :], temp_shapes[k], noise_mode=noise_mode) + styles_idx += 2 + else: + x2, img2 = block_forward(block, x2, img2, styles[:, styles_idx:styles_idx+3, :], temp_shapes[k], noise_mode=noise_mode) + styles_idx += 3 + + img2_cpu = img2.detach().cpu().numpy() + temp_photos.append(img2_cpu) + if i>3: + continue + + styles2 = styles + styles_wanted_direction + + styles_idx = 0 + x = img = None + for k , (res, cur_ws) in enumerate(zip(G.synthesis.block_resolutions, block_ws)): + block = getattr(G.synthesis, f'b{res}') + if k>resolution_dict[resolution]: + continue + if res == 4: + x, img = block_forward(block, x, img, styles2[:, styles_idx:styles_idx+2, :], temp_shapes[k], noise_mode=noise_mode) + styles_idx += 2 + else: + x, img = block_forward(block, x, img, styles2[:, styles_idx:styles_idx+3, :], temp_shapes[k], noise_mode=noise_mode) + styles_idx += 3 + + identity_loss, _ = id_loss(img, img2) + identity_loss *= identity_coefficient + img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255) + img = (transf(img.permute(0, 3, 1, 2))/255).sub_(mean).div_(std) + image_features = model.encode_image(img) + cos_sim = -1*F.cosine_similarity(image_features, (text_features[0]).unsqueeze(0)) + (identity_loss + cos_sim.sum()).backward(retain_graph=True) + + + + + #t1 = time.time() + + for text_counter in range(len(text_prompts)): + text_prompt = text_prompts[text_counter] + print(text_prompt) + + styles_wanted_direction.grad.data.zero_() + styles_wanted_direction_grad_el2 = torch.zeros(1,26,512, device=device) + with torch.no_grad(): + styles_wanted_direction *= 0 + + for i in range(math.ceil(len(seeds)/batch_size)): + print(i*batch_size, "processed", time.time()-t1) + + + styles = torch.vstack(styles_array[i*batch_size:(i+1)*batch_size]).to(device) + + + seed = seeds[i] + + img2 = torch.tensor(temp_photos[i]).to(device) + + styles2 = styles + styles_wanted_direction + + styles_idx = 0 + x = img = None + for k , (res, cur_ws) in enumerate(zip(G.synthesis.block_resolutions, block_ws)): + block = getattr(G.synthesis, f'b{res}') + if k>resolution_dict[resolution]: + continue + + if res == 4: + x, img = block_forward(block, x, img, styles2[:, styles_idx:styles_idx+2, :], temp_shapes[k], noise_mode=noise_mode) + styles_idx += 2 + else: + x, img = block_forward(block, x, img, styles2[:, styles_idx:styles_idx+3, :], temp_shapes[k], noise_mode=noise_mode) + styles_idx += 3 + + identity_loss, _ = id_loss(img, img2) + identity_loss *= identity_coefficient + img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255) + img = (transf(img.permute(0, 3, 1, 2))/255).sub_(mean).div_(std) + image_features = model.encode_image(img) + cos_sim = -1*F.cosine_similarity(image_features, (text_features[text_counter]).unsqueeze(0)) + (identity_loss + cos_sim.sum()).backward(retain_graph=True) + + + styles_wanted_direction.grad[:, [0, 1, 4, 7, 10, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25], :] = 0 + + + if i%2==1: + styles_wanted_direction.data = styles_wanted_direction - styles_wanted_direction.grad*5 + grads.append(styles_wanted_direction.grad.clone()) + styles_wanted_direction.grad.data.zero_() + + if i>3: + styles_wanted_direction_grad_el2[grads[-2]*grads[-1]<0] += 1 + + + styles_wanted_direction_cpu = styles_wanted_direction.detach() + styles_wanted_direction_cpu[styles_wanted_direction_grad_el2>(len(seeds)/batch_size)/4] = 0 + np.savez(f'{outdir}/direction_'+text_prompt.replace(" ", "_")+'.npz', s=styles_wanted_direction_cpu.cpu().numpy()) + + print("time passed:", time.time()-t1) +#---------------------------------------------------------------------------- + +if __name__ == "__main__": + generate_images() # pylint: disable=no-value-for-parameter + +#---------------------------------------------------------------------------- diff --git a/generate_w.py b/generate_w.py new file mode 100644 index 0000000000000000000000000000000000000000..9546c98f13a0ec06c15608825749acca4f865b9c --- /dev/null +++ b/generate_w.py @@ -0,0 +1,148 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +"""Generate images using pretrained network pickle.""" + +import os +import re +from typing import List, Optional +import torchvision +from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize +import click +import dnnlib +import numpy as np +import PIL.Image +import torch +from torch import linalg as LA +import clip +from PIL import Image +import legacy +import torch.nn.functional as F +import cv2 +import matplotlib.pyplot as plt +from torch_utils import misc +from torch_utils import persistence +from torch_utils.ops import conv2d_resample +from torch_utils.ops import upfirdn2d +from torch_utils.ops import bias_act +from torch_utils.ops import fma +import random +import math +import time +import id_loss + + +def block_forward(self, x, img, ws, shapes, force_fp32=False, fused_modconv=None, **layer_kwargs): + misc.assert_shape(ws, [None, self.num_conv + self.num_torgb, self.w_dim]) + w_iter = iter(ws.unbind(dim=1)) + dtype = torch.float16 if self.use_fp16 and not force_fp32 else torch.float32 + memory_format = torch.channels_last if self.channels_last and not force_fp32 else torch.contiguous_format + if fused_modconv is None: + with misc.suppress_tracer_warnings(): # this value will be treated as a constant + fused_modconv = (not self.training) and (dtype == torch.float32 or int(x.shape[0]) == 1) + + # Input. + if self.in_channels == 0: + x = self.const.to(dtype=dtype, memory_format=memory_format) + x = x.unsqueeze(0).repeat([ws.shape[0], 1, 1, 1]) + else: + misc.assert_shape(x, [None, self.in_channels, self.resolution // 2, self.resolution // 2]) + x = x.to(dtype=dtype, memory_format=memory_format) + + # Main layers. + if self.in_channels == 0: + x = self.conv1(x, next(w_iter)[...,:shapes[0]], fused_modconv=fused_modconv, **layer_kwargs) + elif self.architecture == 'resnet': + y = self.skip(x, gain=np.sqrt(0.5)) + x = self.conv0(x, next(w_iter), fused_modconv=fused_modconv, **layer_kwargs) + x = self.conv1(x, next(w_iter), fused_modconv=fused_modconv, gain=np.sqrt(0.5), **layer_kwargs) + x = y.add_(x) + else: + x = self.conv0(x, next(w_iter)[...,:shapes[0]], fused_modconv=fused_modconv, **layer_kwargs) + x = self.conv1(x, next(w_iter)[...,:shapes[1]], fused_modconv=fused_modconv, **layer_kwargs) + + # ToRGB. + if img is not None: + misc.assert_shape(img, [None, self.img_channels, self.resolution // 2, self.resolution // 2]) + img = upfirdn2d.upsample2d(img, self.resample_filter) + if self.is_last or self.architecture == 'skip': + y = self.torgb(x, next(w_iter)[...,:shapes[2]], fused_modconv=fused_modconv) + y = y.to(dtype=torch.float32, memory_format=torch.contiguous_format) + img = img.add_(y) if img is not None else y + + assert x.dtype == dtype + assert img is None or img.dtype == torch.float32 + return x, img + +def unravel_index(index, shape): + out = [] + for dim in reversed(shape): + out.append(index % dim) + index = index // dim + return tuple(reversed(out)) + + +def num_range(s: str) -> List[int]: + '''Accept either a comma separated list of numbers 'a,b,c' or a range 'a-c' and return as a list of ints.''' + + range_re = re.compile(r'^(\d+)-(\d+)$') + m = range_re.match(s) + if m: + return list(range(int(m.group(1)), int(m.group(2))+1)) + vals = s.split(',') + return [int(x) for x in vals] + + +@click.command() +@click.pass_context +@click.option('--network', 'network_pkl', help='Network pickle filename', required=True) +@click.option('--seeds', type=num_range, help='List of random seeds') +@click.option('--trunc', 'truncation_psi', type=float, help='Truncation psi', default=1, show_default=True) +@click.option('--class', 'class_idx', type=int, help='Class label (unconditional if not specified)') +@click.option('--noise-mode', help='Noise mode', type=click.Choice(['const', 'random', 'none']), default='const', show_default=True) +def generate_images( + ctx: click.Context, + network_pkl: str, + seeds: Optional[List[int]], + truncation_psi: float, + noise_mode: str, + class_idx: Optional[int], + projected_w: Optional[str], + projected_s: Optional[str] +): + + print('Loading networks from "%s"...' % network_pkl) + # Use GPU if available + if torch.cuda.is_available(): + device = torch.device("cuda") + else: + device = torch.device("cpu") + + with dnnlib.util.open_url(network_pkl) as f: + G = legacy.load_network_pkl(f)['G_ema'].to(device) # type: ignore + + if seeds is None: + ctx.fail('--seeds option is required when not using --projected-w') + + # Labels. + label = torch.zeros([1, G.c_dim], device=device).requires_grad_() + if G.c_dim != 0: + if class_idx is None: + ctx.fail('Must specify class label with --class when using a conditional network') + label[:, class_idx] = 1 + else: + if class_idx is not None: + print ('warn: --class=lbl ignored when running on an unconditional network') + + z = torch.from_numpy(np.random.RandomState(seeds[0]).randn(1, G.z_dim)).to(device) + ws = G.mapping(z, label, truncation_psi=truncation_psi) + np.savez(f'encoder4editing/projected_w.npz', w=ws.detach().cpu().numpy()) + + +if __name__ == "__main__": + generate_images() diff --git a/helpers.py b/helpers.py new file mode 100644 index 0000000000000000000000000000000000000000..f388ecb420a7db06981b766b80b951554daed613 --- /dev/null +++ b/helpers.py @@ -0,0 +1,119 @@ +from collections import namedtuple +import torch +from torch.nn import Conv2d, BatchNorm2d, PReLU, ReLU, Sigmoid, MaxPool2d, AdaptiveAvgPool2d, Sequential, Module + +""" +ArcFace implementation from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch) +""" + + +class Flatten(Module): + def forward(self, input): + return input.view(input.size(0), -1) + + +def l2_norm(input, axis=1): + norm = torch.norm(input, 2, axis, True) + output = torch.div(input, norm) + return output + + +class Bottleneck(namedtuple('Block', ['in_channel', 'depth', 'stride'])): + """ A named tuple describing a ResNet block. """ + + +def get_block(in_channel, depth, num_units, stride=2): + return [Bottleneck(in_channel, depth, stride)] + [Bottleneck(depth, depth, 1) for i in range(num_units - 1)] + + +def get_blocks(num_layers): + if num_layers == 50: + blocks = [ + get_block(in_channel=64, depth=64, num_units=3), + get_block(in_channel=64, depth=128, num_units=4), + get_block(in_channel=128, depth=256, num_units=14), + get_block(in_channel=256, depth=512, num_units=3) + ] + elif num_layers == 100: + blocks = [ + get_block(in_channel=64, depth=64, num_units=3), + get_block(in_channel=64, depth=128, num_units=13), + get_block(in_channel=128, depth=256, num_units=30), + get_block(in_channel=256, depth=512, num_units=3) + ] + elif num_layers == 152: + blocks = [ + get_block(in_channel=64, depth=64, num_units=3), + get_block(in_channel=64, depth=128, num_units=8), + get_block(in_channel=128, depth=256, num_units=36), + get_block(in_channel=256, depth=512, num_units=3) + ] + else: + raise ValueError("Invalid number of layers: {}. Must be one of [50, 100, 152]".format(num_layers)) + return blocks + + +class SEModule(Module): + def __init__(self, channels, reduction): + super(SEModule, self).__init__() + self.avg_pool = AdaptiveAvgPool2d(1) + self.fc1 = Conv2d(channels, channels // reduction, kernel_size=1, padding=0, bias=False) + self.relu = ReLU(inplace=True) + self.fc2 = Conv2d(channels // reduction, channels, kernel_size=1, padding=0, bias=False) + self.sigmoid = Sigmoid() + + def forward(self, x): + module_input = x + x = self.avg_pool(x) + x = self.fc1(x) + x = self.relu(x) + x = self.fc2(x) + x = self.sigmoid(x) + return module_input * x + + +class bottleneck_IR(Module): + def __init__(self, in_channel, depth, stride): + super(bottleneck_IR, self).__init__() + if in_channel == depth: + self.shortcut_layer = MaxPool2d(1, stride) + else: + self.shortcut_layer = Sequential( + Conv2d(in_channel, depth, (1, 1), stride, bias=False), + BatchNorm2d(depth) + ) + self.res_layer = Sequential( + BatchNorm2d(in_channel), + Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False), PReLU(depth), + Conv2d(depth, depth, (3, 3), stride, 1, bias=False), BatchNorm2d(depth) + ) + + def forward(self, x): + shortcut = self.shortcut_layer(x) + res = self.res_layer(x) + return res + shortcut + + +class bottleneck_IR_SE(Module): + def __init__(self, in_channel, depth, stride): + super(bottleneck_IR_SE, self).__init__() + if in_channel == depth: + self.shortcut_layer = MaxPool2d(1, stride) + else: + self.shortcut_layer = Sequential( + Conv2d(in_channel, depth, (1, 1), stride, bias=False), + BatchNorm2d(depth) + ) + self.res_layer = Sequential( + BatchNorm2d(in_channel), + Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False), + PReLU(depth), + Conv2d(depth, depth, (3, 3), stride, 1, bias=False), + BatchNorm2d(depth), + SEModule(depth, 16) + ) + + def forward(self, x): + shortcut = self.shortcut_layer(x) + res = self.res_layer(x) + return res + shortcut \ No newline at end of file diff --git a/id_loss.py b/id_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..deb37def3d7c49cbf922e6849de541bec86f79ab --- /dev/null +++ b/id_loss.py @@ -0,0 +1,45 @@ +import torch +from torch import nn +from model_irse import Backbone + + +# Use GPU if available +if torch.cuda.is_available(): + device = torch.device("cuda") +else: + device = torch.device("cpu") + + +class IDLoss(nn.Module): + def __init__(self, opts): + super(IDLoss, self).__init__() + print('Loading ResNet ArcFace') + self.facenet = Backbone(input_size=112, num_layers=50, drop_ratio=0.6, mode='ir_se') + self.facenet.load_state_dict(torch.load("model_ir_se50.pth", map_location=device)) + self.pool = torch.nn.AdaptiveAvgPool2d((256, 256)) + self.face_pool = torch.nn.AdaptiveAvgPool2d((112, 112)) + self.facenet.eval() + self.opts = opts + + def extract_feats(self, x): + if x.shape[2] != 256: + x = self.pool(x) + x = x[:, :, 35:223, 32:220] # Crop interesting region + x = self.face_pool(x) + x_feats = self.facenet(x) + return x_feats + + def forward(self, y_hat, y): + n_samples = y.shape[0] + y_feats = self.extract_feats(y) # Otherwise use the feature from there + y_hat_feats = self.extract_feats(y_hat) + y_feats = y_feats.detach() + loss = 0 + sim_improvement = 0 + count = 0 + for i in range(n_samples): + diff_target = y_hat_feats[i].dot(y_feats[i]) + loss += 1 - diff_target + count += 1 + + return loss / count, sim_improvement / count diff --git a/legacy.py b/legacy.py new file mode 100644 index 0000000000000000000000000000000000000000..54dfbcee9f2d594b9870ff4851d66ba8ab63ae16 --- /dev/null +++ b/legacy.py @@ -0,0 +1,319 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +import click +import pickle +import re +import copy +import numpy as np +import torch +import dnnlib +from torch_utils import misc + + +def load_network_pkl(f, force_fp16=False): + data = _LegacyUnpickler(f).load() + + # Legacy TensorFlow pickle => convert. + if isinstance(data, tuple) and len(data) == 3 and all(isinstance(net, _TFNetworkStub) for net in data): + tf_G, tf_D, tf_Gs = data + G = convert_tf_generator(tf_G) + D = convert_tf_discriminator(tf_D) + G_ema = convert_tf_generator(tf_Gs) + data = dict(G=G, D=D, G_ema=G_ema) + + # Add missing fields. + if 'training_set_kwargs' not in data: + data['training_set_kwargs'] = None + if 'augment_pipe' not in data: + data['augment_pipe'] = None + + # Validate contents. + assert isinstance(data['G'], torch.nn.Module) + assert isinstance(data['D'], torch.nn.Module) + assert isinstance(data['G_ema'], torch.nn.Module) + assert isinstance(data['training_set_kwargs'], (dict, type(None))) + assert isinstance(data['augment_pipe'], (torch.nn.Module, type(None))) + + # Force FP16. + if force_fp16: + for key in ['G', 'D', 'G_ema']: + old = data[key] + kwargs = copy.deepcopy(old.init_kwargs) + if key.startswith('G'): + kwargs.synthesis_kwargs = dnnlib.EasyDict(kwargs.get('synthesis_kwargs', {})) + kwargs.synthesis_kwargs.num_fp16_res = 4 + kwargs.synthesis_kwargs.conv_clamp = 256 + if key.startswith('D'): + kwargs.num_fp16_res = 4 + kwargs.conv_clamp = 256 + if kwargs != old.init_kwargs: + new = type(old)(**kwargs).eval().requires_grad_(False) + misc.copy_params_and_buffers(old, new, require_all=True) + data[key] = new + return data + +#---------------------------------------------------------------------------- + +class _TFNetworkStub(dnnlib.EasyDict): + pass + +class _LegacyUnpickler(pickle.Unpickler): + def find_class(self, module, name): + if module == 'dnnlib.tflib.network' and name == 'Network': + return _TFNetworkStub + return super().find_class(module, name) + +#---------------------------------------------------------------------------- + +def _collect_tf_params(tf_net): + # pylint: disable=protected-access + tf_params = dict() + def recurse(prefix, tf_net): + for name, value in tf_net.variables: + tf_params[prefix + name] = value + for name, comp in tf_net.components.items(): + recurse(prefix + name + '/', comp) + recurse('', tf_net) + return tf_params + +#---------------------------------------------------------------------------- + +def _populate_module_params(module, *patterns): + for name, tensor in misc.named_params_and_buffers(module): + found = False + value = None + for pattern, value_fn in zip(patterns[0::2], patterns[1::2]): + match = re.fullmatch(pattern, name) + if match: + found = True + if value_fn is not None: + value = value_fn(*match.groups()) + break + try: + assert found + if value is not None: + tensor.copy_(torch.from_numpy(np.array(value))) + except: + print(name, list(tensor.shape)) + raise + +#---------------------------------------------------------------------------- + +def convert_tf_generator(tf_G): + if tf_G.version < 4: + raise ValueError('TensorFlow pickle version too low') + + # Collect kwargs. + tf_kwargs = tf_G.static_kwargs + known_kwargs = set() + def kwarg(tf_name, default=None, none=None): + known_kwargs.add(tf_name) + val = tf_kwargs.get(tf_name, default) + return val if val is not None else none + + # Convert kwargs. + kwargs = dnnlib.EasyDict( + z_dim = kwarg('latent_size', 512), + c_dim = kwarg('label_size', 0), + w_dim = kwarg('dlatent_size', 512), + img_resolution = kwarg('resolution', 1024), + img_channels = kwarg('num_channels', 3), + mapping_kwargs = dnnlib.EasyDict( + num_layers = kwarg('mapping_layers', 8), + embed_features = kwarg('label_fmaps', None), + layer_features = kwarg('mapping_fmaps', None), + activation = kwarg('mapping_nonlinearity', 'lrelu'), + lr_multiplier = kwarg('mapping_lrmul', 0.01), + w_avg_beta = kwarg('w_avg_beta', 0.995, none=1), + ), + synthesis_kwargs = dnnlib.EasyDict( + channel_base = kwarg('fmap_base', 16384) * 2, + channel_max = kwarg('fmap_max', 512), + num_fp16_res = kwarg('num_fp16_res', 0), + conv_clamp = kwarg('conv_clamp', None), + architecture = kwarg('architecture', 'skip'), + resample_filter = kwarg('resample_kernel', [1,3,3,1]), + use_noise = kwarg('use_noise', True), + activation = kwarg('nonlinearity', 'lrelu'), + ), + ) + + # Check for unknown kwargs. + kwarg('truncation_psi') + kwarg('truncation_cutoff') + kwarg('style_mixing_prob') + kwarg('structure') + unknown_kwargs = list(set(tf_kwargs.keys()) - known_kwargs) + if len(unknown_kwargs) > 0: + raise ValueError('Unknown TensorFlow kwarg', unknown_kwargs[0]) + + # Collect params. + tf_params = _collect_tf_params(tf_G) + for name, value in list(tf_params.items()): + match = re.fullmatch(r'ToRGB_lod(\d+)/(.*)', name) + if match: + r = kwargs.img_resolution // (2 ** int(match.group(1))) + tf_params[f'{r}x{r}/ToRGB/{match.group(2)}'] = value + kwargs.synthesis.kwargs.architecture = 'orig' + #for name, value in tf_params.items(): print(f'{name:<50s}{list(value.shape)}') + + # Convert params. + from training import networks + G = networks.Generator(**kwargs).eval().requires_grad_(False) + # pylint: disable=unnecessary-lambda + _populate_module_params(G, + r'mapping\.w_avg', lambda: tf_params[f'dlatent_avg'], + r'mapping\.embed\.weight', lambda: tf_params[f'mapping/LabelEmbed/weight'].transpose(), + r'mapping\.embed\.bias', lambda: tf_params[f'mapping/LabelEmbed/bias'], + r'mapping\.fc(\d+)\.weight', lambda i: tf_params[f'mapping/Dense{i}/weight'].transpose(), + r'mapping\.fc(\d+)\.bias', lambda i: tf_params[f'mapping/Dense{i}/bias'], + r'synthesis\.b4\.const', lambda: tf_params[f'synthesis/4x4/Const/const'][0], + r'synthesis\.b4\.conv1\.weight', lambda: tf_params[f'synthesis/4x4/Conv/weight'].transpose(3, 2, 0, 1), + r'synthesis\.b4\.conv1\.bias', lambda: tf_params[f'synthesis/4x4/Conv/bias'], + r'synthesis\.b4\.conv1\.noise_const', lambda: tf_params[f'synthesis/noise0'][0, 0], + r'synthesis\.b4\.conv1\.noise_strength', lambda: tf_params[f'synthesis/4x4/Conv/noise_strength'], + r'synthesis\.b4\.conv1\.affine\.weight', lambda: tf_params[f'synthesis/4x4/Conv/mod_weight'].transpose(), + r'synthesis\.b4\.conv1\.affine\.bias', lambda: tf_params[f'synthesis/4x4/Conv/mod_bias'] + 1, + r'synthesis\.b(\d+)\.conv0\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/weight'][::-1, ::-1].transpose(3, 2, 0, 1), + r'synthesis\.b(\d+)\.conv0\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/bias'], + r'synthesis\.b(\d+)\.conv0\.noise_const', lambda r: tf_params[f'synthesis/noise{int(np.log2(int(r)))*2-5}'][0, 0], + r'synthesis\.b(\d+)\.conv0\.noise_strength', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/noise_strength'], + r'synthesis\.b(\d+)\.conv0\.affine\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/mod_weight'].transpose(), + r'synthesis\.b(\d+)\.conv0\.affine\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/mod_bias'] + 1, + r'synthesis\.b(\d+)\.conv1\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/weight'].transpose(3, 2, 0, 1), + r'synthesis\.b(\d+)\.conv1\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/bias'], + r'synthesis\.b(\d+)\.conv1\.noise_const', lambda r: tf_params[f'synthesis/noise{int(np.log2(int(r)))*2-4}'][0, 0], + r'synthesis\.b(\d+)\.conv1\.noise_strength', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/noise_strength'], + r'synthesis\.b(\d+)\.conv1\.affine\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/mod_weight'].transpose(), + r'synthesis\.b(\d+)\.conv1\.affine\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/mod_bias'] + 1, + r'synthesis\.b(\d+)\.torgb\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/ToRGB/weight'].transpose(3, 2, 0, 1), + r'synthesis\.b(\d+)\.torgb\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/ToRGB/bias'], + r'synthesis\.b(\d+)\.torgb\.affine\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/ToRGB/mod_weight'].transpose(), + r'synthesis\.b(\d+)\.torgb\.affine\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/ToRGB/mod_bias'] + 1, + r'synthesis\.b(\d+)\.skip\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Skip/weight'][::-1, ::-1].transpose(3, 2, 0, 1), + r'.*\.resample_filter', None, + ) + return G + +#---------------------------------------------------------------------------- + +def convert_tf_discriminator(tf_D): + if tf_D.version < 4: + raise ValueError('TensorFlow pickle version too low') + + # Collect kwargs. + tf_kwargs = tf_D.static_kwargs + known_kwargs = set() + def kwarg(tf_name, default=None): + known_kwargs.add(tf_name) + return tf_kwargs.get(tf_name, default) + + # Convert kwargs. + kwargs = dnnlib.EasyDict( + c_dim = kwarg('label_size', 0), + img_resolution = kwarg('resolution', 1024), + img_channels = kwarg('num_channels', 3), + architecture = kwarg('architecture', 'resnet'), + channel_base = kwarg('fmap_base', 16384) * 2, + channel_max = kwarg('fmap_max', 512), + num_fp16_res = kwarg('num_fp16_res', 0), + conv_clamp = kwarg('conv_clamp', None), + cmap_dim = kwarg('mapping_fmaps', None), + block_kwargs = dnnlib.EasyDict( + activation = kwarg('nonlinearity', 'lrelu'), + resample_filter = kwarg('resample_kernel', [1,3,3,1]), + freeze_layers = kwarg('freeze_layers', 0), + ), + mapping_kwargs = dnnlib.EasyDict( + num_layers = kwarg('mapping_layers', 0), + embed_features = kwarg('mapping_fmaps', None), + layer_features = kwarg('mapping_fmaps', None), + activation = kwarg('nonlinearity', 'lrelu'), + lr_multiplier = kwarg('mapping_lrmul', 0.1), + ), + epilogue_kwargs = dnnlib.EasyDict( + mbstd_group_size = kwarg('mbstd_group_size', None), + mbstd_num_channels = kwarg('mbstd_num_features', 1), + activation = kwarg('nonlinearity', 'lrelu'), + ), + ) + + # Check for unknown kwargs. + kwarg('structure') + unknown_kwargs = list(set(tf_kwargs.keys()) - known_kwargs) + if len(unknown_kwargs) > 0: + raise ValueError('Unknown TensorFlow kwarg', unknown_kwargs[0]) + + # Collect params. + tf_params = _collect_tf_params(tf_D) + for name, value in list(tf_params.items()): + match = re.fullmatch(r'FromRGB_lod(\d+)/(.*)', name) + if match: + r = kwargs.img_resolution // (2 ** int(match.group(1))) + tf_params[f'{r}x{r}/FromRGB/{match.group(2)}'] = value + kwargs.architecture = 'orig' + #for name, value in tf_params.items(): print(f'{name:<50s}{list(value.shape)}') + + # Convert params. + from training import networks + D = networks.Discriminator(**kwargs).eval().requires_grad_(False) + # pylint: disable=unnecessary-lambda + _populate_module_params(D, + r'b(\d+)\.fromrgb\.weight', lambda r: tf_params[f'{r}x{r}/FromRGB/weight'].transpose(3, 2, 0, 1), + r'b(\d+)\.fromrgb\.bias', lambda r: tf_params[f'{r}x{r}/FromRGB/bias'], + r'b(\d+)\.conv(\d+)\.weight', lambda r, i: tf_params[f'{r}x{r}/Conv{i}{["","_down"][int(i)]}/weight'].transpose(3, 2, 0, 1), + r'b(\d+)\.conv(\d+)\.bias', lambda r, i: tf_params[f'{r}x{r}/Conv{i}{["","_down"][int(i)]}/bias'], + r'b(\d+)\.skip\.weight', lambda r: tf_params[f'{r}x{r}/Skip/weight'].transpose(3, 2, 0, 1), + r'mapping\.embed\.weight', lambda: tf_params[f'LabelEmbed/weight'].transpose(), + r'mapping\.embed\.bias', lambda: tf_params[f'LabelEmbed/bias'], + r'mapping\.fc(\d+)\.weight', lambda i: tf_params[f'Mapping{i}/weight'].transpose(), + r'mapping\.fc(\d+)\.bias', lambda i: tf_params[f'Mapping{i}/bias'], + r'b4\.conv\.weight', lambda: tf_params[f'4x4/Conv/weight'].transpose(3, 2, 0, 1), + r'b4\.conv\.bias', lambda: tf_params[f'4x4/Conv/bias'], + r'b4\.fc\.weight', lambda: tf_params[f'4x4/Dense0/weight'].transpose(), + r'b4\.fc\.bias', lambda: tf_params[f'4x4/Dense0/bias'], + r'b4\.out\.weight', lambda: tf_params[f'Output/weight'].transpose(), + r'b4\.out\.bias', lambda: tf_params[f'Output/bias'], + r'.*\.resample_filter', None, + ) + return D + +#---------------------------------------------------------------------------- + +@click.command() +@click.option('--source', help='Input pickle', required=True, metavar='PATH') +@click.option('--dest', help='Output pickle', required=True, metavar='PATH') +@click.option('--force-fp16', help='Force the networks to use FP16', type=bool, default=False, metavar='BOOL', show_default=True) +def convert_network_pickle(source, dest, force_fp16): + """Convert legacy network pickle into the native PyTorch format. + + The tool is able to load the main network configurations exported using the TensorFlow version of StyleGAN2 or StyleGAN2-ADA. + It does not support e.g. StyleGAN2-ADA comparison methods, StyleGAN2 configs A-D, or StyleGAN1 networks. + + Example: + + \b + python legacy.py \\ + --source=https://nvlabs-fi-cdn.nvidia.com/stylegan2/networks/stylegan2-cat-config-f.pkl \\ + --dest=stylegan2-cat-config-f.pkl + """ + print(f'Loading "{source}"...') + with dnnlib.util.open_url(source) as f: + data = load_network_pkl(f, force_fp16=force_fp16) + print(f'Saving "{dest}"...') + with open(dest, 'wb') as f: + pickle.dump(data, f) + print('Done.') + +#---------------------------------------------------------------------------- + +if __name__ == "__main__": + convert_network_pickle() # pylint: disable=no-value-for-parameter + +#---------------------------------------------------------------------------- diff --git a/metrics/__init__.py b/metrics/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e1e1a5ba99e56a56ecaa14f7d4fa41777789c0cf --- /dev/null +++ b/metrics/__init__.py @@ -0,0 +1,9 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +# empty diff --git a/metrics/frechet_inception_distance.py b/metrics/frechet_inception_distance.py new file mode 100644 index 0000000000000000000000000000000000000000..1d38ec731b33e6f4f20cd4601e58d7e5ce2eaaa3 --- /dev/null +++ b/metrics/frechet_inception_distance.py @@ -0,0 +1,41 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +"""Frechet Inception Distance (FID) from the paper +"GANs trained by a two time-scale update rule converge to a local Nash +equilibrium". Matches the original implementation by Heusel et al. at +https://github.com/bioinf-jku/TTUR/blob/master/fid.py""" + +import numpy as np +import scipy.linalg +from . import metric_utils + +#---------------------------------------------------------------------------- + +def compute_fid(opts, max_real, num_gen): + # Direct TorchScript translation of http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz + detector_url = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/inception-2015-12-05.pt' + detector_kwargs = dict(return_features=True) # Return raw features before the softmax layer. + + mu_real, sigma_real = metric_utils.compute_feature_stats_for_dataset( + opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs, + rel_lo=0, rel_hi=0, capture_mean_cov=True, max_items=max_real).get_mean_cov() + + mu_gen, sigma_gen = metric_utils.compute_feature_stats_for_generator( + opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs, + rel_lo=0, rel_hi=1, capture_mean_cov=True, max_items=num_gen).get_mean_cov() + + if opts.rank != 0: + return float('nan') + + m = np.square(mu_gen - mu_real).sum() + s, _ = scipy.linalg.sqrtm(np.dot(sigma_gen, sigma_real), disp=False) # pylint: disable=no-member + fid = np.real(m + np.trace(sigma_gen + sigma_real - s * 2)) + return float(fid) + +#---------------------------------------------------------------------------- diff --git a/metrics/inception_score.py b/metrics/inception_score.py new file mode 100644 index 0000000000000000000000000000000000000000..3822c1435901a47e8c192b52cd3ed1ce5de67acd --- /dev/null +++ b/metrics/inception_score.py @@ -0,0 +1,38 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +"""Inception Score (IS) from the paper "Improved techniques for training +GANs". Matches the original implementation by Salimans et al. at +https://github.com/openai/improved-gan/blob/master/inception_score/model.py""" + +import numpy as np +from . import metric_utils + +#---------------------------------------------------------------------------- + +def compute_is(opts, num_gen, num_splits): + # Direct TorchScript translation of http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz + detector_url = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/inception-2015-12-05.pt' + detector_kwargs = dict(no_output_bias=True) # Match the original implementation by not applying bias in the softmax layer. + + gen_probs = metric_utils.compute_feature_stats_for_generator( + opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs, + capture_all=True, max_items=num_gen).get_all() + + if opts.rank != 0: + return float('nan'), float('nan') + + scores = [] + for i in range(num_splits): + part = gen_probs[i * num_gen // num_splits : (i + 1) * num_gen // num_splits] + kl = part * (np.log(part) - np.log(np.mean(part, axis=0, keepdims=True))) + kl = np.mean(np.sum(kl, axis=1)) + scores.append(np.exp(kl)) + return float(np.mean(scores)), float(np.std(scores)) + +#---------------------------------------------------------------------------- diff --git a/metrics/kernel_inception_distance.py b/metrics/kernel_inception_distance.py new file mode 100644 index 0000000000000000000000000000000000000000..3ac978925b5cf810463ef8e8a6f0dcd3f9078e6d --- /dev/null +++ b/metrics/kernel_inception_distance.py @@ -0,0 +1,46 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +"""Kernel Inception Distance (KID) from the paper "Demystifying MMD +GANs". Matches the original implementation by Binkowski et al. at +https://github.com/mbinkowski/MMD-GAN/blob/master/gan/compute_scores.py""" + +import numpy as np +from . import metric_utils + +#---------------------------------------------------------------------------- + +def compute_kid(opts, max_real, num_gen, num_subsets, max_subset_size): + # Direct TorchScript translation of http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz + detector_url = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/inception-2015-12-05.pt' + detector_kwargs = dict(return_features=True) # Return raw features before the softmax layer. + + real_features = metric_utils.compute_feature_stats_for_dataset( + opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs, + rel_lo=0, rel_hi=0, capture_all=True, max_items=max_real).get_all() + + gen_features = metric_utils.compute_feature_stats_for_generator( + opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs, + rel_lo=0, rel_hi=1, capture_all=True, max_items=num_gen).get_all() + + if opts.rank != 0: + return float('nan') + + n = real_features.shape[1] + m = min(min(real_features.shape[0], gen_features.shape[0]), max_subset_size) + t = 0 + for _subset_idx in range(num_subsets): + x = gen_features[np.random.choice(gen_features.shape[0], m, replace=False)] + y = real_features[np.random.choice(real_features.shape[0], m, replace=False)] + a = (x @ x.T / n + 1) ** 3 + (y @ y.T / n + 1) ** 3 + b = (x @ y.T / n + 1) ** 3 + t += (a.sum() - np.diag(a).sum()) / (m - 1) - b.sum() * 2 / m + kid = t / num_subsets / m + return float(kid) + +#---------------------------------------------------------------------------- diff --git a/metrics/metric_main.py b/metrics/metric_main.py new file mode 100644 index 0000000000000000000000000000000000000000..738804a6fbdba7bee3b0c68ca2fca4646527bc28 --- /dev/null +++ b/metrics/metric_main.py @@ -0,0 +1,152 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +import os +import time +import json +import torch +import dnnlib + +from . import metric_utils +from . import frechet_inception_distance +from . import kernel_inception_distance +from . import precision_recall +from . import perceptual_path_length +from . import inception_score + +#---------------------------------------------------------------------------- + +_metric_dict = dict() # name => fn + +def register_metric(fn): + assert callable(fn) + _metric_dict[fn.__name__] = fn + return fn + +def is_valid_metric(metric): + return metric in _metric_dict + +def list_valid_metrics(): + return list(_metric_dict.keys()) + +#---------------------------------------------------------------------------- + +def calc_metric(metric, **kwargs): # See metric_utils.MetricOptions for the full list of arguments. + assert is_valid_metric(metric) + opts = metric_utils.MetricOptions(**kwargs) + + # Calculate. + start_time = time.time() + results = _metric_dict[metric](opts) + total_time = time.time() - start_time + + # Broadcast results. + for key, value in list(results.items()): + if opts.num_gpus > 1: + value = torch.as_tensor(value, dtype=torch.float64, device=opts.device) + torch.distributed.broadcast(tensor=value, src=0) + value = float(value.cpu()) + results[key] = value + + # Decorate with metadata. + return dnnlib.EasyDict( + results = dnnlib.EasyDict(results), + metric = metric, + total_time = total_time, + total_time_str = dnnlib.util.format_time(total_time), + num_gpus = opts.num_gpus, + ) + +#---------------------------------------------------------------------------- + +def report_metric(result_dict, run_dir=None, snapshot_pkl=None): + metric = result_dict['metric'] + assert is_valid_metric(metric) + if run_dir is not None and snapshot_pkl is not None: + snapshot_pkl = os.path.relpath(snapshot_pkl, run_dir) + + jsonl_line = json.dumps(dict(result_dict, snapshot_pkl=snapshot_pkl, timestamp=time.time())) + print(jsonl_line) + if run_dir is not None and os.path.isdir(run_dir): + with open(os.path.join(run_dir, f'metric-{metric}.jsonl'), 'at') as f: + f.write(jsonl_line + '\n') + +#---------------------------------------------------------------------------- +# Primary metrics. + +@register_metric +def fid50k_full(opts): + opts.dataset_kwargs.update(max_size=None, xflip=False) + fid = frechet_inception_distance.compute_fid(opts, max_real=None, num_gen=50000) + return dict(fid50k_full=fid) + +@register_metric +def kid50k_full(opts): + opts.dataset_kwargs.update(max_size=None, xflip=False) + kid = kernel_inception_distance.compute_kid(opts, max_real=1000000, num_gen=50000, num_subsets=100, max_subset_size=1000) + return dict(kid50k_full=kid) + +@register_metric +def pr50k3_full(opts): + opts.dataset_kwargs.update(max_size=None, xflip=False) + precision, recall = precision_recall.compute_pr(opts, max_real=200000, num_gen=50000, nhood_size=3, row_batch_size=10000, col_batch_size=10000) + return dict(pr50k3_full_precision=precision, pr50k3_full_recall=recall) + +@register_metric +def ppl2_wend(opts): + ppl = perceptual_path_length.compute_ppl(opts, num_samples=50000, epsilon=1e-4, space='w', sampling='end', crop=False, batch_size=2) + return dict(ppl2_wend=ppl) + +@register_metric +def is50k(opts): + opts.dataset_kwargs.update(max_size=None, xflip=False) + mean, std = inception_score.compute_is(opts, num_gen=50000, num_splits=10) + return dict(is50k_mean=mean, is50k_std=std) + +#---------------------------------------------------------------------------- +# Legacy metrics. + +@register_metric +def fid50k(opts): + opts.dataset_kwargs.update(max_size=None) + fid = frechet_inception_distance.compute_fid(opts, max_real=50000, num_gen=50000) + return dict(fid50k=fid) + +@register_metric +def kid50k(opts): + opts.dataset_kwargs.update(max_size=None) + kid = kernel_inception_distance.compute_kid(opts, max_real=50000, num_gen=50000, num_subsets=100, max_subset_size=1000) + return dict(kid50k=kid) + +@register_metric +def pr50k3(opts): + opts.dataset_kwargs.update(max_size=None) + precision, recall = precision_recall.compute_pr(opts, max_real=50000, num_gen=50000, nhood_size=3, row_batch_size=10000, col_batch_size=10000) + return dict(pr50k3_precision=precision, pr50k3_recall=recall) + +@register_metric +def ppl_zfull(opts): + ppl = perceptual_path_length.compute_ppl(opts, num_samples=50000, epsilon=1e-4, space='z', sampling='full', crop=True, batch_size=2) + return dict(ppl_zfull=ppl) + +@register_metric +def ppl_wfull(opts): + ppl = perceptual_path_length.compute_ppl(opts, num_samples=50000, epsilon=1e-4, space='w', sampling='full', crop=True, batch_size=2) + return dict(ppl_wfull=ppl) + +@register_metric +def ppl_zend(opts): + ppl = perceptual_path_length.compute_ppl(opts, num_samples=50000, epsilon=1e-4, space='z', sampling='end', crop=True, batch_size=2) + return dict(ppl_zend=ppl) + +@register_metric +def ppl_wend(opts): + ppl = perceptual_path_length.compute_ppl(opts, num_samples=50000, epsilon=1e-4, space='w', sampling='end', crop=True, batch_size=2) + return dict(ppl_wend=ppl) + +#---------------------------------------------------------------------------- diff --git a/metrics/metric_utils.py b/metrics/metric_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..16de1eae3ee79549412eff5313dcf26b5d7a4bb9 --- /dev/null +++ b/metrics/metric_utils.py @@ -0,0 +1,275 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +import os +import time +import hashlib +import pickle +import copy +import uuid +import numpy as np +import torch +import dnnlib + +#---------------------------------------------------------------------------- + +class MetricOptions: + def __init__(self, G=None, G_kwargs={}, dataset_kwargs={}, num_gpus=1, rank=0, device=None, progress=None, cache=True): + assert 0 <= rank < num_gpus + self.G = G + self.G_kwargs = dnnlib.EasyDict(G_kwargs) + self.dataset_kwargs = dnnlib.EasyDict(dataset_kwargs) + self.num_gpus = num_gpus + self.rank = rank + self.device = device if device is not None else torch.device('cuda', rank) + self.progress = progress.sub() if progress is not None and rank == 0 else ProgressMonitor() + self.cache = cache + +#---------------------------------------------------------------------------- + +_feature_detector_cache = dict() + +def get_feature_detector_name(url): + return os.path.splitext(url.split('/')[-1])[0] + +def get_feature_detector(url, device=torch.device('cpu'), num_gpus=1, rank=0, verbose=False): + assert 0 <= rank < num_gpus + key = (url, device) + if key not in _feature_detector_cache: + is_leader = (rank == 0) + if not is_leader and num_gpus > 1: + torch.distributed.barrier() # leader goes first + with dnnlib.util.open_url(url, verbose=(verbose and is_leader)) as f: + _feature_detector_cache[key] = torch.jit.load(f).eval().to(device) + if is_leader and num_gpus > 1: + torch.distributed.barrier() # others follow + return _feature_detector_cache[key] + +#---------------------------------------------------------------------------- + +class FeatureStats: + def __init__(self, capture_all=False, capture_mean_cov=False, max_items=None): + self.capture_all = capture_all + self.capture_mean_cov = capture_mean_cov + self.max_items = max_items + self.num_items = 0 + self.num_features = None + self.all_features = None + self.raw_mean = None + self.raw_cov = None + + def set_num_features(self, num_features): + if self.num_features is not None: + assert num_features == self.num_features + else: + self.num_features = num_features + self.all_features = [] + self.raw_mean = np.zeros([num_features], dtype=np.float64) + self.raw_cov = np.zeros([num_features, num_features], dtype=np.float64) + + def is_full(self): + return (self.max_items is not None) and (self.num_items >= self.max_items) + + def append(self, x): + x = np.asarray(x, dtype=np.float32) + assert x.ndim == 2 + if (self.max_items is not None) and (self.num_items + x.shape[0] > self.max_items): + if self.num_items >= self.max_items: + return + x = x[:self.max_items - self.num_items] + + self.set_num_features(x.shape[1]) + self.num_items += x.shape[0] + if self.capture_all: + self.all_features.append(x) + if self.capture_mean_cov: + x64 = x.astype(np.float64) + self.raw_mean += x64.sum(axis=0) + self.raw_cov += x64.T @ x64 + + def append_torch(self, x, num_gpus=1, rank=0): + assert isinstance(x, torch.Tensor) and x.ndim == 2 + assert 0 <= rank < num_gpus + if num_gpus > 1: + ys = [] + for src in range(num_gpus): + y = x.clone() + torch.distributed.broadcast(y, src=src) + ys.append(y) + x = torch.stack(ys, dim=1).flatten(0, 1) # interleave samples + self.append(x.cpu().numpy()) + + def get_all(self): + assert self.capture_all + return np.concatenate(self.all_features, axis=0) + + def get_all_torch(self): + return torch.from_numpy(self.get_all()) + + def get_mean_cov(self): + assert self.capture_mean_cov + mean = self.raw_mean / self.num_items + cov = self.raw_cov / self.num_items + cov = cov - np.outer(mean, mean) + return mean, cov + + def save(self, pkl_file): + with open(pkl_file, 'wb') as f: + pickle.dump(self.__dict__, f) + + @staticmethod + def load(pkl_file): + with open(pkl_file, 'rb') as f: + s = dnnlib.EasyDict(pickle.load(f)) + obj = FeatureStats(capture_all=s.capture_all, max_items=s.max_items) + obj.__dict__.update(s) + return obj + +#---------------------------------------------------------------------------- + +class ProgressMonitor: + def __init__(self, tag=None, num_items=None, flush_interval=1000, verbose=False, progress_fn=None, pfn_lo=0, pfn_hi=1000, pfn_total=1000): + self.tag = tag + self.num_items = num_items + self.verbose = verbose + self.flush_interval = flush_interval + self.progress_fn = progress_fn + self.pfn_lo = pfn_lo + self.pfn_hi = pfn_hi + self.pfn_total = pfn_total + self.start_time = time.time() + self.batch_time = self.start_time + self.batch_items = 0 + if self.progress_fn is not None: + self.progress_fn(self.pfn_lo, self.pfn_total) + + def update(self, cur_items): + assert (self.num_items is None) or (cur_items <= self.num_items) + if (cur_items < self.batch_items + self.flush_interval) and (self.num_items is None or cur_items < self.num_items): + return + cur_time = time.time() + total_time = cur_time - self.start_time + time_per_item = (cur_time - self.batch_time) / max(cur_items - self.batch_items, 1) + if (self.verbose) and (self.tag is not None): + print(f'{self.tag:<19s} items {cur_items:<7d} time {dnnlib.util.format_time(total_time):<12s} ms/item {time_per_item*1e3:.2f}') + self.batch_time = cur_time + self.batch_items = cur_items + + if (self.progress_fn is not None) and (self.num_items is not None): + self.progress_fn(self.pfn_lo + (self.pfn_hi - self.pfn_lo) * (cur_items / self.num_items), self.pfn_total) + + def sub(self, tag=None, num_items=None, flush_interval=1000, rel_lo=0, rel_hi=1): + return ProgressMonitor( + tag = tag, + num_items = num_items, + flush_interval = flush_interval, + verbose = self.verbose, + progress_fn = self.progress_fn, + pfn_lo = self.pfn_lo + (self.pfn_hi - self.pfn_lo) * rel_lo, + pfn_hi = self.pfn_lo + (self.pfn_hi - self.pfn_lo) * rel_hi, + pfn_total = self.pfn_total, + ) + +#---------------------------------------------------------------------------- + +def compute_feature_stats_for_dataset(opts, detector_url, detector_kwargs, rel_lo=0, rel_hi=1, batch_size=64, data_loader_kwargs=None, max_items=None, **stats_kwargs): + dataset = dnnlib.util.construct_class_by_name(**opts.dataset_kwargs) + if data_loader_kwargs is None: + data_loader_kwargs = dict(pin_memory=True, num_workers=3, prefetch_factor=2) + + # Try to lookup from cache. + cache_file = None + if opts.cache: + # Choose cache file name. + args = dict(dataset_kwargs=opts.dataset_kwargs, detector_url=detector_url, detector_kwargs=detector_kwargs, stats_kwargs=stats_kwargs) + md5 = hashlib.md5(repr(sorted(args.items())).encode('utf-8')) + cache_tag = f'{dataset.name}-{get_feature_detector_name(detector_url)}-{md5.hexdigest()}' + cache_file = dnnlib.make_cache_dir_path('gan-metrics', cache_tag + '.pkl') + + # Check if the file exists (all processes must agree). + flag = os.path.isfile(cache_file) if opts.rank == 0 else False + if opts.num_gpus > 1: + flag = torch.as_tensor(flag, dtype=torch.float32, device=opts.device) + torch.distributed.broadcast(tensor=flag, src=0) + flag = (float(flag.cpu()) != 0) + + # Load. + if flag: + return FeatureStats.load(cache_file) + + # Initialize. + num_items = len(dataset) + if max_items is not None: + num_items = min(num_items, max_items) + stats = FeatureStats(max_items=num_items, **stats_kwargs) + progress = opts.progress.sub(tag='dataset features', num_items=num_items, rel_lo=rel_lo, rel_hi=rel_hi) + detector = get_feature_detector(url=detector_url, device=opts.device, num_gpus=opts.num_gpus, rank=opts.rank, verbose=progress.verbose) + + # Main loop. + item_subset = [(i * opts.num_gpus + opts.rank) % num_items for i in range((num_items - 1) // opts.num_gpus + 1)] + for images, _labels in torch.utils.data.DataLoader(dataset=dataset, sampler=item_subset, batch_size=batch_size, **data_loader_kwargs): + if images.shape[1] == 1: + images = images.repeat([1, 3, 1, 1]) + features = detector(images.to(opts.device), **detector_kwargs) + stats.append_torch(features, num_gpus=opts.num_gpus, rank=opts.rank) + progress.update(stats.num_items) + + # Save to cache. + if cache_file is not None and opts.rank == 0: + os.makedirs(os.path.dirname(cache_file), exist_ok=True) + temp_file = cache_file + '.' + uuid.uuid4().hex + stats.save(temp_file) + os.replace(temp_file, cache_file) # atomic + return stats + +#---------------------------------------------------------------------------- + +def compute_feature_stats_for_generator(opts, detector_url, detector_kwargs, rel_lo=0, rel_hi=1, batch_size=64, batch_gen=None, jit=False, **stats_kwargs): + if batch_gen is None: + batch_gen = min(batch_size, 4) + assert batch_size % batch_gen == 0 + + # Setup generator and load labels. + G = copy.deepcopy(opts.G).eval().requires_grad_(False).to(opts.device) + dataset = dnnlib.util.construct_class_by_name(**opts.dataset_kwargs) + + # Image generation func. + def run_generator(z, c): + img = G(z=z, c=c, **opts.G_kwargs) + img = (img * 127.5 + 128).clamp(0, 255).to(torch.uint8) + return img + + # JIT. + if jit: + z = torch.zeros([batch_gen, G.z_dim], device=opts.device) + c = torch.zeros([batch_gen, G.c_dim], device=opts.device) + run_generator = torch.jit.trace(run_generator, [z, c], check_trace=False) + + # Initialize. + stats = FeatureStats(**stats_kwargs) + assert stats.max_items is not None + progress = opts.progress.sub(tag='generator features', num_items=stats.max_items, rel_lo=rel_lo, rel_hi=rel_hi) + detector = get_feature_detector(url=detector_url, device=opts.device, num_gpus=opts.num_gpus, rank=opts.rank, verbose=progress.verbose) + + # Main loop. + while not stats.is_full(): + images = [] + for _i in range(batch_size // batch_gen): + z = torch.randn([batch_gen, G.z_dim], device=opts.device) + c = [dataset.get_label(np.random.randint(len(dataset))) for _i in range(batch_gen)] + c = torch.from_numpy(np.stack(c)).pin_memory().to(opts.device) + images.append(run_generator(z, c)) + images = torch.cat(images) + if images.shape[1] == 1: + images = images.repeat([1, 3, 1, 1]) + features = detector(images, **detector_kwargs) + stats.append_torch(features, num_gpus=opts.num_gpus, rank=opts.rank) + progress.update(stats.num_items) + return stats + +#---------------------------------------------------------------------------- diff --git a/metrics/perceptual_path_length.py b/metrics/perceptual_path_length.py new file mode 100644 index 0000000000000000000000000000000000000000..d070f45a04efed7e9492fddb85078be306753282 --- /dev/null +++ b/metrics/perceptual_path_length.py @@ -0,0 +1,131 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +"""Perceptual Path Length (PPL) from the paper "A Style-Based Generator +Architecture for Generative Adversarial Networks". Matches the original +implementation by Karras et al. at +https://github.com/NVlabs/stylegan/blob/master/metrics/perceptual_path_length.py""" + +import copy +import numpy as np +import torch +import dnnlib +from . import metric_utils + +#---------------------------------------------------------------------------- + +# Spherical interpolation of a batch of vectors. +def slerp(a, b, t): + a = a / a.norm(dim=-1, keepdim=True) + b = b / b.norm(dim=-1, keepdim=True) + d = (a * b).sum(dim=-1, keepdim=True) + p = t * torch.acos(d) + c = b - d * a + c = c / c.norm(dim=-1, keepdim=True) + d = a * torch.cos(p) + c * torch.sin(p) + d = d / d.norm(dim=-1, keepdim=True) + return d + +#---------------------------------------------------------------------------- + +class PPLSampler(torch.nn.Module): + def __init__(self, G, G_kwargs, epsilon, space, sampling, crop, vgg16): + assert space in ['z', 'w'] + assert sampling in ['full', 'end'] + super().__init__() + self.G = copy.deepcopy(G) + self.G_kwargs = G_kwargs + self.epsilon = epsilon + self.space = space + self.sampling = sampling + self.crop = crop + self.vgg16 = copy.deepcopy(vgg16) + + def forward(self, c): + # Generate random latents and interpolation t-values. + t = torch.rand([c.shape[0]], device=c.device) * (1 if self.sampling == 'full' else 0) + z0, z1 = torch.randn([c.shape[0] * 2, self.G.z_dim], device=c.device).chunk(2) + + # Interpolate in W or Z. + if self.space == 'w': + w0, w1 = self.G.mapping(z=torch.cat([z0,z1]), c=torch.cat([c,c])).chunk(2) + wt0 = w0.lerp(w1, t.unsqueeze(1).unsqueeze(2)) + wt1 = w0.lerp(w1, t.unsqueeze(1).unsqueeze(2) + self.epsilon) + else: # space == 'z' + zt0 = slerp(z0, z1, t.unsqueeze(1)) + zt1 = slerp(z0, z1, t.unsqueeze(1) + self.epsilon) + wt0, wt1 = self.G.mapping(z=torch.cat([zt0,zt1]), c=torch.cat([c,c])).chunk(2) + + # Randomize noise buffers. + for name, buf in self.G.named_buffers(): + if name.endswith('.noise_const'): + buf.copy_(torch.randn_like(buf)) + + # Generate images. + img = self.G.synthesis(ws=torch.cat([wt0,wt1]), noise_mode='const', force_fp32=True, **self.G_kwargs) + + # Center crop. + if self.crop: + assert img.shape[2] == img.shape[3] + c = img.shape[2] // 8 + img = img[:, :, c*3 : c*7, c*2 : c*6] + + # Downsample to 256x256. + factor = self.G.img_resolution // 256 + if factor > 1: + img = img.reshape([-1, img.shape[1], img.shape[2] // factor, factor, img.shape[3] // factor, factor]).mean([3, 5]) + + # Scale dynamic range from [-1,1] to [0,255]. + img = (img + 1) * (255 / 2) + if self.G.img_channels == 1: + img = img.repeat([1, 3, 1, 1]) + + # Evaluate differential LPIPS. + lpips_t0, lpips_t1 = self.vgg16(img, resize_images=False, return_lpips=True).chunk(2) + dist = (lpips_t0 - lpips_t1).square().sum(1) / self.epsilon ** 2 + return dist + +#---------------------------------------------------------------------------- + +def compute_ppl(opts, num_samples, epsilon, space, sampling, crop, batch_size, jit=False): + dataset = dnnlib.util.construct_class_by_name(**opts.dataset_kwargs) + vgg16_url = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/vgg16.pt' + vgg16 = metric_utils.get_feature_detector(vgg16_url, num_gpus=opts.num_gpus, rank=opts.rank, verbose=opts.progress.verbose) + + # Setup sampler. + sampler = PPLSampler(G=opts.G, G_kwargs=opts.G_kwargs, epsilon=epsilon, space=space, sampling=sampling, crop=crop, vgg16=vgg16) + sampler.eval().requires_grad_(False).to(opts.device) + if jit: + c = torch.zeros([batch_size, opts.G.c_dim], device=opts.device) + sampler = torch.jit.trace(sampler, [c], check_trace=False) + + # Sampling loop. + dist = [] + progress = opts.progress.sub(tag='ppl sampling', num_items=num_samples) + for batch_start in range(0, num_samples, batch_size * opts.num_gpus): + progress.update(batch_start) + c = [dataset.get_label(np.random.randint(len(dataset))) for _i in range(batch_size)] + c = torch.from_numpy(np.stack(c)).pin_memory().to(opts.device) + x = sampler(c) + for src in range(opts.num_gpus): + y = x.clone() + if opts.num_gpus > 1: + torch.distributed.broadcast(y, src=src) + dist.append(y) + progress.update(num_samples) + + # Compute PPL. + if opts.rank != 0: + return float('nan') + dist = torch.cat(dist)[:num_samples].cpu().numpy() + lo = np.percentile(dist, 1, interpolation='lower') + hi = np.percentile(dist, 99, interpolation='higher') + ppl = np.extract(np.logical_and(dist >= lo, dist <= hi), dist).mean() + return float(ppl) + +#---------------------------------------------------------------------------- diff --git a/metrics/precision_recall.py b/metrics/precision_recall.py new file mode 100644 index 0000000000000000000000000000000000000000..8200b7ef51963ae218e3b871de270a826bf10459 --- /dev/null +++ b/metrics/precision_recall.py @@ -0,0 +1,62 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +"""Precision/Recall (PR) from the paper "Improved Precision and Recall +Metric for Assessing Generative Models". Matches the original implementation +by Kynkaanniemi et al. at +https://github.com/kynkaat/improved-precision-and-recall-metric/blob/master/precision_recall.py""" + +import torch +from . import metric_utils + +#---------------------------------------------------------------------------- + +def compute_distances(row_features, col_features, num_gpus, rank, col_batch_size): + assert 0 <= rank < num_gpus + num_cols = col_features.shape[0] + num_batches = ((num_cols - 1) // col_batch_size // num_gpus + 1) * num_gpus + col_batches = torch.nn.functional.pad(col_features, [0, 0, 0, -num_cols % num_batches]).chunk(num_batches) + dist_batches = [] + for col_batch in col_batches[rank :: num_gpus]: + dist_batch = torch.cdist(row_features.unsqueeze(0), col_batch.unsqueeze(0))[0] + for src in range(num_gpus): + dist_broadcast = dist_batch.clone() + if num_gpus > 1: + torch.distributed.broadcast(dist_broadcast, src=src) + dist_batches.append(dist_broadcast.cpu() if rank == 0 else None) + return torch.cat(dist_batches, dim=1)[:, :num_cols] if rank == 0 else None + +#---------------------------------------------------------------------------- + +def compute_pr(opts, max_real, num_gen, nhood_size, row_batch_size, col_batch_size): + detector_url = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/vgg16.pt' + detector_kwargs = dict(return_features=True) + + real_features = metric_utils.compute_feature_stats_for_dataset( + opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs, + rel_lo=0, rel_hi=0, capture_all=True, max_items=max_real).get_all_torch().to(torch.float16).to(opts.device) + + gen_features = metric_utils.compute_feature_stats_for_generator( + opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs, + rel_lo=0, rel_hi=1, capture_all=True, max_items=num_gen).get_all_torch().to(torch.float16).to(opts.device) + + results = dict() + for name, manifold, probes in [('precision', real_features, gen_features), ('recall', gen_features, real_features)]: + kth = [] + for manifold_batch in manifold.split(row_batch_size): + dist = compute_distances(row_features=manifold_batch, col_features=manifold, num_gpus=opts.num_gpus, rank=opts.rank, col_batch_size=col_batch_size) + kth.append(dist.to(torch.float32).kthvalue(nhood_size + 1).values.to(torch.float16) if opts.rank == 0 else None) + kth = torch.cat(kth) if opts.rank == 0 else None + pred = [] + for probes_batch in probes.split(row_batch_size): + dist = compute_distances(row_features=probes_batch, col_features=manifold, num_gpus=opts.num_gpus, rank=opts.rank, col_batch_size=col_batch_size) + pred.append((dist <= kth).any(dim=1) if opts.rank == 0 else None) + results[name] = float(torch.cat(pred).to(torch.float32).mean() if opts.rank == 0 else 'nan') + return results['precision'], results['recall'] + +#---------------------------------------------------------------------------- diff --git a/model_ir_se50.pth b/model_ir_se50.pth new file mode 100644 index 0000000000000000000000000000000000000000..d3a030dd9a353d94023d3fc3a5baa0991ca3873b --- /dev/null +++ b/model_ir_se50.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a035c768259b98ab1ce0e646312f48b9e1e218197a0f80ac6765e88f8b6ddf28 +size 175367323 diff --git a/model_irse.py b/model_irse.py new file mode 100644 index 0000000000000000000000000000000000000000..b34e1b7dc21dbaa6488387b670f49322c4d906c0 --- /dev/null +++ b/model_irse.py @@ -0,0 +1,84 @@ +from torch.nn import Linear, Conv2d, BatchNorm1d, BatchNorm2d, PReLU, Dropout, Sequential, Module +from helpers import get_blocks, Flatten, bottleneck_IR, bottleneck_IR_SE, l2_norm + +""" +Modified Backbone implementation from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch) +""" + + +class Backbone(Module): + def __init__(self, input_size, num_layers, mode='ir', drop_ratio=0.4, affine=True): + super(Backbone, self).__init__() + assert input_size in [112, 224], "input_size should be 112 or 224" + assert num_layers in [50, 100, 152], "num_layers should be 50, 100 or 152" + assert mode in ['ir', 'ir_se'], "mode should be ir or ir_se" + blocks = get_blocks(num_layers) + if mode == 'ir': + unit_module = bottleneck_IR + elif mode == 'ir_se': + unit_module = bottleneck_IR_SE + self.input_layer = Sequential(Conv2d(3, 64, (3, 3), 1, 1, bias=False), + BatchNorm2d(64), + PReLU(64)) + if input_size == 112: + self.output_layer = Sequential(BatchNorm2d(512), + Dropout(drop_ratio), + Flatten(), + Linear(512 * 7 * 7, 512), + BatchNorm1d(512, affine=affine)) + else: + self.output_layer = Sequential(BatchNorm2d(512), + Dropout(drop_ratio), + Flatten(), + Linear(512 * 14 * 14, 512), + BatchNorm1d(512, affine=affine)) + + modules = [] + for block in blocks: + for bottleneck in block: + modules.append(unit_module(bottleneck.in_channel, + bottleneck.depth, + bottleneck.stride)) + self.body = Sequential(*modules) + + def forward(self, x): + x = self.input_layer(x) + x = self.body(x) + x = self.output_layer(x) + return l2_norm(x) + + +def IR_50(input_size): + """Constructs a ir-50 model.""" + model = Backbone(input_size, num_layers=50, mode='ir', drop_ratio=0.4, affine=False) + return model + + +def IR_101(input_size): + """Constructs a ir-101 model.""" + model = Backbone(input_size, num_layers=100, mode='ir', drop_ratio=0.4, affine=False) + return model + + +def IR_152(input_size): + """Constructs a ir-152 model.""" + model = Backbone(input_size, num_layers=152, mode='ir', drop_ratio=0.4, affine=False) + return model + + +def IR_SE_50(input_size): + """Constructs a ir_se-50 model.""" + model = Backbone(input_size, num_layers=50, mode='ir_se', drop_ratio=0.4, affine=False) + return model + + +def IR_SE_101(input_size): + """Constructs a ir_se-101 model.""" + model = Backbone(input_size, num_layers=100, mode='ir_se', drop_ratio=0.4, affine=False) + return model + + +def IR_SE_152(input_size): + """Constructs a ir_se-152 model.""" + model = Backbone(input_size, num_layers=152, mode='ir_se', drop_ratio=0.4, affine=False) + return model \ No newline at end of file diff --git a/torch_utils/.DS_Store b/torch_utils/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..868d4412b155d34002b27a12e375afa768192658 Binary files /dev/null and b/torch_utils/.DS_Store differ diff --git a/torch_utils/__init__.py b/torch_utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ece0ea08fe2e939cc260a1dafc0ab5b391b773d9 --- /dev/null +++ b/torch_utils/__init__.py @@ -0,0 +1,9 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +# empty diff --git a/torch_utils/__pycache__/__init__.cpython-38.pyc b/torch_utils/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..36fdf2d42f8391370ca21abb25990854f7af5dc3 Binary files /dev/null and b/torch_utils/__pycache__/__init__.cpython-38.pyc differ diff --git a/torch_utils/__pycache__/custom_ops.cpython-38.pyc b/torch_utils/__pycache__/custom_ops.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..70e8130882c8c099ed38cebc3c57f39d82b0c8f7 Binary files /dev/null and b/torch_utils/__pycache__/custom_ops.cpython-38.pyc differ diff --git a/torch_utils/__pycache__/misc.cpython-38.pyc b/torch_utils/__pycache__/misc.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5f1f9d22efaffffb6ab8b662cd54038cc2fab3f3 Binary files /dev/null and b/torch_utils/__pycache__/misc.cpython-38.pyc differ diff --git a/torch_utils/__pycache__/persistence.cpython-38.pyc b/torch_utils/__pycache__/persistence.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2a04c96718fcc78478953a6d11b041ad330351f1 Binary files /dev/null and b/torch_utils/__pycache__/persistence.cpython-38.pyc differ diff --git a/torch_utils/custom_ops.py b/torch_utils/custom_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..4cc4e43fc6f6ce79f2bd68a44ba87990b9b8564e --- /dev/null +++ b/torch_utils/custom_ops.py @@ -0,0 +1,126 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +import os +import glob +import torch +import torch.utils.cpp_extension +import importlib +import hashlib +import shutil +from pathlib import Path + +from torch.utils.file_baton import FileBaton + +#---------------------------------------------------------------------------- +# Global options. + +verbosity = 'brief' # Verbosity level: 'none', 'brief', 'full' + +#---------------------------------------------------------------------------- +# Internal helper funcs. + +def _find_compiler_bindir(): + patterns = [ + 'C:/Program Files (x86)/Microsoft Visual Studio/*/Professional/VC/Tools/MSVC/*/bin/Hostx64/x64', + 'C:/Program Files (x86)/Microsoft Visual Studio/*/BuildTools/VC/Tools/MSVC/*/bin/Hostx64/x64', + 'C:/Program Files (x86)/Microsoft Visual Studio/*/Community/VC/Tools/MSVC/*/bin/Hostx64/x64', + 'C:/Program Files (x86)/Microsoft Visual Studio */vc/bin', + ] + for pattern in patterns: + matches = sorted(glob.glob(pattern)) + if len(matches): + return matches[-1] + return None + +#---------------------------------------------------------------------------- +# Main entry point for compiling and loading C++/CUDA plugins. + +_cached_plugins = dict() + +def get_plugin(module_name, sources, **build_kwargs): + assert verbosity in ['none', 'brief', 'full'] + + # Already cached? + if module_name in _cached_plugins: + return _cached_plugins[module_name] + + # Print status. + if verbosity == 'full': + print(f'Setting up PyTorch plugin "{module_name}"...') + elif verbosity == 'brief': + print(f'Setting up PyTorch plugin "{module_name}"... ', end='', flush=True) + + try: # pylint: disable=too-many-nested-blocks + # Make sure we can find the necessary compiler binaries. + if os.name == 'nt' and os.system("where cl.exe >nul 2>nul") != 0: + compiler_bindir = _find_compiler_bindir() + if compiler_bindir is None: + raise RuntimeError(f'Could not find MSVC/GCC/CLANG installation on this computer. Check _find_compiler_bindir() in "{__file__}".') + os.environ['PATH'] += ';' + compiler_bindir + + # Compile and load. + verbose_build = (verbosity == 'full') + + # Incremental build md5sum trickery. Copies all the input source files + # into a cached build directory under a combined md5 digest of the input + # source files. Copying is done only if the combined digest has changed. + # This keeps input file timestamps and filenames the same as in previous + # extension builds, allowing for fast incremental rebuilds. + # + # This optimization is done only in case all the source files reside in + # a single directory (just for simplicity) and if the TORCH_EXTENSIONS_DIR + # environment variable is set (we take this as a signal that the user + # actually cares about this.) + source_dirs_set = set(os.path.dirname(source) for source in sources) + if len(source_dirs_set) == 1 and ('TORCH_EXTENSIONS_DIR' in os.environ): + all_source_files = sorted(list(x for x in Path(list(source_dirs_set)[0]).iterdir() if x.is_file())) + + # Compute a combined hash digest for all source files in the same + # custom op directory (usually .cu, .cpp, .py and .h files). + hash_md5 = hashlib.md5() + for src in all_source_files: + with open(src, 'rb') as f: + hash_md5.update(f.read()) + build_dir = torch.utils.cpp_extension._get_build_directory(module_name, verbose=verbose_build) # pylint: disable=protected-access + digest_build_dir = os.path.join(build_dir, hash_md5.hexdigest()) + + if not os.path.isdir(digest_build_dir): + os.makedirs(digest_build_dir, exist_ok=True) + baton = FileBaton(os.path.join(digest_build_dir, 'lock')) + if baton.try_acquire(): + try: + for src in all_source_files: + shutil.copyfile(src, os.path.join(digest_build_dir, os.path.basename(src))) + finally: + baton.release() + else: + # Someone else is copying source files under the digest dir, + # wait until done and continue. + baton.wait() + digest_sources = [os.path.join(digest_build_dir, os.path.basename(x)) for x in sources] + torch.utils.cpp_extension.load(name=module_name, build_directory=build_dir, + verbose=verbose_build, sources=digest_sources, **build_kwargs) + else: + torch.utils.cpp_extension.load(name=module_name, verbose=verbose_build, sources=sources, **build_kwargs) + module = importlib.import_module(module_name) + + except: + if verbosity == 'brief': + print('Failed!') + raise + + # Print status and add to cache. + if verbosity == 'full': + print(f'Done setting up PyTorch plugin "{module_name}".') + elif verbosity == 'brief': + print('Done.') + _cached_plugins[module_name] = module + return module + +#---------------------------------------------------------------------------- diff --git a/torch_utils/misc.py b/torch_utils/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..7829f4d9f168557ce8a9a6dec289aa964234cb8c --- /dev/null +++ b/torch_utils/misc.py @@ -0,0 +1,262 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +import re +import contextlib +import numpy as np +import torch +import warnings +import dnnlib + +#---------------------------------------------------------------------------- +# Cached construction of constant tensors. Avoids CPU=>GPU copy when the +# same constant is used multiple times. + +_constant_cache = dict() + +def constant(value, shape=None, dtype=None, device=None, memory_format=None): + value = np.asarray(value) + if shape is not None: + shape = tuple(shape) + if dtype is None: + dtype = torch.get_default_dtype() + if device is None: + device = torch.device('cpu') + if memory_format is None: + memory_format = torch.contiguous_format + + key = (value.shape, value.dtype, value.tobytes(), shape, dtype, device, memory_format) + tensor = _constant_cache.get(key, None) + if tensor is None: + tensor = torch.as_tensor(value.copy(), dtype=dtype, device=device) + if shape is not None: + tensor, _ = torch.broadcast_tensors(tensor, torch.empty(shape)) + tensor = tensor.contiguous(memory_format=memory_format) + _constant_cache[key] = tensor + return tensor + +#---------------------------------------------------------------------------- +# Replace NaN/Inf with specified numerical values. + +try: + nan_to_num = torch.nan_to_num # 1.8.0a0 +except AttributeError: + def nan_to_num(input, nan=0.0, posinf=None, neginf=None, *, out=None): # pylint: disable=redefined-builtin + assert isinstance(input, torch.Tensor) + if posinf is None: + posinf = torch.finfo(input.dtype).max + if neginf is None: + neginf = torch.finfo(input.dtype).min + assert nan == 0 + return torch.clamp(input.unsqueeze(0).nansum(0), min=neginf, max=posinf, out=out) + +#---------------------------------------------------------------------------- +# Symbolic assert. + +try: + symbolic_assert = torch._assert # 1.8.0a0 # pylint: disable=protected-access +except AttributeError: + symbolic_assert = torch.Assert # 1.7.0 + +#---------------------------------------------------------------------------- +# Context manager to suppress known warnings in torch.jit.trace(). + +class suppress_tracer_warnings(warnings.catch_warnings): + def __enter__(self): + super().__enter__() + warnings.simplefilter('ignore', category=torch.jit.TracerWarning) + return self + +#---------------------------------------------------------------------------- +# Assert that the shape of a tensor matches the given list of integers. +# None indicates that the size of a dimension is allowed to vary. +# Performs symbolic assertion when used in torch.jit.trace(). + +def assert_shape(tensor, ref_shape): + if tensor.ndim != len(ref_shape): + raise AssertionError(f'Wrong number of dimensions: got {tensor.ndim}, expected {len(ref_shape)}') + for idx, (size, ref_size) in enumerate(zip(tensor.shape, ref_shape)): + if ref_size is None: + pass + elif isinstance(ref_size, torch.Tensor): + with suppress_tracer_warnings(): # as_tensor results are registered as constants + symbolic_assert(torch.equal(torch.as_tensor(size), ref_size), f'Wrong size for dimension {idx}') + elif isinstance(size, torch.Tensor): + with suppress_tracer_warnings(): # as_tensor results are registered as constants + symbolic_assert(torch.equal(size, torch.as_tensor(ref_size)), f'Wrong size for dimension {idx}: expected {ref_size}') + elif size != ref_size: + raise AssertionError(f'Wrong size for dimension {idx}: got {size}, expected {ref_size}') + +#---------------------------------------------------------------------------- +# Function decorator that calls torch.autograd.profiler.record_function(). + +def profiled_function(fn): + def decorator(*args, **kwargs): + with torch.autograd.profiler.record_function(fn.__name__): + return fn(*args, **kwargs) + decorator.__name__ = fn.__name__ + return decorator + +#---------------------------------------------------------------------------- +# Sampler for torch.utils.data.DataLoader that loops over the dataset +# indefinitely, shuffling items as it goes. + +class InfiniteSampler(torch.utils.data.Sampler): + def __init__(self, dataset, rank=0, num_replicas=1, shuffle=True, seed=0, window_size=0.5): + assert len(dataset) > 0 + assert num_replicas > 0 + assert 0 <= rank < num_replicas + assert 0 <= window_size <= 1 + super().__init__(dataset) + self.dataset = dataset + self.rank = rank + self.num_replicas = num_replicas + self.shuffle = shuffle + self.seed = seed + self.window_size = window_size + + def __iter__(self): + order = np.arange(len(self.dataset)) + rnd = None + window = 0 + if self.shuffle: + rnd = np.random.RandomState(self.seed) + rnd.shuffle(order) + window = int(np.rint(order.size * self.window_size)) + + idx = 0 + while True: + i = idx % order.size + if idx % self.num_replicas == self.rank: + yield order[i] + if window >= 2: + j = (i - rnd.randint(window)) % order.size + order[i], order[j] = order[j], order[i] + idx += 1 + +#---------------------------------------------------------------------------- +# Utilities for operating with torch.nn.Module parameters and buffers. + +def params_and_buffers(module): + assert isinstance(module, torch.nn.Module) + return list(module.parameters()) + list(module.buffers()) + +def named_params_and_buffers(module): + assert isinstance(module, torch.nn.Module) + return list(module.named_parameters()) + list(module.named_buffers()) + +def copy_params_and_buffers(src_module, dst_module, require_all=False): + assert isinstance(src_module, torch.nn.Module) + assert isinstance(dst_module, torch.nn.Module) + src_tensors = {name: tensor for name, tensor in named_params_and_buffers(src_module)} + for name, tensor in named_params_and_buffers(dst_module): + assert (name in src_tensors) or (not require_all) + if name in src_tensors: + tensor.copy_(src_tensors[name].detach()).requires_grad_(tensor.requires_grad) + +#---------------------------------------------------------------------------- +# Context manager for easily enabling/disabling DistributedDataParallel +# synchronization. + +@contextlib.contextmanager +def ddp_sync(module, sync): + assert isinstance(module, torch.nn.Module) + if sync or not isinstance(module, torch.nn.parallel.DistributedDataParallel): + yield + else: + with module.no_sync(): + yield + +#---------------------------------------------------------------------------- +# Check DistributedDataParallel consistency across processes. + +def check_ddp_consistency(module, ignore_regex=None): + assert isinstance(module, torch.nn.Module) + for name, tensor in named_params_and_buffers(module): + fullname = type(module).__name__ + '.' + name + if ignore_regex is not None and re.fullmatch(ignore_regex, fullname): + continue + tensor = tensor.detach() + other = tensor.clone() + torch.distributed.broadcast(tensor=other, src=0) + assert (nan_to_num(tensor) == nan_to_num(other)).all(), fullname + +#---------------------------------------------------------------------------- +# Print summary table of module hierarchy. + +def print_module_summary(module, inputs, max_nesting=3, skip_redundant=True): + assert isinstance(module, torch.nn.Module) + assert not isinstance(module, torch.jit.ScriptModule) + assert isinstance(inputs, (tuple, list)) + + # Register hooks. + entries = [] + nesting = [0] + def pre_hook(_mod, _inputs): + nesting[0] += 1 + def post_hook(mod, _inputs, outputs): + nesting[0] -= 1 + if nesting[0] <= max_nesting: + outputs = list(outputs) if isinstance(outputs, (tuple, list)) else [outputs] + outputs = [t for t in outputs if isinstance(t, torch.Tensor)] + entries.append(dnnlib.EasyDict(mod=mod, outputs=outputs)) + hooks = [mod.register_forward_pre_hook(pre_hook) for mod in module.modules()] + hooks += [mod.register_forward_hook(post_hook) for mod in module.modules()] + + # Run module. + outputs = module(*inputs) + for hook in hooks: + hook.remove() + + # Identify unique outputs, parameters, and buffers. + tensors_seen = set() + for e in entries: + e.unique_params = [t for t in e.mod.parameters() if id(t) not in tensors_seen] + e.unique_buffers = [t for t in e.mod.buffers() if id(t) not in tensors_seen] + e.unique_outputs = [t for t in e.outputs if id(t) not in tensors_seen] + tensors_seen |= {id(t) for t in e.unique_params + e.unique_buffers + e.unique_outputs} + + # Filter out redundant entries. + if skip_redundant: + entries = [e for e in entries if len(e.unique_params) or len(e.unique_buffers) or len(e.unique_outputs)] + + # Construct table. + rows = [[type(module).__name__, 'Parameters', 'Buffers', 'Output shape', 'Datatype']] + rows += [['---'] * len(rows[0])] + param_total = 0 + buffer_total = 0 + submodule_names = {mod: name for name, mod in module.named_modules()} + for e in entries: + name = '' if e.mod is module else submodule_names[e.mod] + param_size = sum(t.numel() for t in e.unique_params) + buffer_size = sum(t.numel() for t in e.unique_buffers) + output_shapes = [str(list(e.outputs[0].shape)) for t in e.outputs] + output_dtypes = [str(t.dtype).split('.')[-1] for t in e.outputs] + rows += [[ + name + (':0' if len(e.outputs) >= 2 else ''), + str(param_size) if param_size else '-', + str(buffer_size) if buffer_size else '-', + (output_shapes + ['-'])[0], + (output_dtypes + ['-'])[0], + ]] + for idx in range(1, len(e.outputs)): + rows += [[name + f':{idx}', '-', '-', output_shapes[idx], output_dtypes[idx]]] + param_total += param_size + buffer_total += buffer_size + rows += [['---'] * len(rows[0])] + rows += [['Total', str(param_total), str(buffer_total), '-', '-']] + + # Print table. + widths = [max(len(cell) for cell in column) for column in zip(*rows)] + print() + for row in rows: + print(' '.join(cell + ' ' * (width - len(cell)) for cell, width in zip(row, widths))) + print() + return outputs + +#---------------------------------------------------------------------------- diff --git a/torch_utils/ops/__init__.py b/torch_utils/ops/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ece0ea08fe2e939cc260a1dafc0ab5b391b773d9 --- /dev/null +++ b/torch_utils/ops/__init__.py @@ -0,0 +1,9 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +# empty diff --git a/torch_utils/ops/__pycache__/__init__.cpython-38.pyc b/torch_utils/ops/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b18840381ad0cc2603ce35800f9cfea52971162a Binary files /dev/null and b/torch_utils/ops/__pycache__/__init__.cpython-38.pyc differ diff --git a/torch_utils/ops/__pycache__/bias_act.cpython-38.pyc b/torch_utils/ops/__pycache__/bias_act.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..136a43d9f09d3b41097f3ee5b6bace1805274702 Binary files /dev/null and b/torch_utils/ops/__pycache__/bias_act.cpython-38.pyc differ diff --git a/torch_utils/ops/__pycache__/conv2d_gradfix.cpython-38.pyc b/torch_utils/ops/__pycache__/conv2d_gradfix.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0164bf7e169086042db03eb5ee1f837096d64ca0 Binary files /dev/null and b/torch_utils/ops/__pycache__/conv2d_gradfix.cpython-38.pyc differ diff --git a/torch_utils/ops/__pycache__/conv2d_resample.cpython-38.pyc b/torch_utils/ops/__pycache__/conv2d_resample.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c9e975c77a618051b9146f18bbe1f5efd905487d Binary files /dev/null and b/torch_utils/ops/__pycache__/conv2d_resample.cpython-38.pyc differ diff --git a/torch_utils/ops/__pycache__/fma.cpython-38.pyc b/torch_utils/ops/__pycache__/fma.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0093da982ddca7375754be809743ecd3606028ab Binary files /dev/null and b/torch_utils/ops/__pycache__/fma.cpython-38.pyc differ diff --git a/torch_utils/ops/__pycache__/upfirdn2d.cpython-38.pyc b/torch_utils/ops/__pycache__/upfirdn2d.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..55edbd1d76ee40a2d4a7f8d56a1a7c7a6d0599db Binary files /dev/null and b/torch_utils/ops/__pycache__/upfirdn2d.cpython-38.pyc differ diff --git a/torch_utils/ops/bias_act.cpp b/torch_utils/ops/bias_act.cpp new file mode 100644 index 0000000000000000000000000000000000000000..5d2425d8054991a8e8b6f7a940fd0ff7fa0bb330 --- /dev/null +++ b/torch_utils/ops/bias_act.cpp @@ -0,0 +1,99 @@ +// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +// +// NVIDIA CORPORATION and its licensors retain all intellectual property +// and proprietary rights in and to this software, related documentation +// and any modifications thereto. Any use, reproduction, disclosure or +// distribution of this software and related documentation without an express +// license agreement from NVIDIA CORPORATION is strictly prohibited. + +#include +#include +#include +#include "bias_act.h" + +//------------------------------------------------------------------------ + +static bool has_same_layout(torch::Tensor x, torch::Tensor y) +{ + if (x.dim() != y.dim()) + return false; + for (int64_t i = 0; i < x.dim(); i++) + { + if (x.size(i) != y.size(i)) + return false; + if (x.size(i) >= 2 && x.stride(i) != y.stride(i)) + return false; + } + return true; +} + +//------------------------------------------------------------------------ + +static torch::Tensor bias_act(torch::Tensor x, torch::Tensor b, torch::Tensor xref, torch::Tensor yref, torch::Tensor dy, int grad, int dim, int act, float alpha, float gain, float clamp) +{ + // Validate arguments. + TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device"); + TORCH_CHECK(b.numel() == 0 || (b.dtype() == x.dtype() && b.device() == x.device()), "b must have the same dtype and device as x"); + TORCH_CHECK(xref.numel() == 0 || (xref.sizes() == x.sizes() && xref.dtype() == x.dtype() && xref.device() == x.device()), "xref must have the same shape, dtype, and device as x"); + TORCH_CHECK(yref.numel() == 0 || (yref.sizes() == x.sizes() && yref.dtype() == x.dtype() && yref.device() == x.device()), "yref must have the same shape, dtype, and device as x"); + TORCH_CHECK(dy.numel() == 0 || (dy.sizes() == x.sizes() && dy.dtype() == x.dtype() && dy.device() == x.device()), "dy must have the same dtype and device as x"); + TORCH_CHECK(x.numel() <= INT_MAX, "x is too large"); + TORCH_CHECK(b.dim() == 1, "b must have rank 1"); + TORCH_CHECK(b.numel() == 0 || (dim >= 0 && dim < x.dim()), "dim is out of bounds"); + TORCH_CHECK(b.numel() == 0 || b.numel() == x.size(dim), "b has wrong number of elements"); + TORCH_CHECK(grad >= 0, "grad must be non-negative"); + + // Validate layout. + TORCH_CHECK(x.is_non_overlapping_and_dense(), "x must be non-overlapping and dense"); + TORCH_CHECK(b.is_contiguous(), "b must be contiguous"); + TORCH_CHECK(xref.numel() == 0 || has_same_layout(xref, x), "xref must have the same layout as x"); + TORCH_CHECK(yref.numel() == 0 || has_same_layout(yref, x), "yref must have the same layout as x"); + TORCH_CHECK(dy.numel() == 0 || has_same_layout(dy, x), "dy must have the same layout as x"); + + // Create output tensor. + const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); + torch::Tensor y = torch::empty_like(x); + TORCH_CHECK(has_same_layout(y, x), "y must have the same layout as x"); + + // Initialize CUDA kernel parameters. + bias_act_kernel_params p; + p.x = x.data_ptr(); + p.b = (b.numel()) ? b.data_ptr() : NULL; + p.xref = (xref.numel()) ? xref.data_ptr() : NULL; + p.yref = (yref.numel()) ? yref.data_ptr() : NULL; + p.dy = (dy.numel()) ? dy.data_ptr() : NULL; + p.y = y.data_ptr(); + p.grad = grad; + p.act = act; + p.alpha = alpha; + p.gain = gain; + p.clamp = clamp; + p.sizeX = (int)x.numel(); + p.sizeB = (int)b.numel(); + p.stepB = (b.numel()) ? (int)x.stride(dim) : 1; + + // Choose CUDA kernel. + void* kernel; + AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] + { + kernel = choose_bias_act_kernel(p); + }); + TORCH_CHECK(kernel, "no CUDA kernel found for the specified activation func"); + + // Launch CUDA kernel. + p.loopX = 4; + int blockSize = 4 * 32; + int gridSize = (p.sizeX - 1) / (p.loopX * blockSize) + 1; + void* args[] = {&p}; + AT_CUDA_CHECK(cudaLaunchKernel(kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream())); + return y; +} + +//------------------------------------------------------------------------ + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) +{ + m.def("bias_act", &bias_act); +} + +//------------------------------------------------------------------------ diff --git a/torch_utils/ops/bias_act.cu b/torch_utils/ops/bias_act.cu new file mode 100644 index 0000000000000000000000000000000000000000..dd8fc4756d7d94727f94af738665b68d9c518880 --- /dev/null +++ b/torch_utils/ops/bias_act.cu @@ -0,0 +1,173 @@ +// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +// +// NVIDIA CORPORATION and its licensors retain all intellectual property +// and proprietary rights in and to this software, related documentation +// and any modifications thereto. Any use, reproduction, disclosure or +// distribution of this software and related documentation without an express +// license agreement from NVIDIA CORPORATION is strictly prohibited. + +#include +#include "bias_act.h" + +//------------------------------------------------------------------------ +// Helpers. + +template struct InternalType; +template <> struct InternalType { typedef double scalar_t; }; +template <> struct InternalType { typedef float scalar_t; }; +template <> struct InternalType { typedef float scalar_t; }; + +//------------------------------------------------------------------------ +// CUDA kernel. + +template +__global__ void bias_act_kernel(bias_act_kernel_params p) +{ + typedef typename InternalType::scalar_t scalar_t; + int G = p.grad; + scalar_t alpha = (scalar_t)p.alpha; + scalar_t gain = (scalar_t)p.gain; + scalar_t clamp = (scalar_t)p.clamp; + scalar_t one = (scalar_t)1; + scalar_t two = (scalar_t)2; + scalar_t expRange = (scalar_t)80; + scalar_t halfExpRange = (scalar_t)40; + scalar_t seluScale = (scalar_t)1.0507009873554804934193349852946; + scalar_t seluAlpha = (scalar_t)1.6732632423543772848170429916717; + + // Loop over elements. + int xi = blockIdx.x * p.loopX * blockDim.x + threadIdx.x; + for (int loopIdx = 0; loopIdx < p.loopX && xi < p.sizeX; loopIdx++, xi += blockDim.x) + { + // Load. + scalar_t x = (scalar_t)((const T*)p.x)[xi]; + scalar_t b = (p.b) ? (scalar_t)((const T*)p.b)[(xi / p.stepB) % p.sizeB] : 0; + scalar_t xref = (p.xref) ? (scalar_t)((const T*)p.xref)[xi] : 0; + scalar_t yref = (p.yref) ? (scalar_t)((const T*)p.yref)[xi] : 0; + scalar_t dy = (p.dy) ? (scalar_t)((const T*)p.dy)[xi] : one; + scalar_t yy = (gain != 0) ? yref / gain : 0; + scalar_t y = 0; + + // Apply bias. + ((G == 0) ? x : xref) += b; + + // linear + if (A == 1) + { + if (G == 0) y = x; + if (G == 1) y = x; + } + + // relu + if (A == 2) + { + if (G == 0) y = (x > 0) ? x : 0; + if (G == 1) y = (yy > 0) ? x : 0; + } + + // lrelu + if (A == 3) + { + if (G == 0) y = (x > 0) ? x : x * alpha; + if (G == 1) y = (yy > 0) ? x : x * alpha; + } + + // tanh + if (A == 4) + { + if (G == 0) { scalar_t c = exp(x); scalar_t d = one / c; y = (x < -expRange) ? -one : (x > expRange) ? one : (c - d) / (c + d); } + if (G == 1) y = x * (one - yy * yy); + if (G == 2) y = x * (one - yy * yy) * (-two * yy); + } + + // sigmoid + if (A == 5) + { + if (G == 0) y = (x < -expRange) ? 0 : one / (exp(-x) + one); + if (G == 1) y = x * yy * (one - yy); + if (G == 2) y = x * yy * (one - yy) * (one - two * yy); + } + + // elu + if (A == 6) + { + if (G == 0) y = (x >= 0) ? x : exp(x) - one; + if (G == 1) y = (yy >= 0) ? x : x * (yy + one); + if (G == 2) y = (yy >= 0) ? 0 : x * (yy + one); + } + + // selu + if (A == 7) + { + if (G == 0) y = (x >= 0) ? seluScale * x : (seluScale * seluAlpha) * (exp(x) - one); + if (G == 1) y = (yy >= 0) ? x * seluScale : x * (yy + seluScale * seluAlpha); + if (G == 2) y = (yy >= 0) ? 0 : x * (yy + seluScale * seluAlpha); + } + + // softplus + if (A == 8) + { + if (G == 0) y = (x > expRange) ? x : log(exp(x) + one); + if (G == 1) y = x * (one - exp(-yy)); + if (G == 2) { scalar_t c = exp(-yy); y = x * c * (one - c); } + } + + // swish + if (A == 9) + { + if (G == 0) + y = (x < -expRange) ? 0 : x / (exp(-x) + one); + else + { + scalar_t c = exp(xref); + scalar_t d = c + one; + if (G == 1) + y = (xref > halfExpRange) ? x : x * c * (xref + d) / (d * d); + else + y = (xref > halfExpRange) ? 0 : x * c * (xref * (two - d) + two * d) / (d * d * d); + yref = (xref < -expRange) ? 0 : xref / (exp(-xref) + one) * gain; + } + } + + // Apply gain. + y *= gain * dy; + + // Clamp. + if (clamp >= 0) + { + if (G == 0) + y = (y > -clamp & y < clamp) ? y : (y >= 0) ? clamp : -clamp; + else + y = (yref > -clamp & yref < clamp) ? y : 0; + } + + // Store. + ((T*)p.y)[xi] = (T)y; + } +} + +//------------------------------------------------------------------------ +// CUDA kernel selection. + +template void* choose_bias_act_kernel(const bias_act_kernel_params& p) +{ + if (p.act == 1) return (void*)bias_act_kernel; + if (p.act == 2) return (void*)bias_act_kernel; + if (p.act == 3) return (void*)bias_act_kernel; + if (p.act == 4) return (void*)bias_act_kernel; + if (p.act == 5) return (void*)bias_act_kernel; + if (p.act == 6) return (void*)bias_act_kernel; + if (p.act == 7) return (void*)bias_act_kernel; + if (p.act == 8) return (void*)bias_act_kernel; + if (p.act == 9) return (void*)bias_act_kernel; + return NULL; +} + +//------------------------------------------------------------------------ +// Template specializations. + +template void* choose_bias_act_kernel (const bias_act_kernel_params& p); +template void* choose_bias_act_kernel (const bias_act_kernel_params& p); +template void* choose_bias_act_kernel (const bias_act_kernel_params& p); + +//------------------------------------------------------------------------ diff --git a/torch_utils/ops/bias_act.h b/torch_utils/ops/bias_act.h new file mode 100644 index 0000000000000000000000000000000000000000..a32187e1fb7e3bae509d4eceaf900866866875a4 --- /dev/null +++ b/torch_utils/ops/bias_act.h @@ -0,0 +1,38 @@ +// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +// +// NVIDIA CORPORATION and its licensors retain all intellectual property +// and proprietary rights in and to this software, related documentation +// and any modifications thereto. Any use, reproduction, disclosure or +// distribution of this software and related documentation without an express +// license agreement from NVIDIA CORPORATION is strictly prohibited. + +//------------------------------------------------------------------------ +// CUDA kernel parameters. + +struct bias_act_kernel_params +{ + const void* x; // [sizeX] + const void* b; // [sizeB] or NULL + const void* xref; // [sizeX] or NULL + const void* yref; // [sizeX] or NULL + const void* dy; // [sizeX] or NULL + void* y; // [sizeX] + + int grad; + int act; + float alpha; + float gain; + float clamp; + + int sizeX; + int sizeB; + int stepB; + int loopX; +}; + +//------------------------------------------------------------------------ +// CUDA kernel selection. + +template void* choose_bias_act_kernel(const bias_act_kernel_params& p); + +//------------------------------------------------------------------------ diff --git a/torch_utils/ops/bias_act.py b/torch_utils/ops/bias_act.py new file mode 100644 index 0000000000000000000000000000000000000000..4bcb409a89ccf6c6f6ecfca5962683df2d280b1f --- /dev/null +++ b/torch_utils/ops/bias_act.py @@ -0,0 +1,212 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +"""Custom PyTorch ops for efficient bias and activation.""" + +import os +import warnings +import numpy as np +import torch +import dnnlib +import traceback + +from .. import custom_ops +from .. import misc + +#---------------------------------------------------------------------------- + +activation_funcs = { + 'linear': dnnlib.EasyDict(func=lambda x, **_: x, def_alpha=0, def_gain=1, cuda_idx=1, ref='', has_2nd_grad=False), + 'relu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.relu(x), def_alpha=0, def_gain=np.sqrt(2), cuda_idx=2, ref='y', has_2nd_grad=False), + 'lrelu': dnnlib.EasyDict(func=lambda x, alpha, **_: torch.nn.functional.leaky_relu(x, alpha), def_alpha=0.2, def_gain=np.sqrt(2), cuda_idx=3, ref='y', has_2nd_grad=False), + 'tanh': dnnlib.EasyDict(func=lambda x, **_: torch.tanh(x), def_alpha=0, def_gain=1, cuda_idx=4, ref='y', has_2nd_grad=True), + 'sigmoid': dnnlib.EasyDict(func=lambda x, **_: torch.sigmoid(x), def_alpha=0, def_gain=1, cuda_idx=5, ref='y', has_2nd_grad=True), + 'elu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.elu(x), def_alpha=0, def_gain=1, cuda_idx=6, ref='y', has_2nd_grad=True), + 'selu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.selu(x), def_alpha=0, def_gain=1, cuda_idx=7, ref='y', has_2nd_grad=True), + 'softplus': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.softplus(x), def_alpha=0, def_gain=1, cuda_idx=8, ref='y', has_2nd_grad=True), + 'swish': dnnlib.EasyDict(func=lambda x, **_: torch.sigmoid(x) * x, def_alpha=0, def_gain=np.sqrt(2), cuda_idx=9, ref='x', has_2nd_grad=True), +} + +#---------------------------------------------------------------------------- + +_inited = False +_plugin = None +_null_tensor = torch.empty([0]) + +def _init(): + global _inited, _plugin + if not _inited: + _inited = True + sources = ['bias_act.cpp', 'bias_act.cu'] + sources = [os.path.join(os.path.dirname(__file__), s) for s in sources] + try: + _plugin = custom_ops.get_plugin('bias_act_plugin', sources=sources, extra_cuda_cflags=['--use_fast_math']) + except: + warnings.warn('Failed to build CUDA kernels for bias_act. Falling back to slow reference implementation. Details:\n\n' + traceback.format_exc()) + return _plugin is not None + +#---------------------------------------------------------------------------- + +def bias_act(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None, impl='cuda'): + r"""Fused bias and activation function. + + Adds bias `b` to activation tensor `x`, evaluates activation function `act`, + and scales the result by `gain`. Each of the steps is optional. In most cases, + the fused op is considerably more efficient than performing the same calculation + using standard PyTorch ops. It supports first and second order gradients, + but not third order gradients. + + Args: + x: Input activation tensor. Can be of any shape. + b: Bias vector, or `None` to disable. Must be a 1D tensor of the same type + as `x`. The shape must be known, and it must match the dimension of `x` + corresponding to `dim`. + dim: The dimension in `x` corresponding to the elements of `b`. + The value of `dim` is ignored if `b` is not specified. + act: Name of the activation function to evaluate, or `"linear"` to disable. + Can be e.g. `"relu"`, `"lrelu"`, `"tanh"`, `"sigmoid"`, `"swish"`, etc. + See `activation_funcs` for a full list. `None` is not allowed. + alpha: Shape parameter for the activation function, or `None` to use the default. + gain: Scaling factor for the output tensor, or `None` to use default. + See `activation_funcs` for the default scaling of each activation function. + If unsure, consider specifying 1. + clamp: Clamp the output values to `[-clamp, +clamp]`, or `None` to disable + the clamping (default). + impl: Name of the implementation to use. Can be `"ref"` or `"cuda"` (default). + + Returns: + Tensor of the same shape and datatype as `x`. + """ + assert isinstance(x, torch.Tensor) + assert impl in ['ref', 'cuda'] + if impl == 'cuda' and x.device.type == 'cuda' and _init(): + return _bias_act_cuda(dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp).apply(x, b) + return _bias_act_ref(x=x, b=b, dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp) + +#---------------------------------------------------------------------------- + +@misc.profiled_function +def _bias_act_ref(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None): + """Slow reference implementation of `bias_act()` using standard TensorFlow ops. + """ + assert isinstance(x, torch.Tensor) + assert clamp is None or clamp >= 0 + spec = activation_funcs[act] + alpha = float(alpha if alpha is not None else spec.def_alpha) + gain = float(gain if gain is not None else spec.def_gain) + clamp = float(clamp if clamp is not None else -1) + + # Add bias. + if b is not None: + assert isinstance(b, torch.Tensor) and b.ndim == 1 + assert 0 <= dim < x.ndim + assert b.shape[0] == x.shape[dim] + x = x + b.reshape([-1 if i == dim else 1 for i in range(x.ndim)]) + + # Evaluate activation function. + alpha = float(alpha) + x = spec.func(x, alpha=alpha) + + # Scale by gain. + gain = float(gain) + if gain != 1: + x = x * gain + + # Clamp. + if clamp >= 0: + x = x.clamp(-clamp, clamp) # pylint: disable=invalid-unary-operand-type + return x + +#---------------------------------------------------------------------------- + +_bias_act_cuda_cache = dict() + +def _bias_act_cuda(dim=1, act='linear', alpha=None, gain=None, clamp=None): + """Fast CUDA implementation of `bias_act()` using custom ops. + """ + # Parse arguments. + assert clamp is None or clamp >= 0 + spec = activation_funcs[act] + alpha = float(alpha if alpha is not None else spec.def_alpha) + gain = float(gain if gain is not None else spec.def_gain) + clamp = float(clamp if clamp is not None else -1) + + # Lookup from cache. + key = (dim, act, alpha, gain, clamp) + if key in _bias_act_cuda_cache: + return _bias_act_cuda_cache[key] + + # Forward op. + class BiasActCuda(torch.autograd.Function): + @staticmethod + def forward(ctx, x, b): # pylint: disable=arguments-differ + ctx.memory_format = torch.channels_last if x.ndim > 2 and x.stride()[1] == 1 else torch.contiguous_format + x = x.contiguous(memory_format=ctx.memory_format) + b = b.contiguous() if b is not None else _null_tensor + y = x + if act != 'linear' or gain != 1 or clamp >= 0 or b is not _null_tensor: + y = _plugin.bias_act(x, b, _null_tensor, _null_tensor, _null_tensor, 0, dim, spec.cuda_idx, alpha, gain, clamp) + ctx.save_for_backward( + x if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor, + b if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor, + y if 'y' in spec.ref else _null_tensor) + return y + + @staticmethod + def backward(ctx, dy): # pylint: disable=arguments-differ + dy = dy.contiguous(memory_format=ctx.memory_format) + x, b, y = ctx.saved_tensors + dx = None + db = None + + if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]: + dx = dy + if act != 'linear' or gain != 1 or clamp >= 0: + dx = BiasActCudaGrad.apply(dy, x, b, y) + + if ctx.needs_input_grad[1]: + db = dx.sum([i for i in range(dx.ndim) if i != dim]) + + return dx, db + + # Backward op. + class BiasActCudaGrad(torch.autograd.Function): + @staticmethod + def forward(ctx, dy, x, b, y): # pylint: disable=arguments-differ + ctx.memory_format = torch.channels_last if dy.ndim > 2 and dy.stride()[1] == 1 else torch.contiguous_format + dx = _plugin.bias_act(dy, b, x, y, _null_tensor, 1, dim, spec.cuda_idx, alpha, gain, clamp) + ctx.save_for_backward( + dy if spec.has_2nd_grad else _null_tensor, + x, b, y) + return dx + + @staticmethod + def backward(ctx, d_dx): # pylint: disable=arguments-differ + d_dx = d_dx.contiguous(memory_format=ctx.memory_format) + dy, x, b, y = ctx.saved_tensors + d_dy = None + d_x = None + d_b = None + d_y = None + + if ctx.needs_input_grad[0]: + d_dy = BiasActCudaGrad.apply(d_dx, x, b, y) + + if spec.has_2nd_grad and (ctx.needs_input_grad[1] or ctx.needs_input_grad[2]): + d_x = _plugin.bias_act(d_dx, b, x, y, dy, 2, dim, spec.cuda_idx, alpha, gain, clamp) + + if spec.has_2nd_grad and ctx.needs_input_grad[2]: + d_b = d_x.sum([i for i in range(d_x.ndim) if i != dim]) + + return d_dy, d_x, d_b, d_y + + # Add to cache. + _bias_act_cuda_cache[key] = BiasActCuda + return BiasActCuda + +#---------------------------------------------------------------------------- diff --git a/torch_utils/ops/conv2d_gradfix.py b/torch_utils/ops/conv2d_gradfix.py new file mode 100644 index 0000000000000000000000000000000000000000..f4ffe7eb48086ed98eff3d5a83f5b361ff4a7b4d --- /dev/null +++ b/torch_utils/ops/conv2d_gradfix.py @@ -0,0 +1,163 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +"""Custom replacement for `torch.nn.functional.conv2d` that supports +arbitrarily high order gradients with zero performance penalty.""" + +import warnings +import contextlib +import torch + +# pylint: disable=redefined-builtin +# pylint: disable=arguments-differ +# pylint: disable=protected-access + +#---------------------------------------------------------------------------- + +enabled = False # Enable the custom op by setting this to true. +weight_gradients_disabled = False # Forcefully disable computation of gradients with respect to the weights. + +@contextlib.contextmanager +def no_weight_gradients(): + global weight_gradients_disabled + old = weight_gradients_disabled + weight_gradients_disabled = True + yield + weight_gradients_disabled = old + +def conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1): + if _should_use_custom_op(input): + return _conv2d_gradfix(transpose=False, weight_shape=weight.shape, stride=stride, padding=padding, output_padding=0, dilation=dilation, groups=groups).apply(input, weight, bias) + return torch.nn.functional.conv2d(input=input, weight=weight, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups) + +def conv_transpose2d(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1): + if _should_use_custom_op(input): + return _conv2d_gradfix(transpose=True, weight_shape=weight.shape, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation).apply(input, weight, bias) + return torch.nn.functional.conv_transpose2d(input=input, weight=weight, bias=bias, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation) + + +def _should_use_custom_op(input): + assert isinstance(input, torch.Tensor) + if (not enabled) or (not torch.backends.cudnn.enabled): + return False + if input.device.type != 'cuda': + return False + if any(torch.__version__.startswith(x) for x in ['1.7.', '1.8.', '1.9']): + return True + return False + +def _tuple_of_ints(xs, ndim): + xs = tuple(xs) if isinstance(xs, (tuple, list)) else (xs,) * ndim + assert len(xs) == ndim + assert all(isinstance(x, int) for x in xs) + return xs + + +_conv2d_gradfix_cache = dict() + +def _conv2d_gradfix(transpose, weight_shape, stride, padding, output_padding, dilation, groups): + # Parse arguments. + ndim = 2 + weight_shape = tuple(weight_shape) + stride = _tuple_of_ints(stride, ndim) + padding = _tuple_of_ints(padding, ndim) + output_padding = _tuple_of_ints(output_padding, ndim) + dilation = _tuple_of_ints(dilation, ndim) + + # Lookup from cache. + key = (transpose, weight_shape, stride, padding, output_padding, dilation, groups) + if key in _conv2d_gradfix_cache: + return _conv2d_gradfix_cache[key] + + # Validate arguments. + assert groups >= 1 + assert len(weight_shape) == ndim + 2 + assert all(stride[i] >= 1 for i in range(ndim)) + assert all(padding[i] >= 0 for i in range(ndim)) + assert all(dilation[i] >= 0 for i in range(ndim)) + if not transpose: + assert all(output_padding[i] == 0 for i in range(ndim)) + else: # transpose + assert all(0 <= output_padding[i] < max(stride[i], dilation[i]) for i in range(ndim)) + + # Helpers. + common_kwargs = dict(stride=stride, padding=padding, dilation=dilation, groups=groups) + def calc_output_padding(input_shape, output_shape): + if transpose: + return [0, 0] + return [ + input_shape[i + 2] + - (output_shape[i + 2] - 1) * stride[i] + - (1 - 2 * padding[i]) + - dilation[i] * (weight_shape[i + 2] - 1) + for i in range(ndim) + ] + + # Forward & backward. + class Conv2d(torch.autograd.Function): + @staticmethod + def forward(ctx, input, weight, bias): + assert weight.shape == weight_shape + if not transpose: + output = torch.nn.functional.conv2d(input=input, weight=weight, bias=bias, **common_kwargs) + else: # transpose + output = torch.nn.functional.conv_transpose2d(input=input, weight=weight, bias=bias, output_padding=output_padding, **common_kwargs) + ctx.save_for_backward(input, weight) + return output + + @staticmethod + def backward(ctx, grad_output): + input, weight = ctx.saved_tensors + grad_input = None + grad_weight = None + grad_bias = None + + if ctx.needs_input_grad[0]: + p = calc_output_padding(input_shape=input.shape, output_shape=grad_output.shape) + grad_input = _conv2d_gradfix(transpose=(not transpose), weight_shape=weight_shape, output_padding=p, **common_kwargs).apply(grad_output, weight, None) + assert grad_input.shape == input.shape + + if ctx.needs_input_grad[1] and not weight_gradients_disabled: + grad_weight = Conv2dGradWeight.apply(grad_output, input) + assert grad_weight.shape == weight_shape + + if ctx.needs_input_grad[2]: + grad_bias = grad_output.sum([0, 2, 3]) + + return grad_input, grad_weight, grad_bias + + # Gradient with respect to the weights. + class Conv2dGradWeight(torch.autograd.Function): + @staticmethod + def forward(ctx, grad_output, input): + op = torch._C._jit_get_operation('aten::cudnn_convolution_backward_weight' if not transpose else 'aten::cudnn_convolution_transpose_backward_weight') + flags = [torch.backends.cudnn.benchmark, torch.backends.cudnn.deterministic, torch.backends.cudnn.allow_tf32] + grad_weight = op(weight_shape, grad_output, input, padding, stride, dilation, groups, *flags) + assert grad_weight.shape == weight_shape + ctx.save_for_backward(grad_output, input) + return grad_weight + + @staticmethod + def backward(ctx, grad2_grad_weight): + grad_output, input = ctx.saved_tensors + grad2_grad_output = None + grad2_input = None + + if ctx.needs_input_grad[0]: + grad2_grad_output = Conv2d.apply(input, grad2_grad_weight, None) + assert grad2_grad_output.shape == grad_output.shape + + if ctx.needs_input_grad[1]: + p = calc_output_padding(input_shape=input.shape, output_shape=grad_output.shape) + grad2_input = _conv2d_gradfix(transpose=(not transpose), weight_shape=weight_shape, output_padding=p, **common_kwargs).apply(grad_output, grad2_grad_weight, None) + assert grad2_input.shape == input.shape + + return grad2_grad_output, grad2_input + + _conv2d_gradfix_cache[key] = Conv2d + return Conv2d diff --git a/torch_utils/ops/conv2d_resample.py b/torch_utils/ops/conv2d_resample.py new file mode 100644 index 0000000000000000000000000000000000000000..cd4750744c83354bab78704d4ef51ad1070fcc4a --- /dev/null +++ b/torch_utils/ops/conv2d_resample.py @@ -0,0 +1,156 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +"""2D convolution with optional up/downsampling.""" + +import torch + +from .. import misc +from . import conv2d_gradfix +from . import upfirdn2d +from .upfirdn2d import _parse_padding +from .upfirdn2d import _get_filter_size + +#---------------------------------------------------------------------------- + +def _get_weight_shape(w): + with misc.suppress_tracer_warnings(): # this value will be treated as a constant + shape = [int(sz) for sz in w.shape] + misc.assert_shape(w, shape) + return shape + +#---------------------------------------------------------------------------- + +def _conv2d_wrapper(x, w, stride=1, padding=0, groups=1, transpose=False, flip_weight=True): + """Wrapper for the underlying `conv2d()` and `conv_transpose2d()` implementations. + """ + out_channels, in_channels_per_group, kh, kw = _get_weight_shape(w) + + # Flip weight if requested. + if not flip_weight: # conv2d() actually performs correlation (flip_weight=True) not convolution (flip_weight=False). + w = w.flip([2, 3]) + + # Workaround performance pitfall in cuDNN 8.0.5, triggered when using + # 1x1 kernel + memory_format=channels_last + less than 64 channels. + if kw == 1 and kh == 1 and stride == 1 and padding in [0, [0, 0], (0, 0)] and not transpose: + if x.stride()[1] == 1 and min(out_channels, in_channels_per_group) < 64: + if out_channels <= 4 and groups == 1: + in_shape = x.shape + x = w.squeeze(3).squeeze(2) @ x.reshape([in_shape[0], in_channels_per_group, -1]) + x = x.reshape([in_shape[0], out_channels, in_shape[2], in_shape[3]]) + else: + x = x.to(memory_format=torch.contiguous_format) + w = w.to(memory_format=torch.contiguous_format) + x = conv2d_gradfix.conv2d(x, w, groups=groups) + return x.to(memory_format=torch.channels_last) + + # Otherwise => execute using conv2d_gradfix. + op = conv2d_gradfix.conv_transpose2d if transpose else conv2d_gradfix.conv2d + return op(x, w, stride=stride, padding=padding, groups=groups) + +#---------------------------------------------------------------------------- + +@misc.profiled_function +def conv2d_resample(x, w, f=None, up=1, down=1, padding=0, groups=1, flip_weight=True, flip_filter=False): + r"""2D convolution with optional up/downsampling. + + Padding is performed only once at the beginning, not between the operations. + + Args: + x: Input tensor of shape + `[batch_size, in_channels, in_height, in_width]`. + w: Weight tensor of shape + `[out_channels, in_channels//groups, kernel_height, kernel_width]`. + f: Low-pass filter for up/downsampling. Must be prepared beforehand by + calling upfirdn2d.setup_filter(). None = identity (default). + up: Integer upsampling factor (default: 1). + down: Integer downsampling factor (default: 1). + padding: Padding with respect to the upsampled image. Can be a single number + or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]` + (default: 0). + groups: Split input channels into N groups (default: 1). + flip_weight: False = convolution, True = correlation (default: True). + flip_filter: False = convolution, True = correlation (default: False). + + Returns: + Tensor of the shape `[batch_size, num_channels, out_height, out_width]`. + """ + # Validate arguments. + assert isinstance(x, torch.Tensor) and (x.ndim == 4) + assert isinstance(w, torch.Tensor) and (w.ndim == 4) and (w.dtype == x.dtype) + assert f is None or (isinstance(f, torch.Tensor) and f.ndim in [1, 2] and f.dtype == torch.float32) + assert isinstance(up, int) and (up >= 1) + assert isinstance(down, int) and (down >= 1) + assert isinstance(groups, int) and (groups >= 1) + out_channels, in_channels_per_group, kh, kw = _get_weight_shape(w) + fw, fh = _get_filter_size(f) + px0, px1, py0, py1 = _parse_padding(padding) + + # Adjust padding to account for up/downsampling. + if up > 1: + px0 += (fw + up - 1) // 2 + px1 += (fw - up) // 2 + py0 += (fh + up - 1) // 2 + py1 += (fh - up) // 2 + if down > 1: + px0 += (fw - down + 1) // 2 + px1 += (fw - down) // 2 + py0 += (fh - down + 1) // 2 + py1 += (fh - down) // 2 + + # Fast path: 1x1 convolution with downsampling only => downsample first, then convolve. + if kw == 1 and kh == 1 and (down > 1 and up == 1): + x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, padding=[px0,px1,py0,py1], flip_filter=flip_filter) + x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight) + return x + + # Fast path: 1x1 convolution with upsampling only => convolve first, then upsample. + if kw == 1 and kh == 1 and (up > 1 and down == 1): + x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight) + x = upfirdn2d.upfirdn2d(x=x, f=f, up=up, padding=[px0,px1,py0,py1], gain=up**2, flip_filter=flip_filter) + return x + + # Fast path: downsampling only => use strided convolution. + if down > 1 and up == 1: + x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0,px1,py0,py1], flip_filter=flip_filter) + x = _conv2d_wrapper(x=x, w=w, stride=down, groups=groups, flip_weight=flip_weight) + return x + + # Fast path: upsampling with optional downsampling => use transpose strided convolution. + if up > 1: + if groups == 1: + w = w.transpose(0, 1) + else: + w = w.reshape(groups, out_channels // groups, in_channels_per_group, kh, kw) + w = w.transpose(1, 2) + w = w.reshape(groups * in_channels_per_group, out_channels // groups, kh, kw) + px0 -= kw - 1 + px1 -= kw - up + py0 -= kh - 1 + py1 -= kh - up + pxt = max(min(-px0, -px1), 0) + pyt = max(min(-py0, -py1), 0) + x = _conv2d_wrapper(x=x, w=w, stride=up, padding=[pyt,pxt], groups=groups, transpose=True, flip_weight=(not flip_weight)) + x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0+pxt,px1+pxt,py0+pyt,py1+pyt], gain=up**2, flip_filter=flip_filter) + if down > 1: + x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter) + return x + + # Fast path: no up/downsampling, padding supported by the underlying implementation => use plain conv2d. + if up == 1 and down == 1: + if px0 == px1 and py0 == py1 and px0 >= 0 and py0 >= 0: + return _conv2d_wrapper(x=x, w=w, padding=[py0,px0], groups=groups, flip_weight=flip_weight) + + # Fallback: Generic reference implementation. + x = upfirdn2d.upfirdn2d(x=x, f=(f if up > 1 else None), up=up, padding=[px0,px1,py0,py1], gain=up**2, flip_filter=flip_filter) + x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight) + if down > 1: + x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter) + return x + +#---------------------------------------------------------------------------- diff --git a/torch_utils/ops/fma.py b/torch_utils/ops/fma.py new file mode 100644 index 0000000000000000000000000000000000000000..2eeac58a626c49231e04122b93e321ada954c5d3 --- /dev/null +++ b/torch_utils/ops/fma.py @@ -0,0 +1,60 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +"""Fused multiply-add, with slightly faster gradients than `torch.addcmul()`.""" + +import torch + +#---------------------------------------------------------------------------- + +def fma(a, b, c): # => a * b + c + return _FusedMultiplyAdd.apply(a, b, c) + +#---------------------------------------------------------------------------- + +class _FusedMultiplyAdd(torch.autograd.Function): # a * b + c + @staticmethod + def forward(ctx, a, b, c): # pylint: disable=arguments-differ + out = torch.addcmul(c, a, b) + ctx.save_for_backward(a, b) + ctx.c_shape = c.shape + return out + + @staticmethod + def backward(ctx, dout): # pylint: disable=arguments-differ + a, b = ctx.saved_tensors + c_shape = ctx.c_shape + da = None + db = None + dc = None + + if ctx.needs_input_grad[0]: + da = _unbroadcast(dout * b, a.shape) + + if ctx.needs_input_grad[1]: + db = _unbroadcast(dout * a, b.shape) + + if ctx.needs_input_grad[2]: + dc = _unbroadcast(dout, c_shape) + + return da, db, dc + +#---------------------------------------------------------------------------- + +def _unbroadcast(x, shape): + extra_dims = x.ndim - len(shape) + assert extra_dims >= 0 + dim = [i for i in range(x.ndim) if x.shape[i] > 1 and (i < extra_dims or shape[i - extra_dims] == 1)] + if len(dim): + x = x.sum(dim=dim, keepdim=True) + if extra_dims: + x = x.reshape(-1, *x.shape[extra_dims+1:]) + assert x.shape == shape + return x + +#---------------------------------------------------------------------------- diff --git a/torch_utils/ops/grid_sample_gradfix.py b/torch_utils/ops/grid_sample_gradfix.py new file mode 100644 index 0000000000000000000000000000000000000000..ca6b3413ea72a734703c34382c023b84523601fd --- /dev/null +++ b/torch_utils/ops/grid_sample_gradfix.py @@ -0,0 +1,83 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +"""Custom replacement for `torch.nn.functional.grid_sample` that +supports arbitrarily high order gradients between the input and output. +Only works on 2D images and assumes +`mode='bilinear'`, `padding_mode='zeros'`, `align_corners=False`.""" + +import warnings +import torch + +# pylint: disable=redefined-builtin +# pylint: disable=arguments-differ +# pylint: disable=protected-access + +#---------------------------------------------------------------------------- + +enabled = False # Enable the custom op by setting this to true. + +#---------------------------------------------------------------------------- + +def grid_sample(input, grid): + if _should_use_custom_op(): + return _GridSample2dForward.apply(input, grid) + return torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False) + +#---------------------------------------------------------------------------- + +def _should_use_custom_op(): + if not enabled: + return False + if any(torch.__version__.startswith(x) for x in ['1.7.', '1.8.', '1.9']): + return True + warnings.warn(f'grid_sample_gradfix not supported on PyTorch {torch.__version__}. Falling back to torch.nn.functional.grid_sample().') + return False + +#---------------------------------------------------------------------------- + +class _GridSample2dForward(torch.autograd.Function): + @staticmethod + def forward(ctx, input, grid): + assert input.ndim == 4 + assert grid.ndim == 4 + output = torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False) + ctx.save_for_backward(input, grid) + return output + + @staticmethod + def backward(ctx, grad_output): + input, grid = ctx.saved_tensors + grad_input, grad_grid = _GridSample2dBackward.apply(grad_output, input, grid) + return grad_input, grad_grid + +#---------------------------------------------------------------------------- + +class _GridSample2dBackward(torch.autograd.Function): + @staticmethod + def forward(ctx, grad_output, input, grid): + op = torch._C._jit_get_operation('aten::grid_sampler_2d_backward') + grad_input, grad_grid = op(grad_output, input, grid, 0, 0, False) + ctx.save_for_backward(grid) + return grad_input, grad_grid + + @staticmethod + def backward(ctx, grad2_grad_input, grad2_grad_grid): + _ = grad2_grad_grid # unused + grid, = ctx.saved_tensors + grad2_grad_output = None + grad2_input = None + grad2_grid = None + + if ctx.needs_input_grad[0]: + grad2_grad_output = _GridSample2dForward.apply(grad2_grad_input, grid) + + assert not ctx.needs_input_grad[2] + return grad2_grad_output, grad2_input, grad2_grid + +#---------------------------------------------------------------------------- diff --git a/torch_utils/ops/upfirdn2d.cpp b/torch_utils/ops/upfirdn2d.cpp new file mode 100644 index 0000000000000000000000000000000000000000..2d7177fc60040751d20e9a8da0301fa3ab64968a --- /dev/null +++ b/torch_utils/ops/upfirdn2d.cpp @@ -0,0 +1,103 @@ +// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +// +// NVIDIA CORPORATION and its licensors retain all intellectual property +// and proprietary rights in and to this software, related documentation +// and any modifications thereto. Any use, reproduction, disclosure or +// distribution of this software and related documentation without an express +// license agreement from NVIDIA CORPORATION is strictly prohibited. + +#include +#include +#include +#include "upfirdn2d.h" + +//------------------------------------------------------------------------ + +static torch::Tensor upfirdn2d(torch::Tensor x, torch::Tensor f, int upx, int upy, int downx, int downy, int padx0, int padx1, int pady0, int pady1, bool flip, float gain) +{ + // Validate arguments. + TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device"); + TORCH_CHECK(f.device() == x.device(), "f must reside on the same device as x"); + TORCH_CHECK(f.dtype() == torch::kFloat, "f must be float32"); + TORCH_CHECK(x.numel() <= INT_MAX, "x is too large"); + TORCH_CHECK(f.numel() <= INT_MAX, "f is too large"); + TORCH_CHECK(x.dim() == 4, "x must be rank 4"); + TORCH_CHECK(f.dim() == 2, "f must be rank 2"); + TORCH_CHECK(f.size(0) >= 1 && f.size(1) >= 1, "f must be at least 1x1"); + TORCH_CHECK(upx >= 1 && upy >= 1, "upsampling factor must be at least 1"); + TORCH_CHECK(downx >= 1 && downy >= 1, "downsampling factor must be at least 1"); + + // Create output tensor. + const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); + int outW = ((int)x.size(3) * upx + padx0 + padx1 - (int)f.size(1) + downx) / downx; + int outH = ((int)x.size(2) * upy + pady0 + pady1 - (int)f.size(0) + downy) / downy; + TORCH_CHECK(outW >= 1 && outH >= 1, "output must be at least 1x1"); + torch::Tensor y = torch::empty({x.size(0), x.size(1), outH, outW}, x.options(), x.suggest_memory_format()); + TORCH_CHECK(y.numel() <= INT_MAX, "output is too large"); + + // Initialize CUDA kernel parameters. + upfirdn2d_kernel_params p; + p.x = x.data_ptr(); + p.f = f.data_ptr(); + p.y = y.data_ptr(); + p.up = make_int2(upx, upy); + p.down = make_int2(downx, downy); + p.pad0 = make_int2(padx0, pady0); + p.flip = (flip) ? 1 : 0; + p.gain = gain; + p.inSize = make_int4((int)x.size(3), (int)x.size(2), (int)x.size(1), (int)x.size(0)); + p.inStride = make_int4((int)x.stride(3), (int)x.stride(2), (int)x.stride(1), (int)x.stride(0)); + p.filterSize = make_int2((int)f.size(1), (int)f.size(0)); + p.filterStride = make_int2((int)f.stride(1), (int)f.stride(0)); + p.outSize = make_int4((int)y.size(3), (int)y.size(2), (int)y.size(1), (int)y.size(0)); + p.outStride = make_int4((int)y.stride(3), (int)y.stride(2), (int)y.stride(1), (int)y.stride(0)); + p.sizeMajor = (p.inStride.z == 1) ? p.inSize.w : p.inSize.w * p.inSize.z; + p.sizeMinor = (p.inStride.z == 1) ? p.inSize.z : 1; + + // Choose CUDA kernel. + upfirdn2d_kernel_spec spec; + AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] + { + spec = choose_upfirdn2d_kernel(p); + }); + + // Set looping options. + p.loopMajor = (p.sizeMajor - 1) / 16384 + 1; + p.loopMinor = spec.loopMinor; + p.loopX = spec.loopX; + p.launchMinor = (p.sizeMinor - 1) / p.loopMinor + 1; + p.launchMajor = (p.sizeMajor - 1) / p.loopMajor + 1; + + // Compute grid size. + dim3 blockSize, gridSize; + if (spec.tileOutW < 0) // large + { + blockSize = dim3(4, 32, 1); + gridSize = dim3( + ((p.outSize.y - 1) / blockSize.x + 1) * p.launchMinor, + (p.outSize.x - 1) / (blockSize.y * p.loopX) + 1, + p.launchMajor); + } + else // small + { + blockSize = dim3(256, 1, 1); + gridSize = dim3( + ((p.outSize.y - 1) / spec.tileOutH + 1) * p.launchMinor, + (p.outSize.x - 1) / (spec.tileOutW * p.loopX) + 1, + p.launchMajor); + } + + // Launch CUDA kernel. + void* args[] = {&p}; + AT_CUDA_CHECK(cudaLaunchKernel(spec.kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream())); + return y; +} + +//------------------------------------------------------------------------ + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) +{ + m.def("upfirdn2d", &upfirdn2d); +} + +//------------------------------------------------------------------------ diff --git a/torch_utils/ops/upfirdn2d.cu b/torch_utils/ops/upfirdn2d.cu new file mode 100644 index 0000000000000000000000000000000000000000..ebdd9879f4bb16fc57a23cbc81f9de8ef54e4916 --- /dev/null +++ b/torch_utils/ops/upfirdn2d.cu @@ -0,0 +1,350 @@ +// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +// +// NVIDIA CORPORATION and its licensors retain all intellectual property +// and proprietary rights in and to this software, related documentation +// and any modifications thereto. Any use, reproduction, disclosure or +// distribution of this software and related documentation without an express +// license agreement from NVIDIA CORPORATION is strictly prohibited. + +#include +#include "upfirdn2d.h" + +//------------------------------------------------------------------------ +// Helpers. + +template struct InternalType; +template <> struct InternalType { typedef double scalar_t; }; +template <> struct InternalType { typedef float scalar_t; }; +template <> struct InternalType { typedef float scalar_t; }; + +static __device__ __forceinline__ int floor_div(int a, int b) +{ + int t = 1 - a / b; + return (a + t * b) / b - t; +} + +//------------------------------------------------------------------------ +// Generic CUDA implementation for large filters. + +template static __global__ void upfirdn2d_kernel_large(upfirdn2d_kernel_params p) +{ + typedef typename InternalType::scalar_t scalar_t; + + // Calculate thread index. + int minorBase = blockIdx.x * blockDim.x + threadIdx.x; + int outY = minorBase / p.launchMinor; + minorBase -= outY * p.launchMinor; + int outXBase = blockIdx.y * p.loopX * blockDim.y + threadIdx.y; + int majorBase = blockIdx.z * p.loopMajor; + if (outXBase >= p.outSize.x | outY >= p.outSize.y | majorBase >= p.sizeMajor) + return; + + // Setup Y receptive field. + int midY = outY * p.down.y + p.up.y - 1 - p.pad0.y; + int inY = min(max(floor_div(midY, p.up.y), 0), p.inSize.y); + int h = min(max(floor_div(midY + p.filterSize.y, p.up.y), 0), p.inSize.y) - inY; + int filterY = midY + p.filterSize.y - (inY + 1) * p.up.y; + if (p.flip) + filterY = p.filterSize.y - 1 - filterY; + + // Loop over major, minor, and X. + for (int majorIdx = 0, major = majorBase; majorIdx < p.loopMajor & major < p.sizeMajor; majorIdx++, major++) + for (int minorIdx = 0, minor = minorBase; minorIdx < p.loopMinor & minor < p.sizeMinor; minorIdx++, minor += p.launchMinor) + { + int nc = major * p.sizeMinor + minor; + int n = nc / p.inSize.z; + int c = nc - n * p.inSize.z; + for (int loopX = 0, outX = outXBase; loopX < p.loopX & outX < p.outSize.x; loopX++, outX += blockDim.y) + { + // Setup X receptive field. + int midX = outX * p.down.x + p.up.x - 1 - p.pad0.x; + int inX = min(max(floor_div(midX, p.up.x), 0), p.inSize.x); + int w = min(max(floor_div(midX + p.filterSize.x, p.up.x), 0), p.inSize.x) - inX; + int filterX = midX + p.filterSize.x - (inX + 1) * p.up.x; + if (p.flip) + filterX = p.filterSize.x - 1 - filterX; + + // Initialize pointers. + const T* xp = &((const T*)p.x)[inX * p.inStride.x + inY * p.inStride.y + c * p.inStride.z + n * p.inStride.w]; + const float* fp = &p.f[filterX * p.filterStride.x + filterY * p.filterStride.y]; + int filterStepX = ((p.flip) ? p.up.x : -p.up.x) * p.filterStride.x; + int filterStepY = ((p.flip) ? p.up.y : -p.up.y) * p.filterStride.y; + + // Inner loop. + scalar_t v = 0; + for (int y = 0; y < h; y++) + { + for (int x = 0; x < w; x++) + { + v += (scalar_t)(*xp) * (scalar_t)(*fp); + xp += p.inStride.x; + fp += filterStepX; + } + xp += p.inStride.y - w * p.inStride.x; + fp += filterStepY - w * filterStepX; + } + + // Store result. + v *= p.gain; + ((T*)p.y)[outX * p.outStride.x + outY * p.outStride.y + c * p.outStride.z + n * p.outStride.w] = (T)v; + } + } +} + +//------------------------------------------------------------------------ +// Specialized CUDA implementation for small filters. + +template +static __global__ void upfirdn2d_kernel_small(upfirdn2d_kernel_params p) +{ + typedef typename InternalType::scalar_t scalar_t; + const int tileInW = ((tileOutW - 1) * downx + filterW - 1) / upx + 1; + const int tileInH = ((tileOutH - 1) * downy + filterH - 1) / upy + 1; + __shared__ volatile scalar_t sf[filterH][filterW]; + __shared__ volatile scalar_t sx[tileInH][tileInW][loopMinor]; + + // Calculate tile index. + int minorBase = blockIdx.x; + int tileOutY = minorBase / p.launchMinor; + minorBase -= tileOutY * p.launchMinor; + minorBase *= loopMinor; + tileOutY *= tileOutH; + int tileOutXBase = blockIdx.y * p.loopX * tileOutW; + int majorBase = blockIdx.z * p.loopMajor; + if (tileOutXBase >= p.outSize.x | tileOutY >= p.outSize.y | majorBase >= p.sizeMajor) + return; + + // Load filter (flipped). + for (int tapIdx = threadIdx.x; tapIdx < filterH * filterW; tapIdx += blockDim.x) + { + int fy = tapIdx / filterW; + int fx = tapIdx - fy * filterW; + scalar_t v = 0; + if (fx < p.filterSize.x & fy < p.filterSize.y) + { + int ffx = (p.flip) ? fx : p.filterSize.x - 1 - fx; + int ffy = (p.flip) ? fy : p.filterSize.y - 1 - fy; + v = (scalar_t)p.f[ffx * p.filterStride.x + ffy * p.filterStride.y]; + } + sf[fy][fx] = v; + } + + // Loop over major and X. + for (int majorIdx = 0, major = majorBase; majorIdx < p.loopMajor & major < p.sizeMajor; majorIdx++, major++) + { + int baseNC = major * p.sizeMinor + minorBase; + int n = baseNC / p.inSize.z; + int baseC = baseNC - n * p.inSize.z; + for (int loopX = 0, tileOutX = tileOutXBase; loopX < p.loopX & tileOutX < p.outSize.x; loopX++, tileOutX += tileOutW) + { + // Load input pixels. + int tileMidX = tileOutX * downx + upx - 1 - p.pad0.x; + int tileMidY = tileOutY * downy + upy - 1 - p.pad0.y; + int tileInX = floor_div(tileMidX, upx); + int tileInY = floor_div(tileMidY, upy); + __syncthreads(); + for (int inIdx = threadIdx.x; inIdx < tileInH * tileInW * loopMinor; inIdx += blockDim.x) + { + int relC = inIdx; + int relInX = relC / loopMinor; + int relInY = relInX / tileInW; + relC -= relInX * loopMinor; + relInX -= relInY * tileInW; + int c = baseC + relC; + int inX = tileInX + relInX; + int inY = tileInY + relInY; + scalar_t v = 0; + if (inX >= 0 & inY >= 0 & inX < p.inSize.x & inY < p.inSize.y & c < p.inSize.z) + v = (scalar_t)((const T*)p.x)[inX * p.inStride.x + inY * p.inStride.y + c * p.inStride.z + n * p.inStride.w]; + sx[relInY][relInX][relC] = v; + } + + // Loop over output pixels. + __syncthreads(); + for (int outIdx = threadIdx.x; outIdx < tileOutH * tileOutW * loopMinor; outIdx += blockDim.x) + { + int relC = outIdx; + int relOutX = relC / loopMinor; + int relOutY = relOutX / tileOutW; + relC -= relOutX * loopMinor; + relOutX -= relOutY * tileOutW; + int c = baseC + relC; + int outX = tileOutX + relOutX; + int outY = tileOutY + relOutY; + + // Setup receptive field. + int midX = tileMidX + relOutX * downx; + int midY = tileMidY + relOutY * downy; + int inX = floor_div(midX, upx); + int inY = floor_div(midY, upy); + int relInX = inX - tileInX; + int relInY = inY - tileInY; + int filterX = (inX + 1) * upx - midX - 1; // flipped + int filterY = (inY + 1) * upy - midY - 1; // flipped + + // Inner loop. + if (outX < p.outSize.x & outY < p.outSize.y & c < p.outSize.z) + { + scalar_t v = 0; + #pragma unroll + for (int y = 0; y < filterH / upy; y++) + #pragma unroll + for (int x = 0; x < filterW / upx; x++) + v += sx[relInY + y][relInX + x][relC] * sf[filterY + y * upy][filterX + x * upx]; + v *= p.gain; + ((T*)p.y)[outX * p.outStride.x + outY * p.outStride.y + c * p.outStride.z + n * p.outStride.w] = (T)v; + } + } + } + } +} + +//------------------------------------------------------------------------ +// CUDA kernel selection. + +template upfirdn2d_kernel_spec choose_upfirdn2d_kernel(const upfirdn2d_kernel_params& p) +{ + int s = p.inStride.z, fx = p.filterSize.x, fy = p.filterSize.y; + + upfirdn2d_kernel_spec spec = {(void*)upfirdn2d_kernel_large, -1,-1,1, 4}; // contiguous + if (s == 1) spec = {(void*)upfirdn2d_kernel_large, -1,-1,4, 1}; // channels_last + + if (s != 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1) // contiguous + { + if (fx <= 7 && fy <= 7 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1}; + if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1}; + if (fx <= 5 && fy <= 5 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1}; + if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1}; + if (fx <= 3 && fy <= 3 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1}; + if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1}; + if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1}; + if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1}; + if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1}; + if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1}; + if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; + if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; + if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; + if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; + if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; + } + if (s == 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1) // channels_last + { + if (fx <= 7 && fy <= 7 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1}; + if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1}; + if (fx <= 5 && fy <= 5 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1}; + if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1}; + if (fx <= 3 && fy <= 3 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1}; + if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1}; + if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1}; + if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1}; + if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1}; + if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1}; + if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1}; + if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1}; + if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1}; + if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1}; + if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1}; + } + if (s != 1 && p.up.x == 2 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1) // contiguous + { + if (fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1}; + if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1}; + if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1}; + if (fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1}; + } + if (s == 1 && p.up.x == 2 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1) // channels_last + { + if (fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1}; + if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1}; + if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1}; + if (fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1}; + } + if (s != 1 && p.up.x == 2 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1) // contiguous + { + if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1}; + if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1}; + if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1}; + if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1}; + if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1}; + } + if (s == 1 && p.up.x == 2 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1) // channels_last + { + if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1}; + if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1}; + if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1}; + if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1}; + if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1}; + } + if (s != 1 && p.up.x == 1 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1) // contiguous + { + if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; + if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; + if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; + if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; + if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; + } + if (s == 1 && p.up.x == 1 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1) // channels_last + { + if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1}; + if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1}; + if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1}; + if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1}; + if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1}; + } + if (s != 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 2) // contiguous + { + if (fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 32,8,1, 1}; + if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small, 32,8,1, 1}; + if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small, 32,8,1, 1}; + if (fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small, 32,8,1, 1}; + } + if (s == 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 2) // channels_last + { + if (fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 8,8,8, 1}; + if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small, 8,8,8, 1}; + if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small, 8,8,8, 1}; + if (fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small, 8,8,8, 1}; + } + if (s != 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 1) // contiguous + { + if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 64,8,1, 1}; + if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 64,8,1, 1}; + if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 64,8,1, 1}; + if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 64,8,1, 1}; + if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 64,8,1, 1}; + } + if (s == 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 1) // channels_last + { + if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 64,1,8, 1}; + if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 64,1,8, 1}; + if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 64,1,8, 1}; + if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 64,1,8, 1}; + if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 64,1,8, 1}; + } + if (s != 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 2) // contiguous + { + if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 32,16,1, 1}; + if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small, 32,16,1, 1}; + if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 32,16,1, 1}; + if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small, 32,16,1, 1}; + if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 32,16,1, 1}; + } + if (s == 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 2) // channels_last + { + if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 1,64,8, 1}; + if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small, 1,64,8, 1}; + if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 1,64,8, 1}; + if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small, 1,64,8, 1}; + if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 1,64,8, 1}; + } + return spec; +} + +//------------------------------------------------------------------------ +// Template specializations. + +template upfirdn2d_kernel_spec choose_upfirdn2d_kernel (const upfirdn2d_kernel_params& p); +template upfirdn2d_kernel_spec choose_upfirdn2d_kernel (const upfirdn2d_kernel_params& p); +template upfirdn2d_kernel_spec choose_upfirdn2d_kernel(const upfirdn2d_kernel_params& p); + +//------------------------------------------------------------------------ diff --git a/torch_utils/ops/upfirdn2d.h b/torch_utils/ops/upfirdn2d.h new file mode 100644 index 0000000000000000000000000000000000000000..c9e2032bcac9d2abde7a75eea4d812da348afadd --- /dev/null +++ b/torch_utils/ops/upfirdn2d.h @@ -0,0 +1,59 @@ +// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +// +// NVIDIA CORPORATION and its licensors retain all intellectual property +// and proprietary rights in and to this software, related documentation +// and any modifications thereto. Any use, reproduction, disclosure or +// distribution of this software and related documentation without an express +// license agreement from NVIDIA CORPORATION is strictly prohibited. + +#include + +//------------------------------------------------------------------------ +// CUDA kernel parameters. + +struct upfirdn2d_kernel_params +{ + const void* x; + const float* f; + void* y; + + int2 up; + int2 down; + int2 pad0; + int flip; + float gain; + + int4 inSize; // [width, height, channel, batch] + int4 inStride; + int2 filterSize; // [width, height] + int2 filterStride; + int4 outSize; // [width, height, channel, batch] + int4 outStride; + int sizeMinor; + int sizeMajor; + + int loopMinor; + int loopMajor; + int loopX; + int launchMinor; + int launchMajor; +}; + +//------------------------------------------------------------------------ +// CUDA kernel specialization. + +struct upfirdn2d_kernel_spec +{ + void* kernel; + int tileOutW; + int tileOutH; + int loopMinor; + int loopX; +}; + +//------------------------------------------------------------------------ +// CUDA kernel selection. + +template upfirdn2d_kernel_spec choose_upfirdn2d_kernel(const upfirdn2d_kernel_params& p); + +//------------------------------------------------------------------------ diff --git a/torch_utils/ops/upfirdn2d.py b/torch_utils/ops/upfirdn2d.py new file mode 100644 index 0000000000000000000000000000000000000000..ceeac2b9834e33b7c601c28bf27f32aa91c69256 --- /dev/null +++ b/torch_utils/ops/upfirdn2d.py @@ -0,0 +1,384 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +"""Custom PyTorch ops for efficient resampling of 2D images.""" + +import os +import warnings +import numpy as np +import torch +import traceback + +from .. import custom_ops +from .. import misc +from . import conv2d_gradfix + +#---------------------------------------------------------------------------- + +_inited = False +_plugin = None + +def _init(): + global _inited, _plugin + if not _inited: + sources = ['upfirdn2d.cpp', 'upfirdn2d.cu'] + sources = [os.path.join(os.path.dirname(__file__), s) for s in sources] + try: + _plugin = custom_ops.get_plugin('upfirdn2d_plugin', sources=sources, extra_cuda_cflags=['--use_fast_math']) + except: + warnings.warn('Failed to build CUDA kernels for upfirdn2d. Falling back to slow reference implementation. Details:\n\n' + traceback.format_exc()) + return _plugin is not None + +def _parse_scaling(scaling): + if isinstance(scaling, int): + scaling = [scaling, scaling] + assert isinstance(scaling, (list, tuple)) + assert all(isinstance(x, int) for x in scaling) + sx, sy = scaling + assert sx >= 1 and sy >= 1 + return sx, sy + +def _parse_padding(padding): + if isinstance(padding, int): + padding = [padding, padding] + assert isinstance(padding, (list, tuple)) + assert all(isinstance(x, int) for x in padding) + if len(padding) == 2: + padx, pady = padding + padding = [padx, padx, pady, pady] + padx0, padx1, pady0, pady1 = padding + return padx0, padx1, pady0, pady1 + +def _get_filter_size(f): + if f is None: + return 1, 1 + assert isinstance(f, torch.Tensor) and f.ndim in [1, 2] + fw = f.shape[-1] + fh = f.shape[0] + with misc.suppress_tracer_warnings(): + fw = int(fw) + fh = int(fh) + misc.assert_shape(f, [fh, fw][:f.ndim]) + assert fw >= 1 and fh >= 1 + return fw, fh + +#---------------------------------------------------------------------------- + +def setup_filter(f, device=torch.device('cpu'), normalize=True, flip_filter=False, gain=1, separable=None): + r"""Convenience function to setup 2D FIR filter for `upfirdn2d()`. + + Args: + f: Torch tensor, numpy array, or python list of the shape + `[filter_height, filter_width]` (non-separable), + `[filter_taps]` (separable), + `[]` (impulse), or + `None` (identity). + device: Result device (default: cpu). + normalize: Normalize the filter so that it retains the magnitude + for constant input signal (DC)? (default: True). + flip_filter: Flip the filter? (default: False). + gain: Overall scaling factor for signal magnitude (default: 1). + separable: Return a separable filter? (default: select automatically). + + Returns: + Float32 tensor of the shape + `[filter_height, filter_width]` (non-separable) or + `[filter_taps]` (separable). + """ + # Validate. + if f is None: + f = 1 + f = torch.as_tensor(f, dtype=torch.float32) + assert f.ndim in [0, 1, 2] + assert f.numel() > 0 + if f.ndim == 0: + f = f[np.newaxis] + + # Separable? + if separable is None: + separable = (f.ndim == 1 and f.numel() >= 8) + if f.ndim == 1 and not separable: + f = f.ger(f) + assert f.ndim == (1 if separable else 2) + + # Apply normalize, flip, gain, and device. + if normalize: + f /= f.sum() + if flip_filter: + f = f.flip(list(range(f.ndim))) + f = f * (gain ** (f.ndim / 2)) + f = f.to(device=device) + return f + +#---------------------------------------------------------------------------- + +def upfirdn2d(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1, impl='cuda'): + r"""Pad, upsample, filter, and downsample a batch of 2D images. + + Performs the following sequence of operations for each channel: + + 1. Upsample the image by inserting N-1 zeros after each pixel (`up`). + + 2. Pad the image with the specified number of zeros on each side (`padding`). + Negative padding corresponds to cropping the image. + + 3. Convolve the image with the specified 2D FIR filter (`f`), shrinking it + so that the footprint of all output pixels lies within the input image. + + 4. Downsample the image by keeping every Nth pixel (`down`). + + This sequence of operations bears close resemblance to scipy.signal.upfirdn(). + The fused op is considerably more efficient than performing the same calculation + using standard PyTorch ops. It supports gradients of arbitrary order. + + Args: + x: Float32/float64/float16 input tensor of the shape + `[batch_size, num_channels, in_height, in_width]`. + f: Float32 FIR filter of the shape + `[filter_height, filter_width]` (non-separable), + `[filter_taps]` (separable), or + `None` (identity). + up: Integer upsampling factor. Can be a single int or a list/tuple + `[x, y]` (default: 1). + down: Integer downsampling factor. Can be a single int or a list/tuple + `[x, y]` (default: 1). + padding: Padding with respect to the upsampled image. Can be a single number + or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]` + (default: 0). + flip_filter: False = convolution, True = correlation (default: False). + gain: Overall scaling factor for signal magnitude (default: 1). + impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`). + + Returns: + Tensor of the shape `[batch_size, num_channels, out_height, out_width]`. + """ + assert isinstance(x, torch.Tensor) + assert impl in ['ref', 'cuda'] + if impl == 'cuda' and x.device.type == 'cuda' and _init(): + return _upfirdn2d_cuda(up=up, down=down, padding=padding, flip_filter=flip_filter, gain=gain).apply(x, f) + return _upfirdn2d_ref(x, f, up=up, down=down, padding=padding, flip_filter=flip_filter, gain=gain) + +#---------------------------------------------------------------------------- + +@misc.profiled_function +def _upfirdn2d_ref(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1): + """Slow reference implementation of `upfirdn2d()` using standard PyTorch ops. + """ + # Validate arguments. + assert isinstance(x, torch.Tensor) and x.ndim == 4 + if f is None: + f = torch.ones([1, 1], dtype=torch.float32, device=x.device) + assert isinstance(f, torch.Tensor) and f.ndim in [1, 2] + assert f.dtype == torch.float32 and not f.requires_grad + batch_size, num_channels, in_height, in_width = x.shape + upx, upy = _parse_scaling(up) + downx, downy = _parse_scaling(down) + padx0, padx1, pady0, pady1 = _parse_padding(padding) + + # Upsample by inserting zeros. + x = x.reshape([batch_size, num_channels, in_height, 1, in_width, 1]) + x = torch.nn.functional.pad(x, [0, upx - 1, 0, 0, 0, upy - 1]) + x = x.reshape([batch_size, num_channels, in_height * upy, in_width * upx]) + + # Pad or crop. + x = torch.nn.functional.pad(x, [max(padx0, 0), max(padx1, 0), max(pady0, 0), max(pady1, 0)]) + x = x[:, :, max(-pady0, 0) : x.shape[2] - max(-pady1, 0), max(-padx0, 0) : x.shape[3] - max(-padx1, 0)] + + # Setup filter. + f = f * (gain ** (f.ndim / 2)) + f = f.to(x.dtype) + if not flip_filter: + f = f.flip(list(range(f.ndim))) + + # Convolve with the filter. + f = f[np.newaxis, np.newaxis].repeat([num_channels, 1] + [1] * f.ndim) + if f.ndim == 4: + x = conv2d_gradfix.conv2d(input=x, weight=f, groups=num_channels) + else: + x = conv2d_gradfix.conv2d(input=x, weight=f.unsqueeze(2), groups=num_channels) + x = conv2d_gradfix.conv2d(input=x, weight=f.unsqueeze(3), groups=num_channels) + + # Downsample by throwing away pixels. + x = x[:, :, ::downy, ::downx] + return x + +#---------------------------------------------------------------------------- + +_upfirdn2d_cuda_cache = dict() + +def _upfirdn2d_cuda(up=1, down=1, padding=0, flip_filter=False, gain=1): + """Fast CUDA implementation of `upfirdn2d()` using custom ops. + """ + # Parse arguments. + upx, upy = _parse_scaling(up) + downx, downy = _parse_scaling(down) + padx0, padx1, pady0, pady1 = _parse_padding(padding) + + # Lookup from cache. + key = (upx, upy, downx, downy, padx0, padx1, pady0, pady1, flip_filter, gain) + if key in _upfirdn2d_cuda_cache: + return _upfirdn2d_cuda_cache[key] + + # Forward op. + class Upfirdn2dCuda(torch.autograd.Function): + @staticmethod + def forward(ctx, x, f): # pylint: disable=arguments-differ + assert isinstance(x, torch.Tensor) and x.ndim == 4 + if f is None: + f = torch.ones([1, 1], dtype=torch.float32, device=x.device) + assert isinstance(f, torch.Tensor) and f.ndim in [1, 2] + y = x + if f.ndim == 2: + y = _plugin.upfirdn2d(y, f, upx, upy, downx, downy, padx0, padx1, pady0, pady1, flip_filter, gain) + else: + y = _plugin.upfirdn2d(y, f.unsqueeze(0), upx, 1, downx, 1, padx0, padx1, 0, 0, flip_filter, np.sqrt(gain)) + y = _plugin.upfirdn2d(y, f.unsqueeze(1), 1, upy, 1, downy, 0, 0, pady0, pady1, flip_filter, np.sqrt(gain)) + ctx.save_for_backward(f) + ctx.x_shape = x.shape + return y + + @staticmethod + def backward(ctx, dy): # pylint: disable=arguments-differ + f, = ctx.saved_tensors + _, _, ih, iw = ctx.x_shape + _, _, oh, ow = dy.shape + fw, fh = _get_filter_size(f) + p = [ + fw - padx0 - 1, + iw * upx - ow * downx + padx0 - upx + 1, + fh - pady0 - 1, + ih * upy - oh * downy + pady0 - upy + 1, + ] + dx = None + df = None + + if ctx.needs_input_grad[0]: + dx = _upfirdn2d_cuda(up=down, down=up, padding=p, flip_filter=(not flip_filter), gain=gain).apply(dy, f) + + assert not ctx.needs_input_grad[1] + return dx, df + + # Add to cache. + _upfirdn2d_cuda_cache[key] = Upfirdn2dCuda + return Upfirdn2dCuda + +#---------------------------------------------------------------------------- + +def filter2d(x, f, padding=0, flip_filter=False, gain=1, impl='cuda'): + r"""Filter a batch of 2D images using the given 2D FIR filter. + + By default, the result is padded so that its shape matches the input. + User-specified padding is applied on top of that, with negative values + indicating cropping. Pixels outside the image are assumed to be zero. + + Args: + x: Float32/float64/float16 input tensor of the shape + `[batch_size, num_channels, in_height, in_width]`. + f: Float32 FIR filter of the shape + `[filter_height, filter_width]` (non-separable), + `[filter_taps]` (separable), or + `None` (identity). + padding: Padding with respect to the output. Can be a single number or a + list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]` + (default: 0). + flip_filter: False = convolution, True = correlation (default: False). + gain: Overall scaling factor for signal magnitude (default: 1). + impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`). + + Returns: + Tensor of the shape `[batch_size, num_channels, out_height, out_width]`. + """ + padx0, padx1, pady0, pady1 = _parse_padding(padding) + fw, fh = _get_filter_size(f) + p = [ + padx0 + fw // 2, + padx1 + (fw - 1) // 2, + pady0 + fh // 2, + pady1 + (fh - 1) // 2, + ] + return upfirdn2d(x, f, padding=p, flip_filter=flip_filter, gain=gain, impl=impl) + +#---------------------------------------------------------------------------- + +def upsample2d(x, f, up=2, padding=0, flip_filter=False, gain=1, impl='cuda'): + r"""Upsample a batch of 2D images using the given 2D FIR filter. + + By default, the result is padded so that its shape is a multiple of the input. + User-specified padding is applied on top of that, with negative values + indicating cropping. Pixels outside the image are assumed to be zero. + + Args: + x: Float32/float64/float16 input tensor of the shape + `[batch_size, num_channels, in_height, in_width]`. + f: Float32 FIR filter of the shape + `[filter_height, filter_width]` (non-separable), + `[filter_taps]` (separable), or + `None` (identity). + up: Integer upsampling factor. Can be a single int or a list/tuple + `[x, y]` (default: 1). + padding: Padding with respect to the output. Can be a single number or a + list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]` + (default: 0). + flip_filter: False = convolution, True = correlation (default: False). + gain: Overall scaling factor for signal magnitude (default: 1). + impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`). + + Returns: + Tensor of the shape `[batch_size, num_channels, out_height, out_width]`. + """ + upx, upy = _parse_scaling(up) + padx0, padx1, pady0, pady1 = _parse_padding(padding) + fw, fh = _get_filter_size(f) + p = [ + padx0 + (fw + upx - 1) // 2, + padx1 + (fw - upx) // 2, + pady0 + (fh + upy - 1) // 2, + pady1 + (fh - upy) // 2, + ] + return upfirdn2d(x, f, up=up, padding=p, flip_filter=flip_filter, gain=gain*upx*upy, impl=impl) + +#---------------------------------------------------------------------------- + +def downsample2d(x, f, down=2, padding=0, flip_filter=False, gain=1, impl='cuda'): + r"""Downsample a batch of 2D images using the given 2D FIR filter. + + By default, the result is padded so that its shape is a fraction of the input. + User-specified padding is applied on top of that, with negative values + indicating cropping. Pixels outside the image are assumed to be zero. + + Args: + x: Float32/float64/float16 input tensor of the shape + `[batch_size, num_channels, in_height, in_width]`. + f: Float32 FIR filter of the shape + `[filter_height, filter_width]` (non-separable), + `[filter_taps]` (separable), or + `None` (identity). + down: Integer downsampling factor. Can be a single int or a list/tuple + `[x, y]` (default: 1). + padding: Padding with respect to the input. Can be a single number or a + list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]` + (default: 0). + flip_filter: False = convolution, True = correlation (default: False). + gain: Overall scaling factor for signal magnitude (default: 1). + impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`). + + Returns: + Tensor of the shape `[batch_size, num_channels, out_height, out_width]`. + """ + downx, downy = _parse_scaling(down) + padx0, padx1, pady0, pady1 = _parse_padding(padding) + fw, fh = _get_filter_size(f) + p = [ + padx0 + (fw - downx + 1) // 2, + padx1 + (fw - downx) // 2, + pady0 + (fh - downy + 1) // 2, + pady1 + (fh - downy) // 2, + ] + return upfirdn2d(x, f, down=down, padding=p, flip_filter=flip_filter, gain=gain, impl=impl) + +#---------------------------------------------------------------------------- diff --git a/torch_utils/persistence.py b/torch_utils/persistence.py new file mode 100644 index 0000000000000000000000000000000000000000..0186cfd97bca0fcb397a7b73643520c1d1105a02 --- /dev/null +++ b/torch_utils/persistence.py @@ -0,0 +1,251 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +"""Facilities for pickling Python code alongside other data. + +The pickled code is automatically imported into a separate Python module +during unpickling. This way, any previously exported pickles will remain +usable even if the original code is no longer available, or if the current +version of the code is not consistent with what was originally pickled.""" + +import sys +import pickle +import io +import inspect +import copy +import uuid +import types +import dnnlib + +#---------------------------------------------------------------------------- + +_version = 6 # internal version number +_decorators = set() # {decorator_class, ...} +_import_hooks = [] # [hook_function, ...] +_module_to_src_dict = dict() # {module: src, ...} +_src_to_module_dict = dict() # {src: module, ...} + +#---------------------------------------------------------------------------- + +def persistent_class(orig_class): + r"""Class decorator that extends a given class to save its source code + when pickled. + + Example: + + from torch_utils import persistence + + @persistence.persistent_class + class MyNetwork(torch.nn.Module): + def __init__(self, num_inputs, num_outputs): + super().__init__() + self.fc = MyLayer(num_inputs, num_outputs) + ... + + @persistence.persistent_class + class MyLayer(torch.nn.Module): + ... + + When pickled, any instance of `MyNetwork` and `MyLayer` will save its + source code alongside other internal state (e.g., parameters, buffers, + and submodules). This way, any previously exported pickle will remain + usable even if the class definitions have been modified or are no + longer available. + + The decorator saves the source code of the entire Python module + containing the decorated class. It does *not* save the source code of + any imported modules. Thus, the imported modules must be available + during unpickling, also including `torch_utils.persistence` itself. + + It is ok to call functions defined in the same module from the + decorated class. However, if the decorated class depends on other + classes defined in the same module, they must be decorated as well. + This is illustrated in the above example in the case of `MyLayer`. + + It is also possible to employ the decorator just-in-time before + calling the constructor. For example: + + cls = MyLayer + if want_to_make_it_persistent: + cls = persistence.persistent_class(cls) + layer = cls(num_inputs, num_outputs) + + As an additional feature, the decorator also keeps track of the + arguments that were used to construct each instance of the decorated + class. The arguments can be queried via `obj.init_args` and + `obj.init_kwargs`, and they are automatically pickled alongside other + object state. A typical use case is to first unpickle a previous + instance of a persistent class, and then upgrade it to use the latest + version of the source code: + + with open('old_pickle.pkl', 'rb') as f: + old_net = pickle.load(f) + new_net = MyNetwork(*old_obj.init_args, **old_obj.init_kwargs) + misc.copy_params_and_buffers(old_net, new_net, require_all=True) + """ + assert isinstance(orig_class, type) + if is_persistent(orig_class): + return orig_class + + assert orig_class.__module__ in sys.modules + orig_module = sys.modules[orig_class.__module__] + orig_module_src = _module_to_src(orig_module) + + class Decorator(orig_class): + _orig_module_src = orig_module_src + _orig_class_name = orig_class.__name__ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._init_args = copy.deepcopy(args) + self._init_kwargs = copy.deepcopy(kwargs) + assert orig_class.__name__ in orig_module.__dict__ + _check_pickleable(self.__reduce__()) + + @property + def init_args(self): + return copy.deepcopy(self._init_args) + + @property + def init_kwargs(self): + return dnnlib.EasyDict(copy.deepcopy(self._init_kwargs)) + + def __reduce__(self): + fields = list(super().__reduce__()) + fields += [None] * max(3 - len(fields), 0) + if fields[0] is not _reconstruct_persistent_obj: + meta = dict(type='class', version=_version, module_src=self._orig_module_src, class_name=self._orig_class_name, state=fields[2]) + fields[0] = _reconstruct_persistent_obj # reconstruct func + fields[1] = (meta,) # reconstruct args + fields[2] = None # state dict + return tuple(fields) + + Decorator.__name__ = orig_class.__name__ + _decorators.add(Decorator) + return Decorator + +#---------------------------------------------------------------------------- + +def is_persistent(obj): + r"""Test whether the given object or class is persistent, i.e., + whether it will save its source code when pickled. + """ + try: + if obj in _decorators: + return True + except TypeError: + pass + return type(obj) in _decorators # pylint: disable=unidiomatic-typecheck + +#---------------------------------------------------------------------------- + +def import_hook(hook): + r"""Register an import hook that is called whenever a persistent object + is being unpickled. A typical use case is to patch the pickled source + code to avoid errors and inconsistencies when the API of some imported + module has changed. + + The hook should have the following signature: + + hook(meta) -> modified meta + + `meta` is an instance of `dnnlib.EasyDict` with the following fields: + + type: Type of the persistent object, e.g. `'class'`. + version: Internal version number of `torch_utils.persistence`. + module_src Original source code of the Python module. + class_name: Class name in the original Python module. + state: Internal state of the object. + + Example: + + @persistence.import_hook + def wreck_my_network(meta): + if meta.class_name == 'MyNetwork': + print('MyNetwork is being imported. I will wreck it!') + meta.module_src = meta.module_src.replace("True", "False") + return meta + """ + assert callable(hook) + _import_hooks.append(hook) + +#---------------------------------------------------------------------------- + +def _reconstruct_persistent_obj(meta): + r"""Hook that is called internally by the `pickle` module to unpickle + a persistent object. + """ + meta = dnnlib.EasyDict(meta) + meta.state = dnnlib.EasyDict(meta.state) + for hook in _import_hooks: + meta = hook(meta) + assert meta is not None + + assert meta.version == _version + module = _src_to_module(meta.module_src) + + assert meta.type == 'class' + orig_class = module.__dict__[meta.class_name] + decorator_class = persistent_class(orig_class) + obj = decorator_class.__new__(decorator_class) + + setstate = getattr(obj, '__setstate__', None) + if callable(setstate): + setstate(meta.state) # pylint: disable=not-callable + else: + obj.__dict__.update(meta.state) + return obj + +#---------------------------------------------------------------------------- + +def _module_to_src(module): + r"""Query the source code of a given Python module. + """ + src = _module_to_src_dict.get(module, None) + if src is None: + src = inspect.getsource(module) + _module_to_src_dict[module] = src + _src_to_module_dict[src] = module + return src + +def _src_to_module(src): + r"""Get or create a Python module for the given source code. + """ + module = _src_to_module_dict.get(src, None) + if module is None: + module_name = "_imported_module_" + uuid.uuid4().hex + module = types.ModuleType(module_name) + sys.modules[module_name] = module + _module_to_src_dict[module] = src + _src_to_module_dict[src] = module + exec(src, module.__dict__) # pylint: disable=exec-used + return module + +#---------------------------------------------------------------------------- + +def _check_pickleable(obj): + r"""Check that the given object is pickleable, raising an exception if + it is not. This function is expected to be considerably more efficient + than actually pickling the object. + """ + def recurse(obj): + if isinstance(obj, (list, tuple, set)): + return [recurse(x) for x in obj] + if isinstance(obj, dict): + return [[recurse(x), recurse(y)] for x, y in obj.items()] + if isinstance(obj, (str, int, float, bool, bytes, bytearray)): + return None # Python primitive types are pickleable. + if f'{type(obj).__module__}.{type(obj).__name__}' in ['numpy.ndarray', 'torch.Tensor']: + return None # NumPy arrays and PyTorch tensors are pickleable. + if is_persistent(obj): + return None # Persistent objects are pickleable, by virtue of the constructor check. + return obj + with io.BytesIO() as f: + pickle.dump(recurse(obj), f) + +#---------------------------------------------------------------------------- diff --git a/torch_utils/training_stats.py b/torch_utils/training_stats.py new file mode 100644 index 0000000000000000000000000000000000000000..26f467f9eaa074ee13de1cf2625cd7da44880847 --- /dev/null +++ b/torch_utils/training_stats.py @@ -0,0 +1,268 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +"""Facilities for reporting and collecting training statistics across +multiple processes and devices. The interface is designed to minimize +synchronization overhead as well as the amount of boilerplate in user +code.""" + +import re +import numpy as np +import torch +import dnnlib + +from . import misc + +#---------------------------------------------------------------------------- + +_num_moments = 3 # [num_scalars, sum_of_scalars, sum_of_squares] +_reduce_dtype = torch.float32 # Data type to use for initial per-tensor reduction. +_counter_dtype = torch.float64 # Data type to use for the internal counters. +_rank = 0 # Rank of the current process. +_sync_device = None # Device to use for multiprocess communication. None = single-process. +_sync_called = False # Has _sync() been called yet? +_counters = dict() # Running counters on each device, updated by report(): name => device => torch.Tensor +_cumulative = dict() # Cumulative counters on the CPU, updated by _sync(): name => torch.Tensor + +#---------------------------------------------------------------------------- + +def init_multiprocessing(rank, sync_device): + r"""Initializes `torch_utils.training_stats` for collecting statistics + across multiple processes. + + This function must be called after + `torch.distributed.init_process_group()` and before `Collector.update()`. + The call is not necessary if multi-process collection is not needed. + + Args: + rank: Rank of the current process. + sync_device: PyTorch device to use for inter-process + communication, or None to disable multi-process + collection. Typically `torch.device('cuda', rank)`. + """ + global _rank, _sync_device + assert not _sync_called + _rank = rank + _sync_device = sync_device + +#---------------------------------------------------------------------------- + +@misc.profiled_function +def report(name, value): + r"""Broadcasts the given set of scalars to all interested instances of + `Collector`, across device and process boundaries. + + This function is expected to be extremely cheap and can be safely + called from anywhere in the training loop, loss function, or inside a + `torch.nn.Module`. + + Warning: The current implementation expects the set of unique names to + be consistent across processes. Please make sure that `report()` is + called at least once for each unique name by each process, and in the + same order. If a given process has no scalars to broadcast, it can do + `report(name, [])` (empty list). + + Args: + name: Arbitrary string specifying the name of the statistic. + Averages are accumulated separately for each unique name. + value: Arbitrary set of scalars. Can be a list, tuple, + NumPy array, PyTorch tensor, or Python scalar. + + Returns: + The same `value` that was passed in. + """ + if name not in _counters: + _counters[name] = dict() + + elems = torch.as_tensor(value) + if elems.numel() == 0: + return value + + elems = elems.detach().flatten().to(_reduce_dtype) + moments = torch.stack([ + torch.ones_like(elems).sum(), + elems.sum(), + elems.square().sum(), + ]) + assert moments.ndim == 1 and moments.shape[0] == _num_moments + moments = moments.to(_counter_dtype) + + device = moments.device + if device not in _counters[name]: + _counters[name][device] = torch.zeros_like(moments) + _counters[name][device].add_(moments) + return value + +#---------------------------------------------------------------------------- + +def report0(name, value): + r"""Broadcasts the given set of scalars by the first process (`rank = 0`), + but ignores any scalars provided by the other processes. + See `report()` for further details. + """ + report(name, value if _rank == 0 else []) + return value + +#---------------------------------------------------------------------------- + +class Collector: + r"""Collects the scalars broadcasted by `report()` and `report0()` and + computes their long-term averages (mean and standard deviation) over + user-defined periods of time. + + The averages are first collected into internal counters that are not + directly visible to the user. They are then copied to the user-visible + state as a result of calling `update()` and can then be queried using + `mean()`, `std()`, `as_dict()`, etc. Calling `update()` also resets the + internal counters for the next round, so that the user-visible state + effectively reflects averages collected between the last two calls to + `update()`. + + Args: + regex: Regular expression defining which statistics to + collect. The default is to collect everything. + keep_previous: Whether to retain the previous averages if no + scalars were collected on a given round + (default: True). + """ + def __init__(self, regex='.*', keep_previous=True): + self._regex = re.compile(regex) + self._keep_previous = keep_previous + self._cumulative = dict() + self._moments = dict() + self.update() + self._moments.clear() + + def names(self): + r"""Returns the names of all statistics broadcasted so far that + match the regular expression specified at construction time. + """ + return [name for name in _counters if self._regex.fullmatch(name)] + + def update(self): + r"""Copies current values of the internal counters to the + user-visible state and resets them for the next round. + + If `keep_previous=True` was specified at construction time, the + operation is skipped for statistics that have received no scalars + since the last update, retaining their previous averages. + + This method performs a number of GPU-to-CPU transfers and one + `torch.distributed.all_reduce()`. It is intended to be called + periodically in the main training loop, typically once every + N training steps. + """ + if not self._keep_previous: + self._moments.clear() + for name, cumulative in _sync(self.names()): + if name not in self._cumulative: + self._cumulative[name] = torch.zeros([_num_moments], dtype=_counter_dtype) + delta = cumulative - self._cumulative[name] + self._cumulative[name].copy_(cumulative) + if float(delta[0]) != 0: + self._moments[name] = delta + + def _get_delta(self, name): + r"""Returns the raw moments that were accumulated for the given + statistic between the last two calls to `update()`, or zero if + no scalars were collected. + """ + assert self._regex.fullmatch(name) + if name not in self._moments: + self._moments[name] = torch.zeros([_num_moments], dtype=_counter_dtype) + return self._moments[name] + + def num(self, name): + r"""Returns the number of scalars that were accumulated for the given + statistic between the last two calls to `update()`, or zero if + no scalars were collected. + """ + delta = self._get_delta(name) + return int(delta[0]) + + def mean(self, name): + r"""Returns the mean of the scalars that were accumulated for the + given statistic between the last two calls to `update()`, or NaN if + no scalars were collected. + """ + delta = self._get_delta(name) + if int(delta[0]) == 0: + return float('nan') + return float(delta[1] / delta[0]) + + def std(self, name): + r"""Returns the standard deviation of the scalars that were + accumulated for the given statistic between the last two calls to + `update()`, or NaN if no scalars were collected. + """ + delta = self._get_delta(name) + if int(delta[0]) == 0 or not np.isfinite(float(delta[1])): + return float('nan') + if int(delta[0]) == 1: + return float(0) + mean = float(delta[1] / delta[0]) + raw_var = float(delta[2] / delta[0]) + return np.sqrt(max(raw_var - np.square(mean), 0)) + + def as_dict(self): + r"""Returns the averages accumulated between the last two calls to + `update()` as an `dnnlib.EasyDict`. The contents are as follows: + + dnnlib.EasyDict( + NAME = dnnlib.EasyDict(num=FLOAT, mean=FLOAT, std=FLOAT), + ... + ) + """ + stats = dnnlib.EasyDict() + for name in self.names(): + stats[name] = dnnlib.EasyDict(num=self.num(name), mean=self.mean(name), std=self.std(name)) + return stats + + def __getitem__(self, name): + r"""Convenience getter. + `collector[name]` is a synonym for `collector.mean(name)`. + """ + return self.mean(name) + +#---------------------------------------------------------------------------- + +def _sync(names): + r"""Synchronize the global cumulative counters across devices and + processes. Called internally by `Collector.update()`. + """ + if len(names) == 0: + return [] + global _sync_called + _sync_called = True + + # Collect deltas within current rank. + deltas = [] + device = _sync_device if _sync_device is not None else torch.device('cpu') + for name in names: + delta = torch.zeros([_num_moments], dtype=_counter_dtype, device=device) + for counter in _counters[name].values(): + delta.add_(counter.to(device)) + counter.copy_(torch.zeros_like(counter)) + deltas.append(delta) + deltas = torch.stack(deltas) + + # Sum deltas across ranks. + if _sync_device is not None: + torch.distributed.all_reduce(deltas) + + # Update cumulative values. + deltas = deltas.cpu() + for idx, name in enumerate(names): + if name not in _cumulative: + _cumulative[name] = torch.zeros([_num_moments], dtype=_counter_dtype) + _cumulative[name].add_(deltas[idx]) + + # Return name-value pairs. + return [(name, _cumulative[name]) for name in names] + +#---------------------------------------------------------------------------- diff --git a/w_s_converter.py b/w_s_converter.py new file mode 100644 index 0000000000000000000000000000000000000000..91ba9630bd94ab741017e0f6df6b2a76de960f35 --- /dev/null +++ b/w_s_converter.py @@ -0,0 +1,188 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +"""Generate images using pretrained network pickle.""" + +import os +import re +import random +import math +import time +import click +import legacy +from typing import List, Optional + +import cv2 +import clip +import dnnlib +import numpy as np +import torchvision +from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize +import torch +from torch import linalg as LA +import torch.nn.functional as F +from PIL import Image +import matplotlib.pyplot as plt + +from torch_utils import misc +from torch_utils import persistence +from torch_utils.ops import conv2d_resample +from torch_utils.ops import upfirdn2d +from torch_utils.ops import bias_act +from torch_utils.ops import fma + + +def block_forward(self, x, img, ws, shapes, force_fp32=False, fused_modconv=None, **layer_kwargs): + misc.assert_shape(ws, [None, self.num_conv + self.num_torgb, self.w_dim]) + w_iter = iter(ws.unbind(dim=1)) + dtype = torch.float16 if self.use_fp16 and not force_fp32 else torch.float32 + memory_format = torch.channels_last if self.channels_last and not force_fp32 else torch.contiguous_format + if fused_modconv is None: + with misc.suppress_tracer_warnings(): # this value will be treated as a constant + fused_modconv = (not self.training) and (dtype == torch.float32 or int(x.shape[0]) == 1) + + # Input. + if self.in_channels == 0: + x = self.const.to(dtype=dtype, memory_format=memory_format) + x = x.unsqueeze(0).repeat([ws.shape[0], 1, 1, 1]) + else: + misc.assert_shape(x, [None, self.in_channels, self.resolution // 2, self.resolution // 2]) + x = x.to(dtype=dtype, memory_format=memory_format) + + # Main layers. + if self.in_channels == 0: + x = self.conv1(x, next(w_iter)[...,:shapes[0]], fused_modconv=fused_modconv, **layer_kwargs) + elif self.architecture == 'resnet': + y = self.skip(x, gain=np.sqrt(0.5)) + x = self.conv0(x, next(w_iter), fused_modconv=fused_modconv, **layer_kwargs) + x = self.conv1(x, next(w_iter), fused_modconv=fused_modconv, gain=np.sqrt(0.5), **layer_kwargs) + x = y.add_(x) + else: + x = self.conv0(x, next(w_iter)[...,:shapes[0]], fused_modconv=fused_modconv, **layer_kwargs) + x = self.conv1(x, next(w_iter)[...,:shapes[1]], fused_modconv=fused_modconv, **layer_kwargs) + + # ToRGB. + if img is not None: + misc.assert_shape(img, [None, self.img_channels, self.resolution // 2, self.resolution // 2]) + img = upfirdn2d.upsample2d(img, self.resample_filter) + if self.is_last or self.architecture == 'skip': + y = self.torgb(x, next(w_iter)[...,:shapes[2]], fused_modconv=fused_modconv) + y = y.to(dtype=torch.float32, memory_format=torch.contiguous_format) + img = img.add_(y) if img is not None else y + + assert x.dtype == dtype + assert img is None or img.dtype == torch.float32 + return x, img + + +def unravel_index(index, shape): + out = [] + for dim in reversed(shape): + out.append(index % dim) + index = index // dim + return tuple(reversed(out)) + + +def num_range(s: str) -> List[int]: + '''Accept either a comma separated list of numbers 'a,b,c' or a range 'a-c' and return as a list of ints.''' + + range_re = re.compile(r'^(\d+)-(\d+)$') + m = range_re.match(s) + if m: + return list(range(int(m.group(1)), int(m.group(2))+1)) + vals = s.split(',') + return [int(x) for x in vals] + + +@click.command() +@click.pass_context +@click.option('--network', 'network_pkl', help='Network pickle filename', required=True) +@click.option('--seeds', type=num_range, help='List of random seeds') +@click.option('--trunc', 'truncation_psi', type=float, help='Truncation psi', default=1, show_default=True) +@click.option('--noise-mode', help='Noise mode', type=click.Choice(['const', 'random', 'none']), default='const', show_default=True) +@click.option('--projected-w', help='Projection result file', type=str, metavar='FILE') +@click.option('--projected_s', help='Projection result file', type=str, metavar='FILE') +@click.option('--outdir', help='Where to save the output images', type=str, required=True, metavar='DIR') + +def generate_images( + ctx: click.Context, + network_pkl: str, + seeds: Optional[List[int]], + truncation_psi: float, + noise_mode: str, + outdir: str, + class_idx: Optional[int], + projected_w: Optional[str], + projected_s: Optional[str] +): + + print('Loading networks from "%s"...' % network_pkl) + # Use GPU if available + if torch.cuda.is_available(): + device = torch.device("cuda") + else: + device = torch.device("cpu") + + with dnnlib.util.open_url(network_pkl) as f: + G = legacy.load_network_pkl(f)['G_ema'].to(device) # type: ignore + + os.makedirs(outdir, exist_ok=True) + + # Generate images. + for i in G.parameters(): + i.requires_grad = True + + ws = np.load(projected_w)['w'] + ws = torch.tensor(ws, device=device) + + block_ws = [] + with torch.autograd.profiler.record_function('split_ws'): + misc.assert_shape(ws, [None, G.synthesis.num_ws, G.synthesis.w_dim]) + ws = ws.to(torch.float32) + + + w_idx = 0 + for res in G.synthesis.block_resolutions: + block = getattr(G.synthesis, f'b{res}') + block_ws.append(ws.narrow(1, w_idx, block.num_conv + block.num_torgb)) + w_idx += block.num_conv + + + styles = torch.zeros(1,26,512, device=device) + styles_idx = 0 + temp_shapes = [] + for res, cur_ws in zip(G.synthesis.block_resolutions, block_ws): + block = getattr(G.synthesis, f'b{res}') + + if res == 4: + temp_shape = (block.conv1.affine.weight.shape[0], block.conv1.affine.weight.shape[0], block.torgb.affine.weight.shape[0]) + styles[0,:1,:] = block.conv1.affine(cur_ws[0,:1,:]) + styles[0,1:2,:] = block.torgb.affine(cur_ws[0,1:2,:]) + + block.conv1.affine = torch.nn.Identity() + block.torgb.affine = torch.nn.Identity() + styles_idx += 2 + else: + temp_shape = (block.conv0.affine.weight.shape[0], block.conv1.affine.weight.shape[0], block.torgb.affine.weight.shape[0]) + styles[0,styles_idx:styles_idx+1,:temp_shape[0]] = block.conv0.affine(cur_ws[0,:1,:]) + styles[0,styles_idx+1:styles_idx+2,:temp_shape[1]] = block.conv1.affine(cur_ws[0,1:2,:]) + styles[0,styles_idx+2:styles_idx+3,:temp_shape[2]] = block.torgb.affine(cur_ws[0,2:3,:]) + + block.conv0.affine = torch.nn.Identity() + block.conv1.affine = torch.nn.Identity() + block.torgb.affine = torch.nn.Identity() + styles_idx += 3 + temp_shapes.append(temp_shape) + + + styles = styles.detach() + np.savez(f'{outdir}/input.npz', s=styles.cpu().numpy()) + + +if __name__ == "__main__": + generate_images()