|
|
|
|
|
import unittest |
|
import pytest |
|
import numpy as np |
|
from unittest.mock import Mock, patch |
|
from ding.framework import OnlineRLContext, OfflineRLContext |
|
from ding.framework import task, Parallel |
|
from ding.framework.middleware.functional import priority_calculator |
|
|
|
|
|
class MockPolicy(Mock): |
|
|
|
def priority_fun(self, data): |
|
return np.random.rand(len(data)) |
|
|
|
|
|
@pytest.mark.unittest |
|
def test_priority_calculator(): |
|
policy = MockPolicy() |
|
ctx = OnlineRLContext() |
|
ctx.trajectories = [ |
|
{ |
|
'obs': np.random.rand(2, 2), |
|
'next_obs': np.random.rand(2, 2), |
|
'reward': np.random.rand(1), |
|
'info': {} |
|
} for _ in range(10) |
|
] |
|
priority_calculator_middleware = priority_calculator(priority_calculation_fn=policy.priority_fun) |
|
priority_calculator_middleware(ctx) |
|
assert len(ctx.trajectories) == 10 |
|
assert all([isinstance(traj['priority'], float) for traj in ctx.trajectories]) |
|
|