KenjieDec commited on
Commit
87c57a3
·
1 Parent(s): 2b6bc23
rembg/__init__.py CHANGED
@@ -1,11 +1,6 @@
1
- import sys
2
- import warnings
3
-
4
- if not (sys.version_info.major == 3 and sys.version_info.minor == 9):
5
- warnings.warn("This library is only for Python 3.9", RuntimeWarning)
6
-
7
  from . import _version
8
 
9
  __version__ = _version.get_versions()["version"]
10
 
11
  from .bg import remove
 
 
 
 
 
 
 
 
1
  from . import _version
2
 
3
  __version__ = _version.get_versions()["version"]
4
 
5
  from .bg import remove
6
+ from .session_factory import new_session
rembg/_version.py CHANGED
@@ -24,8 +24,8 @@ def get_keywords():
24
  # each be defined on a line of their own. _version.py will just call
25
  # get_keywords().
26
  git_refnames = " (HEAD -> main)"
27
- git_full = "3bc1c1af99ebd47dd08d02763fc754d70d42afea"
28
- git_date = "2022-06-16 23:00:14 -0300"
29
  keywords = {"refnames": git_refnames, "full": git_full, "date": git_date}
30
  return keywords
31
 
 
24
  # each be defined on a line of their own. _version.py will just call
25
  # get_keywords().
26
  git_refnames = " (HEAD -> main)"
27
+ git_full = "edc9fe27dff030cf6c2f29ef9a66c32d6e3f4658"
28
+ git_date = "2022-11-28 08:14:19 -0300"
29
  keywords = {"refnames": git_refnames, "full": git_full, "date": git_date}
30
  return keywords
31
 
rembg/bg.py CHANGED
@@ -3,16 +3,26 @@ from enum import Enum
3
  from typing import List, Optional, Union
4
 
5
  import numpy as np
 
 
 
 
 
 
 
 
6
  from PIL import Image
7
  from PIL.Image import Image as PILImage
8
  from pymatting.alpha.estimate_alpha_cf import estimate_alpha_cf
9
  from pymatting.foreground.estimate_foreground_ml import estimate_foreground_ml
10
  from pymatting.util.util import stack_images
11
- from scipy.ndimage.morphology import binary_erosion
12
 
13
  from .session_base import BaseSession
14
  from .session_factory import new_session
15
 
 
 
16
 
17
  class ReturnType(Enum):
18
  BYTES = 0
@@ -27,6 +37,10 @@ def alpha_matting_cutout(
27
  background_threshold: int,
28
  erode_structure_size: int,
29
  ) -> PILImage:
 
 
 
 
30
  img = np.asarray(img)
31
  mask = np.asarray(mask)
32
 
@@ -79,6 +93,19 @@ def get_concat_v(img1: PILImage, img2: PILImage) -> PILImage:
79
  return dst
80
 
