import cv2 import numpy as np import torch import tops from skimage.morphology import disk from torchvision.transforms.functional import resize, InterpolationMode from functools import lru_cache @lru_cache(maxsize=200) def get_kernel(n: int): kernel = disk(n, dtype=bool) return tops.to_cuda(torch.from_numpy(kernel).bool()) def transform_embedding(E: torch.Tensor, S: torch.Tensor, exp_bbox, E_bbox, target_imshape): """ Transforms the detected embedding/mask directly to the target image shape """ C, HE, WE = E.shape assert E_bbox[0] >= exp_bbox[0], (E_bbox, exp_bbox) assert E_bbox[2] >= exp_bbox[0] assert E_bbox[1] >= exp_bbox[1] assert E_bbox[3] >= exp_bbox[1] assert E_bbox[2] <= exp_bbox[2] assert E_bbox[3] <= exp_bbox[3] x0 = int(np.round((E_bbox[0] - exp_bbox[0]) / (exp_bbox[2] - exp_bbox[0]) * target_imshape[1])) x1 = int(np.round((E_bbox[2] - exp_bbox[0]) / (exp_bbox[2] - exp_bbox[0]) * target_imshape[1])) y0 = int(np.round((E_bbox[1] - exp_bbox[1]) / (exp_bbox[3] - exp_bbox[1]) * target_imshape[0])) y1 = int(np.round((E_bbox[3] - exp_bbox[1]) / (exp_bbox[3] - exp_bbox[1]) * target_imshape[0])) new_E = torch.zeros((C, *target_imshape), device=E.device, dtype=torch.float32) new_S = torch.zeros((target_imshape), device=S.device, dtype=torch.bool) E = resize(E, (y1-y0, x1-x0), antialias=True, interpolation=InterpolationMode.BILINEAR) new_E[:, y0:y1, x0:x1] = E S = resize(S[None].float(), (y1-y0, x1-x0), antialias=True, interpolation=InterpolationMode.BILINEAR)[0] > 0 new_S[y0:y1, x0:x1] = S return new_E, new_S def pairwise_mask_iou(mask1: torch.Tensor, mask2: torch.Tensor): """ mask: shape [N, H, W] """ assert len(mask1.shape) == 3 assert len(mask2.shape) == 3 assert mask1.device == mask2.device, (mask1.device, mask2.device) assert mask2.dtype == mask2.dtype assert mask1.dtype == torch.bool assert mask1.shape[1:] == mask2.shape[1:] N1, H1, W1 = mask1.shape N2, H2, W2 = mask2.shape iou = torch.zeros((N1, N2), dtype=torch.float32) for i in range(N1): cur = mask1[i:i+1] inter = torch.logical_and(cur, mask2).flatten(start_dim=1).float().sum(dim=1).cpu() union = torch.logical_or(cur, mask2).flatten(start_dim=1).float().sum(dim=1).cpu() iou[i] = inter / union return iou def find_best_matches(mask1: torch.Tensor, mask2: torch.Tensor, iou_threshold: float): N1 = mask1.shape[0] N2 = mask2.shape[0] ious = pairwise_mask_iou(mask1, mask2).cpu().numpy() indices = np.array([idx for idx, iou in np.ndenumerate(ious)]) ious = ious.flatten() mask = ious >= iou_threshold ious = ious[mask] indices = indices[mask] # do not sort by iou to keep ordering of mask rcnn / cse sorting. taken1 = np.zeros((N1), dtype=bool) taken2 = np.zeros((N2), dtype=bool) matches = [] for i, j in indices: if taken1[i].any() or taken2[j].any(): continue matches.append((i, j)) taken1[i] = True taken2[j] = True return matches def combine_cse_maskrcnn_dets(segmentation: torch.Tensor, cse_dets: dict, iou_threshold: float): assert 0 < iou_threshold <= 1 matches = find_best_matches(segmentation, cse_dets["im_segmentation"], iou_threshold) H, W = segmentation.shape[1:] new_seg = torch.zeros((len(matches), H, W), dtype=torch.bool, device=segmentation.device) cse_im_seg = cse_dets["im_segmentation"] for idx, (i, j) in enumerate(matches): new_seg[idx] = torch.logical_or(segmentation[i], cse_im_seg[j]) cse_dets = dict( instance_segmentation=cse_dets["instance_segmentation"][[j for (i, j) in matches]], instance_embedding=cse_dets["instance_embedding"][[j for (i, j) in matches]], bbox_XYXY=cse_dets["bbox_XYXY"][[j for (i, j) in matches]], scores=cse_dets["scores"][[j for (i, j) in matches]], ) return new_seg, cse_dets, np.array(matches).reshape(-1, 2) def initialize_cse_boxes(segmentation: torch.Tensor, cse_boxes: torch.Tensor): """ cse_boxes can be outside of segmentation. """ boxes = masks_to_boxes(segmentation) assert boxes.shape == cse_boxes.shape, (boxes.shape, cse_boxes.shape) combined = torch.stack((boxes, cse_boxes), dim=-1) boxes = torch.cat(( combined[:, :2].min(dim=2).values, combined[:, 2:].max(dim=2).values, ), dim=1) return boxes def cut_pad_resize(x: torch.Tensor, bbox, target_shape, fdf_resize=False): """ Crops or pads x to fit in the bbox and resize to target shape. """ C, H, W = x.shape x0, y0, x1, y1 = bbox if y0 > 0 and x0 > 0 and x1 <= W and y1 <= H: new_x = x[:, y0:y1, x0:x1] else: new_x = torch.zeros(((C, y1-y0, x1-x0)), dtype=x.dtype, device=x.device) y0_t = max(0, -y0) y1_t = min(y1-y0, (y1-y0)-(y1-H)) x0_t = max(0, -x0) x1_t = min(x1-x0, (x1-x0)-(x1-W)) x0 = max(0, x0) y0 = max(0, y0) x1 = min(x1, W) y1 = min(y1, H) new_x[:, y0_t:y1_t, x0_t:x1_t] = x[:, y0:y1, x0:x1] # Nearest upsampling often generates more sharp synthesized identities. interp = InterpolationMode.BICUBIC if (y1-y0) < target_shape[0] and (x1-x0) < target_shape[1]: interp = InterpolationMode.NEAREST antialias = interp == InterpolationMode.BICUBIC if x1 - x0 == target_shape[1] and y1 - y0 == target_shape[0]: return new_x if x.dtype == torch.bool: new_x = resize(new_x.float(), target_shape, interpolation=InterpolationMode.NEAREST) > 0.5 elif x.dtype == torch.float32: new_x = resize(new_x, target_shape, interpolation=interp, antialias=antialias) elif x.dtype == torch.uint8: if fdf_resize: # FDF dataset is created with cv2 INTER_AREA. # Incorrect resizing generates noticeable poorer inpaintings. upsampling = ((y1-y0) * (x1-x0)) < (target_shape[0] * target_shape[1]) if upsampling: new_x = resize(new_x.float(), target_shape, interpolation=InterpolationMode.BICUBIC, antialias=True).round().clamp(0, 255).byte() else: device = new_x.device new_x = new_x.permute(1, 2, 0).cpu().numpy() new_x = cv2.resize(new_x, target_shape[::-1], interpolation=cv2.INTER_AREA) new_x = torch.from_numpy(np.rollaxis(new_x, 2)).to(device) else: new_x = resize(new_x.float(), target_shape, interpolation=interp, antialias=antialias).round().clamp(0, 255).byte() else: raise ValueError(f"Not supported dtype: {x.dtype}") return new_x def masks_to_boxes(segmentation: torch.Tensor): assert len(segmentation.shape) == 3 x = segmentation.any(dim=1).byte() # Compress rows x0 = x.argmax(dim=1) x1 = segmentation.shape[2] - x.flip(dims=(1,)).argmax(dim=1) y = segmentation.any(dim=2).byte() y0 = y.argmax(dim=1) y1 = segmentation.shape[1] - y.flip(dims=(1,)).argmax(dim=1) return torch.stack([x0, y0, x1, y1], dim=1)