KenjieDec commited on
Commit
c8f8b0e
1 Parent(s): 551d7cc

Update to latest version + sam support?

Browse files
.gitattributes CHANGED
@@ -1,27 +1,27 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ftz filter=lfs diff=lfs merge=lfs -text
6
- *.gz filter=lfs diff=lfs merge=lfs -text
7
- *.h5 filter=lfs diff=lfs merge=lfs -text
8
- *.joblib filter=lfs diff=lfs merge=lfs -text
9
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
10
- *.model filter=lfs diff=lfs merge=lfs -text
11
- *.msgpack filter=lfs diff=lfs merge=lfs -text
12
- *.onnx filter=lfs diff=lfs merge=lfs -text
13
- *.ot filter=lfs diff=lfs merge=lfs -text
14
- *.parquet filter=lfs diff=lfs merge=lfs -text
15
- *.pb filter=lfs diff=lfs merge=lfs -text
16
- *.pt filter=lfs diff=lfs merge=lfs -text
17
- *.pth filter=lfs diff=lfs merge=lfs -text
18
- *.rar filter=lfs diff=lfs merge=lfs -text
19
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
20
- *.tar.* filter=lfs diff=lfs merge=lfs -text
21
- *.tflite filter=lfs diff=lfs merge=lfs -text
22
- *.tgz filter=lfs diff=lfs merge=lfs -text
23
- *.wasm filter=lfs diff=lfs merge=lfs -text
24
- *.xz filter=lfs diff=lfs merge=lfs -text
25
- *.zip filter=lfs diff=lfs merge=lfs -text
26
- *.zstandard filter=lfs diff=lfs merge=lfs -text
27
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ftz filter=lfs diff=lfs merge=lfs -text
6
+ *.gz filter=lfs diff=lfs merge=lfs -text
7
+ *.h5 filter=lfs diff=lfs merge=lfs -text
8
+ *.joblib filter=lfs diff=lfs merge=lfs -text
9
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
10
+ *.model filter=lfs diff=lfs merge=lfs -text
11
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
12
+ *.onnx filter=lfs diff=lfs merge=lfs -text
13
+ *.ot filter=lfs diff=lfs merge=lfs -text
14
+ *.parquet filter=lfs diff=lfs merge=lfs -text
15
+ *.pb filter=lfs diff=lfs merge=lfs -text
16
+ *.pt filter=lfs diff=lfs merge=lfs -text
17
+ *.pth filter=lfs diff=lfs merge=lfs -text
18
+ *.rar filter=lfs diff=lfs merge=lfs -text
19
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
20
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
21
+ *.tflite filter=lfs diff=lfs merge=lfs -text
22
+ *.tgz filter=lfs diff=lfs merge=lfs -text
23
+ *.wasm filter=lfs diff=lfs merge=lfs -text
24
+ *.xz filter=lfs diff=lfs merge=lfs -text
25
+ *.zip filter=lfs diff=lfs merge=lfs -text
26
+ *.zstandard filter=lfs diff=lfs merge=lfs -text
27
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
LICENSE.txt CHANGED
@@ -1,21 +1,21 @@
1
- MIT License
2
-
3
- Copyright (c) 2020 Daniel Gatis
4
-
5
- Permission is hereby granted, free of charge, to any person obtaining a copy
6
- of this software and associated documentation files (the "Software"), to deal
7
- in the Software without restriction, including without limitation the rights
8
- to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
- copies of the Software, and to permit persons to whom the Software is
10
- furnished to do so, subject to the following conditions:
11
-
12
- The above copyright notice and this permission notice shall be included in all
13
- copies or substantial portions of the Software.
14
-
15
- THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
- IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
- FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
- AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
- LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
- OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
  SOFTWARE.
 
1
+ MIT License
2
+
3
+ Copyright (c) 2020 Daniel Gatis
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
  SOFTWARE.
README.md CHANGED
@@ -1,12 +1,12 @@
1
- ---
2
- title: Rembg
3
- emoji: 👀
4
- colorFrom: pink
5
- colorTo: indigo
6
- sdk: gradio
7
- sdk_version: 5.6.0
8
- app_file: app.py
9
- pinned: false
10
- ---
11
-
12
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
+ ---
2
+ title: Rembg
3
+ emoji: 👀
4
+ colorFrom: pink
5
+ colorTo: indigo
6
+ sdk: gradio
7
+ sdk_version: 5.6.0
8
+ app_file: app.py
9
+ pinned: false
10
+ ---
11
+
12
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py CHANGED
@@ -1,70 +1,109 @@
1
- ## Modified from Akhaliq Hugging Face Demo
2
- ## https://huggingface.co/akhaliq
3
-
4
- import gradio as gr
5
- import os
6
- import cv2
7
-
8
- def inference(file, mask, model):
9
- im = cv2.imread(file, cv2.IMREAD_COLOR)
10
- cv2.imwrite(os.path.join("input.png"), im)
11
-
12
- from rembg import new_session, remove
13
-
14
- input_path = 'input.png'
15
- output_path = 'output.png'
16
-
17
- with open(input_path, 'rb') as i:
18
- with open(output_path, 'wb') as o:
19
- input = i.read()
20
- output = remove(
21
- input,
22
- session = new_session(model),
23
- only_mask = (True if mask == "Mask only" else False)
24
- )
25
-
26
-
27
-
28
- o.write(output)
29
- return os.path.join("output.png")
30
-
31
- title = "RemBG"
32
- description = "Gradio demo for RemBG. To use it, simply upload your image and wait. Read more at the link below."
33
- article = "<p style='text-align: center;'><a href='https://github.com/danielgatis/rembg' target='_blank'>Github Repo</a></p>"
34
-
35
-
36
- gr.Interface(
37
- inference,
38
- [
39
- gr.inputs.Image(type="filepath", label="Input"),
40
- gr.inputs.Radio(
41
- [
42
- "Default",
43
- "Mask only"
44
- ],
45
- type="value",
46
- default="Default",
47
- label="Choices"
48
- ),
49
- gr.inputs.Dropdown([
50
- "u2net",
51
- "u2netp",
52
- "u2net_human_seg",
53
- "u2net_cloth_seg",
54
- "silueta",
55
- "isnet-general-use",
56
- "isnet-anime",
57
- "sam",
58
- ],
59
- type="value",
60
- default="isnet-general-use",
61
- label="Models"
62
- ),
63
- ],
64
- gr.outputs.Image(type="filepath", label="Output"),
65
- title=title,
66
- description=description,
67
- article=article,
68
- examples=[["lion.png", "Default", "u2net"], ["girl.jpg", "Default", "u2net"], ["anime-girl.jpg", "Default", "isnet-anime"]],
69
- enable_queue=True
70
- ).launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ import cv2
4
+ from rembg import new_session, remove
5
+ from rembg.sessions import sessions_class
6
+
7
+ def inference(file, mask, model, x, y):
8
+ im = cv2.imread(file, cv2.IMREAD_COLOR)
9
+ input_path = "input.png"
10
+ output_path = "output.png"
11
+ cv2.imwrite(input_path, im)
12
+
13
+ with open(input_path, 'rb') as i:
14
+ with open(output_path, 'wb') as o:
15
+ input = i.read()
16
+ session = new_session(model)
17
+
18
+ output = remove(
19
+ input,
20
+ session=session,
21
+ **{ "sam_prompt": [{"type": "point", "data": [x, y], "label": 1}] },
22
+ only_mask=(mask == "Mask only")
23
+ )
24
+ o.write(output)
25
+
26
+ return output_path
27
+
28
+ title = "RemBG"
29
+ description = "Gradio demo for **[RemBG](https://github.com/danielgatis/rembg)**. To use it, simply upload your image, select a model, click Process, and wait."
30
+ badge = """
31
+ <div style="position: fixed; left: 50%; text-align: center;">
32
+ <a href="https://github.com/danielgatis/rembg" target="_blank" style="text-decoration: none;">
33
+ <img src="https://img.shields.io/badge/RemBG-Github-blue" alt="RemBG Github" />
34
+ </a>
35
+ </div>
36
+ """
37
+ def get_coords(evt: gr.SelectData) -> tuple:
38
+ return evt.index[0], evt.index[1]
39
+
40
+ def show_coords(model: str):
41
+ visible = model == "sam"
42
+ return gr.update(visible=visible), gr.update(visible=visible), gr.update(visible=visible)
43
+
44
+ for session in sessions_class:
45
+ session.download_models()
46
+
47
+ with gr.Blocks() as app:
48
+ gr.Markdown(f"# {title}")
49
+ gr.Markdown(description)
50
+
51
+ with gr.Row():
52
+ inputs = gr.Image(type="filepath", label="Input Image")
53
+ outputs = gr.Image(type="filepath", label="Output Image")
54
+
55
+ with gr.Row():
56
+ mask_option = gr.Radio(
57
+ ["Default", "Mask only"],
58
+ value="Default",
59
+ label="Output Type"
60
+ )
61
+ model_selector = gr.Dropdown(
62
+ [
63
+ "u2net",
64
+ "u2netp",
65
+ "u2net_human_seg",
66
+ "u2net_cloth_seg",
67
+ "silueta",
68
+ "isnet-general-use",
69
+ "isnet-anime",
70
+ "sam",
71
+ "birefnet-general",
72
+ "birefnet-general-lite",
73
+ "birefnet-portrait",
74
+ "birefnet-dis",
75
+ "birefnet-hrsod",
76
+ "birefnet-cod",
77
+ "birefnet-massive"
78
+ ],
79
+ value="isnet-general-use",
80
+ label="Model Selection"
81
+ )
82
+
83
+ extra = gr.Markdown("## Click on the image to capture coordinates (for SAM model)", visible=False)
84
+
85
+ x = gr.Number(label="Mouse X Coordinate", visible=False)
86
+ y = gr.Number(label="Mouse Y Coordinate", visible=False)
87
+
88
+ model_selector.change(show_coords, inputs=model_selector, outputs=[x, y, extra])
89
+ inputs.select(get_coords, None, [x, y])
90
+
91
+
92
+ gr.Button("Process Image").click(
93
+ inference,
94
+ inputs=[inputs, mask_option, model_selector, x, y],
95
+ outputs=outputs
96
+ )
97
+
98
+ gr.Examples(
99
+ examples=[
100
+ ["lion.png", "Default", "u2net", None, None],
101
+ ["girl.jpg", "Default", "u2net", None, None],
102
+ ["anime-girl.jpg", "Default", "isnet-anime", None, None]
103
+ ],
104
+ inputs=[inputs, mask_option, model_selector, x, y],
105
+ outputs=outputs
106
+ )
107
+ gr.HTML(badge)
108
+
109
+ app.launch()
rembg/_version.py CHANGED
@@ -23,9 +23,9 @@ def get_keywords():
23
  # setup.py/versioneer.py will grep for the variable names, so they must
24
  # each be defined on a line of their own. _version.py will just call
25
  # get_keywords().
26
- git_refnames = " (HEAD -> main, tag: v2.0.43)"
27
- git_full = "848a38e4cc5cf41522974dea00848596105b1dfa"
28
- git_date = "2023-06-02 09:20:57 -0300"
29
  keywords = {"refnames": git_refnames, "full": git_full, "date": git_date}
30
  return keywords
31
 
 
23
  # setup.py/versioneer.py will grep for the variable names, so they must
24
  # each be defined on a line of their own. _version.py will just call
25
  # get_keywords().
26
+ git_refnames = " (HEAD -> main)"
27
+ git_full = "e740a9681ea32f5c34adce52aa7cc0b4b85bbb11"
28
+ git_date = "2024-11-20 09:41:13 -0300"
29
  keywords = {"refnames": git_refnames, "full": git_full, "date": git_date}
30
  return keywords
31
 
rembg/bg.py CHANGED
@@ -1,8 +1,9 @@
1
  import io
2
  from enum import Enum
3
- from typing import Any, List, Optional, Tuple, Union
4
 
5
  import numpy as np
 
