| """Training curve generation — real PyTorch mini-training. |
| |
| All curves come from run_real_training() in pytorch_engine.py: |
| - Real torch.nn.Module (SimpleCNN or SimpleMLP) |
| - Real torch.autograd forward + backward passes |
| - Real torch.optim optimizer steps |
| - Real validation on held-out data |
| - 20 epochs, cached per (task_id, seed, model_type) |
| |
| Zero numpy. Zero parametric formulas. Zero synthetic curves. |
| """ |
|
|
| from __future__ import annotations |
|
|
| import torch |
|
|
| from ml_training_debugger.scenarios import ScenarioParams |
|
|
| EPOCHS = 20 |
|
|
|
|
| def _get_real_curves(scenario: ScenarioParams) -> dict[str, list[float]]: |
| """Run real PyTorch training and return loss/accuracy curves. |
| |
| Calls pytorch_engine.run_real_training() which: |
| - Creates a real SimpleCNN or SimpleMLP model |
| - Generates random CIFAR-10 style data (3x32x32) |
| - Runs 20 epochs of real forward/backward passes |
| - Injects the actual fault (wrong LR, eval mode, data leakage, etc.) |
| - Returns real loss_history, val_loss_history, val_acc_history |
| |
| Results are cached per (task_id, seed, model_type) for instant resets. |
| """ |
| from ml_training_debugger.pytorch_engine import run_real_training |
|
|
| return run_real_training(scenario) |
|
|
|
|
| def gen_loss_history(scenario: ScenarioParams) -> list[float]: |
| """Generate training loss history (20 epochs) from real PyTorch training.""" |
| return _get_real_curves(scenario)["loss_history"] |
|
|
|
|
| def gen_val_accuracy_history(scenario: ScenarioParams) -> list[float]: |
| """Generate validation accuracy history (20 epochs) from real PyTorch training.""" |
| return _get_real_curves(scenario)["val_acc_history"] |
|
|
|
|
| def gen_val_loss_history(scenario: ScenarioParams) -> list[float]: |
| """Generate validation loss history (20 epochs) from real PyTorch training.""" |
| return _get_real_curves(scenario)["val_loss_history"] |
|
|
|
|
| def _gen_confusion_matrix(scenario: ScenarioParams) -> list[list[float]]: |
| """Generate a 10x10 confusion matrix based on the fault type. |
| |
| Uses torch.Tensor operations on random data shaped by the fault scenario. |
| """ |
| torch.manual_seed(scenario.seed + 10) |
| root = scenario.root_cause.value |
| n = 10 |
|
|
| if root == "data_leakage": |
| |
| base = torch.eye(n) * 0.8 |
| noise = torch.rand(n, n) * scenario.leakage_pct * 0.3 |
| cm = base + noise |
| elif root == "overfitting": |
| |
| cm = torch.eye(n) * 0.95 + torch.rand(n, n) * 0.02 |
| else: |
| |
| cm = torch.eye(n) * 0.6 + torch.rand(n, n) * 0.08 |
|
|
| |
| row_sums = cm.sum(dim=1, keepdim=True) |
| cm = cm / row_sums |
| return cm.tolist() |
|
|
|
|
| def gen_data_batch_stats(scenario: ScenarioParams) -> dict: |
| """Generate data batch statistics for the scenario.""" |
| torch.manual_seed(scenario.seed + 3) |
|
|
| root = scenario.root_cause.value |
|
|
| cm = _gen_confusion_matrix(scenario) |
|
|
| if root == "data_leakage": |
| overlap = 0.5 + scenario.leakage_pct * 1.5 |
| overlap = min(overlap, 0.92) |
| return { |
| "label_distribution": {i: 0.1 for i in range(10)}, |
| "feature_mean": 0.45 + torch.randn(1).item() * 0.05, |
| "feature_std": 0.22 + torch.randn(1).item() * 0.02, |
| "null_count": 0, |
| "class_overlap_score": overlap, |
| "batch_size": 64, |
| "duplicate_ratio": scenario.leakage_pct, |
| "confusion_matrix": cm, |
| } |
|
|
| if root == "overfitting": |
| return { |
| "label_distribution": {i: 0.1 for i in range(10)}, |
| "feature_mean": 0.48 + torch.randn(1).item() * 0.03, |
| "feature_std": 0.25 + torch.randn(1).item() * 0.02, |
| "null_count": 0, |
| "class_overlap_score": 0.0, |
| "batch_size": 64, |
| "duplicate_ratio": 0.0, |
| "confusion_matrix": cm, |
| } |
|
|
| |
| return { |
| "label_distribution": {i: 0.1 for i in range(10)}, |
| "feature_mean": 0.47 + torch.randn(1).item() * 0.03, |
| "feature_std": 0.24 + torch.randn(1).item() * 0.02, |
| "null_count": 0, |
| "class_overlap_score": 0.0 + torch.randn(1).abs().item() * 0.05, |
| "batch_size": 64, |
| "duplicate_ratio": 0.0, |
| "confusion_matrix": cm, |
| } |
|
|