07Codex07 commited on
Commit
41d23d8
·
1 Parent(s): 7fdfc47

changed the context and retrieval

Browse files
Files changed (2) hide show
  1. chatbot_retriever.py +63 -9
  2. main_api.py +39 -5
chatbot_retriever.py CHANGED
@@ -100,12 +100,13 @@ def ensure_data_dir():
100
  ]
101
 
102
  local_paths = []
 
103
  for f in files:
104
  dest_path = os.path.join(data_dir, f) # ✅ keep real folder structure
105
  os.makedirs(os.path.dirname(dest_path), exist_ok=True)
106
 
107
  if not os.path.exists(dest_path):
108
- print(f"📥 Downloading {f} from Hugging Face (public dataset)...")
109
  downloaded = hf_hub_download(
110
  repo_id=DATASET_REPO,
111
  filename=f,
@@ -113,13 +114,17 @@ def ensure_data_dir():
113
  force_download=True,
114
  )
115
  shutil.copy(downloaded, dest_path) # ✅ copy instead of rename (works inside HF Spaces)
 
116
 
117
  local_paths.append(dest_path)
118
 
119
- # Debug info for verification
120
- print(f"✅ Total files ensured: {len(local_paths)}")
121
- for p in local_paths[:3]:
122
- print(f" {p}")
 
 
 
123
 
124
  return local_paths
125
 
@@ -206,14 +211,24 @@ def load_all_docs(base_dir: str = DATA_DIR) -> List:
206
 
207
  # ---------- Build / load FAISS + BM25 ----------
208
  def build_or_load_indexes(force_reindex: bool = False):
 
209
  if os.getenv("FORCE_REINDEX", "0").lower() in ("1", "true", "yes"):
210
  force_reindex = True
211
 
212
- ensure_data_dir()
 
 
 
 
 
 
 
213
  docs = load_all_docs(DATA_DIR)
214
  if not docs:
215
- logger.warning("No documents found. Returning empty indexes.")
216
  return [], None, [], [], None
 
 
217
 
218
  # chunking
219
  if os.path.exists(CHUNKS_CACHE) and not force_reindex:
@@ -362,9 +377,12 @@ def build_or_load_indexes(force_reindex: bool = False):
362
 
363
  # ---------- Hybrid retrieve ----------
364
  def _ensure_index_built():
 
365
  if not hasattr(hybrid_retrieve, "_index_built") or not hybrid_retrieve._index_built:
 
366
  hybrid_retrieve._chunks, hybrid_retrieve._bm25, hybrid_retrieve._tokenized, hybrid_retrieve._corpus, hybrid_retrieve._faiss = build_or_load_indexes()
367
  hybrid_retrieve._index_built = True
 
368
 
369
 
370
  def _faiss_search(query: str, top_k: int = TOP_K_DOCS, subject: Optional[str] = None):
