File size: 2,483 Bytes
28398e1
 
128c377
88036f9
1d61059
28398e1
128c377
28398e1
 
 
128c377
28398e1
 
8eed7ff
28398e1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f33291b
28398e1
 
 
 
 
 
 
 
 
 
bfe8092
 
28398e1
75e37e0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
from PIL import ImageFont, ImageDraw
import numpy as np
import gradio as gr
import mim
mim.install('mmcv-full==1.7.1')
from mmcv.runner import load_checkpoint

from mmdet.apis import inference_detector
from mmdet.models import build_detector
from cfg import cfg


def show_results(org_image, results, classes, conf_threshold=0.2, classes_to_show=None):
    font = ImageFont.truetype("fonts/Bohemian_typewriter.ttf", 15)
    draw = ImageDraw.Draw(org_image)
    if classes_to_show is None:
        classes_to_show = list(np.arange(len(classes)))
    for i, res_class in enumerate(results):
        if i not in classes_to_show:
            continue
        for res in res_class:
            if res[4] > conf_threshold:
                draw.rectangle((res[0], res[1], res[2], res[3]), fill=None, outline="green", width=5)
                draw.text((res[0], res[1]), str(classes[i]), (0, 255, 0), font=font)
    return org_image


def query_image(input_img):
    np_img = np.array(input_img)[:, :, ::-1].copy()
    result = inference_detector(model, np_img)
    pivot = 0
    for res in result:
        if len(res) > 0:
            pivot = 1
            break
    if pivot > 0:
        image_with_results = show_results(input_img, result, model.CLASSES, conf_threshold=0.2)
    else:
        image_with_results = input_img
    return image_with_results


checkpoint = 'checkpoints/latest.pth'

device = 'cpu'
model = build_detector(cfg.model)
checkpoint = load_checkpoint(model, checkpoint, map_location=device)

model.CLASSES = checkpoint['meta']['CLASSES']
model.cfg = cfg
model.to(device)
model.eval()

examples = ["images/DOPANAR_2_00009199.jpg", "images/2013_05_26_A_00000-00007_4_00016247.jpg",
            "images/2013-07-04_A-1_00014-00016_8_00009857.jpg"]

# iface = gr.Interface(query_image, inputs=gr.inputs.Image(type="pil"), outputs=gr.outputs.Image(type="pil"),
#                      examples=examples)
# iface.launch()

with gr.Blocks() as demo:
    gr.Markdown(
        """
    # Traffic sign detection
    """
    )
    with gr.Row().style(equal_height=True):
        img1 = gr.Image(type="pil", label="Input image")
    with gr.Row():
        process_button = gr.Button("Detect", visible=True)
    with gr.Row():
        img2 = gr.Image(interactive=False, label="Output image with predicted pose")
    process_button.click(fn=query_image, inputs=[img1], outputs=[img2])
    examples = gr.Examples(examples=examples, inputs=[img1])

demo.queue().launch(show_api=False)