limoran commited on
Commit
7e2a2a5
1 Parent(s): 8220eea

add basic files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. configs/__init__.py +30 -0
  2. configs/__pycache__/__init__.cpython-38.pyc +0 -0
  3. configs/__pycache__/base_config.cpython-38.pyc +0 -0
  4. configs/__pycache__/style_based_pix2pixII_config.cpython-38.pyc +0 -0
  5. configs/base_config.py +160 -0
  6. configs/style_based_pix2pixII_config.py +42 -0
  7. data/__init__.py +58 -0
  8. data/__pycache__/__init__.cpython-38.pyc +0 -0
  9. data/__pycache__/static_data.cpython-38.pyc +0 -0
  10. data/__pycache__/super_dataset.cpython-38.pyc +0 -0
  11. data/__pycache__/test_data.cpython-38.pyc +0 -0
  12. data/__pycache__/test_video_data.cpython-38.pyc +0 -0
  13. data/deprecated/custom_data.py +121 -0
  14. data/deprecated/landmark_data.py +89 -0
  15. data/deprecated/numpy_paired_data.py +81 -0
  16. data/deprecated/numpy_unpaired_data.py +100 -0
  17. data/deprecated/paired_data.py +80 -0
  18. data/deprecated/patch_data.py +44 -0
  19. data/deprecated/unpaired_data.py +101 -0
  20. data/static_data.py +457 -0
  21. data/super_dataset.py +321 -0
  22. data/test_data.py +51 -0
  23. data/test_video_data.py +28 -0
  24. exp/sp2pII-phase1.yaml +49 -0
  25. exp/sp2pII-phase2.yaml +49 -0
  26. exp/sp2pII-phase3.yaml +50 -0
  27. exp/sp2pII-phase4.yaml +49 -0
  28. logs/01_2023_09_07__18_32_26/events.out.tfevents.1694082748.aiplatform-wlf2-hi-12.idchb2az2.hb2.kwaidc.com.16044.0 +3 -0
  29. logs/01_2023_09_12__14_54_32/events.out.tfevents.1694501684.aiplatform-wlf2-ge4-22.idchb2az2.hb2.kwaidc.com.76748.0 +3 -0
  30. logs/01_2023_09_12__14_55_34/events.out.tfevents.1694501736.aiplatform-wlf2-ge4-22.idchb2az2.hb2.kwaidc.com.77369.0 +3 -0
  31. logs/01_2023_09_12__15_03_47/events.out.tfevents.1694502229.aiplatform-wlf2-ge4-22.idchb2az2.hb2.kwaidc.com.77940.0 +3 -0
  32. models/__init__.py +68 -0
  33. models/__pycache__/__init__.cpython-38.pyc +0 -0
  34. models/__pycache__/base_model.cpython-38.pyc +0 -0
  35. models/__pycache__/style_based_pix2pixII_model.cpython-38.pyc +0 -0
  36. models/base_model.py +340 -0
  37. models/modules/__init__.py +0 -0
  38. models/modules/__pycache__/__init__.cpython-38.pyc +0 -0
  39. models/modules/__pycache__/networks.cpython-38.pyc +0 -0
  40. models/modules/networks.py +1101 -0
  41. models/modules/sr/light_model_270M.py +347 -0
  42. models/modules/sr/light_model_470M.py +442 -0
  43. models/modules/stylegan2/__pycache__/model.cpython-38.pyc +0 -0
  44. models/modules/stylegan2/__pycache__/non_leaking.cpython-38.pyc +0 -0
  45. models/modules/stylegan2/model.py +716 -0
  46. models/modules/stylegan2/non_leaking.py +465 -0
  47. models/modules/stylegan2/op/__init__.py +2 -0
  48. models/modules/stylegan2/op/__pycache__/__init__.cpython-38.pyc +0 -0
  49. models/modules/stylegan2/op/__pycache__/conv2d_gradfix.cpython-38.pyc +0 -0
  50. models/modules/stylegan2/op/__pycache__/fused_act.cpython-38.pyc +0 -0
configs/__init__.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+ from configs.base_config import BaseConfig
3
+
4
+ def find_config_by_name(config_name):
5
+ # load config lib by config name
6
+ config_file = "configs." + config_name + '_config'
7
+ config_lib = importlib.import_module(config_file)
8
+ print(config_lib)
9
+
10
+ # find the subclass of BaseConfig
11
+ config = None
12
+ target_config_name = config_name.replace('_', '') + 'config'
13
+ target_config_name = target_config_name.lower()
14
+ for name, cls in config_lib.__dict__.items():
15
+ if name.lower() == target_config_name and issubclass(cls, BaseConfig):
16
+ config = cls
17
+
18
+ if config is None:
19
+ raise Exception('No valid config found.')
20
+
21
+ return config
22
+
23
+ def parse_config(cfg_file):
24
+ # parse config using BaseConfig
25
+ cfg = BaseConfig().parse_config(cfg_file)
26
+ model_name = cfg['common']['model']
27
+
28
+ # re-parse using specified Config class
29
+ config = find_config_by_name(model_name)
30
+ return config().parse_config(cfg_file)
configs/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (900 Bytes). View file
 
configs/__pycache__/base_config.cpython-38.pyc ADDED
Binary file (4.96 kB). View file
 
configs/__pycache__/style_based_pix2pixII_config.cpython-38.pyc ADDED
Binary file (2.13 kB). View file
 
configs/base_config.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import yaml
2
+ import copy
3
+ from typing import Union
4
+
5
+ class BaseConfig():
6
+
7
+ def __init__(self):
8
+ self.__config_dict = {}
9
+ self.__check_func_dict = {}
10
+
11
+ is_greater_than_0 = lambda x: x > 0
12
+
13
+ # common config
14
+ self._add_option('common', 'name', str, 'style_master')
15
+ self._add_option('common', 'model', str, 'cycle_gan')
16
+ self._add_option('common', 'phase', str, 'train', check_func=lambda x: x in ['train', 'test'])
17
+ self._add_option('common', 'gpu_ids', list, [0])
18
+ self._add_option('common', 'verbose', bool, False)
19
+
20
+ # model config
21
+ self._add_option('model', 'input_nc', int, 3, check_func=is_greater_than_0)
22
+ self._add_option('model', 'output_nc', int, 3, check_func=is_greater_than_0)
23
+
24
+ # dataset config
25
+ # common dataset options
26
+ self._add_option('dataset', 'use_absolute_datafile', bool, True)
27
+ self._add_option('dataset', 'batch_size', int, 1, check_func=is_greater_than_0)
28
+ self._add_option('dataset', 'n_threads', int, 4, check_func=is_greater_than_0)
29
+ self._add_option('dataset', 'dataroot', str, './')
30
+ self._add_option('dataset', 'drop_last', bool, False)
31
+ self._add_option('dataset', 'landmark_scale', list, None)
32
+ self._add_option('dataset', 'check_all_data', bool, False)
33
+ self._add_option('dataset', 'accept_data_error', bool, True) # Upon loading a bad data, if this is true,
34
+ # dataloader will throw an exception and
35
+ # load the next good data.
36
+ # If this is false, process will crash.
37
+
38
+ self._add_option('dataset', 'train_data', dict, {})
39
+ self._add_option('dataset', 'val_data', dict, {})
40
+
41
+ # paired data config
42
+ self._add_option('dataset', 'paired_trainA_folder', str, '')
43
+ self._add_option('dataset', 'paired_trainB_folder', str, '')
44
+ self._add_option('dataset', 'paired_train_filelist', str, '')
45
+ self._add_option('dataset', 'paired_valA_folder', str, '')
46
+ self._add_option('dataset', 'paired_valB_folder', str, '')
47
+ self._add_option('dataset', 'paired_val_filelist', str, '')
48
+
49
+ # unpaired data config
50
+ self._add_option('dataset', 'unpaired_trainA_folder', str, '')
51
+ self._add_option('dataset', 'unpaired_trainB_folder', str, '')
52
+ self._add_option('dataset', 'unpaired_trainA_filelist', str, '')
53
+ self._add_option('dataset', 'unpaired_trainB_filelist', str, '')
54
+ self._add_option('dataset', 'unpaired_valA_folder', str, '')
55
+ self._add_option('dataset', 'unpaired_valB_folder', str, '')
56
+ self._add_option('dataset', 'unpaired_valA_filelist', str, '')
57
+ self._add_option('dataset', 'unpaired_valB_filelist', str, '')
58
+
59
+ # custom data
60
+ self._add_option('dataset', 'custom_train_data', dict, {})
61
+ self._add_option('dataset', 'custom_val_data', dict, {})
62
+
63
+ # training config
64
+ self._add_option('training', 'checkpoints_dir', str, './checkpoints')
65
+ self._add_option('training', 'log_dir', str, './logs')
66
+ self._add_option('training', 'use_new_log', bool, False)
67
+ self._add_option('training', 'continue_train', bool, False)
68
+ self._add_option('training', 'which_epoch', str, 'latest')
69
+ self._add_option('training', 'n_epochs', int, 100, check_func=is_greater_than_0)
70
+ self._add_option('training', 'n_epochs_decay', int, 100, check_func=is_greater_than_0)
71
+ self._add_option('training', 'save_latest_freq', int, 5000, check_func=is_greater_than_0)
72
+ self._add_option('training', 'print_freq', int, 200, check_func=is_greater_than_0)
73
+ self._add_option('training', 'save_epoch_freq', int, 5, check_func=is_greater_than_0)
74
+ self._add_option('training', 'epoch_as_iter', bool, False)
75
+ self._add_option('training', 'lr', float, 2e-4, check_func=is_greater_than_0)
76
+ self._add_option('training', 'lr_policy', str, 'linear',
77
+ check_func=lambda x: x in ['linear', 'step', 'plateau', 'cosine'])
78
+ self._add_option('training', 'lr_decay_iters', int, 50, check_func=is_greater_than_0)
79
+ self._add_option('training', 'DDP', bool, False)
80
+ self._add_option('training', 'num_nodes', int, 1, check_func=is_greater_than_0)
81
+ self._add_option('training', 'DDP_address', str, '127.0.0.1')
82
+ self._add_option('training', 'DDP_port', str, '29700')
83
+ self._add_option('training', 'find_unused_parameters', bool, False) # a DDP option that allows backward on a subgraph of the model
84
+ self._add_option('training', 'val_percent', float, 5.0, check_func=is_greater_than_0) # Uses x% of training data to validate
85
+ self._add_option('training', 'val', bool, True) # perform validation every epoch
86
+ self._add_option('training', 'save_training_progress', bool, False) # save images to create a training progression video
87
+
88
+ # testing config
89
+ self._add_option('testing', 'results_dir', str, './results')
90
+ self._add_option('testing', 'load_size', int, 512, check_func=is_greater_than_0)
91
+ self._add_option('testing', 'crop_size', int, 512, check_func=is_greater_than_0)
92
+ self._add_option('testing', 'preprocess', list, ['scale_width'])
93
+ self._add_option('testing', 'visual_names', list, [])
94
+ self._add_option('testing', 'num_test', int, 999999, check_func=is_greater_than_0)
95
+ self._add_option('testing', 'image_format', str, 'jpg', check_func=lambda x: x in ['input', 'jpg', 'jpeg', 'png'])
96
+
97
+ def _add_option(self, group_name, option_name, value_type, default_value, check_func=None):
98
+ # check name type
99
+ if not type(group_name) is str or not type(option_name) is str:
100
+ raise Exception('Type of {} and {} must be str.'.format(group_name, option_name))
101
+
102
+ # add group
103
+ if not group_name in self.__config_dict:
104
+ self.__config_dict[group_name] = {}
105
+ self.__check_func_dict[group_name] = {}
106
+
107
+ # check type & default value
108
+ if not type(value_type) is type:
109
+ try:
110
+ if value_type.__origin__ is not Union:
111
+ raise Exception('{} is not a type.'.format(value_type))
112
+ except Exception as e:
113
+ print(e)
114
+ if not type(default_value) is value_type:
115
+ try:
116
+ if value_type.__origin__ is not Union:
117
+ raise Exception('Type of {} must be {}.'.format(default_value, value_type))
118
+ except Exception as e:
119
+ print(e)
120
+
121
+ # add option to dict
122
+ if not option_name in self.__config_dict[group_name]:
123
+ if not check_func is None and not check_func(default_value):
124
+ raise Exception('Checking {}/{} failed.'.format(group_name, option_name))
125
+ self.__config_dict[group_name][option_name] = default_value
126
+ self.__check_func_dict[group_name][option_name] = check_func
127
+ else:
128
+ raise Exception('{} has been already added.'.format(option_name))
129
+
130
+ def parse_config(self, cfg_file):
131
+ # load config from yaml file
132
+ with open(cfg_file, 'r') as f:
133
+ yaml_config = yaml.safe_load(f)
134
+ if not type(yaml_config) is dict:
135
+ raise Exception('Loading yaml file failed.')
136
+
137
+ # replace default options
138
+ config_dict = copy.deepcopy(self.__config_dict)
139
+ for group in config_dict:
140
+ if group in yaml_config:
141
+ for option in config_dict[group]:
142
+ if option in yaml_config[group]:
143
+ value = yaml_config[group][option]
144
+ if not type(value) is type(config_dict[group][option]):
145
+ try: # if <config_dict[group][option]> is not union, it won't have __origin__ attribute. So will throw an error.
146
+ # The line below is necessary because we check if <config_dict[group][option]> has __origin__ attribute.
147
+ if config_dict[group][option].__origin__ is Union:
148
+ # check to see if type of <value> belongs to a type in the union.
149
+ if not isinstance(value, config_dict[group][option].__args__):
150
+ raise Exception('Type of {}/{} must be {}.'.format(group, option,
151
+ config_dict[group][option].__args__))
152
+ except Exception as e: # if the error was thrown, we know there's a type error.
153
+ print(e)
154
+ else:
155
+ check_func = self.__check_func_dict[group][option]
156
+ if not check_func is None and not check_func(value):
157
+ raise Exception('Checking {}/{} failed.'.format(group, option))
158
+ config_dict[group][option] = value
159
+ return config_dict
160
+
configs/style_based_pix2pixII_config.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .base_config import BaseConfig
2
+ from typing import Union as Union
3
+
4
+ class StyleBasedPix2PixIIConfig(BaseConfig):
5
+
6
+ def __init__(self):
7
+ super(StyleBasedPix2PixIIConfig, self).__init__()
8
+
9
+ is_greater_than_0 = lambda x: x > 0
10
+
11
+ # model config
12
+ self._add_option('model', 'ngf', int, 64, check_func=is_greater_than_0)
13
+ self._add_option('model', 'min_feats_size', list, [4, 4])
14
+
15
+ # dataset config
16
+ self._add_option('dataset', 'data_type', list, ['unpaired'])
17
+ self._add_option('dataset', 'direction', str, 'AtoB')
18
+ self._add_option('dataset', 'serial_batches', bool, False)
19
+ self._add_option('dataset', 'load_size', int, 512, check_func=is_greater_than_0)
20
+ self._add_option('dataset', 'crop_size', int, 512, check_func=is_greater_than_0)
21
+ self._add_option('dataset', 'preprocess', Union[list, str], ['resize'])
22
+ self._add_option('dataset', 'no_flip', bool, True)
23
+
24
+ # training config
25
+ self._add_option('training', 'beta1', float, 0.1, check_func=is_greater_than_0)
26
+ self._add_option('training', 'data_aug_prob', float, 0.0, check_func=lambda x: x >= 0.0)
27
+ self._add_option('training', 'style_mixing_prob', float, 0.0, check_func=lambda x: x >= 0.0)
28
+ self._add_option('training', 'phase', int, 1, check_func=lambda x: x in [1, 2, 3, 4])
29
+ self._add_option('training', 'pretrained_model', str, 'model.pth')
30
+ self._add_option('training', 'src_text_prompt', str, 'photo')
31
+ self._add_option('training', 'text_prompt', str, 'a portrait in style of sketch')
32
+ self._add_option('training', 'image_prompt', str, 'style.png')
33
+ self._add_option('training', 'lambda_L1', float, 1.0)
34
+ self._add_option('training', 'lambda_Feat', float, 4.0)
35
+ self._add_option('training', 'lambda_ST', float, 1.0)
36
+ self._add_option('training', 'lambda_GAN', float, 1.0)
37
+ self._add_option('training', 'lambda_CLIP', float, 1.0)
38
+ self._add_option('training', 'lambda_PROJ', float, 1.0)
39
+ self._add_option('training', 'ema', float, 0.999)
40
+
41
+ # testing config
42
+ self._add_option('testing', 'aspect_ratio', float, 1.0, check_func=is_greater_than_0)
data/__init__.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """This package includes all the modules related to data loading and preprocessing
2
+
3
+ To add a custom dataset class called 'dummy', you need to add a file called 'dummy_dataset.py' and define a subclass 'DummyDataset' inherited from BaseDataset.
4
+ You need to implement four functions:
5
+ -- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt).
6
+ -- <__len__>: return the size of dataset.
7
+ -- <__getitem__>: get a data point from data loader.
8
+ -- <modify_commandline_options>: (optionally) add dataset-specific options and set default options.
9
+
10
+ Now you can use the dataset class by specifying flag '--dataset_mode dummy'.
11
+ See our template dataset class 'template_dataset.py' for more details.
12
+ """
13
+ import importlib
14
+ import torch.utils.data
15
+ from torch.utils.data.distributed import DistributedSampler
16
+
17
+ class CustomDataLoader():
18
+ """Wrapper class of Dataset class that performs multi-threaded data loading"""
19
+
20
+ def __init__(self, config, dataset, DDP_gpu=None, drop_last=False):
21
+ """Initialize this class
22
+
23
+ Step 1: create a dataset instance given the name [dataset_mode]
24
+ Step 2: create a multi-threaded data loader.
25
+ """
26
+ self.config = config
27
+ self.dataset = dataset
28
+
29
+ if DDP_gpu is None:
30
+ self.dataloader = torch.utils.data.DataLoader(
31
+ self.dataset,
32
+ batch_size=config['dataset']['batch_size'],
33
+ shuffle=not config['dataset']['serial_batches'],
34
+ num_workers=int(config['dataset']['n_threads']), drop_last=drop_last)
35
+ else:
36
+ sampler = DistributedSampler(self.dataset, num_replicas=self.config['training']['world_size'],
37
+ rank=DDP_gpu)
38
+ self.dataloader = torch.utils.data.DataLoader(
39
+ self.dataset,
40
+ batch_size=config['dataset']['batch_size'],
41
+ shuffle=False,
42
+ num_workers=int(config['dataset']['n_threads']),
43
+ sampler=sampler,
44
+ drop_last=drop_last)
45
+
46
+ def load_data(self):
47
+ return self
48
+
49
+ def __len__(self):
50
+ """Return the number of data in the dataset"""
51
+ return min(len(self.dataset), 1e9)
52
+
53
+ def __iter__(self):
54
+ """Return a batch of data"""
55
+ for i, data in enumerate(self.dataloader):
56
+ if i * self.config['dataset']['batch_size'] >= 1e9:
57
+ break
58
+ yield data
data/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (2.64 kB). View file
 
data/__pycache__/static_data.cpython-38.pyc ADDED
Binary file (10.8 kB). View file
 
data/__pycache__/super_dataset.cpython-38.pyc ADDED
Binary file (9.17 kB). View file
 
data/__pycache__/test_data.cpython-38.pyc ADDED
Binary file (1.74 kB). View file
 
data/__pycache__/test_video_data.cpython-38.pyc ADDED
Binary file (1.43 kB). View file
 
