SMD00 commited on
Commit
81c5831
1 Parent(s): 6395ab7

Upload 4 files

Browse files
Files changed (4) hide show
  1. app.py +365 -0
  2. final_model_weights.pt +3 -0
  3. requirements.txt +21 -0
  4. res18-unet.pt +3 -0
app.py ADDED
@@ -0,0 +1,365 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from PIL import Image
3
+ import cv2 as cv
4
+
5
+ import os
6
+ import glob
7
+ import time
8
+ import numpy as np
9
+ from PIL import Image
10
+ from pathlib import Path
11
+ from tqdm.notebook import tqdm
12
+ import matplotlib.pyplot as plt
13
+ from skimage.color import rgb2lab, lab2rgb
14
+
15
+ # pip install fastai==2.4
16
+
17
+ import torch
18
+ from torch import nn, optim
19
+ from torchvision import transforms
20
+ from torchvision.utils import make_grid
21
+ from torch.utils.data import Dataset, DataLoader
22
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
23
+ use_colab = None
24
+
25
+ SIZE = 256
26
+ class ColorizationDataset(Dataset):
27
+ def __init__(self, paths, split='train'):
28
+ if split == 'train':
29
+ self.transforms = transforms.Compose([
30
+ transforms.Resize((SIZE, SIZE), Image.BICUBIC),
31
+ transforms.RandomHorizontalFlip(), # A little data augmentation!
32
+ ])
33
+ elif split == 'val':
34
+ self.transforms = transforms.Resize((SIZE, SIZE), Image.BICUBIC)
35
+
36
+ self.split = split
37
+ self.size = SIZE
38
+ self.paths = paths
39
+
40
+ def __getitem__(self, idx):
41
+ img = Image.open(self.paths[idx]).convert("RGB")
42
+ img = self.transforms(img)
43
+ img = np.array(img)
44
+ img_lab = rgb2lab(img).astype("float32") # Converting RGB to L*a*b
45
+ img_lab = transforms.ToTensor()(img_lab)
46
+ L = img_lab[[0], ...] / 50. - 1. # Between -1 and 1
47
+ ab = img_lab[[1, 2], ...] / 110. # Between -1 and 1
48
+
49
+ return {'L': L, 'ab': ab}
50
+
51
+ def __len__(self):
52
+ return len(self.paths)
53
+
54
+ def make_dataloaders(batch_size=16, n_workers=4, pin_memory=True, **kwargs): # A handy function to make our dataloaders
55
+ dataset = ColorizationDataset(**kwargs)
56
+ dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=n_workers,
57
+ pin_memory=pin_memory)
58
+ return dataloader
59
+
60
+ class UnetBlock(nn.Module):
61
+ def __init__(self, nf, ni, submodule=None, input_c=None, dropout=False,
62
+ innermost=False, outermost=False):
63
+ super().__init__()
64
+ self.outermost = outermost
65
+ if input_c is None: input_c = nf
66
+ downconv = nn.Conv2d(input_c, ni, kernel_size=4,
67
+ stride=2, padding=1, bias=False)
68
+ downrelu = nn.LeakyReLU(0.2, True)
69
+ downnorm = nn.BatchNorm2d(ni)
70
+ uprelu = nn.ReLU(True)
71
+ upnorm = nn.BatchNorm2d(nf)
72
+
73
+ if outermost:
74
+ upconv = nn.ConvTranspose2d(ni * 2, nf, kernel_size=4,
75
+ stride=2, padding=1)
76
+ down = [downconv]
77
+ up = [uprelu, upconv, nn.Tanh()]
78
+ model = down + [submodule] + up
79
+ elif innermost:
80
+ upconv = nn.ConvTranspose2d(ni, nf, kernel_size=4,
81
+ stride=2, padding=1, bias=False)
82
+ down = [downrelu, downconv]
83
+ up = [uprelu, upconv, upnorm]
84
+ model = down + up
85
+ else:
86
+ upconv = nn.ConvTranspose2d(ni * 2, nf, kernel_size=4,
87
+ stride=2, padding=1, bias=False)
88
+ down = [downrelu, downconv, downnorm]
89
+ up = [uprelu, upconv, upnorm]
90
+ if dropout: up += [nn.Dropout(0.5)]
91
+ model = down + [submodule] + up
92
+ self.model = nn.Sequential(*model)
93
+
94
+ def forward(self, x):
95
+ if self.outermost:
96
+ return self.model(x)
97
+ else:
98
+ return torch.cat([x, self.model(x)], 1)
99
+
100
+ class Unet(nn.Module):
101
+ def __init__(self, input_c=1, output_c=2, n_down=8, num_filters=64):
102
+ super().__init__()
103
+ unet_block = UnetBlock(num_filters * 8, num_filters * 8, innermost=True)
104
+ for _ in range(n_down - 5):
105
+ unet_block = UnetBlock(num_filters * 8, num_filters * 8, submodule=unet_block, dropout=True)
106
+ out_filters = num_filters * 8
107
+ for _ in range(3):
108
+ unet_block = UnetBlock(out_filters // 2, out_filters, submodule=unet_block)
109
+ out_filters //= 2
110
+ self.model = UnetBlock(output_c, out_filters, input_c=input_c, submodule=unet_block, outermost=True)
111
+
112
+ def forward(self, x):
113
+ return self.model(x)
114
+
115
+ class PatchDiscriminator(nn.Module):
116
+ def __init__(self, input_c, num_filters=64, n_down=3):
117
+ super().__init__()
118
+ model = [self.get_layers(input_c, num_filters, norm=False)]
119
+ model += [self.get_layers(num_filters * 2 ** i, num_filters * 2 ** (i + 1), s=1 if i == (n_down-1) else 2)
120
+ for i in range(n_down)] # the 'if' statement is taking care of not using
121
+ # stride of 2 for the last block in this loop
122
+ model += [self.get_layers(num_filters * 2 ** n_down, 1, s=1, norm=False, act=False)] # Make sure to not use normalization or
123
+ # activation for the last layer of the model
124
+ self.model = nn.Sequential(*model)
125
+
126
+ def get_layers(self, ni, nf, k=4, s=2, p=1, norm=True, act=True): # when needing to make some repeatitive blocks of layers,
127
+ layers = [nn.Conv2d(ni, nf, k, s, p, bias=not norm)] # it's always helpful to make a separate method for that purpose
128
+ if norm: layers += [nn.BatchNorm2d(nf)]
129
+ if act: layers += [nn.LeakyReLU(0.2, True)]
130
+ return nn.Sequential(*layers)
131
+
132
+ def forward(self, x):
133
+ return self.model(x)
134
+
135
+ class GANLoss(nn.Module):
136
+ def __init__(self, gan_mode='vanilla', real_label=1.0, fake_label=0.0):
137
+ super().__init__()
138
+ self.register_buffer('real_label', torch.tensor(real_label))
139
+ self.register_buffer('fake_label', torch.tensor(fake_label))
140
+ if gan_mode == 'vanilla':
141
+ self.loss = nn.BCEWithLogitsLoss()
142
+ elif gan_mode == 'lsgan':
143
+ self.loss = nn.MSELoss()
144
+
145
+ def get_labels(self, preds, target_is_real):
146
+ if target_is_real:
147
+ labels = self.real_label
148
+ else:
149
+ labels = self.fake_label
150
+ return labels.expand_as(preds)
151
+
152
+ def __call__(self, preds, target_is_real):
153
+ labels = self.get_labels(preds, target_is_real)
154
+ loss = self.loss(preds, labels)
155
+ return loss
156
+
157
+ def init_weights(net, init='norm', gain=0.02):
158
+
159
+ def init_func(m):
160
+ classname = m.__class__.__name__
161
+ if hasattr(m, 'weight') and 'Conv' in classname:
162
+ if init == 'norm':
163
+ nn.init.normal_(m.weight.data, mean=0.0, std=gain)
164
+ elif init == 'xavier':
165
+ nn.init.xavier_normal_(m.weight.data, gain=gain)
166
+ elif init == 'kaiming':
167
+ nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
168
+
169
+ if hasattr(m, 'bias') and m.bias is not None:
170
+ nn.init.constant_(m.bias.data, 0.0)
171
+ elif 'BatchNorm2d' in classname:
172
+ nn.init.normal_(m.weight.data, 1., gain)
173
+ nn.init.constant_(m.bias.data, 0.)
174
+
175
+ net.apply(init_func)
176
+ print(f"model initialized with {init} initialization")
177
+ return net
178
+
179
+ def init_model(model, device):
180
+ model = model.to(device)
181
+ model = init_weights(model)
182
+ return model
183
+
184
+ class MainModel(nn.Module):
185
+ def __init__(self, net_G=None, lr_G=2e-4, lr_D=2e-4,
186
+ beta1=0.5, beta2=0.999, lambda_L1=100.):
187
+ super().__init__()
188
+
189
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
190
+ self.lambda_L1 = lambda_L1
191
+
192
+ if net_G is None:
193
+ self.net_G = init_model(Unet(input_c=1, output_c=2, n_down=8, num_filters=64), self.device)
194
+ else:
195
+ self.net_G = net_G.to(self.device)
196
+ self.net_D = init_model(PatchDiscriminator(input_c=3, n_down=3, num_filters=64), self.device)
197
+ self.GANcriterion = GANLoss(gan_mode='vanilla').to(self.device)
198
+ self.L1criterion = nn.L1Loss()
199
+ self.opt_G = optim.Adam(self.net_G.parameters(), lr=lr_G, betas=(beta1, beta2))
200
+ self.opt_D = optim.Adam(self.net_D.parameters(), lr=lr_D, betas=(beta1, beta2))
201
+
202
+ def set_requires_grad(self, model, requires_grad=True):
203
+ for p in model.parameters():
204
+ p.requires_grad = requires_grad
205
+
206
+ def setup_input(self, data):
207
+ self.L = data['L'].to(self.device)
208
+ self.ab = data['ab'].to(self.device)
209
+
210
+ def forward(self):
211
+ self.fake_color = self.net_G(self.L)
212
+
213
+ def backward_D(self):
214
+ fake_image = torch.cat([self.L, self.fake_color], dim=1)
215
+ fake_preds = self.net_D(fake_image.detach())
216
+ self.loss_D_fake = self.GANcriterion(fake_preds, False)
217
+ real_image = torch.cat([self.L, self.ab], dim=1)
218
+ real_preds = self.net_D(real_image)
219
+ self.loss_D_real = self.GANcriterion(real_preds, True)
220
+ self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5
221
+ self.loss_D.backward()
222
+
223
+ def backward_G(self):
224
+ fake_image = torch.cat([self.L, self.fake_color], dim=1)
225
+ fake_preds = self.net_D(fake_image)
226
+ self.loss_G_GAN = self.GANcriterion(fake_preds, True)
227
+ self.loss_G_L1 = self.L1criterion(self.fake_color, self.ab) * self.lambda_L1
228
+ self.loss_G = self.loss_G_GAN + self.loss_G_L1
229
+ self.loss_G.backward()
230
+
231
+ def optimize(self):
232
+ self.forward()
233
+ self.net_D.train()
234
+ self.set_requires_grad(self.net_D, True)
235
+ self.opt_D.zero_grad()
236
+ self.backward_D()
237
+ self.opt_D.step()
238
+
239
+ self.net_G.train()
240
+ self.set_requires_grad(self.net_D, False)
241
+ self.opt_G.zero_grad()
242
+ self.backward_G()
243
+ self.opt_G.step()
244
+
245
+ class AverageMeter:
246
+ def __init__(self):
247
+ self.reset()
248
+
249
+ def reset(self):
250
+ self.count, self.avg, self.sum = [0.] * 3
251
+
252
+ def update(self, val, count=1):
253
+ self.count += count
254
+ self.sum += count * val
255
+ self.avg = self.sum / self.count
256
+
257
+ def create_loss_meters():
258
+ loss_D_fake = AverageMeter()
259
+ loss_D_real = AverageMeter()
260
+ loss_D = AverageMeter()
261
+ loss_G_GAN = AverageMeter()
262
+ loss_G_L1 = AverageMeter()
263
+ loss_G = AverageMeter()
264
+
265
+ return {'loss_D_fake': loss_D_fake,
266
+ 'loss_D_real': loss_D_real,
267
+ 'loss_D': loss_D,
268
+ 'loss_G_GAN': loss_G_GAN,
269
+ 'loss_G_L1': loss_G_L1,
270
+ 'loss_G': loss_G}
271
+
272
+ def update_losses(model, loss_meter_dict, count):
273
+ for loss_name, loss_meter in loss_meter_dict.items():
274
+ loss = getattr(model, loss_name)
275
+ loss_meter.update(loss.item(), count=count)
276
+
277
+ def lab_to_rgb(L, ab):
278
+ """
279
+ Takes a batch of images
280
+ """
281
+
282
+ L = (L + 1.) * 50.
283
+ ab = ab * 110.
284
+ Lab = torch.cat([L, ab], dim=1).permute(0, 2, 3, 1).cpu().numpy()
285
+ rgb_imgs = []
286
+ for img in Lab:
287
+ img_rgb = lab2rgb(img)
288
+ rgb_imgs.append(img_rgb)
289
+ return np.stack(rgb_imgs, axis=0)
290
+
291
+ def visualize(model, data, dims):
292
+ model.net_G.eval()
293
+ with torch.no_grad():
294
+ model.setup_input(data)
295
+ model.forward()
296
+ model.net_G.train()
297
+ fake_color = model.fake_color.detach()
298
+ real_color = model.ab
299
+ L = model.L
300
+ fake_imgs = lab_to_rgb(L, fake_color)
301
+ real_imgs = lab_to_rgb(L, real_color)
302
+ for i in range(1):
303
+ # t_img = transforms.Resize((dims[0], dims[1]))(t_img)
304
+ img = Image.fromarray(np.uint8(fake_imgs[i]))
305
+ img = cv.resize(fake_imgs[i], dsize=(dims[1], dims[0]), interpolation=cv.INTER_CUBIC)
306
+ # st.text(f"Size of fake image {fake_imgs[i].shape} \n Type of image = {type(fake_imgs[i])}")
307
+ st.image(img, caption="Output image", use_column_width='auto', clamp=True)
308
+
309
+ def log_results(loss_meter_dict):
310
+ for loss_name, loss_meter in loss_meter_dict.items():
311
+ print(f"{loss_name}: {loss_meter.avg:.5f}")
312
+
313
+ # pip install fastai==2.4
314
+ from fastai.vision.learner import create_body
315
+ from torchvision.models.resnet import resnet18
316
+ from fastai.vision.models.unet import DynamicUnet
317
+
318
+ def build_res_unet(n_input=1, n_output=2, size=256):
319
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
320
+ body = create_body(resnet18(), pretrained=True, n_in=n_input, cut=-2)
321
+ net_G = DynamicUnet(body, n_output, (size, size)).to(device)
322
+ return net_G
323
+
324
+ net_G = build_res_unet(n_input=1, n_output=2, size=256)
325
+ net_G.load_state_dict(torch.load("res18-unet.pt", map_location=device))
326
+ model = MainModel(net_G=net_G)
327
+ model.load_state_dict(torch.load("final_model_weights.pt", map_location=device))
328
+
329
+ class MyDataset(torch.utils.data.Dataset):
330
+ def __init__(self, img_list):
331
+ super(MyDataset, self).__init__()
332
+ self.img_list = img_list
333
+ self.augmentations = transforms.Resize((SIZE, SIZE), Image.BICUBIC)
334
+
335
+
336
+ def __len__(self):
337
+ return len(self.img_list)
338
+
339
+ def __getitem__(self, idx):
340
+ img = self.img_list[idx]
341
+ img = self.augmentations(img)
342
+ img = np.array(img)
343
+ img_lab = rgb2lab(img).astype("float32") # Converting RGB to L*a*b
344
+ img_lab = transforms.ToTensor()(img_lab)
345
+ L = img_lab[[0], ...] / 50. - 1. # Between -1 and 1
346
+ ab = img_lab[[1, 2], ...] / 110.
347
+ return {'L': L, 'ab': ab}
348
+
349
+ def make_dataloaders2(batch_size=16, n_workers=4, pin_memory=True, **kwargs): # A handy function to make our dataloaders
350
+ dataset = MyDataset(**kwargs)
351
+ dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=n_workers,
352
+ pin_memory=pin_memory)
353
+ return dataloader
354
+
355
+ file_up = st.file_uploader("Upload an jpg image", type="jpg")
356
+ if file_up is not None:
357
+ im = Image.open(file_up)
358
+ st.text(body=f"Size of uploaded image {im.shape}")
359
+ a = im.shape
360
+ st.image(im, caption="Uploaded Image.", use_column_width='auto')
361
+ test_dl = make_dataloaders2(img_list=[im])
362
+ for data in test_dl:
363
+ model.setup_input(data)
364
+ model.optimize()
365
+ visualize(model, data, a)
final_model_weights.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:90a13caa255b03eebf2c1372ea3c9a019539ef0cccd0f32e8cf2f33442d9fe29
3
+ size 135592356
requirements.txt ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ streamlit
2
+ Pillow
3
+ glob2
4
+ numpy
5
+ pathlib
6
+ tqdm
7
+ matplotlib==3.2.2
8
+ matplotlib-venn==0.11.7
9
+ scikit-image
10
+ torchvision
11
+ torchsummary
12
+ fastai
13
+ fastcore
14
+ fastdownload
15
+ fastdtw
16
+ fastjsonschema
17
+ fastprogress
18
+ fastrlock
19
+ opencv-contrib-python==4.6.0.66
20
+ opencv-python==4.6.0.66
21
+ opencv-python-headless==4.6.0.66
res18-unet.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:56efe72585e8e0ff83d8a6e9900a40e95a422a5d2556a2db2b9b6d4faae0f7f8
3
+ size 124507223