|
|
|
|
|
import copy |
|
import numpy as np |
|
from typing import Dict |
|
import torch |
|
from scipy.optimize import linear_sum_assignment |
|
|
|
from detectron2.config import configurable |
|
from detectron2.structures import Boxes, Instances |
|
|
|
from ..config.config import CfgNode as CfgNode_ |
|
from .base_tracker import BaseTracker |
|
|
|
|
|
class BaseHungarianTracker(BaseTracker): |
|
""" |
|
A base class for all Hungarian trackers |
|
""" |
|
|
|
@configurable |
|
def __init__( |
|
self, |
|
video_height: int, |
|
video_width: int, |
|
max_num_instances: int = 200, |
|
max_lost_frame_count: int = 0, |
|
min_box_rel_dim: float = 0.02, |
|
min_instance_period: int = 1, |
|
**kwargs |
|
): |
|
""" |
|
Args: |
|
video_height: height the video frame |
|
video_width: width of the video frame |
|
max_num_instances: maximum number of id allowed to be tracked |
|
max_lost_frame_count: maximum number of frame an id can lost tracking |
|
exceed this number, an id is considered as lost |
|
forever |
|
min_box_rel_dim: a percentage, smaller than this dimension, a bbox is |
|
removed from tracking |
|
min_instance_period: an instance will be shown after this number of period |
|
since its first showing up in the video |
|
""" |
|
super().__init__(**kwargs) |
|
self._video_height = video_height |
|
self._video_width = video_width |
|
self._max_num_instances = max_num_instances |
|
self._max_lost_frame_count = max_lost_frame_count |
|
self._min_box_rel_dim = min_box_rel_dim |
|
self._min_instance_period = min_instance_period |
|
|
|
@classmethod |
|
def from_config(cls, cfg: CfgNode_) -> Dict: |
|
raise NotImplementedError("Calling HungarianTracker::from_config") |
|
|
|
def build_cost_matrix(self, instances: Instances, prev_instances: Instances) -> np.ndarray: |
|
raise NotImplementedError("Calling HungarianTracker::build_matrix") |
|
|
|
def update(self, instances: Instances) -> Instances: |
|
if instances.has("pred_keypoints"): |
|
raise NotImplementedError("Need to add support for keypoints") |
|
instances = self._initialize_extra_fields(instances) |
|
if self._prev_instances is not None: |
|
self._untracked_prev_idx = set(range(len(self._prev_instances))) |
|
cost_matrix = self.build_cost_matrix(instances, self._prev_instances) |
|
matched_idx, matched_prev_idx = linear_sum_assignment(cost_matrix) |
|
instances = self._process_matched_idx(instances, matched_idx, matched_prev_idx) |
|
instances = self._process_unmatched_idx(instances, matched_idx) |
|
instances = self._process_unmatched_prev_idx(instances, matched_prev_idx) |
|
self._prev_instances = copy.deepcopy(instances) |
|
return instances |
|
|
|
def _initialize_extra_fields(self, instances: Instances) -> Instances: |
|
""" |
|
If input instances don't have ID, ID_period, lost_frame_count fields, |
|
this method is used to initialize these fields. |
|
|
|
Args: |
|
instances: D2 Instances, for predictions of the current frame |
|
Return: |
|
D2 Instances with extra fields added |
|
""" |
|
if not instances.has("ID"): |
|
instances.set("ID", [None] * len(instances)) |
|
if not instances.has("ID_period"): |
|
instances.set("ID_period", [None] * len(instances)) |
|
if not instances.has("lost_frame_count"): |
|
instances.set("lost_frame_count", [None] * len(instances)) |
|
if self._prev_instances is None: |
|
instances.ID = list(range(len(instances))) |
|
self._id_count += len(instances) |
|
instances.ID_period = [1] * len(instances) |
|
instances.lost_frame_count = [0] * len(instances) |
|
return instances |
|
|
|
def _process_matched_idx( |
|
self, instances: Instances, matched_idx: np.ndarray, matched_prev_idx: np.ndarray |
|
) -> Instances: |
|
assert matched_idx.size == matched_prev_idx.size |
|
for i in range(matched_idx.size): |
|
instances.ID[matched_idx[i]] = self._prev_instances.ID[matched_prev_idx[i]] |
|
instances.ID_period[matched_idx[i]] = ( |
|
self._prev_instances.ID_period[matched_prev_idx[i]] + 1 |
|
) |
|
instances.lost_frame_count[matched_idx[i]] = 0 |
|
return instances |
|
|
|
def _process_unmatched_idx(self, instances: Instances, matched_idx: np.ndarray) -> Instances: |
|
untracked_idx = set(range(len(instances))).difference(set(matched_idx)) |
|
for idx in untracked_idx: |
|
instances.ID[idx] = self._id_count |
|
self._id_count += 1 |
|
instances.ID_period[idx] = 1 |
|
instances.lost_frame_count[idx] = 0 |
|
return instances |
|
|
|
def _process_unmatched_prev_idx( |
|
self, instances: Instances, matched_prev_idx: np.ndarray |
|
) -> Instances: |
|
untracked_instances = Instances( |
|
image_size=instances.image_size, |
|
pred_boxes=[], |
|
pred_masks=[], |
|
pred_classes=[], |
|
scores=[], |
|
ID=[], |
|
ID_period=[], |
|
lost_frame_count=[], |
|
) |
|
prev_bboxes = list(self._prev_instances.pred_boxes) |
|
prev_classes = list(self._prev_instances.pred_classes) |
|
prev_scores = list(self._prev_instances.scores) |
|
prev_ID_period = self._prev_instances.ID_period |
|
if instances.has("pred_masks"): |
|
prev_masks = list(self._prev_instances.pred_masks) |
|
untracked_prev_idx = set(range(len(self._prev_instances))).difference(set(matched_prev_idx)) |
|
for idx in untracked_prev_idx: |
|
x_left, y_top, x_right, y_bot = prev_bboxes[idx] |
|
if ( |
|
(1.0 * (x_right - x_left) / self._video_width < self._min_box_rel_dim) |
|
or (1.0 * (y_bot - y_top) / self._video_height < self._min_box_rel_dim) |
|
or self._prev_instances.lost_frame_count[idx] >= self._max_lost_frame_count |
|
or prev_ID_period[idx] <= self._min_instance_period |
|
): |
|
continue |
|
untracked_instances.pred_boxes.append(list(prev_bboxes[idx].numpy())) |
|
untracked_instances.pred_classes.append(int(prev_classes[idx])) |
|
untracked_instances.scores.append(float(prev_scores[idx])) |
|
untracked_instances.ID.append(self._prev_instances.ID[idx]) |
|
untracked_instances.ID_period.append(self._prev_instances.ID_period[idx]) |
|
untracked_instances.lost_frame_count.append( |
|
self._prev_instances.lost_frame_count[idx] + 1 |
|
) |
|
if instances.has("pred_masks"): |
|
untracked_instances.pred_masks.append(prev_masks[idx].numpy().astype(np.uint8)) |
|
|
|
untracked_instances.pred_boxes = Boxes(torch.FloatTensor(untracked_instances.pred_boxes)) |
|
untracked_instances.pred_classes = torch.IntTensor(untracked_instances.pred_classes) |
|
untracked_instances.scores = torch.FloatTensor(untracked_instances.scores) |
|
if instances.has("pred_masks"): |
|
untracked_instances.pred_masks = torch.IntTensor(untracked_instances.pred_masks) |
|
else: |
|
untracked_instances.remove("pred_masks") |
|
|
|
return Instances.cat( |
|
[ |
|
instances, |
|
untracked_instances, |
|
] |
|
) |
|
|