Spaces:
Paused
Paused
| """ | |
| Standalone RAG Chatbot with Gemma 3n | |
| A simple PDF chatbot using Retrieval-Augmented Generation | |
| """ | |
| import gradio as gr | |
| import torch | |
| import os | |
| import io | |
| import numpy as np | |
| from PIL import Image | |
| import pymupdf # PyMuPDF for PDF processing | |
| # RAG dependencies | |
| try: | |
| from sentence_transformers import SentenceTransformer | |
| from sklearn.metrics.pairwise import cosine_similarity | |
| from transformers import Gemma3nForConditionalGeneration, AutoProcessor | |
| RAG_AVAILABLE = True | |
| except ImportError as e: | |
| print(f"Missing dependencies: {e}") | |
| RAG_AVAILABLE = False | |
| # Global variables | |
| embedding_model = None | |
| chatbot_model = None | |
| chatbot_processor = None | |
| document_chunks = [] | |
| document_embeddings = None | |
| processed_text = "" | |
| def initialize_models(): | |
| """Initialize embedding model and chatbot model""" | |
| global embedding_model, chatbot_model, chatbot_processor | |
| if not RAG_AVAILABLE: | |
| return False, "Required dependencies not installed" | |
| try: | |
| # Initialize embedding model (CPU to save GPU memory) | |
| if embedding_model is None: | |
| print("Loading embedding model...") | |
| embedding_model = SentenceTransformer('all-MiniLM-L6-v2', device='cpu') | |
| print("β Embedding model loaded successfully") | |
| # Initialize chatbot model | |
| if chatbot_model is None or chatbot_processor is None: | |
| hf_token = os.getenv('HF_TOKEN') | |
| if not hf_token: | |
| return False, "HF_TOKEN not found in environment" | |
| print("Loading Gemma 3n model...") | |
| chatbot_model = Gemma3nForConditionalGeneration.from_pretrained( | |
| "google/gemma-3n-e4b-it", | |
| device_map="auto", | |
| torch_dtype=torch.bfloat16, | |
| token=hf_token | |
| ).eval() | |
| chatbot_processor = AutoProcessor.from_pretrained( | |
| "google/gemma-3n-e4b-it", | |
| token=hf_token | |
| ) | |
| print("β Gemma 3n model loaded successfully") | |
| return True, "All models loaded successfully" | |
| except Exception as e: | |
| print(f"Error loading models: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| return False, f"Error: {str(e)}" | |
| def extract_text_from_pdf(pdf_file): | |
| """Extract text from uploaded PDF file""" | |
| try: | |
| if isinstance(pdf_file, str): | |
| # File path | |
| pdf_document = pymupdf.open(pdf_file) | |
| else: | |
| # File object | |
| pdf_bytes = pdf_file.read() | |
| pdf_document = pymupdf.open(stream=pdf_bytes, filetype="pdf") | |
| text_content = "" | |
| for page_num in range(len(pdf_document)): | |
| page = pdf_document[page_num] | |
| text_content += f"\n--- Page {page_num + 1} ---\n" | |
| text_content += page.get_text() | |
| pdf_document.close() | |
| return text_content | |
| except Exception as e: | |
| raise Exception(f"Error extracting text from PDF: {str(e)}") | |
| def chunk_text(text, chunk_size=500, overlap=50): | |
| """Split text into overlapping chunks""" | |
| words = text.split() | |
| chunks = [] | |
| for i in range(0, len(words), chunk_size - overlap): | |
| chunk = ' '.join(words[i:i + chunk_size]) | |
| if chunk.strip(): | |
| chunks.append(chunk) | |
| return chunks | |
| def create_embeddings(chunks): | |
| """Create embeddings for text chunks""" | |
| if embedding_model is None: | |
| return None | |
| try: | |
| print(f"Creating embeddings for {len(chunks)} chunks...") | |
| embeddings = embedding_model.encode(chunks, show_progress_bar=True) | |
| return np.array(embeddings) | |
| except Exception as e: | |
| print(f"Error creating embeddings: {e}") | |
| return None | |
| def retrieve_relevant_chunks(question, chunks, embeddings, top_k=3): | |
| """Retrieve most relevant chunks for a question""" | |
| if embedding_model is None or embeddings is None: | |
| return chunks[:top_k] | |
| try: | |
| question_embedding = embedding_model.encode([question]) | |
| similarities = cosine_similarity(question_embedding, embeddings)[0] | |
| # Get top-k most similar chunks | |
| top_indices = np.argsort(similarities)[-top_k:][::-1] | |
| relevant_chunks = [chunks[i] for i in top_indices] | |
| return relevant_chunks | |
| except Exception as e: | |
| print(f"Error retrieving chunks: {e}") | |
| return chunks[:top_k] | |
| def process_pdf(pdf_file, progress=gr.Progress()): | |
| """Process uploaded PDF and prepare for Q&A""" | |
| global document_chunks, document_embeddings, processed_text | |
| if pdf_file is None: | |
| return "β Please upload a PDF file first" | |
| try: | |
| # Extract text from PDF | |
| progress(0.2, desc="Extracting text from PDF...") | |
| text = extract_text_from_pdf(pdf_file) | |
| if not text.strip(): | |
| return "β No text found in PDF" | |
| processed_text = text | |
| # Create chunks | |
| progress(0.4, desc="Creating text chunks...") | |
| document_chunks = chunk_text(text) | |
| # Create embeddings | |
| progress(0.6, desc="Creating embeddings...") | |
| document_embeddings = create_embeddings(document_chunks) | |
| if document_embeddings is None: | |
| return "β Failed to create embeddings" | |
| progress(1.0, desc="PDF processed successfully!") | |
| return f"β PDF processed successfully! Created {len(document_chunks)} chunks. You can now ask questions about the document." | |
| except Exception as e: | |
| return f"β Error processing PDF: {str(e)}" | |
| def chat_with_pdf(message, history): | |
| """Generate response using RAG""" | |
| global chatbot_model, chatbot_processor | |
| if not message.strip(): | |
| return history | |
| if not processed_text: | |
| return history + [[message, "β Please upload and process a PDF first"]] | |
| # Check if models are loaded | |
| if chatbot_model is None or chatbot_processor is None: | |
| print("Models not loaded, attempting to reload...") | |
| success, error_msg = initialize_models() | |
| if not success: | |
| return history + [[message, f"β Failed to load models: {error_msg}"]] | |
| try: | |
| # Retrieve relevant chunks | |
| if document_chunks and document_embeddings is not None: | |
| relevant_chunks = retrieve_relevant_chunks(message, document_chunks, document_embeddings) | |
| context = "\n\n".join(relevant_chunks) | |
| else: | |
| # Fallback to truncated text | |
| context = processed_text[:2000] + "..." if len(processed_text) > 2000 else processed_text | |
| # Create messages for Gemma | |
| messages = [ | |
| { | |
| "role": "system", | |
| "content": [{"type": "text", "text": "You are a helpful assistant that answers questions about documents. Use the provided context to answer questions accurately and concisely."}] | |
| }, | |
| { | |
| "role": "user", | |
| "content": [{"type": "text", "text": f"Context:\n{context}\n\nQuestion: {message}"}] | |
| } | |
| ] | |
| # Process with Gemma | |
| inputs = chatbot_processor.apply_chat_template( | |
| messages, | |
| add_generation_prompt=True, | |
| tokenize=True, | |
| return_dict=True, | |
| return_tensors="pt" | |
| ).to(chatbot_model.device) | |
| input_len = inputs["input_ids"].shape[-1] | |
| with torch.inference_mode(): | |
| generation = chatbot_model.generate( | |
| **inputs, | |
| max_new_tokens=300, | |
| do_sample=False, | |
| temperature=0.7, | |
| pad_token_id=chatbot_processor.tokenizer.pad_token_id, | |
| use_cache=True | |
| ) | |
| generation = generation[0][input_len:] | |
| response = chatbot_processor.decode(generation, skip_special_tokens=True) | |
| return history + [[message, response]] | |
| except Exception as e: | |
| error_msg = f"β Error generating response: {str(e)}" | |
| return history + [[message, error_msg]] | |
| def clear_chat(): | |
| """Clear chat history and processed data""" | |
| global document_chunks, document_embeddings, processed_text | |
| document_chunks = [] | |
| document_embeddings = None | |
| processed_text = "" | |
| # Clear GPU cache | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| return [], "Ready to process a new PDF" | |
| def get_model_status(): | |
| """Get current model loading status""" | |
| global chatbot_model, chatbot_processor, embedding_model | |
| statuses = [] | |
| if embedding_model is not None: | |
| statuses.append("β Embedding model loaded") | |
| else: | |
| statuses.append("β Embedding model not loaded") | |
| if chatbot_model is not None and chatbot_processor is not None: | |
| statuses.append("β Chatbot model loaded") | |
| else: | |
| statuses.append("β Chatbot model not loaded") | |
| return " | ".join(statuses) | |
| # Initialize models on startup | |
| model_status = "β³ Initializing models..." | |
| if RAG_AVAILABLE: | |
| success, message = initialize_models() | |
| model_status = "β Models ready" if success else f"β {message}" | |
| else: | |
| model_status = "β Dependencies not installed" | |
| # Create Gradio interface | |
| with gr.Blocks( | |
| title="RAG Chatbot with Gemma 3n", | |
| theme=gr.themes.Soft(), | |
| css=""" | |
| .main-container { max-width: 1200px; margin: 0 auto; } | |
| .status-box { padding: 15px; margin: 10px 0; border-radius: 8px; } | |
| .chat-container { height: 500px; } | |
| """ | |
| ) as demo: | |
| gr.Markdown("# π€ RAG Chatbot with Gemma 3n") | |
| gr.Markdown("### Upload a PDF and ask questions about it using Retrieval-Augmented Generation") | |
| with gr.Row(): | |
| status_display = gr.Markdown(f"**Status:** {model_status}") | |
| # Add refresh button for status | |
| refresh_btn = gr.Button("βΎοΈ Refresh Status", size="sm") | |
| def update_status(): | |
| return get_model_status() | |
| refresh_btn.click( | |
| fn=update_status, | |
| outputs=[status_display] | |
| ) | |
| with gr.Row(): | |
| # Left column - PDF upload | |
| with gr.Column(scale=1): | |
| gr.Markdown("## π Upload PDF") | |
| pdf_input = gr.File( | |
| file_types=[".pdf"], | |
| label="Upload PDF Document" | |
| ) | |
| process_btn = gr.Button( | |
| "π Process PDF", | |
| variant="primary", | |
| size="lg" | |
| ) | |
| status_output = gr.Markdown( | |
| "Upload a PDF to get started", | |
| elem_classes="status-box" | |
| ) | |
| clear_btn = gr.Button( | |
| "ποΈ Clear All", | |
| variant="secondary" | |
| ) | |
| # Right column - Chat | |
| with gr.Column(scale=2): | |
| gr.Markdown("## π¬ Ask Questions") | |
| chatbot = gr.Chatbot( | |
| value=[], | |
| height=400, | |
| elem_classes="chat-container" | |
| ) | |
| with gr.Row(): | |
| msg_input = gr.Textbox( | |
| placeholder="Ask a question about your PDF...", | |
| scale=4, | |
| container=False | |
| ) | |
| send_btn = gr.Button("Send", variant="primary", scale=1) | |
| # Event handlers | |
| process_btn.click( | |
| fn=process_pdf, | |
| inputs=[pdf_input], | |
| outputs=[status_output], | |
| show_progress=True | |
| ) | |
| send_btn.click( | |
| fn=chat_with_pdf, | |
| inputs=[msg_input, chatbot], | |
| outputs=[chatbot] | |
| ).then( | |
| lambda: "", | |
| outputs=[msg_input] | |
| ) | |
| msg_input.submit( | |
| fn=chat_with_pdf, | |
| inputs=[msg_input, chatbot], | |
| outputs=[chatbot] | |
| ).then( | |
| lambda: "", | |
| outputs=[msg_input] | |
| ) | |
| clear_btn.click( | |
| fn=clear_chat, | |
| outputs=[chatbot, status_output] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| share=False, | |
| show_error=True | |
| ) |