Nikhil Mudhalwadkar commited on
Commit
b308e39
·
1 Parent(s): 5165732

Recreate the demo

Browse files
Files changed (1) hide show
  1. app.py +98 -154
app.py CHANGED
@@ -1,10 +1,8 @@
1
  from typing import Union, List
2
 
3
  import gradio as gr
4
- import torch
5
- import torch.nn as nn
6
  import matplotlib
7
- import torch.nn.functional as F
8
  from pytorch_lightning.utilities.types import EPOCH_OUTPUT
9
 
10
  matplotlib.use('Agg')
@@ -12,11 +10,8 @@ import numpy as np
12
  from PIL import Image
13
  import albumentations as A
14
  import albumentations.pytorch as al_pytorch
15
- import matplotlib.pyplot as plt
16
  import torchvision
17
  from pl_bolts.models.gans import Pix2Pix
18
- from app.generator.unetGen import Generator as gen
19
- from app.model.lit_model import Pix2PixLitModule
20
 
21
  """ Class """
22
 
@@ -49,182 +44,131 @@ class OverpoweredPix2Pix(Pix2Pix):
49
  ],
50
  normalize=True
51
  )
52
- self.logger.experiment.add_image(f'Image Grid {str(self.current_epoch)}', grid_image, self.current_epoch)
53
-
54
-
55
- class Downsample(nn.Module):
56
- def __init__(self, in_channels, out_channels):
57
- super(Downsample, self).__init__()
58
- self.conv_relu = nn.Sequential(
59
- nn.Conv2d(in_channels, out_channels, 3, 2, 1),
60
- nn.LeakyReLU(inplace=True)
61
- )
62
- self.bn = nn.BatchNorm2d(out_channels)
63
-
64
- def forward(self, x, is_bn=True):
65
- x = self.conv_relu(x)
66
- if is_bn:
67
- x = self.bn(x)
68
- return x
69
-
70
-
71
- class Upsample(nn.Module):
72
- def __init__(self, in_channels, out_channels):
73
- super(Upsample, self).__init__()
74
- self.upconv_relu = nn.Sequential(
75
- nn.ConvTranspose2d(in_channels, out_channels, 3, 2, 1,
76
- output_padding=1),
77
- nn.LeakyReLU(inplace=True)
78
  )
79
- self.bn = nn.BatchNorm2d(out_channels)
80
-
81
- def forward(self, x, is_drop=False):
82
- x = self.upconv_relu(x)
83
- x = self.bn(x)
84
- if is_drop:
85
- x = F.dropout2d(x)
86
- return x
87
-
88
-
89
- class Generator(nn.Module):
90
- def __init__(self):
91
- super(Generator, self).__init__()
92
- self.down1 = Downsample(3, 64)
93
- self.down2 = Downsample(64, 128)
94
- self.down3 = Downsample(128, 256)
95
- self.down4 = Downsample(256, 512)
96
- self.down5 = Downsample(512, 512)
97
- self.down6 = Downsample(512, 512)
98
- self.down7 = Downsample(512, 512)
99
- self.down8 = Downsample(512, 512)
100
-
101
- self.up1 = Upsample(512, 512)
102
- self.up2 = Upsample(1024, 512)
103
- self.up3 = Upsample(1024, 512)
104
- self.up4 = Upsample(1024, 512)
105
- self.up5 = Upsample(1024, 256)
106
- self.up6 = Upsample(512, 128)
107
- self.up7 = Upsample(256, 64)
108
-
109
- self.last = nn.ConvTranspose2d(128, 3,
110
- kernel_size=3,
111
- stride=2,
112
- padding=1,
113
- output_padding=1)
114
-
115
- def forward(self, x):
116
- x1 = self.down1(x) # torch.Size([8, 64, 128, 128])
117
- x2 = self.down2(x1) # torch.Size([8, 128, 64, 64])
118
- x3 = self.down3(x2) # torch.Size([8, 256, 32, 32])
119
- x4 = self.down4(x3) # torch.Size([8, 512, 16, 16])
120
- x5 = self.down5(x4) # torch.Size([8, 512, 8, 8])
121
- x6 = self.down6(x5) # torch.Size([8, 512, 4, 4])
122
- x7 = self.down7(x6) # torch.Size([8, 512, 2, 2])
123
- x8 = self.down8(x7) # torch.Size([8, 512, 1, 1])
124
-
125
- x8 = self.up1(x8, is_drop=True) # torch.Size([8, 512, 2, 2])
126
- x8 = torch.cat([x7, x8], dim=1) # torch.Size([8, 1024, 2, 2])
127
-
128
- x8 = self.up2(x8, is_drop=True) # torch.Size([8, 512, 4, 4])
129
- x8 = torch.cat([x6, x8], dim=1) # torch.Size([8, 1024, 2, 2])
130
-
131
- x8 = self.up3(x8, is_drop=True) # torch.Size([8, 512, 8, 8])
132
- x8 = torch.cat([x5, x8], dim=1) # torch.Size([8, 1024, 8, 8])
133
-
134
- x8 = self.up4(x8) # torch.Size([8, 512, 16, 16])
135
- x8 = torch.cat([x4, x8], dim=1) # torch.Size([8, 1024, 16, 16])
136
-
137
- x8 = self.up5(x8)
138
- x8 = torch.cat([x3, x8], dim=1)
139
-
140
- x8 = self.up6(x8)
141
- x8 = torch.cat([x2, x8], dim=1)
142
-
143
- x8 = self.up7(x8)
144
- x8 = torch.cat([x1, x8], dim=1)
145
-
146
- x8 = torch.tanh(self.last(x8))
147
- return x8
148
 
