import torch from transformers import SamModel, SamProcessor class SAM(): def __init__(self): self.device = "cuda" if torch.cuda.is_available() else "cpu" self.model = SamModel.from_pretrained("facebook/sam-vit-large").to(self.device) self.processor = SamProcessor.from_pretrained("facebook/sam-vit-large") def segment(self, raw_image, input_points): inputs = self.processor(raw_image, input_points=input_points, return_tensors="pt").to(self.device) with torch.no_grad(): outputs = self.model(**inputs) masks = self.processor.image_processor.post_process_masks( outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu() ) scores = outputs.iou_scores return masks, scores