Qar-Raz commited on
Commit
23e3c5c
·
1 Parent(s): 7b44ae2

added chunk details

Browse files
api.py CHANGED
@@ -12,6 +12,7 @@ from fastapi.responses import StreamingResponse
12
  from huggingface_hub import InferenceClient
13
  from pydantic import BaseModel, Field
14
 
 
15
  from vector_db import get_index_by_name, load_chunks_with_local_cache
16
  from retriever.retriever import HybridRetriever
17
  from retriever.generator import RAGGenerator
@@ -40,7 +41,7 @@ class PredictResponse(BaseModel):
40
  model: str
41
  answer: str
42
  contexts: list[str]
43
- metrics: dict[str, float]
44
 
45
 
46
  class TitleRequest(BaseModel):
@@ -139,6 +140,34 @@ def _parse_title_model_candidates() -> list[str]:
139
  return models or ["meta-llama/Meta-Llama-3-8B-Instruct"]
140
 
141
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
142
 
143
  # Fastapi setup
144
  # Fastapi allows us to define python based endpoint
@@ -209,7 +238,9 @@ def startup_event() -> None:
209
  raise RuntimeError("HF_TOKEN not found in environment variables")
210
 
211
  index_name = "cbt-book-recursive"
212
- embed_model_name = "all-MiniLM-L6-v2"
 
 
213
  project_root = os.path.dirname(os.path.abspath(__file__))
214
  cache_dir = os.getenv("BM25_CACHE_DIR", os.path.join(project_root, ".cache"))
215
  force_cache_refresh = os.getenv("BM25_CACHE_REFRESH", "0").lower() in {"1", "true", "yes"}
@@ -251,10 +282,23 @@ def startup_event() -> None:
251
  models_time = time.perf_counter() - models_start
252
 
253
  state_start = time.perf_counter()
 
 
 
 
 
 
 
 
 
 
 
 
254
  state["index"] = index
255
  state["retriever"] = retriever
256
  state["rag_engine"] = rag_engine
257
  state["models"] = models
 
258
  state["title_model_ids"] = _parse_title_model_candidates()
259
  state["title_client"] = InferenceClient(token=hf_token)
260
  state_time = time.perf_counter() - state_start
@@ -338,6 +382,7 @@ def predict(payload: PredictRequest) -> PredictResponse:
338
  index = state["index"]
339
  rag_engine: RAGGenerator = state["rag_engine"]
340
  models: dict[str, Any] = state["models"]
 
341
  state_access_time = time.perf_counter() - state_access_start
342
 
343
  model_resolve_start = time.perf_counter()
@@ -364,33 +409,28 @@ def predict(payload: PredictRequest) -> PredictResponse:
364
  answer = rag_engine.get_answer(model_instance, query, contexts, temperature=0.1)
365
  inference_time = time.perf_counter() - inference_start
366
 
367
- response_start = time.perf_counter()
368
- metrics = {
369
- "precheck_s": round(precheck_time, 3),
370
- "state_access_s": round(state_access_time, 3),
371
- "model_resolve_s": round(model_resolve_time, 3),
372
- "retrieval_s": round(retrieval_time, 3),
373
- "inference_s": round(inference_time, 3),
374
- }
375
- response_build_time = time.perf_counter() - response_start
376
 
377
  total_time = time.perf_counter() - req_start
378
- metrics["response_build_s"] = round(response_build_time, 3)
379
- metrics["total_s"] = round(total_time, 3)
380
 
381
  print(
382
  f"Predict timing | model={model_name} | mode={payload.mode} | "
383
  f"rerank={payload.rerank_strategy} | precheck={precheck_time:.3f}s | "
384
  f"state_access={state_access_time:.3f}s | model_resolve={model_resolve_time:.3f}s | "
385
  f"retrieval={retrieval_time:.3f}s | inference={inference_time:.3f}s | "
386
- f"response_build={response_build_time:.3f}s | total={total_time:.3f}s"
387
  )
388
 
389
  return PredictResponse(
390
  model=model_name,
391
  answer=answer,
392
  contexts=contexts,
393
- metrics=metrics,
394
  )
395
 
396
  # new endpoint for streaming response, allows frontend to render tokens as they come in instead of waiting for full answer
@@ -412,6 +452,7 @@ def predict_stream(payload: PredictRequest) -> StreamingResponse:
412
  index = state["index"]
413
  rag_engine: RAGGenerator = state["rag_engine"]
414
  models: dict[str, Any] = state["models"]
 
415
  state_access_time = time.perf_counter() - state_access_start
416
 
417
  model_resolve_start = time.perf_counter()
