1mg / utils.py
ftkd99's picture
Upload 2 files
027da24 verified
raw
history blame contribute delete
665 Bytes
from sentence_transformers import SentenceTransformer
import faiss
import numpy as np
# Load MiniLM embedder
embedder = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
def embed_texts(texts):
return embedder.encode(texts, convert_to_tensor=False)
def build_faiss_index(texts):
embeddings = embed_texts(texts)
index = faiss.IndexFlatL2(embeddings[0].shape[0])
index.add(np.array(embeddings))
return index, embeddings
def retrieve(query, index, docs, k=3):
query_embedding = embed_texts([query])
distances, indices = index.search(np.array(query_embedding), k)
return [docs[i] for i in indices[0]]