AEYE_INSPECTOR / app.py
ausawin's picture
Upload app.py
749d98f
raw
history blame contribute delete
No virus
4.9 kB
import io
import gradio as gr
import matplotlib.pyplot as plt
import requests, validators
from sqlalchemy import true
import torch
import pathlib
from PIL import Image
import os
from detecto import core, utils, visualize
from detecto.core import Model as DetectoModel
title = """<h1 id="title">AEYE INSPECTOR</h1>"""
css = '''
h1#title {
text-align: center;
}
'''
COLORS = [
[0.000, 0.447, 0.741],
[0.850, 0.325, 0.098],
[0.929, 0.694, 0.125],
[0.494, 0.184, 0.556],
[0.466, 0.674, 0.188],
[0.301, 0.745, 0.933]
]
models = ["Detecto (Faster-RCNN)","YOLOv100"]
urls = [#'http://fbbbb.ddns.net:4080/static/images/ai_img(1).jpg',
#'http://fbbbb.ddns.net:4080/static/images/ai_img(2).jpg',
#'http://fbbbb.ddns.net:4080/static/images/ai_img(3).jpg',
#'http://fbbbb.ddns.net:4080/static/images/ai_img(4).jpg',
#'http://fbbbb.ddns.net:4080/static/images/ai_img(5).jpg',
#'http://fbbbb.ddns.net:4080/static/images/ai_img(6).jpg',
#'http://fbbbb.ddns.net:4080/static/images/ai_img(7).jpg',
#'http://fbbbb.ddns.net:4080/static/images/ai_img(8).jpg',
#'http://fbbbb.ddns.net:4080/static/images/ai_img(9).jpg',
#'http://fbbbb.ddns.net:4080/static/images/ai_img(10).jpg'
]
def detect_objects(model_name,url_input,image_input,threshold):
if validators.url(url_input):
image = Image.open(requests.get(url_input, stream=True).raw)
elif image_input:
image = image_input
if 'Detecto' in model_name:
model = DetectoModel(['heltmet_safe','face_mask','safety_vest','safety_belts','safety_shoes'])
model.load('ai_30ep.pth',['heltmet_safe','face_mask','safety_vest','safety_belts','safety_shoes'])
print("OK")
labels, boxes, scores = model.predict(image)
viz_img = visualize_prediction(image, labels, boxes, scores, threshold)
print(labels)
#print(boxes)
print(scores)
return viz_img
def visualize_prediction(pil_img, labels, boxes, scores, threshold=0.7):
keeps = scores > threshold
print(keeps)
boxess = boxes[keeps].tolist()
print(boxess)
#labelss = labels[keep]
#print(labelss)
plt.figure(figsize=(16, 10))
plt.imshow(pil_img)
ax = plt.gca()
colors = COLORS * 100
for idx, keep in enumerate(keeps):
if keep:
(xmin, ymin, xmax, ymax) = zip(boxess[idx])
print(xmin[0])
print(ymin[0])
ax.add_patch(plt.Rectangle((xmin[0], ymin[0]), xmax[0] - xmin[0], ymax[0] - ymin[0], fill=False, color=colors[idx], linewidth=3))
ax.text(xmin[0], ymin[0], f'{labels[idx]}: {scores[idx]:0.2f}', fontsize=15, bbox=dict(facecolor="yellow", alpha=0.5))
plt.axis("off")
return fig2img(plt.gcf())
def set_example_image(example: list) -> dict:
return gr.Image.update(value=example[0])
def set_example_url(example: list) -> dict:
return gr.Textbox.update(value=example[0])
def fig2img(fig):
buf = io.BytesIO()
fig.savefig(buf)
buf.seek(0)
img = Image.open(buf)
return img
app = gr.Blocks(css=css)
with app:
gr.Markdown(title)
options = gr.Dropdown(choices=models,label='Select Object Detection Model',show_label=True)
slider_input = gr.Slider(minimum=0.2,maximum=1,value=0.7,label='Prediction Threshold')
with gr.Tabs():
with gr.TabItem('Image URL'):
with gr.Row():
url_input = gr.Textbox(lines=2,label='Enter valid image URL here..')
img_output_from_url = gr.Image(shape=(650,650))
with gr.Row():
example_url = gr.Dataset(components=[url_input],samples=[[str(url)] for url in urls])
url_but = gr.Button('Detect')
with gr.TabItem('Image Upload'):
with gr.Row():
img_input = gr.Image(type='pil')
img_output_from_upload= gr.Image(shape=(650,650))
with gr.Row():
example_images = gr.Dataset(components=[img_input],
samples=[[path.as_posix()]
for path in sorted(pathlib.Path('images').rglob('*.JPG'))])
img_but = gr.Button('Detect')
url_but.click(detect_objects,inputs=[options,url_input,img_input,slider_input],outputs=img_output_from_url,queue=True)
img_but.click(detect_objects,inputs=[options,url_input,img_input,slider_input],outputs=img_output_from_upload,queue=True)
example_images.click(fn=set_example_image,inputs=[example_images],outputs=[img_input])
example_url.click(fn=set_example_url,inputs=[example_url],outputs=[url_input])
app.launch(enable_queue=True, server_name='0.0.0.0', show_error=True)