Arulkumar03 commited on
Commit
96b3f69
1 Parent(s): 3fefe22

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +46 -8
app.py CHANGED
@@ -62,8 +62,37 @@ def image_transform_grounding_for_vis(init_image):
62
  image, _ = transform(init_image, None) # 3, h, w
63
  return image
64
 
65
- model = load_model_hf(config_file, ckpt_repo_id, ckpt_filenmae)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
 
 
 
 
67
  def run_grounding(input_image, grounding_caption, box_threshold, text_threshold):
68
  init_image = input_image.convert("RGB")
69
  original_size = init_image.size
@@ -72,12 +101,21 @@ def run_grounding(input_image, grounding_caption, box_threshold, text_threshold)
72
  image_pil: Image = image_transform_grounding_for_vis(init_image)
73
 
74
  # run grounidng
75
- boxes, logits, phrases = predict(model, image_tensor, grounding_caption, box_threshold, text_threshold, device='cpu')
76
- annotated_frame = annotate(image_source=np.asarray(image_pil), boxes=boxes, logits=logits, phrases=phrases)
77
- image_with_box = Image.fromarray(cv2.cvtColor(annotated_frame, cv2.COLOR_BGR2RGB))
78
-
79
-
80
- return image_with_box
 
 
 
 
 
 
 
 
 
81
 
82
  if __name__ == "__main__":
83
 
@@ -124,7 +162,7 @@ if __name__ == "__main__":
124
  gr.Examples(
125
  [["watermelon.jpg", "watermelon", 0.25, 0.25]],
126
  inputs = [input_image, grounding_caption, box_threshold, text_threshold],
127
- outputs = [gallery],
128
  fn=run_grounding,
129
  cache_examples=True,
130
  label='Try this example input!'
 
62
  image, _ = transform(init_image, None) # 3, h, w
63
  return image
64
 
65
+ model = load_model_hf(task, config_file, ckpt_repo_id, ckpt_filenmae)
66
+
67
+ def segment(image, sam_model, boxes):
68
+ sam_model.set_image(image)
69
+ H, W, _ = image.shape
70
+ boxes_xyxy = box_ops.box_cxcywh_to_xyxy(boxes) * torch.Tensor([W, H, W, H])
71
+
72
+ transformed_boxes = sam_model.transform.apply_boxes_torch(boxes_xyxy.to(device), image.shape[:2])
73
+ masks, _, _ = sam_model.predict_torch(
74
+ point_coords = None,
75
+ point_labels = None,
76
+ boxes = transformed_boxes,
77
+ multimask_output = False,
78
+ )
79
+ return masks.cpu()
80
+
81
+
82
+ def draw_mask(mask, image, random_color=True):
83
+ if random_color:
84
+ color = np.concatenate([np.random.random(3), np.array([0.8])], axis=0)
85
+ else:
86
+ color = np.array([30/255, 144/255, 255/255, 0.6])
87
+ h, w = mask.shape[-2:]
88
+ mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
89
+
90
+ annotated_frame_pil = Image.fromarray(image).convert("RGBA")
91
+ mask_image_pil = Image.fromarray((mask_image.cpu().numpy() * 255).astype(np.uint8)).convert("RGBA")
92
 
93
+ return np.array(Image.alpha_composite(annotated_frame_pil, mask_image_pil))
94
+
95
+
96
  def run_grounding(input_image, grounding_caption, box_threshold, text_threshold):
97
  init_image = input_image.convert("RGB")
98
  original_size = init_image.size
 
101
  image_pil: Image = image_transform_grounding_for_vis(init_image)
102
 
103
  # run grounidng
104
+ if task=='predict':
105
+ boxes, logits, phrases = predict(model, image_tensor, grounding_caption, box_threshold, text_threshold, device='cpu')
106
+ annotated_frame = annotate(image_source=np.asarray(image_pil), boxes=boxes, logits=logits, phrases=phrases)
107
+ image_with_box = Image.fromarray(cv2.cvtColor(annotated_frame, cv2.COLOR_BGR2RGB))
108
+
109
+ return image_with_box
110
+
111
+ elif task=='segment':
112
+ boxes, logits, phrases = predict(model, image_tensor, grounding_caption, box_threshold, text_threshold, device='cpu')
113
+ segmented_frame_masks = segment(image_tensor, model, boxes=boxes)
114
+ annotated_frame_with_mask = draw_mask(segmented_frame_masks[0][0], annotated_frame)
115
+ seg_with_bbox=Image.fromarray(annotated_frame_with_mask)
116
+
117
+ return seg_with_bbox
118
+
119
 
120
  if __name__ == "__main__":
121
 
 
162
  gr.Examples(
163
  [["watermelon.jpg", "watermelon", 0.25, 0.25]],
164
  inputs = [input_image, grounding_caption, box_threshold, text_threshold],
165
+ outputs = [gallery],gr.Choice(["segment", "classify"], label="Select Task")],
166
  fn=run_grounding,
167
  cache_examples=True,
168
  label='Try this example input!'