abhishekrs4 commited on
Commit
4a5aa3d
·
1 Parent(s): a0c5927

added image_colourization_cgan module

Browse files
image_colourization_cgan/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ import os, sys
2
+
3
+ sys.path.append(os.path.dirname(os.path.realpath(__file__)))
image_colourization_cgan/image_utils.py ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import skimage
3
+ import numpy as np
4
+ from skimage.transform import resize
5
+ from skimage.io import imread, imsave
6
+ from skimage.color import rgb2lab, lab2rgb, rgb2gray
7
+
8
+
9
+ def resize_image(img, image_size=(320, 320)):
10
+ """
11
+ ---------
12
+ Arguments
13
+ ---------
14
+ img : ndarray
15
+ ndarray of shape (H, W, 3) or (H, W) i.e. RGB and grayscale respectively
16
+ image_size : tuple of ints
17
+ image size to be used for resizing
18
+
19
+ -------
20
+ Returns
21
+ -------
22
+ resized image ndarray of shape (H_resized, W_resized, 3) or (H_resized, W_resized)
23
+ if RGB returns resized image ndarray in range [0, 255]
24
+ if grasycale returns resized image ndarray in range [0, 1]
25
+ """
26
+ img_resized = resize(img, image_size)
27
+ return img_resized
28
+
29
+
30
+ def convert_rgb2gray(img_rgb):
31
+ """
32
+ ---------
33
+ Arguments
34
+ ---------
35
+ img_rgb : ndarray
36
+ ndarray of shape (H, W, 3) i.e. RGB
37
+
38
+ -------
39
+ Returns
40
+ -------
41
+ grayscale image ndarray of shape (H, W)
42
+ """
43
+ img_gray = rgb2gray(img_rgb)
44
+ return img_gray
45
+
46
+
47
+ def convert_lab2rgb(img_lab):
48
+ """
49
+ ---------
50
+ Arguments
51
+ ---------
52
+ img_lab : ndarray
53
+ ndarray of shape (H, W, 3) i.e. Lab
54
+
55
+ -------
56
+ Returns
57
+ -------
58
+ RGB image ndarray of shape (H, W, 3) i.e. RGB space
59
+ """
60
+ img_rgb = lab2rgb(img_lab)
61
+ return img_rgb
62
+
63
+
64
+ def convert_rgb2lab(img_rgb):
65
+ """
66
+ ---------
67
+ Arguments
68
+ ---------
69
+ img_rgb : ndarray
70
+ ndarray of shape (H, W, 3) i.e. RGB
71
+
72
+ -------
73
+ Returns
74
+ -------
75
+ Lab image ndarray of shape (H, W, 3) i.e. Lab space
76
+ """
77
+ img_lab = rgb2lab(img_rgb)
78
+ return img_lab
79
+
80
+
81
+ def apply_image_ab_post_processing(img_ab):
82
+ """
83
+ ---------
84
+ Arguments
85
+ ---------
86
+ img_ab : ndarray
87
+ pre-processed ndarray of shape (H, W, 2) i.e. ab channels in Lab space in range [-1, 1]
88
+
89
+ -------
90
+ Returns
91
+ -------
92
+ post-processed ab channels ndarray of shape (H, W, 2) in range [-110, 110]
93
+ """
94
+ img_ab = img_ab * 110.0
95
+ return img_ab
96
+
97
+
98
+ def apply_image_l_pre_processing(img_l):
99
+ """
100
+ ---------
101
+ Arguments
102
+ ---------
103
+ img_l : ndarray
104
+ ndarray of shape (H, W) i.e. L channel in Lab space in range [0, 100]
105
+
106
+ -------
107
+ Returns
108
+ -------
109
+ pre-processed L channel ndarray of shape (H, W) in range [-1, 1]
110
+ """
111
+ img_l = (img_l / 50.0) - 1
112
+ return img_l
113
+
114
+
115
+ def apply_image_ab_pre_processing(img_ab):
116
+ """
117
+ ---------
118
+ Arguments
119
+ ---------
120
+ img_ab : ndarray
121
+ ndarray of shape (H, W, 2) i.e. ab channels in Lab space in range [-110, 110]
122
+
123
+ -------
124
+ Returns
125
+ -------
126
+ pre-processed ab channels ndarray of shape (H, W, 2) in range [-1, 1]
127
+ """
128
+ img_ab = (img_ab) / 110.0
129
+ return img_ab
130
+
131
+
132
+ def concat_images_l_ab(img_l, img_ab):
133
+ """
134
+ ---------
135
+ Arguments
136
+ ---------
137
+ img_l : ndarray
138
+ ndarray of shape (H, W, 1) i.e. L channel
139
+ img_ab : ndarray
140
+ ndarray of shape (H, W, 2) i.e. ab channels
141
+
142
+ -------
143
+ Returns
144
+ -------
145
+ Lab space ndarray of shape (H, W, 3)
146
+ """
147
+ img_lab = np.concatenate((img_l, img_ab), axis=-1)
148
+ return img_lab
149
+
150
+
151
+ def read_image(file_img):
152
+ """
153
+ ---------
154
+ Arguments
155
+ ---------
156
+ file_img : str
157
+ full path of the image
158
+
159
+ -------
160
+ Returns
161
+ -------
162
+ ndarray of shape (H, W, 3) for RGB or (H, W) for grayscale
163
+ """
164
+ img = imread(file_img)
165
+ return img
166
+
167
+
168
+ def save_image_rgb(file_img, img_arr):
169
+ """
170
+ ---------
171
+ Arguments
172
+ ---------
173
+ file_img : str
174
+ full path of the image
175
+ img_arr : ndarray
176
+ image ndarray to be saved, of shape (H, W, 3) for RGB or (H, W) for grasycale
177
+ """
178
+ imsave(file_img, img_arr)
179
+ return
180
+
181
+
182
+ def rescale_grayscale_image_l_channel(img_gray):
183
+ """
184
+ ---------
185
+ Arguments
186
+ ---------
187
+ img_gray : ndarray
188
+ grayscale image of shape (H, W) in range [0, 1]
189
+
190
+ -------
191
+ Returns
192
+ -------
193
+ L channel ndarray of shape (H, W) in range [0, 100]
194
+ """
195
+ img_l_rescaled = (img_gray) * 100.0
196
+ return img_l_rescaled
image_colourization_cgan/loss.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+
6
+ class GANLoss(nn.Module):
7
+ """
8
+ Define different GAN objectives.
9
+ The GANLoss class abstracts away the need to create the target label tensor
10
+ that has the same size as the input.
11
+ """
12
+
13
+ def __init__(self, loss_mode="vanilla", real_label=1.0, fake_label=0.0):
14
+ """
15
+ ---------
16
+ Arguments
17
+ ---------
18
+ loss_mode : str
19
+ GAN loss mode (default="vanilla")
20
+ real_label : bool
21
+ label for real image
22
+ fake_label : bool
23
+ label for fake image
24
+ """
25
+ super().__init__()
26
+ self.loss_mode = loss_mode
27
+ self.register_buffer("real_label", torch.tensor(real_label))
28
+ self.register_buffer("fake_label", torch.tensor(fake_label))
29
+
30
+ self.loss = None
31
+ if self.loss_mode == "vanilla":
32
+ self.loss = nn.BCEWithLogitsLoss()
33
+ else:
34
+ raise NotImplementedError(
35
+ f"GANLoss with {self.loss_mode} mode - not implemented yet"
36
+ )
37
+
38
+ def get_target_tensor(self, prediction, target_is_real):
39
+ """
40
+ ---------
41
+ Arguments
42
+ ---------
43
+ prediction : tensor
44
+ prediction from a discriminator
45
+ target_is_real : bool
46
+ whether the groundtruth label is for a real image or a fake image
47
+
48
+ -------
49
+ Returns
50
+ -------
51
+ tensor : A label tensor filled with groundtruth label with the same size as that of input
52
+ """
53
+ if target_is_real:
54
+ target_tensor = self.real_label
55
+ else:
56
+ target_tensor = self.fake_label
57
+ return target_tensor.expand_as(prediction)
58
+
59
+ def __call__(self, prediction, target_is_real):
60
+ """
61
+ ---------
62
+ Arguments
63
+ ---------
64
+ prediction : tensor
65
+ prediction from a discriminator
66
+ target_is_real : bool
67
+ whether the groundtruth label is for a real image or a fake image
68
+
69
+ -------
70
+ Returns
71
+ -------
72
+ loss : the computed loss
73
+ """
74
+ if self.loss_mode == "vanilla":
75
+ target_tensor = self.get_target_tensor(prediction, target_is_real)
76
+ loss = self.loss(prediction, target_tensor)
77
+ else:
78
+ loss = 0
79
+ return loss
image_colourization_cgan/model.py ADDED
@@ -0,0 +1,461 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from torchvision.models import resnet34
5
+
6
+ from loss import GANLoss
7
+
8
+
9
+ class ResNetEncoder(nn.Module):
10
+ """
11
+ Defines ResNet-34 Encoder
12
+ """
13
+
14
+ def __init__(self, pretrained=True):
15
+ """
16
+ ---------
17
+ Arguments
18
+ ---------
19
+ pretrained : bool (default=True)
20
+ boolean to control whether to use a pretrained resnet model or not
21
+ """
22
+ super().__init__()
23
+ self.resnet34 = resnet34(pretrained=pretrained)
24
+
25
+ def forward(self, x):
26
+ self.block1 = self.resnet34.conv1(x)
27
+ self.block1 = self.resnet34.bn1(self.block1)
28
+ self.block1 = self.resnet34.relu(self.block1) # [64, H/2, W/2]
29
+
30
+ self.block2 = self.resnet34.maxpool(self.block1)
31
+ self.block2 = self.resnet34.layer1(self.block2) # [64, H/4, W/4]
32
+ self.block3 = self.resnet34.layer2(self.block2) # [128, H/8, W/8]
33
+ self.block4 = self.resnet34.layer3(self.block3) # [256, H/16, W/16]
34
+ self.block5 = self.resnet34.layer4(self.block4) # [512, H/32, W/32]
35
+ return self.block5
36
+
37
+
38
+ class UNetDecoder(nn.Module):
39
+ """
40
+ Defines UNet Decoder
41
+ """
42
+
43
+ def __init__(self, encoder_net, out_channels=2):
44
+ """
45
+ ---------
46
+ Arguments
47
+ ---------
48
+ encoder_net : PyTorch model object of the encoder
49
+ PyTorch model object of the encoder
50
+ out_channels : int (default=2)
51
+ number of output channels of UNet Decoder
52
+ """
53
+ super().__init__()
54
+ self.encoder_net = encoder_net
55
+ self.up_block1 = self.up_conv_block(512, 256, use_dropout=True)
56
+ self.conv_reduction_1 = nn.Conv2d(512, 256, kernel_size=1)
57
+
58
+ self.up_block2 = self.up_conv_block(256, 128, use_dropout=True)
59
+ self.conv_reduction_2 = nn.Conv2d(256, 128, kernel_size=1)
60
+
61
+ self.up_block3 = self.up_conv_block(128, 64)
62
+ self.conv_reduction_3 = nn.Conv2d(128, 64, kernel_size=1)
63
+
64
+ self.up_block4 = self.up_conv_block(64, 64)
65
+ self.conv_reduction_4 = nn.Conv2d(128, 64, kernel_size=1)
66
+
67
+ self.up_block5 = self.final_up_conv_block(
68
+ conv_tr_in_channels=64, conv_tr_out_channels=32, out_channels=out_channels
69
+ )
70
+
71
+ def forward(self, x):
72
+ self.up_1 = self.up_block1(x) # [256, H/16, W/16]
73
+ self.up_1 = torch.cat(
74
+ [self.encoder_net.block4, self.up_1], dim=1
75
+ ) # [512, H/16, W/16]
76
+ self.up_1 = self.conv_reduction_1(self.up_1) # [256, H/16, W/16]
77
+
78
+ self.up_2 = self.up_block2(self.up_1) # [128, H/8, W/8]
79
+ self.up_2 = torch.cat(
80
+ [self.encoder_net.block3, self.up_2], dim=1
81
+ ) # [256, H/8, H/8]
82
+ self.up_2 = self.conv_reduction_2(self.up_2) # [128, H/8, W/8]
83
+
84
+ self.up_3 = self.up_block3(self.up_2) # [64, H/4, W/4]
85
+ self.up_3 = torch.cat(
86
+ [self.encoder_net.block2, self.up_3], dim=1
87
+ ) # [128, H/4, W/4]
88
+ self.up_3 = self.conv_reduction_3(self.up_3) # [64, H/4, W/4]
89
+
90
+ self.up_4 = self.up_block4(self.up_3) # [64, H/2, W/2]
91
+ self.up_4 = torch.cat(
92
+ [self.encoder_net.block1, self.up_4], dim=1
93
+ ) # [128, H/2, W/2]
94
+ self.up_4 = self.conv_reduction_4(self.up_4) # [64, H/2, W/2]
95
+
96
+ self.out_features = self.up_block5(self.up_4) # [2, H, W]
97
+ return self.out_features
98
+
99
+ def final_up_conv_block(
100
+ self,
101
+ conv_tr_in_channels,
102
+ conv_tr_out_channels,
103
+ out_channels,
104
+ conv_tr_kernel_size=4,
105
+ ):
106
+ """
107
+ ---------
108
+ Arguments
109
+ ---------
110
+ conv_tr_in_channels : int
111
+ number of input channels for conv transpose
112
+ conv_tr_out_channels : int
113
+ number of output channels for conv transpose
114
+ out_channels : int
115
+ number of output channels in the final layer
116
+ conv_tr_kernel_size : int (default=4)
117
+ kernel size for convolution transpose layer
118
+
119
+ -------
120
+ Returns
121
+ -------
122
+ A sequential block depending on the input arguments
123
+ """
124
+ final_block = nn.Sequential(
125
+ nn.ReLU(),
126
+ nn.ConvTranspose2d(
127
+ conv_tr_in_channels,
128
+ conv_tr_out_channels,
129
+ kernel_size=conv_tr_kernel_size,
130
+ stride=2,
131
+ padding=1,
132
+ bias=False,
133
+ ),
134
+ nn.Conv2d(conv_tr_out_channels, out_channels, kernel_size=1),
135
+ nn.Tanh(),
136
+ )
137
+ return final_block
138
+
139
+ def up_conv_block(
140
+ self, in_channels, out_channels, conv_tr_kernel_size=4, use_dropout=False
141
+ ):
142
+ """
143
+ ---------
144
+ Arguments
145
+ ---------
146
+ in_channels : int
147
+ number of input channels
148
+ out_channels : int
149
+ number of output channels
150
+ use_dropout : bool (default=False)
151
+ boolean to control whether to use dropout or not [induces randomness - used instead of random noise vector as input in Generator]
152
+ conv_tr_kernel_size : int (default=4)
153
+ kernel size for convolution transpose layer
154
+
155
+ -------
156
+ Returns
157
+ -------
158
+ A sequential block depending on the input arguments
159
+ """
160
+ if use_dropout:
161
+ block = nn.Sequential(
162
+ nn.ReLU(),
163
+ nn.ConvTranspose2d(
164
+ in_channels,
165
+ out_channels,
166
+ kernel_size=conv_tr_kernel_size,
167
+ stride=2,
168
+ padding=1,
169
+ bias=False,
170
+ ),
171
+ nn.BatchNorm2d(out_channels),
172
+ nn.Dropout(0.5),
173
+ )
174
+ else:
175
+ block = nn.Sequential(
176
+ nn.ReLU(),
177
+ nn.ConvTranspose2d(
178
+ in_channels,
179
+ out_channels,
180
+ kernel_size=conv_tr_kernel_size,
181
+ stride=2,
182
+ padding=1,
183
+ bias=False,
184
+ ),
185
+ nn.BatchNorm2d(out_channels),
186
+ )
187
+ return block
188
+
189
+
190
+ class ResUNet(nn.Module):
191
+ """
192
+ Defines Residual UNet model
193
+ """
194
+
195
+ def __init__(self, pretrained=True):
196
+ super().__init__()
197
+ self.encoder_net = ResNetEncoder(pretrained=pretrained)
198
+ self.decoder_net = UNetDecoder(self.encoder_net)
199
+
200
+ def forward(self, x):
201
+ self.encoder_features = self.encoder_net(x)
202
+ self.decoder_features = self.decoder_net(self.encoder_features)
203
+ return self.decoder_features
204
+
205
+
206
+ class Generator(nn.Module):
207
+ """
208
+ Defines a Generator in a GAN
209
+ """
210
+
211
+ def __init__(self, pretrained=True):
212
+ super().__init__()
213
+ self.res_u_net = ResUNet(pretrained=pretrained)
214
+
215
+ def forward(self, x):
216
+ return self.res_u_net(x)
217
+
218
+
219
+ class PatchDiscriminatorGAN(nn.Module):
220
+ """
221
+ Defines a Patch discriminator for GAN
222
+ """
223
+
224
+ def __init__(self, in_channels, num_filters=64, num_blocks=3):
225
+ """
226
+ ---------
227
+ Arguments
228
+ ---------
229
+ in_channels : int
230
+ number of input channels for Discriminator
231
+ num_filters : int (default=64)
232
+ number of filters in the first layer of Discriminator
233
+ num_blocks : int (default=3)
234
+ number of blocks to be used in the Discriminator
235
+ """
236
+ super().__init__()
237
+ model_blocks = [
238
+ self.get_conv_block(in_channels, num_filters, is_batch_norm=False)
239
+ ]
240
+ for i in range(num_blocks):
241
+ if i != num_blocks - 1:
242
+ model_blocks += [
243
+ self.get_conv_block(
244
+ num_filters * (2**i), num_filters * (2 ** (i + 1))
245
+ )
246
+ ]
247
+ else:
248
+ model_blocks += [
249
+ self.get_conv_block(
250
+ num_filters * (2**i), num_filters * (2 ** (i + 1)), stride=1
251
+ )
252
+ ]
253
+ model_blocks += [
254
+ self.get_conv_block(
255
+ num_filters * (2**num_blocks),
256
+ 1,
257
+ stride=1,
258
+ is_batch_norm=False,
259
+ is_activation=False,
260
+ )
261
+ ]
262
+ self.model = nn.Sequential(*model_blocks)
263
+
264
+ def get_conv_block(
265
+ self,
266
+ in_channels,
267
+ out_channels,
268
+ kernel_size=4,
269
+ stride=2,
270
+ padding=1,
271
+ is_batch_norm=True,
272
+ is_activation=True,
273
+ ):
274
+ """
275
+ ---------
276
+ Arguments
277
+ ---------
278
+ in_channels : int
279
+ input number of channels
280
+ out_channels : int
281
+ output number of channels
282
+ kernel_size : int
283
+ convolution kernel size
284
+ stride : int
285
+ stride to be used for convolution
286
+ padding : int
287
+ padding to be used for convolution
288
+ is_batch_norm : bool
289
+ boolean to control whether to add a batchnorm layer to the block
290
+ is_activation : bool
291
+ boolean to control whether to add an activation function to the block
292
+
293
+ -------
294
+ Returns
295
+ -------
296
+ a sequential block depending on the input arguments
297
+ """
298
+ block = [
299
+ nn.Conv2d(
300
+ in_channels,
301
+ out_channels,
302
+ kernel_size=kernel_size,
303
+ stride=stride,
304
+ padding=padding,
305
+ bias=not (is_batch_norm),
306
+ )
307
+ ]
308
+ if is_batch_norm:
309
+ block += [nn.BatchNorm2d(out_channels)]
310
+ if is_activation:
311
+ block += [nn.ELU()]
312
+ return nn.Sequential(*block)
313
+
314
+ def forward(self, x):
315
+ return self.model(x)
316
+
317
+
318
+ class ImageToImageConditionalGAN(nn.Module):
319
+ """
320
+ Defines Image (domain A) to Image (domain B) Conditional Adversarial Network
321
+ """
322
+
323
+ def __init__(
324
+ self,
325
+ device,
326
+ pretrained=False,
327
+ lr_gen=2e-4,
328
+ lr_dis=2e-4,
329
+ beta1=0.5,
330
+ beta2=0.999,
331
+ lambda_=100.0,
332
+ ):
333
+ super().__init__()
334
+ self.device = device
335
+ self.loss_names = ["gen_gan", "gen_l1", "dis_real", "dis_fake"]
336
+ self.lambda_ = lambda_
337
+ self.net_gen = Generator(pretrained=pretrained)
338
+ self.net_dis = PatchDiscriminatorGAN(in_channels=3)
339
+
340
+ self.criterion_GAN = GANLoss().to(self.device)
341
+ self.criterion_l1 = nn.L1Loss()
342
+
343
+ self.optimizer_gen = torch.optim.Adam(
344
+ self.net_gen.parameters(), lr=lr_gen, betas=(beta1, beta2)
345
+ )
346
+ self.optimizer_dis = torch.optim.Adam(
347
+ self.net_dis.parameters(), lr=lr_dis, betas=(beta1, beta2)
348
+ )
349
+
350
+ def set_requires_grad(self, model, requires_grad=True):
351
+ """
352
+ ---------
353
+ Arguments
354
+ ---------
355
+ model : model object
356
+ PyTorch model object
357
+ requires_grad : bool (default=True)
358
+ boolean to control whether the model requires gradients or not
359
+ """
360
+ for param in model.parameters():
361
+ param.requires_grad = requires_grad
362
+
363
+ def setup_input(self, data):
364
+ """
365
+ ---------
366
+ Arguments
367
+ ---------
368
+ data : dict
369
+ dictionary object containing image data of domains 1 and 2
370
+ """
371
+ self.real_domain_1 = data["domain_1"].to(self.device)
372
+ self.real_domain_2 = data["domain_2"].to(self.device)
373
+
374
+ if self.device == torch.device("cuda"):
375
+ self.real_domain_1_1_ch = self.real_domain_1[:, 0, :, :]
376
+ self.real_domain_1_1_ch = self.real_domain_1_1_ch[:, None, :, :]
377
+ else:
378
+ self.real_domain_1_1_ch = self.real_domain_1[:, :, :, 0]
379
+ self.real_domain_1_1_ch = self.real_domain_1_1_ch[:, :, :, None]
380
+
381
+ def forward(self):
382
+ # compute fake image in domain_2: Generator(domain_1)
383
+ self.fake_domain_2 = self.net_gen(self.real_domain_1)
384
+
385
+ def backward_gen(self):
386
+ """
387
+ Calculate GAN and L1 loss for generator
388
+ """
389
+ # first, Generator(domain_1) should try to fool the Discriminator
390
+ fake_domain_12 = torch.cat((self.real_domain_1_1_ch, self.fake_domain_2), dim=1)
391
+ pred_fake = self.net_dis(fake_domain_12)
392
+ self.loss_gen_gan = self.criterion_GAN(pred_fake, True)
393
+
394
+ # second, Generator(domain_1) = domain_2,
395
+ # i.e. output predicted by Generator should be close the domain_2
396
+ self.loss_gen_l1 = (
397
+ self.criterion_l1(self.fake_domain_2, self.real_domain_2) * self.lambda_
398
+ )
399
+
400
+ # compute the combined loss
401
+ self.loss_gen = self.loss_gen_gan + self.loss_gen_l1
402
+ self.loss_gen.backward()
403
+
404
+ def backward_dis(self):
405
+ """
406
+ Calculate GAN loss for discriminator
407
+ """
408
+ # Fake
409
+ fake_domain_12 = torch.cat((self.real_domain_1_1_ch, self.fake_domain_2), dim=1)
410
+ # stop backprop to generator by detaching fake_domain_12
411
+ pred_fake = self.net_dis(fake_domain_12.detach())
412
+ # Discriminator should identify the fake image
413
+ self.loss_dis_fake = self.criterion_GAN(pred_fake, False)
414
+
415
+ # Real
416
+ real_domain_12 = torch.cat((self.real_domain_1_1_ch, self.real_domain_2), dim=1)
417
+ pred_real = self.net_dis(real_domain_12)
418
+ # Discriminator should identify the real image
419
+ self.loss_dis_real = self.criterion_GAN(pred_real, True)
420
+
421
+ # compute the combined loss
422
+ self.loss_dis = (self.loss_dis_fake + self.loss_dis_real) * 0.5
423
+ self.loss_dis.backward()
424
+
425
+ def optimize_params(self):
426
+ # compute fake image in domain_2: Generator(domain_1)
427
+ self.forward()
428
+
429
+ """
430
+ --------------------
431
+ Update Discriminator
432
+ --------------------
433
+ # enable backprop for Discriminator
434
+ # set Discriminator's gradients to zero
435
+ # compute gradients for Discriminator
436
+ # update Discriminator's weights
437
+ """
438
+ self.set_requires_grad(self.net_dis, True)
439
+ self.optimizer_dis.zero_grad()
440
+ self.backward_dis()
441
+ self.optimizer_dis.step()
442
+
443
+ """
444
+ ----------------
445
+ Update Generator
446
+ ----------------
447
+ # Discriminator requires no gradients when optimizing Generator
448
+ # set Generator's gradients to zero
449
+ # calculate gradients for Generator
450
+ # update Generator's weights
451
+ """
452
+ self.set_requires_grad(self.net_dis, False)
453
+ self.optimizer_gen.zero_grad()
454
+ self.backward_gen()
455
+ self.optimizer_gen.step()
456
+
457
+ def get_current_losses(self):
458
+ all_losses = dict()
459
+ for loss_name in self.loss_names:
460
+ all_losses["loss_" + loss_name] = float(getattr(self, "loss_" + loss_name))
461
+ return all_losses