Hu commited on
Commit
2f110b2
1 Parent(s): ef48324

initial commit

Browse files
Files changed (5) hide show
  1. LR_image.png +0 -0
  2. SRCNNmodel_trained.pt +3 -0
  3. barbara.png +0 -0
  4. demo.py +54 -0
  5. model.py +80 -0
LR_image.png ADDED
SRCNNmodel_trained.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c1c33c257abf0eef36eb73c60fbc1863ebf7612cefb07c6a7aea85b283b03ddb
3
+ size 34455
barbara.png ADDED
demo.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import gradio as gr
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from model import SRCNNModel, pred_SRCNN
7
+ from PIL import Image
8
+
9
+
10
+ title = "Super Resolution with CNN"
11
+ description = """
12
+
13
+ Your low resolution image will be reconstructed to high resolution with a scale of 2 with a convolutional neural network!
14
+
15
+ CNN output on the left, bicubic interpolation output on the right.
16
+
17
+
18
+ """
19
+
20
+ article = "Check out the origianl [paper](https://arxiv.org/abs/1501.00092) proposed by Dong *et al*."
21
+
22
+ # load model
23
+ print("Loading SRCNN model...")
24
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
25
+
26
+ model = SRCNNModel().to(device)
27
+ model.load_state_dict(torch.load('SRCNNmodel_trained.pt'))
28
+ model.eval()
29
+ print("SRCNN model loaded!")
30
+
31
+ def image_grid(imgs, rows, cols):
32
+ '''
33
+ imgs:list of PILImage
34
+ '''
35
+ assert len(imgs) == rows*cols
36
+
37
+ w, h = imgs[0].size
38
+ grid = Image.new('RGB', size=(cols*w, rows*h))
39
+ grid_w, grid_h = grid.size
40
+
41
+ for i, img in enumerate(imgs):
42
+ grid.paste(img, box=(i%cols*w, i//cols*h))
43
+ return grid
44
+
45
+ def sepia(image_path):
46
+ # gradio open image as np array
47
+ image = Image.fromarray(image_path,mode='RGB')
48
+ out_final,image_bicubic,image = pred_SRCNN(model=model,image=image,device=device)
49
+ grid = image_grid([out_final,image_bicubic],1,2)
50
+ return grid
51
+
52
+ demo = gr.Interface(fn = sepia, inputs=gr.Image(shape=(200, 200)), outputs="image",title=title,description = description,article = article,examples=['LR_image.png','barbara.png'])
53
+
54
+ demo.launch(share=True)
model.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import torchvision
5
+ from torchvision.transforms import transforms
6
+ import numpy as np
7
+ from PIL import Image
8
+
9
+ class SRCNNModel(nn.Module):
10
+ def __init__(self):
11
+ super(SRCNNModel, self).__init__()
12
+ self.conv1=nn.Conv2d(1,64,9,padding=4)
13
+ self.conv2=nn.Conv2d(64,32,1,padding=0)
14
+ self.conv3=nn.Conv2d(32,1,5,padding=2)
15
+
16
+ def forward(self,x):
17
+ out = F.relu(self.conv1(x))
18
+ out = F.relu(self.conv2(out))
19
+ out = self.conv3(out)
20
+ return out
21
+
22
+ def pred_SRCNN(model,image,device,scale_factor=2):
23
+ """
24
+ model: SRCNN model
25
+ image: low resolution image PILLOW image
26
+ scale_factor: scale factor for resolution
27
+ device: cuda or cpu
28
+ """
29
+ model.to(device)
30
+ model.eval()
31
+
32
+ # open image
33
+ # image = Image.open(image_path)
34
+ # split channels
35
+ y, cb, cr= image.convert('YCbCr').split()
36
+ # size will be used in image transform
37
+ original_size = y.size
38
+
39
+ # bicubic interpolate it to the original size
40
+ y_bicubic = transforms.Resize((original_size[1]*scale_factor,original_size[0]*scale_factor),interpolation=Image.BICUBIC)(y)
41
+ cb_bicubic = transforms.Resize((original_size[1]*scale_factor,original_size[0]*scale_factor),interpolation=Image.BICUBIC)(cb)
42
+ cr_bicubic = transforms.Resize((original_size[1]*scale_factor,original_size[0]*scale_factor),interpolation=Image.BICUBIC)(cr)
43
+ # turn it into tensor and add batch dimension
44
+ y_bicubic = transforms.ToTensor()(y_bicubic).to(device).unsqueeze(0)
45
+ # get the y channel SRCNN prediction
46
+ y_pred = model(y_bicubic)
47
+ # convert it to numpy image
48
+ y_pred = y_pred[0].cpu().detach().numpy()
49
+
50
+ # convert it into regular image pixel values
51
+ y_pred = y_pred*255
52
+ y_pred.clip(0,255)
53
+ # conver y channel from array to PIL image format for merging
54
+ y_pred_PIL = Image.fromarray(np.uint8(y_pred[0]),mode='L')
55
+ # merge the SRCNN y channel with cb cr channels
56
+ out_final = Image.merge('YCbCr',[y_pred_PIL,cb_bicubic,cr_bicubic]).convert('RGB')
57
+
58
+ image_bicubic = transforms.Resize((original_size[1]*scale_factor,original_size[0]*scale_factor),interpolation=Image.BICUBIC)(image)
59
+ return out_final,image_bicubic,image
60
+
61
+
62
+ def main():
63
+ print("Loading SRCNN model...")
64
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
65
+
66
+ model = SRCNNModel().to(device)
67
+ model.load_state_dict(torch.load('SRCNNmodel_trained.pt'))
68
+ model.eval()
69
+ print("SRCNN model loaded!")
70
+
71
+ image_path = "LR_image.png"
72
+
73
+ out_final,image_bicubic,image = pred_SRCNN(model=model,image_path=image_path,device=device)
74
+ image.show()
75
+ out_final.show()
76
+ image_bicubic.show()
77
+
78
+
79
+ if __name__=="__main__":
80
+ main()