Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python | |
| """ | |
| CSV Prediction Script for SWAN Menopause Stage Forecasting | |
| This script demonstrates how to use the trained forecasting module to make predictions | |
| on a batch of individuals from a CSV file and save results with confidence scores | |
| and performance metrics. | |
| Usage: | |
| python predict_csv.py --input demo_individuals.csv --model RandomForest | |
| python predict_csv.py --input individuals.csv --output results.csv --model LogisticRegression | |
| The script will: | |
| 1. Read input CSV with individual feature values | |
| 2. Make predictions using trained model | |
| 3. Save results with predicted stage, confidence, and probabilities | |
| 4. Display summary statistics | |
| """ | |
| import os | |
| import sys | |
| import argparse | |
| import pandas as pd | |
| import numpy as np | |
| from pathlib import Path | |
| def main(): | |
| """Main function to handle CSV prediction.""" | |
| parser = argparse.ArgumentParser( | |
| description='Make menopause stage predictions from CSV file' | |
| ) | |
| parser.add_argument( | |
| '--input', '-i', | |
| required=True, | |
| help='Path to input CSV file with individual feature values' | |
| ) | |
| parser.add_argument( | |
| '--output', '-o', | |
| default=None, | |
| help='Path to output CSV file (default: input_predictions.csv)' | |
| ) | |
| parser.add_argument( | |
| '--model', '-m', | |
| choices=['RandomForest', 'LogisticRegression'], | |
| default='RandomForest', | |
| help='Which model to use for predictions' | |
| ) | |
| parser.add_argument( | |
| '--forecast-dir', | |
| default='swan_ml_output', | |
| help='Directory containing trained forecast models' | |
| ) | |
| args = parser.parse_args() | |
| # Import after parsing args | |
| try: | |
| from menopause import load_forecast_model, predict_from_csv | |
| except ImportError: | |
| print("ERROR: Could not import menopause module.") | |
| print("Make sure you're in the correct directory and menopause.py is available.") | |
| sys.exit(1) | |
| # Check if input file exists | |
| if not os.path.exists(args.input): | |
| print(f"ERROR: Input file not found: {args.input}") | |
| sys.exit(1) | |
| # Check if forecast models exist | |
| forecast_dir = args.forecast_dir | |
| if not os.path.exists(os.path.join(forecast_dir, 'rf_pipeline.pkl')): | |
| print(f"ERROR: Forecast models not found in {forecast_dir}") | |
| print("Please run 'python menopause.py' first to train models.") | |
| sys.exit(1) | |
| print("="*80) | |
| print("MENOPAUSE STAGE PREDICTION FROM CSV") | |
| print("="*80) | |
| # Load forecaster | |
| print(f"\nLoading forecaster from {forecast_dir}...") | |
| forecast = load_forecast_model(forecast_dir) | |
| # Make predictions | |
| print(f"\nUsing model: {args.model}") | |
| results = predict_from_csv( | |
| args.input, | |
| forecast, | |
| output_csv=args.output, | |
| model=args.model, | |
| output_dir='.' | |
| ) | |
| if results is not None: | |
| print("\n" + "="*80) | |
| print("PREDICTION RESULTS") | |
| print("="*80) | |
| # Display results table | |
| print("\nDetailed Results:") | |
| print(results.to_string(index=False)) | |
| # Display performance metrics | |
| print("\n" + "="*80) | |
| print("PERFORMANCE SUMMARY") | |
| print("="*80) | |
| print(f"\nTotal Individuals: {len(results)}") | |
| print(f"\nStage Distribution:") | |
| for stage, count in results['predicted_stage'].value_counts().items(): | |
| pct = count / len(results) * 100 | |
| print(f" {stage}: {count} ({pct:.1f}%)") | |
| print(f"\nConfidence Scores:") | |
| print(f" Mean: {results['confidence'].mean():.3f}") | |
| print(f" Min: {results['confidence'].min():.3f}") | |
| print(f" Max: {results['confidence'].max():.3f}") | |
| print(f" Std Dev: {results['confidence'].std():.3f}") | |
| # Confidence distribution | |
| high_conf = (results['confidence'] > 0.8).sum() | |
| med_conf = ((results['confidence'] > 0.6) & (results['confidence'] <= 0.8)).sum() | |
| low_conf = (results['confidence'] <= 0.6).sum() | |
| print(f"\nConfidence Distribution:") | |
| print(f" High (>0.80): {high_conf}/{len(results)} ({high_conf/len(results)*100:.1f}%)") | |
| print(f" Medium (0.60-0.80): {med_conf}/{len(results)} ({med_conf/len(results)*100:.1f}%)") | |
| print(f" Low (≤0.60): {low_conf}/{len(results)} ({low_conf/len(results)*100:.1f}%)") | |
| # Output file confirmation | |
| output_path = args.output if args.output else f"{Path(args.input).stem}_predictions.csv" | |
| print(f"\n✅ Results saved to: {output_path}") | |
| else: | |
| print("ERROR: Prediction failed.") | |
| sys.exit(1) | |
| print("\n" + "="*80) | |
| if __name__ == '__main__': | |
| main() | |