ramailkk commited on
Commit
a865c33
·
1 Parent(s): 5652550

evaluator methods added

Browse files
Files changed (6) hide show
  1. config.yaml +34 -25
  2. config_loader.py +27 -0
  3. main.py +66 -59
  4. query_only.py +1 -1
  5. retriever/evaluator.py +105 -0
  6. vector_db.py +66 -74
config.yaml CHANGED
@@ -1,37 +1,46 @@
1
- # Pipeline Configuration for ArXiv RAG
2
- project_name: "arxiv_cyber_advisor"
 
3
 
4
- # Stage 1: Data Acquisition
5
- data_ingestion:
6
  category: "cs.AI"
7
- limit: 5
8
- save_local: true
9
- raw_data_path: "data/raw_arxiv.csv"
10
 
11
- # Stage 2: Processing & Embedding
12
- embedding:
13
- model_name: "all-MiniLM-L6-v2"
14
- device: "cpu" # Change to "cuda" if testing on a GPU machine
15
-
16
- chunking:
17
- technique: "recursive"
18
  chunk_size: 500
19
  chunk_overlap: 50
20
 
21
- # Stage 3: Vector Database (Pinecone)
22
  vector_db:
23
- index_name: "arxiv-index"
24
  dimension: 384
25
  metric: "cosine"
 
26
 
27
- # Stage 4: Retrieval & Re-ranking
28
  retrieval:
29
- top_k_hybrid: 10
30
- rerank_model: "cross-encoder/ms-marco-MiniLM-L-6-v2"
31
- top_k_final: 3
 
 
 
 
 
 
 
 
 
 
32
 
33
- # Stage 5: Generation (LLM)
34
- llm:
35
- model_id: "meta-llama/Meta-Llama-3-8B-Instruct"
36
- max_new_tokens: 500
37
- temperature: 0.1
 
 
 
1
+ # ------------------------------------------------------------------
2
+ # RAG TOURNAMENT CONFIGURATION
3
+ # ------------------------------------------------------------------
4
 
5
+ project:
6
+ name: "arxiv-research-rag"
7
  category: "cs.AI"
8
+ doc_limit: 5
 
 
9
 
10
+ processing:
11
+ # Embedding model used for both vector db and evaluator similarity
12
+ embedding_model: "all-MiniLM-L6-v2"
13
+ # Options: sentence, recursive, semantic, fixed
14
+ technique: "recursive"
15
+ # Token limit for MiniLM is 256; keeping it at 250 for safety
 
16
  chunk_size: 500
17
  chunk_overlap: 50
18
 
 
19
  vector_db:
20
+ base_index_name: "arxiv-tournament"
21
  dimension: 384
22
  metric: "cosine"
23
+ batch_size: 100
24
 
 
25
  retrieval:
26
+ # Options: hybrid, semantic, bm25
27
+ mode: "hybrid"
28
+ # Options: cross-encoder, rrf
29
+ rerank_strategy: "cross-encoder"
30
+ use_mmr: true
31
+ top_k: 10
32
+ final_k: 5
33
+
34
+ generation:
35
+ temperature: 0.1
36
+ max_new_tokens: 512
37
+ # The model used to Judge the others
38
+ judge_model: "Llama-3-8B"
39
 
40
+ # List of contestants in the tournament
41
+ models:
42
+ - "Llama-3-8B"
43
+ - "Mistral-7B"
44
+ - "Qwen-2.5"
45
+ - "DeepSeek-V3"
46
+ - "TinyAya"
config_loader.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import yaml
2
+ from pathlib import Path
3
+
4
+ class RAGConfig:
5
+ def __init__(self, config_path="config.yaml"):
6
+ with open(config_path, 'r') as f:
7
+ self.data = yaml.safe_load(f)
8
+
9
+ @property
10
+ def project(self): return self.data['project']
11
+
12
+ @property
13
+ def processing(self): return self.data['processing']
14
+
15
+ @property
16
+ def db(self): return self.data['vector_db']
17
+
18
+ @property
19
+ def retrieval(self): return self.data['retrieval']
20
+
21
+ @property
22
+ def gen(self): return self.data['generation']
23
+
24
+ @property
25
+ def model_list(self): return self.data['models']
26
+
27
+ cfg = RAGConfig()
main.py CHANGED
@@ -1,96 +1,103 @@
1
  import os
2
  from dotenv import load_dotenv
 
3
 