149
 
150
  """ Load the model """
151
- model_checkpoint_path = "model/lightning_bolts_model/epoch=99-step=89000.ckpt"
 
 
152
  # model_checkpoint_path = "model/pix2pix_lightning_model/version_0/checkpoints/epoch=199-step=355600.ckpt"
153
  # model_checkpoint_path = "model/pix2pix_lightning_model/gen.pth"
154
 
155
- model = OverpoweredPix2Pix.load_from_checkpoint(
156
- model_checkpoint_path
 
157
  )
 
158
 
159
- model_chk = torch.load(
160
- model_checkpoint_path, map_location=torch.device('cpu')
 
161
  )
162
- # model = gen().load_state_dict(model_chk)
163
 
164
- model.eval()
165
 
166
-
167
- def greet(name):
168
- return "Hello " + name + "!!"
169
-
170
-
171
- def predict(img: Image):
172
  # transform img
173
  image = np.asarray(img)
174
- # image = image[:, image.shape[1] // 2:, :]
175
  # use on inference
176
  inference_transform = A.Compose([
177
  A.Resize(width=256, height=256),
178
  A.Normalize(mean=[.5, .5, .5], std=[.5, .5, .5], max_pixel_value=255.0),
179
  al_pytorch.ToTensorV2(),
180
  ])
181
- # inverse_transform = A.Compose([
182
- # A.Normalize(
183
- # mean=[0.485, 0.456, 0.406],
184
- # std=[0.229, 0.224, 0.225]
185
- # ),
186
- # ])
187
  inference_img = inference_transform(
188
  image=image
189
  )['image'].unsqueeze(0)
 
 
 
 
 
 
 
 
 
190
  with torch.no_grad():
191
  result = model.gen(inference_img)
192
- # torchvision.utils.save_image(inference_img, "inference_image.png", normalize=True)
193
  torchvision.utils.save_image(result, "inference_image.png", normalize=True)
194
-
195
- """
196
- result_grid = torchvision.utils.make_grid(
197
- [result[0]],
198
- normalize=True
199
- )
200
- # plt.imsave("coloured_grid.png", (result_grid.permute(1,2,0).detach().numpy()*255).astype(int))
201
- torchvision.utils.save_image(
202
- result_grid, "coloured_image.png", normalize=True
203
- )
204
- """
205
  return "inference_image.png" # 'coloured_image.png',
206
 
207
 
208
- iface = gr.Interface(
209
- fn=predict,
210
- inputs=gr.inputs.Image(type="pil"),
211
- #inputs="sketchpad",
212
- examples=[
213
- "examples/thesis_test.png",
214
- "examples/thesis_test2.png",
215
- "examples/thesis1.png",
216
- "examples/thesis4.png",
217
- "examples/thesis5.png",
218
- "examples/thesis6.png",
219
- # "examples/1000000.png"
 
220
  ],
221
- outputs=gr.outputs.Image(type="pil",),
222
- #outputs=[
223
- # "image",
224
- # # "image"
225
- #],
226
- title="Colour your sketches!",
227
- description=" Upload a sketch and the conditional gan will colour it for you!",
228
- article="WIP repo lives here - https://github.com/nmud19/thesisGAN "
229
  )
230
- iface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from typing import Union, List
2
 
3
  import gradio as gr
 
 
4
  import matplotlib
5
+ import torch
6
  from pytorch_lightning.utilities.types import EPOCH_OUTPUT
7
 
8
  matplotlib.use('Agg')
 
10
  from PIL import Image
11
  import albumentations as A
12
  import albumentations.pytorch as al_pytorch
 
13
  import torchvision
14
  from pl_bolts.models.gans import Pix2Pix
 
 
15
 
16
  """ Class """
17
 
 
44
  ],
45
  normalize=True
46
  )
