Arulkumar03 commited on
Commit
4b21d34
1 Parent(s): c25e3bd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -59
app.py CHANGED
@@ -23,7 +23,7 @@ from groundingdino.util.slconfig import SLConfig
23
  from groundingdino.util.utils import clean_state_dict
24
  from groundingdino.util.inference import annotate, load_image, predict
25
  import groundingdino.datasets.transforms as T
26
- from groundingdino.util import box_ops
27
  from huggingface_hub import hf_hub_download
28
 
29
 
@@ -64,62 +64,23 @@ def image_transform_grounding_for_vis(init_image):
64
 
65
  model = load_model_hf(config_file, ckpt_repo_id, ckpt_filenmae)
66
 
67
- def segment(image, sam_model, boxes):
68
- H, W, _ = image.shape
69
- boxes_xyxy = box_ops.box_cxcywh_to_xyxy(boxes) * torch.Tensor([W, H, W, H])
70
-
71
- transformed_boxes = sam_model.transform.apply_boxes_torch(boxes_xyxy.to(device), image.shape[:2])
72
- masks, _, _ = sam_model.predict_torch(
73
- point_coords = None,
74
- point_labels = None,
75
- boxes = transformed_boxes,
76
- multimask_output = False,
77
- )
78
- return masks.cpu()
79
-
80
-
81
- def draw_mask(mask, image, random_color=True):
82
- if random_color:
83
- color = np.concatenate([np.random.random(3), np.array([0.8])], axis=0)
84
- else:
85
- color = np.array([30/255, 144/255, 255/255, 0.6])
86
- h, w = mask.shape[-2:]
87
- mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
88
-
89
- annotated_frame_pil = Image.fromarray(image).convert("RGBA")
90
- mask_image_pil = Image.fromarray((mask_image.cpu().numpy() * 255).astype(np.uint8)).convert("RGBA")
91
-
92
- return np.array(Image.alpha_composite(annotated_frame_pil, mask_image_pil))
93
-
94
-
95
- def run_grounding(input_image,choice, grounding_caption, box_threshold, text_threshold,do_segmentation):
96
  init_image = input_image.convert("RGB")
97
  original_size = init_image.size
98
 
99
  _, image_tensor = image_transform_grounding(init_image)
100
  image_pil: Image = image_transform_grounding_for_vis(init_image)
101
 
102
- if choice == 'segment':
103
- boxes, logits, phrases = predict(model, image_tensor, grounding_caption, box_threshold, text_threshold, device='cpu')
104
- segmented_frame_masks = segment(image_tensor, model, boxes=boxes)
105
- annotated_frame_with_mask = draw_mask(segmented_frame_masks[0][0], annotated_frame)
106
- else:
107
- # run grounding
108
- boxes, logits, phrases = predict(model, image_tensor, grounding_caption, box_threshold, text_threshold, device='cpu')
109
- annotated_frame = annotate(image_source=np.asarray(image_pil), boxes=boxes, logits=logits, phrases=phrases)
110
-
111
  image_with_box = Image.fromarray(cv2.cvtColor(annotated_frame, cv2.COLOR_BGR2RGB))
112
 
113
- return image_with_box
114
 
115
-
116
 
117
  if __name__ == "__main__":
