File size: 3,178 Bytes
4723159
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import spaces
from transformers import Owlv2Processor, Owlv2ForObjectDetection, AutoProcessor, AutoModelForZeroShotObjectDetection
import torch
import gradio as gr

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

owl_model = Owlv2ForObjectDetection.from_pretrained("google/owlv2-base-patch16-ensemble").to(device)
owl_processor = Owlv2Processor.from_pretrained("google/owlv2-base-patch16-ensemble")

dino_processor = AutoProcessor.from_pretrained("IDEA-Research/grounding-dino-base")
dino_model = AutoModelForZeroShotObjectDetection.from_pretrained("IDEA-Research/grounding-dino-base").to(device)

@spaces.GPU
def infer(img, text_queries, score_threshold, model):
  
  if model == "dino":
    queries=""
    for query in text_queries:
      queries += f"{query}. "

    width, height = img.shape[:2]

    target_sizes=[(width, height)]
    inputs = dino_processor(text=queries, images=img, return_tensors="pt").to(device)

    with torch.no_grad():
      outputs = dino_model(**inputs)
      outputs.logits = outputs.logits.cpu()
      outputs.pred_boxes = outputs.pred_boxes.cpu()
      results = dino_processor.post_process_grounded_object_detection(outputs=outputs, input_ids=inputs.input_ids,
                                                                    box_threshold=score_threshold,
                                                                    target_sizes=target_sizes)
  elif model == "owl":
    size = max(img.shape[:2])
    target_sizes = torch.Tensor([[size, size]])
    inputs = owl_processor(text=text_queries, images=img, return_tensors="pt").to(device)

    with torch.no_grad():
      outputs = owl_model(**inputs)
      outputs.logits = outputs.logits.cpu()
      outputs.pred_boxes = outputs.pred_boxes.cpu()
      results = owl_processor.post_process_object_detection(outputs=outputs, target_sizes=target_sizes)
      
  boxes, scores, labels = results[0]["boxes"], results[0]["scores"], results[0]["labels"]
  result_labels = []

  for box, score, label in zip(boxes, scores, labels):
      box = [int(i) for i in box.tolist()]
      if score < score_threshold:
          continue
      if model == "owl":
        label = text_queries[label.cpu().item()]
      result_labels.append((box, label))
  return result_labels

def query_image(img, text_queries, owl_threshold, dino_threshold):
    text_queries = text_queries
    text_queries = text_queries.split(",")
    owl_output = infer(img, text_queries, owl_threshold, "owl")
    dino_output = infer(img, text_queries, owl_threshold, "dino")


    return (img, owl_output), (img, dino_output)


owl_threshold = gr.Slider(0, 1, value=0.16, label="OWL Threshold")
dino_threshold = gr.Slider(0, 1, value=0.12, label="Grounding DINO Threshold")
owl_output = gr.AnnotatedImage(label="OWL Output")
dino_output = gr.AnnotatedImage(label="Grounding DINO Output")
demo = gr.Interface(
    query_image,
    inputs=[gr.Image(label="Input Image"), gr.Textbox("Candidate Labels"), owl_threshold, dino_threshold],
    outputs=[owl_output, dino_output],
    title="Zero-Shot Object Detection with OWLv2",
    examples=[["./bee.jpg", "bee, flower", 0.16, 0.12]]
)
demo.launch(debug=True)