AI_Agent_Final / api /rag_engine.py
SarahXia0405's picture
Update api/rag_engine.py
3268902 verified
# api/rag_engine.py
"""
RAG engine:
- build_rag_chunks_from_file(path, doc_type) -> List[chunk]
- retrieve_relevant_chunks(query, chunks) -> (context_text, used_chunks)
Chunk format (MVP):
{
"text": str,
"source_file": str,
"section": str,
"doc_type": str
}
"""
import os
import re
from typing import Dict, List, Tuple
from pypdf import PdfReader
from docx import Document
from pptx import Presentation
# ============================
# Token helpers (optional tiktoken)
# ============================
def _safe_import_tiktoken():
try:
import tiktoken # type: ignore
return tiktoken
except Exception:
return None
def _approx_tokens(text: str) -> int:
if not text:
return 0
return max(1, int(len(text) / 4))
def _count_text_tokens(text: str, model: str = "") -> int:
tk = _safe_import_tiktoken()
if tk is None:
return _approx_tokens(text)
try:
enc = tk.encoding_for_model(model) if model else tk.get_encoding("cl100k_base")
except Exception:
enc = tk.get_encoding("cl100k_base")
return len(enc.encode(text or ""))
def _truncate_to_tokens(text: str, max_tokens: int, model: str = "") -> str:
"""
Deterministic truncation. Uses tiktoken if available; otherwise approximates by char ratio.
"""
if not text:
return text
tk = _safe_import_tiktoken()
if tk is None:
# approximate by chars
total = _approx_tokens(text)
if total <= max_tokens:
return text
ratio = max_tokens / max(1, total)
cut = max(50, min(len(text), int(len(text) * ratio)))
s = text[:cut]
# tighten
while _approx_tokens(s) > max_tokens and len(s) > 50:
s = s[: int(len(s) * 0.9)]
return s
try:
enc = tk.encoding_for_model(model) if model else tk.get_encoding("cl100k_base")
except Exception:
enc = tk.get_encoding("cl100k_base")
ids = enc.encode(text or "")
if len(ids) <= max_tokens:
return text
return enc.decode(ids[:max_tokens])
# ============================
# RAG hard limits
# ============================
RAG_TOPK_LIMIT = 4
RAG_CHUNK_TOKEN_LIMIT = 500
RAG_CONTEXT_TOKEN_LIMIT = 2000 # 4 * 500
# ----------------------------
# Helpers
# ----------------------------
def _clean_text(s: str) -> str:
s = (s or "").replace("\r", "\n")
s = re.sub(r"\n{3,}", "\n\n", s)
return s.strip()
def _split_into_chunks(text: str, max_chars: int = 1400) -> List[str]:
"""
Simple deterministic chunker:
- split by blank lines
- then pack into <= max_chars
"""
text = _clean_text(text)
if not text:
return []
paras = [p.strip() for p in text.split("\n\n") if p.strip()]
chunks: List[str] = []
buf = ""
for p in paras:
if not buf:
buf = p
continue
if len(buf) + 2 + len(p) <= max_chars:
buf = buf + "\n\n" + p
else:
chunks.append(buf)
buf = p
if buf:
chunks.append(buf)
return chunks
def _file_label(path: str) -> str:
return os.path.basename(path) if path else "uploaded_file"
# ----------------------------
# Parsers
# ----------------------------
def _parse_pdf_to_text(path: str) -> List[Tuple[str, str]]:
"""
Returns list of (section_label, text)
section_label uses page numbers.
"""
reader = PdfReader(path)
out: List[Tuple[str, str]] = []
for i, page in enumerate(reader.pages):
t = page.extract_text() or ""
t = _clean_text(t)
if t:
out.append((f"p{i+1}", t))
return out
def _parse_docx_to_text(path: str) -> List[Tuple[str, str]]:
doc = Document(path)
paras = [p.text.strip() for p in doc.paragraphs if p.text and p.text.strip()]
if not paras:
return []
full = "\n\n".join(paras)
return [("docx", _clean_text(full))]
def _parse_pptx_to_text(path: str) -> List[Tuple[str, str]]:
prs = Presentation(path)
out: List[Tuple[str, str]] = []
for idx, slide in enumerate(prs.slides, start=1):
lines: List[str] = []
for shape in slide.shapes:
if hasattr(shape, "text") and shape.text:
txt = shape.text.strip()
if txt:
lines.append(txt)
if lines:
out.append((f"slide{idx}", _clean_text("\n".join(lines))))
return out
# ----------------------------
# Public API
# ----------------------------
def build_rag_chunks_from_file(path: str, doc_type: str) -> List[Dict]:
"""
Build RAG chunks from a local file path.
Supports: .pdf / .docx / .pptx / .txt
"""
if not path or not os.path.exists(path):
return []
ext = os.path.splitext(path)[1].lower()
source_file = _file_label(path)
sections: List[Tuple[str, str]] = []
try:
if ext == ".pdf":
sections = _parse_pdf_to_text(path)
elif ext == ".docx":
sections = _parse_docx_to_text(path)
elif ext == ".pptx":
sections = _parse_pptx_to_text(path)
elif ext in [".txt", ".md"]:
with open(path, "r", encoding="utf-8", errors="ignore") as f:
sections = [("text", _clean_text(f.read()))]
else:
print(f"[rag_engine] unsupported file type: {ext}")
return []
except Exception as e:
print(f"[rag_engine] parse error for {source_file}: {repr(e)}")
return []
chunks: List[Dict] = []
for section, text in sections:
for j, piece in enumerate(_split_into_chunks(text), start=1):
chunks.append(
{
"text": piece,
"source_file": source_file,
"section": f"{section}#{j}",
"doc_type": doc_type,
}
)
return chunks
def retrieve_relevant_chunks(
query: str,
chunks: List[Dict],
k: int = RAG_TOPK_LIMIT,
max_context_chars: int = 600, # kept for backward compatibility (still used as a safety cap)
min_score: int = 6,
chunk_token_limit: int = RAG_CHUNK_TOKEN_LIMIT,
max_context_tokens: int = RAG_CONTEXT_TOKEN_LIMIT,
model_for_tokenizer: str = "",
) -> Tuple[str, List[Dict]]:
"""
Deterministic lightweight retrieval (no embeddings):
- score by token overlap
- return top-k chunks concatenated as context
Hard limits implemented:
- top-k <= 4 (default)
- each chunk <= 500 tokens
- total context <= 2000 tokens (default)
"""
query = _clean_text(query)
if not query or not chunks:
return "", []
# ✅ Short query gate: avoid wasting time on RAG for greetings / tiny inputs
q_tokens_list = re.findall(r"[a-zA-Z0-9]+", query.lower())
if (len(q_tokens_list) < 3) and (len(query) < 20):
return "", []
q_tokens = set(q_tokens_list)
if not q_tokens:
return "", []
scored: List[Tuple[int, Dict]] = []
for c in chunks:
text = (c.get("text") or "")
if not text:
continue
t_tokens = set(re.findall(r"[a-zA-Z0-9]+", text.lower()))
score = len(q_tokens.intersection(t_tokens))
if score >= min_score:
scored.append((score, c))
if not scored:
return "", []
scored.sort(key=lambda x: x[0], reverse=True)
# hard cap k
k = min(int(k or RAG_TOPK_LIMIT), RAG_TOPK_LIMIT)
top = [c for _, c in scored[:k]]
# truncate each chunk to <= chunk_token_limit
used: List[Dict] = []
truncated_texts: List[str] = []
total_tokens = 0
for c in top:
raw = c.get("text") or ""
if not raw:
continue
t = _truncate_to_tokens(raw, max_tokens=chunk_token_limit, model=model_for_tokenizer)
# enforce total context tokens cap
t_tokens = _count_text_tokens(t, model=model_for_tokenizer)
if total_tokens + t_tokens > max_context_tokens:
remaining = max_context_tokens - total_tokens
if remaining <= 0:
break
t = _truncate_to_tokens(t, max_tokens=remaining, model=model_for_tokenizer)
t_tokens = _count_text_tokens(t, model=model_for_tokenizer)
# legacy char cap safety (keep your previous behavior as extra guard)
if max_context_chars and max_context_chars > 0:
# approximate: don't let total string blow up
current_chars = sum(len(x) for x in truncated_texts)
if current_chars + len(t) > max_context_chars:
t = t[: max(0, max_context_chars - current_chars)]
t = _clean_text(t)
if not t:
continue
truncated_texts.append(t)
used.append(c)
total_tokens += t_tokens
if total_tokens >= max_context_tokens:
break
if not truncated_texts:
return "", []
context = "\n\n---\n\n".join(truncated_texts)
return context, used