data to generation
Browse files- data_loader.py +41 -0
- models/deepseek_v3.py +23 -0
- models/llama_3_8b.py +20 -0
- models/mistral_7b.py +35 -0
- models/qwen_2_5.py +20 -0
- models/tiny_aya.py +31 -0
- retriever/generator.py +19 -0
- retriever/processor.py +111 -0
- retriever/retriever.py +108 -0
data_loader.py
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import fitz # PyMuPDF
|
| 2 |
+
import requests
|
| 3 |
+
import io
|
| 4 |
+
import arxiv
|
| 5 |
+
import pandas as pd
|
| 6 |
+
|
| 7 |
+
def extract_text_from_url(pdf_url):
|
| 8 |
+
"""Downloads a PDF and extracts all text."""
|
| 9 |
+
try:
|
| 10 |
+
response = requests.get(pdf_url)
|
| 11 |
+
# Open the PDF directly from the byte stream
|
| 12 |
+
with fitz.open(stream=io.BytesIO(response.content), filetype="pdf") as doc:
|
| 13 |
+
text = ""
|
| 14 |
+
for page in doc:
|
| 15 |
+
text += page.get_text()
|
| 16 |
+
return text.replace('\n', ' ')
|
| 17 |
+
except Exception as e:
|
| 18 |
+
print(f"Error downloading {pdf_url}: {e}")
|
| 19 |
+
return ""
|
| 20 |
+
|
| 21 |
+
def fetch_arxiv_data(category="cs.AI", limit=5):
|
| 22 |
+
client = arxiv.Client()
|
| 23 |
+
search = arxiv.Search(
|
| 24 |
+
query=f"cat:{category}",
|
| 25 |
+
max_results=limit,
|
| 26 |
+
sort_by=arxiv.SortCriterion.SubmittedDate
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
results = []
|
| 30 |
+
for r in client.results(search):
|
| 31 |
+
print(f"Downloading full text for: {r.title[:50]}...")
|
| 32 |
+
full_text = extract_text_from_url(r.pdf_url)
|
| 33 |
+
|
| 34 |
+
results.append({
|
| 35 |
+
"id": r.entry_id.split('/')[-1],
|
| 36 |
+
"title": r.title,
|
| 37 |
+
"abstract": r.summary.replace('\n', ' '),
|
| 38 |
+
"full_text": full_text, # <--- NEW FIELD
|
| 39 |
+
"url": r.pdf_url
|
| 40 |
+
})
|
| 41 |
+
return pd.DataFrame(results)
|
models/deepseek_v3.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from huggingface_hub import InferenceClient
|
| 2 |
+
|
| 3 |
+
class DeepSeek_V3:
|
| 4 |
+
def __init__(self, token):
|
| 5 |
+
self.client = InferenceClient(token=token)
|
| 6 |
+
self.model_id = "deepseek-ai/DeepSeek-V3"
|
| 7 |
+
|
| 8 |
+
def generate(self, prompt, max_tokens=500, temperature=0.15):
|
| 9 |
+
response = ""
|
| 10 |
+
try:
|
| 11 |
+
for message in self.client.chat_completion(
|
| 12 |
+
model=self.model_id,
|
| 13 |
+
messages=[{"role": "user", "content": prompt}],
|
| 14 |
+
max_tokens=max_tokens,
|
| 15 |
+
temperature=temperature,
|
| 16 |
+
stream=True,
|
| 17 |
+
):
|
| 18 |
+
if message.choices:
|
| 19 |
+
content = message.choices[0].delta.content
|
| 20 |
+
if content: response += content
|
| 21 |
+
except Exception as e:
|
| 22 |
+
return f"⚠️ DeepSeek API Busy: {e}"
|
| 23 |
+
return response
|
models/llama_3_8b.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from huggingface_hub import InferenceClient
|
| 2 |
+
|
| 3 |
+
class Llama3_8B:
|
| 4 |
+
def __init__(self, token):
|
| 5 |
+
self.client = InferenceClient(token=token)
|
| 6 |
+
self.model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
|
| 7 |
+
|
| 8 |
+
def generate(self, prompt, max_tokens=500, temp=0.1):
|
| 9 |
+
response = ""
|
| 10 |
+
for message in self.client.chat_completion(
|
| 11 |
+
model=self.model_id,
|
| 12 |
+
messages=[{"role": "user", "content": prompt}],
|
| 13 |
+
max_tokens=max_tokens,
|
| 14 |
+
temperature=temp,
|
| 15 |
+
stream=True,
|
| 16 |
+
):
|
| 17 |
+
if message.choices:
|
| 18 |
+
content = message.choices[0].delta.content
|
| 19 |
+
if content: response += content
|
| 20 |
+
return response
|
models/mistral_7b.py
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from huggingface_hub import InferenceClient
|
| 3 |
+
|
| 4 |
+
class Mistral_7b:
|
| 5 |
+
def __init__(self, token):
|
| 6 |
+
# Initializing with api_key as per latest documentation
|
| 7 |
+
self.client = InferenceClient(api_key=token)
|
| 8 |
+
# Using the specific provider suffix
|
| 9 |
+
self.model_id = "mistralai/Mistral-7B-Instruct-v0.2:featherless-ai"
|
| 10 |
+
|
| 11 |
+
def generate(self, prompt, max_tokens=500, **kwargs):
|
| 12 |
+
# Extract temperature, defaulting to 0.2 if not provided
|
| 13 |
+
temperature = kwargs.get('temperature', kwargs.get('temp', 0.2))
|
| 14 |
+
|
| 15 |
+
response = ""
|
| 16 |
+
try:
|
| 17 |
+
# Using the new .chat.completions.create syntax for Featherless
|
| 18 |
+
stream = self.client.chat.completions.create(
|
| 19 |
+
model=self.model_id,
|
| 20 |
+
messages=[{"role": "user", "content": prompt}],
|
| 21 |
+
max_tokens=max_tokens,
|
| 22 |
+
temperature=temperature,
|
| 23 |
+
stream=True,
|
| 24 |
+
)
|
| 25 |
+
|
| 26 |
+
for chunk in stream:
|
| 27 |
+
# Accessing content through the standard completion object structure
|
| 28 |
+
if chunk.choices and chunk.choices[0].delta.content:
|
| 29 |
+
content = chunk.choices[0].delta.content
|
| 30 |
+
response += content
|
| 31 |
+
|
| 32 |
+
except Exception as e:
|
| 33 |
+
return f"❌ Mistral Featherless Error: {e}"
|
| 34 |
+
|
| 35 |
+
return response
|
models/qwen_2_5.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from huggingface_hub import InferenceClient
|
| 2 |
+
|
| 3 |
+
class Qwen2_5:
|
| 4 |
+
def __init__(self, token):
|
| 5 |
+
self.client = InferenceClient(token=token)
|
| 6 |
+
self.model_id = "Qwen/Qwen2.5-72B-Instruct"
|
| 7 |
+
|
| 8 |
+
def generate(self, prompt, max_tokens=500, temperature=0.3):
|
| 9 |
+
response = ""
|
| 10 |
+
for message in self.client.chat_completion(
|
| 11 |
+
model=self.model_id,
|
| 12 |
+
messages=[{"role": "user", "content": prompt}],
|
| 13 |
+
max_tokens=max_tokens,
|
| 14 |
+
temperature=temperature,
|
| 15 |
+
stream=True,
|
| 16 |
+
):
|
| 17 |
+
if message.choices:
|
| 18 |
+
content = message.choices[0].delta.content
|
| 19 |
+
if content: response += content
|
| 20 |
+
return response
|
models/tiny_aya.py
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from huggingface_hub import InferenceClient
|
| 2 |
+
|
| 3 |
+
class TinyAya:
|
| 4 |
+
def __init__(self, token):
|
| 5 |
+
self.client = InferenceClient(token=token)
|
| 6 |
+
# 3.3B parameter model, great for multilingual/efficient RAG
|
| 7 |
+
self.model_id = "CohereLabs/tiny-aya-global"
|
| 8 |
+
|
| 9 |
+
def generate(self, prompt, max_tokens=400, **kwargs):
|
| 10 |
+
"""
|
| 11 |
+
Using **kwargs makes this compatible with calls using 'temp' or 'temperature'.
|
| 12 |
+
"""
|
| 13 |
+
# This line looks for 'temperature', then 'temp', and defaults to 0.3 if neither exist
|
| 14 |
+
temperature = kwargs.get('temperature', kwargs.get('temp', 0.3))
|
| 15 |
+
|
| 16 |
+
response = ""
|
| 17 |
+
try:
|
| 18 |
+
for message in self.client.chat_completion(
|
| 19 |
+
model=self.model_id,
|
| 20 |
+
messages=[{"role": "user", "content": prompt}],
|
| 21 |
+
max_tokens=max_tokens,
|
| 22 |
+
temperature=temperature,
|
| 23 |
+
stream=True,
|
| 24 |
+
):
|
| 25 |
+
if message.choices:
|
| 26 |
+
content = message.choices[0].delta.content
|
| 27 |
+
if content: response += content
|
| 28 |
+
except Exception as e:
|
| 29 |
+
return f"❌ TinyAya Error: {e}"
|
| 30 |
+
|
| 31 |
+
return response
|
retriever/generator.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
class RAGGenerator:
|
| 2 |
+
def generate_prompt(self, query, retrieved_contexts):
|
| 3 |
+
"""Prepares the academic prompt template."""
|
| 4 |
+
context_text = "\n\n".join([f"--- Source {i+1} ---\n{c}" for i, c in enumerate(retrieved_contexts)])
|
| 5 |
+
|
| 6 |
+
return f"""You are an expert academic assistant. Use the following pieces of retrieved context to answer the question.
|
| 7 |
+
If the answer isn't in the context, say you don't know based on the provided documents.
|
| 8 |
+
|
| 9 |
+
Context:
|
| 10 |
+
{context_text}
|
| 11 |
+
|
| 12 |
+
Question: {query}
|
| 13 |
+
|
| 14 |
+
Answer:"""
|
| 15 |
+
|
| 16 |
+
def get_answer(self, model_instance, query, retrieved_contexts, **kwargs):
|
| 17 |
+
"""Uses a specific model instance to generate the final answer."""
|
| 18 |
+
prompt = self.generate_prompt(query, retrieved_contexts)
|
| 19 |
+
return model_instance.generate(prompt, **kwargs)
|
retriever/processor.py
ADDED
|
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from langchain_text_splitters import (
|
| 2 |
+
RecursiveCharacterTextSplitter,
|
| 3 |
+
CharacterTextSplitter,
|
| 4 |
+
SentenceTransformersTokenTextSplitter
|
| 5 |
+
)
|
| 6 |
+
from langchain_experimental.text_splitter import SemanticChunker
|
| 7 |
+
from langchain_huggingface import HuggingFaceEmbeddings
|
| 8 |
+
from sentence_transformers import SentenceTransformer
|
| 9 |
+
|
| 10 |
+
class ChunkProcessor:
|
| 11 |
+
def __init__(self, model_name='all-MiniLM-L6-v2'):
|
| 12 |
+
self.model_name = model_name
|
| 13 |
+
self.encoder = SentenceTransformer(model_name)
|
| 14 |
+
# Required for Semantic Chunking
|
| 15 |
+
self.hf_embeddings = HuggingFaceEmbeddings(model_name=model_name)
|
| 16 |
+
|
| 17 |
+
def get_splitter(self, technique, chunk_size=500, chunk_overlap=50, **kwargs):
|
| 18 |
+
"""
|
| 19 |
+
Factory method to return different chunking strategies.
|
| 20 |
+
"""
|
| 21 |
+
if technique == "fixed":
|
| 22 |
+
return CharacterTextSplitter(
|
| 23 |
+
separator=kwargs.get('separator', ""),
|
| 24 |
+
chunk_size=chunk_size,
|
| 25 |
+
chunk_overlap=chunk_overlap
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
elif technique == "recursive":
|
| 29 |
+
return RecursiveCharacterTextSplitter(
|
| 30 |
+
chunk_size=chunk_size,
|
| 31 |
+
chunk_overlap=chunk_overlap
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
elif technique == "character":
|
| 35 |
+
return CharacterTextSplitter(
|
| 36 |
+
separator=kwargs.get('separator', "\n\n"),
|
| 37 |
+
chunk_size=chunk_size,
|
| 38 |
+
chunk_overlap=chunk_overlap
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
elif technique == "sentence":
|
| 42 |
+
# Using Recursive Splitter configured specifically for sentence boundaries
|
| 43 |
+
# This avoids the Spacy [E050] error while still respecting full sentences.
|
| 44 |
+
return RecursiveCharacterTextSplitter(
|
| 45 |
+
chunk_size=chunk_size,
|
| 46 |
+
chunk_overlap=chunk_overlap,
|
| 47 |
+
separators=["\n\n", "\n", ". ", "? ", "! ", " ", ""]
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
elif technique == "semantic":
|
| 51 |
+
return SemanticChunker(
|
| 52 |
+
self.hf_embeddings,
|
| 53 |
+
breakpoint_threshold_type="percentile"
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
elif technique == "token":
|
| 57 |
+
return SentenceTransformersTokenTextSplitter(
|
| 58 |
+
model_name=self.model_name,
|
| 59 |
+
tokens_per_chunk=chunk_size,
|
| 60 |
+
chunk_overlap=chunk_overlap
|
| 61 |
+
)
|
| 62 |
+
else:
|
| 63 |
+
raise ValueError(f"Technique '{technique}' is not supported.")
|
| 64 |
+
|
| 65 |
+
def process(self, df, technique="recursive", chunk_size=500, chunk_overlap=50, **kwargs):
|
| 66 |
+
"""
|
| 67 |
+
Processes a DataFrame into vector-ready chunks with full output for 5 documents.
|
| 68 |
+
"""
|
| 69 |
+
splitter = self.get_splitter(technique, chunk_size, chunk_overlap, **kwargs)
|
| 70 |
+
processed_chunks = []
|
| 71 |
+
|
| 72 |
+
# Take the first 5 documents as requested
|
| 73 |
+
subset_df = df.head(5)
|
| 74 |
+
|
| 75 |
+
for _, row in subset_df.iterrows():
|
| 76 |
+
print(f"\n" + "="*80)
|
| 77 |
+
print(f"📄 DOCUMENT: {row['title']}")
|
| 78 |
+
print(f"🔗 URL: {row['url']}")
|
| 79 |
+
print("-" * 80)
|
| 80 |
+
|
| 81 |
+
# Split the text
|
| 82 |
+
raw_chunks = splitter.split_text(row['full_text'])
|
| 83 |
+
|
| 84 |
+
print(f"🎯 Technique: {technique.upper()} | Total Chunks: {len(raw_chunks)}")
|
| 85 |
+
|
| 86 |
+
for i, text in enumerate(raw_chunks):
|
| 87 |
+
# Standardize output
|
| 88 |
+
content = text.page_content if hasattr(text, 'page_content') else text
|
| 89 |
+
|
| 90 |
+
# Print the full content of every chunk
|
| 91 |
+
print(f"\n[Chunk {i}] ({len(content)} chars):")
|
| 92 |
+
print(f" {content}")
|
| 93 |
+
|
| 94 |
+
# Embedding
|
| 95 |
+
embedding = self.encoder.encode(content).tolist()
|
| 96 |
+
|
| 97 |
+
processed_chunks.append({
|
| 98 |
+
"id": f"{row['id']}-chunk-{i}",
|
| 99 |
+
"values": embedding,
|
| 100 |
+
"metadata": {
|
| 101 |
+
"title": row['title'],
|
| 102 |
+
"text": content,
|
| 103 |
+
"url": row['url'],
|
| 104 |
+
"chunk_index": i,
|
| 105 |
+
"technique": technique
|
| 106 |
+
}
|
| 107 |
+
})
|
| 108 |
+
print("="*80)
|
| 109 |
+
|
| 110 |
+
print(f"\n✅ Finished processing 5 documents into {len(processed_chunks)} chunks.")
|
| 111 |
+
return processed_chunks
|
retriever/retriever.py
ADDED
|
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
from rank_bm25 import BM25Okapi
|
| 3 |
+
from sentence_transformers import CrossEncoder
|
| 4 |
+
from sklearn.metrics.pairwise import cosine_similarity
|
| 5 |
+
|
| 6 |
+
class HybridRetriever:
|
| 7 |
+
def __init__(self, final_chunks, embed_model, rerank_model_name='cross-encoder/ms-marco-MiniLM-L-6-v2'):
|
| 8 |
+
"""
|
| 9 |
+
:param final_chunks: The list of chunk dictionaries with metadata.
|
| 10 |
+
:param embed_model: The SentenceTransformer model used for query and chunk embedding.
|
| 11 |
+
"""
|
| 12 |
+
self.final_chunks = final_chunks
|
| 13 |
+
self.embed_model = embed_model
|
| 14 |
+
self.rerank_model = CrossEncoder(rerank_model_name)
|
| 15 |
+
|
| 16 |
+
# Initialize BM25 corpus
|
| 17 |
+
self.tokenized_corpus = [chunk['metadata']['text'].lower().split() for chunk in final_chunks]
|
| 18 |
+
self.bm25 = BM25Okapi(self.tokenized_corpus)
|
| 19 |
+
|
| 20 |
+
def _rrf_score(self, semantic_results, bm25_results, k=60):
|
| 21 |
+
"""Reciprocal Rank Fusion (RRF) Implementation."""
|
| 22 |
+
scores = {}
|
| 23 |
+
for rank, chunk in enumerate(semantic_results):
|
| 24 |
+
scores[chunk] = scores.get(chunk, 0) + 1 / (k + rank + 1)
|
| 25 |
+
for rank, chunk in enumerate(bm25_results):
|
| 26 |
+
scores[chunk] = scores.get(chunk, 0) + 1 / (k + rank + 1)
|
| 27 |
+
|
| 28 |
+
sorted_chunks = sorted(scores.items(), key=lambda x: x[1], reverse=True)
|
| 29 |
+
return [item[0] for item in sorted_chunks]
|
| 30 |
+
|
| 31 |
+
def _maximal_marginal_relevance(self, query_embedding, chunk_texts, lambda_param=0.5, top_k=3):
|
| 32 |
+
"""
|
| 33 |
+
MMR Re-ranking to balance relevance and diversity.
|
| 34 |
+
"""
|
| 35 |
+
if not chunk_texts: return []
|
| 36 |
+
|
| 37 |
+
chunk_embeddings = self.embed_model.encode(chunk_texts)
|
| 38 |
+
query_embedding = query_embedding.reshape(1, -1)
|
| 39 |
+
|
| 40 |
+
# Initial relevance scores
|
| 41 |
+
relevance_scores = cosine_similarity(query_embedding, chunk_embeddings)[0]
|
| 42 |
+
|
| 43 |
+
selected_indices = []
|
| 44 |
+
unselected_indices = list(range(len(chunk_texts)))
|
| 45 |
+
|
| 46 |
+
# First pick: most relevant
|
| 47 |
+
idx = np.argmax(relevance_scores)
|
| 48 |
+
selected_indices.append(idx)
|
| 49 |
+
unselected_indices.remove(idx)
|
| 50 |
+
|
| 51 |
+
while len(selected_indices) < min(top_k, len(chunk_texts)):
|
| 52 |
+
mmr_scores = []
|
| 53 |
+
for un_idx in unselected_indices:
|
| 54 |
+
# Similarity to query
|
| 55 |
+
rel = relevance_scores[un_idx]
|
| 56 |
+
# Max similarity to already selected chunks (redundancy)
|
| 57 |
+
sim_to_selected = max([cosine_similarity(chunk_embeddings[un_idx].reshape(1, -1),
|
| 58 |
+
chunk_embeddings[sel_idx].reshape(1, -1))[0][0]
|
| 59 |
+
for sel_idx in selected_indices])
|
| 60 |
+
|
| 61 |
+
mmr_score = lambda_param * rel - (1 - lambda_param) * sim_to_selected
|
| 62 |
+
mmr_scores.append((un_idx, mmr_score))
|
| 63 |
+
|
| 64 |
+
next_idx = max(mmr_scores, key=lambda x: x[1])[0]
|
| 65 |
+
selected_indices.append(next_idx)
|
| 66 |
+
unselected_indices.remove(next_idx)
|
| 67 |
+
|
| 68 |
+
return [chunk_texts[i] for i in selected_indices]
|
| 69 |
+
|
| 70 |
+
def search(self, query, index, top_k=10, final_k=3, mode="hybrid", rerank_strategy="cross-encoder"):
|
| 71 |
+
"""
|
| 72 |
+
:param mode: "semantic", "bm25", or "hybrid"
|
| 73 |
+
:param rerank_strategy: "cross-encoder", "rrf", "mmr", or "none"
|
| 74 |
+
"""
|
| 75 |
+
semantic_chunks = []
|
| 76 |
+
bm25_chunks = []
|
| 77 |
+
query_vector = None
|
| 78 |
+
|
| 79 |
+
# 1. Fetch Candidates
|
| 80 |
+
if mode in ["semantic", "hybrid"]:
|
| 81 |
+
query_vector = self.embed_model.encode(query)
|
| 82 |
+
res = index.query(vector=query_vector.tolist(), top_k=top_k, include_metadata=True)
|
| 83 |
+
semantic_chunks = [match['metadata']['text'] for match in res['matches']]
|
| 84 |
+
|
| 85 |
+
if mode in ["bm25", "hybrid"]:
|
| 86 |
+
tokenized_query = query.lower().split()
|
| 87 |
+
bm25_scores = self.bm25.get_scores(tokenized_query)
|
| 88 |
+
top_indices = np.argsort(bm25_scores)[::-1][:top_k]
|
| 89 |
+
bm25_chunks = [self.final_chunks[i]['metadata']['text'] for i in top_indices]
|
| 90 |
+
|
| 91 |
+
# 2. Re-Ranking / Fusion
|
| 92 |
+
if mode == "hybrid" and rerank_strategy == "rrf":
|
| 93 |
+
return self._rrf_score(semantic_chunks, bm25_chunks)[:final_k]
|
| 94 |
+
|
| 95 |
+
# Standard combination for other methods
|
| 96 |
+
combined = list(dict.fromkeys(semantic_chunks + bm25_chunks)) # Deduplicate keep order
|
| 97 |
+
|
| 98 |
+
if rerank_strategy == "cross-encoder" and combined:
|
| 99 |
+
pairs = [[query, chunk] for chunk in combined]
|
| 100 |
+
scores = self.rerank_model.predict(pairs)
|
| 101 |
+
results = sorted(zip(combined, scores), key=lambda x: x[1], reverse=True)
|
| 102 |
+
return [res[0] for res in results[:final_k]]
|
| 103 |
+
|
| 104 |
+
elif rerank_strategy == "mmr" and combined:
|
| 105 |
+
if query_vector is None: query_vector = self.embed_model.encode(query)
|
| 106 |
+
return self._maximal_marginal_relevance(query_vector, combined, top_k=final_k)
|
| 107 |
+
|
| 108 |
+
return combined[:final_k]
|