MerveA commited on
Commit
27cfd4d
·
1 Parent(s): e649674

Fix langchain dependency for HF Space

Browse files
Files changed (4) hide show
  1. .gitignore +1 -0
  2. README.md +1 -1
  3. app.py +201 -96
  4. requirements.txt +2 -1
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ .env
README.md CHANGED
@@ -1,5 +1,5 @@
1
  ---
2
- title: "InsightRAG_Chatbot"
3
  emoji: "🤖"
4
  colorFrom: "purple"
5
  colorTo: "blue"
 
1
  ---
2
+ title: "RAG Chatbot: ML/AI Assistant"
3
  emoji: "🤖"
4
  colorFrom: "purple"
5
  colorTo: "blue"
app.py CHANGED
@@ -1,10 +1,23 @@
1
  import streamlit as st
2
  import os
3
  import json
 
 
 
 
 
4
  import time
5
  from datetime import datetime
 
 
 
 
 
 
 
 
 
6
 
7
- # Page configuration
8
  st.set_page_config(
9
  page_title="🤖 RAG Chatbot: ML/AI Assistant",
10
  page_icon="🤖",
@@ -12,13 +25,8 @@ st.set_page_config(
12
  initial_sidebar_state="expanded"
13
  )
14
 
15
- # Initialize session state immediately
16
- if 'messages' not in st.session_state:
17
- st.session_state.messages = []
18
- if 'rag_system' not in st.session_state:
19
- st.session_state.rag_system = None
20
- if 'initialized' not in st.session_state:
21
- st.session_state.initialized = False
22
 
23
  # Custom CSS for better styling
24
  st.markdown("""
@@ -58,24 +66,88 @@ st.markdown("""
58
  </style>
59
  """, unsafe_allow_html=True)
60
 
61
- # RAG System Functions
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
  def initialize_rag_system(api_key):
63
  """Initialize the RAG system with all components"""
64
  try:
65
  # Set API key
66
  os.environ['GOOGLE_API_KEY'] = api_key
67
 
68
- # Import required libraries with error handling
69
- try:
70
- from sentence_transformers import SentenceTransformer
71
- import chromadb
72
- from chromadb.config import Settings
73
- import google.generativeai as genai
74
- import re
75
- except ImportError as e:
76
- st.error(f"Import error: {e}")
77
- return None
78
-
79
  # Initialize embedding model
80
  embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
81
 
@@ -88,117 +160,148 @@ def initialize_rag_system(api_key):
88
  collection_name = "ml_ai_knowledge"
89
  try:
90
  collection = chroma_client.get_collection(collection_name)
 
91
  except:
92
  collection = chroma_client.create_collection(
93
  name=collection_name,
94
- metadata={"description": "ML/AI knowledge base"}
95
  )
 
96
 
97
  # Check if collection already has data
98
  existing_count = collection.count()
 
99
 
100
  if existing_count == 0:
101
- # Load sample data for demo
102
- sample_texts = [
103
- "Machine learning is a subset of artificial intelligence that focuses on algorithms that can learn from data. Deep learning uses neural networks with multiple layers to process complex patterns in data.",
104
- "Neural networks are computing systems inspired by biological neural networks. They consist of interconnected nodes that process information using a connectionist approach.",
105
- "Supervised learning uses labeled training data to learn a mapping from inputs to outputs. Common algorithms include linear regression, decision trees, and support vector machines.",
106
- "Unsupervised learning finds hidden patterns in data without labeled examples. Clustering algorithms like K-means group similar data points together.",
107
- "Natural language processing combines computational linguistics with machine learning to help computers understand human language. It includes tasks like text classification and sentiment analysis.",
108
- "Computer vision enables machines to interpret and understand visual information from the world. It uses deep learning models like convolutional neural networks.",
109
- "Reinforcement learning is a type of machine learning where agents learn to make decisions by interacting with an environment and receiving rewards or penalties.",
110
- "Feature engineering is the process of selecting and transforming raw data into features that can be used by machine learning algorithms. Good features can significantly improve model performance.",
111
- "Cross-validation is a technique used to assess how well a machine learning model generalizes to new data. It involves splitting data into training and validation sets multiple times.",
112
- "Overfitting occurs when a model learns the training data too well and performs poorly on new data. Regularization techniques help prevent overfitting.",
113
- "Gradient descent is an optimization algorithm used to minimize the cost function in machine learning models. It iteratively adjusts parameters to find the minimum of the function.",
114
- "Backpropagation is a method used to train neural networks by calculating gradients and updating weights. It works by propagating errors backward through the network layers.",
115
- "Convolutional Neural Networks (CNNs) are specialized neural networks designed for processing grid-like data such as images. They use convolutional layers to detect local features.",
116
- "Transformers are a type of neural network architecture that uses attention mechanisms to process sequential data. They are the foundation of modern language models like GPT.",
117
- "Large Language Models (LLMs) are AI systems trained on vast amounts of text data to understand and generate human-like text. They can perform various language tasks.",
118
- "Generative AI refers to AI systems that can create new content, such as text, images, or code. It differs from predictive AI which focuses on making predictions.",
119
- "Transfer learning is a technique where a model trained on one task is adapted for a different but related task. It can significantly reduce training time and improve performance.",
120
- "Hyperparameter tuning is the process of finding the optimal hyperparameters for a machine learning model. Common methods include grid search and random search.",
121
- "Regularization techniques like L1 and L2 regularization help prevent overfitting by adding penalty terms to the loss function. They encourage simpler models.",
122
- "Activation functions introduce non-linearity into neural networks. Common activation functions include ReLU, sigmoid, and tanh."
123
- ]
124
 
125
- # Add sample documents to Chroma
126
  all_chunks = []
127
  chunk_ids = []
128
  chunk_metadatas = []
129
 
130
- for i, text in enumerate(sample_texts):
131
- chunk_id = f"sample_doc_{i}"
132
- metadata = {
133
- "source": f"sample_doc_{i}",
134
- "chunk_index": 0,
135
- "total_chunks": 1,
136
- "text_length": len(text)
137
- }
138
 
139
- all_chunks.append(text)
140
- chunk_ids.append(chunk_id)
141
- chunk_metadatas.append(metadata)
 
 
 
 
 
 
 
 
 
142
 
143
- # Add documents to Chroma
144
- collection.add(
145
- documents=all_chunks,
146
- ids=chunk_ids,
147
- metadatas=chunk_metadatas
148
- )
 
 
 
 
 
 
 
 
 
 
 
 
149
 
150
- # Initialize Gemini using direct API
151
- genai.configure(api_key=api_key)
 
 
 
 
 
152
 
153
  return {
154
  'embedding_model': embedding_model,
155
  'chroma_client': chroma_client,
156
  'collection': collection,
157
- 'genai': genai
158
  }
159
  except Exception as e:
160
  st.error(f"Error initializing RAG system: {e}")
161
  return None
162
 
163
- def rag_pipeline(query, rag_system, n_results=5):
164
- """Complete RAG pipeline using direct Gemini API"""
165
  try:
166
- collection = rag_system['collection']
167
- genai = rag_system['genai']
168
-
169
- # Retrieve relevant documents
170
  results = collection.query(
171
  query_texts=[query],
172
  n_results=n_results
173
  )
174
 
 
175
  documents = results['documents'][0]
 
176
  distances = results['distances'][0]
177
 
178
- if not documents:
179
- return "I couldn't find relevant information for your query. Please try asking about machine learning, deep learning, or AI topics."
180
-
181
- # Create context
182
- context = "\n\n".join(documents)
183
-
184
- # Generate answer using direct Gemini API
185
- model = genai.GenerativeModel('gemini-2.0-flash-exp')
186
-
187
- prompt = f"""You are an AI assistant specialized in machine learning, deep learning, and artificial intelligence.
188
- Use the provided context to answer questions accurately and comprehensively. If the context doesn't contain enough
189
- information, you can supplement with your general knowledge, but always prioritize the provided context.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
190
 
191
- Provide clear, well-structured answers with examples when appropriate.
 
 
 
 
 
 
 
 
 
192
 
193
- Context:
194
- {context}
195
 
196
- Question: {query}
 
197
 
198
- Please provide a comprehensive answer based on the context above."""
 
199
 
200
- response = model.generate_content(prompt)
201
- return response.text, documents, distances
 
202
 
203
  except Exception as e:
204
  return f"Error generating response: {e}", [], []
@@ -207,7 +310,7 @@ def rag_pipeline(query, rag_system, n_results=5):
207
  st.markdown("""
208
  <div class="main-header">
209
  <h1>🤖 RAG Chatbot: ML/AI Assistant</h1>
210
- <p>Powered by Google Gemini 2.5 Flash + Chroma + Direct API</p>
211
  </div>
212
  """, unsafe_allow_html=True)
213
 
@@ -221,7 +324,8 @@ with st.sidebar:
221
  type="password",
222
  help="Get your API key from Google AI Studio"
223
  )
224
-
 
225
  if api_key:
226
  os.environ['GOOGLE_API_KEY'] = api_key
227
 
@@ -281,13 +385,14 @@ if not st.session_state.initialized:
281
  deep learning, AI, and related topics using:
282
 
283
  - **🤖 Generation Model**: Google Gemini 2.5 Flash
 
284
  - **🗄️ Vector Database**: Chroma
285
- - **📚 Dataset**: Sample ML/AI knowledge base
286
  - **🌐 Interface**: Streamlit
287
 
288
  ### 🚀 How It Works
289
 
290
- 1. **Data Loading**: Sample ML/AI content is loaded
291
  2. **Embedding**: Text is processed and embedded using sentence transformers
292
  3. **Storage**: Embeddings are stored in Chroma vector database
293
  4. **Retrieval**: Relevant context is retrieved for user queries
@@ -360,7 +465,7 @@ else:
360
  st.markdown("---")
361
  st.markdown("""
362
  <div style="text-align: center; color: #666; padding: 1rem;">
363
- <p>🤖 RAG Chatbot | Powered by Google Gemini 2.5 Flash + Chroma</p>
364
- <p>📚 Knowledge Base: ML/AI Sample Dataset</p>
365
  </div>
366
  """, unsafe_allow_html=True)
 
1
  import streamlit as st
2
  import os
3
  import json
4
+ import chromadb
5
+ from chromadb.config import Settings
6
+ from sentence_transformers import SentenceTransformer
7
+ from langchain_google_genai import ChatGoogleGenerativeAI
8
+ from langchain.schema import HumanMessage, SystemMessage
9
  import time
10
  from datetime import datetime
11
+ import uuid
12
+ import pandas as pd
13
+ import numpy as np
14
+ from datasets import load_dataset
15
+ from tqdm import tqdm
16
+ import re
17
+ from dotenv import load_dotenv
18
+ import os
19
+
20
 
 
21
  st.set_page_config(
22
  page_title="🤖 RAG Chatbot: ML/AI Assistant",
23
  page_icon="🤖",
 
25
  initial_sidebar_state="expanded"
26
  )
27
 
28
+ load_dotenv()
29
+ api_key = os.environ.get("GOOGLE_API_KEY")
 
 
 
 
 
30
 
31
  # Custom CSS for better styling
32
  st.markdown("""
 
66
  </style>
67
  """, unsafe_allow_html=True)
68
 
69
+ # Initialize session state
70
+ if 'messages' not in st.session_state:
71
+ st.session_state.messages = []
72
+ if 'rag_system' not in st.session_state:
73
+ st.session_state.rag_system = None
74
+ if 'initialized' not in st.session_state:
75
+ st.session_state.initialized = False
76
+
77
+ # RAG System Functions (from notebook)
78
+ def chunk_text(text, chunk_size=500, overlap=50):
79
+ """Split text into overlapping chunks"""
80
+ words = text.split()
81
+ chunks = []
82
+
83
+ for i in range(0, len(words), chunk_size - overlap):
84
+ chunk = ' '.join(words[i:i + chunk_size])
85
+ if len(chunk.strip()) > 50: # Only keep substantial chunks
86
+ chunks.append(chunk)
87
+
88
+ return chunks
89
+
90
+ def load_and_process_dataset():
91
+ """Load and process The Pile dataset"""
92
+ print("📚 Loading The Pile dataset...")
93
+
94
+ try:
95
+ # Load a specific subset that contains ML/AI content
96
+ dataset = load_dataset("EleutherAI/the_pile", split="train", streaming=True)
97
+
98
+ # Take first 1000 samples for demonstration
99
+ texts = []
100
+ ml_keywords = ['machine learning', 'deep learning', 'neural network', 'artificial intelligence',
101
+ 'algorithm', 'model', 'training', 'data', 'feature', 'classification',
102
+ 'regression', 'clustering', 'optimization', 'gradient', 'tensor']
103
+
104
+ print("🔍 Filtering ML/AI related content...")
105
+ count = 0
106
+ for sample in tqdm(dataset, desc="Processing samples"):
107
+ if count >= 1000: # Limit to 1000 samples for demo
108
+ break
109
+
110
+ text = sample['text']
111
+ # Check if text contains ML/AI keywords
112
+ if any(keyword in text.lower() for keyword in ml_keywords):
113
+ # Clean and preprocess text
114
+ text = re.sub(r'\s+', ' ', text) # Remove extra whitespace
115
+ text = text.strip()
116
+
117
+ # Only keep texts that are reasonable length (not too short or too long)
118
+ if 100 <= len(text) <= 2000:
119
+ texts.append(text)
120
+ count += 1
121
+
122
+ print(f"✅ Loaded {len(texts)} ML/AI related text samples")
123
+ return texts
124
+
125
+ except Exception as e:
126
+ print(f"❌ Error loading dataset: {e}")
127
+ print("🔄 Using fallback sample data...")
128
+
129
+ # Fallback sample data if The Pile is not accessible
130
+ texts = [
131
+ "Machine learning is a subset of artificial intelligence that focuses on algorithms that can learn from data. Deep learning uses neural networks with multiple layers to process complex patterns in data.",
132
+ "Neural networks are computing systems inspired by biological neural networks. They consist of interconnected nodes that process information using a connectionist approach.",
133
+ "Supervised learning uses labeled training data to learn a mapping from inputs to outputs. Common algorithms include linear regression, decision trees, and support vector machines.",
134
+ "Unsupervised learning finds hidden patterns in data without labeled examples. Clustering algorithms like K-means group similar data points together.",
135
+ "Natural language processing combines computational linguistics with machine learning to help computers understand human language. It includes tasks like text classification and sentiment analysis.",
136
+ "Computer vision enables machines to interpret and understand visual information from the world. It uses deep learning models like convolutional neural networks.",
137
+ "Reinforcement learning is a type of machine learning where agents learn to make decisions by interacting with an environment and receiving rewards or penalties.",
138
+ "Feature engineering is the process of selecting and transforming raw data into features that can be used by machine learning algorithms. Good features can significantly improve model performance.",
139
+ "Cross-validation is a technique used to assess how well a machine learning model generalizes to new data. It involves splitting data into training and validation sets multiple times.",
140
+ "Overfitting occurs when a model learns the training data too well and performs poorly on new data. Regularization techniques help prevent overfitting."
141
+ ]
142
+ print(f"✅ Using {len(texts)} sample texts")
143
+ return texts
144
+
145
  def initialize_rag_system(api_key):
146
  """Initialize the RAG system with all components"""
147
  try:
148
  # Set API key
149
  os.environ['GOOGLE_API_KEY'] = api_key
150
 
 
 
 
 
 
 
 
 
 
 
 
151
  # Initialize embedding model
152
  embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
153
 
 
160
  collection_name = "ml_ai_knowledge"
161
  try:
162
  collection = chroma_client.get_collection(collection_name)
163
+ print(f"✅ Found existing collection: {collection_name}")
164
  except:
165
  collection = chroma_client.create_collection(
166
  name=collection_name,
167
+ metadata={"description": "ML/AI knowledge base from The Pile dataset"}
168
  )
169
+ print(f"✅ Created new collection: {collection_name}")
170
 
171
  # Check if collection already has data
172
  existing_count = collection.count()
173
+ print(f"📊 Current documents in collection: {existing_count}")
174
 
175
  if existing_count == 0:
176
+ print("🔄 Adding new documents to collection...")
177
+
178
+ # Load and process dataset
179
+ texts = load_and_process_dataset()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
180
 
 
181
  all_chunks = []
182
  chunk_ids = []
183
  chunk_metadatas = []
184
 
185
+ for i, text in enumerate(tqdm(texts, desc="Processing texts")):
186
+ chunks = chunk_text(text)
 
 
 
 
 
 
187
 
188
+ for j, chunk in enumerate(chunks):
189
+ chunk_id = f"doc_{i}_chunk_{j}"
190
+ metadata = {
191
+ "source": f"the_pile_doc_{i}",
192
+ "chunk_index": j,
193
+ "total_chunks": len(chunks),
194
+ "text_length": len(chunk)
195
+ }
196
+
197
+ all_chunks.append(chunk)
198
+ chunk_ids.append(chunk_id)
199
+ chunk_metadatas.append(metadata)
200
 
201
+ print(f"📊 Created {len(all_chunks)} text chunks")
202
+
203
+ # Add documents to Chroma in batches to avoid memory issues
204
+ batch_size = 100
205
+ for i in tqdm(range(0, len(all_chunks), batch_size), desc="Adding to Chroma"):
206
+ batch_chunks = all_chunks[i:i + batch_size]
207
+ batch_ids = chunk_ids[i:i + batch_size]
208
+ batch_metadatas = chunk_metadatas[i:i + batch_size]
209
+
210
+ collection.add(
211
+ documents=batch_chunks,
212
+ ids=batch_ids,
213
+ metadatas=batch_metadatas
214
+ )
215
+
216
+ print("✅ All documents added to Chroma!")
217
+ else:
218
+ print("✅ Collection already contains data, skipping addition")
219
 
220
+ # Initialize Gemini
221
+ llm = ChatGoogleGenerativeAI(
222
+ model="gemini-2.0-flash-exp",
223
+ temperature=0.7,
224
+ max_output_tokens=1024,
225
+ convert_system_message_to_human=True
226
+ )
227
 
228
  return {
229
  'embedding_model': embedding_model,
230
  'chroma_client': chroma_client,
231
  'collection': collection,
232
+ 'llm': llm
233
  }
234
  except Exception as e:
235
  st.error(f"Error initializing RAG system: {e}")
236
  return None
237
 
238
+ def retrieve_relevant_docs(query, collection, n_results=5):
239
+ """Retrieve relevant documents from Chroma"""
240
  try:
 
 
 
 
241
  results = collection.query(
242
  query_texts=[query],
243
  n_results=n_results
244
  )
245
 
246
+ # Extract documents and metadata
247
  documents = results['documents'][0]
248
+ metadatas = results['metadatas'][0]
249
  distances = results['distances'][0]
250
 
251
+ return documents, metadatas, distances
252
+ except Exception as e:
253
+ print(f"Error retrieving documents: {e}")
254
+ return [], [], []
255
+
256
+ def create_context(documents):
257
+ """Create context string from retrieved documents"""
258
+ context = "\n\n".join(documents)
259
+ return context
260
+
261
+ def generate_answer(query, context, llm):
262
+ """Generate answer using Gemini with retrieved context"""
263
+ system_prompt = """You are an AI assistant specialized in machine learning, deep learning, and artificial intelligence.
264
+ Use the provided context to answer questions accurately and comprehensively. If the context doesn't contain enough
265
+ information, you can supplement with your general knowledge, but always prioritize the provided context.
266
+
267
+ Provide clear, well-structured answers with examples when appropriate."""
268
+
269
+ user_prompt = f"""Context:
270
+ {context}
271
+
272
+ Question: {query}
273
+
274
+ Please provide a comprehensive answer based on the context above."""
275
+
276
+ try:
277
+ messages = [
278
+ SystemMessage(content=system_prompt),
279
+ HumanMessage(content=user_prompt)
280
+ ]
281
 
282
+ response = llm.invoke(messages)
283
+ return response.content
284
+ except Exception as e:
285
+ return f"Error generating answer: {e}"
286
+
287
+ def rag_pipeline(query, rag_system, n_results=5):
288
+ """Complete RAG pipeline"""
289
+ try:
290
+ collection = rag_system['collection']
291
+ llm = rag_system['llm']
292
 
293
+ # Retrieve relevant documents
294
+ documents, metadatas, distances = retrieve_relevant_docs(query, collection, n_results)
295
 
296
+ if not documents:
297
+ return "I couldn't find relevant information for your query. Please try asking about machine learning, deep learning, or AI topics."
298
 
299
+ # Create context
300
+ context = create_context(documents)
301
 
302
+ # Generate answer
303
+ answer = generate_answer(query, context, llm)
304
+ return answer, documents, distances
305
 
306
  except Exception as e:
307
  return f"Error generating response: {e}", [], []
 
310
  st.markdown("""
311
  <div class="main-header">
312
  <h1>🤖 RAG Chatbot: ML/AI Assistant</h1>
313
+ <p>Powered by Google Gemini 2.5 Flash + LangChain + Chroma</p>
314
  </div>
315
  """, unsafe_allow_html=True)
316
 
 
324
  type="password",
325
  help="Get your API key from Google AI Studio"
326
  )
327
+
328
+
329
  if api_key:
330
  os.environ['GOOGLE_API_KEY'] = api_key
331
 
 
385
  deep learning, AI, and related topics using:
386
 
387
  - **🤖 Generation Model**: Google Gemini 2.5 Flash
388
+ - **🔗 RAG Framework**: LangChain
389
  - **🗄️ Vector Database**: Chroma
390
+ - **📚 Dataset**: The Pile (EleutherAI/the_pile) from Hugging Face
391
  - **🌐 Interface**: Streamlit
392
 
393
  ### 🚀 How It Works
394
 
395
+ 1. **Data Loading**: Text data from The Pile dataset is loaded and filtered for ML/AI content
396
  2. **Embedding**: Text is processed and embedded using sentence transformers
397
  3. **Storage**: Embeddings are stored in Chroma vector database
398
  4. **Retrieval**: Relevant context is retrieved for user queries
 
465
  st.markdown("---")
466
  st.markdown("""
467
  <div style="text-align: center; color: #666; padding: 1rem;">
468
+ <p>🤖 RAG Chatbot | Powered by Google Gemini 2.5 Flash + LangChain + Chroma</p>
469
+ <p>📚 Knowledge Base: The Pile Dataset (EleutherAI/the_pile)</p>
470
  </div>
471
  """, unsafe_allow_html=True)
requirements.txt CHANGED
@@ -6,4 +6,5 @@ google-generativeai==0.3.2
6
  numpy==1.24.3
7
  pandas==2.0.3
8
  tqdm==4.66.1
9
- huggingface-hub>=0.16.4,<1.0.0
 
 
6
  numpy==1.24.3
7
  pandas==2.0.3
8
  tqdm==4.66.1
9
+ huggingface-hub>=0.16.4,<1.0.0
10
+ gradio