hylee commited on
Commit
13a8cfb
1 Parent(s): b9f4814
Files changed (1) hide show
  1. modnet.py +93 -0
modnet.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import argparse
4
+ import numpy as np
5
+ from PIL import Image
6
+
7
+ import onnx
8
+ import onnxruntime
9
+
10
+
11
+ class ModNet:
12
+
13
+ def __init__(self, model_path):
14
+ # Initialize session and get prediction
15
+ self.session = onnxruntime.InferenceSession(model_path, None)
16
+
17
+ # Get x_scale_factor & y_scale_factor to resize image
18
+ def get_scale_factor(self, im_h, im_w, ref_size):
19
+
20
+ if max(im_h, im_w) < ref_size or min(im_h, im_w) > ref_size:
21
+ if im_w >= im_h:
22
+ im_rh = ref_size
23
+ im_rw = int(im_w / im_h * ref_size)
24
+ elif im_w < im_h:
25
+ im_rw = ref_size
26
+ im_rh = int(im_h / im_w * ref_size)
27
+ else:
28
+ im_rh = im_h
29
+ im_rw = im_w
30
+
31
+ im_rw = im_rw - im_rw % 32
32
+ im_rh = im_rh - im_rh % 32
33
+
34
+ x_scale_factor = im_rw / im_w
35
+ y_scale_factor = im_rh / im_h
36
+
37
+ return x_scale_factor, y_scale_factor
38
+
39
+ def segment(self, image_path):
40
+ ref_size = 512
41
+ ##############################################
42
+ # Main Inference part
43
+ ##############################################
44
+
45
+ # read image
46
+ im = cv2.imread(image_path)
47
+ im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
48
+
49
+ # unify image channels to 3
50
+ if len(im.shape) == 2:
51
+ im = im[:, :, None]
52
+ if im.shape[2] == 1:
53
+ im = np.repeat(im, 3, axis=2)
54
+ elif im.shape[2] == 4:
55
+ im = im[:, :, 0:3]
56
+
57
+ # normalize values to scale it between -1 to 1
58
+ im = (im - 127.5) / 127.5
59
+
60
+ im_h, im_w, im_c = im.shape
61
+ x, y = self.get_scale_factor(im_h, im_w, ref_size)
62
+
63
+ image = im
64
+ # resize image
65
+ im = cv2.resize(im, None, fx=x, fy=y, interpolation=cv2.INTER_AREA)
66
+
67
+ # prepare input shape
68
+ im = np.transpose(im)
69
+ im = np.swapaxes(im, 1, 2)
70
+ im = np.expand_dims(im, axis=0).astype('float32')
71
+
72
+ input_name = self.session.get_inputs()[0].name
73
+ output_name = self.session.get_outputs()[0].name
74
+ result = self.session.run([output_name], {input_name: im})
75
+
76
+ # refine matte
77
+ matte = (np.squeeze(result[0]) * 255).astype('uint8')
78
+ matte = cv2.resize(matte, dsize=(im_w, im_h), interpolation=cv2.INTER_AREA)
79
+
80
+ # obtain predicted foreground
81
+ image = np.asarray(image)
82
+ if len(image.shape) == 2:
83
+ image = image[:, :, None]
84
+ if image.shape[2] == 1:
85
+ image = np.repeat(image, 3, axis=2)
86
+ elif image.shape[2] == 4:
87
+ image = image[:, :, 0:3]
88
+ matte = np.repeat(np.asarray(matte)[:, :, None], 3, axis=2) / 255
89
+ foreground = image * matte + np.full(image.shape, 255) * (1 - matte)
90
+
91
+ return foreground
92
+
93
+