Yvan
minor update
6428e5b
from PIL import ImageFont, ImageDraw
import numpy as np
import gradio as gr
import mim
mim.install('mmcv-full==1.7.1')
from mmcv.runner import load_checkpoint
from mmdet.apis import inference_detector
from mmdet.models import build_detector
from cfg import cfg
def show_results(org_image, results, classes, conf_threshold=0.2, classes_to_show=None):
font = ImageFont.truetype("fonts/Bohemian_typewriter.ttf", 15)
draw = ImageDraw.Draw(org_image)
if classes_to_show is None:
classes_to_show = list(np.arange(len(classes)))
for i, res_class in enumerate(results):
if i not in classes_to_show:
continue
for res in res_class:
if res[4] > conf_threshold:
draw.rectangle((res[0], res[1], res[2], res[3]), fill=None, outline="green", width=5)
draw.text((res[0], res[1]), str(classes[i]), (0, 255, 0), font=font)
return org_image
def query_image(input_img):
np_img = np.array(input_img)[:, :, ::-1].copy()
result = inference_detector(model, np_img)
pivot = 0
for res in result:
if len(res) > 0:
pivot = 1
break
if pivot > 0:
image_with_results = show_results(input_img, result, model.CLASSES, conf_threshold=0.2)
else:
image_with_results = input_img
return image_with_results
checkpoint = 'checkpoints/latest.pth'
device = 'cpu'
model = build_detector(cfg.model)
checkpoint = load_checkpoint(model, checkpoint, map_location=device)
model.CLASSES = checkpoint['meta']['CLASSES']
model.cfg = cfg
model.to(device)
model.eval()
examples = ["images/DOPANAR_2_00009199.jpg", "images/2013_05_26_A_00000-00007_4_00016247.jpg",
"images/2013-07-04_A-1_00014-00016_8_00009857.jpg"]
# iface = gr.Interface(query_image, inputs=gr.inputs.Image(type="pil"), outputs=gr.outputs.Image(type="pil"),
# examples=examples)
# iface.launch()
with gr.Blocks() as demo:
gr.Markdown(
"""
# Traffic sign detection
"""
)
with gr.Row().style(equal_height=True):
img1 = gr.Image(type="pil", label="Input image")
with gr.Row():
process_button = gr.Button("Detect", visible=True)
with gr.Row():
img2 = gr.Image(interactive=False, label="Output image with predicted pose")
process_button.click(fn=query_image, inputs=[img1], outputs=[img2])
examples = gr.Examples(examples=examples, inputs=[img1])
demo.queue().launch(show_api=False)