|
import os.path |
|
|
|
import torch |
|
|
|
from hydra import compose |
|
from hydra.utils import instantiate |
|
from omegaconf import OmegaConf |
|
|
|
from mmengine.model import BaseModule |
|
|
|
|
|
from vlm.utils import load_checkpoint_with_prefix, load_state_dict_to_model |
|
|
|
BASE_DIR = 'work_dirs/ckpt' |
|
|
|
|
|
class SAM2(BaseModule): |
|
def __init__( |
|
self, |
|
cfg_path: str = "sam2_hiera_l.yaml", |
|
ckpt_path: str = "sam2_hiera_large.pt", |
|
hydra_overrides_extra=None, |
|
apply_postprocessing=True, |
|
): |
|
super().__init__(init_cfg=None) |
|
|
|
import third_parts.sam2 |
|
|
|
if hydra_overrides_extra is None: |
|
hydra_overrides_extra = [] |
|
hydra_overrides = [ |
|
|
|
"++model._target_=projects.llava_sam2.models.predictor.SAM2VideoPredictor", |
|
] |
|
|
|
if apply_postprocessing: |
|
hydra_overrides_extra = hydra_overrides_extra.copy() |
|
hydra_overrides_extra += [ |
|
|
|
"++model.sam_mask_decoder_extra_args.dynamic_multimask_via_stability=true", |
|
"++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_delta=0.05", |
|
"++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_thresh=0.98", |
|
|
|
|
|
|
|
|
|
] |
|
hydra_overrides.extend(hydra_overrides_extra) |
|
|
|
|
|
cfg = compose(config_name=cfg_path, overrides=hydra_overrides) |
|
OmegaConf.resolve(cfg) |
|
sam2_model = instantiate(cfg.model, _recursive_=True) |
|
state_dict = load_checkpoint_with_prefix(os.path.join(BASE_DIR, ckpt_path)) |
|
load_state_dict_to_model(sam2_model, state_dict) |
|
|
|
self.sam2_model = sam2_model |
|
|
|
self.hidden_dim = self.sam2_model.hidden_dim |
|
|
|
self.img_mean = (0.485, 0.456, 0.406) |
|
self.img_std = (0.229, 0.224, 0.225) |
|
|
|
def inject_language_embd(self, inference_state, language_embd): |
|
num_frame = len(language_embd) |
|
num_obj = len(language_embd[0]) |
|
mask_out = [] |
|
for frame_idx in range(num_frame): |
|
frame_mask_out = [] |
|
for obj_idx in range(num_obj): |
|
_language_embd = language_embd[frame_idx][obj_idx][None][None] |
|
_, _, out_mask_logits = self.sam2_model.add_language_embd(inference_state, frame_idx, obj_idx + 100, _language_embd) |
|
frame_mask_out.append(out_mask_logits) |
|
frame_mask_out = torch.cat(frame_mask_out, dim=1) |
|
mask_out.append(frame_mask_out) |
|
mask_out = torch.cat(mask_out, dim=0) |
|
return mask_out |
|
|
|
|
|
def language_embd_inference(self, inference_state, language_embd): |
|
num_frame = len(language_embd) |
|
num_obj = len(language_embd[0]) |
|
mask_out = [] |
|
with torch.autocast(device_type="cuda", dtype=torch.bfloat16): |
|
for frame_idx in range(num_frame): |
|
frame_mask_out = [] |
|
|
|
for obj_idx in range(num_obj): |
|
_language_embd = language_embd[frame_idx][obj_idx][None][None] |
|
_, _, out_mask_logits = self.sam2_model.add_language_embd( |
|
inference_state, |
|
frame_idx, |
|
obj_idx + 100, |
|
_language_embd, |
|
inference=True, |
|
) |
|
frame_mask_out.append(out_mask_logits) |
|
frame_mask_out = torch.cat(frame_mask_out, dim=1) |
|
mask_out.append(frame_mask_out) |
|
|
|
|
|
mask_out = [] |
|
for out_frame_idx, out_obj_ids, out_mask_logits in self.sam2_model.propagate_in_video(inference_state): |
|
mask_out.append(out_mask_logits) |
|
mask_out = torch.cat(mask_out, dim=0) |
|
return mask_out |
|
|
|
def get_sam2_embeddings(self, images): |
|
return self.sam2_model.init_state(images) |
|
|
|
def forward(self, batch): |
|
raise NotImplementedError |
|
|
|
def preprocess_image(self, image: torch.Tensor, dtype=torch.float32) -> torch.Tensor: |
|
image = image / 255. |
|
|
|
img_mean = torch.tensor(self.img_mean, dtype=dtype, device=image.device)[:, None, None] |
|
img_std = torch.tensor(self.img_std, dtype=dtype, device=image.device)[:, None, None] |
|
image -= img_mean |
|
image /= img_std |
|
|
|
return image |
|
|