Nikhil Mudhalwadkar commited on
Commit
025bf23
1 Parent(s): 67c17bf

Working lightning bolts

Browse files
Files changed (1) hide show
  1. app.py +185 -15
app.py CHANGED
@@ -1,6 +1,12 @@
 
 
1
  import gradio as gr
2
  import torch
 
3
  import matplotlib
 
 
 
4
  matplotlib.use('Agg')
5
  import numpy as np
6
  from PIL import Image
@@ -8,14 +14,153 @@ import albumentations as A
8
  import albumentations.pytorch as al_pytorch
9
  import matplotlib.pyplot as plt
10
  import torchvision
11
-
 
12
  from app.model.lit_model import Pix2PixLitModule
13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  """ Load the model """
15
- model_checkpoint_path = "model/pix2pix_lightning_model/version_0/checkpoints/epoch=9-step=17780.ckpt"
16
- model = Pix2PixLitModule.load_from_checkpoint(
 
 
 
17
  model_checkpoint_path
18
  )
 
 
 
 
 
 
19
  model.eval()
20
 
21
 
@@ -23,32 +168,57 @@ def greet(name):
23
  return "Hello " + name + "!!"
24
 
25
 
26
- def predict(image: Image):
 
 
 
27
  # use on inference
28
  inference_transform = A.Compose([
29
  A.Resize(width=256, height=256),
30
  A.Normalize(mean=[.5, .5, .5], std=[.5, .5, .5], max_pixel_value=255.0),
31
  al_pytorch.ToTensorV2(),
32
  ])
 
 
 
 
 
 
33
  inference_img = inference_transform(
34
- image=np.asarray(image)
35
  )['image'].unsqueeze(0)
36
- result = model(inference_img)
37
- result_grid = torchvision.utils.make_grid(
38
- [result[0].permute(1, 2, 0).detach()],
39
- normalize=True
40
- )
41
- plt.imsave("coloured_grid.png", result_grid.numpy())
42
- torchvision.utils.save_image(result, "coloured_image.png", normalize=True)
43
- return 'coloured_image.png', 'coloured_grid.png'
44
 
 
 
 
 
 
 
 
 
 
 
 
45
 
46
 
