File size: 6,095 Bytes
d8f06d4 999388b d8f06d4 999388b d8f06d4 999388b d8f06d4 999388b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 |
import os
import pickle
import json
import numpy as np
from typing import List, Dict, Any, Optional, Tuple
import faiss
from tqdm import tqdm
from sentence_transformers import SentenceTransformer, CrossEncoder
class VectorStore:
def __init__(self,
embedding_dir: str = "data/embeddings",
model_name: str = "BAAI/bge-small-en-v1.5",
reranker_name: str = "cross-encoder/ms-marco-MiniLM-L-6-v2"):
self.embedding_dir = embedding_dir
self.index = None
self.chunk_ids = []
self.chunks = {}
self.model = SentenceTransformer(model_name)
self.reranker = CrossEncoder(reranker_name)
self.load_or_create_index()
def load_or_create_index(self) -> None:
index_path = os.path.join(self.embedding_dir, 'faiss_index.pkl')
if os.path.exists(index_path):
with open(index_path, 'rb') as f:
data = pickle.load(f)
self.index = data['index']
self.chunk_ids = data['chunk_ids']
self.chunks = data['chunks']
print(f"Loaded existing index with {len(self.chunk_ids)} chunks")
else:
embeddings_path = os.path.join(self.embedding_dir, 'embeddings.pkl')
if os.path.exists(embeddings_path):
self.create_index()
else:
print("No embeddings found. Please run the chunker first.")
def create_index(self) -> None:
"""Create FAISS index from embeddings."""
embeddings_path = os.path.join(self.embedding_dir, 'embeddings.pkl')
with open(embeddings_path, 'rb') as f:
embedding_map = pickle.load(f)
chunk_ids = list(embedding_map.keys())
embeddings = np.array([embedding_map[chunk_id]['embedding'] for chunk_id in chunk_ids])
chunks = {chunk_id: embedding_map[chunk_id]['chunk'] for chunk_id in chunk_ids}
dimension = embeddings.shape[1]
index = faiss.IndexFlatL2(dimension)
index.add(embeddings.astype(np.float32))
self.index = index
self.chunk_ids = chunk_ids
self.chunks = chunks
with open(os.path.join(self.embedding_dir, 'faiss_index.pkl'), 'wb') as f:
pickle.dump({
'index': index,
'chunk_ids': chunk_ids,
'chunks': chunks
}, f)
print(f"Created index with {len(chunk_ids)} chunks")
def search(self,
query: str,
k: int = 5,
filter_categories: Optional[List[str]] = None,
rerank: bool = True) -> List[Dict[str, Any]]:
if self.index is None:
print("No index available. Please create an index first.")
return []
query_embedding = self.model.encode([query])[0]
D, I = self.index.search(np.array([query_embedding]).astype(np.float32), min(k * 2, len(self.chunk_ids)))
results = []
for i, idx in enumerate(I[0]):
chunk_id = self.chunk_ids[idx]
chunk = self.chunks[chunk_id]
if filter_categories and not any(cat in chunk.get('categories', []) for cat in filter_categories):
continue
result = {
'chunk_id': chunk_id,
'score': float(D[0][i]),
'chunk': chunk
}
results.append(result)
if rerank and results:
pairs = [(query, result['chunk']['content']) for result in results]
rerank_scores = self.reranker.predict(pairs)
for i, score in enumerate(rerank_scores):
results[i]['rerank_score'] = float(score)
results = sorted(results, key=lambda x: x['rerank_score'], reverse=True)
results = results[:k]
return results
def hybrid_search(self,
query: str,
k: int = 5,
filter_categories: Optional[List[str]] = None) -> List[Dict[str, Any]]:
vector_results = self.search(query, k=k, filter_categories=filter_categories, rerank=False)
keywords = query.lower().split()
keyword_scores = {}
for chunk_id, chunk_data in self.chunks.items():
chunk = chunk_data
content = (chunk['title'] + " " + chunk['content']).lower()
score = sum(content.count(keyword) for keyword in keywords)
if filter_categories and not any(cat in chunk.get('categories', []) for cat in filter_categories):
continue
keyword_scores[chunk_id] = score
keyword_results = sorted(
[{'chunk_id': chunk_id, 'score': score, 'chunk': self.chunks[chunk_id]}
for chunk_id, score in keyword_scores.items() if score > 0],
key=lambda x: x['score'],
reverse=True
)[:k]
seen_ids = set()
combined_results = []
for result in vector_results:
combined_results.append(result)
seen_ids.add(result['chunk_id'])
for result in keyword_results:
if result['chunk_id'] not in seen_ids:
combined_results.append(result)
seen_ids.add(result['chunk_id'])
combined_results = combined_results[:k]
if combined_results:
pairs = [(query, result['chunk']['content']) for result in combined_results]
rerank_scores = self.reranker.predict(pairs)
for i, score in enumerate(rerank_scores):
combined_results[i]['rerank_score'] = float(score)
combined_results = sorted(combined_results, key=lambda x: x['rerank_score'], reverse=True)
return combined_results |