Muddasri commited on
Commit
1737e82
·
unverified ·
2 Parent(s): 44406e57b44ae2

Merge branch 'main' into Muddasir/BackendComplete

Browse files
.gitignore CHANGED
@@ -25,6 +25,7 @@ dist/
25
  .mypy_cache/
26
  .ruff_cache/
27
  .ipynb_checkpoints/
 
28
 
29
  # IDE/editor
30
  .vscode/
 
25
  .mypy_cache/
26
  .ruff_cache/
27
  .ipynb_checkpoints/
28
+ .cache/
29
 
30
  # IDE/editor
31
  .vscode/
api.py CHANGED
@@ -1,14 +1,18 @@
1
  # Fastapi endpoints defined here
 
2
  import os
 
3
  import time
4
  from typing import Any
5
 
6
  from dotenv import load_dotenv
7
  from fastapi import FastAPI, HTTPException
8
  from fastapi.middleware.cors import CORSMiddleware
 
 
9
  from pydantic import BaseModel, Field
10
 
11
- from vector_db import get_index_by_name, load_chunks_from_pinecone
12
  from retriever.retriever import HybridRetriever
13
  from retriever.generator import RAGGenerator
14
  from retriever.processor import ChunkProcessor
@@ -20,6 +24,9 @@ from models.deepseek_v3 import DeepSeek_V3
20
  from models.tiny_aya import TinyAya
21
 
22
 
 
 
 
23
  class PredictRequest(BaseModel):
24
  query: str = Field(..., min_length=1, description="User query text")
25
  model: str = Field(default="Llama-3-8B", description="Model name key")
@@ -36,6 +43,102 @@ class PredictResponse(BaseModel):
36
  metrics: dict[str, float]
37
 
38
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
 
40
  # Fastapi setup
41
  # Fastapi allows us to define python based endpoint
@@ -89,10 +192,16 @@ def _resolve_model(name: str, models: dict[str, Any]) -> tuple[str, Any]:
89
 
90
  @app.on_event("startup")
91
  def startup_event() -> None:
 
 
 
92
  load_dotenv()
 
93
 
 
94
  hf_token = os.getenv("HF_TOKEN")
95
  pinecone_api_key = os.getenv("PINECONE_API_KEY")
 
96
 
97
  if not pinecone_api_key:
98
  raise RuntimeError("PINECONE_API_KEY not found in environment variables")
@@ -101,35 +210,71 @@ def startup_event() -> None:
101
 
102
  index_name = "cbt-book-recursive"
103
  embed_model_name = "all-MiniLM-L6-v2"
 
 
 
104
 
105
- startup_start = time.perf_counter()
106
-
107
  index = get_index_by_name(
108
  api_key=pinecone_api_key,
109
  index_name=index_name
110
  )
 
111
 
112
  chunks_start = time.perf_counter()
113
- final_chunks = load_chunks_from_pinecone(index)
 
 
 
 
 
 
114
  chunk_load_time = time.perf_counter() - chunks_start
115
 
116
  if not final_chunks:
117
  raise RuntimeError("No chunks found in Pinecone metadata. Run indexing once before API mode.")
118
 
119
- proc = ChunkProcessor(model_name=embed_model_name, verbose=False)
 
 
 
 
120
  retriever = HybridRetriever(final_chunks, proc.encoder, verbose=False)
 
 
 
121
  rag_engine = RAGGenerator()
 
 
 
122
  models = _build_models(hf_token)
 
123
 
 
124
  state["index"] = index
125
  state["retriever"] = retriever
126
  state["rag_engine"] = rag_engine
127
  state["models"] = models
 
 
 
128
 
129
  startup_time = time.perf_counter() - startup_start
130
  print(
131
  f"API startup complete | chunks={len(final_chunks)} | "
132
- f"chunk_load={chunk_load_time:.3f}s | total={startup_time:.3f}s"
 
 
 
 
 
 
 
 
 
 
 
 
133
  )
134
 
135
 
@@ -139,27 +284,65 @@ def health() -> dict[str, str]:
139
  return {"status": "ok" if ready else "starting"}
140
 
141
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
142
 
143
  # Predict endpoint that takes a query and returns an answer along with contexts and metrics
144
  # is called from the frontend when user clicks submits
145
  # Also resolves model based on user selection
146
  @app.post("/predict", response_model=PredictResponse)
147
  def predict(payload: PredictRequest) -> PredictResponse:
 
 
 
148
  if not state:
149
  raise HTTPException(status_code=503, detail="Service not initialized yet")
150
 
151
  query = payload.query.strip()
152
  if not query:
153
  raise HTTPException(status_code=400, detail="Query cannot be empty")
 
