hylee commited on
Commit
b9f4814
1 Parent(s): 090f070
Files changed (2) hide show
  1. app.py +11 -0
  2. requirements.txt +3 -1
app.py CHANGED
@@ -27,6 +27,8 @@ from data_loader import SalObjDataset
27
  from model import U2NET # full size version 173.6 MB
28
  from model import U2NETP # small version u2net 4.7 MB
29
 
 
 
30
 
31
  # normalize the predicted SOD probability map
32
  def normPRED(d):
@@ -58,6 +60,12 @@ def save_output(image_name,pred,d_dir):
58
  return d_dir+'/'+imidx+'.png'
59
 
60
 
 
 
 
 
 
 
61
  # --------- 1. get image path and name ---------
62
  model_name='u2net_portrait'#u2netp
63
 
@@ -82,6 +90,9 @@ net.eval()
82
 
83
 
84
  def process(im):
 
 
 
85
  img_name_list = glob.glob(im.name)
86
  print("Number of images: ", len(img_name_list))
87
  # --------- 2. dataloader ---------
 
27
  from model import U2NET # full size version 173.6 MB
28
  from model import U2NETP # small version u2net 4.7 MB
29
 
30
+ from modnet import ModNet
31
+ import huggingface_hub
32
 
33
  # normalize the predicted SOD probability map
34
  def normPRED(d):
 
60
  return d_dir+'/'+imidx+'.png'
61
 
62
 
63
+
64
+ modnet_path = huggingface_hub.hf_hub_download('hylee/apdrawing_model',
65
+ 'modnet.onnx',
66
+ force_filename='modnet.onnx')
67
+ modnet = ModNet(modnet_path)
68
+
69
  # --------- 1. get image path and name ---------
70
  model_name='u2net_portrait'#u2netp
71
 
 
90
 
91
 
92
  def process(im):
93
+ image = modnet.segment(im.name)
94
+ Image.fromarray(np.uint8(image)).save(im.name)
95
+
96
  img_name_list = glob.glob(im.name)
97
  print("Number of images: ", len(img_name_list))
98
  # --------- 2. dataloader ---------
requirements.txt CHANGED
@@ -3,4 +3,6 @@ scikit-image
3
  torch
4
  torchvision
5
  pillow
6
- opencv-python-headless
 
 
 
3
  torch
4
  torchvision
5
  pillow
6
+ opencv-python-headless
7
+ onnx==1.8.1
8
+ onnxruntime==1.6.0