| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| import pytest |
| import torch |
|
|
| from lerobot.optim.optimizers import ( |
| AdamConfig, |
| AdamWConfig, |
| MultiAdamConfig, |
| SGDConfig, |
| load_optimizer_state, |
| save_optimizer_state, |
| ) |
| from lerobot.utils.constants import ( |
| OPTIMIZER_PARAM_GROUPS, |
| OPTIMIZER_STATE, |
| ) |
|
|
|
|
| @pytest.mark.parametrize( |
| "config_cls, expected_class", |
| [ |
| (AdamConfig, torch.optim.Adam), |
| (AdamWConfig, torch.optim.AdamW), |
| (SGDConfig, torch.optim.SGD), |
| (MultiAdamConfig, dict), |
| ], |
| ) |
| def test_optimizer_build(config_cls, expected_class, model_params): |
| config = config_cls() |
| if config_cls == MultiAdamConfig: |
| params_dict = {"default": model_params} |
| optimizer = config.build(params_dict) |
| assert isinstance(optimizer, expected_class) |
| assert isinstance(optimizer["default"], torch.optim.Adam) |
| assert optimizer["default"].defaults["lr"] == config.lr |
| else: |
| optimizer = config.build(model_params) |
| assert isinstance(optimizer, expected_class) |
| assert optimizer.defaults["lr"] == config.lr |
|
|
|
|
| def test_save_optimizer_state(optimizer, tmp_path): |
| save_optimizer_state(optimizer, tmp_path) |
| assert (tmp_path / OPTIMIZER_STATE).is_file() |
| assert (tmp_path / OPTIMIZER_PARAM_GROUPS).is_file() |
|
|
|
|
| def test_save_and_load_optimizer_state(model_params, optimizer, tmp_path): |
| save_optimizer_state(optimizer, tmp_path) |
| loaded_optimizer = AdamConfig().build(model_params) |
| loaded_optimizer = load_optimizer_state(loaded_optimizer, tmp_path) |
|
|
| torch.testing.assert_close(optimizer.state_dict(), loaded_optimizer.state_dict()) |
|
|
|
|
| @pytest.fixture |
| def base_params_dict(): |
| return { |
| "actor": [torch.nn.Parameter(torch.randn(10, 10))], |
| "critic": [torch.nn.Parameter(torch.randn(5, 5))], |
| "temperature": [torch.nn.Parameter(torch.randn(3, 3))], |
| } |
|
|
|
|
| @pytest.mark.parametrize( |
| "config_params, expected_values", |
| [ |
| |
| ( |
| { |
| "lr": 1e-3, |
| "weight_decay": 1e-4, |
| "optimizer_groups": { |
| "actor": {"lr": 1e-4}, |
| "critic": {"lr": 5e-4}, |
| "temperature": {"lr": 2e-3}, |
| }, |
| }, |
| { |
| "actor": {"lr": 1e-4, "weight_decay": 1e-4, "betas": (0.9, 0.999)}, |
| "critic": {"lr": 5e-4, "weight_decay": 1e-4, "betas": (0.9, 0.999)}, |
| "temperature": {"lr": 2e-3, "weight_decay": 1e-4, "betas": (0.9, 0.999)}, |
| }, |
| ), |
| |
| ( |
| { |
| "lr": 1e-3, |
| "weight_decay": 1e-4, |
| "optimizer_groups": { |
| "actor": {"lr": 1e-4, "weight_decay": 1e-5}, |
| "critic": {"lr": 5e-4, "weight_decay": 1e-6}, |
| "temperature": {"lr": 2e-3, "betas": (0.95, 0.999)}, |
| }, |
| }, |
| { |
| "actor": {"lr": 1e-4, "weight_decay": 1e-5, "betas": (0.9, 0.999)}, |
| "critic": {"lr": 5e-4, "weight_decay": 1e-6, "betas": (0.9, 0.999)}, |
| "temperature": {"lr": 2e-3, "weight_decay": 1e-4, "betas": (0.95, 0.999)}, |
| }, |
| ), |
| |
| ( |
| { |
| "lr": 1e-3, |
| "weight_decay": 1e-4, |
| "optimizer_groups": { |
| "actor": {"lr": 1e-4, "eps": 1e-6}, |
| "critic": {"lr": 5e-4, "eps": 1e-7}, |
| "temperature": {"lr": 2e-3, "eps": 1e-8}, |
| }, |
| }, |
| { |
| "actor": {"lr": 1e-4, "weight_decay": 1e-4, "betas": (0.9, 0.999), "eps": 1e-6}, |
| "critic": {"lr": 5e-4, "weight_decay": 1e-4, "betas": (0.9, 0.999), "eps": 1e-7}, |
| "temperature": {"lr": 2e-3, "weight_decay": 1e-4, "betas": (0.9, 0.999), "eps": 1e-8}, |
| }, |
| ), |
| ], |
| ) |
| def test_multi_adam_configuration(base_params_dict, config_params, expected_values): |
| |
| config = MultiAdamConfig(**config_params) |
| optimizers = config.build(base_params_dict) |
|
|
| |
| assert len(optimizers) == len(expected_values) |
| assert set(optimizers.keys()) == set(expected_values.keys()) |
|
|
| |
| for opt in optimizers.values(): |
| assert isinstance(opt, torch.optim.Adam) |
|
|
| |
| for name, expected in expected_values.items(): |
| optimizer = optimizers[name] |
| for param, value in expected.items(): |
| assert optimizer.defaults[param] == value |
|
|
|
|
| @pytest.fixture |
| def multi_optimizers(base_params_dict): |
| config = MultiAdamConfig( |
| lr=1e-3, |
| optimizer_groups={ |
| "actor": {"lr": 1e-4}, |
| "critic": {"lr": 5e-4}, |
| "temperature": {"lr": 2e-3}, |
| }, |
| ) |
| return config.build(base_params_dict) |
|
|
|
|
| def test_save_multi_optimizer_state(multi_optimizers, tmp_path): |
| |
| save_optimizer_state(multi_optimizers, tmp_path) |
|
|
| |
| for name in multi_optimizers: |
| assert (tmp_path / name).is_dir() |
| assert (tmp_path / name / OPTIMIZER_STATE).is_file() |
| assert (tmp_path / name / OPTIMIZER_PARAM_GROUPS).is_file() |
|
|
|
|
| def test_save_and_load_multi_optimizer_state(base_params_dict, multi_optimizers, tmp_path): |
| |
| for name, params in base_params_dict.items(): |
| if name in multi_optimizers: |
| |
| dummy_loss = params[0].sum() |
| dummy_loss.backward() |
| |
| multi_optimizers[name].step() |
| |
| multi_optimizers[name].zero_grad() |
|
|
| |
| save_optimizer_state(multi_optimizers, tmp_path) |
|
|
| |
| config = MultiAdamConfig( |
| lr=1e-3, |
| optimizer_groups={ |
| "actor": {"lr": 1e-4}, |
| "critic": {"lr": 5e-4}, |
| "temperature": {"lr": 2e-3}, |
| }, |
| ) |
| new_optimizers = config.build(base_params_dict) |
|
|
| |
| loaded_optimizers = load_optimizer_state(new_optimizers, tmp_path) |
|
|
| |
| for name in multi_optimizers: |
| torch.testing.assert_close(multi_optimizers[name].state_dict(), loaded_optimizers[name].state_dict()) |
|
|
|
|
| def test_save_and_load_empty_multi_optimizer_state(base_params_dict, tmp_path): |
| """Test saving and loading optimizer states even when the state is empty (no backward pass).""" |
| |
| config = MultiAdamConfig( |
| lr=1e-3, |
| optimizer_groups={ |
| "actor": {"lr": 1e-4}, |
| "critic": {"lr": 5e-4}, |
| "temperature": {"lr": 2e-3}, |
| }, |
| ) |
| optimizers = config.build(base_params_dict) |
|
|
| |
| save_optimizer_state(optimizers, tmp_path) |
|
|
| |
| new_optimizers = config.build(base_params_dict) |
|
|
| |
| loaded_optimizers = load_optimizer_state(new_optimizers, tmp_path) |
|
|
| |
| for name, optimizer in optimizers.items(): |
| assert optimizer.defaults["lr"] == loaded_optimizers[name].defaults["lr"] |
| assert optimizer.defaults["weight_decay"] == loaded_optimizers[name].defaults["weight_decay"] |
| assert optimizer.defaults["betas"] == loaded_optimizers[name].defaults["betas"] |
|
|
| |
| torch.testing.assert_close( |
| optimizer.state_dict()["param_groups"], loaded_optimizers[name].state_dict()["param_groups"] |
| ) |
|
|