Spaces:
Sleeping
Sleeping
| """ | |
| RAG Document Q&A Assistant | |
| Upload documents, ask questions, get answers with source citations. | |
| """ | |
| import os | |
| import tempfile | |
| from typing import Optional | |
| import chromadb | |
| from pypdf import PdfReader # PyMuPDF | |
| import gradio as gr | |
| from chromadb.utils import embedding_functions | |
| from openai import OpenAI | |
| # Initialize OpenAI client | |
| openai_client = OpenAI(api_key=os.getenv("OPENAI_API_KEY")) | |
| # Initialize embedding function | |
| embedding_func = embedding_functions.SentenceTransformerEmbeddingFunction( | |
| model_name="all-MiniLM-L6-v2" | |
| ) | |
| # Global state for the current session | |
| chroma_client = None | |
| collection = None | |
| current_chunks = [] | |
| def extract_text_from_pdf(file_path: str) -> str: | |
| """Extract text from PDF using pypdf.""" | |
| reader = PdfReader(file_path) | |
| text = "" | |
| for page in reader.pages: | |
| text += page.extract_text() or "" | |
| return text | |
| def extract_text_from_txt(file_path: str) -> str: | |
| """Extract text from TXT file.""" | |
| with open(file_path, "r", encoding="utf-8", errors="ignore") as f: | |
| return f.read() | |
| def chunk_fixed_size(text: str, chunk_size: int = 500, overlap: int = 100) -> list[dict]: | |
| """Split text into fixed-size chunks with overlap.""" | |
| chunks = [] | |
| start = 0 | |
| chunk_id = 0 | |
| while start < len(text): | |
| end = start + chunk_size | |
| chunk_text = text[start:end].strip() | |
| if chunk_text: | |
| chunks.append({ | |
| "id": f"chunk_{chunk_id}", | |
| "text": chunk_text, | |
| "start": start, | |
| "end": end | |
| }) | |
| chunk_id += 1 | |
| start = end - overlap | |
| return chunks | |
| def chunk_by_paragraph(text: str) -> list[dict]: | |
| """Split text by paragraphs (double newlines).""" | |
| paragraphs = [p.strip() for p in text.split("\n\n") if p.strip()] | |
| chunks = [] | |
| for i, para in enumerate(paragraphs): | |
| if len(para) > 50: | |
| chunks.append({ | |
| "id": f"chunk_{i}", | |
| "text": para, | |
| "start": 0, | |
| "end": 0 | |
| }) | |
| return chunks | |
| def process_document(file, chunking_strategy: str) -> str: | |
| """Process uploaded document and store in vector DB.""" | |
| global chroma_client, collection, current_chunks | |
| if file is None: | |
| return "β Please upload a document first." | |
| file_path = file.name | |
| file_ext = os.path.splitext(file_path)[1].lower() | |
| try: | |
| if file_ext == ".pdf": | |
| text = extract_text_from_pdf(file_path) | |
| elif file_ext in [".txt", ".md"]: | |
| text = extract_text_from_txt(file_path) | |
| else: | |
| return f"β Unsupported file type: {file_ext}. Please upload PDF or TXT." | |
| except Exception as e: | |
| return f"β Error reading file: {str(e)}" | |
| if not text.strip(): | |
| return "β No text could be extracted from the document." | |
| if chunking_strategy == "Fixed-size (500 chars)": | |
| current_chunks = chunk_fixed_size(text, chunk_size=500, overlap=100) | |
| else: | |
| current_chunks = chunk_by_paragraph(text) | |
| if not current_chunks: | |
| return "β No chunks could be created from the document." | |
| # Initialize fresh Chroma client and collection | |
| chroma_client = chromadb.Client() | |
| try: | |
| chroma_client.delete_collection(name="documents") | |
| except: | |
| pass | |
| collection = chroma_client.create_collection( | |
| name="documents", | |
| embedding_function=embedding_func | |
| ) | |
| collection.add( | |
| documents=[c["text"] for c in current_chunks], | |
| ids=[c["id"] for c in current_chunks] | |
| ) | |
| return f"β Document processed successfully!\n\nπ **Stats:**\n- Characters: {len(text):,}\n- Chunks created: {len(current_chunks)}\n- Chunking strategy: {chunking_strategy}" | |
| def retrieve_context(query: str, top_k: int = 3) -> list[dict]: | |
| """Retrieve relevant chunks for the query.""" | |
| if collection is None: | |
| return [] | |
| results = collection.query( | |
| query_texts=[query], | |
| n_results=top_k | |
| ) | |
| retrieved = [] | |
| for i, (doc, distance) in enumerate(zip( | |
| results["documents"][0], | |
| results["distances"][0] | |
| )): | |
| similarity = 1 / (1 + distance) | |
| retrieved.append({ | |
| "text": doc, | |
| "similarity": similarity, | |
| "rank": i + 1 | |
| }) | |
| return retrieved | |
| def generate_answer(query: str, context_docs: list[dict]) -> str: | |
| """Generate answer using OpenAI with retrieved context.""" | |
| if not context_docs: | |
| return "I don't have any context to answer this question. Please upload a document first." | |
| context = "\n\n".join([ | |
| f"[Source {doc['rank']}] (relevance: {doc['similarity']:.0%})\n{doc['text']}" | |
| for doc in context_docs | |
| ]) | |
| prompt = f"""Answer the question based on the provided context. | |
| If the context doesn't contain enough information to answer fully, say so. | |
| Always reference which source(s) you used. | |
| CONTEXT: | |
| {context} | |
| QUESTION: {query} | |
| ANSWER:""" | |
| try: | |
| response = openai_client.chat.completions.create( | |
| model="gpt-4o-mini", | |
| messages=[ | |
| {"role": "system", "content": "You are a helpful assistant that answers questions based on provided document context. Be concise and cite your sources."}, | |
| {"role": "user", "content": prompt} | |
| ], | |
| temperature=0.3, | |
| max_tokens=500 | |
| ) | |
| return response.choices[0].message.content | |
| except Exception as e: | |
| return f"β Error generating answer: {str(e)}" | |
| def ask_question(query: str) -> tuple[str, str]: | |
| """Main function to handle user questions.""" | |
| if not query.strip(): | |
| return "Please enter a question.", "" | |
| if collection is None: | |
| return "Please upload and process a document first.", "" | |
| retrieved = retrieve_context(query, top_k=3) | |
| answer = generate_answer(query, retrieved) | |
| sources = "\n\n---\n\n**π Retrieved Sources:**\n\n" | |
| for doc in retrieved: | |
| sources += f"**[Source {doc['rank']}]** (relevance: {doc['similarity']:.0%})\n" | |
| sources += f"```\n{doc['text'][:300]}{'...' if len(doc['text']) > 300 else ''}\n```\n\n" | |
| return answer, sources | |
| # Build Gradio interface | |
| with gr.Blocks(title="RAG Document Q&A", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown(""" | |
| # π RAG Document Q&A Assistant | |
| Upload a document (PDF or TXT), choose a chunking strategy, and ask questions! | |
| **How it works:** | |
| 1. Your document is split into chunks using the selected strategy | |
| 2. Chunks are embedded using Sentence Transformers (all-MiniLM-L6-v2) | |
| 3. When you ask a question, relevant chunks are retrieved using semantic search | |
| 4. GPT-4o-mini generates an answer based on the retrieved context | |
| --- | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown("### π€ Step 1: Upload Document") | |
| file_input = gr.File( | |
| label="Upload PDF or TXT", | |
| file_types=[".pdf", ".txt", ".md"] | |
| ) | |
| chunking_dropdown = gr.Dropdown( | |
| choices=["Fixed-size (500 chars)", "Paragraph-based"], | |
| value="Paragraph-based", | |
| label="Chunking Strategy" | |
| ) | |
| process_btn = gr.Button("Process Document", variant="primary") | |
| process_output = gr.Markdown(label="Processing Status") | |
| with gr.Column(scale=2): | |
| gr.Markdown("### π¬ Step 2: Ask Questions") | |
| question_input = gr.Textbox( | |
| label="Your Question", | |
| placeholder="What is this document about?", | |
| lines=2 | |
| ) | |
| ask_btn = gr.Button("Ask", variant="primary") | |
| answer_output = gr.Markdown(label="Answer") | |
| sources_output = gr.Markdown(label="Sources") | |
| gr.Markdown(""" | |
| --- | |
| **π References:** | |
| - [RAG Original Paper (Lewis et al., 2020)](https://arxiv.org/abs/2005.11401) | |
| - [RAG Survey (Gao et al., 2023)](https://arxiv.org/pdf/2312.10997) | |
| - [Chunking Strategies for RAG (Merola & Singh, 2025)](https://arxiv.org/abs/2504.19754) | |
| Built as part of an AI/ML Engineering portfolio project. | |
| """) | |
| process_btn.click( | |
| fn=process_document, | |
| inputs=[file_input, chunking_dropdown], | |
| outputs=[process_output] | |
| ) | |
| ask_btn.click( | |
| fn=ask_question, | |
| inputs=[question_input], | |
| outputs=[answer_output, sources_output] | |
| ) | |
| question_input.submit( | |
| fn=ask_question, | |
| inputs=[question_input], | |
| outputs=[answer_output, sources_output] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |