Spaces:
Sleeping
Sleeping
| import os | |
| import tempfile | |
| import gradio as gr | |
| import json | |
| import pandas as pd | |
| import requests | |
| from bs4 import BeautifulSoup | |
| from docx import Document | |
| from sentence_transformers import SentenceTransformer | |
| import faiss | |
| import numpy as np | |
| from transformers import pipeline | |
| import logging | |
| import io | |
| # PDF libraries | |
| try: | |
| from pypdf import PdfReader | |
| HAS_PYPDF = True | |
| except: | |
| HAS_PYPDF = False | |
| try: | |
| import pdfplumber | |
| HAS_PDFPLUMBER = True | |
| except: | |
| HAS_PDFPLUMBER = False | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # ============================== | |
| # CONFIG | |
| # ============================== | |
| HF_GENERATION_MODEL = os.environ.get("HF_GENERATION_MODEL", "google/flan-t5-large") | |
| EMBEDDING_MODEL_NAME = "sentence-transformers/paraphrase-MiniLM-L3-v2" | |
| INDEX_PATH = "faiss_index.index" | |
| METADATA_PATH = "metadata.json" | |
| # Initialize models | |
| embed_model = SentenceTransformer(EMBEDDING_MODEL_NAME) | |
| gen_pipeline = pipeline("text2text-generation", model=HF_GENERATION_MODEL, device=-1) | |
| # ============================== | |
| # SIMPLE TEXT SPLITTER | |
| # ============================== | |
| def simple_text_splitter(text, chunk_size=1000, chunk_overlap=100): | |
| if len(text) <= chunk_size: | |
| return [text.strip()] | |
| chunks = [] | |
| start = 0 | |
| while start < len(text): | |
| end = min(start + chunk_size, len(text)) | |
| chunk = text[start:end].strip() | |
| if len(chunk) > 50: | |
| chunks.append(chunk) | |
| start = end - chunk_overlap | |
| return [c for c in chunks if len(c) > 20] | |
| # ============================== | |
| # CORRECTED FILE HANDLING FOR GRADIO | |
| # ============================== | |
| def get_file_data(file_obj): | |
| """Handle different Gradio file formats correctly""" | |
| debug = [] | |
| # Method 1: File has .name attribute (temp file path) | |
| if hasattr(file_obj, 'name') and file_obj.name: | |
| debug.append(f"Using file path: {file_obj.name}") | |
| return file_obj.name, "path" | |
| # Method 2: File has .data attribute (base64 or bytes) | |
| if hasattr(file_obj, 'data') and file_obj.data: | |
| debug.append(f"Using file.data: {len(file_obj.data)} bytes") | |
| return file_obj.data, "bytes" | |
| # Method 3: Try to read as bytes | |
| try: | |
| if hasattr(file_obj, 'read'): | |
| file_obj.seek(0) # Reset file pointer | |
| data = file_obj.read() | |
| if data: | |
| debug.append(f"Read {len(data)} bytes from file object") | |
| return data, "read" | |
| except Exception as e: | |
| debug.append(f"Read failed: {e}") | |
| # Method 4: Check if it's a dict with content | |
| if isinstance(file_obj, dict): | |
| if 'data' in file_obj and file_obj['data']: | |
| debug.append(f"Using dict data: {len(file_obj['data'])} bytes") | |
| return file_obj['data'], "dict" | |
| if 'name' in file_obj and file_obj['name']: | |
| debug.append(f"Using dict path: {file_obj['name']}") | |
| return file_obj['name'], "dict_path" | |
| # Method 5: String path | |
| if isinstance(file_obj, str) and os.path.exists(file_obj): | |
| debug.append(f"Using string path: {file_obj}") | |
| return file_obj, "string_path" | |
| debug.append("β No valid file data found") | |
| return None, debug | |
| # ============================== | |
| # PDF EXTRACTION | |
| # ============================== | |
| def extract_pdf_text(file_data, source_type, debug_info): | |
| """Extract text from PDF using multiple methods""" | |
| temp_path = None | |
| try: | |
| # If we have a file path, use it directly | |
| if source_type in ["path", "string_path", "dict_path"]: | |
| file_path = file_data | |
| if not os.path.exists(file_path): | |
| debug_info.append(f"β File path doesn't exist: {file_path}") | |
| return "File not found" | |
| # Try pdftotext first (if available) | |
| try: | |
| import subprocess | |
| result = subprocess.run(['pdftotext', file_path, '-'], | |
| capture_output=True, text=True, timeout=15) | |
| if result.returncode == 0 and len(result.stdout.strip()) > 30: | |
| debug_info.append(f"β pdftotext: {len(result.stdout)} chars") | |
| return result.stdout | |
| except: | |
| pass | |
| # Create temp file from bytes | |
| if source_type in ["bytes", "read", "dict"]: | |
| temp_path = tempfile.NamedTemporaryFile(delete=False, suffix='.pdf').name | |
| with open(temp_path, 'wb') as f: | |
| if isinstance(file_data, str): | |
| f.write(file_data.encode('latin1')) # PDFs are binary | |
| else: | |
| f.write(file_data) | |
| file_path = temp_path | |
| debug_info.append(f"Created temp file: {temp_path}") | |
| # Try pdfplumber | |
| if HAS_PDFPLUMBER: | |
| try: | |
| with pdfplumber.open(file_path) as pdf: | |
| text = "" | |
| for i, page in enumerate(pdf.pages[:5]): | |
| page_text = page.extract_text() | |
| if page_text: | |
| text += page_text + "\n" | |
| if len(text.strip()) > 50: | |
| debug_info.append(f"β pdfplumber: {len(text)} chars") | |
| return text | |
| except Exception as e: | |
| debug_info.append(f"pdfplumber failed: {e}") | |
| # Try pypdf | |
| if HAS_PYPDF: | |
| try: | |
| reader = PdfReader(file_path) | |
| text = "" | |
| for i, page in enumerate(reader.pages[:3]): | |
| try: | |
| page_text = page.extract_text() | |
| if page_text and page_text.strip(): | |
| text += page_text + "\n" | |
| except: | |
| continue | |
| if len(text.strip()) > 30: | |
| debug_info.append(f"β pypdf: {len(text)} chars") | |
| return text | |
| except Exception as e: | |
| debug_info.append(f"pypdf failed: {e}") | |
| return "No text extracted - likely scanned PDF images" | |
| finally: | |
| if temp_path and os.path.exists(temp_path): | |
| try: | |
| os.unlink(temp_path) | |
| except: | |
| pass | |
| # ============================== | |
| # OTHER EXTRACTIONS | |
| # ============================== | |
| def extract_docx_text(file_data, source_type, debug_info): | |
| try: | |
| if source_type == "path": | |
| doc = Document(file_data) | |
| else: | |
| # Write to temp file | |
| with tempfile.NamedTemporaryFile(delete=False, suffix='.docx') as tmp: | |
| if isinstance(file_data, bytes): | |
| tmp.write(file_data) | |
| tmp_path = tmp.name | |
| doc = Document(tmp_path) | |
| os.unlink(tmp_path) | |
| text = "\n\n".join([p.text.strip() for p in doc.paragraphs if p.text.strip()]) | |
| if len(text) > 20: | |
| return text | |
| return "No text in DOCX" | |
| except Exception as e: | |
| return f"DOCX error: {e}" | |
| def extract_text_file(file_data, source_type, debug_info): | |
| try: | |
| if source_type == "path": | |
| with open(file_data, 'r', encoding='utf-8', errors='ignore') as f: | |
| return f.read() | |
| else: | |
| # Decode bytes | |
| if isinstance(file_data, bytes): | |
| return file_data.decode('utf-8', errors='ignore') | |
| return str(file_data) | |
| except: | |
| return "Text extraction failed" | |
| # ============================== | |
| # MAIN INGESTION | |
| # ============================== | |
| def ingest_sources(files, urls=""): | |
| docs = [] | |
| metadata = [] | |
| debug_info = [] | |
| # Clear existing | |
| for path in [INDEX_PATH, METADATA_PATH]: | |
| if os.path.exists(path): | |
| os.remove(path) | |
| # Process files | |
| for i, file_obj in enumerate(files or []): | |
| debug_info.append(f"\nπ Processing file {i+1}") | |
| # Get file data correctly | |
| file_data, source_info = get_file_data(file_obj) | |
| if isinstance(source_info, list): | |
| debug_info.extend(source_info) | |
| continue | |
| if not file_data: | |
| debug_info.append("β No file data") | |
| continue | |
| # Get filename and extension | |
| filename = getattr(file_obj, 'name', f'file_{i+1}') | |
| if isinstance(filename, bytes): | |
| filename = filename.decode('utf-8', errors='ignore') | |
| ext = os.path.splitext(filename.lower())[1] if filename else '' | |
| debug_info.append(f"File: {filename}, Type: {source_info}") | |
| # Extract text | |
| text = "" | |
| if ext == '.pdf': | |
| text = extract_pdf_text(file_data, source_info, debug_info) | |
| elif ext in ['.docx', '.doc']: | |
| text = extract_docx_text(file_data, source_info, debug_info) | |
| elif ext in ['.txt', '.md']: | |
| text = extract_text_file(file_data, source_info, debug_info) | |
| else: | |
| debug_info.append(f"Unknown extension: {ext}") | |
| continue | |
| # Preview | |
| preview = text[:100].replace('\n', ' ').strip() | |
| if len(preview) > 80: | |
| preview = preview[:80] + "..." | |
| debug_info.append(f"Extracted {len(text)} chars") | |
| debug_info.append(f"Preview: '{preview}'") | |
| # Create chunks | |
| if len(text.strip()) > 30: | |
| chunks = simple_text_splitter(text) | |
| for j, chunk in enumerate(chunks): | |
| docs.append(chunk) | |
| metadata.append({ | |
| "source": filename, | |
| "chunk": j, | |
| "text": chunk | |
| }) | |
| debug_info.append(f"β {len(chunks)} chunks created") | |
| else: | |
| debug_info.append("β οΈ Insufficient content") | |
| debug_info.append(f"\nπ Total: {len(docs)} chunks") | |
| if docs: | |
| embeddings = embed_model.encode(docs) | |
| index = faiss.IndexFlatL2(embeddings.shape[1]) | |
| index.add(embeddings) | |
| faiss.write_index(index, INDEX_PATH) | |
| with open(METADATA_PATH, 'w') as f: | |
| json.dump(metadata, f) | |
| return f"β SUCCESS: {len(docs)} chunks!" | |
| return "β No content.\n\n" + "\n".join(debug_info[-15:]) | |
| # ============================== | |
| # RETRIEVAL & GENERATION | |
| # ============================== | |
| def retrieve_topk(query, k=3): | |
| if not os.path.exists(INDEX_PATH): | |
| return [] | |
| q_emb = embed_model.encode([query]) | |
| index = faiss.read_index(INDEX_PATH) | |
| D, I = index.search(q_emb, k) | |
| with open(METADATA_PATH, 'r') as f: | |
| metadata = json.load(f) | |
| return [metadata[i] for i in I[0] if i < len(metadata)] | |
| def ask_prompt(query): | |
| hits = retrieve_topk(query) | |
| if not hits: | |
| return "No documents found." | |
| context = "\n\n".join([h['text'][:600] for h in hits]) | |
| prompt = f"Context: {context}\nQuestion: {query}\nAnswer:" | |
| result = gen_pipeline(prompt, max_length=300)[0]['generated_text'] | |
| sources = [f"{h['source']} (chunk {h['chunk']})" for h in hits] | |
| return f"{result}\n\nSources:\n" + "\n".join(sources) | |
| # ============================== | |
| # UI | |
| # ============================== | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# π Document QA") | |
| with gr.Row(): | |
| with gr.Column(): | |
| file_input = gr.File(file_count="multiple") | |
| ingest_btn = gr.Button("Ingest", variant="primary") | |
| status = gr.Textbox(lines=15) | |
| with gr.Column(): | |
| query_input = gr.Textbox(label="Question") | |
| ask_btn = gr.Button("Ask") | |
| answer = gr.Textbox(lines=10) | |
| ingest_btn.click(ingest_sources, [file_input, gr.State("")], status) | |
| ask_btn.click(ask_prompt, query_input, answer) | |
| if __name__ == "__main__": | |
| demo.launch() |