HaoFeng2019 commited on
Commit
2ddeb02
1 Parent(s): 63e518b

Delete demo.py

Browse files
Files changed (1) hide show
  1. demo.py +0 -178
demo.py DELETED
@@ -1,178 +0,0 @@
1
- #origin
2
-
3
- from seg import U2NETP
4
- from GeoTr import GeoTr
5
- from IllTr import IllTr
6
- from inference_ill import rec_ill
7
-
8
- import torch
9
- import torch.nn as nn
10
- import torch.nn.functional as F
11
- import skimage.io as io
12
- import numpy as np
13
- import cv2
14
- import glob
15
- import os
16
- from PIL import Image
17
- import argparse
18
- import warnings
19
- warnings.filterwarnings('ignore')
20
-
21
-
22
-
23
-
24
-
25
- import gradio as gr
26
-
27
-
28
- class GeoTr_Seg(nn.Module):
29
- def __init__(self):
30
- super(GeoTr_Seg, self).__init__()
31
- self.msk = U2NETP(3, 1)
32
- self.GeoTr = GeoTr(num_attn_layers=6)
33
-
34
- def forward(self, x):
35
- msk, _1,_2,_3,_4,_5,_6 = self.msk(x)
36
- msk = (msk > 0.5).float()
37
- x = msk * x
38
-
39
- bm = self.GeoTr(x)
40
- bm = (2 * (bm / 286.8) - 1) * 0.99
41
-
42
- return bm
43
-
44
-
45
- def reload_model(model, path=""):
46
- if not bool(path):
47
- return model
48
- else:
49
- model_dict = model.state_dict()
50
- pretrained_dict = torch.load(path, map_location='cpu')
51
- #print(len(pretrained_dict.keys()))
52
- pretrained_dict = {k[7:]: v for k, v in pretrained_dict.items() if k[7:] in model_dict}
53
- #print(len(pretrained_dict.keys()))
54
- model_dict.update(pretrained_dict)
55
- model.load_state_dict(model_dict)
56
-
57
- return model
58
-
59
-
60
- def reload_segmodel(model, path=""):
61
- if not bool(path):
62
- return model
63
- else:
64
- model_dict = model.state_dict()
65
- pretrained_dict = torch.load(path, map_location='cpu')
66
- #print(len(pretrained_dict.keys()))
67
- pretrained_dict = {k[6:]: v for k, v in pretrained_dict.items() if k[6:] in model_dict}
68
- #print(len(pretrained_dict.keys()))
69
- model_dict.update(pretrained_dict)
70
- model.load_state_dict(model_dict)
71
-
72
- return model
73
-
74
-
75
- def rec(opt):
76
- # print(torch.__version__) # 1.5.1
77
- img_list = os.listdir(opt.distorrted_path) # distorted images list
78
-
79
- if not os.path.exists(opt.gsave_path): # create save path
80
- os.mkdir(opt.gsave_path)
81
- if not os.path.exists(opt.isave_path): # create save path
82
- os.mkdir(opt.isave_path)
83
-
84
- GeoTr_Seg_model = GeoTr_Seg()#.cuda()
85
- # reload segmentation model
86
- reload_segmodel(GeoTr_Seg_model.msk, opt.Seg_path)
87
- # reload geometric unwarping model
88
- reload_model(GeoTr_Seg_model.GeoTr, opt.GeoTr_path)
89
-
90
- IllTr_model = IllTr()#.cuda()
91
- # reload illumination rectification model
92
- reload_model(IllTr_model, opt.IllTr_path)
93
-
94
- # To eval mode
95
- GeoTr_Seg_model.eval()
96
- IllTr_model.eval()
97
-
98
- for img_path in img_list:
99
- name = img_path.split('.')[-2] # image name
100
-
101
- img_path = opt.distorrted_path + img_path # read image and to tensor
102
- im_ori = np.array(Image.open(img_path))[:, :, :3] / 255.
103
- h, w, _ = im_ori.shape
104
- im = cv2.resize(im_ori, (288, 288))
105
- im = im.transpose(2, 0, 1)
106
- im = torch.from_numpy(im).float().unsqueeze(0)
107
-
108
- with torch.no_grad():
109
- # geometric unwarping
110
- bm = GeoTr_Seg_model(im)
111
- bm = bm.cpu()
112
- bm0 = cv2.resize(bm[0, 0].numpy(), (w, h)) # x flow
113
- bm1 = cv2.resize(bm[0, 1].numpy(), (w, h)) # y flow
114
- bm0 = cv2.blur(bm0, (3, 3))
115
- bm1 = cv2.blur(bm1, (3, 3))
116
- lbl = torch.from_numpy(np.stack([bm0, bm1], axis=2)).unsqueeze(0) # h * w * 2
117
-
118
- out = F.grid_sample(torch.from_numpy(im_ori).permute(2,0,1).unsqueeze(0).float(), lbl, align_corners=True)
119
- img_geo = ((out[0]*255).permute(1, 2, 0).numpy())[:,:,::-1].astype(np.uint8)
120
- cv2.imwrite(opt.gsave_path + name + '_geo' + '.png', img_geo) # save
121
-
122
- # illumination rectification
123
- if opt.ill_rec:
124
- ill_savep = opt.isave_path + name + '_ill' + '.png'
125
- rec_ill(IllTr_model, img_geo, saveRecPath=ill_savep)
126
-
127
- print('Done: ', img_path)
128
-
129
-
130
-
131
-
132
-
133
-
134
- def process_image(input_image):
135
- GeoTr_Seg_model = GeoTr_Seg()#.cuda()
136
- reload_segmodel(GeoTr_Seg_model.msk, './model_pretrained/seg.pth')
137
- reload_model(GeoTr_Seg_model.GeoTr, './model_pretrained/geotr.pth')
138
-
139
- IllTr_model = IllTr()#.cuda()
140
- reload_model(IllTr_model, './model_pretrained/illtr.pth')
141
-
142
- GeoTr_Seg_model.eval()
143
- IllTr_model.eval()
144
-
145
- im_ori = np.array(input_image)[:, :, :3] / 255.
146
- h, w, _ = im_ori.shape
147
- im = cv2.resize(im_ori, (288, 288))
148
- im = im.transpose(2, 0, 1)
149
- im = torch.from_numpy(im).float().unsqueeze(0)
150
-
151
- with torch.no_grad():
152
- bm = GeoTr_Seg_model(im)
153
- bm = bm.cpu()
154
- bm0 = cv2.resize(bm[0, 0].numpy(), (w, h))
155
- bm1 = cv2.resize(bm[0, 1].numpy(), (w, h))
156
- bm0 = cv2.blur(bm0, (3, 3))
157
- bm1 = cv2.blur(bm1, (3, 3))
158
- lbl = torch.from_numpy(np.stack([bm0, bm1], axis=2)).unsqueeze(0)
159
-
160
- out = F.grid_sample(torch.from_numpy(im_ori).permute(2, 0, 1).unsqueeze(0).float(), lbl, align_corners=True)
161
- img_geo = ((out[0] * 255).permute(1, 2, 0).numpy()).astype(np.uint8)
162
-
163
- ill_rec=False
164
-
165
- if ill_rec:
166
- img_ill = rec_ill(IllTr_model, img_geo)
167
- return Image.fromarray(img_ill)
168
- else:
169
- return Image.fromarray(img_geo)
170
-
171
- # Define Gradio interface
172
- input_image = gr.inputs.Image()
173
- output_image = gr.outputs.Image(type='pil')
174
-
175
-
176
- iface = gr.Interface(fn=process_image, inputs=input_image, outputs=output_image, title="Image Correction")
177
- iface.launch(server_port=1234, server_name="0.0.0.0")
178
-