AK391
all files
7734d5b
raw
history blame
6.01 kB
import numpy as np
from mmdet.core import bbox2result
from mmdet.models import TwoStageDetector
from qdtrack.core import track2result
from ..builder import MODELS, build_tracker
from qdtrack.core import imshow_tracks, restore_result
from tracker import BYTETracker
@MODELS.register_module()
class QDTrack(TwoStageDetector):
def __init__(self, tracker=None, freeze_detector=False, *args, **kwargs):
self.prepare_cfg(kwargs)
super().__init__(*args, **kwargs)
self.tracker_cfg = tracker
self.freeze_detector = freeze_detector
if self.freeze_detector:
self._freeze_detector()
def _freeze_detector(self):
self.detector = [
self.backbone, self.neck, self.rpn_head, self.roi_head.bbox_head
]
for model in self.detector:
model.eval()
for param in model.parameters():
param.requires_grad = False
def prepare_cfg(self, kwargs):
if kwargs.get('train_cfg', False):
kwargs['roi_head']['track_train_cfg'] = kwargs['train_cfg'].get(
'embed', None)
def init_tracker(self):
# self.tracker = build_tracker(self.tracker_cfg)
self.tracker = BYTETracker()
def forward_train(self,
img,
img_metas,
gt_bboxes,
gt_labels,
gt_match_indices,
ref_img,
ref_img_metas,
ref_gt_bboxes,
ref_gt_labels,
ref_gt_match_indices,
gt_bboxes_ignore=None,
gt_masks=None,
ref_gt_bboxes_ignore=None,
ref_gt_masks=None,
**kwargs):
x = self.extract_feat(img)
losses = dict()
# RPN forward and loss
proposal_cfg = self.train_cfg.get('rpn_proposal', self.test_cfg.rpn)
rpn_losses, proposal_list = self.rpn_head.forward_train(
x,
img_metas,
gt_bboxes,
gt_labels=None,
gt_bboxes_ignore=gt_bboxes_ignore,
proposal_cfg=proposal_cfg)
losses.update(rpn_losses)
ref_x = self.extract_feat(ref_img)
ref_proposals = self.rpn_head.simple_test_rpn(ref_x, ref_img_metas)
roi_losses = self.roi_head.forward_train(
x, img_metas, proposal_list, gt_bboxes, gt_labels,
gt_match_indices, ref_x, ref_img_metas, ref_proposals,
ref_gt_bboxes, ref_gt_labels, gt_bboxes_ignore, gt_masks,
ref_gt_bboxes_ignore, **kwargs)
losses.update(roi_losses)
return losses
def simple_test(self, img, img_metas, rescale=False):
# TODO inherit from a base tracker
assert self.roi_head.with_track, 'Track head must be implemented.'
frame_id = img_metas[0].get('frame_id', -1)
if frame_id == 0:
self.init_tracker()
x = self.extract_feat(img)
proposal_list = self.rpn_head.simple_test_rpn(x, img_metas)
det_bboxes, det_labels, track_feats = self.roi_head.simple_test(x, img_metas, proposal_list, rescale)
bboxes, labels, ids = self.tracker.update(det_bboxes, det_labels, frame_id, track_feats)
# if track_feats is not None:
# bboxes, labels, ids = self.tracker.match(
# bboxes=det_bboxes,
# labels=det_labels,
# track_feats=track_feats,
# frame_id=frame_id)
bbox_result = bbox2result(det_bboxes, det_labels,
self.roi_head.bbox_head.num_classes)
if track_feats is not None:
track_result = track2result(bboxes, labels, ids,
self.roi_head.bbox_head.num_classes)
else:
track_result = [
np.zeros((0, 6), dtype=np.float32)
for i in range(self.roi_head.bbox_head.num_classes)
]
return dict(bbox_results=bbox_result, track_results=track_result)
def show_result(self,
img,
result,
thickness=1,
font_scale=0.5,
show=False,
out_file=None,
wait_time=0,
backend='cv2',
**kwargs):
"""Visualize tracking results.
Args:
img (str | ndarray): Filename of loaded image.
result (dict): Tracking result.
The value of key 'track_results' is ndarray with shape (n, 6)
in [id, tl_x, tl_y, br_x, br_y, score] format.
The value of key 'bbox_results' is ndarray with shape (n, 5)
in [tl_x, tl_y, br_x, br_y, score] format.
thickness (int, optional): Thickness of lines. Defaults to 1.
font_scale (float, optional): Font scales of texts. Defaults
to 0.5.
show (bool, optional): Whether show the visualizations on the
fly. Defaults to False.
out_file (str | None, optional): Output filename. Defaults to None.
backend (str, optional): Backend to draw the bounding boxes,
options are `cv2` and `plt`. Defaults to 'cv2'.
Returns:
ndarray: Visualized image.
"""
assert isinstance(result, dict)
track_result = result.get('track_results', None)
bboxes, labels, ids = restore_result(track_result, return_ids=True)
img = imshow_tracks(
img,
bboxes,
labels,
ids,
classes=self.CLASSES,
thickness=thickness,
font_scale=font_scale,
show=show,
out_file=out_file,
wait_time=wait_time,
backend=backend)
return img