menopause-ml / predict_csv.py
techatcreated's picture
v1
66d45ea verified
#!/usr/bin/env python
"""
CSV Prediction Script for SWAN Menopause Stage Forecasting
This script demonstrates how to use the trained forecasting module to make predictions
on a batch of individuals from a CSV file and save results with confidence scores
and performance metrics.
Usage:
python predict_csv.py --input demo_individuals.csv --model RandomForest
python predict_csv.py --input individuals.csv --output results.csv --model LogisticRegression
The script will:
1. Read input CSV with individual feature values
2. Make predictions using trained model
3. Save results with predicted stage, confidence, and probabilities
4. Display summary statistics
"""
import os
import sys
import argparse
import pandas as pd
import numpy as np
from pathlib import Path
def main():
"""Main function to handle CSV prediction."""
parser = argparse.ArgumentParser(
description='Make menopause stage predictions from CSV file'
)
parser.add_argument(
'--input', '-i',
required=True,
help='Path to input CSV file with individual feature values'
)
parser.add_argument(
'--output', '-o',
default=None,
help='Path to output CSV file (default: input_predictions.csv)'
)
parser.add_argument(
'--model', '-m',
choices=['RandomForest', 'LogisticRegression'],
default='RandomForest',
help='Which model to use for predictions'
)
parser.add_argument(
'--forecast-dir',
default='swan_ml_output',
help='Directory containing trained forecast models'
)
args = parser.parse_args()
# Import after parsing args
try:
from menopause import load_forecast_model, predict_from_csv
except ImportError:
print("ERROR: Could not import menopause module.")
print("Make sure you're in the correct directory and menopause.py is available.")
sys.exit(1)
# Check if input file exists
if not os.path.exists(args.input):
print(f"ERROR: Input file not found: {args.input}")
sys.exit(1)
# Check if forecast models exist
forecast_dir = args.forecast_dir
if not os.path.exists(os.path.join(forecast_dir, 'rf_pipeline.pkl')):
print(f"ERROR: Forecast models not found in {forecast_dir}")
print("Please run 'python menopause.py' first to train models.")
sys.exit(1)
print("="*80)
print("MENOPAUSE STAGE PREDICTION FROM CSV")
print("="*80)
# Load forecaster
print(f"\nLoading forecaster from {forecast_dir}...")
forecast = load_forecast_model(forecast_dir)
# Make predictions
print(f"\nUsing model: {args.model}")
results = predict_from_csv(
args.input,
forecast,
output_csv=args.output,
model=args.model,
output_dir='.'
)
if results is not None:
print("\n" + "="*80)
print("PREDICTION RESULTS")
print("="*80)
# Display results table
print("\nDetailed Results:")
print(results.to_string(index=False))
# Display performance metrics
print("\n" + "="*80)
print("PERFORMANCE SUMMARY")
print("="*80)
print(f"\nTotal Individuals: {len(results)}")
print(f"\nStage Distribution:")
for stage, count in results['predicted_stage'].value_counts().items():
pct = count / len(results) * 100
print(f" {stage}: {count} ({pct:.1f}%)")
print(f"\nConfidence Scores:")
print(f" Mean: {results['confidence'].mean():.3f}")
print(f" Min: {results['confidence'].min():.3f}")
print(f" Max: {results['confidence'].max():.3f}")
print(f" Std Dev: {results['confidence'].std():.3f}")
# Confidence distribution
high_conf = (results['confidence'] > 0.8).sum()
med_conf = ((results['confidence'] > 0.6) & (results['confidence'] <= 0.8)).sum()
low_conf = (results['confidence'] <= 0.6).sum()
print(f"\nConfidence Distribution:")
print(f" High (>0.80): {high_conf}/{len(results)} ({high_conf/len(results)*100:.1f}%)")
print(f" Medium (0.60-0.80): {med_conf}/{len(results)} ({med_conf/len(results)*100:.1f}%)")
print(f" Low (≤0.60): {low_conf}/{len(results)} ({low_conf/len(results)*100:.1f}%)")
# Output file confirmation
output_path = args.output if args.output else f"{Path(args.input).stem}_predictions.csv"
print(f"\n✅ Results saved to: {output_path}")
else:
print("ERROR: Prediction failed.")
sys.exit(1)
print("\n" + "="*80)
if __name__ == '__main__':
main()