detectobject / app.py
quydoan's picture
Add cat/dog detector app
20b303f
import gradio as gr
import torch
from transformers import DetrImageProcessor, DetrForObjectDetection
from PIL import Image, ImageDraw, ImageFont
import requests # To handle image URLs if needed, but we focus on uploads
# Load the model and processor
# Using revision="no_timm" to potentially avoid the timm dependency if not installed,
# but it's safer to include timm in requirements.txt
processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-101-dc5")
model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-101-dc5")
# Define class names for filtering (check model.config.id2label for exact mapping)
# Common COCO IDs: cat=16, dog=17 (0-indexed) but let's use labels
# We need to get the actual labels the model uses
id2label = model.config.id2label
target_labels = ["cat", "dog"]
target_ids = [label_id for label_id, label in id2label.items() if label in target_labels]
# Colors for bounding boxes (simple example)
colors = {"cat": "red", "dog": "blue"}
def detect_objects(image_input):
"""
Detects cats and dogs in the input image using DETR.
Args:
image_input (PIL.Image.Image): Input image.
Returns:
PIL.Image.Image: Image with bounding boxes drawn around detected cats/dogs.
"""
if image_input is None:
return None
# Convert Gradio input (if numpy) to PIL Image, although type="pil" should handle this
if not isinstance(image_input, Image.Image):
image = Image.fromarray(image_input)
else:
image = image_input.copy() # Work on a copy
# Preprocess the image
inputs = processor(images=image, return_tensors="pt")
# Perform inference
outputs = model(**inputs)
# Post-process the results
# Convert outputs (bounding boxes and class logits) to COCO API format
target_sizes = torch.tensor([image.size[::-1]]) # (height, width)
results = processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.7)[0] # Lower threshold (e.g., 0.5) might find more objects
# Draw bounding boxes for cats and dogs
draw = ImageDraw.Draw(image)
try:
# Use a default font or specify a path to a .ttf file if available in the Space
font = ImageFont.load_default()
except IOError:
print("Default font not found. Using basic drawing without text.")
font = None
detections_found = False
for score, label_id, box in zip(results["scores"], results["labels"], results["boxes"]):
label_id = label_id.item()
if label_id in target_ids:
detections_found = True
box = [round(i, 2) for i in box.tolist()]
label = id2label[label_id]
box_color = colors.get(label, "green") # Default to green if label not in colors dict
print(f"Detected {label} with confidence {round(score.item(), 3)} at {box}")
# Draw rectangle
draw.rectangle(box, outline=box_color, width=3)
# Draw label text
if font:
text = f"{label}: {score.item():.2f}"
text_width, text_height = font.getsize(text) if hasattr(font, 'getsize') else (50, 10) # Estimate size if getsize not available
text_bg_coords = [(box[0], box[1]), (box[0] + text_width + 4, box[1] + text_height + 4)]
draw.rectangle(text_bg_coords, fill=box_color)
draw.text((box[0] + 2, box[1] + 2), text, fill="white", font=font)
if not detections_found:
print("No cats or dogs detected with the current threshold.")
# Optionally add text to the image saying nothing was found
# draw.text((10, 10), "No cats or dogs detected", fill="black", font=font)
return image
# Create the Gradio interface
title = "Cat & Dog Detector (using DETR ResNet-101)"
description = ("Upload an image and the model will draw bounding boxes "
"around detected cats and dogs. Uses the facebook/detr-resnet-101-dc5 model from Hugging Face.")
iface = gr.Interface(
fn=detect_objects,
inputs=gr.Image(type="pil", label="Upload Image"),
outputs=gr.Image(type="pil", label="Output Image with Detections"),
title=title,
description=description,
examples=[
# You can add paths to example images if you upload them to your space
# Or provide URLs
["http://images.cocodataset.org/val2017/000000039769.jpg"], # Example image URL with cats
["https://storage.googleapis.com/petbacker/images/blog/2017/dog-and-cat-cover.jpg"] # Example image with dog and cat
],
allow_flagging="never" # You can change flagging options if needed
)
# Launch the app
iface.launch()