154
 
155
- total_start = time.perf_counter()
156
-
157
  retriever: HybridRetriever = state["retriever"]
158
  index = state["index"]
159
  rag_engine: RAGGenerator = state["rag_engine"]
160
  models: dict[str, Any] = state["models"]
 
161
 
 
162
  model_name, model_instance = _resolve_model(payload.model, models)
 
163
 
164
  retrieval_start = time.perf_counter()
165
  contexts = retriever.search(
@@ -177,19 +360,116 @@ def predict(payload: PredictRequest) -> PredictResponse:
177
  if not contexts:
178
  raise HTTPException(status_code=404, detail="No context chunks retrieved for this query")
179
 
180
- generation_start = time.perf_counter()
181
  answer = rag_engine.get_answer(model_instance, query, contexts, temperature=0.1)
182
- generation_time = time.perf_counter() - generation_start
 
 
 
 
 
 
 
 
 
 
183
 
184
- total_time = time.perf_counter() - total_start
 
 
 
 
 
 
 
 
 
 
185
 
186
  return PredictResponse(
187
  model=model_name,
188
  answer=answer,
189
  contexts=contexts,
190
- metrics={
191
- "retrieval_s": round(retrieval_time, 3),
192
- "generation_s": round(generation_time, 3),
193
- "total_s": round(total_time, 3),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
194
  },
195
  )
 
1
  # Fastapi endpoints defined here
2
+ import json
3
  import os
4
+ import re
5
  import time
6
  from typing import Any
7
 
8
  from dotenv import load_dotenv
9
  from fastapi import FastAPI, HTTPException
10
  from fastapi.middleware.cors import CORSMiddleware
11
+ from fastapi.responses import StreamingResponse
12
+ from huggingface_hub import InferenceClient
13
  from pydantic import BaseModel, Field
14
 
15
+ from vector_db import get_index_by_name, load_chunks_with_local_cache
16
  from retriever.retriever import HybridRetriever
17
  from retriever.generator import RAGGenerator
18
  from retriever.processor import ChunkProcessor
 
24
  from models.tiny_aya import TinyAya
25
 
26
 
27
+ #Added cacheing and time logging to track every stages time
28
+
29
+
30
  class PredictRequest(BaseModel):
31
  query: str = Field(..., min_length=1, description="User query text")
32
  model: str = Field(default="Llama-3-8B", description="Model name key")
 
43
  metrics: dict[str, float]
44
 
45
 
46
+ class TitleRequest(BaseModel):
47
+ query: str = Field(..., min_length=1, description="First user message")
48
+
49
+
50
+ class TitleResponse(BaseModel):
51
+ title: str
52
+ source: str
53
+
54
+
55
+ def _to_ndjson(payload: dict[str, Any]) -> str:
56
+ return json.dumps(payload, ensure_ascii=False) + "\n"
57
+
58
+
59
+
60
+ # simpliest possible implementation to determine chat title
61
+ # is fallback incase hf generation fails.
62
+
63
+ def _title_from_query(query: str) -> str:
64
+ stop_words = {
65
+ "a", "an", "and", "are", "as", "at", "be", "by", "can", "do", "for", "from", "how",
66
+ "i", "in", "is", "it", "me", "my", "of", "on", "or", "please", "show", "tell", "that",
67
+ "the", "this", "to", "we", "what", "when", "where", "which", "why", "with", "you", "your",
68
+ }
69
+
70
+ words = re.findall(r"[A-Za-z0-9][A-Za-z0-9\-_/+]*", query)
71
+ if not words:
72
+ return "New Chat"
73
+
74
+ filtered: list[str] = []
75
+ for word in words:
76
+ cleaned = word.strip("-_/+")
77
+ if not cleaned:
78
+ continue
79
+ if cleaned.lower() in stop_words:
80
+ continue
81
+ filtered.append(cleaned)
82
+ if len(filtered) >= 6:
83
+ break
84
+
85
+ chosen = filtered if filtered else words[:6]
86
+ normalized = [w.capitalize() if w.islower() else w for w in chosen]
87
+ title = " ".join(normalized).strip()
88
+ return title[:80] if title else "New Chat"
89
+
90
+
91
+ #actual code for title generation using hf model, uses a simple prompt to generate a concise title based on user query, with some formatting rules to ensure clean output. If generation fails or returns an empty title, falls back to rule-based method.
92
+ # is called in the /predict/title endpoint
93
+
94
+ def _clean_title_text(raw: str) -> str:
95
+ text = (raw or "").strip()
96
+ text = text.replace("\n", " ").replace("\r", " ")
97
+ text = re.sub(r"^[\"'`\s]+|[\"'`\s]+$", "", text)
98
+ text = re.sub(r"\s+", " ", text).strip()
99
+ words = text.split()
100
+ if len(words) > 8:
101
+ text = " ".join(words[:8])
102
+ return text[:80]
103
+
104
+
105
+ def _title_from_hf(query: str, client: InferenceClient, model_id: str) -> str | None:
106
+ system_prompt = (
107
+ "You generate short chat titles. Return only a title, no punctuation at the end, no quotes."
108
+ )
109
+ user_prompt = (
110
+ "Create a concise 3-7 word title for this user request:\n"
111
+ f"{query}"
112
+ )
113
+
114
+ response = client.chat_completion(
115
+ model=model_id,
116
+ messages=[
117
+ {"role": "system", "content": system_prompt},
118
+ {"role": "user", "content": user_prompt},
119
+ ],
120
+ max_tokens=24,
121
+ temperature=0.3,
122
+ )
123
+ if not response or not response.choices:
124
+ return None
125
+
126
+ raw_title = response.choices[0].message.content or ""
127
+ title = _clean_title_text(raw_title)
128
+ if not title or title.lower() == "new chat":
129
+ return None
130
+ return title
131
+
132
+
133
+ def _parse_title_model_candidates() -> list[str]:
134
+ raw = os.getenv(
135
+ "TITLE_MODEL_IDS",
136
+ "Qwen/Qwen2.5-1.5B-Instruct,CohereLabs/tiny-aya-global,meta-llama/Meta-Llama-3-8B-Instruct",
137
+ )
138
+ models = [m.strip() for m in raw.split(",") if m.strip()]
139
+ return models or ["meta-llama/Meta-Llama-3-8B-Instruct"]
140
+
141
+
142
 
143
  # Fastapi setup
144
  # Fastapi allows us to define python based endpoint
 
192
 
193
  @app.on_event("startup")
194
  def startup_event() -> None:
195
+ startup_start = time.perf_counter()
196
+
197
+ dotenv_start = time.perf_counter()
198
  load_dotenv()
199
+ dotenv_time = time.perf_counter() - dotenv_start
200
 
201
+ env_start = time.perf_counter()
202
  hf_token = os.getenv("HF_TOKEN")
203
  pinecone_api_key = os.getenv("PINECONE_API_KEY")
204
+ env_time = time.perf_counter() - env_start
205
 
206
  if not pinecone_api_key:
207
  raise RuntimeError("PINECONE_API_KEY not found in environment variables")
 
210
 
211
  index_name = "cbt-book-recursive"
212
  embed_model_name = "all-MiniLM-L6-v2"
213
+ project_root = os.path.dirname(os.path.abspath(__file__))
214
+ cache_dir = os.getenv("BM25_CACHE_DIR", os.path.join(project_root, ".cache"))
215
+ force_cache_refresh = os.getenv("BM25_CACHE_REFRESH", "0").lower() in {"1", "true", "yes"}
216
 
217
+ index_start = time.perf_counter()
 
218
  index = get_index_by_name(
219
  api_key=pinecone_api_key,
220
  index_name=index_name
221
  )
222
+ index_time = time.perf_counter() - index_start
223
 
224
  chunks_start = time.perf_counter()
225
+ final_chunks, chunk_source = load_chunks_with_local_cache(
226
+ index=index,
227
+ index_name=index_name,
228
+ cache_dir=cache_dir,
229
+ batch_size=100,
230
+ force_refresh=force_cache_refresh,
231
+ )
232
  chunk_load_time = time.perf_counter() - chunks_start
233
 
234
  if not final_chunks:
235
  raise RuntimeError("No chunks found in Pinecone metadata. Run indexing once before API mode.")
236
 
237
+ processor_start = time.perf_counter()
238
+ proc = ChunkProcessor(model_name=embed_model_name, verbose=False, load_hf_embeddings=False)
239
+ processor_time = time.perf_counter() - processor_start
240
+
241
+ retriever_start = time.perf_counter()
242
  retriever = HybridRetriever(final_chunks, proc.encoder, verbose=False)
243
+ retriever_time = time.perf_counter() - retriever_start
244
+
245
+ rag_start = time.perf_counter()
246
  rag_engine = RAGGenerator()
247
+ rag_time = time.perf_counter() - rag_start
248
+
249
+ models_start = time.perf_counter()
250
  models = _build_models(hf_token)
251
+ models_time = time.perf_counter() - models_start
252
 
253
+ state_start = time.perf_counter()
254
  state["index"] = index
255
  state["retriever"] = retriever
256
  state["rag_engine"] = rag_engine
257
  state["models"] = models
258
+ state["title_model_ids"] = _parse_title_model_candidates()
259
+ state["title_client"] = InferenceClient(token=hf_token)
260
+ state_time = time.perf_counter() - state_start
261
 
262
  startup_time = time.perf_counter() - startup_start
263
  print(
264
  f"API startup complete | chunks={len(final_chunks)} | "
265
+ f"dotenv={dotenv_time:.3f}s | "
266
+ f"env={env_time:.3f}s | "
267
+ f"index={index_time:.3f}s | "
268
+ f"cache_dir={cache_dir} | "
269
+ f"force_cache_refresh={force_cache_refresh} | "
270
+ f"chunk_source={chunk_source} | "
271
+ f"chunk_load={chunk_load_time:.3f}s | "
272
+ f"processor={processor_time:.3f}s | "
273
+ f"retriever={retriever_time:.3f}s | "
274
+ f"rag={rag_time:.3f}s | "
275
+ f"models={models_time:.3f}s | "
276
+ f"state={state_time:.3f}s | "
277
+ f"total={startup_time:.3f}s"
278
  )
279
 
280
 
 
284
  return {"status": "ok" if ready else "starting"}
285
 
286
 
287
+ #title generation endpoint
288
+ # is called only once when we create a new chat, after first prompt
289
+ @app.post("/predict/title", response_model=TitleResponse)
290
+ def suggest_title(payload: TitleRequest) -> TitleResponse:
291
+ query = payload.query.strip()
292
+ if not query:
293
+ raise HTTPException(status_code=400, detail="Query cannot be empty")
294
+
295
+ fallback_title = _title_from_query(query)
296
+
297
+ title_client: InferenceClient | None = state.get("title_client")
298
+ title_model_ids: list[str] = state.get("title_model_ids", _parse_title_model_candidates())
299
+
300
+ if title_client is not None:
301
+ for title_model_id in title_model_ids:
302
+ try:
303
+ hf_title = _title_from_hf(query, title_client, title_model_id)
304
+ if hf_title:
305
+ return TitleResponse(title=hf_title, source=f"hf:{title_model_id}")
306
+ except Exception as exc:
307
+ err_text = str(exc)
308
+ # Provider/model availability differs across HF accounts; skip unsupported models.
309
+ if "model_not_supported" in err_text or "not supported by any provider" in err_text:
310
+ continue
311
+ print(f"Title generation model failed ({title_model_id}): {exc}")
312
+ continue
313
+
314
+ print("Title generation fallback triggered: no title model available/successful")
315
+
316
+ return TitleResponse(title=fallback_title, source="rule-based")
317
+
318
+
319
 
320
  # Predict endpoint that takes a query and returns an answer along with contexts and metrics
321
  # is called from the frontend when user clicks submits
322
  # Also resolves model based on user selection
323
  @app.post("/predict", response_model=PredictResponse)
324
  def predict(payload: PredictRequest) -> PredictResponse:
325
+ req_start = time.perf_counter()
326
+
327
+ precheck_start = time.perf_counter()
328
  if not state:
329
  raise HTTPException(status_code=503, detail="Service not initialized yet")
330
 
331
  query = payload.query.strip()
332
  if not query:
333
  raise HTTPException(status_code=400, detail="Query cannot be empty")
334
+ precheck_time = time.perf_counter() - precheck_start
335
 
336
+ state_access_start = time.perf_counter()
 
337
  retriever: HybridRetriever = state["retriever"]
338
  index = state["index"]
339
  rag_engine: RAGGenerator = state["rag_engine"]
340
  models: dict[str, Any] = state["models"]
341
+ state_access_time = time.perf_counter() - state_access_start
342
 
343
+ model_resolve_start = time.perf_counter()
344
  model_name, model_instance = _resolve_model(payload.model, models)
345
+ model_resolve_time = time.perf_counter() - model_resolve_start
346
 
347
  retrieval_start = time.perf_counter()
348
  contexts = retriever.search(
 
360
  if not contexts:
361
  raise HTTPException(status_code=404, detail="No context chunks retrieved for this query")
362
 
363
+ inference_start = time.perf_counter()
364
  answer = rag_engine.get_answer(model_instance, query, contexts, temperature=0.1)
365
+ inference_time = time.perf_counter() - inference_start
366
+
367
+ response_start = time.perf_counter()
368
+ metrics = {
369
+ "precheck_s": round(precheck_time, 3),
370
+ "state_access_s": round(state_access_time, 3),
371
+ "model_resolve_s": round(model_resolve_time, 3),
372
+ "retrieval_s": round(retrieval_time, 3),
373
+ "inference_s": round(inference_time, 3),
374
+ }
375
+ response_build_time = time.perf_counter() - response_start
376
 
377
+ total_time = time.perf_counter() - req_start
378
+ metrics["response_build_s"] = round(response_build_time, 3)
379
+ metrics["total_s"] = round(total_time, 3)
380
+
381
+ print(
382
+ f"Predict timing | model={model_name} | mode={payload.mode} | "
383
+ f"rerank={payload.rerank_strategy} | precheck={precheck_time:.3f}s | "
384
+ f"state_access={state_access_time:.3f}s | model_resolve={model_resolve_time:.3f}s | "
385
+ f"retrieval={retrieval_time:.3f}s | inference={inference_time:.3f}s | "
386
+ f"response_build={response_build_time:.3f}s | total={total_time:.3f}s"
387
+ )
388
 
389
  return PredictResponse(
390
  model=model_name,
391
  answer=answer,
392
  contexts=contexts,
393
+ metrics=metrics,
394
+ )
395
+
396
+ # new endpoint for streaming response, allows frontend to render tokens as they come in instead of waiting for full answer
397
+ @app.post("/predict/stream")
398
+ def predict_stream(payload: PredictRequest) -> StreamingResponse:
399
+ req_start = time.perf_counter()
400
+
401
+ precheck_start = time.perf_counter()
402
+ if not state:
403
+ raise HTTPException(status_code=503, detail="Service not initialized yet")
404
+
405
+ query = payload.query.strip()
406
+ if not query:
407
+ raise HTTPException(status_code=400, detail="Query cannot be empty")
408
+ precheck_time = time.perf_counter() - precheck_start
409
+
410
+ state_access_start = time.perf_counter()
411
+ retriever: HybridRetriever = state["retriever"]
412
+ index = state["index"]
413
+ rag_engine: RAGGenerator = state["rag_engine"]
414
+ models: dict[str, Any] = state["models"]
415
+ state_access_time = time.perf_counter() - state_access_start
416
+
417
+ model_resolve_start = time.perf_counter()
418
+ model_name, model_instance = _resolve_model(payload.model, models)
419
+ model_resolve_time = time.perf_counter() - model_resolve_start
420
+
421
+ retrieval_start = time.perf_counter()
422
+ contexts = retriever.search(
423
+ query,
424
+ index,
425
+ mode=payload.mode,
426
+ rerank_strategy=payload.rerank_strategy,
427
+ use_mmr=True,
428
+ top_k=payload.top_k,
429
+ final_k=payload.final_k,
430
+ verbose=False,
431
+ )
432
+ retrieval_time = time.perf_counter() - retrieval_start
433
+
434
+ if not contexts:
435
+ raise HTTPException(status_code=404, detail="No context chunks retrieved for this query")
436
+
437
+ def stream_events():
438
+ inference_start = time.perf_counter()
439
+ answer_parts: list[str] = []
440
+ try:
441
+ for token in rag_engine.get_answer_stream(model_instance, query, contexts, temperature=0.1):
442
+ answer_parts.append(token)
443
+ yield _to_ndjson({"type": "token", "token": token})
444
+
445
+ inference_time = time.perf_counter() - inference_start
446
+ total_time = time.perf_counter() - req_start
447
+ answer = "".join(answer_parts)
448
+ metrics = {
449
+ "precheck_s": round(precheck_time, 3),
450
+ "state_access_s": round(state_access_time, 3),
451
+ "model_resolve_s": round(model_resolve_time, 3),
452
+ "retrieval_s": round(retrieval_time, 3),
453
+ "inference_s": round(inference_time, 3),
454
+ "total_s": round(total_time, 3),
455
+ }
456
+
457
+ yield _to_ndjson(
458
+ {
459
+ "type": "done",
460
+ "model": model_name,
461
+ "answer": answer,
462
+ "metrics": metrics,
463
+ }
464
+ )
465
+ except Exception as exc:
466
+ yield _to_ndjson({"type": "error", "message": f"Streaming failed: {exc}"})
467
+
468
+ return StreamingResponse(
469
+ stream_events(),
470
+ media_type="application/x-ndjson",
471
+ headers={
472
+ "Cache-Control": "no-cache",
473
+ "X-Accel-Buffering": "no",
474
  },
475
  )
models/deepseek_v3.py CHANGED
@@ -17,7 +17,10 @@ class DeepSeek_V3:
17
  ):
18
  if message.choices:
19
  content = message.choices[0].delta.content
20
- if content: response += content
 
21
  except Exception as e:
22
- return f" DeepSeek API Busy: {e}"
23
- return response
 
 
 
17
  ):
