| from typing import Dict, Optional, Sequence, Type |
|
|
| import gym |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from gym.spaces import Box, Discrete |
| from stable_baselines3.common.preprocessing import get_flattened_obs_dim |
|
|
| from rl_algo_impls.shared.encoder.cnn import CnnEncoder |
| from rl_algo_impls.shared.encoder.gridnet_encoder import GridnetEncoder |
| from rl_algo_impls.shared.encoder.impala_cnn import ImpalaCnn |
| from rl_algo_impls.shared.encoder.microrts_cnn import MicrortsCnn |
| from rl_algo_impls.shared.encoder.nature_cnn import NatureCnn |
| from rl_algo_impls.shared.module.utils import layer_init |
|
|
| CNN_EXTRACTORS_BY_STYLE: Dict[str, Type[CnnEncoder]] = { |
| "nature": NatureCnn, |
| "impala": ImpalaCnn, |
| "microrts": MicrortsCnn, |
| "gridnet_encoder": GridnetEncoder, |
| } |
|
|
|
|
| class Encoder(nn.Module): |
| def __init__( |
| self, |
| obs_space: gym.Space, |
| activation: Type[nn.Module], |
| init_layers_orthogonal: bool = False, |
| cnn_flatten_dim: int = 512, |
| cnn_style: str = "nature", |
| cnn_layers_init_orthogonal: Optional[bool] = None, |
| impala_channels: Sequence[int] = (16, 32, 32), |
| ) -> None: |
| super().__init__() |
| if isinstance(obs_space, Box): |
| |
| if len(obs_space.shape) == 3: |
| self.preprocess = None |
| cnn = CNN_EXTRACTORS_BY_STYLE[cnn_style]( |
| obs_space, |
| activation=activation, |
| cnn_init_layers_orthogonal=cnn_layers_init_orthogonal, |
| linear_init_layers_orthogonal=init_layers_orthogonal, |
| cnn_flatten_dim=cnn_flatten_dim, |
| impala_channels=impala_channels, |
| ) |
| self.feature_extractor = cnn |
| self.out_dim = cnn.out_dim |
| elif len(obs_space.shape) == 1: |
|
|
| def preprocess(obs: torch.Tensor) -> torch.Tensor: |
| if len(obs.shape) == 1: |
| obs = obs.unsqueeze(0) |
| return obs.float() |
|
|
| self.preprocess = preprocess |
| self.feature_extractor = nn.Flatten() |
| self.out_dim = get_flattened_obs_dim(obs_space) |
| else: |
| raise ValueError(f"Unsupported observation space: {obs_space}") |
| elif isinstance(obs_space, Discrete): |
| self.preprocess = lambda x: F.one_hot(x, obs_space.n).float() |
| self.feature_extractor = nn.Flatten() |
| self.out_dim = obs_space.n |
| else: |
| raise NotImplementedError |
|
|
| def forward(self, obs: torch.Tensor) -> torch.Tensor: |
| if self.preprocess: |
| obs = self.preprocess(obs) |
| return self.feature_extractor(obs) |
|
|