File size: 4,422 Bytes
a2e1978
41aa81f
a2e1978
 
 
 
e830065
 
 
 
3a1d843
e830065
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3794adc
9927f5d
e830065
 
 
 
 
 
 
 
 
 
 
de27acf
e830065
 
 
 
 
de27acf
 
e830065
 
 
 
 
 
 
a2e1978
233a435
2493b3c
 
 
3794adc
7cb214c
1d40bf1
3794adc
 
 
 
 
a508b4c
49fd176
 
a508b4c
2493b3c
 
a2e1978
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import gradio as gr 
import numpy as np
from mmcv.transforms import Compose
from mmdet.registry import VISUALIZERS
from mmdet.apis import init_detector, inference_detector

import torch
from torchvision.io import read_image
from torchvision.utils import draw_bounding_boxes
import torchvision.transforms.functional as TF

SIGNS_CLASSES = ('A10', 'A11', 'A12', 'A14', 'A15', 'A16', 'A17', 'A18', 'A19', 'A1a', 'A1b', 'A22', 'A24', 'A28', 'A29', 'A2a', 'A2b', 'A30', 'A31a', 'A31b', 'A31c', 'A32a', 'A32b', 'A4', 'A5a', 'A6a', 'A6b', 'A7a', 'A8', 'A9', 'B1', 'B11', 'B12', 'B13', 'B14', 'B15', 'B16', 'B17', 'B19', 'B2', 'B20a', 'B20b', 'B21a', 'B21b', 'B24a', 'B24b', 'B26', 'B28', 'B29', 'B32', 'B4', 'B5', 'B6', 'C1', 'C10a', 'C10b', 'C13a', 'C14a', 'C2a', 'C2b', 'C2c', 'C2d', 'C2e', 'C2f', 'C3a', 'C3b', 'C4a', 'C4b', 'C4c', 'C7a', 'C9a', 'C9b', 'E1', 'E11', 'E11c', 'E12', 'E13', 'E2a', 'E2b', 'E2c', 'E2d', 'E3a', 'E3b', 'E4', 'E5', 'E6', 'E7a', 'E7b', 'E8a', 'E8b', 'E8c', 'E8d', 'E8e', 'E9', 'I2', 'IJ1', 'IJ10', 'IJ11a', 'IJ11b', 'IJ14c', 'IJ15', 'IJ2', 'IJ3', 'IJ4a', 'IJ4b', 'IJ4c', 'IJ4d', 'IJ4e', 'IJ5', 'IJ6', 'IJ7', 'IJ8', 'IJ9', 'IP10a', 'IP10b', 'IP11a', 'IP11b', 'IP11c', 'IP11e', 'IP11g', 'IP12', 'IP13a', 'IP13b', 'IP13c', 'IP13d', 'IP14a', 'IP15a', 'IP15b', 'IP16', 'IP17', 'IP18a', 'IP18b', 'IP19', 'IP2', 'IP21', 'IP21a', 'IP22', 'IP25a', 'IP25b', 'IP26a', 'IP26b', 'IP27a', 'IP3', 'IP31a', 'IP4a', 'IP4b', 'IP5', 'IP6', 'IP7', 'IP8a', 'IP8b', 'IS10b', 'IS11a', 'IS11b', 'IS11c', 'IS12a', 'IS12b', 'IS12c', 'IS13', 'IS14', 'IS15a', 'IS15b', 'IS16b', 'IS16c', 'IS16d', 'IS17', 'IS18a', 'IS18b', 'IS19a', 'IS19b', 'IS19c', 'IS19d', 'IS1a', 'IS1b', 'IS1c', 'IS1d', 'IS20', 'IS21a', 'IS21b', 'IS21c', 'IS22a', 'IS22c', 'IS22d', 'IS22e', 'IS22f', 'IS23', 'IS24a', 'IS24b', 'IS24c', 'IS2a', 'IS2b', 'IS2c', 'IS2d', 'IS3a', 'IS3b', 'IS3c', 'IS3d', 'IS4a', 'IS4b', 'IS4c', 'IS4d', 'IS5', 'IS6a', 'IS6b', 'IS6c', 'IS6e', 'IS6f', 'IS6g', 'IS7a', 'IS8a', 'IS8b', 'IS9a', 'IS9b', 'IS9c', 'IS9d', 'O2', 'P1', 'P2', 'P3', 'P4', 'P6', 'P7', 'P8', 'UNKNOWN', 'X1', 'X2', 'X3', 'XXX', 'Z2', 'Z3', 'Z4a', 'Z4b', 'Z4c', 'Z4d', 'Z4e', 'Z7', 'Z9')

# Specify the path to model config and checkpoint file
config_file = 'configs/config_cascade_rcnn_traffic_signs.py'
checkpoint_file = 'checkpoints/traffic_signs_cascade_2v2.pth'

def draw_coco_bboxes(img, bboxes, color=(255,255,0), width=5, show=False, export_p=None, 
                     labels=None, resize_to=None):

    bboxes_transf = bboxes

    img = draw_bounding_boxes(img, 
                              torch.Tensor(bboxes_transf), 
                              colors=color, 
                              width=width,
                              labels=labels,
                              font_size=150)
    if show:
        if resize_to is not None:
            img = TF.resize(img, resize_to)
        img_pil = TF.to_pil_image(img)
        img_pil.show()

    if export_p: img_pil.save(export_p)

    return img


def traffic_sign_inference(img):
    # Build the model from a config file and a checkpoint file
    model = init_detector(config_file, checkpoint_file, device='cpu')
    
    result = inference_detector(model, img)
    
    # img = mmcv.imread(img) # numpy -> torch here!
    # img = mmcv.imconvert(img, 'bgr', 'rgb')

    bboxes = result.pred_instances.bboxes
    labels = [SIGNS_CLASSES[l] for l in result.pred_instances.labels]
    img_t = torch.from_numpy(img).permute(2, 0, 1)
    print(f"shape: {img_t.shape}")
    img_res_vis = draw_coco_bboxes(img_t, bboxes, labels=labels, show=True)
    return img_res_vis.permute(1, 2, 0).numpy()

demo = gr.Interface(traffic_sign_inference, gr.Image(), "image")

with demo:
    gr.Markdown('''
        # Czech Traffic Signs Detector
        Using [Cascade R-CNN](https://arxiv.org/abs/1712.00726) pretrained on COCO, finetuned on dataset of 39425 images provided by kky.zcu.cz, running on [MMDetection](https://github.com/open-mmlab/mmdetection). 
        Report (in Czech): https://drive.google.com/file/d/1bFafvrTdd6Gs9-uwIia8R1CZ1a-Fn9KR/view?usp=drive_link  
        ## Run prediction
        1. Upload an image (left box)
        2. Press submit
        3. See the detection result (on the right)
        
        ## Some of the classes
        ![Czech traffic signs](sdz.jpg "Czech traffic signs")

        https://www.znackydubi.cz/images/5bb48f0d7fa21/original
    ''')

demo.launch()