Spaces:
Sleeping
Sleeping
limoran
commited on
Commit
•
7e2a2a5
1
Parent(s):
8220eea
add basic files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- configs/__init__.py +30 -0
- configs/__pycache__/__init__.cpython-38.pyc +0 -0
- configs/__pycache__/base_config.cpython-38.pyc +0 -0
- configs/__pycache__/style_based_pix2pixII_config.cpython-38.pyc +0 -0
- configs/base_config.py +160 -0
- configs/style_based_pix2pixII_config.py +42 -0
- data/__init__.py +58 -0
- data/__pycache__/__init__.cpython-38.pyc +0 -0
- data/__pycache__/static_data.cpython-38.pyc +0 -0
- data/__pycache__/super_dataset.cpython-38.pyc +0 -0
- data/__pycache__/test_data.cpython-38.pyc +0 -0
- data/__pycache__/test_video_data.cpython-38.pyc +0 -0
- data/deprecated/custom_data.py +121 -0
- data/deprecated/landmark_data.py +89 -0
- data/deprecated/numpy_paired_data.py +81 -0
- data/deprecated/numpy_unpaired_data.py +100 -0
- data/deprecated/paired_data.py +80 -0
- data/deprecated/patch_data.py +44 -0
- data/deprecated/unpaired_data.py +101 -0
- data/static_data.py +457 -0
- data/super_dataset.py +321 -0
- data/test_data.py +51 -0
- data/test_video_data.py +28 -0
- exp/sp2pII-phase1.yaml +49 -0
- exp/sp2pII-phase2.yaml +49 -0
- exp/sp2pII-phase3.yaml +50 -0
- exp/sp2pII-phase4.yaml +49 -0
- logs/01_2023_09_07__18_32_26/events.out.tfevents.1694082748.aiplatform-wlf2-hi-12.idchb2az2.hb2.kwaidc.com.16044.0 +3 -0
- logs/01_2023_09_12__14_54_32/events.out.tfevents.1694501684.aiplatform-wlf2-ge4-22.idchb2az2.hb2.kwaidc.com.76748.0 +3 -0
- logs/01_2023_09_12__14_55_34/events.out.tfevents.1694501736.aiplatform-wlf2-ge4-22.idchb2az2.hb2.kwaidc.com.77369.0 +3 -0
- logs/01_2023_09_12__15_03_47/events.out.tfevents.1694502229.aiplatform-wlf2-ge4-22.idchb2az2.hb2.kwaidc.com.77940.0 +3 -0
- models/__init__.py +68 -0
- models/__pycache__/__init__.cpython-38.pyc +0 -0
- models/__pycache__/base_model.cpython-38.pyc +0 -0
- models/__pycache__/style_based_pix2pixII_model.cpython-38.pyc +0 -0
- models/base_model.py +340 -0
- models/modules/__init__.py +0 -0
- models/modules/__pycache__/__init__.cpython-38.pyc +0 -0
- models/modules/__pycache__/networks.cpython-38.pyc +0 -0
- models/modules/networks.py +1101 -0
- models/modules/sr/light_model_270M.py +347 -0
- models/modules/sr/light_model_470M.py +442 -0
- models/modules/stylegan2/__pycache__/model.cpython-38.pyc +0 -0
- models/modules/stylegan2/__pycache__/non_leaking.cpython-38.pyc +0 -0
- models/modules/stylegan2/model.py +716 -0
- models/modules/stylegan2/non_leaking.py +465 -0
- models/modules/stylegan2/op/__init__.py +2 -0
- models/modules/stylegan2/op/__pycache__/__init__.cpython-38.pyc +0 -0
- models/modules/stylegan2/op/__pycache__/conv2d_gradfix.cpython-38.pyc +0 -0
- models/modules/stylegan2/op/__pycache__/fused_act.cpython-38.pyc +0 -0
configs/__init__.py
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import importlib
|
2 |
+
from configs.base_config import BaseConfig
|
3 |
+
|
4 |
+
def find_config_by_name(config_name):
|
5 |
+
# load config lib by config name
|
6 |
+
config_file = "configs." + config_name + '_config'
|
7 |
+
config_lib = importlib.import_module(config_file)
|
8 |
+
print(config_lib)
|
9 |
+
|
10 |
+
# find the subclass of BaseConfig
|
11 |
+
config = None
|
12 |
+
target_config_name = config_name.replace('_', '') + 'config'
|
13 |
+
target_config_name = target_config_name.lower()
|
14 |
+
for name, cls in config_lib.__dict__.items():
|
15 |
+
if name.lower() == target_config_name and issubclass(cls, BaseConfig):
|
16 |
+
config = cls
|
17 |
+
|
18 |
+
if config is None:
|
19 |
+
raise Exception('No valid config found.')
|
20 |
+
|
21 |
+
return config
|
22 |
+
|
23 |
+
def parse_config(cfg_file):
|
24 |
+
# parse config using BaseConfig
|
25 |
+
cfg = BaseConfig().parse_config(cfg_file)
|
26 |
+
model_name = cfg['common']['model']
|
27 |
+
|
28 |
+
# re-parse using specified Config class
|
29 |
+
config = find_config_by_name(model_name)
|
30 |
+
return config().parse_config(cfg_file)
|
configs/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (900 Bytes). View file
|
|
configs/__pycache__/base_config.cpython-38.pyc
ADDED
Binary file (4.96 kB). View file
|
|
configs/__pycache__/style_based_pix2pixII_config.cpython-38.pyc
ADDED
Binary file (2.13 kB). View file
|
|
configs/base_config.py
ADDED
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import yaml
|
2 |
+
import copy
|
3 |
+
from typing import Union
|
4 |
+
|
5 |
+
class BaseConfig():
|
6 |
+
|
7 |
+
def __init__(self):
|
8 |
+
self.__config_dict = {}
|
9 |
+
self.__check_func_dict = {}
|
10 |
+
|
11 |
+
is_greater_than_0 = lambda x: x > 0
|
12 |
+
|
13 |
+
# common config
|
14 |
+
self._add_option('common', 'name', str, 'style_master')
|
15 |
+
self._add_option('common', 'model', str, 'cycle_gan')
|
16 |
+
self._add_option('common', 'phase', str, 'train', check_func=lambda x: x in ['train', 'test'])
|
17 |
+
self._add_option('common', 'gpu_ids', list, [0])
|
18 |
+
self._add_option('common', 'verbose', bool, False)
|
19 |
+
|
20 |
+
# model config
|
21 |
+
self._add_option('model', 'input_nc', int, 3, check_func=is_greater_than_0)
|
22 |
+
self._add_option('model', 'output_nc', int, 3, check_func=is_greater_than_0)
|
23 |
+
|
24 |
+
# dataset config
|
25 |
+
# common dataset options
|
26 |
+
self._add_option('dataset', 'use_absolute_datafile', bool, True)
|
27 |
+
self._add_option('dataset', 'batch_size', int, 1, check_func=is_greater_than_0)
|
28 |
+
self._add_option('dataset', 'n_threads', int, 4, check_func=is_greater_than_0)
|
29 |
+
self._add_option('dataset', 'dataroot', str, './')
|
30 |
+
self._add_option('dataset', 'drop_last', bool, False)
|
31 |
+
self._add_option('dataset', 'landmark_scale', list, None)
|
32 |
+
self._add_option('dataset', 'check_all_data', bool, False)
|
33 |
+
self._add_option('dataset', 'accept_data_error', bool, True) # Upon loading a bad data, if this is true,
|
34 |
+
# dataloader will throw an exception and
|
35 |
+
# load the next good data.
|
36 |
+
# If this is false, process will crash.
|
37 |
+
|
38 |
+
self._add_option('dataset', 'train_data', dict, {})
|
39 |
+
self._add_option('dataset', 'val_data', dict, {})
|
40 |
+
|
41 |
+
# paired data config
|
42 |
+
self._add_option('dataset', 'paired_trainA_folder', str, '')
|
43 |
+
self._add_option('dataset', 'paired_trainB_folder', str, '')
|
44 |
+
self._add_option('dataset', 'paired_train_filelist', str, '')
|
45 |
+
self._add_option('dataset', 'paired_valA_folder', str, '')
|
46 |
+
self._add_option('dataset', 'paired_valB_folder', str, '')
|
47 |
+
self._add_option('dataset', 'paired_val_filelist', str, '')
|
48 |
+
|
49 |
+
# unpaired data config
|
50 |
+
self._add_option('dataset', 'unpaired_trainA_folder', str, '')
|
51 |
+
self._add_option('dataset', 'unpaired_trainB_folder', str, '')
|
52 |
+
self._add_option('dataset', 'unpaired_trainA_filelist', str, '')
|
53 |
+
self._add_option('dataset', 'unpaired_trainB_filelist', str, '')
|
54 |
+
self._add_option('dataset', 'unpaired_valA_folder', str, '')
|
55 |
+
self._add_option('dataset', 'unpaired_valB_folder', str, '')
|
56 |
+
self._add_option('dataset', 'unpaired_valA_filelist', str, '')
|
57 |
+
self._add_option('dataset', 'unpaired_valB_filelist', str, '')
|
58 |
+
|
59 |
+
# custom data
|
60 |
+
self._add_option('dataset', 'custom_train_data', dict, {})
|
61 |
+
self._add_option('dataset', 'custom_val_data', dict, {})
|
62 |
+
|
63 |
+
# training config
|
64 |
+
self._add_option('training', 'checkpoints_dir', str, './checkpoints')
|
65 |
+
self._add_option('training', 'log_dir', str, './logs')
|
66 |
+
self._add_option('training', 'use_new_log', bool, False)
|
67 |
+
self._add_option('training', 'continue_train', bool, False)
|
68 |
+
self._add_option('training', 'which_epoch', str, 'latest')
|
69 |
+
self._add_option('training', 'n_epochs', int, 100, check_func=is_greater_than_0)
|
70 |
+
self._add_option('training', 'n_epochs_decay', int, 100, check_func=is_greater_than_0)
|
71 |
+
self._add_option('training', 'save_latest_freq', int, 5000, check_func=is_greater_than_0)
|
72 |
+
self._add_option('training', 'print_freq', int, 200, check_func=is_greater_than_0)
|
73 |
+
self._add_option('training', 'save_epoch_freq', int, 5, check_func=is_greater_than_0)
|
74 |
+
self._add_option('training', 'epoch_as_iter', bool, False)
|
75 |
+
self._add_option('training', 'lr', float, 2e-4, check_func=is_greater_than_0)
|
76 |
+
self._add_option('training', 'lr_policy', str, 'linear',
|
77 |
+
check_func=lambda x: x in ['linear', 'step', 'plateau', 'cosine'])
|
78 |
+
self._add_option('training', 'lr_decay_iters', int, 50, check_func=is_greater_than_0)
|
79 |
+
self._add_option('training', 'DDP', bool, False)
|
80 |
+
self._add_option('training', 'num_nodes', int, 1, check_func=is_greater_than_0)
|
81 |
+
self._add_option('training', 'DDP_address', str, '127.0.0.1')
|
82 |
+
self._add_option('training', 'DDP_port', str, '29700')
|
83 |
+
self._add_option('training', 'find_unused_parameters', bool, False) # a DDP option that allows backward on a subgraph of the model
|
84 |
+
self._add_option('training', 'val_percent', float, 5.0, check_func=is_greater_than_0) # Uses x% of training data to validate
|
85 |
+
self._add_option('training', 'val', bool, True) # perform validation every epoch
|
86 |
+
self._add_option('training', 'save_training_progress', bool, False) # save images to create a training progression video
|
87 |
+
|
88 |
+
# testing config
|
89 |
+
self._add_option('testing', 'results_dir', str, './results')
|
90 |
+
self._add_option('testing', 'load_size', int, 512, check_func=is_greater_than_0)
|
91 |
+
self._add_option('testing', 'crop_size', int, 512, check_func=is_greater_than_0)
|
92 |
+
self._add_option('testing', 'preprocess', list, ['scale_width'])
|
93 |
+
self._add_option('testing', 'visual_names', list, [])
|
94 |
+
self._add_option('testing', 'num_test', int, 999999, check_func=is_greater_than_0)
|
95 |
+
self._add_option('testing', 'image_format', str, 'jpg', check_func=lambda x: x in ['input', 'jpg', 'jpeg', 'png'])
|
96 |
+
|
97 |
+
def _add_option(self, group_name, option_name, value_type, default_value, check_func=None):
|
98 |
+
# check name type
|
99 |
+
if not type(group_name) is str or not type(option_name) is str:
|
100 |
+
raise Exception('Type of {} and {} must be str.'.format(group_name, option_name))
|
101 |
+
|
102 |
+
# add group
|
103 |
+
if not group_name in self.__config_dict:
|
104 |
+
self.__config_dict[group_name] = {}
|
105 |
+
self.__check_func_dict[group_name] = {}
|
106 |
+
|
107 |
+
# check type & default value
|
108 |
+
if not type(value_type) is type:
|
109 |
+
try:
|
110 |
+
if value_type.__origin__ is not Union:
|
111 |
+
raise Exception('{} is not a type.'.format(value_type))
|
112 |
+
except Exception as e:
|
113 |
+
print(e)
|
114 |
+
if not type(default_value) is value_type:
|
115 |
+
try:
|
116 |
+
if value_type.__origin__ is not Union:
|
117 |
+
raise Exception('Type of {} must be {}.'.format(default_value, value_type))
|
118 |
+
except Exception as e:
|
119 |
+
print(e)
|
120 |
+
|
121 |
+
# add option to dict
|
122 |
+
if not option_name in self.__config_dict[group_name]:
|
123 |
+
if not check_func is None and not check_func(default_value):
|
124 |
+
raise Exception('Checking {}/{} failed.'.format(group_name, option_name))
|
125 |
+
self.__config_dict[group_name][option_name] = default_value
|
126 |
+
self.__check_func_dict[group_name][option_name] = check_func
|
127 |
+
else:
|
128 |
+
raise Exception('{} has been already added.'.format(option_name))
|
129 |
+
|
130 |
+
def parse_config(self, cfg_file):
|
131 |
+
# load config from yaml file
|
132 |
+
with open(cfg_file, 'r') as f:
|
133 |
+
yaml_config = yaml.safe_load(f)
|
134 |
+
if not type(yaml_config) is dict:
|
135 |
+
raise Exception('Loading yaml file failed.')
|
136 |
+
|
137 |
+
# replace default options
|
138 |
+
config_dict = copy.deepcopy(self.__config_dict)
|
139 |
+
for group in config_dict:
|
140 |
+
if group in yaml_config:
|
141 |
+
for option in config_dict[group]:
|
142 |
+
if option in yaml_config[group]:
|
143 |
+
value = yaml_config[group][option]
|
144 |
+
if not type(value) is type(config_dict[group][option]):
|
145 |
+
try: # if <config_dict[group][option]> is not union, it won't have __origin__ attribute. So will throw an error.
|
146 |
+
# The line below is necessary because we check if <config_dict[group][option]> has __origin__ attribute.
|
147 |
+
if config_dict[group][option].__origin__ is Union:
|
148 |
+
# check to see if type of <value> belongs to a type in the union.
|
149 |
+
if not isinstance(value, config_dict[group][option].__args__):
|
150 |
+
raise Exception('Type of {}/{} must be {}.'.format(group, option,
|
151 |
+
config_dict[group][option].__args__))
|
152 |
+
except Exception as e: # if the error was thrown, we know there's a type error.
|
153 |
+
print(e)
|
154 |
+
else:
|
155 |
+
check_func = self.__check_func_dict[group][option]
|
156 |
+
if not check_func is None and not check_func(value):
|
157 |
+
raise Exception('Checking {}/{} failed.'.format(group, option))
|
158 |
+
config_dict[group][option] = value
|
159 |
+
return config_dict
|
160 |
+
|
configs/style_based_pix2pixII_config.py
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .base_config import BaseConfig
|
2 |
+
from typing import Union as Union
|
3 |
+
|
4 |
+
class StyleBasedPix2PixIIConfig(BaseConfig):
|
5 |
+
|
6 |
+
def __init__(self):
|
7 |
+
super(StyleBasedPix2PixIIConfig, self).__init__()
|
8 |
+
|
9 |
+
is_greater_than_0 = lambda x: x > 0
|
10 |
+
|
11 |
+
# model config
|
12 |
+
self._add_option('model', 'ngf', int, 64, check_func=is_greater_than_0)
|
13 |
+
self._add_option('model', 'min_feats_size', list, [4, 4])
|
14 |
+
|
15 |
+
# dataset config
|
16 |
+
self._add_option('dataset', 'data_type', list, ['unpaired'])
|
17 |
+
self._add_option('dataset', 'direction', str, 'AtoB')
|
18 |
+
self._add_option('dataset', 'serial_batches', bool, False)
|
19 |
+
self._add_option('dataset', 'load_size', int, 512, check_func=is_greater_than_0)
|
20 |
+
self._add_option('dataset', 'crop_size', int, 512, check_func=is_greater_than_0)
|
21 |
+
self._add_option('dataset', 'preprocess', Union[list, str], ['resize'])
|
22 |
+
self._add_option('dataset', 'no_flip', bool, True)
|
23 |
+
|
24 |
+
# training config
|
25 |
+
self._add_option('training', 'beta1', float, 0.1, check_func=is_greater_than_0)
|
26 |
+
self._add_option('training', 'data_aug_prob', float, 0.0, check_func=lambda x: x >= 0.0)
|
27 |
+
self._add_option('training', 'style_mixing_prob', float, 0.0, check_func=lambda x: x >= 0.0)
|
28 |
+
self._add_option('training', 'phase', int, 1, check_func=lambda x: x in [1, 2, 3, 4])
|
29 |
+
self._add_option('training', 'pretrained_model', str, 'model.pth')
|
30 |
+
self._add_option('training', 'src_text_prompt', str, 'photo')
|
31 |
+
self._add_option('training', 'text_prompt', str, 'a portrait in style of sketch')
|
32 |
+
self._add_option('training', 'image_prompt', str, 'style.png')
|
33 |
+
self._add_option('training', 'lambda_L1', float, 1.0)
|
34 |
+
self._add_option('training', 'lambda_Feat', float, 4.0)
|
35 |
+
self._add_option('training', 'lambda_ST', float, 1.0)
|
36 |
+
self._add_option('training', 'lambda_GAN', float, 1.0)
|
37 |
+
self._add_option('training', 'lambda_CLIP', float, 1.0)
|
38 |
+
self._add_option('training', 'lambda_PROJ', float, 1.0)
|
39 |
+
self._add_option('training', 'ema', float, 0.999)
|
40 |
+
|
41 |
+
# testing config
|
42 |
+
self._add_option('testing', 'aspect_ratio', float, 1.0, check_func=is_greater_than_0)
|
data/__init__.py
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""This package includes all the modules related to data loading and preprocessing
|
2 |
+
|
3 |
+
To add a custom dataset class called 'dummy', you need to add a file called 'dummy_dataset.py' and define a subclass 'DummyDataset' inherited from BaseDataset.
|
4 |
+
You need to implement four functions:
|
5 |
+
-- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt).
|
6 |
+
-- <__len__>: return the size of dataset.
|
7 |
+
-- <__getitem__>: get a data point from data loader.
|
8 |
+
-- <modify_commandline_options>: (optionally) add dataset-specific options and set default options.
|
9 |
+
|
10 |
+
Now you can use the dataset class by specifying flag '--dataset_mode dummy'.
|
11 |
+
See our template dataset class 'template_dataset.py' for more details.
|
12 |
+
"""
|
13 |
+
import importlib
|
14 |
+
import torch.utils.data
|
15 |
+
from torch.utils.data.distributed import DistributedSampler
|
16 |
+
|
17 |
+
class CustomDataLoader():
|
18 |
+
"""Wrapper class of Dataset class that performs multi-threaded data loading"""
|
19 |
+
|
20 |
+
def __init__(self, config, dataset, DDP_gpu=None, drop_last=False):
|
21 |
+
"""Initialize this class
|
22 |
+
|
23 |
+
Step 1: create a dataset instance given the name [dataset_mode]
|
24 |
+
Step 2: create a multi-threaded data loader.
|
25 |
+
"""
|
26 |
+
self.config = config
|
27 |
+
self.dataset = dataset
|
28 |
+
|
29 |
+
if DDP_gpu is None:
|
30 |
+
self.dataloader = torch.utils.data.DataLoader(
|
31 |
+
self.dataset,
|
32 |
+
batch_size=config['dataset']['batch_size'],
|
33 |
+
shuffle=not config['dataset']['serial_batches'],
|
34 |
+
num_workers=int(config['dataset']['n_threads']), drop_last=drop_last)
|
35 |
+
else:
|
36 |
+
sampler = DistributedSampler(self.dataset, num_replicas=self.config['training']['world_size'],
|
37 |
+
rank=DDP_gpu)
|
38 |
+
self.dataloader = torch.utils.data.DataLoader(
|
39 |
+
self.dataset,
|
40 |
+
batch_size=config['dataset']['batch_size'],
|
41 |
+
shuffle=False,
|
42 |
+
num_workers=int(config['dataset']['n_threads']),
|
43 |
+
sampler=sampler,
|
44 |
+
drop_last=drop_last)
|
45 |
+
|
46 |
+
def load_data(self):
|
47 |
+
return self
|
48 |
+
|
49 |
+
def __len__(self):
|
50 |
+
"""Return the number of data in the dataset"""
|
51 |
+
return min(len(self.dataset), 1e9)
|
52 |
+
|
53 |
+
def __iter__(self):
|
54 |
+
"""Return a batch of data"""
|
55 |
+
for i, data in enumerate(self.dataloader):
|
56 |
+
if i * self.config['dataset']['batch_size'] >= 1e9:
|
57 |
+
break
|
58 |
+
yield data
|
data/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (2.64 kB). View file
|
|
data/__pycache__/static_data.cpython-38.pyc
ADDED
Binary file (10.8 kB). View file
|
|
data/__pycache__/super_dataset.cpython-38.pyc
ADDED
Binary file (9.17 kB). View file
|
|
data/__pycache__/test_data.cpython-38.pyc
ADDED
Binary file (1.74 kB). View file
|
|
data/__pycache__/test_video_data.cpython-38.pyc
ADDED
Binary file (1.43 kB). View file
|
|
data/deprecated/custom_data.py
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import random
|
3 |
+
import numpy as np
|
4 |
+
from utils.augmentation import ImagePathToImage
|
5 |
+
from utils.data_utils import Transforms, check_img_loaded, check_numpy_loaded
|
6 |
+
|
7 |
+
|
8 |
+
class CustomData(object):
|
9 |
+
|
10 |
+
def __init__(self, config, shuffle=False):
|
11 |
+
self.paired_file_groups = []
|
12 |
+
self.paired_type_groups = []
|
13 |
+
self.len_of_groups = []
|
14 |
+
self.landmark_scale = config['dataset']['landmark_scale']
|
15 |
+
self.shuffle = shuffle
|
16 |
+
self.config = config
|
17 |
+
|
18 |
+
data_dict = config['dataset']['custom_' + config['common']['phase'] + '_data']
|
19 |
+
if len(data_dict) == 0:
|
20 |
+
self.len_of_groups.append(0)
|
21 |
+
return
|
22 |
+
|
23 |
+
for i, group in enumerate(data_dict.values()): # one example: (0, group_1), (1, group_2)
|
24 |
+
data_types = group['data_types'] # one example: 'image', 'patch'
|
25 |
+
data_names = group['data_names'] # one example: 'real_A', 'patch_A'
|
26 |
+
file_list = group['file_list'] # one example: "lmt/data/trainA.txt"
|
27 |
+
assert(len(data_types) == len(data_names))
|
28 |
+
|
29 |
+
self.paired_file_groups.append({})
|
30 |
+
self.paired_type_groups.append({})
|
31 |
+
for data_name, data_type in zip(data_names, data_types):
|
32 |
+
self.paired_file_groups[i][data_name] = []
|
33 |
+
self.paired_type_groups[i][data_name] = data_type
|
34 |
+
|
35 |
+
paired_file = open(file_list, 'rt')
|
36 |
+
lines = paired_file.readlines()
|
37 |
+
if self.shuffle:
|
38 |
+
random.shuffle(lines)
|
39 |
+
for line in lines:
|
40 |
+
items = line.strip().split(' ')
|
41 |
+
if len(items) == len(data_names):
|
42 |
+
ok = True
|
43 |
+
for item in items:
|
44 |
+
ok = ok and os.path.exists(item) and os.path.getsize(item) > 0
|
45 |
+
if ok:
|
46 |
+
for data_name, item in zip(data_names, items):
|
47 |
+
self.paired_file_groups[i][data_name].append(item)
|
48 |
+
paired_file.close()
|
49 |
+
|
50 |
+
self.len_of_groups.append(len(self.paired_file_groups[i][data_names[0]]))
|
51 |
+
|
52 |
+
self.transform = Transforms(config)
|
53 |
+
self.transform.get_transform_from_config()
|
54 |
+
self.transform.get_transforms().insert(0, ImagePathToImage())
|
55 |
+
self.transform = self.transform.compose_transforms()
|
56 |
+
|
57 |
+
def get_len(self):
|
58 |
+
return max(self.len_of_groups)
|
59 |
+
|
60 |
+
def get_item(self, idx):
|
61 |
+
return_dict = {}
|
62 |
+
for i in range(len(self.paired_file_groups)):
|
63 |
+
inner_idx = idx if idx < self.len_of_groups[i] else random.randint(0, self.len_of_groups[i] - 1)
|
64 |
+
img_list = []
|
65 |
+
img_k_list = []
|
66 |
+
for k, v in self.paired_file_groups[i].items():
|
67 |
+
if self.paired_type_groups[i][k] == 'image':
|
68 |
+
# gather images for processing later
|
69 |
+
img_k_list.append(k)
|
70 |
+
img_list.append(v[inner_idx])
|
71 |
+
elif self.paired_type_groups[i][k] == 'landmark':
|
72 |
+
# different from images, landmark doesn't use data augmentation. So process them directly here.
|
73 |
+
lmk = np.load(v[inner_idx])
|
74 |
+
lmk[:, 0] *= self.landmark_scale[0]
|
75 |
+
lmk[:, 1] *= self.landmark_scale[1]
|
76 |
+
return_dict[k] = lmk
|
77 |
+
return_dict[k + '_path'] = v[inner_idx]
|
78 |
+
|
79 |
+
# transform all images
|
80 |
+
if len(img_list) == 1:
|
81 |
+
return_dict[img_k_list[0]], _ = self.transform(img_list[0], None)
|
82 |
+
elif len(img_list) > 1:
|
83 |
+
input1, input2 = img_list[0], img_list[1:]
|
84 |
+
output1, output2 = self.transform(input1, input2) # output1 is one image. output2 is a list of images.
|
85 |
+
return_dict[img_k_list[0]] = output1
|
86 |
+
for j in range(1, len(img_list)):
|
87 |
+
return_dict[img_k_list[j]] = output2[j-1]
|
88 |
+
|
89 |
+
return return_dict
|
90 |
+
|
91 |
+
def split_data_into_bins(self, num_bins):
|
92 |
+
bins = []
|
93 |
+
for i in range(0, num_bins):
|
94 |
+
bins.append([])
|
95 |
+
for i in range(0, len(self.paired_file_groups)):
|
96 |
+
for b in range(0, num_bins):
|
97 |
+
bins[b].append({})
|
98 |
+
for dataname, item_list in self.paired_file_groups[i].items():
|
99 |
+
if len(item_list) < self.config['dataset']['n_threads']:
|
100 |
+
bins[0][i][dataname] = item_list
|
101 |
+
else:
|
102 |
+
num_items_in_bin = len(item_list) // num_bins
|
103 |
+
for j in range(0, len(item_list)):
|
104 |
+
which_bin = min(j // num_items_in_bin, num_bins - 1)
|
105 |
+
if dataname not in bins[which_bin][i]:
|
106 |
+
bins[which_bin][i][dataname] = []
|
107 |
+
else:
|
108 |
+
bins[which_bin][i][dataname].append(item_list[j])
|
109 |
+
return bins
|
110 |
+
|
111 |
+
def check_data_helper(self, data):
|
112 |
+
all_pass = True
|
113 |
+
for paired_file_group in data:
|
114 |
+
for k, v in paired_file_group.items():
|
115 |
+
if len(v) > 0:
|
116 |
+
for v1 in v:
|
117 |
+
if '.npy' in v1: # case: numpy array or landmark
|
118 |
+
all_pass = all_pass and check_numpy_loaded(v1)
|
119 |
+
else: # case: image
|
120 |
+
all_pass = all_pass and check_img_loaded(v1)
|
121 |
+
return all_pass
|
data/deprecated/landmark_data.py
ADDED
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from PIL import Image
|
3 |
+
import numpy as np
|
4 |
+
from utils.data_utils import check_create_shuffled_order, check_equal_length
|
5 |
+
|
6 |
+
def landmark_path_to_numpy(lmk_path, image_path, image_tensor):
|
7 |
+
"""Convert an landmark path to the actual landmarks in a numpy array. Also applies scaling to the landmarks
|
8 |
+
according to final image' size.
|
9 |
+
|
10 |
+
Parameters:
|
11 |
+
lmk_path -- the landmark file path.
|
12 |
+
image_path -- the original image file path.
|
13 |
+
image_tensor -- the image tensor after all transformations.
|
14 |
+
"""
|
15 |
+
lmk = np.load(lmk_path)
|
16 |
+
ow, oh = Image.open(image_path).size
|
17 |
+
h, w = image_tensor.size()[1:]
|
18 |
+
lmk[:, 0] *= w / ow
|
19 |
+
lmk[:, 1] *= h / oh
|
20 |
+
return lmk
|
21 |
+
|
22 |
+
def add_landmark_data(data, config, paired_data_order):
|
23 |
+
A_lmk_paths = []
|
24 |
+
B_lmk_paths = []
|
25 |
+
|
26 |
+
if config['dataset']['paired_' + config['common']['phase'] + '_filelist'] != '':
|
27 |
+
paired_data_file = open(config['dataset']['paired_' + config['common']['phase'] + '_filelist'], 'r')
|
28 |
+
Lines = paired_data_file.readlines()
|
29 |
+
paired_data_order = check_create_shuffled_order(Lines, paired_data_order)
|
30 |
+
check_equal_length(Lines, paired_data_order, data)
|
31 |
+
for i in paired_data_order:
|
32 |
+
line = Lines[i]
|
33 |
+
if not config['dataset']['use_absolute_datafile']:
|
34 |
+
file3 = os.path.join(config['dataset']['dataroot'], line.split(" ")[2]).strip()
|
35 |
+
file4 = os.path.join(config['dataset']['dataroot'], line.split(" ")[3]).strip()
|
36 |
+
else:
|
37 |
+
file3 = line.split(" ")[2].strip()
|
38 |
+
file4 = line.split(" ")[3].strip()
|
39 |
+
if os.path.exists(file3) and os.path.exists(file4):
|
40 |
+
A_lmk_paths.append(file3)
|
41 |
+
B_lmk_paths.append(file4)
|
42 |
+
paired_data_file.close()
|
43 |
+
elif config['dataset']['paired_' + config['common']['phase'] + 'A_folder'] != '' and \
|
44 |
+
config['dataset']['paired_' + config['common']['phase'] + 'B_folder'] != '' and \
|
45 |
+
os.path.exists(config['dataset']['paired_' + config['common']['phase'] + 'A_lmk_folder']) and \
|
46 |
+
os.path.exists(config['dataset']['paired_' + config['common']['phase'] + 'B_lmk_folder']):
|
47 |
+
dir_A = config['dataset']['paired_' + config['common']['phase'] + 'A_folder']
|
48 |
+
dir_A_lmk = config['dataset']['paired_' + config['common']['phase'] + 'A_lmk_folder']
|
49 |
+
dir_B_lmk = config['dataset']['paired_' + config['common']['phase'] + 'B_lmk_folder']
|
50 |
+
filenames = os.listdir(dir_A)
|
51 |
+
paired_data_order = check_create_shuffled_order(filenames, paired_data_order)
|
52 |
+
check_equal_length(filenames, paired_data_order, data)
|
53 |
+
for i in paired_data_order:
|
54 |
+
filename = filenames[i]
|
55 |
+
A_lmk_path = os.path.join(dir_A_lmk, os.path.splitext(filename)[0] + '.npy')
|
56 |
+
B_lmk_path = os.path.join(dir_B_lmk, os.path.splitext(filename)[0] + '.npy')
|
57 |
+
if os.path.exists(A_lmk_path) and os.path.exists(B_lmk_path):
|
58 |
+
A_lmk_paths.append(A_lmk_path)
|
59 |
+
B_lmk_paths.append(B_lmk_path)
|
60 |
+
else:
|
61 |
+
dir_A = os.path.join(config['dataset']['dataroot'], config['common']['phase'] + 'pairedA')
|
62 |
+
dir_A_lmk = os.path.join(config['dataset']['dataroot'], config['common']['phase'] + 'pairedA_lmk')
|
63 |
+
dir_B_lmk = os.path.join(config['dataset']['dataroot'], config['common']['phase'] + 'pairedB_lmk')
|
64 |
+
if os.path.exists(dir_A_lmk) and os.path.exists(dir_B_lmk):
|
65 |
+
filenames = os.listdir(dir_A)
|
66 |
+
paired_data_order = check_create_shuffled_order(filenames, paired_data_order)
|
67 |
+
check_equal_length(filenames, paired_data_order, data)
|
68 |
+
for i in paired_data_order:
|
69 |
+
filename = filenames[i]
|
70 |
+
A_lmk_path = os.path.join(dir_A_lmk, os.path.splitext(filename)[0] + '.npy')
|
71 |
+
B_lmk_path = os.path.join(dir_B_lmk, os.path.splitext(filename)[0] + '.npy')
|
72 |
+
if os.path.exists(A_lmk_path) and os.path.exists(B_lmk_path):
|
73 |
+
A_lmk_paths.append(A_lmk_path)
|
74 |
+
B_lmk_paths.append(B_lmk_path)
|
75 |
+
else:
|
76 |
+
print(dir_A_lmk + " or " + dir_B_lmk + " doesn't exist. Skipping landmark data.")
|
77 |
+
|
78 |
+
data['A_lmk_path'] = A_lmk_paths
|
79 |
+
data['B_lmk_path'] = B_lmk_paths
|
80 |
+
|
81 |
+
return paired_data_order
|
82 |
+
|
83 |
+
|
84 |
+
def apply_landmark_transforms(index, data, return_dict):
|
85 |
+
if len(data['A_lmk_path']) > 0:
|
86 |
+
return_dict['A_lmk'] = landmark_path_to_numpy(data['A_lmk_path'][index], data['paired_A_path'][index], return_dict['paired_A'])
|
87 |
+
return_dict['B_lmk'] = landmark_path_to_numpy(data['B_lmk_path'][index], data['paired_B_path'][index], return_dict['paired_B'])
|
88 |
+
return_dict['A_lmk_path'] = data['A_lmk_path'][index]
|
89 |
+
return_dict['B_lmk_path'] = data['B_lmk_path'][index]
|
data/deprecated/numpy_paired_data.py
ADDED
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from utils.util import check_path_is_img
|
3 |
+
from utils.data_utils import Transforms, check_create_shuffled_order, check_equal_length
|
4 |
+
from utils.augmentation import NumpyToTensor
|
5 |
+
|
6 |
+
|
7 |
+
def add_numpy_paired_data(data, transforms, config, paired_data_order):
|
8 |
+
A_paths = []
|
9 |
+
B_paths = []
|
10 |
+
|
11 |
+
if config['dataset']['paired_' + config['common']['phase'] + '_filelist'] != '':
|
12 |
+
paired_data_file = open(config['dataset']['paired_' + config['common']['phase'] + '_filelist'], 'r')
|
13 |
+
Lines = paired_data_file.readlines()
|
14 |
+
paired_data_order = check_create_shuffled_order(Lines, paired_data_order)
|
15 |
+
check_equal_length(Lines, paired_data_order, data)
|
16 |
+
for i in paired_data_order:
|
17 |
+
line = Lines[i]
|
18 |
+
if not config['dataset']['use_absolute_datafile']:
|
19 |
+
file1 = os.path.join(config['dataset']['dataroot'], line.split(" ")[0]).strip()
|
20 |
+
file2 = os.path.join(config['dataset']['dataroot'], line.split(" ")[1]).strip()
|
21 |
+
else:
|
22 |
+
file1 = line.split(" ")[0].strip()
|
23 |
+
file2 = line.split(" ")[1].strip()
|
24 |
+
if os.path.exists(file1) and os.path.exists(file2):
|
25 |
+
A_paths.append(file1)
|
26 |
+
B_paths.append(file2)
|
27 |
+
paired_data_file.close()
|
28 |
+
elif config['dataset']['paired_' + config['common']['phase'] + 'A_folder'] != '' and \
|
29 |
+
config['dataset']['paired_' + config['common']['phase'] + 'B_folder'] != '':
|
30 |
+
dir_A = config['dataset']['paired_' + config['common']['phase'] + 'A_folder']
|
31 |
+
dir_B = config['dataset']['paired_' + config['common']['phase'] + 'B_folder']
|
32 |
+
filenames = os.listdir(dir_A)
|
33 |
+
paired_data_order = check_create_shuffled_order(filenames, paired_data_order)
|
34 |
+
check_equal_length(filenames, paired_data_order, data)
|
35 |
+
for i in paired_data_order:
|
36 |
+
filename = filenames[i]
|
37 |
+
if not check_path_is_img(filename):
|
38 |
+
continue
|
39 |
+
A_path = os.path.join(dir_A, filename)
|
40 |
+
B_path = os.path.join(dir_B, filename)
|
41 |
+
if os.path.exists(A_path) and os.path.exists(B_path):
|
42 |
+
A_paths.append(A_path)
|
43 |
+
B_paths.append(B_path)
|
44 |
+
else:
|
45 |
+
dir_A = os.path.join(config['dataset']['dataroot'], config['common']['phase'] + 'numpypairedA')
|
46 |
+
dir_B = os.path.join(config['dataset']['dataroot'], config['common']['phase'] + 'numpypairedB')
|
47 |
+
if os.path.exists(dir_A) and os.path.exists(dir_B):
|
48 |
+
filenames = os.listdir(dir_A)
|
49 |
+
paired_data_order = check_create_shuffled_order(filenames, paired_data_order)
|
50 |
+
check_equal_length(filenames, paired_data_order, data)
|
51 |
+
for i in paired_data_order:
|
52 |
+
filename = filenames[i]
|
53 |
+
if not check_path_is_img(filename):
|
54 |
+
continue
|
55 |
+
A_path = os.path.join(dir_A, filename)
|
56 |
+
B_path = os.path.join(dir_B, filename)
|
57 |
+
if os.path.exists(A_path) and os.path.exists(B_path):
|
58 |
+
A_paths.append(A_path)
|
59 |
+
B_paths.append(B_path)
|
60 |
+
|
61 |
+
btoA = config['dataset']['direction'] == 'BtoA'
|
62 |
+
# get the number of channels of input image
|
63 |
+
input_nc = config['model']['output_nc'] if btoA else config['model']['input_nc']
|
64 |
+
output_nc = config['model']['input_nc'] if btoA else config['model']['output_nc']
|
65 |
+
|
66 |
+
transform = Transforms(config, input_grayscale_flag=(input_nc == 1), output_grayscale_flag=(output_nc == 1))
|
67 |
+
transform.transform_list.append(NumpyToTensor())
|
68 |
+
transform = transform.compose_transforms()
|
69 |
+
|
70 |
+
data['paired_A_path'] = A_paths
|
71 |
+
data['paired_B_path'] = B_paths
|
72 |
+
transforms['paired'] = transform
|
73 |
+
return paired_data_order
|
74 |
+
|
75 |
+
|
76 |
+
def apply_numpy_paired_transforms(index, data, transforms, return_dict):
|
77 |
+
if len(data['paired_A_path']) > 0:
|
78 |
+
return_dict['paired_A'], return_dict['paired_B'] = transforms['paired'] \
|
79 |
+
(data['paired_A_path'][index], data['paired_B_path'][index])
|
80 |
+
return_dict['paired_A_path'] = data['paired_A_path'][index]
|
81 |
+
return_dict['paired_B_path'] = data['paired_B_path'][index]
|
data/deprecated/numpy_unpaired_data.py
ADDED
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from utils.util import check_path_is_img
|
3 |
+
from utils.data_utils import Transforms
|
4 |
+
from utils.augmentation import NumpyToTensor
|
5 |
+
import random
|
6 |
+
|
7 |
+
def add_numpy_unpaired_data(data, transforms, config, shuffle=False):
|
8 |
+
A_paths = []
|
9 |
+
B_paths = []
|
10 |
+
if config['dataset']['unpaired_' + config['common']['phase'] + 'A_filelist'] != '':
|
11 |
+
unpaired_data_file1 = open(config['dataset']['unpaired_' + config['common']['phase'] + 'A_filelist'], 'r')
|
12 |
+
Lines = unpaired_data_file1.readlines()
|
13 |
+
if shuffle:
|
14 |
+
random.shuffle(Lines)
|
15 |
+
for line in Lines:
|
16 |
+
if not config['dataset']['use_absolute_datafile']:
|
17 |
+
file = os.path.join(config['dataset']['dataroot'], line.strip())
|
18 |
+
else:
|
19 |
+
file = line.strip()
|
20 |
+
if os.path.exists(file):
|
21 |
+
A_paths.append(file)
|
22 |
+
unpaired_data_file1.close()
|
23 |
+
|
24 |
+
unpaired_data_file2 = open(config['dataset']['unpaired_' + config['common']['phase'] + 'B_filelist'], 'r')
|
25 |
+
Lines = unpaired_data_file2.readlines()
|
26 |
+
if shuffle:
|
27 |
+
random.shuffle(Lines)
|
28 |
+
for line in Lines:
|
29 |
+
if not config['dataset']['use_absolute_datafile']:
|
30 |
+
file = os.path.join(config['dataset']['dataroot'], line.strip())
|
31 |
+
else:
|
32 |
+
file = line.strip()
|
33 |
+
if os.path.exists(file):
|
34 |
+
B_paths.append(file)
|
35 |
+
unpaired_data_file2.close()
|
36 |
+
elif config['dataset']['unpaired_' + config['common']['phase'] + 'A_folder'] != '' and \
|
37 |
+
config['dataset']['unpaired_' + config['common']['phase'] + 'B_folder'] != '':
|
38 |
+
dir_A = config['dataset']['unpaired_' + config['common']['phase'] + 'A_folder']
|
39 |
+
filenames = os.listdir(dir_A)
|
40 |
+
if shuffle:
|
41 |
+
random.shuffle(filenames)
|
42 |
+
for filename in filenames:
|
43 |
+
if not check_path_is_img(filename):
|
44 |
+
continue
|
45 |
+
A_path = os.path.join(dir_A, filename)
|
46 |
+
if os.path.exists(A_path):
|
47 |
+
A_paths.append(A_path)
|
48 |
+
|
49 |
+
dir_B = config['dataset']['unpaired_' + config['common']['phase'] + 'B_folder']
|
50 |
+
filenames = os.listdir(dir_B)
|
51 |
+
if shuffle:
|
52 |
+
random.shuffle(filenames)
|
53 |
+
for filename in filenames:
|
54 |
+
if not check_path_is_img(filename):
|
55 |
+
continue
|
56 |
+
B_path = os.path.join(dir_B, filename)
|
57 |
+
if os.path.exists(B_path):
|
58 |
+
B_paths.append(B_path)
|
59 |
+
|
60 |
+
else:
|
61 |
+
dir_A = os.path.join(config['dataset']['dataroot'], config['common']['phase'] + 'numpyunpairedA')
|
62 |
+
dir_B = os.path.join(config['dataset']['dataroot'], config['common']['phase'] + 'numpyunpairedB')
|
63 |
+
if os.path.exists(dir_A) and os.path.exists(dir_B):
|
64 |
+
filenames = os.listdir(dir_A)
|
65 |
+
if shuffle:
|
66 |
+
random.shuffle(filenames)
|
67 |
+
for filename in filenames:
|
68 |
+
if not check_path_is_img(filename):
|
69 |
+
continue
|
70 |
+
A_path = os.path.join(dir_A, filename)
|
71 |
+
A_paths.append(A_path)
|
72 |
+
filenames = os.listdir(dir_B)
|
73 |
+
if shuffle:
|
74 |
+
random.shuffle(filenames)
|
75 |
+
for filename in filenames:
|
76 |
+
if not check_path_is_img(filename):
|
77 |
+
continue
|
78 |
+
B_path = os.path.join(dir_B, filename)
|
79 |
+
B_paths.append(B_path)
|
80 |
+
|
81 |
+
|
82 |
+
btoA = config['dataset']['direction'] == 'BtoA'
|
83 |
+
input_nc = config['model']['output_nc'] if btoA else config['model']['input_nc']
|
84 |
+
output_nc = config['model']['input_nc'] if btoA else config['model']['output_nc']
|
85 |
+
|
86 |
+
transform = Transforms(config, input_grayscale_flag=(input_nc == 1), output_grayscale_flag=(output_nc == 1))
|
87 |
+
transform.transform_list.append(NumpyToTensor())
|
88 |
+
transform = transform.compose_transforms()
|
89 |
+
|
90 |
+
data['unpaired_A_path'] = A_paths
|
91 |
+
data['unpaired_B_path'] = B_paths
|
92 |
+
transforms['unpaired'] = transform
|
93 |
+
|
94 |
+
def apply_numpy_unpaired_transforms(index, data, transforms, return_dict):
|
95 |
+
if len(data['unpaired_A_path']) > 0 and len(data['unpaired_B_path']) > 0:
|
96 |
+
index_B = random.randint(0, len(data['unpaired_B_path']) - 1)
|
97 |
+
return_dict['unpaired_A'], return_dict['unpaired_B'] = transforms['unpaired'] \
|
98 |
+
(data['unpaired_A_path'][index], data['unpaired_B_path'][index_B])
|
99 |
+
return_dict['unpaired_A_path'] = data['unpaired_A_path'][index]
|
100 |
+
return_dict['unpaired_B_path'] = data['unpaired_B_path'][index_B]
|
data/deprecated/paired_data.py
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from utils.util import check_path_is_img
|
3 |
+
from utils.data_utils import Transforms, check_create_shuffled_order
|
4 |
+
from utils.augmentation import ImagePathToImage
|
5 |
+
|
6 |
+
|
7 |
+
def add_paired_data(data, transforms, config, paired_data_order):
|
8 |
+
A_paths = []
|
9 |
+
B_paths = []
|
10 |
+
|
11 |
+
if config['dataset']['paired_' + config['common']['phase'] + '_filelist'] != '':
|
12 |
+
paired_data_file = open(config['dataset']['paired_' + config['common']['phase'] + '_filelist'], 'r')
|
13 |
+
Lines = paired_data_file.readlines()
|
14 |
+
paired_data_order = check_create_shuffled_order(Lines, paired_data_order)
|
15 |
+
for i in paired_data_order:
|
16 |
+
line = Lines[i]
|
17 |
+
if not config['dataset']['use_absolute_datafile']:
|
18 |
+
file1 = os.path.join(config['dataset']['dataroot'], line.split(" ")[0]).strip()
|
19 |
+
file2 = os.path.join(config['dataset']['dataroot'], line.split(" ")[1]).strip()
|
20 |
+
else:
|
21 |
+
file1 = line.split(" ")[0].strip()
|
22 |
+
file2 = line.split(" ")[1].strip()
|
23 |
+
if os.path.exists(file1) and os.path.exists(file2):
|
24 |
+
A_paths.append(file1)
|
25 |
+
B_paths.append(file2)
|
26 |
+
paired_data_file.close()
|
27 |
+
elif config['dataset']['paired_' + config['common']['phase'] + 'A_folder'] != '' and \
|
28 |
+
config['dataset']['paired_' + config['common']['phase'] + 'B_folder'] != '':
|
29 |
+
dir_A = config['dataset']['paired_' + config['common']['phase'] + 'A_folder']
|
30 |
+
dir_B = config['dataset']['paired_' + config['common']['phase'] + 'B_folder']
|
31 |
+
filenames = os.listdir(dir_A)
|
32 |
+
paired_data_order = check_create_shuffled_order(filenames, paired_data_order)
|
33 |
+
for i in paired_data_order:
|
34 |
+
filename = filenames[i]
|
35 |
+
if not check_path_is_img(filename):
|
36 |
+
continue
|
37 |
+
A_path = os.path.join(dir_A, filename)
|
38 |
+
B_path = os.path.join(dir_B, filename)
|
39 |
+
if os.path.exists(A_path) and os.path.exists(B_path):
|
40 |
+
A_paths.append(A_path)
|
41 |
+
B_paths.append(B_path)
|
42 |
+
else:
|
43 |
+
dir_A = os.path.join(config['dataset']['dataroot'], config['common']['phase'] + 'pairedA')
|
44 |
+
dir_B = os.path.join(config['dataset']['dataroot'], config['common']['phase'] + 'pairedB')
|
45 |
+
if os.path.exists(dir_A) and os.path.exists(dir_B):
|
46 |
+
filenames = os.listdir(dir_A)
|
47 |
+
paired_data_order = check_create_shuffled_order(filenames, paired_data_order)
|
48 |
+
for i in paired_data_order:
|
49 |
+
filename = filenames[i]
|
50 |
+
if not check_path_is_img(filename):
|
51 |
+
continue
|
52 |
+
A_path = os.path.join(dir_A, filename)
|
53 |
+
B_path = os.path.join(dir_B, filename)
|
54 |
+
if os.path.exists(A_path) and os.path.exists(B_path):
|
55 |
+
A_paths.append(A_path)
|
56 |
+
B_paths.append(B_path)
|
57 |
+
|
58 |
+
btoA = config['dataset']['direction'] == 'BtoA'
|
59 |
+
# get the number of channels of input image
|
60 |
+
input_nc = config['model']['output_nc'] if btoA else config['model']['input_nc']
|
61 |
+
output_nc = config['model']['input_nc'] if btoA else config['model']['output_nc']
|
62 |
+
|
63 |
+
transform = Transforms(config, input_grayscale_flag=(input_nc == 1), output_grayscale_flag=(output_nc == 1))
|
64 |
+
transform.get_transform_from_config()
|
65 |
+
transform.get_transforms().insert(0, ImagePathToImage())
|
66 |
+
transform = transform.compose_transforms()
|
67 |
+
|
68 |
+
data['paired_A_path'] = A_paths
|
69 |
+
data['paired_B_path'] = B_paths
|
70 |
+
|
71 |
+
transforms['paired'] = transform
|
72 |
+
return paired_data_order
|
73 |
+
|
74 |
+
|
75 |
+
def apply_paired_transforms(index, data, transforms, return_dict):
|
76 |
+
if len(data['paired_A_path']) > 0:
|
77 |
+
return_dict['paired_A'], return_dict['paired_B'] = transforms['paired'] \
|
78 |
+
(data['paired_A_path'][index], data['paired_B_path'][index])
|
79 |
+
return_dict['paired_A_path'] = data['paired_A_path'][index]
|
80 |
+
return_dict['paired_B_path'] = data['paired_B_path'][index]
|
data/deprecated/patch_data.py
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
import torch
|
3 |
+
|
4 |
+
def load_patches(patch_batch_size, batch_size, patch_size, num_patch, diff_patch, index, data, transforms, return_dict):
|
5 |
+
if patch_size > 0:
|
6 |
+
assert (patch_batch_size % batch_size == 0), \
|
7 |
+
"patch_batch_size is not divisible by batch_size."
|
8 |
+
if 'paired_A' in return_dict or 'paired_B' in return_dict:
|
9 |
+
if not diff_patch:
|
10 |
+
# load patch from current image
|
11 |
+
patchA = return_dict['paired_A'].clone()
|
12 |
+
patchB = return_dict['paired_B'].clone()
|
13 |
+
else:
|
14 |
+
# load patch from a different image
|
15 |
+
pathA = data['paired_A_path'][(index + 1) % len(data['paired_A_path'])]
|
16 |
+
pathB = data['paired_B_path'][(index + 1) % len(data['paired_B_path'])]
|
17 |
+
patchA, patchB = transforms['paired'](pathA, pathB)
|
18 |
+
else:
|
19 |
+
if not diff_patch:
|
20 |
+
# load patch from current image
|
21 |
+
patchA = return_dict['unpaired_A'].clone()
|
22 |
+
patchB = return_dict['unpaired_B'].clone()
|
23 |
+
else:
|
24 |
+
# load patch from a different image
|
25 |
+
pathA = data['unpaired_A_path'][(index + 1) % len(data['unpaired_A_path'])]
|
26 |
+
pathB = data['unpaired_B_path'][(index + 1) % len(data['unpaired_B_path'])]
|
27 |
+
patchA, patchB = transforms['unpaired'](pathA, pathB)
|
28 |
+
|
29 |
+
# crop patch
|
30 |
+
patchAs = []
|
31 |
+
patchBs = []
|
32 |
+
_, h, w = patchA.size()
|
33 |
+
|
34 |
+
for _ in range(num_patch):
|
35 |
+
r = random.randint(0, h - patch_size - 1)
|
36 |
+
c = random.randint(0, w - patch_size - 1)
|
37 |
+
patchAs.append(patchA[:, r:r + patch_size, c:c + patch_size])
|
38 |
+
patchBs.append(patchB[:, r:r + patch_size, c:c + patch_size])
|
39 |
+
|
40 |
+
patchAs = torch.cat(patchAs, 0)
|
41 |
+
patchBs = torch.cat(patchBs, 0)
|
42 |
+
|
43 |
+
return_dict['patch_A'] = patchAs
|
44 |
+
return_dict['patch_B'] = patchBs
|
data/deprecated/unpaired_data.py
ADDED
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from utils.util import check_path_is_img
|
3 |
+
from utils.data_utils import Transforms
|
4 |
+
from utils.augmentation import ImagePathToImage
|
5 |
+
import random
|
6 |
+
|
7 |
+
def add_unpaired_data(data, transforms, config, shuffle=False):
|
8 |
+
A_paths = []
|
9 |
+
B_paths = []
|
10 |
+
if config['dataset']['unpaired_' + config['common']['phase'] + 'A_filelist'] != '':
|
11 |
+
unpaired_data_file1 = open(config['dataset']['unpaired_' + config['common']['phase'] + 'A_filelist'], 'r')
|
12 |
+
Lines = unpaired_data_file1.readlines()
|
13 |
+
if shuffle:
|
14 |
+
random.shuffle(Lines)
|
15 |
+
for line in Lines:
|
16 |
+
if not config['dataset']['use_absolute_datafile']:
|
17 |
+
file = os.path.join(config['dataset']['dataroot'], line.strip())
|
18 |
+
else:
|
19 |
+
file = line.strip()
|
20 |
+
if os.path.exists(file):
|
21 |
+
A_paths.append(file)
|
22 |
+
unpaired_data_file1.close()
|
23 |
+
|
24 |
+
unpaired_data_file2 = open(config['dataset']['unpaired_' + config['common']['phase'] + 'B_filelist'], 'r')
|
25 |
+
Lines = unpaired_data_file2.readlines()
|
26 |
+
if shuffle:
|
27 |
+
random.shuffle(Lines)
|
28 |
+
for line in Lines:
|
29 |
+
if not config['dataset']['use_absolute_datafile']:
|
30 |
+
file = os.path.join(config['dataset']['dataroot'], line.strip())
|
31 |
+
else:
|
32 |
+
file = line.strip()
|
33 |
+
if os.path.exists(file):
|
34 |
+
B_paths.append(file)
|
35 |
+
unpaired_data_file2.close()
|
36 |
+
elif config['dataset']['unpaired_' + config['common']['phase'] + 'A_folder'] != '' and \
|
37 |
+
config['dataset']['unpaired_' + config['common']['phase'] + 'B_folder'] != '':
|
38 |
+
dir_A = config['dataset']['unpaired_' + config['common']['phase'] + 'A_folder']
|
39 |
+
filenames = os.listdir(dir_A)
|
40 |
+
if shuffle:
|
41 |
+
random.shuffle(filenames)
|
42 |
+
for filename in filenames:
|
43 |
+
if not check_path_is_img(filename):
|
44 |
+
continue
|
45 |
+
A_path = os.path.join(dir_A, filename)
|
46 |
+
if os.path.exists(A_path):
|
47 |
+
A_paths.append(A_path)
|
48 |
+
|
49 |
+
dir_B = config['dataset']['unpaired_' + config['common']['phase'] + 'B_folder']
|
50 |
+
filenames = os.listdir(dir_B)
|
51 |
+
if shuffle:
|
52 |
+
random.shuffle(filenames)
|
53 |
+
for filename in filenames:
|
54 |
+
if not check_path_is_img(filename):
|
55 |
+
continue
|
56 |
+
B_path = os.path.join(dir_B, filename)
|
57 |
+
if os.path.exists(B_path):
|
58 |
+
B_paths.append(B_path)
|
59 |
+
|
60 |
+
else:
|
61 |
+
dir_A = os.path.join(config['dataset']['dataroot'], config['common']['phase'] + 'unpairedA')
|
62 |
+
dir_B = os.path.join(config['dataset']['dataroot'], config['common']['phase'] + 'unpairedB')
|
63 |
+
if os.path.exists(dir_A) and os.path.exists(dir_B):
|
64 |
+
filenames = os.listdir(dir_A)
|
65 |
+
if shuffle:
|
66 |
+
random.shuffle(filenames)
|
67 |
+
for filename in filenames:
|
68 |
+
if not check_path_is_img(filename):
|
69 |
+
continue
|
70 |
+
A_path = os.path.join(dir_A, filename)
|
71 |
+
A_paths.append(A_path)
|
72 |
+
filenames = os.listdir(dir_B)
|
73 |
+
if shuffle:
|
74 |
+
random.shuffle(filenames)
|
75 |
+
for filename in filenames:
|
76 |
+
if not check_path_is_img(filename):
|
77 |
+
continue
|
78 |
+
B_path = os.path.join(dir_B, filename)
|
79 |
+
B_paths.append(B_path)
|
80 |
+
|
81 |
+
|
82 |
+
btoA = config['dataset']['direction'] == 'BtoA'
|
83 |
+
input_nc = config['model']['output_nc'] if btoA else config['model']['input_nc']
|
84 |
+
output_nc = config['model']['input_nc'] if btoA else config['model']['output_nc']
|
85 |
+
|
86 |
+
transform = Transforms(config, input_grayscale_flag=(input_nc == 1), output_grayscale_flag=(output_nc == 1))
|
87 |
+
transform.get_transform_from_config()
|
88 |
+
transform.get_transforms().insert(0, ImagePathToImage())
|
89 |
+
transform = transform.compose_transforms()
|
90 |
+
|
91 |
+
data['unpaired_A_path'] = A_paths
|
92 |
+
data['unpaired_B_path'] = B_paths
|
93 |
+
transforms['unpaired'] = transform
|
94 |
+
|
95 |
+
def apply_unpaired_transforms(index, data, transforms, return_dict):
|
96 |
+
if len(data['unpaired_A_path']) > 0 and len(data['unpaired_B_path']) > 0:
|
97 |
+
index_B = random.randint(0, len(data['unpaired_B_path']) - 1)
|
98 |
+
return_dict['unpaired_A'], return_dict['unpaired_B'] = transforms['unpaired'] \
|
99 |
+
(data['unpaired_A_path'][index], data['unpaired_B_path'][index_B])
|
100 |
+
return_dict['unpaired_A_path'] = data['unpaired_A_path'][index]
|
101 |
+
return_dict['unpaired_B_path'] = data['unpaired_B_path'][index_B]
|
data/static_data.py
ADDED
@@ -0,0 +1,457 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os, sys
|
2 |
+
import random
|
3 |
+
import numpy as np
|
4 |
+
from utils.augmentation import ImagePathToImage, NumpyToTensor
|
5 |
+
from utils.data_utils import Transforms
|
6 |
+
from utils.util import check_path_is_static_data
|
7 |
+
import torch
|
8 |
+
from PIL import Image
|
9 |
+
|
10 |
+
|
11 |
+
def check_dataname_folder_correspondence(data_names, group, group_name):
|
12 |
+
for data_name in data_names:
|
13 |
+
if data_name + '_folder' not in group:
|
14 |
+
print("%s not found in config file. Going to use dataroot mode to load group %s." % (data_name + '_folder', group_name))
|
15 |
+
return False
|
16 |
+
return True
|
17 |
+
|
18 |
+
|
19 |
+
def custom_check_path_exists(str1):
|
20 |
+
return True if (str1 == "None" or os.path.exists(str1)) else False
|
21 |
+
|
22 |
+
|
23 |
+
def custom_getsize(str1):
|
24 |
+
return 1 if str1 == "None" else os.path.getsize(str1)
|
25 |
+
|
26 |
+
|
27 |
+
def check_different_extension_path_exists(str1):
|
28 |
+
acceptable_extensions = ['png', 'jpg', 'jpeg', 'npy', 'npz', 'PNG', 'JPG', 'JPEG']
|
29 |
+
curr_extension = str1.split('.')[-1]
|
30 |
+
for extension in acceptable_extensions:
|
31 |
+
str2 = str1.replace(curr_extension, extension)
|
32 |
+
if os.path.exists(str2):
|
33 |
+
return str2
|
34 |
+
return None
|
35 |
+
|
36 |
+
|
37 |
+
class StaticData(object):
|
38 |
+
|
39 |
+
def __init__(self, config, shuffle=False):
|
40 |
+
# private variables
|
41 |
+
self.file_groups = []
|
42 |
+
self.type_groups = []
|
43 |
+
self.group_names = []
|
44 |
+
self.pair_type_groups = []
|
45 |
+
self.len_of_groups = []
|
46 |
+
self.transforms = {}
|
47 |
+
# parameters
|
48 |
+
self.shuffle = shuffle
|
49 |
+
self.config = config
|
50 |
+
|
51 |
+
|
52 |
+
def load_static_data(self):
|
53 |
+
data_dict = self.config['dataset'][self.config['common']['phase'] + '_data']
|
54 |
+
print("----------------loading %s static data.---------------------" % self.config['common']['phase'])
|
55 |
+
|
56 |
+
if len(data_dict) == 0:
|
57 |
+
self.len_of_groups.append(0)
|
58 |
+
return
|
59 |
+
|
60 |
+
self.group_names = list(data_dict.keys())
|
61 |
+
for i, group in enumerate(data_dict.values()): # examples: (0, group_1), (1, group_2)
|
62 |
+
data_types = group['data_types'] # examples: 'image', 'patch'
|
63 |
+
data_names = group['data_names'] # examples: 'real_A', 'patch_A'
|
64 |
+
self.file_groups.append({})
|
65 |
+
self.type_groups.append({})
|
66 |
+
self.len_of_groups.append(0)
|
67 |
+
self.pair_type_groups.append(group['paired'])
|
68 |
+
|
69 |
+
# exclude patch data since they are not stored on disk. They will be handled later.
|
70 |
+
data_types, data_names = self.exclude_patch_data(data_types, data_names)
|
71 |
+
assert(len(data_types) == len(data_names))
|
72 |
+
|
73 |
+
if len(data_names) == 0:
|
74 |
+
continue
|
75 |
+
|
76 |
+
for data_name, data_type in zip(data_names, data_types):
|
77 |
+
self.file_groups[i][data_name] = []
|
78 |
+
self.type_groups[i][data_name] = data_type
|
79 |
+
|
80 |
+
|
81 |
+
# paired data
|
82 |
+
if group['paired']:
|
83 |
+
# First way to load data: load a file list
|
84 |
+
if 'file_list' in group:
|
85 |
+
file_list = group['file_list']
|
86 |
+
paired_file = open(file_list, 'rt')
|
87 |
+
lines = paired_file.readlines()
|
88 |
+
if self.shuffle:
|
89 |
+
random.shuffle(lines)
|
90 |
+
for line in lines:
|
91 |
+
items = line.strip().split(' ')
|
92 |
+
if len(items) == len(data_names):
|
93 |
+
ok = True
|
94 |
+
for item in items:
|
95 |
+
ok = ok and os.path.exists(item) and os.path.getsize(item) > 0
|
96 |
+
if ok:
|
97 |
+
for data_name, item in zip(data_names, items):
|
98 |
+
self.file_groups[i][data_name].append(item)
|
99 |
+
paired_file.close()
|
100 |
+
# second and third way to load data: specify one folder for each dataname, or specify a dataroot folder
|
101 |
+
elif check_dataname_folder_correspondence(data_names, group, self.group_names[i]) or 'dataroot' in group:
|
102 |
+
dataname_to_dir_dict = {}
|
103 |
+
for data_name, data_type in zip(data_names, data_types):
|
104 |
+
if 'dataroot' in group:
|
105 |
+
# In new data config format, data is stored in dataroot_name/mode/dataname. e.g. FFHQ/train/pairedA
|
106 |
+
# In old format, data is stored in dataroot_name/mode_dataname. e.g. FFHQ/train_pairedA
|
107 |
+
# So we need to check both.
|
108 |
+
dir = os.path.join(group['dataroot'], self.config['common']['phase'], data_name)
|
109 |
+
if not os.path.exists(dir):
|
110 |
+
old_dir = os.path.join(group['dataroot'], self.config['common']['phase'] + data_name.replace('_', ''))
|
111 |
+
if 'numpy' in data_type:
|
112 |
+
old_dir += 'numpy'
|
113 |
+
if not os.path.exists(old_dir):
|
114 |
+
print("Both %s and %s does not exist. Please check." % (dir, old_dir))
|
115 |
+
sys.exit()
|
116 |
+
else:
|
117 |
+
dir = old_dir
|
118 |
+
else:
|
119 |
+
dir = group[data_name + '_folder']
|
120 |
+
if not os.path.exists(dir):
|
121 |
+
print("directory %s does not exist. Please check." % dir)
|
122 |
+
sys.exit()
|
123 |
+
dataname_to_dir_dict[data_name] = dir
|
124 |
+
|
125 |
+
filenames = os.listdir(dataname_to_dir_dict[data_names[0]])
|
126 |
+
if self.shuffle:
|
127 |
+
random.shuffle(filenames)
|
128 |
+
for filename in filenames:
|
129 |
+
if not check_path_is_static_data(filename):
|
130 |
+
continue
|
131 |
+
file_paths = []
|
132 |
+
for data_name in data_names:
|
133 |
+
file_path = os.path.join(dataname_to_dir_dict[data_name], filename)
|
134 |
+
checked_extension = check_different_extension_path_exists(file_path)
|
135 |
+
if checked_extension is not None:
|
136 |
+
file_paths.append(checked_extension)
|
137 |
+
|
138 |
+
if len(file_paths) != len(data_names):
|
139 |
+
print("for file %s , looks like some of the other pair data is missing. Ignoring and proceeding." % filename)
|
140 |
+
continue
|
141 |
+
else:
|
142 |
+
for j in range(len(data_names)):
|
143 |
+
data_name = data_names[j]
|
144 |
+
self.file_groups[i][data_name].append(file_paths[j])
|
145 |
+
else:
|
146 |
+
print("method for loading data is incorrect/unspecified for data group %s." % self.group_names)
|
147 |
+
sys.exit()
|
148 |
+
|
149 |
+
self.len_of_groups[i] = len(self.file_groups[i][data_names[0]])
|
150 |
+
|
151 |
+
# unpaired data
|
152 |
+
else:
|
153 |
+
# First way to load data: load a file list
|
154 |
+
if 'file_list' in group:
|
155 |
+
file_list = group['file_list']
|
156 |
+
unpaired_file = open(file_list, 'rt')
|
157 |
+
lines = unpaired_file.readlines()
|
158 |
+
if self.shuffle:
|
159 |
+
random.shuffle(lines)
|
160 |
+
item_count = 0
|
161 |
+
for line in lines:
|
162 |
+
items = line.strip().split(' ')
|
163 |
+
if len(items) == len(data_names):
|
164 |
+
ok = True
|
165 |
+
for item in items:
|
166 |
+
ok = ok and custom_check_path_exists(item) and custom_getsize(item) > 0
|
167 |
+
if ok:
|
168 |
+
has_data = False
|
169 |
+
for data_name, item in zip(data_names, items):
|
170 |
+
if item != 'None':
|
171 |
+
self.file_groups[i][data_name].append(item)
|
172 |
+
has_data = True
|
173 |
+
if has_data:
|
174 |
+
item_count += 1
|
175 |
+
unpaired_file.close()
|
176 |
+
self.len_of_groups[i] = item_count
|
177 |
+
# second and third way to load data: specify one folder for each dataname, or specify a dataroot folder
|
178 |
+
elif check_dataname_folder_correspondence(data_names, group, self.group_names[i]) or 'dataroot' in group:
|
179 |
+
max_length = 0
|
180 |
+
for data_name, data_type in zip(data_names, data_types):
|
181 |
+
if 'dataroot' in group:
|
182 |
+
# In new data config format, data is stored in dataroot_name/mode/dataname. e.g. FFHQ/train/pairedA
|
183 |
+
# In old format, data is stored in dataroot_name/mode_dataname. e.g. FFHQ/train_pairedA
|
184 |
+
# So we need to check both.
|
185 |
+
dir = os.path.join(group['dataroot'], self.config['common']['phase'], data_name)
|
186 |
+
if not os.path.exists(dir):
|
187 |
+
old_dir = os.path.join(group['dataroot'], self.config['common']['phase'] + data_name.replace('_', ''))
|
188 |
+
if 'numpy' in data_type:
|
189 |
+
old_dir += 'numpy'
|
190 |
+
if not os.path.exists(old_dir):
|
191 |
+
print("Both %s and %s does not exist. Please check." % (dir, old_dir))
|
192 |
+
sys.exit()
|
193 |
+
else:
|
194 |
+
dir = old_dir
|
195 |
+
else:
|
196 |
+
dir = group[data_name + '_folder']
|
197 |
+
if not os.path.exists(dir):
|
198 |
+
print("directory %s does not exist. Please check." % dir)
|
199 |
+
sys.exit()
|
200 |
+
filenames = os.listdir(dir)
|
201 |
+
if self.shuffle:
|
202 |
+
random.shuffle(filenames)
|
203 |
+
|
204 |
+
item_count = 0
|
205 |
+
for filename in filenames:
|
206 |
+
if not check_path_is_static_data(filename):
|
207 |
+
continue
|
208 |
+
fullpath = os.path.join(dir, filename)
|
209 |
+
if os.path.exists(fullpath):
|
210 |
+
self.file_groups[i][data_name].append(fullpath)
|
211 |
+
item_count += 1
|
212 |
+
max_length = max(item_count, max_length)
|
213 |
+
self.len_of_groups[i] = max_length
|
214 |
+
else:
|
215 |
+
print("method for loading data is incorrect/unspecified for data group %s." % self.group_names)
|
216 |
+
sys.exit()
|
217 |
+
|
218 |
+
|
219 |
+
def create_transforms(self):
|
220 |
+
btoA = self.config['dataset']['direction'] == 'BtoA'
|
221 |
+
input_nc = self.config['model']['output_nc'] if btoA else self.config['model']['input_nc']
|
222 |
+
output_nc = self.config['model']['input_nc'] if btoA else self.config['model']['output_nc']
|
223 |
+
input_grayscale_flag = (input_nc == 1)
|
224 |
+
output_grayscale_flag = (output_nc == 1)
|
225 |
+
|
226 |
+
data_dict = self.config['dataset'][self.config['common']['phase'] + '_data']
|
227 |
+
for i, group in enumerate(data_dict.values()): # examples: (0, group_1), (1, group_2)
|
228 |
+
|
229 |
+
if i not in self.transforms:
|
230 |
+
self.transforms[i] = {}
|
231 |
+
|
232 |
+
data_types = group['data_types'] # examples: 'image', 'patch'
|
233 |
+
data_names = group['data_names'] # examples: 'real_A', 'patch_A'
|
234 |
+
data_types, data_names = self.exclude_patch_data(data_types, data_names)
|
235 |
+
for data_name, data_type in zip(data_names, data_types):
|
236 |
+
if data_type in self.transforms[i]:
|
237 |
+
continue
|
238 |
+
self.transforms[i][data_type] = Transforms(self.config, input_grayscale_flag=input_grayscale_flag,
|
239 |
+
output_grayscale_flag=output_grayscale_flag)
|
240 |
+
self.transforms[i][data_type].create_transforms_from_list(group['preprocess'])
|
241 |
+
if '.png' in self.file_groups[i][data_name][0] or '.jpg' in self.file_groups[i][data_name][0] or \
|
242 |
+
'.jpeg' in self.file_groups[i][data_name][0]:
|
243 |
+
self.transforms[i][data_type].get_transforms().insert(0, ImagePathToImage())
|
244 |
+
elif '.npy' in self.file_groups[i][data_name][0] or '.npz' in self.file_groups[i][data_name][0]:
|
245 |
+
self.transforms[i][data_type].get_transforms().insert(0, NumpyToTensor())
|
246 |
+
self.transforms[i][data_type] = self.transforms[i][data_type].compose_transforms()
|
247 |
+
|
248 |
+
|
249 |
+
def apply_transformations_to_images(self, img_list, img_dataname_list, transform, return_dict,
|
250 |
+
next_img_paths_bucket, next_img_dataname_list):
|
251 |
+
|
252 |
+
if len(img_list) == 1:
|
253 |
+
return_dict[img_dataname_list[0]], _ = transform(img_list[0], None)
|
254 |
+
elif len(img_list) > 1:
|
255 |
+
next_data_count = len(next_img_paths_bucket)
|
256 |
+
img_list += next_img_paths_bucket
|
257 |
+
img_dataname_list += next_img_dataname_list
|
258 |
+
|
259 |
+
input1, input2 = img_list[0], img_list[1:]
|
260 |
+
output1, output2 = transform(input1, input2) # output1 is one image. output2 is a list of images.
|
261 |
+
|
262 |
+
if next_data_count != 0:
|
263 |
+
output2, next_outputs = output2[:-next_data_count], output2[-next_data_count:]
|
264 |
+
for i in range(next_data_count):
|
265 |
+
return_dict[img_dataname_list[-next_data_count+i] + '_next'] = next_outputs[i]
|
266 |
+
|
267 |
+
return_dict[img_dataname_list[0]] = output1
|
268 |
+
for j in range(0, len(output2)):
|
269 |
+
return_dict[img_dataname_list[j+1]] = output2[j]
|
270 |
+
|
271 |
+
return return_dict
|
272 |
+
|
273 |
+
|
274 |
+
def calculate_landmark_scale(self, data_path, data_type, i):
|
275 |
+
if data_type == 'image':
|
276 |
+
original_image = Image.open(data_path)
|
277 |
+
original_width, original_height = original_image.size
|
278 |
+
else:
|
279 |
+
original_image = np.load(data_path)
|
280 |
+
original_height, original_width = original_image.shape[0], original_image.shape[1]
|
281 |
+
transformed_image, _ = self.transforms[i][data_type](data_path, None)
|
282 |
+
transformed_height, transformed_width = transformed_image.size()[1:]
|
283 |
+
landmark_scale = (transformed_width / original_width, transformed_height / original_height)
|
284 |
+
return landmark_scale
|
285 |
+
|
286 |
+
|
287 |
+
def get_item(self, idx):
|
288 |
+
|
289 |
+
return_dict = {}
|
290 |
+
data_dict = self.config['dataset'][self.config['common']['phase'] + '_data']
|
291 |
+
|
292 |
+
for i, group in enumerate(data_dict.values()):
|
293 |
+
if self.file_groups[i] == {}:
|
294 |
+
continue
|
295 |
+
|
296 |
+
paired_type = self.pair_type_groups[i]
|
297 |
+
inner_idx = idx if idx < self.len_of_groups[i] else random.randint(0, self.len_of_groups[i] - 1)
|
298 |
+
|
299 |
+
landmark_scale = None
|
300 |
+
|
301 |
+
# for patches since they might need to be loaded from different images.
|
302 |
+
next_img_paths_bucket = []
|
303 |
+
next_img_dataname_list = []
|
304 |
+
next_numpy_paths_bucket = []
|
305 |
+
next_numpy_dataname_list = []
|
306 |
+
|
307 |
+
# First, handle all non-patch data
|
308 |
+
if paired_type:
|
309 |
+
img_paths_bucket = []
|
310 |
+
img_dataname_list = []
|
311 |
+
numpy_paths_bucket = []
|
312 |
+
numpy_dataname_list = []
|
313 |
+
|
314 |
+
for data_name, data_list in self.file_groups[i].items():
|
315 |
+
data_type = self.type_groups[i][data_name]
|
316 |
+
if data_type in ['image', 'numpy']:
|
317 |
+
if paired_type:
|
318 |
+
# augmentation will be applied to all images in paired group all at once so need to gather the images here.
|
319 |
+
if data_type == 'image':
|
320 |
+
img_paths_bucket.append(data_list[inner_idx])
|
321 |
+
img_dataname_list.append(data_name)
|
322 |
+
else:
|
323 |
+
numpy_paths_bucket.append(data_list[inner_idx])
|
324 |
+
numpy_dataname_list.append(data_name)
|
325 |
+
return_dict[data_name + '_path'] = data_list[inner_idx]
|
326 |
+
if landmark_scale is None:
|
327 |
+
landmark_scale = self.calculate_landmark_scale(data_list[inner_idx], data_type, i)
|
328 |
+
if 'diff_patch' in self.config['dataset'] and self.config['dataset']['diff_patch'] and \
|
329 |
+
data_name in group['patch_sources']:
|
330 |
+
next_idx = (inner_idx + 1) % (len(data_list) - 1)
|
331 |
+
if data_type == 'image':
|
332 |
+
next_img_paths_bucket.append(data_list[next_idx])
|
333 |
+
next_img_dataname_list.append(data_name)
|
334 |
+
else:
|
335 |
+
next_numpy_paths_bucket.append(data_list[next_idx])
|
336 |
+
next_numpy_dataname_list.append(data_name)
|
337 |
+
else:
|
338 |
+
unpaired_inner_idx = random.randint(0, len(data_list) - 1)
|
339 |
+
return_dict[data_name], _ = self.transforms[i][data_type](data_list[unpaired_inner_idx], None)
|
340 |
+
if landmark_scale is None:
|
341 |
+
landmark_scale = self.calculate_landmark_scale(data_list[unpaired_inner_idx], data_type, i)
|
342 |
+
if 'diff_patch' in self.config['dataset'] and self.config['dataset']['diff_patch'] and \
|
343 |
+
data_name in group['patch_sources']:
|
344 |
+
next_idx = (unpaired_inner_idx + 1) % (len(data_list) - 1)
|
345 |
+
return_dict[data_name + '_next'], _ = self.transforms[i][data_type](data_list[next_idx], None)
|
346 |
+
return_dict[data_name + '_path'] = data_list[unpaired_inner_idx]
|
347 |
+
elif self.type_groups[i][data_name] == 'landmark':
|
348 |
+
# We do not apply transformations on landmarks. Only scales landmarks to transformed image's size.
|
349 |
+
# Also numpy data is passed into network as numpy array and not tensor.
|
350 |
+
lmk = np.load(data_list[inner_idx])
|
351 |
+
if self.config['dataset']['landmark_scale'] is not None:
|
352 |
+
lmk[:, 0] *= self.config['dataset']['landmark_scale'][0]
|
353 |
+
lmk[:, 1] *= self.config['dataset']['landmark_scale'][1]
|
354 |
+
else:
|
355 |
+
if landmark_scale is None:
|
356 |
+
print("landmark_scale is None. If you have not defined it in config file, please specify "
|
357 |
+
"image and numpy data before landmark data and the proper scale will be automatically calculated.")
|
358 |
+
else:
|
359 |
+
lmk[:, 0] *= landmark_scale[0]
|
360 |
+
lmk[:, 1] *= landmark_scale[1]
|
361 |
+
return_dict[data_name] = lmk
|
362 |
+
return_dict[data_name + '_path'] = data_list[inner_idx]
|
363 |
+
|
364 |
+
|
365 |
+
if paired_type:
|
366 |
+
# apply augmentations to all images and numpy arrays
|
367 |
+
if 'image' in self.transforms[i]:
|
368 |
+
return_dict = self.apply_transformations_to_images(img_paths_bucket, img_dataname_list,
|
369 |
+
self.transforms[i]['image'], return_dict,
|
370 |
+
next_img_paths_bucket,
|
371 |
+
next_img_dataname_list)
|
372 |
+
if 'numpy' in self.transforms[i]:
|
373 |
+
return_dict = self.apply_transformations_to_images(numpy_paths_bucket, numpy_dataname_list,
|
374 |
+
self.transforms[i]['numpy'], return_dict,
|
375 |
+
next_numpy_paths_bucket,
|
376 |
+
next_numpy_dataname_list)
|
377 |
+
|
378 |
+
# Handle patch data
|
379 |
+
data_types = group['data_types'] # examples: 'image', 'patch'
|
380 |
+
data_names = group['data_names'] # examples: 'real_A', 'patch_A'
|
381 |
+
data_types, data_names = self.filter_patch_data(data_types, data_names)
|
382 |
+
|
383 |
+
if 'patch_sources' in group:
|
384 |
+
patch_sources = group['patch_sources']
|
385 |
+
return_dict = self.load_patches(
|
386 |
+
data_names,
|
387 |
+
self.config['dataset']['patch_batch_size'],
|
388 |
+
self.config['dataset']['batch_size'],
|
389 |
+
self.config['dataset']['patch_size'],
|
390 |
+
self.config['dataset']['patch_batch_size'] // self.config['dataset']['batch_size'],
|
391 |
+
self.config['dataset']['diff_patch'],
|
392 |
+
patch_sources,
|
393 |
+
return_dict,
|
394 |
+
)
|
395 |
+
|
396 |
+
return return_dict
|
397 |
+
|
398 |
+
|
399 |
+
def get_len(self):
|
400 |
+
if len(self.len_of_groups) == 0:
|
401 |
+
return 0
|
402 |
+
else:
|
403 |
+
return max(self.len_of_groups)
|
404 |
+
|
405 |
+
|
406 |
+
def exclude_patch_data(self, data_types, data_names):
|
407 |
+
data_types_patch_excluded = []
|
408 |
+
data_names_patch_excluded = []
|
409 |
+
for data_name, data_type in zip(data_names, data_types):
|
410 |
+
if data_type != 'patch':
|
411 |
+
data_types_patch_excluded.append(data_type)
|
412 |
+
data_names_patch_excluded.append(data_name)
|
413 |
+
return data_types_patch_excluded, data_names_patch_excluded
|
414 |
+
|
415 |
+
|
416 |
+
def filter_patch_data(self, data_types, data_names):
|
417 |
+
data_types_patch = []
|
418 |
+
data_names_patch = []
|
419 |
+
for data_name, data_type in zip(data_names, data_types):
|
420 |
+
if data_type == 'patch':
|
421 |
+
data_types_patch.append(data_type)
|
422 |
+
data_names_patch.append(data_name)
|
423 |
+
return data_types_patch, data_names_patch
|
424 |
+
|
425 |
+
|
426 |
+
def load_patches(self, data_names, patch_batch_size, batch_size, patch_size,
|
427 |
+
num_patch, diff_patch, patch_sources, return_dict):
|
428 |
+
|
429 |
+
if patch_size > 0:
|
430 |
+
assert (patch_batch_size % batch_size == 0), \
|
431 |
+
"patch_batch_size is not divisible by batch_size."
|
432 |
+
assert (len(patch_sources) == len(data_names)), \
|
433 |
+
"length of patch_sources is not the same as number of patch data specified. Please check again in config file."
|
434 |
+
|
435 |
+
rlist = [] # used for cropping patches
|
436 |
+
clist = [] # used for cropping patches
|
437 |
+
for _ in range(num_patch):
|
438 |
+
r = random.randint(0, self.config['dataset']['crop_size'] - patch_size - 1)
|
439 |
+
c = random.randint(0, self.config['dataset']['crop_size'] - patch_size - 1)
|
440 |
+
rlist.append(r)
|
441 |
+
clist.append(c)
|
442 |
+
|
443 |
+
for i in range(len(data_names)):
|
444 |
+
# load transformed image
|
445 |
+
patch = return_dict[patch_sources[i]] if not diff_patch else return_dict[patch_sources[i] + '_next']
|
446 |
+
|
447 |
+
# crop patch
|
448 |
+
patchs = []
|
449 |
+
_, h, w = patch.size()
|
450 |
+
|
451 |
+
for j in range(num_patch):
|
452 |
+
patchs.append(patch[:, rlist[j]:rlist[j] + patch_size, clist[j]:clist[j] + patch_size])
|
453 |
+
patchs = torch.cat(patchs, 0)
|
454 |
+
|
455 |
+
return_dict[data_names[i]] = patchs
|
456 |
+
|
457 |
+
return return_dict
|
data/super_dataset.py
ADDED
@@ -0,0 +1,321 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
import torch.utils.data as data
|
3 |
+
from utils.data_utils import check_img_loaded, check_numpy_loaded
|
4 |
+
|
5 |
+
from data.test_data import add_test_data, apply_test_transforms
|
6 |
+
from data.test_video_data import TestVideoData
|
7 |
+
from data.static_data import StaticData
|
8 |
+
|
9 |
+
from multiprocessing import Pool
|
10 |
+
import sys
|
11 |
+
|
12 |
+
|
13 |
+
class DataBin(object):
|
14 |
+
def __init__(self, filegroups):
|
15 |
+
self.filegroups = filegroups
|
16 |
+
|
17 |
+
|
18 |
+
class SuperDataset(data.Dataset):
|
19 |
+
def __init__(self, config, shuffle=False, check_all_data=False, DDP_device=None):
|
20 |
+
self.config = config
|
21 |
+
|
22 |
+
self.check_all_data = check_all_data
|
23 |
+
self.DDP_device = DDP_device
|
24 |
+
|
25 |
+
self.data = {} # Will be dictionary. Keys are data names, e.g. paired_A, patch_A. Values are lists containing associated data.
|
26 |
+
self.transforms = {}
|
27 |
+
|
28 |
+
if self.config['common']['phase'] == 'test':
|
29 |
+
if not self.config['testing']['test_video'] is None:
|
30 |
+
self.test_video_data = TestVideoData(self.config)
|
31 |
+
else:
|
32 |
+
add_test_data(self.data, self.transforms, self.config)
|
33 |
+
return
|
34 |
+
|
35 |
+
self.static_data = StaticData(self.config, shuffle)
|
36 |
+
|
37 |
+
|
38 |
+
def convert_old_config_to_new(self):
|
39 |
+
data_types = self.config['dataset']['data_type']
|
40 |
+
if len(data_types) == 1 and data_types[0] == 'custom':
|
41 |
+
# convert custom data configuration to new data configuration
|
42 |
+
old_dict = self.config['dataset']['custom_' + self.config['common']['phase'] + '_data']
|
43 |
+
preprocess_list = self.config['dataset']['preprocess']
|
44 |
+
new_datadict = self.config['dataset'][self.config['common']['phase'] + '_data'] = old_dict
|
45 |
+
for i, group in enumerate(new_datadict.values()): # examples: (0, group_1), (1, group_2)
|
46 |
+
group['paired'] = True
|
47 |
+
group['preprocess'] = preprocess_list
|
48 |
+
# custom data does not support patch so we skip patch logic.
|
49 |
+
else:
|
50 |
+
new_datadict = self.config['dataset'][self.config['common']['phase'] + '_data'] = {}
|
51 |
+
preprocess_list = self.config['dataset']['preprocess']
|
52 |
+
new_datadict['paired_group'] = {}
|
53 |
+
new_datadict['paired_group']['paired'] = True
|
54 |
+
new_datadict['paired_group']['data_types'] = []
|
55 |
+
new_datadict['paired_group']['data_names'] = []
|
56 |
+
new_datadict['paired_group']['preprocess'] = preprocess_list
|
57 |
+
new_datadict['unpaired_group'] = {}
|
58 |
+
new_datadict['unpaired_group']['paired'] = False
|
59 |
+
new_datadict['unpaired_group']['data_types'] = []
|
60 |
+
new_datadict['unpaired_group']['data_names'] = []
|
61 |
+
new_datadict['unpaired_group']['preprocess'] = preprocess_list
|
62 |
+
|
63 |
+
for i in range(len(self.config['dataset']['data_type'])):
|
64 |
+
data_type = self.config['dataset']['data_type'][i]
|
65 |
+
if data_type == 'paired' or data_type == 'paired_numpy':
|
66 |
+
if self.config['dataset']['paired_' + self.config['common']['phase'] + '_filelist'] != '':
|
67 |
+
new_datadict['paired_group']['file_list'] = self.config['dataset'][
|
68 |
+
'paired_' + self.config['common']['phase'] + '_filelist']
|
69 |
+
elif self.config['dataset']['paired_' + self.config['common']['phase'] + 'A_folder'] != '' and \
|
70 |
+
self.config['dataset']['paired_' + self.config['common']['phase'] + 'B_folder'] != '':
|
71 |
+
new_datadict['paired_group']['paired_A_folder'] = self.config['dataset']['paired_' + self.config['common']['phase'] + 'A_folder']
|
72 |
+
new_datadict['paired_group']['paired_B_folder'] = self.config['dataset']['paired_' + self.config['common']['phase'] + 'B_folder']
|
73 |
+
else:
|
74 |
+
new_datadict['paired_group']['dataroot'] = self.config['dataset']['dataroot']
|
75 |
+
|
76 |
+
new_datadict['paired_group']['data_names'].append('paired_A')
|
77 |
+
new_datadict['paired_group']['data_names'].append('paired_B')
|
78 |
+
if data_type == 'paired':
|
79 |
+
new_datadict['paired_group']['data_types'].append('image')
|
80 |
+
new_datadict['paired_group']['data_types'].append('image')
|
81 |
+
else:
|
82 |
+
new_datadict['paired_group']['data_types'].append('numpy')
|
83 |
+
new_datadict['paired_group']['data_types'].append('numpy')
|
84 |
+
|
85 |
+
elif data_type == 'unpaired' or data_type == 'unpaired_numpy':
|
86 |
+
if self.config['dataset']['unpaired_' + self.config['common']['phase'] + 'A_filelist'] != ''\
|
87 |
+
and self.config['dataset']['unpaired_' + self.config['common']['phase'] + 'B_filelist'] != '':
|
88 |
+
# combine those two filelists into one filelist
|
89 |
+
self.combine_two_filelists_into_one(
|
90 |
+
self.config['dataset']['unpaired_' + self.config['common']['phase'] + 'A_filelist'],
|
91 |
+
self.config['dataset']['unpaired_' + self.config['common']['phase'] + 'B_filelist']
|
92 |
+
)
|
93 |
+
new_datadict['unpaired_group']['file_list'] = './tmp_filelist.txt'
|
94 |
+
elif self.config['dataset']['unpaired_' + self.config['common']['phase'] + 'A_folder'] != '' and \
|
95 |
+
self.config['dataset']['unpaired_' + self.config['common']['phase'] + 'B_folder'] != '':
|
96 |
+
new_datadict['unpaired_group']['unpaired_A_folder'] = self.config['dataset']['unpaired_' + self.config['common']['phase'] + 'A_folder']
|
97 |
+
new_datadict['unpaired_group']['unpaired_B_folder'] = self.config['dataset']['unpaired_' + self.config['common']['phase'] + 'B_folder']
|
98 |
+
else:
|
99 |
+
new_datadict['unpaired_group']['dataroot'] = self.config['dataset']['dataroot']
|
100 |
+
|
101 |
+
new_datadict['unpaired_group']['data_names'].append('unpaired_A')
|
102 |
+
new_datadict['unpaired_group']['data_names'].append('unpaired_B')
|
103 |
+
if data_type == 'unpaired':
|
104 |
+
new_datadict['unpaired_group']['data_types'].append('image')
|
105 |
+
new_datadict['unpaired_group']['data_types'].append('image')
|
106 |
+
else:
|
107 |
+
new_datadict['unpaired_group']['data_types'].append('numpy')
|
108 |
+
new_datadict['unpaired_group']['data_types'].append('numpy')
|
109 |
+
|
110 |
+
elif data_type == 'landmark':
|
111 |
+
if self.config['dataset']['paired_' + self.config['common']['phase'] + '_filelist'] != '':
|
112 |
+
new_datadict['paired_group']['file_list'] = self.config['dataset'][
|
113 |
+
'paired_' + self.config['common']['phase'] + '_filelist']
|
114 |
+
elif 'paired_' + self.config['common']['phase'] + 'A_lmk_folder' in self.config['dataset'] and \
|
115 |
+
'paired_' + self.config['common']['phase'] + 'B_lmk_folder' in self.config['dataset'] and \
|
116 |
+
self.config['dataset']['paired_' + self.config['common']['phase'] + 'A_lmk_folder'] != '' and \
|
117 |
+
self.config['dataset']['paired_' + self.config['common']['phase'] + 'B_lmk_folder'] != '':
|
118 |
+
new_datadict['paired_group']['lmk_A_folder'] = self.config['dataset']['paired_' + self.config['common']['phase'] + 'A_lmk_folder']
|
119 |
+
new_datadict['paired_group']['lmk_B_folder'] = self.config['dataset']['paired_' + self.config['common']['phase'] + 'B_lmk_folder']
|
120 |
+
else:
|
121 |
+
new_datadict['paired_group']['dataroot'] = self.config['dataset']['dataroot']
|
122 |
+
|
123 |
+
new_datadict['paired_group']['data_names'].append('lmk_A')
|
124 |
+
new_datadict['paired_group']['data_names'].append('lmk_B')
|
125 |
+
new_datadict['paired_group']['data_types'].append('landmark')
|
126 |
+
new_datadict['paired_group']['data_types'].append('landmark')
|
127 |
+
|
128 |
+
# Handle patches. This needs to happen after all non-patch data are added first.
|
129 |
+
if 'patch' in self.config['dataset']['data_type']:
|
130 |
+
# determine if patch comes from paired or unpaired image
|
131 |
+
if 'paired_A' in new_datadict['paired_group']['data_names']:
|
132 |
+
new_datadict['paired_group']['data_types'].append('patch')
|
133 |
+
new_datadict['paired_group']['data_names'].append('patch_A')
|
134 |
+
new_datadict['paired_group']['data_types'].append('patch')
|
135 |
+
new_datadict['paired_group']['data_names'].append('patch_B')
|
136 |
+
|
137 |
+
if 'patch_sources' not in new_datadict['paired_group']:
|
138 |
+
new_datadict['paired_group']['patch_sources'] = []
|
139 |
+
new_datadict['paired_group']['patch_sources'].append('paired_A')
|
140 |
+
new_datadict['paired_group']['patch_sources'].append('paired_B')
|
141 |
+
else:
|
142 |
+
new_datadict['unpaired_group']['data_types'].append('patch')
|
143 |
+
new_datadict['unpaired_group']['data_names'].append('patch_A')
|
144 |
+
new_datadict['unpaired_group']['data_types'].append('patch')
|
145 |
+
new_datadict['unpaired_group']['data_names'].append('patch_B')
|
146 |
+
|
147 |
+
if 'patch_sources' not in new_datadict['unpaired_group']:
|
148 |
+
new_datadict['unpaired_group']['patch_sources'] = []
|
149 |
+
new_datadict['unpaired_group']['patch_sources'].append('unpaired_A')
|
150 |
+
new_datadict['unpaired_group']['patch_sources'].append('unpaired_B')
|
151 |
+
|
152 |
+
if 'diff_patch' not in self.config['dataset']:
|
153 |
+
self.config['dataset']['diff_patch'] = False
|
154 |
+
|
155 |
+
new_datadict = {key: value for key, value in new_datadict.items() if len(value['data_names']) > 0}
|
156 |
+
|
157 |
+
print('-----------------------------------------------------------------------')
|
158 |
+
print("converted %s data configuration: " % self.config['common']['phase'])
|
159 |
+
for key, value in new_datadict.items():
|
160 |
+
print(key + ': ', value)
|
161 |
+
print('-----------------------------------------------------------------------')
|
162 |
+
|
163 |
+
return self.config
|
164 |
+
|
165 |
+
|
166 |
+
def combine_two_filelists_into_one(self, filelist1, filelist2):
|
167 |
+
tmp_file = open('./tmp_filelist.txt', 'w+')
|
168 |
+
f1 = open(filelist1, 'r')
|
169 |
+
f2 = open(filelist2, 'r')
|
170 |
+
f1_lines = f1.readlines()
|
171 |
+
f2_lines = f2.readlines()
|
172 |
+
min_index = min(len(f1_lines), len(f2_lines))
|
173 |
+
for i in range(min_index):
|
174 |
+
tmp_file.write(f1_lines[i].strip() + ' ' + f2_lines[i].strip() + '\n')
|
175 |
+
if min_index == len(f1_lines):
|
176 |
+
for i in range(min_index, len(f2_lines)):
|
177 |
+
tmp_file.write('None ' + f2_lines[i].strip() + '\n')
|
178 |
+
else:
|
179 |
+
for i in range(min_index, len(f1_lines)):
|
180 |
+
tmp_file.write(f1_lines[i].strip() + ' None\n')
|
181 |
+
|
182 |
+
tmp_file.close()
|
183 |
+
f1.close()
|
184 |
+
f2.close()
|
185 |
+
|
186 |
+
|
187 |
+
def __len__(self):
|
188 |
+
if self.config['common']['phase'] == 'test':
|
189 |
+
if self.config['testing']['test_video'] is not None:
|
190 |
+
return self.test_video_data.get_len()
|
191 |
+
else:
|
192 |
+
if len(self.data.keys()) == 0:
|
193 |
+
return 0
|
194 |
+
else:
|
195 |
+
min_len = 999999
|
196 |
+
for k, v in self.data.items():
|
197 |
+
length = len(v)
|
198 |
+
if length < min_len:
|
199 |
+
min_len = length
|
200 |
+
return min_len
|
201 |
+
else:
|
202 |
+
return self.static_data.get_len()
|
203 |
+
|
204 |
+
|
205 |
+
|
206 |
+
def get_item_logic(self, index):
|
207 |
+
return_dict = {}
|
208 |
+
|
209 |
+
if self.config['common']['phase'] == 'test':
|
210 |
+
if not self.config['testing']['test_video'] is None:
|
211 |
+
return self.test_video_data.get_item()
|
212 |
+
else:
|
213 |
+
apply_test_transforms(index, self.data, self.transforms, return_dict)
|
214 |
+
return return_dict
|
215 |
+
|
216 |
+
return_dict = self.static_data.get_item(index)
|
217 |
+
|
218 |
+
return return_dict
|
219 |
+
|
220 |
+
|
221 |
+
def __getitem__(self, index):
|
222 |
+
if self.config['dataset']['accept_data_error']:
|
223 |
+
while True:
|
224 |
+
try:
|
225 |
+
return self.get_item_logic(index)
|
226 |
+
except Exception as e:
|
227 |
+
print("Exception encountered in super_dataset's getitem function: ", e)
|
228 |
+
index = (index + 1) % self.__len__()
|
229 |
+
else:
|
230 |
+
return self.get_item_logic(index)
|
231 |
+
|
232 |
+
|
233 |
+
def split_data(self, value_mode, value, mode='split'):
|
234 |
+
new_dataset = copy.deepcopy(self)
|
235 |
+
ret1, new_dataset.static_data = self.split_data_helper(self.static_data, new_dataset.static_data, value_mode, value, mode=mode)
|
236 |
+
if ret1 is not None:
|
237 |
+
self.static_data = ret1
|
238 |
+
return self, new_dataset
|
239 |
+
|
240 |
+
|
241 |
+
def split_data_helper(self, dataset, new_dataset, value_mode, value, mode='split'):
|
242 |
+
for i in range(len(dataset.file_groups)):
|
243 |
+
max_split_index = 0
|
244 |
+
for k in dataset.file_groups[i].keys():
|
245 |
+
length = len(dataset.file_groups[i][k])
|
246 |
+
if value_mode == 'count':
|
247 |
+
split_index = min(length, value)
|
248 |
+
else:
|
249 |
+
split_index = int((1 - value) * length)
|
250 |
+
max_split_index = max(max_split_index, split_index)
|
251 |
+
new_dataset.file_groups[i][k] = new_dataset.file_groups[i][k][split_index:]
|
252 |
+
if mode == 'split':
|
253 |
+
dataset.file_groups[i][k] = dataset.file_groups[i][k][:split_index]
|
254 |
+
new_dataset.len_of_groups[i] -= max_split_index
|
255 |
+
if mode == 'split':
|
256 |
+
dataset.len_of_groups[i] = max_split_index
|
257 |
+
if mode == 'split':
|
258 |
+
return dataset, new_dataset
|
259 |
+
else:
|
260 |
+
return None, new_dataset
|
261 |
+
|
262 |
+
|
263 |
+
def check_data_helper(self, databin):
|
264 |
+
all_pass = True
|
265 |
+
for group in databin.filegroups:
|
266 |
+
for data_name, data_list in group.items():
|
267 |
+
for data in data_list:
|
268 |
+
if '.npy' in data: # case: numpy array or landmark
|
269 |
+
all_pass = all_pass and check_numpy_loaded(data)
|
270 |
+
else: # case: image
|
271 |
+
all_pass = all_pass and check_img_loaded(data)
|
272 |
+
return all_pass
|
273 |
+
|
274 |
+
|
275 |
+
def check_data(self):
|
276 |
+
if self.DDP_device is None or self.DDP_device == 0:
|
277 |
+
print("-----------------------Checking all data-------------------------------")
|
278 |
+
data_ok = True
|
279 |
+
if self.config['dataset']['n_threads'] == 0:
|
280 |
+
data_ok = data_ok and self.check_data_helper(self.static_data)
|
281 |
+
else:
|
282 |
+
# start n_threads number of workers to perform data checking
|
283 |
+
with Pool(processes=self.config['dataset']['n_threads']) as pool:
|
284 |
+
checks = pool.map(self.check_data_helper,
|
285 |
+
self.split_data_into_bins(self.config['dataset']['n_threads']))
|
286 |
+
for check in checks:
|
287 |
+
data_ok = data_ok and check
|
288 |
+
if data_ok:
|
289 |
+
print("---------------------all data passed check.-----------------------")
|
290 |
+
else:
|
291 |
+
print("---------------------The above data have failed in data checking. "
|
292 |
+
"Please fix first.---------------------------")
|
293 |
+
sys.exit()
|
294 |
+
|
295 |
+
|
296 |
+
def split_data_into_bins(self, num_bins):
|
297 |
+
bins = []
|
298 |
+
for i in range(num_bins):
|
299 |
+
bins.append(DataBin(filegroups=[]))
|
300 |
+
|
301 |
+
# handle static data
|
302 |
+
bins = self.split_data_into_bins_helper(bins, self.static_data)
|
303 |
+
return bins
|
304 |
+
|
305 |
+
|
306 |
+
def split_data_into_bins_helper(self, bins, dataset):
|
307 |
+
num_bins = len(bins)
|
308 |
+
for bin in bins:
|
309 |
+
for group_idx in range(len(dataset.file_groups)):
|
310 |
+
bin.filegroups.append({})
|
311 |
+
|
312 |
+
for group_idx in range(len(dataset.file_groups)):
|
313 |
+
file_group = dataset.file_groups[group_idx]
|
314 |
+
for data_name, data_list in file_group.items():
|
315 |
+
num_items_in_bin = len(data_list) // num_bins
|
316 |
+
for data_index in range(len(data_list)):
|
317 |
+
which_bin = min(data_index // num_items_in_bin, num_bins - 1)
|
318 |
+
if data_name not in bins[which_bin].filegroups[group_idx]:
|
319 |
+
bins[which_bin].filegroups[group_idx][data_name] = []
|
320 |
+
bins[which_bin].filegroups[group_idx][data_name].append(data_list[data_index])
|
321 |
+
return bins
|
data/test_data.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from utils.util import check_path_is_static_data
|
3 |
+
from utils.data_utils import Transforms
|
4 |
+
from utils.augmentation import ImagePathToImage, NumpyToTensor
|
5 |
+
|
6 |
+
def add_test_data(data, transforms, config):
|
7 |
+
A_paths = []
|
8 |
+
B_paths = []
|
9 |
+
|
10 |
+
if not config['testing']['test_img'] is None:
|
11 |
+
A_paths.append(config['testing']['test_img'])
|
12 |
+
B_paths.append(config['testing']['test_img'])
|
13 |
+
else:
|
14 |
+
files = os.listdir(config['testing']['test_folder'])
|
15 |
+
for fn in files:
|
16 |
+
if not check_path_is_static_data(fn):
|
17 |
+
continue
|
18 |
+
full_path = os.path.join(config['testing']['test_folder'], fn)
|
19 |
+
A_paths.append(full_path)
|
20 |
+
B_paths.append(full_path)
|
21 |
+
|
22 |
+
btoA = config['dataset']['direction'] == 'BtoA'
|
23 |
+
# get the number of channels of input image
|
24 |
+
input_nc = config['model']['output_nc'] if btoA else config['model']['input_nc']
|
25 |
+
output_nc = config['model']['input_nc'] if btoA else config['model']['output_nc']
|
26 |
+
|
27 |
+
transform = Transforms(config, input_grayscale_flag=(input_nc == 1), output_grayscale_flag=(output_nc == 1))
|
28 |
+
transform.create_transforms_from_list(config['testing']['preprocess'])
|
29 |
+
transform.get_transforms().insert(0, ImagePathToImage())
|
30 |
+
transform = transform.compose_transforms()
|
31 |
+
|
32 |
+
transform_np = Transforms(config, input_grayscale_flag=(input_nc == 1), output_grayscale_flag=(output_nc == 1))
|
33 |
+
transform_np.transform_list.append(NumpyToTensor())
|
34 |
+
transform_np = transform_np.compose_transforms()
|
35 |
+
|
36 |
+
data['test_A_path'] = A_paths
|
37 |
+
data['test_B_path'] = B_paths
|
38 |
+
transforms['test'] = transform
|
39 |
+
transforms['test_np'] = transform_np
|
40 |
+
|
41 |
+
def apply_test_transforms(index, data, transforms, return_dict):
|
42 |
+
if len(data['test_A_path']) > 0:
|
43 |
+
ext_name = os.path.splitext(data['test_A_path'][index])[1]
|
44 |
+
if not ext_name.lower() in ['.npy', '.npz']:
|
45 |
+
return_dict['test_A'], return_dict['test_B'] = transforms['test'] \
|
46 |
+
(data['test_A_path'][index], data['test_B_path'][index])
|
47 |
+
else:
|
48 |
+
return_dict['test_A'], return_dict['test_B'] = transforms['test_np'] \
|
49 |
+
(data['test_A_path'][index], data['test_B_path'][index])
|
50 |
+
return_dict['test_A_path'] = data['test_A_path'][index]
|
51 |
+
return_dict['test_B_path'] = data['test_B_path'][index]
|
data/test_video_data.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import imp
|
2 |
+
import cv2
|
3 |
+
from PIL import Image
|
4 |
+
from utils.data_utils import Transforms
|
5 |
+
|
6 |
+
|
7 |
+
class TestVideoData(object):
|
8 |
+
|
9 |
+
def __init__(self, config):
|
10 |
+
|
11 |
+
self.vcap = cv2.VideoCapture(config['testing']['test_video'])
|
12 |
+
self.transform = Transforms(config)
|
13 |
+
self.transform.create_transforms_from_list(config['testing']['preprocess'])
|
14 |
+
self.transform = self.transform.compose_transforms()
|
15 |
+
|
16 |
+
def __del__(self):
|
17 |
+
self.vcap.release()
|
18 |
+
|
19 |
+
def get_len(self):
|
20 |
+
return int(self.vcap.get(cv2.CAP_PROP_FRAME_COUNT))
|
21 |
+
|
22 |
+
def get_item(self):
|
23 |
+
return_dict = {}
|
24 |
+
_, frame = self.vcap.read()
|
25 |
+
frame = Image.fromarray(frame[:,:,::-1]).convert('RGB')
|
26 |
+
return_dict['test_A'], return_dict['test_B'] = self.transform(frame, frame)
|
27 |
+
return_dict['test_A_path'], return_dict['test_B_path'] = 'A.jpg', 'B.jpg'
|
28 |
+
return return_dict
|
exp/sp2pII-phase1.yaml
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
common:
|
2 |
+
name: "sp2pII-phase1"
|
3 |
+
model: "style_based_pix2pixII"
|
4 |
+
gpu_ids: [0]
|
5 |
+
option_group:
|
6 |
+
- gpu_ids: [0]
|
7 |
+
- gpu_ids: [1]
|
8 |
+
- gpu_ids: [2]
|
9 |
+
|
10 |
+
model:
|
11 |
+
ngf: 64
|
12 |
+
|
13 |
+
dataset:
|
14 |
+
unpaired_trainA_folder: # source domain folder path(FFHQ)
|
15 |
+
unpaired_trainB_folder: # target domain folder path(AAHQ)
|
16 |
+
preprocess: ["resize"]
|
17 |
+
batch_size: 8
|
18 |
+
crop_size: 512
|
19 |
+
drop_last: true
|
20 |
+
load_size: 512
|
21 |
+
|
22 |
+
training:
|
23 |
+
epoch_as_iter: true
|
24 |
+
n_epochs: 100000
|
25 |
+
n_epochs_decay: 10
|
26 |
+
print_freq: 1000
|
27 |
+
pretrained_model: "pretrained_models/ffhq_pretrain_res512_200000.pt"
|
28 |
+
save_epoch_freq: 5000
|
29 |
+
style_mixing_prob: 0.5
|
30 |
+
lambda_GAN: 1.0
|
31 |
+
lambda_ST: 1.0
|
32 |
+
lambda_L1: 1.0
|
33 |
+
option_group:
|
34 |
+
- lambda_Feat: 4.0
|
35 |
+
- lambda_Feat: 2.0
|
36 |
+
- lambda_Feat: 1.0
|
37 |
+
lr: 0.001
|
38 |
+
lr_policy: "linear"
|
39 |
+
beta1: 0.1
|
40 |
+
|
41 |
+
testing:
|
42 |
+
num_test: 100000
|
43 |
+
preprocess: ["resize"]
|
44 |
+
load_size: 512
|
45 |
+
crop_size: 512
|
46 |
+
results_dir: "./results/sp2pII"
|
47 |
+
visual_names: ["fake_B"]
|
48 |
+
image_format: "png"
|
49 |
+
which_epoch: "latest"
|
exp/sp2pII-phase2.yaml
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
common:
|
2 |
+
name: "sp2pII-phase2"
|
3 |
+
model: "style_based_pix2pixII"
|
4 |
+
gpu_ids: [0]
|
5 |
+
option_group:
|
6 |
+
- gpu_ids: [0]
|
7 |
+
- gpu_ids: [1]
|
8 |
+
- gpu_ids: [2]
|
9 |
+
|
10 |
+
model:
|
11 |
+
ngf: 64
|
12 |
+
|
13 |
+
dataset:
|
14 |
+
unpaired_trainA_folder: # source domain folder path(FFHQ)
|
15 |
+
unpaired_trainB_folder: # target domain folder path(AAHQ)
|
16 |
+
preprocess: ["resize"]
|
17 |
+
batch_size: 8
|
18 |
+
crop_size: 512
|
19 |
+
drop_last: true
|
20 |
+
load_size: 512
|
21 |
+
|
22 |
+
training:
|
23 |
+
epoch_as_iter: true
|
24 |
+
n_epochs: 300000
|
25 |
+
n_epochs_decay: 10
|
26 |
+
print_freq: 1000
|
27 |
+
phase: 2
|
28 |
+
pretrained_model: "pretrained_models/ffhq_pretrain_res512_200000.pt" # phase1 model
|
29 |
+
save_epoch_freq: 5000
|
30 |
+
style_mixing_prob: 0.5
|
31 |
+
lambda_GAN: 1.0
|
32 |
+
lambda_ST: 0.5 # 这个参数可以调整
|
33 |
+
option_group:
|
34 |
+
- data_aug_prob: 0.0
|
35 |
+
- data_aug_prob: 0.1
|
36 |
+
- data_aug_prob: 0.2
|
37 |
+
lr: 0.001
|
38 |
+
lr_policy: "linear"
|
39 |
+
beta1: 0.1
|
40 |
+
|
41 |
+
testing:
|
42 |
+
num_test: 100000
|
43 |
+
preprocess: ["resize"]
|
44 |
+
load_size: 512
|
45 |
+
crop_size: 512
|
46 |
+
results_dir: "./results/sp2pII"
|
47 |
+
visual_names: ["fake_B"]
|
48 |
+
image_format: "png"
|
49 |
+
which_epoch: "latest"
|
exp/sp2pII-phase3.yaml
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
common:
|
2 |
+
name: "sp2pII-phase3"
|
3 |
+
model: "style_based_pix2pixII"
|
4 |
+
gpu_ids: [0]
|
5 |
+
option_group:
|
6 |
+
- gpu_ids: [0]
|
7 |
+
- gpu_ids: [1]
|
8 |
+
- gpu_ids: [2]
|
9 |
+
|
10 |
+
model:
|
11 |
+
ngf: 64
|
12 |
+
|
13 |
+
dataset:
|
14 |
+
unpaired_trainA_folder: # source domain folder path(FFHQ)
|
15 |
+
unpaired_trainB_folder: # target domain folder path(AAHQ)
|
16 |
+
preprocess: ["resize"]
|
17 |
+
batch_size: 8
|
18 |
+
crop_size: 512
|
19 |
+
drop_last: true
|
20 |
+
load_size: 512
|
21 |
+
|
22 |
+
training:
|
23 |
+
epoch_as_iter: true
|
24 |
+
n_epochs: 100000 # 这个收敛很快,10w iter就差不多了
|
25 |
+
n_epochs_decay: 10
|
26 |
+
print_freq: 1000
|
27 |
+
phase: 3
|
28 |
+
pretrained_model: "pretrained_models/phase2_pretrain_90000.pth"
|
29 |
+
save_epoch_freq: 5000
|
30 |
+
style_mixing_prob: 0.5
|
31 |
+
lambda_GAN: 1.0
|
32 |
+
lambda_ST: 1.0
|
33 |
+
lambda_L1: 1.0
|
34 |
+
option_group:
|
35 |
+
- lambda_Feat: 4.0
|
36 |
+
- lambda_Feat: 2.0
|
37 |
+
- lambda_Feat: 1.0
|
38 |
+
lr: 0.0002
|
39 |
+
lr_policy: "linear"
|
40 |
+
beta1: 0.9
|
41 |
+
|
42 |
+
testing:
|
43 |
+
num_test: 100000
|
44 |
+
preprocess: ["resize"]
|
45 |
+
load_size: 512
|
46 |
+
crop_size: 512
|
47 |
+
results_dir: "./results/sp2pII"
|
48 |
+
visual_names: ["fake_B"]
|
49 |
+
image_format: "png"
|
50 |
+
which_epoch: "latest"
|
exp/sp2pII-phase4.yaml
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
common:
|
2 |
+
name: sp2pII-phase4
|
3 |
+
model: style_based_pix2pixII
|
4 |
+
gpu_ids:
|
5 |
+
- 0
|
6 |
+
|
7 |
+
dataset:
|
8 |
+
batch_size: 8
|
9 |
+
crop_size: 512
|
10 |
+
drop_last: true
|
11 |
+
load_size: 512
|
12 |
+
preprocess:
|
13 |
+
- resize
|
14 |
+
unpaired_trainA_folder: "/share/group_machongyang/project/hand_drawn/data/dataset/0909/trainPA/" # source domain folder path(FFHQ)
|
15 |
+
unpaired_trainB_folder: "/share/group_machongyang/project/hand_drawn/data/dataset/0909/trainPA/" # source domain folder path(AAHQ)
|
16 |
+
model:
|
17 |
+
ngf: 64
|
18 |
+
testing:
|
19 |
+
crop_size: 512
|
20 |
+
image_format: png
|
21 |
+
load_size: 512
|
22 |
+
num_test: 100000
|
23 |
+
preprocess:
|
24 |
+
- resize
|
25 |
+
results_dir: ./results/sp2pII
|
26 |
+
visual_names:
|
27 |
+
- fake_B
|
28 |
+
which_epoch: latest
|
29 |
+
training:
|
30 |
+
beta1: 0.9
|
31 |
+
epoch_as_iter: true
|
32 |
+
lambda_Feat: 4.0
|
33 |
+
lambda_GAN: 1.0
|
34 |
+
lambda_L1: 1.0
|
35 |
+
lambda_ST: 0.5
|
36 |
+
lambda_CLIP: 1.0 # 这个参数需要调整
|
37 |
+
lambda_PROJ: 100.0 # 这个参数需要调整(仅使用image prompt情况)
|
38 |
+
ema: 0.99 # 1-1/n
|
39 |
+
text_prompt: "not existed"
|
40 |
+
image_prompt: "" # 如果这个文件存在就用image prompt, 否则用text prompt
|
41 |
+
lr: 0.0002 # 这个参数需要调整, 大概1e-5 ~ 2e-4之间
|
42 |
+
lr_policy: linear
|
43 |
+
n_epochs: 200 # 这个一般500 iter就够了
|
44 |
+
n_epochs_decay: 10
|
45 |
+
phase: 4
|
46 |
+
pretrained_model: pretrained_models/phase3_pretrain_10000.pth
|
47 |
+
print_freq: 50
|
48 |
+
save_epoch_freq: 200
|
49 |
+
style_mixing_prob: 0.5
|
logs/01_2023_09_07__18_32_26/events.out.tfevents.1694082748.aiplatform-wlf2-hi-12.idchb2az2.hb2.kwaidc.com.16044.0
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:8f0a1453d4a30c719d0a3e7b8aa8e0cecdd4d38d931606833e3fb5ce4165d171
|
3 |
+
size 38280782
|
logs/01_2023_09_12__14_54_32/events.out.tfevents.1694501684.aiplatform-wlf2-ge4-22.idchb2az2.hb2.kwaidc.com.76748.0
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:8720497d0782c29971a6d1c2d5f538204f0be8cc406a1f6e2584d0450c9bd179
|
3 |
+
size 40
|
logs/01_2023_09_12__14_55_34/events.out.tfevents.1694501736.aiplatform-wlf2-ge4-22.idchb2az2.hb2.kwaidc.com.77369.0
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:8e810b330e416e9f922c525b985de435ba2ef25afa013b9c732372a6dac58cf8
|
3 |
+
size 31611368
|
logs/01_2023_09_12__15_03_47/events.out.tfevents.1694502229.aiplatform-wlf2-ge4-22.idchb2az2.hb2.kwaidc.com.77940.0
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:9b91beff44afb7b30e880de6477906b5adbdd771f2514158c166093018a0ee55
|
3 |
+
size 30850690
|
models/__init__.py
ADDED
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""This package contains modules related to objective functions, optimizations, and network architectures.
|
2 |
+
|
3 |
+
To add a custom model class called 'dummy', you need to add a file called 'dummy_model.py' and define a subclass DummyModel inherited from BaseModel.
|
4 |
+
You need to implement the following five functions:
|
5 |
+
-- <__init__>: initialize the class; first call BaseModel.__init__(self, opt).
|
6 |
+
-- <set_input>: unpack data from dataset and apply preprocessing.
|
7 |
+
-- <forward>: produce intermediate results.
|
8 |
+
-- <optimize_parameters>: calculate loss, gradients, and update network weights.
|
9 |
+
-- <modify_commandline_options>: (optionally) add model-specific options and set default options.
|
10 |
+
|
11 |
+
In the function <__init__>, you need to define four lists:
|
12 |
+
-- self.loss_names (str list): specify the training losses that you want to plot and save.
|
13 |
+
-- self.model_names (str list): define networks used in our training.
|
14 |
+
-- self.visual_names (str list): specify the images that you want to display and save.
|
15 |
+
-- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an usage.
|
16 |
+
|
17 |
+
Now you can use the model class by specifying flag '--model dummy'.
|
18 |
+
See our template model class 'template_model.py' for more details.
|
19 |
+
"""
|
20 |
+
|
21 |
+
import importlib
|
22 |
+
from models.base_model import BaseModel
|
23 |
+
|
24 |
+
|
25 |
+
def find_model_using_name(model_name):
|
26 |
+
"""Import the module "models/[model_name]_model.py".
|
27 |
+
|
28 |
+
In the file, the class called DatasetNameModel() will
|
29 |
+
be instantiated. It has to be a subclass of BaseModel,
|
30 |
+
and it is case-insensitive.
|
31 |
+
"""
|
32 |
+
model_filename = "models." + model_name + "_model"
|
33 |
+
modellib = importlib.import_module(model_filename)
|
34 |
+
#print(modellib)
|
35 |
+
model = None
|
36 |
+
target_model_name = model_name.replace('_', '') + 'model'
|
37 |
+
for name, cls in modellib.__dict__.items():
|
38 |
+
if name.lower() == target_model_name.lower() \
|
39 |
+
and issubclass(cls, BaseModel):
|
40 |
+
model = cls
|
41 |
+
|
42 |
+
if model is None:
|
43 |
+
print("In %s.py, there should be a subclass of BaseModel with class name that matches %s in lowercase." % (model_filename, target_model_name))
|
44 |
+
exit(0)
|
45 |
+
|
46 |
+
return model
|
47 |
+
|
48 |
+
|
49 |
+
def get_option_setter(model_name):
|
50 |
+
"""Return the static method <modify_commandline_options> of the model class."""
|
51 |
+
model_class = find_model_using_name(model_name)
|
52 |
+
return model_class.modify_commandline_options
|
53 |
+
|
54 |
+
|
55 |
+
def create_model(config, DDP_device=None):
|
56 |
+
"""Create a model given the option.
|
57 |
+
|
58 |
+
This function warps the class CustomDatasetDataLoader.
|
59 |
+
This is the main interface between this package and 'train.py'/'test.py'
|
60 |
+
|
61 |
+
Example:
|
62 |
+
>>> from models import create_model
|
63 |
+
>>> model = create_model(opt)
|
64 |
+
"""
|
65 |
+
model = find_model_using_name(config['common']['model'])
|
66 |
+
instance = model(config, DDP_device=DDP_device)
|
67 |
+
print("model [%s] was created" % type(instance).__name__)
|
68 |
+
return instance
|
models/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (3.27 kB). View file
|
|
models/__pycache__/base_model.cpython-38.pyc
ADDED
Binary file (12.2 kB). View file
|
|
models/__pycache__/style_based_pix2pixII_model.cpython-38.pyc
ADDED
Binary file (15.6 kB). View file
|
|
models/base_model.py
ADDED
@@ -0,0 +1,340 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
from collections import OrderedDict
|
4 |
+
from abc import ABC, abstractmethod
|
5 |
+
from models.modules import networks
|
6 |
+
from utils.util import check_path
|
7 |
+
from utils.net_size import calc_computation
|
8 |
+
|
9 |
+
|
10 |
+
class BaseModel(ABC):
|
11 |
+
"""This class is an abstract base class (ABC) for models.
|
12 |
+
To create a subclass, you need to implement the following five functions:
|
13 |
+
-- <__init__>: initialize the class; first call BaseModel.__init__(self, opt).
|
14 |
+
-- <set_input>: unpack data from dataset and apply preprocessing.
|
15 |
+
-- <forward>: produce intermediate results.
|
16 |
+
-- <optimize_parameters>: calculate losses, gradients, and update network weights.
|
17 |
+
-- <modify_commandline_options>: (optionally) add model-specific options and set default options.
|
18 |
+
"""
|
19 |
+
|
20 |
+
def __init__(self, config, DDP_device=None):
|
21 |
+
"""Initialize the BaseModel class.
|
22 |
+
|
23 |
+
Parameters:
|
24 |
+
opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
|
25 |
+
|
26 |
+
When creating your custom class, you need to implement your own initialization.
|
27 |
+
In this function, you should first call <BaseModel.__init__(self, opt)>
|
28 |
+
Then, you need to define four lists:
|
29 |
+
-- self.loss_names (str list): specify the training losses that you want to plot and save.
|
30 |
+
-- self.model_names (str list): define networks used in our training.
|
31 |
+
-- self.visual_names (str list): specify the images that you want to display and save.
|
32 |
+
-- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an example.
|
33 |
+
"""
|
34 |
+
self.config = config
|
35 |
+
self.gpu_ids = config['common']['gpu_ids']
|
36 |
+
self.isTrain = config['common']['phase'] == 'train'
|
37 |
+
if DDP_device is None:
|
38 |
+
self.device = torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu') # get device name: CPU or GPU
|
39 |
+
self.DDP_device = None
|
40 |
+
self.on_cpu = (self.device.type == 'cpu')
|
41 |
+
else:
|
42 |
+
self.device = DDP_device
|
43 |
+
self.DDP_device = DDP_device
|
44 |
+
self.on_cpu = False
|
45 |
+
self.save_dir = os.path.join(config['training']['checkpoints_dir'], config['common']['name']) # save all the checkpoints to save_dir
|
46 |
+
if config['dataset']['preprocess'] != 'scale_width': # with [scale_width], input images might have different sizes, which hurts the performance of cudnn.benchmark.
|
47 |
+
torch.backends.cudnn.benchmark = True
|
48 |
+
self.loss_names = []
|
49 |
+
self.model_names = []
|
50 |
+
self.visual_names = []
|
51 |
+
self.optimizers = []
|
52 |
+
self.image_paths = []
|
53 |
+
self.metric = 0 # used for learning rate policy 'plateau'
|
54 |
+
self.curr_epoch = 0
|
55 |
+
self.total_iters = 0
|
56 |
+
self.best_val_loss = 999999
|
57 |
+
|
58 |
+
@abstractmethod
|
59 |
+
def set_input(self, input):
|
60 |
+
"""Unpack input data from the dataloader and perform necessary pre-processing steps.
|
61 |
+
|
62 |
+
Parameters:
|
63 |
+
input (dict): includes the data itself and its metadata information.
|
64 |
+
"""
|
65 |
+
pass
|
66 |
+
|
67 |
+
@abstractmethod
|
68 |
+
def forward(self):
|
69 |
+
"""Run forward pass; called by both functions <configimize_parameters> and <test>."""
|
70 |
+
pass
|
71 |
+
|
72 |
+
@abstractmethod
|
73 |
+
def trace_jit(self, input):
|
74 |
+
"""trace torchscript model for C++. Called by <trace_jit.py>"""
|
75 |
+
pass
|
76 |
+
|
77 |
+
@abstractmethod
|
78 |
+
def optimize_parameters(self):
|
79 |
+
"""Calculate losses, gradients, and update network weights; called in every training iteration"""
|
80 |
+
pass
|
81 |
+
|
82 |
+
@abstractmethod
|
83 |
+
def eval_step(self):
|
84 |
+
"""Forward and backward pass but without upgrading weights; called in every validation iteration"""
|
85 |
+
pass
|
86 |
+
|
87 |
+
def setup(self, config, DDP_device=None):
|
88 |
+
"""Load and print networks; create schedulers
|
89 |
+
|
90 |
+
Parameters:
|
91 |
+
opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
|
92 |
+
"""
|
93 |
+
if self.isTrain:
|
94 |
+
self.schedulers = [networks.get_scheduler(optimizer, config) for optimizer in self.optimizers]
|
95 |
+
if not self.isTrain:
|
96 |
+
load_suffix = '{}'.format(config['testing']['which_epoch'])
|
97 |
+
self.load_networks(load_suffix)
|
98 |
+
elif config['training']['continue_train']:
|
99 |
+
load_suffix = '{}'.format(config['training']['which_epoch'])
|
100 |
+
self.load_networks(load_suffix)
|
101 |
+
self.print_networks(config['common']['verbose'], DDP_device=DDP_device)
|
102 |
+
|
103 |
+
def eval(self):
|
104 |
+
"""Make models eval mode during test time"""
|
105 |
+
for name in self.model_names:
|
106 |
+
if isinstance(name, str):
|
107 |
+
net = getattr(self, 'net' + name)
|
108 |
+
net.eval()
|
109 |
+
|
110 |
+
def train(self):
|
111 |
+
"""Make models train mode during train time"""
|
112 |
+
for name in self.model_names:
|
113 |
+
if isinstance(name, str):
|
114 |
+
net = getattr(self, 'net' + name)
|
115 |
+
net.train()
|
116 |
+
|
117 |
+
def test(self):
|
118 |
+
"""Forward function used in test time.
|
119 |
+
|
120 |
+
This function wraps <forward> function in no_grad() so we don't save intermediate steps for backprop
|
121 |
+
It also calls <compute_visuals> to produce additional visualization results
|
122 |
+
"""
|
123 |
+
with torch.no_grad():
|
124 |
+
self.forward()
|
125 |
+
self.compute_visuals()
|
126 |
+
|
127 |
+
def compute_visuals(self):
|
128 |
+
"""Calculate additional output images for visdom and HTML visualization"""
|
129 |
+
pass
|
130 |
+
|
131 |
+
def get_image_paths(self):
|
132 |
+
""" Return image paths that are used to load current data"""
|
133 |
+
return self.image_paths
|
134 |
+
|
135 |
+
def update_learning_rate(self):
|
136 |
+
"""Update learning rates for all the networks; called at the end of every epoch"""
|
137 |
+
for scheduler in self.schedulers:
|
138 |
+
if self.config['training']['lr_policy'] == 'plateau':
|
139 |
+
scheduler.step(self.metric, epoch=self.curr_epoch)
|
140 |
+
else:
|
141 |
+
scheduler.step(epoch=self.curr_epoch)
|
142 |
+
|
143 |
+
# lr = self.optimizers[0].param_groups[0]['lr']
|
144 |
+
# print('learning rate = %.7f' % lr)
|
145 |
+
|
146 |
+
def get_current_visuals(self):
|
147 |
+
"""Return visualization images. train.py will display these images with visdom, and save the images to a HTML"""
|
148 |
+
if not self.isTrain and len(self.config['testing']['visual_names']) > 0:
|
149 |
+
visual_names = list(set(self.visual_names).intersection(set(self.config['testing']['visual_names'])))
|
150 |
+
else:
|
151 |
+
visual_names = self.visual_names
|
152 |
+
visual_ret = OrderedDict()
|
153 |
+
for name in visual_names:
|
154 |
+
if isinstance(name, str):
|
155 |
+
visual_ret[name] = getattr(self, name)
|
156 |
+
return visual_ret
|
157 |
+
|
158 |
+
def get_current_losses(self):
|
159 |
+
"""Return traning losses / errors. train.py will print out these errors on console, and save them to a file"""
|
160 |
+
errors_ret = OrderedDict()
|
161 |
+
for name in self.loss_names:
|
162 |
+
if isinstance(name, str):
|
163 |
+
errors_ret[name] = float(getattr(self, 'loss_' + name)) # float(...) works for both scalar tensor and float number
|
164 |
+
return errors_ret
|
165 |
+
|
166 |
+
def save_networks(self, epoch, val_loss=None):
|
167 |
+
"""Save all the networks to the disk.
|
168 |
+
|
169 |
+
Parameters:
|
170 |
+
epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name)
|
171 |
+
"""
|
172 |
+
check_path(self.save_dir)
|
173 |
+
save_filename = 'epoch_%s.pth' % epoch if val_loss is None else 'best_val_epoch.pth'
|
174 |
+
checkpoint = {}
|
175 |
+
# save all the models
|
176 |
+
for name in self.model_names:
|
177 |
+
if isinstance(name, str):
|
178 |
+
net = getattr(self, 'net' + name)
|
179 |
+
if len(self.gpu_ids) > 0 and torch.cuda.is_available():
|
180 |
+
# if use DDP, save only on rank 0. If using dataparallel, second condition meets.
|
181 |
+
if self.DDP_device == 0 or self.DDP_device is None:
|
182 |
+
checkpoint[name+'_model'] = net.module.state_dict()
|
183 |
+
else:
|
184 |
+
checkpoint[name+'_model'] = net.state_dict()
|
185 |
+
|
186 |
+
# save all the optimizers
|
187 |
+
optimizer_index = 0
|
188 |
+
for optimizer in self.optimizers:
|
189 |
+
checkpoint['optimizer_'+str(optimizer_index)] = optimizer.state_dict()
|
190 |
+
optimizer_index += 1
|
191 |
+
|
192 |
+
# save all the schedulers
|
193 |
+
scheduler_index = 0
|
194 |
+
for scheduler in self.schedulers:
|
195 |
+
checkpoint['scheduler_' + str(scheduler_index)] = scheduler.state_dict()
|
196 |
+
scheduler_index += 1
|
197 |
+
|
198 |
+
# save other information
|
199 |
+
checkpoint['epoch'] = self.curr_epoch
|
200 |
+
checkpoint['total_iters'] = self.total_iters
|
201 |
+
checkpoint['metric'] = self.metric
|
202 |
+
if val_loss is not None:
|
203 |
+
checkpoint['best_val_loss'] = val_loss
|
204 |
+
|
205 |
+
torch.save(checkpoint, os.path.join(self.save_dir, save_filename))
|
206 |
+
|
207 |
+
def __patch_instance_norm_state_dict(self, state_dict, module, keys, i=0):
|
208 |
+
"""Fix InstanceNorm checkpoints incompatibility (prior to 0.4)"""
|
209 |
+
key = keys[i]
|
210 |
+
if i + 1 == len(keys): # at the end, pointing to a parameter/buffer
|
211 |
+
if module.__class__.__name__.startswith('InstanceNorm') and \
|
212 |
+
(key == 'running_mean' or key == 'running_var'):
|
213 |
+
if getattr(module, key) is None:
|
214 |
+
state_dict.pop('.'.join(keys))
|
215 |
+
if module.__class__.__name__.startswith('InstanceNorm') and \
|
216 |
+
(key == 'num_batches_tracked'):
|
217 |
+
state_dict.pop('.'.join(keys))
|
218 |
+
else:
|
219 |
+
self.__patch_instance_norm_state_dict(state_dict, getattr(module, key), keys, i + 1)
|
220 |
+
|
221 |
+
def load_networks(self, epoch, ckpt=None):
|
222 |
+
"""Load all the networks from the disk.
|
223 |
+
|
224 |
+
Parameters:
|
225 |
+
epoch (str) -- current epoch; used in the file name 'epoch_%s.pth' % epoch. Models in the old format
|
226 |
+
with the names '%s_net_%s.pth' % (epoch, name) are also supported. Models in the new format takes priority.
|
227 |
+
"""
|
228 |
+
load_filename = 'epoch_%s.pth' % epoch
|
229 |
+
if ckpt is None:
|
230 |
+
final_load_path = os.path.join(self.save_dir, load_filename)
|
231 |
+
else:
|
232 |
+
final_load_path = ckpt
|
233 |
+
if os.path.exists(final_load_path):
|
234 |
+
# new checkpoint format.
|
235 |
+
print('loading the model in new format from %s' % final_load_path)
|
236 |
+
if self.DDP_device is not None:
|
237 |
+
# unpack the tensors on GPU 0, then transfer to whatever device it needs to be on
|
238 |
+
map_location = {'cuda:%d' % 0: 'cuda:%d' % self.DDP_device}
|
239 |
+
checkpoint = torch.load(final_load_path, map_location=map_location)
|
240 |
+
else:
|
241 |
+
checkpoint = torch.load(final_load_path)
|
242 |
+
for k, v in checkpoint.items():
|
243 |
+
# load models
|
244 |
+
if 'model' in k:
|
245 |
+
name = k.split('_model')[0]
|
246 |
+
if not self.isTrain and 'D' in name: # does not load discriminator when not training
|
247 |
+
continue
|
248 |
+
if not hasattr(self, 'net' + name):
|
249 |
+
continue
|
250 |
+
net = getattr(self, 'net' + name)
|
251 |
+
if isinstance(net, torch.nn.DataParallel) or isinstance(net, torch.nn.parallel.DistributedDataParallel):
|
252 |
+
net = net.module
|
253 |
+
|
254 |
+
# if you are using PyTorch newer than 0.4 (e.g., built from
|
255 |
+
# GitHub source), you can remove str() on self.device
|
256 |
+
if hasattr(v, '_metadata'):
|
257 |
+
del v._metadata
|
258 |
+
|
259 |
+
# patch InstanceNorm checkpoints prior to 0.4
|
260 |
+
for key in list(v.keys()): # need to copy keys here because we mutate in loop
|
261 |
+
self.__patch_instance_norm_state_dict(v, net, key.split('.'))
|
262 |
+
net.load_state_dict(v)
|
263 |
+
# load optimizers
|
264 |
+
elif 'optimizer' in k:
|
265 |
+
if not self.isTrain:
|
266 |
+
continue
|
267 |
+
index = int(k.split('_')[-1])
|
268 |
+
self.optimizers[index].load_state_dict(v)
|
269 |
+
# load schedulers
|
270 |
+
elif 'scheduler' in k:
|
271 |
+
if not self.isTrain:
|
272 |
+
continue
|
273 |
+
index = int(k.split('_')[-1])
|
274 |
+
self.schedulers[index].load_state_dict(v)
|
275 |
+
# load other stuffs
|
276 |
+
elif k == 'epoch':
|
277 |
+
self.curr_epoch = int(v) + 1
|
278 |
+
elif k == 'total_iters':
|
279 |
+
self.total_iters = int(v)
|
280 |
+
elif k == 'metric':
|
281 |
+
self.metric = float(v)
|
282 |
+
elif k == 'best_val_loss':
|
283 |
+
self.best_val_loss = float(v)
|
284 |
+
else:
|
285 |
+
print('Checkpoint load error. Unrecognized parameter saved in checkpoint: ', k)
|
286 |
+
return
|
287 |
+
|
288 |
+
# old checkpoint format.
|
289 |
+
for name in self.model_names:
|
290 |
+
if isinstance(name, str):
|
291 |
+
load_filename = '%s_net_%s.pth' % (epoch, name)
|
292 |
+
load_path = os.path.join(self.save_dir, load_filename)
|
293 |
+
net = getattr(self, 'net' + name)
|
294 |
+
if isinstance(net, torch.nn.DataParallel) or isinstance(net, torch.nn.parallel.DistributedDataParallel):
|
295 |
+
net = net.module
|
296 |
+
print('loading the model from %s' % load_path)
|
297 |
+
# if you are using PyTorch newer than 0.4 (e.g., built from
|
298 |
+
# GitHub source), you can remove str() on self.device
|
299 |
+
state_dict = torch.load(load_path, map_location=str(self.device))
|
300 |
+
if hasattr(state_dict, '_metadata'):
|
301 |
+
del state_dict._metadata
|
302 |
+
|
303 |
+
# patch InstanceNorm checkpoints prior to 0.4
|
304 |
+
for key in list(state_dict.keys()): # need to copy keys here because we mutate in loop
|
305 |
+
self.__patch_instance_norm_state_dict(state_dict, net, key.split('.'))
|
306 |
+
net.load_state_dict(state_dict)
|
307 |
+
|
308 |
+
def print_networks(self, verbose, DDP_device=None):
|
309 |
+
"""Print the total number of parameters in the network and (if verbose) network architecture
|
310 |
+
|
311 |
+
Parameters:
|
312 |
+
verbose (bool) -- if verbose: print the network architecture
|
313 |
+
"""
|
314 |
+
if DDP_device is None or DDP_device == 0:
|
315 |
+
print('---------- Networks initialized -------------')
|
316 |
+
for name in self.model_names:
|
317 |
+
if isinstance(name, str):
|
318 |
+
net = getattr(self, 'net' + name)
|
319 |
+
num_params = 0
|
320 |
+
for param in net.parameters():
|
321 |
+
num_params += param.numel()
|
322 |
+
if verbose:
|
323 |
+
print(net)
|
324 |
+
print('[Network %s] Total number of parameters : %.3f M' % (name, num_params / 1e6))
|
325 |
+
if 'G' in name:
|
326 |
+
calc_computation(net, self.config['model']['input_nc'], self.config['dataset']['crop_size'],self.config['dataset']['crop_size'], DDP_device=DDP_device)
|
327 |
+
print('-----------------------------------------------')
|
328 |
+
|
329 |
+
def set_requires_grad(self, nets, requires_grad=False):
|
330 |
+
"""Set requies_grad=Fasle for all the networks to avoid unnecessary computations
|
331 |
+
Parameters:
|
332 |
+
nets (network list) -- a list of networks
|
333 |
+
requires_grad (bool) -- whether the networks require gradients or not
|
334 |
+
"""
|
335 |
+
if not isinstance(nets, list):
|
336 |
+
nets = [nets]
|
337 |
+
for net in nets:
|
338 |
+
if net is not None:
|
339 |
+
for param in net.parameters():
|
340 |
+
param.requires_grad = requires_grad
|
models/modules/__init__.py
ADDED
File without changes
|
models/modules/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (155 Bytes). View file
|
|
models/modules/__pycache__/networks.cpython-38.pyc
ADDED
Binary file (37.8 kB). View file
|
|
models/modules/networks.py
ADDED
@@ -0,0 +1,1101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
import torch.nn.utils.spectral_norm as spectral_norm
|
6 |
+
from torch.nn import init
|
7 |
+
from torch.autograd import Variable
|
8 |
+
import functools
|
9 |
+
from torch.optim import lr_scheduler
|
10 |
+
from packaging import version
|
11 |
+
import numpy as np
|
12 |
+
|
13 |
+
###############################################################################
|
14 |
+
# Helper Functions
|
15 |
+
###############################################################################
|
16 |
+
|
17 |
+
|
18 |
+
class Identity(nn.Module):
|
19 |
+
def forward(self, x):
|
20 |
+
return x
|
21 |
+
|
22 |
+
|
23 |
+
def get_norm_layer(norm_type='instance'):
|
24 |
+
"""Return a normalization layer
|
25 |
+
|
26 |
+
Parameters:
|
27 |
+
norm_type (str) -- the name of the normalization layer: batch | instance | none
|
28 |
+
|
29 |
+
For BatchNorm, we use learnable affine parameters and track running statistics (mean/stddev).
|
30 |
+
For InstanceNorm, we do not use learnable affine parameters. We do not track running statistics.
|
31 |
+
"""
|
32 |
+
if norm_type == 'batch':
|
33 |
+
norm_layer = functools.partial(nn.BatchNorm2d, affine=True, track_running_stats=True)
|
34 |
+
elif norm_type == 'instance':
|
35 |
+
norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False)
|
36 |
+
elif norm_type == 'none':
|
37 |
+
def norm_layer(x): return Identity()
|
38 |
+
else:
|
39 |
+
raise NotImplementedError('normalization layer [%s] is not found' % norm_type)
|
40 |
+
return norm_layer
|
41 |
+
|
42 |
+
|
43 |
+
def get_scheduler(optimizer, config):
|
44 |
+
"""Return a learning rate scheduler
|
45 |
+
|
46 |
+
Parameters:
|
47 |
+
optimizer -- the optimizer of the network
|
48 |
+
opt (option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions.
|
49 |
+
opt.lr_policy is the name of learning rate policy: linear | step | plateau | cosine
|
50 |
+
|
51 |
+
For 'linear', we keep the same learning rate for the first <opt.n_epochs> epochs
|
52 |
+
and linearly decay the rate to zero over the next <opt.n_epochs_decay> epochs.
|
53 |
+
For other schedulers (step, plateau, and cosine), we use the default PyTorch schedulers.
|
54 |
+
See https://pytorch.org/docs/stable/optim.html for more details.
|
55 |
+
"""
|
56 |
+
if config['training']['lr_policy'] == 'linear':
|
57 |
+
def lambda_rule(epoch):
|
58 |
+
lr_l = 1.0 - max(0, epoch + 1 - config['training']['n_epochs']) / float(config['training']['n_epochs_decay'] + 1)
|
59 |
+
return lr_l
|
60 |
+
scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule)
|
61 |
+
elif config['training']['lr_policy'] == 'step':
|
62 |
+
scheduler = lr_scheduler.StepLR(optimizer, step_size=config['training']['lr_decay_iters'], gamma=0.1)
|
63 |
+
elif config['training']['lr_policy'] == 'plateau':
|
64 |
+
scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5)
|
65 |
+
elif config['training']['lr_policy'] == 'cosine':
|
66 |
+
scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=config['training']['n_epochs'], eta_min=0)
|
67 |
+
else:
|
68 |
+
return NotImplementedError('learning rate policy [%s] is not implemented', config['training']['lr_policy'])
|
69 |
+
return scheduler
|
70 |
+
|
71 |
+
|
72 |
+
def init_weights(net, init_type='normal', init_gain=0.02):
|
73 |
+
"""Initialize network weights.
|
74 |
+
|
75 |
+
Parameters:
|
76 |
+
net (network) -- network to be initialized
|
77 |
+
init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal
|
78 |
+
init_gain (float) -- scaling factor for normal, xavier and orthogonal.
|
79 |
+
|
80 |
+
We use 'normal' in the original pix2pix and CycleGAN paper. But xavier and kaiming might
|
81 |
+
work better for some applications. Feel free to try yourself.
|
82 |
+
"""
|
83 |
+
def init_func(m): # define the initialization function
|
84 |
+
classname = m.__class__.__name__
|
85 |
+
if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
|
86 |
+
if init_type == 'normal':
|
87 |
+
init.normal_(m.weight.data, 0.0, init_gain)
|
88 |
+
elif init_type == 'xavier':
|
89 |
+
init.xavier_normal_(m.weight.data, gain=init_gain)
|
90 |
+
elif init_type == 'kaiming':
|
91 |
+
init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
|
92 |
+
elif init_type == 'orthogonal':
|
93 |
+
init.orthogonal_(m.weight.data, gain=init_gain)
|
94 |
+
else:
|
95 |
+
raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
|
96 |
+
if hasattr(m, 'bias') and m.bias is not None:
|
97 |
+
init.constant_(m.bias.data, 0.0)
|
98 |
+
elif classname.find('BatchNorm2d') != -1: # BatchNorm Layer's weight is not a matrix; only normal distribution applies.
|
99 |
+
if m.affine:
|
100 |
+
init.normal_(m.weight.data, 1.0, init_gain)
|
101 |
+
init.constant_(m.bias.data, 0.0)
|
102 |
+
|
103 |
+
if not init_type == 'none':
|
104 |
+
net.apply(init_func) # apply the initialization function <init_func>
|
105 |
+
|
106 |
+
|
107 |
+
def init_net(net, init_type='normal', init_gain=0.02, gpu_ids=[], DDP_device=None, find_unused_parameters=False):
|
108 |
+
"""Initialize a network: 1. register CPU/GPU device (with multi-GPU support); 2. initialize the network weights
|
109 |
+
Parameters:
|
110 |
+
net (network) -- the network to be initialized
|
111 |
+
init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal
|
112 |
+
gain (float) -- scaling factor for normal, xavier and orthogonal.
|
113 |
+
gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2
|
114 |
+
|
115 |
+
Return an initialized network.
|
116 |
+
"""
|
117 |
+
init_weights(net, init_type, init_gain=init_gain)
|
118 |
+
if DDP_device is not None:
|
119 |
+
net.to(DDP_device)
|
120 |
+
net = torch.nn.parallel.DistributedDataParallel(net, device_ids=[DDP_device], output_device=DDP_device,
|
121 |
+
broadcast_buffers=False, find_unused_parameters=find_unused_parameters) # DDP multi-GPUs
|
122 |
+
if DDP_device == 0:
|
123 |
+
print("model initiated in DDP mode.")
|
124 |
+
elif gpu_ids is not None and len(gpu_ids) > 0:
|
125 |
+
assert(torch.cuda.is_available())
|
126 |
+
net.to(gpu_ids[0])
|
127 |
+
net = torch.nn.DataParallel(net, gpu_ids) # multi-GPUs
|
128 |
+
print("model initiated in dataparallel mode.")
|
129 |
+
return net
|
130 |
+
|
131 |
+
|
132 |
+
def define_G(input_nc, output_nc, ngf, netG, norm='batch', use_dropout=False, init_type='normal', init_gain=0.02,
|
133 |
+
gpu_ids=[], DDP_device=None, find_unused_parameters=False):
|
134 |
+
"""Create a generator
|
135 |
+
|
136 |
+
Parameters:
|
137 |
+
input_nc (int) -- the number of channels in input images
|
138 |
+
output_nc (int) -- the number of channels in output images
|
139 |
+
ngf (int) -- the number of filters in the last conv layer
|
140 |
+
netG (str) -- the architecture's name: resnet_9blocks | resnet_6blocks | unet_256 | unet_128
|
141 |
+
norm (str) -- the name of normalization layers used in the network: batch | instance | none
|
142 |
+
use_dropout (bool) -- if use dropout layers.
|
143 |
+
init_type (str) -- the name of our initialization method.
|
144 |
+
init_gain (float) -- scaling factor for normal, xavier and orthogonal.
|
145 |
+
gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2
|
146 |
+
|
147 |
+
Returns a generator
|
148 |
+
|
149 |
+
Our current implementation provides two types of generators:
|
150 |
+
U-Net: [unet_128] (for 128x128 input images) and [unet_256] (for 256x256 input images)
|
151 |
+
The original U-Net paper: https://arxiv.org/abs/1505.04597
|
152 |
+
|
153 |
+
Resnet-based generator: [resnet_6blocks] (with 6 Resnet blocks) and [resnet_9blocks] (with 9 Resnet blocks)
|
154 |
+
Resnet-based generator consists of several Resnet blocks between a few downsampling/upsampling operations.
|
155 |
+
We adapt Torch code from Justin Johnson's neural style transfer project (https://github.com/jcjohnson/fast-neural-style).
|
156 |
+
|
157 |
+
|
158 |
+
The generator has been initialized by <init_net>. It uses RELU for non-linearity.
|
159 |
+
"""
|
160 |
+
net = None
|
161 |
+
norm_layer = get_norm_layer(norm_type=norm)
|
162 |
+
|
163 |
+
if netG == 'resnet_9blocks':
|
164 |
+
net = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=9)
|
165 |
+
elif netG == 'resnet_6blocks':
|
166 |
+
net = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=6)
|
167 |
+
elif netG == 'unet_128':
|
168 |
+
net = UnetGenerator(input_nc, output_nc, 7, ngf, norm_layer=norm_layer, use_dropout=use_dropout)
|
169 |
+
elif netG == 'unet_256':
|
170 |
+
net = UnetGenerator(input_nc, output_nc, 8, ngf, norm_layer=norm_layer, use_dropout=use_dropout)
|
171 |
+
else:
|
172 |
+
raise NotImplementedError('Generator model name [%s] is not recognized' % netG)
|
173 |
+
return init_net(net, init_type, init_gain, gpu_ids, DDP_device=DDP_device, find_unused_parameters=find_unused_parameters)
|
174 |
+
|
175 |
+
def define_F(netF, netF_nc=256, channels=[], use_dropout=False, init_type='normal', init_gain=0.02, gpu_ids=[], DDP_device=None, find_unused_parameters=False):
|
176 |
+
if netF == 'sample':
|
177 |
+
net = PatchSampleF(use_mlp=False, nc=netF_nc)
|
178 |
+
elif netF == 'mlp_sample':
|
179 |
+
net = PatchSampleF(use_mlp=True, nc=netF_nc)
|
180 |
+
else:
|
181 |
+
raise NotImplementedError('Projection model name [%s] is not recognized' % netF)
|
182 |
+
net.create_mlp(channels)
|
183 |
+
return init_net(net, init_type, init_gain, gpu_ids, DDP_device=DDP_device, find_unused_parameters=find_unused_parameters)
|
184 |
+
|
185 |
+
def define_D(input_nc, ndf, netD, n_layers_D=3, norm='batch', init_type='normal', init_gain=0.02,
|
186 |
+
gpu_ids=[], DDP_device=None, find_unused_parameters=False):
|
187 |
+
"""Create a discriminator
|
188 |
+
|
189 |
+
Parameters:
|
190 |
+
input_nc (int) -- the number of channels in input images
|
191 |
+
ndf (int) -- the number of filters in the first conv layer
|
192 |
+
netD (str) -- the architecture's name: basic | n_layers | pixel
|
193 |
+
n_layers_D (int) -- the number of conv layers in the discriminator; effective when netD=='n_layers'
|
194 |
+
norm (str) -- the type of normalization layers used in the network.
|
195 |
+
init_type (str) -- the name of the initialization method.
|
196 |
+
init_gain (float) -- scaling factor for normal, xavier and orthogonal.
|
197 |
+
gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2
|
198 |
+
|
199 |
+
Returns a discriminator
|
200 |
+
|
201 |
+
Our current implementation provides three types of discriminators:
|
202 |
+
[basic]: 'PatchGAN' classifier described in the original pix2pix paper.
|
203 |
+
It can classify whether 70×70 overlapping patches are real or fake.
|
204 |
+
Such a patch-level discriminator architecture has fewer parameters
|
205 |
+
than a full-image discriminator and can work on arbitrarily-sized images
|
206 |
+
in a fully convolutional fashion.
|
207 |
+
|
208 |
+
[n_layers]: With this mode, you can specify the number of conv layers in the discriminator
|
209 |
+
with the parameter <n_layers_D> (default=3 as used in [basic] (PatchGAN).)
|
210 |
+
|
211 |
+
[pixel]: 1x1 PixelGAN discriminator can classify whether a pixel is real or not.
|
212 |
+
It encourages greater color diversity but has no effect on spatial statistics.
|
213 |
+
|
214 |
+
The discriminator has been initialized by <init_net>. It uses Leakly RELU for non-linearity.
|
215 |
+
"""
|
216 |
+
net = None
|
217 |
+
norm_layer = get_norm_layer(norm_type=norm)
|
218 |
+
|
219 |
+
if netD == 'basic': # default PatchGAN classifier
|
220 |
+
net = NLayerDiscriminator(input_nc, ndf, n_layers=3, norm_layer=norm_layer)
|
221 |
+
elif netD == 'n_layers': # more options
|
222 |
+
net = NLayerDiscriminator(input_nc, ndf, n_layers_D, norm_layer=norm_layer)
|
223 |
+
elif netD == 'pixel': # classify if each pixel is real or fake
|
224 |
+
net = PixelDiscriminator(input_nc, ndf, norm_layer=norm_layer)
|
225 |
+
else:
|
226 |
+
raise NotImplementedError('Discriminator model name [%s] is not recognized' % netD)
|
227 |
+
return init_net(net, init_type, init_gain, gpu_ids, DDP_device=DDP_device, find_unused_parameters=find_unused_parameters)
|
228 |
+
|
229 |
+
def define_G_pix2pixHD(input_nc, output_nc, ngf, netG, n_downsample_global=3, n_blocks_global=9, n_local_enhancers=1,
|
230 |
+
n_blocks_local=3, norm='instance', gpu_ids=[], DDP_device=None, find_unused_parameters=False):
|
231 |
+
norm_layer = get_norm_layer(norm_type=norm)
|
232 |
+
if netG == 'global':
|
233 |
+
netG = GlobalGenerator(input_nc, output_nc, ngf, n_downsample_global, n_blocks_global, norm_layer)
|
234 |
+
elif netG == 'local':
|
235 |
+
netG = LocalEnhancer(input_nc, output_nc, ngf, n_downsample_global, n_blocks_global,
|
236 |
+
n_local_enhancers, n_blocks_local, norm_layer)
|
237 |
+
else:
|
238 |
+
raise('generator not implemented!')
|
239 |
+
return init_net(netG, 'normal', 0.02, gpu_ids, DDP_device=DDP_device, find_unused_parameters=find_unused_parameters)
|
240 |
+
|
241 |
+
def define_D_pix2pixHD(input_nc, ndf, n_layers_D, norm='instance', use_sigmoid=False, num_D=1, getIntermFeat=False,
|
242 |
+
gpu_ids=[], DDP_device=None, find_unused_parameters=False):
|
243 |
+
norm_layer = get_norm_layer(norm_type=norm)
|
244 |
+
netD = MultiscaleDiscriminator(input_nc, ndf, n_layers_D, norm_layer, use_sigmoid, num_D, getIntermFeat)
|
245 |
+
return init_net(netD, 'normal', 0.02, gpu_ids, DDP_device=DDP_device, find_unused_parameters=find_unused_parameters)
|
246 |
+
|
247 |
+
|
248 |
+
class Normalize(nn.Module):
|
249 |
+
|
250 |
+
def __init__(self, power=2):
|
251 |
+
super(Normalize, self).__init__()
|
252 |
+
self.power = power
|
253 |
+
|
254 |
+
def forward(self, x):
|
255 |
+
norm = x.pow(self.power).sum(1, keepdim=True).pow(1. / self.power)
|
256 |
+
out = x.div(norm + 1e-7)
|
257 |
+
return out
|
258 |
+
|
259 |
+
|
260 |
+
class PatchSampleF(nn.Module):
|
261 |
+
def __init__(self, use_mlp=False, nc=256):
|
262 |
+
# potential issues: currently, we use the same patch_ids for multiple images in the batch
|
263 |
+
super(PatchSampleF, self).__init__()
|
264 |
+
self.l2norm = Normalize(2)
|
265 |
+
self.use_mlp = use_mlp
|
266 |
+
self.nc = nc
|
267 |
+
|
268 |
+
def create_mlp(self, channels):
|
269 |
+
if not self.use_mlp:
|
270 |
+
return
|
271 |
+
for mlp_id, ch in enumerate(channels):
|
272 |
+
mlp = nn.Sequential(*[nn.Linear(ch, self.nc), nn.ReLU(), nn.Linear(self.nc, self.nc)])
|
273 |
+
setattr(self, 'mlp_%d' % mlp_id, mlp)
|
274 |
+
|
275 |
+
def forward(self, feats, num_patches=64, patch_ids=None):
|
276 |
+
return_ids = []
|
277 |
+
return_feats = []
|
278 |
+
for feat_id, feat in enumerate(feats):
|
279 |
+
B, H, W = feat.shape[0], feat.shape[2], feat.shape[3]
|
280 |
+
feat_reshape = feat.permute(0, 2, 3, 1).flatten(1, 2)
|
281 |
+
if num_patches > 0:
|
282 |
+
if patch_ids is not None:
|
283 |
+
patch_id = patch_ids[feat_id]
|
284 |
+
else:
|
285 |
+
patch_id = torch.randperm(feat_reshape.shape[1], device=feats[0].device)
|
286 |
+
patch_id = patch_id[:int(min(num_patches, patch_id.shape[0]))]
|
287 |
+
x_sample = feat_reshape[:, patch_id, :].flatten(0, 1)
|
288 |
+
else:
|
289 |
+
x_sample = feat_reshape
|
290 |
+
patch_id = []
|
291 |
+
if self.use_mlp:
|
292 |
+
mlp = getattr(self, 'mlp_%d' % feat_id)
|
293 |
+
x_sample = mlp(x_sample)
|
294 |
+
return_ids.append(patch_id)
|
295 |
+
x_sample = self.l2norm(x_sample)
|
296 |
+
|
297 |
+
if num_patches == 0:
|
298 |
+
x_sample = x_sample.permute(0, 2, 1).reshape([B, x_sample.shape[-1], H, W])
|
299 |
+
return_feats.append(x_sample)
|
300 |
+
return return_feats, return_ids
|
301 |
+
|
302 |
+
|
303 |
+
class LocalEnhancer(nn.Module):
|
304 |
+
def __init__(self, input_nc, output_nc, ngf=32, n_downsample_global=3, n_blocks_global=9,
|
305 |
+
n_local_enhancers=1, n_blocks_local=3, norm_layer=nn.BatchNorm2d, padding_type='reflect'):
|
306 |
+
super(LocalEnhancer, self).__init__()
|
307 |
+
self.n_local_enhancers = n_local_enhancers
|
308 |
+
|
309 |
+
###### global generator model #####
|
310 |
+
ngf_global = ngf * (2**n_local_enhancers)
|
311 |
+
model_global = GlobalGenerator(input_nc, output_nc, ngf_global, n_downsample_global, n_blocks_global, norm_layer).model
|
312 |
+
model_global = [model_global[i] for i in range(len(model_global)-3)] # get rid of final convolution layers
|
313 |
+
self.model = nn.Sequential(*model_global)
|
314 |
+
|
315 |
+
###### local enhancer layers #####
|
316 |
+
for n in range(1, n_local_enhancers+1):
|
317 |
+
### downsample
|
318 |
+
ngf_global = ngf * (2**(n_local_enhancers-n))
|
319 |
+
model_downsample = [nn.ReflectionPad2d(3), nn.Conv2d(input_nc, ngf_global, kernel_size=7, padding=0),
|
320 |
+
norm_layer(ngf_global), nn.ReLU(True),
|
321 |
+
nn.Conv2d(ngf_global, ngf_global * 2, kernel_size=3, stride=2, padding=1),
|
322 |
+
norm_layer(ngf_global * 2), nn.ReLU(True)]
|
323 |
+
### residual blocks
|
324 |
+
model_upsample = []
|
325 |
+
for i in range(n_blocks_local):
|
326 |
+
model_upsample += [ResnetBlock(ngf_global * 2, padding_type=padding_type, norm_layer=norm_layer, use_dropout=False, use_bias=True)]
|
327 |
+
|
328 |
+
### upsample
|
329 |
+
model_upsample += [nn.ConvTranspose2d(ngf_global * 2, ngf_global, kernel_size=3, stride=2, padding=1, output_padding=1),
|
330 |
+
norm_layer(ngf_global), nn.ReLU(True)]
|
331 |
+
|
332 |
+
### final convolution
|
333 |
+
if n == n_local_enhancers:
|
334 |
+
model_upsample += [nn.ReflectionPad2d(3), nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0), nn.Tanh()]
|
335 |
+
|
336 |
+
setattr(self, 'model'+str(n)+'_1', nn.Sequential(*model_downsample))
|
337 |
+
setattr(self, 'model'+str(n)+'_2', nn.Sequential(*model_upsample))
|
338 |
+
|
339 |
+
self.downsample = nn.AvgPool2d(3, stride=2, padding=[1, 1], count_include_pad=False)
|
340 |
+
|
341 |
+
def forward(self, input):
|
342 |
+
### create input pyramid
|
343 |
+
input_downsampled = [input]
|
344 |
+
for i in range(self.n_local_enhancers):
|
345 |
+
input_downsampled.append(self.downsample(input_downsampled[-1]))
|
346 |
+
|
347 |
+
### output at coarest level
|
348 |
+
output_prev = self.model(input_downsampled[-1])
|
349 |
+
### build up one layer at a time
|
350 |
+
for n_local_enhancers in range(1, self.n_local_enhancers+1):
|
351 |
+
model_downsample = getattr(self, 'model'+str(n_local_enhancers)+'_1')
|
352 |
+
model_upsample = getattr(self, 'model'+str(n_local_enhancers)+'_2')
|
353 |
+
input_i = input_downsampled[self.n_local_enhancers-n_local_enhancers]
|
354 |
+
output_prev = model_upsample(model_downsample(input_i) + output_prev)
|
355 |
+
return output_prev
|
356 |
+
|
357 |
+
class GlobalGenerator(nn.Module):
|
358 |
+
def __init__(self, input_nc, output_nc, ngf=64, n_downsampling=3, n_blocks=9, norm_layer=nn.BatchNorm2d,
|
359 |
+
padding_type='reflect'):
|
360 |
+
assert(n_blocks >= 0)
|
361 |
+
super(GlobalGenerator, self).__init__()
|
362 |
+
activation = nn.ReLU(True)
|
363 |
+
|
364 |
+
model = [nn.ReflectionPad2d(3), nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0), norm_layer(ngf), activation]
|
365 |
+
### downsample
|
366 |
+
for i in range(n_downsampling):
|
367 |
+
mult = 2**i
|
368 |
+
model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1),
|
369 |
+
norm_layer(ngf * mult * 2), activation]
|
370 |
+
|
371 |
+
### resnet blocks
|
372 |
+
mult = 2**n_downsampling
|
373 |
+
for i in range(n_blocks):
|
374 |
+
model += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=False, use_bias=True)]
|
375 |
+
|
376 |
+
### upsample
|
377 |
+
for i in range(n_downsampling):
|
378 |
+
mult = 2**(n_downsampling - i)
|
379 |
+
model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2), kernel_size=3, stride=2, padding=1, output_padding=1),
|
380 |
+
norm_layer(int(ngf * mult / 2)), activation]
|
381 |
+
model += [nn.ReflectionPad2d(3), nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0), nn.Tanh()]
|
382 |
+
self.model = nn.Sequential(*model)
|
383 |
+
|
384 |
+
def forward(self, input):
|
385 |
+
return self.model(input)
|
386 |
+
|
387 |
+
class MultiscaleDiscriminator(nn.Module):
|
388 |
+
def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d,
|
389 |
+
use_sigmoid=False, num_D=3, getIntermFeat=False):
|
390 |
+
super(MultiscaleDiscriminator, self).__init__()
|
391 |
+
self.num_D = num_D
|
392 |
+
self.n_layers = n_layers
|
393 |
+
self.getIntermFeat = getIntermFeat
|
394 |
+
|
395 |
+
for i in range(num_D):
|
396 |
+
netD = Pix2PixHDNLayerDiscriminator(input_nc, ndf, n_layers, norm_layer, use_sigmoid, getIntermFeat)
|
397 |
+
if getIntermFeat:
|
398 |
+
for j in range(n_layers+2):
|
399 |
+
setattr(self, 'scale'+str(i)+'_layer'+str(j), getattr(netD, 'model'+str(j)))
|
400 |
+
else:
|
401 |
+
setattr(self, 'layer'+str(i), netD.model)
|
402 |
+
|
403 |
+
self.downsample = nn.AvgPool2d(3, stride=2, padding=[1, 1], count_include_pad=False)
|
404 |
+
|
405 |
+
def singleD_forward(self, model, input):
|
406 |
+
if self.getIntermFeat:
|
407 |
+
result = [input]
|
408 |
+
for i in range(len(model)):
|
409 |
+
result.append(model[i](result[-1]))
|
410 |
+
return result[1:]
|
411 |
+
else:
|
412 |
+
return [model(input)]
|
413 |
+
|
414 |
+
def forward(self, input):
|
415 |
+
num_D = self.num_D
|
416 |
+
result = []
|
417 |
+
input_downsampled = input
|
418 |
+
for i in range(num_D):
|
419 |
+
if self.getIntermFeat:
|
420 |
+
model = [getattr(self, 'scale'+str(num_D-1-i)+'_layer'+str(j)) for j in range(self.n_layers+2)]
|
421 |
+
else:
|
422 |
+
model = getattr(self, 'layer'+str(num_D-1-i))
|
423 |
+
result.append(self.singleD_forward(model, input_downsampled))
|
424 |
+
if i != (num_D-1):
|
425 |
+
input_downsampled = self.downsample(input_downsampled)
|
426 |
+
return result
|
427 |
+
|
428 |
+
class Pix2PixHDNLayerDiscriminator(nn.Module):
|
429 |
+
def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, use_sigmoid=False, getIntermFeat=False):
|
430 |
+
super(Pix2PixHDNLayerDiscriminator, self).__init__()
|
431 |
+
self.getIntermFeat = getIntermFeat
|
432 |
+
self.n_layers = n_layers
|
433 |
+
|
434 |
+
kw = 4
|
435 |
+
padw = int(np.ceil((kw-1.0)/2))
|
436 |
+
sequence = [[nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]]
|
437 |
+
|
438 |
+
nf = ndf
|
439 |
+
for n in range(1, n_layers):
|
440 |
+
nf_prev = nf
|
441 |
+
nf = min(nf * 2, 512)
|
442 |
+
sequence += [[
|
443 |
+
nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=2, padding=padw),
|
444 |
+
norm_layer(nf), nn.LeakyReLU(0.2, True)
|
445 |
+
]]
|
446 |
+
|
447 |
+
nf_prev = nf
|
448 |
+
nf = min(nf * 2, 512)
|
449 |
+
sequence += [[
|
450 |
+
nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=1, padding=padw),
|
451 |
+
norm_layer(nf),
|
452 |
+
nn.LeakyReLU(0.2, True)
|
453 |
+
]]
|
454 |
+
|
455 |
+
sequence += [[nn.Conv2d(nf, 1, kernel_size=kw, stride=1, padding=padw)]]
|
456 |
+
|
457 |
+
if use_sigmoid:
|
458 |
+
sequence += [[nn.Sigmoid()]]
|
459 |
+
|
460 |
+
if getIntermFeat:
|
461 |
+
for n in range(len(sequence)):
|
462 |
+
setattr(self, 'model'+str(n), nn.Sequential(*sequence[n]))
|
463 |
+
else:
|
464 |
+
sequence_stream = []
|
465 |
+
for n in range(len(sequence)):
|
466 |
+
sequence_stream += sequence[n]
|
467 |
+
self.model = nn.Sequential(*sequence_stream)
|
468 |
+
|
469 |
+
def forward(self, input):
|
470 |
+
if self.getIntermFeat:
|
471 |
+
res = [input]
|
472 |
+
for n in range(self.n_layers+2):
|
473 |
+
model = getattr(self, 'model'+str(n))
|
474 |
+
res.append(model(res[-1]))
|
475 |
+
return res[1:]
|
476 |
+
else:
|
477 |
+
return self.model(input)
|
478 |
+
|
479 |
+
|
480 |
+
class MultiGANLoss(nn.Module):
|
481 |
+
def __init__(self, use_lsgan=True, target_real_label=1.0, target_fake_label=0.0,
|
482 |
+
tensor=torch.cuda.FloatTensor):
|
483 |
+
super(MultiGANLoss, self).__init__()
|
484 |
+
self.real_label = target_real_label
|
485 |
+
self.fake_label = target_fake_label
|
486 |
+
self.real_label_var = None
|
487 |
+
self.fake_label_var = None
|
488 |
+
self.Tensor = tensor
|
489 |
+
if use_lsgan:
|
490 |
+
self.loss = nn.MSELoss()
|
491 |
+
else:
|
492 |
+
self.loss = nn.BCELoss()
|
493 |
+
|
494 |
+
def get_target_tensor(self, input, target_is_real):
|
495 |
+
target_tensor = None
|
496 |
+
if target_is_real:
|
497 |
+
create_label = ((self.real_label_var is None) or
|
498 |
+
(self.real_label_var.numel() != input.numel()))
|
499 |
+
if create_label:
|
500 |
+
real_tensor = self.Tensor(input.size()).fill_(self.real_label)
|
501 |
+
self.real_label_var = Variable(real_tensor, requires_grad=False)
|
502 |
+
target_tensor = self.real_label_var
|
503 |
+
else:
|
504 |
+
create_label = ((self.fake_label_var is None) or
|
505 |
+
(self.fake_label_var.numel() != input.numel()))
|
506 |
+
if create_label:
|
507 |
+
fake_tensor = self.Tensor(input.size()).fill_(self.fake_label)
|
508 |
+
self.fake_label_var = Variable(fake_tensor, requires_grad=False)
|
509 |
+
target_tensor = self.fake_label_var
|
510 |
+
return target_tensor
|
511 |
+
|
512 |
+
def __call__(self, input, target_is_real):
|
513 |
+
if isinstance(input[0], list):
|
514 |
+
loss = 0
|
515 |
+
for input_i in input:
|
516 |
+
pred = input_i[-1]
|
517 |
+
target_tensor = self.get_target_tensor(pred, target_is_real)
|
518 |
+
loss += self.loss(pred, target_tensor)
|
519 |
+
return loss
|
520 |
+
else:
|
521 |
+
target_tensor = self.get_target_tensor(input[-1], target_is_real)
|
522 |
+
return self.loss(input[-1], target_tensor)
|
523 |
+
|
524 |
+
##############################################################################
|
525 |
+
# Classes
|
526 |
+
##############################################################################
|
527 |
+
class GANLoss(nn.Module):
|
528 |
+
"""Define different GAN objectives.
|
529 |
+
|
530 |
+
The GANLoss class abstracts away the need to create the target label tensor
|
531 |
+
that has the same size as the input.
|
532 |
+
"""
|
533 |
+
|
534 |
+
def __init__(self, gan_mode, target_real_label=1.0, target_fake_label=0.0):
|
535 |
+
""" Initialize the GANLoss class.
|
536 |
+
|
537 |
+
Parameters:
|
538 |
+
gan_mode (str) - - the type of GAN objective. It currently supports vanilla, lsgan, and wgangp.
|
539 |
+
target_real_label (bool) - - label for a real image
|
540 |
+
target_fake_label (bool) - - label of a fake image
|
541 |
+
|
542 |
+
Note: Do not use sigmoid as the last layer of Discriminator.
|
543 |
+
LSGAN needs no sigmoid. vanilla GANs will handle it with BCEWithLogitsLoss.
|
544 |
+
"""
|
545 |
+
super(GANLoss, self).__init__()
|
546 |
+
self.register_buffer('real_label', torch.tensor(target_real_label))
|
547 |
+
self.register_buffer('fake_label', torch.tensor(target_fake_label))
|
548 |
+
self.gan_mode = gan_mode
|
549 |
+
if gan_mode == 'lsgan':
|
550 |
+
self.loss = nn.MSELoss()
|
551 |
+
elif gan_mode == 'vanilla':
|
552 |
+
self.loss = nn.BCEWithLogitsLoss()
|
553 |
+
elif gan_mode in ['wgangp']:
|
554 |
+
self.loss = None
|
555 |
+
else:
|
556 |
+
raise NotImplementedError('gan mode %s not implemented' % gan_mode)
|
557 |
+
|
558 |
+
def get_target_tensor(self, prediction, target_is_real):
|
559 |
+
"""Create label tensors with the same size as the input.
|
560 |
+
|
561 |
+
Parameters:
|
562 |
+
prediction (tensor) - - tpyically the prediction from a discriminator
|
563 |
+
target_is_real (bool) - - if the ground truth label is for real images or fake images
|
564 |
+
|
565 |
+
Returns:
|
566 |
+
A label tensor filled with ground truth label, and with the size of the input
|
567 |
+
"""
|
568 |
+
|
569 |
+
if target_is_real:
|
570 |
+
target_tensor = self.real_label
|
571 |
+
else:
|
572 |
+
target_tensor = self.fake_label
|
573 |
+
return target_tensor.expand_as(prediction)
|
574 |
+
|
575 |
+
def __call__(self, prediction, target_is_real):
|
576 |
+
"""Calculate loss given Discriminator's output and grount truth labels.
|
577 |
+
|
578 |
+
Parameters:
|
579 |
+
prediction (tensor) - - tpyically the prediction output from a discriminator
|
580 |
+
target_is_real (bool) - - if the ground truth label is for real images or fake images
|
581 |
+
|
582 |
+
Returns:
|
583 |
+
the calculated loss.
|
584 |
+
"""
|
585 |
+
if self.gan_mode in ['lsgan', 'vanilla']:
|
586 |
+
target_tensor = self.get_target_tensor(prediction, target_is_real)
|
587 |
+
loss = self.loss(prediction, target_tensor)
|
588 |
+
elif self.gan_mode == 'wgangp':
|
589 |
+
if target_is_real:
|
590 |
+
loss = nn.functional.softplus(-prediction).mean()
|
591 |
+
else:
|
592 |
+
loss = nn.functional.softplus(prediction).mean()
|
593 |
+
return loss
|
594 |
+
|
595 |
+
|
596 |
+
class PatchNCELoss(nn.Module):
|
597 |
+
def __init__(self, batch_size, nce_T):
|
598 |
+
super().__init__()
|
599 |
+
self.batch_size = batch_size
|
600 |
+
self.nce_T = nce_T
|
601 |
+
self.cross_entropy_loss = torch.nn.CrossEntropyLoss(reduction='none')
|
602 |
+
self.mask_dtype = torch.uint8 if version.parse(torch.__version__) < version.parse('1.2.0') else torch.bool
|
603 |
+
|
604 |
+
def forward(self, feat_q, feat_k):
|
605 |
+
batchSize = feat_q.shape[0]
|
606 |
+
dim = feat_q.shape[1]
|
607 |
+
feat_k = feat_k.detach()
|
608 |
+
|
609 |
+
# pos logit
|
610 |
+
l_pos = torch.bmm(feat_q.view(batchSize, 1, -1), feat_k.view(batchSize, -1, 1))
|
611 |
+
l_pos = l_pos.view(batchSize, 1)
|
612 |
+
|
613 |
+
# neg logit
|
614 |
+
batch_dim_for_bmm = self.batch_size
|
615 |
+
|
616 |
+
# reshape features to batch size
|
617 |
+
feat_q = feat_q.view(batch_dim_for_bmm, -1, dim)
|
618 |
+
feat_k = feat_k.view(batch_dim_for_bmm, -1, dim)
|
619 |
+
npatches = feat_q.size(1)
|
620 |
+
l_neg_curbatch = torch.bmm(feat_q, feat_k.transpose(2, 1))
|
621 |
+
|
622 |
+
# diagonal entries are similarity between same features, and hence meaningless.
|
623 |
+
# just fill the diagonal with very small number, which is exp(-10) and almost zero
|
624 |
+
diagonal = torch.eye(npatches, device=feat_q.device, dtype=self.mask_dtype)[None, :, :]
|
625 |
+
l_neg_curbatch.masked_fill_(diagonal, -10.0)
|
626 |
+
l_neg = l_neg_curbatch.view(-1, npatches)
|
627 |
+
|
628 |
+
out = torch.cat((l_pos, l_neg), dim=1) / self.nce_T
|
629 |
+
|
630 |
+
loss = self.cross_entropy_loss(out, torch.zeros(out.size(0), dtype=torch.long,
|
631 |
+
device=feat_q.device))
|
632 |
+
|
633 |
+
return loss
|
634 |
+
|
635 |
+
|
636 |
+
def cal_gradient_penalty(netD, real_data, fake_data, device, type='mixed', constant=1.0, lambda_gp=10.0):
|
637 |
+
"""Calculate the gradient penalty loss, used in WGAN-GP paper https://arxiv.org/abs/1704.00028
|
638 |
+
|
639 |
+
Arguments:
|
640 |
+
netD (network) -- discriminator network
|
641 |
+
real_data (tensor array) -- real images
|
642 |
+
fake_data (tensor array) -- generated images from the generator
|
643 |
+
device (str) -- GPU / CPU: from torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu')
|
644 |
+
type (str) -- if we mix real and fake data or not [real | fake | mixed].
|
645 |
+
constant (float) -- the constant used in formula ( | |gradient||_2 - constant)^2
|
646 |
+
lambda_gp (float) -- weight for this loss
|
647 |
+
|
648 |
+
Returns the gradient penalty loss
|
649 |
+
"""
|
650 |
+
if lambda_gp > 0.0:
|
651 |
+
if type == 'real': # either use real images, fake images, or a linear interpolation of two.
|
652 |
+
interpolatesv = real_data
|
653 |
+
elif type == 'fake':
|
654 |
+
interpolatesv = fake_data
|
655 |
+
elif type == 'mixed':
|
656 |
+
alpha = torch.rand(real_data.shape[0], 1, device=device)
|
657 |
+
alpha = alpha.expand(real_data.shape[0], real_data.nelement() // real_data.shape[0]).contiguous().view(*real_data.shape)
|
658 |
+
interpolatesv = alpha * real_data + ((1 - alpha) * fake_data)
|
659 |
+
else:
|
660 |
+
raise NotImplementedError('{} not implemented'.format(type))
|
661 |
+
interpolatesv.requires_grad_(True)
|
662 |
+
disc_interpolates = netD(interpolatesv)
|
663 |
+
gradients = torch.autograd.grad(outputs=disc_interpolates, inputs=interpolatesv,
|
664 |
+
grad_outputs=torch.ones(disc_interpolates.size()).to(device),
|
665 |
+
create_graph=True, retain_graph=True, only_inputs=True)
|
666 |
+
gradients = gradients[0].view(real_data.size(0), -1) # flat the data
|
667 |
+
gradient_penalty = (((gradients + 1e-16).norm(2, dim=1) - constant) ** 2).mean() * lambda_gp # added eps
|
668 |
+
return gradient_penalty, gradients
|
669 |
+
else:
|
670 |
+
return 0.0, None
|
671 |
+
|
672 |
+
|
673 |
+
class ResnetGenerator(nn.Module):
|
674 |
+
"""Resnet-based generator that consists of Resnet blocks between a few downsampling/upsampling operations.
|
675 |
+
|
676 |
+
We adapt Torch code and idea from Justin Johnson's neural style transfer project(https://github.com/jcjohnson/fast-neural-style)
|
677 |
+
"""
|
678 |
+
|
679 |
+
def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6, padding_type='reflect'):
|
680 |
+
"""Construct a Resnet-based generator
|
681 |
+
|
682 |
+
Parameters:
|
683 |
+
input_nc (int) -- the number of channels in input images
|
684 |
+
output_nc (int) -- the number of channels in output images
|
685 |
+
ngf (int) -- the number of filters in the last conv layer
|
686 |
+
norm_layer -- normalization layer
|
687 |
+
use_dropout (bool) -- if use dropout layers
|
688 |
+
n_blocks (int) -- the number of ResNet blocks
|
689 |
+
padding_type (str) -- the name of padding layer in conv layers: reflect | replicate | zero
|
690 |
+
"""
|
691 |
+
assert(n_blocks >= 0)
|
692 |
+
super(ResnetGenerator, self).__init__()
|
693 |
+
if type(norm_layer) == functools.partial:
|
694 |
+
use_bias = norm_layer.func == nn.InstanceNorm2d
|
695 |
+
else:
|
696 |
+
use_bias = norm_layer == nn.InstanceNorm2d
|
697 |
+
|
698 |
+
model = [nn.ReflectionPad2d(3),
|
699 |
+
nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=use_bias),
|
700 |
+
norm_layer(ngf),
|
701 |
+
nn.ReLU(True)]
|
702 |
+
|
703 |
+
n_downsampling = 2
|
704 |
+
for i in range(n_downsampling): # add downsampling layers
|
705 |
+
mult = 2 ** i
|
706 |
+
model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1, bias=use_bias),
|
707 |
+
norm_layer(ngf * mult * 2),
|
708 |
+
nn.ReLU(True)]
|
709 |
+
|
710 |
+
mult = 2 ** n_downsampling
|
711 |
+
for i in range(n_blocks): # add ResNet blocks
|
712 |
+
|
713 |
+
model += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)]
|
714 |
+
|
715 |
+
for i in range(n_downsampling): # add upsampling layers
|
716 |
+
mult = 2 ** (n_downsampling - i)
|
717 |
+
model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2),
|
718 |
+
kernel_size=3, stride=2,
|
719 |
+
padding=1, output_padding=1,
|
720 |
+
bias=use_bias),
|
721 |
+
norm_layer(int(ngf * mult / 2)),
|
722 |
+
nn.ReLU(True)]
|
723 |
+
model += [nn.ReflectionPad2d(3)]
|
724 |
+
model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
|
725 |
+
model += [nn.Tanh()]
|
726 |
+
|
727 |
+
self.model = nn.Sequential(*model)
|
728 |
+
|
729 |
+
def forward(self, input, layers=[]):
|
730 |
+
if len(layers) > 0:
|
731 |
+
feat = input
|
732 |
+
feats = []
|
733 |
+
for layer_id, layer in enumerate(self.model):
|
734 |
+
feat = layer(feat)
|
735 |
+
if layer_id in layers:
|
736 |
+
feats.append(feat)
|
737 |
+
if layer_id == layers[-1]:
|
738 |
+
break
|
739 |
+
return feats
|
740 |
+
else:
|
741 |
+
"""Standard forward"""
|
742 |
+
return self.model(input)
|
743 |
+
|
744 |
+
|
745 |
+
class ResnetBlock(nn.Module):
|
746 |
+
"""Define a Resnet block"""
|
747 |
+
|
748 |
+
def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias):
|
749 |
+
"""Initialize the Resnet block
|
750 |
+
|
751 |
+
A resnet block is a conv block with skip connections
|
752 |
+
We construct a conv block with build_conv_block function,
|
753 |
+
and implement skip connections in <forward> function.
|
754 |
+
Original Resnet paper: https://arxiv.org/pdf/1512.03385.pdf
|
755 |
+
"""
|
756 |
+
super(ResnetBlock, self).__init__()
|
757 |
+
self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_dropout, use_bias)
|
758 |
+
|
759 |
+
def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias):
|
760 |
+
"""Construct a convolutional block.
|
761 |
+
|
762 |
+
Parameters:
|
763 |
+
dim (int) -- the number of channels in the conv layer.
|
764 |
+
padding_type (str) -- the name of padding layer: reflect | replicate | zero
|
765 |
+
norm_layer -- normalization layer
|
766 |
+
use_dropout (bool) -- if use dropout layers.
|
767 |
+
use_bias (bool) -- if the conv layer uses bias or not
|
768 |
+
|
769 |
+
Returns a conv block (with a conv layer, a normalization layer, and a non-linearity layer (ReLU))
|
770 |
+
"""
|
771 |
+
conv_block = []
|
772 |
+
p = 0
|
773 |
+
if padding_type == 'reflect':
|
774 |
+
conv_block += [nn.ReflectionPad2d(1)]
|
775 |
+
elif padding_type == 'replicate':
|
776 |
+
conv_block += [nn.ReplicationPad2d(1)]
|
777 |
+
elif padding_type == 'zero':
|
778 |
+
p = 1
|
779 |
+
else:
|
780 |
+
raise NotImplementedError('padding [%s] is not implemented' % padding_type)
|
781 |
+
|
782 |
+
conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim), nn.ReLU(True)]
|
783 |
+
if use_dropout:
|
784 |
+
conv_block += [nn.Dropout(0.5)]
|
785 |
+
|
786 |
+
p = 0
|
787 |
+
if padding_type == 'reflect':
|
788 |
+
conv_block += [nn.ReflectionPad2d(1)]
|
789 |
+
elif padding_type == 'replicate':
|
790 |
+
conv_block += [nn.ReplicationPad2d(1)]
|
791 |
+
elif padding_type == 'zero':
|
792 |
+
p = 1
|
793 |
+
else:
|
794 |
+
raise NotImplementedError('padding [%s] is not implemented' % padding_type)
|
795 |
+
conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim)]
|
796 |
+
|
797 |
+
return nn.Sequential(*conv_block)
|
798 |
+
|
799 |
+
def forward(self, x):
|
800 |
+
"""Forward function (with skip connections)"""
|
801 |
+
out = x + self.conv_block(x) # add skip connections
|
802 |
+
return out
|
803 |
+
|
804 |
+
|
805 |
+
class UnetGenerator(nn.Module):
|
806 |
+
"""Create a Unet-based generator"""
|
807 |
+
|
808 |
+
def __init__(self, input_nc, output_nc, num_downs, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False):
|
809 |
+
"""Construct a Unet generator
|
810 |
+
Parameters:
|
811 |
+
input_nc (int) -- the number of channels in input images
|
812 |
+
output_nc (int) -- the number of channels in output images
|
813 |
+
num_downs (int) -- the number of downsamplings in UNet. For example, # if |num_downs| == 7,
|
814 |
+
image of size 128x128 will become of size 1x1 # at the bottleneck
|
815 |
+
ngf (int) -- the number of filters in the last conv layer
|
816 |
+
norm_layer -- normalization layer
|
817 |
+
|
818 |
+
We construct the U-Net from the innermost layer to the outermost layer.
|
819 |
+
It is a recursive process.
|
820 |
+
"""
|
821 |
+
super(UnetGenerator, self).__init__()
|
822 |
+
# construct unet structure
|
823 |
+
unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True) # add the innermost layer
|
824 |
+
for i in range(num_downs - 5): # add intermediate layers with ngf * 8 filters
|
825 |
+
unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_dropout=use_dropout)
|
826 |
+
# gradually reduce the number of filters from ngf * 8 to ngf
|
827 |
+
unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
|
828 |
+
unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
|
829 |
+
unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
|
830 |
+
self.model = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer) # add the outermost layer
|
831 |
+
|
832 |
+
def forward(self, input):
|
833 |
+
"""Standard forward"""
|
834 |
+
return self.model(input)
|
835 |
+
|
836 |
+
|
837 |
+
class UnetSkipConnectionBlock(nn.Module):
|
838 |
+
"""Defines the Unet submodule with skip connection.
|
839 |
+
X -------------------identity----------------------
|
840 |
+
|-- downsampling -- |submodule| -- upsampling --|
|
841 |
+
"""
|
842 |
+
|
843 |
+
def __init__(self, outer_nc, inner_nc, input_nc=None,
|
844 |
+
submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False):
|
845 |
+
"""Construct a Unet submodule with skip connections.
|
846 |
+
|
847 |
+
Parameters:
|
848 |
+
outer_nc (int) -- the number of filters in the outer conv layer
|
849 |
+
inner_nc (int) -- the number of filters in the inner conv layer
|
850 |
+
input_nc (int) -- the number of channels in input images/features
|
851 |
+
submodule (UnetSkipConnectionBlock) -- previously defined submodules
|
852 |
+
outermost (bool) -- if this module is the outermost module
|
853 |
+
innermost (bool) -- if this module is the innermost module
|
854 |
+
norm_layer -- normalization layer
|
855 |
+
use_dropout (bool) -- if use dropout layers.
|
856 |
+
"""
|
857 |
+
super(UnetSkipConnectionBlock, self).__init__()
|
858 |
+
self.outermost = outermost
|
859 |
+
if type(norm_layer) == functools.partial:
|
860 |
+
use_bias = norm_layer.func == nn.InstanceNorm2d
|
861 |
+
else:
|
862 |
+
use_bias = norm_layer == nn.InstanceNorm2d
|
863 |
+
if input_nc is None:
|
864 |
+
input_nc = outer_nc
|
865 |
+
downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4,
|
866 |
+
stride=2, padding=1, bias=use_bias)
|
867 |
+
downrelu = nn.LeakyReLU(0.2, True)
|
868 |
+
downnorm = norm_layer(inner_nc)
|
869 |
+
uprelu = nn.ReLU(True)
|
870 |
+
upnorm = norm_layer(outer_nc)
|
871 |
+
|
872 |
+
if outermost:
|
873 |
+
upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
|
874 |
+
kernel_size=4, stride=2,
|
875 |
+
padding=1)
|
876 |
+
down = [downconv]
|
877 |
+
up = [uprelu, upconv, nn.Tanh()]
|
878 |
+
model = down + [submodule] + up
|
879 |
+
elif innermost:
|
880 |
+
upconv = nn.ConvTranspose2d(inner_nc, outer_nc,
|
881 |
+
kernel_size=4, stride=2,
|
882 |
+
padding=1, bias=use_bias)
|
883 |
+
down = [downrelu, downconv]
|
884 |
+
up = [uprelu, upconv, upnorm]
|
885 |
+
model = down + up
|
886 |
+
else:
|
887 |
+
upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
|
888 |
+
kernel_size=4, stride=2,
|
889 |
+
padding=1, bias=use_bias)
|
890 |
+
down = [downrelu, downconv, downnorm]
|
891 |
+
up = [uprelu, upconv, upnorm]
|
892 |
+
|
893 |
+
if use_dropout:
|
894 |
+
model = down + [submodule] + up + [nn.Dropout(0.5)]
|
895 |
+
else:
|
896 |
+
model = down + [submodule] + up
|
897 |
+
|
898 |
+
self.model = nn.Sequential(*model)
|
899 |
+
|
900 |
+
def forward(self, x):
|
901 |
+
if self.outermost:
|
902 |
+
return self.model(x)
|
903 |
+
else: # add skip connections
|
904 |
+
return torch.cat([x, self.model(x)], 1)
|
905 |
+
|
906 |
+
|
907 |
+
# Creates SPADE normalization layer based on the given configuration
|
908 |
+
# SPADE consists of two steps. First, it normalizes the activations using
|
909 |
+
# your favorite normalization method, such as Batch Norm or Instance Norm.
|
910 |
+
# Second, it applies scale and bias to the normalized output, conditioned on
|
911 |
+
# the segmentation map.
|
912 |
+
# The format of |config_text| is spade(norm)(ks), where
|
913 |
+
# (norm) specifies the type of parameter-free normalization.
|
914 |
+
# (e.g. syncbatch, batch, instance)
|
915 |
+
# (ks) specifies the size of kernel in the SPADE module (e.g. 3x3)
|
916 |
+
# Example |config_text| will be spadesyncbatch3x3, or spadeinstance5x5.
|
917 |
+
# Also, the other arguments are
|
918 |
+
# |norm_nc|: the #channels of the normalized activations, hence the output dim of SPADE
|
919 |
+
# |label_nc|: the #channels of the input semantic map, hence the input dim of SPADE
|
920 |
+
class SPADE(nn.Module):
|
921 |
+
def __init__(self, config_text, norm_nc, label_nc):
|
922 |
+
super().__init__()
|
923 |
+
|
924 |
+
assert config_text.startswith('spade')
|
925 |
+
parsed = re.search(r'spade(\D+)(\d)x\d', config_text)
|
926 |
+
param_free_norm_type = str(parsed.group(1))
|
927 |
+
ks = int(parsed.group(2))
|
928 |
+
|
929 |
+
if param_free_norm_type == 'instance':
|
930 |
+
self.param_free_norm = nn.InstanceNorm2d(norm_nc, affine=False)
|
931 |
+
elif param_free_norm_type == 'batch':
|
932 |
+
self.param_free_norm = nn.BatchNorm2d(norm_nc, affine=False)
|
933 |
+
elif param_free_norm_type == 'identity':
|
934 |
+
self.param_free_norm = nn.Identity()
|
935 |
+
else:
|
936 |
+
raise ValueError('%s is not a recognized param-free norm type in SPADE'
|
937 |
+
% param_free_norm_type)
|
938 |
+
|
939 |
+
# The dimension of the intermediate embedding space. Yes, hardcoded.
|
940 |
+
nhidden = 128
|
941 |
+
|
942 |
+
pw = ks // 2
|
943 |
+
self.mlp_shared = nn.Sequential(
|
944 |
+
nn.Conv2d(label_nc, nhidden, kernel_size=ks, padding=pw),
|
945 |
+
nn.ReLU()
|
946 |
+
)
|
947 |
+
self.mlp_gamma = nn.Conv2d(nhidden, norm_nc, kernel_size=ks, padding=pw)
|
948 |
+
self.mlp_beta = nn.Conv2d(nhidden, norm_nc, kernel_size=ks, padding=pw)
|
949 |
+
|
950 |
+
def forward(self, x, segmap):
|
951 |
+
|
952 |
+
# Part 1. generate parameter-free normalized activations
|
953 |
+
normalized = self.param_free_norm(x)
|
954 |
+
|
955 |
+
# Part 2. produce scaling and bias conditioned on semantic map
|
956 |
+
segmap = F.interpolate(segmap, size=x.size()[2:], mode='nearest')
|
957 |
+
actv = self.mlp_shared(segmap)
|
958 |
+
gamma = self.mlp_gamma(actv)
|
959 |
+
beta = self.mlp_beta(actv)
|
960 |
+
|
961 |
+
# apply scale and bias
|
962 |
+
out = normalized * (1 + gamma) + beta
|
963 |
+
|
964 |
+
return out
|
965 |
+
|
966 |
+
|
967 |
+
# ResNet block that uses SPADE.
|
968 |
+
# It differs from the ResNet block of pix2pixHD in that
|
969 |
+
# it takes in the segmentation map as input, learns the skip connection if necessary,
|
970 |
+
# and applies normalization first and then convolution.
|
971 |
+
# This architecture seemed like a standard architecture for unconditional or
|
972 |
+
# class-conditional GAN architecture using residual block.
|
973 |
+
# The code was inspired from https://github.com/LMescheder/GAN_stability.
|
974 |
+
class SPADEResnetBlock(nn.Module):
|
975 |
+
def __init__(self, fin, fout, config_str, semantic_nc):
|
976 |
+
super().__init__()
|
977 |
+
# Attributes
|
978 |
+
self.learned_shortcut = (fin != fout)
|
979 |
+
fmiddle = min(fin, fout)
|
980 |
+
|
981 |
+
# create conv layers
|
982 |
+
self.conv_0 = nn.Conv2d(fin, fmiddle, kernel_size=3, padding=1)
|
983 |
+
self.conv_1 = nn.Conv2d(fmiddle, fout, kernel_size=3, padding=1)
|
984 |
+
if self.learned_shortcut:
|
985 |
+
self.conv_s = nn.Conv2d(fin, fout, kernel_size=1, bias=False)
|
986 |
+
|
987 |
+
# apply spectral norm if specified
|
988 |
+
if 'spectral' in config_str:
|
989 |
+
self.conv_0 = spectral_norm(self.conv_0)
|
990 |
+
self.conv_1 = spectral_norm(self.conv_1)
|
991 |
+
if self.learned_shortcut:
|
992 |
+
self.conv_s = spectral_norm(self.conv_s)
|
993 |
+
|
994 |
+
# define normalization layers
|
995 |
+
spade_config_str = config_str.replace('spectral', '')
|
996 |
+
self.norm_0 = SPADE(spade_config_str, fin, semantic_nc)
|
997 |
+
self.norm_1 = SPADE(spade_config_str, fmiddle, semantic_nc)
|
998 |
+
if self.learned_shortcut:
|
999 |
+
self.norm_s = SPADE(spade_config_str, fin, semantic_nc)
|
1000 |
+
|
1001 |
+
# note the resnet block with SPADE also takes in |seg|,
|
1002 |
+
# the semantic segmentation map as input
|
1003 |
+
def forward(self, x, seg):
|
1004 |
+
x_s = self.shortcut(x, seg)
|
1005 |
+
|
1006 |
+
dx = self.conv_0(self.actvn(self.norm_0(x, seg)))
|
1007 |
+
dx = self.conv_1(self.actvn(self.norm_1(dx, seg)))
|
1008 |
+
|
1009 |
+
out = x_s + dx
|
1010 |
+
|
1011 |
+
return out
|
1012 |
+
|
1013 |
+
def shortcut(self, x, seg):
|
1014 |
+
if self.learned_shortcut:
|
1015 |
+
x_s = self.conv_s(self.norm_s(x, seg))
|
1016 |
+
else:
|
1017 |
+
x_s = x
|
1018 |
+
return x_s
|
1019 |
+
|
1020 |
+
def actvn(self, x):
|
1021 |
+
return F.leaky_relu(x, 2e-1)
|
1022 |
+
|
1023 |
+
|
1024 |
+
class NLayerDiscriminator(nn.Module):
|
1025 |
+
"""Defines a PatchGAN discriminator"""
|
1026 |
+
|
1027 |
+
def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d):
|
1028 |
+
"""Construct a PatchGAN discriminator
|
1029 |
+
|
1030 |
+
Parameters:
|
1031 |
+
input_nc (int) -- the number of channels in input images
|
1032 |
+
ndf (int) -- the number of filters in the last conv layer
|
1033 |
+
n_layers (int) -- the number of conv layers in the discriminator
|
1034 |
+
norm_layer -- normalization layer
|
1035 |
+
"""
|
1036 |
+
super(NLayerDiscriminator, self).__init__()
|
1037 |
+
if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters
|
1038 |
+
use_bias = norm_layer.func == nn.InstanceNorm2d
|
1039 |
+
else:
|
1040 |
+
use_bias = norm_layer == nn.InstanceNorm2d
|
1041 |
+
|
1042 |
+
kw = 4
|
1043 |
+
padw = 1
|
1044 |
+
sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]
|
1045 |
+
nf_mult = 1
|
1046 |
+
nf_mult_prev = 1
|
1047 |
+
for n in range(1, n_layers): # gradually increase the number of filters
|
1048 |
+
nf_mult_prev = nf_mult
|
1049 |
+
nf_mult = min(2 ** n, 8)
|
1050 |
+
sequence += [
|
1051 |
+
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
|
1052 |
+
norm_layer(ndf * nf_mult),
|
1053 |
+
nn.LeakyReLU(0.2, True)
|
1054 |
+
]
|
1055 |
+
|
1056 |
+
nf_mult_prev = nf_mult
|
1057 |
+
nf_mult = min(2 ** n_layers, 8)
|
1058 |
+
sequence += [
|
1059 |
+
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),
|
1060 |
+
norm_layer(ndf * nf_mult),
|
1061 |
+
nn.LeakyReLU(0.2, True)
|
1062 |
+
]
|
1063 |
+
|
1064 |
+
sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map
|
1065 |
+
self.model = nn.Sequential(*sequence)
|
1066 |
+
|
1067 |
+
def forward(self, input):
|
1068 |
+
"""Standard forward."""
|
1069 |
+
return self.model(input)
|
1070 |
+
|
1071 |
+
|
1072 |
+
class PixelDiscriminator(nn.Module):
|
1073 |
+
"""Defines a 1x1 PatchGAN discriminator (pixelGAN)"""
|
1074 |
+
|
1075 |
+
def __init__(self, input_nc, ndf=64, norm_layer=nn.BatchNorm2d):
|
1076 |
+
"""Construct a 1x1 PatchGAN discriminator
|
1077 |
+
|
1078 |
+
Parameters:
|
1079 |
+
input_nc (int) -- the number of channels in input images
|
1080 |
+
ndf (int) -- the number of filters in the last conv layer
|
1081 |
+
norm_layer -- normalization layer
|
1082 |
+
"""
|
1083 |
+
super(PixelDiscriminator, self).__init__()
|
1084 |
+
if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters
|
1085 |
+
use_bias = norm_layer.func == nn.InstanceNorm2d
|
1086 |
+
else:
|
1087 |
+
use_bias = norm_layer == nn.InstanceNorm2d
|
1088 |
+
|
1089 |
+
self.net = [
|
1090 |
+
nn.Conv2d(input_nc, ndf, kernel_size=1, stride=1, padding=0),
|
1091 |
+
nn.LeakyReLU(0.2, True),
|
1092 |
+
nn.Conv2d(ndf, ndf * 2, kernel_size=1, stride=1, padding=0, bias=use_bias),
|
1093 |
+
norm_layer(ndf * 2),
|
1094 |
+
nn.LeakyReLU(0.2, True),
|
1095 |
+
nn.Conv2d(ndf * 2, 1, kernel_size=1, stride=1, padding=0, bias=use_bias)]
|
1096 |
+
|
1097 |
+
self.net = nn.Sequential(*self.net)
|
1098 |
+
|
1099 |
+
def forward(self, input):
|
1100 |
+
"""Standard forward."""
|
1101 |
+
return self.net(input)
|
models/modules/sr/light_model_270M.py
ADDED
@@ -0,0 +1,347 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from models.modules.pix2pixMini_module import *
|
4 |
+
|
5 |
+
pack = Pack()
|
6 |
+
unpack = UnPack()
|
7 |
+
AVG = AvgQuant()
|
8 |
+
mul = SliceMul()
|
9 |
+
|
10 |
+
def get_int(channel_dict):
|
11 |
+
for i in range(len(channel_dict['down'])):
|
12 |
+
channel_dict['down'][i] = int(channel_dict['down'][i])
|
13 |
+
for i in range(len(channel_dict['backbone'])):
|
14 |
+
channel_dict['backbone'][i] = int(channel_dict['backbone'][i])
|
15 |
+
for i in range(len(channel_dict['up'])):
|
16 |
+
for j in range(len(channel_dict['up'][i])):
|
17 |
+
if channel_dict['up'][i][j] is not None:
|
18 |
+
channel_dict['up'][i][j] = int(channel_dict['up'][i][j])
|
19 |
+
return channel_dict
|
20 |
+
|
21 |
+
def get_channel_dict(dict_name, ngf):
|
22 |
+
if dict_name is None:
|
23 |
+
raise('invalid channel_dict name')
|
24 |
+
|
25 |
+
if dict_name == '3G':
|
26 |
+
channel_dict = {
|
27 |
+
'n_blocks': 8,
|
28 |
+
'down': [ngf * 1, ngf * 2, ngf * 4, ngf * 8, ngf * 8],
|
29 |
+
'backbone': [ngf * 8, ngf * 8],
|
30 |
+
'hair_up': [
|
31 |
+
[None, None, None, ngf * 4, 2],
|
32 |
+
[None, None, None, ngf * 2, 2],
|
33 |
+
[None, None, None, ngf * 1, 2],
|
34 |
+
[None, None, None, ngf * 1, 1],
|
35 |
+
],
|
36 |
+
'face_up': [
|
37 |
+
[None, None, None, ngf * 4, 2],
|
38 |
+
[None, None, None, ngf * 2, 2],
|
39 |
+
[None, None, None, ngf * 2, 2],
|
40 |
+
[None, None, None, ngf * 1, 1],
|
41 |
+
],
|
42 |
+
'up': [
|
43 |
+
[None, None, None, ngf * 1, 1],
|
44 |
+
],
|
45 |
+
}
|
46 |
+
elif dict_name == '270M':
|
47 |
+
channel_dict = {
|
48 |
+
'n_blocks': 4,
|
49 |
+
'down': [ngf * 2, ngf * 2, ngf * 3, ngf * 6, ngf * 4],
|
50 |
+
'backbone': [ngf * 4, ngf * 4],
|
51 |
+
'hair_up': [
|
52 |
+
[ngf * 4, None, None, ngf * 2, 1],
|
53 |
+
[ngf * 2, None, None, ngf * 2, 1],
|
54 |
+
[ngf * 2, None, ngf * 1, ngf * 1, 1],
|
55 |
+
[None, None, None, ngf * 1, 1],
|
56 |
+
],
|
57 |
+
'face_up': [
|
58 |
+
[ngf * 4, ngf * 6, ngf * 3, ngf * 2, 1],
|
59 |
+
[ngf * 2, ngf * 3, ngf * 2, ngf * 2, 1],
|
60 |
+
[ngf * 2, ngf * 2, ngf * 1, ngf * 1, 1],
|
61 |
+
[None, None, None, ngf * 1, 1],
|
62 |
+
],
|
63 |
+
'up': [
|
64 |
+
[ngf * 2, ngf * 1, ngf * 1, ngf * 2, 1],
|
65 |
+
],
|
66 |
+
}
|
67 |
+
else:
|
68 |
+
raise('invalid_dict_name')
|
69 |
+
return get_int(channel_dict)
|
70 |
+
|
71 |
+
chans = 8
|
72 |
+
in_channels = 12
|
73 |
+
out_channels = 16
|
74 |
+
face_branch_out_channels = 12
|
75 |
+
hair_branch_out_channels = 12
|
76 |
+
|
77 |
+
class hair_face_model_old(nn.Module):
|
78 |
+
def __init__(self, ngf=chans, backbone_type='resnet', use_se=True, channel_dict_name = None, with_hair_branch=False, design=5):
|
79 |
+
super().__init__()
|
80 |
+
self.design = design
|
81 |
+
self.with_hair_branch = with_hair_branch
|
82 |
+
channel_dict = get_channel_dict(channel_dict_name, ngf)
|
83 |
+
n_blocks = channel_dict['n_blocks']
|
84 |
+
|
85 |
+
self.inconv = ConvBlock(in_channels, channel_dict['down'][0], stride=1)
|
86 |
+
|
87 |
+
# Down-Sampling
|
88 |
+
self.DownBlock1 = ConvBlock(channel_dict['down'][0], channel_dict['down'][1], stride=2)
|
89 |
+
self.DownBlock2 = ConvBlock(channel_dict['down'][1], channel_dict['down'][2], stride=2)
|
90 |
+
self.DownBlock3 = ConvBlock(channel_dict['down'][2], channel_dict['down'][3], stride=2)
|
91 |
+
self.DownBlock4 = ConvBlock(channel_dict['down'][3], channel_dict['down'][4], stride=2)
|
92 |
+
|
93 |
+
# Down-Sampling Bottleneck
|
94 |
+
if backbone_type == 'resnet':
|
95 |
+
backbone_block = ResnetBlock
|
96 |
+
elif backbone_type == 'mobilenet':
|
97 |
+
backbone_block = InvertedBottleneck
|
98 |
+
n_blocks = n_blocks
|
99 |
+
else:
|
100 |
+
raise('invalid backbone type')
|
101 |
+
ResBlock = []
|
102 |
+
ResBlock += [backbone_block(channel_dict['down'][4], channel_dict['backbone'][0], use_bias=False, use_se=use_se)]
|
103 |
+
for i in range(1, n_blocks - 1):
|
104 |
+
ResBlock += [backbone_block(channel_dict['backbone'][0], channel_dict['backbone'][0], use_bias=False, use_se=use_se)]
|
105 |
+
ResBlock += [backbone_block(channel_dict['backbone'][0], channel_dict['backbone'][1], use_bias=False, use_se=use_se)]
|
106 |
+
self.ResBlock = nn.Sequential(*ResBlock)
|
107 |
+
|
108 |
+
self.HairUpBlock4 = UpBlock(channel_dict['backbone'][1], None, channel_dict['hair_up'][0][0], None, channel_dict['hair_up'][0][2], channel_dict['hair_up'][0][3], num_conv=channel_dict['hair_up'][0][4])
|
109 |
+
self.HairUpBlock3 = UpBlock(channel_dict['hair_up'][0][3], None, channel_dict['hair_up'][1][0], None, channel_dict['hair_up'][1][2], channel_dict['hair_up'][1][3], num_conv=channel_dict['hair_up'][1][4])
|
110 |
+
self.HairUpBlock2 = UpBlock(channel_dict['hair_up'][1][3], None, channel_dict['hair_up'][2][0], None, channel_dict['hair_up'][2][2], channel_dict['hair_up'][2][3], num_conv=channel_dict['hair_up'][2][4])
|
111 |
+
|
112 |
+
self.FaceUpBlock4 = UpBlock(channel_dict['backbone'][1], channel_dict['down'][3], channel_dict['face_up'][0][0], channel_dict['face_up'][0][1], channel_dict['face_up'][0][2], channel_dict['face_up'][0][3], num_conv=channel_dict['face_up'][0][4])
|
113 |
+
self.FaceUpBlock3 = UpBlock(channel_dict['face_up'][0][3], channel_dict['down'][2], channel_dict['face_up'][1][0], channel_dict['face_up'][1][1], channel_dict['face_up'][1][2], channel_dict['face_up'][1][3], num_conv=channel_dict['face_up'][1][4])
|
114 |
+
self.FaceUpBlock2 = UpBlock(channel_dict['face_up'][1][3], channel_dict['down'][1], channel_dict['face_up'][2][0], channel_dict['face_up'][2][1], channel_dict['face_up'][2][2], channel_dict['face_up'][2][3], num_conv=channel_dict['face_up'][2][4])
|
115 |
+
|
116 |
+
self.UpBlock1 = UpBlock(channel_dict['hair_up'][2][3] + channel_dict['face_up'][2][3], channel_dict['down'][0], channel_dict['up'][0][0], channel_dict['up'][0][1], channel_dict['up'][0][2], channel_dict['up'][0][3], num_conv=channel_dict['up'][0][4])
|
117 |
+
self.outconv = ConvOutBlock(channel_dict['up'][0][3], out_channels )
|
118 |
+
|
119 |
+
#self.shortcut_ratio = [1,1,1,1]
|
120 |
+
|
121 |
+
if self.with_hair_branch:
|
122 |
+
self.HairUpBlock1 = UpBlock(channel_dict['hair_up'][2][3], None, channel_dict['hair_up'][3][0], None, channel_dict['hair_up'][3][2],channel_dict['hair_up'][3][3])
|
123 |
+
self.Hairoutconv = ConvOutBlock(channel_dict['hair_up'][3][3], hair_branch_out_channels)
|
124 |
+
|
125 |
+
self.FaceUpBlock1 = UpBlock(channel_dict['face_up'][2][3], channel_dict['down'][0], channel_dict['face_up'][3][0], channel_dict['face_up'][3][1], channel_dict['face_up'][3][2], channel_dict['face_up'][3][3])
|
126 |
+
self.Faceoutconv = ConvOutBlock(channel_dict['face_up'][3][3], face_branch_out_channels)
|
127 |
+
|
128 |
+
self.up = UpsampleQuant(scale_factor=1.5, mode='bilinear')
|
129 |
+
|
130 |
+
|
131 |
+
def forward(self, x , with_hair=False):
|
132 |
+
|
133 |
+
design = self.design
|
134 |
+
|
135 |
+
x0 = self.inconv(x)
|
136 |
+
x1 = self.DownBlock1(x0)
|
137 |
+
x2 = self.DownBlock2(x1)
|
138 |
+
x3 = self.DownBlock3(x2)
|
139 |
+
x4 = self.DownBlock4(x3)
|
140 |
+
x5 = self.ResBlock(x4)
|
141 |
+
|
142 |
+
x6 = x5
|
143 |
+
hair = self.HairUpBlock4(x6)
|
144 |
+
hair = self.HairUpBlock3(hair)
|
145 |
+
hair = self.HairUpBlock2( hair)
|
146 |
+
|
147 |
+
face = self.FaceUpBlock4(x6, self.up(x3) if design == 0 else x3)
|
148 |
+
face = self.FaceUpBlock3(self.up(face) if design==1 else face,
|
149 |
+
self.up(x2) if design <= 1 else x2)
|
150 |
+
face = self.FaceUpBlock2(self.up(face) if design == 2 else face,
|
151 |
+
self.up(x1) if design <= 2 else x1)
|
152 |
+
|
153 |
+
hf_cat = torch.cat([hair,face], dim=1)
|
154 |
+
|
155 |
+
x7 = self.UpBlock1(hf_cat, x0)
|
156 |
+
|
157 |
+
x7 = self.outconv( x7)
|
158 |
+
if not with_hair or not self.with_hair_branch:
|
159 |
+
return x7
|
160 |
+
else:
|
161 |
+
hair = self.HairUpBlock1(hair)
|
162 |
+
hair = self.Hairoutconv(hair)
|
163 |
+
face = self.FaceUpBlock1(face, x0)
|
164 |
+
face = self.Faceoutconv(face)
|
165 |
+
# print(design == 5) #true
|
166 |
+
return x7 if design == 5 else x7, hair, face
|
167 |
+
|
168 |
+
class hair_face_model(hair_face_model_old):
|
169 |
+
def __init__(self, **kwargs):
|
170 |
+
super(hair_face_model, self).__init__(**kwargs)
|
171 |
+
|
172 |
+
self.upconv = nn.Sequential(
|
173 |
+
nn.Upsample(scale_factor=2, mode='bilinear'),
|
174 |
+
nn.Conv2d(4, 4, kernel_size=3, stride=1, padding=1, bias=True),
|
175 |
+
nn.Tanh()
|
176 |
+
)
|
177 |
+
|
178 |
+
def forward(self, x):
|
179 |
+
x=pack(x)
|
180 |
+
x = super().forward(x)
|
181 |
+
x=unpack(x)
|
182 |
+
return self.upconv(x)
|
183 |
+
|
184 |
+
class ConvBlock(nn.Module):
|
185 |
+
def __init__(self, in_ch, out_ch, stride):
|
186 |
+
super(ConvBlock, self).__init__()
|
187 |
+
self.conv = nn.Sequential(
|
188 |
+
#nn.Conv2d(in_ch, out_ch, kernel_size=3, stride=stride, padding=1, bias=False),
|
189 |
+
Conv2dQuant(in_ch, out_ch, kernel_size=3, stride=stride, padding=1, bias=True),
|
190 |
+
nn.BatchNorm2d(out_ch),
|
191 |
+
HardQuant(0, 4)
|
192 |
+
#nn.ReLU(False))
|
193 |
+
)
|
194 |
+
def forward(self, x):
|
195 |
+
x = self.conv(x)
|
196 |
+
return x
|
197 |
+
|
198 |
+
|
199 |
+
class ConvOutBlock(nn.Module):
|
200 |
+
def __init__(self, in_ch, out_ch):
|
201 |
+
super(ConvOutBlock, self).__init__()
|
202 |
+
self.conv = nn.Sequential(
|
203 |
+
#nn.Conv2d(in_ch, out_ch, kernel_size=3, stride=1, padding=1, bias=False),
|
204 |
+
Conv2dQuant(in_ch, out_ch, kernel_size=3, stride=1, padding=1, bias=False),
|
205 |
+
#nn.Tanh()
|
206 |
+
TanhOp(data_in_type='float', data_out_type='fixed')
|
207 |
+
)
|
208 |
+
def forward(self, x):
|
209 |
+
x = self.conv(x)
|
210 |
+
return x
|
211 |
+
|
212 |
+
|
213 |
+
class UpBlock(nn.Module):
|
214 |
+
def __init__(self, in_ch1, in_ch2, mid_ch1, mid_ch2, mid_ch, out_ch, num_conv=1, use_bn=True):
|
215 |
+
super(UpBlock, self).__init__()
|
216 |
+
|
217 |
+
#self.up = nn.Upsample(scale_factor=2, mode='nearest')
|
218 |
+
self.up = UpsampleQuant(scale_factor=2, mode='nearest')
|
219 |
+
|
220 |
+
## branch_1
|
221 |
+
if mid_ch1 is None or in_ch1 == mid_ch1:
|
222 |
+
self.conv1 = None
|
223 |
+
else:
|
224 |
+
self.conv1 = nn.Sequential(
|
225 |
+
#nn.Conv2d(in_ch1, mid_ch1, 1, bias=False),
|
226 |
+
Conv2dQuant(in_ch1, mid_ch1, 1, bias=True),
|
227 |
+
#nn.ReLU(False),
|
228 |
+
HardQuant(0, 4)
|
229 |
+
)
|
230 |
+
if mid_ch1 is None:
|
231 |
+
mid_ch1 = in_ch1
|
232 |
+
|
233 |
+
if in_ch2 is None:
|
234 |
+
self.use_shortcut = False
|
235 |
+
self.conv2 = None
|
236 |
+
else:
|
237 |
+
self.use_shortcut = True
|
238 |
+
if mid_ch2 is None or in_ch2 == mid_ch2:
|
239 |
+
self.conv2 = None
|
240 |
+
else:
|
241 |
+
self.conv2 = nn.Sequential(
|
242 |
+
#nn.Conv2d(in_ch2, mid_ch2, 1, bias=False),
|
243 |
+
Conv2dQuant(in_ch2, mid_ch2, 1, bias=True),
|
244 |
+
#nn.ReLU(False),
|
245 |
+
HardQuant(0, 4)
|
246 |
+
)
|
247 |
+
if mid_ch2 is None:
|
248 |
+
mid_ch2 = in_ch2
|
249 |
+
#print(self.conv1 is None, self.conv2 is None)
|
250 |
+
combine_ch = mid_ch1
|
251 |
+
if self.use_shortcut:
|
252 |
+
combine_ch = combine_ch + mid_ch2
|
253 |
+
if mid_ch is None or combine_ch == mid_ch:
|
254 |
+
self.conv_combine = None
|
255 |
+
mid_ch = combine_ch
|
256 |
+
else:
|
257 |
+
self.conv_combine = nn.Sequential(
|
258 |
+
#nn.Conv2d(combine_ch, mid_ch, 1, bias=False),
|
259 |
+
Conv2dQuant(combine_ch, mid_ch, 1, bias=True),
|
260 |
+
#nn.ReLU(False),
|
261 |
+
HardQuant(0, 4)
|
262 |
+
)
|
263 |
+
|
264 |
+
conv_list = []
|
265 |
+
#conv_list.append(nn.Conv2d(mid_ch, out_ch, kernel_size=3, stride=1, padding=1, bias=False))
|
266 |
+
conv_list.append(Conv2dQuant(mid_ch, out_ch, kernel_size=3, stride=1, padding=1, bias=True))
|
267 |
+
if use_bn:
|
268 |
+
conv_list.append(nn.BatchNorm2d(out_ch))
|
269 |
+
#conv_list.append(nn.ReLU(False))
|
270 |
+
conv_list.append(HardQuant(0, 4))
|
271 |
+
for n in range(1, num_conv):
|
272 |
+
conv_list.append(ResnetBlock(out_ch, out_ch, use_bias=False, use_se = False, use_bn=use_bn))
|
273 |
+
self.conv = nn.Sequential(*conv_list)
|
274 |
+
|
275 |
+
def forward(self, x1, x2=None, ratio=None):
|
276 |
+
|
277 |
+
if self.conv1 is not None:
|
278 |
+
x1 = self.conv1(x1)
|
279 |
+
x1 = self.up(x1)
|
280 |
+
|
281 |
+
if self.use_shortcut:
|
282 |
+
if self.conv2 is not None:
|
283 |
+
x2 = self.conv2(x2)
|
284 |
+
|
285 |
+
if self.use_shortcut:
|
286 |
+
if ratio is None:
|
287 |
+
x = torch.cat([x1, x2], dim=1)
|
288 |
+
else:
|
289 |
+
x = torch.cat([x1, x2 * ratio], dim=1)
|
290 |
+
else:
|
291 |
+
x = x1
|
292 |
+
|
293 |
+
if self.conv_combine is not None:
|
294 |
+
x = self.conv_combine(x)
|
295 |
+
|
296 |
+
x = self.conv(x)
|
297 |
+
return x
|
298 |
+
|
299 |
+
|
300 |
+
class ResnetBlock(nn.Module):
|
301 |
+
def __init__(self, dim, dim_out, use_bias, use_se = False, use_bn=True):
|
302 |
+
super(ResnetBlock, self).__init__()
|
303 |
+
conv_block = []
|
304 |
+
#conv_block += [nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1, bias=use_bias),]
|
305 |
+
conv_block += [Conv2dQuant(dim, dim, kernel_size=3, stride=1, padding=1, bias=True),]
|
306 |
+
if use_bn:
|
307 |
+
conv_block += [nn.BatchNorm2d(dim),]
|
308 |
+
#conv_block += [nn.ReLU(False)]
|
309 |
+
conv_block += [HardQuant(0, 4)]
|
310 |
+
|
311 |
+
#conv_block.append(nn.Conv2d(dim, dim_out, kernel_size=3, stride=1, padding=1, bias=use_bias))
|
312 |
+
conv_block.append(Conv2dQuant(dim, dim_out, kernel_size=3, stride=1, padding=1, bias=True))
|
313 |
+
|
314 |
+
if use_bn:
|
315 |
+
conv_block.append(nn.BatchNorm2d(dim_out))
|
316 |
+
conv_block += [HardQuant(0, 4)]
|
317 |
+
|
318 |
+
if use_se:
|
319 |
+
conv_block.append(SqEx(dim_out, 4))
|
320 |
+
|
321 |
+
self.conv_block = nn.Sequential(*conv_block)
|
322 |
+
|
323 |
+
self.downsample = None
|
324 |
+
if dim != dim_out:
|
325 |
+
if use_bn:
|
326 |
+
self.downsample = nn.Sequential(
|
327 |
+
#nn.Conv2d(dim, dim_out, kernel_size=1, stride=1, bias=use_bias),
|
328 |
+
#nn.BatchNorm2d(dim_out),
|
329 |
+
Conv2dQuant(dim, dim_out, kernel_size=1, stride=1, bias=True),
|
330 |
+
nn.BatchNorm2d(dim_out),
|
331 |
+
)
|
332 |
+
else:
|
333 |
+
self.downsample = nn.Sequential(
|
334 |
+
#nn.Conv2d(dim, dim_out, kernel_size=1, stride=1, bias=use_bias),
|
335 |
+
Conv2dQuant(dim, dim_out, kernel_size=1, stride=1, bias=True),
|
336 |
+
)
|
337 |
+
|
338 |
+
#self.relu = nn.ReLU(False)
|
339 |
+
self.relu = HardQuant(0, 4)
|
340 |
+
|
341 |
+
def forward(self, x):
|
342 |
+
if self.downsample is None:
|
343 |
+
y = AVG(x, self.conv_block(x))
|
344 |
+
else:
|
345 |
+
y = AVG(self.downsample(x), self.conv_block(x))
|
346 |
+
#y = self.relu(y)
|
347 |
+
return y
|
models/modules/sr/light_model_470M.py
ADDED
@@ -0,0 +1,442 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from models.modules.pix2pixMini_module import *
|
4 |
+
|
5 |
+
AVG = AvgQuant()
|
6 |
+
mul = SliceMul()
|
7 |
+
|
8 |
+
def get_int(channel_dict):
|
9 |
+
for i in range(len(channel_dict['down'])):
|
10 |
+
channel_dict['down'][i] = int(channel_dict['down'][i])
|
11 |
+
for i in range(len(channel_dict['backbone'])):
|
12 |
+
channel_dict['backbone'][i] = int(channel_dict['backbone'][i])
|
13 |
+
for i in range(len(channel_dict['up'])):
|
14 |
+
for j in range(len(channel_dict['up'][i])):
|
15 |
+
if channel_dict['up'][i][j] is not None:
|
16 |
+
channel_dict['up'][i][j] = int(channel_dict['up'][i][j])
|
17 |
+
return channel_dict
|
18 |
+
|
19 |
+
def get_channel_dict(dict_name, ngf):
|
20 |
+
if dict_name is None:
|
21 |
+
raise('invalid channel_dict name')
|
22 |
+
|
23 |
+
if dict_name == '3G':
|
24 |
+
channel_dict = {
|
25 |
+
'n_blocks': 8,
|
26 |
+
'down': [ngf * 1, ngf * 2, ngf * 4, ngf * 8, ngf * 8],
|
27 |
+
'backbone': [ngf * 8, ngf * 8],
|
28 |
+
'hair_up': [
|
29 |
+
[None, None, None, ngf * 4, 2],
|
30 |
+
[None, None, None, ngf * 2, 2],
|
31 |
+
[None, None, None, ngf * 1, 2],
|
32 |
+
[None, None, None, ngf * 1, 1],
|
33 |
+
],
|
34 |
+
'face_up': [
|
35 |
+
[None, None, None, ngf * 4, 2],
|
36 |
+
[None, None, None, ngf * 2, 2],
|
37 |
+
[None, None, None, ngf * 2, 2],
|
38 |
+
[None, None, None, ngf * 1, 1],
|
39 |
+
],
|
40 |
+
'up': [
|
41 |
+
[None, None, None, ngf * 1, 1],
|
42 |
+
],
|
43 |
+
}
|
44 |
+
|
45 |
+
elif dict_name == '470M':
|
46 |
+
channel_dict = {
|
47 |
+
'n_blocks': 4,
|
48 |
+
'down': [ngf * 1, ngf * 2, ngf * 3, ngf * 6, ngf * 4],
|
49 |
+
'backbone': [ngf * 4, ngf * 4],
|
50 |
+
'hair_up': [
|
51 |
+
[None, None, None, ngf * 2, 1],
|
52 |
+
[None, None, None, ngf * 2, 1],
|
53 |
+
[None, None, None, ngf * 1, 1],
|
54 |
+
[None, None, None, ngf * 1, 1],
|
55 |
+
],
|
56 |
+
'face_up': [
|
57 |
+
[None, None, ngf * 4, ngf * 2, 1],
|
58 |
+
[None, None, ngf * 2, ngf * 2, 1],
|
59 |
+
[None, None, ngf * 2, ngf * 1, 1],
|
60 |
+
[None, None, None, ngf * 1, 1],
|
61 |
+
],
|
62 |
+
'up': [
|
63 |
+
[None, None, ngf * 2, ngf * 1, 1],
|
64 |
+
],
|
65 |
+
}
|
66 |
+
else:
|
67 |
+
raise('invalid_dict_name')
|
68 |
+
return get_int(channel_dict)
|
69 |
+
|
70 |
+
chans = 8
|
71 |
+
in_channels = 3
|
72 |
+
out_channels = 4
|
73 |
+
face_branch_out_channels = 3
|
74 |
+
hair_branch_out_channels = 3
|
75 |
+
|
76 |
+
def conv(numIn, numOut, k, s=1, p=0, relu=True, bn=False):
|
77 |
+
layers = []
|
78 |
+
layers.append(Conv2dQuant(numIn, numOut, k, s, p, bias=True))
|
79 |
+
if bn:
|
80 |
+
layers.append(nn.BatchNorm2d(numOut))
|
81 |
+
|
82 |
+
if relu is True:
|
83 |
+
layers.append(HardQuant(0, 4))
|
84 |
+
return nn.Sequential(*layers)
|
85 |
+
|
86 |
+
def mnconv(numIn, numOut, k, s=1, p=0, dilation=1, relu=True, bn = True):
|
87 |
+
if k < 2:
|
88 |
+
return conv(numIn, numOut, k, s, p, relu, bn)
|
89 |
+
layers = []
|
90 |
+
layers.append(Conv2dQuant(numIn, numIn, k, s, p, groups=numIn, dilation=dilation, bias=True))
|
91 |
+
layers.append(nn.BatchNorm2d(numIn))
|
92 |
+
layers.append(HardQuant(0, 4))
|
93 |
+
layers.append(conv(numIn, numOut, 1, 1, 0, relu, bn))
|
94 |
+
return nn.Sequential(*layers)
|
95 |
+
|
96 |
+
class hair_face_model(nn.Module):
|
97 |
+
def __init__(self, ngf=chans, backbone_type='resnet', use_se=True, channel_dict_name = None, with_hair_branch=False, design=5):
|
98 |
+
|
99 |
+
super().__init__()
|
100 |
+
self.design = design
|
101 |
+
self.with_hair_branch = with_hair_branch
|
102 |
+
channel_dict = get_channel_dict(channel_dict_name, ngf)
|
103 |
+
n_blocks = channel_dict['n_blocks']
|
104 |
+
|
105 |
+
self.inconv = ConvBlock(in_channels, channel_dict['down'][0], stride=1)
|
106 |
+
self.shortcut_ratio = [1,1,1,1]
|
107 |
+
|
108 |
+
# Down-Sampling
|
109 |
+
self.DownBlock1 = ConvBlock(channel_dict['down'][0], channel_dict['down'][1], stride=2)
|
110 |
+
self.DownBlock2 = ConvBlock(channel_dict['down'][1], channel_dict['down'][2], stride=2)
|
111 |
+
self.DownBlock3 = ConvBlock(channel_dict['down'][2], channel_dict['down'][3], stride=2)
|
112 |
+
self.DownBlock4 = ConvBlock(channel_dict['down'][3], channel_dict['down'][4], stride=2)
|
113 |
+
|
114 |
+
# Down-Sampling Bottleneck
|
115 |
+
if backbone_type == 'resnet':
|
116 |
+
backbone_block = ResnetBlock
|
117 |
+
elif backbone_type == 'mobilenet':
|
118 |
+
backbone_block = InvertedBottleneck
|
119 |
+
n_blocks = n_blocks
|
120 |
+
else:
|
121 |
+
raise('invalid backbone type')
|
122 |
+
ResBlock = []
|
123 |
+
ResBlock += [backbone_block(channel_dict['down'][4], channel_dict['backbone'][0], use_bias=False, use_se=use_se)]
|
124 |
+
for i in range(1, n_blocks - 1):
|
125 |
+
ResBlock += [backbone_block(channel_dict['backbone'][0], channel_dict['backbone'][0], use_bias=False, use_se=use_se)]
|
126 |
+
ResBlock += [backbone_block(channel_dict['backbone'][0], channel_dict['backbone'][1], use_bias=False, use_se=use_se)]
|
127 |
+
self.ResBlock = nn.Sequential(*ResBlock)
|
128 |
+
|
129 |
+
self.HairUpBlock4 = UpBlock(channel_dict['backbone'][1], None, channel_dict['hair_up'][0][0], None, channel_dict['hair_up'][0][2], channel_dict['hair_up'][0][3], num_conv=channel_dict['hair_up'][0][4])
|
130 |
+
self.HairUpBlock3 = UpBlock(channel_dict['hair_up'][0][3], None, channel_dict['hair_up'][1][0], None, channel_dict['hair_up'][1][2], channel_dict['hair_up'][1][3], num_conv=channel_dict['hair_up'][1][4])
|
131 |
+
self.HairUpBlock2 = UpBlock(channel_dict['hair_up'][1][3], None, channel_dict['hair_up'][2][0], None, channel_dict['hair_up'][2][2], channel_dict['hair_up'][2][3], num_conv=channel_dict['hair_up'][2][4])
|
132 |
+
|
133 |
+
self.FaceUpBlock4 = UpBlock(channel_dict['backbone'][1], channel_dict['down'][3], channel_dict['face_up'][0][0], channel_dict['face_up'][0][1], channel_dict['face_up'][0][2], channel_dict['face_up'][0][3], num_conv=channel_dict['face_up'][0][4])
|
134 |
+
self.FaceUpBlock3 = UpBlock(channel_dict['face_up'][0][3], channel_dict['down'][2], channel_dict['face_up'][1][0], channel_dict['face_up'][1][1], channel_dict['face_up'][1][2], channel_dict['face_up'][1][3], num_conv=channel_dict['face_up'][1][4])
|
135 |
+
self.FaceUpBlock2 = UpBlock(channel_dict['face_up'][1][3], channel_dict['down'][1], channel_dict['face_up'][2][0], channel_dict['face_up'][2][1], channel_dict['face_up'][2][2], channel_dict['face_up'][2][3], num_conv=channel_dict['face_up'][2][4])
|
136 |
+
|
137 |
+
self.UpBlock1 = mnUpBlock(channel_dict['hair_up'][2][3] + channel_dict['face_up'][2][3], channel_dict['down'][0], channel_dict['up'][0][0], channel_dict['up'][0][1], channel_dict['up'][0][2], channel_dict['up'][0][3], num_conv=channel_dict['up'][0][4])
|
138 |
+
self.outconv = mnConvOutBlock(channel_dict['up'][0][3], out_channels)
|
139 |
+
|
140 |
+
#self.shortcut_ratio = [1,1,1,1]
|
141 |
+
|
142 |
+
if self.with_hair_branch:
|
143 |
+
self.HairUpBlock1 = UpBlock(channel_dict['hair_up'][2][3], None, channel_dict['hair_up'][3][0], None, channel_dict['hair_up'][3][2],channel_dict['hair_up'][3][3])
|
144 |
+
self.Hairoutconv = ConvOutBlock(channel_dict['hair_up'][3][3], hair_branch_out_channels)
|
145 |
+
|
146 |
+
self.FaceUpBlock1 = UpBlock(channel_dict['face_up'][2][3], channel_dict['down'][0], channel_dict['face_up'][3][0], channel_dict['face_up'][3][1], channel_dict['face_up'][3][2], channel_dict['face_up'][3][3])
|
147 |
+
self.Faceoutconv = ConvOutBlock(channel_dict['face_up'][3][3], face_branch_out_channels)
|
148 |
+
|
149 |
+
self.up = UpsampleQuant(scale_factor=1.5, mode='bilinear')
|
150 |
+
|
151 |
+
|
152 |
+
def forward(self, x , with_hair=False):
|
153 |
+
x0 = self.inconv(x)
|
154 |
+
x1 = self.DownBlock1(x0)
|
155 |
+
x2 = self.DownBlock2(x1)
|
156 |
+
x3 = self.DownBlock3(x2)
|
157 |
+
x4 = self.DownBlock4(x3)
|
158 |
+
x = self.ResBlock(x4)
|
159 |
+
|
160 |
+
hair = self.HairUpBlock4(x)
|
161 |
+
hair = self.HairUpBlock3(hair)
|
162 |
+
hair = self.HairUpBlock2(hair)
|
163 |
+
face = self.FaceUpBlock4(x, x3, self.shortcut_ratio[0])
|
164 |
+
face = self.FaceUpBlock3(face, x2, self.shortcut_ratio[1])
|
165 |
+
face = self.FaceUpBlock2(face, x1, self.shortcut_ratio[2])
|
166 |
+
|
167 |
+
x = self.UpBlock1(torch.cat([hair,face], dim=1), x0, self.shortcut_ratio[3])
|
168 |
+
# print(self.outconv)
|
169 |
+
x = self.outconv(x)
|
170 |
+
if not with_hair or not self.with_hair_branch:
|
171 |
+
return x
|
172 |
+
else:
|
173 |
+
hair = self.HairUpBlock1(hair)
|
174 |
+
hair = self.Hairoutconv(hair)
|
175 |
+
|
176 |
+
face = self.FaceUpBlock1(face, x0, self.shortcut_ratio[3])
|
177 |
+
face = self.Faceoutconv(face)
|
178 |
+
|
179 |
+
return x, hair, face
|
180 |
+
|
181 |
+
|
182 |
+
class ConvBlock(nn.Module):
|
183 |
+
def __init__(self, in_ch, out_ch, stride):
|
184 |
+
super(ConvBlock, self).__init__()
|
185 |
+
self.conv = nn.Sequential(
|
186 |
+
#nn.Conv2d(in_ch, out_ch, kernel_size=3, stride=stride, padding=1, bias=False),
|
187 |
+
Conv2dQuant(in_ch, out_ch, kernel_size=3, stride=stride, padding=1, bias=True),
|
188 |
+
nn.BatchNorm2d(out_ch),
|
189 |
+
HardQuant(0, 4)
|
190 |
+
#nn.ReLU(False))
|
191 |
+
)
|
192 |
+
def forward(self, x):
|
193 |
+
x = self.conv(x)
|
194 |
+
return x
|
195 |
+
|
196 |
+
|
197 |
+
|
198 |
+
class UpBlock(nn.Module):
|
199 |
+
def __init__(self, in_ch1, in_ch2, mid_ch1, mid_ch2, mid_ch, out_ch, num_conv=1, use_bn=True):
|
200 |
+
super(UpBlock, self).__init__()
|
201 |
+
|
202 |
+
#self.up = nn.Upsample(scale_factor=2, mode='nearest')
|
203 |
+
self.up = UpsampleQuant(scale_factor=2, mode='nearest')
|
204 |
+
|
205 |
+
## branch_1
|
206 |
+
if mid_ch1 is None or in_ch1 == mid_ch1:
|
207 |
+
self.conv1 = None
|
208 |
+
else:
|
209 |
+
self.conv1 = nn.Sequential(
|
210 |
+
#nn.Conv2d(in_ch1, mid_ch1, 1, bias=False),
|
211 |
+
Conv2dQuant(in_ch1, mid_ch1, 1, bias=True),
|
212 |
+
#nn.ReLU(False),
|
213 |
+
HardQuant(0, 4)
|
214 |
+
)
|
215 |
+
if mid_ch1 is None:
|
216 |
+
mid_ch1 = in_ch1
|
217 |
+
|
218 |
+
if in_ch2 is None:
|
219 |
+
self.use_shortcut = False
|
220 |
+
self.conv2 = None
|
221 |
+
else:
|
222 |
+
self.use_shortcut = True
|
223 |
+
if mid_ch2 is None or in_ch2 == mid_ch2:
|
224 |
+
self.conv2 = None
|
225 |
+
else:
|
226 |
+
self.conv2 = nn.Sequential(
|
227 |
+
#nn.Conv2d(in_ch2, mid_ch2, 1, bias=False),
|
228 |
+
Conv2dQuant(in_ch2, mid_ch2, 1, bias=True),
|
229 |
+
#nn.ReLU(False),
|
230 |
+
HardQuant(0, 4)
|
231 |
+
)
|
232 |
+
if mid_ch2 is None:
|
233 |
+
mid_ch2 = in_ch2
|
234 |
+
#print(self.conv1 is None, self.conv2 is None)
|
235 |
+
combine_ch = mid_ch1
|
236 |
+
if self.use_shortcut:
|
237 |
+
combine_ch = combine_ch + mid_ch2
|
238 |
+
if mid_ch is None or combine_ch == mid_ch:
|
239 |
+
self.conv_combine = None
|
240 |
+
mid_ch = combine_ch
|
241 |
+
else:
|
242 |
+
self.conv_combine = nn.Sequential(
|
243 |
+
#nn.Conv2d(combine_ch, mid_ch, 1, bias=False),
|
244 |
+
Conv2dQuant(combine_ch, mid_ch, 1, bias=True),
|
245 |
+
#nn.ReLU(False),
|
246 |
+
HardQuant(0, 4)
|
247 |
+
)
|
248 |
+
|
249 |
+
conv_list = []
|
250 |
+
#conv_list.append(nn.Conv2d(mid_ch, out_ch, kernel_size=3, stride=1, padding=1, bias=False))
|
251 |
+
conv_list.append(Conv2dQuant(mid_ch, out_ch, kernel_size=3, stride=1, padding=1, bias=True))
|
252 |
+
# conv_list.append(mnconv(mid_ch, out_ch, k=3, s=1, p=1))
|
253 |
+
|
254 |
+
if use_bn:
|
255 |
+
conv_list.append(nn.BatchNorm2d(out_ch))
|
256 |
+
#conv_list.append(nn.ReLU(False))
|
257 |
+
conv_list.append(HardQuant(0, 4))
|
258 |
+
for n in range(1, num_conv):
|
259 |
+
conv_list.append(ResnetBlock(out_ch, out_ch, use_bias=False, use_se = False, use_bn=use_bn))
|
260 |
+
self.conv = nn.Sequential(*conv_list)
|
261 |
+
|
262 |
+
def forward(self, x1, x2=None, ratio=None):
|
263 |
+
|
264 |
+
if self.conv1 is not None:
|
265 |
+
x1 = self.conv1(x1)
|
266 |
+
x1 = self.up(x1)
|
267 |
+
|
268 |
+
if self.use_shortcut:
|
269 |
+
if self.conv2 is not None:
|
270 |
+
x2 = self.conv2(x2)
|
271 |
+
|
272 |
+
if self.use_shortcut:
|
273 |
+
if ratio is None:
|
274 |
+
x = torch.cat([x1, x2], dim=1)
|
275 |
+
else:
|
276 |
+
x = torch.cat([x1, x2], dim=1)
|
277 |
+
else:
|
278 |
+
x = x1
|
279 |
+
|
280 |
+
if self.conv_combine is not None:
|
281 |
+
x = self.conv_combine(x)
|
282 |
+
|
283 |
+
x = self.conv(x)
|
284 |
+
return x
|
285 |
+
|
286 |
+
class mnUpBlock(nn.Module):
|
287 |
+
def __init__(self, in_ch1, in_ch2, mid_ch1, mid_ch2, mid_ch, out_ch, num_conv=1, use_bn=True):
|
288 |
+
super(mnUpBlock, self).__init__()
|
289 |
+
|
290 |
+
#self.up = nn.Upsample(scale_factor=2, mode='nearest')
|
291 |
+
self.up = UpsampleQuant(scale_factor=2, mode='nearest')
|
292 |
+
|
293 |
+
## branch_1
|
294 |
+
if mid_ch1 is None or in_ch1 == mid_ch1:
|
295 |
+
self.conv1 = None
|
296 |
+
else:
|
297 |
+
self.conv1 = nn.Sequential(
|
298 |
+
#nn.Conv2d(in_ch1, mid_ch1, 1, bias=False),
|
299 |
+
Conv2dQuant(in_ch1, mid_ch1, 1, bias=True),
|
300 |
+
#nn.ReLU(False),
|
301 |
+
HardQuant(0, 4)
|
302 |
+
)
|
303 |
+
if mid_ch1 is None:
|
304 |
+
mid_ch1 = in_ch1
|
305 |
+
|
306 |
+
if in_ch2 is None:
|
307 |
+
self.use_shortcut = False
|
308 |
+
self.conv2 = None
|
309 |
+
else:
|
310 |
+
self.use_shortcut = True
|
311 |
+
if mid_ch2 is None or in_ch2 == mid_ch2:
|
312 |
+
self.conv2 = None
|
313 |
+
else:
|
314 |
+
self.conv2 = nn.Sequential(
|
315 |
+
#nn.Conv2d(in_ch2, mid_ch2, 1, bias=False),
|
316 |
+
Conv2dQuant(in_ch2, mid_ch2, 1, bias=True),
|
317 |
+
#nn.ReLU(False),
|
318 |
+
HardQuant(0, 4)
|
319 |
+
)
|
320 |
+
if mid_ch2 is None:
|
321 |
+
mid_ch2 = in_ch2
|
322 |
+
#print(self.conv1 is None, self.conv2 is None)
|
323 |
+
combine_ch = mid_ch1
|
324 |
+
if self.use_shortcut:
|
325 |
+
combine_ch = combine_ch + mid_ch2
|
326 |
+
if mid_ch is None or combine_ch == mid_ch:
|
327 |
+
self.conv_combine = None
|
328 |
+
mid_ch = combine_ch
|
329 |
+
else:
|
330 |
+
self.conv_combine = nn.Sequential(
|
331 |
+
#nn.Conv2d(combine_ch, mid_ch, 1, bias=False),
|
332 |
+
Conv2dQuant(combine_ch, mid_ch, 1, bias=True),
|
333 |
+
#nn.ReLU(False),
|
334 |
+
HardQuant(0, 4)
|
335 |
+
)
|
336 |
+
|
337 |
+
conv_list = []
|
338 |
+
#conv_list.append(nn.Conv2d(mid_ch, out_ch, kernel_size=3, stride=1, padding=1, bias=False))
|
339 |
+
# conv_list.append(Conv2dQuant(mid_ch, out_ch, kernel_size=3, stride=1, padding=1, bias=True))
|
340 |
+
conv_list.append(Conv2dQuant(mid_ch, mid_ch, kernel_size=3, stride=1, padding=1, groups=mid_ch,bias=True))
|
341 |
+
conv_list.append(nn.BatchNorm2d(mid_ch))
|
342 |
+
conv_list.append(HardQuant(0, 4))
|
343 |
+
conv_list.append(Conv2dQuant(mid_ch, out_ch, kernel_size=1,stride=1,padding=0))
|
344 |
+
# conv_list.append(mnconv(mid_ch, out_ch, k=3, s=1, p=1))
|
345 |
+
|
346 |
+
if use_bn:
|
347 |
+
conv_list.append(nn.BatchNorm2d(out_ch))
|
348 |
+
#conv_list.append(nn.ReLU(False))
|
349 |
+
conv_list.append(HardQuant(0, 4))
|
350 |
+
for n in range(1, num_conv):
|
351 |
+
conv_list.append(ResnetBlock(out_ch, out_ch, use_bias=False, use_se = False, use_bn=use_bn))
|
352 |
+
self.conv = nn.Sequential(*conv_list)
|
353 |
+
|
354 |
+
def forward(self, x1, x2=None, ratio=None):
|
355 |
+
|
356 |
+
if self.conv1 is not None:
|
357 |
+
x1 = self.conv1(x1)
|
358 |
+
x1 = self.up(x1)
|
359 |
+
|
360 |
+
if self.use_shortcut:
|
361 |
+
if self.conv2 is not None:
|
362 |
+
x2 = self.conv2(x2)
|
363 |
+
|
364 |
+
if self.use_shortcut:
|
365 |
+
if ratio is None:
|
366 |
+
x = torch.cat([x1, x2], dim=1)
|
367 |
+
else:
|
368 |
+
x = torch.cat([x1, x2], dim=1)
|
369 |
+
else:
|
370 |
+
x = x1
|
371 |
+
|
372 |
+
if self.conv_combine is not None:
|
373 |
+
x = self.conv_combine(x)
|
374 |
+
|
375 |
+
x = self.conv(x)
|
376 |
+
return x
|
377 |
+
|
378 |
+
class mnConvOutBlock(nn.Module):
|
379 |
+
def __init__(self, in_ch, out_ch):
|
380 |
+
super(mnConvOutBlock, self).__init__()
|
381 |
+
self.conv = nn.Sequential(
|
382 |
+
Conv2dQuant(in_ch, in_ch, kernel_size=3, stride=1, padding=1, groups=in_ch, bias=True),
|
383 |
+
nn.BatchNorm2d(in_ch),
|
384 |
+
HardQuant(0, 4),
|
385 |
+
Conv2dQuant(in_ch, out_ch, kernel_size=1, stride=1, padding=0, bias=False),
|
386 |
+
#nn.Tanh()
|
387 |
+
TanhOp(data_in_type='float', data_out_type='fixed'),
|
388 |
+
nn.Upsample(scale_factor=2, mode='bilinear'),
|
389 |
+
nn.Conv2d(out_ch, out_ch, kernel_size=3, stride=1, padding=1, bias=True),
|
390 |
+
nn.Tanh()
|
391 |
+
)
|
392 |
+
def forward(self, x0):
|
393 |
+
x0 = self.conv(x0)
|
394 |
+
return x0
|
395 |
+
|
396 |
+
class ResnetBlock(nn.Module):
|
397 |
+
def __init__(self, dim, dim_out, use_bias, use_se = False, use_bn=True):
|
398 |
+
super(ResnetBlock, self).__init__()
|
399 |
+
conv_block = []
|
400 |
+
#conv_block += [nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1, bias=use_bias),]
|
401 |
+
conv_block += [Conv2dQuant(dim, dim, kernel_size=3, stride=1, padding=1, bias=True),]
|
402 |
+
if use_bn:
|
403 |
+
conv_block += [nn.BatchNorm2d(dim),]
|
404 |
+
#conv_block += [nn.ReLU(False)]
|
405 |
+
conv_block += [HardQuant(0, 4)]
|
406 |
+
|
407 |
+
#conv_block.append(nn.Conv2d(dim, dim_out, kernel_size=3, stride=1, padding=1, bias=use_bias))
|
408 |
+
conv_block.append(Conv2dQuant(dim, dim_out, kernel_size=3, stride=1, padding=1, bias=True))
|
409 |
+
if use_bn:
|
410 |
+
conv_block.append(nn.BatchNorm2d(dim_out))
|
411 |
+
conv_block += [HardQuant(0, 4)]
|
412 |
+
|
413 |
+
if use_se:
|
414 |
+
conv_block.append(SqEx(dim_out, 4))
|
415 |
+
|
416 |
+
self.conv_block = nn.Sequential(*conv_block)
|
417 |
+
|
418 |
+
self.downsample = None
|
419 |
+
if dim != dim_out:
|
420 |
+
if use_bn:
|
421 |
+
self.downsample = nn.Sequential(
|
422 |
+
#nn.Conv2d(dim, dim_out, kernel_size=1, stride=1, bias=use_bias),
|
423 |
+
#nn.BatchNorm2d(dim_out),
|
424 |
+
Conv2dQuant(dim, dim_out, kernel_size=1, stride=1, bias=True),
|
425 |
+
nn.BatchNorm2d(dim_out),
|
426 |
+
)
|
427 |
+
else:
|
428 |
+
self.downsample = nn.Sequential(
|
429 |
+
#nn.Conv2d(dim, dim_out, kernel_size=1, stride=1, bias=use_bias),
|
430 |
+
Conv2dQuant(dim, dim_out, kernel_size=1, stride=1, bias=True),
|
431 |
+
)
|
432 |
+
|
433 |
+
#self.relu = nn.ReLU(False)
|
434 |
+
#self.relu = HardQuant(0, 4)
|
435 |
+
|
436 |
+
def forward(self, x):
|
437 |
+
if self.downsample is None:
|
438 |
+
y = AVG(x, self.conv_block(x))
|
439 |
+
else:
|
440 |
+
y = AVG(self.downsample(x), self.conv_block(x))
|
441 |
+
#y = self.relu(y)
|
442 |
+
return y
|
models/modules/stylegan2/__pycache__/model.cpython-38.pyc
ADDED
Binary file (16.3 kB). View file
|
|
models/modules/stylegan2/__pycache__/non_leaking.cpython-38.pyc
ADDED
Binary file (11 kB). View file
|
|
models/modules/stylegan2/model.py
ADDED
@@ -0,0 +1,716 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import random
|
3 |
+
import functools
|
4 |
+
import operator
|
5 |
+
|
6 |
+
import torch
|
7 |
+
from torch import nn
|
8 |
+
from torch.nn import functional as F
|
9 |
+
from torch.autograd import Function
|
10 |
+
|
11 |
+
from .op import FusedLeakyReLU, fused_leaky_relu, upfirdn2d, conv2d_gradfix
|
12 |
+
|
13 |
+
|
14 |
+
class PixelNorm(nn.Module):
|
15 |
+
def __init__(self):
|
16 |
+
super().__init__()
|
17 |
+
|
18 |
+
def forward(self, input):
|
19 |
+
return input * torch.rsqrt(torch.mean(input ** 2, dim=1, keepdim=True) + 1e-8)
|
20 |
+
|
21 |
+
|
22 |
+
def make_kernel(k):
|
23 |
+
k = torch.tensor(k, dtype=torch.float32)
|
24 |
+
|
25 |
+
if k.ndim == 1:
|
26 |
+
k = k[None, :] * k[:, None]
|
27 |
+
|
28 |
+
k /= k.sum()
|
29 |
+
|
30 |
+
return k
|
31 |
+
|
32 |
+
|
33 |
+
class Upsample(nn.Module):
|
34 |
+
def __init__(self, kernel, factor=2):
|
35 |
+
super().__init__()
|
36 |
+
|
37 |
+
self.factor = factor
|
38 |
+
kernel = make_kernel(kernel) * (factor ** 2)
|
39 |
+
self.register_buffer("kernel", kernel)
|
40 |
+
|
41 |
+
p = kernel.shape[0] - factor
|
42 |
+
|
43 |
+
pad0 = (p + 1) // 2 + factor - 1
|
44 |
+
pad1 = p // 2
|
45 |
+
|
46 |
+
self.pad = (pad0, pad1)
|
47 |
+
|
48 |
+
def forward(self, input):
|
49 |
+
out = upfirdn2d(input, self.kernel, up=self.factor, down=1, pad=self.pad)
|
50 |
+
|
51 |
+
return out
|
52 |
+
|
53 |
+
|
54 |
+
class Downsample(nn.Module):
|
55 |
+
def __init__(self, kernel, factor=2):
|
56 |
+
super().__init__()
|
57 |
+
|
58 |
+
self.factor = factor
|
59 |
+
kernel = make_kernel(kernel)
|
60 |
+
self.register_buffer("kernel", kernel)
|
61 |
+
|
62 |
+
p = kernel.shape[0] - factor
|
63 |
+
|
64 |
+
pad0 = (p + 1) // 2
|
65 |
+
pad1 = p // 2
|
66 |
+
|
67 |
+
self.pad = (pad0, pad1)
|
68 |
+
|
69 |
+
def forward(self, input):
|
70 |
+
out = upfirdn2d(input, self.kernel, up=1, down=self.factor, pad=self.pad)
|
71 |
+
|
72 |
+
return out
|
73 |
+
|
74 |
+
|
75 |
+
class Blur(nn.Module):
|
76 |
+
def __init__(self, kernel, pad, upsample_factor=1):
|
77 |
+
super().__init__()
|
78 |
+
|
79 |
+
kernel = make_kernel(kernel)
|
80 |
+
|
81 |
+
if upsample_factor > 1:
|
82 |
+
kernel = kernel * (upsample_factor ** 2)
|
83 |
+
|
84 |
+
self.register_buffer("kernel", kernel)
|
85 |
+
|
86 |
+
self.pad = pad
|
87 |
+
|
88 |
+
def forward(self, input):
|
89 |
+
out = upfirdn2d(input, self.kernel, pad=self.pad)
|
90 |
+
|
91 |
+
return out
|
92 |
+
|
93 |
+
|
94 |
+
class EqualConv2d(nn.Module):
|
95 |
+
def __init__(
|
96 |
+
self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True
|
97 |
+
):
|
98 |
+
super().__init__()
|
99 |
+
|
100 |
+
self.weight = nn.Parameter(
|
101 |
+
torch.randn(out_channel, in_channel, kernel_size, kernel_size)
|
102 |
+
)
|
103 |
+
self.scale = 1 / math.sqrt(in_channel * kernel_size ** 2)
|
104 |
+
|
105 |
+
self.stride = stride
|
106 |
+
self.padding = padding
|
107 |
+
|
108 |
+
if bias:
|
109 |
+
self.bias = nn.Parameter(torch.zeros(out_channel))
|
110 |
+
|
111 |
+
else:
|
112 |
+
self.bias = None
|
113 |
+
|
114 |
+
def forward(self, input):
|
115 |
+
out = conv2d_gradfix.conv2d(
|
116 |
+
input,
|
117 |
+
self.weight * self.scale,
|
118 |
+
bias=self.bias,
|
119 |
+
stride=self.stride,
|
120 |
+
padding=self.padding,
|
121 |
+
)
|
122 |
+
|
123 |
+
return out
|
124 |
+
|
125 |
+
def __repr__(self):
|
126 |
+
return (
|
127 |
+
f"{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]},"
|
128 |
+
f" {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})"
|
129 |
+
)
|
130 |
+
|
131 |
+
|
132 |
+
class EqualLinear(nn.Module):
|
133 |
+
def __init__(
|
134 |
+
self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None
|
135 |
+
):
|
136 |
+
super().__init__()
|
137 |
+
|
138 |
+
self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul))
|
139 |
+
|
140 |
+
if bias:
|
141 |
+
self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init))
|
142 |
+
|
143 |
+
else:
|
144 |
+
self.bias = None
|
145 |
+
|
146 |
+
self.activation = activation
|
147 |
+
|
148 |
+
self.scale = (1 / math.sqrt(in_dim)) * lr_mul
|
149 |
+
self.lr_mul = lr_mul
|
150 |
+
|
151 |
+
def forward(self, input):
|
152 |
+
if self.activation:
|
153 |
+
out = F.linear(input, self.weight * self.scale)
|
154 |
+
out = fused_leaky_relu(out, self.bias * self.lr_mul)
|
155 |
+
|
156 |
+
else:
|
157 |
+
out = F.linear(
|
158 |
+
input, self.weight * self.scale, bias=self.bias * self.lr_mul
|
159 |
+
)
|
160 |
+
|
161 |
+
return out
|
162 |
+
|
163 |
+
def __repr__(self):
|
164 |
+
return (
|
165 |
+
f"{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})"
|
166 |
+
)
|
167 |
+
|
168 |
+
|
169 |
+
class ModulatedConv2d(nn.Module):
|
170 |
+
def __init__(
|
171 |
+
self,
|
172 |
+
in_channel,
|
173 |
+
out_channel,
|
174 |
+
kernel_size,
|
175 |
+
style_dim,
|
176 |
+
demodulate=True,
|
177 |
+
upsample=False,
|
178 |
+
downsample=False,
|
179 |
+
blur_kernel=[1, 3, 3, 1],
|
180 |
+
fused=True,
|
181 |
+
):
|
182 |
+
super().__init__()
|
183 |
+
|
184 |
+
self.eps = 1e-8
|
185 |
+
self.kernel_size = kernel_size
|
186 |
+
self.in_channel = in_channel
|
187 |
+
self.out_channel = out_channel
|
188 |
+
self.upsample = upsample
|
189 |
+
self.downsample = downsample
|
190 |
+
|
191 |
+
if upsample:
|
192 |
+
factor = 2
|
193 |
+
p = (len(blur_kernel) - factor) - (kernel_size - 1)
|
194 |
+
pad0 = (p + 1) // 2 + factor - 1
|
195 |
+
pad1 = p // 2 + 1
|
196 |
+
|
197 |
+
self.blur = Blur(blur_kernel, pad=(pad0, pad1), upsample_factor=factor)
|
198 |
+
|
199 |
+
if downsample:
|
200 |
+
factor = 2
|
201 |
+
p = (len(blur_kernel) - factor) + (kernel_size - 1)
|
202 |
+
pad0 = (p + 1) // 2
|
203 |
+
pad1 = p // 2
|
204 |
+
|
205 |
+
self.blur = Blur(blur_kernel, pad=(pad0, pad1))
|
206 |
+
|
207 |
+
fan_in = in_channel * kernel_size ** 2
|
208 |
+
self.scale = 1 / math.sqrt(fan_in)
|
209 |
+
self.padding = kernel_size // 2
|
210 |
+
|
211 |
+
self.weight = nn.Parameter(
|
212 |
+
torch.randn(1, out_channel, in_channel, kernel_size, kernel_size)
|
213 |
+
)
|
214 |
+
|
215 |
+
self.modulation = EqualLinear(style_dim, in_channel, bias_init=1)
|
216 |
+
|
217 |
+
self.demodulate = demodulate
|
218 |
+
self.fused = fused
|
219 |
+
|
220 |
+
def __repr__(self):
|
221 |
+
return (
|
222 |
+
f"{self.__class__.__name__}({self.in_channel}, {self.out_channel}, {self.kernel_size}, "
|
223 |
+
f"upsample={self.upsample}, downsample={self.downsample})"
|
224 |
+
)
|
225 |
+
|
226 |
+
def forward(self, input, style):
|
227 |
+
batch, in_channel, height, width = input.shape
|
228 |
+
|
229 |
+
if not self.fused:
|
230 |
+
weight = self.scale * self.weight.squeeze(0)
|
231 |
+
style = self.modulation(style)
|
232 |
+
|
233 |
+
if self.demodulate:
|
234 |
+
w = weight.unsqueeze(0) * style.view(batch, 1, in_channel, 1, 1)
|
235 |
+
dcoefs = (w.square().sum((2, 3, 4)) + 1e-8).rsqrt()
|
236 |
+
|
237 |
+
input = input * style.reshape(batch, in_channel, 1, 1)
|
238 |
+
|
239 |
+
if self.upsample:
|
240 |
+
weight = weight.transpose(0, 1)
|
241 |
+
out = conv2d_gradfix.conv_transpose2d(
|
242 |
+
input, weight, padding=0, stride=2
|
243 |
+
)
|
244 |
+
out = self.blur(out)
|
245 |
+
|
246 |
+
elif self.downsample:
|
247 |
+
input = self.blur(input)
|
248 |
+
out = conv2d_gradfix.conv2d(input, weight, padding=0, stride=2)
|
249 |
+
|
250 |
+
else:
|
251 |
+
out = conv2d_gradfix.conv2d(input, weight, padding=self.padding)
|
252 |
+
|
253 |
+
if self.demodulate:
|
254 |
+
out = out * dcoefs.view(batch, -1, 1, 1)
|
255 |
+
|
256 |
+
return out
|
257 |
+
|
258 |
+
style = self.modulation(style).view(batch, 1, in_channel, 1, 1)
|
259 |
+
weight = self.scale * self.weight * style
|
260 |
+
|
261 |
+
if self.demodulate:
|
262 |
+
demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-8)
|
263 |
+
weight = weight * demod.view(batch, self.out_channel, 1, 1, 1)
|
264 |
+
|
265 |
+
weight = weight.view(
|
266 |
+
batch * self.out_channel, in_channel, self.kernel_size, self.kernel_size
|
267 |
+
)
|
268 |
+
|
269 |
+
if self.upsample:
|
270 |
+
input = input.view(1, batch * in_channel, height, width)
|
271 |
+
weight = weight.view(
|
272 |
+
batch, self.out_channel, in_channel, self.kernel_size, self.kernel_size
|
273 |
+
)
|
274 |
+
weight = weight.transpose(1, 2).reshape(
|
275 |
+
batch * in_channel, self.out_channel, self.kernel_size, self.kernel_size
|
276 |
+
)
|
277 |
+
out = conv2d_gradfix.conv_transpose2d(
|
278 |
+
input, weight, padding=0, stride=2, groups=batch
|
279 |
+
)
|
280 |
+
_, _, height, width = out.shape
|
281 |
+
out = out.view(batch, self.out_channel, height, width)
|
282 |
+
out = self.blur(out)
|
283 |
+
|
284 |
+
elif self.downsample:
|
285 |
+
input = self.blur(input)
|
286 |
+
_, _, height, width = input.shape
|
287 |
+
input = input.view(1, batch * in_channel, height, width)
|
288 |
+
out = conv2d_gradfix.conv2d(
|
289 |
+
input, weight, padding=0, stride=2, groups=batch
|
290 |
+
)
|
291 |
+
_, _, height, width = out.shape
|
292 |
+
out = out.view(batch, self.out_channel, height, width)
|
293 |
+
|
294 |
+
else:
|
295 |
+
input = input.view(1, batch * in_channel, height, width)
|
296 |
+
out = conv2d_gradfix.conv2d(
|
297 |
+
input, weight, padding=self.padding, groups=batch
|
298 |
+
)
|
299 |
+
_, _, height, width = out.shape
|
300 |
+
out = out.view(batch, self.out_channel, height, width)
|
301 |
+
|
302 |
+
return out
|
303 |
+
|
304 |
+
|
305 |
+
class NoiseInjection(nn.Module):
|
306 |
+
def __init__(self):
|
307 |
+
super().__init__()
|
308 |
+
|
309 |
+
self.weight = nn.Parameter(torch.zeros(1))
|
310 |
+
|
311 |
+
def forward(self, image, noise=None):
|
312 |
+
if noise is None:
|
313 |
+
batch, _, height, width = image.shape
|
314 |
+
noise = image.new_empty(batch, 1, height, width).normal_()
|
315 |
+
|
316 |
+
return image + self.weight * noise
|
317 |
+
|
318 |
+
|
319 |
+
class ConstantInput(nn.Module):
|
320 |
+
def __init__(self, channel, size=4):
|
321 |
+
super().__init__()
|
322 |
+
|
323 |
+
if type(size) is tuple:
|
324 |
+
self.input = nn.Parameter(torch.randn(1, channel, size[0], size[1]))
|
325 |
+
else:
|
326 |
+
self.input = nn.Parameter(torch.randn(1, channel, size, size))
|
327 |
+
|
328 |
+
def forward(self, input):
|
329 |
+
batch = input.shape[0]
|
330 |
+
out = self.input.repeat(batch, 1, 1, 1)
|
331 |
+
|
332 |
+
return out
|
333 |
+
|
334 |
+
|
335 |
+
class StyledConv(nn.Module):
|
336 |
+
def __init__(
|
337 |
+
self,
|
338 |
+
in_channel,
|
339 |
+
out_channel,
|
340 |
+
kernel_size,
|
341 |
+
style_dim,
|
342 |
+
upsample=False,
|
343 |
+
blur_kernel=[1, 3, 3, 1],
|
344 |
+
demodulate=True,
|
345 |
+
):
|
346 |
+
super().__init__()
|
347 |
+
|
348 |
+
self.conv = ModulatedConv2d(
|
349 |
+
in_channel,
|
350 |
+
out_channel,
|
351 |
+
kernel_size,
|
352 |
+
style_dim,
|
353 |
+
upsample=upsample,
|
354 |
+
blur_kernel=blur_kernel,
|
355 |
+
demodulate=demodulate,
|
356 |
+
)
|
357 |
+
|
358 |
+
self.noise = NoiseInjection()
|
359 |
+
# self.bias = nn.Parameter(torch.zeros(1, out_channel, 1, 1))
|
360 |
+
# self.activate = ScaledLeakyReLU(0.2)
|
361 |
+
self.activate = FusedLeakyReLU(out_channel)
|
362 |
+
|
363 |
+
def forward(self, input, style, noise=None):
|
364 |
+
out = self.conv(input, style)
|
365 |
+
out = self.noise(out, noise=noise)
|
366 |
+
# out = out + self.bias
|
367 |
+
out = self.activate(out)
|
368 |
+
|
369 |
+
return out
|
370 |
+
|
371 |
+
|
372 |
+
class ToRGB(nn.Module):
|
373 |
+
def __init__(self, in_channel, style_dim, upsample=True, blur_kernel=[1, 3, 3, 1]):
|
374 |
+
super().__init__()
|
375 |
+
|
376 |
+
if upsample:
|
377 |
+
self.upsample = Upsample(blur_kernel)
|
378 |
+
|
379 |
+
self.conv = ModulatedConv2d(in_channel, 3, 1, style_dim, demodulate=False)
|
380 |
+
self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1))
|
381 |
+
|
382 |
+
def forward(self, input, style, skip=None):
|
383 |
+
out = self.conv(input, style)
|
384 |
+
out = out + self.bias
|
385 |
+
|
386 |
+
if skip is not None:
|
387 |
+
skip = self.upsample(skip)
|
388 |
+
|
389 |
+
out = out + skip
|
390 |
+
|
391 |
+
return out
|
392 |
+
|
393 |
+
|
394 |
+
class Generator(nn.Module):
|
395 |
+
def __init__(
|
396 |
+
self,
|
397 |
+
size,
|
398 |
+
style_dim,
|
399 |
+
n_mlp,
|
400 |
+
channel_multiplier=2,
|
401 |
+
blur_kernel=[1, 3, 3, 1],
|
402 |
+
lr_mlp=0.01,
|
403 |
+
):
|
404 |
+
super().__init__()
|
405 |
+
|
406 |
+
self.size = size
|
407 |
+
|
408 |
+
self.style_dim = style_dim
|
409 |
+
|
410 |
+
layers = [PixelNorm()]
|
411 |
+
|
412 |
+
for i in range(n_mlp):
|
413 |
+
layers.append(
|
414 |
+
EqualLinear(
|
415 |
+
style_dim, style_dim, lr_mul=lr_mlp, activation="fused_lrelu"
|
416 |
+
)
|
417 |
+
)
|
418 |
+
|
419 |
+
self.style = nn.Sequential(*layers)
|
420 |
+
|
421 |
+
self.channels = {
|
422 |
+
4: 512,
|
423 |
+
8: 512,
|
424 |
+
16: 512,
|
425 |
+
32: 512,
|
426 |
+
64: 256 * channel_multiplier,
|
427 |
+
128: 128 * channel_multiplier,
|
428 |
+
256: 64 * channel_multiplier,
|
429 |
+
512: 32 * channel_multiplier,
|
430 |
+
1024: 16 * channel_multiplier,
|
431 |
+
}
|
432 |
+
|
433 |
+
self.input = ConstantInput(self.channels[4])
|
434 |
+
self.conv1 = StyledConv(
|
435 |
+
self.channels[4], self.channels[4], 3, style_dim, blur_kernel=blur_kernel
|
436 |
+
)
|
437 |
+
self.to_rgb1 = ToRGB(self.channels[4], style_dim, upsample=False)
|
438 |
+
|
439 |
+
self.log_size = int(math.log(size, 2))
|
440 |
+
self.num_layers = (self.log_size - 2) * 2 + 1
|
441 |
+
|
442 |
+
self.convs = nn.ModuleList()
|
443 |
+
self.upsamples = nn.ModuleList()
|
444 |
+
self.to_rgbs = nn.ModuleList()
|
445 |
+
self.noises = nn.Module()
|
446 |
+
|
447 |
+
in_channel = self.channels[4]
|
448 |
+
|
449 |
+
for layer_idx in range(self.num_layers):
|
450 |
+
res = (layer_idx + 5) // 2
|
451 |
+
shape = [1, 1, 2 ** res, 2 ** res]
|
452 |
+
self.noises.register_buffer(f"noise_{layer_idx}", torch.randn(*shape))
|
453 |
+
|
454 |
+
for i in range(3, self.log_size + 1):
|
455 |
+
out_channel = self.channels[2 ** i]
|
456 |
+
|
457 |
+
self.convs.append(
|
458 |
+
StyledConv(
|
459 |
+
in_channel,
|
460 |
+
out_channel,
|
461 |
+
3,
|
462 |
+
style_dim,
|
463 |
+
upsample=True,
|
464 |
+
blur_kernel=blur_kernel,
|
465 |
+
)
|
466 |
+
)
|
467 |
+
|
468 |
+
self.convs.append(
|
469 |
+
StyledConv(
|
470 |
+
out_channel, out_channel, 3, style_dim, blur_kernel=blur_kernel
|
471 |
+
)
|
472 |
+
)
|
473 |
+
|
474 |
+
self.to_rgbs.append(ToRGB(out_channel, style_dim))
|
475 |
+
|
476 |
+
in_channel = out_channel
|
477 |
+
|
478 |
+
self.n_latent = self.log_size * 2 - 2
|
479 |
+
|
480 |
+
def make_noise(self):
|
481 |
+
device = self.input.input.device
|
482 |
+
|
483 |
+
noises = [torch.randn(1, 1, 2 ** 2, 2 ** 2, device=device)]
|
484 |
+
|
485 |
+
for i in range(3, self.log_size + 1):
|
486 |
+
for _ in range(2):
|
487 |
+
noises.append(torch.randn(1, 1, 2 ** i, 2 ** i, device=device))
|
488 |
+
|
489 |
+
return noises
|
490 |
+
|
491 |
+
def mean_latent(self, n_latent):
|
492 |
+
latent_in = torch.randn(
|
493 |
+
n_latent, self.style_dim, device=self.input.input.device
|
494 |
+
)
|
495 |
+
latent = self.style(latent_in).mean(0, keepdim=True)
|
496 |
+
|
497 |
+
return latent
|
498 |
+
|
499 |
+
def get_latent(self, input):
|
500 |
+
return self.style(input)
|
501 |
+
|
502 |
+
def forward(
|
503 |
+
self,
|
504 |
+
styles,
|
505 |
+
return_latents=False,
|
506 |
+
inject_index=None,
|
507 |
+
truncation=1,
|
508 |
+
truncation_latent=None,
|
509 |
+
input_is_latent=False,
|
510 |
+
noise=None,
|
511 |
+
randomize_noise=True,
|
512 |
+
):
|
513 |
+
if not input_is_latent:
|
514 |
+
styles = [self.style(s) for s in styles]
|
515 |
+
|
516 |
+
if noise is None:
|
517 |
+
if randomize_noise:
|
518 |
+
noise = [None] * self.num_layers
|
519 |
+
else:
|
520 |
+
noise = [
|
521 |
+
getattr(self.noises, f"noise_{i}") for i in range(self.num_layers)
|
522 |
+
]
|
523 |
+
|
524 |
+
if truncation < 1:
|
525 |
+
style_t = []
|
526 |
+
|
527 |
+
for style in styles:
|
528 |
+
style_t.append(
|
529 |
+
truncation_latent + truncation * (style - truncation_latent)
|
530 |
+
)
|
531 |
+
|
532 |
+
styles = style_t
|
533 |
+
|
534 |
+
if len(styles) < 2:
|
535 |
+
inject_index = self.n_latent
|
536 |
+
|
537 |
+
if styles[0].ndim < 3:
|
538 |
+
latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
|
539 |
+
|
540 |
+
else:
|
541 |
+
latent = styles[0]
|
542 |
+
|
543 |
+
else:
|
544 |
+
if inject_index is None:
|
545 |
+
inject_index = random.randint(1, self.n_latent - 1)
|
546 |
+
|
547 |
+
latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
|
548 |
+
latent2 = styles[1].unsqueeze(1).repeat(1, self.n_latent - inject_index, 1)
|
549 |
+
|
550 |
+
latent = torch.cat([latent, latent2], 1)
|
551 |
+
|
552 |
+
out = self.input(latent)
|
553 |
+
out = self.conv1(out, latent[:, 0], noise=noise[0])
|
554 |
+
|
555 |
+
skip = self.to_rgb1(out, latent[:, 1])
|
556 |
+
|
557 |
+
i = 1
|
558 |
+
for conv1, conv2, noise1, noise2, to_rgb in zip(
|
559 |
+
self.convs[::2], self.convs[1::2], noise[1::2], noise[2::2], self.to_rgbs
|
560 |
+
):
|
561 |
+
out = conv1(out, latent[:, i], noise=noise1)
|
562 |
+
out = conv2(out, latent[:, i + 1], noise=noise2)
|
563 |
+
skip = to_rgb(out, latent[:, i + 2], skip)
|
564 |
+
|
565 |
+
i += 2
|
566 |
+
|
567 |
+
image = skip
|
568 |
+
|
569 |
+
if return_latents:
|
570 |
+
return image, latent
|
571 |
+
|
572 |
+
else:
|
573 |
+
return image, None
|
574 |
+
|
575 |
+
|
576 |
+
class ConvLayer(nn.Sequential):
|
577 |
+
def __init__(
|
578 |
+
self,
|
579 |
+
in_channel,
|
580 |
+
out_channel,
|
581 |
+
kernel_size,
|
582 |
+
downsample=False,
|
583 |
+
blur_kernel=[1, 3, 3, 1],
|
584 |
+
bias=True,
|
585 |
+
activate=True,
|
586 |
+
):
|
587 |
+
layers = []
|
588 |
+
|
589 |
+
if downsample:
|
590 |
+
factor = 2
|
591 |
+
p = (len(blur_kernel) - factor) + (kernel_size - 1)
|
592 |
+
pad0 = (p + 1) // 2
|
593 |
+
pad1 = p // 2
|
594 |
+
|
595 |
+
layers.append(Blur(blur_kernel, pad=(pad0, pad1)))
|
596 |
+
|
597 |
+
stride = 2
|
598 |
+
self.padding = 0
|
599 |
+
|
600 |
+
else:
|
601 |
+
stride = 1
|
602 |
+
self.padding = kernel_size // 2
|
603 |
+
|
604 |
+
layers.append(
|
605 |
+
EqualConv2d(
|
606 |
+
in_channel,
|
607 |
+
out_channel,
|
608 |
+
kernel_size,
|
609 |
+
padding=self.padding,
|
610 |
+
stride=stride,
|
611 |
+
bias=bias and not activate,
|
612 |
+
)
|
613 |
+
)
|
614 |
+
|
615 |
+
if activate:
|
616 |
+
layers.append(FusedLeakyReLU(out_channel, bias=bias))
|
617 |
+
|
618 |
+
super().__init__(*layers)
|
619 |
+
|
620 |
+
|
621 |
+
class ResBlock(nn.Module):
|
622 |
+
def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1]):
|
623 |
+
super().__init__()
|
624 |
+
|
625 |
+
self.conv1 = ConvLayer(in_channel, in_channel, 3)
|
626 |
+
self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=True)
|
627 |
+
|
628 |
+
self.skip = ConvLayer(
|
629 |
+
in_channel, out_channel, 1, downsample=True, activate=False, bias=False
|
630 |
+
)
|
631 |
+
|
632 |
+
def forward(self, input):
|
633 |
+
out = self.conv1(input)
|
634 |
+
out = self.conv2(out)
|
635 |
+
|
636 |
+
skip = self.skip(input)
|
637 |
+
out = (out + skip) / math.sqrt(2)
|
638 |
+
|
639 |
+
return out
|
640 |
+
|
641 |
+
|
642 |
+
class Discriminator(nn.Module):
|
643 |
+
def __init__(self, size, min_feats_size, channel_multiplier=2, blur_kernel=[1, 3, 3, 1]):
|
644 |
+
super().__init__()
|
645 |
+
|
646 |
+
channels = {
|
647 |
+
4: 512,
|
648 |
+
8: 512,
|
649 |
+
16: 512,
|
650 |
+
32: 512,
|
651 |
+
64: 256 * channel_multiplier,
|
652 |
+
128: 128 * channel_multiplier,
|
653 |
+
256: 64 * channel_multiplier,
|
654 |
+
512: 32 * channel_multiplier,
|
655 |
+
1024: 16 * channel_multiplier,
|
656 |
+
}
|
657 |
+
|
658 |
+
convs = [ConvLayer(3, channels[size], 1)]
|
659 |
+
|
660 |
+
log_size = int(math.log(size, 2))
|
661 |
+
if type(min_feats_size) is tuple:
|
662 |
+
fsize = min_feats_size[0] * min_feats_size[1]
|
663 |
+
else:
|
664 |
+
fsize = min_feats_size * min_feats_size
|
665 |
+
|
666 |
+
in_channel = channels[size]
|
667 |
+
|
668 |
+
for i in range(log_size, 2, -1):
|
669 |
+
out_channel = channels[2 ** (i - 1)]
|
670 |
+
|
671 |
+
convs.append(ResBlock(in_channel, out_channel, blur_kernel))
|
672 |
+
|
673 |
+
in_channel = out_channel
|
674 |
+
|
675 |
+
self.convs = nn.Sequential(*convs)
|
676 |
+
|
677 |
+
self.stddev_group = 4
|
678 |
+
self.stddev_feat = 1
|
679 |
+
|
680 |
+
self.final_conv = ConvLayer(in_channel + 1, channels[4], 3)
|
681 |
+
self.final_linear = nn.Sequential(
|
682 |
+
EqualLinear(channels[4] * fsize, channels[4], activation="fused_lrelu"),
|
683 |
+
EqualLinear(channels[4], 1),
|
684 |
+
)
|
685 |
+
|
686 |
+
def forward(self, input, rtn_feats=False):
|
687 |
+
if rtn_feats:
|
688 |
+
feats = []
|
689 |
+
feat = input
|
690 |
+
for i, block in enumerate(self.convs):
|
691 |
+
feat = block(feat)
|
692 |
+
if i in [ 1, 3, 4, 5 ]:
|
693 |
+
feats.append(feat)
|
694 |
+
if i == 5:
|
695 |
+
break
|
696 |
+
return feats
|
697 |
+
|
698 |
+
out = self.convs(input)
|
699 |
+
|
700 |
+
batch, channel, height, width = out.shape
|
701 |
+
group = min(batch, self.stddev_group)
|
702 |
+
stddev = out.view(
|
703 |
+
group, -1, self.stddev_feat, channel // self.stddev_feat, height, width
|
704 |
+
)
|
705 |
+
stddev = torch.sqrt(stddev.var(0, unbiased=False) + 1e-8)
|
706 |
+
stddev = stddev.mean([2, 3, 4], keepdims=True).squeeze(2)
|
707 |
+
stddev = stddev.repeat(group, 1, height, width)
|
708 |
+
out = torch.cat([out, stddev], 1)
|
709 |
+
|
710 |
+
out = self.final_conv(out)
|
711 |
+
|
712 |
+
out = out.view(batch, -1)
|
713 |
+
out = self.final_linear(out)
|
714 |
+
|
715 |
+
return out
|
716 |
+
|
models/modules/stylegan2/non_leaking.py
ADDED
@@ -0,0 +1,465 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from torch import autograd
|
5 |
+
from torch.nn import functional as F
|
6 |
+
import numpy as np
|
7 |
+
|
8 |
+
# from distributed import reduce_sum
|
9 |
+
from .op import upfirdn2d
|
10 |
+
|
11 |
+
|
12 |
+
# class AdaptiveAugment:
|
13 |
+
# def __init__(self, ada_aug_target, ada_aug_len, update_every, device):
|
14 |
+
# self.ada_aug_target = ada_aug_target
|
15 |
+
# self.ada_aug_len = ada_aug_len
|
16 |
+
# self.update_every = update_every
|
17 |
+
|
18 |
+
# self.ada_update = 0
|
19 |
+
# self.ada_aug_buf = torch.tensor([0.0, 0.0], device=device)
|
20 |
+
# self.r_t_stat = 0
|
21 |
+
# self.ada_aug_p = 0
|
22 |
+
|
23 |
+
# @torch.no_grad()
|
24 |
+
# def tune(self, real_pred):
|
25 |
+
# self.ada_aug_buf += torch.tensor(
|
26 |
+
# (torch.sign(real_pred).sum().item(), real_pred.shape[0]),
|
27 |
+
# device=real_pred.device,
|
28 |
+
# )
|
29 |
+
# self.ada_update += 1
|
30 |
+
|
31 |
+
# if self.ada_update % self.update_every == 0:
|
32 |
+
# self.ada_aug_buf = reduce_sum(self.ada_aug_buf)
|
33 |
+
# pred_signs, n_pred = self.ada_aug_buf.tolist()
|
34 |
+
|
35 |
+
# self.r_t_stat = pred_signs / n_pred
|
36 |
+
|
37 |
+
# if self.r_t_stat > self.ada_aug_target:
|
38 |
+
# sign = 1
|
39 |
+
|
40 |
+
# else:
|
41 |
+
# sign = -1
|
42 |
+
|
43 |
+
# self.ada_aug_p += sign * n_pred / self.ada_aug_len
|
44 |
+
# self.ada_aug_p = min(1, max(0, self.ada_aug_p))
|
45 |
+
# self.ada_aug_buf.mul_(0)
|
46 |
+
# self.ada_update = 0
|
47 |
+
|
48 |
+
# return self.ada_aug_p
|
49 |
+
|
50 |
+
|
51 |
+
SYM6 = (
|
52 |
+
0.015404109327027373,
|
53 |
+
0.0034907120842174702,
|
54 |
+
-0.11799011114819057,
|
55 |
+
-0.048311742585633,
|
56 |
+
0.4910559419267466,
|
57 |
+
0.787641141030194,
|
58 |
+
0.3379294217276218,
|
59 |
+
-0.07263752278646252,
|
60 |
+
-0.021060292512300564,
|
61 |
+
0.04472490177066578,
|
62 |
+
0.0017677118642428036,
|
63 |
+
-0.007800708325034148,
|
64 |
+
)
|
65 |
+
|
66 |
+
|
67 |
+
def translate_mat(t_x, t_y, device="cpu"):
|
68 |
+
batch = t_x.shape[0]
|
69 |
+
|
70 |
+
mat = torch.eye(3, device=device).unsqueeze(0).repeat(batch, 1, 1)
|
71 |
+
translate = torch.stack((t_x, t_y), 1)
|
72 |
+
mat[:, :2, 2] = translate
|
73 |
+
|
74 |
+
return mat
|
75 |
+
|
76 |
+
|
77 |
+
def rotate_mat(theta, device="cpu"):
|
78 |
+
batch = theta.shape[0]
|
79 |
+
|
80 |
+
mat = torch.eye(3, device=device).unsqueeze(0).repeat(batch, 1, 1)
|
81 |
+
sin_t = torch.sin(theta)
|
82 |
+
cos_t = torch.cos(theta)
|
83 |
+
rot = torch.stack((cos_t, -sin_t, sin_t, cos_t), 1).view(batch, 2, 2)
|
84 |
+
mat[:, :2, :2] = rot
|
85 |
+
|
86 |
+
return mat
|
87 |
+
|
88 |
+
|
89 |
+
def scale_mat(s_x, s_y, device="cpu"):
|
90 |
+
batch = s_x.shape[0]
|
91 |
+
|
92 |
+
mat = torch.eye(3, device=device).unsqueeze(0).repeat(batch, 1, 1)
|
93 |
+
mat[:, 0, 0] = s_x
|
94 |
+
mat[:, 1, 1] = s_y
|
95 |
+
|
96 |
+
return mat
|
97 |
+
|
98 |
+
|
99 |
+
def translate3d_mat(t_x, t_y, t_z):
|
100 |
+
batch = t_x.shape[0]
|
101 |
+
|
102 |
+
mat = torch.eye(4).unsqueeze(0).repeat(batch, 1, 1)
|
103 |
+
translate = torch.stack((t_x, t_y, t_z), 1)
|
104 |
+
mat[:, :3, 3] = translate
|
105 |
+
|
106 |
+
return mat
|
107 |
+
|
108 |
+
|
109 |
+
def rotate3d_mat(axis, theta):
|
110 |
+
batch = theta.shape[0]
|
111 |
+
|
112 |
+
u_x, u_y, u_z = axis
|
113 |
+
|
114 |
+
eye = torch.eye(3).unsqueeze(0)
|
115 |
+
cross = torch.tensor([(0, -u_z, u_y), (u_z, 0, -u_x), (-u_y, u_x, 0)]).unsqueeze(0)
|
116 |
+
outer = torch.tensor(axis)
|
117 |
+
outer = (outer.unsqueeze(1) * outer).unsqueeze(0)
|
118 |
+
|
119 |
+
sin_t = torch.sin(theta).view(-1, 1, 1)
|
120 |
+
cos_t = torch.cos(theta).view(-1, 1, 1)
|
121 |
+
|
122 |
+
rot = cos_t * eye + sin_t * cross + (1 - cos_t) * outer
|
123 |
+
|
124 |
+
eye_4 = torch.eye(4).unsqueeze(0).repeat(batch, 1, 1)
|
125 |
+
eye_4[:, :3, :3] = rot
|
126 |
+
|
127 |
+
return eye_4
|
128 |
+
|
129 |
+
|
130 |
+
def scale3d_mat(s_x, s_y, s_z):
|
131 |
+
batch = s_x.shape[0]
|
132 |
+
|
133 |
+
mat = torch.eye(4).unsqueeze(0).repeat(batch, 1, 1)
|
134 |
+
mat[:, 0, 0] = s_x
|
135 |
+
mat[:, 1, 1] = s_y
|
136 |
+
mat[:, 2, 2] = s_z
|
137 |
+
|
138 |
+
return mat
|
139 |
+
|
140 |
+
|
141 |
+
def luma_flip_mat(axis, i):
|
142 |
+
batch = i.shape[0]
|
143 |
+
|
144 |
+
eye = torch.eye(4).unsqueeze(0).repeat(batch, 1, 1)
|
145 |
+
axis = torch.tensor(axis + (0,))
|
146 |
+
flip = 2 * torch.ger(axis, axis) * i.view(-1, 1, 1)
|
147 |
+
|
148 |
+
return eye - flip
|
149 |
+
|
150 |
+
|
151 |
+
def saturation_mat(axis, i):
|
152 |
+
batch = i.shape[0]
|
153 |
+
|
154 |
+
eye = torch.eye(4).unsqueeze(0).repeat(batch, 1, 1)
|
155 |
+
axis = torch.tensor(axis + (0,))
|
156 |
+
axis = torch.ger(axis, axis)
|
157 |
+
saturate = axis + (eye - axis) * i.view(-1, 1, 1)
|
158 |
+
|
159 |
+
return saturate
|
160 |
+
|
161 |
+
|
162 |
+
def lognormal_sample(size, mean=0, std=1, device="cpu"):
|
163 |
+
return torch.empty(size, device=device).log_normal_(mean=mean, std=std)
|
164 |
+
|
165 |
+
|
166 |
+
def category_sample(size, categories, device="cpu"):
|
167 |
+
category = torch.tensor(categories, device=device)
|
168 |
+
sample = torch.randint(high=len(categories), size=(size,), device=device)
|
169 |
+
|
170 |
+
return category[sample]
|
171 |
+
|
172 |
+
|
173 |
+
def uniform_sample(size, low, high, device="cpu"):
|
174 |
+
return torch.empty(size, device=device).uniform_(low, high)
|
175 |
+
|
176 |
+
|
177 |
+
def normal_sample(size, mean=0, std=1, device="cpu"):
|
178 |
+
return torch.empty(size, device=device).normal_(mean, std)
|
179 |
+
|
180 |
+
|
181 |
+
def bernoulli_sample(size, p, device="cpu"):
|
182 |
+
return torch.empty(size, device=device).bernoulli_(p)
|
183 |
+
|
184 |
+
|
185 |
+
def random_mat_apply(p, transform, prev, eye, device="cpu"):
|
186 |
+
size = transform.shape[0]
|
187 |
+
select = bernoulli_sample(size, p, device=device).view(size, 1, 1)
|
188 |
+
select_transform = select * transform + (1 - select) * eye
|
189 |
+
|
190 |
+
return select_transform @ prev
|
191 |
+
|
192 |
+
|
193 |
+
def sample_affine(p, size, height, width, device="cpu"):
|
194 |
+
G = torch.eye(3, device=device).unsqueeze(0).repeat(size, 1, 1)
|
195 |
+
eye = G
|
196 |
+
|
197 |
+
# flip
|
198 |
+
param = category_sample(size, (0, 1))
|
199 |
+
Gc = scale_mat(1 - 2.0 * param, torch.ones(size), device=device)
|
200 |
+
G = random_mat_apply(p, Gc, G, eye, device=device)
|
201 |
+
# print('flip', G, scale_mat(1 - 2.0 * param, torch.ones(size)), sep='\n')
|
202 |
+
|
203 |
+
# 90 rotate
|
204 |
+
param = category_sample(size, (0, 3))
|
205 |
+
Gc = rotate_mat(-math.pi / 2 * param, device=device)
|
206 |
+
G = random_mat_apply(p, Gc, G, eye, device=device)
|
207 |
+
# print('90 rotate', G, rotate_mat(-math.pi / 2 * param), sep='\n')
|
208 |
+
|
209 |
+
# integer translate
|
210 |
+
param = uniform_sample((2, size), -0.125, 0.125)
|
211 |
+
param_height = torch.round(param[0] * height)
|
212 |
+
param_width = torch.round(param[1] * width)
|
213 |
+
Gc = translate_mat(param_width, param_height, device=device)
|
214 |
+
G = random_mat_apply(p, Gc, G, eye, device=device)
|
215 |
+
# print('integer translate', G, translate_mat(param_width, param_height), sep='\n')
|
216 |
+
|
217 |
+
# isotropic scale
|
218 |
+
param = lognormal_sample(size, std=0.2 * math.log(2))
|
219 |
+
Gc = scale_mat(param, param, device=device)
|
220 |
+
G = random_mat_apply(p, Gc, G, eye, device=device)
|
221 |
+
# print('isotropic scale', G, scale_mat(param, param), sep='\n')
|
222 |
+
|
223 |
+
p_rot = 1 - math.sqrt(1 - p)
|
224 |
+
|
225 |
+
# pre-rotate
|
226 |
+
param = uniform_sample(size, -math.pi, math.pi)
|
227 |
+
Gc = rotate_mat(-param, device=device)
|
228 |
+
G = random_mat_apply(p_rot, Gc, G, eye, device=device)
|
229 |
+
# print('pre-rotate', G, rotate_mat(-param), sep='\n')
|
230 |
+
|
231 |
+
# anisotropic scale
|
232 |
+
param = lognormal_sample(size, std=0.2 * math.log(2))
|
233 |
+
Gc = scale_mat(param, 1 / param, device=device)
|
234 |
+
G = random_mat_apply(p, Gc, G, eye, device=device)
|
235 |
+
# print('anisotropic scale', G, scale_mat(param, 1 / param), sep='\n')
|
236 |
+
|
237 |
+
# post-rotate
|
238 |
+
param = uniform_sample(size, -math.pi, math.pi)
|
239 |
+
Gc = rotate_mat(-param, device=device)
|
240 |
+
G = random_mat_apply(p_rot, Gc, G, eye, device=device)
|
241 |
+
# print('post-rotate', G, rotate_mat(-param), sep='\n')
|
242 |
+
|
243 |
+
# fractional translate
|
244 |
+
param = normal_sample((2, size), std=0.125)
|
245 |
+
Gc = translate_mat(param[1] * width, param[0] * height, device=device)
|
246 |
+
G = random_mat_apply(p, Gc, G, eye, device=device)
|
247 |
+
# print('fractional translate', G, translate_mat(param, param), sep='\n')
|
248 |
+
|
249 |
+
return G
|
250 |
+
|
251 |
+
|
252 |
+
def sample_color(p, size):
|
253 |
+
C = torch.eye(4).unsqueeze(0).repeat(size, 1, 1)
|
254 |
+
eye = C
|
255 |
+
axis_val = 1 / math.sqrt(3)
|
256 |
+
axis = (axis_val, axis_val, axis_val)
|
257 |
+
|
258 |
+
# brightness
|
259 |
+
param = normal_sample(size, std=0.2)
|
260 |
+
Cc = translate3d_mat(param, param, param)
|
261 |
+
C = random_mat_apply(p, Cc, C, eye)
|
262 |
+
|
263 |
+
# contrast
|
264 |
+
param = lognormal_sample(size, std=0.5 * math.log(2))
|
265 |
+
Cc = scale3d_mat(param, param, param)
|
266 |
+
C = random_mat_apply(p, Cc, C, eye)
|
267 |
+
|
268 |
+
# luma flip
|
269 |
+
param = category_sample(size, (0, 1))
|
270 |
+
Cc = luma_flip_mat(axis, param)
|
271 |
+
C = random_mat_apply(p, Cc, C, eye)
|
272 |
+
|
273 |
+
# hue rotation
|
274 |
+
param = uniform_sample(size, -math.pi, math.pi)
|
275 |
+
Cc = rotate3d_mat(axis, param)
|
276 |
+
C = random_mat_apply(p, Cc, C, eye)
|
277 |
+
|
278 |
+
# saturation
|
279 |
+
param = lognormal_sample(size, std=1 * math.log(2))
|
280 |
+
Cc = saturation_mat(axis, param)
|
281 |
+
C = random_mat_apply(p, Cc, C, eye)
|
282 |
+
|
283 |
+
return C
|
284 |
+
|
285 |
+
|
286 |
+
def make_grid(shape, x0, x1, y0, y1, device):
|
287 |
+
n, c, h, w = shape
|
288 |
+
grid = torch.empty(n, h, w, 3, device=device)
|
289 |
+
grid[:, :, :, 0] = torch.linspace(x0, x1, w, device=device)
|
290 |
+
grid[:, :, :, 1] = torch.linspace(y0, y1, h, device=device).unsqueeze(-1)
|
291 |
+
grid[:, :, :, 2] = 1
|
292 |
+
|
293 |
+
return grid
|
294 |
+
|
295 |
+
|
296 |
+
def affine_grid(grid, mat):
|
297 |
+
n, h, w, _ = grid.shape
|
298 |
+
return (grid.view(n, h * w, 3) @ mat.transpose(1, 2)).view(n, h, w, 2)
|
299 |
+
|
300 |
+
|
301 |
+
def get_padding(G, height, width, kernel_size):
|
302 |
+
device = G.device
|
303 |
+
|
304 |
+
cx = (width - 1) / 2
|
305 |
+
cy = (height - 1) / 2
|
306 |
+
cp = torch.tensor(
|
307 |
+
[(-cx, -cy, 1), (cx, -cy, 1), (cx, cy, 1), (-cx, cy, 1)], device=device
|
308 |
+
)
|
309 |
+
cp = G @ cp.T
|
310 |
+
|
311 |
+
pad_k = kernel_size // 4
|
312 |
+
|
313 |
+
pad = cp[:, :2, :].permute(1, 0, 2).flatten(1)
|
314 |
+
pad = torch.cat((-pad, pad)).max(1).values
|
315 |
+
pad = pad + torch.tensor([pad_k * 2 - cx, pad_k * 2 - cy] * 2, device=device)
|
316 |
+
pad = pad.max(torch.tensor([0.0, 0.0] * 2, device=device))
|
317 |
+
pad = pad.min(torch.tensor([width - 1.0, height - 1.0] * 2, device=device))
|
318 |
+
|
319 |
+
pad_x1, pad_y1, pad_x2, pad_y2 = pad.ceil().to(torch.int32)
|
320 |
+
|
321 |
+
return pad_x1, pad_x2, pad_y1, pad_y2
|
322 |
+
|
323 |
+
|
324 |
+
def try_sample_affine_and_pad(img, p, kernel_size, G=None):
|
325 |
+
batch, _, height, width = img.shape
|
326 |
+
|
327 |
+
G_try = G
|
328 |
+
|
329 |
+
if G is None:
|
330 |
+
G_try = torch.inverse(sample_affine(p, batch, height, width))
|
331 |
+
|
332 |
+
pad_x1, pad_x2, pad_y1, pad_y2 = get_padding(G_try, height, width, kernel_size)
|
333 |
+
|
334 |
+
img_pad = F.pad(img, (pad_x1, pad_x2, pad_y1, pad_y2), mode="reflect")
|
335 |
+
|
336 |
+
return img_pad, G_try, (pad_x1, pad_x2, pad_y1, pad_y2)
|
337 |
+
|
338 |
+
|
339 |
+
class GridSampleForward(autograd.Function):
|
340 |
+
@staticmethod
|
341 |
+
def forward(ctx, input, grid):
|
342 |
+
out = F.grid_sample(
|
343 |
+
input, grid, mode="bilinear", padding_mode="zeros", align_corners=False
|
344 |
+
)
|
345 |
+
ctx.save_for_backward(input, grid)
|
346 |
+
|
347 |
+
return out
|
348 |
+
|
349 |
+
@staticmethod
|
350 |
+
def backward(ctx, grad_output):
|
351 |
+
input, grid = ctx.saved_tensors
|
352 |
+
grad_input, grad_grid = GridSampleBackward.apply(grad_output, input, grid)
|
353 |
+
|
354 |
+
return grad_input, grad_grid
|
355 |
+
|
356 |
+
|
357 |
+
class GridSampleBackward(autograd.Function):
|
358 |
+
@staticmethod
|
359 |
+
def forward(ctx, grad_output, input, grid):
|
360 |
+
op = torch._C._jit_get_operation("aten::grid_sampler_2d_backward")
|
361 |
+
grad_input, grad_grid = op(grad_output, input, grid, 0, 0, False)
|
362 |
+
ctx.save_for_backward(grid)
|
363 |
+
|
364 |
+
return grad_input, grad_grid
|
365 |
+
|
366 |
+
@staticmethod
|
367 |
+
def backward(ctx, grad_grad_input, grad_grad_grid):
|
368 |
+
(grid,) = ctx.saved_tensors
|
369 |
+
grad_grad_output = None
|
370 |
+
|
371 |
+
if ctx.needs_input_grad[0]:
|
372 |
+
grad_grad_output = GridSampleForward.apply(grad_grad_input, grid)
|
373 |
+
|
374 |
+
return grad_grad_output, None, None
|
375 |
+
|
376 |
+
|
377 |
+
grid_sample = GridSampleForward.apply
|
378 |
+
|
379 |
+
|
380 |
+
def scale_mat_single(s_x, s_y):
|
381 |
+
return torch.tensor(((s_x, 0, 0), (0, s_y, 0), (0, 0, 1)), dtype=torch.float32)
|
382 |
+
|
383 |
+
|
384 |
+
def translate_mat_single(t_x, t_y):
|
385 |
+
return torch.tensor(((1, 0, t_x), (0, 1, t_y), (0, 0, 1)), dtype=torch.float32)
|
386 |
+
|
387 |
+
|
388 |
+
def random_apply_affine(img, p, G=None, antialiasing_kernel=SYM6):
|
389 |
+
kernel = antialiasing_kernel
|
390 |
+
len_k = len(kernel)
|
391 |
+
|
392 |
+
kernel = torch.as_tensor(kernel).to(img)
|
393 |
+
# kernel = torch.ger(kernel, kernel).to(img)
|
394 |
+
kernel_flip = torch.flip(kernel, (0,))
|
395 |
+
|
396 |
+
img_pad, G, (pad_x1, pad_x2, pad_y1, pad_y2) = try_sample_affine_and_pad(
|
397 |
+
img, p, len_k, G
|
398 |
+
)
|
399 |
+
|
400 |
+
G_inv = (
|
401 |
+
translate_mat_single((pad_x1 - pad_x2).item() / 2, (pad_y1 - pad_y2).item() / 2)
|
402 |
+
@ G
|
403 |
+
)
|
404 |
+
up_pad = (
|
405 |
+
(len_k + 2 - 1) // 2,
|
406 |
+
(len_k - 2) // 2,
|
407 |
+
(len_k + 2 - 1) // 2,
|
408 |
+
(len_k - 2) // 2,
|
409 |
+
)
|
410 |
+
img_2x = upfirdn2d(img_pad, kernel.unsqueeze(0), up=(2, 1), pad=(*up_pad[:2], 0, 0))
|
411 |
+
img_2x = upfirdn2d(img_2x, kernel.unsqueeze(1), up=(1, 2), pad=(0, 0, *up_pad[2:]))
|
412 |
+
G_inv = scale_mat_single(2, 2) @ G_inv @ scale_mat_single(1 / 2, 1 / 2)
|
413 |
+
G_inv = translate_mat_single(-0.5, -0.5) @ G_inv @ translate_mat_single(0.5, 0.5)
|
414 |
+
batch_size, channel, height, width = img.shape
|
415 |
+
pad_k = len_k // 4
|
416 |
+
shape = (batch_size, channel, (height + pad_k * 2) * 2, (width + pad_k * 2) * 2)
|
417 |
+
G_inv = (
|
418 |
+
scale_mat_single(2 / img_2x.shape[3], 2 / img_2x.shape[2])
|
419 |
+
@ G_inv
|
420 |
+
@ scale_mat_single(1 / (2 / shape[3]), 1 / (2 / shape[2]))
|
421 |
+
)
|
422 |
+
grid = F.affine_grid(G_inv[:, :2, :].to(img_2x), shape, align_corners=False)
|
423 |
+
img_affine = grid_sample(img_2x, grid)
|
424 |
+
d_p = -pad_k * 2
|
425 |
+
down_pad = (
|
426 |
+
d_p + (len_k - 2 + 1) // 2,
|
427 |
+
d_p + (len_k - 2) // 2,
|
428 |
+
d_p + (len_k - 2 + 1) // 2,
|
429 |
+
d_p + (len_k - 2) // 2,
|
430 |
+
)
|
431 |
+
img_down = upfirdn2d(
|
432 |
+
img_affine, kernel_flip.unsqueeze(0), down=(2, 1), pad=(*down_pad[:2], 0, 0)
|
433 |
+
)
|
434 |
+
img_down = upfirdn2d(
|
435 |
+
img_down, kernel_flip.unsqueeze(1), down=(1, 2), pad=(0, 0, *down_pad[2:])
|
436 |
+
)
|
437 |
+
|
438 |
+
return img_down, G
|
439 |
+
|
440 |
+
|
441 |
+
def apply_color(img, mat):
|
442 |
+
batch = img.shape[0]
|
443 |
+
img = img.permute(0, 2, 3, 1)
|
444 |
+
mat_mul = mat[:, :3, :3].transpose(1, 2).view(batch, 1, 3, 3)
|
445 |
+
mat_add = mat[:, :3, 3].view(batch, 1, 1, 3)
|
446 |
+
img = img @ mat_mul + mat_add
|
447 |
+
img = img.permute(0, 3, 1, 2)
|
448 |
+
|
449 |
+
return img
|
450 |
+
|
451 |
+
|
452 |
+
def random_apply_color(img, p, C=None):
|
453 |
+
if C is None:
|
454 |
+
C = sample_color(p, img.shape[0])
|
455 |
+
|
456 |
+
img = apply_color(img, C.to(img))
|
457 |
+
|
458 |
+
return img, C
|
459 |
+
|
460 |
+
|
461 |
+
def augment(img, p, transform_matrix=(None, None)):
|
462 |
+
img, G = random_apply_affine(img, p, transform_matrix[0])
|
463 |
+
img, C = random_apply_color(img, p, transform_matrix[1])
|
464 |
+
|
465 |
+
return img, (G, C)
|
models/modules/stylegan2/op/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .fused_act import FusedLeakyReLU, fused_leaky_relu
|
2 |
+
from .upfirdn2d import upfirdn2d
|
models/modules/stylegan2/op/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (275 Bytes). View file
|
|
models/modules/stylegan2/op/__pycache__/conv2d_gradfix.cpython-38.pyc
ADDED
Binary file (5.36 kB). View file
|
|
models/modules/stylegan2/op/__pycache__/fused_act.cpython-38.pyc
ADDED
Binary file (3.29 kB). View file
|
|