sshi commited on
Commit
8004ebd
·
1 Parent(s): af2b646

Add notebook file.

Browse files
.gitattributes CHANGED
@@ -32,3 +32,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
35
+ *.ipynb filter=lfs diff=lfs merge=lfs -text
Fine-tuning YOLOS for traffic object detection.ipynb ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c6328705e1cce9fc89a243319dc8b57997f3791298f312d2c6ac078cc8034e32
3
+ size 15511753
README.md CHANGED
@@ -1,5 +1,5 @@
1
  ---
2
- title: YOLOS Traffic Users
3
  emoji: 🔥
4
  colorFrom: gray
5
  colorTo: yellow
 
1
  ---
2
+ title: YOLOS Traffic Object detection
3
  emoji: 🔥
4
  colorFrom: gray
5
  colorTo: yellow
app.py CHANGED
@@ -3,21 +3,31 @@ import os
3
  import torch
4
  import pytorch_lightning as pl
5
 
6
- # torch.hub.download_url_to_file('http://images.cocodataset.org/val2017/000000039769.jpg', 'cats.jpg')
7
- # torch.hub.download_url_to_file('https://huggingface.co/datasets/nielsr/textcaps-sample/resolve/main/stop_sign.png', 'stop_sign.png')
8
- # torch.hub.download_url_to_file('https://cdn.openai.com/dall-e-2/demos/text2im/astronaut/horse/photo/0.jpg', 'astronaut.jpg')
9
-
10
- # os.system("wget https://github.com/hustvl/YOLOP/raw/main/weights/End-to-end.pth")
11
-
12
- from transformers import AutoFeatureExtractor, AutoModelForObjectDetection
13
-
14
- from PIL import Image, ImageDraw
15
  import cv2
16
  import numpy
17
- import matplotlib.pyplot as plt
 
 
 
 
 
18
 
19
  id2label = {1: 'person', 2: 'rider', 3: 'car', 4: 'bus', 5: 'truck', 6: 'bike', 7: 'motor', 8: 'traffic light', 9: 'traffic sign', 10: 'train'}
20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  class Detr(pl.LightningModule):
22
 
23
  def __init__(self, lr, weight_decay):
@@ -71,10 +81,6 @@ class Detr(pl.LightningModule):
71
  return optimizer
72
 
73
 
74
- device = "cuda" if torch.cuda.is_available() else "cpu"
75
-
76
- feature_extractor = AutoFeatureExtractor.from_pretrained("hustvl/yolos-small", size=512, max_size=864)
77
-
78
  # Build model and load checkpoint
79
  checkpoint = './checkpoints/epoch=1-step=2184.ckpt'
80
  model_yolos = Detr.load_from_checkpoint(checkpoint, lr=2.5e-5, weight_decay=1e-4)
@@ -82,19 +88,6 @@ model_yolos = Detr.load_from_checkpoint(checkpoint, lr=2.5e-5, weight_decay=1e-4
82
  model_yolos.to(device)
83
  model_yolos.eval()
84
 
85
- # colors for visualization
86
- colors = [
87
- [ 0, 113, 188,],
88
- [216, 82, 24,],
89
- [236, 176, 31,],
90
- [192, 202, 25,],
91
- [118, 171, 47,],
92
- [ 76, 189, 237,],
93
- [ 46, 125, 188,],
94
- [125, 171, 141,],
95
- [125, 76, 237,],
96
- [ 0, 82, 216,],
97
- [189, 76, 47,]]
98
 
99
  # for output bounding box post-processing
100
  def box_cxcywh_to_xyxy(x):
@@ -103,12 +96,14 @@ def box_cxcywh_to_xyxy(x):
103
  (x_c + 0.5 * w), (y_c + 0.5 * h)]
104
  return torch.stack(b, dim=1)
105
 
 
106
  def rescale_bboxes(out_bbox, size):
107
  img_w, img_h = size
108
  b = box_cxcywh_to_xyxy(out_bbox)
109
  b = b * torch.tensor([img_w, img_h, img_w, img_h], dtype=torch.float32)
110
  return b
111
 
 
112
  def plot_results(pil_img, prob, boxes):
113
 
114
  img = numpy.asarray(pil_img)
@@ -119,12 +114,8 @@ def plot_results(pil_img, prob, boxes):
119
  c1, c2 = (int(xmin), int(ymin)), (int(xmax), int(ymax))
120
 
121
  cv2.rectangle(img, c1, c2, c, thickness=2, lineType=cv2.LINE_AA)
122
-
123
- cv2.putText(img, f'{id2label[cl.item()]}: {p[cl]:0.2f}', [int(xmin), int(ymin)], cv2.FONT_HERSHEY_SIMPLEX, 0.5, c, 1)
124
- # ax.text(xmin, ymin, text, fontsize=10,
125
- # bbox=dict(facecolor=c, alpha=0.5))
126
  return Image.fromarray(img)
127
- # return fig
128
 
129
 
130
  def generate_preds(processor, model, image):
