NavyDevilDoc commited on
Commit
7841205
·
verified ·
1 Parent(s): c88e290

Update src/rag_engine.py

Browse files
Files changed (1) hide show
  1. src/rag_engine.py +51 -52
src/rag_engine.py CHANGED
@@ -3,7 +3,7 @@ from langchain_chroma import Chroma
3
  from langchain_huggingface import HuggingFaceEmbeddings
4
  from sentence_transformers import CrossEncoder
5
  from core.ChunkingManager import ChunkingManager, ChunkingStrategy
6
- import tracker # To trigger syncs
7
 
8
  # --- CONFIGURATION ---
9
  UPLOAD_DIR = "/tmp/rag_uploads"
@@ -12,14 +12,11 @@ EMBEDDING_MODEL_NAME = "all-MiniLM-L6-v2"
12
  RERANKER_MODEL_NAME = "cross-encoder/ms-marco-MiniLM-L-6-v2"
13
 
14
  # --- LAZY LOADING SINGLETONS ---
15
- # We use these globals to store the models once loaded, so we don't reload them
16
- # every time a function is called, but we also don't load them on import.
17
  _embedding_fn = None
18
  _reranker = None
19
  _chunk_manager = None
20
 
21
  def get_embedding_function():
22
- """Lazy loads the embedding model only when needed."""
23
  global _embedding_fn
24
  if _embedding_fn is None:
25
  print("⚙️ Loading Embedding Model...")
@@ -27,7 +24,6 @@ def get_embedding_function():
27
  return _embedding_fn
28
 
29
  def get_reranker_model():
30
- """Lazy loads the CrossEncoder only when needed."""
31
  global _reranker
32
  if _reranker is None:
33
  print("⚙️ Loading Reranker Model...")
@@ -35,7 +31,6 @@ def get_reranker_model():
35
  return _reranker
36
 
37
  def get_chunk_manager():
38
- """Lazy loads the Chunking Manager."""
39
  global _chunk_manager
40
  if _chunk_manager is None:
41
  print("⚙️ Loading Chunk Manager...")
@@ -44,8 +39,6 @@ def get_chunk_manager():
44
 
45
  # --- DATABASE OPERATIONS ---
46
  def get_vectorstore(username):
47
- """Returns the persistent ChromaDB for a SPECIFIC USER."""
48
- # Safety: Ensure username doesn't contain path traversal characters
49
  safe_username = os.path.basename(username)
50
  user_db_path = os.path.join(DB_ROOT, safe_username)
51
 
@@ -59,14 +52,10 @@ def get_vectorstore(username):
59
  )
60
 
61
  def save_uploaded_file(uploaded_file):
62
- """Saves upload to temp, sanitizing the filename."""
63
  if not os.path.exists(UPLOAD_DIR):
64
  os.makedirs(UPLOAD_DIR)
65
-
66
- # SECURITY FIX: Sanitize filename to prevent directory traversal
67
  safe_filename = os.path.basename(uploaded_file.name)
68
  file_path = os.path.join(UPLOAD_DIR, safe_filename)
69
-
70
  with open(file_path, "wb") as f:
71
  f.write(uploaded_file.getbuffer())
72
  return file_path
@@ -82,7 +71,6 @@ def process_and_add_document(file_path, username, strategy="paragraph"):
82
  }
83
  selected_strategy = strat_map.get(strategy, ChunkingStrategy.PARAGRAPH)
84
 
85
- # Use the lazy-loaded chunk manager
86
  manager = get_chunk_manager()
