kasper-boy's picture
Rename mainapp.py to app.py
790a64b verified
raw
history blame
No virus
2.08 kB
import torch
from transformers import DetrImageProcessor, DetrForObjectDetection
from PIL import Image
import gradio as gr
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import io
# Load the processor and model
processor = DetrImageProcessor.from_pretrained('facebook/detr-resnet-101')
model = DetrForObjectDetection.from_pretrained('facebook/detr-resnet-101')
def object_detection(image):
# Preprocess the image
inputs = processor(images=image, return_tensors="pt")
# Perform object detection
outputs = model(**inputs)
# Extract bounding boxes and labels
target_sizes = torch.tensor([image.size[::-1]])
results = processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.9)[0]
# Plot the image with bounding boxes
plt.figure(figsize=(16, 10))
plt.imshow(image)
ax = plt.gca()
for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
box = [round(i, 2) for i in box.tolist()]
xmin, ymin, xmax, ymax = box
width, height = xmax - xmin, ymax - ymin
ax.add_patch(plt.Rectangle((xmin, ymin), width, height, fill=False, color='red', linewidth=3))
text = f'{model.config.id2label[label.item()]}: {round(score.item(), 3)}'
ax.text(xmin, ymin, text, fontsize=15, bbox=dict(facecolor='yellow', alpha=0.5))
plt.axis('off')
# Save the plot to an image buffer
buf = io.BytesIO()
plt.savefig(buf, format='png')
buf.seek(0)
plt.close()
# Convert buffer to an Image object
result_image = Image.open(buf)
return result_image
# Define the Gradio interface
demo = gr.Interface(
fn=object_detection,
inputs=gr.Image(type="pil", label="Upload an Image"),
outputs=gr.Image(type="pil", label="Detected Objects"),
title="Object Detection with DETR (ResNet-101)",
description="Upload an image and get object detection results using the DETR model with a ResNet-101 backbone.",
)
# Launch the Gradio interface
if __name__ == "__main__":
demo.launch()