camilleseab commited on
Commit
d544c8f
1 Parent(s): e785fa3

Tweak image labels

Browse files
Files changed (1) hide show
  1. notebooks/app.ipynb +34 -18
notebooks/app.ipynb CHANGED
@@ -2,7 +2,7 @@
2
  "cells": [
3
  {
4
  "cell_type": "code",
5
- "execution_count": 179,
6
  "metadata": {},
7
  "outputs": [],
8
  "source": [
@@ -43,7 +43,7 @@
43
  },
44
  {
45
  "cell_type": "code",
46
- "execution_count": 180,
47
  "metadata": {},
48
  "outputs": [],
49
  "source": [
@@ -68,7 +68,7 @@
68
  },
69
  {
70
  "cell_type": "code",
71
- "execution_count": 181,
72
  "metadata": {},
73
  "outputs": [],
74
  "source": [
@@ -99,15 +99,17 @@
99
  " grd = widgets.VBox([widgets.Label(label), img_widget])\n",
100
  " return grd\n",
101
  "\n",
102
- "def label_img(img, model) -> Image:\n",
103
- " pred = model.predict(img, device = 'cpu')[0].plot(labels = False)\n",
104
- " pred = cv2.cvtColor(pred, cv2.COLOR_BGR2RGB)\n",
105
- " return Image.fromarray(pred)"
 
 
106
  ]
107
  },
108
  {
109
  "cell_type": "code",
110
- "execution_count": 182,
111
  "metadata": {},
112
  "outputs": [
113
  {
@@ -115,11 +117,11 @@
115
  "output_type": "stream",
116
  "text": [
117
  "\n",
118
- "0: 640x640 1 surveillance, 262.8ms\n",
119
- "Speed: 1.9ms preprocess, 262.8ms inference, 1.1ms postprocess per image at shape (1, 3, 640, 640)\n",
120
  "\n",
121
- "0: 640x640 3 surveillances, 988.3ms\n",
122
- "Speed: 1.8ms preprocess, 988.3ms inference, 0.6ms postprocess per image at shape (1, 3, 640, 640)\n"
123
  ]
124
  }
125
  ],
@@ -166,8 +168,13 @@
166
  "grid[0, :] = widgets.VBox([location, size, heading, pitch, fov, button],\n",
167
  " layout = widgets.Layout(height = 'auto'))\n",
168
  "\n",
 
 
169
  "\n",
170
  "def button_click(b):\n",
 
 
 
171
  " img = get_sv_img(location=location.value,\n",
172
  " size=size.value,\n",
173
  " heading=heading.value,\n",
@@ -175,10 +182,14 @@
175
  " fov=fov.value)\n",
176
  " if img is not None:\n",
177
  " grid[1:, 0] = make_img_widget(img, 'Original image')\n",
178
- " yolo_pred = label_img(img, yolo)\n",
179
- " grid[1:, 1] = make_img_widget(yolo_pred, 'YOLO predictions')\n",
180
- " detr_pred = label_img(img, detr)\n",
181
- " grid[1:, 2] = make_img_widget(detr_pred, 'RT-DETR predictions')\n",
 
 
 
 
182
  " \n",
183
  "\n",
184
  "button.on_click(button_click)\n"
@@ -186,13 +197,13 @@
186
  },
187
  {
188
  "cell_type": "code",
189
- "execution_count": 184,
190
  "metadata": {},
191
  "outputs": [
192
  {
193
  "data": {
194
  "application/vnd.jupyter.widget-view+json": {
195
- "model_id": "9020055a57294d38b6f3e40659b12a1d",
196
  "version_major": 2,
197
  "version_minor": 0
198
  },
@@ -207,6 +218,11 @@
207
  "source": [
208
  "display(grid)"
209
  ]
 
 
 
 
 
210
  }
211
  ],
212
  "metadata": {
 
2
  "cells": [
3
  {
4
  "cell_type": "code",
5
+ "execution_count": 1,
6
  "metadata": {},
7
  "outputs": [],
8
  "source": [
 
43
  },
44
  {
45
  "cell_type": "code",
46
+ "execution_count": 2,
47
  "metadata": {},
48
  "outputs": [],
49
  "source": [
 
68
  },
69
  {
70
  "cell_type": "code",
71
+ "execution_count": 35,
72
  "metadata": {},
73
  "outputs": [],
74
  "source": [
 
99
  " grd = widgets.VBox([widgets.Label(label), img_widget])\n",
100
  " return grd\n",
101
  "\n",
102
+ "def label_img(img, model) -> (Image, int):\n",
103
+ " pred = model.predict(img, device = 'cpu')[0]\n",
104
+ " n = pred.boxes.data.shape[0]\n",
105
+ " plot = pred.plot(labels = False)\n",
106
+ " plot = cv2.cvtColor(plot, cv2.COLOR_BGR2RGB)\n",
107
+ " return Image.fromarray(plot), n"
108
  ]
109
  },
110
  {
111
  "cell_type": "code",
112
+ "execution_count": 38,
113
  "metadata": {},
114
  "outputs": [
115
  {
 
117
  "output_type": "stream",
118
  "text": [
119
  "\n",
120
+ "0: 640x640 2 surveillances, 430.8ms\n",
121
+ "Speed: 19.0ms preprocess, 430.8ms inference, 1.3ms postprocess per image at shape (1, 3, 640, 640)\n",
122
  "\n",
123
+ "0: 640x640 2 surveillances, 1938.6ms\n",
124
+ "Speed: 2.5ms preprocess, 1938.6ms inference, 1.1ms postprocess per image at shape (1, 3, 640, 640)\n"
125
  ]
126
  }
127
  ],
 
168
  "grid[0, :] = widgets.VBox([location, size, heading, pitch, fov, button],\n",
169
  " layout = widgets.Layout(height = 'auto'))\n",
170
  "\n",
171
+ "def singular(x):\n",
172
+ " return '' if x == 1 else 's'\n",
173
  "\n",
174
  "def button_click(b):\n",
175
+ " for i in range(3):\n",
176
+ " grid[1:, i] = widgets.Label('Loading...')\n",
177
+ " \n",
178
  " img = get_sv_img(location=location.value,\n",
179
  " size=size.value,\n",
180
  " heading=heading.value,\n",
 
182
  " fov=fov.value)\n",
183
  " if img is not None:\n",
184
  " grid[1:, 0] = make_img_widget(img, 'Original image')\n",
185
+ " \n",
186
+ " yolo_pred, yolo_n = label_img(img, yolo)\n",
187
+ " yolo_suffix = singular(yolo_n)\n",
188
+ " grid[1:, 1] = make_img_widget(yolo_pred, f'YOLO predictions ({(yolo_n)} result{yolo_suffix})')\n",
189
+ " \n",
190
+ " detr_pred, detr_n = label_img(img, detr)\n",
191
+ " detr_suffix = singular(detr_n)\n",
192
+ " grid[1:, 2] = make_img_widget(detr_pred, f'RT-DETR predictions ({detr_n} result{detr_suffix})')\n",
193
  " \n",
194
  "\n",
195
  "button.on_click(button_click)\n"
 
197
  },
198
  {
199
  "cell_type": "code",
200
+ "execution_count": 37,
201
  "metadata": {},
202
  "outputs": [
203
  {
204
  "data": {
205
  "application/vnd.jupyter.widget-view+json": {
206
+ "model_id": "f3343b2adf0f49dea6c1830d0cdbc9bb",
207
  "version_major": 2,
208
  "version_minor": 0
209
  },
 
218
  "source": [
219
  "display(grid)"
220
  ]
221
+ },
222
+ {
223
+ "cell_type": "markdown",
224
+ "metadata": {},
225
+ "source": []
226
  }
227
  ],
228
  "metadata": {