import pytest import numpy as np import torch from itertools import product from ding.model import VAC, DREAMERVAC from ding.torch_utils import is_differentiable from ding.model import ConvEncoder from easydict import EasyDict ezD = EasyDict({'action_args_shape': (3, ), 'action_type_shape': 4}) B, C, H, W = 4, 3, 128, 128 obs_shape = [4, (8, ), (4, 64, 64)] act_args = [[6, 'discrete'], [(3, ), 'continuous'], [[2, 3, 6], 'discrete'], [ezD, 'hybrid']] # act_args = [[(3, ), True]] args = list(product(*[obs_shape, act_args, [False, True]])) def output_check(model, outputs, action_shape): if isinstance(action_shape, tuple): loss = sum([t.sum() for t in outputs]) elif np.isscalar(action_shape): loss = outputs.sum() elif isinstance(action_shape, dict): loss = outputs.sum() is_differentiable(loss, model) def model_check(model, inputs): outputs = model(inputs, mode='compute_actor_critic') value, logit = outputs['value'], outputs['logit'] if model.action_space == 'continuous': outputs = value.sum() + logit['mu'].sum() + logit['sigma'].sum() elif model.action_space == 'hybrid': outputs = value.sum() + logit['action_type'].sum() + logit['action_args']['mu'].sum( ) + logit['action_args']['sigma'].sum() else: if model.multi_head: outputs = value.sum() + sum([t.sum() for t in logit]) else: outputs = value.sum() + logit.sum() output_check(model, outputs, 1) for p in model.parameters(): p.grad = None logit = model(inputs, mode='compute_actor')['logit'] if model.action_space == 'continuous': logit = logit['mu'].sum() + logit['sigma'].sum() elif model.action_space == 'hybrid': logit = logit['action_type'].sum() + logit['action_args']['mu'].sum() + logit['action_args']['sigma'].sum() output_check(model.actor, logit, model.action_shape) for p in model.parameters(): p.grad = None value = model(inputs, mode='compute_critic')['value'] assert value.shape == (B, ) output_check(model.critic, value, 1) @pytest.mark.unittest class TestDREAMERVAC: def test_DREAMERVAC(self): obs_shape = 8 act_shape = 6 model = DREAMERVAC(obs_shape, act_shape) @pytest.mark.unittest @pytest.mark.parametrize('obs_shape, act_args, share_encoder', args) class TestVACGeneral: def test_vac(self, obs_shape, act_args, share_encoder): if isinstance(obs_shape, int): inputs = torch.randn(B, obs_shape) else: inputs = torch.randn(B, *obs_shape) model = VAC(obs_shape, action_shape=act_args[0], action_space=act_args[1], share_encoder=share_encoder) model_check(model, inputs) @pytest.mark.unittest @pytest.mark.parametrize('share_encoder', [(False, ), (True, )]) class TestVACEncoder: def test_vac_with_impala_encoder(self, share_encoder): inputs = torch.randn(B, 4, 64, 64) model = VAC( obs_shape=(4, 64, 64), action_shape=6, action_space='discrete', share_encoder=share_encoder, impala_cnn_encoder=True ) model_check(model, inputs) def test_encoder_assignment(self, share_encoder): inputs = torch.randn(B, 4, 64, 64) special_encoder = ConvEncoder(obs_shape=(4, 64, 64), hidden_size_list=[16, 32, 32, 64]) model = VAC( obs_shape=(4, 64, 64), action_shape=6, action_space='discrete', share_encoder=share_encoder, actor_head_hidden_size=64, critic_head_hidden_size=64, encoder=special_encoder ) model_check(model, inputs)