@@ -443,23 +484,19 @@ def predict_stream(payload: PredictRequest) -> StreamingResponse:
443
  yield _to_ndjson({"type": "token", "token": token})
444
 
445
  inference_time = time.perf_counter() - inference_start
446
- total_time = time.perf_counter() - req_start
447
  answer = "".join(answer_parts)
448
- metrics = {
449
- "precheck_s": round(precheck_time, 3),
450
- "state_access_s": round(state_access_time, 3),
451
- "model_resolve_s": round(model_resolve_time, 3),
452
- "retrieval_s": round(retrieval_time, 3),
453
- "inference_s": round(inference_time, 3),
454
- "total_s": round(total_time, 3),
455
- }
456
 
457
  yield _to_ndjson(
458
  {
459
  "type": "done",
460
  "model": model_name,
461
  "answer": answer,
462
- "metrics": metrics,
 
463
  }
464
  )
465
  except Exception as exc:
 
12
  from huggingface_hub import InferenceClient
13
  from pydantic import BaseModel, Field
14
 
15
+ from config_loader import cfg
16
  from vector_db import get_index_by_name, load_chunks_with_local_cache
17
  from retriever.retriever import HybridRetriever
18
  from retriever.generator import RAGGenerator
 
41
  model: str
42
  answer: str
43
  contexts: list[str]
44
+ retrieved_chunks: list[dict[str, Any]]
45
 
46
 
47
  class TitleRequest(BaseModel):
 
140
  return models or ["meta-llama/Meta-Llama-3-8B-Instruct"]
141
 
142
 
143
+ def _build_retrieved_chunks(
144
+ contexts: list[str],
145
+ chunk_lookup: dict[str, dict[str, Any]],
146
+ ) -> list[dict[str, Any]]:
147
+ if not contexts:
148
+ return []
149
+
150
+ retrieved_chunks: list[dict[str, Any]] = []
151
+
152
+ for idx, text in enumerate(contexts, start=1):
153
+ meta = chunk_lookup.get(text, {})
154
+ title = meta.get("title") or "Untitled"
155
+ url = meta.get("url") or ""
156
+ chunk_index = meta.get("chunk_index")
157
+
158
+ retrieved_chunks.append(
159
+ {
160
+ "rank": idx,
161
+ "text": text,
162
+ "source_title": title,
163
+ "source_url": url,
164
+ "chunk_index": chunk_index,
165
+ }
166
+ )
167
+
168
+ return retrieved_chunks
169
+
170
+
171
 
172
  # Fastapi setup
173
  # Fastapi allows us to define python based endpoint
 
238
  raise RuntimeError("HF_TOKEN not found in environment variables")
239
 
240
  index_name = "cbt-book-recursive"
241
+ # Keep retrieval embedding model aligned with the one used at ingest time
242
+ # to avoid Pinecone dimension mismatch errors (e.g., 384 vs 512).
243
+ embed_model_name = cfg.processing.get("embedding_model", "all-MiniLM-L6-v2")
244
  project_root = os.path.dirname(os.path.abspath(__file__))
245
  cache_dir = os.getenv("BM25_CACHE_DIR", os.path.join(project_root, ".cache"))
246
  force_cache_refresh = os.getenv("BM25_CACHE_REFRESH", "0").lower() in {"1", "true", "yes"}
 
282
  models_time = time.perf_counter() - models_start
283
 
284
  state_start = time.perf_counter()
285
+ chunk_lookup: dict[str, dict[str, Any]] = {}
286
+ for chunk in final_chunks:
287
+ metadata = chunk.get("metadata", {})
288
+ text = metadata.get("text")
289
+ if not text or text in chunk_lookup:
290
+ continue
291
+ chunk_lookup[text] = {
292
+ "title": metadata.get("title", "Untitled"),
293
+ "url": metadata.get("url", ""),
294
+ "chunk_index": metadata.get("chunk_index"),
295
+ }
296
+
297
  state["index"] = index
298
  state["retriever"] = retriever
299
  state["rag_engine"] = rag_engine
300
  state["models"] = models
301
+ state["chunk_lookup"] = chunk_lookup
302
  state["title_model_ids"] = _parse_title_model_candidates()
303
  state["title_client"] = InferenceClient(token=hf_token)
304
  state_time = time.perf_counter() - state_start
 
382
  index = state["index"]
383
  rag_engine: RAGGenerator = state["rag_engine"]
384
  models: dict[str, Any] = state["models"]
385
+ chunk_lookup: dict[str, dict[str, Any]] = state.get("chunk_lookup", {})
386
  state_access_time = time.perf_counter() - state_access_start
