|
import uuid |
|
import pytest |
|
from mlagents_envs.side_channel import SideChannel, IncomingMessage, OutgoingMessage |
|
from mlagents_envs.side_channel.side_channel_manager import SideChannelManager |
|
from mlagents_envs.side_channel.float_properties_channel import FloatPropertiesChannel |
|
from mlagents_envs.side_channel.raw_bytes_channel import RawBytesChannel |
|
from mlagents_envs.side_channel.engine_configuration_channel import ( |
|
EngineConfigurationChannel, |
|
EngineConfig, |
|
) |
|
from mlagents_envs.side_channel.environment_parameters_channel import ( |
|
EnvironmentParametersChannel, |
|
) |
|
from mlagents_envs.side_channel.stats_side_channel import ( |
|
StatsSideChannel, |
|
StatsAggregationMethod, |
|
) |
|
from mlagents_envs.exception import ( |
|
UnitySideChannelException, |
|
UnityCommunicationException, |
|
) |
|
|
|
|
|
class IntChannel(SideChannel): |
|
def __init__(self): |
|
self.list_int = [] |
|
super().__init__(uuid.UUID("a85ba5c0-4f87-11ea-a517-784f4387d1f7")) |
|
|
|
def on_message_received(self, msg: IncomingMessage) -> None: |
|
val = msg.read_int32() |
|
self.list_int += [val] |
|
|
|
def send_int(self, value): |
|
msg = OutgoingMessage() |
|
msg.write_int32(value) |
|
super().queue_message_to_send(msg) |
|
|
|
|
|
def test_int_channel(): |
|
sender = IntChannel() |
|
receiver = IntChannel() |
|
sender.send_int(5) |
|
sender.send_int(6) |
|
data = SideChannelManager([sender]).generate_side_channel_messages() |
|
SideChannelManager([receiver]).process_side_channel_message(data) |
|
assert receiver.list_int[0] == 5 |
|
assert receiver.list_int[1] == 6 |
|
|
|
|
|
def test_float_properties(): |
|
sender = FloatPropertiesChannel() |
|
receiver = FloatPropertiesChannel() |
|
|
|
sender.set_property("prop1", 1.0) |
|
|
|
data = SideChannelManager([sender]).generate_side_channel_messages() |
|
SideChannelManager([receiver]).process_side_channel_message(data) |
|
|
|
val = receiver.get_property("prop1") |
|
assert val == 1.0 |
|
val = receiver.get_property("prop2") |
|
assert val is None |
|
sender.set_property("prop2", 2.0) |
|
|
|
data = SideChannelManager([sender]).generate_side_channel_messages() |
|
SideChannelManager([receiver]).process_side_channel_message(data) |
|
|
|
val = receiver.get_property("prop1") |
|
assert val == 1.0 |
|
val = receiver.get_property("prop2") |
|
assert val == 2.0 |
|
assert len(receiver.list_properties()) == 2 |
|
assert "prop1" in receiver.list_properties() |
|
assert "prop2" in receiver.list_properties() |
|
val = sender.get_property("prop1") |
|
assert val == 1.0 |
|
|
|
assert receiver.get_property_dict_copy() == {"prop1": 1.0, "prop2": 2.0} |
|
assert receiver.get_property_dict_copy() == sender.get_property_dict_copy() |
|
|
|
|
|
def test_raw_bytes(): |
|
guid = uuid.uuid4() |
|
sender = RawBytesChannel(guid) |
|
receiver = RawBytesChannel(guid) |
|
|
|
sender.send_raw_data(b"foo") |
|
sender.send_raw_data(b"bar") |
|
|
|
data = SideChannelManager([sender]).generate_side_channel_messages() |
|
SideChannelManager([receiver]).process_side_channel_message(data) |
|
|
|
messages = receiver.get_and_clear_received_messages() |
|
assert len(messages) == 2 |
|
assert messages[0].decode("ascii") == "foo" |
|
assert messages[1].decode("ascii") == "bar" |
|
|
|
messages = receiver.get_and_clear_received_messages() |
|
assert len(messages) == 0 |
|
|
|
|
|
def test_message_bool(): |
|
vals = [True, False] |
|
msg_out = OutgoingMessage() |
|
for v in vals: |
|
msg_out.write_bool(v) |
|
|
|
msg_in = IncomingMessage(msg_out.buffer) |
|
read_vals = [] |
|
for _ in range(len(vals)): |
|
read_vals.append(msg_in.read_bool()) |
|
assert vals == read_vals |
|
|
|
|
|
assert msg_in.read_bool() is False |
|
assert msg_in.read_bool(default_value=True) is True |
|
|
|
|
|
def test_message_int32(): |
|
val = 1337 |
|
msg_out = OutgoingMessage() |
|
msg_out.write_int32(val) |
|
|
|
msg_in = IncomingMessage(msg_out.buffer) |
|
read_val = msg_in.read_int32() |
|
assert val == read_val |
|
|
|
|
|
assert 0 == msg_in.read_int32() |
|
assert val == msg_in.read_int32(default_value=val) |
|
|
|
|
|
def test_message_float32(): |
|
val = 42.0 |
|
msg_out = OutgoingMessage() |
|
msg_out.write_float32(val) |
|
|
|
msg_in = IncomingMessage(msg_out.buffer) |
|
read_val = msg_in.read_float32() |
|
|
|
assert val == read_val |
|
|
|
|
|
assert 0.0 == msg_in.read_float32() |
|
assert val == msg_in.read_float32(default_value=val) |
|
|
|
|
|
def test_message_string(): |
|
val = "mlagents!" |
|
msg_out = OutgoingMessage() |
|
msg_out.write_string(val) |
|
|
|
msg_in = IncomingMessage(msg_out.buffer) |
|
read_val = msg_in.read_string() |
|
assert val == read_val |
|
|
|
|
|
assert "" == msg_in.read_string() |
|
assert val == msg_in.read_string(default_value=val) |
|
|
|
|
|
def test_message_float_list(): |
|
val = [1.0, 3.0, 9.0] |
|
msg_out = OutgoingMessage() |
|
msg_out.write_float32_list(val) |
|
|
|
msg_in = IncomingMessage(msg_out.buffer) |
|
read_val = msg_in.read_float32_list() |
|
|
|
assert val == read_val |
|
|
|
|
|
assert [] == msg_in.read_float32_list() |
|
assert val == msg_in.read_float32_list(default_value=val) |
|
|
|
|
|
def test_engine_configuration(): |
|
sender = EngineConfigurationChannel() |
|
|
|
receiver = RawBytesChannel(sender.channel_id) |
|
|
|
config = EngineConfig.default_config() |
|
sender.set_configuration(config) |
|
data = SideChannelManager([sender]).generate_side_channel_messages() |
|
SideChannelManager([receiver]).process_side_channel_message(data) |
|
|
|
received_data = receiver.get_and_clear_received_messages() |
|
assert len(received_data) == 5 |
|
|
|
sent_time_scale = 4.5 |
|
sender.set_configuration_parameters(time_scale=sent_time_scale) |
|
|
|
data = SideChannelManager([sender]).generate_side_channel_messages() |
|
SideChannelManager([receiver]).process_side_channel_message(data) |
|
|
|
message = IncomingMessage(receiver.get_and_clear_received_messages()[0]) |
|
message.read_int32() |
|
time_scale = message.read_float32() |
|
assert time_scale == sent_time_scale |
|
|
|
with pytest.raises(UnitySideChannelException): |
|
sender.set_configuration_parameters(width=None, height=42) |
|
|
|
with pytest.raises(UnityCommunicationException): |
|
|
|
sender.set_configuration_parameters(time_scale=sent_time_scale) |
|
data = SideChannelManager([sender]).generate_side_channel_messages() |
|
SideChannelManager([sender]).process_side_channel_message(data) |
|
|
|
|
|
def test_environment_parameters(): |
|
sender = EnvironmentParametersChannel() |
|
|
|
receiver = RawBytesChannel(sender.channel_id) |
|
|
|
sender.set_float_parameter("param-1", 0.1) |
|
data = SideChannelManager([sender]).generate_side_channel_messages() |
|
SideChannelManager([receiver]).process_side_channel_message(data) |
|
|
|
message = IncomingMessage(receiver.get_and_clear_received_messages()[0]) |
|
key = message.read_string() |
|
dtype = message.read_int32() |
|
value = message.read_float32() |
|
assert key == "param-1" |
|
assert dtype == EnvironmentParametersChannel.EnvironmentDataTypes.FLOAT |
|
assert value - 0.1 < 1e-8 |
|
|
|
sender.set_float_parameter("param-1", 0.1) |
|
sender.set_float_parameter("param-2", 0.1) |
|
sender.set_float_parameter("param-3", 0.1) |
|
|
|
data = SideChannelManager([sender]).generate_side_channel_messages() |
|
SideChannelManager([receiver]).process_side_channel_message(data) |
|
|
|
assert len(receiver.get_and_clear_received_messages()) == 3 |
|
|
|
with pytest.raises(UnityCommunicationException): |
|
|
|
sender.set_float_parameter("param-1", 0.1) |
|
data = SideChannelManager([sender]).generate_side_channel_messages() |
|
SideChannelManager([sender]).process_side_channel_message(data) |
|
|
|
|
|
def test_stats_channel(): |
|
receiver = StatsSideChannel() |
|
message = OutgoingMessage() |
|
message.write_string("stats-1") |
|
message.write_float32(42.0) |
|
message.write_int32(1) |
|
|
|
receiver.on_message_received(IncomingMessage(message.buffer)) |
|
|
|
stats = receiver.get_and_reset_stats() |
|
|
|
assert len(stats) == 1 |
|
val, method = stats["stats-1"][0] |
|
assert val - 42.0 < 1e-8 |
|
assert method == StatsAggregationMethod.MOST_RECENT |
|
|