eaglelandsonce commited on
Commit
d105d84
·
verified ·
1 Parent(s): e813723

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +712 -0
app.py ADDED
@@ -0,0 +1,712 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import uuid
3
+ import json
4
+ from typing import List, Tuple, Dict, Any, Optional
5
+
6
+ import chromadb
7
+ from chromadb.config import Settings
8
+ from openai import OpenAI
9
+ import gradio as gr
10
+ from pypdf import PdfReader
11
+
12
+ # Cross-encoder (Hugging Face / sentence-transformers)
13
+ # pip install sentence-transformers torch
14
+ from sentence_transformers import CrossEncoder
15
+
16
+
17
+ # =========================
18
+ # Chroma Client (Persistent)
19
+ # =========================
20
+
21
+ chroma_client = chromadb.PersistentClient(
22
+ path="chroma_db",
23
+ settings=Settings(anonymized_telemetry=False),
24
+ )
25
+
26
+ collection = chroma_client.get_or_create_collection(
27
+ name="rag_docs",
28
+ metadata={"hnsw:space": "cosine"},
29
+ )
30
+
31
+
32
+ # =========================
33
+ # Cross-Encoder (lazy global)
34
+ # =========================
35
+
36
+ _CROSS_ENCODER: Optional[CrossEncoder] = None
37
+ CROSS_ENCODER_MODEL_NAME = "cross-encoder/ms-marco-MiniLM-L-6-v2"
38
+
39
+
40
+ def get_cross_encoder() -> CrossEncoder:
41
+ global _CROSS_ENCODER
42
+ if _CROSS_ENCODER is None:
43
+ _CROSS_ENCODER = CrossEncoder(CROSS_ENCODER_MODEL_NAME)
44
+ return _CROSS_ENCODER
45
+
46
+
47
+ # =========================
48
+ # Helper Functions
49
+ # =========================
50
+
51
+ def get_openai_client(api_key: str) -> OpenAI:
52
+ if not api_key or not api_key.strip():
53
+ raise ValueError("OpenAI API key is missing.")
54
+ return OpenAI(api_key=api_key.strip())
55
+
56
+
57
+ def extract_text_from_file(file_path: str) -> str:
58
+ ext = os.path.splitext(file_path)[1].lower()
59
+
60
+ if ext in [".txt", ".md"]:
61
+ with open(file_path, "r", encoding="utf-8", errors="ignore") as f:
62
+ return f.read()
63
+
64
+ if ext == ".pdf":
65
+ text = []
66
+ reader = PdfReader(file_path)
67
+ for page in reader.pages:
68
+ page_text = page.extract_text()
69
+ if page_text:
70
+ text.append(page_text)
71
+ return "\n".join(text)
72
+
73
+ with open(file_path, "r", encoding="utf-8", errors="ignore") as f:
74
+ return f.read()
75
+
76
+
77
+ def chunk_text(text: str, chunk_size: int = 800, overlap: int = 200) -> List[str]:
78
+ text = text.replace("\r\n", "\n").replace("\r", "\n")
79
+ chunks = []
80
+ start = 0
81
+ while start < len(text):
82
+ end = start + chunk_size
83
+ chunks.append(text[start:end])
84
+ start += chunk_size - overlap
85
+ return chunks
86
+
87
+
88
+ def embed_texts(texts: List[str], api_key: str) -> List[List[float]]:
89
+ if not texts:
90
+ return []
91
+ client = get_openai_client(api_key)
92
+ resp = client.embeddings.create(
93
+ model="text-embedding-3-small",
94
+ input=texts,
95
+ )
96
+ return [d.embedding for d in resp.data]
97
+
98
+
99
+ def add_documents_to_chroma(file_paths: List[str], api_key: str) -> str:
100
+ if not file_paths:
101
+ return "⚠️ No files were provided."
102
+
103
+ total_chunks = 0
104
+ for file_path in file_paths:
105
+ file_name = os.path.basename(file_path)
106
+ raw_text = extract_text_from_file(file_path)
107
+
108
+ if not raw_text.strip():
109
+ continue
110
+
111
+ chunks = chunk_text(raw_text)
112
+ embeddings = embed_texts(chunks, api_key)
113
+
114
+ ids = [f"{file_name}-{uuid.uuid4()}" for _ in chunks]
115
+ metadatas = [{"source": file_name} for _ in chunks]
116
+
117
+ collection.add(
118
+ ids=ids,
119
+ documents=chunks,
120
+ metadatas=metadatas,
121
+ embeddings=embeddings,
122
+ )
123
+
124
+ total_chunks += len(chunks)
125
+
126
+ count = collection.count()
127
+ return (
128
+ f"✅ Indexed {len(file_paths)} file(s) into Chroma with {total_chunks} chunks. "
129
+ f"Collection now has {count} vectors."
130
+ )
131
+
132
+
133
+ # =========================
134
+ # Query Expansion
135
+ # =========================
136
+
137
+ def query_expansion(user_query: str, api_key: str) -> List[str]:
138
+ user_query = (user_query or "").strip()
139
+ if not user_query:
140
+ return []
141
+
142
+ client = get_openai_client(api_key)
143
+
144
+ system_prompt = (
145
+ "You are an expert in information retrieval systems, particularly skilled in enhancing queries "
146
+ "for document search efficiency."
147
+ )
148
+
149
+ user_prompt = f"""
150
+ Perform query expansion on the received question by considering alternative phrasings or synonyms commonly used in document retrieval contexts.
151
+ If there are multiple ways to phrase the user's question or common synonyms for key terms, provide several reworded versions.
152
+ If there are acronyms or words you are not familiar with, do not try to rephrase them.
153
+ Return at least 3 versions of the question.
154
+ Return ONLY valid JSON with this exact shape:
155
+ {{
156
+ "expanded": ["q1", "q2", "q3"]
157
+ }}
158
+ Question:
159
+ {user_query}
160
+ """.strip()
161
+
162
+ completion = client.chat.completions.create(
163
+ model="gpt-4.1-mini",
164
+ temperature=0.2,
165
+ response_format={"type": "json_object"},
166
+ messages=[
167
+ {"role": "system", "content": system_prompt},
168
+ {"role": "user", "content": user_prompt},
169
+ ],
170
+ )
171
+
172
+ raw = completion.choices[0].message.content
173
+ try:
174
+ data = json.loads(raw)
175
+ expanded = data.get("expanded", [])
176
+ except json.JSONDecodeError:
177
+ expanded = []
178
+
179
+ expanded = [q.strip() for q in expanded if isinstance(q, str) and q.strip()]
180
+ while len(expanded) < 3:
181
+ expanded.append(user_query)
182
+
183
+ # include original as first option
184
+ if expanded and expanded[0] != user_query:
185
+ expanded = [user_query] + expanded
186
+
187
+ # De-dupe preserving order
188
+ seen = set()
189
+ out = []
190
+ for q in expanded:
191
+ if q not in seen:
192
+ seen.add(q)
193
+ out.append(q)
194
+
195
+ return out
196
+
197
+
198
+ def format_expansions_md(expanded: List[str]) -> str:
199
+ if not expanded:
200
+ return "*(No expansions yet — type a question and press Enter.)*"
201
+ lines = [f"{i+1}. {q}" for i, q in enumerate(expanded)]
202
+ return "### 🧠 Expanded Queries\n\n" + "\n".join(lines)
203
+
204
+
205
+ # =========================
206
+ # LLM Self-Evaluation Helper
207
+ # =========================
208
+
209
+ def evaluate_answer(question: str, context: str, answer: str, api_key: str) -> dict:
210
+ client = get_openai_client(api_key)
211
+
212
+ system_prompt = (
213
+ "You are an impartial evaluator for a Retrieval-Augmented Generation (RAG) system. "
214
+ "You will receive: (1) the user query, (2) the retrieved context, and (3) the model's answer. "
215
+ "You must evaluate the answer on five metrics, each scored from 1 (very poor) to 5 (excellent):\n"
216
+ "- Groundedness: Is the answer supported by the retrieved CONTEXT (not outside knowledge)?\n"
217
+ "- Relevance: Does the answer address the USER QUERY directly and appropriately?\n"
218
+ "- Faithfulness: Are the statements logically valid and consistent with the context (no contradictions)?\n"
219
+ "- Context Precision: Does the answer avoid including irrelevant details from the context?\n"
220
+ "- Context Recall: Does the answer capture all IMPORTANT information from the context needed to answer well?\n\n"
221
+ "Return ONLY a single JSON object with this exact structure:\n"
222
+ "{\n"
223
+ ' "query": string,\n'
224
+ ' "response": string,\n'
225
+ ' "groundedness_evaluation": {"score": int, "justification": string},\n'
226
+ ' "relevance_evaluation": {"score": int, "justification": string},\n'
227
+ ' "faithfulness_evaluation": {"score": int, "justification": string},\n'
228
+ ' "context_precision_evaluation": {"score": int, "justification": string},\n'
229
+ ' "context_recall_evaluation": {"score": int, "justification": string}\n'
230
+ "}"
231
+ )
232
+
233
+ user_prompt = (
234
+ f"USER QUERY:\n{question}\n\n"
235
+ f"RETRIEVED CONTEXT:\n{context}\n\n"
236
+ f"MODEL ANSWER:\n{answer}"
237
+ )
238
+
239
+ completion = client.chat.completions.create(
240
+ model="gpt-4.1-mini",
241
+ temperature=0.0,
242
+ response_format={"type": "json_object"},
243
+ messages=[
244
+ {"role": "system", "content": system_prompt},
245
+ {"role": "user", "content": user_prompt},
246
+ ],
247
+ )
248
+
249
+ raw = completion.choices[0].message.content
250
+ try:
251
+ return json.loads(raw)
252
+ except json.JSONDecodeError:
253
+ return {
254
+ "query": question,
255
+ "response": answer,
256
+ "groundedness_evaluation": {"score": None, "justification": "Failed to parse JSON evaluation."},
257
+ "relevance_evaluation": {"score": None, "justification": raw},
258
+ "faithfulness_evaluation": {"score": None, "justification": ""},
259
+ "context_precision_evaluation": {"score": None, "justification": ""},
260
+ "context_recall_evaluation": {"score": None, "justification": ""},
261
+ }
262
+
263
+
264
+ # =========================================================
265
+ # REQUIRED: Chroma Retrieval + Cross-Encoder Rerank + Prompt
266
+ # =========================================================
267
+
268
+ def retrieve_from_chroma(query: str, top_k: int, api_key: str) -> List[Dict[str, Any]]:
269
+ """
270
+ Retrieve top_k passages from Chroma using embeddings.
271
+ Preserves ids + metadatas + distances + documents.
272
+
273
+ Returns list[dict] with keys:
274
+ - id: str
275
+ - text: str
276
+ - metadata: dict
277
+ - distance: float|None
278
+ """
279
+ query = (query or "").strip()
280
+ if not query:
281
+ return []
282
+
283
+ if collection.count() == 0:
284
+ return []
285
+
286
+ q_emb = embed_texts([query], api_key)[0]
287
+ results = collection.query(
288
+ query_embeddings=[q_emb],
289
+ n_results=top_k,
290
+ )
291
+
292
+ ids = results.get("ids", [[]])[0] or []
293
+ docs = results.get("documents", [[]])[0] or []
294
+ metas = results.get("metadatas", [[]])[0] or []
295
+ dists = results.get("distances", [[]])[0] if "distances" in results else [None] * len(docs)
296
+
297
+ out = []
298
+ for i in range(min(len(docs), len(ids), len(metas))):
299
+ out.append({
300
+ "id": ids[i],
301
+ "text": docs[i],
302
+ "metadata": metas[i] or {},
303
+ "distance": dists[i] if i < len(dists) else None,
304
+ })
305
+ return out
306
+
307
+
308
+ def cross_encoder_rerank(query: str, docs: List[Dict[str, Any]], top_n: int) -> List[Dict[str, Any]]:
309
+ """
310
+ Rerank retrieved passages with a HF cross-encoder:
311
+ model = "cross-encoder/ms-marco-MiniLM-L-6-v2"
312
+
313
+ Inputs:
314
+ - query: str
315
+ - docs: list of dicts from retrieve_from_chroma or merged retrieval
316
+ - top_n: int
317
+
318
+ Returns: list of docs with added field:
319
+ - score: float (higher is better)
320
+ """
321
+ query = (query or "").strip()
322
+ if not query or not docs:
323
+ return []
324
+
325
+ model = get_cross_encoder()
326
+
327
+ pairs = [(query, d.get("text", "")) for d in docs]
328
+ scores = model.predict(pairs)
329
+
330
+ reranked = []
331
+ for d, s in zip(docs, scores):
332
+ dd = dict(d)
333
+ dd["score"] = float(s)
334
+ reranked.append(dd)
335
+
336
+ reranked.sort(key=lambda x: x.get("score", float("-inf")), reverse=True)
337
+ return reranked[:top_n]
338
+
339
+
340
+ def build_prompt(query: str, reranked_docs: List[Dict[str, Any]]) -> Tuple[str, str]:
341
+ """
342
+ Build the final context string and the LLM prompt.
343
+
344
+ Returns:
345
+ - context: str (the final context string)
346
+ - prompt: str (full prompt for the LLM)
347
+ """
348
+ parts = []
349
+ for d in reranked_docs:
350
+ md = d.get("metadata", {}) or {}
351
+ source = md.get("source", "unknown")
352
+ page = md.get("page", md.get("page_number", md.get("pageno", "")))
353
+
354
+ header = f"Source: {source}"
355
+ if page != "" and page is not None:
356
+ header += f" | Page: {page}"
357
+
358
+ parts.append(f"{header}\n{d.get('text','')}".strip())
359
+
360
+ context = "\n\n---\n\n".join(parts).strip()
361
+
362
+ prompt = (
363
+ "You are a helpful assistant that answers questions ONLY using the provided document context. "
364
+ "If the context does not contain the answer, say you do not know.\n\n"
365
+ f"Context from documents:\n\n{context}\n\n"
366
+ f"Question: {query}\n\n"
367
+ "Answer based only on the context above."
368
+ )
369
+
370
+ return context, prompt
371
+
372
+
373
+ # =========================
374
+ # Existing Multi-Query RAG (unchanged behavior)
375
+ # =========================
376
+
377
+ def _merge_docs_by_id(doc_lists: List[List[Dict[str, Any]]]) -> List[Dict[str, Any]]:
378
+ """
379
+ Merge/dedupe docs (dicts) by Chroma chunk id. Keeps the best (lowest) distance if present.
380
+ """
381
+ merged: Dict[str, Dict[str, Any]] = {}
382
+ for docs in doc_lists:
383
+ for d in docs:
384
+ cid = d.get("id")
385
+ if not cid:
386
+ continue
387
+ if cid not in merged:
388
+ merged[cid] = d
389
+ else:
390
+ # keep best distance if both have it
391
+ old_dist = merged[cid].get("distance")
392
+ new_dist = d.get("distance")
393
+ if old_dist is not None and new_dist is not None and new_dist < old_dist:
394
+ merged[cid] = d
395
+ return list(merged.values())
396
+
397
+
398
+ def query_rag_multi(selected_queries: List[str], api_key: str) -> str:
399
+ selected_queries = [q.strip() for q in (selected_queries or []) if isinstance(q, str) and q.strip()]
400
+ if not selected_queries:
401
+ return "⚠️ Please select at least one expanded query."
402
+
403
+ if collection.count() == 0:
404
+ return "⚠️ No documents in the database yet. Upload and index some documents first."
405
+
406
+ # Your prior behavior: embed each selected query, retrieve 5 each, merge, take top 5 overall.
407
+ # (We keep this as-is.)
408
+ q_embs = embed_texts(selected_queries, api_key)
409
+ results = collection.query(
410
+ query_embeddings=q_embs,
411
+ n_results=5,
412
+ )
413
+
414
+ # Convert multi-query results to docs
415
+ all_ids = results.get("ids", [])
416
+ all_docs = results.get("documents", [])
417
+ all_metas = results.get("metadatas", [])
418
+ all_dist = results.get("distances", None)
419
+
420
+ doc_lists: List[List[Dict[str, Any]]] = []
421
+ for qi in range(len(all_docs)):
422
+ ids_i = all_ids[qi] if qi < len(all_ids) else []
423
+ docs_i = all_docs[qi] if qi < len(all_docs) else []
424
+ metas_i = all_metas[qi] if qi < len(all_metas) else []
425
+ dist_i = all_dist[qi] if isinstance(all_dist, list) and qi < len(all_dist) else [None] * len(docs_i)
426
+
427
+ out_i = []
428
+ for cid, doc, meta, dist in zip(ids_i, docs_i, metas_i, dist_i):
429
+ out_i.append({"id": cid, "text": doc, "metadata": meta or {}, "distance": dist})
430
+ doc_lists.append(out_i)
431
+
432
+ merged = _merge_docs_by_id(doc_lists)
433
+ if not merged:
434
+ return "I couldn't find any relevant context in the indexed documents."
435
+
436
+ # best-first by distance if available
437
+ merged.sort(key=lambda d: (d.get("distance") is None, d.get("distance", 1e9)))
438
+ top = merged[:5]
439
+
440
+ context_parts = []
441
+ for d in top:
442
+ md = d.get("metadata", {}) or {}
443
+ context_parts.append(f"Source: {md.get('source','unknown')}\n{d.get('text','')}")
444
+ context = "\n\n---\n\n".join(context_parts)
445
+
446
+ client = get_openai_client(api_key)
447
+ system_prompt = (
448
+ "You are a helpful assistant that answers questions ONLY using the provided document context. "
449
+ "If the context does not contain the answer, say you do not know."
450
+ )
451
+ user_prompt = (
452
+ f"Context from documents:\n\n{context}\n\n"
453
+ f"Selected expanded queries:\n- " + "\n- ".join(selected_queries) + "\n\n"
454
+ "Answer based only on the context above."
455
+ )
456
+
457
+ completion = client.chat.completions.create(
458
+ model="gpt-4.1-mini",
459
+ messages=[
460
+ {"role": "system", "content": system_prompt},
461
+ {"role": "user", "content": user_prompt},
462
+ ],
463
+ temperature=0.1,
464
+ )
465
+
466
+ response_text = completion.choices[0].message.content.strip()
467
+
468
+ try:
469
+ eval_dict = evaluate_answer(
470
+ question=" | ".join(selected_queries),
471
+ context=context,
472
+ answer=response_text,
473
+ api_key=api_key,
474
+ )
475
+
476
+ log_record = {
477
+ "query": eval_dict.get("query"),
478
+ "response": eval_dict.get("response"),
479
+ "groundedness_evaluation": eval_dict.get("groundedness_evaluation"),
480
+ "relevance_evaluation": eval_dict.get("relevance_evaluation"),
481
+ "faithfulness_evaluation": eval_dict.get("faithfulness_evaluation"),
482
+ "context_precision_evaluation": eval_dict.get("context_precision_evaluation"),
483
+ "context_recall_evaluation": eval_dict.get("context_recall_evaluation"),
484
+ }
485
+
486
+ return (
487
+ f"### 💬 Answer\n\n{response_text}\n\n"
488
+ f"---\n\n"
489
+ f"### 🔍 Self-evaluation (1–5)\n\n"
490
+ f"```json\n{json.dumps(log_record, indent=2)}\n```"
491
+ )
492
+ except Exception as e:
493
+ return (
494
+ f"### 💬 Answer\n\n{response_text}\n\n"
495
+ f"---\n\n"
496
+ f"⚠️ Self-evaluation failed: {e}"
497
+ )
498
+
499
+
500
+ # =========================
501
+ # Cross-Encode Stage UI Helpers
502
+ # =========================
503
+
504
+ def format_rerank_results_md(query: str, reranked: List[Dict[str, Any]], top_n: int) -> str:
505
+ if not reranked:
506
+ return "*(No reranked results to display.)*"
507
+
508
+ lines = []
509
+ lines.append(f"### 🎯 Cross-Encoder Rerank Results (top {top_n})")
510
+ lines.append("")
511
+ lines.append("| Rank | Score | Source | Page | Snippet |")
512
+ lines.append("|---:|---:|---|---:|---|")
513
+
514
+ for i, d in enumerate(reranked, start=1):
515
+ md = d.get("metadata", {}) or {}
516
+ source = str(md.get("source", "unknown"))
517
+ page = md.get("page", md.get("page_number", md.get("pageno", "")))
518
+ score = d.get("score", None)
519
+ snippet = (d.get("text", "") or "").replace("\n", " ").strip()
520
+ if len(snippet) > 160:
521
+ snippet = snippet[:160] + "…"
522
+ lines.append(f"| {i} | {score:.4f} | {source} | {page if page is not None else ''} | {snippet} |")
523
+
524
+ return "\n".join(lines)
525
+
526
+
527
+ # =========================
528
+ # Gradio Wrappers
529
+ # =========================
530
+
531
+ def gradio_ingest(files, api_key):
532
+ if not api_key or not api_key.strip():
533
+ return "❌ Please enter your OpenAI API key before indexing."
534
+
535
+ if not files:
536
+ return "⚠️ Please drop at least one document."
537
+
538
+ file_paths = files if isinstance(files, list) else [files]
539
+
540
+ try:
541
+ status = add_documents_to_chroma(file_paths, api_key)
542
+ except Exception as e:
543
+ return f"❌ Error during indexing: {e}"
544
+ return status
545
+
546
+
547
+ def gradio_expand(question: str, api_key: str):
548
+ if not api_key or not api_key.strip():
549
+ return gr.update(choices=[], value=[]), "❌ Please enter your OpenAI API key first."
550
+
551
+ expanded = query_expansion(question, api_key)
552
+ md = format_expansions_md(expanded)
553
+ default_value = expanded[:1] if expanded else []
554
+ return gr.update(choices=expanded, value=default_value), md
555
+
556
+
557
+ def gradio_run_selected(selected_queries: List[str], api_key: str) -> str:
558
+ if not api_key or not api_key.strip():
559
+ return "❌ Please enter your OpenAI API key before searching."
560
+ if not selected_queries:
561
+ return "⚠️ Please expand a question and select one or more to run."
562
+
563
+ try:
564
+ return query_rag_multi(selected_queries, api_key)
565
+ except Exception as e:
566
+ return f"❌ Error during question answering: {e}"
567
+
568
+
569
+ def gradio_cross_encode(original_question: str, selected_queries: List[str], api_key: str) -> Tuple[str, str]:
570
+ """
571
+ Cross-encode button:
572
+ - Initial retrieval via Chroma: top_k=20 (per requirement)
573
+ - Rerank via cross-encoder: top_n=5 (per requirement)
574
+ - Show:
575
+ (a) top_n reranked passages,
576
+ (b) their scores,
577
+ (c) final context string
578
+ """
579
+ if not api_key or not api_key.strip():
580
+ return "❌ Please enter your OpenAI API key first.", ""
581
+
582
+ if collection.count() == 0:
583
+ return "⚠️ No documents in the database yet. Upload and index some documents first.", ""
584
+
585
+ original_question = (original_question or "").strip()
586
+ selected_queries = [q.strip() for q in (selected_queries or []) if isinstance(q, str) and q.strip()]
587
+
588
+ if not original_question and not selected_queries:
589
+ return "⚠️ Please type a question and/or select expansions first.", ""
590
+
591
+ # Retrieval: use selected expansions if present, otherwise fall back to original question
592
+ retrieval_queries = selected_queries if selected_queries else [original_question]
593
+
594
+ # Requirement: Chroma retrieval top_k=20
595
+ retrieved_lists = [retrieve_from_chroma(q, top_k=20, api_key=api_key) for q in retrieval_queries]
596
+ merged_docs = _merge_docs_by_id(retrieved_lists)
597
+
598
+ if not merged_docs:
599
+ return "I couldn't find any relevant context in the indexed documents.", ""
600
+
601
+ # Cross-encoder scoring query: use the original user question if available; else first retrieval query
602
+ scoring_query = original_question if original_question else retrieval_queries[0]
603
+
604
+ # Requirement: rerank top_n=5
605
+ reranked = cross_encoder_rerank(scoring_query, merged_docs, top_n=5)
606
+
607
+ # Build final context + prompt
608
+ context, _prompt = build_prompt(scoring_query, reranked)
609
+
610
+ # Return:
611
+ # (a) reranked passages (shown in table),
612
+ # (b) scores (in table),
613
+ # (c) final context string (shown separately)
614
+ md = format_rerank_results_md(scoring_query, reranked, top_n=5)
615
+ return md, f"### 🧩 Final Context (for LLM)\n\n{context}"
616
+
617
+
618
+ # =========================
619
+ # Gradio Interface
620
+ # =========================
621
+
622
+ with gr.Blocks(title="RAG with ChromaDB") as demo:
623
+ gr.Markdown(
624
+ """
625
+ # 📚 RAG Q&A with ChromaDB + Gradio (Multi-Select Query Expansion + Cross-Encoder Rerank)
626
+ 1. Paste your **OpenAI API key** below.
627
+ 2. **Drag & drop** one or more documents into the upload box.
628
+ 3. Click **"Index documents"** to store them in a Chroma vector database.
629
+ 4. Type a question and press **Enter** (or click **Expand**) to generate expanded queries.
630
+ 5. Select **one or more** expanded queries.
631
+ 6. Click **Run Search** for the normal pipeline, or **Cross Encode** to view reranked passages + scores + final context.
632
+ """
633
+ )
634
+
635
+ with gr.Row():
636
+ with gr.Column(scale=1):
637
+ api_key_box = gr.Textbox(
638
+ label="OpenAI API Key",
639
+ placeholder="sk-... (this is kept in memory only for this session)",
640
+ type="password",
641
+ )
642
+
643
+ file_input = gr.File(
644
+ label="Drop your document(s) here",
645
+ file_count="multiple",
646
+ type="filepath",
647
+ )
648
+ ingest_button = gr.Button("Index documents")
649
+ ingest_status = gr.Markdown("⚙️ Waiting for documents...")
650
+
651
+ with gr.Column(scale=1):
652
+ question_box = gr.Textbox(
653
+ label="Type a question, then press Enter to expand",
654
+ placeholder="e.g., What are the main findings in the report?",
655
+ lines=3,
656
+ )
657
+
658
+ with gr.Row():
659
+ expand_button = gr.Button("Expand")
660
+ run_button = gr.Button("Run Search")
661
+ cross_button = gr.Button("Cross Encode")
662
+
663
+ expanded_checks = gr.CheckboxGroup(
664
+ label="Choose one or more expanded queries to run",
665
+ choices=[],
666
+ value=[],
667
+ interactive=True,
668
+ )
669
+
670
+ expansions_preview = gr.Markdown("*(No expansions yet — type a question and press Enter.)*")
671
+ answer_box = gr.Markdown("💬 Answer will appear here (with self-evaluation).")
672
+
673
+ gr.Markdown("---")
674
+ rerank_results_box = gr.Markdown("*(Cross-encoder rerank results will appear here.)*")
675
+ rerank_context_box = gr.Markdown("*(Final context for the LLM will appear here.)*")
676
+
677
+ ingest_button.click(
678
+ fn=gradio_ingest,
679
+ inputs=[file_input, api_key_box],
680
+ outputs=[ingest_status],
681
+ )
682
+
683
+ # Expand on Enter
684
+ question_box.submit(
685
+ fn=gradio_expand,
686
+ inputs=[question_box, api_key_box],
687
+ outputs=[expanded_checks, expansions_preview],
688
+ )
689
+
690
+ # Expand on button click
691
+ expand_button.click(
692
+ fn=gradio_expand,
693
+ inputs=[question_box, api_key_box],
694
+ outputs=[expanded_checks, expansions_preview],
695
+ )
696
+
697
+ # Run selected expanded queries (existing pipeline)
698
+ run_button.click(
699
+ fn=gradio_run_selected,
700
+ inputs=[expanded_checks, api_key_box],
701
+ outputs=[answer_box],
702
+ )
703
+
704
+ # Cross-encoder rerank (new button + UI outputs)
705
+ cross_button.click(
706
+ fn=gradio_cross_encode,
707
+ inputs=[question_box, expanded_checks, api_key_box],
708
+ outputs=[rerank_results_box, rerank_context_box],
709
+ )
710
+
711
+ if __name__ == "__main__":
712
+ demo.launch()