MHesho commited on
Commit
db534ca
1 Parent(s): fb542c6

Added base files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. Demo.ipynb +0 -0
  2. align_images.py +57 -0
  3. config.py +22 -0
  4. dnnlib/__init__.py +20 -0
  5. dnnlib/__pycache__/__init__.cpython-36.pyc +0 -0
  6. dnnlib/__pycache__/__init__.cpython-37.pyc +0 -0
  7. dnnlib/__pycache__/util.cpython-36.pyc +0 -0
  8. dnnlib/__pycache__/util.cpython-37.pyc +0 -0
  9. dnnlib/submission/__init__.py +9 -0
  10. dnnlib/submission/__pycache__/__init__.cpython-36.pyc +0 -0
  11. dnnlib/submission/__pycache__/__init__.cpython-37.pyc +0 -0
  12. dnnlib/submission/__pycache__/run_context.cpython-36.pyc +0 -0
  13. dnnlib/submission/__pycache__/run_context.cpython-37.pyc +0 -0
  14. dnnlib/submission/__pycache__/submit.cpython-36.pyc +0 -0
  15. dnnlib/submission/__pycache__/submit.cpython-37.pyc +0 -0
  16. dnnlib/submission/_internal/run.py +45 -0
  17. dnnlib/submission/run_context.py +99 -0
  18. dnnlib/submission/submit.py +290 -0
  19. dnnlib/tflib/__init__.py +16 -0
  20. dnnlib/tflib/__pycache__/__init__.cpython-36.pyc +0 -0
  21. dnnlib/tflib/__pycache__/__init__.cpython-37.pyc +0 -0
  22. dnnlib/tflib/__pycache__/autosummary.cpython-36.pyc +0 -0
  23. dnnlib/tflib/__pycache__/autosummary.cpython-37.pyc +0 -0
  24. dnnlib/tflib/__pycache__/network.cpython-36.pyc +0 -0
  25. dnnlib/tflib/__pycache__/network.cpython-37.pyc +0 -0
  26. dnnlib/tflib/__pycache__/optimizer.cpython-36.pyc +0 -0
  27. dnnlib/tflib/__pycache__/optimizer.cpython-37.pyc +0 -0
  28. dnnlib/tflib/__pycache__/tfutil.cpython-36.pyc +0 -0
  29. dnnlib/tflib/__pycache__/tfutil.cpython-37.pyc +0 -0
  30. dnnlib/tflib/autosummary.py +184 -0
  31. dnnlib/tflib/network.py +628 -0
  32. dnnlib/tflib/optimizer.py +214 -0
  33. dnnlib/tflib/tfutil.py +242 -0
  34. dnnlib/util.py +408 -0
  35. encode_images.py +242 -0
  36. encoder/__init__.py +0 -0
  37. encoder/__pycache__/__init__.cpython-36.pyc +0 -0
  38. encoder/__pycache__/__init__.cpython-37.pyc +0 -0
  39. encoder/__pycache__/generator_model.cpython-36.pyc +0 -0
  40. encoder/__pycache__/generator_model.cpython-37.pyc +0 -0
  41. encoder/__pycache__/perceptual_model.cpython-36.pyc +0 -0
  42. encoder/__pycache__/perceptual_model.cpython-37.pyc +0 -0
  43. encoder/generator_model.py +137 -0
  44. encoder/perceptual_model.py +304 -0
  45. ffhq_dataset/__init__.py +0 -0
  46. ffhq_dataset/__pycache__/__init__.cpython-36.pyc +0 -0
  47. ffhq_dataset/__pycache__/__init__.cpython-37.pyc +0 -0
  48. ffhq_dataset/__pycache__/face_alignment.cpython-36.pyc +0 -0
  49. ffhq_dataset/__pycache__/face_alignment.cpython-37.pyc +0 -0
  50. ffhq_dataset/__pycache__/landmarks_detector.cpython-36.pyc +0 -0
