macadeliccc's picture
theme
5355cb8
import gradio as gr
from pydantic import BaseModel
from typing import List, Optional
from PIL import Image, ImageDraw, ImageFont
import random
import torch
from transformers import Owlv2Processor, Owlv2ForObjectDetection
import logging
from logging.handlers import RotatingFileHandler
import base64
import io
import os
import numpy as np
class DetectionRequest(BaseModel):
image_data: str
texts: List[List[str]]
class DetectionResult(BaseModel):
detections: List[str]
image_with_boxes: str
processor = Owlv2Processor.from_pretrained("google/owlv2-base-patch16-ensemble")
model = Owlv2ForObjectDetection.from_pretrained("google/owlv2-base-patch16-ensemble")
# Create logs directory if it doesn't exist
if not os.path.exists('logs'):
os.makedirs('logs')
def draw_bounding_boxes(image: Image, boxes, scores, labels, text_labels):
draw = ImageDraw.Draw(image)
width, height = image.size
# Define the color bank
color_bank = ["#0AC2FF", "#47FF0A", "#FF0AC2", "#ADD8E6", "#FF0A47"]
# Use default font
font = ImageFont.load_default()
for box, score, label in zip(boxes, scores, labels):
# Choose a random color
color = random.choice(color_bank)
# Convert the box to a Python list if it's not already
if isinstance(box, torch.Tensor):
box = box.tolist()
elif not isinstance(box, (list, tuple)):
raise TypeError("Box must be a list or tuple of coordinates.")
# Draw the rectangle
draw.rectangle(box, outline=color, width=2)
# Get the text to display
display_text = f"{text_labels[label]}: {score:.2f}"
# Calculate position for the text
text_position = (box[0], box[1] - 10)
# Draw the text
draw.text(text_position, display_text, fill=color, font=font)
return image
def detect_objects_logic(image_data, texts):
try:
# Decode the base64 image
image_data_bytes = base64.b64decode(image_data)
image = Image.open(io.BytesIO(image_data_bytes))
width, height = image.size
inputs = processor(text=texts, images=image, return_tensors="pt")
outputs = model(**inputs)
target_sizes = torch.Tensor([image.size[::-1]])
results = processor.post_process_object_detection(outputs=outputs, threshold=0.1, target_sizes=target_sizes)
detection_strings = []
image_with_boxes = image.copy() # Copy the image only once
for i, text_group in enumerate(texts):
if i >= len(results):
logging.error(f"Text group index {i} exceeds results length.")
continue
logging.info(f"Processing texts: {texts}")
results_per_group = results[i]
boxes = results_per_group["boxes"]
scores = results_per_group["scores"]
labels = results_per_group["labels"]
image_with_boxes = draw_bounding_boxes(image_with_boxes, boxes, scores, labels, text_group)
for box, score, label in zip(boxes, scores, labels):
scaled_box = [round(box[i].item() * (width if i % 2 == 0 else height), 2) for i in range(len(box))]
detection_string = f"Detected {text_group[label]} with confidence {round(score.item(), 3)} at location {scaled_box}"
detection_strings.append(detection_string)
logging.info("Bounding boxes and labels have been drawn on the image.")
return image_with_boxes, detection_strings
except IndexError as e:
logging.error(f"Index error: {e}. Check if the number of text groups matches the model's output.")
raise e
except Exception as e:
logging.error(f"An unexpected error occurred: {e}", exc_info=True)
raise e
def gradio_detect_and_draw(image, text_labels):
# Check if the image is None
if image is None:
raise ValueError("No image was provided.")
# Convert the input image to PIL Image if it's a numpy array
if isinstance(image, np.ndarray):
image = Image.fromarray(image.astype('uint8'), 'RGB')
# Convert PIL Image to base64 for your logic function
buffered = io.BytesIO()
image.save(buffered, format="JPEG")
image_data = base64.b64encode(buffered.getvalue()).decode("utf-8")
# Process texts input
text_labels = [text_labels.split(',')] if text_labels else []
# Call your detection logic
processed_image, detections = detect_objects_logic(image_data, text_labels)
# Convert the output image to PIL Image if it's a numpy array
if isinstance(processed_image, np.ndarray):
processed_image = Image.fromarray(processed_image.astype('uint8'), 'RGB')
return processed_image, detections
with gr.Blocks(gr.themes.Soft()) as demo:
gr.Markdown("## Owlv2 Object Detection Demo")
gr.Markdown('Run this space on your own hardware with this command: ```docker run -it -p 7860:7860 --platform=linux/amd64 \
registry.hf.space/macadeliccc-owlv2-base-patch-16-ensemble-demo:latest python app.py```')
with gr.Row():
with gr.Column():
image_input = gr.Image(type="pil", label="Upload or draw an image")
text_input = gr.Textbox(label="Enter comma-separated labels for detection")
submit_button = gr.Button("Detect")
with gr.Column():
image_output = gr.Image(label="Processed Image")
text_output = gr.Text(label="Detections")
submit_button.click(
gradio_detect_and_draw,
inputs=[image_input, text_input],
outputs=[image_output, text_output]
)
# Add examples
examples = [
["assets/snowman.jpg", "snowman"],
["assets/traffic.jpg", "taxi,traffic light"],
["assets/umbrellas.jpg", "umbrella"]
]
gr.Examples(examples, inputs=[image_input, text_input])
demo.launch()