crazy_bot / rag_engine.py
Wall06's picture
Update rag_engine.py
afd7f8c verified
from __future__ import annotations
import os
import re
import textwrap
import subprocess
import sys
from pathlib import Path
import faiss
import numpy as np
import requests
import spacy
from bs4 import BeautifulSoup
from huggingface_hub import InferenceClient
from pypdf import PdfReader
from sentence_transformers import SentenceTransformer
# ── Config ─────────────────────────────────────────────
EMBED_MODEL = "sentence-transformers/all-MiniLM-L6-v2"
LLM_MODEL = "HuggingFaceH4/zephyr-7b-beta"
CHUNK_SIZE = 400
CHUNK_OVERLAP = 80
TOP_K = 4
# ── Engine ─────────────────────────────────────────────
class RAGEngine:
def __init__(self):
print("Loading embedding model...")
self.embedder = SentenceTransformer(EMBED_MODEL)
self.hf_client = InferenceClient(token=os.getenv("HF_TOKEN"))
self._load_spacy()
self.reset()
def _load_spacy(self):
try:
self.nlp = spacy.load("en_core_web_sm")
except:
subprocess.run(
[sys.executable, "-m", "spacy", "download", "en_core_web_sm"],
check=True,
)
self.nlp = spacy.load("en_core_web_sm")
def reset(self):
self.chunks = []
self.index = None
@property
def ready(self):
return self.index is not None and len(self.chunks) > 0
# ── Loaders ─────────────────────────────────────
def load_pdf(self, path):
reader = PdfReader(path)
text = " ".join(p.extract_text() or "" for p in reader.pages)
if not text.strip():
raise ValueError("No text found in PDF")
self._build_index(text)
return f"βœ… PDF loaded ({len(self.chunks)} chunks)"
def load_url(self, url):
r = requests.get(url, timeout=15, headers={"User-Agent": "Mozilla/5.0"})
r.raise_for_status()
soup = BeautifulSoup(r.text, "html.parser")
for tag in soup(["script", "style"]):
tag.decompose()
text = soup.get_text(" ", strip=True)
if not text.strip():
raise ValueError("No text found in URL")
self._build_index(text)
return f"βœ… URL loaded ({len(self.chunks)} chunks)"
def load_text(self, text):
if not text.strip():
raise ValueError("Empty text")
self._build_index(text)
return f"βœ… Text loaded ({len(self.chunks)} chunks)"
# ── Chunking ─────────────────────────────────────
def _chunk(self, text):
text = re.sub(r"\s+", " ", text)
chunks, i = [], 0
while i < len(text):
chunks.append(text[i:i + CHUNK_SIZE])
i += CHUNK_SIZE - CHUNK_OVERLAP
return [c for c in chunks if len(c.strip()) > 30]
# ── Indexing ─────────────────────────────────────
def _build_index(self, text):
self.chunks = self._chunk(text)
emb = self.embedder.encode(self.chunks, show_progress_bar=False)
emb = np.array(emb).astype("float32")
faiss.normalize_L2(emb)
self.index = faiss.IndexFlatIP(emb.shape[1])
self.index.add(emb)
# ── Retrieval ─────────────────────────────────────
def _retrieve(self, query):
emb = self.embedder.encode([query], show_progress_bar=False)
emb = np.array(emb).astype("float32")
faiss.normalize_L2(emb)
_, idx = self.index.search(emb, TOP_K)
return [self.chunks[i] for i in idx[0] if i < len(self.chunks)]
# ── Answer ───────────────────────────────────────
def answer(self, query):
if not self.ready:
return "⚠️ Please load data first."
chunks = self._retrieve(query)
prompt = f"""
Use ONLY this context to answer:
{chunks}
Question: {query}
"""
try:
res = self.hf_client.text_generation(
prompt,
model=LLM_MODEL,
max_new_tokens=300,
temperature=0.3,
)
return res.strip()
except Exception as e:
return f"⚠️ API Error: {e}\n\n{chunks[0]}"