embedding / handler.py
spcewalker's picture
Update handler.py
78b1028
raw
history blame contribute delete
746 Bytes
import torch
from typing import Dict
from sentence_transformers import SentenceTransformer
class EndpointHandler:
def __init__(self, path=""):
self.model = SentenceTransformer("all-MiniLM-L6-v2")
def __call__(self, data: Dict[str, List[str]]) -> Dict[str, List[List[float]]]:
"""
Args:
data (:obj:):
includes the deserialized sentences as a list of strings
Return:
A :obj:`dict`: list of embeddings for each input sentence
"""
# process input
inputs = data.pop("inputs", data)
# get embeddings
embeddings = self.model.encode(inputs)
# postprocess the embeddings
return {"embeddings": embeddings.tolist()}