87
  chunks = manager.process_document(
88
  file_path=file_path,
@@ -93,11 +81,14 @@ def process_and_add_document(file_path, username, strategy="paragraph"):
93
  if not chunks:
94
  return False, "No text extracted. Is the file empty/scanned?"
95
 
 
 
 
 
96
  print(f"💾 Indexing {len(chunks)} chunks into Vector DB...")
97
  db = get_vectorstore(username)
98
  db.add_documents(chunks)
99
 
100
- # Sync immediately
101
  tracker.upload_user_db(username)
102
 
103
  if os.path.exists(file_path):
@@ -110,50 +101,56 @@ def process_and_add_document(file_path, username, strategy="paragraph"):
110
  return False, str(e)
111
 
112
  # --- RETRIEVAL ENGINE ---
113
- def search_knowledge_base(query, username, k=10):
114
  """
115
  Two-Stage Retrieval System (RAG):
116
  1. Retrieval: Get 10 candidates via fast Vector Search.
117
  2. Reranking: Sort them via Cross-Encoder (Slow/Precise).
118
  3. Return top k.
119
  """
120
- db = get_vectorstore(username)
121
- reranker = get_reranker_model()
122
-
123
- # 1. Broad Search (Retrieve more than needed to filter later)
124
- results = db.similarity_search(query, k=10)
125
-
126
- if not results:
127
- return []
128
 
129
- # 2. Reranking
130
- # Prepare pairs: [[Query, Text1], [Query, Text2]...]
131
- passages = [doc.page_content for doc in results]
132
- ranks = reranker.rank(query, passages)
133
-
134
- # 3. Sort and Filter
135
- # Reranker returns list of dicts: {'corpus_id': 0, 'score': 0.9}
136
- top_results = []
137
-
138
- # Sort ranks by score descending just to be safe (though .rank() usually sorts)
139
- sorted_ranks = sorted(ranks, key=lambda x: x['score'], reverse=True)
140
-
141
- for rank in sorted_ranks[:k]:
142
- doc_index = rank['corpus_id']
143
- doc = results[doc_index]
144
- # Append score for transparency
145
- doc.metadata["relevance_score"] = round(rank['score'], 4)
146
- top_results.append(doc)
147
 
148
- return top_results
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
149
 
150
  def list_documents(username):
151
- """
152
- Returns a list of unique files currently in the user's database.
153
- WARNING: This pulls all metadata. Performance degrades >10k chunks.
154
- """
155
  try:
156
  db = get_vectorstore(username)
 
 
 
 
157
  data = db.get()
158
  metadatas = data['metadatas']
159
 
@@ -162,9 +159,16 @@ def list_documents(username):
162
  for meta in metadatas:
163
  src = meta.get('source', 'unknown')
164
  filename = os.path.basename(src)
 
 
165
 
166
  if src not in file_stats:
167
- file_stats[src] = {'source': src, 'filename': filename, 'chunks': 0}
 
 
 
 
 
168
  file_stats[src]['chunks'] += 1
169
 
170
  return list(file_stats.values())
@@ -174,21 +178,16 @@ def list_documents(username):
174
  return []
175
 
176
  def delete_document(username, source_path):
177
- """Removes all chunks associated with a specific source file."""
178
  try:
179
  print(f"🗑️ Deleting {source_path} for {username}...")
180
  db = get_vectorstore(username)
181
-
182
  db.delete(where={"source": source_path})
183
-
184
  tracker.upload_user_db(username)
185
  return True, f"Deleted {os.path.basename(source_path)}"
186
-
187
  except Exception as e:
188
  return False, str(e)
189
 
190
  def reset_knowledge_base(username):
191
- """Nuke option: Clears the entire database for the user."""
192
  try:
193
  db = get_vectorstore(username)
194
  db.delete_collection()
 
3
  from langchain_huggingface import HuggingFaceEmbeddings
4
  from sentence_transformers import CrossEncoder
5
  from core.ChunkingManager import ChunkingManager, ChunkingStrategy
6
+ import tracker
7
 
8
  # --- CONFIGURATION ---
9
  UPLOAD_DIR = "/tmp/rag_uploads"
 
12
  RERANKER_MODEL_NAME = "cross-encoder/ms-marco-MiniLM-L-6-v2"
13
 
14
  # --- LAZY LOADING SINGLETONS ---
 
 
15
  _embedding_fn = None
16
  _reranker = None
17
  _chunk_manager = None
18
 
19
  def get_embedding_function():
 
20
  global _embedding_fn
21
  if _embedding_fn is None:
22
  print("⚙️ Loading Embedding Model...")
 
24
  return _embedding_fn
25
 
26
  def get_reranker_model():
 
27
  global _reranker
28
  if _reranker is None:
29
  print("⚙️ Loading Reranker Model...")
 
31
  return _reranker
32
 
33
  def get_chunk_manager():
 
34
  global _chunk_manager
35
  if _chunk_manager is None:
36
  print("⚙️ Loading Chunk Manager...")
 
39
 
40
  # --- DATABASE OPERATIONS ---
41
  def get_vectorstore(username):
 
 
42
  safe_username = os.path.basename(username)
43
  user_db_path = os.path.join(DB_ROOT, safe_username)
44
 
 
52
  )
