Bhaskar Saranga commited on
Commit
d2228f2
1 Parent(s): ec6d159

cuda validation for v8 predictions

Browse files
Files changed (1) hide show
  1. app.py +4 -1
app.py CHANGED
@@ -125,7 +125,10 @@ def detectv8(img,model,device,iou_threshold=0.45,confidence_threshold=0.25):
125
  results= model.predict(img,conf=confidence_threshold, iou=iou_threshold)
126
  fps_inference = 1/(time.time()-start)
127
 
128
- boxes=results[0].boxes.numpy()
 
 
 
129
  for bbox in boxes:
130
  #print(f'{colors[names[int(bbox.cls[0])]]}')
131
  label = f'{names[int(bbox.cls[0])]} {bbox.conf[0]:.2f}'
 
125
  results= model.predict(img,conf=confidence_threshold, iou=iou_threshold)
126
  fps_inference = 1/(time.time()-start)
127
 
128
+ if torch.cuda.is_available():
129
+ boxes= results[0].boxes.cpu().detach().numpy()
130
+ else:
131
+ boxes=results[0].boxes.numpy()
132
  for bbox in boxes:
133
  #print(f'{colors[names[int(bbox.cls[0])]]}')
134
  label = f'{names[int(bbox.cls[0])]} {bbox.conf[0]:.2f}'