| """PyTorch-native fault injection engine. |
| |
| Real torch.nn.Module models, real torch.autograd gradients, |
| real state_dict() weight snapshots. Zero numpy. |
| """ |
|
|
| from __future__ import annotations |
|
|
| from typing import Optional |
|
|
| import torch |
| import torch.nn as nn |
|
|
| from ml_training_debugger.models import GradientStats, ModelWeightStats |
| from ml_training_debugger.nn_models import SimpleCNN, SimpleMLP, create_model |
| from ml_training_debugger.scenarios import ScenarioParams |
|
|
| |
| __all__ = ["SimpleCNN", "SimpleMLP", "create_model"] |
|
|
| _create_model = create_model |
|
|
|
|
| |
| _TRAINING_CACHE: dict[tuple[str, int, str], dict[str, list[float]]] = {} |
|
|
| TRAINING_EPOCHS = 20 |
| TRAINING_BATCH_SIZE = 16 |
|
|
|
|
| def run_real_training(scenario: ScenarioParams) -> dict[str, list[float]]: |
| """Run real 20-epoch mini-training and return loss/accuracy curves. |
| |
| Caches results per (task_id, seed, model_type) for instant subsequent resets. |
| Each call takes ~0.5-2s on CPU; cached calls are instant. |
| """ |
| cache_key = (scenario.task_id, scenario.seed, scenario.model_type) |
| if cache_key in _TRAINING_CACHE: |
| return _TRAINING_CACHE[cache_key] |
|
|
| torch.manual_seed(scenario.seed) |
| model = _create_model(scenario.model_type) |
| criterion = nn.CrossEntropyLoss() |
| root = scenario.root_cause.value |
|
|
| |
| if root == "lr_too_high": |
| lr = scenario.learning_rate |
| optimizer = torch.optim.SGD(model.parameters(), lr=lr) |
| model.train() |
| elif root == "vanishing_gradients": |
| optimizer = torch.optim.SGD(model.parameters(), lr=scenario.learning_rate) |
| model.train() |
| elif root == "batchnorm_eval_mode": |
| optimizer = torch.optim.Adam(model.parameters(), lr=0.001) |
| model.eval() |
| elif root == "scheduler_misconfigured": |
| optimizer = torch.optim.Adam(model.parameters(), lr=0.001) |
| scheduler = torch.optim.lr_scheduler.StepLR( |
| optimizer, |
| step_size=scenario.scheduler_step_size, |
| gamma=scenario.scheduler_gamma, |
| ) |
| model.train() |
| elif root == "overfitting": |
| optimizer = torch.optim.Adam( |
| model.parameters(), lr=0.001, weight_decay=scenario.weight_decay |
| ) |
| model.train() |
| else: |
| optimizer = torch.optim.Adam(model.parameters(), lr=0.001) |
| model.train() |
|
|
| loss_history: list[float] = [] |
| val_loss_history: list[float] = [] |
| val_acc_history: list[float] = [] |
|
|
| |
| torch.manual_seed(scenario.seed + 100) |
| train_x = torch.randn(TRAINING_BATCH_SIZE * 4, 3, 32, 32) |
| train_y = torch.randint(0, 10, (TRAINING_BATCH_SIZE * 4,)) |
| val_x = torch.randn(TRAINING_BATCH_SIZE, 3, 32, 32) |
| val_y = torch.randint(0, 10, (TRAINING_BATCH_SIZE,)) |
|
|
| |
| if root == "data_leakage": |
| leak_count = max(1, int(TRAINING_BATCH_SIZE * scenario.leakage_pct)) |
| val_x[:leak_count] = train_x[:leak_count] |
| val_y[:leak_count] = train_y[:leak_count] |
|
|
| for epoch in range(TRAINING_EPOCHS): |
| |
| batch_idx = (epoch % 4) * TRAINING_BATCH_SIZE |
| bx = train_x[batch_idx : batch_idx + TRAINING_BATCH_SIZE] |
| by = train_y[batch_idx : batch_idx + TRAINING_BATCH_SIZE] |
|
|
| optimizer.zero_grad() |
| output = model(bx) |
| loss = criterion(output, by) |
|
|
| loss_val = loss.item() |
| if loss_val != loss_val: |
| loss_history.append(float("inf")) |
| else: |
| loss_history.append(loss_val) |
|
|
| try: |
| loss.backward() |
| optimizer.step() |
| if root == "scheduler_misconfigured": |
| scheduler.step() |
| except RuntimeError: |
| loss_history[-1] = float("inf") |
|
|
| |
| with torch.no_grad(): |
| val_out = model(val_x) |
| v_loss = criterion(val_out, val_y) |
| v_loss_val = v_loss.item() |
| val_loss_history.append(v_loss_val if v_loss_val == v_loss_val else float("inf")) |
| preds = val_out.argmax(dim=1) |
| acc = (preds == val_y).float().mean().item() |
| val_acc_history.append(acc) |
|
|
| result = { |
| "loss_history": loss_history, |
| "val_loss_history": val_loss_history, |
| "val_acc_history": val_acc_history, |
| } |
| _TRAINING_CACHE[cache_key] = result |
| return result |
|
|
|
|
| def create_model_and_inject_fault( |
| scenario: ScenarioParams, |
| ) -> tuple[nn.Module, dict]: |
| """Instantiate a real PyTorch model and inject the specified fault. |
| |
| Returns: |
| (model, info_dict) where info_dict contains computed artifacts. |
| """ |
| torch.manual_seed(scenario.seed) |
|
|
| model = _create_model(scenario.model_type) |
| criterion = nn.CrossEntropyLoss() |
| info: dict = {} |
|
|
| |
| batch_x = torch.randn(8, 3, 32, 32) |
| batch_y = torch.randint(0, 10, (8,)) |
|
|
| if scenario.root_cause.value == "lr_too_high": |
| |
| model.train() |
| optimizer = torch.optim.SGD( |
| model.parameters(), lr=scenario.learning_rate * 10.0 |
| ) |
| for _ in range(3): |
| optimizer.zero_grad() |
| output = model(batch_x) |
| loss = criterion(output, batch_y) |
| loss.backward() |
| optimizer.step() |
| |
| optimizer.zero_grad() |
| output = model(batch_x) |
| loss = criterion(output, batch_y) |
| loss.backward() |
|
|
| elif scenario.root_cause.value == "vanishing_gradients": |
| |
| |
| model.train() |
| optimizer = torch.optim.SGD(model.parameters(), lr=scenario.learning_rate) |
| optimizer.zero_grad() |
| output = model(batch_x) |
| loss = criterion(output, batch_y) |
| loss.backward() |
| |
| depth_mult = scenario.depth_multiplier |
| layer_idx = 0 |
| for name, param in model.named_parameters(): |
| if param.grad is not None: |
| decay = torch.tensor(1e-7) * torch.exp( |
| torch.tensor(-depth_mult * layer_idx) |
| ) |
| param.grad.data = param.grad.data * decay |
| layer_idx += 1 |
|
|
| elif scenario.root_cause.value == "data_leakage": |
| |
| model.train() |
| optimizer = torch.optim.Adam(model.parameters(), lr=0.001) |
| optimizer.zero_grad() |
| output = model(batch_x) |
| loss = criterion(output, batch_y) |
| loss.backward() |
| optimizer.step() |
|
|
| elif scenario.root_cause.value == "overfitting": |
| |
| model.train() |
| optimizer = torch.optim.Adam( |
| model.parameters(), |
| lr=0.001, |
| weight_decay=scenario.weight_decay, |
| ) |
| optimizer.zero_grad() |
| output = model(batch_x) |
| loss = criterion(output, batch_y) |
| loss.backward() |
| optimizer.step() |
|
|
| elif scenario.root_cause.value == "batchnorm_eval_mode": |
| |
| model.eval() |
| optimizer = torch.optim.Adam(model.parameters(), lr=0.001) |
| |
| output = model(batch_x) |
| loss = criterion(output, batch_y) |
| loss.backward() |
| optimizer.step() |
|
|
| elif scenario.root_cause.value == "code_bug": |
| |
| model.train() |
| optimizer = torch.optim.Adam(model.parameters(), lr=0.001) |
| optimizer.zero_grad() |
| output = model(batch_x) |
| loss = criterion(output, batch_y) |
| loss.backward() |
| optimizer.step() |
|
|
| elif scenario.root_cause.value == "scheduler_misconfigured": |
| |
| model.train() |
| optimizer = torch.optim.Adam(model.parameters(), lr=0.001) |
| scheduler = torch.optim.lr_scheduler.StepLR( |
| optimizer, |
| step_size=scenario.scheduler_step_size, |
| gamma=scenario.scheduler_gamma, |
| ) |
| for _ in range(3): |
| optimizer.zero_grad() |
| output = model(batch_x) |
| loss = criterion(output, batch_y) |
| loss.backward() |
| optimizer.step() |
| scheduler.step() |
| info["final_lr"] = optimizer.param_groups[0]["lr"] |
|
|
| return model, info |
|
|
|
|
| def extract_gradient_stats( |
| model: nn.Module, |
| scenario: Optional[ScenarioParams] = None, |
| ) -> list[GradientStats]: |
| """Extract gradient statistics from real param.grad tensors. |
| |
| For Task 5 (batchnorm_eval_mode), injects red-herring spike on |
| the configured layer. |
| """ |
| stats: list[GradientStats] = [] |
|
|
| if isinstance(model, SimpleMLP): |
| named_layers = [ |
| ("fc1", model.fc1), |
| ("fc2", model.fc2), |
| ("fc3", model.fc3), |
| ] |
| else: |
| named_layers = [ |
| ("conv1", model.conv1), |
| ("conv2", model.conv2), |
| ("conv3", model.conv3), |
| ("fc", model.fc), |
| ] |
|
|
| for layer_name, layer in named_layers: |
| norms: list[float] = [] |
| for param in layer.parameters(): |
| if param.grad is not None: |
| norm_val = torch.norm(param.grad).item() |
| norms.append(norm_val) |
|
|
| if not norms: |
| norms = [0.0] |
|
|
| mean_norm = sum(norms) / len(norms) |
| max_norm = max(norms) |
|
|
| |
| norm_history = [mean_norm * (0.9 + 0.2 * i / 4) for i in range(5)] |
|
|
| |
| if scenario and scenario.root_cause.value == "batchnorm_eval_mode": |
| if layer_name == scenario.red_herring_spike_layer: |
| spike = scenario.red_herring_intensity |
| norm_history = [ |
| mean_norm, |
| mean_norm, |
| mean_norm * spike, |
| mean_norm * spike * 1.2, |
| mean_norm, |
| ] |
| mean_norm = sum(norm_history) / len(norm_history) |
| max_norm = max(norm_history) |
|
|
| |
| if layer_name == "conv1" and scenario.red_herring_spike_layer != "conv1": |
| near_vanish = 0.0003 |
| norm_history = [near_vanish * (0.95 + 0.1 * i / 4) for i in range(5)] |
| mean_norm = near_vanish |
| max_norm = max(norm_history) |
|
|
| is_exploding = mean_norm > 10.0 |
| is_vanishing = mean_norm < 1e-6 |
|
|
| stats.append( |
| GradientStats( |
| layer_name=layer_name, |
| norm_history=norm_history, |
| mean_norm=mean_norm, |
| max_norm=max_norm, |
| is_exploding=is_exploding, |
| is_vanishing=is_vanishing, |
| ) |
| ) |
|
|
| return stats |
|
|
|
|
| def extract_weight_stats(model: nn.Module) -> list[ModelWeightStats]: |
| """Extract weight statistics from real model.state_dict().""" |
| stats: list[ModelWeightStats] = [] |
| for name, param in model.named_parameters(): |
| if "weight" not in name: |
| continue |
| stats.append( |
| ModelWeightStats( |
| layer_name=name, |
| weight_norm=torch.norm(param).item(), |
| weight_mean=param.mean().item(), |
| weight_std=param.std().item(), |
| weight_min=param.min().item(), |
| weight_max=param.max().item(), |
| dead_neuron_pct=0.0, |
| has_nan=bool(torch.isnan(param).any().item()), |
| has_inf=bool(torch.isinf(param).any().item()), |
| ) |
| ) |
| return stats |
|
|
|
|
| def extract_model_modes(model: nn.Module) -> dict[str, str]: |
| """Extract training/eval mode for each named module.""" |
| modes: dict[str, str] = {} |
| for name, module in model.named_modules(): |
| if name == "": |
| continue |
| modes[name] = "train" if module.training else "eval" |
| return modes |
|
|