omkarrr88
minor changes
206438f
"""Training curve generation — real PyTorch mini-training.
All curves come from run_real_training() in pytorch_engine.py:
- Real torch.nn.Module (SimpleCNN or SimpleMLP)
- Real torch.autograd forward + backward passes
- Real torch.optim optimizer steps
- Real validation on held-out data
- 20 epochs, cached per (task_id, seed, model_type)
Zero numpy. Zero parametric formulas. Zero synthetic curves.
"""
from __future__ import annotations
import torch
from ml_training_debugger.scenarios import ScenarioParams
EPOCHS = 20
def _get_real_curves(scenario: ScenarioParams) -> dict[str, list[float]]:
"""Run real PyTorch training and return loss/accuracy curves.
Calls pytorch_engine.run_real_training() which:
- Creates a real SimpleCNN or SimpleMLP model
- Generates random CIFAR-10 style data (3x32x32)
- Runs 20 epochs of real forward/backward passes
- Injects the actual fault (wrong LR, eval mode, data leakage, etc.)
- Returns real loss_history, val_loss_history, val_acc_history
Results are cached per (task_id, seed, model_type) for instant resets.
"""
from ml_training_debugger.pytorch_engine import run_real_training
return run_real_training(scenario)
def gen_loss_history(scenario: ScenarioParams) -> list[float]:
"""Generate training loss history (20 epochs) from real PyTorch training."""
return _get_real_curves(scenario)["loss_history"]
def gen_val_accuracy_history(scenario: ScenarioParams) -> list[float]:
"""Generate validation accuracy history (20 epochs) from real PyTorch training."""
return _get_real_curves(scenario)["val_acc_history"]
def gen_val_loss_history(scenario: ScenarioParams) -> list[float]:
"""Generate validation loss history (20 epochs) from real PyTorch training."""
return _get_real_curves(scenario)["val_loss_history"]
def _gen_confusion_matrix(scenario: ScenarioParams) -> list[list[float]]:
"""Generate a 10x10 confusion matrix based on the fault type.
Uses torch.Tensor operations on random data shaped by the fault scenario.
"""
torch.manual_seed(scenario.seed + 10)
root = scenario.root_cause.value
n = 10
if root == "data_leakage":
# High diagonal but with leakage-induced off-diagonal noise
base = torch.eye(n) * 0.8
noise = torch.rand(n, n) * scenario.leakage_pct * 0.3
cm = base + noise
elif root == "overfitting":
# Near-perfect diagonal (memorized)
cm = torch.eye(n) * 0.95 + torch.rand(n, n) * 0.02
else:
# Normal confusion with moderate accuracy
cm = torch.eye(n) * 0.6 + torch.rand(n, n) * 0.08
# Normalize rows to sum to ~1.0
row_sums = cm.sum(dim=1, keepdim=True)
cm = cm / row_sums
return cm.tolist()
def gen_data_batch_stats(scenario: ScenarioParams) -> dict:
"""Generate data batch statistics for the scenario."""
torch.manual_seed(scenario.seed + 3)
root = scenario.root_cause.value
cm = _gen_confusion_matrix(scenario)
if root == "data_leakage":
overlap = 0.5 + scenario.leakage_pct * 1.5
overlap = min(overlap, 0.92)
return {
"label_distribution": {i: 0.1 for i in range(10)},
"feature_mean": 0.45 + torch.randn(1).item() * 0.05,
"feature_std": 0.22 + torch.randn(1).item() * 0.02,
"null_count": 0,
"class_overlap_score": overlap,
"batch_size": 64,
"duplicate_ratio": scenario.leakage_pct,
"confusion_matrix": cm,
}
if root == "overfitting":
return {
"label_distribution": {i: 0.1 for i in range(10)},
"feature_mean": 0.48 + torch.randn(1).item() * 0.03,
"feature_std": 0.25 + torch.randn(1).item() * 0.02,
"null_count": 0,
"class_overlap_score": 0.0,
"batch_size": 64,
"duplicate_ratio": 0.0,
"confusion_matrix": cm,
}
# Default: normal data
return {
"label_distribution": {i: 0.1 for i in range(10)},
"feature_mean": 0.47 + torch.randn(1).item() * 0.03,
"feature_std": 0.24 + torch.randn(1).item() * 0.02,
"null_count": 0,
"class_overlap_score": 0.0 + torch.randn(1).abs().item() * 0.05,
"batch_size": 64,
"duplicate_ratio": 0.0,
"confusion_matrix": cm,
}