Spaces:
Runtime error
Runtime error
File size: 2,527 Bytes
6748ade |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 |
from .yolov7_utils import prepare_input, process_output
import onnxruntime
import os
class YOLOv7Detector:
def __init__(self,
weights=os.path.join(os.path.dirname(
os.path.abspath(__file__)), './weights/yolov7.onnx'),
use_cuda=True, use_onnx=True) -> None:
if use_onnx:
if use_cuda:
providers = [
'CUDAExecutionProvider',
'CPUExecutionProvider'
]
else:
providers = ['CPUExecutionProvider']
self.model = onnxruntime.InferenceSession(weights, providers=providers)
self.class_names= ['person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light',
'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow',
'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee',
'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard',
'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple',
'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch',
'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone',
'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors', 'teddy bear',
'hair drier', 'toothbrush']
# else:
# self.model = torch
self.device = 'cuda' if use_cuda else 'cpu'
def detect(self, image: list,
conf_thres: float = 0.25,
iou_thres: float = 0.45,
classes: list = None,
agnostic_nms: bool = False,
input_shape=(640, 640),
max_det: int = 1000) -> list:
image0 = image.copy()
input_tensor = prepare_input(image, input_shape)
input_name = self.model.get_inputs()[0].name
outputs = self.model.run([self.model.get_outputs()[0].name], {
input_name: input_tensor})
dets = process_output(
outputs, image0.shape[:2], input_shape, conf_thres, iou_thres, classes=classes)
image_info = {
'width': image0.shape[1],
'height': image0.shape[0],
}
return dets, image_info
|