YuxinJ commited on
Commit
5bfe353
1 Parent(s): 6662324
.gitattributes ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ 6.png filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ */**/__pycache__
0.png ADDED
1.jpg ADDED
2.png ADDED
3.png ADDED
4.jpg ADDED
5.png ADDED
6.jpg ADDED
7.png ADDED
8.png ADDED
README.md ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Scenimefy
3
+ emoji: 🦀
4
+ colorFrom: gray
5
+ colorTo: blue
6
+ sdk: gradio
7
+ sdk_version: 3.41.1
8
+ app_file: app.py
9
+ pinned: false
10
+ license: other
11
+ ---
12
+
13
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
Scenimefy/data/__init__.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """This package includes all the modules related to data loading and preprocessing
2
+
3
+ To add a custom dataset class called 'dummy', you need to add a file called 'dummy_dataset.py' and define a subclass 'DummyDataset' inherited from BaseDataset.
4
+ You need to implement four functions:
5
+ -- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt).
6
+ -- <__len__>: return the size of dataset.
7
+ -- <__getitem__>: get a data point from data loader.
8
+ -- <modify_commandline_options>: (optionally) add dataset-specific options and set default options.
9
+
10
+ Now you can use the dataset class by specifying flag '--dataset_mode dummy'.
11
+ See our template dataset class 'template_dataset.py' for more details.
12
+ """
13
+ import importlib
14
+ import torch.utils.data
15
+ from Scenimefy.data.base_dataset import BaseDataset
16
+
17
+
18
+ def find_dataset_using_name(dataset_name):
19
+ """Import the module "data/[dataset_name]_dataset.py".
20
+
21
+ In the file, the class called DatasetNameDataset() will
22
+ be instantiated. It has to be a subclass of BaseDataset,
23
+ and it is case-insensitive.
24
+ """
25
+ dataset_filename = "Scenimefy.data." + dataset_name + "_dataset"
26
+ datasetlib = importlib.import_module(dataset_filename)
27
+
28
+ dataset = None
29
+ target_dataset_name = dataset_name.replace('_', '') + 'dataset'
30
+ for name, cls in datasetlib.__dict__.items():
31
+ if name.lower() == target_dataset_name.lower() \
32
+ and issubclass(cls, BaseDataset):
33
+ dataset = cls
34
+
35
+ if dataset is None:
36
+ raise NotImplementedError("In %s.py, there should be a subclass of BaseDataset with class name that matches %s in lowercase." % (dataset_filename, target_dataset_name))
37
+
38
+ return dataset
39
+
40
+
41
+ def get_option_setter(dataset_name):
42
+ """Return the static method <modify_commandline_options> of the dataset class."""
43
+ dataset_class = find_dataset_using_name(dataset_name)
44
+ return dataset_class.modify_commandline_options
45
+
46
+
47
+ def create_dataset(opt):
48
+ """Create a dataset given the option.
49
+
50
+ This function wraps the class CustomDatasetDataLoader.
51
+ This is the main interface between this package and 'train.py'/'test.py'
52
+
53
+ Example:
54
+ >>> from data import create_dataset
55
+ >>> dataset = create_dataset(opt)
56
+ """
57
+ data_loader = CustomDatasetDataLoader(opt)
58
+ dataset = data_loader.load_data()
59
+ return dataset
60
+
61
+
62
+ class CustomDatasetDataLoader():
63
+ """Wrapper class of Dataset class that performs multi-threaded data loading"""
64
+
65
+ def __init__(self, opt):
66
+ """Initialize this class
67
+
68
+ Step 1: create a dataset instance given the name [dataset_mode]
69
+ Step 2: create a multi-threaded data loader.
70
+ """
71
+ self.opt = opt
72
+ dataset_class = find_dataset_using_name(opt.dataset_mode)
73
+ self.dataset = dataset_class(opt)
74
+ print("dataset [%s] was created" % type(self.dataset).__name__)
75
+ self.dataloader = torch.utils.data.DataLoader(
76
+ self.dataset,
77
+ batch_size=opt.batch_size,
78
+ shuffle=not opt.serial_batches,
79
+ num_workers=int(opt.num_threads),
80
+ drop_last=True if opt.isTrain else False,
81
+ )
82
+
83
+ def set_epoch(self, epoch):
84
+ self.dataset.current_epoch = epoch
85
+
86
+ def load_data(self):
87
+ return self
88
+
89
+ def __len__(self):
90
+ """Return the number of data in the dataset"""
91
+ return min(len(self.dataset), self.opt.max_dataset_size)
92
+
93
+ def __iter__(self):
94
+ """Return a batch of data"""
95
+ for i, data in enumerate(self.dataloader):
96
+ if i * self.opt.batch_size >= self.opt.max_dataset_size:
97
+ break
98
+ yield data
99
+
100
+
101
+ # TODO: add paired dataset (stupid implementation)
102
+ def create_paired_dataset(opt):
103
+ """Create a dataset given the option.
104
+
105
+ This function wraps the class CustomDatasetDataLoader.
106
+ This is the main interface between this package and 'train.py'/'test.py'
107
+
108
+ Example:
109
+ >>> from data import create_dataset
110
+ >>> dataset = create_dataset(opt)
111
+ """
112
+ data_loader = CustomPairedDatasetDataLoader(opt)
113
+ dataset = data_loader.load_data()
114
+ return dataset
115
+
116
+
117
+ class CustomPairedDatasetDataLoader():
118
+ """Wrapper class of Dataset class that performs multi-threaded data loading"""
119
+
120
+ def __init__(self, opt):
121
+ """Initialize this class
122
+
123
+ Step 1: create a dataset instance given the name [dataset_mode]
124
+ Step 2: create a multi-threaded data loader.
125
+ """
126
+ self.opt = opt
127
+ dataset_class = find_dataset_using_name(opt.paired_dataset_mode)
128
+ self.dataset = dataset_class(opt)
129
+ print("dataset [%s] was created" % type(self.dataset).__name__)
130
+ self.dataloader = torch.utils.data.DataLoader(
131
+ self.dataset,
132
+ batch_size=opt.batch_size,
133
+ shuffle=not opt.serial_batches,
134
+ num_workers=int(opt.num_threads),
135
+ drop_last=True if opt.isTrain else False,
136
+ )
137
+
138
+ def set_epoch(self, epoch):
139
+ self.dataset.current_epoch = epoch
140
+
141
+ def load_data(self):
142
+ return self
143
+
144
+ def __len__(self):
145
+ """Return the number of data in the dataset"""
146
+ return min(len(self.dataset), self.opt.max_dataset_size)
147
+
148
+ def __iter__(self):
149
+ """Return a batch of data"""
150
+ for i, data in enumerate(self.dataloader):
151
+ if i * self.opt.batch_size >= self.opt.max_dataset_size:
152
+ break
153
+ yield data
Scenimefy/data/base_dataset.py ADDED
@@ -0,0 +1,230 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """This module implements an abstract base class (ABC) 'BaseDataset' for datasets.
2
+
3
+ It also includes common transformation functions (e.g., get_transform, __scale_width), which can be later used in subclasses.
4
+ """
5
+ import random
6
+ import numpy as np
7
+ import torch.utils.data as data
8
+ from PIL import Image
9
+ import torchvision.transforms as transforms
10
+ from abc import ABC, abstractmethod
11
+
12
+
13
+ class BaseDataset(data.Dataset, ABC):
14
+ """This class is an abstract base class (ABC) for datasets.
15
+
16
+ To create a subclass, you need to implement the following four functions:
17
+ -- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt).
18
+ -- <__len__>: return the size of dataset.
19
+ -- <__getitem__>: get a data point.
20
+ -- <modify_commandline_options>: (optionally) add dataset-specific options and set default options.
21
+ """
22
+
23
+ def __init__(self, opt):
24
+ """Initialize the class; save the options in the class
25
+
26
+ Parameters:
27
+ opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
28
+ """
29
+ self.opt = opt
30
+ self.root = opt.dataroot
31
+ self.current_epoch = 0
32
+
33
+ @staticmethod
34
+ def modify_commandline_options(parser, is_train):
35
+ """Add new dataset-specific options, and rewrite default values for existing options.
36
+
37
+ Parameters:
38
+ parser -- original option parser
39
+ is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.
40
+
41
+ Returns:
42
+ the modified parser.
43
+ """
44
+ return parser
45
+
46
+ @abstractmethod
47
+ def __len__(self):
48
+ """Return the total number of images in the dataset."""
49
+ return 0
50
+
51
+ @abstractmethod
52
+ def __getitem__(self, index):
53
+ """Return a data point and its metadata information.
54
+
55
+ Parameters:
56
+ index - - a random integer for data indexing
57
+
58
+ Returns:
59
+ a dictionary of data with their names. It ususally contains the data itself and its metadata information.
60
+ """
61
+ pass
62
+
63
+
64
+ def get_params(opt, size):
65
+ w, h = size
66
+ new_h = h
67
+ new_w = w
68
+ if opt.preprocess == 'resize_and_crop':
69
+ new_h = new_w = opt.load_size
70
+ elif opt.preprocess == 'scale_width_and_crop':
71
+ new_w = opt.load_size
72
+ new_h = opt.load_size * h // w
73
+
74
+ x = random.randint(0, np.maximum(0, new_w - opt.crop_size))
75
+ y = random.randint(0, np.maximum(0, new_h - opt.crop_size))
76
+
77
+ flip = random.random() > 0.5
78
+
79
+ return {'crop_pos': (x, y), 'flip': flip}
80
+
81
+
82
+ def get_transform(opt, params=None, grayscale=False, method=Image.BICUBIC, convert=True):
83
+ transform_list = []
84
+ if grayscale:
85
+ transform_list.append(transforms.Grayscale(1))
86
+ if 'fixsize' in opt.preprocess:
87
+ transform_list.append(transforms.Resize(params["size"], method))
88
+ if 'resize' in opt.preprocess:
89
+ osize = [opt.load_size, opt.load_size]
90
+ if "gta2cityscapes" in opt.dataroot:
91
+ osize[0] = opt.load_size // 2
92
+ transform_list.append(transforms.Resize(osize, method))
93
+ elif 'scale_width' in opt.preprocess:
94
+ transform_list.append(transforms.Lambda(lambda img: __scale_width(img, opt.load_size, opt.crop_size, method)))
95
+ elif 'scale_shortside' in opt.preprocess:
96
+ transform_list.append(transforms.Lambda(lambda img: __scale_shortside(img, opt.load_size, opt.crop_size, method)))
97
+
98
+ if 'zoom' in opt.preprocess:
99
+ if params is None:
100
+ transform_list.append(transforms.Lambda(lambda img: __random_zoom(img, opt.load_size, opt.crop_size, method)))
101
+ else:
102
+ transform_list.append(transforms.Lambda(lambda img: __random_zoom(img, opt.load_size, opt.crop_size, method, factor=params["scale_factor"])))
103
+
104
+ if 'crop' in opt.preprocess:
105
+ if params is None or 'crop_pos' not in params:
106
+ transform_list.append(transforms.RandomCrop(opt.crop_size))
107
+ else:
108
+ transform_list.append(transforms.Lambda(lambda img: __crop(img, params['crop_pos'], opt.crop_size)))
109
+
110
+ if 'patch' in opt.preprocess:
111
+ transform_list.append(transforms.Lambda(lambda img: __patch(img, params['patch_index'], opt.crop_size)))
112
+
113
+ if 'trim' in opt.preprocess:
114
+ transform_list.append(transforms.Lambda(lambda img: __trim(img, opt.crop_size)))
115
+
116
+ # if opt.preprocess == 'none':
117
+ transform_list.append(transforms.Lambda(lambda img: __make_power_2(img, base=4, method=method)))
118
+
119
+ if not opt.no_flip:
120
+ if params is None or 'flip' not in params:
121
+ transform_list.append(transforms.RandomHorizontalFlip())
122
+ elif 'flip' in params:
123
+ transform_list.append(transforms.Lambda(lambda img: __flip(img, params['flip'])))
124
+
125
+ if convert:
126
+ transform_list += [transforms.ToTensor()]
127
+ if grayscale:
128
+ transform_list += [transforms.Normalize((0.5,), (0.5,))]
129
+ else:
130
+ transform_list += [transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
131
+ return transforms.Compose(transform_list)
132
+
133
+
134
+ def __make_power_2(img, base, method=Image.BICUBIC):
135
+ ow, oh = img.size
136
+ h = int(round(oh / base) * base)
137
+ w = int(round(ow / base) * base)
138
+ if h == oh and w == ow:
139
+ return img
140
+
141
+ return img.resize((w, h), method)
142
+
143
+
144
+ def __random_zoom(img, target_width, crop_width, method=Image.BICUBIC, factor=None):
145
+ if factor is None:
146
+ zoom_level = np.random.uniform(0.8, 1.0, size=[2])
147
+ else:
148
+ zoom_level = (factor[0], factor[1])
149
+ iw, ih = img.size
150
+ zoomw = max(crop_width, iw * zoom_level[0])
151
+ zoomh = max(crop_width, ih * zoom_level[1])
152
+ img = img.resize((int(round(zoomw)), int(round(zoomh))), method)
153
+ return img
154
+
155
+
156
+ def __scale_shortside(img, target_width, crop_width, method=Image.BICUBIC):
157
+ ow, oh = img.size
158
+ shortside = min(ow, oh)
159
+ if shortside >= target_width:
160
+ return img
161
+ else:
162
+ scale = target_width / shortside
163
+ return img.resize((round(ow * scale), round(oh * scale)), method)
164
+
165
+
166
+ def __trim(img, trim_width):
167
+ ow, oh = img.size
168
+ if ow > trim_width:
169
+ xstart = np.random.randint(ow - trim_width)
170
+ xend = xstart + trim_width
171
+ else:
172
+ xstart = 0
173
+ xend = ow
174
+ if oh > trim_width:
175
+ ystart = np.random.randint(oh - trim_width)
176
+ yend = ystart + trim_width
177
+ else:
178
+ ystart = 0
179
+ yend = oh
180
+ return img.crop((xstart, ystart, xend, yend))
181
+
182
+
183
+ def __scale_width(img, target_width, crop_width, method=Image.BICUBIC):
184
+ ow, oh = img.size
185
+ if ow == target_width and oh >= crop_width:
186
+ return img
187
+ w = target_width
188
+ h = int(max(target_width * oh / ow, crop_width))
189
+ return img.resize((w, h), method)
190
+
191
+
192
+ def __crop(img, pos, size):
193
+ ow, oh = img.size
194
+ x1, y1 = pos
195
+ tw = th = size
196
+ if (ow > tw or oh > th):
197
+ return img.crop((x1, y1, x1 + tw, y1 + th))
198
+ return img
199
+
200
+
201
+ def __patch(img, index, size):
202
+ ow, oh = img.size
203
+ nw, nh = ow // size, oh // size
204
+ roomx = ow - nw * size
205
+ roomy = oh - nh * size
206
+ startx = np.random.randint(int(roomx) + 1)
207
+ starty = np.random.randint(int(roomy) + 1)
208
+
209
+ index = index % (nw * nh)
210
+ ix = index // nh
211
+ iy = index % nh
212
+ gridx = startx + ix * size
213
+ gridy = starty + iy * size
214
+ return img.crop((gridx, gridy, gridx + size, gridy + size))
215
+
216
+
217
+ def __flip(img, flip):
218
+ if flip:
219
+ return img.transpose(Image.FLIP_LEFT_RIGHT)
220
+ return img
221
+
222
+
223
+ def __print_size_warning(ow, oh, w, h):
224
+ """Print warning information about image size(only print once)"""
225
+ if not hasattr(__print_size_warning, 'has_printed'):
226
+ print("The image size needs to be a multiple of 4. "
227
+ "The loaded image size was (%d, %d), so it was adjusted to "
228
+ "(%d, %d). This adjustment will be done to all images "
229
+ "whose sizes are not multiples of 4" % (ow, oh, w, h))
230
+ __print_size_warning.has_printed = True
Scenimefy/data/image_folder.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """A modified image folder class
2
+
3
+ We modify the official PyTorch image folder (https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py)
4
+ so that this class can load images from both current directory and its subdirectories.
5
+ """
6
+
7
+ import torch.utils.data as data
8
+
9
+ from PIL import Image
10
+ import os
11
+ import os.path
12
+
13
+ IMG_EXTENSIONS = [
14
+ '.jpg', '.JPG', '.jpeg', '.JPEG',
15
+ '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
16
+ '.tif', '.TIF', '.tiff', '.TIFF',
17
+ ]
18
+
19
+
20
+ def is_image_file(filename):
21
+ return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
22
+
23
+
24
+ def make_dataset(dir, max_dataset_size=float("inf")):
25
+ images = []
26
+ assert os.path.isdir(dir) or os.path.islink(dir), '%s is not a valid directory' % dir
27
+
28
+ for root, _, fnames in sorted(os.walk(dir, followlinks=True)):
29
+ for fname in fnames:
30
+ if is_image_file(fname):
31
+ path = os.path.join(root, fname)
32
+ images.append(path)
33
+ return images[:min(max_dataset_size, len(images))]
34
+
35
+
36
+ def default_loader(path):
37
+ return Image.open(path).convert('RGB')
38
+
39
+
40
+ class ImageFolder(data.Dataset):
41
+
42
+ def __init__(self, root, transform=None, return_paths=False,
43
+ loader=default_loader):
44
+ imgs = make_dataset(root)
45
+ if len(imgs) == 0:
46
+ raise(RuntimeError("Found 0 images in: " + root + "\n"
47
+ "Supported image extensions are: " + ",".join(IMG_EXTENSIONS)))
48
+
49
+ self.root = root
50
+ self.imgs = imgs
51
+ self.transform = transform
52
+ self.return_paths = return_paths
53
+ self.loader = loader
54
+
55
+ def __getitem__(self, index):
56
+ path = self.imgs[index]
57
+ img = self.loader(path)
58
+ if self.transform is not None:
59
+ img = self.transform(img)
60
+ if self.return_paths:
61
+ return img, path
62
+ else:
63
+ return img
64
+
65
+ def __len__(self):
66
+ return len(self.imgs)
Scenimefy/data/unaligned_dataset.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os.path
2
+ from Scenimefy.data.base_dataset import BaseDataset, get_transform
3
+ from Scenimefy.data.image_folder import make_dataset
4
+ from PIL import Image
5
+ import random
6
+ import Scenimefy.utils.util as util
7
+
8
+
9
+ class UnalignedDataset(BaseDataset):
10
+ """
11
+ This dataset class can load unaligned/unpaired datasets.
12
+
13
+ It requires two directories to host training images from domain A '/path/to/data/trainA'
14
+ and from domain B '/path/to/data/trainB' respectively.
15
+ You can train the model with the dataset flag '--dataroot /path/to/data'.
16
+ Similarly, you need to prepare two directories:
17
+ '/path/to/data/testA' and '/path/to/data/testB' during test time.
18
+ """
19
+
20
+ def __init__(self, opt):
21
+ """Initialize this dataset class.
22
+
23
+ Parameters:
24
+ opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
25
+ """
26
+ BaseDataset.__init__(self, opt)
27
+ self.dir_A = os.path.join(opt.dataroot, opt.phase + 'A') # create a path '/path/to/data/trainA'
28
+ self.dir_B = os.path.join(opt.dataroot, opt.phase + 'B') # create a path '/path/to/data/trainB'
29
+
30
+ if opt.phase == "test" and not os.path.exists(self.dir_A) \
31
+ and os.path.exists(os.path.join(opt.dataroot, "valA")):
32
+ self.dir_A = os.path.join(opt.dataroot, "valA")
33
+ self.dir_B = os.path.join(opt.dataroot, "valB")
34
+
35
+ self.A_paths = sorted(make_dataset(self.dir_A, opt.max_dataset_size)) # load images from '/path/to/data/trainA'
36
+ self.B_paths = sorted(make_dataset(self.dir_B, opt.max_dataset_size)) # load images from '/path/to/data/trainB'
37
+ self.A_size = len(self.A_paths) # get the size of dataset A
38
+ self.B_size = len(self.B_paths) # get the size of dataset B
39
+
40
+ def __getitem__(self, index):
41
+ """Return a data point and its metadata information.
42
+
43
+ Parameters:
44
+ index (int) -- a random integer for data indexing
45
+
46
+ Returns a dictionary that contains A, B, A_paths and B_paths
47
+ A (tensor) -- an image in the input domain
48
+ B (tensor) -- its corresponding image in the target domain
49
+ A_paths (str) -- image paths
50
+ B_paths (str) -- image paths
51
+ """
52
+ A_path = self.A_paths[index % self.A_size] # make sure index is within then range
53
+ if self.opt.serial_batches: # make sure index is within then range
54
+ index_B = index % self.B_size
55
+ else: # randomize the index for domain B to avoid fixed pairs.
56
+ index_B = random.randint(0, self.B_size - 1)
57
+ B_path = self.B_paths[index_B]
58
+ A_img = Image.open(A_path).convert('RGB')
59
+ B_img = Image.open(B_path).convert('RGB')
60
+
61
+ # Apply image transformation
62
+ # For FastCUT mode, if in finetuning phase (learning rate is decaying),
63
+ # do not perform resize-crop data augmentation of CycleGAN.
64
+ # print('current_epoch', self.current_epoch)
65
+ is_finetuning = self.opt.isTrain and self.current_epoch > self.opt.n_epochs
66
+ modified_opt = util.copyconf(self.opt, load_size=self.opt.crop_size if is_finetuning else self.opt.load_size)
67
+ transform = get_transform(modified_opt)
68
+ A = transform(A_img)
69
+ B = transform(B_img)
70
+
71
+ return {'A': A, 'B': B, 'A_paths': A_path, 'B_paths': B_path}
72
+
73
+ def __len__(self):
74
+ """Return the total number of images in the dataset.
75
+
76
+ As we have two datasets with potentially different number of images,
77
+ we take a maximum of
78
+ """
79
+ return max(self.A_size, self.B_size)
Scenimefy/models/SRC.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from packaging import version
2
+ import torch
3
+ from torch import nn
4
+
5
+
6
+ class Normalize(nn.Module):
7
+
8
+ def __init__(self, power=2):
9
+ super(Normalize, self).__init__()
10
+ self.power = power
11
+
12
+ def forward(self, x):
13
+ norm = x.pow(self.power).sum(1, keepdim=True).pow(1. / self.power)
14
+ out = x.div(norm + 1e-7)
15
+ return out
16
+
17
+ class SRC_Loss(nn.Module):
18
+ def __init__(self, opt):
19
+ super().__init__()
20
+ self.opt = opt
21
+ self.mask_dtype = torch.uint8 if version.parse(torch.__version__) < version.parse('1.2.0') else torch.bool
22
+
23
+ def forward(self, feat_q, feat_k, only_weight=False, epoch=None):
24
+ '''
25
+ :param feat_q: target
26
+ :param feat_k: source
27
+ :return: SRC loss, weights for hDCE
28
+ '''
29
+
30
+ batchSize = feat_q.shape[0]
31
+ dim = feat_q.shape[1]
32
+ feat_k = feat_k.detach()
33
+ batch_dim_for_bmm = 1 # self.opt.batch_size
34
+ feat_k = Normalize()(feat_k)
35
+ feat_q = Normalize()(feat_q)
36
+
37
+ ## SRC
38
+ feat_q_v = feat_q.view(batch_dim_for_bmm, -1, dim)
39
+ feat_k_v = feat_k.view(batch_dim_for_bmm, -1, dim)
40
+
41
+ spatial_q = torch.bmm(feat_q_v, feat_q_v.transpose(2, 1))
42
+ spatial_k = torch.bmm(feat_k_v, feat_k_v.transpose(2, 1))
43
+
44
+ weight_seed = spatial_k.clone().detach()
45
+ diagonal = torch.eye(self.opt.num_patches, device=feat_k_v.device, dtype=self.mask_dtype)[None, :, :]
46
+
47
+ HDCE_gamma = self.opt.HDCE_gamma
48
+ if self.opt.use_curriculum:
49
+ HDCE_gamma = HDCE_gamma + (self.opt.HDCE_gamma_min - HDCE_gamma) * (epoch) / (self.opt.n_epochs + self.opt.n_epochs_decay)
50
+ if (self.opt.step_gamma)&(epoch>self.opt.step_gamma_epoch):
51
+ HDCE_gamma = 1
52
+
53
+
54
+ ## weights by semantic relation
55
+ weight_seed.masked_fill_(diagonal, -10.0)
56
+ weight_out = nn.Softmax(dim=2)(weight_seed.clone() / HDCE_gamma).detach()
57
+ wmax_out, _ = torch.max(weight_out, dim=2, keepdim=True)
58
+ weight_out /= wmax_out
59
+
60
+ if only_weight:
61
+ return 0, weight_out
62
+
63
+ spatial_q = nn.Softmax(dim=1)(spatial_q)
64
+ spatial_k = nn.Softmax(dim=1)(spatial_k).detach()
65
+
66
+ loss_src = self.get_jsd(spatial_q, spatial_k)
67
+
68
+ return loss_src, weight_out
69
+
70
+ def get_jsd(self, p1, p2):
71
+ '''
72
+ :param p1: n X C
73
+ :param p2: n X C
74
+ :return: n X 1
75
+ '''
76
+ m = 0.5 * (p1 + p2)
77
+ out = 0.5 * (nn.KLDivLoss(reduction='sum', log_target=True)(torch.log(m), torch.log(p1))
78
+ + nn.KLDivLoss(reduction='sum', log_target=True)(torch.log(m), torch.log(p2)))
79
+ return out
Scenimefy/models/__init__.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """This package contains modules related to objective functions, optimizations, and network architectures.
2
+
3
+ To add a custom model class called 'dummy', you need to add a file called 'dummy_model.py' and define a subclass DummyModel inherited from BaseModel.
4
+ You need to implement the following five functions:
5
+ -- <__init__>: initialize the class; first call BaseModel.__init__(self, opt).
6
+ -- <set_input>: unpack data from dataset and apply preprocessing.
7
+ -- <forward>: produce intermediate results.
8
+ -- <optimize_parameters>: calculate loss, gradients, and update network weights.
9
+ -- <modify_commandline_options>: (optionally) add model-specific options and set default options.
10
+
11
+ In the function <__init__>, you need to define four lists:
12
+ -- self.loss_names (str list): specify the training losses that you want to plot and save.
13
+ -- self.model_names (str list): define networks used in our training.
14
+ -- self.visual_names (str list): specify the images that you want to display and save.
15
+ -- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an usage.
16
+
17
+ Now you can use the model class by specifying flag '--model dummy'.
18
+ See our template model class 'template_model.py' for more details.
19
+ """
20
+
21
+ import importlib
22
+ from Scenimefy.models.base_model import BaseModel
23
+
24
+
25
+ def find_model_using_name(model_name):
26
+ """Import the module "models/[model_name]_model.py".
27
+
28
+ In the file, the class called DatasetNameModel() will
29
+ be instantiated. It has to be a subclass of BaseModel,
30
+ and it is case-insensitive.
31
+ """
32
+ model_filename = "Scenimefy.models." + model_name + "_model"
33
+ modellib = importlib.import_module(model_filename)
34
+ model = None
35
+ target_model_name = model_name.replace('_', '') + 'model'
36
+ for name, cls in modellib.__dict__.items():
37
+ if name.lower() == target_model_name.lower() \
38
+ and issubclass(cls, BaseModel):
39
+ model = cls
40
+
41
+ if model is None:
42
+ print("In %s.py, there should be a subclass of BaseModel with class name that matches %s in lowercase." % (model_filename, target_model_name))
43
+ exit(0)
44
+
45
+ return model
46
+
47
+
48
+ def get_option_setter(model_name):
49
+ """Return the static method <modify_commandline_options> of the model class."""
50
+ model_class = find_model_using_name(model_name)
51
+ return model_class.modify_commandline_options
52
+
53
+
54
+ def create_model(opt):
55
+ """Create a model given the option.
56
+
57
+ This function warps the class CustomDatasetDataLoader.
58
+ This is the main interface between this package and 'train.py'/'test.py'
59
+
60
+ Example:
61
+ >>> from models import create_model
62
+ >>> model = create_model(opt)
63
+ """
64
+ model = find_model_using_name(opt.model)
65
+ instance = model(opt)
66
+ print("model [%s] was created" % type(instance).__name__)
67
+ return instance
Scenimefy/models/base_model.py ADDED
@@ -0,0 +1,258 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from collections import OrderedDict
4
+ from abc import ABC, abstractmethod
5
+ from Scenimefy.models import networks
6
+
7
+
8
+ class BaseModel(ABC):
9
+ """This class is an abstract base class (ABC) for models.
10
+ To create a subclass, you need to implement the following five functions:
11
+ -- <__init__>: initialize the class; first call BaseModel.__init__(self, opt).
12
+ -- <set_input>: unpack data from dataset and apply preprocessing.
13
+ -- <forward>: produce intermediate results.
14
+ -- <optimize_parameters>: calculate losses, gradients, and update network weights.
15
+ -- <modify_commandline_options>: (optionally) add model-specific options and set default options.
16
+ """
17
+
18
+ def __init__(self, opt):
19
+ """Initialize the BaseModel class.
20
+
21
+ Parameters:
22
+ opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
23
+
24
+ When creating your custom class, you need to implement your own initialization.
25
+ In this fucntion, you should first call <BaseModel.__init__(self, opt)>
26
+ Then, you need to define four lists:
27
+ -- self.loss_names (str list): specify the training losses that you want to plot and save.
28
+ -- self.model_names (str list): specify the images that you want to display and save.
29
+ -- self.visual_names (str list): define networks used in our training.
30
+ -- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an example.
31
+ """
32
+ self.opt = opt
33
+ self.gpu_ids = opt.gpu_ids
34
+ self.isTrain = opt.isTrain
35
+ self.device = torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu') # get device name: CPU or GPU
36
+ self.save_dir = os.path.join(opt.checkpoints_dir, opt.name) # save all the checkpoints to save_dir
37
+ if opt.preprocess != 'scale_width': # with [scale_width], input images might have different sizes, which hurts the performance of cudnn.benchmark.
38
+ torch.backends.cudnn.benchmark = True
39
+ self.loss_names = []
40
+ self.model_names = []
41
+ self.visual_names = []
42
+ self.optimizers = []
43
+ self.image_paths = []
44
+ self.metric = 0 # used for learning rate policy 'plateau'
45
+
46
+ @staticmethod
47
+ def dict_grad_hook_factory(add_func=lambda x: x):
48
+ saved_dict = dict()
49
+
50
+ def hook_gen(name):
51
+ def grad_hook(grad):
52
+ saved_vals = add_func(grad)
53
+ saved_dict[name] = saved_vals
54
+ return grad_hook
55
+ return hook_gen, saved_dict
56
+
57
+ @staticmethod
58
+ def modify_commandline_options(parser, is_train):
59
+ """Add new model-specific options, and rewrite default values for existing options.
60
+
61
+ Parameters:
62
+ parser -- original option parser
63
+ is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.
64
+
65
+ Returns:
66
+ the modified parser.
67
+ """
68
+ return parser
69
+
70
+ @abstractmethod
71
+ def set_input(self, input):
72
+ """Unpack input data from the dataloader and perform necessary pre-processing steps.
73
+
74
+ Parameters:
75
+ input (dict): includes the data itself and its metadata information.
76
+ """
77
+ pass
78
+
79
+ @abstractmethod
80
+ def forward(self):
81
+ """Run forward pass; called by both functions <optimize_parameters> and <test>."""
82
+ pass
83
+
84
+ @abstractmethod
85
+ def optimize_parameters(self):
86
+ """Calculate losses, gradients, and update network weights; called in every training iteration"""
87
+ pass
88
+
89
+ def setup(self, opt):
90
+ """Load and print networks; create schedulers
91
+
92
+ Parameters:
93
+ opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
94
+ """
95
+ if self.isTrain:
96
+ self.schedulers = [networks.get_scheduler(optimizer, opt) for optimizer in self.optimizers]
97
+ if not self.isTrain or opt.continue_train:
98
+ load_suffix = opt.epoch
99
+ self.load_networks(load_suffix)
100
+
101
+ self.print_networks(opt.verbose)
102
+
103
+ def parallelize(self):
104
+ for name in self.model_names:
105
+ if isinstance(name, str):
106
+ net = getattr(self, 'net' + name)
107
+ setattr(self, 'net' + name, torch.nn.DataParallel(net, self.opt.gpu_ids))
108
+
109
+ def data_dependent_initialize(self, data):
110
+ pass
111
+
112
+ def eval(self):
113
+ """Make models eval mode during test time"""
114
+ for name in self.model_names:
115
+ if isinstance(name, str):
116
+ net = getattr(self, 'net' + name)
117
+ net.eval()
118
+
119
+ def test(self):
120
+ """Forward function used in test time.
121
+
122
+ This function wraps <forward> function in no_grad() so we don't save intermediate steps for backprop
123
+ It also calls <compute_visuals> to produce additional visualization results
124
+ """
125
+ with torch.no_grad():
126
+ self.forward()
127
+ self.compute_visuals()
128
+
129
+ def compute_visuals(self):
130
+ """Calculate additional output images for visdom and HTML visualization"""
131
+ pass
132
+
133
+ def get_image_paths(self):
134
+ """ Return image paths that are used to load current data"""
135
+ return self.image_paths
136
+
137
+ def update_learning_rate(self):
138
+ """Update learning rates for all the networks; called at the end of every epoch"""
139
+ for scheduler in self.schedulers:
140
+ if self.opt.lr_policy == 'plateau':
141
+ scheduler.step(self.metric)
142
+ else:
143
+ scheduler.step()
144
+
145
+ lr = self.optimizers[0].param_groups[0]['lr']
146
+ print('learning rate = %.7f' % lr)
147
+
148
+ def get_current_visuals(self):
149
+ """Return visualization images. train.py will display these images with visdom, and save the images to a HTML"""
150
+ visual_ret = OrderedDict()
151
+ for name in self.visual_names:
152
+ if isinstance(name, str):
153
+ visual_ret[name] = getattr(self, name)
154
+ return visual_ret
155
+
156
+ def get_current_losses(self):
157
+ """Return traning losses / errors. train.py will print out these errors on console, and save them to a file"""
158
+ errors_ret = OrderedDict()
159
+ for name in self.loss_names:
160
+ if isinstance(name, str):
161
+ errors_ret[name] = float(getattr(self, 'loss_' + name)) # float(...) works for both scalar tensor and float number
162
+ return errors_ret
163
+
164
+ def save_networks(self, epoch):
165
+ """Save all the networks to the disk.
166
+
167
+ Parameters:
168
+ epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name)
169
+ """
170
+ for name in self.model_names:
171
+ if isinstance(name, str):
172
+ save_filename = '%s_net_%s.pth' % (epoch, name)
173
+ save_path = os.path.join(self.save_dir, save_filename)
174
+ net = getattr(self, 'net' + name)
175
+
176
+ if len(self.gpu_ids) > 0 and torch.cuda.is_available():
177
+ torch.save(net.module.cpu().state_dict(), save_path)
178
+ net.cuda(self.gpu_ids[0])
179
+ else:
180
+ torch.save(net.cpu().state_dict(), save_path)
181
+
182
+ def __patch_instance_norm_state_dict(self, state_dict, module, keys, i=0):
183
+ """Fix InstanceNorm checkpoints incompatibility (prior to 0.4)"""
184
+ key = keys[i]
185
+ if i + 1 == len(keys): # at the end, pointing to a parameter/buffer
186
+ if module.__class__.__name__.startswith('InstanceNorm') and \
187
+ (key == 'running_mean' or key == 'running_var'):
188
+ if getattr(module, key) is None:
189
+ state_dict.pop('.'.join(keys))
190
+ if module.__class__.__name__.startswith('InstanceNorm') and \
191
+ (key == 'num_batches_tracked'):
192
+ state_dict.pop('.'.join(keys))
193
+ else:
194
+ self.__patch_instance_norm_state_dict(state_dict, getattr(module, key), keys, i + 1)
195
+
196
+ def load_networks(self, epoch):
197
+ """Load all the networks from the disk.
198
+
199
+ Parameters:
200
+ epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name)
201
+ """
202
+ for name in self.model_names:
203
+ if isinstance(name, str):
204
+ load_filename = '%s_net_%s.pth' % (epoch, name)
205
+ if self.opt.isTrain and self.opt.pretrained_name is not None:
206
+ load_dir = os.path.join(self.opt.checkpoints_dir, self.opt.pretrained_name)
207
+ else:
208
+ load_dir = self.save_dir
209
+
210
+ load_path = os.path.join(load_dir, load_filename)
211
+ net = getattr(self, 'net' + name)
212
+ if isinstance(net, torch.nn.DataParallel):
213
+ net = net.module
214
+ print('loading the model from %s' % load_path)
215
+ # if you are using PyTorch newer than 0.4 (e.g., built from
216
+ # GitHub source), you can remove str() on self.device
217
+ state_dict = torch.load(load_path, map_location=str(self.device))
218
+ if hasattr(state_dict, '_metadata'):
219
+ del state_dict._metadata
220
+
221
+ # patch InstanceNorm checkpoints prior to 0.4
222
+ # for key in list(state_dict.keys()): # need to copy keys here because we mutate in loop
223
+ # self.__patch_instance_norm_state_dict(state_dict, net, key.split('.'))
224
+ net.load_state_dict(state_dict)
225
+
226
+ def print_networks(self, verbose):
227
+ """Print the total number of parameters in the network and (if verbose) network architecture
228
+
229
+ Parameters:
230
+ verbose (bool) -- if verbose: print the network architecture
231
+ """
232
+ print('---------- Networks initialized -------------')
233
+ for name in self.model_names:
234
+ if isinstance(name, str):
235
+ net = getattr(self, 'net' + name)
236
+ num_params = 0
237
+ for param in net.parameters():
238
+ num_params += param.numel()
239
+ if verbose:
240
+ print(net)
241
+ print('[Network %s] Total number of parameters : %.3f M' % (name, num_params / 1e6))
242
+ print('-----------------------------------------------')
243
+
244
+ def set_requires_grad(self, nets, requires_grad=False):
245
+ """Set requies_grad=Fasle for all the networks to avoid unnecessary computations
246
+ Parameters:
247
+ nets (network list) -- a list of networks
248
+ requires_grad (bool) -- whether the networks require gradients or not
249
+ """
250
+ if not isinstance(nets, list):
251
+ nets = [nets]
252
+ for net in nets:
253
+ if net is not None:
254
+ for param in net.parameters():
255
+ param.requires_grad = requires_grad
256
+
257
+ def generate_visuals_for_evaluation(self, data, mode):
258
+ return {}
Scenimefy/models/cut_model.py ADDED
@@ -0,0 +1,370 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from Scenimefy.models.base_model import BaseModel
4
+ from Scenimefy.models import networks
5
+ from Scenimefy.models.patchnce import PatchNCELoss
6
+ import Scenimefy.utils.util as util
7
+ from torch.distributions.beta import Beta
8
+ from torch.nn import functional as F
9
+ from Scenimefy.models.hDCE import PatchHDCELoss
10
+ from Scenimefy.models.SRC import SRC_Loss
11
+ import torch.nn as nn
12
+
13
+
14
+ def show_np_r(array, min, max, num):
15
+ plt.figure(num)
16
+ plt.imshow(array, norm=None, cmap='gray', vmin= min, vmax=max)
17
+ plt.axis('off')
18
+ plt.show()
19
+
20
+ def show_hot_r(array, num):
21
+ plt.figure(num)
22
+ plt.imshow(array, norm=None, cmap='hot')
23
+ plt.axis('off')
24
+ plt.show()
25
+
26
+ def show_torch_rgb(array, min, max, num):
27
+ plt.figure(num)
28
+ plt.imshow(array.detach().cpu()[0].permute(1,2,0).numpy()*255, norm=None, cmap='gray', vmin= min, vmax=max)
29
+ plt.axis('off')
30
+ plt.show()
31
+
32
+
33
+ class Normalize(nn.Module):
34
+
35
+ def __init__(self, power=2):
36
+ super(Normalize, self).__init__()
37
+ self.power = power
38
+
39
+ def forward(self, x):
40
+ norm = x.pow(self.power).sum(1, keepdim=True).pow(1. / self.power)
41
+ out = x.div(norm + 1e-7)
42
+ return out
43
+
44
+ def get_lambda(alpha=1.0,size=None,device=None):
45
+ '''Return lambda'''
46
+ if alpha > 0.:
47
+ lam = np.random.beta(alpha, alpha)
48
+ # lam = Beta()
49
+ else:
50
+ lam = 1.
51
+ return lam
52
+ def get_spa_lambda(alpha=1.0,size=None,device=None):
53
+ '''Return lambda'''
54
+ if alpha > 0.:
55
+ lam = torch.from_numpy(np.random.beta(alpha, alpha,size=size)).float().to(device)
56
+ # lam = Beta()
57
+ else:
58
+ lam = 1.
59
+ return lam
60
+ class CUTModel(BaseModel):
61
+ """ This class implements CUT and FastCUT model, described in the paper
62
+ Contrastive Learning for Unpaired Image-to-Image Translation
63
+ Taesung Park, Alexei A. Efros, Richard Zhang, Jun-Yan Zhu
64
+ ECCV, 2020
65
+
66
+ The code borrows heavily from the PyTorch implementation of CycleGAN
67
+ https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix
68
+ """
69
+ @staticmethod
70
+ def modify_commandline_options(parser, is_train=True):
71
+ """ Configures options specific for CUT model
72
+ """
73
+ parser.add_argument('--CUT_mode', type=str, default="CUT", choices='(CUT, cut, FastCUT, fastcut)')
74
+
75
+ parser.add_argument('--lambda_GAN', type=float, default=1.0, help='weight for GAN loss:GAN(G(X))')
76
+ parser.add_argument('--lambda_HDCE', type=float, default=1.0, help='weight for HDCE loss: HDCE(G(X), X)')
77
+ parser.add_argument('--lambda_SRC', type=float, default=1.0, help='weight for SRC loss: SRC(G(X), X)')
78
+ parser.add_argument('--dce_idt', action='store_true')
79
+ parser.add_argument('--nce_layers', type=str, default='0,4,8,12,16', help='compute NCE loss on which layers')
80
+ parser.add_argument('--nce_includes_all_negatives_from_minibatch',
81
+ type=util.str2bool, nargs='?', const=True, default=False,
82
+ help='(used for single image translation) If True, include the negatives from the other samples of the minibatch when computing the contrastive loss. Please see models/patchnce.py for more details.')
83
+ parser.add_argument('--netF', type=str, default='mlp_sample', choices=['sample', 'reshape', 'mlp_sample'], help='how to downsample the feature map')
84
+ parser.add_argument('--netF_nc', type=int, default=256)
85
+ parser.add_argument('--nce_T', type=float, default=0.07, help='temperature for NCE loss')
86
+ parser.add_argument('--num_patches', type=int, default=256, help='number of patches per layer')
87
+ parser.add_argument('--flip_equivariance',
88
+ type=util.str2bool, nargs='?', const=True, default=False,
89
+ help="Enforce flip-equivariance as additional regularization. It's used by FastCUT, but not CUT")
90
+ parser.add_argument('--alpha', type=float, default=0.2)
91
+ parser.add_argument('--use_curriculum', action='store_true')
92
+ parser.add_argument('--HDCE_gamma', type=float, default=1)
93
+ parser.add_argument('--HDCE_gamma_min', type=float, default=1)
94
+ parser.add_argument('--step_gamma', action='store_true')
95
+ parser.add_argument('--step_gamma_epoch', type=int, default=200)
96
+ parser.add_argument('--no_Hneg', action='store_true')
97
+
98
+ parser.set_defaults(pool_size=0) # no image pooling
99
+
100
+ opt, _ = parser.parse_known_args()
101
+
102
+ return parser
103
+
104
+ def __init__(self, opt):
105
+ BaseModel.__init__(self, opt)
106
+
107
+ self.train_epoch = None
108
+
109
+ # specify the training losses you want to print out.
110
+ # The training/test scripts will call <BaseModel.get_current_losses>
111
+ self.loss_names = ['G_GAN', 'D_real', 'D_fake', 'G']
112
+
113
+ if opt.lambda_HDCE > 0.0:
114
+ self.loss_names.append('HDCE')
115
+ if opt.dce_idt and self.isTrain:
116
+ self.loss_names += ['HDCE_Y']
117
+
118
+ if opt.lambda_SRC > 0.0:
119
+ self.loss_names.append('SRC')
120
+
121
+
122
+ self.visual_names = ['real_A', 'fake_B', 'real_B']
123
+ self.nce_layers = [int(i) for i in self.opt.nce_layers.split(',')]
124
+ self.alpha = opt.alpha
125
+ if opt.dce_idt and self.isTrain:
126
+ self.visual_names += ['idt_B']
127
+
128
+ if self.isTrain:
129
+ self.model_names = ['G', 'F', 'D']
130
+ else: # during test time, only load G
131
+ self.model_names = ['G']
132
+ # define networks (both generator and discriminator)
133
+ self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, opt.normG, not opt.no_dropout, opt.init_type, opt.init_gain, opt.no_antialias, opt.no_antialias_up, self.gpu_ids, opt)
134
+ self.netF = networks.define_F(opt.input_nc, opt.netF, opt.normG, not opt.no_dropout, opt.init_type, opt.init_gain, opt.no_antialias, self.gpu_ids, opt)
135
+
136
+
137
+ if self.isTrain:
138
+ self.netD = networks.define_D(opt.output_nc, opt.ndf, opt.netD, opt.n_layers_D, opt.normD, opt.init_type, opt.init_gain, opt.no_antialias, self.gpu_ids, opt)
139
+
140
+ # define loss functions
141
+ self.criterionGAN = networks.GANLoss(opt.gan_mode).to(self.device)
142
+ self.criterionNCE = []
143
+ self.criterionHDCE = []
144
+
145
+ for i, nce_layer in enumerate(self.nce_layers):
146
+ self.criterionNCE.append(PatchNCELoss(opt).to(self.device))
147
+ self.criterionHDCE.append(PatchHDCELoss(opt=opt).to(self.device))
148
+
149
+ self.criterionIdt = torch.nn.L1Loss().to(self.device)
150
+ self.optimizer_G = torch.optim.Adam(self.netG.parameters(), lr=opt.lr, betas=(opt.beta1, opt.beta2))
151
+ self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=opt.lr, betas=(opt.beta1, opt.beta2))
152
+ self.optimizers.append(self.optimizer_G)
153
+ self.optimizers.append(self.optimizer_D)
154
+
155
+ self.criterionR = []
156
+ for nce_layer in self.nce_layers:
157
+ self.criterionR.append(SRC_Loss(opt).to(self.device))
158
+
159
+
160
+ def data_dependent_initialize(self, data):
161
+ """
162
+ The feature network netF is defined in terms of the shape of the intermediate, extracted
163
+ features of the encoder portion of netG. Because of this, the weights of netF are
164
+ initialized at the first feedforward pass with some input images.
165
+ Please also see PatchSampleF.create_mlp(), which is called at the first forward() call.
166
+ """
167
+ self.set_input(data)
168
+ bs_per_gpu = self.real_A.size(0) // max(len(self.opt.gpu_ids), 1)
169
+ self.real_A = self.real_A[:bs_per_gpu]
170
+ self.real_B = self.real_B[:bs_per_gpu]
171
+ self.forward() # compute fake images: G(A)
172
+ if self.opt.isTrain:
173
+ self.compute_D_loss().backward() # calculate gradients for D
174
+ self.compute_G_loss().backward() # calculate graidents for G
175
+ # if self.opt.lambda_NCE > 0.0:
176
+ # self.optimizer_F = torch.optim.Adam(self.netF.parameters(), lr=self.opt.lr, betas=(self.opt.beta1, self.opt.beta2))
177
+ # self.optimizers.append(self.optimizer_F)
178
+ #
179
+ # elif self.opt.lambda_HDCE > 0.0:
180
+ self.optimizer_F = torch.optim.Adam(self.netF.parameters(), lr=self.opt.lr, betas=(self.opt.beta1, self.opt.beta2))
181
+ self.optimizers.append(self.optimizer_F)
182
+
183
+
184
+ def optimize_parameters(self):
185
+ # forward
186
+ self.forward()
187
+
188
+ # update D
189
+ self.set_requires_grad(self.netD, True)
190
+ self.optimizer_D.zero_grad()
191
+ self.loss_D = self.compute_D_loss()
192
+ self.loss_D.backward()
193
+ self.optimizer_D.step()
194
+
195
+ # update G
196
+ self.set_requires_grad(self.netD, False)
197
+ self.optimizer_G.zero_grad()
198
+ if self.opt.netF == 'mlp_sample':
199
+ # if self.opt.lambda_NCE > 0.0:
200
+ # self.optimizer_F.zero_grad()
201
+ # elif self.opt.lambda_HDCE > 0.0:
202
+ self.optimizer_F.zero_grad()
203
+ self.loss_G = self.compute_G_loss()
204
+ self.loss_G.backward()
205
+ self.optimizer_G.step()
206
+ if self.opt.netF == 'mlp_sample':
207
+ # if self.opt.lambda_NCE > 0.0:
208
+ # self.optimizer_F.step()
209
+ # elif self.opt.lambda_HDCE > 0.0:
210
+ self.optimizer_F.step()
211
+
212
+ def set_input(self, input):
213
+ """Unpack input data from the dataloader and perform necessary pre-processing steps.
214
+ Parameters:
215
+ input (dict): include the data itself and its metadata information.
216
+ The option 'direction' can be used to swap domain A and domain B.
217
+ """
218
+ AtoB = self.opt.direction == 'AtoB'
219
+ self.real_A = input['A' if AtoB else 'B'].to(self.device)
220
+ self.real_B = input['B' if AtoB else 'A'].to(self.device)
221
+ self.image_paths = input['A_paths' if AtoB else 'B_paths']
222
+
223
+ def forward(self):
224
+ """Run forward pass; called by both functions <optimize_parameters> and <test>."""
225
+ self.real = torch.cat((self.real_A, self.real_B), dim=0) if self.opt.dce_idt and self.opt.isTrain else self.real_A
226
+ if self.opt.flip_equivariance:
227
+ self.flipped_for_equivariance = self.opt.isTrain and (np.random.random() < 0.5)
228
+ if self.flipped_for_equivariance:
229
+ self.real = torch.flip(self.real, [3])
230
+
231
+ self.fake = self.netG(self.real)
232
+ self.fake_B = self.fake[:self.real_A.size(0)]
233
+ if self.opt.dce_idt:
234
+ self.idt_B = self.fake[self.real_A.size(0):]
235
+
236
+
237
+ def set_epoch(self, epoch):
238
+ self.train_epoch = epoch
239
+
240
+ def compute_D_loss(self):
241
+ """Calculate GAN loss for the discriminator"""
242
+ fake = self.fake_B.detach()
243
+ # Fake; stop backprop to the generator by detaching fake_B
244
+ pred_fake = self.netD(fake)
245
+ self.loss_D_fake = self.criterionGAN(pred_fake, False).mean()
246
+ # Real
247
+ self.pred_real = self.netD(self.real_B)
248
+ loss_D_real = self.criterionGAN(self.pred_real, True)
249
+ self.loss_D_real = loss_D_real.mean()
250
+
251
+ # combine loss and calculate gradients
252
+ self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5
253
+ return self.loss_D
254
+
255
+ def compute_G_loss(self):
256
+ """Calculate GAN and NCE loss for the generator"""
257
+ fake = self.fake_B
258
+ # First, G(A) should fake the discriminator
259
+ if self.opt.lambda_GAN > 0.0:
260
+ pred_fake = self.netD(fake)
261
+ self.loss_G_GAN = self.criterionGAN(pred_fake, True).mean() * self.opt.lambda_GAN
262
+ else:
263
+ self.loss_G_GAN = 0.0
264
+
265
+ ## get feat
266
+ fake_B_feat = self.netG(self.fake_B, self.nce_layers, encode_only=True)
267
+ if self.opt.flip_equivariance and self.flipped_for_equivariance:
268
+ fake_B_feat = [torch.flip(fq, [3]) for fq in fake_B_feat]
269
+ real_A_feat = self.netG(self.real_A, self.nce_layers, encode_only=True)
270
+
271
+ fake_B_pool, sample_ids = self.netF(fake_B_feat, self.opt.num_patches, None)
272
+ real_A_pool, _ = self.netF(real_A_feat, self.opt.num_patches, sample_ids)
273
+
274
+ if self.opt.dce_idt:
275
+ idt_B_feat = self.netG(self.idt_B, self.nce_layers, encode_only=True)
276
+ if self.opt.flip_equivariance and self.flipped_for_equivariance:
277
+ idt_B_feat = [torch.flip(fq, [3]) for fq in idt_B_feat]
278
+ real_B_feat = self.netG(self.real_B, self.nce_layers, encode_only=True)
279
+
280
+ idt_B_pool, _ = self.netF(idt_B_feat, self.opt.num_patches, sample_ids)
281
+ real_B_pool, _ = self.netF(real_B_feat, self.opt.num_patches, sample_ids)
282
+
283
+
284
+ ## Relation Loss
285
+ self.loss_SRC, weight = self.calculate_R_loss(real_A_pool, fake_B_pool, epoch=self.train_epoch)
286
+
287
+
288
+ ## HDCE
289
+ if self.opt.lambda_HDCE > 0.0:
290
+ self.loss_HDCE = self.calculate_HDCE_loss(real_A_pool, fake_B_pool, weight)
291
+ else:
292
+ self.loss_HDCE, self.loss_HDCE_bd = 0.0, 0.0
293
+
294
+ self.loss_HDCE_Y = 0
295
+ if self.opt.dce_idt and self.opt.lambda_HDCE > 0.0:
296
+ _, weight_idt = self.calculate_R_loss(real_B_pool, idt_B_pool, only_weight=True, epoch=self.train_epoch)
297
+ self.loss_HDCE_Y = self.calculate_HDCE_loss(real_B_pool, idt_B_pool, weight_idt)
298
+ loss_HDCE_both = (self.loss_HDCE + self.loss_HDCE_Y) * 0.5
299
+ else:
300
+ loss_HDCE_both = self.loss_HDCE
301
+
302
+ self.loss_G = self.loss_G_GAN + loss_HDCE_both + self.loss_SRC
303
+ return self.loss_G
304
+
305
+
306
+ def calculate_HDCE_loss(self, src, tgt, weight=None):
307
+ n_layers = len(self.nce_layers)
308
+
309
+ feat_q_pool = tgt
310
+ feat_k_pool = src
311
+
312
+ total_HDCE_loss = 0.0
313
+ for f_q, f_k, crit, nce_layer, w in zip(feat_q_pool, feat_k_pool, self.criterionHDCE, self.nce_layers, weight):
314
+ if self.opt.no_Hneg:
315
+ w = None
316
+ loss = crit(f_q, f_k, w) * self.opt.lambda_HDCE
317
+ total_HDCE_loss += loss.mean()
318
+
319
+ return total_HDCE_loss / n_layers
320
+
321
+
322
+ def calculate_R_loss(self, src, tgt, only_weight=False, epoch=None):
323
+ n_layers = len(self.nce_layers)
324
+
325
+ feat_q_pool = tgt
326
+ feat_k_pool = src
327
+
328
+ total_SRC_loss = 0.0
329
+ weights=[]
330
+ for f_q, f_k, crit, nce_layer in zip(feat_q_pool, feat_k_pool, self.criterionR, self.nce_layers):
331
+ loss_SRC, weight = crit(f_q, f_k, only_weight, epoch)
332
+ total_SRC_loss += loss_SRC * self.opt.lambda_SRC
333
+ weights.append(weight)
334
+ return total_SRC_loss / n_layers, weights
335
+
336
+
337
+ #--------------------------------------------------------------------------------------------------------
338
+ def calculate_Patchloss(self, src, tgt, num_patch=4):
339
+
340
+ feat_org = self.netG(src, mode='encoder')
341
+ if self.opt.flip_equivariance and self.flipped_for_equivariance:
342
+ feat_org = torch.flip(feat_org, [3])
343
+
344
+ N,C,H,W = feat_org.size()
345
+
346
+ ps = H//num_patch
347
+ lam = get_spa_lambda(self.alpha,size=(1,1,num_patch**2),device = feat_org.device)
348
+ feat_org_unfold = F.unfold(feat_org,kernel_size=(ps,ps),padding=0,stride=ps)
349
+
350
+ rndperm = torch.randperm(feat_org_unfold.size(2))
351
+ feat_prm = feat_org_unfold[:,:,rndperm]
352
+ feat_mix = lam*feat_org_unfold + (1-lam)*feat_prm
353
+ feat_mix = F.fold(feat_mix,output_size=(H,W),kernel_size=(ps,ps),padding=0,stride=ps)
354
+
355
+ out_mix = self.netG(feat_mix,mode='decoder')
356
+ feat_mix_rec = self.netG(out_mix,mode='encoder')
357
+
358
+ fake_feat = self.netG(tgt,mode='encoder')
359
+
360
+ fake_feat_unfold = F.unfold(fake_feat,kernel_size=(ps,ps),padding=0,stride=ps)
361
+ fake_feat_prm = fake_feat_unfold[:,:,rndperm]
362
+ fake_feat_mix = lam*fake_feat_unfold + (1-lam)*fake_feat_prm
363
+ fake_feat_mix = F.fold(fake_feat_mix,output_size=(H,W),kernel_size=(ps,ps),padding=0,stride=ps)
364
+
365
+
366
+ PM_loss = torch.mean(torch.abs(fake_feat_mix - feat_mix_rec))
367
+
368
+ return 10*PM_loss
369
+
370
+ #--------------------------------------------------------------------------------------------------------
Scenimefy/models/hDCE.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from packaging import version
2
+ import torch
3
+ from torch import nn
4
+
5
+
6
+
7
+ class PatchHDCELoss(nn.Module):
8
+ def __init__(self, opt):
9
+ super().__init__()
10
+ self.opt = opt
11
+ self.cross_entropy_loss = torch.nn.CrossEntropyLoss(reduction='none')
12
+ self.mask_dtype = torch.uint8 if version.parse(torch.__version__) < version.parse('1.2.0') else torch.bool
13
+
14
+ def forward(self, feat_q, feat_k, weight=None):
15
+ batchSize = feat_q.shape[0]
16
+ dim = feat_q.shape[1]
17
+ feat_k = feat_k.detach()
18
+
19
+ # positive logit
20
+ l_pos = torch.bmm(feat_q.view(batchSize, 1, -1), feat_k.view(batchSize, -1, 1))
21
+ l_pos = l_pos.view(batchSize, 1)
22
+
23
+ if self.opt.nce_includes_all_negatives_from_minibatch:
24
+ # reshape features as if they are all negatives of minibatch of size 1.
25
+ batch_dim_for_bmm = 1
26
+ else:
27
+ batch_dim_for_bmm = self.opt.batch_size
28
+
29
+ # reshape features to batch size
30
+ feat_q = feat_q.view(batch_dim_for_bmm, -1, dim)
31
+ feat_k = feat_k.view(batch_dim_for_bmm, -1, dim)
32
+ npatches = feat_q.size(1)
33
+ l_neg_curbatch = torch.bmm(feat_q, feat_k.transpose(2, 1))
34
+
35
+ # weighted by semantic relation
36
+ if weight is not None:
37
+ l_neg_curbatch *= weight
38
+
39
+ diagonal = torch.eye(npatches, device=feat_q.device, dtype=self.mask_dtype)[None, :, :]
40
+ l_neg_curbatch.masked_fill_(diagonal, -10.0)
41
+ l_neg = l_neg_curbatch.view(-1, npatches)
42
+
43
+ logits = (l_neg-l_pos)/self.opt.nce_T
44
+ v = torch.logsumexp(logits, dim=1)
45
+ loss_vec = torch.exp(v-v.detach())
46
+
47
+ # for monitoring
48
+ out_dummy = torch.cat((l_pos, l_neg), dim=1) / self.opt.nce_T
49
+ CELoss_dummy = self.cross_entropy_loss(out_dummy, torch.zeros(out_dummy.size(0), dtype=torch.long, device=feat_q.device))
50
+
51
+ loss = loss_vec.mean()-1+CELoss_dummy.detach()
52
+
53
+ return loss
Scenimefy/models/networks.py ADDED
@@ -0,0 +1,1513 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from torch.nn import init
5
+ import functools
6
+ from torch.optim import lr_scheduler
7
+ import numpy as np
8
+ from Scenimefy.models.stylegan_networks import StyleGAN2Discriminator, StyleGAN2Generator, TileStyleGAN2Discriminator
9
+
10
+ ###############################################################################
11
+ # Helper Functions
12
+ ###############################################################################
13
+
14
+
15
+ def get_filter(filt_size=3):
16
+ if(filt_size == 1):
17
+ a = np.array([1., ])
18
+ elif(filt_size == 2):
19
+ a = np.array([1., 1.])
20
+ elif(filt_size == 3):
21
+ a = np.array([1., 2., 1.])
22
+ elif(filt_size == 4):
23
+ a = np.array([1., 3., 3., 1.])
24
+ elif(filt_size == 5):
25
+ a = np.array([1., 4., 6., 4., 1.])
26
+ elif(filt_size == 6):
27
+ a = np.array([1., 5., 10., 10., 5., 1.])
28
+ elif(filt_size == 7):
29
+ a = np.array([1., 6., 15., 20., 15., 6., 1.])
30
+
31
+ filt = torch.Tensor(a[:, None] * a[None, :])
32
+ filt = filt / torch.sum(filt)
33
+
34
+ return filt
35
+
36
+
37
+ class Downsample(nn.Module):
38
+ def __init__(self, channels, pad_type='reflect', filt_size=3, stride=2, pad_off=0):
39
+ super(Downsample, self).__init__()
40
+ self.filt_size = filt_size
41
+ self.pad_off = pad_off
42
+ self.pad_sizes = [int(1. * (filt_size - 1) / 2), int(np.ceil(1. * (filt_size - 1) / 2)), int(1. * (filt_size - 1) / 2), int(np.ceil(1. * (filt_size - 1) / 2))]
43
+ self.pad_sizes = [pad_size + pad_off for pad_size in self.pad_sizes]
44
+ self.stride = stride
45
+ self.off = int((self.stride - 1) / 2.)
46
+ self.channels = channels
47
+
48
+ filt = get_filter(filt_size=self.filt_size)
49
+ self.register_buffer('filt', filt[None, None, :, :].repeat((self.channels, 1, 1, 1)))
50
+
51
+ self.pad = get_pad_layer(pad_type)(self.pad_sizes)
52
+
53
+ def forward(self, inp):
54
+ if(self.filt_size == 1):
55
+ if(self.pad_off == 0):
56
+ return inp[:, :, ::self.stride, ::self.stride]
57
+ else:
58
+ return self.pad(inp)[:, :, ::self.stride, ::self.stride]
59
+ else:
60
+ return F.conv2d(self.pad(inp), self.filt, stride=self.stride, groups=inp.shape[1])
61
+
62
+
63
+ class Upsample2(nn.Module):
64
+ def __init__(self, scale_factor, mode='nearest'):
65
+ super().__init__()
66
+ self.factor = scale_factor
67
+ self.mode = mode
68
+
69
+ def forward(self, x):
70
+ return torch.nn.functional.interpolate(x, scale_factor=self.factor, mode=self.mode)
71
+
72
+
73
+ class Upsample(nn.Module):
74
+ def __init__(self, channels, pad_type='repl', filt_size=4, stride=2):
75
+ super(Upsample, self).__init__()
76
+ self.filt_size = filt_size
77
+ self.filt_odd = np.mod(filt_size, 2) == 1
78
+ self.pad_size = int((filt_size - 1) / 2)
79
+ self.stride = stride
80
+ self.off = int((self.stride - 1) / 2.)
81
+ self.channels = channels
82
+
83
+ filt = get_filter(filt_size=self.filt_size) * (stride**2)
84
+ self.register_buffer('filt', filt[None, None, :, :].repeat((self.channels, 1, 1, 1)))
85
+
86
+ self.pad = get_pad_layer(pad_type)([1, 1, 1, 1])
87
+
88
+ def forward(self, inp):
89
+ ret_val = F.conv_transpose2d(self.pad(inp), self.filt, stride=self.stride, padding=1 + self.pad_size, groups=inp.shape[1])[:, :, 1:, 1:]
90
+ if(self.filt_odd):
91
+ return ret_val
92
+ else:
93
+ return ret_val[:, :, :-1, :-1]
94
+
95
+
96
+ def get_pad_layer(pad_type):
97
+ if(pad_type in ['refl', 'reflect']):
98
+ PadLayer = nn.ReflectionPad2d
99
+ elif(pad_type in ['repl', 'replicate']):
100
+ PadLayer = nn.ReplicationPad2d
101
+ elif(pad_type == 'zero'):
102
+ PadLayer = nn.ZeroPad2d
103
+ else:
104
+ print('Pad type [%s] not recognized' % pad_type)
105
+ return PadLayer
106
+
107
+
108
+ class Identity(nn.Module):
109
+ def forward(self, x):
110
+ return x
111
+
112
+
113
+ def get_norm_layer(norm_type='instance'):
114
+ """Return a normalization layer
115
+
116
+ Parameters:
117
+ norm_type (str) -- the name of the normalization layer: batch | instance | none
118
+
119
+ For BatchNorm, we use learnable affine parameters and track running statistics (mean/stddev).
120
+ For InstanceNorm, we do not use learnable affine parameters. We do not track running statistics.
121
+ """
122
+ if norm_type == 'batch':
123
+ norm_layer = functools.partial(nn.BatchNorm2d, affine=True, track_running_stats=True)
124
+ elif norm_type == 'instance':
125
+ norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False)
126
+ elif norm_type == 'none':
127
+ def norm_layer(x):
128
+ return Identity()
129
+ else:
130
+ raise NotImplementedError('normalization layer [%s] is not found' % norm_type)
131
+ return norm_layer
132
+
133
+
134
+ def get_scheduler(optimizer, opt):
135
+ """Return a learning rate scheduler
136
+
137
+ Parameters:
138
+ optimizer -- the optimizer of the network
139
+ opt (option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions. 
140
+ opt.lr_policy is the name of learning rate policy: linear | step | plateau | cosine
141
+
142
+ For 'linear', we keep the same learning rate for the first <opt.n_epochs> epochs
143
+ and linearly decay the rate to zero over the next <opt.n_epochs_decay> epochs.
144
+ For other schedulers (step, plateau, and cosine), we use the default PyTorch schedulers.
145
+ See https://pytorch.org/docs/stable/optim.html for more details.
146
+ """
147
+ if opt.lr_policy == 'linear':
148
+ def lambda_rule(epoch):
149
+ lr_l = 1.0 - max(0, epoch + opt.epoch_count - opt.n_epochs) / float(opt.n_epochs_decay + 1)
150
+ return lr_l
151
+ scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule)
152
+ elif opt.lr_policy == 'step':
153
+ scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_iters, gamma=0.1)
154
+ elif opt.lr_policy == 'plateau':
155
+ scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5)
156
+ elif opt.lr_policy == 'cosine':
157
+ scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=opt.n_epochs, eta_min=0)
158
+ else:
159
+ return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy)
160
+ return scheduler
161
+
162
+
163
+ def init_weights(net, init_type='normal', init_gain=0.02, debug=False):
164
+ """Initialize network weights.
165
+
166
+ Parameters:
167
+ net (network) -- network to be initialized
168
+ init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal
169
+ init_gain (float) -- scaling factor for normal, xavier and orthogonal.
170
+
171
+ We use 'normal' in the original pix2pix and CycleGAN paper. But xavier and kaiming might
172
+ work better for some applications. Feel free to try yourself.
173
+ """
174
+ def init_func(m): # define the initialization function
175
+ classname = m.__class__.__name__
176
+ if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
177
+ if debug:
178
+ print(classname)
179
+ if init_type == 'normal':
180
+ init.normal_(m.weight.data, 0.0, init_gain)
181
+ elif init_type == 'xavier':
182
+ init.xavier_normal_(m.weight.data, gain=init_gain)
183
+ elif init_type == 'kaiming':
184
+ init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
185
+ elif init_type == 'orthogonal':
186
+ init.orthogonal_(m.weight.data, gain=init_gain)
187
+ else:
188
+ raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
189
+ if hasattr(m, 'bias') and m.bias is not None:
190
+ init.constant_(m.bias.data, 0.0)
191
+ elif classname.find('BatchNorm2d') != -1: # BatchNorm Layer's weight is not a matrix; only normal distribution applies.
192
+ init.normal_(m.weight.data, 1.0, init_gain)
193
+ init.constant_(m.bias.data, 0.0)
194
+
195
+ net.apply(init_func) # apply the initialization function <init_func>
196
+
197
+
198
+ def init_net(net, init_type='normal', init_gain=0.02, gpu_ids=[], debug=False, initialize_weights=True):
199
+ """Initialize a network: 1. register CPU/GPU device (with multi-GPU support); 2. initialize the network weights
200
+ Parameters:
201
+ net (network) -- the network to be initialized
202
+ init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal
203
+ gain (float) -- scaling factor for normal, xavier and orthogonal.
204
+ gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2
205
+
206
+ Return an initialized network.
207
+ """
208
+ if len(gpu_ids) > 0:
209
+ assert(torch.cuda.is_available())
210
+ net.to(gpu_ids[0])
211
+ # if not amp:
212
+ # net = torch.nn.DataParallel(net, gpu_ids) # multi-GPUs for non-AMP training
213
+ if initialize_weights:
214
+ init_weights(net, init_type, init_gain=init_gain, debug=debug)
215
+ return net
216
+
217
+
218
+ def define_G(input_nc, output_nc, ngf, netG, norm='batch', use_dropout=False, init_type='normal',
219
+ init_gain=0.02, no_antialias=False, no_antialias_up=False, gpu_ids=[], opt=None):
220
+ """Create a generator
221
+
222
+ Parameters:
223
+ input_nc (int) -- the number of channels in input images
224
+ output_nc (int) -- the number of channels in output images
225
+ ngf (int) -- the number of filters in the last conv layer
226
+ netG (str) -- the architecture's name: resnet_9blocks | resnet_6blocks | unet_256 | unet_128
227
+ norm (str) -- the name of normalization layers used in the network: batch | instance | none
228
+ use_dropout (bool) -- if use dropout layers.
229
+ init_type (str) -- the name of our initialization method.
230
+ init_gain (float) -- scaling factor for normal, xavier and orthogonal.
231
+ gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2
232
+
233
+ Returns a generator
234
+
235
+ Our current implementation provides two types of generators:
236
+ U-Net: [unet_128] (for 128x128 input images) and [unet_256] (for 256x256 input images)
237
+ The original U-Net paper: https://arxiv.org/abs/1505.04597
238
+
239
+ Resnet-based generator: [resnet_6blocks] (with 6 Resnet blocks) and [resnet_9blocks] (with 9 Resnet blocks)
240
+ Resnet-based generator consists of several Resnet blocks between a few downsampling/upsampling operations.
241
+ We adapt Torch code from Justin Johnson's neural style transfer project (https://github.com/jcjohnson/fast-neural-style).
242
+
243
+
244
+ The generator has been initialized by <init_net>. It uses RELU for non-linearity.
245
+ """
246
+ net = None
247
+ norm_layer = get_norm_layer(norm_type=norm)
248
+
249
+ if netG == 'resnet_9blocks':
250
+ net = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, no_antialias=no_antialias, no_antialias_up=no_antialias_up, n_blocks=9, opt=opt)
251
+ elif netG == 'resnet_6blocks':
252
+ net = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, no_antialias=no_antialias, no_antialias_up=no_antialias_up, n_blocks=6, opt=opt)
253
+ elif netG == 'resnet_4blocks':
254
+ net = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, no_antialias=no_antialias, no_antialias_up=no_antialias_up, n_blocks=4, opt=opt)
255
+ elif netG == 'unet_128':
256
+ net = UnetGenerator(input_nc, output_nc, 7, ngf, norm_layer=norm_layer, use_dropout=use_dropout)
257
+ elif netG == 'unet_256':
258
+ net = UnetGenerator(input_nc, output_nc, 8, ngf, norm_layer=norm_layer, use_dropout=use_dropout)
259
+ elif netG == 'stylegan2':
260
+ net = StyleGAN2Generator(input_nc, output_nc, ngf, use_dropout=use_dropout, opt=opt)
261
+ elif netG == 'smallstylegan2':
262
+ net = StyleGAN2Generator(input_nc, output_nc, ngf, use_dropout=use_dropout, n_blocks=2, opt=opt)
263
+ elif netG == 'resnet_cat':
264
+ n_blocks = 8
265
+ net = G_Resnet(input_nc, output_nc, opt.nz, num_downs=2, n_res=n_blocks - 4, ngf=ngf, norm='inst', nl_layer='relu')
266
+ else:
267
+ raise NotImplementedError('Generator model name [%s] is not recognized' % netG)
268
+ return init_net(net, init_type, init_gain, gpu_ids, initialize_weights=('stylegan2' not in netG))
269
+
270
+
271
+ def define_F(input_nc, netF, norm='batch', use_dropout=False, init_type='normal', init_gain=0.02, no_antialias=False, gpu_ids=[], opt=None):
272
+ if netF == 'global_pool':
273
+ net = PoolingF()
274
+ elif netF == 'reshape':
275
+ net = ReshapeF()
276
+ elif netF == 'sample':
277
+ net = PatchSampleF(use_mlp=False, init_type=init_type, init_gain=init_gain, gpu_ids=gpu_ids, nc=opt.netF_nc)
278
+ elif netF == 'mlp_sample':
279
+ net = PatchSampleF(use_mlp=True, init_type=init_type, init_gain=init_gain, gpu_ids=gpu_ids, nc=opt.netF_nc)
280
+ elif netF == 'strided_conv':
281
+ net = StridedConvF(init_type=init_type, init_gain=init_gain, gpu_ids=gpu_ids)
282
+ else:
283
+ raise NotImplementedError('projection model name [%s] is not recognized' % netF)
284
+ return init_net(net, init_type, init_gain, gpu_ids)
285
+
286
+ def define_D(input_nc, ndf, netD, n_layers_D=3, norm='batch', init_type='normal', init_gain=0.02, no_antialias=False, gpu_ids=[], opt=None):
287
+ """Create a discriminator
288
+
289
+ Parameters:
290
+ input_nc (int) -- the number of channels in input images
291
+ ndf (int) -- the number of filters in the first conv layer
292
+ netD (str) -- the architecture's name: basic | n_layers | pixel
293
+ n_layers_D (int) -- the number of conv layers in the discriminator; effective when netD=='n_layers'
294
+ norm (str) -- the type of normalization layers used in the network.
295
+ init_type (str) -- the name of the initialization method.
296
+ init_gain (float) -- scaling factor for normal, xavier and orthogonal.
297
+ gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2
298
+
299
+ Returns a discriminator
300
+
301
+ Our current implementation provides three types of discriminators:
302
+ [basic]: 'PatchGAN' classifier described in the original pix2pix paper.
303
+ It can classify whether 70×70 overlapping patches are real or fake.
304
+ Such a patch-level discriminator architecture has fewer parameters
305
+ than a full-image discriminator and can work on arbitrarily-sized images
306
+ in a fully convolutional fashion.
307
+
308
+ [n_layers]: With this mode, you cna specify the number of conv layers in the discriminator
309
+ with the parameter <n_layers_D> (default=3 as used in [basic] (PatchGAN).)
310
+
311
+ [pixel]: 1x1 PixelGAN discriminator can classify whether a pixel is real or not.
312
+ It encourages greater color diversity but has no effect on spatial statistics.
313
+
314
+ The discriminator has been initialized by <init_net>. It uses Leaky RELU for non-linearity.
315
+ """
316
+ net = None
317
+ norm_layer = get_norm_layer(norm_type=norm)
318
+
319
+ if netD == 'basic': # default PatchGAN classifier
320
+ net = NLayerDiscriminator(input_nc, ndf, n_layers=3, norm_layer=norm_layer, no_antialias=no_antialias,)
321
+ elif netD == 'n_layers': # more options
322
+ net = NLayerDiscriminator(input_nc, ndf, n_layers_D, norm_layer=norm_layer, no_antialias=no_antialias,)
323
+ elif netD == 'pixel': # classify if each pixel is real or fake
324
+ net = PixelDiscriminator(input_nc, ndf, norm_layer=norm_layer)
325
+ elif 'stylegan2' in netD:
326
+ net = StyleGAN2Discriminator(input_nc, ndf, n_layers_D, no_antialias=no_antialias, opt=opt)
327
+ else:
328
+ raise NotImplementedError('Discriminator model name [%s] is not recognized' % netD)
329
+ return init_net(net, init_type, init_gain, gpu_ids,
330
+ initialize_weights=('stylegan2' not in netD))
331
+
332
+
333
+ ##############################################################################
334
+ # Classes
335
+ ##############################################################################
336
+ class GANLoss(nn.Module):
337
+ """Define different GAN objectives.
338
+
339
+ The GANLoss class abstracts away the need to create the target label tensor
340
+ that has the same size as the input.
341
+ """
342
+
343
+ def __init__(self, gan_mode, target_real_label=1.0, target_fake_label=0.0):
344
+ """ Initialize the GANLoss class.
345
+
346
+ Parameters:
347
+ gan_mode (str) - - the type of GAN objective. It currently supports vanilla, lsgan, and wgangp.
348
+ target_real_label (bool) - - label for a real image
349
+ target_fake_label (bool) - - label of a fake image
350
+
351
+ Note: Do not use sigmoid as the last layer of Discriminator.
352
+ LSGAN needs no sigmoid. vanilla GANs will handle it with BCEWithLogitsLoss.
353
+ """
354
+ super(GANLoss, self).__init__()
355
+ self.register_buffer('real_label', torch.tensor(target_real_label))
356
+ self.register_buffer('fake_label', torch.tensor(target_fake_label))
357
+ self.gan_mode = gan_mode
358
+ if gan_mode == 'lsgan':
359
+ self.loss = nn.MSELoss()
360
+ elif gan_mode == 'vanilla':
361
+ self.loss = nn.BCEWithLogitsLoss()
362
+ elif gan_mode in ['wgangp', 'nonsaturating']:
363
+ self.loss = None
364
+ else:
365
+ raise NotImplementedError('gan mode %s not implemented' % gan_mode)
366
+
367
+ def get_target_tensor(self, prediction, target_is_real):
368
+ """Create label tensors with the same size as the input.
369
+
370
+ Parameters:
371
+ prediction (tensor) - - tpyically the prediction from a discriminator
372
+ target_is_real (bool) - - if the ground truth label is for real images or fake images
373
+
374
+ Returns:
375
+ A label tensor filled with ground truth label, and with the size of the input
376
+ """
377
+
378
+ if target_is_real:
379
+ target_tensor = self.real_label
380
+ else:
381
+ target_tensor = self.fake_label
382
+ return target_tensor.expand_as(prediction)
383
+
384
+ def __call__(self, prediction, target_is_real):
385
+ """Calculate loss given Discriminator's output and grount truth labels.
386
+
387
+ Parameters:
388
+ prediction (tensor) - - tpyically the prediction output from a discriminator
389
+ target_is_real (bool) - - if the ground truth label is for real images or fake images
390
+
391
+ Returns:
392
+ the calculated loss.
393
+ """
394
+ bs = prediction.size(0)
395
+ if self.gan_mode in ['lsgan', 'vanilla']:
396
+ target_tensor = self.get_target_tensor(prediction, target_is_real)
397
+ loss = self.loss(prediction, target_tensor)
398
+ elif self.gan_mode == 'wgangp':
399
+ if target_is_real:
400
+ loss = -prediction.mean()
401
+ else:
402
+ loss = prediction.mean()
403
+ elif self.gan_mode == 'nonsaturating':
404
+ if target_is_real:
405
+ loss = F.softplus(-prediction).view(bs, -1).mean(dim=1)
406
+ else:
407
+ loss = F.softplus(prediction).view(bs, -1).mean(dim=1)
408
+ return loss
409
+
410
+
411
+ def cal_gradient_penalty(netD, real_data, fake_data, device, type='mixed', constant=1.0, lambda_gp=10.0):
412
+ """Calculate the gradient penalty loss, used in WGAN-GP paper https://arxiv.org/abs/1704.00028
413
+
414
+ Arguments:
415
+ netD (network) -- discriminator network
416
+ real_data (tensor array) -- real images
417
+ fake_data (tensor array) -- generated images from the generator
418
+ device (str) -- GPU / CPU: from torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu')
419
+ type (str) -- if we mix real and fake data or not [real | fake | mixed].
420
+ constant (float) -- the constant used in formula ( | |gradient||_2 - constant)^2
421
+ lambda_gp (float) -- weight for this loss
422
+
423
+ Returns the gradient penalty loss
424
+ """
425
+ if lambda_gp > 0.0:
426
+ if type == 'real': # either use real images, fake images, or a linear interpolation of two.
427
+ interpolatesv = real_data
428
+ elif type == 'fake':
429
+ interpolatesv = fake_data
430
+ elif type == 'mixed':
431
+ alpha = torch.rand(real_data.shape[0], 1, device=device)
432
+ alpha = alpha.expand(real_data.shape[0], real_data.nelement() // real_data.shape[0]).contiguous().view(*real_data.shape)
433
+ interpolatesv = alpha * real_data + ((1 - alpha) * fake_data)
434
+ else:
435
+ raise NotImplementedError('{} not implemented'.format(type))
436
+ interpolatesv.requires_grad_(True)
437
+ disc_interpolates = netD(interpolatesv)
438
+ gradients = torch.autograd.grad(outputs=disc_interpolates, inputs=interpolatesv,
439
+ grad_outputs=torch.ones(disc_interpolates.size()).to(device),
440
+ create_graph=True, retain_graph=True, only_inputs=True)
441
+ gradients = gradients[0].view(real_data.size(0), -1) # flat the data
442
+ gradient_penalty = (((gradients + 1e-16).norm(2, dim=1) - constant) ** 2).mean() * lambda_gp # added eps
443
+ return gradient_penalty, gradients
444
+ else:
445
+ return 0.0, None
446
+
447
+
448
+ class Normalize(nn.Module):
449
+
450
+ def __init__(self, power=2):
451
+ super(Normalize, self).__init__()
452
+ self.power = power
453
+
454
+ def forward(self, x):
455
+ norm = x.pow(self.power).sum(1, keepdim=True).pow(1. / self.power)
456
+ out = x.div(norm + 1e-7)
457
+ return out
458
+
459
+
460
+ class PoolingF(nn.Module):
461
+ def __init__(self):
462
+ super(PoolingF, self).__init__()
463
+ model = [nn.AdaptiveMaxPool2d(1)]
464
+ self.model = nn.Sequential(*model)
465
+ self.l2norm = Normalize(2)
466
+
467
+ def forward(self, x):
468
+ return self.l2norm(self.model(x))
469
+
470
+
471
+ class ReshapeF(nn.Module):
472
+ def __init__(self):
473
+ super(ReshapeF, self).__init__()
474
+ model = [nn.AdaptiveAvgPool2d(4)]
475
+ self.model = nn.Sequential(*model)
476
+ self.l2norm = Normalize(2)
477
+
478
+ def forward(self, x):
479
+ x = self.model(x)
480
+ x_reshape = x.permute(0, 2, 3, 1).flatten(0, 2)
481
+ return self.l2norm(x_reshape)
482
+
483
+
484
+ class StridedConvF(nn.Module):
485
+ def __init__(self, init_type='normal', init_gain=0.02, gpu_ids=[]):
486
+ super().__init__()
487
+ # self.conv1 = nn.Conv2d(256, 128, 3, stride=2)
488
+ # self.conv2 = nn.Conv2d(128, 64, 3, stride=1)
489
+ self.l2_norm = Normalize(2)
490
+ self.mlps = {}
491
+ self.moving_averages = {}
492
+ self.init_type = init_type
493
+ self.init_gain = init_gain
494
+ self.gpu_ids = gpu_ids
495
+
496
+ def create_mlp(self, x):
497
+ C, H = x.shape[1], x.shape[2]
498
+ n_down = int(np.rint(np.log2(H / 32)))
499
+ mlp = []
500
+ for i in range(n_down):
501
+ mlp.append(nn.Conv2d(C, max(C // 2, 64), 3, stride=2))
502
+ mlp.append(nn.ReLU())
503
+ C = max(C // 2, 64)
504
+ mlp.append(nn.Conv2d(C, 64, 3))
505
+ mlp = nn.Sequential(*mlp)
506
+ init_net(mlp, self.init_type, self.init_gain, self.gpu_ids)
507
+ return mlp
508
+
509
+ def update_moving_average(self, key, x):
510
+ if key not in self.moving_averages:
511
+ self.moving_averages[key] = x.detach()
512
+
513
+ self.moving_averages[key] = self.moving_averages[key] * 0.999 + x.detach() * 0.001
514
+
515
+ def forward(self, x, use_instance_norm=False):
516
+ C, H = x.shape[1], x.shape[2]
517
+ key = '%d_%d' % (C, H)
518
+ if key not in self.mlps:
519
+ self.mlps[key] = self.create_mlp(x)
520
+ self.add_module("child_%s" % key, self.mlps[key])
521
+ mlp = self.mlps[key]
522
+ x = mlp(x)
523
+ self.update_moving_average(key, x)
524
+ x = x - self.moving_averages[key]
525
+ if use_instance_norm:
526
+ x = F.instance_norm(x)
527
+ return self.l2_norm(x)
528
+
529
+
530
+ class PatchSampleF(nn.Module):
531
+ def __init__(self, use_mlp=False, init_type='normal', init_gain=0.02, nc=256, gpu_ids=[]):
532
+ # potential issues: currently, we use the same patch_ids for multiple images in the batch
533
+ super(PatchSampleF, self).__init__()
534
+ self.l2norm = Normalize(2)
535
+ self.use_mlp = use_mlp
536
+ self.nc = nc # hard-coded
537
+ self.mlp_init = False
538
+ self.init_type = init_type
539
+ self.init_gain = init_gain
540
+ self.gpu_ids = gpu_ids
541
+
542
+ def create_mlp(self, feats):
543
+ for mlp_id, feat in enumerate(feats):
544
+ input_nc = feat.shape[1]
545
+ mlp = nn.Sequential(*[nn.Linear(input_nc, self.nc), nn.ReLU(), nn.Linear(self.nc, self.nc)])
546
+ if len(self.gpu_ids) > 0:
547
+ mlp.cuda()
548
+ setattr(self, 'mlp_%d' % mlp_id, mlp)
549
+ init_net(self, self.init_type, self.init_gain, self.gpu_ids)
550
+ self.mlp_init = True
551
+
552
+ def forward(self, feats, num_patches=64, patch_ids=None):
553
+ return_ids = []
554
+ return_feats = []
555
+ if self.use_mlp and not self.mlp_init:
556
+ self.create_mlp(feats)
557
+ for feat_id, feat in enumerate(feats):
558
+ B, H, W = feat.shape[0], feat.shape[2], feat.shape[3]
559
+ feat_reshape = feat.permute(0, 2, 3, 1).flatten(1, 2)
560
+ if num_patches > 0:
561
+ if patch_ids is not None:
562
+ patch_id = patch_ids[feat_id]
563
+ else:
564
+ patch_id = torch.randperm(feat_reshape.shape[1], device=feats[0].device)
565
+ patch_id = patch_id[:int(min(num_patches, patch_id.shape[0]))] # .to(patch_ids.device)
566
+ x_sample = feat_reshape[:, patch_id, :].flatten(0, 1) # reshape(-1, x.shape[1])
567
+ else:
568
+ x_sample = feat_reshape
569
+ patch_id = []
570
+ if self.use_mlp:
571
+ mlp = getattr(self, 'mlp_%d' % feat_id)
572
+ x_sample = mlp(x_sample)
573
+ return_ids.append(patch_id)
574
+ x_sample = self.l2norm(x_sample)
575
+
576
+ if num_patches == 0:
577
+ x_sample = x_sample.permute(0, 2, 1).reshape([B, x_sample.shape[-1], H, W])
578
+ return_feats.append(x_sample)
579
+ return return_feats, return_ids
580
+
581
+
582
+ class G_Resnet(nn.Module):
583
+ def __init__(self, input_nc, output_nc, nz, num_downs, n_res, ngf=64,
584
+ norm=None, nl_layer=None):
585
+ super(G_Resnet, self).__init__()
586
+ n_downsample = num_downs
587
+ pad_type = 'reflect'
588
+ self.enc_content = ContentEncoder(n_downsample, n_res, input_nc, ngf, norm, nl_layer, pad_type=pad_type)
589
+ if nz == 0:
590
+ self.dec = Decoder(n_downsample, n_res, self.enc_content.output_dim, output_nc, norm=norm, activ=nl_layer, pad_type=pad_type, nz=nz)
591
+ else:
592
+ self.dec = Decoder_all(n_downsample, n_res, self.enc_content.output_dim, output_nc, norm=norm, activ=nl_layer, pad_type=pad_type, nz=nz)
593
+
594
+ def decode(self, content, style=None):
595
+ return self.dec(content, style)
596
+
597
+ def forward(self, image, style=None, nce_layers=[], encode_only=False):
598
+ content, feats = self.enc_content(image, nce_layers=nce_layers, encode_only=encode_only)
599
+ if encode_only:
600
+ return feats
601
+ else:
602
+ images_recon = self.decode(content, style)
603
+ if len(nce_layers) > 0:
604
+ return images_recon, feats
605
+ else:
606
+ return images_recon
607
+
608
+ ##################################################################################
609
+ # Encoder and Decoders
610
+ ##################################################################################
611
+
612
+
613
+ class E_adaIN(nn.Module):
614
+ def __init__(self, input_nc, output_nc=1, nef=64, n_layers=4,
615
+ norm=None, nl_layer=None, vae=False):
616
+ # style encoder
617
+ super(E_adaIN, self).__init__()
618
+ self.enc_style = StyleEncoder(n_layers, input_nc, nef, output_nc, norm='none', activ='relu', vae=vae)
619
+
620
+ def forward(self, image):
621
+ style = self.enc_style(image)
622
+ return style
623
+
624
+
625
+ class StyleEncoder(nn.Module):
626
+ def __init__(self, n_downsample, input_dim, dim, style_dim, norm, activ, vae=False):
627
+ super(StyleEncoder, self).__init__()
628
+ self.vae = vae
629
+ self.model = []
630
+ self.model += [Conv2dBlock(input_dim, dim, 7, 1, 3, norm=norm, activation=activ, pad_type='reflect')]
631
+ for i in range(2):
632
+ self.model += [Conv2dBlock(dim, 2 * dim, 4, 2, 1, norm=norm, activation=activ, pad_type='reflect')]
633
+ dim *= 2
634
+ for i in range(n_downsample - 2):
635
+ self.model += [Conv2dBlock(dim, dim, 4, 2, 1, norm=norm, activation=activ, pad_type='reflect')]
636
+ self.model += [nn.AdaptiveAvgPool2d(1)] # global average pooling
637
+ if self.vae:
638
+ self.fc_mean = nn.Linear(dim, style_dim) # , 1, 1, 0)
639
+ self.fc_var = nn.Linear(dim, style_dim) # , 1, 1, 0)
640
+ else:
641
+ self.model += [nn.Conv2d(dim, style_dim, 1, 1, 0)]
642
+
643
+ self.model = nn.Sequential(*self.model)
644
+ self.output_dim = dim
645
+
646
+ def forward(self, x):
647
+ if self.vae:
648
+ output = self.model(x)
649
+ output = output.view(x.size(0), -1)
650
+ output_mean = self.fc_mean(output)
651
+ output_var = self.fc_var(output)
652
+ return output_mean, output_var
653
+ else:
654
+ return self.model(x).view(x.size(0), -1)
655
+
656
+
657
+ class ContentEncoder(nn.Module):
658
+ def __init__(self, n_downsample, n_res, input_dim, dim, norm, activ, pad_type='zero'):
659
+ super(ContentEncoder, self).__init__()
660
+ self.model = []
661
+ self.model += [Conv2dBlock(input_dim, dim, 7, 1, 3, norm=norm, activation=activ, pad_type='reflect')]
662
+ # downsampling blocks
663
+ for i in range(n_downsample):
664
+ self.model += [Conv2dBlock(dim, 2 * dim, 4, 2, 1, norm=norm, activation=activ, pad_type='reflect')]
665
+ dim *= 2
666
+ # residual blocks
667
+ self.model += [ResBlocks(n_res, dim, norm=norm, activation=activ, pad_type=pad_type)]
668
+ self.model = nn.Sequential(*self.model)
669
+ self.output_dim = dim
670
+
671
+ def forward(self, x, nce_layers=[], encode_only=False):
672
+ if len(nce_layers) > 0:
673
+ feat = x
674
+ feats = []
675
+ for layer_id, layer in enumerate(self.model):
676
+ feat = layer(feat)
677
+ if layer_id in nce_layers:
678
+ feats.append(feat)
679
+ if layer_id == nce_layers[-1] and encode_only:
680
+ return None, feats
681
+ return feat, feats
682
+ else:
683
+ return self.model(x), None
684
+
685
+ for layer_id, layer in enumerate(self.model):
686
+ print(layer_id, layer)
687
+
688
+
689
+ class Decoder_all(nn.Module):
690
+ def __init__(self, n_upsample, n_res, dim, output_dim, norm='batch', activ='relu', pad_type='zero', nz=0):
691
+ super(Decoder_all, self).__init__()
692
+ # AdaIN residual blocks
693
+ self.resnet_block = ResBlocks(n_res, dim, norm, activ, pad_type=pad_type, nz=nz)
694
+ self.n_blocks = 0
695
+ # upsampling blocks
696
+ for i in range(n_upsample):
697
+ block = [Upsample2(scale_factor=2), Conv2dBlock(dim + nz, dim // 2, 5, 1, 2, norm='ln', activation=activ, pad_type='reflect')]
698
+ setattr(self, 'block_{:d}'.format(self.n_blocks), nn.Sequential(*block))
699
+ self.n_blocks += 1
700
+ dim //= 2
701
+ # use reflection padding in the last conv layer
702
+ setattr(self, 'block_{:d}'.format(self.n_blocks), Conv2dBlock(dim + nz, output_dim, 7, 1, 3, norm='none', activation='tanh', pad_type='reflect'))
703
+ self.n_blocks += 1
704
+
705
+ def forward(self, x, y=None):
706
+ if y is not None:
707
+ output = self.resnet_block(cat_feature(x, y))
708
+ for n in range(self.n_blocks):
709
+ block = getattr(self, 'block_{:d}'.format(n))
710
+ if n > 0:
711
+ output = block(cat_feature(output, y))
712
+ else:
713
+ output = block(output)
714
+ return output
715
+
716
+
717
+ class Decoder(nn.Module):
718
+ def __init__(self, n_upsample, n_res, dim, output_dim, norm='batch', activ='relu', pad_type='zero', nz=0):
719
+ super(Decoder, self).__init__()
720
+
721
+ self.model = []
722
+ # AdaIN residual blocks
723
+ self.model += [ResBlocks(n_res, dim, norm, activ, pad_type=pad_type, nz=nz)]
724
+ # upsampling blocks
725
+ for i in range(n_upsample):
726
+ if i == 0:
727
+ input_dim = dim + nz
728
+ else:
729
+ input_dim = dim
730
+ self.model += [Upsample2(scale_factor=2), Conv2dBlock(input_dim, dim // 2, 5, 1, 2, norm='ln', activation=activ, pad_type='reflect')]
731
+ dim //= 2
732
+ # use reflection padding in the last conv layer
733
+ self.model += [Conv2dBlock(dim, output_dim, 7, 1, 3, norm='none', activation='tanh', pad_type='reflect')]
734
+ self.model = nn.Sequential(*self.model)
735
+
736
+ def forward(self, x, y=None):
737
+ if y is not None:
738
+ return self.model(cat_feature(x, y))
739
+ else:
740
+ return self.model(x)
741
+
742
+ ##################################################################################
743
+ # Sequential Models
744
+ ##################################################################################
745
+
746
+
747
+ class ResBlocks(nn.Module):
748
+ def __init__(self, num_blocks, dim, norm='inst', activation='relu', pad_type='zero', nz=0):
749
+ super(ResBlocks, self).__init__()
750
+ self.model = []
751
+ for i in range(num_blocks):
752
+ self.model += [ResBlock(dim, norm=norm, activation=activation, pad_type=pad_type, nz=nz)]
753
+ self.model = nn.Sequential(*self.model)
754
+
755
+ def forward(self, x):
756
+ return self.model(x)
757
+
758
+
759
+ ##################################################################################
760
+ # Basic Blocks
761
+ ##################################################################################
762
+ def cat_feature(x, y):
763
+ y_expand = y.view(y.size(0), y.size(1), 1, 1).expand(
764
+ y.size(0), y.size(1), x.size(2), x.size(3))
765
+ x_cat = torch.cat([x, y_expand], 1)
766
+ return x_cat
767
+
768
+
769
+ class ResBlock(nn.Module):
770
+ def __init__(self, dim, norm='inst', activation='relu', pad_type='zero', nz=0):
771
+ super(ResBlock, self).__init__()
772
+
773
+ model = []
774
+ model += [Conv2dBlock(dim + nz, dim, 3, 1, 1, norm=norm, activation=activation, pad_type=pad_type)]
775
+ model += [Conv2dBlock(dim, dim + nz, 3, 1, 1, norm=norm, activation='none', pad_type=pad_type)]
776
+ self.model = nn.Sequential(*model)
777
+
778
+ def forward(self, x):
779
+ residual = x
780
+ out = self.model(x)
781
+ out += residual
782
+ return out
783
+
784
+
785
+ class Conv2dBlock(nn.Module):
786
+ def __init__(self, input_dim, output_dim, kernel_size, stride,
787
+ padding=0, norm='none', activation='relu', pad_type='zero'):
788
+ super(Conv2dBlock, self).__init__()
789
+ self.use_bias = True
790
+ # initialize padding
791
+ if pad_type == 'reflect':
792
+ self.pad = nn.ReflectionPad2d(padding)
793
+ elif pad_type == 'zero':
794
+ self.pad = nn.ZeroPad2d(padding)
795
+ else:
796
+ assert 0, "Unsupported padding type: {}".format(pad_type)
797
+
798
+ # initialize normalization
799
+ norm_dim = output_dim
800
+ if norm == 'batch':
801
+ self.norm = nn.BatchNorm2d(norm_dim)
802
+ elif norm == 'inst':
803
+ self.norm = nn.InstanceNorm2d(norm_dim, track_running_stats=False)
804
+ elif norm == 'ln':
805
+ self.norm = LayerNorm(norm_dim)
806
+ elif norm == 'none':
807
+ self.norm = None
808
+ else:
809
+ assert 0, "Unsupported normalization: {}".format(norm)
810
+
811
+ # initialize activation
812
+ if activation == 'relu':
813
+ self.activation = nn.ReLU(inplace=True)
814
+ elif activation == 'lrelu':
815
+ self.activation = nn.LeakyReLU(0.2, inplace=True)
816
+ elif activation == 'prelu':
817
+ self.activation = nn.PReLU()
818
+ elif activation == 'selu':
819
+ self.activation = nn.SELU(inplace=True)
820
+ elif activation == 'tanh':
821
+ self.activation = nn.Tanh()
822
+ elif activation == 'none':
823
+ self.activation = None
824
+ else:
825
+ assert 0, "Unsupported activation: {}".format(activation)
826
+
827
+ # initialize convolution
828
+ self.conv = nn.Conv2d(input_dim, output_dim, kernel_size, stride, bias=self.use_bias)
829
+
830
+ def forward(self, x):
831
+ x = self.conv(self.pad(x))
832
+ if self.norm:
833
+ x = self.norm(x)
834
+ if self.activation:
835
+ x = self.activation(x)
836
+ return x
837
+
838
+
839
+ class LinearBlock(nn.Module):
840
+ def __init__(self, input_dim, output_dim, norm='none', activation='relu'):
841
+ super(LinearBlock, self).__init__()
842
+ use_bias = True
843
+ # initialize fully connected layer
844
+ self.fc = nn.Linear(input_dim, output_dim, bias=use_bias)
845
+
846
+ # initialize normalization
847
+ norm_dim = output_dim
848
+ if norm == 'batch':
849
+ self.norm = nn.BatchNorm1d(norm_dim)
850
+ elif norm == 'inst':
851
+ self.norm = nn.InstanceNorm1d(norm_dim)
852
+ elif norm == 'ln':
853
+ self.norm = LayerNorm(norm_dim)
854
+ elif norm == 'none':
855
+ self.norm = None
856
+ else:
857
+ assert 0, "Unsupported normalization: {}".format(norm)
858
+
859
+ # initialize activation
860
+ if activation == 'relu':
861
+ self.activation = nn.ReLU(inplace=True)
862
+ elif activation == 'lrelu':
863
+ self.activation = nn.LeakyReLU(0.2, inplace=True)
864
+ elif activation == 'prelu':
865
+ self.activation = nn.PReLU()
866
+ elif activation == 'selu':
867
+ self.activation = nn.SELU(inplace=True)
868
+ elif activation == 'tanh':
869
+ self.activation = nn.Tanh()
870
+ elif activation == 'none':
871
+ self.activation = None
872
+ else:
873
+ assert 0, "Unsupported activation: {}".format(activation)
874
+
875
+ def forward(self, x):
876
+ out = self.fc(x)
877
+ if self.norm:
878
+ out = self.norm(out)
879
+ if self.activation:
880
+ out = self.activation(out)
881
+ return out
882
+
883
+ ##################################################################################
884
+ # Normalization layers
885
+ ##################################################################################
886
+
887
+
888
+ class LayerNorm(nn.Module):
889
+ def __init__(self, num_features, eps=1e-5, affine=True):
890
+ super(LayerNorm, self).__init__()
891
+ self.num_features = num_features
892
+ self.affine = affine
893
+ self.eps = eps
894
+
895
+ if self.affine:
896
+ self.gamma = nn.Parameter(torch.Tensor(num_features).uniform_())
897
+ self.beta = nn.Parameter(torch.zeros(num_features))
898
+
899
+ def forward(self, x):
900
+ shape = [-1] + [1] * (x.dim() - 1)
901
+ mean = x.view(x.size(0), -1).mean(1).view(*shape)
902
+ std = x.view(x.size(0), -1).std(1).view(*shape)
903
+ x = (x - mean) / (std + self.eps)
904
+
905
+ if self.affine:
906
+ shape = [1, -1] + [1] * (x.dim() - 2)
907
+ x = x * self.gamma.view(*shape) + self.beta.view(*shape)
908
+ return x
909
+
910
+
911
+ class ResnetGenerator(nn.Module):
912
+ """Resnet-based generator that consists of Resnet blocks between a few downsampling/upsampling operations.
913
+
914
+ We adapt Torch code and idea from Justin Johnson's neural style transfer project(https://github.com/jcjohnson/fast-neural-style)
915
+ """
916
+
917
+ def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6, padding_type='reflect', no_antialias=False, no_antialias_up=False, opt=None):
918
+ """Construct a Resnet-based generator
919
+
920
+ Parameters:
921
+ input_nc (int) -- the number of channels in input images
922
+ output_nc (int) -- the number of channels in output images
923
+ ngf (int) -- the number of filters in the last conv layer
924
+ norm_layer -- normalization layer
925
+ use_dropout (bool) -- if use dropout layers
926
+ n_blocks (int) -- the number of ResNet blocks
927
+ padding_type (str) -- the name of padding layer in conv layers: reflect | replicate | zero
928
+ """
929
+ assert(n_blocks >= 0)
930
+ super(ResnetGenerator, self).__init__()
931
+ self.opt = opt
932
+ if type(norm_layer) == functools.partial:
933
+ use_bias = norm_layer.func == nn.InstanceNorm2d
934
+ else:
935
+ use_bias = norm_layer == nn.InstanceNorm2d
936
+
937
+ model = [nn.ReflectionPad2d(3),
938
+ nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=use_bias),
939
+ norm_layer(ngf),
940
+ nn.ReLU(True)]
941
+
942
+ n_downsampling = 2
943
+ for i in range(n_downsampling): # add downsampling layers
944
+ mult = 2 ** i
945
+ if(no_antialias):
946
+ model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1, bias=use_bias),
947
+ norm_layer(ngf * mult * 2),
948
+ nn.ReLU(True)]
949
+ else:
950
+ model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=1, padding=1, bias=use_bias),
951
+ norm_layer(ngf * mult * 2),
952
+ nn.ReLU(True),
953
+ Downsample(ngf * mult * 2)]
954
+
955
+ mult = 2 ** n_downsampling
956
+ for i in range(n_blocks): # add ResNet blocks
957
+
958
+ model += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)]
959
+
960
+ for i in range(n_downsampling): # add upsampling layers
961
+ mult = 2 ** (n_downsampling - i)
962
+ if no_antialias_up:
963
+ model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2),
964
+ kernel_size=3, stride=2,
965
+ padding=1, output_padding=1,
966
+ bias=use_bias),
967
+ norm_layer(int(ngf * mult / 2)),
968
+ nn.ReLU(True)]
969
+ else:
970
+ model += [Upsample(ngf * mult),
971
+ nn.Conv2d(ngf * mult, int(ngf * mult / 2),
972
+ kernel_size=3, stride=1,
973
+ padding=1, # output_padding=1,
974
+ bias=use_bias),
975
+ norm_layer(int(ngf * mult / 2)),
976
+ nn.ReLU(True)]
977
+ model += [nn.ReflectionPad2d(3)]
978
+ model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
979
+ model += [nn.Tanh()]
980
+
981
+ self.model = nn.Sequential(*model)
982
+
983
+ def forward(self, input, layers=[], encode_only=False,mode='all',stop_layer=16):
984
+ if -1 in layers:
985
+ layers.append(len(self.model))
986
+ if len(layers) > 0:
987
+ feat = input
988
+ feats = []
989
+ for layer_id, layer in enumerate(self.model):
990
+ # print(layer_id, layer)
991
+ feat = layer(feat)
992
+ if layer_id in layers:
993
+ # print("%d: adding the output of %s %d" % (layer_id, layer.__class__.__name__, feat.size(1)))
994
+ feats.append(feat)
995
+ else:
996
+ # print("%d: skipping %s %d" % (layer_id, layer.__class__.__name__, feat.size(1)))
997
+ pass
998
+ if layer_id == layers[-1] and encode_only:
999
+ # print('encoder only return features')
1000
+ return feats # return intermediate features alone; stop in the last layers
1001
+
1002
+ return feat, feats # return both output and intermediate features
1003
+ else:
1004
+ """Standard forward"""
1005
+ if mode=='encoder':
1006
+ feat=input
1007
+ for layer_id, layer in enumerate(self.model):
1008
+ feat = layer(feat)
1009
+ if layer_id == stop_layer:
1010
+ # print('encoder only return features')
1011
+ return feat # return intermediate features alone; stop in the last layers
1012
+ elif mode =='decoder':
1013
+ feat=input
1014
+ for layer_id, layer in enumerate(self.model):
1015
+
1016
+ if layer_id > stop_layer:
1017
+ feat = layer(feat)
1018
+ else:
1019
+ pass
1020
+ # print('encoder only return features')
1021
+ return feat # return intermediate features alone; stop in the last layers
1022
+ else:
1023
+ fake = self.model(input)
1024
+ return fake
1025
+
1026
+ # class ResnetGenerator(nn.Module):
1027
+ # """Resnet-based generator that consists of Resnet blocks between a few downsampling/upsampling operations.
1028
+
1029
+ # We adapt Torch code and idea from Justin Johnson's neural style transfer project(https://github.com/jcjohnson/fast-neural-style)
1030
+ # """
1031
+
1032
+ # def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6, padding_type='reflect', no_antialias=False, no_antialias_up=False, opt=None):
1033
+ # """Construct a Resnet-based generator
1034
+
1035
+ # Parameters:
1036
+ # input_nc (int) -- the number of channels in input images
1037
+ # output_nc (int) -- the number of channels in output images
1038
+ # ngf (int) -- the number of filters in the last conv layer
1039
+ # norm_layer -- normalization layer
1040
+ # use_dropout (bool) -- if use dropout layers
1041
+ # n_blocks (int) -- the number of ResNet blocks
1042
+ # padding_type (str) -- the name of padding layer in conv layers: reflect | replicate | zero
1043
+ # """
1044
+ # assert(n_blocks >= 0)
1045
+ # super(ResnetGenerator, self).__init__()
1046
+ # self.opt = opt
1047
+ # if type(norm_layer) == functools.partial:
1048
+ # use_bias = norm_layer.func == nn.InstanceNorm2d
1049
+ # else:
1050
+ # use_bias = norm_layer == nn.InstanceNorm2d
1051
+
1052
+ # model = [nn.ReflectionPad2d(3),
1053
+ # nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=use_bias),
1054
+ # norm_layer(ngf),
1055
+ # nn.ReLU(True)]
1056
+
1057
+ # n_downsampling = 2
1058
+ # for i in range(n_downsampling): # add downsampling layers
1059
+ # mult = 2 ** i
1060
+ # if(no_antialias):
1061
+ # model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1, bias=use_bias),
1062
+ # norm_layer(ngf * mult * 2),
1063
+ # nn.ReLU(True)]
1064
+ # else:
1065
+ # model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=1, padding=1, bias=use_bias),
1066
+ # norm_layer(ngf * mult * 2),
1067
+ # nn.ReLU(True),
1068
+ # Downsample(ngf * mult * 2)]
1069
+
1070
+ # mult = 2 ** n_downsampling
1071
+ # for i in range(n_blocks): # add ResNet blocks
1072
+
1073
+ # model += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)]
1074
+
1075
+ # for i in range(n_downsampling): # add upsampling layers
1076
+ # mult = 2 ** (n_downsampling - i)
1077
+ # if no_antialias_up:
1078
+ # model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2),
1079
+ # kernel_size=3, stride=2,
1080
+ # padding=1, output_padding=1,
1081
+ # bias=use_bias),
1082
+ # norm_layer(int(ngf * mult / 2)),
1083
+ # nn.ReLU(True)]
1084
+ # else:
1085
+ # model += [Upsample(ngf * mult),
1086
+ # nn.Conv2d(ngf * mult, int(ngf * mult / 2),
1087
+ # kernel_size=3, stride=1,
1088
+ # padding=1, # output_padding=1,
1089
+ # bias=use_bias),
1090
+ # norm_layer(int(ngf * mult / 2)),
1091
+ # nn.ReLU(True)]
1092
+ # model += [nn.ReflectionPad2d(3)]
1093
+ # model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
1094
+ # model += [nn.Tanh()]
1095
+
1096
+ # self.model = nn.Sequential(*model)
1097
+
1098
+ # def forward(self, input, layers=[], encode_only=False):
1099
+ # if -1 in layers:
1100
+ # layers.append(len(self.model))
1101
+ # if len(layers) > 0:
1102
+ # feat = input
1103
+ # feats = []
1104
+ # for layer_id, layer in enumerate(self.model):
1105
+ # # print(layer_id, layer)
1106
+ # feat = layer(feat)
1107
+ # if layer_id in layers:
1108
+ # # print("%d: adding the output of %s %d" % (layer_id, layer.__class__.__name__, feat.size(1)))
1109
+ # feats.append(feat)
1110
+ # else:
1111
+ # # print("%d: skipping %s %d" % (layer_id, layer.__class__.__name__, feat.size(1)))
1112
+ # pass
1113
+ # if layer_id == layers[-1] and encode_only:
1114
+ # # print('encoder only return features')
1115
+ # return feats # return intermediate features alone; stop in the last layers
1116
+
1117
+ # return feat, feats # return both output and intermediate features
1118
+ # else:
1119
+ # """Standard forward"""
1120
+ # fake = self.model(input)
1121
+ # return fake
1122
+
1123
+ class ResnetDecoder(nn.Module):
1124
+ """Resnet-based decoder that consists of a few Resnet blocks + a few upsampling operations.
1125
+ """
1126
+
1127
+ def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6, padding_type='reflect', no_antialias=False):
1128
+ """Construct a Resnet-based decoder
1129
+
1130
+ Parameters:
1131
+ input_nc (int) -- the number of channels in input images
1132
+ output_nc (int) -- the number of channels in output images
1133
+ ngf (int) -- the number of filters in the last conv layer
1134
+ norm_layer -- normalization layer
1135
+ use_dropout (bool) -- if use dropout layers
1136
+ n_blocks (int) -- the number of ResNet blocks
1137
+ padding_type (str) -- the name of padding layer in conv layers: reflect | replicate | zero
1138
+ """
1139
+ assert(n_blocks >= 0)
1140
+ super(ResnetDecoder, self).__init__()
1141
+ if type(norm_layer) == functools.partial:
1142
+ use_bias = norm_layer.func == nn.InstanceNorm2d
1143
+ else:
1144
+ use_bias = norm_layer == nn.InstanceNorm2d
1145
+ model = []
1146
+ n_downsampling = 2
1147
+ mult = 2 ** n_downsampling
1148
+ for i in range(n_blocks): # add ResNet blocks
1149
+
1150
+ model += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)]
1151
+
1152
+ for i in range(n_downsampling): # add upsampling layers
1153
+ mult = 2 ** (n_downsampling - i)
1154
+ if(no_antialias):
1155
+ model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2),
1156
+ kernel_size=3, stride=2,
1157
+ padding=1, output_padding=1,
1158
+ bias=use_bias),
1159
+ norm_layer(int(ngf * mult / 2)),
1160
+ nn.ReLU(True)]
1161
+ else:
1162
+ model += [Upsample(ngf * mult),
1163
+ nn.Conv2d(ngf * mult, int(ngf * mult / 2),
1164
+ kernel_size=3, stride=1,
1165
+ padding=1,
1166
+ bias=use_bias),
1167
+ norm_layer(int(ngf * mult / 2)),
1168
+ nn.ReLU(True)]
1169
+ model += [nn.ReflectionPad2d(3)]
1170
+ model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
1171
+ model += [nn.Tanh()]
1172
+
1173
+ self.model = nn.Sequential(*model)
1174
+
1175
+ def forward(self, input):
1176
+ """Standard forward"""
1177
+ return self.model(input)
1178
+
1179
+
1180
+ class ResnetEncoder(nn.Module):
1181
+ """Resnet-based encoder that consists of a few downsampling + several Resnet blocks
1182
+ """
1183
+
1184
+ def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6, padding_type='reflect', no_antialias=False):
1185
+ """Construct a Resnet-based encoder
1186
+
1187
+ Parameters:
1188
+ input_nc (int) -- the number of channels in input images
1189
+ output_nc (int) -- the number of channels in output images
1190
+ ngf (int) -- the number of filters in the last conv layer
1191
+ norm_layer -- normalization layer
1192
+ use_dropout (bool) -- if use dropout layers
1193
+ n_blocks (int) -- the number of ResNet blocks
1194
+ padding_type (str) -- the name of padding layer in conv layers: reflect | replicate | zero
1195
+ """
1196
+ assert(n_blocks >= 0)
1197
+ super(ResnetEncoder, self).__init__()
1198
+ if type(norm_layer) == functools.partial:
1199
+ use_bias = norm_layer.func == nn.InstanceNorm2d
1200
+ else:
1201
+ use_bias = norm_layer == nn.InstanceNorm2d
1202
+
1203
+ model = [nn.ReflectionPad2d(3),
1204
+ nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=use_bias),
1205
+ norm_layer(ngf),
1206
+ nn.ReLU(True)]
1207
+
1208
+ n_downsampling = 2
1209
+ for i in range(n_downsampling): # add downsampling layers
1210
+ mult = 2 ** i
1211
+ if(no_antialias):
1212
+ model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1, bias=use_bias),
1213
+ norm_layer(ngf * mult * 2),
1214
+ nn.ReLU(True)]
1215
+ else:
1216
+ model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=1, padding=1, bias=use_bias),
1217
+ norm_layer(ngf * mult * 2),
1218
+ nn.ReLU(True),
1219
+ Downsample(ngf * mult * 2)]
1220
+
1221
+ mult = 2 ** n_downsampling
1222
+ for i in range(n_blocks): # add ResNet blocks
1223
+
1224
+ model += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)]
1225
+
1226
+ self.model = nn.Sequential(*model)
1227
+
1228
+ def forward(self, input):
1229
+ """Standard forward"""
1230
+ return self.model(input)
1231
+
1232
+
1233
+ class ResnetBlock(nn.Module):
1234
+ """Define a Resnet block"""
1235
+
1236
+ def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias):
1237
+ """Initialize the Resnet block
1238
+
1239
+ A resnet block is a conv block with skip connections
1240
+ We construct a conv block with build_conv_block function,
1241
+ and implement skip connections in <forward> function.
1242
+ Original Resnet paper: https://arxiv.org/pdf/1512.03385.pdf
1243
+ """
1244
+ super(ResnetBlock, self).__init__()
1245
+ self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_dropout, use_bias)
1246
+
1247
+ def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias):
1248
+ """Construct a convolutional block.
1249
+
1250
+ Parameters:
1251
+ dim (int) -- the number of channels in the conv layer.
1252
+ padding_type (str) -- the name of padding layer: reflect | replicate | zero
1253
+ norm_layer -- normalization layer
1254
+ use_dropout (bool) -- if use dropout layers.
1255
+ use_bias (bool) -- if the conv layer uses bias or not
1256
+
1257
+ Returns a conv block (with a conv layer, a normalization layer, and a non-linearity layer (ReLU))
1258
+ """
1259
+ conv_block = []
1260
+ p = 0
1261
+ if padding_type == 'reflect':
1262
+ conv_block += [nn.ReflectionPad2d(1)]
1263
+ elif padding_type == 'replicate':
1264
+ conv_block += [nn.ReplicationPad2d(1)]
1265
+ elif padding_type == 'zero':
1266
+ p = 1
1267
+ else:
1268
+ raise NotImplementedError('padding [%s] is not implemented' % padding_type)
1269
+
1270
+ conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim), nn.ReLU(True)]
1271
+ if use_dropout:
1272
+ conv_block += [nn.Dropout(0.5)]
1273
+
1274
+ p = 0
1275
+ if padding_type == 'reflect':
1276
+ conv_block += [nn.ReflectionPad2d(1)]
1277
+ elif padding_type == 'replicate':
1278
+ conv_block += [nn.ReplicationPad2d(1)]
1279
+ elif padding_type == 'zero':
1280
+ p = 1
1281
+ else:
1282
+ raise NotImplementedError('padding [%s] is not implemented' % padding_type)
1283
+ conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim)]
1284
+
1285
+ return nn.Sequential(*conv_block)
1286
+
1287
+ def forward(self, x):
1288
+ """Forward function (with skip connections)"""
1289
+ out = x + self.conv_block(x) # add skip connections
1290
+ return out
1291
+
1292
+
1293
+ class UnetGenerator(nn.Module):
1294
+ """Create a Unet-based generator"""
1295
+
1296
+ def __init__(self, input_nc, output_nc, num_downs, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False):
1297
+ """Construct a Unet generator
1298
+ Parameters:
1299
+ input_nc (int) -- the number of channels in input images
1300
+ output_nc (int) -- the number of channels in output images
1301
+ num_downs (int) -- the number of downsamplings in UNet. For example, # if |num_downs| == 7,
1302
+ image of size 128x128 will become of size 1x1 # at the bottleneck
1303
+ ngf (int) -- the number of filters in the last conv layer
1304
+ norm_layer -- normalization layer
1305
+
1306
+ We construct the U-Net from the innermost layer to the outermost layer.
1307
+ It is a recursive process.
1308
+ """
1309
+ super(UnetGenerator, self).__init__()
1310
+ # construct unet structure
1311
+ unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True) # add the innermost layer
1312
+ for i in range(num_downs - 5): # add intermediate layers with ngf * 8 filters
1313
+ unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_dropout=use_dropout)
1314
+ # gradually reduce the number of filters from ngf * 8 to ngf
1315
+ unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
1316
+ unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
1317
+ unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
1318
+ self.model = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer) # add the outermost layer
1319
+
1320
+ def forward(self, input):
1321
+ """Standard forward"""
1322
+ return self.model(input)
1323
+
1324
+
1325
+ class UnetSkipConnectionBlock(nn.Module):
1326
+ """Defines the Unet submodule with skip connection.
1327
+ X -------------------identity----------------------
1328
+ |-- downsampling -- |submodule| -- upsampling --|
1329
+ """
1330
+
1331
+ def __init__(self, outer_nc, inner_nc, input_nc=None,
1332
+ submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False):
1333
+ """Construct a Unet submodule with skip connections.
1334
+
1335
+ Parameters:
1336
+ outer_nc (int) -- the number of filters in the outer conv layer
1337
+ inner_nc (int) -- the number of filters in the inner conv layer
1338
+ input_nc (int) -- the number of channels in input images/features
1339
+ submodule (UnetSkipConnectionBlock) -- previously defined submodules
1340
+ outermost (bool) -- if this module is the outermost module
1341
+ innermost (bool) -- if this module is the innermost module
1342
+ norm_layer -- normalization layer
1343
+ use_dropout (bool) -- if use dropout layers.
1344
+ """
1345
+ super(UnetSkipConnectionBlock, self).__init__()
1346
+ self.outermost = outermost
1347
+ if type(norm_layer) == functools.partial:
1348
+ use_bias = norm_layer.func == nn.InstanceNorm2d
1349
+ else:
1350
+ use_bias = norm_layer == nn.InstanceNorm2d
1351
+ if input_nc is None:
1352
+ input_nc = outer_nc
1353
+ downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4,
1354
+ stride=2, padding=1, bias=use_bias)
1355
+ downrelu = nn.LeakyReLU(0.2, True)
1356
+ downnorm = norm_layer(inner_nc)
1357
+ uprelu = nn.ReLU(True)
1358
+ upnorm = norm_layer(outer_nc)
1359
+
1360
+ if outermost:
1361
+ upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
1362
+ kernel_size=4, stride=2,
1363
+ padding=1)
1364
+ down = [downconv]
1365
+ up = [uprelu, upconv, nn.Tanh()]
1366
+ model = down + [submodule] + up
1367
+ elif innermost:
1368
+ upconv = nn.ConvTranspose2d(inner_nc, outer_nc,
1369
+ kernel_size=4, stride=2,
1370
+ padding=1, bias=use_bias)
1371
+ down = [downrelu, downconv]
1372
+ up = [uprelu, upconv, upnorm]
1373
+ model = down + up
1374
+ else:
1375
+ upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
1376
+ kernel_size=4, stride=2,
1377
+ padding=1, bias=use_bias)
1378
+ down = [downrelu, downconv, downnorm]
1379
+ up = [uprelu, upconv, upnorm]
1380
+
1381
+ if use_dropout:
1382
+ model = down + [submodule] + up + [nn.Dropout(0.5)]
1383
+ else:
1384
+ model = down + [submodule] + up
1385
+
1386
+ self.model = nn.Sequential(*model)
1387
+
1388
+ def forward(self, x):
1389
+ if self.outermost:
1390
+ return self.model(x)
1391
+ else: # add skip connections
1392
+ return torch.cat([x, self.model(x)], 1)
1393
+
1394
+
1395
+ class NLayerDiscriminator(nn.Module):
1396
+ """Defines a PatchGAN discriminator"""
1397
+
1398
+ def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, no_antialias=False):
1399
+ """Construct a PatchGAN discriminator
1400
+
1401
+ Parameters:
1402
+ input_nc (int) -- the number of channels in input images
1403
+ ndf (int) -- the number of filters in the last conv layer
1404
+ n_layers (int) -- the number of conv layers in the discriminator
1405
+ norm_layer -- normalization layer
1406
+ """
1407
+ super(NLayerDiscriminator, self).__init__()
1408
+ if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters
1409
+ use_bias = norm_layer.func == nn.InstanceNorm2d
1410
+ else:
1411
+ use_bias = norm_layer == nn.InstanceNorm2d
1412
+
1413
+ kw = 4
1414
+ padw = 1
1415
+ if(no_antialias):
1416
+ sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]
1417
+ else:
1418
+ sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=1, padding=padw), nn.LeakyReLU(0.2, True), Downsample(ndf)]
1419
+ nf_mult = 1
1420
+ nf_mult_prev = 1
1421
+ for n in range(1, n_layers): # gradually increase the number of filters
1422
+ nf_mult_prev = nf_mult
1423
+ nf_mult = min(2 ** n, 8)
1424
+ if(no_antialias):
1425
+ sequence += [
1426
+ nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
1427
+ norm_layer(ndf * nf_mult),
1428
+ nn.LeakyReLU(0.2, True)
1429
+ ]
1430
+ else:
1431
+ sequence += [
1432
+ nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),
1433
+ norm_layer(ndf * nf_mult),
1434
+ nn.LeakyReLU(0.2, True),
1435
+ Downsample(ndf * nf_mult)]
1436
+
1437
+ nf_mult_prev = nf_mult
1438
+ nf_mult = min(2 ** n_layers, 8)
1439
+ sequence += [
1440
+ nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),
1441
+ norm_layer(ndf * nf_mult),
1442
+ nn.LeakyReLU(0.2, True)
1443
+ ]
1444
+
1445
+ sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map
1446
+ self.model = nn.Sequential(*sequence)
1447
+
1448
+ def forward(self, input):
1449
+ """Standard forward."""
1450
+ return self.model(input)
1451
+
1452
+
1453
+ class PixelDiscriminator(nn.Module):
1454
+ """Defines a 1x1 PatchGAN discriminator (pixelGAN)"""
1455
+
1456
+ def __init__(self, input_nc, ndf=64, norm_layer=nn.BatchNorm2d):
1457
+ """Construct a 1x1 PatchGAN discriminator
1458
+
1459
+ Parameters:
1460
+ input_nc (int) -- the number of channels in input images
1461
+ ndf (int) -- the number of filters in the last conv layer
1462
+ norm_layer -- normalization layer
1463
+ """
1464
+ super(PixelDiscriminator, self).__init__()
1465
+ if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters
1466
+ use_bias = norm_layer.func == nn.InstanceNorm2d
1467
+ else:
1468
+ use_bias = norm_layer == nn.InstanceNorm2d
1469
+
1470
+ self.net = [
1471
+ nn.Conv2d(input_nc, ndf, kernel_size=1, stride=1, padding=0),
1472
+ nn.LeakyReLU(0.2, True),
1473
+ nn.Conv2d(ndf, ndf * 2, kernel_size=1, stride=1, padding=0, bias=use_bias),
1474
+ norm_layer(ndf * 2),
1475
+ nn.LeakyReLU(0.2, True),
1476
+ nn.Conv2d(ndf * 2, 1, kernel_size=1, stride=1, padding=0, bias=use_bias)]
1477
+
1478
+ self.net = nn.Sequential(*self.net)
1479
+
1480
+ def forward(self, input):
1481
+ """Standard forward."""
1482
+ return self.net(input)
1483
+
1484
+
1485
+ class PatchDiscriminator(NLayerDiscriminator):
1486
+ """Defines a PatchGAN discriminator"""
1487
+
1488
+ def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, no_antialias=False):
1489
+ super().__init__(input_nc, ndf, 2, norm_layer, no_antialias)
1490
+
1491
+ def forward(self, input):
1492
+ B, C, H, W = input.size(0), input.size(1), input.size(2), input.size(3)
1493
+ size = 16
1494
+ Y = H // size
1495
+ X = W // size
1496
+ input = input.view(B, C, Y, size, X, size)
1497
+ input = input.permute(0, 2, 4, 1, 3, 5).contiguous().view(B * Y * X, C, size, size)
1498
+ return super().forward(input)
1499
+
1500
+
1501
+ class GroupedChannelNorm(nn.Module):
1502
+ def __init__(self, num_groups):
1503
+ super().__init__()
1504
+ self.num_groups = num_groups
1505
+
1506
+ def forward(self, x):
1507
+ shape = list(x.shape)
1508
+ new_shape = [shape[0], self.num_groups, shape[1] // self.num_groups] + shape[2:]
1509
+ x = x.view(*new_shape)
1510
+ mean = x.mean(dim=2, keepdim=True)
1511
+ std = x.std(dim=2, keepdim=True)
1512
+ x_norm = (x - mean) / (std + 1e-7)
1513
+ return x_norm.view(*shape)
Scenimefy/models/patchnce.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from packaging import version
2
+ import torch
3
+ from torch import nn
4
+
5
+
6
+ class PatchNCELoss(nn.Module):
7
+ def __init__(self, opt):
8
+ super().__init__()
9
+ self.opt = opt
10
+ self.cross_entropy_loss = torch.nn.CrossEntropyLoss(reduction='none')
11
+ self.mask_dtype = torch.uint8 if version.parse(torch.__version__) < version.parse('1.2.0') else torch.bool
12
+
13
+ def forward(self, feat_q, feat_k, weight=None):
14
+ batchSize = feat_q.shape[0]
15
+ dim = feat_q.shape[1]
16
+ feat_k = feat_k.detach()
17
+
18
+ # pos logit
19
+ l_pos = torch.bmm(feat_q.view(batchSize, 1, -1), feat_k.view(batchSize, -1, 1))
20
+ l_pos = l_pos.view(batchSize, 1)
21
+
22
+ # neg logit
23
+
24
+ # Should the negatives from the other samples of a minibatch be utilized?
25
+ # In CUT and FastCUT, we found that it's best to only include negatives
26
+ # from the same image. Therefore, we set
27
+ # --nce_includes_all_negatives_from_minibatch as False
28
+ # However, for single-image translation, the minibatch consists of
29
+ # crops from the "same" high-resolution image.
30
+ # Therefore, we will include the negatives from the entire minibatch.
31
+ if self.opt.nce_includes_all_negatives_from_minibatch:
32
+ # reshape features as if they are all negatives of minibatch of size 1.
33
+ batch_dim_for_bmm = 1
34
+ else:
35
+ batch_dim_for_bmm = self.opt.batch_size
36
+
37
+ # reshape features to batch size
38
+ feat_q = feat_q.view(batch_dim_for_bmm, -1, dim)
39
+ feat_k = feat_k.view(batch_dim_for_bmm, -1, dim)
40
+ npatches = feat_q.size(1)
41
+ l_neg_curbatch = torch.bmm(feat_q, feat_k.transpose(2, 1))
42
+
43
+ if weight is not None:
44
+ l_neg_curbatch *= weight
45
+
46
+ # diagonal entries are similarity between same features, and hence meaningless.
47
+ # just fill the diagonal with very small number, which is exp(-10) and almost zero
48
+ diagonal = torch.eye(npatches, device=feat_q.device, dtype=self.mask_dtype)[None, :, :]
49
+ l_neg_curbatch.masked_fill_(diagonal, -10.0)
50
+ l_neg = l_neg_curbatch.view(-1, npatches)
51
+
52
+ out = torch.cat((l_pos, l_neg), dim=1) / self.opt.nce_T
53
+
54
+ loss = self.cross_entropy_loss(out, torch.zeros(out.size(0), dtype=torch.long,
55
+ device=feat_q.device))
56
+
57
+ return loss
Scenimefy/models/stylegan_networks.py ADDED
@@ -0,0 +1,914 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ The network architectures is based on PyTorch implemenation of StyleGAN2Encoder.
3
+ Original PyTorch repo: https://github.com/rosinality/style-based-gan-pytorch
4
+ Origianl StyelGAN2 paper: https://github.com/NVlabs/stylegan2
5
+ We use the network architeture for our single-image traning setting.
6
+ """
7
+
8
+ import math
9
+ import numpy as np
10
+ import random
11
+
12
+ import torch
13
+ from torch import nn
14
+ from torch.nn import functional as F
15
+
16
+
17
+ def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2 ** 0.5):
18
+ return F.leaky_relu(input + bias, negative_slope) * scale
19
+
20
+
21
+ class FusedLeakyReLU(nn.Module):
22
+ def __init__(self, channel, negative_slope=0.2, scale=2 ** 0.5):
23
+ super().__init__()
24
+ self.bias = nn.Parameter(torch.zeros(1, channel, 1, 1))
25
+ self.negative_slope = negative_slope
26
+ self.scale = scale
27
+
28
+ def forward(self, input):
29
+ # print("FusedLeakyReLU: ", input.abs().mean())
30
+ out = fused_leaky_relu(input, self.bias,
31
+ self.negative_slope,
32
+ self.scale)
33
+ # print("FusedLeakyReLU: ", out.abs().mean())
34
+ return out
35
+
36
+
37
+ def upfirdn2d_native(
38
+ input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
39
+ ):
40
+ _, minor, in_h, in_w = input.shape
41
+ kernel_h, kernel_w = kernel.shape
42
+
43
+ out = input.view(-1, minor, in_h, 1, in_w, 1)
44
+ out = F.pad(out, [0, up_x - 1, 0, 0, 0, up_y - 1, 0, 0])
45
+ out = out.view(-1, minor, in_h * up_y, in_w * up_x)
46
+
47
+ out = F.pad(
48
+ out, [max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)]
49
+ )
50
+ out = out[
51
+ :,
52
+ :,
53
+ max(-pad_y0, 0): out.shape[2] - max(-pad_y1, 0),
54
+ max(-pad_x0, 0): out.shape[3] - max(-pad_x1, 0),
55
+ ]
56
+
57
+ # out = out.permute(0, 3, 1, 2)
58
+ out = out.reshape(
59
+ [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1]
60
+ )
61
+ w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
62
+ out = F.conv2d(out, w)
63
+ out = out.reshape(
64
+ -1,
65
+ minor,
66
+ in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
67
+ in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
68
+ )
69
+ # out = out.permute(0, 2, 3, 1)
70
+
71
+ return out[:, :, ::down_y, ::down_x]
72
+
73
+
74
+ def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
75
+ return upfirdn2d_native(input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1])
76
+
77
+
78
+ class PixelNorm(nn.Module):
79
+ def __init__(self):
80
+ super().__init__()
81
+
82
+ def forward(self, input):
83
+ return input * torch.rsqrt(torch.mean(input ** 2, dim=1, keepdim=True) + 1e-8)
84
+
85
+
86
+ def make_kernel(k):
87
+ k = torch.tensor(k, dtype=torch.float32)
88
+
89
+ if len(k.shape) == 1:
90
+ k = k[None, :] * k[:, None]
91
+
92
+ k /= k.sum()
93
+
94
+ return k
95
+
96
+
97
+ class Upsample(nn.Module):
98
+ def __init__(self, kernel, factor=2):
99
+ super().__init__()
100
+
101
+ self.factor = factor
102
+ kernel = make_kernel(kernel) * (factor ** 2)
103
+ self.register_buffer('kernel', kernel)
104
+
105
+ p = kernel.shape[0] - factor
106
+
107
+ pad0 = (p + 1) // 2 + factor - 1
108
+ pad1 = p // 2
109
+
110
+ self.pad = (pad0, pad1)
111
+
112
+ def forward(self, input):
113
+ out = upfirdn2d(input, self.kernel, up=self.factor, down=1, pad=self.pad)
114
+
115
+ return out
116
+
117
+
118
+ class Downsample(nn.Module):
119
+ def __init__(self, kernel, factor=2):
120
+ super().__init__()
121
+
122
+ self.factor = factor
123
+ kernel = make_kernel(kernel)
124
+ self.register_buffer('kernel', kernel)
125
+
126
+ p = kernel.shape[0] - factor
127
+
128
+ pad0 = (p + 1) // 2
129
+ pad1 = p // 2
130
+
131
+ self.pad = (pad0, pad1)
132
+
133
+ def forward(self, input):
134
+ out = upfirdn2d(input, self.kernel, up=1, down=self.factor, pad=self.pad)
135
+
136
+ return out
137
+
138
+
139
+ class Blur(nn.Module):
140
+ def __init__(self, kernel, pad, upsample_factor=1):
141
+ super().__init__()
142
+
143
+ kernel = make_kernel(kernel)
144
+
145
+ if upsample_factor > 1:
146
+ kernel = kernel * (upsample_factor ** 2)
147
+
148
+ self.register_buffer('kernel', kernel)
149
+
150
+ self.pad = pad
151
+
152
+ def forward(self, input):
153
+ out = upfirdn2d(input, self.kernel, pad=self.pad)
154
+
155
+ return out
156
+
157
+
158
+ class EqualConv2d(nn.Module):
159
+ def __init__(
160
+ self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True
161
+ ):
162
+ super().__init__()
163
+
164
+ self.weight = nn.Parameter(
165
+ torch.randn(out_channel, in_channel, kernel_size, kernel_size)
166
+ )
167
+ self.scale = math.sqrt(1) / math.sqrt(in_channel * (kernel_size ** 2))
168
+
169
+ self.stride = stride
170
+ self.padding = padding
171
+
172
+ if bias:
173
+ self.bias = nn.Parameter(torch.zeros(out_channel))
174
+
175
+ else:
176
+ self.bias = None
177
+
178
+ def forward(self, input):
179
+ # print("Before EqualConv2d: ", input.abs().mean())
180
+ out = F.conv2d(
181
+ input,
182
+ self.weight * self.scale,
183
+ bias=self.bias,
184
+ stride=self.stride,
185
+ padding=self.padding,
186
+ )
187
+ # print("After EqualConv2d: ", out.abs().mean(), (self.weight * self.scale).abs().mean())
188
+
189
+ return out
190
+
191
+ def __repr__(self):
192
+ return (
193
+ f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]},'
194
+ f' {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})'
195
+ )
196
+
197
+
198
+ class EqualLinear(nn.Module):
199
+ def __init__(
200
+ self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None
201
+ ):
202
+ super().__init__()
203
+
204
+ self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul))
205
+
206
+ if bias:
207
+ self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init))
208
+
209
+ else:
210
+ self.bias = None
211
+
212
+ self.activation = activation
213
+
214
+ self.scale = (math.sqrt(1) / math.sqrt(in_dim)) * lr_mul
215
+ self.lr_mul = lr_mul
216
+
217
+ def forward(self, input):
218
+ if self.activation:
219
+ out = F.linear(input, self.weight * self.scale)
220
+ out = fused_leaky_relu(out, self.bias * self.lr_mul)
221
+
222
+ else:
223
+ out = F.linear(
224
+ input, self.weight * self.scale, bias=self.bias * self.lr_mul
225
+ )
226
+
227
+ return out
228
+
229
+ def __repr__(self):
230
+ return (
231
+ f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})'
232
+ )
233
+
234
+
235
+ class ScaledLeakyReLU(nn.Module):
236
+ def __init__(self, negative_slope=0.2):
237
+ super().__init__()
238
+
239
+ self.negative_slope = negative_slope
240
+
241
+ def forward(self, input):
242
+ out = F.leaky_relu(input, negative_slope=self.negative_slope)
243
+
244
+ return out * math.sqrt(2)
245
+
246
+
247
+ class ModulatedConv2d(nn.Module):
248
+ def __init__(
249
+ self,
250
+ in_channel,
251
+ out_channel,
252
+ kernel_size,
253
+ style_dim,
254
+ demodulate=True,
255
+ upsample=False,
256
+ downsample=False,
257
+ blur_kernel=[1, 3, 3, 1],
258
+ ):
259
+ super().__init__()
260
+
261
+ self.eps = 1e-8
262
+ self.kernel_size = kernel_size
263
+ self.in_channel = in_channel
264
+ self.out_channel = out_channel
265
+ self.upsample = upsample
266
+ self.downsample = downsample
267
+
268
+ if upsample:
269
+ factor = 2
270
+ p = (len(blur_kernel) - factor) - (kernel_size - 1)
271
+ pad0 = (p + 1) // 2 + factor - 1
272
+ pad1 = p // 2 + 1
273
+
274
+ self.blur = Blur(blur_kernel, pad=(pad0, pad1), upsample_factor=factor)
275
+
276
+ if downsample:
277
+ factor = 2
278
+ p = (len(blur_kernel) - factor) + (kernel_size - 1)
279
+ pad0 = (p + 1) // 2
280
+ pad1 = p // 2
281
+
282
+ self.blur = Blur(blur_kernel, pad=(pad0, pad1))
283
+
284
+ fan_in = in_channel * kernel_size ** 2
285
+ self.scale = math.sqrt(1) / math.sqrt(fan_in)
286
+ self.padding = kernel_size // 2
287
+
288
+ self.weight = nn.Parameter(
289
+ torch.randn(1, out_channel, in_channel, kernel_size, kernel_size)
290
+ )
291
+
292
+ if style_dim is not None and style_dim > 0:
293
+ self.modulation = EqualLinear(style_dim, in_channel, bias_init=1)
294
+
295
+ self.demodulate = demodulate
296
+
297
+ def __repr__(self):
298
+ return (
299
+ f'{self.__class__.__name__}({self.in_channel}, {self.out_channel}, {self.kernel_size}, '
300
+ f'upsample={self.upsample}, downsample={self.downsample})'
301
+ )
302
+
303
+ def forward(self, input, style):
304
+ batch, in_channel, height, width = input.shape
305
+
306
+ if style is not None:
307
+ style = self.modulation(style).view(batch, 1, in_channel, 1, 1)
308
+ else:
309
+ style = torch.ones(batch, 1, in_channel, 1, 1).cuda()
310
+ weight = self.scale * self.weight * style
311
+
312
+ if self.demodulate:
313
+ demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-8)
314
+ weight = weight * demod.view(batch, self.out_channel, 1, 1, 1)
315
+
316
+ weight = weight.view(
317
+ batch * self.out_channel, in_channel, self.kernel_size, self.kernel_size
318
+ )
319
+
320
+ if self.upsample:
321
+ input = input.view(1, batch * in_channel, height, width)
322
+ weight = weight.view(
323
+ batch, self.out_channel, in_channel, self.kernel_size, self.kernel_size
324
+ )
325
+ weight = weight.transpose(1, 2).reshape(
326
+ batch * in_channel, self.out_channel, self.kernel_size, self.kernel_size
327
+ )
328
+ out = F.conv_transpose2d(input, weight, padding=0, stride=2, groups=batch)
329
+ _, _, height, width = out.shape
330
+ out = out.view(batch, self.out_channel, height, width)
331
+ out = self.blur(out)
332
+
333
+ elif self.downsample:
334
+ input = self.blur(input)
335
+ _, _, height, width = input.shape
336
+ input = input.view(1, batch * in_channel, height, width)
337
+ out = F.conv2d(input, weight, padding=0, stride=2, groups=batch)
338
+ _, _, height, width = out.shape
339
+ out = out.view(batch, self.out_channel, height, width)
340
+
341
+ else:
342
+ input = input.view(1, batch * in_channel, height, width)
343
+ out = F.conv2d(input, weight, padding=self.padding, groups=batch)
344
+ _, _, height, width = out.shape
345
+ out = out.view(batch, self.out_channel, height, width)
346
+
347
+ return out
348
+
349
+
350
+ class NoiseInjection(nn.Module):
351
+ def __init__(self):
352
+ super().__init__()
353
+
354
+ self.weight = nn.Parameter(torch.zeros(1))
355
+
356
+ def forward(self, image, noise=None):
357
+ if noise is None:
358
+ batch, _, height, width = image.shape
359
+ noise = image.new_empty(batch, 1, height, width).normal_()
360
+
361
+ return image + self.weight * noise
362
+
363
+
364
+ class ConstantInput(nn.Module):
365
+ def __init__(self, channel, size=4):
366
+ super().__init__()
367
+
368
+ self.input = nn.Parameter(torch.randn(1, channel, size, size))
369
+
370
+ def forward(self, input):
371
+ batch = input.shape[0]
372
+ out = self.input.repeat(batch, 1, 1, 1)
373
+
374
+ return out
375
+
376
+
377
+ class StyledConv(nn.Module):
378
+ def __init__(
379
+ self,
380
+ in_channel,
381
+ out_channel,
382
+ kernel_size,
383
+ style_dim=None,
384
+ upsample=False,
385
+ blur_kernel=[1, 3, 3, 1],
386
+ demodulate=True,
387
+ inject_noise=True,
388
+ ):
389
+ super().__init__()
390
+
391
+ self.inject_noise = inject_noise
392
+ self.conv = ModulatedConv2d(
393
+ in_channel,
394
+ out_channel,
395
+ kernel_size,
396
+ style_dim,
397
+ upsample=upsample,
398
+ blur_kernel=blur_kernel,
399
+ demodulate=demodulate,
400
+ )
401
+
402
+ self.noise = NoiseInjection()
403
+ # self.bias = nn.Parameter(torch.zeros(1, out_channel, 1, 1))
404
+ # self.activate = ScaledLeakyReLU(0.2)
405
+ self.activate = FusedLeakyReLU(out_channel)
406
+
407
+ def forward(self, input, style=None, noise=None):
408
+ out = self.conv(input, style)
409
+ if self.inject_noise:
410
+ out = self.noise(out, noise=noise)
411
+ # out = out + self.bias
412
+ out = self.activate(out)
413
+
414
+ return out
415
+
416
+
417
+ class ToRGB(nn.Module):
418
+ def __init__(self, in_channel, style_dim, upsample=True, blur_kernel=[1, 3, 3, 1]):
419
+ super().__init__()
420
+
421
+ if upsample:
422
+ self.upsample = Upsample(blur_kernel)
423
+
424
+ self.conv = ModulatedConv2d(in_channel, 3, 1, style_dim, demodulate=False)
425
+ self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1))
426
+
427
+ def forward(self, input, style, skip=None):
428
+ out = self.conv(input, style)
429
+ out = out + self.bias
430
+
431
+ if skip is not None:
432
+ skip = self.upsample(skip)
433
+
434
+ out = out + skip
435
+
436
+ return out
437
+
438
+
439
+ class Generator(nn.Module):
440
+ def __init__(
441
+ self,
442
+ size,
443
+ style_dim,
444
+ n_mlp,
445
+ channel_multiplier=2,
446
+ blur_kernel=[1, 3, 3, 1],
447
+ lr_mlp=0.01,
448
+ ):
449
+ super().__init__()
450
+
451
+ self.size = size
452
+
453
+ self.style_dim = style_dim
454
+
455
+ layers = [PixelNorm()]
456
+
457
+ for i in range(n_mlp):
458
+ layers.append(
459
+ EqualLinear(
460
+ style_dim, style_dim, lr_mul=lr_mlp, activation='fused_lrelu'
461
+ )
462
+ )
463
+
464
+ self.style = nn.Sequential(*layers)
465
+
466
+ self.channels = {
467
+ 4: 512,
468
+ 8: 512,
469
+ 16: 512,
470
+ 32: 512,
471
+ 64: 256 * channel_multiplier,
472
+ 128: 128 * channel_multiplier,
473
+ 256: 64 * channel_multiplier,
474
+ 512: 32 * channel_multiplier,
475
+ 1024: 16 * channel_multiplier,
476
+ }
477
+
478
+ self.input = ConstantInput(self.channels[4])
479
+ self.conv1 = StyledConv(
480
+ self.channels[4], self.channels[4], 3, style_dim, blur_kernel=blur_kernel
481
+ )
482
+ self.to_rgb1 = ToRGB(self.channels[4], style_dim, upsample=False)
483
+
484
+ self.log_size = int(math.log(size, 2))
485
+ self.num_layers = (self.log_size - 2) * 2 + 1
486
+
487
+ self.convs = nn.ModuleList()
488
+ self.upsamples = nn.ModuleList()
489
+ self.to_rgbs = nn.ModuleList()
490
+ self.noises = nn.Module()
491
+
492
+ in_channel = self.channels[4]
493
+
494
+ for layer_idx in range(self.num_layers):
495
+ res = (layer_idx + 5) // 2
496
+ shape = [1, 1, 2 ** res, 2 ** res]
497
+ self.noises.register_buffer(f'noise_{layer_idx}', torch.randn(*shape))
498
+
499
+ for i in range(3, self.log_size + 1):
500
+ out_channel = self.channels[2 ** i]
501
+
502
+ self.convs.append(
503
+ StyledConv(
504
+ in_channel,
505
+ out_channel,
506
+ 3,
507
+ style_dim,
508
+ upsample=True,
509
+ blur_kernel=blur_kernel,
510
+ )
511
+ )
512
+
513
+ self.convs.append(
514
+ StyledConv(
515
+ out_channel, out_channel, 3, style_dim, blur_kernel=blur_kernel
516
+ )
517
+ )
518
+
519
+ self.to_rgbs.append(ToRGB(out_channel, style_dim))
520
+
521
+ in_channel = out_channel
522
+
523
+ self.n_latent = self.log_size * 2 - 2
524
+
525
+ def make_noise(self):
526
+ device = self.input.input.device
527
+
528
+ noises = [torch.randn(1, 1, 2 ** 2, 2 ** 2, device=device)]
529
+
530
+ for i in range(3, self.log_size + 1):
531
+ for _ in range(2):
532
+ noises.append(torch.randn(1, 1, 2 ** i, 2 ** i, device=device))
533
+
534
+ return noises
535
+
536
+ def mean_latent(self, n_latent):
537
+ latent_in = torch.randn(
538
+ n_latent, self.style_dim, device=self.input.input.device
539
+ )
540
+ latent = self.style(latent_in).mean(0, keepdim=True)
541
+
542
+ return latent
543
+
544
+ def get_latent(self, input):
545
+ return self.style(input)
546
+
547
+ def forward(
548
+ self,
549
+ styles,
550
+ return_latents=False,
551
+ inject_index=None,
552
+ truncation=1,
553
+ truncation_latent=None,
554
+ input_is_latent=False,
555
+ noise=None,
556
+ randomize_noise=True,
557
+ ):
558
+ if not input_is_latent:
559
+ styles = [self.style(s) for s in styles]
560
+
561
+ if noise is None:
562
+ if randomize_noise:
563
+ noise = [None] * self.num_layers
564
+ else:
565
+ noise = [
566
+ getattr(self.noises, f'noise_{i}') for i in range(self.num_layers)
567
+ ]
568
+
569
+ if truncation < 1:
570
+ style_t = []
571
+
572
+ for style in styles:
573
+ style_t.append(
574
+ truncation_latent + truncation * (style - truncation_latent)
575
+ )
576
+
577
+ styles = style_t
578
+
579
+ if len(styles) < 2:
580
+ inject_index = self.n_latent
581
+
582
+ if len(styles[0].shape) < 3:
583
+ latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
584
+
585
+ else:
586
+ latent = styles[0]
587
+
588
+ else:
589
+ if inject_index is None:
590
+ inject_index = random.randint(1, self.n_latent - 1)
591
+
592
+ latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
593
+ latent2 = styles[1].unsqueeze(1).repeat(1, self.n_latent - inject_index, 1)
594
+
595
+ latent = torch.cat([latent, latent2], 1)
596
+
597
+ out = self.input(latent)
598
+ out = self.conv1(out, latent[:, 0], noise=noise[0])
599
+
600
+ skip = self.to_rgb1(out, latent[:, 1])
601
+
602
+ i = 1
603
+ for conv1, conv2, noise1, noise2, to_rgb in zip(
604
+ self.convs[::2], self.convs[1::2], noise[1::2], noise[2::2], self.to_rgbs
605
+ ):
606
+ out = conv1(out, latent[:, i], noise=noise1)
607
+ out = conv2(out, latent[:, i + 1], noise=noise2)
608
+ skip = to_rgb(out, latent[:, i + 2], skip)
609
+
610
+ i += 2
611
+
612
+ image = skip
613
+
614
+ if return_latents:
615
+ return image, latent
616
+
617
+ else:
618
+ return image, None
619
+
620
+
621
+ class ConvLayer(nn.Sequential):
622
+ def __init__(
623
+ self,
624
+ in_channel,
625
+ out_channel,
626
+ kernel_size,
627
+ downsample=False,
628
+ blur_kernel=[1, 3, 3, 1],
629
+ bias=True,
630
+ activate=True,
631
+ ):
632
+ layers = []
633
+
634
+ if downsample:
635
+ factor = 2
636
+ p = (len(blur_kernel) - factor) + (kernel_size - 1)
637
+ pad0 = (p + 1) // 2
638
+ pad1 = p // 2
639
+
640
+ layers.append(Blur(blur_kernel, pad=(pad0, pad1)))
641
+
642
+ stride = 2
643
+ self.padding = 0
644
+
645
+ else:
646
+ stride = 1
647
+ self.padding = kernel_size // 2
648
+
649
+ layers.append(
650
+ EqualConv2d(
651
+ in_channel,
652
+ out_channel,
653
+ kernel_size,
654
+ padding=self.padding,
655
+ stride=stride,
656
+ bias=bias and not activate,
657
+ )
658
+ )
659
+
660
+ if activate:
661
+ if bias:
662
+ layers.append(FusedLeakyReLU(out_channel))
663
+
664
+ else:
665
+ layers.append(ScaledLeakyReLU(0.2))
666
+
667
+ super().__init__(*layers)
668
+
669
+
670
+ class ResBlock(nn.Module):
671
+ def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1], downsample=True, skip_gain=1.0):
672
+ super().__init__()
673
+
674
+ self.skip_gain = skip_gain
675
+ self.conv1 = ConvLayer(in_channel, in_channel, 3)
676
+ self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=downsample, blur_kernel=blur_kernel)
677
+
678
+ if in_channel != out_channel or downsample:
679
+ self.skip = ConvLayer(
680
+ in_channel, out_channel, 1, downsample=downsample, activate=False, bias=False
681
+ )
682
+ else:
683
+ self.skip = nn.Identity()
684
+
685
+ def forward(self, input):
686
+ out = self.conv1(input)
687
+ out = self.conv2(out)
688
+
689
+ skip = self.skip(input)
690
+ out = (out * self.skip_gain + skip) / math.sqrt(self.skip_gain ** 2 + 1.0)
691
+
692
+ return out
693
+
694
+
695
+ class StyleGAN2Discriminator(nn.Module):
696
+ def __init__(self, input_nc, ndf=64, n_layers=3, no_antialias=False, size=None, opt=None):
697
+ super().__init__()
698
+ self.opt = opt
699
+ self.stddev_group = 16
700
+ if size is None:
701
+ size = 2 ** int((np.rint(np.log2(min(opt.load_size, opt.crop_size)))))
702
+ if "patch" in self.opt.netD and self.opt.D_patch_size is not None:
703
+ size = 2 ** int(np.log2(self.opt.D_patch_size))
704
+
705
+ blur_kernel = [1, 3, 3, 1]
706
+ channel_multiplier = ndf / 64
707
+ channels = {
708
+ 4: min(384, int(4096 * channel_multiplier)),
709
+ 8: min(384, int(2048 * channel_multiplier)),
710
+ 16: min(384, int(1024 * channel_multiplier)),
711
+ 32: min(384, int(512 * channel_multiplier)),
712
+ 64: int(256 * channel_multiplier),
713
+ 128: int(128 * channel_multiplier),
714
+ 256: int(64 * channel_multiplier),
715
+ 512: int(32 * channel_multiplier),
716
+ 1024: int(16 * channel_multiplier),
717
+ }
718
+
719
+ convs = [ConvLayer(3, channels[size], 1)]
720
+
721
+ log_size = int(math.log(size, 2))
722
+
723
+ in_channel = channels[size]
724
+
725
+ if "smallpatch" in self.opt.netD:
726
+ final_res_log2 = 4
727
+ elif "patch" in self.opt.netD:
728
+ final_res_log2 = 3
729
+ else:
730
+ final_res_log2 = 2
731
+
732
+ for i in range(log_size, final_res_log2, -1):
733
+ out_channel = channels[2 ** (i - 1)]
734
+
735
+ convs.append(ResBlock(in_channel, out_channel, blur_kernel))
736
+
737
+ in_channel = out_channel
738
+
739
+ self.convs = nn.Sequential(*convs)
740
+
741
+ if False and "tile" in self.opt.netD:
742
+ in_channel += 1
743
+ self.final_conv = ConvLayer(in_channel, channels[4], 3)
744
+ if "patch" in self.opt.netD:
745
+ self.final_linear = ConvLayer(channels[4], 1, 3, bias=False, activate=False)
746
+ else:
747
+ self.final_linear = nn.Sequential(
748
+ EqualLinear(channels[4] * 4 * 4, channels[4], activation='fused_lrelu'),
749
+ EqualLinear(channels[4], 1),
750
+ )
751
+
752
+ def forward(self, input, get_minibatch_features=False):
753
+ if "patch" in self.opt.netD and self.opt.D_patch_size is not None:
754
+ h, w = input.size(2), input.size(3)
755
+ y = torch.randint(h - self.opt.D_patch_size, ())
756
+ x = torch.randint(w - self.opt.D_patch_size, ())
757
+ input = input[:, :, y:y + self.opt.D_patch_size, x:x + self.opt.D_patch_size]
758
+ out = input
759
+ for i, conv in enumerate(self.convs):
760
+ out = conv(out)
761
+ # print(i, out.abs().mean())
762
+ # out = self.convs(input)
763
+
764
+ batch, channel, height, width = out.shape
765
+
766
+ if False and "tile" in self.opt.netD:
767
+ group = min(batch, self.stddev_group)
768
+ stddev = out.view(
769
+ group, -1, 1, channel // 1, height, width
770
+ )
771
+ stddev = torch.sqrt(stddev.var(0, unbiased=False) + 1e-8)
772
+ stddev = stddev.mean([2, 3, 4], keepdim=True).squeeze(2)
773
+ stddev = stddev.repeat(group, 1, height, width)
774
+ out = torch.cat([out, stddev], 1)
775
+
776
+ out = self.final_conv(out)
777
+ # print(out.abs().mean())
778
+
779
+ if "patch" not in self.opt.netD:
780
+ out = out.view(batch, -1)
781
+ out = self.final_linear(out)
782
+
783
+ return out
784
+
785
+
786
+ class TileStyleGAN2Discriminator(StyleGAN2Discriminator):
787
+ def forward(self, input):
788
+ B, C, H, W = input.size(0), input.size(1), input.size(2), input.size(3)
789
+ size = self.opt.D_patch_size
790
+ Y = H // size
791
+ X = W // size
792
+ input = input.view(B, C, Y, size, X, size)
793
+ input = input.permute(0, 2, 4, 1, 3, 5).contiguous().view(B * Y * X, C, size, size)
794
+ return super().forward(input)
795
+
796
+
797
+ class StyleGAN2Encoder(nn.Module):
798
+ def __init__(self, input_nc, output_nc, ngf=64, use_dropout=False, n_blocks=6, padding_type='reflect', no_antialias=False, opt=None):
799
+ super().__init__()
800
+ assert opt is not None
801
+ self.opt = opt
802
+ channel_multiplier = ngf / 32
803
+ channels = {
804
+ 4: min(512, int(round(4096 * channel_multiplier))),
805
+ 8: min(512, int(round(2048 * channel_multiplier))),
806
+ 16: min(512, int(round(1024 * channel_multiplier))),
807
+ 32: min(512, int(round(512 * channel_multiplier))),
808
+ 64: int(round(256 * channel_multiplier)),
809
+ 128: int(round(128 * channel_multiplier)),
810
+ 256: int(round(64 * channel_multiplier)),
811
+ 512: int(round(32 * channel_multiplier)),
812
+ 1024: int(round(16 * channel_multiplier)),
813
+ }
814
+
815
+ blur_kernel = [1, 3, 3, 1]
816
+
817
+ cur_res = 2 ** int((np.rint(np.log2(min(opt.load_size, opt.crop_size)))))
818
+ convs = [nn.Identity(),
819
+ ConvLayer(3, channels[cur_res], 1)]
820
+
821
+ num_downsampling = self.opt.stylegan2_G_num_downsampling
822
+ for i in range(num_downsampling):
823
+ in_channel = channels[cur_res]
824
+ out_channel = channels[cur_res // 2]
825
+ convs.append(ResBlock(in_channel, out_channel, blur_kernel, downsample=True))
826
+ cur_res = cur_res // 2
827
+
828
+ for i in range(n_blocks // 2):
829
+ n_channel = channels[cur_res]
830
+ convs.append(ResBlock(n_channel, n_channel, downsample=False))
831
+
832
+ self.convs = nn.Sequential(*convs)
833
+
834
+ def forward(self, input, layers=[], get_features=False):
835
+ feat = input
836
+ feats = []
837
+ if -1 in layers:
838
+ layers.append(len(self.convs) - 1)
839
+ for layer_id, layer in enumerate(self.convs):
840
+ feat = layer(feat)
841
+ # print(layer_id, " features ", feat.abs().mean())
842
+ if layer_id in layers:
843
+ feats.append(feat)
844
+
845
+ if get_features:
846
+ return feat, feats
847
+ else:
848
+ return feat
849
+
850
+
851
+ class StyleGAN2Decoder(nn.Module):
852
+ def __init__(self, input_nc, output_nc, ngf=64, use_dropout=False, n_blocks=6, padding_type='reflect', no_antialias=False, opt=None):
853
+ super().__init__()
854
+ assert opt is not None
855
+ self.opt = opt
856
+
857
+ blur_kernel = [1, 3, 3, 1]
858
+
859
+ channel_multiplier = ngf / 32
860
+ channels = {
861
+ 4: min(512, int(round(4096 * channel_multiplier))),
862
+ 8: min(512, int(round(2048 * channel_multiplier))),
863
+ 16: min(512, int(round(1024 * channel_multiplier))),
864
+ 32: min(512, int(round(512 * channel_multiplier))),
865
+ 64: int(round(256 * channel_multiplier)),
866
+ 128: int(round(128 * channel_multiplier)),
867
+ 256: int(round(64 * channel_multiplier)),
868
+ 512: int(round(32 * channel_multiplier)),
869
+ 1024: int(round(16 * channel_multiplier)),
870
+ }
871
+
872
+ num_downsampling = self.opt.stylegan2_G_num_downsampling
873
+ cur_res = 2 ** int((np.rint(np.log2(min(opt.load_size, opt.crop_size))))) // (2 ** num_downsampling)
874
+ convs = []
875
+
876
+ for i in range(n_blocks // 2):
877
+ n_channel = channels[cur_res]
878
+ convs.append(ResBlock(n_channel, n_channel, downsample=False))
879
+
880
+ for i in range(num_downsampling):
881
+ in_channel = channels[cur_res]
882
+ out_channel = channels[cur_res * 2]
883
+ inject_noise = "small" not in self.opt.netG
884
+ convs.append(
885
+ StyledConv(in_channel, out_channel, 3, upsample=True, blur_kernel=blur_kernel, inject_noise=inject_noise)
886
+ )
887
+ cur_res = cur_res * 2
888
+
889
+ convs.append(ConvLayer(channels[cur_res], 3, 1))
890
+
891
+ self.convs = nn.Sequential(*convs)
892
+
893
+ def forward(self, input):
894
+ return self.convs(input)
895
+
896
+
897
+ class StyleGAN2Generator(nn.Module):
898
+ def __init__(self, input_nc, output_nc, ngf=64, use_dropout=False, n_blocks=6, padding_type='reflect', no_antialias=False, opt=None):
899
+ super().__init__()
900
+ self.opt = opt
901
+ self.encoder = StyleGAN2Encoder(input_nc, output_nc, ngf, use_dropout, n_blocks, padding_type, no_antialias, opt)
902
+ self.decoder = StyleGAN2Decoder(input_nc, output_nc, ngf, use_dropout, n_blocks, padding_type, no_antialias, opt)
903
+
904
+ def forward(self, input, layers=[], encode_only=False):
905
+ feat, feats = self.encoder(input, layers, True)
906
+ if encode_only:
907
+ return feats
908
+ else:
909
+ fake = self.decoder(feat)
910
+
911
+ if len(layers) > 0:
912
+ return fake, feats
913
+ else:
914
+ return fake
Scenimefy/options/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ """
2
+ This package options includes option modules: training options, test options, and basic options (used in both training and test).
3
+ """
Scenimefy/options/base_options.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import torch
4
+ import Scenimefy.models as models
5
+ import Scenimefy.data as data
6
+ from Scenimefy.utils import util
7
+
8
+ class BaseOptions():
9
+ """
10
+ This class defines options used during both training and test time.
11
+
12
+ It also implements several helper functions such as parsing, printing, and saving the options.
13
+ It also gathers additional options defined in <modify_commandline_options> functions in both dataset class and model class.
14
+ """
15
+
16
+ def __init__(self, cmd_line=None):
17
+ """Reset the class; indicates the class hasn't been initailized"""
18
+ self.initialized = False
19
+ self.cmd_line = None
20
+ if cmd_line is not None:
21
+ self.cmd_line = cmd_line.split()
22
+
23
+ def initialize(self, parser):
24
+ """Define the common options that are used in both training and test."""
25
+ # basic parameters
26
+ # load unpaired dataset
27
+ parser.add_argument('--dataroot', default='Scenimefy\datasets\Sample', help='path to images (should have subfolders trainA, trainB, valA, valB, etc)')
28
+ parser.add_argument('--name', type=str, default='huggingface', help='name of the experiment. It decides where to store samples and models')
29
+ parser.add_argument('--easy_label', type=str, default='experiment_name', help='Interpretable name')
30
+ parser.add_argument('--gpu_ids', type=str, default='-1', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU')
31
+ parser.add_argument('--checkpoints_dir', type=str, default='Scenimefy/pretrained_models', help='models are saved here')
32
+ # model parameters
33
+ parser.add_argument('--model', type=str, default='cut', help='chooses which model to use.')
34
+ parser.add_argument('--input_nc', type=int, default=3, help='# of input image channels: 3 for RGB and 1 for grayscale')
35
+ parser.add_argument('--output_nc', type=int, default=3, help='# of output image channels: 3 for RGB and 1 for grayscale')
36
+ parser.add_argument('--ngf', type=int, default=64, help='# of gen filters in the last conv layer')
37
+ parser.add_argument('--ndf', type=int, default=64, help='# of discrim filters in the first conv layer')
38
+ parser.add_argument('--netD', type=str, default='basic', choices=['basic', 'n_layers', 'pixel', 'patch', 'tilestylegan2', 'stylegan2'], help='specify discriminator architecture. The basic model is a 70x70 PatchGAN. n_layers allows you to specify the layers in the discriminator')
39
+ parser.add_argument('--netG', type=str, default='resnet_9blocks', choices=['resnet_9blocks', 'resnet_6blocks', 'unet_256', 'unet_128', 'stylegan2', 'smallstylegan2', 'resnet_cat'], help='specify generator architecture')
40
+ parser.add_argument('--n_layers_D', type=int, default=3, help='only used if netD==n_layers')
41
+ parser.add_argument('--normG', type=str, default='instance', choices=['instance', 'batch', 'none'], help='instance normalization or batch normalization for G')
42
+ parser.add_argument('--normD', type=str, default='instance', choices=['instance', 'batch', 'none'], help='instance normalization or batch normalization for D')
43
+ parser.add_argument('--init_type', type=str, default='xavier', choices=['normal', 'xavier', 'kaiming', 'orthogonal'], help='network initialization')
44
+ parser.add_argument('--init_gain', type=float, default=0.02, help='scaling factor for normal, xavier and orthogonal.')
45
+ parser.add_argument('--no_dropout', type=util.str2bool, nargs='?', const=True, default=True,
46
+ help='no dropout for the generator')
47
+ parser.add_argument('--no_antialias', action='store_true', help='if specified, use stride=2 convs instead of antialiased-downsampling (sad)')
48
+ parser.add_argument('--no_antialias_up', action='store_true', help='if specified, use [upconv(learned filter)] instead of [upconv(hard-coded [1,3,3,1] filter), conv]')
49
+ # dataset parameters
50
+ parser.add_argument('--dataset_mode', type=str, default='unaligned', help='chooses how datasets are loaded. [unaligned | aligned | single | colorization]')
51
+ parser.add_argument('--direction', type=str, default='AtoB', help='AtoB or BtoA')
52
+ parser.add_argument('--serial_batches', action='store_true', help='if true, takes images in order to make batches, otherwise takes them randomly')
53
+ parser.add_argument('--num_threads', default=0, type=int, help='# threads for loading data')
54
+ parser.add_argument('--batch_size', type=int, default=1, help='input batch size')
55
+ parser.add_argument('--load_size', type=int, default=256, help='scale images to this size')
56
+ parser.add_argument('--crop_size', type=int, default=256, help='then crop to this size')
57
+ parser.add_argument('--max_dataset_size', type=int, default=float("inf"), help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.')
58
+ parser.add_argument('--preprocess', type=str, default='none', help='scaling and cropping of images at load time [resize_and_crop | crop | scale_width | scale_width_and_crop | none]')
59
+ parser.add_argument('--no_flip', action='store_true', help='if specified, do not flip the images for data augmentation')
60
+ parser.add_argument('--display_winsize', type=int, default=256, help='display window size for both visdom and HTML')
61
+ parser.add_argument('--random_scale_max', type=float, default=3.0,
62
+ help='(used for single image translation) Randomly scale the image by the specified factor as data augmentation.')
63
+ # additional parameters
64
+ parser.add_argument('--epoch', type=str, default='Shinkai', help='which epoch to load? set to latest to use latest cached model')
65
+ parser.add_argument('--verbose', action='store_true', help='if specified, print more debugging information')
66
+ parser.add_argument('--suffix', default='', type=str, help='customized suffix: opt.name = opt.name + suffix: e.g., {model}_{netG}_size{load_size}')
67
+
68
+ # parameters related to StyleGAN2-based networks
69
+ parser.add_argument('--stylegan2_G_num_downsampling',
70
+ default=1, type=int,
71
+ help='Number of downsampling layers used by StyleGAN2Generator')
72
+
73
+ self.initialized = True
74
+ return parser
75
+
76
+ def gather_options(self):
77
+ """Initialize our parser with basic options(only once).
78
+ Add additional model-specific and dataset-specific options.
79
+ These options are defined in the <modify_commandline_options> function
80
+ in model and dataset classes.
81
+ """
82
+ if not self.initialized: # check if it has been initialized
83
+ parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
84
+ parser = self.initialize(parser)
85
+
86
+ # get the basic options
87
+ if self.cmd_line is None:
88
+ opt, _ = parser.parse_known_args()
89
+ else:
90
+ opt, _ = parser.parse_known_args(self.cmd_line)
91
+
92
+ # modify model-related parser options
93
+ model_name = opt.model
94
+ model_option_setter = models.get_option_setter(model_name)
95
+ parser = model_option_setter(parser, self.isTrain)
96
+ if self.cmd_line is None:
97
+ opt, _ = parser.parse_known_args() # parse again with new defaults
98
+ else:
99
+ opt, _ = parser.parse_known_args(self.cmd_line) # parse again with new defaults
100
+
101
+ # modify dataset-related parser options
102
+ dataset_name = opt.dataset_mode
103
+ dataset_option_setter = data.get_option_setter(dataset_name)
104
+ parser = dataset_option_setter(parser, self.isTrain)
105
+
106
+ # save and return the parser
107
+ self.parser = parser
108
+ if self.cmd_line is None:
109
+ return parser.parse_args()
110
+ else:
111
+ return parser.parse_args(self.cmd_line)
112
+
113
+ def print_options(self, opt):
114
+ """Print and save options
115
+
116
+ It will print both current options and default values(if different).
117
+ It will save options into a text file / [checkpoints_dir] / opt.txt
118
+ """
119
+ message = ''
120
+ message += '----------------- Options ---------------\n'
121
+ for k, v in sorted(vars(opt).items()):
122
+ comment = ''
123
+ default = self.parser.get_default(k)
124
+ if v != default:
125
+ comment = '\t[default: %s]' % str(default)
126
+ message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment)
127
+ message += '----------------- End -------------------'
128
+ print(message)
129
+
130
+ # save to the disk
131
+ expr_dir = os.path.join(opt.checkpoints_dir, opt.name)
132
+ util.mkdirs(expr_dir)
133
+ file_name = os.path.join(expr_dir, '{}_opt.txt'.format(opt.phase))
134
+ try:
135
+ with open(file_name, 'wt') as opt_file:
136
+ opt_file.write(message)
137
+ opt_file.write('\n')
138
+ except PermissionError as error:
139
+ print("permission error {}".format(error))
140
+ pass
141
+
142
+ def parse(self):
143
+ """Parse our options, create checkpoints directory suffix, and set up gpu device."""
144
+ opt = self.gather_options()
145
+ opt.isTrain = self.isTrain # train or test
146
+
147
+ # process opt.suffix
148
+ if opt.suffix:
149
+ suffix = ('_' + opt.suffix.format(**vars(opt))) if opt.suffix != '' else ''
150
+ opt.name = opt.name + suffix
151
+
152
+ self.print_options(opt)
153
+
154
+ # set gpu ids
155
+ str_ids = opt.gpu_ids.split(',')
156
+ opt.gpu_ids = []
157
+ for str_id in str_ids:
158
+ id = int(str_id)
159
+ if id >= 0:
160
+ opt.gpu_ids.append(id)
161
+ if len(opt.gpu_ids) > 0:
162
+ torch.cuda.set_device(opt.gpu_ids[0])
163
+
164
+ self.opt = opt
165
+ return self.opt
Scenimefy/options/test_options.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from Scenimefy.options.base_options import BaseOptions
2
+
3
+
4
+ class TestOptions(BaseOptions):
5
+ """
6
+ This class includes test options.
7
+
8
+ It also includes shared options defined in BaseOptions.
9
+ """
10
+
11
+ def initialize(self, parser):
12
+ parser = BaseOptions.initialize(self, parser) # define shared options
13
+ parser.add_argument('--results_dir', type=str, default='Scenimefy/results/', help='saves results here.')
14
+ parser.add_argument('--phase', type=str, default='test', help='train, val, test, etc')
15
+ # Dropout and Batchnorm has different behavioir during training and test.
16
+ parser.add_argument('--eval', action='store_true', help='use eval mode during test time.')
17
+ parser.add_argument('--num_test', type=int, default=1000, help='how many test images to run')
18
+
19
+ # To avoid cropping, the load_size should be the same as crop_size
20
+ parser.set_defaults(load_size=parser.get_default('crop_size'))
21
+ self.isTrain = False
22
+ return parser
Scenimefy/pretrained_models/huggingface/Shinkai_net_G.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3bdeced133287fbb95832f4342aaff399f6f73507b515118918cc27ccd98ad8c
3
+ size 45570633
Scenimefy/pretrained_models/huggingface/test_opt.txt ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ----------------- Options ---------------
2
+ CUT_mode: CUT
3
+ HDCE_gamma: 1
4
+ HDCE_gamma_min: 1
5
+ alpha: 0.2
6
+ batch_size: 1
7
+ checkpoints_dir: Scenimefy/pretrained_models
8
+ crop_size: 256
9
+ dataroot: Scenimefy\datasets\Sample
10
+ dataset_mode: unaligned
11
+ dce_idt: False
12
+ direction: AtoB
13
+ display_winsize: 256
14
+ easy_label: experiment_name
15
+ epoch: Shinkai
16
+ eval: False
17
+ flip_equivariance: False
18
+ gpu_ids: -1
19
+ init_gain: 0.02
20
+ init_type: xavier
21
+ input_nc: 3
22
+ isTrain: False [default: None]
23
+ lambda_GAN: 1.0
24
+ lambda_HDCE: 1.0
25
+ lambda_SRC: 1.0
26
+ load_size: 256
27
+ max_dataset_size: inf
28
+ model: cut
29
+ n_layers_D: 3
30
+ name: huggingface
31
+ nce_T: 0.07
32
+ nce_includes_all_negatives_from_minibatch: False
33
+ nce_layers: 0,4,8,12,16
34
+ ndf: 64
35
+ netD: basic
36
+ netF: mlp_sample
37
+ netF_nc: 256
38
+ netG: resnet_9blocks
39
+ ngf: 64
40
+ no_Hneg: False
41
+ no_antialias: False
42
+ no_antialias_up: False
43
+ no_dropout: True
44
+ no_flip: False
45
+ normD: instance
46
+ normG: instance
47
+ num_patches: 256
48
+ num_test: 1000
49
+ num_threads: 0
50
+ output_nc: 3
51
+ phase: test
52
+ pool_size: 0
53
+ preprocess: none
54
+ random_scale_max: 3.0
55
+ results_dir: Scenimefy/results/
56
+ serial_batches: False
57
+ step_gamma: False
58
+ step_gamma_epoch: 200
59
+ stylegan2_G_num_downsampling: 1
60
+ suffix:
61
+ use_curriculum: False
62
+ verbose: False
63
+ ----------------- End -------------------
Scenimefy/utils/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ """
2
+ This package includes a miscellaneous collection of useful helper functions.
3
+ """
4
+ from Scenimefy.utils import *
Scenimefy/utils/html.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import dominate
2
+ from dominate.tags import meta, h3, table, tr, td, p, a, img, br
3
+ import os
4
+
5
+
6
+ class HTML:
7
+ """This HTML class allows us to save images and write texts into a single HTML file.
8
+
9
+ It consists of functions such as <add_header> (add a text header to the HTML file),
10
+ <add_images> (add a row of images to the HTML file), and <save> (save the HTML to the disk).
11
+ It is based on Python library 'dominate', a Python library for creating and manipulating HTML documents using a DOM API.
12
+ """
13
+
14
+ def __init__(self, web_dir, title, refresh=0):
15
+ """Initialize the HTML classes
16
+
17
+ Parameters:
18
+ web_dir (str) -- a directory that stores the webpage. HTML file will be created at <web_dir>/index.html; images will be saved at <web_dir/images/
19
+ title (str) -- the webpage name
20
+ refresh (int) -- how often the website refresh itself; if 0; no refreshing
21
+ """
22
+ self.title = title
23
+ self.web_dir = web_dir
24
+ self.img_dir = os.path.join(self.web_dir, 'images')
25
+ if not os.path.exists(self.web_dir):
26
+ os.makedirs(self.web_dir)
27
+ if not os.path.exists(self.img_dir):
28
+ os.makedirs(self.img_dir)
29
+
30
+ self.doc = dominate.document(title=title)
31
+ if refresh > 0:
32
+ with self.doc.head:
33
+ meta(http_equiv="refresh", content=str(refresh))
34
+
35
+ def get_image_dir(self):
36
+ """Return the directory that stores images"""
37
+ return self.img_dir
38
+
39
+ def add_header(self, text):
40
+ """Insert a header to the HTML file
41
+
42
+ Parameters:
43
+ text (str) -- the header text
44
+ """
45
+ with self.doc:
46
+ h3(text)
47
+
48
+ def add_images(self, ims, txts, links, width=400):
49
+ """add images to the HTML file
50
+
51
+ Parameters:
52
+ ims (str list) -- a list of image paths
53
+ txts (str list) -- a list of image names shown on the website
54
+ links (str list) -- a list of hyperref links; when you click an image, it will redirect you to a new page
55
+ """
56
+ self.t = table(border=1, style="table-layout: fixed;") # Insert a table
57
+ self.doc.add(self.t)
58
+ with self.t:
59
+ with tr():
60
+ for im, txt, link in zip(ims, txts, links):
61
+ with td(style="word-wrap: break-word;", halign="center", valign="top"):
62
+ with p():
63
+ with a(href=os.path.join('images', link)):
64
+ img(style="width:%dpx" % width, src=os.path.join('images', im))
65
+ br()
66
+ p(txt)
67
+
68
+ def save(self):
69
+ """save the current content to the HMTL file"""
70
+ html_file = '%s/index.html' % self.web_dir
71
+ f = open(html_file, 'wt')
72
+ f.write(self.doc.render())
73
+ f.close()
74
+
75
+
76
+ if __name__ == '__main__': # we show an example usage here.
77
+ html = HTML('web/', 'test_html')
78
+ html.add_header('hello world')
79
+
80
+ ims, txts, links = [], [], []
81
+ for n in range(4):
82
+ ims.append('image_%d.png' % n)
83
+ txts.append('text_%d' % n)
84
+ links.append('image_%d.png' % n)
85
+ html.add_images(ims, txts, links)
86
+ html.save()
Scenimefy/utils/util.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """This module contains simple helper functions """
2
+ from __future__ import print_function
3
+ import torch
4
+ import numpy as np
5
+ from PIL import Image
6
+ import os
7
+ import importlib
8
+ import argparse
9
+ from argparse import Namespace
10
+ import torchvision
11
+
12
+
13
+ def str2bool(v):
14
+ if isinstance(v, bool):
15
+ return v
16
+ if v.lower() in ('yes', 'true', 't', 'y', '1'):
17
+ return True
18
+ elif v.lower() in ('no', 'false', 'f', 'n', '0'):
19
+ return False
20
+ else:
21
+ raise argparse.ArgumentTypeError('Boolean value expected.')
22
+
23
+
24
+ def copyconf(default_opt, **kwargs):
25
+ conf = Namespace(**vars(default_opt))
26
+ for key in kwargs:
27
+ setattr(conf, key, kwargs[key])
28
+ return conf
29
+
30
+
31
+ def find_class_in_module(target_cls_name, module):
32
+ target_cls_name = target_cls_name.replace('_', '').lower()
33
+ clslib = importlib.import_module(module)
34
+ cls = None
35
+ for name, clsobj in clslib.__dict__.items():
36
+ if name.lower() == target_cls_name:
37
+ cls = clsobj
38
+
39
+ assert cls is not None, "In %s, there should be a class whose name matches %s in lowercase without underscore(_)" % (module, target_cls_name)
40
+
41
+ return cls
42
+
43
+
44
+ def tensor2im(input_image, imtype=np.uint8):
45
+ """"Converts a Tensor array into a numpy image array.
46
+
47
+ Parameters:
48
+ input_image (tensor) -- the input image tensor array
49
+ imtype (type) -- the desired type of the converted numpy array
50
+ """
51
+ if not isinstance(input_image, np.ndarray):
52
+ if isinstance(input_image, torch.Tensor): # get the data from a variable
53
+ image_tensor = input_image.data
54
+ else:
55
+ return input_image
56
+ image_numpy = image_tensor[0].clamp(-1.0, 1.0).cpu().float().numpy() # convert it into a numpy array
57
+ if image_numpy.shape[0] == 1: # grayscale to RGB
58
+ image_numpy = np.tile(image_numpy, (3, 1, 1))
59
+ image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0 # post-processing: tranpose and scaling
60
+ else: # if it is a numpy array, do nothing
61
+ image_numpy = input_image
62
+ return image_numpy.astype(imtype)
63
+
64
+
65
+ def diagnose_network(net, name='network'):
66
+ """Calculate and print the mean of average absolute(gradients)
67
+
68
+ Parameters:
69
+ net (torch network) -- Torch network
70
+ name (str) -- the name of the network
71
+ """
72
+ mean = 0.0
73
+ count = 0
74
+ for param in net.parameters():
75
+ if param.grad is not None:
76
+ mean += torch.mean(torch.abs(param.grad.data))
77
+ count += 1
78
+ if count > 0:
79
+ mean = mean / count
80
+ print(name)
81
+ print(mean)
82
+
83
+
84
+ def save_image(image_numpy, image_path, aspect_ratio=1.0):
85
+ """Save a numpy image to the disk
86
+
87
+ Parameters:
88
+ image_numpy (numpy array) -- input numpy array
89
+ image_path (str) -- the path of the image
90
+ """
91
+
92
+ image_pil = Image.fromarray(image_numpy)
93
+ h, w, _ = image_numpy.shape
94
+
95
+ if aspect_ratio is None:
96
+ pass
97
+ elif aspect_ratio > 1.0:
98
+ image_pil = image_pil.resize((h, int(w * aspect_ratio)), Image.BICUBIC)
99
+ elif aspect_ratio < 1.0:
100
+ image_pil = image_pil.resize((int(h / aspect_ratio), w), Image.BICUBIC)
101
+ # TODO: TEST
102
+ # print(image_path)
103
+ image_pil.save(image_path)
104
+
105
+
106
+ def print_numpy(x, val=True, shp=False):
107
+ """Print the mean, min, max, median, std, and size of a numpy array
108
+
109
+ Parameters:
110
+ val (bool) -- if print the values of the numpy array
111
+ shp (bool) -- if print the shape of the numpy array
112
+ """
113
+ x = x.astype(np.float64)
114
+ if shp:
115
+ print('shape,', x.shape)
116
+ if val:
117
+ x = x.flatten()
118
+ print('mean = %3.3f, min = %3.3f, max = %3.3f, median = %3.3f, std=%3.3f' % (
119
+ np.mean(x), np.min(x), np.max(x), np.median(x), np.std(x)))
120
+
121
+
122
+ def mkdirs(paths):
123
+ """create empty directories if they don't exist
124
+
125
+ Parameters:
126
+ paths (str list) -- a list of directory paths
127
+ """
128
+ if isinstance(paths, list) and not isinstance(paths, str):
129
+ for path in paths:
130
+ mkdir(path)
131
+ else:
132
+ mkdir(paths)
133
+
134
+
135
+ def mkdir(path):
136
+ """create a single empty directory if it didn't exist
137
+
138
+ Parameters:
139
+ path (str) -- a single directory path
140
+ """
141
+ if not os.path.exists(path):
142
+ os.makedirs(path)
143
+
144
+
145
+ def correct_resize_label(t, size):
146
+ device = t.device
147
+ t = t.detach().cpu()
148
+ resized = []
149
+ for i in range(t.size(0)):
150
+ one_t = t[i, :1]
151
+ one_np = np.transpose(one_t.numpy().astype(np.uint8), (1, 2, 0))
152
+ one_np = one_np[:, :, 0]
153
+ one_image = Image.fromarray(one_np).resize(size, Image.NEAREST)
154
+ resized_t = torch.from_numpy(np.array(one_image)).long()
155
+ resized.append(resized_t)
156
+ return torch.stack(resized, dim=0).to(device)
157
+
158
+
159
+ def correct_resize(t, size, mode=Image.BICUBIC):
160
+ device = t.device
161
+ t = t.detach().cpu()
162
+ resized = []
163
+ for i in range(t.size(0)):
164
+ one_t = t[i:i + 1]
165
+ one_image = Image.fromarray(tensor2im(one_t)).resize(size, Image.BICUBIC)
166
+ resized_t = torchvision.transforms.functional.to_tensor(one_image) * 2 - 1.0
167
+ resized.append(resized_t)
168
+ return torch.stack(resized, dim=0).to(device)
Scenimefy/utils/visualizer.py ADDED
@@ -0,0 +1,246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import os
3
+ import sys
4
+ import ntpath
5
+ import time
6
+ from . import util, html
7
+ from subprocess import Popen, PIPE
8
+ import math
9
+
10
+ if sys.version_info[0] == 2:
11
+ VisdomExceptionBase = Exception
12
+ else:
13
+ VisdomExceptionBase = ConnectionError
14
+
15
+
16
+ def save_images(webpage, visuals, image_path, aspect_ratio=1.0, width=256):
17
+ """Save images to the disk.
18
+
19
+ Parameters:
20
+ webpage (the HTML class) -- the HTML webpage class that stores these imaegs (see html.py for more details)
21
+ visuals (OrderedDict) -- an ordered dictionary that stores (name, images (either tensor or numpy) ) pairs
22
+ image_path (str) -- the string is used to create image paths
23
+ aspect_ratio (float) -- the aspect ratio of saved images
24
+ width (int) -- the images will be resized to width x width
25
+
26
+ This function will save images stored in 'visuals' to the HTML file specified by 'webpage'.
27
+ """
28
+ image_dir = webpage.get_image_dir()
29
+ short_path = ntpath.basename(image_path[0])
30
+ name = os.path.splitext(short_path)[0]
31
+
32
+ webpage.add_header(name)
33
+ ims, txts, links = [], [], []
34
+
35
+ for label, im_data in visuals.items():
36
+ im = util.tensor2im(im_data)
37
+ image_name = '%s/%s.png' % (label, name)
38
+ os.makedirs(os.path.join(image_dir, label), exist_ok=True)
39
+ save_path = os.path.join(image_dir, image_name)
40
+ util.save_image(im, save_path, aspect_ratio=aspect_ratio)
41
+ ims.append(image_name)
42
+ txts.append(label)
43
+ links.append(image_name)
44
+ webpage.add_images(ims, txts, links, width=width)
45
+
46
+
47
+ class Visualizer():
48
+ """This class includes several functions that can display/save images and print/save logging information.
49
+
50
+ It uses a Python library 'visdom' for display, and a Python library 'dominate' (wrapped in 'HTML') for creating HTML files with images.
51
+ """
52
+
53
+ def __init__(self, opt):
54
+ """Initialize the Visualizer class
55
+
56
+ Parameters:
57
+ opt -- stores all the experiment flags; needs to be a subclass of BaseOptions
58
+ Step 1: Cache the training/test options
59
+ Step 2: connect to a visdom server
60
+ Step 3: create an HTML object for saveing HTML filters
61
+ Step 4: create a logging file to store training losses
62
+ """
63
+ self.opt = opt # cache the option
64
+ if opt.display_id is None:
65
+ self.display_id = np.random.randint(100000) * 10 # just a random display id
66
+ else:
67
+ self.display_id = opt.display_id
68
+ self.use_html = opt.isTrain and not opt.no_html
69
+ self.win_size = opt.display_winsize
70
+ self.name = opt.name
71
+ self.port = opt.display_port
72
+ self.saved = False
73
+ if self.display_id > 0: # connect to a visdom server given <display_port> and <display_server>
74
+ import visdom
75
+ self.plot_data = {}
76
+ self.ncols = opt.display_ncols
77
+ if "tensorboard_base_url" not in os.environ:
78
+ self.vis = visdom.Visdom(server=opt.display_server, port=opt.display_port, env=opt.display_env)
79
+ else:
80
+ self.vis = visdom.Visdom(port=2004,
81
+ base_url=os.environ['tensorboard_base_url'] + '/visdom')
82
+ if not self.vis.check_connection():
83
+ self.create_visdom_connections()
84
+
85
+ if self.use_html: # create an HTML object at <checkpoints_dir>/web/; images will be saved under <checkpoints_dir>/web/images/
86
+ self.web_dir = os.path.join(opt.checkpoints_dir, opt.name, 'web')
87
+ self.img_dir = os.path.join(self.web_dir, 'images')
88
+ print('create web directory %s...' % self.web_dir)
89
+ util.mkdirs([self.web_dir, self.img_dir])
90
+ # create a logging file to store training losses
91
+ self.log_name = os.path.join(opt.checkpoints_dir, opt.name, 'loss_log.txt')
92
+ with open(self.log_name, "a") as log_file:
93
+ now = time.strftime("%c")
94
+ log_file.write('================ Training Loss (%s) ================\n' % now)
95
+
96
+ def reset(self):
97
+ """Reset the self.saved status"""
98
+ self.saved = False
99
+
100
+ def create_visdom_connections(self):
101
+ """If the program could not connect to Visdom server, this function will start a new server at port < self.port > """
102
+ cmd = sys.executable + ' -m visdom.server -p %d &>/dev/null &' % self.port
103
+ print('\n\nCould not connect to Visdom server. \n Trying to start a server....')
104
+ print('Command: %s' % cmd)
105
+ Popen(cmd, shell=True, stdout=PIPE, stderr=PIPE)
106
+
107
+ def display_current_results(self, visuals, epoch, save_result):
108
+ """Display current results on visdom; save current results to an HTML file.
109
+
110
+ Parameters:
111
+ visuals (OrderedDict) - - dictionary of images to display or save
112
+ epoch (int) - - the current epoch
113
+ save_result (bool) - - if save the current results to an HTML file
114
+ """
115
+ if self.display_id > 0: # show images in the browser using visdom
116
+ ncols = self.ncols
117
+ if ncols > 0: # show all the images in one visdom panel
118
+ ncols = min(ncols, len(visuals))
119
+ h, w = next(iter(visuals.values())).shape[:2]
120
+ table_css = """<style>
121
+ table {border-collapse: separate; border-spacing: 4px; white-space: nowrap; text-align: center}
122
+ table td {width: % dpx; height: % dpx; padding: 4px; outline: 4px solid black}
123
+ </style>""" % (w, h) # create a table css
124
+ # create a table of images.
125
+ title = self.name
126
+ label_html = ''
127
+ label_html_row = ''
128
+ images = []
129
+ idx = 0
130
+ for label, image in visuals.items():
131
+ image_numpy = util.tensor2im(image)
132
+ label_html_row += '<td>%s</td>' % label
133
+ images.append(image_numpy.transpose([2, 0, 1]))
134
+ idx += 1
135
+ if idx % ncols == 0:
136
+ label_html += '<tr>%s</tr>' % label_html_row
137
+ label_html_row = ''
138
+ white_image = np.ones_like(image_numpy.transpose([2, 0, 1])) * 255
139
+ while idx % ncols != 0:
140
+ images.append(white_image)
141
+ label_html_row += '<td></td>'
142
+ idx += 1
143
+ if label_html_row != '':
144
+ label_html += '<tr>%s</tr>' % label_html_row
145
+ try:
146
+ self.vis.images(images, ncols, 2, self.display_id + 1,
147
+ None, dict(title=title + ' images'))
148
+ label_html = '<table>%s</table>' % label_html
149
+ self.vis.text(table_css + label_html, win=self.display_id + 2,
150
+ opts=dict(title=title + ' labels'))
151
+ except VisdomExceptionBase:
152
+ self.create_visdom_connections()
153
+
154
+ else: # show each image in a separate visdom panel;
155
+ idx = 1
156
+ try:
157
+ for label, image in visuals.items():
158
+ image_numpy = util.tensor2im(image)
159
+ self.vis.image(
160
+ image_numpy.transpose([2, 0, 1]),
161
+ self.display_id + idx,
162
+ None,
163
+ dict(title=label)
164
+ )
165
+ idx += 1
166
+ except VisdomExceptionBase:
167
+ self.create_visdom_connections()
168
+
169
+ if self.use_html and (save_result or not self.saved): # save images to an HTML file if they haven't been saved.
170
+ self.saved = True
171
+ # save images to the disk
172
+ for label, image in visuals.items():
173
+ image_numpy = util.tensor2im(image)
174
+ img_path = os.path.join(self.img_dir, 'epoch%.3d_%s.png' % (epoch, label))
175
+ util.save_image(image_numpy, img_path)
176
+
177
+ # update website
178
+ webpage = html.HTML(self.web_dir, 'Experiment name = %s' % self.name, refresh=0)
179
+ for n in range(epoch, 0, -1):
180
+ webpage.add_header('epoch [%d]' % n)
181
+ ims, txts, links = [], [], []
182
+
183
+ for label, image_numpy in visuals.items():
184
+ image_numpy = util.tensor2im(image)
185
+ img_path = 'epoch%.3d_%s.png' % (n, label)
186
+ ims.append(img_path)
187
+ txts.append(label)
188
+ links.append(img_path)
189
+ webpage.add_images(ims, txts, links, width=self.win_size)
190
+ webpage.save()
191
+
192
+ def plot_current_losses(self, epoch, counter_ratio, losses):
193
+ """display the current losses on visdom display: dictionary of error labels and values
194
+
195
+ Parameters:
196
+ epoch (int) -- current epoch
197
+ counter_ratio (float) -- progress (percentage) in the current epoch, between 0 to 1
198
+ losses (OrderedDict) -- training losses stored in the format of (name, float) pairs
199
+ """
200
+ if len(losses) == 0:
201
+ return
202
+
203
+ plot_name = '_'.join(list(losses.keys()))
204
+
205
+ if plot_name not in self.plot_data:
206
+ self.plot_data[plot_name] = {'X': [], 'Y': [], 'legend': list(losses.keys())}
207
+
208
+ plot_data = self.plot_data[plot_name]
209
+ plot_id = list(self.plot_data.keys()).index(plot_name)
210
+
211
+ plot_data['X'].append(epoch + counter_ratio)
212
+ plot_data['Y'].append([losses[k] for k in plot_data['legend']])
213
+ try:
214
+ self.vis.line(
215
+ X=np.stack([np.array(plot_data['X'])] * len(plot_data['legend']), 1),
216
+ Y=np.array(plot_data['Y']),
217
+ opts={
218
+ 'title': self.name,
219
+ 'legend': plot_data['legend'],
220
+ 'xlabel': 'epoch',
221
+ 'ylabel': 'loss'},
222
+ win=self.display_id - plot_id)
223
+ except VisdomExceptionBase:
224
+ self.create_visdom_connections()
225
+
226
+ # losses: same format as |losses| of plot_current_losses
227
+ def print_current_losses(self, epoch, iters, losses, t_comp, t_data):
228
+ """print current losses on console; also save the losses to the disk
229
+
230
+ Parameters:
231
+ epoch (int) -- current epoch
232
+ iters (int) -- current training iteration during this epoch (reset to 0 at the end of every epoch)
233
+ losses (OrderedDict) -- training losses stored in the format of (name, float) pairs
234
+ t_comp (float) -- computational time per data point (normalized by batch_size)
235
+ t_data (float) -- data loading time per data point (normalized by batch_size)
236
+ """
237
+ message = '(epoch: %d, iters: %d, time: %.3f, data: %.3f) ' % (epoch, iters, t_comp, t_data)
238
+ # TODO:
239
+ # lambda_pair = math.cos(math.pi/40 * (epoch - 1))
240
+ # message += '[paired weight: %d] ' % lambda_pair
241
+ for k, v in losses.items():
242
+ message += '%s: %.3f ' % (k, v)
243
+
244
+ print(message) # print the message
245
+ with open(self.log_name, "a") as log_file:
246
+ log_file.write('%s\n' % message) # save the message
app.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ from __future__ import annotations
4
+
5
+ import argparse
6
+ import torch
7
+ import gradio as gr
8
+
9
+ from Scenimefy.options.test_options import TestOptions
10
+ from Scenimefy.models import create_model
11
+ from Scenimefy.utils.util import tensor2im
12
+
13
+ from PIL import Image
14
+ import torchvision.transforms as transforms
15
+
16
+
17
+ def parse_args() -> argparse.Namespace:
18
+ parser = argparse.ArgumentParser()
19
+ parser.add_argument('--device', type=str, default='cpu')
20
+ parser.add_argument('--theme', type=str)
21
+ parser.add_argument('--live', action='store_true')
22
+ parser.add_argument('--share', action='store_true')
23
+ parser.add_argument('--port', type=int)
24
+ parser.add_argument('--disable-queue',
25
+ dest='enable_queue',
26
+ action='store_false')
27
+ parser.add_argument('--allow-flagging', type=str, default='never')
28
+ parser.add_argument('--allow-screenshot', action='store_true')
29
+ return parser.parse_args()
30
+
31
+ TITLE = '''
32
+ Scene Stylization with <a href="https://github.com/Yuxinn-J/Scenimefy">Scenimefy</a>
33
+ '''
34
+ DESCRIPTION = '''
35
+ <div align=center>
36
+ <p>
37
+ Gradio Demo for Scenimefy.
38
+ To use it, simply upload your image, or click one of the examples to load them.
39
+ For best outcomes, please pick a natural scene image similar to the examples below.
40
+ Kindly note that our model is trained on 256x256 resolution images, using much higher resolutions might affect its performance.
41
+ Read more at the links below.
42
+ </p>
43
+ </div>
44
+ '''
45
+ EXAMPLES = [['0.png'], ['1.jpg'], ['2.png'], ['3.png'], ['4.jpg'], ['5.png'], ['6.jpg'], ['7.png'], ['8.png']]
46
+ ARTICLE = r"""
47
+ If Scenimefy is helpful, please help to ⭐ the <a href='https://github.com/Yuxinn-J/Scenimefy' target='_blank'>Github Repo</a>. Thank you!
48
+ 🤟 **Citation**
49
+ If our work is useful for your research, please consider citing:
50
+ ```bibtex
51
+ @inproceedings{jiang2023scenimefy,
52
+ title={Scenimefy: Learning to Craft Anime Scene via Semi-Supervised Image-to-Image Translation},
53
+ author={Jiang, Yuxin and Jiang, Liming and Yang, Shuai and Loy, Chen Change},
54
+ booktitle={ICCV},
55
+ year={2023}
56
+ }
57
+ ```
58
+ 🗞️ **License**
59
+ This project is licensed under <a rel="license" href="https://github.com/Yuxinn-J/Scenimefy/blob/main/LICENSE.md">S-Lab License 1.0</a>.
60
+ Redistribution and use for non-commercial purposes should follow this license.
61
+ """
62
+
63
+
64
+ model = None
65
+
66
+
67
+ def initialize():
68
+ opt = TestOptions().parse() # get test options
69
+ # os.environ["CUDA_VISIBLE_DEVICES"] = str(1)
70
+ # hard-code some parameters for test
71
+ opt.num_threads = 0 # test code only supports num_threads = 1
72
+ opt.batch_size = 1 # test code only supports batch_size = 1
73
+ opt.serial_batches = True # disable data shuffling; comment this line if results on randomly chosen images are needed.
74
+ opt.no_flip = True # no flip; comment this line if results on flipped images are needed.
75
+ opt.display_id = -1 # no visdom display; the test code saves the results to a HTML file.
76
+
77
+ # dataset = create_dataset(opt) # create a dataset given opt.dataset_mode and other options
78
+ global model
79
+ model = create_model(opt) # create a model given opt.model and other options
80
+
81
+ dummy_data = {
82
+ 'A': torch.ones(1, 3, 256, 256),
83
+ 'B': torch.ones(1, 3, 256, 256),
84
+ 'A_paths': 'upload.jpg'
85
+ }
86
+
87
+ model.data_dependent_initialize(dummy_data)
88
+ model.setup(opt) # regular setup: load and print networks; create schedulers
89
+ model.parallelize()
90
+ return model
91
+
92
+
93
+ def __make_power_2(img, base, method=Image.BICUBIC):
94
+ ow, oh = img.size
95
+ h = int(round(oh / base) * base)
96
+ w = int(round(ow / base) * base)
97
+ if h == oh and w == ow:
98
+ return img
99
+
100
+ return img.resize((w, h), method)
101
+
102
+
103
+ def get_transform():
104
+ method=Image.BICUBIC
105
+ transform_list = []
106
+ # if opt.preprocess == 'none':
107
+ transform_list.append(transforms.Lambda(lambda img: __make_power_2(img, base=4, method=method)))
108
+ transform_list += [transforms.ToTensor()]
109
+ transform_list += [transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
110
+ return transforms.Compose(transform_list)
111
+
112
+
113
+ def inference(img):
114
+ transform = get_transform()
115
+ A = transform(img.convert('RGB')) # A.shape: torch.Size([3, 260, 460])
116
+ A = A.unsqueeze(0) # A.shape: torch.Size([1, 3, 260, 460])
117
+
118
+ upload_data = {
119
+ 'A': A,
120
+ 'B': torch.ones_like(A),
121
+ 'A_paths': 'upload.jpg'
122
+ }
123
+
124
+ global model
125
+ model.set_input(upload_data) # unpack data from data loader
126
+ model.test() # run inference
127
+ visuals = model.get_current_visuals()
128
+ return tensor2im(visuals['fake_B'])
129
+
130
+
131
+ def main():
132
+ args = parse_args()
133
+ args.device = 'cuda' if torch.cuda.is_available() else 'cpu'
134
+ print('*** Now using %s.'%(args.device))
135
+
136
+ global model
137
+ model = initialize()
138
+
139
+ gr.Interface(
140
+ inference,
141
+ gr.Image(type="pil", label='Input'),
142
+ gr.Image(type="pil", label='Output').style(height=300),
143
+ theme=args.theme,
144
+ title=TITLE,
145
+ description=DESCRIPTION,
146
+ article=ARTICLE,
147
+ examples=EXAMPLES,
148
+ allow_screenshot=args.allow_screenshot,
149
+ allow_flagging=args.allow_flagging,
150
+ live=args.live
151
+ ).launch(
152
+ enable_queue=args.enable_queue,
153
+ server_port=args.port,
154
+ share=args.share
155
+ )
156
+
157
+ if __name__ == '__main__':
158
+ main()
packages.txt ADDED
File without changes
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ numpy
4
+ Pillow
5
+ scipy
6
+ dominate