developer0hye commited on
Commit
ea17cac
·
verified ·
1 Parent(s): 077bef8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -14
app.py CHANGED
@@ -14,16 +14,28 @@ from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection
14
  # Add supervision for better visualization
15
  import supervision as sv
16
 
17
- # Model ID for Hugging Face
18
- model_id = "rziga/mm_grounding_dino_base_all"
 
 
 
19
 
20
- # Load model and processor using Transformers
21
  device = "cuda" if torch.cuda.is_available() else "cpu"
22
- processor = AutoProcessor.from_pretrained(model_id)
23
- model = AutoModelForZeroShotObjectDetection.from_pretrained(model_id).to(device)
 
24
 
25
  @spaces.GPU
26
- def run_grounding(input_image, grounding_caption, box_threshold, text_threshold):
 
 
 
 
 
 
 
 
27
  # Convert numpy array to PIL Image if needed
28
  if isinstance(input_image, np.ndarray):
29
  if input_image.ndim == 3:
@@ -63,8 +75,6 @@ def run_grounding(input_image, grounding_caption, box_threshold, text_threshold)
63
 
64
  for i, (box, score, label) in enumerate(zip(result["boxes"], result["scores"], result["labels"])):
65
  # box is xyxy format [xmin, ymin, xmax, ymax]
66
- if label.strip() == "":
67
- continue
68
  xyxy = box.tolist()
69
  boxes.append(xyxy)
70
  labels.append(label)
@@ -144,12 +154,18 @@ if __name__ == "__main__":
144
  }
145
  """
146
  with gr.Blocks(css=css) as demo:
147
- gr.Markdown("<h1><center>MM Grounding DINO Base<h1><center>")
148
- gr.Markdown("<h3><center>Open-World Detection with <a href='https://huggingface.co/openmmlab-community/mm_grounding_dino_base_all'>MM Grounding DINO</a><h3><center>")
149
 
150
  with gr.Row():
151
  with gr.Column():
152
  input_image = gr.Image(label="Input Image", type="pil")
 
 
 
 
 
 
153
  grounding_caption = gr.Textbox(
154
  label="Detection Prompt (lowercase + each ends with a dot)",
155
  value="a person. a car."
@@ -181,16 +197,16 @@ if __name__ == "__main__":
181
 
182
  run_button.click(
183
  fn=run_grounding,
184
- inputs=[input_image, grounding_caption, box_threshold, text_threshold],
185
  outputs=[gallery, det_text]
186
  )
187
 
188
  gr.Examples(
189
  examples=[
190
- ["000000039769.jpg", "a cat. a remote control.", 0.3, 0.25],
191
- ["KakaoTalk_20250430_163200504.jpg", "cup. screen. hand.", 0.3, 0.25]
192
  ],
193
- inputs=[input_image, grounding_caption, box_threshold, text_threshold],
194
  outputs=[gallery, det_text],
195
  fn=run_grounding,
196
  cache_examples=True,
 
14
  # Add supervision for better visualization
15
  import supervision as sv
16
 
17
+ # Model IDs for Hugging Face
18
+ MODEL_IDS = {
19
+ "MM Grounding DINO Large": "rziga/mm_grounding_dino_large_all",
20
+ "MM Grounding DINO Base": "rziga/mm_grounding_dino_base_all"
21
+ }
22
 
23
+ # Global variables for model caching
24
  device = "cuda" if torch.cuda.is_available() else "cpu"
25
+ loaded_model_name = None
26
+ processor = None
27
+ model = None
28
 
29
  @spaces.GPU
30
+ def run_grounding(input_image, grounding_caption, model_choice, box_threshold, text_threshold):
31
+ global loaded_model_name, processor, model
32
+
33
+ # Load or reload model if changed
34
+ if loaded_model_name != model_choice:
35
+ model_id = MODEL_IDS[model_choice]
36
+ processor = AutoProcessor.from_pretrained(model_id)
37
+ model = AutoModelForZeroShotObjectDetection.from_pretrained(model_id).to(device)
38
+ loaded_model_name = model_choice
39
  # Convert numpy array to PIL Image if needed
40
  if isinstance(input_image, np.ndarray):
41
  if input_image.ndim == 3:
 
75
 
76
  for i, (box, score, label) in enumerate(zip(result["boxes"], result["scores"], result["labels"])):
77
  # box is xyxy format [xmin, ymin, xmax, ymax]
 
 
78
  xyxy = box.tolist()
79
  boxes.append(xyxy)
80
  labels.append(label)
 
154
  }
155
  """
156
  with gr.Blocks(css=css) as demo:
157
+ gr.Markdown("<h1><center>MM Grounding DINO (Large & Base)<h1><center>")
158
+ gr.Markdown("<h3><center>Open-World Detection with MM Grounding DINO Models<h3><center>")
159
 
160
  with gr.Row():
161
  with gr.Column():
162
  input_image = gr.Image(label="Input Image", type="pil")
163
+ model_choice = gr.Radio(
164
+ choices=list(MODEL_IDS.keys()),
165
+ value="MM Grounding DINO Large",
166
+ label="Select Model",
167
+ info="Choose between Large (better performance) or Base (faster) model"
168
+ )
169
  grounding_caption = gr.Textbox(
170
  label="Detection Prompt (lowercase + each ends with a dot)",
171
  value="a person. a car."
 
197
 
198
  run_button.click(
199
  fn=run_grounding,
200
+ inputs=[input_image, grounding_caption, model_choice, box_threshold, text_threshold],
201
  outputs=[gallery, det_text]
202
  )
203
 
204
  gr.Examples(
205
  examples=[
206
+ ["000000039769.jpg", "a cat. a remote control.", "MM Grounding DINO Large", 0.3, 0.25],
207
+ ["KakaoTalk_20250430_163200504.jpg", "cup. screen. hand.", "MM Grounding DINO Base", 0.3, 0.25]
208
  ],
209
+ inputs=[input_image, grounding_caption, model_choice, box_threshold, text_threshold],
210
  outputs=[gallery, det_text],
211
  fn=run_grounding,
212
  cache_examples=True,