from typing import List from sentence_transformers import SentenceTransformer import os class PreTrainedPipeline: def __init__(self, path=""): """ Initialize model """ self.model = SentenceTransformer(os.path.join(path)) # os.path.join(path, 'quora-distilbert-multilingual') #"sentence-transformers/quora-distilbert-multilingual" #) def __call__(self, inputs: str) -> List[float]: """ Args: inputs (:obj:`str`): a string to get the features of. Return: A :obj:`list` of floats: The features computed by the model. """ return self.model.encode(inputs).tolist() # if __name__ == "__main__": xx = PreTrainedPipeline() print(xx.__call__("hei"))