File size: 260 Bytes
af603ca
 
c6ccb48
 
 
 
 
1
2
3
4
5
6
7
8

from segment_anything import sam_model_registry

def create_sam_model(model_type, checkpoint, device: str = "cpu"):
    medsam_model = sam_model_registry[model_type](checkpoint=checkpoint)
    medsam_model = medsam_model.to(device)   
    return medsam_model