118
-
119
- parser = argparse.ArgumentParser("Grounding DINO demo", add_help=True)
120
- parser.add_argument("--debug", action="store_true", help="using debug mode")
121
- parser.add_argument("--share", action="store_true", help="share the app")
122
- args = parser.parse_args()
123
  css = """
124
  #mkd {
125
  height: 500px;
@@ -133,13 +94,9 @@ if __name__ == "__main__":
133
  gr.Markdown("<h3><center>Open-World Detection with <a href='https://github.com/Arulkumar03/SOTA-Grounding-DINO.ipynb'>Grounding DINO</a><h3><center>")
134
  gr.Markdown("<h3><center>Note the model runs on CPU, so it may take a while to run the model.<h3><center>")
135
 
136
-
137
  with gr.Row():
138
  with gr.Column():
139
  input_image = gr.Image(source='upload', type="pil")
140
- choice = gr.Radio(
141
- ["segment", "classify"], default="segment", label="Choose Operation"
142
- )
143
  grounding_caption = gr.Textbox(label="Detection Prompt")
144
  run_button = gr.Button(label="Run")
145
  with gr.Accordion("Advanced options", open=False):
@@ -155,15 +112,18 @@ if __name__ == "__main__":
155
  type="pil",
156
  # label="grounding results"
157
  ).style(full_width=True, full_height=True)
 
 
158
 
159
  run_button.click(fn=run_grounding, inputs=[
160
- input_image, choice, grounding_caption, box_threshold, text_threshold], outputs=[gallery])
161
  gr.Examples(
162
- [["watermelon.jpg", "segment", "watermelon", 0.25, 0.25]],
163
- inputs=[input_image, choice, grounding_caption, box_threshold, text_threshold],
164
- outputs=[gallery],
165
- fn=run_grounding,
166
- cache_examples=True,
167
- label='Try this example input!'
168
- )
169
- block.launch(share=False, show_api=False, show_error=True)
 
 
23
  from groundingdino.util.utils import clean_state_dict
24
  from groundingdino.util.inference import annotate, load_image, predict
25
  import groundingdino.datasets.transforms as T
26
+
27
  from huggingface_hub import hf_hub_download
28
 
29
 
 
64
 
65
  model = load_model_hf(config_file, ckpt_repo_id, ckpt_filenmae)
66
 
67
+ def run_grounding(input_image, grounding_caption, box_threshold, text_threshold):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
  init_image = input_image.convert("RGB")
69
  original_size = init_image.size
70
 
71
  _, image_tensor = image_transform_grounding(init_image)
72
  image_pil: Image = image_transform_grounding_for_vis(init_image)
73
 
74
+ # run grounidng
75
+ boxes, logits, phrases = predict(model, image_tensor, grounding_caption, box_threshold, text_threshold, device='cpu')
76
+ annotated_frame = annotate(image_source=np.asarray(image_pil), boxes=boxes, logits=logits, phrases=phrases)
 
 
 
 
 
 
77
  image_with_box = Image.fromarray(cv2.cvtColor(annotated_frame, cv2.COLOR_BGR2RGB))
78
 
 
79
 
80
+ return image_with_box
81
 
82
  if __name__ == "__main__":
83
+
 
 
 
 
84
  css = """
85
  #mkd {
86
  height: 500px;
 
94
  gr.Markdown("<h3><center>Open-World Detection with <a href='https://github.com/Arulkumar03/SOTA-Grounding-DINO.ipynb'>Grounding DINO</a><h3><center>")
95
  gr.Markdown("<h3><center>Note the model runs on CPU, so it may take a while to run the model.<h3><center>")
96
 
 
97
  with gr.Row():
98
  with gr.Column():
99
  input_image = gr.Image(source='upload', type="pil")
 
 
 
100
  grounding_caption = gr.Textbox(label="Detection Prompt")
101
  run_button = gr.Button(label="Run")
102
  with gr.Accordion("Advanced options", open=False):
 
112
  type="pil",
113
  # label="grounding results"
114
  ).style(full_width=True, full_height=True)
115
+ # gallery = gr.Gallery(label="Generated images", show_label=False).style(
116
+ # grid=[1], height="auto", container=True, full_width=True, full_height=True)
117
 
118
  run_button.click(fn=run_grounding, inputs=[
119
+ input_image, grounding_caption, box_threshold, text_threshold], outputs=[gallery])
120
  gr.Examples(
121
+ [["watermelon.jpg", "watermelon", 0.25, 0.25]],
122
+ inputs = [input_image, grounding_caption, box_threshold, text_threshold],
123
+ outputs = [gallery],
124
+ fn=run_grounding,
125
+ cache_examples=True,
126
+ label='Try this example input!'
127
+ )
128
+ block.launch(share=True, show_api=False, show_error=True)
129
+