import torch from torchvision.transforms.functional import to_pil_image from segment_anything import SamPredictor, sam_model_registry from PIL import Image class SegmentationModel: def __init__(self) -> None: pass def generate(self, image: torch.Tensor) -> Image.Image: pass class SamSegmentationModel(SegmentationModel): def __init__( self, model_type: str, checkpoint_path: str, device = torch.device("cpu"), ) -> None: super().__init__() sam = sam_model_registry[model_type](checkpoint=checkpoint_path) sam.to(device) self.device = device self.model = SamPredictor(sam) def generate(self, image: torch.Tensor) -> Image.Image: _, H, W = image.size() image = image.unsqueeze(0) self.model.set_torch_image(image, original_image_size=(H, W)) center_point = [H / 2, W / 2] input_point = torch.tensor([[center_point]]).to(self.device) input_label = torch.tensor([[1]]).to(self.device) masks, scores, logits = self.model.predict_torch( point_coords=input_point, point_labels=input_label, boxes=None, multimask_output=True ) masks = masks.squeeze(0) scores = scores.squeeze(0) bmask = masks[torch.argmax(scores).item()] mask_float = 1.0 - bmask.float() final = torch.stack([mask_float, mask_float, mask_float]) return to_pil_image(final)