81
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
  def remove(
83
  data: Union[bytes, PILImage, np.ndarray],
84
  alpha_matting: bool = False,
@@ -87,6 +114,7 @@ def remove(
87
  alpha_matting_erode_size: int = 10,
88
  session: Optional[BaseSession] = None,
89
  only_mask: bool = False,
 
90
  ) -> Union[bytes, PILImage, np.ndarray]:
91
 
92
  if isinstance(data, PILImage):
@@ -108,6 +136,9 @@ def remove(
108
  cutouts = []
109
 
110
  for mask in masks:
 
 
 
111
  if only_mask:
112
  cutout = mask
113
 
 
3
  from typing import List, Optional, Union
4
 
5
  import numpy as np
6
+ from cv2 import (
7
+ BORDER_DEFAULT,
8
+ MORPH_ELLIPSE,
9
+ MORPH_OPEN,
10
+ GaussianBlur,
11
+ getStructuringElement,
12
+ morphologyEx,
13
+ )
14
  from PIL import Image
15
  from PIL.Image import Image as PILImage
16
  from pymatting.alpha.estimate_alpha_cf import estimate_alpha_cf
17
  from pymatting.foreground.estimate_foreground_ml import estimate_foreground_ml
18
  from pymatting.util.util import stack_images
19
+ from scipy.ndimage import binary_erosion
20
 
21
  from .session_base import BaseSession
22
  from .session_factory import new_session
23
 
24
+ kernel = getStructuringElement(MORPH_ELLIPSE, (3, 3))
25
+
26
 
27
  class ReturnType(Enum):
28
  BYTES = 0
 
37
  background_threshold: int,
38
  erode_structure_size: int,
39
  ) -> PILImage:
40
+
41
+ if img.mode == "RGBA" or img.mode == "CMYK":
42
+ img = img.convert("RGB")
43
+
44
  img = np.asarray(img)
45
  mask = np.asarray(mask)
46
 
 
93
  return dst
94
 
95
 
96
+ def post_process(mask: np.ndarray) -> np.ndarray:
97
+ """
98
+ Post Process the mask for a smooth boundary by applying Morphological Operations
99
+ Research based on paper: https://www.sciencedirect.com/science/article/pii/S2352914821000757
100
+ args:
101
+ mask: Binary Numpy Mask
102
+ """
103
+ mask = morphologyEx(mask, MORPH_OPEN, kernel)
104
+ mask = GaussianBlur(mask, (5, 5), sigmaX=2, sigmaY=2, borderType=BORDER_DEFAULT)
105
+ mask = np.where(mask < 127, 0, 255).astype(np.uint8) # convert again to binary
106
+ return mask
107
+
108
+
109
  def remove(
110
  data: Union[bytes, PILImage, np.ndarray],
111
  alpha_matting: bool = False,
 
114
  alpha_matting_erode_size: int = 10,
115
  session: Optional[BaseSession] = None,
116
  only_mask: bool = False,
117
+ post_process_mask: bool = False,
118
  ) -> Union[bytes, PILImage, np.ndarray]:
119
 
120
  if isinstance(data, PILImage):
 
136
  cutouts = []
137
 
138
  for mask in masks:
139
+ if post_process_mask:
140
+ mask = Image.fromarray(post_process(np.array(mask)))
141
+
142
  if only_mask:
143
  cutout = mask
144
 
rembg/cli.py CHANGED
@@ -33,7 +33,9 @@ def main() -> None:
33
  "-m",
34
  "--model",
35
  default="u2net",
36
- type=click.Choice(["u2net", "u2netp", "u2net_human_seg", "u2net_cloth_seg"]),
 
 
37
  show_default=True,
38
  show_choices=True,
39
  help="model name",
@@ -76,6 +78,13 @@ def main() -> None:
76
  show_default=True,
77
  help="output only the mask",
78
  )
 
 
 
 
 
 
 
79
  @click.argument(
80
  "input", default=(None if sys.stdin.isatty() else "-"), type=click.File("rb")
81
  )
@@ -93,7 +102,9 @@ def i(model: str, input: IO, output: IO, **kwargs) -> None:
93
  "-m",
94
  "--model",
95
  default="u2net",
96
- type=click.Choice(["u2net", "u2netp", "u2net_human_seg", "u2net_cloth_seg"]),
 
 
97
  show_default=True,
98
  show_choices=True,
99
  help="model name",
@@ -136,6 +147,13 @@ def i(model: str, input: IO, output: IO, **kwargs) -> None:
136
  show_default=True,
137
  help="output only the mask",
138
  )
 
 
 
 
 
 
 