387
 
388
  model_resolve_start = time.perf_counter()
 
409
  answer = rag_engine.get_answer(model_instance, query, contexts, temperature=0.1)
410
  inference_time = time.perf_counter() - inference_start
411
 
412
+ mapping_start = time.perf_counter()
413
+ retrieved_chunks = _build_retrieved_chunks(
414
+ contexts=contexts,
415
+ chunk_lookup=chunk_lookup,
416
+ )
417
+ mapping_time = time.perf_counter() - mapping_start
 
 
 
418
 
419
  total_time = time.perf_counter() - req_start
 
 
420
 
421
  print(
422
  f"Predict timing | model={model_name} | mode={payload.mode} | "
423
  f"rerank={payload.rerank_strategy} | precheck={precheck_time:.3f}s | "
424
  f"state_access={state_access_time:.3f}s | model_resolve={model_resolve_time:.3f}s | "
425
  f"retrieval={retrieval_time:.3f}s | inference={inference_time:.3f}s | "
426
+ f"context_map={mapping_time:.3f}s | total={total_time:.3f}s"
427
  )
428
 
429
  return PredictResponse(
430
  model=model_name,
431
  answer=answer,
432
  contexts=contexts,
433
+ retrieved_chunks=retrieved_chunks,
434
  )
435
 
436
  # new endpoint for streaming response, allows frontend to render tokens as they come in instead of waiting for full answer
 
452
  index = state["index"]
453
  rag_engine: RAGGenerator = state["rag_engine"]
454
  models: dict[str, Any] = state["models"]
455
+ chunk_lookup: dict[str, dict[str, Any]] = state.get("chunk_lookup", {})
456
  state_access_time = time.perf_counter() - state_access_start
457
 
458
  model_resolve_start = time.perf_counter()
 
484
  yield _to_ndjson({"type": "token", "token": token})
485
 
486
  inference_time = time.perf_counter() - inference_start
 
487
  answer = "".join(answer_parts)
488
+ retrieved_chunks = _build_retrieved_chunks(
489
+ contexts=contexts,
490
+ chunk_lookup=chunk_lookup,
491
+ )
 
 
 
 
492
 
493
  yield _to_ndjson(
494
  {
495
  "type": "done",
496
  "model": model_name,
497
  "answer": answer,
498
+ "contexts": contexts,
499
+ "retrieved_chunks": retrieved_chunks,
500
  }
501
  )
502
  except Exception as exc:
frontend/components/AIAssistantUI.jsx CHANGED
@@ -320,7 +320,7 @@ export default function AIAssistantUI() {
320
  )
321
  }
322
 
