pytorch-training-debugger / validation /run_all_validations.py
omkarrr88
Real training curves added
aa0bed2
#!/usr/bin/env python3
"""Run all validation checks and produce a fidelity report.
Validates that real PyTorch mini-training produces qualitatively correct
behaviors for each fault type. Uses behavioral checks appropriate for
real training on tiny random-data models (not parametric formula checks).
"""
from __future__ import annotations
import json
import sys
from pathlib import Path
import torch
import torch.nn as nn
sys.path.insert(0, str(Path(__file__).parent.parent))
from ml_training_debugger.pytorch_engine import (
SimpleCNN,
SimpleMLP,
create_model_and_inject_fault,
extract_gradient_stats,
extract_model_modes,
extract_weight_stats,
run_real_training,
)
from ml_training_debugger.scenarios import sample_scenario
from ml_training_debugger.simulation import gen_data_batch_stats
def validate_exploding_gradients() -> dict:
"""Task 1: High LR produces gradient instability."""
scenario = sample_scenario("task_001", seed=42)
model, _ = create_model_and_inject_fault(scenario)
stats = extract_gradient_stats(model, scenario)
curves = run_real_training(scenario)
any_exploding = any(s.is_exploding for s in stats)
loss_unstable = max(curves["loss_history"]) > 5.0
max_grad = max(s.mean_norm for s in stats)
return {
"task": "task_001",
"fault": "exploding_gradients",
"checks": {
"gradient_instability_detected": any_exploding,
"loss_shows_instability": loss_unstable,
"max_gradient_norm": round(max_grad, 2),
"max_loss": round(max(curves["loss_history"]), 2),
"real_pytorch_training": True,
},
"pass": any_exploding and loss_unstable,
}
def validate_vanishing_gradients() -> dict:
"""Task 2: Low LR + scaled gradients produce vanishing."""
scenario = sample_scenario("task_002", seed=42)
model, _ = create_model_and_inject_fault(scenario)
stats = extract_gradient_stats(model, scenario)
any_vanishing = any(s.is_vanishing for s in stats)
min_grad = min(s.mean_norm for s in stats)
return {
"task": "task_002",
"fault": "vanishing_gradients",
"checks": {
"vanishing_detected": any_vanishing,
"min_gradient_norm": round(min_grad, 10),
"real_pytorch_gradients": True,
},
"pass": any_vanishing,
}
def validate_data_leakage() -> dict:
"""Task 3: Data leakage produces high overlap score."""
scenario = sample_scenario("task_003", seed=42)
data = gen_data_batch_stats(scenario)
curves = run_real_training(scenario)
overlap_high = data["class_overlap_score"] > 0.5
training_runs = len(curves["loss_history"]) == 20
return {
"task": "task_003",
"fault": "data_leakage",
"checks": {
"class_overlap_above_0.5": overlap_high,
"class_overlap_score": round(data["class_overlap_score"], 4),
"real_training_runs": training_runs,
"has_confusion_matrix": "confusion_matrix" in data,
},
"pass": overlap_high and training_runs,
}
def validate_overfitting() -> dict:
"""Task 4: Overfitting scenario runs real training."""
scenario = sample_scenario("task_004", seed=42)
curves = run_real_training(scenario)
data = gen_data_batch_stats(scenario)
training_runs = len(curves["loss_history"]) == 20
clean_data = data["class_overlap_score"] == 0.0
return {
"task": "task_004",
"fault": "overfitting",
"checks": {
"real_training_runs": training_runs,
"clean_data": clean_data,
"final_train_loss": round(curves["loss_history"][-1], 4),
"final_val_loss": round(curves["val_loss_history"][-1], 4),
},
"pass": training_runs and clean_data,
}
def validate_batchnorm_eval() -> dict:
"""Task 5: BatchNorm eval mode + red herrings."""
scenario = sample_scenario("task_005", seed=42)
model, _ = create_model_and_inject_fault(scenario)
stats = extract_gradient_stats(model, scenario)
modes = extract_model_modes(model)
curves = run_real_training(scenario)
all_eval = all(v == "eval" for v in modes.values())
no_exploding = not any(s.is_exploding for s in stats)
training_runs = len(curves["loss_history"]) == 20
return {
"task": "task_005",
"fault": "batchnorm_eval_mode",
"checks": {
"all_layers_in_eval_mode": all_eval,
"no_layer_is_exploding": no_exploding,
"real_training_runs": training_runs,
"real_model_eval_mode": not model.training,
"red_herring_spike_layer": scenario.red_herring_spike_layer,
},
"pass": all_eval and no_exploding and training_runs,
}
def validate_code_bugs() -> dict:
"""Task 6: Code bug variants."""
from ml_training_debugger.code_templates import (
_TEMPLATES,
generate_code_snippet,
validate_fix,
)
variants = ["eval_mode", "detach_loss", "zero_grad_missing", "inplace_relu"]
results = {}
for variant in variants:
snippet = generate_code_snippet(variant, seed=42)
_, correct_line, correct_replacement = _TEMPLATES[variant]
fix_accepted = validate_fix(variant, correct_line, correct_replacement)
wrong_rejected = not validate_fix(variant, correct_line, "pass")
results[variant] = {
"correct_fix_accepted": fix_accepted,
"wrong_fix_rejected": wrong_rejected,
}
all_pass = all(
r["correct_fix_accepted"] and r["wrong_fix_rejected"]
for r in results.values()
)
return {
"task": "task_006",
"fault": "code_bug",
"checks": {
"variants_tested": len(variants),
"variant_results": results,
"fix_validation_pipeline": "normalize -> tokenize -> semantic -> AST",
},
"pass": all_pass,
}
def validate_scheduler() -> dict:
"""Task 7: Scheduler misconfigured."""
scenario = sample_scenario("task_007", seed=42)
curves = run_real_training(scenario)
training_runs = len(curves["loss_history"]) == 20
return {
"task": "task_007",
"fault": "scheduler_misconfigured",
"checks": {
"real_training_runs": training_runs,
"scheduler_gamma": scenario.scheduler_gamma,
"scheduler_step_size": scenario.scheduler_step_size,
"final_loss": round(curves["loss_history"][-1], 4),
},
"pass": training_runs,
}
def validate_dual_architecture() -> dict:
"""Verify both CNN and MLP architectures work."""
cnn = SimpleCNN()
mlp = SimpleMLP()
x = torch.randn(4, 3, 32, 32)
cnn_out = cnn(x)
mlp_out = mlp(x)
return {
"task": "architecture",
"fault": "dual_model_support",
"checks": {
"cnn_output_shape": list(cnn_out.shape),
"mlp_output_shape": list(mlp_out.shape),
"cnn_params": sum(p.numel() for p in cnn.parameters()),
"mlp_params": sum(p.numel() for p in mlp.parameters()),
"both_produce_10_classes": cnn_out.shape[1] == 10 and mlp_out.shape[1] == 10,
},
"pass": cnn_out.shape == (4, 10) and mlp_out.shape == (4, 10),
}
def main() -> None:
validations = [
validate_exploding_gradients(),
validate_vanishing_gradients(),
validate_data_leakage(),
validate_overfitting(),
validate_batchnorm_eval(),
validate_code_bugs(),
validate_scheduler(),
validate_dual_architecture(),
]
report = {
"methodology": "Real PyTorch 20-epoch mini-training with fault injection",
"torch_version": torch.__version__,
"models": ["SimpleCNN (~50K params)", "SimpleMLP (~20K params)"],
"training_approach": "Real forward+backward passes on random CIFAR-10 style data, cached per (task_id, seed)",
"results": validations,
"summary": {
"total": len(validations),
"passed": sum(1 for v in validations if v["pass"]),
"failed": sum(1 for v in validations if not v["pass"]),
},
}
report_path = Path(__file__).parent / "reports" / "fidelity_report.json"
report_path.parent.mkdir(parents=True, exist_ok=True)
report_path.write_text(json.dumps(report, indent=2, default=str))
for v in validations:
status = "PASS" if v["pass"] else "FAIL"
print(f" {status}: {v['task']}{v['fault']}")
print(f"\n{report['summary']['passed']}/{report['summary']['total']} validations passed")
print(f"Report saved to {report_path}")
if __name__ == "__main__":
main()