File size: 836 Bytes
786d4da
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
import torch
from transformers import SamModel, SamProcessor

class SAM():
    def __init__(self):
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.model = SamModel.from_pretrained("facebook/sam-vit-large").to(self.device)
        self.processor = SamProcessor.from_pretrained("facebook/sam-vit-large")
        
    def segment(self, raw_image, input_points):    

        inputs = self.processor(raw_image, input_points=input_points, return_tensors="pt").to(self.device)
        with torch.no_grad():
            outputs = self.model(**inputs)

        masks = self.processor.image_processor.post_process_masks(
            outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu()
        )
        scores = outputs.iou_scores
        
        return masks, scores