18
  if message.choices:
19
  content = message.choices[0].delta.content
20
+ if content:
21
+ yield content
22
  except Exception as e:
23
+ yield f" DeepSeek API Busy: {e}"
24
+
25
+ def generate(self, prompt, max_tokens=500, temperature=0.1):
26
+ return "".join(self.generate_stream(prompt, max_tokens=max_tokens, temperature=temperature))
models/llama_3_8b.py CHANGED
@@ -16,5 +16,8 @@ class Llama3_8B:
16
  ):
17
  if message.choices:
18
  content = message.choices[0].delta.content
19
- if content: response += content
20
- return response
 
 
 
 
16
  ):
17
  if message.choices:
18
  content = message.choices[0].delta.content
19
+ if content:
20
+ yield content
21
+
22
+ def generate(self, prompt, max_tokens=500, temperature=0.1):
23
+ return "".join(self.generate_stream(prompt, max_tokens=max_tokens, temperature=temperature))
models/mistral_7b.py CHANGED
@@ -18,9 +18,10 @@ class Mistral_7b:
18
  for chunk in stream:
19
  if chunk.choices and chunk.choices[0].delta.content:
20
  content = chunk.choices[0].delta.content
21
- response += content
22
 
23
  except Exception as e:
24
- return f" Mistral Featherless Error: {e}"
25
-
26
- return response
 
 
18
  for chunk in stream:
19
  if chunk.choices and chunk.choices[0].delta.content:
20
  content = chunk.choices[0].delta.content
21
+ yield content
22
 
23
  except Exception as e:
24
+ yield f" Mistral Featherless Error: {e}"
25
+
26
+ def generate(self, prompt, max_tokens=500, temperature=0.1):
27
+ return "".join(self.generate_stream(prompt, max_tokens=max_tokens, temperature=temperature))
models/qwen_2_5.py CHANGED
@@ -16,5 +16,8 @@ class Qwen2_5:
16
  ):
17
  if message.choices:
18
  content = message.choices[0].delta.content
19
- if content: response += content
20
- return response
 
 
 
 
16
  ):
17
  if message.choices:
18
  content = message.choices[0].delta.content
19
+ if content:
20
+ yield content
21
+
22
+ def generate(self, prompt, max_tokens=500, temperature=0.1):
23
+ return "".join(self.generate_stream(prompt, max_tokens=max_tokens, temperature=temperature))
models/tiny_aya.py CHANGED
@@ -18,8 +18,10 @@ class TinyAya:
18
  ):
19
  if message.choices:
20
  content = message.choices[0].delta.content
21
- if content: response += content
 
22
  except Exception as e:
23
- return f" TinyAya Error: {e}"
24
-
25
- return response
 
 
18
  ):