323
- const finalizeAssistant = (finalText) => {
324
  upsertAssistantMessage(
325
  (m) => {
326
  const fallbackContent = m.content || "Sorry, I encountered an error."
@@ -328,6 +328,8 @@ export default function AIAssistantUI() {
328
  ...m,
329
  content: finalText != null ? finalText : fallbackContent,
330
  isStreaming: false,
 
 
331
  }
332
  },
333
  finalText || "Sorry, I encountered an error.",
@@ -385,6 +387,7 @@ export default function AIAssistantUI() {
385
  let buffer = ""
386
  let firstTokenReceived = false
387
  let finalAnswer = null
 
388
 
389
  while (true) {
390
  const { value, done } = await reader.read()
@@ -414,6 +417,7 @@ export default function AIAssistantUI() {
414
  }
415
 
416
  if (evt.type === "done") {
 
417
  finalAnswer = typeof evt.answer === "string" ? evt.answer : null
418
  }
419
 
@@ -428,6 +432,7 @@ export default function AIAssistantUI() {
428
  try {
429
  const evt = JSON.parse(remainder)
430
  if (evt.type === "done") {
 
431
  finalAnswer = typeof evt.answer === "string" ? evt.answer : null
432
  }
433
  if (evt.type === "token") {
@@ -441,7 +446,7 @@ export default function AIAssistantUI() {
441
  }
442
  }
443
 
444
- finalizeAssistant(finalAnswer)
445
  } catch (err) {
446
  console.error("predict request failed:", err)
447
  if (err?.name === "AbortError") {
 
320
  )
321
  }
322
 
323
+ const finalizeAssistant = (finalText, details = null) => {
324
  upsertAssistantMessage(
325
  (m) => {
326
  const fallbackContent = m.content || "Sorry, I encountered an error."
 
328
  ...m,
329
  content: finalText != null ? finalText : fallbackContent,
330
  isStreaming: false,
331
+ retrievedChunks: details?.retrieved_chunks || [],
332
+ contexts: details?.contexts || [],
333
  }
334
  },
335
  finalText || "Sorry, I encountered an error.",
 
387
  let buffer = ""
388
  let firstTokenReceived = false
389
  let finalAnswer = null
390
+ let donePayload = null
391
 
392
  while (true) {
393
  const { value, done } = await reader.read()
 
417
  }
418
 
419
  if (evt.type === "done") {
420
+ donePayload = evt
421
  finalAnswer = typeof evt.answer === "string" ? evt.answer : null
422
  }
423
 
 
432
  try {
433
  const evt = JSON.parse(remainder)
434
  if (evt.type === "done") {
435
+ donePayload = evt
436
  finalAnswer = typeof evt.answer === "string" ? evt.answer : null
437
  }
438
  if (evt.type === "token") {
 
446
  }
447
  }
448
 
449
+ finalizeAssistant(finalAnswer, donePayload)
450
  } catch (err) {
451
  console.error("predict request failed:", err)
452
  if (err?.name === "AbortError") {
frontend/components/ChatPane.jsx CHANGED
@@ -1,7 +1,7 @@
1
  "use client"
2
 
3
  import { useState, forwardRef, useImperativeHandle, useRef } from "react"
4
- import { RefreshCw, Check, X, Square } from "lucide-react"
5
  import Message from "./Message"
6
  import MarkdownMessage from "./MarkdownMessage"
7
  import Composer from "./Composer"
@@ -28,6 +28,41 @@ function ThinkingMessage({ onPause }) {
28
  )
29
  }
30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  const ChatPane = forwardRef(function ChatPane(
32
  {
33
  conversation,
@@ -42,6 +77,7 @@ const ChatPane = forwardRef(function ChatPane(
42
  const [editingId, setEditingId] = useState(null)
43
  const [draft, setDraft] = useState("")
44
  const [busy, setBusy] = useState(false)
 
45
  const composerRef = useRef(null)
46
 
47
  useImperativeHandle(
@@ -174,8 +210,31 @@ const ChatPane = forwardRef(function ChatPane(
174
  ) : (
175
  <Message role={m.role} streaming={Boolean(m.isStreaming)}>
176
  {m.role === "assistant" ? (
177
- <div className={cls(m.isStreaming && "streaming-text-reveal")}>
178
- <MarkdownMessage content={m.content} isStreaming={Boolean(m.isStreaming)} />
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
179
  </div>
180
  ) : (
181
  <div className="whitespace-pre-wrap">{m.content}</div>
 
1
  "use client"
2
 
3
  import { useState, forwardRef, useImperativeHandle, useRef } from "react"
4
+ import { RefreshCw, Check, X, Square, ChevronDown, ChevronUp } from "lucide-react"
5
  import Message from "./Message"
6
  import MarkdownMessage from "./MarkdownMessage"
7
  import Composer from "./Composer"
 
28
  )
29
  }
30
 
31
+ function AssistantDetails({ message }) {
32
+ const chunks = Array.isArray(message?.retrievedChunks) ? message.retrievedChunks : []
33
+
34
+ if (!chunks.length) {
35
+ return (
36
+ <div className="rounded-xl border border-zinc-200 bg-zinc-50 p-3 text-xs text-zinc-500 dark:border-zinc-800 dark:bg-zinc-900/50 dark:text-zinc-400">
37
+ No retrieved chunks were returned for this response.
38
+ </div>
39
+ )
40
+ }
41
+
42
+ return (
43
+ <div className="space-y-3 rounded-xl border border-zinc-200 bg-zinc-50 p-3 text-xs dark:border-zinc-800 dark:bg-zinc-900/50">
44
+ <div className="space-y-2">
45
+ <div className="text-[10px] uppercase tracking-wide text-zinc-500">Retrieved Chunks ({chunks.length})</div>
46
+ <div className="max-h-80 space-y-2 overflow-y-auto pr-1">
47
+ {chunks.map((chunk) => (
48
+ <div
49
+ key={`${message.id}-${chunk.rank}`}
50
+ className="rounded-lg border border-zinc-200 bg-white p-2 dark:border-zinc-800 dark:bg-zinc-950"
51
+ >
52
+ <div className="mb-1 flex flex-wrap items-center gap-2 text-[10px] text-zinc-500">
53
+ <span className="rounded-full border border-zinc-300 px-1.5 py-0.5 dark:border-zinc-700">Chunk #{chunk.rank}</span>
54
+ {chunk.source_title && <span>{chunk.source_title}</span>}
55
+ {chunk.chunk_index !== null && chunk.chunk_index !== undefined && <span>Part {chunk.chunk_index}</span>}
56
+ </div>
57
+ <div className="whitespace-pre-wrap text-xs text-zinc-700 dark:text-zinc-300">{chunk.text}</div>
58
+ </div>
59
+ ))}
60
+ </div>
61
+ </div>
62
+ </div>
63
+ )
64
+ }
65
+
66
  const ChatPane = forwardRef(function ChatPane(
67
  {
68
  conversation,
 
77
  const [editingId, setEditingId] = useState(null)
78
  const [draft, setDraft] = useState("")
79
  const [busy, setBusy] = useState(false)
80
+ const [openDetailsId, setOpenDetailsId] = useState(null)
81
  const composerRef = useRef(null)
82
 
83
  useImperativeHandle(
 
210
  ) : (
211
  <Message role={m.role} streaming={Boolean(m.isStreaming)}>
212
  {m.role === "assistant" ? (
213
+ <div className="space-y-2">
214
+ <div className={cls(m.isStreaming && "streaming-text-reveal")}>
215
+ <MarkdownMessage content={m.content} isStreaming={Boolean(m.isStreaming)} />
216
+ </div>
217
+ {!m.isStreaming && Array.isArray(m.retrievedChunks) && m.retrievedChunks.length > 0 && (
218
+ <>
219
+ <button
220
+ onClick={() => setOpenDetailsId((prev) => (prev === m.id ? null : m.id))}
221
+ className="inline-flex items-center gap-1 rounded-full border border-zinc-300 px-2 py-1 text-xs text-zinc-600 hover:bg-zinc-100 dark:border-zinc-700 dark:text-zinc-300 dark:hover:bg-zinc-800"
222
+ >
223
+ <span>{openDetailsId === m.id ? "Hide Details" : "Details"}</span>
224
+ {openDetailsId === m.id ? <ChevronUp className="h-3.5 w-3.5" /> : <ChevronDown className="h-3.5 w-3.5" />}
225
+ </button>
226
+ <div
227
+ className={cls(
228
+ "grid overflow-hidden transition-all duration-300 ease-out",
229
+ openDetailsId === m.id ? "grid-rows-[1fr] opacity-100" : "grid-rows-[0fr] opacity-0",
230
+ )}
231
+ >
232
+ <div className="min-h-0 overflow-hidden">
233
+ <AssistantDetails message={m} />
234
+ </div>
235
+ </div>
236
+ </>
237
+ )}
238
  </div>
239
  ) : (
240
  <div className="whitespace-pre-wrap">{m.content}</div>
frontend/components/Header.tsx CHANGED
@@ -21,15 +21,15 @@ export default function Header() {
21
  // see https://reui.io/docs/github-button for more variables
22
  initialStars={1}
23
  label=""
24
- targetStars={5}
25
 
26
- repoUrl="https://github.com/Qar-Raz/mlops_project.git"
27
 
28
  filled = {true}
29
  animationDuration= {5}
30
  roundStars={true}
31
  // below line can be commented out for clear black button --@Qamar
32
- className="bg-gray-900/50 border-gray-700 text-gray-200 hover:bg-gray-800/50 hover:border-gray-600"
33
  />
34
 
35
  </nav>
 
21
  // see https://reui.io/docs/github-button for more variables
22
  initialStars={1}
23
  label=""
24
+ targetStars={3}
25
 
26
+ repoUrl="https://github.com/ramailkk/RAG-AS3-NLP"
27
 
28
  filled = {true}
29
  animationDuration= {5}
30
  roundStars={true}
31
  // below line can be commented out for clear black button --@Qamar
32
+ className="bg-gray-900/50 border-gray-700 text-gray-200 hover:bg-gray-800/50 hover:border-gray-600 dark:bg-gray-900/50 dark:border-gray-700 dark:text-gray-200 dark:hover:bg-gray-800/50 dark:hover:border-gray-600"
33
  />
34
 
35
  </nav>
frontend/components/ui/github-button.tsx CHANGED
@@ -160,7 +160,6 @@ function GithubButton({
160
  const ref = React.useRef(null);
161
  const isInView = useInView(ref, inViewOptions);
162
 
163
- // Reset animation state when targetStars changes
164
  useEffect(() => {
165
  setHasAnimated(false);
166
  setCurrentStars(initialStars);
 
160
  const ref = React.useRef(null);
161
  const isInView = useInView(ref, inViewOptions);
162
 
 
163
  useEffect(() => {
164
  setHasAnimated(false);
165
  setCurrentStars(initialStars);