Spaces:
Runtime error
Runtime error
import gradio as gr | |
import numpy as np | |
from mmcv.transforms import Compose | |
from mmdet.registry import VISUALIZERS | |
from mmdet.apis import init_detector, inference_detector | |
import torch | |
from torchvision.io import read_image | |
from torchvision.utils import draw_bounding_boxes | |
import torchvision.transforms.functional as TF | |
SIGNS_CLASSES = ('A10', 'A11', 'A12', 'A14', 'A15', 'A16', 'A17', 'A18', 'A19', 'A1a', 'A1b', 'A22', 'A24', 'A28', 'A29', 'A2a', 'A2b', 'A30', 'A31a', 'A31b', 'A31c', 'A32a', 'A32b', 'A4', 'A5a', 'A6a', 'A6b', 'A7a', 'A8', 'A9', 'B1', 'B11', 'B12', 'B13', 'B14', 'B15', 'B16', 'B17', 'B19', 'B2', 'B20a', 'B20b', 'B21a', 'B21b', 'B24a', 'B24b', 'B26', 'B28', 'B29', 'B32', 'B4', 'B5', 'B6', 'C1', 'C10a', 'C10b', 'C13a', 'C14a', 'C2a', 'C2b', 'C2c', 'C2d', 'C2e', 'C2f', 'C3a', 'C3b', 'C4a', 'C4b', 'C4c', 'C7a', 'C9a', 'C9b', 'E1', 'E11', 'E11c', 'E12', 'E13', 'E2a', 'E2b', 'E2c', 'E2d', 'E3a', 'E3b', 'E4', 'E5', 'E6', 'E7a', 'E7b', 'E8a', 'E8b', 'E8c', 'E8d', 'E8e', 'E9', 'I2', 'IJ1', 'IJ10', 'IJ11a', 'IJ11b', 'IJ14c', 'IJ15', 'IJ2', 'IJ3', 'IJ4a', 'IJ4b', 'IJ4c', 'IJ4d', 'IJ4e', 'IJ5', 'IJ6', 'IJ7', 'IJ8', 'IJ9', 'IP10a', 'IP10b', 'IP11a', 'IP11b', 'IP11c', 'IP11e', 'IP11g', 'IP12', 'IP13a', 'IP13b', 'IP13c', 'IP13d', 'IP14a', 'IP15a', 'IP15b', 'IP16', 'IP17', 'IP18a', 'IP18b', 'IP19', 'IP2', 'IP21', 'IP21a', 'IP22', 'IP25a', 'IP25b', 'IP26a', 'IP26b', 'IP27a', 'IP3', 'IP31a', 'IP4a', 'IP4b', 'IP5', 'IP6', 'IP7', 'IP8a', 'IP8b', 'IS10b', 'IS11a', 'IS11b', 'IS11c', 'IS12a', 'IS12b', 'IS12c', 'IS13', 'IS14', 'IS15a', 'IS15b', 'IS16b', 'IS16c', 'IS16d', 'IS17', 'IS18a', 'IS18b', 'IS19a', 'IS19b', 'IS19c', 'IS19d', 'IS1a', 'IS1b', 'IS1c', 'IS1d', 'IS20', 'IS21a', 'IS21b', 'IS21c', 'IS22a', 'IS22c', 'IS22d', 'IS22e', 'IS22f', 'IS23', 'IS24a', 'IS24b', 'IS24c', 'IS2a', 'IS2b', 'IS2c', 'IS2d', 'IS3a', 'IS3b', 'IS3c', 'IS3d', 'IS4a', 'IS4b', 'IS4c', 'IS4d', 'IS5', 'IS6a', 'IS6b', 'IS6c', 'IS6e', 'IS6f', 'IS6g', 'IS7a', 'IS8a', 'IS8b', 'IS9a', 'IS9b', 'IS9c', 'IS9d', 'O2', 'P1', 'P2', 'P3', 'P4', 'P6', 'P7', 'P8', 'UNKNOWN', 'X1', 'X2', 'X3', 'XXX', 'Z2', 'Z3', 'Z4a', 'Z4b', 'Z4c', 'Z4d', 'Z4e', 'Z7', 'Z9') | |
# Specify the path to model config and checkpoint file | |
config_file = 'configs/config_cascade_rcnn_traffic_signs.py' | |
checkpoint_file = 'checkpoints/traffic_signs_cascade_2v2.pth' | |
def draw_coco_bboxes(img, bboxes, color=(255,255,0), width=5, show=False, export_p=None, | |
labels=None, resize_to=None): | |
bboxes_transf = bboxes | |
img = draw_bounding_boxes(img, | |
torch.Tensor(bboxes_transf), | |
colors=color, | |
width=width, | |
labels=labels, | |
font_size=150) | |
if show: | |
if resize_to is not None: | |
img = TF.resize(img, resize_to) | |
img_pil = TF.to_pil_image(img) | |
img_pil.show() | |
if export_p: img_pil.save(export_p) | |
return img | |
def traffic_sign_inference(img): | |
# Build the model from a config file and a checkpoint file | |
model = init_detector(config_file, checkpoint_file, device='cpu') | |
result = inference_detector(model, img) | |
# img = mmcv.imread(img) # numpy -> torch here! | |
# img = mmcv.imconvert(img, 'bgr', 'rgb') | |
bboxes = result.pred_instances.bboxes | |
labels = [SIGNS_CLASSES[l] for l in result.pred_instances.labels] | |
img_t = torch.from_numpy(img).permute(2, 0, 1) | |
print(f"shape: {img_t.shape}") | |
img_res_vis = draw_coco_bboxes(img_t, bboxes, labels=labels, show=True) | |
return img_res_vis.permute(1, 2, 0).numpy() | |
demo = gr.Interface(traffic_sign_inference, gr.Image(), "image") | |
with demo: | |
gr.Markdown(''' | |
# Czech Traffic Signs Detector | |
Using [Cascade R-CNN](https://arxiv.org/abs/1712.00726) pretrained on COCO, finetuned on dataset of 39425 images provided by kky.zcu.cz, running on [MMDetection](https://github.com/open-mmlab/mmdetection). | |
Report (in Czech): https://drive.google.com/file/d/1bFafvrTdd6Gs9-uwIia8R1CZ1a-Fn9KR/view?usp=drive_link | |
## Run prediction | |
1. Upload an image (left box) | |
2. Press submit | |
3. See the detection result (on the right) | |
## Some of the classes | |
 | |
https://www.znackydubi.cz/images/5bb48f0d7fa21/original | |
''') | |
demo.launch() |