Spaces:
Runtime error
Runtime error
| # app.py | |
| # RAG app for chatting with research papers (optimized for Hugging Face Spaces) | |
| import os, sys, subprocess, re, json, uuid, gc | |
| from typing import List, Dict, Tuple | |
| # ----------------------------- | |
| # Auto-install deps if missing | |
| # ----------------------------- | |
| def ensure(pkg, pip_name=None): | |
| try: | |
| __import__(pkg) | |
| except ImportError: | |
| subprocess.check_call([sys.executable, "-m", "pip", "install", pip_name or pkg]) | |
| ensure("torch") | |
| ensure("transformers") | |
| ensure("accelerate") | |
| ensure("gradio") | |
| ensure("faiss", "faiss-cpu") | |
| ensure("sentence_transformers", "sentence-transformers") | |
| ensure("pypdf") | |
| ensure("docx", "python-docx") | |
| import torch | |
| from transformers import ( | |
| AutoTokenizer, | |
| AutoModelForCausalLM, | |
| TextIteratorStreamer | |
| ) | |
| from sentence_transformers import SentenceTransformer | |
| import faiss, gradio as gr | |
| from pypdf import PdfReader | |
| # ----------------------------- | |
| # Config | |
| # ----------------------------- | |
| DATA_DIR = "rag_data" | |
| os.makedirs(DATA_DIR, exist_ok=True) | |
| INDEX_PATH = os.path.join(DATA_DIR, "faiss.index") | |
| DOCS_PATH = os.path.join(DATA_DIR, "docs.jsonl") | |
| # Default Models | |
| default_emb_model = "allenai/specter2_base" | |
| default_llm_model = "microsoft/Phi-3-mini-4k-instruct" | |
| EMB_MODEL_ID = os.environ.get("EMB_MODEL_ID", default_emb_model) | |
| LLM_MODEL_ID = os.environ.get("LLM_MODEL_ID", default_llm_model) | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| # ----------------------------- | |
| # File loaders | |
| # ----------------------------- | |
| def read_txt(path): | |
| return open(path, "r", encoding="utf-8", errors="ignore").read() | |
| def read_pdf(path): | |
| r = PdfReader(path) | |
| return "\n".join([p.extract_text() or "" for p in r.pages]) | |
| def read_docx(path): | |
| import docx | |
| d = docx.Document(path) | |
| return "\n".join([p.text for p in d.paragraphs]) | |
| def load_file(path): | |
| ext = os.path.splitext(path)[1].lower() | |
| if ext in [".txt", ".md"]: | |
| return read_txt(path) | |
| if ext == ".pdf": | |
| return read_pdf(path) | |
| if ext == ".docx": | |
| return read_docx(path) | |
| return read_txt(path) | |
| # ----------------------------- | |
| # Chunking | |
| # ----------------------------- | |
| def normalize_ws(s: str): | |
| return re.sub(r"\s+", " ", s).strip() | |
| def chunk_text(text, chunk_size=900, overlap=150): | |
| text = normalize_ws(text) | |
| chunks = [] | |
| for i in range(0, len(text), chunk_size - overlap): | |
| chunks.append(text[i:i+chunk_size]) | |
| return chunks | |
| # ----------------------------- | |
| # VectorStore | |
| # ----------------------------- | |
| class VectorStore: | |
| def __init__(self, emb_model): | |
| self.emb_model = emb_model | |
| self.dim = emb_model.get_sentence_embedding_dimension() | |
| if os.path.exists(INDEX_PATH): | |
| self.index = faiss.read_index(INDEX_PATH) | |
| self.meta = [json.loads(l) for l in open(DOCS_PATH, "r", encoding="utf-8")] | |
| else: | |
| self.index = faiss.IndexFlatIP(self.dim) | |
| self.meta = [] | |
| def _embed(self, texts): | |
| embs = self.emb_model.encode(texts, convert_to_tensor=True, normalize_embeddings=True) | |
| return embs.cpu().numpy() | |
| def add(self, chunks, source): | |
| if not chunks: return 0 | |
| embs = self._embed(chunks) | |
| faiss.normalize_L2(embs) | |
| self.index.add(embs) | |
| recs = [] | |
| for c in chunks: | |
| rec = {"id": str(uuid.uuid4()), "source": source, "text": c} | |
| self.meta.append(rec) | |
| recs.append(json.dumps(rec)) | |
| with open(DOCS_PATH, "a", encoding="utf-8") as f: | |
| f.write("\n".join(recs) + "\n") | |
| faiss.write_index(self.index, INDEX_PATH) | |
| return len(chunks) | |
| def search(self, query, k=5): | |
| q = self._embed([query]) | |
| faiss.normalize_L2(q) | |
| D, I = self.index.search(q, k) | |
| return [(float(d), self.meta[i]) for d, i in zip(D[0], I[0]) if i != -1] | |
| def clear(self): | |
| self.index = faiss.IndexFlatIP(self.dim) | |
| self.meta = [] | |
| if os.path.exists(INDEX_PATH): os.remove(INDEX_PATH) | |
| if os.path.exists(DOCS_PATH): os.remove(DOCS_PATH) | |
| # ----------------------------- | |
| # Load models | |
| # ----------------------------- | |
| print(f"[RAG] Loading embeddings: {EMB_MODEL_ID}") | |
| EMB = SentenceTransformer(EMB_MODEL_ID, device=DEVICE) | |
| VEC = VectorStore(EMB) | |
| print(f"[RAG] Loading LLM: {LLM_MODEL_ID}") | |
| bnb_config = None | |
| if DEVICE == "cuda": | |
| from transformers import BitsAndBytesConfig | |
| bnb_config = BitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_use_double_quant=True, | |
| bnb_4bit_quant_type="nf4" | |
| ) | |
| TOKENIZER = AutoTokenizer.from_pretrained(LLM_MODEL_ID, use_fast=True, trust_remote_code=True) | |
| LLM = AutoModelForCausalLM.from_pretrained( | |
| LLM_MODEL_ID, | |
| device_map="auto", | |
| quantization_config=bnb_config, | |
| torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32, | |
| low_cpu_mem_usage=True, | |
| trust_remote_code=True, | |
| ) | |
| # ----------------------------- | |
| # Prompt + Generate | |
| # ----------------------------- | |
| SYSTEM_PROMPT = "You are a helpful assistant. Use the provided context from research papers to answer questions." | |
| def build_prompt(query, history, retrieved): | |
| ctx = "\n\n".join([f"[{i+1}] {m['text']}" for i, (_, m) in enumerate(retrieved)]) | |
| # Try to use chat template if available | |
| if hasattr(TOKENIZER, "apply_chat_template"): | |
| messages = [{"role": "system", "content": SYSTEM_PROMPT + "\nContext:\n" + ctx}] | |
| for u, a in history[-3:]: | |
| messages.append({"role": "user", "content": u}) | |
| messages.append({"role": "assistant", "content": a}) | |
| messages.append({"role": "user", "content": query}) | |
| return TOKENIZER.apply_chat_template(messages, tokenize=False) | |
| else: | |
| # Fallback manual prompt | |
| hist = "".join([f"<user>{u}</user><assistant>{a}</assistant>" for u, a in history[-3:]]) | |
| return f"<system>{SYSTEM_PROMPT}\nContext:\n{ctx}</system>{hist}<user>{query}</user><assistant>" | |
| def generate_answer(prompt, temperature=0.3, max_new_tokens=512): | |
| streamer = TextIteratorStreamer(TOKENIZER, skip_prompt=True, skip_special_tokens=True) | |
| inputs = TOKENIZER([prompt], return_tensors="pt").to(LLM.device) | |
| kwargs = dict( | |
| **inputs, | |
| max_new_tokens=max_new_tokens, | |
| temperature=temperature, | |
| do_sample=temperature > 0, | |
| streamer=streamer | |
| ) | |
| import threading | |
| t = threading.Thread(target=LLM.generate, kwargs=kwargs) | |
| t.start() | |
| out = "" | |
| for token in streamer: | |
| out += token | |
| yield out | |
| t.join() | |
| # ----------------------------- | |
| # Gradio UI | |
| # ----------------------------- | |
| def ui_ingest(files, chunk_size, overlap): | |
| total = 0 | |
| names = [] | |
| for f in files or []: | |
| text = load_file(f.name) | |
| chunks = chunk_text(text, chunk_size, overlap) | |
| n = VEC.add(chunks, os.path.basename(f.name)) | |
| total += n; names.append(f.name) | |
| return f"Added {total} chunks", "\n".join(names) or "β", VEC.index.ntotal | |
| def ui_clear(): | |
| VEC.clear() | |
| gc.collect() | |
| return "Index cleared", "β", 0 | |
| def ui_chat(msg, history, top_k, temperature, max_tokens): | |
| if not msg.strip(): | |
| return history, "" | |
| retrieved = VEC.search(msg, top_k) | |
| prompt = build_prompt(msg, history, retrieved) | |
| reply = "" | |
| for partial in generate_answer(prompt, temperature, max_tokens): | |
| reply = partial | |
| yield history + [(msg, reply)], "" | |
| yield history + [(msg, reply)], "" | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# ππ Research Paper RAG Chat (Phi-3-mini + Specter2)") | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| chatbot = gr.Chatbot(height=500) | |
| msg = gr.Textbox(placeholder="Ask a question...") | |
| with gr.Row(): | |
| send = gr.Button("Send", variant="primary") | |
| clearc = gr.Button("Clear Chat") | |
| with gr.Column(): | |
| files = gr.File(label="Upload PDFs/DOCX/TXT", file_types=[".pdf", ".docx", ".txt", ".md"], file_count="multiple") | |
| chunk_size = gr.Slider(200,2000,900,step=50,label="Chunk Size") | |
| overlap = gr.Slider(0,400,150,step=10,label="Overlap") | |
| ingest_btn = gr.Button("Index Documents") | |
| status = gr.Textbox(label="Status", value="β") | |
| added = gr.Textbox(label="Files", value="β") | |
| total = gr.Number(label="Total Chunks", value=VEC.index.ntotal) | |
| clear_idx = gr.Button("Clear Index", variant="stop") | |
| top_k = gr.Slider(1,10,5,1,label="Top-K") | |
| temperature = gr.Slider(0.0,1.5,0.3,0.1,label="Temperature") | |
| max_tokens = gr.Slider(64,2048,512,64,label="Max New Tokens") | |
| ingest_btn.click(ui_ingest, [files, chunk_size, overlap], [status, added, total]) | |
| clear_idx.click(ui_clear, [], [status, added, total]) | |
| send.click(ui_chat, [msg, chatbot, top_k, temperature, max_tokens], [chatbot, msg]) | |
| msg.submit(ui_chat, [msg, chatbot, top_k, temperature, max_tokens], [chatbot, msg]) | |
| clearc.click(lambda: ([], ""), [], [chatbot, msg]) | |
| if __name__ == "__main__": | |
| demo.queue().launch(server_name="0.0.0.0", server_port=int(os.getenv("PORT", 7860))) | |