4
  from vector_db import get_pinecone_index, refresh_pinecone_index
5
  from retriever.retriever import HybridRetriever
6
  from retriever.generator import RAGGenerator
7
  from retriever.processor import ChunkProcessor
 
8
  import data_loader as dl
9
 
 
10
  from models.llama_3_8b import Llama3_8B
11
  from models.mistral_7b import Mistral_7b
12
  from models.qwen_2_5 import Qwen2_5
13
  from models.deepseek_v3 import DeepSeek_V3
14
  from models.tiny_aya import TinyAya
15
 
 
 
 
 
 
 
 
 
16
  load_dotenv()
17
 
18
  def main():
 
 
 
 
 
 
 
 
 
19
 
20
- # ------------------------------------------------------------------
21
- # 0. Configuration
22
- # Query defined here
23
- # ------------------------------------------------------------------
24
- hf_token = os.getenv("HF_TOKEN")
25
- pinecone_api_key = os.getenv("PINECONE_API_KEY")
26
- if not pinecone_api_key:
27
- raise ValueError("PINECONE_API_KEY not found in environment variables")
28
-
29
- query = "How do transformers handle long sequences?"
30
-
31
- # ------------------------------------------------------------------
32
- # 1. Data Ingestion
33
- # ------------------------------------------------------------------
34
- raw_data = dl.fetch_arxiv_data(category="cs.AI", limit=5)
35
-
36
- # ------------------------------------------------------------------
37
  # 2. Chunking & Embedding
38
- # ------------------------------------------------------------------
39
- proc = ChunkProcessor(model_name='all-MiniLM-L6-v2', verbose=True)
40
  final_chunks = proc.process(
41
  raw_data,
42
- technique="sentence", # options: fixed, recursive, character, sentence, semantic
43
- chunk_size=500,
44
- chunk_overlap=50
45
  )
46
 
47
- # ------------------------------------------------------------------
48
- # 3. Vector DB
49
- # ------------------------------------------------------------------
50
- index_name = "arxiv-index"
51
- index = get_pinecone_index(pinecone_api_key, index_name, dimension=384, metric="cosine")
52
- refresh_pinecone_index(index, final_chunks, batch_size=100)
 
 
53
 
54
- # ------------------------------------------------------------------
55
  # 4. Retrieval
56
- # ------------------------------------------------------------------
57
- retriever = HybridRetriever(final_chunks, proc.encoder, verbose=True)
58
  context_chunks = retriever.search(
59
- query,
60
- index,
61
- mode="hybrid", # options: bm25, semantic, hybrid
62
- rerank_strategy="cross-encoder", # options: cross-encoder, rrf
63
- use_mmr=True,
64
- top_k=10,
65
- final_k=5
66
  )
67
 
68
- if not context_chunks:
69
- print("No context chunks retrieved. Check your index and query.")
70
- return
71
-
72
- # ------------------------------------------------------------------
73
- # 5. Generation
74
- # ------------------------------------------------------------------
75
  rag_engine = RAGGenerator()
76
-
77
 
78
-
79
- models = {
80
- "Llama-3-8B": Llama3_8B(token=hf_token),
81
- "Mistral-7B": Mistral_7b(token=hf_token),
82
- "Qwen-2.5": Qwen2_5(token=hf_token),
83
- "DeepSeek-V3": DeepSeek_V3(token=hf_token),
84
- "TinyAya": TinyAya(token=hf_token)
85
- }
86
-
87
- for name, model in models.items():
88
- print(f"\n--- {name} ---")
89
  try:
90
- print(rag_engine.get_answer(model, query, context_chunks, temperature=0.1))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
  except Exception as e:
92
- print(f"Error: {e}")
93
 
 
94
 
95
  if __name__ == "__main__":
96
  main()
 
1
  import os
2
  from dotenv import load_dotenv
3
+ from config_loader import cfg # Import the Mother Config
4
 
5
  from vector_db import get_pinecone_index, refresh_pinecone_index
6
  from retriever.retriever import HybridRetriever
7
  from retriever.generator import RAGGenerator
8
  from retriever.processor import ChunkProcessor
9
+ from retriever.evaluator import RAGEvaluator
10
  import data_loader as dl
11
 
12
+ # Import fleet mapping
13
  from models.llama_3_8b import Llama3_8B
14
  from models.mistral_7b import Mistral_7b
15
  from models.qwen_2_5 import Qwen2_5