data/deprecated/custom_data.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ import numpy as np
4
+ from utils.augmentation import ImagePathToImage
5
+ from utils.data_utils import Transforms, check_img_loaded, check_numpy_loaded
6
+
7
+
8
+ class CustomData(object):
9
+
10
+ def __init__(self, config, shuffle=False):
11
+ self.paired_file_groups = []
12
+ self.paired_type_groups = []
13
+ self.len_of_groups = []
14
+ self.landmark_scale = config['dataset']['landmark_scale']
15
+ self.shuffle = shuffle
16
+ self.config = config
17
+
18
+ data_dict = config['dataset']['custom_' + config['common']['phase'] + '_data']
19
+ if len(data_dict) == 0:
20
+ self.len_of_groups.append(0)
21
+ return
22
+
23
+ for i, group in enumerate(data_dict.values()): # one example: (0, group_1), (1, group_2)
24
+ data_types = group['data_types'] # one example: 'image', 'patch'
25
+ data_names = group['data_names'] # one example: 'real_A', 'patch_A'
26
+ file_list = group['file_list'] # one example: "lmt/data/trainA.txt"
27
+ assert(len(data_types) == len(data_names))
28
+
29
+ self.paired_file_groups.append({})
30
+ self.paired_type_groups.append({})
31
+ for data_name, data_type in zip(data_names, data_types):
32
+ self.paired_file_groups[i][data_name] = []
33
+ self.paired_type_groups[i][data_name] = data_type
34
+
35
+ paired_file = open(file_list, 'rt')
36
+ lines = paired_file.readlines()
37
+ if self.shuffle:
38
+ random.shuffle(lines)
39
+ for line in lines:
40
+ items = line.strip().split(' ')
41
+ if len(items) == len(data_names):
42
+ ok = True
43
+ for item in items:
44
+ ok = ok and os.path.exists(item) and os.path.getsize(item) > 0
45
+ if ok:
46
+ for data_name, item in zip(data_names, items):
47
+ self.paired_file_groups[i][data_name].append(item)
48
+ paired_file.close()
49
+
50
+ self.len_of_groups.append(len(self.paired_file_groups[i][data_names[0]]))
51
+
52
+ self.transform = Transforms(config)
53
+ self.transform.get_transform_from_config()
54
+ self.transform.get_transforms().insert(0, ImagePathToImage())
55
+ self.transform = self.transform.compose_transforms()
56
+
57
+ def get_len(self):
58
+ return max(self.len_of_groups)
59
+
60
+ def get_item(self, idx):
61
+ return_dict = {}
62
+ for i in range(len(self.paired_file_groups)):
63
+ inner_idx = idx if idx < self.len_of_groups[i] else random.randint(0, self.len_of_groups[i] - 1)
64
+ img_list = []
65
+ img_k_list = []
66
+ for k, v in self.paired_file_groups[i].items():
67
+ if self.paired_type_groups[i][k] == 'image':
68
+ # gather images for processing later
69
+ img_k_list.append(k)
70
+ img_list.append(v[inner_idx])
71
+ elif self.paired_type_groups[i][k] == 'landmark':
72
+ # different from images, landmark doesn't use data augmentation. So process them directly here.
73
+ lmk = np.load(v[inner_idx])
74
+ lmk[:, 0] *= self.landmark_scale[0]
75
+ lmk[:, 1] *= self.landmark_scale[1]
76
+ return_dict[k] = lmk
77
+ return_dict[k + '_path'] = v[inner_idx]
78
+
79
+ # transform all images
80
+ if len(img_list) == 1:
81
+ return_dict[img_k_list[0]], _ = self.transform(img_list[0], None)
82
+ elif len(img_list) > 1:
83
+ input1, input2 = img_list[0], img_list[1:]
84
+ output1, output2 = self.transform(input1, input2) # output1 is one image. output2 is a list of images.
85
+ return_dict[img_k_list[0]] = output1
86
+ for j in range(1, len(img_list)):
87
+ return_dict[img_k_list[j]] = output2[j-1]
88
+
89
+ return return_dict
90
+
91
+ def split_data_into_bins(self, num_bins):
92
+ bins = []
93
+ for i in range(0, num_bins):
94
+ bins.append([])
95
+ for i in range(0, len(self.paired_file_groups)):
96
+ for b in range(0, num_bins):
97
+ bins[b].append({})
98
+ for dataname, item_list in self.paired_file_groups[i].items():
99
+ if len(item_list) < self.config['dataset']['n_threads']:
100
+ bins[0][i][dataname] = item_list
101
+ else:
102
+ num_items_in_bin = len(item_list) // num_bins
103
+ for j in range(0, len(item_list)):
104
+ which_bin = min(j // num_items_in_bin, num_bins - 1)
105
+ if dataname not in bins[which_bin][i]:
106
+ bins[which_bin][i][dataname] = []
107
+ else:
108
+ bins[which_bin][i][dataname].append(item_list[j])
109
+ return bins
110
+
111
+ def check_data_helper(self, data):
112
+ all_pass = True
113
+ for paired_file_group in data:
114
+ for k, v in paired_file_group.items():
115
+ if len(v) > 0:
116
+ for v1 in v:
117
+ if '.npy' in v1: # case: numpy array or landmark
118
+ all_pass = all_pass and check_numpy_loaded(v1)
119
+ else: # case: image
120
+ all_pass = all_pass and check_img_loaded(v1)
121
+ return all_pass
data/deprecated/landmark_data.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from PIL import Image
3
+ import numpy as np
4
+ from utils.data_utils import check_create_shuffled_order, check_equal_length
5
+
6
+ def landmark_path_to_numpy(lmk_path, image_path, image_tensor):
7
+ """Convert an landmark path to the actual landmarks in a numpy array. Also applies scaling to the landmarks
8
+ according to final image' size.
9
+
10
+ Parameters:
11
+ lmk_path -- the landmark file path.
12
+ image_path -- the original image file path.
13
+ image_tensor -- the image tensor after all transformations.
14
+ """
15
+ lmk = np.load(lmk_path)
16
+ ow, oh = Image.open(image_path).size
17
+ h, w = image_tensor.size()[1:]
18
+ lmk[:, 0] *= w / ow
19
+ lmk[:, 1] *= h / oh
20
+ return lmk
21
+
22
+ def add_landmark_data(data, config, paired_data_order):
23
+ A_lmk_paths = []
24
+ B_lmk_paths = []
25
+
26
+ if config['dataset']['paired_' + config['common']['phase'] + '_filelist'] != '':
27
+ paired_data_file = open(config['dataset']['paired_' + config['common']['phase'] + '_filelist'], 'r')
28
+ Lines = paired_data_file.readlines()
29
+ paired_data_order = check_create_shuffled_order(Lines, paired_data_order)
30
+ check_equal_length(Lines, paired_data_order, data)
31
+ for i in paired_data_order:
32
+ line = Lines[i]
33
+ if not config['dataset']['use_absolute_datafile']:
34
+ file3 = os.path.join(config['dataset']['dataroot'], line.split(" ")[2]).strip()
35
+ file4 = os.path.join(config['dataset']['dataroot'], line.split(" ")[3]).strip()
36
+ else:
37
+ file3 = line.split(" ")[2].strip()
38
+ file4 = line.split(" ")[3].strip()
39
+ if os.path.exists(file3) and os.path.exists(file4):
40
+ A_lmk_paths.append(file3)
41
+ B_lmk_paths.append(file4)
42
+ paired_data_file.close()
43
+ elif config['dataset']['paired_' + config['common']['phase'] + 'A_folder'] != '' and \
44
+ config['dataset']['paired_' + config['common']['phase'] + 'B_folder'] != '' and \
45
+ os.path.exists(config['dataset']['paired_' + config['common']['phase'] + 'A_lmk_folder']) and \
46
+ os.path.exists(config['dataset']['paired_' + config['common']['phase'] + 'B_lmk_folder']):
47
+ dir_A = config['dataset']['paired_' + config['common']['phase'] + 'A_folder']
48
+ dir_A_lmk = config['dataset']['paired_' + config['common']['phase'] + 'A_lmk_folder']
49
+ dir_B_lmk = config['dataset']['paired_' + config['common']['phase'] + 'B_lmk_folder']
50
+ filenames = os.listdir(dir_A)
51
+ paired_data_order = check_create_shuffled_order(filenames, paired_data_order)
52
+ check_equal_length(filenames, paired_data_order, data)
53
+ for i in paired_data_order:
54
+ filename = filenames[i]
55
+ A_lmk_path = os.path.join(dir_A_lmk, os.path.splitext(filename)[0] + '.npy')
56
+ B_lmk_path = os.path.join(dir_B_lmk, os.path.splitext(filename)[0] + '.npy')
57
+ if os.path.exists(A_lmk_path) and os.path.exists(B_lmk_path):
58
+ A_lmk_paths.append(A_lmk_path)
59
+ B_lmk_paths.append(B_lmk_path)
60
+ else:
61
+ dir_A = os.path.join(config['dataset']['dataroot'], config['common']['phase'] + 'pairedA')
62
+ dir_A_lmk = os.path.join(config['dataset']['dataroot'], config['common']['phase'] + 'pairedA_lmk')
63
+ dir_B_lmk = os.path.join(config['dataset']['dataroot'], config['common']['phase'] + 'pairedB_lmk')
64
+ if os.path.exists(dir_A_lmk) and os.path.exists(dir_B_lmk):
65
+ filenames = os.listdir(dir_A)
66
+ paired_data_order = check_create_shuffled_order(filenames, paired_data_order)
67
+ check_equal_length(filenames, paired_data_order, data)
68
+ for i in paired_data_order:
69
+ filename = filenames[i]
70
+ A_lmk_path = os.path.join(dir_A_lmk, os.path.splitext(filename)[0] + '.npy')
71
+ B_lmk_path = os.path.join(dir_B_lmk, os.path.splitext(filename)[0] + '.npy')
72
+ if os.path.exists(A_lmk_path) and os.path.exists(B_lmk_path):
73
+ A_lmk_paths.append(A_lmk_path)
74
+ B_lmk_paths.append(B_lmk_path)
75
+ else:
76
+ print(dir_A_lmk + " or " + dir_B_lmk + " doesn't exist. Skipping landmark data.")
77
+
78
+ data['A_lmk_path'] = A_lmk_paths
79
+ data['B_lmk_path'] = B_lmk_paths
80
+
81
+ return paired_data_order
82
+
83
+
84
+ def apply_landmark_transforms(index, data, return_dict):
85
+ if len(data['A_lmk_path']) > 0:
86
+ return_dict['A_lmk'] = landmark_path_to_numpy(data['A_lmk_path'][index], data['paired_A_path'][index], return_dict['paired_A'])
87
+ return_dict['B_lmk'] = landmark_path_to_numpy(data['B_lmk_path'][index], data['paired_B_path'][index], return_dict['paired_B'])
88
+ return_dict['A_lmk_path'] = data['A_lmk_path'][index]
89
+ return_dict['B_lmk_path'] = data['B_lmk_path'][index]
data/deprecated/numpy_paired_data.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from utils.util import check_path_is_img
3
+ from utils.data_utils import Transforms, check_create_shuffled_order, check_equal_length
4
+ from utils.augmentation import NumpyToTensor
5
+
6
+
7
+ def add_numpy_paired_data(data, transforms, config, paired_data_order):
8
+ A_paths = []
9
+ B_paths = []
10
+
11
+ if config['dataset']['paired_' + config['common']['phase'] + '_filelist'] != '':
12
+ paired_data_file = open(config['dataset']['paired_' + config['common']['phase'] + '_filelist'], 'r')
13
+ Lines = paired_data_file.readlines()
14
+ paired_data_order = check_create_shuffled_order(Lines, paired_data_order)
15
+ check_equal_length(Lines, paired_data_order, data)
16
+ for i in paired_data_order:
17
+ line = Lines[i]
18
+ if not config['dataset']['use_absolute_datafile']:
19
+ file1 = os.path.join(config['dataset']['dataroot'], line.split(" ")[0]).strip()
20
+ file2 = os.path.join(config['dataset']['dataroot'], line.split(" ")[1]).strip()
21
+ else:
22
+ file1 = line.split(" ")[0].strip()
23
+ file2 = line.split(" ")[1].strip()
24
+ if os.path.exists(file1) and os.path.exists(file2):
25
+ A_paths.append(file1)
26
+ B_paths.append(file2)
27
+ paired_data_file.close()
28
+ elif config['dataset']['paired_' + config['common']['phase'] + 'A_folder'] != '' and \
29
+ config['dataset']['paired_' + config['common']['phase'] + 'B_folder'] != '':
30
+ dir_A = config['dataset']['paired_' + config['common']['phase'] + 'A_folder']
31
+ dir_B = config['dataset']['paired_' + config['common']['phase'] + 'B_folder']
32
+ filenames = os.listdir(dir_A)
33
+ paired_data_order = check_create_shuffled_order(filenames, paired_data_order)
34
+ check_equal_length(filenames, paired_data_order, data)
35
+ for i in paired_data_order:
36
+ filename = filenames[i]
37
+ if not check_path_is_img(filename):
38
+ continue
39
+ A_path = os.path.join(dir_A, filename)
40
+ B_path = os.path.join(dir_B, filename)
41
+ if os.path.exists(A_path) and os.path.exists(B_path):
42
+ A_paths.append(A_path)
43
+ B_paths.append(B_path)
44
+ else:
45
+ dir_A = os.path.join(config['dataset']['dataroot'], config['common']['phase'] + 'numpypairedA')
46
+ dir_B = os.path.join(config['dataset']['dataroot'], config['common']['phase'] + 'numpypairedB')
47
+ if os.path.exists(dir_A) and os.path.exists(dir_B):
48
+ filenames = os.listdir(dir_A)
49
+ paired_data_order = check_create_shuffled_order(filenames, paired_data_order)
50
+ check_equal_length(filenames, paired_data_order, data)
51
+ for i in paired_data_order:
52
+ filename = filenames[i]
53
+ if not check_path_is_img(filename):
54
+ continue
55
+ A_path = os.path.join(dir_A, filename)
56
+ B_path = os.path.join(dir_B, filename)
57
+ if os.path.exists(A_path) and os.path.exists(B_path):
58
+ A_paths.append(A_path)
59
+ B_paths.append(B_path)
60
+
61
+ btoA = config['dataset']['direction'] == 'BtoA'
62
+ # get the number of channels of input image
63
+ input_nc = config['model']['output_nc'] if btoA else config['model']['input_nc']
64
+ output_nc = config['model']['input_nc'] if btoA else config['model']['output_nc']
65
+
66
+ transform = Transforms(config, input_grayscale_flag=(input_nc == 1), output_grayscale_flag=(output_nc == 1))
67
+ transform.transform_list.append(NumpyToTensor())
68
+ transform = transform.compose_transforms()
69
+
70
+ data['paired_A_path'] = A_paths
71
+ data['paired_B_path'] = B_paths
72
+ transforms['paired'] = transform
73
+ return paired_data_order
74
+
75
+
76
+ def apply_numpy_paired_transforms(index, data, transforms, return_dict):
77
+ if len(data['paired_A_path']) > 0:
78
+ return_dict['paired_A'], return_dict['paired_B'] = transforms['paired'] \
79
+ (data['paired_A_path'][index], data['paired_B_path'][index])
80
+ return_dict['paired_A_path'] = data['paired_A_path'][index]
81
+ return_dict['paired_B_path'] = data['paired_B_path'][index]
data/deprecated/numpy_unpaired_data.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from utils.util import check_path_is_img
3
+ from utils.data_utils import Transforms
4
+ from utils.augmentation import NumpyToTensor
5
+ import random
6
+
7
+ def add_numpy_unpaired_data(data, transforms, config, shuffle=False):
8
+ A_paths = []
9
+ B_paths = []
10
+ if config['dataset']['unpaired_' + config['common']['phase'] + 'A_filelist'] != '':
11
+ unpaired_data_file1 = open(config['dataset']['unpaired_' + config['common']['phase'] + 'A_filelist'], 'r')
12
+ Lines = unpaired_data_file1.readlines()
13
+ if shuffle:
14
+ random.shuffle(Lines)
15
+ for line in Lines:
16
+ if not config['dataset']['use_absolute_datafile']:
17
+ file = os.path.join(config['dataset']['dataroot'], line.strip())
18
+ else:
19
+ file = line.strip()
20
+ if os.path.exists(file):
21
+ A_paths.append(file)
22
+ unpaired_data_file1.close()
23
+
24
+ unpaired_data_file2 = open(config['dataset']['unpaired_' + config['common']['phase'] + 'B_filelist'], 'r')
25
+ Lines = unpaired_data_file2.readlines()
26
+ if shuffle:
27
+ random.shuffle(Lines)
28
+ for line in Lines:
29
+ if not config['dataset']['use_absolute_datafile']:
30
+ file = os.path.join(config['dataset']['dataroot'], line.strip())
31
+ else:
32
+ file = line.strip()
33
+ if os.path.exists(file):
34
+ B_paths.append(file)
35
+ unpaired_data_file2.close()
36
+ elif config['dataset']['unpaired_' + config['common']['phase'] + 'A_folder'] != '' and \
37
+ config['dataset']['unpaired_' + config['common']['phase'] + 'B_folder'] != '':
38
+ dir_A = config['dataset']['unpaired_' + config['common']['phase'] + 'A_folder']
39
+ filenames = os.listdir(dir_A)
40
+ if shuffle:
41
+ random.shuffle(filenames)
42
+ for filename in filenames:
43
+ if not check_path_is_img(filename):
44
+ continue
45
+ A_path = os.path.join(dir_A, filename)
46
+ if os.path.exists(A_path):
47
+ A_paths.append(A_path)
48
+
49
+ dir_B = config['dataset']['unpaired_' + config['common']['phase'] + 'B_folder']
50
+ filenames = os.listdir(dir_B)
51
+ if shuffle:
52
+ random.shuffle(filenames)
53
+ for filename in filenames:
54
+ if not check_path_is_img(filename):
55
+ continue
56
+ B_path = os.path.join(dir_B, filename)
57
+ if os.path.exists(B_path):
58
+ B_paths.append(B_path)
59
+
60
+ else:
61
+ dir_A = os.path.join(config['dataset']['dataroot'], config['common']['phase'] + 'numpyunpairedA')
62
+ dir_B = os.path.join(config['dataset']['dataroot'], config['common']['phase'] + 'numpyunpairedB')
63
+ if os.path.exists(dir_A) and os.path.exists(dir_B):
64
+ filenames = os.listdir(dir_A)
65
+ if shuffle:
66
+ random.shuffle(filenames)
67
+ for filename in filenames:
68
+ if not check_path_is_img(filename):
69
+ continue
70
+ A_path = os.path.join(dir_A, filename)
71
+ A_paths.append(A_path)
72
+ filenames = os.listdir(dir_B)
73
+ if shuffle:
74
+ random.shuffle(filenames)
75
+ for filename in filenames:
76
+ if not check_path_is_img(filename):
77
+ continue
78
+ B_path = os.path.join(dir_B, filename)
79
+ B_paths.append(B_path)
80
+
81
+
82
+ btoA = config['dataset']['direction'] == 'BtoA'
83
+ input_nc = config['model']['output_nc'] if btoA else config['model']['input_nc']
84
+ output_nc = config['model']['input_nc'] if btoA else config['model']['output_nc']
85
+
86
+ transform = Transforms(config, input_grayscale_flag=(input_nc == 1), output_grayscale_flag=(output_nc == 1))
87
+ transform.transform_list.append(NumpyToTensor())
88
+ transform = transform.compose_transforms()
89
+
90
+ data['unpaired_A_path'] = A_paths
91
+ data['unpaired_B_path'] = B_paths
92
+ transforms['unpaired'] = transform
93
+
94
+ def apply_numpy_unpaired_transforms(index, data, transforms, return_dict):
95
+ if len(data['unpaired_A_path']) > 0 and len(data['unpaired_B_path']) > 0:
96
+ index_B = random.randint(0, len(data['unpaired_B_path']) - 1)
97
+ return_dict['unpaired_A'], return_dict['unpaired_B'] = transforms['unpaired'] \
98
+ (data['unpaired_A_path'][index], data['unpaired_B_path'][index_B])
99
+ return_dict['unpaired_A_path'] = data['unpaired_A_path'][index]
100
+ return_dict['unpaired_B_path'] = data['unpaired_B_path'][index_B]
data/deprecated/paired_data.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from utils.util import check_path_is_img
3
+ from utils.data_utils import Transforms, check_create_shuffled_order
4
+ from utils.augmentation import ImagePathToImage
5
+
6
+
7
+ def add_paired_data(data, transforms, config, paired_data_order):
8
+ A_paths = []
9
+ B_paths = []
10
+
11
+ if config['dataset']['paired_' + config['common']['phase'] + '_filelist'] != '':
12
+ paired_data_file = open(config['dataset']['paired_' + config['common']['phase'] + '_filelist'], 'r')
13
+ Lines = paired_data_file.readlines()
14
+ paired_data_order = check_create_shuffled_order(Lines, paired_data_order)
15
+ for i in paired_data_order:
16
+ line = Lines[i]
17
+ if not config['dataset']['use_absolute_datafile']:
18
+ file1 = os.path.join(config['dataset']['dataroot'], line.split(" ")[0]).strip()
19
+ file2 = os.path.join(config['dataset']['dataroot'], line.split(" ")[1]).strip()
20
+ else:
21
+ file1 = line.split(" ")[0].strip()
22
+ file2 = line.split(" ")[1].strip()
23
+ if os.path.exists(file1) and os.path.exists(file2):
24
+ A_paths.append(file1)
25
+ B_paths.append(file2)
26
+ paired_data_file.close()
27
+ elif config['dataset']['paired_' + config['common']['phase'] + 'A_folder'] != '' and \
28
+ config['dataset']['paired_' + config['common']['phase'] + 'B_folder'] != '':
29
+ dir_A = config['dataset']['paired_' + config['common']['phase'] + 'A_folder']
30
+ dir_B = config['dataset']['paired_' + config['common']['phase'] + 'B_folder']
31
+ filenames = os.listdir(dir_A)
32
+ paired_data_order = check_create_shuffled_order(filenames, paired_data_order)
33
+ for i in paired_data_order:
34
+ filename = filenames[i]
35
+ if not check_path_is_img(filename):
36
+ continue
37
+ A_path = os.path.join(dir_A, filename)
38
+ B_path = os.path.join(dir_B, filename)
39
+ if os.path.exists(A_path) and os.path.exists(B_path):
40
+ A_paths.append(A_path)
41
+ B_paths.append(B_path)
42
+ else:
43
+ dir_A = os.path.join(config['dataset']['dataroot'], config['common']['phase'] + 'pairedA')
44
+ dir_B = os.path.join(config['dataset']['dataroot'], config['common']['phase'] + 'pairedB')
45
+ if os.path.exists(dir_A) and os.path.exists(dir_B):
46
+ filenames = os.listdir(dir_A)
47
+ paired_data_order = check_create_shuffled_order(filenames, paired_data_order)
48
+ for i in paired_data_order:
49
+ filename = filenames[i]
50
+ if not check_path_is_img(filename):
51
+ continue
52
+ A_path = os.path.join(dir_A, filename)
53
+ B_path = os.path.join(dir_B, filename)
54
+ if os.path.exists(A_path) and os.path.exists(B_path):
55
+ A_paths.append(A_path)
56
+ B_paths.append(B_path)
57
+
58
+ btoA = config['dataset']['direction'] == 'BtoA'
59
+ # get the number of channels of input image
60
+ input_nc = config['model']['output_nc'] if btoA else config['model']['input_nc']
61
+ output_nc = config['model']['input_nc'] if btoA else config['model']['output_nc']
62
+
63
+ transform = Transforms(config, input_grayscale_flag=(input_nc == 1), output_grayscale_flag=(output_nc == 1))
64
+ transform.get_transform_from_config()
65
+ transform.get_transforms().insert(0, ImagePathToImage())
66
+ transform = transform.compose_transforms()
67
+
68
+ data['paired_A_path'] = A_paths
69
+ data['paired_B_path'] = B_paths
70
+
71
+ transforms['paired'] = transform
72
+ return paired_data_order
73
+
74
+
75
+ def apply_paired_transforms(index, data, transforms, return_dict):
76
+ if len(data['paired_A_path']) > 0:
77
+ return_dict['paired_A'], return_dict['paired_B'] = transforms['paired'] \
78
+ (data['paired_A_path'][index], data['paired_B_path'][index])
79
+ return_dict['paired_A_path'] = data['paired_A_path'][index]
80
+ return_dict['paired_B_path'] = data['paired_B_path'][index]
data/deprecated/patch_data.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import torch
3
+
4
+ def load_patches(patch_batch_size, batch_size, patch_size, num_patch, diff_patch, index, data, transforms, return_dict):
5
+ if patch_size > 0:
6
+ assert (patch_batch_size % batch_size == 0), \
7
+ "patch_batch_size is not divisible by batch_size."
8
+ if 'paired_A' in return_dict or 'paired_B' in return_dict:
9
+ if not diff_patch:
10
+ # load patch from current image
11
+ patchA = return_dict['paired_A'].clone()
12
+ patchB = return_dict['paired_B'].clone()
13
+ else:
14
+ # load patch from a different image
15
+ pathA = data['paired_A_path'][(index + 1) % len(data['paired_A_path'])]
16
+ pathB = data['paired_B_path'][(index + 1) % len(data['paired_B_path'])]
17
+ patchA, patchB = transforms['paired'](pathA, pathB)
18
+ else:
19
+ if not diff_patch:
20
+ # load patch from current image
21
+ patchA = return_dict['unpaired_A'].clone()
22
+ patchB = return_dict['unpaired_B'].clone()
23
+ else:
24
+ # load patch from a different image
25
+ pathA = data['unpaired_A_path'][(index + 1) % len(data['unpaired_A_path'])]
26
+ pathB = data['unpaired_B_path'][(index + 1) % len(data['unpaired_B_path'])]
27
+ patchA, patchB = transforms['unpaired'](pathA, pathB)
28
+
29
+ # crop patch
30
+ patchAs = []
31
+ patchBs = []
32
+ _, h, w = patchA.size()
33
+
34
+ for _ in range(num_patch):
35
+ r = random.randint(0, h - patch_size - 1)
36
+ c = random.randint(0, w - patch_size - 1)
37
+ patchAs.append(patchA[:, r:r + patch_size, c:c + patch_size])
38
+ patchBs.append(patchB[:, r:r + patch_size, c:c + patch_size])
39
+
40
+ patchAs = torch.cat(patchAs, 0)
41
+ patchBs = torch.cat(patchBs, 0)
42
+
43
+ return_dict['patch_A'] = patchAs
44
+ return_dict['patch_B'] = patchBs
data/deprecated/unpaired_data.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from utils.util import check_path_is_img
3
+ from utils.data_utils import Transforms
4
+ from utils.augmentation import ImagePathToImage
5
+ import random
6
+
7
+ def add_unpaired_data(data, transforms, config, shuffle=False):
8
+ A_paths = []
9
+ B_paths = []
10
+ if config['dataset']['unpaired_' + config['common']['phase'] + 'A_filelist'] != '':
11
+ unpaired_data_file1 = open(config['dataset']['unpaired_' + config['common']['phase'] + 'A_filelist'], 'r')
12
+ Lines = unpaired_data_file1.readlines()
13
+ if shuffle:
14
+ random.shuffle(Lines)
15
+ for line in Lines:
16
+ if not config['dataset']['use_absolute_datafile']:
17
+ file = os.path.join(config['dataset']['dataroot'], line.strip())
18
+ else:
19
+ file = line.strip()
20
+ if os.path.exists(file):
21
+ A_paths.append(file)
22
+ unpaired_data_file1.close()
23
+
24
+ unpaired_data_file2 = open(config['dataset']['unpaired_' + config['common']['phase'] + 'B_filelist'], 'r')
25
+ Lines = unpaired_data_file2.readlines()
26
+ if shuffle:
27
+ random.shuffle(Lines)
28
+ for line in Lines:
29
+ if not config['dataset']['use_absolute_datafile']:
30
+ file = os.path.join(config['dataset']['dataroot'], line.strip())
31
+ else:
32
+ file = line.strip()
33
+ if os.path.exists(file):
34
+ B_paths.append(file)
35
+ unpaired_data_file2.close()
36
+ elif config['dataset']['unpaired_' + config['common']['phase'] + 'A_folder'] != '' and \
37
+ config['dataset']['unpaired_' + config['common']['phase'] + 'B_folder'] != '':
38
+ dir_A = config['dataset']['unpaired_' + config['common']['phase'] + 'A_folder']
39
+ filenames = os.listdir(dir_A)
40
+ if shuffle:
41
+ random.shuffle(filenames)
42
+ for filename in filenames:
43
+ if not check_path_is_img(filename):
44
+ continue
45
+ A_path = os.path.join(dir_A, filename)
46
+ if os.path.exists(A_path):
47
+ A_paths.append(A_path)
48
+
49
+ dir_B = config['dataset']['unpaired_' + config['common']['phase'] + 'B_folder']
50
+ filenames = os.listdir(dir_B)
51
+ if shuffle:
52
+ random.shuffle(filenames)
53
+ for filename in filenames:
54
+ if not check_path_is_img(filename):
55
+ continue
56
+ B_path = os.path.join(dir_B, filename)
57
+ if os.path.exists(B_path):
58
+ B_paths.append(B_path)
59
+
60
+ else:
61
+ dir_A = os.path.join(config['dataset']['dataroot'], config['common']['phase'] + 'unpairedA')
62
+ dir_B = os.path.join(config['dataset']['dataroot'], config['common']['phase'] + 'unpairedB')
63
+ if os.path.exists(dir_A) and os.path.exists(dir_B):
64
+ filenames = os.listdir(dir_A)
65
+ if shuffle:
66
+ random.shuffle(filenames)
67
+ for filename in filenames:
68
+ if not check_path_is_img(filename):
69
+ continue
70
+ A_path = os.path.join(dir_A, filename)
71
+ A_paths.append(A_path)
72
+ filenames = os.listdir(dir_B)
73
+ if shuffle:
74
+ random.shuffle(filenames)
75
+ for filename in filenames:
76
+ if not check_path_is_img(filename):
77
+ continue
78
+ B_path = os.path.join(dir_B, filename)
79
+ B_paths.append(B_path)
80
+
81
+
82
+ btoA = config['dataset']['direction'] == 'BtoA'
83
+ input_nc = config['model']['output_nc'] if btoA else config['model']['input_nc']
84
+ output_nc = config['model']['input_nc'] if btoA else config['model']['output_nc']
85
+
86
+ transform = Transforms(config, input_grayscale_flag=(input_nc == 1), output_grayscale_flag=(output_nc == 1))
87
+ transform.get_transform_from_config()
88
+ transform.get_transforms().insert(0, ImagePathToImage())
89
+ transform = transform.compose_transforms()
90
+
91
+ data['unpaired_A_path'] = A_paths
92
+ data['unpaired_B_path'] = B_paths
93
+ transforms['unpaired'] = transform
94
+
95
+ def apply_unpaired_transforms(index, data, transforms, return_dict):
96
+ if len(data['unpaired_A_path']) > 0 and len(data['unpaired_B_path']) > 0:
97
+ index_B = random.randint(0, len(data['unpaired_B_path']) - 1)
98
+ return_dict['unpaired_A'], return_dict['unpaired_B'] = transforms['unpaired'] \
99
+ (data['unpaired_A_path'][index], data['unpaired_B_path'][index_B])
100
+ return_dict['unpaired_A_path'] = data['unpaired_A_path'][index]
101
+ return_dict['unpaired_B_path'] = data['unpaired_B_path'][index_B]
data/static_data.py ADDED
@@ -0,0 +1,457 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, sys
2
+ import random
3
+ import numpy as np
4
+ from utils.augmentation import ImagePathToImage, NumpyToTensor
5
+ from utils.data_utils import Transforms
6
+ from utils.util import check_path_is_static_data
7
+ import torch
8
+ from PIL import Image
9
+
10
+
11
+ def check_dataname_folder_correspondence(data_names, group, group_name):
12
+ for data_name in data_names:
13
+ if data_name + '_folder' not in group:
14
+ print("%s not found in config file. Going to use dataroot mode to load group %s." % (data_name + '_folder', group_name))
15
+ return False
16
+ return True
17
+
18
+
19
+ def custom_check_path_exists(str1):
20
+ return True if (str1 == "None" or os.path.exists(str1)) else False
21
+
22
+
23
+ def custom_getsize(str1):
24
+ return 1 if str1 == "None" else os.path.getsize(str1)
25
+
26
+
27
+ def check_different_extension_path_exists(str1):
28
+ acceptable_extensions = ['png', 'jpg', 'jpeg', 'npy', 'npz', 'PNG', 'JPG', 'JPEG']
29
+ curr_extension = str1.split('.')[-1]
30
+ for extension in acceptable_extensions:
31
+ str2 = str1.replace(curr_extension, extension)
32
+ if os.path.exists(str2):
33
+ return str2
34
+ return None
35
+
36
+
37
+ class StaticData(object):
38
+
39
+ def __init__(self, config, shuffle=False):
40
+ # private variables
41
+ self.file_groups = []
42
+ self.type_groups = []
43
+ self.group_names = []
44
+ self.pair_type_groups = []
45
+ self.len_of_groups = []
46
+ self.transforms = {}
47
+ # parameters
48
+ self.shuffle = shuffle
49
+ self.config = config
50
+
51
+
52
+ def load_static_data(self):
53
+ data_dict = self.config['dataset'][self.config['common']['phase'] + '_data']
54
+ print("----------------loading %s static data.---------------------" % self.config['common']['phase'])
55
+
56
+ if len(data_dict) == 0:
57
+ self.len_of_groups.append(0)
58
+ return
59
+
60
+ self.group_names = list(data_dict.keys())
61
+ for i, group in enumerate(data_dict.values()): # examples: (0, group_1), (1, group_2)
62
+ data_types = group['data_types'] # examples: 'image', 'patch'
63
+ data_names = group['data_names'] # examples: 'real_A', 'patch_A'
64
+ self.file_groups.append({})
65
+ self.type_groups.append({})
66
+ self.len_of_groups.append(0)
67
+ self.pair_type_groups.append(group['paired'])
68
+
69
+ # exclude patch data since they are not stored on disk. They will be handled later.
70
+ data_types, data_names = self.exclude_patch_data(data_types, data_names)
71
+ assert(len(data_types) == len(data_names))
72
+
73
+ if len(data_names) == 0:
74
+ continue
75
+
76
+ for data_name, data_type in zip(data_names, data_types):
77
+ self.file_groups[i][data_name] = []
78
+ self.type_groups[i][data_name] = data_type
79
+
80
+
81
+ # paired data
82
+ if group['paired']:
83
+ # First way to load data: load a file list
84
+ if 'file_list' in group:
85
+ file_list = group['file_list']
86
+ paired_file = open(file_list, 'rt')
87
+ lines = paired_file.readlines()
88
+ if self.shuffle:
89
+ random.shuffle(lines)
90
+ for line in lines:
91
+ items = line.strip().split(' ')
92
+ if len(items) == len(data_names):
93
+ ok = True
94
+ for item in items:
95
+ ok = ok and os.path.exists(item) and os.path.getsize(item) > 0
96
+ if ok:
97
+ for data_name, item in zip(data_names, items):
98
+ self.file_groups[i][data_name].append(item)
99
+ paired_file.close()
100
+ # second and third way to load data: specify one folder for each dataname, or specify a dataroot folder
101
+ elif check_dataname_folder_correspondence(data_names, group, self.group_names[i]) or 'dataroot' in group:
102
+ dataname_to_dir_dict = {}
103
+ for data_name, data_type in zip(data_names, data_types):
104
+ if 'dataroot' in group:
105
+ # In new data config format, data is stored in dataroot_name/mode/dataname. e.g. FFHQ/train/pairedA
106
+ # In old format, data is stored in dataroot_name/mode_dataname. e.g. FFHQ/train_pairedA
107
+ # So we need to check both.
108
+ dir = os.path.join(group['dataroot'], self.config['common']['phase'], data_name)
109
+ if not os.path.exists(dir):
110
+ old_dir = os.path.join(group['dataroot'], self.config['common']['phase'] + data_name.replace('_', ''))
111
+ if 'numpy' in data_type:
112
+ old_dir += 'numpy'
113
+ if not os.path.exists(old_dir):
114
+ print("Both %s and %s does not exist. Please check." % (dir, old_dir))
115
+ sys.exit()
116
+ else:
117
+ dir = old_dir
118
+ else:
119
+ dir = group[data_name + '_folder']
120
+ if not os.path.exists(dir):
121
+ print("directory %s does not exist. Please check." % dir)
122
+ sys.exit()
123
+ dataname_to_dir_dict[data_name] = dir
124
+
125
+ filenames = os.listdir(dataname_to_dir_dict[data_names[0]])
126
+ if self.shuffle:
127
+ random.shuffle(filenames)
128
+ for filename in filenames:
129
+ if not check_path_is_static_data(filename):
130
+ continue
131
+ file_paths = []
132
+ for data_name in data_names:
133
+ file_path = os.path.join(dataname_to_dir_dict[data_name], filename)
134
+ checked_extension = check_different_extension_path_exists(file_path)
135
+ if checked_extension is not None:
136
+ file_paths.append(checked_extension)
137
+
138
+ if len(file_paths) != len(data_names):
139
+ print("for file %s , looks like some of the other pair data is missing. Ignoring and proceeding." % filename)
140
+ continue
141
+ else:
142
+ for j in range(len(data_names)):
143
+ data_name = data_names[j]
144
+ self.file_groups[i][data_name].append(file_paths[j])
145
+ else:
146
+ print("method for loading data is incorrect/unspecified for data group %s." % self.group_names)
147
+ sys.exit()
148
+
149
+ self.len_of_groups[i] = len(self.file_groups[i][data_names[0]])
150
+
151
+ # unpaired data
152
+ else:
153
+ # First way to load data: load a file list
154
+ if 'file_list' in group:
155
+ file_list = group['file_list']
156
+ unpaired_file = open(file_list, 'rt')
157
+ lines = unpaired_file.readlines()
158
+ if self.shuffle:
159
+ random.shuffle(lines)
160
+ item_count = 0
161
+ for line in lines:
162
+ items = line.strip().split(' ')
163
+ if len(items) == len(data_names):
164
+ ok = True
165
+ for item in items:
166
+ ok = ok and custom_check_path_exists(item) and custom_getsize(item) > 0
167
+ if ok:
168
+ has_data = False
169
+ for data_name, item in zip(data_names, items):
170
+ if item != 'None':
171
+ self.file_groups[i][data_name].append(item)
172
+ has_data = True
173
+ if has_data:
174
+ item_count += 1
175
+ unpaired_file.close()
176
+ self.len_of_groups[i] = item_count
177
+ # second and third way to load data: specify one folder for each dataname, or specify a dataroot folder
178
+ elif check_dataname_folder_correspondence(data_names, group, self.group_names[i]) or 'dataroot' in group:
179
+ max_length = 0
180
+ for data_name, data_type in zip(data_names, data_types):
181
+ if 'dataroot' in group:
182
+ # In new data config format, data is stored in dataroot_name/mode/dataname. e.g. FFHQ/train/pairedA
183
+ # In old format, data is stored in dataroot_name/mode_dataname. e.g. FFHQ/train_pairedA
184
+ # So we need to check both.
185
+ dir = os.path.join(group['dataroot'], self.config['common']['phase'], data_name)
186
+ if not os.path.exists(dir):
187
+ old_dir = os.path.join(group['dataroot'], self.config['common']['phase'] + data_name.replace('_', ''))
188
+ if 'numpy' in data_type:
189
+ old_dir += 'numpy'
190
+ if not os.path.exists(old_dir):
191
+ print("Both %s and %s does not exist. Please check." % (dir, old_dir))
192
+ sys.exit()
193
+ else:
194
+ dir = old_dir
195
+ else:
196
+ dir = group[data_name + '_folder']
197
+ if not os.path.exists(dir):
198
+ print("directory %s does not exist. Please check." % dir)
199
+ sys.exit()
200
+ filenames = os.listdir(dir)
201
+ if self.shuffle:
202
+ random.shuffle(filenames)
203
+
204
+ item_count = 0
205
+ for filename in filenames:
206
+ if not check_path_is_static_data(filename):
207
+ continue
208
+ fullpath = os.path.join(dir, filename)
209
+ if os.path.exists(fullpath):
210
+ self.file_groups[i][data_name].append(fullpath)
211
+ item_count += 1
212
+ max_length = max(item_count, max_length)
213
+ self.len_of_groups[i] = max_length
214
+ else:
215
+ print("method for loading data is incorrect/unspecified for data group %s." % self.group_names)
216
+ sys.exit()
217
+
218
+
219
+ def create_transforms(self):
220
+ btoA = self.config['dataset']['direction'] == 'BtoA'
221
+ input_nc = self.config['model']['output_nc'] if btoA else self.config['model']['input_nc']
222
+ output_nc = self.config['model']['input_nc'] if btoA else self.config['model']['output_nc']
223
+ input_grayscale_flag = (input_nc == 1)
224
+ output_grayscale_flag = (output_nc == 1)
225
+
226
+ data_dict = self.config['dataset'][self.config['common']['phase'] + '_data']
227
+ for i, group in enumerate(data_dict.values()): # examples: (0, group_1), (1, group_2)
228
+
229
+ if i not in self.transforms:
230
+ self.transforms[i] = {}
231
+
232
+ data_types = group['data_types'] # examples: 'image', 'patch'
233
+ data_names = group['data_names'] # examples: 'real_A', 'patch_A'
234
+ data_types, data_names = self.exclude_patch_data(data_types, data_names)
235
+ for data_name, data_type in zip(data_names, data_types):
236
+ if data_type in self.transforms[i]:
237
+ continue
238
+ self.transforms[i][data_type] = Transforms(self.config, input_grayscale_flag=input_grayscale_flag,
239
+ output_grayscale_flag=output_grayscale_flag)
240
+ self.transforms[i][data_type].create_transforms_from_list(group['preprocess'])
241
+ if '.png' in self.file_groups[i][data_name][0] or '.jpg' in self.file_groups[i][data_name][0] or \
242
+ '.jpeg' in self.file_groups[i][data_name][0]:
243
+ self.transforms[i][data_type].get_transforms().insert(0, ImagePathToImage())
244
+ elif '.npy' in self.file_groups[i][data_name][0] or '.npz' in self.file_groups[i][data_name][0]:
245
+ self.transforms[i][data_type].get_transforms().insert(0, NumpyToTensor())
246
+ self.transforms[i][data_type] = self.transforms[i][data_type].compose_transforms()
247
+
248
+
249
+ def apply_transformations_to_images(self, img_list, img_dataname_list, transform, return_dict,
250
+ next_img_paths_bucket, next_img_dataname_list):
251
+
252
+ if len(img_list) == 1:
253
+ return_dict[img_dataname_list[0]], _ = transform(img_list[0], None)
254
+ elif len(img_list) > 1:
255
+ next_data_count = len(next_img_paths_bucket)
256
+ img_list += next_img_paths_bucket
257
+ img_dataname_list += next_img_dataname_list
258
+
259
+ input1, input2 = img_list[0], img_list[1:]
260
+ output1, output2 = transform(input1, input2) # output1 is one image. output2 is a list of images.
261
+
262
+ if next_data_count != 0:
263
+ output2, next_outputs = output2[:-next_data_count], output2[-next_data_count:]
264
+ for i in range(next_data_count):
265
+ return_dict[img_dataname_list[-next_data_count+i] + '_next'] = next_outputs[i]
266
+
267
+ return_dict[img_dataname_list[0]] = output1
268
+ for j in range(0, len(output2)):
269
+ return_dict[img_dataname_list[j+1]] = output2[j]
270
+
271
+ return return_dict
272
+
273
+
274
+ def calculate_landmark_scale(self, data_path, data_type, i):
275
+ if data_type == 'image':
276
+ original_image = Image.open(data_path)
277
+ original_width, original_height = original_image.size
278
+ else:
279
+ original_image = np.load(data_path)
280
+ original_height, original_width = original_image.shape[0], original_image.shape[1]
281
+ transformed_image, _ = self.transforms[i][data_type](data_path, None)
282
+ transformed_height, transformed_width = transformed_image.size()[1:]
283
+ landmark_scale = (transformed_width / original_width, transformed_height / original_height)
284
+ return landmark_scale
285
+
286
+
287
+ def get_item(self, idx):
288
+
289
+ return_dict = {}
290
+ data_dict = self.config['dataset'][self.config['common']['phase'] + '_data']
291
+
292
+ for i, group in enumerate(data_dict.values()):
293
+ if self.file_groups[i] == {}:
294
+ continue
295
+
296
+ paired_type = self.pair_type_groups[i]
297
+ inner_idx = idx if idx < self.len_of_groups[i] else random.randint(0, self.len_of_groups[i] - 1)
298
+
299
+ landmark_scale = None
300
+
301
+ # for patches since they might need to be loaded from different images.
302
+ next_img_paths_bucket = []
303
+ next_img_dataname_list = []
304
+ next_numpy_paths_bucket = []
305
+ next_numpy_dataname_list = []
306
+
307
+ # First, handle all non-patch data
308
+ if paired_type:
309
+ img_paths_bucket = []
310
+ img_dataname_list = []
311
+ numpy_paths_bucket = []
312
+ numpy_dataname_list = []
313
+
314
+ for data_name, data_list in self.file_groups[i].items():
315
+ data_type = self.type_groups[i][data_name]
316
+ if data_type in ['image', 'numpy']:
317
+ if paired_type:
318
+ # augmentation will be applied to all images in paired group all at once so need to gather the images here.
319
+ if data_type == 'image':
320
+ img_paths_bucket.append(data_list[inner_idx])
321
+ img_dataname_list.append(data_name)
322
+ else:
323
+ numpy_paths_bucket.append(data_list[inner_idx])
324
+ numpy_dataname_list.append(data_name)
325
+ return_dict[data_name + '_path'] = data_list[inner_idx]
326
+ if landmark_scale is None:
327
+ landmark_scale = self.calculate_landmark_scale(data_list[inner_idx], data_type, i)
328
+ if 'diff_patch' in self.config['dataset'] and self.config['dataset']['diff_patch'] and \
329
+ data_name in group['patch_sources']:
330
+ next_idx = (inner_idx + 1) % (len(data_list) - 1)
331
+ if data_type == 'image':
332
+ next_img_paths_bucket.append(data_list[next_idx])
333
+ next_img_dataname_list.append(data_name)
334
+ else:
335
+ next_numpy_paths_bucket.append(data_list[next_idx])
336
+ next_numpy_dataname_list.append(data_name)
337
+ else:
338
+ unpaired_inner_idx = random.randint(0, len(data_list) - 1)
339
+ return_dict[data_name], _ = self.transforms[i][data_type](data_list[unpaired_inner_idx], None)
340
+ if landmark_scale is None:
341
+ landmark_scale = self.calculate_landmark_scale(data_list[unpaired_inner_idx], data_type, i)
342
+ if 'diff_patch' in self.config['dataset'] and self.config['dataset']['diff_patch'] and \
343
+ data_name in group['patch_sources']:
344
+ next_idx = (unpaired_inner_idx + 1) % (len(data_list) - 1)
345
+ return_dict[data_name + '_next'], _ = self.transforms[i][data_type](data_list[next_idx], None)
346
+ return_dict[data_name + '_path'] = data_list[unpaired_inner_idx]
347
+ elif self.type_groups[i][data_name] == 'landmark':
348
+ # We do not apply transformations on landmarks. Only scales landmarks to transformed image's size.
349
+ # Also numpy data is passed into network as numpy array and not tensor.
350
+ lmk = np.load(data_list[inner_idx])
351
+ if self.config['dataset']['landmark_scale'] is not None:
352
+ lmk[:, 0] *= self.config['dataset']['landmark_scale'][0]
353
+ lmk[:, 1] *= self.config['dataset']['landmark_scale'][1]
354
+ else:
355
+ if landmark_scale is None:
356
+ print("landmark_scale is None. If you have not defined it in config file, please specify "
357
+ "image and numpy data before landmark data and the proper scale will be automatically calculated.")
358
+ else:
359
+ lmk[:, 0] *= landmark_scale[0]
360
+ lmk[:, 1] *= landmark_scale[1]
361
+ return_dict[data_name] = lmk
362
+ return_dict[data_name + '_path'] = data_list[inner_idx]
363
+
364
+
365
+ if paired_type:
366
+ # apply augmentations to all images and numpy arrays
367
+ if 'image' in self.transforms[i]:
368
+ return_dict = self.apply_transformations_to_images(img_paths_bucket, img_dataname_list,
369
+ self.transforms[i]['image'], return_dict,
370
+ next_img_paths_bucket,
371
+ next_img_dataname_list)
372
+ if 'numpy' in self.transforms[i]:
373
+ return_dict = self.apply_transformations_to_images(numpy_paths_bucket, numpy_dataname_list,
374
+ self.transforms[i]['numpy'], return_dict,
375
+ next_numpy_paths_bucket,
376
+ next_numpy_dataname_list)
377
+
378
+ # Handle patch data
379
+ data_types = group['data_types'] # examples: 'image', 'patch'
380
+ data_names = group['data_names'] # examples: 'real_A', 'patch_A'
381
+ data_types, data_names = self.filter_patch_data(data_types, data_names)
382
+
383
+ if 'patch_sources' in group:
384
+ patch_sources = group['patch_sources']
385
+ return_dict = self.load_patches(
386
+ data_names,
387
+ self.config['dataset']['patch_batch_size'],
388
+ self.config['dataset']['batch_size'],
389
+ self.config['dataset']['patch_size'],
390
+ self.config['dataset']['patch_batch_size'] // self.config['dataset']['batch_size'],
391
+ self.config['dataset']['diff_patch'],
392
+ patch_sources,
393
+ return_dict,
394
+ )
395
+
396
+ return return_dict
397
+
398
+
399
+ def get_len(self):
400
+ if len(self.len_of_groups) == 0:
401
+ return 0
402
+ else:
403
+ return max(self.len_of_groups)
404
+
405
+
406
+ def exclude_patch_data(self, data_types, data_names):
407
+ data_types_patch_excluded = []
408
+ data_names_patch_excluded = []
409
+ for data_name, data_type in zip(data_names, data_types):
410
+ if data_type != 'patch':
411
+ data_types_patch_excluded.append(data_type)
412
+ data_names_patch_excluded.append(data_name)
413
+ return data_types_patch_excluded, data_names_patch_excluded
414
+
415
+
416
+ def filter_patch_data(self, data_types, data_names):
417
+ data_types_patch = []
418
+ data_names_patch = []
419
+ for data_name, data_type in zip(data_names, data_types):
420
+ if data_type == 'patch':
421
+ data_types_patch.append(data_type)
422
+ data_names_patch.append(data_name)
423
+ return data_types_patch, data_names_patch
424
+
425
+
426
+ def load_patches(self, data_names, patch_batch_size, batch_size, patch_size,
427
+ num_patch, diff_patch, patch_sources, return_dict):
428
+
429
+ if patch_size > 0:
430
+ assert (patch_batch_size % batch_size == 0), \
431
+ "patch_batch_size is not divisible by batch_size."
432
+ assert (len(patch_sources) == len(data_names)), \
433
+ "length of patch_sources is not the same as number of patch data specified. Please check again in config file."
434
+
435
+ rlist = [] # used for cropping patches
436
+ clist = [] # used for cropping patches
437
+ for _ in range(num_patch):
438
+ r = random.randint(0, self.config['dataset']['crop_size'] - patch_size - 1)
439
+ c = random.randint(0, self.config['dataset']['crop_size'] - patch_size - 1)
440
+ rlist.append(r)
441
+ clist.append(c)
442
+
443
+ for i in range(len(data_names)):
444
+ # load transformed image
445
+ patch = return_dict[patch_sources[i]] if not diff_patch else return_dict[patch_sources[i] + '_next']
446
+
447
+ # crop patch
448
+ patchs = []
449
+ _, h, w = patch.size()
450
+
451
+ for j in range(num_patch):
452
+ patchs.append(patch[:, rlist[j]:rlist[j] + patch_size, clist[j]:clist[j] + patch_size])
453
+ patchs = torch.cat(patchs, 0)
454
+
455
+ return_dict[data_names[i]] = patchs
456
+
457
+ return return_dict
data/super_dataset.py ADDED
@@ -0,0 +1,321 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import torch.utils.data as data
3
+ from utils.data_utils import check_img_loaded, check_numpy_loaded
4
+
5
+ from data.test_data import add_test_data, apply_test_transforms
6
+ from data.test_video_data import TestVideoData
7
+ from data.static_data import StaticData
8
+
9
+ from multiprocessing import Pool
10
+ import sys
11
+
12
+
13
+ class DataBin(object):
14
+ def __init__(self, filegroups):
15
+ self.filegroups = filegroups
16
+
17
+
18
+ class SuperDataset(data.Dataset):
19
+ def __init__(self, config, shuffle=False, check_all_data=False, DDP_device=None):
20
+ self.config = config
21
+
22
+ self.check_all_data = check_all_data
23
+ self.DDP_device = DDP_device
24
+
25
+ self.data = {} # Will be dictionary. Keys are data names, e.g. paired_A, patch_A. Values are lists containing associated data.
26
+ self.transforms = {}
27
+
28
+ if self.config['common']['phase'] == 'test':
29
+ if not self.config['testing']['test_video'] is None:
30
+ self.test_video_data = TestVideoData(self.config)
31
+ else:
32
+ add_test_data(self.data, self.transforms, self.config)
33
+ return
34
+
35
+ self.static_data = StaticData(self.config, shuffle)
36
+
37
+
38
+ def convert_old_config_to_new(self):
39
+ data_types = self.config['dataset']['data_type']
40
+ if len(data_types) == 1 and data_types[0] == 'custom':
41
+ # convert custom data configuration to new data configuration
42
+ old_dict = self.config['dataset']['custom_' + self.config['common']['phase'] + '_data']
43
+ preprocess_list = self.config['dataset']['preprocess']
44
+ new_datadict = self.config['dataset'][self.config['common']['phase'] + '_data'] = old_dict
45
+ for i, group in enumerate(new_datadict.values()): # examples: (0, group_1), (1, group_2)
46
+ group['paired'] = True
47
+ group['preprocess'] = preprocess_list
48
+ # custom data does not support patch so we skip patch logic.
49
+ else:
50
+ new_datadict = self.config['dataset'][self.config['common']['phase'] + '_data'] = {}
51
+ preprocess_list = self.config['dataset']['preprocess']
52
+ new_datadict['paired_group'] = {}
53
+ new_datadict['paired_group']['paired'] = True
54
+ new_datadict['paired_group']['data_types'] = []
55
+ new_datadict['paired_group']['data_names'] = []
56
+ new_datadict['paired_group']['preprocess'] = preprocess_list
57
+ new_datadict['unpaired_group'] = {}
58
+ new_datadict['unpaired_group']['paired'] = False
59
+ new_datadict['unpaired_group']['data_types'] = []
60
+ new_datadict['unpaired_group']['data_names'] = []
61
+ new_datadict['unpaired_group']['preprocess'] = preprocess_list
62
+
63
+ for i in range(len(self.config['dataset']['data_type'])):
64
+ data_type = self.config['dataset']['data_type'][i]
65
+ if data_type == 'paired' or data_type == 'paired_numpy':
66
+ if self.config['dataset']['paired_' + self.config['common']['phase'] + '_filelist'] != '':
67
+ new_datadict['paired_group']['file_list'] = self.config['dataset'][
68
+ 'paired_' + self.config['common']['phase'] + '_filelist']
69
+ elif self.config['dataset']['paired_' + self.config['common']['phase'] + 'A_folder'] != '' and \
70
+ self.config['dataset']['paired_' + self.config['common']['phase'] + 'B_folder'] != '':
71
+ new_datadict['paired_group']['paired_A_folder'] = self.config['dataset']['paired_' + self.config['common']['phase'] + 'A_folder']
72
+ new_datadict['paired_group']['paired_B_folder'] = self.config['dataset']['paired_' + self.config['common']['phase'] + 'B_folder']
73
+ else:
74
+ new_datadict['paired_group']['dataroot'] = self.config['dataset']['dataroot']
75
+
76
+ new_datadict['paired_group']['data_names'].append('paired_A')
77
+ new_datadict['paired_group']['data_names'].append('paired_B')
78
+ if data_type == 'paired':
79
+ new_datadict['paired_group']['data_types'].append('image')
80
+ new_datadict['paired_group']['data_types'].append('image')
81
+ else:
82
+ new_datadict['paired_group']['data_types'].append('numpy')
83
+ new_datadict['paired_group']['data_types'].append('numpy')
84
+
85
+ elif data_type == 'unpaired' or data_type == 'unpaired_numpy':
86
+ if self.config['dataset']['unpaired_' + self.config['common']['phase'] + 'A_filelist'] != ''\
87
+ and self.config['dataset']['unpaired_' + self.config['common']['phase'] + 'B_filelist'] != '':
88
+ # combine those two filelists into one filelist
89
+ self.combine_two_filelists_into_one(
90
+ self.config['dataset']['unpaired_' + self.config['common']['phase'] + 'A_filelist'],
91
+ self.config['dataset']['unpaired_' + self.config['common']['phase'] + 'B_filelist']
92
+ )
93
+ new_datadict['unpaired_group']['file_list'] = './tmp_filelist.txt'
94
+ elif self.config['dataset']['unpaired_' + self.config['common']['phase'] + 'A_folder'] != '' and \
95
+ self.config['dataset']['unpaired_' + self.config['common']['phase'] + 'B_folder'] != '':
96
+ new_datadict['unpaired_group']['unpaired_A_folder'] = self.config['dataset']['unpaired_' + self.config['common']['phase'] + 'A_folder']
97
+ new_datadict['unpaired_group']['unpaired_B_folder'] = self.config['dataset']['unpaired_' + self.config['common']['phase'] + 'B_folder']
98
+ else:
99
+ new_datadict['unpaired_group']['dataroot'] = self.config['dataset']['dataroot']
100
+
101
+ new_datadict['unpaired_group']['data_names'].append('unpaired_A')
102
+ new_datadict['unpaired_group']['data_names'].append('unpaired_B')
103
+ if data_type == 'unpaired':
104
+ new_datadict['unpaired_group']['data_types'].append('image')
105
+ new_datadict['unpaired_group']['data_types'].append('image')
106
+ else:
107
+ new_datadict['unpaired_group']['data_types'].append('numpy')
108
+ new_datadict['unpaired_group']['data_types'].append('numpy')
109
+
110
+ elif data_type == 'landmark':
111
+ if self.config['dataset']['paired_' + self.config['common']['phase'] + '_filelist'] != '':
112
+ new_datadict['paired_group']['file_list'] = self.config['dataset'][
113
+ 'paired_' + self.config['common']['phase'] + '_filelist']
114
+ elif 'paired_' + self.config['common']['phase'] + 'A_lmk_folder' in self.config['dataset'] and \
115
+ 'paired_' + self.config['common']['phase'] + 'B_lmk_folder' in self.config['dataset'] and \
116
+ self.config['dataset']['paired_' + self.config['common']['phase'] + 'A_lmk_folder'] != '' and \
117
+ self.config['dataset']['paired_' + self.config['common']['phase'] + 'B_lmk_folder'] != '':
118
+ new_datadict['paired_group']['lmk_A_folder'] = self.config['dataset']['paired_' + self.config['common']['phase'] + 'A_lmk_folder']
119
+ new_datadict['paired_group']['lmk_B_folder'] = self.config['dataset']['paired_' + self.config['common']['phase'] + 'B_lmk_folder']
120
+ else:
121
+ new_datadict['paired_group']['dataroot'] = self.config['dataset']['dataroot']
122
+
123
+ new_datadict['paired_group']['data_names'].append('lmk_A')
124
+ new_datadict['paired_group']['data_names'].append('lmk_B')
125
+ new_datadict['paired_group']['data_types'].append('landmark')
126
+ new_datadict['paired_group']['data_types'].append('landmark')
127
+
128
+ # Handle patches. This needs to happen after all non-patch data are added first.
129
+ if 'patch' in self.config['dataset']['data_type']:
130
+ # determine if patch comes from paired or unpaired image
131
+ if 'paired_A' in new_datadict['paired_group']['data_names']:
132
+ new_datadict['paired_group']['data_types'].append('patch')
133
+ new_datadict['paired_group']['data_names'].append('patch_A')
134
+ new_datadict['paired_group']['data_types'].append('patch')
135
+ new_datadict['paired_group']['data_names'].append('patch_B')
136
+
137
+ if 'patch_sources' not in new_datadict['paired_group']:
138
+ new_datadict['paired_group']['patch_sources'] = []
139
+ new_datadict['paired_group']['patch_sources'].append('paired_A')
140
+ new_datadict['paired_group']['patch_sources'].append('paired_B')
141
+ else:
142
+ new_datadict['unpaired_group']['data_types'].append('patch')
143
+ new_datadict['unpaired_group']['data_names'].append('patch_A')
144
+ new_datadict['unpaired_group']['data_types'].append('patch')
145
+ new_datadict['unpaired_group']['data_names'].append('patch_B')
146
+
147
+ if 'patch_sources' not in new_datadict['unpaired_group']:
148
+ new_datadict['unpaired_group']['patch_sources'] = []
149
+ new_datadict['unpaired_group']['patch_sources'].append('unpaired_A')
150
+ new_datadict['unpaired_group']['patch_sources'].append('unpaired_B')
151
+
152
+ if 'diff_patch' not in self.config['dataset']:
153
+ self.config['dataset']['diff_patch'] = False
154
+
155
+ new_datadict = {key: value for key, value in new_datadict.items() if len(value['data_names']) > 0}
156
+
157
+ print('-----------------------------------------------------------------------')
158
+ print("converted %s data configuration: " % self.config['common']['phase'])
159
+ for key, value in new_datadict.items():
160
+ print(key + ': ', value)
161
+ print('-----------------------------------------------------------------------')
162
+
163
+ return self.config
164
+
165
+
166
+ def combine_two_filelists_into_one(self, filelist1, filelist2):
167
+ tmp_file = open('./tmp_filelist.txt', 'w+')
168
+ f1 = open(filelist1, 'r')
169
+ f2 = open(filelist2, 'r')
170
+ f1_lines = f1.readlines()
171
+ f2_lines = f2.readlines()
172
+ min_index = min(len(f1_lines), len(f2_lines))
173
+ for i in range(min_index):
174
+ tmp_file.write(f1_lines[i].strip() + ' ' + f2_lines[i].strip() + '\n')
175
+ if min_index == len(f1_lines):
176
+ for i in range(min_index, len(f2_lines)):
177
+ tmp_file.write('None ' + f2_lines[i].strip() + '\n')
178
+ else:
179
+ for i in range(min_index, len(f1_lines)):
180
+ tmp_file.write(f1_lines[i].strip() + ' None\n')
181
+
182
+ tmp_file.close()
183
+ f1.close()
184
+ f2.close()
185
+
186
+
187
+ def __len__(self):
188
+ if self.config['common']['phase'] == 'test':
189
+ if self.config['testing']['test_video'] is not None:
190
+ return self.test_video_data.get_len()
191
+ else:
192
+ if len(self.data.keys()) == 0:
193
+ return 0
194
+ else:
195
+ min_len = 999999
196
+ for k, v in self.data.items():
197
+ length = len(v)
198
+ if length < min_len:
199
+ min_len = length
200
+ return min_len
201
+ else:
202
+ return self.static_data.get_len()
203
+
204
+
205
+
206
+ def get_item_logic(self, index):
207
+ return_dict = {}
208
+
209
+ if self.config['common']['phase'] == 'test':
210
+ if not self.config['testing']['test_video'] is None:
211
+ return self.test_video_data.get_item()
212
+ else:
213
+ apply_test_transforms(index, self.data, self.transforms, return_dict)
214
+ return return_dict
215
+
216
+ return_dict = self.static_data.get_item(index)
217
+
218
+ return return_dict
219
+
220
+
221
+ def __getitem__(self, index):
222
+ if self.config['dataset']['accept_data_error']:
223
+ while True:
224
+ try:
225
+ return self.get_item_logic(index)
226
+ except Exception as e:
227
+ print("Exception encountered in super_dataset's getitem function: ", e)
228
+ index = (index + 1) % self.__len__()
229
+ else:
230
+ return self.get_item_logic(index)
231
+
232
+
233
+ def split_data(self, value_mode, value, mode='split'):
234
+ new_dataset = copy.deepcopy(self)
235
+ ret1, new_dataset.static_data = self.split_data_helper(self.static_data, new_dataset.static_data, value_mode, value, mode=mode)
236
+ if ret1 is not None:
237
+ self.static_data = ret1
238
+ return self, new_dataset
239
+
240
+
241
+ def split_data_helper(self, dataset, new_dataset, value_mode, value, mode='split'):
242
+ for i in range(len(dataset.file_groups)):
243
+ max_split_index = 0
244
+ for k in dataset.file_groups[i].keys():
245
+ length = len(dataset.file_groups[i][k])
246
+ if value_mode == 'count':
247
+ split_index = min(length, value)
248
+ else:
249
+ split_index = int((1 - value) * length)
250
+ max_split_index = max(max_split_index, split_index)
251
+ new_dataset.file_groups[i][k] = new_dataset.file_groups[i][k][split_index:]
252
+ if mode == 'split':
253
+ dataset.file_groups[i][k] = dataset.file_groups[i][k][:split_index]
254
+ new_dataset.len_of_groups[i] -= max_split_index
255
+ if mode == 'split':
256
+ dataset.len_of_groups[i] = max_split_index
257
+ if mode == 'split':
258
+ return dataset, new_dataset
259
+ else:
260
+ return None, new_dataset
261
+
262
+
263
+ def check_data_helper(self, databin):
264
+ all_pass = True
265
+ for group in databin.filegroups:
266
+ for data_name, data_list in group.items():
267
+ for data in data_list:
268
+ if '.npy' in data: # case: numpy array or landmark
269
+ all_pass = all_pass and check_numpy_loaded(data)
270
+ else: # case: image
271
+ all_pass = all_pass and check_img_loaded(data)
272
+ return all_pass
273
+
274
+
275
+ def check_data(self):
276
+ if self.DDP_device is None or self.DDP_device == 0:
277
+ print("-----------------------Checking all data-------------------------------")
278
+ data_ok = True
279
+ if self.config['dataset']['n_threads'] == 0:
280
+ data_ok = data_ok and self.check_data_helper(self.static_data)
281
+ else:
282
+ # start n_threads number of workers to perform data checking
283
+ with Pool(processes=self.config['dataset']['n_threads']) as pool:
284
+ checks = pool.map(self.check_data_helper,
285
+ self.split_data_into_bins(self.config['dataset']['n_threads']))
286
+ for check in checks:
287
+ data_ok = data_ok and check
288
+ if data_ok:
289
+ print("---------------------all data passed check.-----------------------")
290
+ else:
291
+ print("---------------------The above data have failed in data checking. "
292
+ "Please fix first.---------------------------")
293
+ sys.exit()
294
+
295
+
296
+ def split_data_into_bins(self, num_bins):
297
+ bins = []
298
+ for i in range(num_bins):
299
+ bins.append(DataBin(filegroups=[]))
300
+
301
+ # handle static data
302
+ bins = self.split_data_into_bins_helper(bins, self.static_data)
303
+ return bins
304
+
305
+
306
+ def split_data_into_bins_helper(self, bins, dataset):
307
+ num_bins = len(bins)
308
+ for bin in bins:
309
+ for group_idx in range(len(dataset.file_groups)):
310
+ bin.filegroups.append({})
311
+
312
+ for group_idx in range(len(dataset.file_groups)):
313
+ file_group = dataset.file_groups[group_idx]
314
+ for data_name, data_list in file_group.items():
315
+ num_items_in_bin = len(data_list) // num_bins
316
+ for data_index in range(len(data_list)):
317
+ which_bin = min(data_index // num_items_in_bin, num_bins - 1)
318
+ if data_name not in bins[which_bin].filegroups[group_idx]:
319
+ bins[which_bin].filegroups[group_idx][data_name] = []
320
+ bins[which_bin].filegroups[group_idx][data_name].append(data_list[data_index])
321
+ return bins
data/test_data.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from utils.util import check_path_is_static_data
3
+ from utils.data_utils import Transforms
4
+ from utils.augmentation import ImagePathToImage, NumpyToTensor
5
+
6
+ def add_test_data(data, transforms, config):
7
+ A_paths = []
8
+ B_paths = []
9
+
10
+ if not config['testing']['test_img'] is None:
11
+ A_paths.append(config['testing']['test_img'])
12
+ B_paths.append(config['testing']['test_img'])
13
+ else:
14
+ files = os.listdir(config['testing']['test_folder'])
15
+ for fn in files:
16
+ if not check_path_is_static_data(fn):
17
+ continue
18
+ full_path = os.path.join(config['testing']['test_folder'], fn)
19
+ A_paths.append(full_path)
20
+ B_paths.append(full_path)
21
+
22
+ btoA = config['dataset']['direction'] == 'BtoA'
23
+ # get the number of channels of input image
24
+ input_nc = config['model']['output_nc'] if btoA else config['model']['input_nc']
25
+ output_nc = config['model']['input_nc'] if btoA else config['model']['output_nc']
26
+
27
+ transform = Transforms(config, input_grayscale_flag=(input_nc == 1), output_grayscale_flag=(output_nc == 1))
28
+ transform.create_transforms_from_list(config['testing']['preprocess'])
29
+ transform.get_transforms().insert(0, ImagePathToImage())
30
+ transform = transform.compose_transforms()
31
+
32
+ transform_np = Transforms(config, input_grayscale_flag=(input_nc == 1), output_grayscale_flag=(output_nc == 1))
33
+ transform_np.transform_list.append(NumpyToTensor())
34
+ transform_np = transform_np.compose_transforms()
35
+
36
+ data['test_A_path'] = A_paths
37
+ data['test_B_path'] = B_paths
38
+ transforms['test'] = transform
39
+ transforms['test_np'] = transform_np
40
+
41
+ def apply_test_transforms(index, data, transforms, return_dict):
42
+ if len(data['test_A_path']) > 0:
43
+ ext_name = os.path.splitext(data['test_A_path'][index])[1]
44
+ if not ext_name.lower() in ['.npy', '.npz']:
45
+ return_dict['test_A'], return_dict['test_B'] = transforms['test'] \
46
+ (data['test_A_path'][index], data['test_B_path'][index])
47
+ else:
48
+ return_dict['test_A'], return_dict['test_B'] = transforms['test_np'] \
49
+ (data['test_A_path'][index], data['test_B_path'][index])
50
+ return_dict['test_A_path'] = data['test_A_path'][index]
51
+ return_dict['test_B_path'] = data['test_B_path'][index]
data/test_video_data.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import imp
2
+ import cv2
3
+ from PIL import Image
4
+ from utils.data_utils import Transforms
5
+
6
+
7
+ class TestVideoData(object):
8
+
9
+ def __init__(self, config):
10
+
11
+ self.vcap = cv2.VideoCapture(config['testing']['test_video'])
12
+ self.transform = Transforms(config)
13
+ self.transform.create_transforms_from_list(config['testing']['preprocess'])
14
+ self.transform = self.transform.compose_transforms()
15
+
16
+ def __del__(self):
17
+ self.vcap.release()
18
+
19
+ def get_len(self):
20
+ return int(self.vcap.get(cv2.CAP_PROP_FRAME_COUNT))
21
+
22
+ def get_item(self):
23
+ return_dict = {}
24
+ _, frame = self.vcap.read()
25
+ frame = Image.fromarray(frame[:,:,::-1]).convert('RGB')
26
+ return_dict['test_A'], return_dict['test_B'] = self.transform(frame, frame)
27
+ return_dict['test_A_path'], return_dict['test_B_path'] = 'A.jpg', 'B.jpg'
28
+ return return_dict
exp/sp2pII-phase1.yaml ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ common:
2
+ name: "sp2pII-phase1"
3
+ model: "style_based_pix2pixII"
4
+ gpu_ids: [0]
5
+ option_group:
6
+ - gpu_ids: [0]
7
+ - gpu_ids: [1]
8
+ - gpu_ids: [2]
9
+
10
+ model:
11
+ ngf: 64
12
+
13
+ dataset:
14
+ unpaired_trainA_folder: # source domain folder path(FFHQ)
15
+ unpaired_trainB_folder: # target domain folder path(AAHQ)
16
+ preprocess: ["resize"]
17
+ batch_size: 8
18
+ crop_size: 512
19
+ drop_last: true
20
+ load_size: 512
21
+
22
+ training:
23
+ epoch_as_iter: true
24
+ n_epochs: 100000
25
+ n_epochs_decay: 10
26
+ print_freq: 1000
27
+ pretrained_model: "pretrained_models/ffhq_pretrain_res512_200000.pt"
28
+ save_epoch_freq: 5000
29
+ style_mixing_prob: 0.5
30
+ lambda_GAN: 1.0
31
+ lambda_ST: 1.0
32
+ lambda_L1: 1.0
33
+ option_group:
34
+ - lambda_Feat: 4.0
35
+ - lambda_Feat: 2.0
36
+ - lambda_Feat: 1.0
37
+ lr: 0.001
38
+ lr_policy: "linear"
39
+ beta1: 0.1
40
+
41
+ testing:
42
+ num_test: 100000
43
+ preprocess: ["resize"]
44
+ load_size: 512
45
+ crop_size: 512
46
+ results_dir: "./results/sp2pII"
47
+ visual_names: ["fake_B"]
48
+ image_format: "png"
49
+ which_epoch: "latest"
exp/sp2pII-phase2.yaml ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ common:
2
+ name: "sp2pII-phase2"
3
+ model: "style_based_pix2pixII"
4
+ gpu_ids: [0]
5
+ option_group:
6
+ - gpu_ids: [0]
7
+ - gpu_ids: [1]
8
+ - gpu_ids: [2]
9
+
10
+ model:
11
+ ngf: 64
12
+
13
+ dataset:
14
+ unpaired_trainA_folder: # source domain folder path(FFHQ)
15
+ unpaired_trainB_folder: # target domain folder path(AAHQ)
16
+ preprocess: ["resize"]
17
+ batch_size: 8
18
+ crop_size: 512
19
+ drop_last: true
20
+ load_size: 512
21
+
22
+ training:
23
+ epoch_as_iter: true
24
+ n_epochs: 300000
25
+ n_epochs_decay: 10
26
+ print_freq: 1000
27
+ phase: 2
28
+ pretrained_model: "pretrained_models/ffhq_pretrain_res512_200000.pt" # phase1 model
29
+ save_epoch_freq: 5000
30
+ style_mixing_prob: 0.5
31
+ lambda_GAN: 1.0
32
+ lambda_ST: 0.5 # 这个参数可以调整
33
+ option_group:
34
+ - data_aug_prob: 0.0
35
+ - data_aug_prob: 0.1
36
+ - data_aug_prob: 0.2
37
+ lr: 0.001
38
+ lr_policy: "linear"
39
+ beta1: 0.1
40
+
41
+ testing:
42
+ num_test: 100000
43
+ preprocess: ["resize"]
44
+ load_size: 512
45
+ crop_size: 512
46
+ results_dir: "./results/sp2pII"
47
+ visual_names: ["fake_B"]
48
+ image_format: "png"
49
+ which_epoch: "latest"
exp/sp2pII-phase3.yaml ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ common:
2
+ name: "sp2pII-phase3"
3
+ model: "style_based_pix2pixII"
4
+ gpu_ids: [0]
5
+ option_group:
6
+ - gpu_ids: [0]
7
+ - gpu_ids: [1]
8
+ - gpu_ids: [2]
9
+
10
+ model:
11
+ ngf: 64
12
+
13
+ dataset:
14
+ unpaired_trainA_folder: # source domain folder path(FFHQ)
15
+ unpaired_trainB_folder: # target domain folder path(AAHQ)
16
+ preprocess: ["resize"]
17
+ batch_size: 8
18
+ crop_size: 512
19
+ drop_last: true
20
+ load_size: 512
21
+
22
+ training:
23
+ epoch_as_iter: true
24
+ n_epochs: 100000 # 这个收敛很快,10w iter就差不多了
25
+ n_epochs_decay: 10
26
+ print_freq: 1000
27
+ phase: 3
28
+ pretrained_model: "pretrained_models/phase2_pretrain_90000.pth"
29
+ save_epoch_freq: 5000
30
+ style_mixing_prob: 0.5
31
+ lambda_GAN: 1.0
32
+ lambda_ST: 1.0
33
+ lambda_L1: 1.0
34
+ option_group:
35
+ - lambda_Feat: 4.0
36
+ - lambda_Feat: 2.0
37
+ - lambda_Feat: 1.0
38
+ lr: 0.0002
39
+ lr_policy: "linear"
40
+ beta1: 0.9
41
+
42
+ testing:
43
+ num_test: 100000
44
+ preprocess: ["resize"]
45
+ load_size: 512
46
+ crop_size: 512
47
+ results_dir: "./results/sp2pII"
48
+ visual_names: ["fake_B"]
49
+ image_format: "png"
50
+ which_epoch: "latest"
exp/sp2pII-phase4.yaml ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ common:
2
+ name: sp2pII-phase4
3
+ model: style_based_pix2pixII
4
+ gpu_ids:
5
+ - 0
6
+
7
+ dataset:
8
+ batch_size: 8
9
+ crop_size: 512
10
+ drop_last: true
11
+ load_size: 512
12
+ preprocess:
13
+ - resize
14
+ unpaired_trainA_folder: "/share/group_machongyang/project/hand_drawn/data/dataset/0909/trainPA/" # source domain folder path(FFHQ)
15
+ unpaired_trainB_folder: "/share/group_machongyang/project/hand_drawn/data/dataset/0909/trainPA/" # source domain folder path(AAHQ)
16
+ model:
17
+ ngf: 64
18
+ testing:
19
+ crop_size: 512
20
+ image_format: png
21
+ load_size: 512
22
+ num_test: 100000
23
+ preprocess:
24
+ - resize
25
+ results_dir: ./results/sp2pII
26
+ visual_names:
27
+ - fake_B
28
+ which_epoch: latest
29
+ training:
30
+ beta1: 0.9
31
+ epoch_as_iter: true
32
+ lambda_Feat: 4.0
33
+ lambda_GAN: 1.0
34
+ lambda_L1: 1.0
35
+ lambda_ST: 0.5
36
+ lambda_CLIP: 1.0 # 这个参数需要调整
37
+ lambda_PROJ: 100.0 # 这个参数需要调整(仅使用image prompt情况)
38
+ ema: 0.99 # 1-1/n
39
+ text_prompt: "not existed"
40
+ image_prompt: "" # 如果这个文件存在就用image prompt, 否则用text prompt
41
+ lr: 0.0002 # 这个参数需要调整, 大概1e-5 ~ 2e-4之间
42
+ lr_policy: linear
43
+ n_epochs: 200 # 这个一般500 iter就够了
44
+ n_epochs_decay: 10
45
+ phase: 4
46
+ pretrained_model: pretrained_models/phase3_pretrain_10000.pth
47
+ print_freq: 50
48
+ save_epoch_freq: 200
49
+ style_mixing_prob: 0.5
logs/01_2023_09_07__18_32_26/events.out.tfevents.1694082748.aiplatform-wlf2-hi-12.idchb2az2.hb2.kwaidc.com.16044.0 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8f0a1453d4a30c719d0a3e7b8aa8e0cecdd4d38d931606833e3fb5ce4165d171
3
+ size 38280782
logs/01_2023_09_12__14_54_32/events.out.tfevents.1694501684.aiplatform-wlf2-ge4-22.idchb2az2.hb2.kwaidc.com.76748.0 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8720497d0782c29971a6d1c2d5f538204f0be8cc406a1f6e2584d0450c9bd179
3
+ size 40
logs/01_2023_09_12__14_55_34/events.out.tfevents.1694501736.aiplatform-wlf2-ge4-22.idchb2az2.hb2.kwaidc.com.77369.0 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8e810b330e416e9f922c525b985de435ba2ef25afa013b9c732372a6dac58cf8
3
+ size 31611368
logs/01_2023_09_12__15_03_47/events.out.tfevents.1694502229.aiplatform-wlf2-ge4-22.idchb2az2.hb2.kwaidc.com.77940.0 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9b91beff44afb7b30e880de6477906b5adbdd771f2514158c166093018a0ee55
3
+ size 30850690
models/__init__.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """This package contains modules related to objective functions, optimizations, and network architectures.
2
+
3
+ To add a custom model class called 'dummy', you need to add a file called 'dummy_model.py' and define a subclass DummyModel inherited from BaseModel.
4
+ You need to implement the following five functions:
5
+ -- <__init__>: initialize the class; first call BaseModel.__init__(self, opt).
6
+ -- <set_input>: unpack data from dataset and apply preprocessing.
7
+ -- <forward>: produce intermediate results.
8
+ -- <optimize_parameters>: calculate loss, gradients, and update network weights.
9
+ -- <modify_commandline_options>: (optionally) add model-specific options and set default options.
10
+
11
+ In the function <__init__>, you need to define four lists:
12
+ -- self.loss_names (str list): specify the training losses that you want to plot and save.
13
+ -- self.model_names (str list): define networks used in our training.
14
+ -- self.visual_names (str list): specify the images that you want to display and save.
15
+ -- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an usage.
16
+
17
+ Now you can use the model class by specifying flag '--model dummy'.
18
+ See our template model class 'template_model.py' for more details.
19
+ """
20
+
21
+ import importlib
22
+ from models.base_model import BaseModel
23
+
24
+
25
+ def find_model_using_name(model_name):
26
+ """Import the module "models/[model_name]_model.py".
27
+
28
+ In the file, the class called DatasetNameModel() will
29
+ be instantiated. It has to be a subclass of BaseModel,
30
+ and it is case-insensitive.
31
+ """
32
+ model_filename = "models." + model_name + "_model"
33
+ modellib = importlib.import_module(model_filename)
34
+ #print(modellib)
35
+ model = None
36
+ target_model_name = model_name.replace('_', '') + 'model'
37
+ for name, cls in modellib.__dict__.items():
38
+ if name.lower() == target_model_name.lower() \
39
+ and issubclass(cls, BaseModel):
40
+ model = cls
41
+
42
+ if model is None:
43
+ print("In %s.py, there should be a subclass of BaseModel with class name that matches %s in lowercase." % (model_filename, target_model_name))
44
+ exit(0)
45
+
46
+ return model
47
+
48
+
49
+ def get_option_setter(model_name):
50
+ """Return the static method <modify_commandline_options> of the model class."""
51
+ model_class = find_model_using_name(model_name)
52
+ return model_class.modify_commandline_options
53
+
54
+
55
+ def create_model(config, DDP_device=None):
56
+ """Create a model given the option.
57
+
58
+ This function warps the class CustomDatasetDataLoader.
59
+ This is the main interface between this package and 'train.py'/'test.py'
60
+
61
+ Example:
62
+ >>> from models import create_model
63
+ >>> model = create_model(opt)
64
+ """
65
+ model = find_model_using_name(config['common']['model'])
66
+ instance = model(config, DDP_device=DDP_device)
67
+ print("model [%s] was created" % type(instance).__name__)
68
+ return instance
models/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (3.27 kB). View file
 
models/__pycache__/base_model.cpython-38.pyc ADDED
Binary file (12.2 kB). View file
 
models/__pycache__/style_based_pix2pixII_model.cpython-38.pyc ADDED
Binary file (15.6 kB). View file
 
models/base_model.py ADDED
@@ -0,0 +1,340 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from collections import OrderedDict
4
+ from abc import ABC, abstractmethod
5
+ from models.modules import networks
6
+ from utils.util import check_path
7
+ from utils.net_size import calc_computation
8
+
9
+
10
+ class BaseModel(ABC):
11
+ """This class is an abstract base class (ABC) for models.
12
+ To create a subclass, you need to implement the following five functions:
13
+ -- <__init__>: initialize the class; first call BaseModel.__init__(self, opt).
14
+ -- <set_input>: unpack data from dataset and apply preprocessing.
15
+ -- <forward>: produce intermediate results.
16
+ -- <optimize_parameters>: calculate losses, gradients, and update network weights.
17
+ -- <modify_commandline_options>: (optionally) add model-specific options and set default options.
18
+ """
19
+
20
+ def __init__(self, config, DDP_device=None):
21
+ """Initialize the BaseModel class.
22
+
23
+ Parameters:
24
+ opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
25
+
26
+ When creating your custom class, you need to implement your own initialization.
27
+ In this function, you should first call <BaseModel.__init__(self, opt)>
28
+ Then, you need to define four lists:
29
+ -- self.loss_names (str list): specify the training losses that you want to plot and save.
30
+ -- self.model_names (str list): define networks used in our training.
31
+ -- self.visual_names (str list): specify the images that you want to display and save.
32
+ -- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an example.
33
+ """
34
+ self.config = config
35
+ self.gpu_ids = config['common']['gpu_ids']
36
+ self.isTrain = config['common']['phase'] == 'train'
37
+ if DDP_device is None:
38
+ self.device = torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu') # get device name: CPU or GPU
39
+ self.DDP_device = None
40
+ self.on_cpu = (self.device.type == 'cpu')
41
+ else:
42
+ self.device = DDP_device
43
+ self.DDP_device = DDP_device
44
+ self.on_cpu = False
45
+ self.save_dir = os.path.join(config['training']['checkpoints_dir'], config['common']['name']) # save all the checkpoints to save_dir
46
+ if config['dataset']['preprocess'] != 'scale_width': # with [scale_width], input images might have different sizes, which hurts the performance of cudnn.benchmark.
47
+ torch.backends.cudnn.benchmark = True
48
+ self.loss_names = []
49
+ self.model_names = []
50
+ self.visual_names = []
51
+ self.optimizers = []
52
+ self.image_paths = []
53
+ self.metric = 0 # used for learning rate policy 'plateau'
54
+ self.curr_epoch = 0
55
+ self.total_iters = 0
56
+ self.best_val_loss = 999999
57
+
58
+ @abstractmethod
59
+ def set_input(self, input):
60
+ """Unpack input data from the dataloader and perform necessary pre-processing steps.
61
+
62
+ Parameters:
63
+ input (dict): includes the data itself and its metadata information.
64
+ """
65
+ pass
66
+
67
+ @abstractmethod
68
+ def forward(self):
69
+ """Run forward pass; called by both functions <configimize_parameters> and <test>."""
70
+ pass
71
+
72
+ @abstractmethod
73
+ def trace_jit(self, input):
74
+ """trace torchscript model for C++. Called by <trace_jit.py>"""
75
+ pass
76
+
77
+ @abstractmethod
78
+ def optimize_parameters(self):
79
+ """Calculate losses, gradients, and update network weights; called in every training iteration"""
80
+ pass
81
+
82
+ @abstractmethod
83
+ def eval_step(self):
84
+ """Forward and backward pass but without upgrading weights; called in every validation iteration"""
85
+ pass
86
+
87
+ def setup(self, config, DDP_device=None):
88
+ """Load and print networks; create schedulers
89
+
90
+ Parameters:
91
+ opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
92
+ """
93
+ if self.isTrain:
94
+ self.schedulers = [networks.get_scheduler(optimizer, config) for optimizer in self.optimizers]
95
+ if not self.isTrain:
96
+ load_suffix = '{}'.format(config['testing']['which_epoch'])
97
+ self.load_networks(load_suffix)
98
+ elif config['training']['continue_train']:
99
+ load_suffix = '{}'.format(config['training']['which_epoch'])
100
+ self.load_networks(load_suffix)
101
+ self.print_networks(config['common']['verbose'], DDP_device=DDP_device)
102
+
103
+ def eval(self):
104
+ """Make models eval mode during test time"""
105
+ for name in self.model_names:
106
+ if isinstance(name, str):
107
+ net = getattr(self, 'net' + name)
108
+ net.eval()
109
+
110
+ def train(self):
111
+ """Make models train mode during train time"""
112
+ for name in self.model_names:
113
+ if isinstance(name, str):
114
+ net = getattr(self, 'net' + name)
115
+ net.train()
116
+
117
+ def test(self):
118
+ """Forward function used in test time.
119
+
120
+ This function wraps <forward> function in no_grad() so we don't save intermediate steps for backprop
121
+ It also calls <compute_visuals> to produce additional visualization results
122
+ """
123
+ with torch.no_grad():
124
+ self.forward()
125
+ self.compute_visuals()
126
+
127
+ def compute_visuals(self):
128
+ """Calculate additional output images for visdom and HTML visualization"""
129
+ pass
130
+
131
+ def get_image_paths(self):
132
+ """ Return image paths that are used to load current data"""
133
+ return self.image_paths
134
+
135
+ def update_learning_rate(self):
136
+ """Update learning rates for all the networks; called at the end of every epoch"""
137
+ for scheduler in self.schedulers:
138
+ if self.config['training']['lr_policy'] == 'plateau':
139
+ scheduler.step(self.metric, epoch=self.curr_epoch)
140
+ else:
141
+ scheduler.step(epoch=self.curr_epoch)
142
+
143
+ # lr = self.optimizers[0].param_groups[0]['lr']
144
+ # print('learning rate = %.7f' % lr)
145
+
146
+ def get_current_visuals(self):
147
+ """Return visualization images. train.py will display these images with visdom, and save the images to a HTML"""
148
+ if not self.isTrain and len(self.config['testing']['visual_names']) > 0:
149
+ visual_names = list(set(self.visual_names).intersection(set(self.config['testing']['visual_names'])))
150
+ else:
151
+ visual_names = self.visual_names
152
+ visual_ret = OrderedDict()
153
+ for name in visual_names:
154
+ if isinstance(name, str):
155
+ visual_ret[name] = getattr(self, name)
156
+ return visual_ret
157
+
158
+ def get_current_losses(self):
159
+ """Return traning losses / errors. train.py will print out these errors on console, and save them to a file"""
160
+ errors_ret = OrderedDict()
161
+ for name in self.loss_names:
162
+ if isinstance(name, str):
163
+ errors_ret[name] = float(getattr(self, 'loss_' + name)) # float(...) works for both scalar tensor and float number
164
+ return errors_ret
165
+
166
+ def save_networks(self, epoch, val_loss=None):
167
+ """Save all the networks to the disk.
168
+
169
+ Parameters:
170
+ epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name)
171
+ """
172
+ check_path(self.save_dir)
173
+ save_filename = 'epoch_%s.pth' % epoch if val_loss is None else 'best_val_epoch.pth'
174
+ checkpoint = {}
175
+ # save all the models
176
+ for name in self.model_names:
177
+ if isinstance(name, str):
178
+ net = getattr(self, 'net' + name)
179
+ if len(self.gpu_ids) > 0 and torch.cuda.is_available():
180
+ # if use DDP, save only on rank 0. If using dataparallel, second condition meets.
181
+ if self.DDP_device == 0 or self.DDP_device is None:
182
+ checkpoint[name+'_model'] = net.module.state_dict()
183
+ else:
184
+ checkpoint[name+'_model'] = net.state_dict()
185
+
186
+ # save all the optimizers
187
+ optimizer_index = 0
188
+ for optimizer in self.optimizers:
189
+ checkpoint['optimizer_'+str(optimizer_index)] = optimizer.state_dict()
190
+ optimizer_index += 1
191
+
192
+ # save all the schedulers
193
+ scheduler_index = 0
194
+ for scheduler in self.schedulers:
195
+ checkpoint['scheduler_' + str(scheduler_index)] = scheduler.state_dict()
196
+ scheduler_index += 1
197
+
198
+ # save other information
199
+ checkpoint['epoch'] = self.curr_epoch
200
+ checkpoint['total_iters'] = self.total_iters
201
+ checkpoint['metric'] = self.metric
202
+ if val_loss is not None:
203
+ checkpoint['best_val_loss'] = val_loss
204
+
205
+ torch.save(checkpoint, os.path.join(self.save_dir, save_filename))
206
+
207
+ def __patch_instance_norm_state_dict(self, state_dict, module, keys, i=0):
208
+ """Fix InstanceNorm checkpoints incompatibility (prior to 0.4)"""
209
+ key = keys[i]
210
+ if i + 1 == len(keys): # at the end, pointing to a parameter/buffer
211
+ if module.__class__.__name__.startswith('InstanceNorm') and \
212
+ (key == 'running_mean' or key == 'running_var'):
213
+ if getattr(module, key) is None:
214
+ state_dict.pop('.'.join(keys))
215
+ if module.__class__.__name__.startswith('InstanceNorm') and \
216
+ (key == 'num_batches_tracked'):
217
+ state_dict.pop('.'.join(keys))
218
+ else:
219
+ self.__patch_instance_norm_state_dict(state_dict, getattr(module, key), keys, i + 1)
220
+
221
+ def load_networks(self, epoch, ckpt=None):
222
+ """Load all the networks from the disk.
223
+
224
+ Parameters:
225
+ epoch (str) -- current epoch; used in the file name 'epoch_%s.pth' % epoch. Models in the old format
226
+ with the names '%s_net_%s.pth' % (epoch, name) are also supported. Models in the new format takes priority.
227
+ """
228
+ load_filename = 'epoch_%s.pth' % epoch
229
+ if ckpt is None:
230
+ final_load_path = os.path.join(self.save_dir, load_filename)
231
+ else:
232
+ final_load_path = ckpt
233
+ if os.path.exists(final_load_path):
234
+ # new checkpoint format.
235
+ print('loading the model in new format from %s' % final_load_path)
236
+ if self.DDP_device is not None:
237
+ # unpack the tensors on GPU 0, then transfer to whatever device it needs to be on
238
+ map_location = {'cuda:%d' % 0: 'cuda:%d' % self.DDP_device}
239
+ checkpoint = torch.load(final_load_path, map_location=map_location)
240
+ else:
241
+ checkpoint = torch.load(final_load_path)
242
+ for k, v in checkpoint.items():
243
+ # load models
244
+ if 'model' in k:
245
+ name = k.split('_model')[0]
246
+ if not self.isTrain and 'D' in name: # does not load discriminator when not training
247
+ continue
248
+ if not hasattr(self, 'net' + name):
249
+ continue
250
+ net = getattr(self, 'net' + name)
251
+ if isinstance(net, torch.nn.DataParallel) or isinstance(net, torch.nn.parallel.DistributedDataParallel):
252
+ net = net.module
253
+
254
+ # if you are using PyTorch newer than 0.4 (e.g., built from
255
+ # GitHub source), you can remove str() on self.device
256
+ if hasattr(v, '_metadata'):
257
+ del v._metadata
258
+
259
+ # patch InstanceNorm checkpoints prior to 0.4
260
+ for key in list(v.keys()): # need to copy keys here because we mutate in loop
261
+ self.__patch_instance_norm_state_dict(v, net, key.split('.'))
262
+ net.load_state_dict(v)
263
+ # load optimizers
264
+ elif 'optimizer' in k:
265
+ if not self.isTrain:
266
+ continue
267
+ index = int(k.split('_')[-1])
268
+ self.optimizers[index].load_state_dict(v)
269
+ # load schedulers
270
+ elif 'scheduler' in k:
271
+ if not self.isTrain:
272
+ continue
273
+ index = int(k.split('_')[-1])
274
+ self.schedulers[index].load_state_dict(v)
275
+ # load other stuffs
276
+ elif k == 'epoch':
277
+ self.curr_epoch = int(v) + 1
278
+ elif k == 'total_iters':
279
+ self.total_iters = int(v)
280
+ elif k == 'metric':
281
+ self.metric = float(v)
282
+ elif k == 'best_val_loss':
283
+ self.best_val_loss = float(v)
284
+ else:
285
+ print('Checkpoint load error. Unrecognized parameter saved in checkpoint: ', k)
286
+ return
287
+
288
+ # old checkpoint format.
289
+ for name in self.model_names:
290
+ if isinstance(name, str):
291
+ load_filename = '%s_net_%s.pth' % (epoch, name)
292
+ load_path = os.path.join(self.save_dir, load_filename)
293
+ net = getattr(self, 'net' + name)
294
+ if isinstance(net, torch.nn.DataParallel) or isinstance(net, torch.nn.parallel.DistributedDataParallel):
295
+ net = net.module
296
+ print('loading the model from %s' % load_path)
297
+ # if you are using PyTorch newer than 0.4 (e.g., built from
298
+ # GitHub source), you can remove str() on self.device
299
+ state_dict = torch.load(load_path, map_location=str(self.device))
300
+ if hasattr(state_dict, '_metadata'):
301
+ del state_dict._metadata
302
+
303
+ # patch InstanceNorm checkpoints prior to 0.4
304
+ for key in list(state_dict.keys()): # need to copy keys here because we mutate in loop
305
+ self.__patch_instance_norm_state_dict(state_dict, net, key.split('.'))
306
+ net.load_state_dict(state_dict)
307
+
308
+ def print_networks(self, verbose, DDP_device=None):
309
+ """Print the total number of parameters in the network and (if verbose) network architecture
310
+
311
+ Parameters:
312
+ verbose (bool) -- if verbose: print the network architecture
313
+ """
314
+ if DDP_device is None or DDP_device == 0:
315
+ print('---------- Networks initialized -------------')
316
+ for name in self.model_names:
317
+ if isinstance(name, str):
318
+ net = getattr(self, 'net' + name)
319
+ num_params = 0
320
+ for param in net.parameters():
321
+ num_params += param.numel()
322
+ if verbose:
323
+ print(net)
324
+ print('[Network %s] Total number of parameters : %.3f M' % (name, num_params / 1e6))
325
+ if 'G' in name:
326
+ calc_computation(net, self.config['model']['input_nc'], self.config['dataset']['crop_size'],self.config['dataset']['crop_size'], DDP_device=DDP_device)
327
+ print('-----------------------------------------------')
328
+
329
+ def set_requires_grad(self, nets, requires_grad=False):
330
+ """Set requies_grad=Fasle for all the networks to avoid unnecessary computations
331
+ Parameters:
332
+ nets (network list) -- a list of networks
333
+ requires_grad (bool) -- whether the networks require gradients or not
334
+ """
335
+ if not isinstance(nets, list):
336
+ nets = [nets]
337
+ for net in nets:
338
+ if net is not None:
339
+ for param in net.parameters():
340
+ param.requires_grad = requires_grad
models/modules/__init__.py ADDED
File without changes
models/modules/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (155 Bytes). View file
 
models/modules/__pycache__/networks.cpython-38.pyc ADDED
Binary file (37.8 kB). View file
 
models/modules/networks.py ADDED
@@ -0,0 +1,1101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ import torch.nn.utils.spectral_norm as spectral_norm
6
+ from torch.nn import init
7
+ from torch.autograd import Variable
8
+ import functools
9
+ from torch.optim import lr_scheduler
10
+ from packaging import version
11
+ import numpy as np
12
+
13
+ ###############################################################################
14
+ # Helper Functions
15
+ ###############################################################################
16
+
17
+
18
+ class Identity(nn.Module):
19
+ def forward(self, x):
20
+ return x
21
+
22
+
23
+ def get_norm_layer(norm_type='instance'):
24
+ """Return a normalization layer
25
+
26
+ Parameters:
27
+ norm_type (str) -- the name of the normalization layer: batch | instance | none
28
+
29
+ For BatchNorm, we use learnable affine parameters and track running statistics (mean/stddev).
30
+ For InstanceNorm, we do not use learnable affine parameters. We do not track running statistics.
31
+ """
32
+ if norm_type == 'batch':
33
+ norm_layer = functools.partial(nn.BatchNorm2d, affine=True, track_running_stats=True)
34
+ elif norm_type == 'instance':
35
+ norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False)
36
+ elif norm_type == 'none':
37
+ def norm_layer(x): return Identity()
38
+ else:
39
+ raise NotImplementedError('normalization layer [%s] is not found' % norm_type)
40
+ return norm_layer
41
+
42
+
43
+ def get_scheduler(optimizer, config):
44
+ """Return a learning rate scheduler
45
+
46
+ Parameters:
47
+ optimizer -- the optimizer of the network
48
+ opt (option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions. 
49
+ opt.lr_policy is the name of learning rate policy: linear | step | plateau | cosine
50
+
51
+ For 'linear', we keep the same learning rate for the first <opt.n_epochs> epochs
52
+ and linearly decay the rate to zero over the next <opt.n_epochs_decay> epochs.
53
+ For other schedulers (step, plateau, and cosine), we use the default PyTorch schedulers.
54
+ See https://pytorch.org/docs/stable/optim.html for more details.
55
+ """
56
+ if config['training']['lr_policy'] == 'linear':
57
+ def lambda_rule(epoch):
58
+ lr_l = 1.0 - max(0, epoch + 1 - config['training']['n_epochs']) / float(config['training']['n_epochs_decay'] + 1)
59
+ return lr_l
60
+ scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule)
61
+ elif config['training']['lr_policy'] == 'step':
62
+ scheduler = lr_scheduler.StepLR(optimizer, step_size=config['training']['lr_decay_iters'], gamma=0.1)
63
+ elif config['training']['lr_policy'] == 'plateau':
64
+ scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5)
65
+ elif config['training']['lr_policy'] == 'cosine':
66
+ scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=config['training']['n_epochs'], eta_min=0)
67
+ else:
68
+ return NotImplementedError('learning rate policy [%s] is not implemented', config['training']['lr_policy'])
69
+ return scheduler
70
+
71
+
72
+ def init_weights(net, init_type='normal', init_gain=0.02):
73
+ """Initialize network weights.
74
+
75
+ Parameters:
76
+ net (network) -- network to be initialized
77
+ init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal
78
+ init_gain (float) -- scaling factor for normal, xavier and orthogonal.
79
+
80
+ We use 'normal' in the original pix2pix and CycleGAN paper. But xavier and kaiming might
81
+ work better for some applications. Feel free to try yourself.
82
+ """
83
+ def init_func(m): # define the initialization function
84
+ classname = m.__class__.__name__
85
+ if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
86
+ if init_type == 'normal':
87
+ init.normal_(m.weight.data, 0.0, init_gain)
88
+ elif init_type == 'xavier':
89
+ init.xavier_normal_(m.weight.data, gain=init_gain)
90
+ elif init_type == 'kaiming':
91
+ init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
92
+ elif init_type == 'orthogonal':
93
+ init.orthogonal_(m.weight.data, gain=init_gain)
94
+ else:
95
+ raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
96
+ if hasattr(m, 'bias') and m.bias is not None:
97
+ init.constant_(m.bias.data, 0.0)
98
+ elif classname.find('BatchNorm2d') != -1: # BatchNorm Layer's weight is not a matrix; only normal distribution applies.
99
+ if m.affine:
100
+ init.normal_(m.weight.data, 1.0, init_gain)
101
+ init.constant_(m.bias.data, 0.0)
102
+
103
+ if not init_type == 'none':
104
+ net.apply(init_func) # apply the initialization function <init_func>
105
+
106
+
107
+ def init_net(net, init_type='normal', init_gain=0.02, gpu_ids=[], DDP_device=None, find_unused_parameters=False):
108
+ """Initialize a network: 1. register CPU/GPU device (with multi-GPU support); 2. initialize the network weights
109
+ Parameters:
110
+ net (network) -- the network to be initialized
111
+ init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal
112
+ gain (float) -- scaling factor for normal, xavier and orthogonal.
113
+ gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2
114
+
115
+ Return an initialized network.
116
+ """
117
+ init_weights(net, init_type, init_gain=init_gain)
118
+ if DDP_device is not None:
119
+ net.to(DDP_device)
120
+ net = torch.nn.parallel.DistributedDataParallel(net, device_ids=[DDP_device], output_device=DDP_device,
121
+ broadcast_buffers=False, find_unused_parameters=find_unused_parameters) # DDP multi-GPUs
122
+ if DDP_device == 0:
123
+ print("model initiated in DDP mode.")
124
+ elif gpu_ids is not None and len(gpu_ids) > 0:
125
+ assert(torch.cuda.is_available())
126
+ net.to(gpu_ids[0])
127
+ net = torch.nn.DataParallel(net, gpu_ids) # multi-GPUs
128
+ print("model initiated in dataparallel mode.")
129
+ return net
130
+
131
+
132
+ def define_G(input_nc, output_nc, ngf, netG, norm='batch', use_dropout=False, init_type='normal', init_gain=0.02,
133
+ gpu_ids=[], DDP_device=None, find_unused_parameters=False):
134
+ """Create a generator
135
+
136
+ Parameters:
137
+ input_nc (int) -- the number of channels in input images
138
+ output_nc (int) -- the number of channels in output images
139
+ ngf (int) -- the number of filters in the last conv layer
140
+ netG (str) -- the architecture's name: resnet_9blocks | resnet_6blocks | unet_256 | unet_128
141
+ norm (str) -- the name of normalization layers used in the network: batch | instance | none
142
+ use_dropout (bool) -- if use dropout layers.
143
+ init_type (str) -- the name of our initialization method.
144
+ init_gain (float) -- scaling factor for normal, xavier and orthogonal.
145
+ gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2
146
+
147
+ Returns a generator
148
+
149
+ Our current implementation provides two types of generators:
150
+ U-Net: [unet_128] (for 128x128 input images) and [unet_256] (for 256x256 input images)
151
+ The original U-Net paper: https://arxiv.org/abs/1505.04597
152
+
153
+ Resnet-based generator: [resnet_6blocks] (with 6 Resnet blocks) and [resnet_9blocks] (with 9 Resnet blocks)
154
+ Resnet-based generator consists of several Resnet blocks between a few downsampling/upsampling operations.
155
+ We adapt Torch code from Justin Johnson's neural style transfer project (https://github.com/jcjohnson/fast-neural-style).
156
+
157
+
158
+ The generator has been initialized by <init_net>. It uses RELU for non-linearity.
159
+ """
160
+ net = None
161
+ norm_layer = get_norm_layer(norm_type=norm)
162
+
163
+ if netG == 'resnet_9blocks':
164
+ net = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=9)
165
+ elif netG == 'resnet_6blocks':
166
+ net = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=6)
167
+ elif netG == 'unet_128':
168
+ net = UnetGenerator(input_nc, output_nc, 7, ngf, norm_layer=norm_layer, use_dropout=use_dropout)
169
+ elif netG == 'unet_256':
170
+ net = UnetGenerator(input_nc, output_nc, 8, ngf, norm_layer=norm_layer, use_dropout=use_dropout)
171
+ else:
172
+ raise NotImplementedError('Generator model name [%s] is not recognized' % netG)
173
+ return init_net(net, init_type, init_gain, gpu_ids, DDP_device=DDP_device, find_unused_parameters=find_unused_parameters)
174
+
175
+ def define_F(netF, netF_nc=256, channels=[], use_dropout=False, init_type='normal', init_gain=0.02, gpu_ids=[], DDP_device=None, find_unused_parameters=False):
176
+ if netF == 'sample':
177
+ net = PatchSampleF(use_mlp=False, nc=netF_nc)
178
+ elif netF == 'mlp_sample':
179
+ net = PatchSampleF(use_mlp=True, nc=netF_nc)
180
+ else:
181
+ raise NotImplementedError('Projection model name [%s] is not recognized' % netF)
182
+ net.create_mlp(channels)
183
+ return init_net(net, init_type, init_gain, gpu_ids, DDP_device=DDP_device, find_unused_parameters=find_unused_parameters)
184
+
185
+ def define_D(input_nc, ndf, netD, n_layers_D=3, norm='batch', init_type='normal', init_gain=0.02,
186
+ gpu_ids=[], DDP_device=None, find_unused_parameters=False):
187
+ """Create a discriminator
188
+
189
+ Parameters:
190
+ input_nc (int) -- the number of channels in input images
191
+ ndf (int) -- the number of filters in the first conv layer
192
+ netD (str) -- the architecture's name: basic | n_layers | pixel
193
+ n_layers_D (int) -- the number of conv layers in the discriminator; effective when netD=='n_layers'
194
+ norm (str) -- the type of normalization layers used in the network.
195
+ init_type (str) -- the name of the initialization method.
196
+ init_gain (float) -- scaling factor for normal, xavier and orthogonal.
197
+ gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2
198
+
199
+ Returns a discriminator
200
+
201
+ Our current implementation provides three types of discriminators:
202
+ [basic]: 'PatchGAN' classifier described in the original pix2pix paper.
203
+ It can classify whether 70×70 overlapping patches are real or fake.
204
+ Such a patch-level discriminator architecture has fewer parameters
205
+ than a full-image discriminator and can work on arbitrarily-sized images
206
+ in a fully convolutional fashion.
207
+
208
+ [n_layers]: With this mode, you can specify the number of conv layers in the discriminator
209
+ with the parameter <n_layers_D> (default=3 as used in [basic] (PatchGAN).)
210
+
211
+ [pixel]: 1x1 PixelGAN discriminator can classify whether a pixel is real or not.
212
+ It encourages greater color diversity but has no effect on spatial statistics.
213
+
214
+ The discriminator has been initialized by <init_net>. It uses Leakly RELU for non-linearity.
215
+ """
216
+ net = None
217
+ norm_layer = get_norm_layer(norm_type=norm)
218
+
219
+ if netD == 'basic': # default PatchGAN classifier
220
+ net = NLayerDiscriminator(input_nc, ndf, n_layers=3, norm_layer=norm_layer)
221
+ elif netD == 'n_layers': # more options
222
+ net = NLayerDiscriminator(input_nc, ndf, n_layers_D, norm_layer=norm_layer)
223
+ elif netD == 'pixel': # classify if each pixel is real or fake
224
+ net = PixelDiscriminator(input_nc, ndf, norm_layer=norm_layer)
225
+ else:
226
+ raise NotImplementedError('Discriminator model name [%s] is not recognized' % netD)
227
+ return init_net(net, init_type, init_gain, gpu_ids, DDP_device=DDP_device, find_unused_parameters=find_unused_parameters)
228
+
229
+ def define_G_pix2pixHD(input_nc, output_nc, ngf, netG, n_downsample_global=3, n_blocks_global=9, n_local_enhancers=1,
230
+ n_blocks_local=3, norm='instance', gpu_ids=[], DDP_device=None, find_unused_parameters=False):
231
+ norm_layer = get_norm_layer(norm_type=norm)
232
+ if netG == 'global':
233
+ netG = GlobalGenerator(input_nc, output_nc, ngf, n_downsample_global, n_blocks_global, norm_layer)
234
+ elif netG == 'local':
235
+ netG = LocalEnhancer(input_nc, output_nc, ngf, n_downsample_global, n_blocks_global,
236
+ n_local_enhancers, n_blocks_local, norm_layer)
237
+ else:
238
+ raise('generator not implemented!')
239
+ return init_net(netG, 'normal', 0.02, gpu_ids, DDP_device=DDP_device, find_unused_parameters=find_unused_parameters)
240
+
241
+ def define_D_pix2pixHD(input_nc, ndf, n_layers_D, norm='instance', use_sigmoid=False, num_D=1, getIntermFeat=False,
242
+ gpu_ids=[], DDP_device=None, find_unused_parameters=False):
243
+ norm_layer = get_norm_layer(norm_type=norm)
244
+ netD = MultiscaleDiscriminator(input_nc, ndf, n_layers_D, norm_layer, use_sigmoid, num_D, getIntermFeat)
245
+ return init_net(netD, 'normal', 0.02, gpu_ids, DDP_device=DDP_device, find_unused_parameters=find_unused_parameters)
246
+
247
+
248
+ class Normalize(nn.Module):
249
+
250
+ def __init__(self, power=2):
251
+ super(Normalize, self).__init__()
252
+ self.power = power
253
+
254
+ def forward(self, x):
255
+ norm = x.pow(self.power).sum(1, keepdim=True).pow(1. / self.power)
256
+ out = x.div(norm + 1e-7)
257
+ return out
258
+
259
+
260
+ class PatchSampleF(nn.Module):
261
+ def __init__(self, use_mlp=False, nc=256):
262
+ # potential issues: currently, we use the same patch_ids for multiple images in the batch
263
+ super(PatchSampleF, self).__init__()
264
+ self.l2norm = Normalize(2)
265
+ self.use_mlp = use_mlp
266
+ self.nc = nc
267
+
268
+ def create_mlp(self, channels):
269
+ if not self.use_mlp:
270
+ return
271
+ for mlp_id, ch in enumerate(channels):
272
+ mlp = nn.Sequential(*[nn.Linear(ch, self.nc), nn.ReLU(), nn.Linear(self.nc, self.nc)])
273
+ setattr(self, 'mlp_%d' % mlp_id, mlp)
274
+
275
+ def forward(self, feats, num_patches=64, patch_ids=None):
276
+ return_ids = []
277
+ return_feats = []
278
+ for feat_id, feat in enumerate(feats):
279
+ B, H, W = feat.shape[0], feat.shape[2], feat.shape[3]
280
+ feat_reshape = feat.permute(0, 2, 3, 1).flatten(1, 2)
281
+ if num_patches > 0:
282
+ if patch_ids is not None:
283
+ patch_id = patch_ids[feat_id]
284
+ else:
285
+ patch_id = torch.randperm(feat_reshape.shape[1], device=feats[0].device)
286
+ patch_id = patch_id[:int(min(num_patches, patch_id.shape[0]))]
287
+ x_sample = feat_reshape[:, patch_id, :].flatten(0, 1)
288
+ else:
289
+ x_sample = feat_reshape
290
+ patch_id = []
291
+ if self.use_mlp:
292
+ mlp = getattr(self, 'mlp_%d' % feat_id)
293
+ x_sample = mlp(x_sample)
294
+ return_ids.append(patch_id)
295
+ x_sample = self.l2norm(x_sample)
296
+
297
+ if num_patches == 0:
298
+ x_sample = x_sample.permute(0, 2, 1).reshape([B, x_sample.shape[-1], H, W])
299
+ return_feats.append(x_sample)
300
+ return return_feats, return_ids
301
+
302
+
303
+ class LocalEnhancer(nn.Module):
304
+ def __init__(self, input_nc, output_nc, ngf=32, n_downsample_global=3, n_blocks_global=9,
305
+ n_local_enhancers=1, n_blocks_local=3, norm_layer=nn.BatchNorm2d, padding_type='reflect'):
306
+ super(LocalEnhancer, self).__init__()
307
+ self.n_local_enhancers = n_local_enhancers
308
+
309
+ ###### global generator model #####
310
+ ngf_global = ngf * (2**n_local_enhancers)
311
+ model_global = GlobalGenerator(input_nc, output_nc, ngf_global, n_downsample_global, n_blocks_global, norm_layer).model
312
+ model_global = [model_global[i] for i in range(len(model_global)-3)] # get rid of final convolution layers
313
+ self.model = nn.Sequential(*model_global)
314
+
315
+ ###### local enhancer layers #####
316
+ for n in range(1, n_local_enhancers+1):
317
+ ### downsample
318
+ ngf_global = ngf * (2**(n_local_enhancers-n))
319
+ model_downsample = [nn.ReflectionPad2d(3), nn.Conv2d(input_nc, ngf_global, kernel_size=7, padding=0),
320
+ norm_layer(ngf_global), nn.ReLU(True),
321
+ nn.Conv2d(ngf_global, ngf_global * 2, kernel_size=3, stride=2, padding=1),
322
+ norm_layer(ngf_global * 2), nn.ReLU(True)]
323
+ ### residual blocks
324
+ model_upsample = []
325
+ for i in range(n_blocks_local):
326
+ model_upsample += [ResnetBlock(ngf_global * 2, padding_type=padding_type, norm_layer=norm_layer, use_dropout=False, use_bias=True)]
327
+
328
+ ### upsample
329
+ model_upsample += [nn.ConvTranspose2d(ngf_global * 2, ngf_global, kernel_size=3, stride=2, padding=1, output_padding=1),
330
+ norm_layer(ngf_global), nn.ReLU(True)]
331
+
332
+ ### final convolution
333
+ if n == n_local_enhancers:
334
+ model_upsample += [nn.ReflectionPad2d(3), nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0), nn.Tanh()]
335
+
336
+ setattr(self, 'model'+str(n)+'_1', nn.Sequential(*model_downsample))
337
+ setattr(self, 'model'+str(n)+'_2', nn.Sequential(*model_upsample))
338
+
339
+ self.downsample = nn.AvgPool2d(3, stride=2, padding=[1, 1], count_include_pad=False)
340
+
341
+ def forward(self, input):
342
+ ### create input pyramid
343
+ input_downsampled = [input]
344
+ for i in range(self.n_local_enhancers):
345
+ input_downsampled.append(self.downsample(input_downsampled[-1]))
346
+
347
+ ### output at coarest level
348
+ output_prev = self.model(input_downsampled[-1])
349
+ ### build up one layer at a time
350
+ for n_local_enhancers in range(1, self.n_local_enhancers+1):
351
+ model_downsample = getattr(self, 'model'+str(n_local_enhancers)+'_1')
352
+ model_upsample = getattr(self, 'model'+str(n_local_enhancers)+'_2')
353
+ input_i = input_downsampled[self.n_local_enhancers-n_local_enhancers]
354
+ output_prev = model_upsample(model_downsample(input_i) + output_prev)
355
+ return output_prev
356
+
357
+ class GlobalGenerator(nn.Module):
358
+ def __init__(self, input_nc, output_nc, ngf=64, n_downsampling=3, n_blocks=9, norm_layer=nn.BatchNorm2d,
359
+ padding_type='reflect'):
360
+ assert(n_blocks >= 0)
361
+ super(GlobalGenerator, self).__init__()
362
+ activation = nn.ReLU(True)
363
+
364
+ model = [nn.ReflectionPad2d(3), nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0), norm_layer(ngf), activation]
365
+ ### downsample
366
+ for i in range(n_downsampling):
367
+ mult = 2**i
368
+ model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1),
369
+ norm_layer(ngf * mult * 2), activation]
370
+
371
+ ### resnet blocks
372
+ mult = 2**n_downsampling
373
+ for i in range(n_blocks):
374
+ model += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=False, use_bias=True)]
375
+
376
+ ### upsample
377
+ for i in range(n_downsampling):
378
+ mult = 2**(n_downsampling - i)
379
+ model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2), kernel_size=3, stride=2, padding=1, output_padding=1),
380
+ norm_layer(int(ngf * mult / 2)), activation]
381
+ model += [nn.ReflectionPad2d(3), nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0), nn.Tanh()]
382
+ self.model = nn.Sequential(*model)
383
+
384
+ def forward(self, input):
385
+ return self.model(input)
386
+
387
+ class MultiscaleDiscriminator(nn.Module):
388
+ def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d,
389
+ use_sigmoid=False, num_D=3, getIntermFeat=False):
390
+ super(MultiscaleDiscriminator, self).__init__()
391
+ self.num_D = num_D
392
+ self.n_layers = n_layers
393
+ self.getIntermFeat = getIntermFeat
394
+
395
+ for i in range(num_D):
396
+ netD = Pix2PixHDNLayerDiscriminator(input_nc, ndf, n_layers, norm_layer, use_sigmoid, getIntermFeat)
397
+ if getIntermFeat:
398
+ for j in range(n_layers+2):
399
+ setattr(self, 'scale'+str(i)+'_layer'+str(j), getattr(netD, 'model'+str(j)))
400
+ else:
401
+ setattr(self, 'layer'+str(i), netD.model)
402
+
403
+ self.downsample = nn.AvgPool2d(3, stride=2, padding=[1, 1], count_include_pad=False)
404
+
405
+ def singleD_forward(self, model, input):
406
+ if self.getIntermFeat:
407
+ result = [input]
408
+ for i in range(len(model)):
409
+ result.append(model[i](result[-1]))
410
+ return result[1:]
411
+ else:
412
+ return [model(input)]
413
+
414
+ def forward(self, input):
415
+ num_D = self.num_D
416
+ result = []
417
+ input_downsampled = input
418
+ for i in range(num_D):
419
+ if self.getIntermFeat:
420
+ model = [getattr(self, 'scale'+str(num_D-1-i)+'_layer'+str(j)) for j in range(self.n_layers+2)]
421
+ else:
422
+ model = getattr(self, 'layer'+str(num_D-1-i))
423
+ result.append(self.singleD_forward(model, input_downsampled))
424
+ if i != (num_D-1):
425
+ input_downsampled = self.downsample(input_downsampled)
426
+ return result
427
+
428
+ class Pix2PixHDNLayerDiscriminator(nn.Module):
429
+ def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, use_sigmoid=False, getIntermFeat=False):
430
+ super(Pix2PixHDNLayerDiscriminator, self).__init__()
431
+ self.getIntermFeat = getIntermFeat
432
+ self.n_layers = n_layers
433
+
434
+ kw = 4
435
+ padw = int(np.ceil((kw-1.0)/2))
436
+ sequence = [[nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]]
437
+
438
+ nf = ndf
439
+ for n in range(1, n_layers):
440
+ nf_prev = nf
441
+ nf = min(nf * 2, 512)
442
+ sequence += [[
443
+ nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=2, padding=padw),
444
+ norm_layer(nf), nn.LeakyReLU(0.2, True)
445
+ ]]
446
+
447
+ nf_prev = nf
448
+ nf = min(nf * 2, 512)
449
+ sequence += [[
450
+ nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=1, padding=padw),
451
+ norm_layer(nf),
452
+ nn.LeakyReLU(0.2, True)
453
+ ]]
454
+
455
+ sequence += [[nn.Conv2d(nf, 1, kernel_size=kw, stride=1, padding=padw)]]
456
+
457
+ if use_sigmoid:
458
+ sequence += [[nn.Sigmoid()]]
459
+
460
+ if getIntermFeat:
461
+ for n in range(len(sequence)):
462
+ setattr(self, 'model'+str(n), nn.Sequential(*sequence[n]))
463
+ else:
464
+ sequence_stream = []
465
+ for n in range(len(sequence)):
466
+ sequence_stream += sequence[n]
467
+ self.model = nn.Sequential(*sequence_stream)
468
+
469
+ def forward(self, input):
470
+ if self.getIntermFeat:
471
+ res = [input]
472
+ for n in range(self.n_layers+2):
473
+ model = getattr(self, 'model'+str(n))
474
+ res.append(model(res[-1]))
475
+ return res[1:]
476
+ else:
477
+ return self.model(input)
478
+
479
+
480
+ class MultiGANLoss(nn.Module):
481
+ def __init__(self, use_lsgan=True, target_real_label=1.0, target_fake_label=0.0,
482
+ tensor=torch.cuda.FloatTensor):
483
+ super(MultiGANLoss, self).__init__()
484
+ self.real_label = target_real_label
485
+ self.fake_label = target_fake_label
486
+ self.real_label_var = None
487
+ self.fake_label_var = None
488
+ self.Tensor = tensor
489
+ if use_lsgan:
490
+ self.loss = nn.MSELoss()
491
+ else:
492
+ self.loss = nn.BCELoss()
493
+
494
+ def get_target_tensor(self, input, target_is_real):
495
+ target_tensor = None
496
+ if target_is_real:
497
+ create_label = ((self.real_label_var is None) or
498
+ (self.real_label_var.numel() != input.numel()))
499
+ if create_label:
500
+ real_tensor = self.Tensor(input.size()).fill_(self.real_label)
501
+ self.real_label_var = Variable(real_tensor, requires_grad=False)
502
+ target_tensor = self.real_label_var
503
+ else:
504
+ create_label = ((self.fake_label_var is None) or
505
+ (self.fake_label_var.numel() != input.numel()))
506
+ if create_label:
507
+ fake_tensor = self.Tensor(input.size()).fill_(self.fake_label)
508
+ self.fake_label_var = Variable(fake_tensor, requires_grad=False)
509
+ target_tensor = self.fake_label_var
510
+ return target_tensor
511
+
512
+ def __call__(self, input, target_is_real):
513
+ if isinstance(input[0], list):
514
+ loss = 0
515
+ for input_i in input:
516
+ pred = input_i[-1]
517
+ target_tensor = self.get_target_tensor(pred, target_is_real)
518
+ loss += self.loss(pred, target_tensor)
519
+ return loss
520
+ else:
521
+ target_tensor = self.get_target_tensor(input[-1], target_is_real)
522
+ return self.loss(input[-1], target_tensor)
523
+
524
+ ##############################################################################
525
+ # Classes
526
+ ##############################################################################
527
+ class GANLoss(nn.Module):
528
+ """Define different GAN objectives.
529
+
530
+ The GANLoss class abstracts away the need to create the target label tensor
531
+ that has the same size as the input.
532
+ """
533
+
534
+ def __init__(self, gan_mode, target_real_label=1.0, target_fake_label=0.0):
535
+ """ Initialize the GANLoss class.
536
+
537
+ Parameters:
538
+ gan_mode (str) - - the type of GAN objective. It currently supports vanilla, lsgan, and wgangp.
539
+ target_real_label (bool) - - label for a real image
540
+ target_fake_label (bool) - - label of a fake image
541
+
542
+ Note: Do not use sigmoid as the last layer of Discriminator.
543
+ LSGAN needs no sigmoid. vanilla GANs will handle it with BCEWithLogitsLoss.
544
+ """
545
+ super(GANLoss, self).__init__()
546
+ self.register_buffer('real_label', torch.tensor(target_real_label))
547
+ self.register_buffer('fake_label', torch.tensor(target_fake_label))
548
+ self.gan_mode = gan_mode
549
+ if gan_mode == 'lsgan':
550
+ self.loss = nn.MSELoss()
551
+ elif gan_mode == 'vanilla':
552
+ self.loss = nn.BCEWithLogitsLoss()
553
+ elif gan_mode in ['wgangp']:
554
+ self.loss = None
555
+ else:
556
+ raise NotImplementedError('gan mode %s not implemented' % gan_mode)
557
+
558
+ def get_target_tensor(self, prediction, target_is_real):
559
+ """Create label tensors with the same size as the input.
560
+
561
+ Parameters:
562
+ prediction (tensor) - - tpyically the prediction from a discriminator
563
+ target_is_real (bool) - - if the ground truth label is for real images or fake images
564
+
565
+ Returns:
566
+ A label tensor filled with ground truth label, and with the size of the input
567
+ """
568
+
569
+ if target_is_real:
570
+ target_tensor = self.real_label
571
+ else:
572
+ target_tensor = self.fake_label
573
+ return target_tensor.expand_as(prediction)
574
+
575
+ def __call__(self, prediction, target_is_real):
576
+ """Calculate loss given Discriminator's output and grount truth labels.
577
+
578
+ Parameters:
579
+ prediction (tensor) - - tpyically the prediction output from a discriminator
580
+ target_is_real (bool) - - if the ground truth label is for real images or fake images
581
+
582
+ Returns:
583
+ the calculated loss.
584
+ """
585
+ if self.gan_mode in ['lsgan', 'vanilla']:
586
+ target_tensor = self.get_target_tensor(prediction, target_is_real)
587
+ loss = self.loss(prediction, target_tensor)
588
+ elif self.gan_mode == 'wgangp':
589
+ if target_is_real:
590
+ loss = nn.functional.softplus(-prediction).mean()
591
+ else:
592
+ loss = nn.functional.softplus(prediction).mean()
593
+ return loss
594
+
595
+
596
+ class PatchNCELoss(nn.Module):
597
+ def __init__(self, batch_size, nce_T):
598
+ super().__init__()
599
+ self.batch_size = batch_size
600
+ self.nce_T = nce_T
601
+ self.cross_entropy_loss = torch.nn.CrossEntropyLoss(reduction='none')
602
+ self.mask_dtype = torch.uint8 if version.parse(torch.__version__) < version.parse('1.2.0') else torch.bool
603
+
604
+ def forward(self, feat_q, feat_k):
605
+ batchSize = feat_q.shape[0]
606
+ dim = feat_q.shape[1]
607
+ feat_k = feat_k.detach()
608
+
609
+ # pos logit
610
+ l_pos = torch.bmm(feat_q.view(batchSize, 1, -1), feat_k.view(batchSize, -1, 1))
611
+ l_pos = l_pos.view(batchSize, 1)
612
+
613
+ # neg logit
614
+ batch_dim_for_bmm = self.batch_size
615
+
616
+ # reshape features to batch size
617
+ feat_q = feat_q.view(batch_dim_for_bmm, -1, dim)
618
+ feat_k = feat_k.view(batch_dim_for_bmm, -1, dim)
619
+ npatches = feat_q.size(1)
620
+ l_neg_curbatch = torch.bmm(feat_q, feat_k.transpose(2, 1))
621
+
622
+ # diagonal entries are similarity between same features, and hence meaningless.
623
+ # just fill the diagonal with very small number, which is exp(-10) and almost zero
624
+ diagonal = torch.eye(npatches, device=feat_q.device, dtype=self.mask_dtype)[None, :, :]
625
+ l_neg_curbatch.masked_fill_(diagonal, -10.0)
626
+ l_neg = l_neg_curbatch.view(-1, npatches)
627
+
628
+ out = torch.cat((l_pos, l_neg), dim=1) / self.nce_T
629
+
630
+ loss = self.cross_entropy_loss(out, torch.zeros(out.size(0), dtype=torch.long,
631
+ device=feat_q.device))
632
+
633
+ return loss
634
+
635
+
636
+ def cal_gradient_penalty(netD, real_data, fake_data, device, type='mixed', constant=1.0, lambda_gp=10.0):
637
+ """Calculate the gradient penalty loss, used in WGAN-GP paper https://arxiv.org/abs/1704.00028
638
+
639
+ Arguments:
640
+ netD (network) -- discriminator network
641
+ real_data (tensor array) -- real images
642
+ fake_data (tensor array) -- generated images from the generator
643
+ device (str) -- GPU / CPU: from torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu')
644
+ type (str) -- if we mix real and fake data or not [real | fake | mixed].
645
+ constant (float) -- the constant used in formula ( | |gradient||_2 - constant)^2
646
+ lambda_gp (float) -- weight for this loss
647
+
648
+ Returns the gradient penalty loss
649
+ """
650
+ if lambda_gp > 0.0:
651
+ if type == 'real': # either use real images, fake images, or a linear interpolation of two.
652
+ interpolatesv = real_data
653
+ elif type == 'fake':
654
+ interpolatesv = fake_data
655
+ elif type == 'mixed':
656
+ alpha = torch.rand(real_data.shape[0], 1, device=device)
657
+ alpha = alpha.expand(real_data.shape[0], real_data.nelement() // real_data.shape[0]).contiguous().view(*real_data.shape)
658
+ interpolatesv = alpha * real_data + ((1 - alpha) * fake_data)
659
+ else:
660
+ raise NotImplementedError('{} not implemented'.format(type))
661
+ interpolatesv.requires_grad_(True)
662
+ disc_interpolates = netD(interpolatesv)
663
+ gradients = torch.autograd.grad(outputs=disc_interpolates, inputs=interpolatesv,
664
+ grad_outputs=torch.ones(disc_interpolates.size()).to(device),
665
+ create_graph=True, retain_graph=True, only_inputs=True)
666
+ gradients = gradients[0].view(real_data.size(0), -1) # flat the data
667
+ gradient_penalty = (((gradients + 1e-16).norm(2, dim=1) - constant) ** 2).mean() * lambda_gp # added eps
668
+ return gradient_penalty, gradients
669
+ else:
670
+ return 0.0, None
671
+
672
+
673
+ class ResnetGenerator(nn.Module):
674
+ """Resnet-based generator that consists of Resnet blocks between a few downsampling/upsampling operations.
675
+
676
+ We adapt Torch code and idea from Justin Johnson's neural style transfer project(https://github.com/jcjohnson/fast-neural-style)
677
+ """
678
+
679
+ def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6, padding_type='reflect'):
680
+ """Construct a Resnet-based generator
681
+
682
+ Parameters:
683
+ input_nc (int) -- the number of channels in input images
684
+ output_nc (int) -- the number of channels in output images
685
+ ngf (int) -- the number of filters in the last conv layer
686
+ norm_layer -- normalization layer
687
+ use_dropout (bool) -- if use dropout layers
688
+ n_blocks (int) -- the number of ResNet blocks
689
+ padding_type (str) -- the name of padding layer in conv layers: reflect | replicate | zero
690
+ """
691
+ assert(n_blocks >= 0)
692
+ super(ResnetGenerator, self).__init__()
693
+ if type(norm_layer) == functools.partial:
694
+ use_bias = norm_layer.func == nn.InstanceNorm2d
695
+ else:
696
+ use_bias = norm_layer == nn.InstanceNorm2d
697
+
698
+ model = [nn.ReflectionPad2d(3),
699
+ nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=use_bias),
700
+ norm_layer(ngf),
701
+ nn.ReLU(True)]
702
+
703
+ n_downsampling = 2
704
+ for i in range(n_downsampling): # add downsampling layers
705
+ mult = 2 ** i
706
+ model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1, bias=use_bias),
707
+ norm_layer(ngf * mult * 2),
708
+ nn.ReLU(True)]
709
+
710
+ mult = 2 ** n_downsampling
711
+ for i in range(n_blocks): # add ResNet blocks
712
+
713
+ model += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)]
714
+
715
+ for i in range(n_downsampling): # add upsampling layers
716
+ mult = 2 ** (n_downsampling - i)
717
+ model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2),
718
+ kernel_size=3, stride=2,
719
+ padding=1, output_padding=1,
720
+ bias=use_bias),
721
+ norm_layer(int(ngf * mult / 2)),
722
+ nn.ReLU(True)]
723
+ model += [nn.ReflectionPad2d(3)]
724
+ model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
725
+ model += [nn.Tanh()]
726
+
727
+ self.model = nn.Sequential(*model)
728
+
729
+ def forward(self, input, layers=[]):
730
+ if len(layers) > 0:
731
+ feat = input
732
+ feats = []
733
+ for layer_id, layer in enumerate(self.model):
734
+ feat = layer(feat)
735
+ if layer_id in layers:
736
+ feats.append(feat)
737
+ if layer_id == layers[-1]:
738
+ break
739
+ return feats
740
+ else:
741
+ """Standard forward"""
742
+ return self.model(input)
743
+
744
+
745
+ class ResnetBlock(nn.Module):
746
+ """Define a Resnet block"""
747
+
748
+ def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias):
749
+ """Initialize the Resnet block
750
+
751
+ A resnet block is a conv block with skip connections
752
+ We construct a conv block with build_conv_block function,
753
+ and implement skip connections in <forward> function.
754
+ Original Resnet paper: https://arxiv.org/pdf/1512.03385.pdf
755
+ """
756
+ super(ResnetBlock, self).__init__()
757
+ self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_dropout, use_bias)
758
+
759
+ def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias):
760
+ """Construct a convolutional block.
761
+
762
+ Parameters:
763
+ dim (int) -- the number of channels in the conv layer.
764
+ padding_type (str) -- the name of padding layer: reflect | replicate | zero
765
+ norm_layer -- normalization layer
766
+ use_dropout (bool) -- if use dropout layers.
767
+ use_bias (bool) -- if the conv layer uses bias or not
768
+
769
+ Returns a conv block (with a conv layer, a normalization layer, and a non-linearity layer (ReLU))
770
+ """
771
+ conv_block = []
772
+ p = 0
773
+ if padding_type == 'reflect':
774
+ conv_block += [nn.ReflectionPad2d(1)]
775
+ elif padding_type == 'replicate':
776
+ conv_block += [nn.ReplicationPad2d(1)]
777
+ elif padding_type == 'zero':
778
+ p = 1
779
+ else:
780
+ raise NotImplementedError('padding [%s] is not implemented' % padding_type)
781
+
782
+ conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim), nn.ReLU(True)]
783
+ if use_dropout:
784
+ conv_block += [nn.Dropout(0.5)]
785
+
786
+ p = 0
787
+ if padding_type == 'reflect':
788
+ conv_block += [nn.ReflectionPad2d(1)]
789
+ elif padding_type == 'replicate':
790
+ conv_block += [nn.ReplicationPad2d(1)]
791
+ elif padding_type == 'zero':
792
+ p = 1
793
+ else:
794
+ raise NotImplementedError('padding [%s] is not implemented' % padding_type)
795
+ conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim)]
796
+
797
+ return nn.Sequential(*conv_block)
798
+
799
+ def forward(self, x):
800
+ """Forward function (with skip connections)"""
801
+ out = x + self.conv_block(x) # add skip connections
802
+ return out
803
+
804
+
805
+ class UnetGenerator(nn.Module):
806
+ """Create a Unet-based generator"""
807
+
808
+ def __init__(self, input_nc, output_nc, num_downs, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False):
809
+ """Construct a Unet generator
810
+ Parameters:
811
+ input_nc (int) -- the number of channels in input images
812
+ output_nc (int) -- the number of channels in output images
813
+ num_downs (int) -- the number of downsamplings in UNet. For example, # if |num_downs| == 7,
814
+ image of size 128x128 will become of size 1x1 # at the bottleneck
815
+ ngf (int) -- the number of filters in the last conv layer
816
+ norm_layer -- normalization layer
817
+
818
+ We construct the U-Net from the innermost layer to the outermost layer.
819
+ It is a recursive process.
820
+ """
821
+ super(UnetGenerator, self).__init__()
822
+ # construct unet structure
823
+ unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True) # add the innermost layer
824
+ for i in range(num_downs - 5): # add intermediate layers with ngf * 8 filters
825
+ unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_dropout=use_dropout)
826
+ # gradually reduce the number of filters from ngf * 8 to ngf
827
+ unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
828
+ unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
829
+ unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
830
+ self.model = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer) # add the outermost layer
831
+
832
+ def forward(self, input):
833
+ """Standard forward"""
834
+ return self.model(input)
835
+
836
+
837
+ class UnetSkipConnectionBlock(nn.Module):
838
+ """Defines the Unet submodule with skip connection.
839
+ X -------------------identity----------------------
840
+ |-- downsampling -- |submodule| -- upsampling --|
841
+ """
842
+
843
+ def __init__(self, outer_nc, inner_nc, input_nc=None,
844
+ submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False):
845
+ """Construct a Unet submodule with skip connections.
846
+
847
+ Parameters:
848
+ outer_nc (int) -- the number of filters in the outer conv layer
849
+ inner_nc (int) -- the number of filters in the inner conv layer
850
+ input_nc (int) -- the number of channels in input images/features
851
+ submodule (UnetSkipConnectionBlock) -- previously defined submodules
852
+ outermost (bool) -- if this module is the outermost module
853
+ innermost (bool) -- if this module is the innermost module
854
+ norm_layer -- normalization layer
855
+ use_dropout (bool) -- if use dropout layers.
856
+ """
857
+ super(UnetSkipConnectionBlock, self).__init__()
858
+ self.outermost = outermost
859
+ if type(norm_layer) == functools.partial:
860
+ use_bias = norm_layer.func == nn.InstanceNorm2d
861
+ else:
862
+ use_bias = norm_layer == nn.InstanceNorm2d
863
+ if input_nc is None:
864
+ input_nc = outer_nc
865
+ downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4,
866
+ stride=2, padding=1, bias=use_bias)
867
+ downrelu = nn.LeakyReLU(0.2, True)
868
+ downnorm = norm_layer(inner_nc)
869
+ uprelu = nn.ReLU(True)
870
+ upnorm = norm_layer(outer_nc)
871
+
872
+ if outermost:
873
+ upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
874
+ kernel_size=4, stride=2,
875
+ padding=1)
876
+ down = [downconv]
877
+ up = [uprelu, upconv, nn.Tanh()]
878
+ model = down + [submodule] + up
879
+ elif innermost:
880
+ upconv = nn.ConvTranspose2d(inner_nc, outer_nc,
881
+ kernel_size=4, stride=2,
882
+ padding=1, bias=use_bias)
883
+ down = [downrelu, downconv]
884
+ up = [uprelu, upconv, upnorm]
885
+ model = down + up
886
+ else:
887
+ upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
888
+ kernel_size=4, stride=2,
889
+ padding=1, bias=use_bias)
890
+ down = [downrelu, downconv, downnorm]
891
+ up = [uprelu, upconv, upnorm]
892
+
893
+ if use_dropout:
894
+ model = down + [submodule] + up + [nn.Dropout(0.5)]
895
+ else:
896
+ model = down + [submodule] + up
897
+
898
+ self.model = nn.Sequential(*model)
899
+
900
+ def forward(self, x):
901
+ if self.outermost:
902
+ return self.model(x)
903
+ else: # add skip connections
904
+ return torch.cat([x, self.model(x)], 1)
905
+
906
+
907
+ # Creates SPADE normalization layer based on the given configuration
908
+ # SPADE consists of two steps. First, it normalizes the activations using
909
+ # your favorite normalization method, such as Batch Norm or Instance Norm.
910
+ # Second, it applies scale and bias to the normalized output, conditioned on
911
+ # the segmentation map.
912
+ # The format of |config_text| is spade(norm)(ks), where
913
+ # (norm) specifies the type of parameter-free normalization.
914
+ # (e.g. syncbatch, batch, instance)
915
+ # (ks) specifies the size of kernel in the SPADE module (e.g. 3x3)
916
+ # Example |config_text| will be spadesyncbatch3x3, or spadeinstance5x5.
917
+ # Also, the other arguments are
918
+ # |norm_nc|: the #channels of the normalized activations, hence the output dim of SPADE
919
+ # |label_nc|: the #channels of the input semantic map, hence the input dim of SPADE
920
+ class SPADE(nn.Module):
921
+ def __init__(self, config_text, norm_nc, label_nc):
922
+ super().__init__()
923
+
924
+ assert config_text.startswith('spade')
925
+ parsed = re.search(r'spade(\D+)(\d)x\d', config_text)
926
+ param_free_norm_type = str(parsed.group(1))
927
+ ks = int(parsed.group(2))
928
+
929
+ if param_free_norm_type == 'instance':
930
+ self.param_free_norm = nn.InstanceNorm2d(norm_nc, affine=False)
931
+ elif param_free_norm_type == 'batch':
932
+ self.param_free_norm = nn.BatchNorm2d(norm_nc, affine=False)
933
+ elif param_free_norm_type == 'identity':
934
+ self.param_free_norm = nn.Identity()
935
+ else:
936
+ raise ValueError('%s is not a recognized param-free norm type in SPADE'
937
+ % param_free_norm_type)
938
+
939
+ # The dimension of the intermediate embedding space. Yes, hardcoded.
940
+ nhidden = 128
941
+
942
+ pw = ks // 2
943
+ self.mlp_shared = nn.Sequential(
944
+ nn.Conv2d(label_nc, nhidden, kernel_size=ks, padding=pw),
945
+ nn.ReLU()
946
+ )
947
+ self.mlp_gamma = nn.Conv2d(nhidden, norm_nc, kernel_size=ks, padding=pw)
948
+ self.mlp_beta = nn.Conv2d(nhidden, norm_nc, kernel_size=ks, padding=pw)
949
+
950
+ def forward(self, x, segmap):
951
+
952
+ # Part 1. generate parameter-free normalized activations
953
+ normalized = self.param_free_norm(x)
954
+
955
+ # Part 2. produce scaling and bias conditioned on semantic map
956
+ segmap = F.interpolate(segmap, size=x.size()[2:], mode='nearest')
957
+ actv = self.mlp_shared(segmap)
958
+ gamma = self.mlp_gamma(actv)
959
+ beta = self.mlp_beta(actv)
960
+
961
+ # apply scale and bias
962
+ out = normalized * (1 + gamma) + beta
963
+
964
+ return out
965
+
966
+
967
+ # ResNet block that uses SPADE.
968
+ # It differs from the ResNet block of pix2pixHD in that
969
+ # it takes in the segmentation map as input, learns the skip connection if necessary,
970
+ # and applies normalization first and then convolution.
971
+ # This architecture seemed like a standard architecture for unconditional or
972
+ # class-conditional GAN architecture using residual block.
973
+ # The code was inspired from https://github.com/LMescheder/GAN_stability.
974
+ class SPADEResnetBlock(nn.Module):
975
+ def __init__(self, fin, fout, config_str, semantic_nc):
976
+ super().__init__()
977
+ # Attributes
978
+ self.learned_shortcut = (fin != fout)
979
+ fmiddle = min(fin, fout)
980
+
981
+ # create conv layers
982
+ self.conv_0 = nn.Conv2d(fin, fmiddle, kernel_size=3, padding=1)
983
+ self.conv_1 = nn.Conv2d(fmiddle, fout, kernel_size=3, padding=1)
984
+ if self.learned_shortcut:
985
+ self.conv_s = nn.Conv2d(fin, fout, kernel_size=1, bias=False)
986
+
987
+ # apply spectral norm if specified
988
+ if 'spectral' in config_str:
989
+ self.conv_0 = spectral_norm(self.conv_0)
990
+ self.conv_1 = spectral_norm(self.conv_1)
991
+ if self.learned_shortcut:
992
+ self.conv_s = spectral_norm(self.conv_s)
993
+
994
+ # define normalization layers
995
+ spade_config_str = config_str.replace('spectral', '')
996
+ self.norm_0 = SPADE(spade_config_str, fin, semantic_nc)
997
+ self.norm_1 = SPADE(spade_config_str, fmiddle, semantic_nc)
998
+ if self.learned_shortcut:
999
+ self.norm_s = SPADE(spade_config_str, fin, semantic_nc)
1000
+
1001
+ # note the resnet block with SPADE also takes in |seg|,
1002
+ # the semantic segmentation map as input
1003
+ def forward(self, x, seg):
1004
+ x_s = self.shortcut(x, seg)
1005
+
1006
+ dx = self.conv_0(self.actvn(self.norm_0(x, seg)))
1007
+ dx = self.conv_1(self.actvn(self.norm_1(dx, seg)))
1008
+
1009
+ out = x_s + dx
1010
+
1011
+ return out
1012
+
1013
+ def shortcut(self, x, seg):
1014
+ if self.learned_shortcut:
1015
+ x_s = self.conv_s(self.norm_s(x, seg))
1016
+ else:
1017
+ x_s = x
1018
+ return x_s
1019
+
1020
+ def actvn(self, x):
1021
+ return F.leaky_relu(x, 2e-1)
1022
+
1023
+
1024
+ class NLayerDiscriminator(nn.Module):
1025
+ """Defines a PatchGAN discriminator"""
1026
+
1027
+ def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d):
1028
+ """Construct a PatchGAN discriminator
1029
+
1030
+ Parameters:
1031
+ input_nc (int) -- the number of channels in input images
1032
+ ndf (int) -- the number of filters in the last conv layer
1033
+ n_layers (int) -- the number of conv layers in the discriminator
1034
+ norm_layer -- normalization layer
1035
+ """
1036
+ super(NLayerDiscriminator, self).__init__()
1037
+ if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters
1038
+ use_bias = norm_layer.func == nn.InstanceNorm2d
1039
+ else:
1040
+ use_bias = norm_layer == nn.InstanceNorm2d
1041
+
1042
+ kw = 4
1043
+ padw = 1
1044
+ sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]
1045
+ nf_mult = 1
1046
+ nf_mult_prev = 1
1047
+ for n in range(1, n_layers): # gradually increase the number of filters
1048
+ nf_mult_prev = nf_mult
1049
+ nf_mult = min(2 ** n, 8)
1050
+ sequence += [
1051
+ nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
1052
+ norm_layer(ndf * nf_mult),
1053
+ nn.LeakyReLU(0.2, True)
1054
+ ]
1055
+
1056
+ nf_mult_prev = nf_mult
1057
+ nf_mult = min(2 ** n_layers, 8)
1058
+ sequence += [
1059
+ nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),
1060
+ norm_layer(ndf * nf_mult),
1061
+ nn.LeakyReLU(0.2, True)
1062
+ ]
1063
+
1064
+ sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map
1065
+ self.model = nn.Sequential(*sequence)
1066
+
1067
+ def forward(self, input):
1068
+ """Standard forward."""
1069
+ return self.model(input)
1070
+
1071
+
1072
+ class PixelDiscriminator(nn.Module):
1073
+ """Defines a 1x1 PatchGAN discriminator (pixelGAN)"""
1074
+
1075
+ def __init__(self, input_nc, ndf=64, norm_layer=nn.BatchNorm2d):
1076
+ """Construct a 1x1 PatchGAN discriminator
1077
+
1078
+ Parameters:
1079
+ input_nc (int) -- the number of channels in input images
1080
+ ndf (int) -- the number of filters in the last conv layer
1081
+ norm_layer -- normalization layer
1082
+ """
1083
+ super(PixelDiscriminator, self).__init__()
1084
+ if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters
1085
+ use_bias = norm_layer.func == nn.InstanceNorm2d
1086
+ else:
1087
+ use_bias = norm_layer == nn.InstanceNorm2d
1088
+
1089
+ self.net = [
1090
+ nn.Conv2d(input_nc, ndf, kernel_size=1, stride=1, padding=0),
1091
+ nn.LeakyReLU(0.2, True),
1092
+ nn.Conv2d(ndf, ndf * 2, kernel_size=1, stride=1, padding=0, bias=use_bias),
1093
+ norm_layer(ndf * 2),
1094
+ nn.LeakyReLU(0.2, True),
1095
+ nn.Conv2d(ndf * 2, 1, kernel_size=1, stride=1, padding=0, bias=use_bias)]
1096
+
1097
+ self.net = nn.Sequential(*self.net)
1098
+
1099
+ def forward(self, input):
1100
+ """Standard forward."""
1101
+ return self.net(input)
models/modules/sr/light_model_270M.py ADDED
@@ -0,0 +1,347 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from models.modules.pix2pixMini_module import *
4
+
5
+ pack = Pack()
6
+ unpack = UnPack()
7
+ AVG = AvgQuant()
8
+ mul = SliceMul()
9
+
10
+ def get_int(channel_dict):
11
+ for i in range(len(channel_dict['down'])):
12
+ channel_dict['down'][i] = int(channel_dict['down'][i])
13
+ for i in range(len(channel_dict['backbone'])):
14
+ channel_dict['backbone'][i] = int(channel_dict['backbone'][i])
15
+ for i in range(len(channel_dict['up'])):
16
+ for j in range(len(channel_dict['up'][i])):
17
+ if channel_dict['up'][i][j] is not None:
18
+ channel_dict['up'][i][j] = int(channel_dict['up'][i][j])
19
+ return channel_dict
20
+
21
+ def get_channel_dict(dict_name, ngf):
22
+ if dict_name is None:
23
+ raise('invalid channel_dict name')
24
+
25
+ if dict_name == '3G':
26
+ channel_dict = {
27
+ 'n_blocks': 8,
28
+ 'down': [ngf * 1, ngf * 2, ngf * 4, ngf * 8, ngf * 8],
29
+ 'backbone': [ngf * 8, ngf * 8],
30
+ 'hair_up': [
31
+ [None, None, None, ngf * 4, 2],
32
+ [None, None, None, ngf * 2, 2],
33
+ [None, None, None, ngf * 1, 2],
34
+ [None, None, None, ngf * 1, 1],
35
+ ],
36
+ 'face_up': [
37
+ [None, None, None, ngf * 4, 2],
38
+ [None, None, None, ngf * 2, 2],
39
+ [None, None, None, ngf * 2, 2],
40
+ [None, None, None, ngf * 1, 1],
41
+ ],
42
+ 'up': [
43
+ [None, None, None, ngf * 1, 1],
44
+ ],
45
+ }
46
+ elif dict_name == '270M':
47
+ channel_dict = {
48
+ 'n_blocks': 4,
49
+ 'down': [ngf * 2, ngf * 2, ngf * 3, ngf * 6, ngf * 4],
50
+ 'backbone': [ngf * 4, ngf * 4],
51
+ 'hair_up': [
52
+ [ngf * 4, None, None, ngf * 2, 1],
53
+ [ngf * 2, None, None, ngf * 2, 1],
54
+ [ngf * 2, None, ngf * 1, ngf * 1, 1],
55
+ [None, None, None, ngf * 1, 1],
56
+ ],
57
+ 'face_up': [
58
+ [ngf * 4, ngf * 6, ngf * 3, ngf * 2, 1],
59
+ [ngf * 2, ngf * 3, ngf * 2, ngf * 2, 1],
60
+ [ngf * 2, ngf * 2, ngf * 1, ngf * 1, 1],
61
+ [None, None, None, ngf * 1, 1],
62
+ ],
63
+ 'up': [
64
+ [ngf * 2, ngf * 1, ngf * 1, ngf * 2, 1],
65
+ ],
66
+ }
67
+ else:
68
+ raise('invalid_dict_name')
69
+ return get_int(channel_dict)
70
+
71
+ chans = 8
72
+ in_channels = 12
73
+ out_channels = 16
74
+ face_branch_out_channels = 12
75
+ hair_branch_out_channels = 12
76
+
77
+ class hair_face_model_old(nn.Module):
78
+ def __init__(self, ngf=chans, backbone_type='resnet', use_se=True, channel_dict_name = None, with_hair_branch=False, design=5):
79
+ super().__init__()
80
+ self.design = design
81
+ self.with_hair_branch = with_hair_branch
82
+ channel_dict = get_channel_dict(channel_dict_name, ngf)
83
+ n_blocks = channel_dict['n_blocks']
84
+
85
+ self.inconv = ConvBlock(in_channels, channel_dict['down'][0], stride=1)
86
+
87
+ # Down-Sampling
88
+ self.DownBlock1 = ConvBlock(channel_dict['down'][0], channel_dict['down'][1], stride=2)
89
+ self.DownBlock2 = ConvBlock(channel_dict['down'][1], channel_dict['down'][2], stride=2)
90
+ self.DownBlock3 = ConvBlock(channel_dict['down'][2], channel_dict['down'][3], stride=2)
91
+ self.DownBlock4 = ConvBlock(channel_dict['down'][3], channel_dict['down'][4], stride=2)
92
+
93
+ # Down-Sampling Bottleneck
94
+ if backbone_type == 'resnet':
95
+ backbone_block = ResnetBlock
96
+ elif backbone_type == 'mobilenet':
97
+ backbone_block = InvertedBottleneck
98
+ n_blocks = n_blocks
99
+ else:
100
+ raise('invalid backbone type')
101
+ ResBlock = []
102
+ ResBlock += [backbone_block(channel_dict['down'][4], channel_dict['backbone'][0], use_bias=False, use_se=use_se)]
103
+ for i in range(1, n_blocks - 1):
104
+ ResBlock += [backbone_block(channel_dict['backbone'][0], channel_dict['backbone'][0], use_bias=False, use_se=use_se)]
105
+ ResBlock += [backbone_block(channel_dict['backbone'][0], channel_dict['backbone'][1], use_bias=False, use_se=use_se)]
106
+ self.ResBlock = nn.Sequential(*ResBlock)
107
+
108
+ self.HairUpBlock4 = UpBlock(channel_dict['backbone'][1], None, channel_dict['hair_up'][0][0], None, channel_dict['hair_up'][0][2], channel_dict['hair_up'][0][3], num_conv=channel_dict['hair_up'][0][4])
109
+ self.HairUpBlock3 = UpBlock(channel_dict['hair_up'][0][3], None, channel_dict['hair_up'][1][0], None, channel_dict['hair_up'][1][2], channel_dict['hair_up'][1][3], num_conv=channel_dict['hair_up'][1][4])
110
+ self.HairUpBlock2 = UpBlock(channel_dict['hair_up'][1][3], None, channel_dict['hair_up'][2][0], None, channel_dict['hair_up'][2][2], channel_dict['hair_up'][2][3], num_conv=channel_dict['hair_up'][2][4])
111
+
112
+ self.FaceUpBlock4 = UpBlock(channel_dict['backbone'][1], channel_dict['down'][3], channel_dict['face_up'][0][0], channel_dict['face_up'][0][1], channel_dict['face_up'][0][2], channel_dict['face_up'][0][3], num_conv=channel_dict['face_up'][0][4])
113
+ self.FaceUpBlock3 = UpBlock(channel_dict['face_up'][0][3], channel_dict['down'][2], channel_dict['face_up'][1][0], channel_dict['face_up'][1][1], channel_dict['face_up'][1][2], channel_dict['face_up'][1][3], num_conv=channel_dict['face_up'][1][4])
114
+ self.FaceUpBlock2 = UpBlock(channel_dict['face_up'][1][3], channel_dict['down'][1], channel_dict['face_up'][2][0], channel_dict['face_up'][2][1], channel_dict['face_up'][2][2], channel_dict['face_up'][2][3], num_conv=channel_dict['face_up'][2][4])
115
+
116
+ self.UpBlock1 = UpBlock(channel_dict['hair_up'][2][3] + channel_dict['face_up'][2][3], channel_dict['down'][0], channel_dict['up'][0][0], channel_dict['up'][0][1], channel_dict['up'][0][2], channel_dict['up'][0][3], num_conv=channel_dict['up'][0][4])
117
+ self.outconv = ConvOutBlock(channel_dict['up'][0][3], out_channels )
118
+
119
+ #self.shortcut_ratio = [1,1,1,1]
120
+
121
+ if self.with_hair_branch:
122
+ self.HairUpBlock1 = UpBlock(channel_dict['hair_up'][2][3], None, channel_dict['hair_up'][3][0], None, channel_dict['hair_up'][3][2],channel_dict['hair_up'][3][3])
123
+ self.Hairoutconv = ConvOutBlock(channel_dict['hair_up'][3][3], hair_branch_out_channels)
124
+
125
+ self.FaceUpBlock1 = UpBlock(channel_dict['face_up'][2][3], channel_dict['down'][0], channel_dict['face_up'][3][0], channel_dict['face_up'][3][1], channel_dict['face_up'][3][2], channel_dict['face_up'][3][3])
126
+ self.Faceoutconv = ConvOutBlock(channel_dict['face_up'][3][3], face_branch_out_channels)
127
+
128
+ self.up = UpsampleQuant(scale_factor=1.5, mode='bilinear')
129
+
130
+
131
+ def forward(self, x , with_hair=False):
132
+
133
+ design = self.design
134
+
135
+ x0 = self.inconv(x)
136
+ x1 = self.DownBlock1(x0)
137
+ x2 = self.DownBlock2(x1)
138
+ x3 = self.DownBlock3(x2)
139
+ x4 = self.DownBlock4(x3)
140
+ x5 = self.ResBlock(x4)
141
+
142
+ x6 = x5
143
+ hair = self.HairUpBlock4(x6)
144
+ hair = self.HairUpBlock3(hair)
145
+ hair = self.HairUpBlock2( hair)
146
+
147
+ face = self.FaceUpBlock4(x6, self.up(x3) if design == 0 else x3)
148
+ face = self.FaceUpBlock3(self.up(face) if design==1 else face,
149
+ self.up(x2) if design <= 1 else x2)
150
+ face = self.FaceUpBlock2(self.up(face) if design == 2 else face,
151
+ self.up(x1) if design <= 2 else x1)
152
+
153
+ hf_cat = torch.cat([hair,face], dim=1)
154
+
155
+ x7 = self.UpBlock1(hf_cat, x0)
156
+
157
+ x7 = self.outconv( x7)
158
+ if not with_hair or not self.with_hair_branch:
159
+ return x7
160
+ else:
161
+ hair = self.HairUpBlock1(hair)
162
+ hair = self.Hairoutconv(hair)
163
+ face = self.FaceUpBlock1(face, x0)
164
+ face = self.Faceoutconv(face)
165
+ # print(design == 5) #true
166
+ return x7 if design == 5 else x7, hair, face
167
+
168
+ class hair_face_model(hair_face_model_old):
169
+ def __init__(self, **kwargs):
170
+ super(hair_face_model, self).__init__(**kwargs)
171
+
172
+ self.upconv = nn.Sequential(
173
+ nn.Upsample(scale_factor=2, mode='bilinear'),
174
+ nn.Conv2d(4, 4, kernel_size=3, stride=1, padding=1, bias=True),
175
+ nn.Tanh()
176
+ )
177
+
178
+ def forward(self, x):
179
+ x=pack(x)
180
+ x = super().forward(x)
181
+ x=unpack(x)
182
+ return self.upconv(x)
183
+
184
+ class ConvBlock(nn.Module):
185
+ def __init__(self, in_ch, out_ch, stride):
186
+ super(ConvBlock, self).__init__()
187
+ self.conv = nn.Sequential(
188
+ #nn.Conv2d(in_ch, out_ch, kernel_size=3, stride=stride, padding=1, bias=False),
189
+ Conv2dQuant(in_ch, out_ch, kernel_size=3, stride=stride, padding=1, bias=True),
190
+ nn.BatchNorm2d(out_ch),
191
+ HardQuant(0, 4)
192
+ #nn.ReLU(False))
193
+ )
194
+ def forward(self, x):
195
+ x = self.conv(x)
196
+ return x
197
+
198
+
199
+ class ConvOutBlock(nn.Module):
200
+ def __init__(self, in_ch, out_ch):
201
+ super(ConvOutBlock, self).__init__()
202
+ self.conv = nn.Sequential(
203
+ #nn.Conv2d(in_ch, out_ch, kernel_size=3, stride=1, padding=1, bias=False),
204
+ Conv2dQuant(in_ch, out_ch, kernel_size=3, stride=1, padding=1, bias=False),
205
+ #nn.Tanh()
206
+ TanhOp(data_in_type='float', data_out_type='fixed')
207
+ )
208
+ def forward(self, x):
209
+ x = self.conv(x)
210
+ return x
211
+
212
+
213
+ class UpBlock(nn.Module):
214
+ def __init__(self, in_ch1, in_ch2, mid_ch1, mid_ch2, mid_ch, out_ch, num_conv=1, use_bn=True):
215
+ super(UpBlock, self).__init__()
216
+
217
+ #self.up = nn.Upsample(scale_factor=2, mode='nearest')
218
+ self.up = UpsampleQuant(scale_factor=2, mode='nearest')
219
+
220
+ ## branch_1
221
+ if mid_ch1 is None or in_ch1 == mid_ch1:
222
+ self.conv1 = None
223
+ else:
224
+ self.conv1 = nn.Sequential(
225
+ #nn.Conv2d(in_ch1, mid_ch1, 1, bias=False),
226
+ Conv2dQuant(in_ch1, mid_ch1, 1, bias=True),
227
+ #nn.ReLU(False),
228
+ HardQuant(0, 4)
229
+ )
230
+ if mid_ch1 is None:
231
+ mid_ch1 = in_ch1
232
+
233
+ if in_ch2 is None:
234
+ self.use_shortcut = False
235
+ self.conv2 = None
236
+ else:
237
+ self.use_shortcut = True
238
+ if mid_ch2 is None or in_ch2 == mid_ch2:
239
+ self.conv2 = None
240
+ else:
241
+ self.conv2 = nn.Sequential(
242
+ #nn.Conv2d(in_ch2, mid_ch2, 1, bias=False),
243
+ Conv2dQuant(in_ch2, mid_ch2, 1, bias=True),
244
+ #nn.ReLU(False),
245
+ HardQuant(0, 4)
246
+ )
247
+ if mid_ch2 is None:
248
+ mid_ch2 = in_ch2
249
+ #print(self.conv1 is None, self.conv2 is None)
250
+ combine_ch = mid_ch1
251
+ if self.use_shortcut:
252
+ combine_ch = combine_ch + mid_ch2
253
+ if mid_ch is None or combine_ch == mid_ch:
254
+ self.conv_combine = None
255
+ mid_ch = combine_ch
256
+ else:
257
+ self.conv_combine = nn.Sequential(
258
+ #nn.Conv2d(combine_ch, mid_ch, 1, bias=False),
259
+ Conv2dQuant(combine_ch, mid_ch, 1, bias=True),
260
+ #nn.ReLU(False),
261
+ HardQuant(0, 4)
262
+ )
263
+
264
+ conv_list = []
265
+ #conv_list.append(nn.Conv2d(mid_ch, out_ch, kernel_size=3, stride=1, padding=1, bias=False))
266
+ conv_list.append(Conv2dQuant(mid_ch, out_ch, kernel_size=3, stride=1, padding=1, bias=True))
267
+ if use_bn:
268
+ conv_list.append(nn.BatchNorm2d(out_ch))
269
+ #conv_list.append(nn.ReLU(False))
270
+ conv_list.append(HardQuant(0, 4))
271
+ for n in range(1, num_conv):
272
+ conv_list.append(ResnetBlock(out_ch, out_ch, use_bias=False, use_se = False, use_bn=use_bn))
273
+ self.conv = nn.Sequential(*conv_list)
274
+
275
+ def forward(self, x1, x2=None, ratio=None):
276
+
277
+ if self.conv1 is not None:
278
+ x1 = self.conv1(x1)
279
+ x1 = self.up(x1)
280
+
281
+ if self.use_shortcut:
282
+ if self.conv2 is not None:
283
+ x2 = self.conv2(x2)
284
+
285
+ if self.use_shortcut:
286
+ if ratio is None:
287
+ x = torch.cat([x1, x2], dim=1)
288
+ else:
289
+ x = torch.cat([x1, x2 * ratio], dim=1)
290
+ else:
291
+ x = x1
292
+
293
+ if self.conv_combine is not None:
294
+ x = self.conv_combine(x)
295
+
296
+ x = self.conv(x)
297
+ return x
298
+
299
+
300
+ class ResnetBlock(nn.Module):
301
+ def __init__(self, dim, dim_out, use_bias, use_se = False, use_bn=True):
302
+ super(ResnetBlock, self).__init__()
303
+ conv_block = []
304
+ #conv_block += [nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1, bias=use_bias),]
305
+ conv_block += [Conv2dQuant(dim, dim, kernel_size=3, stride=1, padding=1, bias=True),]
306
+ if use_bn:
307
+ conv_block += [nn.BatchNorm2d(dim),]
308
+ #conv_block += [nn.ReLU(False)]
309
+ conv_block += [HardQuant(0, 4)]
310
+
311
+ #conv_block.append(nn.Conv2d(dim, dim_out, kernel_size=3, stride=1, padding=1, bias=use_bias))
312
+ conv_block.append(Conv2dQuant(dim, dim_out, kernel_size=3, stride=1, padding=1, bias=True))
313
+
314
+ if use_bn:
315
+ conv_block.append(nn.BatchNorm2d(dim_out))
316
+ conv_block += [HardQuant(0, 4)]
317
+
318
+ if use_se:
319
+ conv_block.append(SqEx(dim_out, 4))
320
+
321
+ self.conv_block = nn.Sequential(*conv_block)
322
+
323
+ self.downsample = None
324
+ if dim != dim_out:
325
+ if use_bn:
326
+ self.downsample = nn.Sequential(
327
+ #nn.Conv2d(dim, dim_out, kernel_size=1, stride=1, bias=use_bias),
328
+ #nn.BatchNorm2d(dim_out),
329
+ Conv2dQuant(dim, dim_out, kernel_size=1, stride=1, bias=True),
330
+ nn.BatchNorm2d(dim_out),
331
+ )
332
+ else:
333
+ self.downsample = nn.Sequential(
334
+ #nn.Conv2d(dim, dim_out, kernel_size=1, stride=1, bias=use_bias),
335
+ Conv2dQuant(dim, dim_out, kernel_size=1, stride=1, bias=True),
336
+ )
337
+
338
+ #self.relu = nn.ReLU(False)
339
+ self.relu = HardQuant(0, 4)
340
+
341
+ def forward(self, x):
342
+ if self.downsample is None:
343
+ y = AVG(x, self.conv_block(x))
344
+ else:
345
+ y = AVG(self.downsample(x), self.conv_block(x))
346
+ #y = self.relu(y)
347
+ return y
models/modules/sr/light_model_470M.py ADDED
@@ -0,0 +1,442 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from models.modules.pix2pixMini_module import *
4
+
5
+ AVG = AvgQuant()
6
+ mul = SliceMul()
7
+
8
+ def get_int(channel_dict):
9
+ for i in range(len(channel_dict['down'])):
10
+ channel_dict['down'][i] = int(channel_dict['down'][i])
11
+ for i in range(len(channel_dict['backbone'])):
12
+ channel_dict['backbone'][i] = int(channel_dict['backbone'][i])
13
+ for i in range(len(channel_dict['up'])):
14
+ for j in range(len(channel_dict['up'][i])):
15
+ if channel_dict['up'][i][j] is not None:
16
+ channel_dict['up'][i][j] = int(channel_dict['up'][i][j])
17
+ return channel_dict
18
+
19
+ def get_channel_dict(dict_name, ngf):
20
+ if dict_name is None:
21
+ raise('invalid channel_dict name')
22
+
23
+ if dict_name == '3G':
24
+ channel_dict = {
25
+ 'n_blocks': 8,
26
+ 'down': [ngf * 1, ngf * 2, ngf * 4, ngf * 8, ngf * 8],
27
+ 'backbone': [ngf * 8, ngf * 8],
28
+ 'hair_up': [
29
+ [None, None, None, ngf * 4, 2],
30
+ [None, None, None, ngf * 2, 2],
31
+ [None, None, None, ngf * 1, 2],
32
+ [None, None, None, ngf * 1, 1],
33
+ ],
34
+ 'face_up': [
35
+ [None, None, None, ngf * 4, 2],
36
+ [None, None, None, ngf * 2, 2],
37
+ [None, None, None, ngf * 2, 2],
38
+ [None, None, None, ngf * 1, 1],
39
+ ],
40
+ 'up': [
41
+ [None, None, None, ngf * 1, 1],
42
+ ],
43
+ }
44
+
45
+ elif dict_name == '470M':
46
+ channel_dict = {
47
+ 'n_blocks': 4,
48
+ 'down': [ngf * 1, ngf * 2, ngf * 3, ngf * 6, ngf * 4],
49
+ 'backbone': [ngf * 4, ngf * 4],
50
+ 'hair_up': [
51
+ [None, None, None, ngf * 2, 1],
52
+ [None, None, None, ngf * 2, 1],
53
+ [None, None, None, ngf * 1, 1],
54
+ [None, None, None, ngf * 1, 1],
55
+ ],
56
+ 'face_up': [
57
+ [None, None, ngf * 4, ngf * 2, 1],
58
+ [None, None, ngf * 2, ngf * 2, 1],
59
+ [None, None, ngf * 2, ngf * 1, 1],
60
+ [None, None, None, ngf * 1, 1],
61
+ ],
62
+ 'up': [
63
+ [None, None, ngf * 2, ngf * 1, 1],
64
+ ],
65
+ }
66
+ else:
67
+ raise('invalid_dict_name')
68
+ return get_int(channel_dict)
69
+
70
+ chans = 8
71
+ in_channels = 3
72
+ out_channels = 4
73
+ face_branch_out_channels = 3
74
+ hair_branch_out_channels = 3
75
+
76
+ def conv(numIn, numOut, k, s=1, p=0, relu=True, bn=False):
77
+ layers = []
78
+ layers.append(Conv2dQuant(numIn, numOut, k, s, p, bias=True))
79
+ if bn:
80
+ layers.append(nn.BatchNorm2d(numOut))
81
+
82
+ if relu is True:
83
+ layers.append(HardQuant(0, 4))
84
+ return nn.Sequential(*layers)
85
+
86
+ def mnconv(numIn, numOut, k, s=1, p=0, dilation=1, relu=True, bn = True):
87
+ if k < 2:
88
+ return conv(numIn, numOut, k, s, p, relu, bn)
89
+ layers = []
90
+ layers.append(Conv2dQuant(numIn, numIn, k, s, p, groups=numIn, dilation=dilation, bias=True))
91
+ layers.append(nn.BatchNorm2d(numIn))
92
+ layers.append(HardQuant(0, 4))
93
+ layers.append(conv(numIn, numOut, 1, 1, 0, relu, bn))
94
+ return nn.Sequential(*layers)
95
+
96
+ class hair_face_model(nn.Module):
97
+ def __init__(self, ngf=chans, backbone_type='resnet', use_se=True, channel_dict_name = None, with_hair_branch=False, design=5):
98
+
99
+ super().__init__()
100
+ self.design = design
101
+ self.with_hair_branch = with_hair_branch
102
+ channel_dict = get_channel_dict(channel_dict_name, ngf)
103
+ n_blocks = channel_dict['n_blocks']
104
+
105
+ self.inconv = ConvBlock(in_channels, channel_dict['down'][0], stride=1)
106
+ self.shortcut_ratio = [1,1,1,1]
107
+
108
+ # Down-Sampling
109
+ self.DownBlock1 = ConvBlock(channel_dict['down'][0], channel_dict['down'][1], stride=2)
110
+ self.DownBlock2 = ConvBlock(channel_dict['down'][1], channel_dict['down'][2], stride=2)
111
+ self.DownBlock3 = ConvBlock(channel_dict['down'][2], channel_dict['down'][3], stride=2)
112
+ self.DownBlock4 = ConvBlock(channel_dict['down'][3], channel_dict['down'][4], stride=2)
113
+
114
+ # Down-Sampling Bottleneck
115
+ if backbone_type == 'resnet':
116
+ backbone_block = ResnetBlock
117
+ elif backbone_type == 'mobilenet':
118
+ backbone_block = InvertedBottleneck
119
+ n_blocks = n_blocks
120
+ else:
121
+ raise('invalid backbone type')
122
+ ResBlock = []
123
+ ResBlock += [backbone_block(channel_dict['down'][4], channel_dict['backbone'][0], use_bias=False, use_se=use_se)]
124
+ for i in range(1, n_blocks - 1):
125
+ ResBlock += [backbone_block(channel_dict['backbone'][0], channel_dict['backbone'][0], use_bias=False, use_se=use_se)]
126
+ ResBlock += [backbone_block(channel_dict['backbone'][0], channel_dict['backbone'][1], use_bias=False, use_se=use_se)]
127
+ self.ResBlock = nn.Sequential(*ResBlock)
128
+
129
+ self.HairUpBlock4 = UpBlock(channel_dict['backbone'][1], None, channel_dict['hair_up'][0][0], None, channel_dict['hair_up'][0][2], channel_dict['hair_up'][0][3], num_conv=channel_dict['hair_up'][0][4])
130
+ self.HairUpBlock3 = UpBlock(channel_dict['hair_up'][0][3], None, channel_dict['hair_up'][1][0], None, channel_dict['hair_up'][1][2], channel_dict['hair_up'][1][3], num_conv=channel_dict['hair_up'][1][4])
131
+ self.HairUpBlock2 = UpBlock(channel_dict['hair_up'][1][3], None, channel_dict['hair_up'][2][0], None, channel_dict['hair_up'][2][2], channel_dict['hair_up'][2][3], num_conv=channel_dict['hair_up'][2][4])
132
+
133
+ self.FaceUpBlock4 = UpBlock(channel_dict['backbone'][1], channel_dict['down'][3], channel_dict['face_up'][0][0], channel_dict['face_up'][0][1], channel_dict['face_up'][0][2], channel_dict['face_up'][0][3], num_conv=channel_dict['face_up'][0][4])
134
+ self.FaceUpBlock3 = UpBlock(channel_dict['face_up'][0][3], channel_dict['down'][2], channel_dict['face_up'][1][0], channel_dict['face_up'][1][1], channel_dict['face_up'][1][2], channel_dict['face_up'][1][3], num_conv=channel_dict['face_up'][1][4])
135
+ self.FaceUpBlock2 = UpBlock(channel_dict['face_up'][1][3], channel_dict['down'][1], channel_dict['face_up'][2][0], channel_dict['face_up'][2][1], channel_dict['face_up'][2][2], channel_dict['face_up'][2][3], num_conv=channel_dict['face_up'][2][4])
136
+
137
+ self.UpBlock1 = mnUpBlock(channel_dict['hair_up'][2][3] + channel_dict['face_up'][2][3], channel_dict['down'][0], channel_dict['up'][0][0], channel_dict['up'][0][1], channel_dict['up'][0][2], channel_dict['up'][0][3], num_conv=channel_dict['up'][0][4])
138
+ self.outconv = mnConvOutBlock(channel_dict['up'][0][3], out_channels)
139
+
140
+ #self.shortcut_ratio = [1,1,1,1]
141
+
142
+ if self.with_hair_branch:
143
+ self.HairUpBlock1 = UpBlock(channel_dict['hair_up'][2][3], None, channel_dict['hair_up'][3][0], None, channel_dict['hair_up'][3][2],channel_dict['hair_up'][3][3])
144
+ self.Hairoutconv = ConvOutBlock(channel_dict['hair_up'][3][3], hair_branch_out_channels)
145
+
146
+ self.FaceUpBlock1 = UpBlock(channel_dict['face_up'][2][3], channel_dict['down'][0], channel_dict['face_up'][3][0], channel_dict['face_up'][3][1], channel_dict['face_up'][3][2], channel_dict['face_up'][3][3])
147
+ self.Faceoutconv = ConvOutBlock(channel_dict['face_up'][3][3], face_branch_out_channels)
148
+
149
+ self.up = UpsampleQuant(scale_factor=1.5, mode='bilinear')
150
+
151
+
152
+ def forward(self, x , with_hair=False):
153
+ x0 = self.inconv(x)
154
+ x1 = self.DownBlock1(x0)
155
+ x2 = self.DownBlock2(x1)
156
+ x3 = self.DownBlock3(x2)
157
+ x4 = self.DownBlock4(x3)
158
+ x = self.ResBlock(x4)
159
+
160
+ hair = self.HairUpBlock4(x)
161
+ hair = self.HairUpBlock3(hair)
162
+ hair = self.HairUpBlock2(hair)
163
+ face = self.FaceUpBlock4(x, x3, self.shortcut_ratio[0])
164
+ face = self.FaceUpBlock3(face, x2, self.shortcut_ratio[1])
165
+ face = self.FaceUpBlock2(face, x1, self.shortcut_ratio[2])
166
+
167
+ x = self.UpBlock1(torch.cat([hair,face], dim=1), x0, self.shortcut_ratio[3])
168
+ # print(self.outconv)
169
+ x = self.outconv(x)
170
+ if not with_hair or not self.with_hair_branch:
171
+ return x
172
+ else:
173
+ hair = self.HairUpBlock1(hair)
174
+ hair = self.Hairoutconv(hair)
175
+
176
+ face = self.FaceUpBlock1(face, x0, self.shortcut_ratio[3])
177
+ face = self.Faceoutconv(face)
178
+
179
+ return x, hair, face
180
+
181
+
182
+ class ConvBlock(nn.Module):
183
+ def __init__(self, in_ch, out_ch, stride):
184
+ super(ConvBlock, self).__init__()
185
+ self.conv = nn.Sequential(
186
+ #nn.Conv2d(in_ch, out_ch, kernel_size=3, stride=stride, padding=1, bias=False),
187
+ Conv2dQuant(in_ch, out_ch, kernel_size=3, stride=stride, padding=1, bias=True),
188
+ nn.BatchNorm2d(out_ch),
189
+ HardQuant(0, 4)
190
+ #nn.ReLU(False))
191
+ )
192
+ def forward(self, x):
193
+ x = self.conv(x)
194
+ return x
195
+
196
+
197
+
198
+ class UpBlock(nn.Module):
199
+ def __init__(self, in_ch1, in_ch2, mid_ch1, mid_ch2, mid_ch, out_ch, num_conv=1, use_bn=True):
200
+ super(UpBlock, self).__init__()
201
+
202
+ #self.up = nn.Upsample(scale_factor=2, mode='nearest')
203
+ self.up = UpsampleQuant(scale_factor=2, mode='nearest')
204
+
205
+ ## branch_1
206
+ if mid_ch1 is None or in_ch1 == mid_ch1:
207
+ self.conv1 = None
208
+ else:
209
+ self.conv1 = nn.Sequential(
210
+ #nn.Conv2d(in_ch1, mid_ch1, 1, bias=False),
211
+ Conv2dQuant(in_ch1, mid_ch1, 1, bias=True),
212
+ #nn.ReLU(False),
213
+ HardQuant(0, 4)
214
+ )
215
+ if mid_ch1 is None:
216
+ mid_ch1 = in_ch1
217
+
218
+ if in_ch2 is None:
219
+ self.use_shortcut = False
220
+ self.conv2 = None
221
+ else:
222
+ self.use_shortcut = True
223
+ if mid_ch2 is None or in_ch2 == mid_ch2:
224
+ self.conv2 = None
225
+ else:
226
+ self.conv2 = nn.Sequential(
227
+ #nn.Conv2d(in_ch2, mid_ch2, 1, bias=False),
228
+ Conv2dQuant(in_ch2, mid_ch2, 1, bias=True),
229
+ #nn.ReLU(False),
230
+ HardQuant(0, 4)
231
+ )
232
+ if mid_ch2 is None:
233
+ mid_ch2 = in_ch2
234
+ #print(self.conv1 is None, self.conv2 is None)
235
+ combine_ch = mid_ch1
236
+ if self.use_shortcut:
237
+ combine_ch = combine_ch + mid_ch2
238
+ if mid_ch is None or combine_ch == mid_ch:
239
+ self.conv_combine = None
240
+ mid_ch = combine_ch
241
+ else:
242
+ self.conv_combine = nn.Sequential(
243
+ #nn.Conv2d(combine_ch, mid_ch, 1, bias=False),
244
+ Conv2dQuant(combine_ch, mid_ch, 1, bias=True),
245
+ #nn.ReLU(False),
246
+ HardQuant(0, 4)
247
+ )
248
+
249
+ conv_list = []
250
+ #conv_list.append(nn.Conv2d(mid_ch, out_ch, kernel_size=3, stride=1, padding=1, bias=False))
251
+ conv_list.append(Conv2dQuant(mid_ch, out_ch, kernel_size=3, stride=1, padding=1, bias=True))
252
+ # conv_list.append(mnconv(mid_ch, out_ch, k=3, s=1, p=1))
253
+
254
+ if use_bn:
255
+ conv_list.append(nn.BatchNorm2d(out_ch))
256
+ #conv_list.append(nn.ReLU(False))
257
+ conv_list.append(HardQuant(0, 4))
258
+ for n in range(1, num_conv):
259
+ conv_list.append(ResnetBlock(out_ch, out_ch, use_bias=False, use_se = False, use_bn=use_bn))
260
+ self.conv = nn.Sequential(*conv_list)
261
+
262
+ def forward(self, x1, x2=None, ratio=None):
263
+
264
+ if self.conv1 is not None:
265
+ x1 = self.conv1(x1)
266
+ x1 = self.up(x1)
267
+
268
+ if self.use_shortcut:
269
+ if self.conv2 is not None:
270
+ x2 = self.conv2(x2)
271
+
272
+ if self.use_shortcut:
273
+ if ratio is None:
274
+ x = torch.cat([x1, x2], dim=1)
275
+ else:
276
+ x = torch.cat([x1, x2], dim=1)
277
+ else:
278
+ x = x1
279
+
280
+ if self.conv_combine is not None:
281
+ x = self.conv_combine(x)
282
+
283
+ x = self.conv(x)
284
+ return x
285
+
286
+ class mnUpBlock(nn.Module):
287
+ def __init__(self, in_ch1, in_ch2, mid_ch1, mid_ch2, mid_ch, out_ch, num_conv=1, use_bn=True):
288
+ super(mnUpBlock, self).__init__()
289
+
290
+ #self.up = nn.Upsample(scale_factor=2, mode='nearest')
291
+ self.up = UpsampleQuant(scale_factor=2, mode='nearest')
292
+
293
+ ## branch_1
294
+ if mid_ch1 is None or in_ch1 == mid_ch1:
295
+ self.conv1 = None
296
+ else:
297
+ self.conv1 = nn.Sequential(
298
+ #nn.Conv2d(in_ch1, mid_ch1, 1, bias=False),
299
+ Conv2dQuant(in_ch1, mid_ch1, 1, bias=True),
300
+ #nn.ReLU(False),
301
+ HardQuant(0, 4)
302
+ )
303
+ if mid_ch1 is None:
304
+ mid_ch1 = in_ch1
305
+
306
+ if in_ch2 is None:
307
+ self.use_shortcut = False
308
+ self.conv2 = None
309
+ else:
310
+ self.use_shortcut = True
311
+ if mid_ch2 is None or in_ch2 == mid_ch2:
312
+ self.conv2 = None
313
+ else:
314
+ self.conv2 = nn.Sequential(
315
+ #nn.Conv2d(in_ch2, mid_ch2, 1, bias=False),
316
+ Conv2dQuant(in_ch2, mid_ch2, 1, bias=True),
317
+ #nn.ReLU(False),
318
+ HardQuant(0, 4)
319
+ )
320
+ if mid_ch2 is None:
321
+ mid_ch2 = in_ch2
322
+ #print(self.conv1 is None, self.conv2 is None)
323
+ combine_ch = mid_ch1
324
+ if self.use_shortcut:
325
+ combine_ch = combine_ch + mid_ch2
326
+ if mid_ch is None or combine_ch == mid_ch:
327
+ self.conv_combine = None
328
+ mid_ch = combine_ch
329
+ else:
330
+ self.conv_combine = nn.Sequential(
331
+ #nn.Conv2d(combine_ch, mid_ch, 1, bias=False),
332
+ Conv2dQuant(combine_ch, mid_ch, 1, bias=True),
333
+ #nn.ReLU(False),
334
+ HardQuant(0, 4)
335
+ )
336
+
337
+ conv_list = []
338
+ #conv_list.append(nn.Conv2d(mid_ch, out_ch, kernel_size=3, stride=1, padding=1, bias=False))
339
+ # conv_list.append(Conv2dQuant(mid_ch, out_ch, kernel_size=3, stride=1, padding=1, bias=True))
340
+ conv_list.append(Conv2dQuant(mid_ch, mid_ch, kernel_size=3, stride=1, padding=1, groups=mid_ch,bias=True))
341
+ conv_list.append(nn.BatchNorm2d(mid_ch))
342
+ conv_list.append(HardQuant(0, 4))
343
+ conv_list.append(Conv2dQuant(mid_ch, out_ch, kernel_size=1,stride=1,padding=0))
344
+ # conv_list.append(mnconv(mid_ch, out_ch, k=3, s=1, p=1))
345
+
346
+ if use_bn:
347
+ conv_list.append(nn.BatchNorm2d(out_ch))
348
+ #conv_list.append(nn.ReLU(False))
349
+ conv_list.append(HardQuant(0, 4))
350
+ for n in range(1, num_conv):
351
+ conv_list.append(ResnetBlock(out_ch, out_ch, use_bias=False, use_se = False, use_bn=use_bn))
352
+ self.conv = nn.Sequential(*conv_list)
353
+
354
+ def forward(self, x1, x2=None, ratio=None):
355
+
356
+ if self.conv1 is not None:
357
+ x1 = self.conv1(x1)
358
+ x1 = self.up(x1)
359
+
360
+ if self.use_shortcut:
361
+ if self.conv2 is not None:
362
+ x2 = self.conv2(x2)
363
+
364
+ if self.use_shortcut:
365
+ if ratio is None:
366
+ x = torch.cat([x1, x2], dim=1)
367
+ else:
368
+ x = torch.cat([x1, x2], dim=1)
369
+ else:
370
+ x = x1
371
+
372
+ if self.conv_combine is not None:
373
+ x = self.conv_combine(x)
374
+
375
+ x = self.conv(x)
376
+ return x
377
+
378
+ class mnConvOutBlock(nn.Module):
379
+ def __init__(self, in_ch, out_ch):
380
+ super(mnConvOutBlock, self).__init__()
381
+ self.conv = nn.Sequential(
382
+ Conv2dQuant(in_ch, in_ch, kernel_size=3, stride=1, padding=1, groups=in_ch, bias=True),
383
+ nn.BatchNorm2d(in_ch),
384
+ HardQuant(0, 4),
385
+ Conv2dQuant(in_ch, out_ch, kernel_size=1, stride=1, padding=0, bias=False),
386
+ #nn.Tanh()
387
+ TanhOp(data_in_type='float', data_out_type='fixed'),
388
+ nn.Upsample(scale_factor=2, mode='bilinear'),
389
+ nn.Conv2d(out_ch, out_ch, kernel_size=3, stride=1, padding=1, bias=True),
390
+ nn.Tanh()
391
+ )
392
+ def forward(self, x0):
393
+ x0 = self.conv(x0)
394
+ return x0
395
+
396
+ class ResnetBlock(nn.Module):
397
+ def __init__(self, dim, dim_out, use_bias, use_se = False, use_bn=True):
398
+ super(ResnetBlock, self).__init__()
399
+ conv_block = []
400
+ #conv_block += [nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1, bias=use_bias),]
401
+ conv_block += [Conv2dQuant(dim, dim, kernel_size=3, stride=1, padding=1, bias=True),]
402
+ if use_bn:
403
+ conv_block += [nn.BatchNorm2d(dim),]
404
+ #conv_block += [nn.ReLU(False)]
405
+ conv_block += [HardQuant(0, 4)]
406
+
407
+ #conv_block.append(nn.Conv2d(dim, dim_out, kernel_size=3, stride=1, padding=1, bias=use_bias))
408
+ conv_block.append(Conv2dQuant(dim, dim_out, kernel_size=3, stride=1, padding=1, bias=True))
409
+ if use_bn:
410
+ conv_block.append(nn.BatchNorm2d(dim_out))
411
+ conv_block += [HardQuant(0, 4)]
412
+
413
+ if use_se:
414
+ conv_block.append(SqEx(dim_out, 4))
415
+
416
+ self.conv_block = nn.Sequential(*conv_block)
417
+
418
+ self.downsample = None
419
+ if dim != dim_out:
420
+ if use_bn:
421
+ self.downsample = nn.Sequential(
422
+ #nn.Conv2d(dim, dim_out, kernel_size=1, stride=1, bias=use_bias),
423
+ #nn.BatchNorm2d(dim_out),
424
+ Conv2dQuant(dim, dim_out, kernel_size=1, stride=1, bias=True),
425
+ nn.BatchNorm2d(dim_out),
426
+ )
427
+ else:
428
+ self.downsample = nn.Sequential(
429
+ #nn.Conv2d(dim, dim_out, kernel_size=1, stride=1, bias=use_bias),
430
+ Conv2dQuant(dim, dim_out, kernel_size=1, stride=1, bias=True),
431
+ )
432
+
433
+ #self.relu = nn.ReLU(False)
434
+ #self.relu = HardQuant(0, 4)
435
+
436
+ def forward(self, x):
437
+ if self.downsample is None:
438
+ y = AVG(x, self.conv_block(x))
439
+ else:
440
+ y = AVG(self.downsample(x), self.conv_block(x))
441
+ #y = self.relu(y)
442
+ return y
models/modules/stylegan2/__pycache__/model.cpython-38.pyc ADDED
Binary file (16.3 kB). View file
 
