Nikhil Mudhalwadkar commited on
Commit
7337bea
1 Parent(s): 0604f1a

added new model with cosine similarity

Browse files
app.py CHANGED
@@ -21,7 +21,7 @@ import albumentations.pytorch as al_pytorch
21
  import torchvision
22
  from pl_bolts.models.gans import Pix2Pix
23
  from pl_bolts.models.gans.pix2pix.components import PatchGAN
24
-
25
 
26
  """ Class """
27
 
@@ -86,6 +86,7 @@ train_16_val_1_plbolts_model_chkpt = (
86
  "model/lightning_bolts_model/epoch=99-step=89000.ckpt"
87
  )
88
  modified_patchgan_chkpt = "model/lightning_bolts_model/modified_patchgan.ckpt"
 
89
  # model_checkpoint_path = "model/pix2pix_lightning_model/version_0/checkpoints/epoch=199-step=355600.ckpt"
90
  # model_checkpoint_path = "model/pix2pix_lightning_model/gen.pth"
91
 
@@ -106,6 +107,112 @@ modified_patchgan_model = PatchGanChanged.load_from_checkpoint(modified_patchgan
106
  modified_patchgan_model.eval()
107
 
108
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
  def predict(img: Image, type_of_model: str):
110
  """Create predictions"""
111
  # transform img
@@ -127,6 +234,8 @@ def predict(img: Image, type_of_model: str):
127
  model = train_16_val_1_plbolts_model
128
  elif type_of_model == "train batch size 64, val batch size 16":
129
  model = train_64_val_16_plbolts_model
 
 
130
  else:
131
  model = modified_patchgan_model
132
 
@@ -151,6 +260,13 @@ def predict3(img: Image):
151
  )
152
 
153
 
 
 
 
 
 
 
 
154
  model_input = gr.inputs.Radio(
155
  [
156
  "train batch size 16, val batch size 1",
@@ -169,17 +285,19 @@ img_examples = [
169
  "examples/thesis6.png",
170
  ]
171
 
172
-
173
  with gr.Blocks() as demo:
174
  gr.Markdown(" # Colour your sketches!")
175
  gr.Markdown(" ## Description :")
176
- gr.Markdown(" There are three Pix2Pix models in this example:")
177
  gr.Markdown(" 1. Training batch size is 16 , validation is 1")
178
  gr.Markdown(" 2. Training batch size is 64 , validation is 16")
179
  gr.Markdown(
180
  " 3. PatchGAN is changed, 1 value only instead of 16*16 ;"
181
  "training batch size is 64 , validation is 16"
182
  )
 
 
 
183
  with gr.Tabs():
184
  with gr.TabItem("tr_16_val_1"):
185
  with gr.Row():
@@ -222,6 +340,20 @@ with gr.Blocks() as demo:
222
  outputs=image_output3,
223
  fn=predict3,
224
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
225
 
226
  colour_1.click(
227
  fn=predict1,
@@ -238,6 +370,11 @@ with gr.Blocks() as demo:
238
  inputs=image_input3,
239
  outputs=image_output3,
240
  )
 
 
 
 
 
241
 
242
  demo.title = "Colour your sketches!"
243
  demo.launch()
 
21
  import torchvision
22
  from pl_bolts.models.gans import Pix2Pix
23
  from pl_bolts.models.gans.pix2pix.components import PatchGAN
24
+ import torchvision.models as models
25
 
26
  """ Class """
27
 
 
86
  "model/lightning_bolts_model/epoch=99-step=89000.ckpt"
87
  )
88
  modified_patchgan_chkpt = "model/lightning_bolts_model/modified_patchgan.ckpt"
89
+
90
  # model_checkpoint_path = "model/pix2pix_lightning_model/version_0/checkpoints/epoch=199-step=355600.ckpt"
91
  # model_checkpoint_path = "model/pix2pix_lightning_model/gen.pth"
92
 
 
107
  modified_patchgan_model.eval()
108
 
109
 
110
+ # Create new class
111
+ class OverpoweredPix2Pix(Pix2Pix):
112
+ def __init__(self, in_channels, out_channels):
113
+ super(OverpoweredPix2Pix, self).__init__(
114
+ in_channels=in_channels, out_channels=out_channels
115
+ )
116
+ self._create_inception_score()
117
+
118
+ def _gen_step(self, real_images, conditioned_images):
119
+ # Pix2Pix has adversarial and a reconstruction loss
120
+ # First calculate the adversarial loss
121
+ fake_images = self.gen(conditioned_images)
122
+ disc_logits = self.patch_gan(fake_images, conditioned_images)
123
+ adversarial_loss = self.adversarial_criterion(
124
+ disc_logits, torch.ones_like(disc_logits)
125
+ )
126
+
127
+ # calculate reconstruction loss
128
+ recon_loss = self.recon_criterion(fake_images, real_images)
129
+ lambda_recon = self.hparams.lambda_recon
130
+
131
+ # calculate cosine similarity
132
+ representations_real = self.feature_extractor(real_images).flatten(1)
133
+ representations_fake = self.feature_extractor(fake_images).flatten(1)
134
+ similarity_score_list = self.cosine_similarity(
135
+ representations_real, representations_fake
136
+ )
137
+ cosine_sim = sum(similarity_score_list) / len(similarity_score_list)
138
+
139
+ self.log("Gen Cosine Sim Loss ", 1 - cosine_sim.cpu().detach().numpy())
140
+ # print(adversarial_loss,1-cosine_sim, lambda_recon, recon_loss, )
141
+
142
+ return (
143
+ (adversarial_loss)
144
+ + (lambda_recon * recon_loss)
145
+ + (lambda_recon * (1 - cosine_sim))
146
+ )
147
+
148
+ def _create_inception_score(self):
149
+ # init a pretrained resnet
150
+ backbone = models.resnet50(pretrained=True)
151
+ num_filters = backbone.fc.in_features
152
+ layers = list(backbone.children())[:-1]
153
+ self.feature_extractor = torch.nn.Sequential(*layers)
154
+ self.cosine_similarity = torch.nn.CosineSimilarity(dim=1, eps=1e-6)
155
+
156
+ def validation_step(self, batch, batch_idx):
157
+ """Validation step"""
158
+ real, condition = batch
159
+ with torch.no_grad():
160
+ disc_loss = self._disc_step(real, condition)
161
+ self.log("Valid PatchGAN Loss", disc_loss)
162
+
163
+ gan_loss = self._gen_step(real, condition)
164
+ self.log("Valid Generator Loss", gan_loss)
165
+
166
+ #
167
+ fake_images = self.gen(condition)
168
+ representations_real = self.feature_extractor(real).flatten(1)
169
+ representations_fake = self.feature_extractor(fake_images).flatten(1)
170
+ similarity_score_list = self.cosine_similarity(
171
+ representations_real, representations_fake
172
+ )
173
+ cosine_sim = sum(similarity_score_list) / len(similarity_score_list)
174
+
175
+ self.log("Valid Cosine Sim", cosine_sim)
176
+
177
+ return {"sketch": condition, "colour": real}
178
+
179
+ def validation_epoch_end(
180
+ self, outputs: Union[EPOCH_OUTPUT, List[EPOCH_OUTPUT]]
181
+ ) -> None:
182
+ sketch = outputs[0]["sketch"]
183
+ colour = outputs[0]["colour"]
184
+ self.feature_extractor.eval()
185
+ with torch.no_grad():
186
+ gen_coloured = self.gen(sketch)
187
+ representations_gen = self.feature_extractor(gen_coloured).flatten(1)
188
+ representations_fake = self.feature_extractor(colour).flatten(1)
189
+
190
+ similarity_score_list = self.cosine_similarity(
191
+ representations_gen, representations_fake
192
+ )
193
+ similarity_score = sum(similarity_score_list) / len(similarity_score_list)
194
+
195
+ grid_image = torchvision.utils.make_grid(
196
+ [
197
+ sketch[0],
198
+ colour[0],
199
+ gen_coloured[0],
200
+ ],
201
+ normalize=True,
202
+ )
203
+ self.logger.experiment.add_image(
204
+ f"Image Grid {str(self.current_epoch)} __ {str(similarity_score)} ",
205
+ grid_image,
206
+ self.current_epoch,
207
+ )
208
+
209
+
210
+ cosine_sim_model_chkpt = "model/lightning_bolts_model/cosine_sim_model.ckpt"
211
+
212
+ cosine_sim_model = OverpoweredPix2Pix.load_from_checkpoint(cosine_sim_model_chkpt)
213
+ cosine_sim_model.eval()
214
+
215
+
216
  def predict(img: Image, type_of_model: str):
217
  """Create predictions"""
218
  # transform img
 
234
  model = train_16_val_1_plbolts_model
235
  elif type_of_model == "train batch size 64, val batch size 16":
236
  model = train_64_val_16_plbolts_model
237
+ elif type_of_model == "cosine similarity":
238
+ model = cosine_sim_model
239
  else:
240
  model = modified_patchgan_model
241
 
 
260
  )
