File size: 2,804 Bytes
e62d449
 
 
 
 
 
 
 
 
 
 
 
 
223ea38
 
 
 
 
 
 
e62d449
223ea38
 
e62d449
223ea38
 
 
e62d449
223ea38
 
 
 
e62d449
223ea38
 
 
 
 
e62d449
223ea38
 
 
 
e62d449
223ea38
e62d449
223ea38
 
 
 
 
e62d449
223ea38
 
e62d449
223ea38
 
e62d449
223ea38
 
 
 
e62d449
 
 
8bff2fc
66136c9
8c5514e
e62d449
8c5514e
e62d449
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import torch
from transformers import DetrImageProcessor, DetrForObjectDetection
from PIL import Image
import gradio as gr
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import io

# Load the processor and model
processor = DetrImageProcessor.from_pretrained('facebook/detr-resnet-101')
model = DetrForObjectDetection.from_pretrained('facebook/detr-resnet-101')

def object_detection(image, confidence_threshold):
    try:
        # Convert the input to a PIL Image object if it's not already
        if not isinstance(image, Image.Image):
            image = Image.open(io.BytesIO(image))
        
        # Preprocess the image
        inputs = processor(images=image, return_tensors="pt")

        # Perform object detection
        outputs = model(**inputs)

        # Extract bounding boxes and labels
        target_sizes = torch.tensor([image.size[::-1]])
        results = processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=confidence_threshold)[0]

        # Plot the image with bounding boxes
        plt.figure(figsize=(16, 10))
        plt.imshow(image)
        ax = plt.gca()

        detected_objects = []
        for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
            box = [round(i, 2) for i in box.tolist()]
            xmin, ymin, xmax, ymax = box
            width, height = xmax - xmin, ymax - ymin

            ax.add_patch(plt.Rectangle((xmin, ymin), width, height, fill=False, color='red', linewidth=3))
            text = f'{model.config.id2label[label.item()]}: {round(score.item(), 3)}'
            ax.text(xmin, ymin, text, fontsize=15, bbox=dict(facecolor='yellow', alpha=0.5))
            detected_objects.append(text)

        plt.axis('off')

        # Save the plot to an image buffer
        buf = io.BytesIO()
        plt.savefig(buf, format='png')
        buf.seek(0)
        plt.close()

        # Convert buffer to an Image object
        result_image = Image.open(buf)

        # Join detected objects into a single string
        detected_objects_text = "\n".join(detected_objects)

        return result_image, detected_objects_text
    
    except Exception as e:
        return Image.new("RGB", (224, 224), color="white"), str(e)

# Define the Gradio interface
demo = gr.Interface(
    fn=object_detection,
    inputs=[gr.Image(label="Upload an Image"), gr.Slider(minimum=0.0, maximum=1.0, label="Confidence Threshold")],
    outputs=[gr.Image(label="Detected Objects"), gr.Textbox(label="Detected Objects List")],
    title="Object Detection with DETR (ResNet-101)",
    description="Upload an image and get object detection results using the DETR model with a ResNet-101 backbone."
)

# Launch the Gradio interface
if __name__ == "__main__":
    demo.launch()