from copy import deepcopy import json import os import argparse import torchvision.transforms.functional as F import torch import cv2 import numpy as np from tqdm import tqdm from pathlib import Path import sys sys.path.append('VISAM') from main import get_args_parser from models import build_model from util.tool import load_model from models.structures import Instances from torch.utils.data import Dataset, DataLoader # segment anything sys.path.append('segment_anything') from segment_anything import build_sam, SamPredictor class Colors: # Ultralytics color palette https://ultralytics.com/ def __init__(self): # hex = matplotlib.colors.TABLEAU_COLORS.values() hexs = ('FF3838', 'FF9D97', 'FF701F', 'FFB21D', 'CFD231', '48F90A', '92CC17', '3DDB86', '1A9334', '00D4BB', '2C99A8', '00C2FF', '344593', '6473FF', '0018EC', '8438FF', '520085', 'CB38FF', 'FF95C8', 'FF37C7') self.palette = [self.hex2rgb(f'#{c}') for c in hexs] self.n = len(self.palette) def __call__(self, i, bgr=False): c = self.palette[int(i) % self.n] return (c[2], c[1], c[0]) if bgr else c @staticmethod def hex2rgb(h): # rgb order (PIL) return tuple(int(h[1 + i:1 + i + 2], 16) for i in (0, 2, 4)) colors = Colors() # create instance for 'from utils.plots import colors' class ListImgDataset(Dataset): def __init__(self, mot_path, img_list, det_db) -> None: super().__init__() self.mot_path = mot_path self.img_list = img_list self.det_db = det_db ''' common settings ''' self.img_height = 800 self.img_width = 1536 self.mean = [0.485, 0.456, 0.406] self.std = [0.229, 0.224, 0.225] def load_img_from_file(self, f_path): cur_img = cv2.imread(os.path.join(self.mot_path, f_path)) assert cur_img is not None, f_path cur_img = cv2.cvtColor(cur_img, cv2.COLOR_BGR2RGB) proposals = [] im_h, im_w = cur_img.shape[:2] for line in self.det_db[f_path[:-4] + '.txt']: l, t, w, h, s = list(map(float, line.split(','))) proposals.append([(l + w / 2) / im_w, (t + h / 2) / im_h, w / im_w, h / im_h, s]) return cur_img, torch.as_tensor(proposals).reshape(-1, 5) def init_img(self, img, proposals): ori_img = img.copy() self.seq_h, self.seq_w = img.shape[:2] scale = self.img_height / min(self.seq_h, self.seq_w) if max(self.seq_h, self.seq_w) * scale > self.img_width: scale = self.img_width / max(self.seq_h, self.seq_w) target_h = int(self.seq_h * scale) target_w = int(self.seq_w * scale) img = cv2.resize(img, (target_w, target_h)) img = F.normalize(F.to_tensor(img), self.mean, self.std) img = img.unsqueeze(0) return img, ori_img, proposals def __len__(self): return len(self.img_list) def __getitem__(self, index): img, proposals = self.load_img_from_file(self.img_list[index]) return self.init_img(img, proposals) class Detector(object): def __init__(self, args, model, vid, sam_predictor=None): self.args = args self.detr = model self.vid = vid self.seq_num = os.path.basename(vid) img_list = os.listdir(os.path.join(self.args.mot_path, vid, 'img1')) img_list = [os.path.join(vid, 'img1', i) for i in img_list if 'jpg' in i] self.img_list = sorted(img_list) self.img_len = len(self.img_list) self.predict_path = os.path.join(self.args.output_dir, args.exp_name) os.makedirs(self.predict_path, exist_ok=True) fps = 25 size = (1920, 1080) self.videowriter = cv2.VideoWriter('visam.avi', cv2.VideoWriter_fourcc('M','J','P','G'), fps, size) self.sam_predictor = sam_predictor @staticmethod def filter_dt_by_score(dt_instances: Instances, prob_threshold: float) -> Instances: keep = dt_instances.scores > prob_threshold keep &= dt_instances.obj_idxes >= 0 return dt_instances[keep] @staticmethod def filter_dt_by_area(dt_instances: Instances, area_threshold: float) -> Instances: wh = dt_instances.boxes[:, 2:4] - dt_instances.boxes[:, 0:2] areas = wh[:, 0] * wh[:, 1] keep = areas > area_threshold return dt_instances[keep] def detect(self, prob_threshold=0.6, area_threshold=100, vis=False): total_dts = 0 total_occlusion_dts = 0 track_instances = None with open(os.path.join(self.args.mot_path, 'DanceTrack', self.args.det_db)) as f: det_db = json.load(f) loader = DataLoader(ListImgDataset(self.args.mot_path, self.img_list, det_db), 1, num_workers=2) lines = [] for i, data in enumerate(tqdm(loader)): cur_img, ori_img, proposals = [d[0] for d in data] cur_img, proposals = cur_img.cuda(), proposals.cuda() # track_instances = None if track_instances is not None: track_instances.remove('boxes') track_instances.remove('labels') seq_h, seq_w, _ = ori_img.shape res = self.detr.inference_single_image(cur_img, (seq_h, seq_w), track_instances, proposals) track_instances = res['track_instances'] dt_instances = deepcopy(track_instances) # filter det instances by score. dt_instances = self.filter_dt_by_score(dt_instances, prob_threshold) dt_instances = self.filter_dt_by_area(dt_instances, area_threshold) total_dts += len(dt_instances) bbox_xyxy = dt_instances.boxes.tolist() identities = dt_instances.obj_idxes.tolist() img = ori_img.to(torch.device('cpu')).numpy().copy()[..., ::-1] if self.sam_predictor is not None: masks_all = [] self.sam_predictor.set_image(ori_img.to(torch.device('cpu')).numpy().copy()) for bbox, id in zip(np.array(bbox_xyxy), identities): masks, iou_predictions, low_res_masks = self.sam_predictor.predict(box=bbox) index_max = iou_predictions.argsort()[0] masks = np.concatenate([masks[index_max:(index_max+1)], masks[index_max:(index_max+1)], masks[index_max:(index_max+1)]], axis=0) masks = masks.astype(np.int32)*np.array(colors(id))[:, None, None] masks_all.append(masks) self.sam_predictor.reset_image() if len(masks_all): masks_sum = masks_all[0].copy() for m in masks_all[1:]: masks_sum += m else: masks_sum = np.zeros_like(img).transpose(2, 0, 1) img = (img * 0.5 + (masks_sum.transpose(1,2,0) * 30) %128).astype(np.uint8) for bbox in bbox_xyxy: cv2.rectangle(img, (int(bbox[0]), int(bbox[1])), (int(bbox[2]), int(bbox[3])), (0,0,255), thickness=3) self.videowriter.write(img) save_format = '{frame},{id},{x1:.2f},{y1:.2f},{w:.2f},{h:.2f},1,-1,-1,-1\n' for xyxy, track_id in zip(bbox_xyxy, identities): if track_id < 0 or track_id is None: continue x1, y1, x2, y2 = xyxy w, h = x2 - x1, y2 - y1 lines.append(save_format.format(frame=i + 1, id=track_id, x1=x1, y1=y1, w=w, h=h)) with open(os.path.join(self.predict_path, f'{self.seq_num}.txt'), 'w') as f: f.writelines(lines) print("totally {} dts {} occlusion dts".format(total_dts, total_occlusion_dts)) class RuntimeTrackerBase(object): def __init__(self, score_thresh=0.6, filter_score_thresh=0.5, miss_tolerance=10): self.score_thresh = score_thresh self.filter_score_thresh = filter_score_thresh self.miss_tolerance = miss_tolerance self.max_obj_id = 0 def clear(self): self.max_obj_id = 0 def update(self, track_instances: Instances): device = track_instances.obj_idxes.device track_instances.disappear_time[track_instances.scores >= self.score_thresh] = 0 new_obj = (track_instances.obj_idxes == -1) & (track_instances.scores >= self.score_thresh) disappeared_obj = (track_instances.obj_idxes >= 0) & (track_instances.scores < self.filter_score_thresh) num_new_objs = new_obj.sum().item() track_instances.obj_idxes[new_obj] = self.max_obj_id + torch.arange(num_new_objs, device=device) self.max_obj_id += num_new_objs track_instances.disappear_time[disappeared_obj] += 1 to_del = disappeared_obj & (track_instances.disappear_time >= self.miss_tolerance) track_instances.obj_idxes[to_del] = -1 if __name__ == "__main__": parser = argparse.ArgumentParser("Grounded-Segment-Anything VISAM Demo", parents=[get_args_parser()]) parser.add_argument('--score_threshold', default=0.5, type=float) parser.add_argument('--update_score_threshold', default=0.5, type=float) parser.add_argument('--miss_tolerance', default=20, type=int) parser.add_argument( "--sam_checkpoint", type=str, required=True, help="path to checkpoint file" ) parser.add_argument("--video_path", type=str, required=True, help="path to image file") args = parser.parse_args() # make dir if args.output_dir: Path(args.output_dir).mkdir(parents=True, exist_ok=True) sam_predictor = SamPredictor(build_sam(checkpoint=args.sam_checkpoint)) _ = sam_predictor.model.to(device='cuda') # load model and weights detr, _, _ = build_model(args) detr.track_embed.score_thr = args.update_score_threshold detr.track_base = RuntimeTrackerBase(args.score_threshold, args.score_threshold, args.miss_tolerance) checkpoint = torch.load(args.resume, map_location='cpu') detr = load_model(detr, args.resume) detr.eval() detr = detr.cuda() rank = int(os.environ.get('RLAUNCH_REPLICA', '0')) ws = int(os.environ.get('RLAUNCH_REPLICA_TOTAL', '1')) det = Detector(args, model=detr, vid=args.video_path, sam_predictor=sam_predictor) det.detect(args.score_threshold)