|
|
|
from copy import deepcopy |
|
import json |
|
|
|
import os |
|
import argparse |
|
import torchvision.transforms.functional as F |
|
import torch |
|
import cv2 |
|
import numpy as np |
|
from tqdm import tqdm |
|
from pathlib import Path |
|
import sys |
|
sys.path.append('VISAM') |
|
from main import get_args_parser |
|
from models import build_model |
|
from util.tool import load_model |
|
from models.structures import Instances |
|
|
|
from torch.utils.data import Dataset, DataLoader |
|
|
|
|
|
|
|
sys.path.append('segment_anything') |
|
from segment_anything import build_sam, SamPredictor |
|
|
|
|
|
class Colors: |
|
|
|
def __init__(self): |
|
|
|
hexs = ('FF3838', 'FF9D97', 'FF701F', 'FFB21D', 'CFD231', '48F90A', '92CC17', '3DDB86', '1A9334', '00D4BB', |
|
'2C99A8', '00C2FF', '344593', '6473FF', '0018EC', '8438FF', '520085', 'CB38FF', 'FF95C8', 'FF37C7') |
|
self.palette = [self.hex2rgb(f'#{c}') for c in hexs] |
|
self.n = len(self.palette) |
|
|
|
def __call__(self, i, bgr=False): |
|
c = self.palette[int(i) % self.n] |
|
return (c[2], c[1], c[0]) if bgr else c |
|
|
|
@staticmethod |
|
def hex2rgb(h): |
|
return tuple(int(h[1 + i:1 + i + 2], 16) for i in (0, 2, 4)) |
|
|
|
|
|
colors = Colors() |
|
|
|
|
|
class ListImgDataset(Dataset): |
|
def __init__(self, mot_path, img_list, det_db) -> None: |
|
super().__init__() |
|
self.mot_path = mot_path |
|
self.img_list = img_list |
|
self.det_db = det_db |
|
|
|
''' |
|
common settings |
|
''' |
|
self.img_height = 800 |
|
self.img_width = 1536 |
|
self.mean = [0.485, 0.456, 0.406] |
|
self.std = [0.229, 0.224, 0.225] |
|
|
|
def load_img_from_file(self, f_path): |
|
cur_img = cv2.imread(os.path.join(self.mot_path, f_path)) |
|
assert cur_img is not None, f_path |
|
cur_img = cv2.cvtColor(cur_img, cv2.COLOR_BGR2RGB) |
|
proposals = [] |
|
im_h, im_w = cur_img.shape[:2] |
|
for line in self.det_db[f_path[:-4] + '.txt']: |
|
l, t, w, h, s = list(map(float, line.split(','))) |
|
proposals.append([(l + w / 2) / im_w, |
|
(t + h / 2) / im_h, |
|
w / im_w, |
|
h / im_h, |
|
s]) |
|
return cur_img, torch.as_tensor(proposals).reshape(-1, 5) |
|
|
|
def init_img(self, img, proposals): |
|
ori_img = img.copy() |
|
self.seq_h, self.seq_w = img.shape[:2] |
|
scale = self.img_height / min(self.seq_h, self.seq_w) |
|
if max(self.seq_h, self.seq_w) * scale > self.img_width: |
|
scale = self.img_width / max(self.seq_h, self.seq_w) |
|
target_h = int(self.seq_h * scale) |
|
target_w = int(self.seq_w * scale) |
|
img = cv2.resize(img, (target_w, target_h)) |
|
img = F.normalize(F.to_tensor(img), self.mean, self.std) |
|
img = img.unsqueeze(0) |
|
return img, ori_img, proposals |
|
|
|
def __len__(self): |
|
return len(self.img_list) |
|
|
|
def __getitem__(self, index): |
|
img, proposals = self.load_img_from_file(self.img_list[index]) |
|
return self.init_img(img, proposals) |
|
|
|
|
|
class Detector(object): |
|
def __init__(self, args, model, vid, sam_predictor=None): |
|
self.args = args |
|
self.detr = model |
|
|
|
self.vid = vid |
|
self.seq_num = os.path.basename(vid) |
|
img_list = os.listdir(os.path.join(self.args.mot_path, vid, 'img1')) |
|
img_list = [os.path.join(vid, 'img1', i) for i in img_list if 'jpg' in i] |
|
|
|
self.img_list = sorted(img_list) |
|
self.img_len = len(self.img_list) |
|
|
|
self.predict_path = os.path.join(self.args.output_dir, args.exp_name) |
|
os.makedirs(self.predict_path, exist_ok=True) |
|
|
|
fps = 25 |
|
size = (1920, 1080) |
|
self.videowriter = cv2.VideoWriter('visam.avi', cv2.VideoWriter_fourcc('M','J','P','G'), fps, size) |
|
|
|
self.sam_predictor = sam_predictor |
|
|
|
@staticmethod |
|
def filter_dt_by_score(dt_instances: Instances, prob_threshold: float) -> Instances: |
|
keep = dt_instances.scores > prob_threshold |
|
keep &= dt_instances.obj_idxes >= 0 |
|
return dt_instances[keep] |
|
|
|
@staticmethod |
|
def filter_dt_by_area(dt_instances: Instances, area_threshold: float) -> Instances: |
|
wh = dt_instances.boxes[:, 2:4] - dt_instances.boxes[:, 0:2] |
|
areas = wh[:, 0] * wh[:, 1] |
|
keep = areas > area_threshold |
|
return dt_instances[keep] |
|
|
|
def detect(self, prob_threshold=0.6, area_threshold=100, vis=False): |
|
total_dts = 0 |
|
total_occlusion_dts = 0 |
|
|
|
track_instances = None |
|
with open(os.path.join(self.args.mot_path, 'DanceTrack', self.args.det_db)) as f: |
|
det_db = json.load(f) |
|
loader = DataLoader(ListImgDataset(self.args.mot_path, self.img_list, det_db), 1, num_workers=2) |
|
lines = [] |
|
for i, data in enumerate(tqdm(loader)): |
|
cur_img, ori_img, proposals = [d[0] for d in data] |
|
cur_img, proposals = cur_img.cuda(), proposals.cuda() |
|
|
|
|
|
if track_instances is not None: |
|
track_instances.remove('boxes') |
|
track_instances.remove('labels') |
|
seq_h, seq_w, _ = ori_img.shape |
|
|
|
res = self.detr.inference_single_image(cur_img, (seq_h, seq_w), track_instances, proposals) |
|
track_instances = res['track_instances'] |
|
|
|
dt_instances = deepcopy(track_instances) |
|
|
|
|
|
dt_instances = self.filter_dt_by_score(dt_instances, prob_threshold) |
|
dt_instances = self.filter_dt_by_area(dt_instances, area_threshold) |
|
|
|
total_dts += len(dt_instances) |
|
|
|
bbox_xyxy = dt_instances.boxes.tolist() |
|
identities = dt_instances.obj_idxes.tolist() |
|
|
|
img = ori_img.to(torch.device('cpu')).numpy().copy()[..., ::-1] |
|
if self.sam_predictor is not None: |
|
masks_all = [] |
|
self.sam_predictor.set_image(ori_img.to(torch.device('cpu')).numpy().copy()) |
|
|
|
for bbox, id in zip(np.array(bbox_xyxy), identities): |
|
masks, iou_predictions, low_res_masks = self.sam_predictor.predict(box=bbox) |
|
index_max = iou_predictions.argsort()[0] |
|
masks = np.concatenate([masks[index_max:(index_max+1)], masks[index_max:(index_max+1)], masks[index_max:(index_max+1)]], axis=0) |
|
masks = masks.astype(np.int32)*np.array(colors(id))[:, None, None] |
|
masks_all.append(masks) |
|
|
|
self.sam_predictor.reset_image() |
|
if len(masks_all): |
|
masks_sum = masks_all[0].copy() |
|
for m in masks_all[1:]: |
|
masks_sum += m |
|
else: |
|
masks_sum = np.zeros_like(img).transpose(2, 0, 1) |
|
|
|
img = (img * 0.5 + (masks_sum.transpose(1,2,0) * 30) %128).astype(np.uint8) |
|
for bbox in bbox_xyxy: |
|
cv2.rectangle(img, (int(bbox[0]), int(bbox[1])), (int(bbox[2]), int(bbox[3])), (0,0,255), thickness=3) |
|
self.videowriter.write(img) |
|
|
|
save_format = '{frame},{id},{x1:.2f},{y1:.2f},{w:.2f},{h:.2f},1,-1,-1,-1\n' |
|
for xyxy, track_id in zip(bbox_xyxy, identities): |
|
if track_id < 0 or track_id is None: |
|
continue |
|
x1, y1, x2, y2 = xyxy |
|
w, h = x2 - x1, y2 - y1 |
|
lines.append(save_format.format(frame=i + 1, id=track_id, x1=x1, y1=y1, w=w, h=h)) |
|
with open(os.path.join(self.predict_path, f'{self.seq_num}.txt'), 'w') as f: |
|
f.writelines(lines) |
|
print("totally {} dts {} occlusion dts".format(total_dts, total_occlusion_dts)) |
|
|
|
|
|
class RuntimeTrackerBase(object): |
|
def __init__(self, score_thresh=0.6, filter_score_thresh=0.5, miss_tolerance=10): |
|
self.score_thresh = score_thresh |
|
self.filter_score_thresh = filter_score_thresh |
|
self.miss_tolerance = miss_tolerance |
|
self.max_obj_id = 0 |
|
|
|
def clear(self): |
|
self.max_obj_id = 0 |
|
|
|
def update(self, track_instances: Instances): |
|
device = track_instances.obj_idxes.device |
|
|
|
track_instances.disappear_time[track_instances.scores >= self.score_thresh] = 0 |
|
new_obj = (track_instances.obj_idxes == -1) & (track_instances.scores >= self.score_thresh) |
|
disappeared_obj = (track_instances.obj_idxes >= 0) & (track_instances.scores < self.filter_score_thresh) |
|
num_new_objs = new_obj.sum().item() |
|
|
|
track_instances.obj_idxes[new_obj] = self.max_obj_id + torch.arange(num_new_objs, device=device) |
|
self.max_obj_id += num_new_objs |
|
|
|
track_instances.disappear_time[disappeared_obj] += 1 |
|
to_del = disappeared_obj & (track_instances.disappear_time >= self.miss_tolerance) |
|
track_instances.obj_idxes[to_del] = -1 |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
parser = argparse.ArgumentParser("Grounded-Segment-Anything VISAM Demo", parents=[get_args_parser()]) |
|
parser.add_argument('--score_threshold', default=0.5, type=float) |
|
parser.add_argument('--update_score_threshold', default=0.5, type=float) |
|
parser.add_argument('--miss_tolerance', default=20, type=int) |
|
|
|
parser.add_argument( |
|
"--sam_checkpoint", type=str, required=True, help="path to checkpoint file" |
|
) |
|
parser.add_argument("--video_path", type=str, required=True, help="path to image file") |
|
|
|
args = parser.parse_args() |
|
|
|
|
|
if args.output_dir: |
|
Path(args.output_dir).mkdir(parents=True, exist_ok=True) |
|
|
|
sam_predictor = SamPredictor(build_sam(checkpoint=args.sam_checkpoint)) |
|
_ = sam_predictor.model.to(device='cuda') |
|
|
|
|
|
detr, _, _ = build_model(args) |
|
detr.track_embed.score_thr = args.update_score_threshold |
|
detr.track_base = RuntimeTrackerBase(args.score_threshold, args.score_threshold, args.miss_tolerance) |
|
checkpoint = torch.load(args.resume, map_location='cpu') |
|
detr = load_model(detr, args.resume) |
|
detr.eval() |
|
detr = detr.cuda() |
|
|
|
rank = int(os.environ.get('RLAUNCH_REPLICA', '0')) |
|
ws = int(os.environ.get('RLAUNCH_REPLICA_TOTAL', '1')) |
|
|
|
det = Detector(args, model=detr, vid=args.video_path, sam_predictor=sam_predictor) |
|
det.detect(args.score_threshold) |
|
|