Grounded-Segment-Anything / grounded_sam_visam.py
hikerxu's picture
Upload folder using huggingface_hub
483de47 verified
raw
history blame
10.5 kB
from copy import deepcopy
import json
import os
import argparse
import torchvision.transforms.functional as F
import torch
import cv2
import numpy as np
from tqdm import tqdm
from pathlib import Path
import sys
sys.path.append('VISAM')
from main import get_args_parser
from models import build_model
from util.tool import load_model
from models.structures import Instances
from torch.utils.data import Dataset, DataLoader
# segment anything
sys.path.append('segment_anything')
from segment_anything import build_sam, SamPredictor
class Colors:
# Ultralytics color palette https://ultralytics.com/
def __init__(self):
# hex = matplotlib.colors.TABLEAU_COLORS.values()
hexs = ('FF3838', 'FF9D97', 'FF701F', 'FFB21D', 'CFD231', '48F90A', '92CC17', '3DDB86', '1A9334', '00D4BB',
'2C99A8', '00C2FF', '344593', '6473FF', '0018EC', '8438FF', '520085', 'CB38FF', 'FF95C8', 'FF37C7')
self.palette = [self.hex2rgb(f'#{c}') for c in hexs]
self.n = len(self.palette)
def __call__(self, i, bgr=False):
c = self.palette[int(i) % self.n]
return (c[2], c[1], c[0]) if bgr else c
@staticmethod
def hex2rgb(h): # rgb order (PIL)
return tuple(int(h[1 + i:1 + i + 2], 16) for i in (0, 2, 4))
colors = Colors() # create instance for 'from utils.plots import colors'
class ListImgDataset(Dataset):
def __init__(self, mot_path, img_list, det_db) -> None:
super().__init__()
self.mot_path = mot_path
self.img_list = img_list
self.det_db = det_db
'''
common settings
'''
self.img_height = 800
self.img_width = 1536
self.mean = [0.485, 0.456, 0.406]
self.std = [0.229, 0.224, 0.225]
def load_img_from_file(self, f_path):
cur_img = cv2.imread(os.path.join(self.mot_path, f_path))
assert cur_img is not None, f_path
cur_img = cv2.cvtColor(cur_img, cv2.COLOR_BGR2RGB)
proposals = []
im_h, im_w = cur_img.shape[:2]
for line in self.det_db[f_path[:-4] + '.txt']:
l, t, w, h, s = list(map(float, line.split(',')))
proposals.append([(l + w / 2) / im_w,
(t + h / 2) / im_h,
w / im_w,
h / im_h,
s])
return cur_img, torch.as_tensor(proposals).reshape(-1, 5)
def init_img(self, img, proposals):
ori_img = img.copy()
self.seq_h, self.seq_w = img.shape[:2]
scale = self.img_height / min(self.seq_h, self.seq_w)
if max(self.seq_h, self.seq_w) * scale > self.img_width:
scale = self.img_width / max(self.seq_h, self.seq_w)
target_h = int(self.seq_h * scale)
target_w = int(self.seq_w * scale)
img = cv2.resize(img, (target_w, target_h))
img = F.normalize(F.to_tensor(img), self.mean, self.std)
img = img.unsqueeze(0)
return img, ori_img, proposals
def __len__(self):
return len(self.img_list)
def __getitem__(self, index):
img, proposals = self.load_img_from_file(self.img_list[index])
return self.init_img(img, proposals)
class Detector(object):
def __init__(self, args, model, vid, sam_predictor=None):
self.args = args
self.detr = model
self.vid = vid
self.seq_num = os.path.basename(vid)
img_list = os.listdir(os.path.join(self.args.mot_path, vid, 'img1'))
img_list = [os.path.join(vid, 'img1', i) for i in img_list if 'jpg' in i]
self.img_list = sorted(img_list)
self.img_len = len(self.img_list)
self.predict_path = os.path.join(self.args.output_dir, args.exp_name)
os.makedirs(self.predict_path, exist_ok=True)
fps = 25
size = (1920, 1080)
self.videowriter = cv2.VideoWriter('visam.avi', cv2.VideoWriter_fourcc('M','J','P','G'), fps, size)
self.sam_predictor = sam_predictor
@staticmethod
def filter_dt_by_score(dt_instances: Instances, prob_threshold: float) -> Instances:
keep = dt_instances.scores > prob_threshold
keep &= dt_instances.obj_idxes >= 0
return dt_instances[keep]
@staticmethod
def filter_dt_by_area(dt_instances: Instances, area_threshold: float) -> Instances:
wh = dt_instances.boxes[:, 2:4] - dt_instances.boxes[:, 0:2]
areas = wh[:, 0] * wh[:, 1]
keep = areas > area_threshold
return dt_instances[keep]
def detect(self, prob_threshold=0.6, area_threshold=100, vis=False):
total_dts = 0
total_occlusion_dts = 0
track_instances = None
with open(os.path.join(self.args.mot_path, 'DanceTrack', self.args.det_db)) as f:
det_db = json.load(f)
loader = DataLoader(ListImgDataset(self.args.mot_path, self.img_list, det_db), 1, num_workers=2)
lines = []
for i, data in enumerate(tqdm(loader)):
cur_img, ori_img, proposals = [d[0] for d in data]
cur_img, proposals = cur_img.cuda(), proposals.cuda()
# track_instances = None
if track_instances is not None:
track_instances.remove('boxes')
track_instances.remove('labels')
seq_h, seq_w, _ = ori_img.shape
res = self.detr.inference_single_image(cur_img, (seq_h, seq_w), track_instances, proposals)
track_instances = res['track_instances']
dt_instances = deepcopy(track_instances)
# filter det instances by score.
dt_instances = self.filter_dt_by_score(dt_instances, prob_threshold)
dt_instances = self.filter_dt_by_area(dt_instances, area_threshold)
total_dts += len(dt_instances)
bbox_xyxy = dt_instances.boxes.tolist()
identities = dt_instances.obj_idxes.tolist()
img = ori_img.to(torch.device('cpu')).numpy().copy()[..., ::-1]
if self.sam_predictor is not None:
masks_all = []
self.sam_predictor.set_image(ori_img.to(torch.device('cpu')).numpy().copy())
for bbox, id in zip(np.array(bbox_xyxy), identities):
masks, iou_predictions, low_res_masks = self.sam_predictor.predict(box=bbox)
index_max = iou_predictions.argsort()[0]
masks = np.concatenate([masks[index_max:(index_max+1)], masks[index_max:(index_max+1)], masks[index_max:(index_max+1)]], axis=0)
masks = masks.astype(np.int32)*np.array(colors(id))[:, None, None]
masks_all.append(masks)
self.sam_predictor.reset_image()
if len(masks_all):
masks_sum = masks_all[0].copy()
for m in masks_all[1:]:
masks_sum += m
else:
masks_sum = np.zeros_like(img).transpose(2, 0, 1)
img = (img * 0.5 + (masks_sum.transpose(1,2,0) * 30) %128).astype(np.uint8)
for bbox in bbox_xyxy:
cv2.rectangle(img, (int(bbox[0]), int(bbox[1])), (int(bbox[2]), int(bbox[3])), (0,0,255), thickness=3)
self.videowriter.write(img)
save_format = '{frame},{id},{x1:.2f},{y1:.2f},{w:.2f},{h:.2f},1,-1,-1,-1\n'
for xyxy, track_id in zip(bbox_xyxy, identities):
if track_id < 0 or track_id is None:
continue
x1, y1, x2, y2 = xyxy
w, h = x2 - x1, y2 - y1
lines.append(save_format.format(frame=i + 1, id=track_id, x1=x1, y1=y1, w=w, h=h))
with open(os.path.join(self.predict_path, f'{self.seq_num}.txt'), 'w') as f:
f.writelines(lines)
print("totally {} dts {} occlusion dts".format(total_dts, total_occlusion_dts))
class RuntimeTrackerBase(object):
def __init__(self, score_thresh=0.6, filter_score_thresh=0.5, miss_tolerance=10):
self.score_thresh = score_thresh
self.filter_score_thresh = filter_score_thresh
self.miss_tolerance = miss_tolerance
self.max_obj_id = 0
def clear(self):
self.max_obj_id = 0
def update(self, track_instances: Instances):
device = track_instances.obj_idxes.device
track_instances.disappear_time[track_instances.scores >= self.score_thresh] = 0
new_obj = (track_instances.obj_idxes == -1) & (track_instances.scores >= self.score_thresh)
disappeared_obj = (track_instances.obj_idxes >= 0) & (track_instances.scores < self.filter_score_thresh)
num_new_objs = new_obj.sum().item()
track_instances.obj_idxes[new_obj] = self.max_obj_id + torch.arange(num_new_objs, device=device)
self.max_obj_id += num_new_objs
track_instances.disappear_time[disappeared_obj] += 1
to_del = disappeared_obj & (track_instances.disappear_time >= self.miss_tolerance)
track_instances.obj_idxes[to_del] = -1
if __name__ == "__main__":
parser = argparse.ArgumentParser("Grounded-Segment-Anything VISAM Demo", parents=[get_args_parser()])
parser.add_argument('--score_threshold', default=0.5, type=float)
parser.add_argument('--update_score_threshold', default=0.5, type=float)
parser.add_argument('--miss_tolerance', default=20, type=int)
parser.add_argument(
"--sam_checkpoint", type=str, required=True, help="path to checkpoint file"
)
parser.add_argument("--video_path", type=str, required=True, help="path to image file")
args = parser.parse_args()
# make dir
if args.output_dir:
Path(args.output_dir).mkdir(parents=True, exist_ok=True)
sam_predictor = SamPredictor(build_sam(checkpoint=args.sam_checkpoint))
_ = sam_predictor.model.to(device='cuda')
# load model and weights
detr, _, _ = build_model(args)
detr.track_embed.score_thr = args.update_score_threshold
detr.track_base = RuntimeTrackerBase(args.score_threshold, args.score_threshold, args.miss_tolerance)
checkpoint = torch.load(args.resume, map_location='cpu')
detr = load_model(detr, args.resume)
detr.eval()
detr = detr.cuda()
rank = int(os.environ.get('RLAUNCH_REPLICA', '0'))
ws = int(os.environ.get('RLAUNCH_REPLICA_TOTAL', '1'))
det = Detector(args, model=detr, vid=args.video_path, sam_predictor=sam_predictor)
det.detect(args.score_threshold)