import streamlit as st import torch from transformers import Owlv2Processor, Owlv2ForObjectDetection from PIL import Image, ImageDraw, ImageFont import numpy as np import random if torch.cuda.is_available(): device = torch.device("cuda") else: device = torch.device("cpu") model = Owlv2ForObjectDetection.from_pretrained("google/owlv2-base-patch16-ensemble").to(device) processor = Owlv2Processor.from_pretrained("google/owlv2-base-patch16-ensemble") st.title("Zero-Shot Object Detection with OWLv2") uploaded_image = st.file_uploader("Upload Image", type=["jpg", "jpeg", "png"]) text_queries = st.text_input("Enter text queries (comma-separated):") score_threshold = st.slider("Score Threshold", min_value=0.0, max_value=1.0, value=0.1, step=0.01) def query_image(img, text_queries, score_threshold): try: img = Image.open(img).convert("RGB") img_np = np.array(img) text_queries = text_queries.split(",") size = max(img_np.shape[:2]) target_sizes = torch.Tensor([[size, size]]) inputs = processor(text=text_queries, images=img_np, return_tensors="pt").to(device) with torch.no_grad(): outputs = model(**inputs) outputs.logits = outputs.logits.cpu() outputs.pred_boxes = outputs.pred_boxes.cpu() results = processor.post_process_object_detection(outputs=outputs, target_sizes=target_sizes) boxes, scores, labels = results[0]["boxes"], results[0]["scores"], results[0]["labels"] result_labels = [] for box, score, label in zip(boxes, scores, labels): box = [int(i) for i in box.tolist()] if score < score_threshold: continue result_labels.append((box, text_queries[label.item()])) return img, result_labels except Exception as e: st.error(f"Error performing object detection: {e}") if uploaded_image is not None: annotated_image, detected_objects = query_image(uploaded_image, text_queries, score_threshold) if annotated_image: draw = ImageDraw.Draw(annotated_image) font = ImageFont.load_default() for box, label in detected_objects: color = (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255)) draw.rectangle(box, outline=color, width=3) draw.text((box[0], box[1]), label, fill="black", font=font) st.image(annotated_image, caption="Annotated Image", use_column_width=True)