pyedward commited on
Commit
6e95c86
·
verified ·
1 Parent(s): 7ce333b

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +67 -0
  2. requirements.txt +6 -0
app.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import DetrImageProcessor, DetrForObjectDetection
2
+ import torch
3
+ from PIL import Image, ImageDraw
4
+ import requests
5
+ import random
6
+ from IPython.display import display
7
+ import gradio as gr
8
+
9
+ # you can specify the revision tag if you don't want the timm dependency
10
+ processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50", revision="no_timm")
11
+ model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50", revision="no_timm")
12
+
13
+ def draw_detections(image, outputs, processor, model, threshold=0.9):
14
+ """
15
+ Draw bounding boxes and labels on an image using detection results.
16
+
17
+ Args:
18
+ image (PIL.Image): Input image.
19
+ outputs (dict): Model output.
20
+ processor: The processor used for post-processing.
21
+ model: The object detection model.
22
+ threshold (float): Confidence threshold.
23
+
24
+ Returns:
25
+ PIL.Image: The image with bounding boxes drawn.
26
+ """
27
+ target_sizes = torch.tensor([image.size[::-1]])
28
+ results = processor.post_process_object_detection(
29
+ outputs, target_sizes=target_sizes, threshold=threshold
30
+ )[0]
31
+
32
+ draw_image = image.copy()
33
+ draw = ImageDraw.Draw(draw_image, "RGBA")
34
+
35
+ # define fixed colors per label for consistency
36
+ COLORS = {}
37
+ for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
38
+ box = [round(i, 2) for i in box.tolist()]
39
+ label_name = model.config.id2label[label.item()]
40
+
41
+ # assign consistent random color for each label type
42
+ if label_name not in COLORS:
43
+ COLORS[label_name] = tuple(random.choices(range(256), k=3))
44
+ color = COLORS[label_name]
45
+
46
+ # draw translucent box
47
+ draw.rectangle(box, fill=color + (80,), outline=color, width=3)
48
+ draw.text((box[0] + 3, box[1] + 3),
49
+ f"{label_name} {round(score.item(), 2)}",
50
+ fill=(255, 255, 255, 255))
51
+
52
+ return draw_image
53
+
54
+
55
+ def detect_and_draw(img):
56
+ inputs = processor(images=img, return_tensors="pt")
57
+ outputs = model(**inputs)
58
+ return draw_detections(img, outputs, processor, model)
59
+
60
+ demo = gr.Interface(
61
+ fn=detect_and_draw,
62
+ inputs=gr.Image(type="pil"),
63
+ outputs="image",
64
+ title="Object Detection Viewer"
65
+ )
66
+
67
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ torch
2
+ transformers
3
+ Pillow
4
+ requests
5
+ ipython
6
+ gradio