SkalskiP's picture
SAM2 added
2fbf361
raw
history blame
438 Bytes
import torch
from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor
SAM_CHECKPOINT = "checkpoints/sam2_hiera_large.pt"
SAM_CONFIG = "sam2_hiera_l.yaml"
def load_sam_model(
device: torch.device,
config: str = SAM_CONFIG,
checkpoint: str = SAM_CHECKPOINT
) -> SAM2ImagePredictor:
model = build_sam2(config, checkpoint, device=device)
return SAM2ImagePredictor(sam_model=model)