File size: 2,157 Bytes
1e58367
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import torch
import gradio as gr
import numpy as np
from PIL import Image, ImageDraw, ImageFont
from transformers import OwlViTProcessor, OwlViTForObjectDetection

model = OwlViTForObjectDetection.from_pretrained("google/owlvit-base-patch32").eval()
processor = OwlViTProcessor.from_pretrained("google/owlvit-base-patch32")


def query_image(img, text_queries):
    text_queries = text_queries.split(",")
    inputs = processor(text=text_queries, images=img, return_tensors="pt")

    with torch.no_grad():
      outputs = model(**inputs)

    target_sizes = torch.Tensor([[768, 768]])
    results = processor.post_process(outputs=outputs, target_sizes=target_sizes)
    boxes, scores, labels = results[0]["boxes"], results[0]["scores"], results[0]["labels"]

    draw = ImageDraw.Draw(img)  
    font = ImageFont.truetype("/System/Library/Fonts/Helvetica.ttc", size=22)

    score_threshold = 0.1
    for box, score, label in zip(boxes, scores, labels):
        box = [int(i) for i in box.tolist()]

        if score >= score_threshold:
            draw.rectangle(box, outline="red", width=4)
            text_loc =[box[0]+5, box[3]+10]
            draw.text(text_loc, text_queries[label], fill="red", font=font, stroke_width=1)
    
    img = np.array(img)
    return img


description = description = """
Gradio demo for <a href="https://huggingface.co/docs/transformers/main/en/model_doc/owlvit">OWL-ViT</a>, 
introduced in <a href="https://arxiv.org/abs/2205.06230">Simple Open-Vocabulary Object Detection
with Vision Transformers</a>. 
\n\nYou can use OWL-ViT to query images with text descriptions of any object. 
To use it, simply upload an image and enter comma separated text descriptions of objects you want to query the image for. 
"""
demo = gr.Interface(
    query_image, 
    inputs=[gr.Image(shape=(768, 768), type="pil"), "text"], 
    outputs="image",
    title="Zero-Shot Object Detection with OWL-ViT",
    description="You can use OWL-ViT to query images with text descriptions of any object",
    examples=[["astronaut.png", "human face, rocket, flag, nasa badge"], ["coffee.png", "coffee mug, spoon, plate"]]
)
demo.launch(debug=True)