premiere commit
Browse filesspaces core pour une demo
app.py
ADDED
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import gradio as gr
|
3 |
+
import numpy as np
|
4 |
+
import onnxruntime
|
5 |
+
import requests
|
6 |
+
from huggingface_hub import hf_hub_download
|
7 |
+
from PIL import Image
|
8 |
+
|
9 |
+
|
10 |
+
# Get x_scale_factor & y_scale_factor to resize image
|
11 |
+
def get_scale_factor(im_h, im_w, ref_size=512):
|
12 |
+
|
13 |
+
if max(im_h, im_w) < ref_size or min(im_h, im_w) > ref_size:
|
14 |
+
if im_w >= im_h:
|
15 |
+
im_rh = ref_size
|
16 |
+
im_rw = int(im_w / im_h * ref_size)
|
17 |
+
elif im_w < im_h:
|
18 |
+
im_rw = ref_size
|
19 |
+
im_rh = int(im_h / im_w * ref_size)
|
20 |
+
else:
|
21 |
+
im_rh = im_h
|
22 |
+
im_rw = im_w
|
23 |
+
|
24 |
+
im_rw = im_rw - im_rw % 32
|
25 |
+
im_rh = im_rh - im_rh % 32
|
26 |
+
|
27 |
+
x_scale_factor = im_rw / im_w
|
28 |
+
y_scale_factor = im_rh / im_h
|
29 |
+
|
30 |
+
return x_scale_factor, y_scale_factor
|
31 |
+
|
32 |
+
|
33 |
+
MODEL_PATH = hf_hub_download('nateraw/background-remover-files', 'modnet.onnx', repo_type='dataset')
|
34 |
+
|
35 |
+
|
36 |
+
def main(image_path, threshold):
|
37 |
+
|
38 |
+
# read image
|
39 |
+
im = cv2.imread(image_path)
|
40 |
+
im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
|
41 |
+
|
42 |
+
# unify image channels to 3
|
43 |
+
if len(im.shape) == 2:
|
44 |
+
im = im[:, :, None]
|
45 |
+
if im.shape[2] == 1:
|
46 |
+
im = np.repeat(im, 3, axis=2)
|
47 |
+
elif im.shape[2] == 4:
|
48 |
+
im = im[:, :, 0:3]
|
49 |
+
|
50 |
+
# normalize values to scale it between -1 to 1
|
51 |
+
im = (im - 127.5) / 127.5
|
52 |
+
|
53 |
+
im_h, im_w, im_c = im.shape
|
54 |
+
x, y = get_scale_factor(im_h, im_w)
|
55 |
+
|
56 |
+
# resize image
|
57 |
+
im = cv2.resize(im, None, fx=x, fy=y, interpolation=cv2.INTER_AREA)
|
58 |
+
|
59 |
+
# prepare input shape
|
60 |
+
im = np.transpose(im)
|
61 |
+
im = np.swapaxes(im, 1, 2)
|
62 |
+
im = np.expand_dims(im, axis=0).astype('float32')
|
63 |
+
|
64 |
+
# Initialize session and get prediction
|
65 |
+
session = onnxruntime.InferenceSession(MODEL_PATH, None)
|
66 |
+
input_name = session.get_inputs()[0].name
|
67 |
+
output_name = session.get_outputs()[0].name
|
68 |
+
result = session.run([output_name], {input_name: im})
|
69 |
+
|
70 |
+
# refine matte
|
71 |
+
matte = (np.squeeze(result[0]) * 255).astype('uint8')
|
72 |
+
matte = cv2.resize(matte, dsize=(im_w, im_h), interpolation=cv2.INTER_AREA)
|
73 |
+
|
74 |
+
# HACK - Could probably just convert this to PIL instead of writing
|
75 |
+
cv2.imwrite('out.png', matte)
|
76 |
+
|
77 |
+
image = Image.open(image_path)
|
78 |
+
matte = Image.open('out.png')
|
79 |
+
|
80 |
+
# obtain predicted foreground
|
81 |
+
image = np.asarray(image)
|
82 |
+
if len(image.shape) == 2:
|
83 |
+
image = image[:, :, None]
|
84 |
+
if image.shape[2] == 1:
|
85 |
+
image = np.repeat(image, 3, axis=2)
|
86 |
+
elif image.shape[2] == 4:
|
87 |
+
image = image[:, :, 0:3]
|
88 |
+
|
89 |
+
b, g, r = cv2.split(image)
|
90 |
+
|
91 |
+
mask = np.asarray(matte)
|
92 |
+
a = np.ones(mask.shape, dtype='uint8') * 255
|
93 |
+
alpha_im = cv2.merge([b, g, r, a], 4)
|
94 |
+
bg = np.zeros(alpha_im.shape)
|
95 |
+
new_mask = np.stack([mask, mask, mask, mask], axis=2)
|
96 |
+
foreground = np.where(new_mask > threshold, alpha_im, bg).astype(np.uint8)
|
97 |
+
|
98 |
+
return Image.fromarray(foreground)
|
99 |
+
|
100 |
+
|
101 |
+
title = "Groupe 12 background remover"
|
102 |
+
description = "Groupe 12 background remover est un modèle capable de supprimer l'arrière-plan d'une image donnée. Pour l'utiliser, il suffit de télécharger votre image, ou de cliquer sur l'un des exemples pour les charger. Pour en savoir plus, cliquez sur les liens ci-dessous.."
|
103 |
+
article = "<div style='text-align: center;'> <a href='https://github.com/ZHKKKe/MODNet' target='_blank'>Github Repo</a> | <a href='https://arxiv.org/abs/2011.11961' target='_blank'>MODNet: Real-Time Trimap-Free Portrait Matting via Objective Decomposition</a> </div>"
|
104 |
+
|
105 |
+
url = "https://huggingface.co/datasets/nateraw/background-remover-files/resolve/main/twitter_profile_pic.jpeg"
|
106 |
+
image = Image.open(requests.get(url, stream=True).raw)
|
107 |
+
image.save('twitter_profile_pic.jpg')
|
108 |
+
|
109 |
+
url = "https://upload.wikimedia.org/wikipedia/commons/8/8d/President_Barack_Obama.jpg"
|
110 |
+
image = Image.open(requests.get(url, stream=True).raw)
|
111 |
+
image.save('obama.jpg')
|
112 |
+
|
113 |
+
interface = gr.Interface(
|
114 |
+
fn=main,
|
115 |
+
inputs=[
|
116 |
+
gr.inputs.Image(type='filepath'),
|
117 |
+
gr.inputs.Slider(minimum=0, maximum=250, default=100, step=5, label='Mask Cutoff Threshold'),
|
118 |
+
],
|
119 |
+
outputs='image',
|
120 |
+
examples=[['twitter_profile_pic.jpg', 120], ['obama.jpg', 155]],
|
121 |
+
title=title,
|
122 |
+
description=description,
|
123 |
+
article=article,
|
124 |
+
)
|
125 |
+
|
126 |
+
if __name__ == '__main__':
|
127 |
+
interface.launch(debug=True)
|
128 |
+
|