16
  from models.deepseek_v3 import DeepSeek_V3
17
  from models.tiny_aya import TinyAya
18
 
19
+ MODEL_MAP = {
20
+ "Llama-3-8B": Llama3_8B,
21
+ "Mistral-7B": Mistral_7b,
22
+ "Qwen-2.5": Qwen2_5,
23
+ "DeepSeek-V3": DeepSeek_V3,
24
+ "TinyAya": TinyAya
25
+ }
26
+
27
  load_dotenv()
28
 
29
  def main():
30
+ hf_token = os.getenv("HF_TOKEN")
31
+ pinecone_key = os.getenv("PINECONE_API_KEY")
32
+ query = "How do transformers handle long sequences?"
33
+
34
+ # 1. Data Ingestion (Controlled by Config)
35
+ raw_data = dl.fetch_arxiv_data(
36
+ category=cfg.project['category'],
37
+ limit=cfg.project['doc_limit']
38
+ )
39
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  # 2. Chunking & Embedding
41
+ proc = ChunkProcessor(model_name=cfg.processing['embedding_model'])
 
42
  final_chunks = proc.process(
43
  raw_data,
44
+ technique=cfg.processing['technique'],
45
+ chunk_size=cfg.processing['chunk_size'],
46
+ chunk_overlap=cfg.processing['chunk_overlap']
47
  )
48
 
49
+ # 3. Vector DB (Auto-names index based on technique)
50
+ index = get_pinecone_index(
51
+ pinecone_key,
52
+ cfg.db['base_index_name'],
53
+ technique=cfg.processing['technique'],
54
+ dimension=cfg.db['dimension']
55
+ )
56
+ refresh_pinecone_index(index, final_chunks, batch_size=cfg.db['batch_size'])
57
 
 
58
  # 4. Retrieval
59
+ retriever = HybridRetriever(final_chunks, proc.encoder)
 
60
  context_chunks = retriever.search(
61
+ query, index,
62
+ mode=cfg.retrieval['mode'],
63
+ rerank_strategy=cfg.retrieval['rerank_strategy'],
64
+ use_mmr=cfg.retrieval['use_mmr'],
65
+ top_k=cfg.retrieval['top_k'],
66
+ final_k=cfg.retrieval['final_k']
 
67
  )
68
 
69
+ # 5. Initialization of Contestants
 
 
 
 
 
 
70
  rag_engine = RAGGenerator()
71
+ models = {name: MODEL_MAP[name](token=hf_token) for name in cfg.model_list}
72
 
73
+ # Setup Evaluator with the designated Judge
74
+ judge_llm = models[cfg.gen['judge_model']]
75
+ evaluator = RAGEvaluator(judge_llm, proc.encoder)
76
+ tournament_results = {}
77
+
78
+ # 6. Tournament Loop
79
+ for name, model_inst in models.items():
80
+ print(f"\n--- Processing {name} ---")
 
 
 
81
  try:
82
+ # Generation
83
+ answer = rag_engine.get_answer(
84
+ model_inst, query, context_chunks,
85
+ temperature=cfg.gen['temperature']
86
+ )
87
+
88
+ # Batch Evaluation
89
+ faith = evaluator.evaluate_faithfulness(answer, context_chunks)
90
+ rel = evaluator.evaluate_relevancy(query, answer)
91
+
92
+ tournament_results[name] = {
93
+ "Faithfulness": faith['score'],
94
+ "Relevancy": rel['score'],
95
+ "Claims": faith['details']
96
+ }
97
  except Exception as e:
98
+ print(f"Error evaluating {name}: {e}")
99
 
100
+ # 7. Final Output (Omitted for brevity, use your existing report logic)
101
 
102
  if __name__ == "__main__":
103
  main()
query_only.py CHANGED
@@ -1,6 +1,6 @@
1
  # This file is for inference without actually embedding documents
2
  # Main does embedding everytime, is redundant for querying.
3
- # made this just to test querying part --@Qamar
4
 
5
  import os
6
  import time
 
1
  # This file is for inference without actually embedding documents
2
  # Main does embedding everytime, is redundant for querying.
3
+ # made this just to test querying part
4
 
5
  import os
6
  import time
