Spaces:
Running
Running
Update to latest version + sam support?
Browse files- .gitattributes +27 -27
- LICENSE.txt +20 -20
- README.md +11 -11
- app.py +109 -70
- rembg/_version.py +3 -3
- rembg/bg.py +133 -19
- rembg/cli.py +4 -2
- rembg/commands/__init__.py +11 -11
- rembg/commands/b_command.py +23 -7
- rembg/commands/d_command.py +14 -0
- rembg/commands/i_command.py +18 -3
- rembg/commands/p_command.py +55 -10
- rembg/commands/s_command.py +58 -22
- rembg/session_factory.py +18 -0
- rembg/sessions/__init__.py +85 -19
- rembg/sessions/base.py +5 -2
- rembg/sessions/birefnet_cod.py +52 -0
- rembg/sessions/birefnet_dis.py +52 -0
- rembg/sessions/birefnet_general.py +91 -0
- rembg/sessions/birefnet_general_lite.py +52 -0
- rembg/sessions/birefnet_hrsod.py +52 -0
- rembg/sessions/birefnet_massive.py +52 -0
- rembg/sessions/birefnet_portrait.py +52 -0
- rembg/sessions/dis_anime.py +43 -6
- rembg/sessions/dis_general_use.py +43 -6
- rembg/sessions/sam.py +252 -78
- rembg/sessions/silueta.py +46 -5
- rembg/sessions/u2net.py +43 -6
- rembg/sessions/u2net_cloth_seg.py +57 -23
- rembg/sessions/u2net_custom.py +105 -0
- rembg/sessions/u2net_human_seg.py +43 -6
- rembg/sessions/u2netp.py +31 -6
- requirements.txt +20 -18
.gitattributes
CHANGED
@@ -1,27 +1,27 @@
|
|
1 |
-
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
-
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
-
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
-
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
-
*.ftz filter=lfs diff=lfs merge=lfs -text
|
6 |
-
*.gz filter=lfs diff=lfs merge=lfs -text
|
7 |
-
*.h5 filter=lfs diff=lfs merge=lfs -text
|
8 |
-
*.joblib filter=lfs diff=lfs merge=lfs -text
|
9 |
-
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
10 |
-
*.model filter=lfs diff=lfs merge=lfs -text
|
11 |
-
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
12 |
-
*.onnx filter=lfs diff=lfs merge=lfs -text
|
13 |
-
*.ot filter=lfs diff=lfs merge=lfs -text
|
14 |
-
*.parquet filter=lfs diff=lfs merge=lfs -text
|
15 |
-
*.pb filter=lfs diff=lfs merge=lfs -text
|
16 |
-
*.pt filter=lfs diff=lfs merge=lfs -text
|
17 |
-
*.pth filter=lfs diff=lfs merge=lfs -text
|
18 |
-
*.rar filter=lfs diff=lfs merge=lfs -text
|
19 |
-
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
20 |
-
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
21 |
-
*.tflite filter=lfs diff=lfs merge=lfs -text
|
22 |
-
*.tgz filter=lfs diff=lfs merge=lfs -text
|
23 |
-
*.wasm filter=lfs diff=lfs merge=lfs -text
|
24 |
-
*.xz filter=lfs diff=lfs merge=lfs -text
|
25 |
-
*.zip filter=lfs diff=lfs merge=lfs -text
|
26 |
-
*.zstandard filter=lfs diff=lfs merge=lfs -text
|
27 |
-
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
19 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
20 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
23 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
26 |
+
*.zstandard filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
LICENSE.txt
CHANGED
@@ -1,21 +1,21 @@
|
|
1 |
-
MIT License
|
2 |
-
|
3 |
-
Copyright (c) 2020 Daniel Gatis
|
4 |
-
|
5 |
-
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
-
of this software and associated documentation files (the "Software"), to deal
|
7 |
-
in the Software without restriction, including without limitation the rights
|
8 |
-
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
-
copies of the Software, and to permit persons to whom the Software is
|
10 |
-
furnished to do so, subject to the following conditions:
|
11 |
-
|
12 |
-
The above copyright notice and this permission notice shall be included in all
|
13 |
-
copies or substantial portions of the Software.
|
14 |
-
|
15 |
-
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
-
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
-
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
-
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
-
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
-
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
SOFTWARE.
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2020 Daniel Gatis
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
SOFTWARE.
|
README.md
CHANGED
@@ -1,12 +1,12 @@
|
|
1 |
-
---
|
2 |
-
title: Rembg
|
3 |
-
emoji: 👀
|
4 |
-
colorFrom: pink
|
5 |
-
colorTo: indigo
|
6 |
-
sdk: gradio
|
7 |
-
sdk_version: 5.6.0
|
8 |
-
app_file: app.py
|
9 |
-
pinned: false
|
10 |
-
---
|
11 |
-
|
12 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
1 |
+
---
|
2 |
+
title: Rembg
|
3 |
+
emoji: 👀
|
4 |
+
colorFrom: pink
|
5 |
+
colorTo: indigo
|
6 |
+
sdk: gradio
|
7 |
+
sdk_version: 5.6.0
|
8 |
+
app_file: app.py
|
9 |
+
pinned: false
|
10 |
+
---
|
11 |
+
|
12 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
app.py
CHANGED
@@ -1,70 +1,109 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
import
|
5 |
-
import
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
)
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
[
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
"
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import os
|
3 |
+
import cv2
|
4 |
+
from rembg import new_session, remove
|
5 |
+
from rembg.sessions import sessions_class
|
6 |
+
|
7 |
+
def inference(file, mask, model, x, y):
|
8 |
+
im = cv2.imread(file, cv2.IMREAD_COLOR)
|
9 |
+
input_path = "input.png"
|
10 |
+
output_path = "output.png"
|
11 |
+
cv2.imwrite(input_path, im)
|
12 |
+
|
13 |
+
with open(input_path, 'rb') as i:
|
14 |
+
with open(output_path, 'wb') as o:
|
15 |
+
input = i.read()
|
16 |
+
session = new_session(model)
|
17 |
+
|
18 |
+
output = remove(
|
19 |
+
input,
|
20 |
+
session=session,
|
21 |
+
**{ "sam_prompt": [{"type": "point", "data": [x, y], "label": 1}] },
|
22 |
+
only_mask=(mask == "Mask only")
|
23 |
+
)
|
24 |
+
o.write(output)
|
25 |
+
|
26 |
+
return output_path
|
27 |
+
|
28 |
+
title = "RemBG"
|
29 |
+
description = "Gradio demo for **[RemBG](https://github.com/danielgatis/rembg)**. To use it, simply upload your image, select a model, click Process, and wait."
|
30 |
+
badge = """
|
31 |
+
<div style="position: fixed; left: 50%; text-align: center;">
|
32 |
+
<a href="https://github.com/danielgatis/rembg" target="_blank" style="text-decoration: none;">
|
33 |
+
<img src="https://img.shields.io/badge/RemBG-Github-blue" alt="RemBG Github" />
|
34 |
+
</a>
|
35 |
+
</div>
|
36 |
+
"""
|
37 |
+
def get_coords(evt: gr.SelectData) -> tuple:
|
38 |
+
return evt.index[0], evt.index[1]
|
39 |
+
|
40 |
+
def show_coords(model: str):
|
41 |
+
visible = model == "sam"
|
42 |
+
return gr.update(visible=visible), gr.update(visible=visible), gr.update(visible=visible)
|
43 |
+
|
44 |
+
for session in sessions_class:
|
45 |
+
session.download_models()
|
46 |
+
|
47 |
+
with gr.Blocks() as app:
|
48 |
+
gr.Markdown(f"# {title}")
|
49 |
+
gr.Markdown(description)
|
50 |
+
|
51 |
+
with gr.Row():
|
52 |
+
inputs = gr.Image(type="filepath", label="Input Image")
|
53 |
+
outputs = gr.Image(type="filepath", label="Output Image")
|
54 |
+
|
55 |
+
with gr.Row():
|
56 |
+
mask_option = gr.Radio(
|
57 |
+
["Default", "Mask only"],
|
58 |
+
value="Default",
|
59 |
+
label="Output Type"
|
60 |
+
)
|
61 |
+
model_selector = gr.Dropdown(
|
62 |
+
[
|
63 |
+
"u2net",
|
64 |
+
"u2netp",
|
65 |
+
"u2net_human_seg",
|
66 |
+
"u2net_cloth_seg",
|
67 |
+
"silueta",
|
68 |
+
"isnet-general-use",
|
69 |
+
"isnet-anime",
|
70 |
+
"sam",
|
71 |
+
"birefnet-general",
|
72 |
+
"birefnet-general-lite",
|
73 |
+
"birefnet-portrait",
|
74 |
+
"birefnet-dis",
|
75 |
+
"birefnet-hrsod",
|
76 |
+
"birefnet-cod",
|
77 |
+
"birefnet-massive"
|
78 |
+
],
|
79 |
+
value="isnet-general-use",
|
80 |
+
label="Model Selection"
|
81 |
+
)
|
82 |
+
|
83 |
+
extra = gr.Markdown("## Click on the image to capture coordinates (for SAM model)", visible=False)
|
84 |
+
|
85 |
+
x = gr.Number(label="Mouse X Coordinate", visible=False)
|
86 |
+
y = gr.Number(label="Mouse Y Coordinate", visible=False)
|
87 |
+
|
88 |
+
model_selector.change(show_coords, inputs=model_selector, outputs=[x, y, extra])
|
89 |
+
inputs.select(get_coords, None, [x, y])
|
90 |
+
|
91 |
+
|
92 |
+
gr.Button("Process Image").click(
|
93 |
+
inference,
|
94 |
+
inputs=[inputs, mask_option, model_selector, x, y],
|
95 |
+
outputs=outputs
|
96 |
+
)
|
97 |
+
|
98 |
+
gr.Examples(
|
99 |
+
examples=[
|
100 |
+
["lion.png", "Default", "u2net", None, None],
|
101 |
+
["girl.jpg", "Default", "u2net", None, None],
|
102 |
+
["anime-girl.jpg", "Default", "isnet-anime", None, None]
|
103 |
+
],
|
104 |
+
inputs=[inputs, mask_option, model_selector, x, y],
|
105 |
+
outputs=outputs
|
106 |
+
)
|
107 |
+
gr.HTML(badge)
|
108 |
+
|
109 |
+
app.launch()
|
rembg/_version.py
CHANGED
@@ -23,9 +23,9 @@ def get_keywords():
|
|
23 |
# setup.py/versioneer.py will grep for the variable names, so they must
|
24 |
# each be defined on a line of their own. _version.py will just call
|
25 |
# get_keywords().
|
26 |
-
git_refnames = " (HEAD -> main
|
27 |
-
git_full = "
|
28 |
-
git_date = "
|
29 |
keywords = {"refnames": git_refnames, "full": git_full, "date": git_date}
|
30 |
return keywords
|
31 |
|
|
|
23 |
# setup.py/versioneer.py will grep for the variable names, so they must
|
24 |
# each be defined on a line of their own. _version.py will just call
|
25 |
# get_keywords().
|
26 |
+
git_refnames = " (HEAD -> main)"
|
27 |
+
git_full = "e740a9681ea32f5c34adce52aa7cc0b4b85bbb11"
|
28 |
+
git_date = "2024-11-20 09:41:13 -0300"
|
29 |
keywords = {"refnames": git_refnames, "full": git_full, "date": git_date}
|
30 |
return keywords
|
31 |
|
rembg/bg.py
CHANGED
@@ -1,8 +1,9 @@
|
|
1 |
import io
|
2 |
from enum import Enum
|
3 |
-
from typing import Any, List, Optional, Tuple, Union
|
4 |
|
5 |
import numpy as np
|
|
|
6 |
from cv2 import (
|
7 |
BORDER_DEFAULT,
|
8 |
MORPH_ELLIPSE,
|
@@ -22,6 +23,8 @@ from .session_factory import new_session
|
|
22 |
from .sessions import sessions_class
|
23 |
from .sessions.base import BaseSession
|
24 |
|
|
|
|
|
25 |
kernel = getStructuringElement(MORPH_ELLIPSE, (3, 3))
|
26 |
|
27 |
|
@@ -38,14 +41,25 @@ def alpha_matting_cutout(
|
|
38 |
background_threshold: int,
|
39 |
erode_structure_size: int,
|
40 |
) -> PILImage:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
41 |
if img.mode == "RGBA" or img.mode == "CMYK":
|
42 |
img = img.convert("RGB")
|
43 |
|
44 |
-
|
45 |
-
|
46 |
|
47 |
-
is_foreground =
|
48 |
-
is_background =
|
49 |
|
50 |
structure = None
|
51 |
if erode_structure_size > 0:
|
@@ -56,11 +70,11 @@ def alpha_matting_cutout(
|
|
56 |
is_foreground = binary_erosion(is_foreground, structure=structure)
|
57 |
is_background = binary_erosion(is_background, structure=structure, border_value=1)
|
58 |
|
59 |
-
trimap = np.full(
|
60 |
trimap[is_foreground] = 255
|
61 |
trimap[is_background] = 0
|
62 |
|
63 |
-
img_normalized =
|
64 |
trimap_normalized = trimap / 255.0
|
65 |
|
66 |
alpha = estimate_alpha_cf(img_normalized, trimap_normalized)
|
@@ -74,12 +88,46 @@ def alpha_matting_cutout(
|
|
74 |
|
75 |
|
76 |
def naive_cutout(img: PILImage, mask: PILImage) -> PILImage:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
77 |
empty = Image.new("RGBA", (img.size), 0)
|
78 |
cutout = Image.composite(img, empty, mask)
|
79 |
return cutout
|
80 |
|
81 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
82 |
def get_concat_v_multi(imgs: List[PILImage]) -> PILImage:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
83 |
pivot = imgs.pop(0)
|
84 |
for im in imgs:
|
85 |
pivot = get_concat_v(pivot, im)
|
@@ -87,6 +135,16 @@ def get_concat_v_multi(imgs: List[PILImage]) -> PILImage:
|
|
87 |
|
88 |
|
89 |
def get_concat_v(img1: PILImage, img2: PILImage) -> PILImage:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
90 |
dst = Image.new("RGBA", (img1.width, img1.height + img2.height))
|
91 |
dst.paste(img1, (0, 0))
|
92 |
dst.paste(img2, (0, img1.height))
|
@@ -102,11 +160,21 @@ def post_process(mask: np.ndarray) -> np.ndarray:
|
|
102 |
"""
|
103 |
mask = morphologyEx(mask, MORPH_OPEN, kernel)
|
104 |
mask = GaussianBlur(mask, (5, 5), sigmaX=2, sigmaY=2, borderType=BORDER_DEFAULT)
|
105 |
-
mask = np.where(mask < 127, 0, 255).astype(np.uint8) #
|
106 |
return mask
|
107 |
|
108 |
|
109 |
def apply_background_color(img: PILImage, color: Tuple[int, int, int, int]) -> PILImage:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
110 |
r, g, b, a = color
|
111 |
colored_image = Image.new("RGBA", img.size, (r, g, b, a))
|
112 |
colored_image.paste(img, mask=img)
|
@@ -115,10 +183,22 @@ def apply_background_color(img: PILImage, color: Tuple[int, int, int, int]) -> P
|
|
115 |
|
116 |
|
117 |
def fix_image_orientation(img: PILImage) -> PILImage:
|
118 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
119 |
|
120 |
|
121 |
def download_models() -> None:
|
|
|
|
|
|
|
122 |
for session in sessions_class:
|
123 |
session.download_models()
|
124 |
|
@@ -133,20 +213,49 @@ def remove(
|
|
133 |
only_mask: bool = False,
|
134 |
post_process_mask: bool = False,
|
135 |
bgcolor: Optional[Tuple[int, int, int, int]] = None,
|
|
|
136 |
*args: Optional[Any],
|
137 |
**kwargs: Optional[Any]
|
138 |
) -> Union[bytes, PILImage, np.ndarray]:
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
143 |
return_type = ReturnType.BYTES
|
144 |
-
img = Image.open(io.BytesIO(data))
|
|
|
|
|
|
|
145 |
elif isinstance(data, np.ndarray):
|
146 |
return_type = ReturnType.NDARRAY
|
147 |
-
img = Image.fromarray(data)
|
148 |
else:
|
149 |
-
raise ValueError(
|
|
|
|
|
|
|
|
|
|
|
|
|
150 |
|
151 |
# Fix image orientation
|
152 |
img = fix_image_orientation(img)
|
@@ -174,10 +283,15 @@ def remove(
|
|
174 |
alpha_matting_erode_size,
|
175 |
)
|
176 |
except ValueError:
|
177 |
-
|
178 |
-
|
|
|
|
|
179 |
else:
|
180 |
-
|
|
|
|
|
|
|
181 |
|
182 |
cutouts.append(cutout)
|
183 |
|
|
|
1 |
import io
|
2 |
from enum import Enum
|
3 |
+
from typing import Any, List, Optional, Tuple, Union, cast
|
4 |
|
5 |
import numpy as np
|
6 |
+
import onnxruntime as ort
|
7 |
from cv2 import (
|
8 |
BORDER_DEFAULT,
|
9 |
MORPH_ELLIPSE,
|
|
|
23 |
from .sessions import sessions_class
|
24 |
from .sessions.base import BaseSession
|
25 |
|
26 |
+
ort.set_default_logger_severity(3)
|
27 |
+
|
28 |
kernel = getStructuringElement(MORPH_ELLIPSE, (3, 3))
|
29 |
|
30 |
|
|
|
41 |
background_threshold: int,
|
42 |
erode_structure_size: int,
|
43 |
) -> PILImage:
|
44 |
+
"""
|
45 |
+
Perform alpha matting on an image using a given mask and threshold values.
|
46 |
+
|
47 |
+
This function takes a PIL image `img` and a PIL image `mask` as input, along with
|
48 |
+
the `foreground_threshold` and `background_threshold` values used to determine
|
49 |
+
foreground and background pixels. The `erode_structure_size` parameter specifies
|
50 |
+
the size of the erosion structure to be applied to the mask.
|
51 |
+
|
52 |
+
The function returns a PIL image representing the cutout of the foreground object
|
53 |
+
from the original image.
|
54 |
+
"""
|
55 |
if img.mode == "RGBA" or img.mode == "CMYK":
|
56 |
img = img.convert("RGB")
|
57 |
|
58 |
+
img_array = np.asarray(img)
|
59 |
+
mask_array = np.asarray(mask)
|
60 |
|
61 |
+
is_foreground = mask_array > foreground_threshold
|
62 |
+
is_background = mask_array < background_threshold
|
63 |
|
64 |
structure = None
|
65 |
if erode_structure_size > 0:
|
|
|
70 |
is_foreground = binary_erosion(is_foreground, structure=structure)
|
71 |
is_background = binary_erosion(is_background, structure=structure, border_value=1)
|
72 |
|
73 |
+
trimap = np.full(mask_array.shape, dtype=np.uint8, fill_value=128)
|
74 |
trimap[is_foreground] = 255
|
75 |
trimap[is_background] = 0
|
76 |
|
77 |
+
img_normalized = img_array / 255.0
|
78 |
trimap_normalized = trimap / 255.0
|
79 |
|
80 |
alpha = estimate_alpha_cf(img_normalized, trimap_normalized)
|
|
|
88 |
|
89 |
|
90 |
def naive_cutout(img: PILImage, mask: PILImage) -> PILImage:
|
91 |
+
"""
|
92 |
+
Perform a simple cutout operation on an image using a mask.
|
93 |
+
|
94 |
+
This function takes a PIL image `img` and a PIL image `mask` as input.
|
95 |
+
It uses the mask to create a new image where the pixels from `img` are
|
96 |
+
cut out based on the mask.
|
97 |
+
|
98 |
+
The function returns a PIL image representing the cutout of the original
|
99 |
+
image using the mask.
|
100 |
+
"""
|
101 |
empty = Image.new("RGBA", (img.size), 0)
|
102 |
cutout = Image.composite(img, empty, mask)
|
103 |
return cutout
|
104 |
|
105 |
|
106 |
+
def putalpha_cutout(img: PILImage, mask: PILImage) -> PILImage:
|
107 |
+
"""
|
108 |
+
Apply the specified mask to the image as an alpha cutout.
|
109 |
+
|
110 |
+
Args:
|
111 |
+
img (PILImage): The image to be modified.
|
112 |
+
mask (PILImage): The mask to be applied.
|
113 |
+
|
114 |
+
Returns:
|
115 |
+
PILImage: The modified image with the alpha cutout applied.
|
116 |
+
"""
|
117 |
+
img.putalpha(mask)
|
118 |
+
return img
|
119 |
+
|
120 |
+
|
121 |
def get_concat_v_multi(imgs: List[PILImage]) -> PILImage:
|
122 |
+
"""
|
123 |
+
Concatenate multiple images vertically.
|
124 |
+
|
125 |
+
Args:
|
126 |
+
imgs (List[PILImage]): The list of images to be concatenated.
|
127 |
+
|
128 |
+
Returns:
|
129 |
+
PILImage: The concatenated image.
|
130 |
+
"""
|
131 |
pivot = imgs.pop(0)
|
132 |
for im in imgs:
|
133 |
pivot = get_concat_v(pivot, im)
|
|
|
135 |
|
136 |
|
137 |
def get_concat_v(img1: PILImage, img2: PILImage) -> PILImage:
|
138 |
+
"""
|
139 |
+
Concatenate two images vertically.
|
140 |
+
|
141 |
+
Args:
|
142 |
+
img1 (PILImage): The first image.
|
143 |
+
img2 (PILImage): The second image to be concatenated below the first image.
|
144 |
+
|
145 |
+
Returns:
|
146 |
+
PILImage: The concatenated image.
|
147 |
+
"""
|
148 |
dst = Image.new("RGBA", (img1.width, img1.height + img2.height))
|
149 |
dst.paste(img1, (0, 0))
|
150 |
dst.paste(img2, (0, img1.height))
|
|
|
160 |
"""
|
161 |
mask = morphologyEx(mask, MORPH_OPEN, kernel)
|
162 |
mask = GaussianBlur(mask, (5, 5), sigmaX=2, sigmaY=2, borderType=BORDER_DEFAULT)
|
163 |
+
mask = np.where(mask < 127, 0, 255).astype(np.uint8) # type: ignore
|
164 |
return mask
|
165 |
|
166 |
|
167 |
def apply_background_color(img: PILImage, color: Tuple[int, int, int, int]) -> PILImage:
|
168 |
+
"""
|
169 |
+
Apply the specified background color to the image.
|
170 |
+
|
171 |
+
Args:
|
172 |
+
img (PILImage): The image to be modified.
|
173 |
+
color (Tuple[int, int, int, int]): The RGBA color to be applied.
|
174 |
+
|
175 |
+
Returns:
|
176 |
+
PILImage: The modified image with the background color applied.
|
177 |
+
"""
|
178 |
r, g, b, a = color
|
179 |
colored_image = Image.new("RGBA", img.size, (r, g, b, a))
|
180 |
colored_image.paste(img, mask=img)
|
|
|
183 |
|
184 |
|
185 |
def fix_image_orientation(img: PILImage) -> PILImage:
|
186 |
+
"""
|
187 |
+
Fix the orientation of the image based on its EXIF data.
|
188 |
+
|
189 |
+
Args:
|
190 |
+
img (PILImage): The image to be fixed.
|
191 |
+
|
192 |
+
Returns:
|
193 |
+
PILImage: The fixed image.
|
194 |
+
"""
|
195 |
+
return cast(PILImage, ImageOps.exif_transpose(img))
|
196 |
|
197 |
|
198 |
def download_models() -> None:
|
199 |
+
"""
|
200 |
+
Download models for image processing.
|
201 |
+
"""
|
202 |
for session in sessions_class:
|
203 |
session.download_models()
|
204 |
|
|
|
213 |
only_mask: bool = False,
|
214 |
post_process_mask: bool = False,
|
215 |
bgcolor: Optional[Tuple[int, int, int, int]] = None,
|
216 |
+
force_return_bytes: bool = False,
|
217 |
*args: Optional[Any],
|
218 |
**kwargs: Optional[Any]
|
219 |
) -> Union[bytes, PILImage, np.ndarray]:
|
220 |
+
"""
|
221 |
+
Remove the background from an input image.
|
222 |
+
|
223 |
+
This function takes in various parameters and returns a modified version of the input image with the background removed. The function can handle input data in the form of bytes, a PIL image, or a numpy array. The function first checks the type of the input data and converts it to a PIL image if necessary. It then fixes the orientation of the image and proceeds to perform background removal using the 'u2net' model. The result is a list of binary masks representing the foreground objects in the image. These masks are post-processed and combined to create a final cutout image. If a background color is provided, it is applied to the cutout image. The function returns the resulting cutout image in the format specified by the input 'return_type' parameter or as python bytes if force_return_bytes is true.
|
224 |
+
|
225 |
+
Parameters:
|
226 |
+
data (Union[bytes, PILImage, np.ndarray]): The input image data.
|
227 |
+
alpha_matting (bool, optional): Flag indicating whether to use alpha matting. Defaults to False.
|
228 |
+
alpha_matting_foreground_threshold (int, optional): Foreground threshold for alpha matting. Defaults to 240.
|
229 |
+
alpha_matting_background_threshold (int, optional): Background threshold for alpha matting. Defaults to 10.
|
230 |
+
alpha_matting_erode_size (int, optional): Erosion size for alpha matting. Defaults to 10.
|
231 |
+
session (Optional[BaseSession], optional): A session object for the 'u2net' model. Defaults to None.
|
232 |
+
only_mask (bool, optional): Flag indicating whether to return only the binary masks. Defaults to False.
|
233 |
+
post_process_mask (bool, optional): Flag indicating whether to post-process the masks. Defaults to False.
|
234 |
+
bgcolor (Optional[Tuple[int, int, int, int]], optional): Background color for the cutout image. Defaults to None.
|
235 |
+
force_return_bytes (bool, optional): Flag indicating whether to return the cutout image as bytes. Defaults to False.
|
236 |
+
*args (Optional[Any]): Additional positional arguments.
|
237 |
+
**kwargs (Optional[Any]): Additional keyword arguments.
|
238 |
+
|
239 |
+
Returns:
|
240 |
+
Union[bytes, PILImage, np.ndarray]: The cutout image with the background removed.
|
241 |
+
"""
|
242 |
+
if isinstance(data, bytes) or force_return_bytes:
|
243 |
return_type = ReturnType.BYTES
|
244 |
+
img = cast(PILImage, Image.open(io.BytesIO(cast(bytes, data))))
|
245 |
+
elif isinstance(data, PILImage):
|
246 |
+
return_type = ReturnType.PILLOW
|
247 |
+
img = cast(PILImage, data)
|
248 |
elif isinstance(data, np.ndarray):
|
249 |
return_type = ReturnType.NDARRAY
|
250 |
+
img = cast(PILImage, Image.fromarray(data))
|
251 |
else:
|
252 |
+
raise ValueError(
|
253 |
+
"Input type {} is not supported. Try using force_return_bytes=True to force python bytes output".format(
|
254 |
+
type(data)
|
255 |
+
)
|
256 |
+
)
|
257 |
+
|
258 |
+
putalpha = kwargs.pop("putalpha", False)
|
259 |
|
260 |
# Fix image orientation
|
261 |
img = fix_image_orientation(img)
|
|
|
283 |
alpha_matting_erode_size,
|
284 |
)
|
285 |
except ValueError:
|
286 |
+
if putalpha:
|
287 |
+
cutout = putalpha_cutout(img, mask)
|
288 |
+
else:
|
289 |
+
cutout = naive_cutout(img, mask)
|
290 |
else:
|
291 |
+
if putalpha:
|
292 |
+
cutout = putalpha_cutout(img, mask)
|
293 |
+
else:
|
294 |
+
cutout = naive_cutout(img, mask)
|
295 |
|
296 |
cutouts.append(cutout)
|
297 |
|
rembg/cli.py
CHANGED
@@ -6,9 +6,11 @@ from .commands import command_functions
|
|
6 |
|
7 |
@click.group()
|
8 |
@click.version_option(version=_version.get_versions()["version"])
|
9 |
-
def
|
10 |
pass
|
11 |
|
12 |
|
13 |
for command in command_functions:
|
14 |
-
|
|
|
|
|
|
6 |
|
7 |
@click.group()
|
8 |
@click.version_option(version=_version.get_versions()["version"])
|
9 |
+
def _main() -> None:
|
10 |
pass
|
11 |
|
12 |
|
13 |
for command in command_functions:
|
14 |
+
_main.add_command(command)
|
15 |
+
|
16 |
+
_main()
|
rembg/commands/__init__.py
CHANGED
@@ -1,13 +1,13 @@
|
|
1 |
-
from importlib import import_module
|
2 |
-
from pathlib import Path
|
3 |
-
from pkgutil import iter_modules
|
4 |
-
|
5 |
command_functions = []
|
6 |
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
command_functions = []
|
2 |
|
3 |
+
from .b_command import b_command
|
4 |
+
from .d_command import d_command
|
5 |
+
from .i_command import i_command
|
6 |
+
from .p_command import p_command
|
7 |
+
from .s_command import s_command
|
8 |
+
|
9 |
+
command_functions.append(b_command)
|
10 |
+
command_functions.append(d_command)
|
11 |
+
command_functions.append(i_command)
|
12 |
+
command_functions.append(p_command)
|
13 |
+
command_functions.append(s_command)
|
rembg/commands/b_command.py
CHANGED
@@ -6,14 +6,14 @@ import sys
|
|
6 |
from typing import IO
|
7 |
|
8 |
import click
|
9 |
-
from PIL import Image
|
10 |
|
11 |
from ..bg import remove
|
12 |
from ..session_factory import new_session
|
13 |
from ..sessions import sessions_names
|
14 |
|
15 |
|
16 |
-
@click.command(
|
17 |
name="b",
|
18 |
help="for a byte stream as input",
|
19 |
)
|
@@ -74,7 +74,7 @@ from ..sessions import sessions_names
|
|
74 |
@click.option(
|
75 |
"-bgc",
|
76 |
"--bgcolor",
|
77 |
-
default=
|
78 |
type=(int, int, int, int),
|
79 |
nargs=4,
|
80 |
help="Background color (R G B A) to replace the removed background with",
|
@@ -94,7 +94,7 @@ from ..sessions import sessions_names
|
|
94 |
"image_height",
|
95 |
type=int,
|
96 |
)
|
97 |
-
def
|
98 |
model: str,
|
99 |
extras: str,
|
100 |
image_width: int,
|
@@ -102,12 +102,28 @@ def rs_command(
|
|
102 |
output_specifier: str,
|
103 |
**kwargs
|
104 |
) -> None:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
105 |
try:
|
106 |
kwargs.update(json.loads(extras))
|
107 |
except Exception:
|
108 |
pass
|
109 |
|
110 |
-
session = new_session(model)
|
111 |
bytes_per_img = image_width * image_height * 3
|
112 |
|
113 |
if output_specifier:
|
@@ -118,7 +134,7 @@ def rs_command(
|
|
118 |
if not os.path.isdir(output_dir):
|
119 |
os.makedirs(output_dir, exist_ok=True)
|
120 |
|
121 |
-
def img_to_byte_array(img:
|
122 |
buff = io.BytesIO()
|
123 |
img.save(buff, format="PNG")
|
124 |
return buff.getvalue()
|
@@ -146,7 +162,7 @@ def rs_command(
|
|
146 |
if not img_bytes:
|
147 |
break
|
148 |
|
149 |
-
img =
|
150 |
output = remove(img, session=session, **kwargs)
|
151 |
|
152 |
if output_specifier:
|
|
|
6 |
from typing import IO
|
7 |
|
8 |
import click
|
9 |
+
from PIL.Image import Image as PILImage
|
10 |
|
11 |
from ..bg import remove
|
12 |
from ..session_factory import new_session
|
13 |
from ..sessions import sessions_names
|
14 |
|
15 |
|
16 |
+
@click.command( # type: ignore
|
17 |
name="b",
|
18 |
help="for a byte stream as input",
|
19 |
)
|
|
|
74 |
@click.option(
|
75 |
"-bgc",
|
76 |
"--bgcolor",
|
77 |
+
default=(0, 0, 0, 0),
|
78 |
type=(int, int, int, int),
|
79 |
nargs=4,
|
80 |
help="Background color (R G B A) to replace the removed background with",
|
|
|
94 |
"image_height",
|
95 |
type=int,
|
96 |
)
|
97 |
+
def b_command(
|
98 |
model: str,
|
99 |
extras: str,
|
100 |
image_width: int,
|
|
|
102 |
output_specifier: str,
|
103 |
**kwargs
|
104 |
) -> None:
|
105 |
+
"""
|
106 |
+
Command-line interface for processing images by removing the background using a specified model and generating a mask.
|
107 |
+
|
108 |
+
This CLI command takes several options and arguments to configure the background removal process and save the processed images.
|
109 |
+
|
110 |
+
Parameters:
|
111 |
+
model (str): The name of the model to use for background removal.
|
112 |
+
extras (str): Additional options in JSON format that can be passed to customize the background removal process.
|
113 |
+
image_width (int): The width of the input images in pixels.
|
114 |
+
image_height (int): The height of the input images in pixels.
|
115 |
+
output_specifier (str): A printf-style specifier for the output filenames. If specified, the processed images will be saved to the specified output directory with filenames generated using the specifier.
|
116 |
+
**kwargs: Additional keyword arguments that can be used to customize the background removal process.
|
117 |
+
|
118 |
+
Returns:
|
119 |
+
None
|
120 |
+
"""
|
121 |
try:
|
122 |
kwargs.update(json.loads(extras))
|
123 |
except Exception:
|
124 |
pass
|
125 |
|
126 |
+
session = new_session(model, **kwargs)
|
127 |
bytes_per_img = image_width * image_height * 3
|
128 |
|
129 |
if output_specifier:
|
|
|
134 |
if not os.path.isdir(output_dir):
|
135 |
os.makedirs(output_dir, exist_ok=True)
|
136 |
|
137 |
+
def img_to_byte_array(img: PILImage) -> bytes:
|
138 |
buff = io.BytesIO()
|
139 |
img.save(buff, format="PNG")
|
140 |
return buff.getvalue()
|
|
|
162 |
if not img_bytes:
|
163 |
break
|
164 |
|
165 |
+
img = PILImage.frombytes("RGB", (image_width, image_height), img_bytes)
|
166 |
output = remove(img, session=session, **kwargs)
|
167 |
|
168 |
if output_specifier:
|
rembg/commands/d_command.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import click
|
2 |
+
|
3 |
+
from ..bg import download_models
|
4 |
+
|
5 |
+
|
6 |
+
@click.command( # type: ignore
|
7 |
+
name="d",
|
8 |
+
help="download all models",
|
9 |
+
)
|
10 |
+
def d_command(*args, **kwargs) -> None:
|
11 |
+
"""
|
12 |
+
Download all models
|
13 |
+
"""
|
14 |
+
download_models()
|
rembg/commands/i_command.py
CHANGED
@@ -9,7 +9,7 @@ from ..session_factory import new_session
|
|
9 |
from ..sessions import sessions_names
|
10 |
|
11 |
|
12 |
-
@click.command(
|
13 |
name="i",
|
14 |
help="for a file as input",
|
15 |
)
|
@@ -70,7 +70,7 @@ from ..sessions import sessions_names
|
|
70 |
@click.option(
|
71 |
"-bgc",
|
72 |
"--bgcolor",
|
73 |
-
default=
|
74 |
type=(int, int, int, int),
|
75 |
nargs=4,
|
76 |
help="Background color (R G B A) to replace the removed background with",
|
@@ -85,9 +85,24 @@ from ..sessions import sessions_names
|
|
85 |
type=click.File("wb", lazy=True),
|
86 |
)
|
87 |
def i_command(model: str, extras: str, input: IO, output: IO, **kwargs) -> None:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
88 |
try:
|
89 |
kwargs.update(json.loads(extras))
|
90 |
except Exception:
|
91 |
pass
|
92 |
|
93 |
-
output.write(remove(input.read(), session=new_session(model), **kwargs))
|
|
|
9 |
from ..sessions import sessions_names
|
10 |
|
11 |
|
12 |
+
@click.command( # type: ignore
|
13 |
name="i",
|
14 |
help="for a file as input",
|
15 |
)
|
|
|
70 |
@click.option(
|
71 |
"-bgc",
|
72 |
"--bgcolor",
|
73 |
+
default=(0, 0, 0, 0),
|
74 |
type=(int, int, int, int),
|
75 |
nargs=4,
|
76 |
help="Background color (R G B A) to replace the removed background with",
|
|
|
85 |
type=click.File("wb", lazy=True),
|
86 |
)
|
87 |
def i_command(model: str, extras: str, input: IO, output: IO, **kwargs) -> None:
|
88 |
+
"""
|
89 |
+
Click command line interface function to process an input file based on the provided options.
|
90 |
+
|
91 |
+
This function is the entry point for the CLI program. It reads an input file, applies image processing operations based on the provided options, and writes the output to a file.
|
92 |
+
|
93 |
+
Parameters:
|
94 |
+
model (str): The name of the model to use for image processing.
|
95 |
+
extras (str): Additional options in JSON format.
|
96 |
+
input: The input file to process.
|
97 |
+
output: The output file to write the processed image to.
|
98 |
+
**kwargs: Additional keyword arguments corresponding to the command line options.
|
99 |
+
|
100 |
+
Returns:
|
101 |
+
None
|
102 |
+
"""
|
103 |
try:
|
104 |
kwargs.update(json.loads(extras))
|
105 |
except Exception:
|
106 |
pass
|
107 |
|
108 |
+
output.write(remove(input.read(), session=new_session(model, **kwargs), **kwargs))
|
rembg/commands/p_command.py
CHANGED
@@ -14,7 +14,7 @@ from ..session_factory import new_session
|
|
14 |
from ..sessions import sessions_names
|
15 |
|
16 |
|
17 |
-
@click.command(
|
18 |
name="p",
|
19 |
help="for a folder as input",
|
20 |
)
|
@@ -80,10 +80,18 @@ from ..sessions import sessions_names
|
|
80 |
show_default=True,
|
81 |
help="watches a folder for changes",
|
82 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
83 |
@click.option(
|
84 |
"-bgc",
|
85 |
"--bgcolor",
|
86 |
-
default=
|
87 |
type=(int, int, int, int),
|
88 |
nargs=4,
|
89 |
help="Background color (R G B A) to replace the removed background with",
|
@@ -115,14 +123,36 @@ def p_command(
|
|
115 |
input: pathlib.Path,
|
116 |
output: pathlib.Path,
|
117 |
watch: bool,
|
|
|
118 |
**kwargs,
|
119 |
) -> None:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
120 |
try:
|
121 |
kwargs.update(json.loads(extras))
|
122 |
except Exception:
|
123 |
pass
|
124 |
|
125 |
-
session = new_session(model)
|
126 |
|
127 |
def process(each_input: pathlib.Path) -> None:
|
128 |
try:
|
@@ -147,33 +177,48 @@ def p_command(
|
|
147 |
print(
|
148 |
f"processed: {each_input.absolute()} -> {each_output.absolute()}"
|
149 |
)
|
|
|
|
|
|
|
|
|
150 |
except Exception as e:
|
151 |
print(e)
|
152 |
|
153 |
inputs = list(input.glob("**/*"))
|
154 |
if not watch:
|
155 |
-
|
156 |
|
157 |
-
for each_input in
|
158 |
if not each_input.is_dir():
|
159 |
process(each_input)
|
160 |
|
161 |
if watch:
|
|
|
162 |
observer = Observer()
|
163 |
|
164 |
class EventHandler(FileSystemEventHandler):
|
165 |
def on_any_event(self, event: FileSystemEvent) -> None:
|
166 |
-
|
167 |
-
|
|
|
|
|
|
|
|
|
168 |
):
|
169 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
170 |
|
171 |
event_handler = EventHandler()
|
172 |
-
observer.schedule(event_handler, input, recursive=False)
|
173 |
observer.start()
|
174 |
|
175 |
try:
|
176 |
-
while
|
177 |
time.sleep(1)
|
178 |
|
179 |
finally:
|
|
|
14 |
from ..sessions import sessions_names
|
15 |
|
16 |
|
17 |
+
@click.command( # type: ignore
|
18 |
name="p",
|
19 |
help="for a folder as input",
|
20 |
)
|
|
|
80 |
show_default=True,
|
81 |
help="watches a folder for changes",
|
82 |
)
|
83 |
+
@click.option(
|
84 |
+
"-d",
|
85 |
+
"--delete_input",
|
86 |
+
default=False,
|
87 |
+
is_flag=True,
|
88 |
+
show_default=True,
|
89 |
+
help="delete input file after processing",
|
90 |
+
)
|
91 |
@click.option(
|
92 |
"-bgc",
|
93 |
"--bgcolor",
|
94 |
+
default=(0, 0, 0, 0),
|
95 |
type=(int, int, int, int),
|
96 |
nargs=4,
|
97 |
help="Background color (R G B A) to replace the removed background with",
|
|
|
123 |
input: pathlib.Path,
|
124 |
output: pathlib.Path,
|
125 |
watch: bool,
|
126 |
+
delete_input: bool,
|
127 |
**kwargs,
|
128 |
) -> None:
|
129 |
+
"""
|
130 |
+
Command-line interface (CLI) program for performing background removal on images in a folder.
|
131 |
+
|
132 |
+
This program takes a folder as input and uses a specified model to remove the background from the images in the folder.
|
133 |
+
It provides various options for configuration, such as choosing the model, enabling alpha matting, setting trimap thresholds, erode size, etc.
|
134 |
+
Additional options include outputting only the mask and post-processing the mask.
|
135 |
+
The program can also watch the input folder for changes and automatically process new images.
|
136 |
+
The resulting images with the background removed are saved in the specified output folder.
|
137 |
+
|
138 |
+
Parameters:
|
139 |
+
model (str): The name of the model to use for background removal.
|
140 |
+
extras (str): Additional options in JSON format.
|
141 |
+
input (pathlib.Path): The path to the input folder.
|
142 |
+
output (pathlib.Path): The path to the output folder.
|
143 |
+
watch (bool): Whether to watch the input folder for changes.
|
144 |
+
delete_input (bool): Whether to delete the input file after processing.
|
145 |
+
**kwargs: Additional keyword arguments.
|
146 |
+
|
147 |
+
Returns:
|
148 |
+
None
|
149 |
+
"""
|
150 |
try:
|
151 |
kwargs.update(json.loads(extras))
|
152 |
except Exception:
|
153 |
pass
|
154 |
|
155 |
+
session = new_session(model, **kwargs)
|
156 |
|
157 |
def process(each_input: pathlib.Path) -> None:
|
158 |
try:
|
|
|
177 |
print(
|
178 |
f"processed: {each_input.absolute()} -> {each_output.absolute()}"
|
179 |
)
|
180 |
+
|
181 |
+
if delete_input:
|
182 |
+
each_input.unlink()
|
183 |
+
|
184 |
except Exception as e:
|
185 |
print(e)
|
186 |
|
187 |
inputs = list(input.glob("**/*"))
|
188 |
if not watch:
|
189 |
+
inputs_tqdm = tqdm(inputs)
|
190 |
|
191 |
+
for each_input in inputs_tqdm:
|
192 |
if not each_input.is_dir():
|
193 |
process(each_input)
|
194 |
|
195 |
if watch:
|
196 |
+
should_watch = True
|
197 |
observer = Observer()
|
198 |
|
199 |
class EventHandler(FileSystemEventHandler):
|
200 |
def on_any_event(self, event: FileSystemEvent) -> None:
|
201 |
+
src_path = cast(str, event.src_path)
|
202 |
+
if (
|
203 |
+
not (
|
204 |
+
event.is_directory or event.event_type in ["deleted", "closed"]
|
205 |
+
)
|
206 |
+
and pathlib.Path(src_path).exists()
|
207 |
):
|
208 |
+
if src_path.endswith("stop.txt"):
|
209 |
+
nonlocal should_watch
|
210 |
+
should_watch = False
|
211 |
+
pathlib.Path(src_path).unlink()
|
212 |
+
return
|
213 |
+
|
214 |
+
process(pathlib.Path(src_path))
|
215 |
|
216 |
event_handler = EventHandler()
|
217 |
+
observer.schedule(event_handler, str(input), recursive=False)
|
218 |
observer.start()
|
219 |
|
220 |
try:
|
221 |
+
while should_watch:
|
222 |
time.sleep(1)
|
223 |
|
224 |
finally:
|
rembg/commands/s_command.py
CHANGED
@@ -19,18 +19,26 @@ from ..sessions import sessions_names
|
|
19 |
from ..sessions.base import BaseSession
|
20 |
|
21 |
|
22 |
-
@click.command(
|
23 |
name="s",
|
24 |
help="for a http server",
|
25 |
)
|
26 |
@click.option(
|
27 |
"-p",
|
28 |
"--port",
|
29 |
-
default=
|
30 |
type=int,
|
31 |
show_default=True,
|
32 |
help="port",
|
33 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
34 |
@click.option(
|
35 |
"-l",
|
36 |
"--log_level",
|
@@ -47,7 +55,13 @@ from ..sessions.base import BaseSession
|
|
47 |
show_default=True,
|
48 |
help="number of worker threads",
|
49 |
)
|
50 |
-
def s_command(port: int, log_level: str, threads: int) -> None:
|
|
|
|
|
|
|
|
|
|
|
|
|
51 |
sessions: dict[str, BaseSession] = {}
|
52 |
tags_metadata = [
|
53 |
{
|
@@ -186,7 +200,9 @@ def s_command(port: int, log_level: str, threads: int) -> None:
|
|
186 |
return Response(
|
187 |
remove(
|
188 |
content,
|
189 |
-
session=sessions.setdefault(
|
|
|
|
|
190 |
alpha_matting=commons.a,
|
191 |
alpha_matting_foreground_threshold=commons.af,
|
192 |
alpha_matting_background_threshold=commons.ab,
|
@@ -245,12 +261,27 @@ def s_command(port: int, log_level: str, threads: int) -> None:
|
|
245 |
return await asyncify(im_without_bg)(file, commons) # type: ignore
|
246 |
|
247 |
def gr_app(app):
|
248 |
-
def inference(input_path, model):
|
249 |
output_path = "output.png"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
250 |
with open(input_path, "rb") as i:
|
251 |
with open(output_path, "wb") as o:
|
252 |
input = i.read()
|
253 |
-
output = remove(input,
|
254 |
o.write(output)
|
255 |
return os.path.join(output_path)
|
256 |
|
@@ -258,28 +289,33 @@ def s_command(port: int, log_level: str, threads: int) -> None:
|
|
258 |
inference,
|
259 |
[
|
260 |
gr.components.Image(type="filepath", label="Input"),
|
261 |
-
gr.components.Dropdown(
|
262 |
-
|
263 |
-
|
264 |
-
|
265 |
-
|
266 |
-
|
267 |
-
|
268 |
-
"isnet-general-use",
|
269 |
-
"isnet-anime",
|
270 |
-
],
|
271 |
-
value="u2net",
|
272 |
-
label="Models",
|
273 |
),
|
|
|
|
|
|
|
|
|
|
|
|
|
274 |
],
|
275 |
gr.components.Image(type="filepath", label="Output"),
|
|
|
276 |
)
|
277 |
|
278 |
-
interface.queue(concurrency_count=3)
|
279 |
app = gr.mount_gradio_app(app, interface, path="/")
|
280 |
return app
|
281 |
|
282 |
-
print(
|
283 |
-
|
|
|
|
|
|
|
|
|
284 |
|
285 |
-
uvicorn.run(gr_app(app), host=
|
|
|
19 |
from ..sessions.base import BaseSession
|
20 |
|
21 |
|
22 |
+
@click.command( # type: ignore
|
23 |
name="s",
|
24 |
help="for a http server",
|
25 |
)
|
26 |
@click.option(
|
27 |
"-p",
|
28 |
"--port",
|
29 |
+
default=7000,
|
30 |
type=int,
|
31 |
show_default=True,
|
32 |
help="port",
|
33 |
)
|
34 |
+
@click.option(
|
35 |
+
"-h",
|
36 |
+
"--host",
|
37 |
+
default="0.0.0.0",
|
38 |
+
type=str,
|
39 |
+
show_default=True,
|
40 |
+
help="host",
|
41 |
+
)
|
42 |
@click.option(
|
43 |
"-l",
|
44 |
"--log_level",
|
|
|
55 |
show_default=True,
|
56 |
help="number of worker threads",
|
57 |
)
|
58 |
+
def s_command(port: int, host: str, log_level: str, threads: int) -> None:
|
59 |
+
"""
|
60 |
+
Command-line interface for running the FastAPI web server.
|
61 |
+
|
62 |
+
This function starts the FastAPI web server with the specified port and log level.
|
63 |
+
If the number of worker threads is specified, it sets the thread limiter accordingly.
|
64 |
+
"""
|
65 |
sessions: dict[str, BaseSession] = {}
|
66 |
tags_metadata = [
|
67 |
{
|
|
|
200 |
return Response(
|
201 |
remove(
|
202 |
content,
|
203 |
+
session=sessions.setdefault(
|
204 |
+
commons.model, new_session(commons.model, **kwargs)
|
205 |
+
),
|
206 |
alpha_matting=commons.a,
|
207 |
alpha_matting_foreground_threshold=commons.af,
|
208 |
alpha_matting_background_threshold=commons.ab,
|
|
|
261 |
return await asyncify(im_without_bg)(file, commons) # type: ignore
|
262 |
|
263 |
def gr_app(app):
|
264 |
+
def inference(input_path, model, *args):
|
265 |
output_path = "output.png"
|
266 |
+
a, af, ab, ae, om, ppm, cmd_args = args
|
267 |
+
|
268 |
+
kwargs = {
|
269 |
+
"alpha_matting": a,
|
270 |
+
"alpha_matting_foreground_threshold": af,
|
271 |
+
"alpha_matting_background_threshold": ab,
|
272 |
+
"alpha_matting_erode_size": ae,
|
273 |
+
"only_mask": om,
|
274 |
+
"post_process_mask": ppm,
|
275 |
+
}
|
276 |
+
|
277 |
+
if cmd_args:
|
278 |
+
kwargs.update(json.loads(cmd_args))
|
279 |
+
kwargs["session"] = new_session(model, **kwargs)
|
280 |
+
|
281 |
with open(input_path, "rb") as i:
|
282 |
with open(output_path, "wb") as o:
|
283 |
input = i.read()
|
284 |
+
output = remove(input, **kwargs)
|
285 |
o.write(output)
|
286 |
return os.path.join(output_path)
|
287 |
|
|
|
289 |
inference,
|
290 |
[
|
291 |
gr.components.Image(type="filepath", label="Input"),
|
292 |
+
gr.components.Dropdown(sessions_names, value="u2net", label="Models"),
|
293 |
+
gr.components.Checkbox(value=True, label="Alpha matting"),
|
294 |
+
gr.components.Slider(
|
295 |
+
value=240, minimum=0, maximum=255, label="Foreground threshold"
|
296 |
+
),
|
297 |
+
gr.components.Slider(
|
298 |
+
value=10, minimum=0, maximum=255, label="Background threshold"
|
|
|
|
|
|
|
|
|
|
|
299 |
),
|
300 |
+
gr.components.Slider(
|
301 |
+
value=40, minimum=0, maximum=255, label="Erosion size"
|
302 |
+
),
|
303 |
+
gr.components.Checkbox(value=False, label="Only mask"),
|
304 |
+
gr.components.Checkbox(value=True, label="Post process mask"),
|
305 |
+
gr.components.Textbox(label="Arguments"),
|
306 |
],
|
307 |
gr.components.Image(type="filepath", label="Output"),
|
308 |
+
concurrency_limit=3,
|
309 |
)
|
310 |
|
|
|
311 |
app = gr.mount_gradio_app(app, interface, path="/")
|
312 |
return app
|
313 |
|
314 |
+
print(
|
315 |
+
f"To access the API documentation, go to http://{'localhost' if host == '0.0.0.0' else host}:{port}/api"
|
316 |
+
)
|
317 |
+
print(
|
318 |
+
f"To access the UI, go to http://{'localhost' if host == '0.0.0.0' else host}:{port}"
|
319 |
+
)
|
320 |
|
321 |
+
uvicorn.run(gr_app(app), host=host, port=port, log_level=log_level)
|
rembg/session_factory.py
CHANGED
@@ -11,6 +11,23 @@ from .sessions.u2net import U2netSession
|
|
11 |
def new_session(
|
12 |
model_name: str = "u2net", providers=None, *args, **kwargs
|
13 |
) -> BaseSession:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
session_class: Type[BaseSession] = U2netSession
|
15 |
|
16 |
for sc in sessions_class:
|
@@ -22,5 +39,6 @@ def new_session(
|
|
22 |
|
23 |
if "OMP_NUM_THREADS" in os.environ:
|
24 |
sess_opts.inter_op_num_threads = int(os.environ["OMP_NUM_THREADS"])
|
|
|
25 |
|
26 |
return session_class(model_name, sess_opts, providers, *args, **kwargs)
|
|
|
11 |
def new_session(
|
12 |
model_name: str = "u2net", providers=None, *args, **kwargs
|
13 |
) -> BaseSession:
|
14 |
+
"""
|
15 |
+
Create a new session object based on the specified model name.
|
16 |
+
|
17 |
+
This function searches for the session class based on the model name in the 'sessions_class' list.
|
18 |
+
It then creates an instance of the session class with the provided arguments.
|
19 |
+
The 'sess_opts' object is created using the 'ort.SessionOptions()' constructor.
|
20 |
+
If the 'OMP_NUM_THREADS' environment variable is set, the 'inter_op_num_threads' option of 'sess_opts' is set to its value.
|
21 |
+
|
22 |
+
Parameters:
|
23 |
+
model_name (str): The name of the model.
|
24 |
+
providers: The providers for the session.
|
25 |
+
*args: Additional positional arguments.
|
26 |
+
**kwargs: Additional keyword arguments.
|
27 |
+
|
28 |
+
Returns:
|
29 |
+
BaseSession: The created session object.
|
30 |
+
"""
|
31 |
session_class: Type[BaseSession] = U2netSession
|
32 |
|
33 |
for sc in sessions_class:
|
|
|
39 |
|
40 |
if "OMP_NUM_THREADS" in os.environ:
|
41 |
sess_opts.inter_op_num_threads = int(os.environ["OMP_NUM_THREADS"])
|
42 |
+
sess_opts.intra_op_num_threads = int(os.environ["OMP_NUM_THREADS"])
|
43 |
|
44 |
return session_class(model_name, sess_opts, providers, *args, **kwargs)
|
rembg/sessions/__init__.py
CHANGED
@@ -1,22 +1,88 @@
|
|
1 |
-
from
|
2 |
-
|
3 |
-
from
|
4 |
-
from pkgutil import iter_modules
|
5 |
|
6 |
from .base import BaseSession
|
7 |
|
8 |
-
sessions_class = []
|
9 |
-
sessions_names = []
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
from typing import List
|
|
|
4 |
|
5 |
from .base import BaseSession
|
6 |
|
7 |
+
sessions_class: List[type[BaseSession]] = []
|
8 |
+
sessions_names: List[str] = []
|
9 |
+
|
10 |
+
from .birefnet_general import BiRefNetSessionGeneral
|
11 |
+
|
12 |
+
sessions_class.append(BiRefNetSessionGeneral)
|
13 |
+
sessions_names.append(BiRefNetSessionGeneral.name())
|
14 |
+
|
15 |
+
from .birefnet_general_lite import BiRefNetSessionGeneralLite
|
16 |
+
|
17 |
+
sessions_class.append(BiRefNetSessionGeneralLite)
|
18 |
+
sessions_names.append(BiRefNetSessionGeneralLite.name())
|
19 |
+
|
20 |
+
from .birefnet_portrait import BiRefNetSessionPortrait
|
21 |
+
|
22 |
+
sessions_class.append(BiRefNetSessionPortrait)
|
23 |
+
sessions_names.append(BiRefNetSessionPortrait.name())
|
24 |
+
|
25 |
+
from .birefnet_dis import BiRefNetSessionDIS
|
26 |
+
|
27 |
+
sessions_class.append(BiRefNetSessionDIS)
|
28 |
+
sessions_names.append(BiRefNetSessionDIS.name())
|
29 |
+
|
30 |
+
from .birefnet_hrsod import BiRefNetSessionHRSOD
|
31 |
+
|
32 |
+
sessions_class.append(BiRefNetSessionHRSOD)
|
33 |
+
sessions_names.append(BiRefNetSessionHRSOD.name())
|
34 |
+
|
35 |
+
from .birefnet_cod import BiRefNetSessionCOD
|
36 |
+
|
37 |
+
sessions_class.append(BiRefNetSessionCOD)
|
38 |
+
sessions_names.append(BiRefNetSessionCOD.name())
|
39 |
+
|
40 |
+
from .birefnet_massive import BiRefNetSessionMassive
|
41 |
+
|
42 |
+
sessions_class.append(BiRefNetSessionMassive)
|
43 |
+
sessions_names.append(BiRefNetSessionMassive.name())
|
44 |
+
|
45 |
+
from .dis_anime import DisSession
|
46 |
+
|
47 |
+
sessions_class.append(DisSession)
|
48 |
+
sessions_names.append(DisSession.name())
|
49 |
+
|
50 |
+
from .dis_general_use import DisSession as DisSessionGeneralUse
|
51 |
+
|
52 |
+
sessions_class.append(DisSessionGeneralUse)
|
53 |
+
sessions_names.append(DisSessionGeneralUse.name())
|
54 |
+
|
55 |
+
from .sam import SamSession
|
56 |
+
|
57 |
+
sessions_class.append(SamSession)
|
58 |
+
sessions_names.append(SamSession.name())
|
59 |
+
|
60 |
+
from .silueta import SiluetaSession
|
61 |
+
|
62 |
+
sessions_class.append(SiluetaSession)
|
63 |
+
sessions_names.append(SiluetaSession.name())
|
64 |
+
|
65 |
+
from .u2net_cloth_seg import Unet2ClothSession
|
66 |
+
|
67 |
+
sessions_class.append(Unet2ClothSession)
|
68 |
+
sessions_names.append(Unet2ClothSession.name())
|
69 |
+
|
70 |
+
from .u2net_custom import U2netCustomSession
|
71 |
+
|
72 |
+
sessions_class.append(U2netCustomSession)
|
73 |
+
sessions_names.append(U2netCustomSession.name())
|
74 |
+
|
75 |
+
from .u2net_human_seg import U2netHumanSegSession
|
76 |
+
|
77 |
+
sessions_class.append(U2netHumanSegSession)
|
78 |
+
sessions_names.append(U2netHumanSegSession.name())
|
79 |
+
|
80 |
+
from .u2net import U2netSession
|
81 |
+
|
82 |
+
sessions_class.append(U2netSession)
|
83 |
+
sessions_names.append(U2netSession.name())
|
84 |
+
|
85 |
+
from .u2netp import U2netpSession
|
86 |
+
|
87 |
+
sessions_class.append(U2netpSession)
|
88 |
+
sessions_names.append(U2netpSession.name())
|
rembg/sessions/base.py
CHANGED
@@ -8,6 +8,8 @@ from PIL.Image import Image as PILImage
|
|
8 |
|
9 |
|
10 |
class BaseSession:
|
|
|
|
|
11 |
def __init__(
|
12 |
self,
|
13 |
model_name: str,
|
@@ -16,6 +18,7 @@ class BaseSession:
|
|
16 |
*args,
|
17 |
**kwargs
|
18 |
):
|
|
|
19 |
self.model_name = model_name
|
20 |
|
21 |
self.providers = []
|
@@ -29,7 +32,7 @@ class BaseSession:
|
|
29 |
self.providers.extend(_providers)
|
30 |
|
31 |
self.inner_session = ort.InferenceSession(
|
32 |
-
str(self.__class__.download_models()),
|
33 |
providers=self.providers,
|
34 |
sess_options=sess_opts,
|
35 |
)
|
@@ -43,7 +46,7 @@ class BaseSession:
|
|
43 |
*args,
|
44 |
**kwargs
|
45 |
) -> Dict[str, np.ndarray]:
|
46 |
-
im = img.convert("RGB").resize(size, Image.LANCZOS)
|
47 |
|
48 |
im_ary = np.array(im)
|
49 |
im_ary = im_ary / np.max(im_ary)
|
|
|
8 |
|
9 |
|
10 |
class BaseSession:
|
11 |
+
"""This is a base class for managing a session with a machine learning model."""
|
12 |
+
|
13 |
def __init__(
|
14 |
self,
|
15 |
model_name: str,
|
|
|
18 |
*args,
|
19 |
**kwargs
|
20 |
):
|
21 |
+
"""Initialize an instance of the BaseSession class."""
|
22 |
self.model_name = model_name
|
23 |
|
24 |
self.providers = []
|
|
|
32 |
self.providers.extend(_providers)
|
33 |
|
34 |
self.inner_session = ort.InferenceSession(
|
35 |
+
str(self.__class__.download_models(*args, **kwargs)),
|
36 |
providers=self.providers,
|
37 |
sess_options=sess_opts,
|
38 |
)
|
|
|
46 |
*args,
|
47 |
**kwargs
|
48 |
) -> Dict[str, np.ndarray]:
|
49 |
+
im = img.convert("RGB").resize(size, Image.Resampling.LANCZOS)
|
50 |
|
51 |
im_ary = np.array(im)
|
52 |
im_ary = im_ary / np.max(im_ary)
|
rembg/sessions/birefnet_cod.py
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
import pooch
|
4 |
+
|
5 |
+
from . import BiRefNetSessionGeneral
|
6 |
+
|
7 |
+
|
8 |
+
class BiRefNetSessionCOD(BiRefNetSessionGeneral):
|
9 |
+
"""
|
10 |
+
This class represents a BiRefNet-COD session, which is a subclass of BiRefNetSessionGeneral.
|
11 |
+
"""
|
12 |
+
|
13 |
+
@classmethod
|
14 |
+
def download_models(cls, *args, **kwargs):
|
15 |
+
"""
|
16 |
+
Downloads the BiRefNet-COD model file from a specific URL and saves it.
|
17 |
+
|
18 |
+
Parameters:
|
19 |
+
*args: Additional positional arguments.
|
20 |
+
**kwargs: Additional keyword arguments.
|
21 |
+
|
22 |
+
Returns:
|
23 |
+
str: The path to the downloaded model file.
|
24 |
+
"""
|
25 |
+
fname = f"{cls.name(*args, **kwargs)}.onnx"
|
26 |
+
pooch.retrieve(
|
27 |
+
"https://github.com/danielgatis/rembg/releases/download/v0.0.0/BiRefNet-COD-epoch_125.onnx",
|
28 |
+
(
|
29 |
+
None
|
30 |
+
if cls.checksum_disabled(*args, **kwargs)
|
31 |
+
else "md5:f6d0d21ca89d287f17e7afe9f5fd3b45"
|
32 |
+
),
|
33 |
+
fname=fname,
|
34 |
+
path=cls.u2net_home(*args, **kwargs),
|
35 |
+
progressbar=True,
|
36 |
+
)
|
37 |
+
|
38 |
+
return os.path.join(cls.u2net_home(*args, **kwargs), fname)
|
39 |
+
|
40 |
+
@classmethod
|
41 |
+
def name(cls, *args, **kwargs):
|
42 |
+
"""
|
43 |
+
Returns the name of the BiRefNet-COD session.
|
44 |
+
|
45 |
+
Parameters:
|
46 |
+
*args: Additional positional arguments.
|
47 |
+
**kwargs: Additional keyword arguments.
|
48 |
+
|
49 |
+
Returns:
|
50 |
+
str: The name of the session.
|
51 |
+
"""
|
52 |
+
return "birefnet-cod"
|
rembg/sessions/birefnet_dis.py
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
import pooch
|
4 |
+
|
5 |
+
from . import BiRefNetSessionGeneral
|
6 |
+
|
7 |
+
|
8 |
+
class BiRefNetSessionDIS(BiRefNetSessionGeneral):
|
9 |
+
"""
|
10 |
+
This class represents a BiRefNet-DIS session, which is a subclass of BiRefNetSessionGeneral.
|
11 |
+
"""
|
12 |
+
|
13 |
+
@classmethod
|
14 |
+
def download_models(cls, *args, **kwargs):
|
15 |
+
"""
|
16 |
+
Downloads the BiRefNet-DIS model file from a specific URL and saves it.
|
17 |
+
|
18 |
+
Parameters:
|
19 |
+
*args: Additional positional arguments.
|
20 |
+
**kwargs: Additional keyword arguments.
|
21 |
+
|
22 |
+
Returns:
|
23 |
+
str: The path to the downloaded model file.
|
24 |
+
"""
|
25 |
+
fname = f"{cls.name(*args, **kwargs)}.onnx"
|
26 |
+
pooch.retrieve(
|
27 |
+
"https://github.com/danielgatis/rembg/releases/download/v0.0.0/BiRefNet-DIS-epoch_590.onnx",
|
28 |
+
(
|
29 |
+
None
|
30 |
+
if cls.checksum_disabled(*args, **kwargs)
|
31 |
+
else "md5:2d4d44102b446f33a4ebb2e56c051f2b"
|
32 |
+
),
|
33 |
+
fname=fname,
|
34 |
+
path=cls.u2net_home(*args, **kwargs),
|
35 |
+
progressbar=True,
|
36 |
+
)
|
37 |
+
|
38 |
+
return os.path.join(cls.u2net_home(*args, **kwargs), fname)
|
39 |
+
|
40 |
+
@classmethod
|
41 |
+
def name(cls, *args, **kwargs):
|
42 |
+
"""
|
43 |
+
Returns the name of the BiRefNet-DIS session.
|
44 |
+
|
45 |
+
Parameters:
|
46 |
+
*args: Additional positional arguments.
|
47 |
+
**kwargs: Additional keyword arguments.
|
48 |
+
|
49 |
+
Returns:
|
50 |
+
str: The name of the session.
|
51 |
+
"""
|
52 |
+
return "birefnet-dis"
|
rembg/sessions/birefnet_general.py
ADDED
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from typing import List
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import pooch
|
6 |
+
from PIL import Image
|
7 |
+
from PIL.Image import Image as PILImage
|
8 |
+
|
9 |
+
from .base import BaseSession
|
10 |
+
|
11 |
+
|
12 |
+
class BiRefNetSessionGeneral(BaseSession):
|
13 |
+
"""
|
14 |
+
This class represents a BiRefNet-General session, which is a subclass of BaseSession.
|
15 |
+
"""
|
16 |
+
|
17 |
+
def sigmoid(self, mat):
|
18 |
+
return 1 / (1 + np.exp(-mat))
|
19 |
+
|
20 |
+
def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]:
|
21 |
+
"""
|
22 |
+
Predicts the output masks for the input image using the inner session.
|
23 |
+
|
24 |
+
Parameters:
|
25 |
+
img (PILImage): The input image.
|
26 |
+
*args: Additional positional arguments.
|
27 |
+
**kwargs: Additional keyword arguments.
|
28 |
+
|
29 |
+
Returns:
|
30 |
+
List[PILImage]: The list of output masks.
|
31 |
+
"""
|
32 |
+
ort_outs = self.inner_session.run(
|
33 |
+
None,
|
34 |
+
self.normalize(
|
35 |
+
img, (0.485, 0.456, 0.406), (0.229, 0.224, 0.225), (1024, 1024)
|
36 |
+
),
|
37 |
+
)
|
38 |
+
|
39 |
+
pred = self.sigmoid(ort_outs[0][:, 0, :, :])
|
40 |
+
|
41 |
+
ma = np.max(pred)
|
42 |
+
mi = np.min(pred)
|
43 |
+
|
44 |
+
pred = (pred - mi) / (ma - mi)
|
45 |
+
pred = np.squeeze(pred)
|
46 |
+
|
47 |
+
mask = Image.fromarray((pred * 255).astype("uint8"), mode="L")
|
48 |
+
mask = mask.resize(img.size, Image.Resampling.LANCZOS)
|
49 |
+
|
50 |
+
return [mask]
|
51 |
+
|
52 |
+
@classmethod
|
53 |
+
def download_models(cls, *args, **kwargs):
|
54 |
+
"""
|
55 |
+
Downloads the BiRefNet-General model file from a specific URL and saves it.
|
56 |
+
|
57 |
+
Parameters:
|
58 |
+
*args: Additional positional arguments.
|
59 |
+
**kwargs: Additional keyword arguments.
|
60 |
+
|
61 |
+
Returns:
|
62 |
+
str: The path to the downloaded model file.
|
63 |
+
"""
|
64 |
+
fname = f"{cls.name(*args, **kwargs)}.onnx"
|
65 |
+
pooch.retrieve(
|
66 |
+
"https://github.com/danielgatis/rembg/releases/download/v0.0.0/BiRefNet-general-epoch_244.onnx",
|
67 |
+
(
|
68 |
+
None
|
69 |
+
if cls.checksum_disabled(*args, **kwargs)
|
70 |
+
else "md5:7a35a0141cbbc80de11d9c9a28f52697"
|
71 |
+
),
|
72 |
+
fname=fname,
|
73 |
+
path=cls.u2net_home(*args, **kwargs),
|
74 |
+
progressbar=True,
|
75 |
+
)
|
76 |
+
|
77 |
+
return os.path.join(cls.u2net_home(*args, **kwargs), fname)
|
78 |
+
|
79 |
+
@classmethod
|
80 |
+
def name(cls, *args, **kwargs):
|
81 |
+
"""
|
82 |
+
Returns the name of the BiRefNet-General session.
|
83 |
+
|
84 |
+
Parameters:
|
85 |
+
*args: Additional positional arguments.
|
86 |
+
**kwargs: Additional keyword arguments.
|
87 |
+
|
88 |
+
Returns:
|
89 |
+
str: The name of the session.
|
90 |
+
"""
|
91 |
+
return "birefnet-general"
|
rembg/sessions/birefnet_general_lite.py
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
import pooch
|
4 |
+
|
5 |
+
from . import BiRefNetSessionGeneral
|
6 |
+
|
7 |
+
|
8 |
+
class BiRefNetSessionGeneralLite(BiRefNetSessionGeneral):
|
9 |
+
"""
|
10 |
+
This class represents a BiRefNet-General-Lite session, which is a subclass of BiRefNetSessionGeneral.
|
11 |
+
"""
|
12 |
+
|
13 |
+
@classmethod
|
14 |
+
def download_models(cls, *args, **kwargs):
|
15 |
+
"""
|
16 |
+
Downloads the BiRefNet-General-Lite model file from a specific URL and saves it.
|
17 |
+
|
18 |
+
Parameters:
|
19 |
+
*args: Additional positional arguments.
|
20 |
+
**kwargs: Additional keyword arguments.
|
21 |
+
|
22 |
+
Returns:
|
23 |
+
str: The path to the downloaded model file.
|
24 |
+
"""
|
25 |
+
fname = f"{cls.name(*args, **kwargs)}.onnx"
|
26 |
+
pooch.retrieve(
|
27 |
+
"https://github.com/danielgatis/rembg/releases/download/v0.0.0/BiRefNet-general-bb_swin_v1_tiny-epoch_232.onnx",
|
28 |
+
(
|
29 |
+
None
|
30 |
+
if cls.checksum_disabled(*args, **kwargs)
|
31 |
+
else "md5:4fab47adc4ff364be1713e97b7e66334"
|
32 |
+
),
|
33 |
+
fname=fname,
|
34 |
+
path=cls.u2net_home(*args, **kwargs),
|
35 |
+
progressbar=True,
|
36 |
+
)
|
37 |
+
|
38 |
+
return os.path.join(cls.u2net_home(*args, **kwargs), fname)
|
39 |
+
|
40 |
+
@classmethod
|
41 |
+
def name(cls, *args, **kwargs):
|
42 |
+
"""
|
43 |
+
Returns the name of the BiRefNet-General-Lite session.
|
44 |
+
|
45 |
+
Parameters:
|
46 |
+
*args: Additional positional arguments.
|
47 |
+
**kwargs: Additional keyword arguments.
|
48 |
+
|
49 |
+
Returns:
|
50 |
+
str: The name of the session.
|
51 |
+
"""
|
52 |
+
return "birefnet-general-lite"
|
rembg/sessions/birefnet_hrsod.py
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
import pooch
|
4 |
+
|
5 |
+
from . import BiRefNetSessionGeneral
|
6 |
+
|
7 |
+
|
8 |
+
class BiRefNetSessionHRSOD(BiRefNetSessionGeneral):
|
9 |
+
"""
|
10 |
+
This class represents a BiRefNet-HRSOD session, which is a subclass of BiRefNetSessionGeneral.
|
11 |
+
"""
|
12 |
+
|
13 |
+
@classmethod
|
14 |
+
def download_models(cls, *args, **kwargs):
|
15 |
+
"""
|
16 |
+
Downloads the BiRefNet-HRSOD model file from a specific URL and saves it.
|
17 |
+
|
18 |
+
Parameters:
|
19 |
+
*args: Additional positional arguments.
|
20 |
+
**kwargs: Additional keyword arguments.
|
21 |
+
|
22 |
+
Returns:
|
23 |
+
str: The path to the downloaded model file.
|
24 |
+
"""
|
25 |
+
fname = f"{cls.name(*args, **kwargs)}.onnx"
|
26 |
+
pooch.retrieve(
|
27 |
+
"https://github.com/danielgatis/rembg/releases/download/v0.0.0/BiRefNet-HRSOD_DHU-epoch_115.onnx",
|
28 |
+
(
|
29 |
+
None
|
30 |
+
if cls.checksum_disabled(*args, **kwargs)
|
31 |
+
else "md5:c017ade5de8a50ff0fd74d790d268dda"
|
32 |
+
),
|
33 |
+
fname=fname,
|
34 |
+
path=cls.u2net_home(*args, **kwargs),
|
35 |
+
progressbar=True,
|
36 |
+
)
|
37 |
+
|
38 |
+
return os.path.join(cls.u2net_home(*args, **kwargs), fname)
|
39 |
+
|
40 |
+
@classmethod
|
41 |
+
def name(cls, *args, **kwargs):
|
42 |
+
"""
|
43 |
+
Returns the name of the BiRefNet-HRSOD session.
|
44 |
+
|
45 |
+
Parameters:
|
46 |
+
*args: Additional positional arguments.
|
47 |
+
**kwargs: Additional keyword arguments.
|
48 |
+
|
49 |
+
Returns:
|
50 |
+
str: The name of the session.
|
51 |
+
"""
|
52 |
+
return "birefnet-hrsod"
|
rembg/sessions/birefnet_massive.py
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
import pooch
|
4 |
+
|
5 |
+
from . import BiRefNetSessionGeneral
|
6 |
+
|
7 |
+
|
8 |
+
class BiRefNetSessionMassive(BiRefNetSessionGeneral):
|
9 |
+
"""
|
10 |
+
This class represents a BiRefNet-Massive session, which is a subclass of BiRefNetSessionGeneral.
|
11 |
+
"""
|
12 |
+
|
13 |
+
@classmethod
|
14 |
+
def download_models(cls, *args, **kwargs):
|
15 |
+
"""
|
16 |
+
Downloads the BiRefNet-Massive model file from a specific URL and saves it.
|
17 |
+
|
18 |
+
Parameters:
|
19 |
+
*args: Additional positional arguments.
|
20 |
+
**kwargs: Additional keyword arguments.
|
21 |
+
|
22 |
+
Returns:
|
23 |
+
str: The path to the downloaded model file.
|
24 |
+
"""
|
25 |
+
fname = f"{cls.name(*args, **kwargs)}.onnx"
|
26 |
+
pooch.retrieve(
|
27 |
+
"https://github.com/danielgatis/rembg/releases/download/v0.0.0/BiRefNet-massive-TR_DIS5K_TR_TEs-epoch_420.onnx",
|
28 |
+
(
|
29 |
+
None
|
30 |
+
if cls.checksum_disabled(*args, **kwargs)
|
31 |
+
else "md5:33e726a2136a3d59eb0fdf613e31e3e9"
|
32 |
+
),
|
33 |
+
fname=fname,
|
34 |
+
path=cls.u2net_home(*args, **kwargs),
|
35 |
+
progressbar=True,
|
36 |
+
)
|
37 |
+
|
38 |
+
return os.path.join(cls.u2net_home(*args, **kwargs), fname)
|
39 |
+
|
40 |
+
@classmethod
|
41 |
+
def name(cls, *args, **kwargs):
|
42 |
+
"""
|
43 |
+
Returns the name of the BiRefNet-Massive session.
|
44 |
+
|
45 |
+
Parameters:
|
46 |
+
*args: Additional positional arguments.
|
47 |
+
**kwargs: Additional keyword arguments.
|
48 |
+
|
49 |
+
Returns:
|
50 |
+
str: The name of the session.
|
51 |
+
"""
|
52 |
+
return "birefnet-massive"
|
rembg/sessions/birefnet_portrait.py
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
import pooch
|
4 |
+
|
5 |
+
from . import BiRefNetSessionGeneral
|
6 |
+
|
7 |
+
|
8 |
+
class BiRefNetSessionPortrait(BiRefNetSessionGeneral):
|
9 |
+
"""
|
10 |
+
This class represents a BiRefNet-Portrait session, which is a subclass of BiRefNetSessionGeneral.
|
11 |
+
"""
|
12 |
+
|
13 |
+
@classmethod
|
14 |
+
def download_models(cls, *args, **kwargs):
|
15 |
+
"""
|
16 |
+
Downloads the BiRefNet-Portrait model file from a specific URL and saves it.
|
17 |
+
|
18 |
+
Parameters:
|
19 |
+
*args: Additional positional arguments.
|
20 |
+
**kwargs: Additional keyword arguments.
|
21 |
+
|
22 |
+
Returns:
|
23 |
+
str: The path to the downloaded model file.
|
24 |
+
"""
|
25 |
+
fname = f"{cls.name(*args, **kwargs)}.onnx"
|
26 |
+
pooch.retrieve(
|
27 |
+
"https://github.com/danielgatis/rembg/releases/download/v0.0.0/BiRefNet-portrait-epoch_150.onnx",
|
28 |
+
(
|
29 |
+
None
|
30 |
+
if cls.checksum_disabled(*args, **kwargs)
|
31 |
+
else "md5:c3a64a6abf20250d090cd055f12a3b67"
|
32 |
+
),
|
33 |
+
fname=fname,
|
34 |
+
path=cls.u2net_home(*args, **kwargs),
|
35 |
+
progressbar=True,
|
36 |
+
)
|
37 |
+
|
38 |
+
return os.path.join(cls.u2net_home(*args, **kwargs), fname)
|
39 |
+
|
40 |
+
@classmethod
|
41 |
+
def name(cls, *args, **kwargs):
|
42 |
+
"""
|
43 |
+
Returns the name of the BiRefNet-Portrait session.
|
44 |
+
|
45 |
+
Parameters:
|
46 |
+
*args: Additional positional arguments.
|
47 |
+
**kwargs: Additional keyword arguments.
|
48 |
+
|
49 |
+
Returns:
|
50 |
+
str: The name of the session.
|
51 |
+
"""
|
52 |
+
return "birefnet-portrait"
|
rembg/sessions/dis_anime.py
CHANGED
@@ -10,7 +10,22 @@ from .base import BaseSession
|
|
10 |
|
11 |
|
12 |
class DisSession(BaseSession):
|
|
|
|
|
|
|
|
|
13 |
def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
ort_outs = self.inner_session.run(
|
15 |
None,
|
16 |
self.normalize(img, (0.485, 0.456, 0.406), (1.0, 1.0, 1.0), (1024, 1024)),
|
@@ -25,25 +40,47 @@ class DisSession(BaseSession):
|
|
25 |
pred = np.squeeze(pred)
|
26 |
|
27 |
mask = Image.fromarray((pred * 255).astype("uint8"), mode="L")
|
28 |
-
mask = mask.resize(img.size, Image.LANCZOS)
|
29 |
|
30 |
return [mask]
|
31 |
|
32 |
@classmethod
|
33 |
def download_models(cls, *args, **kwargs):
|
34 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
35 |
pooch.retrieve(
|
36 |
"https://github.com/danielgatis/rembg/releases/download/v0.0.0/isnet-anime.onnx",
|
37 |
-
|
38 |
-
|
39 |
-
|
|
|
|
|
40 |
fname=fname,
|
41 |
path=cls.u2net_home(*args, **kwargs),
|
42 |
progressbar=True,
|
43 |
)
|
44 |
|
45 |
-
return os.path.join(cls.u2net_home(), fname)
|
46 |
|
47 |
@classmethod
|
48 |
def name(cls, *args, **kwargs):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
49 |
return "isnet-anime"
|
|
|
10 |
|
11 |
|
12 |
class DisSession(BaseSession):
|
13 |
+
"""
|
14 |
+
This class represents a session for object detection.
|
15 |
+
"""
|
16 |
+
|
17 |
def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]:
|
18 |
+
"""
|
19 |
+
Use a pre-trained model to predict the object in the given image.
|
20 |
+
|
21 |
+
Parameters:
|
22 |
+
img (PILImage): The input image.
|
23 |
+
*args: Variable length argument list.
|
24 |
+
**kwargs: Arbitrary keyword arguments.
|
25 |
+
|
26 |
+
Returns:
|
27 |
+
List[PILImage]: A list of predicted mask images.
|
28 |
+
"""
|
29 |
ort_outs = self.inner_session.run(
|
30 |
None,
|
31 |
self.normalize(img, (0.485, 0.456, 0.406), (1.0, 1.0, 1.0), (1024, 1024)),
|
|
|
40 |
pred = np.squeeze(pred)
|
41 |
|
42 |
mask = Image.fromarray((pred * 255).astype("uint8"), mode="L")
|
43 |
+
mask = mask.resize(img.size, Image.Resampling.LANCZOS)
|
44 |
|
45 |
return [mask]
|
46 |
|
47 |
@classmethod
|
48 |
def download_models(cls, *args, **kwargs):
|
49 |
+
"""
|
50 |
+
Download the pre-trained models.
|
51 |
+
|
52 |
+
Parameters:
|
53 |
+
*args: Variable length argument list.
|
54 |
+
**kwargs: Arbitrary keyword arguments.
|
55 |
+
|
56 |
+
Returns:
|
57 |
+
str: The path of the downloaded model file.
|
58 |
+
"""
|
59 |
+
fname = f"{cls.name(*args, **kwargs)}.onnx"
|
60 |
pooch.retrieve(
|
61 |
"https://github.com/danielgatis/rembg/releases/download/v0.0.0/isnet-anime.onnx",
|
62 |
+
(
|
63 |
+
None
|
64 |
+
if cls.checksum_disabled(*args, **kwargs)
|
65 |
+
else "md5:6f184e756bb3bd901c8849220a83e38e"
|
66 |
+
),
|
67 |
fname=fname,
|
68 |
path=cls.u2net_home(*args, **kwargs),
|
69 |
progressbar=True,
|
70 |
)
|
71 |
|
72 |
+
return os.path.join(cls.u2net_home(*args, **kwargs), fname)
|
73 |
|
74 |
@classmethod
|
75 |
def name(cls, *args, **kwargs):
|
76 |
+
"""
|
77 |
+
Get the name of the pre-trained model.
|
78 |
+
|
79 |
+
Parameters:
|
80 |
+
*args: Variable length argument list.
|
81 |
+
**kwargs: Arbitrary keyword arguments.
|
82 |
+
|
83 |
+
Returns:
|
84 |
+
str: The name of the pre-trained model.
|
85 |
+
"""
|
86 |
return "isnet-anime"
|
rembg/sessions/dis_general_use.py
CHANGED
@@ -11,6 +11,17 @@ from .base import BaseSession
|
|
11 |
|
12 |
class DisSession(BaseSession):
|
13 |
def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
ort_outs = self.inner_session.run(
|
15 |
None,
|
16 |
self.normalize(img, (0.485, 0.456, 0.406), (1.0, 1.0, 1.0), (1024, 1024)),
|
@@ -25,25 +36,51 @@ class DisSession(BaseSession):
|
|
25 |
pred = np.squeeze(pred)
|
26 |
|
27 |
mask = Image.fromarray((pred * 255).astype("uint8"), mode="L")
|
28 |
-
mask = mask.resize(img.size, Image.LANCZOS)
|
29 |
|
30 |
return [mask]
|
31 |
|
32 |
@classmethod
|
33 |
def download_models(cls, *args, **kwargs):
|
34 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
35 |
pooch.retrieve(
|
36 |
"https://github.com/danielgatis/rembg/releases/download/v0.0.0/isnet-general-use.onnx",
|
37 |
-
|
38 |
-
|
39 |
-
|
|
|
|
|
40 |
fname=fname,
|
41 |
path=cls.u2net_home(*args, **kwargs),
|
42 |
progressbar=True,
|
43 |
)
|
44 |
|
45 |
-
return os.path.join(cls.u2net_home(), fname)
|
46 |
|
47 |
@classmethod
|
48 |
def name(cls, *args, **kwargs):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
49 |
return "isnet-general-use"
|
|
|
11 |
|
12 |
class DisSession(BaseSession):
|
13 |
def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]:
|
14 |
+
"""
|
15 |
+
Predicts the mask image for the input image.
|
16 |
+
|
17 |
+
This method takes a PILImage object as input and returns a list of PILImage objects as output. It performs several image processing operations to generate the mask image.
|
18 |
+
|
19 |
+
Parameters:
|
20 |
+
img (PILImage): The input image.
|
21 |
+
|
22 |
+
Returns:
|
23 |
+
List[PILImage]: A list of PILImage objects representing the generated mask image.
|
24 |
+
"""
|
25 |
ort_outs = self.inner_session.run(
|
26 |
None,
|
27 |
self.normalize(img, (0.485, 0.456, 0.406), (1.0, 1.0, 1.0), (1024, 1024)),
|
|
|
36 |
pred = np.squeeze(pred)
|
37 |
|
38 |
mask = Image.fromarray((pred * 255).astype("uint8"), mode="L")
|
39 |
+
mask = mask.resize(img.size, Image.Resampling.LANCZOS)
|
40 |
|
41 |
return [mask]
|
42 |
|
43 |
@classmethod
|
44 |
def download_models(cls, *args, **kwargs):
|
45 |
+
"""
|
46 |
+
Downloads the pre-trained model file.
|
47 |
+
|
48 |
+
This class method downloads the pre-trained model file from a specified URL using the pooch library.
|
49 |
+
|
50 |
+
Parameters:
|
51 |
+
args: Additional positional arguments.
|
52 |
+
kwargs: Additional keyword arguments.
|
53 |
+
|
54 |
+
Returns:
|
55 |
+
str: The path to the downloaded model file.
|
56 |
+
"""
|
57 |
+
fname = f"{cls.name(*args, **kwargs)}.onnx"
|
58 |
pooch.retrieve(
|
59 |
"https://github.com/danielgatis/rembg/releases/download/v0.0.0/isnet-general-use.onnx",
|
60 |
+
(
|
61 |
+
None
|
62 |
+
if cls.checksum_disabled(*args, **kwargs)
|
63 |
+
else "md5:fc16ebd8b0c10d971d3513d564d01e29"
|
64 |
+
),
|
65 |
fname=fname,
|
66 |
path=cls.u2net_home(*args, **kwargs),
|
67 |
progressbar=True,
|
68 |
)
|
69 |
|
70 |
+
return os.path.join(cls.u2net_home(*args, **kwargs), fname)
|
71 |
|
72 |
@classmethod
|
73 |
def name(cls, *args, **kwargs):
|
74 |
+
"""
|
75 |
+
Returns the name of the model.
|
76 |
+
|
77 |
+
This class method returns the name of the model.
|
78 |
+
|
79 |
+
Parameters:
|
80 |
+
args: Additional positional arguments.
|
81 |
+
kwargs: Additional keyword arguments.
|
82 |
+
|
83 |
+
Returns:
|
84 |
+
str: The name of the model.
|
85 |
+
"""
|
86 |
return "isnet-general-use"
|
rembg/sessions/sam.py
CHANGED
@@ -1,9 +1,12 @@
|
|
1 |
import os
|
2 |
-
from
|
|
|
3 |
|
|
|
4 |
import numpy as np
|
5 |
import onnxruntime as ort
|
6 |
import pooch
|
|
|
7 |
from PIL import Image
|
8 |
from PIL.Image import Image as PILImage
|
9 |
|
@@ -15,104 +18,213 @@ def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int):
|
|
15 |
newh, neww = oldh * scale, oldw * scale
|
16 |
neww = int(neww + 0.5)
|
17 |
newh = int(newh + 0.5)
|
|
|
18 |
return (newh, neww)
|
19 |
|
20 |
|
21 |
-
def apply_coords(coords: np.ndarray, original_size, target_length)
|
22 |
old_h, old_w = original_size
|
23 |
new_h, new_w = get_preprocess_shape(
|
24 |
original_size[0], original_size[1], target_length
|
25 |
)
|
26 |
-
|
|
|
27 |
coords[..., 0] = coords[..., 0] * (new_w / old_w)
|
28 |
coords[..., 1] = coords[..., 1] * (new_h / old_h)
|
|
|
29 |
return coords
|
30 |
|
31 |
|
32 |
-
def
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
38 |
|
39 |
-
|
|
|
40 |
|
41 |
|
42 |
-
def
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
49 |
|
50 |
|
51 |
class SamSession(BaseSession):
|
52 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
53 |
self.model_name = model_name
|
54 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
55 |
self.encoder = ort.InferenceSession(
|
56 |
str(paths[0]),
|
57 |
-
providers=
|
58 |
sess_options=sess_opts,
|
59 |
)
|
60 |
self.decoder = ort.InferenceSession(
|
61 |
str(paths[1]),
|
62 |
-
providers=
|
63 |
sess_options=sess_opts,
|
64 |
)
|
65 |
|
66 |
-
def normalize(
|
67 |
-
self,
|
68 |
-
img: np.ndarray,
|
69 |
-
mean=(123.675, 116.28, 103.53),
|
70 |
-
std=(58.395, 57.12, 57.375),
|
71 |
-
size=(1024, 1024),
|
72 |
-
*args,
|
73 |
-
**kwargs,
|
74 |
-
):
|
75 |
-
pixel_mean = np.array([*mean]).reshape(1, 1, -1)
|
76 |
-
pixel_std = np.array([*std]).reshape(1, 1, -1)
|
77 |
-
x = (img - pixel_mean) / pixel_std
|
78 |
-
return x
|
79 |
-
|
80 |
def predict(
|
81 |
self,
|
82 |
img: PILImage,
|
83 |
*args,
|
84 |
**kwargs,
|
85 |
) -> List[PILImage]:
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
image
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
107 |
onnx_coord = np.concatenate([input_points, np.array([[0.0, 0.0]])], axis=0)[
|
108 |
None, :, :
|
109 |
]
|
110 |
onnx_label = np.concatenate([input_labels, np.array([-1])], axis=0)[
|
111 |
None, :
|
112 |
].astype(np.float32)
|
113 |
-
onnx_coord = apply_coords(onnx_coord,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
114 |
|
115 |
-
# Create an empty mask input and an indicator for no mask.
|
116 |
onnx_mask_input = np.zeros((1, 1, 256, 256), dtype=np.float32)
|
117 |
onnx_has_mask_input = np.zeros(1, dtype=np.float32)
|
118 |
|
@@ -122,48 +234,110 @@ class SamSession(BaseSession):
|
|
122 |
"point_labels": onnx_label,
|
123 |
"mask_input": onnx_mask_input,
|
124 |
"has_mask_input": onnx_has_mask_input,
|
125 |
-
"orig_im_size": np.array(
|
126 |
}
|
127 |
|
128 |
-
masks, _,
|
129 |
-
|
130 |
-
masks =
|
131 |
-
|
132 |
-
|
133 |
-
]
|
|
|
134 |
|
135 |
-
return
|
136 |
|
137 |
@classmethod
|
138 |
def download_models(cls, *args, **kwargs):
|
139 |
-
|
140 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
141 |
|
142 |
pooch.retrieve(
|
143 |
-
"https://github.com/danielgatis/rembg/releases/download/v0.0.0/
|
144 |
-
None
|
145 |
-
if cls.checksum_disabled(*args, **kwargs)
|
146 |
-
else "md5:13d97c5c79ab13ef86d67cbde5f1b250",
|
147 |
fname=fname_encoder,
|
148 |
path=cls.u2net_home(*args, **kwargs),
|
149 |
progressbar=True,
|
150 |
)
|
151 |
|
152 |
pooch.retrieve(
|
153 |
-
"https://github.com/danielgatis/rembg/releases/download/v0.0.0/
|
154 |
-
None
|
155 |
-
if cls.checksum_disabled(*args, **kwargs)
|
156 |
-
else "md5:fa3d1c36a3187d3de1c8deebf33dd127",
|
157 |
fname=fname_decoder,
|
158 |
path=cls.u2net_home(*args, **kwargs),
|
159 |
progressbar=True,
|
160 |
)
|
161 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
162 |
return (
|
163 |
-
os.path.join(cls.u2net_home(), fname_encoder),
|
164 |
-
os.path.join(cls.u2net_home(), fname_decoder),
|
165 |
)
|
166 |
|
167 |
@classmethod
|
168 |
def name(cls, *args, **kwargs):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
169 |
return "sam"
|
|
|
1 |
import os
|
2 |
+
from copy import deepcopy
|
3 |
+
from typing import Dict, List, Tuple
|
4 |
|
5 |
+
import cv2
|
6 |
import numpy as np
|
7 |
import onnxruntime as ort
|
8 |
import pooch
|
9 |
+
from jsonschema import validate
|
10 |
from PIL import Image
|
11 |
from PIL.Image import Image as PILImage
|
12 |
|
|
|
18 |
newh, neww = oldh * scale, oldw * scale
|
19 |
neww = int(neww + 0.5)
|
20 |
newh = int(newh + 0.5)
|
21 |
+
|
22 |
return (newh, neww)
|
23 |
|
24 |
|
25 |
+
def apply_coords(coords: np.ndarray, original_size, target_length):
|
26 |
old_h, old_w = original_size
|
27 |
new_h, new_w = get_preprocess_shape(
|
28 |
original_size[0], original_size[1], target_length
|
29 |
)
|
30 |
+
|
31 |
+
coords = deepcopy(coords).astype(float)
|
32 |
coords[..., 0] = coords[..., 0] * (new_w / old_w)
|
33 |
coords[..., 1] = coords[..., 1] * (new_h / old_h)
|
34 |
+
|
35 |
return coords
|
36 |
|
37 |
|
38 |
+
def get_input_points(prompt):
|
39 |
+
points = []
|
40 |
+
labels = []
|
41 |
+
|
42 |
+
for mark in prompt:
|
43 |
+
if mark["type"] == "point":
|
44 |
+
points.append(mark["data"])
|
45 |
+
labels.append(mark["label"])
|
46 |
+
elif mark["type"] == "rectangle":
|
47 |
+
points.append([mark["data"][0], mark["data"][1]])
|
48 |
+
points.append([mark["data"][2], mark["data"][3]])
|
49 |
+
labels.append(2)
|
50 |
+
labels.append(3)
|
51 |
|
52 |
+
points, labels = np.array(points), np.array(labels)
|
53 |
+
return points, labels
|
54 |
|
55 |
|
56 |
+
def transform_masks(masks, original_size, transform_matrix):
|
57 |
+
output_masks = []
|
58 |
+
|
59 |
+
for batch in range(masks.shape[0]):
|
60 |
+
batch_masks = []
|
61 |
+
for mask_id in range(masks.shape[1]):
|
62 |
+
mask = masks[batch, mask_id]
|
63 |
+
mask = cv2.warpAffine(
|
64 |
+
mask,
|
65 |
+
transform_matrix[:2],
|
66 |
+
(original_size[1], original_size[0]),
|
67 |
+
flags=cv2.INTER_LINEAR,
|
68 |
+
)
|
69 |
+
batch_masks.append(mask)
|
70 |
+
output_masks.append(batch_masks)
|
71 |
+
|
72 |
+
return np.array(output_masks)
|
73 |
|
74 |
|
75 |
class SamSession(BaseSession):
|
76 |
+
"""
|
77 |
+
This class represents a session for the Sam model.
|
78 |
+
|
79 |
+
Args:
|
80 |
+
model_name (str): The name of the model.
|
81 |
+
sess_opts (ort.SessionOptions): The session options.
|
82 |
+
*args: Variable length argument list.
|
83 |
+
**kwargs: Arbitrary keyword arguments.
|
84 |
+
"""
|
85 |
+
|
86 |
+
def __init__(
|
87 |
+
self,
|
88 |
+
model_name: str,
|
89 |
+
sess_opts: ort.SessionOptions,
|
90 |
+
providers=None,
|
91 |
+
*args,
|
92 |
+
**kwargs,
|
93 |
+
):
|
94 |
+
"""
|
95 |
+
Initialize a new SamSession with the given model name and session options.
|
96 |
+
|
97 |
+
Args:
|
98 |
+
model_name (str): The name of the model.
|
99 |
+
sess_opts (ort.SessionOptions): The session options.
|
100 |
+
*args: Variable length argument list.
|
101 |
+
**kwargs: Arbitrary keyword arguments.
|
102 |
+
"""
|
103 |
self.model_name = model_name
|
104 |
+
|
105 |
+
valid_providers = []
|
106 |
+
available_providers = ort.get_available_providers()
|
107 |
+
|
108 |
+
for provider in providers or []:
|
109 |
+
if provider in available_providers:
|
110 |
+
valid_providers.append(provider)
|
111 |
+
else:
|
112 |
+
valid_providers.extend(available_providers)
|
113 |
+
|
114 |
+
paths = self.__class__.download_models(*args, **kwargs)
|
115 |
self.encoder = ort.InferenceSession(
|
116 |
str(paths[0]),
|
117 |
+
providers=valid_providers,
|
118 |
sess_options=sess_opts,
|
119 |
)
|
120 |
self.decoder = ort.InferenceSession(
|
121 |
str(paths[1]),
|
122 |
+
providers=valid_providers,
|
123 |
sess_options=sess_opts,
|
124 |
)
|
125 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
126 |
def predict(
|
127 |
self,
|
128 |
img: PILImage,
|
129 |
*args,
|
130 |
**kwargs,
|
131 |
) -> List[PILImage]:
|
132 |
+
"""
|
133 |
+
Predict masks for an input image.
|
134 |
+
|
135 |
+
This function takes an image as input and performs various preprocessing steps on the image. It then runs the image through an encoder to obtain an image embedding. The function also takes input labels and points as additional arguments. It concatenates the input points and labels with padding and transforms them. It creates an empty mask input and an indicator for no mask. The function then passes the image embedding, point coordinates, point labels, mask input, and has mask input to a decoder. The decoder generates masks based on the input and returns them as a list of images.
|
136 |
+
|
137 |
+
Parameters:
|
138 |
+
img (PILImage): The input image.
|
139 |
+
*args: Additional arguments.
|
140 |
+
**kwargs: Additional keyword arguments.
|
141 |
+
|
142 |
+
Returns:
|
143 |
+
List[PILImage]: A list of masks generated by the decoder.
|
144 |
+
"""
|
145 |
+
prompt = kwargs.get("sam_prompt", "{}")
|
146 |
+
schema = {
|
147 |
+
"type": "array",
|
148 |
+
"items": {
|
149 |
+
"type": "object",
|
150 |
+
"properties": {
|
151 |
+
"type": {"type": "string"},
|
152 |
+
"label": {"type": "integer"},
|
153 |
+
"data": {
|
154 |
+
"type": "array",
|
155 |
+
"items": {"type": "number"},
|
156 |
+
},
|
157 |
+
},
|
158 |
+
},
|
159 |
+
}
|
160 |
+
|
161 |
+
validate(instance=prompt, schema=schema)
|
162 |
+
|
163 |
+
target_size = 1024
|
164 |
+
input_size = (684, 1024)
|
165 |
+
encoder_input_name = self.encoder.get_inputs()[0].name
|
166 |
+
|
167 |
+
img = img.convert("RGB")
|
168 |
+
cv_image = np.array(img)
|
169 |
+
original_size = cv_image.shape[:2]
|
170 |
+
|
171 |
+
scale_x = input_size[1] / cv_image.shape[1]
|
172 |
+
scale_y = input_size[0] / cv_image.shape[0]
|
173 |
+
scale = min(scale_x, scale_y)
|
174 |
+
|
175 |
+
transform_matrix = np.array(
|
176 |
+
[
|
177 |
+
[scale, 0, 0],
|
178 |
+
[0, scale, 0],
|
179 |
+
[0, 0, 1],
|
180 |
+
]
|
181 |
+
)
|
182 |
+
|
183 |
+
cv_image = cv2.warpAffine(
|
184 |
+
cv_image,
|
185 |
+
transform_matrix[:2],
|
186 |
+
(input_size[1], input_size[0]),
|
187 |
+
flags=cv2.INTER_LINEAR,
|
188 |
+
)
|
189 |
+
|
190 |
+
## encoder
|
191 |
+
|
192 |
+
encoder_inputs = {
|
193 |
+
encoder_input_name: cv_image.astype(np.float32),
|
194 |
+
}
|
195 |
+
|
196 |
+
encoder_output = self.encoder.run(None, encoder_inputs)
|
197 |
+
image_embedding = encoder_output[0]
|
198 |
+
|
199 |
+
embedding = {
|
200 |
+
"image_embedding": image_embedding,
|
201 |
+
"original_size": original_size,
|
202 |
+
"transform_matrix": transform_matrix,
|
203 |
+
}
|
204 |
+
|
205 |
+
## decoder
|
206 |
+
|
207 |
+
input_points, input_labels = get_input_points(prompt)
|
208 |
onnx_coord = np.concatenate([input_points, np.array([[0.0, 0.0]])], axis=0)[
|
209 |
None, :, :
|
210 |
]
|
211 |
onnx_label = np.concatenate([input_labels, np.array([-1])], axis=0)[
|
212 |
None, :
|
213 |
].astype(np.float32)
|
214 |
+
onnx_coord = apply_coords(onnx_coord, input_size, target_size).astype(
|
215 |
+
np.float32
|
216 |
+
)
|
217 |
+
|
218 |
+
onnx_coord = np.concatenate(
|
219 |
+
[
|
220 |
+
onnx_coord,
|
221 |
+
np.ones((1, onnx_coord.shape[1], 1), dtype=np.float32),
|
222 |
+
],
|
223 |
+
axis=2,
|
224 |
+
)
|
225 |
+
onnx_coord = np.matmul(onnx_coord, transform_matrix.T)
|
226 |
+
onnx_coord = onnx_coord[:, :, :2].astype(np.float32)
|
227 |
|
|
|
228 |
onnx_mask_input = np.zeros((1, 1, 256, 256), dtype=np.float32)
|
229 |
onnx_has_mask_input = np.zeros(1, dtype=np.float32)
|
230 |
|
|
|
234 |
"point_labels": onnx_label,
|
235 |
"mask_input": onnx_mask_input,
|
236 |
"has_mask_input": onnx_has_mask_input,
|
237 |
+
"orig_im_size": np.array(input_size, dtype=np.float32),
|
238 |
}
|
239 |
|
240 |
+
masks, _, _ = self.decoder.run(None, decoder_inputs)
|
241 |
+
inv_transform_matrix = np.linalg.inv(transform_matrix)
|
242 |
+
masks = transform_masks(masks, original_size, inv_transform_matrix)
|
243 |
+
|
244 |
+
mask = np.zeros((masks.shape[2], masks.shape[3], 3), dtype=np.uint8)
|
245 |
+
for m in masks[0, :, :, :]:
|
246 |
+
mask[m > 0.0] = [255, 255, 255]
|
247 |
|
248 |
+
return [Image.fromarray(mask).convert("L")]
|
249 |
|
250 |
@classmethod
|
251 |
def download_models(cls, *args, **kwargs):
|
252 |
+
"""
|
253 |
+
Class method to download ONNX model files.
|
254 |
+
|
255 |
+
This method is responsible for downloading two ONNX model files from specified URLs and saving them locally. The downloaded files are saved with the naming convention 'name_encoder.onnx' and 'name_decoder.onnx', where 'name' is the value returned by the 'name' method.
|
256 |
+
|
257 |
+
Parameters:
|
258 |
+
cls: The class object.
|
259 |
+
*args: Variable length argument list.
|
260 |
+
**kwargs: Arbitrary keyword arguments.
|
261 |
+
|
262 |
+
Returns:
|
263 |
+
tuple: A tuple containing the file paths of the downloaded encoder and decoder models.
|
264 |
+
"""
|
265 |
+
model_name = kwargs.get("sam_model", "sam_vit_b_01ec64")
|
266 |
+
quant = kwargs.get("sam_quant", False)
|
267 |
+
|
268 |
+
fname_encoder = f"{model_name}.encoder.onnx"
|
269 |
+
fname_decoder = f"{model_name}.decoder.onnx"
|
270 |
+
|
271 |
+
if quant:
|
272 |
+
fname_encoder = f"{model_name}.encoder.quant.onnx"
|
273 |
+
fname_decoder = f"{model_name}.decoder.quant.onnx"
|
274 |
|
275 |
pooch.retrieve(
|
276 |
+
f"https://github.com/danielgatis/rembg/releases/download/v0.0.0/{fname_encoder}",
|
277 |
+
None,
|
|
|
|
|
278 |
fname=fname_encoder,
|
279 |
path=cls.u2net_home(*args, **kwargs),
|
280 |
progressbar=True,
|
281 |
)
|
282 |
|
283 |
pooch.retrieve(
|
284 |
+
f"https://github.com/danielgatis/rembg/releases/download/v0.0.0/{fname_decoder}",
|
285 |
+
None,
|
|
|
|
|
286 |
fname=fname_decoder,
|
287 |
path=cls.u2net_home(*args, **kwargs),
|
288 |
progressbar=True,
|
289 |
)
|
290 |
|
291 |
+
if fname_encoder == "sam_vit_h_4b8939.encoder.onnx" and not os.path.exists(
|
292 |
+
os.path.join(
|
293 |
+
cls.u2net_home(*args, **kwargs), "sam_vit_h_4b8939.encoder_data.bin"
|
294 |
+
)
|
295 |
+
):
|
296 |
+
content = bytearray()
|
297 |
+
|
298 |
+
for i in range(1, 4):
|
299 |
+
pooch.retrieve(
|
300 |
+
f"https://github.com/danielgatis/rembg/releases/download/v0.0.0/sam_vit_h_4b8939.encoder_data.{i}.bin",
|
301 |
+
None,
|
302 |
+
fname=f"sam_vit_h_4b8939.encoder_data.{i}.bin",
|
303 |
+
path=cls.u2net_home(*args, **kwargs),
|
304 |
+
progressbar=True,
|
305 |
+
)
|
306 |
+
|
307 |
+
fbin = os.path.join(
|
308 |
+
cls.u2net_home(*args, **kwargs),
|
309 |
+
f"sam_vit_h_4b8939.encoder_data.{i}.bin",
|
310 |
+
)
|
311 |
+
content.extend(open(fbin, "rb").read())
|
312 |
+
os.remove(fbin)
|
313 |
+
|
314 |
+
with open(
|
315 |
+
os.path.join(
|
316 |
+
cls.u2net_home(*args, **kwargs),
|
317 |
+
"sam_vit_h_4b8939.encoder_data.bin",
|
318 |
+
),
|
319 |
+
"wb",
|
320 |
+
) as fp:
|
321 |
+
fp.write(content)
|
322 |
+
|
323 |
return (
|
324 |
+
os.path.join(cls.u2net_home(*args, **kwargs), fname_encoder),
|
325 |
+
os.path.join(cls.u2net_home(*args, **kwargs), fname_decoder),
|
326 |
)
|
327 |
|
328 |
@classmethod
|
329 |
def name(cls, *args, **kwargs):
|
330 |
+
"""
|
331 |
+
Class method to return a string value.
|
332 |
+
|
333 |
+
This method returns the string value 'sam'.
|
334 |
+
|
335 |
+
Parameters:
|
336 |
+
cls: The class object.
|
337 |
+
*args: Variable length argument list.
|
338 |
+
**kwargs: Arbitrary keyword arguments.
|
339 |
+
|
340 |
+
Returns:
|
341 |
+
str: The string value 'sam'.
|
342 |
+
"""
|
343 |
return "sam"
|
rembg/sessions/silueta.py
CHANGED
@@ -10,7 +10,22 @@ from .base import BaseSession
|
|
10 |
|
11 |
|
12 |
class SiluetaSession(BaseSession):
|
|
|
|
|
13 |
def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
ort_outs = self.inner_session.run(
|
15 |
None,
|
16 |
self.normalize(
|
@@ -27,25 +42,51 @@ class SiluetaSession(BaseSession):
|
|
27 |
pred = np.squeeze(pred)
|
28 |
|
29 |
mask = Image.fromarray((pred * 255).astype("uint8"), mode="L")
|
30 |
-
mask = mask.resize(img.size, Image.LANCZOS)
|
31 |
|
32 |
return [mask]
|
33 |
|
34 |
@classmethod
|
35 |
def download_models(cls, *args, **kwargs):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
36 |
fname = f"{cls.name()}.onnx"
|
37 |
pooch.retrieve(
|
38 |
"https://github.com/danielgatis/rembg/releases/download/v0.0.0/silueta.onnx",
|
39 |
-
|
40 |
-
|
41 |
-
|
|
|
|
|
42 |
fname=fname,
|
43 |
path=cls.u2net_home(*args, **kwargs),
|
44 |
progressbar=True,
|
45 |
)
|
46 |
|
47 |
-
return os.path.join(cls.u2net_home(), fname)
|
48 |
|
49 |
@classmethod
|
50 |
def name(cls, *args, **kwargs):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
51 |
return "silueta"
|
|
|
10 |
|
11 |
|
12 |
class SiluetaSession(BaseSession):
|
13 |
+
"""This is a class representing a SiluetaSession object."""
|
14 |
+
|
15 |
def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]:
|
16 |
+
"""
|
17 |
+
Predict the mask of the input image.
|
18 |
+
|
19 |
+
This method takes an image as input, preprocesses it, and performs a prediction to generate a mask. The generated mask is then post-processed and returned as a list of PILImage objects.
|
20 |
+
|
21 |
+
Parameters:
|
22 |
+
img (PILImage): The input image to be processed.
|
23 |
+
*args: Variable length argument list.
|
24 |
+
**kwargs: Arbitrary keyword arguments.
|
25 |
+
|
26 |
+
Returns:
|
27 |
+
List[PILImage]: A list of post-processed masks.
|
28 |
+
"""
|
29 |
ort_outs = self.inner_session.run(
|
30 |
None,
|
31 |
self.normalize(
|
|
|
42 |
pred = np.squeeze(pred)
|
43 |
|
44 |
mask = Image.fromarray((pred * 255).astype("uint8"), mode="L")
|
45 |
+
mask = mask.resize(img.size, Image.Resampling.LANCZOS)
|
46 |
|
47 |
return [mask]
|
48 |
|
49 |
@classmethod
|
50 |
def download_models(cls, *args, **kwargs):
|
51 |
+
"""
|
52 |
+
Download the pre-trained model file.
|
53 |
+
|
54 |
+
This method downloads the pre-trained model file from a specified URL. The file is saved to the U2NET home directory.
|
55 |
+
|
56 |
+
Parameters:
|
57 |
+
*args: Variable length argument list.
|
58 |
+
**kwargs: Arbitrary keyword arguments.
|
59 |
+
|
60 |
+
Returns:
|
61 |
+
str: The path to the downloaded model file.
|
62 |
+
"""
|
63 |
fname = f"{cls.name()}.onnx"
|
64 |
pooch.retrieve(
|
65 |
"https://github.com/danielgatis/rembg/releases/download/v0.0.0/silueta.onnx",
|
66 |
+
(
|
67 |
+
None
|
68 |
+
if cls.checksum_disabled(*args, **kwargs)
|
69 |
+
else "md5:55e59e0d8062d2f5d013f4725ee84782"
|
70 |
+
),
|
71 |
fname=fname,
|
72 |
path=cls.u2net_home(*args, **kwargs),
|
73 |
progressbar=True,
|
74 |
)
|
75 |
|
76 |
+
return os.path.join(cls.u2net_home(*args, **kwargs), fname)
|
77 |
|
78 |
@classmethod
|
79 |
def name(cls, *args, **kwargs):
|
80 |
+
"""
|
81 |
+
Return the name of the model.
|
82 |
+
|
83 |
+
This method returns the name of the Silueta model.
|
84 |
+
|
85 |
+
Parameters:
|
86 |
+
*args: Variable length argument list.
|
87 |
+
**kwargs: Arbitrary keyword arguments.
|
88 |
+
|
89 |
+
Returns:
|
90 |
+
str: The name of the model.
|
91 |
+
"""
|
92 |
return "silueta"
|
rembg/sessions/u2net.py
CHANGED
@@ -10,7 +10,22 @@ from .base import BaseSession
|
|
10 |
|
11 |
|
12 |
class U2netSession(BaseSession):
|
|
|
|
|
|
|
|
|
13 |
def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
ort_outs = self.inner_session.run(
|
15 |
None,
|
16 |
self.normalize(
|
@@ -27,25 +42,47 @@ class U2netSession(BaseSession):
|
|
27 |
pred = np.squeeze(pred)
|
28 |
|
29 |
mask = Image.fromarray((pred * 255).astype("uint8"), mode="L")
|
30 |
-
mask = mask.resize(img.size, Image.LANCZOS)
|
31 |
|
32 |
return [mask]
|
33 |
|
34 |
@classmethod
|
35 |
def download_models(cls, *args, **kwargs):
|
36 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
37 |
pooch.retrieve(
|
38 |
"https://github.com/danielgatis/rembg/releases/download/v0.0.0/u2net.onnx",
|
39 |
-
|
40 |
-
|
41 |
-
|
|
|
|
|
42 |
fname=fname,
|
43 |
path=cls.u2net_home(*args, **kwargs),
|
44 |
progressbar=True,
|
45 |
)
|
46 |
|
47 |
-
return os.path.join(cls.u2net_home(), fname)
|
48 |
|
49 |
@classmethod
|
50 |
def name(cls, *args, **kwargs):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
51 |
return "u2net"
|
|
|
10 |
|
11 |
|
12 |
class U2netSession(BaseSession):
|
13 |
+
"""
|
14 |
+
This class represents a U2net session, which is a subclass of BaseSession.
|
15 |
+
"""
|
16 |
+
|
17 |
def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]:
|
18 |
+
"""
|
19 |
+
Predicts the output masks for the input image using the inner session.
|
20 |
+
|
21 |
+
Parameters:
|
22 |
+
img (PILImage): The input image.
|
23 |
+
*args: Additional positional arguments.
|
24 |
+
**kwargs: Additional keyword arguments.
|
25 |
+
|
26 |
+
Returns:
|
27 |
+
List[PILImage]: The list of output masks.
|
28 |
+
"""
|
29 |
ort_outs = self.inner_session.run(
|
30 |
None,
|
31 |
self.normalize(
|
|
|
42 |
pred = np.squeeze(pred)
|
43 |
|
44 |
mask = Image.fromarray((pred * 255).astype("uint8"), mode="L")
|
45 |
+
mask = mask.resize(img.size, Image.Resampling.LANCZOS)
|
46 |
|
47 |
return [mask]
|
48 |
|
49 |
@classmethod
|
50 |
def download_models(cls, *args, **kwargs):
|
51 |
+
"""
|
52 |
+
Downloads the U2net model file from a specific URL and saves it.
|
53 |
+
|
54 |
+
Parameters:
|
55 |
+
*args: Additional positional arguments.
|
56 |
+
**kwargs: Additional keyword arguments.
|
57 |
+
|
58 |
+
Returns:
|
59 |
+
str: The path to the downloaded model file.
|
60 |
+
"""
|
61 |
+
fname = f"{cls.name(*args, **kwargs)}.onnx"
|
62 |
pooch.retrieve(
|
63 |
"https://github.com/danielgatis/rembg/releases/download/v0.0.0/u2net.onnx",
|
64 |
+
(
|
65 |
+
None
|
66 |
+
if cls.checksum_disabled(*args, **kwargs)
|
67 |
+
else "md5:60024c5c889badc19c04ad937298a77b"
|
68 |
+
),
|
69 |
fname=fname,
|
70 |
path=cls.u2net_home(*args, **kwargs),
|
71 |
progressbar=True,
|
72 |
)
|
73 |
|
74 |
+
return os.path.join(cls.u2net_home(*args, **kwargs), fname)
|
75 |
|
76 |
@classmethod
|
77 |
def name(cls, *args, **kwargs):
|
78 |
+
"""
|
79 |
+
Returns the name of the U2net session.
|
80 |
+
|
81 |
+
Parameters:
|
82 |
+
*args: Additional positional arguments.
|
83 |
+
**kwargs: Additional keyword arguments.
|
84 |
+
|
85 |
+
Returns:
|
86 |
+
str: The name of the session.
|
87 |
+
"""
|
88 |
return "u2net"
|
rembg/sessions/u2net_cloth_seg.py
CHANGED
@@ -9,7 +9,7 @@ from scipy.special import log_softmax
|
|
9 |
|
10 |
from .base import BaseSession
|
11 |
|
12 |
-
|
13 |
0,
|
14 |
0,
|
15 |
0,
|
@@ -24,7 +24,7 @@ pallete1 = [
|
|
24 |
0,
|
25 |
]
|
26 |
|
27 |
-
|
28 |
0,
|
29 |
0,
|
30 |
0,
|
@@ -39,7 +39,7 @@ pallete2 = [
|
|
39 |
0,
|
40 |
]
|
41 |
|
42 |
-
|
43 |
0,
|
44 |
0,
|
45 |
0,
|
@@ -57,6 +57,22 @@ pallete3 = [
|
|
57 |
|
58 |
class Unet2ClothSession(BaseSession):
|
59 |
def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
60 |
ort_outs = self.inner_session.run(
|
61 |
None,
|
62 |
self.normalize(
|
@@ -71,41 +87,59 @@ class Unet2ClothSession(BaseSession):
|
|
71 |
pred = np.squeeze(pred, 0)
|
72 |
|
73 |
mask = Image.fromarray(pred.astype("uint8"), mode="L")
|
74 |
-
mask = mask.resize(img.size, Image.LANCZOS)
|
75 |
|
76 |
masks = []
|
77 |
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
92 |
|
93 |
return masks
|
94 |
|
95 |
@classmethod
|
96 |
def download_models(cls, *args, **kwargs):
|
97 |
-
fname = f"{cls.name()}.onnx"
|
98 |
pooch.retrieve(
|
99 |
"https://github.com/danielgatis/rembg/releases/download/v0.0.0/u2net_cloth_seg.onnx",
|
100 |
-
|
101 |
-
|
102 |
-
|
|
|
|
|
103 |
fname=fname,
|
104 |
path=cls.u2net_home(*args, **kwargs),
|
105 |
progressbar=True,
|
106 |
)
|
107 |
|
108 |
-
return os.path.join(cls.u2net_home(), fname)
|
109 |
|
110 |
@classmethod
|
111 |
def name(cls, *args, **kwargs):
|
|
|
9 |
|
10 |
from .base import BaseSession
|
11 |
|
12 |
+
palette1 = [
|
13 |
0,
|
14 |
0,
|
15 |
0,
|
|
|
24 |
0,
|
25 |
]
|
26 |
|
27 |
+
palette2 = [
|
28 |
0,
|
29 |
0,
|
30 |
0,
|
|
|
39 |
0,
|
40 |
]
|
41 |
|
42 |
+
palette3 = [
|
43 |
0,
|
44 |
0,
|
45 |
0,
|
|
|
57 |
|
58 |
class Unet2ClothSession(BaseSession):
|
59 |
def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]:
|
60 |
+
"""
|
61 |
+
Predict the cloth category of an image.
|
62 |
+
|
63 |
+
This method takes an image as input and predicts the cloth category of the image.
|
64 |
+
The method uses the inner_session to make predictions using a pre-trained model.
|
65 |
+
The predicted mask is then converted to an image and resized to match the size of the input image.
|
66 |
+
Depending on the cloth category specified in the method arguments, the method applies different color palettes to the mask and appends the resulting images to a list.
|
67 |
+
|
68 |
+
Parameters:
|
69 |
+
img (PILImage): The input image.
|
70 |
+
*args: Additional positional arguments.
|
71 |
+
**kwargs: Additional keyword arguments.
|
72 |
+
|
73 |
+
Returns:
|
74 |
+
List[PILImage]: A list of images representing the predicted masks.
|
75 |
+
"""
|
76 |
ort_outs = self.inner_session.run(
|
77 |
None,
|
78 |
self.normalize(
|
|
|
87 |
pred = np.squeeze(pred, 0)
|
88 |
|
89 |
mask = Image.fromarray(pred.astype("uint8"), mode="L")
|
90 |
+
mask = mask.resize(img.size, Image.Resampling.LANCZOS)
|
91 |
|
92 |
masks = []
|
93 |
|
94 |
+
cloth_category = kwargs.get("cc") or kwargs.get("cloth_category")
|
95 |
+
|
96 |
+
def upper_cloth():
|
97 |
+
mask1 = mask.copy()
|
98 |
+
mask1.putpalette(palette1)
|
99 |
+
mask1 = mask1.convert("RGB").convert("L")
|
100 |
+
masks.append(mask1)
|
101 |
+
|
102 |
+
def lower_cloth():
|
103 |
+
mask2 = mask.copy()
|
104 |
+
mask2.putpalette(palette2)
|
105 |
+
mask2 = mask2.convert("RGB").convert("L")
|
106 |
+
masks.append(mask2)
|
107 |
+
|
108 |
+
def full_cloth():
|
109 |
+
mask3 = mask.copy()
|
110 |
+
mask3.putpalette(palette3)
|
111 |
+
mask3 = mask3.convert("RGB").convert("L")
|
112 |
+
masks.append(mask3)
|
113 |
+
|
114 |
+
if cloth_category == "upper":
|
115 |
+
upper_cloth()
|
116 |
+
elif cloth_category == "lower":
|
117 |
+
lower_cloth()
|
118 |
+
elif cloth_category == "full":
|
119 |
+
full_cloth()
|
120 |
+
else:
|
121 |
+
upper_cloth()
|
122 |
+
lower_cloth()
|
123 |
+
full_cloth()
|
124 |
|
125 |
return masks
|
126 |
|
127 |
@classmethod
|
128 |
def download_models(cls, *args, **kwargs):
|
129 |
+
fname = f"{cls.name(*args, **kwargs)}.onnx"
|
130 |
pooch.retrieve(
|
131 |
"https://github.com/danielgatis/rembg/releases/download/v0.0.0/u2net_cloth_seg.onnx",
|
132 |
+
(
|
133 |
+
None
|
134 |
+
if cls.checksum_disabled(*args, **kwargs)
|
135 |
+
else "md5:2434d1f3cb744e0e49386c906e5a08bb"
|
136 |
+
),
|
137 |
fname=fname,
|
138 |
path=cls.u2net_home(*args, **kwargs),
|
139 |
progressbar=True,
|
140 |
)
|
141 |
|
142 |
+
return os.path.join(cls.u2net_home(*args, **kwargs), fname)
|
143 |
|
144 |
@classmethod
|
145 |
def name(cls, *args, **kwargs):
|
rembg/sessions/u2net_custom.py
ADDED
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from typing import List
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import onnxruntime as ort
|
6 |
+
import pooch
|
7 |
+
from PIL import Image
|
8 |
+
from PIL.Image import Image as PILImage
|
9 |
+
|
10 |
+
from .base import BaseSession
|
11 |
+
|
12 |
+
|
13 |
+
class U2netCustomSession(BaseSession):
|
14 |
+
"""This is a class representing a custom session for the U2net model."""
|
15 |
+
|
16 |
+
def __init__(
|
17 |
+
self,
|
18 |
+
model_name: str,
|
19 |
+
sess_opts: ort.SessionOptions,
|
20 |
+
providers=None,
|
21 |
+
*args,
|
22 |
+
**kwargs
|
23 |
+
):
|
24 |
+
"""
|
25 |
+
Initialize a new U2netCustomSession object.
|
26 |
+
|
27 |
+
Parameters:
|
28 |
+
model_name (str): The name of the model.
|
29 |
+
sess_opts (ort.SessionOptions): The session options.
|
30 |
+
providers: The providers.
|
31 |
+
*args: Additional positional arguments.
|
32 |
+
**kwargs: Additional keyword arguments.
|
33 |
+
|
34 |
+
Raises:
|
35 |
+
ValueError: If model_path is None.
|
36 |
+
"""
|
37 |
+
model_path = kwargs.get("model_path")
|
38 |
+
if model_path is None:
|
39 |
+
raise ValueError("model_path is required")
|
40 |
+
|
41 |
+
super().__init__(model_name, sess_opts, providers, *args, **kwargs)
|
42 |
+
|
43 |
+
def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]:
|
44 |
+
"""
|
45 |
+
Predict the segmentation mask for the input image.
|
46 |
+
|
47 |
+
Parameters:
|
48 |
+
img (PILImage): The input image.
|
49 |
+
*args: Additional positional arguments.
|
50 |
+
**kwargs: Additional keyword arguments.
|
51 |
+
|
52 |
+
Returns:
|
53 |
+
List[PILImage]: A list of PILImage objects representing the segmentation mask.
|
54 |
+
"""
|
55 |
+
ort_outs = self.inner_session.run(
|
56 |
+
None,
|
57 |
+
self.normalize(
|
58 |
+
img, (0.485, 0.456, 0.406), (0.229, 0.224, 0.225), (320, 320)
|
59 |
+
),
|
60 |
+
)
|
61 |
+
|
62 |
+
pred = ort_outs[0][:, 0, :, :]
|
63 |
+
|
64 |
+
ma = np.max(pred)
|
65 |
+
mi = np.min(pred)
|
66 |
+
|
67 |
+
pred = (pred - mi) / (ma - mi)
|
68 |
+
pred = np.squeeze(pred)
|
69 |
+
|
70 |
+
mask = Image.fromarray((pred * 255).astype("uint8"), mode="L")
|
71 |
+
mask = mask.resize(img.size, Image.Resampling.LANCZOS)
|
72 |
+
|
73 |
+
return [mask]
|
74 |
+
|
75 |
+
@classmethod
|
76 |
+
def download_models(cls, *args, **kwargs):
|
77 |
+
"""
|
78 |
+
Download the model files.
|
79 |
+
|
80 |
+
Parameters:
|
81 |
+
*args: Additional positional arguments.
|
82 |
+
**kwargs: Additional keyword arguments.
|
83 |
+
|
84 |
+
Returns:
|
85 |
+
str: The absolute path to the model files.
|
86 |
+
"""
|
87 |
+
model_path = kwargs.get("model_path")
|
88 |
+
if model_path is None:
|
89 |
+
return
|
90 |
+
|
91 |
+
return os.path.abspath(os.path.expanduser(model_path))
|
92 |
+
|
93 |
+
@classmethod
|
94 |
+
def name(cls, *args, **kwargs):
|
95 |
+
"""
|
96 |
+
Get the name of the model.
|
97 |
+
|
98 |
+
Parameters:
|
99 |
+
*args: Additional positional arguments.
|
100 |
+
**kwargs: Additional keyword arguments.
|
101 |
+
|
102 |
+
Returns:
|
103 |
+
str: The name of the model.
|
104 |
+
"""
|
105 |
+
return "u2net_custom"
|
rembg/sessions/u2net_human_seg.py
CHANGED
@@ -10,7 +10,22 @@ from .base import BaseSession
|
|
10 |
|
11 |
|
12 |
class U2netHumanSegSession(BaseSession):
|
|
|
|
|
|
|
|
|
13 |
def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
ort_outs = self.inner_session.run(
|
15 |
None,
|
16 |
self.normalize(
|
@@ -27,25 +42,47 @@ class U2netHumanSegSession(BaseSession):
|
|
27 |
pred = np.squeeze(pred)
|
28 |
|
29 |
mask = Image.fromarray((pred * 255).astype("uint8"), mode="L")
|
30 |
-
mask = mask.resize(img.size, Image.LANCZOS)
|
31 |
|
32 |
return [mask]
|
33 |
|
34 |
@classmethod
|
35 |
def download_models(cls, *args, **kwargs):
|
36 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
37 |
pooch.retrieve(
|
38 |
"https://github.com/danielgatis/rembg/releases/download/v0.0.0/u2net_human_seg.onnx",
|
39 |
-
|
40 |
-
|
41 |
-
|
|
|
|
|
42 |
fname=fname,
|
43 |
path=cls.u2net_home(*args, **kwargs),
|
44 |
progressbar=True,
|
45 |
)
|
46 |
|
47 |
-
return os.path.join(cls.u2net_home(), fname)
|
48 |
|
49 |
@classmethod
|
50 |
def name(cls, *args, **kwargs):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
51 |
return "u2net_human_seg"
|
|
|
10 |
|
11 |
|
12 |
class U2netHumanSegSession(BaseSession):
|
13 |
+
"""
|
14 |
+
This class represents a session for performing human segmentation using the U2Net model.
|
15 |
+
"""
|
16 |
+
|
17 |
def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]:
|
18 |
+
"""
|
19 |
+
Predicts human segmentation masks for the input image.
|
20 |
+
|
21 |
+
Parameters:
|
22 |
+
img (PILImage): The input image.
|
23 |
+
*args: Variable length argument list.
|
24 |
+
**kwargs: Arbitrary keyword arguments.
|
25 |
+
|
26 |
+
Returns:
|
27 |
+
List[PILImage]: A list of predicted masks.
|
28 |
+
"""
|
29 |
ort_outs = self.inner_session.run(
|
30 |
None,
|
31 |
self.normalize(
|
|
|
42 |
pred = np.squeeze(pred)
|
43 |
|
44 |
mask = Image.fromarray((pred * 255).astype("uint8"), mode="L")
|
45 |
+
mask = mask.resize(img.size, Image.Resampling.LANCZOS)
|
46 |
|
47 |
return [mask]
|
48 |
|
49 |
@classmethod
|
50 |
def download_models(cls, *args, **kwargs):
|
51 |
+
"""
|
52 |
+
Downloads the U2Net model weights.
|
53 |
+
|
54 |
+
Parameters:
|
55 |
+
*args: Variable length argument list.
|
56 |
+
**kwargs: Arbitrary keyword arguments.
|
57 |
+
|
58 |
+
Returns:
|
59 |
+
str: The path to the downloaded model weights.
|
60 |
+
"""
|
61 |
+
fname = f"{cls.name(*args, **kwargs)}.onnx"
|
62 |
pooch.retrieve(
|
63 |
"https://github.com/danielgatis/rembg/releases/download/v0.0.0/u2net_human_seg.onnx",
|
64 |
+
(
|
65 |
+
None
|
66 |
+
if cls.checksum_disabled(*args, **kwargs)
|
67 |
+
else "md5:c09ddc2e0104f800e3e1bb4652583d1f"
|
68 |
+
),
|
69 |
fname=fname,
|
70 |
path=cls.u2net_home(*args, **kwargs),
|
71 |
progressbar=True,
|
72 |
)
|
73 |
|
74 |
+
return os.path.join(cls.u2net_home(*args, **kwargs), fname)
|
75 |
|
76 |
@classmethod
|
77 |
def name(cls, *args, **kwargs):
|
78 |
+
"""
|
79 |
+
Returns the name of the U2Net model.
|
80 |
+
|
81 |
+
Parameters:
|
82 |
+
*args: Variable length argument list.
|
83 |
+
**kwargs: Arbitrary keyword arguments.
|
84 |
+
|
85 |
+
Returns:
|
86 |
+
str: The name of the model.
|
87 |
+
"""
|
88 |
return "u2net_human_seg"
|
rembg/sessions/u2netp.py
CHANGED
@@ -10,7 +10,18 @@ from .base import BaseSession
|
|
10 |
|
11 |
|
12 |
class U2netpSession(BaseSession):
|
|
|
|
|
13 |
def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
ort_outs = self.inner_session.run(
|
15 |
None,
|
16 |
self.normalize(
|
@@ -27,25 +38,39 @@ class U2netpSession(BaseSession):
|
|
27 |
pred = np.squeeze(pred)
|
28 |
|
29 |
mask = Image.fromarray((pred * 255).astype("uint8"), mode="L")
|
30 |
-
mask = mask.resize(img.size, Image.LANCZOS)
|
31 |
|
32 |
return [mask]
|
33 |
|
34 |
@classmethod
|
35 |
def download_models(cls, *args, **kwargs):
|
36 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
37 |
pooch.retrieve(
|
38 |
"https://github.com/danielgatis/rembg/releases/download/v0.0.0/u2netp.onnx",
|
39 |
-
|
40 |
-
|
41 |
-
|
|
|
|
|
42 |
fname=fname,
|
43 |
path=cls.u2net_home(*args, **kwargs),
|
44 |
progressbar=True,
|
45 |
)
|
46 |
|
47 |
-
return os.path.join(cls.u2net_home(), fname)
|
48 |
|
49 |
@classmethod
|
50 |
def name(cls, *args, **kwargs):
|
|
|
|
|
|
|
|
|
|
|
|
|
51 |
return "u2netp"
|
|
|
10 |
|
11 |
|
12 |
class U2netpSession(BaseSession):
|
13 |
+
"""This class represents a session for using the U2netp model."""
|
14 |
+
|
15 |
def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]:
|
16 |
+
"""
|
17 |
+
Predicts the mask for the given image using the U2netp model.
|
18 |
+
|
19 |
+
Parameters:
|
20 |
+
img (PILImage): The input image.
|
21 |
+
|
22 |
+
Returns:
|
23 |
+
List[PILImage]: The predicted mask.
|
24 |
+
"""
|
25 |
ort_outs = self.inner_session.run(
|
26 |
None,
|
27 |
self.normalize(
|
|
|
38 |
pred = np.squeeze(pred)
|
39 |
|
40 |
mask = Image.fromarray((pred * 255).astype("uint8"), mode="L")
|
41 |
+
mask = mask.resize(img.size, Image.Resampling.LANCZOS)
|
42 |
|
43 |
return [mask]
|
44 |
|
45 |
@classmethod
|
46 |
def download_models(cls, *args, **kwargs):
|
47 |
+
"""
|
48 |
+
Downloads the U2netp model.
|
49 |
+
|
50 |
+
Returns:
|
51 |
+
str: The path to the downloaded model.
|
52 |
+
"""
|
53 |
+
fname = f"{cls.name(*args, **kwargs)}.onnx"
|
54 |
pooch.retrieve(
|
55 |
"https://github.com/danielgatis/rembg/releases/download/v0.0.0/u2netp.onnx",
|
56 |
+
(
|
57 |
+
None
|
58 |
+
if cls.checksum_disabled(*args, **kwargs)
|
59 |
+
else "md5:8e83ca70e441ab06c318d82300c84806"
|
60 |
+
),
|
61 |
fname=fname,
|
62 |
path=cls.u2net_home(*args, **kwargs),
|
63 |
progressbar=True,
|
64 |
)
|
65 |
|
66 |
+
return os.path.join(cls.u2net_home(*args, **kwargs), fname)
|
67 |
|
68 |
@classmethod
|
69 |
def name(cls, *args, **kwargs):
|
70 |
+
"""
|
71 |
+
Returns the name of the U2netp model.
|
72 |
+
|
73 |
+
Returns:
|
74 |
+
str: The name of the model.
|
75 |
+
"""
|
76 |
return "u2netp"
|
requirements.txt
CHANGED
@@ -1,18 +1,20 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
|
|
|
|
|
1 |
+
filetype==1.2.0
|
2 |
+
pooch==1.6.0
|
3 |
+
imagehash==4.3.1
|
4 |
+
numpy==1.23.5
|
5 |
+
onnxruntime==1.13.1
|
6 |
+
opencv-python-headless==4.6.0.66
|
7 |
+
pillow==9.3.0
|
8 |
+
pymatting==1.1.8
|
9 |
+
python-multipart==0.0.5
|
10 |
+
scikit-image==0.19.3
|
11 |
+
scipy==1.9.3
|
12 |
+
tqdm==4.64.1
|
13 |
+
uvicorn==0.20.0
|
14 |
+
watchdog==2.1.9
|
15 |
+
click==8.1.3
|
16 |
+
fastapi==0.85.0
|
17 |
+
aiohttp==3.8.3
|
18 |
+
asyncer==0.0.2
|
19 |
+
gradio==3.0.20
|
20 |
+
jsonschema==4.16.0
|