6
  from cv2 import (
7
  BORDER_DEFAULT,
8
  MORPH_ELLIPSE,
@@ -22,6 +23,8 @@ from .session_factory import new_session
22
  from .sessions import sessions_class
23
  from .sessions.base import BaseSession
24
 
 
 
25
  kernel = getStructuringElement(MORPH_ELLIPSE, (3, 3))
26
 
27
 
@@ -38,14 +41,25 @@ def alpha_matting_cutout(
38
  background_threshold: int,
39
  erode_structure_size: int,
40
  ) -> PILImage:
 
 
 
 
 
 
 
 
 
 
 
41
  if img.mode == "RGBA" or img.mode == "CMYK":
42
  img = img.convert("RGB")
43
 
44
- img = np.asarray(img)
45
- mask = np.asarray(mask)
46
 
47
- is_foreground = mask > foreground_threshold
48
- is_background = mask < background_threshold
49
 
50
  structure = None
51
  if erode_structure_size > 0:
@@ -56,11 +70,11 @@ def alpha_matting_cutout(
56
  is_foreground = binary_erosion(is_foreground, structure=structure)
57
  is_background = binary_erosion(is_background, structure=structure, border_value=1)
58
 
59
- trimap = np.full(mask.shape, dtype=np.uint8, fill_value=128)
60
  trimap[is_foreground] = 255
61
  trimap[is_background] = 0
62
 
63
- img_normalized = img / 255.0
64
  trimap_normalized = trimap / 255.0
65
 
66
  alpha = estimate_alpha_cf(img_normalized, trimap_normalized)
@@ -74,12 +88,46 @@ def alpha_matting_cutout(
74
 
75
 
76
  def naive_cutout(img: PILImage, mask: PILImage) -> PILImage:
 
 
 
 
 
 
 
 
 
 
77
  empty = Image.new("RGBA", (img.size), 0)
78
  cutout = Image.composite(img, empty, mask)
79
  return cutout
80
 
81
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
  def get_concat_v_multi(imgs: List[PILImage]) -> PILImage:
 
 
 
 
 
 
 
 
 
83
  pivot = imgs.pop(0)
84
  for im in imgs:
85
  pivot = get_concat_v(pivot, im)
@@ -87,6 +135,16 @@ def get_concat_v_multi(imgs: List[PILImage]) -> PILImage:
87
 
88
 
89
  def get_concat_v(img1: PILImage, img2: PILImage) -> PILImage:
 
 
 
 
 
 
 
 
 
 
90
  dst = Image.new("RGBA", (img1.width, img1.height + img2.height))
91
  dst.paste(img1, (0, 0))
92
  dst.paste(img2, (0, img1.height))
@@ -102,11 +160,21 @@ def post_process(mask: np.ndarray) -> np.ndarray:
102
  """
103
  mask = morphologyEx(mask, MORPH_OPEN, kernel)
104
  mask = GaussianBlur(mask, (5, 5), sigmaX=2, sigmaY=2, borderType=BORDER_DEFAULT)
105
- mask = np.where(mask < 127, 0, 255).astype(np.uint8) # convert again to binary
106
  return mask
107
 
108
 
109
  def apply_background_color(img: PILImage, color: Tuple[int, int, int, int]) -> PILImage:
 
 
 
 
 
 
 
 
 
 
110
  r, g, b, a = color
111
  colored_image = Image.new("RGBA", img.size, (r, g, b, a))
112
  colored_image.paste(img, mask=img)
@@ -115,10 +183,22 @@ def apply_background_color(img: PILImage, color: Tuple[int, int, int, int]) -> P
115
 
116
 
117
  def fix_image_orientation(img: PILImage) -> PILImage:
118
- return ImageOps.exif_transpose(img)
 
 
 
 
 
 
 
 
 
119
 
120
 
121
  def download_models() -> None:
 
 
 
122
  for session in sessions_class:
123
  session.download_models()
124
 
@@ -133,20 +213,49 @@ def remove(
133
  only_mask: bool = False,
134
  post_process_mask: bool = False,
135
  bgcolor: Optional[Tuple[int, int, int, int]] = None,
 
136
  *args: Optional[Any],
137
  **kwargs: Optional[Any]
138
  ) -> Union[bytes, PILImage, np.ndarray]:
139
- if isinstance(data, PILImage):
140
- return_type = ReturnType.PILLOW
141
- img = data
142
- elif isinstance(data, bytes):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
143
  return_type = ReturnType.BYTES
144
- img = Image.open(io.BytesIO(data))
 
 
 
145
  elif isinstance(data, np.ndarray):
146
  return_type = ReturnType.NDARRAY
147
- img = Image.fromarray(data)
148
  else:
149
- raise ValueError("Input type {} is not supported.".format(type(data)))
 
 
 
 
 
 
150
 
151
  # Fix image orientation
152
  img = fix_image_orientation(img)
@@ -174,10 +283,15 @@ def remove(
174
  alpha_matting_erode_size,
175
  )
176
  except ValueError:
177
- cutout = naive_cutout(img, mask)
178
-
 
 
179
  else:
180
- cutout = naive_cutout(img, mask)
 
 
 
181
 
182
  cutouts.append(cutout)
183
 
 
1
  import io
2
  from enum import Enum
3
+ from typing import Any, List, Optional, Tuple, Union, cast
4
 
5
  import numpy as np
6
+ import onnxruntime as ort
7
  from cv2 import (
8
  BORDER_DEFAULT,
9
  MORPH_ELLIPSE,
 
23
  from .sessions import sessions_class
24
  from .sessions.base import BaseSession
25
 
26
+ ort.set_default_logger_severity(3)
27
+
28
  kernel = getStructuringElement(MORPH_ELLIPSE, (3, 3))
29
 
30
 
 
41
  background_threshold: int,
42
  erode_structure_size: int,
43
  ) -> PILImage:
44
+ """
45
+ Perform alpha matting on an image using a given mask and threshold values.
46
+
47
+ This function takes a PIL image `img` and a PIL image `mask` as input, along with
48
+ the `foreground_threshold` and `background_threshold` values used to determine
49
+ foreground and background pixels. The `erode_structure_size` parameter specifies
50
+ the size of the erosion structure to be applied to the mask.
51
+
52
+ The function returns a PIL image representing the cutout of the foreground object
53
+ from the original image.
54
+ """
55
  if img.mode == "RGBA" or img.mode == "CMYK":
56
  img = img.convert("RGB")
57
 
58
+ img_array = np.asarray(img)
59
+ mask_array = np.asarray(mask)
60
 
61
+ is_foreground = mask_array > foreground_threshold
62
+ is_background = mask_array < background_threshold
63
 
64
  structure = None
65
  if erode_structure_size > 0:
 
70
  is_foreground = binary_erosion(is_foreground, structure=structure)
71
  is_background = binary_erosion(is_background, structure=structure, border_value=1)
72
 
73
+ trimap = np.full(mask_array.shape, dtype=np.uint8, fill_value=128)
74
  trimap[is_foreground] = 255
75
  trimap[is_background] = 0
76
 
77
+ img_normalized = img_array / 255.0
78
  trimap_normalized = trimap / 255.0
79
 
80
  alpha = estimate_alpha_cf(img_normalized, trimap_normalized)
 
88
 
89
 
90
  def naive_cutout(img: PILImage, mask: PILImage) -> PILImage:
91
+ """
92
+ Perform a simple cutout operation on an image using a mask.
93
+
94
+ This function takes a PIL image `img` and a PIL image `mask` as input.
95
+ It uses the mask to create a new image where the pixels from `img` are
96
+ cut out based on the mask.
97
+
98
+ The function returns a PIL image representing the cutout of the original
99
+ image using the mask.
100
+ """
101
  empty = Image.new("RGBA", (img.size), 0)
102
  cutout = Image.composite(img, empty, mask)
103
  return cutout
104
 
105
 
106
+ def putalpha_cutout(img: PILImage, mask: PILImage) -> PILImage:
107
+ """
108
+ Apply the specified mask to the image as an alpha cutout.
109
+
110
+ Args:
111
+ img (PILImage): The image to be modified.
112
+ mask (PILImage): The mask to be applied.
113
+
114
+ Returns:
115
+ PILImage: The modified image with the alpha cutout applied.
116
+ """
117
+ img.putalpha(mask)
118
+ return img
119
+
120
+
121
  def get_concat_v_multi(imgs: List[PILImage]) -> PILImage:
122
+ """
123
+ Concatenate multiple images vertically.
124
+
125
+ Args:
126
+ imgs (List[PILImage]): The list of images to be concatenated.
127
+
128
+ Returns:
129
+ PILImage: The concatenated image.
130
+ """
131
  pivot = imgs.pop(0)
132
  for im in imgs:
133
  pivot = get_concat_v(pivot, im)
 
135
 
136
 
137
  def get_concat_v(img1: PILImage, img2: PILImage) -> PILImage:
138
+ """
139
+ Concatenate two images vertically.
140
+
141
+ Args:
142
+ img1 (PILImage): The first image.
143
+ img2 (PILImage): The second image to be concatenated below the first image.
144
+
145
+ Returns:
146
+ PILImage: The concatenated image.
147
+ """
148
  dst = Image.new("RGBA", (img1.width, img1.height + img2.height))
149
  dst.paste(img1, (0, 0))
150
  dst.paste(img2, (0, img1.height))
 
160
  """
161
  mask = morphologyEx(mask, MORPH_OPEN, kernel)
162
  mask = GaussianBlur(mask, (5, 5), sigmaX=2, sigmaY=2, borderType=BORDER_DEFAULT)
163
+ mask = np.where(mask < 127, 0, 255).astype(np.uint8) # type: ignore
164
  return mask
165
 
166
 
167
  def apply_background_color(img: PILImage, color: Tuple[int, int, int, int]) -> PILImage:
168
+ """
169
+ Apply the specified background color to the image.
170
+
171
+ Args:
172
+ img (PILImage): The image to be modified.
173
+ color (Tuple[int, int, int, int]): The RGBA color to be applied.
174
+
175
+ Returns:
176
+ PILImage: The modified image with the background color applied.
177
+ """
178
  r, g, b, a = color
179
  colored_image = Image.new("RGBA", img.size, (r, g, b, a))
180
  colored_image.paste(img, mask=img)
 
183
 
184
 
185
  def fix_image_orientation(img: PILImage) -> PILImage:
186
+ """
187
+ Fix the orientation of the image based on its EXIF data.
188
+
189
+ Args:
190
+ img (PILImage): The image to be fixed.
191
+
192
+ Returns:
193
+ PILImage: The fixed image.
194
+ """
195
+ return cast(PILImage, ImageOps.exif_transpose(img))
196
 
197
 
198
  def download_models() -> None:
199
+ """
200
+ Download models for image processing.
201
+ """
202
  for session in sessions_class:
203
  session.download_models()
204
 
 
213
  only_mask: bool = False,
214
  post_process_mask: bool = False,
215
  bgcolor: Optional[Tuple[int, int, int, int]] = None,
216
+ force_return_bytes: bool = False,
217
  *args: Optional[Any],
218
  **kwargs: Optional[Any]
219
  ) -> Union[bytes, PILImage, np.ndarray]:
220
+ """
221
+ Remove the background from an input image.
222
+
223
+ This function takes in various parameters and returns a modified version of the input image with the background removed. The function can handle input data in the form of bytes, a PIL image, or a numpy array. The function first checks the type of the input data and converts it to a PIL image if necessary. It then fixes the orientation of the image and proceeds to perform background removal using the 'u2net' model. The result is a list of binary masks representing the foreground objects in the image. These masks are post-processed and combined to create a final cutout image. If a background color is provided, it is applied to the cutout image. The function returns the resulting cutout image in the format specified by the input 'return_type' parameter or as python bytes if force_return_bytes is true.
224
+
225
+ Parameters:
226
+ data (Union[bytes, PILImage, np.ndarray]): The input image data.
227
+ alpha_matting (bool, optional): Flag indicating whether to use alpha matting. Defaults to False.
228
+ alpha_matting_foreground_threshold (int, optional): Foreground threshold for alpha matting. Defaults to 240.
229
+ alpha_matting_background_threshold (int, optional): Background threshold for alpha matting. Defaults to 10.
230
+ alpha_matting_erode_size (int, optional): Erosion size for alpha matting. Defaults to 10.
231
+ session (Optional[BaseSession], optional): A session object for the 'u2net' model. Defaults to None.
232
+ only_mask (bool, optional): Flag indicating whether to return only the binary masks. Defaults to False.
233
+ post_process_mask (bool, optional): Flag indicating whether to post-process the masks. Defaults to False.
234
+ bgcolor (Optional[Tuple[int, int, int, int]], optional): Background color for the cutout image. Defaults to None.
235
+ force_return_bytes (bool, optional): Flag indicating whether to return the cutout image as bytes. Defaults to False.
236
+ *args (Optional[Any]): Additional positional arguments.
237
+ **kwargs (Optional[Any]): Additional keyword arguments.
238
+
239
+ Returns:
240
+ Union[bytes, PILImage, np.ndarray]: The cutout image with the background removed.
241
+ """
242
+ if isinstance(data, bytes) or force_return_bytes:
243
  return_type = ReturnType.BYTES
244
+ img = cast(PILImage, Image.open(io.BytesIO(cast(bytes, data))))
245
+ elif isinstance(data, PILImage):
246
+ return_type = ReturnType.PILLOW
247
+ img = cast(PILImage, data)
248
  elif isinstance(data, np.ndarray):
249
  return_type = ReturnType.NDARRAY
250
+ img = cast(PILImage, Image.fromarray(data))
251
  else:
252
+ raise ValueError(
253
+ "Input type {} is not supported. Try using force_return_bytes=True to force python bytes output".format(
254
+ type(data)
255
+ )
256
+ )
257
+
258
+ putalpha = kwargs.pop("putalpha", False)
259
 
260
  # Fix image orientation
261
  img = fix_image_orientation(img)
 
283
  alpha_matting_erode_size,
284
  )
285
  except ValueError:
286
+ if putalpha:
287
+ cutout = putalpha_cutout(img, mask)
288
+ else:
289
+ cutout = naive_cutout(img, mask)
290
  else:
291
+ if putalpha:
292
+ cutout = putalpha_cutout(img, mask)
293
+ else:
294
+ cutout = naive_cutout(img, mask)
295
 
296
  cutouts.append(cutout)
297
 
rembg/cli.py CHANGED
@@ -6,9 +6,11 @@ from .commands import command_functions
6
 
7
  @click.group()
8
  @click.version_option(version=_version.get_versions()["version"])
9
- def main() -> None:
10
  pass
11
 
12
 
13
  for command in command_functions:
14
- main.add_command(command)
 
 
 
6
 
7
  @click.group()
8
  @click.version_option(version=_version.get_versions()["version"])
9
+ def _main() -> None:
10
  pass
11
 
12
 
13
  for command in command_functions:
14
+ _main.add_command(command)
15
+
16
+ _main()
rembg/commands/__init__.py CHANGED
@@ -1,13 +1,13 @@
1
- from importlib import import_module
2
- from pathlib import Path
3
- from pkgutil import iter_modules
4
-
5
  command_functions = []
6
 
7
- package_dir = Path(__file__).resolve().parent
8
- for _b, module_name, _p in iter_modules([str(package_dir)]):
9
- module = import_module(f"{__name__}.{module_name}")
10
- for attribute_name in dir(module):
11
- attribute = getattr(module, attribute_name)
12
- if attribute_name.endswith("_command"):
13
- command_functions.append(attribute)
 
 
 
 
 
 
 
 
 
1
  command_functions = []
2
 
3
+ from .b_command import b_command
4
+ from .d_command import d_command
5
+ from .i_command import i_command
6
+ from .p_command import p_command
7
+ from .s_command import s_command
8
+
9
+ command_functions.append(b_command)
10
+ command_functions.append(d_command)
11
+ command_functions.append(i_command)
12
+ command_functions.append(p_command)
13
+ command_functions.append(s_command)
rembg/commands/b_command.py CHANGED
@@ -6,14 +6,14 @@ import sys
6
  from typing import IO
7
 
8
  import click
9
- from PIL import Image
10
 
11
  from ..bg import remove
12
  from ..session_factory import new_session
13
  from ..sessions import sessions_names
14
 
15
 
16
- @click.command(
17
  name="b",
18
  help="for a byte stream as input",
19
  )
@@ -74,7 +74,7 @@ from ..sessions import sessions_names
74
  @click.option(
75
  "-bgc",
76
  "--bgcolor",
77
- default=None,
78
  type=(int, int, int, int),
79
  nargs=4,
80
  help="Background color (R G B A) to replace the removed background with",
@@ -94,7 +94,7 @@ from ..sessions import sessions_names
94
  "image_height",
95
  type=int,
96
  )
97
- def rs_command(
98
  model: str,
99
  extras: str,
100
  image_width: int,
@@ -102,12 +102,28 @@ def rs_command(
102
  output_specifier: str,
103
  **kwargs
104
  ) -> None:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
  try:
106
  kwargs.update(json.loads(extras))
107
  except Exception:
108
  pass
109
 
110
- session = new_session(model)
111
  bytes_per_img = image_width * image_height * 3
112
 
113
  if output_specifier:
@@ -118,7 +134,7 @@ def rs_command(
118
  if not os.path.isdir(output_dir):
119
  os.makedirs(output_dir, exist_ok=True)
120
 
121
- def img_to_byte_array(img: Image) -> bytes:
122
  buff = io.BytesIO()
123
  img.save(buff, format="PNG")
124
  return buff.getvalue()
@@ -146,7 +162,7 @@ def rs_command(
146
  if not img_bytes:
147
  break
148
 
149
- img = Image.frombytes("RGB", (image_width, image_height), img_bytes)
150
  output = remove(img, session=session, **kwargs)
151
 
152
  if output_specifier:
 
6
  from typing import IO
7
 
8
  import click
9
+ from PIL.Image import Image as PILImage
10
 
11
  from ..bg import remove
12
  from ..session_factory import new_session
13
  from ..sessions import sessions_names
14
 
15
 
16
+ @click.command( # type: ignore
17
  name="b",
18
  help="for a byte stream as input",
19
  )
 
74
  @click.option(
75
  "-bgc",
76
  "--bgcolor",
77
+ default=(0, 0, 0, 0),
78
  type=(int, int, int, int),
79
  nargs=4,
80
  help="Background color (R G B A) to replace the removed background with",
 
94
  "image_height",
95
  type=int,
96
  )
97
+ def b_command(
98
  model: str,
99
  extras: str,
100
  image_width: int,
 
102
  output_specifier: str,
103
  **kwargs
104
  ) -> None:
105
+ """
106
+ Command-line interface for processing images by removing the background using a specified model and generating a mask.
107
+
108
+ This CLI command takes several options and arguments to configure the background removal process and save the processed images.
109
+
110
+ Parameters:
111
+ model (str): The name of the model to use for background removal.
112
+ extras (str): Additional options in JSON format that can be passed to customize the background removal process.
113
+ image_width (int): The width of the input images in pixels.
114
+ image_height (int): The height of the input images in pixels.
115
+ output_specifier (str): A printf-style specifier for the output filenames. If specified, the processed images will be saved to the specified output directory with filenames generated using the specifier.
116
+ **kwargs: Additional keyword arguments that can be used to customize the background removal process.
117
+
118
+ Returns:
119
+ None
120
+ """
121
  try:
122
  kwargs.update(json.loads(extras))
123
  except Exception:
124
  pass
125
 
126
+ session = new_session(model, **kwargs)
127
  bytes_per_img = image_width * image_height * 3
128
 
129
  if output_specifier:
 
134
  if not os.path.isdir(output_dir):
135
  os.makedirs(output_dir, exist_ok=True)
136
 
137
+ def img_to_byte_array(img: PILImage) -> bytes:
138
  buff = io.BytesIO()
139
  img.save(buff, format="PNG")
140
  return buff.getvalue()
 
162
  if not img_bytes:
163
  break
164
 
165
+ img = PILImage.frombytes("RGB", (image_width, image_height), img_bytes)
166
  output = remove(img, session=session, **kwargs)
167
 
168
  if output_specifier:
rembg/commands/d_command.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import click
2
+
3
+ from ..bg import download_models
4
+
5
+
6
+ @click.command( # type: ignore
7
+ name="d",
8
+ help="download all models",
9
+ )
10
+ def d_command(*args, **kwargs) -> None:
11
+ """
12
+ Download all models
13
+ """
14
+ download_models()
rembg/commands/i_command.py CHANGED
@@ -9,7 +9,7 @@ from ..session_factory import new_session
9
  from ..sessions import sessions_names
10
 
11
 
12
- @click.command(
13
  name="i",
14
  help="for a file as input",
15
  )
@@ -70,7 +70,7 @@ from ..sessions import sessions_names
70
  @click.option(
71
  "-bgc",
72
  "--bgcolor",
73
- default=None,
74
  type=(int, int, int, int),
75
  nargs=4,
76
  help="Background color (R G B A) to replace the removed background with",
@@ -85,9 +85,24 @@ from ..sessions import sessions_names
85
  type=click.File("wb", lazy=True),
86
  )
87
  def i_command(model: str, extras: str, input: IO, output: IO, **kwargs) -> None:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
  try:
89
  kwargs.update(json.loads(extras))
90
  except Exception:
91
  pass
92
 
93
- output.write(remove(input.read(), session=new_session(model), **kwargs))
 
9
  from ..sessions import sessions_names
10
 
11
 
12
+ @click.command( # type: ignore
13
  name="i",
14
  help="for a file as input",
15
  )
 
70
  @click.option(
71
  "-bgc",
72
  "--bgcolor",
73
+ default=(0, 0, 0, 0),
74
  type=(int, int, int, int),
75
  nargs=4,
76
  help="Background color (R G B A) to replace the removed background with",
 
85
  type=click.File("wb", lazy=True),
86
  )
87
  def i_command(model: str, extras: str, input: IO, output: IO, **kwargs) -> None:
88
+ """
89
+ Click command line interface function to process an input file based on the provided options.
90
+
91
+ This function is the entry point for the CLI program. It reads an input file, applies image processing operations based on the provided options, and writes the output to a file.
92
+
93
+ Parameters:
94
+ model (str): The name of the model to use for image processing.
95
+ extras (str): Additional options in JSON format.
96
+ input: The input file to process.
97
+ output: The output file to write the processed image to.
98
+ **kwargs: Additional keyword arguments corresponding to the command line options.
99
+
100
+ Returns:
101
+ None
102
+ """
103
  try:
104
  kwargs.update(json.loads(extras))
105
  except Exception:
106
  pass
107
 
108
+ output.write(remove(input.read(), session=new_session(model, **kwargs), **kwargs))
rembg/commands/p_command.py CHANGED
@@ -14,7 +14,7 @@ from ..session_factory import new_session
14
  from ..sessions import sessions_names
15
 
16
 
17
- @click.command(
18
  name="p",
19
  help="for a folder as input",
20
  )
@@ -80,10 +80,18 @@ from ..sessions import sessions_names
80
  show_default=True,
81
  help="watches a folder for changes",
82
  )
 
 
 
 
 
 
 
 
83
  @click.option(
84
  "-bgc",
85
  "--bgcolor",
86
- default=None,
87
  type=(int, int, int, int),
88
  nargs=4,
89
  help="Background color (R G B A) to replace the removed background with",
@@ -115,14 +123,36 @@ def p_command(
115
  input: pathlib.Path,
116
  output: pathlib.Path,
117
  watch: bool,
 
118
  **kwargs,
119
  ) -> None:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
120
  try:
121
  kwargs.update(json.loads(extras))
122
  except Exception:
123
  pass
124
 
125
- session = new_session(model)
126
 
127
  def process(each_input: pathlib.Path) -> None:
128
  try:
@@ -147,33 +177,48 @@ def p_command(
147
  print(
148
  f"processed: {each_input.absolute()} -> {each_output.absolute()}"
149
  )
 
 
 
 
150
  except Exception as e:
151
  print(e)
152
 
153
  inputs = list(input.glob("**/*"))
154
  if not watch:
155
- inputs = tqdm(inputs)
156
 
157
- for each_input in inputs:
158
  if not each_input.is_dir():
159
  process(each_input)
160
 
161
  if watch:
 
162
  observer = Observer()
163
 
164
  class EventHandler(FileSystemEventHandler):
165
  def on_any_event(self, event: FileSystemEvent) -> None:
166
- if not (
167
- event.is_directory or event.event_type in ["deleted", "closed"]
 
 
 
 
168
  ):
169
- process(pathlib.Path(event.src_path))
 
 
 
 
 
 
170
 
171
  event_handler = EventHandler()
172
- observer.schedule(event_handler, input, recursive=False)
173
  observer.start()
174
 
175
  try:
176
- while True:
177
  time.sleep(1)
178
 
179
  finally:
 
14
  from ..sessions import sessions_names
15
 
16
 
17
+ @click.command( # type: ignore
18
  name="p",
19
  help="for a folder as input",
20
  )
 
80
  show_default=True,
81
  help="watches a folder for changes",
82
  )
83
+ @click.option(
84
+ "-d",
85
+ "--delete_input",
86
+ default=False,
87
+ is_flag=True,
88
+ show_default=True,
89
+ help="delete input file after processing",
90
+ )
91
  @click.option(
92
  "-bgc",
93
  "--bgcolor",
94
+ default=(0, 0, 0, 0),
95
  type=(int, int, int, int),
96
  nargs=4,
97
  help="Background color (R G B A) to replace the removed background with",
 
123
  input: pathlib.Path,
124
  output: pathlib.Path,
125
  watch: bool,
126
+ delete_input: bool,
127
  **kwargs,
128
  ) -> None:
129
+ """
130
+ Command-line interface (CLI) program for performing background removal on images in a folder.
131
+
132
+ This program takes a folder as input and uses a specified model to remove the background from the images in the folder.
133
+ It provides various options for configuration, such as choosing the model, enabling alpha matting, setting trimap thresholds, erode size, etc.
134
+ Additional options include outputting only the mask and post-processing the mask.
135
+ The program can also watch the input folder for changes and automatically process new images.
136
+ The resulting images with the background removed are saved in the specified output folder.
137
+
138
+ Parameters:
139
+ model (str): The name of the model to use for background removal.
140
+ extras (str): Additional options in JSON format.
141
+ input (pathlib.Path): The path to the input folder.
142
+ output (pathlib.Path): The path to the output folder.
143
+ watch (bool): Whether to watch the input folder for changes.
144
+ delete_input (bool): Whether to delete the input file after processing.
145
+ **kwargs: Additional keyword arguments.
146
+
147
+ Returns:
148
+ None
149
+ """
150
  try:
151
  kwargs.update(json.loads(extras))
152
  except Exception:
153
  pass
154
 
155
+ session = new_session(model, **kwargs)
156
 
157
  def process(each_input: pathlib.Path) -> None:
158
  try:
 
177
  print(
178
  f"processed: {each_input.absolute()} -> {each_output.absolute()}"
179
  )
180
+
181
+ if delete_input:
182
+ each_input.unlink()
183
+
184
  except Exception as e:
185
  print(e)
186
 
187
  inputs = list(input.glob("**/*"))
188
  if not watch:
189
+ inputs_tqdm = tqdm(inputs)
190
 
191
+ for each_input in inputs_tqdm:
192
  if not each_input.is_dir():
193
  process(each_input)
194
 
195
  if watch:
196
+ should_watch = True
197
  observer = Observer()
198
 
199
  class EventHandler(FileSystemEventHandler):
200
  def on_any_event(self, event: FileSystemEvent) -> None:
201
+ src_path = cast(str, event.src_path)
202
+ if (
203
+ not (
204
+ event.is_directory or event.event_type in ["deleted", "closed"]
205
+ )
206
+ and pathlib.Path(src_path).exists()
207
  ):
208
+ if src_path.endswith("stop.txt"):
209
+ nonlocal should_watch
210
+ should_watch = False
211
+ pathlib.Path(src_path).unlink()
212
+ return
213
+
214
+ process(pathlib.Path(src_path))
215
 
216
  event_handler = EventHandler()
217
+ observer.schedule(event_handler, str(input), recursive=False)
218
  observer.start()
219
 
220
  try:
221
+ while should_watch:
222
  time.sleep(1)
223
 
224
  finally:
rembg/commands/s_command.py CHANGED
@@ -19,18 +19,26 @@ from ..sessions import sessions_names
19
  from ..sessions.base import BaseSession
20
 
21
 
22
- @click.command(
23
  name="s",
24
  help="for a http server",
25
  )
26
  @click.option(
27
  "-p",
28
  "--port",
29
- default=5000,
30
  type=int,
31
  show_default=True,
32
  help="port",
33
  )
 
 
 
 
 
 
 
 
34
  @click.option(
35
  "-l",
36
  "--log_level",
@@ -47,7 +55,13 @@ from ..sessions.base import BaseSession
47
  show_default=True,
48
  help="number of worker threads",
49
  )
50
- def s_command(port: int, log_level: str, threads: int) -> None:
 
 
 
 
 
 
51
  sessions: dict[str, BaseSession] = {}
52
  tags_metadata = [
53
  {
@@ -186,7 +200,9 @@ def s_command(port: int, log_level: str, threads: int) -> None:
186
  return Response(
187
  remove(
188
  content,
189
- session=sessions.setdefault(commons.model, new_session(commons.model)),
 
 
190
  alpha_matting=commons.a,
191
  alpha_matting_foreground_threshold=commons.af,
192
  alpha_matting_background_threshold=commons.ab,
@@ -245,12 +261,27 @@ def s_command(port: int, log_level: str, threads: int) -> None:
245
  return await asyncify(im_without_bg)(file, commons) # type: ignore
246
 
247
  def gr_app(app):
248
- def inference(input_path, model):
249
  output_path = "output.png"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
250
  with open(input_path, "rb") as i:
251
  with open(output_path, "wb") as o:
252
  input = i.read()
253
- output = remove(input, session=new_session(model))
254
  o.write(output)
255
  return os.path.join(output_path)
256
 
@@ -258,28 +289,33 @@ def s_command(port: int, log_level: str, threads: int) -> None:
258
  inference,
259
  [
260
  gr.components.Image(type="filepath", label="Input"),
261
- gr.components.Dropdown(
262
- [
263
- "u2net",
264
- "u2netp",
265
- "u2net_human_seg",
266
- "u2net_cloth_seg",
267
- "silueta",
268
- "isnet-general-use",
269
- "isnet-anime",
270
- ],
271
- value="u2net",
272
- label="Models",
273
  ),
 
 
 
 
 
 
274
  ],
275
  gr.components.Image(type="filepath", label="Output"),
 
276
  )
277
 
278
- interface.queue(concurrency_count=3)
279
  app = gr.mount_gradio_app(app, interface, path="/")
280
  return app
281
 
282
- print(f"To access the API documentation, go to http://localhost:{port}/api")
283
- print(f"To access the UI, go to http://localhost:{port}")
 
 
 
 
284
 
285
- uvicorn.run(gr_app(app), host="0.0.0.0", port=port, log_level=log_level)
 
19
  from ..sessions.base import BaseSession
20
 
21
 
22
+ @click.command( # type: ignore
23
  name="s",
24
  help="for a http server",
25
  )
26
  @click.option(
27
  "-p",
28
  "--port",
29
+ default=7000,
30
  type=int,
31
  show_default=True,
32
  help="port",
33
  )
34
+ @click.option(
35
+ "-h",
36
+ "--host",
37
+ default="0.0.0.0",
38
+ type=str,
39
+ show_default=True,
40
+ help="host",
41
+ )
42
  @click.option(
43
  "-l",
44
  "--log_level",
 
55
  show_default=True,
56
  help="number of worker threads",
57
  )
58
+ def s_command(port: int, host: str, log_level: str, threads: int) -> None:
59
+ """
60
+ Command-line interface for running the FastAPI web server.
61
+
62
+ This function starts the FastAPI web server with the specified port and log level.
63
+ If the number of worker threads is specified, it sets the thread limiter accordingly.
64
+ """
65
  sessions: dict[str, BaseSession] = {}
66
  tags_metadata = [
67
  {
 
200
  return Response(
201
  remove(
202
  content,
203
+ session=sessions.setdefault(
204
+ commons.model, new_session(commons.model, **kwargs)
205
+ ),
206
  alpha_matting=commons.a,
207
  alpha_matting_foreground_threshold=commons.af,
208
  alpha_matting_background_threshold=commons.ab,
 
261
  return await asyncify(im_without_bg)(file, commons) # type: ignore
262
 
263
  def gr_app(app):
264
+ def inference(input_path, model, *args):
265
  output_path = "output.png"
266
+ a, af, ab, ae, om, ppm, cmd_args = args
267
+
268
+ kwargs = {
269
+ "alpha_matting": a,
270
+ "alpha_matting_foreground_threshold": af,
271
+ "alpha_matting_background_threshold": ab,
272
+ "alpha_matting_erode_size": ae,
273
+ "only_mask": om,
274
+ "post_process_mask": ppm,
275
+ }
276
+
277
+ if cmd_args:
278
+ kwargs.update(json.loads(cmd_args))
279
+ kwargs["session"] = new_session(model, **kwargs)
280
+
281
  with open(input_path, "rb") as i:
282
  with open(output_path, "wb") as o:
283
  input = i.read()
284
+ output = remove(input, **kwargs)
285
  o.write(output)
286
  return os.path.join(output_path)
287
 
 
289
  inference,
290
  [
291
  gr.components.Image(type="filepath", label="Input"),
292
+ gr.components.Dropdown(sessions_names, value="u2net", label="Models"),
293
+ gr.components.Checkbox(value=True, label="Alpha matting"),
294
+ gr.components.Slider(
295
+ value=240, minimum=0, maximum=255, label="Foreground threshold"
296
+ ),
297
+ gr.components.Slider(
298
+ value=10, minimum=0, maximum=255, label="Background threshold"
 
 
 
 
 
299
  ),
300
+ gr.components.Slider(
301
+ value=40, minimum=0, maximum=255, label="Erosion size"
302
+ ),
303
+ gr.components.Checkbox(value=False, label="Only mask"),
304
+ gr.components.Checkbox(value=True, label="Post process mask"),
305
+ gr.components.Textbox(label="Arguments"),
306
  ],
307
  gr.components.Image(type="filepath", label="Output"),
308
+ concurrency_limit=3,
309
  )
310
 
 
311
  app = gr.mount_gradio_app(app, interface, path="/")
312
  return app
313
 
314
+ print(
315
+ f"To access the API documentation, go to http://{'localhost' if host == '0.0.0.0' else host}:{port}/api"
316
+ )
317
+ print(
318
+ f"To access the UI, go to http://{'localhost' if host == '0.0.0.0' else host}:{port}"
319
+ )
320
 
321
+ uvicorn.run(gr_app(app), host=host, port=port, log_level=log_level)
rembg/session_factory.py CHANGED
@@ -11,6 +11,23 @@ from .sessions.u2net import U2netSession
11
  def new_session(
12
  model_name: str = "u2net", providers=None, *args, **kwargs
13
  ) -> BaseSession:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  session_class: Type[BaseSession] = U2netSession
15
 
16
  for sc in sessions_class:
@@ -22,5 +39,6 @@ def new_session(
22
 
23
  if "OMP_NUM_THREADS" in os.environ:
24
  sess_opts.inter_op_num_threads = int(os.environ["OMP_NUM_THREADS"])
 
25
 
26
  return session_class(model_name, sess_opts, providers, *args, **kwargs)
 
11
  def new_session(
12
  model_name: str = "u2net", providers=None, *args, **kwargs
13
  ) -> BaseSession:
14
+ """
15
+ Create a new session object based on the specified model name.
16
+
17
+ This function searches for the session class based on the model name in the 'sessions_class' list.
18
+ It then creates an instance of the session class with the provided arguments.
19
+ The 'sess_opts' object is created using the 'ort.SessionOptions()' constructor.
20
+ If the 'OMP_NUM_THREADS' environment variable is set, the 'inter_op_num_threads' option of 'sess_opts' is set to its value.
21
+
22
+ Parameters:
23
+ model_name (str): The name of the model.
24
+ providers: The providers for the session.
25
+ *args: Additional positional arguments.
26
+ **kwargs: Additional keyword arguments.
27
+
28
+ Returns:
29
+ BaseSession: The created session object.
30
+ """
31
  session_class: Type[BaseSession] = U2netSession
32
 
33
  for sc in sessions_class:
 
39
 
40
  if "OMP_NUM_THREADS" in os.environ:
41
  sess_opts.inter_op_num_threads = int(os.environ["OMP_NUM_THREADS"])
42
+ sess_opts.intra_op_num_threads = int(os.environ["OMP_NUM_THREADS"])
43
 
44
  return session_class(model_name, sess_opts, providers, *args, **kwargs)
rembg/sessions/__init__.py CHANGED
@@ -1,22 +1,88 @@
1
- from importlib import import_module
2
- from inspect import isclass
3
- from pathlib import Path
4
- from pkgutil import iter_modules
5
 
6
  from .base import BaseSession
7
 
8
- sessions_class = []
9
- sessions_names = []
10
-
11
- package_dir = Path(__file__).resolve().parent
12
- for _b, module_name, _p in iter_modules([str(package_dir)]):
13
- module = import_module(f"{__name__}.{module_name}")
14
- for attribute_name in dir(module):
15
- attribute = getattr(module, attribute_name)
16
- if (
17
- isclass(attribute)
18
- and issubclass(attribute, BaseSession)
19
- and attribute != BaseSession
20
- ):
21
- sessions_class.append(attribute)
22
- sessions_names.append(attribute.name())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from typing import List
 
4
 
5
  from .base import BaseSession
6
 
7
+ sessions_class: List[type[BaseSession]] = []
8
+ sessions_names: List[str] = []
9
+
10
+ from .birefnet_general import BiRefNetSessionGeneral
11
+
12
+ sessions_class.append(BiRefNetSessionGeneral)
13
+ sessions_names.append(BiRefNetSessionGeneral.name())
14
+
15
+ from .birefnet_general_lite import BiRefNetSessionGeneralLite
16
+
17
+ sessions_class.append(BiRefNetSessionGeneralLite)
18
+ sessions_names.append(BiRefNetSessionGeneralLite.name())
19
+
20
+ from .birefnet_portrait import BiRefNetSessionPortrait
21
+
22
+ sessions_class.append(BiRefNetSessionPortrait)
23
+ sessions_names.append(BiRefNetSessionPortrait.name())
24
+
25
+ from .birefnet_dis import BiRefNetSessionDIS
26
+
27
+ sessions_class.append(BiRefNetSessionDIS)
28
+ sessions_names.append(BiRefNetSessionDIS.name())
29
+
30
+ from .birefnet_hrsod import BiRefNetSessionHRSOD
31
+
32
+ sessions_class.append(BiRefNetSessionHRSOD)
33
+ sessions_names.append(BiRefNetSessionHRSOD.name())
34
+
35
+ from .birefnet_cod import BiRefNetSessionCOD
36
+
37
+ sessions_class.append(BiRefNetSessionCOD)
38
+ sessions_names.append(BiRefNetSessionCOD.name())
39
+
40
+ from .birefnet_massive import BiRefNetSessionMassive
41
+
42
+ sessions_class.append(BiRefNetSessionMassive)
43
+ sessions_names.append(BiRefNetSessionMassive.name())
44
+
45
+ from .dis_anime import DisSession
46
+
47
+ sessions_class.append(DisSession)
48
+ sessions_names.append(DisSession.name())
49
+
50
+ from .dis_general_use import DisSession as DisSessionGeneralUse
51
+
52
+ sessions_class.append(DisSessionGeneralUse)
53
+ sessions_names.append(DisSessionGeneralUse.name())
54
+
55
+ from .sam import SamSession
56
+
57
+ sessions_class.append(SamSession)
58
+ sessions_names.append(SamSession.name())
59
+
60
+ from .silueta import SiluetaSession
61
+
62
+ sessions_class.append(SiluetaSession)
63
+ sessions_names.append(SiluetaSession.name())
64
+
65
+ from .u2net_cloth_seg import Unet2ClothSession
66
+
67
+ sessions_class.append(Unet2ClothSession)
68
+ sessions_names.append(Unet2ClothSession.name())
69
+
70
+ from .u2net_custom import U2netCustomSession
71
+
72
+ sessions_class.append(U2netCustomSession)
73
+ sessions_names.append(U2netCustomSession.name())
74
+
75
+ from .u2net_human_seg import U2netHumanSegSession
76
+
77
+ sessions_class.append(U2netHumanSegSession)
78
+ sessions_names.append(U2netHumanSegSession.name())
79
+
80
+ from .u2net import U2netSession
81
+
82
+ sessions_class.append(U2netSession)
83
+ sessions_names.append(U2netSession.name())
84
+
85
+ from .u2netp import U2netpSession
86
+
87
+ sessions_class.append(U2netpSession)
88
+ sessions_names.append(U2netpSession.name())
rembg/sessions/base.py CHANGED
@@ -8,6 +8,8 @@ from PIL.Image import Image as PILImage
8
 
9
 
10
  class BaseSession:
 
 
11
  def __init__(
12
  self,
13
  model_name: str,
@@ -16,6 +18,7 @@ class BaseSession:
16
  *args,
17
  **kwargs
18
  ):
 
19
  self.model_name = model_name
20
 
21
  self.providers = []
@@ -29,7 +32,7 @@ class BaseSession:
29
  self.providers.extend(_providers)
30
 
31
  self.inner_session = ort.InferenceSession(
32
- str(self.__class__.download_models()),
33
  providers=self.providers,
34
  sess_options=sess_opts,
35
  )
@@ -43,7 +46,7 @@ class BaseSession:
43
  *args,
44
  **kwargs
45
  ) -> Dict[str, np.ndarray]:
46
- im = img.convert("RGB").resize(size, Image.LANCZOS)
47
 
48
  im_ary = np.array(im)
49
  im_ary = im_ary / np.max(im_ary)
 
8
 
9
 
10
  class BaseSession:
11
+ """This is a base class for managing a session with a machine learning model."""
12
+
13
  def __init__(
14
  self,
15
  model_name: str,
 
18
  *args,
19
  **kwargs
20
  ):
21
+ """Initialize an instance of the BaseSession class."""
22
  self.model_name = model_name
23
 
24
  self.providers = []
 
32
  self.providers.extend(_providers)
33
 
34
  self.inner_session = ort.InferenceSession(
35
+ str(self.__class__.download_models(*args, **kwargs)),
36
  providers=self.providers,
37
  sess_options=sess_opts,
38
  )
 
46
  *args,
47
  **kwargs
48
  ) -> Dict[str, np.ndarray]:
49
+ im = img.convert("RGB").resize(size, Image.Resampling.LANCZOS)
50
 
51
  im_ary = np.array(im)
52
  im_ary = im_ary / np.max(im_ary)
rembg/sessions/birefnet_cod.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import pooch
4
+
5
+ from . import BiRefNetSessionGeneral
6
+
7
+
8
+ class BiRefNetSessionCOD(BiRefNetSessionGeneral):
9
+ """
10
+ This class represents a BiRefNet-COD session, which is a subclass of BiRefNetSessionGeneral.
11
+ """
12
+
13
+ @classmethod
14
+ def download_models(cls, *args, **kwargs):
15
+ """
16
+ Downloads the BiRefNet-COD model file from a specific URL and saves it.
17
+
18
+ Parameters:
19
+ *args: Additional positional arguments.
20
+ **kwargs: Additional keyword arguments.
21
+
22
+ Returns:
23
+ str: The path to the downloaded model file.
24
+ """
25
+ fname = f"{cls.name(*args, **kwargs)}.onnx"
26
+ pooch.retrieve(
27
+ "https://github.com/danielgatis/rembg/releases/download/v0.0.0/BiRefNet-COD-epoch_125.onnx",
28
+ (
29
+ None
30
+ if cls.checksum_disabled(*args, **kwargs)
31
+ else "md5:f6d0d21ca89d287f17e7afe9f5fd3b45"
32
+ ),
33
+ fname=fname,
34
+ path=cls.u2net_home(*args, **kwargs),
35
+ progressbar=True,
36
+ )
37
+
38
+ return os.path.join(cls.u2net_home(*args, **kwargs), fname)
39
+
40
+ @classmethod
41
+ def name(cls, *args, **kwargs):
42
+ """
43
+ Returns the name of the BiRefNet-COD session.
44
+
45
+ Parameters:
46
+ *args: Additional positional arguments.
47
+ **kwargs: Additional keyword arguments.
48
+
49
+ Returns:
50
+ str: The name of the session.
51
+ """
52
+ return "birefnet-cod"
rembg/sessions/birefnet_dis.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import pooch
4
+
5
+ from . import BiRefNetSessionGeneral
6
+
7
+
8
+ class BiRefNetSessionDIS(BiRefNetSessionGeneral):
9
+ """
10
+ This class represents a BiRefNet-DIS session, which is a subclass of BiRefNetSessionGeneral.
11
+ """
12
+
13
+ @classmethod
14
+ def download_models(cls, *args, **kwargs):
15
+ """
16
+ Downloads the BiRefNet-DIS model file from a specific URL and saves it.
17
+
18
+ Parameters:
19
+ *args: Additional positional arguments.
20
+ **kwargs: Additional keyword arguments.
21
+
22
+ Returns:
23
+ str: The path to the downloaded model file.
24
+ """
25
+ fname = f"{cls.name(*args, **kwargs)}.onnx"
26
+ pooch.retrieve(
27
+ "https://github.com/danielgatis/rembg/releases/download/v0.0.0/BiRefNet-DIS-epoch_590.onnx",
28
+ (
29
+ None
30
+ if cls.checksum_disabled(*args, **kwargs)
31
+ else "md5:2d4d44102b446f33a4ebb2e56c051f2b"
32
+ ),
33
+ fname=fname,
34
+ path=cls.u2net_home(*args, **kwargs),
35
+ progressbar=True,
36
+ )
37
+
38
+ return os.path.join(cls.u2net_home(*args, **kwargs), fname)
39
+
40
+ @classmethod
41
+ def name(cls, *args, **kwargs):
42
+ """
43
+ Returns the name of the BiRefNet-DIS session.
44
+
45
+ Parameters:
46
+ *args: Additional positional arguments.
47
+ **kwargs: Additional keyword arguments.
48
+
49
+ Returns:
50
+ str: The name of the session.
51
+ """
52
+ return "birefnet-dis"
rembg/sessions/birefnet_general.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import List
3
+
4
+ import numpy as np
5
+ import pooch
6
+ from PIL import Image
7
+ from PIL.Image import Image as PILImage
8
+
9
+ from .base import BaseSession
10
+
11
+
12
+ class BiRefNetSessionGeneral(BaseSession):
13
+ """
14
+ This class represents a BiRefNet-General session, which is a subclass of BaseSession.
15
+ """
16
+
17
+ def sigmoid(self, mat):
18
+ return 1 / (1 + np.exp(-mat))
19
+
20
+ def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]:
21
+ """
22
+ Predicts the output masks for the input image using the inner session.
23
+
24
+ Parameters:
25
+ img (PILImage): The input image.
26
+ *args: Additional positional arguments.
27
+ **kwargs: Additional keyword arguments.
28
+
29
+ Returns:
30
+ List[PILImage]: The list of output masks.
31
+ """
32
+ ort_outs = self.inner_session.run(
33
+ None,
34
+ self.normalize(
35
+ img, (0.485, 0.456, 0.406), (0.229, 0.224, 0.225), (1024, 1024)
36
+ ),
37
+ )
38
+
39
+ pred = self.sigmoid(ort_outs[0][:, 0, :, :])
40
+
41
+ ma = np.max(pred)
42
+ mi = np.min(pred)
43
+
44
+ pred = (pred - mi) / (ma - mi)
45
+ pred = np.squeeze(pred)
46
+
47
+ mask = Image.fromarray((pred * 255).astype("uint8"), mode="L")
48
+ mask = mask.resize(img.size, Image.Resampling.LANCZOS)
49
+
50
+ return [mask]
51
+
52
+ @classmethod
53
+ def download_models(cls, *args, **kwargs):
54
+ """
55
+ Downloads the BiRefNet-General model file from a specific URL and saves it.
56
+
57
+ Parameters:
58
+ *args: Additional positional arguments.
59
+ **kwargs: Additional keyword arguments.
60
+
61
+ Returns:
62
+ str: The path to the downloaded model file.
63
+ """
64
+ fname = f"{cls.name(*args, **kwargs)}.onnx"
65
+ pooch.retrieve(
66
+ "https://github.com/danielgatis/rembg/releases/download/v0.0.0/BiRefNet-general-epoch_244.onnx",
67
+ (
68
+ None
69
+ if cls.checksum_disabled(*args, **kwargs)
70
+ else "md5:7a35a0141cbbc80de11d9c9a28f52697"
71
+ ),
72
+ fname=fname,
73
+ path=cls.u2net_home(*args, **kwargs),
74
+ progressbar=True,
75
+ )
76
+
77
+ return os.path.join(cls.u2net_home(*args, **kwargs), fname)
78
+
79
+ @classmethod
80
+ def name(cls, *args, **kwargs):
81
+ """
82
+ Returns the name of the BiRefNet-General session.
83
+
84
+ Parameters:
85
+ *args: Additional positional arguments.
86
+ **kwargs: Additional keyword arguments.
87
+
88
+ Returns:
89
+ str: The name of the session.
90
+ """
91
+ return "birefnet-general"
rembg/sessions/birefnet_general_lite.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import pooch
4
+
5
+ from . import BiRefNetSessionGeneral
6
+
7
+
8
+ class BiRefNetSessionGeneralLite(BiRefNetSessionGeneral):
9
+ """
10
+ This class represents a BiRefNet-General-Lite session, which is a subclass of BiRefNetSessionGeneral.
11
+ """
12
+
13
+ @classmethod
14
+ def download_models(cls, *args, **kwargs):
15
+ """
16
+ Downloads the BiRefNet-General-Lite model file from a specific URL and saves it.
17
+
18
+ Parameters:
19
+ *args: Additional positional arguments.
20
+ **kwargs: Additional keyword arguments.
21
+
22
+ Returns:
23
+ str: The path to the downloaded model file.
24
+ """
25
+ fname = f"{cls.name(*args, **kwargs)}.onnx"
26
+ pooch.retrieve(
27
+ "https://github.com/danielgatis/rembg/releases/download/v0.0.0/BiRefNet-general-bb_swin_v1_tiny-epoch_232.onnx",
28
+ (
29
+ None
30
+ if cls.checksum_disabled(*args, **kwargs)
31
+ else "md5:4fab47adc4ff364be1713e97b7e66334"
32
+ ),
33
+ fname=fname,
34
+ path=cls.u2net_home(*args, **kwargs),
35
+ progressbar=True,
36
+ )
37
+
38
+ return os.path.join(cls.u2net_home(*args, **kwargs), fname)
39
+
40
+ @classmethod
41
+ def name(cls, *args, **kwargs):
42
+ """
43
+ Returns the name of the BiRefNet-General-Lite session.
44
+
45
+ Parameters:
46
+ *args: Additional positional arguments.
47
+ **kwargs: Additional keyword arguments.
48
+
49
+ Returns:
50
+ str: The name of the session.
51
+ """
52
+ return "birefnet-general-lite"
rembg/sessions/birefnet_hrsod.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import pooch
4
+
5
+ from . import BiRefNetSessionGeneral
6
+
7
+
8
+ class BiRefNetSessionHRSOD(BiRefNetSessionGeneral):
9
+ """
10
+ This class represents a BiRefNet-HRSOD session, which is a subclass of BiRefNetSessionGeneral.
11
+ """
12
+
13
+ @classmethod
14
+ def download_models(cls, *args, **kwargs):
15
+ """
16
+ Downloads the BiRefNet-HRSOD model file from a specific URL and saves it.
17
+
18
+ Parameters:
19
+ *args: Additional positional arguments.
20
+ **kwargs: Additional keyword arguments.
21
+
22
+ Returns:
23
+ str: The path to the downloaded model file.
24
+ """
25
+ fname = f"{cls.name(*args, **kwargs)}.onnx"
26
+ pooch.retrieve(
27
+ "https://github.com/danielgatis/rembg/releases/download/v0.0.0/BiRefNet-HRSOD_DHU-epoch_115.onnx",
28
+ (
29
+ None
30
+ if cls.checksum_disabled(*args, **kwargs)
31
+ else "md5:c017ade5de8a50ff0fd74d790d268dda"
32
+ ),
33
+ fname=fname,
34
+ path=cls.u2net_home(*args, **kwargs),
35
+ progressbar=True,
36
+ )
37
+
38
+ return os.path.join(cls.u2net_home(*args, **kwargs), fname)
39
+
40
+ @classmethod
41
+ def name(cls, *args, **kwargs):
42
+ """
43
+ Returns the name of the BiRefNet-HRSOD session.
44
+
45
+ Parameters:
46
+ *args: Additional positional arguments.
47
+ **kwargs: Additional keyword arguments.
48
+
49
+ Returns:
50
+ str: The name of the session.
51
+ """
52
+ return "birefnet-hrsod"
rembg/sessions/birefnet_massive.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import pooch
4
+
5
+ from . import BiRefNetSessionGeneral
6
+
7
+
8
+ class BiRefNetSessionMassive(BiRefNetSessionGeneral):
9
+ """
10
+ This class represents a BiRefNet-Massive session, which is a subclass of BiRefNetSessionGeneral.
11
+ """
12
+
13
+ @classmethod
14
+ def download_models(cls, *args, **kwargs):
15
+ """
16
+ Downloads the BiRefNet-Massive model file from a specific URL and saves it.
17
+
18
+ Parameters:
19
+ *args: Additional positional arguments.
20
+ **kwargs: Additional keyword arguments.
21
+
22
+ Returns:
23
+ str: The path to the downloaded model file.
24
+ """
25
+ fname = f"{cls.name(*args, **kwargs)}.onnx"
26
+ pooch.retrieve(
27
+ "https://github.com/danielgatis/rembg/releases/download/v0.0.0/BiRefNet-massive-TR_DIS5K_TR_TEs-epoch_420.onnx",
28
+ (
29
+ None
30
+ if cls.checksum_disabled(*args, **kwargs)
31
+ else "md5:33e726a2136a3d59eb0fdf613e31e3e9"
32
+ ),
33
+ fname=fname,
34
+ path=cls.u2net_home(*args, **kwargs),
35
+ progressbar=True,
36
+ )
37
+
38
+ return os.path.join(cls.u2net_home(*args, **kwargs), fname)
39
+
40
+ @classmethod
41
+ def name(cls, *args, **kwargs):
42
+ """
43
+ Returns the name of the BiRefNet-Massive session.
44
+
45
+ Parameters:
46
+ *args: Additional positional arguments.
47
+ **kwargs: Additional keyword arguments.
48
+
49
+ Returns:
50
+ str: The name of the session.
51
+ """
52
+ return "birefnet-massive"
rembg/sessions/birefnet_portrait.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import pooch
4
+
5
+ from . import BiRefNetSessionGeneral
6
+
7
+
8
+ class BiRefNetSessionPortrait(BiRefNetSessionGeneral):
9
+ """
10
+ This class represents a BiRefNet-Portrait session, which is a subclass of BiRefNetSessionGeneral.
11
+ """
12
+
13
+ @classmethod
14
+ def download_models(cls, *args, **kwargs):
15
+ """
16
+ Downloads the BiRefNet-Portrait model file from a specific URL and saves it.
17
+
18
+ Parameters:
19
+ *args: Additional positional arguments.
20
+ **kwargs: Additional keyword arguments.
21
+
22
+ Returns:
23
+ str: The path to the downloaded model file.
24
+ """
25
+ fname = f"{cls.name(*args, **kwargs)}.onnx"
26
+ pooch.retrieve(
27
+ "https://github.com/danielgatis/rembg/releases/download/v0.0.0/BiRefNet-portrait-epoch_150.onnx",
28
+ (
29
+ None
30
+ if cls.checksum_disabled(*args, **kwargs)
31
+ else "md5:c3a64a6abf20250d090cd055f12a3b67"
32
+ ),
33
+ fname=fname,
34
+ path=cls.u2net_home(*args, **kwargs),
35
+ progressbar=True,
36
+ )
37
+
38
+ return os.path.join(cls.u2net_home(*args, **kwargs), fname)
39
+
40
+ @classmethod
41
+ def name(cls, *args, **kwargs):
42
+ """
43
+ Returns the name of the BiRefNet-Portrait session.
44
+
45
+ Parameters:
46
+ *args: Additional positional arguments.
47
+ **kwargs: Additional keyword arguments.
48
+
49
+ Returns:
50
+ str: The name of the session.
51
+ """
52
+ return "birefnet-portrait"
rembg/sessions/dis_anime.py CHANGED
@@ -10,7 +10,22 @@ from .base import BaseSession
10
 
11
 
12
  class DisSession(BaseSession):
 
 
 
 
13
  def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]:
 
 
 
 
 
 
 
 
 
 
 
14
  ort_outs = self.inner_session.run(
15
  None,
16
  self.normalize(img, (0.485, 0.456, 0.406), (1.0, 1.0, 1.0), (1024, 1024)),
@@ -25,25 +40,47 @@ class DisSession(BaseSession):
25
  pred = np.squeeze(pred)
26
 
27
  mask = Image.fromarray((pred * 255).astype("uint8"), mode="L")
28
- mask = mask.resize(img.size, Image.LANCZOS)
29
 
30
  return [mask]
31
 
32
  @classmethod
33
  def download_models(cls, *args, **kwargs):
34
- fname = f"{cls.name()}.onnx"
 
 
 
 
 
 
 
 
 
 
35
  pooch.retrieve(
36
  "https://github.com/danielgatis/rembg/releases/download/v0.0.0/isnet-anime.onnx",
37
- None
38
- if cls.checksum_disabled(*args, **kwargs)
39
- else "md5:6f184e756bb3bd901c8849220a83e38e",
 
 
40
  fname=fname,
41
  path=cls.u2net_home(*args, **kwargs),
42
  progressbar=True,
43
  )
44
 
45
- return os.path.join(cls.u2net_home(), fname)
46
 
47
  @classmethod
48
  def name(cls, *args, **kwargs):
 
 
 
 
 
 
 
 
 
 
49
  return "isnet-anime"
 
10
 
11
 
12
  class DisSession(BaseSession):
13
+ """
14
+ This class represents a session for object detection.
15
+ """
16
+
17
  def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]:
18
+ """
19
+ Use a pre-trained model to predict the object in the given image.
20
+
21
+ Parameters:
22
+ img (PILImage): The input image.
23
+ *args: Variable length argument list.
24
+ **kwargs: Arbitrary keyword arguments.
25
+
26
+ Returns:
27
+ List[PILImage]: A list of predicted mask images.
28
+ """
29
  ort_outs = self.inner_session.run(
30
  None,
31
  self.normalize(img, (0.485, 0.456, 0.406), (1.0, 1.0, 1.0), (1024, 1024)),
 
40
  pred = np.squeeze(pred)
41
 
42
  mask = Image.fromarray((pred * 255).astype("uint8"), mode="L")
43
+ mask = mask.resize(img.size, Image.Resampling.LANCZOS)
44
 
45
  return [mask]
46
 
47
  @classmethod
48
  def download_models(cls, *args, **kwargs):
49
+ """
50
+ Download the pre-trained models.
51
+
52
+ Parameters:
53
+ *args: Variable length argument list.
54
+ **kwargs: Arbitrary keyword arguments.
55
+
56
+ Returns:
57
+ str: The path of the downloaded model file.
58
+ """
59
+ fname = f"{cls.name(*args, **kwargs)}.onnx"
60
  pooch.retrieve(
61
  "https://github.com/danielgatis/rembg/releases/download/v0.0.0/isnet-anime.onnx",
62
+ (
63
+ None
64
+ if cls.checksum_disabled(*args, **kwargs)
65
+ else "md5:6f184e756bb3bd901c8849220a83e38e"
66
+ ),
67
  fname=fname,
68
  path=cls.u2net_home(*args, **kwargs),
69
  progressbar=True,
70
  )
71
 
72
+ return os.path.join(cls.u2net_home(*args, **kwargs), fname)
73
 
74
  @classmethod
75
  def name(cls, *args, **kwargs):
76
+ """
77
+ Get the name of the pre-trained model.
78
+
79
+ Parameters:
80
+ *args: Variable length argument list.
81
+ **kwargs: Arbitrary keyword arguments.
82
+
83
+ Returns:
84
+ str: The name of the pre-trained model.
85
+ """
86
  return "isnet-anime"
rembg/sessions/dis_general_use.py CHANGED
@@ -11,6 +11,17 @@ from .base import BaseSession
11
 
12
  class DisSession(BaseSession):
13
  def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]:
 
 
 
 
 
 
 
 
 
 
 
14
  ort_outs = self.inner_session.run(
15
  None,
16
  self.normalize(img, (0.485, 0.456, 0.406), (1.0, 1.0, 1.0), (1024, 1024)),
@@ -25,25 +36,51 @@ class DisSession(BaseSession):
25
  pred = np.squeeze(pred)
26
 
27
  mask = Image.fromarray((pred * 255).astype("uint8"), mode="L")
28
- mask = mask.resize(img.size, Image.LANCZOS)
29
 
30
  return [mask]
31
 
32
  @classmethod
33
  def download_models(cls, *args, **kwargs):
34
- fname = f"{cls.name()}.onnx"
 
 
 
 
 
 
 
 
 
 
 
 
35
  pooch.retrieve(
36
  "https://github.com/danielgatis/rembg/releases/download/v0.0.0/isnet-general-use.onnx",
37
- None
38
- if cls.checksum_disabled(*args, **kwargs)
39
- else "md5:fc16ebd8b0c10d971d3513d564d01e29",
 
 
40
  fname=fname,
41
  path=cls.u2net_home(*args, **kwargs),
42
  progressbar=True,
43
  )
44
 
45
- return os.path.join(cls.u2net_home(), fname)
46
 
47
  @classmethod
48
  def name(cls, *args, **kwargs):
 
 
 
 
 
 
 
 
 
 
 
 
49
  return "isnet-general-use"
 
11
 
12
  class DisSession(BaseSession):
13
  def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]:
14
+ """
15
+ Predicts the mask image for the input image.
16
+
17
+ This method takes a PILImage object as input and returns a list of PILImage objects as output. It performs several image processing operations to generate the mask image.
18
+
19
+ Parameters:
20
+ img (PILImage): The input image.
21
+
22
+ Returns:
23
+ List[PILImage]: A list of PILImage objects representing the generated mask image.
24
+ """
25
  ort_outs = self.inner_session.run(
26
  None,
27
  self.normalize(img, (0.485, 0.456, 0.406), (1.0, 1.0, 1.0), (1024, 1024)),
 
36
  pred = np.squeeze(pred)
37
 
38
  mask = Image.fromarray((pred * 255).astype("uint8"), mode="L")
39
+ mask = mask.resize(img.size, Image.Resampling.LANCZOS)
40
 
41
  return [mask]
42
 
43
  @classmethod
44
  def download_models(cls, *args, **kwargs):
45
+ """
46
+ Downloads the pre-trained model file.
47
+
48
+ This class method downloads the pre-trained model file from a specified URL using the pooch library.
49
+
50
+ Parameters:
51
+ args: Additional positional arguments.
52
+ kwargs: Additional keyword arguments.
53
+
54
+ Returns:
55
+ str: The path to the downloaded model file.
56
+ """
57
+ fname = f"{cls.name(*args, **kwargs)}.onnx"
58
  pooch.retrieve(
59
  "https://github.com/danielgatis/rembg/releases/download/v0.0.0/isnet-general-use.onnx",
60
+ (
61
+ None
62
+ if cls.checksum_disabled(*args, **kwargs)
63
+ else "md5:fc16ebd8b0c10d971d3513d564d01e29"
64
+ ),
65
  fname=fname,
66
  path=cls.u2net_home(*args, **kwargs),
67
  progressbar=True,
68
  )
69
 
70
+ return os.path.join(cls.u2net_home(*args, **kwargs), fname)
71
 
72
  @classmethod
73
  def name(cls, *args, **kwargs):
74
+ """
75
+ Returns the name of the model.
76
+
77
+ This class method returns the name of the model.
78
+
79
+ Parameters:
80
+ args: Additional positional arguments.
81
+ kwargs: Additional keyword arguments.
82
+
83
+ Returns:
84
+ str: The name of the model.
85
+ """
86
  return "isnet-general-use"
rembg/sessions/sam.py CHANGED
@@ -1,9 +1,12 @@
1
  import os
2
- from typing import List
 
3
 
 
4
  import numpy as np
5
  import onnxruntime as ort
6
  import pooch
 
7
  from PIL import Image
8
  from PIL.Image import Image as PILImage
9
 
@@ -15,104 +18,213 @@ def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int):
15
  newh, neww = oldh * scale, oldw * scale
16
  neww = int(neww + 0.5)
17
  newh = int(newh + 0.5)
 
18
  return (newh, neww)
19
 
20
 
21
- def apply_coords(coords: np.ndarray, original_size, target_length) -> np.ndarray:
22
  old_h, old_w = original_size
23
  new_h, new_w = get_preprocess_shape(
24
  original_size[0], original_size[1], target_length
25
  )
26
- coords = coords.copy().astype(float)
 
27
  coords[..., 0] = coords[..., 0] * (new_w / old_w)
28
  coords[..., 1] = coords[..., 1] * (new_h / old_h)
 
29
  return coords
30
 
31
 
32
- def resize_longes_side(img: PILImage, size=1024):
33
- w, h = img.size
34
- if h > w:
35
- new_h, new_w = size, int(w * size / h)
36
- else:
37
- new_h, new_w = int(h * size / w), size
 
 
 
 
 
 
 
38
 
39
- return img.resize((new_w, new_h))
 
40
 
41
 
42
- def pad_to_square(img: np.ndarray, size=1024):
43
- h, w = img.shape[:2]
44
- padh = size - h
45
- padw = size - w
46
- img = np.pad(img, ((0, padh), (0, padw), (0, 0)), mode="constant")
47
- img = img.astype(np.float32)
48
- return img
 
 
 
 
 
 
 
 
 
 
49
 
50
 
51
  class SamSession(BaseSession):
52
- def __init__(self, model_name: str, sess_opts: ort.SessionOptions, *args, **kwargs):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
  self.model_name = model_name
54
- paths = self.__class__.download_models()
 
 
 
 
 
 
 
 
 
 
55
  self.encoder = ort.InferenceSession(
56
  str(paths[0]),
57
- providers=ort.get_available_providers(),
58
  sess_options=sess_opts,
59
  )
60
  self.decoder = ort.InferenceSession(
61
  str(paths[1]),
62
- providers=ort.get_available_providers(),
63
  sess_options=sess_opts,
64
  )
65
 
66
- def normalize(
67
- self,
68
- img: np.ndarray,
69
- mean=(123.675, 116.28, 103.53),
70
- std=(58.395, 57.12, 57.375),
71
- size=(1024, 1024),
72
- *args,
73
- **kwargs,
74
- ):
75
- pixel_mean = np.array([*mean]).reshape(1, 1, -1)
76
- pixel_std = np.array([*std]).reshape(1, 1, -1)
77
- x = (img - pixel_mean) / pixel_std
78
- return x
79
-
80
  def predict(
81
  self,
82
  img: PILImage,
83
  *args,
84
  **kwargs,
85
  ) -> List[PILImage]:
86
- # Preprocess image
87
- image = resize_longes_side(img)
88
- image = np.array(image)
89
- image = self.normalize(image)
90
- image = pad_to_square(image)
91
-
92
- input_labels = kwargs.get("input_labels")
93
- input_points = kwargs.get("input_points")
94
-
95
- if input_labels is None:
96
- raise ValueError("input_labels is required")
97
- if input_points is None:
98
- raise ValueError("input_points is required")
99
-
100
- # Transpose
101
- image = image.transpose(2, 0, 1)[None, :, :, :]
102
- # Run encoder (Image embedding)
103
- encoded = self.encoder.run(None, {"x": image})
104
- image_embedding = encoded[0]
105
-
106
- # Add a batch index, concatenate a padding point, and transform.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
  onnx_coord = np.concatenate([input_points, np.array([[0.0, 0.0]])], axis=0)[
108
  None, :, :
109
  ]
110
  onnx_label = np.concatenate([input_labels, np.array([-1])], axis=0)[
111
  None, :
112
  ].astype(np.float32)
113
- onnx_coord = apply_coords(onnx_coord, img.size[::1], 1024).astype(np.float32)
 
 
 
 
 
 
 
 
 
 
 
 
114
 
115
- # Create an empty mask input and an indicator for no mask.
116
  onnx_mask_input = np.zeros((1, 1, 256, 256), dtype=np.float32)
117
  onnx_has_mask_input = np.zeros(1, dtype=np.float32)
118
 
@@ -122,48 +234,110 @@ class SamSession(BaseSession):
122
  "point_labels": onnx_label,
123
  "mask_input": onnx_mask_input,
124
  "has_mask_input": onnx_has_mask_input,
125
- "orig_im_size": np.array(img.size[::-1], dtype=np.float32),
126
  }
127
 
128
- masks, _, low_res_logits = self.decoder.run(None, decoder_inputs)
129
- masks = masks > 0.0
130
- masks = [
131
- Image.fromarray((masks[i, 0] * 255).astype(np.uint8))
132
- for i in range(masks.shape[0])
133
- ]
 
134
 
135
- return masks
136
 
137
  @classmethod
138
  def download_models(cls, *args, **kwargs):
139
- fname_encoder = f"{cls.name()}_encoder.onnx"
140
- fname_decoder = f"{cls.name()}_decoder.onnx"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
141
 
142
  pooch.retrieve(
143
- "https://github.com/danielgatis/rembg/releases/download/v0.0.0/vit_b-encoder-quant.onnx",
144
- None
145
- if cls.checksum_disabled(*args, **kwargs)
146
- else "md5:13d97c5c79ab13ef86d67cbde5f1b250",
147
  fname=fname_encoder,
148
  path=cls.u2net_home(*args, **kwargs),
149
  progressbar=True,
150
  )
151
 
152
  pooch.retrieve(
153
- "https://github.com/danielgatis/rembg/releases/download/v0.0.0/vit_b-decoder-quant.onnx",
154
- None
155
- if cls.checksum_disabled(*args, **kwargs)
156
- else "md5:fa3d1c36a3187d3de1c8deebf33dd127",
157
  fname=fname_decoder,
158
  path=cls.u2net_home(*args, **kwargs),
159
  progressbar=True,
160
  )
161
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
162
  return (
163
- os.path.join(cls.u2net_home(), fname_encoder),
164
- os.path.join(cls.u2net_home(), fname_decoder),
165
  )
166
 
167
  @classmethod
168
  def name(cls, *args, **kwargs):
 
 
 
 
 
 
 
 
 
 
 
 
 
169
  return "sam"
 
1
  import os
2
+ from copy import deepcopy
3
+ from typing import Dict, List, Tuple
4
 
5
+ import cv2
6
  import numpy as np
7
  import onnxruntime as ort
8
  import pooch
9
+ from jsonschema import validate
10
  from PIL import Image
11
  from PIL.Image import Image as PILImage
12
 
 
18
  newh, neww = oldh * scale, oldw * scale
19
  neww = int(neww + 0.5)
20
  newh = int(newh + 0.5)
21
+
22
  return (newh, neww)
23
 
24
 
25
+ def apply_coords(coords: np.ndarray, original_size, target_length):
26
  old_h, old_w = original_size
27
  new_h, new_w = get_preprocess_shape(
28
  original_size[0], original_size[1], target_length
29
  )
30
+
31
+ coords = deepcopy(coords).astype(float)
32
  coords[..., 0] = coords[..., 0] * (new_w / old_w)
33
  coords[..., 1] = coords[..., 1] * (new_h / old_h)
34
+
35
  return coords
36
 
37
 
38
+ def get_input_points(prompt):
39
+ points = []
40
+ labels = []
41
+
42
+ for mark in prompt:
43
+ if mark["type"] == "point":
44
+ points.append(mark["data"])
45
+ labels.append(mark["label"])
46
+ elif mark["type"] == "rectangle":
47
+ points.append([mark["data"][0], mark["data"][1]])
48
+ points.append([mark["data"][2], mark["data"][3]])
49
+ labels.append(2)
50
+ labels.append(3)
51
 
52
+ points, labels = np.array(points), np.array(labels)
53
+ return points, labels
54
 
55
 
56
+ def transform_masks(masks, original_size, transform_matrix):
57
+ output_masks = []
58
+
59
+ for batch in range(masks.shape[0]):
60
+ batch_masks = []
61
+ for mask_id in range(masks.shape[1]):
62
+ mask = masks[batch, mask_id]
63
+ mask = cv2.warpAffine(
64
+ mask,
65
+ transform_matrix[:2],
66
+ (original_size[1], original_size[0]),
67
+ flags=cv2.INTER_LINEAR,
68
+ )
69
+ batch_masks.append(mask)
70
+ output_masks.append(batch_masks)
71
+
72
+ return np.array(output_masks)
73
 
74
 
75
  class SamSession(BaseSession):
76
+ """
77
+ This class represents a session for the Sam model.
78
+
79
+ Args:
80
+ model_name (str): The name of the model.
81
+ sess_opts (ort.SessionOptions): The session options.
82
+ *args: Variable length argument list.
83
+ **kwargs: Arbitrary keyword arguments.
84
+ """
85
+
86
+ def __init__(
87
+ self,
88
+ model_name: str,
89
+ sess_opts: ort.SessionOptions,
90
+ providers=None,
91
+ *args,
92
+ **kwargs,
93
+ ):
94
+ """
95
+ Initialize a new SamSession with the given model name and session options.
96
+
97
+ Args:
98
+ model_name (str): The name of the model.
99
+ sess_opts (ort.SessionOptions): The session options.
100
+ *args: Variable length argument list.
101
+ **kwargs: Arbitrary keyword arguments.
102
+ """
103
  self.model_name = model_name
104
+
105
+ valid_providers = []
106
+ available_providers = ort.get_available_providers()
107
+
108
+ for provider in providers or []:
109
+ if provider in available_providers:
110
+ valid_providers.append(provider)
111
+ else:
112
+ valid_providers.extend(available_providers)
113
+
114
+ paths = self.__class__.download_models(*args, **kwargs)
115
  self.encoder = ort.InferenceSession(
116
  str(paths[0]),
117
+ providers=valid_providers,
118
  sess_options=sess_opts,
119
  )
120
  self.decoder = ort.InferenceSession(
121
  str(paths[1]),
122
+ providers=valid_providers,
123
  sess_options=sess_opts,
124
  )
125
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
  def predict(
127
  self,
128
  img: PILImage,
129
  *args,
130
  **kwargs,
131
  ) -> List[PILImage]:
132
+ """
133
+ Predict masks for an input image.
134
+
135
+ This function takes an image as input and performs various preprocessing steps on the image. It then runs the image through an encoder to obtain an image embedding. The function also takes input labels and points as additional arguments. It concatenates the input points and labels with padding and transforms them. It creates an empty mask input and an indicator for no mask. The function then passes the image embedding, point coordinates, point labels, mask input, and has mask input to a decoder. The decoder generates masks based on the input and returns them as a list of images.
136
+
137
+ Parameters:
138
+ img (PILImage): The input image.
139
+ *args: Additional arguments.
140
+ **kwargs: Additional keyword arguments.
141
+
142
+ Returns:
143
+ List[PILImage]: A list of masks generated by the decoder.
144
+ """
145
+ prompt = kwargs.get("sam_prompt", "{}")
146
+ schema = {
147
+ "type": "array",
148
+ "items": {
149
+ "type": "object",
150
+ "properties": {
151
+ "type": {"type": "string"},
152
+ "label": {"type": "integer"},
153
+ "data": {
154
+ "type": "array",
155
+ "items": {"type": "number"},
156
+ },
157
+ },
158
+ },
159
+ }
160
+
161
+ validate(instance=prompt, schema=schema)
162
+
163
+ target_size = 1024
164
+ input_size = (684, 1024)
165
+ encoder_input_name = self.encoder.get_inputs()[0].name
166
+
167
+ img = img.convert("RGB")
168
+ cv_image = np.array(img)
169
+ original_size = cv_image.shape[:2]
170
+
171
+ scale_x = input_size[1] / cv_image.shape[1]
172
+ scale_y = input_size[0] / cv_image.shape[0]
173
+ scale = min(scale_x, scale_y)
174
+
175
+ transform_matrix = np.array(
176
+ [
177
+ [scale, 0, 0],
178
+ [0, scale, 0],
179
+ [0, 0, 1],
180
+ ]
181
+ )
182
+
183
+ cv_image = cv2.warpAffine(
184
+ cv_image,
185
+ transform_matrix[:2],
186
+ (input_size[1], input_size[0]),
187
+ flags=cv2.INTER_LINEAR,
188
+ )
189
+
190
+ ## encoder
191
+
192
+ encoder_inputs = {
193
+ encoder_input_name: cv_image.astype(np.float32),
194
+ }
195
+
196
+ encoder_output = self.encoder.run(None, encoder_inputs)
197
+ image_embedding = encoder_output[0]
198
+
199
+ embedding = {
200
+ "image_embedding": image_embedding,
201
+ "original_size": original_size,
202
+ "transform_matrix": transform_matrix,
203
+ }
204
+
205
+ ## decoder
206
+
207
+ input_points, input_labels = get_input_points(prompt)
208
  onnx_coord = np.concatenate([input_points, np.array([[0.0, 0.0]])], axis=0)[
209
  None, :, :
210
  ]
211
  onnx_label = np.concatenate([input_labels, np.array([-1])], axis=0)[
212
  None, :
213
  ].astype(np.float32)
214
+ onnx_coord = apply_coords(onnx_coord, input_size, target_size).astype(
215
+ np.float32
216
+ )
217
+
218
+ onnx_coord = np.concatenate(
219
+ [
220
+ onnx_coord,
221
+ np.ones((1, onnx_coord.shape[1], 1), dtype=np.float32),
222
+ ],
223
+ axis=2,
224
+ )
225
+ onnx_coord = np.matmul(onnx_coord, transform_matrix.T)
226
+ onnx_coord = onnx_coord[:, :, :2].astype(np.float32)
227
 
 
228
  onnx_mask_input = np.zeros((1, 1, 256, 256), dtype=np.float32)
229
  onnx_has_mask_input = np.zeros(1, dtype=np.float32)
230
 
 
234
  "point_labels": onnx_label,
235
  "mask_input": onnx_mask_input,
236
  "has_mask_input": onnx_has_mask_input,
237
+ "orig_im_size": np.array(input_size, dtype=np.float32),
238
  }
239
 
240
+ masks, _, _ = self.decoder.run(None, decoder_inputs)
241
+ inv_transform_matrix = np.linalg.inv(transform_matrix)
242
+ masks = transform_masks(masks, original_size, inv_transform_matrix)
243
+
244
+ mask = np.zeros((masks.shape[2], masks.shape[3], 3), dtype=np.uint8)
245
+ for m in masks[0, :, :, :]:
246
+ mask[m > 0.0] = [255, 255, 255]
247
 
248
+ return [Image.fromarray(mask).convert("L")]
249
 
250
  @classmethod
251
  def download_models(cls, *args, **kwargs):
252
+ """
253
+ Class method to download ONNX model files.
254
+
255
+ This method is responsible for downloading two ONNX model files from specified URLs and saving them locally. The downloaded files are saved with the naming convention 'name_encoder.onnx' and 'name_decoder.onnx', where 'name' is the value returned by the 'name' method.
256
+
257
+ Parameters:
258
+ cls: The class object.
259
+ *args: Variable length argument list.
260
+ **kwargs: Arbitrary keyword arguments.
261
+
262
+ Returns:
263
+ tuple: A tuple containing the file paths of the downloaded encoder and decoder models.
264
+ """
265
+ model_name = kwargs.get("sam_model", "sam_vit_b_01ec64")
266
+ quant = kwargs.get("sam_quant", False)
267
+
268
+ fname_encoder = f"{model_name}.encoder.onnx"
269
+ fname_decoder = f"{model_name}.decoder.onnx"
270
+
271
+ if quant:
272
+ fname_encoder = f"{model_name}.encoder.quant.onnx"
273
+ fname_decoder = f"{model_name}.decoder.quant.onnx"
274
 
275
  pooch.retrieve(
276
+ f"https://github.com/danielgatis/rembg/releases/download/v0.0.0/{fname_encoder}",
277
+ None,
 
 
278
  fname=fname_encoder,
279
  path=cls.u2net_home(*args, **kwargs),
280
  progressbar=True,
281
  )
282
 
283
  pooch.retrieve(
284
+ f"https://github.com/danielgatis/rembg/releases/download/v0.0.0/{fname_decoder}",
285
+ None,
 
 
286
  fname=fname_decoder,
287
  path=cls.u2net_home(*args, **kwargs),
288
  progressbar=True,
289
  )
290
 
291
+ if fname_encoder == "sam_vit_h_4b8939.encoder.onnx" and not os.path.exists(
292
+ os.path.join(
293
+ cls.u2net_home(*args, **kwargs), "sam_vit_h_4b8939.encoder_data.bin"
294
+ )
295
+ ):
296
+ content = bytearray()
297
+
298
+ for i in range(1, 4):
299
+ pooch.retrieve(
300
+ f"https://github.com/danielgatis/rembg/releases/download/v0.0.0/sam_vit_h_4b8939.encoder_data.{i}.bin",
301
+ None,
302
+ fname=f"sam_vit_h_4b8939.encoder_data.{i}.bin",
303
+ path=cls.u2net_home(*args, **kwargs),
304
+ progressbar=True,
305
+ )
306
+
307
+ fbin = os.path.join(
308
+ cls.u2net_home(*args, **kwargs),
309
+ f"sam_vit_h_4b8939.encoder_data.{i}.bin",
310
+ )
311
+ content.extend(open(fbin, "rb").read())
312
+ os.remove(fbin)
313
+
314
+ with open(
315
+ os.path.join(
316
+ cls.u2net_home(*args, **kwargs),
317
+ "sam_vit_h_4b8939.encoder_data.bin",
318
+ ),
319
+ "wb",
320
+ ) as fp:
321
+ fp.write(content)
322
+
323
  return (
324
+ os.path.join(cls.u2net_home(*args, **kwargs), fname_encoder),
325
+ os.path.join(cls.u2net_home(*args, **kwargs), fname_decoder),
326
  )
327
 
328
  @classmethod
329
  def name(cls, *args, **kwargs):
330
+ """
331
+ Class method to return a string value.
332
+
333
+ This method returns the string value 'sam'.
334
+
335
+ Parameters:
336
+ cls: The class object.
337
+ *args: Variable length argument list.
338
+ **kwargs: Arbitrary keyword arguments.
339
+
340
+ Returns:
341
+ str: The string value 'sam'.
342
+ """
343
  return "sam"
rembg/sessions/silueta.py CHANGED
@@ -10,7 +10,22 @@ from .base import BaseSession
10
 
11
 
12
  class SiluetaSession(BaseSession):
 
 
13
  def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]:
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  ort_outs = self.inner_session.run(
15
  None,
16
  self.normalize(
@@ -27,25 +42,51 @@ class SiluetaSession(BaseSession):
27
  pred = np.squeeze(pred)
28
 
29
  mask = Image.fromarray((pred * 255).astype("uint8"), mode="L")
30
- mask = mask.resize(img.size, Image.LANCZOS)
31
 
32
  return [mask]
33
 
34
  @classmethod
35
  def download_models(cls, *args, **kwargs):
 
 
 
 
 
 
 
 
 
 
 
 
36
  fname = f"{cls.name()}.onnx"
37
  pooch.retrieve(
38
  "https://github.com/danielgatis/rembg/releases/download/v0.0.0/silueta.onnx",
39
- None
40
- if cls.checksum_disabled(*args, **kwargs)
41
- else "md5:55e59e0d8062d2f5d013f4725ee84782",
 
 
42
  fname=fname,
43
  path=cls.u2net_home(*args, **kwargs),
44
  progressbar=True,
45
  )
46
 
47
- return os.path.join(cls.u2net_home(), fname)
48
 
49
  @classmethod
50
  def name(cls, *args, **kwargs):
 
 
 
 
 
 
 
 
 
 
 
 
51
  return "silueta"
 
10
 
11
 
12
  class SiluetaSession(BaseSession):
13
+ """This is a class representing a SiluetaSession object."""
14
+
15
  def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]:
16
+ """
17
+ Predict the mask of the input image.
18
+
19
+ This method takes an image as input, preprocesses it, and performs a prediction to generate a mask. The generated mask is then post-processed and returned as a list of PILImage objects.
20
+
21
+ Parameters:
22
+ img (PILImage): The input image to be processed.
23
+ *args: Variable length argument list.
24
+ **kwargs: Arbitrary keyword arguments.
25
+
26
+ Returns:
27
+ List[PILImage]: A list of post-processed masks.
28
+ """
29
  ort_outs = self.inner_session.run(
30
  None,
31
  self.normalize(
 
42
  pred = np.squeeze(pred)
43
 
44
  mask = Image.fromarray((pred * 255).astype("uint8"), mode="L")
45
+ mask = mask.resize(img.size, Image.Resampling.LANCZOS)
46
 
47
  return [mask]
48
 
49
  @classmethod
50
  def download_models(cls, *args, **kwargs):
51
+ """
52
+ Download the pre-trained model file.
53
+
54
+ This method downloads the pre-trained model file from a specified URL. The file is saved to the U2NET home directory.
55
+
56
+ Parameters:
57
+ *args: Variable length argument list.
58
+ **kwargs: Arbitrary keyword arguments.
59
+
60
+ Returns:
61
+ str: The path to the downloaded model file.
62
+ """
63
  fname = f"{cls.name()}.onnx"
64
  pooch.retrieve(
65
  "https://github.com/danielgatis/rembg/releases/download/v0.0.0/silueta.onnx",
66
+ (
67
+ None
68
+ if cls.checksum_disabled(*args, **kwargs)
69
+ else "md5:55e59e0d8062d2f5d013f4725ee84782"
70
+ ),
71
  fname=fname,
72
  path=cls.u2net_home(*args, **kwargs),
73
  progressbar=True,
74
  )
75
 
76
+ return os.path.join(cls.u2net_home(*args, **kwargs), fname)
77
 
78
  @classmethod
79
  def name(cls, *args, **kwargs):
80
+ """
81
+ Return the name of the model.
82
+
83
+ This method returns the name of the Silueta model.
84
+
85
+ Parameters:
86
+ *args: Variable length argument list.
87
+ **kwargs: Arbitrary keyword arguments.
88
+
89
+ Returns:
90
+ str: The name of the model.
91
+ """
92
  return "silueta"
rembg/sessions/u2net.py CHANGED
@@ -10,7 +10,22 @@ from .base import BaseSession
10
 
11
 
12
  class U2netSession(BaseSession):
 
 
 
 
13
  def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]:
 
 
 
 
 
 
 
 
 
 
 
14
  ort_outs = self.inner_session.run(
15
  None,
16
  self.normalize(
@@ -27,25 +42,47 @@ class U2netSession(BaseSession):
27
  pred = np.squeeze(pred)
28
 
29
  mask = Image.fromarray((pred * 255).astype("uint8"), mode="L")
30
- mask = mask.resize(img.size, Image.LANCZOS)
31
 
32
  return [mask]
33
 
34
  @classmethod
35
  def download_models(cls, *args, **kwargs):
36
- fname = f"{cls.name()}.onnx"
 
 
 
 
 
 
 
 
 
 
37
  pooch.retrieve(
38
  "https://github.com/danielgatis/rembg/releases/download/v0.0.0/u2net.onnx",
39
- None
40
- if cls.checksum_disabled(*args, **kwargs)
41
- else "md5:60024c5c889badc19c04ad937298a77b",
 
 
42
  fname=fname,
43
  path=cls.u2net_home(*args, **kwargs),
44
  progressbar=True,
45
  )
46
 
47
- return os.path.join(cls.u2net_home(), fname)
48
 
49
  @classmethod
50
  def name(cls, *args, **kwargs):
 
 
 
 
 
 
 
 
 
 
51
  return "u2net"
 
10
 
11
 
12
  class U2netSession(BaseSession):
13
+ """
14
+ This class represents a U2net session, which is a subclass of BaseSession.
15
+ """
16
+
17
  def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]:
18
+ """
19
+ Predicts the output masks for the input image using the inner session.
20
+
21
+ Parameters:
22
+ img (PILImage): The input image.
23
+ *args: Additional positional arguments.
24
+ **kwargs: Additional keyword arguments.
25
+
26
+ Returns:
27
+ List[PILImage]: The list of output masks.
28
+ """
29
  ort_outs = self.inner_session.run(
30
  None,
31
  self.normalize(
 
42
  pred = np.squeeze(pred)
43
 
44
  mask = Image.fromarray((pred * 255).astype("uint8"), mode="L")
45
+ mask = mask.resize(img.size, Image.Resampling.LANCZOS)
46
 
47
  return [mask]
48
 
49
  @classmethod
50
  def download_models(cls, *args, **kwargs):
51
+ """
52
+ Downloads the U2net model file from a specific URL and saves it.
53
+
54
+ Parameters:
55
+ *args: Additional positional arguments.
56
+ **kwargs: Additional keyword arguments.
57
+
58
+ Returns:
59
+ str: The path to the downloaded model file.
60
+ """
61
+ fname = f"{cls.name(*args, **kwargs)}.onnx"
62
  pooch.retrieve(
63
  "https://github.com/danielgatis/rembg/releases/download/v0.0.0/u2net.onnx",
64
+ (
65
+ None
66
+ if cls.checksum_disabled(*args, **kwargs)
67
+ else "md5:60024c5c889badc19c04ad937298a77b"
68
+ ),
69
  fname=fname,
70
  path=cls.u2net_home(*args, **kwargs),
71
  progressbar=True,
72
  )
73
 
74
+ return os.path.join(cls.u2net_home(*args, **kwargs), fname)
75
 
76
  @classmethod
77
  def name(cls, *args, **kwargs):
78
+ """
79
+ Returns the name of the U2net session.
80
+
81
+ Parameters:
82
+ *args: Additional positional arguments.
83
+ **kwargs: Additional keyword arguments.
84
+
85
+ Returns:
86
+ str: The name of the session.
87
+ """
88
  return "u2net"
rembg/sessions/u2net_cloth_seg.py CHANGED
@@ -9,7 +9,7 @@ from scipy.special import log_softmax
9
 
10
  from .base import BaseSession
11
 
12
- pallete1 = [
13
  0,
14
  0,
15
  0,
@@ -24,7 +24,7 @@ pallete1 = [
24
  0,
25
  ]
26
 
27
- pallete2 = [
28
  0,
29
  0,
30
  0,
@@ -39,7 +39,7 @@ pallete2 = [
39
  0,
40
  ]
41
 
42
- pallete3 = [
43
  0,
44
  0,
45
  0,
@@ -57,6 +57,22 @@ pallete3 = [
57
 
58
  class Unet2ClothSession(BaseSession):
59
  def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
  ort_outs = self.inner_session.run(
61
  None,
62
  self.normalize(
@@ -71,41 +87,59 @@ class Unet2ClothSession(BaseSession):
71
  pred = np.squeeze(pred, 0)
72
 
73
  mask = Image.fromarray(pred.astype("uint8"), mode="L")
74
- mask = mask.resize(img.size, Image.LANCZOS)
75
 
76
  masks = []
77
 
78
- mask1 = mask.copy()
79
- mask1.putpalette(pallete1)
80
- mask1 = mask1.convert("RGB").convert("L")
81
- masks.append(mask1)
82
-
83
- mask2 = mask.copy()
84
- mask2.putpalette(pallete2)
85
- mask2 = mask2.convert("RGB").convert("L")
86
- masks.append(mask2)
87
-
88
- mask3 = mask.copy()
89
- mask3.putpalette(pallete3)
90
- mask3 = mask3.convert("RGB").convert("L")
91
- masks.append(mask3)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
 
93
  return masks
94
 
95
  @classmethod
96
  def download_models(cls, *args, **kwargs):
97
- fname = f"{cls.name()}.onnx"
98
  pooch.retrieve(
99
  "https://github.com/danielgatis/rembg/releases/download/v0.0.0/u2net_cloth_seg.onnx",
100
- None
101
- if cls.checksum_disabled(*args, **kwargs)
102
- else "md5:2434d1f3cb744e0e49386c906e5a08bb",
 
 
103
  fname=fname,
104
  path=cls.u2net_home(*args, **kwargs),
105
  progressbar=True,
106
  )
107
 
108
- return os.path.join(cls.u2net_home(), fname)
109
 
110
  @classmethod
111
  def name(cls, *args, **kwargs):
 
9
 
10
  from .base import BaseSession
11
 
12
+ palette1 = [
13
  0,
14
  0,
15
  0,
 
24
  0,
25
  ]
26
 
27
+ palette2 = [
28
  0,
29
  0,
30
  0,
 
39
  0,
40
  ]
41
 
42
+ palette3 = [
43
  0,
44
  0,
45
  0,
 
57
 
58
  class Unet2ClothSession(BaseSession):
59
  def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]:
60
+ """
61
+ Predict the cloth category of an image.
62
+
63
+ This method takes an image as input and predicts the cloth category of the image.
64
+ The method uses the inner_session to make predictions using a pre-trained model.
65
+ The predicted mask is then converted to an image and resized to match the size of the input image.
66
+ Depending on the cloth category specified in the method arguments, the method applies different color palettes to the mask and appends the resulting images to a list.
67
+
68
+ Parameters:
69
+ img (PILImage): The input image.
70
+ *args: Additional positional arguments.
71
+ **kwargs: Additional keyword arguments.
72
+
73
+ Returns:
74
+ List[PILImage]: A list of images representing the predicted masks.
75
+ """
76
  ort_outs = self.inner_session.run(
77
  None,
78
  self.normalize(
 
87
  pred = np.squeeze(pred, 0)
88
 
89
  mask = Image.fromarray(pred.astype("uint8"), mode="L")
90
+ mask = mask.resize(img.size, Image.Resampling.LANCZOS)
91
 
92
  masks = []
93
 
94
+ cloth_category = kwargs.get("cc") or kwargs.get("cloth_category")
95
+
96
+ def upper_cloth():
97
+ mask1 = mask.copy()
98
+ mask1.putpalette(palette1)
99
+ mask1 = mask1.convert("RGB").convert("L")
100
+ masks.append(mask1)
101
+
102
+ def lower_cloth():
103
+ mask2 = mask.copy()
104
+ mask2.putpalette(palette2)
105
+ mask2 = mask2.convert("RGB").convert("L")
106
+ masks.append(mask2)
107
+
108
+ def full_cloth():
109
+ mask3 = mask.copy()
110
+ mask3.putpalette(palette3)
111
+ mask3 = mask3.convert("RGB").convert("L")
112
+ masks.append(mask3)
113
+
114
+ if cloth_category == "upper":
115
+ upper_cloth()
116
+ elif cloth_category == "lower":
117
+ lower_cloth()
118
+ elif cloth_category == "full":
119
+ full_cloth()
120
+ else:
121
+ upper_cloth()
122
+ lower_cloth()
123
+ full_cloth()
124
 
125
  return masks
126
 
127
  @classmethod
128
  def download_models(cls, *args, **kwargs):
129
+ fname = f"{cls.name(*args, **kwargs)}.onnx"
130
  pooch.retrieve(
131
  "https://github.com/danielgatis/rembg/releases/download/v0.0.0/u2net_cloth_seg.onnx",
132
+ (
133
+ None
134
+ if cls.checksum_disabled(*args, **kwargs)
135
+ else "md5:2434d1f3cb744e0e49386c906e5a08bb"
136
+ ),
137
  fname=fname,
138
  path=cls.u2net_home(*args, **kwargs),
139
  progressbar=True,
140
  )
141
 
142
+ return os.path.join(cls.u2net_home(*args, **kwargs), fname)
143
 
144
  @classmethod
145
  def name(cls, *args, **kwargs):
rembg/sessions/u2net_custom.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import List
3
+
4
+ import numpy as np
5
+ import onnxruntime as ort
6
+ import pooch
7
+ from PIL import Image
8
+ from PIL.Image import Image as PILImage
9
+
10
+ from .base import BaseSession
11
+
12
+
13
+ class U2netCustomSession(BaseSession):
14
+ """This is a class representing a custom session for the U2net model."""
15
+
16
+ def __init__(
17
+ self,
18
+ model_name: str,
19
+ sess_opts: ort.SessionOptions,
20
+ providers=None,
21
+ *args,
22
+ **kwargs
23
+ ):
24
+ """
25
+ Initialize a new U2netCustomSession object.
26
+
27
+ Parameters:
28
+ model_name (str): The name of the model.
29
+ sess_opts (ort.SessionOptions): The session options.
30
+ providers: The providers.
31
+ *args: Additional positional arguments.
32
+ **kwargs: Additional keyword arguments.
33
+
34
+ Raises:
35
+ ValueError: If model_path is None.
36
+ """
37
+ model_path = kwargs.get("model_path")
38
+ if model_path is None:
39
+ raise ValueError("model_path is required")
40
+
41
+ super().__init__(model_name, sess_opts, providers, *args, **kwargs)
42
+
43
+ def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]:
44
+ """
45
+ Predict the segmentation mask for the input image.
46
+
47
+ Parameters:
48
+ img (PILImage): The input image.
49
+ *args: Additional positional arguments.
50
+ **kwargs: Additional keyword arguments.
51
+
52
+ Returns:
53
+ List[PILImage]: A list of PILImage objects representing the segmentation mask.
54
+ """
55
+ ort_outs = self.inner_session.run(
56
+ None,
57
+ self.normalize(
58
+ img, (0.485, 0.456, 0.406), (0.229, 0.224, 0.225), (320, 320)
59
+ ),
60
+ )
61
+
62
+ pred = ort_outs[0][:, 0, :, :]
63
+
64
+ ma = np.max(pred)
65
+ mi = np.min(pred)
66
+
67
+ pred = (pred - mi) / (ma - mi)
68
+ pred = np.squeeze(pred)
69
+
70
+ mask = Image.fromarray((pred * 255).astype("uint8"), mode="L")
71
+ mask = mask.resize(img.size, Image.Resampling.LANCZOS)
72
+
73
+ return [mask]
74
+
75
+ @classmethod
76
+ def download_models(cls, *args, **kwargs):
77
+ """
78
+ Download the model files.
79
+
80
+ Parameters:
81
+ *args: Additional positional arguments.
82
+ **kwargs: Additional keyword arguments.
83
+
84
+ Returns:
85
+ str: The absolute path to the model files.
86
+ """
87
+ model_path = kwargs.get("model_path")
88
+ if model_path is None:
89
+ return
90
+
91
+ return os.path.abspath(os.path.expanduser(model_path))
92
+
93
+ @classmethod
94
+ def name(cls, *args, **kwargs):
95
+ """
96
+ Get the name of the model.
97
+
98
+ Parameters:
99
+ *args: Additional positional arguments.
100
+ **kwargs: Additional keyword arguments.
101
+
102
+ Returns:
103
+ str: The name of the model.
104
+ """
105
+ return "u2net_custom"
rembg/sessions/u2net_human_seg.py CHANGED
@@ -10,7 +10,22 @@ from .base import BaseSession
10
 
11
 
12
  class U2netHumanSegSession(BaseSession):
 
 
 
 
13
  def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]:
 
 
 
 
 
 
 
 
 
 
 
14
  ort_outs = self.inner_session.run(
15
  None,
16
  self.normalize(
@@ -27,25 +42,47 @@ class U2netHumanSegSession(BaseSession):
27
  pred = np.squeeze(pred)
28
 
29
  mask = Image.fromarray((pred * 255).astype("uint8"), mode="L")
30
- mask = mask.resize(img.size, Image.LANCZOS)
31
 
32
  return [mask]
33
 
34
  @classmethod
35
  def download_models(cls, *args, **kwargs):
36
- fname = f"{cls.name()}.onnx"
 
 
 
 
 
 
 
 
 
 
37
  pooch.retrieve(
38
  "https://github.com/danielgatis/rembg/releases/download/v0.0.0/u2net_human_seg.onnx",
39
- None
40
- if cls.checksum_disabled(*args, **kwargs)
41
- else "md5:c09ddc2e0104f800e3e1bb4652583d1f",
 
 
42
  fname=fname,
43
  path=cls.u2net_home(*args, **kwargs),
44
  progressbar=True,
45
  )
46
 
47
- return os.path.join(cls.u2net_home(), fname)
48
 
49
  @classmethod
50
  def name(cls, *args, **kwargs):
 
 
 
 
 
 
 
 
 
 
51
  return "u2net_human_seg"
 
10
 
11
 
12
  class U2netHumanSegSession(BaseSession):
13
+ """
14
+ This class represents a session for performing human segmentation using the U2Net model.
15
+ """
16
+
17
  def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]:
18
+ """
19
+ Predicts human segmentation masks for the input image.
20
+
21
+ Parameters:
22
+ img (PILImage): The input image.
23
+ *args: Variable length argument list.
24
+ **kwargs: Arbitrary keyword arguments.
25
+
26
+ Returns:
27
+ List[PILImage]: A list of predicted masks.
28
+ """
29
  ort_outs = self.inner_session.run(
30
  None,
31
  self.normalize(
 
42
  pred = np.squeeze(pred)
43
 
44
  mask = Image.fromarray((pred * 255).astype("uint8"), mode="L")
45
+ mask = mask.resize(img.size, Image.Resampling.LANCZOS)
46
 
47
  return [mask]
48
 
49
  @classmethod
50
  def download_models(cls, *args, **kwargs):
51
+ """
52
+ Downloads the U2Net model weights.
53
+
54
+ Parameters:
55
+ *args: Variable length argument list.
56
+ **kwargs: Arbitrary keyword arguments.
57
+
58
+ Returns:
59
+ str: The path to the downloaded model weights.
60
+ """
61
+ fname = f"{cls.name(*args, **kwargs)}.onnx"
62
  pooch.retrieve(
63
  "https://github.com/danielgatis/rembg/releases/download/v0.0.0/u2net_human_seg.onnx",
64
+ (
65
+ None
66
+ if cls.checksum_disabled(*args, **kwargs)
67
+ else "md5:c09ddc2e0104f800e3e1bb4652583d1f"
68
+ ),
69
  fname=fname,
70
  path=cls.u2net_home(*args, **kwargs),
71
  progressbar=True,
72
  )
73
 
74
+ return os.path.join(cls.u2net_home(*args, **kwargs), fname)
75
 
76
  @classmethod
77
  def name(cls, *args, **kwargs):
78
+ """
79
+ Returns the name of the U2Net model.
80
+
81
+ Parameters:
82
+ *args: Variable length argument list.
83
+ **kwargs: Arbitrary keyword arguments.
84
+
85
+ Returns:
86
+ str: The name of the model.
87
+ """
88
  return "u2net_human_seg"
rembg/sessions/u2netp.py CHANGED
@@ -10,7 +10,18 @@ from .base import BaseSession
10
 
11
 
12
  class U2netpSession(BaseSession):
 
 
13
  def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]:
 
 
 
 
 
 
 
 
 
14
  ort_outs = self.inner_session.run(
15
  None,
16
  self.normalize(
@@ -27,25 +38,39 @@ class U2netpSession(BaseSession):
27
  pred = np.squeeze(pred)
28
 
29
  mask = Image.fromarray((pred * 255).astype("uint8"), mode="L")
30
- mask = mask.resize(img.size, Image.LANCZOS)
31
 
32
  return [mask]
33
 
34
  @classmethod
35
  def download_models(cls, *args, **kwargs):
36
- fname = f"{cls.name()}.onnx"
 
 
 
 
 
 
37
  pooch.retrieve(
38
  "https://github.com/danielgatis/rembg/releases/download/v0.0.0/u2netp.onnx",
39
- None
40
- if cls.checksum_disabled(*args, **kwargs)
41
- else "md5:8e83ca70e441ab06c318d82300c84806",
 
 
42
  fname=fname,
43
  path=cls.u2net_home(*args, **kwargs),
44
  progressbar=True,
45
  )
46
 
47
- return os.path.join(cls.u2net_home(), fname)
48
 
49
  @classmethod
50
  def name(cls, *args, **kwargs):
 
 
 
 
 
 
51
  return "u2netp"
 
10
 
11
 
12
  class U2netpSession(BaseSession):
13
+ """This class represents a session for using the U2netp model."""
14
+
15
  def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]:
16
+ """
17
+ Predicts the mask for the given image using the U2netp model.
18
+
19
+ Parameters:
20
+ img (PILImage): The input image.
21
+
22
+ Returns:
23
+ List[PILImage]: The predicted mask.
24
+ """
25
  ort_outs = self.inner_session.run(
26
  None,
27
  self.normalize(
 
38
  pred = np.squeeze(pred)
39
 
40
  mask = Image.fromarray((pred * 255).astype("uint8"), mode="L")
41
+ mask = mask.resize(img.size, Image.Resampling.LANCZOS)
42
 
43
  return [mask]
44
 
45
  @classmethod
46
  def download_models(cls, *args, **kwargs):
47
+ """
48
+ Downloads the U2netp model.
49
+
50
+ Returns:
51
+ str: The path to the downloaded model.
52
+ """
53
+ fname = f"{cls.name(*args, **kwargs)}.onnx"
54
  pooch.retrieve(
55
  "https://github.com/danielgatis/rembg/releases/download/v0.0.0/u2netp.onnx",
56
+ (
57
+ None
58
+ if cls.checksum_disabled(*args, **kwargs)
59
+ else "md5:8e83ca70e441ab06c318d82300c84806"
60
+ ),
61
  fname=fname,
62
  path=cls.u2net_home(*args, **kwargs),
63
  progressbar=True,
64
  )
65
 
66
+ return os.path.join(cls.u2net_home(*args, **kwargs), fname)
67
 
68
  @classmethod
69
  def name(cls, *args, **kwargs):
70
+ """
71
+ Returns the name of the U2netp model.
72
+
73
+ Returns:
74
+ str: The name of the model.
75
+ """
76
  return "u2netp"
requirements.txt CHANGED
@@ -1,18 +1,20 @@
1
- aiohttp==3.8.1
2
- asyncer==0.0.2
3
- click==8.1.3
4
- fastapi==0.87.0
5
- filetype==1.2.0
6
- pooch==1.6.0
7
- imagehash==4.3.1
8
- numpy==1.23.5
9
- onnxruntime==1.13.1
10
- opencv-python-headless==4.6.0.66
11
- pillow==9.3.0
12
- pymatting==1.1.8
13
- python-multipart==0.0.5
14
- scikit-image==0.19.3
15
- scipy==1.9.3
16
- tqdm==4.64.1
17
- uvicorn==0.20.0
18
- watchdog==2.1.9
 
 
 
1
+ filetype==1.2.0
2
+ pooch==1.6.0
3
+ imagehash==4.3.1
4
+ numpy==1.23.5
5
+ onnxruntime==1.13.1
6
+ opencv-python-headless==4.6.0.66
7
+ pillow==9.3.0
8
+ pymatting==1.1.8
9
+ python-multipart==0.0.5
10
+ scikit-image==0.19.3
11
+ scipy==1.9.3
12
+ tqdm==4.64.1
13
+ uvicorn==0.20.0
14
+ watchdog==2.1.9
15
+ click==8.1.3
16
+ fastapi==0.85.0
17
+ aiohttp==3.8.3
18
+ asyncer==0.0.2
19
+ gradio==3.0.20
20
+ jsonschema==4.16.0