HaoFeng2019 commited on
Commit
532251c
1 Parent(s): 059ec7f

Upload 8 files

Browse files
Files changed (6) hide show
  1. GeoTr.py +2 -2
  2. app.py +15 -15
  3. demo.py +12 -12
  4. position_encoding.py +1 -1
  5. requirements.txt +1 -1
  6. seg.py +3 -2
GeoTr.py CHANGED
@@ -107,7 +107,7 @@ class TransDecoder(nn.Module):
107
  self.position_embedding = build_position_encoding(hidden_dim)
108
 
109
  def forward(self, imgf, query_embed):
110
- pos = self.position_embedding(torch.ones(imgf.shape[0], imgf.shape[2], imgf.shape[3]).bool().cuda()) # torch.Size([1, 128, 36, 36])
111
 
112
  bs, c, h, w = imgf.shape
113
  imgf = imgf.flatten(2).permute(2, 0, 1)
@@ -129,7 +129,7 @@ class TransEncoder(nn.Module):
129
  self.position_embedding = build_position_encoding(hidden_dim)
130
 
131
  def forward(self, imgf):
132
- pos = self.position_embedding(torch.ones(imgf.shape[0], imgf.shape[2], imgf.shape[3]).bool().cuda()) # torch.Size([1, 128, 36, 36])
133
  bs, c, h, w = imgf.shape
134
  imgf = imgf.flatten(2).permute(2, 0, 1)
135
  pos = pos.flatten(2).permute(2, 0, 1)
 
107
  self.position_embedding = build_position_encoding(hidden_dim)
108
 
109
  def forward(self, imgf, query_embed):
110
+ pos = self.position_embedding(torch.ones(imgf.shape[0], imgf.shape[2], imgf.shape[3]).bool()) #.cuda()) # torch.Size([1, 128, 36, 36])
111
 
112
  bs, c, h, w = imgf.shape
113
  imgf = imgf.flatten(2).permute(2, 0, 1)
 
129
  self.position_embedding = build_position_encoding(hidden_dim)
130
 
131
  def forward(self, imgf):
132
+ pos = self.position_embedding(torch.ones(imgf.shape[0], imgf.shape[2], imgf.shape[3]).bool()) #.cuda()) # torch.Size([1, 128, 36, 36])
133
  bs, c, h, w = imgf.shape
134
  imgf = imgf.flatten(2).permute(2, 0, 1)
135
  pos = pos.flatten(2).permute(2, 0, 1)
app.py CHANGED
@@ -11,10 +11,10 @@ import torch.nn.functional as F
11
  import skimage.io as io
12
  import numpy as np
13
  import cv2
14
- #import glob
15
  import os
16
  from PIL import Image
17
- #import argparse
18
  import warnings
19
  warnings.filterwarnings('ignore')
20
 
@@ -47,10 +47,10 @@ def reload_model(model, path=""):
47
  return model
48
  else:
49
  model_dict = model.state_dict()
50
- pretrained_dict = torch.load(path, map_location='cuda:0')
51
- print(len(pretrained_dict.keys()))
52
  pretrained_dict = {k[7:]: v for k, v in pretrained_dict.items() if k[7:] in model_dict}
53
- print(len(pretrained_dict.keys()))
54
  model_dict.update(pretrained_dict)
55
  model.load_state_dict(model_dict)
56
 
@@ -62,10 +62,10 @@ def reload_segmodel(model, path=""):
62
  return model
63
  else:
64
  model_dict = model.state_dict()
65
- pretrained_dict = torch.load(path, map_location='cuda:0')
66
- print(len(pretrained_dict.keys()))
67
  pretrained_dict = {k[6:]: v for k, v in pretrained_dict.items() if k[6:] in model_dict}
68
- print(len(pretrained_dict.keys()))
69
  model_dict.update(pretrained_dict)
70
  model.load_state_dict(model_dict)
71
 
@@ -81,13 +81,13 @@ def rec(opt):
81
  if not os.path.exists(opt.isave_path): # create save path
82
  os.mkdir(opt.isave_path)
83
 
84
- GeoTr_Seg_model = GeoTr_Seg().cuda()
85
  # reload segmentation model
86
  reload_segmodel(GeoTr_Seg_model.msk, opt.Seg_path)
87
  # reload geometric unwarping model
88
  reload_model(GeoTr_Seg_model.GeoTr, opt.GeoTr_path)
89
 
90
- IllTr_model = IllTr().cuda()
91
  # reload illumination rectification model
92
  reload_model(IllTr_model, opt.IllTr_path)
93
 
