File size: 4,823 Bytes
05c9ac2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 |
import io
import os
from unittest import mock
import numpy as np
import pytest
import tempfile
from mlagents_envs.communicator_objects.demonstration_meta_pb2 import (
DemonstrationMetaProto,
)
from mlagents.trainers.tests.mock_brain import (
create_mock_3dball_behavior_specs,
setup_test_behavior_specs,
)
from mlagents.trainers.demo_loader import (
load_demonstration,
demo_to_buffer,
get_demo_files,
write_delimited,
)
from mlagents.trainers.buffer import BufferKey
BEHAVIOR_SPEC = create_mock_3dball_behavior_specs()
def test_load_demo():
path_prefix = os.path.dirname(os.path.abspath(__file__))
behavior_spec, pair_infos, total_expected = load_demonstration(
path_prefix + "/test.demo"
)
assert np.sum(behavior_spec.observation_specs[0].shape) == 8
assert len(pair_infos) == total_expected
_, demo_buffer = demo_to_buffer(path_prefix + "/test.demo", 1, BEHAVIOR_SPEC)
assert (
len(demo_buffer[BufferKey.CONTINUOUS_ACTION]) == total_expected - 1
or len(demo_buffer[BufferKey.DISCRETE_ACTION]) == total_expected - 1
)
def test_load_demo_dir():
path_prefix = os.path.dirname(os.path.abspath(__file__))
behavior_spec, pair_infos, total_expected = load_demonstration(
path_prefix + "/test_demo_dir"
)
assert np.sum(behavior_spec.observation_specs[0].shape) == 8
assert len(pair_infos) == total_expected
_, demo_buffer = demo_to_buffer(path_prefix + "/test_demo_dir", 1, BEHAVIOR_SPEC)
assert (
len(demo_buffer[BufferKey.CONTINUOUS_ACTION]) == total_expected - 1
or len(demo_buffer[BufferKey.DISCRETE_ACTION]) == total_expected - 1
)
def test_demo_mismatch():
path_prefix = os.path.dirname(os.path.abspath(__file__))
# observation size mismatch
with pytest.raises(RuntimeError):
mismatch_obs = setup_test_behavior_specs(
False, False, vector_action_space=2, vector_obs_space=9
)
_, demo_buffer = demo_to_buffer(path_prefix + "/test.demo", 1, mismatch_obs)
# action mismatch
with pytest.raises(RuntimeError):
mismatch_act = setup_test_behavior_specs(
False, False, vector_action_space=3, vector_obs_space=9
)
_, demo_buffer = demo_to_buffer(path_prefix + "/test.demo", 1, mismatch_act)
# action type mismatch
with pytest.raises(RuntimeError):
mismatch_act_type = setup_test_behavior_specs(
True, False, vector_action_space=[2], vector_obs_space=9
)
_, demo_buffer = demo_to_buffer(
path_prefix + "/test.demo", 1, mismatch_act_type
)
# number obs mismatch
with pytest.raises(RuntimeError):
mismatch_obs_number = setup_test_behavior_specs(
False, True, vector_action_space=2, vector_obs_space=9
)
_, demo_buffer = demo_to_buffer(
path_prefix + "/test.demo", 1, mismatch_obs_number
)
def test_edge_cases():
path_prefix = os.path.dirname(os.path.abspath(__file__))
# nonexistent file and directory
with pytest.raises(FileNotFoundError):
get_demo_files(os.path.join(path_prefix, "nonexistent_file.demo"))
with pytest.raises(FileNotFoundError):
get_demo_files(os.path.join(path_prefix, "nonexistent_directory"))
with tempfile.TemporaryDirectory() as tmpdirname:
# empty directory
with pytest.raises(ValueError):
get_demo_files(tmpdirname)
# invalid file
invalid_fname = os.path.join(tmpdirname, "mydemo.notademo")
with open(invalid_fname, "w") as f:
f.write("I'm not a demo")
with pytest.raises(ValueError):
get_demo_files(invalid_fname)
# invalid directory
with pytest.raises(ValueError):
get_demo_files(tmpdirname)
# valid file
valid_fname = os.path.join(tmpdirname, "mydemo.demo")
with open(valid_fname, "w") as f:
f.write("I'm a demo file")
assert get_demo_files(valid_fname) == [valid_fname]
# valid directory
assert get_demo_files(tmpdirname) == [valid_fname]
@mock.patch("mlagents.trainers.demo_loader.get_demo_files", return_value=["foo.demo"])
def test_unsupported_version_raises_error(mock_get_demo_files):
# Create a metadata proto with an unsupported version
bad_metadata = DemonstrationMetaProto()
bad_metadata.api_version = 1337
# Write the metadata to a temporary buffer, which will get returned by open()
buffer = io.BytesIO()
write_delimited(buffer, bad_metadata)
m = mock.mock_open(read_data=buffer.getvalue())
# Make sure that we get a RuntimeError when trying to load this.
with mock.patch("builtins.open", m):
with pytest.raises(RuntimeError):
load_demonstration("foo")
|