File size: 4,680 Bytes
20b303f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import gradio as gr
import torch
from transformers import DetrImageProcessor, DetrForObjectDetection
from PIL import Image, ImageDraw, ImageFont
import requests # To handle image URLs if needed, but we focus on uploads

# Load the model and processor
# Using revision="no_timm" to potentially avoid the timm dependency if not installed,
# but it's safer to include timm in requirements.txt
processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-101-dc5")
model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-101-dc5")

# Define class names for filtering (check model.config.id2label for exact mapping)
# Common COCO IDs: cat=16, dog=17 (0-indexed) but let's use labels
# We need to get the actual labels the model uses
id2label = model.config.id2label
target_labels = ["cat", "dog"]
target_ids = [label_id for label_id, label in id2label.items() if label in target_labels]

# Colors for bounding boxes (simple example)
colors = {"cat": "red", "dog": "blue"}

def detect_objects(image_input):
    """
    Detects cats and dogs in the input image using DETR.

    Args:
        image_input (PIL.Image.Image): Input image.

    Returns:
        PIL.Image.Image: Image with bounding boxes drawn around detected cats/dogs.
    """
    if image_input is None:
        return None

    # Convert Gradio input (if numpy) to PIL Image, although type="pil" should handle this
    if not isinstance(image_input, Image.Image):
         image = Image.fromarray(image_input)
    else:
         image = image_input.copy() # Work on a copy

    # Preprocess the image
    inputs = processor(images=image, return_tensors="pt")

    # Perform inference
    outputs = model(**inputs)

    # Post-process the results
    # Convert outputs (bounding boxes and class logits) to COCO API format
    target_sizes = torch.tensor([image.size[::-1]]) # (height, width)
    results = processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.7)[0] # Lower threshold (e.g., 0.5) might find more objects

    # Draw bounding boxes for cats and dogs
    draw = ImageDraw.Draw(image)
    try:
        # Use a default font or specify a path to a .ttf file if available in the Space
        font = ImageFont.load_default()
    except IOError:
        print("Default font not found. Using basic drawing without text.")
        font = None

    detections_found = False
    for score, label_id, box in zip(results["scores"], results["labels"], results["boxes"]):
        label_id = label_id.item()
        if label_id in target_ids:
            detections_found = True
            box = [round(i, 2) for i in box.tolist()]
            label = id2label[label_id]
            box_color = colors.get(label, "green") # Default to green if label not in colors dict

            print(f"Detected {label} with confidence {round(score.item(), 3)} at {box}")

            # Draw rectangle
            draw.rectangle(box, outline=box_color, width=3)

            # Draw label text
            if font:
                text = f"{label}: {score.item():.2f}"
                text_width, text_height = font.getsize(text) if hasattr(font, 'getsize') else (50, 10) # Estimate size if getsize not available
                text_bg_coords = [(box[0], box[1]), (box[0] + text_width + 4, box[1] + text_height + 4)]
                draw.rectangle(text_bg_coords, fill=box_color)
                draw.text((box[0] + 2, box[1] + 2), text, fill="white", font=font)

    if not detections_found:
        print("No cats or dogs detected with the current threshold.")
        # Optionally add text to the image saying nothing was found
        # draw.text((10, 10), "No cats or dogs detected", fill="black", font=font)


    return image

# Create the Gradio interface
title = "Cat & Dog Detector (using DETR ResNet-101)"
description = ("Upload an image and the model will draw bounding boxes "
               "around detected cats and dogs. Uses the facebook/detr-resnet-101-dc5 model from Hugging Face.")

iface = gr.Interface(
    fn=detect_objects,
    inputs=gr.Image(type="pil", label="Upload Image"),
    outputs=gr.Image(type="pil", label="Output Image with Detections"),
    title=title,
    description=description,
    examples=[
        # You can add paths to example images if you upload them to your space
        # Or provide URLs
         ["http://images.cocodataset.org/val2017/000000039769.jpg"], # Example image URL with cats
         ["https://storage.googleapis.com/petbacker/images/blog/2017/dog-and-cat-cover.jpg"] # Example image with dog and cat
    ],
    allow_flagging="never" # You can change flagging options if needed
)

# Launch the app
iface.launch()