ramailkk commited on
Commit
6d56aa1
·
1 Parent(s): 643df21

cleaning code -> making phase 1 pipeline ready

Browse files
config.yaml CHANGED
@@ -4,7 +4,7 @@ project_name: "arxiv_cyber_advisor"
4
  # Stage 1: Data Acquisition
5
  data_ingestion:
6
  category: "cs.AI"
7
- limit: 20
8
  save_local: true
9
  raw_data_path: "data/raw_arxiv.csv"
10
 
 
4
  # Stage 1: Data Acquisition
5
  data_ingestion:
6
  category: "cs.AI"
7
+ limit: 5
8
  save_local: true
9
  raw_data_path: "data/raw_arxiv.csv"
10
 
data_loader.py CHANGED
@@ -35,7 +35,7 @@ def fetch_arxiv_data(category="cs.AI", limit=5):
35
  "id": r.entry_id.split('/')[-1],
36
  "title": r.title,
37
  "abstract": r.summary.replace('\n', ' '),
38
- "full_text": full_text, # <--- NEW FIELD
39
  "url": r.pdf_url
40
  })
41
  return pd.DataFrame(results)
 
35
  "id": r.entry_id.split('/')[-1],
36
  "title": r.title,
37
  "abstract": r.summary.replace('\n', ' '),
38
+ "full_text": full_text, # <--- Main part of the data
39
  "url": r.pdf_url
40
  })
41
  return pd.DataFrame(results)
main.py CHANGED
@@ -1,28 +1,95 @@
1
- import yaml
2
- from data_processor import fetch_arxiv_data, process_to_chunks
3
 
4
- def load_config():
5
- with open("config.yaml", "r") as f:
6
- return yaml.safe_load(f)
 
 
 
 
 
 
 
 
 
 
7
 
8
  def main():
9
- config = load_config()
10
-
11
- # Run Stage 1
12
- raw_data = fetch_arxiv_data(
13
- category=config['data_ingestion']['category'],
14
- limit=config['data_ingestion']['limit']
15
- )
16
 
17
- # Run Stage 2 using YAML defaults
18
- final_chunks = process_to_chunks(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  raw_data,
20
- model_name=config['embedding']['model_name'],
21
- chunk_size=config['chunking']['chunk_size'],
22
- chunk_overlap=config['chunking']['chunk_overlap']
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  )
24
 
