File size: 2,074 Bytes
079c32c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 |
import pytest
from easydict import EasyDict
from ding.framework import OnlineRLContext
from ding.framework.middleware.ckpt_handler import CkptSaver
import torch.nn as nn
import torch.optim as optim
import os
import shutil
from unittest.mock import Mock, patch
from ding.framework import task
from ding.policy.base_policy import Policy
class TheModelClass(nn.Module):
def state_dict(self):
return 'fake_state_dict'
class MockPolicy(Mock):
def __init__(self, model, **kwargs) -> None:
super(MockPolicy, self).__init__(model)
self.learn_mode = model
@property
def eval_mode(self):
return EasyDict({"state_dict": lambda: {}})
@pytest.mark.unittest
def test_ckpt_saver():
exp_name = 'test_ckpt_saver_exp'
ctx = OnlineRLContext()
train_freq = 100
model = TheModelClass()
if not os.path.exists(exp_name):
os.makedirs(exp_name)
prefix = '{}/ckpt'.format(exp_name)
with patch("ding.policy.Policy", MockPolicy), task.start():
policy = MockPolicy(model)
def mock_save_file(path, data, fs_type=None, use_lock=False):
assert path == "{}/eval.pth.tar".format(prefix)
with patch("ding.framework.middleware.ckpt_handler.save_file", mock_save_file):
ctx.train_iter = 1
ctx.eval_value = 9.4
ckpt_saver = CkptSaver(policy, exp_name, train_freq)
ckpt_saver(ctx)
def mock_save_file(path, data, fs_type=None, use_lock=False):
assert path == "{}/iteration_{}.pth.tar".format(prefix, ctx.train_iter)
with patch("ding.framework.middleware.ckpt_handler.save_file", mock_save_file):
ctx.train_iter = 100
ctx.eval_value = 1
ckpt_saver(ctx)
def mock_save_file(path, data, fs_type=None, use_lock=False):
assert path == "{}/final.pth.tar".format(prefix)
with patch("ding.framework.middleware.ckpt_handler.save_file", mock_save_file):
task.finish = True
ckpt_saver(ctx)
shutil.rmtree(exp_name)
|