139
  @click.option(
140
  "-w",
141
  "--watch",
@@ -243,7 +261,15 @@ def p(
243
  show_default=True,
244
  help="log level",
245
  )
246
- def s(port: int, log_level: str) -> None:
 
 
 
 
 
 
 
 
247
  sessions: dict[str, BaseSession] = {}
248
  tags_metadata = [
249
  {
@@ -284,6 +310,7 @@ def s(port: int, log_level: str) -> None:
284
  u2netp = "u2netp"
285
  u2net_human_seg = "u2net_human_seg"
286
  u2net_cloth_seg = "u2net_cloth_seg"
 
287
 
288
  class CommonQueryParams:
289
  def __init__(
@@ -309,6 +336,7 @@ def s(port: int, log_level: str) -> None:
309
  default=10, ge=0, description="Alpha Matting (Erode Structure Size)"
310
  ),
311
  om: bool = Query(default=False, description="Only Mask"),
 
312
  ):
313
  self.model = model
314
  self.a = a
@@ -316,6 +344,7 @@ def s(port: int, log_level: str) -> None:
316
  self.ab = ab
317
  self.ae = ae
318
  self.om = om
 
319
 
320
  class CommonQueryPostParams:
321
  def __init__(
@@ -341,6 +370,7 @@ def s(port: int, log_level: str) -> None:
341
  default=10, ge=0, description="Alpha Matting (Erode Structure Size)"
342
  ),
343
  om: bool = Form(default=False, description="Only Mask"),
 
344
  ):
345
  self.model = model
346
  self.a = a
@@ -348,6 +378,7 @@ def s(port: int, log_level: str) -> None:
348
  self.ab = ab
349
  self.ae = ae
350
  self.om = om
 
351
 
352
  def im_without_bg(content: bytes, commons: CommonQueryParams) -> Response:
353
  return Response(
@@ -361,10 +392,19 @@ def s(port: int, log_level: str) -> None:
361
  alpha_matting_background_threshold=commons.ab,
362
  alpha_matting_erode_size=commons.ae,
363
  only_mask=commons.om,
 
364
  ),
365
  media_type="image/png",
366
  )
367
 
 
 
 
 
 
 
 
 
368
  @app.get(
369
  path="/",
370
  tags=["Background Removal"],
 
33
  "-m",
34
  "--model",
35
  default="u2net",
36
+ type=click.Choice(
37
+ ["u2net", "u2netp", "u2net_human_seg", "u2net_cloth_seg", "silueta"]
38
+ ),
39
  show_default=True,
40
  show_choices=True,
41
  help="model name",
 
78
  show_default=True,
79
  help="output only the mask",
80
  )
81
+ @click.option(
82
+ "-ppm",
83
+ "--post-process-mask",
84
+ is_flag=True,
85
+ show_default=True,
86
+ help="post process the mask",
87
+ )
88
  @click.argument(
89
  "input", default=(None if sys.stdin.isatty() else "-"), type=click.File("rb")
90
  )
 
102
  "-m",
103
  "--model",
104
  default="u2net",
105
+ type=click.Choice(
106
+ ["u2net", "u2netp", "u2net_human_seg", "u2net_cloth_seg", "silueta"]
107
+ ),
108
  show_default=True,
109
  show_choices=True,
110
  help="model name",
 
147
  show_default=True,
148
  help="output only the mask",
149
  )
150
+ @click.option(
151
+ "-ppm",
152
+ "--post-process-mask",
153
+ is_flag=True,
154
+ show_default=True,
155
+ help="post process the mask",
156
+ )
157
  @click.option(
158
  "-w",
159
  "--watch",
 
261
  show_default=True,
262
  help="log level",
263
  )
264
+ @click.option(
265
+ "-t",
266
+ "--threads",
267
+ default=None,
268
+ type=int,
269
+ show_default=True,
270
+ help="number of worker threads",
271
+ )
272
+ def s(port: int, log_level: str, threads: int) -> None:
273
  sessions: dict[str, BaseSession] = {}
274
  tags_metadata = [
275
  {
 
310
  u2netp = "u2netp"
311
  u2net_human_seg = "u2net_human_seg"
312
  u2net_cloth_seg = "u2net_cloth_seg"
313
+ silueta = "silueta"
314
 
315
  class CommonQueryParams:
316
  def __init__(
 
336
  default=10, ge=0, description="Alpha Matting (Erode Structure Size)"
337
  ),
338
  om: bool = Query(default=False, description="Only Mask"),
339
+ ppm: bool = Query(default=False, description="Post Process Mask"),
340
  ):
341
  self.model = model
342
  self.a = a
 
344
  self.ab = ab
345
  self.ae = ae
346
  self.om = om
347
+ self.ppm = ppm
348
 
349
  class CommonQueryPostParams:
350
  def __init__(
 
370
  default=10, ge=0, description="Alpha Matting (Erode Structure Size)"
371
  ),
372
  om: bool = Form(default=False, description="Only Mask"),
373
+ ppm: bool = Form(default=False, description="Post Process Mask"),
374
  ):
