rag_server / app.py
atkiya110's picture
Update app.py
e7a1a65 verified
"""
EWU RAG Server — Memory-efficient edition for HuggingFace Spaces
=================================================================
Stack:
Retrieval : BM25 (sparse) + intfloat/multilingual-e5-small (dense) → RRF fusion
Generation : TinyLlama-1.1B-Chat (local, float32, CPU)
Falls back to HF Inference API if HF_API_TOKEN is set
Extras : No reranker · No KG · No multi-query · No MMR
RAM budget (free-tier, 16 GB):
TinyLlama float32 ........... ~4.4 GB
multilingual-e5-small ........ ~120 MB
FAISS flat index (5k docs) ...... ~8 MB
BM25 + doc store ............. ~100 MB
FastAPI + overhead ........... ~300 MB
───────────────────────────── ~5 GB ✓
Env vars:
HF_API_TOKEN optional — enables HF Inference API fallback
HF_GEN_MODEL default: TinyLlama/TinyLlama-1.1B-Chat-v1.0
API_BASE EWU API base URL
API_KEY bearer token for API_BASE (optional)
GITHUB_BASE raw GitHub URL prefix for JSON files
"""
import asyncio
import gc
import logging
import os
import pickle
import re
import time
from contextlib import asynccontextmanager
from typing import Any, Dict, List, Optional, Tuple
import httpx
import numpy as np
import uvicorn
from fastapi import FastAPI, HTTPException
from fastapi.responses import JSONResponse
from pydantic import BaseModel
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
logger = logging.getLogger(__name__)
# ── optional deps ──────────────────────────────────────────────────────────────
try:
import faiss
FAISS_OK = True
except ImportError:
FAISS_OK = False
try:
from sentence_transformers import SentenceTransformer
ST_OK = True
except ImportError:
ST_OK = False
try:
from rank_bm25 import BM25Okapi
BM25_OK = True
except ImportError:
BM25_OK = False
# ── config ─────────────────────────────────────────────────────────────────────
HF_API_TOKEN = os.getenv("HF_API_TOKEN", "")
HF_GEN_MODEL = os.getenv("HF_GEN_MODEL", "TinyLlama/TinyLlama-1.1B-Chat-v1.0")
HF_API_URL = f"https://api-inference.huggingface.co/models/{HF_GEN_MODEL}"
HF_API_TIMEOUT = int(os.getenv("HF_API_TIMEOUT", "300"))
HF_MAX_NEW_TOKENS = int(os.getenv("HF_MAX_NEW_TOKENS", "300"))
EMBED_MODEL = os.getenv("EMBED_MODEL", "intfloat/multilingual-e5-small")
# Keep generation short to limit CPU time
GEN_MAX_NEW_TOKENS = 30
GEN_TIMEOUT_S = 40
GEN_PROMPT_MAX_CHARS = 120
API_BASE = os.getenv("API_BASE", "https://ewu-server.onrender.com/api")
API_KEY = os.getenv("API_KEY", "i6EDytaX4E2jI6GvZQc0b1RSZHTI5_wVRa2rfL7rLpk")
API_HEADERS = {"Authorization": f"Bearer {API_KEY}"} if API_KEY else {}
GITHUB_BASE = os.getenv(
"GITHUB_BASE",
"https://raw.githubusercontent.com/Atkiya/jsonfiles/main/",
)
TOP_K_RETRIEVE = int(os.getenv("TOP_K_RETRIEVE", "12"))
TOP_K_FINAL = int(os.getenv("TOP_K_FINAL", "5"))
CHUNK_SIZE = int(os.getenv("CHUNK_SIZE", "500"))
CACHE_DIR = os.getenv("CACHE_DIR", "./cache")
CACHE_TTL_H = int(os.getenv("CACHE_TTL_H", "24"))
CACHE_VERSION = "v1_tinyllama_e5"
os.makedirs(CACHE_DIR, exist_ok=True)
# ── data sources ───────────────────────────────────────────────────────────────
API_LIST_ENDPOINTS = [
("admission-deadlines", "Admission Deadlines", {}),
("academic-calendar", "Academic Calendar", {}),
("grade-scale", "Grade Scale", {}),
("departments", "Departments", {}),
("programs", "Programs", {}),
("tuition-fees", "Tuition Fees", {}),
("scholarships", "Scholarships", {}),
("clubs", "Clubs", {}),
("events", "Events", {}),
("notices", "Notices", {"limit": 200}),
("helpdesk", "Helpdesk", {}),
("faculty", "Faculty", {"limit": 200}),
("courses", "Courses", {"limit": 500}),
("policies", "Policies", {}),
("governance", "Governance", {}),
("alumni", "Alumni", {}),
("partnerships", "Partnerships", {}),
]
GITHUB_FILES = [
"dynamic_admission_process.json",
"dynamic_admission_requirements.json",
"dynamic_facilites.json",
"ma_english.json", "mba_emba.json", "ms_cse.json", "ms_dsa.json",
"mds.json", "mphil_pharmacy.json", "mss_eco.json",
"scholarships_and_financial_aids.json",
"st_ba.json", "st_ce.json", "st_cse.json", "st_ece.json",
"st_economics.json", "st_eee.json", "st_english.json",
"st_geb.json", "st_information_studies.json", "st_law.json",
"st_math.json", "st_pharmacy.json", "st_social_relations.json",
"st_sociology.json", "syndicate.json", "tesol.json",
"ewu_board_of_trustees.json", "admission_deadlines.json",
"ewu_faculty_complete.json", "dynamic_grading.json",
"ewu_proctor_schedule.json", "ewu_newsletters_complete.json",
"static_aboutEWU.json", "static_Admin.json",
"static_AllAvailablePrograms.json", "static_alumni.json",
"static_campus_life.json", "static_Career_Counseling_Center.json",
"static_clubs.json", "static_depts.json", "static_facilities.json",
"static_helpdesk.json", "static_payment_procedure.json",
"static_Policy.json", "static_Programs.json", "static_Rules.json",
"static_Sexual_harassment.json", "static_Tuition_fees.json",
]
# ── app state ──────────────────────────────────────────────────────────────────
class AppState:
embedder = None
gen_model = None
gen_tokenizer = None
documents: List[Dict] = []
doc_embeddings: Optional[np.ndarray] = None
faiss_index = None
bm25 = None
ready: bool = False
error: str = ""
state = AppState()
# ── cache helpers ──────────────────────────────────────────────────────────────
def _cp(name: str) -> str:
return os.path.join(CACHE_DIR, f"{CACHE_VERSION}_{name}")
def _cache_fresh(name: str) -> bool:
p = _cp(name)
return os.path.exists(p) and (time.time() - os.path.getmtime(p)) / 3600 < CACHE_TTL_H
def _save(name: str, obj: Any) -> None:
try:
with open(_cp(name), "wb") as f:
pickle.dump(obj, f, protocol=5)
logger.info("[cache] saved %s", name)
except Exception as e:
logger.warning("[cache] save %s failed: %s", name, e)
def _load(name: str) -> Optional[Any]:
try:
with open(_cp(name), "rb") as f:
return pickle.load(f)
except Exception:
return None
def _save_faiss(idx) -> None:
if FAISS_OK and idx is not None:
try:
faiss.write_index(idx, _cp("faiss.index"))
except Exception as e:
logger.warning("[cache] faiss save: %s", e)
def _load_faiss():
p = _cp("faiss.index")
if FAISS_OK and os.path.exists(p):
try:
return faiss.read_index(p)
except Exception:
pass
return None
# ── language ───────────────────────────────────────────────────────────────────
_BANGLA_MIN, _BANGLA_MAX = 0x0980, 0x09FF
_BANGLISH_KW = {
"ami","tumi","apni","ki","koto","kothay","ache","hobe","vorti","bhorti",
"fee","britti","bibhag","thikana","shesh","tarikh","kibhabe","kivabe",
}
BANGLISH_MAP = {
"vorti":"admission","bhorti":"admission","britti":"scholarship",
"bibhag":"department","thikana":"address","shesh":"last",
"tarikh":"date","fee":"fee","fees":"fee","kibhabe":"how","kivabe":"how",
}
NORMALIZE_RULES: List[Tuple[str, str]] = [
(r"\bvc\b|উপাচার্য|ভিসি", "vice chancellor"),
(r"ভর্তি|এডমিশন|\bvorti\b|\bbhorti\b", "admission"),
(r"বৃত্তি|স্কলারশিপ|\bbritti\b", "scholarship"),
(r"ফি|টিউশন|\bfee\b|\bfees\b|\btuition\b", "tuition fee"),
(r"ঠিকানা|কোথায়|\bthikana\b|\bkothay\b", "address location"),
(r"ডেডলাইন|শেষ তারিখ|\bdeadline\b", "deadline"),
(r"বিভাগ|\bbibhag\b|\bdepartment\b|\bdept\b", "department"),
(r"কোর্স|\bcourse\b|\bsubject\b", "course"),
]
def detect_language(text: str) -> str:
if any(_BANGLA_MIN <= ord(ch) <= _BANGLA_MAX for ch in text):
return "bangla"
if set(re.findall(r"\w+", text.lower())) & _BANGLISH_KW:
return "banglish"
return "english"
def normalize_query(text: str) -> str:
q = text.strip().lower()
for tok in re.findall(r"\w+", q):
if tok in BANGLISH_MAP:
q = q.replace(tok, BANGLISH_MAP[tok])
for pat, repl in NORMALIZE_RULES:
q = re.sub(pat, repl, q, flags=re.IGNORECASE)
q = re.sub(r"[^\w\s\u0980-\u09FF]", " ", q)
return re.sub(r"\s+", " ", q).strip()
def tokenize(text: str) -> List[str]:
return [t for t in re.findall(r"[\w\u0980-\u09FF]+", normalize_query(text)) if len(t) > 1]
# ── fetching ───────────────────────────────────────────────────────────────────
async def fetch_json(url: str, headers: dict = None,
params: dict = None, timeout: int = 60) -> Optional[Any]:
try:
async with httpx.AsyncClient(timeout=timeout) as client:
r = await client.get(url, headers=headers or {}, params=params or {})
if r.status_code == 200:
return r.json()
logger.warning("HTTP %s → %s", r.status_code, url)
except Exception as e:
logger.warning("Fetch %s: %s", url, e)
return None
def _unwrap(data: Any) -> list:
if isinstance(data, list):
return data
if isinstance(data, dict):
for key in ("data", "results", "items"):
if isinstance(data.get(key), list):
return data[key]
return [data]
return []
async def load_api() -> list:
# Wake the Render.com server (it sleeps on free tier)
for attempt in range(3):
if await fetch_json(f"{API_BASE}/grade-scale", API_HEADERS, timeout=60):
logger.info("API server awake.")
break
logger.info("Wake attempt %d/3 …", attempt + 1)
await asyncio.sleep(10)
else:
logger.warning("API server unreachable — skipping.")
return []
results = await asyncio.gather(
*[fetch_json(f"{API_BASE}/{s}", API_HEADERS, params=p, timeout=60)
for s, _, p in API_LIST_ENDPOINTS],
return_exceptions=True,
)
docs = []
for (suffix, label, _), data in zip(API_LIST_ENDPOINTS, results):
if not data or isinstance(data, Exception):
continue
items = _unwrap(data)
for item in items:
docs.append({"raw": item, "source": f"api:{suffix}"})
logger.info("%-28s → %d records", label, len(items))
logger.info("API total: %d docs", len(docs))
return docs
async def load_github() -> list:
responses = await asyncio.gather(
*[fetch_json(GITHUB_BASE + f, timeout=60) for f in GITHUB_FILES],
return_exceptions=True,
)
docs = []
for fname, data in zip(GITHUB_FILES, responses):
if not data or isinstance(data, Exception):
continue
for item in (data if isinstance(data, list) else [data]):
docs.append({"raw": item, "source": f"github:{fname}"})
logger.info("GitHub total: %d docs", len(docs))
return docs
# ── chunking ───────────────────────────────────────────────────────────────────
FIELD_PRIORITY = [
"title","name","department","program_name","course_code","course_title",
"credit","deadline","date","email","phone","amount","fee","description",
"eligibility","requirements","address","location",
]
def _flatten(obj: Any, path: str = "") -> List[Tuple[str, str]]:
out = []
if isinstance(obj, dict):
for k, v in obj.items():
p = f"{path} > {k}" if path else k
if isinstance(v, (dict, list)):
out.extend(_flatten(v, p))
else:
s = str(v).strip()
if s and s.lower() not in ("null", "none", "", "[]", "{}"):
out.append((p, s))
elif isinstance(obj, list):
for i, item in enumerate(obj):
if isinstance(item, (dict, list)):
out.extend(_flatten(item, f"{path}[{i}]"))
else:
s = str(item).strip()
if s:
out.append((f"{path}[{i}]", s))
return out
def make_chunk(raw: Any, source: str) -> str:
if not isinstance(raw, (dict, list)):
return f"source: {source} | {str(raw).strip()}"[:CHUNK_SIZE]
pairs = _flatten(raw)
def _pri(k: str) -> int:
kl = k.lower()
for i, f in enumerate(FIELD_PRIORITY):
if f in kl:
return i
return 999
pairs.sort(key=lambda kv: _pri(kv[0]))
lines = [f"source: {source}"] + [f"{k}: {v}" for k, v in pairs[:20]]
return " | ".join(lines)[:CHUNK_SIZE]
def chunk_documents(raw_docs: List[Dict]) -> List[Dict]:
out = []
for doc in raw_docs:
text = make_chunk(doc["raw"], doc["source"])
if text.strip():
out.append({"content": text, "source": doc["source"]})
return out
# ── indexing ───────────────────────────────────────────────────────────────────
def build_bm25(documents: List[Dict]) -> Optional[Any]:
if not BM25_OK:
return None
toks = [tokenize(d["content"]) for d in documents]
toks = [t for t in toks if t]
if not toks:
return None
obj = BM25Okapi(toks)
logger.info("[BM25] %d docs indexed.", len(toks))
return obj
def build_faiss(documents: List[Dict], embedder) -> Tuple[Optional[Any], Optional[np.ndarray]]:
if not FAISS_OK or embedder is None:
return None, None
try:
texts = [f"passage: {d['content']}" for d in documents]
emb = embedder.encode(
texts,
normalize_embeddings=True,
show_progress_bar=True,
batch_size=32, # smaller batch → lower peak RAM
)
emb = np.array(emb, dtype="float32")
idx = faiss.IndexFlatIP(emb.shape[1])
idx.add(emb)
logger.info("[FAISS] %d vectors, dim=%d.", idx.ntotal, emb.shape[1])
return idx, emb
except Exception as e:
logger.error("[FAISS] build failed: %s", e)
return None, None
# ── retrieval ──────────────────────────────────────────────────────────────────
def dense_retrieve(query: str, k: int = TOP_K_RETRIEVE) -> List[Dict]:
if state.faiss_index is None or state.embedder is None:
return []
try:
q_vec = state.embedder.encode(
[f"query: {query}"], normalize_embeddings=True
)
q_vec = np.array(q_vec, dtype="float32")
k_a = min(k, state.faiss_index.ntotal)
scores, ids = state.faiss_index.search(q_vec, k_a)
return [
{**state.documents[i], "dense_score": float(s)}
for s, i in zip(scores[0], ids[0]) if i >= 0
]
except Exception as e:
logger.error("[dense] %s", e)
return []
def sparse_retrieve(query: str, k: int = TOP_K_RETRIEVE) -> List[Dict]:
if state.bm25 is None:
return []
try:
toks = tokenize(query)
if not toks:
return []
scores = np.array(state.bm25.get_scores(toks), dtype="float32")
idxs = np.argsort(scores)[::-1][:k]
return [
{**state.documents[i], "sparse_score": float(scores[i])}
for i in idxs if scores[i] > 0
]
except Exception as e:
logger.error("[sparse] %s", e)
return []
def rrf_fuse(dense: List[Dict], sparse: List[Dict],
w_dense: float = 0.6, w_sparse: float = 0.4,
rrf_k: int = 60) -> List[Dict]:
"""Reciprocal-rank fusion of two ranked lists."""
scores: Dict[str, float] = {}
doc_map: Dict[str, Dict] = {}
for rank, d in enumerate(dense):
key = d["source"] + "||" + d["content"]
scores[key] = scores.get(key, 0.0) + w_dense / (rrf_k + rank + 1)
doc_map[key] = d
for rank, d in enumerate(sparse):
key = d["source"] + "||" + d["content"]
scores[key] = scores.get(key, 0.0) + w_sparse / (rrf_k + rank + 1)
doc_map[key] = d
return [
{**doc_map[k], "rrf_score": round(s, 6)}
for k, s in sorted(scores.items(), key=lambda x: x[1], reverse=True)
]
async def retrieve(query: str, k: int = TOP_K_FINAL) -> List[Dict]:
norm = normalize_query(query)
dense = await asyncio.to_thread(dense_retrieve, norm, TOP_K_RETRIEVE)
sparse = await asyncio.to_thread(sparse_retrieve, norm, TOP_K_RETRIEVE)
fused = rrf_fuse(dense, sparse)
return fused[:k]
# ── generation ─────────────────────────────────────────────────────────────────
SYSTEM_PROMPT = (
"You are EWU Assistant for East West University, Dhaka, Bangladesh. "
"Answer ONLY from the provided context. Be concise and clear. "
"Use bullet points for lists. Never invent fees, dates, names, or codes. "
"If the context does not contain the answer, say so honestly. "
"You support English, Bangla (বাংলা), and Banglish."
)
def _lang_instruction(lang: str) -> str:
if lang == "bangla":
return "Respond entirely in Bengali (বাংলা script)."
if lang == "banglish":
return "Respond in Banglish — Bengali meaning written with English letters."
return "Respond in clear English."
def _fallback(lang: str) -> str:
url = "https://www.ewubd.edu/"
if lang == "bangla":
return f"দুঃখিত, উত্তর দিতে পারছি না। EWU ওয়েবসাইট দেখুন: {url}"
if lang == "banglish":
return f"Sorry, ekhon answer dite parcchi na. EWU website dekhun: {url}"
return f"Sorry, I couldn't generate a response. Please visit: {url}"
# ── HF Inference API (optional fallback) ──────────────────────────────────────
async def _generate_hf_api(query: str, context: str, lang: str) -> str:
if not HF_API_TOKEN:
return ""
prompt = (
f"<|system|>\n{SYSTEM_PROMPT}\n{_lang_instruction(lang)}</s>\n"
f"<|user|>\nContext:\n{context[:GEN_PROMPT_MAX_CHARS]}\n\nQuestion: {query}</s>\n"
f"<|assistant|>\n"
)
payload = {
"inputs": prompt,
"parameters": {
"max_new_tokens": HF_MAX_NEW_TOKENS,
"temperature": 0.3,
"top_p": 0.9,
"repetition_penalty": 1.1,
"return_full_text": False,
},
}
headers = {"Authorization": f"Bearer {HF_API_TOKEN}", "Content-Type": "application/json"}
async with httpx.AsyncClient(timeout=HF_API_TIMEOUT) as client:
for attempt in range(2):
try:
r = await client.post(HF_API_URL, headers=headers, json=payload)
if r.status_code == 200:
data = r.json()
if isinstance(data, list) and data:
return data[0].get("generated_text", "").strip()
return ""
if r.status_code == 503 and attempt == 0:
wait = min(int(r.json().get("estimated_time", 20)), 30)
logger.info("[HF API] model loading, waiting %ds …", wait)
await asyncio.sleep(wait)
continue
logger.error("[HF API] HTTP %s: %s", r.status_code, r.text[:200])
return ""
except httpx.TimeoutException:
logger.error("[HF API] timeout.")
return ""
except Exception as e:
logger.error("[HF API] %s", e)
return ""
return ""
# ── local TinyLlama generation ─────────────────────────────────────────────────
def _generate_local_sync(query: str, context: str, lang: str) -> str:
import torch
model = state.gen_model
tokenizer = state.gen_tokenizer
if model is None or tokenizer is None:
return ""
messages = [
{"role": "system", "content": SYSTEM_PROMPT + "\n" + _lang_instruction(lang)},
{"role": "user", "content": f"Context:\n{context[:GEN_PROMPT_MAX_CHARS]}\n\nQuestion: {query}"},
]
try:
text = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
inputs = tokenizer(
text,
return_tensors="pt",
truncation=True,
max_length=1024, # keep prompt short to save RAM
)
prompt_len = inputs["input_ids"].shape[-1]
logger.info("[local] prompt_tokens=%d max_new=%d", prompt_len, GEN_MAX_NEW_TOKENS)
with torch.inference_mode():
logger.info("[local] starting generate")
output_ids = model.generate(
**inputs,
max_new_tokens = GEN_MAX_NEW_TOKENS,
do_sample = True,
temperature = 0.3,
top_p = 0.9,
repetition_penalty = 1.1,
pad_token_id = tokenizer.eos_token_id,
eos_token_id = tokenizer.eos_token_id,
)
logger.info("[local] generate finished")
new_ids = output_ids[0][prompt_len:]
answer = tokenizer.decode(new_ids, skip_special_tokens=True).strip()
logger.info("[local] generated %d tokens → %d chars.", len(new_ids), len(answer))
return answer
except Exception as e:
logger.error("[local] generation failed: %s", e)
return ""
finally:
# Free activation memory after each call
gc.collect()
async def generate(query: str, context: str, lang: str) -> str:
# 1. Try HF Inference API first (costs 0 local RAM)
if HF_API_TOKEN:
answer = await _generate_hf_api(query, context, lang)
if answer:
return answer
logger.warning("[gen] HF API failed — falling back to local TinyLlama.")
# 2. Local TinyLlama
if state.gen_model is not None:
try:
answer = await asyncio.wait_for(
asyncio.to_thread(_generate_local_sync, query, context, lang),
timeout=GEN_TIMEOUT_S,
)
if answer:
return answer
except asyncio.TimeoutError:
logger.error("[gen] TinyLlama timed out after %ds.", GEN_TIMEOUT_S)
except Exception as e:
logger.error("[gen] TinyLlama error: %s", e)
return _fallback(lang)
# ── model loading ──────────────────────────────────────────────────────────────
def _load_embedder():
if not ST_OK:
logger.warning("[embed] sentence-transformers not installed.")
return None
try:
logger.info("[embed] Loading %s …", EMBED_MODEL)
model = SentenceTransformer(EMBED_MODEL, device="cpu")
logger.info("[embed] ✓ Ready.")
return model
except Exception as e:
logger.error("[embed] Failed: %s", e)
return None
def _load_tinyllama():
try:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
logger.info("[gen] Loading TinyLlama (%s) …", HF_GEN_MODEL)
tokenizer = AutoTokenizer.from_pretrained(HF_GEN_MODEL, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
HF_GEN_MODEL,
torch_dtype = torch.float32, # float16 unstable on CPU
low_cpu_mem_usage = True, # stream weights → lower peak RAM
trust_remote_code = True,
)
model.eval()
# Limit CPU threads so we don't starve other workers
torch.set_num_threads(2)
logger.info("[gen] ✓ TinyLlama ready (float32, CPU).")
return model, tokenizer
except Exception as e:
logger.error("[gen] TinyLlama load failed: %s", e)
return None, None
# ── boot ───────────────────────────────────────────────────────────────────────
async def _boot():
try:
logger.info("=== EWU RAG (TinyLlama + e5-small, lean pipeline) ===")
# Step 1 — load embedding model (small, fast)
state.embedder = await asyncio.to_thread(_load_embedder)
# Step 2 — try loading indexes from cache
cache_ready = (
_cache_fresh("documents.pkl")
and _cache_fresh("bm25.pkl")
and (_cache_fresh("faiss.index") or not FAISS_OK)
)
if cache_ready:
docs = _load("documents.pkl")
bm25 = _load("bm25.pkl")
if docs and bm25:
state.documents = docs
state.bm25 = bm25
state.faiss_index, state.doc_embeddings = None, None
idx = _load_faiss()
if idx is not None:
state.faiss_index = idx
logger.info("[cache] FAISS loaded (%d vectors).", idx.ntotal)
logger.info("[cache] %d chunks loaded from disk.", len(docs))
# Load TinyLlama AFTER data is ready (so we see OOM early if needed)
state.gen_model, state.gen_tokenizer = await asyncio.to_thread(_load_tinyllama)
state.ready = True
return
# Step 3 — fetch fresh data
logger.info("Fetching data (API + GitHub) …")
api_docs, gh_docs = await asyncio.gather(load_api(), load_github())
raw_docs = api_docs + gh_docs
logger.info("Total raw docs: %d", len(raw_docs))
if not raw_docs:
logger.warning("No documents fetched — starting empty.")
state.ready = True
return
# Step 4 — chunk
logger.info("Chunking …")
state.documents = await asyncio.to_thread(chunk_documents, raw_docs)
logger.info("Total chunks: %d", len(state.documents))
# Step 5 — BM25 index
state.bm25 = await asyncio.to_thread(build_bm25, state.documents)
_save("documents.pkl", state.documents)
if state.bm25:
_save("bm25.pkl", state.bm25)
# Step 6 — FAISS index (only if embedder loaded successfully)
if state.embedder:
logger.info("Building FAISS index …")
idx, emb = await asyncio.to_thread(
build_faiss, state.documents, state.embedder
)
state.faiss_index = idx
state.doc_embeddings = emb
if idx is not None:
_save_faiss(idx)
# Step 7 — load TinyLlama last (heaviest, ~4.4 GB)
state.gen_model, state.gen_tokenizer = await asyncio.to_thread(_load_tinyllama)
state.ready = True
logger.info("✓ EWU RAG fully ready.")
except Exception as e:
state.error = str(e)
state.ready = False
logger.exception("Boot failed")
@asynccontextmanager
async def lifespan(app: FastAPI):
task = asyncio.create_task(_boot())
try:
yield
finally:
task.cancel()
try:
await task
except asyncio.CancelledError:
pass
# ── FastAPI ────────────────────────────────────────────────────────────────────
app = FastAPI(title="EWU RAG (TinyLlama + e5-small)", lifespan=lifespan)
class Query(BaseModel):
query: str
top_k: int = TOP_K_FINAL
@app.get("/")
async def root():
return {
"service": "EWU RAG Server",
"status": "ready" if state.ready else ("error" if state.error else "loading"),
"docs_loaded": len(state.documents),
"faiss": state.faiss_index is not None,
"bm25": state.bm25 is not None,
"gen_model": HF_GEN_MODEL,
"gen_loaded": state.gen_model is not None,
"endpoints": {
"POST /rag": "Submit a question",
"GET /health": "Detailed health check",
},
}
@app.post("/rag")
async def rag_endpoint(q: Query):
if not state.ready:
raise HTTPException(503, detail=state.error or "Still initializing — retry shortly.")
raw_query = q.query.strip()
if not raw_query:
raise HTTPException(400, detail="Query must not be empty.")
lang = detect_language(raw_query)
top_k = max(1, min(q.top_k, 8))
results = await retrieve(raw_query, k=top_k)
if not results:
return {
"answer": _fallback(lang),
"lang": lang,
"sources": [],
}
context = "\n\n---\n\n".join(r["content"] for r in results)
answer = await generate(raw_query, context, lang)
return {
"answer": answer,
"lang": lang,
"sources": [
{
"source": r["source"],
"rrf_score": round(r.get("rrf_score", 0), 6),
"dense_score": round(r.get("dense_score", 0), 4) if "dense_score" in r else None,
"sparse_score": round(r.get("sparse_score", 0), 4) if "sparse_score" in r else None,
}
for r in results
],
}
@app.get("/health")
async def health():
return {
"status": "ready" if state.ready else ("error" if state.error else "loading"),
"docs": len(state.documents),
"embed_model": EMBED_MODEL,
"gen_model": HF_GEN_MODEL,
"gen_loaded": state.gen_model is not None,
"hf_api_token": bool(HF_API_TOKEN),
"faiss": state.faiss_index is not None,
"bm25": state.bm25 is not None,
"max_new_tokens": GEN_MAX_NEW_TOKENS,
"gen_timeout_s": GEN_TIMEOUT_S,
"error": state.error or None,
}
if __name__ == "__main__":
uvicorn.run("app:app", host="0.0.0.0", port=7860, reload=False)