glenn-jocher commited on
Commit
2e53844
1 Parent(s): 39ef6c7

ONNX inference update (#4073)

Browse files
Files changed (1) hide show
  1. detect.py +32 -22
detect.py CHANGED
@@ -64,18 +64,23 @@ def run(weights='yolov5s.pt', # model.pt path(s)
64
  half &= device.type != 'cpu' # half precision only supported on CUDA
65
 
66
  # Load model
67
- model = attempt_load(weights, map_location=device) # load FP32 model
68
- stride = int(model.stride.max()) # model stride
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
  imgsz = check_img_size(imgsz, s=stride) # check image size
70
- names = model.module.names if hasattr(model, 'module') else model.names # get class names
71
- if half:
72
- model.half() # to FP16
73
-
74
- # Second-stage classifier
75
- classify = False
76
- if classify:
77
- modelc = load_classifier(name='resnet50', n=2) # initialize
78
- modelc.load_state_dict(torch.load('resnet50.pt', map_location=device)['model']).to(device).eval()
79
 
80
  # Dataloader
81
  if webcam:
@@ -89,31 +94,36 @@ def run(weights='yolov5s.pt', # model.pt path(s)
89
  vid_path, vid_writer = [None] * bs, [None] * bs
90
 
91
  # Run inference
92
- if device.type != 'cpu':
93
  model(torch.zeros(1, 3, imgsz, imgsz).to(device).type_as(next(model.parameters()))) # run once
94
  t0 = time.time()
95
  for path, img, im0s, vid_cap in dataset:
96
- img = torch.from_numpy(img).to(device)
97
- img = img.half() if half else img.float() # uint8 to fp16/32
 
 
 
98
  img /= 255.0 # 0 - 255 to 0.0 - 1.0
99
- if img.ndimension() == 3:
100
- img = img.unsqueeze(0)
101
 
102
  # Inference
103
  t1 = time_sync()
104
- pred = model(img,
105
- augment=augment,
106
- visualize=increment_path(save_dir / Path(path).stem, mkdir=True) if visualize else False)[0]
 
 
107
 
108
- # Apply NMS
109
  pred = non_max_suppression(pred, conf_thres, iou_thres, classes, agnostic_nms, max_det=max_det)
110
  t2 = time_sync()
111
 
112
- # Apply Classifier
113
  if classify:
114
  pred = apply_classifier(pred, modelc, img, im0s)
115
 
116
- # Process detections
117
  for i, det in enumerate(pred): # detections per image
118
  if webcam: # batch_size >= 1
119
  p, s, im0, frame = path[i], f'{i}: ', im0s[i].copy(), dataset.count
 
64
  half &= device.type != 'cpu' # half precision only supported on CUDA
65
 
66
  # Load model
67
+ w = weights[0] if isinstance(weights, list) else weights
68
+ classify, pt, onnx = False, w.endswith('.pt'), w.endswith('.onnx') # inference type
69
+ stride, names = 64, [f'class{i}' for i in range(1000)] # assign defaults
70
+ if pt:
71
+ model = attempt_load(weights, map_location=device) # load FP32 model
72
+ stride = int(model.stride.max()) # model stride
73
+ names = model.module.names if hasattr(model, 'module') else model.names # get class names
74
+ if half:
75
+ model.half() # to FP16
76
+ if classify: # second-stage classifier
77
+ modelc = load_classifier(name='resnet50', n=2) # initialize
78
+ modelc.load_state_dict(torch.load('resnet50.pt', map_location=device)['model']).to(device).eval()
79
+ elif onnx:
80
+ check_requirements(('onnx', 'onnxruntime'))
81
+ import onnxruntime
82
+ session = onnxruntime.InferenceSession(w, None)
83
  imgsz = check_img_size(imgsz, s=stride) # check image size
 
 
 
 
 
 
 
 
 
84
 
85
  # Dataloader
86
  if webcam:
 
94
  vid_path, vid_writer = [None] * bs, [None] * bs
95
 
96
  # Run inference
97
+ if pt and device.type != 'cpu':
98
  model(torch.zeros(1, 3, imgsz, imgsz).to(device).type_as(next(model.parameters()))) # run once
99
  t0 = time.time()
100
  for path, img, im0s, vid_cap in dataset:
101
+ if pt:
102
+ img = torch.from_numpy(img).to(device)
103
+ img = img.half() if half else img.float() # uint8 to fp16/32
104
+ elif onnx:
105
+ img = img.astype('float32')
106
  img /= 255.0 # 0 - 255 to 0.0 - 1.0
107
+ if len(img.shape) == 3:
108
+ img = img[None] # expand for batch dim
109
 
110
  # Inference
111
  t1 = time_sync()
112
+ if pt:
113
+ visualize = increment_path(save_dir / Path(path).stem, mkdir=True) if visualize else False
114
+ pred = model(img, augment=augment, visualize=visualize)[0]
115
+ elif onnx:
116
+ pred = torch.tensor(session.run([session.get_outputs()[0].name], {session.get_inputs()[0].name: img}))
117
 
118
+ # NMS
119
  pred = non_max_suppression(pred, conf_thres, iou_thres, classes, agnostic_nms, max_det=max_det)
120
  t2 = time_sync()
121
 
122
+ # Second-stage classifier (optional)
123
  if classify:
124
  pred = apply_classifier(pred, modelc, img, im0s)
125
 
126
+ # Process predictions
127
  for i, det in enumerate(pred): # detections per image
128
  if webcam: # batch_size >= 1
129
  p, s, im0, frame = path[i], f'{i}: ', im0s[i].copy(), dataset.count