Spaces:
Runtime error
Runtime error
import gradio as gr | |
from pydantic import BaseModel | |
from typing import List, Optional | |
from PIL import Image, ImageDraw, ImageFont | |
import random | |
import torch | |
from transformers import Owlv2Processor, Owlv2ForObjectDetection | |
import logging | |
from logging.handlers import RotatingFileHandler | |
import base64 | |
import io | |
import os | |
import numpy as np | |
class DetectionRequest(BaseModel): | |
image_data: str | |
texts: List[List[str]] | |
class DetectionResult(BaseModel): | |
detections: List[str] | |
image_with_boxes: str | |
processor = Owlv2Processor.from_pretrained("google/owlv2-base-patch16-ensemble") | |
model = Owlv2ForObjectDetection.from_pretrained("google/owlv2-base-patch16-ensemble") | |
# Create logs directory if it doesn't exist | |
if not os.path.exists('logs'): | |
os.makedirs('logs') | |
def draw_bounding_boxes(image: Image, boxes, scores, labels, text_labels): | |
draw = ImageDraw.Draw(image) | |
width, height = image.size | |
# Define the color bank | |
color_bank = ["#0AC2FF", "#47FF0A", "#FF0AC2", "#ADD8E6", "#FF0A47"] | |
# Use default font | |
font = ImageFont.load_default() | |
for box, score, label in zip(boxes, scores, labels): | |
# Choose a random color | |
color = random.choice(color_bank) | |
# Convert the box to a Python list if it's not already | |
if isinstance(box, torch.Tensor): | |
box = box.tolist() | |
elif not isinstance(box, (list, tuple)): | |
raise TypeError("Box must be a list or tuple of coordinates.") | |
# Draw the rectangle | |
draw.rectangle(box, outline=color, width=2) | |
# Get the text to display | |
display_text = f"{text_labels[label]}: {score:.2f}" | |
# Calculate position for the text | |
text_position = (box[0], box[1] - 10) | |
# Draw the text | |
draw.text(text_position, display_text, fill=color, font=font) | |
return image | |
def detect_objects_logic(image_data, texts): | |
try: | |
# Decode the base64 image | |
image_data_bytes = base64.b64decode(image_data) | |
image = Image.open(io.BytesIO(image_data_bytes)) | |
width, height = image.size | |
inputs = processor(text=texts, images=image, return_tensors="pt") | |
outputs = model(**inputs) | |
target_sizes = torch.Tensor([image.size[::-1]]) | |
results = processor.post_process_object_detection(outputs=outputs, threshold=0.1, target_sizes=target_sizes) | |
detection_strings = [] | |
image_with_boxes = image.copy() # Copy the image only once | |
for i, text_group in enumerate(texts): | |
if i >= len(results): | |
logging.error(f"Text group index {i} exceeds results length.") | |
continue | |
logging.info(f"Processing texts: {texts}") | |
results_per_group = results[i] | |
boxes = results_per_group["boxes"] | |
scores = results_per_group["scores"] | |
labels = results_per_group["labels"] | |
image_with_boxes = draw_bounding_boxes(image_with_boxes, boxes, scores, labels, text_group) | |
for box, score, label in zip(boxes, scores, labels): | |
scaled_box = [round(box[i].item() * (width if i % 2 == 0 else height), 2) for i in range(len(box))] | |
detection_string = f"Detected {text_group[label]} with confidence {round(score.item(), 3)} at location {scaled_box}" | |
detection_strings.append(detection_string) | |
logging.info("Bounding boxes and labels have been drawn on the image.") | |
return image_with_boxes, detection_strings | |
except IndexError as e: | |
logging.error(f"Index error: {e}. Check if the number of text groups matches the model's output.") | |
raise e | |
except Exception as e: | |
logging.error(f"An unexpected error occurred: {e}", exc_info=True) | |
raise e | |
def gradio_detect_and_draw(image, text_labels): | |
# Check if the image is None | |
if image is None: | |
raise ValueError("No image was provided.") | |
# Convert the input image to PIL Image if it's a numpy array | |
if isinstance(image, np.ndarray): | |
image = Image.fromarray(image.astype('uint8'), 'RGB') | |
# Convert PIL Image to base64 for your logic function | |
buffered = io.BytesIO() | |
image.save(buffered, format="JPEG") | |
image_data = base64.b64encode(buffered.getvalue()).decode("utf-8") | |
# Process texts input | |
text_labels = [text_labels.split(',')] if text_labels else [] | |
# Call your detection logic | |
processed_image, detections = detect_objects_logic(image_data, text_labels) | |
# Convert the output image to PIL Image if it's a numpy array | |
if isinstance(processed_image, np.ndarray): | |
processed_image = Image.fromarray(processed_image.astype('uint8'), 'RGB') | |
return processed_image, detections | |
with gr.Blocks() as demo: | |
gr.Markdown("## Owlv2 Object Detection Demo") | |
with gr.Row(): | |
with gr.Column(): | |
image_input = gr.Image(type="pil", label="Upload or draw an image") | |
text_input = gr.Textbox(label="Enter comma-separated labels for detection") | |
submit_button = gr.Button("Detect") | |
with gr.Column(): | |
image_output = gr.Image(label="Processed Image") | |
text_output = gr.Text(label="Detections") | |
submit_button.click( | |
gradio_detect_and_draw, | |
inputs=[image_input, text_input], | |
outputs=[image_output, text_output] | |
) | |
# Add examples | |
examples = [ | |
["https://huggingface.co/microsoft/kosmos-2-patch14-224/resolve/main/snowman.jpg", "snowman"], | |
["https://history.iowa.gov/sites/default/files/primary-sources/images/history-education-pss-transportation-centralpark-source.jpg", "taxi,traffic light"], | |
["https://i.pinimg.com/1200x/51/e1/a1/51e1a12517e95725590d3a4b1a7575d7.jpg", "umbrella"] | |
] | |
gr.Examples(examples, inputs=[image_input, text_input]) | |
demo.launch() | |