AK391
all files
7734d5b
raw history blame
No virus
6.01 kB
import numpy as np
from mmdet.core import bbox2result
from mmdet.models import TwoStageDetector
from qdtrack.core import track2result
from ..builder import MODELS, build_tracker
from qdtrack.core import imshow_tracks, restore_result
from tracker import BYTETracker
@MODELS.register_module()
class QDTrack(TwoStageDetector):
def __init__(self, tracker=None, freeze_detector=False, *args, **kwargs):
self.prepare_cfg(kwargs)
super().__init__(*args, **kwargs)
self.tracker_cfg = tracker
self.freeze_detector = freeze_detector
if self.freeze_detector:
self._freeze_detector()
def _freeze_detector(self):
self.detector = [
self.backbone, self.neck, self.rpn_head, self.roi_head.bbox_head
]
for model in self.detector:
model.eval()
for param in model.parameters():
param.requires_grad = False
def prepare_cfg(self, kwargs):
if kwargs.get('train_cfg', False):
kwargs['roi_head']['track_train_cfg'] = kwargs['train_cfg'].get(
'embed', None)
def init_tracker(self):
# self.tracker = build_tracker(self.tracker_cfg)
self.tracker = BYTETracker()
def forward_train(self,
img,
img_metas,
gt_bboxes,
gt_labels,
gt_match_indices,
ref_img,
ref_img_metas,
ref_gt_bboxes,
ref_gt_labels,
ref_gt_match_indices,
gt_bboxes_ignore=None,
gt_masks=None,
ref_gt_bboxes_ignore=None,
ref_gt_masks=None,
**kwargs):
x = self.extract_feat(img)
losses = dict()
# RPN forward and loss
proposal_cfg = self.train_cfg.get('rpn_proposal', self.test_cfg.rpn)
rpn_losses, proposal_list = self.rpn_head.forward_train(
x,
img_metas,
gt_bboxes,
gt_labels=None,
gt_bboxes_ignore=gt_bboxes_ignore,
proposal_cfg=proposal_cfg)
losses.update(rpn_losses)
ref_x = self.extract_feat(ref_img)
ref_proposals = self.rpn_head.simple_test_rpn(ref_x, ref_img_metas)
roi_losses = self.roi_head.forward_train(
x, img_metas, proposal_list, gt_bboxes, gt_labels,
gt_match_indices, ref_x, ref_img_metas, ref_proposals,
ref_gt_bboxes, ref_gt_labels, gt_bboxes_ignore, gt_masks,
ref_gt_bboxes_ignore, **kwargs)
losses.update(roi_losses)
return losses
def simple_test(self, img, img_metas, rescale=False):
# TODO inherit from a base tracker
assert self.roi_head.with_track, 'Track head must be implemented.'
frame_id = img_metas[0].get('frame_id', -1)
if frame_id == 0:
self.init_tracker()
x = self.extract_feat(img)
proposal_list = self.rpn_head.simple_test_rpn(x, img_metas)
det_bboxes, det_labels, track_feats = self.roi_head.simple_test(x, img_metas, proposal_list, rescale)
bboxes, labels, ids = self.tracker.update(det_bboxes, det_labels, frame_id, track_feats)
# if track_feats is not None:
# bboxes, labels, ids = self.tracker.match(
# bboxes=det_bboxes,
# labels=det_labels,
# track_feats=track_feats,
# frame_id=frame_id)
bbox_result = bbox2result(det_bboxes, det_labels,
self.roi_head.bbox_head.num_classes)
if track_feats is not None:
track_result = track2result(bboxes, labels, ids,
self.roi_head.bbox_head.num_classes)
else:
track_result = [
np.zeros((0, 6), dtype=np.float32)
for i in range(self.roi_head.bbox_head.num_classes)
]
return dict(bbox_results=bbox_result, track_results=track_result)
def show_result(self,
img,
result,
thickness=1,
font_scale=0.5,
show=False,
out_file=None,
wait_time=0,
backend='cv2',
**kwargs):
"""Visualize tracking results.
Args:
img (str | ndarray): Filename of loaded image.
result (dict): Tracking result.
The value of key 'track_results' is ndarray with shape (n, 6)
in [id, tl_x, tl_y, br_x, br_y, score] format.
The value of key 'bbox_results' is ndarray with shape (n, 5)
in [tl_x, tl_y, br_x, br_y, score] format.
thickness (int, optional): Thickness of lines. Defaults to 1.
font_scale (float, optional): Font scales of texts. Defaults
to 0.5.
show (bool, optional): Whether show the visualizations on the
fly. Defaults to False.
out_file (str | None, optional): Output filename. Defaults to None.
backend (str, optional): Backend to draw the bounding boxes,
options are `cv2` and `plt`. Defaults to 'cv2'.
Returns:
ndarray: Visualized image.
"""
assert isinstance(result, dict)
track_result = result.get('track_results', None)
bboxes, labels, ids = restore_result(track_result, return_ids=True)
img = imshow_tracks(
img,
bboxes,
labels,
ids,
classes=self.CLASSES,
thickness=thickness,
font_scale=font_scale,
show=show,
out_file=out_file,
wait_time=wait_time,
backend=backend)
return img