ERAV2_S15 / app.py
Vasudevakrishna's picture
S15 added.
f267ae4
raw
history blame
No virus
2.79 kB
import cv2, os
import gradio as gr
import numpy as np
import torch
from models.common import DetectMultiBackend
from utils.augmentations import letterbox
from utils.general import non_max_suppression, scale_boxes
from utils.plots import Annotator, colors, save_one_box
from utils.torch_utils import select_device
# path = r"C:\Users\H504171\U_Personal\YoloV3\data\buffalo\Test\341.jpg"
img_size = 640
stride = 32
auto = True
max_det=1000
classes = None
agnostic_nms = False
line_thickness = 3
# Load model
device = select_device('cpu')
dnn =False
data = "data/custom_data.yaml"
weights = "weights/best.pt"
model = DetectMultiBackend(weights, device=device, dnn=dnn, data=data, fp16=False)
def inference(image, iou_thres, conf_thres):
im = letterbox(image, img_size, stride=stride, auto=auto)[0] # padded resize
im = im.transpose((2, 0, 1))[::-1] # HWC to CHW, BGR to RGB
im = np.ascontiguousarray(im) # contiguous
im = torch.from_numpy(im).to(model.device)
im = im.half() if model.fp16 else im.float() # uint8 to fp16/32
im /= 255 # 0 - 255 to 0.0 - 1.0
if len(im.shape) == 3:
im = im[None]
pred = model(im, augment=False, visualize=False)
pred = pred[0][1] if isinstance(pred[0], list) else pred[0]
pred = non_max_suppression(pred, conf_thres, iou_thres, None, agnostic_nms, max_det=max_det)
# Process predictions
for i, det in enumerate(pred): # per image
gn = torch.tensor(image.shape)[[1, 0, 1, 0]] # normalization gain whwh
imc = image.copy()
annotator = Annotator(image, line_width=line_thickness, example=str({0:'buffalo'}))
if len(det):
# Rescale boxes from img_size to im0 size
det[:, :4] = scale_boxes(im.shape[2:], det[:, :4], image.shape).round()
# Print results
for c in det[:, 5].unique():
n = (det[:, 5] == c).sum() # detections per class
# Write results
for *xyxy, conf, cls in reversed(det):
c = int(cls) # integer class
label = 'buffalo'
annotator.box_label(xyxy, label, color=colors(c, True))
# Stream results
im0 = annotator.result()
return im0
title = "YOLO V9 trained on Custom Dataset"
description = "Gradio interface to show yoloV9 object detection."
examples = [[f'examples/{i}'] for i in os.listdir("examples")]
demo = gr.Interface(
inference,
inputs = [gr.Image(height=640, width = 640, label="Input Image"), gr.Slider(0, 1, value = 0.5, label="IOU Value"), gr.Slider(0, 1, value = 0.5, label="Threshold Value")],
outputs = [gr.Image(label="YoloV9 Output", height=640, width = 640)],
title = title,
description = description,
examples = examples,
)
demo.launch()