Spaces:
Running
Running
Update rag_engine.py
Browse files- rag_engine.py +19 -0
rag_engine.py
CHANGED
@@ -51,6 +51,7 @@ def setup_openai_client():
|
|
51 |
print(f"β OpenAI client initialization error: {str(e)}")
|
52 |
return False
|
53 |
|
|
|
54 |
def load_model():
|
55 |
"""Load the embedding model and store in session state"""
|
56 |
try:
|
@@ -90,6 +91,7 @@ def load_model():
|
|
90 |
# Return None values - don't raise exception
|
91 |
return None, None
|
92 |
|
|
|
93 |
def download_file_from_gcs(bucket, gcs_path, local_path):
|
94 |
"""Download a file from GCS to local storage."""
|
95 |
try:
|
@@ -106,6 +108,7 @@ def download_file_from_gcs(bucket, gcs_path, local_path):
|
|
106 |
print(f"β Error downloading {gcs_path}: {str(e)}")
|
107 |
return False
|
108 |
|
|
|
109 |
def load_data_files():
|
110 |
"""Load FAISS index, text chunks, and metadata"""
|
111 |
# Check if already loaded in session state
|
@@ -178,6 +181,7 @@ def average_pool(last_hidden_states, attention_mask):
|
|
178 |
# Cache for query embeddings
|
179 |
query_embedding_cache = {}
|
180 |
|
|
|
181 |
def get_embedding(text):
|
182 |
"""Generate embeddings for a text query"""
|
183 |
# Check cache first
|
@@ -227,6 +231,7 @@ def get_embedding(text):
|
|
227 |
print(f"β Embedding error: {str(e)}")
|
228 |
return np.zeros((1, 384), dtype=np.float32)
|
229 |
|
|
|
230 |
def retrieve_passages(query, faiss_index, text_chunks, metadata_dict, top_k=5, similarity_threshold=0.5):
|
231 |
"""Retrieve top-k most relevant passages using FAISS with metadata."""
|
232 |
try:
|
@@ -271,12 +276,17 @@ def retrieve_passages(query, faiss_index, text_chunks, metadata_dict, top_k=5, s
|
|
271 |
if len(retrieved_passages) == top_k:
|
272 |
break
|
273 |
|
|
|
|
|
|
|
|
|
274 |
print(f"Retrieved {len(retrieved_passages)} passages")
|
275 |
return retrieved_passages, retrieved_sources
|
276 |
except Exception as e:
|
277 |
print(f"β Error in retrieve_passages: {str(e)}")
|
278 |
return [], []
|
279 |
|
|
|
280 |
def answer_with_llm(query, context=None, word_limit=100):
|
281 |
"""Generate an answer using OpenAI GPT model with formatted citations."""
|
282 |
try:
|
@@ -338,6 +348,10 @@ def answer_with_llm(query, context=None, word_limit=100):
|
|
338 |
if not answer.endswith((".", "!", "?")):
|
339 |
answer += "."
|
340 |
|
|
|
|
|
|
|
|
|
341 |
return answer
|
342 |
|
343 |
except Exception as e:
|
@@ -356,6 +370,7 @@ def format_citations(sources):
|
|
356 |
|
357 |
return "\n".join(formatted_citations)
|
358 |
|
|
|
359 |
def process_query(query, top_k=5, word_limit=100):
|
360 |
"""Process a query through the RAG pipeline with proper formatting."""
|
361 |
print(f"\nπ Processing query: {query}")
|
@@ -390,4 +405,8 @@ def process_query(query, top_k=5, word_limit=100):
|
|
390 |
else:
|
391 |
llm_answer_with_rag = "β οΈ No relevant context found."
|
392 |
|
|
|
|
|
|
|
|
|
393 |
return {"query": query, "answer_with_rag": llm_answer_with_rag, "citations": sources}
|
|
|
51 |
print(f"β OpenAI client initialization error: {str(e)}")
|
52 |
return False
|
53 |
|
54 |
+
@st.cache_resource
|
55 |
def load_model():
|
56 |
"""Load the embedding model and store in session state"""
|
57 |
try:
|
|
|
91 |
# Return None values - don't raise exception
|
92 |
return None, None
|
93 |
|
94 |
+
@st.cache_data(ttl=3600)
|
95 |
def download_file_from_gcs(bucket, gcs_path, local_path):
|
96 |
"""Download a file from GCS to local storage."""
|
97 |
try:
|
|
|
108 |
print(f"β Error downloading {gcs_path}: {str(e)}")
|
109 |
return False
|
110 |
|
111 |
+
@st.cache_resource
|
112 |
def load_data_files():
|
113 |
"""Load FAISS index, text chunks, and metadata"""
|
114 |
# Check if already loaded in session state
|
|
|
181 |
# Cache for query embeddings
|
182 |
query_embedding_cache = {}
|
183 |
|
184 |
+
@st.cache_data(ttl=1800)
|
185 |
def get_embedding(text):
|
186 |
"""Generate embeddings for a text query"""
|
187 |
# Check cache first
|
|
|
231 |
print(f"β Embedding error: {str(e)}")
|
232 |
return np.zeros((1, 384), dtype=np.float32)
|
233 |
|
234 |
+
@st.cache_data(ttl=900)
|
235 |
def retrieve_passages(query, faiss_index, text_chunks, metadata_dict, top_k=5, similarity_threshold=0.5):
|
236 |
"""Retrieve top-k most relevant passages using FAISS with metadata."""
|
237 |
try:
|
|
|
276 |
if len(retrieved_passages) == top_k:
|
277 |
break
|
278 |
|
279 |
+
# Clean up
|
280 |
+
del query_embedding, distances, indices
|
281 |
+
gc.collect()
|
282 |
+
|
283 |
print(f"Retrieved {len(retrieved_passages)} passages")
|
284 |
return retrieved_passages, retrieved_sources
|
285 |
except Exception as e:
|
286 |
print(f"β Error in retrieve_passages: {str(e)}")
|
287 |
return [], []
|
288 |
|
289 |
+
@st.cache_data(ttl=1800)
|
290 |
def answer_with_llm(query, context=None, word_limit=100):
|
291 |
"""Generate an answer using OpenAI GPT model with formatted citations."""
|
292 |
try:
|
|
|
348 |
if not answer.endswith((".", "!", "?")):
|
349 |
answer += "."
|
350 |
|
351 |
+
# Clean up
|
352 |
+
del response, formatted_context, system_message, user_message
|
353 |
+
gc.collect()
|
354 |
+
|
355 |
return answer
|
356 |
|
357 |
except Exception as e:
|
|
|
370 |
|
371 |
return "\n".join(formatted_citations)
|
372 |
|
373 |
+
@st.cache_data(ttl=3600)
|
374 |
def process_query(query, top_k=5, word_limit=100):
|
375 |
"""Process a query through the RAG pipeline with proper formatting."""
|
376 |
print(f"\nπ Processing query: {query}")
|
|
|
405 |
else:
|
406 |
llm_answer_with_rag = "β οΈ No relevant context found."
|
407 |
|
408 |
+
# Clean up
|
409 |
+
del retrieved_context, retrieved_sources
|
410 |
+
gc.collect()
|
411 |
+
|
412 |
return {"query": query, "answer_with_rag": llm_answer_with_rag, "citations": sources}
|