47
  iface = gr.Interface(
48
  fn=predict,
49
  inputs=gr.inputs.Image(type="pil"),
50
- examples=["examples/thesis_test.png", "examples/thesis_test2.png"],
51
- outputs=["image","image"],
 
 
 
 
 
 
 
 
 
52
  title="Colour your sketches!",
53
  description=" Upload a sketch and the conditional gan will colour it for you!",
54
  article="WIP repo lives here - https://github.com/nmud19/thesisGAN "
 
1
+ from typing import Union, List
2
+
3
  import gradio as gr
4
  import torch
5
+ import torch.nn as nn
6
  import matplotlib
7
+ import torch.nn.functional as F
8
+ from pytorch_lightning.utilities.types import EPOCH_OUTPUT
9
+
10
  matplotlib.use('Agg')
11
  import numpy as np
12
  from PIL import Image
 
14
  import albumentations.pytorch as al_pytorch
15
  import matplotlib.pyplot as plt
16
  import torchvision
17
+ from pl_bolts.models.gans import Pix2Pix
18
+ from app.generator.unetGen import Generator as gen
19
  from app.model.lit_model import Pix2PixLitModule
20
 
21
+ """ Class """
22
+
23
+
24
+ class OverpoweredPix2Pix(Pix2Pix):
25
+
26
+ def validation_step(self, batch, batch_idx):
27
+ """ Validation step """
28
+ real, condition = batch
29
+ with torch.no_grad():
30
+ loss = self._disc_step(real, condition)
31
+ self.log("val_PatchGAN_loss", loss)
32
+
33
+ loss = self._gen_step(real, condition)
34
+ self.log("val_generator_loss", loss)
35
+
36
+ return {
37
+ 'sketch': real,
38
+ 'colour': condition
39
+ }
40
+
41
+ def validation_epoch_end(self, outputs: Union[EPOCH_OUTPUT, List[EPOCH_OUTPUT]]) -> None:
42
+ sketch = outputs[0]['sketch']
43
+ colour = outputs[0]['colour']
44
+ with torch.no_grad():
45
+ gen_coloured = self.gen(sketch)
46
+ grid_image = torchvision.utils.make_grid(
47
+ [
48
+ sketch[0], colour[0], gen_coloured[0],
49
+ ],
50
+ normalize=True
51
+ )
52
+ self.logger.experiment.add_image(f'Image Grid {str(self.current_epoch)}', grid_image, self.current_epoch)
53
+
54
+
55
+ class Downsample(nn.Module):
56
+ def __init__(self, in_channels, out_channels):
57
+ super(Downsample, self).__init__()
58
+ self.conv_relu = nn.Sequential(
59
+ nn.Conv2d(in_channels, out_channels, 3, 2, 1),
60
+ nn.LeakyReLU(inplace=True)
61
+ )
62
+ self.bn = nn.BatchNorm2d(out_channels)
63
+
64
+ def forward(self, x, is_bn=True):
65
+ x = self.conv_relu(x)
66
+ if is_bn:
67
+ x = self.bn(x)
68
+ return x
69
+
70
+
71
+ class Upsample(nn.Module):
72
+ def __init__(self, in_channels, out_channels):
73
+ super(Upsample, self).__init__()
74
+ self.upconv_relu = nn.Sequential(
75
+ nn.ConvTranspose2d(in_channels, out_channels, 3, 2, 1,
76
+ output_padding=1),
77
+ nn.LeakyReLU(inplace=True)
78
+ )
79
+ self.bn = nn.BatchNorm2d(out_channels)
80
+
81
+ def forward(self, x, is_drop=False):
82
+ x = self.upconv_relu(x)
83
+ x = self.bn(x)
84
+ if is_drop:
85
+ x = F.dropout2d(x)
86
+ return x
87
+
88
+
89
+ class Generator(nn.Module):
90
+ def __init__(self):
91
+ super(Generator, self).__init__()
92
+ self.down1 = Downsample(3, 64)
93
+ self.down2 = Downsample(64, 128)
94
+ self.down3 = Downsample(128, 256)
95
+ self.down4 = Downsample(256, 512)
96
+ self.down5 = Downsample(512, 512)
97
+ self.down6 = Downsample(512, 512)
98
+ self.down7 = Downsample(512, 512)
99
+ self.down8 = Downsample(512, 512)
100
+
101
+ self.up1 = Upsample(512, 512)
102
+ self.up2 = Upsample(1024, 512)
103
+ self.up3 = Upsample(1024, 512)
104
+ self.up4 = Upsample(1024, 512)
105
+ self.up5 = Upsample(1024, 256)
106
+ self.up6 = Upsample(512, 128)
107
+ self.up7 = Upsample(256, 64)
108
+
109
+ self.last = nn.ConvTranspose2d(128, 3,
110
+ kernel_size=3,
111
+ stride=2,
112
+ padding=1,
113
+ output_padding=1)
114
+
115
+ def forward(self, x):
116
+ x1 = self.down1(x) # torch.Size([8, 64, 128, 128])
117
+ x2 = self.down2(x1) # torch.Size([8, 128, 64, 64])
118
+ x3 = self.down3(x2) # torch.Size([8, 256, 32, 32])
119
+ x4 = self.down4(x3) # torch.Size([8, 512, 16, 16])
120
+ x5 = self.down5(x4) # torch.Size([8, 512, 8, 8])
121
+ x6 = self.down6(x5) # torch.Size([8, 512, 4, 4])
122
+ x7 = self.down7(x6) # torch.Size([8, 512, 2, 2])
123
+ x8 = self.down8(x7) # torch.Size([8, 512, 1, 1])
124
+
125
+ x8 = self.up1(x8, is_drop=True) # torch.Size([8, 512, 2, 2])
126
+ x8 = torch.cat([x7, x8], dim=1) # torch.Size([8, 1024, 2, 2])
127
+
128
+ x8 = self.up2(x8, is_drop=True) # torch.Size([8, 512, 4, 4])
129
+ x8 = torch.cat([x6, x8], dim=1) # torch.Size([8, 1024, 2, 2])
130
+
131
+ x8 = self.up3(x8, is_drop=True) # torch.Size([8, 512, 8, 8])
132
+ x8 = torch.cat([x5, x8], dim=1) # torch.Size([8, 1024, 8, 8])
133
+
134
+ x8 = self.up4(x8) # torch.Size([8, 512, 16, 16])
135
+ x8 = torch.cat([x4, x8], dim=1) # torch.Size([8, 1024, 16, 16])
136
+
137
+ x8 = self.up5(x8)
138
+ x8 = torch.cat([x3, x8], dim=1)
139
+
140
+ x8 = self.up6(x8)
141
+ x8 = torch.cat([x2, x8], dim=1)
142
+
143
+ x8 = self.up7(x8)
144
+ x8 = torch.cat([x1, x8], dim=1)
145
+
146
+ x8 = torch.tanh(self.last(x8))
147
+ return x8
148
+
149
+
150
  """ Load the model """
151
+ model_checkpoint_path = "model/lightning_bolts_model/epoch=8-step=8010.ckpt"
152
+ # model_checkpoint_path = "model/pix2pix_lightning_model/version_0/checkpoints/epoch=199-step=355600.ckpt"
153
+ # model_checkpoint_path = "model/pix2pix_lightning_model/gen.pth"
154
+
155
+ model = OverpoweredPix2Pix.load_from_checkpoint(
156
  model_checkpoint_path
157
  )
158
+
159
+ model_chk = torch.load(
160
+ model_checkpoint_path, map_location=torch.device('cpu')
161
+ )
162
+ # model = gen().load_state_dict(model_chk)
163
+
164
  model.eval()
165
 
166
 
 
168
  return "Hello " + name + "!!"
169
 
170
 
171
+ def predict(img: Image):
172
+ # transform img
173
+ image = np.asarray(img)
174
+ # image = image[:, image.shape[1] // 2:, :]
175
  # use on inference
176
  inference_transform = A.Compose([
177
  A.Resize(width=256, height=256),
178
  A.Normalize(mean=[.5, .5, .5], std=[.5, .5, .5], max_pixel_value=255.0),
179
  al_pytorch.ToTensorV2(),
180
  ])
181
+ # inverse_transform = A.Compose([
182
+ # A.Normalize(
183
+ # mean=[0.485, 0.456, 0.406],
184
+ # std=[0.229, 0.224, 0.225]
185
+ # ),
186
+ # ])
187
  inference_img = inference_transform(
188
+ image=image
189
  )['image'].unsqueeze(0)
190
+ with torch.no_grad():
191
+ result = model.gen(inference_img)
192
+ # torchvision.utils.save_image(inference_img, "inference_image.png", normalize=True)
193
+ torchvision.utils.save_image(result, "inference_image.png", normalize=True)
 
 
 
 
194
 
195
+ """
196
+ result_grid = torchvision.utils.make_grid(
197
+ [result[0]],
198
+ normalize=True
199
+ )
200
+ # plt.imsave("coloured_grid.png", (result_grid.permute(1,2,0).detach().numpy()*255).astype(int))
201
+ torchvision.utils.save_image(
202
+ result_grid, "coloured_image.png", normalize=True
203
+ )
204
+ """
205
+ return "inference_image.png" # 'coloured_image.png',
206
 
207
 
208
  iface = gr.Interface(
209
  fn=predict,
210
  inputs=gr.inputs.Image(type="pil"),
211
+ #inputs="sketchpad",
212
+ examples=[
213
+ "examples/thesis_test.png",
214
+ "examples/thesis_test2.png",
215
+ # "examples/1000000.png"
216
+ ],
217
+ outputs=gr.outputs.Image(type="pil",),
218
+ #outputs=[
219
+ # "image",
220
+ # # "image"
221
+ #],
222
  title="Colour your sketches!",
223
  description=" Upload a sketch and the conditional gan will colour it for you!",
224
  article="WIP repo lives here - https://github.com/nmud19/thesisGAN "