sgoodfriend's picture
A2C playing LunarLander-v2 from https://github.com/sgoodfriend/rl-algo-impls/tree/983cb75e43e51cf4ef57f177194ab9a4a1a8808b
de6a584
raw
history blame contribute delete
No virus
1.35 kB
from typing import Optional, Sequence, Type
import gym
import torch as th
import torch.nn as nn
from gym.spaces import Discrete
from rl_algo_impls.shared.encoder import Encoder
from rl_algo_impls.shared.module.utils import mlp
class QNetwork(nn.Module):
def __init__(
self,
observation_space: gym.Space,
action_space: gym.Space,
hidden_sizes: Sequence[int] = [],
activation: Type[nn.Module] = nn.ReLU, # Used by stable-baselines3
cnn_flatten_dim: int = 512,
cnn_style: str = "nature",
cnn_layers_init_orthogonal: Optional[bool] = None,
impala_channels: Sequence[int] = (16, 32, 32),
) -> None:
super().__init__()
assert isinstance(action_space, Discrete)
self._feature_extractor = Encoder(
observation_space,
activation,
cnn_flatten_dim=cnn_flatten_dim,
cnn_style=cnn_style,
cnn_layers_init_orthogonal=cnn_layers_init_orthogonal,
impala_channels=impala_channels,
)
layer_sizes = (
(self._feature_extractor.out_dim,) + tuple(hidden_sizes) + (action_space.n,)
)
self._fc = mlp(layer_sizes, activation)
def forward(self, obs: th.Tensor) -> th.Tensor:
x = self._feature_extractor(obs)
return self._fc(x)