File size: 3,379 Bytes
079c32c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 |
import pytest
import random
import copy
import torch
import treetensor.torch as ttorch
from unittest.mock import Mock, patch
from ding.data.buffer import DequeBuffer
from ding.framework import OnlineRLContext, task
from ding.framework.middleware import trainer, multistep_trainer, OffPolicyLearner, HERLearner
from ding.framework.middleware.tests import MockHerRewardModel, CONFIG
class MockPolicy(Mock):
_device = 'cpu'
# MockPolicy class for train mode
def forward(self, train_data, **kwargs):
res = {
'total_loss': 0.1,
}
return res
class MultiStepMockPolicy(Mock):
_device = 'cpu'
# MockPolicy class for multi-step train mode
def forward(self, train_data, **kwargs):
res = [
{
'total_loss': 0.1,
},
{
'total_loss': 1.0,
},
]
return res
def get_mock_train_input():
data = {'obs': torch.rand(2, 2), 'next_obs': torch.rand(2, 2), 'reward': random.random(), 'info': {}}
return ttorch.as_tensor(data)
@pytest.mark.unittest
def test_trainer():
cfg = copy.deepcopy(CONFIG)
ctx = OnlineRLContext()
ctx.train_data = None
with patch("ding.policy.Policy", MockPolicy):
policy = MockPolicy()
for _ in range(10):
trainer(cfg, policy)(ctx)
assert ctx.train_iter == 0
ctx.train_data = get_mock_train_input()
with patch("ding.policy.Policy", MockPolicy):
policy = MockPolicy()
for _ in range(30):
trainer(cfg, policy)(ctx)
assert ctx.train_iter == 30
assert ctx.train_output["total_loss"] == 0.1
@pytest.mark.unittest
def test_multistep_trainer():
cfg = copy.deepcopy(CONFIG)
ctx = OnlineRLContext()
ctx.train_data = None
with patch("ding.policy.Policy", MockPolicy):
policy = MockPolicy()
for _ in range(10):
trainer(cfg, policy)(ctx)
assert ctx.train_iter == 0
ctx.train_data = get_mock_train_input()
with patch("ding.policy.Policy", MultiStepMockPolicy):
policy = MultiStepMockPolicy()
for _ in range(30):
multistep_trainer(policy, 10)(ctx)
assert ctx.train_iter == 60
assert ctx.train_output[0]["total_loss"] == 0.1
assert ctx.train_output[1]["total_loss"] == 1.0
@pytest.mark.unittest
def test_offpolicy_learner():
cfg = copy.deepcopy(CONFIG)
ctx = OnlineRLContext()
buffer = DequeBuffer(size=10)
for _ in range(10):
buffer.push(get_mock_train_input())
with patch("ding.policy.Policy", MockPolicy):
with task.start():
policy = MockPolicy()
learner = OffPolicyLearner(cfg, policy, buffer)
learner(ctx)
assert len(ctx.train_output) == 4
@pytest.mark.unittest
def test_her_learner():
cfg = copy.deepcopy(CONFIG)
ctx = OnlineRLContext()
buffer = DequeBuffer(size=10)
for _ in range(10):
buffer.push([get_mock_train_input(), get_mock_train_input()])
with patch("ding.policy.Policy", MockPolicy), patch("ding.reward_model.HerRewardModel", MockHerRewardModel):
with task.start():
policy = MockPolicy()
her_reward_model = MockHerRewardModel()
learner = HERLearner(cfg, policy, buffer, her_reward_model)
learner(ctx)
assert len(ctx.train_output) == 4
|