| | """ |
| | Compare two models: band-aided vs properly trained. |
| | Evaluates both on same test set and reports metrics. |
| | |
| | Usage: |
| | python scripts/compare_models.py \ |
| | --model1 ./output/Se124M_700K_infix \ |
| | --model2 ./output/Se124M_700K_infix_v2 \ |
| | --num_samples 500 |
| | """ |
| |
|
| | import argparse |
| | import json |
| | import os |
| | import sys |
| | from datetime import datetime |
| |
|
| | |
| | sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) |
| | from evaluate import evaluate_model |
| |
|
| |
|
| | def format_metric(value, metric_type): |
| | """Format metric value for display.""" |
| | if metric_type == "rate": |
| | return f"{value * 100:5.1f}%" |
| | elif metric_type == "float": |
| | return f"{value:7.2f}" |
| | elif metric_type == "int": |
| | return f"{int(value):7d}" |
| | else: |
| | return f"{value:7}" |
| |
|
| |
|
| | def print_comparison_table(metrics1, metrics2, model1_name, model2_name): |
| | """Print formatted comparison table.""" |
| | print("\n" + "=" * 80) |
| | print("COMPARISON RESULTS") |
| | print("=" * 80) |
| |
|
| | |
| | print(f"{'Metric':<35} {model1_name:>20} {model2_name:>20}") |
| | print("-" * 80) |
| |
|
| | |
| | comparison_metrics = [ |
| | ("valid_rate", "Valid Rate", "rate"), |
| | ("parseable_rate", "Parseable Rate", "rate"), |
| | ("constraints_met_rate", "Constraints Met", "rate"), |
| | ("diversity_rate", "Diversity", "rate"), |
| | ("avg_expression_length", "Avg Expression Length", "float"), |
| | ("total_samples", "Total Samples", "int"), |
| | ("total_valid", "Total Valid", "int"), |
| | ] |
| |
|
| | improvements = [] |
| |
|
| | for key, label, metric_type in comparison_metrics: |
| | val1 = metrics1.get(key, 0) |
| | val2 = metrics2.get(key, 0) |
| |
|
| | formatted_val1 = format_metric(val1, metric_type) |
| | formatted_val2 = format_metric(val2, metric_type) |
| |
|
| | print(f"{label:<35} {formatted_val1:>20} {formatted_val2:>20}") |
| |
|
| | |
| | if metric_type == "rate" and val1 > 0: |
| | improvement = ((val2 - val1) / val1) * 100 |
| | improvements.append((label, improvement, val2 - val1)) |
| |
|
| | print("=" * 80) |
| |
|
| | |
| | print("\nIMPROVEMENTS (Model 2 vs Model 1):") |
| | print("-" * 80) |
| |
|
| | for label, improvement, absolute_diff in improvements: |
| | sign = "+" if improvement > 0 else "" |
| | abs_sign = "+" if absolute_diff > 0 else "" |
| | print(f"{label:<35} {sign}{improvement:>6.1f}% ({abs_sign}{absolute_diff * 100:>5.1f} pp)") |
| |
|
| | print("-" * 80) |
| |
|
| | |
| | valid_rate_improvement = metrics2.get("valid_rate", 0) - metrics1.get("valid_rate", 0) |
| |
|
| | print("\n" + "=" * 80) |
| | if valid_rate_improvement > 0.20: |
| | print(f"🎯 SIGNIFICANT IMPROVEMENT: Model 2 wins by {valid_rate_improvement * 100:.1f} percentage points") |
| | print(" The properly trained model significantly outperforms the band-aided version!") |
| | elif valid_rate_improvement > 0.05: |
| | print(f"✅ IMPROVEMENT: Model 2 wins by {valid_rate_improvement * 100:.1f} percentage points") |
| | print(" The properly trained model shows clear improvement.") |
| | elif valid_rate_improvement > 0: |
| | print(f"📈 SLIGHT IMPROVEMENT: Model 2 wins by {valid_rate_improvement * 100:.1f} percentage points") |
| | print(" The properly trained model shows modest improvement.") |
| | elif valid_rate_improvement == 0: |
| | print("⚖️ TIE: Both models perform equally") |
| | print(" No significant difference between models.") |
| | else: |
| | print(f"⚠️ REGRESSION: Model 1 wins by {-valid_rate_improvement * 100:.1f} percentage points") |
| | print(" The band-aided model performs better - retraining may need adjustment.") |
| |
|
| | print("=" * 80) |
| |
|
| |
|
| | def save_comparison_report(metrics1, metrics2, model1_name, model2_name, output_dir): |
| | """Save detailed comparison report to JSON.""" |
| | os.makedirs(output_dir, exist_ok=True) |
| |
|
| | timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") |
| | report_file = os.path.join(output_dir, f"comparison_{timestamp}.json") |
| |
|
| | report = { |
| | "timestamp": timestamp, |
| | "model1": { |
| | "name": model1_name, |
| | "metrics": metrics1 |
| | }, |
| | "model2": { |
| | "name": model2_name, |
| | "metrics": metrics2 |
| | }, |
| | "comparison": { |
| | "valid_rate_diff": metrics2.get("valid_rate", 0) - metrics1.get("valid_rate", 0), |
| | "parseable_rate_diff": metrics2.get("parseable_rate", 0) - metrics1.get("parseable_rate", 0), |
| | "constraints_met_diff": metrics2.get("constraints_met_rate", 0) - metrics1.get("constraints_met_rate", 0), |
| | "diversity_diff": metrics2.get("diversity_rate", 0) - metrics1.get("diversity_rate", 0), |
| | } |
| | } |
| |
|
| | with open(report_file, "w") as f: |
| | json.dump(report, f, indent=2) |
| |
|
| | print(f"\n📄 Detailed comparison report saved to: {report_file}") |
| | return report_file |
| |
|
| |
|
| | def compare_models(model1_path, model2_path, model1_name, model2_name, |
| | num_samples=500, dataset_repo_id="augustocsc/sintetico_natural", |
| | data_dir="700K", data_column="i_prompt_n", output_dir="./evaluation_results/comparison"): |
| | """Compare two models on same test set.""" |
| |
|
| | print("=" * 80) |
| | print("MODEL COMPARISON") |
| | print("=" * 80) |
| | print(f"Model 1 ({model1_name}): {model1_path}") |
| | print(f"Model 2 ({model2_name}): {model2_path}") |
| | print(f"Samples: {num_samples}") |
| | print(f"Dataset: {dataset_repo_id}/{data_dir}") |
| | print("=" * 80) |
| |
|
| | |
| | os.makedirs(output_dir, exist_ok=True) |
| |
|
| | |
| | print(f"\n[1/2] Evaluating Model 1: {model1_name}") |
| | print("-" * 80) |
| |
|
| | args1 = argparse.Namespace( |
| | model_path=model1_path, |
| | base_model=None, |
| | dataset_repo_id=dataset_repo_id, |
| | data_dir=data_dir, |
| | data_column=data_column, |
| | num_samples=num_samples, |
| | num_generations=1, |
| | max_new_tokens=128, |
| | temperature=0.7, |
| | top_p=0.9, |
| | output_dir=os.path.join(output_dir, "model1"), |
| | seed=42, |
| | device="auto" |
| | ) |
| |
|
| | try: |
| | metrics1 = evaluate_model(args1) |
| | except Exception as e: |
| | print(f"\n❌ Error evaluating Model 1: {e}") |
| | import traceback |
| | traceback.print_exc() |
| | sys.exit(1) |
| |
|
| | |
| | print(f"\n[2/2] Evaluating Model 2: {model2_name}") |
| | print("-" * 80) |
| |
|
| | args2 = argparse.Namespace( |
| | model_path=model2_path, |
| | base_model=None, |
| | dataset_repo_id=dataset_repo_id, |
| | data_dir=data_dir, |
| | data_column=data_column, |
| | num_samples=num_samples, |
| | num_generations=1, |
| | max_new_tokens=128, |
| | temperature=0.7, |
| | top_p=0.9, |
| | output_dir=os.path.join(output_dir, "model2"), |
| | seed=42, |
| | device="auto" |
| | ) |
| |
|
| | try: |
| | metrics2 = evaluate_model(args2) |
| | except Exception as e: |
| | print(f"\n❌ Error evaluating Model 2: {e}") |
| | import traceback |
| | traceback.print_exc() |
| | sys.exit(1) |
| |
|
| | |
| | print_comparison_table(metrics1, metrics2, model1_name, model2_name) |
| |
|
| | |
| | save_comparison_report(metrics1, metrics2, model1_name, model2_name, output_dir) |
| |
|
| | return metrics1, metrics2 |
| |
|
| |
|
| | def main(): |
| | parser = argparse.ArgumentParser( |
| | description="Compare two models on the same test set" |
| | ) |
| | parser.add_argument("--model1", type=str, required=True, |
| | help="Path to first model (band-aided)") |
| | parser.add_argument("--model2", type=str, required=True, |
| | help="Path to second model (properly trained)") |
| | parser.add_argument("--model1_name", type=str, default="Band-Aided", |
| | help="Display name for model 1") |
| | parser.add_argument("--model2_name", type=str, default="Proper", |
| | help="Display name for model 2") |
| | parser.add_argument("--num_samples", type=int, default=500, |
| | help="Number of samples to evaluate") |
| | parser.add_argument("--dataset_repo_id", type=str, default="augustocsc/sintetico_natural", |
| | help="HuggingFace dataset repository") |
| | parser.add_argument("--data_dir", type=str, default="700K", |
| | help="Data directory within dataset") |
| | parser.add_argument("--data_column", type=str, default="i_prompt_n", |
| | help="Column name for prompts") |
| | parser.add_argument("--output_dir", type=str, default="./evaluation_results/comparison", |
| | help="Directory to save comparison results") |
| |
|
| | args = parser.parse_args() |
| |
|
| | |
| | try: |
| | compare_models( |
| | model1_path=args.model1, |
| | model2_path=args.model2, |
| | model1_name=args.model1_name, |
| | model2_name=args.model2_name, |
| | num_samples=args.num_samples, |
| | dataset_repo_id=args.dataset_repo_id, |
| | data_dir=args.data_dir, |
| | data_column=args.data_column, |
| | output_dir=args.output_dir |
| | ) |
| |
|
| | print("\n✅ Comparison complete!") |
| |
|
| | except Exception as e: |
| | print(f"\n❌ Error during comparison: {e}") |
| | import traceback |
| | traceback.print_exc() |
| | sys.exit(1) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|