add new demo interface

#2
by nmud19 - opened
app.py CHANGED
@@ -1,11 +1,3 @@
1
- # Hack for spaces
2
- import os
3
-
4
- os.system("pip uninstall -y gradio")
5
- os.system("pip install -r requirements.txt")
6
-
7
- # Real code begins
8
-
9
  from typing import Union, List
10
 
11
  import gradio as gr
@@ -13,22 +5,21 @@ import matplotlib
13
  import torch
14
  from pytorch_lightning.utilities.types import EPOCH_OUTPUT
15
 
16
- matplotlib.use("Agg")
17
  import numpy as np
18
  from PIL import Image
19
  import albumentations as A
20
  import albumentations.pytorch as al_pytorch
21
  import torchvision
22
  from pl_bolts.models.gans import Pix2Pix
23
- from pl_bolts.models.gans.pix2pix.components import PatchGAN
24
- import torchvision.models as models
25
 
26
  """ Class """
27
 
28
 
29
  class OverpoweredPix2Pix(Pix2Pix):
 
30
  def validation_step(self, batch, batch_idx):
31
- """Validation step"""
32
  real, condition = batch
33
  with torch.no_grad():
34
  loss = self._disc_step(real, condition)
@@ -37,56 +28,33 @@ class OverpoweredPix2Pix(Pix2Pix):
37
  loss = self._gen_step(real, condition)
38
  self.log("val_generator_loss", loss)
39
 
40
- return {"sketch": real, "colour": condition}
 
 
 
41
 
42
- def validation_epoch_end(
43
- self, outputs: Union[EPOCH_OUTPUT, List[EPOCH_OUTPUT]]
44
- ) -> None:
45
- sketch = outputs[0]["sketch"]
46
- colour = outputs[0]["colour"]
47
  with torch.no_grad():
48
  gen_coloured = self.gen(sketch)
49
  grid_image = torchvision.utils.make_grid(
50
  [
51
- sketch[0],
52
- colour[0],
53
- gen_coloured[0],
54
  ],
55
- normalize=True,
56
  )
57
  self.logger.experiment.add_image(
58
- f"Image Grid {str(self.current_epoch)}", grid_image, self.current_epoch
59
- )
60
-
61
-
62
- class PatchGanChanged(OverpoweredPix2Pix):
63
- def __init__(self, in_channels, out_channels):
64
- super(PatchGanChanged, self).__init__(
65
- in_channels=in_channels, out_channels=out_channels
66
- )
67
- self.patch_gan = self.get_dense_PatchGAN(self.patch_gan)
68
-
69
- @staticmethod
70
- def get_dense_PatchGAN(disc: PatchGAN) -> PatchGAN:
71
- """Add final layer to gan"""
72
- disc.final = torch.nn.Sequential(
73
- disc.final,
74
- torch.nn.Flatten(),
75
- torch.nn.Linear(16 * 16, 1),
76
  )
77
- return disc
78
 
79
 
80
  """ Load the model """
81
  # train_64_val_16_patchgan_1val_plbolts_model_chkpt = "model/lightning_bolts_model/modified_path_gan.ckpt"
82
- train_64_val_16_plbolts_model_chkpt = (
83
- "model/lightning_bolts_model/epoch=99-step=44600.ckpt"
84
- )
85
- train_16_val_1_plbolts_model_chkpt = (
86
- "model/lightning_bolts_model/epoch=99-step=89000.ckpt"
87
- )
88
- modified_patchgan_chkpt = "model/lightning_bolts_model/modified_patchgan.ckpt"
89
-
90
  # model_checkpoint_path = "model/pix2pix_lightning_model/version_0/checkpoints/epoch=199-step=355600.ckpt"
91
  # model_checkpoint_path = "model/pix2pix_lightning_model/gen.pth"
92
 
@@ -102,142 +70,28 @@ train_16_val_1_plbolts_model = OverpoweredPix2Pix.load_from_checkpoint(
102
  )
103
  train_16_val_1_plbolts_model.eval()
104
 
105
- #
106
- modified_patchgan_model = PatchGanChanged.load_from_checkpoint(modified_patchgan_chkpt)
107
- modified_patchgan_model.eval()
108
-
109
-
110
- # Create new class
111
- class OverpoweredPix2Pix(Pix2Pix):
112
- def __init__(self, in_channels, out_channels):
113
- super(OverpoweredPix2Pix, self).__init__(
114
- in_channels=in_channels, out_channels=out_channels
115
- )
116
- self._create_inception_score()
117
-
118
- def _gen_step(self, real_images, conditioned_images):
119
- # Pix2Pix has adversarial and a reconstruction loss
120
- # First calculate the adversarial loss
121
- fake_images = self.gen(conditioned_images)
122
- disc_logits = self.patch_gan(fake_images, conditioned_images)
123
- adversarial_loss = self.adversarial_criterion(
124
- disc_logits, torch.ones_like(disc_logits)
125
- )
126
-
127
- # calculate reconstruction loss
128
- recon_loss = self.recon_criterion(fake_images, real_images)
129
- lambda_recon = self.hparams.lambda_recon
130
-
131
- # calculate cosine similarity
132
- representations_real = self.feature_extractor(real_images).flatten(1)
133
- representations_fake = self.feature_extractor(fake_images).flatten(1)
134
- similarity_score_list = self.cosine_similarity(
135
- representations_real, representations_fake
136
- )
137
- cosine_sim = sum(similarity_score_list) / len(similarity_score_list)
138
-
139
- self.log("Gen Cosine Sim Loss ", 1 - cosine_sim.cpu().detach().numpy())
140
- # print(adversarial_loss,1-cosine_sim, lambda_recon, recon_loss, )
141
-
142
- return (
143
- (adversarial_loss)
144
- + (lambda_recon * recon_loss)
145
- + (lambda_recon * (1 - cosine_sim))
146
- )
147
-
148
- def _create_inception_score(self):
149
- # init a pretrained resnet
150
- backbone = models.resnet50(pretrained=True)
151
- num_filters = backbone.fc.in_features
152
- layers = list(backbone.children())[:-1]
153
- self.feature_extractor = torch.nn.Sequential(*layers)
154
- self.cosine_similarity = torch.nn.CosineSimilarity(dim=1, eps=1e-6)
155
-
156
- def validation_step(self, batch, batch_idx):
157
- """Validation step"""
158
- real, condition = batch
159
- with torch.no_grad():
160
- disc_loss = self._disc_step(real, condition)
161
- self.log("Valid PatchGAN Loss", disc_loss)
162
-
163
- gan_loss = self._gen_step(real, condition)
164
- self.log("Valid Generator Loss", gan_loss)
165
-
166
- #
167
- fake_images = self.gen(condition)
168
- representations_real = self.feature_extractor(real).flatten(1)
169
- representations_fake = self.feature_extractor(fake_images).flatten(1)
170
- similarity_score_list = self.cosine_similarity(
171
- representations_real, representations_fake
172
- )
173
- cosine_sim = sum(similarity_score_list) / len(similarity_score_list)
174
-
175
- self.log("Valid Cosine Sim", cosine_sim)
176
-
177
- return {"sketch": condition, "colour": real}
178
-
179
- def validation_epoch_end(
180
- self, outputs: Union[EPOCH_OUTPUT, List[EPOCH_OUTPUT]]
181
- ) -> None:
182
- sketch = outputs[0]["sketch"]
183
- colour = outputs[0]["colour"]
184
- self.feature_extractor.eval()
185
- with torch.no_grad():
186
- gen_coloured = self.gen(sketch)
187
- representations_gen = self.feature_extractor(gen_coloured).flatten(1)
188
- representations_fake = self.feature_extractor(colour).flatten(1)
189
-
190
- similarity_score_list = self.cosine_similarity(
191
- representations_gen, representations_fake
192
- )
193
- similarity_score = sum(similarity_score_list) / len(similarity_score_list)
194
-
195
- grid_image = torchvision.utils.make_grid(
196
- [
197
- sketch[0],
198
- colour[0],
199
- gen_coloured[0],
200
- ],
201
- normalize=True,
202
- )
203
- self.logger.experiment.add_image(
204
- f"Image Grid {str(self.current_epoch)} __ {str(similarity_score)} ",
205
- grid_image,
206
- self.current_epoch,
207
- )
208
-
209
-
210
- cosine_sim_model_chkpt = "model/lightning_bolts_model/cosine_sim_model.ckpt"
211
-
212
- cosine_sim_model = OverpoweredPix2Pix.load_from_checkpoint(cosine_sim_model_chkpt)
213
- cosine_sim_model.eval()
214
-
215
 
