import torch import numpy as np from dp2 import utils from dp2.utils import vis_utils, crop_box from .utils import ( cut_pad_resize, masks_to_boxes, get_kernel, transform_embedding, initialize_cse_boxes ) from .box_utils import get_expanded_bbox, include_box import torchvision import tops from .box_utils_fdf import expand_bbox as expand_bbox_fdf class VehicleDetection: def __init__(self, segmentation: torch.BoolTensor) -> None: self.segmentation = segmentation self.boxes = masks_to_boxes(segmentation) assert self.boxes.shape[1] == 4, self.boxes.shape self.n_detections = self.segmentation.shape[0] area = (self.boxes[:, 3] - self.boxes[:, 1]) * (self.boxes[:, 2] - self.boxes[:, 0]) sorted_idx = torch.argsort(area, descending=True) self.segmentation = self.segmentation[sorted_idx] self.boxes = self.boxes[sorted_idx].cpu() def pre_process(self): pass def get_crop(self, idx: int, im): assert idx < len(self) box = self.boxes[idx] im = crop_box(self.im, box) mask = crop_box(self.segmentation[idx]) mask = mask == 0 return dict(img=im, mask=mask.float(), boxes=box) def visualize(self, im): if len(self) == 0: return im im = vis_utils.draw_mask(im.clone(), self.segmentation.logical_not()) return im def __len__(self): return self.n_detections @staticmethod def from_state_dict(state_dict, **kwargs): numel = np.prod(state_dict["shape"]) arr = np.unpackbits(state_dict["segmentation"].numpy(), count=numel) segmentation = tops.to_cuda(torch.from_numpy(arr)).view(state_dict["shape"]) return VehicleDetection(segmentation) def state_dict(self, **kwargs): segmentation = torch.from_numpy(np.packbits(self.segmentation.bool().cpu().numpy())) return dict(segmentation=segmentation, cls=self.__class__, shape=self.segmentation.shape) class FaceDetection: def __init__(self, boxes_ltrb: torch.LongTensor, target_imsize, fdf128_expand: bool, keypoints: torch.Tensor = None, **kwargs) -> None: self.boxes = boxes_ltrb.cpu() assert self.boxes.shape[1] == 4, self.boxes.shape self.target_imsize = tuple(target_imsize) # Sory by area to paste in largest faces last area = (self.boxes[:, 2] - self.boxes[:, 0]) * (self.boxes[:, 3] - self.boxes[:, 1]).view(-1) idx = area.argsort(descending=False) self.boxes = self.boxes[idx] self.fdf128_expand = fdf128_expand self.orig_keypoints = keypoints if keypoints is not None: self.orig_keypoints = self.orig_keypoints[idx] assert keypoints.shape == (len(boxes_ltrb), 17, 2) or \ keypoints.shape == (len(boxes_ltrb), 7, 2), keypoints.shape def visualize(self, im): if len(self) == 0: return im orig_device = im.device for box in self.boxes: simple_expand = False if self.fdf128_expand else True e_box = torch.from_numpy(expand_bbox_fdf(box.numpy(), im.shape[-2:], simple_expand)) im = torchvision.utils.draw_bounding_boxes(im.cpu(), e_box[None], colors=(0, 0, 255), width=2) im = torchvision.utils.draw_bounding_boxes(im.cpu(), self.boxes, colors=(255, 0, 0), width=2) if self.orig_keypoints is not None: im = vis_utils.draw_keypoints(im, self.orig_keypoints, radius=1) return im.to(device=orig_device) def get_crop(self, idx: int, im): assert idx < len(self) box = self.boxes[idx].numpy() simple_expand = False if self.fdf128_expand else True expanded_boxes = expand_bbox_fdf(box, im.shape[-2:], simple_expand) im = cut_pad_resize(im, expanded_boxes, self.target_imsize, fdf_resize=True) # Find the square mask corresponding to box. box_mask = box.copy().astype(float) box_mask[[0, 2]] -= expanded_boxes[0] box_mask[[1, 3]] -= expanded_boxes[1] width = expanded_boxes[2] - expanded_boxes[0] resize_factor = self.target_imsize[0] / width box_mask = (box_mask * resize_factor).astype(int) mask = torch.ones((1, *self.target_imsize), device=im.device, dtype=torch.float32) crop_box(mask, box_mask).fill_(0) if self.orig_keypoints is None: return dict( img=im[None], mask=mask[None], boxes=torch.from_numpy(expanded_boxes).view(1, -1)) keypoint = self.orig_keypoints[idx, :7, :2].clone() keypoint[:, 0] -= expanded_boxes[0] keypoint[:, 1] -= expanded_boxes[1] w = expanded_boxes[2] - expanded_boxes[0] keypoint /= w keypoint = keypoint.clamp(0, 1) return dict( img=im[None], mask=mask[None], boxes=torch.from_numpy(expanded_boxes).view(1, -1), keypoints=keypoint[None]) def __len__(self): return len(self.boxes) @staticmethod def from_state_dict(state_dict, **kwargs): return FaceDetection( state_dict["boxes"].cpu(), keypoints=state_dict["orig_keypoints"] if "orig_keypoints" in state_dict else None, **kwargs) def state_dict(self, **kwargs): return dict( boxes=self.boxes, cls=self.__class__, orig_keypoints=self.orig_keypoints) def pre_process(self): pass def remove_dilate_in_pad(mask: torch.Tensor, exp_box, orig_imshape): """ Dilation happens after padding, which could place dilation in the padded area. Remove this. """ x0, y0, x1, y1 = exp_box H, W = orig_imshape # Padding in original image space p_y0 = max(0, -y0) p_y1 = max(y1 - H, 0) p_x0 = max(0, -x0) p_x1 = max(x1 - W, 0) resize_ratio = mask.shape[-2] / (y1-y0) p_x0, p_y0, p_x1, p_y1 = [(_*resize_ratio).floor().long() for _ in [p_x0, p_y0, p_x1, p_y1]] mask[..., :p_y0, :] = 0 mask[..., :p_x0] = 0 mask[..., mask.shape[-2] - p_y1:, :] = 0 mask[..., mask.shape[-1] - p_x1:] = 0 class CSEPersonDetection: def __init__(self, segmentation, cse_dets, target_imsize, exp_bbox_cfg, exp_bbox_filter, dilation_percentage: float, embed_map: torch.Tensor, orig_imshape_CHW, normalize_embedding: bool) -> None: self.segmentation = segmentation self.cse_dets = cse_dets self.target_imsize = list(target_imsize) self.pre_processed = False self.exp_bbox_cfg = exp_bbox_cfg self.exp_bbox_filter = exp_bbox_filter self.dilation_percentage = dilation_percentage self.embed_map = embed_map self.embed_map_cpu = embed_map.cpu() self.normalize_embedding = normalize_embedding if self.normalize_embedding: embed_map_mean = self.embed_map.mean(dim=0, keepdim=True) embed_map_rstd = ((self.embed_map - embed_map_mean).square().mean(dim=0, keepdim=True)+1e-8).rsqrt() self.embed_map_normalized = (self.embed_map - embed_map_mean) * embed_map_rstd self.orig_imshape_CHW = orig_imshape_CHW @torch.no_grad() def pre_process(self): if self.pre_processed: return boxes = initialize_cse_boxes(self.segmentation, self.cse_dets["bbox_XYXY"]).cpu() expanded_boxes = [] included_boxes = [] for i in range(len(boxes)): exp_box = get_expanded_bbox( boxes[i], self.orig_imshape_CHW[1:], self.segmentation[i], **self.exp_bbox_cfg, target_aspect_ratio=self.target_imsize[0]/self.target_imsize[1]) if not include_box(exp_box, imshape=self.orig_imshape_CHW[1:], **self.exp_bbox_filter): continue included_boxes.append(i) expanded_boxes.append(exp_box) expanded_boxes = torch.LongTensor(expanded_boxes).view(-1, 4) self.segmentation = self.segmentation[included_boxes] self.cse_dets = {k: v[included_boxes] for k, v in self.cse_dets.items()} self.mask = torch.empty((len(expanded_boxes), *self.target_imsize), device=tops.get_device(), dtype=torch.bool) area = self.segmentation.sum(dim=[1, 2]).view(len(expanded_boxes)) for i, box in enumerate(expanded_boxes): self.mask[i] = cut_pad_resize(self.segmentation[i:i+1], box, self.target_imsize)[0] dilation_kernel = get_kernel(int((self.target_imsize[0]*self.target_imsize[1])**0.5*self.dilation_percentage)) self.maskrcnn_mask = self.mask.clone().logical_not()[:, None] self.mask = utils.binary_dilation(self.mask[:, None], dilation_kernel) for i in range(len(expanded_boxes)): remove_dilate_in_pad(self.mask[i], expanded_boxes[i], self.orig_imshape_CHW[1:]) self.boxes = expanded_boxes.cpu() self.dilated_boxes = get_dilated_boxes(self.boxes, self.mask) self.pre_processed = True self.n_detections = len(self.boxes) self.mask = self.mask.logical_not() E_mask = torch.zeros((self.n_detections, 1, *self.target_imsize), device=self.mask.device, dtype=torch.bool) self.vertices = torch.zeros_like(E_mask, dtype=torch.long) for i in range(self.n_detections): E_, E_mask[i] = transform_embedding( self.cse_dets["instance_embedding"][i], self.cse_dets["instance_segmentation"][i], self.boxes[i], self.cse_dets["bbox_XYXY"][i].cpu(), self.target_imsize ) self.vertices[i] = utils.from_E_to_vertex( E_[None], E_mask[i:i+1].logical_not(), self.embed_map).squeeze()[None] self.E_mask = E_mask sorted_idx = torch.argsort(area, descending=False) self.mask = self.mask[sorted_idx] self.boxes = self.boxes[sorted_idx.cpu()] self.vertices = self.vertices[sorted_idx] self.E_mask = self.E_mask[sorted_idx] self.maskrcnn_mask = self.maskrcnn_mask[sorted_idx] def get_crop(self, idx: int, im): self.pre_process() assert idx < len(self) box = self.boxes[idx] mask = self.mask[idx] im = cut_pad_resize(im, box, self.target_imsize).unsqueeze(0) vertices_ = self.vertices[idx] E_mask_ = self.E_mask[idx].float() if self.normalize_embedding: embedding = self.embed_map_normalized[vertices_.squeeze(dim=0)].permute(2, 0, 1) * E_mask_ else: embedding = self.embed_map[vertices_.squeeze(dim=0)].permute(2, 0, 1) * E_mask_ return dict( img=im, mask=mask.float()[None], boxes=box.reshape(1, -1), E_mask=E_mask_[None], vertices=vertices_[None], embed_map=self.embed_map, embedding=embedding[None], maskrcnn_mask=self.maskrcnn_mask[idx].float()[None] ) def __len__(self): self.pre_process() return self.n_detections def state_dict(self, after_preprocess=False): """ The processed annotations occupy more space than the original detections. """ if not after_preprocess: return { "combined_segmentation": self.segmentation.bool(), "cse_instance_segmentation": self.cse_dets["instance_segmentation"].bool(), "cse_instance_embedding": self.cse_dets["instance_embedding"], "cse_bbox_XYXY": self.cse_dets["bbox_XYXY"].long(), "cls": self.__class__, "orig_imshape_CHW": self.orig_imshape_CHW } self.pre_process() def compress_bool(x): return torch.from_numpy(np.packbits(x.bool().cpu().numpy())) return dict( E_mask=compress_bool(self.E_mask), mask=compress_bool(self.mask), maskrcnn_mask=compress_bool(self.maskrcnn_mask), vertices=self.vertices.to(torch.int16).cpu(), cls=self.__class__, boxes=self.boxes, orig_imshape_CHW=self.orig_imshape_CHW, ) @staticmethod def from_state_dict( state_dict, embed_map, post_process_cfg, **kwargs): after_preprocess = "segmentation" not in state_dict and "combined_segmentation" not in state_dict if after_preprocess: detection = CSEPersonDetection( segmentation=None, cse_dets=None, embed_map=embed_map, orig_imshape_CHW=state_dict["orig_imshape_CHW"], **post_process_cfg) detection.vertices = tops.to_cuda(state_dict["vertices"].long()) numel = np.prod(detection.vertices.shape) def unpack_bool(x): x = torch.from_numpy(np.unpackbits(x.numpy(), count=numel)) return x.view(*detection.vertices.shape) detection.E_mask = tops.to_cuda(unpack_bool(state_dict["E_mask"])) detection.mask = tops.to_cuda(unpack_bool(state_dict["mask"])) detection.maskrcnn_mask = tops.to_cuda(unpack_bool(state_dict["maskrcnn_mask"])) detection.n_detections = len(detection.mask) detection.pre_processed = True if isinstance(state_dict["boxes"], np.ndarray): state_dict["boxes"] = torch.from_numpy(state_dict["boxes"]) detection.boxes = state_dict["boxes"] return detection cse_dets = dict( instance_segmentation=state_dict["cse_instance_segmentation"], instance_embedding=state_dict["cse_instance_embedding"], embed_map=embed_map, bbox_XYXY=state_dict["cse_bbox_XYXY"]) cse_dets = {k: tops.to_cuda(v) for k, v in cse_dets.items()} segmentation = state_dict["combined_segmentation"] return CSEPersonDetection( segmentation, cse_dets, embed_map=embed_map, orig_imshape_CHW=state_dict["orig_imshape_CHW"], **post_process_cfg) def visualize(self, im): self.pre_process() if len(self) == 0: return im im = vis_utils.draw_cropped_masks( im.cpu(), self.mask.cpu(), self.boxes, visualize_instances=False) E = self.embed_map_cpu[self.vertices.long().cpu()].squeeze(1).permute(0, 3, 1, 2) im = vis_utils.draw_cse_all( E, self.E_mask.squeeze(1).bool().cpu(), im, self.boxes, self.embed_map_cpu) im = torchvision.utils.draw_bounding_boxes(im, self.boxes, colors=(255, 0, 0), width=2) return im def shift_and_preprocess_keypoints(keypoints: torch.Tensor, boxes): keypoints = keypoints.clone() N = boxes.shape[0] tops.assert_shape(keypoints, (N, None, 3)) tops.assert_shape(boxes, (N, 4)) x0, y0, x1, y1 = [_.view(-1, 1) for _ in boxes.T] w = x1 - x0 h = y1 - y0 keypoints[:, :, 0] = (keypoints[:, :, 0] - x0) / w keypoints[:, :, 1] = (keypoints[:, :, 1] - y0) / h def check_outside(x): return (x < 0).logical_or(x > 1) is_outside = check_outside(keypoints[:, :, 0]).logical_or(check_outside(keypoints[:, :, 1])) keypoints[:, :, 2] = keypoints[:, :, 2] > 0 keypoints[:, :, 2] = (keypoints[:, :, 2] > 0).logical_and(is_outside.logical_not()) return keypoints class PersonDetection: def __init__( self, segmentation, target_imsize, exp_bbox_cfg, exp_bbox_filter, dilation_percentage: float, orig_imshape_CHW, kp_vis_thr=None, keypoints=None, **kwargs) -> None: self.segmentation = segmentation self.target_imsize = list(target_imsize) self.pre_processed = False self.exp_bbox_cfg = exp_bbox_cfg self.exp_bbox_filter = exp_bbox_filter self.dilation_percentage = dilation_percentage self.orig_imshape_CHW = orig_imshape_CHW self.orig_keypoints = keypoints if keypoints is not None: assert kp_vis_thr is not None self.kp_vis_thr = kp_vis_thr @torch.no_grad() def pre_process(self): if self.pre_processed: return boxes = masks_to_boxes(self.segmentation).cpu() expanded_boxes = [] included_boxes = [] for i in range(len(boxes)): exp_box = get_expanded_bbox( boxes[i], self.orig_imshape_CHW[1:], self.segmentation[i], **self.exp_bbox_cfg, target_aspect_ratio=self.target_imsize[0]/self.target_imsize[1]) if not include_box(exp_box, imshape=self.orig_imshape_CHW[1:], **self.exp_bbox_filter): continue included_boxes.append(i) expanded_boxes.append(exp_box) expanded_boxes = torch.LongTensor(expanded_boxes).view(-1, 4) self.segmentation = self.segmentation[included_boxes] if self.orig_keypoints is not None: self.keypoints = self.orig_keypoints[included_boxes].clone() self.keypoints[:, :, 2] = self.keypoints[:, :, 2] >= self.kp_vis_thr area = self.segmentation.sum(dim=[1, 2]).view(len(expanded_boxes)).cpu() self.mask = torch.empty((len(expanded_boxes), *self.target_imsize), device=tops.get_device(), dtype=torch.bool) for i, box in enumerate(expanded_boxes): self.mask[i] = cut_pad_resize(self.segmentation[i:i+1], box, self.target_imsize)[0] if self.orig_keypoints is not None: self.keypoints = shift_and_preprocess_keypoints(self.keypoints, expanded_boxes) dilation_kernel = get_kernel(int((self.target_imsize[0]*self.target_imsize[1])**0.5*self.dilation_percentage)) self.maskrcnn_mask = self.mask.clone().logical_not()[:, None] self.mask = utils.binary_dilation(self.mask[:, None], dilation_kernel) for i in range(len(expanded_boxes)): remove_dilate_in_pad(self.mask[i], expanded_boxes[i], self.orig_imshape_CHW[1:]) self.boxes = expanded_boxes self.dilated_boxes = get_dilated_boxes(self.boxes, self.mask) self.pre_processed = True self.n_detections = len(self.boxes) self.mask = self.mask.logical_not() sorted_idx = torch.argsort(area, descending=False) self.mask = self.mask[sorted_idx] self.boxes = self.boxes[sorted_idx.cpu()] self.segmentation = self.segmentation[sorted_idx] self.maskrcnn_mask = self.maskrcnn_mask[sorted_idx] if self.keypoints is not None: self.keypoints = self.keypoints[sorted_idx.cpu()] def get_crop(self, idx: int, im: torch.Tensor): assert idx < len(self) self.pre_process() box = self.boxes[idx] mask = self.mask[idx][None].float() im = cut_pad_resize(im, box, self.target_imsize).unsqueeze(0) batch = dict( img=im, mask=mask, boxes=box.reshape(1, -1), maskrcnn_mask=self.maskrcnn_mask[idx][None].float()) if self.keypoints is not None: batch["keypoints"] = self.keypoints[idx:idx+1] return batch def __len__(self): self.pre_process() return self.n_detections def state_dict(self, **kwargs): return dict( segmentation=self.segmentation.bool(), cls=self.__class__, orig_imshape_CHW=self.orig_imshape_CHW, keypoints=self.orig_keypoints ) @staticmethod def from_state_dict( state_dict, post_process_cfg, **kwargs): return PersonDetection( state_dict["segmentation"], orig_imshape_CHW=state_dict["orig_imshape_CHW"], **post_process_cfg, keypoints=state_dict["keypoints"]) def visualize(self, im): self.pre_process() im = im.cpu() if len(self) == 0: return im im = vis_utils.draw_cropped_masks(im.clone(), self.mask.cpu(), self.boxes, visualize_instances=False) if self.keypoints is not None: im = vis_utils.draw_cropped_keypoints(im, self.keypoints, self.boxes) return im def get_dilated_boxes(exp_bbox: torch.LongTensor, mask): """ mask: resized mask """ assert exp_bbox.shape[0] == mask.shape[0] boxes = masks_to_boxes(mask.squeeze(1)).cpu() H, W = exp_bbox[:, 3] - exp_bbox[:, 1], exp_bbox[:, 2] - exp_bbox[:, 0] boxes[:, [0, 2]] = (boxes[:, [0, 2]] * W[:, None] / mask.shape[-1]).long() boxes[:, [1, 3]] = (boxes[:, [1, 3]] * H[:, None] / mask.shape[-2]).long() boxes[:, [0, 2]] += exp_bbox[:, 0:1] boxes[:, [1, 3]] += exp_bbox[:, 1:2] return boxes