File size: 3,222 Bytes
f761808
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import torch
from torch import nn
from torch.nn import functional as F
from agent.helpers import init_weights


class VAE(nn.Module):
    def __init__(self, state_dim, action_dim, device, hidden_size=256) -> None:
        super(VAE, self).__init__()

        self.hidden_size = hidden_size
        self.action_dim = action_dim

        input_dim = state_dim + action_dim

        self.encoder = nn.Sequential(nn.Linear(input_dim, hidden_size),
                                     nn.Mish(),
                                     nn.Linear(hidden_size, hidden_size),
                                     nn.Mish(),
                                     nn.Linear(hidden_size, hidden_size),
                                     nn.Mish())

        self.fc_mu = nn.Linear(hidden_size, hidden_size)
        self.fc_var = nn.Linear(hidden_size, hidden_size)

        self.decoder = nn.Sequential(nn.Linear(hidden_size + state_dim, hidden_size),
                                     nn.Mish(),
                                     nn.Linear(hidden_size, hidden_size),
                                     nn.Mish(),
                                     nn.Linear(hidden_size, hidden_size),
                                     nn.Mish())

        self.final_layer = nn.Sequential(nn.Linear(hidden_size, action_dim))

        self.apply(init_weights)

        self.device = device

    def encode(self, action, state):
        x = torch.cat([action, state], dim=-1)
        result = self.encoder(x)
        result = torch.flatten(result, start_dim=1)

        # Split the result into mu and var components
        # of the latent Gaussian distribution
        mu = self.fc_mu(result)
        log_var = self.fc_var(result)

        return mu, log_var

    def decode(self, z, state):
        x = torch.cat([z, state], dim=-1)
        result = self.decoder(x)
        result = self.final_layer(result)
        return result

    def reparameterize(self, mu, logvar):
        """
        Will a single z be enough ti compute the expectation
        for the loss??
        :param mu: (Tensor) Mean of the latent Gaussian
        :param logvar: (Tensor) Standard deviation of the latent Gaussian
        :return:
        """
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return eps * std + mu

    def loss(self, action, state):
        mu, log_var = self.encode(action, state)
        z = self.reparameterize(mu, log_var)
        recons = self.decode(z, state)

        kld_weight = 0.1  # Account for the minibatch samples from the dataset
        recons_loss = F.mse_loss(recons, action)

        kld_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim=1), dim=0)

        # print('recons_loss: ', recons_loss)
        # print('kld_loss: ', kld_loss)

        loss = recons_loss + kld_weight * kld_loss
        return loss

    def forward(self, state, eval=False):
        batch_size = state.shape[0]
        shape = (batch_size, self.hidden_size)

        if eval:
            z = torch.zeros(shape, device=self.device)
        else:
            z = torch.randn(shape, device=self.device)
        samples = self.decode(z, state)

        return samples.clamp(-1., 1.)