19
  if message.choices:
20
  content = message.choices[0].delta.content
21
+ if content:
22
+ yield content
23
  except Exception as e:
24
+ yield f" TinyAya Error: {e}"
25
+
26
+ def generate(self, prompt, max_tokens=500, temperature=0.1):
27
+ return "".join(self.generate_stream(prompt, max_tokens=max_tokens, temperature=temperature))
requirements.txt CHANGED
@@ -16,6 +16,7 @@ fastapi==0.121.1
16
  filelock==3.25.2
17
  frozenlist==1.8.0
18
  fsspec==2026.2.0
 
19
  greenlet==3.3.2
20
  h11==0.16.0
21
  hf-xet==1.4.2
 
16
  filelock==3.25.2
17
  frozenlist==1.8.0
18
  fsspec==2026.2.0
19
+ groq
20
  greenlet==3.3.2
21
  h11==0.16.0
22
  hf-xet==1.4.2
retriever/generator.py CHANGED
@@ -1,8 +1,10 @@
 
 
 
1
  class RAGGenerator:
2
  def generate_prompt(self, query, retrieved_contexts):
3
  """Prepares the academic prompt template."""
4
  context_text = "\n\n".join([f"--- Source {i+1} ---\n{c}" for i, c in enumerate(retrieved_contexts)])
