""" Copyright (c) https://github.com/xingyizhou/CenterTrack Modified by Peize Sun, Rufeng Zhang """ # coding: utf-8 import torch from scipy.optimize import linear_sum_assignment from util import box_ops import copy class Tracker(object): def __init__(self, score_thresh, max_age=32): self.score_thresh = score_thresh self.low_thresh = 0.2 self.high_thresh = score_thresh + 0.1 self.max_age = max_age self.id_count = 0 self.tracks_dict = dict() self.tracks = list() self.unmatched_tracks = list() self.reset_all() def reset_all(self): self.id_count = 0 self.tracks_dict = dict() self.tracks = list() self.unmatched_tracks = list() def init_track(self, results): scores = results["scores"] classes = results["labels"] bboxes = results["boxes"] # x1y1x2y2 ret = list() ret_dict = dict() for idx in range(scores.shape[0]): if scores[idx] >= self.score_thresh: self.id_count += 1 obj = dict() obj["score"] = float(scores[idx]) obj["bbox"] = bboxes[idx, :].cpu().numpy().tolist() obj["tracking_id"] = self.id_count obj['active'] = 1 obj['age'] = 1 ret.append(obj) ret_dict[idx] = obj self.tracks = ret self.tracks_dict = ret_dict return copy.deepcopy(ret) def step(self, output_results): scores = output_results["scores"] bboxes = output_results["boxes"] # x1y1x2y2 track_bboxes = output_results["track_boxes"] if "track_boxes" in output_results else None # x1y1x2y2 results = list() results_dict = dict() results_second = list() tracks = list() for idx in range(scores.shape[0]): if idx in self.tracks_dict and track_bboxes is not None: self.tracks_dict[idx]["bbox"] = track_bboxes[idx, :].cpu().numpy().tolist() if scores[idx] >= self.score_thresh: obj = dict() obj["score"] = float(scores[idx]) obj["bbox"] = bboxes[idx, :].cpu().numpy().tolist() results.append(obj) results_dict[idx] = obj elif scores[idx] >= self.low_thresh: second_obj = dict() second_obj["score"] = float(scores[idx]) second_obj["bbox"] = bboxes[idx, :].cpu().numpy().tolist() results_second.append(second_obj) results_dict[idx] = second_obj tracks = [v for v in self.tracks_dict.values()] + self.unmatched_tracks # for trackss in tracks: # print(trackss.keys()) N = len(results) M = len(tracks) ret = list() unmatched_tracks = [t for t in range(M)] unmatched_dets = [d for d in range(N)] if N > 0 and M > 0: det_box = torch.stack([torch.tensor(obj['bbox']) for obj in results], dim=0) # N x 4 track_box = torch.stack([torch.tensor(obj['bbox']) for obj in tracks], dim=0) # M x 4 cost_bbox = 1.0 - box_ops.generalized_box_iou(det_box, track_box) # N x M matched_indices = linear_sum_assignment(cost_bbox) unmatched_dets = [d for d in range(N) if not (d in matched_indices[0])] unmatched_tracks = [d for d in range(M) if not (d in matched_indices[1])] matches = [[],[]] for (m0, m1) in zip(matched_indices[0], matched_indices[1]): if cost_bbox[m0, m1] > 1.2: unmatched_dets.append(m0) unmatched_tracks.append(m1) else: matches[0].append(m0) matches[1].append(m1) for (m0, m1) in zip(matches[0], matches[1]): track = results[m0] track['tracking_id'] = tracks[m1]['tracking_id'] track['age'] = 1 track['active'] = 1 ret.append(track) # second association N_second = len(results_second) unmatched_tracks_obj = list() for i in unmatched_tracks: #print(tracks[i].keys()) track = tracks[i] if track['active'] == 1: unmatched_tracks_obj.append(track) M_second = len(unmatched_tracks_obj) unmatched_tracks_second = [t for t in range(M_second)] if N_second > 0 and M_second > 0: det_box_second = torch.stack([torch.tensor(obj['bbox']) for obj in results_second], dim=0) # N_second x 4 track_box_second = torch.stack([torch.tensor(obj['bbox']) for obj in unmatched_tracks_obj], dim=0) # M_second x 4 cost_bbox_second = 1.0 - box_ops.generalized_box_iou(det_box_second, track_box_second) # N_second x M_second matched_indices_second = linear_sum_assignment(cost_bbox_second) unmatched_tracks_second = [d for d in range(M_second) if not (d in matched_indices_second[1])] matches_second = [[],[]] for (m0, m1) in zip(matched_indices_second[0], matched_indices_second[1]): if cost_bbox_second[m0, m1] > 0.8: unmatched_tracks_second.append(m1) else: matches_second[0].append(m0) matches_second[1].append(m1) for (m0, m1) in zip(matches_second[0], matches_second[1]): track = results_second[m0] track['tracking_id'] = unmatched_tracks_obj[m1]['tracking_id'] track['age'] = 1 track['active'] = 1 ret.append(track) for i in unmatched_dets: trackd = results[i] if trackd["score"] >= self.high_thresh: self.id_count += 1 trackd['tracking_id'] = self.id_count trackd['age'] = 1 trackd['active'] = 1 ret.append(trackd) # ------------------------------------------------------ # ret_unmatched_tracks = [] for j in unmatched_tracks: track = tracks[j] if track['active'] == 0 and track['age'] < self.max_age: track['age'] += 1 track['active'] = 0 ret.append(track) ret_unmatched_tracks.append(track) for i in unmatched_tracks_second: track = unmatched_tracks_obj[i] if track['age'] < self.max_age: track['age'] += 1 track['active'] = 0 ret.append(track) ret_unmatched_tracks.append(track) # for i in unmatched_tracks: # track = tracks[i] # if track['age'] < self.max_age: # track['age'] += 1 # track['active'] = 0 # ret.append(track) # ret_unmatched_tracks.append(track) #print(len(ret_unmatched_tracks)) self.tracks = ret self.tracks_dict = {red_ind:red for red_ind, red in results_dict.items() if 'tracking_id' in red} self.unmatched_tracks = ret_unmatched_tracks return copy.deepcopy(ret)