eireneai's picture
Update app.py
945a642
raw
history blame contribute delete
No virus
3.25 kB
from typing import Any
import cv2
import numpy as np
from yolov7 import YOLOv7
import gradio as gr
from PIL import Image
class Inference:
def setup_models(self, model_path, labels_path, engine_path):
yolo = YOLOv7(
model_path,
labels_path,
engine_path
)
return yolo
def __init__(self, model_path, labels_path, engine_path):
self.model = self.setup_models(
model_path,
labels_path,
engine_path
)
def __call__(self, frame: np.ndarray, conf_threshold: float, nms_threshold: float, *args: Any, **kwds: Any) -> Any:
boxes, scores, class_ids = self.model(frame, conf_threshold, nms_threshold)
return boxes, scores, class_ids
infer1 = Inference(
"models/firesmoke.onnx",
"models/labels.txt",
"firesmoke.trt"
)
infer2 = Inference(
"models/firesmoke-henry.onnx",
"models/labels.txt",
"firesmoke-henry.trt"
)
def run(content_img, conf_threshold, nms_threshold):
content_img = cv2.cvtColor(np.array(content_img), cv2.COLOR_RGB2BGR)
boxes1, scores1, class_ids1 = infer1(content_img, conf_threshold, nms_threshold)
boxes2, scores2, class_ids2 = infer2(content_img, conf_threshold, nms_threshold)
img1 = content_img.copy()
img2 = content_img.copy()
if len(boxes1) > 0:
for box, score, class_id in zip(boxes1, scores1, class_ids1):
x1 = int(box[0])
y1 = int(box[1])
x2 = int(box[2])
y2 = int(box[3])
cv2.rectangle(img1, (x1, y1), (x2, y2), (0, 0, 255), 2)
cv2.rectangle(img1, (x1, y1-20), (x1+100, y1), (0, 0, 255), -1)
cv2.putText(img1, "{}:{:.2f}".format(class_id, score), (x1, y1), cv2.FONT_HERSHEY_COMPLEX, 0.5, (255, 255, 255), 1)
if len(boxes2) > 0:
for box, score, class_id in zip(boxes2, scores2, class_ids2):
x1 = int(box[0])
y1 = int(box[1])
x2 = int(box[2])
y2 = int(box[3])
cv2.rectangle(img2, (x1, y1), (x2, y2), (0, 0, 255), 2)
cv2.rectangle(img2, (x1, y1-20), (x1+100, y1), (0, 0, 255), -1)
cv2.putText(img2, "{}:{:.2f}".format(class_id, score), (x1, y1), cv2.FONT_HERSHEY_COMPLEX, 0.5, (255, 255, 255), 1)
img1 = cv2.cvtColor(img1, cv2.COLOR_BGR2RGB)
img2 = cv2.cvtColor(img2, cv2.COLOR_BGR2RGB)
img1 = Image.fromarray(img1)
img2 = Image.fromarray(img2)
return img1, img2
if __name__ == '__main__':
style = gr.Interface(
fn=run,
inputs=[
gr.Image(label='Input Image'),
gr.Slider(minimum=0.05, maximum=1, step=0.05, label="Confidence Threshold", default=0.2),
gr.Slider(minimum=0.05, maximum=1, step=0.05, label="NMS Threshold", default=0.5),
],
outputs=[
gr.Image(
type="pil",
label="Finetuned"
),
gr.Image(
type="pil",
label="Finetuned + New Data"
),
],
examples=[
['examples/fire1.jpg', 0.2, 0.5],
['examples/fire2.jpg', 0.2, 0.5],
['examples/fire3.jpg', 0.15, 0.5]
]
)
style.launch()