@@ -107,7 +107,7 @@ def rec(opt):
107
 
108
  with torch.no_grad():
109
  # geometric unwarping
110
- bm = GeoTr_Seg_model(im.cuda())
111
  bm = bm.cpu()
112
  bm0 = cv2.resize(bm[0, 0].numpy(), (w, h)) # x flow
113
  bm1 = cv2.resize(bm[0, 1].numpy(), (w, h)) # y flow
@@ -132,11 +132,11 @@ def rec(opt):
132
 
133
 
134
  def process_image(input_image):
135
- GeoTr_Seg_model = GeoTr_Seg().cuda()
136
  reload_segmodel(GeoTr_Seg_model.msk, './model_pretrained/seg.pth')
137
  reload_model(GeoTr_Seg_model.GeoTr, './model_pretrained/geotr.pth')
138
 
139
- IllTr_model = IllTr().cuda()
140
  reload_model(IllTr_model, './model_pretrained/illtr.pth')
141
 
142
  GeoTr_Seg_model.eval()
@@ -149,7 +149,7 @@ def process_image(input_image):
149
  im = torch.from_numpy(im).float().unsqueeze(0)
150
 
151
  with torch.no_grad():
152
- bm = GeoTr_Seg_model(im.cuda())
153
  bm = bm.cpu()
154
  bm0 = cv2.resize(bm[0, 0].numpy(), (w, h))
155
  bm1 = cv2.resize(bm[0, 1].numpy(), (w, h))
@@ -173,6 +173,6 @@ input_image = gr.inputs.Image()
173
  output_image = gr.outputs.Image(type='pil')
174
 
175
 
176
- iface = gr.Interface(fn=process_image, inputs=input_image, outputs=output_image, title="Image Correction")
177
  iface.launch()
178
 
 
11
  import skimage.io as io
12
  import numpy as np
13
  import cv2
14
+ import glob
15
  import os
16
  from PIL import Image
17
+ import argparse
18
  import warnings
19
  warnings.filterwarnings('ignore')
20
 
 
47
  return model
48
  else:
49
  model_dict = model.state_dict()
50
+ pretrained_dict = torch.load(path, map_location='cpu')
51
+ #print(len(pretrained_dict.keys()))
52
  pretrained_dict = {k[7:]: v for k, v in pretrained_dict.items() if k[7:] in model_dict}
53
+ #print(len(pretrained_dict.keys()))
54
  model_dict.update(pretrained_dict)
55
  model.load_state_dict(model_dict)
56
 
 
62
  return model
63
  else:
64
  model_dict = model.state_dict()
65
+ pretrained_dict = torch.load(path, map_location='cpu')
66
+ #print(len(pretrained_dict.keys()))
67
  pretrained_dict = {k[6:]: v for k, v in pretrained_dict.items() if k[6:] in model_dict}
68
+ #print(len(pretrained_dict.keys()))
69
  model_dict.update(pretrained_dict)
70
  model.load_state_dict(model_dict)
71
 
 
81
  if not os.path.exists(opt.isave_path): # create save path
82
  os.mkdir(opt.isave_path)
83
 
84
+ GeoTr_Seg_model = GeoTr_Seg()#.cuda()
85
  # reload segmentation model
86
  reload_segmodel(GeoTr_Seg_model.msk, opt.Seg_path)
87
  # reload geometric unwarping model
88
  reload_model(GeoTr_Seg_model.GeoTr, opt.GeoTr_path)
89
 
90
+ IllTr_model = IllTr()#.cuda()
91
  # reload illumination rectification model
92
  reload_model(IllTr_model, opt.IllTr_path)
93
 
 
107
 
108
  with torch.no_grad():
109
  # geometric unwarping
110
+ bm = GeoTr_Seg_model(im)
111
  bm = bm.cpu()
112
  bm0 = cv2.resize(bm[0, 0].numpy(), (w, h)) # x flow
113
  bm1 = cv2.resize(bm[0, 1].numpy(), (w, h)) # y flow
 
132
 
133
 
134
  def process_image(input_image):
135
+ GeoTr_Seg_model = GeoTr_Seg()#.cuda()
136
  reload_segmodel(GeoTr_Seg_model.msk, './model_pretrained/seg.pth')
137
  reload_model(GeoTr_Seg_model.GeoTr, './model_pretrained/geotr.pth')
138
 
139
+ IllTr_model = IllTr()#.cuda()
140
  reload_model(IllTr_model, './model_pretrained/illtr.pth')
141
 
142
  GeoTr_Seg_model.eval()
 
