grantpitt's picture
maybe fix pipeline
5535b1f
raw
history blame contribute delete
No virus
1.23 kB
from typing import Dict, List, Any
from transformers import CLIPTokenizer, CLIPModel
import numpy as np
import torch
class PreTrainedPipeline():
def __init__(self, path="."):
# load the model
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model = CLIPModel.from_pretrained(path).to(self.device).eval()
self.tokenizer = CLIPTokenizer.from_pretrained(path)
def __call__(self, data: Dict[str, Any]) -> List[float]:
"""
data args:
inputs (:obj: `str` | `PIL.Image` | `np.array`)
kwargs
Return:
A :obj:`list` | `dict`: will be serialized and returned
"""
# compute the embedding of the input
# query = data["inputs"]
inputs = self.tokenizer(data, padding=True, return_tensors="pt").to(
self.device
)
with torch.no_grad():
text_features = self.model.get_text_features(**inputs)
text_features = text_features.cpu().detach().numpy()
input_embedding = text_features[0]
# normalize the embedding
input_embedding /= np.linalg.norm(input_embedding)
return input_embedding.tolist()