Commit
·
4a5aa3d
1
Parent(s):
a0c5927
added image_colourization_cgan module
Browse files
image_colourization_cgan/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
import os, sys
|
2 |
+
|
3 |
+
sys.path.append(os.path.dirname(os.path.realpath(__file__)))
|
image_colourization_cgan/image_utils.py
ADDED
@@ -0,0 +1,196 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import skimage
|
3 |
+
import numpy as np
|
4 |
+
from skimage.transform import resize
|
5 |
+
from skimage.io import imread, imsave
|
6 |
+
from skimage.color import rgb2lab, lab2rgb, rgb2gray
|
7 |
+
|
8 |
+
|
9 |
+
def resize_image(img, image_size=(320, 320)):
|
10 |
+
"""
|
11 |
+
---------
|
12 |
+
Arguments
|
13 |
+
---------
|
14 |
+
img : ndarray
|
15 |
+
ndarray of shape (H, W, 3) or (H, W) i.e. RGB and grayscale respectively
|
16 |
+
image_size : tuple of ints
|
17 |
+
image size to be used for resizing
|
18 |
+
|
19 |
+
-------
|
20 |
+
Returns
|
21 |
+
-------
|
22 |
+
resized image ndarray of shape (H_resized, W_resized, 3) or (H_resized, W_resized)
|
23 |
+
if RGB returns resized image ndarray in range [0, 255]
|
24 |
+
if grasycale returns resized image ndarray in range [0, 1]
|
25 |
+
"""
|
26 |
+
img_resized = resize(img, image_size)
|
27 |
+
return img_resized
|
28 |
+
|
29 |
+
|
30 |
+
def convert_rgb2gray(img_rgb):
|
31 |
+
"""
|
32 |
+
---------
|
33 |
+
Arguments
|
34 |
+
---------
|
35 |
+
img_rgb : ndarray
|
36 |
+
ndarray of shape (H, W, 3) i.e. RGB
|
37 |
+
|
38 |
+
-------
|
39 |
+
Returns
|
40 |
+
-------
|
41 |
+
grayscale image ndarray of shape (H, W)
|
42 |
+
"""
|
43 |
+
img_gray = rgb2gray(img_rgb)
|
44 |
+
return img_gray
|
45 |
+
|
46 |
+
|
47 |
+
def convert_lab2rgb(img_lab):
|
48 |
+
"""
|
49 |
+
---------
|
50 |
+
Arguments
|
51 |
+
---------
|
52 |
+
img_lab : ndarray
|
53 |
+
ndarray of shape (H, W, 3) i.e. Lab
|
54 |
+
|
55 |
+
-------
|
56 |
+
Returns
|
57 |
+
-------
|
58 |
+
RGB image ndarray of shape (H, W, 3) i.e. RGB space
|
59 |
+
"""
|
60 |
+
img_rgb = lab2rgb(img_lab)
|
61 |
+
return img_rgb
|
62 |
+
|
63 |
+
|
64 |
+
def convert_rgb2lab(img_rgb):
|
65 |
+
"""
|
66 |
+
---------
|
67 |
+
Arguments
|
68 |
+
---------
|
69 |
+
img_rgb : ndarray
|
70 |
+
ndarray of shape (H, W, 3) i.e. RGB
|
71 |
+
|
72 |
+
-------
|
73 |
+
Returns
|
74 |
+
-------
|
75 |
+
Lab image ndarray of shape (H, W, 3) i.e. Lab space
|
76 |
+
"""
|
77 |
+
img_lab = rgb2lab(img_rgb)
|
78 |
+
return img_lab
|
79 |
+
|
80 |
+
|
81 |
+
def apply_image_ab_post_processing(img_ab):
|
82 |
+
"""
|
83 |
+
---------
|
84 |
+
Arguments
|
85 |
+
---------
|
86 |
+
img_ab : ndarray
|
87 |
+
pre-processed ndarray of shape (H, W, 2) i.e. ab channels in Lab space in range [-1, 1]
|
88 |
+
|
89 |
+
-------
|
90 |
+
Returns
|
91 |
+
-------
|
92 |
+
post-processed ab channels ndarray of shape (H, W, 2) in range [-110, 110]
|
93 |
+
"""
|
94 |
+
img_ab = img_ab * 110.0
|
95 |
+
return img_ab
|
96 |
+
|
97 |
+
|
98 |
+
def apply_image_l_pre_processing(img_l):
|
99 |
+
"""
|
100 |
+
---------
|
101 |
+
Arguments
|
102 |
+
---------
|
103 |
+
img_l : ndarray
|
104 |
+
ndarray of shape (H, W) i.e. L channel in Lab space in range [0, 100]
|
105 |
+
|
106 |
+
-------
|
107 |
+
Returns
|
108 |
+
-------
|
109 |
+
pre-processed L channel ndarray of shape (H, W) in range [-1, 1]
|
110 |
+
"""
|
111 |
+
img_l = (img_l / 50.0) - 1
|
112 |
+
return img_l
|
113 |
+
|
114 |
+
|
115 |
+
def apply_image_ab_pre_processing(img_ab):
|
116 |
+
"""
|
117 |
+
---------
|
118 |
+
Arguments
|
119 |
+
---------
|
120 |
+
img_ab : ndarray
|
121 |
+
ndarray of shape (H, W, 2) i.e. ab channels in Lab space in range [-110, 110]
|
122 |
+
|
123 |
+
-------
|
124 |
+
Returns
|
125 |
+
-------
|
126 |
+
pre-processed ab channels ndarray of shape (H, W, 2) in range [-1, 1]
|
127 |
+
"""
|
128 |
+
img_ab = (img_ab) / 110.0
|
129 |
+
return img_ab
|
130 |
+
|
131 |
+
|
132 |
+
def concat_images_l_ab(img_l, img_ab):
|
133 |
+
"""
|
134 |
+
---------
|
135 |
+
Arguments
|
136 |
+
---------
|
137 |
+
img_l : ndarray
|
138 |
+
ndarray of shape (H, W, 1) i.e. L channel
|
139 |
+
img_ab : ndarray
|
140 |
+
ndarray of shape (H, W, 2) i.e. ab channels
|
141 |
+
|
142 |
+
-------
|
143 |
+
Returns
|
144 |
+
-------
|
145 |
+
Lab space ndarray of shape (H, W, 3)
|
146 |
+
"""
|
147 |
+
img_lab = np.concatenate((img_l, img_ab), axis=-1)
|
148 |
+
return img_lab
|
149 |
+
|
150 |
+
|
151 |
+
def read_image(file_img):
|
152 |
+
"""
|
153 |
+
---------
|
154 |
+
Arguments
|
155 |
+
---------
|
156 |
+
file_img : str
|
157 |
+
full path of the image
|
158 |
+
|
159 |
+
-------
|
160 |
+
Returns
|
161 |
+
-------
|
162 |
+
ndarray of shape (H, W, 3) for RGB or (H, W) for grayscale
|
163 |
+
"""
|
164 |
+
img = imread(file_img)
|
165 |
+
return img
|
166 |
+
|
167 |
+
|
168 |
+
def save_image_rgb(file_img, img_arr):
|
169 |
+
"""
|
170 |
+
---------
|
171 |
+
Arguments
|
172 |
+
---------
|
173 |
+
file_img : str
|
174 |
+
full path of the image
|
175 |
+
img_arr : ndarray
|
176 |
+
image ndarray to be saved, of shape (H, W, 3) for RGB or (H, W) for grasycale
|
177 |
+
"""
|
178 |
+
imsave(file_img, img_arr)
|
179 |
+
return
|
180 |
+
|
181 |
+
|
182 |
+
def rescale_grayscale_image_l_channel(img_gray):
|
183 |
+
"""
|
184 |
+
---------
|
185 |
+
Arguments
|
186 |
+
---------
|
187 |
+
img_gray : ndarray
|
188 |
+
grayscale image of shape (H, W) in range [0, 1]
|
189 |
+
|
190 |
+
-------
|
191 |
+
Returns
|
192 |
+
-------
|
193 |
+
L channel ndarray of shape (H, W) in range [0, 100]
|
194 |
+
"""
|
195 |
+
img_l_rescaled = (img_gray) * 100.0
|
196 |
+
return img_l_rescaled
|
image_colourization_cgan/loss.py
ADDED
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
|
6 |
+
class GANLoss(nn.Module):
|
7 |
+
"""
|
8 |
+
Define different GAN objectives.
|
9 |
+
The GANLoss class abstracts away the need to create the target label tensor
|
10 |
+
that has the same size as the input.
|
11 |
+
"""
|
12 |
+
|
13 |
+
def __init__(self, loss_mode="vanilla", real_label=1.0, fake_label=0.0):
|
14 |
+
"""
|
15 |
+
---------
|
16 |
+
Arguments
|
17 |
+
---------
|
18 |
+
loss_mode : str
|
19 |
+
GAN loss mode (default="vanilla")
|
20 |
+
real_label : bool
|
21 |
+
label for real image
|
22 |
+
fake_label : bool
|
23 |
+
label for fake image
|
24 |
+
"""
|
25 |
+
super().__init__()
|
26 |
+
self.loss_mode = loss_mode
|
27 |
+
self.register_buffer("real_label", torch.tensor(real_label))
|
28 |
+
self.register_buffer("fake_label", torch.tensor(fake_label))
|
29 |
+
|
30 |
+
self.loss = None
|
31 |
+
if self.loss_mode == "vanilla":
|
32 |
+
self.loss = nn.BCEWithLogitsLoss()
|
33 |
+
else:
|
34 |
+
raise NotImplementedError(
|
35 |
+
f"GANLoss with {self.loss_mode} mode - not implemented yet"
|
36 |
+
)
|
37 |
+
|
38 |
+
def get_target_tensor(self, prediction, target_is_real):
|
39 |
+
"""
|
40 |
+
---------
|
41 |
+
Arguments
|
42 |
+
---------
|
43 |
+
prediction : tensor
|
44 |
+
prediction from a discriminator
|
45 |
+
target_is_real : bool
|
46 |
+
whether the groundtruth label is for a real image or a fake image
|
47 |
+
|
48 |
+
-------
|
49 |
+
Returns
|
50 |
+
-------
|
51 |
+
tensor : A label tensor filled with groundtruth label with the same size as that of input
|
52 |
+
"""
|
53 |
+
if target_is_real:
|
54 |
+
target_tensor = self.real_label
|
55 |
+
else:
|
56 |
+
target_tensor = self.fake_label
|
57 |
+
return target_tensor.expand_as(prediction)
|
58 |
+
|
59 |
+
def __call__(self, prediction, target_is_real):
|
60 |
+
"""
|
61 |
+
---------
|
62 |
+
Arguments
|
63 |
+
---------
|
64 |
+
prediction : tensor
|
65 |
+
prediction from a discriminator
|
66 |
+
target_is_real : bool
|
67 |
+
whether the groundtruth label is for a real image or a fake image
|
68 |
+
|
69 |
+
-------
|
70 |
+
Returns
|
71 |
+
-------
|
72 |
+
loss : the computed loss
|
73 |
+
"""
|
74 |
+
if self.loss_mode == "vanilla":
|
75 |
+
target_tensor = self.get_target_tensor(prediction, target_is_real)
|
76 |
+
loss = self.loss(prediction, target_tensor)
|
77 |
+
else:
|
78 |
+
loss = 0
|
79 |
+
return loss
|
image_colourization_cgan/model.py
ADDED
@@ -0,0 +1,461 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from torchvision.models import resnet34
|
5 |
+
|
6 |
+
from loss import GANLoss
|
7 |
+
|
8 |
+
|
9 |
+
class ResNetEncoder(nn.Module):
|
10 |
+
"""
|
11 |
+
Defines ResNet-34 Encoder
|
12 |
+
"""
|
13 |
+
|
14 |
+
def __init__(self, pretrained=True):
|
15 |
+
"""
|
16 |
+
---------
|
17 |
+
Arguments
|
18 |
+
---------
|
19 |
+
pretrained : bool (default=True)
|
20 |
+
boolean to control whether to use a pretrained resnet model or not
|
21 |
+
"""
|
22 |
+
super().__init__()
|
23 |
+
self.resnet34 = resnet34(pretrained=pretrained)
|
24 |
+
|
25 |
+
def forward(self, x):
|
26 |
+
self.block1 = self.resnet34.conv1(x)
|
27 |
+
self.block1 = self.resnet34.bn1(self.block1)
|
28 |
+
self.block1 = self.resnet34.relu(self.block1) # [64, H/2, W/2]
|
29 |
+
|
30 |
+
self.block2 = self.resnet34.maxpool(self.block1)
|
31 |
+
self.block2 = self.resnet34.layer1(self.block2) # [64, H/4, W/4]
|
32 |
+
self.block3 = self.resnet34.layer2(self.block2) # [128, H/8, W/8]
|
33 |
+
self.block4 = self.resnet34.layer3(self.block3) # [256, H/16, W/16]
|
34 |
+
self.block5 = self.resnet34.layer4(self.block4) # [512, H/32, W/32]
|
35 |
+
return self.block5
|
36 |
+
|
37 |
+
|
38 |
+
class UNetDecoder(nn.Module):
|
39 |
+
"""
|
40 |
+
Defines UNet Decoder
|
41 |
+
"""
|
42 |
+
|
43 |
+
def __init__(self, encoder_net, out_channels=2):
|
44 |
+
"""
|
45 |
+
---------
|
46 |
+
Arguments
|
47 |
+
---------
|
48 |
+
encoder_net : PyTorch model object of the encoder
|
49 |
+
PyTorch model object of the encoder
|
50 |
+
out_channels : int (default=2)
|
51 |
+
number of output channels of UNet Decoder
|
52 |
+
"""
|
53 |
+
super().__init__()
|
54 |
+
self.encoder_net = encoder_net
|
55 |
+
self.up_block1 = self.up_conv_block(512, 256, use_dropout=True)
|
56 |
+
self.conv_reduction_1 = nn.Conv2d(512, 256, kernel_size=1)
|
57 |
+
|
58 |
+
self.up_block2 = self.up_conv_block(256, 128, use_dropout=True)
|
59 |
+
self.conv_reduction_2 = nn.Conv2d(256, 128, kernel_size=1)
|
60 |
+
|
61 |
+
self.up_block3 = self.up_conv_block(128, 64)
|
62 |
+
self.conv_reduction_3 = nn.Conv2d(128, 64, kernel_size=1)
|
63 |
+
|
64 |
+
self.up_block4 = self.up_conv_block(64, 64)
|
65 |
+
self.conv_reduction_4 = nn.Conv2d(128, 64, kernel_size=1)
|
66 |
+
|
67 |
+
self.up_block5 = self.final_up_conv_block(
|
68 |
+
conv_tr_in_channels=64, conv_tr_out_channels=32, out_channels=out_channels
|
69 |
+
)
|
70 |
+
|
71 |
+
def forward(self, x):
|
72 |
+
self.up_1 = self.up_block1(x) # [256, H/16, W/16]
|
73 |
+
self.up_1 = torch.cat(
|
74 |
+
[self.encoder_net.block4, self.up_1], dim=1
|
75 |
+
) # [512, H/16, W/16]
|
76 |
+
self.up_1 = self.conv_reduction_1(self.up_1) # [256, H/16, W/16]
|
77 |
+
|
78 |
+
self.up_2 = self.up_block2(self.up_1) # [128, H/8, W/8]
|
79 |
+
self.up_2 = torch.cat(
|
80 |
+
[self.encoder_net.block3, self.up_2], dim=1
|
81 |
+
) # [256, H/8, H/8]
|
82 |
+
self.up_2 = self.conv_reduction_2(self.up_2) # [128, H/8, W/8]
|
83 |
+
|
84 |
+
self.up_3 = self.up_block3(self.up_2) # [64, H/4, W/4]
|
85 |
+
self.up_3 = torch.cat(
|
86 |
+
[self.encoder_net.block2, self.up_3], dim=1
|
87 |
+
) # [128, H/4, W/4]
|
88 |
+
self.up_3 = self.conv_reduction_3(self.up_3) # [64, H/4, W/4]
|
89 |
+
|
90 |
+
self.up_4 = self.up_block4(self.up_3) # [64, H/2, W/2]
|
91 |
+
self.up_4 = torch.cat(
|
92 |
+
[self.encoder_net.block1, self.up_4], dim=1
|
93 |
+
) # [128, H/2, W/2]
|
94 |
+
self.up_4 = self.conv_reduction_4(self.up_4) # [64, H/2, W/2]
|
95 |
+
|
96 |
+
self.out_features = self.up_block5(self.up_4) # [2, H, W]
|
97 |
+
return self.out_features
|
98 |
+
|
99 |
+
def final_up_conv_block(
|
100 |
+
self,
|
101 |
+
conv_tr_in_channels,
|
102 |
+
conv_tr_out_channels,
|
103 |
+
out_channels,
|
104 |
+
conv_tr_kernel_size=4,
|
105 |
+
):
|
106 |
+
"""
|
107 |
+
---------
|
108 |
+
Arguments
|
109 |
+
---------
|
110 |
+
conv_tr_in_channels : int
|
111 |
+
number of input channels for conv transpose
|
112 |
+
conv_tr_out_channels : int
|
113 |
+
number of output channels for conv transpose
|
114 |
+
out_channels : int
|
115 |
+
number of output channels in the final layer
|
116 |
+
conv_tr_kernel_size : int (default=4)
|
117 |
+
kernel size for convolution transpose layer
|
118 |
+
|
119 |
+
-------
|
120 |
+
Returns
|
121 |
+
-------
|
122 |
+
A sequential block depending on the input arguments
|
123 |
+
"""
|
124 |
+
final_block = nn.Sequential(
|
125 |
+
nn.ReLU(),
|
126 |
+
nn.ConvTranspose2d(
|
127 |
+
conv_tr_in_channels,
|
128 |
+
conv_tr_out_channels,
|
129 |
+
kernel_size=conv_tr_kernel_size,
|
130 |
+
stride=2,
|
131 |
+
padding=1,
|
132 |
+
bias=False,
|
133 |
+
),
|
134 |
+
nn.Conv2d(conv_tr_out_channels, out_channels, kernel_size=1),
|
135 |
+
nn.Tanh(),
|
136 |
+
)
|
137 |
+
return final_block
|
138 |
+
|
139 |
+
def up_conv_block(
|
140 |
+
self, in_channels, out_channels, conv_tr_kernel_size=4, use_dropout=False
|
141 |
+
):
|
142 |
+
"""
|
143 |
+
---------
|
144 |
+
Arguments
|
145 |
+
---------
|
146 |
+
in_channels : int
|
147 |
+
number of input channels
|
148 |
+
out_channels : int
|
149 |
+
number of output channels
|
150 |
+
use_dropout : bool (default=False)
|
151 |
+
boolean to control whether to use dropout or not [induces randomness - used instead of random noise vector as input in Generator]
|
152 |
+
conv_tr_kernel_size : int (default=4)
|
153 |
+
kernel size for convolution transpose layer
|
154 |
+
|
155 |
+
-------
|
156 |
+
Returns
|
157 |
+
-------
|
158 |
+
A sequential block depending on the input arguments
|
159 |
+
"""
|
160 |
+
if use_dropout:
|
161 |
+
block = nn.Sequential(
|
162 |
+
nn.ReLU(),
|
163 |
+
nn.ConvTranspose2d(
|
164 |
+
in_channels,
|
165 |
+
out_channels,
|
166 |
+
kernel_size=conv_tr_kernel_size,
|
167 |
+
stride=2,
|
168 |
+
padding=1,
|
169 |
+
bias=False,
|
170 |
+
),
|
171 |
+
nn.BatchNorm2d(out_channels),
|
172 |
+
nn.Dropout(0.5),
|
173 |
+
)
|
174 |
+
else:
|
175 |
+
block = nn.Sequential(
|
176 |
+
nn.ReLU(),
|
177 |
+
nn.ConvTranspose2d(
|
178 |
+
in_channels,
|
179 |
+
out_channels,
|
180 |
+
kernel_size=conv_tr_kernel_size,
|
181 |
+
stride=2,
|
182 |
+
padding=1,
|
183 |
+
bias=False,
|
184 |
+
),
|
185 |
+
nn.BatchNorm2d(out_channels),
|
186 |
+
)
|
187 |
+
return block
|
188 |
+
|
189 |
+
|
190 |
+
class ResUNet(nn.Module):
|
191 |
+
"""
|
192 |
+
Defines Residual UNet model
|
193 |
+
"""
|
194 |
+
|
195 |
+
def __init__(self, pretrained=True):
|
196 |
+
super().__init__()
|
197 |
+
self.encoder_net = ResNetEncoder(pretrained=pretrained)
|
198 |
+
self.decoder_net = UNetDecoder(self.encoder_net)
|
199 |
+
|
200 |
+
def forward(self, x):
|
201 |
+
self.encoder_features = self.encoder_net(x)
|
202 |
+
self.decoder_features = self.decoder_net(self.encoder_features)
|
203 |
+
return self.decoder_features
|
204 |
+
|
205 |
+
|
206 |
+
class Generator(nn.Module):
|
207 |
+
"""
|
208 |
+
Defines a Generator in a GAN
|
209 |
+
"""
|
210 |
+
|
211 |
+
def __init__(self, pretrained=True):
|
212 |
+
super().__init__()
|
213 |
+
self.res_u_net = ResUNet(pretrained=pretrained)
|
214 |
+
|
215 |
+
def forward(self, x):
|
216 |
+
return self.res_u_net(x)
|
217 |
+
|
218 |
+
|
219 |
+
class PatchDiscriminatorGAN(nn.Module):
|
220 |
+
"""
|
221 |
+
Defines a Patch discriminator for GAN
|
222 |
+
"""
|
223 |
+
|
224 |
+
def __init__(self, in_channels, num_filters=64, num_blocks=3):
|
225 |
+
"""
|
226 |
+
---------
|
227 |
+
Arguments
|
228 |
+
---------
|
229 |
+
in_channels : int
|
230 |
+
number of input channels for Discriminator
|
231 |
+
num_filters : int (default=64)
|
232 |
+
number of filters in the first layer of Discriminator
|
233 |
+
num_blocks : int (default=3)
|
234 |
+
number of blocks to be used in the Discriminator
|
235 |
+
"""
|
236 |
+
super().__init__()
|
237 |
+
model_blocks = [
|
238 |
+
self.get_conv_block(in_channels, num_filters, is_batch_norm=False)
|
239 |
+
]
|
240 |
+
for i in range(num_blocks):
|
241 |
+
if i != num_blocks - 1:
|
242 |
+
model_blocks += [
|
243 |
+
self.get_conv_block(
|
244 |
+
num_filters * (2**i), num_filters * (2 ** (i + 1))
|
245 |
+
)
|
246 |
+
]
|
247 |
+
else:
|
248 |
+
model_blocks += [
|
249 |
+
self.get_conv_block(
|
250 |
+
num_filters * (2**i), num_filters * (2 ** (i + 1)), stride=1
|
251 |
+
)
|
252 |
+
]
|
253 |
+
model_blocks += [
|
254 |
+
self.get_conv_block(
|
255 |
+
num_filters * (2**num_blocks),
|
256 |
+
1,
|
257 |
+
stride=1,
|
258 |
+
is_batch_norm=False,
|
259 |
+
is_activation=False,
|
260 |
+
)
|
261 |
+
]
|
262 |
+
self.model = nn.Sequential(*model_blocks)
|
263 |
+
|
264 |
+
def get_conv_block(
|
265 |
+
self,
|
266 |
+
in_channels,
|
267 |
+
out_channels,
|
268 |
+
kernel_size=4,
|
269 |
+
stride=2,
|
270 |
+
padding=1,
|
271 |
+
is_batch_norm=True,
|
272 |
+
is_activation=True,
|
273 |
+
):
|
274 |
+
"""
|
275 |
+
---------
|
276 |
+
Arguments
|
277 |
+
---------
|
278 |
+
in_channels : int
|
279 |
+
input number of channels
|
280 |
+
out_channels : int
|
281 |
+
output number of channels
|
282 |
+
kernel_size : int
|
283 |
+
convolution kernel size
|
284 |
+
stride : int
|
285 |
+
stride to be used for convolution
|
286 |
+
padding : int
|
287 |
+
padding to be used for convolution
|
288 |
+
is_batch_norm : bool
|
289 |
+
boolean to control whether to add a batchnorm layer to the block
|
290 |
+
is_activation : bool
|
291 |
+
boolean to control whether to add an activation function to the block
|
292 |
+
|
293 |
+
-------
|
294 |
+
Returns
|
295 |
+
-------
|
296 |
+
a sequential block depending on the input arguments
|
297 |
+
"""
|
298 |
+
block = [
|
299 |
+
nn.Conv2d(
|
300 |
+
in_channels,
|
301 |
+
out_channels,
|
302 |
+
kernel_size=kernel_size,
|
303 |
+
stride=stride,
|
304 |
+
padding=padding,
|
305 |
+
bias=not (is_batch_norm),
|
306 |
+
)
|
307 |
+
]
|
308 |
+
if is_batch_norm:
|
309 |
+
block += [nn.BatchNorm2d(out_channels)]
|
310 |
+
if is_activation:
|
311 |
+
block += [nn.ELU()]
|
312 |
+
return nn.Sequential(*block)
|
313 |
+
|
314 |
+
def forward(self, x):
|
315 |
+
return self.model(x)
|
316 |
+
|
317 |
+
|
318 |
+
class ImageToImageConditionalGAN(nn.Module):
|
319 |
+
"""
|
320 |
+
Defines Image (domain A) to Image (domain B) Conditional Adversarial Network
|
321 |
+
"""
|
322 |
+
|
323 |
+
def __init__(
|
324 |
+
self,
|
325 |
+
device,
|
326 |
+
pretrained=False,
|
327 |
+
lr_gen=2e-4,
|
328 |
+
lr_dis=2e-4,
|
329 |
+
beta1=0.5,
|
330 |
+
beta2=0.999,
|
331 |
+
lambda_=100.0,
|
332 |
+
):
|
333 |
+
super().__init__()
|
334 |
+
self.device = device
|
335 |
+
self.loss_names = ["gen_gan", "gen_l1", "dis_real", "dis_fake"]
|
336 |
+
self.lambda_ = lambda_
|
337 |
+
self.net_gen = Generator(pretrained=pretrained)
|
338 |
+
self.net_dis = PatchDiscriminatorGAN(in_channels=3)
|
339 |
+
|
340 |
+
self.criterion_GAN = GANLoss().to(self.device)
|
341 |
+
self.criterion_l1 = nn.L1Loss()
|
342 |
+
|
343 |
+
self.optimizer_gen = torch.optim.Adam(
|
344 |
+
self.net_gen.parameters(), lr=lr_gen, betas=(beta1, beta2)
|
345 |
+
)
|
346 |
+
self.optimizer_dis = torch.optim.Adam(
|
347 |
+
self.net_dis.parameters(), lr=lr_dis, betas=(beta1, beta2)
|
348 |
+
)
|
349 |
+
|
350 |
+
def set_requires_grad(self, model, requires_grad=True):
|
351 |
+
"""
|
352 |
+
---------
|
353 |
+
Arguments
|
354 |
+
---------
|
355 |
+
model : model object
|
356 |
+
PyTorch model object
|
357 |
+
requires_grad : bool (default=True)
|
358 |
+
boolean to control whether the model requires gradients or not
|
359 |
+
"""
|
360 |
+
for param in model.parameters():
|
361 |
+
param.requires_grad = requires_grad
|
362 |
+
|
363 |
+
def setup_input(self, data):
|
364 |
+
"""
|
365 |
+
---------
|
366 |
+
Arguments
|
367 |
+
---------
|
368 |
+
data : dict
|
369 |
+
dictionary object containing image data of domains 1 and 2
|
370 |
+
"""
|
371 |
+
self.real_domain_1 = data["domain_1"].to(self.device)
|
372 |
+
self.real_domain_2 = data["domain_2"].to(self.device)
|
373 |
+
|
374 |
+
if self.device == torch.device("cuda"):
|
375 |
+
self.real_domain_1_1_ch = self.real_domain_1[:, 0, :, :]
|
376 |
+
self.real_domain_1_1_ch = self.real_domain_1_1_ch[:, None, :, :]
|
377 |
+
else:
|
378 |
+
self.real_domain_1_1_ch = self.real_domain_1[:, :, :, 0]
|
379 |
+
self.real_domain_1_1_ch = self.real_domain_1_1_ch[:, :, :, None]
|
380 |
+
|
381 |
+
def forward(self):
|
382 |
+
# compute fake image in domain_2: Generator(domain_1)
|
383 |
+
self.fake_domain_2 = self.net_gen(self.real_domain_1)
|
384 |
+
|
385 |
+
def backward_gen(self):
|
386 |
+
"""
|
387 |
+
Calculate GAN and L1 loss for generator
|
388 |
+
"""
|
389 |
+
# first, Generator(domain_1) should try to fool the Discriminator
|
390 |
+
fake_domain_12 = torch.cat((self.real_domain_1_1_ch, self.fake_domain_2), dim=1)
|
391 |
+
pred_fake = self.net_dis(fake_domain_12)
|
392 |
+
self.loss_gen_gan = self.criterion_GAN(pred_fake, True)
|
393 |
+
|
394 |
+
# second, Generator(domain_1) = domain_2,
|
395 |
+
# i.e. output predicted by Generator should be close the domain_2
|
396 |
+
self.loss_gen_l1 = (
|
397 |
+
self.criterion_l1(self.fake_domain_2, self.real_domain_2) * self.lambda_
|
398 |
+
)
|
399 |
+
|
400 |
+
# compute the combined loss
|
401 |
+
self.loss_gen = self.loss_gen_gan + self.loss_gen_l1
|
402 |
+
self.loss_gen.backward()
|
403 |
+
|
404 |
+
def backward_dis(self):
|
405 |
+
"""
|
406 |
+
Calculate GAN loss for discriminator
|
407 |
+
"""
|
408 |
+
# Fake
|
409 |
+
fake_domain_12 = torch.cat((self.real_domain_1_1_ch, self.fake_domain_2), dim=1)
|
410 |
+
# stop backprop to generator by detaching fake_domain_12
|
411 |
+
pred_fake = self.net_dis(fake_domain_12.detach())
|
412 |
+
# Discriminator should identify the fake image
|
413 |
+
self.loss_dis_fake = self.criterion_GAN(pred_fake, False)
|
414 |
+
|
415 |
+
# Real
|
416 |
+
real_domain_12 = torch.cat((self.real_domain_1_1_ch, self.real_domain_2), dim=1)
|
417 |
+
pred_real = self.net_dis(real_domain_12)
|
418 |
+
# Discriminator should identify the real image
|
419 |
+
self.loss_dis_real = self.criterion_GAN(pred_real, True)
|
420 |
+
|
421 |
+
# compute the combined loss
|
422 |
+
self.loss_dis = (self.loss_dis_fake + self.loss_dis_real) * 0.5
|
423 |
+
self.loss_dis.backward()
|
424 |
+
|
425 |
+
def optimize_params(self):
|
426 |
+
# compute fake image in domain_2: Generator(domain_1)
|
427 |
+
self.forward()
|
428 |
+
|
429 |
+
"""
|
430 |
+
--------------------
|
431 |
+
Update Discriminator
|
432 |
+
--------------------
|
433 |
+
# enable backprop for Discriminator
|
434 |
+
# set Discriminator's gradients to zero
|
435 |
+
# compute gradients for Discriminator
|
436 |
+
# update Discriminator's weights
|
437 |
+
"""
|
438 |
+
self.set_requires_grad(self.net_dis, True)
|
439 |
+
self.optimizer_dis.zero_grad()
|
440 |
+
self.backward_dis()
|
441 |
+
self.optimizer_dis.step()
|
442 |
+
|
443 |
+
"""
|
444 |
+
----------------
|
445 |
+
Update Generator
|
446 |
+
----------------
|
447 |
+
# Discriminator requires no gradients when optimizing Generator
|
448 |
+
# set Generator's gradients to zero
|
449 |
+
# calculate gradients for Generator
|
450 |
+
# update Generator's weights
|
451 |
+
"""
|
452 |
+
self.set_requires_grad(self.net_dis, False)
|
453 |
+
self.optimizer_gen.zero_grad()
|
454 |
+
self.backward_gen()
|
455 |
+
self.optimizer_gen.step()
|
456 |
+
|
457 |
+
def get_current_losses(self):
|
458 |
+
all_losses = dict()
|
459 |
+
for loss_name in self.loss_names:
|
460 |
+
all_losses["loss_" + loss_name] = float(getattr(self, "loss_" + loss_name))
|
461 |
+
return all_losses
|