vedaMD / src /vector_store_manager.py
sniro23's picture
Initial commit without binary files
19aaa42
#!/usr/bin/env python3
"""
Vector Store Manager for Maternal Health RAG Chatbot
Uses FAISS with the optimal all-MiniLM-L6-v2 embedding model
"""
import json
import numpy as np
import faiss
from pathlib import Path
from typing import List, Dict, Any, Tuple, Optional
import logging
from sentence_transformers import SentenceTransformer
import pickle
import time
from dataclasses import dataclass
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
@dataclass
class SearchResult:
"""Container for search results"""
content: str
score: float
metadata: Dict[str, Any]
chunk_index: int
source_document: str
chunk_type: str
clinical_importance: float
class MaternalHealthVectorStore:
"""Vector store for maternal health guidelines with clinical context filtering"""
def __init__(self,
vector_store_dir: str = "vector_store",
embedding_model: str = "sentence-transformers/all-MiniLM-L6-v2",
chunks_dir: str = "comprehensive_chunks"):
self.vector_store_dir = Path(vector_store_dir)
self.vector_store_dir.mkdir(exist_ok=True)
self.chunks_dir = Path(chunks_dir)
self.embedding_model_name = embedding_model
# Initialize components
self.embedding_model = None
self.index = None
self.documents = []
self.metadata = []
# Vector store files
self.index_file = self.vector_store_dir / "faiss_index.bin"
self.documents_file = self.vector_store_dir / "documents.json"
self.metadata_file = self.vector_store_dir / "metadata.json"
self.config_file = self.vector_store_dir / "config.json"
# Search parameters
self.default_k = 5
self.similarity_threshold = 0.3
def initialize_embedding_model(self):
"""Initialize the optimal embedding model"""
logger.info(f"Initializing embedding model: {self.embedding_model_name}")
try:
self.embedding_model = SentenceTransformer(self.embedding_model_name)
logger.info("✅ Embedding model loaded successfully")
# Get embedding dimension
test_embedding = self.embedding_model.encode(["test"])
self.embedding_dimension = test_embedding.shape[1]
logger.info(f"📏 Embedding dimension: {self.embedding_dimension}")
except Exception as e:
logger.error(f"❌ Failed to load embedding model: {e}")
raise
def load_medical_documents(self) -> List[Dict[str, Any]]:
"""Load processed medical documents"""
logger.info("Loading medical documents for vector store...")
langchain_file = self.chunks_dir / "langchain_documents_comprehensive.json"
if not langchain_file.exists():
raise FileNotFoundError(f"Medical documents not found: {langchain_file}")
with open(langchain_file, 'r', encoding='utf-8') as f:
documents = json.load(f)
logger.info(f"📚 Loaded {len(documents)} medical document chunks")
return documents
def create_vector_index(self, force_rebuild: bool = False) -> bool:
"""Create or load FAISS vector index"""
# Check if existing index can be loaded
if not force_rebuild and self.index_file.exists():
try:
return self.load_existing_index()
except Exception as e:
logger.warning(f"Failed to load existing index: {e}")
logger.info("Rebuilding index from scratch...")
# Initialize embedding model if not done
if self.embedding_model is None:
self.initialize_embedding_model()
# Load documents
documents = self.load_medical_documents()
logger.info("Creating vector embeddings for all medical chunks...")
# Extract content and metadata
texts = []
metadata = []
for doc in documents:
content = doc['page_content']
meta = doc['metadata']
# Skip very short chunks
if len(content.strip()) < 50:
continue
texts.append(content)
metadata.append(meta)
# Generate embeddings in batches
logger.info(f"Generating embeddings for {len(texts)} chunks...")
start_time = time.time()
embeddings = self.embedding_model.encode(
texts,
batch_size=32,
show_progress_bar=True,
convert_to_numpy=True
)
embed_time = time.time() - start_time
logger.info(f"⚡ Embeddings generated in {embed_time:.2f} seconds")
# Create FAISS index
logger.info("Building FAISS index...")
# Use IndexFlatIP for inner product (cosine similarity)
# Normalize embeddings for cosine similarity
faiss.normalize_L2(embeddings)
# Create index
index = faiss.IndexFlatIP(self.embedding_dimension)
index.add(embeddings.astype('float32'))
# Store components
self.index = index
self.documents = texts
self.metadata = metadata
# Save to disk
self.save_index()
logger.info(f"✅ Vector store created with {index.ntotal} embeddings")
return True
def load_existing_index(self) -> bool:
"""Load existing FAISS index from disk"""
logger.info("Loading existing vector store...")
try:
# Load FAISS index
self.index = faiss.read_index(str(self.index_file))
# Load documents
with open(self.documents_file, 'r', encoding='utf-8') as f:
self.documents = json.load(f)
# Load metadata
with open(self.metadata_file, 'r', encoding='utf-8') as f:
self.metadata = json.load(f)
# Load config
with open(self.config_file, 'r') as f:
config = json.load(f)
self.embedding_model_name = config['embedding_model']
self.embedding_dimension = config['embedding_dimension']
# Initialize embedding model
self.initialize_embedding_model()
logger.info(f"✅ Loaded existing vector store with {self.index.ntotal} embeddings")
return True
except Exception as e:
logger.error(f"❌ Failed to load existing index: {e}")
return False
def save_index(self):
"""Save FAISS index and metadata to disk"""
logger.info("Saving vector store to disk...")
try:
# Save FAISS index
faiss.write_index(self.index, str(self.index_file))
# Save documents
with open(self.documents_file, 'w', encoding='utf-8') as f:
json.dump(self.documents, f, ensure_ascii=False, indent=2)
# Save metadata
with open(self.metadata_file, 'w', encoding='utf-8') as f:
json.dump(self.metadata, f, ensure_ascii=False, indent=2)
# Save config
config = {
'embedding_model': self.embedding_model_name,
'embedding_dimension': self.embedding_dimension,
'total_chunks': len(self.documents),
'created_at': time.strftime('%Y-%m-%d %H:%M:%S')
}
with open(self.config_file, 'w') as f:
json.dump(config, f, indent=2)
logger.info(f"💾 Vector store saved to {self.vector_store_dir}")
except Exception as e:
logger.error(f"❌ Failed to save vector store: {e}")
raise
def search(self,
query: str,
k: int = None,
filters: Dict[str, Any] = None,
min_score: float = None) -> List[SearchResult]:
"""Search for relevant medical content"""
if self.index is None:
raise ValueError("Vector store not initialized. Call create_vector_index() first.")
if k is None:
k = self.default_k
if min_score is None:
min_score = self.similarity_threshold
# Generate query embedding
query_embedding = self.embedding_model.encode([query])
faiss.normalize_L2(query_embedding)
# Search in FAISS index
scores, indices = self.index.search(query_embedding.astype('float32'), k * 2) # Get more for filtering
# Process results
results = []
for score, idx in zip(scores[0], indices[0]):
if idx == -1 or score < min_score:
continue
# Get document and metadata
content = self.documents[idx]
metadata = self.metadata[idx]
# Apply filters if specified
if filters and not self._matches_filters(metadata, filters):
continue
# Create search result
result = SearchResult(
content=content,
score=float(score),
metadata=metadata,
chunk_index=idx,
source_document=metadata.get('source', ''),
chunk_type=metadata.get('chunk_type', 'text'),
clinical_importance=metadata.get('clinical_importance', 0.5)
)
results.append(result)
# Stop when we have enough results
if len(results) >= k:
break
return results
def _matches_filters(self, metadata: Dict[str, Any], filters: Dict[str, Any]) -> bool:
"""Check if metadata matches the specified filters"""
for key, value in filters.items():
if key not in metadata:
return False
meta_value = metadata[key]
# Handle different filter types
if isinstance(value, list):
if meta_value not in value:
return False
elif isinstance(value, dict):
if 'min' in value and meta_value < value['min']:
return False
if 'max' in value and meta_value > value['max']:
return False
else:
if meta_value != value:
return False
return True
def search_by_medical_context(self,
query: str,
content_types: List[str] = None,
min_importance: float = 0.5,
k: int = 5) -> List[SearchResult]:
"""Search with medical context filtering"""
filters = {}
# Filter by content types
if content_types:
filters['chunk_type'] = content_types
# Filter by clinical importance
if min_importance > 0:
filters['clinical_importance'] = {'min': min_importance}
return self.search(query, k=k, filters=filters)
def get_statistics(self) -> Dict[str, Any]:
"""Get vector store statistics"""
if self.index is None:
return {'error': 'Vector store not initialized'}
# Calculate statistics from metadata
chunk_types = {}
importance_distribution = {'low': 0, 'medium': 0, 'high': 0, 'critical': 0}
sources = {}
for meta in self.metadata:
# Chunk types
chunk_type = meta.get('chunk_type', 'unknown')
chunk_types[chunk_type] = chunk_types.get(chunk_type, 0) + 1
# Importance distribution
importance = meta.get('clinical_importance', 0)
if importance >= 0.9:
importance_distribution['critical'] += 1
elif importance >= 0.7:
importance_distribution['high'] += 1
elif importance >= 0.5:
importance_distribution['medium'] += 1
else:
importance_distribution['low'] += 1
# Sources
source = meta.get('source', 'unknown')
sources[source] = sources.get(source, 0) + 1
return {
'total_chunks': self.index.ntotal,
'embedding_dimension': self.embedding_dimension,
'embedding_model': self.embedding_model_name,
'chunk_type_distribution': chunk_types,
'clinical_importance_distribution': importance_distribution,
'source_distribution': dict(list(sources.items())[:10]), # Top 10 sources
'vector_store_size_mb': self.index_file.stat().st_size / (1024*1024) if self.index_file.exists() else 0
}
def main():
"""Main function to create and test vector store"""
logger.info("🚀 Creating Maternal Health Vector Store...")
# Create vector store manager
vector_store = MaternalHealthVectorStore()
# Create the vector index
success = vector_store.create_vector_index()
if not success:
logger.error("❌ Failed to create vector store")
return
# Test searches
logger.info("\n🔍 Testing search functionality...")
test_queries = [
"What is the recommended dosage of magnesium sulfate for preeclampsia?",
"How to manage postpartum hemorrhage in emergency situations?",
"Signs and symptoms of puerperal sepsis",
"Normal fetal heart rate during labor"
]
for query in test_queries:
logger.info(f"\n📝 Query: {query}")
results = vector_store.search(query, k=3)
for i, result in enumerate(results, 1):
logger.info(f" {i}. Score: {result.score:.3f} | Type: {result.chunk_type} | "
f"Importance: {result.clinical_importance:.2f}")
logger.info(f" Content: {result.content[:100]}...")
# Get statistics
stats = vector_store.get_statistics()
logger.info("\n📊 Vector Store Statistics:")
logger.info(f" Total chunks: {stats['total_chunks']}")
logger.info(f" Embedding dimension: {stats['embedding_dimension']}")
logger.info(f" High importance chunks: {stats['clinical_importance_distribution']['high'] + stats['clinical_importance_distribution']['critical']}")
logger.info(f" Vector store size: {stats['vector_store_size_mb']:.1f} MB")
logger.info("\n✅ Vector store creation and testing complete!")
if __name__ == "__main__":
main()