retriever/evaluator.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from sklearn.metrics.pairwise import cosine_similarity
3
+
4
+ class RAGEvaluator:
5
+ def __init__(self, judge_model, embedding_model, verbose=True):
6
+ """
7
+ judge_model: An instance of an LLM class.
8
+ embedding_model: The proc.encoder for similarity checks.
9
+ verbose: If True, uses internal printer functions to show progress.
10
+ """
11
+ self.judge = judge_model
12
+ self.encoder = embedding_model
13
+ self.verbose = verbose
14
+
15
+ # ------------------------------------------------------------------
16
+ # 1. FAITHFULNESS: Claim Extraction & Verification
17
+ # ------------------------------------------------------------------
18
+ def evaluate_faithfulness(self, answer, context_list):
19
+ if self.verbose:
20
+ self._print_extraction_header(len(answer))
21
+
22
+ # --- Step A: Extraction ---
23
+ extraction_prompt = f"Extract a list of independent factual claims from the following answer. Respond ONLY with the claims, one per line. Do not include any introductory text.\nAnswer: {answer}"
24
+ raw_claims = self.judge.generate(extraction_prompt)
25
+ claims = [c.strip() for c in raw_claims.split('\n') if len(c.strip()) > 5]
26
+
27
+ if not claims:
28
+ return {"score": 0, "details": []}
29
+
30
+ # --- Step B: Batch Verification ---
31
+ combined_context = "\n".join(context_list)
32
+ claims_formatted = "\n".join([f"{i+1}. {c}" for i, c in enumerate(claims)])
33
+
34
+ batch_prompt = f"Context: {combined_context}\nClaims: {claims_formatted}\nRespond YES/NO for each."
35
+ raw_verdicts = self.judge.generate(batch_prompt)
36
+ verdict_lines = [v.strip().upper() for v in raw_verdicts.split('\n') if v.strip()]
37
+
38
+ # --- Step C: Scoring & Details ---
39
+ verified_count = 0
40
+ details = []
41
+ for i, claim in enumerate(claims):
42
+ is_supported = "YES" in verdict_lines[i] if i < len(verdict_lines) else False
43
+ if is_supported: verified_count += 1
44
+
45
+ details.append({
46
+ "claim": claim,
47
+ "verdict": "Supported" if is_supported else "Not Supported"
48
+ })
49
+
50
+ score = (verified_count / len(claims)) * 100
51
+
52
+ if self.verbose:
53
+ self._print_faithfulness_results(claims, details, score)
54
+
55
+ return {"score": score, "details": details}
56
+
57
+ # ------------------------------------------------------------------
58
+ # 2. RELEVANCY: Alternate Query Generation
59
+ # ------------------------------------------------------------------
60
+ def evaluate_relevancy(self, query, answer):
61
+ if self.verbose:
62
+ self._print_relevancy_header()
63
+
64
+ # --- Step A: Generation ---
65
+ gen_prompt = f"Generate 3 distinct questions this answer addresses.\nAnswer: {answer}"
66
+ raw_gen = self.judge.generate(gen_prompt)
67
+ gen_queries = [q.strip() for q in raw_gen.split('\n') if '?' in q][:3]
68
+
69
+ if not gen_queries:
70
+ return {"score": 0, "queries": []}
71
+
72
+ # --- Step B: Similarity Logic ---
73
+ original_vec = self.encoder.encode([query])
74
+ generated_vecs = self.encoder.encode(gen_queries)
75
+ similarities = cosine_similarity(original_vec, generated_vecs)[0]
76
+ avg_score = np.mean(similarities)
77
+
78
+ if self.verbose:
79
+ self._print_relevancy_results(query, gen_queries, similarities, avg_score)
80
+
81
+ return {"score": avg_score, "queries": gen_queries}
82
+
83
+ # ------------------------------------------------------------------
84
+ # 3. PRINT HELPERS (Keep the logic above clean)
85
+ # ------------------------------------------------------------------
86
+ def _print_extraction_header(self, length):
87
+ print(f"\n[EVAL] Analyzing Faithfulness...")
88
+ print(f" - Extracting claims from answer ({length} chars)")
89
+
90
+ def _print_faithfulness_results(self, claims, details, score):
91
+ print(f" - Verifying {len(claims)} claims against context...")
92
+ for i, detail in enumerate(details):
93
+ status = "✅" if "Supported" in detail['verdict'] else "❌"
94
+ print(f" {status} Claim {i+1}: {detail['claim'][:75]}...")
95
+ print(f" 🎯 Faithfulness Score: {score:.1f}%")
96
+
97
+ def _print_relevancy_header(self):
98
+ print(f"\n[EVAL] Analyzing Relevancy...")
99
+ print(f" - Generating 3 sample questions addressed by the answer")
100
+
101
+ def _print_relevancy_results(self, query, gen_queries, similarities, avg):
102
+ print(f" - Comparing to original query: '{query}'")
103
+ for i, (q, sim) in enumerate(zip(gen_queries, similarities)):
104
+ print(f" Q{i+1}: {q} (Sim: {sim:.2f})")
105
+ print(f" 🎯 Average Relevancy: {avg:.2f}")
vector_db.py CHANGED
@@ -1,102 +1,94 @@
1
  import time
 
