Spaces:
Runtime error
Runtime error
sshi
commited on
Commit
·
8004ebd
1
Parent(s):
af2b646
Add notebook file.
Browse files- .gitattributes +1 -0
- Fine-tuning YOLOS for traffic object detection.ipynb +3 -0
- README.md +1 -1
- app.py +28 -36
.gitattributes
CHANGED
@@ -32,3 +32,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
32 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
33 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
34 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
32 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
33 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
34 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
35 |
+
*.ipynb filter=lfs diff=lfs merge=lfs -text
|
Fine-tuning YOLOS for traffic object detection.ipynb
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:c6328705e1cce9fc89a243319dc8b57997f3791298f312d2c6ac078cc8034e32
|
3 |
+
size 15511753
|
README.md
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
---
|
2 |
-
title: YOLOS Traffic
|
3 |
emoji: 🔥
|
4 |
colorFrom: gray
|
5 |
colorTo: yellow
|
|
|
1 |
---
|
2 |
+
title: YOLOS Traffic Object detection
|
3 |
emoji: 🔥
|
4 |
colorFrom: gray
|
5 |
colorTo: yellow
|
app.py
CHANGED
@@ -3,21 +3,31 @@ import os
|
|
3 |
import torch
|
4 |
import pytorch_lightning as pl
|
5 |
|
6 |
-
# torch.hub.download_url_to_file('http://images.cocodataset.org/val2017/000000039769.jpg', 'cats.jpg')
|
7 |
-
# torch.hub.download_url_to_file('https://huggingface.co/datasets/nielsr/textcaps-sample/resolve/main/stop_sign.png', 'stop_sign.png')
|
8 |
-
# torch.hub.download_url_to_file('https://cdn.openai.com/dall-e-2/demos/text2im/astronaut/horse/photo/0.jpg', 'astronaut.jpg')
|
9 |
-
|
10 |
-
# os.system("wget https://github.com/hustvl/YOLOP/raw/main/weights/End-to-end.pth")
|
11 |
-
|
12 |
-
from transformers import AutoFeatureExtractor, AutoModelForObjectDetection
|
13 |
-
|
14 |
-
from PIL import Image, ImageDraw
|
15 |
import cv2
|
16 |
import numpy
|
17 |
-
import
|
|
|
|
|
|
|
|
|
|
|
18 |
|
19 |
id2label = {1: 'person', 2: 'rider', 3: 'car', 4: 'bus', 5: 'truck', 6: 'bike', 7: 'motor', 8: 'traffic light', 9: 'traffic sign', 10: 'train'}
|
20 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
21 |
class Detr(pl.LightningModule):
|
22 |
|
23 |
def __init__(self, lr, weight_decay):
|
@@ -71,10 +81,6 @@ class Detr(pl.LightningModule):
|
|
71 |
return optimizer
|
72 |
|
73 |
|
74 |
-
device = "cuda" if torch.cuda.is_available() else "cpu"
|
75 |
-
|
76 |
-
feature_extractor = AutoFeatureExtractor.from_pretrained("hustvl/yolos-small", size=512, max_size=864)
|
77 |
-
|
78 |
# Build model and load checkpoint
|
79 |
checkpoint = './checkpoints/epoch=1-step=2184.ckpt'
|
80 |
model_yolos = Detr.load_from_checkpoint(checkpoint, lr=2.5e-5, weight_decay=1e-4)
|
@@ -82,19 +88,6 @@ model_yolos = Detr.load_from_checkpoint(checkpoint, lr=2.5e-5, weight_decay=1e-4
|
|
82 |
model_yolos.to(device)
|
83 |
model_yolos.eval()
|
84 |
|
85 |
-
# colors for visualization
|
86 |
-
colors = [
|
87 |
-
[ 0, 113, 188,],
|
88 |
-
[216, 82, 24,],
|
89 |
-
[236, 176, 31,],
|
90 |
-
[192, 202, 25,],
|
91 |
-
[118, 171, 47,],
|
92 |
-
[ 76, 189, 237,],
|
93 |
-
[ 46, 125, 188,],
|
94 |
-
[125, 171, 141,],
|
95 |
-
[125, 76, 237,],
|
96 |
-
[ 0, 82, 216,],
|
97 |
-
[189, 76, 47,]]
|
98 |
|
99 |
# for output bounding box post-processing
|
100 |
def box_cxcywh_to_xyxy(x):
|
@@ -103,12 +96,14 @@ def box_cxcywh_to_xyxy(x):
|
|
103 |
(x_c + 0.5 * w), (y_c + 0.5 * h)]
|
104 |
return torch.stack(b, dim=1)
|
105 |
|
|
|
106 |
def rescale_bboxes(out_bbox, size):
|
107 |
img_w, img_h = size
|
108 |
b = box_cxcywh_to_xyxy(out_bbox)
|
109 |
b = b * torch.tensor([img_w, img_h, img_w, img_h], dtype=torch.float32)
|
110 |
return b
|
111 |
|
|
|
112 |
def plot_results(pil_img, prob, boxes):
|
113 |
|
114 |
img = numpy.asarray(pil_img)
|
@@ -119,12 +114,8 @@ def plot_results(pil_img, prob, boxes):
|
|
119 |
c1, c2 = (int(xmin), int(ymin)), (int(xmax), int(ymax))
|
120 |
|
121 |
cv2.rectangle(img, c1, c2, c, thickness=2, lineType=cv2.LINE_AA)
|
122 |
-
|
123 |
-
cv2.putText(img, f'{id2label[cl.item()]}: {p[cl]:0.2f}', [int(xmin), int(ymin)], cv2.FONT_HERSHEY_SIMPLEX, 0.5, c, 1)
|
124 |
-
# ax.text(xmin, ymin, text, fontsize=10,
|
125 |
-
# bbox=dict(facecolor=c, alpha=0.5))
|
126 |
return Image.fromarray(img)
|
127 |
-
# return fig
|
128 |
|
129 |
|
130 |
def generate_preds(processor, model, image):
|
@@ -145,20 +136,21 @@ def visualize_preds(image, preds, threshold=0.9):
|
|
145 |
|
146 |
|
147 |
def detect(img):
|
148 |
-
|
149 |
# Run inference
|
150 |
preds = generate_preds(feature_extractor, model_yolos, img)
|
151 |
-
|
152 |
return visualize_preds(img, preds)
|
153 |
|
|
|
|
|
|
|
|
|
154 |
|
155 |
interface = gr.Interface(
|
156 |
fn=detect,
|
157 |
inputs=[gr.Image(type="pil")],
|
158 |
outputs=gr.Image(type="pil"),
|
159 |
-
# outputs = ['plot'],
|
160 |
examples=[["./imgs/example1.jpg"], ["./imgs/example2.jpg"]],
|
161 |
title="YOLOS for traffic object detection",
|
162 |
-
description=
|
163 |
|
164 |
interface.launch()
|
|
|
3 |
import torch
|
4 |
import pytorch_lightning as pl
|
5 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
import cv2
|
7 |
import numpy
|
8 |
+
from transformers import AutoFeatureExtractor, AutoModelForObjectDetection
|
9 |
+
from PIL import Image
|
10 |
+
|
11 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
12 |
+
|
13 |
+
feature_extractor = AutoFeatureExtractor.from_pretrained("hustvl/yolos-small", size=512, max_size=864)
|
14 |
|
15 |
id2label = {1: 'person', 2: 'rider', 3: 'car', 4: 'bus', 5: 'truck', 6: 'bike', 7: 'motor', 8: 'traffic light', 9: 'traffic sign', 10: 'train'}
|
16 |
|
17 |
+
# colors for visualization
|
18 |
+
colors = [
|
19 |
+
[ 0, 113, 188,],
|
20 |
+
[216, 82, 24,],
|
21 |
+
[236, 176, 31,],
|
22 |
+
[192, 202, 25,],
|
23 |
+
[118, 171, 47,],
|
24 |
+
[ 76, 189, 237,],
|
25 |
+
[ 46, 125, 188,],
|
26 |
+
[125, 171, 141,],
|
27 |
+
[125, 76, 237,],
|
28 |
+
[ 0, 82, 216,],
|
29 |
+
[189, 76, 47,]]
|
30 |
+
|
31 |
class Detr(pl.LightningModule):
|
32 |
|
33 |
def __init__(self, lr, weight_decay):
|
|
|
81 |
return optimizer
|
82 |
|
83 |
|
|
|
|
|
|
|
|
|
84 |
# Build model and load checkpoint
|
85 |
checkpoint = './checkpoints/epoch=1-step=2184.ckpt'
|
86 |
model_yolos = Detr.load_from_checkpoint(checkpoint, lr=2.5e-5, weight_decay=1e-4)
|
|
|
88 |
model_yolos.to(device)
|
89 |
model_yolos.eval()
|
90 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
91 |
|
92 |
# for output bounding box post-processing
|
93 |
def box_cxcywh_to_xyxy(x):
|
|
|
96 |
(x_c + 0.5 * w), (y_c + 0.5 * h)]
|
97 |
return torch.stack(b, dim=1)
|
98 |
|
99 |
+
|
100 |
def rescale_bboxes(out_bbox, size):
|
101 |
img_w, img_h = size
|
102 |
b = box_cxcywh_to_xyxy(out_bbox)
|
103 |
b = b * torch.tensor([img_w, img_h, img_w, img_h], dtype=torch.float32)
|
104 |
return b
|
105 |
|
106 |
+
|
107 |
def plot_results(pil_img, prob, boxes):
|
108 |
|
109 |
img = numpy.asarray(pil_img)
|
|
|
114 |
c1, c2 = (int(xmin), int(ymin)), (int(xmax), int(ymax))
|
115 |
|
116 |
cv2.rectangle(img, c1, c2, c, thickness=2, lineType=cv2.LINE_AA)
|
117 |
+
cv2.putText(img, f'{id2label[cl.item()]}: {p[cl]:0.2f}', [int(xmin), int(ymin)-5], cv2.FONT_HERSHEY_SIMPLEX, 0.7, c, 2)
|
|
|
|
|
|
|
118 |
return Image.fromarray(img)
|
|
|
119 |
|
120 |
|
121 |
def generate_preds(processor, model, image):
|
|
|
136 |
|
137 |
|
138 |
def detect(img):
|
|
|
139 |
# Run inference
|
140 |
preds = generate_preds(feature_extractor, model_yolos, img)
|
|
|
141 |
return visualize_preds(img, preds)
|
142 |
|
143 |
+
|
144 |
+
description = "This is a traffic object detector based on <a href='https://huggingface.co/docs/transformers/model_doc/yolos' style='text-decoration: underline' target='_blank'>YOLOS</a>. \n" + \
|
145 |
+
"The model can detect following targets: {1: 'person', 2: 'rider', 3: 'car', 4: 'bus', 5: 'truck', 6: 'bike', 7: 'motor', 8: 'traffic light', 9: 'traffic sign', 10: 'train'}."
|
146 |
+
|
147 |
|
148 |
interface = gr.Interface(
|
149 |
fn=detect,
|
150 |
inputs=[gr.Image(type="pil")],
|
151 |
outputs=gr.Image(type="pil"),
|
|
|
152 |
examples=[["./imgs/example1.jpg"], ["./imgs/example2.jpg"]],
|
153 |
title="YOLOS for traffic object detection",
|
154 |
+
description=description)
|
155 |
|
156 |
interface.launch()
|