Ara Yeroyan commited on
Commit
3fc1b5f
Β·
1 Parent(s): f8a1d41

finalize gemini version

Browse files
Files changed (2) hide show
  1. app.py +60 -7
  2. src/vectorstore.py +27 -8
app.py CHANGED
@@ -33,6 +33,7 @@ from src.config.paths import (
33
  CONVERSATIONS_DIR,
34
  )
35
 
 
36
  # ===== CRITICAL: Fix OMP_NUM_THREADS FIRST, before ANY other imports =====
37
  # Some libraries load at import time and will fail if OMP_NUM_THREADS is invalid
38
  omp_threads = os.environ.get("OMP_NUM_THREADS", "")
@@ -72,6 +73,9 @@ if IS_DEPLOYED and HF_CACHE_DIR:
72
  except (PermissionError, OSError):
73
  # If we can't create it, log but continue (might already exist from Dockerfile)
74
  pass
 
 
 
75
 
76
  # Configure logging
77
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
@@ -191,18 +195,35 @@ def main():
191
  if 'chatbot_version' not in st.session_state:
192
  st.session_state.chatbot_version = "v1"
193
 
194
- # Initialize chatbot based on version (reinitialize if version changes)
195
  chatbot_version_key = f"chatbot_{st.session_state.chatbot_version}"
196
- if chatbot_version_key not in st.session_state or st.session_state.get('_last_version') != st.session_state.chatbot_version:
 
 
 
 
 
 
 
197
  try:
198
- with st.spinner("πŸ”„ Loading AI models and connecting to database..."):
 
 
 
 
 
 
199
  st.session_state[chatbot_version_key] = get_chatbot(st.session_state.chatbot_version)
200
  st.session_state['_last_version'] = st.session_state.chatbot_version
201
  st.session_state.chatbot = st.session_state[chatbot_version_key]
202
- st.success("βœ… AI system ready!")
203
  except Exception as e:
204
  st.error(f"❌ Failed to initialize chatbot: {str(e)}")
205
- st.error("Please check your environment variables (GEMINI_API_KEY, GEMINI_FILESTORE_NAME for beta)")
 
 
 
 
206
  # Reset to v1 to prevent infinite loop
207
  st.session_state.chatbot_version = "v1"
208
  st.session_state['_last_version'] = "v1"
@@ -210,6 +231,7 @@ def main():
210
  del st.session_state['chatbot']
211
  st.stop() # Stop execution to prevent infinite loop
212
  else:
 
213
  st.session_state.chatbot = st.session_state[chatbot_version_key]
214
 
215
  # Reset conversation history if needed (but keep chatbot cached)
@@ -223,7 +245,38 @@ def main():
223
  st.rerun()
224
 
225
 
