AnnaMats's picture
Second Push
05c9ac2
import pytest
from unittest import mock
import torch # noqa I201
from mlagents.torch_utils import set_torch_config, default_device
from mlagents.trainers.settings import TorchSettings
@pytest.mark.parametrize(
"device_str, expected_type, expected_index, expected_tensor_type",
[
("cpu", "cpu", None, torch.FloatTensor),
("cuda", "cuda", None, torch.cuda.FloatTensor),
("cuda:42", "cuda", 42, torch.cuda.FloatTensor),
("opengl", "opengl", None, torch.FloatTensor),
],
)
@mock.patch.object(torch, "set_default_tensor_type")
def test_set_torch_device(
mock_set_default_tensor_type,
device_str,
expected_type,
expected_index,
expected_tensor_type,
):
try:
torch_settings = TorchSettings(device=device_str)
set_torch_config(torch_settings)
assert default_device().type == expected_type
if expected_index is None:
assert default_device().index is None
else:
assert default_device().index == expected_index
mock_set_default_tensor_type.assert_called_once_with(expected_tensor_type)
except Exception:
raise
finally:
# restore the defaults
torch_settings = TorchSettings(device=None)
set_torch_config(torch_settings)