Nikhil Mudhalwadkar commited on
Commit
0604f1a
1 Parent(s): aea7215

added new model and black

Browse files
app.py CHANGED
@@ -1,5 +1,6 @@
1
  # Hack for spaces
2
  import os
 
3
  os.system("pip uninstall -y gradio")
4
  os.system("pip install -r requirements.txt")
5
 
@@ -19,7 +20,7 @@ import albumentations as A
19
  import albumentations.pytorch as al_pytorch
20
  import torchvision
21
  from pl_bolts.models.gans import Pix2Pix
22
-
23
 
24
 
25
  """ Class """
@@ -58,6 +59,24 @@ class OverpoweredPix2Pix(Pix2Pix):
58
  )
59
 
60
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
  """ Load the model """
62
  # train_64_val_16_patchgan_1val_plbolts_model_chkpt = "model/lightning_bolts_model/modified_path_gan.ckpt"
63
  train_64_val_16_plbolts_model_chkpt = (
@@ -66,6 +85,7 @@ train_64_val_16_plbolts_model_chkpt = (
66
  train_16_val_1_plbolts_model_chkpt = (
67
  "model/lightning_bolts_model/epoch=99-step=89000.ckpt"
68
  )
 
69
  # model_checkpoint_path = "model/pix2pix_lightning_model/version_0/checkpoints/epoch=199-step=355600.ckpt"
70
  # model_checkpoint_path = "model/pix2pix_lightning_model/gen.pth"
71
 
@@ -81,6 +101,10 @@ train_16_val_1_plbolts_model = OverpoweredPix2Pix.load_from_checkpoint(
81
  )
82
  train_16_val_1_plbolts_model.eval()
83
 
 
 
 
 
84
 
85
  def predict(img: Image, type_of_model: str):
86
  """Create predictions"""
@@ -104,7 +128,7 @@ def predict(img: Image, type_of_model: str):
104
  elif type_of_model == "train batch size 64, val batch size 16":
105
  model = train_64_val_16_plbolts_model
106
  else:
107
- raise Exception("NOT YET SUPPORTED")
108
 
109
  with torch.no_grad():
110
  result = model.gen(inference_img)
@@ -120,6 +144,13 @@ def predict2(img: Image):
120
  return predict(img=img, type_of_model="train batch size 64, val batch size 16")
121
 
122
 
 
 
 
 
 
 
 
123
  model_input = gr.inputs.Radio(
124
  [
125
  "train batch size 16, val batch size 1",
@@ -177,6 +208,20 @@ with gr.Blocks() as demo:
177
  outputs=image_output2,
178
  fn=predict2,
179
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
180
 
181
  colour_1.click(
182
  fn=predict1,
@@ -188,6 +233,11 @@ with gr.Blocks() as demo:
188
  inputs=image_input2,
189
  outputs=image_output2,
190
  )
 
 
 
 
 
191
 
192
  demo.title = "Colour your sketches!"
193
  demo.launch()
 
1
  # Hack for spaces
2
  import os
3
+
4
  os.system("pip uninstall -y gradio")
5
  os.system("pip install -r requirements.txt")
6
 
 
20
  import albumentations.pytorch as al_pytorch
21
  import torchvision
22
  from pl_bolts.models.gans import Pix2Pix
23
+ from pl_bolts.models.gans.pix2pix.components import PatchGAN
24
 
25
 
26
  """ Class """
 
59
  )
60
 
61
 
62
+ class PatchGanChanged(OverpoweredPix2Pix):
63
+ def __init__(self, in_channels, out_channels):
64
+ super(PatchGanChanged, self).__init__(
65
+ in_channels=in_channels, out_channels=out_channels
66
+ )
67
+ self.patch_gan = self.get_dense_PatchGAN(self.patch_gan)
68
+
69
+ @staticmethod
70
+ def get_dense_PatchGAN(disc: PatchGAN) -> PatchGAN:
71
+ """Add final layer to gan"""
72
+ disc.final = torch.nn.Sequential(
73
+ disc.final,
74
+ torch.nn.Flatten(),
75
+ torch.nn.Linear(16 * 16, 1),
76
+ )
77
+ return disc
78
+
79
+
80
  """ Load the model """
81
  # train_64_val_16_patchgan_1val_plbolts_model_chkpt = "model/lightning_bolts_model/modified_path_gan.ckpt"
82
  train_64_val_16_plbolts_model_chkpt = (
 
85
  train_16_val_1_plbolts_model_chkpt = (
86
  "model/lightning_bolts_model/epoch=99-step=89000.ckpt"
87
  )
88
+ modified_patchgan_chkpt = "model/lightning_bolts_model/modified_patchgan.ckpt"
89
  # model_checkpoint_path = "model/pix2pix_lightning_model/version_0/checkpoints/epoch=199-step=355600.ckpt"
90
  # model_checkpoint_path = "model/pix2pix_lightning_model/gen.pth"
91
 
 
101
  )
102
  train_16_val_1_plbolts_model.eval()
103
 
104
+ #
105
+ modified_patchgan_model = PatchGanChanged.load_from_checkpoint(modified_patchgan_chkpt)
106
+ modified_patchgan_model.eval()
107
+
108
 
109
  def predict(img: Image, type_of_model: str):
110
  """Create predictions"""
 
128
  elif type_of_model == "train batch size 64, val batch size 16":
129
  model = train_64_val_16_plbolts_model
130
  else:
131
+ model = modified_patchgan_model
132
 
133
  with torch.no_grad():
134
  result = model.gen(inference_img)
 
144
  return predict(img=img, type_of_model="train batch size 64, val batch size 16")
145
 
146
 
147
+ def predict3(img: Image):
148
+ return predict(
149
+ img=img,
150
+ type_of_model="train batch size 64, val batch size 16, patch gan has 1 output score instead of 16*16",
151
+ )
152
+
153
+
154
  model_input = gr.inputs.Radio(
155
  [
156
  "train batch size 16, val batch size 1",
 
208
  outputs=image_output2,
209
  fn=predict2,
210
  )
211
+ with gr.TabItem("Single Value Discriminator"):
212
+ with gr.Row():
213
+ image_input3 = gr.inputs.Image(type="pil")
214
+ image_output3 = gr.outputs.Image(
215
+ type="pil",
216
+ )
217
+ colour_3 = gr.Button("Colour it!")
218
+ with gr.Row():
219
+ gr.Examples(
220
+ examples=img_examples,
221
+ inputs=image_input3,
222
+ outputs=image_output3,
223
+ fn=predict3,
224
+ )
225
 
226
  colour_1.click(
227
  fn=predict1,
 
233
  inputs=image_input2,
234
  outputs=image_output2,
235
  )
236
+ colour_3.click(
237
+ fn=predict3,
238
+ inputs=image_input3,
239
+ outputs=image_output3,
240
+ )
241
 
242
  demo.title = "Colour your sketches!"
243
  demo.launch()
model/lightning_bolts_model/modified_patchgan.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7b6b85940399eb68eca7a62b603cd62ac2bc813bbec70a16df83842da73dd14a
3
+ size 686280151