5
-
6
  return f"""You are an expert academic assistant. Use the following pieces of retrieved context to answer the question.
7
  If the answer isn't in the context, say you don't know based on the provided documents.
8
 
@@ -16,4 +18,9 @@ Answer:"""
16
  def get_answer(self, model_instance, query, retrieved_contexts, **kwargs):
17
  """Uses a specific model instance to generate the final answer."""
18
  prompt = self.generate_prompt(query, retrieved_contexts)
19
- return model_instance.generate(prompt, **kwargs)
 
 
 
 
 
 
1
+ #changed the prompt to output as markdown, plus some formating details
2
+ #also added get answer stream for incremental token rendering on the frontend
3
+ # --@Qamar
4
  class RAGGenerator:
5
  def generate_prompt(self, query, retrieved_contexts):
6
  """Prepares the academic prompt template."""
7
  context_text = "\n\n".join([f"--- Source {i+1} ---\n{c}" for i, c in enumerate(retrieved_contexts)])
 
8
  return f"""You are an expert academic assistant. Use the following pieces of retrieved context to answer the question.
9
  If the answer isn't in the context, say you don't know based on the provided documents.
10
 
 
18
  def get_answer(self, model_instance, query, retrieved_contexts, **kwargs):
19
  """Uses a specific model instance to generate the final answer."""
20
  prompt = self.generate_prompt(query, retrieved_contexts)
21
+ return model_instance.generate(prompt, **kwargs)
22
+
23
+ def get_answer_stream(self, model_instance, query, retrieved_contexts, **kwargs):
24
+ """Yields model output chunks so the frontend can render incremental tokens."""
25
+ prompt = self.generate_prompt(query, retrieved_contexts)
26
+ return model_instance.generate_stream(prompt, **kwargs)
retriever/processor.py CHANGED
@@ -14,11 +14,16 @@ import pandas as pd
14
 
15
 
16
  class ChunkProcessor:
17
- def __init__(self, model_name='all-MiniLM-L6-v2', verbose: bool = True):
18
  self.model_name = model_name
19
  self.encoder = SentenceTransformer(model_name)
20
  self.verbose = verbose
21
- self.hf_embeddings = HuggingFaceEmbeddings(model_name=model_name)
 
 
 
 
 
22
 
23
  # ------------------------------------------------------------------
24
  # Splitters
@@ -84,7 +89,7 @@ class ChunkProcessor:
84
 
85
  elif technique == "semantic":
86
  return SemanticChunker(
87
- self.hf_embeddings,
88
  breakpoint_threshold_type=kwargs.get('breakpoint_threshold_type', "percentile"),
89
  # Using 70 because 95 was giving way too big chunks
90
  breakpoint_threshold_amount=kwargs.get('breakpoint_threshold_amount', 70)
 
14
 
15
 
16
  class ChunkProcessor:
17
+ def __init__(self, model_name='all-MiniLM-L6-v2', verbose: bool = True, load_hf_embeddings: bool = False):
18
  self.model_name = model_name
19
  self.encoder = SentenceTransformer(model_name)
20
  self.verbose = verbose
21
+ self.hf_embeddings = HuggingFaceEmbeddings(model_name=model_name) if load_hf_embeddings else None
22
+
23
+ def _get_hf_embeddings(self):
24
+ if self.hf_embeddings is None:
25
+ self.hf_embeddings = HuggingFaceEmbeddings(model_name=self.model_name)
26
+ return self.hf_embeddings
27
 
28
  # ------------------------------------------------------------------
29
  # Splitters
 
89
 
90
  elif technique == "semantic":
91
  return SemanticChunker(
92
+ self._get_hf_embeddings(),
93
  breakpoint_threshold_type=kwargs.get('breakpoint_threshold_type', "percentile"),
94
  # Using 70 because 95 was giving way too big chunks
95
  breakpoint_threshold_amount=kwargs.get('breakpoint_threshold_amount', 70)
retriever/retriever.py CHANGED
@@ -30,6 +30,17 @@ class HybridRetriever:
30
  # Better tokenization for BM25 (strips punctuation)
31
  self.tokenized_corpus = [self._tokenize(chunk['metadata']['text']) for chunk in final_chunks]
32
  self.bm25 = BM25Okapi(self.tokenized_corpus)
 
 
 
 
 
 
 
 
 
 
 
33
 
34
  def _tokenize(self, text: str) -> List[str]:
35
  """Tokenize text using regex to strip punctuation."""
 
30
  # Better tokenization for BM25 (strips punctuation)
31
  self.tokenized_corpus = [self._tokenize(chunk['metadata']['text']) for chunk in final_chunks]
32
  self.bm25 = BM25Okapi(self.tokenized_corpus)
33
+ bm25_time = time.perf_counter() - bm25_start
34
+
35
+ total_time = time.perf_counter() - init_start
36
+ print(
37
+ "HybridRetriever init complete | "
38
+ f"chunks={len(final_chunks)} | "
39
+ f"reranker_load={reranker_time:.3f}s | "
40
+ f"tokenize={tokenization_time:.3f}s | "
41
+ f"bm25_build={bm25_time:.3f}s | "
42
+ f"total={total_time:.3f}s"
43
+ )
44
 
45
  def _tokenize(self, text: str) -> List[str]:
46
  """Tokenize text using regex to strip punctuation."""
vector_db.py CHANGED
@@ -1,7 +1,14 @@
1
  import time
2
  import re
 
 
 
3
  from pinecone import Pinecone, ServerlessSpec
4
 
 
 
 
 
5
  def slugify_technique(name):
6
  """Converts 'Sentence Splitter' to 'sentence-splitter' for Pinecone naming."""
7
  return re.sub(r'[^a-z0-9]+', '-', name.lower()).strip('-')
@@ -109,6 +116,80 @@ def upsert_to_pinecone(index, chunks, batch_size=100):
109
  batch = chunks[i : i + batch_size]
110
  index.upsert(vectors=batch)
111
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
 
113
  def load_chunks_from_pinecone(index, batch_size: int = 100) -> list[dict[str, any]]:
114
  """
 