@@ -145,20 +136,21 @@ def visualize_preds(image, preds, threshold=0.9):
145
 
146
 
147
  def detect(img):
148
-
149
  # Run inference
150
  preds = generate_preds(feature_extractor, model_yolos, img)
151
-
152
  return visualize_preds(img, preds)
153
 
 
 
 
 
154
 
155
  interface = gr.Interface(
156
  fn=detect,
157
  inputs=[gr.Image(type="pil")],
158
  outputs=gr.Image(type="pil"),
159
- # outputs = ['plot'],
160
  examples=[["./imgs/example1.jpg"], ["./imgs/example2.jpg"]],
161
  title="YOLOS for traffic object detection",
162
- description="A downstream application for <a href='https://huggingface.co/docs/transformers/model_doc/yolos' style='text-decoration: underline' target='_blank'>YOLOS</a> which can performe traffic object detection. ")
163
 
164
  interface.launch()
 
3
  import torch
4
  import pytorch_lightning as pl
5
 
 
 
 
 
 
 
 
 
 
6
  import cv2
7
  import numpy
8
+ from transformers import AutoFeatureExtractor, AutoModelForObjectDetection
9
+ from PIL import Image
10
+
11
+ device = "cuda" if torch.cuda.is_available() else "cpu"
12
+
13
+ feature_extractor = AutoFeatureExtractor.from_pretrained("hustvl/yolos-small", size=512, max_size=864)
14
 
15
  id2label = {1: 'person', 2: 'rider', 3: 'car', 4: 'bus', 5: 'truck', 6: 'bike', 7: 'motor', 8: 'traffic light', 9: 'traffic sign', 10: 'train'}
16
 
17
+ # colors for visualization
18
+ colors = [
19
+ [ 0, 113, 188,],
20
+ [216, 82, 24,],
21
+ [236, 176, 31,],
22
+ [192, 202, 25,],
23
+ [118, 171, 47,],
24
+ [ 76, 189, 237,],
25
+ [ 46, 125, 188,],
26
+ [125, 171, 141,],
27
+ [125, 76, 237,],
28
+ [ 0, 82, 216,],
29
+ [189, 76, 47,]]
30
+
31
  class Detr(pl.LightningModule):
32
 
33
  def __init__(self, lr, weight_decay):
 
81
  return optimizer
82
 
83
 
 
 
 
 
84
  # Build model and load checkpoint
85
  checkpoint = './checkpoints/epoch=1-step=2184.ckpt'
86
  model_yolos = Detr.load_from_checkpoint(checkpoint, lr=2.5e-5, weight_decay=1e-4)
 
88
  model_yolos.to(device)
89
  model_yolos.eval()
90
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
 
92
  # for output bounding box post-processing
93
  def box_cxcywh_to_xyxy(x):
 
96
  (x_c + 0.5 * w), (y_c + 0.5 * h)]
97
  return torch.stack(b, dim=1)
98
 
99
+
100
  def rescale_bboxes(out_bbox, size):
101
  img_w, img_h = size
102
  b = box_cxcywh_to_xyxy(out_bbox)
103
  b = b * torch.tensor([img_w, img_h, img_w, img_h], dtype=torch.float32)
104
  return b
105
 
106
+
107
  def plot_results(pil_img, prob, boxes):
108
 
109
  img = numpy.asarray(pil_img)
 
114
  c1, c2 = (int(xmin), int(ymin)), (int(xmax), int(ymax))
115
 
116
  cv2.rectangle(img, c1, c2, c, thickness=2, lineType=cv2.LINE_AA)
117
+ cv2.putText(img, f'{id2label[cl.item()]}: {p[cl]:0.2f}', [int(xmin), int(ymin)-5], cv2.FONT_HERSHEY_SIMPLEX, 0.7, c, 2)
 
 
 
118
  return Image.fromarray(img)
 
119
 
120
 
121
  def generate_preds(processor, model, image):
 
136
 
137
 
138
  def detect(img):
 
139
  # Run inference
140
  preds = generate_preds(feature_extractor, model_yolos, img)
 
141
  return visualize_preds(img, preds)
142
 
143
+
144
+ description = "This is a traffic object detector based on <a href='https://huggingface.co/docs/transformers/model_doc/yolos' style='text-decoration: underline' target='_blank'>YOLOS</a>. \n" + \
145
+ "The model can detect following targets: {1: 'person', 2: 'rider', 3: 'car', 4: 'bus', 5: 'truck', 6: 'bike', 7: 'motor', 8: 'traffic light', 9: 'traffic sign', 10: 'train'}."
146
+
147
 
148
  interface = gr.Interface(
149
  fn=detect,
150
  inputs=[gr.Image(type="pil")],
151
  outputs=gr.Image(type="pil"),
 
152
  examples=[["./imgs/example1.jpg"], ["./imgs/example2.jpg"]],
153
  title="YOLOS for traffic object detection",
154
+ description=description)
155
 
156
  interface.launch()