File size: 1,129 Bytes
0e936e1 8bf4dee 0e936e1 8bf4dee 0e936e1 8bf4dee 0e936e1 8bf4dee 0e936e1 8bf4dee |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 |
from typing import Sequence, Type
import numpy as np
import torch
import torch.nn as nn
from rl_algo_impls.shared.encoder import EncoderOutDim
from rl_algo_impls.shared.module.module import mlp
class CriticHead(nn.Module):
def __init__(
self,
in_dim: EncoderOutDim,
hidden_sizes: Sequence[int] = (),
activation: Type[nn.Module] = nn.Tanh,
init_layers_orthogonal: bool = True,
) -> None:
super().__init__()
seq = []
if isinstance(in_dim, tuple):
seq.append(nn.Flatten())
in_channels = int(np.prod(in_dim))
else:
in_channels = in_dim
layer_sizes = (in_channels,) + tuple(hidden_sizes) + (1,)
seq.append(
mlp(
layer_sizes,
activation,
init_layers_orthogonal=init_layers_orthogonal,
final_layer_gain=1.0,
hidden_layer_gain=1.0,
)
)
self._fc = nn.Sequential(*seq)
def forward(self, obs: torch.Tensor) -> torch.Tensor:
v = self._fc(obs)
return v.squeeze(-1)
|