@@ -408,12 +426,19 @@ def _faiss_search(query: str, top_k: int = TOP_K_DOCS, subject: Optional[str] =
408
 
409
  def hybrid_retrieve(query: str, subject: Optional[str] = None, top_k: int = TOP_K_DOCS, max_chars: int = MAX_CONTEXT_CHARS) -> Dict[str, Any]:
410
  if not query:
 
411
  return {"context": None, "bm25_docs": [], "faiss_docs": [], "meta": []}
412
 
413
  _ensure_index_built()
414
 
415
  chunks = hybrid_retrieve._chunks
416
  bm25 = hybrid_retrieve._bm25
 
 
 
 
 
 
417
 
418
  # BM25
419
  results_bm25 = []
@@ -423,7 +448,11 @@ def hybrid_retrieve(query: str, subject: Optional[str] = None, top_k: int = TOP_
423
  scores = bm25.get_scores(q_tokens)
424
  ranked_idx = sorted(range(len(scores)), key=lambda i: scores[i], reverse=True)[:top_k]
425
  for i in ranked_idx:
426
- results_bm25.append((float(scores[i]), chunks[i].metadata, chunks[i].page_content))
 
 
 
 
427
  except Exception:
428
  logger.exception("BM25 search failed")
429
 
@@ -431,6 +460,7 @@ def hybrid_retrieve(query: str, subject: Optional[str] = None, top_k: int = TOP_
431
  results_faiss = []
432
  try:
433
  results_faiss = _faiss_search(query, top_k=top_k, subject=subject)
 
434
  except Exception:
435
  logger.exception("FAISS search failed")
436
 
@@ -458,16 +488,26 @@ def hybrid_retrieve(query: str, subject: Optional[str] = None, top_k: int = TOP_
458
 
459
  # compose context parts with headers
460
  context_parts = []
 
461
  for i, t in enumerate(merged_texts):
 
 
 
 
462
  header = f"\n\n===== DOC {i+1} =====\n"
463
  context_parts.append(header + t)
 
464
  context = "\n".join(context_parts).strip()
465
  if not context:
 
 
466
  return {"context": None, "bm25_docs": results_bm25, "faiss_docs": results_faiss, "meta": merged_meta}
467
 
468
  if len(context) > max_chars:
469
  context = context[:max_chars].rstrip() + "..."
 
470
 
 
471
  return {"context": context, "bm25_docs": results_bm25, "faiss_docs": results_faiss, "meta": merged_meta}
472
 
473
 
@@ -477,14 +517,28 @@ def _last_n_user_messages(rows: List[tuple], n: int = 1) -> List[str]:
477
  return users[-1:] # always return ONLY the latest user query # only keep the last one
478
 
479
  def retrieve_node_from_rows(rows: List[tuple], top_k: int = TOP_K_DOCS) -> Dict[str, Any]:
 
480
  last_users = _last_n_user_messages(rows, n=1)
481
  current_query = " ".join(last_users).strip() if last_users else ""
482
  if not current_query:
 
483
  return {"context": None, "direct": False}
 
 
 
484
  detected = None
485
  try:
486
  detected = detect_subject(current_query)
 
 
487
  except Exception:
488
  detected = None
 
489
  result = hybrid_retrieve(current_query, subject=detected, top_k=top_k, max_chars=MAX_CONTEXT_CHARS)
490
- return {"context": result.get("context"), "direct": False}
 
 
 
 
 
 
 
100
  ]
101
 
102
  local_paths = []
103
+ downloaded_count = 0
104
  for f in files:
105
  dest_path = os.path.join(data_dir, f) # ✅ keep real folder structure
106
  os.makedirs(os.path.dirname(dest_path), exist_ok=True)
107
 
108
  if not os.path.exists(dest_path):
109
+ logger.info(f"📥 Downloading {f} from Hugging Face (public dataset)...")
110
  downloaded = hf_hub_download(
111
  repo_id=DATASET_REPO,
112
  filename=f,
 
114
  force_download=True,
115
  )
116
  shutil.copy(downloaded, dest_path) # ✅ copy instead of rename (works inside HF Spaces)
117
+ downloaded_count += 1
118
 
119
  local_paths.append(dest_path)
120
 
121
+ # Only print summary if files were actually downloaded
122
+ if downloaded_count > 0:
123
+ logger.info(f"✅ Downloaded {downloaded_count} new file(s). Total files ensured: {len(local_paths)}")
124
+ for p in local_paths[:3]:
125
+ logger.debug(f" → {p}")
126
+ else:
127
+ logger.debug(f"✅ All {len(local_paths)} data files already exist")
128
 
129
  return local_paths
130
 
 
211
 
212
  # ---------- Build / load FAISS + BM25 ----------
213
  def build_or_load_indexes(force_reindex: bool = False):
214
+ """Build or load FAISS and BM25 indexes. Returns (chunks, bm25, tokenized, corpus_texts, faiss_data)."""
215
  if os.getenv("FORCE_REINDEX", "0").lower() in ("1", "true", "yes"):
216
  force_reindex = True
217
 
218
+ # Only ensure data dir if files don't exist (check a sample file to avoid repeated calls)
219
+ sample_file = os.path.join(DATA_DIR, "cn.pdf")
220
+ if not os.path.exists(sample_file) or force_reindex:
221
+ logger.info("Data files missing or force_reindex=True, ensuring data directory...")
222
+ ensure_data_dir()
223
+ else:
224
+ logger.debug("Data files already exist, skipping ensure_data_dir()")
225
+
226
  docs = load_all_docs(DATA_DIR)
227
  if not docs:
228
+ logger.warning("No documents found in %s. Returning empty indexes.", DATA_DIR)
229
  return [], None, [], [], None
230
+
231
+ logger.info("Loaded %d document pages from %s", len(docs), DATA_DIR)
232
 
233
  # chunking
234
  if os.path.exists(CHUNKS_CACHE) and not force_reindex:
 
377
 
378
  # ---------- Hybrid retrieve ----------
379
  def _ensure_index_built():
380
+ """Ensure indexes are built. Only rebuilds if not already initialized."""
381
  if not hasattr(hybrid_retrieve, "_index_built") or not hybrid_retrieve._index_built:
382
+ logger.info("Initializing indexes for hybrid_retrieve...")
383
  hybrid_retrieve._chunks, hybrid_retrieve._bm25, hybrid_retrieve._tokenized, hybrid_retrieve._corpus, hybrid_retrieve._faiss = build_or_load_indexes()
384
  hybrid_retrieve._index_built = True
385
+ logger.info("Indexes initialized: %d chunks available", len(hybrid_retrieve._chunks) if hybrid_retrieve._chunks else 0)
386
 
387
 
388
  def _faiss_search(query: str, top_k: int = TOP_K_DOCS, subject: Optional[str] = None):
 
426
 
427
  def hybrid_retrieve(query: str, subject: Optional[str] = None, top_k: int = TOP_K_DOCS, max_chars: int = MAX_CONTEXT_CHARS) -> Dict[str, Any]:
428
  if not query:
429
+ logger.warning("hybrid_retrieve called with empty query")
430
  return {"context": None, "bm25_docs": [], "faiss_docs": [], "meta": []}
431
 
432
  _ensure_index_built()
433
 
434
  chunks = hybrid_retrieve._chunks
435
  bm25 = hybrid_retrieve._bm25
436
+
437
+ if not chunks:
438
+ logger.error("No chunks available for retrieval. Indexes may not be built correctly.")
439
+ return {"context": None, "bm25_docs": [], "faiss_docs": [], "meta": []}
440
+
441
+ logger.debug("Retrieving for query: %s (top_k=%d)", query[:50], top_k)
442
 
443
  # BM25
444
  results_bm25 = []
 
448
  scores = bm25.get_scores(q_tokens)
449
  ranked_idx = sorted(range(len(scores)), key=lambda i: scores[i], reverse=True)[:top_k]
450
  for i in ranked_idx:
451
+ if i < len(chunks):
452
+ results_bm25.append((float(scores[i]), chunks[i].metadata, chunks[i].page_content))
453
+ logger.debug("BM25 found %d results", len(results_bm25))
454
+ else:
455
+ logger.warning("BM25 index is None")
456
  except Exception:
457
  logger.exception("BM25 search failed")
458
 
 
460
  results_faiss = []
461
  try:
462
  results_faiss = _faiss_search(query, top_k=top_k, subject=subject)
463
+ logger.debug("FAISS found %d results", len(results_faiss))
464
  except Exception:
465
  logger.exception("FAISS search failed")
466
 
 
488
 
489
  # compose context parts with headers
490
  context_parts = []
491
+ seen_texts = set() # Deduplicate by text content
492
  for i, t in enumerate(merged_texts):
493
+ # Deduplicate: skip if we've seen this text before
494
+ if t in seen_texts:
495
+ continue
496
+ seen_texts.add(t)
497
  header = f"\n\n===== DOC {i+1} =====\n"
498
  context_parts.append(header + t)
499
+
500
  context = "\n".join(context_parts).strip()
501
  if not context:
502
+ logger.warning("No context generated from retrieval for query: %s (BM25: %d, FAISS: %d results)",
503
+ query[:50], len(results_bm25), len(results_faiss))
504
  return {"context": None, "bm25_docs": results_bm25, "faiss_docs": results_faiss, "meta": merged_meta}
505
 
506
  if len(context) > max_chars:
507
  context = context[:max_chars].rstrip() + "..."
508
+ logger.debug("Context truncated from %d to %d characters", len("\n".join(context_parts)), max_chars)
509
 
510
+ logger.info("Retrieved context: %d characters from %d documents", len(context), len(context_parts))
511
  return {"context": context, "bm25_docs": results_bm25, "faiss_docs": results_faiss, "meta": merged_meta}
512
 
513
 
 
517
  return users[-1:] # always return ONLY the latest user query # only keep the last one
518
 
519
  def retrieve_node_from_rows(rows: List[tuple], top_k: int = TOP_K_DOCS) -> Dict[str, Any]:
520
+ """Retrieve context from documents based on the last user message in rows."""
521
  last_users = _last_n_user_messages(rows, n=1)
522
  current_query = " ".join(last_users).strip() if last_users else ""
523
  if not current_query:
524
+ logger.warning("retrieve_node_from_rows: No user query found in rows")
525
  return {"context": None, "direct": False}
526
+
527
+ logger.debug("retrieve_node_from_rows: Query='%s'", current_query[:50])
528
+
529
  detected = None
530
  try:
531
  detected = detect_subject(current_query)
532
+ if detected:
533
+ logger.debug("Detected subject: %s", detected)
534
  except Exception:
535
  detected = None
536
+
537
  result = hybrid_retrieve(current_query, subject=detected, top_k=top_k, max_chars=MAX_CONTEXT_CHARS)
538
+ context = result.get("context")
539
+ if context:
540
+ logger.info("retrieve_node_from_rows: Successfully retrieved %d characters of context", len(context))
541
+ else:
542
+ logger.warning("retrieve_node_from_rows: No context retrieved for query: %s", current_query[:50])
543
+
544
+ return {"context": context, "direct": False}
main_api.py CHANGED
@@ -16,10 +16,8 @@ from memory_store import init_db, save_message, get_last_messages, clear_user_me
16
  from chatbot_retriever import build_or_load_indexes, hybrid_retrieve, retrieve_node_from_rows, load_all_docs, ensure_data_dir # :contentReference[oaicite:5]{index=5}
17
  from chatbot_graph import SYSTEM_PROMPT, call_llm, _extract_answer_from_response # :contentReference[oaicite:6]{index=6}
18
 
19
- ensure_data_dir()
20
  # ----------------- CORS SETUP -----------------
21
- from fastapi.middleware.cors import CORSMiddleware
22
- ensure_data_dir()
23
  app = FastAPI(title="RAG Chat Backend", version="1.0")
24
 
25
  from fastapi.middleware.cors import CORSMiddleware
@@ -73,13 +71,42 @@ class RetrieveResponse(BaseModel):
73
  def ensure_indexes(force_reindex: bool = False):
74
  """
75
  Build or load indexes synchronously. This wraps build_or_load_indexes from chatbot_retriever.
 
76
  """
 
 
 
 
 
77
  if INDEXES["built"] and not force_reindex:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
  return INDEXES["info"]
 
79
  try:
80
  chunks, bm25, tokenized, corpus_texts, faiss_data = build_or_load_indexes(force_reindex=force_reindex)
 
 
 
 
 
 
 
81
  INDEXES["built"] = True
82
  INDEXES["info"] = {"chunks_len": len(chunks) if chunks else 0, "corpus_len": len(corpus_texts) if corpus_texts else 0}
 
83
  return INDEXES["info"]
84
  except Exception:
85
  logger.exception("Index build/load failed")
@@ -221,8 +248,12 @@ def chat(req: ChatRequest):
221
  try:
222
  retrieved = retrieve_node_from_rows(rows)
223
  context = retrieved.get("context")
224
- except Exception:
225
- logger.exception("retriever call failed")
 
 
 
 
226
  context = None
227
 
228
  # 5) build system prompt content
@@ -242,6 +273,9 @@ def chat(req: ChatRequest):
242
  system_content = SYSTEM_PROMPT
243
  if trimmed_context:
244
  system_content += "\n\n===== RETRIEVED CONTEXT =====\n" + trimmed_context
 
 
 
245
  # build prompt messages as list of simple dicts (call_llm expects same message format as in chatbot_graph)
246
  # chatbot_graph.call_llm expects langchain messages (SystemMessage/HumanMessage) — we built that in original file.
247
  # create messages as minimal objects that call_llm can accept (we rely on original call_llm).
 
16
  from chatbot_retriever import build_or_load_indexes, hybrid_retrieve, retrieve_node_from_rows, load_all_docs, ensure_data_dir # :contentReference[oaicite:5]{index=5}
17
  from chatbot_graph import SYSTEM_PROMPT, call_llm, _extract_answer_from_response # :contentReference[oaicite:6]{index=6}
18
 
 
19
  # ----------------- CORS SETUP -----------------
20
+ from fastapi.middleware.cors import CORSMiddleware
 
21
  app = FastAPI(title="RAG Chat Backend", version="1.0")
22
 
23
  from fastapi.middleware.cors import CORSMiddleware
 
71
  def ensure_indexes(force_reindex: bool = False):
72
  """
73
  Build or load indexes synchronously. This wraps build_or_load_indexes from chatbot_retriever.
74
+ Also initializes hybrid_retrieve module variables to avoid reindexing.
75
  """
76
+ # Check if hybrid_retrieve already has indexes built (avoid duplicate work)
77
+ if hasattr(hybrid_retrieve, "_index_built") and hybrid_retrieve._index_built and not force_reindex:
78
+ if INDEXES["built"]:
79
+ return INDEXES["info"]
80
+
81
  if INDEXES["built"] and not force_reindex:
82
+ # Ensure hybrid_retrieve module variables are also set
83
+ if not hasattr(hybrid_retrieve, "_index_built") or not hybrid_retrieve._index_built:
84
+ # Indexes exist but hybrid_retrieve wasn't initialized, reload them
85
+ try:
86
+ chunks, bm25, tokenized, corpus_texts, faiss_data = build_or_load_indexes(force_reindex=False)
87
+ hybrid_retrieve._chunks = chunks
88
+ hybrid_retrieve._bm25 = bm25
89
+ hybrid_retrieve._tokenized = tokenized
90
+ hybrid_retrieve._corpus = corpus_texts
91
+ hybrid_retrieve._faiss = faiss_data
92
+ hybrid_retrieve._index_built = True
93
+ logger.info("Initialized hybrid_retrieve module variables from existing indexes")
94
+ except Exception:
95
+ logger.exception("Failed to initialize hybrid_retrieve variables")
96
  return INDEXES["info"]
97
+
98
  try:
99
  chunks, bm25, tokenized, corpus_texts, faiss_data = build_or_load_indexes(force_reindex=force_reindex)
100
+ # Set module-level variables in hybrid_retrieve to avoid rebuilding
101
+ hybrid_retrieve._chunks = chunks
102
+ hybrid_retrieve._bm25 = bm25
103
+ hybrid_retrieve._tokenized = tokenized
104
+ hybrid_retrieve._corpus = corpus_texts
105
+ hybrid_retrieve._faiss = faiss_data
106
+ hybrid_retrieve._index_built = True
107
  INDEXES["built"] = True
108
  INDEXES["info"] = {"chunks_len": len(chunks) if chunks else 0, "corpus_len": len(corpus_texts) if corpus_texts else 0}
109
+ logger.info("Indexes built/loaded: %d chunks, %d corpus texts", INDEXES["info"]["chunks_len"], INDEXES["info"]["corpus_len"])
110
  return INDEXES["info"]
111
  except Exception:
112
  logger.exception("Index build/load failed")
 
248
  try:
249
  retrieved = retrieve_node_from_rows(rows)
250
  context = retrieved.get("context")
251
+ if context:
252
+ logger.info("Retrieved context: %d characters from documents", len(context))
253
+ else:
254
+ logger.warning("Retriever returned empty context for query: %s", req.message)
255
+ except Exception as e:
256
+ logger.exception("retriever call failed: %s", e)
257
  context = None
258
 
259
  # 5) build system prompt content
 
273
  system_content = SYSTEM_PROMPT
274
  if trimmed_context:
275
  system_content += "\n\n===== RETRIEVED CONTEXT =====\n" + trimmed_context
276
+ logger.debug("Added context to system prompt: %d characters", len(trimmed_context))
277
+ else:
278
+ logger.warning("No context to add to system prompt for query: %s", req.message)
279
  # build prompt messages as list of simple dicts (call_llm expects same message format as in chatbot_graph)
280
  # chatbot_graph.call_llm expects langchain messages (SystemMessage/HumanMessage) — we built that in original file.
281
  # create messages as minimal objects that call_llm can accept (we rely on original call_llm).