Darius Morawiec commited on
Commit
9401db3
·
1 Parent(s): 15693ed

Refactor object detection logic and update UI components for improved usability

Browse files
Files changed (1) hide show
  1. app.py +45 -99
app.py CHANGED
@@ -2,9 +2,9 @@ import gradio as gr
2
  import PIL.Image
3
  import torch
4
  from transformers import AutoModelForZeroShotObjectDetection, AutoProcessor
 
5
 
6
- # DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
7
- DEVICE = "cpu"
8
 
9
 
10
  class Detector:
@@ -67,116 +67,62 @@ def _postprocess(detections):
67
 
68
  def detect_objects(image, labels, confidence_threshold):
69
  labels = [label.strip() for label in labels.split(",")]
70
- return (
71
- (
72
- image,
73
- _postprocess(
74
- models["tiny"].detect(
75
- image,
76
- labels,
77
- threshold=confidence_threshold,
78
- )
79
- ),
80
- ),
81
- (
82
- image,
83
- _postprocess(
84
- models["base"].detect(
85
- image,
86
- labels,
87
- threshold=confidence_threshold,
88
- )
89
- ),
90
- ),
91
- (
92
  image,
93
- _postprocess(
94
- models["large"].detect(
95
- image,
96
- labels,
97
- threshold=confidence_threshold,
98
- )
99
- ),
100
- ),
101
- )
102
 
103
 
104
  with gr.Blocks() as demo:
105
- gr.Markdown("# LLMDet Open Vocabulary Object Detection")
106
-
107
- confidence_slider = gr.Slider(
108
- 0,
109
- 1,
110
- value=0.4,
111
- step=0.01,
112
- interactive=True,
113
- label="Confidence threshold",
114
- )
115
 
116
- labels = [
117
- "backpack",
118
- "bag",
119
- "belt",
120
- "blouse",
121
- "boot",
122
- "bracelet",
123
- "cap",
124
- "cardigan",
125
- "coat",
126
- "dress",
127
- "earring",
128
- "flipflop",
129
- "glasses",
130
- "glove",
131
- "handbag",
132
- "hat",
133
- "heels",
134
- "jacket",
135
- "jeans",
136
- "loafer",
137
- "necklace",
138
- "pullover",
139
- "raincoat",
140
- "ring",
141
- "sandal",
142
- "scarf",
143
- "shirt",
144
- "shoe",
145
- "shorts",
146
- "skirt",
147
- "slippers",
148
- "sneaker",
149
- "socks",
150
- "suitcase",
151
- "sunglasses",
152
- "sweater",
153
- "tshirt",
154
- "tie",
155
- "top",
156
- "trouser",
157
- "umbrella",
158
- "vest",
159
- "watch",
160
- ]
161
-
162
- # Requested labels
163
- text_input = gr.Textbox(
164
- label="Object labels (comma separated)!",
165
- placeholder="shirt, jeans, shoe",
166
- lines=1,
167
- value=",".join(labels),
168
- )
169
 
170
  with gr.Row():
171
- image_input = gr.Image(type="pil", image_mode="RGB")
172
 
173
  with gr.Row():
174
  output_annotated_image_tiny = gr.AnnotatedImage(label="TINY")
175
  output_annotated_image_base = gr.AnnotatedImage(label="BASE")
176
  output_annotated_image_large = gr.AnnotatedImage(label="LARGE")
177
 
178
- detect_button = gr.Button("Detect")
179
-
180
  # Connect the button to the detection function
181
  detect_button.click(
182
  fn=detect_objects,
 
2
  import PIL.Image
3
  import torch
4
  from transformers import AutoModelForZeroShotObjectDetection, AutoProcessor
5
+ from transformers.image_utils import load_image
6
 
7
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
 
8
 
9
 
10
  class Detector:
 
67
 
68
  def detect_objects(image, labels, confidence_threshold):
69
  labels = [label.strip() for label in labels.split(",")]
70
+
71
+ detections = []
72
+ for model_name in models.keys():
73
+ detection = models[model_name].detect(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
  image,
75
+ labels,
76
+ threshold=confidence_threshold,
77
+ )
78
+ detections.append(_postprocess(detection))
79
+
80
+ return tuple((image, det) for det in detections)
 
 
 
81
 
82
 
83
  with gr.Blocks() as demo:
84
+ gr.Markdown("# [LLMDet](https://arxiv.org/abs/2501.18954) Arena ")
 
 
 
 
 
 
 
 
 
85
 
86
+ with gr.Row():
87
+ with gr.Column():
88
+ gr.Markdown("## Input Image")
89
+
90
+ image_url = "http://images.cocodataset.org/val2017/000000039769.jpg"
91
+ image = load_image(image_url)
92
+ image_input = gr.Image(type="pil", image_mode="RGB", value=image)
93
+
94
+ with gr.Column():
95
+ gr.Markdown("## Settings")
96
+
97
+ confidence_slider = gr.Slider(
98
+ 0,
99
+ 1,
100
+ value=0.4,
101
+ step=0.01,
102
+ interactive=True,
103
+ label="Confidence threshold:",
104
+ )
105
+
106
+ labels = ["a cat", "a remote control"]
107
+
108
+ text_input = gr.Textbox(
109
+ label="Object labels (comma separated):",
110
+ placeholder=",".join(labels),
111
+ lines=1,
112
+ value=",".join(labels),
113
+ )
114
+
115
+ with gr.Row():
116
+ detect_button = gr.Button("Run Object Detection")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
 
118
  with gr.Row():
119
+ gr.Markdown("## Output Annotated Images")
120
 
121
  with gr.Row():
122
  output_annotated_image_tiny = gr.AnnotatedImage(label="TINY")
123
  output_annotated_image_base = gr.AnnotatedImage(label="BASE")
124
  output_annotated_image_large = gr.AnnotatedImage(label="LARGE")
125
 
 
 
126
  # Connect the button to the detection function
127
  detect_button.click(
128
  fn=detect_objects,