Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,444 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import os
|
3 |
+
import tempfile
|
4 |
+
import pickle
|
5 |
+
from typing import List, Dict, Any
|
6 |
+
import numpy as np
|
7 |
+
from pathlib import Path
|
8 |
+
|
9 |
+
# Document processing
|
10 |
+
import PyPDF2
|
11 |
+
import docx
|
12 |
+
from sentence_transformers import SentenceTransformer
|
13 |
+
import faiss
|
14 |
+
|
15 |
+
# Groq API
|
16 |
+
from groq import Groq
|
17 |
+
|
18 |
+
# Text processing
|
19 |
+
import nltk
|
20 |
+
from nltk.tokenize import sent_tokenize
|
21 |
+
import re
|
22 |
+
|
23 |
+
# Download required NLTK data
|
24 |
+
try:
|
25 |
+
nltk.data.find('tokenizers/punkt')
|
26 |
+
except LookupError:
|
27 |
+
nltk.download('punkt')
|
28 |
+
|
29 |
+
class DocumentProcessor:
|
30 |
+
"""Handles document upload and text extraction"""
|
31 |
+
|
32 |
+
@staticmethod
|
33 |
+
def extract_text_from_pdf(file_path: str) -> str:
|
34 |
+
"""Extract text from PDF file"""
|
35 |
+
text = ""
|
36 |
+
try:
|
37 |
+
with open(file_path, 'rb') as file:
|
38 |
+
pdf_reader = PyPDF2.PdfReader(file)
|
39 |
+
for page in pdf_reader.pages:
|
40 |
+
text += page.extract_text() + "\n"
|
41 |
+
except Exception as e:
|
42 |
+
st.error(f"Error reading PDF: {str(e)}")
|
43 |
+
return text
|
44 |
+
|
45 |
+
@staticmethod
|
46 |
+
def extract_text_from_docx(file_path: str) -> str:
|
47 |
+
"""Extract text from DOCX file"""
|
48 |
+
text = ""
|
49 |
+
try:
|
50 |
+
doc = docx.Document(file_path)
|
51 |
+
for paragraph in doc.paragraphs:
|
52 |
+
text += paragraph.text + "\n"
|
53 |
+
except Exception as e:
|
54 |
+
st.error(f"Error reading DOCX: {str(e)}")
|
55 |
+
return text
|
56 |
+
|
57 |
+
@staticmethod
|
58 |
+
def extract_text_from_txt(file_path: str) -> str:
|
59 |
+
"""Extract text from TXT file"""
|
60 |
+
text = ""
|
61 |
+
try:
|
62 |
+
with open(file_path, 'r', encoding='utf-8') as file:
|
63 |
+
text = file.read()
|
64 |
+
except Exception as e:
|
65 |
+
st.error(f"Error reading TXT: {str(e)}")
|
66 |
+
return text
|
67 |
+
|
68 |
+
def process_uploaded_file(self, uploaded_file) -> str:
|
69 |
+
"""Process uploaded file and extract text"""
|
70 |
+
if uploaded_file is None:
|
71 |
+
return ""
|
72 |
+
|
73 |
+
# Save uploaded file temporarily
|
74 |
+
with tempfile.NamedTemporaryFile(delete=False, suffix=f".{uploaded_file.name.split('.')[-1]}") as tmp_file:
|
75 |
+
tmp_file.write(uploaded_file.getvalue())
|
76 |
+
tmp_file_path = tmp_file.name
|
77 |
+
|
78 |
+
try:
|
79 |
+
file_extension = uploaded_file.name.split('.')[-1].lower()
|
80 |
+
|
81 |
+
if file_extension == 'pdf':
|
82 |
+
text = self.extract_text_from_pdf(tmp_file_path)
|
83 |
+
elif file_extension == 'docx':
|
84 |
+
text = self.extract_text_from_docx(tmp_file_path)
|
85 |
+
elif file_extension == 'txt':
|
86 |
+
text = self.extract_text_from_txt(tmp_file_path)
|
87 |
+
else:
|
88 |
+
st.error(f"Unsupported file type: {file_extension}")
|
89 |
+
return ""
|
90 |
+
|
91 |
+
return text
|
92 |
+
finally:
|
93 |
+
# Clean up temporary file
|
94 |
+
os.unlink(tmp_file_path)
|
95 |
+
|
96 |
+
class TextChunker:
|
97 |
+
"""Handles text chunking and preprocessing"""
|
98 |
+
|
99 |
+
def __init__(self, chunk_size: int = 1000, overlap: int = 200):
|
100 |
+
self.chunk_size = chunk_size
|
101 |
+
self.overlap = overlap
|
102 |
+
|
103 |
+
def clean_text(self, text: str) -> str:
|
104 |
+
"""Clean and preprocess text"""
|
105 |
+
# Remove extra whitespace
|
106 |
+
text = re.sub(r'\s+', ' ', text)
|
107 |
+
# Remove special characters but keep punctuation
|
108 |
+
text = re.sub(r'[^\w\s\.\!\?\,\;\:\-\(\)]', '', text)
|
109 |
+
return text.strip()
|
110 |
+
|
111 |
+
def create_chunks(self, text: str) -> List[str]:
|
112 |
+
"""Create overlapping chunks from text"""
|
113 |
+
cleaned_text = self.clean_text(text)
|
114 |
+
|
115 |
+
# Split into sentences first
|
116 |
+
sentences = sent_tokenize(cleaned_text)
|
117 |
+
|
118 |
+
chunks = []
|
119 |
+
current_chunk = ""
|
120 |
+
|
121 |
+
for sentence in sentences:
|
122 |
+
# If adding this sentence would exceed chunk size, start a new chunk
|
123 |
+
if len(current_chunk) + len(sentence) > self.chunk_size:
|
124 |
+
if current_chunk:
|
125 |
+
chunks.append(current_chunk.strip())
|
126 |
+
|
127 |
+
# Create overlap by keeping last part of current chunk
|
128 |
+
words = current_chunk.split()
|
129 |
+
if len(words) > 20: # Keep last 20 words for overlap
|
130 |
+
current_chunk = " ".join(words[-20:]) + " " + sentence
|
131 |
+
else:
|
132 |
+
current_chunk = sentence
|
133 |
+
else:
|
134 |
+
current_chunk = sentence
|
135 |
+
else:
|
136 |
+
current_chunk += " " + sentence
|
137 |
+
|
138 |
+
# Add the last chunk
|
139 |
+
if current_chunk:
|
140 |
+
chunks.append(current_chunk.strip())
|
141 |
+
|
142 |
+
return chunks
|
143 |
+
|
144 |
+
class VectorDatabase:
|
145 |
+
"""Handles vector embeddings and FAISS operations"""
|
146 |
+
|
147 |
+
def __init__(self, model_name: str = "all-MiniLM-L6-v2"):
|
148 |
+
self.embedding_model = SentenceTransformer(model_name)
|
149 |
+
self.dimension = self.embedding_model.get_sentence_embedding_dimension()
|
150 |
+
self.index = faiss.IndexFlatIP(self.dimension) # Inner product for similarity
|
151 |
+
self.chunks = []
|
152 |
+
self.embeddings = None
|
153 |
+
|
154 |
+
def create_embeddings(self, chunks: List[str]) -> np.ndarray:
|
155 |
+
"""Create embeddings for text chunks"""
|
156 |
+
with st.spinner("Creating embeddings..."):
|
157 |
+
embeddings = self.embedding_model.encode(chunks, show_progress_bar=True)
|
158 |
+
# Normalize embeddings for cosine similarity
|
159 |
+
embeddings = embeddings / np.linalg.norm(embeddings, axis=1, keepdims=True)
|
160 |
+
return embeddings
|
161 |
+
|
162 |
+
def add_documents(self, chunks: List[str]):
|
163 |
+
"""Add documents to the vector database"""
|
164 |
+
if not chunks:
|
165 |
+
return
|
166 |
+
|
167 |
+
self.chunks.extend(chunks)
|
168 |
+
embeddings = self.create_embeddings(chunks)
|
169 |
+
|
170 |
+
if self.embeddings is None:
|
171 |
+
self.embeddings = embeddings
|
172 |
+
else:
|
173 |
+
self.embeddings = np.vstack([self.embeddings, embeddings])
|
174 |
+
|
175 |
+
# Add to FAISS index
|
176 |
+
self.index.add(embeddings.astype(np.float32))
|
177 |
+
|
178 |
+
st.success(f"Added {len(chunks)} chunks to vector database")
|
179 |
+
|
180 |
+
def search(self, query: str, k: int = 5) -> List[Dict[str, Any]]:
|
181 |
+
"""Search for similar documents"""
|
182 |
+
if self.index.ntotal == 0:
|
183 |
+
return []
|
184 |
+
|
185 |
+
# Create query embedding
|
186 |
+
query_embedding = self.embedding_model.encode([query])
|
187 |
+
query_embedding = query_embedding / np.linalg.norm(query_embedding)
|
188 |
+
|
189 |
+
# Search in FAISS
|
190 |
+
scores, indices = self.index.search(query_embedding.astype(np.float32), k)
|
191 |
+
|
192 |
+
results = []
|
193 |
+
for i, (score, idx) in enumerate(zip(scores[0], indices[0])):
|
194 |
+
if idx < len(self.chunks):
|
195 |
+
results.append({
|
196 |
+
'chunk': self.chunks[idx],
|
197 |
+
'score': float(score),
|
198 |
+
'rank': i + 1
|
199 |
+
})
|
200 |
+
|
201 |
+
return results
|
202 |
+
|
203 |
+
def save_database(self, filepath: str):
|
204 |
+
"""Save the vector database to disk"""
|
205 |
+
data = {
|
206 |
+
'chunks': self.chunks,
|
207 |
+
'embeddings': self.embeddings,
|
208 |
+
'index': faiss.serialize_index(self.index)
|
209 |
+
}
|
210 |
+
|
211 |
+
with open(filepath, 'wb') as f:
|
212 |
+
pickle.dump(data, f)
|
213 |
+
|
214 |
+
def load_database(self, filepath: str):
|
215 |
+
"""Load the vector database from disk"""
|
216 |
+
try:
|
217 |
+
with open(filepath, 'rb') as f:
|
218 |
+
data = pickle.load(f)
|
219 |
+
|
220 |
+
self.chunks = data['chunks']
|
221 |
+
self.embeddings = data['embeddings']
|
222 |
+
self.index = faiss.deserialize_index(data['index'])
|
223 |
+
|
224 |
+
return True
|
225 |
+
except Exception as e:
|
226 |
+
st.error(f"Error loading database: {str(e)}")
|
227 |
+
return False
|
228 |
+
|
229 |
+
class RAGSystem:
|
230 |
+
"""Main RAG system that combines retrieval and generation"""
|
231 |
+
|
232 |
+
def __init__(self, groq_api_key: str):
|
233 |
+
self.groq_client = Groq(api_key=groq_api_key)
|
234 |
+
self.vector_db = VectorDatabase()
|
235 |
+
self.doc_processor = DocumentProcessor()
|
236 |
+
self.text_chunker = TextChunker()
|
237 |
+
|
238 |
+
def process_document(self, uploaded_file):
|
239 |
+
"""Process uploaded document and add to vector database"""
|
240 |
+
# Extract text from document
|
241 |
+
text = self.doc_processor.process_uploaded_file(uploaded_file)
|
242 |
+
|
243 |
+
if not text:
|
244 |
+
st.error("No text extracted from document")
|
245 |
+
return False
|
246 |
+
|
247 |
+
# Create chunks
|
248 |
+
chunks = self.text_chunker.create_chunks(text)
|
249 |
+
|
250 |
+
if not chunks:
|
251 |
+
st.error("No chunks created from text")
|
252 |
+
return False
|
253 |
+
|
254 |
+
# Add to vector database
|
255 |
+
self.vector_db.add_documents(chunks)
|
256 |
+
|
257 |
+
return True
|
258 |
+
|
259 |
+
def generate_response(self, query: str, context: str, model: str = "llama-3.3-70b-versatile") -> str:
|
260 |
+
"""Generate response using Groq API"""
|
261 |
+
|
262 |
+
prompt = f"""
|
263 |
+
Based on the following context, please answer the question. If the answer is not in the context, say "I don't have enough information to answer this question based on the provided documents."
|
264 |
+
|
265 |
+
Context:
|
266 |
+
{context}
|
267 |
+
|
268 |
+
Question: {query}
|
269 |
+
|
270 |
+
Answer:
|
271 |
+
"""
|
272 |
+
|
273 |
+
try:
|
274 |
+
chat_completion = self.groq_client.chat.completions.create(
|
275 |
+
messages=[
|
276 |
+
{
|
277 |
+
"role": "system",
|
278 |
+
"content": "You are a helpful assistant that answers questions based on provided context. Be accurate and concise."
|
279 |
+
},
|
280 |
+
{
|
281 |
+
"role": "user",
|
282 |
+
"content": prompt
|
283 |
+
}
|
284 |
+
],
|
285 |
+
model=model,
|
286 |
+
temperature=0.1,
|
287 |
+
max_tokens=1000
|
288 |
+
)
|
289 |
+
|
290 |
+
return chat_completion.choices[0].message.content
|
291 |
+
except Exception as e:
|
292 |
+
return f"Error generating response: {str(e)}"
|
293 |
+
|
294 |
+
def query(self, question: str, model: str = "llama-3.3-70b-versatile") -> Dict[str, Any]:
|
295 |
+
"""Query the RAG system"""
|
296 |
+
# Retrieve relevant documents
|
297 |
+
search_results = self.vector_db.search(question, k=3)
|
298 |
+
|
299 |
+
if not search_results:
|
300 |
+
return {
|
301 |
+
'answer': "No relevant documents found. Please upload some documents first.",
|
302 |
+
'sources': []
|
303 |
+
}
|
304 |
+
|
305 |
+
# Combine contexts
|
306 |
+
context = "\n\n".join([result['chunk'] for result in search_results])
|
307 |
+
|
308 |
+
# Generate response
|
309 |
+
answer = self.generate_response(question, context, model)
|
310 |
+
|
311 |
+
return {
|
312 |
+
'answer': answer,
|
313 |
+
'sources': search_results
|
314 |
+
}
|
315 |
+
|
316 |
+
def main():
|
317 |
+
st.set_page_config(
|
318 |
+
page_title="RAG Application",
|
319 |
+
page_icon="π",
|
320 |
+
layout="wide",
|
321 |
+
initial_sidebar_state="expanded"
|
322 |
+
)
|
323 |
+
|
324 |
+
st.title("π RAG Application")
|
325 |
+
st.markdown("**Upload documents and ask questions using AI-powered search and generation**")
|
326 |
+
|
327 |
+
# Initialize session state
|
328 |
+
if 'rag_system' not in st.session_state:
|
329 |
+
st.session_state.rag_system = None
|
330 |
+
if 'documents_processed' not in st.session_state:
|
331 |
+
st.session_state.documents_processed = 0
|
332 |
+
|
333 |
+
# Sidebar for configuration
|
334 |
+
with st.sidebar:
|
335 |
+
st.header("βοΈ Configuration")
|
336 |
+
|
337 |
+
# API Key input
|
338 |
+
groq_api_key = st.text_input(
|
339 |
+
"Groq API Key",
|
340 |
+
type="password",
|
341 |
+
help="Enter your Groq API key"
|
342 |
+
)
|
343 |
+
|
344 |
+
if not groq_api_key:
|
345 |
+
st.warning("Please enter your Groq API key to continue")
|
346 |
+
st.stop()
|
347 |
+
|
348 |
+
# Model selection
|
349 |
+
model_options = [
|
350 |
+
"llama-3.3-70b-versatile",
|
351 |
+
"llama-3.2-90b-text-preview",
|
352 |
+
"llama-3.1-70b-versatile",
|
353 |
+
"mixtral-8x7b-32768",
|
354 |
+
"gemma2-9b-it"
|
355 |
+
]
|
356 |
+
|
357 |
+
selected_model = st.selectbox(
|
358 |
+
"Select Model",
|
359 |
+
model_options,
|
360 |
+
index=0
|
361 |
+
)
|
362 |
+
|
363 |
+
# Initialize RAG system
|
364 |
+
if st.session_state.rag_system is None:
|
365 |
+
try:
|
366 |
+
st.session_state.rag_system = RAGSystem(groq_api_key)
|
367 |
+
st.success("RAG system initialized!")
|
368 |
+
except Exception as e:
|
369 |
+
st.error(f"Error initializing RAG system: {str(e)}")
|
370 |
+
st.stop()
|
371 |
+
|
372 |
+
st.header("π Statistics")
|
373 |
+
st.metric("Documents Processed", st.session_state.documents_processed)
|
374 |
+
st.metric("Chunks in Database", len(st.session_state.rag_system.vector_db.chunks))
|
375 |
+
|
376 |
+
# Main content area
|
377 |
+
col1, col2 = st.columns([1, 2])
|
378 |
+
|
379 |
+
with col1:
|
380 |
+
st.header("π Document Upload")
|
381 |
+
|
382 |
+
uploaded_files = st.file_uploader(
|
383 |
+
"Upload documents",
|
384 |
+
accept_multiple_files=True,
|
385 |
+
type=['pdf', 'docx', 'txt'],
|
386 |
+
help="Upload PDF, DOCX, or TXT files"
|
387 |
+
)
|
388 |
+
|
389 |
+
if uploaded_files:
|
390 |
+
for uploaded_file in uploaded_files:
|
391 |
+
if st.button(f"Process {uploaded_file.name}"):
|
392 |
+
with st.spinner(f"Processing {uploaded_file.name}..."):
|
393 |
+
success = st.session_state.rag_system.process_document(uploaded_file)
|
394 |
+
if success:
|
395 |
+
st.session_state.documents_processed += 1
|
396 |
+
st.success(f"Successfully processed {uploaded_file.name}")
|
397 |
+
else:
|
398 |
+
st.error(f"Failed to process {uploaded_file.name}")
|
399 |
+
|
400 |
+
with col2:
|
401 |
+
st.header("π¬ Ask Questions")
|
402 |
+
|
403 |
+
if len(st.session_state.rag_system.vector_db.chunks) == 0:
|
404 |
+
st.info("Please upload and process documents before asking questions.")
|
405 |
+
else:
|
406 |
+
question = st.text_input(
|
407 |
+
"Enter your question:",
|
408 |
+
placeholder="What is this document about?"
|
409 |
+
)
|
410 |
+
|
411 |
+
if st.button("Ask Question") and question:
|
412 |
+
with st.spinner("Generating answer..."):
|
413 |
+
response = st.session_state.rag_system.query(question, selected_model)
|
414 |
+
|
415 |
+
st.subheader("Answer:")
|
416 |
+
st.write(response['answer'])
|
417 |
+
|
418 |
+
if response['sources']:
|
419 |
+
st.subheader("Sources:")
|
420 |
+
for i, source in enumerate(response['sources']):
|
421 |
+
with st.expander(f"Source {i+1} (Score: {source['score']:.3f})"):
|
422 |
+
st.write(source['chunk'])
|
423 |
+
|
424 |
+
# Additional features
|
425 |
+
st.header("π§ Additional Features")
|
426 |
+
|
427 |
+
col3, col4 = st.columns(2)
|
428 |
+
|
429 |
+
with col3:
|
430 |
+
if st.button("Clear Database"):
|
431 |
+
st.session_state.rag_system.vector_db = VectorDatabase()
|
432 |
+
st.session_state.documents_processed = 0
|
433 |
+
st.success("Database cleared successfully!")
|
434 |
+
|
435 |
+
with col4:
|
436 |
+
if st.button("Save Database"):
|
437 |
+
if len(st.session_state.rag_system.vector_db.chunks) > 0:
|
438 |
+
st.session_state.rag_system.vector_db.save_database("rag_database.pkl")
|
439 |
+
st.success("Database saved successfully!")
|
440 |
+
else:
|
441 |
+
st.warning("No data to save")
|
442 |
+
|
443 |
+
if __name__ == "__main__":
|
444 |
+
main()
|