File size: 1,169 Bytes
c5b2f2f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 |
from typing import Dict, List, Any
from FlagEmbedding import BGEM3FlagModel
class EndpointHandler():
def __init__(self, path=""):
self.model = BGEM3FlagModel(path, use_fp16=True) # Setting use_fp16 to True speeds up computation with a slight performance degradation
# Preload all the elements you are going to need at inference.
# pseudo:
# self.model= load_model(path)
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
Args:
data (:obj:):
includes the input data and the parameters for the inference.
Return:
A :obj:`list`:. The object returned should be a list of vector
"""
inputs = data.pop("inputs", data)
parameters = data.pop("parameters", None)
# pass inputs with all kwargs in data
if parameters is not None:
embeddings = self.model.encode(inputs, **parameters)
else:
embeddings = self.model.encode(inputs)
# postprocess the prediction
return embeddings
# return self.model.encode(inputs, batch_size=12, max_length=8192)['dense_vecs']
|