Spaces:
Running
Running
| """Retrieval-only RAG engine for chat responses.""" | |
| from __future__ import annotations | |
| import logging | |
| import os | |
| import re | |
| from pathlib import Path | |
| from ingestion_engine.embedding_generator import generate_query | |
| from persistence.vector_store import VectorStore | |
| logger = logging.getLogger(__name__) | |
| K_RETRIEVE = 40 | |
| K_FINAL = 8 | |
| ALPHA = 0.05 | |
| MAX_SNIPPET_CHARS = 280 | |
| GEN_MODEL = "Qwen/Qwen2.5-7B-Instruct" | |
| MAX_NEW_TOKENS = 400 | |
| TEMPERATURE = 0.2 | |
| TIMEOUT_SEC = 45 | |
| PROMPT_FILE = Path(__file__).resolve().parent.parent / "artifacts" / "prompt.poml" | |
| def _parse_poml() -> tuple[str, str]: | |
| """Parse prompt.poml into (system_message, user_template).""" | |
| raw = PROMPT_FILE.read_text(encoding="utf-8") | |
| # System message: <role> + <item> rules inside <system> | |
| role_m = re.search(r"<role>(.*?)</role>", raw, re.DOTALL) | |
| role = role_m.group(1).strip() if role_m else "" | |
| items = re.findall(r"<item>(.*?)</item>", raw, re.DOTALL) | |
| rules = "\n".join(f"{i+1}) {it.strip()}" for i, it in enumerate(items)) | |
| system_msg = f"{role}\nRules:\n{rules}" if role else rules | |
| # User template: content inside <template> | |
| tmpl_m = re.search(r"<template>(.*?)</template>", raw, re.DOTALL) | |
| user_template = tmpl_m.group(1).strip() if tmpl_m else "{{question}}\n\n{{context}}" | |
| return system_msg, user_template | |
| def _clean_text(text: str) -> str: | |
| return " ".join((text or "").split()) | |
| def _tokenize_keywords(text: str) -> set[str]: | |
| tokens = re.split(r"[^a-z0-9]+", (text or "").lower()) | |
| return {t for t in tokens if len(t) >= 3} | |
| def _keyword_hit_count(query_keywords: set[str], chunk_text: str) -> int: | |
| if not query_keywords: | |
| return 0 | |
| chunk_tokens = _tokenize_keywords(chunk_text) | |
| return len(query_keywords.intersection(chunk_tokens)) | |
| def _rerank_matches(query: str, matches: list[dict]) -> list[dict]: | |
| """Stage 2 rerank: pinecone score + ALPHA * lexical keyword hits.""" | |
| query_keywords = _tokenize_keywords(query) | |
| rescored = [] | |
| for m in matches: | |
| pinecone_score = float(m.get("score", 0.0) or 0.0) | |
| hits = _keyword_hit_count(query_keywords, m.get("text", "")) | |
| combined_score = pinecone_score + ALPHA * hits | |
| rescored.append( | |
| { | |
| **m, | |
| "keyword_hit_count": hits, | |
| "combined_score": combined_score, | |
| } | |
| ) | |
| rescored.sort(key=lambda x: x.get("combined_score", 0.0), reverse=True) | |
| deduped = [] | |
| seen = set() | |
| for item in rescored: | |
| key = ( | |
| item.get("source_filename", ""), | |
| int(item.get("chunk_index", 0) or 0), | |
| _clean_text(item.get("text", "")), | |
| ) | |
| if key in seen: | |
| continue | |
| seen.add(key) | |
| deduped.append(item) | |
| if len(deduped) >= K_FINAL: | |
| break | |
| return deduped | |
| def _build_citations(matches: list[dict]) -> list[dict]: | |
| """Convert vector matches into the citation format used by pages/chat.py.""" | |
| citations = [] | |
| seen = set() | |
| for match in matches: | |
| source = match.get("source_filename", "Unknown source") | |
| chunk_index = int(match.get("chunk_index", 0) or 0) | |
| key = (source, chunk_index) | |
| if key in seen: | |
| continue | |
| seen.add(key) | |
| snippet = _clean_text(match.get("text", "")) | |
| if len(snippet) > MAX_SNIPPET_CHARS: | |
| snippet = snippet[:MAX_SNIPPET_CHARS].rstrip() + "..." | |
| citations.append( | |
| { | |
| "source": source, | |
| "page": chunk_index, | |
| "text": snippet, | |
| } | |
| ) | |
| return citations | |
| def _build_content(matches: list[dict]) -> str: | |
| if not matches: | |
| return ( | |
| "I couldn't find relevant information in your uploaded sources for that question. " | |
| "Try rephrasing the question or adding more sources." | |
| ) | |
| lines = ["Based on your uploaded sources, here are the most relevant passages:", ""] | |
| for idx, match in enumerate(matches, start=1): | |
| source = match.get("source_filename", "Unknown source") | |
| chunk_index = int(match.get("chunk_index", 0) or 0) | |
| score = float(match.get("score", 0.0) or 0.0) | |
| combined = float(match.get("combined_score", score) or score) | |
| hits = int(match.get("keyword_hit_count", 0) or 0) | |
| snippet = _clean_text(match.get("text", "")) | |
| if len(snippet) > MAX_SNIPPET_CHARS: | |
| snippet = snippet[:MAX_SNIPPET_CHARS].rstrip() + "..." | |
| lines.append( | |
| f"{idx}. **{source}** (chunk {chunk_index}, pinecone: {score:.3f}, hits: {hits}, combined: {combined:.3f})" | |
| ) | |
| lines.append(f" {snippet}") | |
| lines.append("") | |
| lines.append("This is a two-stage retrieval-only response (no LLM synthesis yet).") | |
| return "\n".join(lines) | |
| def _build_context_text(reranked_matches: list[dict]) -> str: | |
| """Build formatted context from reranked chunks.""" | |
| blocks = [] | |
| for idx, match in enumerate(reranked_matches, start=1): | |
| source = match.get("source_filename", "Unknown source") | |
| chunk_index = int(match.get("chunk_index", 0) or 0) | |
| text = _clean_text(match.get("text", "")) | |
| blocks.append(f"[S{idx}] source={source} chunk={chunk_index}\n{text}") | |
| return "\n\n".join(blocks) | |
| def _build_user_message(question: str, reranked_matches: list[dict], history: list[dict], max_history=5) -> str: | |
| """Build the user message from POML template + chat history + context.""" | |
| _, user_template = _parse_poml() | |
| context_text = _build_context_text(reranked_matches) | |
| prompt = user_template.replace("{{question}}", question).replace("{{context}}", context_text) | |
| if history: | |
| history_text = "" | |
| for msg in history[-max_history:]: | |
| role = msg.get("role", "").capitalize() | |
| content = msg.get("content", "") | |
| history_text += f"{role}: {content}\n" | |
| return f"CHAT HISTORY:\n{history_text}\n{prompt}" | |
| return prompt | |
| def _generate_answer(question: str, context_chunks: list[dict], chat_history: list[dict]) -> str: | |
| """Generate a grounded response using Hugging Face Inference API + chat history.""" | |
| from huggingface_hub import InferenceClient | |
| token = os.environ.get("HF_TOKEN") | |
| client = InferenceClient(token=token, timeout=TIMEOUT_SEC) | |
| system_msg, _ = _parse_poml() | |
| user_msg = _build_user_message(question, context_chunks, chat_history) | |
| response = client.chat_completion( | |
| model=GEN_MODEL, | |
| messages=[ | |
| {"role": "system", "content": system_msg}, | |
| {"role": "user", "content": user_msg}, | |
| ], | |
| max_tokens=MAX_NEW_TOKENS, | |
| temperature=TEMPERATURE, | |
| ) | |
| content = response.choices[0].message.content if response and response.choices else "" | |
| return (content or "").strip() | |
| def rag_answer(question: str, notebook_id: str, chat_history: list[dict] = None) -> dict: | |
| """Return a retrieval-only answer object: {"content": str, "citations": list}, using chat history""" | |
| q = (question or "").strip() | |
| if not q: | |
| return {"content": "Please enter a question.", "citations": []} | |
| try: | |
| query_vector = generate_query(q) | |
| # Stage 1: retrieve candidate pool | |
| matches = VectorStore().query(query_vector=query_vector, namespace=notebook_id, top_k=K_RETRIEVE) | |
| candidates = [m for m in matches if m.get("text")] | |
| if not candidates: | |
| return { | |
| "content": ( | |
| "I couldn't find relevant information in your uploaded sources for that question. " | |
| "Try rephrasing the question or adding more sources." | |
| ), | |
| "citations": [], | |
| } | |
| # Stage 2: rerank and keep top K_FINAL | |
| final_matches = _rerank_matches(q, candidates) | |
| citations = _build_citations(final_matches) | |
| retrieval_only = _build_content(final_matches) | |
| try: | |
| generated = _generate_answer(q, final_matches, chat_history or []) | |
| content = generated or retrieval_only | |
| except Exception as e: | |
| logger.warning("Generation failed, falling back to retrieval-only content: %s", e) | |
| content = retrieval_only | |
| return { | |
| "content": content, | |
| "citations": citations, | |
| } | |
| except Exception as e: | |
| logger.error("RAG retrieval failed: %s", e) | |
| return { | |
| "content": f"I ran into an error while retrieving from sources: {e}", | |
| "citations": [], | |
| } | |