149
  im = torch.from_numpy(im).float().unsqueeze(0)
150
 
151
  with torch.no_grad():
152
+ bm = GeoTr_Seg_model(im)
153
  bm = bm.cpu()
154
  bm0 = cv2.resize(bm[0, 0].numpy(), (w, h))
155
  bm1 = cv2.resize(bm[0, 1].numpy(), (w, h))
 
173
  output_image = gr.outputs.Image(type='pil')
174
 
175
 
176
+ iface = gr.Interface(fn=process_image, inputs=input_image, outputs=output_image, title="DocTr")
177
  iface.launch()
178
 
demo.py CHANGED
@@ -47,10 +47,10 @@ def reload_model(model, path=""):
47
  return model
48
  else:
49
  model_dict = model.state_dict()
50
- pretrained_dict = torch.load(path, map_location='cuda:0')
51
- print(len(pretrained_dict.keys()))
52
  pretrained_dict = {k[7:]: v for k, v in pretrained_dict.items() if k[7:] in model_dict}
53
- print(len(pretrained_dict.keys()))
54
  model_dict.update(pretrained_dict)
55
  model.load_state_dict(model_dict)
56
 
@@ -62,10 +62,10 @@ def reload_segmodel(model, path=""):
62
  return model
63
  else:
64
  model_dict = model.state_dict()
65
- pretrained_dict = torch.load(path, map_location='cuda:0')
66
- print(len(pretrained_dict.keys()))
67
  pretrained_dict = {k[6:]: v for k, v in pretrained_dict.items() if k[6:] in model_dict}
68
- print(len(pretrained_dict.keys()))
69
  model_dict.update(pretrained_dict)
70
  model.load_state_dict(model_dict)
71
 
@@ -81,13 +81,13 @@ def rec(opt):
81
  if not os.path.exists(opt.isave_path): # create save path
82
  os.mkdir(opt.isave_path)
83
 
84
- GeoTr_Seg_model = GeoTr_Seg().cuda()
85
  # reload segmentation model
86
  reload_segmodel(GeoTr_Seg_model.msk, opt.Seg_path)
87
  # reload geometric unwarping model
88
  reload_model(GeoTr_Seg_model.GeoTr, opt.GeoTr_path)
89
 
90
- IllTr_model = IllTr().cuda()
91
  # reload illumination rectification model
92
  reload_model(IllTr_model, opt.IllTr_path)
93
 
@@ -107,7 +107,7 @@ def rec(opt):
107
 
108
  with torch.no_grad():
109
  # geometric unwarping
110
- bm = GeoTr_Seg_model(im.cuda())
111
  bm = bm.cpu()
112
  bm0 = cv2.resize(bm[0, 0].numpy(), (w, h)) # x flow
113
  bm1 = cv2.resize(bm[0, 1].numpy(), (w, h)) # y flow
@@ -132,11 +132,11 @@ def rec(opt):
132
 
133
 
134
  def process_image(input_image):
135
- GeoTr_Seg_model = GeoTr_Seg().cuda()
136
  reload_segmodel(GeoTr_Seg_model.msk, './model_pretrained/seg.pth')
137
  reload_model(GeoTr_Seg_model.GeoTr, './model_pretrained/geotr.pth')
138
 
139
- IllTr_model = IllTr().cuda()
140
  reload_model(IllTr_model, './model_pretrained/illtr.pth')
141
 
142
  GeoTr_Seg_model.eval()
@@ -149,7 +149,7 @@ def process_image(input_image):
149
  im = torch.from_numpy(im).float().unsqueeze(0)
150
 
151
  with torch.no_grad():
152
- bm = GeoTr_Seg_model(im.cuda())
153
  bm = bm.cpu()
154
  bm0 = cv2.resize(bm[0, 0].numpy(), (w, h))
155
  bm1 = cv2.resize(bm[0, 1].numpy(), (w, h))
 
47
  return model
48
  else:
49
  model_dict = model.state_dict()
50
+ pretrained_dict = torch.load(path, map_location='cpu')
51
+ #print(len(pretrained_dict.keys()))
52
  pretrained_dict = {k[7:]: v for k, v in pretrained_dict.items() if k[7:] in model_dict}
53
+ #print(len(pretrained_dict.keys()))
54
  model_dict.update(pretrained_dict)
55
  model.load_state_dict(model_dict)
56
 
 
62
  return model
63
  else:
64
  model_dict = model.state_dict()
65
+ pretrained_dict = torch.load(path, map_location='cpu')
66
+ #print(len(pretrained_dict.keys()))
67
  pretrained_dict = {k[6:]: v for k, v in pretrained_dict.items() if k[6:] in model_dict}
