Spaces:
Sleeping
Sleeping
#!/usr/bin/env python3 | |
""" | |
Embedding Model Evaluator for Medical Content | |
Tests different free embedding models to find the best for maternal health guidelines | |
""" | |
import json | |
import numpy as np | |
from pathlib import Path | |
from typing import List, Dict, Any, Tuple | |
import logging | |
from sentence_transformers import SentenceTransformer | |
from sklearn.metrics.pairwise import cosine_similarity | |
from sklearn.cluster import KMeans | |
from sklearn.decomposition import PCA | |
import matplotlib.pyplot as plt | |
import time | |
# Configure logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
class MedicalEmbeddingEvaluator: | |
"""Evaluates different embedding models for medical content quality""" | |
def __init__(self, chunks_dir: Path = Path("comprehensive_chunks")): | |
self.chunks_dir = chunks_dir | |
self.medical_chunks = [] | |
self.evaluation_results = {} | |
# Free embedding models to test | |
self.embedding_models = { | |
'all-MiniLM-L6-v2': 'sentence-transformers/all-MiniLM-L6-v2', | |
'all-mpnet-base-v2': 'sentence-transformers/all-mpnet-base-v2', | |
'all-MiniLM-L12-v2': 'sentence-transformers/all-MiniLM-L12-v2', | |
'multi-qa-MiniLM-L6-cos-v1': 'sentence-transformers/multi-qa-MiniLM-L6-cos-v1', | |
'all-distilroberta-v1': 'sentence-transformers/all-distilroberta-v1' | |
} | |
# Medical test queries for evaluation | |
self.test_queries = [ | |
"What is the recommended dosage of magnesium sulfate for preeclampsia?", | |
"How to manage postpartum hemorrhage in emergency situations?", | |
"Normal ranges for fetal heart rate during labor", | |
"Contraindications for vaginal delivery in breech presentation", | |
"Signs and symptoms of puerperal sepsis", | |
"Management of gestational diabetes during pregnancy", | |
"Emergency cesarean section indications", | |
"Postpartum care guidelines for mother and baby", | |
"RhESUS incompatibility management protocol", | |
"Antepartum monitoring guidelines for high-risk pregnancy" | |
] | |
def load_medical_chunks(self) -> List[Dict]: | |
"""Load medical chunks from comprehensive chunking results""" | |
logger.info("Loading medical chunks for embedding evaluation...") | |
langchain_file = self.chunks_dir / "langchain_documents_comprehensive.json" | |
if not langchain_file.exists(): | |
raise FileNotFoundError(f"LangChain documents not found: {langchain_file}") | |
with open(langchain_file) as f: | |
chunks_data = json.load(f) | |
# Filter and prepare chunks for evaluation | |
medical_chunks = [] | |
for chunk in chunks_data: | |
content = chunk['page_content'] | |
metadata = chunk['metadata'] | |
# Skip very short chunks | |
if len(content.strip()) < 100: | |
continue | |
medical_chunks.append({ | |
'content': content, | |
'chunk_type': metadata.get('chunk_type', 'text'), | |
'clinical_importance': metadata.get('clinical_importance', 0.5), | |
'source': metadata.get('source', ''), | |
'has_dosage_info': metadata.get('has_dosage_info', False), | |
'is_maternal_specific': metadata.get('is_maternal_specific', False), | |
'has_clinical_protocols': metadata.get('has_clinical_protocols', False) | |
}) | |
logger.info(f"Loaded {len(medical_chunks)} medical chunks for evaluation") | |
return medical_chunks | |
def evaluate_embedding_model(self, model_name: str, model_path: str) -> Dict[str, Any]: | |
"""Evaluate a single embedding model""" | |
logger.info(f"Evaluating embedding model: {model_name}") | |
try: | |
# Load model | |
start_time = time.time() | |
model = SentenceTransformer(model_path) | |
load_time = time.time() - start_time | |
# Sample chunks for evaluation (use subset for speed) | |
sample_chunks = self.medical_chunks[:100] # Use first 100 chunks | |
chunk_texts = [chunk['content'] for chunk in sample_chunks] | |
# Generate embeddings for chunks | |
logger.info(f"Generating embeddings for {len(chunk_texts)} chunks...") | |
start_time = time.time() | |
chunk_embeddings = model.encode(chunk_texts, show_progress_bar=True) | |
chunk_embed_time = time.time() - start_time | |
# Generate embeddings for test queries | |
start_time = time.time() | |
query_embeddings = model.encode(self.test_queries) | |
query_embed_time = time.time() - start_time | |
# Evaluation metrics | |
results = { | |
'model_name': model_name, | |
'model_path': model_path, | |
'load_time': load_time, | |
'chunk_embed_time': chunk_embed_time, | |
'query_embed_time': query_embed_time, | |
'embedding_dimension': chunk_embeddings.shape[1], | |
'chunks_processed': len(chunk_texts), | |
'queries_processed': len(self.test_queries) | |
} | |
# Test semantic search quality | |
search_results = self._evaluate_search_quality( | |
query_embeddings, chunk_embeddings, sample_chunks | |
) | |
results.update(search_results) | |
# Test clustering quality | |
cluster_results = self._evaluate_clustering_quality( | |
chunk_embeddings, sample_chunks | |
) | |
results.update(cluster_results) | |
# Calculate overall score | |
results['overall_score'] = self._calculate_overall_score(results) | |
logger.info(f"β {model_name} evaluation complete - Overall Score: {results['overall_score']:.3f}") | |
return results | |
except Exception as e: | |
logger.error(f"β Failed to evaluate {model_name}: {e}") | |
return { | |
'model_name': model_name, | |
'model_path': model_path, | |
'error': str(e), | |
'overall_score': 0.0 | |
} | |
def _evaluate_search_quality(self, query_embeddings: np.ndarray, | |
chunk_embeddings: np.ndarray, | |
chunks: List[Dict]) -> Dict[str, float]: | |
"""Evaluate semantic search quality""" | |
# Calculate similarities between queries and chunks | |
similarities = cosine_similarity(query_embeddings, chunk_embeddings) | |
search_metrics = { | |
'avg_max_similarity': 0.0, | |
'medical_content_precision': 0.0, | |
'dosage_query_accuracy': 0.0, | |
'emergency_query_accuracy': 0.0 | |
} | |
total_queries = len(self.test_queries) | |
for i, query in enumerate(self.test_queries): | |
query_similarities = similarities[i] | |
top_indices = np.argsort(query_similarities)[::-1][:5] # Top 5 results | |
# Max similarity for this query | |
max_sim = np.max(query_similarities) | |
search_metrics['avg_max_similarity'] += max_sim | |
# Check if top results contain relevant medical content | |
top_chunks = [chunks[idx] for idx in top_indices] | |
medical_relevant = sum(1 for chunk in top_chunks | |
if chunk['clinical_importance'] > 0.7) | |
search_metrics['medical_content_precision'] += medical_relevant / 5 | |
# Specific query type accuracy | |
if 'dosage' in query.lower() or 'dose' in query.lower(): | |
dosage_relevant = sum(1 for chunk in top_chunks | |
if chunk['has_dosage_info']) | |
search_metrics['dosage_query_accuracy'] += dosage_relevant / 5 | |
if 'emergency' in query.lower() or 'urgent' in query.lower(): | |
emergency_relevant = sum(1 for chunk in top_chunks | |
if chunk['chunk_type'] == 'emergency') | |
search_metrics['emergency_query_accuracy'] += emergency_relevant / 5 | |
# Average the metrics | |
for key in search_metrics: | |
search_metrics[key] /= total_queries | |
return search_metrics | |
def _evaluate_clustering_quality(self, embeddings: np.ndarray, | |
chunks: List[Dict]) -> Dict[str, float]: | |
"""Evaluate how well embeddings cluster similar medical content""" | |
# Perform clustering | |
n_clusters = min(8, len(chunks) // 10) # Reasonable number of clusters | |
kmeans = KMeans(n_clusters=n_clusters, random_state=42) | |
cluster_labels = kmeans.fit_predict(embeddings) | |
# Calculate cluster purity based on chunk types | |
cluster_metrics = { | |
'cluster_purity': 0.0, | |
'dosage_cluster_coherence': 0.0, | |
'maternal_cluster_coherence': 0.0 | |
} | |
# Calculate cluster purity | |
total_items = len(chunks) | |
for cluster_id in range(n_clusters): | |
cluster_indices = np.where(cluster_labels == cluster_id)[0] | |
if len(cluster_indices) == 0: | |
continue | |
cluster_chunks = [chunks[i] for i in cluster_indices] | |
# Find dominant chunk type in this cluster | |
chunk_types = [chunk['chunk_type'] for chunk in cluster_chunks] | |
if chunk_types: | |
dominant_type = max(set(chunk_types), key=chunk_types.count) | |
purity = chunk_types.count(dominant_type) / len(chunk_types) | |
cluster_metrics['cluster_purity'] += purity * len(cluster_indices) / total_items | |
# Check dosage content clustering | |
dosage_chunks = [chunk for chunk in cluster_chunks if chunk['has_dosage_info']] | |
if len(cluster_chunks) > 0: | |
dosage_ratio = len(dosage_chunks) / len(cluster_chunks) | |
if dosage_ratio > 0.5: # If majority are dosage chunks | |
cluster_metrics['dosage_cluster_coherence'] += dosage_ratio | |
# Check maternal content clustering | |
maternal_chunks = [chunk for chunk in cluster_chunks if chunk['is_maternal_specific']] | |
if len(cluster_chunks) > 0: | |
maternal_ratio = len(maternal_chunks) / len(cluster_chunks) | |
if maternal_ratio > 0.5: # If majority are maternal chunks | |
cluster_metrics['maternal_cluster_coherence'] += maternal_ratio | |
return cluster_metrics | |
def _calculate_overall_score(self, results: Dict[str, Any]) -> float: | |
"""Calculate overall score for the embedding model""" | |
if 'error' in results: | |
return 0.0 | |
# Weighted scoring components | |
weights = { | |
'search_quality': 0.4, | |
'clustering_quality': 0.2, | |
'speed': 0.2, | |
'medical_relevance': 0.2 | |
} | |
# Search quality score (0-1) | |
search_score = ( | |
results.get('avg_max_similarity', 0) * 0.4 + | |
results.get('medical_content_precision', 0) * 0.3 + | |
results.get('dosage_query_accuracy', 0) * 0.15 + | |
results.get('emergency_query_accuracy', 0) * 0.15 | |
) | |
# Clustering quality score (0-1) | |
cluster_score = ( | |
results.get('cluster_purity', 0) * 0.5 + | |
results.get('dosage_cluster_coherence', 0) * 0.25 + | |
results.get('maternal_cluster_coherence', 0) * 0.25 | |
) | |
# Speed score (inverse of time, normalized) | |
total_time = results.get('chunk_embed_time', 1) + results.get('query_embed_time', 1) | |
speed_score = max(0, 1 - (total_time / 100)) # Normalize to 0-1 | |
# Medical relevance (based on search accuracy for medical queries) | |
medical_score = ( | |
results.get('medical_content_precision', 0) * 0.6 + | |
results.get('dosage_query_accuracy', 0) * 0.4 | |
) | |
# Calculate weighted overall score | |
overall = ( | |
search_score * weights['search_quality'] + | |
cluster_score * weights['clustering_quality'] + | |
speed_score * weights['speed'] + | |
medical_score * weights['medical_relevance'] | |
) | |
return min(1.0, max(0.0, overall)) | |
def run_comprehensive_evaluation(self) -> Dict[str, Any]: | |
"""Run comprehensive evaluation of all embedding models""" | |
logger.info("Starting comprehensive embedding model evaluation...") | |
# Load medical chunks | |
self.medical_chunks = self.load_medical_chunks() | |
if len(self.medical_chunks) == 0: | |
raise ValueError("No medical chunks loaded for evaluation") | |
# Evaluate each model | |
results = {} | |
for model_name, model_path in self.embedding_models.items(): | |
logger.info(f"\nπ Evaluating: {model_name}") | |
results[model_name] = self.evaluate_embedding_model(model_name, model_path) | |
# Generate summary report | |
summary = self._generate_evaluation_summary(results) | |
# Save results | |
output_file = Path("src/embedding_evaluation_results.json") | |
with open(output_file, 'w') as f: | |
json.dump({ | |
'evaluation_summary': summary, | |
'detailed_results': results, | |
'test_queries': self.test_queries, | |
'chunks_evaluated': len(self.medical_chunks) | |
}, f, indent=2) | |
logger.info(f"π Evaluation results saved to: {output_file}") | |
return summary | |
def _generate_evaluation_summary(self, results: Dict[str, Any]) -> Dict[str, Any]: | |
"""Generate evaluation summary with recommendations""" | |
valid_results = {k: v for k, v in results.items() if 'error' not in v} | |
if not valid_results: | |
return {'error': 'No models evaluated successfully'} | |
# Find best model | |
best_model = max(valid_results.items(), key=lambda x: x[1]['overall_score']) | |
# Calculate averages | |
avg_scores = {} | |
for metric in ['overall_score', 'avg_max_similarity', 'medical_content_precision']: | |
scores = [r.get(metric, 0) for r in valid_results.values()] | |
avg_scores[f'avg_{metric}'] = sum(scores) / len(scores) if scores else 0 | |
summary = { | |
'best_model': { | |
'name': best_model[0], | |
'path': best_model[1]['model_path'], | |
'score': best_model[1]['overall_score'], | |
'strengths': [] | |
}, | |
'model_rankings': sorted( | |
[(name, res['overall_score']) for name, res in valid_results.items()], | |
key=lambda x: x[1], reverse=True | |
), | |
'evaluation_metrics': avg_scores, | |
'recommendation': '', | |
'models_tested': len(results), | |
'successful_evaluations': len(valid_results) | |
} | |
# Add strengths and recommendation | |
best_result = best_model[1] | |
strengths = [] | |
if best_result.get('medical_content_precision', 0) > 0.7: | |
strengths.append("High medical content precision") | |
if best_result.get('dosage_query_accuracy', 0) > 0.6: | |
strengths.append("Good dosage information retrieval") | |
if best_result.get('cluster_purity', 0) > 0.6: | |
strengths.append("Effective content clustering") | |
if best_result.get('chunk_embed_time', 100) < 30: | |
strengths.append("Fast embedding generation") | |
summary['best_model']['strengths'] = strengths | |
summary['recommendation'] = ( | |
f"Recommended model: {best_model[0]} with overall score {best_result['overall_score']:.3f}. " | |
f"This model shows {', '.join(strengths)} and is well-suited for maternal health content." | |
) | |
return summary | |
def main(): | |
"""Main evaluation function""" | |
evaluator = MedicalEmbeddingEvaluator() | |
try: | |
summary = evaluator.run_comprehensive_evaluation() | |
# Print summary | |
logger.info("=" * 80) | |
logger.info("EMBEDDING MODEL EVALUATION COMPLETE!") | |
logger.info("=" * 80) | |
logger.info(f"π Best Model: {summary['best_model']['name']}") | |
logger.info(f"π Overall Score: {summary['best_model']['score']:.3f}") | |
logger.info(f"πͺ Strengths: {', '.join(summary['best_model']['strengths'])}") | |
logger.info(f"π Recommendation: {summary['recommendation']}") | |
logger.info("\nπ Model Rankings:") | |
for i, (model, score) in enumerate(summary['model_rankings'], 1): | |
logger.info(f"{i}. {model}: {score:.3f}") | |
logger.info("=" * 80) | |
return summary | |
except Exception as e: | |
logger.error(f"β Evaluation failed: {e}") | |
return None | |
if __name__ == "__main__": | |
main() |