MSaadTariq's picture
Create app.py
389e8f6 verified
raw
history blame contribute delete
No virus
2.05 kB
import torch
import gradio as gr
from transformers import Owlv2Processor, Owlv2ForObjectDetection
import spaces
# Use GPU if available
if torch.cuda.is_available():
device = torch.device("cuda")
else:
device = torch.device("cpu")
model = Owlv2ForObjectDetection.from_pretrained("google/owlv2-base-patch16-ensemble").to(device)
processor = Owlv2Processor.from_pretrained("google/owlv2-base-patch16-ensemble")
def query_image(Upload_Image, Text, score_threshold):
Text = Text
Text = Text.split(",")
size = max(Upload_Image.shape[:2])
target_sizes = torch.Tensor([[size, size]])
inputs = processor(text=Text, images=Upload_Image, return_tensors="pt").to(device)
with torch.no_grad():
outputs = model(**inputs)
outputs.logits = outputs.logits.cpu()
outputs.pred_boxes = outputs.pred_boxes.cpu()
results = processor.post_process_object_detection(outputs=outputs, target_sizes=target_sizes)
boxes, scores, labels = results[0]["boxes"], results[0]["scores"], results[0]["labels"]
result_labels = []
for box, score, label in zip(boxes, scores, labels):
box = [int(i) for i in box.tolist()]
if score < score_threshold:
continue
result_labels.append((box, Text[label.item()]))
return Upload_Image, result_labels
description = """
You can use AnyVision to query images with text descriptions of any object.
To use it, simply upload an image and enter comma separated text descriptions of objects you want to query the image for. You
can also use the score threshold slider to set a threshold to filter out low probability predictions.
You can get better predictions by querying the image with text templates used in training the original model: e.g. *"photo of a star-spangled banner"*,
*"image of a shoe"*.
"""
demo = gr.Interface(
query_image,
inputs=[gr.Image(), "text", gr.Slider(0, 1, value=0.1)],
outputs="annotatedimage",
title="AnyVision - Zero-Shot Object Detector with Owl2",
description=description
)
demo.launch()