bluestyle97 commited on
Commit
c070b3c
·
verified ·
1 Parent(s): d70dae0

Update freesplatter/utils/infer_util.py

Browse files
Files changed (1) hide show
  1. freesplatter/utils/infer_util.py +31 -31
freesplatter/utils/infer_util.py CHANGED
@@ -68,36 +68,10 @@ def get_obj_from_str(string, reload=False):
68
  # return image
69
 
70
 
71
- # @torch.inference_mode()
72
- # def remove_background(
73
- # image: PIL.Image.Image,
74
- # rembg: Any = None,
75
- # force: bool = False,
76
- # **rembg_kwargs,
77
- # ) -> PIL.Image.Image:
78
- # do_remove = True
79
- # if image.mode == "RGBA" and image.getextrema()[3][0] < 255:
80
- # do_remove = False
81
- # do_remove = do_remove or force
82
- # if do_remove:
83
- # transform_image = transforms.Compose([
84
- # transforms.Resize((1024, 1024)),
85
- # transforms.ToTensor(),
86
- # transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
87
- # ])
88
- # image = image.convert('RGB')
89
- # input_images = transform_image(image).unsqueeze(0).to(rembg.device)
90
- # with torch.no_grad():
91
- # preds = rembg(input_images)[-1].sigmoid().cpu()
92
- # pred = preds[0].squeeze()
93
- # pred_pil = transforms.ToPILImage()(pred)
94
- # mask = pred_pil.resize(image.size)
95
- # image.putalpha(mask)
96
- # return image
97
-
98
-
99
- def remove_background(image: PIL.Image.Image,
100
- rembg_session: Any = None,
101
  force: bool = False,
102
  **rembg_kwargs,
103
  ) -> PIL.Image.Image:
@@ -106,10 +80,36 @@ def remove_background(image: PIL.Image.Image,
106
  do_remove = False
107
  do_remove = do_remove or force
108
  if do_remove:
109
- image = rembg.remove(image, session=rembg_session, **rembg_kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
110
  return image
111
 
112
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113
  def resize_foreground(
114
  image: PIL.Image.Image,
115
  ratio: float,
 
68
  # return image
69
 
70
 
71
+ @torch.inference_mode()
72
+ def remove_background(
73
+ image: PIL.Image.Image,
74
+ rembg: Any = None,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
  force: bool = False,
76
  **rembg_kwargs,
77
  ) -> PIL.Image.Image:
 
80
  do_remove = False
81
  do_remove = do_remove or force
82
  if do_remove:
83
+ transform_image = transforms.Compose([
84
+ transforms.Resize((1024, 1024)),
85
+ transforms.ToTensor(),
86
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
87
+ ])
88
+ image = image.convert('RGB')
89
+ input_images = transform_image(image).unsqueeze(0).to(rembg.device)
90
+ with torch.no_grad():
91
+ preds = rembg(input_images)[-1].sigmoid().cpu()
92
+ pred = preds[0].squeeze()
93
+ pred_pil = transforms.ToPILImage()(pred)
94
+ mask = pred_pil.resize(image.size)
95
+ image.putalpha(mask)
96
  return image
97
 
98
 
99
+ # def remove_background(image: PIL.Image.Image,
100
+ # rembg_session: Any = None,
101
+ # force: bool = False,
102
+ # **rembg_kwargs,
103
+ # ) -> PIL.Image.Image:
104
+ # do_remove = True
105
+ # if image.mode == "RGBA" and image.getextrema()[3][0] < 255:
106
+ # do_remove = False
107
+ # do_remove = do_remove or force
108
+ # if do_remove:
109
+ # image = rembg.remove(image, session=rembg_session, **rembg_kwargs)
110
+ # return image
111
+
112
+
113
  def resize_foreground(
114
  image: PIL.Image.Image,
115
  ratio: float,