Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	| import cv2 | |
| import numpy as np | |
| import torch | |
| import tops | |
| from skimage.morphology import disk | |
| from torchvision.transforms.functional import resize, InterpolationMode | |
| from functools import lru_cache | |
| def get_kernel(n: int): | |
| kernel = disk(n, dtype=bool) | |
| return tops.to_cuda(torch.from_numpy(kernel).bool()) | |
| def transform_embedding(E: torch.Tensor, S: torch.Tensor, exp_bbox, E_bbox, target_imshape): | |
| """ | |
| Transforms the detected embedding/mask directly to the target image shape | |
| """ | |
| C, HE, WE = E.shape | |
| assert E_bbox[0] >= exp_bbox[0], (E_bbox, exp_bbox) | |
| assert E_bbox[2] >= exp_bbox[0] | |
| assert E_bbox[1] >= exp_bbox[1] | |
| assert E_bbox[3] >= exp_bbox[1] | |
| assert E_bbox[2] <= exp_bbox[2] | |
| assert E_bbox[3] <= exp_bbox[3] | |
| x0 = int(np.round((E_bbox[0] - exp_bbox[0]) / (exp_bbox[2] - exp_bbox[0]) * target_imshape[1])) | |
| x1 = int(np.round((E_bbox[2] - exp_bbox[0]) / (exp_bbox[2] - exp_bbox[0]) * target_imshape[1])) | |
| y0 = int(np.round((E_bbox[1] - exp_bbox[1]) / (exp_bbox[3] - exp_bbox[1]) * target_imshape[0])) | |
| y1 = int(np.round((E_bbox[3] - exp_bbox[1]) / (exp_bbox[3] - exp_bbox[1]) * target_imshape[0])) | |
| new_E = torch.zeros((C, *target_imshape), device=E.device, dtype=torch.float32) | |
| new_S = torch.zeros((target_imshape), device=S.device, dtype=torch.bool) | |
| E = resize(E, (y1-y0, x1-x0), antialias=True, interpolation=InterpolationMode.BILINEAR) | |
| new_E[:, y0:y1, x0:x1] = E | |
| S = resize(S[None].float(), (y1-y0, x1-x0), antialias=True, interpolation=InterpolationMode.BILINEAR)[0] > 0 | |
| new_S[y0:y1, x0:x1] = S | |
| return new_E, new_S | |
| def pairwise_mask_iou(mask1: torch.Tensor, mask2: torch.Tensor): | |
| """ | |
| mask: shape [N, H, W] | |
| """ | |
| assert len(mask1.shape) == 3 | |
| assert len(mask2.shape) == 3 | |
| assert mask1.device == mask2.device, (mask1.device, mask2.device) | |
| assert mask2.dtype == mask2.dtype | |
| assert mask1.dtype == torch.bool | |
| assert mask1.shape[1:] == mask2.shape[1:] | |
| N1, H1, W1 = mask1.shape | |
| N2, H2, W2 = mask2.shape | |
| iou = torch.zeros((N1, N2), dtype=torch.float32) | |
| for i in range(N1): | |
| cur = mask1[i:i+1] | |
| inter = torch.logical_and(cur, mask2).flatten(start_dim=1).float().sum(dim=1).cpu() | |
| union = torch.logical_or(cur, mask2).flatten(start_dim=1).float().sum(dim=1).cpu() | |
| iou[i] = inter / union | |
| return iou | |
| def find_best_matches(mask1: torch.Tensor, mask2: torch.Tensor, iou_threshold: float): | |
| N1 = mask1.shape[0] | |
| N2 = mask2.shape[0] | |
| ious = pairwise_mask_iou(mask1, mask2).cpu().numpy() | |
| indices = np.array([idx for idx, iou in np.ndenumerate(ious)]) | |
| ious = ious.flatten() | |
| mask = ious >= iou_threshold | |
| ious = ious[mask] | |
| indices = indices[mask] | |
| # do not sort by iou to keep ordering of mask rcnn / cse sorting. | |
| taken1 = np.zeros((N1), dtype=bool) | |
| taken2 = np.zeros((N2), dtype=bool) | |
| matches = [] | |
| for i, j in indices: | |
| if taken1[i].any() or taken2[j].any(): | |
| continue | |
| matches.append((i, j)) | |
| taken1[i] = True | |
| taken2[j] = True | |
| return matches | |
| def combine_cse_maskrcnn_dets(segmentation: torch.Tensor, cse_dets: dict, iou_threshold: float): | |
| assert 0 < iou_threshold <= 1 | |
| matches = find_best_matches(segmentation, cse_dets["im_segmentation"], iou_threshold) | |
| H, W = segmentation.shape[1:] | |
| new_seg = torch.zeros((len(matches), H, W), dtype=torch.bool, device=segmentation.device) | |
| cse_im_seg = cse_dets["im_segmentation"] | |
| for idx, (i, j) in enumerate(matches): | |
| new_seg[idx] = torch.logical_or(segmentation[i], cse_im_seg[j]) | |
| cse_dets = dict( | |
| instance_segmentation=cse_dets["instance_segmentation"][[j for (i, j) in matches]], | |
| instance_embedding=cse_dets["instance_embedding"][[j for (i, j) in matches]], | |
| bbox_XYXY=cse_dets["bbox_XYXY"][[j for (i, j) in matches]], | |
| scores=cse_dets["scores"][[j for (i, j) in matches]], | |
| ) | |
| return new_seg, cse_dets, np.array(matches).reshape(-1, 2) | |
| def initialize_cse_boxes(segmentation: torch.Tensor, cse_boxes: torch.Tensor): | |
| """ | |
| cse_boxes can be outside of segmentation. | |
| """ | |
| boxes = masks_to_boxes(segmentation) | |
| assert boxes.shape == cse_boxes.shape, (boxes.shape, cse_boxes.shape) | |
| combined = torch.stack((boxes, cse_boxes), dim=-1) | |
| boxes = torch.cat(( | |
| combined[:, :2].min(dim=2).values, | |
| combined[:, 2:].max(dim=2).values, | |
| ), dim=1) | |
| return boxes | |
| def cut_pad_resize(x: torch.Tensor, bbox, target_shape, fdf_resize=False): | |
| """ | |
| Crops or pads x to fit in the bbox and resize to target shape. | |
| """ | |
| C, H, W = x.shape | |
| x0, y0, x1, y1 = bbox | |
| if y0 > 0 and x0 > 0 and x1 <= W and y1 <= H: | |
| new_x = x[:, y0:y1, x0:x1] | |
| else: | |
| new_x = torch.zeros(((C, y1-y0, x1-x0)), dtype=x.dtype, device=x.device) | |
| y0_t = max(0, -y0) | |
| y1_t = min(y1-y0, (y1-y0)-(y1-H)) | |
| x0_t = max(0, -x0) | |
| x1_t = min(x1-x0, (x1-x0)-(x1-W)) | |
| x0 = max(0, x0) | |
| y0 = max(0, y0) | |
| x1 = min(x1, W) | |
| y1 = min(y1, H) | |
| new_x[:, y0_t:y1_t, x0_t:x1_t] = x[:, y0:y1, x0:x1] | |
| # Nearest upsampling often generates more sharp synthesized identities. | |
| interp = InterpolationMode.BICUBIC | |
| if (y1-y0) < target_shape[0] and (x1-x0) < target_shape[1]: | |
| interp = InterpolationMode.NEAREST | |
| antialias = interp == InterpolationMode.BICUBIC | |
| if x1 - x0 == target_shape[1] and y1 - y0 == target_shape[0]: | |
| return new_x | |
| if x.dtype == torch.bool: | |
| new_x = resize(new_x.float(), target_shape, interpolation=InterpolationMode.NEAREST) > 0.5 | |
| elif x.dtype == torch.float32: | |
| new_x = resize(new_x, target_shape, interpolation=interp, antialias=antialias) | |
| elif x.dtype == torch.uint8: | |
| if fdf_resize: # FDF dataset is created with cv2 INTER_AREA. | |
| # Incorrect resizing generates noticeable poorer inpaintings. | |
| upsampling = ((y1-y0) * (x1-x0)) < (target_shape[0] * target_shape[1]) | |
| if upsampling: | |
| new_x = resize(new_x.float(), target_shape, interpolation=InterpolationMode.BICUBIC, | |
| antialias=True).round().clamp(0, 255).byte() | |
| else: | |
| device = new_x.device | |
| new_x = new_x.permute(1, 2, 0).cpu().numpy() | |
| new_x = cv2.resize(new_x, target_shape[::-1], interpolation=cv2.INTER_AREA) | |
| new_x = torch.from_numpy(np.rollaxis(new_x, 2)).to(device) | |
| else: | |
| new_x = resize(new_x.float(), target_shape, interpolation=interp, | |
| antialias=antialias).round().clamp(0, 255).byte() | |
| else: | |
| raise ValueError(f"Not supported dtype: {x.dtype}") | |
| return new_x | |
| def masks_to_boxes(segmentation: torch.Tensor): | |
| assert len(segmentation.shape) == 3 | |
| x = segmentation.any(dim=1).byte() # Compress rows | |
| x0 = x.argmax(dim=1) | |
| x1 = segmentation.shape[2] - x.flip(dims=(1,)).argmax(dim=1) | |
| y = segmentation.any(dim=2).byte() | |
| y0 = y.argmax(dim=1) | |
| y1 = segmentation.shape[1] - y.flip(dims=(1,)).argmax(dim=1) | |
| return torch.stack([x0, y0, x1, y1], dim=1) | |