Spaces:
Sleeping
Sleeping
File size: 2,481 Bytes
758b199 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 |
from sentence_transformers import SentenceTransformer
import faiss
import numpy as np
import pickle
import os
from document_processing import load_document
class EmbeddingEngine:
def __init__(self, model_name:str):
self.model = SentenceTransformer(model_name)
def encode(self, texts):
return self.model.encode(texts, convert_to_numpy=True)
class FAISSVectorStore:
def __init__(self, dim=None, index_path="faiss_index.index", metadata_path="faiss_metadata.pkl", overwrite=False):
self.index_path = index_path
self.metadata_path = metadata_path
self.metadata = []
if not overwrite and os.path.exists(index_path) and os.path.exists(metadata_path):
self.load()
else:
if dim is None:
raise ValueError("Must provide dimension if creating new index")
self.index = faiss.IndexFlatL2(dim)
def add(self, embeddings, metadata_list):
# Auto-create index if not already loaded
if not hasattr(self, 'index'):
dim = embeddings.shape[1]
self.index = faiss.IndexFlatL2(dim)
self.index.add(np.array(embeddings).astype("float32"))
self.metadata.extend(metadata_list)
def search(self, query_embedding, top_k=5):
distances, indices = self.index.search(np.array([query_embedding]).astype("float32"), top_k)
results = []
for i in indices[0]:
if i < len(self.metadata):
results.append(self.metadata[i])
return results
def save(self):
faiss.write_index(self.index, self.index_path)
with open(self.metadata_path, "wb") as f:
pickle.dump(self.metadata, f)
def load(self):
self.index = faiss.read_index(self.index_path)
with open(self.metadata_path, "rb") as f:
self.metadata = pickle.load(f)
class ClauseQueryEngine:
def __init__(self, model_name: str, vector_store_path="faiss_index.index", metadata_path="faiss_metadata.pkl"):
self.embedder = EmbeddingEngine(model_name=model_name)
self.store = FAISSVectorStore(index_path=vector_store_path, metadata_path=metadata_path)
def search(self, query: str, top_k: int = 5):
query_vector = self.embedder.encode([f"query: {query}"])[0] # BGE: use "query: " prefix
results = self.store.search(query_vector, top_k)
return results
|