import streamlit as st import os import tempfile import pickle from typing import List, Dict, Any import numpy as np from pathlib import Path # Document processing import PyPDF2 import docx from sentence_transformers import SentenceTransformer import faiss # Groq API from groq import Groq # Text processing import nltk from nltk.tokenize import sent_tokenize import re # Download required NLTK data try: nltk.data.find('tokenizers/punkt') except LookupError: nltk.download('punkt') class DocumentProcessor: """Handles document upload and text extraction""" @staticmethod def extract_text_from_pdf(file_path: str) -> str: """Extract text from PDF file""" text = "" try: with open(file_path, 'rb') as file: pdf_reader = PyPDF2.PdfReader(file) for page in pdf_reader.pages: text += page.extract_text() + "\n" except Exception as e: st.error(f"Error reading PDF: {str(e)}") return text @staticmethod def extract_text_from_docx(file_path: str) -> str: """Extract text from DOCX file""" text = "" try: doc = docx.Document(file_path) for paragraph in doc.paragraphs: text += paragraph.text + "\n" except Exception as e: st.error(f"Error reading DOCX: {str(e)}") return text @staticmethod def extract_text_from_txt(file_path: str) -> str: """Extract text from TXT file""" text = "" try: with open(file_path, 'r', encoding='utf-8') as file: text = file.read() except Exception as e: st.error(f"Error reading TXT: {str(e)}") return text def process_uploaded_file(self, uploaded_file) -> str: """Process uploaded file and extract text""" if uploaded_file is None: return "" # Save uploaded file temporarily with tempfile.NamedTemporaryFile(delete=False, suffix=f".{uploaded_file.name.split('.')[-1]}") as tmp_file: tmp_file.write(uploaded_file.getvalue()) tmp_file_path = tmp_file.name try: file_extension = uploaded_file.name.split('.')[-1].lower() if file_extension == 'pdf': text = self.extract_text_from_pdf(tmp_file_path) elif file_extension == 'docx': text = self.extract_text_from_docx(tmp_file_path) elif file_extension == 'txt': text = self.extract_text_from_txt(tmp_file_path) else: st.error(f"Unsupported file type: {file_extension}") return "" return text finally: # Clean up temporary file os.unlink(tmp_file_path) class TextChunker: """Handles text chunking and preprocessing""" def __init__(self, chunk_size: int = 1000, overlap: int = 200): self.chunk_size = chunk_size self.overlap = overlap def clean_text(self, text: str) -> str: """Clean and preprocess text""" # Remove extra whitespace text = re.sub(r'\s+', ' ', text) # Remove special characters but keep punctuation text = re.sub(r'[^\w\s\.\!\?\,\;\:\-\(\)]', '', text) return text.strip() def create_chunks(self, text: str) -> List[str]: """Create overlapping chunks from text""" cleaned_text = self.clean_text(text) # Split into sentences first sentences = sent_tokenize(cleaned_text) chunks = [] current_chunk = "" for sentence in sentences: # If adding this sentence would exceed chunk size, start a new chunk if len(current_chunk) + len(sentence) > self.chunk_size: if current_chunk: chunks.append(current_chunk.strip()) # Create overlap by keeping last part of current chunk words = current_chunk.split() if len(words) > 20: # Keep last 20 words for overlap current_chunk = " ".join(words[-20:]) + " " + sentence else: current_chunk = sentence else: current_chunk = sentence else: current_chunk += " " + sentence # Add the last chunk if current_chunk: chunks.append(current_chunk.strip()) return chunks class VectorDatabase: """Handles vector embeddings and FAISS operations""" def __init__(self, model_name: str = "all-MiniLM-L6-v2"): self.embedding_model = SentenceTransformer(model_name) self.dimension = self.embedding_model.get_sentence_embedding_dimension() self.index = faiss.IndexFlatIP(self.dimension) # Inner product for similarity self.chunks = [] self.embeddings = None def create_embeddings(self, chunks: List[str]) -> np.ndarray: """Create embeddings for text chunks""" with st.spinner("Creating embeddings..."): embeddings = self.embedding_model.encode(chunks, show_progress_bar=True) # Normalize embeddings for cosine similarity embeddings = embeddings / np.linalg.norm(embeddings, axis=1, keepdims=True) return embeddings def add_documents(self, chunks: List[str]): """Add documents to the vector database""" if not chunks: return self.chunks.extend(chunks) embeddings = self.create_embeddings(chunks) if self.embeddings is None: self.embeddings = embeddings else: self.embeddings = np.vstack([self.embeddings, embeddings]) # Add to FAISS index self.index.add(embeddings.astype(np.float32)) st.success(f"Added {len(chunks)} chunks to vector database") def search(self, query: str, k: int = 5) -> List[Dict[str, Any]]: """Search for similar documents""" if self.index.ntotal == 0: return [] # Create query embedding query_embedding = self.embedding_model.encode([query]) query_embedding = query_embedding / np.linalg.norm(query_embedding) # Search in FAISS scores, indices = self.index.search(query_embedding.astype(np.float32), k) results = [] for i, (score, idx) in enumerate(zip(scores[0], indices[0])): if idx < len(self.chunks): results.append({ 'chunk': self.chunks[idx], 'score': float(score), 'rank': i + 1 }) return results def save_database(self, filepath: str): """Save the vector database to disk""" data = { 'chunks': self.chunks, 'embeddings': self.embeddings, 'index': faiss.serialize_index(self.index) } with open(filepath, 'wb') as f: pickle.dump(data, f) def load_database(self, filepath: str): """Load the vector database from disk""" try: with open(filepath, 'rb') as f: data = pickle.load(f) self.chunks = data['chunks'] self.embeddings = data['embeddings'] self.index = faiss.deserialize_index(data['index']) return True except Exception as e: st.error(f"Error loading database: {str(e)}") return False class RAGSystem: """Main RAG system that combines retrieval and generation""" def __init__(self, groq_api_key: str): self.groq_client = Groq(api_key=groq_api_key) self.vector_db = VectorDatabase() self.doc_processor = DocumentProcessor() self.text_chunker = TextChunker() def process_document(self, uploaded_file): """Process uploaded document and add to vector database""" # Extract text from document text = self.doc_processor.process_uploaded_file(uploaded_file) if not text: st.error("No text extracted from document") return False # Create chunks chunks = self.text_chunker.create_chunks(text) if not chunks: st.error("No chunks created from text") return False # Add to vector database self.vector_db.add_documents(chunks) return True def generate_response(self, query: str, context: str, model: str = "llama-3.3-70b-versatile") -> str: """Generate response using Groq API""" prompt = f""" Based on the following context, please answer the question. If the answer is not in the context, say "I don't have enough information to answer this question based on the provided documents." Context: {context} Question: {query} Answer: """ try: chat_completion = self.groq_client.chat.completions.create( messages=[ { "role": "system", "content": "You are a helpful assistant that answers questions based on provided context. Be accurate and concise." }, { "role": "user", "content": prompt } ], model=model, temperature=0.1, max_tokens=1000 ) return chat_completion.choices[0].message.content except Exception as e: return f"Error generating response: {str(e)}" def query(self, question: str, model: str = "llama-3.3-70b-versatile") -> Dict[str, Any]: """Query the RAG system""" # Retrieve relevant documents search_results = self.vector_db.search(question, k=3) if not search_results: return { 'answer': "No relevant documents found. Please upload some documents first.", 'sources': [] } # Combine contexts context = "\n\n".join([result['chunk'] for result in search_results]) # Generate response answer = self.generate_response(question, context, model) return { 'answer': answer, 'sources': search_results } def main(): st.set_page_config( page_title="RAG Application", page_icon="🔍", layout="wide", initial_sidebar_state="expanded" ) st.title("🔍 RAG Application") st.markdown("**Upload documents and ask questions using AI-powered search and generation**") # Initialize session state if 'rag_system' not in st.session_state: st.session_state.rag_system = None if 'documents_processed' not in st.session_state: st.session_state.documents_processed = 0 # Sidebar for configuration with st.sidebar: st.header("⚙️ Configuration") # API Key input groq_api_key = st.text_input( "Groq API Key", type="password", help="Enter your Groq API key" ) if not groq_api_key: st.warning("Please enter your Groq API key to continue") st.stop() # Model selection model_options = [ "llama-3.3-70b-versatile", "llama-3.2-90b-text-preview", "llama-3.1-70b-versatile", "mixtral-8x7b-32768", "gemma2-9b-it" ] selected_model = st.selectbox( "Select Model", model_options, index=0 ) # Initialize RAG system if st.session_state.rag_system is None: try: st.session_state.rag_system = RAGSystem(groq_api_key) st.success("RAG system initialized!") except Exception as e: st.error(f"Error initializing RAG system: {str(e)}") st.stop() st.header("📊 Statistics") st.metric("Documents Processed", st.session_state.documents_processed) st.metric("Chunks in Database", len(st.session_state.rag_system.vector_db.chunks)) # Main content area col1, col2 = st.columns([1, 2]) with col1: st.header("📄 Document Upload") uploaded_files = st.file_uploader( "Upload documents", accept_multiple_files=True, type=['pdf', 'docx', 'txt'], help="Upload PDF, DOCX, or TXT files" ) if uploaded_files: for uploaded_file in uploaded_files: if st.button(f"Process {uploaded_file.name}"): with st.spinner(f"Processing {uploaded_file.name}..."): success = st.session_state.rag_system.process_document(uploaded_file) if success: st.session_state.documents_processed += 1 st.success(f"Successfully processed {uploaded_file.name}") else: st.error(f"Failed to process {uploaded_file.name}") with col2: st.header("💬 Ask Questions") if len(st.session_state.rag_system.vector_db.chunks) == 0: st.info("Please upload and process documents before asking questions.") else: question = st.text_input( "Enter your question:", placeholder="What is this document about?" ) if st.button("Ask Question") and question: with st.spinner("Generating answer..."): response = st.session_state.rag_system.query(question, selected_model) st.subheader("Answer:") st.write(response['answer']) if response['sources']: st.subheader("Sources:") for i, source in enumerate(response['sources']): with st.expander(f"Source {i+1} (Score: {source['score']:.3f})"): st.write(source['chunk']) # Additional features st.header("🔧 Additional Features") col3, col4 = st.columns(2) with col3: if st.button("Clear Database"): st.session_state.rag_system.vector_db = VectorDatabase() st.session_state.documents_processed = 0 st.success("Database cleared successfully!") with col4: if st.button("Save Database"): if len(st.session_state.rag_system.vector_db.chunks) > 0: st.session_state.rag_system.vector_db.save_database("rag_database.pkl") st.success("Database saved successfully!") else: st.warning("No data to save") if __name__ == "__main__": main()