| | |
| | """ |
| | Compare trained models (base vs medium) on expression generation complexity. |
| | Runs REINFORCE for a few epochs on Nguyen-5 to see which model explores better. |
| | """ |
| |
|
| | import os |
| | import sys |
| | import json |
| | import argparse |
| | from pathlib import Path |
| |
|
| | import numpy as np |
| | import torch |
| |
|
| | |
| | PROJECT_ROOT = Path(__file__).parent.parent |
| | sys.path.insert(0, str(PROJECT_ROOT)) |
| | sys.path.insert(0, str(PROJECT_ROOT / "classes")) |
| |
|
| | from scripts.debug_reinforce import DebugREINFORCE |
| |
|
| |
|
| | def main(): |
| | parser = argparse.ArgumentParser() |
| | parser.add_argument("--model_base", type=str, required=True, help="Path to trained base model") |
| | parser.add_argument("--model_medium", type=str, required=True, help="Path to trained medium model") |
| | parser.add_argument("--dataset", type=str, default="data/benchmarks/nguyen/nguyen_5.csv") |
| | parser.add_argument("--epochs", type=int, default=10) |
| | parser.add_argument("--output_file", type=str, default="model_comparison.json") |
| | args = parser.parse_args() |
| |
|
| | |
| | import pandas as pd |
| | df = pd.read_csv(args.dataset) |
| | x_cols = [c for c in df.columns if c.startswith('x_')] |
| | X = df[x_cols].values |
| | y = df['y'].values |
| |
|
| | print("="*80) |
| | print("COMPARING TRAINED MODELS") |
| | print("="*80) |
| | print(f"Dataset: {args.dataset}") |
| | print(f" Samples: {len(df)}, Variables: {len(x_cols)}") |
| | print(f" Target range: [{y.min():.4f}, {y.max():.4f}]") |
| | print() |
| |
|
| | results = {} |
| |
|
| | for model_name, model_path in [("base", args.model_base), ("medium", args.model_medium)]: |
| | print("="*80) |
| | print(f"Testing {model_name.upper()} model: {model_path}") |
| | print("="*80) |
| |
|
| | |
| | reinforce = DebugREINFORCE(model_path, X, y) |
| | reinforce.run(epochs=args.epochs) |
| |
|
| | |
| | expressions = reinforce.all_expressions |
| | valid = [e for e in expressions if e["is_valid"]] |
| |
|
| | |
| | has_power = sum(1 for e in valid if '**' in e['expression'] or 'pow(' in e['expression']) |
| | has_nested_trig = sum(1 for e in valid |
| | if any(nested in e['expression'] |
| | for nested in ['sin(sin', 'sin(cos', 'cos(sin', 'cos(cos'])) |
| |
|
| | depths = [] |
| | for e in valid: |
| | depth = max(e['expression'].count('('), 1) |
| | depths.append(depth) |
| |
|
| | best_r2 = max((e['r2'] for e in expressions), default=-1.0) |
| |
|
| | results[model_name] = { |
| | "model_path": model_path, |
| | "total_expressions": len(expressions), |
| | "valid_count": len(valid), |
| | "valid_pct": 100 * len(valid) / len(expressions) if expressions else 0, |
| | "has_power": has_power, |
| | "has_power_pct": 100 * has_power / len(valid) if valid else 0, |
| | "has_nested_trig": has_nested_trig, |
| | "has_nested_trig_pct": 100 * has_nested_trig / len(valid) if valid else 0, |
| | "avg_depth": sum(depths) / len(depths) if depths else 0, |
| | "max_depth": max(depths) if depths else 0, |
| | "best_r2": float(best_r2), |
| | } |
| |
|
| | print() |
| | print(f"Results for {model_name}:") |
| | print(f" Valid: {len(valid)}/{len(expressions)} ({results[model_name]['valid_pct']:.1f}%)") |
| | print(f" With power: {has_power} ({results[model_name]['has_power_pct']:.1f}%)") |
| | print(f" With nested trig: {has_nested_trig} ({results[model_name]['has_nested_trig_pct']:.1f}%)") |
| | print(f" Avg depth: {results[model_name]['avg_depth']:.2f}") |
| | print(f" Best R2: {results[model_name]['best_r2']:.4f}") |
| | print() |
| |
|
| | |
| | with open(args.output_file, 'w') as f: |
| | json.dump(results, f, indent=2) |
| |
|
| | print("="*80) |
| | print("COMPARISON SUMMARY") |
| | print("="*80) |
| | print(f"{'Metric':<25} {'Base':>15} {'Medium':>15} {'Improvement':>15}") |
| | print("-"*80) |
| |
|
| | metrics = [ |
| | ("Valid %", "valid_pct", "%"), |
| | ("Power %", "has_power_pct", "%"), |
| | ("Nested Trig %", "has_nested_trig_pct", "%"), |
| | ("Avg Depth", "avg_depth", ""), |
| | ("Best R2", "best_r2", ""), |
| | ] |
| |
|
| | for label, key, unit in metrics: |
| | base_val = results["base"][key] |
| | medium_val = results["medium"][key] |
| |
|
| | if base_val != 0: |
| | improvement = ((medium_val - base_val) / abs(base_val)) * 100 |
| | print(f"{label:<25} {base_val:>14.2f}{unit} {medium_val:>14.2f}{unit} {improvement:>+14.1f}%") |
| | else: |
| | print(f"{label:<25} {base_val:>14.2f}{unit} {medium_val:>14.2f}{unit} {'N/A':>15}") |
| |
|
| | print() |
| | print(f"Results saved to: {args.output_file}") |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|