3martini's picture
Upload folder using huggingface_hub
786d4da verified
raw
history blame
No virus
836 Bytes
import torch
from transformers import SamModel, SamProcessor
class SAM():
def __init__(self):
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.model = SamModel.from_pretrained("facebook/sam-vit-large").to(self.device)
self.processor = SamProcessor.from_pretrained("facebook/sam-vit-large")
def segment(self, raw_image, input_points):
inputs = self.processor(raw_image, input_points=input_points, return_tensors="pt").to(self.device)
with torch.no_grad():
outputs = self.model(**inputs)
masks = self.processor.image_processor.post_process_masks(
outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu()
)
scores = outputs.iou_scores
return masks, scores