Spaces:
Running
Running
init
Browse files- .gitattributes +36 -0
- .gitignore +1 -0
- 0.png +0 -0
- 1.jpg +0 -0
- 2.png +0 -0
- 3.png +0 -0
- 4.jpg +0 -0
- 5.png +0 -0
- 6.jpg +0 -0
- 7.png +0 -0
- 8.png +0 -0
- README.md +13 -0
- Scenimefy/data/__init__.py +153 -0
- Scenimefy/data/base_dataset.py +230 -0
- Scenimefy/data/image_folder.py +66 -0
- Scenimefy/data/unaligned_dataset.py +79 -0
- Scenimefy/models/SRC.py +79 -0
- Scenimefy/models/__init__.py +67 -0
- Scenimefy/models/base_model.py +258 -0
- Scenimefy/models/cut_model.py +370 -0
- Scenimefy/models/hDCE.py +53 -0
- Scenimefy/models/networks.py +1513 -0
- Scenimefy/models/patchnce.py +57 -0
- Scenimefy/models/stylegan_networks.py +914 -0
- Scenimefy/options/__init__.py +3 -0
- Scenimefy/options/base_options.py +165 -0
- Scenimefy/options/test_options.py +22 -0
- Scenimefy/pretrained_models/huggingface/Shinkai_net_G.pth +3 -0
- Scenimefy/pretrained_models/huggingface/test_opt.txt +63 -0
- Scenimefy/utils/__init__.py +4 -0
- Scenimefy/utils/html.py +86 -0
- Scenimefy/utils/util.py +168 -0
- Scenimefy/utils/visualizer.py +246 -0
- app.py +158 -0
- packages.txt +0 -0
- requirements.txt +6 -0
.gitattributes
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
6.png filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
*/**/__pycache__
|
0.png
ADDED
1.jpg
ADDED
2.png
ADDED
3.png
ADDED
4.jpg
ADDED
5.png
ADDED
6.jpg
ADDED
7.png
ADDED
8.png
ADDED
README.md
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: Scenimefy
|
3 |
+
emoji: 🦀
|
4 |
+
colorFrom: gray
|
5 |
+
colorTo: blue
|
6 |
+
sdk: gradio
|
7 |
+
sdk_version: 3.41.1
|
8 |
+
app_file: app.py
|
9 |
+
pinned: false
|
10 |
+
license: other
|
11 |
+
---
|
12 |
+
|
13 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
Scenimefy/data/__init__.py
ADDED
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""This package includes all the modules related to data loading and preprocessing
|
2 |
+
|
3 |
+
To add a custom dataset class called 'dummy', you need to add a file called 'dummy_dataset.py' and define a subclass 'DummyDataset' inherited from BaseDataset.
|
4 |
+
You need to implement four functions:
|
5 |
+
-- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt).
|
6 |
+
-- <__len__>: return the size of dataset.
|
7 |
+
-- <__getitem__>: get a data point from data loader.
|
8 |
+
-- <modify_commandline_options>: (optionally) add dataset-specific options and set default options.
|
9 |
+
|
10 |
+
Now you can use the dataset class by specifying flag '--dataset_mode dummy'.
|
11 |
+
See our template dataset class 'template_dataset.py' for more details.
|
12 |
+
"""
|
13 |
+
import importlib
|
14 |
+
import torch.utils.data
|
15 |
+
from Scenimefy.data.base_dataset import BaseDataset
|
16 |
+
|
17 |
+
|
18 |
+
def find_dataset_using_name(dataset_name):
|
19 |
+
"""Import the module "data/[dataset_name]_dataset.py".
|
20 |
+
|
21 |
+
In the file, the class called DatasetNameDataset() will
|
22 |
+
be instantiated. It has to be a subclass of BaseDataset,
|
23 |
+
and it is case-insensitive.
|
24 |
+
"""
|
25 |
+
dataset_filename = "Scenimefy.data." + dataset_name + "_dataset"
|
26 |
+
datasetlib = importlib.import_module(dataset_filename)
|
27 |
+
|
28 |
+
dataset = None
|
29 |
+
target_dataset_name = dataset_name.replace('_', '') + 'dataset'
|
30 |
+
for name, cls in datasetlib.__dict__.items():
|
31 |
+
if name.lower() == target_dataset_name.lower() \
|
32 |
+
and issubclass(cls, BaseDataset):
|
33 |
+
dataset = cls
|
34 |
+
|
35 |
+
if dataset is None:
|
36 |
+
raise NotImplementedError("In %s.py, there should be a subclass of BaseDataset with class name that matches %s in lowercase." % (dataset_filename, target_dataset_name))
|
37 |
+
|
38 |
+
return dataset
|
39 |
+
|
40 |
+
|
41 |
+
def get_option_setter(dataset_name):
|
42 |
+
"""Return the static method <modify_commandline_options> of the dataset class."""
|
43 |
+
dataset_class = find_dataset_using_name(dataset_name)
|
44 |
+
return dataset_class.modify_commandline_options
|
45 |
+
|
46 |
+
|
47 |
+
def create_dataset(opt):
|
48 |
+
"""Create a dataset given the option.
|
49 |
+
|
50 |
+
This function wraps the class CustomDatasetDataLoader.
|
51 |
+
This is the main interface between this package and 'train.py'/'test.py'
|
52 |
+
|
53 |
+
Example:
|
54 |
+
>>> from data import create_dataset
|
55 |
+
>>> dataset = create_dataset(opt)
|
56 |
+
"""
|
57 |
+
data_loader = CustomDatasetDataLoader(opt)
|
58 |
+
dataset = data_loader.load_data()
|
59 |
+
return dataset
|
60 |
+
|
61 |
+
|
62 |
+
class CustomDatasetDataLoader():
|
63 |
+
"""Wrapper class of Dataset class that performs multi-threaded data loading"""
|
64 |
+
|
65 |
+
def __init__(self, opt):
|
66 |
+
"""Initialize this class
|
67 |
+
|
68 |
+
Step 1: create a dataset instance given the name [dataset_mode]
|
69 |
+
Step 2: create a multi-threaded data loader.
|
70 |
+
"""
|
71 |
+
self.opt = opt
|
72 |
+
dataset_class = find_dataset_using_name(opt.dataset_mode)
|
73 |
+
self.dataset = dataset_class(opt)
|
74 |
+
print("dataset [%s] was created" % type(self.dataset).__name__)
|
75 |
+
self.dataloader = torch.utils.data.DataLoader(
|
76 |
+
self.dataset,
|
77 |
+
batch_size=opt.batch_size,
|
78 |
+
shuffle=not opt.serial_batches,
|
79 |
+
num_workers=int(opt.num_threads),
|
80 |
+
drop_last=True if opt.isTrain else False,
|
81 |
+
)
|
82 |
+
|
83 |
+
def set_epoch(self, epoch):
|
84 |
+
self.dataset.current_epoch = epoch
|
85 |
+
|
86 |
+
def load_data(self):
|
87 |
+
return self
|
88 |
+
|
89 |
+
def __len__(self):
|
90 |
+
"""Return the number of data in the dataset"""
|
91 |
+
return min(len(self.dataset), self.opt.max_dataset_size)
|
92 |
+
|
93 |
+
def __iter__(self):
|
94 |
+
"""Return a batch of data"""
|
95 |
+
for i, data in enumerate(self.dataloader):
|
96 |
+
if i * self.opt.batch_size >= self.opt.max_dataset_size:
|
97 |
+
break
|
98 |
+
yield data
|
99 |
+
|
100 |
+
|
101 |
+
# TODO: add paired dataset (stupid implementation)
|
102 |
+
def create_paired_dataset(opt):
|
103 |
+
"""Create a dataset given the option.
|
104 |
+
|
105 |
+
This function wraps the class CustomDatasetDataLoader.
|
106 |
+
This is the main interface between this package and 'train.py'/'test.py'
|
107 |
+
|
108 |
+
Example:
|
109 |
+
>>> from data import create_dataset
|
110 |
+
>>> dataset = create_dataset(opt)
|
111 |
+
"""
|
112 |
+
data_loader = CustomPairedDatasetDataLoader(opt)
|
113 |
+
dataset = data_loader.load_data()
|
114 |
+
return dataset
|
115 |
+
|
116 |
+
|
117 |
+
class CustomPairedDatasetDataLoader():
|
118 |
+
"""Wrapper class of Dataset class that performs multi-threaded data loading"""
|
119 |
+
|
120 |
+
def __init__(self, opt):
|
121 |
+
"""Initialize this class
|
122 |
+
|
123 |
+
Step 1: create a dataset instance given the name [dataset_mode]
|
124 |
+
Step 2: create a multi-threaded data loader.
|
125 |
+
"""
|
126 |
+
self.opt = opt
|
127 |
+
dataset_class = find_dataset_using_name(opt.paired_dataset_mode)
|
128 |
+
self.dataset = dataset_class(opt)
|
129 |
+
print("dataset [%s] was created" % type(self.dataset).__name__)
|
130 |
+
self.dataloader = torch.utils.data.DataLoader(
|
131 |
+
self.dataset,
|
132 |
+
batch_size=opt.batch_size,
|
133 |
+
shuffle=not opt.serial_batches,
|
134 |
+
num_workers=int(opt.num_threads),
|
135 |
+
drop_last=True if opt.isTrain else False,
|
136 |
+
)
|
137 |
+
|
138 |
+
def set_epoch(self, epoch):
|
139 |
+
self.dataset.current_epoch = epoch
|
140 |
+
|
141 |
+
def load_data(self):
|
142 |
+
return self
|
143 |
+
|
144 |
+
def __len__(self):
|
145 |
+
"""Return the number of data in the dataset"""
|
146 |
+
return min(len(self.dataset), self.opt.max_dataset_size)
|
147 |
+
|
148 |
+
def __iter__(self):
|
149 |
+
"""Return a batch of data"""
|
150 |
+
for i, data in enumerate(self.dataloader):
|
151 |
+
if i * self.opt.batch_size >= self.opt.max_dataset_size:
|
152 |
+
break
|
153 |
+
yield data
|
Scenimefy/data/base_dataset.py
ADDED
@@ -0,0 +1,230 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""This module implements an abstract base class (ABC) 'BaseDataset' for datasets.
|
2 |
+
|
3 |
+
It also includes common transformation functions (e.g., get_transform, __scale_width), which can be later used in subclasses.
|
4 |
+
"""
|
5 |
+
import random
|
6 |
+
import numpy as np
|
7 |
+
import torch.utils.data as data
|
8 |
+
from PIL import Image
|
9 |
+
import torchvision.transforms as transforms
|
10 |
+
from abc import ABC, abstractmethod
|
11 |
+
|
12 |
+
|
13 |
+
class BaseDataset(data.Dataset, ABC):
|
14 |
+
"""This class is an abstract base class (ABC) for datasets.
|
15 |
+
|
16 |
+
To create a subclass, you need to implement the following four functions:
|
17 |
+
-- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt).
|
18 |
+
-- <__len__>: return the size of dataset.
|
19 |
+
-- <__getitem__>: get a data point.
|
20 |
+
-- <modify_commandline_options>: (optionally) add dataset-specific options and set default options.
|
21 |
+
"""
|
22 |
+
|
23 |
+
def __init__(self, opt):
|
24 |
+
"""Initialize the class; save the options in the class
|
25 |
+
|
26 |
+
Parameters:
|
27 |
+
opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
|
28 |
+
"""
|
29 |
+
self.opt = opt
|
30 |
+
self.root = opt.dataroot
|
31 |
+
self.current_epoch = 0
|
32 |
+
|
33 |
+
@staticmethod
|
34 |
+
def modify_commandline_options(parser, is_train):
|
35 |
+
"""Add new dataset-specific options, and rewrite default values for existing options.
|
36 |
+
|
37 |
+
Parameters:
|
38 |
+
parser -- original option parser
|
39 |
+
is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.
|
40 |
+
|
41 |
+
Returns:
|
42 |
+
the modified parser.
|
43 |
+
"""
|
44 |
+
return parser
|
45 |
+
|
46 |
+
@abstractmethod
|
47 |
+
def __len__(self):
|
48 |
+
"""Return the total number of images in the dataset."""
|
49 |
+
return 0
|
50 |
+
|
51 |
+
@abstractmethod
|
52 |
+
def __getitem__(self, index):
|
53 |
+
"""Return a data point and its metadata information.
|
54 |
+
|
55 |
+
Parameters:
|
56 |
+
index - - a random integer for data indexing
|
57 |
+
|
58 |
+
Returns:
|
59 |
+
a dictionary of data with their names. It ususally contains the data itself and its metadata information.
|
60 |
+
"""
|
61 |
+
pass
|
62 |
+
|
63 |
+
|
64 |
+
def get_params(opt, size):
|
65 |
+
w, h = size
|
66 |
+
new_h = h
|
67 |
+
new_w = w
|
68 |
+
if opt.preprocess == 'resize_and_crop':
|
69 |
+
new_h = new_w = opt.load_size
|
70 |
+
elif opt.preprocess == 'scale_width_and_crop':
|
71 |
+
new_w = opt.load_size
|
72 |
+
new_h = opt.load_size * h // w
|
73 |
+
|
74 |
+
x = random.randint(0, np.maximum(0, new_w - opt.crop_size))
|
75 |
+
y = random.randint(0, np.maximum(0, new_h - opt.crop_size))
|
76 |
+
|
77 |
+
flip = random.random() > 0.5
|
78 |
+
|
79 |
+
return {'crop_pos': (x, y), 'flip': flip}
|
80 |
+
|
81 |
+
|
82 |
+
def get_transform(opt, params=None, grayscale=False, method=Image.BICUBIC, convert=True):
|
83 |
+
transform_list = []
|
84 |
+
if grayscale:
|
85 |
+
transform_list.append(transforms.Grayscale(1))
|
86 |
+
if 'fixsize' in opt.preprocess:
|
87 |
+
transform_list.append(transforms.Resize(params["size"], method))
|
88 |
+
if 'resize' in opt.preprocess:
|
89 |
+
osize = [opt.load_size, opt.load_size]
|
90 |
+
if "gta2cityscapes" in opt.dataroot:
|
91 |
+
osize[0] = opt.load_size // 2
|
92 |
+
transform_list.append(transforms.Resize(osize, method))
|
93 |
+
elif 'scale_width' in opt.preprocess:
|
94 |
+
transform_list.append(transforms.Lambda(lambda img: __scale_width(img, opt.load_size, opt.crop_size, method)))
|
95 |
+
elif 'scale_shortside' in opt.preprocess:
|
96 |
+
transform_list.append(transforms.Lambda(lambda img: __scale_shortside(img, opt.load_size, opt.crop_size, method)))
|
97 |
+
|
98 |
+
if 'zoom' in opt.preprocess:
|
99 |
+
if params is None:
|
100 |
+
transform_list.append(transforms.Lambda(lambda img: __random_zoom(img, opt.load_size, opt.crop_size, method)))
|
101 |
+
else:
|
102 |
+
transform_list.append(transforms.Lambda(lambda img: __random_zoom(img, opt.load_size, opt.crop_size, method, factor=params["scale_factor"])))
|
103 |
+
|
104 |
+
if 'crop' in opt.preprocess:
|
105 |
+
if params is None or 'crop_pos' not in params:
|
106 |
+
transform_list.append(transforms.RandomCrop(opt.crop_size))
|
107 |
+
else:
|
108 |
+
transform_list.append(transforms.Lambda(lambda img: __crop(img, params['crop_pos'], opt.crop_size)))
|
109 |
+
|
110 |
+
if 'patch' in opt.preprocess:
|
111 |
+
transform_list.append(transforms.Lambda(lambda img: __patch(img, params['patch_index'], opt.crop_size)))
|
112 |
+
|
113 |
+
if 'trim' in opt.preprocess:
|
114 |
+
transform_list.append(transforms.Lambda(lambda img: __trim(img, opt.crop_size)))
|
115 |
+
|
116 |
+
# if opt.preprocess == 'none':
|
117 |
+
transform_list.append(transforms.Lambda(lambda img: __make_power_2(img, base=4, method=method)))
|
118 |
+
|
119 |
+
if not opt.no_flip:
|
120 |
+
if params is None or 'flip' not in params:
|
121 |
+
transform_list.append(transforms.RandomHorizontalFlip())
|
122 |
+
elif 'flip' in params:
|
123 |
+
transform_list.append(transforms.Lambda(lambda img: __flip(img, params['flip'])))
|
124 |
+
|
125 |
+
if convert:
|
126 |
+
transform_list += [transforms.ToTensor()]
|
127 |
+
if grayscale:
|
128 |
+
transform_list += [transforms.Normalize((0.5,), (0.5,))]
|
129 |
+
else:
|
130 |
+
transform_list += [transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
|
131 |
+
return transforms.Compose(transform_list)
|
132 |
+
|
133 |
+
|
134 |
+
def __make_power_2(img, base, method=Image.BICUBIC):
|
135 |
+
ow, oh = img.size
|
136 |
+
h = int(round(oh / base) * base)
|
137 |
+
w = int(round(ow / base) * base)
|
138 |
+
if h == oh and w == ow:
|
139 |
+
return img
|
140 |
+
|
141 |
+
return img.resize((w, h), method)
|
142 |
+
|
143 |
+
|
144 |
+
def __random_zoom(img, target_width, crop_width, method=Image.BICUBIC, factor=None):
|
145 |
+
if factor is None:
|
146 |
+
zoom_level = np.random.uniform(0.8, 1.0, size=[2])
|
147 |
+
else:
|
148 |
+
zoom_level = (factor[0], factor[1])
|
149 |
+
iw, ih = img.size
|
150 |
+
zoomw = max(crop_width, iw * zoom_level[0])
|
151 |
+
zoomh = max(crop_width, ih * zoom_level[1])
|
152 |
+
img = img.resize((int(round(zoomw)), int(round(zoomh))), method)
|
153 |
+
return img
|
154 |
+
|
155 |
+
|
156 |
+
def __scale_shortside(img, target_width, crop_width, method=Image.BICUBIC):
|
157 |
+
ow, oh = img.size
|
158 |
+
shortside = min(ow, oh)
|
159 |
+
if shortside >= target_width:
|
160 |
+
return img
|
161 |
+
else:
|
162 |
+
scale = target_width / shortside
|
163 |
+
return img.resize((round(ow * scale), round(oh * scale)), method)
|
164 |
+
|
165 |
+
|
166 |
+
def __trim(img, trim_width):
|
167 |
+
ow, oh = img.size
|
168 |
+
if ow > trim_width:
|
169 |
+
xstart = np.random.randint(ow - trim_width)
|
170 |
+
xend = xstart + trim_width
|
171 |
+
else:
|
172 |
+
xstart = 0
|
173 |
+
xend = ow
|
174 |
+
if oh > trim_width:
|
175 |
+
ystart = np.random.randint(oh - trim_width)
|
176 |
+
yend = ystart + trim_width
|
177 |
+
else:
|
178 |
+
ystart = 0
|
179 |
+
yend = oh
|
180 |
+
return img.crop((xstart, ystart, xend, yend))
|
181 |
+
|
182 |
+
|
183 |
+
def __scale_width(img, target_width, crop_width, method=Image.BICUBIC):
|
184 |
+
ow, oh = img.size
|
185 |
+
if ow == target_width and oh >= crop_width:
|
186 |
+
return img
|
187 |
+
w = target_width
|
188 |
+
h = int(max(target_width * oh / ow, crop_width))
|
189 |
+
return img.resize((w, h), method)
|
190 |
+
|
191 |
+
|
192 |
+
def __crop(img, pos, size):
|
193 |
+
ow, oh = img.size
|
194 |
+
x1, y1 = pos
|
195 |
+
tw = th = size
|
196 |
+
if (ow > tw or oh > th):
|
197 |
+
return img.crop((x1, y1, x1 + tw, y1 + th))
|
198 |
+
return img
|
199 |
+
|
200 |
+
|
201 |
+
def __patch(img, index, size):
|
202 |
+
ow, oh = img.size
|
203 |
+
nw, nh = ow // size, oh // size
|
204 |
+
roomx = ow - nw * size
|
205 |
+
roomy = oh - nh * size
|
206 |
+
startx = np.random.randint(int(roomx) + 1)
|
207 |
+
starty = np.random.randint(int(roomy) + 1)
|
208 |
+
|
209 |
+
index = index % (nw * nh)
|
210 |
+
ix = index // nh
|
211 |
+
iy = index % nh
|
212 |
+
gridx = startx + ix * size
|
213 |
+
gridy = starty + iy * size
|
214 |
+
return img.crop((gridx, gridy, gridx + size, gridy + size))
|
215 |
+
|
216 |
+
|
217 |
+
def __flip(img, flip):
|
218 |
+
if flip:
|
219 |
+
return img.transpose(Image.FLIP_LEFT_RIGHT)
|
220 |
+
return img
|
221 |
+
|
222 |
+
|
223 |
+
def __print_size_warning(ow, oh, w, h):
|
224 |
+
"""Print warning information about image size(only print once)"""
|
225 |
+
if not hasattr(__print_size_warning, 'has_printed'):
|
226 |
+
print("The image size needs to be a multiple of 4. "
|
227 |
+
"The loaded image size was (%d, %d), so it was adjusted to "
|
228 |
+
"(%d, %d). This adjustment will be done to all images "
|
229 |
+
"whose sizes are not multiples of 4" % (ow, oh, w, h))
|
230 |
+
__print_size_warning.has_printed = True
|
Scenimefy/data/image_folder.py
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""A modified image folder class
|
2 |
+
|
3 |
+
We modify the official PyTorch image folder (https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py)
|
4 |
+
so that this class can load images from both current directory and its subdirectories.
|
5 |
+
"""
|
6 |
+
|
7 |
+
import torch.utils.data as data
|
8 |
+
|
9 |
+
from PIL import Image
|
10 |
+
import os
|
11 |
+
import os.path
|
12 |
+
|
13 |
+
IMG_EXTENSIONS = [
|
14 |
+
'.jpg', '.JPG', '.jpeg', '.JPEG',
|
15 |
+
'.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
|
16 |
+
'.tif', '.TIF', '.tiff', '.TIFF',
|
17 |
+
]
|
18 |
+
|
19 |
+
|
20 |
+
def is_image_file(filename):
|
21 |
+
return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
|
22 |
+
|
23 |
+
|
24 |
+
def make_dataset(dir, max_dataset_size=float("inf")):
|
25 |
+
images = []
|
26 |
+
assert os.path.isdir(dir) or os.path.islink(dir), '%s is not a valid directory' % dir
|
27 |
+
|
28 |
+
for root, _, fnames in sorted(os.walk(dir, followlinks=True)):
|
29 |
+
for fname in fnames:
|
30 |
+
if is_image_file(fname):
|
31 |
+
path = os.path.join(root, fname)
|
32 |
+
images.append(path)
|
33 |
+
return images[:min(max_dataset_size, len(images))]
|
34 |
+
|
35 |
+
|
36 |
+
def default_loader(path):
|
37 |
+
return Image.open(path).convert('RGB')
|
38 |
+
|
39 |
+
|
40 |
+
class ImageFolder(data.Dataset):
|
41 |
+
|
42 |
+
def __init__(self, root, transform=None, return_paths=False,
|
43 |
+
loader=default_loader):
|
44 |
+
imgs = make_dataset(root)
|
45 |
+
if len(imgs) == 0:
|
46 |
+
raise(RuntimeError("Found 0 images in: " + root + "\n"
|
47 |
+
"Supported image extensions are: " + ",".join(IMG_EXTENSIONS)))
|
48 |
+
|
49 |
+
self.root = root
|
50 |
+
self.imgs = imgs
|
51 |
+
self.transform = transform
|
52 |
+
self.return_paths = return_paths
|
53 |
+
self.loader = loader
|
54 |
+
|
55 |
+
def __getitem__(self, index):
|
56 |
+
path = self.imgs[index]
|
57 |
+
img = self.loader(path)
|
58 |
+
if self.transform is not None:
|
59 |
+
img = self.transform(img)
|
60 |
+
if self.return_paths:
|
61 |
+
return img, path
|
62 |
+
else:
|
63 |
+
return img
|
64 |
+
|
65 |
+
def __len__(self):
|
66 |
+
return len(self.imgs)
|
Scenimefy/data/unaligned_dataset.py
ADDED
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os.path
|
2 |
+
from Scenimefy.data.base_dataset import BaseDataset, get_transform
|
3 |
+
from Scenimefy.data.image_folder import make_dataset
|
4 |
+
from PIL import Image
|
5 |
+
import random
|
6 |
+
import Scenimefy.utils.util as util
|
7 |
+
|
8 |
+
|
9 |
+
class UnalignedDataset(BaseDataset):
|
10 |
+
"""
|
11 |
+
This dataset class can load unaligned/unpaired datasets.
|
12 |
+
|
13 |
+
It requires two directories to host training images from domain A '/path/to/data/trainA'
|
14 |
+
and from domain B '/path/to/data/trainB' respectively.
|
15 |
+
You can train the model with the dataset flag '--dataroot /path/to/data'.
|
16 |
+
Similarly, you need to prepare two directories:
|
17 |
+
'/path/to/data/testA' and '/path/to/data/testB' during test time.
|
18 |
+
"""
|
19 |
+
|
20 |
+
def __init__(self, opt):
|
21 |
+
"""Initialize this dataset class.
|
22 |
+
|
23 |
+
Parameters:
|
24 |
+
opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
|
25 |
+
"""
|
26 |
+
BaseDataset.__init__(self, opt)
|
27 |
+
self.dir_A = os.path.join(opt.dataroot, opt.phase + 'A') # create a path '/path/to/data/trainA'
|
28 |
+
self.dir_B = os.path.join(opt.dataroot, opt.phase + 'B') # create a path '/path/to/data/trainB'
|
29 |
+
|
30 |
+
if opt.phase == "test" and not os.path.exists(self.dir_A) \
|
31 |
+
and os.path.exists(os.path.join(opt.dataroot, "valA")):
|
32 |
+
self.dir_A = os.path.join(opt.dataroot, "valA")
|
33 |
+
self.dir_B = os.path.join(opt.dataroot, "valB")
|
34 |
+
|
35 |
+
self.A_paths = sorted(make_dataset(self.dir_A, opt.max_dataset_size)) # load images from '/path/to/data/trainA'
|
36 |
+
self.B_paths = sorted(make_dataset(self.dir_B, opt.max_dataset_size)) # load images from '/path/to/data/trainB'
|
37 |
+
self.A_size = len(self.A_paths) # get the size of dataset A
|
38 |
+
self.B_size = len(self.B_paths) # get the size of dataset B
|
39 |
+
|
40 |
+
def __getitem__(self, index):
|
41 |
+
"""Return a data point and its metadata information.
|
42 |
+
|
43 |
+
Parameters:
|
44 |
+
index (int) -- a random integer for data indexing
|
45 |
+
|
46 |
+
Returns a dictionary that contains A, B, A_paths and B_paths
|
47 |
+
A (tensor) -- an image in the input domain
|
48 |
+
B (tensor) -- its corresponding image in the target domain
|
49 |
+
A_paths (str) -- image paths
|
50 |
+
B_paths (str) -- image paths
|
51 |
+
"""
|
52 |
+
A_path = self.A_paths[index % self.A_size] # make sure index is within then range
|
53 |
+
if self.opt.serial_batches: # make sure index is within then range
|
54 |
+
index_B = index % self.B_size
|
55 |
+
else: # randomize the index for domain B to avoid fixed pairs.
|
56 |
+
index_B = random.randint(0, self.B_size - 1)
|
57 |
+
B_path = self.B_paths[index_B]
|
58 |
+
A_img = Image.open(A_path).convert('RGB')
|
59 |
+
B_img = Image.open(B_path).convert('RGB')
|
60 |
+
|
61 |
+
# Apply image transformation
|
62 |
+
# For FastCUT mode, if in finetuning phase (learning rate is decaying),
|
63 |
+
# do not perform resize-crop data augmentation of CycleGAN.
|
64 |
+
# print('current_epoch', self.current_epoch)
|
65 |
+
is_finetuning = self.opt.isTrain and self.current_epoch > self.opt.n_epochs
|
66 |
+
modified_opt = util.copyconf(self.opt, load_size=self.opt.crop_size if is_finetuning else self.opt.load_size)
|
67 |
+
transform = get_transform(modified_opt)
|
68 |
+
A = transform(A_img)
|
69 |
+
B = transform(B_img)
|
70 |
+
|
71 |
+
return {'A': A, 'B': B, 'A_paths': A_path, 'B_paths': B_path}
|
72 |
+
|
73 |
+
def __len__(self):
|
74 |
+
"""Return the total number of images in the dataset.
|
75 |
+
|
76 |
+
As we have two datasets with potentially different number of images,
|
77 |
+
we take a maximum of
|
78 |
+
"""
|
79 |
+
return max(self.A_size, self.B_size)
|
Scenimefy/models/SRC.py
ADDED
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from packaging import version
|
2 |
+
import torch
|
3 |
+
from torch import nn
|
4 |
+
|
5 |
+
|
6 |
+
class Normalize(nn.Module):
|
7 |
+
|
8 |
+
def __init__(self, power=2):
|
9 |
+
super(Normalize, self).__init__()
|
10 |
+
self.power = power
|
11 |
+
|
12 |
+
def forward(self, x):
|
13 |
+
norm = x.pow(self.power).sum(1, keepdim=True).pow(1. / self.power)
|
14 |
+
out = x.div(norm + 1e-7)
|
15 |
+
return out
|
16 |
+
|
17 |
+
class SRC_Loss(nn.Module):
|
18 |
+
def __init__(self, opt):
|
19 |
+
super().__init__()
|
20 |
+
self.opt = opt
|
21 |
+
self.mask_dtype = torch.uint8 if version.parse(torch.__version__) < version.parse('1.2.0') else torch.bool
|
22 |
+
|
23 |
+
def forward(self, feat_q, feat_k, only_weight=False, epoch=None):
|
24 |
+
'''
|
25 |
+
:param feat_q: target
|
26 |
+
:param feat_k: source
|
27 |
+
:return: SRC loss, weights for hDCE
|
28 |
+
'''
|
29 |
+
|
30 |
+
batchSize = feat_q.shape[0]
|
31 |
+
dim = feat_q.shape[1]
|
32 |
+
feat_k = feat_k.detach()
|
33 |
+
batch_dim_for_bmm = 1 # self.opt.batch_size
|
34 |
+
feat_k = Normalize()(feat_k)
|
35 |
+
feat_q = Normalize()(feat_q)
|
36 |
+
|
37 |
+
## SRC
|
38 |
+
feat_q_v = feat_q.view(batch_dim_for_bmm, -1, dim)
|
39 |
+
feat_k_v = feat_k.view(batch_dim_for_bmm, -1, dim)
|
40 |
+
|
41 |
+
spatial_q = torch.bmm(feat_q_v, feat_q_v.transpose(2, 1))
|
42 |
+
spatial_k = torch.bmm(feat_k_v, feat_k_v.transpose(2, 1))
|
43 |
+
|
44 |
+
weight_seed = spatial_k.clone().detach()
|
45 |
+
diagonal = torch.eye(self.opt.num_patches, device=feat_k_v.device, dtype=self.mask_dtype)[None, :, :]
|
46 |
+
|
47 |
+
HDCE_gamma = self.opt.HDCE_gamma
|
48 |
+
if self.opt.use_curriculum:
|
49 |
+
HDCE_gamma = HDCE_gamma + (self.opt.HDCE_gamma_min - HDCE_gamma) * (epoch) / (self.opt.n_epochs + self.opt.n_epochs_decay)
|
50 |
+
if (self.opt.step_gamma)&(epoch>self.opt.step_gamma_epoch):
|
51 |
+
HDCE_gamma = 1
|
52 |
+
|
53 |
+
|
54 |
+
## weights by semantic relation
|
55 |
+
weight_seed.masked_fill_(diagonal, -10.0)
|
56 |
+
weight_out = nn.Softmax(dim=2)(weight_seed.clone() / HDCE_gamma).detach()
|
57 |
+
wmax_out, _ = torch.max(weight_out, dim=2, keepdim=True)
|
58 |
+
weight_out /= wmax_out
|
59 |
+
|
60 |
+
if only_weight:
|
61 |
+
return 0, weight_out
|
62 |
+
|
63 |
+
spatial_q = nn.Softmax(dim=1)(spatial_q)
|
64 |
+
spatial_k = nn.Softmax(dim=1)(spatial_k).detach()
|
65 |
+
|
66 |
+
loss_src = self.get_jsd(spatial_q, spatial_k)
|
67 |
+
|
68 |
+
return loss_src, weight_out
|
69 |
+
|
70 |
+
def get_jsd(self, p1, p2):
|
71 |
+
'''
|
72 |
+
:param p1: n X C
|
73 |
+
:param p2: n X C
|
74 |
+
:return: n X 1
|
75 |
+
'''
|
76 |
+
m = 0.5 * (p1 + p2)
|
77 |
+
out = 0.5 * (nn.KLDivLoss(reduction='sum', log_target=True)(torch.log(m), torch.log(p1))
|
78 |
+
+ nn.KLDivLoss(reduction='sum', log_target=True)(torch.log(m), torch.log(p2)))
|
79 |
+
return out
|
Scenimefy/models/__init__.py
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""This package contains modules related to objective functions, optimizations, and network architectures.
|
2 |
+
|
3 |
+
To add a custom model class called 'dummy', you need to add a file called 'dummy_model.py' and define a subclass DummyModel inherited from BaseModel.
|
4 |
+
You need to implement the following five functions:
|
5 |
+
-- <__init__>: initialize the class; first call BaseModel.__init__(self, opt).
|
6 |
+
-- <set_input>: unpack data from dataset and apply preprocessing.
|
7 |
+
-- <forward>: produce intermediate results.
|
8 |
+
-- <optimize_parameters>: calculate loss, gradients, and update network weights.
|
9 |
+
-- <modify_commandline_options>: (optionally) add model-specific options and set default options.
|
10 |
+
|
11 |
+
In the function <__init__>, you need to define four lists:
|
12 |
+
-- self.loss_names (str list): specify the training losses that you want to plot and save.
|
13 |
+
-- self.model_names (str list): define networks used in our training.
|
14 |
+
-- self.visual_names (str list): specify the images that you want to display and save.
|
15 |
+
-- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an usage.
|
16 |
+
|
17 |
+
Now you can use the model class by specifying flag '--model dummy'.
|
18 |
+
See our template model class 'template_model.py' for more details.
|
19 |
+
"""
|
20 |
+
|
21 |
+
import importlib
|
22 |
+
from Scenimefy.models.base_model import BaseModel
|
23 |
+
|
24 |
+
|
25 |
+
def find_model_using_name(model_name):
|
26 |
+
"""Import the module "models/[model_name]_model.py".
|
27 |
+
|
28 |
+
In the file, the class called DatasetNameModel() will
|
29 |
+
be instantiated. It has to be a subclass of BaseModel,
|
30 |
+
and it is case-insensitive.
|
31 |
+
"""
|
32 |
+
model_filename = "Scenimefy.models." + model_name + "_model"
|
33 |
+
modellib = importlib.import_module(model_filename)
|
34 |
+
model = None
|
35 |
+
target_model_name = model_name.replace('_', '') + 'model'
|
36 |
+
for name, cls in modellib.__dict__.items():
|
37 |
+
if name.lower() == target_model_name.lower() \
|
38 |
+
and issubclass(cls, BaseModel):
|
39 |
+
model = cls
|
40 |
+
|
41 |
+
if model is None:
|
42 |
+
print("In %s.py, there should be a subclass of BaseModel with class name that matches %s in lowercase." % (model_filename, target_model_name))
|
43 |
+
exit(0)
|
44 |
+
|
45 |
+
return model
|
46 |
+
|
47 |
+
|
48 |
+
def get_option_setter(model_name):
|
49 |
+
"""Return the static method <modify_commandline_options> of the model class."""
|
50 |
+
model_class = find_model_using_name(model_name)
|
51 |
+
return model_class.modify_commandline_options
|
52 |
+
|
53 |
+
|
54 |
+
def create_model(opt):
|
55 |
+
"""Create a model given the option.
|
56 |
+
|
57 |
+
This function warps the class CustomDatasetDataLoader.
|
58 |
+
This is the main interface between this package and 'train.py'/'test.py'
|
59 |
+
|
60 |
+
Example:
|
61 |
+
>>> from models import create_model
|
62 |
+
>>> model = create_model(opt)
|
63 |
+
"""
|
64 |
+
model = find_model_using_name(opt.model)
|
65 |
+
instance = model(opt)
|
66 |
+
print("model [%s] was created" % type(instance).__name__)
|
67 |
+
return instance
|
Scenimefy/models/base_model.py
ADDED
@@ -0,0 +1,258 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
from collections import OrderedDict
|
4 |
+
from abc import ABC, abstractmethod
|
5 |
+
from Scenimefy.models import networks
|
6 |
+
|
7 |
+
|
8 |
+
class BaseModel(ABC):
|
9 |
+
"""This class is an abstract base class (ABC) for models.
|
10 |
+
To create a subclass, you need to implement the following five functions:
|
11 |
+
-- <__init__>: initialize the class; first call BaseModel.__init__(self, opt).
|
12 |
+
-- <set_input>: unpack data from dataset and apply preprocessing.
|
13 |
+
-- <forward>: produce intermediate results.
|
14 |
+
-- <optimize_parameters>: calculate losses, gradients, and update network weights.
|
15 |
+
-- <modify_commandline_options>: (optionally) add model-specific options and set default options.
|
16 |
+
"""
|
17 |
+
|
18 |
+
def __init__(self, opt):
|
19 |
+
"""Initialize the BaseModel class.
|
20 |
+
|
21 |
+
Parameters:
|
22 |
+
opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
|
23 |
+
|
24 |
+
When creating your custom class, you need to implement your own initialization.
|
25 |
+
In this fucntion, you should first call <BaseModel.__init__(self, opt)>
|
26 |
+
Then, you need to define four lists:
|
27 |
+
-- self.loss_names (str list): specify the training losses that you want to plot and save.
|
28 |
+
-- self.model_names (str list): specify the images that you want to display and save.
|
29 |
+
-- self.visual_names (str list): define networks used in our training.
|
30 |
+
-- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an example.
|
31 |
+
"""
|
32 |
+
self.opt = opt
|
33 |
+
self.gpu_ids = opt.gpu_ids
|
34 |
+
self.isTrain = opt.isTrain
|
35 |
+
self.device = torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu') # get device name: CPU or GPU
|
36 |
+
self.save_dir = os.path.join(opt.checkpoints_dir, opt.name) # save all the checkpoints to save_dir
|
37 |
+
if opt.preprocess != 'scale_width': # with [scale_width], input images might have different sizes, which hurts the performance of cudnn.benchmark.
|
38 |
+
torch.backends.cudnn.benchmark = True
|
39 |
+
self.loss_names = []
|
40 |
+
self.model_names = []
|
41 |
+
self.visual_names = []
|
42 |
+
self.optimizers = []
|
43 |
+
self.image_paths = []
|
44 |
+
self.metric = 0 # used for learning rate policy 'plateau'
|
45 |
+
|
46 |
+
@staticmethod
|
47 |
+
def dict_grad_hook_factory(add_func=lambda x: x):
|
48 |
+
saved_dict = dict()
|
49 |
+
|
50 |
+
def hook_gen(name):
|
51 |
+
def grad_hook(grad):
|
52 |
+
saved_vals = add_func(grad)
|
53 |
+
saved_dict[name] = saved_vals
|
54 |
+
return grad_hook
|
55 |
+
return hook_gen, saved_dict
|
56 |
+
|
57 |
+
@staticmethod
|
58 |
+
def modify_commandline_options(parser, is_train):
|
59 |
+
"""Add new model-specific options, and rewrite default values for existing options.
|
60 |
+
|
61 |
+
Parameters:
|
62 |
+
parser -- original option parser
|
63 |
+
is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.
|
64 |
+
|
65 |
+
Returns:
|
66 |
+
the modified parser.
|
67 |
+
"""
|
68 |
+
return parser
|
69 |
+
|
70 |
+
@abstractmethod
|
71 |
+
def set_input(self, input):
|
72 |
+
"""Unpack input data from the dataloader and perform necessary pre-processing steps.
|
73 |
+
|
74 |
+
Parameters:
|
75 |
+
input (dict): includes the data itself and its metadata information.
|
76 |
+
"""
|
77 |
+
pass
|
78 |
+
|
79 |
+
@abstractmethod
|
80 |
+
def forward(self):
|
81 |
+
"""Run forward pass; called by both functions <optimize_parameters> and <test>."""
|
82 |
+
pass
|
83 |
+
|
84 |
+
@abstractmethod
|
85 |
+
def optimize_parameters(self):
|
86 |
+
"""Calculate losses, gradients, and update network weights; called in every training iteration"""
|
87 |
+
pass
|
88 |
+
|
89 |
+
def setup(self, opt):
|
90 |
+
"""Load and print networks; create schedulers
|
91 |
+
|
92 |
+
Parameters:
|
93 |
+
opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
|
94 |
+
"""
|
95 |
+
if self.isTrain:
|
96 |
+
self.schedulers = [networks.get_scheduler(optimizer, opt) for optimizer in self.optimizers]
|
97 |
+
if not self.isTrain or opt.continue_train:
|
98 |
+
load_suffix = opt.epoch
|
99 |
+
self.load_networks(load_suffix)
|
100 |
+
|
101 |
+
self.print_networks(opt.verbose)
|
102 |
+
|
103 |
+
def parallelize(self):
|
104 |
+
for name in self.model_names:
|
105 |
+
if isinstance(name, str):
|
106 |
+
net = getattr(self, 'net' + name)
|
107 |
+
setattr(self, 'net' + name, torch.nn.DataParallel(net, self.opt.gpu_ids))
|
108 |
+
|
109 |
+
def data_dependent_initialize(self, data):
|
110 |
+
pass
|
111 |
+
|
112 |
+
def eval(self):
|
113 |
+
"""Make models eval mode during test time"""
|
114 |
+
for name in self.model_names:
|
115 |
+
if isinstance(name, str):
|
116 |
+
net = getattr(self, 'net' + name)
|
117 |
+
net.eval()
|
118 |
+
|
119 |
+
def test(self):
|
120 |
+
"""Forward function used in test time.
|
121 |
+
|
122 |
+
This function wraps <forward> function in no_grad() so we don't save intermediate steps for backprop
|
123 |
+
It also calls <compute_visuals> to produce additional visualization results
|
124 |
+
"""
|
125 |
+
with torch.no_grad():
|
126 |
+
self.forward()
|
127 |
+
self.compute_visuals()
|
128 |
+
|
129 |
+
def compute_visuals(self):
|
130 |
+
"""Calculate additional output images for visdom and HTML visualization"""
|
131 |
+
pass
|
132 |
+
|
133 |
+
def get_image_paths(self):
|
134 |
+
""" Return image paths that are used to load current data"""
|
135 |
+
return self.image_paths
|
136 |
+
|
137 |
+
def update_learning_rate(self):
|
138 |
+
"""Update learning rates for all the networks; called at the end of every epoch"""
|
139 |
+
for scheduler in self.schedulers:
|
140 |
+
if self.opt.lr_policy == 'plateau':
|
141 |
+
scheduler.step(self.metric)
|
142 |
+
else:
|
143 |
+
scheduler.step()
|
144 |
+
|
145 |
+
lr = self.optimizers[0].param_groups[0]['lr']
|
146 |
+
print('learning rate = %.7f' % lr)
|
147 |
+
|
148 |
+
def get_current_visuals(self):
|
149 |
+
"""Return visualization images. train.py will display these images with visdom, and save the images to a HTML"""
|
150 |
+
visual_ret = OrderedDict()
|
151 |
+
for name in self.visual_names:
|
152 |
+
if isinstance(name, str):
|
153 |
+
visual_ret[name] = getattr(self, name)
|
154 |
+
return visual_ret
|
155 |
+
|
156 |
+
def get_current_losses(self):
|
157 |
+
"""Return traning losses / errors. train.py will print out these errors on console, and save them to a file"""
|
158 |
+
errors_ret = OrderedDict()
|
159 |
+
for name in self.loss_names:
|
160 |
+
if isinstance(name, str):
|
161 |
+
errors_ret[name] = float(getattr(self, 'loss_' + name)) # float(...) works for both scalar tensor and float number
|
162 |
+
return errors_ret
|
163 |
+
|
164 |
+
def save_networks(self, epoch):
|
165 |
+
"""Save all the networks to the disk.
|
166 |
+
|
167 |
+
Parameters:
|
168 |
+
epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name)
|
169 |
+
"""
|
170 |
+
for name in self.model_names:
|
171 |
+
if isinstance(name, str):
|
172 |
+
save_filename = '%s_net_%s.pth' % (epoch, name)
|
173 |
+
save_path = os.path.join(self.save_dir, save_filename)
|
174 |
+
net = getattr(self, 'net' + name)
|
175 |
+
|
176 |
+
if len(self.gpu_ids) > 0 and torch.cuda.is_available():
|
177 |
+
torch.save(net.module.cpu().state_dict(), save_path)
|
178 |
+
net.cuda(self.gpu_ids[0])
|
179 |
+
else:
|
180 |
+
torch.save(net.cpu().state_dict(), save_path)
|
181 |
+
|
182 |
+
def __patch_instance_norm_state_dict(self, state_dict, module, keys, i=0):
|
183 |
+
"""Fix InstanceNorm checkpoints incompatibility (prior to 0.4)"""
|
184 |
+
key = keys[i]
|
185 |
+
if i + 1 == len(keys): # at the end, pointing to a parameter/buffer
|
186 |
+
if module.__class__.__name__.startswith('InstanceNorm') and \
|
187 |
+
(key == 'running_mean' or key == 'running_var'):
|
188 |
+
if getattr(module, key) is None:
|
189 |
+
state_dict.pop('.'.join(keys))
|
190 |
+
if module.__class__.__name__.startswith('InstanceNorm') and \
|
191 |
+
(key == 'num_batches_tracked'):
|
192 |
+
state_dict.pop('.'.join(keys))
|
193 |
+
else:
|
194 |
+
self.__patch_instance_norm_state_dict(state_dict, getattr(module, key), keys, i + 1)
|
195 |
+
|
196 |
+
def load_networks(self, epoch):
|
197 |
+
"""Load all the networks from the disk.
|
198 |
+
|
199 |
+
Parameters:
|
200 |
+
epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name)
|
201 |
+
"""
|
202 |
+
for name in self.model_names:
|
203 |
+
if isinstance(name, str):
|
204 |
+
load_filename = '%s_net_%s.pth' % (epoch, name)
|
205 |
+
if self.opt.isTrain and self.opt.pretrained_name is not None:
|
206 |
+
load_dir = os.path.join(self.opt.checkpoints_dir, self.opt.pretrained_name)
|
207 |
+
else:
|
208 |
+
load_dir = self.save_dir
|
209 |
+
|
210 |
+
load_path = os.path.join(load_dir, load_filename)
|
211 |
+
net = getattr(self, 'net' + name)
|
212 |
+
if isinstance(net, torch.nn.DataParallel):
|
213 |
+
net = net.module
|
214 |
+
print('loading the model from %s' % load_path)
|
215 |
+
# if you are using PyTorch newer than 0.4 (e.g., built from
|
216 |
+
# GitHub source), you can remove str() on self.device
|
217 |
+
state_dict = torch.load(load_path, map_location=str(self.device))
|
218 |
+
if hasattr(state_dict, '_metadata'):
|
219 |
+
del state_dict._metadata
|
220 |
+
|
221 |
+
# patch InstanceNorm checkpoints prior to 0.4
|
222 |
+
# for key in list(state_dict.keys()): # need to copy keys here because we mutate in loop
|
223 |
+
# self.__patch_instance_norm_state_dict(state_dict, net, key.split('.'))
|
224 |
+
net.load_state_dict(state_dict)
|
225 |
+
|
226 |
+
def print_networks(self, verbose):
|
227 |
+
"""Print the total number of parameters in the network and (if verbose) network architecture
|
228 |
+
|
229 |
+
Parameters:
|
230 |
+
verbose (bool) -- if verbose: print the network architecture
|
231 |
+
"""
|
232 |
+
print('---------- Networks initialized -------------')
|
233 |
+
for name in self.model_names:
|
234 |
+
if isinstance(name, str):
|
235 |
+
net = getattr(self, 'net' + name)
|
236 |
+
num_params = 0
|
237 |
+
for param in net.parameters():
|
238 |
+
num_params += param.numel()
|
239 |
+
if verbose:
|
240 |
+
print(net)
|
241 |
+
print('[Network %s] Total number of parameters : %.3f M' % (name, num_params / 1e6))
|
242 |
+
print('-----------------------------------------------')
|
243 |
+
|
244 |
+
def set_requires_grad(self, nets, requires_grad=False):
|
245 |
+
"""Set requies_grad=Fasle for all the networks to avoid unnecessary computations
|
246 |
+
Parameters:
|
247 |
+
nets (network list) -- a list of networks
|
248 |
+
requires_grad (bool) -- whether the networks require gradients or not
|
249 |
+
"""
|
250 |
+
if not isinstance(nets, list):
|
251 |
+
nets = [nets]
|
252 |
+
for net in nets:
|
253 |
+
if net is not None:
|
254 |
+
for param in net.parameters():
|
255 |
+
param.requires_grad = requires_grad
|
256 |
+
|
257 |
+
def generate_visuals_for_evaluation(self, data, mode):
|
258 |
+
return {}
|
Scenimefy/models/cut_model.py
ADDED
@@ -0,0 +1,370 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
from Scenimefy.models.base_model import BaseModel
|
4 |
+
from Scenimefy.models import networks
|
5 |
+
from Scenimefy.models.patchnce import PatchNCELoss
|
6 |
+
import Scenimefy.utils.util as util
|
7 |
+
from torch.distributions.beta import Beta
|
8 |
+
from torch.nn import functional as F
|
9 |
+
from Scenimefy.models.hDCE import PatchHDCELoss
|
10 |
+
from Scenimefy.models.SRC import SRC_Loss
|
11 |
+
import torch.nn as nn
|
12 |
+
|
13 |
+
|
14 |
+
def show_np_r(array, min, max, num):
|
15 |
+
plt.figure(num)
|
16 |
+
plt.imshow(array, norm=None, cmap='gray', vmin= min, vmax=max)
|
17 |
+
plt.axis('off')
|
18 |
+
plt.show()
|
19 |
+
|
20 |
+
def show_hot_r(array, num):
|
21 |
+
plt.figure(num)
|
22 |
+
plt.imshow(array, norm=None, cmap='hot')
|
23 |
+
plt.axis('off')
|
24 |
+
plt.show()
|
25 |
+
|
26 |
+
def show_torch_rgb(array, min, max, num):
|
27 |
+
plt.figure(num)
|
28 |
+
plt.imshow(array.detach().cpu()[0].permute(1,2,0).numpy()*255, norm=None, cmap='gray', vmin= min, vmax=max)
|
29 |
+
plt.axis('off')
|
30 |
+
plt.show()
|
31 |
+
|
32 |
+
|
33 |
+
class Normalize(nn.Module):
|
34 |
+
|
35 |
+
def __init__(self, power=2):
|
36 |
+
super(Normalize, self).__init__()
|
37 |
+
self.power = power
|
38 |
+
|
39 |
+
def forward(self, x):
|
40 |
+
norm = x.pow(self.power).sum(1, keepdim=True).pow(1. / self.power)
|
41 |
+
out = x.div(norm + 1e-7)
|
42 |
+
return out
|
43 |
+
|
44 |
+
def get_lambda(alpha=1.0,size=None,device=None):
|
45 |
+
'''Return lambda'''
|
46 |
+
if alpha > 0.:
|
47 |
+
lam = np.random.beta(alpha, alpha)
|
48 |
+
# lam = Beta()
|
49 |
+
else:
|
50 |
+
lam = 1.
|
51 |
+
return lam
|
52 |
+
def get_spa_lambda(alpha=1.0,size=None,device=None):
|
53 |
+
'''Return lambda'''
|
54 |
+
if alpha > 0.:
|
55 |
+
lam = torch.from_numpy(np.random.beta(alpha, alpha,size=size)).float().to(device)
|
56 |
+
# lam = Beta()
|
57 |
+
else:
|
58 |
+
lam = 1.
|
59 |
+
return lam
|
60 |
+
class CUTModel(BaseModel):
|
61 |
+
""" This class implements CUT and FastCUT model, described in the paper
|
62 |
+
Contrastive Learning for Unpaired Image-to-Image Translation
|
63 |
+
Taesung Park, Alexei A. Efros, Richard Zhang, Jun-Yan Zhu
|
64 |
+
ECCV, 2020
|
65 |
+
|
66 |
+
The code borrows heavily from the PyTorch implementation of CycleGAN
|
67 |
+
https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix
|
68 |
+
"""
|
69 |
+
@staticmethod
|
70 |
+
def modify_commandline_options(parser, is_train=True):
|
71 |
+
""" Configures options specific for CUT model
|
72 |
+
"""
|
73 |
+
parser.add_argument('--CUT_mode', type=str, default="CUT", choices='(CUT, cut, FastCUT, fastcut)')
|
74 |
+
|
75 |
+
parser.add_argument('--lambda_GAN', type=float, default=1.0, help='weight for GAN loss:GAN(G(X))')
|
76 |
+
parser.add_argument('--lambda_HDCE', type=float, default=1.0, help='weight for HDCE loss: HDCE(G(X), X)')
|
77 |
+
parser.add_argument('--lambda_SRC', type=float, default=1.0, help='weight for SRC loss: SRC(G(X), X)')
|
78 |
+
parser.add_argument('--dce_idt', action='store_true')
|
79 |
+
parser.add_argument('--nce_layers', type=str, default='0,4,8,12,16', help='compute NCE loss on which layers')
|
80 |
+
parser.add_argument('--nce_includes_all_negatives_from_minibatch',
|
81 |
+
type=util.str2bool, nargs='?', const=True, default=False,
|
82 |
+
help='(used for single image translation) If True, include the negatives from the other samples of the minibatch when computing the contrastive loss. Please see models/patchnce.py for more details.')
|
83 |
+
parser.add_argument('--netF', type=str, default='mlp_sample', choices=['sample', 'reshape', 'mlp_sample'], help='how to downsample the feature map')
|
84 |
+
parser.add_argument('--netF_nc', type=int, default=256)
|
85 |
+
parser.add_argument('--nce_T', type=float, default=0.07, help='temperature for NCE loss')
|
86 |
+
parser.add_argument('--num_patches', type=int, default=256, help='number of patches per layer')
|
87 |
+
parser.add_argument('--flip_equivariance',
|
88 |
+
type=util.str2bool, nargs='?', const=True, default=False,
|
89 |
+
help="Enforce flip-equivariance as additional regularization. It's used by FastCUT, but not CUT")
|
90 |
+
parser.add_argument('--alpha', type=float, default=0.2)
|
91 |
+
parser.add_argument('--use_curriculum', action='store_true')
|
92 |
+
parser.add_argument('--HDCE_gamma', type=float, default=1)
|
93 |
+
parser.add_argument('--HDCE_gamma_min', type=float, default=1)
|
94 |
+
parser.add_argument('--step_gamma', action='store_true')
|
95 |
+
parser.add_argument('--step_gamma_epoch', type=int, default=200)
|
96 |
+
parser.add_argument('--no_Hneg', action='store_true')
|
97 |
+
|
98 |
+
parser.set_defaults(pool_size=0) # no image pooling
|
99 |
+
|
100 |
+
opt, _ = parser.parse_known_args()
|
101 |
+
|
102 |
+
return parser
|
103 |
+
|
104 |
+
def __init__(self, opt):
|
105 |
+
BaseModel.__init__(self, opt)
|
106 |
+
|
107 |
+
self.train_epoch = None
|
108 |
+
|
109 |
+
# specify the training losses you want to print out.
|
110 |
+
# The training/test scripts will call <BaseModel.get_current_losses>
|
111 |
+
self.loss_names = ['G_GAN', 'D_real', 'D_fake', 'G']
|
112 |
+
|
113 |
+
if opt.lambda_HDCE > 0.0:
|
114 |
+
self.loss_names.append('HDCE')
|
115 |
+
if opt.dce_idt and self.isTrain:
|
116 |
+
self.loss_names += ['HDCE_Y']
|
117 |
+
|
118 |
+
if opt.lambda_SRC > 0.0:
|
119 |
+
self.loss_names.append('SRC')
|
120 |
+
|
121 |
+
|
122 |
+
self.visual_names = ['real_A', 'fake_B', 'real_B']
|
123 |
+
self.nce_layers = [int(i) for i in self.opt.nce_layers.split(',')]
|
124 |
+
self.alpha = opt.alpha
|
125 |
+
if opt.dce_idt and self.isTrain:
|
126 |
+
self.visual_names += ['idt_B']
|
127 |
+
|
128 |
+
if self.isTrain:
|
129 |
+
self.model_names = ['G', 'F', 'D']
|
130 |
+
else: # during test time, only load G
|
131 |
+
self.model_names = ['G']
|
132 |
+
# define networks (both generator and discriminator)
|
133 |
+
self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, opt.normG, not opt.no_dropout, opt.init_type, opt.init_gain, opt.no_antialias, opt.no_antialias_up, self.gpu_ids, opt)
|
134 |
+
self.netF = networks.define_F(opt.input_nc, opt.netF, opt.normG, not opt.no_dropout, opt.init_type, opt.init_gain, opt.no_antialias, self.gpu_ids, opt)
|
135 |
+
|
136 |
+
|
137 |
+
if self.isTrain:
|
138 |
+
self.netD = networks.define_D(opt.output_nc, opt.ndf, opt.netD, opt.n_layers_D, opt.normD, opt.init_type, opt.init_gain, opt.no_antialias, self.gpu_ids, opt)
|
139 |
+
|
140 |
+
# define loss functions
|
141 |
+
self.criterionGAN = networks.GANLoss(opt.gan_mode).to(self.device)
|
142 |
+
self.criterionNCE = []
|
143 |
+
self.criterionHDCE = []
|
144 |
+
|
145 |
+
for i, nce_layer in enumerate(self.nce_layers):
|
146 |
+
self.criterionNCE.append(PatchNCELoss(opt).to(self.device))
|
147 |
+
self.criterionHDCE.append(PatchHDCELoss(opt=opt).to(self.device))
|
148 |
+
|
149 |
+
self.criterionIdt = torch.nn.L1Loss().to(self.device)
|
150 |
+
self.optimizer_G = torch.optim.Adam(self.netG.parameters(), lr=opt.lr, betas=(opt.beta1, opt.beta2))
|
151 |
+
self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=opt.lr, betas=(opt.beta1, opt.beta2))
|
152 |
+
self.optimizers.append(self.optimizer_G)
|
153 |
+
self.optimizers.append(self.optimizer_D)
|
154 |
+
|
155 |
+
self.criterionR = []
|
156 |
+
for nce_layer in self.nce_layers:
|
157 |
+
self.criterionR.append(SRC_Loss(opt).to(self.device))
|
158 |
+
|
159 |
+
|
160 |
+
def data_dependent_initialize(self, data):
|
161 |
+
"""
|
162 |
+
The feature network netF is defined in terms of the shape of the intermediate, extracted
|
163 |
+
features of the encoder portion of netG. Because of this, the weights of netF are
|
164 |
+
initialized at the first feedforward pass with some input images.
|
165 |
+
Please also see PatchSampleF.create_mlp(), which is called at the first forward() call.
|
166 |
+
"""
|
167 |
+
self.set_input(data)
|
168 |
+
bs_per_gpu = self.real_A.size(0) // max(len(self.opt.gpu_ids), 1)
|
169 |
+
self.real_A = self.real_A[:bs_per_gpu]
|
170 |
+
self.real_B = self.real_B[:bs_per_gpu]
|
171 |
+
self.forward() # compute fake images: G(A)
|
172 |
+
if self.opt.isTrain:
|
173 |
+
self.compute_D_loss().backward() # calculate gradients for D
|
174 |
+
self.compute_G_loss().backward() # calculate graidents for G
|
175 |
+
# if self.opt.lambda_NCE > 0.0:
|
176 |
+
# self.optimizer_F = torch.optim.Adam(self.netF.parameters(), lr=self.opt.lr, betas=(self.opt.beta1, self.opt.beta2))
|
177 |
+
# self.optimizers.append(self.optimizer_F)
|
178 |
+
#
|
179 |
+
# elif self.opt.lambda_HDCE > 0.0:
|
180 |
+
self.optimizer_F = torch.optim.Adam(self.netF.parameters(), lr=self.opt.lr, betas=(self.opt.beta1, self.opt.beta2))
|
181 |
+
self.optimizers.append(self.optimizer_F)
|
182 |
+
|
183 |
+
|
184 |
+
def optimize_parameters(self):
|
185 |
+
# forward
|
186 |
+
self.forward()
|
187 |
+
|
188 |
+
# update D
|
189 |
+
self.set_requires_grad(self.netD, True)
|
190 |
+
self.optimizer_D.zero_grad()
|
191 |
+
self.loss_D = self.compute_D_loss()
|
192 |
+
self.loss_D.backward()
|
193 |
+
self.optimizer_D.step()
|
194 |
+
|
195 |
+
# update G
|
196 |
+
self.set_requires_grad(self.netD, False)
|
197 |
+
self.optimizer_G.zero_grad()
|
198 |
+
if self.opt.netF == 'mlp_sample':
|
199 |
+
# if self.opt.lambda_NCE > 0.0:
|
200 |
+
# self.optimizer_F.zero_grad()
|
201 |
+
# elif self.opt.lambda_HDCE > 0.0:
|
202 |
+
self.optimizer_F.zero_grad()
|
203 |
+
self.loss_G = self.compute_G_loss()
|
204 |
+
self.loss_G.backward()
|
205 |
+
self.optimizer_G.step()
|
206 |
+
if self.opt.netF == 'mlp_sample':
|
207 |
+
# if self.opt.lambda_NCE > 0.0:
|
208 |
+
# self.optimizer_F.step()
|
209 |
+
# elif self.opt.lambda_HDCE > 0.0:
|
210 |
+
self.optimizer_F.step()
|
211 |
+
|
212 |
+
def set_input(self, input):
|
213 |
+
"""Unpack input data from the dataloader and perform necessary pre-processing steps.
|
214 |
+
Parameters:
|
215 |
+
input (dict): include the data itself and its metadata information.
|
216 |
+
The option 'direction' can be used to swap domain A and domain B.
|
217 |
+
"""
|
218 |
+
AtoB = self.opt.direction == 'AtoB'
|
219 |
+
self.real_A = input['A' if AtoB else 'B'].to(self.device)
|
220 |
+
self.real_B = input['B' if AtoB else 'A'].to(self.device)
|
221 |
+
self.image_paths = input['A_paths' if AtoB else 'B_paths']
|
222 |
+
|
223 |
+
def forward(self):
|
224 |
+
"""Run forward pass; called by both functions <optimize_parameters> and <test>."""
|
225 |
+
self.real = torch.cat((self.real_A, self.real_B), dim=0) if self.opt.dce_idt and self.opt.isTrain else self.real_A
|
226 |
+
if self.opt.flip_equivariance:
|
227 |
+
self.flipped_for_equivariance = self.opt.isTrain and (np.random.random() < 0.5)
|
228 |
+
if self.flipped_for_equivariance:
|
229 |
+
self.real = torch.flip(self.real, [3])
|
230 |
+
|
231 |
+
self.fake = self.netG(self.real)
|
232 |
+
self.fake_B = self.fake[:self.real_A.size(0)]
|
233 |
+
if self.opt.dce_idt:
|
234 |
+
self.idt_B = self.fake[self.real_A.size(0):]
|
235 |
+
|
236 |
+
|
237 |
+
def set_epoch(self, epoch):
|
238 |
+
self.train_epoch = epoch
|
239 |
+
|
240 |
+
def compute_D_loss(self):
|
241 |
+
"""Calculate GAN loss for the discriminator"""
|
242 |
+
fake = self.fake_B.detach()
|
243 |
+
# Fake; stop backprop to the generator by detaching fake_B
|
244 |
+
pred_fake = self.netD(fake)
|
245 |
+
self.loss_D_fake = self.criterionGAN(pred_fake, False).mean()
|
246 |
+
# Real
|
247 |
+
self.pred_real = self.netD(self.real_B)
|
248 |
+
loss_D_real = self.criterionGAN(self.pred_real, True)
|
249 |
+
self.loss_D_real = loss_D_real.mean()
|
250 |
+
|
251 |
+
# combine loss and calculate gradients
|
252 |
+
self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5
|
253 |
+
return self.loss_D
|
254 |
+
|
255 |
+
def compute_G_loss(self):
|
256 |
+
"""Calculate GAN and NCE loss for the generator"""
|
257 |
+
fake = self.fake_B
|
258 |
+
# First, G(A) should fake the discriminator
|
259 |
+
if self.opt.lambda_GAN > 0.0:
|
260 |
+
pred_fake = self.netD(fake)
|
261 |
+
self.loss_G_GAN = self.criterionGAN(pred_fake, True).mean() * self.opt.lambda_GAN
|
262 |
+
else:
|
263 |
+
self.loss_G_GAN = 0.0
|
264 |
+
|
265 |
+
## get feat
|
266 |
+
fake_B_feat = self.netG(self.fake_B, self.nce_layers, encode_only=True)
|
267 |
+
if self.opt.flip_equivariance and self.flipped_for_equivariance:
|
268 |
+
fake_B_feat = [torch.flip(fq, [3]) for fq in fake_B_feat]
|
269 |
+
real_A_feat = self.netG(self.real_A, self.nce_layers, encode_only=True)
|
270 |
+
|
271 |
+
fake_B_pool, sample_ids = self.netF(fake_B_feat, self.opt.num_patches, None)
|
272 |
+
real_A_pool, _ = self.netF(real_A_feat, self.opt.num_patches, sample_ids)
|
273 |
+
|
274 |
+
if self.opt.dce_idt:
|
275 |
+
idt_B_feat = self.netG(self.idt_B, self.nce_layers, encode_only=True)
|
276 |
+
if self.opt.flip_equivariance and self.flipped_for_equivariance:
|
277 |
+
idt_B_feat = [torch.flip(fq, [3]) for fq in idt_B_feat]
|
278 |
+
real_B_feat = self.netG(self.real_B, self.nce_layers, encode_only=True)
|
279 |
+
|
280 |
+
idt_B_pool, _ = self.netF(idt_B_feat, self.opt.num_patches, sample_ids)
|
281 |
+
real_B_pool, _ = self.netF(real_B_feat, self.opt.num_patches, sample_ids)
|
282 |
+
|
283 |
+
|
284 |
+
## Relation Loss
|
285 |
+
self.loss_SRC, weight = self.calculate_R_loss(real_A_pool, fake_B_pool, epoch=self.train_epoch)
|
286 |
+
|
287 |
+
|
288 |
+
## HDCE
|
289 |
+
if self.opt.lambda_HDCE > 0.0:
|
290 |
+
self.loss_HDCE = self.calculate_HDCE_loss(real_A_pool, fake_B_pool, weight)
|
291 |
+
else:
|
292 |
+
self.loss_HDCE, self.loss_HDCE_bd = 0.0, 0.0
|
293 |
+
|
294 |
+
self.loss_HDCE_Y = 0
|
295 |
+
if self.opt.dce_idt and self.opt.lambda_HDCE > 0.0:
|
296 |
+
_, weight_idt = self.calculate_R_loss(real_B_pool, idt_B_pool, only_weight=True, epoch=self.train_epoch)
|
297 |
+
self.loss_HDCE_Y = self.calculate_HDCE_loss(real_B_pool, idt_B_pool, weight_idt)
|
298 |
+
loss_HDCE_both = (self.loss_HDCE + self.loss_HDCE_Y) * 0.5
|
299 |
+
else:
|
300 |
+
loss_HDCE_both = self.loss_HDCE
|
301 |
+
|
302 |
+
self.loss_G = self.loss_G_GAN + loss_HDCE_both + self.loss_SRC
|
303 |
+
return self.loss_G
|
304 |
+
|
305 |
+
|
306 |
+
def calculate_HDCE_loss(self, src, tgt, weight=None):
|
307 |
+
n_layers = len(self.nce_layers)
|
308 |
+
|
309 |
+
feat_q_pool = tgt
|
310 |
+
feat_k_pool = src
|
311 |
+
|
312 |
+
total_HDCE_loss = 0.0
|
313 |
+
for f_q, f_k, crit, nce_layer, w in zip(feat_q_pool, feat_k_pool, self.criterionHDCE, self.nce_layers, weight):
|
314 |
+
if self.opt.no_Hneg:
|
315 |
+
w = None
|
316 |
+
loss = crit(f_q, f_k, w) * self.opt.lambda_HDCE
|
317 |
+
total_HDCE_loss += loss.mean()
|
318 |
+
|
319 |
+
return total_HDCE_loss / n_layers
|
320 |
+
|
321 |
+
|
322 |
+
def calculate_R_loss(self, src, tgt, only_weight=False, epoch=None):
|
323 |
+
n_layers = len(self.nce_layers)
|
324 |
+
|
325 |
+
feat_q_pool = tgt
|
326 |
+
feat_k_pool = src
|
327 |
+
|
328 |
+
total_SRC_loss = 0.0
|
329 |
+
weights=[]
|
330 |
+
for f_q, f_k, crit, nce_layer in zip(feat_q_pool, feat_k_pool, self.criterionR, self.nce_layers):
|
331 |
+
loss_SRC, weight = crit(f_q, f_k, only_weight, epoch)
|
332 |
+
total_SRC_loss += loss_SRC * self.opt.lambda_SRC
|
333 |
+
weights.append(weight)
|
334 |
+
return total_SRC_loss / n_layers, weights
|
335 |
+
|
336 |
+
|
337 |
+
#--------------------------------------------------------------------------------------------------------
|
338 |
+
def calculate_Patchloss(self, src, tgt, num_patch=4):
|
339 |
+
|
340 |
+
feat_org = self.netG(src, mode='encoder')
|
341 |
+
if self.opt.flip_equivariance and self.flipped_for_equivariance:
|
342 |
+
feat_org = torch.flip(feat_org, [3])
|
343 |
+
|
344 |
+
N,C,H,W = feat_org.size()
|
345 |
+
|
346 |
+
ps = H//num_patch
|
347 |
+
lam = get_spa_lambda(self.alpha,size=(1,1,num_patch**2),device = feat_org.device)
|
348 |
+
feat_org_unfold = F.unfold(feat_org,kernel_size=(ps,ps),padding=0,stride=ps)
|
349 |
+
|
350 |
+
rndperm = torch.randperm(feat_org_unfold.size(2))
|
351 |
+
feat_prm = feat_org_unfold[:,:,rndperm]
|
352 |
+
feat_mix = lam*feat_org_unfold + (1-lam)*feat_prm
|
353 |
+
feat_mix = F.fold(feat_mix,output_size=(H,W),kernel_size=(ps,ps),padding=0,stride=ps)
|
354 |
+
|
355 |
+
out_mix = self.netG(feat_mix,mode='decoder')
|
356 |
+
feat_mix_rec = self.netG(out_mix,mode='encoder')
|
357 |
+
|
358 |
+
fake_feat = self.netG(tgt,mode='encoder')
|
359 |
+
|
360 |
+
fake_feat_unfold = F.unfold(fake_feat,kernel_size=(ps,ps),padding=0,stride=ps)
|
361 |
+
fake_feat_prm = fake_feat_unfold[:,:,rndperm]
|
362 |
+
fake_feat_mix = lam*fake_feat_unfold + (1-lam)*fake_feat_prm
|
363 |
+
fake_feat_mix = F.fold(fake_feat_mix,output_size=(H,W),kernel_size=(ps,ps),padding=0,stride=ps)
|
364 |
+
|
365 |
+
|
366 |
+
PM_loss = torch.mean(torch.abs(fake_feat_mix - feat_mix_rec))
|
367 |
+
|
368 |
+
return 10*PM_loss
|
369 |
+
|
370 |
+
#--------------------------------------------------------------------------------------------------------
|
Scenimefy/models/hDCE.py
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from packaging import version
|
2 |
+
import torch
|
3 |
+
from torch import nn
|
4 |
+
|
5 |
+
|
6 |
+
|
7 |
+
class PatchHDCELoss(nn.Module):
|
8 |
+
def __init__(self, opt):
|
9 |
+
super().__init__()
|
10 |
+
self.opt = opt
|
11 |
+
self.cross_entropy_loss = torch.nn.CrossEntropyLoss(reduction='none')
|
12 |
+
self.mask_dtype = torch.uint8 if version.parse(torch.__version__) < version.parse('1.2.0') else torch.bool
|
13 |
+
|
14 |
+
def forward(self, feat_q, feat_k, weight=None):
|
15 |
+
batchSize = feat_q.shape[0]
|
16 |
+
dim = feat_q.shape[1]
|
17 |
+
feat_k = feat_k.detach()
|
18 |
+
|
19 |
+
# positive logit
|
20 |
+
l_pos = torch.bmm(feat_q.view(batchSize, 1, -1), feat_k.view(batchSize, -1, 1))
|
21 |
+
l_pos = l_pos.view(batchSize, 1)
|
22 |
+
|
23 |
+
if self.opt.nce_includes_all_negatives_from_minibatch:
|
24 |
+
# reshape features as if they are all negatives of minibatch of size 1.
|
25 |
+
batch_dim_for_bmm = 1
|
26 |
+
else:
|
27 |
+
batch_dim_for_bmm = self.opt.batch_size
|
28 |
+
|
29 |
+
# reshape features to batch size
|
30 |
+
feat_q = feat_q.view(batch_dim_for_bmm, -1, dim)
|
31 |
+
feat_k = feat_k.view(batch_dim_for_bmm, -1, dim)
|
32 |
+
npatches = feat_q.size(1)
|
33 |
+
l_neg_curbatch = torch.bmm(feat_q, feat_k.transpose(2, 1))
|
34 |
+
|
35 |
+
# weighted by semantic relation
|
36 |
+
if weight is not None:
|
37 |
+
l_neg_curbatch *= weight
|
38 |
+
|
39 |
+
diagonal = torch.eye(npatches, device=feat_q.device, dtype=self.mask_dtype)[None, :, :]
|
40 |
+
l_neg_curbatch.masked_fill_(diagonal, -10.0)
|
41 |
+
l_neg = l_neg_curbatch.view(-1, npatches)
|
42 |
+
|
43 |
+
logits = (l_neg-l_pos)/self.opt.nce_T
|
44 |
+
v = torch.logsumexp(logits, dim=1)
|
45 |
+
loss_vec = torch.exp(v-v.detach())
|
46 |
+
|
47 |
+
# for monitoring
|
48 |
+
out_dummy = torch.cat((l_pos, l_neg), dim=1) / self.opt.nce_T
|
49 |
+
CELoss_dummy = self.cross_entropy_loss(out_dummy, torch.zeros(out_dummy.size(0), dtype=torch.long, device=feat_q.device))
|
50 |
+
|
51 |
+
loss = loss_vec.mean()-1+CELoss_dummy.detach()
|
52 |
+
|
53 |
+
return loss
|
Scenimefy/models/networks.py
ADDED
@@ -0,0 +1,1513 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from torch.nn import init
|
5 |
+
import functools
|
6 |
+
from torch.optim import lr_scheduler
|
7 |
+
import numpy as np
|
8 |
+
from Scenimefy.models.stylegan_networks import StyleGAN2Discriminator, StyleGAN2Generator, TileStyleGAN2Discriminator
|
9 |
+
|
10 |
+
###############################################################################
|
11 |
+
# Helper Functions
|
12 |
+
###############################################################################
|
13 |
+
|
14 |
+
|
15 |
+
def get_filter(filt_size=3):
|
16 |
+
if(filt_size == 1):
|
17 |
+
a = np.array([1., ])
|
18 |
+
elif(filt_size == 2):
|
19 |
+
a = np.array([1., 1.])
|
20 |
+
elif(filt_size == 3):
|
21 |
+
a = np.array([1., 2., 1.])
|
22 |
+
elif(filt_size == 4):
|
23 |
+
a = np.array([1., 3., 3., 1.])
|
24 |
+
elif(filt_size == 5):
|
25 |
+
a = np.array([1., 4., 6., 4., 1.])
|
26 |
+
elif(filt_size == 6):
|
27 |
+
a = np.array([1., 5., 10., 10., 5., 1.])
|
28 |
+
elif(filt_size == 7):
|
29 |
+
a = np.array([1., 6., 15., 20., 15., 6., 1.])
|
30 |
+
|
31 |
+
filt = torch.Tensor(a[:, None] * a[None, :])
|
32 |
+
filt = filt / torch.sum(filt)
|
33 |
+
|
34 |
+
return filt
|
35 |
+
|
36 |
+
|
37 |
+
class Downsample(nn.Module):
|
38 |
+
def __init__(self, channels, pad_type='reflect', filt_size=3, stride=2, pad_off=0):
|
39 |
+
super(Downsample, self).__init__()
|
40 |
+
self.filt_size = filt_size
|
41 |
+
self.pad_off = pad_off
|
42 |
+
self.pad_sizes = [int(1. * (filt_size - 1) / 2), int(np.ceil(1. * (filt_size - 1) / 2)), int(1. * (filt_size - 1) / 2), int(np.ceil(1. * (filt_size - 1) / 2))]
|
43 |
+
self.pad_sizes = [pad_size + pad_off for pad_size in self.pad_sizes]
|
44 |
+
self.stride = stride
|
45 |
+
self.off = int((self.stride - 1) / 2.)
|
46 |
+
self.channels = channels
|
47 |
+
|
48 |
+
filt = get_filter(filt_size=self.filt_size)
|
49 |
+
self.register_buffer('filt', filt[None, None, :, :].repeat((self.channels, 1, 1, 1)))
|
50 |
+
|
51 |
+
self.pad = get_pad_layer(pad_type)(self.pad_sizes)
|
52 |
+
|
53 |
+
def forward(self, inp):
|
54 |
+
if(self.filt_size == 1):
|
55 |
+
if(self.pad_off == 0):
|
56 |
+
return inp[:, :, ::self.stride, ::self.stride]
|
57 |
+
else:
|
58 |
+
return self.pad(inp)[:, :, ::self.stride, ::self.stride]
|
59 |
+
else:
|
60 |
+
return F.conv2d(self.pad(inp), self.filt, stride=self.stride, groups=inp.shape[1])
|
61 |
+
|
62 |
+
|
63 |
+
class Upsample2(nn.Module):
|
64 |
+
def __init__(self, scale_factor, mode='nearest'):
|
65 |
+
super().__init__()
|
66 |
+
self.factor = scale_factor
|
67 |
+
self.mode = mode
|
68 |
+
|
69 |
+
def forward(self, x):
|
70 |
+
return torch.nn.functional.interpolate(x, scale_factor=self.factor, mode=self.mode)
|
71 |
+
|
72 |
+
|
73 |
+
class Upsample(nn.Module):
|
74 |
+
def __init__(self, channels, pad_type='repl', filt_size=4, stride=2):
|
75 |
+
super(Upsample, self).__init__()
|
76 |
+
self.filt_size = filt_size
|
77 |
+
self.filt_odd = np.mod(filt_size, 2) == 1
|
78 |
+
self.pad_size = int((filt_size - 1) / 2)
|
79 |
+
self.stride = stride
|
80 |
+
self.off = int((self.stride - 1) / 2.)
|
81 |
+
self.channels = channels
|
82 |
+
|
83 |
+
filt = get_filter(filt_size=self.filt_size) * (stride**2)
|
84 |
+
self.register_buffer('filt', filt[None, None, :, :].repeat((self.channels, 1, 1, 1)))
|
85 |
+
|
86 |
+
self.pad = get_pad_layer(pad_type)([1, 1, 1, 1])
|
87 |
+
|
88 |
+
def forward(self, inp):
|
89 |
+
ret_val = F.conv_transpose2d(self.pad(inp), self.filt, stride=self.stride, padding=1 + self.pad_size, groups=inp.shape[1])[:, :, 1:, 1:]
|
90 |
+
if(self.filt_odd):
|
91 |
+
return ret_val
|
92 |
+
else:
|
93 |
+
return ret_val[:, :, :-1, :-1]
|
94 |
+
|
95 |
+
|
96 |
+
def get_pad_layer(pad_type):
|
97 |
+
if(pad_type in ['refl', 'reflect']):
|
98 |
+
PadLayer = nn.ReflectionPad2d
|
99 |
+
elif(pad_type in ['repl', 'replicate']):
|
100 |
+
PadLayer = nn.ReplicationPad2d
|
101 |
+
elif(pad_type == 'zero'):
|
102 |
+
PadLayer = nn.ZeroPad2d
|
103 |
+
else:
|
104 |
+
print('Pad type [%s] not recognized' % pad_type)
|
105 |
+
return PadLayer
|
106 |
+
|
107 |
+
|
108 |
+
class Identity(nn.Module):
|
109 |
+
def forward(self, x):
|
110 |
+
return x
|
111 |
+
|
112 |
+
|
113 |
+
def get_norm_layer(norm_type='instance'):
|
114 |
+
"""Return a normalization layer
|
115 |
+
|
116 |
+
Parameters:
|
117 |
+
norm_type (str) -- the name of the normalization layer: batch | instance | none
|
118 |
+
|
119 |
+
For BatchNorm, we use learnable affine parameters and track running statistics (mean/stddev).
|
120 |
+
For InstanceNorm, we do not use learnable affine parameters. We do not track running statistics.
|
121 |
+
"""
|
122 |
+
if norm_type == 'batch':
|
123 |
+
norm_layer = functools.partial(nn.BatchNorm2d, affine=True, track_running_stats=True)
|
124 |
+
elif norm_type == 'instance':
|
125 |
+
norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False)
|
126 |
+
elif norm_type == 'none':
|
127 |
+
def norm_layer(x):
|
128 |
+
return Identity()
|
129 |
+
else:
|
130 |
+
raise NotImplementedError('normalization layer [%s] is not found' % norm_type)
|
131 |
+
return norm_layer
|
132 |
+
|
133 |
+
|
134 |
+
def get_scheduler(optimizer, opt):
|
135 |
+
"""Return a learning rate scheduler
|
136 |
+
|
137 |
+
Parameters:
|
138 |
+
optimizer -- the optimizer of the network
|
139 |
+
opt (option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions.
|
140 |
+
opt.lr_policy is the name of learning rate policy: linear | step | plateau | cosine
|
141 |
+
|
142 |
+
For 'linear', we keep the same learning rate for the first <opt.n_epochs> epochs
|
143 |
+
and linearly decay the rate to zero over the next <opt.n_epochs_decay> epochs.
|
144 |
+
For other schedulers (step, plateau, and cosine), we use the default PyTorch schedulers.
|
145 |
+
See https://pytorch.org/docs/stable/optim.html for more details.
|
146 |
+
"""
|
147 |
+
if opt.lr_policy == 'linear':
|
148 |
+
def lambda_rule(epoch):
|
149 |
+
lr_l = 1.0 - max(0, epoch + opt.epoch_count - opt.n_epochs) / float(opt.n_epochs_decay + 1)
|
150 |
+
return lr_l
|
151 |
+
scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule)
|
152 |
+
elif opt.lr_policy == 'step':
|
153 |
+
scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_iters, gamma=0.1)
|
154 |
+
elif opt.lr_policy == 'plateau':
|
155 |
+
scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5)
|
156 |
+
elif opt.lr_policy == 'cosine':
|
157 |
+
scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=opt.n_epochs, eta_min=0)
|
158 |
+
else:
|
159 |
+
return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy)
|
160 |
+
return scheduler
|
161 |
+
|
162 |
+
|
163 |
+
def init_weights(net, init_type='normal', init_gain=0.02, debug=False):
|
164 |
+
"""Initialize network weights.
|
165 |
+
|
166 |
+
Parameters:
|
167 |
+
net (network) -- network to be initialized
|
168 |
+
init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal
|
169 |
+
init_gain (float) -- scaling factor for normal, xavier and orthogonal.
|
170 |
+
|
171 |
+
We use 'normal' in the original pix2pix and CycleGAN paper. But xavier and kaiming might
|
172 |
+
work better for some applications. Feel free to try yourself.
|
173 |
+
"""
|
174 |
+
def init_func(m): # define the initialization function
|
175 |
+
classname = m.__class__.__name__
|
176 |
+
if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
|
177 |
+
if debug:
|
178 |
+
print(classname)
|
179 |
+
if init_type == 'normal':
|
180 |
+
init.normal_(m.weight.data, 0.0, init_gain)
|
181 |
+
elif init_type == 'xavier':
|
182 |
+
init.xavier_normal_(m.weight.data, gain=init_gain)
|
183 |
+
elif init_type == 'kaiming':
|
184 |
+
init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
|
185 |
+
elif init_type == 'orthogonal':
|
186 |
+
init.orthogonal_(m.weight.data, gain=init_gain)
|
187 |
+
else:
|
188 |
+
raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
|
189 |
+
if hasattr(m, 'bias') and m.bias is not None:
|
190 |
+
init.constant_(m.bias.data, 0.0)
|
191 |
+
elif classname.find('BatchNorm2d') != -1: # BatchNorm Layer's weight is not a matrix; only normal distribution applies.
|
192 |
+
init.normal_(m.weight.data, 1.0, init_gain)
|
193 |
+
init.constant_(m.bias.data, 0.0)
|
194 |
+
|
195 |
+
net.apply(init_func) # apply the initialization function <init_func>
|
196 |
+
|
197 |
+
|
198 |
+
def init_net(net, init_type='normal', init_gain=0.02, gpu_ids=[], debug=False, initialize_weights=True):
|
199 |
+
"""Initialize a network: 1. register CPU/GPU device (with multi-GPU support); 2. initialize the network weights
|
200 |
+
Parameters:
|
201 |
+
net (network) -- the network to be initialized
|
202 |
+
init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal
|
203 |
+
gain (float) -- scaling factor for normal, xavier and orthogonal.
|
204 |
+
gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2
|
205 |
+
|
206 |
+
Return an initialized network.
|
207 |
+
"""
|
208 |
+
if len(gpu_ids) > 0:
|
209 |
+
assert(torch.cuda.is_available())
|
210 |
+
net.to(gpu_ids[0])
|
211 |
+
# if not amp:
|
212 |
+
# net = torch.nn.DataParallel(net, gpu_ids) # multi-GPUs for non-AMP training
|
213 |
+
if initialize_weights:
|
214 |
+
init_weights(net, init_type, init_gain=init_gain, debug=debug)
|
215 |
+
return net
|
216 |
+
|
217 |
+
|
218 |
+
def define_G(input_nc, output_nc, ngf, netG, norm='batch', use_dropout=False, init_type='normal',
|
219 |
+
init_gain=0.02, no_antialias=False, no_antialias_up=False, gpu_ids=[], opt=None):
|
220 |
+
"""Create a generator
|
221 |
+
|
222 |
+
Parameters:
|
223 |
+
input_nc (int) -- the number of channels in input images
|
224 |
+
output_nc (int) -- the number of channels in output images
|
225 |
+
ngf (int) -- the number of filters in the last conv layer
|
226 |
+
netG (str) -- the architecture's name: resnet_9blocks | resnet_6blocks | unet_256 | unet_128
|
227 |
+
norm (str) -- the name of normalization layers used in the network: batch | instance | none
|
228 |
+
use_dropout (bool) -- if use dropout layers.
|
229 |
+
init_type (str) -- the name of our initialization method.
|
230 |
+
init_gain (float) -- scaling factor for normal, xavier and orthogonal.
|
231 |
+
gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2
|
232 |
+
|
233 |
+
Returns a generator
|
234 |
+
|
235 |
+
Our current implementation provides two types of generators:
|
236 |
+
U-Net: [unet_128] (for 128x128 input images) and [unet_256] (for 256x256 input images)
|
237 |
+
The original U-Net paper: https://arxiv.org/abs/1505.04597
|
238 |
+
|
239 |
+
Resnet-based generator: [resnet_6blocks] (with 6 Resnet blocks) and [resnet_9blocks] (with 9 Resnet blocks)
|
240 |
+
Resnet-based generator consists of several Resnet blocks between a few downsampling/upsampling operations.
|
241 |
+
We adapt Torch code from Justin Johnson's neural style transfer project (https://github.com/jcjohnson/fast-neural-style).
|
242 |
+
|
243 |
+
|
244 |
+
The generator has been initialized by <init_net>. It uses RELU for non-linearity.
|
245 |
+
"""
|
246 |
+
net = None
|
247 |
+
norm_layer = get_norm_layer(norm_type=norm)
|
248 |
+
|
249 |
+
if netG == 'resnet_9blocks':
|
250 |
+
net = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, no_antialias=no_antialias, no_antialias_up=no_antialias_up, n_blocks=9, opt=opt)
|
251 |
+
elif netG == 'resnet_6blocks':
|
252 |
+
net = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, no_antialias=no_antialias, no_antialias_up=no_antialias_up, n_blocks=6, opt=opt)
|
253 |
+
elif netG == 'resnet_4blocks':
|
254 |
+
net = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, no_antialias=no_antialias, no_antialias_up=no_antialias_up, n_blocks=4, opt=opt)
|
255 |
+
elif netG == 'unet_128':
|
256 |
+
net = UnetGenerator(input_nc, output_nc, 7, ngf, norm_layer=norm_layer, use_dropout=use_dropout)
|
257 |
+
elif netG == 'unet_256':
|
258 |
+
net = UnetGenerator(input_nc, output_nc, 8, ngf, norm_layer=norm_layer, use_dropout=use_dropout)
|
259 |
+
elif netG == 'stylegan2':
|
260 |
+
net = StyleGAN2Generator(input_nc, output_nc, ngf, use_dropout=use_dropout, opt=opt)
|
261 |
+
elif netG == 'smallstylegan2':
|
262 |
+
net = StyleGAN2Generator(input_nc, output_nc, ngf, use_dropout=use_dropout, n_blocks=2, opt=opt)
|
263 |
+
elif netG == 'resnet_cat':
|
264 |
+
n_blocks = 8
|
265 |
+
net = G_Resnet(input_nc, output_nc, opt.nz, num_downs=2, n_res=n_blocks - 4, ngf=ngf, norm='inst', nl_layer='relu')
|
266 |
+
else:
|
267 |
+
raise NotImplementedError('Generator model name [%s] is not recognized' % netG)
|
268 |
+
return init_net(net, init_type, init_gain, gpu_ids, initialize_weights=('stylegan2' not in netG))
|
269 |
+
|
270 |
+
|
271 |
+
def define_F(input_nc, netF, norm='batch', use_dropout=False, init_type='normal', init_gain=0.02, no_antialias=False, gpu_ids=[], opt=None):
|
272 |
+
if netF == 'global_pool':
|
273 |
+
net = PoolingF()
|
274 |
+
elif netF == 'reshape':
|
275 |
+
net = ReshapeF()
|
276 |
+
elif netF == 'sample':
|
277 |
+
net = PatchSampleF(use_mlp=False, init_type=init_type, init_gain=init_gain, gpu_ids=gpu_ids, nc=opt.netF_nc)
|
278 |
+
elif netF == 'mlp_sample':
|
279 |
+
net = PatchSampleF(use_mlp=True, init_type=init_type, init_gain=init_gain, gpu_ids=gpu_ids, nc=opt.netF_nc)
|
280 |
+
elif netF == 'strided_conv':
|
281 |
+
net = StridedConvF(init_type=init_type, init_gain=init_gain, gpu_ids=gpu_ids)
|
282 |
+
else:
|
283 |
+
raise NotImplementedError('projection model name [%s] is not recognized' % netF)
|
284 |
+
return init_net(net, init_type, init_gain, gpu_ids)
|
285 |
+
|
286 |
+
def define_D(input_nc, ndf, netD, n_layers_D=3, norm='batch', init_type='normal', init_gain=0.02, no_antialias=False, gpu_ids=[], opt=None):
|
287 |
+
"""Create a discriminator
|
288 |
+
|
289 |
+
Parameters:
|
290 |
+
input_nc (int) -- the number of channels in input images
|
291 |
+
ndf (int) -- the number of filters in the first conv layer
|
292 |
+
netD (str) -- the architecture's name: basic | n_layers | pixel
|
293 |
+
n_layers_D (int) -- the number of conv layers in the discriminator; effective when netD=='n_layers'
|
294 |
+
norm (str) -- the type of normalization layers used in the network.
|
295 |
+
init_type (str) -- the name of the initialization method.
|
296 |
+
init_gain (float) -- scaling factor for normal, xavier and orthogonal.
|
297 |
+
gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2
|
298 |
+
|
299 |
+
Returns a discriminator
|
300 |
+
|
301 |
+
Our current implementation provides three types of discriminators:
|
302 |
+
[basic]: 'PatchGAN' classifier described in the original pix2pix paper.
|
303 |
+
It can classify whether 70×70 overlapping patches are real or fake.
|
304 |
+
Such a patch-level discriminator architecture has fewer parameters
|
305 |
+
than a full-image discriminator and can work on arbitrarily-sized images
|
306 |
+
in a fully convolutional fashion.
|
307 |
+
|
308 |
+
[n_layers]: With this mode, you cna specify the number of conv layers in the discriminator
|
309 |
+
with the parameter <n_layers_D> (default=3 as used in [basic] (PatchGAN).)
|
310 |
+
|
311 |
+
[pixel]: 1x1 PixelGAN discriminator can classify whether a pixel is real or not.
|
312 |
+
It encourages greater color diversity but has no effect on spatial statistics.
|
313 |
+
|
314 |
+
The discriminator has been initialized by <init_net>. It uses Leaky RELU for non-linearity.
|
315 |
+
"""
|
316 |
+
net = None
|
317 |
+
norm_layer = get_norm_layer(norm_type=norm)
|
318 |
+
|
319 |
+
if netD == 'basic': # default PatchGAN classifier
|
320 |
+
net = NLayerDiscriminator(input_nc, ndf, n_layers=3, norm_layer=norm_layer, no_antialias=no_antialias,)
|
321 |
+
elif netD == 'n_layers': # more options
|
322 |
+
net = NLayerDiscriminator(input_nc, ndf, n_layers_D, norm_layer=norm_layer, no_antialias=no_antialias,)
|
323 |
+
elif netD == 'pixel': # classify if each pixel is real or fake
|
324 |
+
net = PixelDiscriminator(input_nc, ndf, norm_layer=norm_layer)
|
325 |
+
elif 'stylegan2' in netD:
|
326 |
+
net = StyleGAN2Discriminator(input_nc, ndf, n_layers_D, no_antialias=no_antialias, opt=opt)
|
327 |
+
else:
|
328 |
+
raise NotImplementedError('Discriminator model name [%s] is not recognized' % netD)
|
329 |
+
return init_net(net, init_type, init_gain, gpu_ids,
|
330 |
+
initialize_weights=('stylegan2' not in netD))
|
331 |
+
|
332 |
+
|
333 |
+
##############################################################################
|
334 |
+
# Classes
|
335 |
+
##############################################################################
|
336 |
+
class GANLoss(nn.Module):
|
337 |
+
"""Define different GAN objectives.
|
338 |
+
|
339 |
+
The GANLoss class abstracts away the need to create the target label tensor
|
340 |
+
that has the same size as the input.
|
341 |
+
"""
|
342 |
+
|
343 |
+
def __init__(self, gan_mode, target_real_label=1.0, target_fake_label=0.0):
|
344 |
+
""" Initialize the GANLoss class.
|
345 |
+
|
346 |
+
Parameters:
|
347 |
+
gan_mode (str) - - the type of GAN objective. It currently supports vanilla, lsgan, and wgangp.
|
348 |
+
target_real_label (bool) - - label for a real image
|
349 |
+
target_fake_label (bool) - - label of a fake image
|
350 |
+
|
351 |
+
Note: Do not use sigmoid as the last layer of Discriminator.
|
352 |
+
LSGAN needs no sigmoid. vanilla GANs will handle it with BCEWithLogitsLoss.
|
353 |
+
"""
|
354 |
+
super(GANLoss, self).__init__()
|
355 |
+
self.register_buffer('real_label', torch.tensor(target_real_label))
|
356 |
+
self.register_buffer('fake_label', torch.tensor(target_fake_label))
|
357 |
+
self.gan_mode = gan_mode
|
358 |
+
if gan_mode == 'lsgan':
|
359 |
+
self.loss = nn.MSELoss()
|
360 |
+
elif gan_mode == 'vanilla':
|
361 |
+
self.loss = nn.BCEWithLogitsLoss()
|
362 |
+
elif gan_mode in ['wgangp', 'nonsaturating']:
|
363 |
+
self.loss = None
|
364 |
+
else:
|
365 |
+
raise NotImplementedError('gan mode %s not implemented' % gan_mode)
|
366 |
+
|
367 |
+
def get_target_tensor(self, prediction, target_is_real):
|
368 |
+
"""Create label tensors with the same size as the input.
|
369 |
+
|
370 |
+
Parameters:
|
371 |
+
prediction (tensor) - - tpyically the prediction from a discriminator
|
372 |
+
target_is_real (bool) - - if the ground truth label is for real images or fake images
|
373 |
+
|
374 |
+
Returns:
|
375 |
+
A label tensor filled with ground truth label, and with the size of the input
|
376 |
+
"""
|
377 |
+
|
378 |
+
if target_is_real:
|
379 |
+
target_tensor = self.real_label
|
380 |
+
else:
|
381 |
+
target_tensor = self.fake_label
|
382 |
+
return target_tensor.expand_as(prediction)
|
383 |
+
|
384 |
+
def __call__(self, prediction, target_is_real):
|
385 |
+
"""Calculate loss given Discriminator's output and grount truth labels.
|
386 |
+
|
387 |
+
Parameters:
|
388 |
+
prediction (tensor) - - tpyically the prediction output from a discriminator
|
389 |
+
target_is_real (bool) - - if the ground truth label is for real images or fake images
|
390 |
+
|
391 |
+
Returns:
|
392 |
+
the calculated loss.
|
393 |
+
"""
|
394 |
+
bs = prediction.size(0)
|
395 |
+
if self.gan_mode in ['lsgan', 'vanilla']:
|
396 |
+
target_tensor = self.get_target_tensor(prediction, target_is_real)
|
397 |
+
loss = self.loss(prediction, target_tensor)
|
398 |
+
elif self.gan_mode == 'wgangp':
|
399 |
+
if target_is_real:
|
400 |
+
loss = -prediction.mean()
|
401 |
+
else:
|
402 |
+
loss = prediction.mean()
|
403 |
+
elif self.gan_mode == 'nonsaturating':
|
404 |
+
if target_is_real:
|
405 |
+
loss = F.softplus(-prediction).view(bs, -1).mean(dim=1)
|
406 |
+
else:
|
407 |
+
loss = F.softplus(prediction).view(bs, -1).mean(dim=1)
|
408 |
+
return loss
|
409 |
+
|
410 |
+
|
411 |
+
def cal_gradient_penalty(netD, real_data, fake_data, device, type='mixed', constant=1.0, lambda_gp=10.0):
|
412 |
+
"""Calculate the gradient penalty loss, used in WGAN-GP paper https://arxiv.org/abs/1704.00028
|
413 |
+
|
414 |
+
Arguments:
|
415 |
+
netD (network) -- discriminator network
|
416 |
+
real_data (tensor array) -- real images
|
417 |
+
fake_data (tensor array) -- generated images from the generator
|
418 |
+
device (str) -- GPU / CPU: from torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu')
|
419 |
+
type (str) -- if we mix real and fake data or not [real | fake | mixed].
|
420 |
+
constant (float) -- the constant used in formula ( | |gradient||_2 - constant)^2
|
421 |
+
lambda_gp (float) -- weight for this loss
|
422 |
+
|
423 |
+
Returns the gradient penalty loss
|
424 |
+
"""
|
425 |
+
if lambda_gp > 0.0:
|
426 |
+
if type == 'real': # either use real images, fake images, or a linear interpolation of two.
|
427 |
+
interpolatesv = real_data
|
428 |
+
elif type == 'fake':
|
429 |
+
interpolatesv = fake_data
|
430 |
+
elif type == 'mixed':
|
431 |
+
alpha = torch.rand(real_data.shape[0], 1, device=device)
|
432 |
+
alpha = alpha.expand(real_data.shape[0], real_data.nelement() // real_data.shape[0]).contiguous().view(*real_data.shape)
|
433 |
+
interpolatesv = alpha * real_data + ((1 - alpha) * fake_data)
|
434 |
+
else:
|
435 |
+
raise NotImplementedError('{} not implemented'.format(type))
|
436 |
+
interpolatesv.requires_grad_(True)
|
437 |
+
disc_interpolates = netD(interpolatesv)
|
438 |
+
gradients = torch.autograd.grad(outputs=disc_interpolates, inputs=interpolatesv,
|
439 |
+
grad_outputs=torch.ones(disc_interpolates.size()).to(device),
|
440 |
+
create_graph=True, retain_graph=True, only_inputs=True)
|
441 |
+
gradients = gradients[0].view(real_data.size(0), -1) # flat the data
|
442 |
+
gradient_penalty = (((gradients + 1e-16).norm(2, dim=1) - constant) ** 2).mean() * lambda_gp # added eps
|
443 |
+
return gradient_penalty, gradients
|
444 |
+
else:
|
445 |
+
return 0.0, None
|
446 |
+
|
447 |
+
|
448 |
+
class Normalize(nn.Module):
|
449 |
+
|
450 |
+
def __init__(self, power=2):
|
451 |
+
super(Normalize, self).__init__()
|
452 |
+
self.power = power
|
453 |
+
|
454 |
+
def forward(self, x):
|
455 |
+
norm = x.pow(self.power).sum(1, keepdim=True).pow(1. / self.power)
|
456 |
+
out = x.div(norm + 1e-7)
|
457 |
+
return out
|
458 |
+
|
459 |
+
|
460 |
+
class PoolingF(nn.Module):
|
461 |
+
def __init__(self):
|
462 |
+
super(PoolingF, self).__init__()
|
463 |
+
model = [nn.AdaptiveMaxPool2d(1)]
|
464 |
+
self.model = nn.Sequential(*model)
|
465 |
+
self.l2norm = Normalize(2)
|
466 |
+
|
467 |
+
def forward(self, x):
|
468 |
+
return self.l2norm(self.model(x))
|
469 |
+
|
470 |
+
|
471 |
+
class ReshapeF(nn.Module):
|
472 |
+
def __init__(self):
|
473 |
+
super(ReshapeF, self).__init__()
|
474 |
+
model = [nn.AdaptiveAvgPool2d(4)]
|
475 |
+
self.model = nn.Sequential(*model)
|
476 |
+
self.l2norm = Normalize(2)
|
477 |
+
|
478 |
+
def forward(self, x):
|
479 |
+
x = self.model(x)
|
480 |
+
x_reshape = x.permute(0, 2, 3, 1).flatten(0, 2)
|
481 |
+
return self.l2norm(x_reshape)
|
482 |
+
|
483 |
+
|
484 |
+
class StridedConvF(nn.Module):
|
485 |
+
def __init__(self, init_type='normal', init_gain=0.02, gpu_ids=[]):
|
486 |
+
super().__init__()
|
487 |
+
# self.conv1 = nn.Conv2d(256, 128, 3, stride=2)
|
488 |
+
# self.conv2 = nn.Conv2d(128, 64, 3, stride=1)
|
489 |
+
self.l2_norm = Normalize(2)
|
490 |
+
self.mlps = {}
|
491 |
+
self.moving_averages = {}
|
492 |
+
self.init_type = init_type
|
493 |
+
self.init_gain = init_gain
|
494 |
+
self.gpu_ids = gpu_ids
|
495 |
+
|
496 |
+
def create_mlp(self, x):
|
497 |
+
C, H = x.shape[1], x.shape[2]
|
498 |
+
n_down = int(np.rint(np.log2(H / 32)))
|
499 |
+
mlp = []
|
500 |
+
for i in range(n_down):
|
501 |
+
mlp.append(nn.Conv2d(C, max(C // 2, 64), 3, stride=2))
|
502 |
+
mlp.append(nn.ReLU())
|
503 |
+
C = max(C // 2, 64)
|
504 |
+
mlp.append(nn.Conv2d(C, 64, 3))
|
505 |
+
mlp = nn.Sequential(*mlp)
|
506 |
+
init_net(mlp, self.init_type, self.init_gain, self.gpu_ids)
|
507 |
+
return mlp
|
508 |
+
|
509 |
+
def update_moving_average(self, key, x):
|
510 |
+
if key not in self.moving_averages:
|
511 |
+
self.moving_averages[key] = x.detach()
|
512 |
+
|
513 |
+
self.moving_averages[key] = self.moving_averages[key] * 0.999 + x.detach() * 0.001
|
514 |
+
|
515 |
+
def forward(self, x, use_instance_norm=False):
|
516 |
+
C, H = x.shape[1], x.shape[2]
|
517 |
+
key = '%d_%d' % (C, H)
|
518 |
+
if key not in self.mlps:
|
519 |
+
self.mlps[key] = self.create_mlp(x)
|
520 |
+
self.add_module("child_%s" % key, self.mlps[key])
|
521 |
+
mlp = self.mlps[key]
|
522 |
+
x = mlp(x)
|
523 |
+
self.update_moving_average(key, x)
|
524 |
+
x = x - self.moving_averages[key]
|
525 |
+
if use_instance_norm:
|
526 |
+
x = F.instance_norm(x)
|
527 |
+
return self.l2_norm(x)
|
528 |
+
|
529 |
+
|
530 |
+
class PatchSampleF(nn.Module):
|
531 |
+
def __init__(self, use_mlp=False, init_type='normal', init_gain=0.02, nc=256, gpu_ids=[]):
|
532 |
+
# potential issues: currently, we use the same patch_ids for multiple images in the batch
|
533 |
+
super(PatchSampleF, self).__init__()
|
534 |
+
self.l2norm = Normalize(2)
|
535 |
+
self.use_mlp = use_mlp
|
536 |
+
self.nc = nc # hard-coded
|
537 |
+
self.mlp_init = False
|
538 |
+
self.init_type = init_type
|
539 |
+
self.init_gain = init_gain
|
540 |
+
self.gpu_ids = gpu_ids
|
541 |
+
|
542 |
+
def create_mlp(self, feats):
|
543 |
+
for mlp_id, feat in enumerate(feats):
|
544 |
+
input_nc = feat.shape[1]
|
545 |
+
mlp = nn.Sequential(*[nn.Linear(input_nc, self.nc), nn.ReLU(), nn.Linear(self.nc, self.nc)])
|
546 |
+
if len(self.gpu_ids) > 0:
|
547 |
+
mlp.cuda()
|
548 |
+
setattr(self, 'mlp_%d' % mlp_id, mlp)
|
549 |
+
init_net(self, self.init_type, self.init_gain, self.gpu_ids)
|
550 |
+
self.mlp_init = True
|
551 |
+
|
552 |
+
def forward(self, feats, num_patches=64, patch_ids=None):
|
553 |
+
return_ids = []
|
554 |
+
return_feats = []
|
555 |
+
if self.use_mlp and not self.mlp_init:
|
556 |
+
self.create_mlp(feats)
|
557 |
+
for feat_id, feat in enumerate(feats):
|
558 |
+
B, H, W = feat.shape[0], feat.shape[2], feat.shape[3]
|
559 |
+
feat_reshape = feat.permute(0, 2, 3, 1).flatten(1, 2)
|
560 |
+
if num_patches > 0:
|
561 |
+
if patch_ids is not None:
|
562 |
+
patch_id = patch_ids[feat_id]
|
563 |
+
else:
|
564 |
+
patch_id = torch.randperm(feat_reshape.shape[1], device=feats[0].device)
|
565 |
+
patch_id = patch_id[:int(min(num_patches, patch_id.shape[0]))] # .to(patch_ids.device)
|
566 |
+
x_sample = feat_reshape[:, patch_id, :].flatten(0, 1) # reshape(-1, x.shape[1])
|
567 |
+
else:
|
568 |
+
x_sample = feat_reshape
|
569 |
+
patch_id = []
|
570 |
+
if self.use_mlp:
|
571 |
+
mlp = getattr(self, 'mlp_%d' % feat_id)
|
572 |
+
x_sample = mlp(x_sample)
|
573 |
+
return_ids.append(patch_id)
|
574 |
+
x_sample = self.l2norm(x_sample)
|
575 |
+
|
576 |
+
if num_patches == 0:
|
577 |
+
x_sample = x_sample.permute(0, 2, 1).reshape([B, x_sample.shape[-1], H, W])
|
578 |
+
return_feats.append(x_sample)
|
579 |
+
return return_feats, return_ids
|
580 |
+
|
581 |
+
|
582 |
+
class G_Resnet(nn.Module):
|
583 |
+
def __init__(self, input_nc, output_nc, nz, num_downs, n_res, ngf=64,
|
584 |
+
norm=None, nl_layer=None):
|
585 |
+
super(G_Resnet, self).__init__()
|
586 |
+
n_downsample = num_downs
|
587 |
+
pad_type = 'reflect'
|
588 |
+
self.enc_content = ContentEncoder(n_downsample, n_res, input_nc, ngf, norm, nl_layer, pad_type=pad_type)
|
589 |
+
if nz == 0:
|
590 |
+
self.dec = Decoder(n_downsample, n_res, self.enc_content.output_dim, output_nc, norm=norm, activ=nl_layer, pad_type=pad_type, nz=nz)
|
591 |
+
else:
|
592 |
+
self.dec = Decoder_all(n_downsample, n_res, self.enc_content.output_dim, output_nc, norm=norm, activ=nl_layer, pad_type=pad_type, nz=nz)
|
593 |
+
|
594 |
+
def decode(self, content, style=None):
|
595 |
+
return self.dec(content, style)
|
596 |
+
|
597 |
+
def forward(self, image, style=None, nce_layers=[], encode_only=False):
|
598 |
+
content, feats = self.enc_content(image, nce_layers=nce_layers, encode_only=encode_only)
|
599 |
+
if encode_only:
|
600 |
+
return feats
|
601 |
+
else:
|
602 |
+
images_recon = self.decode(content, style)
|
603 |
+
if len(nce_layers) > 0:
|
604 |
+
return images_recon, feats
|
605 |
+
else:
|
606 |
+
return images_recon
|
607 |
+
|
608 |
+
##################################################################################
|
609 |
+
# Encoder and Decoders
|
610 |
+
##################################################################################
|
611 |
+
|
612 |
+
|
613 |
+
class E_adaIN(nn.Module):
|
614 |
+
def __init__(self, input_nc, output_nc=1, nef=64, n_layers=4,
|
615 |
+
norm=None, nl_layer=None, vae=False):
|
616 |
+
# style encoder
|
617 |
+
super(E_adaIN, self).__init__()
|
618 |
+
self.enc_style = StyleEncoder(n_layers, input_nc, nef, output_nc, norm='none', activ='relu', vae=vae)
|
619 |
+
|
620 |
+
def forward(self, image):
|
621 |
+
style = self.enc_style(image)
|
622 |
+
return style
|
623 |
+
|
624 |
+
|
625 |
+
class StyleEncoder(nn.Module):
|
626 |
+
def __init__(self, n_downsample, input_dim, dim, style_dim, norm, activ, vae=False):
|
627 |
+
super(StyleEncoder, self).__init__()
|
628 |
+
self.vae = vae
|
629 |
+
self.model = []
|
630 |
+
self.model += [Conv2dBlock(input_dim, dim, 7, 1, 3, norm=norm, activation=activ, pad_type='reflect')]
|
631 |
+
for i in range(2):
|
632 |
+
self.model += [Conv2dBlock(dim, 2 * dim, 4, 2, 1, norm=norm, activation=activ, pad_type='reflect')]
|
633 |
+
dim *= 2
|
634 |
+
for i in range(n_downsample - 2):
|
635 |
+
self.model += [Conv2dBlock(dim, dim, 4, 2, 1, norm=norm, activation=activ, pad_type='reflect')]
|
636 |
+
self.model += [nn.AdaptiveAvgPool2d(1)] # global average pooling
|
637 |
+
if self.vae:
|
638 |
+
self.fc_mean = nn.Linear(dim, style_dim) # , 1, 1, 0)
|
639 |
+
self.fc_var = nn.Linear(dim, style_dim) # , 1, 1, 0)
|
640 |
+
else:
|
641 |
+
self.model += [nn.Conv2d(dim, style_dim, 1, 1, 0)]
|
642 |
+
|
643 |
+
self.model = nn.Sequential(*self.model)
|
644 |
+
self.output_dim = dim
|
645 |
+
|
646 |
+
def forward(self, x):
|
647 |
+
if self.vae:
|
648 |
+
output = self.model(x)
|
649 |
+
output = output.view(x.size(0), -1)
|
650 |
+
output_mean = self.fc_mean(output)
|
651 |
+
output_var = self.fc_var(output)
|
652 |
+
return output_mean, output_var
|
653 |
+
else:
|
654 |
+
return self.model(x).view(x.size(0), -1)
|
655 |
+
|
656 |
+
|
657 |
+
class ContentEncoder(nn.Module):
|
658 |
+
def __init__(self, n_downsample, n_res, input_dim, dim, norm, activ, pad_type='zero'):
|
659 |
+
super(ContentEncoder, self).__init__()
|
660 |
+
self.model = []
|
661 |
+
self.model += [Conv2dBlock(input_dim, dim, 7, 1, 3, norm=norm, activation=activ, pad_type='reflect')]
|
662 |
+
# downsampling blocks
|
663 |
+
for i in range(n_downsample):
|
664 |
+
self.model += [Conv2dBlock(dim, 2 * dim, 4, 2, 1, norm=norm, activation=activ, pad_type='reflect')]
|
665 |
+
dim *= 2
|
666 |
+
# residual blocks
|
667 |
+
self.model += [ResBlocks(n_res, dim, norm=norm, activation=activ, pad_type=pad_type)]
|
668 |
+
self.model = nn.Sequential(*self.model)
|
669 |
+
self.output_dim = dim
|
670 |
+
|
671 |
+
def forward(self, x, nce_layers=[], encode_only=False):
|
672 |
+
if len(nce_layers) > 0:
|
673 |
+
feat = x
|
674 |
+
feats = []
|
675 |
+
for layer_id, layer in enumerate(self.model):
|
676 |
+
feat = layer(feat)
|
677 |
+
if layer_id in nce_layers:
|
678 |
+
feats.append(feat)
|
679 |
+
if layer_id == nce_layers[-1] and encode_only:
|
680 |
+
return None, feats
|
681 |
+
return feat, feats
|
682 |
+
else:
|
683 |
+
return self.model(x), None
|
684 |
+
|
685 |
+
for layer_id, layer in enumerate(self.model):
|
686 |
+
print(layer_id, layer)
|
687 |
+
|
688 |
+
|
689 |
+
class Decoder_all(nn.Module):
|
690 |
+
def __init__(self, n_upsample, n_res, dim, output_dim, norm='batch', activ='relu', pad_type='zero', nz=0):
|
691 |
+
super(Decoder_all, self).__init__()
|
692 |
+
# AdaIN residual blocks
|
693 |
+
self.resnet_block = ResBlocks(n_res, dim, norm, activ, pad_type=pad_type, nz=nz)
|
694 |
+
self.n_blocks = 0
|
695 |
+
# upsampling blocks
|
696 |
+
for i in range(n_upsample):
|
697 |
+
block = [Upsample2(scale_factor=2), Conv2dBlock(dim + nz, dim // 2, 5, 1, 2, norm='ln', activation=activ, pad_type='reflect')]
|
698 |
+
setattr(self, 'block_{:d}'.format(self.n_blocks), nn.Sequential(*block))
|
699 |
+
self.n_blocks += 1
|
700 |
+
dim //= 2
|
701 |
+
# use reflection padding in the last conv layer
|
702 |
+
setattr(self, 'block_{:d}'.format(self.n_blocks), Conv2dBlock(dim + nz, output_dim, 7, 1, 3, norm='none', activation='tanh', pad_type='reflect'))
|
703 |
+
self.n_blocks += 1
|
704 |
+
|
705 |
+
def forward(self, x, y=None):
|
706 |
+
if y is not None:
|
707 |
+
output = self.resnet_block(cat_feature(x, y))
|
708 |
+
for n in range(self.n_blocks):
|
709 |
+
block = getattr(self, 'block_{:d}'.format(n))
|
710 |
+
if n > 0:
|
711 |
+
output = block(cat_feature(output, y))
|
712 |
+
else:
|
713 |
+
output = block(output)
|
714 |
+
return output
|
715 |
+
|
716 |
+
|
717 |
+
class Decoder(nn.Module):
|
718 |
+
def __init__(self, n_upsample, n_res, dim, output_dim, norm='batch', activ='relu', pad_type='zero', nz=0):
|
719 |
+
super(Decoder, self).__init__()
|
720 |
+
|
721 |
+
self.model = []
|
722 |
+
# AdaIN residual blocks
|
723 |
+
self.model += [ResBlocks(n_res, dim, norm, activ, pad_type=pad_type, nz=nz)]
|
724 |
+
# upsampling blocks
|
725 |
+
for i in range(n_upsample):
|
726 |
+
if i == 0:
|
727 |
+
input_dim = dim + nz
|
728 |
+
else:
|
729 |
+
input_dim = dim
|
730 |
+
self.model += [Upsample2(scale_factor=2), Conv2dBlock(input_dim, dim // 2, 5, 1, 2, norm='ln', activation=activ, pad_type='reflect')]
|
731 |
+
dim //= 2
|
732 |
+
# use reflection padding in the last conv layer
|
733 |
+
self.model += [Conv2dBlock(dim, output_dim, 7, 1, 3, norm='none', activation='tanh', pad_type='reflect')]
|
734 |
+
self.model = nn.Sequential(*self.model)
|
735 |
+
|
736 |
+
def forward(self, x, y=None):
|
737 |
+
if y is not None:
|
738 |
+
return self.model(cat_feature(x, y))
|
739 |
+
else:
|
740 |
+
return self.model(x)
|
741 |
+
|
742 |
+
##################################################################################
|
743 |
+
# Sequential Models
|
744 |
+
##################################################################################
|
745 |
+
|
746 |
+
|
747 |
+
class ResBlocks(nn.Module):
|
748 |
+
def __init__(self, num_blocks, dim, norm='inst', activation='relu', pad_type='zero', nz=0):
|
749 |
+
super(ResBlocks, self).__init__()
|
750 |
+
self.model = []
|
751 |
+
for i in range(num_blocks):
|
752 |
+
self.model += [ResBlock(dim, norm=norm, activation=activation, pad_type=pad_type, nz=nz)]
|
753 |
+
self.model = nn.Sequential(*self.model)
|
754 |
+
|
755 |
+
def forward(self, x):
|
756 |
+
return self.model(x)
|
757 |
+
|
758 |
+
|
759 |
+
##################################################################################
|
760 |
+
# Basic Blocks
|
761 |
+
##################################################################################
|
762 |
+
def cat_feature(x, y):
|
763 |
+
y_expand = y.view(y.size(0), y.size(1), 1, 1).expand(
|
764 |
+
y.size(0), y.size(1), x.size(2), x.size(3))
|
765 |
+
x_cat = torch.cat([x, y_expand], 1)
|
766 |
+
return x_cat
|
767 |
+
|
768 |
+
|
769 |
+
class ResBlock(nn.Module):
|
770 |
+
def __init__(self, dim, norm='inst', activation='relu', pad_type='zero', nz=0):
|
771 |
+
super(ResBlock, self).__init__()
|
772 |
+
|
773 |
+
model = []
|
774 |
+
model += [Conv2dBlock(dim + nz, dim, 3, 1, 1, norm=norm, activation=activation, pad_type=pad_type)]
|
775 |
+
model += [Conv2dBlock(dim, dim + nz, 3, 1, 1, norm=norm, activation='none', pad_type=pad_type)]
|
776 |
+
self.model = nn.Sequential(*model)
|
777 |
+
|
778 |
+
def forward(self, x):
|
779 |
+
residual = x
|
780 |
+
out = self.model(x)
|
781 |
+
out += residual
|
782 |
+
return out
|
783 |
+
|
784 |
+
|
785 |
+
class Conv2dBlock(nn.Module):
|
786 |
+
def __init__(self, input_dim, output_dim, kernel_size, stride,
|
787 |
+
padding=0, norm='none', activation='relu', pad_type='zero'):
|
788 |
+
super(Conv2dBlock, self).__init__()
|
789 |
+
self.use_bias = True
|
790 |
+
# initialize padding
|
791 |
+
if pad_type == 'reflect':
|
792 |
+
self.pad = nn.ReflectionPad2d(padding)
|
793 |
+
elif pad_type == 'zero':
|
794 |
+
self.pad = nn.ZeroPad2d(padding)
|
795 |
+
else:
|
796 |
+
assert 0, "Unsupported padding type: {}".format(pad_type)
|
797 |
+
|
798 |
+
# initialize normalization
|
799 |
+
norm_dim = output_dim
|
800 |
+
if norm == 'batch':
|
801 |
+
self.norm = nn.BatchNorm2d(norm_dim)
|
802 |
+
elif norm == 'inst':
|
803 |
+
self.norm = nn.InstanceNorm2d(norm_dim, track_running_stats=False)
|
804 |
+
elif norm == 'ln':
|
805 |
+
self.norm = LayerNorm(norm_dim)
|
806 |
+
elif norm == 'none':
|
807 |
+
self.norm = None
|
808 |
+
else:
|
809 |
+
assert 0, "Unsupported normalization: {}".format(norm)
|
810 |
+
|
811 |
+
# initialize activation
|
812 |
+
if activation == 'relu':
|
813 |
+
self.activation = nn.ReLU(inplace=True)
|
814 |
+
elif activation == 'lrelu':
|
815 |
+
self.activation = nn.LeakyReLU(0.2, inplace=True)
|
816 |
+
elif activation == 'prelu':
|
817 |
+
self.activation = nn.PReLU()
|
818 |
+
elif activation == 'selu':
|
819 |
+
self.activation = nn.SELU(inplace=True)
|
820 |
+
elif activation == 'tanh':
|
821 |
+
self.activation = nn.Tanh()
|
822 |
+
elif activation == 'none':
|
823 |
+
self.activation = None
|
824 |
+
else:
|
825 |
+
assert 0, "Unsupported activation: {}".format(activation)
|
826 |
+
|
827 |
+
# initialize convolution
|
828 |
+
self.conv = nn.Conv2d(input_dim, output_dim, kernel_size, stride, bias=self.use_bias)
|
829 |
+
|
830 |
+
def forward(self, x):
|
831 |
+
x = self.conv(self.pad(x))
|
832 |
+
if self.norm:
|
833 |
+
x = self.norm(x)
|
834 |
+
if self.activation:
|
835 |
+
x = self.activation(x)
|
836 |
+
return x
|
837 |
+
|
838 |
+
|
839 |
+
class LinearBlock(nn.Module):
|
840 |
+
def __init__(self, input_dim, output_dim, norm='none', activation='relu'):
|
841 |
+
super(LinearBlock, self).__init__()
|
842 |
+
use_bias = True
|
843 |
+
# initialize fully connected layer
|
844 |
+
self.fc = nn.Linear(input_dim, output_dim, bias=use_bias)
|
845 |
+
|
846 |
+
# initialize normalization
|
847 |
+
norm_dim = output_dim
|
848 |
+
if norm == 'batch':
|
849 |
+
self.norm = nn.BatchNorm1d(norm_dim)
|
850 |
+
elif norm == 'inst':
|
851 |
+
self.norm = nn.InstanceNorm1d(norm_dim)
|
852 |
+
elif norm == 'ln':
|
853 |
+
self.norm = LayerNorm(norm_dim)
|
854 |
+
elif norm == 'none':
|
855 |
+
self.norm = None
|
856 |
+
else:
|
857 |
+
assert 0, "Unsupported normalization: {}".format(norm)
|
858 |
+
|
859 |
+
# initialize activation
|
860 |
+
if activation == 'relu':
|
861 |
+
self.activation = nn.ReLU(inplace=True)
|
862 |
+
elif activation == 'lrelu':
|
863 |
+
self.activation = nn.LeakyReLU(0.2, inplace=True)
|
864 |
+
elif activation == 'prelu':
|
865 |
+
self.activation = nn.PReLU()
|
866 |
+
elif activation == 'selu':
|
867 |
+
self.activation = nn.SELU(inplace=True)
|
868 |
+
elif activation == 'tanh':
|
869 |
+
self.activation = nn.Tanh()
|
870 |
+
elif activation == 'none':
|
871 |
+
self.activation = None
|
872 |
+
else:
|
873 |
+
assert 0, "Unsupported activation: {}".format(activation)
|
874 |
+
|
875 |
+
def forward(self, x):
|
876 |
+
out = self.fc(x)
|
877 |
+
if self.norm:
|
878 |
+
out = self.norm(out)
|
879 |
+
if self.activation:
|
880 |
+
out = self.activation(out)
|
881 |
+
return out
|
882 |
+
|
883 |
+
##################################################################################
|
884 |
+
# Normalization layers
|
885 |
+
##################################################################################
|
886 |
+
|
887 |
+
|
888 |
+
class LayerNorm(nn.Module):
|
889 |
+
def __init__(self, num_features, eps=1e-5, affine=True):
|
890 |
+
super(LayerNorm, self).__init__()
|
891 |
+
self.num_features = num_features
|
892 |
+
self.affine = affine
|
893 |
+
self.eps = eps
|
894 |
+
|
895 |
+
if self.affine:
|
896 |
+
self.gamma = nn.Parameter(torch.Tensor(num_features).uniform_())
|
897 |
+
self.beta = nn.Parameter(torch.zeros(num_features))
|
898 |
+
|
899 |
+
def forward(self, x):
|
900 |
+
shape = [-1] + [1] * (x.dim() - 1)
|
901 |
+
mean = x.view(x.size(0), -1).mean(1).view(*shape)
|
902 |
+
std = x.view(x.size(0), -1).std(1).view(*shape)
|
903 |
+
x = (x - mean) / (std + self.eps)
|
904 |
+
|
905 |
+
if self.affine:
|
906 |
+
shape = [1, -1] + [1] * (x.dim() - 2)
|
907 |
+
x = x * self.gamma.view(*shape) + self.beta.view(*shape)
|
908 |
+
return x
|
909 |
+
|
910 |
+
|
911 |
+
class ResnetGenerator(nn.Module):
|
912 |
+
"""Resnet-based generator that consists of Resnet blocks between a few downsampling/upsampling operations.
|
913 |
+
|
914 |
+
We adapt Torch code and idea from Justin Johnson's neural style transfer project(https://github.com/jcjohnson/fast-neural-style)
|
915 |
+
"""
|
916 |
+
|
917 |
+
def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6, padding_type='reflect', no_antialias=False, no_antialias_up=False, opt=None):
|
918 |
+
"""Construct a Resnet-based generator
|
919 |
+
|
920 |
+
Parameters:
|
921 |
+
input_nc (int) -- the number of channels in input images
|
922 |
+
output_nc (int) -- the number of channels in output images
|
923 |
+
ngf (int) -- the number of filters in the last conv layer
|
924 |
+
norm_layer -- normalization layer
|
925 |
+
use_dropout (bool) -- if use dropout layers
|
926 |
+
n_blocks (int) -- the number of ResNet blocks
|
927 |
+
padding_type (str) -- the name of padding layer in conv layers: reflect | replicate | zero
|
928 |
+
"""
|
929 |
+
assert(n_blocks >= 0)
|
930 |
+
super(ResnetGenerator, self).__init__()
|
931 |
+
self.opt = opt
|
932 |
+
if type(norm_layer) == functools.partial:
|
933 |
+
use_bias = norm_layer.func == nn.InstanceNorm2d
|
934 |
+
else:
|
935 |
+
use_bias = norm_layer == nn.InstanceNorm2d
|
936 |
+
|
937 |
+
model = [nn.ReflectionPad2d(3),
|
938 |
+
nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=use_bias),
|
939 |
+
norm_layer(ngf),
|
940 |
+
nn.ReLU(True)]
|
941 |
+
|
942 |
+
n_downsampling = 2
|
943 |
+
for i in range(n_downsampling): # add downsampling layers
|
944 |
+
mult = 2 ** i
|
945 |
+
if(no_antialias):
|
946 |
+
model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1, bias=use_bias),
|
947 |
+
norm_layer(ngf * mult * 2),
|
948 |
+
nn.ReLU(True)]
|
949 |
+
else:
|
950 |
+
model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=1, padding=1, bias=use_bias),
|
951 |
+
norm_layer(ngf * mult * 2),
|
952 |
+
nn.ReLU(True),
|
953 |
+
Downsample(ngf * mult * 2)]
|
954 |
+
|
955 |
+
mult = 2 ** n_downsampling
|
956 |
+
for i in range(n_blocks): # add ResNet blocks
|
957 |
+
|
958 |
+
model += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)]
|
959 |
+
|
960 |
+
for i in range(n_downsampling): # add upsampling layers
|
961 |
+
mult = 2 ** (n_downsampling - i)
|
962 |
+
if no_antialias_up:
|
963 |
+
model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2),
|
964 |
+
kernel_size=3, stride=2,
|
965 |
+
padding=1, output_padding=1,
|
966 |
+
bias=use_bias),
|
967 |
+
norm_layer(int(ngf * mult / 2)),
|
968 |
+
nn.ReLU(True)]
|
969 |
+
else:
|
970 |
+
model += [Upsample(ngf * mult),
|
971 |
+
nn.Conv2d(ngf * mult, int(ngf * mult / 2),
|
972 |
+
kernel_size=3, stride=1,
|
973 |
+
padding=1, # output_padding=1,
|
974 |
+
bias=use_bias),
|
975 |
+
norm_layer(int(ngf * mult / 2)),
|
976 |
+
nn.ReLU(True)]
|
977 |
+
model += [nn.ReflectionPad2d(3)]
|
978 |
+
model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
|
979 |
+
model += [nn.Tanh()]
|
980 |
+
|
981 |
+
self.model = nn.Sequential(*model)
|
982 |
+
|
983 |
+
def forward(self, input, layers=[], encode_only=False,mode='all',stop_layer=16):
|
984 |
+
if -1 in layers:
|
985 |
+
layers.append(len(self.model))
|
986 |
+
if len(layers) > 0:
|
987 |
+
feat = input
|
988 |
+
feats = []
|
989 |
+
for layer_id, layer in enumerate(self.model):
|
990 |
+
# print(layer_id, layer)
|
991 |
+
feat = layer(feat)
|
992 |
+
if layer_id in layers:
|
993 |
+
# print("%d: adding the output of %s %d" % (layer_id, layer.__class__.__name__, feat.size(1)))
|
994 |
+
feats.append(feat)
|
995 |
+
else:
|
996 |
+
# print("%d: skipping %s %d" % (layer_id, layer.__class__.__name__, feat.size(1)))
|
997 |
+
pass
|
998 |
+
if layer_id == layers[-1] and encode_only:
|
999 |
+
# print('encoder only return features')
|
1000 |
+
return feats # return intermediate features alone; stop in the last layers
|
1001 |
+
|
1002 |
+
return feat, feats # return both output and intermediate features
|
1003 |
+
else:
|
1004 |
+
"""Standard forward"""
|
1005 |
+
if mode=='encoder':
|
1006 |
+
feat=input
|
1007 |
+
for layer_id, layer in enumerate(self.model):
|
1008 |
+
feat = layer(feat)
|
1009 |
+
if layer_id == stop_layer:
|
1010 |
+
# print('encoder only return features')
|
1011 |
+
return feat # return intermediate features alone; stop in the last layers
|
1012 |
+
elif mode =='decoder':
|
1013 |
+
feat=input
|
1014 |
+
for layer_id, layer in enumerate(self.model):
|
1015 |
+
|
1016 |
+
if layer_id > stop_layer:
|
1017 |
+
feat = layer(feat)
|
1018 |
+
else:
|
1019 |
+
pass
|
1020 |
+
# print('encoder only return features')
|
1021 |
+
return feat # return intermediate features alone; stop in the last layers
|
1022 |
+
else:
|
1023 |
+
fake = self.model(input)
|
1024 |
+
return fake
|
1025 |
+
|
1026 |
+
# class ResnetGenerator(nn.Module):
|
1027 |
+
# """Resnet-based generator that consists of Resnet blocks between a few downsampling/upsampling operations.
|
1028 |
+
|
1029 |
+
# We adapt Torch code and idea from Justin Johnson's neural style transfer project(https://github.com/jcjohnson/fast-neural-style)
|
1030 |
+
# """
|
1031 |
+
|
1032 |
+
# def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6, padding_type='reflect', no_antialias=False, no_antialias_up=False, opt=None):
|
1033 |
+
# """Construct a Resnet-based generator
|
1034 |
+
|
1035 |
+
# Parameters:
|
1036 |
+
# input_nc (int) -- the number of channels in input images
|
1037 |
+
# output_nc (int) -- the number of channels in output images
|
1038 |
+
# ngf (int) -- the number of filters in the last conv layer
|
1039 |
+
# norm_layer -- normalization layer
|
1040 |
+
# use_dropout (bool) -- if use dropout layers
|
1041 |
+
# n_blocks (int) -- the number of ResNet blocks
|
1042 |
+
# padding_type (str) -- the name of padding layer in conv layers: reflect | replicate | zero
|
1043 |
+
# """
|
1044 |
+
# assert(n_blocks >= 0)
|
1045 |
+
# super(ResnetGenerator, self).__init__()
|
1046 |
+
# self.opt = opt
|
1047 |
+
# if type(norm_layer) == functools.partial:
|
1048 |
+
# use_bias = norm_layer.func == nn.InstanceNorm2d
|
1049 |
+
# else:
|
1050 |
+
# use_bias = norm_layer == nn.InstanceNorm2d
|
1051 |
+
|
1052 |
+
# model = [nn.ReflectionPad2d(3),
|
1053 |
+
# nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=use_bias),
|
1054 |
+
# norm_layer(ngf),
|
1055 |
+
# nn.ReLU(True)]
|
1056 |
+
|
1057 |
+
# n_downsampling = 2
|
1058 |
+
# for i in range(n_downsampling): # add downsampling layers
|
1059 |
+
# mult = 2 ** i
|
1060 |
+
# if(no_antialias):
|
1061 |
+
# model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1, bias=use_bias),
|
1062 |
+
# norm_layer(ngf * mult * 2),
|
1063 |
+
# nn.ReLU(True)]
|
1064 |
+
# else:
|
1065 |
+
# model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=1, padding=1, bias=use_bias),
|
1066 |
+
# norm_layer(ngf * mult * 2),
|
1067 |
+
# nn.ReLU(True),
|
1068 |
+
# Downsample(ngf * mult * 2)]
|
1069 |
+
|
1070 |
+
# mult = 2 ** n_downsampling
|
1071 |
+
# for i in range(n_blocks): # add ResNet blocks
|
1072 |
+
|
1073 |
+
# model += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)]
|
1074 |
+
|
1075 |
+
# for i in range(n_downsampling): # add upsampling layers
|
1076 |
+
# mult = 2 ** (n_downsampling - i)
|
1077 |
+
# if no_antialias_up:
|
1078 |
+
# model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2),
|
1079 |
+
# kernel_size=3, stride=2,
|
1080 |
+
# padding=1, output_padding=1,
|
1081 |
+
# bias=use_bias),
|
1082 |
+
# norm_layer(int(ngf * mult / 2)),
|
1083 |
+
# nn.ReLU(True)]
|
1084 |
+
# else:
|
1085 |
+
# model += [Upsample(ngf * mult),
|
1086 |
+
# nn.Conv2d(ngf * mult, int(ngf * mult / 2),
|
1087 |
+
# kernel_size=3, stride=1,
|
1088 |
+
# padding=1, # output_padding=1,
|
1089 |
+
# bias=use_bias),
|
1090 |
+
# norm_layer(int(ngf * mult / 2)),
|
1091 |
+
# nn.ReLU(True)]
|
1092 |
+
# model += [nn.ReflectionPad2d(3)]
|
1093 |
+
# model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
|
1094 |
+
# model += [nn.Tanh()]
|
1095 |
+
|
1096 |
+
# self.model = nn.Sequential(*model)
|
1097 |
+
|
1098 |
+
# def forward(self, input, layers=[], encode_only=False):
|
1099 |
+
# if -1 in layers:
|
1100 |
+
# layers.append(len(self.model))
|
1101 |
+
# if len(layers) > 0:
|
1102 |
+
# feat = input
|
1103 |
+
# feats = []
|
1104 |
+
# for layer_id, layer in enumerate(self.model):
|
1105 |
+
# # print(layer_id, layer)
|
1106 |
+
# feat = layer(feat)
|
1107 |
+
# if layer_id in layers:
|
1108 |
+
# # print("%d: adding the output of %s %d" % (layer_id, layer.__class__.__name__, feat.size(1)))
|
1109 |
+
# feats.append(feat)
|
1110 |
+
# else:
|
1111 |
+
# # print("%d: skipping %s %d" % (layer_id, layer.__class__.__name__, feat.size(1)))
|
1112 |
+
# pass
|
1113 |
+
# if layer_id == layers[-1] and encode_only:
|
1114 |
+
# # print('encoder only return features')
|
1115 |
+
# return feats # return intermediate features alone; stop in the last layers
|
1116 |
+
|
1117 |
+
# return feat, feats # return both output and intermediate features
|
1118 |
+
# else:
|
1119 |
+
# """Standard forward"""
|
1120 |
+
# fake = self.model(input)
|
1121 |
+
# return fake
|
1122 |
+
|
1123 |
+
class ResnetDecoder(nn.Module):
|
1124 |
+
"""Resnet-based decoder that consists of a few Resnet blocks + a few upsampling operations.
|
1125 |
+
"""
|
1126 |
+
|
1127 |
+
def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6, padding_type='reflect', no_antialias=False):
|
1128 |
+
"""Construct a Resnet-based decoder
|
1129 |
+
|
1130 |
+
Parameters:
|
1131 |
+
input_nc (int) -- the number of channels in input images
|
1132 |
+
output_nc (int) -- the number of channels in output images
|
1133 |
+
ngf (int) -- the number of filters in the last conv layer
|
1134 |
+
norm_layer -- normalization layer
|
1135 |
+
use_dropout (bool) -- if use dropout layers
|
1136 |
+
n_blocks (int) -- the number of ResNet blocks
|
1137 |
+
padding_type (str) -- the name of padding layer in conv layers: reflect | replicate | zero
|
1138 |
+
"""
|
1139 |
+
assert(n_blocks >= 0)
|
1140 |
+
super(ResnetDecoder, self).__init__()
|
1141 |
+
if type(norm_layer) == functools.partial:
|
1142 |
+
use_bias = norm_layer.func == nn.InstanceNorm2d
|
1143 |
+
else:
|
1144 |
+
use_bias = norm_layer == nn.InstanceNorm2d
|
1145 |
+
model = []
|
1146 |
+
n_downsampling = 2
|
1147 |
+
mult = 2 ** n_downsampling
|
1148 |
+
for i in range(n_blocks): # add ResNet blocks
|
1149 |
+
|
1150 |
+
model += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)]
|
1151 |
+
|
1152 |
+
for i in range(n_downsampling): # add upsampling layers
|
1153 |
+
mult = 2 ** (n_downsampling - i)
|
1154 |
+
if(no_antialias):
|
1155 |
+
model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2),
|
1156 |
+
kernel_size=3, stride=2,
|
1157 |
+
padding=1, output_padding=1,
|
1158 |
+
bias=use_bias),
|
1159 |
+
norm_layer(int(ngf * mult / 2)),
|
1160 |
+
nn.ReLU(True)]
|
1161 |
+
else:
|
1162 |
+
model += [Upsample(ngf * mult),
|
1163 |
+
nn.Conv2d(ngf * mult, int(ngf * mult / 2),
|
1164 |
+
kernel_size=3, stride=1,
|
1165 |
+
padding=1,
|
1166 |
+
bias=use_bias),
|
1167 |
+
norm_layer(int(ngf * mult / 2)),
|
1168 |
+
nn.ReLU(True)]
|
1169 |
+
model += [nn.ReflectionPad2d(3)]
|
1170 |
+
model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
|
1171 |
+
model += [nn.Tanh()]
|
1172 |
+
|
1173 |
+
self.model = nn.Sequential(*model)
|
1174 |
+
|
1175 |
+
def forward(self, input):
|
1176 |
+
"""Standard forward"""
|
1177 |
+
return self.model(input)
|
1178 |
+
|
1179 |
+
|
1180 |
+
class ResnetEncoder(nn.Module):
|
1181 |
+
"""Resnet-based encoder that consists of a few downsampling + several Resnet blocks
|
1182 |
+
"""
|
1183 |
+
|
1184 |
+
def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6, padding_type='reflect', no_antialias=False):
|
1185 |
+
"""Construct a Resnet-based encoder
|
1186 |
+
|
1187 |
+
Parameters:
|
1188 |
+
input_nc (int) -- the number of channels in input images
|
1189 |
+
output_nc (int) -- the number of channels in output images
|
1190 |
+
ngf (int) -- the number of filters in the last conv layer
|
1191 |
+
norm_layer -- normalization layer
|
1192 |
+
use_dropout (bool) -- if use dropout layers
|
1193 |
+
n_blocks (int) -- the number of ResNet blocks
|
1194 |
+
padding_type (str) -- the name of padding layer in conv layers: reflect | replicate | zero
|
1195 |
+
"""
|
1196 |
+
assert(n_blocks >= 0)
|
1197 |
+
super(ResnetEncoder, self).__init__()
|
1198 |
+
if type(norm_layer) == functools.partial:
|
1199 |
+
use_bias = norm_layer.func == nn.InstanceNorm2d
|
1200 |
+
else:
|
1201 |
+
use_bias = norm_layer == nn.InstanceNorm2d
|
1202 |
+
|
1203 |
+
model = [nn.ReflectionPad2d(3),
|
1204 |
+
nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=use_bias),
|
1205 |
+
norm_layer(ngf),
|
1206 |
+
nn.ReLU(True)]
|
1207 |
+
|
1208 |
+
n_downsampling = 2
|
1209 |
+
for i in range(n_downsampling): # add downsampling layers
|
1210 |
+
mult = 2 ** i
|
1211 |
+
if(no_antialias):
|
1212 |
+
model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1, bias=use_bias),
|
1213 |
+
norm_layer(ngf * mult * 2),
|
1214 |
+
nn.ReLU(True)]
|
1215 |
+
else:
|
1216 |
+
model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=1, padding=1, bias=use_bias),
|
1217 |
+
norm_layer(ngf * mult * 2),
|
1218 |
+
nn.ReLU(True),
|
1219 |
+
Downsample(ngf * mult * 2)]
|
1220 |
+
|
1221 |
+
mult = 2 ** n_downsampling
|
1222 |
+
for i in range(n_blocks): # add ResNet blocks
|
1223 |
+
|
1224 |
+
model += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)]
|
1225 |
+
|
1226 |
+
self.model = nn.Sequential(*model)
|
1227 |
+
|
1228 |
+
def forward(self, input):
|
1229 |
+
"""Standard forward"""
|
1230 |
+
return self.model(input)
|
1231 |
+
|
1232 |
+
|
1233 |
+
class ResnetBlock(nn.Module):
|
1234 |
+
"""Define a Resnet block"""
|
1235 |
+
|
1236 |
+
def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias):
|
1237 |
+
"""Initialize the Resnet block
|
1238 |
+
|
1239 |
+
A resnet block is a conv block with skip connections
|
1240 |
+
We construct a conv block with build_conv_block function,
|
1241 |
+
and implement skip connections in <forward> function.
|
1242 |
+
Original Resnet paper: https://arxiv.org/pdf/1512.03385.pdf
|
1243 |
+
"""
|
1244 |
+
super(ResnetBlock, self).__init__()
|
1245 |
+
self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_dropout, use_bias)
|
1246 |
+
|
1247 |
+
def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias):
|
1248 |
+
"""Construct a convolutional block.
|
1249 |
+
|
1250 |
+
Parameters:
|
1251 |
+
dim (int) -- the number of channels in the conv layer.
|
1252 |
+
padding_type (str) -- the name of padding layer: reflect | replicate | zero
|
1253 |
+
norm_layer -- normalization layer
|
1254 |
+
use_dropout (bool) -- if use dropout layers.
|
1255 |
+
use_bias (bool) -- if the conv layer uses bias or not
|
1256 |
+
|
1257 |
+
Returns a conv block (with a conv layer, a normalization layer, and a non-linearity layer (ReLU))
|
1258 |
+
"""
|
1259 |
+
conv_block = []
|
1260 |
+
p = 0
|
1261 |
+
if padding_type == 'reflect':
|
1262 |
+
conv_block += [nn.ReflectionPad2d(1)]
|
1263 |
+
elif padding_type == 'replicate':
|
1264 |
+
conv_block += [nn.ReplicationPad2d(1)]
|
1265 |
+
elif padding_type == 'zero':
|
1266 |
+
p = 1
|
1267 |
+
else:
|
1268 |
+
raise NotImplementedError('padding [%s] is not implemented' % padding_type)
|
1269 |
+
|
1270 |
+
conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim), nn.ReLU(True)]
|
1271 |
+
if use_dropout:
|
1272 |
+
conv_block += [nn.Dropout(0.5)]
|
1273 |
+
|
1274 |
+
p = 0
|
1275 |
+
if padding_type == 'reflect':
|
1276 |
+
conv_block += [nn.ReflectionPad2d(1)]
|
1277 |
+
elif padding_type == 'replicate':
|
1278 |
+
conv_block += [nn.ReplicationPad2d(1)]
|
1279 |
+
elif padding_type == 'zero':
|
1280 |
+
p = 1
|
1281 |
+
else:
|
1282 |
+
raise NotImplementedError('padding [%s] is not implemented' % padding_type)
|
1283 |
+
conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim)]
|
1284 |
+
|
1285 |
+
return nn.Sequential(*conv_block)
|
1286 |
+
|
1287 |
+
def forward(self, x):
|
1288 |
+
"""Forward function (with skip connections)"""
|
1289 |
+
out = x + self.conv_block(x) # add skip connections
|
1290 |
+
return out
|
1291 |
+
|
1292 |
+
|
1293 |
+
class UnetGenerator(nn.Module):
|
1294 |
+
"""Create a Unet-based generator"""
|
1295 |
+
|
1296 |
+
def __init__(self, input_nc, output_nc, num_downs, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False):
|
1297 |
+
"""Construct a Unet generator
|
1298 |
+
Parameters:
|
1299 |
+
input_nc (int) -- the number of channels in input images
|
1300 |
+
output_nc (int) -- the number of channels in output images
|
1301 |
+
num_downs (int) -- the number of downsamplings in UNet. For example, # if |num_downs| == 7,
|
1302 |
+
image of size 128x128 will become of size 1x1 # at the bottleneck
|
1303 |
+
ngf (int) -- the number of filters in the last conv layer
|
1304 |
+
norm_layer -- normalization layer
|
1305 |
+
|
1306 |
+
We construct the U-Net from the innermost layer to the outermost layer.
|
1307 |
+
It is a recursive process.
|
1308 |
+
"""
|
1309 |
+
super(UnetGenerator, self).__init__()
|
1310 |
+
# construct unet structure
|
1311 |
+
unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True) # add the innermost layer
|
1312 |
+
for i in range(num_downs - 5): # add intermediate layers with ngf * 8 filters
|
1313 |
+
unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_dropout=use_dropout)
|
1314 |
+
# gradually reduce the number of filters from ngf * 8 to ngf
|
1315 |
+
unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
|
1316 |
+
unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
|
1317 |
+
unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
|
1318 |
+
self.model = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer) # add the outermost layer
|
1319 |
+
|
1320 |
+
def forward(self, input):
|
1321 |
+
"""Standard forward"""
|
1322 |
+
return self.model(input)
|
1323 |
+
|
1324 |
+
|
1325 |
+
class UnetSkipConnectionBlock(nn.Module):
|
1326 |
+
"""Defines the Unet submodule with skip connection.
|
1327 |
+
X -------------------identity----------------------
|
1328 |
+
|-- downsampling -- |submodule| -- upsampling --|
|
1329 |
+
"""
|
1330 |
+
|
1331 |
+
def __init__(self, outer_nc, inner_nc, input_nc=None,
|
1332 |
+
submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False):
|
1333 |
+
"""Construct a Unet submodule with skip connections.
|
1334 |
+
|
1335 |
+
Parameters:
|
1336 |
+
outer_nc (int) -- the number of filters in the outer conv layer
|
1337 |
+
inner_nc (int) -- the number of filters in the inner conv layer
|
1338 |
+
input_nc (int) -- the number of channels in input images/features
|
1339 |
+
submodule (UnetSkipConnectionBlock) -- previously defined submodules
|
1340 |
+
outermost (bool) -- if this module is the outermost module
|
1341 |
+
innermost (bool) -- if this module is the innermost module
|
1342 |
+
norm_layer -- normalization layer
|
1343 |
+
use_dropout (bool) -- if use dropout layers.
|
1344 |
+
"""
|
1345 |
+
super(UnetSkipConnectionBlock, self).__init__()
|
1346 |
+
self.outermost = outermost
|
1347 |
+
if type(norm_layer) == functools.partial:
|
1348 |
+
use_bias = norm_layer.func == nn.InstanceNorm2d
|
1349 |
+
else:
|
1350 |
+
use_bias = norm_layer == nn.InstanceNorm2d
|
1351 |
+
if input_nc is None:
|
1352 |
+
input_nc = outer_nc
|
1353 |
+
downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4,
|
1354 |
+
stride=2, padding=1, bias=use_bias)
|
1355 |
+
downrelu = nn.LeakyReLU(0.2, True)
|
1356 |
+
downnorm = norm_layer(inner_nc)
|
1357 |
+
uprelu = nn.ReLU(True)
|
1358 |
+
upnorm = norm_layer(outer_nc)
|
1359 |
+
|
1360 |
+
if outermost:
|
1361 |
+
upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
|
1362 |
+
kernel_size=4, stride=2,
|
1363 |
+
padding=1)
|
1364 |
+
down = [downconv]
|
1365 |
+
up = [uprelu, upconv, nn.Tanh()]
|
1366 |
+
model = down + [submodule] + up
|
1367 |
+
elif innermost:
|
1368 |
+
upconv = nn.ConvTranspose2d(inner_nc, outer_nc,
|
1369 |
+
kernel_size=4, stride=2,
|
1370 |
+
padding=1, bias=use_bias)
|
1371 |
+
down = [downrelu, downconv]
|
1372 |
+
up = [uprelu, upconv, upnorm]
|
1373 |
+
model = down + up
|
1374 |
+
else:
|
1375 |
+
upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
|
1376 |
+
kernel_size=4, stride=2,
|
1377 |
+
padding=1, bias=use_bias)
|
1378 |
+
down = [downrelu, downconv, downnorm]
|
1379 |
+
up = [uprelu, upconv, upnorm]
|
1380 |
+
|
1381 |
+
if use_dropout:
|
1382 |
+
model = down + [submodule] + up + [nn.Dropout(0.5)]
|
1383 |
+
else:
|
1384 |
+
model = down + [submodule] + up
|
1385 |
+
|
1386 |
+
self.model = nn.Sequential(*model)
|
1387 |
+
|
1388 |
+
def forward(self, x):
|
1389 |
+
if self.outermost:
|
1390 |
+
return self.model(x)
|
1391 |
+
else: # add skip connections
|
1392 |
+
return torch.cat([x, self.model(x)], 1)
|
1393 |
+
|
1394 |
+
|
1395 |
+
class NLayerDiscriminator(nn.Module):
|
1396 |
+
"""Defines a PatchGAN discriminator"""
|
1397 |
+
|
1398 |
+
def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, no_antialias=False):
|
1399 |
+
"""Construct a PatchGAN discriminator
|
1400 |
+
|
1401 |
+
Parameters:
|
1402 |
+
input_nc (int) -- the number of channels in input images
|
1403 |
+
ndf (int) -- the number of filters in the last conv layer
|
1404 |
+
n_layers (int) -- the number of conv layers in the discriminator
|
1405 |
+
norm_layer -- normalization layer
|
1406 |
+
"""
|
1407 |
+
super(NLayerDiscriminator, self).__init__()
|
1408 |
+
if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters
|
1409 |
+
use_bias = norm_layer.func == nn.InstanceNorm2d
|
1410 |
+
else:
|
1411 |
+
use_bias = norm_layer == nn.InstanceNorm2d
|
1412 |
+
|
1413 |
+
kw = 4
|
1414 |
+
padw = 1
|
1415 |
+
if(no_antialias):
|
1416 |
+
sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]
|
1417 |
+
else:
|
1418 |
+
sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=1, padding=padw), nn.LeakyReLU(0.2, True), Downsample(ndf)]
|
1419 |
+
nf_mult = 1
|
1420 |
+
nf_mult_prev = 1
|
1421 |
+
for n in range(1, n_layers): # gradually increase the number of filters
|
1422 |
+
nf_mult_prev = nf_mult
|
1423 |
+
nf_mult = min(2 ** n, 8)
|
1424 |
+
if(no_antialias):
|
1425 |
+
sequence += [
|
1426 |
+
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
|
1427 |
+
norm_layer(ndf * nf_mult),
|
1428 |
+
nn.LeakyReLU(0.2, True)
|
1429 |
+
]
|
1430 |
+
else:
|
1431 |
+
sequence += [
|
1432 |
+
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),
|
1433 |
+
norm_layer(ndf * nf_mult),
|
1434 |
+
nn.LeakyReLU(0.2, True),
|
1435 |
+
Downsample(ndf * nf_mult)]
|
1436 |
+
|
1437 |
+
nf_mult_prev = nf_mult
|
1438 |
+
nf_mult = min(2 ** n_layers, 8)
|
1439 |
+
sequence += [
|
1440 |
+
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),
|
1441 |
+
norm_layer(ndf * nf_mult),
|
1442 |
+
nn.LeakyReLU(0.2, True)
|
1443 |
+
]
|
1444 |
+
|
1445 |
+
sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map
|
1446 |
+
self.model = nn.Sequential(*sequence)
|
1447 |
+
|
1448 |
+
def forward(self, input):
|
1449 |
+
"""Standard forward."""
|
1450 |
+
return self.model(input)
|
1451 |
+
|
1452 |
+
|
1453 |
+
class PixelDiscriminator(nn.Module):
|
1454 |
+
"""Defines a 1x1 PatchGAN discriminator (pixelGAN)"""
|
1455 |
+
|
1456 |
+
def __init__(self, input_nc, ndf=64, norm_layer=nn.BatchNorm2d):
|
1457 |
+
"""Construct a 1x1 PatchGAN discriminator
|
1458 |
+
|
1459 |
+
Parameters:
|
1460 |
+
input_nc (int) -- the number of channels in input images
|
1461 |
+
ndf (int) -- the number of filters in the last conv layer
|
1462 |
+
norm_layer -- normalization layer
|
1463 |
+
"""
|
1464 |
+
super(PixelDiscriminator, self).__init__()
|
1465 |
+
if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters
|
1466 |
+
use_bias = norm_layer.func == nn.InstanceNorm2d
|
1467 |
+
else:
|
1468 |
+
use_bias = norm_layer == nn.InstanceNorm2d
|
1469 |
+
|
1470 |
+
self.net = [
|
1471 |
+
nn.Conv2d(input_nc, ndf, kernel_size=1, stride=1, padding=0),
|
1472 |
+
nn.LeakyReLU(0.2, True),
|
1473 |
+
nn.Conv2d(ndf, ndf * 2, kernel_size=1, stride=1, padding=0, bias=use_bias),
|
1474 |
+
norm_layer(ndf * 2),
|
1475 |
+
nn.LeakyReLU(0.2, True),
|
1476 |
+
nn.Conv2d(ndf * 2, 1, kernel_size=1, stride=1, padding=0, bias=use_bias)]
|
1477 |
+
|
1478 |
+
self.net = nn.Sequential(*self.net)
|
1479 |
+
|
1480 |
+
def forward(self, input):
|
1481 |
+
"""Standard forward."""
|
1482 |
+
return self.net(input)
|
1483 |
+
|
1484 |
+
|
1485 |
+
class PatchDiscriminator(NLayerDiscriminator):
|
1486 |
+
"""Defines a PatchGAN discriminator"""
|
1487 |
+
|
1488 |
+
def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, no_antialias=False):
|
1489 |
+
super().__init__(input_nc, ndf, 2, norm_layer, no_antialias)
|
1490 |
+
|
1491 |
+
def forward(self, input):
|
1492 |
+
B, C, H, W = input.size(0), input.size(1), input.size(2), input.size(3)
|
1493 |
+
size = 16
|
1494 |
+
Y = H // size
|
1495 |
+
X = W // size
|
1496 |
+
input = input.view(B, C, Y, size, X, size)
|
1497 |
+
input = input.permute(0, 2, 4, 1, 3, 5).contiguous().view(B * Y * X, C, size, size)
|
1498 |
+
return super().forward(input)
|
1499 |
+
|
1500 |
+
|
1501 |
+
class GroupedChannelNorm(nn.Module):
|
1502 |
+
def __init__(self, num_groups):
|
1503 |
+
super().__init__()
|
1504 |
+
self.num_groups = num_groups
|
1505 |
+
|
1506 |
+
def forward(self, x):
|
1507 |
+
shape = list(x.shape)
|
1508 |
+
new_shape = [shape[0], self.num_groups, shape[1] // self.num_groups] + shape[2:]
|
1509 |
+
x = x.view(*new_shape)
|
1510 |
+
mean = x.mean(dim=2, keepdim=True)
|
1511 |
+
std = x.std(dim=2, keepdim=True)
|
1512 |
+
x_norm = (x - mean) / (std + 1e-7)
|
1513 |
+
return x_norm.view(*shape)
|
Scenimefy/models/patchnce.py
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from packaging import version
|
2 |
+
import torch
|
3 |
+
from torch import nn
|
4 |
+
|
5 |
+
|
6 |
+
class PatchNCELoss(nn.Module):
|
7 |
+
def __init__(self, opt):
|
8 |
+
super().__init__()
|
9 |
+
self.opt = opt
|
10 |
+
self.cross_entropy_loss = torch.nn.CrossEntropyLoss(reduction='none')
|
11 |
+
self.mask_dtype = torch.uint8 if version.parse(torch.__version__) < version.parse('1.2.0') else torch.bool
|
12 |
+
|
13 |
+
def forward(self, feat_q, feat_k, weight=None):
|
14 |
+
batchSize = feat_q.shape[0]
|
15 |
+
dim = feat_q.shape[1]
|
16 |
+
feat_k = feat_k.detach()
|
17 |
+
|
18 |
+
# pos logit
|
19 |
+
l_pos = torch.bmm(feat_q.view(batchSize, 1, -1), feat_k.view(batchSize, -1, 1))
|
20 |
+
l_pos = l_pos.view(batchSize, 1)
|
21 |
+
|
22 |
+
# neg logit
|
23 |
+
|
24 |
+
# Should the negatives from the other samples of a minibatch be utilized?
|
25 |
+
# In CUT and FastCUT, we found that it's best to only include negatives
|
26 |
+
# from the same image. Therefore, we set
|
27 |
+
# --nce_includes_all_negatives_from_minibatch as False
|
28 |
+
# However, for single-image translation, the minibatch consists of
|
29 |
+
# crops from the "same" high-resolution image.
|
30 |
+
# Therefore, we will include the negatives from the entire minibatch.
|
31 |
+
if self.opt.nce_includes_all_negatives_from_minibatch:
|
32 |
+
# reshape features as if they are all negatives of minibatch of size 1.
|
33 |
+
batch_dim_for_bmm = 1
|
34 |
+
else:
|
35 |
+
batch_dim_for_bmm = self.opt.batch_size
|
36 |
+
|
37 |
+
# reshape features to batch size
|
38 |
+
feat_q = feat_q.view(batch_dim_for_bmm, -1, dim)
|
39 |
+
feat_k = feat_k.view(batch_dim_for_bmm, -1, dim)
|
40 |
+
npatches = feat_q.size(1)
|
41 |
+
l_neg_curbatch = torch.bmm(feat_q, feat_k.transpose(2, 1))
|
42 |
+
|
43 |
+
if weight is not None:
|
44 |
+
l_neg_curbatch *= weight
|
45 |
+
|
46 |
+
# diagonal entries are similarity between same features, and hence meaningless.
|
47 |
+
# just fill the diagonal with very small number, which is exp(-10) and almost zero
|
48 |
+
diagonal = torch.eye(npatches, device=feat_q.device, dtype=self.mask_dtype)[None, :, :]
|
49 |
+
l_neg_curbatch.masked_fill_(diagonal, -10.0)
|
50 |
+
l_neg = l_neg_curbatch.view(-1, npatches)
|
51 |
+
|
52 |
+
out = torch.cat((l_pos, l_neg), dim=1) / self.opt.nce_T
|
53 |
+
|
54 |
+
loss = self.cross_entropy_loss(out, torch.zeros(out.size(0), dtype=torch.long,
|
55 |
+
device=feat_q.device))
|
56 |
+
|
57 |
+
return loss
|
Scenimefy/models/stylegan_networks.py
ADDED
@@ -0,0 +1,914 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
The network architectures is based on PyTorch implemenation of StyleGAN2Encoder.
|
3 |
+
Original PyTorch repo: https://github.com/rosinality/style-based-gan-pytorch
|
4 |
+
Origianl StyelGAN2 paper: https://github.com/NVlabs/stylegan2
|
5 |
+
We use the network architeture for our single-image traning setting.
|
6 |
+
"""
|
7 |
+
|
8 |
+
import math
|
9 |
+
import numpy as np
|
10 |
+
import random
|
11 |
+
|
12 |
+
import torch
|
13 |
+
from torch import nn
|
14 |
+
from torch.nn import functional as F
|
15 |
+
|
16 |
+
|
17 |
+
def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2 ** 0.5):
|
18 |
+
return F.leaky_relu(input + bias, negative_slope) * scale
|
19 |
+
|
20 |
+
|
21 |
+
class FusedLeakyReLU(nn.Module):
|
22 |
+
def __init__(self, channel, negative_slope=0.2, scale=2 ** 0.5):
|
23 |
+
super().__init__()
|
24 |
+
self.bias = nn.Parameter(torch.zeros(1, channel, 1, 1))
|
25 |
+
self.negative_slope = negative_slope
|
26 |
+
self.scale = scale
|
27 |
+
|
28 |
+
def forward(self, input):
|
29 |
+
# print("FusedLeakyReLU: ", input.abs().mean())
|
30 |
+
out = fused_leaky_relu(input, self.bias,
|
31 |
+
self.negative_slope,
|
32 |
+
self.scale)
|
33 |
+
# print("FusedLeakyReLU: ", out.abs().mean())
|
34 |
+
return out
|
35 |
+
|
36 |
+
|
37 |
+
def upfirdn2d_native(
|
38 |
+
input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
|
39 |
+
):
|
40 |
+
_, minor, in_h, in_w = input.shape
|
41 |
+
kernel_h, kernel_w = kernel.shape
|
42 |
+
|
43 |
+
out = input.view(-1, minor, in_h, 1, in_w, 1)
|
44 |
+
out = F.pad(out, [0, up_x - 1, 0, 0, 0, up_y - 1, 0, 0])
|
45 |
+
out = out.view(-1, minor, in_h * up_y, in_w * up_x)
|
46 |
+
|
47 |
+
out = F.pad(
|
48 |
+
out, [max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)]
|
49 |
+
)
|
50 |
+
out = out[
|
51 |
+
:,
|
52 |
+
:,
|
53 |
+
max(-pad_y0, 0): out.shape[2] - max(-pad_y1, 0),
|
54 |
+
max(-pad_x0, 0): out.shape[3] - max(-pad_x1, 0),
|
55 |
+
]
|
56 |
+
|
57 |
+
# out = out.permute(0, 3, 1, 2)
|
58 |
+
out = out.reshape(
|
59 |
+
[-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1]
|
60 |
+
)
|
61 |
+
w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
|
62 |
+
out = F.conv2d(out, w)
|
63 |
+
out = out.reshape(
|
64 |
+
-1,
|
65 |
+
minor,
|
66 |
+
in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
|
67 |
+
in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
|
68 |
+
)
|
69 |
+
# out = out.permute(0, 2, 3, 1)
|
70 |
+
|
71 |
+
return out[:, :, ::down_y, ::down_x]
|
72 |
+
|
73 |
+
|
74 |
+
def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
|
75 |
+
return upfirdn2d_native(input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1])
|
76 |
+
|
77 |
+
|
78 |
+
class PixelNorm(nn.Module):
|
79 |
+
def __init__(self):
|
80 |
+
super().__init__()
|
81 |
+
|
82 |
+
def forward(self, input):
|
83 |
+
return input * torch.rsqrt(torch.mean(input ** 2, dim=1, keepdim=True) + 1e-8)
|
84 |
+
|
85 |
+
|
86 |
+
def make_kernel(k):
|
87 |
+
k = torch.tensor(k, dtype=torch.float32)
|
88 |
+
|
89 |
+
if len(k.shape) == 1:
|
90 |
+
k = k[None, :] * k[:, None]
|
91 |
+
|
92 |
+
k /= k.sum()
|
93 |
+
|
94 |
+
return k
|
95 |
+
|
96 |
+
|
97 |
+
class Upsample(nn.Module):
|
98 |
+
def __init__(self, kernel, factor=2):
|
99 |
+
super().__init__()
|
100 |
+
|
101 |
+
self.factor = factor
|
102 |
+
kernel = make_kernel(kernel) * (factor ** 2)
|
103 |
+
self.register_buffer('kernel', kernel)
|
104 |
+
|
105 |
+
p = kernel.shape[0] - factor
|
106 |
+
|
107 |
+
pad0 = (p + 1) // 2 + factor - 1
|
108 |
+
pad1 = p // 2
|
109 |
+
|
110 |
+
self.pad = (pad0, pad1)
|
111 |
+
|
112 |
+
def forward(self, input):
|
113 |
+
out = upfirdn2d(input, self.kernel, up=self.factor, down=1, pad=self.pad)
|
114 |
+
|
115 |
+
return out
|
116 |
+
|
117 |
+
|
118 |
+
class Downsample(nn.Module):
|
119 |
+
def __init__(self, kernel, factor=2):
|
120 |
+
super().__init__()
|
121 |
+
|
122 |
+
self.factor = factor
|
123 |
+
kernel = make_kernel(kernel)
|
124 |
+
self.register_buffer('kernel', kernel)
|
125 |
+
|
126 |
+
p = kernel.shape[0] - factor
|
127 |
+
|
128 |
+
pad0 = (p + 1) // 2
|
129 |
+
pad1 = p // 2
|
130 |
+
|
131 |
+
self.pad = (pad0, pad1)
|
132 |
+
|
133 |
+
def forward(self, input):
|
134 |
+
out = upfirdn2d(input, self.kernel, up=1, down=self.factor, pad=self.pad)
|
135 |
+
|
136 |
+
return out
|
137 |
+
|
138 |
+
|
139 |
+
class Blur(nn.Module):
|
140 |
+
def __init__(self, kernel, pad, upsample_factor=1):
|
141 |
+
super().__init__()
|
142 |
+
|
143 |
+
kernel = make_kernel(kernel)
|
144 |
+
|
145 |
+
if upsample_factor > 1:
|
146 |
+
kernel = kernel * (upsample_factor ** 2)
|
147 |
+
|
148 |
+
self.register_buffer('kernel', kernel)
|
149 |
+
|
150 |
+
self.pad = pad
|
151 |
+
|
152 |
+
def forward(self, input):
|
153 |
+
out = upfirdn2d(input, self.kernel, pad=self.pad)
|
154 |
+
|
155 |
+
return out
|
156 |
+
|
157 |
+
|
158 |
+
class EqualConv2d(nn.Module):
|
159 |
+
def __init__(
|
160 |
+
self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True
|
161 |
+
):
|
162 |
+
super().__init__()
|
163 |
+
|
164 |
+
self.weight = nn.Parameter(
|
165 |
+
torch.randn(out_channel, in_channel, kernel_size, kernel_size)
|
166 |
+
)
|
167 |
+
self.scale = math.sqrt(1) / math.sqrt(in_channel * (kernel_size ** 2))
|
168 |
+
|
169 |
+
self.stride = stride
|
170 |
+
self.padding = padding
|
171 |
+
|
172 |
+
if bias:
|
173 |
+
self.bias = nn.Parameter(torch.zeros(out_channel))
|
174 |
+
|
175 |
+
else:
|
176 |
+
self.bias = None
|
177 |
+
|
178 |
+
def forward(self, input):
|
179 |
+
# print("Before EqualConv2d: ", input.abs().mean())
|
180 |
+
out = F.conv2d(
|
181 |
+
input,
|
182 |
+
self.weight * self.scale,
|
183 |
+
bias=self.bias,
|
184 |
+
stride=self.stride,
|
185 |
+
padding=self.padding,
|
186 |
+
)
|
187 |
+
# print("After EqualConv2d: ", out.abs().mean(), (self.weight * self.scale).abs().mean())
|
188 |
+
|
189 |
+
return out
|
190 |
+
|
191 |
+
def __repr__(self):
|
192 |
+
return (
|
193 |
+
f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]},'
|
194 |
+
f' {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})'
|
195 |
+
)
|
196 |
+
|
197 |
+
|
198 |
+
class EqualLinear(nn.Module):
|
199 |
+
def __init__(
|
200 |
+
self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None
|
201 |
+
):
|
202 |
+
super().__init__()
|
203 |
+
|
204 |
+
self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul))
|
205 |
+
|
206 |
+
if bias:
|
207 |
+
self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init))
|
208 |
+
|
209 |
+
else:
|
210 |
+
self.bias = None
|
211 |
+
|
212 |
+
self.activation = activation
|
213 |
+
|
214 |
+
self.scale = (math.sqrt(1) / math.sqrt(in_dim)) * lr_mul
|
215 |
+
self.lr_mul = lr_mul
|
216 |
+
|
217 |
+
def forward(self, input):
|
218 |
+
if self.activation:
|
219 |
+
out = F.linear(input, self.weight * self.scale)
|
220 |
+
out = fused_leaky_relu(out, self.bias * self.lr_mul)
|
221 |
+
|
222 |
+
else:
|
223 |
+
out = F.linear(
|
224 |
+
input, self.weight * self.scale, bias=self.bias * self.lr_mul
|
225 |
+
)
|
226 |
+
|
227 |
+
return out
|
228 |
+
|
229 |
+
def __repr__(self):
|
230 |
+
return (
|
231 |
+
f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})'
|
232 |
+
)
|
233 |
+
|
234 |
+
|
235 |
+
class ScaledLeakyReLU(nn.Module):
|
236 |
+
def __init__(self, negative_slope=0.2):
|
237 |
+
super().__init__()
|
238 |
+
|
239 |
+
self.negative_slope = negative_slope
|
240 |
+
|
241 |
+
def forward(self, input):
|
242 |
+
out = F.leaky_relu(input, negative_slope=self.negative_slope)
|
243 |
+
|
244 |
+
return out * math.sqrt(2)
|
245 |
+
|
246 |
+
|
247 |
+
class ModulatedConv2d(nn.Module):
|
248 |
+
def __init__(
|
249 |
+
self,
|
250 |
+
in_channel,
|
251 |
+
out_channel,
|
252 |
+
kernel_size,
|
253 |
+
style_dim,
|
254 |
+
demodulate=True,
|
255 |
+
upsample=False,
|
256 |
+
downsample=False,
|
257 |
+
blur_kernel=[1, 3, 3, 1],
|
258 |
+
):
|
259 |
+
super().__init__()
|
260 |
+
|
261 |
+
self.eps = 1e-8
|
262 |
+
self.kernel_size = kernel_size
|
263 |
+
self.in_channel = in_channel
|
264 |
+
self.out_channel = out_channel
|
265 |
+
self.upsample = upsample
|
266 |
+
self.downsample = downsample
|
267 |
+
|
268 |
+
if upsample:
|
269 |
+
factor = 2
|
270 |
+
p = (len(blur_kernel) - factor) - (kernel_size - 1)
|
271 |
+
pad0 = (p + 1) // 2 + factor - 1
|
272 |
+
pad1 = p // 2 + 1
|
273 |
+
|
274 |
+
self.blur = Blur(blur_kernel, pad=(pad0, pad1), upsample_factor=factor)
|
275 |
+
|
276 |
+
if downsample:
|
277 |
+
factor = 2
|
278 |
+
p = (len(blur_kernel) - factor) + (kernel_size - 1)
|
279 |
+
pad0 = (p + 1) // 2
|
280 |
+
pad1 = p // 2
|
281 |
+
|
282 |
+
self.blur = Blur(blur_kernel, pad=(pad0, pad1))
|
283 |
+
|
284 |
+
fan_in = in_channel * kernel_size ** 2
|
285 |
+
self.scale = math.sqrt(1) / math.sqrt(fan_in)
|
286 |
+
self.padding = kernel_size // 2
|
287 |
+
|
288 |
+
self.weight = nn.Parameter(
|
289 |
+
torch.randn(1, out_channel, in_channel, kernel_size, kernel_size)
|
290 |
+
)
|
291 |
+
|
292 |
+
if style_dim is not None and style_dim > 0:
|
293 |
+
self.modulation = EqualLinear(style_dim, in_channel, bias_init=1)
|
294 |
+
|
295 |
+
self.demodulate = demodulate
|
296 |
+
|
297 |
+
def __repr__(self):
|
298 |
+
return (
|
299 |
+
f'{self.__class__.__name__}({self.in_channel}, {self.out_channel}, {self.kernel_size}, '
|
300 |
+
f'upsample={self.upsample}, downsample={self.downsample})'
|
301 |
+
)
|
302 |
+
|
303 |
+
def forward(self, input, style):
|
304 |
+
batch, in_channel, height, width = input.shape
|
305 |
+
|
306 |
+
if style is not None:
|
307 |
+
style = self.modulation(style).view(batch, 1, in_channel, 1, 1)
|
308 |
+
else:
|
309 |
+
style = torch.ones(batch, 1, in_channel, 1, 1).cuda()
|
310 |
+
weight = self.scale * self.weight * style
|
311 |
+
|
312 |
+
if self.demodulate:
|
313 |
+
demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-8)
|
314 |
+
weight = weight * demod.view(batch, self.out_channel, 1, 1, 1)
|
315 |
+
|
316 |
+
weight = weight.view(
|
317 |
+
batch * self.out_channel, in_channel, self.kernel_size, self.kernel_size
|
318 |
+
)
|
319 |
+
|
320 |
+
if self.upsample:
|
321 |
+
input = input.view(1, batch * in_channel, height, width)
|
322 |
+
weight = weight.view(
|
323 |
+
batch, self.out_channel, in_channel, self.kernel_size, self.kernel_size
|
324 |
+
)
|
325 |
+
weight = weight.transpose(1, 2).reshape(
|
326 |
+
batch * in_channel, self.out_channel, self.kernel_size, self.kernel_size
|
327 |
+
)
|
328 |
+
out = F.conv_transpose2d(input, weight, padding=0, stride=2, groups=batch)
|
329 |
+
_, _, height, width = out.shape
|
330 |
+
out = out.view(batch, self.out_channel, height, width)
|
331 |
+
out = self.blur(out)
|
332 |
+
|
333 |
+
elif self.downsample:
|
334 |
+
input = self.blur(input)
|
335 |
+
_, _, height, width = input.shape
|
336 |
+
input = input.view(1, batch * in_channel, height, width)
|
337 |
+
out = F.conv2d(input, weight, padding=0, stride=2, groups=batch)
|
338 |
+
_, _, height, width = out.shape
|
339 |
+
out = out.view(batch, self.out_channel, height, width)
|
340 |
+
|
341 |
+
else:
|
342 |
+
input = input.view(1, batch * in_channel, height, width)
|
343 |
+
out = F.conv2d(input, weight, padding=self.padding, groups=batch)
|
344 |
+
_, _, height, width = out.shape
|
345 |
+
out = out.view(batch, self.out_channel, height, width)
|
346 |
+
|
347 |
+
return out
|
348 |
+
|
349 |
+
|
350 |
+
class NoiseInjection(nn.Module):
|
351 |
+
def __init__(self):
|
352 |
+
super().__init__()
|
353 |
+
|
354 |
+
self.weight = nn.Parameter(torch.zeros(1))
|
355 |
+
|
356 |
+
def forward(self, image, noise=None):
|
357 |
+
if noise is None:
|
358 |
+
batch, _, height, width = image.shape
|
359 |
+
noise = image.new_empty(batch, 1, height, width).normal_()
|
360 |
+
|
361 |
+
return image + self.weight * noise
|
362 |
+
|
363 |
+
|
364 |
+
class ConstantInput(nn.Module):
|
365 |
+
def __init__(self, channel, size=4):
|
366 |
+
super().__init__()
|
367 |
+
|
368 |
+
self.input = nn.Parameter(torch.randn(1, channel, size, size))
|
369 |
+
|
370 |
+
def forward(self, input):
|
371 |
+
batch = input.shape[0]
|
372 |
+
out = self.input.repeat(batch, 1, 1, 1)
|
373 |
+
|
374 |
+
return out
|
375 |
+
|
376 |
+
|
377 |
+
class StyledConv(nn.Module):
|
378 |
+
def __init__(
|
379 |
+
self,
|
380 |
+
in_channel,
|
381 |
+
out_channel,
|
382 |
+
kernel_size,
|
383 |
+
style_dim=None,
|
384 |
+
upsample=False,
|
385 |
+
blur_kernel=[1, 3, 3, 1],
|
386 |
+
demodulate=True,
|
387 |
+
inject_noise=True,
|
388 |
+
):
|
389 |
+
super().__init__()
|
390 |
+
|
391 |
+
self.inject_noise = inject_noise
|
392 |
+
self.conv = ModulatedConv2d(
|
393 |
+
in_channel,
|
394 |
+
out_channel,
|
395 |
+
kernel_size,
|
396 |
+
style_dim,
|
397 |
+
upsample=upsample,
|
398 |
+
blur_kernel=blur_kernel,
|
399 |
+
demodulate=demodulate,
|
400 |
+
)
|
401 |
+
|
402 |
+
self.noise = NoiseInjection()
|
403 |
+
# self.bias = nn.Parameter(torch.zeros(1, out_channel, 1, 1))
|
404 |
+
# self.activate = ScaledLeakyReLU(0.2)
|
405 |
+
self.activate = FusedLeakyReLU(out_channel)
|
406 |
+
|
407 |
+
def forward(self, input, style=None, noise=None):
|
408 |
+
out = self.conv(input, style)
|
409 |
+
if self.inject_noise:
|
410 |
+
out = self.noise(out, noise=noise)
|
411 |
+
# out = out + self.bias
|
412 |
+
out = self.activate(out)
|
413 |
+
|
414 |
+
return out
|
415 |
+
|
416 |
+
|
417 |
+
class ToRGB(nn.Module):
|
418 |
+
def __init__(self, in_channel, style_dim, upsample=True, blur_kernel=[1, 3, 3, 1]):
|
419 |
+
super().__init__()
|
420 |
+
|
421 |
+
if upsample:
|
422 |
+
self.upsample = Upsample(blur_kernel)
|
423 |
+
|
424 |
+
self.conv = ModulatedConv2d(in_channel, 3, 1, style_dim, demodulate=False)
|
425 |
+
self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1))
|
426 |
+
|
427 |
+
def forward(self, input, style, skip=None):
|
428 |
+
out = self.conv(input, style)
|
429 |
+
out = out + self.bias
|
430 |
+
|
431 |
+
if skip is not None:
|
432 |
+
skip = self.upsample(skip)
|
433 |
+
|
434 |
+
out = out + skip
|
435 |
+
|
436 |
+
return out
|
437 |
+
|
438 |
+
|
439 |
+
class Generator(nn.Module):
|
440 |
+
def __init__(
|
441 |
+
self,
|
442 |
+
size,
|
443 |
+
style_dim,
|
444 |
+
n_mlp,
|
445 |
+
channel_multiplier=2,
|
446 |
+
blur_kernel=[1, 3, 3, 1],
|
447 |
+
lr_mlp=0.01,
|
448 |
+
):
|
449 |
+
super().__init__()
|
450 |
+
|
451 |
+
self.size = size
|
452 |
+
|
453 |
+
self.style_dim = style_dim
|
454 |
+
|
455 |
+
layers = [PixelNorm()]
|
456 |
+
|
457 |
+
for i in range(n_mlp):
|
458 |
+
layers.append(
|
459 |
+
EqualLinear(
|
460 |
+
style_dim, style_dim, lr_mul=lr_mlp, activation='fused_lrelu'
|
461 |
+
)
|
462 |
+
)
|
463 |
+
|
464 |
+
self.style = nn.Sequential(*layers)
|
465 |
+
|
466 |
+
self.channels = {
|
467 |
+
4: 512,
|
468 |
+
8: 512,
|
469 |
+
16: 512,
|
470 |
+
32: 512,
|
471 |
+
64: 256 * channel_multiplier,
|
472 |
+
128: 128 * channel_multiplier,
|
473 |
+
256: 64 * channel_multiplier,
|
474 |
+
512: 32 * channel_multiplier,
|
475 |
+
1024: 16 * channel_multiplier,
|
476 |
+
}
|
477 |
+
|
478 |
+
self.input = ConstantInput(self.channels[4])
|
479 |
+
self.conv1 = StyledConv(
|
480 |
+
self.channels[4], self.channels[4], 3, style_dim, blur_kernel=blur_kernel
|
481 |
+
)
|
482 |
+
self.to_rgb1 = ToRGB(self.channels[4], style_dim, upsample=False)
|
483 |
+
|
484 |
+
self.log_size = int(math.log(size, 2))
|
485 |
+
self.num_layers = (self.log_size - 2) * 2 + 1
|
486 |
+
|
487 |
+
self.convs = nn.ModuleList()
|
488 |
+
self.upsamples = nn.ModuleList()
|
489 |
+
self.to_rgbs = nn.ModuleList()
|
490 |
+
self.noises = nn.Module()
|
491 |
+
|
492 |
+
in_channel = self.channels[4]
|
493 |
+
|
494 |
+
for layer_idx in range(self.num_layers):
|
495 |
+
res = (layer_idx + 5) // 2
|
496 |
+
shape = [1, 1, 2 ** res, 2 ** res]
|
497 |
+
self.noises.register_buffer(f'noise_{layer_idx}', torch.randn(*shape))
|
498 |
+
|
499 |
+
for i in range(3, self.log_size + 1):
|
500 |
+
out_channel = self.channels[2 ** i]
|
501 |
+
|
502 |
+
self.convs.append(
|
503 |
+
StyledConv(
|
504 |
+
in_channel,
|
505 |
+
out_channel,
|
506 |
+
3,
|
507 |
+
style_dim,
|
508 |
+
upsample=True,
|
509 |
+
blur_kernel=blur_kernel,
|
510 |
+
)
|
511 |
+
)
|
512 |
+
|
513 |
+
self.convs.append(
|
514 |
+
StyledConv(
|
515 |
+
out_channel, out_channel, 3, style_dim, blur_kernel=blur_kernel
|
516 |
+
)
|
517 |
+
)
|
518 |
+
|
519 |
+
self.to_rgbs.append(ToRGB(out_channel, style_dim))
|
520 |
+
|
521 |
+
in_channel = out_channel
|
522 |
+
|
523 |
+
self.n_latent = self.log_size * 2 - 2
|
524 |
+
|
525 |
+
def make_noise(self):
|
526 |
+
device = self.input.input.device
|
527 |
+
|
528 |
+
noises = [torch.randn(1, 1, 2 ** 2, 2 ** 2, device=device)]
|
529 |
+
|
530 |
+
for i in range(3, self.log_size + 1):
|
531 |
+
for _ in range(2):
|
532 |
+
noises.append(torch.randn(1, 1, 2 ** i, 2 ** i, device=device))
|
533 |
+
|
534 |
+
return noises
|
535 |
+
|
536 |
+
def mean_latent(self, n_latent):
|
537 |
+
latent_in = torch.randn(
|
538 |
+
n_latent, self.style_dim, device=self.input.input.device
|
539 |
+
)
|
540 |
+
latent = self.style(latent_in).mean(0, keepdim=True)
|
541 |
+
|
542 |
+
return latent
|
543 |
+
|
544 |
+
def get_latent(self, input):
|
545 |
+
return self.style(input)
|
546 |
+
|
547 |
+
def forward(
|
548 |
+
self,
|
549 |
+
styles,
|
550 |
+
return_latents=False,
|
551 |
+
inject_index=None,
|
552 |
+
truncation=1,
|
553 |
+
truncation_latent=None,
|
554 |
+
input_is_latent=False,
|
555 |
+
noise=None,
|
556 |
+
randomize_noise=True,
|
557 |
+
):
|
558 |
+
if not input_is_latent:
|
559 |
+
styles = [self.style(s) for s in styles]
|
560 |
+
|
561 |
+
if noise is None:
|
562 |
+
if randomize_noise:
|
563 |
+
noise = [None] * self.num_layers
|
564 |
+
else:
|
565 |
+
noise = [
|
566 |
+
getattr(self.noises, f'noise_{i}') for i in range(self.num_layers)
|
567 |
+
]
|
568 |
+
|
569 |
+
if truncation < 1:
|
570 |
+
style_t = []
|
571 |
+
|
572 |
+
for style in styles:
|
573 |
+
style_t.append(
|
574 |
+
truncation_latent + truncation * (style - truncation_latent)
|
575 |
+
)
|
576 |
+
|
577 |
+
styles = style_t
|
578 |
+
|
579 |
+
if len(styles) < 2:
|
580 |
+
inject_index = self.n_latent
|
581 |
+
|
582 |
+
if len(styles[0].shape) < 3:
|
583 |
+
latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
|
584 |
+
|
585 |
+
else:
|
586 |
+
latent = styles[0]
|
587 |
+
|
588 |
+
else:
|
589 |
+
if inject_index is None:
|
590 |
+
inject_index = random.randint(1, self.n_latent - 1)
|
591 |
+
|
592 |
+
latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
|
593 |
+
latent2 = styles[1].unsqueeze(1).repeat(1, self.n_latent - inject_index, 1)
|
594 |
+
|
595 |
+
latent = torch.cat([latent, latent2], 1)
|
596 |
+
|
597 |
+
out = self.input(latent)
|
598 |
+
out = self.conv1(out, latent[:, 0], noise=noise[0])
|
599 |
+
|
600 |
+
skip = self.to_rgb1(out, latent[:, 1])
|
601 |
+
|
602 |
+
i = 1
|
603 |
+
for conv1, conv2, noise1, noise2, to_rgb in zip(
|
604 |
+
self.convs[::2], self.convs[1::2], noise[1::2], noise[2::2], self.to_rgbs
|
605 |
+
):
|
606 |
+
out = conv1(out, latent[:, i], noise=noise1)
|
607 |
+
out = conv2(out, latent[:, i + 1], noise=noise2)
|
608 |
+
skip = to_rgb(out, latent[:, i + 2], skip)
|
609 |
+
|
610 |
+
i += 2
|
611 |
+
|
612 |
+
image = skip
|
613 |
+
|
614 |
+
if return_latents:
|
615 |
+
return image, latent
|
616 |
+
|
617 |
+
else:
|
618 |
+
return image, None
|
619 |
+
|
620 |
+
|
621 |
+
class ConvLayer(nn.Sequential):
|
622 |
+
def __init__(
|
623 |
+
self,
|
624 |
+
in_channel,
|
625 |
+
out_channel,
|
626 |
+
kernel_size,
|
627 |
+
downsample=False,
|
628 |
+
blur_kernel=[1, 3, 3, 1],
|
629 |
+
bias=True,
|
630 |
+
activate=True,
|
631 |
+
):
|
632 |
+
layers = []
|
633 |
+
|
634 |
+
if downsample:
|
635 |
+
factor = 2
|
636 |
+
p = (len(blur_kernel) - factor) + (kernel_size - 1)
|
637 |
+
pad0 = (p + 1) // 2
|
638 |
+
pad1 = p // 2
|
639 |
+
|
640 |
+
layers.append(Blur(blur_kernel, pad=(pad0, pad1)))
|
641 |
+
|
642 |
+
stride = 2
|
643 |
+
self.padding = 0
|
644 |
+
|
645 |
+
else:
|
646 |
+
stride = 1
|
647 |
+
self.padding = kernel_size // 2
|
648 |
+
|
649 |
+
layers.append(
|
650 |
+
EqualConv2d(
|
651 |
+
in_channel,
|
652 |
+
out_channel,
|
653 |
+
kernel_size,
|
654 |
+
padding=self.padding,
|
655 |
+
stride=stride,
|
656 |
+
bias=bias and not activate,
|
657 |
+
)
|
658 |
+
)
|
659 |
+
|
660 |
+
if activate:
|
661 |
+
if bias:
|
662 |
+
layers.append(FusedLeakyReLU(out_channel))
|
663 |
+
|
664 |
+
else:
|
665 |
+
layers.append(ScaledLeakyReLU(0.2))
|
666 |
+
|
667 |
+
super().__init__(*layers)
|
668 |
+
|
669 |
+
|
670 |
+
class ResBlock(nn.Module):
|
671 |
+
def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1], downsample=True, skip_gain=1.0):
|
672 |
+
super().__init__()
|
673 |
+
|
674 |
+
self.skip_gain = skip_gain
|
675 |
+
self.conv1 = ConvLayer(in_channel, in_channel, 3)
|
676 |
+
self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=downsample, blur_kernel=blur_kernel)
|
677 |
+
|
678 |
+
if in_channel != out_channel or downsample:
|
679 |
+
self.skip = ConvLayer(
|
680 |
+
in_channel, out_channel, 1, downsample=downsample, activate=False, bias=False
|
681 |
+
)
|
682 |
+
else:
|
683 |
+
self.skip = nn.Identity()
|
684 |
+
|
685 |
+
def forward(self, input):
|
686 |
+
out = self.conv1(input)
|
687 |
+
out = self.conv2(out)
|
688 |
+
|
689 |
+
skip = self.skip(input)
|
690 |
+
out = (out * self.skip_gain + skip) / math.sqrt(self.skip_gain ** 2 + 1.0)
|
691 |
+
|
692 |
+
return out
|
693 |
+
|
694 |
+
|
695 |
+
class StyleGAN2Discriminator(nn.Module):
|
696 |
+
def __init__(self, input_nc, ndf=64, n_layers=3, no_antialias=False, size=None, opt=None):
|
697 |
+
super().__init__()
|
698 |
+
self.opt = opt
|
699 |
+
self.stddev_group = 16
|
700 |
+
if size is None:
|
701 |
+
size = 2 ** int((np.rint(np.log2(min(opt.load_size, opt.crop_size)))))
|
702 |
+
if "patch" in self.opt.netD and self.opt.D_patch_size is not None:
|
703 |
+
size = 2 ** int(np.log2(self.opt.D_patch_size))
|
704 |
+
|
705 |
+
blur_kernel = [1, 3, 3, 1]
|
706 |
+
channel_multiplier = ndf / 64
|
707 |
+
channels = {
|
708 |
+
4: min(384, int(4096 * channel_multiplier)),
|
709 |
+
8: min(384, int(2048 * channel_multiplier)),
|
710 |
+
16: min(384, int(1024 * channel_multiplier)),
|
711 |
+
32: min(384, int(512 * channel_multiplier)),
|
712 |
+
64: int(256 * channel_multiplier),
|
713 |
+
128: int(128 * channel_multiplier),
|
714 |
+
256: int(64 * channel_multiplier),
|
715 |
+
512: int(32 * channel_multiplier),
|
716 |
+
1024: int(16 * channel_multiplier),
|
717 |
+
}
|
718 |
+
|
719 |
+
convs = [ConvLayer(3, channels[size], 1)]
|
720 |
+
|
721 |
+
log_size = int(math.log(size, 2))
|
722 |
+
|
723 |
+
in_channel = channels[size]
|
724 |
+
|
725 |
+
if "smallpatch" in self.opt.netD:
|
726 |
+
final_res_log2 = 4
|
727 |
+
elif "patch" in self.opt.netD:
|
728 |
+
final_res_log2 = 3
|
729 |
+
else:
|
730 |
+
final_res_log2 = 2
|
731 |
+
|
732 |
+
for i in range(log_size, final_res_log2, -1):
|
733 |
+
out_channel = channels[2 ** (i - 1)]
|
734 |
+
|
735 |
+
convs.append(ResBlock(in_channel, out_channel, blur_kernel))
|
736 |
+
|
737 |
+
in_channel = out_channel
|
738 |
+
|
739 |
+
self.convs = nn.Sequential(*convs)
|
740 |
+
|
741 |
+
if False and "tile" in self.opt.netD:
|
742 |
+
in_channel += 1
|
743 |
+
self.final_conv = ConvLayer(in_channel, channels[4], 3)
|
744 |
+
if "patch" in self.opt.netD:
|
745 |
+
self.final_linear = ConvLayer(channels[4], 1, 3, bias=False, activate=False)
|
746 |
+
else:
|
747 |
+
self.final_linear = nn.Sequential(
|
748 |
+
EqualLinear(channels[4] * 4 * 4, channels[4], activation='fused_lrelu'),
|
749 |
+
EqualLinear(channels[4], 1),
|
750 |
+
)
|
751 |
+
|
752 |
+
def forward(self, input, get_minibatch_features=False):
|
753 |
+
if "patch" in self.opt.netD and self.opt.D_patch_size is not None:
|
754 |
+
h, w = input.size(2), input.size(3)
|
755 |
+
y = torch.randint(h - self.opt.D_patch_size, ())
|
756 |
+
x = torch.randint(w - self.opt.D_patch_size, ())
|
757 |
+
input = input[:, :, y:y + self.opt.D_patch_size, x:x + self.opt.D_patch_size]
|
758 |
+
out = input
|
759 |
+
for i, conv in enumerate(self.convs):
|
760 |
+
out = conv(out)
|
761 |
+
# print(i, out.abs().mean())
|
762 |
+
# out = self.convs(input)
|
763 |
+
|
764 |
+
batch, channel, height, width = out.shape
|
765 |
+
|
766 |
+
if False and "tile" in self.opt.netD:
|
767 |
+
group = min(batch, self.stddev_group)
|
768 |
+
stddev = out.view(
|
769 |
+
group, -1, 1, channel // 1, height, width
|
770 |
+
)
|
771 |
+
stddev = torch.sqrt(stddev.var(0, unbiased=False) + 1e-8)
|
772 |
+
stddev = stddev.mean([2, 3, 4], keepdim=True).squeeze(2)
|
773 |
+
stddev = stddev.repeat(group, 1, height, width)
|
774 |
+
out = torch.cat([out, stddev], 1)
|
775 |
+
|
776 |
+
out = self.final_conv(out)
|
777 |
+
# print(out.abs().mean())
|
778 |
+
|
779 |
+
if "patch" not in self.opt.netD:
|
780 |
+
out = out.view(batch, -1)
|
781 |
+
out = self.final_linear(out)
|
782 |
+
|
783 |
+
return out
|
784 |
+
|
785 |
+
|
786 |
+
class TileStyleGAN2Discriminator(StyleGAN2Discriminator):
|
787 |
+
def forward(self, input):
|
788 |
+
B, C, H, W = input.size(0), input.size(1), input.size(2), input.size(3)
|
789 |
+
size = self.opt.D_patch_size
|
790 |
+
Y = H // size
|
791 |
+
X = W // size
|
792 |
+
input = input.view(B, C, Y, size, X, size)
|
793 |
+
input = input.permute(0, 2, 4, 1, 3, 5).contiguous().view(B * Y * X, C, size, size)
|
794 |
+
return super().forward(input)
|
795 |
+
|
796 |
+
|
797 |
+
class StyleGAN2Encoder(nn.Module):
|
798 |
+
def __init__(self, input_nc, output_nc, ngf=64, use_dropout=False, n_blocks=6, padding_type='reflect', no_antialias=False, opt=None):
|
799 |
+
super().__init__()
|
800 |
+
assert opt is not None
|
801 |
+
self.opt = opt
|
802 |
+
channel_multiplier = ngf / 32
|
803 |
+
channels = {
|
804 |
+
4: min(512, int(round(4096 * channel_multiplier))),
|
805 |
+
8: min(512, int(round(2048 * channel_multiplier))),
|
806 |
+
16: min(512, int(round(1024 * channel_multiplier))),
|
807 |
+
32: min(512, int(round(512 * channel_multiplier))),
|
808 |
+
64: int(round(256 * channel_multiplier)),
|
809 |
+
128: int(round(128 * channel_multiplier)),
|
810 |
+
256: int(round(64 * channel_multiplier)),
|
811 |
+
512: int(round(32 * channel_multiplier)),
|
812 |
+
1024: int(round(16 * channel_multiplier)),
|
813 |
+
}
|
814 |
+
|
815 |
+
blur_kernel = [1, 3, 3, 1]
|
816 |
+
|
817 |
+
cur_res = 2 ** int((np.rint(np.log2(min(opt.load_size, opt.crop_size)))))
|
818 |
+
convs = [nn.Identity(),
|
819 |
+
ConvLayer(3, channels[cur_res], 1)]
|
820 |
+
|
821 |
+
num_downsampling = self.opt.stylegan2_G_num_downsampling
|
822 |
+
for i in range(num_downsampling):
|
823 |
+
in_channel = channels[cur_res]
|
824 |
+
out_channel = channels[cur_res // 2]
|
825 |
+
convs.append(ResBlock(in_channel, out_channel, blur_kernel, downsample=True))
|
826 |
+
cur_res = cur_res // 2
|
827 |
+
|
828 |
+
for i in range(n_blocks // 2):
|
829 |
+
n_channel = channels[cur_res]
|
830 |
+
convs.append(ResBlock(n_channel, n_channel, downsample=False))
|
831 |
+
|
832 |
+
self.convs = nn.Sequential(*convs)
|
833 |
+
|
834 |
+
def forward(self, input, layers=[], get_features=False):
|
835 |
+
feat = input
|
836 |
+
feats = []
|
837 |
+
if -1 in layers:
|
838 |
+
layers.append(len(self.convs) - 1)
|
839 |
+
for layer_id, layer in enumerate(self.convs):
|
840 |
+
feat = layer(feat)
|
841 |
+
# print(layer_id, " features ", feat.abs().mean())
|
842 |
+
if layer_id in layers:
|
843 |
+
feats.append(feat)
|
844 |
+
|
845 |
+
if get_features:
|
846 |
+
return feat, feats
|
847 |
+
else:
|
848 |
+
return feat
|
849 |
+
|
850 |
+
|
851 |
+
class StyleGAN2Decoder(nn.Module):
|
852 |
+
def __init__(self, input_nc, output_nc, ngf=64, use_dropout=False, n_blocks=6, padding_type='reflect', no_antialias=False, opt=None):
|
853 |
+
super().__init__()
|
854 |
+
assert opt is not None
|
855 |
+
self.opt = opt
|
856 |
+
|
857 |
+
blur_kernel = [1, 3, 3, 1]
|
858 |
+
|
859 |
+
channel_multiplier = ngf / 32
|
860 |
+
channels = {
|
861 |
+
4: min(512, int(round(4096 * channel_multiplier))),
|
862 |
+
8: min(512, int(round(2048 * channel_multiplier))),
|
863 |
+
16: min(512, int(round(1024 * channel_multiplier))),
|
864 |
+
32: min(512, int(round(512 * channel_multiplier))),
|
865 |
+
64: int(round(256 * channel_multiplier)),
|
866 |
+
128: int(round(128 * channel_multiplier)),
|
867 |
+
256: int(round(64 * channel_multiplier)),
|
868 |
+
512: int(round(32 * channel_multiplier)),
|
869 |
+
1024: int(round(16 * channel_multiplier)),
|
870 |
+
}
|
871 |
+
|
872 |
+
num_downsampling = self.opt.stylegan2_G_num_downsampling
|
873 |
+
cur_res = 2 ** int((np.rint(np.log2(min(opt.load_size, opt.crop_size))))) // (2 ** num_downsampling)
|
874 |
+
convs = []
|
875 |
+
|
876 |
+
for i in range(n_blocks // 2):
|
877 |
+
n_channel = channels[cur_res]
|
878 |
+
convs.append(ResBlock(n_channel, n_channel, downsample=False))
|
879 |
+
|
880 |
+
for i in range(num_downsampling):
|
881 |
+
in_channel = channels[cur_res]
|
882 |
+
out_channel = channels[cur_res * 2]
|
883 |
+
inject_noise = "small" not in self.opt.netG
|
884 |
+
convs.append(
|
885 |
+
StyledConv(in_channel, out_channel, 3, upsample=True, blur_kernel=blur_kernel, inject_noise=inject_noise)
|
886 |
+
)
|
887 |
+
cur_res = cur_res * 2
|
888 |
+
|
889 |
+
convs.append(ConvLayer(channels[cur_res], 3, 1))
|
890 |
+
|
891 |
+
self.convs = nn.Sequential(*convs)
|
892 |
+
|
893 |
+
def forward(self, input):
|
894 |
+
return self.convs(input)
|
895 |
+
|
896 |
+
|
897 |
+
class StyleGAN2Generator(nn.Module):
|
898 |
+
def __init__(self, input_nc, output_nc, ngf=64, use_dropout=False, n_blocks=6, padding_type='reflect', no_antialias=False, opt=None):
|
899 |
+
super().__init__()
|
900 |
+
self.opt = opt
|
901 |
+
self.encoder = StyleGAN2Encoder(input_nc, output_nc, ngf, use_dropout, n_blocks, padding_type, no_antialias, opt)
|
902 |
+
self.decoder = StyleGAN2Decoder(input_nc, output_nc, ngf, use_dropout, n_blocks, padding_type, no_antialias, opt)
|
903 |
+
|
904 |
+
def forward(self, input, layers=[], encode_only=False):
|
905 |
+
feat, feats = self.encoder(input, layers, True)
|
906 |
+
if encode_only:
|
907 |
+
return feats
|
908 |
+
else:
|
909 |
+
fake = self.decoder(feat)
|
910 |
+
|
911 |
+
if len(layers) > 0:
|
912 |
+
return fake, feats
|
913 |
+
else:
|
914 |
+
return fake
|
Scenimefy/options/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
This package options includes option modules: training options, test options, and basic options (used in both training and test).
|
3 |
+
"""
|
Scenimefy/options/base_options.py
ADDED
@@ -0,0 +1,165 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
import torch
|
4 |
+
import Scenimefy.models as models
|
5 |
+
import Scenimefy.data as data
|
6 |
+
from Scenimefy.utils import util
|
7 |
+
|
8 |
+
class BaseOptions():
|
9 |
+
"""
|
10 |
+
This class defines options used during both training and test time.
|
11 |
+
|
12 |
+
It also implements several helper functions such as parsing, printing, and saving the options.
|
13 |
+
It also gathers additional options defined in <modify_commandline_options> functions in both dataset class and model class.
|
14 |
+
"""
|
15 |
+
|
16 |
+
def __init__(self, cmd_line=None):
|
17 |
+
"""Reset the class; indicates the class hasn't been initailized"""
|
18 |
+
self.initialized = False
|
19 |
+
self.cmd_line = None
|
20 |
+
if cmd_line is not None:
|
21 |
+
self.cmd_line = cmd_line.split()
|
22 |
+
|
23 |
+
def initialize(self, parser):
|
24 |
+
"""Define the common options that are used in both training and test."""
|
25 |
+
# basic parameters
|
26 |
+
# load unpaired dataset
|
27 |
+
parser.add_argument('--dataroot', default='Scenimefy\datasets\Sample', help='path to images (should have subfolders trainA, trainB, valA, valB, etc)')
|
28 |
+
parser.add_argument('--name', type=str, default='huggingface', help='name of the experiment. It decides where to store samples and models')
|
29 |
+
parser.add_argument('--easy_label', type=str, default='experiment_name', help='Interpretable name')
|
30 |
+
parser.add_argument('--gpu_ids', type=str, default='-1', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU')
|
31 |
+
parser.add_argument('--checkpoints_dir', type=str, default='Scenimefy/pretrained_models', help='models are saved here')
|
32 |
+
# model parameters
|
33 |
+
parser.add_argument('--model', type=str, default='cut', help='chooses which model to use.')
|
34 |
+
parser.add_argument('--input_nc', type=int, default=3, help='# of input image channels: 3 for RGB and 1 for grayscale')
|
35 |
+
parser.add_argument('--output_nc', type=int, default=3, help='# of output image channels: 3 for RGB and 1 for grayscale')
|
36 |
+
parser.add_argument('--ngf', type=int, default=64, help='# of gen filters in the last conv layer')
|
37 |
+
parser.add_argument('--ndf', type=int, default=64, help='# of discrim filters in the first conv layer')
|
38 |
+
parser.add_argument('--netD', type=str, default='basic', choices=['basic', 'n_layers', 'pixel', 'patch', 'tilestylegan2', 'stylegan2'], help='specify discriminator architecture. The basic model is a 70x70 PatchGAN. n_layers allows you to specify the layers in the discriminator')
|
39 |
+
parser.add_argument('--netG', type=str, default='resnet_9blocks', choices=['resnet_9blocks', 'resnet_6blocks', 'unet_256', 'unet_128', 'stylegan2', 'smallstylegan2', 'resnet_cat'], help='specify generator architecture')
|
40 |
+
parser.add_argument('--n_layers_D', type=int, default=3, help='only used if netD==n_layers')
|
41 |
+
parser.add_argument('--normG', type=str, default='instance', choices=['instance', 'batch', 'none'], help='instance normalization or batch normalization for G')
|
42 |
+
parser.add_argument('--normD', type=str, default='instance', choices=['instance', 'batch', 'none'], help='instance normalization or batch normalization for D')
|
43 |
+
parser.add_argument('--init_type', type=str, default='xavier', choices=['normal', 'xavier', 'kaiming', 'orthogonal'], help='network initialization')
|
44 |
+
parser.add_argument('--init_gain', type=float, default=0.02, help='scaling factor for normal, xavier and orthogonal.')
|
45 |
+
parser.add_argument('--no_dropout', type=util.str2bool, nargs='?', const=True, default=True,
|
46 |
+
help='no dropout for the generator')
|
47 |
+
parser.add_argument('--no_antialias', action='store_true', help='if specified, use stride=2 convs instead of antialiased-downsampling (sad)')
|
48 |
+
parser.add_argument('--no_antialias_up', action='store_true', help='if specified, use [upconv(learned filter)] instead of [upconv(hard-coded [1,3,3,1] filter), conv]')
|
49 |
+
# dataset parameters
|
50 |
+
parser.add_argument('--dataset_mode', type=str, default='unaligned', help='chooses how datasets are loaded. [unaligned | aligned | single | colorization]')
|
51 |
+
parser.add_argument('--direction', type=str, default='AtoB', help='AtoB or BtoA')
|
52 |
+
parser.add_argument('--serial_batches', action='store_true', help='if true, takes images in order to make batches, otherwise takes them randomly')
|
53 |
+
parser.add_argument('--num_threads', default=0, type=int, help='# threads for loading data')
|
54 |
+
parser.add_argument('--batch_size', type=int, default=1, help='input batch size')
|
55 |
+
parser.add_argument('--load_size', type=int, default=256, help='scale images to this size')
|
56 |
+
parser.add_argument('--crop_size', type=int, default=256, help='then crop to this size')
|
57 |
+
parser.add_argument('--max_dataset_size', type=int, default=float("inf"), help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.')
|
58 |
+
parser.add_argument('--preprocess', type=str, default='none', help='scaling and cropping of images at load time [resize_and_crop | crop | scale_width | scale_width_and_crop | none]')
|
59 |
+
parser.add_argument('--no_flip', action='store_true', help='if specified, do not flip the images for data augmentation')
|
60 |
+
parser.add_argument('--display_winsize', type=int, default=256, help='display window size for both visdom and HTML')
|
61 |
+
parser.add_argument('--random_scale_max', type=float, default=3.0,
|
62 |
+
help='(used for single image translation) Randomly scale the image by the specified factor as data augmentation.')
|
63 |
+
# additional parameters
|
64 |
+
parser.add_argument('--epoch', type=str, default='Shinkai', help='which epoch to load? set to latest to use latest cached model')
|
65 |
+
parser.add_argument('--verbose', action='store_true', help='if specified, print more debugging information')
|
66 |
+
parser.add_argument('--suffix', default='', type=str, help='customized suffix: opt.name = opt.name + suffix: e.g., {model}_{netG}_size{load_size}')
|
67 |
+
|
68 |
+
# parameters related to StyleGAN2-based networks
|
69 |
+
parser.add_argument('--stylegan2_G_num_downsampling',
|
70 |
+
default=1, type=int,
|
71 |
+
help='Number of downsampling layers used by StyleGAN2Generator')
|
72 |
+
|
73 |
+
self.initialized = True
|
74 |
+
return parser
|
75 |
+
|
76 |
+
def gather_options(self):
|
77 |
+
"""Initialize our parser with basic options(only once).
|
78 |
+
Add additional model-specific and dataset-specific options.
|
79 |
+
These options are defined in the <modify_commandline_options> function
|
80 |
+
in model and dataset classes.
|
81 |
+
"""
|
82 |
+
if not self.initialized: # check if it has been initialized
|
83 |
+
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
84 |
+
parser = self.initialize(parser)
|
85 |
+
|
86 |
+
# get the basic options
|
87 |
+
if self.cmd_line is None:
|
88 |
+
opt, _ = parser.parse_known_args()
|
89 |
+
else:
|
90 |
+
opt, _ = parser.parse_known_args(self.cmd_line)
|
91 |
+
|
92 |
+
# modify model-related parser options
|
93 |
+
model_name = opt.model
|
94 |
+
model_option_setter = models.get_option_setter(model_name)
|
95 |
+
parser = model_option_setter(parser, self.isTrain)
|
96 |
+
if self.cmd_line is None:
|
97 |
+
opt, _ = parser.parse_known_args() # parse again with new defaults
|
98 |
+
else:
|
99 |
+
opt, _ = parser.parse_known_args(self.cmd_line) # parse again with new defaults
|
100 |
+
|
101 |
+
# modify dataset-related parser options
|
102 |
+
dataset_name = opt.dataset_mode
|
103 |
+
dataset_option_setter = data.get_option_setter(dataset_name)
|
104 |
+
parser = dataset_option_setter(parser, self.isTrain)
|
105 |
+
|
106 |
+
# save and return the parser
|
107 |
+
self.parser = parser
|
108 |
+
if self.cmd_line is None:
|
109 |
+
return parser.parse_args()
|
110 |
+
else:
|
111 |
+
return parser.parse_args(self.cmd_line)
|
112 |
+
|
113 |
+
def print_options(self, opt):
|
114 |
+
"""Print and save options
|
115 |
+
|
116 |
+
It will print both current options and default values(if different).
|
117 |
+
It will save options into a text file / [checkpoints_dir] / opt.txt
|
118 |
+
"""
|
119 |
+
message = ''
|
120 |
+
message += '----------------- Options ---------------\n'
|
121 |
+
for k, v in sorted(vars(opt).items()):
|
122 |
+
comment = ''
|
123 |
+
default = self.parser.get_default(k)
|
124 |
+
if v != default:
|
125 |
+
comment = '\t[default: %s]' % str(default)
|
126 |
+
message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment)
|
127 |
+
message += '----------------- End -------------------'
|
128 |
+
print(message)
|
129 |
+
|
130 |
+
# save to the disk
|
131 |
+
expr_dir = os.path.join(opt.checkpoints_dir, opt.name)
|
132 |
+
util.mkdirs(expr_dir)
|
133 |
+
file_name = os.path.join(expr_dir, '{}_opt.txt'.format(opt.phase))
|
134 |
+
try:
|
135 |
+
with open(file_name, 'wt') as opt_file:
|
136 |
+
opt_file.write(message)
|
137 |
+
opt_file.write('\n')
|
138 |
+
except PermissionError as error:
|
139 |
+
print("permission error {}".format(error))
|
140 |
+
pass
|
141 |
+
|
142 |
+
def parse(self):
|
143 |
+
"""Parse our options, create checkpoints directory suffix, and set up gpu device."""
|
144 |
+
opt = self.gather_options()
|
145 |
+
opt.isTrain = self.isTrain # train or test
|
146 |
+
|
147 |
+
# process opt.suffix
|
148 |
+
if opt.suffix:
|
149 |
+
suffix = ('_' + opt.suffix.format(**vars(opt))) if opt.suffix != '' else ''
|
150 |
+
opt.name = opt.name + suffix
|
151 |
+
|
152 |
+
self.print_options(opt)
|
153 |
+
|
154 |
+
# set gpu ids
|
155 |
+
str_ids = opt.gpu_ids.split(',')
|
156 |
+
opt.gpu_ids = []
|
157 |
+
for str_id in str_ids:
|
158 |
+
id = int(str_id)
|
159 |
+
if id >= 0:
|
160 |
+
opt.gpu_ids.append(id)
|
161 |
+
if len(opt.gpu_ids) > 0:
|
162 |
+
torch.cuda.set_device(opt.gpu_ids[0])
|
163 |
+
|
164 |
+
self.opt = opt
|
165 |
+
return self.opt
|
Scenimefy/options/test_options.py
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from Scenimefy.options.base_options import BaseOptions
|
2 |
+
|
3 |
+
|
4 |
+
class TestOptions(BaseOptions):
|
5 |
+
"""
|
6 |
+
This class includes test options.
|
7 |
+
|
8 |
+
It also includes shared options defined in BaseOptions.
|
9 |
+
"""
|
10 |
+
|
11 |
+
def initialize(self, parser):
|
12 |
+
parser = BaseOptions.initialize(self, parser) # define shared options
|
13 |
+
parser.add_argument('--results_dir', type=str, default='Scenimefy/results/', help='saves results here.')
|
14 |
+
parser.add_argument('--phase', type=str, default='test', help='train, val, test, etc')
|
15 |
+
# Dropout and Batchnorm has different behavioir during training and test.
|
16 |
+
parser.add_argument('--eval', action='store_true', help='use eval mode during test time.')
|
17 |
+
parser.add_argument('--num_test', type=int, default=1000, help='how many test images to run')
|
18 |
+
|
19 |
+
# To avoid cropping, the load_size should be the same as crop_size
|
20 |
+
parser.set_defaults(load_size=parser.get_default('crop_size'))
|
21 |
+
self.isTrain = False
|
22 |
+
return parser
|
Scenimefy/pretrained_models/huggingface/Shinkai_net_G.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:3bdeced133287fbb95832f4342aaff399f6f73507b515118918cc27ccd98ad8c
|
3 |
+
size 45570633
|
Scenimefy/pretrained_models/huggingface/test_opt.txt
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
----------------- Options ---------------
|
2 |
+
CUT_mode: CUT
|
3 |
+
HDCE_gamma: 1
|
4 |
+
HDCE_gamma_min: 1
|
5 |
+
alpha: 0.2
|
6 |
+
batch_size: 1
|
7 |
+
checkpoints_dir: Scenimefy/pretrained_models
|
8 |
+
crop_size: 256
|
9 |
+
dataroot: Scenimefy\datasets\Sample
|
10 |
+
dataset_mode: unaligned
|
11 |
+
dce_idt: False
|
12 |
+
direction: AtoB
|
13 |
+
display_winsize: 256
|
14 |
+
easy_label: experiment_name
|
15 |
+
epoch: Shinkai
|
16 |
+
eval: False
|
17 |
+
flip_equivariance: False
|
18 |
+
gpu_ids: -1
|
19 |
+
init_gain: 0.02
|
20 |
+
init_type: xavier
|
21 |
+
input_nc: 3
|
22 |
+
isTrain: False [default: None]
|
23 |
+
lambda_GAN: 1.0
|
24 |
+
lambda_HDCE: 1.0
|
25 |
+
lambda_SRC: 1.0
|
26 |
+
load_size: 256
|
27 |
+
max_dataset_size: inf
|
28 |
+
model: cut
|
29 |
+
n_layers_D: 3
|
30 |
+
name: huggingface
|
31 |
+
nce_T: 0.07
|
32 |
+
nce_includes_all_negatives_from_minibatch: False
|
33 |
+
nce_layers: 0,4,8,12,16
|
34 |
+
ndf: 64
|
35 |
+
netD: basic
|
36 |
+
netF: mlp_sample
|
37 |
+
netF_nc: 256
|
38 |
+
netG: resnet_9blocks
|
39 |
+
ngf: 64
|
40 |
+
no_Hneg: False
|
41 |
+
no_antialias: False
|
42 |
+
no_antialias_up: False
|
43 |
+
no_dropout: True
|
44 |
+
no_flip: False
|
45 |
+
normD: instance
|
46 |
+
normG: instance
|
47 |
+
num_patches: 256
|
48 |
+
num_test: 1000
|
49 |
+
num_threads: 0
|
50 |
+
output_nc: 3
|
51 |
+
phase: test
|
52 |
+
pool_size: 0
|
53 |
+
preprocess: none
|
54 |
+
random_scale_max: 3.0
|
55 |
+
results_dir: Scenimefy/results/
|
56 |
+
serial_batches: False
|
57 |
+
step_gamma: False
|
58 |
+
step_gamma_epoch: 200
|
59 |
+
stylegan2_G_num_downsampling: 1
|
60 |
+
suffix:
|
61 |
+
use_curriculum: False
|
62 |
+
verbose: False
|
63 |
+
----------------- End -------------------
|
Scenimefy/utils/__init__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
This package includes a miscellaneous collection of useful helper functions.
|
3 |
+
"""
|
4 |
+
from Scenimefy.utils import *
|
Scenimefy/utils/html.py
ADDED
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import dominate
|
2 |
+
from dominate.tags import meta, h3, table, tr, td, p, a, img, br
|
3 |
+
import os
|
4 |
+
|
5 |
+
|
6 |
+
class HTML:
|
7 |
+
"""This HTML class allows us to save images and write texts into a single HTML file.
|
8 |
+
|
9 |
+
It consists of functions such as <add_header> (add a text header to the HTML file),
|
10 |
+
<add_images> (add a row of images to the HTML file), and <save> (save the HTML to the disk).
|
11 |
+
It is based on Python library 'dominate', a Python library for creating and manipulating HTML documents using a DOM API.
|
12 |
+
"""
|
13 |
+
|
14 |
+
def __init__(self, web_dir, title, refresh=0):
|
15 |
+
"""Initialize the HTML classes
|
16 |
+
|
17 |
+
Parameters:
|
18 |
+
web_dir (str) -- a directory that stores the webpage. HTML file will be created at <web_dir>/index.html; images will be saved at <web_dir/images/
|
19 |
+
title (str) -- the webpage name
|
20 |
+
refresh (int) -- how often the website refresh itself; if 0; no refreshing
|
21 |
+
"""
|
22 |
+
self.title = title
|
23 |
+
self.web_dir = web_dir
|
24 |
+
self.img_dir = os.path.join(self.web_dir, 'images')
|
25 |
+
if not os.path.exists(self.web_dir):
|
26 |
+
os.makedirs(self.web_dir)
|
27 |
+
if not os.path.exists(self.img_dir):
|
28 |
+
os.makedirs(self.img_dir)
|
29 |
+
|
30 |
+
self.doc = dominate.document(title=title)
|
31 |
+
if refresh > 0:
|
32 |
+
with self.doc.head:
|
33 |
+
meta(http_equiv="refresh", content=str(refresh))
|
34 |
+
|
35 |
+
def get_image_dir(self):
|
36 |
+
"""Return the directory that stores images"""
|
37 |
+
return self.img_dir
|
38 |
+
|
39 |
+
def add_header(self, text):
|
40 |
+
"""Insert a header to the HTML file
|
41 |
+
|
42 |
+
Parameters:
|
43 |
+
text (str) -- the header text
|
44 |
+
"""
|
45 |
+
with self.doc:
|
46 |
+
h3(text)
|
47 |
+
|
48 |
+
def add_images(self, ims, txts, links, width=400):
|
49 |
+
"""add images to the HTML file
|
50 |
+
|
51 |
+
Parameters:
|
52 |
+
ims (str list) -- a list of image paths
|
53 |
+
txts (str list) -- a list of image names shown on the website
|
54 |
+
links (str list) -- a list of hyperref links; when you click an image, it will redirect you to a new page
|
55 |
+
"""
|
56 |
+
self.t = table(border=1, style="table-layout: fixed;") # Insert a table
|
57 |
+
self.doc.add(self.t)
|
58 |
+
with self.t:
|
59 |
+
with tr():
|
60 |
+
for im, txt, link in zip(ims, txts, links):
|
61 |
+
with td(style="word-wrap: break-word;", halign="center", valign="top"):
|
62 |
+
with p():
|
63 |
+
with a(href=os.path.join('images', link)):
|
64 |
+
img(style="width:%dpx" % width, src=os.path.join('images', im))
|
65 |
+
br()
|
66 |
+
p(txt)
|
67 |
+
|
68 |
+
def save(self):
|
69 |
+
"""save the current content to the HMTL file"""
|
70 |
+
html_file = '%s/index.html' % self.web_dir
|
71 |
+
f = open(html_file, 'wt')
|
72 |
+
f.write(self.doc.render())
|
73 |
+
f.close()
|
74 |
+
|
75 |
+
|
76 |
+
if __name__ == '__main__': # we show an example usage here.
|
77 |
+
html = HTML('web/', 'test_html')
|
78 |
+
html.add_header('hello world')
|
79 |
+
|
80 |
+
ims, txts, links = [], [], []
|
81 |
+
for n in range(4):
|
82 |
+
ims.append('image_%d.png' % n)
|
83 |
+
txts.append('text_%d' % n)
|
84 |
+
links.append('image_%d.png' % n)
|
85 |
+
html.add_images(ims, txts, links)
|
86 |
+
html.save()
|
Scenimefy/utils/util.py
ADDED
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""This module contains simple helper functions """
|
2 |
+
from __future__ import print_function
|
3 |
+
import torch
|
4 |
+
import numpy as np
|
5 |
+
from PIL import Image
|
6 |
+
import os
|
7 |
+
import importlib
|
8 |
+
import argparse
|
9 |
+
from argparse import Namespace
|
10 |
+
import torchvision
|
11 |
+
|
12 |
+
|
13 |
+
def str2bool(v):
|
14 |
+
if isinstance(v, bool):
|
15 |
+
return v
|
16 |
+
if v.lower() in ('yes', 'true', 't', 'y', '1'):
|
17 |
+
return True
|
18 |
+
elif v.lower() in ('no', 'false', 'f', 'n', '0'):
|
19 |
+
return False
|
20 |
+
else:
|
21 |
+
raise argparse.ArgumentTypeError('Boolean value expected.')
|
22 |
+
|
23 |
+
|
24 |
+
def copyconf(default_opt, **kwargs):
|
25 |
+
conf = Namespace(**vars(default_opt))
|
26 |
+
for key in kwargs:
|
27 |
+
setattr(conf, key, kwargs[key])
|
28 |
+
return conf
|
29 |
+
|
30 |
+
|
31 |
+
def find_class_in_module(target_cls_name, module):
|
32 |
+
target_cls_name = target_cls_name.replace('_', '').lower()
|
33 |
+
clslib = importlib.import_module(module)
|
34 |
+
cls = None
|
35 |
+
for name, clsobj in clslib.__dict__.items():
|
36 |
+
if name.lower() == target_cls_name:
|
37 |
+
cls = clsobj
|
38 |
+
|
39 |
+
assert cls is not None, "In %s, there should be a class whose name matches %s in lowercase without underscore(_)" % (module, target_cls_name)
|
40 |
+
|
41 |
+
return cls
|
42 |
+
|
43 |
+
|
44 |
+
def tensor2im(input_image, imtype=np.uint8):
|
45 |
+
""""Converts a Tensor array into a numpy image array.
|
46 |
+
|
47 |
+
Parameters:
|
48 |
+
input_image (tensor) -- the input image tensor array
|
49 |
+
imtype (type) -- the desired type of the converted numpy array
|
50 |
+
"""
|
51 |
+
if not isinstance(input_image, np.ndarray):
|
52 |
+
if isinstance(input_image, torch.Tensor): # get the data from a variable
|
53 |
+
image_tensor = input_image.data
|
54 |
+
else:
|
55 |
+
return input_image
|
56 |
+
image_numpy = image_tensor[0].clamp(-1.0, 1.0).cpu().float().numpy() # convert it into a numpy array
|
57 |
+
if image_numpy.shape[0] == 1: # grayscale to RGB
|
58 |
+
image_numpy = np.tile(image_numpy, (3, 1, 1))
|
59 |
+
image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0 # post-processing: tranpose and scaling
|
60 |
+
else: # if it is a numpy array, do nothing
|
61 |
+
image_numpy = input_image
|
62 |
+
return image_numpy.astype(imtype)
|
63 |
+
|
64 |
+
|
65 |
+
def diagnose_network(net, name='network'):
|
66 |
+
"""Calculate and print the mean of average absolute(gradients)
|
67 |
+
|
68 |
+
Parameters:
|
69 |
+
net (torch network) -- Torch network
|
70 |
+
name (str) -- the name of the network
|
71 |
+
"""
|
72 |
+
mean = 0.0
|
73 |
+
count = 0
|
74 |
+
for param in net.parameters():
|
75 |
+
if param.grad is not None:
|
76 |
+
mean += torch.mean(torch.abs(param.grad.data))
|
77 |
+
count += 1
|
78 |
+
if count > 0:
|
79 |
+
mean = mean / count
|
80 |
+
print(name)
|
81 |
+
print(mean)
|
82 |
+
|
83 |
+
|
84 |
+
def save_image(image_numpy, image_path, aspect_ratio=1.0):
|
85 |
+
"""Save a numpy image to the disk
|
86 |
+
|
87 |
+
Parameters:
|
88 |
+
image_numpy (numpy array) -- input numpy array
|
89 |
+
image_path (str) -- the path of the image
|
90 |
+
"""
|
91 |
+
|
92 |
+
image_pil = Image.fromarray(image_numpy)
|
93 |
+
h, w, _ = image_numpy.shape
|
94 |
+
|
95 |
+
if aspect_ratio is None:
|
96 |
+
pass
|
97 |
+
elif aspect_ratio > 1.0:
|
98 |
+
image_pil = image_pil.resize((h, int(w * aspect_ratio)), Image.BICUBIC)
|
99 |
+
elif aspect_ratio < 1.0:
|
100 |
+
image_pil = image_pil.resize((int(h / aspect_ratio), w), Image.BICUBIC)
|
101 |
+
# TODO: TEST
|
102 |
+
# print(image_path)
|
103 |
+
image_pil.save(image_path)
|
104 |
+
|
105 |
+
|
106 |
+
def print_numpy(x, val=True, shp=False):
|
107 |
+
"""Print the mean, min, max, median, std, and size of a numpy array
|
108 |
+
|
109 |
+
Parameters:
|
110 |
+
val (bool) -- if print the values of the numpy array
|
111 |
+
shp (bool) -- if print the shape of the numpy array
|
112 |
+
"""
|
113 |
+
x = x.astype(np.float64)
|
114 |
+
if shp:
|
115 |
+
print('shape,', x.shape)
|
116 |
+
if val:
|
117 |
+
x = x.flatten()
|
118 |
+
print('mean = %3.3f, min = %3.3f, max = %3.3f, median = %3.3f, std=%3.3f' % (
|
119 |
+
np.mean(x), np.min(x), np.max(x), np.median(x), np.std(x)))
|
120 |
+
|
121 |
+
|
122 |
+
def mkdirs(paths):
|
123 |
+
"""create empty directories if they don't exist
|
124 |
+
|
125 |
+
Parameters:
|
126 |
+
paths (str list) -- a list of directory paths
|
127 |
+
"""
|
128 |
+
if isinstance(paths, list) and not isinstance(paths, str):
|
129 |
+
for path in paths:
|
130 |
+
mkdir(path)
|
131 |
+
else:
|
132 |
+
mkdir(paths)
|
133 |
+
|
134 |
+
|
135 |
+
def mkdir(path):
|
136 |
+
"""create a single empty directory if it didn't exist
|
137 |
+
|
138 |
+
Parameters:
|
139 |
+
path (str) -- a single directory path
|
140 |
+
"""
|
141 |
+
if not os.path.exists(path):
|
142 |
+
os.makedirs(path)
|
143 |
+
|
144 |
+
|
145 |
+
def correct_resize_label(t, size):
|
146 |
+
device = t.device
|
147 |
+
t = t.detach().cpu()
|
148 |
+
resized = []
|
149 |
+
for i in range(t.size(0)):
|
150 |
+
one_t = t[i, :1]
|
151 |
+
one_np = np.transpose(one_t.numpy().astype(np.uint8), (1, 2, 0))
|
152 |
+
one_np = one_np[:, :, 0]
|
153 |
+
one_image = Image.fromarray(one_np).resize(size, Image.NEAREST)
|
154 |
+
resized_t = torch.from_numpy(np.array(one_image)).long()
|
155 |
+
resized.append(resized_t)
|
156 |
+
return torch.stack(resized, dim=0).to(device)
|
157 |
+
|
158 |
+
|
159 |
+
def correct_resize(t, size, mode=Image.BICUBIC):
|
160 |
+
device = t.device
|
161 |
+
t = t.detach().cpu()
|
162 |
+
resized = []
|
163 |
+
for i in range(t.size(0)):
|
164 |
+
one_t = t[i:i + 1]
|
165 |
+
one_image = Image.fromarray(tensor2im(one_t)).resize(size, Image.BICUBIC)
|
166 |
+
resized_t = torchvision.transforms.functional.to_tensor(one_image) * 2 - 1.0
|
167 |
+
resized.append(resized_t)
|
168 |
+
return torch.stack(resized, dim=0).to(device)
|
Scenimefy/utils/visualizer.py
ADDED
@@ -0,0 +1,246 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import os
|
3 |
+
import sys
|
4 |
+
import ntpath
|
5 |
+
import time
|
6 |
+
from . import util, html
|
7 |
+
from subprocess import Popen, PIPE
|
8 |
+
import math
|
9 |
+
|
10 |
+
if sys.version_info[0] == 2:
|
11 |
+
VisdomExceptionBase = Exception
|
12 |
+
else:
|
13 |
+
VisdomExceptionBase = ConnectionError
|
14 |
+
|
15 |
+
|
16 |
+
def save_images(webpage, visuals, image_path, aspect_ratio=1.0, width=256):
|
17 |
+
"""Save images to the disk.
|
18 |
+
|
19 |
+
Parameters:
|
20 |
+
webpage (the HTML class) -- the HTML webpage class that stores these imaegs (see html.py for more details)
|
21 |
+
visuals (OrderedDict) -- an ordered dictionary that stores (name, images (either tensor or numpy) ) pairs
|
22 |
+
image_path (str) -- the string is used to create image paths
|
23 |
+
aspect_ratio (float) -- the aspect ratio of saved images
|
24 |
+
width (int) -- the images will be resized to width x width
|
25 |
+
|
26 |
+
This function will save images stored in 'visuals' to the HTML file specified by 'webpage'.
|
27 |
+
"""
|
28 |
+
image_dir = webpage.get_image_dir()
|
29 |
+
short_path = ntpath.basename(image_path[0])
|
30 |
+
name = os.path.splitext(short_path)[0]
|
31 |
+
|
32 |
+
webpage.add_header(name)
|
33 |
+
ims, txts, links = [], [], []
|
34 |
+
|
35 |
+
for label, im_data in visuals.items():
|
36 |
+
im = util.tensor2im(im_data)
|
37 |
+
image_name = '%s/%s.png' % (label, name)
|
38 |
+
os.makedirs(os.path.join(image_dir, label), exist_ok=True)
|
39 |
+
save_path = os.path.join(image_dir, image_name)
|
40 |
+
util.save_image(im, save_path, aspect_ratio=aspect_ratio)
|
41 |
+
ims.append(image_name)
|
42 |
+
txts.append(label)
|
43 |
+
links.append(image_name)
|
44 |
+
webpage.add_images(ims, txts, links, width=width)
|
45 |
+
|
46 |
+
|
47 |
+
class Visualizer():
|
48 |
+
"""This class includes several functions that can display/save images and print/save logging information.
|
49 |
+
|
50 |
+
It uses a Python library 'visdom' for display, and a Python library 'dominate' (wrapped in 'HTML') for creating HTML files with images.
|
51 |
+
"""
|
52 |
+
|
53 |
+
def __init__(self, opt):
|
54 |
+
"""Initialize the Visualizer class
|
55 |
+
|
56 |
+
Parameters:
|
57 |
+
opt -- stores all the experiment flags; needs to be a subclass of BaseOptions
|
58 |
+
Step 1: Cache the training/test options
|
59 |
+
Step 2: connect to a visdom server
|
60 |
+
Step 3: create an HTML object for saveing HTML filters
|
61 |
+
Step 4: create a logging file to store training losses
|
62 |
+
"""
|
63 |
+
self.opt = opt # cache the option
|
64 |
+
if opt.display_id is None:
|
65 |
+
self.display_id = np.random.randint(100000) * 10 # just a random display id
|
66 |
+
else:
|
67 |
+
self.display_id = opt.display_id
|
68 |
+
self.use_html = opt.isTrain and not opt.no_html
|
69 |
+
self.win_size = opt.display_winsize
|
70 |
+
self.name = opt.name
|
71 |
+
self.port = opt.display_port
|
72 |
+
self.saved = False
|
73 |
+
if self.display_id > 0: # connect to a visdom server given <display_port> and <display_server>
|
74 |
+
import visdom
|
75 |
+
self.plot_data = {}
|
76 |
+
self.ncols = opt.display_ncols
|
77 |
+
if "tensorboard_base_url" not in os.environ:
|
78 |
+
self.vis = visdom.Visdom(server=opt.display_server, port=opt.display_port, env=opt.display_env)
|
79 |
+
else:
|
80 |
+
self.vis = visdom.Visdom(port=2004,
|
81 |
+
base_url=os.environ['tensorboard_base_url'] + '/visdom')
|
82 |
+
if not self.vis.check_connection():
|
83 |
+
self.create_visdom_connections()
|
84 |
+
|
85 |
+
if self.use_html: # create an HTML object at <checkpoints_dir>/web/; images will be saved under <checkpoints_dir>/web/images/
|
86 |
+
self.web_dir = os.path.join(opt.checkpoints_dir, opt.name, 'web')
|
87 |
+
self.img_dir = os.path.join(self.web_dir, 'images')
|
88 |
+
print('create web directory %s...' % self.web_dir)
|
89 |
+
util.mkdirs([self.web_dir, self.img_dir])
|
90 |
+
# create a logging file to store training losses
|
91 |
+
self.log_name = os.path.join(opt.checkpoints_dir, opt.name, 'loss_log.txt')
|
92 |
+
with open(self.log_name, "a") as log_file:
|
93 |
+
now = time.strftime("%c")
|
94 |
+
log_file.write('================ Training Loss (%s) ================\n' % now)
|
95 |
+
|
96 |
+
def reset(self):
|
97 |
+
"""Reset the self.saved status"""
|
98 |
+
self.saved = False
|
99 |
+
|
100 |
+
def create_visdom_connections(self):
|
101 |
+
"""If the program could not connect to Visdom server, this function will start a new server at port < self.port > """
|
102 |
+
cmd = sys.executable + ' -m visdom.server -p %d &>/dev/null &' % self.port
|
103 |
+
print('\n\nCould not connect to Visdom server. \n Trying to start a server....')
|
104 |
+
print('Command: %s' % cmd)
|
105 |
+
Popen(cmd, shell=True, stdout=PIPE, stderr=PIPE)
|
106 |
+
|
107 |
+
def display_current_results(self, visuals, epoch, save_result):
|
108 |
+
"""Display current results on visdom; save current results to an HTML file.
|
109 |
+
|
110 |
+
Parameters:
|
111 |
+
visuals (OrderedDict) - - dictionary of images to display or save
|
112 |
+
epoch (int) - - the current epoch
|
113 |
+
save_result (bool) - - if save the current results to an HTML file
|
114 |
+
"""
|
115 |
+
if self.display_id > 0: # show images in the browser using visdom
|
116 |
+
ncols = self.ncols
|
117 |
+
if ncols > 0: # show all the images in one visdom panel
|
118 |
+
ncols = min(ncols, len(visuals))
|
119 |
+
h, w = next(iter(visuals.values())).shape[:2]
|
120 |
+
table_css = """<style>
|
121 |
+
table {border-collapse: separate; border-spacing: 4px; white-space: nowrap; text-align: center}
|
122 |
+
table td {width: % dpx; height: % dpx; padding: 4px; outline: 4px solid black}
|
123 |
+
</style>""" % (w, h) # create a table css
|
124 |
+
# create a table of images.
|
125 |
+
title = self.name
|
126 |
+
label_html = ''
|
127 |
+
label_html_row = ''
|
128 |
+
images = []
|
129 |
+
idx = 0
|
130 |
+
for label, image in visuals.items():
|
131 |
+
image_numpy = util.tensor2im(image)
|
132 |
+
label_html_row += '<td>%s</td>' % label
|
133 |
+
images.append(image_numpy.transpose([2, 0, 1]))
|
134 |
+
idx += 1
|
135 |
+
if idx % ncols == 0:
|
136 |
+
label_html += '<tr>%s</tr>' % label_html_row
|
137 |
+
label_html_row = ''
|
138 |
+
white_image = np.ones_like(image_numpy.transpose([2, 0, 1])) * 255
|
139 |
+
while idx % ncols != 0:
|
140 |
+
images.append(white_image)
|
141 |
+
label_html_row += '<td></td>'
|
142 |
+
idx += 1
|
143 |
+
if label_html_row != '':
|
144 |
+
label_html += '<tr>%s</tr>' % label_html_row
|
145 |
+
try:
|
146 |
+
self.vis.images(images, ncols, 2, self.display_id + 1,
|
147 |
+
None, dict(title=title + ' images'))
|
148 |
+
label_html = '<table>%s</table>' % label_html
|
149 |
+
self.vis.text(table_css + label_html, win=self.display_id + 2,
|
150 |
+
opts=dict(title=title + ' labels'))
|
151 |
+
except VisdomExceptionBase:
|
152 |
+
self.create_visdom_connections()
|
153 |
+
|
154 |
+
else: # show each image in a separate visdom panel;
|
155 |
+
idx = 1
|
156 |
+
try:
|
157 |
+
for label, image in visuals.items():
|
158 |
+
image_numpy = util.tensor2im(image)
|
159 |
+
self.vis.image(
|
160 |
+
image_numpy.transpose([2, 0, 1]),
|
161 |
+
self.display_id + idx,
|
162 |
+
None,
|
163 |
+
dict(title=label)
|
164 |
+
)
|
165 |
+
idx += 1
|
166 |
+
except VisdomExceptionBase:
|
167 |
+
self.create_visdom_connections()
|
168 |
+
|
169 |
+
if self.use_html and (save_result or not self.saved): # save images to an HTML file if they haven't been saved.
|
170 |
+
self.saved = True
|
171 |
+
# save images to the disk
|
172 |
+
for label, image in visuals.items():
|
173 |
+
image_numpy = util.tensor2im(image)
|
174 |
+
img_path = os.path.join(self.img_dir, 'epoch%.3d_%s.png' % (epoch, label))
|
175 |
+
util.save_image(image_numpy, img_path)
|
176 |
+
|
177 |
+
# update website
|
178 |
+
webpage = html.HTML(self.web_dir, 'Experiment name = %s' % self.name, refresh=0)
|
179 |
+
for n in range(epoch, 0, -1):
|
180 |
+
webpage.add_header('epoch [%d]' % n)
|
181 |
+
ims, txts, links = [], [], []
|
182 |
+
|
183 |
+
for label, image_numpy in visuals.items():
|
184 |
+
image_numpy = util.tensor2im(image)
|
185 |
+
img_path = 'epoch%.3d_%s.png' % (n, label)
|
186 |
+
ims.append(img_path)
|
187 |
+
txts.append(label)
|
188 |
+
links.append(img_path)
|
189 |
+
webpage.add_images(ims, txts, links, width=self.win_size)
|
190 |
+
webpage.save()
|
191 |
+
|
192 |
+
def plot_current_losses(self, epoch, counter_ratio, losses):
|
193 |
+
"""display the current losses on visdom display: dictionary of error labels and values
|
194 |
+
|
195 |
+
Parameters:
|
196 |
+
epoch (int) -- current epoch
|
197 |
+
counter_ratio (float) -- progress (percentage) in the current epoch, between 0 to 1
|
198 |
+
losses (OrderedDict) -- training losses stored in the format of (name, float) pairs
|
199 |
+
"""
|
200 |
+
if len(losses) == 0:
|
201 |
+
return
|
202 |
+
|
203 |
+
plot_name = '_'.join(list(losses.keys()))
|
204 |
+
|
205 |
+
if plot_name not in self.plot_data:
|
206 |
+
self.plot_data[plot_name] = {'X': [], 'Y': [], 'legend': list(losses.keys())}
|
207 |
+
|
208 |
+
plot_data = self.plot_data[plot_name]
|
209 |
+
plot_id = list(self.plot_data.keys()).index(plot_name)
|
210 |
+
|
211 |
+
plot_data['X'].append(epoch + counter_ratio)
|
212 |
+
plot_data['Y'].append([losses[k] for k in plot_data['legend']])
|
213 |
+
try:
|
214 |
+
self.vis.line(
|
215 |
+
X=np.stack([np.array(plot_data['X'])] * len(plot_data['legend']), 1),
|
216 |
+
Y=np.array(plot_data['Y']),
|
217 |
+
opts={
|
218 |
+
'title': self.name,
|
219 |
+
'legend': plot_data['legend'],
|
220 |
+
'xlabel': 'epoch',
|
221 |
+
'ylabel': 'loss'},
|
222 |
+
win=self.display_id - plot_id)
|
223 |
+
except VisdomExceptionBase:
|
224 |
+
self.create_visdom_connections()
|
225 |
+
|
226 |
+
# losses: same format as |losses| of plot_current_losses
|
227 |
+
def print_current_losses(self, epoch, iters, losses, t_comp, t_data):
|
228 |
+
"""print current losses on console; also save the losses to the disk
|
229 |
+
|
230 |
+
Parameters:
|
231 |
+
epoch (int) -- current epoch
|
232 |
+
iters (int) -- current training iteration during this epoch (reset to 0 at the end of every epoch)
|
233 |
+
losses (OrderedDict) -- training losses stored in the format of (name, float) pairs
|
234 |
+
t_comp (float) -- computational time per data point (normalized by batch_size)
|
235 |
+
t_data (float) -- data loading time per data point (normalized by batch_size)
|
236 |
+
"""
|
237 |
+
message = '(epoch: %d, iters: %d, time: %.3f, data: %.3f) ' % (epoch, iters, t_comp, t_data)
|
238 |
+
# TODO:
|
239 |
+
# lambda_pair = math.cos(math.pi/40 * (epoch - 1))
|
240 |
+
# message += '[paired weight: %d] ' % lambda_pair
|
241 |
+
for k, v in losses.items():
|
242 |
+
message += '%s: %.3f ' % (k, v)
|
243 |
+
|
244 |
+
print(message) # print the message
|
245 |
+
with open(self.log_name, "a") as log_file:
|
246 |
+
log_file.write('%s\n' % message) # save the message
|
app.py
ADDED
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
|
3 |
+
from __future__ import annotations
|
4 |
+
|
5 |
+
import argparse
|
6 |
+
import torch
|
7 |
+
import gradio as gr
|
8 |
+
|
9 |
+
from Scenimefy.options.test_options import TestOptions
|
10 |
+
from Scenimefy.models import create_model
|
11 |
+
from Scenimefy.utils.util import tensor2im
|
12 |
+
|
13 |
+
from PIL import Image
|
14 |
+
import torchvision.transforms as transforms
|
15 |
+
|
16 |
+
|
17 |
+
def parse_args() -> argparse.Namespace:
|
18 |
+
parser = argparse.ArgumentParser()
|
19 |
+
parser.add_argument('--device', type=str, default='cpu')
|
20 |
+
parser.add_argument('--theme', type=str)
|
21 |
+
parser.add_argument('--live', action='store_true')
|
22 |
+
parser.add_argument('--share', action='store_true')
|
23 |
+
parser.add_argument('--port', type=int)
|
24 |
+
parser.add_argument('--disable-queue',
|
25 |
+
dest='enable_queue',
|
26 |
+
action='store_false')
|
27 |
+
parser.add_argument('--allow-flagging', type=str, default='never')
|
28 |
+
parser.add_argument('--allow-screenshot', action='store_true')
|
29 |
+
return parser.parse_args()
|
30 |
+
|
31 |
+
TITLE = '''
|
32 |
+
Scene Stylization with <a href="https://github.com/Yuxinn-J/Scenimefy">Scenimefy</a>
|
33 |
+
'''
|
34 |
+
DESCRIPTION = '''
|
35 |
+
<div align=center>
|
36 |
+
<p>
|
37 |
+
Gradio Demo for Scenimefy.
|
38 |
+
To use it, simply upload your image, or click one of the examples to load them.
|
39 |
+
For best outcomes, please pick a natural scene image similar to the examples below.
|
40 |
+
Kindly note that our model is trained on 256x256 resolution images, using much higher resolutions might affect its performance.
|
41 |
+
Read more at the links below.
|
42 |
+
</p>
|
43 |
+
</div>
|
44 |
+
'''
|
45 |
+
EXAMPLES = [['0.png'], ['1.jpg'], ['2.png'], ['3.png'], ['4.jpg'], ['5.png'], ['6.jpg'], ['7.png'], ['8.png']]
|
46 |
+
ARTICLE = r"""
|
47 |
+
If Scenimefy is helpful, please help to ⭐ the <a href='https://github.com/Yuxinn-J/Scenimefy' target='_blank'>Github Repo</a>. Thank you!
|
48 |
+
🤟 **Citation**
|
49 |
+
If our work is useful for your research, please consider citing:
|
50 |
+
```bibtex
|
51 |
+
@inproceedings{jiang2023scenimefy,
|
52 |
+
title={Scenimefy: Learning to Craft Anime Scene via Semi-Supervised Image-to-Image Translation},
|
53 |
+
author={Jiang, Yuxin and Jiang, Liming and Yang, Shuai and Loy, Chen Change},
|
54 |
+
booktitle={ICCV},
|
55 |
+
year={2023}
|
56 |
+
}
|
57 |
+
```
|
58 |
+
🗞️ **License**
|
59 |
+
This project is licensed under <a rel="license" href="https://github.com/Yuxinn-J/Scenimefy/blob/main/LICENSE.md">S-Lab License 1.0</a>.
|
60 |
+
Redistribution and use for non-commercial purposes should follow this license.
|
61 |
+
"""
|
62 |
+
|
63 |
+
|
64 |
+
model = None
|
65 |
+
|
66 |
+
|
67 |
+
def initialize():
|
68 |
+
opt = TestOptions().parse() # get test options
|
69 |
+
# os.environ["CUDA_VISIBLE_DEVICES"] = str(1)
|
70 |
+
# hard-code some parameters for test
|
71 |
+
opt.num_threads = 0 # test code only supports num_threads = 1
|
72 |
+
opt.batch_size = 1 # test code only supports batch_size = 1
|
73 |
+
opt.serial_batches = True # disable data shuffling; comment this line if results on randomly chosen images are needed.
|
74 |
+
opt.no_flip = True # no flip; comment this line if results on flipped images are needed.
|
75 |
+
opt.display_id = -1 # no visdom display; the test code saves the results to a HTML file.
|
76 |
+
|
77 |
+
# dataset = create_dataset(opt) # create a dataset given opt.dataset_mode and other options
|
78 |
+
global model
|
79 |
+
model = create_model(opt) # create a model given opt.model and other options
|
80 |
+
|
81 |
+
dummy_data = {
|
82 |
+
'A': torch.ones(1, 3, 256, 256),
|
83 |
+
'B': torch.ones(1, 3, 256, 256),
|
84 |
+
'A_paths': 'upload.jpg'
|
85 |
+
}
|
86 |
+
|
87 |
+
model.data_dependent_initialize(dummy_data)
|
88 |
+
model.setup(opt) # regular setup: load and print networks; create schedulers
|
89 |
+
model.parallelize()
|
90 |
+
return model
|
91 |
+
|
92 |
+
|
93 |
+
def __make_power_2(img, base, method=Image.BICUBIC):
|
94 |
+
ow, oh = img.size
|
95 |
+
h = int(round(oh / base) * base)
|
96 |
+
w = int(round(ow / base) * base)
|
97 |
+
if h == oh and w == ow:
|
98 |
+
return img
|
99 |
+
|
100 |
+
return img.resize((w, h), method)
|
101 |
+
|
102 |
+
|
103 |
+
def get_transform():
|
104 |
+
method=Image.BICUBIC
|
105 |
+
transform_list = []
|
106 |
+
# if opt.preprocess == 'none':
|
107 |
+
transform_list.append(transforms.Lambda(lambda img: __make_power_2(img, base=4, method=method)))
|
108 |
+
transform_list += [transforms.ToTensor()]
|
109 |
+
transform_list += [transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
|
110 |
+
return transforms.Compose(transform_list)
|
111 |
+
|
112 |
+
|
113 |
+
def inference(img):
|
114 |
+
transform = get_transform()
|
115 |
+
A = transform(img.convert('RGB')) # A.shape: torch.Size([3, 260, 460])
|
116 |
+
A = A.unsqueeze(0) # A.shape: torch.Size([1, 3, 260, 460])
|
117 |
+
|
118 |
+
upload_data = {
|
119 |
+
'A': A,
|
120 |
+
'B': torch.ones_like(A),
|
121 |
+
'A_paths': 'upload.jpg'
|
122 |
+
}
|
123 |
+
|
124 |
+
global model
|
125 |
+
model.set_input(upload_data) # unpack data from data loader
|
126 |
+
model.test() # run inference
|
127 |
+
visuals = model.get_current_visuals()
|
128 |
+
return tensor2im(visuals['fake_B'])
|
129 |
+
|
130 |
+
|
131 |
+
def main():
|
132 |
+
args = parse_args()
|
133 |
+
args.device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
134 |
+
print('*** Now using %s.'%(args.device))
|
135 |
+
|
136 |
+
global model
|
137 |
+
model = initialize()
|
138 |
+
|
139 |
+
gr.Interface(
|
140 |
+
inference,
|
141 |
+
gr.Image(type="pil", label='Input'),
|
142 |
+
gr.Image(type="pil", label='Output').style(height=300),
|
143 |
+
theme=args.theme,
|
144 |
+
title=TITLE,
|
145 |
+
description=DESCRIPTION,
|
146 |
+
article=ARTICLE,
|
147 |
+
examples=EXAMPLES,
|
148 |
+
allow_screenshot=args.allow_screenshot,
|
149 |
+
allow_flagging=args.allow_flagging,
|
150 |
+
live=args.live
|
151 |
+
).launch(
|
152 |
+
enable_queue=args.enable_queue,
|
153 |
+
server_port=args.port,
|
154 |
+
share=args.share
|
155 |
+
)
|
156 |
+
|
157 |
+
if __name__ == '__main__':
|
158 |
+
main()
|
packages.txt
ADDED
File without changes
|
requirements.txt
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch
|
2 |
+
torchvision
|
3 |
+
numpy
|
4 |
+
Pillow
|
5 |
+
scipy
|
6 |
+
dominate
|