HaoFeng2019 commited on
Commit
4cbf82a
1 Parent(s): 1b147a3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -13
app.py CHANGED
@@ -14,11 +14,13 @@ import os
14
  from PIL import Image
15
  import argparse
16
  import warnings
 
17
  warnings.filterwarnings('ignore')
18
 
19
  import gradio as gr
20
 
21
- example_img_list = ['51_1 copy.png','48_2 copy.png','25.jpg']
 
22
 
23
  def reload_model(model, path=""):
24
  if not bool(path):
@@ -49,51 +51,54 @@ def reload_segmodel(model, path=""):
49
 
50
  return model
51
 
 
52
  class GeoTr_Seg(nn.Module):
53
  def __init__(self):
54
  super(GeoTr_Seg, self).__init__()
55
  self.msk = U2NETP(3, 1)
56
  self.GeoTr = GeoTr(num_attn_layers=6)
57
-
58
  def forward(self, x):
59
- msk, _1,_2,_3,_4,_5,_6 = self.msk(x)
60
  msk = (msk > 0.5).float()
61
  x = msk * x
62
 
63
  bm = self.GeoTr(x)
64
  bm = (2 * (bm / 286.8) - 1) * 0.99
65
-
66
  return bm
67
 
68
 
69
  # Initialize models
70
  GeoTr_Seg_model = GeoTr_Seg()
71
- #IllTr_model = IllTr()
72
 
73
  # Load models only once
74
  reload_segmodel(GeoTr_Seg_model.msk, './model_pretrained/seg.pth')
75
  reload_model(GeoTr_Seg_model.GeoTr, './model_pretrained/geotr.pth')
76
- #reload_model(IllTr_model, './model_pretrained/illtr.pth')
77
 
78
  # Compile models (assuming PyTorch 2.0)
79
  GeoTr_Seg_model = torch.compile(GeoTr_Seg_model)
80
- #IllTr_model = torch.compile(IllTr_model)
 
 
81
 
82
  def process_image(input_image):
83
  GeoTr_Seg_model.eval()
 
84
 
85
  im_ori = np.array(input_image)[:, :, :3] / 255.
86
  h, w, _ = im_ori.shape
87
- new_height = int(h * (288 / w))
88
- im = cv2.resize(im_ori, (288, new_height))
89
  im = im.transpose(2, 0, 1)
90
  im = torch.from_numpy(im).float().unsqueeze(0)
91
 
92
  with torch.no_grad():
93
  bm = GeoTr_Seg_model(im)
94
  bm = bm.cpu()
95
- bm0 = cv2.resize(bm[0, 0].numpy(), (288, new_height))
96
- bm1 = cv2.resize(bm[0, 1].numpy(), (288, new_height))
97
  bm0 = cv2.blur(bm0, (3, 3))
98
  bm1 = cv2.blur(bm1, (3, 3))
99
  lbl = torch.from_numpy(np.stack([bm0, bm1], axis=2)).unsqueeze(0)
@@ -114,5 +119,6 @@ def process_image(input_image):
114
  input_image = gr.inputs.Image()
115
  output_image = gr.outputs.Image(type='pil')
116
 
117
- iface = gr.Interface(fn=process_image, inputs=input_image, outputs=output_image, title="DocTr", examples=example_img_list)
118
- iface.launch()
 
 
14
  from PIL import Image
15
  import argparse
16
  import warnings
17
+
18
  warnings.filterwarnings('ignore')
19
 
20
  import gradio as gr
21
 
22
+ example_img_list = ['51_1 copy.png', '48_2 copy.png', '25.jpg']
23
+
24
 
25
  def reload_model(model, path=""):
26
  if not bool(path):
 
51
 
52
  return model
53
 
54
+
55
  class GeoTr_Seg(nn.Module):
56
  def __init__(self):
57
  super(GeoTr_Seg, self).__init__()
58
  self.msk = U2NETP(3, 1)
59
  self.GeoTr = GeoTr(num_attn_layers=6)
60
+
61
  def forward(self, x):
62
+ msk, _1, _2, _3, _4, _5, _6 = self.msk(x)
63
  msk = (msk > 0.5).float()
64
  x = msk * x
65
 
66
  bm = self.GeoTr(x)
67
  bm = (2 * (bm / 286.8) - 1) * 0.99
68
+
69
  return bm
70
 
71
 
72
  # Initialize models
73
  GeoTr_Seg_model = GeoTr_Seg()
74
+ # IllTr_model = IllTr()
75
 
76
  # Load models only once
77
  reload_segmodel(GeoTr_Seg_model.msk, './model_pretrained/seg.pth')
78
  reload_model(GeoTr_Seg_model.GeoTr, './model_pretrained/geotr.pth')
79
+ # reload_model(IllTr_model, './model_pretrained/illtr.pth')
80
 
81
  # Compile models (assuming PyTorch 2.0)
82
  GeoTr_Seg_model = torch.compile(GeoTr_Seg_model)
83
+
84
+
85
+ # IllTr_model = torch.compile(IllTr_model)
86
 
87
  def process_image(input_image):
88
  GeoTr_Seg_model.eval()
89
+ # IllTr_model.eval()
90
 
91
  im_ori = np.array(input_image)[:, :, :3] / 255.
92
  h, w, _ = im_ori.shape
93
+ im = cv2.resize(im_ori, (288, 288))
 
94
  im = im.transpose(2, 0, 1)
95
  im = torch.from_numpy(im).float().unsqueeze(0)
96
 
97
  with torch.no_grad():
98
  bm = GeoTr_Seg_model(im)
99
  bm = bm.cpu()
100
+ bm0 = cv2.resize(bm[0, 0].numpy(), (w, h))
101
+ bm1 = cv2.resize(bm[0, 1].numpy(), (w, h))
102
  bm0 = cv2.blur(bm0, (3, 3))
103
  bm1 = cv2.blur(bm1, (3, 3))
104
  lbl = torch.from_numpy(np.stack([bm0, bm1], axis=2)).unsqueeze(0)
 
119
  input_image = gr.inputs.Image()
120
  output_image = gr.outputs.Image(type='pil')
121
 
122
+ iface = gr.Interface(fn=process_image, inputs=input_image, outputs=output_image, title="DocTr",
123
+ examples=example_img_list)
124
+ iface.launch()