cleaning code -> making phase 1 pipeline ready
Browse files- config.yaml +1 -1
- data_loader.py +1 -1
- main.py +85 -18
- models/deepseek_v3.py +2 -2
- models/llama_3_8b.py +2 -2
- models/mistral_7b.py +2 -11
- models/qwen_2_5.py +1 -1
- models/tiny_aya.py +3 -9
- retriever/generator.py +1 -1
- retriever/processor.py +93 -159
- retriever/retriever.py +126 -154
- vector_db.py +71 -2
config.yaml
CHANGED
|
@@ -4,7 +4,7 @@ project_name: "arxiv_cyber_advisor"
|
|
| 4 |
# Stage 1: Data Acquisition
|
| 5 |
data_ingestion:
|
| 6 |
category: "cs.AI"
|
| 7 |
-
limit:
|
| 8 |
save_local: true
|
| 9 |
raw_data_path: "data/raw_arxiv.csv"
|
| 10 |
|
|
|
|
| 4 |
# Stage 1: Data Acquisition
|
| 5 |
data_ingestion:
|
| 6 |
category: "cs.AI"
|
| 7 |
+
limit: 5
|
| 8 |
save_local: true
|
| 9 |
raw_data_path: "data/raw_arxiv.csv"
|
| 10 |
|
data_loader.py
CHANGED
|
@@ -35,7 +35,7 @@ def fetch_arxiv_data(category="cs.AI", limit=5):
|
|
| 35 |
"id": r.entry_id.split('/')[-1],
|
| 36 |
"title": r.title,
|
| 37 |
"abstract": r.summary.replace('\n', ' '),
|
| 38 |
-
"full_text": full_text, # <---
|
| 39 |
"url": r.pdf_url
|
| 40 |
})
|
| 41 |
return pd.DataFrame(results)
|
|
|
|
| 35 |
"id": r.entry_id.split('/')[-1],
|
| 36 |
"title": r.title,
|
| 37 |
"abstract": r.summary.replace('\n', ' '),
|
| 38 |
+
"full_text": full_text, # <--- Main part of the data
|
| 39 |
"url": r.pdf_url
|
| 40 |
})
|
| 41 |
return pd.DataFrame(results)
|
main.py
CHANGED
|
@@ -1,28 +1,95 @@
|
|
| 1 |
-
import
|
| 2 |
-
from
|
| 3 |
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
|
| 8 |
def main():
|
| 9 |
-
config = load_config()
|
| 10 |
-
|
| 11 |
-
# Run Stage 1
|
| 12 |
-
raw_data = fetch_arxiv_data(
|
| 13 |
-
category=config['data_ingestion']['category'],
|
| 14 |
-
limit=config['data_ingestion']['limit']
|
| 15 |
-
)
|
| 16 |
|
| 17 |
-
#
|
| 18 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
raw_data,
|
| 20 |
-
|
| 21 |
-
chunk_size=
|
| 22 |
-
chunk_overlap=
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
)
|
| 24 |
|
| 25 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
|
| 27 |
if __name__ == "__main__":
|
| 28 |
main()
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from dotenv import load_dotenv
|
| 3 |
|
| 4 |
+
from vector_db import get_pinecone_index, refresh_pinecone_index
|
| 5 |
+
from retriever.retriever import HybridRetriever
|
| 6 |
+
from retriever.generator import RAGGenerator
|
| 7 |
+
from retriever.processor import ChunkProcessor
|
| 8 |
+
import data_loader as dl
|
| 9 |
+
|
| 10 |
+
from models.llama_3_8b import Llama3_8B
|
| 11 |
+
from models.mistral_7b import Mistral_7b
|
| 12 |
+
from models.qwen_2_5 import Qwen2_5
|
| 13 |
+
from models.deepseek_v3 import DeepSeek_V3
|
| 14 |
+
from models.tiny_aya import TinyAya
|
| 15 |
+
|
| 16 |
+
load_dotenv()
|
| 17 |
|
| 18 |
def main():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
|
| 20 |
+
# ------------------------------------------------------------------
|
| 21 |
+
# 0. Configuration
|
| 22 |
+
# ------------------------------------------------------------------
|
| 23 |
+
hf_token = os.getenv("HF_TOKEN")
|
| 24 |
+
pinecone_api_key = os.getenv("PINECONE_API_KEY")
|
| 25 |
+
if not pinecone_api_key:
|
| 26 |
+
raise ValueError("PINECONE_API_KEY not found in environment variables")
|
| 27 |
+
|
| 28 |
+
query = "How do transformers handle long sequences?"
|
| 29 |
+
|
| 30 |
+
# ------------------------------------------------------------------
|
| 31 |
+
# 1. Data Ingestion
|
| 32 |
+
# ------------------------------------------------------------------
|
| 33 |
+
raw_data = dl.fetch_arxiv_data(category="cs.AI", limit=5)
|
| 34 |
+
|
| 35 |
+
# ------------------------------------------------------------------
|
| 36 |
+
# 2. Chunking & Embedding
|
| 37 |
+
# ------------------------------------------------------------------
|
| 38 |
+
proc = ChunkProcessor(model_name='all-MiniLM-L6-v2', verbose=True)
|
| 39 |
+
final_chunks = proc.process(
|
| 40 |
raw_data,
|
| 41 |
+
technique="sentence", # options: fixed, recursive, character, sentence, semantic
|
| 42 |
+
chunk_size=500,
|
| 43 |
+
chunk_overlap=50
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
# ------------------------------------------------------------------
|
| 47 |
+
# 3. Vector DB
|
| 48 |
+
# ------------------------------------------------------------------
|
| 49 |
+
index_name = "arxiv-index"
|
| 50 |
+
index = get_pinecone_index(pinecone_api_key, index_name, dimension=384, metric="cosine")
|
| 51 |
+
refresh_pinecone_index(index, final_chunks, batch_size=100)
|
| 52 |
+
|
| 53 |
+
# ------------------------------------------------------------------
|
| 54 |
+
# 4. Retrieval
|
| 55 |
+
# ------------------------------------------------------------------
|
| 56 |
+
retriever = HybridRetriever(final_chunks, proc.encoder, verbose=True)
|
| 57 |
+
context_chunks = retriever.search(
|
| 58 |
+
query,
|
| 59 |
+
index,
|
| 60 |
+
mode="hybrid", # options: bm25, semantic, hybrid
|
| 61 |
+
rerank_strategy="cross-encoder", # options: cross-encoder, rrf
|
| 62 |
+
use_mmr=True,
|
| 63 |
+
top_k=10,
|
| 64 |
+
final_k=5
|
| 65 |
)
|
| 66 |
|
| 67 |
+
if not context_chunks:
|
| 68 |
+
print("No context chunks retrieved. Check your index and query.")
|
| 69 |
+
return
|
| 70 |
+
|
| 71 |
+
# ------------------------------------------------------------------
|
| 72 |
+
# 5. Generation
|
| 73 |
+
# ------------------------------------------------------------------
|
| 74 |
+
rag_engine = RAGGenerator()
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
models = {
|
| 79 |
+
"Llama-3-8B": Llama3_8B(token=hf_token),
|
| 80 |
+
"Mistral-7B": Mistral_7b(token=hf_token),
|
| 81 |
+
"Qwen-2.5": Qwen2_5(token=hf_token),
|
| 82 |
+
"DeepSeek-V3": DeepSeek_V3(token=hf_token),
|
| 83 |
+
"TinyAya": TinyAya(token=hf_token)
|
| 84 |
+
}
|
| 85 |
+
|
| 86 |
+
for name, model in models.items():
|
| 87 |
+
print(f"\n--- {name} ---")
|
| 88 |
+
try:
|
| 89 |
+
print(rag_engine.get_answer(model, query, context_chunks, temperature=0.1))
|
| 90 |
+
except Exception as e:
|
| 91 |
+
print(f"Error: {e}")
|
| 92 |
+
|
| 93 |
|
| 94 |
if __name__ == "__main__":
|
| 95 |
main()
|
models/deepseek_v3.py
CHANGED
|
@@ -5,7 +5,7 @@ class DeepSeek_V3:
|
|
| 5 |
self.client = InferenceClient(token=token)
|
| 6 |
self.model_id = "deepseek-ai/DeepSeek-V3"
|
| 7 |
|
| 8 |
-
def generate(self, prompt, max_tokens=500, temperature=0.
|
| 9 |
response = ""
|
| 10 |
try:
|
| 11 |
for message in self.client.chat_completion(
|
|
@@ -19,5 +19,5 @@ class DeepSeek_V3:
|
|
| 19 |
content = message.choices[0].delta.content
|
| 20 |
if content: response += content
|
| 21 |
except Exception as e:
|
| 22 |
-
return f"
|
| 23 |
return response
|
|
|
|
| 5 |
self.client = InferenceClient(token=token)
|
| 6 |
self.model_id = "deepseek-ai/DeepSeek-V3"
|
| 7 |
|
| 8 |
+
def generate(self, prompt, max_tokens=500, temperature=0.1):
|
| 9 |
response = ""
|
| 10 |
try:
|
| 11 |
for message in self.client.chat_completion(
|
|
|
|
| 19 |
content = message.choices[0].delta.content
|
| 20 |
if content: response += content
|
| 21 |
except Exception as e:
|
| 22 |
+
return f" DeepSeek API Busy: {e}"
|
| 23 |
return response
|
models/llama_3_8b.py
CHANGED
|
@@ -5,13 +5,13 @@ class Llama3_8B:
|
|
| 5 |
self.client = InferenceClient(token=token)
|
| 6 |
self.model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
|
| 7 |
|
| 8 |
-
def generate(self, prompt, max_tokens=500,
|
| 9 |
response = ""
|
| 10 |
for message in self.client.chat_completion(
|
| 11 |
model=self.model_id,
|
| 12 |
messages=[{"role": "user", "content": prompt}],
|
| 13 |
max_tokens=max_tokens,
|
| 14 |
-
temperature=
|
| 15 |
stream=True,
|
| 16 |
):
|
| 17 |
if message.choices:
|
|
|
|
| 5 |
self.client = InferenceClient(token=token)
|
| 6 |
self.model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
|
| 7 |
|
| 8 |
+
def generate(self, prompt, max_tokens=500, temperature=0.1):
|
| 9 |
response = ""
|
| 10 |
for message in self.client.chat_completion(
|
| 11 |
model=self.model_id,
|
| 12 |
messages=[{"role": "user", "content": prompt}],
|
| 13 |
max_tokens=max_tokens,
|
| 14 |
+
temperature=temperature,
|
| 15 |
stream=True,
|
| 16 |
):
|
| 17 |
if message.choices:
|
models/mistral_7b.py
CHANGED
|
@@ -1,20 +1,13 @@
|
|
| 1 |
-
import os
|
| 2 |
from huggingface_hub import InferenceClient
|
| 3 |
|
| 4 |
class Mistral_7b:
|
| 5 |
def __init__(self, token):
|
| 6 |
-
# Initializing with api_key as per latest documentation
|
| 7 |
self.client = InferenceClient(api_key=token)
|
| 8 |
-
# Using the specific provider suffix
|
| 9 |
self.model_id = "mistralai/Mistral-7B-Instruct-v0.2:featherless-ai"
|
| 10 |
|
| 11 |
-
def generate(self, prompt, max_tokens=500,
|
| 12 |
-
# Extract temperature, defaulting to 0.2 if not provided
|
| 13 |
-
temperature = kwargs.get('temperature', kwargs.get('temp', 0.2))
|
| 14 |
-
|
| 15 |
response = ""
|
| 16 |
try:
|
| 17 |
-
# Using the new .chat.completions.create syntax for Featherless
|
| 18 |
stream = self.client.chat.completions.create(
|
| 19 |
model=self.model_id,
|
| 20 |
messages=[{"role": "user", "content": prompt}],
|
|
@@ -22,14 +15,12 @@ class Mistral_7b:
|
|
| 22 |
temperature=temperature,
|
| 23 |
stream=True,
|
| 24 |
)
|
| 25 |
-
|
| 26 |
for chunk in stream:
|
| 27 |
-
# Accessing content through the standard completion object structure
|
| 28 |
if chunk.choices and chunk.choices[0].delta.content:
|
| 29 |
content = chunk.choices[0].delta.content
|
| 30 |
response += content
|
| 31 |
|
| 32 |
except Exception as e:
|
| 33 |
-
return f"
|
| 34 |
|
| 35 |
return response
|
|
|
|
|
|
|
| 1 |
from huggingface_hub import InferenceClient
|
| 2 |
|
| 3 |
class Mistral_7b:
|
| 4 |
def __init__(self, token):
|
|
|
|
| 5 |
self.client = InferenceClient(api_key=token)
|
|
|
|
| 6 |
self.model_id = "mistralai/Mistral-7B-Instruct-v0.2:featherless-ai"
|
| 7 |
|
| 8 |
+
def generate(self, prompt, max_tokens=500, temperature=0.1):
|
|
|
|
|
|
|
|
|
|
| 9 |
response = ""
|
| 10 |
try:
|
|
|
|
| 11 |
stream = self.client.chat.completions.create(
|
| 12 |
model=self.model_id,
|
| 13 |
messages=[{"role": "user", "content": prompt}],
|
|
|
|
| 15 |
temperature=temperature,
|
| 16 |
stream=True,
|
| 17 |
)
|
|
|
|
| 18 |
for chunk in stream:
|
|
|
|
| 19 |
if chunk.choices and chunk.choices[0].delta.content:
|
| 20 |
content = chunk.choices[0].delta.content
|
| 21 |
response += content
|
| 22 |
|
| 23 |
except Exception as e:
|
| 24 |
+
return f" Mistral Featherless Error: {e}"
|
| 25 |
|
| 26 |
return response
|
models/qwen_2_5.py
CHANGED
|
@@ -5,7 +5,7 @@ class Qwen2_5:
|
|
| 5 |
self.client = InferenceClient(token=token)
|
| 6 |
self.model_id = "Qwen/Qwen2.5-72B-Instruct"
|
| 7 |
|
| 8 |
-
def generate(self, prompt, max_tokens=500, temperature=0.
|
| 9 |
response = ""
|
| 10 |
for message in self.client.chat_completion(
|
| 11 |
model=self.model_id,
|
|
|
|
| 5 |
self.client = InferenceClient(token=token)
|
| 6 |
self.model_id = "Qwen/Qwen2.5-72B-Instruct"
|
| 7 |
|
| 8 |
+
def generate(self, prompt, max_tokens=500, temperature=0.1):
|
| 9 |
response = ""
|
| 10 |
for message in self.client.chat_completion(
|
| 11 |
model=self.model_id,
|
models/tiny_aya.py
CHANGED
|
@@ -3,16 +3,10 @@ from huggingface_hub import InferenceClient
|
|
| 3 |
class TinyAya:
|
| 4 |
def __init__(self, token):
|
| 5 |
self.client = InferenceClient(token=token)
|
| 6 |
-
# 3.3B parameter model, great for multilingual/efficient RAG
|
| 7 |
self.model_id = "CohereLabs/tiny-aya-global"
|
| 8 |
|
| 9 |
-
def generate(self, prompt, max_tokens=
|
| 10 |
-
|
| 11 |
-
Using **kwargs makes this compatible with calls using 'temp' or 'temperature'.
|
| 12 |
-
"""
|
| 13 |
-
# This line looks for 'temperature', then 'temp', and defaults to 0.3 if neither exist
|
| 14 |
-
temperature = kwargs.get('temperature', kwargs.get('temp', 0.3))
|
| 15 |
-
|
| 16 |
response = ""
|
| 17 |
try:
|
| 18 |
for message in self.client.chat_completion(
|
|
@@ -26,6 +20,6 @@ class TinyAya:
|
|
| 26 |
content = message.choices[0].delta.content
|
| 27 |
if content: response += content
|
| 28 |
except Exception as e:
|
| 29 |
-
return f"
|
| 30 |
|
| 31 |
return response
|
|
|
|
| 3 |
class TinyAya:
|
| 4 |
def __init__(self, token):
|
| 5 |
self.client = InferenceClient(token=token)
|
|
|
|
| 6 |
self.model_id = "CohereLabs/tiny-aya-global"
|
| 7 |
|
| 8 |
+
def generate(self, prompt, max_tokens=500, temperature=0.1):
|
| 9 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
response = ""
|
| 11 |
try:
|
| 12 |
for message in self.client.chat_completion(
|
|
|
|
| 20 |
content = message.choices[0].delta.content
|
| 21 |
if content: response += content
|
| 22 |
except Exception as e:
|
| 23 |
+
return f" TinyAya Error: {e}"
|
| 24 |
|
| 25 |
return response
|
retriever/generator.py
CHANGED
|
@@ -15,5 +15,5 @@ Answer:"""
|
|
| 15 |
|
| 16 |
def get_answer(self, model_instance, query, retrieved_contexts, **kwargs):
|
| 17 |
"""Uses a specific model instance to generate the final answer."""
|
| 18 |
-
prompt = self.generate_prompt(query, retrieved_contexts)
|
| 19 |
return model_instance.generate(prompt, **kwargs)
|
|
|
|
| 15 |
|
| 16 |
def get_answer(self, model_instance, query, retrieved_contexts, **kwargs):
|
| 17 |
"""Uses a specific model instance to generate the final answer."""
|
| 18 |
+
prompt = self.generate_prompt(query, retrieved_contexts)
|
| 19 |
return model_instance.generate(prompt, **kwargs)
|
retriever/processor.py
CHANGED
|
@@ -1,169 +1,133 @@
|
|
| 1 |
from langchain_text_splitters import (
|
| 2 |
RecursiveCharacterTextSplitter,
|
| 3 |
CharacterTextSplitter,
|
| 4 |
-
SentenceTransformersTokenTextSplitter
|
|
|
|
| 5 |
)
|
| 6 |
from langchain_experimental.text_splitter import SemanticChunker
|
| 7 |
from langchain_huggingface import HuggingFaceEmbeddings
|
| 8 |
from sentence_transformers import SentenceTransformer
|
| 9 |
from typing import List, Dict, Any, Optional
|
|
|
|
|
|
|
| 10 |
import pandas as pd
|
| 11 |
|
|
|
|
| 12 |
class ChunkProcessor:
|
| 13 |
def __init__(self, model_name='all-MiniLM-L6-v2', verbose: bool = True):
|
| 14 |
self.model_name = model_name
|
| 15 |
self.encoder = SentenceTransformer(model_name)
|
| 16 |
self.verbose = verbose
|
| 17 |
-
# Required for Semantic Chunking
|
| 18 |
self.hf_embeddings = HuggingFaceEmbeddings(model_name=model_name)
|
| 19 |
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
print(*args, **kwargs)
|
| 24 |
|
| 25 |
def get_splitter(self, technique: str, chunk_size: int = 500, chunk_overlap: int = 50, **kwargs):
|
| 26 |
"""
|
| 27 |
Factory method to return different chunking strategies.
|
| 28 |
-
|
| 29 |
Strategies:
|
| 30 |
-
- "fixed":
|
| 31 |
-
- "recursive": Recursive character splitting with hierarchical separators
|
| 32 |
-
- "character": Character-based splitting
|
| 33 |
-
- "sentence":
|
| 34 |
-
- "semantic":
|
| 35 |
-
- "token": Token-based splitting for transformer models
|
| 36 |
"""
|
| 37 |
if technique == "fixed":
|
| 38 |
-
# FIXED: Simple character-based splitter - WILL split mid-sentence
|
| 39 |
return CharacterTextSplitter(
|
| 40 |
-
separator=kwargs.get('separator', ""),
|
| 41 |
-
chunk_size=chunk_size,
|
| 42 |
chunk_overlap=chunk_overlap,
|
| 43 |
length_function=len,
|
| 44 |
is_separator_regex=False
|
| 45 |
)
|
| 46 |
-
|
| 47 |
elif technique == "recursive":
|
| 48 |
-
# FIXED: Proper recursive splitter with default separators that preserve semantics
|
| 49 |
-
separators = kwargs.get('separators', ["\n\n", "\n", ". ", "! ", "? ", "; ", ", ", " ", ""])
|
| 50 |
return RecursiveCharacterTextSplitter(
|
| 51 |
-
chunk_size=chunk_size,
|
| 52 |
chunk_overlap=chunk_overlap,
|
| 53 |
-
separators=separators,
|
| 54 |
length_function=len,
|
| 55 |
keep_separator=kwargs.get('keep_separator', True)
|
| 56 |
)
|
| 57 |
-
|
| 58 |
elif technique == "character":
|
| 59 |
-
# FIXED: Character splitter with paragraph separator
|
| 60 |
return CharacterTextSplitter(
|
| 61 |
-
separator=kwargs.get('separator', "\n\n"),
|
| 62 |
-
chunk_size=chunk_size,
|
| 63 |
chunk_overlap=chunk_overlap,
|
| 64 |
length_function=len,
|
| 65 |
is_separator_regex=False
|
| 66 |
)
|
| 67 |
-
|
| 68 |
elif technique == "sentence":
|
| 69 |
-
#
|
| 70 |
-
|
| 71 |
-
return RecursiveCharacterTextSplitter(
|
| 72 |
chunk_size=chunk_size,
|
| 73 |
chunk_overlap=chunk_overlap,
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
)
|
| 78 |
-
|
| 79 |
elif technique == "semantic":
|
| 80 |
-
# FIXED: Semantic chunker with proper configuration
|
| 81 |
return SemanticChunker(
|
| 82 |
-
self.hf_embeddings,
|
| 83 |
breakpoint_threshold_type=kwargs.get('breakpoint_threshold_type', "percentile"),
|
| 84 |
-
breakpoint_threshold_amount=kwargs.get('breakpoint_threshold_amount', 95)
|
| 85 |
-
min_chunk_size=kwargs.get('min_chunk_size', chunk_size // 10),
|
| 86 |
-
max_chunk_size=kwargs.get('max_chunk_size', chunk_size)
|
| 87 |
-
)
|
| 88 |
-
|
| 89 |
-
elif technique == "token":
|
| 90 |
-
# FIXED: Token-based splitter with proper token counting
|
| 91 |
-
return SentenceTransformersTokenTextSplitter(
|
| 92 |
-
model_name=self.model_name,
|
| 93 |
-
tokens_per_chunk=chunk_size,
|
| 94 |
-
chunk_overlap=chunk_overlap,
|
| 95 |
-
length_function=kwargs.get('length_function', lambda x: len(self.encoder.encode(x)))
|
| 96 |
)
|
| 97 |
-
|
| 98 |
else:
|
| 99 |
-
raise ValueError(f"Technique '{technique}' is not supported. Choose from: fixed, recursive, character, sentence, semantic
|
|
|
|
|
|
|
|
|
|
|
|
|
| 100 |
|
| 101 |
-
def process(self, df: pd.DataFrame, technique: str = "recursive", chunk_size: int = 500,
|
| 102 |
-
chunk_overlap: int = 50, max_docs: Optional[int] = 5,
|
| 103 |
-
**kwargs) -> List[Dict[str, Any]]:
|
| 104 |
"""
|
| 105 |
-
Processes a DataFrame into vector-ready chunks
|
| 106 |
-
|
| 107 |
Args:
|
| 108 |
-
df:
|
| 109 |
-
technique:
|
| 110 |
-
chunk_size:
|
| 111 |
chunk_overlap: Overlap between consecutive chunks
|
| 112 |
-
max_docs:
|
| 113 |
-
verbose:
|
| 114 |
-
**kwargs:
|
| 115 |
-
|
| 116 |
Returns:
|
| 117 |
-
List of
|
| 118 |
"""
|
| 119 |
-
# Determine if we should print
|
| 120 |
should_print = verbose if verbose is not None else self.verbose
|
| 121 |
-
|
| 122 |
-
splitter = self.get_splitter(technique, chunk_size, chunk_overlap, **kwargs)
|
| 123 |
-
processed_chunks = []
|
| 124 |
-
|
| 125 |
-
# Select documents to process
|
| 126 |
-
if max_docs:
|
| 127 |
-
subset_df = df.head(max_docs)
|
| 128 |
-
else:
|
| 129 |
-
subset_df = df
|
| 130 |
-
|
| 131 |
-
# Validate required columns exist
|
| 132 |
required_cols = ['id', 'title', 'url', 'full_text']
|
| 133 |
-
missing_cols = [col for col in required_cols if col not in
|
| 134 |
if missing_cols:
|
| 135 |
raise ValueError(f"DataFrame missing required columns: {missing_cols}")
|
| 136 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 137 |
for _, row in subset_df.iterrows():
|
| 138 |
if should_print:
|
| 139 |
-
self.
|
| 140 |
-
|
| 141 |
-
self._print(f"🔗 URL: {row['url']}")
|
| 142 |
-
self._print(f"📏 Technique: {technique.upper()} | Chunk Size: {chunk_size} | Overlap: {chunk_overlap}")
|
| 143 |
-
self._print("-" * 80)
|
| 144 |
-
|
| 145 |
-
# Split the text
|
| 146 |
raw_chunks = splitter.split_text(row['full_text'])
|
| 147 |
-
|
| 148 |
-
if should_print:
|
| 149 |
-
self._print(f"🎯 Total Chunks Generated: {len(raw_chunks)}")
|
| 150 |
-
|
| 151 |
for i, text in enumerate(raw_chunks):
|
| 152 |
-
# Standardize output (handle both string and Document objects)
|
| 153 |
content = text.page_content if hasattr(text, 'page_content') else text
|
| 154 |
-
|
| 155 |
if should_print:
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
preview = content[:200] + "..." if len(content) > 200 else content
|
| 159 |
-
self._print(f" {preview}")
|
| 160 |
-
|
| 161 |
-
# Generate embedding
|
| 162 |
-
embedding = self.encoder.encode(content).tolist()
|
| 163 |
-
|
| 164 |
processed_chunks.append({
|
| 165 |
"id": f"{row['id']}-chunk-{i}",
|
| 166 |
-
"values":
|
| 167 |
"metadata": {
|
| 168 |
"title": row['title'],
|
| 169 |
"text": content,
|
|
@@ -174,67 +138,37 @@ class ChunkProcessor:
|
|
| 174 |
"total_chunks": len(raw_chunks)
|
| 175 |
}
|
| 176 |
})
|
| 177 |
-
|
| 178 |
if should_print:
|
| 179 |
-
self.
|
| 180 |
-
|
| 181 |
if should_print:
|
| 182 |
-
self.
|
| 183 |
-
|
| 184 |
-
self._print(f"📊 Average chunk size: {sum(c['metadata']['chunk_size'] for c in processed_chunks) / len(processed_chunks):.0f} chars")
|
| 185 |
-
|
| 186 |
return processed_chunks
|
| 187 |
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
""
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
# Analyze chunks
|
| 215 |
-
chunk_lengths = [len(c.page_content if hasattr(c, 'page_content') else c) for c in chunks]
|
| 216 |
-
avg_chunk_size = sum(chunk_lengths) / len(chunk_lengths) if chunk_lengths else 0
|
| 217 |
-
|
| 218 |
-
# Count how many chunks end with sentence boundaries
|
| 219 |
-
sentence_enders = ['.', '!', '?', '"', "'"]
|
| 220 |
-
complete_sentences = sum(1 for c in chunks
|
| 221 |
-
if (c.page_content if hasattr(c, 'page_content') else c).strip()[-1] in sentence_enders)
|
| 222 |
-
|
| 223 |
-
results[technique] = {
|
| 224 |
-
'num_chunks': len(chunks),
|
| 225 |
-
'avg_chunk_size': avg_chunk_size,
|
| 226 |
-
'min_chunk_size': min(chunk_lengths) if chunk_lengths else 0,
|
| 227 |
-
'max_chunk_size': max(chunk_lengths) if chunk_lengths else 0,
|
| 228 |
-
'complete_sentences_ratio': complete_sentences / len(chunks) if chunks else 0,
|
| 229 |
-
'chunk_lengths': chunk_lengths
|
| 230 |
-
}
|
| 231 |
-
|
| 232 |
-
if should_print:
|
| 233 |
-
self._print(f" ✓ Generated {len(chunks)} chunks, avg size: {avg_chunk_size:.0f} chars")
|
| 234 |
-
|
| 235 |
-
except Exception as e:
|
| 236 |
-
results[technique] = {'error': str(e)}
|
| 237 |
-
if should_print:
|
| 238 |
-
self._print(f" ✗ Error: {str(e)}")
|
| 239 |
-
|
| 240 |
-
return results
|
|
|
|
| 1 |
from langchain_text_splitters import (
|
| 2 |
RecursiveCharacterTextSplitter,
|
| 3 |
CharacterTextSplitter,
|
| 4 |
+
SentenceTransformersTokenTextSplitter,
|
| 5 |
+
NLTKTextSplitter
|
| 6 |
)
|
| 7 |
from langchain_experimental.text_splitter import SemanticChunker
|
| 8 |
from langchain_huggingface import HuggingFaceEmbeddings
|
| 9 |
from sentence_transformers import SentenceTransformer
|
| 10 |
from typing import List, Dict, Any, Optional
|
| 11 |
+
import nltk
|
| 12 |
+
nltk.download('punkt_tab', quiet=True)
|
| 13 |
import pandas as pd
|
| 14 |
|
| 15 |
+
|
| 16 |
class ChunkProcessor:
|
| 17 |
def __init__(self, model_name='all-MiniLM-L6-v2', verbose: bool = True):
|
| 18 |
self.model_name = model_name
|
| 19 |
self.encoder = SentenceTransformer(model_name)
|
| 20 |
self.verbose = verbose
|
|
|
|
| 21 |
self.hf_embeddings = HuggingFaceEmbeddings(model_name=model_name)
|
| 22 |
|
| 23 |
+
# ------------------------------------------------------------------
|
| 24 |
+
# Splitters
|
| 25 |
+
# ------------------------------------------------------------------
|
|
|
|
| 26 |
|
| 27 |
def get_splitter(self, technique: str, chunk_size: int = 500, chunk_overlap: int = 50, **kwargs):
|
| 28 |
"""
|
| 29 |
Factory method to return different chunking strategies.
|
| 30 |
+
|
| 31 |
Strategies:
|
| 32 |
+
- "fixed": Character-based, may split mid-sentence
|
| 33 |
+
- "recursive": Recursive character splitting with hierarchical separators
|
| 34 |
+
- "character": Character-based splitting on paragraph boundaries
|
| 35 |
+
- "sentence": Sliding window over NLTK sentences
|
| 36 |
+
- "semantic": Embedding-based semantic chunking
|
|
|
|
| 37 |
"""
|
| 38 |
if technique == "fixed":
|
|
|
|
| 39 |
return CharacterTextSplitter(
|
| 40 |
+
separator=kwargs.get('separator', ""),
|
| 41 |
+
chunk_size=chunk_size,
|
| 42 |
chunk_overlap=chunk_overlap,
|
| 43 |
length_function=len,
|
| 44 |
is_separator_regex=False
|
| 45 |
)
|
| 46 |
+
|
| 47 |
elif technique == "recursive":
|
|
|
|
|
|
|
| 48 |
return RecursiveCharacterTextSplitter(
|
| 49 |
+
chunk_size=chunk_size,
|
| 50 |
chunk_overlap=chunk_overlap,
|
| 51 |
+
separators=kwargs.get('separators', ["\n\n", "\n", ". ", "! ", "? ", "; ", ", ", " ", ""]),
|
| 52 |
length_function=len,
|
| 53 |
keep_separator=kwargs.get('keep_separator', True)
|
| 54 |
)
|
| 55 |
+
|
| 56 |
elif technique == "character":
|
|
|
|
| 57 |
return CharacterTextSplitter(
|
| 58 |
+
separator=kwargs.get('separator', "\n\n"),
|
| 59 |
+
chunk_size=chunk_size,
|
| 60 |
chunk_overlap=chunk_overlap,
|
| 61 |
length_function=len,
|
| 62 |
is_separator_regex=False
|
| 63 |
)
|
| 64 |
+
|
| 65 |
elif technique == "sentence":
|
| 66 |
+
# sentence-level chunking using NLTK
|
| 67 |
+
return NLTKTextSplitter(
|
|
|
|
| 68 |
chunk_size=chunk_size,
|
| 69 |
chunk_overlap=chunk_overlap,
|
| 70 |
+
separator="\n"
|
| 71 |
+
)
|
| 72 |
+
|
|
|
|
|
|
|
| 73 |
elif technique == "semantic":
|
|
|
|
| 74 |
return SemanticChunker(
|
| 75 |
+
self.hf_embeddings,
|
| 76 |
breakpoint_threshold_type=kwargs.get('breakpoint_threshold_type', "percentile"),
|
| 77 |
+
breakpoint_threshold_amount=kwargs.get('breakpoint_threshold_amount', 95)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 78 |
)
|
| 79 |
+
|
| 80 |
else:
|
| 81 |
+
raise ValueError(f"Technique '{technique}' is not supported. Choose from: fixed, recursive, character, sentence, semantic")
|
| 82 |
+
|
| 83 |
+
# ------------------------------------------------------------------
|
| 84 |
+
# Processing
|
| 85 |
+
# ------------------------------------------------------------------
|
| 86 |
|
| 87 |
+
def process(self, df: pd.DataFrame, technique: str = "recursive", chunk_size: int = 500,
|
| 88 |
+
chunk_overlap: int = 50, max_docs: Optional[int] = 5,
|
| 89 |
+
verbose: Optional[bool] = None, **kwargs) -> List[Dict[str, Any]]:
|
| 90 |
"""
|
| 91 |
+
Processes a DataFrame into vector-ready chunks.
|
| 92 |
+
|
| 93 |
Args:
|
| 94 |
+
df: DataFrame with columns: id, title, url, full_text
|
| 95 |
+
technique: Chunking strategy to use
|
| 96 |
+
chunk_size: Maximum size of each chunk in characters
|
| 97 |
chunk_overlap: Overlap between consecutive chunks
|
| 98 |
+
max_docs: Number of documents to process (None for all)
|
| 99 |
+
verbose: Override instance verbose setting
|
| 100 |
+
**kwargs: Additional arguments passed to the splitter
|
| 101 |
+
|
| 102 |
Returns:
|
| 103 |
+
List of chunk dicts with embeddings and metadata
|
| 104 |
"""
|
|
|
|
| 105 |
should_print = verbose if verbose is not None else self.verbose
|
| 106 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 107 |
required_cols = ['id', 'title', 'url', 'full_text']
|
| 108 |
+
missing_cols = [col for col in required_cols if col not in df.columns]
|
| 109 |
if missing_cols:
|
| 110 |
raise ValueError(f"DataFrame missing required columns: {missing_cols}")
|
| 111 |
+
|
| 112 |
+
splitter = self.get_splitter(technique, chunk_size, chunk_overlap, **kwargs)
|
| 113 |
+
subset_df = df.head(max_docs) if max_docs else df
|
| 114 |
+
processed_chunks = []
|
| 115 |
+
|
| 116 |
for _, row in subset_df.iterrows():
|
| 117 |
if should_print:
|
| 118 |
+
self._print_document_header(row['title'], row['url'], technique, chunk_size, chunk_overlap)
|
| 119 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 120 |
raw_chunks = splitter.split_text(row['full_text'])
|
| 121 |
+
|
|
|
|
|
|
|
|
|
|
| 122 |
for i, text in enumerate(raw_chunks):
|
|
|
|
| 123 |
content = text.page_content if hasattr(text, 'page_content') else text
|
| 124 |
+
|
| 125 |
if should_print:
|
| 126 |
+
self._print_chunk(i, content)
|
| 127 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 128 |
processed_chunks.append({
|
| 129 |
"id": f"{row['id']}-chunk-{i}",
|
| 130 |
+
"values": self.encoder.encode(content).tolist(),
|
| 131 |
"metadata": {
|
| 132 |
"title": row['title'],
|
| 133 |
"text": content,
|
|
|
|
| 138 |
"total_chunks": len(raw_chunks)
|
| 139 |
}
|
| 140 |
})
|
| 141 |
+
|
| 142 |
if should_print:
|
| 143 |
+
self._print_document_summary(len(raw_chunks))
|
| 144 |
+
|
| 145 |
if should_print:
|
| 146 |
+
self._print_processing_summary(len(subset_df), processed_chunks)
|
| 147 |
+
|
|
|
|
|
|
|
| 148 |
return processed_chunks
|
| 149 |
|
| 150 |
+
|
| 151 |
+
# ------------------------------------------------------------------
|
| 152 |
+
# Printing
|
| 153 |
+
# ------------------------------------------------------------------
|
| 154 |
+
|
| 155 |
+
def _print_document_header(self, title, url, technique, chunk_size, chunk_overlap):
|
| 156 |
+
print("\n" + "="*80)
|
| 157 |
+
print(f"DOCUMENT: {title}")
|
| 158 |
+
print(f"URL: {url}")
|
| 159 |
+
print(f"Technique: {technique.upper()} | Chunk Size: {chunk_size} | Overlap: {chunk_overlap}")
|
| 160 |
+
print("-" * 80)
|
| 161 |
+
|
| 162 |
+
def _print_chunk(self, index, content):
|
| 163 |
+
print(f"\n[Chunk {index}] ({len(content)} chars):")
|
| 164 |
+
print(f" {content}")
|
| 165 |
+
|
| 166 |
+
def _print_document_summary(self, num_chunks):
|
| 167 |
+
print(f"Total Chunks Generated: {num_chunks}")
|
| 168 |
+
print("="*80)
|
| 169 |
+
|
| 170 |
+
def _print_processing_summary(self, num_docs, processed_chunks):
|
| 171 |
+
print(f"\nFinished processing {num_docs} documents into {len(processed_chunks)} chunks.")
|
| 172 |
+
if processed_chunks:
|
| 173 |
+
avg = sum(c['metadata']['chunk_size'] for c in processed_chunks) / len(processed_chunks)
|
| 174 |
+
print(f"Average chunk size: {avg:.0f} chars")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
retriever/retriever.py
CHANGED
|
@@ -2,194 +2,166 @@ import numpy as np
|
|
| 2 |
from rank_bm25 import BM25Okapi
|
| 3 |
from sentence_transformers import CrossEncoder
|
| 4 |
from sklearn.metrics.pairwise import cosine_similarity
|
| 5 |
-
from typing import Optional
|
| 6 |
|
| 7 |
class HybridRetriever:
|
| 8 |
def __init__(self, final_chunks, embed_model, rerank_model_name='cross-encoder/ms-marco-MiniLM-L-6-v2', verbose: bool = True):
|
| 9 |
-
"""
|
| 10 |
-
:param final_chunks: The list of chunk dictionaries with metadata.
|
| 11 |
-
:param embed_model: The SentenceTransformer model used for query and chunk embedding.
|
| 12 |
-
:param verbose: Whether to print retrieval details and final results.
|
| 13 |
-
"""
|
| 14 |
self.final_chunks = final_chunks
|
| 15 |
self.embed_model = embed_model
|
| 16 |
self.rerank_model = CrossEncoder(rerank_model_name)
|
| 17 |
self.verbose = verbose
|
| 18 |
|
| 19 |
-
# Initialize BM25 corpus
|
| 20 |
self.tokenized_corpus = [chunk['metadata']['text'].lower().split() for chunk in final_chunks]
|
| 21 |
self.bm25 = BM25Okapi(self.tokenized_corpus)
|
| 22 |
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
|
| 28 |
-
def _rrf_score(self, semantic_results, bm25_results, k=60):
|
| 29 |
-
"""Reciprocal Rank Fusion (RRF) Implementation."""
|
| 30 |
scores = {}
|
| 31 |
for rank, chunk in enumerate(semantic_results):
|
| 32 |
scores[chunk] = scores.get(chunk, 0) + 1 / (k + rank + 1)
|
| 33 |
for rank, chunk in enumerate(bm25_results):
|
| 34 |
scores[chunk] = scores.get(chunk, 0) + 1 / (k + rank + 1)
|
| 35 |
-
|
| 36 |
-
sorted_chunks = sorted(scores.items(), key=lambda x: x[1], reverse=True)
|
| 37 |
-
return [item[0] for item in sorted_chunks]
|
| 38 |
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 47 |
|
| 48 |
-
|
|
|
|
| 49 |
relevance_scores = cosine_similarity(query_embedding, chunk_embeddings)[0]
|
| 50 |
|
| 51 |
-
|
| 52 |
-
unselected_indices = list(range(len(chunk_texts)))
|
| 53 |
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
unselected_indices.remove(idx)
|
| 58 |
|
| 59 |
-
while len(
|
| 60 |
-
mmr_scores = [
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
def search(self, query, index, top_k=10, final_k=3, mode="hybrid",
|
| 79 |
-
|
|
|
|
| 80 |
"""
|
| 81 |
-
:param mode:
|
| 82 |
-
:param rerank_strategy:
|
| 83 |
-
:param
|
|
|
|
| 84 |
"""
|
| 85 |
-
# Determine if we should print
|
| 86 |
should_print = verbose if verbose is not None else self.verbose
|
| 87 |
|
| 88 |
if should_print:
|
| 89 |
-
self.
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
self._print(f"🎯 Top-K: {top_k} | Final-K: {final_k}")
|
| 93 |
-
self._print("-" * 80)
|
| 94 |
-
|
| 95 |
-
semantic_chunks = []
|
| 96 |
-
bm25_chunks = []
|
| 97 |
query_vector = None
|
|
|
|
| 98 |
|
| 99 |
-
# 1. Fetch Candidates
|
| 100 |
if mode in ["semantic", "hybrid"]:
|
|
|
|
| 101 |
if should_print:
|
| 102 |
-
self.
|
| 103 |
-
|
| 104 |
-
query_vector = self.embed_model.encode(query)
|
| 105 |
-
res = index.query(vector=query_vector.tolist(), top_k=top_k, include_metadata=True)
|
| 106 |
-
semantic_chunks = [match['metadata']['text'] for match in res['matches']]
|
| 107 |
-
|
| 108 |
-
if should_print:
|
| 109 |
-
self._print(f" ✓ Retrieved {len(semantic_chunks)} semantic candidates")
|
| 110 |
-
for i, chunk in enumerate(semantic_chunks[:3]): # Show first 3
|
| 111 |
-
preview = chunk[:100] + "..." if len(chunk) > 100 else chunk
|
| 112 |
-
self._print(f" [{i}] {preview}")
|
| 113 |
|
| 114 |
if mode in ["bm25", "hybrid"]:
|
|
|
|
| 115 |
if should_print:
|
| 116 |
-
self.
|
| 117 |
-
|
| 118 |
-
tokenized_query = query.lower().split()
|
| 119 |
-
bm25_scores = self.bm25.get_scores(tokenized_query)
|
| 120 |
-
top_indices = np.argsort(bm25_scores)[::-1][:top_k]
|
| 121 |
-
bm25_chunks = [self.final_chunks[i]['metadata']['text'] for i in top_indices]
|
| 122 |
-
|
| 123 |
-
if should_print:
|
| 124 |
-
self._print(f" ✓ Retrieved {len(bm25_chunks)} BM25 candidates")
|
| 125 |
-
for i, chunk in enumerate(bm25_chunks[:3]): # Show first 3
|
| 126 |
-
preview = chunk[:100] + "..." if len(chunk) > 100 else chunk
|
| 127 |
-
self._print(f" [{i}] {preview}")
|
| 128 |
|
| 129 |
-
# 2.
|
| 130 |
-
if
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
# Standard combination for other methods
|
| 146 |
-
combined = list(dict.fromkeys(semantic_chunks + bm25_chunks)) # Deduplicate keep order
|
| 147 |
-
|
| 148 |
-
if should_print:
|
| 149 |
-
self._print(f"🔄 Combined unique candidates: {len(combined)}")
|
| 150 |
-
self._print(f"🔄 Applying {rerank_strategy.upper()} reranking...")
|
| 151 |
-
|
| 152 |
-
if rerank_strategy == "cross-encoder" and combined:
|
| 153 |
-
|
| 154 |
-
pairs = [[query, chunk] for chunk in combined]
|
| 155 |
-
scores = self.rerank_model.predict(pairs)
|
| 156 |
-
results = sorted(zip(combined, scores), key=lambda x: x[1], reverse=True)
|
| 157 |
-
results = [res[0] for res in results[:final_k]]
|
| 158 |
-
|
| 159 |
-
if should_print:
|
| 160 |
-
self._print(f"\n✅ Final {final_k} Results (Cross-Encoder Reranked):")
|
| 161 |
-
for i, chunk in enumerate(results):
|
| 162 |
-
preview = chunk[:150] + "..." if len(chunk) > 150 else chunk
|
| 163 |
-
self._print(f" [{i+1}] {preview}")
|
| 164 |
-
self._print("="*80)
|
| 165 |
-
|
| 166 |
-
return results
|
| 167 |
-
|
| 168 |
-
elif rerank_strategy == "mmr" and combined:
|
| 169 |
-
if should_print:
|
| 170 |
-
self._print(f" Using MMR with λ=0.5 to balance relevance and diversity")
|
| 171 |
-
|
| 172 |
-
if query_vector is None:
|
| 173 |
query_vector = self.embed_model.encode(query)
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
from rank_bm25 import BM25Okapi
|
| 3 |
from sentence_transformers import CrossEncoder
|
| 4 |
from sklearn.metrics.pairwise import cosine_similarity
|
| 5 |
+
from typing import Optional, List
|
| 6 |
|
| 7 |
class HybridRetriever:
|
| 8 |
def __init__(self, final_chunks, embed_model, rerank_model_name='cross-encoder/ms-marco-MiniLM-L-6-v2', verbose: bool = True):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
self.final_chunks = final_chunks
|
| 10 |
self.embed_model = embed_model
|
| 11 |
self.rerank_model = CrossEncoder(rerank_model_name)
|
| 12 |
self.verbose = verbose
|
| 13 |
|
|
|
|
| 14 |
self.tokenized_corpus = [chunk['metadata']['text'].lower().split() for chunk in final_chunks]
|
| 15 |
self.bm25 = BM25Okapi(self.tokenized_corpus)
|
| 16 |
|
| 17 |
+
# ------------------------------------------------------------------
|
| 18 |
+
# Retrieval
|
| 19 |
+
# ------------------------------------------------------------------
|
| 20 |
+
|
| 21 |
+
def _semantic_search(self, query, index, top_k) -> tuple[np.ndarray, List[str]]:
|
| 22 |
+
query_vector = self.embed_model.encode(query)
|
| 23 |
+
res = index.query(vector=query_vector.tolist(), top_k=top_k, include_metadata=True)
|
| 24 |
+
chunks = [match['metadata']['text'] for match in res['matches']]
|
| 25 |
+
return query_vector, chunks
|
| 26 |
+
|
| 27 |
+
def _bm25_search(self, query, top_k) -> List[str]:
|
| 28 |
+
tokenized_query = query.lower().split()
|
| 29 |
+
scores = self.bm25.get_scores(tokenized_query)
|
| 30 |
+
top_indices = np.argsort(scores)[::-1][:top_k]
|
| 31 |
+
return [self.final_chunks[i]['metadata']['text'] for i in top_indices]
|
| 32 |
+
|
| 33 |
+
# ------------------------------------------------------------------
|
| 34 |
+
# Fusion
|
| 35 |
+
# ------------------------------------------------------------------
|
| 36 |
|
| 37 |
+
def _rrf_score(self, semantic_results, bm25_results, k=60) -> List[str]:
|
|
|
|
| 38 |
scores = {}
|
| 39 |
for rank, chunk in enumerate(semantic_results):
|
| 40 |
scores[chunk] = scores.get(chunk, 0) + 1 / (k + rank + 1)
|
| 41 |
for rank, chunk in enumerate(bm25_results):
|
| 42 |
scores[chunk] = scores.get(chunk, 0) + 1 / (k + rank + 1)
|
| 43 |
+
return [chunk for chunk, _ in sorted(scores.items(), key=lambda x: x[1], reverse=True)]
|
|
|
|
|
|
|
| 44 |
|
| 45 |
+
# ------------------------------------------------------------------
|
| 46 |
+
# Reranking
|
| 47 |
+
# ------------------------------------------------------------------
|
| 48 |
+
|
| 49 |
+
def _cross_encoder_rerank(self, query, chunks, final_k) -> List[str]:
|
| 50 |
+
pairs = [[query, chunk] for chunk in chunks]
|
| 51 |
+
scores = self.rerank_model.predict(pairs)
|
| 52 |
+
ranked = sorted(zip(chunks, scores), key=lambda x: x[1], reverse=True)
|
| 53 |
+
return [chunk for chunk, _ in ranked[:final_k]]
|
| 54 |
+
|
| 55 |
+
# ------------------------------------------------------------------
|
| 56 |
+
# MMR (applied after reranking as a diversity filter)
|
| 57 |
+
# ------------------------------------------------------------------
|
| 58 |
+
|
| 59 |
+
def _maximal_marginal_relevance(self, query_vector, chunks, lambda_param=0.5, top_k=3) -> List[str]:
|
| 60 |
+
if not chunks:
|
| 61 |
+
return []
|
| 62 |
|
| 63 |
+
chunk_embeddings = self.embed_model.encode(chunks)
|
| 64 |
+
query_embedding = query_vector.reshape(1, -1)
|
| 65 |
relevance_scores = cosine_similarity(query_embedding, chunk_embeddings)[0]
|
| 66 |
|
| 67 |
+
selected, unselected = [], list(range(len(chunks)))
|
|
|
|
| 68 |
|
| 69 |
+
first = int(np.argmax(relevance_scores))
|
| 70 |
+
selected.append(first)
|
| 71 |
+
unselected.remove(first)
|
|
|
|
| 72 |
|
| 73 |
+
while len(selected) < min(top_k, len(chunks)):
|
| 74 |
+
mmr_scores = [
|
| 75 |
+
(i, lambda_param * relevance_scores[i] - (1 - lambda_param) * max(
|
| 76 |
+
cosine_similarity(chunk_embeddings[i].reshape(1, -1),
|
| 77 |
+
chunk_embeddings[s].reshape(1, -1))[0][0]
|
| 78 |
+
for s in selected
|
| 79 |
+
))
|
| 80 |
+
for i in unselected
|
| 81 |
+
]
|
| 82 |
+
best = max(mmr_scores, key=lambda x: x[1])[0]
|
| 83 |
+
selected.append(best)
|
| 84 |
+
unselected.remove(best)
|
| 85 |
+
|
| 86 |
+
return [chunks[i] for i in selected]
|
| 87 |
+
|
| 88 |
+
# ------------------------------------------------------------------
|
| 89 |
+
# Main search
|
| 90 |
+
# ------------------------------------------------------------------
|
| 91 |
+
|
| 92 |
+
def search(self, query, index, top_k=10, final_k=3, mode="hybrid",
|
| 93 |
+
rerank_strategy="cross-encoder", use_mmr=True, lambda_param=0.5,
|
| 94 |
+
verbose: Optional[bool] = None) -> List[str]:
|
| 95 |
"""
|
| 96 |
+
:param mode: "semantic", "bm25", or "hybrid"
|
| 97 |
+
:param rerank_strategy: "cross-encoder", "rrf", or "none"
|
| 98 |
+
:param use_mmr: Whether to apply MMR diversity filter after reranking
|
| 99 |
+
:param lambda_param: MMR trade-off between relevance (1.0) and diversity (0.0)
|
| 100 |
"""
|
|
|
|
| 101 |
should_print = verbose if verbose is not None else self.verbose
|
| 102 |
|
| 103 |
if should_print:
|
| 104 |
+
self._print_search_header(query, mode, rerank_strategy, top_k, final_k)
|
| 105 |
+
|
| 106 |
+
# 1. Retrieve candidates
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 107 |
query_vector = None
|
| 108 |
+
semantic_chunks, bm25_chunks = [], []
|
| 109 |
|
|
|
|
| 110 |
if mode in ["semantic", "hybrid"]:
|
| 111 |
+
query_vector, semantic_chunks = self._semantic_search(query, index, top_k)
|
| 112 |
if should_print:
|
| 113 |
+
self._print_candidates("Semantic Search", semantic_chunks)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 114 |
|
| 115 |
if mode in ["bm25", "hybrid"]:
|
| 116 |
+
bm25_chunks = self._bm25_search(query, top_k)
|
| 117 |
if should_print:
|
| 118 |
+
self._print_candidates("BM25 Search", bm25_chunks)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 119 |
|
| 120 |
+
# 2. Fuse / rerank
|
| 121 |
+
if rerank_strategy == "rrf":
|
| 122 |
+
candidates = self._rrf_score(semantic_chunks, bm25_chunks)[:final_k]
|
| 123 |
+
label = "RRF"
|
| 124 |
+
elif rerank_strategy == "cross-encoder":
|
| 125 |
+
combined = list(dict.fromkeys(semantic_chunks + bm25_chunks))
|
| 126 |
+
candidates = self._cross_encoder_rerank(query, combined, final_k)
|
| 127 |
+
label = "Cross-Encoder"
|
| 128 |
+
else: # "none"
|
| 129 |
+
candidates = list(dict.fromkeys(semantic_chunks + bm25_chunks))[:final_k]
|
| 130 |
+
label = "No Reranking"
|
| 131 |
+
|
| 132 |
+
# 3. MMR diversity filter (applied after reranking)
|
| 133 |
+
if use_mmr and candidates:
|
| 134 |
+
if query_vector is None:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 135 |
query_vector = self.embed_model.encode(query)
|
| 136 |
+
candidates = self._maximal_marginal_relevance(query_vector, candidates,
|
| 137 |
+
lambda_param=lambda_param, top_k=3)
|
| 138 |
+
label += " + MMR"
|
| 139 |
+
|
| 140 |
+
if should_print:
|
| 141 |
+
self._print_final_results(candidates, label)
|
| 142 |
+
|
| 143 |
+
return candidates
|
| 144 |
+
|
| 145 |
+
# ------------------------------------------------------------------
|
| 146 |
+
# Printing
|
| 147 |
+
# ------------------------------------------------------------------
|
| 148 |
+
|
| 149 |
+
def _print_search_header(self, query, mode, rerank_strategy, top_k, final_k):
|
| 150 |
+
print("\n" + "="*80)
|
| 151 |
+
print(f" SEARCH QUERY: {query}")
|
| 152 |
+
print(f"Mode: {mode.upper()} | Rerank: {rerank_strategy.upper()}")
|
| 153 |
+
print(f"Top-K: {top_k} | Final-K: {final_k}")
|
| 154 |
+
print("-" * 80)
|
| 155 |
+
|
| 156 |
+
def _print_candidates(self, label, chunks, preview_n=3):
|
| 157 |
+
print(f"{label}: Retrieved {len(chunks)} candidates")
|
| 158 |
+
for i, chunk in enumerate(chunks[:preview_n]):
|
| 159 |
+
preview = chunk[:100] + "..." if len(chunk) > 100 else chunk
|
| 160 |
+
print(f" [{i}] {preview}")
|
| 161 |
+
|
| 162 |
+
def _print_final_results(self, results, strategy_label):
|
| 163 |
+
print(f"\n Final {len(results)} Results ({strategy_label}):")
|
| 164 |
+
for i, chunk in enumerate(results):
|
| 165 |
+
preview = chunk[:150] + "..." if len(chunk) > 150 else chunk
|
| 166 |
+
print(f" [{i+1}] {preview}")
|
| 167 |
+
print("="*80)
|
vector_db.py
CHANGED
|
@@ -22,12 +22,81 @@ def get_pinecone_index(api_key, index_name, dimension=384, metric="cosine"):
|
|
| 22 |
|
| 23 |
return pc.Index(index_name)
|
| 24 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
def upsert_to_pinecone(index, chunks, batch_size=100):
|
| 26 |
-
"""Upserts chunks to Pinecone in manageable batches.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
print(f"Uploading {len(chunks)} chunks to Pinecone...")
|
| 28 |
|
| 29 |
for i in range(0, len(chunks), batch_size):
|
| 30 |
batch = chunks[i : i + batch_size]
|
| 31 |
index.upsert(vectors=batch)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
|
| 33 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
|
| 23 |
return pc.Index(index_name)
|
| 24 |
|
| 25 |
+
def prepare_vectors_for_upsert(final_chunks):
|
| 26 |
+
"""Convert final_chunks to the format expected by Pinecone upsert"""
|
| 27 |
+
vectors = []
|
| 28 |
+
for chunk in final_chunks:
|
| 29 |
+
vectors.append({
|
| 30 |
+
'id': chunk['id'],
|
| 31 |
+
'values': chunk['values'], # The embedding vector
|
| 32 |
+
'metadata': {
|
| 33 |
+
'text': chunk['metadata']['text'],
|
| 34 |
+
'title': chunk['metadata']['title'],
|
| 35 |
+
'url': chunk['metadata']['url'],
|
| 36 |
+
'chunk_index': chunk['metadata']['chunk_index'],
|
| 37 |
+
'technique': chunk['metadata']['technique'],
|
| 38 |
+
'chunk_size': chunk['metadata']['chunk_size'],
|
| 39 |
+
'total_chunks': chunk['metadata']['total_chunks']
|
| 40 |
+
}
|
| 41 |
+
})
|
| 42 |
+
return vectors
|
| 43 |
+
|
| 44 |
def upsert_to_pinecone(index, chunks, batch_size=100):
|
| 45 |
+
"""Upserts chunks to Pinecone in manageable batches.
|
| 46 |
+
|
| 47 |
+
Args:
|
| 48 |
+
index: Pinecone index object
|
| 49 |
+
chunks: List of chunk dictionaries (as returned by prepare_vectors_for_upsert)
|
| 50 |
+
batch_size: Number of vectors to upsert in each batch
|
| 51 |
+
"""
|
| 52 |
print(f"Uploading {len(chunks)} chunks to Pinecone...")
|
| 53 |
|
| 54 |
for i in range(0, len(chunks), batch_size):
|
| 55 |
batch = chunks[i : i + batch_size]
|
| 56 |
index.upsert(vectors=batch)
|
| 57 |
+
print(f" Uploaded batch {i//batch_size + 1}/{(len(chunks)-1)//batch_size + 1} ({len(batch)} vectors)")
|
| 58 |
+
|
| 59 |
+
print(" Upsert complete.")
|
| 60 |
+
|
| 61 |
+
def refresh_pinecone_index(index, final_chunks, batch_size=100):
|
| 62 |
+
"""Helper function to refresh index with new chunks.
|
| 63 |
+
|
| 64 |
+
This function checks if the index has the expected number of vectors,
|
| 65 |
+
and upserts if necessary.
|
| 66 |
+
|
| 67 |
+
Args:
|
| 68 |
+
index: Pinecone index object
|
| 69 |
+
final_chunks: List of chunk dictionaries with embeddings
|
| 70 |
+
batch_size: Batch size for upsert
|
| 71 |
+
|
| 72 |
+
Returns:
|
| 73 |
+
Boolean indicating if upsert was performed
|
| 74 |
+
"""
|
| 75 |
+
try:
|
| 76 |
+
stats = index.describe_index_stats()
|
| 77 |
+
current_vector_count = stats.get('total_vector_count', 0)
|
| 78 |
+
expected_vector_count = len(final_chunks)
|
| 79 |
+
|
| 80 |
+
print(f"\n Current vectors in index: {current_vector_count}")
|
| 81 |
+
print(f" Expected vectors: {expected_vector_count}")
|
| 82 |
|
| 83 |
+
if current_vector_count == 0:
|
| 84 |
+
print(" Index is empty. Preparing vectors for upsert...")
|
| 85 |
+
vectors_to_upsert = prepare_vectors_for_upsert(final_chunks)
|
| 86 |
+
upsert_to_pinecone(index, vectors_to_upsert, batch_size)
|
| 87 |
+
|
| 88 |
+
# Verify upsert
|
| 89 |
+
stats = index.describe_index_stats()
|
| 90 |
+
print(f" After upsert - Total vectors: {stats.get('total_vector_count', 0)}")
|
| 91 |
+
return True
|
| 92 |
+
elif current_vector_count != expected_vector_count:
|
| 93 |
+
print(f" Vector count mismatch. Expected {expected_vector_count}, got {current_vector_count}")
|
| 94 |
+
print(" Consider recreating the index if you want to refresh.")
|
| 95 |
+
return False
|
| 96 |
+
else:
|
| 97 |
+
print(f"ℹ Index already has {current_vector_count} vectors. Ready for search.")
|
| 98 |
+
return False
|
| 99 |
+
|
| 100 |
+
except Exception as e:
|
| 101 |
+
print(f"Error checking index stats: {e}")
|
| 102 |
+
return False
|