53
 
54
  def save_uploaded_file(uploaded_file):
 
55
  if not os.path.exists(UPLOAD_DIR):
56
  os.makedirs(UPLOAD_DIR)
 
 
57
  safe_filename = os.path.basename(uploaded_file.name)
58
  file_path = os.path.join(UPLOAD_DIR, safe_filename)
 
59
  with open(file_path, "wb") as f:
60
  f.write(uploaded_file.getbuffer())
61
  return file_path
 
71
  }
72
  selected_strategy = strat_map.get(strategy, ChunkingStrategy.PARAGRAPH)
73
 
 
74
  manager = get_chunk_manager()
75
  chunks = manager.process_document(
76
  file_path=file_path,
 
81
  if not chunks:
82
  return False, "No text extracted. Is the file empty/scanned?"
83
 
84
+ # FIX #1: Tag every chunk with the strategy used
85
+ for chunk in chunks:
86
+ chunk.metadata["strategy"] = strategy
87
+
88
  print(f"💾 Indexing {len(chunks)} chunks into Vector DB...")
89
  db = get_vectorstore(username)
90
  db.add_documents(chunks)
91
 
 
92
  tracker.upload_user_db(username)
93
 
94
  if os.path.exists(file_path):
 
101
  return False, str(e)
102
 
103
  # --- RETRIEVAL ENGINE ---
104
+ def search_knowledge_base(query, username, k=3):
105
  """
106
  Two-Stage Retrieval System (RAG):
107
  1. Retrieval: Get 10 candidates via fast Vector Search.
108
  2. Reranking: Sort them via Cross-Encoder (Slow/Precise).
109
  3. Return top k.
110
  """
111
+ try:
112
+ db = get_vectorstore(username)
113
+
114
+ # FIX #3: Graceful handling for empty/missing DB
115
+ # If the collection is empty, Chroma sometimes throws an error or returns nothing.
116
+ # We check count first to be safe.
117
+ if db._collection.count() == 0:
118
+ return []
119
 
120
+ reranker = get_reranker_model()
121
+
122
+ # 1. Broad Search
123
+ results = db.similarity_search(query, k=10)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
124
 
125
+ if not results:
126
+ return []
127
+
128
+ # 2. Reranking
129
+ passages = [doc.page_content for doc in results]
130
+ ranks = reranker.rank(query, passages)
131
+
132
+ top_results = []
133
+ sorted_ranks = sorted(ranks, key=lambda x: x['score'], reverse=True)
134
+
135
+ for rank in sorted_ranks[:k]:
136
+ doc_index = rank['corpus_id']
137
+ doc = results[doc_index]
138
+ doc.metadata["relevance_score"] = round(rank['score'], 4)
139
+ top_results.append(doc)
140
+
141
+ return top_results
142
+
143
+ except Exception as e:
144
+ print(f"⚠️ Search Error (likely empty DB): {e}")
145
+ return []
146
 
147
  def list_documents(username):
 
 
 
 
148
  try:
149
  db = get_vectorstore(username)
150
+ # Check if empty before fetching to prevent errors
151
+ if db._collection.count() == 0:
152
+ return []
153
+
154
  data = db.get()
155
  metadatas = data['metadatas']
156
 
 
159
  for meta in metadatas:
160
  src = meta.get('source', 'unknown')
161
  filename = os.path.basename(src)
162
+ # FIX #2: Retrieve the strategy (Default to 'unknown' for old docs)
163
+ strat = meta.get('strategy', 'unknown')
164
 
165
  if src not in file_stats:
166
+ file_stats[src] = {
167
+ 'source': src,
168
+ 'filename': filename,
169
+ 'chunks': 0,
170
+ 'strategy': strat
171
+ }
172
  file_stats[src]['chunks'] += 1
173
 
174
  return list(file_stats.values())
 
178
  return []
179
 
180
  def delete_document(username, source_path):
 
181
  try:
182
  print(f"🗑️ Deleting {source_path} for {username}...")
183
  db = get_vectorstore(username)
 
184
  db.delete(where={"source": source_path})
 
185
  tracker.upload_user_db(username)
186
  return True, f"Deleted {os.path.basename(source_path)}"
 
187
  except Exception as e:
188
  return False, str(e)
189
 
190
  def reset_knowledge_base(username):
 
191
  try:
192
  db = get_vectorstore(username)
193
  db.delete_collection()