Upload 7 files
Browse filesupdate model and v0.1 pipeline
- .gitattributes +13 -0
- app.py +271 -0
- models/matting.pt +3 -0
- models/sod.pt +3 -0
- models/trimap.pt +3 -0
- requirements.txt +0 -0
.gitattributes
CHANGED
@@ -33,3 +33,16 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
|
37 |
+
### Python ###
|
38 |
+
# Byte-compiled / optimized / DLL files
|
39 |
+
__pycache__/
|
40 |
+
*.py[cod]
|
41 |
+
*$py.class
|
42 |
+
|
43 |
+
# PyCharm
|
44 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
45 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
46 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
47 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
48 |
+
#.idea/
|
app.py
ADDED
@@ -0,0 +1,271 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from hashlib import sha1
|
2 |
+
from pathlib import Path
|
3 |
+
|
4 |
+
import cv2
|
5 |
+
import gradio as gr
|
6 |
+
import numpy as np
|
7 |
+
from PIL import Image
|
8 |
+
import PIL
|
9 |
+
import torch
|
10 |
+
from torchvision import transforms
|
11 |
+
import torch.nn.functional as F
|
12 |
+
|
13 |
+
|
14 |
+
def estimate_foreground_ml(image, alpha, return_background=False):
|
15 |
+
"""
|
16 |
+
Estimates the foreground and background of an image based on an alpha mask.
|
17 |
+
|
18 |
+
Parameters:
|
19 |
+
- image: numpy array of shape (H, W, 3), the input RGB image.
|
20 |
+
- alpha: numpy array of shape (H, W), the alpha mask with values ranging from 0 to 1.
|
21 |
+
- return_background: boolean, if True, both foreground and background are returned.
|
22 |
+
|
23 |
+
Returns:
|
24 |
+
- If return_background is False, returns only the foreground.
|
25 |
+
- If return_background is True, returns a tuple (foreground, background).
|
26 |
+
"""
|
27 |
+
|
28 |
+
# Estimating foreground
|
29 |
+
# Expand alpha dimensions from (H, W) to (H, W, 1) to make it compatible for element-wise multiplication with the RGB image
|
30 |
+
foreground = image * alpha
|
31 |
+
|
32 |
+
if return_background:
|
33 |
+
# Estimating background
|
34 |
+
# Inverse alpha mask to isolate background
|
35 |
+
background_alpha = 1 - alpha
|
36 |
+
# Assuming a white background. This can be changed based on the application or estimated from the image.
|
37 |
+
background = (image * background_alpha) + (1 - background_alpha) * 255
|
38 |
+
|
39 |
+
return foreground, background
|
40 |
+
|
41 |
+
return foreground
|
42 |
+
|
43 |
+
|
44 |
+
def load_entire_model(taskname):
|
45 |
+
model_ls = []
|
46 |
+
if (taskname == "mask"):
|
47 |
+
model = torch.jit.load(Path("./models/sod.pt"))
|
48 |
+
model.eval()
|
49 |
+
model_ls.append(model)
|
50 |
+
elif (taskname == "matting"):
|
51 |
+
model = torch.jit.load(Path("./models/trimap.pt"))
|
52 |
+
model.eval()
|
53 |
+
model_ls.append(model)
|
54 |
+
|
55 |
+
model = torch.jit.load(Path("./models/matting.pt"))
|
56 |
+
model.eval()
|
57 |
+
model_ls.append(model)
|
58 |
+
else:
|
59 |
+
model_ls = []
|
60 |
+
|
61 |
+
return model_ls
|
62 |
+
|
63 |
+
|
64 |
+
model_names = [
|
65 |
+
"matting",
|
66 |
+
"mask"
|
67 |
+
]
|
68 |
+
model_dict = {
|
69 |
+
name: None
|
70 |
+
for name in model_names
|
71 |
+
}
|
72 |
+
|
73 |
+
last_result = {
|
74 |
+
"cache_key": None,
|
75 |
+
"algorithm": None,
|
76 |
+
}
|
77 |
+
|
78 |
+
|
79 |
+
def image_matting(
|
80 |
+
image: PIL.Image.Image,
|
81 |
+
result_type: str,
|
82 |
+
bg_color: str,
|
83 |
+
algorithm: str,
|
84 |
+
morph_op: str,
|
85 |
+
morph_op_factor: float,
|
86 |
+
) -> np.ndarray:
|
87 |
+
image_np = np.ascontiguousarray(image)
|
88 |
+
width, height = image_np.shape[1], image_np.shape[0]
|
89 |
+
cache_key = sha1(image_np).hexdigest()
|
90 |
+
if cache_key == last_result["cache_key"] and algorithm == last_result["algorithm"]:
|
91 |
+
alpha = last_result["alpha"]
|
92 |
+
else:
|
93 |
+
model = load_entire_model(algorithm)
|
94 |
+
transform = transforms.Compose([
|
95 |
+
# transforms.ToPILImage(),
|
96 |
+
transforms.Resize((798, 798)),
|
97 |
+
transforms.ToTensor(),
|
98 |
+
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
99 |
+
])
|
100 |
+
if (algorithm == "mask"):
|
101 |
+
input_tensor = transform(image).unsqueeze(0)
|
102 |
+
with torch.no_grad():
|
103 |
+
alpha = model[0](input_tensor).float()
|
104 |
+
alpha = F.interpolate(alpha, [height, width], mode="bilinear")
|
105 |
+
alpha = np.array(alpha* 255.).astype(np.uint8)[0][0]
|
106 |
+
alpha = np.stack((alpha,alpha,alpha),axis=2)
|
107 |
+
else:
|
108 |
+
transform2 = transforms.Compose([
|
109 |
+
transforms.Resize((800, 800)),
|
110 |
+
transforms.ToTensor(),
|
111 |
+
# transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
112 |
+
])
|
113 |
+
|
114 |
+
input_tensor = transform(image).unsqueeze(0)
|
115 |
+
with torch.no_grad():
|
116 |
+
output = model[0](input_tensor).float()
|
117 |
+
output = F.interpolate(output, [height, width], mode="bilinear")
|
118 |
+
|
119 |
+
trimap = np.array(output[0][0])
|
120 |
+
|
121 |
+
ratio = 0.05
|
122 |
+
site = np.where(trimap > 0)
|
123 |
+
try:
|
124 |
+
bbox = [np.min(site[1]), np.min(site[0]), np.max(site[1]), np.max(site[0])]
|
125 |
+
except:
|
126 |
+
bbox = [0, 0, width, height]
|
127 |
+
|
128 |
+
x0, y0, x1, y1 = bbox
|
129 |
+
H = y1 - y0
|
130 |
+
W = x1 - x0
|
131 |
+
x0 = int(max(0, x0 - ratio * W))
|
132 |
+
x1 = int(min(width, x1 + ratio * W) )
|
133 |
+
y0 = int(max(0, y0 - ratio * H) )
|
134 |
+
y1 = int(min(height, y1 + ratio * H) )
|
135 |
+
|
136 |
+
Image_input = image.crop((x0, y0, x1, y1))
|
137 |
+
# Image_input.save('image.png')
|
138 |
+
input_tensor = transform2(Image_input).unsqueeze(0)
|
139 |
+
|
140 |
+
trimap = trimap[y0:y1, x0:x1]
|
141 |
+
trimap = np.where(trimap < 1, 0, trimap)
|
142 |
+
trimap = np.where(trimap > 1, 255, trimap)
|
143 |
+
trimap = np.where(trimap == 1, 128, trimap)
|
144 |
+
# cv2.imwrite("trimap.png", trimap)
|
145 |
+
|
146 |
+
trimap = Image.fromarray(np.uint8(trimap)).convert('L')
|
147 |
+
input_tensor2 = transform2(trimap).unsqueeze(0)
|
148 |
+
with torch.no_grad():
|
149 |
+
output = model[1]({'image': input_tensor, 'trimap': input_tensor2})['phas']
|
150 |
+
output = F.interpolate(output, [Image_input.size[1],Image_input.size[0]], mode="bilinear")[0].numpy()
|
151 |
+
|
152 |
+
numpy_image = (output * 255.).astype(np.uint8) # Scale to [0, 255] and convert to uint8
|
153 |
+
|
154 |
+
# Step 4: Remove the channel dimension since it's a grayscale image
|
155 |
+
numpy_image = numpy_image.squeeze(0) # Convert from (1, H, W) to (H, W)
|
156 |
+
pil_image = Image.fromarray(numpy_image, mode='L')
|
157 |
+
alpha = Image.new(mode='RGB', size=image.size)
|
158 |
+
alpha.paste(pil_image, (x0, y0))
|
159 |
+
# alpha.save('tmp.png')
|
160 |
+
|
161 |
+
alpha = np.array(alpha).astype(np.uint8)
|
162 |
+
last_result["cache_key"] = cache_key
|
163 |
+
last_result["algorithm"] = algorithm
|
164 |
+
last_result["alpha"] = alpha
|
165 |
+
|
166 |
+
# alpha = (alpha * 255).astype(np.uint8)
|
167 |
+
image = np.array(image)
|
168 |
+
kernel = np.ones((morph_op_factor, morph_op_factor), np.uint8)
|
169 |
+
if morph_op == "Dilate":
|
170 |
+
alpha = cv2.dilate(alpha, kernel, iterations=int(morph_op_factor))
|
171 |
+
elif morph_op == "Erode":
|
172 |
+
alpha = cv2.erode(alpha, kernel, iterations=int(morph_op_factor))
|
173 |
+
else:
|
174 |
+
alpha = alpha
|
175 |
+
alpha = (alpha / 255).astype("float32")
|
176 |
+
|
177 |
+
image = (image / 255.0).astype("float32")
|
178 |
+
fg = estimate_foreground_ml(image, alpha)
|
179 |
+
|
180 |
+
if result_type == "Remove BG":
|
181 |
+
result = fg
|
182 |
+
elif result_type == "Replace BG":
|
183 |
+
bg_r = int(bg_color[1:3], base=16)
|
184 |
+
bg_g = int(bg_color[3:5], base=16)
|
185 |
+
bg_b = int(bg_color[5:7], base=16)
|
186 |
+
|
187 |
+
bg = np.zeros_like(fg)
|
188 |
+
bg[:, :, 0] = bg_r / 255.
|
189 |
+
bg[:, :, 1] = bg_g / 255.
|
190 |
+
bg[:, :, 2] = bg_b / 255.
|
191 |
+
|
192 |
+
result = alpha * image + (1 - alpha) * bg
|
193 |
+
result = np.clip(result, 0, 1)
|
194 |
+
else:
|
195 |
+
result = alpha
|
196 |
+
|
197 |
+
return result
|
198 |
+
|
199 |
+
|
200 |
+
def main():
|
201 |
+
with gr.Blocks() as app:
|
202 |
+
gr.Markdown("Salient Object Matting")
|
203 |
+
|
204 |
+
with gr.Row(variant="panel"):
|
205 |
+
image_input = gr.Image(type='pil')
|
206 |
+
image_output = gr.Image()
|
207 |
+
|
208 |
+
with gr.Row(variant="panel"):
|
209 |
+
result_type = gr.Radio(
|
210 |
+
label="Mode",
|
211 |
+
show_label=True,
|
212 |
+
choices=[
|
213 |
+
"Remove BG",
|
214 |
+
"Replace BG",
|
215 |
+
"Generate Mask",
|
216 |
+
],
|
217 |
+
value="Remove BG",
|
218 |
+
)
|
219 |
+
bg_color = gr.ColorPicker(
|
220 |
+
label="BG Color",
|
221 |
+
show_label=True,
|
222 |
+
value="#000000",
|
223 |
+
)
|
224 |
+
algorithm = gr.Dropdown(
|
225 |
+
label="Algorithm",
|
226 |
+
show_label=True,
|
227 |
+
choices=model_names,
|
228 |
+
value="matting"
|
229 |
+
)
|
230 |
+
|
231 |
+
with gr.Row(variant="panel"):
|
232 |
+
morph_op = gr.Radio(
|
233 |
+
label="Post-process",
|
234 |
+
show_label=True,
|
235 |
+
choices=[
|
236 |
+
"None",
|
237 |
+
"Erode",
|
238 |
+
"Dilate",
|
239 |
+
],
|
240 |
+
value="None",
|
241 |
+
)
|
242 |
+
|
243 |
+
morph_op_factor = gr.Slider(
|
244 |
+
label="Factor",
|
245 |
+
show_label=True,
|
246 |
+
minimum=3,
|
247 |
+
maximum=20,
|
248 |
+
value=3,
|
249 |
+
step=2,
|
250 |
+
)
|
251 |
+
|
252 |
+
run_button = gr.Button("Run")
|
253 |
+
|
254 |
+
run_button.click(
|
255 |
+
image_matting,
|
256 |
+
inputs=[
|
257 |
+
image_input,
|
258 |
+
result_type,
|
259 |
+
bg_color,
|
260 |
+
algorithm,
|
261 |
+
morph_op,
|
262 |
+
morph_op_factor,
|
263 |
+
],
|
264 |
+
outputs=image_output,
|
265 |
+
)
|
266 |
+
|
267 |
+
app.launch()
|
268 |
+
|
269 |
+
|
270 |
+
if __name__ == "__main__":
|
271 |
+
main()
|
models/matting.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:00f2ab77b8f35af8509410df12f0dd14645b49d540da16ab84f78d9497a48d61
|
3 |
+
size 387204217
|
models/sod.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:4783bd4a1fd075d43e486ec81224ad831772dd178817dafd251af4016f9048ca
|
3 |
+
size 356605803
|
models/trimap.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:69173b0fe662c967c53e2274ea128a4d4ed68f88e60e05af4ca540d99b95e450
|
3 |
+
size 356607339
|
requirements.txt
ADDED
Binary file (222 Bytes). View file
|
|