kasper-boy commited on
Commit
8cb992f
1 Parent(s): 264a1ba

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +59 -0
app.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import DetrImageProcessor, DetrForObjectDetection
3
+ from PIL import Image
4
+ import gradio as gr
5
+ import matplotlib.pyplot as plt
6
+ import matplotlib.patches as patches
7
+ import io
8
+
9
+ # Load the processor and model
10
+ processor = DetrImageProcessor.from_pretrained('facebook/detr-resnet-101')
11
+ model = DetrForObjectDetection.from_pretrained('facebook/detr-resnet-101')
12
+
13
+ def object_detection(image):
14
+ # Preprocess the image
15
+ inputs = processor(images=image, return_tensors="pt")
16
+
17
+ # Perform object detection
18
+ outputs = model(**inputs)
19
+
20
+ # Extract bounding boxes and labels
21
+ target_sizes = torch.tensor([image.size[::-1]])
22
+ results = processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.9)[0]
23
+
24
+ # Plot the image with bounding boxes
25
+ plt.figure(figsize=(16, 10))
26
+ plt.imshow(image)
27
+ ax = plt.gca()
28
+
29
+ for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
30
+ box = [round(i, 2) for i in box.tolist()]
31
+ xmin, ymin, xmax, ymax = box
32
+ width, height = xmax - xmin, ymax - ymin
33
+
34
+ ax.add_patch(plt.Rectangle((xmin, ymin), width, height, fill=False, color='red', linewidth=3))
35
+ text = f'{model.config.id2label[label.item()]}: {round(score.item(), 3)}'
36
+ ax.text(xmin, ymin, text, fontsize=15, bbox=dict(facecolor='yellow', alpha=0.5))
37
+
38
+ plt.axis('off')
39
+
40
+ # Save the plot to an image buffer
41
+ buf = io.BytesIO()
42
+ plt.savefig(buf, format='png')
43
+ buf.seek(0)
44
+ plt.close()
45
+
46
+ return buf
47
+
48
+ # Define the Gradio interface
49
+ demo = gr.Interface(
50
+ fn=object_detection,
51
+ inputs=gr.Image(type="pil", label="Upload an Image"),
52
+ outputs=gr.Image(type="pil", label="Detected Objects"),
53
+ title="Object Detection with DETR (ResNet-101)",
54
+ description="Upload an image and get object detection results using the DETR model with a ResNet-101 backbone.",
55
+ )
56
+
57
+ # Launch the Gradio interface
58
+ if __name__ == "__main__":
59
+ demo.launch()