JimmyBhoy commited on
Commit
506b259
·
verified ·
1 Parent(s): 51fffd7

Enhanced app.py with RAG functionality: file upload, chunking with unstructured, embedding with sentence-transformers, vector storage with ChromaDB, and semantic retrieval QA

Browse files
Files changed (1) hide show
  1. app.py +190 -3
app.py CHANGED
@@ -1,11 +1,109 @@
1
  import gradio as gr
2
  import random
 
 
 
3
  from smolagents import GradioUI, CodeAgent, HfApiModel
 
 
 
 
 
4
 
5
  # Import our custom tools from their modules
6
  from tools import DuckDuckGoSearchTool, WeatherInfoTool, HubStatsTool
7
  from retriever import load_guest_dataset
8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  # Initialize the Hugging Face model
10
  model = HfApiModel()
11
 
@@ -21,13 +119,102 @@ hub_stats_tool = HubStatsTool()
21
  # Load the guest dataset and initialize the guest info tool
22
  guest_info_tool = load_guest_dataset()
23
 
24
- # Create Alfred with all the tools
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  alfred = CodeAgent(
26
- tools=[guest_info_tool, weather_info_tool, hub_stats_tool, search_tool],
27
  model=model,
28
  add_base_tools=True, # Add any additional base tools
29
  planning_interval=3 # Enable planning every 3 steps
30
  )
31
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  if __name__ == "__main__":
33
- GradioUI(alfred).launch()
 
 
 
1
  import gradio as gr
2
  import random
3
+ import os
4
+ import tempfile
5
+ from pathlib import Path
6
  from smolagents import GradioUI, CodeAgent, HfApiModel
7
+ from sentence_transformers import SentenceTransformer
8
+ import chromadb
9
+ from unstructured.partition.auto import partition
10
+ import numpy as np
11
+ from typing import List, Optional
12
 
13
  # Import our custom tools from their modules
14
  from tools import DuckDuckGoSearchTool, WeatherInfoTool, HubStatsTool
15
  from retriever import load_guest_dataset
16
 
17
+ # Initialize embedding model
18
+ embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
19
+
20
+ # Initialize ChromaDB
21
+ chroma_client = chromadb.Client()
22
+ collection = chroma_client.get_or_create_collection(name="documents")
23
+
24
+ class RAGDocumentProcessor:
25
+ def __init__(self, embedding_model, collection):
26
+ self.embedding_model = embedding_model
27
+ self.collection = collection
28
+
29
+ def process_document(self, file_path: str) -> List[str]:
30
+ """Process document using unstructured for chunking"""
31
+ elements = partition(filename=file_path)
32
+ chunks = []
33
+
34
+ # Group elements into meaningful chunks
35
+ current_chunk = ""
36
+ for element in elements:
37
+ text = str(element)
38
+ if len(current_chunk + text) > 1000: # Max chunk size
39
+ if current_chunk:
40
+ chunks.append(current_chunk.strip())
41
+ current_chunk = text
42
+ else:
43
+ current_chunk += " " + text
44
+
45
+ if current_chunk:
46
+ chunks.append(current_chunk.strip())
47
+
48
+ return chunks
49
+
50
+ def add_document_to_vector_store(self, file_path: str, filename: str):
51
+ """Add document chunks to ChromaDB vector store"""
52
+ chunks = self.process_document(file_path)
53
+
54
+ # Generate embeddings
55
+ embeddings = self.embedding_model.encode(chunks).tolist()
56
+
57
+ # Create IDs and metadata
58
+ ids = [f"{filename}_{i}" for i in range(len(chunks))]
59
+ metadatas = [{"filename": filename, "chunk_id": i} for i in range(len(chunks))]
60
+
61
+ # Add to collection
62
+ self.collection.add(
63
+ embeddings=embeddings,
64
+ documents=chunks,
65
+ metadatas=metadatas,
66
+ ids=ids
67
+ )
68
+
69
+ return len(chunks)
70
+
71
+ def semantic_search(self, query: str, n_results: int = 5) -> List[str]:
72
+ """Perform semantic search using ChromaDB"""
73
+ query_embedding = self.embedding_model.encode([query]).tolist()
74
+
75
+ results = self.collection.query(
76
+ query_embeddings=query_embedding,
77
+ n_results=n_results
78
+ )
79
+
80
+ return results['documents'][0] if results['documents'] else []
81
+
82
+ # Initialize RAG processor
83
+ rag_processor = RAGDocumentProcessor(embedding_model, collection)
84
+
85
+ class SemanticRAGTool:
86
+ """Tool for semantic retrieval and QA using uploaded documents"""
87
+
88
+ name = "semantic_rag_search"
89
+ description = "Search through uploaded documents using semantic similarity and provide context-aware responses"
90
+
91
+ def __call__(self, query: str) -> str:
92
+ """Perform semantic search and return relevant context"""
93
+ relevant_docs = rag_processor.semantic_search(query, n_results=3)
94
+
95
+ if not relevant_docs:
96
+ return "No relevant documents found. Please upload documents first."
97
+
98
+ context = "\n\n".join(relevant_docs)
99
+
100
+ response = f"Based on the uploaded documents, here's the relevant information:\n\n{context}\n\nThis information can help answer your query: {query}"
101
+
102
+ return response
103
+
104
+ # Initialize the semantic RAG tool
105
+ semantic_rag_tool = SemanticRAGTool()
106
+
107
  # Initialize the Hugging Face model