Demo.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
align_images.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import bz2
4
+ import argparse
5
+ from keras.utils import get_file
6
+ from ffhq_dataset.face_alignment import image_align
7
+ from ffhq_dataset.landmarks_detector import LandmarksDetector
8
+ import multiprocessing
9
+
10
+ def unpack_bz2(src_path):
11
+ data = bz2.BZ2File(src_path).read()
12
+ dst_path = src_path[:-4]
13
+ with open(dst_path, 'wb') as fp:
14
+ fp.write(data)
15
+ return dst_path
16
+
17
+
18
+ if __name__ == "__main__":
19
+ """
20
+ Extracts and aligns all faces from images using DLib and a function from original FFHQ dataset preparation step
21
+ python align_images.py /raw_images /aligned_images
22
+ """
23
+ parser = argparse.ArgumentParser(description='Align faces from input images', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
24
+ parser.add_argument('raw_dir', help='Directory with raw images for face alignment')
25
+ parser.add_argument('aligned_dir', help='Directory for storing aligned images')
26
+ parser.add_argument('--output_size', default=1024, help='The dimension of images for input to the model', type=int)
27
+ parser.add_argument('--x_scale', default=1, help='Scaling factor for x dimension', type=float)
28
+ parser.add_argument('--y_scale', default=1, help='Scaling factor for y dimension', type=float)
29
+ parser.add_argument('--em_scale', default=0.1, help='Scaling factor for eye-mouth distance', type=float)
30
+ parser.add_argument('--use_alpha', default=False, help='Add an alpha channel for masking', type=bool)
31
+
32
+ args, other_args = parser.parse_known_args()
33
+
34
+ landmarks_model_path = unpack_bz2("shape_predictor_68_face_landmarks.dat.bz2")
35
+ RAW_IMAGES_DIR = args.raw_dir
36
+ ALIGNED_IMAGES_DIR = args.aligned_dir
37
+
38
+ landmarks_detector = LandmarksDetector(landmarks_model_path)
39
+ for img_name in os.listdir(RAW_IMAGES_DIR):
40
+ print('Aligning %s ...' % img_name)
41
+ try:
42
+ raw_img_path = os.path.join(RAW_IMAGES_DIR, img_name)
43
+ fn = face_img_name = '%s_%02d.png' % (os.path.splitext(img_name)[0], 1)
44
+ if os.path.isfile(fn):
45
+ continue
46
+ print('Getting landmarks...')
47
+ for i, face_landmarks in enumerate(landmarks_detector.get_landmarks(raw_img_path), start=1):
48
+ try:
49
+ print('Starting face alignment...')
50
+ face_img_name = '%s_%02d.png' % (os.path.splitext(img_name)[0], i)
51
+ aligned_face_path = os.path.join(ALIGNED_IMAGES_DIR, face_img_name)
52
+ image_align(raw_img_path, aligned_face_path, face_landmarks, output_size=args.output_size, x_scale=args.x_scale, y_scale=args.y_scale, em_scale=args.em_scale, alpha=args.use_alpha)
53
+ print('Wrote result %s' % aligned_face_path)
54
+ except:
55
+ print("Exception in face alignment!")
56
+ except:
57
+ print("Exception in landmark detection!")
config.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # This work is licensed under the Creative Commons Attribution-NonCommercial
4
+ # 4.0 International License. To view a copy of this license, visit
5
+ # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to
6
+ # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.
7
+
8
+ """Global configuration."""
9
+
10
+ #----------------------------------------------------------------------------
11
+ # Paths.
12
+
13
+ result_dir = 'results'
14
+ data_dir = 'datasets'
15
+ cache_dir = 'cache'
16
+ run_dir_ignore = ['results', 'datasets', 'cache']
17
+
18
+ # experimental - replace Dense layers with TreeConnect
19
+ use_treeconnect = False
20
+ treeconnect_threshold = 1024
21
+
22
+ #----------------------------------------------------------------------------
dnnlib/__init__.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # This work is licensed under the Creative Commons Attribution-NonCommercial
4
+ # 4.0 International License. To view a copy of this license, visit
5
+ # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to
6
+ # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.
7
+
8
+ from . import submission
9
+
10
+ from .submission.run_context import RunContext
11
+
12
+ from .submission.submit import SubmitTarget
13
+ from .submission.submit import PathType
14
+ from .submission.submit import SubmitConfig
15
+ from .submission.submit import get_path_from_template
16
+ from .submission.submit import submit_run
17
+
18
+ from .util import EasyDict
19
+
20
+ submit_config: SubmitConfig = None # Package level variable for SubmitConfig which is only valid when inside the run function.
dnnlib/__pycache__/__init__.cpython-36.pyc ADDED
Binary file (468 Bytes). View file
 
dnnlib/__pycache__/__init__.cpython-37.pyc ADDED
Binary file (498 Bytes). View file
 
dnnlib/__pycache__/util.cpython-36.pyc ADDED
Binary file (12.1 kB). View file
 
dnnlib/__pycache__/util.cpython-37.pyc ADDED
Binary file (12.1 kB). View file
 
dnnlib/submission/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # This work is licensed under the Creative Commons Attribution-NonCommercial
4
+ # 4.0 International License. To view a copy of this license, visit
5
+ # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to
6
+ # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.
7
+
8
+ from . import run_context
9
+ from . import submit
dnnlib/submission/__pycache__/__init__.cpython-36.pyc ADDED
Binary file (188 Bytes). View file
 
dnnlib/submission/__pycache__/__init__.cpython-37.pyc ADDED
Binary file (192 Bytes). View file
 
dnnlib/submission/__pycache__/run_context.cpython-36.pyc ADDED
Binary file (4.35 kB). View file
 
dnnlib/submission/__pycache__/run_context.cpython-37.pyc ADDED
Binary file (4.35 kB). View file
 
dnnlib/submission/__pycache__/submit.cpython-36.pyc ADDED
Binary file (9.19 kB). View file
 
dnnlib/submission/__pycache__/submit.cpython-37.pyc ADDED
Binary file (9.19 kB). View file
 
dnnlib/submission/_internal/run.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # This work is licensed under the Creative Commons Attribution-NonCommercial
4
+ # 4.0 International License. To view a copy of this license, visit
5
+ # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to
6
+ # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.
7
+
8
+ """Helper for launching run functions in computing clusters.
9
+
10
+ During the submit process, this file is copied to the appropriate run dir.
11
+ When the job is launched in the cluster, this module is the first thing that
12
+ is run inside the docker container.
13
+ """
14
+
15
+ import os
16
+ import pickle
17
+ import sys
18
+
19
+ # PYTHONPATH should have been set so that the run_dir/src is in it
20
+ import dnnlib
21
+
22
+ def main():
23
+ if not len(sys.argv) >= 4:
24
+ raise RuntimeError("This script needs three arguments: run_dir, task_name and host_name!")
25
+
26
+ run_dir = str(sys.argv[1])
27
+ task_name = str(sys.argv[2])
28
+ host_name = str(sys.argv[3])
29
+
30
+ submit_config_path = os.path.join(run_dir, "submit_config.pkl")
31
+
32
+ # SubmitConfig should have been pickled to the run dir
33
+ if not os.path.exists(submit_config_path):
34
+ raise RuntimeError("SubmitConfig pickle file does not exist!")
35
+
36
+ submit_config: dnnlib.SubmitConfig = pickle.load(open(submit_config_path, "rb"))
37
+ dnnlib.submission.submit.set_user_name_override(submit_config.user_name)
38
+
39
+ submit_config.task_name = task_name
40
+ submit_config.host_name = host_name
41
+
42
+ dnnlib.submission.submit.run_wrapper(submit_config)
43
+
44
+ if __name__ == "__main__":
45
+ main()
dnnlib/submission/run_context.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # This work is licensed under the Creative Commons Attribution-NonCommercial
4
+ # 4.0 International License. To view a copy of this license, visit
5
+ # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to
6
+ # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.
7
+
8
+ """Helpers for managing the run/training loop."""
9
+
10
+ import datetime
11
+ import json
12
+ import os
13
+ import pprint
14
+ import time
15
+ import types
16
+
17
+ from typing import Any
18
+
19
+ from . import submit
20
+
21
+
22
+ class RunContext(object):
23
+ """Helper class for managing the run/training loop.
24
+
25
+ The context will hide the implementation details of a basic run/training loop.
26
+ It will set things up properly, tell if run should be stopped, and then cleans up.
27
+ User should call update periodically and use should_stop to determine if run should be stopped.
28
+
29
+ Args:
30
+ submit_config: The SubmitConfig that is used for the current run.
31
+ config_module: The whole config module that is used for the current run.
32
+ max_epoch: Optional cached value for the max_epoch variable used in update.
33
+ """
34
+
35
+ def __init__(self, submit_config: submit.SubmitConfig, config_module: types.ModuleType = None, max_epoch: Any = None):
36
+ self.submit_config = submit_config
37
+ self.should_stop_flag = False
38
+ self.has_closed = False
39
+ self.start_time = time.time()
40
+ self.last_update_time = time.time()
41
+ self.last_update_interval = 0.0
42
+ self.max_epoch = max_epoch
43
+
44
+ # pretty print the all the relevant content of the config module to a text file
45
+ if config_module is not None:
46
+ with open(os.path.join(submit_config.run_dir, "config.txt"), "w") as f:
47
+ filtered_dict = {k: v for k, v in config_module.__dict__.items() if not k.startswith("_") and not isinstance(v, (types.ModuleType, types.FunctionType, types.LambdaType, submit.SubmitConfig, type))}
48
+ pprint.pprint(filtered_dict, stream=f, indent=4, width=200, compact=False)
49
+
50
+ # write out details about the run to a text file
51
+ self.run_txt_data = {"task_name": submit_config.task_name, "host_name": submit_config.host_name, "start_time": datetime.datetime.now().isoformat(sep=" ")}
52
+ with open(os.path.join(submit_config.run_dir, "run.txt"), "w") as f:
53
+ pprint.pprint(self.run_txt_data, stream=f, indent=4, width=200, compact=False)
54
+
55
+ def __enter__(self) -> "RunContext":
56
+ return self
57
+
58
+ def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
59
+ self.close()
60
+
61
+ def update(self, loss: Any = 0, cur_epoch: Any = 0, max_epoch: Any = None) -> None:
62
+ """Do general housekeeping and keep the state of the context up-to-date.
63
+ Should be called often enough but not in a tight loop."""
64
+ assert not self.has_closed
65
+
66
+ self.last_update_interval = time.time() - self.last_update_time
67
+ self.last_update_time = time.time()
68
+
69
+ if os.path.exists(os.path.join(self.submit_config.run_dir, "abort.txt")):
70
+ self.should_stop_flag = True
71
+
72
+ max_epoch_val = self.max_epoch if max_epoch is None else max_epoch
73
+
74
+ def should_stop(self) -> bool:
75
+ """Tell whether a stopping condition has been triggered one way or another."""
76
+ return self.should_stop_flag
77
+
78
+ def get_time_since_start(self) -> float:
79
+ """How much time has passed since the creation of the context."""
80
+ return time.time() - self.start_time
81
+
82
+ def get_time_since_last_update(self) -> float:
83
+ """How much time has passed since the last call to update."""
84
+ return time.time() - self.last_update_time
85
+
86
+ def get_last_update_interval(self) -> float:
87
+ """How much time passed between the previous two calls to update."""
88
+ return self.last_update_interval
89
+
90
+ def close(self) -> None:
91
+ """Close the context and clean up.
92
+ Should only be called once."""
93
+ if not self.has_closed:
94
+ # update the run.txt with stopping time
95
+ self.run_txt_data["stop_time"] = datetime.datetime.now().isoformat(sep=" ")
96
+ with open(os.path.join(self.submit_config.run_dir, "run.txt"), "w") as f:
97
+ pprint.pprint(self.run_txt_data, stream=f, indent=4, width=200, compact=False)
98
+
99
+ self.has_closed = True
dnnlib/submission/submit.py ADDED
@@ -0,0 +1,290 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # This work is licensed under the Creative Commons Attribution-NonCommercial
4
+ # 4.0 International License. To view a copy of this license, visit
5
+ # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to
6
+ # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.
7
+
8
+ """Submit a function to be run either locally or in a computing cluster."""
9
+
10
+ import copy
11
+ import io
12
+ import os
13
+ import pathlib
14
+ import pickle
15
+ import platform
16
+ import pprint
17
+ import re
18
+ import shutil
19
+ import time
20
+ import traceback
21
+
22
+ import zipfile
23
+
24
+ from enum import Enum
25
+
26
+ from .. import util
27
+ from ..util import EasyDict
28
+
29
+
30
+ class SubmitTarget(Enum):
31
+ """The target where the function should be run.
32
+
33
+ LOCAL: Run it locally.
34
+ """
35
+ LOCAL = 1
36
+
37
+
38
+ class PathType(Enum):
39
+ """Determines in which format should a path be formatted.
40
+
41
+ WINDOWS: Format with Windows style.
42
+ LINUX: Format with Linux/Posix style.
43
+ AUTO: Use current OS type to select either WINDOWS or LINUX.
44
+ """
45
+ WINDOWS = 1
46
+ LINUX = 2
47
+ AUTO = 3
48
+
49
+
50
+ _user_name_override = None
51
+
52
+
53
+ class SubmitConfig(util.EasyDict):
54
+ """Strongly typed config dict needed to submit runs.
55
+
56
+ Attributes:
57
+ run_dir_root: Path to the run dir root. Can be optionally templated with tags. Needs to always be run through get_path_from_template.
58
+ run_desc: Description of the run. Will be used in the run dir and task name.
59
+ run_dir_ignore: List of file patterns used to ignore files when copying files to the run dir.
60
+ run_dir_extra_files: List of (abs_path, rel_path) tuples of file paths. rel_path root will be the src directory inside the run dir.
61
+ submit_target: Submit target enum value. Used to select where the run is actually launched.
62
+ num_gpus: Number of GPUs used/requested for the run.
63
+ print_info: Whether to print debug information when submitting.
64
+ ask_confirmation: Whether to ask a confirmation before submitting.
65
+ run_id: Automatically populated value during submit.
66
+ run_name: Automatically populated value during submit.
67
+ run_dir: Automatically populated value during submit.
68
+ run_func_name: Automatically populated value during submit.
69
+ run_func_kwargs: Automatically populated value during submit.
70
+ user_name: Automatically populated value during submit. Can be set by the user which will then override the automatic value.
71
+ task_name: Automatically populated value during submit.
72
+ host_name: Automatically populated value during submit.
73
+ """
74
+
75
+ def __init__(self):
76
+ super().__init__()
77
+
78
+ # run (set these)
79
+ self.run_dir_root = "" # should always be passed through get_path_from_template
80
+ self.run_desc = ""
81
+ self.run_dir_ignore = ["__pycache__", "*.pyproj", "*.sln", "*.suo", ".cache", ".idea", ".vs", ".vscode"]
82
+ self.run_dir_extra_files = None
83
+
84
+ # submit (set these)
85
+ self.submit_target = SubmitTarget.LOCAL
86
+ self.num_gpus = 1
87
+ self.print_info = False
88
+ self.ask_confirmation = False
89
+
90
+ # (automatically populated)
91
+ self.run_id = None
92
+ self.run_name = None
93
+ self.run_dir = None
94
+ self.run_func_name = None
95
+ self.run_func_kwargs = None
96
+ self.user_name = None
97
+ self.task_name = None
98
+ self.host_name = "localhost"
99
+
100
+
101
+ def get_path_from_template(path_template: str, path_type: PathType = PathType.AUTO) -> str:
102
+ """Replace tags in the given path template and return either Windows or Linux formatted path."""
103
+ # automatically select path type depending on running OS
104
+ if path_type == PathType.AUTO:
105
+ if platform.system() == "Windows":
106
+ path_type = PathType.WINDOWS
107
+ elif platform.system() == "Linux":
108
+ path_type = PathType.LINUX
109
+ else:
110
+ raise RuntimeError("Unknown platform")
111
+
112
+ path_template = path_template.replace("<USERNAME>", get_user_name())
113
+
114
+ # return correctly formatted path
115
+ if path_type == PathType.WINDOWS:
116
+ return str(pathlib.PureWindowsPath(path_template))
117
+ elif path_type == PathType.LINUX:
118
+ return str(pathlib.PurePosixPath(path_template))
119
+ else:
120
+ raise RuntimeError("Unknown platform")
121
+
122
+
123
+ def get_template_from_path(path: str) -> str:
124
+ """Convert a normal path back to its template representation."""
125
+ # replace all path parts with the template tags
126
+ path = path.replace("\\", "/")
127
+ return path
128
+
129
+
130
+ def convert_path(path: str, path_type: PathType = PathType.AUTO) -> str:
131
+ """Convert a normal path to template and the convert it back to a normal path with given path type."""
132
+ path_template = get_template_from_path(path)
133
+ path = get_path_from_template(path_template, path_type)
134
+ return path
135
+
136
+
137
+ def set_user_name_override(name: str) -> None:
138
+ """Set the global username override value."""
139
+ global _user_name_override
140
+ _user_name_override = name
141
+
142
+
143
+ def get_user_name():
144
+ """Get the current user name."""
145
+ if _user_name_override is not None:
146
+ return _user_name_override
147
+ elif platform.system() == "Windows":
148
+ return os.getlogin()
149
+ elif platform.system() == "Linux":
150
+ try:
151
+ import pwd # pylint: disable=import-error
152
+ return pwd.getpwuid(os.geteuid()).pw_name # pylint: disable=no-member
153
+ except:
154
+ return "unknown"
155
+ else:
156
+ raise RuntimeError("Unknown platform")
157
+
158
+
159
+ def _create_run_dir_local(submit_config: SubmitConfig) -> str:
160
+ """Create a new run dir with increasing ID number at the start."""
161
+ run_dir_root = get_path_from_template(submit_config.run_dir_root, PathType.AUTO)
162
+
163
+ if not os.path.exists(run_dir_root):
164
+ print("Creating the run dir root: {}".format(run_dir_root))
165
+ os.makedirs(run_dir_root)
166
+
167
+ submit_config.run_id = _get_next_run_id_local(run_dir_root)
168
+ submit_config.run_name = "{0:05d}-{1}".format(submit_config.run_id, submit_config.run_desc)
169
+ run_dir = os.path.join(run_dir_root, submit_config.run_name)
170
+
171
+ if os.path.exists(run_dir):
172
+ raise RuntimeError("The run dir already exists! ({0})".format(run_dir))
173
+
174
+ print("Creating the run dir: {}".format(run_dir))
175
+ os.makedirs(run_dir)
176
+
177
+ return run_dir
178
+
179
+
180
+ def _get_next_run_id_local(run_dir_root: str) -> int:
181
+ """Reads all directory names in a given directory (non-recursive) and returns the next (increasing) run id. Assumes IDs are numbers at the start of the directory names."""
182
+ dir_names = [d for d in os.listdir(run_dir_root) if os.path.isdir(os.path.join(run_dir_root, d))]
183
+ r = re.compile("^\\d+") # match one or more digits at the start of the string
184
+ run_id = 0
185
+
186
+ for dir_name in dir_names:
187
+ m = r.match(dir_name)
188
+
189
+ if m is not None:
190
+ i = int(m.group())
191
+ run_id = max(run_id, i + 1)
192
+
193
+ return run_id
194
+
195
+
196
+ def _populate_run_dir(run_dir: str, submit_config: SubmitConfig) -> None:
197
+ """Copy all necessary files into the run dir. Assumes that the dir exists, is local, and is writable."""
198
+ print("Copying files to the run dir")
199
+ files = []
200
+
201
+ run_func_module_dir_path = util.get_module_dir_by_obj_name(submit_config.run_func_name)
202
+ assert '.' in submit_config.run_func_name
203
+ for _idx in range(submit_config.run_func_name.count('.') - 1):
204
+ run_func_module_dir_path = os.path.dirname(run_func_module_dir_path)
205
+ files += util.list_dir_recursively_with_ignore(run_func_module_dir_path, ignores=submit_config.run_dir_ignore, add_base_to_relative=False)
206
+
207
+ dnnlib_module_dir_path = util.get_module_dir_by_obj_name("dnnlib")
208
+ files += util.list_dir_recursively_with_ignore(dnnlib_module_dir_path, ignores=submit_config.run_dir_ignore, add_base_to_relative=True)
209
+
210
+ if submit_config.run_dir_extra_files is not None:
211
+ files += submit_config.run_dir_extra_files
212
+
213
+ files = [(f[0], os.path.join(run_dir, "src", f[1])) for f in files]
214
+ files += [(os.path.join(dnnlib_module_dir_path, "submission", "_internal", "run.py"), os.path.join(run_dir, "run.py"))]
215
+
216
+ util.copy_files_and_create_dirs(files)
217
+
218
+ pickle.dump(submit_config, open(os.path.join(run_dir, "submit_config.pkl"), "wb"))
219
+
220
+ with open(os.path.join(run_dir, "submit_config.txt"), "w") as f:
221
+ pprint.pprint(submit_config, stream=f, indent=4, width=200, compact=False)
222
+
223
+
224
+ def run_wrapper(submit_config: SubmitConfig) -> None:
225
+ """Wrap the actual run function call for handling logging, exceptions, typing, etc."""
226
+ is_local = submit_config.submit_target == SubmitTarget.LOCAL
227
+
228
+ checker = None
229
+
230
+ # when running locally, redirect stderr to stdout, log stdout to a file, and force flushing
231
+ if is_local:
232
+ logger = util.Logger(file_name=os.path.join(submit_config.run_dir, "log.txt"), file_mode="w", should_flush=True)
233
+ else: # when running in a cluster, redirect stderr to stdout, and just force flushing (log writing is handled by run.sh)
234
+ logger = util.Logger(file_name=None, should_flush=True)
235
+
236
+ import dnnlib
237
+ dnnlib.submit_config = submit_config
238
+
239
+ try:
240
+ print("dnnlib: Running {0}() on {1}...".format(submit_config.run_func_name, submit_config.host_name))
241
+ start_time = time.time()
242
+ util.call_func_by_name(func_name=submit_config.run_func_name, submit_config=submit_config, **submit_config.run_func_kwargs)
243
+ print("dnnlib: Finished {0}() in {1}.".format(submit_config.run_func_name, util.format_time(time.time() - start_time)))
244
+ except:
245
+ if is_local:
246
+ raise
247
+ else:
248
+ traceback.print_exc()
249
+
250
+ log_src = os.path.join(submit_config.run_dir, "log.txt")
251
+ log_dst = os.path.join(get_path_from_template(submit_config.run_dir_root), "{0}-error.txt".format(submit_config.run_name))
252
+ shutil.copyfile(log_src, log_dst)
253
+ finally:
254
+ open(os.path.join(submit_config.run_dir, "_finished.txt"), "w").close()
255
+
256
+ dnnlib.submit_config = None
257
+ logger.close()
258
+
259
+ if checker is not None:
260
+ checker.stop()
261
+
262
+
263
+ def submit_run(submit_config: SubmitConfig, run_func_name: str, **run_func_kwargs) -> None:
264
+ """Create a run dir, gather files related to the run, copy files to the run dir, and launch the run in appropriate place."""
265
+ submit_config = copy.copy(submit_config)
266
+
267
+ if submit_config.user_name is None:
268
+ submit_config.user_name = get_user_name()
269
+
270
+ submit_config.run_func_name = run_func_name
271
+ submit_config.run_func_kwargs = run_func_kwargs
272
+
273
+ assert submit_config.submit_target == SubmitTarget.LOCAL
274
+ if submit_config.submit_target in {SubmitTarget.LOCAL}:
275
+ run_dir = _create_run_dir_local(submit_config)
276
+
277
+ submit_config.task_name = "{0}-{1:05d}-{2}".format(submit_config.user_name, submit_config.run_id, submit_config.run_desc)
278
+ submit_config.run_dir = run_dir
279
+ _populate_run_dir(run_dir, submit_config)
280
+
281
+ if submit_config.print_info:
282
+ print("\nSubmit config:\n")
283
+ pprint.pprint(submit_config, indent=4, width=200, compact=False)
284
+ print()
285
+
286
+ if submit_config.ask_confirmation:
287
+ if not util.ask_yes_no("Continue submitting the job?"):
288
+ return
289
+
290
+ run_wrapper(submit_config)
dnnlib/tflib/__init__.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # This work is licensed under the Creative Commons Attribution-NonCommercial
4
+ # 4.0 International License. To view a copy of this license, visit
5
+ # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to
6
+ # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.
7
+
8
+ from . import autosummary
9
+ from . import network
10
+ from . import optimizer
11
+ from . import tfutil
12
+
13
+ from .tfutil import *
14
+ from .network import Network
15
+
16
+ from .optimizer import Optimizer
dnnlib/tflib/__pycache__/__init__.cpython-36.pyc ADDED
Binary file (322 Bytes). View file
 
dnnlib/tflib/__pycache__/__init__.cpython-37.pyc ADDED
Binary file (326 Bytes). View file
 
dnnlib/tflib/__pycache__/autosummary.cpython-36.pyc ADDED
Binary file (6.38 kB). View file
 
dnnlib/tflib/__pycache__/autosummary.cpython-37.pyc ADDED
Binary file (6.38 kB). View file
 
dnnlib/tflib/__pycache__/network.cpython-36.pyc ADDED
Binary file (31 kB). View file
 
dnnlib/tflib/__pycache__/network.cpython-37.pyc ADDED
Binary file (31 kB). View file
 
dnnlib/tflib/__pycache__/optimizer.cpython-36.pyc ADDED
Binary file (8.52 kB). View file
 
dnnlib/tflib/__pycache__/optimizer.cpython-37.pyc ADDED
Binary file (8.53 kB). View file
 
dnnlib/tflib/__pycache__/tfutil.cpython-36.pyc ADDED
Binary file (8.47 kB). View file
 
dnnlib/tflib/__pycache__/tfutil.cpython-37.pyc ADDED
Binary file (8.44 kB). View file
 
dnnlib/tflib/autosummary.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # This work is licensed under the Creative Commons Attribution-NonCommercial
4
+ # 4.0 International License. To view a copy of this license, visit
5
+ # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to
6
+ # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.
7
+
8
+ """Helper for adding automatically tracked values to Tensorboard.
9
+
10
+ Autosummary creates an identity op that internally keeps track of the input
11
+ values and automatically shows up in TensorBoard. The reported value
12
+ represents an average over input components. The average is accumulated
13
+ constantly over time and flushed when save_summaries() is called.
14
+
15
+ Notes:
16
+ - The output tensor must be used as an input for something else in the
17
+ graph. Otherwise, the autosummary op will not get executed, and the average
18
+ value will not get accumulated.
19
+ - It is perfectly fine to include autosummaries with the same name in
20
+ several places throughout the graph, even if they are executed concurrently.
21
+ - It is ok to also pass in a python scalar or numpy array. In this case, it
22
+ is added to the average immediately.
23
+ """
24
+
25
+ from collections import OrderedDict
26
+ import numpy as np
27
+ import tensorflow as tf
28
+ from tensorboard import summary as summary_lib
29
+ from tensorboard.plugins.custom_scalar import layout_pb2
30
+
31
+ from . import tfutil
32
+ from .tfutil import TfExpression
33
+ from .tfutil import TfExpressionEx
34
+
35
+ _dtype = tf.float64
36
+ _vars = OrderedDict() # name => [var, ...]
37
+ _immediate = OrderedDict() # name => update_op, update_value
38
+ _finalized = False
39
+ _merge_op = None
40
+
41
+
42
+ def _create_var(name: str, value_expr: TfExpression) -> TfExpression:
43
+ """Internal helper for creating autosummary accumulators."""
44
+ assert not _finalized
45
+ name_id = name.replace("/", "_")
46
+ v = tf.cast(value_expr, _dtype)
47
+
48
+ if v.shape.is_fully_defined():
49
+ size = np.prod(tfutil.shape_to_list(v.shape))
50
+ size_expr = tf.constant(size, dtype=_dtype)
51
+ else:
52
+ size = None
53
+ size_expr = tf.reduce_prod(tf.cast(tf.shape(v), _dtype))
54
+
55
+ if size == 1:
56
+ if v.shape.ndims != 0:
57
+ v = tf.reshape(v, [])
58
+ v = [size_expr, v, tf.square(v)]
59
+ else:
60
+ v = [size_expr, tf.reduce_sum(v), tf.reduce_sum(tf.square(v))]
61
+ v = tf.cond(tf.is_finite(v[1]), lambda: tf.stack(v), lambda: tf.zeros(3, dtype=_dtype))
62
+
63
+ with tfutil.absolute_name_scope("Autosummary/" + name_id), tf.control_dependencies(None):
64
+ var = tf.Variable(tf.zeros(3, dtype=_dtype), trainable=False) # [sum(1), sum(x), sum(x**2)]
65
+ update_op = tf.cond(tf.is_variable_initialized(var), lambda: tf.assign_add(var, v), lambda: tf.assign(var, v))
66
+
67
+ if name in _vars:
68
+ _vars[name].append(var)
69
+ else:
70
+ _vars[name] = [var]
71
+ return update_op
72
+
73
+
74
+ def autosummary(name: str, value: TfExpressionEx, passthru: TfExpressionEx = None) -> TfExpressionEx:
75
+ """Create a new autosummary.
76
+
77
+ Args:
78
+ name: Name to use in TensorBoard
79
+ value: TensorFlow expression or python value to track
80
+ passthru: Optionally return this TF node without modifications but tack an autosummary update side-effect to this node.
81
+
82
+ Example use of the passthru mechanism:
83
+
84
+ n = autosummary('l2loss', loss, passthru=n)
85
+
86
+ This is a shorthand for the following code:
87
+
88
+ with tf.control_dependencies([autosummary('l2loss', loss)]):
89
+ n = tf.identity(n)
90
+ """
91
+ tfutil.assert_tf_initialized()
92
+ name_id = name.replace("/", "_")
93
+
94
+ if tfutil.is_tf_expression(value):
95
+ with tf.name_scope("summary_" + name_id), tf.device(value.device):
96
+ update_op = _create_var(name, value)
97
+ with tf.control_dependencies([update_op]):
98
+ return tf.identity(value if passthru is None else passthru)
99
+
100
+ else: # python scalar or numpy array
101
+ if name not in _immediate:
102
+ with tfutil.absolute_name_scope("Autosummary/" + name_id), tf.device(None), tf.control_dependencies(None):
103
+ update_value = tf.placeholder(_dtype)
104
+ update_op = _create_var(name, update_value)
105
+ _immediate[name] = update_op, update_value
106
+
107
+ update_op, update_value = _immediate[name]
108
+ tfutil.run(update_op, {update_value: value})
109
+ return value if passthru is None else passthru
110
+
111
+
112
+ def finalize_autosummaries() -> None:
113
+ """Create the necessary ops to include autosummaries in TensorBoard report.
114
+ Note: This should be done only once per graph.
115
+ """
116
+ global _finalized
117
+ tfutil.assert_tf_initialized()
118
+
119
+ if _finalized:
120
+ return None
121
+
122
+ _finalized = True
123
+ tfutil.init_uninitialized_vars([var for vars_list in _vars.values() for var in vars_list])
124
+
125
+ # Create summary ops.
126
+ with tf.device(None), tf.control_dependencies(None):
127
+ for name, vars_list in _vars.items():
128
+ name_id = name.replace("/", "_")
129
+ with tfutil.absolute_name_scope("Autosummary/" + name_id):
130
+ moments = tf.add_n(vars_list)
131
+ moments /= moments[0]
132
+ with tf.control_dependencies([moments]): # read before resetting
133
+ reset_ops = [tf.assign(var, tf.zeros(3, dtype=_dtype)) for var in vars_list]
134
+ with tf.name_scope(None), tf.control_dependencies(reset_ops): # reset before reporting
135
+ mean = moments[1]
136
+ std = tf.sqrt(moments[2] - tf.square(moments[1]))
137
+ tf.summary.scalar(name, mean)
138
+ tf.summary.scalar("xCustomScalars/" + name + "/margin_lo", mean - std)
139
+ tf.summary.scalar("xCustomScalars/" + name + "/margin_hi", mean + std)
140
+
141
+ # Group by category and chart name.
142
+ cat_dict = OrderedDict()
143
+ for series_name in sorted(_vars.keys()):
144
+ p = series_name.split("/")
145
+ cat = p[0] if len(p) >= 2 else ""
146
+ chart = "/".join(p[1:-1]) if len(p) >= 3 else p[-1]
147
+ if cat not in cat_dict:
148
+ cat_dict[cat] = OrderedDict()
149
+ if chart not in cat_dict[cat]:
150
+ cat_dict[cat][chart] = []
151
+ cat_dict[cat][chart].append(series_name)
152
+
153
+ # Setup custom_scalar layout.
154
+ categories = []
155
+ for cat_name, chart_dict in cat_dict.items():
156
+ charts = []
157
+ for chart_name, series_names in chart_dict.items():
158
+ series = []
159
+ for series_name in series_names:
160
+ series.append(layout_pb2.MarginChartContent.Series(
161
+ value=series_name,
162
+ lower="xCustomScalars/" + series_name + "/margin_lo",
163
+ upper="xCustomScalars/" + series_name + "/margin_hi"))
164
+ margin = layout_pb2.MarginChartContent(series=series)
165
+ charts.append(layout_pb2.Chart(title=chart_name, margin=margin))
166
+ categories.append(layout_pb2.Category(title=cat_name, chart=charts))
167
+ layout = summary_lib.custom_scalar_pb(layout_pb2.Layout(category=categories))
168
+ return layout
169
+
170
+ def save_summaries(file_writer, global_step=None):
171
+ """Call FileWriter.add_summary() with all summaries in the default graph,
172
+ automatically finalizing and merging them on the first call.
173
+ """
174
+ global _merge_op
175
+ tfutil.assert_tf_initialized()
176
+
177
+ if _merge_op is None:
178
+ layout = finalize_autosummaries()
179
+ if layout is not None:
180
+ file_writer.add_summary(layout)
181
+ with tf.device(None), tf.control_dependencies(None):
182
+ _merge_op = tf.summary.merge_all()
183
+
184
+ file_writer.add_summary(_merge_op.eval(), global_step)
dnnlib/tflib/network.py ADDED
@@ -0,0 +1,628 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # This work is licensed under the Creative Commons Attribution-NonCommercial
4
+ # 4.0 International License. To view a copy of this license, visit
5
+ # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to
6
+ # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.
7
+
8
+ """Helper for managing networks."""
9
+
10
+ import types
11
+ import inspect
12
+ import re
13
+ import uuid
14
+ import sys
15
+ import numpy as np
16
+ import tensorflow as tf
17
+
18
+ from collections import OrderedDict
19
+ from typing import Any, List, Tuple, Union
20
+
21
+ from . import tfutil
22
+ from .. import util
23
+
24
+ from .tfutil import TfExpression, TfExpressionEx
25
+
26
+ _import_handlers = [] # Custom import handlers for dealing with legacy data in pickle import.
27
+ _import_module_src = dict() # Source code for temporary modules created during pickle import.
28
+
29
+
30
+ def import_handler(handler_func):
31
+ """Function decorator for declaring custom import handlers."""
32
+ _import_handlers.append(handler_func)
33
+ return handler_func
34
+
35
+
36
+ class Network:
37
+ """Generic network abstraction.
38
+
39
+ Acts as a convenience wrapper for a parameterized network construction
40
+ function, providing several utility methods and convenient access to
41
+ the inputs/outputs/weights.
42
+
43
+ Network objects can be safely pickled and unpickled for long-term
44
+ archival purposes. The pickling works reliably as long as the underlying
45
+ network construction function is defined in a standalone Python module
46
+ that has no side effects or application-specific imports.
47
+
48
+ Args:
49
+ name: Network name. Used to select TensorFlow name and variable scopes.
50
+ func_name: Fully qualified name of the underlying network construction function, or a top-level function object.
51
+ static_kwargs: Keyword arguments to be passed in to the network construction function.
52
+
53
+ Attributes:
54
+ name: User-specified name, defaults to build func name if None.
55
+ scope: Unique TensorFlow scope containing template graph and variables, derived from the user-specified name.
56
+ static_kwargs: Arguments passed to the user-supplied build func.
57
+ components: Container for sub-networks. Passed to the build func, and retained between calls.
58
+ num_inputs: Number of input tensors.
59
+ num_outputs: Number of output tensors.
60
+ input_shapes: Input tensor shapes (NC or NCHW), including minibatch dimension.
61
+ output_shapes: Output tensor shapes (NC or NCHW), including minibatch dimension.
62
+ input_shape: Short-hand for input_shapes[0].
63
+ output_shape: Short-hand for output_shapes[0].
64
+ input_templates: Input placeholders in the template graph.
65
+ output_templates: Output tensors in the template graph.
66
+ input_names: Name string for each input.
67
+ output_names: Name string for each output.
68
+ own_vars: Variables defined by this network (local_name => var), excluding sub-networks.
69
+ vars: All variables (local_name => var).
70
+ trainables: All trainable variables (local_name => var).
71
+ var_global_to_local: Mapping from variable global names to local names.
72
+ """
73
+
74
+ def __init__(self, name: str = None, func_name: Any = None, **static_kwargs):
75
+ tfutil.assert_tf_initialized()
76
+ assert isinstance(name, str) or name is None
77
+ assert func_name is not None
78
+ assert isinstance(func_name, str) or util.is_top_level_function(func_name)
79
+ assert util.is_pickleable(static_kwargs)
80
+
81
+ self._init_fields()
82
+ self.name = name
83
+ self.static_kwargs = util.EasyDict(static_kwargs)
84
+
85
+ # Locate the user-specified network build function.
86
+ if util.is_top_level_function(func_name):
87
+ func_name = util.get_top_level_function_name(func_name)
88
+ module, self._build_func_name = util.get_module_from_obj_name(func_name)
89
+ self._build_func = util.get_obj_from_module(module, self._build_func_name)
90
+ assert callable(self._build_func)
91
+
92
+ # Dig up source code for the module containing the build function.
93
+ self._build_module_src = _import_module_src.get(module, None)
94
+ if self._build_module_src is None:
95
+ self._build_module_src = inspect.getsource(module)
96
+
97
+ # Init TensorFlow graph.
98
+ self._init_graph()
99
+ self.reset_own_vars()
100
+
101
+ def _init_fields(self) -> None:
102
+ self.name = None
103
+ self.scope = None
104
+ self.static_kwargs = util.EasyDict()
105
+ self.components = util.EasyDict()
106
+ self.num_inputs = 0
107
+ self.num_outputs = 0
108
+ self.input_shapes = [[]]
109
+ self.output_shapes = [[]]
110
+ self.input_shape = []
111
+ self.output_shape = []
112
+ self.input_templates = []
113
+ self.output_templates = []
114
+ self.input_names = []
115
+ self.output_names = []
116
+ self.own_vars = OrderedDict()
117
+ self.vars = OrderedDict()
118
+ self.trainables = OrderedDict()
119
+ self.var_global_to_local = OrderedDict()
120
+
121
+ self._build_func = None # User-supplied build function that constructs the network.
122
+ self._build_func_name = None # Name of the build function.
123
+ self._build_module_src = None # Full source code of the module containing the build function.
124
+ self._run_cache = dict() # Cached graph data for Network.run().
125
+
126
+ def _init_graph(self) -> None:
127
+ # Collect inputs.
128
+ self.input_names = []
129
+
130
+ for param in inspect.signature(self._build_func).parameters.values():
131
+ if param.kind == param.POSITIONAL_OR_KEYWORD and param.default is param.empty:
132
+ self.input_names.append(param.name)
133
+
134
+ self.num_inputs = len(self.input_names)
135
+ assert self.num_inputs >= 1
136
+
137
+ # Choose name and scope.
138
+ if self.name is None:
139
+ self.name = self._build_func_name
140
+ assert re.match("^[A-Za-z0-9_.\\-]*$", self.name)
141
+ with tf.name_scope(None):
142
+ self.scope = tf.compat.v1.get_default_graph().unique_name(self.name, mark_as_used=True)
143
+
144
+ # Finalize build func kwargs.
145
+ build_kwargs = dict(self.static_kwargs)
146
+ build_kwargs["is_template_graph"] = True
147
+ build_kwargs["components"] = self.components
148
+
149
+ # Build template graph.
150
+ with tfutil.absolute_variable_scope(self.scope, reuse=tf.compat.v1.AUTO_REUSE), tfutil.absolute_name_scope(self.scope): # ignore surrounding scopes
151
+ assert tf.compat.v1.get_variable_scope().name == self.scope
152
+ assert tf.compat.v1.get_default_graph().get_name_scope() == self.scope
153
+ with tf.control_dependencies(None): # ignore surrounding control dependencies
154
+ self.input_templates = [tf.compat.v1.placeholder(tf.float32, name=name) for name in self.input_names]
155
+ out_expr = self._build_func(*self.input_templates, **build_kwargs)
156
+
157
+ # Collect outputs.
158
+ assert tfutil.is_tf_expression(out_expr) or isinstance(out_expr, tuple)
159
+ self.output_templates = [out_expr] if tfutil.is_tf_expression(out_expr) else list(out_expr)
160
+ self.num_outputs = len(self.output_templates)
161
+ assert self.num_outputs >= 1
162
+ assert all(tfutil.is_tf_expression(t) for t in self.output_templates)
163
+
164
+ # Perform sanity checks.
165
+ if any(t.shape.ndims is None for t in self.input_templates):
166
+ raise ValueError("Network input shapes not defined. Please call x.set_shape() for each input.")
167
+ if any(t.shape.ndims is None for t in self.output_templates):
168
+ raise ValueError("Network output shapes not defined. Please call x.set_shape() where applicable.")
169
+ if any(not isinstance(comp, Network) for comp in self.components.values()):
170
+ raise ValueError("Components of a Network must be Networks themselves.")
171
+ if len(self.components) != len(set(comp.name for comp in self.components.values())):
172
+ raise ValueError("Components of a Network must have unique names.")
173
+
174
+ # List inputs and outputs.
175
+ self.input_shapes = [tfutil.shape_to_list(t.shape) for t in self.input_templates]
176
+ self.output_shapes = [tfutil.shape_to_list(t.shape) for t in self.output_templates]
177
+ self.input_shape = self.input_shapes[0]
178
+ self.output_shape = self.output_shapes[0]
179
+ self.output_names = [t.name.split("/")[-1].split(":")[0] for t in self.output_templates]
180
+
181
+ # List variables.
182
+ self.own_vars = OrderedDict((var.name[len(self.scope) + 1:].split(":")[0], var) for var in tf.compat.v1.global_variables(self.scope + "/"))
183
+ self.vars = OrderedDict(self.own_vars)
184
+ self.vars.update((comp.name + "/" + name, var) for comp in self.components.values() for name, var in comp.vars.items())
185
+ self.trainables = OrderedDict((name, var) for name, var in self.vars.items() if var.trainable)
186
+ self.var_global_to_local = OrderedDict((var.name.split(":")[0], name) for name, var in self.vars.items())
187
+
188
+ def reset_own_vars(self) -> None:
189
+ """Re-initialize all variables of this network, excluding sub-networks."""
190
+ tfutil.run([var.initializer for var in self.own_vars.values()])
191
+
192
+ def reset_vars(self) -> None:
193
+ """Re-initialize all variables of this network, including sub-networks."""
194
+ tfutil.run([var.initializer for var in self.vars.values()])
195
+
196
+ def reset_trainables(self) -> None:
197
+ """Re-initialize all trainable variables of this network, including sub-networks."""
198
+ tfutil.run([var.initializer for var in self.trainables.values()])
199
+
200
+ def get_output_for(self, *in_expr: TfExpression, return_as_list: bool = False, **dynamic_kwargs) -> Union[TfExpression, List[TfExpression]]:
201
+ """Construct TensorFlow expression(s) for the output(s) of this network, given the input expression(s)."""
202
+ assert len(in_expr) == self.num_inputs
203
+ assert not all(expr is None for expr in in_expr)
204
+
205
+ # Finalize build func kwargs.
206
+ build_kwargs = dict(self.static_kwargs)
207
+ build_kwargs.update(dynamic_kwargs)
208
+ build_kwargs["is_template_graph"] = False
209
+ build_kwargs["components"] = self.components
210
+
211
+ # Build TensorFlow graph to evaluate the network.
212
+ with tfutil.absolute_variable_scope(self.scope, reuse=True), tf.name_scope(self.name):
213
+ assert tf.compat.v1.get_variable_scope().name == self.scope
214
+ valid_inputs = [expr for expr in in_expr if expr is not None]
215
+ final_inputs = []
216
+ for expr, name, shape in zip(in_expr, self.input_names, self.input_shapes):
217
+ if expr is not None:
218
+ expr = tf.identity(expr, name=name)
219
+ else:
220
+ expr = tf.zeros([tf.shape(valid_inputs[0])[0]] + shape[1:], name=name)
221
+ final_inputs.append(expr)
222
+ out_expr = self._build_func(*final_inputs, **build_kwargs)
223
+
224
+ # Propagate input shapes back to the user-specified expressions.
225
+ for expr, final in zip(in_expr, final_inputs):
226
+ if isinstance(expr, tf.Tensor):
227
+ expr.set_shape(final.shape)
228
+
229
+ # Express outputs in the desired format.
230
+ assert tfutil.is_tf_expression(out_expr) or isinstance(out_expr, tuple)
231
+ if return_as_list:
232
+ out_expr = [out_expr] if tfutil.is_tf_expression(out_expr) else list(out_expr)
233
+ return out_expr
234
+
235
+ def get_var_local_name(self, var_or_global_name: Union[TfExpression, str]) -> str:
236
+ """Get the local name of a given variable, without any surrounding name scopes."""
237
+ assert tfutil.is_tf_expression(var_or_global_name) or isinstance(var_or_global_name, str)
238
+ global_name = var_or_global_name if isinstance(var_or_global_name, str) else var_or_global_name.name
239
+ return self.var_global_to_local[global_name]
240
+
241
+ def find_var(self, var_or_local_name: Union[TfExpression, str]) -> TfExpression:
242
+ """Find variable by local or global name."""
243
+ assert tfutil.is_tf_expression(var_or_local_name) or isinstance(var_or_local_name, str)
244
+ return self.vars[var_or_local_name] if isinstance(var_or_local_name, str) else var_or_local_name
245
+
246
+ def get_var(self, var_or_local_name: Union[TfExpression, str]) -> np.ndarray:
247
+ """Get the value of a given variable as NumPy array.
248
+ Note: This method is very inefficient -- prefer to use tflib.run(list_of_vars) whenever possible."""
249
+ return self.find_var(var_or_local_name).eval()
250
+
251
+ def set_var(self, var_or_local_name: Union[TfExpression, str], new_value: Union[int, float, np.ndarray]) -> None:
252
+ """Set the value of a given variable based on the given NumPy array.
253
+ Note: This method is very inefficient -- prefer to use tflib.set_vars() whenever possible."""
254
+ tfutil.set_vars({self.find_var(var_or_local_name): new_value})
255
+
256
+ def __getstate__(self) -> dict:
257
+ """Pickle export."""
258
+ state = dict()
259
+ state["version"] = 3
260
+ state["name"] = self.name
261
+ state["static_kwargs"] = dict(self.static_kwargs)
262
+ state["components"] = dict(self.components)
263
+ state["build_module_src"] = self._build_module_src
264
+ state["build_func_name"] = self._build_func_name
265
+ state["variables"] = list(zip(self.own_vars.keys(), tfutil.run(list(self.own_vars.values()))))
266
+ return state
267
+
268
+ def __setstate__(self, state: dict) -> None:
269
+ """Pickle import."""
270
+ # pylint: disable=attribute-defined-outside-init
271
+ tfutil.assert_tf_initialized()
272
+ self._init_fields()
273
+
274
+ # Execute custom import handlers.
275
+ for handler in _import_handlers:
276
+ state = handler(state)
277
+
278
+ # Set basic fields.
279
+ assert state["version"] in [2, 3]
280
+ self.name = state["name"]
281
+ self.static_kwargs = util.EasyDict(state["static_kwargs"])
282
+ self.components = util.EasyDict(state.get("components", {}))
283
+ self._build_module_src = state["build_module_src"]
284
+ self._build_func_name = state["build_func_name"]
285
+
286
+ # Create temporary module from the imported source code.
287
+ module_name = "_tflib_network_import_" + uuid.uuid4().hex
288
+ module = types.ModuleType(module_name)
289
+ sys.modules[module_name] = module
290
+ _import_module_src[module] = self._build_module_src
291
+ exec(self._build_module_src, module.__dict__) # pylint: disable=exec-used
292
+
293
+ # Locate network build function in the temporary module.
294
+ self._build_func = util.get_obj_from_module(module, self._build_func_name)
295
+ assert callable(self._build_func)
296
+
297
+ # Init TensorFlow graph.
298
+ self._init_graph()
299
+ self.reset_own_vars()
300
+ tfutil.set_vars({self.find_var(name): value for name, value in state["variables"]})
301
+
302
+ def clone(self, name: str = None, **new_static_kwargs) -> "Network":
303
+ """Create a clone of this network with its own copy of the variables."""
304
+ # pylint: disable=protected-access
305
+ net = object.__new__(Network)
306
+ net._init_fields()
307
+ net.name = name if name is not None else self.name
308
+ net.static_kwargs = util.EasyDict(self.static_kwargs)
309
+ net.static_kwargs.update(new_static_kwargs)
310
+ net._build_module_src = self._build_module_src
311
+ net._build_func_name = self._build_func_name
312
+ net._build_func = self._build_func
313
+ net._init_graph()
314
+ net.copy_vars_from(self)
315
+ return net
316
+
317
+ def copy_own_vars_from(self, src_net: "Network") -> None:
318
+ """Copy the values of all variables from the given network, excluding sub-networks."""
319
+ names = [name for name in self.own_vars.keys() if name in src_net.own_vars]
320
+ tfutil.set_vars(tfutil.run({self.vars[name]: src_net.vars[name] for name in names}))
321
+
322
+ def copy_vars_from(self, src_net: "Network") -> None:
323
+ """Copy the values of all variables from the given network, including sub-networks."""
324
+ names = [name for name in self.vars.keys() if name in src_net.vars]
325
+ tfutil.set_vars(tfutil.run({self.vars[name]: src_net.vars[name] for name in names}))
326
+
327
+ def copy_trainables_from(self, src_net: "Network") -> None:
328
+ """Copy the values of all trainable variables from the given network, including sub-networks."""
329
+ names = [name for name in self.trainables.keys() if name in src_net.trainables]
330
+ tfutil.set_vars(tfutil.run({self.vars[name]: src_net.vars[name] for name in names}))
331
+
332
+ def copy_compatible_trainables_from(self, src_net: "Network") -> None:
333
+ """Copy the compatible values of all trainable variables from the given network, including sub-networks"""
334
+ names = []
335
+ for name in self.trainables.keys():
336
+ if name not in src_net.trainables:
337
+ print("Not restoring (not present): {}".format(name))
338
+ elif self.trainables[name].shape != src_net.trainables[name].shape:
339
+ print("Not restoring (different shape): {}".format(name))
340
+
341
+ if name in src_net.trainables and self.trainables[name].shape == src_net.trainables[name].shape:
342
+ names.append(name)
343
+
344
+ tfutil.set_vars(tfutil.run({self.vars[name]: src_net.vars[name] for name in names}))
345
+
346
+ def apply_swa(self, src_net, epoch):
347
+ """Perform stochastic weight averaging on the compatible values of all trainable variables from the given network, including sub-networks"""
348
+ names = []
349
+ for name in self.trainables.keys():
350
+ if name not in src_net.trainables:
351
+ print("Not restoring (not present): {}".format(name))
352
+ elif self.trainables[name].shape != src_net.trainables[name].shape:
353
+ print("Not restoring (different shape): {}".format(name))
354
+
355
+ if name in src_net.trainables and self.trainables[name].shape == src_net.trainables[name].shape:
356
+ names.append(name)
357
+
358
+ scale_new_data = 1.0 / (epoch + 1)
359
+ scale_moving_average = (1.0 - scale_new_data)
360
+ tfutil.set_vars(tfutil.run({self.vars[name]: (src_net.vars[name] * scale_new_data + self.vars[name] * scale_moving_average) for name in names}))
361
+
362
+ def convert(self, new_func_name: str, new_name: str = None, **new_static_kwargs) -> "Network":
363
+ """Create new network with the given parameters, and copy all variables from this network."""
364
+ if new_name is None:
365
+ new_name = self.name
366
+ static_kwargs = dict(self.static_kwargs)
367
+ static_kwargs.update(new_static_kwargs)
368
+ net = Network(name=new_name, func_name=new_func_name, **static_kwargs)
369
+ net.copy_vars_from(self)
370
+ return net
371
+
372
+ def setup_as_moving_average_of(self, src_net: "Network", beta: TfExpressionEx = 0.99, beta_nontrainable: TfExpressionEx = 0.0) -> tf.Operation:
373
+ """Construct a TensorFlow op that updates the variables of this network
374
+ to be slightly closer to those of the given network."""
375
+ with tfutil.absolute_name_scope(self.scope + "/_MovingAvg"):
376
+ ops = []
377
+ for name, var in self.vars.items():
378
+ if name in src_net.vars:
379
+ cur_beta = beta if name in self.trainables else beta_nontrainable
380
+ new_value = tfutil.lerp(src_net.vars[name], var, cur_beta)
381
+ ops.append(var.assign(new_value))
382
+ return tf.group(*ops)
383
+
384
+ def run(self,
385
+ *in_arrays: Tuple[Union[np.ndarray, None], ...],
386
+ input_transform: dict = None,
387
+ output_transform: dict = None,
388
+ return_as_list: bool = False,
389
+ print_progress: bool = False,
390
+ minibatch_size: int = None,
391
+ num_gpus: int = 1,
392
+ assume_frozen: bool = False,
393
+ custom_inputs=None,
394
+ **dynamic_kwargs) -> Union[np.ndarray, Tuple[np.ndarray, ...], List[np.ndarray]]:
395
+ """Run this network for the given NumPy array(s), and return the output(s) as NumPy array(s).
396
+
397
+ Args:
398
+ input_transform: A dict specifying a custom transformation to be applied to the input tensor(s) before evaluating the network.
399
+ The dict must contain a 'func' field that points to a top-level function. The function is called with the input
400
+ TensorFlow expression(s) as positional arguments. Any remaining fields of the dict will be passed in as kwargs.
401
+ output_transform: A dict specifying a custom transformation to be applied to the output tensor(s) after evaluating the network.
402
+ The dict must contain a 'func' field that points to a top-level function. The function is called with the output
403
+ TensorFlow expression(s) as positional arguments. Any remaining fields of the dict will be passed in as kwargs.
404
+ return_as_list: True = return a list of NumPy arrays, False = return a single NumPy array, or a tuple if there are multiple outputs.
405
+ print_progress: Print progress to the console? Useful for very large input arrays.
406
+ minibatch_size: Maximum minibatch size to use, None = disable batching.
407
+ num_gpus: Number of GPUs to use.
408
+ assume_frozen: Improve multi-GPU performance by assuming that the trainable parameters will remain changed between calls.
409
+ dynamic_kwargs: Additional keyword arguments to be passed into the network build function.
410
+ custom_inputs: Allow to use another Tensor as input instead of default Placeholders
411
+ """
412
+ assert len(in_arrays) == self.num_inputs
413
+ assert not all(arr is None for arr in in_arrays)
414
+ assert input_transform is None or util.is_top_level_function(input_transform["func"])
415
+ assert output_transform is None or util.is_top_level_function(output_transform["func"])
416
+ output_transform, dynamic_kwargs = _handle_legacy_output_transforms(output_transform, dynamic_kwargs)
417
+ num_items = in_arrays[0].shape[0]
418
+ if minibatch_size is None:
419
+ minibatch_size = num_items
420
+
421
+ # Construct unique hash key from all arguments that affect the TensorFlow graph.
422
+ key = dict(input_transform=input_transform, output_transform=output_transform, num_gpus=num_gpus, assume_frozen=assume_frozen, dynamic_kwargs=dynamic_kwargs)
423
+ def unwind_key(obj):
424
+ if isinstance(obj, dict):
425
+ return [(key, unwind_key(value)) for key, value in sorted(obj.items())]
426
+ if callable(obj):
427
+ return util.get_top_level_function_name(obj)
428
+ return obj
429
+ key = repr(unwind_key(key))
430
+
431
+ # Build graph.
432
+ if key not in self._run_cache:
433
+ with tfutil.absolute_name_scope(self.scope + "/_Run"), tf.control_dependencies(None):
434
+ if custom_inputs is not None:
435
+ with tf.device("/gpu:0"):
436
+ in_expr = [input_builder(name) for input_builder, name in zip(custom_inputs, self.input_names)]
437
+ in_split = list(zip(*[tf.split(x, num_gpus) for x in in_expr]))
438
+ else:
439
+ with tf.device("/cpu:0"):
440
+ in_expr = [tf.placeholder(tf.float32, name=name) for name in self.input_names]
441
+ in_split = list(zip(*[tf.split(x, num_gpus) for x in in_expr]))
442
+
443
+ out_split = []
444
+ for gpu in range(num_gpus):
445
+ with tf.device("/gpu:%d" % gpu):
446
+ net_gpu = self.clone() if assume_frozen else self
447
+ in_gpu = in_split[gpu]
448
+
449
+ if input_transform is not None:
450
+ in_kwargs = dict(input_transform)
451
+ in_gpu = in_kwargs.pop("func")(*in_gpu, **in_kwargs)
452
+ in_gpu = [in_gpu] if tfutil.is_tf_expression(in_gpu) else list(in_gpu)
453
+
454
+ assert len(in_gpu) == self.num_inputs
455
+ out_gpu = net_gpu.get_output_for(*in_gpu, return_as_list=True, **dynamic_kwargs)
456
+
457
+ if output_transform is not None:
458
+ out_kwargs = dict(output_transform)
459
+ out_gpu = out_kwargs.pop("func")(*out_gpu, **out_kwargs)
460
+ out_gpu = [out_gpu] if tfutil.is_tf_expression(out_gpu) else list(out_gpu)
461
+
462
+ assert len(out_gpu) == self.num_outputs
463
+ out_split.append(out_gpu)
464
+
465
+ with tf.device("/cpu:0"):
466
+ out_expr = [tf.concat(outputs, axis=0) for outputs in zip(*out_split)]
467
+ self._run_cache[key] = in_expr, out_expr
468
+
469
+ # Run minibatches.
470
+ in_expr, out_expr = self._run_cache[key]
471
+ out_arrays = [np.empty([num_items] + tfutil.shape_to_list(expr.shape)[1:], expr.dtype.name) for expr in out_expr]
472
+
473
+ for mb_begin in range(0, num_items, minibatch_size):
474
+ if print_progress:
475
+ print("\r%d / %d" % (mb_begin, num_items), end="")
476
+
477
+ mb_end = min(mb_begin + minibatch_size, num_items)
478
+ mb_num = mb_end - mb_begin
479
+ mb_in = [src[mb_begin : mb_end] if src is not None else np.zeros([mb_num] + shape[1:]) for src, shape in zip(in_arrays, self.input_shapes)]
480
+ mb_out = tf.compat.v1.get_default_session().run(out_expr, dict(zip(in_expr, mb_in)))
481
+
482
+ for dst, src in zip(out_arrays, mb_out):
483
+ dst[mb_begin: mb_end] = src
484
+
485
+ # Done.
486
+ if print_progress:
487
+ print("\r%d / %d" % (num_items, num_items))
488
+
489
+ if not return_as_list:
490
+ out_arrays = out_arrays[0] if len(out_arrays) == 1 else tuple(out_arrays)
491
+ return out_arrays
492
+
493
+ def list_ops(self) -> List[TfExpression]:
494
+ include_prefix = self.scope + "/"
495
+ exclude_prefix = include_prefix + "_"
496
+ ops = tf.get_default_graph().get_operations()
497
+ ops = [op for op in ops if op.name.startswith(include_prefix)]
498
+ ops = [op for op in ops if not op.name.startswith(exclude_prefix)]
499
+ return ops
500
+
501
+ def list_layers(self) -> List[Tuple[str, TfExpression, List[TfExpression]]]:
502
+ """Returns a list of (layer_name, output_expr, trainable_vars) tuples corresponding to
503
+ individual layers of the network. Mainly intended to be used for reporting."""
504
+ layers = []
505
+
506
+ def recurse(scope, parent_ops, parent_vars, level):
507
+ # Ignore specific patterns.
508
+ if any(p in scope for p in ["/Shape", "/strided_slice", "/Cast", "/concat", "/Assign"]):
509
+ return
510
+
511
+ # Filter ops and vars by scope.
512
+ global_prefix = scope + "/"
513
+ local_prefix = global_prefix[len(self.scope) + 1:]
514
+ cur_ops = [op for op in parent_ops if op.name.startswith(global_prefix) or op.name == global_prefix[:-1]]
515
+ cur_vars = [(name, var) for name, var in parent_vars if name.startswith(local_prefix) or name == local_prefix[:-1]]
516
+ if not cur_ops and not cur_vars:
517
+ return
518
+
519
+ # Filter out all ops related to variables.
520
+ for var in [op for op in cur_ops if op.type.startswith("Variable")]:
521
+ var_prefix = var.name + "/"
522
+ cur_ops = [op for op in cur_ops if not op.name.startswith(var_prefix)]
523
+
524
+ # Scope does not contain ops as immediate children => recurse deeper.
525
+ contains_direct_ops = any("/" not in op.name[len(global_prefix):] and op.type != "Identity" for op in cur_ops)
526
+ if (level == 0 or not contains_direct_ops) and (len(cur_ops) + len(cur_vars)) > 1:
527
+ visited = set()
528
+ for rel_name in [op.name[len(global_prefix):] for op in cur_ops] + [name[len(local_prefix):] for name, _var in cur_vars]:
529
+ token = rel_name.split("/")[0]
530
+ if token not in visited:
531
+ recurse(global_prefix + token, cur_ops, cur_vars, level + 1)
532
+ visited.add(token)
533
+ return
534
+
535
+ # Report layer.
536
+ layer_name = scope[len(self.scope) + 1:]
537
+ layer_output = cur_ops[-1].outputs[0] if cur_ops else cur_vars[-1][1]
538
+ layer_trainables = [var for _name, var in cur_vars if var.trainable]
539
+ layers.append((layer_name, layer_output, layer_trainables))
540
+
541
+ recurse(self.scope, self.list_ops(), list(self.vars.items()), 0)
542
+ return layers
543
+
544
+ def print_layers(self, title: str = None, hide_layers_with_no_params: bool = False) -> None:
545
+ """Print a summary table of the network structure."""
546
+ rows = [[title if title is not None else self.name, "Params", "OutputShape", "WeightShape"]]
547
+ rows += [["---"] * 4]
548
+ total_params = 0
549
+
550
+ for layer_name, layer_output, layer_trainables in self.list_layers():
551
+ num_params = sum(np.prod(tfutil.shape_to_list(var.shape)) for var in layer_trainables)
552
+ weights = [var for var in layer_trainables if var.name.endswith("/weight:0") or var.name.endswith("/weight_1:0")]
553
+ weights.sort(key=lambda x: len(x.name))
554
+ if len(weights) == 0 and len(layer_trainables) == 1:
555
+ weights = layer_trainables
556
+ total_params += num_params
557
+
558
+ if not hide_layers_with_no_params or num_params != 0:
559
+ num_params_str = str(num_params) if num_params > 0 else "-"
560
+ output_shape_str = str(layer_output.shape)
561
+ weight_shape_str = str(weights[0].shape) if len(weights) >= 1 else "-"
562
+ rows += [[layer_name, num_params_str, output_shape_str, weight_shape_str]]
563
+
564
+ rows += [["---"] * 4]
565
+ rows += [["Total", str(total_params), "", ""]]
566
+
567
+ widths = [max(len(cell) for cell in column) for column in zip(*rows)]
568
+ print()
569
+ for row in rows:
570
+ print(" ".join(cell + " " * (width - len(cell)) for cell, width in zip(row, widths)))
571
+ print()
572
+
573
+ def setup_weight_histograms(self, title: str = None) -> None:
574
+ """Construct summary ops to include histograms of all trainable parameters in TensorBoard."""
575
+ if title is None:
576
+ title = self.name
577
+
578
+ with tf.name_scope(None), tf.device(None), tf.control_dependencies(None):
579
+ for local_name, var in self.trainables.items():
580
+ if "/" in local_name:
581
+ p = local_name.split("/")
582
+ name = title + "_" + p[-1] + "/" + "_".join(p[:-1])
583
+ else:
584
+ name = title + "_toplevel/" + local_name
585
+
586
+ tf.summary.histogram(name, var)
587
+
588
+ #----------------------------------------------------------------------------
589
+ # Backwards-compatible emulation of legacy output transformation in Network.run().
590
+
591
+ _print_legacy_warning = True
592
+
593
+ def _handle_legacy_output_transforms(output_transform, dynamic_kwargs):
594
+ global _print_legacy_warning
595
+ legacy_kwargs = ["out_mul", "out_add", "out_shrink", "out_dtype"]
596
+ if not any(kwarg in dynamic_kwargs for kwarg in legacy_kwargs):
597
+ return output_transform, dynamic_kwargs
598
+
599
+ if _print_legacy_warning:
600
+ _print_legacy_warning = False
601
+ print()
602
+ print("WARNING: Old-style output transformations in Network.run() are deprecated.")
603
+ print("Consider using 'output_transform=dict(func=tflib.convert_images_to_uint8)'")
604
+ print("instead of 'out_mul=127.5, out_add=127.5, out_dtype=np.uint8'.")
605
+ print()
606
+ assert output_transform is None
607
+
608
+ new_kwargs = dict(dynamic_kwargs)
609
+ new_transform = {kwarg: new_kwargs.pop(kwarg) for kwarg in legacy_kwargs if kwarg in dynamic_kwargs}
610
+ new_transform["func"] = _legacy_output_transform_func
611
+ return new_transform, new_kwargs
612
+
613
+ def _legacy_output_transform_func(*expr, out_mul=1.0, out_add=0.0, out_shrink=1, out_dtype=None):
614
+ if out_mul != 1.0:
615
+ expr = [x * out_mul for x in expr]
616
+
617
+ if out_add != 0.0:
618
+ expr = [x + out_add for x in expr]
619
+
620
+ if out_shrink > 1:
621
+ ksize = [1, 1, out_shrink, out_shrink]
622
+ expr = [tf.nn.avg_pool(x, ksize=ksize, strides=ksize, padding="VALID", data_format="NCHW") for x in expr]
623
+
624
+ if out_dtype is not None:
625
+ if tf.as_dtype(out_dtype).is_integer:
626
+ expr = [tf.round(x) for x in expr]
627
+ expr = [tf.saturate_cast(x, out_dtype) for x in expr]
628
+ return expr
dnnlib/tflib/optimizer.py ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # This work is licensed under the Creative Commons Attribution-NonCommercial
4
+ # 4.0 International License. To view a copy of this license, visit
5
+ # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to
6
+ # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.
7
+
8
+ """Helper wrapper for a Tensorflow optimizer."""
9
+
10
+ import numpy as np
11
+ import tensorflow as tf
12
+
13
+ from collections import OrderedDict
14
+ from typing import List, Union
15
+
16
+ from . import autosummary
17
+ from . import tfutil
18
+ from .. import util
19
+
20
+ from .tfutil import TfExpression, TfExpressionEx
21
+
22
+ try:
23
+ # TensorFlow 1.13
24
+ from tensorflow.python.ops import nccl_ops
25
+ except:
26
+ # Older TensorFlow versions
27
+ import tensorflow.contrib.nccl as nccl_ops
28
+
29
+ class Optimizer:
30
+ """A Wrapper for tf.train.Optimizer.
31
+
32
+ Automatically takes care of:
33
+ - Gradient averaging for multi-GPU training.
34
+ - Dynamic loss scaling and typecasts for FP16 training.
35
+ - Ignoring corrupted gradients that contain NaNs/Infs.
36
+ - Reporting statistics.
37
+ - Well-chosen default settings.
38
+ """
39
+
40
+ def __init__(self,
41
+ name: str = "Train",
42
+ tf_optimizer: str = "tf.train.AdamOptimizer",
43
+ learning_rate: TfExpressionEx = 0.001,
44
+ use_loss_scaling: bool = False,
45
+ loss_scaling_init: float = 64.0,
46
+ loss_scaling_inc: float = 0.0005,
47
+ loss_scaling_dec: float = 1.0,
48
+ **kwargs):
49
+
50
+ # Init fields.
51
+ self.name = name
52
+ self.learning_rate = tf.convert_to_tensor(learning_rate)
53
+ self.id = self.name.replace("/", ".")
54
+ self.scope = tf.get_default_graph().unique_name(self.id)
55
+ self.optimizer_class = util.get_obj_by_name(tf_optimizer)
56
+ self.optimizer_kwargs = dict(kwargs)
57
+ self.use_loss_scaling = use_loss_scaling
58
+ self.loss_scaling_init = loss_scaling_init
59
+ self.loss_scaling_inc = loss_scaling_inc
60
+ self.loss_scaling_dec = loss_scaling_dec
61
+ self._grad_shapes = None # [shape, ...]
62
+ self._dev_opt = OrderedDict() # device => optimizer
63
+ self._dev_grads = OrderedDict() # device => [[(grad, var), ...], ...]
64
+ self._dev_ls_var = OrderedDict() # device => variable (log2 of loss scaling factor)
65
+ self._updates_applied = False
66
+
67
+ def register_gradients(self, loss: TfExpression, trainable_vars: Union[List, dict]) -> None:
68
+ """Register the gradients of the given loss function with respect to the given variables.
69
+ Intended to be called once per GPU."""
70
+ assert not self._updates_applied
71
+
72
+ # Validate arguments.
73
+ if isinstance(trainable_vars, dict):
74
+ trainable_vars = list(trainable_vars.values()) # allow passing in Network.trainables as vars
75
+
76
+ assert isinstance(trainable_vars, list) and len(trainable_vars) >= 1
77
+ assert all(tfutil.is_tf_expression(expr) for expr in trainable_vars + [loss])
78
+
79
+ if self._grad_shapes is None:
80
+ self._grad_shapes = [tfutil.shape_to_list(var.shape) for var in trainable_vars]
81
+
82
+ assert len(trainable_vars) == len(self._grad_shapes)
83
+ assert all(tfutil.shape_to_list(var.shape) == var_shape for var, var_shape in zip(trainable_vars, self._grad_shapes))
84
+
85
+ dev = loss.device
86
+
87
+ assert all(var.device == dev for var in trainable_vars)
88
+
89
+ # Register device and compute gradients.
90
+ with tf.name_scope(self.id + "_grad"), tf.device(dev):
91
+ if dev not in self._dev_opt:
92
+ opt_name = self.scope.replace("/", "_") + "_opt%d" % len(self._dev_opt)
93
+ assert callable(self.optimizer_class)
94
+ self._dev_opt[dev] = self.optimizer_class(name=opt_name, learning_rate=self.learning_rate, **self.optimizer_kwargs)
95
+ self._dev_grads[dev] = []
96
+
97
+ loss = self.apply_loss_scaling(tf.cast(loss, tf.float32))
98
+ grads = self._dev_opt[dev].compute_gradients(loss, trainable_vars, gate_gradients=tf.train.Optimizer.GATE_NONE) # disable gating to reduce memory usage
99
+ grads = [(g, v) if g is not None else (tf.zeros_like(v), v) for g, v in grads] # replace disconnected gradients with zeros
100
+ self._dev_grads[dev].append(grads)
101
+
102
+ def apply_updates(self) -> tf.Operation:
103
+ """Construct training op to update the registered variables based on their gradients."""
104
+ tfutil.assert_tf_initialized()
105
+ assert not self._updates_applied
106
+ self._updates_applied = True
107
+ devices = list(self._dev_grads.keys())
108
+ total_grads = sum(len(grads) for grads in self._dev_grads.values())
109
+ assert len(devices) >= 1 and total_grads >= 1
110
+ ops = []
111
+
112
+ with tfutil.absolute_name_scope(self.scope):
113
+ # Cast gradients to FP32 and calculate partial sum within each device.
114
+ dev_grads = OrderedDict() # device => [(grad, var), ...]
115
+
116
+ for dev_idx, dev in enumerate(devices):
117
+ with tf.name_scope("ProcessGrads%d" % dev_idx), tf.device(dev):
118
+ sums = []
119
+
120
+ for gv in zip(*self._dev_grads[dev]):
121
+ assert all(v is gv[0][1] for g, v in gv)
122
+ g = [tf.cast(g, tf.float32) for g, v in gv]
123
+ g = g[0] if len(g) == 1 else tf.add_n(g)
124
+ sums.append((g, gv[0][1]))
125
+
126
+ dev_grads[dev] = sums
127
+
128
+ # Sum gradients across devices.
129
+ if len(devices) > 1:
130
+ with tf.name_scope("SumAcrossGPUs"), tf.device(None):
131
+ for var_idx, grad_shape in enumerate(self._grad_shapes):
132
+ g = [dev_grads[dev][var_idx][0] for dev in devices]
133
+
134
+ if np.prod(grad_shape): # nccl does not support zero-sized tensors
135
+ g = nccl_ops.all_sum(g)
136
+
137
+ for dev, gg in zip(devices, g):
138
+ dev_grads[dev][var_idx] = (gg, dev_grads[dev][var_idx][1])
139
+
140
+ # Apply updates separately on each device.
141
+ for dev_idx, (dev, grads) in enumerate(dev_grads.items()):
142
+ with tf.name_scope("ApplyGrads%d" % dev_idx), tf.device(dev):
143
+ # Scale gradients as needed.
144
+ if self.use_loss_scaling or total_grads > 1:
145
+ with tf.name_scope("Scale"):
146
+ coef = tf.constant(np.float32(1.0 / total_grads), name="coef")
147
+ coef = self.undo_loss_scaling(coef)
148
+ grads = [(g * coef, v) for g, v in grads]
149
+
150
+ # Check for overflows.
151
+ with tf.name_scope("CheckOverflow"):
152
+ grad_ok = tf.reduce_all(tf.stack([tf.reduce_all(tf.is_finite(g)) for g, v in grads]))
153
+
154
+ # Update weights and adjust loss scaling.
155
+ with tf.name_scope("UpdateWeights"):
156
+ # pylint: disable=cell-var-from-loop
157
+ opt = self._dev_opt[dev]
158
+ ls_var = self.get_loss_scaling_var(dev)
159
+
160
+ if not self.use_loss_scaling:
161
+ ops.append(tf.cond(grad_ok, lambda: opt.apply_gradients(grads), tf.no_op))
162
+ else:
163
+ ops.append(tf.cond(grad_ok,
164
+ lambda: tf.group(tf.assign_add(ls_var, self.loss_scaling_inc), opt.apply_gradients(grads)),
165
+ lambda: tf.group(tf.assign_sub(ls_var, self.loss_scaling_dec))))
166
+
167
+ # Report statistics on the last device.
168
+ if dev == devices[-1]:
169
+ with tf.name_scope("Statistics"):
170
+ ops.append(autosummary.autosummary(self.id + "/learning_rate", self.learning_rate))
171
+ ops.append(autosummary.autosummary(self.id + "/overflow_frequency", tf.where(grad_ok, 0, 1)))
172
+
173
+ if self.use_loss_scaling:
174
+ ops.append(autosummary.autosummary(self.id + "/loss_scaling_log2", ls_var))
175
+
176
+ # Initialize variables and group everything into a single op.
177
+ self.reset_optimizer_state()
178
+ tfutil.init_uninitialized_vars(list(self._dev_ls_var.values()))
179
+
180
+ return tf.group(*ops, name="TrainingOp")
181
+
182
+ def reset_optimizer_state(self) -> None:
183
+ """Reset internal state of the underlying optimizer."""
184
+ tfutil.assert_tf_initialized()
185
+ tfutil.run([var.initializer for opt in self._dev_opt.values() for var in opt.variables()])
186
+
187
+ def get_loss_scaling_var(self, device: str) -> Union[tf.Variable, None]:
188
+ """Get or create variable representing log2 of the current dynamic loss scaling factor."""
189
+ if not self.use_loss_scaling:
190
+ return None
191
+
192
+ if device not in self._dev_ls_var:
193
+ with tfutil.absolute_name_scope(self.scope + "/LossScalingVars"), tf.control_dependencies(None):
194
+ self._dev_ls_var[device] = tf.Variable(np.float32(self.loss_scaling_init), name="loss_scaling_var")
195
+
196
+ return self._dev_ls_var[device]
197
+
198
+ def apply_loss_scaling(self, value: TfExpression) -> TfExpression:
199
+ """Apply dynamic loss scaling for the given expression."""
200
+ assert tfutil.is_tf_expression(value)
201
+
202
+ if not self.use_loss_scaling:
203
+ return value
204
+
205
+ return value * tfutil.exp2(self.get_loss_scaling_var(value.device))
206
+
207
+ def undo_loss_scaling(self, value: TfExpression) -> TfExpression:
208
+ """Undo the effect of dynamic loss scaling for the given expression."""
209
+ assert tfutil.is_tf_expression(value)
210
+
211
+ if not self.use_loss_scaling:
212
+ return value
213
+
214
+ return value * tfutil.exp2(-self.get_loss_scaling_var(value.device)) # pylint: disable=invalid-unary-operand-type
dnnlib/tflib/tfutil.py ADDED
@@ -0,0 +1,242 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # This work is licensed under the Creative Commons Attribution-NonCommercial
4
+ # 4.0 International License. To view a copy of this license, visit
5
+ # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to
6
+ # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.
7
+
8
+ """Miscellaneous helper utils for Tensorflow."""
9
+
10
+ import os
11
+ import numpy as np
12
+ import tensorflow as tf
13
+
14
+ from typing import Any, Iterable, List, Union
15
+
16
+ TfExpression = Union[tf.Tensor, tf.Variable, tf.Operation]
17
+ """A type that represents a valid Tensorflow expression."""
18
+
19
+ TfExpressionEx = Union[TfExpression, int, float, np.ndarray]
20
+ """A type that can be converted to a valid Tensorflow expression."""
21
+
22
+
23
+ def run(*args, **kwargs) -> Any:
24
+ """Run the specified ops in the default session."""
25
+ assert_tf_initialized()
26
+ return tf.compat.v1.get_default_session().run(*args, **kwargs)
27
+
28
+
29
+ def is_tf_expression(x: Any) -> bool:
30
+ """Check whether the input is a valid Tensorflow expression, i.e., Tensorflow Tensor, Variable, or Operation."""
31
+ return isinstance(x, (tf.Tensor, tf.Variable, tf.Operation))
32
+
33
+
34
+ def shape_to_list(shape: Iterable[tf.compat.v1.Dimension]) -> List[Union[int, None]]:
35
+ """Convert a Tensorflow shape to a list of ints."""
36
+ return [dim.value for dim in shape]
37
+
38
+
39
+ def flatten(x: TfExpressionEx) -> TfExpression:
40
+ """Shortcut function for flattening a tensor."""
41
+ with tf.name_scope("Flatten"):
42
+ return tf.reshape(x, [-1])
43
+
44
+
45
+ def log2(x: TfExpressionEx) -> TfExpression:
46
+ """Logarithm in base 2."""
47
+ with tf.name_scope("Log2"):
48
+ return tf.log(x) * np.float32(1.0 / np.log(2.0))
49
+
50
+
51
+ def exp2(x: TfExpressionEx) -> TfExpression:
52
+ """Exponent in base 2."""
53
+ with tf.name_scope("Exp2"):
54
+ return tf.exp(x * np.float32(np.log(2.0)))
55
+
56
+
57
+ def lerp(a: TfExpressionEx, b: TfExpressionEx, t: TfExpressionEx) -> TfExpressionEx:
58
+ """Linear interpolation."""
59
+ with tf.name_scope("Lerp"):
60
+ return a + (b - a) * t
61
+
62
+
63
+ def lerp_clip(a: TfExpressionEx, b: TfExpressionEx, t: TfExpressionEx) -> TfExpression:
64
+ """Linear interpolation with clip."""
65
+ with tf.name_scope("LerpClip"):
66
+ return a + (b - a) * tf.clip_by_value(t, 0.0, 1.0)
67
+
68
+
69
+ def absolute_name_scope(scope: str) -> tf.name_scope:
70
+ """Forcefully enter the specified name scope, ignoring any surrounding scopes."""
71
+ return tf.name_scope(scope + "/")
72
+
73
+
74
+ def absolute_variable_scope(scope: str, **kwargs) -> tf.compat.v1.variable_scope:
75
+ """Forcefully enter the specified variable scope, ignoring any surrounding scopes."""
76
+ return tf.compat.v1.variable_scope(tf.compat.v1.VariableScope(name=scope, **kwargs), auxiliary_name_scope=False)
77
+
78
+
79
+ def _sanitize_tf_config(config_dict: dict = None) -> dict:
80
+ # Defaults.
81
+ cfg = dict()
82
+ cfg["rnd.np_random_seed"] = None # Random seed for NumPy. None = keep as is.
83
+ cfg["rnd.tf_random_seed"] = "auto" # Random seed for TensorFlow. 'auto' = derive from NumPy random state. None = keep as is.
84
+ cfg["env.TF_CPP_MIN_LOG_LEVEL"] = "1" # 0 = Print all available debug info from TensorFlow. 1 = Print warnings and errors, but disable debug info.
85
+ cfg["graph_options.place_pruned_graph"] = True # False = Check that all ops are available on the designated device. True = Skip the check for ops that are not used.
86
+ cfg["gpu_options.allow_growth"] = True # False = Allocate all GPU memory at the beginning. True = Allocate only as much GPU memory as needed.
87
+
88
+ # User overrides.
89
+ if config_dict is not None:
90
+ cfg.update(config_dict)
91
+ return cfg
92
+
93
+
94
+ def init_tf(config_dict: dict = None) -> None:
95
+ """Initialize TensorFlow session using good default settings."""
96
+ # Skip if already initialized.
97
+ if tf.compat.v1.get_default_session() is not None:
98
+ return
99
+
100
+ # Setup config dict and random seeds.
101
+ cfg = _sanitize_tf_config(config_dict)
102
+ np_random_seed = cfg["rnd.np_random_seed"]
103
+ if np_random_seed is not None:
104
+ np.random.seed(np_random_seed)
105
+ tf_random_seed = cfg["rnd.tf_random_seed"]
106
+ if tf_random_seed == "auto":
107
+ tf_random_seed = np.random.randint(1 << 31)
108
+ if tf_random_seed is not None:
109
+ tf.compat.v1.set_random_seed(tf_random_seed)
110
+
111
+ # Setup environment variables.
112
+ for key, value in list(cfg.items()):
113
+ fields = key.split(".")
114
+ if fields[0] == "env":
115
+ assert len(fields) == 2
116
+ os.environ[fields[1]] = str(value)
117
+
118
+ # Create default TensorFlow session.
119
+ create_session(cfg, force_as_default=True)
120
+
121
+
122
+ def assert_tf_initialized():
123
+ """Check that TensorFlow session has been initialized."""
124
+ if tf.compat.v1.get_default_session() is None:
125
+ raise RuntimeError("No default TensorFlow session found. Please call dnnlib.tflib.init_tf().")
126
+
127
+
128
+ def create_session(config_dict: dict = None, force_as_default: bool = False) -> tf.compat.v1.Session:
129
+ """Create tf.Session based on config dict."""
130
+ # Setup TensorFlow config proto.
131
+ cfg = _sanitize_tf_config(config_dict)
132
+ config_proto = tf.compat.v1.ConfigProto()
133
+ for key, value in cfg.items():
134
+ fields = key.split(".")
135
+ if fields[0] not in ["rnd", "env"]:
136
+ obj = config_proto
137
+ for field in fields[:-1]:
138
+ obj = getattr(obj, field)
139
+ setattr(obj, fields[-1], value)
140
+
141
+ # Create session.
142
+ session = tf.compat.v1.Session(config=config_proto)
143
+ if force_as_default:
144
+ # pylint: disable=protected-access
145
+ session._default_session = session.as_default()
146
+ session._default_session.enforce_nesting = False
147
+ session._default_session.__enter__() # pylint: disable=no-member
148
+
149
+ return session
150
+
151
+
152
+ def init_uninitialized_vars(target_vars: List[tf.Variable] = None) -> None:
153
+ """Initialize all tf.Variables that have not already been initialized.
154
+
155
+ Equivalent to the following, but more efficient and does not bloat the tf graph:
156
+ tf.variables_initializer(tf.report_uninitialized_variables()).run()
157
+ """
158
+ assert_tf_initialized()
159
+ if target_vars is None:
160
+ target_vars = tf.global_variables()
161
+
162
+ test_vars = []
163
+ test_ops = []
164
+
165
+ with tf.control_dependencies(None): # ignore surrounding control_dependencies
166
+ for var in target_vars:
167
+ assert is_tf_expression(var)
168
+
169
+ try:
170
+ tf.get_default_graph().get_tensor_by_name(var.name.replace(":0", "/IsVariableInitialized:0"))
171
+ except KeyError:
172
+ # Op does not exist => variable may be uninitialized.
173
+ test_vars.append(var)
174
+
175
+ with absolute_name_scope(var.name.split(":")[0]):
176
+ test_ops.append(tf.is_variable_initialized(var))
177
+
178
+ init_vars = [var for var, inited in zip(test_vars, run(test_ops)) if not inited]
179
+ run([var.initializer for var in init_vars])
180
+
181
+
182
+ def set_vars(var_to_value_dict: dict) -> None:
183
+ """Set the values of given tf.Variables.
184
+
185
+ Equivalent to the following, but more efficient and does not bloat the tf graph:
186
+ tflib.run([tf.assign(var, value) for var, value in var_to_value_dict.items()]
187
+ """
188
+ assert_tf_initialized()
189
+ ops = []
190
+ feed_dict = {}
191
+
192
+ for var, value in var_to_value_dict.items():
193
+ assert is_tf_expression(var)
194
+
195
+ try:
196
+ setter = tf.compat.v1.get_default_graph().get_tensor_by_name(var.name.replace(":0", "/setter:0")) # look for existing op
197
+ except KeyError:
198
+ with absolute_name_scope(var.name.split(":")[0]):
199
+ with tf.control_dependencies(None): # ignore surrounding control_dependencies
200
+ setter = tf.compat.v1.assign(var, tf.compat.v1.placeholder(var.dtype, var.shape, "new_value"), name="setter") # create new setter
201
+
202
+ ops.append(setter)
203
+ feed_dict[setter.op.inputs[1]] = value
204
+
205
+ run(ops, feed_dict)
206
+
207
+
208
+ def create_var_with_large_initial_value(initial_value: np.ndarray, *args, **kwargs):
209
+ """Create tf.Variable with large initial value without bloating the tf graph."""
210
+ assert_tf_initialized()
211
+ assert isinstance(initial_value, np.ndarray)
212
+ zeros = tf.zeros(initial_value.shape, initial_value.dtype)
213
+ var = tf.Variable(zeros, *args, **kwargs)
214
+ set_vars({var: initial_value})
215
+ return var
216
+
217
+
218
+ def convert_images_from_uint8(images, drange=[-1,1], nhwc_to_nchw=False):
219
+ """Convert a minibatch of images from uint8 to float32 with configurable dynamic range.
220
+ Can be used as an input transformation for Network.run().
221
+ """
222
+ images = tf.cast(images, tf.float32)
223
+ if nhwc_to_nchw:
224
+ images = tf.transpose(images, [0, 3, 1, 2])
225
+ return (images - drange[0]) * ((drange[1] - drange[0]) / 255)
226
+
227
+
228
+ def convert_images_to_uint8(images, drange=[-1,1], nchw_to_nhwc=False, shrink=1, uint8_cast=True):
229
+ """Convert a minibatch of images from float32 to uint8 with configurable dynamic range.
230
+ Can be used as an output transformation for Network.run().
231
+ """
232
+ images = tf.cast(images, tf.float32)
233
+ if shrink > 1:
234
+ ksize = [1, 1, shrink, shrink]
235
+ images = tf.nn.avg_pool(images, ksize=ksize, strides=ksize, padding="VALID", data_format="NCHW")
236
+ if nchw_to_nhwc:
237
+ images = tf.transpose(images, [0, 2, 3, 1])
238
+ scale = 255 / (drange[1] - drange[0])
239
+ images = images * scale + (0.5 - drange[0] * scale)
240
+ if uint8_cast:
241
+ images = tf.saturate_cast(images, tf.uint8)
242
+ return images
dnnlib/util.py ADDED
@@ -0,0 +1,408 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # This work is licensed under the Creative Commons Attribution-NonCommercial
4
+ # 4.0 International License. To view a copy of this license, visit
5
+ # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to
6
+ # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.
7
+
8
+ """Miscellaneous utility classes and functions."""
9
+
10
+ import ctypes
11
+ import fnmatch
12
+ import importlib
13
+ import inspect
14
+ import numpy as np
15
+ import os
16
+ import shutil
17
+ import sys
18
+ import types
19
+ import io
20
+ import pickle
21
+ import re
22
+ import requests
23
+ import html
24
+ import hashlib
25
+ import glob
26
+ import uuid
27
+
28
+ from distutils.util import strtobool
29
+ from typing import Any, List, Tuple, Union
30
+
31
+
32
+ # Util classes
33
+ # ------------------------------------------------------------------------------------------
34
+
35
+
36
+ class EasyDict(dict):
37
+ """Convenience class that behaves like a dict but allows access with the attribute syntax."""
38
+
39
+ def __getattr__(self, name: str) -> Any:
40
+ try:
41
+ return self[name]
42
+ except KeyError:
43
+ raise AttributeError(name)
44
+
45
+ def __setattr__(self, name: str, value: Any) -> None:
46
+ self[name] = value
47
+
48
+ def __delattr__(self, name: str) -> None:
49
+ del self[name]
50
+
51
+
52
+ class Logger(object):
53
+ """Redirect stderr to stdout, optionally print stdout to a file, and optionally force flushing on both stdout and the file."""
54
+
55
+ def __init__(self, file_name: str = None, file_mode: str = "w", should_flush: bool = True):
56
+ self.file = None
57
+
58
+ if file_name is not None:
59
+ self.file = open(file_name, file_mode)
60
+
61
+ self.should_flush = should_flush
62
+ self.stdout = sys.stdout
63
+ self.stderr = sys.stderr
64
+
65
+ sys.stdout = self
66
+ sys.stderr = self
67
+
68
+ def __enter__(self) -> "Logger":
69
+ return self
70
+
71
+ def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
72
+ self.close()
73
+
74
+ def write(self, text: str) -> None:
75
+ """Write text to stdout (and a file) and optionally flush."""
76
+ if len(text) == 0: # workaround for a bug in VSCode debugger: sys.stdout.write(''); sys.stdout.flush() => crash
77
+ return
78
+
79
+ if self.file is not None:
80
+ self.file.write(text)
81
+
82
+ self.stdout.write(text)
83
+
84
+ if self.should_flush:
85
+ self.flush()
86
+
87
+ def flush(self) -> None:
88
+ """Flush written text to both stdout and a file, if open."""
89
+ if self.file is not None:
90
+ self.file.flush()
91
+
92
+ self.stdout.flush()
93
+
94
+ def close(self) -> None:
95
+ """Flush, close possible files, and remove stdout/stderr mirroring."""
96
+ self.flush()
97
+
98
+ # if using multiple loggers, prevent closing in wrong order
99
+ if sys.stdout is self:
100
+ sys.stdout = self.stdout
101
+ if sys.stderr is self:
102
+ sys.stderr = self.stderr
103
+
104
+ if self.file is not None:
105
+ self.file.close()
106
+
107
+
108
+ # Small util functions
109
+ # ------------------------------------------------------------------------------------------
110
+
111
+
112
+ def format_time(seconds: Union[int, float]) -> str:
113
+ """Convert the seconds to human readable string with days, hours, minutes and seconds."""
114
+ s = int(np.rint(seconds))
115
+
116
+ if s < 60:
117
+ return "{0}s".format(s)
118
+ elif s < 60 * 60:
119
+ return "{0}m {1:02}s".format(s // 60, s % 60)
120
+ elif s < 24 * 60 * 60:
121
+ return "{0}h {1:02}m {2:02}s".format(s // (60 * 60), (s // 60) % 60, s % 60)
122
+ else:
123
+ return "{0}d {1:02}h {2:02}m".format(s // (24 * 60 * 60), (s // (60 * 60)) % 24, (s // 60) % 60)
124
+
125
+
126
+ def ask_yes_no(question: str) -> bool:
127
+ """Ask the user the question until the user inputs a valid answer."""
128
+ while True:
129
+ try:
130
+ print("{0} [y/n]".format(question))
131
+ return strtobool(input().lower())
132
+ except ValueError:
133
+ pass
134
+
135
+
136
+ def tuple_product(t: Tuple) -> Any:
137
+ """Calculate the product of the tuple elements."""
138
+ result = 1
139
+
140
+ for v in t:
141
+ result *= v
142
+
143
+ return result
144
+
145
+
146
+ _str_to_ctype = {
147
+ "uint8": ctypes.c_ubyte,
148
+ "uint16": ctypes.c_uint16,
149
+ "uint32": ctypes.c_uint32,
150
+ "uint64": ctypes.c_uint64,
151
+ "int8": ctypes.c_byte,
152
+ "int16": ctypes.c_int16,
153
+ "int32": ctypes.c_int32,
154
+ "int64": ctypes.c_int64,
155
+ "float32": ctypes.c_float,
156
+ "float64": ctypes.c_double
157
+ }
158
+
159
+
160
+ def get_dtype_and_ctype(type_obj: Any) -> Tuple[np.dtype, Any]:
161
+ """Given a type name string (or an object having a __name__ attribute), return matching Numpy and ctypes types that have the same size in bytes."""
162
+ type_str = None
163
+
164
+ if isinstance(type_obj, str):
165
+ type_str = type_obj
166
+ elif hasattr(type_obj, "__name__"):
167
+ type_str = type_obj.__name__
168
+ elif hasattr(type_obj, "name"):
169
+ type_str = type_obj.name
170
+ else:
171
+ raise RuntimeError("Cannot infer type name from input")
172
+
173
+ assert type_str in _str_to_ctype.keys()
174
+
175
+ my_dtype = np.dtype(type_str)
176
+ my_ctype = _str_to_ctype[type_str]
177
+
178
+ assert my_dtype.itemsize == ctypes.sizeof(my_ctype)
179
+
180
+ return my_dtype, my_ctype
181
+
182
+
183
+ def is_pickleable(obj: Any) -> bool:
184
+ try:
185
+ with io.BytesIO() as stream:
186
+ pickle.dump(obj, stream)
187
+ return True
188
+ except:
189
+ return False
190
+
191
+
192
+ # Functionality to import modules/objects by name, and call functions by name
193
+ # ------------------------------------------------------------------------------------------
194
+
195
+ def get_module_from_obj_name(obj_name: str) -> Tuple[types.ModuleType, str]:
196
+ """Searches for the underlying module behind the name to some python object.
197
+ Returns the module and the object name (original name with module part removed)."""
198
+
199
+ # allow convenience shorthands, substitute them by full names
200
+ obj_name = re.sub("^np.", "numpy.", obj_name)
201
+ obj_name = re.sub("^tf.", "tensorflow.", obj_name)
202
+
203
+ # list alternatives for (module_name, local_obj_name)
204
+ parts = obj_name.split(".")
205
+ name_pairs = [(".".join(parts[:i]), ".".join(parts[i:])) for i in range(len(parts), 0, -1)]
206
+
207
+ # try each alternative in turn
208
+ for module_name, local_obj_name in name_pairs:
209
+ try:
210
+ module = importlib.import_module(module_name) # may raise ImportError
211
+ get_obj_from_module(module, local_obj_name) # may raise AttributeError
212
+ return module, local_obj_name
213
+ except:
214
+ pass
215
+
216
+ # maybe some of the modules themselves contain errors?
217
+ for module_name, _local_obj_name in name_pairs:
218
+ try:
219
+ importlib.import_module(module_name) # may raise ImportError
220
+ except ImportError:
221
+ if not str(sys.exc_info()[1]).startswith("No module named '" + module_name + "'"):
222
+ raise
223
+
224
+ # maybe the requested attribute is missing?
225
+ for module_name, local_obj_name in name_pairs:
226
+ try:
227
+ module = importlib.import_module(module_name) # may raise ImportError
228
+ get_obj_from_module(module, local_obj_name) # may raise AttributeError
229
+ except ImportError:
230
+ pass
231
+
232
+ # we are out of luck, but we have no idea why
233
+ raise ImportError(obj_name)
234
+
235
+
236
+ def get_obj_from_module(module: types.ModuleType, obj_name: str) -> Any:
237
+ """Traverses the object name and returns the last (rightmost) python object."""
238
+ if obj_name == '':
239
+ return module
240
+ obj = module
241
+ for part in obj_name.split("."):
242
+ obj = getattr(obj, part)
243
+ return obj
244
+
245
+
246
+ def get_obj_by_name(name: str) -> Any:
247
+ """Finds the python object with the given name."""
248
+ module, obj_name = get_module_from_obj_name(name)
249
+ return get_obj_from_module(module, obj_name)
250
+
251
+
252
+ def call_func_by_name(*args, func_name: str = None, **kwargs) -> Any:
253
+ """Finds the python object with the given name and calls it as a function."""
254
+ assert func_name is not None
255
+ func_obj = get_obj_by_name(func_name)
256
+ assert callable(func_obj)
257
+ return func_obj(*args, **kwargs)
258
+
259
+
260
+ def get_module_dir_by_obj_name(obj_name: str) -> str:
261
+ """Get the directory path of the module containing the given object name."""
262
+ module, _ = get_module_from_obj_name(obj_name)
263
+ return os.path.dirname(inspect.getfile(module))
264
+
265
+
266
+ def is_top_level_function(obj: Any) -> bool:
267
+ """Determine whether the given object is a top-level function, i.e., defined at module scope using 'def'."""
268
+ return callable(obj) and obj.__name__ in sys.modules[obj.__module__].__dict__
269
+
270
+
271
+ def get_top_level_function_name(obj: Any) -> str:
272
+ """Return the fully-qualified name of a top-level function."""
273
+ assert is_top_level_function(obj)
274
+ return obj.__module__ + "." + obj.__name__
275
+
276
+
277
+ # File system helpers
278
+ # ------------------------------------------------------------------------------------------
279
+
280
+ def list_dir_recursively_with_ignore(dir_path: str, ignores: List[str] = None, add_base_to_relative: bool = False) -> List[Tuple[str, str]]:
281
+ """List all files recursively in a given directory while ignoring given file and directory names.
282
+ Returns list of tuples containing both absolute and relative paths."""
283
+ assert os.path.isdir(dir_path)
284
+ base_name = os.path.basename(os.path.normpath(dir_path))
285
+
286
+ if ignores is None:
287
+ ignores = []
288
+
289
+ result = []
290
+
291
+ for root, dirs, files in os.walk(dir_path, topdown=True):
292
+ for ignore_ in ignores:
293
+ dirs_to_remove = [d for d in dirs if fnmatch.fnmatch(d, ignore_)]
294
+
295
+ # dirs need to be edited in-place
296
+ for d in dirs_to_remove:
297
+ dirs.remove(d)
298
+
299
+ files = [f for f in files if not fnmatch.fnmatch(f, ignore_)]
300
+
301
+ absolute_paths = [os.path.join(root, f) for f in files]
302
+ relative_paths = [os.path.relpath(p, dir_path) for p in absolute_paths]
303
+
304
+ if add_base_to_relative:
305
+ relative_paths = [os.path.join(base_name, p) for p in relative_paths]
306
+
307
+ assert len(absolute_paths) == len(relative_paths)
308
+ result += zip(absolute_paths, relative_paths)
309
+
310
+ return result
311
+
312
+
313
+ def copy_files_and_create_dirs(files: List[Tuple[str, str]]) -> None:
314
+ """Takes in a list of tuples of (src, dst) paths and copies files.
315
+ Will create all necessary directories."""
316
+ for file in files:
317
+ target_dir_name = os.path.dirname(file[1])
318
+
319
+ # will create all intermediate-level directories
320
+ if not os.path.exists(target_dir_name):
321
+ os.makedirs(target_dir_name)
322
+
323
+ shutil.copyfile(file[0], file[1])
324
+
325
+
326
+ # URL helpers
327
+ # ------------------------------------------------------------------------------------------
328
+
329
+ def is_url(obj: Any) -> bool:
330
+ """Determine whether the given object is a valid URL string."""
331
+ if not isinstance(obj, str) or not "://" in obj:
332
+ return False
333
+ try:
334
+ res = requests.compat.urlparse(obj)
335
+ if not res.scheme or not res.netloc or not "." in res.netloc:
336
+ return False
337
+ res = requests.compat.urlparse(requests.compat.urljoin(obj, "/"))
338
+ if not res.scheme or not res.netloc or not "." in res.netloc:
339
+ return False
340
+ except:
341
+ return False
342
+ return True
343
+
344
+
345
+ def open_url(url: str, cache_dir: str = None, num_attempts: int = 10, verbose: bool = True) -> Any:
346
+ """Download the given URL and return a binary-mode file object to access the data."""
347
+ if not is_url(url) and os.path.isfile(url):
348
+ return open(url, 'rb')
349
+
350
+ assert is_url(url)
351
+ assert num_attempts >= 1
352
+
353
+ # Lookup from cache.
354
+ url_md5 = hashlib.md5(url.encode("utf-8")).hexdigest()
355
+ if cache_dir is not None:
356
+ cache_files = glob.glob(os.path.join(cache_dir, url_md5 + "_*"))
357
+ if len(cache_files) == 1:
358
+ return open(cache_files[0], "rb")
359
+
360
+ # Download.
361
+ url_name = None
362
+ url_data = None
363
+ with requests.Session() as session:
364
+ if verbose:
365
+ print("Downloading %s ..." % url, end="", flush=True)
366
+ for attempts_left in reversed(range(num_attempts)):
367
+ try:
368
+ with session.get(url) as res:
369
+ res.raise_for_status()
370
+ if len(res.content) == 0:
371
+ raise IOError("No data received")
372
+
373
+ if len(res.content) < 8192:
374
+ content_str = res.content.decode("utf-8")
375
+ if "download_warning" in res.headers.get("Set-Cookie", ""):
376
+ links = [html.unescape(link) for link in content_str.split('"') if "export=download" in link]
377
+ if len(links) == 1:
378
+ url = requests.compat.urljoin(url, links[0])
379
+ raise IOError("Google Drive virus checker nag")
380
+ if "Google Drive - Quota exceeded" in content_str:
381
+ raise IOError("Google Drive quota exceeded")
382
+
383
+ match = re.search(r'filename="([^"]*)"', res.headers.get("Content-Disposition", ""))
384
+ url_name = match[1] if match else url
385
+ url_data = res.content
386
+ if verbose:
387
+ print(" done")
388
+ break
389
+ except:
390
+ if not attempts_left:
391
+ if verbose:
392
+ print(" failed")
393
+ raise
394
+ if verbose:
395
+ print(".", end="", flush=True)
396
+
397
+ # Save to cache.
398
+ if cache_dir is not None:
399
+ safe_name = re.sub(r"[^0-9a-zA-Z-._]", "_", url_name)
400
+ cache_file = os.path.join(cache_dir, url_md5 + "_" + safe_name)
401
+ temp_file = os.path.join(cache_dir, "tmp_" + uuid.uuid4().hex + "_" + url_md5 + "_" + safe_name)
402
+ os.makedirs(cache_dir, exist_ok=True)
403
+ with open(temp_file, "wb") as f:
404
+ f.write(url_data)
405
+ os.replace(temp_file, cache_file) # atomic
406
+
407
+ # Return data as file object.
408
+ return io.BytesIO(url_data)
encode_images.py ADDED
@@ -0,0 +1,242 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import argparse
3
+ import pickle
4
+ from tqdm import tqdm
5
+ import PIL.Image
6
+ from PIL import ImageFilter
7
+ import numpy as np
8
+ import dnnlib
9
+ import dnnlib.tflib as tflib
10
+ import config
11
+ from encoder.generator_model import Generator
12
+ from encoder.perceptual_model import PerceptualModel, load_images
13
+ #from tensorflow.keras.models import load_model
14
+ from keras.models import load_model
15
+ from keras.applications.resnet50 import preprocess_input
16
+
17
+ def split_to_batches(l, n):
18
+ for i in range(0, len(l), n):
19
+ yield l[i:i + n]
20
+
21
+ def str2bool(v):
22
+ if isinstance(v, bool):
23
+ return v
24
+ if v.lower() in ('yes', 'true', 't', 'y', '1'):
25
+ return True
26
+ elif v.lower() in ('no', 'false', 'f', 'n', '0'):
27
+ return False
28
+ else:
29
+ raise argparse.ArgumentTypeError('Boolean value expected.')
30
+
31
+ def main():
32
+ parser = argparse.ArgumentParser(description='Find latent representation of reference images using perceptual losses', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
33
+ parser.add_argument('src_dir', help='Directory with images for encoding')
34
+ parser.add_argument('generated_images_dir', help='Directory for storing generated images')
35
+ parser.add_argument('dlatent_dir', help='Directory for storing dlatent representations')
36
+ parser.add_argument('--data_dir', default='data', help='Directory for storing optional models')
37
+ parser.add_argument('--mask_dir', default='masks', help='Directory for storing optional masks')
38
+ parser.add_argument('--load_last', default='', help='Start with embeddings from directory')
39
+ parser.add_argument('--dlatent_avg', default='', help='Use dlatent from file specified here for truncation instead of dlatent_avg from Gs')
40
+ parser.add_argument('--model_url', default='./data/karras2019stylegan-ffhq-1024x1024.pkl', help='Fetch a StyleGAN model to train on from this URL')
41
+ parser.add_argument('--architecture', default='./data/vgg16_zhang_perceptual.pkl', help='Сonvolutional neural network model from this URL')
42
+ parser.add_argument('--model_res', default=1024, help='The dimension of images in the StyleGAN model', type=int)
43
+ parser.add_argument('--batch_size', default=1, help='Batch size for generator and perceptual model', type=int)
44
+ parser.add_argument('--optimizer', default='ggt', help='Optimization algorithm used for optimizing dlatents')
45
+
46
+ # Perceptual model params
47
+ parser.add_argument('--image_size', default=256, help='Size of images for perceptual model', type=int)
48
+ parser.add_argument('--resnet_image_size', default=256, help='Size of images for the Resnet model', type=int)
49
+ parser.add_argument('--lr', default=0.25, help='Learning rate for perceptual model', type=float)
50
+ parser.add_argument('--decay_rate', default=0.9, help='Decay rate for learning rate', type=float)
51
+ parser.add_argument('--iterations', default=100, help='Number of optimization steps for each batch', type=int)
52
+ parser.add_argument('--decay_steps', default=4, help='Decay steps for learning rate decay (as a percent of iterations)', type=float)
53
+ parser.add_argument('--early_stopping', default=True, help='Stop early once training stabilizes', type=str2bool, nargs='?', const=True)
54
+ parser.add_argument('--early_stopping_threshold', default=0.5, help='Stop after this threshold has been reached', type=float)
55
+ parser.add_argument('--early_stopping_patience', default=10, help='Number of iterations to wait below threshold', type=int)
56
+ parser.add_argument('--load_effnet', default='data/finetuned_effnet.h5', help='Model to load for EfficientNet approximation of dlatents')
57
+ parser.add_argument('--load_resnet', default='data/finetuned_resnet.h5', help='Model to load for ResNet approximation of dlatents')
58
+ parser.add_argument('--use_preprocess_input', default=True, help='Call process_input() first before using feed forward net', type=str2bool, nargs='?', const=True)
59
+ parser.add_argument('--use_best_loss', default=True, help='Output the lowest loss value found as the solution', type=str2bool, nargs='?', const=True)
60
+ parser.add_argument('--average_best_loss', default=0.25, help='Do a running weighted average with the previous best dlatents found', type=float)
61
+ parser.add_argument('--sharpen_input', default=True, help='Sharpen the input images', type=str2bool, nargs='?', const=True)
62
+
63
+ # Loss function options
64
+ parser.add_argument('--use_vgg_loss', default=0.4, help='Use VGG perceptual loss; 0 to disable, > 0 to scale.', type=float)
65
+ parser.add_argument('--use_vgg_layer', default=9, help='Pick which VGG layer to use.', type=int)
66
+ parser.add_argument('--use_pixel_loss', default=1.5, help='Use logcosh image pixel loss; 0 to disable, > 0 to scale.', type=float)
67
+ parser.add_argument('--use_mssim_loss', default=200, help='Use MS-SIM perceptual loss; 0 to disable, > 0 to scale.', type=float)
68
+ parser.add_argument('--use_lpips_loss', default=100, help='Use LPIPS perceptual loss; 0 to disable, > 0 to scale.', type=float)
69
+ parser.add_argument('--use_l1_penalty', default=0.5, help='Use L1 penalty on latents; 0 to disable, > 0 to scale.', type=float)
70
+ parser.add_argument('--use_discriminator_loss', default=0.5, help='Use trained discriminator to evaluate realism.', type=float)
71
+ parser.add_argument('--use_adaptive_loss', default=False, help='Use the adaptive robust loss function from Google Research for pixel and VGG feature loss.', type=str2bool, nargs='?', const=True)
72
+
73
+ # Generator params
74
+ parser.add_argument('--randomize_noise', default=False, help='Add noise to dlatents during optimization', type=str2bool, nargs='?', const=True)
75
+ parser.add_argument('--tile_dlatents', default=False, help='Tile dlatents to use a single vector at each scale', type=str2bool, nargs='?', const=True)
76
+ parser.add_argument('--clipping_threshold', default=2.0, help='Stochastic clipping of gradient values outside of this threshold', type=float)
77
+
78
+ # Masking params
79
+ parser.add_argument('--load_mask', default=False, help='Load segmentation masks', type=str2bool, nargs='?', const=True)
80
+ parser.add_argument('--face_mask', default=True, help='Generate a mask for predicting only the face area', type=str2bool, nargs='?', const=True)
81
+ parser.add_argument('--use_grabcut', default=True, help='Use grabcut algorithm on the face mask to better segment the foreground', type=str2bool, nargs='?', const=True)
82
+ parser.add_argument('--scale_mask', default=1.4, help='Look over a wider section of foreground for grabcut', type=float)
83
+ parser.add_argument('--composite_mask', default=True, help='Merge the unmasked area back into the generated image', type=str2bool, nargs='?', const=True)
84
+ parser.add_argument('--composite_blur', default=8, help='Size of blur filter to smoothly composite the images', type=int)
85
+
86
+ # Video params
87
+ parser.add_argument('--video_dir', default='videos', help='Directory for storing training videos')
88
+ parser.add_argument('--output_video', default=False, help='Generate videos of the optimization process', type=bool)
89
+ parser.add_argument('--video_codec', default='MJPG', help='FOURCC-supported video codec name')
90
+ parser.add_argument('--video_frame_rate', default=24, help='Video frames per second', type=int)
91
+ parser.add_argument('--video_size', default=512, help='Video size in pixels', type=int)
92
+ parser.add_argument('--video_skip', default=1, help='Only write every n frames (1 = write every frame)', type=int)
93
+
94
+ args, other_args = parser.parse_known_args()
95
+
96
+ args.decay_steps *= 0.01 * args.iterations # Calculate steps as a percent of total iterations
97
+
98
+ if args.output_video:
99
+ import cv2
100
+ synthesis_kwargs = dict(output_transform=dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=False), minibatch_size=args.batch_size)
101
+
102
+ ref_images = [os.path.join(args.src_dir, x) for x in os.listdir(args.src_dir)]
103
+ ref_images = list(filter(os.path.isfile, ref_images))
104
+
105
+ if len(ref_images) == 0:
106
+ raise Exception('%s is empty' % args.src_dir)
107
+
108
+ os.makedirs(args.data_dir, exist_ok=True)
109
+ os.makedirs(args.mask_dir, exist_ok=True)
110
+ os.makedirs(args.generated_images_dir, exist_ok=True)
111
+ os.makedirs(args.dlatent_dir, exist_ok=True)
112
+ os.makedirs(args.video_dir, exist_ok=True)
113
+
114
+ # Initialize generator and perceptual model
115
+ tflib.init_tf()
116
+ with dnnlib.util.open_url(args.model_url, cache_dir=config.cache_dir) as f:
117
+ generator_network, discriminator_network, Gs_network = pickle.load(f)
118
+
119
+ generator = Generator(Gs_network, args.batch_size, clipping_threshold=args.clipping_threshold, tiled_dlatent=args.tile_dlatents, model_res=args.model_res, randomize_noise=args.randomize_noise)
120
+ if (args.dlatent_avg != ''):
121
+ generator.set_dlatent_avg(np.load(args.dlatent_avg))
122
+
123
+ perc_model = None
124
+ if (args.use_lpips_loss > 0.00000001):
125
+ with dnnlib.util.open_url(args.architecture, cache_dir=config.cache_dir) as f:
126
+ perc_model = pickle.load(f)
127
+ perceptual_model = PerceptualModel(args, perc_model=perc_model, batch_size=args.batch_size)
128
+ perceptual_model.build_perceptual_model(generator, discriminator_network)
129
+
130
+ ff_model = None
131
+
132
+ # Optimize (only) dlatents by minimizing perceptual loss between reference and generated images in feature space
133
+ for images_batch in tqdm(split_to_batches(ref_images, args.batch_size), total=len(ref_images)//args.batch_size):
134
+ names = [os.path.splitext(os.path.basename(x))[0] for x in images_batch]
135
+ if args.output_video:
136
+ video_out = {}
137
+ for name in names:
138
+ video_out[name] = cv2.VideoWriter(os.path.join(args.video_dir, f'{name}.avi'),cv2.VideoWriter_fourcc(*args.video_codec), args.video_frame_rate, (args.video_size,args.video_size))
139
+
140
+ perceptual_model.set_reference_images(images_batch)
141
+ dlatents = None
142
+ if (args.load_last != ''): # load previous dlatents for initialization
143
+ for name in names:
144
+ dl = np.expand_dims(np.load(os.path.join(args.load_last, f'{name}.npy')),axis=0)
145
+ if (dlatents is None):
146
+ dlatents = dl
147
+ else:
148
+ dlatents = np.vstack((dlatents,dl))
149
+ else:
150
+ if (ff_model is None):
151
+ if os.path.exists(args.load_resnet):
152
+ from keras.applications.resnet50 import preprocess_input
153
+ print("Loading ResNet Model:")
154
+ ff_model = load_model(args.load_resnet)
155
+ if (ff_model is None):
156
+ if os.path.exists(args.load_effnet):
157
+ import efficientnet
158
+ from efficientnet import preprocess_input
159
+ print("Loading EfficientNet Model:")
160
+ ff_model = load_model(args.load_effnet)
161
+ if (ff_model is not None): # predict initial dlatents with ResNet model
162
+ if (args.use_preprocess_input):
163
+ dlatents = ff_model.predict(preprocess_input(load_images(images_batch,image_size=args.resnet_image_size)))
164
+ else:
165
+ dlatents = ff_model.predict(load_images(images_batch,image_size=args.resnet_image_size))
166
+ if dlatents is not None:
167
+ generator.set_dlatents(dlatents)
168
+ op = perceptual_model.optimize(generator.dlatent_variable, iterations=args.iterations, use_optimizer=args.optimizer)
169
+ pbar = tqdm(op, leave=False, total=args.iterations)
170
+ vid_count = 0
171
+ best_loss = None
172
+ best_dlatent = None
173
+ avg_loss_count = 0
174
+ if args.early_stopping:
175
+ avg_loss = prev_loss = None
176
+ for loss_dict in pbar:
177
+ if args.early_stopping: # early stopping feature
178
+ if prev_loss is not None:
179
+ if avg_loss is not None:
180
+ avg_loss = 0.5 * avg_loss + (prev_loss - loss_dict["loss"])
181
+ if avg_loss < args.early_stopping_threshold: # count while under threshold; else reset
182
+ avg_loss_count += 1
183
+ else:
184
+ avg_loss_count = 0
185
+ if avg_loss_count > args.early_stopping_patience: # stop once threshold is reached
186
+ print("")
187
+ break
188
+ else:
189
+ avg_loss = prev_loss - loss_dict["loss"]
190
+ pbar.set_description(" ".join(names) + ": " + "; ".join(["{} {:.4f}".format(k, v) for k, v in loss_dict.items()]))
191
+ if best_loss is None or loss_dict["loss"] < best_loss:
192
+ if best_dlatent is None or args.average_best_loss <= 0.00000001:
193
+ best_dlatent = generator.get_dlatents()
194
+ else:
195
+ best_dlatent = 0.25 * best_dlatent + 0.75 * generator.get_dlatents()
196
+ if args.use_best_loss:
197
+ generator.set_dlatents(best_dlatent)
198
+ best_loss = loss_dict["loss"]
199
+ if args.output_video and (vid_count % args.video_skip == 0):
200
+ batch_frames = generator.generate_images()
201
+ for i, name in enumerate(names):
202
+ video_frame = PIL.Image.fromarray(batch_frames[i], 'RGB').resize((args.video_size,args.video_size),PIL.Image.LANCZOS)
203
+ video_out[name].write(cv2.cvtColor(np.array(video_frame).astype('uint8'), cv2.COLOR_RGB2BGR))
204
+ generator.stochastic_clip_dlatents()
205
+ prev_loss = loss_dict["loss"]
206
+ if not args.use_best_loss:
207
+ best_loss = prev_loss
208
+ print(" ".join(names), " Loss {:.4f}".format(best_loss))
209
+
210
+ if args.output_video:
211
+ for name in names:
212
+ video_out[name].release()
213
+
214
+ # Generate images from found dlatents and save them
215
+ if args.use_best_loss:
216
+ generator.set_dlatents(best_dlatent)
217
+ generated_images = generator.generate_images()
218
+ generated_dlatents = generator.get_dlatents()
219
+ for img_array, dlatent, img_path, img_name in zip(generated_images, generated_dlatents, images_batch, names):
220
+ mask_img = None
221
+ if args.composite_mask and (args.load_mask or args.face_mask):
222
+ _, im_name = os.path.split(img_path)
223
+ mask_img = os.path.join(args.mask_dir, f'{im_name}')
224
+ if args.composite_mask and mask_img is not None and os.path.isfile(mask_img):
225
+ orig_img = PIL.Image.open(img_path).convert('RGB')
226
+ width, height = orig_img.size
227
+ imask = PIL.Image.open(mask_img).convert('L').resize((width, height))
228
+ imask = imask.filter(ImageFilter.GaussianBlur(args.composite_blur))
229
+ mask = np.array(imask)/255
230
+ mask = np.expand_dims(mask,axis=-1)
231
+ img_array = mask*np.array(img_array) + (1.0-mask)*np.array(orig_img)
232
+ img_array = img_array.astype(np.uint8)
233
+ #img_array = np.where(mask, np.array(img_array), orig_img)
234
+ img = PIL.Image.fromarray(img_array, 'RGB')
235
+ img.save(os.path.join(args.generated_images_dir, f'{img_name}.png'), 'PNG')
236
+ np.save(os.path.join(args.dlatent_dir, f'{img_name}.npy'), dlatent)
237
+
238
+ generator.reset_dlatents()
239
+
240
+
241
+ if __name__ == "__main__":
242
+ main()
encoder/__init__.py ADDED
File without changes
encoder/__pycache__/__init__.cpython-36.pyc ADDED
Binary file (113 Bytes). View file
 
encoder/__pycache__/__init__.cpython-37.pyc ADDED
Binary file (117 Bytes). View file
 
encoder/__pycache__/generator_model.cpython-36.pyc ADDED
Binary file (5.09 kB). View file
 
encoder/__pycache__/generator_model.cpython-37.pyc ADDED
Binary file (5.1 kB). View file
 
encoder/__pycache__/perceptual_model.cpython-36.pyc ADDED
Binary file (10.1 kB). View file
 
encoder/__pycache__/perceptual_model.cpython-37.pyc ADDED
Binary file (10 kB). View file
 
encoder/generator_model.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import tensorflow as tf
3
+ import numpy as np
4
+ import dnnlib.tflib as tflib
5
+ from functools import partial
6
+
7
+
8
+ def create_stub(name, batch_size):
9
+ return tf.constant(0, dtype='float32', shape=(batch_size, 0))
10
+
11
+
12
+ def create_variable_for_generator(name, batch_size, tiled_dlatent, model_scale=18, tile_size = 1):
13
+ if tiled_dlatent:
14
+ low_dim_dlatent = tf.get_variable('learnable_dlatents',
15
+ shape=(batch_size, tile_size, 512),
16
+ dtype='float32',
17
+ initializer=tf.initializers.random_normal())
18
+ return tf.tile(low_dim_dlatent, [1, model_scale // tile_size, 1])
19
+ else:
20
+ return tf.get_variable('learnable_dlatents',
21
+ shape=(batch_size, model_scale, 512),
22
+ dtype='float32',
23
+ initializer=tf.initializers.random_normal())
24
+
25
+
26
+ class Generator:
27
+ def __init__(self, model, batch_size, custom_input=None, clipping_threshold=2, tiled_dlatent=False, model_res=1024, randomize_noise=False):
28
+ self.batch_size = batch_size
29
+ self.tiled_dlatent=tiled_dlatent
30
+ self.model_scale = int(2*(math.log(model_res,2)-1)) # For example, 1024 -> 18
31
+
32
+ if tiled_dlatent:
33
+ self.initial_dlatents = np.zeros((self.batch_size, 512))
34
+ model.components.synthesis.run(np.zeros((self.batch_size, self.model_scale, 512)),
35
+ randomize_noise=randomize_noise, minibatch_size=self.batch_size,
36
+ custom_inputs=[partial(create_variable_for_generator, batch_size=batch_size, tiled_dlatent=True),
37
+ partial(create_stub, batch_size=batch_size)],
38
+ structure='fixed')
39
+ else:
40
+ self.initial_dlatents = np.zeros((self.batch_size, self.model_scale, 512))
41
+ if custom_input is not None:
42
+ model.components.synthesis.run(self.initial_dlatents,
43
+ randomize_noise=randomize_noise, minibatch_size=self.batch_size,
44
+ custom_inputs=[partial(custom_input.eval(), batch_size=batch_size), partial(create_stub, batch_size=batch_size)],
45
+ structure='fixed')
46
+ else:
47
+ model.components.synthesis.run(self.initial_dlatents,
48
+ randomize_noise=randomize_noise, minibatch_size=self.batch_size,
49
+ custom_inputs=[partial(create_variable_for_generator, batch_size=batch_size, tiled_dlatent=False, model_scale=self.model_scale),
50
+ partial(create_stub, batch_size=batch_size)],
51
+ structure='fixed')
52
+
53
+ self.dlatent_avg_def = model.get_var('dlatent_avg')
54
+ self.reset_dlatent_avg()
55
+ self.sess = tf.compat.v1.get_default_session()
56
+ self.graph = tf.compat.v1.get_default_graph()
57
+
58
+ self.dlatent_variable = next(v for v in tf.compat.v1.global_variables() if 'learnable_dlatents' in v.name)
59
+ self._assign_dlatent_ph = tf.compat.v1.placeholder(tf.float32, name="assign_dlatent_ph")
60
+ self._assign_dlantent = tf.assign(self.dlatent_variable, self._assign_dlatent_ph)
61
+ self.set_dlatents(self.initial_dlatents)
62
+
63
+ def get_tensor(name):
64
+ try:
65
+ return self.graph.get_tensor_by_name(name)
66
+ except KeyError:
67
+ return None
68
+
69
+ self.generator_output = get_tensor('G_synthesis_1/_Run/concat:0')
70
+ if self.generator_output is None:
71
+ self.generator_output = get_tensor('G_synthesis_1/_Run/concat/concat:0')
72
+ if self.generator_output is None:
73
+ self.generator_output = get_tensor('G_synthesis_1/_Run/concat_1/concat:0')
74
+ # If we loaded only Gs and didn't load G or D, then scope "G_synthesis_1" won't exist in the graph.
75
+ if self.generator_output is None:
76
+ self.generator_output = get_tensor('G_synthesis/_Run/concat:0')
77
+ if self.generator_output is None:
78
+ self.generator_output = get_tensor('G_synthesis/_Run/concat/concat:0')
79
+ if self.generator_output is None:
80
+ self.generator_output = get_tensor('G_synthesis/_Run/concat_1/concat:0')
81
+ if self.generator_output is None:
82
+ for op in self.graph.get_operations():
83
+ print(op)
84
+ raise Exception("Couldn't find G_synthesis_1/_Run/concat tensor output")
85
+ self.generated_image = tflib.convert_images_to_uint8(self.generator_output, nchw_to_nhwc=True, uint8_cast=False)
86
+ self.generated_image_uint8 = tf.saturate_cast(self.generated_image, tf.uint8)
87
+
88
+ # Implement stochastic clipping similar to what is described in https://arxiv.org/abs/1702.04782
89
+ # (Slightly different in that the latent space is normal gaussian here and was uniform in [-1, 1] in that paper,
90
+ # so we clip any vector components outside of [-2, 2]. It seems fine, but I haven't done an ablation check.)
91
+ clipping_mask = tf.math.logical_or(self.dlatent_variable > clipping_threshold, self.dlatent_variable < -clipping_threshold)
92
+ clipped_values = tf.where(clipping_mask, tf.random.normal(shape=self.dlatent_variable.shape), self.dlatent_variable)
93
+ self.stochastic_clip_op = tf.assign(self.dlatent_variable, clipped_values)
94
+
95
+ def reset_dlatents(self):
96
+ self.set_dlatents(self.initial_dlatents)
97
+
98
+ def set_dlatents(self, dlatents):
99
+ if self.tiled_dlatent:
100
+ if (dlatents.shape != (self.batch_size, 512)) and (dlatents.shape[1] != 512):
101
+ dlatents = np.mean(dlatents, axis=1)
102
+ if (dlatents.shape != (self.batch_size, 512)):
103
+ dlatents = np.vstack([dlatents, np.zeros((self.batch_size-dlatents.shape[0], 512))])
104
+ assert (dlatents.shape == (self.batch_size, 512))
105
+ else:
106
+ if (dlatents.shape[1] > self.model_scale):
107
+ dlatents = dlatents[:,:self.model_scale,:]
108
+ if (isinstance(dlatents.shape[0], int)):
109
+ if (dlatents.shape != (self.batch_size, self.model_scale, 512)):
110
+ dlatents = np.vstack([dlatents, np.zeros((self.batch_size-dlatents.shape[0], self.model_scale, 512))])
111
+ assert (dlatents.shape == (self.batch_size, self.model_scale, 512))
112
+ self.sess.run([self._assign_dlantent], {self._assign_dlatent_ph: dlatents})
113
+ return
114
+ else:
115
+ self._assign_dlantent = tf.assign(self.dlatent_variable, dlatents)
116
+ return
117
+ self.sess.run([self._assign_dlantent], {self._assign_dlatent_ph: dlatents})
118
+
119
+ def stochastic_clip_dlatents(self):
120
+ self.sess.run(self.stochastic_clip_op)
121
+
122
+ def get_dlatents(self):
123
+ return self.sess.run(self.dlatent_variable)
124
+
125
+ def get_dlatent_avg(self):
126
+ return self.dlatent_avg
127
+
128
+ def set_dlatent_avg(self, dlatent_avg):
129
+ self.dlatent_avg = dlatent_avg
130
+
131
+ def reset_dlatent_avg(self):
132
+ self.dlatent_avg = self.dlatent_avg_def
133
+
134
+ def generate_images(self, dlatents=None):
135
+ if dlatents is not None:
136
+ self.set_dlatents(dlatents)
137
+ return self.sess.run(self.generated_image_uint8)
encoder/perceptual_model.py ADDED
@@ -0,0 +1,304 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import absolute_import, division, print_function, unicode_literals
2
+ import tensorflow as tf
3
+ #import tensorflow_probability as tfp
4
+ #tf.enable_eager_execution()
5
+
6
+ import os
7
+ import bz2
8
+ import PIL.Image
9
+ from PIL import ImageFilter
10
+ import numpy as np
11
+ from keras.models import Model
12
+ from keras.utils import get_file
13
+ from keras.applications.vgg16 import VGG16, preprocess_input
14
+ import keras.backend as K
15
+ import traceback
16
+ import dnnlib.tflib as tflib
17
+
18
+ def load_images(images_list, image_size=256, sharpen=False):
19
+ loaded_images = list()
20
+ for img_path in images_list:
21
+ img = PIL.Image.open(img_path).convert('RGB')
22
+ if image_size is not None:
23
+ img = img.resize((image_size,image_size),PIL.Image.LANCZOS)
24
+ if (sharpen):
25
+ img = img.filter(ImageFilter.DETAIL)
26
+ img = np.array(img)
27
+ img = np.expand_dims(img, 0)
28
+ loaded_images.append(img)
29
+ loaded_images = np.vstack(loaded_images)
30
+ return loaded_images
31
+
32
+ def tf_custom_adaptive_loss(a,b):
33
+ from adaptive import lossfun
34
+ shape = a.get_shape().as_list()
35
+ dim = np.prod(shape[1:])
36
+ a = tf.reshape(a, [-1, dim])
37
+ b = tf.reshape(b, [-1, dim])
38
+ loss, _, _ = lossfun(b-a, var_suffix='1')
39
+ return tf.math.reduce_mean(loss)
40
+
41
+ def tf_custom_adaptive_rgb_loss(a,b):
42
+ from adaptive import image_lossfun
43
+ loss, _, _ = image_lossfun(b-a, color_space='RGB', representation='PIXEL')
44
+ return tf.math.reduce_mean(loss)
45
+
46
+ def tf_custom_l1_loss(img1,img2):
47
+ return tf.math.reduce_mean(tf.math.abs(img2-img1), axis=None)
48
+
49
+ def tf_custom_logcosh_loss(img1,img2):
50
+ return tf.math.reduce_mean(tf.keras.losses.logcosh(img1,img2))
51
+
52
+ def create_stub(batch_size):
53
+ return tf.constant(0, dtype='float32', shape=(batch_size, 0))
54
+
55
+ def unpack_bz2(src_path):
56
+ data = bz2.BZ2File(src_path).read()
57
+ dst_path = src_path[:-4]
58
+ with open(dst_path, 'wb') as fp:
59
+ fp.write(data)
60
+ return dst_path
61
+
62
+ class PerceptualModel:
63
+ def __init__(self, args, batch_size=1, perc_model=None, sess=None):
64
+ self.sess = tf.compat.v1.get_default_session() if sess is None else sess
65
+ K.set_session(self.sess)
66
+ self.epsilon = 0.00000001
67
+ self.lr = args.lr
68
+ self.decay_rate = args.decay_rate
69
+ self.decay_steps = args.decay_steps
70
+ self.img_size = args.image_size
71
+ self.layer = args.use_vgg_layer
72
+ self.vgg_loss = args.use_vgg_loss
73
+ self.face_mask = args.face_mask
74
+ self.use_grabcut = args.use_grabcut
75
+ self.scale_mask = args.scale_mask
76
+ self.mask_dir = args.mask_dir
77
+ if (self.layer <= 0 or self.vgg_loss <= self.epsilon):
78
+ self.vgg_loss = None
79
+ self.pixel_loss = args.use_pixel_loss
80
+ if (self.pixel_loss <= self.epsilon):
81
+ self.pixel_loss = None
82
+ self.mssim_loss = args.use_mssim_loss
83
+ if (self.mssim_loss <= self.epsilon):
84
+ self.mssim_loss = None
85
+ self.lpips_loss = args.use_lpips_loss
86
+ if (self.lpips_loss <= self.epsilon):
87
+ self.lpips_loss = None
88
+ self.l1_penalty = args.use_l1_penalty
89
+ if (self.l1_penalty <= self.epsilon):
90
+ self.l1_penalty = None
91
+ self.adaptive_loss = args.use_adaptive_loss
92
+ self.sharpen_input = args.sharpen_input
93
+ self.batch_size = batch_size
94
+ if perc_model is not None and self.lpips_loss is not None:
95
+ self.perc_model = perc_model
96
+ else:
97
+ self.perc_model = None
98
+ self.ref_img = None
99
+ self.ref_weight = None
100
+ self.perceptual_model = None
101
+ self.ref_img_features = None
102
+ self.features_weight = None
103
+ self.loss = None
104
+ self.discriminator_loss = args.use_discriminator_loss
105
+ if (self.discriminator_loss <= self.epsilon):
106
+ self.discriminator_loss = None
107
+ if self.discriminator_loss is not None:
108
+ self.discriminator = None
109
+ self.stub = create_stub(batch_size)
110
+
111
+ if self.face_mask:
112
+ import dlib
113
+ self.detector = dlib.get_frontal_face_detector()
114
+ landmarks_model_path = unpack_bz2('shape_predictor_68_face_landmarks.dat.bz2')
115
+ self.predictor = dlib.shape_predictor(landmarks_model_path)
116
+
117
+ def add_placeholder(self, var_name):
118
+ var_val = getattr(self, var_name)
119
+ setattr(self, var_name + "_placeholder", tf.compat.v1.placeholder(var_val.dtype, shape=var_val.get_shape()))
120
+ setattr(self, var_name + "_op", var_val.assign(getattr(self, var_name + "_placeholder")))
121
+
122
+ def assign_placeholder(self, var_name, var_val):
123
+ self.sess.run(getattr(self, var_name + "_op"), {getattr(self, var_name + "_placeholder"): var_val})
124
+
125
+ def build_perceptual_model(self, generator, discriminator=None):
126
+ # Learning rate
127
+ global_step = tf.Variable(0, dtype=tf.int32, trainable=False, name="global_step")
128
+ incremented_global_step = tf.compat.v1.assign_add(global_step, 1)
129
+ self._reset_global_step = tf.assign(global_step, 0)
130
+ self.learning_rate = tf.compat.v1.train.exponential_decay(self.lr, incremented_global_step,
131
+ self.decay_steps, self.decay_rate, staircase=True)
132
+ self.sess.run([self._reset_global_step])
133
+
134
+ if self.discriminator_loss is not None:
135
+ self.discriminator = discriminator
136
+
137
+ generated_image_tensor = generator.generated_image
138
+ generated_image = tf.compat.v1.image.resize_nearest_neighbor(generated_image_tensor,
139
+ (self.img_size, self.img_size), align_corners=True)
140
+
141
+ self.ref_img = tf.get_variable('ref_img', shape=generated_image.shape,
142
+ dtype='float32', initializer=tf.initializers.zeros())
143
+ self.ref_weight = tf.get_variable('ref_weight', shape=generated_image.shape,
144
+ dtype='float32', initializer=tf.initializers.zeros())
145
+ self.add_placeholder("ref_img")
146
+ self.add_placeholder("ref_weight")
147
+
148
+ if (self.vgg_loss is not None):
149
+ vgg16 = VGG16(include_top=False, weights='vgg16_weights_tf_dim_ordering_tf_kernels_notop.h5', input_shape=(self.img_size, self.img_size, 3)) # https://github.com/fchollet/deep-learning-models/releases/download/v0.1/vgg16_weights_tf_dim_ordering_tf_kernels_notop.h5
150
+ self.perceptual_model = Model(vgg16.input, vgg16.layers[self.layer].output)
151
+ generated_img_features = self.perceptual_model(preprocess_input(self.ref_weight * generated_image))
152
+ self.ref_img_features = tf.get_variable('ref_img_features', shape=generated_img_features.shape,
153
+ dtype='float32', initializer=tf.initializers.zeros())
154
+ self.features_weight = tf.get_variable('features_weight', shape=generated_img_features.shape,
155
+ dtype='float32', initializer=tf.initializers.zeros())
156
+ self.sess.run([self.features_weight.initializer, self.features_weight.initializer])
157
+ self.add_placeholder("ref_img_features")
158
+ self.add_placeholder("features_weight")
159
+
160
+ if self.perc_model is not None and self.lpips_loss is not None:
161
+ img1 = tflib.convert_images_from_uint8(self.ref_weight * self.ref_img, nhwc_to_nchw=True)
162
+ img2 = tflib.convert_images_from_uint8(self.ref_weight * generated_image, nhwc_to_nchw=True)
163
+
164
+ self.loss = 0
165
+ # L1 loss on VGG16 features
166
+ if (self.vgg_loss is not None):
167
+ if self.adaptive_loss:
168
+ self.loss += self.vgg_loss * tf_custom_adaptive_loss(self.features_weight * self.ref_img_features, self.features_weight * generated_img_features)
169
+ else:
170
+ self.loss += self.vgg_loss * tf_custom_logcosh_loss(self.features_weight * self.ref_img_features, self.features_weight * generated_img_features)
171
+ # + logcosh loss on image pixels
172
+ if (self.pixel_loss is not None):
173
+ if self.adaptive_loss:
174
+ self.loss += self.pixel_loss * tf_custom_adaptive_rgb_loss(self.ref_weight * self.ref_img, self.ref_weight * generated_image)
175
+ else:
176
+ self.loss += self.pixel_loss * tf_custom_logcosh_loss(self.ref_weight * self.ref_img, self.ref_weight * generated_image)
177
+ # + MS-SIM loss on image pixels
178
+ if (self.mssim_loss is not None):
179
+ self.loss += self.mssim_loss * tf.math.reduce_mean(1-tf.image.ssim_multiscale(self.ref_weight * self.ref_img, self.ref_weight * generated_image, 1))
180
+ # + extra perceptual loss on image pixels
181
+ if self.perc_model is not None and self.lpips_loss is not None:
182
+ self.loss += self.lpips_loss * tf.math.reduce_mean(self.perc_model.get_output_for(img1, img2))
183
+ # + L1 penalty on dlatent weights
184
+ if self.l1_penalty is not None:
185
+ self.loss += self.l1_penalty * 512 * tf.math.reduce_mean(tf.math.abs(generator.dlatent_variable-generator.get_dlatent_avg()))
186
+ # discriminator loss (realism)
187
+ if self.discriminator_loss is not None:
188
+ self.loss += self.discriminator_loss * tf.math.reduce_mean(self.discriminator.get_output_for(tflib.convert_images_from_uint8(generated_image_tensor, nhwc_to_nchw=True), self.stub))
189
+ # - discriminator_network.get_output_for(tflib.convert_images_from_uint8(ref_img, nhwc_to_nchw=True), stub)
190
+
191
+
192
+ def generate_face_mask(self, im):
193
+ from imutils import face_utils
194
+ import cv2
195
+ rects = self.detector(im, 1)
196
+ # loop over the face detections
197
+ for (j, rect) in enumerate(rects):
198
+ """
199
+ Determine the facial landmarks for the face region, then convert the facial landmark (x, y)-coordinates to a NumPy array
200
+ """
201
+ shape = self.predictor(im, rect)
202
+ shape = face_utils.shape_to_np(shape)
203
+
204
+ # we extract the face
205
+ vertices = cv2.convexHull(shape)
206
+ mask = np.zeros(im.shape[:2],np.uint8)
207
+ cv2.fillConvexPoly(mask, vertices, 1)
208
+ if self.use_grabcut:
209
+ bgdModel = np.zeros((1,65),np.float64)
210
+ fgdModel = np.zeros((1,65),np.float64)
211
+ rect = (0,0,im.shape[1],im.shape[2])
212
+ (x,y),radius = cv2.minEnclosingCircle(vertices)
213
+ center = (int(x),int(y))
214
+ radius = int(radius*self.scale_mask)
215
+ mask = cv2.circle(mask,center,radius,cv2.GC_PR_FGD,-1)
216
+ cv2.fillConvexPoly(mask, vertices, cv2.GC_FGD)
217
+ cv2.grabCut(im,mask,rect,bgdModel,fgdModel,5,cv2.GC_INIT_WITH_MASK)
218
+ mask = np.where((mask==2)|(mask==0),0,1)
219
+ return mask
220
+
221
+ def set_reference_images(self, images_list):
222
+ assert(len(images_list) != 0 and len(images_list) <= self.batch_size)
223
+ loaded_image = load_images(images_list, self.img_size, sharpen=self.sharpen_input)
224
+ image_features = None
225
+ if self.perceptual_model is not None:
226
+ image_features = self.perceptual_model.predict_on_batch(preprocess_input(np.array(loaded_image)))
227
+ weight_mask = np.ones(self.features_weight.shape)
228
+
229
+ if self.face_mask:
230
+ image_mask = np.zeros(self.ref_weight.shape)
231
+ for (i, im) in enumerate(loaded_image):
232
+ try:
233
+ _, img_name = os.path.split(images_list[i])
234
+ mask_img = os.path.join(self.mask_dir, f'{img_name}')
235
+ if (os.path.isfile(mask_img)):
236
+ print("Loading mask " + mask_img)
237
+ imask = PIL.Image.open(mask_img).convert('L')
238
+ mask = np.array(imask)/255
239
+ mask = np.expand_dims(mask,axis=-1)
240
+ else:
241
+ mask = self.generate_face_mask(im)
242
+ imask = (255*mask).astype('uint8')
243
+ imask = PIL.Image.fromarray(imask, 'L')
244
+ print("Saving mask " + mask_img)
245
+ imask.save(mask_img, 'PNG')
246
+ mask = np.expand_dims(mask,axis=-1)
247
+ mask = np.ones(im.shape,np.float32) * mask
248
+ except Exception as e:
249
+ print("Exception in mask handling for " + mask_img)
250
+ traceback.print_exc()
251
+ mask = np.ones(im.shape[:2],np.uint8)
252
+ mask = np.ones(im.shape,np.float32) * np.expand_dims(mask,axis=-1)
253
+ image_mask[i] = mask
254
+ img = None
255
+ else:
256
+ image_mask = np.ones(self.ref_weight.shape)
257
+
258
+ if len(images_list) != self.batch_size:
259
+ if image_features is not None:
260
+ features_space = list(self.features_weight.shape[1:])
261
+ existing_features_shape = [len(images_list)] + features_space
262
+ empty_features_shape = [self.batch_size - len(images_list)] + features_space
263
+ existing_examples = np.ones(shape=existing_features_shape)
264
+ empty_examples = np.zeros(shape=empty_features_shape)
265
+ weight_mask = np.vstack([existing_examples, empty_examples])
266
+ image_features = np.vstack([image_features, np.zeros(empty_features_shape)])
267
+
268
+ images_space = list(self.ref_weight.shape[1:])
269
+ existing_images_space = [len(images_list)] + images_space
270
+ empty_images_space = [self.batch_size - len(images_list)] + images_space
271
+ existing_images = np.ones(shape=existing_images_space)
272
+ empty_images = np.zeros(shape=empty_images_space)
273
+ image_mask = image_mask * np.vstack([existing_images, empty_images])
274
+ loaded_image = np.vstack([loaded_image, np.zeros(empty_images_space)])
275
+
276
+ if image_features is not None:
277
+ self.assign_placeholder("features_weight", weight_mask)
278
+ self.assign_placeholder("ref_img_features", image_features)
279
+ self.assign_placeholder("ref_weight", image_mask)
280
+ self.assign_placeholder("ref_img", loaded_image)
281
+
282
+ def optimize(self, vars_to_optimize, iterations=200, use_optimizer='adam'):
283
+ vars_to_optimize = vars_to_optimize if isinstance(vars_to_optimize, list) else [vars_to_optimize]
284
+ if use_optimizer == 'lbfgs':
285
+ optimizer = tf.contrib.opt.ScipyOptimizerInterface(self.loss, var_list=vars_to_optimize, method='L-BFGS-B', options={'maxiter': iterations})
286
+ else:
287
+ if use_optimizer == 'ggt':
288
+ optimizer = tf.contrib.opt.GGTOptimizer(learning_rate=self.learning_rate)
289
+ else:
290
+ optimizer = tf.train.AdamOptimizer(learning_rate=self.learning_rate)
291
+ min_op = optimizer.minimize(self.loss, var_list=[vars_to_optimize])
292
+ self.sess.run(tf.variables_initializer(optimizer.variables()))
293
+ fetch_ops = [min_op, self.loss, self.learning_rate]
294
+ #min_op = optimizer.minimize(self.sess)
295
+ #optim_results = tfp.optimizer.lbfgs_minimize(make_val_and_grad_fn(get_loss), initial_position=vars_to_optimize, num_correction_pairs=10, tolerance=1e-8)
296
+ self.sess.run(self._reset_global_step)
297
+ #self.sess.graph.finalize() # Graph is read-only after this statement.
298
+ for _ in range(iterations):
299
+ if use_optimizer == 'lbfgs':
300
+ optimizer.minimize(self.sess, fetches=[vars_to_optimize, self.loss])
301
+ yield {"loss":self.loss.eval()}
302
+ else:
303
+ _, loss, lr = self.sess.run(fetch_ops)
304
+ yield {"loss":loss,"lr":lr}
ffhq_dataset/__init__.py ADDED
File without changes
ffhq_dataset/__pycache__/__init__.cpython-36.pyc ADDED
Binary file (118 Bytes). View file
 
ffhq_dataset/__pycache__/__init__.cpython-37.pyc ADDED
Binary file (122 Bytes). View file
 
ffhq_dataset/__pycache__/face_alignment.cpython-36.pyc ADDED
Binary file (3.17 kB). View file
 
ffhq_dataset/__pycache__/face_alignment.cpython-37.pyc ADDED
Binary file (3.17 kB). View file
 
ffhq_dataset/__pycache__/landmarks_detector.cpython-36.pyc ADDED
Binary file (1.16 kB). View file