File size: 2,836 Bytes
5c9c091 33f6fc2 5c9c091 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 |
from typing import Dict, Optional, Sequence, Type
import gym
import torch
import torch.nn as nn
import torch.nn.functional as F
from gym.spaces import Box, Discrete
from stable_baselines3.common.preprocessing import get_flattened_obs_dim
from rl_algo_impls.shared.encoder.cnn import CnnEncoder
from rl_algo_impls.shared.encoder.gridnet_encoder import GridnetEncoder
from rl_algo_impls.shared.encoder.impala_cnn import ImpalaCnn
from rl_algo_impls.shared.encoder.microrts_cnn import MicrortsCnn
from rl_algo_impls.shared.encoder.nature_cnn import NatureCnn
from rl_algo_impls.shared.module.utils import layer_init
CNN_EXTRACTORS_BY_STYLE: Dict[str, Type[CnnEncoder]] = {
"nature": NatureCnn,
"impala": ImpalaCnn,
"microrts": MicrortsCnn,
"gridnet_encoder": GridnetEncoder,
}
class Encoder(nn.Module):
def __init__(
self,
obs_space: gym.Space,
activation: Type[nn.Module],
init_layers_orthogonal: bool = False,
cnn_flatten_dim: int = 512,
cnn_style: str = "nature",
cnn_layers_init_orthogonal: Optional[bool] = None,
impala_channels: Sequence[int] = (16, 32, 32),
) -> None:
super().__init__()
if isinstance(obs_space, Box):
# Conv2D: (channels, height, width)
if len(obs_space.shape) == 3: # type: ignore
self.preprocess = None
cnn = CNN_EXTRACTORS_BY_STYLE[cnn_style](
obs_space,
activation=activation,
cnn_init_layers_orthogonal=cnn_layers_init_orthogonal,
linear_init_layers_orthogonal=init_layers_orthogonal,
cnn_flatten_dim=cnn_flatten_dim,
impala_channels=impala_channels,
)
self.feature_extractor = cnn
self.out_dim = cnn.out_dim
elif len(obs_space.shape) == 1: # type: ignore
def preprocess(obs: torch.Tensor) -> torch.Tensor:
if len(obs.shape) == 1:
obs = obs.unsqueeze(0)
return obs.float()
self.preprocess = preprocess
self.feature_extractor = nn.Flatten()
self.out_dim = get_flattened_obs_dim(obs_space)
else:
raise ValueError(f"Unsupported observation space: {obs_space}")
elif isinstance(obs_space, Discrete):
self.preprocess = lambda x: F.one_hot(x, obs_space.n).float()
self.feature_extractor = nn.Flatten()
self.out_dim = obs_space.n # type: ignore
else:
raise NotImplementedError
def forward(self, obs: torch.Tensor) -> torch.Tensor:
if self.preprocess:
obs = self.preprocess(obs)
return self.feature_extractor(obs)
|