File size: 1,495 Bytes
079c32c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import pytest
import torch
from ding.torch_utils import is_differentiable
from ding.model.template.vae import VanillaVAE


@pytest.mark.unittest
def test_vae():
    batch_size = 32
    action_shape = 6
    original_action_shape = 2
    obs_shape = 6
    hidden_size_list = [256, 256]
    inputs = {
        'action': torch.randn(batch_size, original_action_shape),
        'obs': torch.randn(batch_size, obs_shape),
        'next_obs': torch.randn(batch_size, obs_shape)
    }

    vae_model = VanillaVAE(original_action_shape, obs_shape, action_shape, hidden_size_list)
    outputs = vae_model(inputs)

    assert outputs['recons_action'].shape == (batch_size, original_action_shape)
    assert outputs['prediction_residual'].shape == (batch_size, obs_shape)
    assert isinstance(outputs['input'], dict)
    assert outputs['mu'].shape == (batch_size, obs_shape)
    assert outputs['log_var'].shape == (batch_size, obs_shape)
    assert outputs['z'].shape == (batch_size, action_shape)

    outputs_decode = vae_model.decode_with_obs(outputs['z'], inputs['obs'])
    assert outputs_decode['reconstruction_action'].shape == (batch_size, original_action_shape)
    assert outputs_decode['predition_residual'].shape == (batch_size, obs_shape)

    outputs['original_action'] = inputs['action']
    outputs['true_residual'] = inputs['next_obs'] - inputs['obs']
    vae_loss = vae_model.loss_function(outputs, kld_weight=0.01, predict_weight=0.01)
    is_differentiable(vae_loss['loss'], vae_model)