import torch
import cv2
import gradio as gr
import numpy as np
from transformers import OwlViTProcessor, OwlViTForObjectDetection
# Use GPU if available
if torch.cuda.is_available():
device = torch.device("cuda")
else:
device = torch.device("cpu")
model = OwlViTForObjectDetection.from_pretrained("google/owlvit-base-patch32").to(device)
model.eval()
processor = OwlViTProcessor.from_pretrained("google/owlvit-base-patch32")
def image_guided_detection(img, query_img, score_threshold, nms_threshold):
target_sizes = torch.Tensor([img.size[::-1]])
inputs = processor(query_images=query_img, images=img, return_tensors="pt").to(device)
with torch.no_grad():
outputs = model.image_guided_detection(**inputs)
outputs.logits = outputs.logits.cpu()
outputs.pred_boxes = outputs.target_pred_boxes.cpu()
results = processor.post_process_image_guided_detection(
outputs=outputs,
threshold=score_threshold,
nms_threshold=nms_threshold,
target_sizes=target_sizes
)
boxes, scores = results[0]["boxes"], results[0]["scores"]
img = np.asarray(img)
for box, score in zip(boxes, scores):
box = [int(i) for i in box.tolist()]
if score >= score_threshold:
img = cv2.rectangle(img, box[:2], box[2:], (255,0,0), 5)
if box[3] + 25 > 768:
y = box[3] - 10
else:
y = box[3] + 25
return img
description = """
Gradio demo for image-guided / one-shot object detection with OWL-ViT -
OWL-ViT,
introduced in Simple Open-Vocabulary Object Detection
with Vision Transformers.
\n\nYou can use OWL-ViT to query images with text descriptions of any object or alternatively with an
example / query image of the target object. To use it, simply upload an image and a query image that only contains the object
you're looking for. You can also use the score and non-maximum suppression threshold sliders to set a threshold to filter out
low probability and overlapping bounding box predictions.
\n\nFor an in-depth tutorial on how to use OWL-ViT with transformers, check out our
Colab notebook
and our HF spaces demo for zero-shot / text-guided object detection.
"""
demo = gr.Interface(
image_guided_detection,
inputs=[gr.Image(type="pil"), gr.Image(type="pil"), gr.Slider(0, 1, value=0.6), gr.Slider(0, 1, value=0.3)],
outputs="image",
title="Image-Guided Object Detection with OWL-ViT",
description=description,
examples=[
["assets/image2.jpeg", "assets/query2.jpeg", 0.7, 0.3],
["assets/image1.jpeg", "assets/query1.jpeg", 0.6, 0.3]
]
)
demo.launch()