Thomas.Chaigneau commited on
Commit
6de6ae4
1 Parent(s): 52c4b0e
Files changed (1) hide show
  1. model.py +368 -0
model.py ADDED
@@ -0,0 +1,368 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytorch_lightning as pl
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+ from typing import Dict, List, Optional, OrderedDict, Tuple
8
+
9
+
10
+ class Discriminator(nn.Module):
11
+ def __init__(
12
+ self,
13
+ hidden_size: Optional[int] = 64,
14
+ channels: Optional[int] = 3,
15
+ kernel_size: Optional[int] = 4,
16
+ stride: Optional[int] = 2,
17
+ padding: Optional[int] = 1,
18
+ negative_slope: Optional[float] = 0.2,
19
+ bias: Optional[bool] = False,
20
+ ):
21
+ """
22
+ Initializes the discriminator.
23
+
24
+ Parameters
25
+ ----------
26
+ hidden_size : int, optional
27
+ The input size. (the default is 64)
28
+ channels : int, optional
29
+ The number of channels. (default: 3)
30
+ kernel_size : int, optional
31
+ The kernal size. (default: 4)
32
+ stride : int, optional
33
+ The stride. (default: 2)
34
+ padding : int, optional
35
+ The padding. (default: 1)
36
+ negative_slope : float, optional
37
+ The negative slope. (default: 0.2)
38
+ bias : bool, optional
39
+ Whether to use bias. (default: False)
40
+ """
41
+ super().__init__()
42
+ self.hidden_size = hidden_size
43
+ self.channels = channels
44
+ self.kernel_size = kernel_size
45
+ self.stride = stride
46
+ self.padding = padding
47
+ self.negative_slope = negative_slope
48
+ self.bias = bias
49
+
50
+ self.model = nn.Sequential(
51
+ nn.utils.spectral_norm(
52
+ nn.Conv2d(
53
+ self.channels, self.hidden_size,
54
+ kernel_size=self.kernel_size, stride=self.stride, padding=self.padding, bias=self.bias
55
+ ),
56
+ ),
57
+ nn.LeakyReLU(self.negative_slope, inplace=True),
58
+
59
+ nn.utils.spectral_norm(
60
+ nn.Conv2d(
61
+ hidden_size, hidden_size * 2,
62
+ kernel_size=self.kernel_size, stride=self.stride, padding=self.padding, bias=self.bias
63
+ ),
64
+ ),
65
+ nn.BatchNorm2d(hidden_size * 2),
66
+ nn.LeakyReLU(self.negative_slope, inplace=True),
67
+
68
+ nn.utils.spectral_norm(
69
+ nn.Conv2d(
70
+ hidden_size * 2, hidden_size * 4,
71
+ kernel_size=self.kernel_size, stride=self.stride, padding=self.padding, bias=self.bias
72
+ ),
73
+ ),
74
+ nn.BatchNorm2d(hidden_size * 4),
75
+ nn.LeakyReLU(self.negative_slope, inplace=True),
76
+
77
+ nn.utils.spectral_norm(
78
+ nn.Conv2d(
79
+ hidden_size * 4, hidden_size * 8,
80
+ kernel_size=self.kernel_size, stride=self.stride, padding=self.padding, bias=self.bias
81
+ ),
82
+ ),
83
+ nn.BatchNorm2d(hidden_size * 8),
84
+ nn.LeakyReLU(self.negative_slope, inplace=True),
85
+
86
+ nn.utils.spectral_norm(
87
+ nn.Conv2d(hidden_size * 8, 1, kernel_size=4, stride=1, padding=0, bias=self.bias), # output size: (1, 1, 1)
88
+ ),
89
+ nn.Flatten(),
90
+ nn.Sigmoid(),
91
+ )
92
+
93
+
94
+ def forward(self, input_img: torch.Tensor) -> torch.Tensor:
95
+ """
96
+ Forward propagation.
97
+
98
+ Parameters
99
+ ----------
100
+ input_img : torch.Tensor
101
+ The input image.
102
+
103
+ Returns
104
+ -------
105
+ torch.Tensor
106
+ The output.
107
+ """
108
+ logits = self.model(input_img)
109
+ return logits
110
+
111
+
112
+ class Generator(nn.Module):
113
+ def __init__(
114
+ self,
115
+ hidden_size: Optional[int] = 64,
116
+ latent_size: Optional[int] = 128,
117
+ channels: Optional[int] = 3,
118
+ kernel_size: Optional[int] = 4,
119
+ stride: Optional[int] = 2,
120
+ padding: Optional[int] = 1,
121
+ bias: Optional[bool] = False,
122
+ ):
123
+ """
124
+ Initializes the generator.
125
+
126
+ Parameters
127
+ ----------
128
+ hidden_size : int, optional
129
+ The hidden size. (default: 64)
130
+ latent_size : int, optional
131
+ The latent size. (default: 128)
132
+ channels : int, optional
133
+ The number of channels. (default: 3)
134
+ kernel_size : int, optional
135
+ The kernel size. (default: 4)
136
+ stride : int, optional
137
+ The stride. (default: 2)
138
+ padding : int, optional
139
+ The padding. (default: 1)
140
+ bias : bool, optional
141
+ Whether to use bias. (default: False)
142
+ """
143
+ super().__init__()
144
+ self.hidden_size = hidden_size
145
+ self.latent_size = latent_size
146
+ self.channels = channels
147
+ self.kernel_size = kernel_size
148
+ self.stride = stride
149
+ self.padding = padding
150
+ self.bias = bias
151
+
152
+ self.model = nn.Sequential(
153
+ nn.ConvTranspose2d(
154
+ self.latent_size, self.hidden_size * 8,
155
+ kernel_size=self.kernel_size, stride=1, padding=0, bias=self.bias
156
+ ),
157
+ nn.BatchNorm2d(self.hidden_size * 8),
158
+ nn.ReLU(inplace=True),
159
+
160
+ nn.ConvTranspose2d(
161
+ self.hidden_size * 8, self.hidden_size * 4,
162
+ kernel_size=self.kernel_size, stride=self.stride, padding=self.padding, bias=self.bias
163
+ ),
164
+ nn.BatchNorm2d(self.hidden_size * 4),
165
+ nn.ReLU(inplace=True),
166
+
167
+ nn.ConvTranspose2d(
168
+ self.hidden_size * 4, self.hidden_size * 2,
169
+ kernel_size=self.kernel_size, stride=self.stride, padding=self.padding, bias=self.bias
170
+ ),
171
+ nn.BatchNorm2d(self.hidden_size * 2),
172
+ nn.ReLU(inplace=True),
173
+
174
+ nn.ConvTranspose2d(
175
+ self.hidden_size * 2, self.hidden_size,
176
+ kernel_size=self.kernel_size, stride=self.stride, padding=self.padding, bias=self.bias
177
+ ),
178
+ nn.BatchNorm2d(self.hidden_size),
179
+ nn.ReLU(inplace=True),
180
+
181
+ nn.ConvTranspose2d(
182
+ self.hidden_size, self.channels,
183
+ kernel_size=self.kernel_size, stride=self.stride, padding=self.padding, bias=self.bias
184
+ ),
185
+ nn.Tanh() # output size: (channels, 64, 64)
186
+ )
187
+
188
+
189
+ def forward(self, input_noise: torch.Tensor) -> torch.Tensor:
190
+ """
191
+ Forward propagation.
192
+
193
+ Parameters
194
+ ----------
195
+ input_noise : torch.Tensor
196
+ The input image.
197
+
198
+ Returns
199
+ -------
200
+ torch.Tensor
201
+ The output.
202
+ """
203
+ fake_img = self.model(input_noise)
204
+ return fake_img
205
+
206
+
207
+ class DocuGAN(pl.LightningModule):
208
+ def __init__(
209
+ self,
210
+ hidden_size: Optional[int] = 64,
211
+ latent_size: Optional[int] = 128,
212
+ num_channel: Optional[int] = 3,
213
+ learning_rate: Optional[float] = 0.0002,
214
+ batch_size: Optional[int] = 128,
215
+ bias1: Optional[float] = 0.5,
216
+ bias2: Optional[float] = 0.999,
217
+ ):
218
+ """
219
+ Initializes the LightningGan.
220
+
221
+ Parameters
222
+ ----------
223
+ hidden_size : int, optional
224
+ The hidden size. (default: 64)
225
+ latent_size : int, optional
226
+ The latent size. (default: 128)
227
+ num_channel : int, optional
228
+ The number of channels. (default: 3)
229
+ learning_rate : float, optional
230
+ The learning rate. (default: 0.0002)
231
+ batch_size : int, optional
232
+ The batch size. (default: 128)
233
+ bias1 : float, optional
234
+ The bias1. (default: 0.5)
235
+ bias2 : float, optional
236
+ The bias2. (default: 0.999)
237
+ """
238
+ super().__init__()
239
+ self.hidden_size = hidden_size
240
+ self.latent_size = latent_size
241
+ self.num_channel = num_channel
242
+ self.learning_rate = learning_rate
243
+ self.batch_size = batch_size
244
+ self.bias1 = bias1
245
+ self.bias2 = bias2
246
+ self.criterion = nn.BCELoss()
247
+ self.validation = torch.randn(self.batch_size, self.latent_size, 1, 1)
248
+ self.save_hyperparameters()
249
+
250
+ self.generator = Generator(
251
+ latent_size=self.latent_size, channels=self.num_channel, hidden_size=self.hidden_size
252
+ )
253
+ self.generator.apply(self.weights_init)
254
+
255
+ self.discriminator = Discriminator(channels=self.num_channel, hidden_size=self.hidden_size)
256
+ self.discriminator.apply(self.weights_init)
257
+
258
+ # self.model = InceptionV3() # For FID metric
259
+
260
+
261
+ def weights_init(self, m: nn.Module) -> None:
262
+ """
263
+ Initializes the weights.
264
+
265
+ Parameters
266
+ ----------
267
+ m : nn.Module
268
+ The module.
269
+ """
270
+ classname = m.__class__.__name__
271
+ if classname.find("Conv") != -1:
272
+ nn.init.normal_(m.weight.data, 0.0, 0.02)
273
+ elif classname.find("BatchNorm") != -1:
274
+ nn.init.normal_(m.weight.data, 1.0, 0.02)
275
+ nn.init.constant_(m.bias.data, 0)
276
+
277
+
278
+ def configure_optimizers(self) -> Tuple[List[torch.optim.Optimizer], List]:
279
+ """
280
+ Configures the optimizers.
281
+
282
+ Returns
283
+ -------
284
+ Tuple[List[torch.optim.Optimizer], List]
285
+ The optimizers and the LR schedulers.
286
+ """
287
+ opt_generator = torch.optim.Adam(
288
+ self.generator.parameters(), lr=self.learning_rate, betas=(self.bias1, self.bias2)
289
+ )
290
+ opt_discriminator = torch.optim.Adam(
291
+ self.discriminator.parameters(), lr=self.learning_rate, betas=(self.bias1, self.bias2)
292
+ )
293
+ return [opt_generator, opt_discriminator], []
294
+
295
+
296
+ def forward(self, z: torch.Tensor) -> torch.Tensor:
297
+ """
298
+ Forward propagation.
299
+
300
+ Parameters
301
+ ----------
302
+ z : torch.Tensorh
303
+ The latent vector.
304
+
305
+ Returns
306
+ -------
307
+ torch.Tensor
308
+ The output.
309
+ """
310
+ return self.generator(z)
311
+
312
+
313
+ def training_step(
314
+ self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int, optimizer_idx: int
315
+ ) -> Dict:
316
+ """
317
+ Training step.
318
+
319
+ Parameters
320
+ ----------
321
+ batch : Tuple[torch.Tensor, torch.Tensor]
322
+ The batch.
323
+ batch_idx : int
324
+ The batch index.
325
+ optimizer_idx : int
326
+ The optimizer index.
327
+
328
+ Returns
329
+ -------
330
+ Dict
331
+ The training loss.
332
+ """
333
+ real_images = batch["tr_image"]
334
+
335
+ if optimizer_idx == 0: # Only train the generator
336
+ fake_random_noise = torch.randn(self.batch_size, self.latent_size, 1, 1)
337
+ fake_random_noise = fake_random_noise.type_as(real_images)
338
+ fake_images = self(fake_random_noise)
339
+
340
+ # Try to fool the discriminator
341
+ preds = self.discriminator(fake_images)
342
+ loss = self.criterion(preds, torch.ones_like(preds))
343
+ self.log("g_loss", loss, on_step=False, on_epoch=True)
344
+
345
+ tqdm_dict = {"g_loss": loss}
346
+ output = OrderedDict({"loss": loss, "progress_bar": tqdm_dict, "log": tqdm_dict})
347
+ return output
348
+
349
+ elif optimizer_idx == 1: # Only train the discriminator
350
+ real_preds = self.discriminator(real_images)
351
+ real_loss = self.criterion(real_preds, torch.ones_like(real_preds))
352
+
353
+ # Generate fake images
354
+ real_random_noise = torch.randn(self.batch_size, self.latent_size, 1, 1)
355
+ real_random_noise = real_random_noise.type_as(real_images)
356
+ fake_images = self(real_random_noise)
357
+
358
+ # Pass fake images though discriminator
359
+ fake_preds = self.discriminator(fake_images)
360
+ fake_loss = self.criterion(fake_preds, torch.zeros_like(fake_preds))
361
+
362
+ # Update discriminator weights
363
+ loss = real_loss + fake_loss
364
+ self.log("d_loss", loss, on_step=False, on_epoch=True)
365
+
366
+ tqdm_dict = {"d_loss": loss}
367
+ output = OrderedDict({"loss": loss, "progress_bar": tqdm_dict, "log": tqdm_dict})
368
+ return output