Babyloncoder's picture
Create app.py
396610c verified
raw
history blame
2.35 kB
import streamlit as st
import torch
from transformers import Owlv2Processor, Owlv2ForObjectDetection
from PIL import Image, ImageDraw, ImageFont
import numpy as np
import random
if torch.cuda.is_available():
device = torch.device("cuda")
else:
device = torch.device("cpu")
model = Owlv2ForObjectDetection.from_pretrained("google/owlv2-base-patch16-ensemble").to(device)
processor = Owlv2Processor.from_pretrained("google/owlv2-base-patch16-ensemble")
st.title("Zero-Shot Object Detection with OWLv2")
uploaded_image = st.file_uploader("Upload Image", type=["jpg", "jpeg", "png"])
text_queries = st.text_input("Enter text queries (comma-separated):")
score_threshold = st.slider("Score Threshold", min_value=0.0, max_value=1.0, value=0.1, step=0.01)
def query_image(img, text_queries, score_threshold):
try:
img = Image.open(img).convert("RGB")
img_np = np.array(img)
text_queries = text_queries.split(",")
size = max(img_np.shape[:2])
target_sizes = torch.Tensor([[size, size]])
inputs = processor(text=text_queries, images=img_np, return_tensors="pt").to(device)
with torch.no_grad():
outputs = model(**inputs)
outputs.logits = outputs.logits.cpu()
outputs.pred_boxes = outputs.pred_boxes.cpu()
results = processor.post_process_object_detection(outputs=outputs, target_sizes=target_sizes)
boxes, scores, labels = results[0]["boxes"], results[0]["scores"], results[0]["labels"]
result_labels = []
for box, score, label in zip(boxes, scores, labels):
box = [int(i) for i in box.tolist()]
if score < score_threshold:
continue
result_labels.append((box, text_queries[label.item()]))
return img, result_labels
except Exception as e:
st.error(f"Error performing object detection: {e}")
if uploaded_image is not None:
annotated_image, detected_objects = query_image(uploaded_image, text_queries, score_threshold)
if annotated_image:
draw = ImageDraw.Draw(annotated_image)
font = ImageFont.load_default()
for box, label in detected_objects:
color = (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255))
draw.rectangle(box, outline=color, width=3)
draw.text((box[0], box[1]), label, fill="black", font=font)
st.image(annotated_image, caption="Annotated Image", use_column_width=True)