ankanghosh commited on
Commit
4b849de
Β·
verified Β·
1 Parent(s): 0795f5a

Update rag_engine.py

Browse files
Files changed (1) hide show
  1. rag_engine.py +19 -0
rag_engine.py CHANGED
@@ -51,6 +51,7 @@ def setup_openai_client():
51
  print(f"❌ OpenAI client initialization error: {str(e)}")
52
  return False
53
 
 
54
  def load_model():
55
  """Load the embedding model and store in session state"""
56
  try:
@@ -90,6 +91,7 @@ def load_model():
90
  # Return None values - don't raise exception
91
  return None, None
92
 
 
93
  def download_file_from_gcs(bucket, gcs_path, local_path):
94
  """Download a file from GCS to local storage."""
95
  try:
@@ -106,6 +108,7 @@ def download_file_from_gcs(bucket, gcs_path, local_path):
106
  print(f"❌ Error downloading {gcs_path}: {str(e)}")
107
  return False
108
 
 
109
  def load_data_files():
110
  """Load FAISS index, text chunks, and metadata"""
111
  # Check if already loaded in session state
@@ -178,6 +181,7 @@ def average_pool(last_hidden_states, attention_mask):
178
  # Cache for query embeddings
179
  query_embedding_cache = {}
180
 
 
181
  def get_embedding(text):
182
  """Generate embeddings for a text query"""
183
  # Check cache first
@@ -227,6 +231,7 @@ def get_embedding(text):
227
  print(f"❌ Embedding error: {str(e)}")
228
  return np.zeros((1, 384), dtype=np.float32)
229
 
 
230
  def retrieve_passages(query, faiss_index, text_chunks, metadata_dict, top_k=5, similarity_threshold=0.5):
231
  """Retrieve top-k most relevant passages using FAISS with metadata."""
232
  try:
@@ -271,12 +276,17 @@ def retrieve_passages(query, faiss_index, text_chunks, metadata_dict, top_k=5, s
271
  if len(retrieved_passages) == top_k:
272
  break
273
 
 
 
 
 
274
  print(f"Retrieved {len(retrieved_passages)} passages")
275
  return retrieved_passages, retrieved_sources
276
  except Exception as e:
277
  print(f"❌ Error in retrieve_passages: {str(e)}")
278
  return [], []
279
 
 
280
  def answer_with_llm(query, context=None, word_limit=100):
281
  """Generate an answer using OpenAI GPT model with formatted citations."""
282
  try:
@@ -338,6 +348,10 @@ def answer_with_llm(query, context=None, word_limit=100):
338
  if not answer.endswith((".", "!", "?")):
339
  answer += "."
340
 
 
 
 
 
341
  return answer
342
 
343
  except Exception as e:
@@ -356,6 +370,7 @@ def format_citations(sources):
356
 
357
  return "\n".join(formatted_citations)
358
 
 
359
  def process_query(query, top_k=5, word_limit=100):
360
  """Process a query through the RAG pipeline with proper formatting."""
361
  print(f"\nπŸ” Processing query: {query}")
@@ -390,4 +405,8 @@ def process_query(query, top_k=5, word_limit=100):
390
  else:
391
  llm_answer_with_rag = "⚠️ No relevant context found."
392
 
 
 
 
 
393
  return {"query": query, "answer_with_rag": llm_answer_with_rag, "citations": sources}
 
51
  print(f"❌ OpenAI client initialization error: {str(e)}")
52
  return False
53
 
54
+ @st.cache_resource
55
  def load_model():
56
  """Load the embedding model and store in session state"""
57
  try:
 
91
  # Return None values - don't raise exception
92
  return None, None
93
 
94
+ @st.cache_data(ttl=3600)
95
  def download_file_from_gcs(bucket, gcs_path, local_path):
96
  """Download a file from GCS to local storage."""
97
  try:
 
108
  print(f"❌ Error downloading {gcs_path}: {str(e)}")
109
  return False
110
 
111
+ @st.cache_resource
112
  def load_data_files():
113
  """Load FAISS index, text chunks, and metadata"""
114
  # Check if already loaded in session state
 
181
  # Cache for query embeddings
182
  query_embedding_cache = {}
183
 
184
+ @st.cache_data(ttl=1800)
185
  def get_embedding(text):
186
  """Generate embeddings for a text query"""
187
  # Check cache first
 
231
  print(f"❌ Embedding error: {str(e)}")
232
  return np.zeros((1, 384), dtype=np.float32)
233
 
234
+ @st.cache_data(ttl=900)
235
  def retrieve_passages(query, faiss_index, text_chunks, metadata_dict, top_k=5, similarity_threshold=0.5):
236
  """Retrieve top-k most relevant passages using FAISS with metadata."""
237
  try:
 
276
  if len(retrieved_passages) == top_k:
277
  break
278
 
279
+ # Clean up
280
+ del query_embedding, distances, indices
281
+ gc.collect()
282
+
283
  print(f"Retrieved {len(retrieved_passages)} passages")
284
  return retrieved_passages, retrieved_sources
285
  except Exception as e:
286
  print(f"❌ Error in retrieve_passages: {str(e)}")
287
  return [], []
288
 
289
+ @st.cache_data(ttl=1800)
290
  def answer_with_llm(query, context=None, word_limit=100):
291
  """Generate an answer using OpenAI GPT model with formatted citations."""
292
  try:
 
348
  if not answer.endswith((".", "!", "?")):
349
  answer += "."
350
 
351
+ # Clean up
352
+ del response, formatted_context, system_message, user_message
353
+ gc.collect()
354
+
355
  return answer
356
 
357
  except Exception as e:
 
370
 
371
  return "\n".join(formatted_citations)
372
 
373
+ @st.cache_data(ttl=3600)
374
  def process_query(query, top_k=5, word_limit=100):
375
  """Process a query through the RAG pipeline with proper formatting."""
376
  print(f"\nπŸ” Processing query: {query}")
 
405
  else:
406
  llm_answer_with_rag = "⚠️ No relevant context found."
407
 
408
+ # Clean up
409
+ del retrieved_context, retrieved_sources
410
+ gc.collect()
411
+
412
  return {"query": query, "answer_with_rag": llm_answer_with_rag, "citations": sources}