Spaces:
Sleeping
Sleeping
| """ | |
| EWU RAG Server — Memory-efficient edition for HuggingFace Spaces | |
| ================================================================= | |
| Stack: | |
| Retrieval : BM25 (sparse) + intfloat/multilingual-e5-small (dense) → RRF fusion | |
| Generation : TinyLlama-1.1B-Chat (local, float32, CPU) | |
| Falls back to HF Inference API if HF_API_TOKEN is set | |
| Extras : No reranker · No KG · No multi-query · No MMR | |
| RAM budget (free-tier, 16 GB): | |
| TinyLlama float32 ........... ~4.4 GB | |
| multilingual-e5-small ........ ~120 MB | |
| FAISS flat index (5k docs) ...... ~8 MB | |
| BM25 + doc store ............. ~100 MB | |
| FastAPI + overhead ........... ~300 MB | |
| ───────────────────────────── ~5 GB ✓ | |
| Env vars: | |
| HF_API_TOKEN optional — enables HF Inference API fallback | |
| HF_GEN_MODEL default: TinyLlama/TinyLlama-1.1B-Chat-v1.0 | |
| API_BASE EWU API base URL | |
| API_KEY bearer token for API_BASE (optional) | |
| GITHUB_BASE raw GitHub URL prefix for JSON files | |
| """ | |
| import asyncio | |
| import gc | |
| import logging | |
| import os | |
| import pickle | |
| import re | |
| import time | |
| from contextlib import asynccontextmanager | |
| from typing import Any, Dict, List, Optional, Tuple | |
| import httpx | |
| import numpy as np | |
| import uvicorn | |
| from fastapi import FastAPI, HTTPException | |
| from fastapi.responses import JSONResponse | |
| from pydantic import BaseModel | |
| logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s") | |
| logger = logging.getLogger(__name__) | |
| # ── optional deps ────────────────────────────────────────────────────────────── | |
| try: | |
| import faiss | |
| FAISS_OK = True | |
| except ImportError: | |
| FAISS_OK = False | |
| try: | |
| from sentence_transformers import SentenceTransformer | |
| ST_OK = True | |
| except ImportError: | |
| ST_OK = False | |
| try: | |
| from rank_bm25 import BM25Okapi | |
| BM25_OK = True | |
| except ImportError: | |
| BM25_OK = False | |
| # ── config ───────────────────────────────────────────────────────────────────── | |
| HF_API_TOKEN = os.getenv("HF_API_TOKEN", "") | |
| HF_GEN_MODEL = os.getenv("HF_GEN_MODEL", "TinyLlama/TinyLlama-1.1B-Chat-v1.0") | |
| HF_API_URL = f"https://api-inference.huggingface.co/models/{HF_GEN_MODEL}" | |
| HF_API_TIMEOUT = int(os.getenv("HF_API_TIMEOUT", "300")) | |
| HF_MAX_NEW_TOKENS = int(os.getenv("HF_MAX_NEW_TOKENS", "300")) | |
| EMBED_MODEL = os.getenv("EMBED_MODEL", "intfloat/multilingual-e5-small") | |
| # Keep generation short to limit CPU time | |
| GEN_MAX_NEW_TOKENS = 30 | |
| GEN_TIMEOUT_S = 40 | |
| GEN_PROMPT_MAX_CHARS = 120 | |
| API_BASE = os.getenv("API_BASE", "https://ewu-server.onrender.com/api") | |
| API_KEY = os.getenv("API_KEY", "i6EDytaX4E2jI6GvZQc0b1RSZHTI5_wVRa2rfL7rLpk") | |
| API_HEADERS = {"Authorization": f"Bearer {API_KEY}"} if API_KEY else {} | |
| GITHUB_BASE = os.getenv( | |
| "GITHUB_BASE", | |
| "https://raw.githubusercontent.com/Atkiya/jsonfiles/main/", | |
| ) | |
| TOP_K_RETRIEVE = int(os.getenv("TOP_K_RETRIEVE", "12")) | |
| TOP_K_FINAL = int(os.getenv("TOP_K_FINAL", "5")) | |
| CHUNK_SIZE = int(os.getenv("CHUNK_SIZE", "500")) | |
| CACHE_DIR = os.getenv("CACHE_DIR", "./cache") | |
| CACHE_TTL_H = int(os.getenv("CACHE_TTL_H", "24")) | |
| CACHE_VERSION = "v1_tinyllama_e5" | |
| os.makedirs(CACHE_DIR, exist_ok=True) | |
| # ── data sources ─────────────────────────────────────────────────────────────── | |
| API_LIST_ENDPOINTS = [ | |
| ("admission-deadlines", "Admission Deadlines", {}), | |
| ("academic-calendar", "Academic Calendar", {}), | |
| ("grade-scale", "Grade Scale", {}), | |
| ("departments", "Departments", {}), | |
| ("programs", "Programs", {}), | |
| ("tuition-fees", "Tuition Fees", {}), | |
| ("scholarships", "Scholarships", {}), | |
| ("clubs", "Clubs", {}), | |
| ("events", "Events", {}), | |
| ("notices", "Notices", {"limit": 200}), | |
| ("helpdesk", "Helpdesk", {}), | |
| ("faculty", "Faculty", {"limit": 200}), | |
| ("courses", "Courses", {"limit": 500}), | |
| ("policies", "Policies", {}), | |
| ("governance", "Governance", {}), | |
| ("alumni", "Alumni", {}), | |
| ("partnerships", "Partnerships", {}), | |
| ] | |
| GITHUB_FILES = [ | |
| "dynamic_admission_process.json", | |
| "dynamic_admission_requirements.json", | |
| "dynamic_facilites.json", | |
| "ma_english.json", "mba_emba.json", "ms_cse.json", "ms_dsa.json", | |
| "mds.json", "mphil_pharmacy.json", "mss_eco.json", | |
| "scholarships_and_financial_aids.json", | |
| "st_ba.json", "st_ce.json", "st_cse.json", "st_ece.json", | |
| "st_economics.json", "st_eee.json", "st_english.json", | |
| "st_geb.json", "st_information_studies.json", "st_law.json", | |
| "st_math.json", "st_pharmacy.json", "st_social_relations.json", | |
| "st_sociology.json", "syndicate.json", "tesol.json", | |
| "ewu_board_of_trustees.json", "admission_deadlines.json", | |
| "ewu_faculty_complete.json", "dynamic_grading.json", | |
| "ewu_proctor_schedule.json", "ewu_newsletters_complete.json", | |
| "static_aboutEWU.json", "static_Admin.json", | |
| "static_AllAvailablePrograms.json", "static_alumni.json", | |
| "static_campus_life.json", "static_Career_Counseling_Center.json", | |
| "static_clubs.json", "static_depts.json", "static_facilities.json", | |
| "static_helpdesk.json", "static_payment_procedure.json", | |
| "static_Policy.json", "static_Programs.json", "static_Rules.json", | |
| "static_Sexual_harassment.json", "static_Tuition_fees.json", | |
| ] | |
| # ── app state ────────────────────────────────────────────────────────────────── | |
| class AppState: | |
| embedder = None | |
| gen_model = None | |
| gen_tokenizer = None | |
| documents: List[Dict] = [] | |
| doc_embeddings: Optional[np.ndarray] = None | |
| faiss_index = None | |
| bm25 = None | |
| ready: bool = False | |
| error: str = "" | |
| state = AppState() | |
| # ── cache helpers ────────────────────────────────────────────────────────────── | |
| def _cp(name: str) -> str: | |
| return os.path.join(CACHE_DIR, f"{CACHE_VERSION}_{name}") | |
| def _cache_fresh(name: str) -> bool: | |
| p = _cp(name) | |
| return os.path.exists(p) and (time.time() - os.path.getmtime(p)) / 3600 < CACHE_TTL_H | |
| def _save(name: str, obj: Any) -> None: | |
| try: | |
| with open(_cp(name), "wb") as f: | |
| pickle.dump(obj, f, protocol=5) | |
| logger.info("[cache] saved %s", name) | |
| except Exception as e: | |
| logger.warning("[cache] save %s failed: %s", name, e) | |
| def _load(name: str) -> Optional[Any]: | |
| try: | |
| with open(_cp(name), "rb") as f: | |
| return pickle.load(f) | |
| except Exception: | |
| return None | |
| def _save_faiss(idx) -> None: | |
| if FAISS_OK and idx is not None: | |
| try: | |
| faiss.write_index(idx, _cp("faiss.index")) | |
| except Exception as e: | |
| logger.warning("[cache] faiss save: %s", e) | |
| def _load_faiss(): | |
| p = _cp("faiss.index") | |
| if FAISS_OK and os.path.exists(p): | |
| try: | |
| return faiss.read_index(p) | |
| except Exception: | |
| pass | |
| return None | |
| # ── language ─────────────────────────────────────────────────────────────────── | |
| _BANGLA_MIN, _BANGLA_MAX = 0x0980, 0x09FF | |
| _BANGLISH_KW = { | |
| "ami","tumi","apni","ki","koto","kothay","ache","hobe","vorti","bhorti", | |
| "fee","britti","bibhag","thikana","shesh","tarikh","kibhabe","kivabe", | |
| } | |
| BANGLISH_MAP = { | |
| "vorti":"admission","bhorti":"admission","britti":"scholarship", | |
| "bibhag":"department","thikana":"address","shesh":"last", | |
| "tarikh":"date","fee":"fee","fees":"fee","kibhabe":"how","kivabe":"how", | |
| } | |
| NORMALIZE_RULES: List[Tuple[str, str]] = [ | |
| (r"\bvc\b|উপাচার্য|ভিসি", "vice chancellor"), | |
| (r"ভর্তি|এডমিশন|\bvorti\b|\bbhorti\b", "admission"), | |
| (r"বৃত্তি|স্কলারশিপ|\bbritti\b", "scholarship"), | |
| (r"ফি|টিউশন|\bfee\b|\bfees\b|\btuition\b", "tuition fee"), | |
| (r"ঠিকানা|কোথায়|\bthikana\b|\bkothay\b", "address location"), | |
| (r"ডেডলাইন|শেষ তারিখ|\bdeadline\b", "deadline"), | |
| (r"বিভাগ|\bbibhag\b|\bdepartment\b|\bdept\b", "department"), | |
| (r"কোর্স|\bcourse\b|\bsubject\b", "course"), | |
| ] | |
| def detect_language(text: str) -> str: | |
| if any(_BANGLA_MIN <= ord(ch) <= _BANGLA_MAX for ch in text): | |
| return "bangla" | |
| if set(re.findall(r"\w+", text.lower())) & _BANGLISH_KW: | |
| return "banglish" | |
| return "english" | |
| def normalize_query(text: str) -> str: | |
| q = text.strip().lower() | |
| for tok in re.findall(r"\w+", q): | |
| if tok in BANGLISH_MAP: | |
| q = q.replace(tok, BANGLISH_MAP[tok]) | |
| for pat, repl in NORMALIZE_RULES: | |
| q = re.sub(pat, repl, q, flags=re.IGNORECASE) | |
| q = re.sub(r"[^\w\s\u0980-\u09FF]", " ", q) | |
| return re.sub(r"\s+", " ", q).strip() | |
| def tokenize(text: str) -> List[str]: | |
| return [t for t in re.findall(r"[\w\u0980-\u09FF]+", normalize_query(text)) if len(t) > 1] | |
| # ── fetching ─────────────────────────────────────────────────────────────────── | |
| async def fetch_json(url: str, headers: dict = None, | |
| params: dict = None, timeout: int = 60) -> Optional[Any]: | |
| try: | |
| async with httpx.AsyncClient(timeout=timeout) as client: | |
| r = await client.get(url, headers=headers or {}, params=params or {}) | |
| if r.status_code == 200: | |
| return r.json() | |
| logger.warning("HTTP %s → %s", r.status_code, url) | |
| except Exception as e: | |
| logger.warning("Fetch %s: %s", url, e) | |
| return None | |
| def _unwrap(data: Any) -> list: | |
| if isinstance(data, list): | |
| return data | |
| if isinstance(data, dict): | |
| for key in ("data", "results", "items"): | |
| if isinstance(data.get(key), list): | |
| return data[key] | |
| return [data] | |
| return [] | |
| async def load_api() -> list: | |
| # Wake the Render.com server (it sleeps on free tier) | |
| for attempt in range(3): | |
| if await fetch_json(f"{API_BASE}/grade-scale", API_HEADERS, timeout=60): | |
| logger.info("API server awake.") | |
| break | |
| logger.info("Wake attempt %d/3 …", attempt + 1) | |
| await asyncio.sleep(10) | |
| else: | |
| logger.warning("API server unreachable — skipping.") | |
| return [] | |
| results = await asyncio.gather( | |
| *[fetch_json(f"{API_BASE}/{s}", API_HEADERS, params=p, timeout=60) | |
| for s, _, p in API_LIST_ENDPOINTS], | |
| return_exceptions=True, | |
| ) | |
| docs = [] | |
| for (suffix, label, _), data in zip(API_LIST_ENDPOINTS, results): | |
| if not data or isinstance(data, Exception): | |
| continue | |
| items = _unwrap(data) | |
| for item in items: | |
| docs.append({"raw": item, "source": f"api:{suffix}"}) | |
| logger.info("%-28s → %d records", label, len(items)) | |
| logger.info("API total: %d docs", len(docs)) | |
| return docs | |
| async def load_github() -> list: | |
| responses = await asyncio.gather( | |
| *[fetch_json(GITHUB_BASE + f, timeout=60) for f in GITHUB_FILES], | |
| return_exceptions=True, | |
| ) | |
| docs = [] | |
| for fname, data in zip(GITHUB_FILES, responses): | |
| if not data or isinstance(data, Exception): | |
| continue | |
| for item in (data if isinstance(data, list) else [data]): | |
| docs.append({"raw": item, "source": f"github:{fname}"}) | |
| logger.info("GitHub total: %d docs", len(docs)) | |
| return docs | |
| # ── chunking ─────────────────────────────────────────────────────────────────── | |
| FIELD_PRIORITY = [ | |
| "title","name","department","program_name","course_code","course_title", | |
| "credit","deadline","date","email","phone","amount","fee","description", | |
| "eligibility","requirements","address","location", | |
| ] | |
| def _flatten(obj: Any, path: str = "") -> List[Tuple[str, str]]: | |
| out = [] | |
| if isinstance(obj, dict): | |
| for k, v in obj.items(): | |
| p = f"{path} > {k}" if path else k | |
| if isinstance(v, (dict, list)): | |
| out.extend(_flatten(v, p)) | |
| else: | |
| s = str(v).strip() | |
| if s and s.lower() not in ("null", "none", "", "[]", "{}"): | |
| out.append((p, s)) | |
| elif isinstance(obj, list): | |
| for i, item in enumerate(obj): | |
| if isinstance(item, (dict, list)): | |
| out.extend(_flatten(item, f"{path}[{i}]")) | |
| else: | |
| s = str(item).strip() | |
| if s: | |
| out.append((f"{path}[{i}]", s)) | |
| return out | |
| def make_chunk(raw: Any, source: str) -> str: | |
| if not isinstance(raw, (dict, list)): | |
| return f"source: {source} | {str(raw).strip()}"[:CHUNK_SIZE] | |
| pairs = _flatten(raw) | |
| def _pri(k: str) -> int: | |
| kl = k.lower() | |
| for i, f in enumerate(FIELD_PRIORITY): | |
| if f in kl: | |
| return i | |
| return 999 | |
| pairs.sort(key=lambda kv: _pri(kv[0])) | |
| lines = [f"source: {source}"] + [f"{k}: {v}" for k, v in pairs[:20]] | |
| return " | ".join(lines)[:CHUNK_SIZE] | |
| def chunk_documents(raw_docs: List[Dict]) -> List[Dict]: | |
| out = [] | |
| for doc in raw_docs: | |
| text = make_chunk(doc["raw"], doc["source"]) | |
| if text.strip(): | |
| out.append({"content": text, "source": doc["source"]}) | |
| return out | |
| # ── indexing ─────────────────────────────────────────────────────────────────── | |
| def build_bm25(documents: List[Dict]) -> Optional[Any]: | |
| if not BM25_OK: | |
| return None | |
| toks = [tokenize(d["content"]) for d in documents] | |
| toks = [t for t in toks if t] | |
| if not toks: | |
| return None | |
| obj = BM25Okapi(toks) | |
| logger.info("[BM25] %d docs indexed.", len(toks)) | |
| return obj | |
| def build_faiss(documents: List[Dict], embedder) -> Tuple[Optional[Any], Optional[np.ndarray]]: | |
| if not FAISS_OK or embedder is None: | |
| return None, None | |
| try: | |
| texts = [f"passage: {d['content']}" for d in documents] | |
| emb = embedder.encode( | |
| texts, | |
| normalize_embeddings=True, | |
| show_progress_bar=True, | |
| batch_size=32, # smaller batch → lower peak RAM | |
| ) | |
| emb = np.array(emb, dtype="float32") | |
| idx = faiss.IndexFlatIP(emb.shape[1]) | |
| idx.add(emb) | |
| logger.info("[FAISS] %d vectors, dim=%d.", idx.ntotal, emb.shape[1]) | |
| return idx, emb | |
| except Exception as e: | |
| logger.error("[FAISS] build failed: %s", e) | |
| return None, None | |
| # ── retrieval ────────────────────────────────────────────────────────────────── | |
| def dense_retrieve(query: str, k: int = TOP_K_RETRIEVE) -> List[Dict]: | |
| if state.faiss_index is None or state.embedder is None: | |
| return [] | |
| try: | |
| q_vec = state.embedder.encode( | |
| [f"query: {query}"], normalize_embeddings=True | |
| ) | |
| q_vec = np.array(q_vec, dtype="float32") | |
| k_a = min(k, state.faiss_index.ntotal) | |
| scores, ids = state.faiss_index.search(q_vec, k_a) | |
| return [ | |
| {**state.documents[i], "dense_score": float(s)} | |
| for s, i in zip(scores[0], ids[0]) if i >= 0 | |
| ] | |
| except Exception as e: | |
| logger.error("[dense] %s", e) | |
| return [] | |
| def sparse_retrieve(query: str, k: int = TOP_K_RETRIEVE) -> List[Dict]: | |
| if state.bm25 is None: | |
| return [] | |
| try: | |
| toks = tokenize(query) | |
| if not toks: | |
| return [] | |
| scores = np.array(state.bm25.get_scores(toks), dtype="float32") | |
| idxs = np.argsort(scores)[::-1][:k] | |
| return [ | |
| {**state.documents[i], "sparse_score": float(scores[i])} | |
| for i in idxs if scores[i] > 0 | |
| ] | |
| except Exception as e: | |
| logger.error("[sparse] %s", e) | |
| return [] | |
| def rrf_fuse(dense: List[Dict], sparse: List[Dict], | |
| w_dense: float = 0.6, w_sparse: float = 0.4, | |
| rrf_k: int = 60) -> List[Dict]: | |
| """Reciprocal-rank fusion of two ranked lists.""" | |
| scores: Dict[str, float] = {} | |
| doc_map: Dict[str, Dict] = {} | |
| for rank, d in enumerate(dense): | |
| key = d["source"] + "||" + d["content"] | |
| scores[key] = scores.get(key, 0.0) + w_dense / (rrf_k + rank + 1) | |
| doc_map[key] = d | |
| for rank, d in enumerate(sparse): | |
| key = d["source"] + "||" + d["content"] | |
| scores[key] = scores.get(key, 0.0) + w_sparse / (rrf_k + rank + 1) | |
| doc_map[key] = d | |
| return [ | |
| {**doc_map[k], "rrf_score": round(s, 6)} | |
| for k, s in sorted(scores.items(), key=lambda x: x[1], reverse=True) | |
| ] | |
| async def retrieve(query: str, k: int = TOP_K_FINAL) -> List[Dict]: | |
| norm = normalize_query(query) | |
| dense = await asyncio.to_thread(dense_retrieve, norm, TOP_K_RETRIEVE) | |
| sparse = await asyncio.to_thread(sparse_retrieve, norm, TOP_K_RETRIEVE) | |
| fused = rrf_fuse(dense, sparse) | |
| return fused[:k] | |
| # ── generation ───────────────────────────────────────────────────────────────── | |
| SYSTEM_PROMPT = ( | |
| "You are EWU Assistant for East West University, Dhaka, Bangladesh. " | |
| "Answer ONLY from the provided context. Be concise and clear. " | |
| "Use bullet points for lists. Never invent fees, dates, names, or codes. " | |
| "If the context does not contain the answer, say so honestly. " | |
| "You support English, Bangla (বাংলা), and Banglish." | |
| ) | |
| def _lang_instruction(lang: str) -> str: | |
| if lang == "bangla": | |
| return "Respond entirely in Bengali (বাংলা script)." | |
| if lang == "banglish": | |
| return "Respond in Banglish — Bengali meaning written with English letters." | |
| return "Respond in clear English." | |
| def _fallback(lang: str) -> str: | |
| url = "https://www.ewubd.edu/" | |
| if lang == "bangla": | |
| return f"দুঃখিত, উত্তর দিতে পারছি না। EWU ওয়েবসাইট দেখুন: {url}" | |
| if lang == "banglish": | |
| return f"Sorry, ekhon answer dite parcchi na. EWU website dekhun: {url}" | |
| return f"Sorry, I couldn't generate a response. Please visit: {url}" | |
| # ── HF Inference API (optional fallback) ────────────────────────────────────── | |
| async def _generate_hf_api(query: str, context: str, lang: str) -> str: | |
| if not HF_API_TOKEN: | |
| return "" | |
| prompt = ( | |
| f"<|system|>\n{SYSTEM_PROMPT}\n{_lang_instruction(lang)}</s>\n" | |
| f"<|user|>\nContext:\n{context[:GEN_PROMPT_MAX_CHARS]}\n\nQuestion: {query}</s>\n" | |
| f"<|assistant|>\n" | |
| ) | |
| payload = { | |
| "inputs": prompt, | |
| "parameters": { | |
| "max_new_tokens": HF_MAX_NEW_TOKENS, | |
| "temperature": 0.3, | |
| "top_p": 0.9, | |
| "repetition_penalty": 1.1, | |
| "return_full_text": False, | |
| }, | |
| } | |
| headers = {"Authorization": f"Bearer {HF_API_TOKEN}", "Content-Type": "application/json"} | |
| async with httpx.AsyncClient(timeout=HF_API_TIMEOUT) as client: | |
| for attempt in range(2): | |
| try: | |
| r = await client.post(HF_API_URL, headers=headers, json=payload) | |
| if r.status_code == 200: | |
| data = r.json() | |
| if isinstance(data, list) and data: | |
| return data[0].get("generated_text", "").strip() | |
| return "" | |
| if r.status_code == 503 and attempt == 0: | |
| wait = min(int(r.json().get("estimated_time", 20)), 30) | |
| logger.info("[HF API] model loading, waiting %ds …", wait) | |
| await asyncio.sleep(wait) | |
| continue | |
| logger.error("[HF API] HTTP %s: %s", r.status_code, r.text[:200]) | |
| return "" | |
| except httpx.TimeoutException: | |
| logger.error("[HF API] timeout.") | |
| return "" | |
| except Exception as e: | |
| logger.error("[HF API] %s", e) | |
| return "" | |
| return "" | |
| # ── local TinyLlama generation ───────────────────────────────────────────────── | |
| def _generate_local_sync(query: str, context: str, lang: str) -> str: | |
| import torch | |
| model = state.gen_model | |
| tokenizer = state.gen_tokenizer | |
| if model is None or tokenizer is None: | |
| return "" | |
| messages = [ | |
| {"role": "system", "content": SYSTEM_PROMPT + "\n" + _lang_instruction(lang)}, | |
| {"role": "user", "content": f"Context:\n{context[:GEN_PROMPT_MAX_CHARS]}\n\nQuestion: {query}"}, | |
| ] | |
| try: | |
| text = tokenizer.apply_chat_template( | |
| messages, tokenize=False, add_generation_prompt=True | |
| ) | |
| inputs = tokenizer( | |
| text, | |
| return_tensors="pt", | |
| truncation=True, | |
| max_length=1024, # keep prompt short to save RAM | |
| ) | |
| prompt_len = inputs["input_ids"].shape[-1] | |
| logger.info("[local] prompt_tokens=%d max_new=%d", prompt_len, GEN_MAX_NEW_TOKENS) | |
| with torch.inference_mode(): | |
| logger.info("[local] starting generate") | |
| output_ids = model.generate( | |
| **inputs, | |
| max_new_tokens = GEN_MAX_NEW_TOKENS, | |
| do_sample = True, | |
| temperature = 0.3, | |
| top_p = 0.9, | |
| repetition_penalty = 1.1, | |
| pad_token_id = tokenizer.eos_token_id, | |
| eos_token_id = tokenizer.eos_token_id, | |
| ) | |
| logger.info("[local] generate finished") | |
| new_ids = output_ids[0][prompt_len:] | |
| answer = tokenizer.decode(new_ids, skip_special_tokens=True).strip() | |
| logger.info("[local] generated %d tokens → %d chars.", len(new_ids), len(answer)) | |
| return answer | |
| except Exception as e: | |
| logger.error("[local] generation failed: %s", e) | |
| return "" | |
| finally: | |
| # Free activation memory after each call | |
| gc.collect() | |
| async def generate(query: str, context: str, lang: str) -> str: | |
| # 1. Try HF Inference API first (costs 0 local RAM) | |
| if HF_API_TOKEN: | |
| answer = await _generate_hf_api(query, context, lang) | |
| if answer: | |
| return answer | |
| logger.warning("[gen] HF API failed — falling back to local TinyLlama.") | |
| # 2. Local TinyLlama | |
| if state.gen_model is not None: | |
| try: | |
| answer = await asyncio.wait_for( | |
| asyncio.to_thread(_generate_local_sync, query, context, lang), | |
| timeout=GEN_TIMEOUT_S, | |
| ) | |
| if answer: | |
| return answer | |
| except asyncio.TimeoutError: | |
| logger.error("[gen] TinyLlama timed out after %ds.", GEN_TIMEOUT_S) | |
| except Exception as e: | |
| logger.error("[gen] TinyLlama error: %s", e) | |
| return _fallback(lang) | |
| # ── model loading ────────────────────────────────────────────────────────────── | |
| def _load_embedder(): | |
| if not ST_OK: | |
| logger.warning("[embed] sentence-transformers not installed.") | |
| return None | |
| try: | |
| logger.info("[embed] Loading %s …", EMBED_MODEL) | |
| model = SentenceTransformer(EMBED_MODEL, device="cpu") | |
| logger.info("[embed] ✓ Ready.") | |
| return model | |
| except Exception as e: | |
| logger.error("[embed] Failed: %s", e) | |
| return None | |
| def _load_tinyllama(): | |
| try: | |
| import torch | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| logger.info("[gen] Loading TinyLlama (%s) …", HF_GEN_MODEL) | |
| tokenizer = AutoTokenizer.from_pretrained(HF_GEN_MODEL, trust_remote_code=True) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| HF_GEN_MODEL, | |
| torch_dtype = torch.float32, # float16 unstable on CPU | |
| low_cpu_mem_usage = True, # stream weights → lower peak RAM | |
| trust_remote_code = True, | |
| ) | |
| model.eval() | |
| # Limit CPU threads so we don't starve other workers | |
| torch.set_num_threads(2) | |
| logger.info("[gen] ✓ TinyLlama ready (float32, CPU).") | |
| return model, tokenizer | |
| except Exception as e: | |
| logger.error("[gen] TinyLlama load failed: %s", e) | |
| return None, None | |
| # ── boot ─────────────────────────────────────────────────────────────────────── | |
| async def _boot(): | |
| try: | |
| logger.info("=== EWU RAG (TinyLlama + e5-small, lean pipeline) ===") | |
| # Step 1 — load embedding model (small, fast) | |
| state.embedder = await asyncio.to_thread(_load_embedder) | |
| # Step 2 — try loading indexes from cache | |
| cache_ready = ( | |
| _cache_fresh("documents.pkl") | |
| and _cache_fresh("bm25.pkl") | |
| and (_cache_fresh("faiss.index") or not FAISS_OK) | |
| ) | |
| if cache_ready: | |
| docs = _load("documents.pkl") | |
| bm25 = _load("bm25.pkl") | |
| if docs and bm25: | |
| state.documents = docs | |
| state.bm25 = bm25 | |
| state.faiss_index, state.doc_embeddings = None, None | |
| idx = _load_faiss() | |
| if idx is not None: | |
| state.faiss_index = idx | |
| logger.info("[cache] FAISS loaded (%d vectors).", idx.ntotal) | |
| logger.info("[cache] %d chunks loaded from disk.", len(docs)) | |
| # Load TinyLlama AFTER data is ready (so we see OOM early if needed) | |
| state.gen_model, state.gen_tokenizer = await asyncio.to_thread(_load_tinyllama) | |
| state.ready = True | |
| return | |
| # Step 3 — fetch fresh data | |
| logger.info("Fetching data (API + GitHub) …") | |
| api_docs, gh_docs = await asyncio.gather(load_api(), load_github()) | |
| raw_docs = api_docs + gh_docs | |
| logger.info("Total raw docs: %d", len(raw_docs)) | |
| if not raw_docs: | |
| logger.warning("No documents fetched — starting empty.") | |
| state.ready = True | |
| return | |
| # Step 4 — chunk | |
| logger.info("Chunking …") | |
| state.documents = await asyncio.to_thread(chunk_documents, raw_docs) | |
| logger.info("Total chunks: %d", len(state.documents)) | |
| # Step 5 — BM25 index | |
| state.bm25 = await asyncio.to_thread(build_bm25, state.documents) | |
| _save("documents.pkl", state.documents) | |
| if state.bm25: | |
| _save("bm25.pkl", state.bm25) | |
| # Step 6 — FAISS index (only if embedder loaded successfully) | |
| if state.embedder: | |
| logger.info("Building FAISS index …") | |
| idx, emb = await asyncio.to_thread( | |
| build_faiss, state.documents, state.embedder | |
| ) | |
| state.faiss_index = idx | |
| state.doc_embeddings = emb | |
| if idx is not None: | |
| _save_faiss(idx) | |
| # Step 7 — load TinyLlama last (heaviest, ~4.4 GB) | |
| state.gen_model, state.gen_tokenizer = await asyncio.to_thread(_load_tinyllama) | |
| state.ready = True | |
| logger.info("✓ EWU RAG fully ready.") | |
| except Exception as e: | |
| state.error = str(e) | |
| state.ready = False | |
| logger.exception("Boot failed") | |
| async def lifespan(app: FastAPI): | |
| task = asyncio.create_task(_boot()) | |
| try: | |
| yield | |
| finally: | |
| task.cancel() | |
| try: | |
| await task | |
| except asyncio.CancelledError: | |
| pass | |
| # ── FastAPI ──────────────────────────────────────────────────────────────────── | |
| app = FastAPI(title="EWU RAG (TinyLlama + e5-small)", lifespan=lifespan) | |
| class Query(BaseModel): | |
| query: str | |
| top_k: int = TOP_K_FINAL | |
| async def root(): | |
| return { | |
| "service": "EWU RAG Server", | |
| "status": "ready" if state.ready else ("error" if state.error else "loading"), | |
| "docs_loaded": len(state.documents), | |
| "faiss": state.faiss_index is not None, | |
| "bm25": state.bm25 is not None, | |
| "gen_model": HF_GEN_MODEL, | |
| "gen_loaded": state.gen_model is not None, | |
| "endpoints": { | |
| "POST /rag": "Submit a question", | |
| "GET /health": "Detailed health check", | |
| }, | |
| } | |
| async def rag_endpoint(q: Query): | |
| if not state.ready: | |
| raise HTTPException(503, detail=state.error or "Still initializing — retry shortly.") | |
| raw_query = q.query.strip() | |
| if not raw_query: | |
| raise HTTPException(400, detail="Query must not be empty.") | |
| lang = detect_language(raw_query) | |
| top_k = max(1, min(q.top_k, 8)) | |
| results = await retrieve(raw_query, k=top_k) | |
| if not results: | |
| return { | |
| "answer": _fallback(lang), | |
| "lang": lang, | |
| "sources": [], | |
| } | |
| context = "\n\n---\n\n".join(r["content"] for r in results) | |
| answer = await generate(raw_query, context, lang) | |
| return { | |
| "answer": answer, | |
| "lang": lang, | |
| "sources": [ | |
| { | |
| "source": r["source"], | |
| "rrf_score": round(r.get("rrf_score", 0), 6), | |
| "dense_score": round(r.get("dense_score", 0), 4) if "dense_score" in r else None, | |
| "sparse_score": round(r.get("sparse_score", 0), 4) if "sparse_score" in r else None, | |
| } | |
| for r in results | |
| ], | |
| } | |
| async def health(): | |
| return { | |
| "status": "ready" if state.ready else ("error" if state.error else "loading"), | |
| "docs": len(state.documents), | |
| "embed_model": EMBED_MODEL, | |
| "gen_model": HF_GEN_MODEL, | |
| "gen_loaded": state.gen_model is not None, | |
| "hf_api_token": bool(HF_API_TOKEN), | |
| "faiss": state.faiss_index is not None, | |
| "bm25": state.bm25 is not None, | |
| "max_new_tokens": GEN_MAX_NEW_TOKENS, | |
| "gen_timeout_s": GEN_TIMEOUT_S, | |
| "error": state.error or None, | |
| } | |
| if __name__ == "__main__": | |
| uvicorn.run("app:app", host="0.0.0.0", port=7860, reload=False) |