File size: 5,235 Bytes
c37ceb0
 
 
 
 
224b9f6
88c51e4
8004ebd
 
cbbced1
8004ebd
 
 
 
c37ceb0
c473504
c37ceb0
8004ebd
 
 
 
 
063786d
8004ebd
 
063786d
8004ebd
 
 
 
 
c37ceb0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cec8f07
d78a0b4
c37ceb0
d78a0b4
 
c37ceb0
d252b7e
c37ceb0
 
 
 
 
 
 
8004ebd
c37ceb0
 
 
 
 
 
8004ebd
c37ceb0
224b9f6
88c51e4
224b9f6
c37ceb0
 
ff9db0f
5eaf2b4
224b9f6
6d9e76a
8004ebd
6d9e76a
c37ceb0
 
 
 
3efc2d7
c37ceb0
 
 
 
 
 
 
 
 
 
 
 
 
 
d78a0b4
c37ceb0
d78a0b4
60390c6
c37ceb0
8004ebd
60390c6
b2cad71
8004ebd
c37ceb0
 
 
 
224b9f6
b2cad71
c37ceb0
8004ebd
c37ceb0
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
import gradio as gr
import os
import torch
import pytorch_lightning as pl

import cv2
import numpy
from transformers import AutoFeatureExtractor, AutoModelForObjectDetection
from PIL import Image
import streamlit as st

device = "cuda" if torch.cuda.is_available() else "cpu"

feature_extractor = AutoFeatureExtractor.from_pretrained("hustvl/yolos-small", size=512, max_size=864)

id2label = {1: 'person', 2: 'rider', 3: 'car', 4: 'bus', 5: 'truck', 6: 'bike', 7: 'motor', 8: 'traffic light', 9: 'traffic sign', 10: 'train'}

# colors for visualization
colors = [
    [  0, 113, 188,],
    [216,  82,  24,],
    [236, 176,  31,],
    [255, 255,  0,],
    [118, 171,  47,],
    [ 76, 189, 237,],
    [ 46, 155, 188,],
    [125, 171, 141,],
    [125,  76, 237,],
    [  0,  82, 216,],
    [189,  76,  47,]]

class Detr(pl.LightningModule):

     def __init__(self, lr, weight_decay):
         super().__init__()
         # replace COCO classification head with custom head
         self.model = AutoModelForObjectDetection.from_pretrained("hustvl/yolos-small",
                                                             num_labels=len(id2label),
                                                             ignore_mismatched_sizes=True)
         # see https://github.com/PyTorchLightning/pytorch-lightning/pull/1896
         self.lr = lr
         self.weight_decay = weight_decay

     def forward(self, pixel_values):
       outputs = self.model(pixel_values=pixel_values)

       return outputs
     
     def common_step(self, batch, batch_idx):
       pixel_values = batch["pixel_values"]
       labels = [{k: v.to(self.device) for k, v in t.items()} for t in batch["labels"]]

       outputs = self.model(pixel_values=pixel_values, labels=labels)

       loss = outputs.loss
       loss_dict = outputs.loss_dict

       return loss, loss_dict

     def training_step(self, batch, batch_idx):
        loss, loss_dict = self.common_step(batch, batch_idx)     
        # logs metrics for each training_step,
        # and the average across the epoch
        self.log("training_loss", loss)
        for k,v in loss_dict.items():
          self.log("train_" + k, v.item())

        return loss

     def validation_step(self, batch, batch_idx):
        loss, loss_dict = self.common_step(batch, batch_idx)     
        self.log("validation_loss", loss)
        for k,v in loss_dict.items():
          self.log("validation_" + k, v.item())

        return loss

     def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr=self.lr,
                                  weight_decay=self.weight_decay)
        
        return optimizer


# Build model and load checkpoint
checkpoint = './checkpoints/epoch=1-step=2184.ckpt'
model_yolos = Detr.load_from_checkpoint(checkpoint, lr=2.5e-5, weight_decay=1e-4)

model_yolos.to(device)
model_yolos.eval()


# for output bounding box post-processing
def box_cxcywh_to_xyxy(x):
    x_c, y_c, w, h = x.unbind(1)
    b = [(x_c - 0.5 * w), (y_c - 0.5 * h),
         (x_c + 0.5 * w), (y_c + 0.5 * h)]
    return torch.stack(b, dim=1)


def rescale_bboxes(out_bbox, size):
    img_w, img_h = size
    b = box_cxcywh_to_xyxy(out_bbox)
    b = b * torch.tensor([img_w, img_h, img_w, img_h], dtype=torch.float32)
    return b


def plot_results(pil_img, prob, boxes):

    img = numpy.asarray(pil_img)
    
    for p, (xmin, ymin, xmax, ymax) in zip(prob, boxes.tolist()):
        cl = p.argmax()
        c = colors[cl]
        c1, c2 = (int(xmin), int(ymin)), (int(xmax), int(ymax))

        cv2.rectangle(img, c1, c2, c, thickness=2, lineType=cv2.LINE_AA)
        cv2.putText(img, f'{id2label[cl.item()]}: {p[cl]:0.2f}', [int(xmin), int(ymin)-5], cv2.FONT_HERSHEY_SIMPLEX, 0.7, c, 2)
    return Image.fromarray(img)


def generate_preds(processor, model, image):
    inputs = processor(images=image, return_tensors="pt").to(device)
    preds = model(pixel_values=inputs.pixel_values)
    return preds


def visualize_preds(image, preds, threshold=0.9):
    # keep only predictions with confidence >= threshold
    probas = preds.logits.softmax(-1)[0, :, :-1]
    keep = probas.max(-1).values > threshold
    
    # convert predicted boxes from [0; 1] to image scales
    bboxes_scaled = rescale_bboxes(preds.pred_boxes[0, keep].cpu(), image.size)

    return plot_results(image, probas[keep], bboxes_scaled)


def detect(img):
    # Run inference
    preds = generate_preds(feature_extractor, model_yolos, img)
    return visualize_preds(img, preds) 


description = "Welcome to this space! 🤗this is a traffic object detector based on <a href='https://huggingface.co/docs/transformers/model_doc/yolos' style='text-decoration: underline' target='_blank'>YOLOS</a>. \n\n" + \
    "The model can detect following targets: 🚶‍♂️person, 🚴‍♀️rider, 🚗car, 🚌bus, 🚚truck, 🚲bike, 🏍️motor, 🚦traffic light, ⛔traffic sign, 🚄train."

   
interface = gr.Interface(
    fn=detect,
    inputs=[gr.Image(type="pil")], 
    outputs=gr.Image(type="pil"),
    examples=[["./imgs/example1.jpg"], ["./imgs/example2.jpg"], ["./imgs/example3.png"]],
    title="YOLOS for traffic object detection",
    description=description)

interface.launch()