|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Miscellaneous utility classes and functions.""" |
|
|
|
import ctypes |
|
import fnmatch |
|
import importlib |
|
import inspect |
|
import numpy as np |
|
import os |
|
import shutil |
|
import sys |
|
import types |
|
import io |
|
import pickle |
|
import re |
|
import requests |
|
import html |
|
import hashlib |
|
import glob |
|
import tempfile |
|
import urllib |
|
import urllib.request |
|
import uuid |
|
|
|
from distutils.util import strtobool |
|
from typing import Any, List, Tuple, Union |
|
|
|
import torch |
|
|
|
|
|
|
|
|
|
def calculate_adaptive_weight(recon_loss, g_loss, last_layer, disc_weight_max=1.0): |
|
recon_grads = torch.autograd.grad(recon_loss, last_layer, retain_graph=True)[0] |
|
g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] |
|
|
|
d_weight = torch.norm(recon_grads) / (torch.norm(g_grads) + 1e-4) |
|
d_weight = torch.clamp(d_weight, 0.0, disc_weight_max).detach() |
|
return d_weight |
|
|
|
|
|
class EasyDict(dict): |
|
"""Convenience class that behaves like a dict but allows access with the attribute syntax.""" |
|
|
|
def __getattr__(self, name: str) -> Any: |
|
try: |
|
return self[name] |
|
except KeyError: |
|
raise AttributeError(name) |
|
|
|
def __setattr__(self, name: str, value: Any) -> None: |
|
self[name] = value |
|
|
|
def __delattr__(self, name: str) -> None: |
|
del self[name] |
|
|
|
|
|
class Logger(object): |
|
"""Redirect stderr to stdout, optionally print stdout to a file, and optionally force flushing on both stdout and the file.""" |
|
|
|
def __init__(self, |
|
file_name: str = None, |
|
file_mode: str = "w", |
|
should_flush: bool = True): |
|
self.file = None |
|
|
|
if file_name is not None: |
|
self.file = open(file_name, file_mode) |
|
|
|
self.should_flush = should_flush |
|
self.stdout = sys.stdout |
|
self.stderr = sys.stderr |
|
|
|
sys.stdout = self |
|
sys.stderr = self |
|
|
|
def __enter__(self) -> "Logger": |
|
return self |
|
|
|
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: |
|
self.close() |
|
|
|
def write(self, text: Union[str, bytes]) -> None: |
|
"""Write text to stdout (and a file) and optionally flush.""" |
|
if isinstance(text, bytes): |
|
text = text.decode() |
|
if len( |
|
text |
|
) == 0: |
|
return |
|
|
|
if self.file is not None: |
|
self.file.write(text) |
|
|
|
self.stdout.write(text) |
|
|
|
if self.should_flush: |
|
self.flush() |
|
|
|
def flush(self) -> None: |
|
"""Flush written text to both stdout and a file, if open.""" |
|
if self.file is not None: |
|
self.file.flush() |
|
|
|
self.stdout.flush() |
|
|
|
def close(self) -> None: |
|
"""Flush, close possible files, and remove stdout/stderr mirroring.""" |
|
self.flush() |
|
|
|
|
|
if sys.stdout is self: |
|
sys.stdout = self.stdout |
|
if sys.stderr is self: |
|
sys.stderr = self.stderr |
|
|
|
if self.file is not None: |
|
self.file.close() |
|
self.file = None |
|
|
|
|
|
|
|
|
|
|
|
_dnnlib_cache_dir = None |
|
|
|
|
|
def set_cache_dir(path: str) -> None: |
|
global _dnnlib_cache_dir |
|
_dnnlib_cache_dir = path |
|
|
|
|
|
def make_cache_dir_path(*paths: str) -> str: |
|
if _dnnlib_cache_dir is not None: |
|
return os.path.join(_dnnlib_cache_dir, *paths) |
|
if 'DNNLIB_CACHE_DIR' in os.environ: |
|
return os.path.join(os.environ['DNNLIB_CACHE_DIR'], *paths) |
|
if 'HOME' in os.environ: |
|
return os.path.join(os.environ['HOME'], '.cache', 'dnnlib', *paths) |
|
if 'USERPROFILE' in os.environ: |
|
return os.path.join(os.environ['USERPROFILE'], '.cache', 'dnnlib', |
|
*paths) |
|
return os.path.join(tempfile.gettempdir(), '.cache', 'dnnlib', *paths) |
|
|
|
|
|
|
|
|
|
|
|
|
|
def format_time(seconds: Union[int, float]) -> str: |
|
"""Convert the seconds to human readable string with days, hours, minutes and seconds.""" |
|
s = int(np.rint(seconds)) |
|
|
|
if s < 60: |
|
return "{0}s".format(s) |
|
elif s < 60 * 60: |
|
return "{0}m {1:02}s".format(s // 60, s % 60) |
|
elif s < 24 * 60 * 60: |
|
return "{0}h {1:02}m {2:02}s".format(s // (60 * 60), (s // 60) % 60, |
|
s % 60) |
|
else: |
|
return "{0}d {1:02}h {2:02}m".format(s // (24 * 60 * 60), |
|
(s // (60 * 60)) % 24, |
|
(s // 60) % 60) |
|
|
|
|
|
def format_time_brief(seconds: Union[int, float]) -> str: |
|
"""Convert the seconds to human readable string with days, hours, minutes and seconds.""" |
|
s = int(np.rint(seconds)) |
|
|
|
if s < 60: |
|
return "{0}s".format(s) |
|
elif s < 60 * 60: |
|
return "{0}m {1:02}s".format(s // 60, s % 60) |
|
elif s < 24 * 60 * 60: |
|
return "{0}h {1:02}m".format(s // (60 * 60), (s // 60) % 60) |
|
else: |
|
return "{0}d {1:02}h".format(s // (24 * 60 * 60), |
|
(s // (60 * 60)) % 24) |
|
|
|
|
|
def ask_yes_no(question: str) -> bool: |
|
"""Ask the user the question until the user inputs a valid answer.""" |
|
while True: |
|
try: |
|
print("{0} [y/n]".format(question)) |
|
return strtobool(input().lower()) |
|
except ValueError: |
|
pass |
|
|
|
|
|
def tuple_product(t: Tuple) -> Any: |
|
"""Calculate the product of the tuple elements.""" |
|
result = 1 |
|
|
|
for v in t: |
|
result *= v |
|
|
|
return result |
|
|
|
|
|
_str_to_ctype = { |
|
"uint8": ctypes.c_ubyte, |
|
"uint16": ctypes.c_uint16, |
|
"uint32": ctypes.c_uint32, |
|
"uint64": ctypes.c_uint64, |
|
"int8": ctypes.c_byte, |
|
"int16": ctypes.c_int16, |
|
"int32": ctypes.c_int32, |
|
"int64": ctypes.c_int64, |
|
"float32": ctypes.c_float, |
|
"float64": ctypes.c_double |
|
} |
|
|
|
|
|
def get_dtype_and_ctype(type_obj: Any) -> Tuple[np.dtype, Any]: |
|
"""Given a type name string (or an object having a __name__ attribute), return matching Numpy and ctypes types that have the same size in bytes.""" |
|
type_str = None |
|
|
|
if isinstance(type_obj, str): |
|
type_str = type_obj |
|
elif hasattr(type_obj, "__name__"): |
|
type_str = type_obj.__name__ |
|
elif hasattr(type_obj, "name"): |
|
type_str = type_obj.name |
|
else: |
|
raise RuntimeError("Cannot infer type name from input") |
|
|
|
assert type_str in _str_to_ctype.keys() |
|
|
|
my_dtype = np.dtype(type_str) |
|
my_ctype = _str_to_ctype[type_str] |
|
|
|
assert my_dtype.itemsize == ctypes.sizeof(my_ctype) |
|
|
|
return my_dtype, my_ctype |
|
|
|
|
|
def is_pickleable(obj: Any) -> bool: |
|
try: |
|
with io.BytesIO() as stream: |
|
pickle.dump(obj, stream) |
|
return True |
|
except: |
|
return False |
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_module_from_obj_name(obj_name: str) -> Tuple[types.ModuleType, str]: |
|
"""Searches for the underlying module behind the name to some python object. |
|
Returns the module and the object name (original name with module part removed).""" |
|
|
|
|
|
obj_name = re.sub("^np.", "numpy.", obj_name) |
|
obj_name = re.sub("^tf.", "tensorflow.", obj_name) |
|
|
|
|
|
parts = obj_name.split(".") |
|
name_pairs = [(".".join(parts[:i]), ".".join(parts[i:])) |
|
for i in range(len(parts), 0, -1)] |
|
|
|
|
|
for module_name, local_obj_name in name_pairs: |
|
try: |
|
module = importlib.import_module( |
|
module_name) |
|
get_obj_from_module(module, |
|
local_obj_name) |
|
return module, local_obj_name |
|
except: |
|
pass |
|
|
|
|
|
for module_name, _local_obj_name in name_pairs: |
|
try: |
|
importlib.import_module(module_name) |
|
except ImportError: |
|
if not str(sys.exc_info()[1]).startswith("No module named '" + |
|
module_name + "'"): |
|
raise |
|
|
|
|
|
for module_name, local_obj_name in name_pairs: |
|
try: |
|
module = importlib.import_module( |
|
module_name) |
|
get_obj_from_module(module, |
|
local_obj_name) |
|
except ImportError: |
|
pass |
|
|
|
|
|
raise ImportError(obj_name) |
|
|
|
|
|
def get_obj_from_module(module: types.ModuleType, obj_name: str) -> Any: |
|
"""Traverses the object name and returns the last (rightmost) python object.""" |
|
if obj_name == '': |
|
return module |
|
obj = module |
|
for part in obj_name.split("."): |
|
obj = getattr(obj, part) |
|
return obj |
|
|
|
|
|
def get_obj_by_name(name: str) -> Any: |
|
"""Finds the python object with the given name.""" |
|
module, obj_name = get_module_from_obj_name(name) |
|
return get_obj_from_module(module, obj_name) |
|
|
|
|
|
def call_func_by_name(*args, func_name: str = None, **kwargs) -> Any: |
|
"""Finds the python object with the given name and calls it as a function.""" |
|
assert func_name is not None |
|
func_obj = get_obj_by_name(func_name) |
|
assert callable(func_obj) |
|
return func_obj(*args, **kwargs) |
|
|
|
|
|
def construct_class_by_name(*args, class_name: str = None, **kwargs) -> Any: |
|
"""Finds the python class with the given name and constructs it with the given arguments.""" |
|
return call_func_by_name(*args, func_name=class_name, **kwargs) |
|
|
|
|
|
def get_module_dir_by_obj_name(obj_name: str) -> str: |
|
"""Get the directory path of the module containing the given object name.""" |
|
module, _ = get_module_from_obj_name(obj_name) |
|
return os.path.dirname(inspect.getfile(module)) |
|
|
|
|
|
def is_top_level_function(obj: Any) -> bool: |
|
"""Determine whether the given object is a top-level function, i.e., defined at module scope using 'def'.""" |
|
return callable(obj) and obj.__name__ in sys.modules[ |
|
obj.__module__].__dict__ |
|
|
|
|
|
def get_top_level_function_name(obj: Any) -> str: |
|
"""Return the fully-qualified name of a top-level function.""" |
|
assert is_top_level_function(obj) |
|
module = obj.__module__ |
|
if module == '__main__': |
|
module = os.path.splitext( |
|
os.path.basename(sys.modules[module].__file__))[0] |
|
return module + "." + obj.__name__ |
|
|
|
|
|
|
|
|
|
|
|
|
|
def list_dir_recursively_with_ignore( |
|
dir_path: str, |
|
ignores: List[str] = None, |
|
add_base_to_relative: bool = False) -> List[Tuple[str, str]]: |
|
"""List all files recursively in a given directory while ignoring given file and directory names. |
|
Returns list of tuples containing both absolute and relative paths.""" |
|
assert os.path.isdir(dir_path) |
|
base_name = os.path.basename(os.path.normpath(dir_path)) |
|
|
|
if ignores is None: |
|
ignores = [] |
|
|
|
result = [] |
|
|
|
for root, dirs, files in os.walk(dir_path, topdown=True): |
|
for ignore_ in ignores: |
|
dirs_to_remove = [d for d in dirs if fnmatch.fnmatch(d, ignore_)] |
|
|
|
|
|
for d in dirs_to_remove: |
|
dirs.remove(d) |
|
|
|
files = [f for f in files if not fnmatch.fnmatch(f, ignore_)] |
|
|
|
absolute_paths = [os.path.join(root, f) for f in files] |
|
relative_paths = [os.path.relpath(p, dir_path) for p in absolute_paths] |
|
|
|
if add_base_to_relative: |
|
relative_paths = [ |
|
os.path.join(base_name, p) for p in relative_paths |
|
] |
|
|
|
assert len(absolute_paths) == len(relative_paths) |
|
result += zip(absolute_paths, relative_paths) |
|
|
|
return result |
|
|
|
|
|
def copy_files_and_create_dirs(files: List[Tuple[str, str]]) -> None: |
|
"""Takes in a list of tuples of (src, dst) paths and copies files. |
|
Will create all necessary directories.""" |
|
for file in files: |
|
target_dir_name = os.path.dirname(file[1]) |
|
|
|
|
|
if not os.path.exists(target_dir_name): |
|
os.makedirs(target_dir_name) |
|
|
|
shutil.copyfile(file[0], file[1]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
def is_url(obj: Any, allow_file_urls: bool = False) -> bool: |
|
"""Determine whether the given object is a valid URL string.""" |
|
if not isinstance(obj, str) or not "://" in obj: |
|
return False |
|
if allow_file_urls and obj.startswith('file://'): |
|
return True |
|
try: |
|
res = requests.compat.urlparse(obj) |
|
if not res.scheme or not res.netloc or not "." in res.netloc: |
|
return False |
|
res = requests.compat.urlparse(requests.compat.urljoin(obj, "/")) |
|
if not res.scheme or not res.netloc or not "." in res.netloc: |
|
return False |
|
except: |
|
return False |
|
return True |
|
|
|
|
|
def open_url(url: str, |
|
cache_dir: str = None, |
|
num_attempts: int = 10, |
|
verbose: bool = True, |
|
return_filename: bool = False, |
|
cache: bool = True) -> Any: |
|
"""Download the given URL and return a binary-mode file object to access the data.""" |
|
assert num_attempts >= 1 |
|
assert not (return_filename and (not cache)) |
|
|
|
|
|
if not re.match('^[a-z]+://', url): |
|
return url if return_filename else open(url, "rb") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if url.startswith('file://'): |
|
filename = urllib.parse.urlparse(url).path |
|
if re.match(r'^/[a-zA-Z]:', filename): |
|
filename = filename[1:] |
|
return filename if return_filename else open(filename, "rb") |
|
|
|
assert is_url(url) |
|
|
|
|
|
if cache_dir is None: |
|
cache_dir = make_cache_dir_path('downloads') |
|
|
|
url_md5 = hashlib.md5(url.encode("utf-8")).hexdigest() |
|
if cache: |
|
cache_files = glob.glob(os.path.join(cache_dir, url_md5 + "_*")) |
|
if len(cache_files) == 1: |
|
filename = cache_files[0] |
|
return filename if return_filename else open(filename, "rb") |
|
|
|
|
|
url_name = None |
|
url_data = None |
|
with requests.Session() as session: |
|
if verbose: |
|
print("Downloading %s ..." % url, end="", flush=True) |
|
for attempts_left in reversed(range(num_attempts)): |
|
try: |
|
with session.get(url) as res: |
|
res.raise_for_status() |
|
if len(res.content) == 0: |
|
raise IOError("No data received") |
|
|
|
if len(res.content) < 8192: |
|
content_str = res.content.decode("utf-8") |
|
if "download_warning" in res.headers.get( |
|
"Set-Cookie", ""): |
|
links = [ |
|
html.unescape(link) |
|
for link in content_str.split('"') |
|
if "export=download" in link |
|
] |
|
if len(links) == 1: |
|
url = requests.compat.urljoin(url, links[0]) |
|
raise IOError("Google Drive virus checker nag") |
|
if "Google Drive - Quota exceeded" in content_str: |
|
raise IOError( |
|
"Google Drive download quota exceeded -- please try again later" |
|
) |
|
|
|
match = re.search( |
|
r'filename="([^"]*)"', |
|
res.headers.get("Content-Disposition", "")) |
|
url_name = match[1] if match else url |
|
url_data = res.content |
|
if verbose: |
|
print(" done") |
|
break |
|
except KeyboardInterrupt: |
|
raise |
|
except: |
|
if not attempts_left: |
|
if verbose: |
|
print(" failed") |
|
raise |
|
if verbose: |
|
print(".", end="", flush=True) |
|
|
|
|
|
if cache: |
|
safe_name = re.sub(r"[^0-9a-zA-Z-._]", "_", url_name) |
|
cache_file = os.path.join(cache_dir, url_md5 + "_" + safe_name) |
|
temp_file = os.path.join( |
|
cache_dir, |
|
"tmp_" + uuid.uuid4().hex + "_" + url_md5 + "_" + safe_name) |
|
os.makedirs(cache_dir, exist_ok=True) |
|
with open(temp_file, "wb") as f: |
|
f.write(url_data) |
|
os.replace(temp_file, cache_file) |
|
if return_filename: |
|
return cache_file |
|
|
|
|
|
assert not return_filename |
|
return io.BytesIO(url_data) |
|
|
|
class InfiniteSampler(torch.utils.data.Sampler): |
|
|
|
def __init__(self, |
|
dataset, |
|
rank=0, |
|
num_replicas=1, |
|
shuffle=True, |
|
seed=0, |
|
window_size=0.5): |
|
assert len(dataset) > 0 |
|
assert num_replicas > 0 |
|
assert 0 <= rank < num_replicas |
|
assert 0 <= window_size <= 1 |
|
super().__init__(dataset) |
|
self.dataset = dataset |
|
self.rank = rank |
|
self.num_replicas = num_replicas |
|
self.shuffle = shuffle |
|
self.seed = seed |
|
self.window_size = window_size |
|
|
|
def __iter__(self): |
|
order = np.arange(len(self.dataset)) |
|
rnd = None |
|
window = 0 |
|
if self.shuffle: |
|
rnd = np.random.RandomState(self.seed) |
|
rnd.shuffle(order) |
|
window = int(np.rint(order.size * self.window_size)) |
|
|
|
idx = 0 |
|
while True: |
|
i = idx % order.size |
|
if idx % self.num_replicas == self.rank: |
|
yield order[i] |
|
if window >= 2: |
|
j = (i - rnd.randint(window)) % order.size |
|
order[i], order[j] = order[j], order[i] |
|
idx += 1 |
|
|
|
def requires_grad(model, flag=True): |
|
for p in model.parameters(): |
|
p.requires_grad = flag |
|
|