Added base files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- Demo.ipynb +0 -0
- align_images.py +57 -0
- config.py +22 -0
- dnnlib/__init__.py +20 -0
- dnnlib/__pycache__/__init__.cpython-36.pyc +0 -0
- dnnlib/__pycache__/__init__.cpython-37.pyc +0 -0
- dnnlib/__pycache__/util.cpython-36.pyc +0 -0
- dnnlib/__pycache__/util.cpython-37.pyc +0 -0
- dnnlib/submission/__init__.py +9 -0
- dnnlib/submission/__pycache__/__init__.cpython-36.pyc +0 -0
- dnnlib/submission/__pycache__/__init__.cpython-37.pyc +0 -0
- dnnlib/submission/__pycache__/run_context.cpython-36.pyc +0 -0
- dnnlib/submission/__pycache__/run_context.cpython-37.pyc +0 -0
- dnnlib/submission/__pycache__/submit.cpython-36.pyc +0 -0
- dnnlib/submission/__pycache__/submit.cpython-37.pyc +0 -0
- dnnlib/submission/_internal/run.py +45 -0
- dnnlib/submission/run_context.py +99 -0
- dnnlib/submission/submit.py +290 -0
- dnnlib/tflib/__init__.py +16 -0
- dnnlib/tflib/__pycache__/__init__.cpython-36.pyc +0 -0
- dnnlib/tflib/__pycache__/__init__.cpython-37.pyc +0 -0
- dnnlib/tflib/__pycache__/autosummary.cpython-36.pyc +0 -0
- dnnlib/tflib/__pycache__/autosummary.cpython-37.pyc +0 -0
- dnnlib/tflib/__pycache__/network.cpython-36.pyc +0 -0
- dnnlib/tflib/__pycache__/network.cpython-37.pyc +0 -0
- dnnlib/tflib/__pycache__/optimizer.cpython-36.pyc +0 -0
- dnnlib/tflib/__pycache__/optimizer.cpython-37.pyc +0 -0
- dnnlib/tflib/__pycache__/tfutil.cpython-36.pyc +0 -0
- dnnlib/tflib/__pycache__/tfutil.cpython-37.pyc +0 -0
- dnnlib/tflib/autosummary.py +184 -0
- dnnlib/tflib/network.py +628 -0
- dnnlib/tflib/optimizer.py +214 -0
- dnnlib/tflib/tfutil.py +242 -0
- dnnlib/util.py +408 -0
- encode_images.py +242 -0
- encoder/__init__.py +0 -0
- encoder/__pycache__/__init__.cpython-36.pyc +0 -0
- encoder/__pycache__/__init__.cpython-37.pyc +0 -0
- encoder/__pycache__/generator_model.cpython-36.pyc +0 -0
- encoder/__pycache__/generator_model.cpython-37.pyc +0 -0
- encoder/__pycache__/perceptual_model.cpython-36.pyc +0 -0
- encoder/__pycache__/perceptual_model.cpython-37.pyc +0 -0
- encoder/generator_model.py +137 -0
- encoder/perceptual_model.py +304 -0
- ffhq_dataset/__init__.py +0 -0
- ffhq_dataset/__pycache__/__init__.cpython-36.pyc +0 -0
- ffhq_dataset/__pycache__/__init__.cpython-37.pyc +0 -0
- ffhq_dataset/__pycache__/face_alignment.cpython-36.pyc +0 -0
- ffhq_dataset/__pycache__/face_alignment.cpython-37.pyc +0 -0
- ffhq_dataset/__pycache__/landmarks_detector.cpython-36.pyc +0 -0
Demo.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
align_images.py
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
import bz2
|
4 |
+
import argparse
|
5 |
+
from keras.utils import get_file
|
6 |
+
from ffhq_dataset.face_alignment import image_align
|
7 |
+
from ffhq_dataset.landmarks_detector import LandmarksDetector
|
8 |
+
import multiprocessing
|
9 |
+
|
10 |
+
def unpack_bz2(src_path):
|
11 |
+
data = bz2.BZ2File(src_path).read()
|
12 |
+
dst_path = src_path[:-4]
|
13 |
+
with open(dst_path, 'wb') as fp:
|
14 |
+
fp.write(data)
|
15 |
+
return dst_path
|
16 |
+
|
17 |
+
|
18 |
+
if __name__ == "__main__":
|
19 |
+
"""
|
20 |
+
Extracts and aligns all faces from images using DLib and a function from original FFHQ dataset preparation step
|
21 |
+
python align_images.py /raw_images /aligned_images
|
22 |
+
"""
|
23 |
+
parser = argparse.ArgumentParser(description='Align faces from input images', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
24 |
+
parser.add_argument('raw_dir', help='Directory with raw images for face alignment')
|
25 |
+
parser.add_argument('aligned_dir', help='Directory for storing aligned images')
|
26 |
+
parser.add_argument('--output_size', default=1024, help='The dimension of images for input to the model', type=int)
|
27 |
+
parser.add_argument('--x_scale', default=1, help='Scaling factor for x dimension', type=float)
|
28 |
+
parser.add_argument('--y_scale', default=1, help='Scaling factor for y dimension', type=float)
|
29 |
+
parser.add_argument('--em_scale', default=0.1, help='Scaling factor for eye-mouth distance', type=float)
|
30 |
+
parser.add_argument('--use_alpha', default=False, help='Add an alpha channel for masking', type=bool)
|
31 |
+
|
32 |
+
args, other_args = parser.parse_known_args()
|
33 |
+
|
34 |
+
landmarks_model_path = unpack_bz2("shape_predictor_68_face_landmarks.dat.bz2")
|
35 |
+
RAW_IMAGES_DIR = args.raw_dir
|
36 |
+
ALIGNED_IMAGES_DIR = args.aligned_dir
|
37 |
+
|
38 |
+
landmarks_detector = LandmarksDetector(landmarks_model_path)
|
39 |
+
for img_name in os.listdir(RAW_IMAGES_DIR):
|
40 |
+
print('Aligning %s ...' % img_name)
|
41 |
+
try:
|
42 |
+
raw_img_path = os.path.join(RAW_IMAGES_DIR, img_name)
|
43 |
+
fn = face_img_name = '%s_%02d.png' % (os.path.splitext(img_name)[0], 1)
|
44 |
+
if os.path.isfile(fn):
|
45 |
+
continue
|
46 |
+
print('Getting landmarks...')
|
47 |
+
for i, face_landmarks in enumerate(landmarks_detector.get_landmarks(raw_img_path), start=1):
|
48 |
+
try:
|
49 |
+
print('Starting face alignment...')
|
50 |
+
face_img_name = '%s_%02d.png' % (os.path.splitext(img_name)[0], i)
|
51 |
+
aligned_face_path = os.path.join(ALIGNED_IMAGES_DIR, face_img_name)
|
52 |
+
image_align(raw_img_path, aligned_face_path, face_landmarks, output_size=args.output_size, x_scale=args.x_scale, y_scale=args.y_scale, em_scale=args.em_scale, alpha=args.use_alpha)
|
53 |
+
print('Wrote result %s' % aligned_face_path)
|
54 |
+
except:
|
55 |
+
print("Exception in face alignment!")
|
56 |
+
except:
|
57 |
+
print("Exception in landmark detection!")
|
config.py
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
#
|
3 |
+
# This work is licensed under the Creative Commons Attribution-NonCommercial
|
4 |
+
# 4.0 International License. To view a copy of this license, visit
|
5 |
+
# http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to
|
6 |
+
# Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.
|
7 |
+
|
8 |
+
"""Global configuration."""
|
9 |
+
|
10 |
+
#----------------------------------------------------------------------------
|
11 |
+
# Paths.
|
12 |
+
|
13 |
+
result_dir = 'results'
|
14 |
+
data_dir = 'datasets'
|
15 |
+
cache_dir = 'cache'
|
16 |
+
run_dir_ignore = ['results', 'datasets', 'cache']
|
17 |
+
|
18 |
+
# experimental - replace Dense layers with TreeConnect
|
19 |
+
use_treeconnect = False
|
20 |
+
treeconnect_threshold = 1024
|
21 |
+
|
22 |
+
#----------------------------------------------------------------------------
|
dnnlib/__init__.py
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
#
|
3 |
+
# This work is licensed under the Creative Commons Attribution-NonCommercial
|
4 |
+
# 4.0 International License. To view a copy of this license, visit
|
5 |
+
# http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to
|
6 |
+
# Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.
|
7 |
+
|
8 |
+
from . import submission
|
9 |
+
|
10 |
+
from .submission.run_context import RunContext
|
11 |
+
|
12 |
+
from .submission.submit import SubmitTarget
|
13 |
+
from .submission.submit import PathType
|
14 |
+
from .submission.submit import SubmitConfig
|
15 |
+
from .submission.submit import get_path_from_template
|
16 |
+
from .submission.submit import submit_run
|
17 |
+
|
18 |
+
from .util import EasyDict
|
19 |
+
|
20 |
+
submit_config: SubmitConfig = None # Package level variable for SubmitConfig which is only valid when inside the run function.
|
dnnlib/__pycache__/__init__.cpython-36.pyc
ADDED
Binary file (468 Bytes). View file
|
|
dnnlib/__pycache__/__init__.cpython-37.pyc
ADDED
Binary file (498 Bytes). View file
|
|
dnnlib/__pycache__/util.cpython-36.pyc
ADDED
Binary file (12.1 kB). View file
|
|
dnnlib/__pycache__/util.cpython-37.pyc
ADDED
Binary file (12.1 kB). View file
|
|
dnnlib/submission/__init__.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
#
|
3 |
+
# This work is licensed under the Creative Commons Attribution-NonCommercial
|
4 |
+
# 4.0 International License. To view a copy of this license, visit
|
5 |
+
# http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to
|
6 |
+
# Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.
|
7 |
+
|
8 |
+
from . import run_context
|
9 |
+
from . import submit
|
dnnlib/submission/__pycache__/__init__.cpython-36.pyc
ADDED
Binary file (188 Bytes). View file
|
|
dnnlib/submission/__pycache__/__init__.cpython-37.pyc
ADDED
Binary file (192 Bytes). View file
|
|
dnnlib/submission/__pycache__/run_context.cpython-36.pyc
ADDED
Binary file (4.35 kB). View file
|
|
dnnlib/submission/__pycache__/run_context.cpython-37.pyc
ADDED
Binary file (4.35 kB). View file
|
|
dnnlib/submission/__pycache__/submit.cpython-36.pyc
ADDED
Binary file (9.19 kB). View file
|
|
dnnlib/submission/__pycache__/submit.cpython-37.pyc
ADDED
Binary file (9.19 kB). View file
|
|
dnnlib/submission/_internal/run.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
#
|
3 |
+
# This work is licensed under the Creative Commons Attribution-NonCommercial
|
4 |
+
# 4.0 International License. To view a copy of this license, visit
|
5 |
+
# http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to
|
6 |
+
# Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.
|
7 |
+
|
8 |
+
"""Helper for launching run functions in computing clusters.
|
9 |
+
|
10 |
+
During the submit process, this file is copied to the appropriate run dir.
|
11 |
+
When the job is launched in the cluster, this module is the first thing that
|
12 |
+
is run inside the docker container.
|
13 |
+
"""
|
14 |
+
|
15 |
+
import os
|
16 |
+
import pickle
|
17 |
+
import sys
|
18 |
+
|
19 |
+
# PYTHONPATH should have been set so that the run_dir/src is in it
|
20 |
+
import dnnlib
|
21 |
+
|
22 |
+
def main():
|
23 |
+
if not len(sys.argv) >= 4:
|
24 |
+
raise RuntimeError("This script needs three arguments: run_dir, task_name and host_name!")
|
25 |
+
|
26 |
+
run_dir = str(sys.argv[1])
|
27 |
+
task_name = str(sys.argv[2])
|
28 |
+
host_name = str(sys.argv[3])
|
29 |
+
|
30 |
+
submit_config_path = os.path.join(run_dir, "submit_config.pkl")
|
31 |
+
|
32 |
+
# SubmitConfig should have been pickled to the run dir
|
33 |
+
if not os.path.exists(submit_config_path):
|
34 |
+
raise RuntimeError("SubmitConfig pickle file does not exist!")
|
35 |
+
|
36 |
+
submit_config: dnnlib.SubmitConfig = pickle.load(open(submit_config_path, "rb"))
|
37 |
+
dnnlib.submission.submit.set_user_name_override(submit_config.user_name)
|
38 |
+
|
39 |
+
submit_config.task_name = task_name
|
40 |
+
submit_config.host_name = host_name
|
41 |
+
|
42 |
+
dnnlib.submission.submit.run_wrapper(submit_config)
|
43 |
+
|
44 |
+
if __name__ == "__main__":
|
45 |
+
main()
|
dnnlib/submission/run_context.py
ADDED
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
#
|
3 |
+
# This work is licensed under the Creative Commons Attribution-NonCommercial
|
4 |
+
# 4.0 International License. To view a copy of this license, visit
|
5 |
+
# http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to
|
6 |
+
# Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.
|
7 |
+
|
8 |
+
"""Helpers for managing the run/training loop."""
|
9 |
+
|
10 |
+
import datetime
|
11 |
+
import json
|
12 |
+
import os
|
13 |
+
import pprint
|
14 |
+
import time
|
15 |
+
import types
|
16 |
+
|
17 |
+
from typing import Any
|
18 |
+
|
19 |
+
from . import submit
|
20 |
+
|
21 |
+
|
22 |
+
class RunContext(object):
|
23 |
+
"""Helper class for managing the run/training loop.
|
24 |
+
|
25 |
+
The context will hide the implementation details of a basic run/training loop.
|
26 |
+
It will set things up properly, tell if run should be stopped, and then cleans up.
|
27 |
+
User should call update periodically and use should_stop to determine if run should be stopped.
|
28 |
+
|
29 |
+
Args:
|
30 |
+
submit_config: The SubmitConfig that is used for the current run.
|
31 |
+
config_module: The whole config module that is used for the current run.
|
32 |
+
max_epoch: Optional cached value for the max_epoch variable used in update.
|
33 |
+
"""
|
34 |
+
|
35 |
+
def __init__(self, submit_config: submit.SubmitConfig, config_module: types.ModuleType = None, max_epoch: Any = None):
|
36 |
+
self.submit_config = submit_config
|
37 |
+
self.should_stop_flag = False
|
38 |
+
self.has_closed = False
|
39 |
+
self.start_time = time.time()
|
40 |
+
self.last_update_time = time.time()
|
41 |
+
self.last_update_interval = 0.0
|
42 |
+
self.max_epoch = max_epoch
|
43 |
+
|
44 |
+
# pretty print the all the relevant content of the config module to a text file
|
45 |
+
if config_module is not None:
|
46 |
+
with open(os.path.join(submit_config.run_dir, "config.txt"), "w") as f:
|
47 |
+
filtered_dict = {k: v for k, v in config_module.__dict__.items() if not k.startswith("_") and not isinstance(v, (types.ModuleType, types.FunctionType, types.LambdaType, submit.SubmitConfig, type))}
|
48 |
+
pprint.pprint(filtered_dict, stream=f, indent=4, width=200, compact=False)
|
49 |
+
|
50 |
+
# write out details about the run to a text file
|
51 |
+
self.run_txt_data = {"task_name": submit_config.task_name, "host_name": submit_config.host_name, "start_time": datetime.datetime.now().isoformat(sep=" ")}
|
52 |
+
with open(os.path.join(submit_config.run_dir, "run.txt"), "w") as f:
|
53 |
+
pprint.pprint(self.run_txt_data, stream=f, indent=4, width=200, compact=False)
|
54 |
+
|
55 |
+
def __enter__(self) -> "RunContext":
|
56 |
+
return self
|
57 |
+
|
58 |
+
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
|
59 |
+
self.close()
|
60 |
+
|
61 |
+
def update(self, loss: Any = 0, cur_epoch: Any = 0, max_epoch: Any = None) -> None:
|
62 |
+
"""Do general housekeeping and keep the state of the context up-to-date.
|
63 |
+
Should be called often enough but not in a tight loop."""
|
64 |
+
assert not self.has_closed
|
65 |
+
|
66 |
+
self.last_update_interval = time.time() - self.last_update_time
|
67 |
+
self.last_update_time = time.time()
|
68 |
+
|
69 |
+
if os.path.exists(os.path.join(self.submit_config.run_dir, "abort.txt")):
|
70 |
+
self.should_stop_flag = True
|
71 |
+
|
72 |
+
max_epoch_val = self.max_epoch if max_epoch is None else max_epoch
|
73 |
+
|
74 |
+
def should_stop(self) -> bool:
|
75 |
+
"""Tell whether a stopping condition has been triggered one way or another."""
|
76 |
+
return self.should_stop_flag
|
77 |
+
|
78 |
+
def get_time_since_start(self) -> float:
|
79 |
+
"""How much time has passed since the creation of the context."""
|
80 |
+
return time.time() - self.start_time
|
81 |
+
|
82 |
+
def get_time_since_last_update(self) -> float:
|
83 |
+
"""How much time has passed since the last call to update."""
|
84 |
+
return time.time() - self.last_update_time
|
85 |
+
|
86 |
+
def get_last_update_interval(self) -> float:
|
87 |
+
"""How much time passed between the previous two calls to update."""
|
88 |
+
return self.last_update_interval
|
89 |
+
|
90 |
+
def close(self) -> None:
|
91 |
+
"""Close the context and clean up.
|
92 |
+
Should only be called once."""
|
93 |
+
if not self.has_closed:
|
94 |
+
# update the run.txt with stopping time
|
95 |
+
self.run_txt_data["stop_time"] = datetime.datetime.now().isoformat(sep=" ")
|
96 |
+
with open(os.path.join(self.submit_config.run_dir, "run.txt"), "w") as f:
|
97 |
+
pprint.pprint(self.run_txt_data, stream=f, indent=4, width=200, compact=False)
|
98 |
+
|
99 |
+
self.has_closed = True
|
dnnlib/submission/submit.py
ADDED
@@ -0,0 +1,290 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
#
|
3 |
+
# This work is licensed under the Creative Commons Attribution-NonCommercial
|
4 |
+
# 4.0 International License. To view a copy of this license, visit
|
5 |
+
# http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to
|
6 |
+
# Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.
|
7 |
+
|
8 |
+
"""Submit a function to be run either locally or in a computing cluster."""
|
9 |
+
|
10 |
+
import copy
|
11 |
+
import io
|
12 |
+
import os
|
13 |
+
import pathlib
|
14 |
+
import pickle
|
15 |
+
import platform
|
16 |
+
import pprint
|
17 |
+
import re
|
18 |
+
import shutil
|
19 |
+
import time
|
20 |
+
import traceback
|
21 |
+
|
22 |
+
import zipfile
|
23 |
+
|
24 |
+
from enum import Enum
|
25 |
+
|
26 |
+
from .. import util
|
27 |
+
from ..util import EasyDict
|
28 |
+
|
29 |
+
|
30 |
+
class SubmitTarget(Enum):
|
31 |
+
"""The target where the function should be run.
|
32 |
+
|
33 |
+
LOCAL: Run it locally.
|
34 |
+
"""
|
35 |
+
LOCAL = 1
|
36 |
+
|
37 |
+
|
38 |
+
class PathType(Enum):
|
39 |
+
"""Determines in which format should a path be formatted.
|
40 |
+
|
41 |
+
WINDOWS: Format with Windows style.
|
42 |
+
LINUX: Format with Linux/Posix style.
|
43 |
+
AUTO: Use current OS type to select either WINDOWS or LINUX.
|
44 |
+
"""
|
45 |
+
WINDOWS = 1
|
46 |
+
LINUX = 2
|
47 |
+
AUTO = 3
|
48 |
+
|
49 |
+
|
50 |
+
_user_name_override = None
|
51 |
+
|
52 |
+
|
53 |
+
class SubmitConfig(util.EasyDict):
|
54 |
+
"""Strongly typed config dict needed to submit runs.
|
55 |
+
|
56 |
+
Attributes:
|
57 |
+
run_dir_root: Path to the run dir root. Can be optionally templated with tags. Needs to always be run through get_path_from_template.
|
58 |
+
run_desc: Description of the run. Will be used in the run dir and task name.
|
59 |
+
run_dir_ignore: List of file patterns used to ignore files when copying files to the run dir.
|
60 |
+
run_dir_extra_files: List of (abs_path, rel_path) tuples of file paths. rel_path root will be the src directory inside the run dir.
|
61 |
+
submit_target: Submit target enum value. Used to select where the run is actually launched.
|
62 |
+
num_gpus: Number of GPUs used/requested for the run.
|
63 |
+
print_info: Whether to print debug information when submitting.
|
64 |
+
ask_confirmation: Whether to ask a confirmation before submitting.
|
65 |
+
run_id: Automatically populated value during submit.
|
66 |
+
run_name: Automatically populated value during submit.
|
67 |
+
run_dir: Automatically populated value during submit.
|
68 |
+
run_func_name: Automatically populated value during submit.
|
69 |
+
run_func_kwargs: Automatically populated value during submit.
|
70 |
+
user_name: Automatically populated value during submit. Can be set by the user which will then override the automatic value.
|
71 |
+
task_name: Automatically populated value during submit.
|
72 |
+
host_name: Automatically populated value during submit.
|
73 |
+
"""
|
74 |
+
|
75 |
+
def __init__(self):
|
76 |
+
super().__init__()
|
77 |
+
|
78 |
+
# run (set these)
|
79 |
+
self.run_dir_root = "" # should always be passed through get_path_from_template
|
80 |
+
self.run_desc = ""
|
81 |
+
self.run_dir_ignore = ["__pycache__", "*.pyproj", "*.sln", "*.suo", ".cache", ".idea", ".vs", ".vscode"]
|
82 |
+
self.run_dir_extra_files = None
|
83 |
+
|
84 |
+
# submit (set these)
|
85 |
+
self.submit_target = SubmitTarget.LOCAL
|
86 |
+
self.num_gpus = 1
|
87 |
+
self.print_info = False
|
88 |
+
self.ask_confirmation = False
|
89 |
+
|
90 |
+
# (automatically populated)
|
91 |
+
self.run_id = None
|
92 |
+
self.run_name = None
|
93 |
+
self.run_dir = None
|
94 |
+
self.run_func_name = None
|
95 |
+
self.run_func_kwargs = None
|
96 |
+
self.user_name = None
|
97 |
+
self.task_name = None
|
98 |
+
self.host_name = "localhost"
|
99 |
+
|
100 |
+
|
101 |
+
def get_path_from_template(path_template: str, path_type: PathType = PathType.AUTO) -> str:
|
102 |
+
"""Replace tags in the given path template and return either Windows or Linux formatted path."""
|
103 |
+
# automatically select path type depending on running OS
|
104 |
+
if path_type == PathType.AUTO:
|
105 |
+
if platform.system() == "Windows":
|
106 |
+
path_type = PathType.WINDOWS
|
107 |
+
elif platform.system() == "Linux":
|
108 |
+
path_type = PathType.LINUX
|
109 |
+
else:
|
110 |
+
raise RuntimeError("Unknown platform")
|
111 |
+
|
112 |
+
path_template = path_template.replace("<USERNAME>", get_user_name())
|
113 |
+
|
114 |
+
# return correctly formatted path
|
115 |
+
if path_type == PathType.WINDOWS:
|
116 |
+
return str(pathlib.PureWindowsPath(path_template))
|
117 |
+
elif path_type == PathType.LINUX:
|
118 |
+
return str(pathlib.PurePosixPath(path_template))
|
119 |
+
else:
|
120 |
+
raise RuntimeError("Unknown platform")
|
121 |
+
|
122 |
+
|
123 |
+
def get_template_from_path(path: str) -> str:
|
124 |
+
"""Convert a normal path back to its template representation."""
|
125 |
+
# replace all path parts with the template tags
|
126 |
+
path = path.replace("\\", "/")
|
127 |
+
return path
|
128 |
+
|
129 |
+
|
130 |
+
def convert_path(path: str, path_type: PathType = PathType.AUTO) -> str:
|
131 |
+
"""Convert a normal path to template and the convert it back to a normal path with given path type."""
|
132 |
+
path_template = get_template_from_path(path)
|
133 |
+
path = get_path_from_template(path_template, path_type)
|
134 |
+
return path
|
135 |
+
|
136 |
+
|
137 |
+
def set_user_name_override(name: str) -> None:
|
138 |
+
"""Set the global username override value."""
|
139 |
+
global _user_name_override
|
140 |
+
_user_name_override = name
|
141 |
+
|
142 |
+
|
143 |
+
def get_user_name():
|
144 |
+
"""Get the current user name."""
|
145 |
+
if _user_name_override is not None:
|
146 |
+
return _user_name_override
|
147 |
+
elif platform.system() == "Windows":
|
148 |
+
return os.getlogin()
|
149 |
+
elif platform.system() == "Linux":
|
150 |
+
try:
|
151 |
+
import pwd # pylint: disable=import-error
|
152 |
+
return pwd.getpwuid(os.geteuid()).pw_name # pylint: disable=no-member
|
153 |
+
except:
|
154 |
+
return "unknown"
|
155 |
+
else:
|
156 |
+
raise RuntimeError("Unknown platform")
|
157 |
+
|
158 |
+
|
159 |
+
def _create_run_dir_local(submit_config: SubmitConfig) -> str:
|
160 |
+
"""Create a new run dir with increasing ID number at the start."""
|
161 |
+
run_dir_root = get_path_from_template(submit_config.run_dir_root, PathType.AUTO)
|
162 |
+
|
163 |
+
if not os.path.exists(run_dir_root):
|
164 |
+
print("Creating the run dir root: {}".format(run_dir_root))
|
165 |
+
os.makedirs(run_dir_root)
|
166 |
+
|
167 |
+
submit_config.run_id = _get_next_run_id_local(run_dir_root)
|
168 |
+
submit_config.run_name = "{0:05d}-{1}".format(submit_config.run_id, submit_config.run_desc)
|
169 |
+
run_dir = os.path.join(run_dir_root, submit_config.run_name)
|
170 |
+
|
171 |
+
if os.path.exists(run_dir):
|
172 |
+
raise RuntimeError("The run dir already exists! ({0})".format(run_dir))
|
173 |
+
|
174 |
+
print("Creating the run dir: {}".format(run_dir))
|
175 |
+
os.makedirs(run_dir)
|
176 |
+
|
177 |
+
return run_dir
|
178 |
+
|
179 |
+
|
180 |
+
def _get_next_run_id_local(run_dir_root: str) -> int:
|
181 |
+
"""Reads all directory names in a given directory (non-recursive) and returns the next (increasing) run id. Assumes IDs are numbers at the start of the directory names."""
|
182 |
+
dir_names = [d for d in os.listdir(run_dir_root) if os.path.isdir(os.path.join(run_dir_root, d))]
|
183 |
+
r = re.compile("^\\d+") # match one or more digits at the start of the string
|
184 |
+
run_id = 0
|
185 |
+
|
186 |
+
for dir_name in dir_names:
|
187 |
+
m = r.match(dir_name)
|
188 |
+
|
189 |
+
if m is not None:
|
190 |
+
i = int(m.group())
|
191 |
+
run_id = max(run_id, i + 1)
|
192 |
+
|
193 |
+
return run_id
|
194 |
+
|
195 |
+
|
196 |
+
def _populate_run_dir(run_dir: str, submit_config: SubmitConfig) -> None:
|
197 |
+
"""Copy all necessary files into the run dir. Assumes that the dir exists, is local, and is writable."""
|
198 |
+
print("Copying files to the run dir")
|
199 |
+
files = []
|
200 |
+
|
201 |
+
run_func_module_dir_path = util.get_module_dir_by_obj_name(submit_config.run_func_name)
|
202 |
+
assert '.' in submit_config.run_func_name
|
203 |
+
for _idx in range(submit_config.run_func_name.count('.') - 1):
|
204 |
+
run_func_module_dir_path = os.path.dirname(run_func_module_dir_path)
|
205 |
+
files += util.list_dir_recursively_with_ignore(run_func_module_dir_path, ignores=submit_config.run_dir_ignore, add_base_to_relative=False)
|
206 |
+
|
207 |
+
dnnlib_module_dir_path = util.get_module_dir_by_obj_name("dnnlib")
|
208 |
+
files += util.list_dir_recursively_with_ignore(dnnlib_module_dir_path, ignores=submit_config.run_dir_ignore, add_base_to_relative=True)
|
209 |
+
|
210 |
+
if submit_config.run_dir_extra_files is not None:
|
211 |
+
files += submit_config.run_dir_extra_files
|
212 |
+
|
213 |
+
files = [(f[0], os.path.join(run_dir, "src", f[1])) for f in files]
|
214 |
+
files += [(os.path.join(dnnlib_module_dir_path, "submission", "_internal", "run.py"), os.path.join(run_dir, "run.py"))]
|
215 |
+
|
216 |
+
util.copy_files_and_create_dirs(files)
|
217 |
+
|
218 |
+
pickle.dump(submit_config, open(os.path.join(run_dir, "submit_config.pkl"), "wb"))
|
219 |
+
|
220 |
+
with open(os.path.join(run_dir, "submit_config.txt"), "w") as f:
|
221 |
+
pprint.pprint(submit_config, stream=f, indent=4, width=200, compact=False)
|
222 |
+
|
223 |
+
|
224 |
+
def run_wrapper(submit_config: SubmitConfig) -> None:
|
225 |
+
"""Wrap the actual run function call for handling logging, exceptions, typing, etc."""
|
226 |
+
is_local = submit_config.submit_target == SubmitTarget.LOCAL
|
227 |
+
|
228 |
+
checker = None
|
229 |
+
|
230 |
+
# when running locally, redirect stderr to stdout, log stdout to a file, and force flushing
|
231 |
+
if is_local:
|
232 |
+
logger = util.Logger(file_name=os.path.join(submit_config.run_dir, "log.txt"), file_mode="w", should_flush=True)
|
233 |
+
else: # when running in a cluster, redirect stderr to stdout, and just force flushing (log writing is handled by run.sh)
|
234 |
+
logger = util.Logger(file_name=None, should_flush=True)
|
235 |
+
|
236 |
+
import dnnlib
|
237 |
+
dnnlib.submit_config = submit_config
|
238 |
+
|
239 |
+
try:
|
240 |
+
print("dnnlib: Running {0}() on {1}...".format(submit_config.run_func_name, submit_config.host_name))
|
241 |
+
start_time = time.time()
|
242 |
+
util.call_func_by_name(func_name=submit_config.run_func_name, submit_config=submit_config, **submit_config.run_func_kwargs)
|
243 |
+
print("dnnlib: Finished {0}() in {1}.".format(submit_config.run_func_name, util.format_time(time.time() - start_time)))
|
244 |
+
except:
|
245 |
+
if is_local:
|
246 |
+
raise
|
247 |
+
else:
|
248 |
+
traceback.print_exc()
|
249 |
+
|
250 |
+
log_src = os.path.join(submit_config.run_dir, "log.txt")
|
251 |
+
log_dst = os.path.join(get_path_from_template(submit_config.run_dir_root), "{0}-error.txt".format(submit_config.run_name))
|
252 |
+
shutil.copyfile(log_src, log_dst)
|
253 |
+
finally:
|
254 |
+
open(os.path.join(submit_config.run_dir, "_finished.txt"), "w").close()
|
255 |
+
|
256 |
+
dnnlib.submit_config = None
|
257 |
+
logger.close()
|
258 |
+
|
259 |
+
if checker is not None:
|
260 |
+
checker.stop()
|
261 |
+
|
262 |
+
|
263 |
+
def submit_run(submit_config: SubmitConfig, run_func_name: str, **run_func_kwargs) -> None:
|
264 |
+
"""Create a run dir, gather files related to the run, copy files to the run dir, and launch the run in appropriate place."""
|
265 |
+
submit_config = copy.copy(submit_config)
|
266 |
+
|
267 |
+
if submit_config.user_name is None:
|
268 |
+
submit_config.user_name = get_user_name()
|
269 |
+
|
270 |
+
submit_config.run_func_name = run_func_name
|
271 |
+
submit_config.run_func_kwargs = run_func_kwargs
|
272 |
+
|
273 |
+
assert submit_config.submit_target == SubmitTarget.LOCAL
|
274 |
+
if submit_config.submit_target in {SubmitTarget.LOCAL}:
|
275 |
+
run_dir = _create_run_dir_local(submit_config)
|
276 |
+
|
277 |
+
submit_config.task_name = "{0}-{1:05d}-{2}".format(submit_config.user_name, submit_config.run_id, submit_config.run_desc)
|
278 |
+
submit_config.run_dir = run_dir
|
279 |
+
_populate_run_dir(run_dir, submit_config)
|
280 |
+
|
281 |
+
if submit_config.print_info:
|
282 |
+
print("\nSubmit config:\n")
|
283 |
+
pprint.pprint(submit_config, indent=4, width=200, compact=False)
|
284 |
+
print()
|
285 |
+
|
286 |
+
if submit_config.ask_confirmation:
|
287 |
+
if not util.ask_yes_no("Continue submitting the job?"):
|
288 |
+
return
|
289 |
+
|
290 |
+
run_wrapper(submit_config)
|
dnnlib/tflib/__init__.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
#
|
3 |
+
# This work is licensed under the Creative Commons Attribution-NonCommercial
|
4 |
+
# 4.0 International License. To view a copy of this license, visit
|
5 |
+
# http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to
|
6 |
+
# Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.
|
7 |
+
|
8 |
+
from . import autosummary
|
9 |
+
from . import network
|
10 |
+
from . import optimizer
|
11 |
+
from . import tfutil
|
12 |
+
|
13 |
+
from .tfutil import *
|
14 |
+
from .network import Network
|
15 |
+
|
16 |
+
from .optimizer import Optimizer
|
dnnlib/tflib/__pycache__/__init__.cpython-36.pyc
ADDED
Binary file (322 Bytes). View file
|
|
dnnlib/tflib/__pycache__/__init__.cpython-37.pyc
ADDED
Binary file (326 Bytes). View file
|
|
dnnlib/tflib/__pycache__/autosummary.cpython-36.pyc
ADDED
Binary file (6.38 kB). View file
|
|
dnnlib/tflib/__pycache__/autosummary.cpython-37.pyc
ADDED
Binary file (6.38 kB). View file
|
|
dnnlib/tflib/__pycache__/network.cpython-36.pyc
ADDED
Binary file (31 kB). View file
|
|
dnnlib/tflib/__pycache__/network.cpython-37.pyc
ADDED
Binary file (31 kB). View file
|
|
dnnlib/tflib/__pycache__/optimizer.cpython-36.pyc
ADDED
Binary file (8.52 kB). View file
|
|
dnnlib/tflib/__pycache__/optimizer.cpython-37.pyc
ADDED
Binary file (8.53 kB). View file
|
|
dnnlib/tflib/__pycache__/tfutil.cpython-36.pyc
ADDED
Binary file (8.47 kB). View file
|
|
dnnlib/tflib/__pycache__/tfutil.cpython-37.pyc
ADDED
Binary file (8.44 kB). View file
|
|
dnnlib/tflib/autosummary.py
ADDED
@@ -0,0 +1,184 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
#
|
3 |
+
# This work is licensed under the Creative Commons Attribution-NonCommercial
|
4 |
+
# 4.0 International License. To view a copy of this license, visit
|
5 |
+
# http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to
|
6 |
+
# Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.
|
7 |
+
|
8 |
+
"""Helper for adding automatically tracked values to Tensorboard.
|
9 |
+
|
10 |
+
Autosummary creates an identity op that internally keeps track of the input
|
11 |
+
values and automatically shows up in TensorBoard. The reported value
|
12 |
+
represents an average over input components. The average is accumulated
|
13 |
+
constantly over time and flushed when save_summaries() is called.
|
14 |
+
|
15 |
+
Notes:
|
16 |
+
- The output tensor must be used as an input for something else in the
|
17 |
+
graph. Otherwise, the autosummary op will not get executed, and the average
|
18 |
+
value will not get accumulated.
|
19 |
+
- It is perfectly fine to include autosummaries with the same name in
|
20 |
+
several places throughout the graph, even if they are executed concurrently.
|
21 |
+
- It is ok to also pass in a python scalar or numpy array. In this case, it
|
22 |
+
is added to the average immediately.
|
23 |
+
"""
|
24 |
+
|
25 |
+
from collections import OrderedDict
|
26 |
+
import numpy as np
|
27 |
+
import tensorflow as tf
|
28 |
+
from tensorboard import summary as summary_lib
|
29 |
+
from tensorboard.plugins.custom_scalar import layout_pb2
|
30 |
+
|
31 |
+
from . import tfutil
|
32 |
+
from .tfutil import TfExpression
|
33 |
+
from .tfutil import TfExpressionEx
|
34 |
+
|
35 |
+
_dtype = tf.float64
|
36 |
+
_vars = OrderedDict() # name => [var, ...]
|
37 |
+
_immediate = OrderedDict() # name => update_op, update_value
|
38 |
+
_finalized = False
|
39 |
+
_merge_op = None
|
40 |
+
|
41 |
+
|
42 |
+
def _create_var(name: str, value_expr: TfExpression) -> TfExpression:
|
43 |
+
"""Internal helper for creating autosummary accumulators."""
|
44 |
+
assert not _finalized
|
45 |
+
name_id = name.replace("/", "_")
|
46 |
+
v = tf.cast(value_expr, _dtype)
|
47 |
+
|
48 |
+
if v.shape.is_fully_defined():
|
49 |
+
size = np.prod(tfutil.shape_to_list(v.shape))
|
50 |
+
size_expr = tf.constant(size, dtype=_dtype)
|
51 |
+
else:
|
52 |
+
size = None
|
53 |
+
size_expr = tf.reduce_prod(tf.cast(tf.shape(v), _dtype))
|
54 |
+
|
55 |
+
if size == 1:
|
56 |
+
if v.shape.ndims != 0:
|
57 |
+
v = tf.reshape(v, [])
|
58 |
+
v = [size_expr, v, tf.square(v)]
|
59 |
+
else:
|
60 |
+
v = [size_expr, tf.reduce_sum(v), tf.reduce_sum(tf.square(v))]
|
61 |
+
v = tf.cond(tf.is_finite(v[1]), lambda: tf.stack(v), lambda: tf.zeros(3, dtype=_dtype))
|
62 |
+
|
63 |
+
with tfutil.absolute_name_scope("Autosummary/" + name_id), tf.control_dependencies(None):
|
64 |
+
var = tf.Variable(tf.zeros(3, dtype=_dtype), trainable=False) # [sum(1), sum(x), sum(x**2)]
|
65 |
+
update_op = tf.cond(tf.is_variable_initialized(var), lambda: tf.assign_add(var, v), lambda: tf.assign(var, v))
|
66 |
+
|
67 |
+
if name in _vars:
|
68 |
+
_vars[name].append(var)
|
69 |
+
else:
|
70 |
+
_vars[name] = [var]
|
71 |
+
return update_op
|
72 |
+
|
73 |
+
|
74 |
+
def autosummary(name: str, value: TfExpressionEx, passthru: TfExpressionEx = None) -> TfExpressionEx:
|
75 |
+
"""Create a new autosummary.
|
76 |
+
|
77 |
+
Args:
|
78 |
+
name: Name to use in TensorBoard
|
79 |
+
value: TensorFlow expression or python value to track
|
80 |
+
passthru: Optionally return this TF node without modifications but tack an autosummary update side-effect to this node.
|
81 |
+
|
82 |
+
Example use of the passthru mechanism:
|
83 |
+
|
84 |
+
n = autosummary('l2loss', loss, passthru=n)
|
85 |
+
|
86 |
+
This is a shorthand for the following code:
|
87 |
+
|
88 |
+
with tf.control_dependencies([autosummary('l2loss', loss)]):
|
89 |
+
n = tf.identity(n)
|
90 |
+
"""
|
91 |
+
tfutil.assert_tf_initialized()
|
92 |
+
name_id = name.replace("/", "_")
|
93 |
+
|
94 |
+
if tfutil.is_tf_expression(value):
|
95 |
+
with tf.name_scope("summary_" + name_id), tf.device(value.device):
|
96 |
+
update_op = _create_var(name, value)
|
97 |
+
with tf.control_dependencies([update_op]):
|
98 |
+
return tf.identity(value if passthru is None else passthru)
|
99 |
+
|
100 |
+
else: # python scalar or numpy array
|
101 |
+
if name not in _immediate:
|
102 |
+
with tfutil.absolute_name_scope("Autosummary/" + name_id), tf.device(None), tf.control_dependencies(None):
|
103 |
+
update_value = tf.placeholder(_dtype)
|
104 |
+
update_op = _create_var(name, update_value)
|
105 |
+
_immediate[name] = update_op, update_value
|
106 |
+
|
107 |
+
update_op, update_value = _immediate[name]
|
108 |
+
tfutil.run(update_op, {update_value: value})
|
109 |
+
return value if passthru is None else passthru
|
110 |
+
|
111 |
+
|
112 |
+
def finalize_autosummaries() -> None:
|
113 |
+
"""Create the necessary ops to include autosummaries in TensorBoard report.
|
114 |
+
Note: This should be done only once per graph.
|
115 |
+
"""
|
116 |
+
global _finalized
|
117 |
+
tfutil.assert_tf_initialized()
|
118 |
+
|
119 |
+
if _finalized:
|
120 |
+
return None
|
121 |
+
|
122 |
+
_finalized = True
|
123 |
+
tfutil.init_uninitialized_vars([var for vars_list in _vars.values() for var in vars_list])
|
124 |
+
|
125 |
+
# Create summary ops.
|
126 |
+
with tf.device(None), tf.control_dependencies(None):
|
127 |
+
for name, vars_list in _vars.items():
|
128 |
+
name_id = name.replace("/", "_")
|
129 |
+
with tfutil.absolute_name_scope("Autosummary/" + name_id):
|
130 |
+
moments = tf.add_n(vars_list)
|
131 |
+
moments /= moments[0]
|
132 |
+
with tf.control_dependencies([moments]): # read before resetting
|
133 |
+
reset_ops = [tf.assign(var, tf.zeros(3, dtype=_dtype)) for var in vars_list]
|
134 |
+
with tf.name_scope(None), tf.control_dependencies(reset_ops): # reset before reporting
|
135 |
+
mean = moments[1]
|
136 |
+
std = tf.sqrt(moments[2] - tf.square(moments[1]))
|
137 |
+
tf.summary.scalar(name, mean)
|
138 |
+
tf.summary.scalar("xCustomScalars/" + name + "/margin_lo", mean - std)
|
139 |
+
tf.summary.scalar("xCustomScalars/" + name + "/margin_hi", mean + std)
|
140 |
+
|
141 |
+
# Group by category and chart name.
|
142 |
+
cat_dict = OrderedDict()
|
143 |
+
for series_name in sorted(_vars.keys()):
|
144 |
+
p = series_name.split("/")
|
145 |
+
cat = p[0] if len(p) >= 2 else ""
|
146 |
+
chart = "/".join(p[1:-1]) if len(p) >= 3 else p[-1]
|
147 |
+
if cat not in cat_dict:
|
148 |
+
cat_dict[cat] = OrderedDict()
|
149 |
+
if chart not in cat_dict[cat]:
|
150 |
+
cat_dict[cat][chart] = []
|
151 |
+
cat_dict[cat][chart].append(series_name)
|
152 |
+
|
153 |
+
# Setup custom_scalar layout.
|
154 |
+
categories = []
|
155 |
+
for cat_name, chart_dict in cat_dict.items():
|
156 |
+
charts = []
|
157 |
+
for chart_name, series_names in chart_dict.items():
|
158 |
+
series = []
|
159 |
+
for series_name in series_names:
|
160 |
+
series.append(layout_pb2.MarginChartContent.Series(
|
161 |
+
value=series_name,
|
162 |
+
lower="xCustomScalars/" + series_name + "/margin_lo",
|
163 |
+
upper="xCustomScalars/" + series_name + "/margin_hi"))
|
164 |
+
margin = layout_pb2.MarginChartContent(series=series)
|
165 |
+
charts.append(layout_pb2.Chart(title=chart_name, margin=margin))
|
166 |
+
categories.append(layout_pb2.Category(title=cat_name, chart=charts))
|
167 |
+
layout = summary_lib.custom_scalar_pb(layout_pb2.Layout(category=categories))
|
168 |
+
return layout
|
169 |
+
|
170 |
+
def save_summaries(file_writer, global_step=None):
|
171 |
+
"""Call FileWriter.add_summary() with all summaries in the default graph,
|
172 |
+
automatically finalizing and merging them on the first call.
|
173 |
+
"""
|
174 |
+
global _merge_op
|
175 |
+
tfutil.assert_tf_initialized()
|
176 |
+
|
177 |
+
if _merge_op is None:
|
178 |
+
layout = finalize_autosummaries()
|
179 |
+
if layout is not None:
|
180 |
+
file_writer.add_summary(layout)
|
181 |
+
with tf.device(None), tf.control_dependencies(None):
|
182 |
+
_merge_op = tf.summary.merge_all()
|
183 |
+
|
184 |
+
file_writer.add_summary(_merge_op.eval(), global_step)
|
dnnlib/tflib/network.py
ADDED
@@ -0,0 +1,628 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
#
|
3 |
+
# This work is licensed under the Creative Commons Attribution-NonCommercial
|
4 |
+
# 4.0 International License. To view a copy of this license, visit
|
5 |
+
# http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to
|
6 |
+
# Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.
|
7 |
+
|
8 |
+
"""Helper for managing networks."""
|
9 |
+
|
10 |
+
import types
|
11 |
+
import inspect
|
12 |
+
import re
|
13 |
+
import uuid
|
14 |
+
import sys
|
15 |
+
import numpy as np
|
16 |
+
import tensorflow as tf
|
17 |
+
|
18 |
+
from collections import OrderedDict
|
19 |
+
from typing import Any, List, Tuple, Union
|
20 |
+
|
21 |
+
from . import tfutil
|
22 |
+
from .. import util
|
23 |
+
|
24 |
+
from .tfutil import TfExpression, TfExpressionEx
|
25 |
+
|
26 |
+
_import_handlers = [] # Custom import handlers for dealing with legacy data in pickle import.
|
27 |
+
_import_module_src = dict() # Source code for temporary modules created during pickle import.
|
28 |
+
|
29 |
+
|
30 |
+
def import_handler(handler_func):
|
31 |
+
"""Function decorator for declaring custom import handlers."""
|
32 |
+
_import_handlers.append(handler_func)
|
33 |
+
return handler_func
|
34 |
+
|
35 |
+
|
36 |
+
class Network:
|
37 |
+
"""Generic network abstraction.
|
38 |
+
|
39 |
+
Acts as a convenience wrapper for a parameterized network construction
|
40 |
+
function, providing several utility methods and convenient access to
|
41 |
+
the inputs/outputs/weights.
|
42 |
+
|
43 |
+
Network objects can be safely pickled and unpickled for long-term
|
44 |
+
archival purposes. The pickling works reliably as long as the underlying
|
45 |
+
network construction function is defined in a standalone Python module
|
46 |
+
that has no side effects or application-specific imports.
|
47 |
+
|
48 |
+
Args:
|
49 |
+
name: Network name. Used to select TensorFlow name and variable scopes.
|
50 |
+
func_name: Fully qualified name of the underlying network construction function, or a top-level function object.
|
51 |
+
static_kwargs: Keyword arguments to be passed in to the network construction function.
|
52 |
+
|
53 |
+
Attributes:
|
54 |
+
name: User-specified name, defaults to build func name if None.
|
55 |
+
scope: Unique TensorFlow scope containing template graph and variables, derived from the user-specified name.
|
56 |
+
static_kwargs: Arguments passed to the user-supplied build func.
|
57 |
+
components: Container for sub-networks. Passed to the build func, and retained between calls.
|
58 |
+
num_inputs: Number of input tensors.
|
59 |
+
num_outputs: Number of output tensors.
|
60 |
+
input_shapes: Input tensor shapes (NC or NCHW), including minibatch dimension.
|
61 |
+
output_shapes: Output tensor shapes (NC or NCHW), including minibatch dimension.
|
62 |
+
input_shape: Short-hand for input_shapes[0].
|
63 |
+
output_shape: Short-hand for output_shapes[0].
|
64 |
+
input_templates: Input placeholders in the template graph.
|
65 |
+
output_templates: Output tensors in the template graph.
|
66 |
+
input_names: Name string for each input.
|
67 |
+
output_names: Name string for each output.
|
68 |
+
own_vars: Variables defined by this network (local_name => var), excluding sub-networks.
|
69 |
+
vars: All variables (local_name => var).
|
70 |
+
trainables: All trainable variables (local_name => var).
|
71 |
+
var_global_to_local: Mapping from variable global names to local names.
|
72 |
+
"""
|
73 |
+
|
74 |
+
def __init__(self, name: str = None, func_name: Any = None, **static_kwargs):
|
75 |
+
tfutil.assert_tf_initialized()
|
76 |
+
assert isinstance(name, str) or name is None
|
77 |
+
assert func_name is not None
|
78 |
+
assert isinstance(func_name, str) or util.is_top_level_function(func_name)
|
79 |
+
assert util.is_pickleable(static_kwargs)
|
80 |
+
|
81 |
+
self._init_fields()
|
82 |
+
self.name = name
|
83 |
+
self.static_kwargs = util.EasyDict(static_kwargs)
|
84 |
+
|
85 |
+
# Locate the user-specified network build function.
|
86 |
+
if util.is_top_level_function(func_name):
|
87 |
+
func_name = util.get_top_level_function_name(func_name)
|
88 |
+
module, self._build_func_name = util.get_module_from_obj_name(func_name)
|
89 |
+
self._build_func = util.get_obj_from_module(module, self._build_func_name)
|
90 |
+
assert callable(self._build_func)
|
91 |
+
|
92 |
+
# Dig up source code for the module containing the build function.
|
93 |
+
self._build_module_src = _import_module_src.get(module, None)
|
94 |
+
if self._build_module_src is None:
|
95 |
+
self._build_module_src = inspect.getsource(module)
|
96 |
+
|
97 |
+
# Init TensorFlow graph.
|
98 |
+
self._init_graph()
|
99 |
+
self.reset_own_vars()
|
100 |
+
|
101 |
+
def _init_fields(self) -> None:
|
102 |
+
self.name = None
|
103 |
+
self.scope = None
|
104 |
+
self.static_kwargs = util.EasyDict()
|
105 |
+
self.components = util.EasyDict()
|
106 |
+
self.num_inputs = 0
|
107 |
+
self.num_outputs = 0
|
108 |
+
self.input_shapes = [[]]
|
109 |
+
self.output_shapes = [[]]
|
110 |
+
self.input_shape = []
|
111 |
+
self.output_shape = []
|
112 |
+
self.input_templates = []
|
113 |
+
self.output_templates = []
|
114 |
+
self.input_names = []
|
115 |
+
self.output_names = []
|
116 |
+
self.own_vars = OrderedDict()
|
117 |
+
self.vars = OrderedDict()
|
118 |
+
self.trainables = OrderedDict()
|
119 |
+
self.var_global_to_local = OrderedDict()
|
120 |
+
|
121 |
+
self._build_func = None # User-supplied build function that constructs the network.
|
122 |
+
self._build_func_name = None # Name of the build function.
|
123 |
+
self._build_module_src = None # Full source code of the module containing the build function.
|
124 |
+
self._run_cache = dict() # Cached graph data for Network.run().
|
125 |
+
|
126 |
+
def _init_graph(self) -> None:
|
127 |
+
# Collect inputs.
|
128 |
+
self.input_names = []
|
129 |
+
|
130 |
+
for param in inspect.signature(self._build_func).parameters.values():
|
131 |
+
if param.kind == param.POSITIONAL_OR_KEYWORD and param.default is param.empty:
|
132 |
+
self.input_names.append(param.name)
|
133 |
+
|
134 |
+
self.num_inputs = len(self.input_names)
|
135 |
+
assert self.num_inputs >= 1
|
136 |
+
|
137 |
+
# Choose name and scope.
|
138 |
+
if self.name is None:
|
139 |
+
self.name = self._build_func_name
|
140 |
+
assert re.match("^[A-Za-z0-9_.\\-]*$", self.name)
|
141 |
+
with tf.name_scope(None):
|
142 |
+
self.scope = tf.compat.v1.get_default_graph().unique_name(self.name, mark_as_used=True)
|
143 |
+
|
144 |
+
# Finalize build func kwargs.
|
145 |
+
build_kwargs = dict(self.static_kwargs)
|
146 |
+
build_kwargs["is_template_graph"] = True
|
147 |
+
build_kwargs["components"] = self.components
|
148 |
+
|
149 |
+
# Build template graph.
|
150 |
+
with tfutil.absolute_variable_scope(self.scope, reuse=tf.compat.v1.AUTO_REUSE), tfutil.absolute_name_scope(self.scope): # ignore surrounding scopes
|
151 |
+
assert tf.compat.v1.get_variable_scope().name == self.scope
|
152 |
+
assert tf.compat.v1.get_default_graph().get_name_scope() == self.scope
|
153 |
+
with tf.control_dependencies(None): # ignore surrounding control dependencies
|
154 |
+
self.input_templates = [tf.compat.v1.placeholder(tf.float32, name=name) for name in self.input_names]
|
155 |
+
out_expr = self._build_func(*self.input_templates, **build_kwargs)
|
156 |
+
|
157 |
+
# Collect outputs.
|
158 |
+
assert tfutil.is_tf_expression(out_expr) or isinstance(out_expr, tuple)
|
159 |
+
self.output_templates = [out_expr] if tfutil.is_tf_expression(out_expr) else list(out_expr)
|
160 |
+
self.num_outputs = len(self.output_templates)
|
161 |
+
assert self.num_outputs >= 1
|
162 |
+
assert all(tfutil.is_tf_expression(t) for t in self.output_templates)
|
163 |
+
|
164 |
+
# Perform sanity checks.
|
165 |
+
if any(t.shape.ndims is None for t in self.input_templates):
|
166 |
+
raise ValueError("Network input shapes not defined. Please call x.set_shape() for each input.")
|
167 |
+
if any(t.shape.ndims is None for t in self.output_templates):
|
168 |
+
raise ValueError("Network output shapes not defined. Please call x.set_shape() where applicable.")
|
169 |
+
if any(not isinstance(comp, Network) for comp in self.components.values()):
|
170 |
+
raise ValueError("Components of a Network must be Networks themselves.")
|
171 |
+
if len(self.components) != len(set(comp.name for comp in self.components.values())):
|
172 |
+
raise ValueError("Components of a Network must have unique names.")
|
173 |
+
|
174 |
+
# List inputs and outputs.
|
175 |
+
self.input_shapes = [tfutil.shape_to_list(t.shape) for t in self.input_templates]
|
176 |
+
self.output_shapes = [tfutil.shape_to_list(t.shape) for t in self.output_templates]
|
177 |
+
self.input_shape = self.input_shapes[0]
|
178 |
+
self.output_shape = self.output_shapes[0]
|
179 |
+
self.output_names = [t.name.split("/")[-1].split(":")[0] for t in self.output_templates]
|
180 |
+
|
181 |
+
# List variables.
|
182 |
+
self.own_vars = OrderedDict((var.name[len(self.scope) + 1:].split(":")[0], var) for var in tf.compat.v1.global_variables(self.scope + "/"))
|
183 |
+
self.vars = OrderedDict(self.own_vars)
|
184 |
+
self.vars.update((comp.name + "/" + name, var) for comp in self.components.values() for name, var in comp.vars.items())
|
185 |
+
self.trainables = OrderedDict((name, var) for name, var in self.vars.items() if var.trainable)
|
186 |
+
self.var_global_to_local = OrderedDict((var.name.split(":")[0], name) for name, var in self.vars.items())
|
187 |
+
|
188 |
+
def reset_own_vars(self) -> None:
|
189 |
+
"""Re-initialize all variables of this network, excluding sub-networks."""
|
190 |
+
tfutil.run([var.initializer for var in self.own_vars.values()])
|
191 |
+
|
192 |
+
def reset_vars(self) -> None:
|
193 |
+
"""Re-initialize all variables of this network, including sub-networks."""
|
194 |
+
tfutil.run([var.initializer for var in self.vars.values()])
|
195 |
+
|
196 |
+
def reset_trainables(self) -> None:
|
197 |
+
"""Re-initialize all trainable variables of this network, including sub-networks."""
|
198 |
+
tfutil.run([var.initializer for var in self.trainables.values()])
|
199 |
+
|
200 |
+
def get_output_for(self, *in_expr: TfExpression, return_as_list: bool = False, **dynamic_kwargs) -> Union[TfExpression, List[TfExpression]]:
|
201 |
+
"""Construct TensorFlow expression(s) for the output(s) of this network, given the input expression(s)."""
|
202 |
+
assert len(in_expr) == self.num_inputs
|
203 |
+
assert not all(expr is None for expr in in_expr)
|
204 |
+
|
205 |
+
# Finalize build func kwargs.
|
206 |
+
build_kwargs = dict(self.static_kwargs)
|
207 |
+
build_kwargs.update(dynamic_kwargs)
|
208 |
+
build_kwargs["is_template_graph"] = False
|
209 |
+
build_kwargs["components"] = self.components
|
210 |
+
|
211 |
+
# Build TensorFlow graph to evaluate the network.
|
212 |
+
with tfutil.absolute_variable_scope(self.scope, reuse=True), tf.name_scope(self.name):
|
213 |
+
assert tf.compat.v1.get_variable_scope().name == self.scope
|
214 |
+
valid_inputs = [expr for expr in in_expr if expr is not None]
|
215 |
+
final_inputs = []
|
216 |
+
for expr, name, shape in zip(in_expr, self.input_names, self.input_shapes):
|
217 |
+
if expr is not None:
|
218 |
+
expr = tf.identity(expr, name=name)
|
219 |
+
else:
|
220 |
+
expr = tf.zeros([tf.shape(valid_inputs[0])[0]] + shape[1:], name=name)
|
221 |
+
final_inputs.append(expr)
|
222 |
+
out_expr = self._build_func(*final_inputs, **build_kwargs)
|
223 |
+
|
224 |
+
# Propagate input shapes back to the user-specified expressions.
|
225 |
+
for expr, final in zip(in_expr, final_inputs):
|
226 |
+
if isinstance(expr, tf.Tensor):
|
227 |
+
expr.set_shape(final.shape)
|
228 |
+
|
229 |
+
# Express outputs in the desired format.
|
230 |
+
assert tfutil.is_tf_expression(out_expr) or isinstance(out_expr, tuple)
|
231 |
+
if return_as_list:
|
232 |
+
out_expr = [out_expr] if tfutil.is_tf_expression(out_expr) else list(out_expr)
|
233 |
+
return out_expr
|
234 |
+
|
235 |
+
def get_var_local_name(self, var_or_global_name: Union[TfExpression, str]) -> str:
|
236 |
+
"""Get the local name of a given variable, without any surrounding name scopes."""
|
237 |
+
assert tfutil.is_tf_expression(var_or_global_name) or isinstance(var_or_global_name, str)
|
238 |
+
global_name = var_or_global_name if isinstance(var_or_global_name, str) else var_or_global_name.name
|
239 |
+
return self.var_global_to_local[global_name]
|
240 |
+
|
241 |
+
def find_var(self, var_or_local_name: Union[TfExpression, str]) -> TfExpression:
|
242 |
+
"""Find variable by local or global name."""
|
243 |
+
assert tfutil.is_tf_expression(var_or_local_name) or isinstance(var_or_local_name, str)
|
244 |
+
return self.vars[var_or_local_name] if isinstance(var_or_local_name, str) else var_or_local_name
|
245 |
+
|
246 |
+
def get_var(self, var_or_local_name: Union[TfExpression, str]) -> np.ndarray:
|
247 |
+
"""Get the value of a given variable as NumPy array.
|
248 |
+
Note: This method is very inefficient -- prefer to use tflib.run(list_of_vars) whenever possible."""
|
249 |
+
return self.find_var(var_or_local_name).eval()
|
250 |
+
|
251 |
+
def set_var(self, var_or_local_name: Union[TfExpression, str], new_value: Union[int, float, np.ndarray]) -> None:
|
252 |
+
"""Set the value of a given variable based on the given NumPy array.
|
253 |
+
Note: This method is very inefficient -- prefer to use tflib.set_vars() whenever possible."""
|
254 |
+
tfutil.set_vars({self.find_var(var_or_local_name): new_value})
|
255 |
+
|
256 |
+
def __getstate__(self) -> dict:
|
257 |
+
"""Pickle export."""
|
258 |
+
state = dict()
|
259 |
+
state["version"] = 3
|
260 |
+
state["name"] = self.name
|
261 |
+
state["static_kwargs"] = dict(self.static_kwargs)
|
262 |
+
state["components"] = dict(self.components)
|
263 |
+
state["build_module_src"] = self._build_module_src
|
264 |
+
state["build_func_name"] = self._build_func_name
|
265 |
+
state["variables"] = list(zip(self.own_vars.keys(), tfutil.run(list(self.own_vars.values()))))
|
266 |
+
return state
|
267 |
+
|
268 |
+
def __setstate__(self, state: dict) -> None:
|
269 |
+
"""Pickle import."""
|
270 |
+
# pylint: disable=attribute-defined-outside-init
|
271 |
+
tfutil.assert_tf_initialized()
|
272 |
+
self._init_fields()
|
273 |
+
|
274 |
+
# Execute custom import handlers.
|
275 |
+
for handler in _import_handlers:
|
276 |
+
state = handler(state)
|
277 |
+
|
278 |
+
# Set basic fields.
|
279 |
+
assert state["version"] in [2, 3]
|
280 |
+
self.name = state["name"]
|
281 |
+
self.static_kwargs = util.EasyDict(state["static_kwargs"])
|
282 |
+
self.components = util.EasyDict(state.get("components", {}))
|
283 |
+
self._build_module_src = state["build_module_src"]
|
284 |
+
self._build_func_name = state["build_func_name"]
|
285 |
+
|
286 |
+
# Create temporary module from the imported source code.
|
287 |
+
module_name = "_tflib_network_import_" + uuid.uuid4().hex
|
288 |
+
module = types.ModuleType(module_name)
|
289 |
+
sys.modules[module_name] = module
|
290 |
+
_import_module_src[module] = self._build_module_src
|
291 |
+
exec(self._build_module_src, module.__dict__) # pylint: disable=exec-used
|
292 |
+
|
293 |
+
# Locate network build function in the temporary module.
|
294 |
+
self._build_func = util.get_obj_from_module(module, self._build_func_name)
|
295 |
+
assert callable(self._build_func)
|
296 |
+
|
297 |
+
# Init TensorFlow graph.
|
298 |
+
self._init_graph()
|
299 |
+
self.reset_own_vars()
|
300 |
+
tfutil.set_vars({self.find_var(name): value for name, value in state["variables"]})
|
301 |
+
|
302 |
+
def clone(self, name: str = None, **new_static_kwargs) -> "Network":
|
303 |
+
"""Create a clone of this network with its own copy of the variables."""
|
304 |
+
# pylint: disable=protected-access
|
305 |
+
net = object.__new__(Network)
|
306 |
+
net._init_fields()
|
307 |
+
net.name = name if name is not None else self.name
|
308 |
+
net.static_kwargs = util.EasyDict(self.static_kwargs)
|
309 |
+
net.static_kwargs.update(new_static_kwargs)
|
310 |
+
net._build_module_src = self._build_module_src
|
311 |
+
net._build_func_name = self._build_func_name
|
312 |
+
net._build_func = self._build_func
|
313 |
+
net._init_graph()
|
314 |
+
net.copy_vars_from(self)
|
315 |
+
return net
|
316 |
+
|
317 |
+
def copy_own_vars_from(self, src_net: "Network") -> None:
|
318 |
+
"""Copy the values of all variables from the given network, excluding sub-networks."""
|
319 |
+
names = [name for name in self.own_vars.keys() if name in src_net.own_vars]
|
320 |
+
tfutil.set_vars(tfutil.run({self.vars[name]: src_net.vars[name] for name in names}))
|
321 |
+
|
322 |
+
def copy_vars_from(self, src_net: "Network") -> None:
|
323 |
+
"""Copy the values of all variables from the given network, including sub-networks."""
|
324 |
+
names = [name for name in self.vars.keys() if name in src_net.vars]
|
325 |
+
tfutil.set_vars(tfutil.run({self.vars[name]: src_net.vars[name] for name in names}))
|
326 |
+
|
327 |
+
def copy_trainables_from(self, src_net: "Network") -> None:
|
328 |
+
"""Copy the values of all trainable variables from the given network, including sub-networks."""
|
329 |
+
names = [name for name in self.trainables.keys() if name in src_net.trainables]
|
330 |
+
tfutil.set_vars(tfutil.run({self.vars[name]: src_net.vars[name] for name in names}))
|
331 |
+
|
332 |
+
def copy_compatible_trainables_from(self, src_net: "Network") -> None:
|
333 |
+
"""Copy the compatible values of all trainable variables from the given network, including sub-networks"""
|
334 |
+
names = []
|
335 |
+
for name in self.trainables.keys():
|
336 |
+
if name not in src_net.trainables:
|
337 |
+
print("Not restoring (not present): {}".format(name))
|
338 |
+
elif self.trainables[name].shape != src_net.trainables[name].shape:
|
339 |
+
print("Not restoring (different shape): {}".format(name))
|
340 |
+
|
341 |
+
if name in src_net.trainables and self.trainables[name].shape == src_net.trainables[name].shape:
|
342 |
+
names.append(name)
|
343 |
+
|
344 |
+
tfutil.set_vars(tfutil.run({self.vars[name]: src_net.vars[name] for name in names}))
|
345 |
+
|
346 |
+
def apply_swa(self, src_net, epoch):
|
347 |
+
"""Perform stochastic weight averaging on the compatible values of all trainable variables from the given network, including sub-networks"""
|
348 |
+
names = []
|
349 |
+
for name in self.trainables.keys():
|
350 |
+
if name not in src_net.trainables:
|
351 |
+
print("Not restoring (not present): {}".format(name))
|
352 |
+
elif self.trainables[name].shape != src_net.trainables[name].shape:
|
353 |
+
print("Not restoring (different shape): {}".format(name))
|
354 |
+
|
355 |
+
if name in src_net.trainables and self.trainables[name].shape == src_net.trainables[name].shape:
|
356 |
+
names.append(name)
|
357 |
+
|
358 |
+
scale_new_data = 1.0 / (epoch + 1)
|
359 |
+
scale_moving_average = (1.0 - scale_new_data)
|
360 |
+
tfutil.set_vars(tfutil.run({self.vars[name]: (src_net.vars[name] * scale_new_data + self.vars[name] * scale_moving_average) for name in names}))
|
361 |
+
|
362 |
+
def convert(self, new_func_name: str, new_name: str = None, **new_static_kwargs) -> "Network":
|
363 |
+
"""Create new network with the given parameters, and copy all variables from this network."""
|
364 |
+
if new_name is None:
|
365 |
+
new_name = self.name
|
366 |
+
static_kwargs = dict(self.static_kwargs)
|
367 |
+
static_kwargs.update(new_static_kwargs)
|
368 |
+
net = Network(name=new_name, func_name=new_func_name, **static_kwargs)
|
369 |
+
net.copy_vars_from(self)
|
370 |
+
return net
|
371 |
+
|
372 |
+
def setup_as_moving_average_of(self, src_net: "Network", beta: TfExpressionEx = 0.99, beta_nontrainable: TfExpressionEx = 0.0) -> tf.Operation:
|
373 |
+
"""Construct a TensorFlow op that updates the variables of this network
|
374 |
+
to be slightly closer to those of the given network."""
|
375 |
+
with tfutil.absolute_name_scope(self.scope + "/_MovingAvg"):
|
376 |
+
ops = []
|
377 |
+
for name, var in self.vars.items():
|
378 |
+
if name in src_net.vars:
|
379 |
+
cur_beta = beta if name in self.trainables else beta_nontrainable
|
380 |
+
new_value = tfutil.lerp(src_net.vars[name], var, cur_beta)
|
381 |
+
ops.append(var.assign(new_value))
|
382 |
+
return tf.group(*ops)
|
383 |
+
|
384 |
+
def run(self,
|
385 |
+
*in_arrays: Tuple[Union[np.ndarray, None], ...],
|
386 |
+
input_transform: dict = None,
|
387 |
+
output_transform: dict = None,
|
388 |
+
return_as_list: bool = False,
|
389 |
+
print_progress: bool = False,
|
390 |
+
minibatch_size: int = None,
|
391 |
+
num_gpus: int = 1,
|
392 |
+
assume_frozen: bool = False,
|
393 |
+
custom_inputs=None,
|
394 |
+
**dynamic_kwargs) -> Union[np.ndarray, Tuple[np.ndarray, ...], List[np.ndarray]]:
|
395 |
+
"""Run this network for the given NumPy array(s), and return the output(s) as NumPy array(s).
|
396 |
+
|
397 |
+
Args:
|
398 |
+
input_transform: A dict specifying a custom transformation to be applied to the input tensor(s) before evaluating the network.
|
399 |
+
The dict must contain a 'func' field that points to a top-level function. The function is called with the input
|
400 |
+
TensorFlow expression(s) as positional arguments. Any remaining fields of the dict will be passed in as kwargs.
|
401 |
+
output_transform: A dict specifying a custom transformation to be applied to the output tensor(s) after evaluating the network.
|
402 |
+
The dict must contain a 'func' field that points to a top-level function. The function is called with the output
|
403 |
+
TensorFlow expression(s) as positional arguments. Any remaining fields of the dict will be passed in as kwargs.
|
404 |
+
return_as_list: True = return a list of NumPy arrays, False = return a single NumPy array, or a tuple if there are multiple outputs.
|
405 |
+
print_progress: Print progress to the console? Useful for very large input arrays.
|
406 |
+
minibatch_size: Maximum minibatch size to use, None = disable batching.
|
407 |
+
num_gpus: Number of GPUs to use.
|
408 |
+
assume_frozen: Improve multi-GPU performance by assuming that the trainable parameters will remain changed between calls.
|
409 |
+
dynamic_kwargs: Additional keyword arguments to be passed into the network build function.
|
410 |
+
custom_inputs: Allow to use another Tensor as input instead of default Placeholders
|
411 |
+
"""
|
412 |
+
assert len(in_arrays) == self.num_inputs
|
413 |
+
assert not all(arr is None for arr in in_arrays)
|
414 |
+
assert input_transform is None or util.is_top_level_function(input_transform["func"])
|
415 |
+
assert output_transform is None or util.is_top_level_function(output_transform["func"])
|
416 |
+
output_transform, dynamic_kwargs = _handle_legacy_output_transforms(output_transform, dynamic_kwargs)
|
417 |
+
num_items = in_arrays[0].shape[0]
|
418 |
+
if minibatch_size is None:
|
419 |
+
minibatch_size = num_items
|
420 |
+
|
421 |
+
# Construct unique hash key from all arguments that affect the TensorFlow graph.
|
422 |
+
key = dict(input_transform=input_transform, output_transform=output_transform, num_gpus=num_gpus, assume_frozen=assume_frozen, dynamic_kwargs=dynamic_kwargs)
|
423 |
+
def unwind_key(obj):
|
424 |
+
if isinstance(obj, dict):
|
425 |
+
return [(key, unwind_key(value)) for key, value in sorted(obj.items())]
|
426 |
+
if callable(obj):
|
427 |
+
return util.get_top_level_function_name(obj)
|
428 |
+
return obj
|
429 |
+
key = repr(unwind_key(key))
|
430 |
+
|
431 |
+
# Build graph.
|
432 |
+
if key not in self._run_cache:
|
433 |
+
with tfutil.absolute_name_scope(self.scope + "/_Run"), tf.control_dependencies(None):
|
434 |
+
if custom_inputs is not None:
|
435 |
+
with tf.device("/gpu:0"):
|
436 |
+
in_expr = [input_builder(name) for input_builder, name in zip(custom_inputs, self.input_names)]
|
437 |
+
in_split = list(zip(*[tf.split(x, num_gpus) for x in in_expr]))
|
438 |
+
else:
|
439 |
+
with tf.device("/cpu:0"):
|
440 |
+
in_expr = [tf.placeholder(tf.float32, name=name) for name in self.input_names]
|
441 |
+
in_split = list(zip(*[tf.split(x, num_gpus) for x in in_expr]))
|
442 |
+
|
443 |
+
out_split = []
|
444 |
+
for gpu in range(num_gpus):
|
445 |
+
with tf.device("/gpu:%d" % gpu):
|
446 |
+
net_gpu = self.clone() if assume_frozen else self
|
447 |
+
in_gpu = in_split[gpu]
|
448 |
+
|
449 |
+
if input_transform is not None:
|
450 |
+
in_kwargs = dict(input_transform)
|
451 |
+
in_gpu = in_kwargs.pop("func")(*in_gpu, **in_kwargs)
|
452 |
+
in_gpu = [in_gpu] if tfutil.is_tf_expression(in_gpu) else list(in_gpu)
|
453 |
+
|
454 |
+
assert len(in_gpu) == self.num_inputs
|
455 |
+
out_gpu = net_gpu.get_output_for(*in_gpu, return_as_list=True, **dynamic_kwargs)
|
456 |
+
|
457 |
+
if output_transform is not None:
|
458 |
+
out_kwargs = dict(output_transform)
|
459 |
+
out_gpu = out_kwargs.pop("func")(*out_gpu, **out_kwargs)
|
460 |
+
out_gpu = [out_gpu] if tfutil.is_tf_expression(out_gpu) else list(out_gpu)
|
461 |
+
|
462 |
+
assert len(out_gpu) == self.num_outputs
|
463 |
+
out_split.append(out_gpu)
|
464 |
+
|
465 |
+
with tf.device("/cpu:0"):
|
466 |
+
out_expr = [tf.concat(outputs, axis=0) for outputs in zip(*out_split)]
|
467 |
+
self._run_cache[key] = in_expr, out_expr
|
468 |
+
|
469 |
+
# Run minibatches.
|
470 |
+
in_expr, out_expr = self._run_cache[key]
|
471 |
+
out_arrays = [np.empty([num_items] + tfutil.shape_to_list(expr.shape)[1:], expr.dtype.name) for expr in out_expr]
|
472 |
+
|
473 |
+
for mb_begin in range(0, num_items, minibatch_size):
|
474 |
+
if print_progress:
|
475 |
+
print("\r%d / %d" % (mb_begin, num_items), end="")
|
476 |
+
|
477 |
+
mb_end = min(mb_begin + minibatch_size, num_items)
|
478 |
+
mb_num = mb_end - mb_begin
|
479 |
+
mb_in = [src[mb_begin : mb_end] if src is not None else np.zeros([mb_num] + shape[1:]) for src, shape in zip(in_arrays, self.input_shapes)]
|
480 |
+
mb_out = tf.compat.v1.get_default_session().run(out_expr, dict(zip(in_expr, mb_in)))
|
481 |
+
|
482 |
+
for dst, src in zip(out_arrays, mb_out):
|
483 |
+
dst[mb_begin: mb_end] = src
|
484 |
+
|
485 |
+
# Done.
|
486 |
+
if print_progress:
|
487 |
+
print("\r%d / %d" % (num_items, num_items))
|
488 |
+
|
489 |
+
if not return_as_list:
|
490 |
+
out_arrays = out_arrays[0] if len(out_arrays) == 1 else tuple(out_arrays)
|
491 |
+
return out_arrays
|
492 |
+
|
493 |
+
def list_ops(self) -> List[TfExpression]:
|
494 |
+
include_prefix = self.scope + "/"
|
495 |
+
exclude_prefix = include_prefix + "_"
|
496 |
+
ops = tf.get_default_graph().get_operations()
|
497 |
+
ops = [op for op in ops if op.name.startswith(include_prefix)]
|
498 |
+
ops = [op for op in ops if not op.name.startswith(exclude_prefix)]
|
499 |
+
return ops
|
500 |
+
|
501 |
+
def list_layers(self) -> List[Tuple[str, TfExpression, List[TfExpression]]]:
|
502 |
+
"""Returns a list of (layer_name, output_expr, trainable_vars) tuples corresponding to
|
503 |
+
individual layers of the network. Mainly intended to be used for reporting."""
|
504 |
+
layers = []
|
505 |
+
|
506 |
+
def recurse(scope, parent_ops, parent_vars, level):
|
507 |
+
# Ignore specific patterns.
|
508 |
+
if any(p in scope for p in ["/Shape", "/strided_slice", "/Cast", "/concat", "/Assign"]):
|
509 |
+
return
|
510 |
+
|
511 |
+
# Filter ops and vars by scope.
|
512 |
+
global_prefix = scope + "/"
|
513 |
+
local_prefix = global_prefix[len(self.scope) + 1:]
|
514 |
+
cur_ops = [op for op in parent_ops if op.name.startswith(global_prefix) or op.name == global_prefix[:-1]]
|
515 |
+
cur_vars = [(name, var) for name, var in parent_vars if name.startswith(local_prefix) or name == local_prefix[:-1]]
|
516 |
+
if not cur_ops and not cur_vars:
|
517 |
+
return
|
518 |
+
|
519 |
+
# Filter out all ops related to variables.
|
520 |
+
for var in [op for op in cur_ops if op.type.startswith("Variable")]:
|
521 |
+
var_prefix = var.name + "/"
|
522 |
+
cur_ops = [op for op in cur_ops if not op.name.startswith(var_prefix)]
|
523 |
+
|
524 |
+
# Scope does not contain ops as immediate children => recurse deeper.
|
525 |
+
contains_direct_ops = any("/" not in op.name[len(global_prefix):] and op.type != "Identity" for op in cur_ops)
|
526 |
+
if (level == 0 or not contains_direct_ops) and (len(cur_ops) + len(cur_vars)) > 1:
|
527 |
+
visited = set()
|
528 |
+
for rel_name in [op.name[len(global_prefix):] for op in cur_ops] + [name[len(local_prefix):] for name, _var in cur_vars]:
|
529 |
+
token = rel_name.split("/")[0]
|
530 |
+
if token not in visited:
|
531 |
+
recurse(global_prefix + token, cur_ops, cur_vars, level + 1)
|
532 |
+
visited.add(token)
|
533 |
+
return
|
534 |
+
|
535 |
+
# Report layer.
|
536 |
+
layer_name = scope[len(self.scope) + 1:]
|
537 |
+
layer_output = cur_ops[-1].outputs[0] if cur_ops else cur_vars[-1][1]
|
538 |
+
layer_trainables = [var for _name, var in cur_vars if var.trainable]
|
539 |
+
layers.append((layer_name, layer_output, layer_trainables))
|
540 |
+
|
541 |
+
recurse(self.scope, self.list_ops(), list(self.vars.items()), 0)
|
542 |
+
return layers
|
543 |
+
|
544 |
+
def print_layers(self, title: str = None, hide_layers_with_no_params: bool = False) -> None:
|
545 |
+
"""Print a summary table of the network structure."""
|
546 |
+
rows = [[title if title is not None else self.name, "Params", "OutputShape", "WeightShape"]]
|
547 |
+
rows += [["---"] * 4]
|
548 |
+
total_params = 0
|
549 |
+
|
550 |
+
for layer_name, layer_output, layer_trainables in self.list_layers():
|
551 |
+
num_params = sum(np.prod(tfutil.shape_to_list(var.shape)) for var in layer_trainables)
|
552 |
+
weights = [var for var in layer_trainables if var.name.endswith("/weight:0") or var.name.endswith("/weight_1:0")]
|
553 |
+
weights.sort(key=lambda x: len(x.name))
|
554 |
+
if len(weights) == 0 and len(layer_trainables) == 1:
|
555 |
+
weights = layer_trainables
|
556 |
+
total_params += num_params
|
557 |
+
|
558 |
+
if not hide_layers_with_no_params or num_params != 0:
|
559 |
+
num_params_str = str(num_params) if num_params > 0 else "-"
|
560 |
+
output_shape_str = str(layer_output.shape)
|
561 |
+
weight_shape_str = str(weights[0].shape) if len(weights) >= 1 else "-"
|
562 |
+
rows += [[layer_name, num_params_str, output_shape_str, weight_shape_str]]
|
563 |
+
|
564 |
+
rows += [["---"] * 4]
|
565 |
+
rows += [["Total", str(total_params), "", ""]]
|
566 |
+
|
567 |
+
widths = [max(len(cell) for cell in column) for column in zip(*rows)]
|
568 |
+
print()
|
569 |
+
for row in rows:
|
570 |
+
print(" ".join(cell + " " * (width - len(cell)) for cell, width in zip(row, widths)))
|
571 |
+
print()
|
572 |
+
|
573 |
+
def setup_weight_histograms(self, title: str = None) -> None:
|
574 |
+
"""Construct summary ops to include histograms of all trainable parameters in TensorBoard."""
|
575 |
+
if title is None:
|
576 |
+
title = self.name
|
577 |
+
|
578 |
+
with tf.name_scope(None), tf.device(None), tf.control_dependencies(None):
|
579 |
+
for local_name, var in self.trainables.items():
|
580 |
+
if "/" in local_name:
|
581 |
+
p = local_name.split("/")
|
582 |
+
name = title + "_" + p[-1] + "/" + "_".join(p[:-1])
|
583 |
+
else:
|
584 |
+
name = title + "_toplevel/" + local_name
|
585 |
+
|
586 |
+
tf.summary.histogram(name, var)
|
587 |
+
|
588 |
+
#----------------------------------------------------------------------------
|
589 |
+
# Backwards-compatible emulation of legacy output transformation in Network.run().
|
590 |
+
|
591 |
+
_print_legacy_warning = True
|
592 |
+
|
593 |
+
def _handle_legacy_output_transforms(output_transform, dynamic_kwargs):
|
594 |
+
global _print_legacy_warning
|
595 |
+
legacy_kwargs = ["out_mul", "out_add", "out_shrink", "out_dtype"]
|
596 |
+
if not any(kwarg in dynamic_kwargs for kwarg in legacy_kwargs):
|
597 |
+
return output_transform, dynamic_kwargs
|
598 |
+
|
599 |
+
if _print_legacy_warning:
|
600 |
+
_print_legacy_warning = False
|
601 |
+
print()
|
602 |
+
print("WARNING: Old-style output transformations in Network.run() are deprecated.")
|
603 |
+
print("Consider using 'output_transform=dict(func=tflib.convert_images_to_uint8)'")
|
604 |
+
print("instead of 'out_mul=127.5, out_add=127.5, out_dtype=np.uint8'.")
|
605 |
+
print()
|
606 |
+
assert output_transform is None
|
607 |
+
|
608 |
+
new_kwargs = dict(dynamic_kwargs)
|
609 |
+
new_transform = {kwarg: new_kwargs.pop(kwarg) for kwarg in legacy_kwargs if kwarg in dynamic_kwargs}
|
610 |
+
new_transform["func"] = _legacy_output_transform_func
|
611 |
+
return new_transform, new_kwargs
|
612 |
+
|
613 |
+
def _legacy_output_transform_func(*expr, out_mul=1.0, out_add=0.0, out_shrink=1, out_dtype=None):
|
614 |
+
if out_mul != 1.0:
|
615 |
+
expr = [x * out_mul for x in expr]
|
616 |
+
|
617 |
+
if out_add != 0.0:
|
618 |
+
expr = [x + out_add for x in expr]
|
619 |
+
|
620 |
+
if out_shrink > 1:
|
621 |
+
ksize = [1, 1, out_shrink, out_shrink]
|
622 |
+
expr = [tf.nn.avg_pool(x, ksize=ksize, strides=ksize, padding="VALID", data_format="NCHW") for x in expr]
|
623 |
+
|
624 |
+
if out_dtype is not None:
|
625 |
+
if tf.as_dtype(out_dtype).is_integer:
|
626 |
+
expr = [tf.round(x) for x in expr]
|
627 |
+
expr = [tf.saturate_cast(x, out_dtype) for x in expr]
|
628 |
+
return expr
|
dnnlib/tflib/optimizer.py
ADDED
@@ -0,0 +1,214 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
#
|
3 |
+
# This work is licensed under the Creative Commons Attribution-NonCommercial
|
4 |
+
# 4.0 International License. To view a copy of this license, visit
|
5 |
+
# http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to
|
6 |
+
# Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.
|
7 |
+
|
8 |
+
"""Helper wrapper for a Tensorflow optimizer."""
|
9 |
+
|
10 |
+
import numpy as np
|
11 |
+
import tensorflow as tf
|
12 |
+
|
13 |
+
from collections import OrderedDict
|
14 |
+
from typing import List, Union
|
15 |
+
|
16 |
+
from . import autosummary
|
17 |
+
from . import tfutil
|
18 |
+
from .. import util
|
19 |
+
|
20 |
+
from .tfutil import TfExpression, TfExpressionEx
|
21 |
+
|
22 |
+
try:
|
23 |
+
# TensorFlow 1.13
|
24 |
+
from tensorflow.python.ops import nccl_ops
|
25 |
+
except:
|
26 |
+
# Older TensorFlow versions
|
27 |
+
import tensorflow.contrib.nccl as nccl_ops
|
28 |
+
|
29 |
+
class Optimizer:
|
30 |
+
"""A Wrapper for tf.train.Optimizer.
|
31 |
+
|
32 |
+
Automatically takes care of:
|
33 |
+
- Gradient averaging for multi-GPU training.
|
34 |
+
- Dynamic loss scaling and typecasts for FP16 training.
|
35 |
+
- Ignoring corrupted gradients that contain NaNs/Infs.
|
36 |
+
- Reporting statistics.
|
37 |
+
- Well-chosen default settings.
|
38 |
+
"""
|
39 |
+
|
40 |
+
def __init__(self,
|
41 |
+
name: str = "Train",
|
42 |
+
tf_optimizer: str = "tf.train.AdamOptimizer",
|
43 |
+
learning_rate: TfExpressionEx = 0.001,
|
44 |
+
use_loss_scaling: bool = False,
|
45 |
+
loss_scaling_init: float = 64.0,
|
46 |
+
loss_scaling_inc: float = 0.0005,
|
47 |
+
loss_scaling_dec: float = 1.0,
|
48 |
+
**kwargs):
|
49 |
+
|
50 |
+
# Init fields.
|
51 |
+
self.name = name
|
52 |
+
self.learning_rate = tf.convert_to_tensor(learning_rate)
|
53 |
+
self.id = self.name.replace("/", ".")
|
54 |
+
self.scope = tf.get_default_graph().unique_name(self.id)
|
55 |
+
self.optimizer_class = util.get_obj_by_name(tf_optimizer)
|
56 |
+
self.optimizer_kwargs = dict(kwargs)
|
57 |
+
self.use_loss_scaling = use_loss_scaling
|
58 |
+
self.loss_scaling_init = loss_scaling_init
|
59 |
+
self.loss_scaling_inc = loss_scaling_inc
|
60 |
+
self.loss_scaling_dec = loss_scaling_dec
|
61 |
+
self._grad_shapes = None # [shape, ...]
|
62 |
+
self._dev_opt = OrderedDict() # device => optimizer
|
63 |
+
self._dev_grads = OrderedDict() # device => [[(grad, var), ...], ...]
|
64 |
+
self._dev_ls_var = OrderedDict() # device => variable (log2 of loss scaling factor)
|
65 |
+
self._updates_applied = False
|
66 |
+
|
67 |
+
def register_gradients(self, loss: TfExpression, trainable_vars: Union[List, dict]) -> None:
|
68 |
+
"""Register the gradients of the given loss function with respect to the given variables.
|
69 |
+
Intended to be called once per GPU."""
|
70 |
+
assert not self._updates_applied
|
71 |
+
|
72 |
+
# Validate arguments.
|
73 |
+
if isinstance(trainable_vars, dict):
|
74 |
+
trainable_vars = list(trainable_vars.values()) # allow passing in Network.trainables as vars
|
75 |
+
|
76 |
+
assert isinstance(trainable_vars, list) and len(trainable_vars) >= 1
|
77 |
+
assert all(tfutil.is_tf_expression(expr) for expr in trainable_vars + [loss])
|
78 |
+
|
79 |
+
if self._grad_shapes is None:
|
80 |
+
self._grad_shapes = [tfutil.shape_to_list(var.shape) for var in trainable_vars]
|
81 |
+
|
82 |
+
assert len(trainable_vars) == len(self._grad_shapes)
|
83 |
+
assert all(tfutil.shape_to_list(var.shape) == var_shape for var, var_shape in zip(trainable_vars, self._grad_shapes))
|
84 |
+
|
85 |
+
dev = loss.device
|
86 |
+
|
87 |
+
assert all(var.device == dev for var in trainable_vars)
|
88 |
+
|
89 |
+
# Register device and compute gradients.
|
90 |
+
with tf.name_scope(self.id + "_grad"), tf.device(dev):
|
91 |
+
if dev not in self._dev_opt:
|
92 |
+
opt_name = self.scope.replace("/", "_") + "_opt%d" % len(self._dev_opt)
|
93 |
+
assert callable(self.optimizer_class)
|
94 |
+
self._dev_opt[dev] = self.optimizer_class(name=opt_name, learning_rate=self.learning_rate, **self.optimizer_kwargs)
|
95 |
+
self._dev_grads[dev] = []
|
96 |
+
|
97 |
+
loss = self.apply_loss_scaling(tf.cast(loss, tf.float32))
|
98 |
+
grads = self._dev_opt[dev].compute_gradients(loss, trainable_vars, gate_gradients=tf.train.Optimizer.GATE_NONE) # disable gating to reduce memory usage
|
99 |
+
grads = [(g, v) if g is not None else (tf.zeros_like(v), v) for g, v in grads] # replace disconnected gradients with zeros
|
100 |
+
self._dev_grads[dev].append(grads)
|
101 |
+
|
102 |
+
def apply_updates(self) -> tf.Operation:
|
103 |
+
"""Construct training op to update the registered variables based on their gradients."""
|
104 |
+
tfutil.assert_tf_initialized()
|
105 |
+
assert not self._updates_applied
|
106 |
+
self._updates_applied = True
|
107 |
+
devices = list(self._dev_grads.keys())
|
108 |
+
total_grads = sum(len(grads) for grads in self._dev_grads.values())
|
109 |
+
assert len(devices) >= 1 and total_grads >= 1
|
110 |
+
ops = []
|
111 |
+
|
112 |
+
with tfutil.absolute_name_scope(self.scope):
|
113 |
+
# Cast gradients to FP32 and calculate partial sum within each device.
|
114 |
+
dev_grads = OrderedDict() # device => [(grad, var), ...]
|
115 |
+
|
116 |
+
for dev_idx, dev in enumerate(devices):
|
117 |
+
with tf.name_scope("ProcessGrads%d" % dev_idx), tf.device(dev):
|
118 |
+
sums = []
|
119 |
+
|
120 |
+
for gv in zip(*self._dev_grads[dev]):
|
121 |
+
assert all(v is gv[0][1] for g, v in gv)
|
122 |
+
g = [tf.cast(g, tf.float32) for g, v in gv]
|
123 |
+
g = g[0] if len(g) == 1 else tf.add_n(g)
|
124 |
+
sums.append((g, gv[0][1]))
|
125 |
+
|
126 |
+
dev_grads[dev] = sums
|
127 |
+
|
128 |
+
# Sum gradients across devices.
|
129 |
+
if len(devices) > 1:
|
130 |
+
with tf.name_scope("SumAcrossGPUs"), tf.device(None):
|
131 |
+
for var_idx, grad_shape in enumerate(self._grad_shapes):
|
132 |
+
g = [dev_grads[dev][var_idx][0] for dev in devices]
|
133 |
+
|
134 |
+
if np.prod(grad_shape): # nccl does not support zero-sized tensors
|
135 |
+
g = nccl_ops.all_sum(g)
|
136 |
+
|
137 |
+
for dev, gg in zip(devices, g):
|
138 |
+
dev_grads[dev][var_idx] = (gg, dev_grads[dev][var_idx][1])
|
139 |
+
|
140 |
+
# Apply updates separately on each device.
|
141 |
+
for dev_idx, (dev, grads) in enumerate(dev_grads.items()):
|
142 |
+
with tf.name_scope("ApplyGrads%d" % dev_idx), tf.device(dev):
|
143 |
+
# Scale gradients as needed.
|
144 |
+
if self.use_loss_scaling or total_grads > 1:
|
145 |
+
with tf.name_scope("Scale"):
|
146 |
+
coef = tf.constant(np.float32(1.0 / total_grads), name="coef")
|
147 |
+
coef = self.undo_loss_scaling(coef)
|
148 |
+
grads = [(g * coef, v) for g, v in grads]
|
149 |
+
|
150 |
+
# Check for overflows.
|
151 |
+
with tf.name_scope("CheckOverflow"):
|
152 |
+
grad_ok = tf.reduce_all(tf.stack([tf.reduce_all(tf.is_finite(g)) for g, v in grads]))
|
153 |
+
|
154 |
+
# Update weights and adjust loss scaling.
|
155 |
+
with tf.name_scope("UpdateWeights"):
|
156 |
+
# pylint: disable=cell-var-from-loop
|
157 |
+
opt = self._dev_opt[dev]
|
158 |
+
ls_var = self.get_loss_scaling_var(dev)
|
159 |
+
|
160 |
+
if not self.use_loss_scaling:
|
161 |
+
ops.append(tf.cond(grad_ok, lambda: opt.apply_gradients(grads), tf.no_op))
|
162 |
+
else:
|
163 |
+
ops.append(tf.cond(grad_ok,
|
164 |
+
lambda: tf.group(tf.assign_add(ls_var, self.loss_scaling_inc), opt.apply_gradients(grads)),
|
165 |
+
lambda: tf.group(tf.assign_sub(ls_var, self.loss_scaling_dec))))
|
166 |
+
|
167 |
+
# Report statistics on the last device.
|
168 |
+
if dev == devices[-1]:
|
169 |
+
with tf.name_scope("Statistics"):
|
170 |
+
ops.append(autosummary.autosummary(self.id + "/learning_rate", self.learning_rate))
|
171 |
+
ops.append(autosummary.autosummary(self.id + "/overflow_frequency", tf.where(grad_ok, 0, 1)))
|
172 |
+
|
173 |
+
if self.use_loss_scaling:
|
174 |
+
ops.append(autosummary.autosummary(self.id + "/loss_scaling_log2", ls_var))
|
175 |
+
|
176 |
+
# Initialize variables and group everything into a single op.
|
177 |
+
self.reset_optimizer_state()
|
178 |
+
tfutil.init_uninitialized_vars(list(self._dev_ls_var.values()))
|
179 |
+
|
180 |
+
return tf.group(*ops, name="TrainingOp")
|
181 |
+
|
182 |
+
def reset_optimizer_state(self) -> None:
|
183 |
+
"""Reset internal state of the underlying optimizer."""
|
184 |
+
tfutil.assert_tf_initialized()
|
185 |
+
tfutil.run([var.initializer for opt in self._dev_opt.values() for var in opt.variables()])
|
186 |
+
|
187 |
+
def get_loss_scaling_var(self, device: str) -> Union[tf.Variable, None]:
|
188 |
+
"""Get or create variable representing log2 of the current dynamic loss scaling factor."""
|
189 |
+
if not self.use_loss_scaling:
|
190 |
+
return None
|
191 |
+
|
192 |
+
if device not in self._dev_ls_var:
|
193 |
+
with tfutil.absolute_name_scope(self.scope + "/LossScalingVars"), tf.control_dependencies(None):
|
194 |
+
self._dev_ls_var[device] = tf.Variable(np.float32(self.loss_scaling_init), name="loss_scaling_var")
|
195 |
+
|
196 |
+
return self._dev_ls_var[device]
|
197 |
+
|
198 |
+
def apply_loss_scaling(self, value: TfExpression) -> TfExpression:
|
199 |
+
"""Apply dynamic loss scaling for the given expression."""
|
200 |
+
assert tfutil.is_tf_expression(value)
|
201 |
+
|
202 |
+
if not self.use_loss_scaling:
|
203 |
+
return value
|
204 |
+
|
205 |
+
return value * tfutil.exp2(self.get_loss_scaling_var(value.device))
|
206 |
+
|
207 |
+
def undo_loss_scaling(self, value: TfExpression) -> TfExpression:
|
208 |
+
"""Undo the effect of dynamic loss scaling for the given expression."""
|
209 |
+
assert tfutil.is_tf_expression(value)
|
210 |
+
|
211 |
+
if not self.use_loss_scaling:
|
212 |
+
return value
|
213 |
+
|
214 |
+
return value * tfutil.exp2(-self.get_loss_scaling_var(value.device)) # pylint: disable=invalid-unary-operand-type
|
dnnlib/tflib/tfutil.py
ADDED
@@ -0,0 +1,242 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
#
|
3 |
+
# This work is licensed under the Creative Commons Attribution-NonCommercial
|
4 |
+
# 4.0 International License. To view a copy of this license, visit
|
5 |
+
# http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to
|
6 |
+
# Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.
|
7 |
+
|
8 |
+
"""Miscellaneous helper utils for Tensorflow."""
|
9 |
+
|
10 |
+
import os
|
11 |
+
import numpy as np
|
12 |
+
import tensorflow as tf
|
13 |
+
|
14 |
+
from typing import Any, Iterable, List, Union
|
15 |
+
|
16 |
+
TfExpression = Union[tf.Tensor, tf.Variable, tf.Operation]
|
17 |
+
"""A type that represents a valid Tensorflow expression."""
|
18 |
+
|
19 |
+
TfExpressionEx = Union[TfExpression, int, float, np.ndarray]
|
20 |
+
"""A type that can be converted to a valid Tensorflow expression."""
|
21 |
+
|
22 |
+
|
23 |
+
def run(*args, **kwargs) -> Any:
|
24 |
+
"""Run the specified ops in the default session."""
|
25 |
+
assert_tf_initialized()
|
26 |
+
return tf.compat.v1.get_default_session().run(*args, **kwargs)
|
27 |
+
|
28 |
+
|
29 |
+
def is_tf_expression(x: Any) -> bool:
|
30 |
+
"""Check whether the input is a valid Tensorflow expression, i.e., Tensorflow Tensor, Variable, or Operation."""
|
31 |
+
return isinstance(x, (tf.Tensor, tf.Variable, tf.Operation))
|
32 |
+
|
33 |
+
|
34 |
+
def shape_to_list(shape: Iterable[tf.compat.v1.Dimension]) -> List[Union[int, None]]:
|
35 |
+
"""Convert a Tensorflow shape to a list of ints."""
|
36 |
+
return [dim.value for dim in shape]
|
37 |
+
|
38 |
+
|
39 |
+
def flatten(x: TfExpressionEx) -> TfExpression:
|
40 |
+
"""Shortcut function for flattening a tensor."""
|
41 |
+
with tf.name_scope("Flatten"):
|
42 |
+
return tf.reshape(x, [-1])
|
43 |
+
|
44 |
+
|
45 |
+
def log2(x: TfExpressionEx) -> TfExpression:
|
46 |
+
"""Logarithm in base 2."""
|
47 |
+
with tf.name_scope("Log2"):
|
48 |
+
return tf.log(x) * np.float32(1.0 / np.log(2.0))
|
49 |
+
|
50 |
+
|
51 |
+
def exp2(x: TfExpressionEx) -> TfExpression:
|
52 |
+
"""Exponent in base 2."""
|
53 |
+
with tf.name_scope("Exp2"):
|
54 |
+
return tf.exp(x * np.float32(np.log(2.0)))
|
55 |
+
|
56 |
+
|
57 |
+
def lerp(a: TfExpressionEx, b: TfExpressionEx, t: TfExpressionEx) -> TfExpressionEx:
|
58 |
+
"""Linear interpolation."""
|
59 |
+
with tf.name_scope("Lerp"):
|
60 |
+
return a + (b - a) * t
|
61 |
+
|
62 |
+
|
63 |
+
def lerp_clip(a: TfExpressionEx, b: TfExpressionEx, t: TfExpressionEx) -> TfExpression:
|
64 |
+
"""Linear interpolation with clip."""
|
65 |
+
with tf.name_scope("LerpClip"):
|
66 |
+
return a + (b - a) * tf.clip_by_value(t, 0.0, 1.0)
|
67 |
+
|
68 |
+
|
69 |
+
def absolute_name_scope(scope: str) -> tf.name_scope:
|
70 |
+
"""Forcefully enter the specified name scope, ignoring any surrounding scopes."""
|
71 |
+
return tf.name_scope(scope + "/")
|
72 |
+
|
73 |
+
|
74 |
+
def absolute_variable_scope(scope: str, **kwargs) -> tf.compat.v1.variable_scope:
|
75 |
+
"""Forcefully enter the specified variable scope, ignoring any surrounding scopes."""
|
76 |
+
return tf.compat.v1.variable_scope(tf.compat.v1.VariableScope(name=scope, **kwargs), auxiliary_name_scope=False)
|
77 |
+
|
78 |
+
|
79 |
+
def _sanitize_tf_config(config_dict: dict = None) -> dict:
|
80 |
+
# Defaults.
|
81 |
+
cfg = dict()
|
82 |
+
cfg["rnd.np_random_seed"] = None # Random seed for NumPy. None = keep as is.
|
83 |
+
cfg["rnd.tf_random_seed"] = "auto" # Random seed for TensorFlow. 'auto' = derive from NumPy random state. None = keep as is.
|
84 |
+
cfg["env.TF_CPP_MIN_LOG_LEVEL"] = "1" # 0 = Print all available debug info from TensorFlow. 1 = Print warnings and errors, but disable debug info.
|
85 |
+
cfg["graph_options.place_pruned_graph"] = True # False = Check that all ops are available on the designated device. True = Skip the check for ops that are not used.
|
86 |
+
cfg["gpu_options.allow_growth"] = True # False = Allocate all GPU memory at the beginning. True = Allocate only as much GPU memory as needed.
|
87 |
+
|
88 |
+
# User overrides.
|
89 |
+
if config_dict is not None:
|
90 |
+
cfg.update(config_dict)
|
91 |
+
return cfg
|
92 |
+
|
93 |
+
|
94 |
+
def init_tf(config_dict: dict = None) -> None:
|
95 |
+
"""Initialize TensorFlow session using good default settings."""
|
96 |
+
# Skip if already initialized.
|
97 |
+
if tf.compat.v1.get_default_session() is not None:
|
98 |
+
return
|
99 |
+
|
100 |
+
# Setup config dict and random seeds.
|
101 |
+
cfg = _sanitize_tf_config(config_dict)
|
102 |
+
np_random_seed = cfg["rnd.np_random_seed"]
|
103 |
+
if np_random_seed is not None:
|
104 |
+
np.random.seed(np_random_seed)
|
105 |
+
tf_random_seed = cfg["rnd.tf_random_seed"]
|
106 |
+
if tf_random_seed == "auto":
|
107 |
+
tf_random_seed = np.random.randint(1 << 31)
|
108 |
+
if tf_random_seed is not None:
|
109 |
+
tf.compat.v1.set_random_seed(tf_random_seed)
|
110 |
+
|
111 |
+
# Setup environment variables.
|
112 |
+
for key, value in list(cfg.items()):
|
113 |
+
fields = key.split(".")
|
114 |
+
if fields[0] == "env":
|
115 |
+
assert len(fields) == 2
|
116 |
+
os.environ[fields[1]] = str(value)
|
117 |
+
|
118 |
+
# Create default TensorFlow session.
|
119 |
+
create_session(cfg, force_as_default=True)
|
120 |
+
|
121 |
+
|
122 |
+
def assert_tf_initialized():
|
123 |
+
"""Check that TensorFlow session has been initialized."""
|
124 |
+
if tf.compat.v1.get_default_session() is None:
|
125 |
+
raise RuntimeError("No default TensorFlow session found. Please call dnnlib.tflib.init_tf().")
|
126 |
+
|
127 |
+
|
128 |
+
def create_session(config_dict: dict = None, force_as_default: bool = False) -> tf.compat.v1.Session:
|
129 |
+
"""Create tf.Session based on config dict."""
|
130 |
+
# Setup TensorFlow config proto.
|
131 |
+
cfg = _sanitize_tf_config(config_dict)
|
132 |
+
config_proto = tf.compat.v1.ConfigProto()
|
133 |
+
for key, value in cfg.items():
|
134 |
+
fields = key.split(".")
|
135 |
+
if fields[0] not in ["rnd", "env"]:
|
136 |
+
obj = config_proto
|
137 |
+
for field in fields[:-1]:
|
138 |
+
obj = getattr(obj, field)
|
139 |
+
setattr(obj, fields[-1], value)
|
140 |
+
|
141 |
+
# Create session.
|
142 |
+
session = tf.compat.v1.Session(config=config_proto)
|
143 |
+
if force_as_default:
|
144 |
+
# pylint: disable=protected-access
|
145 |
+
session._default_session = session.as_default()
|
146 |
+
session._default_session.enforce_nesting = False
|
147 |
+
session._default_session.__enter__() # pylint: disable=no-member
|
148 |
+
|
149 |
+
return session
|
150 |
+
|
151 |
+
|
152 |
+
def init_uninitialized_vars(target_vars: List[tf.Variable] = None) -> None:
|
153 |
+
"""Initialize all tf.Variables that have not already been initialized.
|
154 |
+
|
155 |
+
Equivalent to the following, but more efficient and does not bloat the tf graph:
|
156 |
+
tf.variables_initializer(tf.report_uninitialized_variables()).run()
|
157 |
+
"""
|
158 |
+
assert_tf_initialized()
|
159 |
+
if target_vars is None:
|
160 |
+
target_vars = tf.global_variables()
|
161 |
+
|
162 |
+
test_vars = []
|
163 |
+
test_ops = []
|
164 |
+
|
165 |
+
with tf.control_dependencies(None): # ignore surrounding control_dependencies
|
166 |
+
for var in target_vars:
|
167 |
+
assert is_tf_expression(var)
|
168 |
+
|
169 |
+
try:
|
170 |
+
tf.get_default_graph().get_tensor_by_name(var.name.replace(":0", "/IsVariableInitialized:0"))
|
171 |
+
except KeyError:
|
172 |
+
# Op does not exist => variable may be uninitialized.
|
173 |
+
test_vars.append(var)
|
174 |
+
|
175 |
+
with absolute_name_scope(var.name.split(":")[0]):
|
176 |
+
test_ops.append(tf.is_variable_initialized(var))
|
177 |
+
|
178 |
+
init_vars = [var for var, inited in zip(test_vars, run(test_ops)) if not inited]
|
179 |
+
run([var.initializer for var in init_vars])
|
180 |
+
|
181 |
+
|
182 |
+
def set_vars(var_to_value_dict: dict) -> None:
|
183 |
+
"""Set the values of given tf.Variables.
|
184 |
+
|
185 |
+
Equivalent to the following, but more efficient and does not bloat the tf graph:
|
186 |
+
tflib.run([tf.assign(var, value) for var, value in var_to_value_dict.items()]
|
187 |
+
"""
|
188 |
+
assert_tf_initialized()
|
189 |
+
ops = []
|
190 |
+
feed_dict = {}
|
191 |
+
|
192 |
+
for var, value in var_to_value_dict.items():
|
193 |
+
assert is_tf_expression(var)
|
194 |
+
|
195 |
+
try:
|
196 |
+
setter = tf.compat.v1.get_default_graph().get_tensor_by_name(var.name.replace(":0", "/setter:0")) # look for existing op
|
197 |
+
except KeyError:
|
198 |
+
with absolute_name_scope(var.name.split(":")[0]):
|
199 |
+
with tf.control_dependencies(None): # ignore surrounding control_dependencies
|
200 |
+
setter = tf.compat.v1.assign(var, tf.compat.v1.placeholder(var.dtype, var.shape, "new_value"), name="setter") # create new setter
|
201 |
+
|
202 |
+
ops.append(setter)
|
203 |
+
feed_dict[setter.op.inputs[1]] = value
|
204 |
+
|
205 |
+
run(ops, feed_dict)
|
206 |
+
|
207 |
+
|
208 |
+
def create_var_with_large_initial_value(initial_value: np.ndarray, *args, **kwargs):
|
209 |
+
"""Create tf.Variable with large initial value without bloating the tf graph."""
|
210 |
+
assert_tf_initialized()
|
211 |
+
assert isinstance(initial_value, np.ndarray)
|
212 |
+
zeros = tf.zeros(initial_value.shape, initial_value.dtype)
|
213 |
+
var = tf.Variable(zeros, *args, **kwargs)
|
214 |
+
set_vars({var: initial_value})
|
215 |
+
return var
|
216 |
+
|
217 |
+
|
218 |
+
def convert_images_from_uint8(images, drange=[-1,1], nhwc_to_nchw=False):
|
219 |
+
"""Convert a minibatch of images from uint8 to float32 with configurable dynamic range.
|
220 |
+
Can be used as an input transformation for Network.run().
|
221 |
+
"""
|
222 |
+
images = tf.cast(images, tf.float32)
|
223 |
+
if nhwc_to_nchw:
|
224 |
+
images = tf.transpose(images, [0, 3, 1, 2])
|
225 |
+
return (images - drange[0]) * ((drange[1] - drange[0]) / 255)
|
226 |
+
|
227 |
+
|
228 |
+
def convert_images_to_uint8(images, drange=[-1,1], nchw_to_nhwc=False, shrink=1, uint8_cast=True):
|
229 |
+
"""Convert a minibatch of images from float32 to uint8 with configurable dynamic range.
|
230 |
+
Can be used as an output transformation for Network.run().
|
231 |
+
"""
|
232 |
+
images = tf.cast(images, tf.float32)
|
233 |
+
if shrink > 1:
|
234 |
+
ksize = [1, 1, shrink, shrink]
|
235 |
+
images = tf.nn.avg_pool(images, ksize=ksize, strides=ksize, padding="VALID", data_format="NCHW")
|
236 |
+
if nchw_to_nhwc:
|
237 |
+
images = tf.transpose(images, [0, 2, 3, 1])
|
238 |
+
scale = 255 / (drange[1] - drange[0])
|
239 |
+
images = images * scale + (0.5 - drange[0] * scale)
|
240 |
+
if uint8_cast:
|
241 |
+
images = tf.saturate_cast(images, tf.uint8)
|
242 |
+
return images
|
dnnlib/util.py
ADDED
@@ -0,0 +1,408 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
#
|
3 |
+
# This work is licensed under the Creative Commons Attribution-NonCommercial
|
4 |
+
# 4.0 International License. To view a copy of this license, visit
|
5 |
+
# http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to
|
6 |
+
# Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.
|
7 |
+
|
8 |
+
"""Miscellaneous utility classes and functions."""
|
9 |
+
|
10 |
+
import ctypes
|
11 |
+
import fnmatch
|
12 |
+
import importlib
|
13 |
+
import inspect
|
14 |
+
import numpy as np
|
15 |
+
import os
|
16 |
+
import shutil
|
17 |
+
import sys
|
18 |
+
import types
|
19 |
+
import io
|
20 |
+
import pickle
|
21 |
+
import re
|
22 |
+
import requests
|
23 |
+
import html
|
24 |
+
import hashlib
|
25 |
+
import glob
|
26 |
+
import uuid
|
27 |
+
|
28 |
+
from distutils.util import strtobool
|
29 |
+
from typing import Any, List, Tuple, Union
|
30 |
+
|
31 |
+
|
32 |
+
# Util classes
|
33 |
+
# ------------------------------------------------------------------------------------------
|
34 |
+
|
35 |
+
|
36 |
+
class EasyDict(dict):
|
37 |
+
"""Convenience class that behaves like a dict but allows access with the attribute syntax."""
|
38 |
+
|
39 |
+
def __getattr__(self, name: str) -> Any:
|
40 |
+
try:
|
41 |
+
return self[name]
|
42 |
+
except KeyError:
|
43 |
+
raise AttributeError(name)
|
44 |
+
|
45 |
+
def __setattr__(self, name: str, value: Any) -> None:
|
46 |
+
self[name] = value
|
47 |
+
|
48 |
+
def __delattr__(self, name: str) -> None:
|
49 |
+
del self[name]
|
50 |
+
|
51 |
+
|
52 |
+
class Logger(object):
|
53 |
+
"""Redirect stderr to stdout, optionally print stdout to a file, and optionally force flushing on both stdout and the file."""
|
54 |
+
|
55 |
+
def __init__(self, file_name: str = None, file_mode: str = "w", should_flush: bool = True):
|
56 |
+
self.file = None
|
57 |
+
|
58 |
+
if file_name is not None:
|
59 |
+
self.file = open(file_name, file_mode)
|
60 |
+
|
61 |
+
self.should_flush = should_flush
|
62 |
+
self.stdout = sys.stdout
|
63 |
+
self.stderr = sys.stderr
|
64 |
+
|
65 |
+
sys.stdout = self
|
66 |
+
sys.stderr = self
|
67 |
+
|
68 |
+
def __enter__(self) -> "Logger":
|
69 |
+
return self
|
70 |
+
|
71 |
+
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
|
72 |
+
self.close()
|
73 |
+
|
74 |
+
def write(self, text: str) -> None:
|
75 |
+
"""Write text to stdout (and a file) and optionally flush."""
|
76 |
+
if len(text) == 0: # workaround for a bug in VSCode debugger: sys.stdout.write(''); sys.stdout.flush() => crash
|
77 |
+
return
|
78 |
+
|
79 |
+
if self.file is not None:
|
80 |
+
self.file.write(text)
|
81 |
+
|
82 |
+
self.stdout.write(text)
|
83 |
+
|
84 |
+
if self.should_flush:
|
85 |
+
self.flush()
|
86 |
+
|
87 |
+
def flush(self) -> None:
|
88 |
+
"""Flush written text to both stdout and a file, if open."""
|
89 |
+
if self.file is not None:
|
90 |
+
self.file.flush()
|
91 |
+
|
92 |
+
self.stdout.flush()
|
93 |
+
|
94 |
+
def close(self) -> None:
|
95 |
+
"""Flush, close possible files, and remove stdout/stderr mirroring."""
|
96 |
+
self.flush()
|
97 |
+
|
98 |
+
# if using multiple loggers, prevent closing in wrong order
|
99 |
+
if sys.stdout is self:
|
100 |
+
sys.stdout = self.stdout
|
101 |
+
if sys.stderr is self:
|
102 |
+
sys.stderr = self.stderr
|
103 |
+
|
104 |
+
if self.file is not None:
|
105 |
+
self.file.close()
|
106 |
+
|
107 |
+
|
108 |
+
# Small util functions
|
109 |
+
# ------------------------------------------------------------------------------------------
|
110 |
+
|
111 |
+
|
112 |
+
def format_time(seconds: Union[int, float]) -> str:
|
113 |
+
"""Convert the seconds to human readable string with days, hours, minutes and seconds."""
|
114 |
+
s = int(np.rint(seconds))
|
115 |
+
|
116 |
+
if s < 60:
|
117 |
+
return "{0}s".format(s)
|
118 |
+
elif s < 60 * 60:
|
119 |
+
return "{0}m {1:02}s".format(s // 60, s % 60)
|
120 |
+
elif s < 24 * 60 * 60:
|
121 |
+
return "{0}h {1:02}m {2:02}s".format(s // (60 * 60), (s // 60) % 60, s % 60)
|
122 |
+
else:
|
123 |
+
return "{0}d {1:02}h {2:02}m".format(s // (24 * 60 * 60), (s // (60 * 60)) % 24, (s // 60) % 60)
|
124 |
+
|
125 |
+
|
126 |
+
def ask_yes_no(question: str) -> bool:
|
127 |
+
"""Ask the user the question until the user inputs a valid answer."""
|
128 |
+
while True:
|
129 |
+
try:
|
130 |
+
print("{0} [y/n]".format(question))
|
131 |
+
return strtobool(input().lower())
|
132 |
+
except ValueError:
|
133 |
+
pass
|
134 |
+
|
135 |
+
|
136 |
+
def tuple_product(t: Tuple) -> Any:
|
137 |
+
"""Calculate the product of the tuple elements."""
|
138 |
+
result = 1
|
139 |
+
|
140 |
+
for v in t:
|
141 |
+
result *= v
|
142 |
+
|
143 |
+
return result
|
144 |
+
|
145 |
+
|
146 |
+
_str_to_ctype = {
|
147 |
+
"uint8": ctypes.c_ubyte,
|
148 |
+
"uint16": ctypes.c_uint16,
|
149 |
+
"uint32": ctypes.c_uint32,
|
150 |
+
"uint64": ctypes.c_uint64,
|
151 |
+
"int8": ctypes.c_byte,
|
152 |
+
"int16": ctypes.c_int16,
|
153 |
+
"int32": ctypes.c_int32,
|
154 |
+
"int64": ctypes.c_int64,
|
155 |
+
"float32": ctypes.c_float,
|
156 |
+
"float64": ctypes.c_double
|
157 |
+
}
|
158 |
+
|
159 |
+
|
160 |
+
def get_dtype_and_ctype(type_obj: Any) -> Tuple[np.dtype, Any]:
|
161 |
+
"""Given a type name string (or an object having a __name__ attribute), return matching Numpy and ctypes types that have the same size in bytes."""
|
162 |
+
type_str = None
|
163 |
+
|
164 |
+
if isinstance(type_obj, str):
|
165 |
+
type_str = type_obj
|
166 |
+
elif hasattr(type_obj, "__name__"):
|
167 |
+
type_str = type_obj.__name__
|
168 |
+
elif hasattr(type_obj, "name"):
|
169 |
+
type_str = type_obj.name
|
170 |
+
else:
|
171 |
+
raise RuntimeError("Cannot infer type name from input")
|
172 |
+
|
173 |
+
assert type_str in _str_to_ctype.keys()
|
174 |
+
|
175 |
+
my_dtype = np.dtype(type_str)
|
176 |
+
my_ctype = _str_to_ctype[type_str]
|
177 |
+
|
178 |
+
assert my_dtype.itemsize == ctypes.sizeof(my_ctype)
|
179 |
+
|
180 |
+
return my_dtype, my_ctype
|
181 |
+
|
182 |
+
|
183 |
+
def is_pickleable(obj: Any) -> bool:
|
184 |
+
try:
|
185 |
+
with io.BytesIO() as stream:
|
186 |
+
pickle.dump(obj, stream)
|
187 |
+
return True
|
188 |
+
except:
|
189 |
+
return False
|
190 |
+
|
191 |
+
|
192 |
+
# Functionality to import modules/objects by name, and call functions by name
|
193 |
+
# ------------------------------------------------------------------------------------------
|
194 |
+
|
195 |
+
def get_module_from_obj_name(obj_name: str) -> Tuple[types.ModuleType, str]:
|
196 |
+
"""Searches for the underlying module behind the name to some python object.
|
197 |
+
Returns the module and the object name (original name with module part removed)."""
|
198 |
+
|
199 |
+
# allow convenience shorthands, substitute them by full names
|
200 |
+
obj_name = re.sub("^np.", "numpy.", obj_name)
|
201 |
+
obj_name = re.sub("^tf.", "tensorflow.", obj_name)
|
202 |
+
|
203 |
+
# list alternatives for (module_name, local_obj_name)
|
204 |
+
parts = obj_name.split(".")
|
205 |
+
name_pairs = [(".".join(parts[:i]), ".".join(parts[i:])) for i in range(len(parts), 0, -1)]
|
206 |
+
|
207 |
+
# try each alternative in turn
|
208 |
+
for module_name, local_obj_name in name_pairs:
|
209 |
+
try:
|
210 |
+
module = importlib.import_module(module_name) # may raise ImportError
|
211 |
+
get_obj_from_module(module, local_obj_name) # may raise AttributeError
|
212 |
+
return module, local_obj_name
|
213 |
+
except:
|
214 |
+
pass
|
215 |
+
|
216 |
+
# maybe some of the modules themselves contain errors?
|
217 |
+
for module_name, _local_obj_name in name_pairs:
|
218 |
+
try:
|
219 |
+
importlib.import_module(module_name) # may raise ImportError
|
220 |
+
except ImportError:
|
221 |
+
if not str(sys.exc_info()[1]).startswith("No module named '" + module_name + "'"):
|
222 |
+
raise
|
223 |
+
|
224 |
+
# maybe the requested attribute is missing?
|
225 |
+
for module_name, local_obj_name in name_pairs:
|
226 |
+
try:
|
227 |
+
module = importlib.import_module(module_name) # may raise ImportError
|
228 |
+
get_obj_from_module(module, local_obj_name) # may raise AttributeError
|
229 |
+
except ImportError:
|
230 |
+
pass
|
231 |
+
|
232 |
+
# we are out of luck, but we have no idea why
|
233 |
+
raise ImportError(obj_name)
|
234 |
+
|
235 |
+
|
236 |
+
def get_obj_from_module(module: types.ModuleType, obj_name: str) -> Any:
|
237 |
+
"""Traverses the object name and returns the last (rightmost) python object."""
|
238 |
+
if obj_name == '':
|
239 |
+
return module
|
240 |
+
obj = module
|
241 |
+
for part in obj_name.split("."):
|
242 |
+
obj = getattr(obj, part)
|
243 |
+
return obj
|
244 |
+
|
245 |
+
|
246 |
+
def get_obj_by_name(name: str) -> Any:
|
247 |
+
"""Finds the python object with the given name."""
|
248 |
+
module, obj_name = get_module_from_obj_name(name)
|
249 |
+
return get_obj_from_module(module, obj_name)
|
250 |
+
|
251 |
+
|
252 |
+
def call_func_by_name(*args, func_name: str = None, **kwargs) -> Any:
|
253 |
+
"""Finds the python object with the given name and calls it as a function."""
|
254 |
+
assert func_name is not None
|
255 |
+
func_obj = get_obj_by_name(func_name)
|
256 |
+
assert callable(func_obj)
|
257 |
+
return func_obj(*args, **kwargs)
|
258 |
+
|
259 |
+
|
260 |
+
def get_module_dir_by_obj_name(obj_name: str) -> str:
|
261 |
+
"""Get the directory path of the module containing the given object name."""
|
262 |
+
module, _ = get_module_from_obj_name(obj_name)
|
263 |
+
return os.path.dirname(inspect.getfile(module))
|
264 |
+
|
265 |
+
|
266 |
+
def is_top_level_function(obj: Any) -> bool:
|
267 |
+
"""Determine whether the given object is a top-level function, i.e., defined at module scope using 'def'."""
|
268 |
+
return callable(obj) and obj.__name__ in sys.modules[obj.__module__].__dict__
|
269 |
+
|
270 |
+
|
271 |
+
def get_top_level_function_name(obj: Any) -> str:
|
272 |
+
"""Return the fully-qualified name of a top-level function."""
|
273 |
+
assert is_top_level_function(obj)
|
274 |
+
return obj.__module__ + "." + obj.__name__
|
275 |
+
|
276 |
+
|
277 |
+
# File system helpers
|
278 |
+
# ------------------------------------------------------------------------------------------
|
279 |
+
|
280 |
+
def list_dir_recursively_with_ignore(dir_path: str, ignores: List[str] = None, add_base_to_relative: bool = False) -> List[Tuple[str, str]]:
|
281 |
+
"""List all files recursively in a given directory while ignoring given file and directory names.
|
282 |
+
Returns list of tuples containing both absolute and relative paths."""
|
283 |
+
assert os.path.isdir(dir_path)
|
284 |
+
base_name = os.path.basename(os.path.normpath(dir_path))
|
285 |
+
|
286 |
+
if ignores is None:
|
287 |
+
ignores = []
|
288 |
+
|
289 |
+
result = []
|
290 |
+
|
291 |
+
for root, dirs, files in os.walk(dir_path, topdown=True):
|
292 |
+
for ignore_ in ignores:
|
293 |
+
dirs_to_remove = [d for d in dirs if fnmatch.fnmatch(d, ignore_)]
|
294 |
+
|
295 |
+
# dirs need to be edited in-place
|
296 |
+
for d in dirs_to_remove:
|
297 |
+
dirs.remove(d)
|
298 |
+
|
299 |
+
files = [f for f in files if not fnmatch.fnmatch(f, ignore_)]
|
300 |
+
|
301 |
+
absolute_paths = [os.path.join(root, f) for f in files]
|
302 |
+
relative_paths = [os.path.relpath(p, dir_path) for p in absolute_paths]
|
303 |
+
|
304 |
+
if add_base_to_relative:
|
305 |
+
relative_paths = [os.path.join(base_name, p) for p in relative_paths]
|
306 |
+
|
307 |
+
assert len(absolute_paths) == len(relative_paths)
|
308 |
+
result += zip(absolute_paths, relative_paths)
|
309 |
+
|
310 |
+
return result
|
311 |
+
|
312 |
+
|
313 |
+
def copy_files_and_create_dirs(files: List[Tuple[str, str]]) -> None:
|
314 |
+
"""Takes in a list of tuples of (src, dst) paths and copies files.
|
315 |
+
Will create all necessary directories."""
|
316 |
+
for file in files:
|
317 |
+
target_dir_name = os.path.dirname(file[1])
|
318 |
+
|
319 |
+
# will create all intermediate-level directories
|
320 |
+
if not os.path.exists(target_dir_name):
|
321 |
+
os.makedirs(target_dir_name)
|
322 |
+
|
323 |
+
shutil.copyfile(file[0], file[1])
|
324 |
+
|
325 |
+
|
326 |
+
# URL helpers
|
327 |
+
# ------------------------------------------------------------------------------------------
|
328 |
+
|
329 |
+
def is_url(obj: Any) -> bool:
|
330 |
+
"""Determine whether the given object is a valid URL string."""
|
331 |
+
if not isinstance(obj, str) or not "://" in obj:
|
332 |
+
return False
|
333 |
+
try:
|
334 |
+
res = requests.compat.urlparse(obj)
|
335 |
+
if not res.scheme or not res.netloc or not "." in res.netloc:
|
336 |
+
return False
|
337 |
+
res = requests.compat.urlparse(requests.compat.urljoin(obj, "/"))
|
338 |
+
if not res.scheme or not res.netloc or not "." in res.netloc:
|
339 |
+
return False
|
340 |
+
except:
|
341 |
+
return False
|
342 |
+
return True
|
343 |
+
|
344 |
+
|
345 |
+
def open_url(url: str, cache_dir: str = None, num_attempts: int = 10, verbose: bool = True) -> Any:
|
346 |
+
"""Download the given URL and return a binary-mode file object to access the data."""
|
347 |
+
if not is_url(url) and os.path.isfile(url):
|
348 |
+
return open(url, 'rb')
|
349 |
+
|
350 |
+
assert is_url(url)
|
351 |
+
assert num_attempts >= 1
|
352 |
+
|
353 |
+
# Lookup from cache.
|
354 |
+
url_md5 = hashlib.md5(url.encode("utf-8")).hexdigest()
|
355 |
+
if cache_dir is not None:
|
356 |
+
cache_files = glob.glob(os.path.join(cache_dir, url_md5 + "_*"))
|
357 |
+
if len(cache_files) == 1:
|
358 |
+
return open(cache_files[0], "rb")
|
359 |
+
|
360 |
+
# Download.
|
361 |
+
url_name = None
|
362 |
+
url_data = None
|
363 |
+
with requests.Session() as session:
|
364 |
+
if verbose:
|
365 |
+
print("Downloading %s ..." % url, end="", flush=True)
|
366 |
+
for attempts_left in reversed(range(num_attempts)):
|
367 |
+
try:
|
368 |
+
with session.get(url) as res:
|
369 |
+
res.raise_for_status()
|
370 |
+
if len(res.content) == 0:
|
371 |
+
raise IOError("No data received")
|
372 |
+
|
373 |
+
if len(res.content) < 8192:
|
374 |
+
content_str = res.content.decode("utf-8")
|
375 |
+
if "download_warning" in res.headers.get("Set-Cookie", ""):
|
376 |
+
links = [html.unescape(link) for link in content_str.split('"') if "export=download" in link]
|
377 |
+
if len(links) == 1:
|
378 |
+
url = requests.compat.urljoin(url, links[0])
|
379 |
+
raise IOError("Google Drive virus checker nag")
|
380 |
+
if "Google Drive - Quota exceeded" in content_str:
|
381 |
+
raise IOError("Google Drive quota exceeded")
|
382 |
+
|
383 |
+
match = re.search(r'filename="([^"]*)"', res.headers.get("Content-Disposition", ""))
|
384 |
+
url_name = match[1] if match else url
|
385 |
+
url_data = res.content
|
386 |
+
if verbose:
|
387 |
+
print(" done")
|
388 |
+
break
|
389 |
+
except:
|
390 |
+
if not attempts_left:
|
391 |
+
if verbose:
|
392 |
+
print(" failed")
|
393 |
+
raise
|
394 |
+
if verbose:
|
395 |
+
print(".", end="", flush=True)
|
396 |
+
|
397 |
+
# Save to cache.
|
398 |
+
if cache_dir is not None:
|
399 |
+
safe_name = re.sub(r"[^0-9a-zA-Z-._]", "_", url_name)
|
400 |
+
cache_file = os.path.join(cache_dir, url_md5 + "_" + safe_name)
|
401 |
+
temp_file = os.path.join(cache_dir, "tmp_" + uuid.uuid4().hex + "_" + url_md5 + "_" + safe_name)
|
402 |
+
os.makedirs(cache_dir, exist_ok=True)
|
403 |
+
with open(temp_file, "wb") as f:
|
404 |
+
f.write(url_data)
|
405 |
+
os.replace(temp_file, cache_file) # atomic
|
406 |
+
|
407 |
+
# Return data as file object.
|
408 |
+
return io.BytesIO(url_data)
|
encode_images.py
ADDED
@@ -0,0 +1,242 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import argparse
|
3 |
+
import pickle
|
4 |
+
from tqdm import tqdm
|
5 |
+
import PIL.Image
|
6 |
+
from PIL import ImageFilter
|
7 |
+
import numpy as np
|
8 |
+
import dnnlib
|
9 |
+
import dnnlib.tflib as tflib
|
10 |
+
import config
|
11 |
+
from encoder.generator_model import Generator
|
12 |
+
from encoder.perceptual_model import PerceptualModel, load_images
|
13 |
+
#from tensorflow.keras.models import load_model
|
14 |
+
from keras.models import load_model
|
15 |
+
from keras.applications.resnet50 import preprocess_input
|
16 |
+
|
17 |
+
def split_to_batches(l, n):
|
18 |
+
for i in range(0, len(l), n):
|
19 |
+
yield l[i:i + n]
|
20 |
+
|
21 |
+
def str2bool(v):
|
22 |
+
if isinstance(v, bool):
|
23 |
+
return v
|
24 |
+
if v.lower() in ('yes', 'true', 't', 'y', '1'):
|
25 |
+
return True
|
26 |
+
elif v.lower() in ('no', 'false', 'f', 'n', '0'):
|
27 |
+
return False
|
28 |
+
else:
|
29 |
+
raise argparse.ArgumentTypeError('Boolean value expected.')
|
30 |
+
|
31 |
+
def main():
|
32 |
+
parser = argparse.ArgumentParser(description='Find latent representation of reference images using perceptual losses', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
33 |
+
parser.add_argument('src_dir', help='Directory with images for encoding')
|
34 |
+
parser.add_argument('generated_images_dir', help='Directory for storing generated images')
|
35 |
+
parser.add_argument('dlatent_dir', help='Directory for storing dlatent representations')
|
36 |
+
parser.add_argument('--data_dir', default='data', help='Directory for storing optional models')
|
37 |
+
parser.add_argument('--mask_dir', default='masks', help='Directory for storing optional masks')
|
38 |
+
parser.add_argument('--load_last', default='', help='Start with embeddings from directory')
|
39 |
+
parser.add_argument('--dlatent_avg', default='', help='Use dlatent from file specified here for truncation instead of dlatent_avg from Gs')
|
40 |
+
parser.add_argument('--model_url', default='./data/karras2019stylegan-ffhq-1024x1024.pkl', help='Fetch a StyleGAN model to train on from this URL')
|
41 |
+
parser.add_argument('--architecture', default='./data/vgg16_zhang_perceptual.pkl', help='Сonvolutional neural network model from this URL')
|
42 |
+
parser.add_argument('--model_res', default=1024, help='The dimension of images in the StyleGAN model', type=int)
|
43 |
+
parser.add_argument('--batch_size', default=1, help='Batch size for generator and perceptual model', type=int)
|
44 |
+
parser.add_argument('--optimizer', default='ggt', help='Optimization algorithm used for optimizing dlatents')
|
45 |
+
|
46 |
+
# Perceptual model params
|
47 |
+
parser.add_argument('--image_size', default=256, help='Size of images for perceptual model', type=int)
|
48 |
+
parser.add_argument('--resnet_image_size', default=256, help='Size of images for the Resnet model', type=int)
|
49 |
+
parser.add_argument('--lr', default=0.25, help='Learning rate for perceptual model', type=float)
|
50 |
+
parser.add_argument('--decay_rate', default=0.9, help='Decay rate for learning rate', type=float)
|
51 |
+
parser.add_argument('--iterations', default=100, help='Number of optimization steps for each batch', type=int)
|
52 |
+
parser.add_argument('--decay_steps', default=4, help='Decay steps for learning rate decay (as a percent of iterations)', type=float)
|
53 |
+
parser.add_argument('--early_stopping', default=True, help='Stop early once training stabilizes', type=str2bool, nargs='?', const=True)
|
54 |
+
parser.add_argument('--early_stopping_threshold', default=0.5, help='Stop after this threshold has been reached', type=float)
|
55 |
+
parser.add_argument('--early_stopping_patience', default=10, help='Number of iterations to wait below threshold', type=int)
|
56 |
+
parser.add_argument('--load_effnet', default='data/finetuned_effnet.h5', help='Model to load for EfficientNet approximation of dlatents')
|
57 |
+
parser.add_argument('--load_resnet', default='data/finetuned_resnet.h5', help='Model to load for ResNet approximation of dlatents')
|
58 |
+
parser.add_argument('--use_preprocess_input', default=True, help='Call process_input() first before using feed forward net', type=str2bool, nargs='?', const=True)
|
59 |
+
parser.add_argument('--use_best_loss', default=True, help='Output the lowest loss value found as the solution', type=str2bool, nargs='?', const=True)
|
60 |
+
parser.add_argument('--average_best_loss', default=0.25, help='Do a running weighted average with the previous best dlatents found', type=float)
|
61 |
+
parser.add_argument('--sharpen_input', default=True, help='Sharpen the input images', type=str2bool, nargs='?', const=True)
|
62 |
+
|
63 |
+
# Loss function options
|
64 |
+
parser.add_argument('--use_vgg_loss', default=0.4, help='Use VGG perceptual loss; 0 to disable, > 0 to scale.', type=float)
|
65 |
+
parser.add_argument('--use_vgg_layer', default=9, help='Pick which VGG layer to use.', type=int)
|
66 |
+
parser.add_argument('--use_pixel_loss', default=1.5, help='Use logcosh image pixel loss; 0 to disable, > 0 to scale.', type=float)
|
67 |
+
parser.add_argument('--use_mssim_loss', default=200, help='Use MS-SIM perceptual loss; 0 to disable, > 0 to scale.', type=float)
|
68 |
+
parser.add_argument('--use_lpips_loss', default=100, help='Use LPIPS perceptual loss; 0 to disable, > 0 to scale.', type=float)
|
69 |
+
parser.add_argument('--use_l1_penalty', default=0.5, help='Use L1 penalty on latents; 0 to disable, > 0 to scale.', type=float)
|
70 |
+
parser.add_argument('--use_discriminator_loss', default=0.5, help='Use trained discriminator to evaluate realism.', type=float)
|
71 |
+
parser.add_argument('--use_adaptive_loss', default=False, help='Use the adaptive robust loss function from Google Research for pixel and VGG feature loss.', type=str2bool, nargs='?', const=True)
|
72 |
+
|
73 |
+
# Generator params
|
74 |
+
parser.add_argument('--randomize_noise', default=False, help='Add noise to dlatents during optimization', type=str2bool, nargs='?', const=True)
|
75 |
+
parser.add_argument('--tile_dlatents', default=False, help='Tile dlatents to use a single vector at each scale', type=str2bool, nargs='?', const=True)
|
76 |
+
parser.add_argument('--clipping_threshold', default=2.0, help='Stochastic clipping of gradient values outside of this threshold', type=float)
|
77 |
+
|
78 |
+
# Masking params
|
79 |
+
parser.add_argument('--load_mask', default=False, help='Load segmentation masks', type=str2bool, nargs='?', const=True)
|
80 |
+
parser.add_argument('--face_mask', default=True, help='Generate a mask for predicting only the face area', type=str2bool, nargs='?', const=True)
|
81 |
+
parser.add_argument('--use_grabcut', default=True, help='Use grabcut algorithm on the face mask to better segment the foreground', type=str2bool, nargs='?', const=True)
|
82 |
+
parser.add_argument('--scale_mask', default=1.4, help='Look over a wider section of foreground for grabcut', type=float)
|
83 |
+
parser.add_argument('--composite_mask', default=True, help='Merge the unmasked area back into the generated image', type=str2bool, nargs='?', const=True)
|
84 |
+
parser.add_argument('--composite_blur', default=8, help='Size of blur filter to smoothly composite the images', type=int)
|
85 |
+
|
86 |
+
# Video params
|
87 |
+
parser.add_argument('--video_dir', default='videos', help='Directory for storing training videos')
|
88 |
+
parser.add_argument('--output_video', default=False, help='Generate videos of the optimization process', type=bool)
|
89 |
+
parser.add_argument('--video_codec', default='MJPG', help='FOURCC-supported video codec name')
|
90 |
+
parser.add_argument('--video_frame_rate', default=24, help='Video frames per second', type=int)
|
91 |
+
parser.add_argument('--video_size', default=512, help='Video size in pixels', type=int)
|
92 |
+
parser.add_argument('--video_skip', default=1, help='Only write every n frames (1 = write every frame)', type=int)
|
93 |
+
|
94 |
+
args, other_args = parser.parse_known_args()
|
95 |
+
|
96 |
+
args.decay_steps *= 0.01 * args.iterations # Calculate steps as a percent of total iterations
|
97 |
+
|
98 |
+
if args.output_video:
|
99 |
+
import cv2
|
100 |
+
synthesis_kwargs = dict(output_transform=dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=False), minibatch_size=args.batch_size)
|
101 |
+
|
102 |
+
ref_images = [os.path.join(args.src_dir, x) for x in os.listdir(args.src_dir)]
|
103 |
+
ref_images = list(filter(os.path.isfile, ref_images))
|
104 |
+
|
105 |
+
if len(ref_images) == 0:
|
106 |
+
raise Exception('%s is empty' % args.src_dir)
|
107 |
+
|
108 |
+
os.makedirs(args.data_dir, exist_ok=True)
|
109 |
+
os.makedirs(args.mask_dir, exist_ok=True)
|
110 |
+
os.makedirs(args.generated_images_dir, exist_ok=True)
|
111 |
+
os.makedirs(args.dlatent_dir, exist_ok=True)
|
112 |
+
os.makedirs(args.video_dir, exist_ok=True)
|
113 |
+
|
114 |
+
# Initialize generator and perceptual model
|
115 |
+
tflib.init_tf()
|
116 |
+
with dnnlib.util.open_url(args.model_url, cache_dir=config.cache_dir) as f:
|
117 |
+
generator_network, discriminator_network, Gs_network = pickle.load(f)
|
118 |
+
|
119 |
+
generator = Generator(Gs_network, args.batch_size, clipping_threshold=args.clipping_threshold, tiled_dlatent=args.tile_dlatents, model_res=args.model_res, randomize_noise=args.randomize_noise)
|
120 |
+
if (args.dlatent_avg != ''):
|
121 |
+
generator.set_dlatent_avg(np.load(args.dlatent_avg))
|
122 |
+
|
123 |
+
perc_model = None
|
124 |
+
if (args.use_lpips_loss > 0.00000001):
|
125 |
+
with dnnlib.util.open_url(args.architecture, cache_dir=config.cache_dir) as f:
|
126 |
+
perc_model = pickle.load(f)
|
127 |
+
perceptual_model = PerceptualModel(args, perc_model=perc_model, batch_size=args.batch_size)
|
128 |
+
perceptual_model.build_perceptual_model(generator, discriminator_network)
|
129 |
+
|
130 |
+
ff_model = None
|
131 |
+
|
132 |
+
# Optimize (only) dlatents by minimizing perceptual loss between reference and generated images in feature space
|
133 |
+
for images_batch in tqdm(split_to_batches(ref_images, args.batch_size), total=len(ref_images)//args.batch_size):
|
134 |
+
names = [os.path.splitext(os.path.basename(x))[0] for x in images_batch]
|
135 |
+
if args.output_video:
|
136 |
+
video_out = {}
|
137 |
+
for name in names:
|
138 |
+
video_out[name] = cv2.VideoWriter(os.path.join(args.video_dir, f'{name}.avi'),cv2.VideoWriter_fourcc(*args.video_codec), args.video_frame_rate, (args.video_size,args.video_size))
|
139 |
+
|
140 |
+
perceptual_model.set_reference_images(images_batch)
|
141 |
+
dlatents = None
|
142 |
+
if (args.load_last != ''): # load previous dlatents for initialization
|
143 |
+
for name in names:
|
144 |
+
dl = np.expand_dims(np.load(os.path.join(args.load_last, f'{name}.npy')),axis=0)
|
145 |
+
if (dlatents is None):
|
146 |
+
dlatents = dl
|
147 |
+
else:
|
148 |
+
dlatents = np.vstack((dlatents,dl))
|
149 |
+
else:
|
150 |
+
if (ff_model is None):
|
151 |
+
if os.path.exists(args.load_resnet):
|
152 |
+
from keras.applications.resnet50 import preprocess_input
|
153 |
+
print("Loading ResNet Model:")
|
154 |
+
ff_model = load_model(args.load_resnet)
|
155 |
+
if (ff_model is None):
|
156 |
+
if os.path.exists(args.load_effnet):
|
157 |
+
import efficientnet
|
158 |
+
from efficientnet import preprocess_input
|
159 |
+
print("Loading EfficientNet Model:")
|
160 |
+
ff_model = load_model(args.load_effnet)
|
161 |
+
if (ff_model is not None): # predict initial dlatents with ResNet model
|
162 |
+
if (args.use_preprocess_input):
|
163 |
+
dlatents = ff_model.predict(preprocess_input(load_images(images_batch,image_size=args.resnet_image_size)))
|
164 |
+
else:
|
165 |
+
dlatents = ff_model.predict(load_images(images_batch,image_size=args.resnet_image_size))
|
166 |
+
if dlatents is not None:
|
167 |
+
generator.set_dlatents(dlatents)
|
168 |
+
op = perceptual_model.optimize(generator.dlatent_variable, iterations=args.iterations, use_optimizer=args.optimizer)
|
169 |
+
pbar = tqdm(op, leave=False, total=args.iterations)
|
170 |
+
vid_count = 0
|
171 |
+
best_loss = None
|
172 |
+
best_dlatent = None
|
173 |
+
avg_loss_count = 0
|
174 |
+
if args.early_stopping:
|
175 |
+
avg_loss = prev_loss = None
|
176 |
+
for loss_dict in pbar:
|
177 |
+
if args.early_stopping: # early stopping feature
|
178 |
+
if prev_loss is not None:
|
179 |
+
if avg_loss is not None:
|
180 |
+
avg_loss = 0.5 * avg_loss + (prev_loss - loss_dict["loss"])
|
181 |
+
if avg_loss < args.early_stopping_threshold: # count while under threshold; else reset
|
182 |
+
avg_loss_count += 1
|
183 |
+
else:
|
184 |
+
avg_loss_count = 0
|
185 |
+
if avg_loss_count > args.early_stopping_patience: # stop once threshold is reached
|
186 |
+
print("")
|
187 |
+
break
|
188 |
+
else:
|
189 |
+
avg_loss = prev_loss - loss_dict["loss"]
|
190 |
+
pbar.set_description(" ".join(names) + ": " + "; ".join(["{} {:.4f}".format(k, v) for k, v in loss_dict.items()]))
|
191 |
+
if best_loss is None or loss_dict["loss"] < best_loss:
|
192 |
+
if best_dlatent is None or args.average_best_loss <= 0.00000001:
|
193 |
+
best_dlatent = generator.get_dlatents()
|
194 |
+
else:
|
195 |
+
best_dlatent = 0.25 * best_dlatent + 0.75 * generator.get_dlatents()
|
196 |
+
if args.use_best_loss:
|
197 |
+
generator.set_dlatents(best_dlatent)
|
198 |
+
best_loss = loss_dict["loss"]
|
199 |
+
if args.output_video and (vid_count % args.video_skip == 0):
|
200 |
+
batch_frames = generator.generate_images()
|
201 |
+
for i, name in enumerate(names):
|
202 |
+
video_frame = PIL.Image.fromarray(batch_frames[i], 'RGB').resize((args.video_size,args.video_size),PIL.Image.LANCZOS)
|
203 |
+
video_out[name].write(cv2.cvtColor(np.array(video_frame).astype('uint8'), cv2.COLOR_RGB2BGR))
|
204 |
+
generator.stochastic_clip_dlatents()
|
205 |
+
prev_loss = loss_dict["loss"]
|
206 |
+
if not args.use_best_loss:
|
207 |
+
best_loss = prev_loss
|
208 |
+
print(" ".join(names), " Loss {:.4f}".format(best_loss))
|
209 |
+
|
210 |
+
if args.output_video:
|
211 |
+
for name in names:
|
212 |
+
video_out[name].release()
|
213 |
+
|
214 |
+
# Generate images from found dlatents and save them
|
215 |
+
if args.use_best_loss:
|
216 |
+
generator.set_dlatents(best_dlatent)
|
217 |
+
generated_images = generator.generate_images()
|
218 |
+
generated_dlatents = generator.get_dlatents()
|
219 |
+
for img_array, dlatent, img_path, img_name in zip(generated_images, generated_dlatents, images_batch, names):
|
220 |
+
mask_img = None
|
221 |
+
if args.composite_mask and (args.load_mask or args.face_mask):
|
222 |
+
_, im_name = os.path.split(img_path)
|
223 |
+
mask_img = os.path.join(args.mask_dir, f'{im_name}')
|
224 |
+
if args.composite_mask and mask_img is not None and os.path.isfile(mask_img):
|
225 |
+
orig_img = PIL.Image.open(img_path).convert('RGB')
|
226 |
+
width, height = orig_img.size
|
227 |
+
imask = PIL.Image.open(mask_img).convert('L').resize((width, height))
|
228 |
+
imask = imask.filter(ImageFilter.GaussianBlur(args.composite_blur))
|
229 |
+
mask = np.array(imask)/255
|
230 |
+
mask = np.expand_dims(mask,axis=-1)
|
231 |
+
img_array = mask*np.array(img_array) + (1.0-mask)*np.array(orig_img)
|
232 |
+
img_array = img_array.astype(np.uint8)
|
233 |
+
#img_array = np.where(mask, np.array(img_array), orig_img)
|
234 |
+
img = PIL.Image.fromarray(img_array, 'RGB')
|
235 |
+
img.save(os.path.join(args.generated_images_dir, f'{img_name}.png'), 'PNG')
|
236 |
+
np.save(os.path.join(args.dlatent_dir, f'{img_name}.npy'), dlatent)
|
237 |
+
|
238 |
+
generator.reset_dlatents()
|
239 |
+
|
240 |
+
|
241 |
+
if __name__ == "__main__":
|
242 |
+
main()
|
encoder/__init__.py
ADDED
File without changes
|
encoder/__pycache__/__init__.cpython-36.pyc
ADDED
Binary file (113 Bytes). View file
|
|
encoder/__pycache__/__init__.cpython-37.pyc
ADDED
Binary file (117 Bytes). View file
|
|
encoder/__pycache__/generator_model.cpython-36.pyc
ADDED
Binary file (5.09 kB). View file
|
|
encoder/__pycache__/generator_model.cpython-37.pyc
ADDED
Binary file (5.1 kB). View file
|
|
encoder/__pycache__/perceptual_model.cpython-36.pyc
ADDED
Binary file (10.1 kB). View file
|
|
encoder/__pycache__/perceptual_model.cpython-37.pyc
ADDED
Binary file (10 kB). View file
|
|
encoder/generator_model.py
ADDED
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import tensorflow as tf
|
3 |
+
import numpy as np
|
4 |
+
import dnnlib.tflib as tflib
|
5 |
+
from functools import partial
|
6 |
+
|
7 |
+
|
8 |
+
def create_stub(name, batch_size):
|
9 |
+
return tf.constant(0, dtype='float32', shape=(batch_size, 0))
|
10 |
+
|
11 |
+
|
12 |
+
def create_variable_for_generator(name, batch_size, tiled_dlatent, model_scale=18, tile_size = 1):
|
13 |
+
if tiled_dlatent:
|
14 |
+
low_dim_dlatent = tf.get_variable('learnable_dlatents',
|
15 |
+
shape=(batch_size, tile_size, 512),
|
16 |
+
dtype='float32',
|
17 |
+
initializer=tf.initializers.random_normal())
|
18 |
+
return tf.tile(low_dim_dlatent, [1, model_scale // tile_size, 1])
|
19 |
+
else:
|
20 |
+
return tf.get_variable('learnable_dlatents',
|
21 |
+
shape=(batch_size, model_scale, 512),
|
22 |
+
dtype='float32',
|
23 |
+
initializer=tf.initializers.random_normal())
|
24 |
+
|
25 |
+
|
26 |
+
class Generator:
|
27 |
+
def __init__(self, model, batch_size, custom_input=None, clipping_threshold=2, tiled_dlatent=False, model_res=1024, randomize_noise=False):
|
28 |
+
self.batch_size = batch_size
|
29 |
+
self.tiled_dlatent=tiled_dlatent
|
30 |
+
self.model_scale = int(2*(math.log(model_res,2)-1)) # For example, 1024 -> 18
|
31 |
+
|
32 |
+
if tiled_dlatent:
|
33 |
+
self.initial_dlatents = np.zeros((self.batch_size, 512))
|
34 |
+
model.components.synthesis.run(np.zeros((self.batch_size, self.model_scale, 512)),
|
35 |
+
randomize_noise=randomize_noise, minibatch_size=self.batch_size,
|
36 |
+
custom_inputs=[partial(create_variable_for_generator, batch_size=batch_size, tiled_dlatent=True),
|
37 |
+
partial(create_stub, batch_size=batch_size)],
|
38 |
+
structure='fixed')
|
39 |
+
else:
|
40 |
+
self.initial_dlatents = np.zeros((self.batch_size, self.model_scale, 512))
|
41 |
+
if custom_input is not None:
|
42 |
+
model.components.synthesis.run(self.initial_dlatents,
|
43 |
+
randomize_noise=randomize_noise, minibatch_size=self.batch_size,
|
44 |
+
custom_inputs=[partial(custom_input.eval(), batch_size=batch_size), partial(create_stub, batch_size=batch_size)],
|
45 |
+
structure='fixed')
|
46 |
+
else:
|
47 |
+
model.components.synthesis.run(self.initial_dlatents,
|
48 |
+
randomize_noise=randomize_noise, minibatch_size=self.batch_size,
|
49 |
+
custom_inputs=[partial(create_variable_for_generator, batch_size=batch_size, tiled_dlatent=False, model_scale=self.model_scale),
|
50 |
+
partial(create_stub, batch_size=batch_size)],
|
51 |
+
structure='fixed')
|
52 |
+
|
53 |
+
self.dlatent_avg_def = model.get_var('dlatent_avg')
|
54 |
+
self.reset_dlatent_avg()
|
55 |
+
self.sess = tf.compat.v1.get_default_session()
|
56 |
+
self.graph = tf.compat.v1.get_default_graph()
|
57 |
+
|
58 |
+
self.dlatent_variable = next(v for v in tf.compat.v1.global_variables() if 'learnable_dlatents' in v.name)
|
59 |
+
self._assign_dlatent_ph = tf.compat.v1.placeholder(tf.float32, name="assign_dlatent_ph")
|
60 |
+
self._assign_dlantent = tf.assign(self.dlatent_variable, self._assign_dlatent_ph)
|
61 |
+
self.set_dlatents(self.initial_dlatents)
|
62 |
+
|
63 |
+
def get_tensor(name):
|
64 |
+
try:
|
65 |
+
return self.graph.get_tensor_by_name(name)
|
66 |
+
except KeyError:
|
67 |
+
return None
|
68 |
+
|
69 |
+
self.generator_output = get_tensor('G_synthesis_1/_Run/concat:0')
|
70 |
+
if self.generator_output is None:
|
71 |
+
self.generator_output = get_tensor('G_synthesis_1/_Run/concat/concat:0')
|
72 |
+
if self.generator_output is None:
|
73 |
+
self.generator_output = get_tensor('G_synthesis_1/_Run/concat_1/concat:0')
|
74 |
+
# If we loaded only Gs and didn't load G or D, then scope "G_synthesis_1" won't exist in the graph.
|
75 |
+
if self.generator_output is None:
|
76 |
+
self.generator_output = get_tensor('G_synthesis/_Run/concat:0')
|
77 |
+
if self.generator_output is None:
|
78 |
+
self.generator_output = get_tensor('G_synthesis/_Run/concat/concat:0')
|
79 |
+
if self.generator_output is None:
|
80 |
+
self.generator_output = get_tensor('G_synthesis/_Run/concat_1/concat:0')
|
81 |
+
if self.generator_output is None:
|
82 |
+
for op in self.graph.get_operations():
|
83 |
+
print(op)
|
84 |
+
raise Exception("Couldn't find G_synthesis_1/_Run/concat tensor output")
|
85 |
+
self.generated_image = tflib.convert_images_to_uint8(self.generator_output, nchw_to_nhwc=True, uint8_cast=False)
|
86 |
+
self.generated_image_uint8 = tf.saturate_cast(self.generated_image, tf.uint8)
|
87 |
+
|
88 |
+
# Implement stochastic clipping similar to what is described in https://arxiv.org/abs/1702.04782
|
89 |
+
# (Slightly different in that the latent space is normal gaussian here and was uniform in [-1, 1] in that paper,
|
90 |
+
# so we clip any vector components outside of [-2, 2]. It seems fine, but I haven't done an ablation check.)
|
91 |
+
clipping_mask = tf.math.logical_or(self.dlatent_variable > clipping_threshold, self.dlatent_variable < -clipping_threshold)
|
92 |
+
clipped_values = tf.where(clipping_mask, tf.random.normal(shape=self.dlatent_variable.shape), self.dlatent_variable)
|
93 |
+
self.stochastic_clip_op = tf.assign(self.dlatent_variable, clipped_values)
|
94 |
+
|
95 |
+
def reset_dlatents(self):
|
96 |
+
self.set_dlatents(self.initial_dlatents)
|
97 |
+
|
98 |
+
def set_dlatents(self, dlatents):
|
99 |
+
if self.tiled_dlatent:
|
100 |
+
if (dlatents.shape != (self.batch_size, 512)) and (dlatents.shape[1] != 512):
|
101 |
+
dlatents = np.mean(dlatents, axis=1)
|
102 |
+
if (dlatents.shape != (self.batch_size, 512)):
|
103 |
+
dlatents = np.vstack([dlatents, np.zeros((self.batch_size-dlatents.shape[0], 512))])
|
104 |
+
assert (dlatents.shape == (self.batch_size, 512))
|
105 |
+
else:
|
106 |
+
if (dlatents.shape[1] > self.model_scale):
|
107 |
+
dlatents = dlatents[:,:self.model_scale,:]
|
108 |
+
if (isinstance(dlatents.shape[0], int)):
|
109 |
+
if (dlatents.shape != (self.batch_size, self.model_scale, 512)):
|
110 |
+
dlatents = np.vstack([dlatents, np.zeros((self.batch_size-dlatents.shape[0], self.model_scale, 512))])
|
111 |
+
assert (dlatents.shape == (self.batch_size, self.model_scale, 512))
|
112 |
+
self.sess.run([self._assign_dlantent], {self._assign_dlatent_ph: dlatents})
|
113 |
+
return
|
114 |
+
else:
|
115 |
+
self._assign_dlantent = tf.assign(self.dlatent_variable, dlatents)
|
116 |
+
return
|
117 |
+
self.sess.run([self._assign_dlantent], {self._assign_dlatent_ph: dlatents})
|
118 |
+
|
119 |
+
def stochastic_clip_dlatents(self):
|
120 |
+
self.sess.run(self.stochastic_clip_op)
|
121 |
+
|
122 |
+
def get_dlatents(self):
|
123 |
+
return self.sess.run(self.dlatent_variable)
|
124 |
+
|
125 |
+
def get_dlatent_avg(self):
|
126 |
+
return self.dlatent_avg
|
127 |
+
|
128 |
+
def set_dlatent_avg(self, dlatent_avg):
|
129 |
+
self.dlatent_avg = dlatent_avg
|
130 |
+
|
131 |
+
def reset_dlatent_avg(self):
|
132 |
+
self.dlatent_avg = self.dlatent_avg_def
|
133 |
+
|
134 |
+
def generate_images(self, dlatents=None):
|
135 |
+
if dlatents is not None:
|
136 |
+
self.set_dlatents(dlatents)
|
137 |
+
return self.sess.run(self.generated_image_uint8)
|
encoder/perceptual_model.py
ADDED
@@ -0,0 +1,304 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import absolute_import, division, print_function, unicode_literals
|
2 |
+
import tensorflow as tf
|
3 |
+
#import tensorflow_probability as tfp
|
4 |
+
#tf.enable_eager_execution()
|
5 |
+
|
6 |
+
import os
|
7 |
+
import bz2
|
8 |
+
import PIL.Image
|
9 |
+
from PIL import ImageFilter
|
10 |
+
import numpy as np
|
11 |
+
from keras.models import Model
|
12 |
+
from keras.utils import get_file
|
13 |
+
from keras.applications.vgg16 import VGG16, preprocess_input
|
14 |
+
import keras.backend as K
|
15 |
+
import traceback
|
16 |
+
import dnnlib.tflib as tflib
|
17 |
+
|
18 |
+
def load_images(images_list, image_size=256, sharpen=False):
|
19 |
+
loaded_images = list()
|
20 |
+
for img_path in images_list:
|
21 |
+
img = PIL.Image.open(img_path).convert('RGB')
|
22 |
+
if image_size is not None:
|
23 |
+
img = img.resize((image_size,image_size),PIL.Image.LANCZOS)
|
24 |
+
if (sharpen):
|
25 |
+
img = img.filter(ImageFilter.DETAIL)
|
26 |
+
img = np.array(img)
|
27 |
+
img = np.expand_dims(img, 0)
|
28 |
+
loaded_images.append(img)
|
29 |
+
loaded_images = np.vstack(loaded_images)
|
30 |
+
return loaded_images
|
31 |
+
|
32 |
+
def tf_custom_adaptive_loss(a,b):
|
33 |
+
from adaptive import lossfun
|
34 |
+
shape = a.get_shape().as_list()
|
35 |
+
dim = np.prod(shape[1:])
|
36 |
+
a = tf.reshape(a, [-1, dim])
|
37 |
+
b = tf.reshape(b, [-1, dim])
|
38 |
+
loss, _, _ = lossfun(b-a, var_suffix='1')
|
39 |
+
return tf.math.reduce_mean(loss)
|
40 |
+
|
41 |
+
def tf_custom_adaptive_rgb_loss(a,b):
|
42 |
+
from adaptive import image_lossfun
|
43 |
+
loss, _, _ = image_lossfun(b-a, color_space='RGB', representation='PIXEL')
|
44 |
+
return tf.math.reduce_mean(loss)
|
45 |
+
|
46 |
+
def tf_custom_l1_loss(img1,img2):
|
47 |
+
return tf.math.reduce_mean(tf.math.abs(img2-img1), axis=None)
|
48 |
+
|
49 |
+
def tf_custom_logcosh_loss(img1,img2):
|
50 |
+
return tf.math.reduce_mean(tf.keras.losses.logcosh(img1,img2))
|
51 |
+
|
52 |
+
def create_stub(batch_size):
|
53 |
+
return tf.constant(0, dtype='float32', shape=(batch_size, 0))
|
54 |
+
|
55 |
+
def unpack_bz2(src_path):
|
56 |
+
data = bz2.BZ2File(src_path).read()
|
57 |
+
dst_path = src_path[:-4]
|
58 |
+
with open(dst_path, 'wb') as fp:
|
59 |
+
fp.write(data)
|
60 |
+
return dst_path
|
61 |
+
|
62 |
+
class PerceptualModel:
|
63 |
+
def __init__(self, args, batch_size=1, perc_model=None, sess=None):
|
64 |
+
self.sess = tf.compat.v1.get_default_session() if sess is None else sess
|
65 |
+
K.set_session(self.sess)
|
66 |
+
self.epsilon = 0.00000001
|
67 |
+
self.lr = args.lr
|
68 |
+
self.decay_rate = args.decay_rate
|
69 |
+
self.decay_steps = args.decay_steps
|
70 |
+
self.img_size = args.image_size
|
71 |
+
self.layer = args.use_vgg_layer
|
72 |
+
self.vgg_loss = args.use_vgg_loss
|
73 |
+
self.face_mask = args.face_mask
|
74 |
+
self.use_grabcut = args.use_grabcut
|
75 |
+
self.scale_mask = args.scale_mask
|
76 |
+
self.mask_dir = args.mask_dir
|
77 |
+
if (self.layer <= 0 or self.vgg_loss <= self.epsilon):
|
78 |
+
self.vgg_loss = None
|
79 |
+
self.pixel_loss = args.use_pixel_loss
|
80 |
+
if (self.pixel_loss <= self.epsilon):
|
81 |
+
self.pixel_loss = None
|
82 |
+
self.mssim_loss = args.use_mssim_loss
|
83 |
+
if (self.mssim_loss <= self.epsilon):
|
84 |
+
self.mssim_loss = None
|
85 |
+
self.lpips_loss = args.use_lpips_loss
|
86 |
+
if (self.lpips_loss <= self.epsilon):
|
87 |
+
self.lpips_loss = None
|
88 |
+
self.l1_penalty = args.use_l1_penalty
|
89 |
+
if (self.l1_penalty <= self.epsilon):
|
90 |
+
self.l1_penalty = None
|
91 |
+
self.adaptive_loss = args.use_adaptive_loss
|
92 |
+
self.sharpen_input = args.sharpen_input
|
93 |
+
self.batch_size = batch_size
|
94 |
+
if perc_model is not None and self.lpips_loss is not None:
|
95 |
+
self.perc_model = perc_model
|
96 |
+
else:
|
97 |
+
self.perc_model = None
|
98 |
+
self.ref_img = None
|
99 |
+
self.ref_weight = None
|
100 |
+
self.perceptual_model = None
|
101 |
+
self.ref_img_features = None
|
102 |
+
self.features_weight = None
|
103 |
+
self.loss = None
|
104 |
+
self.discriminator_loss = args.use_discriminator_loss
|
105 |
+
if (self.discriminator_loss <= self.epsilon):
|
106 |
+
self.discriminator_loss = None
|
107 |
+
if self.discriminator_loss is not None:
|
108 |
+
self.discriminator = None
|
109 |
+
self.stub = create_stub(batch_size)
|
110 |
+
|
111 |
+
if self.face_mask:
|
112 |
+
import dlib
|
113 |
+
self.detector = dlib.get_frontal_face_detector()
|
114 |
+
landmarks_model_path = unpack_bz2('shape_predictor_68_face_landmarks.dat.bz2')
|
115 |
+
self.predictor = dlib.shape_predictor(landmarks_model_path)
|
116 |
+
|
117 |
+
def add_placeholder(self, var_name):
|
118 |
+
var_val = getattr(self, var_name)
|
119 |
+
setattr(self, var_name + "_placeholder", tf.compat.v1.placeholder(var_val.dtype, shape=var_val.get_shape()))
|
120 |
+
setattr(self, var_name + "_op", var_val.assign(getattr(self, var_name + "_placeholder")))
|
121 |
+
|
122 |
+
def assign_placeholder(self, var_name, var_val):
|
123 |
+
self.sess.run(getattr(self, var_name + "_op"), {getattr(self, var_name + "_placeholder"): var_val})
|
124 |
+
|
125 |
+
def build_perceptual_model(self, generator, discriminator=None):
|
126 |
+
# Learning rate
|
127 |
+
global_step = tf.Variable(0, dtype=tf.int32, trainable=False, name="global_step")
|
128 |
+
incremented_global_step = tf.compat.v1.assign_add(global_step, 1)
|
129 |
+
self._reset_global_step = tf.assign(global_step, 0)
|
130 |
+
self.learning_rate = tf.compat.v1.train.exponential_decay(self.lr, incremented_global_step,
|
131 |
+
self.decay_steps, self.decay_rate, staircase=True)
|
132 |
+
self.sess.run([self._reset_global_step])
|
133 |
+
|
134 |
+
if self.discriminator_loss is not None:
|
135 |
+
self.discriminator = discriminator
|
136 |
+
|
137 |
+
generated_image_tensor = generator.generated_image
|
138 |
+
generated_image = tf.compat.v1.image.resize_nearest_neighbor(generated_image_tensor,
|
139 |
+
(self.img_size, self.img_size), align_corners=True)
|
140 |
+
|
141 |
+
self.ref_img = tf.get_variable('ref_img', shape=generated_image.shape,
|
142 |
+
dtype='float32', initializer=tf.initializers.zeros())
|
143 |
+
self.ref_weight = tf.get_variable('ref_weight', shape=generated_image.shape,
|
144 |
+
dtype='float32', initializer=tf.initializers.zeros())
|
145 |
+
self.add_placeholder("ref_img")
|
146 |
+
self.add_placeholder("ref_weight")
|
147 |
+
|
148 |
+
if (self.vgg_loss is not None):
|
149 |
+
vgg16 = VGG16(include_top=False, weights='vgg16_weights_tf_dim_ordering_tf_kernels_notop.h5', input_shape=(self.img_size, self.img_size, 3)) # https://github.com/fchollet/deep-learning-models/releases/download/v0.1/vgg16_weights_tf_dim_ordering_tf_kernels_notop.h5
|
150 |
+
self.perceptual_model = Model(vgg16.input, vgg16.layers[self.layer].output)
|
151 |
+
generated_img_features = self.perceptual_model(preprocess_input(self.ref_weight * generated_image))
|
152 |
+
self.ref_img_features = tf.get_variable('ref_img_features', shape=generated_img_features.shape,
|
153 |
+
dtype='float32', initializer=tf.initializers.zeros())
|
154 |
+
self.features_weight = tf.get_variable('features_weight', shape=generated_img_features.shape,
|
155 |
+
dtype='float32', initializer=tf.initializers.zeros())
|
156 |
+
self.sess.run([self.features_weight.initializer, self.features_weight.initializer])
|
157 |
+
self.add_placeholder("ref_img_features")
|
158 |
+
self.add_placeholder("features_weight")
|
159 |
+
|
160 |
+
if self.perc_model is not None and self.lpips_loss is not None:
|
161 |
+
img1 = tflib.convert_images_from_uint8(self.ref_weight * self.ref_img, nhwc_to_nchw=True)
|
162 |
+
img2 = tflib.convert_images_from_uint8(self.ref_weight * generated_image, nhwc_to_nchw=True)
|
163 |
+
|
164 |
+
self.loss = 0
|
165 |
+
# L1 loss on VGG16 features
|
166 |
+
if (self.vgg_loss is not None):
|
167 |
+
if self.adaptive_loss:
|
168 |
+
self.loss += self.vgg_loss * tf_custom_adaptive_loss(self.features_weight * self.ref_img_features, self.features_weight * generated_img_features)
|
169 |
+
else:
|
170 |
+
self.loss += self.vgg_loss * tf_custom_logcosh_loss(self.features_weight * self.ref_img_features, self.features_weight * generated_img_features)
|
171 |
+
# + logcosh loss on image pixels
|
172 |
+
if (self.pixel_loss is not None):
|
173 |
+
if self.adaptive_loss:
|
174 |
+
self.loss += self.pixel_loss * tf_custom_adaptive_rgb_loss(self.ref_weight * self.ref_img, self.ref_weight * generated_image)
|
175 |
+
else:
|
176 |
+
self.loss += self.pixel_loss * tf_custom_logcosh_loss(self.ref_weight * self.ref_img, self.ref_weight * generated_image)
|
177 |
+
# + MS-SIM loss on image pixels
|
178 |
+
if (self.mssim_loss is not None):
|
179 |
+
self.loss += self.mssim_loss * tf.math.reduce_mean(1-tf.image.ssim_multiscale(self.ref_weight * self.ref_img, self.ref_weight * generated_image, 1))
|
180 |
+
# + extra perceptual loss on image pixels
|
181 |
+
if self.perc_model is not None and self.lpips_loss is not None:
|
182 |
+
self.loss += self.lpips_loss * tf.math.reduce_mean(self.perc_model.get_output_for(img1, img2))
|
183 |
+
# + L1 penalty on dlatent weights
|
184 |
+
if self.l1_penalty is not None:
|
185 |
+
self.loss += self.l1_penalty * 512 * tf.math.reduce_mean(tf.math.abs(generator.dlatent_variable-generator.get_dlatent_avg()))
|
186 |
+
# discriminator loss (realism)
|
187 |
+
if self.discriminator_loss is not None:
|
188 |
+
self.loss += self.discriminator_loss * tf.math.reduce_mean(self.discriminator.get_output_for(tflib.convert_images_from_uint8(generated_image_tensor, nhwc_to_nchw=True), self.stub))
|
189 |
+
# - discriminator_network.get_output_for(tflib.convert_images_from_uint8(ref_img, nhwc_to_nchw=True), stub)
|
190 |
+
|
191 |
+
|
192 |
+
def generate_face_mask(self, im):
|
193 |
+
from imutils import face_utils
|
194 |
+
import cv2
|
195 |
+
rects = self.detector(im, 1)
|
196 |
+
# loop over the face detections
|
197 |
+
for (j, rect) in enumerate(rects):
|
198 |
+
"""
|
199 |
+
Determine the facial landmarks for the face region, then convert the facial landmark (x, y)-coordinates to a NumPy array
|
200 |
+
"""
|
201 |
+
shape = self.predictor(im, rect)
|
202 |
+
shape = face_utils.shape_to_np(shape)
|
203 |
+
|
204 |
+
# we extract the face
|
205 |
+
vertices = cv2.convexHull(shape)
|
206 |
+
mask = np.zeros(im.shape[:2],np.uint8)
|
207 |
+
cv2.fillConvexPoly(mask, vertices, 1)
|
208 |
+
if self.use_grabcut:
|
209 |
+
bgdModel = np.zeros((1,65),np.float64)
|
210 |
+
fgdModel = np.zeros((1,65),np.float64)
|
211 |
+
rect = (0,0,im.shape[1],im.shape[2])
|
212 |
+
(x,y),radius = cv2.minEnclosingCircle(vertices)
|
213 |
+
center = (int(x),int(y))
|
214 |
+
radius = int(radius*self.scale_mask)
|
215 |
+
mask = cv2.circle(mask,center,radius,cv2.GC_PR_FGD,-1)
|
216 |
+
cv2.fillConvexPoly(mask, vertices, cv2.GC_FGD)
|
217 |
+
cv2.grabCut(im,mask,rect,bgdModel,fgdModel,5,cv2.GC_INIT_WITH_MASK)
|
218 |
+
mask = np.where((mask==2)|(mask==0),0,1)
|
219 |
+
return mask
|
220 |
+
|
221 |
+
def set_reference_images(self, images_list):
|
222 |
+
assert(len(images_list) != 0 and len(images_list) <= self.batch_size)
|
223 |
+
loaded_image = load_images(images_list, self.img_size, sharpen=self.sharpen_input)
|
224 |
+
image_features = None
|
225 |
+
if self.perceptual_model is not None:
|
226 |
+
image_features = self.perceptual_model.predict_on_batch(preprocess_input(np.array(loaded_image)))
|
227 |
+
weight_mask = np.ones(self.features_weight.shape)
|
228 |
+
|
229 |
+
if self.face_mask:
|
230 |
+
image_mask = np.zeros(self.ref_weight.shape)
|
231 |
+
for (i, im) in enumerate(loaded_image):
|
232 |
+
try:
|
233 |
+
_, img_name = os.path.split(images_list[i])
|
234 |
+
mask_img = os.path.join(self.mask_dir, f'{img_name}')
|
235 |
+
if (os.path.isfile(mask_img)):
|
236 |
+
print("Loading mask " + mask_img)
|
237 |
+
imask = PIL.Image.open(mask_img).convert('L')
|
238 |
+
mask = np.array(imask)/255
|
239 |
+
mask = np.expand_dims(mask,axis=-1)
|
240 |
+
else:
|
241 |
+
mask = self.generate_face_mask(im)
|
242 |
+
imask = (255*mask).astype('uint8')
|
243 |
+
imask = PIL.Image.fromarray(imask, 'L')
|
244 |
+
print("Saving mask " + mask_img)
|
245 |
+
imask.save(mask_img, 'PNG')
|
246 |
+
mask = np.expand_dims(mask,axis=-1)
|
247 |
+
mask = np.ones(im.shape,np.float32) * mask
|
248 |
+
except Exception as e:
|
249 |
+
print("Exception in mask handling for " + mask_img)
|
250 |
+
traceback.print_exc()
|
251 |
+
mask = np.ones(im.shape[:2],np.uint8)
|
252 |
+
mask = np.ones(im.shape,np.float32) * np.expand_dims(mask,axis=-1)
|
253 |
+
image_mask[i] = mask
|
254 |
+
img = None
|
255 |
+
else:
|
256 |
+
image_mask = np.ones(self.ref_weight.shape)
|
257 |
+
|
258 |
+
if len(images_list) != self.batch_size:
|
259 |
+
if image_features is not None:
|
260 |
+
features_space = list(self.features_weight.shape[1:])
|
261 |
+
existing_features_shape = [len(images_list)] + features_space
|
262 |
+
empty_features_shape = [self.batch_size - len(images_list)] + features_space
|
263 |
+
existing_examples = np.ones(shape=existing_features_shape)
|
264 |
+
empty_examples = np.zeros(shape=empty_features_shape)
|
265 |
+
weight_mask = np.vstack([existing_examples, empty_examples])
|
266 |
+
image_features = np.vstack([image_features, np.zeros(empty_features_shape)])
|
267 |
+
|
268 |
+
images_space = list(self.ref_weight.shape[1:])
|
269 |
+
existing_images_space = [len(images_list)] + images_space
|
270 |
+
empty_images_space = [self.batch_size - len(images_list)] + images_space
|
271 |
+
existing_images = np.ones(shape=existing_images_space)
|
272 |
+
empty_images = np.zeros(shape=empty_images_space)
|
273 |
+
image_mask = image_mask * np.vstack([existing_images, empty_images])
|
274 |
+
loaded_image = np.vstack([loaded_image, np.zeros(empty_images_space)])
|
275 |
+
|
276 |
+
if image_features is not None:
|
277 |
+
self.assign_placeholder("features_weight", weight_mask)
|
278 |
+
self.assign_placeholder("ref_img_features", image_features)
|
279 |
+
self.assign_placeholder("ref_weight", image_mask)
|
280 |
+
self.assign_placeholder("ref_img", loaded_image)
|
281 |
+
|
282 |
+
def optimize(self, vars_to_optimize, iterations=200, use_optimizer='adam'):
|
283 |
+
vars_to_optimize = vars_to_optimize if isinstance(vars_to_optimize, list) else [vars_to_optimize]
|
284 |
+
if use_optimizer == 'lbfgs':
|
285 |
+
optimizer = tf.contrib.opt.ScipyOptimizerInterface(self.loss, var_list=vars_to_optimize, method='L-BFGS-B', options={'maxiter': iterations})
|
286 |
+
else:
|
287 |
+
if use_optimizer == 'ggt':
|
288 |
+
optimizer = tf.contrib.opt.GGTOptimizer(learning_rate=self.learning_rate)
|
289 |
+
else:
|
290 |
+
optimizer = tf.train.AdamOptimizer(learning_rate=self.learning_rate)
|
291 |
+
min_op = optimizer.minimize(self.loss, var_list=[vars_to_optimize])
|
292 |
+
self.sess.run(tf.variables_initializer(optimizer.variables()))
|
293 |
+
fetch_ops = [min_op, self.loss, self.learning_rate]
|
294 |
+
#min_op = optimizer.minimize(self.sess)
|
295 |
+
#optim_results = tfp.optimizer.lbfgs_minimize(make_val_and_grad_fn(get_loss), initial_position=vars_to_optimize, num_correction_pairs=10, tolerance=1e-8)
|
296 |
+
self.sess.run(self._reset_global_step)
|
297 |
+
#self.sess.graph.finalize() # Graph is read-only after this statement.
|
298 |
+
for _ in range(iterations):
|
299 |
+
if use_optimizer == 'lbfgs':
|
300 |
+
optimizer.minimize(self.sess, fetches=[vars_to_optimize, self.loss])
|
301 |
+
yield {"loss":self.loss.eval()}
|
302 |
+
else:
|
303 |
+
_, loss, lr = self.sess.run(fetch_ops)
|
304 |
+
yield {"loss":loss,"lr":lr}
|
ffhq_dataset/__init__.py
ADDED
File without changes
|
ffhq_dataset/__pycache__/__init__.cpython-36.pyc
ADDED
Binary file (118 Bytes). View file
|
|
ffhq_dataset/__pycache__/__init__.cpython-37.pyc
ADDED
Binary file (122 Bytes). View file
|
|
ffhq_dataset/__pycache__/face_alignment.cpython-36.pyc
ADDED
Binary file (3.17 kB). View file
|
|
ffhq_dataset/__pycache__/face_alignment.cpython-37.pyc
ADDED
Binary file (3.17 kB). View file
|
|
ffhq_dataset/__pycache__/landmarks_detector.cpython-36.pyc
ADDED
Binary file (1.16 kB). View file
|
|