|
import torch |
|
import torch.nn.functional as F |
|
from .sam2_implementation.modeling.sam2_base import SAM2Base as _SAM2Base |
|
|
|
class SAM2Base(_SAM2Base): |
|
|
|
def track_step( |
|
self, |
|
frame_idx, |
|
is_init_cond_frame, |
|
current_vision_feats, |
|
current_vision_pos_embeds, |
|
feat_sizes, |
|
point_inputs, |
|
mask_inputs, |
|
output_dict, |
|
num_frames, |
|
track_in_reverse=False, |
|
|
|
|
|
|
|
|
|
|
|
run_mem_encoder=True, |
|
|
|
prev_sam_mask_logits=None, |
|
|
|
language_embd=None, |
|
): |
|
current_out = {"point_inputs": point_inputs, "mask_inputs": mask_inputs} |
|
|
|
if len(current_vision_feats) > 1: |
|
high_res_features = [ |
|
x.permute(1, 2, 0).view(x.size(1), x.size(2), *s) |
|
for x, s in zip(current_vision_feats[:-1], feat_sizes[:-1]) |
|
] |
|
else: |
|
high_res_features = None |
|
if mask_inputs is not None and self.use_mask_input_as_output_without_sam: |
|
|
|
|
|
pix_feat = current_vision_feats[-1].permute(1, 2, 0) |
|
pix_feat = pix_feat.view(-1, self.hidden_dim, *feat_sizes[-1]) |
|
sam_outputs = self._use_mask_as_output( |
|
pix_feat, high_res_features, mask_inputs |
|
) |
|
else: |
|
|
|
pix_feat_with_mem = self._prepare_memory_conditioned_features( |
|
frame_idx=frame_idx, |
|
is_init_cond_frame=is_init_cond_frame, |
|
current_vision_feats=current_vision_feats[-1:], |
|
current_vision_pos_embeds=current_vision_pos_embeds[-1:], |
|
feat_sizes=feat_sizes[-1:], |
|
output_dict=output_dict, |
|
num_frames=num_frames, |
|
track_in_reverse=track_in_reverse, |
|
) |
|
|
|
|
|
|
|
|
|
if prev_sam_mask_logits is not None: |
|
assert point_inputs is not None and mask_inputs is None |
|
mask_inputs = prev_sam_mask_logits |
|
multimask_output = self._use_multimask(is_init_cond_frame, point_inputs) |
|
sam_outputs = self._forward_sam_heads( |
|
backbone_features=pix_feat_with_mem, |
|
point_inputs=point_inputs, |
|
mask_inputs=mask_inputs, |
|
high_res_features=high_res_features, |
|
multimask_output=multimask_output, |
|
|
|
language_embd=language_embd, |
|
) |
|
( |
|
_, |
|
_, |
|
_, |
|
low_res_masks, |
|
high_res_masks, |
|
obj_ptr, |
|
_, |
|
) = sam_outputs |
|
|
|
current_out["pred_masks"] = low_res_masks |
|
current_out["pred_masks_high_res"] = high_res_masks |
|
current_out["obj_ptr"] = obj_ptr |
|
|
|
|
|
|
|
if run_mem_encoder and self.num_maskmem > 0: |
|
high_res_masks_for_mem_enc = high_res_masks |
|
maskmem_features, maskmem_pos_enc = self._encode_new_memory( |
|
current_vision_feats=current_vision_feats, |
|
feat_sizes=feat_sizes, |
|
pred_masks_high_res=high_res_masks_for_mem_enc, |
|
is_mask_from_pts=(point_inputs is not None), |
|
) |
|
current_out["maskmem_features"] = maskmem_features |
|
current_out["maskmem_pos_enc"] = maskmem_pos_enc |
|
else: |
|
current_out["maskmem_features"] = None |
|
current_out["maskmem_pos_enc"] = None |
|
|
|
return current_out |
|
|
|
|
|
def _forward_sam_heads( |
|
self, |
|
backbone_features, |
|
point_inputs=None, |
|
mask_inputs=None, |
|
high_res_features=None, |
|
multimask_output=False, |
|
|
|
language_embd=None, |
|
): |
|
""" |
|
Forward SAM prompt encoders and mask heads. |
|
|
|
Inputs: |
|
- backbone_features: image features of [B, C, H, W] shape |
|
- point_inputs: a dictionary with "point_coords" and "point_labels", where |
|
1) "point_coords" has [B, P, 2] shape and float32 dtype and contains the |
|
absolute pixel-unit coordinate in (x, y) format of the P input points |
|
2) "point_labels" has shape [B, P] and int32 dtype, where 1 means |
|
positive clicks, 0 means negative clicks, and -1 means padding |
|
- mask_inputs: a mask of [B, 1, H*16, W*16] shape, float or bool, with the |
|
same spatial size as the image. |
|
- high_res_features: either 1) None or 2) or a list of length 2 containing |
|
two feature maps of [B, C, 4*H, 4*W] and [B, C, 2*H, 2*W] shapes respectively, |
|
which will be used as high-resolution feature maps for SAM decoder. |
|
- multimask_output: if it's True, we output 3 candidate masks and their 3 |
|
corresponding IoU estimates, and if it's False, we output only 1 mask and |
|
its corresponding IoU estimate. |
|
|
|
Outputs: |
|
- low_res_multimasks: [B, M, H*4, W*4] shape (where M = 3 if |
|
`multimask_output=True` and M = 1 if `multimask_output=False`), the SAM |
|
output mask logits (before sigmoid) for the low-resolution masks, with 4x |
|
the resolution (1/4 stride) of the input backbone_features. |
|
- high_res_multimasks: [B, M, H*16, W*16] shape (where M = 3 |
|
if `multimask_output=True` and M = 1 if `multimask_output=False`), |
|
upsampled from the low-resolution masks, with shape size as the image |
|
(stride is 1 pixel). |
|
- ious, [B, M] shape, where (where M = 3 if `multimask_output=True` and M = 1 |
|
if `multimask_output=False`), the estimated IoU of each output mask. |
|
- low_res_masks: [B, 1, H*4, W*4] shape, the best mask in `low_res_multimasks`. |
|
If `multimask_output=True`, it's the mask with the highest IoU estimate. |
|
If `multimask_output=False`, it's the same as `low_res_multimasks`. |
|
- high_res_masks: [B, 1, H*16, W*16] shape, the best mask in `high_res_multimasks`. |
|
If `multimask_output=True`, it's the mask with the highest IoU estimate. |
|
If `multimask_output=False`, it's the same as `high_res_multimasks`. |
|
- obj_ptr: [B, C] shape, the object pointer vector for the output mask, extracted |
|
based on the output token from the SAM mask decoder. |
|
""" |
|
B = backbone_features.size(0) |
|
device = backbone_features.device |
|
assert backbone_features.size(1) == self.sam_prompt_embed_dim |
|
assert backbone_features.size(2) == self.sam_image_embedding_size |
|
assert backbone_features.size(3) == self.sam_image_embedding_size |
|
|
|
|
|
if point_inputs is not None: |
|
sam_point_coords = point_inputs["point_coords"] |
|
sam_point_labels = point_inputs["point_labels"] |
|
assert sam_point_coords.size(0) == B and sam_point_labels.size(0) == B |
|
else: |
|
|
|
sam_point_coords = torch.zeros(B, 1, 2, device=device) |
|
sam_point_labels = -torch.ones(B, 1, dtype=torch.int32, device=device) |
|
|
|
|
|
if mask_inputs is not None: |
|
|
|
|
|
assert len(mask_inputs.shape) == 4 and mask_inputs.shape[:2] == (B, 1) |
|
if mask_inputs.shape[-2:] != self.sam_prompt_encoder.mask_input_size: |
|
sam_mask_prompt = F.interpolate( |
|
mask_inputs.float(), |
|
size=self.sam_prompt_encoder.mask_input_size, |
|
align_corners=False, |
|
mode="bilinear", |
|
antialias=True, |
|
) |
|
else: |
|
sam_mask_prompt = mask_inputs |
|
else: |
|
|
|
|
|
sam_mask_prompt = None |
|
|
|
sparse_embeddings, dense_embeddings = self.sam_prompt_encoder( |
|
points=(sam_point_coords, sam_point_labels), |
|
boxes=None, |
|
masks=sam_mask_prompt, |
|
) |
|
|
|
|
|
if language_embd is not None: |
|
|
|
assert sparse_embeddings.size(0) == language_embd.size(0) |
|
assert sparse_embeddings.size(2) == language_embd.size(2) |
|
sparse_embeddings = torch.cat([sparse_embeddings, language_embd], dim=1) |
|
|
|
( |
|
low_res_multimasks, |
|
ious, |
|
sam_output_tokens, |
|
object_score_logits, |
|
) = self.sam_mask_decoder( |
|
image_embeddings=backbone_features, |
|
image_pe=self.sam_prompt_encoder.get_dense_pe(), |
|
sparse_prompt_embeddings=sparse_embeddings, |
|
dense_prompt_embeddings=dense_embeddings, |
|
multimask_output=multimask_output, |
|
repeat_image=False, |
|
high_res_features=high_res_features, |
|
) |
|
if self.pred_obj_scores: |
|
is_obj_appearing = object_score_logits > 0 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
low_res_multimasks = low_res_multimasks.float() |
|
high_res_multimasks = F.interpolate( |
|
low_res_multimasks, |
|
size=(self.image_size, self.image_size), |
|
mode="bilinear", |
|
align_corners=False, |
|
) |
|
|
|
sam_output_token = sam_output_tokens[:, 0] |
|
if multimask_output: |
|
|
|
best_iou_inds = torch.argmax(ious, dim=-1) |
|
batch_inds = torch.arange(B, device=device) |
|
low_res_masks = low_res_multimasks[batch_inds, best_iou_inds].unsqueeze(1) |
|
high_res_masks = high_res_multimasks[batch_inds, best_iou_inds].unsqueeze(1) |
|
if sam_output_tokens.size(1) > 1: |
|
sam_output_token = sam_output_tokens[batch_inds, best_iou_inds] |
|
else: |
|
low_res_masks, high_res_masks = low_res_multimasks, high_res_multimasks |
|
|
|
|
|
obj_ptr = self.obj_ptr_proj(sam_output_token) |
|
if self.pred_obj_scores: |
|
|
|
if self.soft_no_obj_ptr: |
|
|
|
assert not self.teacher_force_obj_scores_for_mem |
|
lambda_is_obj_appearing = object_score_logits.sigmoid() |
|
else: |
|
lambda_is_obj_appearing = is_obj_appearing.float() |
|
|
|
if self.fixed_no_obj_ptr: |
|
obj_ptr = lambda_is_obj_appearing * obj_ptr |
|
obj_ptr = obj_ptr + (1 - lambda_is_obj_appearing) * self.no_obj_ptr |
|
|
|
return ( |
|
low_res_multimasks, |
|
high_res_multimasks, |
|
ious, |
|
low_res_masks, |
|
high_res_masks, |
|
obj_ptr, |
|
object_score_logits, |
|
) |
|
|