team-149-project / utils /semantic_similarity.py
knguyen471's picture
Upload 11 files
888aba6 verified
raw
history blame contribute delete
827 Bytes
from typing import List
from sentence_transformers import SentenceTransformer
class Encoder():
def __init__(self):
print("Loading embedding model...")
self.model = SentenceTransformer(
"KaLM-Embedding/KaLM-embedding-multilingual-mini-instruct-v2.5",
model_kwargs={"attn_implementation": "sdpa"}
)
self.model = self.model.half()
def encode(
self,
texts: List[str],
batch_size: int = 8,
show_progress_bar: bool = False,
save_path: str = None):
embeddings = self.model.encode(texts, convert_to_tensor=True, show_progress_bar=show_progress_bar, batch_size=batch_size)
# if save_path:
# torch.save(embeddings, save_path)
return embeddings