usiddiquee
hi
e1832f4
# Mikel Broström 🔥 Yolo Tracking 🧾 AGPL-3.0 license
import argparse
import cv2
import numpy as np
from functools import partial
from pathlib import Path
import torch
from boxmot import TRACKERS
from boxmot.tracker_zoo import create_tracker
from boxmot.utils import ROOT, WEIGHTS, TRACKER_CONFIGS
from boxmot.utils.checks import RequirementsChecker
from tracking.detectors import (get_yolo_inferer, default_imgsz,
is_ultralytics_model, is_yolox_model)
# checker = RequirementsChecker()
# checker.check_packages(('ultralytics @ git+https://github.com/mikel-brostrom/ultralytics.git', )) # install
from ultralytics import YOLO
from ultralytics.utils.plotting import Annotator, colors
from ultralytics.data.utils import VID_FORMATS
from ultralytics.utils.plotting import save_one_box
def on_predict_start(predictor, persist=False):
"""
Initialize trackers for object tracking during prediction.
Args:
predictor (object): The predictor object to initialize trackers for.
persist (bool, optional): Whether to persist the trackers if they already exist. Defaults to False.
"""
assert predictor.custom_args.tracking_method in TRACKERS, \
f"'{predictor.custom_args.tracking_method}' is not supported. Supported ones are {TRACKERS}"
tracking_config = TRACKER_CONFIGS / (predictor.custom_args.tracking_method + '.yaml')
trackers = []
for i in range(predictor.dataset.bs):
tracker = create_tracker(
predictor.custom_args.tracking_method,
tracking_config,
predictor.custom_args.reid_model,
predictor.device,
predictor.custom_args.half,
predictor.custom_args.per_class
)
# motion only modeles do not have
if hasattr(tracker, 'model'):
tracker.model.warmup()
trackers.append(tracker)
predictor.trackers = trackers
@torch.no_grad()
def run(args):
if args.imgsz is None:
args.imgsz = default_imgsz(args.yolo_model)
yolo = YOLO(
args.yolo_model if is_ultralytics_model(args.yolo_model)
else 'yolov8n.pt',
)
results = yolo.track(
source=args.source,
conf=args.conf,
iou=args.iou,
agnostic_nms=args.agnostic_nms,
show=False,
stream=True,
device=args.device,
show_conf=args.show_conf,
save_txt=args.save_txt,
show_labels=args.show_labels,
save=args.save,
verbose=args.verbose,
exist_ok=args.exist_ok,
project=args.project,
name=args.name,
classes=args.classes,
imgsz=args.imgsz,
vid_stride=args.vid_stride,
line_width=args.line_width
)
yolo.add_callback('on_predict_start', partial(on_predict_start, persist=True))
if not is_ultralytics_model(args.yolo_model):
# replace yolov8 model
m = get_yolo_inferer(args.yolo_model)
yolo_model = m(model=args.yolo_model, device=yolo.predictor.device,
args=yolo.predictor.args)
yolo.predictor.model = yolo_model
# If current model is YOLOX, change the preprocess and postprocess
if not is_ultralytics_model(args.yolo_model):
# add callback to save image paths for further processing
yolo.add_callback(
"on_predict_batch_start",
lambda p: yolo_model.update_im_paths(p)
)
yolo.predictor.preprocess = (
lambda imgs: yolo_model.preprocess(im=imgs))
yolo.predictor.postprocess = (
lambda preds, im, im0s:
yolo_model.postprocess(preds=preds, im=im, im0s=im0s))
# store custom args in predictor
yolo.predictor.custom_args = args
for r in results:
if hasattr(yolo.predictor.trackers[0], "plot_results"):
img = yolo.predictor.trackers[0].plot_results(r.orig_img, args.show_trajectories)
else:
# Ultralytics Results handles its own image internally
img = r.plot()
if args.show is True:
cv2.imshow('BoxMOT', img)
key = cv2.waitKey(1) & 0xFF
if key == ord(' ') or key == ord('q'):
break
def parse_opt():
parser = argparse.ArgumentParser()
parser.add_argument('--yolo-model', type=Path, default=WEIGHTS / 'yolov8n',
help='yolo model path')
parser.add_argument('--reid-model', type=Path, default=WEIGHTS / 'osnet_x0_25_msmt17.pt',
help='reid model path')
parser.add_argument('--tracking-method', type=str, default='deepocsort',
help='deepocsort, botsort, strongsort, ocsort, bytetrack, imprassoc, boosttrack')
parser.add_argument('--source', type=str, default='0',
help='file/dir/URL/glob, 0 for webcam')
parser.add_argument('--imgsz', '--img', '--img-size', nargs='+', type=int, default=None,
help='inference size h,w')
parser.add_argument('--conf', type=float, default=0.5,
help='confidence threshold')
parser.add_argument('--iou', type=float, default=0.7,
help='intersection over union (IoU) threshold for NMS')
parser.add_argument('--device', default='',
help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
parser.add_argument('--show', action='store_true',
help='display tracking video results')
parser.add_argument('--save', action='store_true',
help='save video tracking results')
# class 0 is person, 1 is bycicle, 2 is car... 79 is oven
parser.add_argument('--classes', nargs='+', type=int,
help='filter by class: --classes 0, or --classes 0 2 3')
parser.add_argument('--project', default=ROOT / 'runs' / 'track',
help='save results to project/name')
parser.add_argument('--name', default='exp',
help='save results to project/name')
parser.add_argument('--exist-ok', action='store_true',
help='existing project/name ok, do not increment')
parser.add_argument('--half', action='store_true',
help='use FP16 half-precision inference')
parser.add_argument('--vid-stride', type=int, default=1,
help='video frame-rate stride')
parser.add_argument('--show-labels', action='store_false',
help='either show all or only bboxes')
parser.add_argument('--show-conf', action='store_false',
help='hide confidences when show')
parser.add_argument('--show-trajectories', action='store_true',
help='show confidences')
parser.add_argument('--save-txt', action='store_true',
help='save tracking results in a txt file')
parser.add_argument('--save-id-crops', action='store_true',
help='save each crop to its respective id folder')
parser.add_argument('--line-width', default=None, type=int,
help='The line width of the bounding boxes. If None, it is scaled to the image size.')
parser.add_argument('--per-class', default=False, action='store_true',
help='not mix up classes when tracking')
parser.add_argument('--verbose', default=True, action='store_true',
help='print results per frame')
parser.add_argument('--agnostic-nms', default=False, action='store_true',
help='class-agnostic NMS')
opt = parser.parse_args()
return opt
if __name__ == "__main__":
opt = parse_opt()
run(opt)