Spaces:
Running
Running
| # ========================= | |
| # GLOBAL CACHE | |
| # ========================= | |
| CACHE = { | |
| "last_text_hash": None, | |
| "chunks": None, | |
| "embeddings": None, | |
| "knn": None | |
| } | |
| import os | |
| import re | |
| import tempfile | |
| import gradio as gr | |
| from pypdf import PdfReader | |
| import trafilatura | |
| import requests | |
| from bs4 import BeautifulSoup | |
| import docx | |
| import pandas as pd | |
| from io import StringIO | |
| import numpy as np | |
| from sentence_transformers import SentenceTransformer | |
| from sklearn.neighbors import NearestNeighbors | |
| import torch | |
| import brotli | |
| from transformers import ( | |
| AutoModelForCausalLM, | |
| AutoTokenizer, | |
| NllbTokenizer, | |
| M2M100ForConditionalGeneration, | |
| ) | |
| # ========================= | |
| # CONFIG | |
| # ========================= | |
| EMBED_MODEL = "intfloat/e5-small-v2" | |
| LLM_MODEL = "Qwen/Qwen2.5-0.5B-Instruct" | |
| # Translation model (open, no auth required) | |
| TRANS_MODEL_ID = "facebook/nllb-200-distilled-600M" | |
| CHUNK_SIZE = 1500 | |
| CHUNK_OVERLAP = 300 | |
| MIN_SECTION_LEN = 300 | |
| # ========================= | |
| # CLEAN TEXT | |
| # ========================= | |
| def clean_text(text: str) -> str: | |
| return " ".join(text.replace("\r", "\n").split()) | |
| # ========================= | |
| # PDF INGEST | |
| # ========================= | |
| def extract_text_from_pdf(path: str) -> str: | |
| reader = PdfReader(path) | |
| text = "" | |
| for page in reader.pages: | |
| page_text = page.extract_text() or "" | |
| text += "\n" + page_text | |
| return clean_text(text) | |
| def extract_pdf_from_url(url: str) -> str: | |
| r = requests.get(url, timeout=20) | |
| tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".pdf") | |
| tmp.write(r.content) | |
| tmp.flush() | |
| txt = extract_text_from_pdf(tmp.name) | |
| tmp.close() | |
| return txt | |
| # ========================= | |
| # DOCX / TXT / CSV INGEST | |
| # ========================= | |
| def extract_docx_from_url(url: str) -> str: | |
| r = requests.get(url, timeout=20) | |
| tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".docx") | |
| tmp.write(r.content) | |
| tmp.flush() | |
| document = docx.Document(tmp.name) | |
| text = "\n".join(p.text for p in document.paragraphs) | |
| tmp.close() | |
| return clean_text(text) | |
| def extract_txt_from_url(url: str) -> str: | |
| return clean_text(requests.get(url, timeout=20).text) | |
| def extract_csv_from_url(url: str) -> str: | |
| df = pd.read_csv(StringIO(requests.get(url, timeout=20).text)) | |
| return clean_text(df.to_string()) | |
| # ========================= | |
| # ROBUST HTML + IN-PAGE PDF HANDLER | |
| # ========================= | |
| def extract_html_from_url(url: str) -> str: | |
| """ | |
| Robust extractor for research sites: | |
| - Handles brotli (br) encoding | |
| - Detects <a href="...pdf"> links inside HTML and downloads PDF | |
| - Falls back to cleaned HTML text | |
| """ | |
| headers = { | |
| "User-Agent": ( | |
| "Mozilla/5.0 (Windows NT 10.0; Win64; x64) " | |
| "AppleWebKit/537.36 (KHTML, like Gecko) " | |
| "Chrome/120.0 Safari/537.36" | |
| ), | |
| "Accept": "*/*", | |
| "Accept-Encoding": "gzip, deflate, br", | |
| } | |
| # 1) Fetch HTML | |
| try: | |
| resp = requests.get(url, headers=headers, timeout=20) | |
| if resp.headers.get("Content-Encoding") == "br": | |
| html = brotli.decompress(resp.content).decode("utf-8", errors="ignore") | |
| else: | |
| html = resp.text | |
| except Exception as e: | |
| return f"Error loading HTML: {e}" | |
| soup = BeautifulSoup(html, "html.parser") | |
| # 2) Try to find a PDF link inside the page | |
| pdf_links = [a["href"] for a in soup.find_all("a", href=True) | |
| if ".pdf" in a["href"].lower()] | |
| if pdf_links: | |
| pdf_url = pdf_links[0] | |
| if pdf_url.startswith("/"): | |
| from urllib.parse import urljoin | |
| pdf_url = urljoin(url, pdf_url) | |
| try: | |
| pdf_resp = requests.get(pdf_url, headers=headers, timeout=20) | |
| if pdf_resp.status_code == 200: | |
| tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".pdf") | |
| tmp.write(pdf_resp.content) | |
| tmp.flush() | |
| text = extract_text_from_pdf(tmp.name) | |
| tmp.close() | |
| return text | |
| except Exception: | |
| # If PDF fails, fall back to HTML extraction | |
| pass | |
| # 3) Try trafilatura for main text | |
| extracted = trafilatura.extract(html) | |
| if extracted and len(extracted) > 200: | |
| return clean_text(extracted) | |
| # 4) Raw HTML fallback | |
| for bad in soup(["script", "style", "noscript"]): | |
| bad.decompose() | |
| return clean_text(soup.get_text(" ", strip=True)) | |
| # ========================= | |
| # FILE TYPE DETECTION | |
| # ========================= | |
| def detect_filetype(url: str, headers) -> str: | |
| u = url.lower() | |
| c = headers.get("Content-Type", "").lower() | |
| if u.endswith(".pdf") or "pdf" in c: | |
| return "pdf" | |
| if u.endswith(".docx") or "word" in c: | |
| return "docx" | |
| if u.endswith(".txt") or "text/plain" in c: | |
| return "txt" | |
| if u.endswith(".csv") or "csv" in c: | |
| return "csv" | |
| return "html" | |
| # ========================= | |
| # SECTION-AWARE CHUNKING | |
| # ========================= | |
| SECTION_KEYWORDS = [ | |
| "introduction", "method", "methodology", "materials and methods", | |
| "results", "discussion", "conclusion", "conclusions", "abstract", | |
| "background", "analysis" | |
| ] | |
| def is_heading(line: str) -> bool: | |
| line = line.strip() | |
| if not line or len(line) > 120: | |
| return False | |
| lower = line.lower() | |
| if lower in SECTION_KEYWORDS: | |
| return True | |
| if line == line.upper() and any(c.isalpha() for c in line): | |
| return True | |
| if re.match(r"^\d+(\.\d+)*\s+[A-Za-z]", line): | |
| return True | |
| return False | |
| def split_into_sections(text: str): | |
| lines = text.split("\n") | |
| sections, title, buf = [], "Document", [] | |
| for line in lines: | |
| if is_heading(line): | |
| if buf and len("\n".join(buf)) > MIN_SECTION_LEN: | |
| sections.append((title, "\n".join(buf))) | |
| title = line.strip() | |
| buf = [] | |
| else: | |
| buf.append(line) | |
| if buf: | |
| body = "\n".join(buf) | |
| if len(body) > MIN_SECTION_LEN: | |
| sections.append((title, body)) | |
| if not sections: | |
| return [("Document", text)] | |
| return sections | |
| def chunk_text(text: str): | |
| sections = split_into_sections(text) | |
| # fallback: sliding window if no good sections | |
| if len(sections) == 1: | |
| chunks = [] | |
| start = 0 | |
| while start < len(text): | |
| end = min(start + CHUNK_SIZE, len(text)) | |
| chunks.append(text[start:end]) | |
| start += CHUNK_SIZE - CHUNK_OVERLAP | |
| return chunks | |
| chunks = [] | |
| for _, body in sections: | |
| paragraphs = [p.strip() for p in re.split(r"\n\s*\n", body) if p.strip()] | |
| current = "" | |
| for para in paragraphs: | |
| if not current: | |
| current = para | |
| elif len(current) + len(para) + 2 <= CHUNK_SIZE: | |
| current += "\n\n" + para | |
| else: | |
| chunks.append(current) | |
| current = para | |
| if current: | |
| chunks.append(current) | |
| return chunks | |
| # ========================= | |
| # SEMANTIC SEARCH (KNN) | |
| # ========================= | |
| class SemanticSearch: | |
| def __init__(self, model: str): | |
| self.embedder = SentenceTransformer(model) | |
| self.knn = None | |
| self.chunks = [] | |
| def build(self, chunks): | |
| global CACHE | |
| import hashlib | |
| h = hashlib.md5("".join(chunks).encode()).hexdigest() | |
| if CACHE["last_text_hash"] == h: | |
| print("β‘ Using cached embeddings") | |
| self.chunks = CACHE["chunks"] | |
| self.knn = CACHE["knn"] | |
| return | |
| print("β‘ Rebuilding embeddingsβ¦") | |
| self.chunks = chunks | |
| emb = self.embedder.encode(chunks, convert_to_numpy=True) | |
| self.knn = NearestNeighbors(metric="cosine") | |
| self.knn.fit(emb) | |
| CACHE["last_text_hash"] = h | |
| CACHE["chunks"] = chunks | |
| CACHE["knn"] = self.knn | |
| def search(self, q, k=5): | |
| if len(self.chunks) == 1: | |
| return [(self.chunks[0], 0.0)] | |
| q_emb = self.embedder.encode([q], convert_to_numpy=True) | |
| k = min(k, len(self.chunks)) | |
| dist, ids = self.knn.kneighbors(q_emb, n_neighbors=k) | |
| return [(self.chunks[i], float(dist[0][j])) for j, i in enumerate(ids[0])] | |
| vs = None | |
| # ========================= | |
| # LOAD QWEN FOR RAG | |
| # ========================= | |
| print("Loading Qwen 0.5Bβ¦") | |
| q_tokenizer = AutoTokenizer.from_pretrained(LLM_MODEL) | |
| q_model = AutoModelForCausalLM.from_pretrained(LLM_MODEL).to("cpu") | |
| q_model.eval() | |
| def run_llm(system: str, user: str) -> str: | |
| messages = [ | |
| {"role": "system", "content": system}, | |
| {"role": "user", "content": user}, | |
| ] | |
| text = q_tokenizer.apply_chat_template( | |
| messages, tokenize=False, add_generation_prompt=True | |
| ) | |
| inp = q_tokenizer(text, return_tensors="pt").to("cpu") | |
| out = q_model.generate( | |
| **inp, | |
| max_new_tokens=256, | |
| do_sample=True, | |
| top_p=0.85, | |
| temperature=0.6, | |
| eos_token_id=q_tokenizer.eos_token_id, | |
| ) | |
| gen = out[0][inp["input_ids"].shape[1]:] | |
| return q_tokenizer.decode(gen, skip_special_tokens=True).strip() | |
| # ========================= | |
| # LOAD NLLB TRANSLATOR | |
| # ========================= | |
| print("Loading NLLB-200 translatorβ¦") | |
| trans_tokenizer = NllbTokenizer.from_pretrained(TRANS_MODEL_ID) | |
| trans_model = M2M100ForConditionalGeneration.from_pretrained(TRANS_MODEL_ID).to("cpu") | |
| LANG_CODES = { | |
| "English": "eng_Latn", | |
| "Hindi": "hin_Deva", | |
| "Telugu": "tel_Telu", | |
| "Tamil": "tam_Taml", | |
| "Kannada": "kan_Knda", | |
| "Malayalam": "mal_Mlym", | |
| "Bengali": "ben_Beng", | |
| "Marathi": "mar_Deva", | |
| "Gujarati": "guj_Gujr", | |
| "Odia": "ory_Orya", | |
| "Punjabi": "pan_Guru", | |
| "Assamese": "asm_Beng", | |
| } | |
| def translate_to_indic(text: str, lang: str) -> str: | |
| if lang == "English" or lang == "auto": | |
| return text | |
| try: | |
| tgt = LANG_CODES[lang] | |
| inputs = trans_tokenizer(text, return_tensors="pt").to("cpu") | |
| output = trans_model.generate( | |
| **inputs, | |
| forced_bos_token_id=trans_tokenizer.convert_tokens_to_ids(tgt), | |
| max_new_tokens=300, | |
| ) | |
| return trans_tokenizer.batch_decode(output, skip_special_tokens=True)[0] | |
| except Exception as e: | |
| print("Translation error:", e) | |
| return text | |
| # ========================= | |
| # RAG PROMPT | |
| # ========================= | |
| def build_prompt(question, retrieved): | |
| ctx = "\n\n---\n\n".join([c for c, _ in retrieved]) | |
| return f""" | |
| You are a precise and factual RAG system. | |
| Your task is to answer the question strictly using the information found in the context. | |
| Follow these rules: | |
| 1. Use ONLY the context. Do not add external knowledge. | |
| 2. If the context does not contain the answer, say: | |
| "I don't know based on this document." | |
| 3. When possible, structure your answer into short, clear points. | |
| 4. Keep the answer concise, factual, and in English. | |
| CONTEXT: | |
| {ctx} | |
| QUESTION: | |
| {question} | |
| Write your answer below (in English): | |
| """.strip() | |
| # ========================= | |
| # SOURCE DISPLAY | |
| # ========================= | |
| def highlight_sources(retrieved): | |
| html = "<h4>π Source Passages</h4>" | |
| for i, (chunk, score) in enumerate(retrieved): | |
| html += f""" | |
| <div style='padding:10px; background:#eef6ff; margin-bottom:10px;'> | |
| <b>[{i+1}] Score: {1-score:.3f}</b><br> | |
| {chunk[:400]}... | |
| </div> | |
| """ | |
| return html | |
| # ========================= | |
| # ANSWER FUNCTION | |
| # ========================= | |
| def answer_question(q, lang): | |
| global vs | |
| if vs is None: | |
| return "Please load a document first.", "" | |
| retrieved = vs.search(q) | |
| prompt = build_prompt(q, retrieved) | |
| english_answer = run_llm("You are a reliable factual RAG assistant.", prompt) | |
| final = translate_to_indic(english_answer, lang) | |
| return final, highlight_sources(retrieved) | |
| # ========================= | |
| # LOADERS | |
| # ========================= | |
| def load_pdf_ui(file, lang): | |
| global vs | |
| if not file: | |
| return "Upload a PDF." | |
| text = extract_text_from_pdf(file.name) | |
| chunks = chunk_text(text) | |
| vs = SemanticSearch(EMBED_MODEL) | |
| vs.build(chunks) | |
| return f"PDF loaded with {len(chunks)} chunks." | |
| def load_url_ui(url, lang): | |
| global vs | |
| if not url: | |
| return "Enter a URL." | |
| try: | |
| head = requests.head(url, allow_redirects=True, timeout=20) | |
| ftype = detect_filetype(url, head.headers) | |
| if ftype == "pdf": | |
| text = extract_pdf_from_url(url) | |
| elif ftype == "docx": | |
| text = extract_docx_from_url(url) | |
| elif ftype == "txt": | |
| text = extract_txt_from_url(url) | |
| elif ftype == "csv": | |
| text = extract_csv_from_url(url) | |
| else: | |
| text = extract_html_from_url(url) | |
| except Exception as e: | |
| return f"Error loading URL: {e}" | |
| chunks = chunk_text(text) | |
| vs = SemanticSearch(EMBED_MODEL) | |
| vs.build(chunks) | |
| return f"URL loaded with {len(chunks)} chunks." | |
| # ========================= | |
| # UI | |
| # ========================= | |
| def create_app(): | |
| with gr.Blocks() as demo: | |
| gr.Markdown("<h1>π Multilingual Chat with PDF / URL</h1>") | |
| lang = gr.Dropdown( | |
| [ | |
| "auto", "English", "Hindi", "Telugu", "Tamil", | |
| "Kannada", "Malayalam", "Bengali", "Marathi", | |
| "Gujarati", "Odia", "Punjabi", "Assamese" | |
| ], | |
| value="auto", | |
| label="Answer Language" | |
| ) | |
| with gr.Tab("π Load Document"): | |
| pdf = gr.File(label="Upload PDF") | |
| pdf_status = gr.HTML() | |
| gr.Button("Load PDF").click(load_pdf_ui, [pdf, lang], pdf_status) | |
| url = gr.Textbox(label="Enter URL (PDF, DOCX, TXT, CSV, Website)") | |
| url_status = gr.HTML() | |
| gr.Button("Load URL").click(load_url_ui, [url, lang], url_status) | |
| with gr.Tab("π¬ Chat"): | |
| q = gr.Textbox(label="Your Question") | |
| a = gr.HTML() | |
| cits = gr.HTML() | |
| gr.Button("Ask").click(answer_question, [q, lang], [a, cits]) | |
| # Example Questions | |
| gr.Markdown("### β¨ Example Questions") | |
| with gr.Row(): | |
| ex1 = gr.Button("Give a summary of this document") | |
| ex2 = gr.Button("What are the key findings?") | |
| ex3 = gr.Button("Explain the methodology used") | |
| ex4 = gr.Button("List the main conclusions") | |
| ex5 = gr.Button("Explain in simple terms") | |
| ex6 = gr.Button("What is the significance of this study?") | |
| ex1.click(lambda: "Give a summary of this document", None, q) | |
| ex2.click(lambda: "What are the key findings?", None, q) | |
| ex3.click(lambda: "Explain the methodology used", None, q) | |
| ex4.click(lambda: "List the main conclusions", None, q) | |
| ex5.click(lambda: "Explain this in simple terms", None, q) | |
| ex6.click(lambda: "What is the significance of this study?", None, q) | |
| return demo | |
| demo = create_app() | |
| if __name__ == "__main__": | |
| demo.launch() | |