AnnaMats's picture
Second Push
05c9ac2
import os
import unittest
import json
from enum import Enum
import time
from mlagents.trainers.training_status import (
StatusType,
StatusMetaData,
GlobalTrainingStatus,
)
from mlagents.trainers.policy.checkpoint_manager import (
ModelCheckpointManager,
ModelCheckpoint,
)
def test_globaltrainingstatus(tmpdir):
path_dir = os.path.join(tmpdir, "test.json")
GlobalTrainingStatus.set_parameter_state("Category1", StatusType.LESSON_NUM, 3)
GlobalTrainingStatus.save_state(path_dir)
with open(path_dir) as fp:
test_json = json.load(fp)
assert "Category1" in test_json
assert StatusType.LESSON_NUM.value in test_json["Category1"]
assert test_json["Category1"][StatusType.LESSON_NUM.value] == 3
assert "metadata" in test_json
GlobalTrainingStatus.load_state(path_dir)
restored_val = GlobalTrainingStatus.get_parameter_state(
"Category1", StatusType.LESSON_NUM
)
assert restored_val == 3
# Test unknown categories and status types (keys)
unknown_category = GlobalTrainingStatus.get_parameter_state(
"Category3", StatusType.LESSON_NUM
)
class FakeStatusType(Enum):
NOTAREALKEY = "notarealkey"
unknown_key = GlobalTrainingStatus.get_parameter_state(
"Category1", FakeStatusType.NOTAREALKEY
)
assert unknown_category is None
assert unknown_key is None
def test_model_management(tmpdir):
results_path = os.path.join(tmpdir, "results")
brain_name = "Mock_brain"
final_model_path = os.path.join(results_path, brain_name)
test_checkpoint_list = [
{
"steps": 1,
"file_path": os.path.join(final_model_path, f"{brain_name}-1.nn"),
"reward": 1.312,
"creation_time": time.time(),
"auxillary_file_paths": [],
},
{
"steps": 2,
"file_path": os.path.join(final_model_path, f"{brain_name}-2.nn"),
"reward": 1.912,
"creation_time": time.time(),
"auxillary_file_paths": [],
},
{
"steps": 3,
"file_path": os.path.join(final_model_path, f"{brain_name}-3.nn"),
"reward": 2.312,
"creation_time": time.time(),
"auxillary_file_paths": [],
},
]
GlobalTrainingStatus.set_parameter_state(
brain_name, StatusType.CHECKPOINTS, test_checkpoint_list
)
new_checkpoint_4 = ModelCheckpoint(
4, os.path.join(final_model_path, f"{brain_name}-4.nn"), 2.678, time.time()
)
ModelCheckpointManager.add_checkpoint(brain_name, new_checkpoint_4, 4)
assert len(ModelCheckpointManager.get_checkpoints(brain_name)) == 4
new_checkpoint_5 = ModelCheckpoint(
5, os.path.join(final_model_path, f"{brain_name}-5.nn"), 3.122, time.time()
)
ModelCheckpointManager.add_checkpoint(brain_name, new_checkpoint_5, 4)
assert len(ModelCheckpointManager.get_checkpoints(brain_name)) == 4
final_model_path = f"{final_model_path}.nn"
final_model_time = time.time()
current_step = 6
final_model = ModelCheckpoint(
current_step, final_model_path, 3.294, final_model_time
)
ModelCheckpointManager.track_final_checkpoint(brain_name, final_model)
assert len(ModelCheckpointManager.get_checkpoints(brain_name)) == 4
check_checkpoints = GlobalTrainingStatus.saved_state[brain_name][
StatusType.CHECKPOINTS.value
]
assert check_checkpoints is not None
final_model = GlobalTrainingStatus.saved_state[StatusType.FINAL_CHECKPOINT.value]
assert final_model is not None
class StatsMetaDataTest(unittest.TestCase):
def test_metadata_compare(self):
# Test write_stats
with self.assertLogs("mlagents.trainers", level="WARNING") as cm:
default_metadata = StatusMetaData()
version_statsmetadata = StatusMetaData(mlagents_version="test")
default_metadata.check_compatibility(version_statsmetadata)
torch_version_statsmetadata = StatusMetaData(torch_version="test")
default_metadata.check_compatibility(torch_version_statsmetadata)
# Assert that 2 warnings have been thrown
assert len(cm.output) == 2