ntsc207 commited on
Commit
1935572
1 Parent(s): 8e8faee

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -4
app.py CHANGED
@@ -33,7 +33,7 @@ def yolov9_inference(model_id, img_path=None, vid_path=None, tracking_algorithm
33
  img.save(img_path)
34
  input_path = img_path
35
 
36
- output_path, df, frame_counts_df = run(weights=model_id, imgsz=(image_size,image_size), conf_thres=conf_threshold, iou_thres=iou_threshold, source=input_path, device='cpu', hide_conf= True)
37
  elif vid_path is not None:
38
  vid_name = 'output.mp4'
39
 
@@ -70,12 +70,12 @@ def yolov9_inference(model_id, img_path=None, vid_path=None, tracking_algorithm
70
  out.release()
71
  input_path = vid_name
72
  if tracking_algorithm == 'deep_sort':
73
- output_path, df, frame_counts_df = run_deepsort(weights=model_id, imgsz=(image_size,image_size), conf_thres=conf_threshold, iou_thres=iou_threshold, source=input_path, device='cpu', draw_trails=True)
74
  elif tracking_algorithm == 'strong_sort':
75
- device_strongsort = torch.device('cpu')
76
  output_path, df, frame_counts_df = run_strongsort(yolo_weights=model_id, imgsz=(image_size,image_size), conf_thres=conf_threshold, iou_thres=iou_threshold, source=input_path, device=device_strongsort, strong_sort_weights = "osnet_x0_25_msmt17.pt", hide_conf= True)
77
  else:
78
- output_path, df, frame_counts_df = run(weights=model_id, imgsz=(image_size,image_size), conf_thres=conf_threshold, iou_thres=iou_threshold, source=input_path, device='cpu', hide_conf= True)
79
  # Assuming output_path is the path to the output file
80
  _, output_extension = os.path.splitext(output_path)
81
  palette = {"Bus": "red", "Bike": "blue", "Car": "green", "Pedestrian": "yellow", "Truck": "purple"}
 
33
  img.save(img_path)
34
  input_path = img_path
35
 
36
+ output_path, df, frame_counts_df = run(weights=model_id, imgsz=(image_size,image_size), conf_thres=conf_threshold, iou_thres=iou_threshold, source=input_path, device='0', hide_conf= True)
37
  elif vid_path is not None:
38
  vid_name = 'output.mp4'
39
 
 
70
  out.release()
71
  input_path = vid_name
72
  if tracking_algorithm == 'deep_sort':
73
+ output_path, df, frame_counts_df = run_deepsort(weights=model_id, imgsz=(image_size,image_size), conf_thres=conf_threshold, iou_thres=iou_threshold, source=input_path, device='0', draw_trails=True)
74
  elif tracking_algorithm == 'strong_sort':
75
+ device_strongsort = torch.device('cuda:0')
76
  output_path, df, frame_counts_df = run_strongsort(yolo_weights=model_id, imgsz=(image_size,image_size), conf_thres=conf_threshold, iou_thres=iou_threshold, source=input_path, device=device_strongsort, strong_sort_weights = "osnet_x0_25_msmt17.pt", hide_conf= True)
77
  else:
78
+ output_path, df, frame_counts_df = run(weights=model_id, imgsz=(image_size,image_size), conf_thres=conf_threshold, iou_thres=iou_threshold, source=input_path, device='0', hide_conf= True)
79
  # Assuming output_path is the path to the output file
80
  _, output_extension = os.path.splitext(output_path)
81
  palette = {"Bus": "red", "Bike": "blue", "Car": "green", "Pedestrian": "yellow", "Truck": "purple"}