| import numpy as np |
| import torch |
|
|
| from convgru_ensemble.lightning_model import RadarLightningModel |
|
|
|
|
| def test_predict_handles_unpadded_inputs(): |
| model = RadarLightningModel( |
| input_channels=1, |
| num_blocks=1, |
| forecast_steps=2, |
| ensemble_size=1, |
| noisy_decoder=False, |
| ) |
| past = np.zeros((4, 8, 8), dtype=np.float32) |
|
|
| preds = model.predict(past, forecast_steps=2, ensemble_size=1) |
|
|
| assert preds.shape == (1, 2, 8, 8) |
| assert np.isfinite(preds).all() |
|
|
|
|
| def test_from_checkpoint_delegates_to_lightning_loader(monkeypatch): |
| captured = {} |
|
|
| def fake_loader(cls, checkpoint_path, map_location=None, strict=None, weights_only=None): |
| captured["checkpoint_path"] = checkpoint_path |
| captured["map_location"] = map_location |
| captured["strict"] = strict |
| captured["weights_only"] = weights_only |
| return "loaded-model" |
|
|
| monkeypatch.setattr(RadarLightningModel, "load_from_checkpoint", classmethod(fake_loader)) |
|
|
| loaded = RadarLightningModel.from_checkpoint("/tmp/model.ckpt", device="cpu") |
|
|
| assert loaded == "loaded-model" |
| assert captured["checkpoint_path"] == "/tmp/model.ckpt" |
| assert isinstance(captured["map_location"], torch.device) |
| assert captured["map_location"].type == "cpu" |
| assert captured["strict"] is True |
|
|