Spaces:
Runtime error
Runtime error
import numpy as np | |
import torchvision | |
import time | |
import math | |
import os | |
import copy | |
import pdb | |
import argparse | |
import sys | |
import cv2 | |
import skimage.io | |
import skimage.transform | |
import skimage.color | |
import skimage | |
import torch | |
import model | |
from torch.utils.data import Dataset, DataLoader | |
from torchvision import datasets, models, transforms | |
from dataloader import CSVDataset, collater, Resizer, AspectRatioBasedSampler, Augmenter, UnNormalizer, Normalizer, RGB_MEAN, RGB_STD | |
from scipy.optimize import linear_sum_assignment | |
from tracker import BYTETracker | |
def write_results(filename, results): | |
save_format = '{frame},{id},{x1},{y1},{w},{h},{s},-1,-1,-1\n' | |
with open(filename, 'w') as f: | |
for frame_id, tlwhs, track_ids, scores in results: | |
for tlwh, track_id, score in zip(tlwhs, track_ids, scores): | |
if track_id < 0: | |
continue | |
x1, y1, w, h = tlwh | |
line = save_format.format(frame=frame_id, id=track_id, x1=round(x1, 1), y1=round(y1, 1), w=round(w, 1), h=round(h, 1), s=round(score, 2)) | |
f.write(line) | |
def write_results_no_score(filename, results): | |
save_format = '{frame},{id},{x1},{y1},{w},{h},-1,-1,-1,-1\n' | |
with open(filename, 'w') as f: | |
for frame_id, tlwhs, track_ids in results: | |
for tlwh, track_id in zip(tlwhs, track_ids): | |
if track_id < 0: | |
continue | |
x1, y1, w, h = tlwh | |
line = save_format.format(frame=frame_id, id=track_id, x1=round(x1, 1), y1=round(y1, 1), w=round(w, 1), h=round(h, 1)) | |
f.write(line) | |
def run_each_dataset(model_dir, retinanet, dataset_path, subset, cur_dataset): | |
print(cur_dataset) | |
img_list = os.listdir(os.path.join(dataset_path, subset, cur_dataset, 'img1')) | |
img_list = [os.path.join(dataset_path, subset, cur_dataset, 'img1', _) for _ in img_list if ('jpg' in _) or ('png' in _)] | |
img_list = sorted(img_list) | |
img_len = len(img_list) | |
last_feat = None | |
confidence_threshold = 0.6 | |
IOU_threshold = 0.5 | |
retention_threshold = 10 | |
det_list_all = [] | |
tracklet_all = [] | |
results = [] | |
max_id = 0 | |
max_draw_len = 100 | |
draw_interval = 5 | |
img_width = 1920 | |
img_height = 1080 | |
fps = 30 | |
tracker = BYTETracker() | |
for idx in range((int(img_len / 2)), img_len + 1): | |
i = idx - 1 | |
print('tracking: ', i) | |
with torch.no_grad(): | |
data_path1 = img_list[min(idx, img_len - 1)] | |
img_origin1 = skimage.io.imread(data_path1) | |
img_h, img_w, _ = img_origin1.shape | |
img_height, img_width = img_h, img_w | |
resize_h, resize_w = math.ceil(img_h / 32) * 32, math.ceil(img_w / 32) * 32 | |
img1 = np.zeros((resize_h, resize_w, 3), dtype=img_origin1.dtype) | |
img1[:img_h, :img_w, :] = img_origin1 | |
img1 = (img1.astype(np.float32) / 255.0 - np.array([[RGB_MEAN]])) / np.array([[RGB_STD]]) | |
img1 = torch.from_numpy(img1).permute(2, 0, 1).view(1, 3, resize_h, resize_w) | |
scores, transformed_anchors, last_feat = retinanet(img1.cuda().float(), last_feat=last_feat) | |
if idx > (int(img_len / 2)): | |
idxs = np.where(scores > 0.1) | |
# run tracking | |
online_targets = tracker.update(transformed_anchors[idxs[0], :4], scores[idxs[0]]) | |
online_tlwhs = [] | |
online_ids = [] | |
online_scores = [] | |
for t in online_targets: | |
tlwh = t.tlwh | |
tid = t.track_id | |
online_tlwhs.append(tlwh) | |
online_ids.append(tid) | |
online_scores.append(t.score) | |
results.append((idx, online_tlwhs, online_ids, online_scores)) | |
fout_tracking = os.path.join(model_dir, 'results', cur_dataset + '.txt') | |
write_results(fout_tracking, results) | |
def main(args=None): | |
parser = argparse.ArgumentParser(description='Simple script for testing a CTracker network.') | |
parser.add_argument('--dataset_path', default='/dockerdata/home/jeromepeng/data/MOT/MOT17/', type=str, | |
help='Dataset path, location of the images sequence.') | |
parser.add_argument('--model_dir', default='./trained_model/', help='Path to model (.pt) file.') | |
parser.add_argument('--model_path', default='./trained_model/model_final.pth', help='Path to model (.pt) file.') | |
parser.add_argument('--seq_nums', default=0, type=int) | |
parser = parser.parse_args(args) | |
if not os.path.exists(os.path.join(parser.model_dir, 'results')): | |
os.makedirs(os.path.join(parser.model_dir, 'results')) | |
retinanet = model.resnet50(num_classes=1, pretrained=True) | |
# retinanet_save = torch.load(os.path.join(parser.model_dir, 'model_final.pth')) | |
retinanet_save = torch.load(os.path.join(parser.model_path)) | |
# rename moco pre-trained keys | |
state_dict = retinanet_save.state_dict() | |
for k in list(state_dict.keys()): | |
# retain only encoder up to before the embedding layer | |
if k.startswith('module.'): | |
# remove prefix | |
state_dict[k[len("module."):]] = state_dict[k] | |
# delete renamed or unused k | |
del state_dict[k] | |
retinanet.load_state_dict(state_dict) | |
use_gpu = True | |
if use_gpu: retinanet = retinanet.cuda() | |
retinanet.eval() | |
seq_nums = [] | |
if parser.seq_nums > 0: | |
seq_nums.append(parser.seq_nums) | |
else: | |
seq_nums = [2, 4, 5, 9, 10, 11, 13] | |
for seq_num in seq_nums: | |
run_each_dataset(parser.model_dir, retinanet, parser.dataset_path, 'train', 'MOT17-{:02d}'.format(seq_num)) | |
# for seq_num in [1, 3, 6, 7, 8, 12, 14]: | |
# run_each_dataset(parser.model_dir, retinanet, parser.dataset_path, 'test', 'MOT17-{:02d}'.format(seq_num)) | |
if __name__ == '__main__': | |
main() | |