JuanLozada97's picture
Change to vit b
c90b2b6
raw
history blame
260 Bytes
from segment_anything import sam_model_registry
def create_sam_model(model_type, checkpoint, device: str = "cpu"):
medsam_model = sam_model_registry[model_type](checkpoint=checkpoint)
medsam_model = medsam_model.to(device)
return medsam_model