Spaces:
Running
on
Zero
Running
on
Zero
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# All rights reserved. | |
# | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
from typing import List | |
import torch | |
from mmengine.structures import InstanceData | |
from torch import Tensor | |
from mmdet.registry import MODELS | |
from mmdet.structures import DetDataSample | |
from mmdet.structures.bbox import bbox_overlaps | |
from .base_tracker import BaseTracker | |
class MaskTrackRCNNTracker(BaseTracker): | |
"""Tracker for MaskTrack R-CNN. | |
Args: | |
match_weights (dict[str : float]): The Weighting factor when computing | |
the match score. It contains keys as follows: | |
- det_score (float): The coefficient of `det_score` when computing | |
match score. | |
- iou (float): The coefficient of `ious` when computing match | |
score. | |
- det_label (float): The coefficient of `label_deltas` when | |
computing match score. | |
""" | |
def __init__(self, | |
match_weights: dict = dict( | |
det_score=1.0, iou=2.0, det_label=10.0), | |
**kwargs): | |
super().__init__(**kwargs) | |
self.match_weights = match_weights | |
def get_match_score(self, bboxes: Tensor, labels: Tensor, scores: Tensor, | |
prev_bboxes: Tensor, prev_labels: Tensor, | |
similarity_logits: Tensor) -> Tensor: | |
"""Get the match score. | |
Args: | |
bboxes (torch.Tensor): of shape (num_current_bboxes, 4) in | |
[tl_x, tl_y, br_x, br_y] format. Denoting the detection | |
bboxes of current frame. | |
labels (torch.Tensor): of shape (num_current_bboxes, ) | |
scores (torch.Tensor): of shape (num_current_bboxes, ) | |
prev_bboxes (torch.Tensor): of shape (num_previous_bboxes, 4) in | |
[tl_x, tl_y, br_x, br_y] format. Denoting the detection bboxes | |
of previous frame. | |
prev_labels (torch.Tensor): of shape (num_previous_bboxes, ) | |
similarity_logits (torch.Tensor): of shape (num_current_bboxes, | |
num_previous_bboxes + 1). Denoting the similarity logits from | |
track head. | |
Returns: | |
torch.Tensor: The matching score of shape (num_current_bboxes, | |
num_previous_bboxes + 1) | |
""" | |
similarity_scores = similarity_logits.softmax(dim=1) | |
ious = bbox_overlaps(bboxes, prev_bboxes) | |
iou_dummy = ious.new_zeros(ious.shape[0], 1) | |
ious = torch.cat((iou_dummy, ious), dim=1) | |
label_deltas = (labels.view(-1, 1) == prev_labels).float() | |
label_deltas_dummy = label_deltas.new_ones(label_deltas.shape[0], 1) | |
label_deltas = torch.cat((label_deltas_dummy, label_deltas), dim=1) | |
match_score = similarity_scores.log() | |
match_score += self.match_weights['det_score'] * \ | |
scores.view(-1, 1).log() | |
match_score += self.match_weights['iou'] * ious | |
match_score += self.match_weights['det_label'] * label_deltas | |
return match_score | |
def assign_ids(self, match_scores: Tensor): | |
num_prev_bboxes = match_scores.shape[1] - 1 | |
_, match_ids = match_scores.max(dim=1) | |
ids = match_ids.new_zeros(match_ids.shape[0]) - 1 | |
best_match_scores = match_scores.new_zeros(num_prev_bboxes) - 1e6 | |
for idx, match_id in enumerate(match_ids): | |
if match_id == 0: | |
ids[idx] = self.num_tracks | |
self.num_tracks += 1 | |
else: | |
match_score = match_scores[idx, match_id] | |
# TODO: fix the bug where multiple candidate might match | |
# with the same previous object. | |
if match_score > best_match_scores[match_id - 1]: | |
ids[idx] = self.ids[match_id - 1] | |
best_match_scores[match_id - 1] = match_score | |
return ids, best_match_scores | |
def track(self, | |
model: torch.nn.Module, | |
feats: List[torch.Tensor], | |
data_sample: DetDataSample, | |
rescale=True, | |
**kwargs) -> InstanceData: | |
"""Tracking forward function. | |
Args: | |
model (nn.Module): VIS model. | |
img (Tensor): of shape (T, C, H, W) encoding input image. | |
Typically these should be mean centered and std scaled. | |
The T denotes the number of key images and usually is 1 in | |
MaskTrackRCNN method. | |
feats (list[Tensor]): Multi level feature maps of `img`. | |
data_sample (:obj:`TrackDataSample`): The data sample. | |
It includes information such as `pred_det_instances`. | |
rescale (bool, optional): If True, the bounding boxes should be | |
rescaled to fit the original scale of the image. Defaults to | |
True. | |
Returns: | |
:obj:`InstanceData`: Tracking results of the input images. | |
Each InstanceData usually contains ``bboxes``, ``labels``, | |
``scores`` and ``instances_id``. | |
""" | |
metainfo = data_sample.metainfo | |
bboxes = data_sample.pred_instances.bboxes | |
masks = data_sample.pred_instances.masks | |
labels = data_sample.pred_instances.labels | |
scores = data_sample.pred_instances.scores | |
frame_id = metainfo.get('frame_id', -1) | |
# create pred_track_instances | |
pred_track_instances = InstanceData() | |
if bboxes.shape[0] == 0: | |
ids = torch.zeros_like(labels) | |
pred_track_instances = data_sample.pred_instances.clone() | |
pred_track_instances.instances_id = ids | |
return pred_track_instances | |
rescaled_bboxes = bboxes.clone() | |
if rescale: | |
scale_factor = rescaled_bboxes.new_tensor( | |
metainfo['scale_factor']).repeat((1, 2)) | |
rescaled_bboxes = rescaled_bboxes * scale_factor | |
roi_feats, _ = model.track_head.extract_roi_feats( | |
feats, [rescaled_bboxes]) | |
if self.empty: | |
num_new_tracks = bboxes.size(0) | |
ids = torch.arange( | |
self.num_tracks, | |
self.num_tracks + num_new_tracks, | |
dtype=torch.long) | |
self.num_tracks += num_new_tracks | |
else: | |
prev_bboxes = self.get('bboxes') | |
prev_labels = self.get('labels') | |
prev_roi_feats = self.get('roi_feats') | |
similarity_logits = model.track_head.predict( | |
roi_feats, prev_roi_feats) | |
match_scores = self.get_match_score(bboxes, labels, scores, | |
prev_bboxes, prev_labels, | |
similarity_logits) | |
ids, _ = self.assign_ids(match_scores) | |
valid_inds = ids > -1 | |
ids = ids[valid_inds] | |
bboxes = bboxes[valid_inds] | |
labels = labels[valid_inds] | |
scores = scores[valid_inds] | |
masks = masks[valid_inds] | |
roi_feats = roi_feats[valid_inds] | |
self.update( | |
ids=ids, | |
bboxes=bboxes, | |
labels=labels, | |
scores=scores, | |
masks=masks, | |
roi_feats=roi_feats, | |
frame_ids=frame_id) | |
# update pred_track_instances | |
pred_track_instances.bboxes = bboxes | |
pred_track_instances.masks = masks | |
pred_track_instances.labels = labels | |
pred_track_instances.scores = scores | |
pred_track_instances.instances_id = ids | |
return pred_track_instances | |