Spaces:
Sleeping
Sleeping
| # --- quiet TensorFlow / tf-keras noise (must be first lines in the file) ----- | |
| import os, warnings | |
| # Hide TF C++ INFO/WARNING/ERROR levels except errors | |
| os.environ.setdefault("TF_CPP_MIN_LOG_LEVEL", "3") | |
| # Stop the oneDNN notice you’re seeing | |
| os.environ.setdefault("TF_ENABLE_ONEDNN_OPTS", "0") | |
| # Silence specific deprecation chatter from tf_keras | |
| warnings.filterwarnings( | |
| "ignore", | |
| category=DeprecationWarning, | |
| message=r".*tf\.losses\.sparse_softmax_cross_entropy.*", | |
| ) | |
| # Blanket-ignore DeprecationWarnings originating from tensorflow / tf_keras modules | |
| warnings.filterwarnings("ignore", category=DeprecationWarning, module=r"^(tensorflow|tf_keras)\b") | |
| warnings.filterwarnings("ignore", category=UserWarning, module=r"^(tensorflow|tf_keras)\b") | |
| # ----------------------------------------------------------------------------- | |
| # (now your existing imports follow) | |
| from contextlib import asynccontextmanager | |
| import argparse | |
| import os | |
| import sys | |
| import json | |
| import time | |
| import pickle | |
| import pathlib | |
| from typing import List, Tuple, Dict, Any | |
| import numpy as np | |
| from tqdm import tqdm | |
| # --- DEV-ONLY TOKENS (you asked to avoid .env) -------------------------------- | |
| OPENAI_API_KEY = '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' # <<< put your dev key here | |
| OPENAI_MODEL = "gpt-4o-mini" # solid + cost-effective for demo | |
| # --- Heavy deps ---------------------------------------------------------------- | |
| try: | |
| import faiss # type: ignore | |
| except Exception as e: | |
| print("FAISS is required. pip install faiss-cpu", file=sys.stderr) | |
| raise | |
| try: | |
| from datasets import load_dataset # type: ignore | |
| except Exception: | |
| print("HuggingFace datasets is required. pip install datasets", file=sys.stderr) | |
| raise | |
| try: | |
| from sentence_transformers import SentenceTransformer # type: ignore | |
| except Exception: | |
| print("sentence-transformers is required. pip install sentence-transformers", file=sys.stderr) | |
| raise | |
| try: | |
| from openai import OpenAI # type: ignore | |
| except Exception: | |
| print("openai>=1.0 is required. pip install openai", file=sys.stderr) | |
| raise | |
| # --- Optional API mode --------------------------------------------------------- | |
| try: | |
| from fastapi import FastAPI | |
| from pydantic import BaseModel | |
| import uvicorn | |
| FASTAPI_AVAILABLE = True | |
| except Exception: | |
| FASTAPI_AVAILABLE = False | |
| # --- Paths -------------------------------------------------------------------- | |
| ROOT = pathlib.Path(__file__).resolve().parent | |
| ART = ROOT / "artifacts" | |
| ART.mkdir(exist_ok=True) | |
| INDEX_FILE = ART / "squad_v2.faiss" | |
| META_FILE = ART / "squad_v2_meta.pkl" | |
| # --- Chunking params ----------------------------------------------------------- | |
| # SQuAD contexts can be long. We chunk for better retrieval quality. | |
| CHUNK_SIZE = 500 # characters per chunk | |
| CHUNK_OVERLAP = 100 # overlap to preserve context across boundaries | |
| # --- Minimal logging ----------------------------------------------------------- | |
| def log(msg: str): | |
| print(f"[RAG] {msg}", flush=True) | |
| # --- Data prep ---------------------------------------------------------------- | |
| def chunk_text(text: str, chunk_size: int, overlap: int) -> List[str]: | |
| if not text: | |
| return [] | |
| chunks = [] | |
| start = 0 | |
| while start < len(text): | |
| end = min(len(text), start + chunk_size) | |
| chunks.append(text[start:end]) | |
| if end == len(text): | |
| break | |
| start = end - overlap | |
| if start < 0: | |
| start = 0 | |
| return chunks | |
| def load_and_prepare_squad() -> List[Dict[str, Any]]: | |
| """ | |
| Returns a list of dicts: | |
| { | |
| 'id': str, # synthetic id per chunk | |
| 'title': str, | |
| 'context': str, # chunk text | |
| 'source_meta': { 'split': 'train|validation', 'orig_example_id': ..., 'title': ...} | |
| } | |
| """ | |
| log("Downloading SQuAD v2 via datasets ...") | |
| ds = load_dataset("rajpurkar/squad_v2") | |
| prepared: List[Dict[str, Any]] = [] | |
| for split in ["train", "validation"]: | |
| rows = ds[split] | |
| log(f"Processing split: {split} (n={len(rows)})") | |
| for i, ex in enumerate(rows): | |
| title = ex.get("title") or "" | |
| context = ex.get("context") or "" | |
| ex_id = ex.get("id") or f"{split}-{i}" | |
| chunks = chunk_text(context, CHUNK_SIZE, CHUNK_OVERLAP) | |
| for j, chunk in enumerate(chunks): | |
| prepared.append({ | |
| "id": f"{ex_id}::chunk{j}", | |
| "title": title, | |
| "context": chunk.strip(), | |
| "source_meta": {"split": split, "orig_example_id": ex_id, "title": title}, | |
| }) | |
| log(f"Prepared {len(prepared)} chunks total.") | |
| return prepared | |
| # --- Embeddings & FAISS ------------------------------------------------------- | |
| def build_index(prepared: List[Dict[str, Any]], model_name: str = "all-MiniLM-L6-v2"): | |
| log(f"Loading embedding model: {model_name}") | |
| st_model = SentenceTransformer(model_name) | |
| texts = [r["context"] for r in prepared] | |
| log("Encoding chunks -> embeddings (this can take a while) ...") | |
| embs = st_model.encode(texts, show_progress_bar=True, convert_to_numpy=True, batch_size=256) | |
| embs = embs.astype("float32") | |
| dim = embs.shape[1] | |
| index = faiss.IndexFlatL2(dim) | |
| index.add(embs) | |
| log(f"Built FAISS index with {index.ntotal} vectors. Saving to disk ...") | |
| faiss.write_index(index, str(INDEX_FILE)) | |
| meta = { | |
| "records": prepared, | |
| "embedding_model": model_name, | |
| "dim": dim, | |
| "created_at": time.time(), | |
| "chunk_size": CHUNK_SIZE, | |
| "chunk_overlap": CHUNK_OVERLAP, | |
| } | |
| with open(META_FILE, "wb") as f: | |
| pickle.dump(meta, f) | |
| log("Index + metadata saved.") | |
| return index, meta, st_model | |
| def load_index(): | |
| if not INDEX_FILE.exists() or not META_FILE.exists(): | |
| raise FileNotFoundError("Index or metadata not found. Run with --build-index first.") | |
| index = faiss.read_index(str(INDEX_FILE)) | |
| with open(META_FILE, "rb") as f: | |
| meta = pickle.load(f) | |
| # lazy load embedding model to match metadata | |
| st_model = SentenceTransformer(meta.get("embedding_model", "all-MiniLM-L6-v2")) | |
| return index, meta, st_model | |
| # --- RAG core ----------------------------------------------------------------- | |
| class GroundedQA: | |
| def __init__(self, index, records: List[Dict[str, Any]], embed_model, openai_api_key: str): | |
| self.index = index | |
| self.records = records | |
| self.embed_model = embed_model | |
| self.client = OpenAI(api_key=openai_api_key) | |
| def retrieve(self, question: str, k: int = 5) -> List[Tuple[Dict[str, Any], float]]: | |
| q_emb = self.embed_model.encode([question], convert_to_numpy=True).astype("float32") | |
| distances, indices = self.index.search(q_emb, k) | |
| out = [] | |
| for rank, idx in enumerate(indices[0]): | |
| rec = self.records[idx] | |
| dist = float(distances[0][rank]) | |
| out.append((rec, dist)) | |
| return out | |
| def _build_prompt(self, question: str, retrieved: List[Tuple[Dict[str, Any], float]]) -> str: | |
| context_blocks = [] | |
| for i, (rec, _) in enumerate(retrieved, start=1): | |
| title = rec.get("title") or "Untitled" | |
| ctx = rec["context"] | |
| context_blocks.append(f"[Source {i} | {title}] {ctx}") | |
| context_text = "\n\n".join(context_blocks) | |
| prompt = ( | |
| "You are a precise, grounded Q&A assistant. " | |
| "Answer ONLY using the provided context. If the answer is not in the context, say you don't know.\n" | |
| "Add citations like [Source X] inline where relevant.\n\n" | |
| f"Context:\n{context_text}\n\n" | |
| f"Question: {question}\n\n" | |
| "Answer (with citations):" | |
| ) | |
| return prompt | |
| def answer_with_citations(self, question: str, k: int = 5) -> Dict[str, Any]: | |
| retrieved = self.retrieve(question, k=k) | |
| prompt = self._build_prompt(question, retrieved) | |
| resp = self.client.chat.completions.create( | |
| model=OPENAI_MODEL, | |
| messages=[{"role": "user", "content": prompt}], | |
| temperature=0.2, | |
| max_tokens=400, | |
| ) | |
| answer = resp.choices[0].message.content.strip() | |
| return { | |
| "answer": answer, | |
| "sources": [ | |
| { | |
| "rank": i + 1, | |
| "distance": d, | |
| "id": rec["id"], | |
| "title": rec.get("title"), | |
| "split": rec["source_meta"]["split"], | |
| "excerpt": rec["context"][:240] + ("..." if len(rec["context"]) > 240 else "") | |
| } | |
| for i, (rec, d) in enumerate(retrieved) | |
| ], | |
| } | |
| # --- Simple confidence heuristic ---------------------------------------------- | |
| def should_review(rag_result: Dict[str, Any], threshold: float = 1.2) -> bool: | |
| # Lower L2 distance -> closer match. We flag for human review if the average distance is high. | |
| if not rag_result.get("sources"): | |
| return True | |
| avg = float(np.mean([s["distance"] for s in rag_result["sources"]])) | |
| return avg > threshold | |
| # --- CLI ---------------------------------------------------------------------- | |
| def cli_build_index(): | |
| prepared = load_and_prepare_squad() | |
| build_index(prepared) | |
| def cli_query(question: str, k: int = 5): | |
| index, meta, st_model = load_index() | |
| qa = GroundedQA(index, meta["records"], st_model, OPENAI_API_KEY) | |
| result = qa.answer_with_citations(question, k=k) | |
| print("\n=== Answer ===") | |
| print(result["answer"]) | |
| print("\n=== Sources ===") | |
| for s in result["sources"]: | |
| print(f"[{s['rank']}] ({s['distance']:.4f}) {s['title']} :: {s['id']}") | |
| print(f" {s['excerpt']}") | |
| print("\nReview flag:", "YES" if should_review(result) else "NO") | |
| # --- API (optional) ----------------------------------------------------------- | |
| if FASTAPI_AVAILABLE: | |
| app = FastAPI(title="Nyxion Labs RAG — SQuAD v2") | |
| class AskBody(BaseModel): | |
| question: str | |
| k: int = 5 | |
| _STATE = {"qa": None} | |
| if FASTAPI_AVAILABLE: | |
| async def lifespan(app: FastAPI): | |
| # Startup: warm the RAG pipeline once | |
| index, meta, st_model = load_index() | |
| app.state.qa = GroundedQA(index, meta["records"], st_model, OPENAI_API_KEY) | |
| yield | |
| # Teardown (optional): nothing to clean up | |
| app = FastAPI(title="Nyxion Labs RAG — SQuAD v2", lifespan=lifespan) | |
| class AskBody(BaseModel): | |
| question: str | |
| k: int = 5 | |
| def query_api(body: AskBody): | |
| qa: GroundedQA = app.state.qa | |
| res = qa.answer_with_citations(body.question, k=body.k) | |
| # Keep types JSON-safe + quick review flag | |
| avg = float(np.mean([s["distance"] for s in res["sources"]])) if res["sources"] else float("inf") | |
| res["needs_review"] = bool(avg > 1.2) | |
| return res | |
| def query_api(body: AskBody): | |
| qa: GroundedQA = _STATE["qa"] | |
| res = qa.answer_with_citations(body.question, k=body.k) | |
| res["needs_review"] = should_review(res) | |
| return res | |
| # --- main --------------------------------------------------------------------- | |
| def parse_args(): | |
| p = argparse.ArgumentParser(description="Nyxion Labs — RAG on SQuAD v2") | |
| p.add_argument("--build-index", action="store_true", help="Download SQuAD and build FAISS index") | |
| p.add_argument("--q", "--question", dest="question", type=str, help="Ask a question") | |
| p.add_argument("-k", type=int, default=5, help="Top-k contexts to retrieve") | |
| p.add_argument("--serve", action="store_true", help="Run FastAPI server on :8000") | |
| return p.parse_args() | |
| def main(): | |
| args = parse_args() | |
| if args.build_index: | |
| cli_build_index() | |
| return | |
| if args.serve: | |
| if not FASTAPI_AVAILABLE: | |
| print("FastAPI not installed. pip install fastapi uvicorn pydantic", file=sys.stderr) | |
| sys.exit(1) | |
| uvicorn.run("rag_demo:app", host="0.0.0.0", port=8000, reload=False) | |
| return | |
| if args.question: | |
| if OPENAI_API_KEY.startswith("sk-your-dev-key-here"): | |
| log("WARNING: Set your OPENAI_API_KEY at top of file.") | |
| cli_query(args.question, k=args.k) | |
| return | |
| print(__doc__) | |
| if __name__ == "__main__": | |
| main() | |