2
  from pinecone import Pinecone, ServerlessSpec
3
 
4
- def get_pinecone_index(api_key, index_name, dimension=384, metric="cosine"):
5
- """Initializes Pinecone and returns the index object, creating it if necessary."""
 
 
 
 
 
 
 
6
  pc = Pinecone(api_key=api_key)
 
 
7
 
8
- # Check if index exists
9
  existing_indexes = [idx.name for idx in pc.list_indexes()]
10
 
11
- if index_name not in existing_indexes:
12
- print(f"Creating new Pinecone index: {index_name}...")
13
  pc.create_index(
14
- name=index_name,
15
  dimension=dimension,
16
  metric=metric,
17
  spec=ServerlessSpec(cloud="aws", region="us-east-1")
18
  )
19
- # Wait for index to be ready
20
- while not pc.describe_index(index_name).status['ready']:
21
  time.sleep(1)
22
 
23
- return pc.Index(index_name)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  def prepare_vectors_for_upsert(final_chunks):
26
- """Convert final_chunks to the format expected by Pinecone upsert"""
27
  vectors = []
28
  for chunk in final_chunks:
 
29
  vectors.append({
30
  'id': chunk['id'],
31
- 'values': chunk['values'], # The embedding vector
32
  'metadata': {
33
- 'text': chunk['metadata']['text'],
34
- 'title': chunk['metadata']['title'],
35
- 'url': chunk['metadata']['url'],
36
- 'chunk_index': chunk['metadata']['chunk_index'],
37
- 'technique': chunk['metadata']['technique'],
38
- 'chunk_size': chunk['metadata']['chunk_size'],
39
- 'total_chunks': chunk['metadata']['total_chunks']
40
  }
41
  })
42
  return vectors
43
 
44
  def upsert_to_pinecone(index, chunks, batch_size=100):
45
- """Upserts chunks to Pinecone in manageable batches.
46
-
47
- Args:
48
- index: Pinecone index object
49
- chunks: List of chunk dictionaries (as returned by prepare_vectors_for_upsert)
50
- batch_size: Number of vectors to upsert in each batch
51
- """
52
- print(f"Uploading {len(chunks)} chunks to Pinecone...")
53
-
54
  for i in range(0, len(chunks), batch_size):
55
  batch = chunks[i : i + batch_size]
56
- index.upsert(vectors=batch)
57
- print(f" Uploaded batch {i//batch_size + 1}/{(len(chunks)-1)//batch_size + 1} ({len(batch)} vectors)")
58
-
59
- print(" Upsert complete.")
60
-
61
- def refresh_pinecone_index(index, final_chunks, batch_size=100):
62
- """Helper function to refresh index with new chunks.
63
-
64
- This function checks if the index has the expected number of vectors,
65
- and upserts if necessary.
66
-
67
- Args:
68
- index: Pinecone index object
69
- final_chunks: List of chunk dictionaries with embeddings
70
- batch_size: Batch size for upsert
71
-
72
- Returns:
73
- Boolean indicating if upsert was performed
74
- """
75
- try:
76
- stats = index.describe_index_stats()
77
- current_vector_count = stats.get('total_vector_count', 0)
78
- expected_vector_count = len(final_chunks)
79
-
80
- print(f"\n Current vectors in index: {current_vector_count}")
81
- print(f" Expected vectors: {expected_vector_count}")
82
-
83
- if current_vector_count == 0:
84
- print(" Index is empty. Preparing vectors for upsert...")
85
- vectors_to_upsert = prepare_vectors_for_upsert(final_chunks)
86
- upsert_to_pinecone(index, vectors_to_upsert, batch_size)
87
-
88
- # Verify upsert
89
- stats = index.describe_index_stats()
90
- print(f" After upsert - Total vectors: {stats.get('total_vector_count', 0)}")
91
- return True
92
- elif current_vector_count != expected_vector_count:
93
- print(f" Vector count mismatch. Expected {expected_vector_count}, got {current_vector_count}")
94
- print(" Consider recreating the index if you want to refresh.")
95
- return False
96
- else:
97
- print(f"ℹ Index already has {current_vector_count} vectors. Ready for search.")
98
- return False
99
-
100
- except Exception as e:
101
- print(f"Error checking index stats: {e}")
102
- return False
 
