|
import torch |
|
import torch.nn as nn |
|
|
|
from .sam2_predictor import SAM2VideoPredictor |
|
from .sam2_implementation.modeling.backbones.hieradet import Hiera |
|
from .sam2_implementation.modeling.backbones.image_encoder import FpnNeck, ImageEncoder |
|
from .sam2_implementation.modeling.position_encoding import PositionEmbeddingSine |
|
from .sam2_implementation.modeling.memory_encoder import MemoryEncoder |
|
from .sam2_implementation.modeling.memory_attention import MemoryAttentionLayer, MemoryAttention |
|
from .sam2_implementation.modeling.sam.transformer import RoPEAttention |
|
from .sam2_implementation.modeling.memory_encoder import MaskDownSampler |
|
from .sam2_implementation.modeling.memory_encoder import Fuser |
|
from .sam2_implementation.modeling.memory_encoder import CXBlock |
|
|
|
def load_checkpoint_with_prefix(filename, prefix=None, map_location='cpu', logger='current'): |
|
"""Load partial pretrained model with specific prefix. |
|
|
|
Args: |
|
prefix (str): The prefix of sub-module. |
|
filename (str): Accept local filepath, URL, ``torchvision://xxx``, |
|
``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for |
|
details. |
|
map_location (str | None): Same as :func:`torch.load`. |
|
Defaults to None. |
|
logger: logger |
|
|
|
Returns: |
|
dict or OrderedDict: The loaded checkpoint. |
|
""" |
|
checkpoint = torch.load(filename, map_location=map_location) |
|
|
|
if 'state_dict' in checkpoint: |
|
state_dict = checkpoint['state_dict'] |
|
elif 'model' in checkpoint: |
|
state_dict = checkpoint['model'] |
|
else: |
|
state_dict = checkpoint |
|
if not prefix: |
|
return state_dict |
|
if not prefix.endswith('.'): |
|
prefix += '.' |
|
prefix_len = len(prefix) |
|
|
|
state_dict = { |
|
k[prefix_len:]: v |
|
for k, v in state_dict.items() if k.startswith(prefix) |
|
} |
|
|
|
assert state_dict, f'{prefix} is not in the pretrained model' |
|
return state_dict |
|
|
|
def load_state_dict_to_model(model, state_dict, logger='current'): |
|
missing_keys, unexpected_keys = model.load_state_dict(state_dict) |
|
if missing_keys: |
|
print(missing_keys) |
|
raise RuntimeError() |
|
if unexpected_keys: |
|
print(unexpected_keys) |
|
raise RuntimeError() |
|
print("Loaded checkpoint successfully") |
|
|
|
class SAM2(nn.Module): |
|
def __init__( |
|
self, |
|
ckpt_path: str = None, |
|
): |
|
super().__init__() |
|
|
|
image_encoder = self.build_image_encoder() |
|
memory_attention = self.build_memory_attention() |
|
memory_encoder = self.build_memory_encoder() |
|
sam2_model = SAM2VideoPredictor( |
|
image_encoder=image_encoder, |
|
memory_attention=memory_attention, |
|
memory_encoder=memory_encoder, |
|
num_maskmem = 7, |
|
image_size = 1024, |
|
|
|
sigmoid_scale_for_mem_enc = 20.0, |
|
sigmoid_bias_for_mem_enc = -10.0, |
|
use_mask_input_as_output_without_sam = True, |
|
|
|
directly_add_no_mem_embed = True, |
|
|
|
use_high_res_features_in_sam = True, |
|
|
|
multimask_output_in_sam = True, |
|
|
|
iou_prediction_use_sigmoid = True, |
|
|
|
use_obj_ptrs_in_encoder = True, |
|
add_tpos_enc_to_obj_ptrs = False, |
|
only_obj_ptrs_in_the_past_for_eval = True, |
|
|
|
pred_obj_scores = True, |
|
pred_obj_scores_mlp = True, |
|
fixed_no_obj_ptr = True, |
|
|
|
multimask_output_for_tracking = True, |
|
use_multimask_token_for_obj_ptr = True, |
|
multimask_min_pt_num = 0, |
|
multimask_max_pt_num = 1, |
|
use_mlp_for_obj_ptr_proj = True, |
|
|
|
compile_image_encoder = False, |
|
sam_mask_decoder_extra_args={ |
|
'dynamic_multimask_via_stability':True, |
|
'dynamic_multimask_stability_delta': 0.05, |
|
'dynamic_multimask_stability_thresh': 0.98, |
|
} |
|
) |
|
if ckpt_path is not None: |
|
state_dict = load_checkpoint_with_prefix(ckpt_path) |
|
load_state_dict_to_model(sam2_model, state_dict) |
|
|
|
self.sam2_model = sam2_model |
|
|
|
self.hidden_dim = self.sam2_model.hidden_dim |
|
|
|
self.img_mean = (0.485, 0.456, 0.406) |
|
self.img_std = (0.229, 0.224, 0.225) |
|
|
|
def build_image_encoder(self): |
|
def build_trunk(): |
|
embed_dim = 144 |
|
num_heads = 2 |
|
stages = [2, 6, 36, 4] |
|
global_att_blocks = [23, 33, 43] |
|
window_pos_embed_bkg_spatial_size = [7, 7] |
|
window_spec = [8, 4, 16, 8] |
|
ret = Hiera( |
|
embed_dim=embed_dim, |
|
num_heads=num_heads, |
|
stages=stages, |
|
global_att_blocks=global_att_blocks, |
|
window_pos_embed_bkg_spatial_size=window_pos_embed_bkg_spatial_size, |
|
window_spec=window_spec, |
|
) |
|
return ret |
|
def build_neck(): |
|
def build_position_encoding(): |
|
num_pos_feats = 256 |
|
normalize = True |
|
scale = None |
|
temperature = 10000 |
|
ret = PositionEmbeddingSine( |
|
num_pos_feats=num_pos_feats, |
|
normalize=normalize, |
|
scale=scale, |
|
temperature=temperature, |
|
) |
|
return ret |
|
d_model = 256 |
|
backbone_channel_list = [1152, 576, 288, 144] |
|
fpn_top_down_levels = [2, 3] |
|
fpn_interp_model = 'nearest' |
|
position_encoding = build_position_encoding() |
|
ret = FpnNeck( |
|
d_model=d_model, |
|
position_encoding=position_encoding, |
|
backbone_channel_list=backbone_channel_list, |
|
fpn_top_down_levels=fpn_top_down_levels, |
|
fpn_interp_model=fpn_interp_model, |
|
) |
|
return ret |
|
scalp = 1 |
|
trunk = build_trunk() |
|
neck = build_neck() |
|
ret = ImageEncoder(scalp=scalp, trunk=trunk, neck=neck) |
|
return ret |
|
|
|
def build_memory_attention(self): |
|
def build_layer(): |
|
def build_self_attention(): |
|
rope_theta = 10000.0 |
|
feat_sizes = [32, 32] |
|
embedding_dim = 256 |
|
num_heads = 1 |
|
downsample_rate = 1 |
|
dropout = 0.1 |
|
ret = RoPEAttention( |
|
rope_theta=rope_theta, |
|
feat_sizes=feat_sizes, |
|
embedding_dim=embedding_dim, |
|
num_heads=num_heads, |
|
downsample_rate=downsample_rate, |
|
dropout=dropout |
|
) |
|
return ret |
|
def build_cross_attention(): |
|
rope_theta = 10000.0 |
|
feat_sizes = [32, 32] |
|
rope_k_repeat = True |
|
embedding_dim = 256 |
|
num_heads = 1 |
|
downsample_rate = 1 |
|
dropout = 0.1 |
|
kv_in_dim = 64 |
|
ret = RoPEAttention( |
|
rope_theta=rope_theta, |
|
feat_sizes=feat_sizes, |
|
rope_k_repeat=rope_k_repeat, |
|
embedding_dim=embedding_dim, |
|
num_heads=num_heads, |
|
downsample_rate=downsample_rate, |
|
dropout=dropout, |
|
kv_in_dim=kv_in_dim |
|
) |
|
return ret |
|
activation = 'relu' |
|
dim_feedforward = 2048 |
|
dropout = 0.1 |
|
pos_enc_at_attn = False |
|
d_model = 256 |
|
pos_enc_at_cross_attn_keys = True |
|
pos_enc_at_cross_attn_queries = False |
|
self_attention = build_self_attention() |
|
cross_attention = build_cross_attention() |
|
ret = MemoryAttentionLayer( |
|
activation=activation, |
|
dim_feedforward=dim_feedforward, |
|
dropout=dropout, |
|
pos_enc_at_attn=pos_enc_at_attn, |
|
d_model=d_model, |
|
pos_enc_at_cross_attn_queries=pos_enc_at_cross_attn_queries, |
|
pos_enc_at_cross_attn_keys=pos_enc_at_cross_attn_keys, |
|
self_attention=self_attention, |
|
cross_attention=cross_attention, |
|
) |
|
return ret |
|
d_model = 256 |
|
pos_enc_at_input = True |
|
num_layers = 4 |
|
layer = build_layer() |
|
ret = MemoryAttention( |
|
d_model=d_model, |
|
pos_enc_at_input=pos_enc_at_input, |
|
num_layers=num_layers, |
|
layer=layer, |
|
) |
|
return ret |
|
|
|
def build_memory_encoder(self): |
|
def build_position_encoding(): |
|
num_pos_feats = 64 |
|
normalize = True |
|
scale = None |
|
temperature = 10000 |
|
ret = PositionEmbeddingSine( |
|
num_pos_feats=num_pos_feats, |
|
normalize=normalize, |
|
scale=scale, |
|
temperature=temperature, |
|
) |
|
return ret |
|
|
|
def build_mask_downsampler(): |
|
kernel_size = 3 |
|
stride = 2 |
|
padding = 1 |
|
ret = MaskDownSampler( |
|
kernel_size=kernel_size, |
|
stride=stride, |
|
padding=padding, |
|
) |
|
return ret |
|
|
|
def build_fuser(): |
|
def build_layer(): |
|
dim = 256 |
|
kernel_size = 7 |
|
padding = 3 |
|
layer_scale_init_value = 1e-6 |
|
use_dwconv = True |
|
ret = CXBlock( |
|
dim=dim, kernel_size=kernel_size, |
|
padding=padding, layer_scale_init_value=layer_scale_init_value, |
|
use_dwconv=use_dwconv, |
|
) |
|
return ret |
|
|
|
num_layers = 2 |
|
layer = build_layer() |
|
ret = Fuser( |
|
layer=layer, |
|
num_layers=num_layers |
|
) |
|
return ret |
|
|
|
out_dim = 64 |
|
position_encoding = build_position_encoding() |
|
mask_downsampler = build_mask_downsampler() |
|
fuser = build_fuser() |
|
ret = MemoryEncoder( |
|
out_dim=out_dim, |
|
position_encoding=position_encoding, |
|
mask_downsampler=mask_downsampler, |
|
fuser=fuser, |
|
) |
|
return ret |
|
|
|
def inject_language_embd(self, inference_state, language_embd): |
|
num_frame = len(language_embd) |
|
num_obj = len(language_embd[0]) |
|
mask_out = [] |
|
for frame_idx in range(num_frame): |
|
frame_mask_out = [] |
|
for obj_idx in range(num_obj): |
|
_language_embd = language_embd[frame_idx][obj_idx][None][None] |
|
_, _, out_mask_logits = self.sam2_model.add_language_embd(inference_state, frame_idx, obj_idx + 100, _language_embd) |
|
frame_mask_out.append(out_mask_logits) |
|
frame_mask_out = torch.cat(frame_mask_out, dim=1) |
|
mask_out.append(frame_mask_out) |
|
mask_out = torch.cat(mask_out, dim=0) |
|
return mask_out |
|
|
|
|
|
def language_embd_inference(self, inference_state, language_embd): |
|
num_frame = len(language_embd) |
|
num_obj = len(language_embd[0]) |
|
mask_out = [] |
|
with torch.autocast(device_type="cuda", dtype=torch.bfloat16): |
|
for frame_idx in range(num_frame): |
|
frame_mask_out = [] |
|
|
|
for obj_idx in range(num_obj): |
|
_language_embd = language_embd[frame_idx][obj_idx][None][None] |
|
_, _, out_mask_logits = self.sam2_model.add_language_embd( |
|
inference_state, |
|
frame_idx, |
|
obj_idx + 100, |
|
_language_embd, |
|
inference=True, |
|
) |
|
frame_mask_out.append(out_mask_logits) |
|
frame_mask_out = torch.cat(frame_mask_out, dim=1) |
|
mask_out.append(frame_mask_out) |
|
|
|
|
|
mask_out = [] |
|
for out_frame_idx, out_obj_ids, out_mask_logits in self.sam2_model.propagate_in_video(inference_state): |
|
mask_out.append(out_mask_logits) |
|
mask_out = torch.cat(mask_out, dim=0) |
|
return mask_out |
|
|
|
def get_sam2_embeddings(self, images): |
|
return self.sam2_model.init_state(images) |
|
|
|
def forward(self, batch): |
|
raise NotImplementedError |
|
|
|
def preprocess_image(self, image: torch.Tensor, dtype=torch.float32) -> torch.Tensor: |
|
image = image / 255. |
|
|
|
img_mean = torch.tensor(self.img_mean, dtype=dtype, device=image.device)[:, None, None] |
|
img_std = torch.tensor(self.img_std, dtype=dtype, device=image.device)[:, None, None] |
|
image -= img_mean |
|
image /= img_std |
|
|
|
return image |
|
|