226
- st.markdown('<p class="subtitle">Ask questions about audit reports. Use the sidebar filters to narrow down your search!</p>', unsafe_allow_html=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
227
 
228
  # Show version info
229
  if st.session_state.chatbot_version == "beta":
@@ -289,7 +342,7 @@ def main():
289
  # Determine if filename filter is active
290
  filename_mode = len(selected_filenames) > 0
291
  # Sources filter
292
- st.markdown('<div class="filter-section">', unsafe_allow_html=True)
293
  st.markdown('<div class="filter-title">πŸ“Š Sources</div>', unsafe_allow_html=True)
294
  selected_sources = st.multiselect(
295
  "Select sources:",
 
33
  CONVERSATIONS_DIR,
34
  )
35
 
36
+
37
  # ===== CRITICAL: Fix OMP_NUM_THREADS FIRST, before ANY other imports =====
38
  # Some libraries load at import time and will fail if OMP_NUM_THREADS is invalid
39
  omp_threads = os.environ.get("OMP_NUM_THREADS", "")
 
73
  except (PermissionError, OSError):
74
  # If we can't create it, log but continue (might already exist from Dockerfile)
75
  pass
76
+ else:
77
+ from dotenv import load_dotenv
78
+ load_dotenv()
79
 
80
  # Configure logging
81
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
 
195
  if 'chatbot_version' not in st.session_state:
196
  st.session_state.chatbot_version = "v1"
197
 
198
+ # Initialize chatbot based on version (only if not already initialized for this version)
199
  chatbot_version_key = f"chatbot_{st.session_state.chatbot_version}"
200
+
201
+ # Check if we need to initialize: chatbot doesn't exist OR version changed
202
+ needs_init = (
203
+ chatbot_version_key not in st.session_state or
204
+ st.session_state.get('_last_version') != st.session_state.chatbot_version
205
+ )
206
+
207
+ if needs_init:
208
  try:
209
+ # Different spinner messages for different versions
210
+ if st.session_state.chatbot_version == "beta":
211
+ spinner_msg = "πŸ”„ Initializing Gemini File Search..."
212
+ else:
213
+ spinner_msg = "πŸ”„ Loading AI models and connecting to database..."
214
+
215
+ with st.spinner(spinner_msg):
216
  st.session_state[chatbot_version_key] = get_chatbot(st.session_state.chatbot_version)
217
  st.session_state['_last_version'] = st.session_state.chatbot_version
218
  st.session_state.chatbot = st.session_state[chatbot_version_key]
219
+ print("βœ… AI system ready!")
220
  except Exception as e:
221
  st.error(f"❌ Failed to initialize chatbot: {str(e)}")
222
+ # Only show Gemini-specific error message for beta version
223
+ if st.session_state.chatbot_version == "beta":
224
+ st.error("Please check your environment variables (GEMINI_API_KEY, GEMINI_FILESTORE_NAME for beta)")
225
+ else:
226
+ st.error("Please check your configuration and ensure all required models and databases are accessible.")
227
  # Reset to v1 to prevent infinite loop
228
  st.session_state.chatbot_version = "v1"
229
  st.session_state['_last_version'] = "v1"
 
231
  del st.session_state['chatbot']
232
  st.stop() # Stop execution to prevent infinite loop
233
  else:
234
+ # Chatbot already initialized for this version, just use it
235
  st.session_state.chatbot = st.session_state[chatbot_version_key]
236
 
237
  # Reset conversation history if needed (but keep chatbot cached)
 
245
  st.rerun()
246
 
247
 
248
+ # Version selection radio button (top right)
249
+ col1, col2 = st.columns([3, 1])
250
+ with col1:
251
+ st.markdown('<p class="subtitle">Ask questions about audit reports. Use the sidebar filters to narrow down your search!</p>', unsafe_allow_html=True)
252
+ with col2:
253
+ st.markdown("<br>", unsafe_allow_html=True) # Add some spacing
254
+ selected_version = st.radio(
255
+ "**Version:**",
256
+ options=["v1", "beta"],
257
+ index=0 if st.session_state.chatbot_version == "v1" else 1,
258
+ horizontal=True,
259
+ key="version_selector",
260
+ help="Select v1 (default RAG system) or beta (Gemini File Search)"
261
+ )
262
+
263
+ # Update version if changed
264
+ if selected_version != st.session_state.chatbot_version:
265
+ # Store the old version to check if we need to switch
266
+ old_version = st.session_state.chatbot_version
267
+ st.session_state.chatbot_version = selected_version
268
+
269
+ # If chatbot for new version already exists, just switch to it
270
+ new_chatbot_key = f"chatbot_{selected_version}"
271
+ if new_chatbot_key in st.session_state:
272
+ # Chatbot already exists, just switch
273
+ st.session_state.chatbot = st.session_state[new_chatbot_key]
274
+ st.session_state['_last_version'] = selected_version
275
+ else:
276
+ # Need to initialize new version - will be handled by initialization logic above
277
+ st.session_state['_last_version'] = old_version # Set to old to trigger init check
278
+
279
+ st.rerun()
280
 
281
  # Show version info
282
  if st.session_state.chatbot_version == "beta":
 
342
  # Determine if filename filter is active
343
  filename_mode = len(selected_filenames) > 0
344
  # Sources filter
345
+ # st.markdown('<div class="filter-section">', unsafe_allow_html=True)
346
  st.markdown('<div class="filter-title">πŸ“Š Sources</div>', unsafe_allow_html=True)
347
  selected_sources = st.multiselect(
348
  "Select sources:",
src/vectorstore.py CHANGED
@@ -1,9 +1,20 @@
1
  """Vector store management and operations."""
 
 
 
 
 
2
  from pathlib import Path
3
  from typing import Dict, Any, List, Optional
4
 
5
 
6
  import torch
 
 
 
 
 
 
7
  from langchain_qdrant import QdrantVectorStore
8
  from langchain.docstore.document import Document
9
  from langchain_core.embeddings import Embeddings
@@ -28,19 +39,23 @@ class MatryoshkaEmbeddings(Embeddings):
28
 
29
  if truncate_dim and "matryoshka" in model_name.lower():
30
  # Use SentenceTransformer directly for Matryoshka models
31
- # Explicitly load on CPU first to avoid meta tensor issues
 
 
32
  self.model = SentenceTransformer(
33
  model_name,
34
  truncate_dim=truncate_dim,
35
- device="cpu" # Load on CPU first, prevents meta tensor error
36
  )
37
  print(f"πŸ”§ Matryoshka model configured for {truncate_dim} dimensions")
38
  else:
39
  # Use standard HuggingFaceEmbeddings
40
- # Pass device="cpu" to prevent meta tensor issues
 
41
  if "model_kwargs" not in kwargs:
42
  kwargs["model_kwargs"] = {}
43
- kwargs["model_kwargs"]["device"] = "cpu"
 
44
  self.model = HuggingFaceEmbeddings(model_name=model_name, **kwargs)
45
 
46
  def embed_documents(self, texts: List[str]) -> List[List[float]]:
@@ -87,12 +102,14 @@ class VectorStoreManager:
87
  model_name = self.config["retriever"]["model"]
88
  normalize = self.config["retriever"]["normalize"]
89
 
90
- # Fix for meta tensor issue: explicitly load on CPU first
91
- # This prevents HuggingFaceEmbeddings from trying to move meta tensors
92
- # The model will be loaded on CPU and can be moved later if needed
93
  model_kwargs = {
94
- "device": "cpu" # Load on CPU first to avoid meta tensor issues
 
95
  }
 
96
  encode_kwargs = {
97
  "normalize_embeddings": normalize,
98
  "batch_size": 100,
@@ -119,6 +136,8 @@ class VectorStoreManager:
119
  return embeddings
120
 
121
  # Use standard HuggingFaceEmbeddings for non-Matryoshka models
 
 
122
  embeddings = HuggingFaceEmbeddings(
123
  model_name=model_name,
124
  model_kwargs=model_kwargs,
 
1
  """Vector store management and operations."""
2
+ import os
3
+ # Disable MPS before importing torch to prevent meta tensor issues on Mac
4
+ os.environ.setdefault("PYTORCH_ENABLE_MPS_FALLBACK", "1")
5
+ os.environ.setdefault("PYTORCH_MPS_HIGH_WATERMARK_RATIO", "0.0")
6
+
7
  from pathlib import Path
8
  from typing import Dict, Any, List, Optional
9
 
10
 
11
  import torch
12
+ # Disable MPS backend explicitly to prevent meta tensor issues
13
+ if hasattr(torch.backends, 'mps'):
14
+ # Monkey patch to disable MPS
15
+ original_mps_available = torch.backends.mps.is_available
16
+ torch.backends.mps.is_available = lambda: False
17
+
18
  from langchain_qdrant import QdrantVectorStore
19
  from langchain.docstore.document import Document
20
  from langchain_core.embeddings import Embeddings
 
39
 
40
  if truncate_dim and "matryoshka" in model_name.lower():
41
  # Use SentenceTransformer directly for Matryoshka models
42
+ # Fix for meta tensor issue: Explicitly force CPU
43
+ # MPS is already disabled at module level
44
+ # Explicitly pass device="cpu" to prevent MPS/CUDA detection
45
  self.model = SentenceTransformer(
46
  model_name,
47
  truncate_dim=truncate_dim,
48
+ device="cpu" # Force CPU to prevent meta tensor issues
49
  )
50
  print(f"πŸ”§ Matryoshka model configured for {truncate_dim} dimensions")
51
  else:
52
  # Use standard HuggingFaceEmbeddings
53
+ # Don't pass device parameter - let it load naturally on CPU
54
+ # This prevents the meta tensor error
55
  if "model_kwargs" not in kwargs:
56
  kwargs["model_kwargs"] = {}
57
+ # Remove device from model_kwargs if present to prevent meta tensor issues
58
+ kwargs["model_kwargs"].pop("device", None)
59
  self.model = HuggingFaceEmbeddings(model_name=model_name, **kwargs)
60
 
61
  def embed_documents(self, texts: List[str]) -> List[List[float]]:
 
102
  model_name = self.config["retriever"]["model"]
103
  normalize = self.config["retriever"]["normalize"]
104
 
105
+ # Fix for meta tensor issue: Force CPU usage to prevent MPS/CUDA detection
106
+ # The error occurs when SentenceTransformer detects MPS/CUDA and tries to move meta tensors
107
+ # MPS is already disabled at module level, now we explicitly force CPU in model_kwargs
108
  model_kwargs = {
109
+ "device": "cpu", # Explicitly force CPU to prevent MPS/CUDA detection
110
+ "trust_remote_code": True, # Some models need this
111
  }
112
+
113
  encode_kwargs = {
114
  "normalize_embeddings": normalize,
115
  "batch_size": 100,
 
136
  return embeddings
137
 
138
  # Use standard HuggingFaceEmbeddings for non-Matryoshka models
139
+ # Don't pass device in model_kwargs - let HuggingFaceEmbeddings handle it
140
+ # but ensure we're not using meta device
141
  embeddings = HuggingFaceEmbeddings(
142
  model_name=model_name,
143
  model_kwargs=model_kwargs,