375
  self.model = model
376
  self.a = a
 
378
  self.ab = ab
379
  self.ae = ae
380
  self.om = om
381
+ self.ppm = ppm
382
 
383
  def im_without_bg(content: bytes, commons: CommonQueryParams) -> Response:
384
  return Response(
 
392
  alpha_matting_background_threshold=commons.ab,
393
  alpha_matting_erode_size=commons.ae,
394
  only_mask=commons.om,
395
+ post_process_mask=commons.ppm,
396
  ),
397
  media_type="image/png",
398
  )
399
 
400
+ @app.on_event("startup")
401
+ def startup():
402
+ if threads is not None:
403
+ from anyio import CapacityLimiter
404
+ from anyio.lowlevel import RunVar
405
+
406
+ RunVar("_default_thread_limiter").set(CapacityLimiter(threads))
407
+
408
  @app.get(
409
  path="/",
410
  tags=["Background Removal"],
rembg/session_base.py CHANGED
@@ -18,7 +18,7 @@ class BaseSession:
18
  std: Tuple[float, float, float],
19
  size: Tuple[int, int],
20
  ) -> Dict[str, np.ndarray]:
21
- im = img.convert("RGB").resize(size, Image.LANCZOS)
22
 
23
  im_ary = np.array(im)
24
  im_ary = im_ary / np.max(im_ary)
 
18
  std: Tuple[float, float, float],
19
  size: Tuple[int, int],
20
  ) -> Dict[str, np.ndarray]:
21
+ im = img.convert("RGB").resize(size, Image.Resampling.LANCZOS)
22
 
23
  im_ary = np.array(im)
24
  im_ary = im_ary / np.max(im_ary)
rembg/session_factory.py CHANGED
@@ -5,50 +5,56 @@ from contextlib import redirect_stdout
5
  from pathlib import Path
6
  from typing import Type
7
 
8
- import gdown
9
  import onnxruntime as ort
 
10
 
11
  from .session_base import BaseSession
12
  from .session_cloth import ClothSession
13
  from .session_simple import SimpleSession
14
 
15
 
16
- def new_session(model_name: str) -> BaseSession:
17
  session_class: Type[BaseSession]
 
 
 
18
 
19
  if model_name == "u2netp":
20
  md5 = "8e83ca70e441ab06c318d82300c84806"
21
- url = "https://drive.google.com/uc?id=1tNuFmLv0TSNDjYIkjEdeH1IWKQdUA4HR"
22
- session_class = SimpleSession
23
- elif model_name == "u2net":
24
- md5 = "60024c5c889badc19c04ad937298a77b"
25
- url = "https://drive.google.com/uc?id=1tCU5MM1LhRgGou5OpmpjBQbSrYIUoYab"
26
  session_class = SimpleSession
27
  elif model_name == "u2net_human_seg":
28
  md5 = "c09ddc2e0104f800e3e1bb4652583d1f"
29
- url = "https://drive.google.com/uc?id=1ZfqwVxu-1XWC1xU1GHIP-FM_Knd_AX5j"
30
  session_class = SimpleSession
31
  elif model_name == "u2net_cloth_seg":
32
  md5 = "2434d1f3cb744e0e49386c906e5a08bb"
33
- url = "https://drive.google.com/uc?id=15rKbQSXQzrKCQurUjZFg8HqzZad8bcyz"
34
  session_class = ClothSession
35
- else:
36
- assert AssertionError(
37
- "Choose between u2net, u2netp, u2net_human_seg or u2net_cloth_seg"
 
38
  )
 
39
 
40
- home = os.getenv("U2NET_HOME", os.path.join("~", ".u2net"))
41
- path = Path(home).expanduser() / f"{model_name}.onnx"
42
- path.parents[0].mkdir(parents=True, exist_ok=True)
 
 
 
 
43
 