25
- print(f"✅ Pipeline finished with {len(final_chunks)} chunks.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
  if __name__ == "__main__":
28
  main()
 
1
+ import os
2
+ from dotenv import load_dotenv
3
 
4
+ from vector_db import get_pinecone_index, refresh_pinecone_index
5
+ from retriever.retriever import HybridRetriever
6
+ from retriever.generator import RAGGenerator
7
+ from retriever.processor import ChunkProcessor
8
+ import data_loader as dl
9
+
10
+ from models.llama_3_8b import Llama3_8B
11
+ from models.mistral_7b import Mistral_7b
12
+ from models.qwen_2_5 import Qwen2_5
13
+ from models.deepseek_v3 import DeepSeek_V3
14
+ from models.tiny_aya import TinyAya
15
+
16
+ load_dotenv()
17
 
18
  def main():
 
 
 
 
 
 
 
19
 
20
+ # ------------------------------------------------------------------
21
+ # 0. Configuration
22
+ # ------------------------------------------------------------------
23
+ hf_token = os.getenv("HF_TOKEN")
24
+ pinecone_api_key = os.getenv("PINECONE_API_KEY")
25
+ if not pinecone_api_key:
26
+ raise ValueError("PINECONE_API_KEY not found in environment variables")
27
+
28
+ query = "How do transformers handle long sequences?"
29
+
30
+ # ------------------------------------------------------------------
31
+ # 1. Data Ingestion
32
+ # ------------------------------------------------------------------
33
+ raw_data = dl.fetch_arxiv_data(category="cs.AI", limit=5)
34
+
35
+ # ------------------------------------------------------------------
36
+ # 2. Chunking & Embedding
37
+ # ------------------------------------------------------------------
38
+ proc = ChunkProcessor(model_name='all-MiniLM-L6-v2', verbose=True)
39
+ final_chunks = proc.process(
40
  raw_data,
41
+ technique="sentence", # options: fixed, recursive, character, sentence, semantic
42
+ chunk_size=500,
43
+ chunk_overlap=50
44
+ )
45
+
46
+ # ------------------------------------------------------------------
47
+ # 3. Vector DB
48
+ # ------------------------------------------------------------------
49
+ index_name = "arxiv-index"
50
+ index = get_pinecone_index(pinecone_api_key, index_name, dimension=384, metric="cosine")
51
+ refresh_pinecone_index(index, final_chunks, batch_size=100)
52
+
53
+ # ------------------------------------------------------------------
54
+ # 4. Retrieval
55
+ # ------------------------------------------------------------------
56
+ retriever = HybridRetriever(final_chunks, proc.encoder, verbose=True)
57
+ context_chunks = retriever.search(
58
+ query,
59
+ index,
60
+ mode="hybrid", # options: bm25, semantic, hybrid
61
+ rerank_strategy="cross-encoder", # options: cross-encoder, rrf
62
+ use_mmr=True,
63
+ top_k=10,
64
+ final_k=5
65
  )
66
 
67
+ if not context_chunks:
68
+ print("No context chunks retrieved. Check your index and query.")
69
+ return
70
+
71
+ # ------------------------------------------------------------------
72
+ # 5. Generation
73
+ # ------------------------------------------------------------------
74
+ rag_engine = RAGGenerator()
75
+
76
+
77
+
78
+ models = {
79
+ "Llama-3-8B": Llama3_8B(token=hf_token),
80
+ "Mistral-7B": Mistral_7b(token=hf_token),
81
+ "Qwen-2.5": Qwen2_5(token=hf_token),
82
+ "DeepSeek-V3": DeepSeek_V3(token=hf_token),
83
+ "TinyAya": TinyAya(token=hf_token)
84
+ }
85
+
86
+ for name, model in models.items():
87
+ print(f"\n--- {name} ---")
88
+ try:
89
+ print(rag_engine.get_answer(model, query, context_chunks, temperature=0.1))
90
+ except Exception as e:
91
+ print(f"Error: {e}")
92
+
93
 
94
  if __name__ == "__main__":
95
  main()
models/deepseek_v3.py CHANGED
@@ -5,7 +5,7 @@ class DeepSeek_V3:
5
  self.client = InferenceClient(token=token)
6
  self.model_id = "deepseek-ai/DeepSeek-V3"
7
 
8
- def generate(self, prompt, max_tokens=500, temperature=0.15):
9
  response = ""
10
  try:
11
  for message in self.client.chat_completion(
@@ -19,5 +19,5 @@ class DeepSeek_V3:
19
  content = message.choices[0].delta.content
20
  if content: response += content
21
  except Exception as e:
22
- return f"⚠️ DeepSeek API Busy: {e}"
23
  return response
 
5
  self.client = InferenceClient(token=token)
6
  self.model_id = "deepseek-ai/DeepSeek-V3"
7
 
8
+ def generate(self, prompt, max_tokens=500, temperature=0.1):
9
  response = ""
10
  try:
11
  for message in self.client.chat_completion(
 
19
  content = message.choices[0].delta.content
20
  if content: response += content
21
  except Exception as e:
22
+ return f" DeepSeek API Busy: {e}"
23
  return response
models/llama_3_8b.py CHANGED
@@ -5,13 +5,13 @@ class Llama3_8B:
5
  self.client = InferenceClient(token=token)
6
  self.model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
7
 
8
- def generate(self, prompt, max_tokens=500, temp=0.1):
9
  response = ""
10
  for message in self.client.chat_completion(
11
  model=self.model_id,
12
  messages=[{"role": "user", "content": prompt}],
13
  max_tokens=max_tokens,
14
- temperature=temp,
15
  stream=True,
16
  ):
17
  if message.choices:
 
5
  self.client = InferenceClient(token=token)
6
  self.model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
7
 
8
+ def generate(self, prompt, max_tokens=500, temperature=0.1):
9
  response = ""
10
  for message in self.client.chat_completion(
11
  model=self.model_id,
12
  messages=[{"role": "user", "content": prompt}],
13
  max_tokens=max_tokens,
14
+ temperature=temperature,
15
  stream=True,
16
  ):
17
  if message.choices:
models/mistral_7b.py CHANGED
@@ -1,20 +1,13 @@
1
- import os
2
  from huggingface_hub import InferenceClient
3
 
4
  class Mistral_7b:
5
  def __init__(self, token):
6
- # Initializing with api_key as per latest documentation
7
  self.client = InferenceClient(api_key=token)
8
- # Using the specific provider suffix
9
  self.model_id = "mistralai/Mistral-7B-Instruct-v0.2:featherless-ai"
10
 
11
- def generate(self, prompt, max_tokens=500, **kwargs):
12
- # Extract temperature, defaulting to 0.2 if not provided
13
- temperature = kwargs.get('temperature', kwargs.get('temp', 0.2))
14
-
15
  response = ""
16
  try:
17
- # Using the new .chat.completions.create syntax for Featherless
18
  stream = self.client.chat.completions.create(
19
  model=self.model_id,
20
  messages=[{"role": "user", "content": prompt}],
@@ -22,14 +15,12 @@ class Mistral_7b:
22
  temperature=temperature,
23
  stream=True,
24
  )
25
-
26
  for chunk in stream:
27
- # Accessing content through the standard completion object structure
28
  if chunk.choices and chunk.choices[0].delta.content:
29
  content = chunk.choices[0].delta.content
30
  response += content
31
 
32
  except Exception as e:
33
- return f" Mistral Featherless Error: {e}"
34
 
35
  return response
 
 
1
  from huggingface_hub import InferenceClient
2
 
3
  class Mistral_7b:
4
  def __init__(self, token):
 
5
  self.client = InferenceClient(api_key=token)
 
6
  self.model_id = "mistralai/Mistral-7B-Instruct-v0.2:featherless-ai"
7
 
8
+ def generate(self, prompt, max_tokens=500, temperature=0.1):
 
 
 
9
  response = ""
10
  try:
 
11
  stream = self.client.chat.completions.create(
12
  model=self.model_id,
13
  messages=[{"role": "user", "content": prompt}],
 
15
  temperature=temperature,
16
  stream=True,
17
  )
 
18
  for chunk in stream:
 
19
  if chunk.choices and chunk.choices[0].delta.content:
20
  content = chunk.choices[0].delta.content
21
  response += content
22
 
23
  except Exception as e:
24
+ return f" Mistral Featherless Error: {e}"
25
 
26
  return response
models/qwen_2_5.py CHANGED
@@ -5,7 +5,7 @@ class Qwen2_5:
5
  self.client = InferenceClient(token=token)
6
  self.model_id = "Qwen/Qwen2.5-72B-Instruct"
7
 
8
- def generate(self, prompt, max_tokens=500, temperature=0.3):
9
  response = ""
10
  for message in self.client.chat_completion(
11
  model=self.model_id,
 
5
  self.client = InferenceClient(token=token)
6
  self.model_id = "Qwen/Qwen2.5-72B-Instruct"
7
 
8
+ def generate(self, prompt, max_tokens=500, temperature=0.1):
9
  response = ""
10
  for message in self.client.chat_completion(
11
  model=self.model_id,
models/tiny_aya.py CHANGED
@@ -3,16 +3,10 @@ from huggingface_hub import InferenceClient
3
  class TinyAya:
4
  def __init__(self, token):
5
  self.client = InferenceClient(token=token)
6
- # 3.3B parameter model, great for multilingual/efficient RAG
7
  self.model_id = "CohereLabs/tiny-aya-global"
8
 
9
- def generate(self, prompt, max_tokens=400, **kwargs):
10
- """
11
- Using **kwargs makes this compatible with calls using 'temp' or 'temperature'.
12
- """
13
- # This line looks for 'temperature', then 'temp', and defaults to 0.3 if neither exist
14
- temperature = kwargs.get('temperature', kwargs.get('temp', 0.3))
15
-
16
  response = ""
17
  try:
18
  for message in self.client.chat_completion(
@@ -26,6 +20,6 @@ class TinyAya:
26
  content = message.choices[0].delta.content
27
  if content: response += content
28
  except Exception as e:
29
- return f" TinyAya Error: {e}"
30
 
31
  return response
 
3
  class TinyAya:
4
  def __init__(self, token):
5
  self.client = InferenceClient(token=token)
 
6
  self.model_id = "CohereLabs/tiny-aya-global"
7
 
8
+ def generate(self, prompt, max_tokens=500, temperature=0.1):
9
+
 
 
 
 
 
10
  response = ""
11
  try:
12
  for message in self.client.chat_completion(
 
20
  content = message.choices[0].delta.content
21
  if content: response += content
22
  except Exception as e:
23
+ return f" TinyAya Error: {e}"
24
 
25
  return response
retriever/generator.py CHANGED
@@ -15,5 +15,5 @@ Answer:"""
15
 
16
  def get_answer(self, model_instance, query, retrieved_contexts, **kwargs):
17
  """Uses a specific model instance to generate the final answer."""
18
- prompt = self.generate_prompt(query, retrieved_contexts)
19
  return model_instance.generate(prompt, **kwargs)
 
15
 
16
  def get_answer(self, model_instance, query, retrieved_contexts, **kwargs):
17
  """Uses a specific model instance to generate the final answer."""
18
+ prompt = self.generate_prompt(query, retrieved_contexts)
19
  return model_instance.generate(prompt, **kwargs)
retriever/processor.py CHANGED
@@ -1,169 +1,133 @@
1
  from langchain_text_splitters import (
2
  RecursiveCharacterTextSplitter,
3
  CharacterTextSplitter,
4
- SentenceTransformersTokenTextSplitter
 
5
  )
6
  from langchain_experimental.text_splitter import SemanticChunker
7
  from langchain_huggingface import HuggingFaceEmbeddings
8
  from sentence_transformers import SentenceTransformer
9
  from typing import List, Dict, Any, Optional
 
 
10
  import pandas as pd
11
 
 
12
  class ChunkProcessor:
13
  def __init__(self, model_name='all-MiniLM-L6-v2', verbose: bool = True):
14
  self.model_name = model_name
15
  self.encoder = SentenceTransformer(model_name)
16
  self.verbose = verbose
17
- # Required for Semantic Chunking
18
  self.hf_embeddings = HuggingFaceEmbeddings(model_name=model_name)
19
 
20
- def _print(self, *args, **kwargs):
21
- """Helper method to conditionally print"""
22
- if self.verbose:
23
- print(*args, **kwargs)
24
 
25
  def get_splitter(self, technique: str, chunk_size: int = 500, chunk_overlap: int = 50, **kwargs):
26
  """
27
  Factory method to return different chunking strategies.
28
-
29
  Strategies:
30
- - "fixed": Simple character-based splitting with empty separator (can split mid-sentence)
31
- - "recursive": Recursive character splitting with hierarchical separators (preserves semantics)
32
- - "character": Character-based splitting with paragraph separator
33
- - "sentence": Recursive splitting optimized for sentence boundaries
34
- - "semantic": Embedding-based semantic chunking
35
- - "token": Token-based splitting for transformer models
36
  """
37
  if technique == "fixed":
38
- # FIXED: Simple character-based splitter - WILL split mid-sentence
39
  return CharacterTextSplitter(
40
- separator=kwargs.get('separator', ""),
41
- chunk_size=chunk_size,
42
  chunk_overlap=chunk_overlap,
43
  length_function=len,
44
  is_separator_regex=False
45
  )
46
-
47
  elif technique == "recursive":
48
- # FIXED: Proper recursive splitter with default separators that preserve semantics
49
- separators = kwargs.get('separators', ["\n\n", "\n", ". ", "! ", "? ", "; ", ", ", " ", ""])
50
  return RecursiveCharacterTextSplitter(
51
- chunk_size=chunk_size,
52
  chunk_overlap=chunk_overlap,
53
- separators=separators,
54
  length_function=len,
55
  keep_separator=kwargs.get('keep_separator', True)
56
  )
57
-
58
  elif technique == "character":
59
- # FIXED: Character splitter with paragraph separator
60
  return CharacterTextSplitter(
61
- separator=kwargs.get('separator', "\n\n"),
62
- chunk_size=chunk_size,
63
  chunk_overlap=chunk_overlap,
64
  length_function=len,
65
  is_separator_regex=False
66
  )
67
-
68
  elif technique == "sentence":
69
- # FIXED: Using Recursive Splitter with comprehensive sentence boundaries
70
- # This preserves full sentences whenever possible
71
- return RecursiveCharacterTextSplitter(
72
  chunk_size=chunk_size,
73
  chunk_overlap=chunk_overlap,
74
- separators=kwargs.get('separators', ["\n\n", "\n", ". ", "? ", "! ", ".\n", "?\n", "!\n", "; ", ": ", ", ", " ", ""]),
75
- length_function=len,
76
- keep_separator=kwargs.get('keep_separator', True)
77
- )
78
-
79
  elif technique == "semantic":
80
- # FIXED: Semantic chunker with proper configuration
81
  return SemanticChunker(
82
- self.hf_embeddings,
83
  breakpoint_threshold_type=kwargs.get('breakpoint_threshold_type', "percentile"),
84
- breakpoint_threshold_amount=kwargs.get('breakpoint_threshold_amount', 95),
85
- min_chunk_size=kwargs.get('min_chunk_size', chunk_size // 10),
86
- max_chunk_size=kwargs.get('max_chunk_size', chunk_size)
87
- )
88
-
89
- elif technique == "token":
90
- # FIXED: Token-based splitter with proper token counting
91
- return SentenceTransformersTokenTextSplitter(
92
- model_name=self.model_name,
93
- tokens_per_chunk=chunk_size,
94
- chunk_overlap=chunk_overlap,
95
- length_function=kwargs.get('length_function', lambda x: len(self.encoder.encode(x)))
96
  )
97
-
98
  else:
99
- raise ValueError(f"Technique '{technique}' is not supported. Choose from: fixed, recursive, character, sentence, semantic, token")
 
 
 
 
100
 
101
- def process(self, df: pd.DataFrame, technique: str = "recursive", chunk_size: int = 500,
102
- chunk_overlap: int = 50, max_docs: Optional[int] = 5, verbose: Optional[bool] = None,
103
- **kwargs) -> List[Dict[str, Any]]:
104
  """
105
- Processes a DataFrame into vector-ready chunks with full output for documents.
106
-
107
  Args:
108
- df: DataFrame containing documents with columns: id, title, url, full_text
109
- technique: Chunking strategy to use
110
- chunk_size: Maximum size of each chunk (characters for most, tokens for token splitter)
111
  chunk_overlap: Overlap between consecutive chunks
112
- max_docs: Maximum number of documents to process (None for all)
113
- verbose: Override the instance's verbose setting (if None, uses instance setting)
114
- **kwargs: Additional arguments to pass to splitter
115
-
116
  Returns:
117
- List of processed chunks with embeddings and metadata
118
  """
119
- # Determine if we should print
120
  should_print = verbose if verbose is not None else self.verbose
121
-
122
- splitter = self.get_splitter(technique, chunk_size, chunk_overlap, **kwargs)
123
- processed_chunks = []
124
-
125
- # Select documents to process
126
- if max_docs:
127
- subset_df = df.head(max_docs)
128
- else:
129
- subset_df = df
130
-
131
- # Validate required columns exist
132
  required_cols = ['id', 'title', 'url', 'full_text']
133
- missing_cols = [col for col in required_cols if col not in subset_df.columns]
134
  if missing_cols:
135
  raise ValueError(f"DataFrame missing required columns: {missing_cols}")
136
-
 
 
 
 
137
  for _, row in subset_df.iterrows():
138
  if should_print:
139
- self._print("\n" + "="*80)
140
- self._print(f"📄 DOCUMENT: {row['title']}")
141
- self._print(f"🔗 URL: {row['url']}")
142
- self._print(f"📏 Technique: {technique.upper()} | Chunk Size: {chunk_size} | Overlap: {chunk_overlap}")
143
- self._print("-" * 80)
144
-
145
- # Split the text
146
  raw_chunks = splitter.split_text(row['full_text'])
147
-
148
- if should_print:
149
- self._print(f"🎯 Total Chunks Generated: {len(raw_chunks)}")
150
-
151
  for i, text in enumerate(raw_chunks):
152
- # Standardize output (handle both string and Document objects)
153
  content = text.page_content if hasattr(text, 'page_content') else text
154
-
155
  if should_print:
156
- # Print chunk preview
157
- self._print(f"\n[Chunk {i}] ({len(content)} chars):")
158
- preview = content[:200] + "..." if len(content) > 200 else content
159
- self._print(f" {preview}")
160
-
161
- # Generate embedding
162
- embedding = self.encoder.encode(content).tolist()
163
-
164
  processed_chunks.append({
165
  "id": f"{row['id']}-chunk-{i}",
166
- "values": embedding,
167
  "metadata": {
168
  "title": row['title'],
169
  "text": content,
@@ -174,67 +138,37 @@ class ChunkProcessor:
174
  "total_chunks": len(raw_chunks)
175
  }
176
  })
177
-
178
  if should_print:
179
- self._print("="*80)
180
-
181
  if should_print:
182
- self._print(f"\n✅ Finished processing {len(subset_df)} documents into {len(processed_chunks)} chunks.")
183
- if len(processed_chunks) > 0:
184
- self._print(f"📊 Average chunk size: {sum(c['metadata']['chunk_size'] for c in processed_chunks) / len(processed_chunks):.0f} chars")
185
-
186
  return processed_chunks
187
 
188
- def compare_strategies(self, df: pd.DataFrame, text_column: str = 'full_text',
189
- chunk_size: int = 500, max_docs: int = 1,
190
- verbose: Optional[bool] = None) -> Dict[str, Any]:
191
- """
192
- Compare different chunking strategies on the same document.
193
-
194
- Returns:
195
- Dictionary with comparison metrics for each strategy
196
- """
197
- # Determine if we should print
198
- should_print = verbose if verbose is not None else self.verbose
199
-
200
- strategies = ['fixed', 'recursive', 'character', 'sentence', 'semantic', 'token']
201
- results = {}
202
-
203
- # Get first document
204
- sample_text = df.iloc[0][text_column]
205
-
206
- for technique in strategies:
207
- try:
208
- if should_print:
209
- self._print(f"\n🔍 Testing {technique.upper()} strategy...")
210
-
211
- splitter = self.get_splitter(technique, chunk_size=chunk_size)
212
- chunks = splitter.split_text(sample_text)
213
-
214
- # Analyze chunks
215
- chunk_lengths = [len(c.page_content if hasattr(c, 'page_content') else c) for c in chunks]
216
- avg_chunk_size = sum(chunk_lengths) / len(chunk_lengths) if chunk_lengths else 0
217
-
218
- # Count how many chunks end with sentence boundaries
219
- sentence_enders = ['.', '!', '?', '"', "'"]
220
- complete_sentences = sum(1 for c in chunks
221
- if (c.page_content if hasattr(c, 'page_content') else c).strip()[-1] in sentence_enders)
222
-
223
- results[technique] = {
224
- 'num_chunks': len(chunks),
225
- 'avg_chunk_size': avg_chunk_size,
226
- 'min_chunk_size': min(chunk_lengths) if chunk_lengths else 0,
227
- 'max_chunk_size': max(chunk_lengths) if chunk_lengths else 0,
228
- 'complete_sentences_ratio': complete_sentences / len(chunks) if chunks else 0,
229
- 'chunk_lengths': chunk_lengths
230
- }
231
-
232
- if should_print:
233
- self._print(f" ✓ Generated {len(chunks)} chunks, avg size: {avg_chunk_size:.0f} chars")
234
-
235
- except Exception as e:
236
- results[technique] = {'error': str(e)}
237
- if should_print:
238
- self._print(f" ✗ Error: {str(e)}")
239
-
240
- return results
 
1
  from langchain_text_splitters import (
2
  RecursiveCharacterTextSplitter,
3
  CharacterTextSplitter,
4
+ SentenceTransformersTokenTextSplitter,
5
+ NLTKTextSplitter
6
  )
7
  from langchain_experimental.text_splitter import SemanticChunker
8
  from langchain_huggingface import HuggingFaceEmbeddings
9
  from sentence_transformers import SentenceTransformer
10
  from typing import List, Dict, Any, Optional
11
+ import nltk
12
+ nltk.download('punkt_tab', quiet=True)
13
  import pandas as pd
14
 
15
+
16
  class ChunkProcessor:
17
  def __init__(self, model_name='all-MiniLM-L6-v2', verbose: bool = True):
18
  self.model_name = model_name
19
  self.encoder = SentenceTransformer(model_name)
20
  self.verbose = verbose
 
21
  self.hf_embeddings = HuggingFaceEmbeddings(model_name=model_name)
22
 
23
+ # ------------------------------------------------------------------
24
+ # Splitters
25
+ # ------------------------------------------------------------------
 
26
 
27
  def get_splitter(self, technique: str, chunk_size: int = 500, chunk_overlap: int = 50, **kwargs):
28
  """
29
  Factory method to return different chunking strategies.
30
+
31
  Strategies:
32
+ - "fixed": Character-based, may split mid-sentence
33
+ - "recursive": Recursive character splitting with hierarchical separators
34
+ - "character": Character-based splitting on paragraph boundaries
35
+ - "sentence": Sliding window over NLTK sentences
36
+ - "semantic": Embedding-based semantic chunking
 
37
  """
38
  if technique == "fixed":
 
39
  return CharacterTextSplitter(
40
+ separator=kwargs.get('separator', ""),
41
+ chunk_size=chunk_size,
42
  chunk_overlap=chunk_overlap,
43
  length_function=len,
44
  is_separator_regex=False
45
  )
46
+
47
  elif technique == "recursive":
 
 
48
  return RecursiveCharacterTextSplitter(
49
+ chunk_size=chunk_size,
50
  chunk_overlap=chunk_overlap,
51
+ separators=kwargs.get('separators', ["\n\n", "\n", ". ", "! ", "? ", "; ", ", ", " ", ""]),
52
  length_function=len,
53
  keep_separator=kwargs.get('keep_separator', True)
54
  )
55
+
56
  elif technique == "character":
 
57
  return CharacterTextSplitter(
58
+ separator=kwargs.get('separator', "\n\n"),
59
+ chunk_size=chunk_size,
60
  chunk_overlap=chunk_overlap,
61
  length_function=len,
62
  is_separator_regex=False
63
  )
64
+
65
  elif technique == "sentence":
66
+ # sentence-level chunking using NLTK
67
+ return NLTKTextSplitter(
 
68
  chunk_size=chunk_size,
69
  chunk_overlap=chunk_overlap,
70
+ separator="\n"
71
+ )
72
+
 
 
73
  elif technique == "semantic":
 
74
  return SemanticChunker(
75
+ self.hf_embeddings,
76
  breakpoint_threshold_type=kwargs.get('breakpoint_threshold_type', "percentile"),
77
+ breakpoint_threshold_amount=kwargs.get('breakpoint_threshold_amount', 95)
 
 
 
 
 
 
 
 
 
 
 
78
  )
79
+
80
  else:
81
+ raise ValueError(f"Technique '{technique}' is not supported. Choose from: fixed, recursive, character, sentence, semantic")
82
+
83
+ # ------------------------------------------------------------------
84
+ # Processing
85
+ # ------------------------------------------------------------------
86
 
87
+ def process(self, df: pd.DataFrame, technique: str = "recursive", chunk_size: int = 500,
88
+ chunk_overlap: int = 50, max_docs: Optional[int] = 5,
89
+ verbose: Optional[bool] = None, **kwargs) -> List[Dict[str, Any]]:
90
  """
91
+ Processes a DataFrame into vector-ready chunks.
92
+
93
  Args:
94
+ df: DataFrame with columns: id, title, url, full_text
95
+ technique: Chunking strategy to use
96
+ chunk_size: Maximum size of each chunk in characters
97
  chunk_overlap: Overlap between consecutive chunks
98
+ max_docs: Number of documents to process (None for all)
99
+ verbose: Override instance verbose setting
100
+ **kwargs: Additional arguments passed to the splitter
101
+
102
  Returns:
103
+ List of chunk dicts with embeddings and metadata
104
  """
 
105
  should_print = verbose if verbose is not None else self.verbose
106
+
 
 
 
 
 
 
 
 
 
 
107
  required_cols = ['id', 'title', 'url', 'full_text']
108
+ missing_cols = [col for col in required_cols if col not in df.columns]
109
  if missing_cols:
110
  raise ValueError(f"DataFrame missing required columns: {missing_cols}")
111
+
112
+ splitter = self.get_splitter(technique, chunk_size, chunk_overlap, **kwargs)
113
+ subset_df = df.head(max_docs) if max_docs else df
114
+ processed_chunks = []
115
+
116
  for _, row in subset_df.iterrows():
117
  if should_print:
118
+ self._print_document_header(row['title'], row['url'], technique, chunk_size, chunk_overlap)
119
+
 
 
 
 
 
120
  raw_chunks = splitter.split_text(row['full_text'])
121
+
 
 
 
122
  for i, text in enumerate(raw_chunks):
 
123
  content = text.page_content if hasattr(text, 'page_content') else text
124
+
125
  if should_print:
126
+ self._print_chunk(i, content)
127
+
 
 
 
 
 
 
128
  processed_chunks.append({
129
  "id": f"{row['id']}-chunk-{i}",
130
+ "values": self.encoder.encode(content).tolist(),
131
  "metadata": {
132
  "title": row['title'],
133
  "text": content,
 
138
  "total_chunks": len(raw_chunks)
139
  }
140
  })
141
+
142
  if should_print:
143
+ self._print_document_summary(len(raw_chunks))
144
+
145
  if should_print:
146
+ self._print_processing_summary(len(subset_df), processed_chunks)
147
+
 
 
148
  return processed_chunks
149
 
150
+
151
+ # ------------------------------------------------------------------
152
+ # Printing
153
+ # ------------------------------------------------------------------
154
+
155
+ def _print_document_header(self, title, url, technique, chunk_size, chunk_overlap):
156
+ print("\n" + "="*80)
157
+ print(f"DOCUMENT: {title}")
158
+ print(f"URL: {url}")
159
+ print(f"Technique: {technique.upper()} | Chunk Size: {chunk_size} | Overlap: {chunk_overlap}")
160
+ print("-" * 80)
161
+
162
+ def _print_chunk(self, index, content):
163
+ print(f"\n[Chunk {index}] ({len(content)} chars):")
164
+ print(f" {content}")
165
+
166
+ def _print_document_summary(self, num_chunks):
167
+ print(f"Total Chunks Generated: {num_chunks}")
168
+ print("="*80)
169
+
170
+ def _print_processing_summary(self, num_docs, processed_chunks):
171
+ print(f"\nFinished processing {num_docs} documents into {len(processed_chunks)} chunks.")
172
+ if processed_chunks:
173
+ avg = sum(c['metadata']['chunk_size'] for c in processed_chunks) / len(processed_chunks)
174
+ print(f"Average chunk size: {avg:.0f} chars")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
retriever/retriever.py CHANGED
@@ -2,194 +2,166 @@ import numpy as np
2
  from rank_bm25 import BM25Okapi
3
  from sentence_transformers import CrossEncoder
4
  from sklearn.metrics.pairwise import cosine_similarity
5
- from typing import Optional
6
 
7
  class HybridRetriever:
8
  def __init__(self, final_chunks, embed_model, rerank_model_name='cross-encoder/ms-marco-MiniLM-L-6-v2', verbose: bool = True):
9
- """
10
- :param final_chunks: The list of chunk dictionaries with metadata.
11
- :param embed_model: The SentenceTransformer model used for query and chunk embedding.
12
- :param verbose: Whether to print retrieval details and final results.
13
- """
14
  self.final_chunks = final_chunks
15
  self.embed_model = embed_model
16
  self.rerank_model = CrossEncoder(rerank_model_name)
17
  self.verbose = verbose
18
 
19
- # Initialize BM25 corpus
20
  self.tokenized_corpus = [chunk['metadata']['text'].lower().split() for chunk in final_chunks]
21
  self.bm25 = BM25Okapi(self.tokenized_corpus)
22
 
23
- def _print(self, *args, **kwargs):
24
- """Helper method to conditionally print"""
25
- if self.verbose:
26
- print(*args, **kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
- def _rrf_score(self, semantic_results, bm25_results, k=60):
29
- """Reciprocal Rank Fusion (RRF) Implementation."""
30
  scores = {}
31
  for rank, chunk in enumerate(semantic_results):
32
  scores[chunk] = scores.get(chunk, 0) + 1 / (k + rank + 1)
33
  for rank, chunk in enumerate(bm25_results):
34
  scores[chunk] = scores.get(chunk, 0) + 1 / (k + rank + 1)
35
-
36
- sorted_chunks = sorted(scores.items(), key=lambda x: x[1], reverse=True)
37
- return [item[0] for item in sorted_chunks]
38
 
39
- def _maximal_marginal_relevance(self, query_embedding, chunk_texts, lambda_param=0.5, top_k=3):
40
- """
41
- MMR Re-ranking to balance relevance and diversity.
42
- """
43
- if not chunk_texts: return []
44
-
45
- chunk_embeddings = self.embed_model.encode(chunk_texts)
46
- query_embedding = query_embedding.reshape(1, -1)
 
 
 
 
 
 
 
 
 
47
 
48
- # Initial relevance scores
 
49
  relevance_scores = cosine_similarity(query_embedding, chunk_embeddings)[0]
50
 
51
- selected_indices = []
52
- unselected_indices = list(range(len(chunk_texts)))
53
 
54
- # First pick: most relevant
55
- idx = np.argmax(relevance_scores)
56
- selected_indices.append(idx)
57
- unselected_indices.remove(idx)
58
 
59
- while len(selected_indices) < min(top_k, len(chunk_texts)):
60
- mmr_scores = []
61
- for un_idx in unselected_indices:
62
- # Similarity to query
63
- rel = relevance_scores[un_idx]
64
- # Max similarity to already selected chunks (redundancy)
65
- sim_to_selected = max([cosine_similarity(chunk_embeddings[un_idx].reshape(1, -1),
66
- chunk_embeddings[sel_idx].reshape(1, -1))[0][0]
67
- for sel_idx in selected_indices])
68
-
69
- mmr_score = lambda_param * rel - (1 - lambda_param) * sim_to_selected
70
- mmr_scores.append((un_idx, mmr_score))
71
-
72
- next_idx = max(mmr_scores, key=lambda x: x[1])[0]
73
- selected_indices.append(next_idx)
74
- unselected_indices.remove(next_idx)
75
-
76
- return [chunk_texts[i] for i in selected_indices]
77
-
78
- def search(self, query, index, top_k=10, final_k=3, mode="hybrid", rerank_strategy="cross-encoder",
79
- verbose: Optional[bool] = None):
 
80
  """
81
- :param mode: "semantic", "bm25", or "hybrid"
82
- :param rerank_strategy: "cross-encoder", "rrf", "mmr", or "none"
83
- :param verbose: Override the instance's verbose setting (if None, uses instance setting)
 
84
  """
85
- # Determine if we should print
86
  should_print = verbose if verbose is not None else self.verbose
87
 
88
  if should_print:
89
- self._print("\n" + "="*80)
90
- self._print(f"🔍 SEARCH QUERY: {query}")
91
- self._print(f"📊 Mode: {mode.upper()} | Rerank: {rerank_strategy.upper()}")
92
- self._print(f"🎯 Top-K: {top_k} | Final-K: {final_k}")
93
- self._print("-" * 80)
94
-
95
- semantic_chunks = []
96
- bm25_chunks = []
97
  query_vector = None
 
98
 
99
- # 1. Fetch Candidates
100
  if mode in ["semantic", "hybrid"]:
 
101
  if should_print:
102
- self._print(f"📚 Semantic Search: Retrieving top {top_k} candidates...")
103
-
104
- query_vector = self.embed_model.encode(query)
105
- res = index.query(vector=query_vector.tolist(), top_k=top_k, include_metadata=True)
106
- semantic_chunks = [match['metadata']['text'] for match in res['matches']]
107
-
108
- if should_print:
109
- self._print(f" ✓ Retrieved {len(semantic_chunks)} semantic candidates")
110
- for i, chunk in enumerate(semantic_chunks[:3]): # Show first 3
111
- preview = chunk[:100] + "..." if len(chunk) > 100 else chunk
112
- self._print(f" [{i}] {preview}")
113
 
114
  if mode in ["bm25", "hybrid"]:
 
115
  if should_print:
116
- self._print(f"📚 BM25 Search: Retrieving top {top_k} candidates...")
117
-
118
- tokenized_query = query.lower().split()
119
- bm25_scores = self.bm25.get_scores(tokenized_query)
120
- top_indices = np.argsort(bm25_scores)[::-1][:top_k]
121
- bm25_chunks = [self.final_chunks[i]['metadata']['text'] for i in top_indices]
122
-
123
- if should_print:
124
- self._print(f" ✓ Retrieved {len(bm25_chunks)} BM25 candidates")
125
- for i, chunk in enumerate(bm25_chunks[:3]): # Show first 3
126
- preview = chunk[:100] + "..." if len(chunk) > 100 else chunk
127
- self._print(f" [{i}] {preview}")
128
 
129
- # 2. Re-Ranking / Fusion
130
- if mode == "hybrid" and rerank_strategy == "rrf":
131
- if should_print:
132
- self._print(f"🔄 Applying Reciprocal Rank Fusion (RRF)...")
133
-
134
- results = self._rrf_score(semantic_chunks, bm25_chunks)[:final_k]
135
-
136
- if should_print:
137
- self._print(f"✅ Final {final_k} Results:")
138
- for i, chunk in enumerate(results):
139
- preview = chunk[:150] + "..." if len(chunk) > 150 else chunk
140
- self._print(f" [{i+1}] {preview}")
141
- self._print("="*80)
142
-
143
- return results
144
-
145
- # Standard combination for other methods
146
- combined = list(dict.fromkeys(semantic_chunks + bm25_chunks)) # Deduplicate keep order
147
-
148
- if should_print:
149
- self._print(f"🔄 Combined unique candidates: {len(combined)}")
150
- self._print(f"🔄 Applying {rerank_strategy.upper()} reranking...")
151
-
152
- if rerank_strategy == "cross-encoder" and combined:
153
-
154
- pairs = [[query, chunk] for chunk in combined]
155
- scores = self.rerank_model.predict(pairs)
156
- results = sorted(zip(combined, scores), key=lambda x: x[1], reverse=True)
157
- results = [res[0] for res in results[:final_k]]
158
-
159
- if should_print:
160
- self._print(f"\n✅ Final {final_k} Results (Cross-Encoder Reranked):")
161
- for i, chunk in enumerate(results):
162
- preview = chunk[:150] + "..." if len(chunk) > 150 else chunk
163
- self._print(f" [{i+1}] {preview}")
164
- self._print("="*80)
165
-
166
- return results
167
-
168
- elif rerank_strategy == "mmr" and combined:
169
- if should_print:
170
- self._print(f" Using MMR with λ=0.5 to balance relevance and diversity")
171
-
172
- if query_vector is None:
173
  query_vector = self.embed_model.encode(query)
174
- results = self._maximal_marginal_relevance(query_vector, combined, top_k=final_k)
175
-
176
- if should_print:
177
- self._print(f"\n✅ Final {final_k} Results (MMR Reranked):")
178
- for i, chunk in enumerate(results):
179
- preview = chunk[:150] + "..." if len(chunk) > 150 else chunk
180
- self._print(f" [{i+1}] {preview}")
181
- self._print("="*80)
182
-
183
- return results
184
-
185
- else: # "none" or fallback
186
- results = combined[:final_k]
187
-
188
- if should_print:
189
- self._print(f"\n✅ Final {final_k} Results (No Reranking):")
190
- for i, chunk in enumerate(results):
191
- preview = chunk[:150] + "..." if len(chunk) > 150 else chunk
192
- self._print(f" [{i+1}] {preview}")
193
- self._print("="*80)
194
-
195
- return results
 
 
 
 
 
 
 
 
 
 
 
2
  from rank_bm25 import BM25Okapi
3
  from sentence_transformers import CrossEncoder
4
  from sklearn.metrics.pairwise import cosine_similarity
5
+ from typing import Optional, List
6
 
7
  class HybridRetriever:
8
  def __init__(self, final_chunks, embed_model, rerank_model_name='cross-encoder/ms-marco-MiniLM-L-6-v2', verbose: bool = True):
 
 
 
 
 
9
  self.final_chunks = final_chunks
10
  self.embed_model = embed_model
11
  self.rerank_model = CrossEncoder(rerank_model_name)
12
  self.verbose = verbose
13
 
 
14
  self.tokenized_corpus = [chunk['metadata']['text'].lower().split() for chunk in final_chunks]
15
  self.bm25 = BM25Okapi(self.tokenized_corpus)
16
 
17
+ # ------------------------------------------------------------------
18
+ # Retrieval
19
+ # ------------------------------------------------------------------
20
+
21
+ def _semantic_search(self, query, index, top_k) -> tuple[np.ndarray, List[str]]:
22
+ query_vector = self.embed_model.encode(query)
23
+ res = index.query(vector=query_vector.tolist(), top_k=top_k, include_metadata=True)
24
+ chunks = [match['metadata']['text'] for match in res['matches']]
25
+ return query_vector, chunks
26
+
27
+ def _bm25_search(self, query, top_k) -> List[str]:
28
+ tokenized_query = query.lower().split()
29
+ scores = self.bm25.get_scores(tokenized_query)
30
+ top_indices = np.argsort(scores)[::-1][:top_k]
31
+ return [self.final_chunks[i]['metadata']['text'] for i in top_indices]
32
+
33
+ # ------------------------------------------------------------------
34
+ # Fusion
35
+ # ------------------------------------------------------------------
36
 
37
+ def _rrf_score(self, semantic_results, bm25_results, k=60) -> List[str]:
 
38
  scores = {}
39
  for rank, chunk in enumerate(semantic_results):
40
  scores[chunk] = scores.get(chunk, 0) + 1 / (k + rank + 1)
41
  for rank, chunk in enumerate(bm25_results):
42
  scores[chunk] = scores.get(chunk, 0) + 1 / (k + rank + 1)
43
+ return [chunk for chunk, _ in sorted(scores.items(), key=lambda x: x[1], reverse=True)]
 
 
44
 
45
+ # ------------------------------------------------------------------
46
+ # Reranking
47
+ # ------------------------------------------------------------------
48
+
49
+ def _cross_encoder_rerank(self, query, chunks, final_k) -> List[str]:
50
+ pairs = [[query, chunk] for chunk in chunks]
51
+ scores = self.rerank_model.predict(pairs)
52
+ ranked = sorted(zip(chunks, scores), key=lambda x: x[1], reverse=True)
53
+ return [chunk for chunk, _ in ranked[:final_k]]
54
+
55
+ # ------------------------------------------------------------------
56
+ # MMR (applied after reranking as a diversity filter)
57
+ # ------------------------------------------------------------------
58
+
59
+ def _maximal_marginal_relevance(self, query_vector, chunks, lambda_param=0.5, top_k=3) -> List[str]:
60
+ if not chunks:
61
+ return []
62
 
63
+ chunk_embeddings = self.embed_model.encode(chunks)
64
+ query_embedding = query_vector.reshape(1, -1)
65
  relevance_scores = cosine_similarity(query_embedding, chunk_embeddings)[0]
66
 
67
+ selected, unselected = [], list(range(len(chunks)))
 
68
 
69
+ first = int(np.argmax(relevance_scores))
70
+ selected.append(first)
71
+ unselected.remove(first)
 
72
 
73
+ while len(selected) < min(top_k, len(chunks)):
74
+ mmr_scores = [
75
+ (i, lambda_param * relevance_scores[i] - (1 - lambda_param) * max(
76
+ cosine_similarity(chunk_embeddings[i].reshape(1, -1),
77
+ chunk_embeddings[s].reshape(1, -1))[0][0]
78
+ for s in selected
79
+ ))
80
+ for i in unselected
81
+ ]
82
+ best = max(mmr_scores, key=lambda x: x[1])[0]
83
+ selected.append(best)
84
+ unselected.remove(best)
85
+
86
+ return [chunks[i] for i in selected]
87
+
88
+ # ------------------------------------------------------------------
89
+ # Main search
90
+ # ------------------------------------------------------------------
91
+
92
+ def search(self, query, index, top_k=10, final_k=3, mode="hybrid",
93
+ rerank_strategy="cross-encoder", use_mmr=True, lambda_param=0.5,
94
+ verbose: Optional[bool] = None) -> List[str]:
95
  """
96
+ :param mode: "semantic", "bm25", or "hybrid"
97
+ :param rerank_strategy: "cross-encoder", "rrf", or "none"
98
+ :param use_mmr: Whether to apply MMR diversity filter after reranking
99
+ :param lambda_param: MMR trade-off between relevance (1.0) and diversity (0.0)
100
  """
 
101
  should_print = verbose if verbose is not None else self.verbose
102
 
103
  if should_print:
104
+ self._print_search_header(query, mode, rerank_strategy, top_k, final_k)
105
+
106
+ # 1. Retrieve candidates
 
 
 
 
 
107
  query_vector = None
108
+ semantic_chunks, bm25_chunks = [], []
109
 
 
110
  if mode in ["semantic", "hybrid"]:
111
+ query_vector, semantic_chunks = self._semantic_search(query, index, top_k)
112
  if should_print:
113
+ self._print_candidates("Semantic Search", semantic_chunks)
 
 
 
 
 
 
 
 
 
 
114
 
115
  if mode in ["bm25", "hybrid"]:
116
+ bm25_chunks = self._bm25_search(query, top_k)
117
  if should_print:
118
+ self._print_candidates("BM25 Search", bm25_chunks)
 
 
 
 
 
 
 
 
 
 
 
119
 
120
+ # 2. Fuse / rerank
121
+ if rerank_strategy == "rrf":
122
+ candidates = self._rrf_score(semantic_chunks, bm25_chunks)[:final_k]
123
+ label = "RRF"
124
+ elif rerank_strategy == "cross-encoder":
125
+ combined = list(dict.fromkeys(semantic_chunks + bm25_chunks))
126
+ candidates = self._cross_encoder_rerank(query, combined, final_k)
127
+ label = "Cross-Encoder"
128
+ else: # "none"
129
+ candidates = list(dict.fromkeys(semantic_chunks + bm25_chunks))[:final_k]
130
+ label = "No Reranking"
131
+
132
+ # 3. MMR diversity filter (applied after reranking)
133
+ if use_mmr and candidates:
134
+ if query_vector is None:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
135
  query_vector = self.embed_model.encode(query)
136
+ candidates = self._maximal_marginal_relevance(query_vector, candidates,
137
+ lambda_param=lambda_param, top_k=3)
138
+ label += " + MMR"
139
+
140
+ if should_print:
141
+ self._print_final_results(candidates, label)
142
+
143
+ return candidates
144
+
145
+ # ------------------------------------------------------------------
146
+ # Printing
147
+ # ------------------------------------------------------------------
148
+
149
+ def _print_search_header(self, query, mode, rerank_strategy, top_k, final_k):
150
+ print("\n" + "="*80)
151
+ print(f" SEARCH QUERY: {query}")
152
+ print(f"Mode: {mode.upper()} | Rerank: {rerank_strategy.upper()}")
153
+ print(f"Top-K: {top_k} | Final-K: {final_k}")
154
+ print("-" * 80)
155
+
156
+ def _print_candidates(self, label, chunks, preview_n=3):
157
+ print(f"{label}: Retrieved {len(chunks)} candidates")
158
+ for i, chunk in enumerate(chunks[:preview_n]):
159
+ preview = chunk[:100] + "..." if len(chunk) > 100 else chunk
160
+ print(f" [{i}] {preview}")
161
+
162
+ def _print_final_results(self, results, strategy_label):
163
+ print(f"\n Final {len(results)} Results ({strategy_label}):")
164
+ for i, chunk in enumerate(results):
165
+ preview = chunk[:150] + "..." if len(chunk) > 150 else chunk
166
+ print(f" [{i+1}] {preview}")
167
+ print("="*80)
vector_db.py CHANGED
@@ -22,12 +22,81 @@ def get_pinecone_index(api_key, index_name, dimension=384, metric="cosine"):
22
 
23
  return pc.Index(index_name)
24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  def upsert_to_pinecone(index, chunks, batch_size=100):
26
- """Upserts chunks to Pinecone in manageable batches."""
 
 
 
 
 
 
27
  print(f"Uploading {len(chunks)} chunks to Pinecone...")
28
 
29
  for i in range(0, len(chunks), batch_size):
30
  batch = chunks[i : i + batch_size]
31
  index.upsert(vectors=batch)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
 
33
- print("✅ Upsert complete.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
  return pc.Index(index_name)
24
 
25
+ def prepare_vectors_for_upsert(final_chunks):
26
+ """Convert final_chunks to the format expected by Pinecone upsert"""
27
+ vectors = []
28
+ for chunk in final_chunks:
29
+ vectors.append({
30
+ 'id': chunk['id'],
31
+ 'values': chunk['values'], # The embedding vector
32
+ 'metadata': {
33
+ 'text': chunk['metadata']['text'],
34
+ 'title': chunk['metadata']['title'],
35
+ 'url': chunk['metadata']['url'],
36
+ 'chunk_index': chunk['metadata']['chunk_index'],
37
+ 'technique': chunk['metadata']['technique'],
38
+ 'chunk_size': chunk['metadata']['chunk_size'],
39
+ 'total_chunks': chunk['metadata']['total_chunks']
40
+ }
41
+ })
42
+ return vectors
43
+
44
  def upsert_to_pinecone(index, chunks, batch_size=100):
45
+ """Upserts chunks to Pinecone in manageable batches.
46
+
47
+ Args:
48
+ index: Pinecone index object
49
+ chunks: List of chunk dictionaries (as returned by prepare_vectors_for_upsert)
50
+ batch_size: Number of vectors to upsert in each batch
51
+ """
52
  print(f"Uploading {len(chunks)} chunks to Pinecone...")
53
 
54
  for i in range(0, len(chunks), batch_size):
55
  batch = chunks[i : i + batch_size]
56
  index.upsert(vectors=batch)
57
+ print(f" Uploaded batch {i//batch_size + 1}/{(len(chunks)-1)//batch_size + 1} ({len(batch)} vectors)")
58
+
59
+ print(" Upsert complete.")
60
+
61
+ def refresh_pinecone_index(index, final_chunks, batch_size=100):
62
+ """Helper function to refresh index with new chunks.
63
+
64
+ This function checks if the index has the expected number of vectors,
65
+ and upserts if necessary.
66
+
67
+ Args:
68
+ index: Pinecone index object
69
+ final_chunks: List of chunk dictionaries with embeddings
70
+ batch_size: Batch size for upsert
71
+
72
+ Returns:
73
+ Boolean indicating if upsert was performed
74
+ """
75
+ try:
76
+ stats = index.describe_index_stats()
77
+ current_vector_count = stats.get('total_vector_count', 0)
78
+ expected_vector_count = len(final_chunks)
79
+
80
+ print(f"\n Current vectors in index: {current_vector_count}")
81
+ print(f" Expected vectors: {expected_vector_count}")
82
 
83
+ if current_vector_count == 0:
84
+ print(" Index is empty. Preparing vectors for upsert...")
85
+ vectors_to_upsert = prepare_vectors_for_upsert(final_chunks)
86
+ upsert_to_pinecone(index, vectors_to_upsert, batch_size)
87
+
88
+ # Verify upsert
89
+ stats = index.describe_index_stats()
90
+ print(f" After upsert - Total vectors: {stats.get('total_vector_count', 0)}")
91
+ return True
92
+ elif current_vector_count != expected_vector_count:
93
+ print(f" Vector count mismatch. Expected {expected_vector_count}, got {current_vector_count}")
94
+ print(" Consider recreating the index if you want to refresh.")
95
+ return False
96
+ else:
97
+ print(f"ℹ Index already has {current_vector_count} vectors. Ready for search.")
98
+ return False
99
+
100
+ except Exception as e:
101
+ print(f"Error checking index stats: {e}")
102
+ return False