1
  import time
2
  import re
3
+ import json
4
+ from pathlib import Path
5
+ from typing import Any, Dict, List
6
  from pinecone import Pinecone, ServerlessSpec
7
 
8
+
9
+ # Added cacheing to reduce consecutive startup time
10
+ # --@Qamar
11
+
12
  def slugify_technique(name):
13
  """Converts 'Sentence Splitter' to 'sentence-splitter' for Pinecone naming."""
14
  return re.sub(r'[^a-z0-9]+', '-', name.lower()).strip('-')
 
116
  batch = chunks[i : i + batch_size]
117
  index.upsert(vectors=batch)
118
 
119
+ # Some methods for loading chunks back from Pinecone with local caching to speed up BM25 initialization
120
+
121
+ def _sanitize_index_name(index_name: str) -> str:
122
+ return re.sub(r'[^a-zA-Z0-9._-]+', '-', index_name).strip('-') or 'default-index'
123
+
124
+
125
+ def _chunk_cache_path(cache_dir: str, index_name: str) -> Path:
126
+ cache_root = Path(cache_dir)
127
+ cache_root.mkdir(parents=True, exist_ok=True)
128
+ safe_name = _sanitize_index_name(index_name)
129
+ return cache_root / f"bm25_chunks_{safe_name}.json"
130
+
131
+
132
+ def _read_chunk_cache(path: Path) -> Dict[str, Any]:
133
+ with path.open("r", encoding="utf-8") as f:
134
+ return json.load(f)
135
+
136
+
137
+ def _write_chunk_cache(path: Path, payload: Dict[str, Any]) -> None:
138
+ with path.open("w", encoding="utf-8") as f:
139
+ json.dump(payload, f)
140
+
141
+
142
+ def load_chunks_with_local_cache(
143
+ index,
144
+ index_name: str,
145
+ cache_dir: str = ".cache",
146
+ batch_size: int = 100,
147
+ force_refresh: bool = False,
148
+ ) -> tuple[List[Dict[str, Any]], str]:
149
+
150
+ cache_file = _chunk_cache_path(cache_dir=cache_dir, index_name=index_name)
151
+ stats = index.describe_index_stats()
152
+ current_count = stats.get("total_vector_count", 0)
153
+
154
+ if not force_refresh and cache_file.exists():
155
+ try:
156
+ cached_payload = _read_chunk_cache(cache_file)
157
+ cached_meta = cached_payload.get("meta", {})
158
+ cached_count = cached_meta.get("vector_count", -1)
159
+ cached_chunks = cached_payload.get("chunks", [])
160
+
161
+ if cached_count == current_count and cached_chunks:
162
+ print(
163
+ f" Loaded BM25 chunk cache: {cache_file} "
164
+ f"(chunks={len(cached_chunks)}, vectors={cached_count})"
165
+ )
166
+ return cached_chunks, "cache"
167
+
168
+ print(
169
+ " BM25 cache stale or empty. "
170
+ f"cache_vectors={cached_count}, pinecone_vectors={current_count}. Refreshing..."
171
+ )
172
+ except Exception as e:
173
+ print(f" Failed to read BM25 cache ({cache_file}): {e}. Refreshing from Pinecone...")
174
+
175
+ chunks = load_chunks_from_pinecone(index=index, batch_size=batch_size)
176
+ payload = {
177
+ "meta": {
178
+ "index_name": index_name,
179
+ "vector_count": current_count,
180
+ "updated_at_epoch_s": int(time.time()),
181
+ },
182
+ "chunks": chunks,
183
+ }
184
+
185
+ try:
186
+ _write_chunk_cache(cache_file, payload)
187
+ print(f" Saved BM25 chunk cache: {cache_file} (chunks={len(chunks)})")
188
+ except Exception as e:
189
+ print(f" Failed to write BM25 cache ({cache_file}): {e}")
190
+
191
+ return chunks, "pinecone"
192
+
193
 
194
  def load_chunks_from_pinecone(index, batch_size: int = 100) -> list[dict[str, any]]:
195
  """