kxqt commited on
Commit
1f28384
1 Parent(s): c46b2fc

fix bugs and add sam time box

Browse files
app.py CHANGED
@@ -1,10 +1,18 @@
1
  import os
 
2
  import torch
3
  import numpy as np
4
 
5
  import gradio as gr
6
 
7
  from segment_anything import build_sam, SamAutomaticMaskGenerator
 
 
 
 
 
 
 
8
 
9
  os.system(r'python -m wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth')
10
 
@@ -22,14 +30,76 @@ hourglass_args = {
22
  },
23
  }
24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  def predict(image, speed_mode, points_per_side):
26
  points_per_side = int(points_per_side)
27
  mask_generator = SamAutomaticMaskGenerator(
28
- build_sam(checkpoint="sam_vit_h_4b8939.pth", hourglass_kwargs=hourglass_args[speed_mode]),
29
  points_per_side=points_per_side,
30
  points_per_batch=64 if points_per_side > 12 else points_per_side * points_per_side
31
  )
32
- masks = mask_generator.generate(image)
 
 
 
 
 
33
 
34
  if len(masks) == 0:
35
  return image
@@ -41,7 +111,7 @@ def predict(image, speed_mode, points_per_side):
41
  img = img * (1 - m[..., None]) + color_mask * m[..., None]
42
 
43
  image = ((image + img * 255) / 2).astype(np.uint8)
44
- return image
45
 
46
  description = """
47
  # <center>Expedit-SAM (Expedite Segment Anything Model without any training)</center>
@@ -73,7 +143,9 @@ def main():
73
  with gr.Row():
74
  run_btn = gr.Button(label="Run", id="run", value="Run")
75
  clear_btn = gr.Button(label="Clear", id="clear", value="Clear")
76
- output_image = gr.Image(label="Output Image")
 
 
77
  gr.Examples(
78
  examples=[
79
  ["./notebooks/images/dog.jpg"],
@@ -88,7 +160,7 @@ def main():
88
  run_btn.click(
89
  fn=predict,
90
  inputs=[input_image, speed_mode, points_per_side],
91
- outputs=output_image
92
  )
93
  clear_btn.click(
94
  fn=lambda: [None, None],
 
1
  import os
2
+ import time
3
  import torch
4
  import numpy as np
5
 
6
  import gradio as gr
7
 
8
  from segment_anything import build_sam, SamAutomaticMaskGenerator
9
+ from segment_anything.utils.amg import (
10
+ batch_iterator,
11
+ MaskData,
12
+ calculate_stability_score,
13
+ batched_mask_to_box,
14
+ is_box_near_crop_edge,
15
+ )
16
 
17
  os.system(r'python -m wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth')
18
 
 
30
  },
31
  }
32
 
33
+ def generate_mask(image, generator: SamAutomaticMaskGenerator):
34
+ start = time.perf_counter()
35
+ generator.predictor.set_image(image)
36
+ eta1 = time.perf_counter() - start
37
+
38
+ image_size = image.shape[:2]
39
+ points_scale = np.array(image_size)[None, ::-1]
40
+ points_for_image = generator.point_grids[0] * points_scale
41
+ for (points,) in batch_iterator(generator.points_per_batch, points_for_image):
42
+ transformed_points = generator.predictor.transform.apply_coords(points, image_size)
43
+ in_points = torch.as_tensor(transformed_points, device=generator.predictor.device)
44
+ in_labels = torch.ones(in_points.shape[0], dtype=torch.int, device=in_points.device)
45
+ start = time.perf_counter()
46
+ masks, iou_preds, _ = generator.predictor.predict_torch(
47
+ in_points[:, None, :],
48
+ in_labels[:, None],
49
+ multimask_output=True,
50
+ return_logits=True,
51
+ )
52
+ eta2 = time.perf_counter() - start
53
+
54
+ # Serialize predictions and store in MaskData
55
+ data = MaskData(
56
+ masks=masks.flatten(0, 1),
57
+ iou_preds=iou_preds.flatten(0, 1),
58
+ points=torch.as_tensor(points.repeat(masks.shape[1], axis=0)),
59
+ )
60
+ del masks
61
+
62
+ # Filter by predicted IoU
63
+ if generator.pred_iou_thresh > 0.0:
64
+ keep_mask = data["iou_preds"] > generator.pred_iou_thresh
65
+ data.filter(keep_mask)
66
+
67
+ # Calculate stability score
68
+ data["stability_score"] = calculate_stability_score(
69
+ data["masks"], generator.predictor.model.mask_threshold, generator.stability_score_offset
70
+ )
71
+ if generator.stability_score_thresh > 0.0:
72
+ keep_mask = data["stability_score"] >= generator.stability_score_thresh
73
+ data.filter(keep_mask)
74
+
75
+ # Threshold masks and calculate boxes
76
+ data["masks"] = data["masks"] > generator.predictor.model.mask_threshold
77
+
78
+ # Write mask records
79
+ curr_anns = []
80
+ for idx in range(len(data["masks"])):
81
+ ann = {
82
+ "segmentation": data["masks"][idx].numpy(),
83
+ "area": data["masks"][idx].sum().item(),
84
+ }
85
+ curr_anns.append(ann)
86
+
87
+ return curr_anns
88
+
89
+
90
  def predict(image, speed_mode, points_per_side):
91
  points_per_side = int(points_per_side)
92
  mask_generator = SamAutomaticMaskGenerator(
93
+ build_sam(checkpoint="sam_vit_h_4b8939.pth", **hourglass_args[speed_mode]),
94
  points_per_side=points_per_side,
95
  points_per_batch=64 if points_per_side > 12 else points_per_side * points_per_side
96
  )
97
+ start = time.perf_counter()
98
+ with torch.no_grad():
99
+ # masks = mask_generator.generate(image)
100
+ masks = generate_mask(image, mask_generator)
101
+ eta = time.perf_counter() - start
102
+ eta_text = f"Time of generation: {eta:.2f} seconds"
103
 
104
  if len(masks) == 0:
105
  return image
 
111
  img = img * (1 - m[..., None]) + color_mask * m[..., None]
112
 
113
  image = ((image + img * 255) / 2).astype(np.uint8)
114
+ return image, eta_text
115
 
116
  description = """
