|
from sentence_transformers import SentenceTransformer
|
|
import faiss
|
|
import numpy as np
|
|
|
|
|
|
embedder = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
|
|
|
|
def embed_texts(texts):
|
|
return embedder.encode(texts, convert_to_tensor=False)
|
|
|
|
def build_faiss_index(texts):
|
|
embeddings = embed_texts(texts)
|
|
index = faiss.IndexFlatL2(embeddings[0].shape[0])
|
|
index.add(np.array(embeddings))
|
|
return index, embeddings
|
|
|
|
def retrieve(query, index, docs, k=3):
|
|
query_embedding = embed_texts([query])
|
|
distances, indices = index.search(np.array(query_embedding), k)
|
|
return [docs[i] for i in indices[0]]
|
|
|