radames's picture
radames HF staff
Update pipeline.py
4d0bd64
from typing import List
import torch
from transformers import SamModel, SamProcessor
from PIL import Image
import numpy as np
MODEL_ID = "facebook/sam-vit-base"
class PreTrainedPipeline():
def __init__(self, path=""):
self.device = torch.device(
"cuda" if torch.cuda.is_available() else "cpu")
self.processor = SamProcessor.from_pretrained(MODEL_ID)
self.model = SamModel.from_pretrained(MODEL_ID).to(self.device)
self.model.eval()
self.model = self.model.to(self.device)
def __call__(self, inputs: "Image.Image") -> List[float]:
raw_image = inputs.convert("RGB")
inputs = self.processor(raw_image, return_tensors="pt").to(self.device)
feature_vector = self.model.get_image_embeddings(
inputs["pixel_values"])
return feature_vector.tolist()