47
+ self.logger.experiment.add_image(
48
+ f'Image Grid {str(self.current_epoch)}',
49
+ grid_image,
50
+ self.current_epoch
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
 
53
 
54
  """ Load the model """
55
+ # train_64_val_16_patchgan_1val_plbolts_model_chkpt = "model/lightning_bolts_model/modified_path_gan.ckpt"
56
+ train_64_val_16_plbolts_model_chkpt = "model/lightning_bolts_model/epoch=99-step=44600.ckpt"
57
+ train_16_val_1_plbolts_model_chkpt = "model/lightning_bolts_model/epoch=99-step=89000.ckpt"
58
  # model_checkpoint_path = "model/pix2pix_lightning_model/version_0/checkpoints/epoch=199-step=355600.ckpt"
59
  # model_checkpoint_path = "model/pix2pix_lightning_model/gen.pth"
60
 
61
+ # Load the models
62
+ train_64_val_16_plbolts_model = OverpoweredPix2Pix.load_from_checkpoint(
63
+ train_64_val_16_plbolts_model_chkpt
64
  )
65
+ train_64_val_16_plbolts_model.eval()
66
 
67
+ #
68
+ train_16_val_1_plbolts_model = OverpoweredPix2Pix.load_from_checkpoint(
69
+ train_16_val_1_plbolts_model_chkpt
70
  )
71
+ train_16_val_1_plbolts_model.eval()
72
 
 
73
 
74
+ def predict(img: Image, type_of_model: str):
75
+ """ Create predictions """
 
 
 
 
76
  # transform img
77
  image = np.asarray(img)
 
78
  # use on inference
79
  inference_transform = A.Compose([
80
  A.Resize(width=256, height=256),
81
  A.Normalize(mean=[.5, .5, .5], std=[.5, .5, .5], max_pixel_value=255.0),
82
  al_pytorch.ToTensorV2(),
83
  ])
 
 
 
 
 
 
84
  inference_img = inference_transform(
85
  image=image
86
  )['image'].unsqueeze(0)
87
+
88
+ # Choose model
89
+ if type_of_model == "train batch size 16, val batch size 1":
90
+ model = train_16_val_1_plbolts_model
91
+ elif type_of_model == "train batch size 64, val batch size 16":
92
+ model = train_64_val_16_plbolts_model
93
+ else:
94
+ raise Exception("NOT YET SUPPORTED")
95
+
96
  with torch.no_grad():
97
  result = model.gen(inference_img)
 
98
  torchvision.utils.save_image(result, "inference_image.png", normalize=True)
 
 
 
 
 
 
 
 
 
 
 
99
  return "inference_image.png" # 'coloured_image.png',
100
 
101
 
102
+ def predict1(img: Image):
103
+ return predict(img=img, type_of_model="train batch size 16, val batch size 1")
104
+
105
+
106
+ def predict2(img: Image):
107
+ return predict(img=img, type_of_model="train batch size 64, val batch size 16")
108
+
109
+
110
+ model_input = gr.inputs.Radio(
111
+ [
112
+ "train batch size 16, val batch size 1",
113
+ "train batch size 64, val batch size 16",
114
+ "train batch size 64, val batch size 16, patch gan has 1 output score instead of 16*16",
115
  ],
116
+ label="Type of Pix2Pix model to use : "
 
 
 
 
 
 
 
117
  )
118
+ image_input = gr.inputs.Image(type="pil")
119
+ img_examples = [
120
+ "examples/thesis_test.png",
121
+ "examples/thesis_test2.png",
122
+ "examples/thesis1.png",
123
+ "examples/thesis4.png",
124
+ "examples/thesis5.png",
125
+ "examples/thesis6.png",
126
+ ]
127
+
128
+
129
+ with gr.Blocks() as demo:
130
+ gr.Markdown(" # Colour your sketches!")
131
+ gr.Markdown(" ## Description :")
132
+ gr.Markdown(" There are three Pix2Pix models in this example:")
133
+ gr.Markdown(" 1. Training batch size is 16 , validation is 1")
134
+ gr.Markdown(" 2. Training batch size is 64 , validation is 16")
135
+ gr.Markdown(" 3. PatchGAN is changed, 1 value only instead of 16*16 ;"
136
+ "training batch size is 64 , validation is 16")
137
+ with gr.Tabs():
138
+ with gr.TabItem("tr_16_val_1"):
139
+ with gr.Row():
140
+ image_input1 = gr.inputs.Image(type="pil")
141
+ image_output1 = gr.outputs.Image(type="pil", )
142
+ colour_1 = gr.Button("Colour it!")
143
+ gr.Examples(
144
+ examples=img_examples,
145
+ inputs=image_input1,
146
+ outputs=image_output1,
147
+ fn=predict1,
148
+ )
149
+ with gr.TabItem("tr_64_val_14"):
150
+ with gr.Row():
151
+ image_input2 = gr.inputs.Image(type="pil")
152
+ image_output2 = gr.outputs.Image(type="pil", )
153
+ colour_2 = gr.Button("Colour it!")
154
+ with gr.Row():
155
+ gr.Examples(
156
+ examples=img_examples,
157
+ inputs=image_input2,
158
+ outputs=image_output2,
159
+ fn=predict2,
160
+ )
161
+
162
+ colour_1.click(
163
+ fn=predict1,
164
+ inputs=image_input1,
165
+ outputs=image_output1,
166
+ )
167
+ colour_2.click(
168
+ fn=predict2,
169
+ inputs=image_input2,
170
+ outputs=image_output2,
171
+ )
172
+
173
+ demo.title = "Colour your sketches!"
174
+ demo.launch()