44
- if not path.exists():
45
- with redirect_stdout(sys.stderr):
46
- gdown.download(url, str(path), use_cookies=False)
47
- else:
48
- hashing = hashlib.new("md5", path.read_bytes(), usedforsecurity=False)
49
- if hashing.hexdigest() != md5:
50
- with redirect_stdout(sys.stderr):
51
- gdown.download(url, str(path), use_cookies=False)
52
 
53
  sess_opts = ort.SessionOptions()
54
 
@@ -58,6 +64,8 @@ def new_session(model_name: str) -> BaseSession:
58
  return session_class(
59
  model_name,
60
  ort.InferenceSession(
61
- str(path), providers=ort.get_available_providers(), sess_options=sess_opts
 
 
62
  ),
63
  )
 
5
  from pathlib import Path
6
  from typing import Type
7
 
 
8
  import onnxruntime as ort
9
+ import pooch
10
 
11
  from .session_base import BaseSession
12
  from .session_cloth import ClothSession
13
  from .session_simple import SimpleSession
14
 
15
 
16
+ def new_session(model_name: str = "u2net") -> BaseSession:
17
  session_class: Type[BaseSession]
18
+ md5 = "60024c5c889badc19c04ad937298a77b"
19
+ url = "https://github.com/danielgatis/rembg/releases/download/v0.0.0/u2net.onnx"
20
+ session_class = SimpleSession
21
 
22
  if model_name == "u2netp":
23
  md5 = "8e83ca70e441ab06c318d82300c84806"
24
+ url = (
25
+ "https://github.com/danielgatis/rembg/releases/download/v0.0.0/u2netp.onnx"
26
+ )
 
 
27
  session_class = SimpleSession
28
  elif model_name == "u2net_human_seg":
29
  md5 = "c09ddc2e0104f800e3e1bb4652583d1f"
30
+ url = "https://github.com/danielgatis/rembg/releases/download/v0.0.0/u2net_human_seg.onnx"
31
  session_class = SimpleSession
32
  elif model_name == "u2net_cloth_seg":
33
  md5 = "2434d1f3cb744e0e49386c906e5a08bb"
34
+ url = "https://github.com/danielgatis/rembg/releases/download/v0.0.0/u2net_cloth_seg.onnx"
35
  session_class = ClothSession
36
+ elif model_name == "silueta":
37
+ md5 = "55e59e0d8062d2f5d013f4725ee84782"
38
+ url = (
39
+ "https://github.com/danielgatis/rembg/releases/download/v0.0.0/silueta.onnx"
40
  )
41
+ session_class = SimpleSession
42
 
43
+ u2net_home = os.getenv(
44
+ "U2NET_HOME", os.path.join(os.getenv("XDG_DATA_HOME", "~"), ".u2net")
45
+ )
46
+
47
+ fname = f"{model_name}.onnx"
48
+ path = Path(u2net_home).expanduser()
49
+ full_path = Path(u2net_home).expanduser() / fname
50
 
51
+ pooch.retrieve(
52
+ url,
53
+ f"md5:{md5}",
54
+ fname=fname,
55
+ path=Path(u2net_home).expanduser(),
56
+ progressbar=True,
57
+ )
 
58
 
59
  sess_opts = ort.SessionOptions()
60
 
 
64
  return session_class(
65
  model_name,
66
  ort.InferenceSession(
67
+ str(full_path),
68
+ providers=ort.get_available_providers(),
69
+ sess_options=sess_opts,
70
  ),
71
  )
rembg/session_simple.py CHANGED
@@ -25,6 +25,6 @@ class SimpleSession(BaseSession):
25
  pred = np.squeeze(pred)
26
 
27
  mask = Image.fromarray((pred * 255).astype("uint8"), mode="L")
28
- mask = mask.resize(img.size, Image.LANCZOS)
29
 
30
  return [mask]
 
25
  pred = np.squeeze(pred)
26
 
27
  mask = Image.fromarray((pred * 255).astype("uint8"), mode="L")
28
+ mask = mask.resize(img.size, Image.Resampling.LANCZOS)
29
 
30
  return [mask]