216
  def predict(img: Image, type_of_model: str):
217
- """Create predictions"""
218
  # transform img
219
  image = np.asarray(img)
220
  # use on inference
221
- inference_transform = A.Compose(
222
- [
223
- A.Resize(width=256, height=256),
224
- A.Normalize(
225
- mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], max_pixel_value=255.0
226
- ),
227
- al_pytorch.ToTensorV2(),
228
- ]
229
- )
230
- inference_img = inference_transform(image=image)["image"].unsqueeze(0)
231
 
232
  # Choose model
233
  if type_of_model == "train batch size 16, val batch size 1":
234
  model = train_16_val_1_plbolts_model
235
  elif type_of_model == "train batch size 64, val batch size 16":
236
  model = train_64_val_16_plbolts_model
237
- elif type_of_model == "cosine similarity":
238
- model = cosine_sim_model
239
  else:
240
- model = modified_patchgan_model
241
 
242
  with torch.no_grad():
243
  result = model.gen(inference_img)
@@ -253,27 +107,13 @@ def predict2(img: Image):
253
  return predict(img=img, type_of_model="train batch size 64, val batch size 16")
254
 
255
 
256
- def predict3(img: Image):
257
- return predict(
258
- img=img,
259
- type_of_model="train batch size 64, val batch size 16, patch gan has 1 output score instead of 16*16",
260
- )
261
-
262
-
263
- def predict4(img: Image):
264
- return predict(
265
- img=img,
266
- type_of_model="cosine similarity",
267
- )
268
-
269
-
270
  model_input = gr.inputs.Radio(
271
  [
272
  "train batch size 16, val batch size 1",
273
  "train batch size 64, val batch size 16",
274
  "train batch size 64, val batch size 16, patch gan has 1 output score instead of 16*16",
275
  ],
276
- label="Type of Pix2Pix model to use : ",
277
  )
278
  image_input = gr.inputs.Image(type="pil")
