Spaces:
Runtime error
Runtime error
| #!/usr/bin/env python3 | |
| # Copyright 2004-present Facebook. All Rights Reserved. | |
| from detectron2.config import configurable | |
| from detectron2.utils.registry import Registry | |
| from ..config.config import CfgNode as CfgNode_ | |
| from ..structures import Instances | |
| TRACKER_HEADS_REGISTRY = Registry("TRACKER_HEADS") | |
| TRACKER_HEADS_REGISTRY.__doc__ = """ | |
| Registry for tracking classes. | |
| """ | |
| class BaseTracker: | |
| """ | |
| A parent class for all trackers | |
| """ | |
| def __init__(self, **kwargs): | |
| self._prev_instances = None # (D2)instances for previous frame | |
| self._matched_idx = set() # indices in prev_instances found matching | |
| self._matched_ID = set() # idendities in prev_instances found matching | |
| self._untracked_prev_idx = set() # indices in prev_instances not found matching | |
| self._id_count = 0 # used to assign new id | |
| def from_config(cls, cfg: CfgNode_): | |
| raise NotImplementedError("Calling BaseTracker::from_config") | |
| def update(self, predictions: Instances) -> Instances: | |
| """ | |
| Args: | |
| predictions: D2 Instances for predictions of the current frame | |
| Return: | |
| D2 Instances for predictions of the current frame with ID assigned | |
| _prev_instances and instances will have the following fields: | |
| .pred_boxes (shape=[N, 4]) | |
| .scores (shape=[N,]) | |
| .pred_classes (shape=[N,]) | |
| .pred_keypoints (shape=[N, M, 3], Optional) | |
| .pred_masks (shape=List[2D_MASK], Optional) 2D_MASK: shape=[H, W] | |
| .ID (shape=[N,]) | |
| N: # of detected bboxes | |
| H and W: height and width of 2D mask | |
| """ | |
| raise NotImplementedError("Calling BaseTracker::update") | |
| def build_tracker_head(cfg: CfgNode_) -> BaseTracker: | |
| """ | |
| Build a tracker head from `cfg.TRACKER_HEADS.TRACKER_NAME`. | |
| Args: | |
| cfg: D2 CfgNode, config file with tracker information | |
| Return: | |
| tracker object | |
| """ | |
| name = cfg.TRACKER_HEADS.TRACKER_NAME | |
| tracker_class = TRACKER_HEADS_REGISTRY.get(name) | |
| return tracker_class(cfg) | |