models/modules/stylegan2/__pycache__/non_leaking.cpython-38.pyc ADDED
Binary file (11 kB). View file
 
models/modules/stylegan2/model.py ADDED
@@ -0,0 +1,716 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import random
3
+ import functools
4
+ import operator
5
+
6
+ import torch
7
+ from torch import nn
8
+ from torch.nn import functional as F
9
+ from torch.autograd import Function
10
+
11
+ from .op import FusedLeakyReLU, fused_leaky_relu, upfirdn2d, conv2d_gradfix
12
+
13
+
14
+ class PixelNorm(nn.Module):
15
+ def __init__(self):
16
+ super().__init__()
17
+
18
+ def forward(self, input):
19
+ return input * torch.rsqrt(torch.mean(input ** 2, dim=1, keepdim=True) + 1e-8)
20
+
21
+
22
+ def make_kernel(k):
23
+ k = torch.tensor(k, dtype=torch.float32)
24
+
25
+ if k.ndim == 1:
26
+ k = k[None, :] * k[:, None]
27
+
28
+ k /= k.sum()
29
+
30
+ return k
31
+
32
+
33
+ class Upsample(nn.Module):
34
+ def __init__(self, kernel, factor=2):
35
+ super().__init__()
36
+
37
+ self.factor = factor
38
+ kernel = make_kernel(kernel) * (factor ** 2)
39
+ self.register_buffer("kernel", kernel)
40
+
41
+ p = kernel.shape[0] - factor
42
+
43
+ pad0 = (p + 1) // 2 + factor - 1
44
+ pad1 = p // 2
45
+
46
+ self.pad = (pad0, pad1)
47
+
48
+ def forward(self, input):
49
+ out = upfirdn2d(input, self.kernel, up=self.factor, down=1, pad=self.pad)
50
+
51
+ return out
52
+
53
+
54
+ class Downsample(nn.Module):
55
+ def __init__(self, kernel, factor=2):
56
+ super().__init__()
57
+
58
+ self.factor = factor
59
+ kernel = make_kernel(kernel)
60
+ self.register_buffer("kernel", kernel)
61
+
62
+ p = kernel.shape[0] - factor
63
+
64
+ pad0 = (p + 1) // 2
65
+ pad1 = p // 2
66
+
67
+ self.pad = (pad0, pad1)
68
+
69
+ def forward(self, input):
70
+ out = upfirdn2d(input, self.kernel, up=1, down=self.factor, pad=self.pad)
71
+
72
+ return out
73
+
74
+
75
+ class Blur(nn.Module):
76
+ def __init__(self, kernel, pad, upsample_factor=1):
77
+ super().__init__()
78
+
79
+ kernel = make_kernel(kernel)
80
+
81
+ if upsample_factor > 1:
82
+ kernel = kernel * (upsample_factor ** 2)
83
+
84
+ self.register_buffer("kernel", kernel)
85
+
86
+ self.pad = pad
87
+
88
+ def forward(self, input):
89
+ out = upfirdn2d(input, self.kernel, pad=self.pad)
90
+
91
+ return out
92
+
93
+
94
+ class EqualConv2d(nn.Module):
95
+ def __init__(
96
+ self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True
97
+ ):
98
+ super().__init__()
99
+
100
+ self.weight = nn.Parameter(
101
+ torch.randn(out_channel, in_channel, kernel_size, kernel_size)
102
+ )
103
+ self.scale = 1 / math.sqrt(in_channel * kernel_size ** 2)
104
+
105
+ self.stride = stride
106
+ self.padding = padding
107
+
108
+ if bias:
109
+ self.bias = nn.Parameter(torch.zeros(out_channel))
110
+
111
+ else:
112
+ self.bias = None
113
+
114
+ def forward(self, input):
115
+ out = conv2d_gradfix.conv2d(
116
+ input,
117
+ self.weight * self.scale,
118
+ bias=self.bias,
119
+ stride=self.stride,
120
+ padding=self.padding,
121
+ )
122
+
123
+ return out
124
+
125
+ def __repr__(self):
126
+ return (
127
+ f"{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]},"
128
+ f" {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})"
129
+ )
130
+
131
+
132
+ class EqualLinear(nn.Module):
133
+ def __init__(
134
+ self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None
135
+ ):
136
+ super().__init__()
137
+
138
+ self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul))
139
+
140
+ if bias:
141
+ self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init))
142
+
143
+ else:
144
+ self.bias = None
145
+
146
+ self.activation = activation
147
+
148
+ self.scale = (1 / math.sqrt(in_dim)) * lr_mul
149
+ self.lr_mul = lr_mul
150
+
151
+ def forward(self, input):
152
+ if self.activation:
153
+ out = F.linear(input, self.weight * self.scale)
154
+ out = fused_leaky_relu(out, self.bias * self.lr_mul)
155
+
156
+ else:
157
+ out = F.linear(
158
+ input, self.weight * self.scale, bias=self.bias * self.lr_mul
159
+ )
160
+
161
+ return out
162
+
163
+ def __repr__(self):
164
+ return (
165
+ f"{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})"
166
+ )
167
+
168
+
169
+ class ModulatedConv2d(nn.Module):
170
+ def __init__(
171
+ self,
172
+ in_channel,
173
+ out_channel,
174
+ kernel_size,
175
+ style_dim,
176
+ demodulate=True,
177
+ upsample=False,
178
+ downsample=False,
179
+ blur_kernel=[1, 3, 3, 1],
180
+ fused=True,
181
+ ):
182
+ super().__init__()
183
+
184
+ self.eps = 1e-8
185
+ self.kernel_size = kernel_size
186
+ self.in_channel = in_channel
187
+ self.out_channel = out_channel
188
+ self.upsample = upsample
189
+ self.downsample = downsample
190
+
191
+ if upsample:
192
+ factor = 2
193
+ p = (len(blur_kernel) - factor) - (kernel_size - 1)
194
+ pad0 = (p + 1) // 2 + factor - 1
195
+ pad1 = p // 2 + 1
196
+
197
+ self.blur = Blur(blur_kernel, pad=(pad0, pad1), upsample_factor=factor)
198
+
199
+ if downsample:
200
+ factor = 2
201
+ p = (len(blur_kernel) - factor) + (kernel_size - 1)
202
+ pad0 = (p + 1) // 2
203
+ pad1 = p // 2
204
+
205
+ self.blur = Blur(blur_kernel, pad=(pad0, pad1))
206
+
207
+ fan_in = in_channel * kernel_size ** 2
208
+ self.scale = 1 / math.sqrt(fan_in)
209
+ self.padding = kernel_size // 2
210
+
211
+ self.weight = nn.Parameter(
212
+ torch.randn(1, out_channel, in_channel, kernel_size, kernel_size)
213
+ )
214
+
215
+ self.modulation = EqualLinear(style_dim, in_channel, bias_init=1)
216
+
217
+ self.demodulate = demodulate
218
+ self.fused = fused
219
+
220
+ def __repr__(self):
221
+ return (
222
+ f"{self.__class__.__name__}({self.in_channel}, {self.out_channel}, {self.kernel_size}, "
223
+ f"upsample={self.upsample}, downsample={self.downsample})"
224
+ )
225
+
226
+ def forward(self, input, style):
227
+ batch, in_channel, height, width = input.shape
228
+
229
+ if not self.fused:
230
+ weight = self.scale * self.weight.squeeze(0)
231
+ style = self.modulation(style)
232
+
233
+ if self.demodulate:
234
+ w = weight.unsqueeze(0) * style.view(batch, 1, in_channel, 1, 1)
235
+ dcoefs = (w.square().sum((2, 3, 4)) + 1e-8).rsqrt()
236
+
237
+ input = input * style.reshape(batch, in_channel, 1, 1)
238
+
239
+ if self.upsample:
240
+ weight = weight.transpose(0, 1)
241
+ out = conv2d_gradfix.conv_transpose2d(
242
+ input, weight, padding=0, stride=2
243
+ )
244
+ out = self.blur(out)
245
+
246
+ elif self.downsample:
247
+ input = self.blur(input)
248
+ out = conv2d_gradfix.conv2d(input, weight, padding=0, stride=2)
249
+
250
+ else:
251
+ out = conv2d_gradfix.conv2d(input, weight, padding=self.padding)
252
+
253
+ if self.demodulate:
254
+ out = out * dcoefs.view(batch, -1, 1, 1)
255
+
256
+ return out
257
+
258
+ style = self.modulation(style).view(batch, 1, in_channel, 1, 1)
259
+ weight = self.scale * self.weight * style
260
+
261
+ if self.demodulate:
262
+ demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-8)
263
+ weight = weight * demod.view(batch, self.out_channel, 1, 1, 1)
264
+
265
+ weight = weight.view(
266
+ batch * self.out_channel, in_channel, self.kernel_size, self.kernel_size
267
+ )
268
+
269
+ if self.upsample:
270
+ input = input.view(1, batch * in_channel, height, width)
271
+ weight = weight.view(
272
+ batch, self.out_channel, in_channel, self.kernel_size, self.kernel_size
273
+ )
274
+ weight = weight.transpose(1, 2).reshape(
275
+ batch * in_channel, self.out_channel, self.kernel_size, self.kernel_size
276
+ )
277
+ out = conv2d_gradfix.conv_transpose2d(
278
+ input, weight, padding=0, stride=2, groups=batch
279
+ )
280
+ _, _, height, width = out.shape
281
+ out = out.view(batch, self.out_channel, height, width)
282
+ out = self.blur(out)
283
+
284
+ elif self.downsample:
285
+ input = self.blur(input)
286
+ _, _, height, width = input.shape
287
+ input = input.view(1, batch * in_channel, height, width)
288
+ out = conv2d_gradfix.conv2d(
289
+ input, weight, padding=0, stride=2, groups=batch
290
+ )
291
+ _, _, height, width = out.shape
292
+ out = out.view(batch, self.out_channel, height, width)
293
+
294
+ else:
295
+ input = input.view(1, batch * in_channel, height, width)
296
+ out = conv2d_gradfix.conv2d(
297
+ input, weight, padding=self.padding, groups=batch
298
+ )
299
+ _, _, height, width = out.shape
300
+ out = out.view(batch, self.out_channel, height, width)
301
+
302
+ return out
303
+
304
+
305
+ class NoiseInjection(nn.Module):
306
+ def __init__(self):
307
+ super().__init__()
308
+
309
+ self.weight = nn.Parameter(torch.zeros(1))
310
+
311
+ def forward(self, image, noise=None):
312
+ if noise is None:
313
+ batch, _, height, width = image.shape
314
+ noise = image.new_empty(batch, 1, height, width).normal_()
315
+
316
+ return image + self.weight * noise
317
+
318
+
319
+ class ConstantInput(nn.Module):
320
+ def __init__(self, channel, size=4):
321
+ super().__init__()
322
+
323
+ if type(size) is tuple:
324
+ self.input = nn.Parameter(torch.randn(1, channel, size[0], size[1]))
325
+ else:
326
+ self.input = nn.Parameter(torch.randn(1, channel, size, size))
327
+
328
+ def forward(self, input):
329
+ batch = input.shape[0]
330
+ out = self.input.repeat(batch, 1, 1, 1)
331
+
332
+ return out
333
+
334
+
335
+ class StyledConv(nn.Module):
336
+ def __init__(
337
+ self,
338
+ in_channel,
339
+ out_channel,
340
+ kernel_size,
341
+ style_dim,
342
+ upsample=False,
343
+ blur_kernel=[1, 3, 3, 1],
344
+ demodulate=True,
345
+ ):
346
+ super().__init__()
347
+
348
+ self.conv = ModulatedConv2d(
349
+ in_channel,
350
+ out_channel,
351
+ kernel_size,
352
+ style_dim,
353
+ upsample=upsample,
354
+ blur_kernel=blur_kernel,
355
+ demodulate=demodulate,
356
+ )
357
+
358
+ self.noise = NoiseInjection()
359
+ # self.bias = nn.Parameter(torch.zeros(1, out_channel, 1, 1))
360
+ # self.activate = ScaledLeakyReLU(0.2)
361
+ self.activate = FusedLeakyReLU(out_channel)
362
+
363
+ def forward(self, input, style, noise=None):
364
+ out = self.conv(input, style)
365
+ out = self.noise(out, noise=noise)
366
+ # out = out + self.bias
367
+ out = self.activate(out)
368
+
369
+ return out
370
+
371
+
372
+ class ToRGB(nn.Module):
373
+ def __init__(self, in_channel, style_dim, upsample=True, blur_kernel=[1, 3, 3, 1]):
374
+ super().__init__()
375
+
376
+ if upsample:
377
+ self.upsample = Upsample(blur_kernel)
378
+
379
+ self.conv = ModulatedConv2d(in_channel, 3, 1, style_dim, demodulate=False)
380
+ self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1))
381
+
382
+ def forward(self, input, style, skip=None):
383
+ out = self.conv(input, style)
384
+ out = out + self.bias
385
+
386
+ if skip is not None:
387
+ skip = self.upsample(skip)
388
+
389
+ out = out + skip
390
+
391
+ return out
392
+
393
+
394
+ class Generator(nn.Module):
395
+ def __init__(
396
+ self,
397
+ size,
398
+ style_dim,
399
+ n_mlp,
400
+ channel_multiplier=2,
401
+ blur_kernel=[1, 3, 3, 1],
402
+ lr_mlp=0.01,
403
+ ):
404
+ super().__init__()
405
+
406
+ self.size = size
407
+
408
+ self.style_dim = style_dim
409
+
410
+ layers = [PixelNorm()]
411
+
412
+ for i in range(n_mlp):
413
+ layers.append(
414
+ EqualLinear(
415
+ style_dim, style_dim, lr_mul=lr_mlp, activation="fused_lrelu"
416
+ )
417
+ )
418
+
419
+ self.style = nn.Sequential(*layers)
420
+
421
+ self.channels = {
422
+ 4: 512,
423
+ 8: 512,
424
+ 16: 512,
425
+ 32: 512,
426
+ 64: 256 * channel_multiplier,
427
+ 128: 128 * channel_multiplier,
428
+ 256: 64 * channel_multiplier,
429
+ 512: 32 * channel_multiplier,
430
+ 1024: 16 * channel_multiplier,
431
+ }
432
+
433
+ self.input = ConstantInput(self.channels[4])
434
+ self.conv1 = StyledConv(
435
+ self.channels[4], self.channels[4], 3, style_dim, blur_kernel=blur_kernel
436
+ )
437
+ self.to_rgb1 = ToRGB(self.channels[4], style_dim, upsample=False)
438
+
439
+ self.log_size = int(math.log(size, 2))
440
+ self.num_layers = (self.log_size - 2) * 2 + 1
441
+
442
+ self.convs = nn.ModuleList()
443
+ self.upsamples = nn.ModuleList()
444
+ self.to_rgbs = nn.ModuleList()
445
+ self.noises = nn.Module()
446
+
447
+ in_channel = self.channels[4]
448
+
449
+ for layer_idx in range(self.num_layers):
450
+ res = (layer_idx + 5) // 2
451
+ shape = [1, 1, 2 ** res, 2 ** res]
452
+ self.noises.register_buffer(f"noise_{layer_idx}", torch.randn(*shape))
453
+
454
+ for i in range(3, self.log_size + 1):
455
+ out_channel = self.channels[2 ** i]
456
+
457
+ self.convs.append(
458
+ StyledConv(
459
+ in_channel,
460
+ out_channel,
461
+ 3,
462
+ style_dim,
463
+ upsample=True,
464
+ blur_kernel=blur_kernel,
465
+ )
466
+ )
467
+
468
+ self.convs.append(
469
+ StyledConv(
470
+ out_channel, out_channel, 3, style_dim, blur_kernel=blur_kernel
471
+ )
472
+ )
473
+
474
+ self.to_rgbs.append(ToRGB(out_channel, style_dim))
475
+
476
+ in_channel = out_channel
477
+
478
+ self.n_latent = self.log_size * 2 - 2
479
+
480
+ def make_noise(self):
481
+ device = self.input.input.device
482
+
483
+ noises = [torch.randn(1, 1, 2 ** 2, 2 ** 2, device=device)]
484
+
485
+ for i in range(3, self.log_size + 1):
486
+ for _ in range(2):
487
+ noises.append(torch.randn(1, 1, 2 ** i, 2 ** i, device=device))
488
+
489
+ return noises
490
+
491
+ def mean_latent(self, n_latent):
492
+ latent_in = torch.randn(
493
+ n_latent, self.style_dim, device=self.input.input.device
494
+ )
495
+ latent = self.style(latent_in).mean(0, keepdim=True)
496
+
497
+ return latent
498
+
499
+ def get_latent(self, input):
500
+ return self.style(input)
501
+
502
+ def forward(
503
+ self,
504
+ styles,
505
+ return_latents=False,
506
+ inject_index=None,
507
+ truncation=1,
508
+ truncation_latent=None,
509
+ input_is_latent=False,
510
+ noise=None,
511
+ randomize_noise=True,
512
+ ):
513
+ if not input_is_latent:
514
+ styles = [self.style(s) for s in styles]
515
+
516
+ if noise is None:
517
+ if randomize_noise:
518
+ noise = [None] * self.num_layers
519
+ else:
520
+ noise = [
521
+ getattr(self.noises, f"noise_{i}") for i in range(self.num_layers)
522
+ ]
523
+
524
+ if truncation < 1:
525
+ style_t = []
526
+
527
+ for style in styles:
528
+ style_t.append(
529
+ truncation_latent + truncation * (style - truncation_latent)
530
+ )
531
+
532
+ styles = style_t
533
+
534
+ if len(styles) < 2:
535
+ inject_index = self.n_latent
536
+
537
+ if styles[0].ndim < 3:
538
+ latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
539
+
540
+ else:
541
+ latent = styles[0]
542
+
543
+ else:
544
+ if inject_index is None:
545
+ inject_index = random.randint(1, self.n_latent - 1)
546
+
547
+ latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
548
+ latent2 = styles[1].unsqueeze(1).repeat(1, self.n_latent - inject_index, 1)
549
+
550
+ latent = torch.cat([latent, latent2], 1)
551
+
552
+ out = self.input(latent)
553
+ out = self.conv1(out, latent[:, 0], noise=noise[0])
554
+
555
+ skip = self.to_rgb1(out, latent[:, 1])
556
+
557
+ i = 1
558
+ for conv1, conv2, noise1, noise2, to_rgb in zip(
559
+ self.convs[::2], self.convs[1::2], noise[1::2], noise[2::2], self.to_rgbs
560
+ ):
561
+ out = conv1(out, latent[:, i], noise=noise1)
562
+ out = conv2(out, latent[:, i + 1], noise=noise2)
563
+ skip = to_rgb(out, latent[:, i + 2], skip)
564
+
565
+ i += 2
566
+
567
+ image = skip
568
+
569
+ if return_latents:
570
+ return image, latent
571
+
572
+ else:
573
+ return image, None
574
+
575
+
576
+ class ConvLayer(nn.Sequential):
577
+ def __init__(
578
+ self,
579
+ in_channel,
580
+ out_channel,
581
+ kernel_size,
582
+ downsample=False,
583
+ blur_kernel=[1, 3, 3, 1],
584
+ bias=True,
585
+ activate=True,
586
+ ):
587
+ layers = []
588
+
589
+ if downsample:
590
+ factor = 2
591
+ p = (len(blur_kernel) - factor) + (kernel_size - 1)
592
+ pad0 = (p + 1) // 2
593
+ pad1 = p // 2
594
+
595
+ layers.append(Blur(blur_kernel, pad=(pad0, pad1)))
596
+
597
+ stride = 2
598
+ self.padding = 0
599
+
600
+ else:
601
+ stride = 1
602
+ self.padding = kernel_size // 2
603
+
604
+ layers.append(
605
+ EqualConv2d(
606
+ in_channel,
607
+ out_channel,
608
+ kernel_size,
609
+ padding=self.padding,
610
+ stride=stride,
611
+ bias=bias and not activate,
612
+ )
613
+ )
614
+
615
+ if activate:
616
+ layers.append(FusedLeakyReLU(out_channel, bias=bias))
617
+
618
+ super().__init__(*layers)
619
+
620
+
621
+ class ResBlock(nn.Module):
622
+ def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1]):
623
+ super().__init__()
624
+
625
+ self.conv1 = ConvLayer(in_channel, in_channel, 3)
626
+ self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=True)
627
+
628
+ self.skip = ConvLayer(
629
+ in_channel, out_channel, 1, downsample=True, activate=False, bias=False
630
+ )
631
+
632
+ def forward(self, input):
633
+ out = self.conv1(input)
634
+ out = self.conv2(out)
635
+
636
+ skip = self.skip(input)
637
+ out = (out + skip) / math.sqrt(2)
638
+
639
+ return out
640
+
641
+
642
+ class Discriminator(nn.Module):
643
+ def __init__(self, size, min_feats_size, channel_multiplier=2, blur_kernel=[1, 3, 3, 1]):
644
+ super().__init__()
645
+
646
+ channels = {
647
+ 4: 512,
648
+ 8: 512,
649
+ 16: 512,
650
+ 32: 512,
651
+ 64: 256 * channel_multiplier,
652
+ 128: 128 * channel_multiplier,
653
+ 256: 64 * channel_multiplier,
654
+ 512: 32 * channel_multiplier,
655
+ 1024: 16 * channel_multiplier,
656
+ }
657
+
658
+ convs = [ConvLayer(3, channels[size], 1)]
659
+
660
+ log_size = int(math.log(size, 2))
661
+ if type(min_feats_size) is tuple:
662
+ fsize = min_feats_size[0] * min_feats_size[1]
663
+ else:
664
+ fsize = min_feats_size * min_feats_size
665
+
666
+ in_channel = channels[size]
667
+
668
+ for i in range(log_size, 2, -1):
669
+ out_channel = channels[2 ** (i - 1)]
670
+
671
+ convs.append(ResBlock(in_channel, out_channel, blur_kernel))
672
+
673
+ in_channel = out_channel
674
+
675
+ self.convs = nn.Sequential(*convs)
676
+
677
+ self.stddev_group = 4
678
+ self.stddev_feat = 1
679
+
680
+ self.final_conv = ConvLayer(in_channel + 1, channels[4], 3)
681
+ self.final_linear = nn.Sequential(
682
+ EqualLinear(channels[4] * fsize, channels[4], activation="fused_lrelu"),
683
+ EqualLinear(channels[4], 1),
684
+ )
685
+
686
+ def forward(self, input, rtn_feats=False):
687
+ if rtn_feats:
688
+ feats = []
689
+ feat = input
690
+ for i, block in enumerate(self.convs):
691
+ feat = block(feat)
692
+ if i in [ 1, 3, 4, 5 ]:
693
+ feats.append(feat)
694
+ if i == 5:
695
+ break
696
+ return feats
697
+
698
+ out = self.convs(input)
699
+
700
+ batch, channel, height, width = out.shape
701
+ group = min(batch, self.stddev_group)
702
+ stddev = out.view(
703
+ group, -1, self.stddev_feat, channel // self.stddev_feat, height, width
704
+ )
705
+ stddev = torch.sqrt(stddev.var(0, unbiased=False) + 1e-8)
706
+ stddev = stddev.mean([2, 3, 4], keepdims=True).squeeze(2)
707
+ stddev = stddev.repeat(group, 1, height, width)
708
+ out = torch.cat([out, stddev], 1)
709
+
710
+ out = self.final_conv(out)
711
+
712
+ out = out.view(batch, -1)
713
+ out = self.final_linear(out)
714
+
715
+ return out
716
+
models/modules/stylegan2/non_leaking.py ADDED
@@ -0,0 +1,465 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import torch
4
+ from torch import autograd
5
+ from torch.nn import functional as F
6
+ import numpy as np
7
+
8
+ # from distributed import reduce_sum
9
+ from .op import upfirdn2d
10
+
11
+
12
+ # class AdaptiveAugment:
13
+ # def __init__(self, ada_aug_target, ada_aug_len, update_every, device):
14
+ # self.ada_aug_target = ada_aug_target
15
+ # self.ada_aug_len = ada_aug_len
16
+ # self.update_every = update_every
17
+
18
+ # self.ada_update = 0
19
+ # self.ada_aug_buf = torch.tensor([0.0, 0.0], device=device)
20
+ # self.r_t_stat = 0
21
+ # self.ada_aug_p = 0
22
+
23
+ # @torch.no_grad()
24
+ # def tune(self, real_pred):
25
+ # self.ada_aug_buf += torch.tensor(
26
+ # (torch.sign(real_pred).sum().item(), real_pred.shape[0]),
27
+ # device=real_pred.device,
28
+ # )
29
+ # self.ada_update += 1
30
+
31
+ # if self.ada_update % self.update_every == 0:
32
+ # self.ada_aug_buf = reduce_sum(self.ada_aug_buf)
33
+ # pred_signs, n_pred = self.ada_aug_buf.tolist()
34
+
35
+ # self.r_t_stat = pred_signs / n_pred
36
+
37
+ # if self.r_t_stat > self.ada_aug_target:
38
+ # sign = 1
39
+
40
+ # else:
41
+ # sign = -1
42
+
43
+ # self.ada_aug_p += sign * n_pred / self.ada_aug_len
44
+ # self.ada_aug_p = min(1, max(0, self.ada_aug_p))
45
+ # self.ada_aug_buf.mul_(0)
46
+ # self.ada_update = 0
47
+
48
+ # return self.ada_aug_p
49
+
50
+
51
+ SYM6 = (
52
+ 0.015404109327027373,
53
+ 0.0034907120842174702,
54
+ -0.11799011114819057,
55
+ -0.048311742585633,
56
+ 0.4910559419267466,
57
+ 0.787641141030194,
58
+ 0.3379294217276218,
59
+ -0.07263752278646252,
60
+ -0.021060292512300564,
61
+ 0.04472490177066578,
62
+ 0.0017677118642428036,
63
+ -0.007800708325034148,
64
+ )
65
+
66
+
67
+ def translate_mat(t_x, t_y, device="cpu"):
68
+ batch = t_x.shape[0]
69
+
70
+ mat = torch.eye(3, device=device).unsqueeze(0).repeat(batch, 1, 1)
71
+ translate = torch.stack((t_x, t_y), 1)
72
+ mat[:, :2, 2] = translate
73
+
74
+ return mat
75
+
76
+
77
+ def rotate_mat(theta, device="cpu"):
78
+ batch = theta.shape[0]
79
+
80
+ mat = torch.eye(3, device=device).unsqueeze(0).repeat(batch, 1, 1)
81
+ sin_t = torch.sin(theta)
82
+ cos_t = torch.cos(theta)
83
+ rot = torch.stack((cos_t, -sin_t, sin_t, cos_t), 1).view(batch, 2, 2)
84
+ mat[:, :2, :2] = rot
85
+
86
+ return mat
87
+
88
+
89
+ def scale_mat(s_x, s_y, device="cpu"):
90
+ batch = s_x.shape[0]
91
+
92
+ mat = torch.eye(3, device=device).unsqueeze(0).repeat(batch, 1, 1)
93
+ mat[:, 0, 0] = s_x
94
+ mat[:, 1, 1] = s_y
95
+
96
+ return mat
97
+
98
+
99
+ def translate3d_mat(t_x, t_y, t_z):
100
+ batch = t_x.shape[0]
101
+
102
+ mat = torch.eye(4).unsqueeze(0).repeat(batch, 1, 1)
103
+ translate = torch.stack((t_x, t_y, t_z), 1)
104
+ mat[:, :3, 3] = translate
105
+
106
+ return mat
107
+
108
+
109
+ def rotate3d_mat(axis, theta):
110
+ batch = theta.shape[0]
111
+
112
+ u_x, u_y, u_z = axis
113
+
114
+ eye = torch.eye(3).unsqueeze(0)
115
+ cross = torch.tensor([(0, -u_z, u_y), (u_z, 0, -u_x), (-u_y, u_x, 0)]).unsqueeze(0)
116
+ outer = torch.tensor(axis)
117
+ outer = (outer.unsqueeze(1) * outer).unsqueeze(0)
118
+
119
+ sin_t = torch.sin(theta).view(-1, 1, 1)
120
+ cos_t = torch.cos(theta).view(-1, 1, 1)
121
+
122
+ rot = cos_t * eye + sin_t * cross + (1 - cos_t) * outer
123
+
124
+ eye_4 = torch.eye(4).unsqueeze(0).repeat(batch, 1, 1)
125
+ eye_4[:, :3, :3] = rot
126
+
127
+ return eye_4
128
+
129
+
130
+ def scale3d_mat(s_x, s_y, s_z):
131
+ batch = s_x.shape[0]
132
+
133
+ mat = torch.eye(4).unsqueeze(0).repeat(batch, 1, 1)
134
+ mat[:, 0, 0] = s_x
135
+ mat[:, 1, 1] = s_y
136
+ mat[:, 2, 2] = s_z
137
+
138
+ return mat
139
+
140
+
141
+ def luma_flip_mat(axis, i):
142
+ batch = i.shape[0]
143
+
144
+ eye = torch.eye(4).unsqueeze(0).repeat(batch, 1, 1)
145
+ axis = torch.tensor(axis + (0,))
146
+ flip = 2 * torch.ger(axis, axis) * i.view(-1, 1, 1)
147
+
148
+ return eye - flip
149
+
150
+
151
+ def saturation_mat(axis, i):
152
+ batch = i.shape[0]
153
+
154
+ eye = torch.eye(4).unsqueeze(0).repeat(batch, 1, 1)
155
+ axis = torch.tensor(axis + (0,))
156
+ axis = torch.ger(axis, axis)
157
+ saturate = axis + (eye - axis) * i.view(-1, 1, 1)
158
+
159
+ return saturate
160
+
161
+
162
+ def lognormal_sample(size, mean=0, std=1, device="cpu"):
163
+ return torch.empty(size, device=device).log_normal_(mean=mean, std=std)
164
+
165
+
166
+ def category_sample(size, categories, device="cpu"):
167
+ category = torch.tensor(categories, device=device)
168
+ sample = torch.randint(high=len(categories), size=(size,), device=device)
169
+
170
+ return category[sample]
171
+
172
+
173
+ def uniform_sample(size, low, high, device="cpu"):
174
+ return torch.empty(size, device=device).uniform_(low, high)
175
+
176
+
177
+ def normal_sample(size, mean=0, std=1, device="cpu"):
178
+ return torch.empty(size, device=device).normal_(mean, std)
179
+
180
+
181
+ def bernoulli_sample(size, p, device="cpu"):
182
+ return torch.empty(size, device=device).bernoulli_(p)
183
+
184
+
185
+ def random_mat_apply(p, transform, prev, eye, device="cpu"):
186
+ size = transform.shape[0]
187
+ select = bernoulli_sample(size, p, device=device).view(size, 1, 1)
188
+ select_transform = select * transform + (1 - select) * eye
189
+
190
+ return select_transform @ prev
191
+
192
+
193
+ def sample_affine(p, size, height, width, device="cpu"):
194
+ G = torch.eye(3, device=device).unsqueeze(0).repeat(size, 1, 1)
195
+ eye = G
196
+
197
+ # flip
198
+ param = category_sample(size, (0, 1))
199
+ Gc = scale_mat(1 - 2.0 * param, torch.ones(size), device=device)
200
+ G = random_mat_apply(p, Gc, G, eye, device=device)
201
+ # print('flip', G, scale_mat(1 - 2.0 * param, torch.ones(size)), sep='\n')
202
+
203
+ # 90 rotate
204
+ param = category_sample(size, (0, 3))
205
+ Gc = rotate_mat(-math.pi / 2 * param, device=device)
206
+ G = random_mat_apply(p, Gc, G, eye, device=device)
207
+ # print('90 rotate', G, rotate_mat(-math.pi / 2 * param), sep='\n')
208
+
209
+ # integer translate
210
+ param = uniform_sample((2, size), -0.125, 0.125)
211
+ param_height = torch.round(param[0] * height)
212
+ param_width = torch.round(param[1] * width)
213
+ Gc = translate_mat(param_width, param_height, device=device)
214
+ G = random_mat_apply(p, Gc, G, eye, device=device)
215
+ # print('integer translate', G, translate_mat(param_width, param_height), sep='\n')
216
+
217
+ # isotropic scale
218
+ param = lognormal_sample(size, std=0.2 * math.log(2))
219
+ Gc = scale_mat(param, param, device=device)
220
+ G = random_mat_apply(p, Gc, G, eye, device=device)
221
+ # print('isotropic scale', G, scale_mat(param, param), sep='\n')
222
+
223
+ p_rot = 1 - math.sqrt(1 - p)
224
+
225
+ # pre-rotate
226
+ param = uniform_sample(size, -math.pi, math.pi)
227
+ Gc = rotate_mat(-param, device=device)
228
+ G = random_mat_apply(p_rot, Gc, G, eye, device=device)
229
+ # print('pre-rotate', G, rotate_mat(-param), sep='\n')
230
+
231
+ # anisotropic scale
232
+ param = lognormal_sample(size, std=0.2 * math.log(2))
233
+ Gc = scale_mat(param, 1 / param, device=device)
234
+ G = random_mat_apply(p, Gc, G, eye, device=device)
235
+ # print('anisotropic scale', G, scale_mat(param, 1 / param), sep='\n')
236
+
237
+ # post-rotate
238
+ param = uniform_sample(size, -math.pi, math.pi)
239
+ Gc = rotate_mat(-param, device=device)
240
+ G = random_mat_apply(p_rot, Gc, G, eye, device=device)
241
+ # print('post-rotate', G, rotate_mat(-param), sep='\n')
242
+
243
+ # fractional translate
244
+ param = normal_sample((2, size), std=0.125)
245
+ Gc = translate_mat(param[1] * width, param[0] * height, device=device)
246
+ G = random_mat_apply(p, Gc, G, eye, device=device)
247
+ # print('fractional translate', G, translate_mat(param, param), sep='\n')
248
+
249
+ return G
250
+
251
+
252
+ def sample_color(p, size):
253
+ C = torch.eye(4).unsqueeze(0).repeat(size, 1, 1)
254
+ eye = C
255
+ axis_val = 1 / math.sqrt(3)
256
+ axis = (axis_val, axis_val, axis_val)
257
+
258
+ # brightness
259
+ param = normal_sample(size, std=0.2)
260
+ Cc = translate3d_mat(param, param, param)
261
+ C = random_mat_apply(p, Cc, C, eye)
262
+
263
+ # contrast
264
+ param = lognormal_sample(size, std=0.5 * math.log(2))
265
+ Cc = scale3d_mat(param, param, param)
266
+ C = random_mat_apply(p, Cc, C, eye)
267
+
268
+ # luma flip
269
+ param = category_sample(size, (0, 1))
270
+ Cc = luma_flip_mat(axis, param)
271
+ C = random_mat_apply(p, Cc, C, eye)
272
+
273
+ # hue rotation
274
+ param = uniform_sample(size, -math.pi, math.pi)
275
+ Cc = rotate3d_mat(axis, param)
276
+ C = random_mat_apply(p, Cc, C, eye)
277
+
278
+ # saturation
279
+ param = lognormal_sample(size, std=1 * math.log(2))
280
+ Cc = saturation_mat(axis, param)
281
+ C = random_mat_apply(p, Cc, C, eye)
282
+
283
+ return C
284
+
285
+
286
+ def make_grid(shape, x0, x1, y0, y1, device):
287
+ n, c, h, w = shape
288
+ grid = torch.empty(n, h, w, 3, device=device)
289
+ grid[:, :, :, 0] = torch.linspace(x0, x1, w, device=device)
290
+ grid[:, :, :, 1] = torch.linspace(y0, y1, h, device=device).unsqueeze(-1)
291
+ grid[:, :, :, 2] = 1
292
+
293
+ return grid
294
+
295
+
296
+ def affine_grid(grid, mat):
297
+ n, h, w, _ = grid.shape
298
+ return (grid.view(n, h * w, 3) @ mat.transpose(1, 2)).view(n, h, w, 2)
299
+
300
+
301
+ def get_padding(G, height, width, kernel_size):
302
+ device = G.device
303
+
304
+ cx = (width - 1) / 2
305
+ cy = (height - 1) / 2
306
+ cp = torch.tensor(
307
+ [(-cx, -cy, 1), (cx, -cy, 1), (cx, cy, 1), (-cx, cy, 1)], device=device
308
+ )
309
+ cp = G @ cp.T
310
+
311
+ pad_k = kernel_size // 4
312
+
313
+ pad = cp[:, :2, :].permute(1, 0, 2).flatten(1)
314
+ pad = torch.cat((-pad, pad)).max(1).values
315
+ pad = pad + torch.tensor([pad_k * 2 - cx, pad_k * 2 - cy] * 2, device=device)
316
+ pad = pad.max(torch.tensor([0.0, 0.0] * 2, device=device))
317
+ pad = pad.min(torch.tensor([width - 1.0, height - 1.0] * 2, device=device))
318
+
319
+ pad_x1, pad_y1, pad_x2, pad_y2 = pad.ceil().to(torch.int32)
320
+
321
+ return pad_x1, pad_x2, pad_y1, pad_y2
322
+
323
+
324
+ def try_sample_affine_and_pad(img, p, kernel_size, G=None):
325
+ batch, _, height, width = img.shape
326
+
327
+ G_try = G
328
+
329
+ if G is None:
330
+ G_try = torch.inverse(sample_affine(p, batch, height, width))
331
+
332
+ pad_x1, pad_x2, pad_y1, pad_y2 = get_padding(G_try, height, width, kernel_size)
333
+
334
+ img_pad = F.pad(img, (pad_x1, pad_x2, pad_y1, pad_y2), mode="reflect")
335
+
336
+ return img_pad, G_try, (pad_x1, pad_x2, pad_y1, pad_y2)
337
+
338
+
339
+ class GridSampleForward(autograd.Function):
340
+ @staticmethod
341
+ def forward(ctx, input, grid):
342
+ out = F.grid_sample(
343
+ input, grid, mode="bilinear", padding_mode="zeros", align_corners=False
344
+ )
345
+ ctx.save_for_backward(input, grid)
346
+
347
+ return out
348
+
349
+ @staticmethod
350
+ def backward(ctx, grad_output):
351
+ input, grid = ctx.saved_tensors
352
+ grad_input, grad_grid = GridSampleBackward.apply(grad_output, input, grid)
353
+
354
+ return grad_input, grad_grid
355
+
356
+
357
+ class GridSampleBackward(autograd.Function):
358
+ @staticmethod
359
+ def forward(ctx, grad_output, input, grid):
360
+ op = torch._C._jit_get_operation("aten::grid_sampler_2d_backward")
361
+ grad_input, grad_grid = op(grad_output, input, grid, 0, 0, False)
362
+ ctx.save_for_backward(grid)
363
+
364
+ return grad_input, grad_grid
365
+
366
+ @staticmethod
367
+ def backward(ctx, grad_grad_input, grad_grad_grid):
368
+ (grid,) = ctx.saved_tensors
369
+ grad_grad_output = None
370
+
371
+ if ctx.needs_input_grad[0]:
372
+ grad_grad_output = GridSampleForward.apply(grad_grad_input, grid)
373
+
374
+ return grad_grad_output, None, None
375
+
376
+
377
+ grid_sample = GridSampleForward.apply
378
+
379
+
380
+ def scale_mat_single(s_x, s_y):
381
+ return torch.tensor(((s_x, 0, 0), (0, s_y, 0), (0, 0, 1)), dtype=torch.float32)
382
+
383
+
384
+ def translate_mat_single(t_x, t_y):
385
+ return torch.tensor(((1, 0, t_x), (0, 1, t_y), (0, 0, 1)), dtype=torch.float32)
386
+
387
+
388
+ def random_apply_affine(img, p, G=None, antialiasing_kernel=SYM6):
389
+ kernel = antialiasing_kernel
390
+ len_k = len(kernel)
391
+
392
+ kernel = torch.as_tensor(kernel).to(img)
393
+ # kernel = torch.ger(kernel, kernel).to(img)
394
+ kernel_flip = torch.flip(kernel, (0,))
395
+
396
+ img_pad, G, (pad_x1, pad_x2, pad_y1, pad_y2) = try_sample_affine_and_pad(
397
+ img, p, len_k, G
398
+ )
399
+
400
+ G_inv = (
401
+ translate_mat_single((pad_x1 - pad_x2).item() / 2, (pad_y1 - pad_y2).item() / 2)
402
+ @ G
403
+ )
404
+ up_pad = (
405
+ (len_k + 2 - 1) // 2,
406
+ (len_k - 2) // 2,
407
+ (len_k + 2 - 1) // 2,
408
+ (len_k - 2) // 2,
409
+ )
410
+ img_2x = upfirdn2d(img_pad, kernel.unsqueeze(0), up=(2, 1), pad=(*up_pad[:2], 0, 0))
411
+ img_2x = upfirdn2d(img_2x, kernel.unsqueeze(1), up=(1, 2), pad=(0, 0, *up_pad[2:]))
412
+ G_inv = scale_mat_single(2, 2) @ G_inv @ scale_mat_single(1 / 2, 1 / 2)
413
+ G_inv = translate_mat_single(-0.5, -0.5) @ G_inv @ translate_mat_single(0.5, 0.5)
414
+ batch_size, channel, height, width = img.shape
415
+ pad_k = len_k // 4
416
+ shape = (batch_size, channel, (height + pad_k * 2) * 2, (width + pad_k * 2) * 2)
417
+ G_inv = (
418
+ scale_mat_single(2 / img_2x.shape[3], 2 / img_2x.shape[2])
419
+ @ G_inv
420
+ @ scale_mat_single(1 / (2 / shape[3]), 1 / (2 / shape[2]))
421
+ )
422
+ grid = F.affine_grid(G_inv[:, :2, :].to(img_2x), shape, align_corners=False)
423
+ img_affine = grid_sample(img_2x, grid)
424
+ d_p = -pad_k * 2
425
+ down_pad = (
426
+ d_p + (len_k - 2 + 1) // 2,
427
+ d_p + (len_k - 2) // 2,
428
+ d_p + (len_k - 2 + 1) // 2,
429
+ d_p + (len_k - 2) // 2,
430
+ )
431
+ img_down = upfirdn2d(
432
+ img_affine, kernel_flip.unsqueeze(0), down=(2, 1), pad=(*down_pad[:2], 0, 0)
433
+ )
434
+ img_down = upfirdn2d(
435
+ img_down, kernel_flip.unsqueeze(1), down=(1, 2), pad=(0, 0, *down_pad[2:])
436
+ )
437
+
438
+ return img_down, G
439
+
440
+
441
+ def apply_color(img, mat):
442
+ batch = img.shape[0]
443
+ img = img.permute(0, 2, 3, 1)
444
+ mat_mul = mat[:, :3, :3].transpose(1, 2).view(batch, 1, 3, 3)
445
+ mat_add = mat[:, :3, 3].view(batch, 1, 1, 3)
446
+ img = img @ mat_mul + mat_add
447
+ img = img.permute(0, 3, 1, 2)
448
+
449
+ return img
450
+
451
+
452
+ def random_apply_color(img, p, C=None):
453
+ if C is None:
454
+ C = sample_color(p, img.shape[0])
455
+
456
+ img = apply_color(img, C.to(img))
457
+
458
+ return img, C
459
+
460
+
461
+ def augment(img, p, transform_matrix=(None, None)):
462
+ img, G = random_apply_affine(img, p, transform_matrix[0])
463
+ img, C = random_apply_color(img, p, transform_matrix[1])
464
+
465
+ return img, (G, C)
models/modules/stylegan2/op/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .fused_act import FusedLeakyReLU, fused_leaky_relu
2
+ from .upfirdn2d import upfirdn2d
models/modules/stylegan2/op/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (275 Bytes). View file
 
models/modules/stylegan2/op/__pycache__/conv2d_gradfix.cpython-38.pyc ADDED
Binary file (5.36 kB). View file
 
models/modules/stylegan2/op/__pycache__/fused_act.cpython-38.pyc ADDED
Binary file (3.29 kB). View file