Spaces:
Runtime error
Runtime error
from typing import Any | |
import cv2 | |
import numpy as np | |
from yolov7 import YOLOv7 | |
import gradio as gr | |
from PIL import Image | |
class Inference: | |
def setup_models(self, model_path, labels_path, engine_path): | |
yolo = YOLOv7( | |
model_path, | |
labels_path, | |
engine_path | |
) | |
return yolo | |
def __init__(self, model_path, labels_path, engine_path): | |
self.model = self.setup_models( | |
model_path, | |
labels_path, | |
engine_path | |
) | |
def __call__(self, frame: np.ndarray, conf_threshold: float, nms_threshold: float, *args: Any, **kwds: Any) -> Any: | |
boxes, scores, class_ids = self.model(frame, conf_threshold, nms_threshold) | |
return boxes, scores, class_ids | |
infer1 = Inference( | |
"models/firesmoke.onnx", | |
"models/labels.txt", | |
"firesmoke.trt" | |
) | |
infer2 = Inference( | |
"models/firesmoke-henry.onnx", | |
"models/labels.txt", | |
"firesmoke-henry.trt" | |
) | |
def run(content_img, conf_threshold, nms_threshold): | |
content_img = cv2.cvtColor(np.array(content_img), cv2.COLOR_RGB2BGR) | |
boxes1, scores1, class_ids1 = infer1(content_img, conf_threshold, nms_threshold) | |
boxes2, scores2, class_ids2 = infer2(content_img, conf_threshold, nms_threshold) | |
img1 = content_img.copy() | |
img2 = content_img.copy() | |
if len(boxes1) > 0: | |
for box, score, class_id in zip(boxes1, scores1, class_ids1): | |
x1 = int(box[0]) | |
y1 = int(box[1]) | |
x2 = int(box[2]) | |
y2 = int(box[3]) | |
cv2.rectangle(img1, (x1, y1), (x2, y2), (0, 0, 255), 2) | |
cv2.rectangle(img1, (x1, y1-20), (x1+100, y1), (0, 0, 255), -1) | |
cv2.putText(img1, "{}:{:.2f}".format(class_id, score), (x1, y1), cv2.FONT_HERSHEY_COMPLEX, 0.5, (255, 255, 255), 1) | |
if len(boxes2) > 0: | |
for box, score, class_id in zip(boxes2, scores2, class_ids2): | |
x1 = int(box[0]) | |
y1 = int(box[1]) | |
x2 = int(box[2]) | |
y2 = int(box[3]) | |
cv2.rectangle(img2, (x1, y1), (x2, y2), (0, 0, 255), 2) | |
cv2.rectangle(img2, (x1, y1-20), (x1+100, y1), (0, 0, 255), -1) | |
cv2.putText(img2, "{}:{:.2f}".format(class_id, score), (x1, y1), cv2.FONT_HERSHEY_COMPLEX, 0.5, (255, 255, 255), 1) | |
img1 = cv2.cvtColor(img1, cv2.COLOR_BGR2RGB) | |
img2 = cv2.cvtColor(img2, cv2.COLOR_BGR2RGB) | |
img1 = Image.fromarray(img1) | |
img2 = Image.fromarray(img2) | |
return img1, img2 | |
if __name__ == '__main__': | |
style = gr.Interface( | |
fn=run, | |
inputs=[ | |
gr.Image(label='Input Image'), | |
gr.Slider(minimum=0.05, maximum=1, step=0.05, label="Confidence Threshold", default=0.2), | |
gr.Slider(minimum=0.05, maximum=1, step=0.05, label="NMS Threshold", default=0.5), | |
], | |
outputs=[ | |
gr.Image( | |
type="pil", | |
label="Finetuned" | |
), | |
gr.Image( | |
type="pil", | |
label="Finetuned + New Data" | |
), | |
], | |
examples=[ | |
['examples/fire1.jpg', 0.2, 0.5], | |
['examples/fire2.jpg', 0.2, 0.5], | |
['examples/fire3.jpg', 0.2, 0.5] | |
] | |
) | |
style.launch() |