import torch from transformers import OwlViTProcessor, OwlViTForObjectDetection from PIL import Image, ImageDraw import gradio as gr # Load pre-trained Owl-ViT model and processor model = OwlViTForObjectDetection.from_pretrained("google/owlvit-base-patch32") processor = OwlViTProcessor.from_pretrained("google/owlvit-base-patch32") def detect_objects(image: Image.Image, texts: str): # Prepare text queries text_queries = [text.strip() for text in texts.split(',')] # Prepare inputs for the model inputs = processor(text=text_queries, images=image, return_tensors="pt") # Perform inference with the model with torch.no_grad(): outputs = model(**inputs) # Post-process the outputs to extract detected boxes and labels target_sizes = torch.tensor([image.size[::-1]]) results = processor.post_process(outputs=outputs, target_sizes=target_sizes) # Extracting results detected_boxes = [] for i, box in enumerate(results[0]["boxes"]): score = results[0]["scores"][i].item() label = results[0]["labels"][i].item() if score > 0.1: # Confidence threshold detected_boxes.append((box, text_queries[label], score)) return detected_boxes def visualize(image, texts): # Detect objects in the image boxes = detect_objects(image, texts) # Draw boxes on the image image = image.copy() draw = ImageDraw.Draw(image) for box, label, score in boxes: box = [round(coord) for coord in box.tolist()] draw.rectangle(box, outline="red", width=3) draw.text((box[0], box[1]), f"{label}: {score:.2f}", fill="red") return image # Gradio Interface def gradio_interface(image, texts): return visualize(image, texts) interface = gr.Interface( fn=gradio_interface, inputs=[gr.Image(type="pil", label="Upload an Image"), gr.Textbox(label="Comma-separated Text Queries")], outputs=gr.Image(type="pil", label="Object Detection Output"), title="Owl-ViT Object Detection", description="Upload an image and provide comma-separated text queries for object detection.", allow_flagging="never" ) if __name__ == "__main__": interface.launch()