import numpy.typing as npt import numpy as np import torch import torch.nn as nn import torch.nn.functional as F import cv2 from torchvision.ops.boxes import batched_nms from app.mobile_sam import SamPredictor from app.mobile_sam.utils import batched_mask_to_box from app.sam.postprocessing import clean_mask_torch def point_selection(mask_sim, topk: int = 1): # Top-1 point selection _, h = mask_sim.shape topk_xy = mask_sim.flatten(0).topk(topk)[1] topk_x = (topk_xy // h).unsqueeze(0) topk_y = topk_xy - topk_x * h topk_xy = torch.cat((topk_y, topk_x), dim=0).permute(1, 0) topk_label = np.array([1] * topk) topk_xy = topk_xy.cpu().numpy() return topk_xy, topk_label def mask_nms( masks: list[npt.NDArray], scores: list[float], iou_thresh: float = 0.2 ) -> tuple[list[npt.NDArray], list[float]]: ious = np.zeros((len(masks), len(masks))) np_masks = np.array(masks).astype(bool) np_scores = np.array(scores) remove_indices = set() for i in range(len(masks)): mask_i = np_masks[i, :, :] intersection_sum = np.logical_and(mask_i, np_masks).sum(axis=(1, 2)) union = np.logical_or(mask_i, np_masks) ious_i = intersection_sum / union.sum(axis=(1, 2)) ious[i, :] = ious_i # if the mask completely overlaps another mask, take the highest # scoring mask and remove the lower (current) one overlap = intersection_sum >= np_masks.sum(axis=(1, 2)) * 0.90 argmax_idx = np_scores[overlap].argmax() max_idx = np.where(overlap == True)[0][argmax_idx] if max_idx != i: remove_indices.add(i) for i in range(ious.shape[0]): ious_i = ious[i, :] idxs = np.where(ious_i > iou_thresh)[0] keep = idxs[np.argmax(np_scores[idxs])] if keep != i: remove_indices.add(i) return [masks[i] for i in range(len(masks)) if i not in remove_indices], [ scores[i] for i in range(len(masks)) if i not in remove_indices ] class MaskWeights(nn.Module): def __init__(self): super().__init__() self.weights = nn.Parameter(torch.ones(2, 1, requires_grad=True) / 3) class PerSAM: def __init__( self, sam: SamPredictor, target_feat: torch.Tensor, max_objects: int, score_thresh: float, nms_iou_thresh: float, mask_weights: torch.Tensor, ) -> None: super().__init__() self.sam = sam self.weights = mask_weights self.target_feat = target_feat self.max_objects = max_objects self.score_thresh = score_thresh self.nms_iou_thresh = nms_iou_thresh def __call__(self, x: npt.NDArray) -> tuple[npt.NDArray, npt.NDArray, npt.NDArray]: return fast_inference( self.sam, x, self.target_feat, self.weights, self.max_objects, self.score_thresh, self.nms_iou_thresh, ) def fast_inference( predictor: SamPredictor, image: npt.NDArray, target_feat: torch.Tensor, weights: torch.Tensor, max_objects: int, score_thresh: float, nms_iou_thresh: float = 0.2, ) -> tuple[npt.NDArray | None, npt.NDArray | None, npt.NDArray | None]: weights_np = weights.detach().cpu().numpy() pred_masks = [] pred_scores = [] # Image feature encoding predictor.set_image(image) test_feat = predictor.features.squeeze() # Cosine similarity C, h, w = test_feat.shape test_feat = test_feat / test_feat.norm(dim=0, keepdim=True) test_feat = test_feat.reshape(C, h * w) sim = target_feat @ test_feat sim = sim.reshape(1, 1, h, w) sim = F.interpolate(sim, scale_factor=4, mode="bilinear") sim = predictor.model.postprocess_masks( sim, input_size=predictor.input_size, original_size=predictor.original_size ).squeeze() for _ in range(max_objects): # Positive location prior topk_xy, topk_label = point_selection(sim, topk=1) # First-step prediction logits_high, scores, logits = predictor.predict( point_coords=topk_xy, point_labels=topk_label, multimask_output=True, return_logits=True, return_numpy=False, ) logits = logits.detach().cpu().numpy() # Weighted sum three-scale masks logits_high = logits_high * weights.unsqueeze(-1) logit_high = logits_high.sum(0) # mask = (logit_high > 0).detach().cpu().numpy() mask = (logit_high > 0) mask = clean_mask_torch(mask).bool()[0, 0, :, :].detach().cpu().numpy() logits = logits * weights_np[..., None] logit = logits.sum(0) # Cascaded Post-refinement-1 y, x = np.nonzero(mask) x_min = x.min() x_max = x.max() y_min = y.min() y_max = y.max() input_box = np.array([x_min, y_min, x_max, y_max]) masks, scores, logits = predictor.predict( point_coords=topk_xy, point_labels=topk_label, box=input_box[None, :], mask_input=logit[None, :, :], multimask_output=True, ) best_idx = np.argmax(scores) # Cascaded Post-refinement-2 y, x = np.nonzero(masks[best_idx]) x_min = x.min() x_max = x.max() y_min = y.min() y_max = y.max() input_box = np.array([x_min, y_min, x_max, y_max]) masks, scores, logits = predictor.predict( point_coords=topk_xy, point_labels=topk_label, box=input_box[None, :], mask_input=logits[best_idx : best_idx + 1, :, :], multimask_output=True, return_numpy=False, ) best_idx = np.argmax(scores.detach().cpu().numpy()) final_mask = masks[best_idx] score = sim[topk_xy[0][1], topk_xy[0][0]].item() final_mask_dilate = cv2.dilate( final_mask.detach().cpu().numpy().astype(np.uint8), np.ones((5, 5), np.uint8), iterations=1 ) if score < score_thresh: break sim[final_mask_dilate] = 0 pred_masks.append(final_mask) pred_scores.append(score) if len(pred_masks) == 0: return None, None, None pred_masks = torch.stack(pred_masks) bboxes = batched_mask_to_box(pred_masks) keep_by_nms = batched_nms( bboxes.float(), torch.as_tensor(pred_scores), torch.zeros_like(bboxes[:, 0]), iou_threshold=nms_iou_thresh, ) pred_masks = pred_masks[keep_by_nms].cpu().numpy() pred_scores = np.array(pred_scores)[keep_by_nms.cpu().numpy()] bboxes = bboxes[keep_by_nms].int().cpu().numpy() return pred_masks, bboxes, pred_scores