#!/usr/bin/env python3 # Copyright 2004-present Facebook. All Rights Reserved. from detectron2.config import configurable from detectron2.utils.registry import Registry from ..config.config import CfgNode as CfgNode_ from ..structures import Instances TRACKER_HEADS_REGISTRY = Registry("TRACKER_HEADS") TRACKER_HEADS_REGISTRY.__doc__ = """ Registry for tracking classes. """ class BaseTracker: """ A parent class for all trackers """ @configurable def __init__(self, **kwargs): self._prev_instances = None # (D2)instances for previous frame self._matched_idx = set() # indices in prev_instances found matching self._matched_ID = set() # idendities in prev_instances found matching self._untracked_prev_idx = set() # indices in prev_instances not found matching self._id_count = 0 # used to assign new id @classmethod def from_config(cls, cfg: CfgNode_): raise NotImplementedError("Calling BaseTracker::from_config") def update(self, predictions: Instances) -> Instances: """ Args: predictions: D2 Instances for predictions of the current frame Return: D2 Instances for predictions of the current frame with ID assigned _prev_instances and instances will have the following fields: .pred_boxes (shape=[N, 4]) .scores (shape=[N,]) .pred_classes (shape=[N,]) .pred_keypoints (shape=[N, M, 3], Optional) .pred_masks (shape=List[2D_MASK], Optional) 2D_MASK: shape=[H, W] .ID (shape=[N,]) N: # of detected bboxes H and W: height and width of 2D mask """ raise NotImplementedError("Calling BaseTracker::update") def build_tracker_head(cfg: CfgNode_) -> BaseTracker: """ Build a tracker head from `cfg.TRACKER_HEADS.TRACKER_NAME`. Args: cfg: D2 CfgNode, config file with tracker information Return: tracker object """ name = cfg.TRACKER_HEADS.TRACKER_NAME tracker_class = TRACKER_HEADS_REGISTRY.get(name) return tracker_class(cfg)