Spaces:
Runtime error
Runtime error
import torch | |
import numpy as np | |
from dp2 import utils | |
from dp2.utils import vis_utils, crop_box | |
from .utils import ( | |
cut_pad_resize, masks_to_boxes, | |
get_kernel, transform_embedding, initialize_cse_boxes | |
) | |
from .box_utils import get_expanded_bbox, include_box | |
import torchvision | |
import tops | |
from .box_utils_fdf import expand_bbox as expand_bbox_fdf | |
class VehicleDetection: | |
def __init__(self, segmentation: torch.BoolTensor) -> None: | |
self.segmentation = segmentation | |
self.boxes = masks_to_boxes(segmentation) | |
assert self.boxes.shape[1] == 4, self.boxes.shape | |
self.n_detections = self.segmentation.shape[0] | |
area = (self.boxes[:, 3] - self.boxes[:, 1]) * (self.boxes[:, 2] - self.boxes[:, 0]) | |
sorted_idx = torch.argsort(area, descending=True) | |
self.segmentation = self.segmentation[sorted_idx] | |
self.boxes = self.boxes[sorted_idx].cpu() | |
def pre_process(self): | |
pass | |
def get_crop(self, idx: int, im): | |
assert idx < len(self) | |
box = self.boxes[idx] | |
im = crop_box(self.im, box) | |
mask = crop_box(self.segmentation[idx]) | |
mask = mask == 0 | |
return dict(img=im, mask=mask.float(), boxes=box) | |
def visualize(self, im): | |
if len(self) == 0: | |
return im | |
im = vis_utils.draw_mask(im.clone(), self.segmentation.logical_not()) | |
return im | |
def __len__(self): | |
return self.n_detections | |
def from_state_dict(state_dict, **kwargs): | |
numel = np.prod(state_dict["shape"]) | |
arr = np.unpackbits(state_dict["segmentation"].numpy(), count=numel) | |
segmentation = tops.to_cuda(torch.from_numpy(arr)).view(state_dict["shape"]) | |
return VehicleDetection(segmentation) | |
def state_dict(self, **kwargs): | |
segmentation = torch.from_numpy(np.packbits(self.segmentation.bool().cpu().numpy())) | |
return dict(segmentation=segmentation, cls=self.__class__, shape=self.segmentation.shape) | |
class FaceDetection: | |
def __init__(self, | |
boxes_ltrb: torch.LongTensor, target_imsize, fdf128_expand: bool, | |
keypoints: torch.Tensor = None, | |
**kwargs) -> None: | |
self.boxes = boxes_ltrb.cpu() | |
assert self.boxes.shape[1] == 4, self.boxes.shape | |
self.target_imsize = tuple(target_imsize) | |
# Sory by area to paste in largest faces last | |
area = (self.boxes[:, 2] - self.boxes[:, 0]) * (self.boxes[:, 3] - self.boxes[:, 1]).view(-1) | |
idx = area.argsort(descending=False) | |
self.boxes = self.boxes[idx] | |
self.fdf128_expand = fdf128_expand | |
self.orig_keypoints = keypoints | |
if keypoints is not None: | |
self.orig_keypoints = self.orig_keypoints[idx] | |
assert keypoints.shape == (len(boxes_ltrb), 17, 2) or \ | |
keypoints.shape == (len(boxes_ltrb), 7, 2), keypoints.shape | |
def visualize(self, im): | |
if len(self) == 0: | |
return im | |
orig_device = im.device | |
for box in self.boxes: | |
simple_expand = False if self.fdf128_expand else True | |
e_box = torch.from_numpy(expand_bbox_fdf(box.numpy(), im.shape[-2:], simple_expand)) | |
im = torchvision.utils.draw_bounding_boxes(im.cpu(), e_box[None], colors=(0, 0, 255), width=2) | |
im = torchvision.utils.draw_bounding_boxes(im.cpu(), self.boxes, colors=(255, 0, 0), width=2) | |
if self.orig_keypoints is not None: | |
im = vis_utils.draw_keypoints(im, self.orig_keypoints, radius=1) | |
return im.to(device=orig_device) | |
def get_crop(self, idx: int, im): | |
assert idx < len(self) | |
box = self.boxes[idx].numpy() | |
simple_expand = False if self.fdf128_expand else True | |
expanded_boxes = expand_bbox_fdf(box, im.shape[-2:], simple_expand) | |
im = cut_pad_resize(im, expanded_boxes, self.target_imsize, fdf_resize=True) | |
# Find the square mask corresponding to box. | |
box_mask = box.copy().astype(float) | |
box_mask[[0, 2]] -= expanded_boxes[0] | |
box_mask[[1, 3]] -= expanded_boxes[1] | |
width = expanded_boxes[2] - expanded_boxes[0] | |
resize_factor = self.target_imsize[0] / width | |
box_mask = (box_mask * resize_factor).astype(int) | |
mask = torch.ones((1, *self.target_imsize), device=im.device, dtype=torch.float32) | |
crop_box(mask, box_mask).fill_(0) | |
if self.orig_keypoints is None: | |
return dict( | |
img=im[None], mask=mask[None], | |
boxes=torch.from_numpy(expanded_boxes).view(1, -1)) | |
keypoint = self.orig_keypoints[idx, :7, :2].clone() | |
keypoint[:, 0] -= expanded_boxes[0] | |
keypoint[:, 1] -= expanded_boxes[1] | |
w = expanded_boxes[2] - expanded_boxes[0] | |
keypoint /= w | |
keypoint = keypoint.clamp(0, 1) | |
return dict( | |
img=im[None], mask=mask[None], | |
boxes=torch.from_numpy(expanded_boxes).view(1, -1), | |
keypoints=keypoint[None]) | |
def __len__(self): | |
return len(self.boxes) | |
def from_state_dict(state_dict, **kwargs): | |
return FaceDetection( | |
state_dict["boxes"].cpu(), | |
keypoints=state_dict["orig_keypoints"] if "orig_keypoints" in state_dict else None, | |
**kwargs) | |
def state_dict(self, **kwargs): | |
return dict( | |
boxes=self.boxes, | |
cls=self.__class__, | |
orig_keypoints=self.orig_keypoints) | |
def pre_process(self): | |
pass | |
def remove_dilate_in_pad(mask: torch.Tensor, exp_box, orig_imshape): | |
""" | |
Dilation happens after padding, which could place dilation in the padded area. | |
Remove this. | |
""" | |
x0, y0, x1, y1 = exp_box | |
H, W = orig_imshape | |
# Padding in original image space | |
p_y0 = max(0, -y0) | |
p_y1 = max(y1 - H, 0) | |
p_x0 = max(0, -x0) | |
p_x1 = max(x1 - W, 0) | |
resize_ratio = mask.shape[-2] / (y1-y0) | |
p_x0, p_y0, p_x1, p_y1 = [(_*resize_ratio).floor().long() for _ in [p_x0, p_y0, p_x1, p_y1]] | |
mask[..., :p_y0, :] = 0 | |
mask[..., :p_x0] = 0 | |
mask[..., mask.shape[-2] - p_y1:, :] = 0 | |
mask[..., mask.shape[-1] - p_x1:] = 0 | |
class CSEPersonDetection: | |
def __init__(self, | |
segmentation, cse_dets, | |
target_imsize, | |
exp_bbox_cfg, exp_bbox_filter, | |
dilation_percentage: float, | |
embed_map: torch.Tensor, | |
orig_imshape_CHW, | |
normalize_embedding: bool) -> None: | |
self.segmentation = segmentation | |
self.cse_dets = cse_dets | |
self.target_imsize = list(target_imsize) | |
self.pre_processed = False | |
self.exp_bbox_cfg = exp_bbox_cfg | |
self.exp_bbox_filter = exp_bbox_filter | |
self.dilation_percentage = dilation_percentage | |
self.embed_map = embed_map | |
self.embed_map_cpu = embed_map.cpu() | |
self.normalize_embedding = normalize_embedding | |
if self.normalize_embedding: | |
embed_map_mean = self.embed_map.mean(dim=0, keepdim=True) | |
embed_map_rstd = ((self.embed_map - embed_map_mean).square().mean(dim=0, keepdim=True)+1e-8).rsqrt() | |
self.embed_map_normalized = (self.embed_map - embed_map_mean) * embed_map_rstd | |
self.orig_imshape_CHW = orig_imshape_CHW | |
def pre_process(self): | |
if self.pre_processed: | |
return | |
boxes = initialize_cse_boxes(self.segmentation, self.cse_dets["bbox_XYXY"]).cpu() | |
expanded_boxes = [] | |
included_boxes = [] | |
for i in range(len(boxes)): | |
exp_box = get_expanded_bbox( | |
boxes[i], self.orig_imshape_CHW[1:], self.segmentation[i], **self.exp_bbox_cfg, | |
target_aspect_ratio=self.target_imsize[0]/self.target_imsize[1]) | |
if not include_box(exp_box, imshape=self.orig_imshape_CHW[1:], **self.exp_bbox_filter): | |
continue | |
included_boxes.append(i) | |
expanded_boxes.append(exp_box) | |
expanded_boxes = torch.LongTensor(expanded_boxes).view(-1, 4) | |
self.segmentation = self.segmentation[included_boxes] | |
self.cse_dets = {k: v[included_boxes] for k, v in self.cse_dets.items()} | |
self.mask = torch.empty((len(expanded_boxes), *self.target_imsize), device=tops.get_device(), dtype=torch.bool) | |
area = self.segmentation.sum(dim=[1, 2]).view(len(expanded_boxes)) | |
for i, box in enumerate(expanded_boxes): | |
self.mask[i] = cut_pad_resize(self.segmentation[i:i+1], box, self.target_imsize)[0] | |
dilation_kernel = get_kernel(int((self.target_imsize[0]*self.target_imsize[1])**0.5*self.dilation_percentage)) | |
self.maskrcnn_mask = self.mask.clone().logical_not()[:, None] | |
self.mask = utils.binary_dilation(self.mask[:, None], dilation_kernel) | |
for i in range(len(expanded_boxes)): | |
remove_dilate_in_pad(self.mask[i], expanded_boxes[i], self.orig_imshape_CHW[1:]) | |
self.boxes = expanded_boxes.cpu() | |
self.dilated_boxes = get_dilated_boxes(self.boxes, self.mask) | |
self.pre_processed = True | |
self.n_detections = len(self.boxes) | |
self.mask = self.mask.logical_not() | |
E_mask = torch.zeros((self.n_detections, 1, *self.target_imsize), device=self.mask.device, dtype=torch.bool) | |
self.vertices = torch.zeros_like(E_mask, dtype=torch.long) | |
for i in range(self.n_detections): | |
E_, E_mask[i] = transform_embedding( | |
self.cse_dets["instance_embedding"][i], | |
self.cse_dets["instance_segmentation"][i], | |
self.boxes[i], | |
self.cse_dets["bbox_XYXY"][i].cpu(), | |
self.target_imsize | |
) | |
self.vertices[i] = utils.from_E_to_vertex( | |
E_[None], E_mask[i:i+1].logical_not(), self.embed_map).squeeze()[None] | |
self.E_mask = E_mask | |
sorted_idx = torch.argsort(area, descending=False) | |
self.mask = self.mask[sorted_idx] | |
self.boxes = self.boxes[sorted_idx.cpu()] | |
self.vertices = self.vertices[sorted_idx] | |
self.E_mask = self.E_mask[sorted_idx] | |
self.maskrcnn_mask = self.maskrcnn_mask[sorted_idx] | |
def get_crop(self, idx: int, im): | |
self.pre_process() | |
assert idx < len(self) | |
box = self.boxes[idx] | |
mask = self.mask[idx] | |
im = cut_pad_resize(im, box, self.target_imsize).unsqueeze(0) | |
vertices_ = self.vertices[idx] | |
E_mask_ = self.E_mask[idx].float() | |
if self.normalize_embedding: | |
embedding = self.embed_map_normalized[vertices_.squeeze(dim=0)].permute(2, 0, 1) * E_mask_ | |
else: | |
embedding = self.embed_map[vertices_.squeeze(dim=0)].permute(2, 0, 1) * E_mask_ | |
return dict( | |
img=im, | |
mask=mask.float()[None], | |
boxes=box.reshape(1, -1), | |
E_mask=E_mask_[None], | |
vertices=vertices_[None], | |
embed_map=self.embed_map, | |
embedding=embedding[None], | |
maskrcnn_mask=self.maskrcnn_mask[idx].float()[None] | |
) | |
def __len__(self): | |
self.pre_process() | |
return self.n_detections | |
def state_dict(self, after_preprocess=False): | |
""" | |
The processed annotations occupy more space than the original detections. | |
""" | |
if not after_preprocess: | |
return { | |
"combined_segmentation": self.segmentation.bool(), | |
"cse_instance_segmentation": self.cse_dets["instance_segmentation"].bool(), | |
"cse_instance_embedding": self.cse_dets["instance_embedding"], | |
"cse_bbox_XYXY": self.cse_dets["bbox_XYXY"].long(), | |
"cls": self.__class__, | |
"orig_imshape_CHW": self.orig_imshape_CHW | |
} | |
self.pre_process() | |
def compress_bool(x): return torch.from_numpy(np.packbits(x.bool().cpu().numpy())) | |
return dict( | |
E_mask=compress_bool(self.E_mask), | |
mask=compress_bool(self.mask), | |
maskrcnn_mask=compress_bool(self.maskrcnn_mask), | |
vertices=self.vertices.to(torch.int16).cpu(), | |
cls=self.__class__, | |
boxes=self.boxes, | |
orig_imshape_CHW=self.orig_imshape_CHW, | |
) | |
def from_state_dict( | |
state_dict, embed_map, | |
post_process_cfg, **kwargs): | |
after_preprocess = "segmentation" not in state_dict and "combined_segmentation" not in state_dict | |
if after_preprocess: | |
detection = CSEPersonDetection( | |
segmentation=None, cse_dets=None, embed_map=embed_map, | |
orig_imshape_CHW=state_dict["orig_imshape_CHW"], | |
**post_process_cfg) | |
detection.vertices = tops.to_cuda(state_dict["vertices"].long()) | |
numel = np.prod(detection.vertices.shape) | |
def unpack_bool(x): | |
x = torch.from_numpy(np.unpackbits(x.numpy(), count=numel)) | |
return x.view(*detection.vertices.shape) | |
detection.E_mask = tops.to_cuda(unpack_bool(state_dict["E_mask"])) | |
detection.mask = tops.to_cuda(unpack_bool(state_dict["mask"])) | |
detection.maskrcnn_mask = tops.to_cuda(unpack_bool(state_dict["maskrcnn_mask"])) | |
detection.n_detections = len(detection.mask) | |
detection.pre_processed = True | |
if isinstance(state_dict["boxes"], np.ndarray): | |
state_dict["boxes"] = torch.from_numpy(state_dict["boxes"]) | |
detection.boxes = state_dict["boxes"] | |
return detection | |
cse_dets = dict( | |
instance_segmentation=state_dict["cse_instance_segmentation"], | |
instance_embedding=state_dict["cse_instance_embedding"], | |
embed_map=embed_map, | |
bbox_XYXY=state_dict["cse_bbox_XYXY"]) | |
cse_dets = {k: tops.to_cuda(v) for k, v in cse_dets.items()} | |
segmentation = state_dict["combined_segmentation"] | |
return CSEPersonDetection( | |
segmentation, cse_dets, embed_map=embed_map, | |
orig_imshape_CHW=state_dict["orig_imshape_CHW"], | |
**post_process_cfg) | |
def visualize(self, im): | |
self.pre_process() | |
if len(self) == 0: | |
return im | |
im = vis_utils.draw_cropped_masks( | |
im.cpu(), self.mask.cpu(), self.boxes, visualize_instances=False) | |
E = self.embed_map_cpu[self.vertices.long().cpu()].squeeze(1).permute(0, 3, 1, 2) | |
im = vis_utils.draw_cse_all( | |
E, self.E_mask.squeeze(1).bool().cpu(), im, | |
self.boxes, self.embed_map_cpu) | |
im = torchvision.utils.draw_bounding_boxes(im, self.boxes, colors=(255, 0, 0), width=2) | |
return im | |
def shift_and_preprocess_keypoints(keypoints: torch.Tensor, boxes): | |
keypoints = keypoints.clone() | |
N = boxes.shape[0] | |
tops.assert_shape(keypoints, (N, None, 3)) | |
tops.assert_shape(boxes, (N, 4)) | |
x0, y0, x1, y1 = [_.view(-1, 1) for _ in boxes.T] | |
w = x1 - x0 | |
h = y1 - y0 | |
keypoints[:, :, 0] = (keypoints[:, :, 0] - x0) / w | |
keypoints[:, :, 1] = (keypoints[:, :, 1] - y0) / h | |
def check_outside(x): return (x < 0).logical_or(x > 1) | |
is_outside = check_outside(keypoints[:, :, 0]).logical_or(check_outside(keypoints[:, :, 1])) | |
keypoints[:, :, 2] = keypoints[:, :, 2] > 0 | |
keypoints[:, :, 2] = (keypoints[:, :, 2] > 0).logical_and(is_outside.logical_not()) | |
return keypoints | |
class PersonDetection: | |
def __init__( | |
self, | |
segmentation, | |
target_imsize, | |
exp_bbox_cfg, exp_bbox_filter, | |
dilation_percentage: float, | |
orig_imshape_CHW, | |
kp_vis_thr=None, | |
keypoints=None, | |
**kwargs) -> None: | |
self.segmentation = segmentation | |
self.target_imsize = list(target_imsize) | |
self.pre_processed = False | |
self.exp_bbox_cfg = exp_bbox_cfg | |
self.exp_bbox_filter = exp_bbox_filter | |
self.dilation_percentage = dilation_percentage | |
self.orig_imshape_CHW = orig_imshape_CHW | |
self.orig_keypoints = keypoints | |
if keypoints is not None: | |
assert kp_vis_thr is not None | |
self.kp_vis_thr = kp_vis_thr | |
def pre_process(self): | |
if self.pre_processed: | |
return | |
boxes = masks_to_boxes(self.segmentation).cpu() | |
expanded_boxes = [] | |
included_boxes = [] | |
for i in range(len(boxes)): | |
exp_box = get_expanded_bbox( | |
boxes[i], self.orig_imshape_CHW[1:], self.segmentation[i], **self.exp_bbox_cfg, | |
target_aspect_ratio=self.target_imsize[0]/self.target_imsize[1]) | |
if not include_box(exp_box, imshape=self.orig_imshape_CHW[1:], **self.exp_bbox_filter): | |
continue | |
included_boxes.append(i) | |
expanded_boxes.append(exp_box) | |
expanded_boxes = torch.LongTensor(expanded_boxes).view(-1, 4) | |
self.segmentation = self.segmentation[included_boxes] | |
if self.orig_keypoints is not None: | |
self.keypoints = self.orig_keypoints[included_boxes].clone() | |
self.keypoints[:, :, 2] = self.keypoints[:, :, 2] >= self.kp_vis_thr | |
area = self.segmentation.sum(dim=[1, 2]).view(len(expanded_boxes)).cpu() | |
self.mask = torch.empty((len(expanded_boxes), *self.target_imsize), device=tops.get_device(), dtype=torch.bool) | |
for i, box in enumerate(expanded_boxes): | |
self.mask[i] = cut_pad_resize(self.segmentation[i:i+1], box, self.target_imsize)[0] | |
if self.orig_keypoints is not None: | |
self.keypoints = shift_and_preprocess_keypoints(self.keypoints, expanded_boxes) | |
dilation_kernel = get_kernel(int((self.target_imsize[0]*self.target_imsize[1])**0.5*self.dilation_percentage)) | |
self.maskrcnn_mask = self.mask.clone().logical_not()[:, None] | |
self.mask = utils.binary_dilation(self.mask[:, None], dilation_kernel) | |
for i in range(len(expanded_boxes)): | |
remove_dilate_in_pad(self.mask[i], expanded_boxes[i], self.orig_imshape_CHW[1:]) | |
self.boxes = expanded_boxes | |
self.dilated_boxes = get_dilated_boxes(self.boxes, self.mask) | |
self.pre_processed = True | |
self.n_detections = len(self.boxes) | |
self.mask = self.mask.logical_not() | |
sorted_idx = torch.argsort(area, descending=False) | |
self.mask = self.mask[sorted_idx] | |
self.boxes = self.boxes[sorted_idx.cpu()] | |
self.segmentation = self.segmentation[sorted_idx] | |
self.maskrcnn_mask = self.maskrcnn_mask[sorted_idx] | |
if self.keypoints is not None: | |
self.keypoints = self.keypoints[sorted_idx.cpu()] | |
def get_crop(self, idx: int, im: torch.Tensor): | |
assert idx < len(self) | |
self.pre_process() | |
box = self.boxes[idx] | |
mask = self.mask[idx][None].float() | |
im = cut_pad_resize(im, box, self.target_imsize).unsqueeze(0) | |
batch = dict( | |
img=im, mask=mask, boxes=box.reshape(1, -1), | |
maskrcnn_mask=self.maskrcnn_mask[idx][None].float()) | |
if self.keypoints is not None: | |
batch["keypoints"] = self.keypoints[idx:idx+1] | |
return batch | |
def __len__(self): | |
self.pre_process() | |
return self.n_detections | |
def state_dict(self, **kwargs): | |
return dict( | |
segmentation=self.segmentation.bool(), | |
cls=self.__class__, | |
orig_imshape_CHW=self.orig_imshape_CHW, | |
keypoints=self.orig_keypoints | |
) | |
def from_state_dict( | |
state_dict, | |
post_process_cfg, **kwargs): | |
return PersonDetection( | |
state_dict["segmentation"], | |
orig_imshape_CHW=state_dict["orig_imshape_CHW"], | |
**post_process_cfg, | |
keypoints=state_dict["keypoints"]) | |
def visualize(self, im): | |
self.pre_process() | |
im = im.cpu() | |
if len(self) == 0: | |
return im | |
im = vis_utils.draw_cropped_masks(im.clone(), self.mask.cpu(), self.boxes, visualize_instances=False) | |
if self.keypoints is not None: | |
im = vis_utils.draw_cropped_keypoints(im, self.keypoints, self.boxes) | |
return im | |
def get_dilated_boxes(exp_bbox: torch.LongTensor, mask): | |
""" | |
mask: resized mask | |
""" | |
assert exp_bbox.shape[0] == mask.shape[0] | |
boxes = masks_to_boxes(mask.squeeze(1)).cpu() | |
H, W = exp_bbox[:, 3] - exp_bbox[:, 1], exp_bbox[:, 2] - exp_bbox[:, 0] | |
boxes[:, [0, 2]] = (boxes[:, [0, 2]] * W[:, None] / mask.shape[-1]).long() | |
boxes[:, [1, 3]] = (boxes[:, [1, 3]] * H[:, None] / mask.shape[-2]).long() | |
boxes[:, [0, 2]] += exp_bbox[:, 0:1] | |
boxes[:, [1, 3]] += exp_bbox[:, 1:2] | |
return boxes | |