| import os | |
| import pickle | |
| import json | |
| import numpy as np | |
| from typing import List, Dict, Any, Optional, Tuple | |
| import faiss | |
| from tqdm import tqdm | |
| from sentence_transformers import SentenceTransformer, CrossEncoder | |
| class VectorStore: | |
| def __init__(self, | |
| embedding_dir: str = "data/embeddings", | |
| model_name: str = "BAAI/bge-small-en-v1.5", | |
| reranker_name: str = "cross-encoder/ms-marco-MiniLM-L-6-v2"): | |
| self.embedding_dir = embedding_dir | |
| self.index = None | |
| self.chunk_ids = [] | |
| self.chunks = {} | |
| self.model = SentenceTransformer(model_name) | |
| self.reranker = CrossEncoder(reranker_name) | |
| self.load_or_create_index() | |
| def load_or_create_index(self) -> None: | |
| index_path = os.path.join(self.embedding_dir, 'faiss_index.pkl') | |
| if os.path.exists(index_path): | |
| with open(index_path, 'rb') as f: | |
| data = pickle.load(f) | |
| self.index = data['index'] | |
| self.chunk_ids = data['chunk_ids'] | |
| self.chunks = data['chunks'] | |
| print(f"Loaded existing index with {len(self.chunk_ids)} chunks") | |
| else: | |
| embeddings_path = os.path.join(self.embedding_dir, 'embeddings.pkl') | |
| if os.path.exists(embeddings_path): | |
| self.create_index() | |
| else: | |
| print("No embeddings found. Please run the chunker first.") | |
| def create_index(self) -> None: | |
| """Create FAISS index from embeddings.""" | |
| embeddings_path = os.path.join(self.embedding_dir, 'embeddings.pkl') | |
| with open(embeddings_path, 'rb') as f: | |
| embedding_map = pickle.load(f) | |
| chunk_ids = list(embedding_map.keys()) | |
| embeddings = np.array([embedding_map[chunk_id]['embedding'] for chunk_id in chunk_ids]) | |
| chunks = {chunk_id: embedding_map[chunk_id]['chunk'] for chunk_id in chunk_ids} | |
| dimension = embeddings.shape[1] | |
| index = faiss.IndexFlatL2(dimension) | |
| index.add(embeddings.astype(np.float32)) | |
| self.index = index | |
| self.chunk_ids = chunk_ids | |
| self.chunks = chunks | |
| with open(os.path.join(self.embedding_dir, 'faiss_index.pkl'), 'wb') as f: | |
| pickle.dump({ | |
| 'index': index, | |
| 'chunk_ids': chunk_ids, | |
| 'chunks': chunks | |
| }, f) | |
| print(f"Created index with {len(chunk_ids)} chunks") | |
| def search(self, | |
| query: str, | |
| k: int = 5, | |
| filter_categories: Optional[List[str]] = None, | |
| rerank: bool = True) -> List[Dict[str, Any]]: | |
| if self.index is None: | |
| print("No index available. Please create an index first.") | |
| return [] | |
| query_embedding = self.model.encode([query])[0] | |
| D, I = self.index.search(np.array([query_embedding]).astype(np.float32), min(k * 2, len(self.chunk_ids))) | |
| results = [] | |
| for i, idx in enumerate(I[0]): | |
| chunk_id = self.chunk_ids[idx] | |
| chunk = self.chunks[chunk_id] | |
| if filter_categories and not any(cat in chunk.get('categories', []) for cat in filter_categories): | |
| continue | |
| result = { | |
| 'chunk_id': chunk_id, | |
| 'score': float(D[0][i]), | |
| 'chunk': chunk | |
| } | |
| results.append(result) | |
| if rerank and results: | |
| pairs = [(query, result['chunk']['content']) for result in results] | |
| rerank_scores = self.reranker.predict(pairs) | |
| for i, score in enumerate(rerank_scores): | |
| results[i]['rerank_score'] = float(score) | |
| results = sorted(results, key=lambda x: x['rerank_score'], reverse=True) | |
| results = results[:k] | |
| return results | |
| def hybrid_search(self, | |
| query: str, | |
| k: int = 5, | |
| filter_categories: Optional[List[str]] = None) -> List[Dict[str, Any]]: | |
| vector_results = self.search(query, k=k, filter_categories=filter_categories, rerank=False) | |
| keywords = query.lower().split() | |
| keyword_scores = {} | |
| for chunk_id, chunk_data in self.chunks.items(): | |
| chunk = chunk_data | |
| content = (chunk['title'] + " " + chunk['content']).lower() | |
| score = sum(content.count(keyword) for keyword in keywords) | |
| if filter_categories and not any(cat in chunk.get('categories', []) for cat in filter_categories): | |
| continue | |
| keyword_scores[chunk_id] = score | |
| keyword_results = sorted( | |
| [{'chunk_id': chunk_id, 'score': score, 'chunk': self.chunks[chunk_id]} | |
| for chunk_id, score in keyword_scores.items() if score > 0], | |
| key=lambda x: x['score'], | |
| reverse=True | |
| )[:k] | |
| seen_ids = set() | |
| combined_results = [] | |
| for result in vector_results: | |
| combined_results.append(result) | |
| seen_ids.add(result['chunk_id']) | |
| for result in keyword_results: | |
| if result['chunk_id'] not in seen_ids: | |
| combined_results.append(result) | |
| seen_ids.add(result['chunk_id']) | |
| combined_results = combined_results[:k] | |
| if combined_results: | |
| pairs = [(query, result['chunk']['content']) for result in combined_results] | |
| rerank_scores = self.reranker.predict(pairs) | |
| for i, score in enumerate(rerank_scores): | |
| combined_results[i]['rerank_score'] = float(score) | |
| combined_results = sorted(combined_results, key=lambda x: x['rerank_score'], reverse=True) | |
| return combined_results |