A2C playing LunarLander-v2 from https://github.com/sgoodfriend/rl-algo-impls/tree/983cb75e43e51cf4ef57f177194ab9a4a1a8808b
de6a584
| from typing import Sequence, Type | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| from rl_algo_impls.shared.encoder import EncoderOutDim | |
| from rl_algo_impls.shared.module.utils import mlp | |
| class CriticHead(nn.Module): | |
| def __init__( | |
| self, | |
| in_dim: EncoderOutDim, | |
| hidden_sizes: Sequence[int] = (), | |
| activation: Type[nn.Module] = nn.Tanh, | |
| init_layers_orthogonal: bool = True, | |
| ) -> None: | |
| super().__init__() | |
| seq = [] | |
| if isinstance(in_dim, tuple): | |
| seq.append(nn.Flatten()) | |
| in_channels = int(np.prod(in_dim)) | |
| else: | |
| in_channels = in_dim | |
| layer_sizes = (in_channels,) + tuple(hidden_sizes) + (1,) | |
| seq.append( | |
| mlp( | |
| layer_sizes, | |
| activation, | |
| init_layers_orthogonal=init_layers_orthogonal, | |
| final_layer_gain=1.0, | |
| hidden_layer_gain=1.0, | |
| ) | |
| ) | |
| self._fc = nn.Sequential(*seq) | |
| def forward(self, obs: torch.Tensor) -> torch.Tensor: | |
| v = self._fc(obs) | |
| return v.squeeze(-1) | |