|
import pytest |
|
import torch |
|
import copy |
|
from unittest.mock import patch |
|
from ding.framework import OnlineRLContext, task |
|
from ding.framework.middleware import TransitionList, inferencer, rolloutor |
|
from ding.framework.middleware import StepCollector, EpisodeCollector |
|
from ding.framework.middleware.tests import MockPolicy, MockEnv, CONFIG |
|
|
|
|
|
@pytest.mark.unittest |
|
def test_inferencer(): |
|
ctx = OnlineRLContext() |
|
with patch("ding.policy.Policy", MockPolicy), patch("ding.envs.BaseEnvManagerV2", MockEnv): |
|
policy = MockPolicy() |
|
env = MockEnv() |
|
inferencer(0, policy, env)(ctx) |
|
assert isinstance(ctx.inference_output, dict) |
|
assert ctx.inference_output[0] == {'action': torch.Tensor([0.])} |
|
assert ctx.inference_output[1] == {'action': torch.Tensor([4.])} |
|
|
|
|
|
@pytest.mark.unittest |
|
def test_rolloutor(): |
|
ctx = OnlineRLContext() |
|
transitions = TransitionList(2) |
|
with patch("ding.policy.Policy", MockPolicy), patch("ding.envs.BaseEnvManagerV2", MockEnv): |
|
policy = MockPolicy() |
|
env = MockEnv() |
|
for _ in range(10): |
|
inferencer(0, policy, env)(ctx) |
|
rolloutor(policy, env, transitions)(ctx) |
|
assert ctx.env_episode == 20 |
|
assert ctx.env_step == 20 |
|
|
|
|
|
@pytest.mark.unittest |
|
def test_step_collector(): |
|
cfg = copy.deepcopy(CONFIG) |
|
ctx = OnlineRLContext() |
|
|
|
|
|
with patch("ding.policy.Policy", MockPolicy), patch("ding.envs.BaseEnvManagerV2", MockEnv): |
|
with task.start(): |
|
policy = MockPolicy() |
|
env = MockEnv() |
|
collector = StepCollector(cfg, policy, env) |
|
collector(ctx) |
|
assert len(ctx.trajectories) == 16 |
|
assert ctx.trajectory_end_idx == [7, 15] |
|
|
|
|
|
with patch("ding.policy.Policy", MockPolicy), patch("ding.envs.BaseEnvManagerV2", MockEnv): |
|
with task.start(): |
|
policy = MockPolicy() |
|
env = MockEnv() |
|
collector = StepCollector(cfg, policy, env, random_collect_size=8) |
|
collector(ctx) |
|
assert len(ctx.trajectories) == 16 |
|
assert ctx.trajectory_end_idx == [7, 15] |
|
|
|
|
|
@pytest.mark.unittest |
|
def test_episode_collector(): |
|
cfg = copy.deepcopy(CONFIG) |
|
ctx = OnlineRLContext() |
|
|
|
|
|
with patch("ding.policy.Policy", MockPolicy), patch("ding.envs.BaseEnvManagerV2", MockEnv): |
|
with task.start(): |
|
policy = MockPolicy() |
|
env = MockEnv() |
|
collector = EpisodeCollector(cfg, policy, env) |
|
collector(ctx) |
|
assert len(ctx.episodes) == 16 |
|
|
|
|
|
with patch("ding.policy.Policy", MockPolicy), patch("ding.envs.BaseEnvManagerV2", MockEnv): |
|
with task.start(): |
|
policy = MockPolicy() |
|
env = MockEnv() |
|
collector = EpisodeCollector(cfg, policy, env, random_collect_size=8) |
|
collector(ctx) |
|
assert len(ctx.episodes) == 16 |
|
|