Spaces:
Runtime error
Runtime error
File size: 6,005 Bytes
7734d5b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 |
import numpy as np
from mmdet.core import bbox2result
from mmdet.models import TwoStageDetector
from qdtrack.core import track2result
from ..builder import MODELS, build_tracker
from qdtrack.core import imshow_tracks, restore_result
from tracker import BYTETracker
@MODELS.register_module()
class QDTrack(TwoStageDetector):
def __init__(self, tracker=None, freeze_detector=False, *args, **kwargs):
self.prepare_cfg(kwargs)
super().__init__(*args, **kwargs)
self.tracker_cfg = tracker
self.freeze_detector = freeze_detector
if self.freeze_detector:
self._freeze_detector()
def _freeze_detector(self):
self.detector = [
self.backbone, self.neck, self.rpn_head, self.roi_head.bbox_head
]
for model in self.detector:
model.eval()
for param in model.parameters():
param.requires_grad = False
def prepare_cfg(self, kwargs):
if kwargs.get('train_cfg', False):
kwargs['roi_head']['track_train_cfg'] = kwargs['train_cfg'].get(
'embed', None)
def init_tracker(self):
# self.tracker = build_tracker(self.tracker_cfg)
self.tracker = BYTETracker()
def forward_train(self,
img,
img_metas,
gt_bboxes,
gt_labels,
gt_match_indices,
ref_img,
ref_img_metas,
ref_gt_bboxes,
ref_gt_labels,
ref_gt_match_indices,
gt_bboxes_ignore=None,
gt_masks=None,
ref_gt_bboxes_ignore=None,
ref_gt_masks=None,
**kwargs):
x = self.extract_feat(img)
losses = dict()
# RPN forward and loss
proposal_cfg = self.train_cfg.get('rpn_proposal', self.test_cfg.rpn)
rpn_losses, proposal_list = self.rpn_head.forward_train(
x,
img_metas,
gt_bboxes,
gt_labels=None,
gt_bboxes_ignore=gt_bboxes_ignore,
proposal_cfg=proposal_cfg)
losses.update(rpn_losses)
ref_x = self.extract_feat(ref_img)
ref_proposals = self.rpn_head.simple_test_rpn(ref_x, ref_img_metas)
roi_losses = self.roi_head.forward_train(
x, img_metas, proposal_list, gt_bboxes, gt_labels,
gt_match_indices, ref_x, ref_img_metas, ref_proposals,
ref_gt_bboxes, ref_gt_labels, gt_bboxes_ignore, gt_masks,
ref_gt_bboxes_ignore, **kwargs)
losses.update(roi_losses)
return losses
def simple_test(self, img, img_metas, rescale=False):
# TODO inherit from a base tracker
assert self.roi_head.with_track, 'Track head must be implemented.'
frame_id = img_metas[0].get('frame_id', -1)
if frame_id == 0:
self.init_tracker()
x = self.extract_feat(img)
proposal_list = self.rpn_head.simple_test_rpn(x, img_metas)
det_bboxes, det_labels, track_feats = self.roi_head.simple_test(x, img_metas, proposal_list, rescale)
bboxes, labels, ids = self.tracker.update(det_bboxes, det_labels, frame_id, track_feats)
# if track_feats is not None:
# bboxes, labels, ids = self.tracker.match(
# bboxes=det_bboxes,
# labels=det_labels,
# track_feats=track_feats,
# frame_id=frame_id)
bbox_result = bbox2result(det_bboxes, det_labels,
self.roi_head.bbox_head.num_classes)
if track_feats is not None:
track_result = track2result(bboxes, labels, ids,
self.roi_head.bbox_head.num_classes)
else:
track_result = [
np.zeros((0, 6), dtype=np.float32)
for i in range(self.roi_head.bbox_head.num_classes)
]
return dict(bbox_results=bbox_result, track_results=track_result)
def show_result(self,
img,
result,
thickness=1,
font_scale=0.5,
show=False,
out_file=None,
wait_time=0,
backend='cv2',
**kwargs):
"""Visualize tracking results.
Args:
img (str | ndarray): Filename of loaded image.
result (dict): Tracking result.
The value of key 'track_results' is ndarray with shape (n, 6)
in [id, tl_x, tl_y, br_x, br_y, score] format.
The value of key 'bbox_results' is ndarray with shape (n, 5)
in [tl_x, tl_y, br_x, br_y, score] format.
thickness (int, optional): Thickness of lines. Defaults to 1.
font_scale (float, optional): Font scales of texts. Defaults
to 0.5.
show (bool, optional): Whether show the visualizations on the
fly. Defaults to False.
out_file (str | None, optional): Output filename. Defaults to None.
backend (str, optional): Backend to draw the bounding boxes,
options are `cv2` and `plt`. Defaults to 'cv2'.
Returns:
ndarray: Visualized image.
"""
assert isinstance(result, dict)
track_result = result.get('track_results', None)
bboxes, labels, ids = restore_result(track_result, return_ids=True)
img = imshow_tracks(
img,
bboxes,
labels,
ids,
classes=self.CLASSES,
thickness=thickness,
font_scale=font_scale,
show=show,
out_file=out_file,
wait_time=wait_time,
backend=backend)
return img
|