radames's picture
Update pipeline.py
ea671c2
raw
history blame contribute delete
851 Bytes
from typing import List
import torch
from transformers import SamModel, SamProcessor
from PIL import Image
import numpy as np
MODEL_ID = "facebook/sam-vit-large"
class PreTrainedPipeline():
def __init__(self, path=""):
self.device = torch.device(
"cuda" if torch.cuda.is_available() else "cpu")
self.processor = SamProcessor.from_pretrained(MODEL_ID)
self.model = SamModel.from_pretrained(MODEL_ID).to(self.device)
self.model.eval()
self.model = self.model.to(self.device)
def __call__(self, inputs: "Image.Image") -> List[float]:
raw_image = inputs.convert("RGB")
inputs = self.processor(raw_image, return_tensors="pt").to(self.device)
feature_vector = self.model.get_image_embeddings(
inputs["pixel_values"])
return feature_vector.tolist()