|
import streamlit as st |
|
import os |
|
import tempfile |
|
import pickle |
|
from typing import List, Dict, Any |
|
import numpy as np |
|
from pathlib import Path |
|
|
|
|
|
import PyPDF2 |
|
import docx |
|
from sentence_transformers import SentenceTransformer |
|
import faiss |
|
|
|
|
|
from groq import Groq |
|
|
|
|
|
import nltk |
|
from nltk.tokenize import sent_tokenize |
|
import re |
|
|
|
|
|
try: |
|
nltk.data.find('tokenizers/punkt') |
|
except LookupError: |
|
nltk.download('punkt') |
|
|
|
class DocumentProcessor: |
|
"""Handles document upload and text extraction""" |
|
|
|
@staticmethod |
|
def extract_text_from_pdf(file_path: str) -> str: |
|
"""Extract text from PDF file""" |
|
text = "" |
|
try: |
|
with open(file_path, 'rb') as file: |
|
pdf_reader = PyPDF2.PdfReader(file) |
|
for page in pdf_reader.pages: |
|
text += page.extract_text() + "\n" |
|
except Exception as e: |
|
st.error(f"Error reading PDF: {str(e)}") |
|
return text |
|
|
|
@staticmethod |
|
def extract_text_from_docx(file_path: str) -> str: |
|
"""Extract text from DOCX file""" |
|
text = "" |
|
try: |
|
doc = docx.Document(file_path) |
|
for paragraph in doc.paragraphs: |
|
text += paragraph.text + "\n" |
|
except Exception as e: |
|
st.error(f"Error reading DOCX: {str(e)}") |
|
return text |
|
|
|
@staticmethod |
|
def extract_text_from_txt(file_path: str) -> str: |
|
"""Extract text from TXT file""" |
|
text = "" |
|
try: |
|
with open(file_path, 'r', encoding='utf-8') as file: |
|
text = file.read() |
|
except Exception as e: |
|
st.error(f"Error reading TXT: {str(e)}") |
|
return text |
|
|
|
def process_uploaded_file(self, uploaded_file) -> str: |
|
"""Process uploaded file and extract text""" |
|
if uploaded_file is None: |
|
return "" |
|
|
|
|
|
with tempfile.NamedTemporaryFile(delete=False, suffix=f".{uploaded_file.name.split('.')[-1]}") as tmp_file: |
|
tmp_file.write(uploaded_file.getvalue()) |
|
tmp_file_path = tmp_file.name |
|
|
|
try: |
|
file_extension = uploaded_file.name.split('.')[-1].lower() |
|
|
|
if file_extension == 'pdf': |
|
text = self.extract_text_from_pdf(tmp_file_path) |
|
elif file_extension == 'docx': |
|
text = self.extract_text_from_docx(tmp_file_path) |
|
elif file_extension == 'txt': |
|
text = self.extract_text_from_txt(tmp_file_path) |
|
else: |
|
st.error(f"Unsupported file type: {file_extension}") |
|
return "" |
|
|
|
return text |
|
finally: |
|
|
|
os.unlink(tmp_file_path) |
|
|
|
class TextChunker: |
|
"""Handles text chunking and preprocessing""" |
|
|
|
def __init__(self, chunk_size: int = 1000, overlap: int = 200): |
|
self.chunk_size = chunk_size |
|
self.overlap = overlap |
|
|
|
def clean_text(self, text: str) -> str: |
|
"""Clean and preprocess text""" |
|
|
|
text = re.sub(r'\s+', ' ', text) |
|
|
|
text = re.sub(r'[^\w\s\.\!\?\,\;\:\-\(\)]', '', text) |
|
return text.strip() |
|
|
|
def create_chunks(self, text: str) -> List[str]: |
|
"""Create overlapping chunks from text""" |
|
cleaned_text = self.clean_text(text) |
|
|
|
|
|
sentences = sent_tokenize(cleaned_text) |
|
|
|
chunks = [] |
|
current_chunk = "" |
|
|
|
for sentence in sentences: |
|
|
|
if len(current_chunk) + len(sentence) > self.chunk_size: |
|
if current_chunk: |
|
chunks.append(current_chunk.strip()) |
|
|
|
|
|
words = current_chunk.split() |
|
if len(words) > 20: |
|
current_chunk = " ".join(words[-20:]) + " " + sentence |
|
else: |
|
current_chunk = sentence |
|
else: |
|
current_chunk = sentence |
|
else: |
|
current_chunk += " " + sentence |
|
|
|
|
|
if current_chunk: |
|
chunks.append(current_chunk.strip()) |
|
|
|
return chunks |
|
|
|
class VectorDatabase: |
|
"""Handles vector embeddings and FAISS operations""" |
|
|
|
def __init__(self, model_name: str = "all-MiniLM-L6-v2"): |
|
self.embedding_model = SentenceTransformer(model_name) |
|
self.dimension = self.embedding_model.get_sentence_embedding_dimension() |
|
self.index = faiss.IndexFlatIP(self.dimension) |
|
self.chunks = [] |
|
self.embeddings = None |
|
|
|
def create_embeddings(self, chunks: List[str]) -> np.ndarray: |
|
"""Create embeddings for text chunks""" |
|
with st.spinner("Creating embeddings..."): |
|
embeddings = self.embedding_model.encode(chunks, show_progress_bar=True) |
|
|
|
embeddings = embeddings / np.linalg.norm(embeddings, axis=1, keepdims=True) |
|
return embeddings |
|
|
|
def add_documents(self, chunks: List[str]): |
|
"""Add documents to the vector database""" |
|
if not chunks: |
|
return |
|
|
|
self.chunks.extend(chunks) |
|
embeddings = self.create_embeddings(chunks) |
|
|
|
if self.embeddings is None: |
|
self.embeddings = embeddings |
|
else: |
|
self.embeddings = np.vstack([self.embeddings, embeddings]) |
|
|
|
|
|
self.index.add(embeddings.astype(np.float32)) |
|
|
|
st.success(f"Added {len(chunks)} chunks to vector database") |
|
|
|
def search(self, query: str, k: int = 5) -> List[Dict[str, Any]]: |
|
"""Search for similar documents""" |
|
if self.index.ntotal == 0: |
|
return [] |
|
|
|
|
|
query_embedding = self.embedding_model.encode([query]) |
|
query_embedding = query_embedding / np.linalg.norm(query_embedding) |
|
|
|
|
|
scores, indices = self.index.search(query_embedding.astype(np.float32), k) |
|
|
|
results = [] |
|
for i, (score, idx) in enumerate(zip(scores[0], indices[0])): |
|
if idx < len(self.chunks): |
|
results.append({ |
|
'chunk': self.chunks[idx], |
|
'score': float(score), |
|
'rank': i + 1 |
|
}) |
|
|
|
return results |
|
|
|
def save_database(self, filepath: str): |
|
"""Save the vector database to disk""" |
|
data = { |
|
'chunks': self.chunks, |
|
'embeddings': self.embeddings, |
|
'index': faiss.serialize_index(self.index) |
|
} |
|
|
|
with open(filepath, 'wb') as f: |
|
pickle.dump(data, f) |
|
|
|
def load_database(self, filepath: str): |
|
"""Load the vector database from disk""" |
|
try: |
|
with open(filepath, 'rb') as f: |
|
data = pickle.load(f) |
|
|
|
self.chunks = data['chunks'] |
|
self.embeddings = data['embeddings'] |
|
self.index = faiss.deserialize_index(data['index']) |
|
|
|
return True |
|
except Exception as e: |
|
st.error(f"Error loading database: {str(e)}") |
|
return False |
|
|
|
class RAGSystem: |
|
"""Main RAG system that combines retrieval and generation""" |
|
|
|
def __init__(self, groq_api_key: str): |
|
self.groq_client = Groq(api_key=groq_api_key) |
|
self.vector_db = VectorDatabase() |
|
self.doc_processor = DocumentProcessor() |
|
self.text_chunker = TextChunker() |
|
|
|
def process_document(self, uploaded_file): |
|
"""Process uploaded document and add to vector database""" |
|
|
|
text = self.doc_processor.process_uploaded_file(uploaded_file) |
|
|
|
if not text: |
|
st.error("No text extracted from document") |
|
return False |
|
|
|
|
|
chunks = self.text_chunker.create_chunks(text) |
|
|
|
if not chunks: |
|
st.error("No chunks created from text") |
|
return False |
|
|
|
|
|
self.vector_db.add_documents(chunks) |
|
|
|
return True |
|
|
|
def generate_response(self, query: str, context: str, model: str = "llama-3.3-70b-versatile") -> str: |
|
"""Generate response using Groq API""" |
|
|
|
prompt = f""" |
|
Based on the following context, please answer the question. If the answer is not in the context, say "I don't have enough information to answer this question based on the provided documents." |
|
|
|
Context: |
|
{context} |
|
|
|
Question: {query} |
|
|
|
Answer: |
|
""" |
|
|
|
try: |
|
chat_completion = self.groq_client.chat.completions.create( |
|
messages=[ |
|
{ |
|
"role": "system", |
|
"content": "You are a helpful assistant that answers questions based on provided context. Be accurate and concise." |
|
}, |
|
{ |
|
"role": "user", |
|
"content": prompt |
|
} |
|
], |
|
model=model, |
|
temperature=0.1, |
|
max_tokens=1000 |
|
) |
|
|
|
return chat_completion.choices[0].message.content |
|
except Exception as e: |
|
return f"Error generating response: {str(e)}" |
|
|
|
def query(self, question: str, model: str = "llama-3.3-70b-versatile") -> Dict[str, Any]: |
|
"""Query the RAG system""" |
|
|
|
search_results = self.vector_db.search(question, k=3) |
|
|
|
if not search_results: |
|
return { |
|
'answer': "No relevant documents found. Please upload some documents first.", |
|
'sources': [] |
|
} |
|
|
|
|
|
context = "\n\n".join([result['chunk'] for result in search_results]) |
|
|
|
|
|
answer = self.generate_response(question, context, model) |
|
|
|
return { |
|
'answer': answer, |
|
'sources': search_results |
|
} |
|
|
|
def main(): |
|
st.set_page_config( |
|
page_title="RAG Application", |
|
page_icon="π", |
|
layout="wide", |
|
initial_sidebar_state="expanded" |
|
) |
|
|
|
st.title("π RAG Application") |
|
st.markdown("**Upload documents and ask questions using AI-powered search and generation**") |
|
|
|
|
|
if 'rag_system' not in st.session_state: |
|
st.session_state.rag_system = None |
|
if 'documents_processed' not in st.session_state: |
|
st.session_state.documents_processed = 0 |
|
|
|
|
|
with st.sidebar: |
|
st.header("βοΈ Configuration") |
|
|
|
|
|
groq_api_key = st.text_input( |
|
"Groq API Key", |
|
type="password", |
|
help="Enter your Groq API key" |
|
) |
|
|
|
if not groq_api_key: |
|
st.warning("Please enter your Groq API key to continue") |
|
st.stop() |
|
|
|
|
|
model_options = [ |
|
"llama-3.3-70b-versatile", |
|
"llama-3.2-90b-text-preview", |
|
"llama-3.1-70b-versatile", |
|
"mixtral-8x7b-32768", |
|
"gemma2-9b-it" |
|
] |
|
|
|
selected_model = st.selectbox( |
|
"Select Model", |
|
model_options, |
|
index=0 |
|
) |
|
|
|
|
|
if st.session_state.rag_system is None: |
|
try: |
|
st.session_state.rag_system = RAGSystem(groq_api_key) |
|
st.success("RAG system initialized!") |
|
except Exception as e: |
|
st.error(f"Error initializing RAG system: {str(e)}") |
|
st.stop() |
|
|
|
st.header("π Statistics") |
|
st.metric("Documents Processed", st.session_state.documents_processed) |
|
st.metric("Chunks in Database", len(st.session_state.rag_system.vector_db.chunks)) |
|
|
|
|
|
col1, col2 = st.columns([1, 2]) |
|
|
|
with col1: |
|
st.header("π Document Upload") |
|
|
|
uploaded_files = st.file_uploader( |
|
"Upload documents", |
|
accept_multiple_files=True, |
|
type=['pdf', 'docx', 'txt'], |
|
help="Upload PDF, DOCX, or TXT files" |
|
) |
|
|
|
if uploaded_files: |
|
for uploaded_file in uploaded_files: |
|
if st.button(f"Process {uploaded_file.name}"): |
|
with st.spinner(f"Processing {uploaded_file.name}..."): |
|
success = st.session_state.rag_system.process_document(uploaded_file) |
|
if success: |
|
st.session_state.documents_processed += 1 |
|
st.success(f"Successfully processed {uploaded_file.name}") |
|
else: |
|
st.error(f"Failed to process {uploaded_file.name}") |
|
|
|
with col2: |
|
st.header("π¬ Ask Questions") |
|
|
|
if len(st.session_state.rag_system.vector_db.chunks) == 0: |
|
st.info("Please upload and process documents before asking questions.") |
|
else: |
|
question = st.text_input( |
|
"Enter your question:", |
|
placeholder="What is this document about?" |
|
) |
|
|
|
if st.button("Ask Question") and question: |
|
with st.spinner("Generating answer..."): |
|
response = st.session_state.rag_system.query(question, selected_model) |
|
|
|
st.subheader("Answer:") |
|
st.write(response['answer']) |
|
|
|
if response['sources']: |
|
st.subheader("Sources:") |
|
for i, source in enumerate(response['sources']): |
|
with st.expander(f"Source {i+1} (Score: {source['score']:.3f})"): |
|
st.write(source['chunk']) |
|
|
|
|
|
st.header("π§ Additional Features") |
|
|
|
col3, col4 = st.columns(2) |
|
|
|
with col3: |
|
if st.button("Clear Database"): |
|
st.session_state.rag_system.vector_db = VectorDatabase() |
|
st.session_state.documents_processed = 0 |
|
st.success("Database cleared successfully!") |
|
|
|
with col4: |
|
if st.button("Save Database"): |
|
if len(st.session_state.rag_system.vector_db.chunks) > 0: |
|
st.session_state.rag_system.vector_db.save_database("rag_database.pkl") |
|
st.success("Database saved successfully!") |
|
else: |
|
st.warning("No data to save") |
|
|
|
if __name__ == "__main__": |
|
main() |