Spaces:
Sleeping
Sleeping
import cv2, os | |
import gradio as gr | |
import numpy as np | |
import torch | |
from models.common import DetectMultiBackend | |
from utils.augmentations import letterbox | |
from utils.general import non_max_suppression, scale_boxes | |
from utils.plots import Annotator, colors, save_one_box | |
from utils.torch_utils import select_device | |
img_size = 640 | |
stride = 32 | |
auto = True | |
max_det=1000 | |
classes = None | |
agnostic_nms = False | |
line_thickness = 3 | |
# Load model | |
device = select_device('cpu') | |
dnn =False | |
data = "data/custom_data.yaml" | |
weights = "weights/best.pt" | |
model = DetectMultiBackend(weights, device=device, dnn=dnn, data=data, fp16=False) | |
def inference(image, iou_thres=0.5, conf_thres=0.5): | |
im = letterbox(image, img_size, stride=stride, auto=auto)[0] # padded resize | |
im = im.transpose((2, 0, 1))[::-1] # HWC to CHW, BGR to RGB | |
im = np.ascontiguousarray(im) # contiguous | |
im = torch.from_numpy(im).to(model.device) | |
im = im.half() if model.fp16 else im.float() # uint8 to fp16/32 | |
im /= 255 # 0 - 255 to 0.0 - 1.0 | |
if len(im.shape) == 3: | |
im = im[None] | |
pred = model(im, augment=False, visualize=False) | |
pred = pred[0][1] if isinstance(pred[0], list) else pred[0] | |
pred = non_max_suppression(pred, conf_thres, iou_thres, None, agnostic_nms, max_det=max_det) | |
# Process predictions | |
for i, det in enumerate(pred): # per image | |
gn = torch.tensor(image.shape)[[1, 0, 1, 0]] # normalization gain whwh | |
imc = image.copy() | |
annotator = Annotator(image, line_width=line_thickness, example=str({0:'buffalo'})) | |
if len(det): | |
# Rescale boxes from img_size to im0 size | |
det[:, :4] = scale_boxes(im.shape[2:], det[:, :4], image.shape).round() | |
# Print results | |
for c in det[:, 5].unique(): | |
n = (det[:, 5] == c).sum() # detections per class | |
# Write results | |
for *xyxy, conf, cls in reversed(det): | |
c = int(cls) # integer class | |
label = 'buffalo' | |
annotator.box_label(xyxy, label, color=colors(c, True)) | |
# Stream results | |
im0 = annotator.result() | |
return im0 | |
title = "YOLO V9 trained on Custom Dataset" | |
description = "Gradio interface to show yoloV9 object detection." | |
examples = [[f'examples/{i}'] for i in os.listdir("examples")] | |
demo = gr.Interface( | |
inference, | |
inputs = [gr.Image(height=640, width = 640, label="Input Image"), gr.Slider(0, 1, value = 0.5, label="IOU Value"), gr.Slider(0, 1, value = 0.5, label="Threshold Value")], | |
outputs = [gr.Image(label="YoloV9 Output", height=640, width = 640)], | |
title = title, | |
description = description, | |
examples = examples, | |
) | |
demo.launch() |