import pytest from easydict import EasyDict from ding.framework import OnlineRLContext from ding.framework.middleware.ckpt_handler import CkptSaver import torch.nn as nn import torch.optim as optim import os import shutil from unittest.mock import Mock, patch from ding.framework import task from ding.policy.base_policy import Policy class TheModelClass(nn.Module): def state_dict(self): return 'fake_state_dict' class MockPolicy(Mock): def __init__(self, model, **kwargs) -> None: super(MockPolicy, self).__init__(model) self.learn_mode = model @property def eval_mode(self): return EasyDict({"state_dict": lambda: {}}) @pytest.mark.unittest def test_ckpt_saver(): exp_name = 'test_ckpt_saver_exp' ctx = OnlineRLContext() train_freq = 100 model = TheModelClass() if not os.path.exists(exp_name): os.makedirs(exp_name) prefix = '{}/ckpt'.format(exp_name) with patch("ding.policy.Policy", MockPolicy), task.start(): policy = MockPolicy(model) def mock_save_file(path, data, fs_type=None, use_lock=False): assert path == "{}/eval.pth.tar".format(prefix) with patch("ding.framework.middleware.ckpt_handler.save_file", mock_save_file): ctx.train_iter = 1 ctx.eval_value = 9.4 ckpt_saver = CkptSaver(policy, exp_name, train_freq) ckpt_saver(ctx) def mock_save_file(path, data, fs_type=None, use_lock=False): assert path == "{}/iteration_{}.pth.tar".format(prefix, ctx.train_iter) with patch("ding.framework.middleware.ckpt_handler.save_file", mock_save_file): ctx.train_iter = 100 ctx.eval_value = 1 ckpt_saver(ctx) def mock_save_file(path, data, fs_type=None, use_lock=False): assert path == "{}/final.pth.tar".format(prefix) with patch("ding.framework.middleware.ckpt_handler.save_file", mock_save_file): task.finish = True ckpt_saver(ctx) shutil.rmtree(exp_name)