ERAV2_S15 / app.py
Vasudevakrishna's picture
Update app.py
0de2a3d verified
import cv2, os
import gradio as gr
import numpy as np
import torch
from models.common import DetectMultiBackend
from utils.augmentations import letterbox
from utils.general import non_max_suppression, scale_boxes
from utils.plots import Annotator, colors, save_one_box
from utils.torch_utils import select_device
img_size = 640
stride = 32
auto = True
max_det=1000
classes = None
agnostic_nms = False
line_thickness = 3
# Load model
device = select_device('cpu')
dnn =False
data = "data/custom_data.yaml"
weights = "weights/best.pt"
model = DetectMultiBackend(weights, device=device, dnn=dnn, data=data, fp16=False)
def inference(image, iou_thres=0.5, conf_thres=0.5):
im = letterbox(image, img_size, stride=stride, auto=auto)[0] # padded resize
im = im.transpose((2, 0, 1))[::-1] # HWC to CHW, BGR to RGB
im = np.ascontiguousarray(im) # contiguous
im = torch.from_numpy(im).to(model.device)
im = im.half() if model.fp16 else im.float() # uint8 to fp16/32
im /= 255 # 0 - 255 to 0.0 - 1.0
if len(im.shape) == 3:
im = im[None]
pred = model(im, augment=False, visualize=False)
pred = pred[0][1] if isinstance(pred[0], list) else pred[0]
pred = non_max_suppression(pred, conf_thres, iou_thres, None, agnostic_nms, max_det=max_det)
# Process predictions
for i, det in enumerate(pred): # per image
gn = torch.tensor(image.shape)[[1, 0, 1, 0]] # normalization gain whwh
imc = image.copy()
annotator = Annotator(image, line_width=line_thickness, example=str({0:'buffalo'}))
if len(det):
# Rescale boxes from img_size to im0 size
det[:, :4] = scale_boxes(im.shape[2:], det[:, :4], image.shape).round()
# Print results
for c in det[:, 5].unique():
n = (det[:, 5] == c).sum() # detections per class
# Write results
for *xyxy, conf, cls in reversed(det):
c = int(cls) # integer class
label = 'buffalo'
annotator.box_label(xyxy, label, color=colors(c, True))
# Stream results
im0 = annotator.result()
return im0
title = "YOLO V9 trained on Custom Dataset"
description = "Gradio interface to show yoloV9 object detection."
examples = [[f'examples/{i}'] for i in os.listdir("examples")]
demo = gr.Interface(
inference,
inputs = [gr.Image(height=640, width = 640, label="Input Image"), gr.Slider(0, 1, value = 0.5, label="IOU Value"), gr.Slider(0, 1, value = 0.5, label="Threshold Value")],
outputs = [gr.Image(label="YoloV9 Output", height=640, width = 640)],
title = title,
description = description,
examples = examples,
)
demo.launch()