108
  model = HfApiModel()
109
 
 
119
  # Load the guest dataset and initialize the guest info tool
120
  guest_info_tool = load_guest_dataset()
121
 
122
+ def upload_and_process_file(file):
123
+ """Handle file upload and processing"""
124
+ if file is None:
125
+ return "No file uploaded."
126
+
127
+ try:
128
+ # Get the file path
129
+ file_path = file.name
130
+ filename = Path(file_path).name
131
+
132
+ # Process and add to vector store
133
+ num_chunks = rag_processor.add_document_to_vector_store(file_path, filename)
134
+
135
+ return f"Successfully processed '{filename}' into {num_chunks} chunks and added to vector store."
136
+
137
+ except Exception as e:
138
+ return f"Error processing file: {str(e)}"
139
+
140
+ # Create Alfred with all the tools including the new RAG tool
141
  alfred = CodeAgent(
142
+ tools=[guest_info_tool, weather_info_tool, hub_stats_tool, search_tool, semantic_rag_tool],
143
  model=model,
144
  add_base_tools=True, # Add any additional base tools
145
  planning_interval=3 # Enable planning every 3 steps
146
  )
147
 
148
+ # Create custom Gradio interface with file upload
149
+ def create_rag_interface():
150
+ """Create enhanced Gradio interface with file upload and RAG capabilities"""
151
+
152
+ with gr.Blocks(title="Production RAG Agent") as demo:
153
+ gr.Markdown("# Production RAG Agent with Document Upload")
154
+ gr.Markdown("Upload documents and ask questions using semantic search and AI reasoning.")
155
+
156
+ with gr.Row():
157
+ with gr.Column(scale=1):
158
+ file_upload = gr.File(
159
+ label="Upload Documents",
160
+ file_types=[".pdf", ".docx", ".txt", ".md", ".html"],
161
+ file_count="multiple"
162
+ )
163
+ upload_btn = gr.Button("Process Documents")
164
+ upload_status = gr.Textbox(
165
+ label="Upload Status",
166
+ interactive=False
167
+ )
168
+
169
+ with gr.Column(scale=2):
170
+ # Embed the GradioUI from smolagents
171
+ chatbot = gr.Chatbot(label="AI Assistant")
172
+ msg_input = gr.Textbox(
173
+ label="Message",
174
+ placeholder="Ask questions about uploaded documents or anything else..."
175
+ )
176
+ send_btn = gr.Button("Send")
177
+ clear_btn = gr.Button("Clear")
178
+
179
+ # File upload handler
180
+ upload_btn.click(
181
+ fn=lambda files: "\n".join([upload_and_process_file(file) for file in files]) if files else "No files selected.",
182
+ inputs=[file_upload],
183
+ outputs=[upload_status]
184
+ )
185
+
186
+ # Chat functionality
187
+ def respond(message, history):
188
+ try:
189
+ # Use Alfred to generate response
190
+ response = alfred.run(message)
191
+ history.append((message, str(response)))
192
+ return history, ""
193
+ except Exception as e:
194
+ error_msg = f"Error: {str(e)}"
195
+ history.append((message, error_msg))
196
+ return history, ""
197
+
198
+ send_btn.click(
199
+ respond,
200
+ inputs=[msg_input, chatbot],
201
+ outputs=[chatbot, msg_input]
202
+ )
203
+
204
+ msg_input.submit(
205
+ respond,
206
+ inputs=[msg_input, chatbot],
207
+ outputs=[chatbot, msg_input]
208
+ )
209
+
210
+ clear_btn.click(
211
+ lambda: ([], ""),
212
+ outputs=[chatbot, msg_input]
213
+ )
214
+
215
+ return demo
216
+
217
  if __name__ == "__main__":
218
+ # Launch the enhanced RAG interface
219
+ demo = create_rag_interface()
220
+ demo.launch()