AUMREDKA's picture
Update buffalo_rag/vector_store/db.py
999388b verified
import os
import pickle
import json
import numpy as np
from typing import List, Dict, Any, Optional, Tuple
import faiss
from tqdm import tqdm
from sentence_transformers import SentenceTransformer, CrossEncoder
class VectorStore:
def __init__(self,
embedding_dir: str = "data/embeddings",
model_name: str = "BAAI/bge-small-en-v1.5",
reranker_name: str = "cross-encoder/ms-marco-MiniLM-L-6-v2"):
self.embedding_dir = embedding_dir
self.index = None
self.chunk_ids = []
self.chunks = {}
self.model = SentenceTransformer(model_name)
self.reranker = CrossEncoder(reranker_name)
self.load_or_create_index()
def load_or_create_index(self) -> None:
index_path = os.path.join(self.embedding_dir, 'faiss_index.pkl')
if os.path.exists(index_path):
with open(index_path, 'rb') as f:
data = pickle.load(f)
self.index = data['index']
self.chunk_ids = data['chunk_ids']
self.chunks = data['chunks']
print(f"Loaded existing index with {len(self.chunk_ids)} chunks")
else:
embeddings_path = os.path.join(self.embedding_dir, 'embeddings.pkl')
if os.path.exists(embeddings_path):
self.create_index()
else:
print("No embeddings found. Please run the chunker first.")
def create_index(self) -> None:
"""Create FAISS index from embeddings."""
embeddings_path = os.path.join(self.embedding_dir, 'embeddings.pkl')
with open(embeddings_path, 'rb') as f:
embedding_map = pickle.load(f)
chunk_ids = list(embedding_map.keys())
embeddings = np.array([embedding_map[chunk_id]['embedding'] for chunk_id in chunk_ids])
chunks = {chunk_id: embedding_map[chunk_id]['chunk'] for chunk_id in chunk_ids}
dimension = embeddings.shape[1]
index = faiss.IndexFlatL2(dimension)
index.add(embeddings.astype(np.float32))
self.index = index
self.chunk_ids = chunk_ids
self.chunks = chunks
with open(os.path.join(self.embedding_dir, 'faiss_index.pkl'), 'wb') as f:
pickle.dump({
'index': index,
'chunk_ids': chunk_ids,
'chunks': chunks
}, f)
print(f"Created index with {len(chunk_ids)} chunks")
def search(self,
query: str,
k: int = 5,
filter_categories: Optional[List[str]] = None,
rerank: bool = True) -> List[Dict[str, Any]]:
if self.index is None:
print("No index available. Please create an index first.")
return []
query_embedding = self.model.encode([query])[0]
D, I = self.index.search(np.array([query_embedding]).astype(np.float32), min(k * 2, len(self.chunk_ids)))
results = []
for i, idx in enumerate(I[0]):
chunk_id = self.chunk_ids[idx]
chunk = self.chunks[chunk_id]
if filter_categories and not any(cat in chunk.get('categories', []) for cat in filter_categories):
continue
result = {
'chunk_id': chunk_id,
'score': float(D[0][i]),
'chunk': chunk
}
results.append(result)
if rerank and results:
pairs = [(query, result['chunk']['content']) for result in results]
rerank_scores = self.reranker.predict(pairs)
for i, score in enumerate(rerank_scores):
results[i]['rerank_score'] = float(score)
results = sorted(results, key=lambda x: x['rerank_score'], reverse=True)
results = results[:k]
return results
def hybrid_search(self,
query: str,
k: int = 5,
filter_categories: Optional[List[str]] = None) -> List[Dict[str, Any]]:
vector_results = self.search(query, k=k, filter_categories=filter_categories, rerank=False)
keywords = query.lower().split()
keyword_scores = {}
for chunk_id, chunk_data in self.chunks.items():
chunk = chunk_data
content = (chunk['title'] + " " + chunk['content']).lower()
score = sum(content.count(keyword) for keyword in keywords)
if filter_categories and not any(cat in chunk.get('categories', []) for cat in filter_categories):
continue
keyword_scores[chunk_id] = score
keyword_results = sorted(
[{'chunk_id': chunk_id, 'score': score, 'chunk': self.chunks[chunk_id]}
for chunk_id, score in keyword_scores.items() if score > 0],
key=lambda x: x['score'],
reverse=True
)[:k]
seen_ids = set()
combined_results = []
for result in vector_results:
combined_results.append(result)
seen_ids.add(result['chunk_id'])
for result in keyword_results:
if result['chunk_id'] not in seen_ids:
combined_results.append(result)
seen_ids.add(result['chunk_id'])
combined_results = combined_results[:k]
if combined_results:
pairs = [(query, result['chunk']['content']) for result in combined_results]
rerank_scores = self.reranker.predict(pairs)
for i, score in enumerate(rerank_scores):
combined_results[i]['rerank_score'] = float(score)
combined_results = sorted(combined_results, key=lambda x: x['rerank_score'], reverse=True)
return combined_results