VideoMaMa / tools /base_segmenter.py
pizb's picture
initial update
d33e75e
"""
SAM2 Base Segmenter
Adapted from MatAnyone demo
"""
import sys
sys.path.append("/home/cvlab19/project/samuel/CVPR/sam2")
import torch
import numpy as np
from sam2.build_sam import build_sam2_video_predictor
class BaseSegmenter:
def __init__(self, SAM_checkpoint, model_type, device):
"""
Initialize SAM2 segmenter
Args:
SAM_checkpoint: Path to SAM2 checkpoint
model_type: SAM2 model config file
device: Device to run on
"""
self.device = device
self.model_type = model_type
# Build SAM2 video predictor
self.sam_predictor = build_sam2_video_predictor(
config_file=model_type,
ckpt_path=SAM_checkpoint,
device=device
)
self.orignal_image = None
self.inference_state = None
def set_image(self, image: np.ndarray):
"""Set the current image for segmentation"""
self.orignal_image = image
def reset_image(self):
"""Reset the current image"""
self.orignal_image = None
self.inference_state = None
def predict(self, prompts, prompt_type, multimask=True):
"""
Predict mask from prompts
Args:
prompts: Dictionary with point_coords, point_labels, mask_input
prompt_type: 'point' or 'both'
multimask: Whether to return multiple masks
Returns:
masks, scores, logits
"""
# For SAM2, we need to handle prompts differently
# This is simplified - actual implementation will use video predictor
# Placeholder - actual SAM2 prediction would go here
# For now, return dummy values
h, w = self.orignal_image.shape[:2]
dummy_mask = np.zeros((h, w), dtype=bool)
dummy_score = np.array([1.0])
dummy_logit = np.zeros((h, w), dtype=np.float32)
return np.array([dummy_mask]), dummy_score, np.array([dummy_logit])