RFP_Analyzer_Agent / utils /persistence.py
cryogenic22's picture
Update utils/persistence.py
66fb470 verified
import os
import json
from datetime import datetime
import faiss
import numpy as np
import pickle
from pathlib import Path
import streamlit as st
from typing import List, Dict, Any, Optional, Tuple
from langchain_core.messages import BaseMessage, HumanMessage, AIMessage
class PersistenceManager:
def __init__(self, data_dir: str = "data"):
"""Initialize the persistence manager with paths for storing data.
Args:
data_dir: Base directory for data storage
"""
self.base_dir = Path(data_dir)
self.vector_store_dir = self.base_dir / "vector_stores"
self.chat_history_dir = self.base_dir / "chat_histories"
self.chunks_dir = self.base_dir / "chunks"
# Create necessary directories
for directory in [self.vector_store_dir, self.chat_history_dir, self.chunks_dir]:
directory.mkdir(parents=True, exist_ok=True)
def save_vector_store(self, vector_store: Any, session_id: str) -> bool:
"""Save FAISS vector store and related metadata.
Args:
vector_store: FAISS vector store instance
session_id: Unique identifier for the session
"""
try:
# Create session-specific directory
store_path = self.vector_store_dir / session_id
store_path.mkdir(exist_ok=True)
# Save the FAISS index
faiss.write_index(vector_store.index,
str(store_path / "index.faiss"))
# Save the documents and their metadata
with open(store_path / "docstore.pkl", "wb") as f:
pickle.dump(vector_store.docstore, f)
return True
except Exception as e:
st.error(f"Error saving vector store: {str(e)}")
return False
def load_vector_store(self, session_id: str) -> Any:
"""Load FAISS vector store and related metadata.
Args:
session_id: Unique identifier for the session
"""
try:
store_path = self.vector_store_dir / session_id
if not store_path.exists():
return None
# Load the FAISS index
index = faiss.read_index(str(store_path / "index.faiss"))
# Load the document store
with open(store_path / "docstore.pkl", "rb") as f:
docstore = pickle.load(f)
# Recreate the vector store
from langchain.vectorstores import FAISS
vector_store = FAISS(
embedding_function=st.session_state.embeddings,
index=index,
docstore=docstore,
index_to_docstore_id=docstore.index_to_docstore_id
)
return vector_store
except Exception as e:
st.error(f"Error loading vector store: {str(e)}")
return None
def save_chat_history(
self,
messages: List[BaseMessage],
session_id: str,
metadata: Dict[str, Any] = None
) -> bool:
"""Save chat history with metadata.
Args:
messages: List of chat messages
session_id: Unique identifier for the chat session
metadata: Additional metadata about the chat session
"""
try:
# Convert messages to serializable format
serialized_messages = []
for msg in messages:
if isinstance(msg, (HumanMessage, AIMessage)):
serialized_messages.append({
'type': msg.__class__.__name__,
'content': msg.content,
'timestamp': datetime.now().isoformat()
})
# Prepare chat data
chat_data = {
'messages': serialized_messages,
'metadata': metadata or {},
'last_updated': datetime.now().isoformat()
}
# Save to JSON file
chat_file = self.chat_history_dir / f"{session_id}.json"
with open(chat_file, 'w') as f:
json.dump(chat_data, f, indent=2)
return True
except Exception as e:
st.error(f"Error saving chat history: {str(e)}")
return False
def load_chat_history(self, session_id: str) -> List[BaseMessage]:
"""Load chat history for a session.
Args:
session_id: Unique identifier for the chat session
"""
try:
chat_file = self.chat_history_dir / f"{session_id}.json"
if not chat_file.exists():
return []
with open(chat_file, 'r') as f:
chat_data = json.load(f)
# Convert back to message objects
messages = []
for msg in chat_data['messages']:
if msg['type'] == 'HumanMessage':
messages.append(HumanMessage(content=msg['content']))
elif msg['type'] == 'AIMessage':
messages.append(AIMessage(content=msg['content']))
return messages
except Exception as e:
st.error(f"Error loading chat history: {str(e)}")
return []
def save_chunks(
self,
chunks: List[str],
chunk_metadatas: List[Dict],
session_id: str
) -> bool:
"""Save document chunks and their metadata.
Args:
chunks: List of text chunks
chunk_metadatas: List of metadata dictionaries for each chunk
session_id: Unique identifier for the session
"""
try:
chunk_data = {
'chunks': chunks,
'metadatas': chunk_metadatas,
'created_at': datetime.now().isoformat()
}
chunk_file = self.chunks_dir / f"{session_id}_chunks.pkl"
with open(chunk_file, 'wb') as f:
pickle.dump(chunk_data, f)
return True
except Exception as e:
st.error(f"Error saving chunks: {str(e)}")
return False
def load_chunks(self, session_id: str) -> tuple:
"""Load document chunks and their metadata.
Args:
session_id: Unique identifier for the session
"""
try:
chunk_file = self.chunks_dir / f"{session_id}_chunks.pkl"
if not chunk_file.exists():
return None, None
with open(chunk_file, 'rb') as f:
chunk_data = pickle.load(f)
return chunk_data['chunks'], chunk_data['metadatas']
except Exception as e:
st.error(f"Error loading chunks: {str(e)}")
return None, None
def list_available_sessions(self) -> List[Dict[str, Any]]:
"""List all available chat sessions with their metadata."""
try:
sessions = []
for chat_file in self.chat_history_dir.glob("*.json"):
with open(chat_file, 'r') as f:
chat_data = json.load(f)
session_id = chat_file.stem
sessions.append({
'session_id': session_id,
'last_updated': chat_data['last_updated'],
'metadata': chat_data['metadata']
})
# Sort by last updated time
sessions.sort(key=lambda x: x['last_updated'], reverse=True)
return sessions
except Exception as e:
st.error(f"Error listing sessions: {str(e)}")
return []