import torch
import gradio as gr
from transformers import OwlViTProcessor, OwlViTForObjectDetection
# Use GPU if available
if torch.cuda.is_available():
device = torch.device("cuda")
else:
device = torch.device("cpu")
model = OwlViTForObjectDetection.from_pretrained("nielsr/owlv2-base-patch16-ensemble").to(device)
model.eval()
processor = OwlViTProcessor.from_pretrained("nielsr/owlv2-base-patch16-ensemble")
def query_image(img, text_queries, score_threshold):
text_queries = text_queries
text_queries = text_queries.split(",")
size = max(img.shape[:2])
target_sizes = torch.Tensor([[size, size]])
inputs = processor(text=text_queries, images=img, return_tensors="pt").to(device)
with torch.no_grad():
outputs = model(**inputs)
outputs.logits = outputs.logits.cpu()
outputs.pred_boxes = outputs.pred_boxes.cpu()
results = processor.post_process_object_detection(outputs=outputs, target_sizes=target_sizes)
boxes, scores, labels = results[0]["boxes"], results[0]["scores"], results[0]["labels"]
result_labels = []
for box, score, label in zip(boxes, scores, labels):
box = [int(i) for i in box.tolist()]
if score < score_threshold:
continue
result_labels.append((box, text_queries[label.item()]))
return img, result_labels
description = """
Try this demo for OWLv2,
introduced in Scaling Open-Vocabulary Object Detection.
\n\n Compared to OWLVIT, OWLv2 performs better both in yield and performance (average precision).
You can use OWLv2 to query images with text descriptions of any object.
To use it, simply upload an image and enter comma separated text descriptions of objects you want to query the image for. You
can also use the score threshold slider to set a threshold to filter out low probability predictions.
\n\nOWL-ViT is trained on text templates,
hence you can get better predictions by querying the image with text templates used in training the original model: e.g. *"photo of a star-spangled banner"*,
*"image of a shoe"*. Refer to the CLIP paper to see the full list of text templates used to augment the training data.
\n\nColab demo
"""
demo = gr.Interface(
query_image,
inputs=[gr.Image(), "text", gr.Slider(0, 1, value=0.1)],
outputs="annotatedimage",
title="Zero-Shot Object Detection with OWLv2",
description=description,
examples=[
["assets/astronaut.png", "human face, rocket, star-spangled banner, nasa badge", 0.11],
["assets/coffee.png", "coffee mug, spoon, plate", 0.1],
["assets/butterflies.jpeg", "orange butterfly", 0.3],
],
)
demo.launch()