File size: 2,681 Bytes
4723159
d1a452d
4723159
 
 
 
 
 
d1a452d
8c692eb
4723159
 
d1a452d
 
 
 
 
 
 
1ebfb13
d1a452d
 
1ebfb13
d1a452d
 
 
 
 
 
 
4723159
d1a452d
 
4723159
d1a452d
 
 
 
42b4893
d1a452d
 
 
 
4723159
42b4893
d1a452d
 
eb9c0c4
d1a452d
 
 
 
4723159
 
 
 
 
42b4893
d1a452d
1eb6b8e
 
c05c5a3
4723159
1ebfb13
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import spaces
from transformers import Owlv2Processor, Owlv2ForObjectDetection, AutoProcessor, AutoModelForZeroShotObjectDetection
import torch
import gradio as gr

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

dino_processor = AutoProcessor.from_pretrained("IDEA-Research/grounding-dino-base")
dino_model = AutoModelForZeroShotObjectDetection.from_pretrained("IDEA-Research/grounding-dino-base").to("cuda")

@spaces.GPU
def infer(img, text_queries, score_threshold, model):
  
  if model == "dino":
    queries=""
    for query in text_queries:
      queries += f"{query}. "

    width, height = img.shape[:2]

    target_sizes=[(width, height)]
    inputs = dino_processor(text=queries, images=img, return_tensors="pt").to(device)

    with torch.no_grad():
      outputs = dino_model(**inputs)
      outputs.logits = outputs.logits.cpu()
      outputs.pred_boxes = outputs.pred_boxes.cpu()
      results = dino_processor.post_process_grounded_object_detection(outputs=outputs, input_ids=inputs.input_ids,
                                                                    box_threshold=score_threshold,
                                                                    target_sizes=target_sizes)
      
  boxes, scores, labels = results[0]["boxes"], results[0]["scores"], results[0]["labels"]
  result_labels = []

  for box, score, label in zip(boxes, scores, labels):
      box = [int(i) for i in box.tolist()]
      if score < score_threshold:
          continue

      if model == "dino":
        if label != "":
            result_labels.append((box, label))
  return result_labels

def query_image(img, text_queries, dino_threshold):
    text_queries = text_queries
    text_queries = text_queries.split(",")
    dino_output = infer(img, text_queries, dino_threshold, "dino")


    return (img, dino_output)


dino_threshold = gr.Slider(0, 1, value=0.12, label="Grounding DINO Threshold")
dino_output = gr.AnnotatedImage(label="Grounding DINO Output")
demo = gr.Interface(
    query_image,
    inputs=[gr.Image(label="Input Image"), gr.Textbox(label="Candidate Labels"), dino_threshold],
    outputs=[ dino_output],
    title="Grounding DINO DSA2024",
    description="DSA2024 Space to evaluate  state-of-the-art [Grounding DINO](https://huggingface.co/IDEA-Research/grounding-dino-base) zero-shot object detection model. Simply upload an image and enter a list of the objects you want to detect with comma, or try one of the examples. Play with the threshold to filter out low confidence predictions in the model.",
    examples=[["./deer.jpg", "zebra, deer, goat", 0.16], ["./zebra.jpg", "zebra, lion, deer", 0.16]]
)
demo.launch(debug=True)