|
|
|
|
|
|
|
|
import os |
|
|
from copy import deepcopy |
|
|
from typing import Dict, List, Optional, Tuple |
|
|
|
|
|
import numpy as np |
|
|
import torch |
|
|
|
|
|
from sam3.model.model_misc import SAM3Output |
|
|
|
|
|
from sam3.model.sam1_task_predictor import SAM3InteractiveImagePredictor |
|
|
from sam3.model.vl_combiner import SAM3VLBackbone |
|
|
from sam3.perflib.nms import nms_masks |
|
|
|
|
|
from sam3.train.data.collator import BatchedDatapoint |
|
|
|
|
|
from .act_ckpt_utils import activation_ckpt_wrapper |
|
|
|
|
|
from .box_ops import box_cxcywh_to_xyxy |
|
|
|
|
|
from .geometry_encoders import Prompt |
|
|
from .model_misc import inverse_sigmoid |
|
|
|
|
|
|
|
|
def _update_out(out, out_name, out_value, auxiliary=True, update_aux=True): |
|
|
out[out_name] = out_value[-1] if auxiliary else out_value |
|
|
if auxiliary and update_aux: |
|
|
if "aux_outputs" not in out: |
|
|
out["aux_outputs"] = [{} for _ in range(len(out_value) - 1)] |
|
|
assert len(out["aux_outputs"]) == len(out_value) - 1 |
|
|
for aux_output, aux_value in zip(out["aux_outputs"], out_value[:-1]): |
|
|
aux_output[out_name] = aux_value |
|
|
|
|
|
|
|
|
class Sam3Image(torch.nn.Module): |
|
|
TEXT_ID_FOR_TEXT = 0 |
|
|
TEXT_ID_FOR_VISUAL = 1 |
|
|
TEXT_ID_FOR_GEOMETRIC = 2 |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
backbone: SAM3VLBackbone, |
|
|
transformer, |
|
|
input_geometry_encoder, |
|
|
segmentation_head=None, |
|
|
num_feature_levels=1, |
|
|
o2m_mask_predict=True, |
|
|
dot_prod_scoring=None, |
|
|
use_instance_query: bool = True, |
|
|
multimask_output: bool = True, |
|
|
use_act_checkpoint_seg_head: bool = True, |
|
|
interactivity_in_encoder: bool = True, |
|
|
matcher=None, |
|
|
use_dot_prod_scoring=True, |
|
|
supervise_joint_box_scores: bool = False, |
|
|
detach_presence_in_joint_score: bool = False, |
|
|
separate_scorer_for_instance: bool = False, |
|
|
num_interactive_steps_val: int = 0, |
|
|
inst_interactive_predictor: SAM3InteractiveImagePredictor = None, |
|
|
**kwargs, |
|
|
): |
|
|
super().__init__() |
|
|
self.backbone = backbone |
|
|
self.geometry_encoder = input_geometry_encoder |
|
|
self.transformer = transformer |
|
|
self.hidden_dim = transformer.d_model |
|
|
self.num_feature_levels = num_feature_levels |
|
|
self.segmentation_head = segmentation_head |
|
|
|
|
|
self.o2m_mask_predict = o2m_mask_predict |
|
|
|
|
|
self.dot_prod_scoring = dot_prod_scoring |
|
|
self.use_act_checkpoint_seg_head = use_act_checkpoint_seg_head |
|
|
self.interactivity_in_encoder = interactivity_in_encoder |
|
|
self.matcher = matcher |
|
|
|
|
|
self.num_interactive_steps_val = num_interactive_steps_val |
|
|
self.use_dot_prod_scoring = use_dot_prod_scoring |
|
|
|
|
|
if self.use_dot_prod_scoring: |
|
|
assert dot_prod_scoring is not None |
|
|
self.dot_prod_scoring = dot_prod_scoring |
|
|
self.instance_dot_prod_scoring = None |
|
|
if separate_scorer_for_instance: |
|
|
self.instance_dot_prod_scoring = deepcopy(dot_prod_scoring) |
|
|
else: |
|
|
self.class_embed = torch.nn.Linear(self.hidden_dim, 1) |
|
|
self.instance_class_embed = None |
|
|
if separate_scorer_for_instance: |
|
|
self.instance_class_embed = deepcopy(self.class_embed) |
|
|
|
|
|
self.supervise_joint_box_scores = supervise_joint_box_scores |
|
|
self.detach_presence_in_joint_score = detach_presence_in_joint_score |
|
|
|
|
|
|
|
|
num_o2o_static = self.transformer.decoder.num_queries |
|
|
num_o2m_static = self.transformer.decoder.num_o2m_queries |
|
|
assert num_o2m_static == (num_o2o_static if self.transformer.decoder.dac else 0) |
|
|
self.dac = self.transformer.decoder.dac |
|
|
|
|
|
self.use_instance_query = use_instance_query |
|
|
self.multimask_output = multimask_output |
|
|
|
|
|
self.inst_interactive_predictor = inst_interactive_predictor |
|
|
|
|
|
@property |
|
|
def device(self): |
|
|
self._device = getattr(self, "_device", None) or next(self.parameters()).device |
|
|
return self._device |
|
|
|
|
|
def to(self, *args, **kwargs): |
|
|
|
|
|
self._device = None |
|
|
return super().to(*args, **kwargs) |
|
|
|
|
|
def _get_img_feats(self, backbone_out, img_ids): |
|
|
"""Retrieve correct image features from backbone output.""" |
|
|
if "backbone_fpn" in backbone_out: |
|
|
if "id_mapping" in backbone_out and backbone_out["id_mapping"] is not None: |
|
|
img_ids = backbone_out["id_mapping"][img_ids] |
|
|
|
|
|
|
|
|
|
|
|
torch._assert_async((img_ids >= 0).all()) |
|
|
|
|
|
vis_feats = backbone_out["backbone_fpn"][-self.num_feature_levels :] |
|
|
vis_pos_enc = backbone_out["vision_pos_enc"][-self.num_feature_levels :] |
|
|
vis_feat_sizes = [x.shape[-2:] for x in vis_pos_enc] |
|
|
|
|
|
img_feats = [x[img_ids].flatten(2).permute(2, 0, 1) for x in vis_feats] |
|
|
img_pos_embeds = [ |
|
|
x[img_ids].flatten(2).permute(2, 0, 1) for x in vis_pos_enc |
|
|
] |
|
|
return backbone_out, img_feats, img_pos_embeds, vis_feat_sizes |
|
|
|
|
|
|
|
|
|
|
|
img_batch = backbone_out["img_batch_all_stages"] |
|
|
if img_ids.numel() > 1: |
|
|
|
|
|
unique_ids, _ = torch.unique(img_ids, return_inverse=True) |
|
|
else: |
|
|
unique_ids, _ = img_ids, slice(None) |
|
|
|
|
|
|
|
|
|
|
|
if isinstance(img_batch, torch.Tensor): |
|
|
image = img_batch[unique_ids] |
|
|
elif unique_ids.numel() == 1: |
|
|
image = img_batch[unique_ids.item()].unsqueeze(0) |
|
|
else: |
|
|
image = torch.stack([img_batch[i] for i in unique_ids.tolist()]) |
|
|
|
|
|
image = image.to(dtype=torch.float32, device=self.device) |
|
|
|
|
|
id_mapping = torch.full( |
|
|
(len(img_batch),), -1, dtype=torch.long, device=self.device |
|
|
) |
|
|
id_mapping[unique_ids] = torch.arange(len(unique_ids), device=self.device) |
|
|
backbone_out = { |
|
|
**backbone_out, |
|
|
**self.backbone.forward_image(image), |
|
|
"id_mapping": id_mapping, |
|
|
} |
|
|
assert "backbone_fpn" in backbone_out |
|
|
return self._get_img_feats(backbone_out, img_ids=img_ids) |
|
|
|
|
|
def _encode_prompt( |
|
|
self, |
|
|
backbone_out, |
|
|
find_input, |
|
|
geometric_prompt, |
|
|
visual_prompt_embed=None, |
|
|
visual_prompt_mask=None, |
|
|
encode_text=True, |
|
|
prev_mask_pred=None, |
|
|
): |
|
|
|
|
|
|
|
|
txt_ids = find_input.text_ids |
|
|
txt_feats = backbone_out["language_features"][:, txt_ids] |
|
|
txt_masks = backbone_out["language_mask"][txt_ids] |
|
|
|
|
|
feat_tuple = self._get_img_feats(backbone_out, find_input.img_ids) |
|
|
backbone_out, img_feats, img_pos_embeds, vis_feat_sizes = feat_tuple |
|
|
|
|
|
if prev_mask_pred is not None: |
|
|
img_feats = [img_feats[-1] + prev_mask_pred] |
|
|
|
|
|
geo_feats, geo_masks = self.geometry_encoder( |
|
|
geo_prompt=geometric_prompt, |
|
|
img_feats=img_feats, |
|
|
img_sizes=vis_feat_sizes, |
|
|
img_pos_embeds=img_pos_embeds, |
|
|
) |
|
|
if visual_prompt_embed is None: |
|
|
visual_prompt_embed = torch.zeros( |
|
|
(0, *geo_feats.shape[1:]), device=geo_feats.device |
|
|
) |
|
|
visual_prompt_mask = torch.zeros( |
|
|
(*geo_masks.shape[:-1], 0), |
|
|
device=geo_masks.device, |
|
|
dtype=geo_masks.dtype, |
|
|
) |
|
|
if encode_text: |
|
|
prompt = torch.cat([txt_feats, geo_feats, visual_prompt_embed], dim=0) |
|
|
prompt_mask = torch.cat([txt_masks, geo_masks, visual_prompt_mask], dim=1) |
|
|
else: |
|
|
prompt = torch.cat([geo_feats, visual_prompt_embed], dim=0) |
|
|
prompt_mask = torch.cat([geo_masks, visual_prompt_mask], dim=1) |
|
|
return prompt, prompt_mask, backbone_out |
|
|
|
|
|
def _run_encoder( |
|
|
self, |
|
|
backbone_out, |
|
|
find_input, |
|
|
prompt, |
|
|
prompt_mask, |
|
|
encoder_extra_kwargs: Optional[Dict] = None, |
|
|
): |
|
|
feat_tuple = self._get_img_feats(backbone_out, find_input.img_ids) |
|
|
backbone_out, img_feats, img_pos_embeds, vis_feat_sizes = feat_tuple |
|
|
|
|
|
|
|
|
prompt_pos_embed = torch.zeros_like(prompt) |
|
|
|
|
|
memory = self.transformer.encoder( |
|
|
src=img_feats.copy(), |
|
|
src_key_padding_mask=None, |
|
|
src_pos=img_pos_embeds.copy(), |
|
|
prompt=prompt, |
|
|
prompt_pos=prompt_pos_embed, |
|
|
prompt_key_padding_mask=prompt_mask, |
|
|
feat_sizes=vis_feat_sizes, |
|
|
encoder_extra_kwargs=encoder_extra_kwargs, |
|
|
) |
|
|
encoder_out = { |
|
|
|
|
|
"encoder_hidden_states": memory["memory"], |
|
|
"pos_embed": memory["pos_embed"], |
|
|
"padding_mask": memory["padding_mask"], |
|
|
"level_start_index": memory["level_start_index"], |
|
|
"spatial_shapes": memory["spatial_shapes"], |
|
|
"valid_ratios": memory["valid_ratios"], |
|
|
"vis_feat_sizes": vis_feat_sizes, |
|
|
|
|
|
"prompt_before_enc": prompt, |
|
|
"prompt_after_enc": memory.get("memory_text", prompt), |
|
|
"prompt_mask": prompt_mask, |
|
|
} |
|
|
return backbone_out, encoder_out, feat_tuple |
|
|
|
|
|
def _run_decoder( |
|
|
self, |
|
|
pos_embed, |
|
|
memory, |
|
|
src_mask, |
|
|
out, |
|
|
prompt, |
|
|
prompt_mask, |
|
|
encoder_out, |
|
|
): |
|
|
bs = memory.shape[1] |
|
|
query_embed = self.transformer.decoder.query_embed.weight |
|
|
tgt = query_embed.unsqueeze(1).repeat(1, bs, 1) |
|
|
|
|
|
apply_dac = self.transformer.decoder.dac and self.training |
|
|
hs, reference_boxes, dec_presence_out, dec_presence_feats = ( |
|
|
self.transformer.decoder( |
|
|
tgt=tgt, |
|
|
memory=memory, |
|
|
memory_key_padding_mask=src_mask, |
|
|
pos=pos_embed, |
|
|
reference_boxes=None, |
|
|
level_start_index=encoder_out["level_start_index"], |
|
|
spatial_shapes=encoder_out["spatial_shapes"], |
|
|
valid_ratios=encoder_out["valid_ratios"], |
|
|
tgt_mask=None, |
|
|
memory_text=prompt, |
|
|
text_attention_mask=prompt_mask, |
|
|
apply_dac=apply_dac, |
|
|
) |
|
|
) |
|
|
hs = hs.transpose(1, 2) |
|
|
reference_boxes = reference_boxes.transpose(1, 2) |
|
|
if dec_presence_out is not None: |
|
|
|
|
|
dec_presence_out = dec_presence_out.transpose(1, 2) |
|
|
|
|
|
out["presence_feats"] = dec_presence_feats |
|
|
self._update_scores_and_boxes( |
|
|
out, |
|
|
hs, |
|
|
reference_boxes, |
|
|
prompt, |
|
|
prompt_mask, |
|
|
dec_presence_out=dec_presence_out, |
|
|
) |
|
|
return out, hs |
|
|
|
|
|
def _update_scores_and_boxes( |
|
|
self, |
|
|
out, |
|
|
hs, |
|
|
reference_boxes, |
|
|
prompt, |
|
|
prompt_mask, |
|
|
dec_presence_out=None, |
|
|
is_instance_prompt=False, |
|
|
): |
|
|
apply_dac = self.transformer.decoder.dac and self.training |
|
|
num_o2o = (hs.size(2) // 2) if apply_dac else hs.size(2) |
|
|
num_o2m = hs.size(2) - num_o2o |
|
|
assert num_o2m == (num_o2o if apply_dac else 0) |
|
|
out["queries"] = hs[-1][:, :num_o2o] |
|
|
|
|
|
if self.use_dot_prod_scoring: |
|
|
dot_prod_scoring_head = self.dot_prod_scoring |
|
|
if is_instance_prompt and self.instance_dot_prod_scoring is not None: |
|
|
dot_prod_scoring_head = self.instance_dot_prod_scoring |
|
|
outputs_class = dot_prod_scoring_head(hs, prompt, prompt_mask) |
|
|
else: |
|
|
class_embed_head = self.class_embed |
|
|
if is_instance_prompt and self.instance_class_embed is not None: |
|
|
class_embed_head = self.instance_class_embed |
|
|
outputs_class = class_embed_head(hs) |
|
|
|
|
|
|
|
|
box_head = self.transformer.decoder.bbox_embed |
|
|
if ( |
|
|
is_instance_prompt |
|
|
and self.transformer.decoder.instance_bbox_embed is not None |
|
|
): |
|
|
box_head = self.transformer.decoder.instance_bbox_embed |
|
|
anchor_box_offsets = box_head(hs) |
|
|
reference_boxes_inv_sig = inverse_sigmoid(reference_boxes) |
|
|
outputs_coord = (reference_boxes_inv_sig + anchor_box_offsets).sigmoid() |
|
|
outputs_boxes_xyxy = box_cxcywh_to_xyxy(outputs_coord) |
|
|
|
|
|
if dec_presence_out is not None: |
|
|
_update_out( |
|
|
out, "presence_logit_dec", dec_presence_out, update_aux=self.training |
|
|
) |
|
|
|
|
|
if self.supervise_joint_box_scores: |
|
|
assert dec_presence_out is not None |
|
|
prob_dec_presence_out = dec_presence_out.clone().sigmoid() |
|
|
if self.detach_presence_in_joint_score: |
|
|
prob_dec_presence_out = prob_dec_presence_out.detach() |
|
|
|
|
|
outputs_class = inverse_sigmoid( |
|
|
outputs_class.sigmoid() * prob_dec_presence_out.unsqueeze(2) |
|
|
).clamp(min=-10.0, max=10.0) |
|
|
|
|
|
_update_out( |
|
|
out, "pred_logits", outputs_class[:, :, :num_o2o], update_aux=self.training |
|
|
) |
|
|
_update_out( |
|
|
out, "pred_boxes", outputs_coord[:, :, :num_o2o], update_aux=self.training |
|
|
) |
|
|
_update_out( |
|
|
out, |
|
|
"pred_boxes_xyxy", |
|
|
outputs_boxes_xyxy[:, :, :num_o2o], |
|
|
update_aux=self.training, |
|
|
) |
|
|
if num_o2m > 0 and self.training: |
|
|
_update_out( |
|
|
out, |
|
|
"pred_logits_o2m", |
|
|
outputs_class[:, :, num_o2o:], |
|
|
update_aux=self.training, |
|
|
) |
|
|
_update_out( |
|
|
out, |
|
|
"pred_boxes_o2m", |
|
|
outputs_coord[:, :, num_o2o:], |
|
|
update_aux=self.training, |
|
|
) |
|
|
_update_out( |
|
|
out, |
|
|
"pred_boxes_xyxy_o2m", |
|
|
outputs_boxes_xyxy[:, :, num_o2o:], |
|
|
update_aux=self.training, |
|
|
) |
|
|
|
|
|
def _run_segmentation_heads( |
|
|
self, |
|
|
out, |
|
|
backbone_out, |
|
|
img_ids, |
|
|
vis_feat_sizes, |
|
|
encoder_hidden_states, |
|
|
prompt, |
|
|
prompt_mask, |
|
|
hs, |
|
|
): |
|
|
apply_dac = self.transformer.decoder.dac and self.training |
|
|
if self.segmentation_head is not None: |
|
|
num_o2o = (hs.size(2) // 2) if apply_dac else hs.size(2) |
|
|
num_o2m = hs.size(2) - num_o2o |
|
|
obj_queries = hs if self.o2m_mask_predict else hs[:, :, :num_o2o] |
|
|
seg_head_outputs = activation_ckpt_wrapper(self.segmentation_head)( |
|
|
backbone_feats=backbone_out["backbone_fpn"], |
|
|
obj_queries=obj_queries, |
|
|
image_ids=img_ids, |
|
|
encoder_hidden_states=encoder_hidden_states, |
|
|
act_ckpt_enable=self.training and self.use_act_checkpoint_seg_head, |
|
|
prompt=prompt, |
|
|
prompt_mask=prompt_mask, |
|
|
) |
|
|
aux_masks = False |
|
|
for k, v in seg_head_outputs.items(): |
|
|
if k in self.segmentation_head.instance_keys: |
|
|
_update_out(out, k, v[:, :num_o2o], auxiliary=aux_masks) |
|
|
if ( |
|
|
self.o2m_mask_predict and num_o2m > 0 |
|
|
): |
|
|
_update_out( |
|
|
out, f"{k}_o2m", v[:, num_o2o:], auxiliary=aux_masks |
|
|
) |
|
|
else: |
|
|
out[k] = v |
|
|
else: |
|
|
backbone_out.pop("backbone_fpn", None) |
|
|
|
|
|
def _get_best_mask(self, out): |
|
|
prev_mask_idx = out["pred_logits"].argmax(dim=1).squeeze(1) |
|
|
batch_idx = torch.arange( |
|
|
out["pred_logits"].shape[0], device=prev_mask_idx.device |
|
|
) |
|
|
prev_mask_pred = out["pred_masks"][batch_idx, prev_mask_idx][:, None] |
|
|
|
|
|
prev_mask_pred = self.geometry_encoder.mask_encoder.mask_downsampler( |
|
|
prev_mask_pred |
|
|
) |
|
|
prev_mask_pred = prev_mask_pred.flatten(-2).permute(2, 0, 1) |
|
|
|
|
|
return prev_mask_pred |
|
|
|
|
|
def forward_grounding( |
|
|
self, |
|
|
backbone_out, |
|
|
find_input, |
|
|
find_target, |
|
|
geometric_prompt: Prompt, |
|
|
): |
|
|
with torch.profiler.record_function("SAM3Image._encode_prompt"): |
|
|
prompt, prompt_mask, backbone_out = self._encode_prompt( |
|
|
backbone_out, find_input, geometric_prompt |
|
|
) |
|
|
|
|
|
with torch.profiler.record_function("SAM3Image._run_encoder"): |
|
|
backbone_out, encoder_out, _ = self._run_encoder( |
|
|
backbone_out, find_input, prompt, prompt_mask |
|
|
) |
|
|
out = { |
|
|
"encoder_hidden_states": encoder_out["encoder_hidden_states"], |
|
|
"prev_encoder_out": { |
|
|
"encoder_out": encoder_out, |
|
|
"backbone_out": backbone_out, |
|
|
}, |
|
|
} |
|
|
|
|
|
|
|
|
with torch.profiler.record_function("SAM3Image._run_decoder"): |
|
|
out, hs = self._run_decoder( |
|
|
memory=out["encoder_hidden_states"], |
|
|
pos_embed=encoder_out["pos_embed"], |
|
|
src_mask=encoder_out["padding_mask"], |
|
|
out=out, |
|
|
prompt=prompt, |
|
|
prompt_mask=prompt_mask, |
|
|
encoder_out=encoder_out, |
|
|
) |
|
|
|
|
|
|
|
|
with torch.profiler.record_function("SAM3Image._run_segmentation_heads"): |
|
|
self._run_segmentation_heads( |
|
|
out=out, |
|
|
backbone_out=backbone_out, |
|
|
img_ids=find_input.img_ids, |
|
|
vis_feat_sizes=encoder_out["vis_feat_sizes"], |
|
|
encoder_hidden_states=out["encoder_hidden_states"], |
|
|
prompt=prompt, |
|
|
prompt_mask=prompt_mask, |
|
|
hs=hs, |
|
|
) |
|
|
|
|
|
if self.training or self.num_interactive_steps_val > 0: |
|
|
self._compute_matching(out, self.back_convert(find_target)) |
|
|
return out |
|
|
|
|
|
def _postprocess_out(self, out: Dict, multimask_output: bool = False): |
|
|
|
|
|
num_mask_boxes = out["pred_boxes"].size(1) |
|
|
if not self.training and multimask_output and num_mask_boxes > 1: |
|
|
out["multi_pred_logits"] = out["pred_logits"] |
|
|
if "pred_masks" in out: |
|
|
out["multi_pred_masks"] = out["pred_masks"] |
|
|
out["multi_pred_boxes"] = out["pred_boxes"] |
|
|
out["multi_pred_boxes_xyxy"] = out["pred_boxes_xyxy"] |
|
|
|
|
|
best_mask_idx = out["pred_logits"].argmax(1).squeeze(1) |
|
|
batch_idx = torch.arange(len(best_mask_idx), device=best_mask_idx.device) |
|
|
|
|
|
out["pred_logits"] = out["pred_logits"][batch_idx, best_mask_idx].unsqueeze( |
|
|
1 |
|
|
) |
|
|
if "pred_masks" in out: |
|
|
out["pred_masks"] = out["pred_masks"][ |
|
|
batch_idx, best_mask_idx |
|
|
].unsqueeze(1) |
|
|
out["pred_boxes"] = out["pred_boxes"][batch_idx, best_mask_idx].unsqueeze(1) |
|
|
out["pred_boxes_xyxy"] = out["pred_boxes_xyxy"][ |
|
|
batch_idx, best_mask_idx |
|
|
].unsqueeze(1) |
|
|
|
|
|
return out |
|
|
|
|
|
def _get_dummy_prompt(self, num_prompts=1): |
|
|
device = self.device |
|
|
geometric_prompt = Prompt( |
|
|
box_embeddings=torch.zeros(0, num_prompts, 4, device=device), |
|
|
box_mask=torch.zeros(num_prompts, 0, device=device, dtype=torch.bool), |
|
|
) |
|
|
return geometric_prompt |
|
|
|
|
|
def forward(self, input: BatchedDatapoint): |
|
|
device = self.device |
|
|
backbone_out = {"img_batch_all_stages": input.img_batch} |
|
|
backbone_out.update(self.backbone.forward_image(input.img_batch)) |
|
|
num_frames = len(input.find_inputs) |
|
|
assert num_frames == 1 |
|
|
|
|
|
text_outputs = self.backbone.forward_text(input.find_text_batch, device=device) |
|
|
backbone_out.update(text_outputs) |
|
|
|
|
|
previous_stages_out = SAM3Output( |
|
|
iter_mode=SAM3Output.IterMode.LAST_STEP_PER_STAGE |
|
|
) |
|
|
|
|
|
find_input = input.find_inputs[0] |
|
|
find_target = input.find_targets[0] |
|
|
|
|
|
if find_input.input_points is not None and find_input.input_points.numel() > 0: |
|
|
print("Warning: Point prompts are ignored in PCS.") |
|
|
|
|
|
num_interactive_steps = 0 if self.training else self.num_interactive_steps_val |
|
|
geometric_prompt = Prompt( |
|
|
box_embeddings=find_input.input_boxes, |
|
|
box_mask=find_input.input_boxes_mask, |
|
|
box_labels=find_input.input_boxes_label, |
|
|
) |
|
|
|
|
|
|
|
|
stage_outs = [] |
|
|
for cur_step in range(num_interactive_steps + 1): |
|
|
if cur_step > 0: |
|
|
|
|
|
geometric_prompt, _ = self.interactive_prompt_sampler.sample( |
|
|
geo_prompt=geometric_prompt, |
|
|
find_target=find_target, |
|
|
previous_out=stage_outs[-1], |
|
|
) |
|
|
out = self.forward_grounding( |
|
|
backbone_out=backbone_out, |
|
|
find_input=find_input, |
|
|
find_target=find_target, |
|
|
geometric_prompt=geometric_prompt.clone(), |
|
|
) |
|
|
stage_outs.append(out) |
|
|
|
|
|
previous_stages_out.append(stage_outs) |
|
|
return previous_stages_out |
|
|
|
|
|
def _compute_matching(self, out, targets): |
|
|
out["indices"] = self.matcher(out, targets) |
|
|
for aux_out in out.get("aux_outputs", []): |
|
|
aux_out["indices"] = self.matcher(aux_out, targets) |
|
|
|
|
|
def back_convert(self, targets): |
|
|
batched_targets = { |
|
|
"boxes": targets.boxes.view(-1, 4), |
|
|
"boxes_xyxy": box_cxcywh_to_xyxy(targets.boxes.view(-1, 4)), |
|
|
"boxes_padded": targets.boxes_padded, |
|
|
"positive_map": targets.boxes.new_ones(len(targets.boxes), 1), |
|
|
"num_boxes": targets.num_boxes, |
|
|
"masks": targets.segments, |
|
|
"semantic_masks": targets.semantic_segments, |
|
|
"is_valid_mask": targets.is_valid_segment, |
|
|
"is_exhaustive": targets.is_exhaustive, |
|
|
"object_ids_packed": targets.object_ids, |
|
|
"object_ids_padded": targets.object_ids_padded, |
|
|
} |
|
|
return batched_targets |
|
|
|
|
|
def predict_inst( |
|
|
self, |
|
|
inference_state, |
|
|
**kwargs, |
|
|
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: |
|
|
orig_h, orig_w = ( |
|
|
inference_state["original_height"], |
|
|
inference_state["original_width"], |
|
|
) |
|
|
backbone_out = inference_state["backbone_out"]["sam2_backbone_out"] |
|
|
( |
|
|
_, |
|
|
vision_feats, |
|
|
_, |
|
|
_, |
|
|
) = self.inst_interactive_predictor.model._prepare_backbone_features( |
|
|
backbone_out |
|
|
) |
|
|
|
|
|
vision_feats[-1] = ( |
|
|
vision_feats[-1] + self.inst_interactive_predictor.model.no_mem_embed |
|
|
) |
|
|
feats = [ |
|
|
feat.permute(1, 2, 0).view(1, -1, *feat_size) |
|
|
for feat, feat_size in zip( |
|
|
vision_feats[::-1], self.inst_interactive_predictor._bb_feat_sizes[::-1] |
|
|
) |
|
|
][::-1] |
|
|
self.inst_interactive_predictor._features = { |
|
|
"image_embed": feats[-1], |
|
|
"high_res_feats": feats[:-1], |
|
|
} |
|
|
self.inst_interactive_predictor._is_image_set = True |
|
|
self.inst_interactive_predictor._orig_hw = [(orig_h, orig_w)] |
|
|
res = self.inst_interactive_predictor.predict(**kwargs) |
|
|
self.inst_interactive_predictor._features = None |
|
|
self.inst_interactive_predictor._is_image_set = False |
|
|
return res |
|
|
|
|
|
def predict_inst_batch( |
|
|
self, |
|
|
inference_state, |
|
|
*args, |
|
|
**kwargs, |
|
|
) -> Tuple[List[np.ndarray], List[np.ndarray], List[np.ndarray]]: |
|
|
backbone_out = inference_state["backbone_out"]["sam2_backbone_out"] |
|
|
( |
|
|
_, |
|
|
vision_feats, |
|
|
_, |
|
|
_, |
|
|
) = self.inst_interactive_predictor.model._prepare_backbone_features( |
|
|
backbone_out |
|
|
) |
|
|
|
|
|
vision_feats[-1] = ( |
|
|
vision_feats[-1] + self.inst_interactive_predictor.model.no_mem_embed |
|
|
) |
|
|
batch_size = vision_feats[-1].shape[1] |
|
|
orig_heights, orig_widths = ( |
|
|
inference_state["original_heights"], |
|
|
inference_state["original_widths"], |
|
|
) |
|
|
assert ( |
|
|
batch_size == len(orig_heights) == len(orig_widths) |
|
|
), f"Batch size mismatch in predict_inst_batch. Got {batch_size}, {len(orig_heights)}, {len(orig_widths)}" |
|
|
feats = [ |
|
|
feat.permute(1, 2, 0).view(batch_size, -1, *feat_size) |
|
|
for feat, feat_size in zip( |
|
|
vision_feats[::-1], self.inst_interactive_predictor._bb_feat_sizes[::-1] |
|
|
) |
|
|
][::-1] |
|
|
self.inst_interactive_predictor._features = { |
|
|
"image_embed": feats[-1], |
|
|
"high_res_feats": feats[:-1], |
|
|
} |
|
|
self.inst_interactive_predictor._is_image_set = True |
|
|
self.inst_interactive_predictor._is_batch = True |
|
|
self.inst_interactive_predictor._orig_hw = [ |
|
|
(orig_h, orig_w) for orig_h, orig_w in zip(orig_heights, orig_widths) |
|
|
] |
|
|
res = self.inst_interactive_predictor.predict_batch(*args, **kwargs) |
|
|
self.inst_interactive_predictor._features = None |
|
|
self.inst_interactive_predictor._is_image_set = False |
|
|
self.inst_interactive_predictor._is_batch = False |
|
|
return res |
|
|
|
|
|
|
|
|
class Sam3ImageOnVideoMultiGPU(Sam3Image): |
|
|
def __init__( |
|
|
self, *args, async_all_gather=True, gather_backbone_out=None, **kwargs |
|
|
): |
|
|
super().__init__(*args, **kwargs) |
|
|
self.rank = int(os.getenv("RANK", "0")) |
|
|
self.world_size = int(os.getenv("WORLD_SIZE", "1")) |
|
|
self.async_all_gather = async_all_gather |
|
|
|
|
|
|
|
|
if gather_backbone_out is None: |
|
|
gather_backbone_out = isinstance(self.backbone, SAM3VLBackbone) |
|
|
self.gather_backbone_out = gather_backbone_out |
|
|
|
|
|
def forward_video_grounding_multigpu( |
|
|
self, |
|
|
backbone_out, |
|
|
find_inputs, |
|
|
geometric_prompt: Prompt, |
|
|
frame_idx, |
|
|
num_frames, |
|
|
|
|
|
multigpu_buffer, |
|
|
track_in_reverse=False, |
|
|
|
|
|
return_sam2_backbone_feats=False, |
|
|
|
|
|
run_nms=False, |
|
|
nms_prob_thresh=None, |
|
|
nms_iou_thresh=None, |
|
|
**kwargs, |
|
|
): |
|
|
""" |
|
|
Compute the detector's detection outputs in a distributed manner, where all GPUs process |
|
|
a chunk of frames (equal to the number of GPUs) at once and store them in cache. |
|
|
""" |
|
|
|
|
|
frame_idx_curr_b = frame_idx - frame_idx % self.world_size |
|
|
frame_idx_curr_e = min(frame_idx_curr_b + self.world_size, num_frames) |
|
|
|
|
|
|
|
|
if frame_idx not in multigpu_buffer: |
|
|
with torch.profiler.record_function("build_multigpu_buffer_next_chunk1"): |
|
|
self._build_multigpu_buffer_next_chunk( |
|
|
backbone_out=backbone_out, |
|
|
find_inputs=find_inputs, |
|
|
geometric_prompt=geometric_prompt, |
|
|
frame_idx_begin=frame_idx_curr_b, |
|
|
frame_idx_end=frame_idx_curr_e, |
|
|
num_frames=num_frames, |
|
|
multigpu_buffer=multigpu_buffer, |
|
|
run_nms=run_nms, |
|
|
nms_prob_thresh=nms_prob_thresh, |
|
|
nms_iou_thresh=nms_iou_thresh, |
|
|
) |
|
|
|
|
|
|
|
|
out = {} |
|
|
for k, (v, handle) in multigpu_buffer[frame_idx].items(): |
|
|
if k.startswith("sam2_backbone_") and not return_sam2_backbone_feats: |
|
|
continue |
|
|
if handle is not None: |
|
|
handle.wait() |
|
|
out[k] = v |
|
|
|
|
|
|
|
|
if not track_in_reverse and frame_idx_curr_b - self.world_size >= 0: |
|
|
frame_idx_prev_e = frame_idx_curr_b |
|
|
frame_idx_prev_b = frame_idx_curr_b - self.world_size |
|
|
elif track_in_reverse and frame_idx_curr_e < num_frames: |
|
|
frame_idx_prev_b = frame_idx_curr_e |
|
|
frame_idx_prev_e = min(frame_idx_prev_b + self.world_size, num_frames) |
|
|
else: |
|
|
frame_idx_prev_b = frame_idx_prev_e = None |
|
|
if frame_idx_prev_b is not None: |
|
|
for frame_idx_rm in range(frame_idx_prev_b, frame_idx_prev_e): |
|
|
multigpu_buffer.pop(frame_idx_rm, None) |
|
|
|
|
|
|
|
|
|
|
|
if not track_in_reverse and frame_idx_curr_e < num_frames: |
|
|
frame_idx_next_b = frame_idx_curr_e |
|
|
frame_idx_next_e = min(frame_idx_next_b + self.world_size, num_frames) |
|
|
elif track_in_reverse and frame_idx_curr_b - self.world_size >= 0: |
|
|
frame_idx_next_e = frame_idx_curr_b |
|
|
frame_idx_next_b = frame_idx_curr_b - self.world_size |
|
|
else: |
|
|
frame_idx_next_b = frame_idx_next_e = None |
|
|
if frame_idx_next_b is not None and frame_idx_next_b not in multigpu_buffer: |
|
|
with torch.profiler.record_function("build_multigpu_buffer_next_chunk2"): |
|
|
self._build_multigpu_buffer_next_chunk( |
|
|
backbone_out=backbone_out, |
|
|
find_inputs=find_inputs, |
|
|
geometric_prompt=geometric_prompt, |
|
|
frame_idx_begin=frame_idx_next_b, |
|
|
frame_idx_end=frame_idx_next_e, |
|
|
num_frames=num_frames, |
|
|
multigpu_buffer=multigpu_buffer, |
|
|
run_nms=run_nms, |
|
|
nms_prob_thresh=nms_prob_thresh, |
|
|
nms_iou_thresh=nms_iou_thresh, |
|
|
) |
|
|
|
|
|
return out, backbone_out |
|
|
|
|
|
def _build_multigpu_buffer_next_chunk( |
|
|
self, |
|
|
backbone_out, |
|
|
find_inputs, |
|
|
geometric_prompt: Prompt, |
|
|
frame_idx_begin, |
|
|
frame_idx_end, |
|
|
num_frames, |
|
|
multigpu_buffer, |
|
|
run_nms=False, |
|
|
nms_prob_thresh=None, |
|
|
nms_iou_thresh=None, |
|
|
): |
|
|
"""Compute detection outputs on a chunk of frames and store their results in multigpu_buffer.""" |
|
|
|
|
|
frame_idx_local_gpu = min(frame_idx_begin + self.rank, frame_idx_end - 1) |
|
|
|
|
|
with torch.profiler.record_function("forward_grounding"): |
|
|
out_local = self.forward_grounding( |
|
|
backbone_out=backbone_out, |
|
|
find_input=find_inputs[frame_idx_local_gpu], |
|
|
find_target=None, |
|
|
geometric_prompt=geometric_prompt, |
|
|
) |
|
|
if run_nms: |
|
|
with torch.profiler.record_function("nms_masks"): |
|
|
|
|
|
assert nms_prob_thresh is not None and nms_iou_thresh is not None |
|
|
pred_probs = out_local["pred_logits"].squeeze(-1).sigmoid() |
|
|
pred_masks = out_local["pred_masks"] |
|
|
|
|
|
for prompt_idx in range(pred_probs.size(0)): |
|
|
keep = nms_masks( |
|
|
pred_probs=pred_probs[prompt_idx], |
|
|
pred_masks=pred_masks[prompt_idx], |
|
|
prob_threshold=nms_prob_thresh, |
|
|
iou_threshold=nms_iou_thresh, |
|
|
) |
|
|
|
|
|
out_local["pred_logits"][prompt_idx, :, 0] -= 1e4 * (~keep).float() |
|
|
|
|
|
if self.gather_backbone_out: |
|
|
|
|
|
feats = out_local["prev_encoder_out"]["backbone_out"]["sam2_backbone_out"] |
|
|
assert len(feats["backbone_fpn"]) == 3 |
|
|
|
|
|
|
|
|
backbone_fpn_bf16 = [x.to(torch.bfloat16) for x in feats["backbone_fpn"]] |
|
|
fpn0, fpn_handle0 = self._gather_tensor(backbone_fpn_bf16[0]) |
|
|
fpn1, fpn_handle1 = self._gather_tensor(backbone_fpn_bf16[1]) |
|
|
fpn2, fpn_handle2 = self._gather_tensor(backbone_fpn_bf16[2]) |
|
|
|
|
|
vision_pos_enc = feats["vision_pos_enc"] |
|
|
|
|
|
|
|
|
out_local = { |
|
|
"pred_logits": out_local["pred_logits"], |
|
|
"pred_boxes": out_local["pred_boxes"], |
|
|
"pred_boxes_xyxy": out_local["pred_boxes_xyxy"], |
|
|
"pred_masks": out_local["pred_masks"], |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
out_gathered = {k: self._gather_tensor(v) for k, v in out_local.items()} |
|
|
for rank in range(self.world_size): |
|
|
frame_idx_to_save = frame_idx_begin + rank |
|
|
if frame_idx_to_save >= num_frames: |
|
|
continue |
|
|
frame_buffer = { |
|
|
k: (v[rank], handle) for k, (v, handle) in out_gathered.items() |
|
|
} |
|
|
if self.gather_backbone_out: |
|
|
|
|
|
frame_buffer["tracker_backbone_fpn_0"] = (fpn0[rank], fpn_handle0) |
|
|
frame_buffer["tracker_backbone_fpn_1"] = (fpn1[rank], fpn_handle1) |
|
|
frame_buffer["tracker_backbone_fpn_2"] = (fpn2[rank], fpn_handle2) |
|
|
frame_buffer["tracker_backbone_pos_enc"] = (vision_pos_enc, None) |
|
|
|
|
|
multigpu_buffer[frame_idx_to_save] = frame_buffer |
|
|
|
|
|
def _gather_tensor(self, x): |
|
|
if self.world_size == 1: |
|
|
return [x], None |
|
|
|
|
|
async_op = self.async_all_gather |
|
|
|
|
|
|
|
|
x = x.contiguous() |
|
|
output_list = [torch.empty_like(x) for _ in range(self.world_size)] |
|
|
handle = torch.distributed.all_gather(output_list, x, async_op=async_op) |
|
|
return output_list, handle |
|
|
|