1
  import time
2
+ import re
3
  from pinecone import Pinecone, ServerlessSpec
4
 
5
+ def slugify_technique(name):
6
+ """Converts 'Sentence Splitter' to 'sentence-splitter' for Pinecone naming."""
7
+ return re.sub(r'[^a-z0-9]+', '-', name.lower()).strip('-')
8
+
9
+ def get_pinecone_index(api_key, base_name, technique, dimension=384, metric="cosine"):
10
+ """
11
+ Creates/Returns an index specifically for a technique.
12
+ Example: 'arxiv-index-token'
13
+ """
14
  pc = Pinecone(api_key=api_key)
15
+ tech_slug = slugify_technique(technique)
16
+ full_index_name = f"{base_name}-{tech_slug}"
17
 
 
18
  existing_indexes = [idx.name for idx in pc.list_indexes()]
19
 
20
+ if full_index_name not in existing_indexes:
21
+ print(f" Creating specialized index: {full_index_name}...")
22
  pc.create_index(
23
+ name=full_index_name,
24
  dimension=dimension,
25
  metric=metric,
26
  spec=ServerlessSpec(cloud="aws", region="us-east-1")
27
  )
28
+ # Wait for index to spin up
29
+ while not pc.describe_index(full_index_name).status['ready']:
30
  time.sleep(1)
31
 
32
+ print(f" Using Index: {full_index_name}")
33
+ return pc.Index(full_index_name)
34
+
35
+ def refresh_pinecone_index(index, final_chunks, batch_size=100):
36
+ """
37
+ Refreshes the specific index. Since index is now technique-specific,
38
+ we just check if it's already populated.
39
+ """
40
+ if not final_chunks:
41
+ print("No chunks provided to refresh.")
42
+ return False
43
+
44
+ try:
45
+ # Check current stats for this specific index
46
+ stats = index.describe_index_stats()
47
+ current_count = stats.get('total_vector_count', 0)
48
+ expected_count = len(final_chunks)
49
+
50
+ print(f" Index Stats -> Existing: {current_count} | New Chunks: {expected_count}")
51
 
52
+ if current_count == 0:
53
+ print(f"➕ Index is empty. Upserting {expected_count} vectors...")
54
+ vectors = prepare_vectors_for_upsert(final_chunks)
55
+ upsert_to_pinecone(index, vectors, batch_size)
56
+ return True
57
+
58
+ elif current_count < expected_count:
59
+ # Simple check to see if we need to top up or refresh
60
+ print(f" Vector count mismatch ({current_count} < {expected_count}). Updating index...")
61
+ vectors = prepare_vectors_for_upsert(final_chunks)
62
+ upsert_to_pinecone(index, vectors, batch_size)
63
+ return True
64
+
65
+ else:
66
+ print(f" Index is already populated with {current_count} vectors. Ready for search.")
67
+ return False
68
+
69
+ except Exception as e:
70
+ print(f" Error refreshing index: {e}")
71
+ return False
72
+
73
+ # Utility functions remain the same as previous version
74
  def prepare_vectors_for_upsert(final_chunks):
 
75
  vectors = []
76
  for chunk in final_chunks:
77
+ meta = chunk.get('metadata', {})
78
  vectors.append({
79
  'id': chunk['id'],
80
+ 'values': chunk['values'],
81
  'metadata': {
82
+ 'text': meta.get('text', ""),
83
+ 'title': meta.get('title', ""),
84
+ 'url': meta.get('url', ""),
85
+ 'chunk_index': meta.get('chunk_index', 0),
86
+ 'technique': meta.get('technique', "unknown")
 
 
87
  }
88
  })
89
  return vectors
90
 
91
  def upsert_to_pinecone(index, chunks, batch_size=100):
 
 
 
 
 
 
 
 
 
92
  for i in range(0, len(chunks), batch_size):
93
  batch = chunks[i : i + batch_size]
94
+ index.upsert(vectors=batch)