File size: 3,635 Bytes
184fd64 99cb033 184fd64 b8bd0d3 184fd64 84fc022 184fd64 477e5a7 184fd64 b8bd0d3 184fd64 99cb033 477e5a7 e789313 477e5a7 184fd64 e3d0d39 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 |
import torch
import cv2
import gradio as gr
import numpy as np
import requests
from PIL import Image
from io import BytesIO
from transformers import OwlViTProcessor, OwlViTForObjectDetection
import os
# Use GPU if available
if torch.cuda.is_available():
device = torch.device("cuda")
else:
device = torch.device("cpu")
model = OwlViTForObjectDetection.from_pretrained("google/owlvit-large-patch14").to(device)
model.eval()
processor = OwlViTProcessor.from_pretrained("google/owlvit-large-patch14")
def query_image(img, text_queries, score_threshold):
text_queries = text_queries.split(",")
img = np.array(img)
target_sizes = torch.Tensor([img.shape[:2]])
inputs = processor(text=text_queries, images=img, return_tensors="pt").to(device)
with torch.no_grad():
outputs = model(**inputs)
outputs.logits = outputs.logits.cpu()
outputs.pred_boxes = outputs.pred_boxes.cpu()
results = processor.post_process(outputs=outputs, target_sizes=target_sizes)
boxes, scores, labels = results[0]["boxes"], results[0]["scores"], results[0]["labels"]
font = cv2.FONT_HERSHEY_SIMPLEX
for box, score, label in zip(boxes, scores, labels):
box = [int(i) for i in box.tolist()]
if score >= score_threshold:
img = cv2.rectangle(img, box[:2], box[2:], (255,0,0), 5)
if box[3] + 25 > 768:
y = box[3] - 10
else:
y = box[3] + 25
img = cv2.putText(
img, text_queries[label], (box[0], y), font, 1, (255,0,0), 2, cv2.LINE_AA
)
return img
description = """
\n\nYou can use OWL-ViT to query images with text descriptions of any object.
To use it, simply upload an image and enter comma separated text descriptions of objects you want to query the image for. You
can also use the score threshold slider to set a threshold to filter out low probability predictions.
"""
upload = gr.Interface(
query_image,
inputs=[gr.Image(source="upload"),
"text",
gr.Slider(0, 1, value=0.1)],
outputs="image",
title="Zero-Shot Object Detection with OWL-ViT",
description=description,
examples=[os.path.join("examples","IMGP0178.jpg")])
web = gr.Interface(
query_image,
inputs=[gr.Image(source="webcam"),
"text",
gr.Slider(0, 1, value=0.1)],
outputs="image",
title="Zero-Shot Object Detection with OWL-ViT",
description=description,
examples="./examples/IMGP0178.jpg")
demo = gr.TabbedInterface(interface_list=[upload, web],
tab_names=["From a File", "From your Webcam"])
demo.launch()
with gr.Blocks() as demo:
with gr.Column():
with gr.Tab("Upload image"):
with gr.Row():
with gr.Column():
inputs_file=[gr.Image(source="webcam"),
"text",
gr.Slider(0, 1, value=0.1)]
submit_btn = gr.Button("Submit")
im_output = gr.Image()
with gr.Tab("Capture image with webcam"):
with gr.Row():
with gr.Column():
inputs_web=[gr.Image(source="webcam"),
"text",
gr.Slider(0, 1, value=0.1)]
submit_btn_web = gr.Button("Submit")
web_output = gr.Image()
gr.Examples(["examples/IMGP0178.jpg"])
submit_btn.click(fn=query_image, inputs= inputs_file, outputs = im_output)
submit_btn_web.click(fn=query_image, inputs= inputs_web, outputs = web_output)
demo.launch()
|