JuanLozada97's picture
Upload 2 files
af603ca
raw
history blame
324 Bytes
import matplotlib.pyplot as plt
from segment_anything import sam_model_registry
import torch.nn.functional as F
def create_sam_model(model_type, checkpoint, device: str = "cpu"):
medsam_model = sam_model_registry[model_type](checkpoint=checkpoint)
medsam_model = medsam_model.to(device)
return medsam_model