pytorch-training-debugger / tests /test_pytorch_engine.py
omkarrr88
Major fixes + gap fixes
4f58e42
"""Test real PyTorch model instantiation and fault injection."""
from __future__ import annotations
import torch
import torch.nn as nn
from ml_training_debugger.pytorch_engine import (
SimpleCNN,
create_model_and_inject_fault,
extract_gradient_stats,
extract_model_modes,
extract_weight_stats,
)
from ml_training_debugger.scenarios import sample_scenario
class TestSimpleCNN:
def test_is_nn_module(self):
model = SimpleCNN()
assert isinstance(model, nn.Module)
def test_param_count(self):
model = SimpleCNN()
count = sum(p.numel() for p in model.parameters())
assert 30_000 < count < 100_000 # ~50K params
def test_forward_pass(self):
model = SimpleCNN()
x = torch.randn(2, 3, 32, 32)
out = model(x)
assert out.shape == (2, 10)
class TestFaultInjection:
def test_task_001_exploding_gradients(self):
scenario = sample_scenario("task_001", seed=42)
model, info = create_model_and_inject_fault(scenario)
stats = extract_gradient_stats(model, scenario)
assert len(stats) > 0
# At least some layers should have elevated gradients
any_high = any(s.mean_norm > 1.0 for s in stats)
assert any_high
def test_task_005_eval_mode(self):
scenario = sample_scenario("task_005", seed=42)
model, info = create_model_and_inject_fault(scenario)
assert not model.training # model.eval() was called
def test_task_005_gradients_not_exploding(self):
scenario = sample_scenario("task_005", seed=42)
model, info = create_model_and_inject_fault(scenario)
stats = extract_gradient_stats(model, scenario)
# ALL layers must have is_exploding=False
for s in stats:
assert not s.is_exploding, f"Layer {s.layer_name} should not be exploding"
class TestExtractGradientStats:
def test_returns_gradient_stats(self):
scenario = sample_scenario("task_001", seed=42)
model, _ = create_model_and_inject_fault(scenario)
stats = extract_gradient_stats(model, scenario)
assert len(stats) == 4 # conv1, conv2, conv3, fc
for s in stats:
assert isinstance(s.mean_norm, float)
assert isinstance(s.norm_history, list)
assert len(s.norm_history) == 5
class TestExtractWeightStats:
def test_returns_weight_stats(self):
scenario = sample_scenario("task_001", seed=42)
model, _ = create_model_and_inject_fault(scenario)
stats = extract_weight_stats(model)
assert len(stats) > 0
for s in stats:
assert isinstance(s.weight_norm, float)
assert isinstance(s.has_nan, bool)
class TestExtractModelModes:
def test_train_mode(self):
model = SimpleCNN()
model.train()
modes = extract_model_modes(model)
assert all(v == "train" for v in modes.values())
def test_eval_mode(self):
model = SimpleCNN()
model.eval()
modes = extract_model_modes(model)
assert all(v == "eval" for v in modes.values())
class TestTask005RedHerrings:
"""Test Task 5 red herring injection — conv1 near-vanishing, FC spike."""
def test_conv1_near_vanishing_red_herring(self):
"""When spike layer is fc, conv1 should show near-vanishing gradient."""
scenario = sample_scenario("task_005", seed=42)
model, _ = create_model_and_inject_fault(scenario)
stats = extract_gradient_stats(model, scenario)
conv1 = next(s for s in stats if s.layer_name == "conv1")
if scenario.red_herring_spike_layer != "conv1":
# conv1 should be near-vanishing (but not is_vanishing since 0.0003 > 1e-6)
assert conv1.mean_norm < 0.01
assert not conv1.is_vanishing # 0.0003 > 1e-6
def test_fc_spike_not_exploding(self):
"""FC spike has elevated gradient but is_exploding=False (mean < 10.0)."""
scenario = sample_scenario("task_005", seed=42)
model, _ = create_model_and_inject_fault(scenario)
stats = extract_gradient_stats(model, scenario)
spike_layer = next(
s for s in stats if s.layer_name == scenario.red_herring_spike_layer
)
assert not spike_layer.is_exploding
# Should have non-trivial norm from the spike
assert spike_layer.mean_norm > 0
def test_all_layers_not_exploding(self):
"""All layers is_exploding=False — this gates gradients_were_normal."""
scenario = sample_scenario("task_005", seed=42)
model, _ = create_model_and_inject_fault(scenario)
stats = extract_gradient_stats(model, scenario)
for s in stats:
assert not s.is_exploding, f"{s.layer_name} should not be exploding"
class TestVanishingGradientInjection:
"""Test vanishing gradient fault injection produces correct stats."""
def test_task_002_vanishing(self):
scenario = sample_scenario("task_002", seed=42)
model, _ = create_model_and_inject_fault(scenario)
stats = extract_gradient_stats(model, scenario)
# Deeper layers should have vanishing gradients
assert any(s.is_vanishing for s in stats)
def test_task_002_model_in_train_mode(self):
scenario = sample_scenario("task_002", seed=42)
model, _ = create_model_and_inject_fault(scenario)
assert model.training
class TestCodeBugFaultInjection:
"""Test code bug fault injection — model should be normal."""
def test_task_006_model_trains_normally(self):
scenario = sample_scenario("task_006", seed=42)
model, _ = create_model_and_inject_fault(scenario)
assert model.training # Should be in train mode
stats = extract_gradient_stats(model, scenario)
# No exploding/vanishing — bug is in code only
assert not any(s.is_exploding for s in stats)
class TestDataLeakageFaultInjection:
"""Test data leakage scenario — model should be normal."""
def test_task_003_normal_model(self):
scenario = sample_scenario("task_003", seed=42)
model, _ = create_model_and_inject_fault(scenario)
assert model.training
stats = extract_gradient_stats(model, scenario)
assert not any(s.is_exploding for s in stats)