mellow commited on
Commit
fed222b
1 Parent(s): 3d84e38

premiere commit

Browse files

spaces core pour une demo

Files changed (1) hide show
  1. app.py +128 -0
app.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import gradio as gr
3
+ import numpy as np
4
+ import onnxruntime
5
+ import requests
6
+ from huggingface_hub import hf_hub_download
7
+ from PIL import Image
8
+
9
+
10
+ # Get x_scale_factor & y_scale_factor to resize image
11
+ def get_scale_factor(im_h, im_w, ref_size=512):
12
+
13
+ if max(im_h, im_w) < ref_size or min(im_h, im_w) > ref_size:
14
+ if im_w >= im_h:
15
+ im_rh = ref_size
16
+ im_rw = int(im_w / im_h * ref_size)
17
+ elif im_w < im_h:
18
+ im_rw = ref_size
19
+ im_rh = int(im_h / im_w * ref_size)
20
+ else:
21
+ im_rh = im_h
22
+ im_rw = im_w
23
+
24
+ im_rw = im_rw - im_rw % 32
25
+ im_rh = im_rh - im_rh % 32
26
+
27
+ x_scale_factor = im_rw / im_w
28
+ y_scale_factor = im_rh / im_h
29
+
30
+ return x_scale_factor, y_scale_factor
31
+
32
+
33
+ MODEL_PATH = hf_hub_download('nateraw/background-remover-files', 'modnet.onnx', repo_type='dataset')
34
+
35
+
36
+ def main(image_path, threshold):
37
+
38
+ # read image
39
+ im = cv2.imread(image_path)
40
+ im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
41
+
42
+ # unify image channels to 3
43
+ if len(im.shape) == 2:
44
+ im = im[:, :, None]
45
+ if im.shape[2] == 1:
46
+ im = np.repeat(im, 3, axis=2)
47
+ elif im.shape[2] == 4:
48
+ im = im[:, :, 0:3]
49
+
50
+ # normalize values to scale it between -1 to 1
51
+ im = (im - 127.5) / 127.5
52
+
53
+ im_h, im_w, im_c = im.shape
54
+ x, y = get_scale_factor(im_h, im_w)
55
+
56
+ # resize image
57
+ im = cv2.resize(im, None, fx=x, fy=y, interpolation=cv2.INTER_AREA)
58
+
59
+ # prepare input shape
60
+ im = np.transpose(im)
61
+ im = np.swapaxes(im, 1, 2)
62
+ im = np.expand_dims(im, axis=0).astype('float32')
63
+
64
+ # Initialize session and get prediction
65
+ session = onnxruntime.InferenceSession(MODEL_PATH, None)
66
+ input_name = session.get_inputs()[0].name
67
+ output_name = session.get_outputs()[0].name
68
+ result = session.run([output_name], {input_name: im})
69
+
70
+ # refine matte
71
+ matte = (np.squeeze(result[0]) * 255).astype('uint8')
72
+ matte = cv2.resize(matte, dsize=(im_w, im_h), interpolation=cv2.INTER_AREA)
73
+
74
+ # HACK - Could probably just convert this to PIL instead of writing
75
+ cv2.imwrite('out.png', matte)
76
+
77
+ image = Image.open(image_path)
78
+ matte = Image.open('out.png')
79
+
80
+ # obtain predicted foreground
81
+ image = np.asarray(image)
82
+ if len(image.shape) == 2:
83
+ image = image[:, :, None]
84
+ if image.shape[2] == 1:
85
+ image = np.repeat(image, 3, axis=2)
86
+ elif image.shape[2] == 4:
87
+ image = image[:, :, 0:3]
88
+
89
+ b, g, r = cv2.split(image)
90
+
91
+ mask = np.asarray(matte)
92
+ a = np.ones(mask.shape, dtype='uint8') * 255
93
+ alpha_im = cv2.merge([b, g, r, a], 4)
94
+ bg = np.zeros(alpha_im.shape)
95
+ new_mask = np.stack([mask, mask, mask, mask], axis=2)
96
+ foreground = np.where(new_mask > threshold, alpha_im, bg).astype(np.uint8)
97
+
98
+ return Image.fromarray(foreground)
99
+
100
+
101
+ title = "Groupe 12 background remover"
102
+ description = "Groupe 12 background remover est un modèle capable de supprimer l'arrière-plan d'une image donnée. Pour l'utiliser, il suffit de télécharger votre image, ou de cliquer sur l'un des exemples pour les charger. Pour en savoir plus, cliquez sur les liens ci-dessous.."
103
+ article = "<div style='text-align: center;'> <a href='https://github.com/ZHKKKe/MODNet' target='_blank'>Github Repo</a> | <a href='https://arxiv.org/abs/2011.11961' target='_blank'>MODNet: Real-Time Trimap-Free Portrait Matting via Objective Decomposition</a> </div>"
104
+
105
+ url = "https://huggingface.co/datasets/nateraw/background-remover-files/resolve/main/twitter_profile_pic.jpeg"
106
+ image = Image.open(requests.get(url, stream=True).raw)
107
+ image.save('twitter_profile_pic.jpg')
108
+
109
+ url = "https://upload.wikimedia.org/wikipedia/commons/8/8d/President_Barack_Obama.jpg"
110
+ image = Image.open(requests.get(url, stream=True).raw)
111
+ image.save('obama.jpg')
112
+
113
+ interface = gr.Interface(
114
+ fn=main,
115
+ inputs=[
116
+ gr.inputs.Image(type='filepath'),
117
+ gr.inputs.Slider(minimum=0, maximum=250, default=100, step=5, label='Mask Cutoff Threshold'),
118
+ ],
119
+ outputs='image',
120
+ examples=[['twitter_profile_pic.jpg', 120], ['obama.jpg', 155]],
121
+ title=title,
122
+ description=description,
123
+ article=article,
124
+ )
125
+
126
+ if __name__ == '__main__':
127
+ interface.launch(debug=True)
128
+