grantpitt's picture
add custom handler
7755815
raw
history blame
No virus
1.23 kB
from typing import Dict, List, Any
from transformers import CLIPTokenizer, CLIPModel
import numpy as np
import os
import torch
class EndpointHandler:
def __init__(self, path="."):
# load the model
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model = CLIPModel.from_pretrained(path).to(self.device).eval()
self.tokenizer = CLIPTokenizer.from_pretrained(path)
def __call__(self, data: Dict[str, Any]) -> List[float]:
"""
data args:
inputs (:obj: `str` | `PIL.Image` | `np.array`)
kwargs
Return:
A :obj:`list` | `dict`: will be serialized and returned
"""
# compute the embedding of the input
query = data["inputs"]
inputs = self.tokenizer(query, padding=True, return_tensors="pt").to(
self.device
)
with torch.no_grad():
text_features = self.model.get_text_features(**inputs)
text_features = text_features.cpu().detach().numpy()
input_embedding = text_features[0]
# normalize the embedding
input_embedding /= np.linalg.norm(input_embedding)
return input_embedding.tolist()