from functools import lru_cache import torch from sentence_transformers import SentenceTransformer DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' list_models = [ 'sentence-transformers/paraphrase-multilingual-mpnet-base-v2', 'sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2', 'sentence-transformers/all-mpnet-base-v2', 'sentence-transformers/all-MiniLM-L12-v2', 'cyclone/simcse-chinese-roberta-wwm-ext' ] class SBert: def __init__(self, path): print(f'Loading model from {path} ...') self.model = SentenceTransformer(path, device=DEVICE) @lru_cache(maxsize=10000) def __call__(self, x) -> torch.Tensor: y = self.model.encode(x, convert_to_tensor=True) return y