Spaces:
Runtime error
Runtime error
| """ | |
| EnggSS RAG ChatBot — HuggingFace Space (serving only) | |
| ====================================================== | |
| Loads a pre-built PRIVATE HuggingFace Dataset (embeddings already computed | |
| by preprocessing/create_dataset.py) and serves a conversational Q&A interface. | |
| No PDF loading. No chunking. No embedding of documents at runtime. | |
| Only the user query is embedded on each call (~20 ms). | |
| Required Space Secrets (Settings → Variables and Secrets): | |
| HF_TOKEN — HuggingFace token with READ access to the dataset | |
| HF_DATASET_REPO — e.g. your-org/enggss-rag-dataset | |
| """ | |
| from __future__ import annotations | |
| import logging | |
| import os | |
| from collections import Counter | |
| from typing import Any | |
| import gradio as gr | |
| import numpy as np | |
| from datasets import load_dataset | |
| from dotenv import load_dotenv | |
| from langchain_core.messages import AIMessage, HumanMessage | |
| from langchain_core.output_parsers import StrOutputParser | |
| from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder | |
| from langchain_huggingface import HuggingFaceEndpoint | |
| from sentence_transformers import SentenceTransformer | |
| load_dotenv() | |
| # ─── Configuration ──────────────────────────────────────────────────────────── | |
| HF_TOKEN = os.environ.get("HF_TOKEN", "") | |
| DATASET_REPO = os.environ.get("HF_DATASET_REPO", "") | |
| LLM_REPO = "Qwen/Qwen2.5-7B-Instruct" | |
| EMBED_MODEL = "BAAI/bge-large-en-v1.5" | |
| QUERY_PREFIX = "Represent this sentence for searching relevant passages: " | |
| TOP_K = 3 | |
| FETCH_K = 15 | |
| LAMBDA_MMR = 0.7 # 1.0 = pure relevance · 0.0 = pure diversity | |
| logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s") | |
| log = logging.getLogger(__name__) | |
| # ═════════════════════════════════════════════════════════════════════════════ | |
| # 1 ─ Embedding model (local, cached by sentence-transformers after 1st run) | |
| # ═════════════════════════════════════════════════════════════════════════════ | |
| log.info("Loading embedding model: %s", EMBED_MODEL) | |
| try: | |
| _embed_model = SentenceTransformer(EMBED_MODEL) | |
| EMBED_ERROR = None | |
| except Exception as _exc: | |
| _embed_model = None | |
| EMBED_ERROR = str(_exc) | |
| log.error("Embedding model failed: %s", _exc) | |
| def embed_query(text: str) -> np.ndarray: | |
| """ | |
| Embed a single query string with the BGE instruction prefix. | |
| Returns a unit-normalised float32 vector of shape (1024,). | |
| """ | |
| if _embed_model is None: | |
| raise RuntimeError(f"Embedding model unavailable: {EMBED_ERROR}") | |
| vec = _embed_model.encode( | |
| QUERY_PREFIX + text, | |
| normalize_embeddings=True, | |
| show_progress_bar=False, | |
| ) | |
| return vec.astype(np.float32) | |
| # ═════════════════════════════════════════════════════════════════════════════ | |
| # 2 ─ LLM (HF Inference API — no local model download) | |
| # ═════════════════════════════════════════════════════════════════════════════ | |
| try: | |
| _llm = HuggingFaceEndpoint( | |
| repo_id=LLM_REPO, | |
| temperature=0.01, | |
| max_new_tokens=1024, | |
| huggingfacehub_api_token=HF_TOKEN, | |
| ) | |
| LLM_ERROR = None | |
| except Exception as _exc: | |
| _llm = None | |
| LLM_ERROR = str(_exc) | |
| log.error("LLM init failed: %s", _exc) | |
| _qa_prompt = ChatPromptTemplate.from_messages([ | |
| ("system", | |
| "You are a technical expert on engineering specifications and IS/IEEE/BIS standards. " | |
| "Answer ONLY from the provided context. Be precise and point-wise. " | |
| "If the context does not contain the answer, say so clearly."), | |
| MessagesPlaceholder("chat_history"), | |
| ("human", | |
| "Context from technical documents:\n{context}\n\n" | |
| "Question: {question}"), | |
| ]) | |
| _answer_chain = (_qa_prompt | _llm | StrOutputParser()) if _llm else None | |
| # ═════════════════════════════════════════════════════════════════════════════ | |
| # 3 ─ Load dataset from HF Hub into a NumPy matrix | |
| # ═════════════════════════════════════════════════════════════════════════════ | |
| EMB_MATRIX: np.ndarray | None = None | |
| METADATA: list[dict] | None = None | |
| def load_knowledge_base() -> tuple[str, str]: | |
| """ | |
| Download the private HF Dataset, build the NumPy embedding matrix, and | |
| populate the module-level EMB_MATRIX / METADATA. | |
| Returns: | |
| (status_str, detail_str) e.g. ("✅ Ready", "8 420 chunks · 35 docs") | |
| """ | |
| global EMB_MATRIX, METADATA | |
| if not DATASET_REPO: | |
| return "❌ Not configured", "Set the HF_DATASET_REPO secret in Space Settings." | |
| if not HF_TOKEN: | |
| return "❌ Not configured", "Set the HF_TOKEN secret in Space Settings." | |
| log.info("Loading dataset from HF Hub: %s", DATASET_REPO) | |
| try: | |
| ds = load_dataset(DATASET_REPO, token=HF_TOKEN, split="train") | |
| except Exception as exc: | |
| return "❌ Load failed", str(exc) | |
| if len(ds) == 0: | |
| return "❌ Empty dataset", "Dataset has no records. Run create_dataset.py first." | |
| # Build normalised float32 matrix (N × 1024) | |
| mat = np.array(ds["embedding"], dtype=np.float32) | |
| norms = np.linalg.norm(mat, axis=1, keepdims=True) | |
| mat = mat / np.where(norms == 0, 1.0, norms) | |
| EMB_MATRIX = mat | |
| METADATA = [ | |
| { | |
| "text": r["text"], | |
| "source": r["source"], | |
| "page": r["page"], | |
| "context": r.get("context", ""), | |
| } | |
| for r in ds | |
| ] | |
| n_docs = len({m["source"] for m in METADATA}) | |
| detail = f"{len(METADATA):,} chunks · {n_docs} documents" | |
| log.info("Dataset ready: %s", detail) | |
| return "✅ Ready", detail | |
| # Load at startup | |
| _status, _detail = load_knowledge_base() | |
| log.info("Startup — %s: %s", _status, _detail) | |
| # ═════════════════════════════════════════════════════════════════════════════ | |
| # 4 ─ Retrieval (cosine similarity + MMR, pure NumPy) | |
| # ═════════════════════════════════════════════════════════════════════════════ | |
| def _mmr( | |
| query_emb: np.ndarray, | |
| scores: np.ndarray, | |
| top_k: int, | |
| fetch_k: int, | |
| lambda_mult: float, | |
| ) -> list[tuple[int, float]]: | |
| """ | |
| Maximum Marginal Relevance selection. | |
| Picks *top_k* results that balance relevance to the query (cosine score) | |
| against redundancy with already-selected chunks. | |
| """ | |
| candidates = list(np.argsort(scores)[::-1][:fetch_k]) | |
| selected: list[int] = [] | |
| while len(selected) < top_k and candidates: | |
| if not selected: | |
| best = candidates[0] | |
| else: | |
| sel_vecs = EMB_MATRIX[selected] # (n_sel, D) | |
| mmr_vals = [ | |
| lambda_mult * scores[c] | |
| - (1 - lambda_mult) * float(np.max(sel_vecs @ EMB_MATRIX[c])) | |
| for c in candidates | |
| ] | |
| best = candidates[int(np.argmax(mmr_vals))] | |
| selected.append(best) | |
| candidates.remove(best) | |
| return [(idx, float(scores[idx])) for idx in selected] | |
| def retrieve(question: str) -> list[dict[str, Any]]: | |
| """ | |
| Embed *question* and return top-k diverse chunks with similarity scores. | |
| """ | |
| q_emb = embed_query(question) | |
| scores = EMB_MATRIX @ q_emb # dot product = cosine (unit vecs) | |
| hits = _mmr(q_emb, scores, TOP_K, FETCH_K, LAMBDA_MMR) | |
| return [{**METADATA[idx], "score": score} for idx, score in hits] | |
| # ═════════════════════════════════════════════════════════════════════════════ | |
| # 5 ─ Q&A function (wired to gr.ChatInterface) | |
| # ═════════════════════════════════════════════════════════════════════════════ | |
| def qa_fn(question: str, history: list[dict]) -> str: | |
| """ | |
| 1. Retrieve top-k contexts via MMR. | |
| 2. Generate an answer with Qwen2.5-7B using the contexts + chat history. | |
| 3. Return a formatted Markdown string with contexts + answer. | |
| """ | |
| # Guard: dataset not loaded | |
| if EMB_MATRIX is None: | |
| return ( | |
| f"⚠️ **Dataset not loaded** ({_status}).\n\n" | |
| f"{_detail}\n\n" | |
| "Run `preprocessing/create_dataset.py` locally to build the dataset, " | |
| "then restart this Space." | |
| ) | |
| if not question.strip(): | |
| return "Please enter a question." | |
| # ── Retrieve ───────────────────────────────────────────────────────────── | |
| try: | |
| contexts = retrieve(question) | |
| except Exception as exc: | |
| log.error("Retrieval error: %s", exc) | |
| return f"❌ Retrieval failed: {exc}" | |
| ctx_display = "\n\n".join( | |
| f"**[{i+1}] {c['source']} — Page {c['page']} " | |
| f"· similarity {c['score']:.3f}**\n" | |
| f"> *{c['context']}*\n\n" | |
| f"{c['text'][:600]}{'…' if len(c['text']) > 600 else ''}" | |
| for i, c in enumerate(contexts) | |
| ) | |
| # ── Generate ───────────────────────────────────────────────────────────── | |
| if _answer_chain is None: | |
| answer = f"⚠️ LLM unavailable: {LLM_ERROR}" | |
| else: | |
| context_str = "\n\n---\n\n".join( | |
| f"[{i+1}] Source: {c['source']} | Page: {c['page']}\n{c['text']}" | |
| for i, c in enumerate(contexts) | |
| ) | |
| lc_history = [ | |
| HumanMessage(content=m["content"]) if m["role"] == "user" | |
| else AIMessage(content=m["content"]) | |
| for m in history | |
| ] | |
| try: | |
| answer = _answer_chain.invoke({ | |
| "context": context_str, | |
| "question": question, | |
| "chat_history": lc_history, | |
| }) | |
| except Exception as exc: | |
| log.error("LLM error: %s", exc) | |
| answer = f"❌ LLM error: {exc}" | |
| return ( | |
| f"## Retrieved Contexts\n\n{ctx_display}\n\n" | |
| f"---\n\n" | |
| f"## Answer\n\n{answer}" | |
| ) | |
| # ═════════════════════════════════════════════════════════════════════════════ | |
| # 6 ─ Analytics | |
| # ═════════════════════════════════════════════════════════════════════════════ | |
| def get_analytics() -> tuple[int, int, float, list[list]]: | |
| if METADATA is None: | |
| return 0, 0, 0.0, [] | |
| counts = Counter(m["source"] for m in METADATA) | |
| total = len(METADATA) | |
| n_docs = len(counts) | |
| avg = round(total / n_docs, 1) if n_docs else 0.0 | |
| table = [[src, cnt] for src, cnt in sorted(counts.items())] | |
| return total, n_docs, avg, table | |
| # ═════════════════════════════════════════════════════════════════════════════ | |
| # 7 ─ Gradio UI | |
| # ═════════════════════════════════════════════════════════════════════════════ | |
| EXAMPLES = [ | |
| "What should be the GIB height outside the GIS hall?", | |
| "STATCOM station ratings and specifications", | |
| "Specifications of XLPE power cables", | |
| "Specification for Ethernet switches in SAS", | |
| "Type tests for HV switchgear as per IS standards", | |
| "Technical requirements for 765 kV class transformer", | |
| ] | |
| with gr.Blocks(title="EnggSS RAG ChatBot", theme=gr.themes.Base()) as demo: | |
| gr.Markdown( | |
| "# ⚡ EnggSS RAG ChatBot\n" | |
| "Conversational Q&A over **Model Technical Specifications** & " | |
| "**IS / IEEE / BIS Standards**\n\n" | |
| f"> **Dataset:** {_status} — {_detail} | " | |
| f"**Embedding:** `{EMBED_MODEL}` | " | |
| f"**LLM:** `{LLM_REPO}`" | |
| ) | |
| with gr.Tabs(): | |
| # ── Tab 1 : Q&A ─────────────────────────────────────────────────────── | |
| with gr.Tab("💬 Q&A"): | |
| gr.ChatInterface( | |
| fn=qa_fn, | |
| type="messages", | |
| examples=EXAMPLES, | |
| concurrency_limit=None, | |
| # fill_height removed in gradio 5.x | |
| ) | |
| # ── Tab 2 : Analytics ───────────────────────────────────────────────── | |
| with gr.Tab("📊 Analytics"): | |
| gr.Markdown("### Knowledge Base Statistics") | |
| refresh_btn = gr.Button("🔄 Refresh", size="sm") | |
| with gr.Row(): | |
| m_chunks = gr.Metric(label="Total Chunks", value=0) | |
| m_docs = gr.Metric(label="Documents Processed", value=0) | |
| m_avg = gr.Metric(label="Avg Chunks / Doc", value=0.0) | |
| tbl = gr.Dataframe( | |
| headers=["Document", "Chunks"], | |
| datatype=["str", "number"], | |
| interactive=False, | |
| label="Chunks per Document", | |
| ) | |
| def _refresh(): | |
| return get_analytics() | |
| refresh_btn.click(fn=_refresh, outputs=[m_chunks, m_docs, m_avg, tbl]) | |
| demo.load(fn=_refresh, outputs=[m_chunks, m_docs, m_avg, tbl]) | |
| gr.Markdown( | |
| f"**Retrieval:** MMR · k={TOP_K} · fetch_k={FETCH_K} · λ={LAMBDA_MMR} \n" | |
| f"**Embedding model:** `{EMBED_MODEL}` (1024-dim, L2-normalised) \n" | |
| f"**LLM:** `{LLM_REPO}` via HF Inference API" | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(debug=True) | |