liuyuan-pal commited on
Commit
a14768e
1 Parent(s): 741df38
Files changed (4) hide show
  1. app.py +21 -13
  2. detection_test.py +35 -0
  3. hf_demo/examples/basket.png +0 -0
  4. sam_utils.py +2 -3
app.py CHANGED
@@ -7,6 +7,7 @@ import torch
7
  import os
8
  import fire
9
  from omegaconf import OmegaConf
 
10
 
11
  from ldm.util import add_margin, instantiate_from_config
12
  from sam_utils import sam_init, sam_out_nosave
@@ -95,21 +96,28 @@ def white_background(img):
95
  return Image.fromarray(rgb)
96
 
97
  def sam_predict(predictor, raw_im):
98
- raw_im = np.asarray(raw_im)
99
- raw_rgb = white_background(raw_im)
100
- h, w = raw_rgb.height, raw_rgb.width
101
- raw_rgb = add_margin(raw_rgb, color=255, size=max(h, w))
102
-
103
- raw_rgb.thumbnail([512, 512], Image.Resampling.LANCZOS)
104
- image_sam = sam_out_nosave(predictor, raw_rgb.convert("RGB"))
105
-
106
- image_sam = np.asarray(image_sam)
107
- out_mask = image_sam[:,:,3:]>0
108
- out_rgb = image_sam[:,:,:3] * out_mask + 1 - out_mask
109
- out_mask = out_mask.astype(np.uint8) * 255
110
- out_img = np.concatenate([out_rgb, out_mask], 2)
 
 
 
 
 
 
111
 
112
  image_sam = Image.fromarray(out_img, mode='RGBA')
 
113
  torch.cuda.empty_cache()
114
  return image_sam
115
 
 
7
  import os
8
  import fire
9
  from omegaconf import OmegaConf
10
+ from rembg import remove
11
 
12
  from ldm.util import add_margin, instantiate_from_config
13
  from sam_utils import sam_init, sam_out_nosave
 
96
  return Image.fromarray(rgb)
97
 
98
  def sam_predict(predictor, raw_im):
99
+ raw_im.thumbnail([512, 512], Image.Resampling.LANCZOS)
100
+ image_nobg = remove(raw_im.convert('RGBA'), alpha_matting=True)
101
+ arr = np.asarray(image_nobg)[:, :, -1]
102
+ x_nonzero = np.nonzero(arr.sum(axis=0))
103
+ y_nonzero = np.nonzero(arr.sum(axis=1))
104
+ x_min = int(x_nonzero[0].min())
105
+ y_min = int(y_nonzero[0].min())
106
+ x_max = int(x_nonzero[0].max())
107
+ y_max = int(y_nonzero[0].max())
108
+ # image_nobg.save('./nobg.png')
109
+
110
+ image_nobg.thumbnail([512, 512], Image.Resampling.LANCZOS)
111
+ image_sam = sam_out_nosave(predictor, image_nobg.convert("RGB"), (x_min, y_min, x_max, y_max))
112
+
113
+ # imsave('./mask.png', np.asarray(image_sam)[:,:,3]*255)
114
+ image_sam = np.asarray(image_sam, np.float32) / 255
115
+ out_mask = image_sam[:, :, 3:]
116
+ out_rgb = image_sam[:, :, :3] * out_mask + 1 - out_mask
117
+ out_img = (np.concatenate([out_rgb, out_mask], 2) * 255).astype(np.uint8)
118
 
119
  image_sam = Image.fromarray(out_img, mode='RGBA')
120
+ # image_sam.save('./output.png')
121
  torch.cuda.empty_cache()
122
  return image_sam
123
 
detection_test.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from PIL import Image
3
+ from skimage.io import imsave
4
+
5
+ from app import white_background
6
+ from ldm.util import add_margin
7
+ from sam_utils import sam_out_nosave, sam_init
8
+ from rembg import remove
9
+
10
+ raw_im = Image.open('hf_demo/examples/basket.png')
11
+ predictor = sam_init()
12
+
13
+ raw_im.thumbnail([512, 512], Image.Resampling.LANCZOS)
14
+ width, height = raw_im.size
15
+ image_nobg = remove(raw_im.convert('RGBA'), alpha_matting=True)
16
+ arr = np.asarray(image_nobg)[:, :, -1]
17
+ x_nonzero = np.nonzero(arr.sum(axis=0))
18
+ y_nonzero = np.nonzero(arr.sum(axis=1))
19
+ x_min = int(x_nonzero[0].min())
20
+ y_min = int(y_nonzero[0].min())
21
+ x_max = int(x_nonzero[0].max())
22
+ y_max = int(y_nonzero[0].max())
23
+ # image_nobg.save('./nobg.png')
24
+
25
+ image_nobg.thumbnail([512, 512], Image.Resampling.LANCZOS)
26
+ image_sam = sam_out_nosave(predictor, image_nobg.convert("RGB"), (x_min, y_min, x_max, y_max))
27
+
28
+ # imsave('./mask.png', np.asarray(image_sam)[:,:,3]*255)
29
+ image_sam = np.asarray(image_sam, np.float32) / 255
30
+ out_mask = image_sam[:, :, 3:]
31
+ out_rgb = image_sam[:, :, :3] * out_mask + 1 - out_mask
32
+ out_img = (np.concatenate([out_rgb, out_mask], 2) * 255).astype(np.uint8)
33
+
34
+ image_sam = Image.fromarray(out_img, mode='RGBA')
35
+ # image_sam.save('./output.png')
hf_demo/examples/basket.png ADDED
sam_utils.py CHANGED
@@ -16,10 +16,9 @@ def sam_init(device_id=0):
16
  predictor = SamPredictor(sam)
17
  return predictor
18
 
19
- def sam_out_nosave(predictor, input_image, ):
 
20
  image = np.asarray(input_image)
21
- h, w, _ = image.shape
22
- bbox = np.array([0, 0, h, w])
23
 
24
  start_time = time.time()
25
  predictor.set_image(image)
 
16
  predictor = SamPredictor(sam)
17
  return predictor
18
 
19
+ def sam_out_nosave(predictor, input_image, bbox):
20
+ bbox = np.array(bbox)
21
  image = np.asarray(input_image)
 
 
22
 
23
  start_time = time.time()
24
  predictor.set_image(image)