Arulkumar03
commited on
Commit
•
96b3f69
1
Parent(s):
3fefe22
Update app.py
Browse files
app.py
CHANGED
@@ -62,8 +62,37 @@ def image_transform_grounding_for_vis(init_image):
|
|
62 |
image, _ = transform(init_image, None) # 3, h, w
|
63 |
return image
|
64 |
|
65 |
-
model = load_model_hf(config_file, ckpt_repo_id, ckpt_filenmae)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
66 |
|
|
|
|
|
|
|
67 |
def run_grounding(input_image, grounding_caption, box_threshold, text_threshold):
|
68 |
init_image = input_image.convert("RGB")
|
69 |
original_size = init_image.size
|
@@ -72,12 +101,21 @@ def run_grounding(input_image, grounding_caption, box_threshold, text_threshold)
|
|
72 |
image_pil: Image = image_transform_grounding_for_vis(init_image)
|
73 |
|
74 |
# run grounidng
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
81 |
|
82 |
if __name__ == "__main__":
|
83 |
|
@@ -124,7 +162,7 @@ if __name__ == "__main__":
|
|
124 |
gr.Examples(
|
125 |
[["watermelon.jpg", "watermelon", 0.25, 0.25]],
|
126 |
inputs = [input_image, grounding_caption, box_threshold, text_threshold],
|
127 |
-
outputs = [gallery],
|
128 |
fn=run_grounding,
|
129 |
cache_examples=True,
|
130 |
label='Try this example input!'
|
|
|
62 |
image, _ = transform(init_image, None) # 3, h, w
|
63 |
return image
|
64 |
|
65 |
+
model = load_model_hf(task, config_file, ckpt_repo_id, ckpt_filenmae)
|
66 |
+
|
67 |
+
def segment(image, sam_model, boxes):
|
68 |
+
sam_model.set_image(image)
|
69 |
+
H, W, _ = image.shape
|
70 |
+
boxes_xyxy = box_ops.box_cxcywh_to_xyxy(boxes) * torch.Tensor([W, H, W, H])
|
71 |
+
|
72 |
+
transformed_boxes = sam_model.transform.apply_boxes_torch(boxes_xyxy.to(device), image.shape[:2])
|
73 |
+
masks, _, _ = sam_model.predict_torch(
|
74 |
+
point_coords = None,
|
75 |
+
point_labels = None,
|
76 |
+
boxes = transformed_boxes,
|
77 |
+
multimask_output = False,
|
78 |
+
)
|
79 |
+
return masks.cpu()
|
80 |
+
|
81 |
+
|
82 |
+
def draw_mask(mask, image, random_color=True):
|
83 |
+
if random_color:
|
84 |
+
color = np.concatenate([np.random.random(3), np.array([0.8])], axis=0)
|
85 |
+
else:
|
86 |
+
color = np.array([30/255, 144/255, 255/255, 0.6])
|
87 |
+
h, w = mask.shape[-2:]
|
88 |
+
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
|
89 |
+
|
90 |
+
annotated_frame_pil = Image.fromarray(image).convert("RGBA")
|
91 |
+
mask_image_pil = Image.fromarray((mask_image.cpu().numpy() * 255).astype(np.uint8)).convert("RGBA")
|
92 |
|
93 |
+
return np.array(Image.alpha_composite(annotated_frame_pil, mask_image_pil))
|
94 |
+
|
95 |
+
|
96 |
def run_grounding(input_image, grounding_caption, box_threshold, text_threshold):
|
97 |
init_image = input_image.convert("RGB")
|
98 |
original_size = init_image.size
|
|
|
101 |
image_pil: Image = image_transform_grounding_for_vis(init_image)
|
102 |
|
103 |
# run grounidng
|
104 |
+
if task=='predict':
|
105 |
+
boxes, logits, phrases = predict(model, image_tensor, grounding_caption, box_threshold, text_threshold, device='cpu')
|
106 |
+
annotated_frame = annotate(image_source=np.asarray(image_pil), boxes=boxes, logits=logits, phrases=phrases)
|
107 |
+
image_with_box = Image.fromarray(cv2.cvtColor(annotated_frame, cv2.COLOR_BGR2RGB))
|
108 |
+
|
109 |
+
return image_with_box
|
110 |
+
|
111 |
+
elif task=='segment':
|
112 |
+
boxes, logits, phrases = predict(model, image_tensor, grounding_caption, box_threshold, text_threshold, device='cpu')
|
113 |
+
segmented_frame_masks = segment(image_tensor, model, boxes=boxes)
|
114 |
+
annotated_frame_with_mask = draw_mask(segmented_frame_masks[0][0], annotated_frame)
|
115 |
+
seg_with_bbox=Image.fromarray(annotated_frame_with_mask)
|
116 |
+
|
117 |
+
return seg_with_bbox
|
118 |
+
|
119 |
|
120 |
if __name__ == "__main__":
|
121 |
|
|
|
162 |
gr.Examples(
|
163 |
[["watermelon.jpg", "watermelon", 0.25, 0.25]],
|
164 |
inputs = [input_image, grounding_caption, box_threshold, text_threshold],
|
165 |
+
outputs = [gallery],gr.Choice(["segment", "classify"], label="Select Task")],
|
166 |
fn=run_grounding,
|
167 |
cache_examples=True,
|
168 |
label='Try this example input!'
|