sgoodfriend's picture
PPO playing LunarLander-v2 from https://github.com/sgoodfriend/rl-algo-impls/tree/983cb75e43e51cf4ef57f177194ab9a4a1a8808b
479abbd
raw
history blame contribute delete
No virus
3.93 kB
from typing import Optional, Tuple, Type
import gym
import torch.nn as nn
from gym.spaces import Box, Discrete, MultiDiscrete
from rl_algo_impls.shared.actor.actor import Actor
from rl_algo_impls.shared.actor.categorical import CategoricalActorHead
from rl_algo_impls.shared.actor.gaussian import GaussianActorHead
from rl_algo_impls.shared.actor.gridnet import GridnetActorHead
from rl_algo_impls.shared.actor.gridnet_decoder import GridnetDecoder
from rl_algo_impls.shared.actor.multi_discrete import MultiDiscreteActorHead
from rl_algo_impls.shared.actor.state_dependent_noise import (
StateDependentNoiseActorHead,
)
from rl_algo_impls.shared.encoder import EncoderOutDim
def actor_head(
action_space: gym.Space,
in_dim: EncoderOutDim,
hidden_sizes: Tuple[int, ...],
init_layers_orthogonal: bool,
activation: Type[nn.Module],
log_std_init: float = -0.5,
use_sde: bool = False,
full_std: bool = True,
squash_output: bool = False,
actor_head_style: str = "single",
action_plane_space: Optional[bool] = None,
) -> Actor:
assert not use_sde or isinstance(
action_space, Box
), "use_sde only valid if Box action_space"
assert not squash_output or use_sde, "squash_output only valid if use_sde"
if isinstance(action_space, Discrete):
assert isinstance(in_dim, int)
return CategoricalActorHead(
action_space.n, # type: ignore
in_dim=in_dim,
hidden_sizes=hidden_sizes,
activation=activation,
init_layers_orthogonal=init_layers_orthogonal,
)
elif isinstance(action_space, Box):
assert isinstance(in_dim, int)
if use_sde:
return StateDependentNoiseActorHead(
action_space.shape[0], # type: ignore
in_dim=in_dim,
hidden_sizes=hidden_sizes,
activation=activation,
init_layers_orthogonal=init_layers_orthogonal,
log_std_init=log_std_init,
full_std=full_std,
squash_output=squash_output,
)
else:
return GaussianActorHead(
action_space.shape[0], # type: ignore
in_dim=in_dim,
hidden_sizes=hidden_sizes,
activation=activation,
init_layers_orthogonal=init_layers_orthogonal,
log_std_init=log_std_init,
)
elif isinstance(action_space, MultiDiscrete):
if actor_head_style == "single":
return MultiDiscreteActorHead(
action_space.nvec, # type: ignore
in_dim=in_dim,
hidden_sizes=hidden_sizes,
activation=activation,
init_layers_orthogonal=init_layers_orthogonal,
)
elif actor_head_style == "gridnet":
assert isinstance(action_plane_space, MultiDiscrete)
return GridnetActorHead(
len(action_space.nvec) // len(action_plane_space.nvec), # type: ignore
action_plane_space.nvec, # type: ignore
in_dim=in_dim,
hidden_sizes=hidden_sizes,
activation=activation,
init_layers_orthogonal=init_layers_orthogonal,
)
elif actor_head_style == "gridnet_decoder":
assert isinstance(action_plane_space, MultiDiscrete)
return GridnetDecoder(
len(action_space.nvec) // len(action_plane_space.nvec), # type: ignore
action_plane_space.nvec, # type: ignore
in_dim=in_dim,
activation=activation,
init_layers_orthogonal=init_layers_orthogonal,
)
else:
raise ValueError(f"Doesn't support actor_head_style {actor_head_style}")
else:
raise ValueError(f"Unsupported action space: {action_space}")