import torch import gradio as gr import numpy as np from PIL import Image, ImageDraw, ImageFont from transformers import OwlViTProcessor, OwlViTForObjectDetection model = OwlViTForObjectDetection.from_pretrained("google/owlvit-base-patch32").eval() processor = OwlViTProcessor.from_pretrained("google/owlvit-base-patch32") def query_image(img, text_queries): text_queries = text_queries.split(",") inputs = processor(text=text_queries, images=img, return_tensors="pt") with torch.no_grad(): outputs = model(**inputs) target_sizes = torch.Tensor([[768, 768]]) results = processor.post_process(outputs=outputs, target_sizes=target_sizes) boxes, scores, labels = results[0]["boxes"], results[0]["scores"], results[0]["labels"] draw = ImageDraw.Draw(img) font = ImageFont.truetype("assets/Helvetica.ttf", size=22) score_threshold = 0.1 for box, score, label in zip(boxes, scores, labels): box = [int(i) for i in box.tolist()] if score >= score_threshold: draw.rectangle(box, outline="red", width=4) text_loc =[box[0]+5, box[3]+10] draw.text(text_loc, text_queries[label], fill="red", font=font, stroke_width=1) img = np.array(img) return img description = """ Gradio demo for OWL-ViT, introduced in Simple Open-Vocabulary Object Detection with Vision Transformers. \n\nYou can use OWL-ViT to query images with text descriptions of any object. To use it, simply upload an image and enter comma separated text descriptions of objects you want to query the image for. \n\nColab demo """ demo = gr.Interface( query_image, inputs=[gr.Image(shape=(768, 768), type="pil"), "text"], outputs="image", title="Zero-Shot Object Detection with OWL-ViT", description=description, examples=[["assets/astronaut.png", "human face, rocket, flag, nasa badge"], ["assets/coffee.png", "coffee mug, spoon, plate"]] ) demo.launch(debug=True)