zjowowen's picture
init space
079c32c
raw
history blame
2.07 kB
import pytest
from easydict import EasyDict
from ding.framework import OnlineRLContext
from ding.framework.middleware.ckpt_handler import CkptSaver
import torch.nn as nn
import torch.optim as optim
import os
import shutil
from unittest.mock import Mock, patch
from ding.framework import task
from ding.policy.base_policy import Policy
class TheModelClass(nn.Module):
def state_dict(self):
return 'fake_state_dict'
class MockPolicy(Mock):
def __init__(self, model, **kwargs) -> None:
super(MockPolicy, self).__init__(model)
self.learn_mode = model
@property
def eval_mode(self):
return EasyDict({"state_dict": lambda: {}})
@pytest.mark.unittest
def test_ckpt_saver():
exp_name = 'test_ckpt_saver_exp'
ctx = OnlineRLContext()
train_freq = 100
model = TheModelClass()
if not os.path.exists(exp_name):
os.makedirs(exp_name)
prefix = '{}/ckpt'.format(exp_name)
with patch("ding.policy.Policy", MockPolicy), task.start():
policy = MockPolicy(model)
def mock_save_file(path, data, fs_type=None, use_lock=False):
assert path == "{}/eval.pth.tar".format(prefix)
with patch("ding.framework.middleware.ckpt_handler.save_file", mock_save_file):
ctx.train_iter = 1
ctx.eval_value = 9.4
ckpt_saver = CkptSaver(policy, exp_name, train_freq)
ckpt_saver(ctx)
def mock_save_file(path, data, fs_type=None, use_lock=False):
assert path == "{}/iteration_{}.pth.tar".format(prefix, ctx.train_iter)
with patch("ding.framework.middleware.ckpt_handler.save_file", mock_save_file):
ctx.train_iter = 100
ctx.eval_value = 1
ckpt_saver(ctx)
def mock_save_file(path, data, fs_type=None, use_lock=False):
assert path == "{}/final.pth.tar".format(prefix)
with patch("ding.framework.middleware.ckpt_handler.save_file", mock_save_file):
task.finish = True
ckpt_saver(ctx)
shutil.rmtree(exp_name)