#!/usr/bin/env python3 # Copyright 2004-present Facebook. All Rights Reserved. import copy import numpy as np from typing import Dict import torch from scipy.optimize import linear_sum_assignment from detectron2.config import configurable from detectron2.structures import Boxes, Instances from ..config.config import CfgNode as CfgNode_ from .base_tracker import BaseTracker class BaseHungarianTracker(BaseTracker): """ A base class for all Hungarian trackers """ @configurable def __init__( self, video_height: int, video_width: int, max_num_instances: int = 200, max_lost_frame_count: int = 0, min_box_rel_dim: float = 0.02, min_instance_period: int = 1, **kwargs ): """ Args: video_height: height the video frame video_width: width of the video frame max_num_instances: maximum number of id allowed to be tracked max_lost_frame_count: maximum number of frame an id can lost tracking exceed this number, an id is considered as lost forever min_box_rel_dim: a percentage, smaller than this dimension, a bbox is removed from tracking min_instance_period: an instance will be shown after this number of period since its first showing up in the video """ super().__init__(**kwargs) self._video_height = video_height self._video_width = video_width self._max_num_instances = max_num_instances self._max_lost_frame_count = max_lost_frame_count self._min_box_rel_dim = min_box_rel_dim self._min_instance_period = min_instance_period @classmethod def from_config(cls, cfg: CfgNode_) -> Dict: raise NotImplementedError("Calling HungarianTracker::from_config") def build_cost_matrix(self, instances: Instances, prev_instances: Instances) -> np.ndarray: raise NotImplementedError("Calling HungarianTracker::build_matrix") def update(self, instances: Instances) -> Instances: if instances.has("pred_keypoints"): raise NotImplementedError("Need to add support for keypoints") instances = self._initialize_extra_fields(instances) if self._prev_instances is not None: self._untracked_prev_idx = set(range(len(self._prev_instances))) cost_matrix = self.build_cost_matrix(instances, self._prev_instances) matched_idx, matched_prev_idx = linear_sum_assignment(cost_matrix) instances = self._process_matched_idx(instances, matched_idx, matched_prev_idx) instances = self._process_unmatched_idx(instances, matched_idx) instances = self._process_unmatched_prev_idx(instances, matched_prev_idx) self._prev_instances = copy.deepcopy(instances) return instances def _initialize_extra_fields(self, instances: Instances) -> Instances: """ If input instances don't have ID, ID_period, lost_frame_count fields, this method is used to initialize these fields. Args: instances: D2 Instances, for predictions of the current frame Return: D2 Instances with extra fields added """ if not instances.has("ID"): instances.set("ID", [None] * len(instances)) if not instances.has("ID_period"): instances.set("ID_period", [None] * len(instances)) if not instances.has("lost_frame_count"): instances.set("lost_frame_count", [None] * len(instances)) if self._prev_instances is None: instances.ID = list(range(len(instances))) self._id_count += len(instances) instances.ID_period = [1] * len(instances) instances.lost_frame_count = [0] * len(instances) return instances def _process_matched_idx( self, instances: Instances, matched_idx: np.ndarray, matched_prev_idx: np.ndarray ) -> Instances: assert matched_idx.size == matched_prev_idx.size for i in range(matched_idx.size): instances.ID[matched_idx[i]] = self._prev_instances.ID[matched_prev_idx[i]] instances.ID_period[matched_idx[i]] = ( self._prev_instances.ID_period[matched_prev_idx[i]] + 1 ) instances.lost_frame_count[matched_idx[i]] = 0 return instances def _process_unmatched_idx(self, instances: Instances, matched_idx: np.ndarray) -> Instances: untracked_idx = set(range(len(instances))).difference(set(matched_idx)) for idx in untracked_idx: instances.ID[idx] = self._id_count self._id_count += 1 instances.ID_period[idx] = 1 instances.lost_frame_count[idx] = 0 return instances def _process_unmatched_prev_idx( self, instances: Instances, matched_prev_idx: np.ndarray ) -> Instances: untracked_instances = Instances( image_size=instances.image_size, pred_boxes=[], pred_masks=[], pred_classes=[], scores=[], ID=[], ID_period=[], lost_frame_count=[], ) prev_bboxes = list(self._prev_instances.pred_boxes) prev_classes = list(self._prev_instances.pred_classes) prev_scores = list(self._prev_instances.scores) prev_ID_period = self._prev_instances.ID_period if instances.has("pred_masks"): prev_masks = list(self._prev_instances.pred_masks) untracked_prev_idx = set(range(len(self._prev_instances))).difference(set(matched_prev_idx)) for idx in untracked_prev_idx: x_left, y_top, x_right, y_bot = prev_bboxes[idx] if ( (1.0 * (x_right - x_left) / self._video_width < self._min_box_rel_dim) or (1.0 * (y_bot - y_top) / self._video_height < self._min_box_rel_dim) or self._prev_instances.lost_frame_count[idx] >= self._max_lost_frame_count or prev_ID_period[idx] <= self._min_instance_period ): continue untracked_instances.pred_boxes.append(list(prev_bboxes[idx].numpy())) untracked_instances.pred_classes.append(int(prev_classes[idx])) untracked_instances.scores.append(float(prev_scores[idx])) untracked_instances.ID.append(self._prev_instances.ID[idx]) untracked_instances.ID_period.append(self._prev_instances.ID_period[idx]) untracked_instances.lost_frame_count.append( self._prev_instances.lost_frame_count[idx] + 1 ) if instances.has("pred_masks"): untracked_instances.pred_masks.append(prev_masks[idx].numpy().astype(np.uint8)) untracked_instances.pred_boxes = Boxes(torch.FloatTensor(untracked_instances.pred_boxes)) untracked_instances.pred_classes = torch.IntTensor(untracked_instances.pred_classes) untracked_instances.scores = torch.FloatTensor(untracked_instances.scores) if instances.has("pred_masks"): untracked_instances.pred_masks = torch.IntTensor(untracked_instances.pred_masks) else: untracked_instances.remove("pred_masks") return Instances.cat( [ instances, untracked_instances, ] )