261
 
262
 
263
+ def predict4(img: Image):
264
+ return predict(
265
+ img=img,
266
+ type_of_model="cosine similarity",
267
+ )
268
+
269
+
270
  model_input = gr.inputs.Radio(
271
  [
272
  "train batch size 16, val batch size 1",
 
285
  "examples/thesis6.png",
286
  ]
287
 
 
288
  with gr.Blocks() as demo:
289
  gr.Markdown(" # Colour your sketches!")
290
  gr.Markdown(" ## Description :")
291
+ gr.Markdown(" There are 4 Pix2Pix models in this example:")
292
  gr.Markdown(" 1. Training batch size is 16 , validation is 1")
293
  gr.Markdown(" 2. Training batch size is 64 , validation is 16")
294
  gr.Markdown(
295
  " 3. PatchGAN is changed, 1 value only instead of 16*16 ;"
296
  "training batch size is 64 , validation is 16"
297
  )
298
+ gr.Markdown(
299
+ " 4. cosine similarity is also added as a metric in this experiment for the generator. "
300
+ )
301
  with gr.Tabs():
302
  with gr.TabItem("tr_16_val_1"):
303
  with gr.Row():
 
340
  outputs=image_output3,
341
  fn=predict3,
342
  )
343
+ with gr.TabItem("Cosine similarity loss"):
344
+ with gr.Row():
345
+ image_input4 = gr.inputs.Image(type="pil")
346
+ image_output4 = gr.outputs.Image(
347
+ type="pil",
348
+ )
349
+ colour_4 = gr.Button("Colour it!")
350
+ with gr.Row():
351
+ gr.Examples(
352
+ examples=img_examples,
353
+ inputs=image_input4,
354
+ outputs=image_output4,
355
+ fn=predict4,
356
+ )
357
 
358
  colour_1.click(
359
  fn=predict1,
 
370
  inputs=image_input3,
371
  outputs=image_output3,
372
  )
373
+ colour_4.click(
374
+ fn=predict4,
375
+ inputs=image_input4,
376
+ outputs=image_output4,
377
+ )
378
 
379
  demo.title = "Colour your sketches!"
380
  demo.launch()
model/lightning_bolts_model/cosine_sim_model.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2987394cad6890877faaf61ade50eada5397c2d1447a48049e8ad3197ea461cc
3
+ size 780630439