File size: 991 Bytes
079c32c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 |
#unittest for priority_calculator
import unittest
import pytest
import numpy as np
from unittest.mock import Mock, patch
from ding.framework import OnlineRLContext, OfflineRLContext
from ding.framework import task, Parallel
from ding.framework.middleware.functional import priority_calculator
class MockPolicy(Mock):
def priority_fun(self, data):
return np.random.rand(len(data))
@pytest.mark.unittest
def test_priority_calculator():
policy = MockPolicy()
ctx = OnlineRLContext()
ctx.trajectories = [
{
'obs': np.random.rand(2, 2),
'next_obs': np.random.rand(2, 2),
'reward': np.random.rand(1),
'info': {}
} for _ in range(10)
]
priority_calculator_middleware = priority_calculator(priority_calculation_fn=policy.priority_fun)
priority_calculator_middleware(ctx)
assert len(ctx.trajectories) == 10
assert all([isinstance(traj['priority'], float) for traj in ctx.trajectories])
|