Spaces:
Sleeping
Sleeping
| # Mikel Broström 🔥 Yolo Tracking 🧾 AGPL-3.0 license | |
| import argparse | |
| import cv2 | |
| import numpy as np | |
| from functools import partial | |
| from pathlib import Path | |
| import torch | |
| from boxmot import TRACKERS | |
| from boxmot.tracker_zoo import create_tracker | |
| from boxmot.utils import ROOT, WEIGHTS, TRACKER_CONFIGS | |
| from boxmot.utils.checks import RequirementsChecker | |
| from tracking.detectors import (get_yolo_inferer, default_imgsz, | |
| is_ultralytics_model, is_yolox_model) | |
| # checker = RequirementsChecker() | |
| # checker.check_packages(('ultralytics @ git+https://github.com/mikel-brostrom/ultralytics.git', )) # install | |
| from ultralytics import YOLO | |
| from ultralytics.utils.plotting import Annotator, colors | |
| from ultralytics.data.utils import VID_FORMATS | |
| from ultralytics.utils.plotting import save_one_box | |
| def on_predict_start(predictor, persist=False): | |
| """ | |
| Initialize trackers for object tracking during prediction. | |
| Args: | |
| predictor (object): The predictor object to initialize trackers for. | |
| persist (bool, optional): Whether to persist the trackers if they already exist. Defaults to False. | |
| """ | |
| assert predictor.custom_args.tracking_method in TRACKERS, \ | |
| f"'{predictor.custom_args.tracking_method}' is not supported. Supported ones are {TRACKERS}" | |
| tracking_config = TRACKER_CONFIGS / (predictor.custom_args.tracking_method + '.yaml') | |
| trackers = [] | |
| for i in range(predictor.dataset.bs): | |
| tracker = create_tracker( | |
| predictor.custom_args.tracking_method, | |
| tracking_config, | |
| predictor.custom_args.reid_model, | |
| predictor.device, | |
| predictor.custom_args.half, | |
| predictor.custom_args.per_class | |
| ) | |
| # motion only modeles do not have | |
| if hasattr(tracker, 'model'): | |
| tracker.model.warmup() | |
| trackers.append(tracker) | |
| predictor.trackers = trackers | |
| def run(args): | |
| if args.imgsz is None: | |
| args.imgsz = default_imgsz(args.yolo_model) | |
| yolo = YOLO( | |
| args.yolo_model if is_ultralytics_model(args.yolo_model) | |
| else 'yolov8n.pt', | |
| ) | |
| results = yolo.track( | |
| source=args.source, | |
| conf=args.conf, | |
| iou=args.iou, | |
| agnostic_nms=args.agnostic_nms, | |
| show=False, | |
| stream=True, | |
| device=args.device, | |
| show_conf=args.show_conf, | |
| save_txt=args.save_txt, | |
| show_labels=args.show_labels, | |
| save=args.save, | |
| verbose=args.verbose, | |
| exist_ok=args.exist_ok, | |
| project=args.project, | |
| name=args.name, | |
| classes=args.classes, | |
| imgsz=args.imgsz, | |
| vid_stride=args.vid_stride, | |
| line_width=args.line_width | |
| ) | |
| yolo.add_callback('on_predict_start', partial(on_predict_start, persist=True)) | |
| if not is_ultralytics_model(args.yolo_model): | |
| # replace yolov8 model | |
| m = get_yolo_inferer(args.yolo_model) | |
| yolo_model = m(model=args.yolo_model, device=yolo.predictor.device, | |
| args=yolo.predictor.args) | |
| yolo.predictor.model = yolo_model | |
| # If current model is YOLOX, change the preprocess and postprocess | |
| if not is_ultralytics_model(args.yolo_model): | |
| # add callback to save image paths for further processing | |
| yolo.add_callback( | |
| "on_predict_batch_start", | |
| lambda p: yolo_model.update_im_paths(p) | |
| ) | |
| yolo.predictor.preprocess = ( | |
| lambda imgs: yolo_model.preprocess(im=imgs)) | |
| yolo.predictor.postprocess = ( | |
| lambda preds, im, im0s: | |
| yolo_model.postprocess(preds=preds, im=im, im0s=im0s)) | |
| # store custom args in predictor | |
| yolo.predictor.custom_args = args | |
| for r in results: | |
| if hasattr(yolo.predictor.trackers[0], "plot_results"): | |
| img = yolo.predictor.trackers[0].plot_results(r.orig_img, args.show_trajectories) | |
| else: | |
| # Ultralytics Results handles its own image internally | |
| img = r.plot() | |
| if args.show is True: | |
| cv2.imshow('BoxMOT', img) | |
| key = cv2.waitKey(1) & 0xFF | |
| if key == ord(' ') or key == ord('q'): | |
| break | |
| def parse_opt(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument('--yolo-model', type=Path, default=WEIGHTS / 'yolov8n', | |
| help='yolo model path') | |
| parser.add_argument('--reid-model', type=Path, default=WEIGHTS / 'osnet_x0_25_msmt17.pt', | |
| help='reid model path') | |
| parser.add_argument('--tracking-method', type=str, default='deepocsort', | |
| help='deepocsort, botsort, strongsort, ocsort, bytetrack, imprassoc, boosttrack') | |
| parser.add_argument('--source', type=str, default='0', | |
| help='file/dir/URL/glob, 0 for webcam') | |
| parser.add_argument('--imgsz', '--img', '--img-size', nargs='+', type=int, default=None, | |
| help='inference size h,w') | |
| parser.add_argument('--conf', type=float, default=0.5, | |
| help='confidence threshold') | |
| parser.add_argument('--iou', type=float, default=0.7, | |
| help='intersection over union (IoU) threshold for NMS') | |
| parser.add_argument('--device', default='', | |
| help='cuda device, i.e. 0 or 0,1,2,3 or cpu') | |
| parser.add_argument('--show', action='store_true', | |
| help='display tracking video results') | |
| parser.add_argument('--save', action='store_true', | |
| help='save video tracking results') | |
| # class 0 is person, 1 is bycicle, 2 is car... 79 is oven | |
| parser.add_argument('--classes', nargs='+', type=int, | |
| help='filter by class: --classes 0, or --classes 0 2 3') | |
| parser.add_argument('--project', default=ROOT / 'runs' / 'track', | |
| help='save results to project/name') | |
| parser.add_argument('--name', default='exp', | |
| help='save results to project/name') | |
| parser.add_argument('--exist-ok', action='store_true', | |
| help='existing project/name ok, do not increment') | |
| parser.add_argument('--half', action='store_true', | |
| help='use FP16 half-precision inference') | |
| parser.add_argument('--vid-stride', type=int, default=1, | |
| help='video frame-rate stride') | |
| parser.add_argument('--show-labels', action='store_false', | |
| help='either show all or only bboxes') | |
| parser.add_argument('--show-conf', action='store_false', | |
| help='hide confidences when show') | |
| parser.add_argument('--show-trajectories', action='store_true', | |
| help='show confidences') | |
| parser.add_argument('--save-txt', action='store_true', | |
| help='save tracking results in a txt file') | |
| parser.add_argument('--save-id-crops', action='store_true', | |
| help='save each crop to its respective id folder') | |
| parser.add_argument('--line-width', default=None, type=int, | |
| help='The line width of the bounding boxes. If None, it is scaled to the image size.') | |
| parser.add_argument('--per-class', default=False, action='store_true', | |
| help='not mix up classes when tracking') | |
| parser.add_argument('--verbose', default=True, action='store_true', | |
| help='print results per frame') | |
| parser.add_argument('--agnostic-nms', default=False, action='store_true', | |
| help='class-agnostic NMS') | |
| opt = parser.parse_args() | |
| return opt | |
| if __name__ == "__main__": | |
| opt = parse_opt() | |
| run(opt) | |