sgoodfriend's picture
PPO playing QbertNoFrameskip-v4 from https://github.com/sgoodfriend/rl-algo-impls/tree/983cb75e43e51cf4ef57f177194ab9a4a1a8808b
a861765
from typing import Dict, Optional, Sequence, Type
import gym
import torch
import torch.nn as nn
import torch.nn.functional as F
from gym.spaces import Box, Discrete
from stable_baselines3.common.preprocessing import get_flattened_obs_dim
from rl_algo_impls.shared.encoder.cnn import CnnEncoder
from rl_algo_impls.shared.encoder.gridnet_encoder import GridnetEncoder
from rl_algo_impls.shared.encoder.impala_cnn import ImpalaCnn
from rl_algo_impls.shared.encoder.microrts_cnn import MicrortsCnn
from rl_algo_impls.shared.encoder.nature_cnn import NatureCnn
from rl_algo_impls.shared.module.utils import layer_init
CNN_EXTRACTORS_BY_STYLE: Dict[str, Type[CnnEncoder]] = {
"nature": NatureCnn,
"impala": ImpalaCnn,
"microrts": MicrortsCnn,
"gridnet_encoder": GridnetEncoder,
}
class Encoder(nn.Module):
def __init__(
self,
obs_space: gym.Space,
activation: Type[nn.Module],
init_layers_orthogonal: bool = False,
cnn_flatten_dim: int = 512,
cnn_style: str = "nature",
cnn_layers_init_orthogonal: Optional[bool] = None,
impala_channels: Sequence[int] = (16, 32, 32),
) -> None:
super().__init__()
if isinstance(obs_space, Box):
# Conv2D: (channels, height, width)
if len(obs_space.shape) == 3: # type: ignore
self.preprocess = None
cnn = CNN_EXTRACTORS_BY_STYLE[cnn_style](
obs_space,
activation=activation,
cnn_init_layers_orthogonal=cnn_layers_init_orthogonal,
linear_init_layers_orthogonal=init_layers_orthogonal,
cnn_flatten_dim=cnn_flatten_dim,
impala_channels=impala_channels,
)
self.feature_extractor = cnn
self.out_dim = cnn.out_dim
elif len(obs_space.shape) == 1: # type: ignore
def preprocess(obs: torch.Tensor) -> torch.Tensor:
if len(obs.shape) == 1:
obs = obs.unsqueeze(0)
return obs.float()
self.preprocess = preprocess
self.feature_extractor = nn.Flatten()
self.out_dim = get_flattened_obs_dim(obs_space)
else:
raise ValueError(f"Unsupported observation space: {obs_space}")
elif isinstance(obs_space, Discrete):
self.preprocess = lambda x: F.one_hot(x, obs_space.n).float()
self.feature_extractor = nn.Flatten()
self.out_dim = obs_space.n # type: ignore
else:
raise NotImplementedError
def forward(self, obs: torch.Tensor) -> torch.Tensor:
if self.preprocess:
obs = self.preprocess(obs)
return self.feature_extractor(obs)