from functools import lru_cache import torch from sentence_transformers import SentenceTransformer import numpy as np DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' class SBert: def __init__(self, path): self.model = SentenceTransformer(path, device=DEVICE) @lru_cache(maxsize=10000) def __call__(self, x) -> np.ndarray: y = self.model.encode(x) return y