import io import gradio as gr import matplotlib.pyplot as plt import requests, validators from sqlalchemy import true import torch import pathlib from PIL import Image import os from detecto import core, utils, visualize from detecto.core import Model as DetectoModel title = """

AEYE INSPECTOR

""" css = ''' h1#title { text-align: center; } ''' COLORS = [ [0.000, 0.447, 0.741], [0.850, 0.325, 0.098], [0.929, 0.694, 0.125], [0.494, 0.184, 0.556], [0.466, 0.674, 0.188], [0.301, 0.745, 0.933] ] models = ["Detecto (Faster-RCNN)","YOLOv100"] urls = [#'http://fbbbb.ddns.net:4080/static/images/ai_img(1).jpg', #'http://fbbbb.ddns.net:4080/static/images/ai_img(2).jpg', #'http://fbbbb.ddns.net:4080/static/images/ai_img(3).jpg', #'http://fbbbb.ddns.net:4080/static/images/ai_img(4).jpg', #'http://fbbbb.ddns.net:4080/static/images/ai_img(5).jpg', #'http://fbbbb.ddns.net:4080/static/images/ai_img(6).jpg', #'http://fbbbb.ddns.net:4080/static/images/ai_img(7).jpg', #'http://fbbbb.ddns.net:4080/static/images/ai_img(8).jpg', #'http://fbbbb.ddns.net:4080/static/images/ai_img(9).jpg', #'http://fbbbb.ddns.net:4080/static/images/ai_img(10).jpg' ] def detect_objects(model_name,url_input,image_input,threshold): if validators.url(url_input): image = Image.open(requests.get(url_input, stream=True).raw) elif image_input: image = image_input if 'Detecto' in model_name: model = DetectoModel(['heltmet_safe','face_mask','safety_vest','safety_belts','safety_shoes']) model.load('ai_30ep.pth',['heltmet_safe','face_mask','safety_vest','safety_belts','safety_shoes']) print("OK") labels, boxes, scores = model.predict(image) viz_img = visualize_prediction(image, labels, boxes, scores, threshold) print(labels) #print(boxes) print(scores) return viz_img def visualize_prediction(pil_img, labels, boxes, scores, threshold=0.7): keeps = scores > threshold print(keeps) boxess = boxes[keeps].tolist() print(boxess) #labelss = labels[keep] #print(labelss) plt.figure(figsize=(16, 10)) plt.imshow(pil_img) ax = plt.gca() colors = COLORS * 100 for idx, keep in enumerate(keeps): if keep: (xmin, ymin, xmax, ymax) = zip(boxess[idx]) print(xmin[0]) print(ymin[0]) ax.add_patch(plt.Rectangle((xmin[0], ymin[0]), xmax[0] - xmin[0], ymax[0] - ymin[0], fill=False, color=colors[idx], linewidth=3)) ax.text(xmin[0], ymin[0], f'{labels[idx]}: {scores[idx]:0.2f}', fontsize=15, bbox=dict(facecolor="yellow", alpha=0.5)) plt.axis("off") return fig2img(plt.gcf()) def set_example_image(example: list) -> dict: return gr.Image.update(value=example[0]) def set_example_url(example: list) -> dict: return gr.Textbox.update(value=example[0]) def fig2img(fig): buf = io.BytesIO() fig.savefig(buf) buf.seek(0) img = Image.open(buf) return img app = gr.Blocks(css=css) with app: gr.Markdown(title) options = gr.Dropdown(choices=models,label='Select Object Detection Model',show_label=True) slider_input = gr.Slider(minimum=0.2,maximum=1,value=0.7,label='Prediction Threshold') with gr.Tabs(): with gr.TabItem('Image URL'): with gr.Row(): url_input = gr.Textbox(lines=2,label='Enter valid image URL here..') img_output_from_url = gr.Image(shape=(650,650)) with gr.Row(): example_url = gr.Dataset(components=[url_input],samples=[[str(url)] for url in urls]) url_but = gr.Button('Detect') with gr.TabItem('Image Upload'): with gr.Row(): img_input = gr.Image(type='pil') img_output_from_upload= gr.Image(shape=(650,650)) with gr.Row(): example_images = gr.Dataset(components=[img_input], samples=[[path.as_posix()] for path in sorted(pathlib.Path('images').rglob('*.JPG'))]) img_but = gr.Button('Detect') url_but.click(detect_objects,inputs=[options,url_input,img_input,slider_input],outputs=img_output_from_url,queue=True) img_but.click(detect_objects,inputs=[options,url_input,img_input,slider_input],outputs=img_output_from_upload,queue=True) example_images.click(fn=set_example_image,inputs=[example_images],outputs=[img_input]) example_url.click(fn=set_example_url,inputs=[example_url],outputs=[url_input]) app.launch(enable_queue=True, server_name='0.0.0.0', show_error=True)