dillonlaird's picture
fixed noise in mask issue
2751f79
raw
history blame
No virus
6.8 kB
import numpy.typing as npt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import cv2
from torchvision.ops.boxes import batched_nms
from app.mobile_sam import SamPredictor
from app.mobile_sam.utils import batched_mask_to_box
from app.sam.postprocessing import clean_mask_torch
def point_selection(mask_sim, topk: int = 1):
# Top-1 point selection
_, h = mask_sim.shape
topk_xy = mask_sim.flatten(0).topk(topk)[1]
topk_x = (topk_xy // h).unsqueeze(0)
topk_y = topk_xy - topk_x * h
topk_xy = torch.cat((topk_y, topk_x), dim=0).permute(1, 0)
topk_label = np.array([1] * topk)
topk_xy = topk_xy.cpu().numpy()
return topk_xy, topk_label
def mask_nms(
masks: list[npt.NDArray], scores: list[float], iou_thresh: float = 0.2
) -> tuple[list[npt.NDArray], list[float]]:
ious = np.zeros((len(masks), len(masks)))
np_masks = np.array(masks).astype(bool)
np_scores = np.array(scores)
remove_indices = set()
for i in range(len(masks)):
mask_i = np_masks[i, :, :]
intersection_sum = np.logical_and(mask_i, np_masks).sum(axis=(1, 2))
union = np.logical_or(mask_i, np_masks)
ious_i = intersection_sum / union.sum(axis=(1, 2))
ious[i, :] = ious_i
# if the mask completely overlaps another mask, take the highest
# scoring mask and remove the lower (current) one
overlap = intersection_sum >= np_masks.sum(axis=(1, 2)) * 0.90
argmax_idx = np_scores[overlap].argmax()
max_idx = np.where(overlap == True)[0][argmax_idx]
if max_idx != i:
remove_indices.add(i)
for i in range(ious.shape[0]):
ious_i = ious[i, :]
idxs = np.where(ious_i > iou_thresh)[0]
keep = idxs[np.argmax(np_scores[idxs])]
if keep != i:
remove_indices.add(i)
return [masks[i] for i in range(len(masks)) if i not in remove_indices], [
scores[i] for i in range(len(masks)) if i not in remove_indices
]
class MaskWeights(nn.Module):
def __init__(self):
super().__init__()
self.weights = nn.Parameter(torch.ones(2, 1, requires_grad=True) / 3)
class PerSAM:
def __init__(
self,
sam: SamPredictor,
target_feat: torch.Tensor,
max_objects: int,
score_thresh: float,
nms_iou_thresh: float,
mask_weights: torch.Tensor,
) -> None:
super().__init__()
self.sam = sam
self.weights = mask_weights
self.target_feat = target_feat
self.max_objects = max_objects
self.score_thresh = score_thresh
self.nms_iou_thresh = nms_iou_thresh
def __call__(self, x: npt.NDArray) -> tuple[npt.NDArray, npt.NDArray, npt.NDArray]:
return fast_inference(
self.sam,
x,
self.target_feat,
self.weights,
self.max_objects,
self.score_thresh,
self.nms_iou_thresh,
)
def fast_inference(
predictor: SamPredictor,
image: npt.NDArray,
target_feat: torch.Tensor,
weights: torch.Tensor,
max_objects: int,
score_thresh: float,
nms_iou_thresh: float = 0.2,
) -> tuple[npt.NDArray | None, npt.NDArray | None, npt.NDArray | None]:
weights_np = weights.detach().cpu().numpy()
pred_masks = []
pred_scores = []
# Image feature encoding
predictor.set_image(image)
test_feat = predictor.features.squeeze()
# Cosine similarity
C, h, w = test_feat.shape
test_feat = test_feat / test_feat.norm(dim=0, keepdim=True)
test_feat = test_feat.reshape(C, h * w)
sim = target_feat @ test_feat
sim = sim.reshape(1, 1, h, w)
sim = F.interpolate(sim, scale_factor=4, mode="bilinear")
sim = predictor.model.postprocess_masks(
sim, input_size=predictor.input_size, original_size=predictor.original_size
).squeeze()
for _ in range(max_objects):
# Positive location prior
topk_xy, topk_label = point_selection(sim, topk=1)
# First-step prediction
logits_high, scores, logits = predictor.predict(
point_coords=topk_xy,
point_labels=topk_label,
multimask_output=True,
return_logits=True,
return_numpy=False,
)
logits = logits.detach().cpu().numpy()
# Weighted sum three-scale masks
logits_high = logits_high * weights.unsqueeze(-1)
logit_high = logits_high.sum(0)
# mask = (logit_high > 0).detach().cpu().numpy()
mask = (logit_high > 0)
mask = clean_mask_torch(mask).bool()[0, 0, :, :].detach().cpu().numpy()
logits = logits * weights_np[..., None]
logit = logits.sum(0)
# Cascaded Post-refinement-1
y, x = np.nonzero(mask)
x_min = x.min()
x_max = x.max()
y_min = y.min()
y_max = y.max()
input_box = np.array([x_min, y_min, x_max, y_max])
masks, scores, logits = predictor.predict(
point_coords=topk_xy,
point_labels=topk_label,
box=input_box[None, :],
mask_input=logit[None, :, :],
multimask_output=True,
)
best_idx = np.argmax(scores)
# Cascaded Post-refinement-2
y, x = np.nonzero(masks[best_idx])
x_min = x.min()
x_max = x.max()
y_min = y.min()
y_max = y.max()
input_box = np.array([x_min, y_min, x_max, y_max])
masks, scores, logits = predictor.predict(
point_coords=topk_xy,
point_labels=topk_label,
box=input_box[None, :],
mask_input=logits[best_idx : best_idx + 1, :, :],
multimask_output=True,
return_numpy=False,
)
best_idx = np.argmax(scores.detach().cpu().numpy())
final_mask = masks[best_idx]
score = sim[topk_xy[0][1], topk_xy[0][0]].item()
final_mask_dilate = cv2.dilate(
final_mask.detach().cpu().numpy().astype(np.uint8), np.ones((5, 5), np.uint8), iterations=1
)
if score < score_thresh:
break
sim[final_mask_dilate] = 0
pred_masks.append(final_mask)
pred_scores.append(score)
if len(pred_masks) == 0:
return None, None, None
pred_masks = torch.stack(pred_masks)
bboxes = batched_mask_to_box(pred_masks)
keep_by_nms = batched_nms(
bboxes.float(),
torch.as_tensor(pred_scores),
torch.zeros_like(bboxes[:, 0]),
iou_threshold=nms_iou_thresh,
)
pred_masks = pred_masks[keep_by_nms].cpu().numpy()
pred_scores = np.array(pred_scores)[keep_by_nms.cpu().numpy()]
bboxes = bboxes[keep_by_nms].int().cpu().numpy()
return pred_masks, bboxes, pred_scores