File size: 991 Bytes
079c32c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
#unittest for priority_calculator

import unittest
import pytest
import numpy as np
from unittest.mock import Mock, patch
from ding.framework import OnlineRLContext, OfflineRLContext
from ding.framework import task, Parallel
from ding.framework.middleware.functional import priority_calculator


class MockPolicy(Mock):

    def priority_fun(self, data):
        return np.random.rand(len(data))


@pytest.mark.unittest
def test_priority_calculator():
    policy = MockPolicy()
    ctx = OnlineRLContext()
    ctx.trajectories = [
        {
            'obs': np.random.rand(2, 2),
            'next_obs': np.random.rand(2, 2),
            'reward': np.random.rand(1),
            'info': {}
        } for _ in range(10)
    ]
    priority_calculator_middleware = priority_calculator(priority_calculation_fn=policy.priority_fun)
    priority_calculator_middleware(ctx)
    assert len(ctx.trajectories) == 10
    assert all([isinstance(traj['priority'], float) for traj in ctx.trajectories])