ngthanhtinqn commited on
Commit
c199bab
β€’
1 Parent(s): 127eb07

fix threshold

Browse files
Files changed (2) hide show
  1. app.py +5 -5
  2. demo.py +2 -2
app.py CHANGED
@@ -17,15 +17,15 @@ hence you can get better predictions by querying the image with text templates u
17
  """
18
  demo = gr.Interface(
19
  query_image,
20
- inputs=[gr.Image(), "text", gr.Slider(0, 1, value=0.1)],
21
  outputs="image",
22
  title="Zero-Shot Object Detection with OWL-ViT",
23
  description=description,
24
  examples=[
25
- ["./demo_images/cats.png", "cats,ears", 0.11],
26
- ["./demo_images/demo1.jpg", "bear,soil,sea", 0.1],
27
- ["./demo_images/demo2.jpg", "dog,ear,leg,eyes,tail", 0.1],
28
- ["./demo_images/tanager.jpg", "wing,eyes,back,legs,tail", 0.01]
29
  ],
30
  )
31
 
 
17
  """
18
  demo = gr.Interface(
19
  query_image,
20
+ inputs=[gr.Image(), "text"],
21
  outputs="image",
22
  title="Zero-Shot Object Detection with OWL-ViT",
23
  description=description,
24
  examples=[
25
+ ["./demo_images/cats.png", "cats,ears"],
26
+ ["./demo_images/demo1.jpg", "bear,soil,sea"],
27
+ ["./demo_images/demo2.jpg", "dog,ear,leg,eyes,tail"],
28
+ ["./demo_images/tanager.jpg", "wing,eyes,back,legs,tail"]
29
  ],
30
  )
31
 
demo.py CHANGED
@@ -81,7 +81,7 @@ owlvit_processor = OwlViTProcessor.from_pretrained("google/owlvit-base-patch32")
81
  # run segment anything (SAM)
82
  sam_predictor = SamPredictor(build_sam(checkpoint="./sam_vit_h_4b8939.pth"))
83
 
84
- def query_image(img, text_prompt, box_threshold):
85
  # load image
86
  if not isinstance(img, PIL.Image.Image):
87
  pil_img = Image.fromarray(np.uint8(img)).convert('RGB')
@@ -89,7 +89,7 @@ def query_image(img, text_prompt, box_threshold):
89
  text_prompt = text_prompt
90
  texts = text_prompt.split(",")
91
 
92
- box_threshold = box_threshold
93
 
94
  # run object detection model
95
  with torch.no_grad():
 
81
  # run segment anything (SAM)
82
  sam_predictor = SamPredictor(build_sam(checkpoint="./sam_vit_h_4b8939.pth"))
83
 
84
+ def query_image(img, text_prompt):
85
  # load image
86
  if not isinstance(img, PIL.Image.Image):
87
  pil_img = Image.fromarray(np.uint8(img)).convert('RGB')
 
89
  text_prompt = text_prompt
90
  texts = text_prompt.split(",")
91
 
92
+ box_threshold = 0.0
93
 
94
  # run object detection model
95
  with torch.no_grad():