agnixcode commited on
Commit
b6df80f
Β·
verified Β·
1 Parent(s): ef4b314

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +444 -0
app.py ADDED
@@ -0,0 +1,444 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import os
3
+ import tempfile
4
+ import pickle
5
+ from typing import List, Dict, Any
6
+ import numpy as np
7
+ from pathlib import Path
8
+
9
+ # Document processing
10
+ import PyPDF2
11
+ import docx
12
+ from sentence_transformers import SentenceTransformer
13
+ import faiss
14
+
15
+ # Groq API
16
+ from groq import Groq
17
+
18
+ # Text processing
19
+ import nltk
20
+ from nltk.tokenize import sent_tokenize
21
+ import re
22
+
23
+ # Download required NLTK data
24
+ try:
25
+ nltk.data.find('tokenizers/punkt')
26
+ except LookupError:
27
+ nltk.download('punkt')
28
+
29
+ class DocumentProcessor:
30
+ """Handles document upload and text extraction"""
31
+
32
+ @staticmethod
33
+ def extract_text_from_pdf(file_path: str) -> str:
34
+ """Extract text from PDF file"""
35
+ text = ""
36
+ try:
37
+ with open(file_path, 'rb') as file:
38
+ pdf_reader = PyPDF2.PdfReader(file)
39
+ for page in pdf_reader.pages:
40
+ text += page.extract_text() + "\n"
41
+ except Exception as e:
42
+ st.error(f"Error reading PDF: {str(e)}")
43
+ return text
44
+
45
+ @staticmethod
46
+ def extract_text_from_docx(file_path: str) -> str:
47
+ """Extract text from DOCX file"""
48
+ text = ""
49
+ try:
50
+ doc = docx.Document(file_path)
51
+ for paragraph in doc.paragraphs:
52
+ text += paragraph.text + "\n"
53
+ except Exception as e:
54
+ st.error(f"Error reading DOCX: {str(e)}")
55
+ return text
56
+
57
+ @staticmethod
58
+ def extract_text_from_txt(file_path: str) -> str:
59
+ """Extract text from TXT file"""
60
+ text = ""
61
+ try:
62
+ with open(file_path, 'r', encoding='utf-8') as file:
63
+ text = file.read()
64
+ except Exception as e:
65
+ st.error(f"Error reading TXT: {str(e)}")
66
+ return text
67
+
68
+ def process_uploaded_file(self, uploaded_file) -> str:
69
+ """Process uploaded file and extract text"""
70
+ if uploaded_file is None:
71
+ return ""
72
+
73
+ # Save uploaded file temporarily
74
+ with tempfile.NamedTemporaryFile(delete=False, suffix=f".{uploaded_file.name.split('.')[-1]}") as tmp_file:
75
+ tmp_file.write(uploaded_file.getvalue())
76
+ tmp_file_path = tmp_file.name
77
+
78
+ try:
79
+ file_extension = uploaded_file.name.split('.')[-1].lower()
80
+
81
+ if file_extension == 'pdf':
82
+ text = self.extract_text_from_pdf(tmp_file_path)
83
+ elif file_extension == 'docx':
84
+ text = self.extract_text_from_docx(tmp_file_path)
85
+ elif file_extension == 'txt':
86
+ text = self.extract_text_from_txt(tmp_file_path)
87
+ else:
88
+ st.error(f"Unsupported file type: {file_extension}")
89
+ return ""
90
+
91
+ return text
92
+ finally:
93
+ # Clean up temporary file
94
+ os.unlink(tmp_file_path)
95
+
96
+ class TextChunker:
97
+ """Handles text chunking and preprocessing"""
98
+
99
+ def __init__(self, chunk_size: int = 1000, overlap: int = 200):
100
+ self.chunk_size = chunk_size
101
+ self.overlap = overlap
102
+
103
+ def clean_text(self, text: str) -> str:
104
+ """Clean and preprocess text"""
105
+ # Remove extra whitespace
106
+ text = re.sub(r'\s+', ' ', text)
107
+ # Remove special characters but keep punctuation
108
+ text = re.sub(r'[^\w\s\.\!\?\,\;\:\-\(\)]', '', text)
109
+ return text.strip()
110
+
111
+ def create_chunks(self, text: str) -> List[str]:
112
+ """Create overlapping chunks from text"""
113
+ cleaned_text = self.clean_text(text)
114
+
115
+ # Split into sentences first
116
+ sentences = sent_tokenize(cleaned_text)
117
+
118
+ chunks = []
119
+ current_chunk = ""
120
+
121
+ for sentence in sentences:
122
+ # If adding this sentence would exceed chunk size, start a new chunk
123
+ if len(current_chunk) + len(sentence) > self.chunk_size:
124
+ if current_chunk:
125
+ chunks.append(current_chunk.strip())
126
+
127
+ # Create overlap by keeping last part of current chunk
128
+ words = current_chunk.split()
129
+ if len(words) > 20: # Keep last 20 words for overlap
130
+ current_chunk = " ".join(words[-20:]) + " " + sentence
131
+ else:
132
+ current_chunk = sentence
133
+ else:
134
+ current_chunk = sentence
135
+ else:
136
+ current_chunk += " " + sentence
137
+
138
+ # Add the last chunk
139
+ if current_chunk:
140
+ chunks.append(current_chunk.strip())
141
+
142
+ return chunks
143
+
144
+ class VectorDatabase:
145
+ """Handles vector embeddings and FAISS operations"""
146
+
147
+ def __init__(self, model_name: str = "all-MiniLM-L6-v2"):
148
+ self.embedding_model = SentenceTransformer(model_name)
149
+ self.dimension = self.embedding_model.get_sentence_embedding_dimension()
150
+ self.index = faiss.IndexFlatIP(self.dimension) # Inner product for similarity
151
+ self.chunks = []
152
+ self.embeddings = None
153
+
154
+ def create_embeddings(self, chunks: List[str]) -> np.ndarray:
155
+ """Create embeddings for text chunks"""
156
+ with st.spinner("Creating embeddings..."):
157
+ embeddings = self.embedding_model.encode(chunks, show_progress_bar=True)
158
+ # Normalize embeddings for cosine similarity
159
+ embeddings = embeddings / np.linalg.norm(embeddings, axis=1, keepdims=True)
160
+ return embeddings
161
+
162
+ def add_documents(self, chunks: List[str]):
163
+ """Add documents to the vector database"""
164
+ if not chunks:
165
+ return
166
+
167
+ self.chunks.extend(chunks)
168
+ embeddings = self.create_embeddings(chunks)
169
+
170
+ if self.embeddings is None:
171
+ self.embeddings = embeddings
172
+ else:
173
+ self.embeddings = np.vstack([self.embeddings, embeddings])
174
+
175
+ # Add to FAISS index
176
+ self.index.add(embeddings.astype(np.float32))
177
+
178
+ st.success(f"Added {len(chunks)} chunks to vector database")
179
+
180
+ def search(self, query: str, k: int = 5) -> List[Dict[str, Any]]:
181
+ """Search for similar documents"""
182
+ if self.index.ntotal == 0:
183
+ return []
184
+
185
+ # Create query embedding
186
+ query_embedding = self.embedding_model.encode([query])
187
+ query_embedding = query_embedding / np.linalg.norm(query_embedding)
188
+
189
+ # Search in FAISS
190
+ scores, indices = self.index.search(query_embedding.astype(np.float32), k)
191
+
192
+ results = []
193
+ for i, (score, idx) in enumerate(zip(scores[0], indices[0])):
194
+ if idx < len(self.chunks):
195
+ results.append({
196
+ 'chunk': self.chunks[idx],
197
+ 'score': float(score),
198
+ 'rank': i + 1
199
+ })
200
+
201
+ return results
202
+
203
+ def save_database(self, filepath: str):
204
+ """Save the vector database to disk"""
205
+ data = {
206
+ 'chunks': self.chunks,
207
+ 'embeddings': self.embeddings,
208
+ 'index': faiss.serialize_index(self.index)
209
+ }
210
+
211
+ with open(filepath, 'wb') as f:
212
+ pickle.dump(data, f)
213
+
214
+ def load_database(self, filepath: str):
215
+ """Load the vector database from disk"""
216
+ try:
217
+ with open(filepath, 'rb') as f:
218
+ data = pickle.load(f)
219
+
220
+ self.chunks = data['chunks']
221
+ self.embeddings = data['embeddings']
222
+ self.index = faiss.deserialize_index(data['index'])
223
+
224
+ return True
225
+ except Exception as e:
226
+ st.error(f"Error loading database: {str(e)}")
227
+ return False
228
+
229
+ class RAGSystem:
230
+ """Main RAG system that combines retrieval and generation"""
231
+
232
+ def __init__(self, groq_api_key: str):
233
+ self.groq_client = Groq(api_key=groq_api_key)
234
+ self.vector_db = VectorDatabase()
235
+ self.doc_processor = DocumentProcessor()
236
+ self.text_chunker = TextChunker()
237
+
238
+ def process_document(self, uploaded_file):
239
+ """Process uploaded document and add to vector database"""
240
+ # Extract text from document
241
+ text = self.doc_processor.process_uploaded_file(uploaded_file)
242
+
243
+ if not text:
244
+ st.error("No text extracted from document")
245
+ return False
246
+
247
+ # Create chunks
248
+ chunks = self.text_chunker.create_chunks(text)
249
+
250
+ if not chunks:
251
+ st.error("No chunks created from text")
252
+ return False
253
+
254
+ # Add to vector database
255
+ self.vector_db.add_documents(chunks)
256
+
257
+ return True
258
+
259
+ def generate_response(self, query: str, context: str, model: str = "llama-3.3-70b-versatile") -> str:
260
+ """Generate response using Groq API"""
261
+
262
+ prompt = f"""
263
+ Based on the following context, please answer the question. If the answer is not in the context, say "I don't have enough information to answer this question based on the provided documents."
264
+
265
+ Context:
266
+ {context}
267
+
268
+ Question: {query}
269
+
270
+ Answer:
271
+ """
272
+
273
+ try:
274
+ chat_completion = self.groq_client.chat.completions.create(
275
+ messages=[
276
+ {
277
+ "role": "system",
278
+ "content": "You are a helpful assistant that answers questions based on provided context. Be accurate and concise."
279
+ },
280
+ {
281
+ "role": "user",
282
+ "content": prompt
283
+ }
284
+ ],
285
+ model=model,
286
+ temperature=0.1,
287
+ max_tokens=1000
288
+ )
289
+
290
+ return chat_completion.choices[0].message.content
291
+ except Exception as e:
292
+ return f"Error generating response: {str(e)}"
293
+
294
+ def query(self, question: str, model: str = "llama-3.3-70b-versatile") -> Dict[str, Any]:
295
+ """Query the RAG system"""
296
+ # Retrieve relevant documents
297
+ search_results = self.vector_db.search(question, k=3)
298
+
299
+ if not search_results:
300
+ return {
301
+ 'answer': "No relevant documents found. Please upload some documents first.",
302
+ 'sources': []
303
+ }
304
+
305
+ # Combine contexts
306
+ context = "\n\n".join([result['chunk'] for result in search_results])
307
+
308
+ # Generate response
309
+ answer = self.generate_response(question, context, model)
310
+
311
+ return {
312
+ 'answer': answer,
313
+ 'sources': search_results
314
+ }
315
+
316
+ def main():
317
+ st.set_page_config(
318
+ page_title="RAG Application",
319
+ page_icon="πŸ”",
320
+ layout="wide",
321
+ initial_sidebar_state="expanded"
322
+ )
323
+
324
+ st.title("πŸ” RAG Application")
325
+ st.markdown("**Upload documents and ask questions using AI-powered search and generation**")
326
+
327
+ # Initialize session state
328
+ if 'rag_system' not in st.session_state:
329
+ st.session_state.rag_system = None
330
+ if 'documents_processed' not in st.session_state:
331
+ st.session_state.documents_processed = 0
332
+
333
+ # Sidebar for configuration
334
+ with st.sidebar:
335
+ st.header("βš™οΈ Configuration")
336
+
337
+ # API Key input
338
+ groq_api_key = st.text_input(
339
+ "Groq API Key",
340
+ type="password",
341
+ help="Enter your Groq API key"
342
+ )
343
+
344
+ if not groq_api_key:
345
+ st.warning("Please enter your Groq API key to continue")
346
+ st.stop()
347
+
348
+ # Model selection
349
+ model_options = [
350
+ "llama-3.3-70b-versatile",
351
+ "llama-3.2-90b-text-preview",
352
+ "llama-3.1-70b-versatile",
353
+ "mixtral-8x7b-32768",
354
+ "gemma2-9b-it"
355
+ ]
356
+
357
+ selected_model = st.selectbox(
358
+ "Select Model",
359
+ model_options,
360
+ index=0
361
+ )
362
+
363
+ # Initialize RAG system
364
+ if st.session_state.rag_system is None:
365
+ try:
366
+ st.session_state.rag_system = RAGSystem(groq_api_key)
367
+ st.success("RAG system initialized!")
368
+ except Exception as e:
369
+ st.error(f"Error initializing RAG system: {str(e)}")
370
+ st.stop()
371
+
372
+ st.header("πŸ“Š Statistics")
373
+ st.metric("Documents Processed", st.session_state.documents_processed)
374
+ st.metric("Chunks in Database", len(st.session_state.rag_system.vector_db.chunks))
375
+
376
+ # Main content area
377
+ col1, col2 = st.columns([1, 2])
378
+
379
+ with col1:
380
+ st.header("πŸ“„ Document Upload")
381
+
382
+ uploaded_files = st.file_uploader(
383
+ "Upload documents",
384
+ accept_multiple_files=True,
385
+ type=['pdf', 'docx', 'txt'],
386
+ help="Upload PDF, DOCX, or TXT files"
387
+ )
388
+
389
+ if uploaded_files:
390
+ for uploaded_file in uploaded_files:
391
+ if st.button(f"Process {uploaded_file.name}"):
392
+ with st.spinner(f"Processing {uploaded_file.name}..."):
393
+ success = st.session_state.rag_system.process_document(uploaded_file)
394
+ if success:
395
+ st.session_state.documents_processed += 1
396
+ st.success(f"Successfully processed {uploaded_file.name}")
397
+ else:
398
+ st.error(f"Failed to process {uploaded_file.name}")
399
+
400
+ with col2:
401
+ st.header("πŸ’¬ Ask Questions")
402
+
403
+ if len(st.session_state.rag_system.vector_db.chunks) == 0:
404
+ st.info("Please upload and process documents before asking questions.")
405
+ else:
406
+ question = st.text_input(
407
+ "Enter your question:",
408
+ placeholder="What is this document about?"
409
+ )
410
+
411
+ if st.button("Ask Question") and question:
412
+ with st.spinner("Generating answer..."):
413
+ response = st.session_state.rag_system.query(question, selected_model)
414
+
415
+ st.subheader("Answer:")
416
+ st.write(response['answer'])
417
+
418
+ if response['sources']:
419
+ st.subheader("Sources:")
420
+ for i, source in enumerate(response['sources']):
421
+ with st.expander(f"Source {i+1} (Score: {source['score']:.3f})"):
422
+ st.write(source['chunk'])
423
+
424
+ # Additional features
425
+ st.header("πŸ”§ Additional Features")
426
+
427
+ col3, col4 = st.columns(2)
428
+
429
+ with col3:
430
+ if st.button("Clear Database"):
431
+ st.session_state.rag_system.vector_db = VectorDatabase()
432
+ st.session_state.documents_processed = 0
433
+ st.success("Database cleared successfully!")
434
+
435
+ with col4:
436
+ if st.button("Save Database"):
437
+ if len(st.session_state.rag_system.vector_db.chunks) > 0:
438
+ st.session_state.rag_system.vector_db.save_database("rag_database.pkl")
439
+ st.success("Database saved successfully!")
440
+ else:
441
+ st.warning("No data to save")
442
+
443
+ if __name__ == "__main__":
444
+ main()