zjowowen's picture
init space
079c32c
raw
history blame
2.59 kB
import pytest
import torch
from ding.framework import OnlineRLContext
from ding.data.buffer import DequeBuffer
from typing import Any
import numpy as np
import copy
from ding.framework.middleware.functional.enhancer import reward_estimator, her_data_enhancer
from unittest.mock import Mock, patch
from ding.framework.middleware.tests import MockHerRewardModel, CONFIG
DATA = [{'obs': torch.rand(2, 2), 'next_obs': torch.rand(2, 2)} for _ in range(20)]
class MockRewardModel(Mock):
def estimate(self, data: list) -> Any:
assert len(data) == len(DATA)
assert torch.equal(data[0]['obs'], DATA[0]['obs'])
@pytest.mark.unittest
def test_reward_estimator():
ctx = OnlineRLContext()
ctx.train_data = copy.deepcopy(DATA)
with patch("ding.reward_model.HerRewardModel", MockHerRewardModel):
reward_estimator(cfg=None, reward_model=MockRewardModel())(ctx)
@pytest.mark.unittest
def test_her_data_enhancer():
cfg = copy.deepcopy(CONFIG)
ctx = OnlineRLContext()
with patch("ding.reward_model.HerRewardModel", MockHerRewardModel):
mock_her_reward_model = MockHerRewardModel()
buffer = DequeBuffer(mock_her_reward_model.episode_size)
train_data = [
[
{
'action': torch.randint(low=0, high=5, size=(1, )),
'collect_train_iter': torch.tensor([0]),
'done': torch.tensor(False),
'next_obs': torch.randint(low=0, high=2, size=(10, ), dtype=torch.float32),
'obs': torch.randint(low=0, high=2, size=(10, ), dtype=torch.float32),
'reward': torch.randint(low=0, high=2, size=(1, ), dtype=torch.float32),
} for _ in range(np.random.choice([1, 4, 5], size=1)[0])
] for _ in range(mock_her_reward_model.episode_size)
]
for d in train_data:
buffer.push(d)
her_data_enhancer(cfg=cfg, buffer_=buffer, her_reward_model=MockHerRewardModel())(ctx)
assert len(ctx.train_data) == mock_her_reward_model.episode_size * mock_her_reward_model.episode_element_size
assert len(ctx.train_data[0]) == 6
buffer = DequeBuffer(cfg.policy.learn.batch_size)
for d in train_data:
buffer.push(d)
mock_her_reward_model.episode_size = None
her_data_enhancer(cfg=cfg, buffer_=buffer, her_reward_model=MockHerRewardModel())(ctx)
assert len(ctx.train_data) == cfg.policy.learn.batch_size * mock_her_reward_model.episode_element_size
assert len(ctx.train_data[0]) == 6