279
  img_examples = [
@@ -285,26 +125,20 @@ img_examples = [
285
  "examples/thesis6.png",
286
  ]
287
 
 
288
  with gr.Blocks() as demo:
289
  gr.Markdown(" # Colour your sketches!")
290
  gr.Markdown(" ## Description :")
291
- gr.Markdown(" There are 4 Pix2Pix models in this example:")
292
  gr.Markdown(" 1. Training batch size is 16 , validation is 1")
293
  gr.Markdown(" 2. Training batch size is 64 , validation is 16")
294
- gr.Markdown(
295
- " 3. PatchGAN is changed, 1 value only instead of 16*16 ;"
296
- "training batch size is 64 , validation is 16"
297
- )
298
- gr.Markdown(
299
- " 4. cosine similarity is also added as a metric in this experiment for the generator. "
300
- )
301
  with gr.Tabs():
302
  with gr.TabItem("tr_16_val_1"):
303
  with gr.Row():
304
  image_input1 = gr.inputs.Image(type="pil")
305
- image_output1 = gr.outputs.Image(
306
- type="pil",
307
- )
308
  colour_1 = gr.Button("Colour it!")
309
  gr.Examples(
310
  examples=img_examples,
@@ -315,9 +149,7 @@ with gr.Blocks() as demo:
315
  with gr.TabItem("tr_64_val_14"):
316
  with gr.Row():
317
  image_input2 = gr.inputs.Image(type="pil")
318
- image_output2 = gr.outputs.Image(
319
- type="pil",
320
- )
321
  colour_2 = gr.Button("Colour it!")
322
  with gr.Row():
323
  gr.Examples(
@@ -326,34 +158,6 @@ with gr.Blocks() as demo:
326
  outputs=image_output2,
327
  fn=predict2,
328
  )
329
- with gr.TabItem("Single Value Discriminator"):
330
- with gr.Row():
331
- image_input3 = gr.inputs.Image(type="pil")
332
- image_output3 = gr.outputs.Image(
333
- type="pil",
334
- )
335
- colour_3 = gr.Button("Colour it!")
336
- with gr.Row():
337
- gr.Examples(
338
- examples=img_examples,
339
- inputs=image_input3,
340
- outputs=image_output3,
341
- fn=predict3,
342
- )
343
- with gr.TabItem("Cosine similarity loss"):
344
- with gr.Row():
345
- image_input4 = gr.inputs.Image(type="pil")
346
- image_output4 = gr.outputs.Image(
347
- type="pil",
348
- )
349
- colour_4 = gr.Button("Colour it!")
350
- with gr.Row():
351
- gr.Examples(
352
- examples=img_examples,
353
- inputs=image_input4,
354
- outputs=image_output4,
355
- fn=predict4,
356
- )
357
 
358
  colour_1.click(
359
  fn=predict1,
@@ -365,16 +169,6 @@ with gr.Blocks() as demo:
365
  inputs=image_input2,
366
  outputs=image_output2,
367
  )
368
- colour_3.click(
369
- fn=predict3,
370
- inputs=image_input3,
371
- outputs=image_output3,
372
- )
373
- colour_4.click(
374
- fn=predict4,
375
- inputs=image_input4,
376
- outputs=image_output4,
377
- )
378
 
379
  demo.title = "Colour your sketches!"
380
  demo.launch()
 
 
 
 
 
 
 
 
 
1
  from typing import Union, List
2
 
3
  import gradio as gr
 
5
  import torch
6
  from pytorch_lightning.utilities.types import EPOCH_OUTPUT
7
 
8
+ matplotlib.use('Agg')
9
  import numpy as np
10
  from PIL import Image
11
  import albumentations as A
12
  import albumentations.pytorch as al_pytorch
13
  import torchvision
14
  from pl_bolts.models.gans import Pix2Pix
 
 
15
 
16
  """ Class """
17
 
18
 
19
  class OverpoweredPix2Pix(Pix2Pix):
20
+
21
  def validation_step(self, batch, batch_idx):
22
+ """ Validation step """
23
  real, condition = batch
24
  with torch.no_grad():
25
  loss = self._disc_step(real, condition)
 
28
  loss = self._gen_step(real, condition)
29
  self.log("val_generator_loss", loss)
30
 
31
+ return {
32
+ 'sketch': real,
33
+ 'colour': condition
34
+ }
35
 
36
+ def validation_epoch_end(self, outputs: Union[EPOCH_OUTPUT, List[EPOCH_OUTPUT]]) -> None:
37
+ sketch = outputs[0]['sketch']
38
+ colour = outputs[0]['colour']
 
 
39
  with torch.no_grad():
40
  gen_coloured = self.gen(sketch)
41
  grid_image = torchvision.utils.make_grid(
42
  [
43
+ sketch[0], colour[0], gen_coloured[0],
 
 
44
  ],
45
+ normalize=True
46
  )
47
  self.logger.experiment.add_image(
48
+ f'Image Grid {str(self.current_epoch)}',
49
+ grid_image,
50
+ self.current_epoch
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  )
 
52
 
53
 
54
  """ Load the model """
55
  # train_64_val_16_patchgan_1val_plbolts_model_chkpt = "model/lightning_bolts_model/modified_path_gan.ckpt"
56
+ train_64_val_16_plbolts_model_chkpt = "model/lightning_bolts_model/epoch=99-step=44600.ckpt"
57
+ train_16_val_1_plbolts_model_chkpt = "model/lightning_bolts_model/epoch=99-step=89000.ckpt"
 
 
 
 
 
 
58
  # model_checkpoint_path = "model/pix2pix_lightning_model/version_0/checkpoints/epoch=199-step=355600.ckpt"
59
  # model_checkpoint_path = "model/pix2pix_lightning_model/gen.pth"
60
 
 
70
  )
71
  train_16_val_1_plbolts_model.eval()
72
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
 
74
  def predict(img: Image, type_of_model: str):
75
+ """ Create predictions """
76
  # transform img
77
  image = np.asarray(img)
78
  # use on inference
79
+ inference_transform = A.Compose([
80
+ A.Resize(width=256, height=256),
81
+ A.Normalize(mean=[.5, .5, .5], std=[.5, .5, .5], max_pixel_value=255.0),
82
+ al_pytorch.ToTensorV2(),
83
+ ])
84
+ inference_img = inference_transform(
85
+ image=image
86
+ )['image'].unsqueeze(0)
 
 
87
 
88
  # Choose model
89
  if type_of_model == "train batch size 16, val batch size 1":
90
  model = train_16_val_1_plbolts_model
91
  elif type_of_model == "train batch size 64, val batch size 16":
92
  model = train_64_val_16_plbolts_model
 
 
93
  else:
94
+ raise Exception("NOT YET SUPPORTED")
95
 
96
  with torch.no_grad():
97
  result = model.gen(inference_img)
 
107
  return predict(img=img, type_of_model="train batch size 64, val batch size 16")
108
 
109
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
  model_input = gr.inputs.Radio(
111
  [
112
  "train batch size 16, val batch size 1",
113
  "train batch size 64, val batch size 16",
114
  "train batch size 64, val batch size 16, patch gan has 1 output score instead of 16*16",
115
  ],
116
+ label="Type of Pix2Pix model to use : "
117
  )
118
  image_input = gr.inputs.Image(type="pil")
119
  img_examples = [
 
125
  "examples/thesis6.png",
126
  ]
127
 
128
+
129
  with gr.Blocks() as demo:
130
  gr.Markdown(" # Colour your sketches!")
131
  gr.Markdown(" ## Description :")
132
+ gr.Markdown(" There are three Pix2Pix models in this example:")
133
  gr.Markdown(" 1. Training batch size is 16 , validation is 1")
134
  gr.Markdown(" 2. Training batch size is 64 , validation is 16")
135
+ gr.Markdown(" 3. PatchGAN is changed, 1 value only instead of 16*16 ;"
136
+ "training batch size is 64 , validation is 16")
 
 
 
 
 
137
  with gr.Tabs():
138
  with gr.TabItem("tr_16_val_1"):
139
  with gr.Row():
140
  image_input1 = gr.inputs.Image(type="pil")
141
+ image_output1 = gr.outputs.Image(type="pil", )
 
 
142
  colour_1 = gr.Button("Colour it!")
143
  gr.Examples(
144
  examples=img_examples,
 
149
  with gr.TabItem("tr_64_val_14"):
150
  with gr.Row():
151
  image_input2 = gr.inputs.Image(type="pil")
152
+ image_output2 = gr.outputs.Image(type="pil", )
 
 
153
  colour_2 = gr.Button("Colour it!")
154
  with gr.Row():
155
  gr.Examples(
 
158
  outputs=image_output2,
159
  fn=predict2,
160
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
161
 
162
  colour_1.click(
163
  fn=predict1,
 
169
  inputs=image_input2,
170
  outputs=image_output2,
171
  )
 
 
 
 
 
 
 
 
 
 
172
 
173
  demo.title = "Colour your sketches!"
174
  demo.launch()
app/__init__.py ADDED
File without changes
app/config.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ num_workers = 4
2
+ train_batch_size = 32
3
+ val_batch_size = 1
app/consume_data/__init__.py ADDED
File without changes
app/consume_data/consume_data.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import os
3
+ from typing import List, Optional
4
+ from PIL import Image
5
+ import matplotlib.pyplot as plt
6
+ from torchvision import transforms
7
+ import albumentations as A
8
+ import numpy as np
9
+ import albumentations.pytorch as al_pytorch
10
+ from typing import Dict, Tuple
11
+ from app import config
12
+ import pytorch_lightning as pl
13
+
14
+ torch.__version__
15
+
16
+
17
+ class AnimeDataset(torch.utils.data.Dataset):
18
+ """ Sketchs and Colored Image dataset """
19
+
20
+ def __init__(self, imgs_path: List[str], transforms: transforms.Compose) -> None:
21
+ """ Set the transforms and file path """
22
+ self.list_files = imgs_path
23
+ self.transform = transforms
24
+
25
+ def __len__(self) -> int:
26
+ """ Should return number of files """
27
+ return len(self.list_files)
28
+
29
+ def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor]:
30
+ """ Get image and mask by index """
31
+ # read image file
32
+ img_file = self.list_files[index]
33
+ # img_path = os.path.join(self.root_dir, img_file)
34
+ image = np.array(Image.open(img_file))
35
+
36
+ # divide image into sketchs and colored_imgs, right is sketch and left is colored images
37
+ sketchs = image[:, image.shape[1] // 2:, :]
38
+ colored_imgs = image[:, :image.shape[1] // 2, :]
39
+
40
+ # data augmentation on both sketchs and colored_imgs
41
+ augmentations = self.transform.both_transform(image=sketchs, image0=colored_imgs)
42
+ sketchs, colored_imgs = augmentations['image'], augmentations['image0']
43
+
44
+ # conduct data augmentation respectively
45
+ sketchs = self.transform.transform_only_input(image=sketchs)['image']
46
+ colored_imgs = self.transform.transform_only_mask(image=colored_imgs)['image']
47
+ return sketchs, colored_imgs
48
+
49
+
50
+ # Data Augmentation
51
+ class Transforms:
52
+ def __init__(self):
53
+ # use on both sketchs and colored images
54
+ self.both_transform = A.Compose([
55
+ A.Resize(width=256, height=256),
56
+ A.HorizontalFlip(p=.5)
57
+ ], additional_targets={'image0': 'image'})
58
+
59
+ # use on sketchs only
60
+ self.transform_only_input = A.Compose([
61
+ A.ColorJitter(p=.1),
62
+ A.Normalize(mean=[.5, .5, .5], std=[.5, .5, .5], max_pixel_value=255.0),
63
+ al_pytorch.ToTensorV2(),
64
+ ])
65
+
66
+ # use on colored images
67
+ self.transform_only_mask = A.Compose([
68
+ A.Normalize(mean=[.5, .5, .5], std=[.5, .5, .5], max_pixel_value=255.0),
69
+ al_pytorch.ToTensorV2(),
70
+ ])
71
+
72
+
73
+ class Transforms_v1:
74
+ """ Class to hold transforms """
75
+
76
+ def __init__(self):
77
+ # use on both sketchs and colored images
78
+ self.resize_572 = A.Compose([
79
+ A.Resize(width=572, height=572)
80
+ ])
81
+
82
+ self.resize_388 = A.Compose([
83
+ A.Resize(width=388, height=388)
84
+ ])
85
+
86
+ self.resize_256 = A.Compose([
87
+ A.Resize(width=256, height=256)
88
+ ])
89
+
90
+ # use on sketchs only
91
+ self.transform_only_input = A.Compose([
92
+ # A.ColorJitter(p=.1),
93
+ A.Normalize(mean=[.5, .5, .5], std=[.5, .5, .5], max_pixel_value=255.0),
94
+ al_pytorch.ToTensorV2(),
95
+ ])
96
+
97
+ # use on colored images
98
+ self.transform_only_mask = A.Compose([
99
+ A.Normalize(mean=[.5, .5, .5], std=[.5, .5, .5], max_pixel_value=255.0),
100
+ al_pytorch.ToTensorV2(),
101
+ ])
102
+
103
+
104
+ class AnimeSketchDataModule(pl.LightningDataModule):
105
+ """ Class to hold the Anime sketch Data"""
106
+
107
+ def __init__(
108
+ self,
109
+ data_dir: str,
110
+ train_folder_name: str = "train/",
111
+ val_folder_name: str = "val/",
112
+ train_batch_size: int = config.train_batch_size,
113
+ val_batch_size: int = config.val_batch_size,
114
+ train_num_images: int = 0,
115
+ val_num_images: int = 0,
116
+ ):
117
+ super().__init__()
118
+ self.val_dataset = None
119
+ self.train_dataset = None
120
+ self.data_dir: str = data_dir
121
+ # Set train and val images folder
122
+ train_path: str = f"{self.data_dir}{train_folder_name}/"
123
+ train_images: List[str] = [f"{train_path}{x}" for x in os.listdir(train_path)]
124
+ val_path: str = f"{self.data_dir}{val_folder_name}"
125
+ val_images: List[str] = [f"{val_path}{x}" for x in os.listdir(val_path)]
126
+ #
127
+ self.train_images = train_images[:train_num_images] if train_num_images else train_images
128
+ self.val_images = val_images[:val_num_images] if val_num_images else val_images
129
+ #
130
+ self.train_batch_size = train_batch_size
131
+ self.val_batch_size = val_batch_size
132
+
133
+ def set_datasets(self) -> None:
134
+ """ Get the train and test datasets """
135
+ self.train_dataset = AnimeDataset(
136
+ imgs_path=self.train_images,
137
+ transforms=Transforms()
138
+ )
139
+ self.val_dataset = AnimeDataset(
140
+ imgs_path=self.val_images,
141
+ transforms=Transforms()
142
+ )
143
+ print("The train test dataset lengths are : ", len(self.train_dataset), len(self.val_dataset))
144
+ return None
145
+
146
+ def setup(self, stage: Optional[str] = None) -> None:
147
+ self.set_datasets()
148
+
149
+ def train_dataloader(self):
150
+ return torch.utils.data.DataLoader(
151
+ self.train_dataset,
152
+ batch_size=self.train_batch_size,
153
+ shuffle=False,
154
+ num_workers=2,
155
+ pin_memory=True
156
+ )
157
+
158
+ def val_dataloader(self):
159
+ return torch.utils.data.DataLoader(
160
+ self.val_dataset,
161
+ batch_size=self.val_batch_size,
162
+ shuffle=False,
163
+ num_workers=2,
164
+ pin_memory=True
165
+ )
app/data.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import os
3
+ from typing import List
4
+ from PIL import Image
5
+ import matplotlib.pyplot as plt
6
+ from torchvision import transforms
7
+ import albumentations as A
8
+ import numpy as np
9
+ import albumentations.pytorch as al_pytorch
10
+ from typing import Dict, Tuple
11
+
12
+
13
+ class AnimeDataset(torch.utils.data.Dataset):
14
+ """ Sketchs and Colored Image dataset """
15
+
16
+ def __init__(self, imgs_path: List[str], transforms: transforms.Compose) -> None:
17
+ """ Set the transforms and file path """
18
+ self.list_files = imgs_path
19
+ self.transform = transforms
20
+
21
+ def __len__(self) -> int:
22
+ """ Should return number of files """
23
+ return len(self.list_files)
24
+
25
+ def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor]:
26
+ """ Get image and mask by index """
27
+ # read image file
28
+ img_path = img_file = self.list_files[index]
29
+ image = np.array(Image.open(img_path))
30
+
31
+ # divide image into sketchs and colored_imgs, right is sketch and left is colored images
32
+ # as according to the dataset
33
+ sketchs = image[:, image.shape[1] // 2:, :]
34
+ colored_imgs = image[:, :image.shape[1] // 2, :]
35
+
36
+ # data augmentation on both sketchs and colored_imgs
37
+ augmentations = self.transform.both_transform(image=sketchs, image0=colored_imgs)
38
+ sketchs, colored_imgs = augmentations['image'], augmentations['image0']
39
+
40
+ # conduct data augmentation respectively
41
+ sketchs = self.transform.transform_only_input(image=sketchs)['image']
42
+ colored_imgs = self.transform.transform_only_mask(image=colored_imgs)['image']
43
+ return sketchs, colored_imgs
44
+
45
+
46
+ class Transforms:
47
+ """ Class to hold transforms """
48
+
49
+ def __init__(self):
50
+ # use on both sketchs and colored images
51
+ self.both_transform = A.Compose([
52
+ A.Resize(width=1024, height=1024),
53
+ A.HorizontalFlip(p=.5)
54
+ ],
55
+ additional_targets={'image0': 'image'}
56
+ )
57
+
58
+ # use on sketchs only
59
+ self.transform_only_input = A.Compose([
60
+ # A.ColorJitter(p=.1),
61
+ A.Normalize(mean=[.5, .5, .5], std=[.5, .5, .5], max_pixel_value=255.0),
62
+ al_pytorch.ToTensorV2(),
63
+ ])
64
+
65
+ # use on colored images
66
+ self.transform_only_mask = A.Compose([
67
+ A.Normalize(mean=[.5, .5, .5], std=[.5, .5, .5], max_pixel_value=255.0),
68
+ al_pytorch.ToTensorV2(),
69
+ ])
app/discriminator/__init__.py ADDED
File without changes
app/discriminator/patch_gan.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import torch
3
+ import albumentations as A
4
+
5
+
6
+ # CNN block will be used repeatly later
7
+ class CNNBlock(nn.Module):
8
+ def __init__(self, in_channels, out_channels, stride=2):
9
+ super().__init__()
10
+ self.conv = nn.Sequential(
11
+ nn.Conv2d(in_channels, out_channels, 4, stride, bias=False, padding_mode='reflect'),
12
+ nn.BatchNorm2d(out_channels),
13
+ nn.LeakyReLU(0.2)
14
+ )
15
+
16
+ def forward(self, x):
17
+ return self.conv(x)
18
+
19
+
20
+ class PatchGan(torch.nn.Module):
21
+ """ Patch GAN Architecture """
22
+
23
+ @staticmethod
24
+ def create_contracting_block(in_channels: int, out_channels: int):
25
+ """
26
+ Create encoding layer
27
+ :param in_channels:
28
+ :param out_channels:
29
+ :return:
30
+ """
31
+ conv_layer = torch.nn.Sequential(
32
+ torch.nn.Conv2d(
33
+ in_channels=in_channels,
34
+ out_channels=out_channels,
35
+ kernel_size=3,
36
+ padding=1,
37
+ ),
38
+ torch.nn.ReLU(),
39
+ torch.nn.Conv2d(
40
+ in_channels=out_channels,
41
+ out_channels=out_channels,
42
+ kernel_size=3,
43
+ padding=1,
44
+ ),
45
+ torch.nn.ReLU(),
46
+ )
47
+ max_pool = torch.nn.Sequential(
48
+ torch.nn.MaxPool2d(
49
+ stride=2,
50
+ kernel_size=2,
51
+ ),
52
+ )
53
+ layer = torch.nn.Sequential(
54
+ conv_layer,
55
+ max_pool,
56
+ )
57
+ return layer
58
+
59
+ def __init__(self, input_channels: int, hidden_channels: int) -> None:
60
+ super().__init__()
61
+ self.resize_channels = torch.nn.Conv2d(
62
+ in_channels=input_channels,
63
+ out_channels=hidden_channels,
64
+ kernel_size=1,
65
+ )
66
+
67
+ self.enc1 = self.create_contracting_block(
68
+ in_channels=hidden_channels,
69
+ out_channels=hidden_channels * 2
70
+ )
71
+
72
+ self.enc2 = self.create_contracting_block(
73
+ in_channels=hidden_channels * 2,
74
+ out_channels=hidden_channels * 4
75
+ )
76
+
77
+ self.enc3 = self.create_contracting_block(
78
+ in_channels=hidden_channels * 4,
79
+ out_channels=hidden_channels * 8
80
+ )
81
+ self.enc4 = self.create_contracting_block(
82
+ in_channels=hidden_channels * 8,
83
+ out_channels=hidden_channels * 16
84
+ )
85
+
86
+ self.final_layer = torch.nn.Conv2d(
87
+ in_channels=hidden_channels * 16,
88
+ out_channels=1,
89
+ kernel_size=1,
90
+ )
91
+
92
+ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
93
+ """ Forward patch gan layer """
94
+ inpt = torch.cat([x, y], axis=1)
95
+ resize_img = self.resize_channels(inpt)
96
+ enc1 = self.enc1(resize_img)
97
+ enc2 = self.enc2(enc1)
98
+ enc3 = self.enc3(enc2)
99
+ enc4 = self.enc4(enc3)
100
+ final_layer = self.final_layer(enc4)
101
+ return final_layer
102
+
103
+
104
+ # x, y <- concatenate the gen image and the input image to determin the gen image is real or not
105
+ class Discriminator(nn.Module):
106
+ def __init__(self, in_channels=3, features=[64, 128, 256, 512]):
107
+ super().__init__()
108
+ self.initial = nn.Sequential(
109
+ nn.Conv2d(in_channels * 2, features[0], kernel_size=4, stride=2, padding=1, padding_mode='reflect'),
110
+ nn.LeakyReLU(.2)
111
+ )
112
+
113
+ # save layers into a list
114
+ layers = []
115
+ in_channels = features[0]
116
+ for feature in features[1:]:
117
+ layers.append(
118
+ CNNBlock(
119
+ in_channels,
120
+ feature,
121
+ stride=1 if feature == features[-1] else 2
122
+ ),
123
+ )
124
+ in_channels = feature
125
+
126
+ # append last conv layer
127
+ layers.append(
128
+ nn.Conv2d(in_channels, 1, kernel_size=4, stride=1, padding=1, padding_mode='reflect')
129
+ )
130
+
131
+ # create a model using the list of layers
132
+ self.model = nn.Sequential(*layers)
133
+
134
+ def forward(self, x, y):
135
+ x = torch.cat([x, y], dim=1)
136
+ x = self.initial(x)
137
+ return self.model(x)
app/generator/__init__.py ADDED
File without changes
app/generator/unetGen.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from app.generator import unetParts
4
+
5
+
6
+ class UNET(torch.nn.Module):
7
+ """ Implementation of unet """
8
+
9
+ def __init__(
10
+ self,
11
+ ) -> None:
12
+ """
13
+ Create the UNET here
14
+ """
15
+ super().__init__()
16
+ self.enc_layer1: unetParts.EncoderLayer = unetParts.EncoderLayer(
17
+ in_channels=3,
18
+ out_channels=64
19
+ )
20
+ self.enc_layer2: unetParts.EncoderLayer = unetParts.EncoderLayer(
21
+ in_channels=64,
22
+ out_channels=128
23
+ )
24
+ self.enc_layer3: unetParts.EncoderLayer = unetParts.EncoderLayer(
25
+ in_channels=128,
26
+ out_channels=256
27
+ )
28
+ self.enc_layer4: unetParts.EncoderLayer = unetParts.EncoderLayer(
29
+ in_channels=256,
30
+ out_channels=512
31
+ )
32
+ # Middle layer
33
+ self.middle_layer: unetParts.MiddleLayer = unetParts.MiddleLayer(
34
+ in_channels=512,
35
+ out_channels=1024,
36
+ )
37
+ # Decoding layer
38
+ self.dec_layer1: unetParts.DecoderLayer = unetParts.DecoderLayer(
39
+ in_channels=1024,
40
+ out_channels=512,
41
+ )
42
+ self.dec_layer2: unetParts.DecoderLayer = unetParts.DecoderLayer(
43
+ in_channels=512,
44
+ out_channels=256,
45
+ )
46
+
47
+ self.dec_layer3: unetParts.DecoderLayer = unetParts.DecoderLayer(
48
+ in_channels=256,
49
+ out_channels=128,
50
+ )
51
+ self.dec_layer4: unetParts.DecoderLayer = unetParts.DecoderLayer(
52
+ in_channels=128,
53
+ out_channels=64,
54
+ )
55
+ self.final_layer: torch.nn.Conv2d = torch.nn.Conv2d(
56
+ in_channels=64,
57
+ out_channels=3,
58
+ kernel_size=1
59
+ )
60
+
61
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
62
+ """
63
+ Forward function
64
+ :param x:
65
+ :return:
66
+ """
67
+ # enc layers
68
+ enc1, conv1 = self.enc_layer1(x=x) # 64
69
+ enc2, conv2 = self.enc_layer2(x=enc1) # 128
70
+ enc3, conv3 = self.enc_layer3(x=enc2) # 256
71
+ enc4, conv4 = self.enc_layer4(x=enc3) # 512
72
+ # middle layers
73
+ mid = self.middle_layer(x=enc4) # 1024
74
+ # expanding layers
75
+ # 512
76
+ dec1 = self.dec_layer1(
77
+ input_layer=mid,
78
+ cropping_layer=conv4,
79
+ )
80
+ # 256
81
+ dec2 = self.dec_layer2(
82
+ input_layer=dec1,
83
+ cropping_layer=conv3,
84
+ )
85
+ # 128
86
+ dec3 = self.dec_layer3(
87
+ input_layer=dec2,
88
+ cropping_layer=conv2,
89
+ )
90
+ # 64
91
+ dec4 = self.dec_layer4(
92
+ input_layer=dec3,
93
+ cropping_layer=conv1,
94
+ )
95
+ # 3
96
+ fin_layer = self.final_layer(
97
+ dec4,
98
+ )
99
+ # Interpolate to retain size
100
+ fin_layer_resized = torch.nn.functional.interpolate(fin_layer, 572)
101
+ return fin_layer_resized
102
+
103
+
104
+ class Generator(nn.Module):
105
+ def __init__(self, in_channels=3, features=64):
106
+ super().__init__()
107
+ # Encoder
108
+ self.initial_down = nn.Sequential(
109
+ nn.Conv2d(in_channels, features, 4, 2, 1, padding_mode='reflect'),
110
+ nn.LeakyReLU(.2),
111
+ )
112
+ self.down1 = Block(features, features * 2, down=True, act='leaky', use_dropout=False) # 64
113
+ self.down2 = Block(features * 2, features * 4, down=True, act='leaky', use_dropout=False) # 32
114
+ self.down3 = Block(features * 4, features * 8, down=True, act='leaky', use_dropout=False) # 16
115
+ self.down4 = Block(features * 8, features * 8, down=True, act='leaky', use_dropout=False) # 8
116
+ self.down5 = Block(features * 8, features * 8, down=True, act='leaky', use_dropout=False) # 4
117
+ self.down6 = Block(features * 8, features * 8, down=True, act='leaky', use_dropout=False) # 2
118
+ self.bottleneck = nn.Sequential(
119
+ nn.Conv2d(features * 8, features * 8, 4, 2, 1, padding_mode='reflect'),
120
+ nn.ReLU(), # 1x1
121
+ )
122
+ # Decoder
123
+ self.up1 = Block(features * 8, features * 8, down=False, act='relu', use_dropout=True)
124
+ self.up2 = Block(features * 8 * 2, features * 8, down=False, act='relu', use_dropout=True)
125
+ self.up3 = Block(features * 8 * 2, features * 8, down=False, act='relu', use_dropout=True)
126
+ self.up4 = Block(features * 8 * 2, features * 8, down=False, act='relu', use_dropout=False)
127
+ self.up5 = Block(features * 8 * 2, features * 4, down=False, act='relu', use_dropout=False)
128
+ self.up6 = Block(features * 4 * 2, features * 2, down=False, act='relu', use_dropout=False)
129
+ self.up7 = Block(features * 2 * 2, features, down=False, act='relu', use_dropout=False)
130
+ self.final_up = nn.Sequential(
131
+ nn.ConvTranspose2d(features * 2, in_channels, kernel_size=4, stride=2, padding=1),
132
+ nn.Tanh()
133
+ )
134
+
135
+ def forward(self, x):
136
+ # Encoder
137
+ d1 = self.initial_down(x)
138
+ d2 = self.down1(d1)
139
+ d3 = self.down2(d2)
140
+ d4 = self.down3(d3)
141
+ d5 = self.down4(d4)
142
+ d6 = self.down5(d5)
143
+ d7 = self.down6(d6)
144
+ bottleneck = self.bottleneck(d7)
145
+
146
+ # Decoder
147
+ u1 = self.up1(bottleneck)
148
+ u2 = self.up2(torch.cat([u1, d7], 1))
149
+ u3 = self.up3(torch.cat([u2, d6], 1))
150
+ u4 = self.up4(torch.cat([u3, d5], 1))
151
+ u5 = self.up5(torch.cat([u4, d4], 1))
152
+ u6 = self.up6(torch.cat([u5, d3], 1))
153
+ u7 = self.up7(torch.cat([u6, d2], 1))
154
+ return self.final_up(torch.cat([u7, d1], 1))
155
+
156
+
157
+ # block will be use repeatly later
158
+ class Block(nn.Module):
159
+ def __init__(self, in_channels, out_channels, down=True, act='relu', use_dropout=False):
160
+ super().__init__()
161
+ self.conv = nn.Sequential(
162
+ # the block will be use on both encoder (down=True) and decoder (down=False)
163
+ nn.Conv2d(in_channels, out_channels, 4, 2, 1, bias=False, padding_mode='reflect')
164
+ if down
165
+ else nn.ConvTranspose2d(in_channels, out_channels, 4, 2, 1, bias=False),
166
+ nn.BatchNorm2d(out_channels),
167
+ nn.ReLU() if act == 'relu' else nn.LeakyReLU(.2)
168
+ )
169
+ self.use_dropout = use_dropout
170
+ self.dropout = nn.Dropout(.5)
171
+
172
+ def forward(self, x):
173
+ x = self.conv(x)
174
+ return self.dropout(x) if self.use_dropout else x
app/generator/unetParts.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from typing import Tuple
3
+
4
+
5
+ class DecoderLayer(torch.nn.Module):
6
+ """Decoder model"""
7
+
8
+ def __init__(self, in_channels: int, out_channels: int):
9
+ super().__init__()
10
+ self.up_sample_layer = torch.nn.Sequential(
11
+ torch.nn.ConvTranspose2d(
12
+ in_channels=in_channels,
13
+ out_channels=out_channels,
14
+ kernel_size=2,
15
+ stride=2,
16
+ bias=False,
17
+ )
18
+ )
19
+ self.conv_layer = EncoderLayer(
20
+ in_channels=in_channels,
21
+ out_channels=out_channels,
22
+ ).conv_layer
23
+
24
+ @staticmethod
25
+ def _get_cropping_shape(previous_layer_shape: torch.Size, current_layer_shape: torch.Size) -> int:
26
+ """ Get the shape to crop """
27
+ return (previous_layer_shape[2] - current_layer_shape[2]) // 2 * -1
28
+
29
+ def forward(
30
+ self,
31
+ input_layer: torch.Tensor,
32
+ cropping_layer: torch.Tensor
33
+ ) -> torch.Tensor:
34
+ """
35
+ Forward function to concatenate and conv the figure
36
+ :param cropping_layer:
37
+ :param input_layer:
38
+ :return:
39
+ """
40
+ input_layer = self.up_sample_layer(input_layer)
41
+
42
+ cropping_shape = self._get_cropping_shape(
43
+ current_layer_shape=input_layer.shape,
44
+ previous_layer_shape=cropping_layer.shape,
45
+ )
46
+
47
+ cropping_layer = torch.nn.functional.pad(
48
+ input=cropping_layer,
49
+ pad=[cropping_shape for _ in range(4)]
50
+ )
51
+ combined_layer = torch.cat(
52
+ tensors=[input_layer, cropping_layer],
53
+ dim=1
54
+ )
55
+ result = self.conv_layer(combined_layer)
56
+ return result
57
+
58
+
59
+ class EncoderLayer(torch.nn.Module):
60
+ """Encoder Layer"""
61
+
62
+ def __init__(self, in_channels: int, out_channels: int) -> None:
63
+ super().__init__()
64
+ self.conv_layer = torch.nn.Sequential(
65
+ torch.nn.Conv2d(
66
+ in_channels=in_channels,
67
+ out_channels=out_channels,
68
+ kernel_size=3,
69
+ stride=2,
70
+ padding=1,
71
+ ),
72
+ torch.nn.LeakyReLU(),
73
+ torch.nn.Conv2d(
74
+ in_channels=out_channels,
75
+ out_channels=out_channels,
76
+ kernel_size=3,
77
+ stride=2,
78
+ padding=1,
79
+ ),
80
+ torch.nn.LeakyReLU(),
81
+ )
82
+ self.max_pool = torch.nn.Sequential(
83
+ torch.nn.MaxPool2d(2),
84
+ )
85
+ self.layer = torch.nn.Sequential(
86
+ self.conv_layer,
87
+ self.max_pool,
88
+ )
89
+
90
+ def get_conv_layers(self, x: torch.Tensor) -> torch.Tensor:
91
+ """Need to concatenate the layer"""
92
+ return self.conv_layer(x)
93
+
94
+ def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
95
+ """Forward pass to return conv layer and the max pool layer"""
96
+ conv_output: torch.tensor = self.conv_layer(x)
97
+ fin_out: torch.Tensor = self.max_pool(conv_output)
98
+ return fin_out, conv_output
99
+
100
+
101
+ class MiddleLayer(EncoderLayer):
102
+ """Middle layer only"""
103
+
104
+ def forward(self, x: torch.tensor) -> torch.tensor:
105
+ """Forward pass"""
106
+ return self.conv_layer(x)
app/model/__init__.py ADDED
File without changes
app/model/lit_model.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import matplotlib.pyplot as plt
2
+ import pytorch_lightning as pl
3
+ import torch
4
+ import torch.nn as nn
5
+ import torchvision
6
+
7
+
8
+ class Pix2PixLitModule(pl.LightningModule):
9
+ """ Lightning Module for pix2pix """
10
+
11
+ @staticmethod
12
+ def _weights_init(m):
13
+ if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
14
+ torch.nn.init.normal_(m.weight, 0.0, 0.02)
15
+ if isinstance(m, nn.BatchNorm2d):
16
+ torch.nn.init.normal_(m.weight, 0.0, 0.02)
17
+ torch.nn.init.constant_(m.bias, 0)
18
+
19
+ def __init__(
20
+ self,
21
+ generator,
22
+ discriminator,
23
+ use_gpu: bool,
24
+ lambda_recon=100
25
+ ):
26
+ super().__init__()
27
+ self.save_hyperparameters()
28
+
29
+ self.gen = generator
30
+ self.disc = discriminator
31
+
32
+ # intializing weights
33
+ self.gen = self.gen.apply(self._weights_init)
34
+ self.disc = self.disc.apply(self._weights_init)
35
+ #
36
+ self.adversarial_criterion = nn.BCEWithLogitsLoss()
37
+ self.recon_criterion = nn.L1Loss()
38
+ self.lambda_l1 = lambda_recon
39
+
40
+ def _gen_step(self, sketch, coloured_sketches):
41
+ # Pix2Pix has adversarial and a reconstruction loss
42
+ # First calculate the adversarial loss
43
+ gen_coloured_sketches = self.gen(sketch)
44
+ # disc_logits = self.disc(gen_coloured_sketches, coloured_sketches)
45
+ disc_logits = self.disc(sketch, gen_coloured_sketches)
46
+ adversarial_loss = self.adversarial_criterion(disc_logits, torch.ones_like(disc_logits))
47
+ # calculate reconstruction loss
48
+ recon_loss = self.recon_criterion(gen_coloured_sketches, coloured_sketches) * self.lambda_l1
49
+ #
50
+ self.log("Gen recon_loss", recon_loss)
51
+ self.log("Gen adversarial_loss", adversarial_loss)
52
+ #
53
+ return adversarial_loss + recon_loss
54
+
55
+ def _disc_step(self, sketch, coloured_sketches):
56
+ gen_coloured_sketches = self.gen(sketch).detach()
57
+ #
58
+ # fake_logits = self.disc(gen_coloured_sketches, coloured_sketches)
59
+ fake_logits = self.disc(sketch, gen_coloured_sketches)
60
+ real_logits = self.disc(sketch, coloured_sketches)
61
+ #
62
+ fake_loss = self.adversarial_criterion(fake_logits, torch.zeros_like(fake_logits))
63
+ real_loss = self.adversarial_criterion(real_logits, torch.ones_like(real_logits))
64
+ #
65
+ self.log("PatchGAN fake_loss", fake_loss)
66
+ self.log("PatchGAN real_loss", real_loss)
67
+ return (real_loss + fake_loss) / 2
68
+
69
+ def forward(self, x):
70
+ return self.gen(x)
71
+
72
+ def training_step(self, batch, batch_idx, optimizer_idx):
73
+ real, condition = batch
74
+ loss = None
75
+ if optimizer_idx == 0:
76
+ loss = self._disc_step(real, condition)
77
+ self.log("TRAIN_PatchGAN Loss", loss)
78
+ elif optimizer_idx == 1:
79
+ loss = self._gen_step(real, condition)
80
+ self.log("TRAIN_Generator Loss", loss)
81
+ return loss
82
+
83
+ def validation_epoch_end(self, outputs) -> None:
84
+ """ Log the images"""
85
+ sketch = outputs[0]['sketch']
86
+ colour = outputs[0]['colour']
87
+ gen_coloured = self.gen(sketch)
88
+ grid_image = torchvision.utils.make_grid(
89
+ [sketch[0], colour[0], gen_coloured[0]],
90
+ normalize=True
91
+ )
92
+ self.logger.experiment.add_image(f'Image Grid {str(self.current_epoch)}', grid_image, self.current_epoch)
93
+ #plt.imshow(grid_image.permute(1, 2, 0))
94
+
95
+ def validation_step(self, batch, batch_idx):
96
+ """ Validation step """
97
+ real, condition = batch
98
+ return {
99
+ 'sketch': real,
100
+ 'colour': condition
101
+ }
102
+
103
+ def configure_optimizers(self, lr=2e-4):
104
+ gen_opt = torch.optim.Adam(self.gen.parameters(), lr=lr, betas=(0.5, 0.999))
105
+ disc_opt = torch.optim.Adam(self.disc.parameters(), lr=lr, betas=(0.5, 0.999))
106
+ return disc_opt, gen_opt
107
+
108
+ # class EpochInference(pl.Callback):
109
+ # """
110
+ # Callback on each end of training epoch
111
+ # The callback will do inference on test dataloader based on corresponding checkpoints
112
+ # The results will be saved as an image with 4-rows:
113
+ # 1 - Input image e.g. grayscale edged input
114
+ # 2 - Ground-truth
115
+ # 3 - Single inference
116
+ # 4 - Mean of hundred accumulated inference
117
+ # Note that the inference have a noise factor that will generate different output on each execution
118
+ # """
119
+ #
120
+ # def __init__(self, dataloader, use_gpu: bool, *args, **kwargs):
121
+ # super().__init__(*args, **kwargs)
122
+ # self.dataloader = dataloader
123
+ # self.use_gpu = use_gpu
124
+ #
125
+ # def on_train_epoch_end(self, trainer, pl_module):
126
+ # super().on_train_epoch_end(trainer, pl_module)
127
+ # data = next(iter(self.dataloader))
128
+ # image, target = data
129
+ # if self.use_gpu:
130
+ # image = image.cuda()
131
+ # target = target.cuda()
132
+ # with torch.no_grad():
133
+ # # Take average of multiple inference as there is a random noise
134
+ # # Single
135
+ # reconstruction_init = pl_module(image)
136
+ # reconstruction_init = torch.clip(reconstruction_init, 0, 1)
137
+ # # # Mean
138
+ # # reconstruction_mean = torch.stack([pl_module(image) for _ in range(10)])
139
+ # # reconstruction_mean = torch.clip(reconstruction_mean, 0, 1)
140
+ # # reconstruction_mean = torch.mean(reconstruction_mean, dim=0)
141
+ # # Grayscale 1-D to 3-D
142
+ # # image = torch.stack([image for _ in range(3)], dim=1)
143
+ # # image = torch.squeeze(image)
144
+ # grid_image = torchvision.utils.make_grid([image[0], target[0], reconstruction_init[0]])
145
+ # torchvision.utils.save_image(grid_image, fp=f'{trainer.default_root_dir}/epoch-{trainer.current_epoch:04}.png')
app/scratch.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ class GANInference:
3
+ def __init__(
4
+ self,
5
+ model: Pix2PixLitModule,
6
+ img_file: str = "/Users/nimud/Downloads/thesis_test2.png",
7
+ ) -> None:
8
+ self.img_file = img_file
9
+ self.model = model
10
+
11
+ def _get_image_from_path(self) -> torch.Tensor:
12
+ """ gets the tensor from filepath """
13
+ image = np.array(Image.open(self.img_file))
14
+ # use on inference
15
+ inference_transform = A.Compose([
16
+ A.Resize(width=256, height=256),
17
+ A.Normalize(mean=[.5, .5, .5], std=[.5, .5, .5], max_pixel_value=255.0),
18
+ al_pytorch.ToTensorV2(),
19
+ ])
20
+ inference_img = inference_transform(image=image)['image'].unsqueeze(0)
21
+ return inference_img
22
+
23
+ def _create_grid(self, result: torch.Tensor) -> np.array:
24
+ return torchvision.utils.make_grid(
25
+ [result[0].permute(1, 2, 0).detach()],
26
+ normalize=True
27
+ )
28
+
29
+ def run(self) -> np.array:
30
+ """ Returns a plottable image """
31
+ inference_img = self._get_image_from_path()
32
+ result = self.model(inference_img)
33
+ adjusted_result = self._create_grid(result=result)
34
+ return adjusted_result
model/lightning_bolts_model/cosine_sim_model.ckpt DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:2987394cad6890877faaf61ade50eada5397c2d1447a48049e8ad3197ea461cc
3
- size 780630439
 
 
 
 
model/lightning_bolts_model/modified_patchgan.ckpt DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:7b6b85940399eb68eca7a62b603cd62ac2bc813bbec70a16df83842da73dd14a
3
- size 686280151