working first phase
Browse files- retriever/processor.py +158 -29
- retriever/retriever.py +95 -8
retriever/processor.py
CHANGED
|
@@ -6,92 +6,159 @@ from langchain_text_splitters import (
|
|
| 6 |
from langchain_experimental.text_splitter import SemanticChunker
|
| 7 |
from langchain_huggingface import HuggingFaceEmbeddings
|
| 8 |
from sentence_transformers import SentenceTransformer
|
|
|
|
|
|
|
| 9 |
|
| 10 |
class ChunkProcessor:
|
| 11 |
-
def __init__(self, model_name='all-MiniLM-L6-v2'):
|
| 12 |
self.model_name = model_name
|
| 13 |
self.encoder = SentenceTransformer(model_name)
|
|
|
|
| 14 |
# Required for Semantic Chunking
|
| 15 |
self.hf_embeddings = HuggingFaceEmbeddings(model_name=model_name)
|
| 16 |
|
| 17 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
"""
|
| 19 |
Factory method to return different chunking strategies.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
"""
|
| 21 |
if technique == "fixed":
|
|
|
|
| 22 |
return CharacterTextSplitter(
|
| 23 |
separator=kwargs.get('separator', ""),
|
| 24 |
chunk_size=chunk_size,
|
| 25 |
-
chunk_overlap=chunk_overlap
|
|
|
|
|
|
|
| 26 |
)
|
| 27 |
|
| 28 |
elif technique == "recursive":
|
|
|
|
|
|
|
| 29 |
return RecursiveCharacterTextSplitter(
|
| 30 |
chunk_size=chunk_size,
|
| 31 |
-
chunk_overlap=chunk_overlap
|
|
|
|
|
|
|
|
|
|
| 32 |
)
|
| 33 |
|
| 34 |
elif technique == "character":
|
|
|
|
| 35 |
return CharacterTextSplitter(
|
| 36 |
separator=kwargs.get('separator', "\n\n"),
|
| 37 |
chunk_size=chunk_size,
|
| 38 |
-
chunk_overlap=chunk_overlap
|
|
|
|
|
|
|
| 39 |
)
|
| 40 |
|
| 41 |
elif technique == "sentence":
|
| 42 |
-
# Using Recursive Splitter
|
| 43 |
-
# This
|
| 44 |
return RecursiveCharacterTextSplitter(
|
| 45 |
chunk_size=chunk_size,
|
| 46 |
chunk_overlap=chunk_overlap,
|
| 47 |
-
separators=["\n\n", "\n", ". ", "? ", "! ", " ", ""]
|
|
|
|
|
|
|
| 48 |
)
|
| 49 |
|
| 50 |
elif technique == "semantic":
|
|
|
|
| 51 |
return SemanticChunker(
|
| 52 |
self.hf_embeddings,
|
| 53 |
-
breakpoint_threshold_type="percentile"
|
|
|
|
|
|
|
|
|
|
| 54 |
)
|
| 55 |
|
| 56 |
elif technique == "token":
|
|
|
|
| 57 |
return SentenceTransformersTokenTextSplitter(
|
| 58 |
model_name=self.model_name,
|
| 59 |
tokens_per_chunk=chunk_size,
|
| 60 |
-
chunk_overlap=chunk_overlap
|
|
|
|
| 61 |
)
|
|
|
|
| 62 |
else:
|
| 63 |
-
raise ValueError(f"Technique '{technique}' is not supported.")
|
| 64 |
|
| 65 |
-
def process(self, df, technique="recursive", chunk_size=500,
|
|
|
|
|
|
|
| 66 |
"""
|
| 67 |
-
Processes a DataFrame into vector-ready chunks with full output for
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 68 |
"""
|
|
|
|
|
|
|
|
|
|
| 69 |
splitter = self.get_splitter(technique, chunk_size, chunk_overlap, **kwargs)
|
| 70 |
processed_chunks = []
|
| 71 |
|
| 72 |
-
#
|
| 73 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 74 |
|
| 75 |
for _, row in subset_df.iterrows():
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
|
|
|
|
|
|
| 80 |
|
| 81 |
# Split the text
|
| 82 |
raw_chunks = splitter.split_text(row['full_text'])
|
| 83 |
|
| 84 |
-
|
|
|
|
| 85 |
|
| 86 |
for i, text in enumerate(raw_chunks):
|
| 87 |
-
# Standardize output
|
| 88 |
content = text.page_content if hasattr(text, 'page_content') else text
|
| 89 |
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
|
|
|
|
|
|
| 93 |
|
| 94 |
-
#
|
| 95 |
embedding = self.encoder.encode(content).tolist()
|
| 96 |
|
| 97 |
processed_chunks.append({
|
|
@@ -102,10 +169,72 @@ class ChunkProcessor:
|
|
| 102 |
"text": content,
|
| 103 |
"url": row['url'],
|
| 104 |
"chunk_index": i,
|
| 105 |
-
"technique": technique
|
|
|
|
|
|
|
| 106 |
}
|
| 107 |
})
|
| 108 |
-
print("="*80)
|
| 109 |
|
| 110 |
-
|
| 111 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
from langchain_experimental.text_splitter import SemanticChunker
|
| 7 |
from langchain_huggingface import HuggingFaceEmbeddings
|
| 8 |
from sentence_transformers import SentenceTransformer
|
| 9 |
+
from typing import List, Dict, Any, Optional
|
| 10 |
+
import pandas as pd
|
| 11 |
|
| 12 |
class ChunkProcessor:
|
| 13 |
+
def __init__(self, model_name='all-MiniLM-L6-v2', verbose: bool = True):
|
| 14 |
self.model_name = model_name
|
| 15 |
self.encoder = SentenceTransformer(model_name)
|
| 16 |
+
self.verbose = verbose
|
| 17 |
# Required for Semantic Chunking
|
| 18 |
self.hf_embeddings = HuggingFaceEmbeddings(model_name=model_name)
|
| 19 |
|
| 20 |
+
def _print(self, *args, **kwargs):
|
| 21 |
+
"""Helper method to conditionally print"""
|
| 22 |
+
if self.verbose:
|
| 23 |
+
print(*args, **kwargs)
|
| 24 |
+
|
| 25 |
+
def get_splitter(self, technique: str, chunk_size: int = 500, chunk_overlap: int = 50, **kwargs):
|
| 26 |
"""
|
| 27 |
Factory method to return different chunking strategies.
|
| 28 |
+
|
| 29 |
+
Strategies:
|
| 30 |
+
- "fixed": Simple character-based splitting with empty separator (can split mid-sentence)
|
| 31 |
+
- "recursive": Recursive character splitting with hierarchical separators (preserves semantics)
|
| 32 |
+
- "character": Character-based splitting with paragraph separator
|
| 33 |
+
- "sentence": Recursive splitting optimized for sentence boundaries
|
| 34 |
+
- "semantic": Embedding-based semantic chunking
|
| 35 |
+
- "token": Token-based splitting for transformer models
|
| 36 |
"""
|
| 37 |
if technique == "fixed":
|
| 38 |
+
# FIXED: Simple character-based splitter - WILL split mid-sentence
|
| 39 |
return CharacterTextSplitter(
|
| 40 |
separator=kwargs.get('separator', ""),
|
| 41 |
chunk_size=chunk_size,
|
| 42 |
+
chunk_overlap=chunk_overlap,
|
| 43 |
+
length_function=len,
|
| 44 |
+
is_separator_regex=False
|
| 45 |
)
|
| 46 |
|
| 47 |
elif technique == "recursive":
|
| 48 |
+
# FIXED: Proper recursive splitter with default separators that preserve semantics
|
| 49 |
+
separators = kwargs.get('separators', ["\n\n", "\n", ". ", "! ", "? ", "; ", ", ", " ", ""])
|
| 50 |
return RecursiveCharacterTextSplitter(
|
| 51 |
chunk_size=chunk_size,
|
| 52 |
+
chunk_overlap=chunk_overlap,
|
| 53 |
+
separators=separators,
|
| 54 |
+
length_function=len,
|
| 55 |
+
keep_separator=kwargs.get('keep_separator', True)
|
| 56 |
)
|
| 57 |
|
| 58 |
elif technique == "character":
|
| 59 |
+
# FIXED: Character splitter with paragraph separator
|
| 60 |
return CharacterTextSplitter(
|
| 61 |
separator=kwargs.get('separator', "\n\n"),
|
| 62 |
chunk_size=chunk_size,
|
| 63 |
+
chunk_overlap=chunk_overlap,
|
| 64 |
+
length_function=len,
|
| 65 |
+
is_separator_regex=False
|
| 66 |
)
|
| 67 |
|
| 68 |
elif technique == "sentence":
|
| 69 |
+
# FIXED: Using Recursive Splitter with comprehensive sentence boundaries
|
| 70 |
+
# This preserves full sentences whenever possible
|
| 71 |
return RecursiveCharacterTextSplitter(
|
| 72 |
chunk_size=chunk_size,
|
| 73 |
chunk_overlap=chunk_overlap,
|
| 74 |
+
separators=kwargs.get('separators', ["\n\n", "\n", ". ", "? ", "! ", ".\n", "?\n", "!\n", "; ", ": ", ", ", " ", ""]),
|
| 75 |
+
length_function=len,
|
| 76 |
+
keep_separator=kwargs.get('keep_separator', True)
|
| 77 |
)
|
| 78 |
|
| 79 |
elif technique == "semantic":
|
| 80 |
+
# FIXED: Semantic chunker with proper configuration
|
| 81 |
return SemanticChunker(
|
| 82 |
self.hf_embeddings,
|
| 83 |
+
breakpoint_threshold_type=kwargs.get('breakpoint_threshold_type', "percentile"),
|
| 84 |
+
breakpoint_threshold_amount=kwargs.get('breakpoint_threshold_amount', 95),
|
| 85 |
+
min_chunk_size=kwargs.get('min_chunk_size', chunk_size // 10),
|
| 86 |
+
max_chunk_size=kwargs.get('max_chunk_size', chunk_size)
|
| 87 |
)
|
| 88 |
|
| 89 |
elif technique == "token":
|
| 90 |
+
# FIXED: Token-based splitter with proper token counting
|
| 91 |
return SentenceTransformersTokenTextSplitter(
|
| 92 |
model_name=self.model_name,
|
| 93 |
tokens_per_chunk=chunk_size,
|
| 94 |
+
chunk_overlap=chunk_overlap,
|
| 95 |
+
length_function=kwargs.get('length_function', lambda x: len(self.encoder.encode(x)))
|
| 96 |
)
|
| 97 |
+
|
| 98 |
else:
|
| 99 |
+
raise ValueError(f"Technique '{technique}' is not supported. Choose from: fixed, recursive, character, sentence, semantic, token")
|
| 100 |
|
| 101 |
+
def process(self, df: pd.DataFrame, technique: str = "recursive", chunk_size: int = 500,
|
| 102 |
+
chunk_overlap: int = 50, max_docs: Optional[int] = 5, verbose: Optional[bool] = None,
|
| 103 |
+
**kwargs) -> List[Dict[str, Any]]:
|
| 104 |
"""
|
| 105 |
+
Processes a DataFrame into vector-ready chunks with full output for documents.
|
| 106 |
+
|
| 107 |
+
Args:
|
| 108 |
+
df: DataFrame containing documents with columns: id, title, url, full_text
|
| 109 |
+
technique: Chunking strategy to use
|
| 110 |
+
chunk_size: Maximum size of each chunk (characters for most, tokens for token splitter)
|
| 111 |
+
chunk_overlap: Overlap between consecutive chunks
|
| 112 |
+
max_docs: Maximum number of documents to process (None for all)
|
| 113 |
+
verbose: Override the instance's verbose setting (if None, uses instance setting)
|
| 114 |
+
**kwargs: Additional arguments to pass to splitter
|
| 115 |
+
|
| 116 |
+
Returns:
|
| 117 |
+
List of processed chunks with embeddings and metadata
|
| 118 |
"""
|
| 119 |
+
# Determine if we should print
|
| 120 |
+
should_print = verbose if verbose is not None else self.verbose
|
| 121 |
+
|
| 122 |
splitter = self.get_splitter(technique, chunk_size, chunk_overlap, **kwargs)
|
| 123 |
processed_chunks = []
|
| 124 |
|
| 125 |
+
# Select documents to process
|
| 126 |
+
if max_docs:
|
| 127 |
+
subset_df = df.head(max_docs)
|
| 128 |
+
else:
|
| 129 |
+
subset_df = df
|
| 130 |
+
|
| 131 |
+
# Validate required columns exist
|
| 132 |
+
required_cols = ['id', 'title', 'url', 'full_text']
|
| 133 |
+
missing_cols = [col for col in required_cols if col not in subset_df.columns]
|
| 134 |
+
if missing_cols:
|
| 135 |
+
raise ValueError(f"DataFrame missing required columns: {missing_cols}")
|
| 136 |
|
| 137 |
for _, row in subset_df.iterrows():
|
| 138 |
+
if should_print:
|
| 139 |
+
self._print("\n" + "="*80)
|
| 140 |
+
self._print(f"π DOCUMENT: {row['title']}")
|
| 141 |
+
self._print(f"π URL: {row['url']}")
|
| 142 |
+
self._print(f"π Technique: {technique.upper()} | Chunk Size: {chunk_size} | Overlap: {chunk_overlap}")
|
| 143 |
+
self._print("-" * 80)
|
| 144 |
|
| 145 |
# Split the text
|
| 146 |
raw_chunks = splitter.split_text(row['full_text'])
|
| 147 |
|
| 148 |
+
if should_print:
|
| 149 |
+
self._print(f"π― Total Chunks Generated: {len(raw_chunks)}")
|
| 150 |
|
| 151 |
for i, text in enumerate(raw_chunks):
|
| 152 |
+
# Standardize output (handle both string and Document objects)
|
| 153 |
content = text.page_content if hasattr(text, 'page_content') else text
|
| 154 |
|
| 155 |
+
if should_print:
|
| 156 |
+
# Print chunk preview
|
| 157 |
+
self._print(f"\n[Chunk {i}] ({len(content)} chars):")
|
| 158 |
+
preview = content[:200] + "..." if len(content) > 200 else content
|
| 159 |
+
self._print(f" {preview}")
|
| 160 |
|
| 161 |
+
# Generate embedding
|
| 162 |
embedding = self.encoder.encode(content).tolist()
|
| 163 |
|
| 164 |
processed_chunks.append({
|
|
|
|
| 169 |
"text": content,
|
| 170 |
"url": row['url'],
|
| 171 |
"chunk_index": i,
|
| 172 |
+
"technique": technique,
|
| 173 |
+
"chunk_size": len(content),
|
| 174 |
+
"total_chunks": len(raw_chunks)
|
| 175 |
}
|
| 176 |
})
|
|
|
|
| 177 |
|
| 178 |
+
if should_print:
|
| 179 |
+
self._print("="*80)
|
| 180 |
+
|
| 181 |
+
if should_print:
|
| 182 |
+
self._print(f"\nβ
Finished processing {len(subset_df)} documents into {len(processed_chunks)} chunks.")
|
| 183 |
+
if len(processed_chunks) > 0:
|
| 184 |
+
self._print(f"π Average chunk size: {sum(c['metadata']['chunk_size'] for c in processed_chunks) / len(processed_chunks):.0f} chars")
|
| 185 |
+
|
| 186 |
+
return processed_chunks
|
| 187 |
+
|
| 188 |
+
def compare_strategies(self, df: pd.DataFrame, text_column: str = 'full_text',
|
| 189 |
+
chunk_size: int = 500, max_docs: int = 1,
|
| 190 |
+
verbose: Optional[bool] = None) -> Dict[str, Any]:
|
| 191 |
+
"""
|
| 192 |
+
Compare different chunking strategies on the same document.
|
| 193 |
+
|
| 194 |
+
Returns:
|
| 195 |
+
Dictionary with comparison metrics for each strategy
|
| 196 |
+
"""
|
| 197 |
+
# Determine if we should print
|
| 198 |
+
should_print = verbose if verbose is not None else self.verbose
|
| 199 |
+
|
| 200 |
+
strategies = ['fixed', 'recursive', 'character', 'sentence', 'semantic', 'token']
|
| 201 |
+
results = {}
|
| 202 |
+
|
| 203 |
+
# Get first document
|
| 204 |
+
sample_text = df.iloc[0][text_column]
|
| 205 |
+
|
| 206 |
+
for technique in strategies:
|
| 207 |
+
try:
|
| 208 |
+
if should_print:
|
| 209 |
+
self._print(f"\nπ Testing {technique.upper()} strategy...")
|
| 210 |
+
|
| 211 |
+
splitter = self.get_splitter(technique, chunk_size=chunk_size)
|
| 212 |
+
chunks = splitter.split_text(sample_text)
|
| 213 |
+
|
| 214 |
+
# Analyze chunks
|
| 215 |
+
chunk_lengths = [len(c.page_content if hasattr(c, 'page_content') else c) for c in chunks]
|
| 216 |
+
avg_chunk_size = sum(chunk_lengths) / len(chunk_lengths) if chunk_lengths else 0
|
| 217 |
+
|
| 218 |
+
# Count how many chunks end with sentence boundaries
|
| 219 |
+
sentence_enders = ['.', '!', '?', '"', "'"]
|
| 220 |
+
complete_sentences = sum(1 for c in chunks
|
| 221 |
+
if (c.page_content if hasattr(c, 'page_content') else c).strip()[-1] in sentence_enders)
|
| 222 |
+
|
| 223 |
+
results[technique] = {
|
| 224 |
+
'num_chunks': len(chunks),
|
| 225 |
+
'avg_chunk_size': avg_chunk_size,
|
| 226 |
+
'min_chunk_size': min(chunk_lengths) if chunk_lengths else 0,
|
| 227 |
+
'max_chunk_size': max(chunk_lengths) if chunk_lengths else 0,
|
| 228 |
+
'complete_sentences_ratio': complete_sentences / len(chunks) if chunks else 0,
|
| 229 |
+
'chunk_lengths': chunk_lengths
|
| 230 |
+
}
|
| 231 |
+
|
| 232 |
+
if should_print:
|
| 233 |
+
self._print(f" β Generated {len(chunks)} chunks, avg size: {avg_chunk_size:.0f} chars")
|
| 234 |
+
|
| 235 |
+
except Exception as e:
|
| 236 |
+
results[technique] = {'error': str(e)}
|
| 237 |
+
if should_print:
|
| 238 |
+
self._print(f" β Error: {str(e)}")
|
| 239 |
+
|
| 240 |
+
return results
|
retriever/retriever.py
CHANGED
|
@@ -2,21 +2,29 @@ import numpy as np
|
|
| 2 |
from rank_bm25 import BM25Okapi
|
| 3 |
from sentence_transformers import CrossEncoder
|
| 4 |
from sklearn.metrics.pairwise import cosine_similarity
|
|
|
|
| 5 |
|
| 6 |
class HybridRetriever:
|
| 7 |
-
def __init__(self, final_chunks, embed_model, rerank_model_name='cross-encoder/ms-marco-MiniLM-L-6-v2'):
|
| 8 |
"""
|
| 9 |
:param final_chunks: The list of chunk dictionaries with metadata.
|
| 10 |
:param embed_model: The SentenceTransformer model used for query and chunk embedding.
|
|
|
|
| 11 |
"""
|
| 12 |
self.final_chunks = final_chunks
|
| 13 |
self.embed_model = embed_model
|
| 14 |
self.rerank_model = CrossEncoder(rerank_model_name)
|
|
|
|
| 15 |
|
| 16 |
# Initialize BM25 corpus
|
| 17 |
self.tokenized_corpus = [chunk['metadata']['text'].lower().split() for chunk in final_chunks]
|
| 18 |
self.bm25 = BM25Okapi(self.tokenized_corpus)
|
| 19 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
def _rrf_score(self, semantic_results, bm25_results, k=60):
|
| 21 |
"""Reciprocal Rank Fusion (RRF) Implementation."""
|
| 22 |
scores = {}
|
|
@@ -67,42 +75,121 @@ class HybridRetriever:
|
|
| 67 |
|
| 68 |
return [chunk_texts[i] for i in selected_indices]
|
| 69 |
|
| 70 |
-
def search(self, query, index, top_k=10, final_k=3, mode="hybrid", rerank_strategy="cross-encoder"
|
|
|
|
| 71 |
"""
|
| 72 |
:param mode: "semantic", "bm25", or "hybrid"
|
| 73 |
:param rerank_strategy: "cross-encoder", "rrf", "mmr", or "none"
|
|
|
|
| 74 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 75 |
semantic_chunks = []
|
| 76 |
bm25_chunks = []
|
| 77 |
query_vector = None
|
| 78 |
|
| 79 |
# 1. Fetch Candidates
|
| 80 |
if mode in ["semantic", "hybrid"]:
|
|
|
|
|
|
|
|
|
|
| 81 |
query_vector = self.embed_model.encode(query)
|
| 82 |
res = index.query(vector=query_vector.tolist(), top_k=top_k, include_metadata=True)
|
| 83 |
semantic_chunks = [match['metadata']['text'] for match in res['matches']]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 84 |
|
| 85 |
if mode in ["bm25", "hybrid"]:
|
|
|
|
|
|
|
|
|
|
| 86 |
tokenized_query = query.lower().split()
|
| 87 |
bm25_scores = self.bm25.get_scores(tokenized_query)
|
| 88 |
top_indices = np.argsort(bm25_scores)[::-1][:top_k]
|
| 89 |
bm25_chunks = [self.final_chunks[i]['metadata']['text'] for i in top_indices]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 90 |
|
| 91 |
# 2. Re-Ranking / Fusion
|
| 92 |
if mode == "hybrid" and rerank_strategy == "rrf":
|
| 93 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 94 |
|
| 95 |
# Standard combination for other methods
|
| 96 |
combined = list(dict.fromkeys(semantic_chunks + bm25_chunks)) # Deduplicate keep order
|
|
|
|
|
|
|
|
|
|
|
|
|
| 97 |
|
| 98 |
if rerank_strategy == "cross-encoder" and combined:
|
|
|
|
| 99 |
pairs = [[query, chunk] for chunk in combined]
|
| 100 |
scores = self.rerank_model.predict(pairs)
|
| 101 |
results = sorted(zip(combined, scores), key=lambda x: x[1], reverse=True)
|
| 102 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 103 |
|
| 104 |
elif rerank_strategy == "mmr" and combined:
|
| 105 |
-
if
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
from rank_bm25 import BM25Okapi
|
| 3 |
from sentence_transformers import CrossEncoder
|
| 4 |
from sklearn.metrics.pairwise import cosine_similarity
|
| 5 |
+
from typing import Optional
|
| 6 |
|
| 7 |
class HybridRetriever:
|
| 8 |
+
def __init__(self, final_chunks, embed_model, rerank_model_name='cross-encoder/ms-marco-MiniLM-L-6-v2', verbose: bool = True):
|
| 9 |
"""
|
| 10 |
:param final_chunks: The list of chunk dictionaries with metadata.
|
| 11 |
:param embed_model: The SentenceTransformer model used for query and chunk embedding.
|
| 12 |
+
:param verbose: Whether to print retrieval details and final results.
|
| 13 |
"""
|
| 14 |
self.final_chunks = final_chunks
|
| 15 |
self.embed_model = embed_model
|
| 16 |
self.rerank_model = CrossEncoder(rerank_model_name)
|
| 17 |
+
self.verbose = verbose
|
| 18 |
|
| 19 |
# Initialize BM25 corpus
|
| 20 |
self.tokenized_corpus = [chunk['metadata']['text'].lower().split() for chunk in final_chunks]
|
| 21 |
self.bm25 = BM25Okapi(self.tokenized_corpus)
|
| 22 |
|
| 23 |
+
def _print(self, *args, **kwargs):
|
| 24 |
+
"""Helper method to conditionally print"""
|
| 25 |
+
if self.verbose:
|
| 26 |
+
print(*args, **kwargs)
|
| 27 |
+
|
| 28 |
def _rrf_score(self, semantic_results, bm25_results, k=60):
|
| 29 |
"""Reciprocal Rank Fusion (RRF) Implementation."""
|
| 30 |
scores = {}
|
|
|
|
| 75 |
|
| 76 |
return [chunk_texts[i] for i in selected_indices]
|
| 77 |
|
| 78 |
+
def search(self, query, index, top_k=10, final_k=3, mode="hybrid", rerank_strategy="cross-encoder",
|
| 79 |
+
verbose: Optional[bool] = None):
|
| 80 |
"""
|
| 81 |
:param mode: "semantic", "bm25", or "hybrid"
|
| 82 |
:param rerank_strategy: "cross-encoder", "rrf", "mmr", or "none"
|
| 83 |
+
:param verbose: Override the instance's verbose setting (if None, uses instance setting)
|
| 84 |
"""
|
| 85 |
+
# Determine if we should print
|
| 86 |
+
should_print = verbose if verbose is not None else self.verbose
|
| 87 |
+
|
| 88 |
+
if should_print:
|
| 89 |
+
self._print("\n" + "="*80)
|
| 90 |
+
self._print(f"π SEARCH QUERY: {query}")
|
| 91 |
+
self._print(f"π Mode: {mode.upper()} | Rerank: {rerank_strategy.upper()}")
|
| 92 |
+
self._print(f"π― Top-K: {top_k} | Final-K: {final_k}")
|
| 93 |
+
self._print("-" * 80)
|
| 94 |
+
|
| 95 |
semantic_chunks = []
|
| 96 |
bm25_chunks = []
|
| 97 |
query_vector = None
|
| 98 |
|
| 99 |
# 1. Fetch Candidates
|
| 100 |
if mode in ["semantic", "hybrid"]:
|
| 101 |
+
if should_print:
|
| 102 |
+
self._print(f"π Semantic Search: Retrieving top {top_k} candidates...")
|
| 103 |
+
|
| 104 |
query_vector = self.embed_model.encode(query)
|
| 105 |
res = index.query(vector=query_vector.tolist(), top_k=top_k, include_metadata=True)
|
| 106 |
semantic_chunks = [match['metadata']['text'] for match in res['matches']]
|
| 107 |
+
|
| 108 |
+
if should_print:
|
| 109 |
+
self._print(f" β Retrieved {len(semantic_chunks)} semantic candidates")
|
| 110 |
+
for i, chunk in enumerate(semantic_chunks[:3]): # Show first 3
|
| 111 |
+
preview = chunk[:100] + "..." if len(chunk) > 100 else chunk
|
| 112 |
+
self._print(f" [{i}] {preview}")
|
| 113 |
|
| 114 |
if mode in ["bm25", "hybrid"]:
|
| 115 |
+
if should_print:
|
| 116 |
+
self._print(f"π BM25 Search: Retrieving top {top_k} candidates...")
|
| 117 |
+
|
| 118 |
tokenized_query = query.lower().split()
|
| 119 |
bm25_scores = self.bm25.get_scores(tokenized_query)
|
| 120 |
top_indices = np.argsort(bm25_scores)[::-1][:top_k]
|
| 121 |
bm25_chunks = [self.final_chunks[i]['metadata']['text'] for i in top_indices]
|
| 122 |
+
|
| 123 |
+
if should_print:
|
| 124 |
+
self._print(f" β Retrieved {len(bm25_chunks)} BM25 candidates")
|
| 125 |
+
for i, chunk in enumerate(bm25_chunks[:3]): # Show first 3
|
| 126 |
+
preview = chunk[:100] + "..." if len(chunk) > 100 else chunk
|
| 127 |
+
self._print(f" [{i}] {preview}")
|
| 128 |
|
| 129 |
# 2. Re-Ranking / Fusion
|
| 130 |
if mode == "hybrid" and rerank_strategy == "rrf":
|
| 131 |
+
if should_print:
|
| 132 |
+
self._print(f"π Applying Reciprocal Rank Fusion (RRF)...")
|
| 133 |
+
|
| 134 |
+
results = self._rrf_score(semantic_chunks, bm25_chunks)[:final_k]
|
| 135 |
+
|
| 136 |
+
if should_print:
|
| 137 |
+
self._print(f"β
Final {final_k} Results:")
|
| 138 |
+
for i, chunk in enumerate(results):
|
| 139 |
+
preview = chunk[:150] + "..." if len(chunk) > 150 else chunk
|
| 140 |
+
self._print(f" [{i+1}] {preview}")
|
| 141 |
+
self._print("="*80)
|
| 142 |
+
|
| 143 |
+
return results
|
| 144 |
|
| 145 |
# Standard combination for other methods
|
| 146 |
combined = list(dict.fromkeys(semantic_chunks + bm25_chunks)) # Deduplicate keep order
|
| 147 |
+
|
| 148 |
+
if should_print:
|
| 149 |
+
self._print(f"π Combined unique candidates: {len(combined)}")
|
| 150 |
+
self._print(f"π Applying {rerank_strategy.upper()} reranking...")
|
| 151 |
|
| 152 |
if rerank_strategy == "cross-encoder" and combined:
|
| 153 |
+
|
| 154 |
pairs = [[query, chunk] for chunk in combined]
|
| 155 |
scores = self.rerank_model.predict(pairs)
|
| 156 |
results = sorted(zip(combined, scores), key=lambda x: x[1], reverse=True)
|
| 157 |
+
results = [res[0] for res in results[:final_k]]
|
| 158 |
+
|
| 159 |
+
if should_print:
|
| 160 |
+
self._print(f"\nβ
Final {final_k} Results (Cross-Encoder Reranked):")
|
| 161 |
+
for i, chunk in enumerate(results):
|
| 162 |
+
preview = chunk[:150] + "..." if len(chunk) > 150 else chunk
|
| 163 |
+
self._print(f" [{i+1}] {preview}")
|
| 164 |
+
self._print("="*80)
|
| 165 |
+
|
| 166 |
+
return results
|
| 167 |
|
| 168 |
elif rerank_strategy == "mmr" and combined:
|
| 169 |
+
if should_print:
|
| 170 |
+
self._print(f" Using MMR with Ξ»=0.5 to balance relevance and diversity")
|
| 171 |
+
|
| 172 |
+
if query_vector is None:
|
| 173 |
+
query_vector = self.embed_model.encode(query)
|
| 174 |
+
results = self._maximal_marginal_relevance(query_vector, combined, top_k=final_k)
|
| 175 |
+
|
| 176 |
+
if should_print:
|
| 177 |
+
self._print(f"\nβ
Final {final_k} Results (MMR Reranked):")
|
| 178 |
+
for i, chunk in enumerate(results):
|
| 179 |
+
preview = chunk[:150] + "..." if len(chunk) > 150 else chunk
|
| 180 |
+
self._print(f" [{i+1}] {preview}")
|
| 181 |
+
self._print("="*80)
|
| 182 |
+
|
| 183 |
+
return results
|
| 184 |
+
|
| 185 |
+
else: # "none" or fallback
|
| 186 |
+
results = combined[:final_k]
|
| 187 |
+
|
| 188 |
+
if should_print:
|
| 189 |
+
self._print(f"\nβ
Final {final_k} Results (No Reranking):")
|
| 190 |
+
for i, chunk in enumerate(results):
|
| 191 |
+
preview = chunk[:150] + "..." if len(chunk) > 150 else chunk
|
| 192 |
+
self._print(f" [{i+1}] {preview}")
|
| 193 |
+
self._print("="*80)
|
| 194 |
+
|
| 195 |
+
return results
|