File size: 475 Bytes
079c32c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
import pytest
import torch
from ding.policy.mbpolicy.utils import q_evaluation


@pytest.mark.unittest
def test_q_evaluation():
    T, B, O, A = 10, 20, 100, 30
    obss = torch.randn(T, B, O)
    actions = torch.randn(T, B, A)

    def fake_q_fn(obss, actions):
        # obss:    flatten_B * O
        # actions: flatten_B * A
        # return:  flatten_B
        return obss.sum(-1)

    q_value = q_evaluation(obss, actions, fake_q_fn)
    assert q_value.shape == (T, B)