NotebookLM / services /rag_engine.py
internomega-terrablue
refactor: move all prompts to prompt.poml, fix greeting handling
98c2768
"""Retrieval-only RAG engine for chat responses."""
from __future__ import annotations
import logging
import os
import re
from pathlib import Path
from ingestion_engine.embedding_generator import generate_query
from persistence.vector_store import VectorStore
logger = logging.getLogger(__name__)
K_RETRIEVE = 40
K_FINAL = 8
ALPHA = 0.05
MAX_SNIPPET_CHARS = 280
GEN_MODEL = "Qwen/Qwen2.5-7B-Instruct"
MAX_NEW_TOKENS = 400
TEMPERATURE = 0.2
TIMEOUT_SEC = 45
PROMPT_FILE = Path(__file__).resolve().parent.parent / "artifacts" / "prompt.poml"
def _parse_poml() -> tuple[str, str]:
"""Parse prompt.poml into (system_message, user_template)."""
raw = PROMPT_FILE.read_text(encoding="utf-8")
# System message: <role> + <item> rules inside <system>
role_m = re.search(r"<role>(.*?)</role>", raw, re.DOTALL)
role = role_m.group(1).strip() if role_m else ""
items = re.findall(r"<item>(.*?)</item>", raw, re.DOTALL)
rules = "\n".join(f"{i+1}) {it.strip()}" for i, it in enumerate(items))
system_msg = f"{role}\nRules:\n{rules}" if role else rules
# User template: content inside <template>
tmpl_m = re.search(r"<template>(.*?)</template>", raw, re.DOTALL)
user_template = tmpl_m.group(1).strip() if tmpl_m else "{{question}}\n\n{{context}}"
return system_msg, user_template
def _clean_text(text: str) -> str:
return " ".join((text or "").split())
def _tokenize_keywords(text: str) -> set[str]:
tokens = re.split(r"[^a-z0-9]+", (text or "").lower())
return {t for t in tokens if len(t) >= 3}
def _keyword_hit_count(query_keywords: set[str], chunk_text: str) -> int:
if not query_keywords:
return 0
chunk_tokens = _tokenize_keywords(chunk_text)
return len(query_keywords.intersection(chunk_tokens))
def _rerank_matches(query: str, matches: list[dict]) -> list[dict]:
"""Stage 2 rerank: pinecone score + ALPHA * lexical keyword hits."""
query_keywords = _tokenize_keywords(query)
rescored = []
for m in matches:
pinecone_score = float(m.get("score", 0.0) or 0.0)
hits = _keyword_hit_count(query_keywords, m.get("text", ""))
combined_score = pinecone_score + ALPHA * hits
rescored.append(
{
**m,
"keyword_hit_count": hits,
"combined_score": combined_score,
}
)
rescored.sort(key=lambda x: x.get("combined_score", 0.0), reverse=True)
deduped = []
seen = set()
for item in rescored:
key = (
item.get("source_filename", ""),
int(item.get("chunk_index", 0) or 0),
_clean_text(item.get("text", "")),
)
if key in seen:
continue
seen.add(key)
deduped.append(item)
if len(deduped) >= K_FINAL:
break
return deduped
def _build_citations(matches: list[dict]) -> list[dict]:
"""Convert vector matches into the citation format used by pages/chat.py."""
citations = []
seen = set()
for match in matches:
source = match.get("source_filename", "Unknown source")
chunk_index = int(match.get("chunk_index", 0) or 0)
key = (source, chunk_index)
if key in seen:
continue
seen.add(key)
snippet = _clean_text(match.get("text", ""))
if len(snippet) > MAX_SNIPPET_CHARS:
snippet = snippet[:MAX_SNIPPET_CHARS].rstrip() + "..."
citations.append(
{
"source": source,
"page": chunk_index,
"text": snippet,
}
)
return citations
def _build_content(matches: list[dict]) -> str:
if not matches:
return (
"I couldn't find relevant information in your uploaded sources for that question. "
"Try rephrasing the question or adding more sources."
)
lines = ["Based on your uploaded sources, here are the most relevant passages:", ""]
for idx, match in enumerate(matches, start=1):
source = match.get("source_filename", "Unknown source")
chunk_index = int(match.get("chunk_index", 0) or 0)
score = float(match.get("score", 0.0) or 0.0)
combined = float(match.get("combined_score", score) or score)
hits = int(match.get("keyword_hit_count", 0) or 0)
snippet = _clean_text(match.get("text", ""))
if len(snippet) > MAX_SNIPPET_CHARS:
snippet = snippet[:MAX_SNIPPET_CHARS].rstrip() + "..."
lines.append(
f"{idx}. **{source}** (chunk {chunk_index}, pinecone: {score:.3f}, hits: {hits}, combined: {combined:.3f})"
)
lines.append(f" {snippet}")
lines.append("")
lines.append("This is a two-stage retrieval-only response (no LLM synthesis yet).")
return "\n".join(lines)
def _build_context_text(reranked_matches: list[dict]) -> str:
"""Build formatted context from reranked chunks."""
blocks = []
for idx, match in enumerate(reranked_matches, start=1):
source = match.get("source_filename", "Unknown source")
chunk_index = int(match.get("chunk_index", 0) or 0)
text = _clean_text(match.get("text", ""))
blocks.append(f"[S{idx}] source={source} chunk={chunk_index}\n{text}")
return "\n\n".join(blocks)
def _build_user_message(question: str, reranked_matches: list[dict], history: list[dict], max_history=5) -> str:
"""Build the user message from POML template + chat history + context."""
_, user_template = _parse_poml()
context_text = _build_context_text(reranked_matches)
prompt = user_template.replace("{{question}}", question).replace("{{context}}", context_text)
if history:
history_text = ""
for msg in history[-max_history:]:
role = msg.get("role", "").capitalize()
content = msg.get("content", "")
history_text += f"{role}: {content}\n"
return f"CHAT HISTORY:\n{history_text}\n{prompt}"
return prompt
def _generate_answer(question: str, context_chunks: list[dict], chat_history: list[dict]) -> str:
"""Generate a grounded response using Hugging Face Inference API + chat history."""
from huggingface_hub import InferenceClient
token = os.environ.get("HF_TOKEN")
client = InferenceClient(token=token, timeout=TIMEOUT_SEC)
system_msg, _ = _parse_poml()
user_msg = _build_user_message(question, context_chunks, chat_history)
response = client.chat_completion(
model=GEN_MODEL,
messages=[
{"role": "system", "content": system_msg},
{"role": "user", "content": user_msg},
],
max_tokens=MAX_NEW_TOKENS,
temperature=TEMPERATURE,
)
content = response.choices[0].message.content if response and response.choices else ""
return (content or "").strip()
def rag_answer(question: str, notebook_id: str, chat_history: list[dict] = None) -> dict:
"""Return a retrieval-only answer object: {"content": str, "citations": list}, using chat history"""
q = (question or "").strip()
if not q:
return {"content": "Please enter a question.", "citations": []}
try:
query_vector = generate_query(q)
# Stage 1: retrieve candidate pool
matches = VectorStore().query(query_vector=query_vector, namespace=notebook_id, top_k=K_RETRIEVE)
candidates = [m for m in matches if m.get("text")]
if not candidates:
return {
"content": (
"I couldn't find relevant information in your uploaded sources for that question. "
"Try rephrasing the question or adding more sources."
),
"citations": [],
}
# Stage 2: rerank and keep top K_FINAL
final_matches = _rerank_matches(q, candidates)
citations = _build_citations(final_matches)
retrieval_only = _build_content(final_matches)
try:
generated = _generate_answer(q, final_matches, chat_history or [])
content = generated or retrieval_only
except Exception as e:
logger.warning("Generation failed, falling back to retrieval-only content: %s", e)
content = retrieval_only
return {
"content": content,
"citations": citations,
}
except Exception as e:
logger.error("RAG retrieval failed: %s", e)
return {
"content": f"I ran into an error while retrieving from sources: {e}",
"citations": [],
}