117
  # <center>Expedit-SAM (Expedite Segment Anything Model without any training)</center>
 
143
  with gr.Row():
144
  run_btn = gr.Button(label="Run", id="run", value="Run")
145
  clear_btn = gr.Button(label="Clear", id="clear", value="Clear")
146
+ with gr.Column():
147
+ output_image = gr.Image(label="Output Image")
148
+ eta_label = gr.Label(label="ETA")
149
  gr.Examples(
150
  examples=[
151
  ["./notebooks/images/dog.jpg"],
 
160
  run_btn.click(
161
  fn=predict,
162
  inputs=[input_image, speed_mode, points_per_side],
163
+ outputs=[output_image, eta_label]
164
  )
165
  clear_btn.click(
166
  fn=lambda: [None, None],
segment_anything/modeling/hourglass_image_encoder.py CHANGED
@@ -203,7 +203,7 @@ class TokenReconstructionBlock(UnpoolingBase):
203
  mink = torch.min(topk, dim=-1).values
204
  mink = mink.unsqueeze(-1).repeat(1, 1, weight.shape[-1])
205
  mask = torch.ge(weight, mink)
206
- zero = Variable(torch.zeros_like(weight)).cuda()
207
  attention = torch.where(mask, weight, zero)
208
  attention = F.normalize(attention, dim=2)
209
  ret = torch.einsum("bnm, bmc -> bnc", attention, x)
@@ -233,10 +233,10 @@ class HourglassImageEncoderViT(ImageEncoderViT):
233
  global_attn_indexes: Tuple[int, ...] = (),
234
  hourglass_clustering_location: int = -1,
235
  hourglass_num_cluster: int = None,
236
- hourglass_cluster_iters: int = 3,
237
- hourglass_temperture: float = 0.1,
238
- hourglass_cluster_window_size: int = 12,
239
- hourglass_reconstruction_k: int = 36,
240
  ) -> None:
241
  """
242
  Args:
 
203
  mink = torch.min(topk, dim=-1).values
204
  mink = mink.unsqueeze(-1).repeat(1, 1, weight.shape[-1])
205
  mask = torch.ge(weight, mink)
206
+ zero = Variable(torch.zeros_like(weight)).to(weight.device)
207
  attention = torch.where(mask, weight, zero)
208
  attention = F.normalize(attention, dim=2)
209
  ret = torch.einsum("bnm, bmc -> bnc", attention, x)
 
233
  global_attn_indexes: Tuple[int, ...] = (),
234
  hourglass_clustering_location: int = -1,
235
  hourglass_num_cluster: int = None,
236
+ hourglass_cluster_iters: int = 5,
237
+ hourglass_temperture: float = 0.01,
238
+ hourglass_cluster_window_size: int = 5,
239
+ hourglass_reconstruction_k: int = 20,
240
  ) -> None:
241
  """
242
  Args: