sgoodfriend's picture
PPO playing MountainCar-v0 from https://github.com/sgoodfriend/rl-algo-impls/tree/983cb75e43e51cf4ef57f177194ab9a4a1a8808b
3cc5c1d
raw
history blame
No virus
1.53 kB
from abc import ABC, abstractmethod
from typing import NamedTuple, Optional, Sequence, Tuple
import torch
import torch.nn as nn
from gym.spaces import Box, Discrete, Space
from rl_algo_impls.shared.actor import PiForward
class ACNForward(NamedTuple):
pi_forward: PiForward
v: torch.Tensor
class ActorCriticNetwork(nn.Module, ABC):
@abstractmethod
def forward(
self,
obs: torch.Tensor,
action: torch.Tensor,
action_masks: Optional[torch.Tensor] = None,
) -> ACNForward:
...
@abstractmethod
def distribution_and_value(
self, obs: torch.Tensor, action_masks: Optional[torch.Tensor] = None
) -> ACNForward:
...
@abstractmethod
def value(self, obs: torch.Tensor) -> torch.Tensor:
...
@abstractmethod
def reset_noise(self, batch_size: Optional[int] = None) -> None:
...
@property
def action_shape(self) -> Tuple[int, ...]:
...
def default_hidden_sizes(obs_space: Space) -> Sequence[int]:
if isinstance(obs_space, Box):
if len(obs_space.shape) == 3: # type: ignore
# By default feature extractor to output has no hidden layers
return []
elif len(obs_space.shape) == 1: # type: ignore
return [64, 64]
else:
raise ValueError(f"Unsupported observation space: {obs_space}")
elif isinstance(obs_space, Discrete):
return [64]
else:
raise ValueError(f"Unsupported observation space: {obs_space}")