68
+ #print(len(pretrained_dict.keys()))
69
  model_dict.update(pretrained_dict)
70
  model.load_state_dict(model_dict)
71
 
 
81
  if not os.path.exists(opt.isave_path): # create save path
82
  os.mkdir(opt.isave_path)
83
 
84
+ GeoTr_Seg_model = GeoTr_Seg()#.cuda()
85
  # reload segmentation model
86
  reload_segmodel(GeoTr_Seg_model.msk, opt.Seg_path)
87
  # reload geometric unwarping model
88
  reload_model(GeoTr_Seg_model.GeoTr, opt.GeoTr_path)
89
 
90
+ IllTr_model = IllTr()#.cuda()
91
  # reload illumination rectification model
92
  reload_model(IllTr_model, opt.IllTr_path)
93
 
 
107
 
108
  with torch.no_grad():
109
  # geometric unwarping
110
+ bm = GeoTr_Seg_model(im)
111
  bm = bm.cpu()
112
  bm0 = cv2.resize(bm[0, 0].numpy(), (w, h)) # x flow
113
  bm1 = cv2.resize(bm[0, 1].numpy(), (w, h)) # y flow
 
132
 
133
 
134
  def process_image(input_image):
135
+ GeoTr_Seg_model = GeoTr_Seg()#.cuda()
136
  reload_segmodel(GeoTr_Seg_model.msk, './model_pretrained/seg.pth')
137
  reload_model(GeoTr_Seg_model.GeoTr, './model_pretrained/geotr.pth')
138
 
139
+ IllTr_model = IllTr()#.cuda()
140
  reload_model(IllTr_model, './model_pretrained/illtr.pth')
141
 
142
  GeoTr_Seg_model.eval()
 
149
  im = torch.from_numpy(im).float().unsqueeze(0)
150
 
151
  with torch.no_grad():
152
+ bm = GeoTr_Seg_model(im)
153
  bm = bm.cpu()
154
  bm0 = cv2.resize(bm[0, 0].numpy(), (w, h))
155
  bm1 = cv2.resize(bm[0, 1].numpy(), (w, h))
position_encoding.py CHANGED
@@ -58,7 +58,7 @@ class PositionEmbeddingSine(nn.Module):
58
  y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
59
  x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
60
 
61
- dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32).cuda()
62
  dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
63
 
64
  pos_x = x_embed[:, :, :, None] / dim_t
 
58
  y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
59
  x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
60
 
61
+ dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32)#.cuda()
62
  dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
63
 
64
  pos_x = x_embed[:, :, :, None] / dim_t
requirements.txt CHANGED
@@ -2,7 +2,7 @@ gradio
2
  numpy
3
  opencv_python
4
  Pillow
5
- skimage
6
  timm
7
  torch
8
  torchvision
 
2
  numpy
3
  opencv_python
4
  Pillow
5
+ scikit_image
6
  timm
7
  torch
8
  torchvision
seg.py CHANGED
@@ -40,6 +40,7 @@ class REBNCONV(nn.Module):
40
  self.relu_s1 = nn.ReLU(inplace=True)
41
 
42
  def forward(self, x):
 
43
  hx = x
44
  xout = self.relu_s1(self.bn_s1(self.conv_s1(hx)))
45
 
@@ -559,9 +560,9 @@ def get_parameter_number(net):
559
 
560
 
561
  if __name__ == '__main__':
562
- net = U2NET(4, 1).cuda()
563
  print(get_parameter_number(net)) # 69090500 加attention后69442032
564
  with torch.no_grad():
565
- inputs = torch.zeros(1, 3, 256, 256).cuda()
566
  outs = net(inputs)
567
  print(outs[0].shape) # torch.Size([2, 3, 256, 256]) torch.Size([2, 2, 256, 256])
 
40
  self.relu_s1 = nn.ReLU(inplace=True)
41
 
42
  def forward(self, x):
43
+ #print(x.device)
44
  hx = x
45
  xout = self.relu_s1(self.bn_s1(self.conv_s1(hx)))
46
 
 
560
 
561
 
562
  if __name__ == '__main__':
563
+ net = U2NET(4, 1)#.cuda()
564
  print(get_parameter_number(net)) # 69090500 加attention后69442032
565
  with torch.no_grad():
566
+ inputs = torch.zeros(1, 3, 256, 256)#.cuda()
567
  outs = net(inputs)
568
  print(outs[0].shape) # torch.Size([2, 3, 256, 256]) torch.Size([2, 2, 256, 256])