sgoodfriend's picture
PPO playing MountainCar-v0 from https://github.com/sgoodfriend/rl-algo-impls/tree/983cb75e43e51cf4ef57f177194ab9a4a1a8808b
3cc5c1d
raw
history blame contribute delete
No virus
1.92 kB
from abc import ABC, abstractmethod
from typing import Optional, Tuple, Type, Union
import gym
import numpy as np
import torch
import torch.nn as nn
from rl_algo_impls.shared.module.utils import layer_init
EncoderOutDim = Union[int, Tuple[int, ...]]
class CnnEncoder(nn.Module, ABC):
@abstractmethod
def __init__(
self,
obs_space: gym.Space,
**kwargs,
) -> None:
super().__init__()
self.range_size = np.max(obs_space.high) - np.min(obs_space.low) # type: ignore
def preprocess(self, obs: torch.Tensor) -> torch.Tensor:
if len(obs.shape) == 3:
obs = obs.unsqueeze(0)
return obs.float() / self.range_size
def forward(self, obs: torch.Tensor) -> torch.Tensor:
return self.preprocess(obs)
@property
@abstractmethod
def out_dim(self) -> EncoderOutDim:
...
class FlattenedCnnEncoder(CnnEncoder):
def __init__(
self,
obs_space: gym.Space,
activation: Type[nn.Module],
linear_init_layers_orthogonal: bool,
cnn_flatten_dim: int,
cnn: nn.Module,
**kwargs,
) -> None:
super().__init__(obs_space, **kwargs)
self.cnn = cnn
self.flattened_dim = cnn_flatten_dim
with torch.no_grad():
cnn_out = torch.flatten(
cnn(self.preprocess(torch.as_tensor(obs_space.sample()))), start_dim=1
)
self.fc = nn.Sequential(
nn.Flatten(),
layer_init(
nn.Linear(cnn_out.shape[1], cnn_flatten_dim),
linear_init_layers_orthogonal,
),
activation(),
)
def forward(self, obs: torch.Tensor) -> torch.Tensor:
x = super().forward(obs)
x = self.cnn(x)
x = self.fc(x)
return x
@property
def out_dim(self) -> EncoderOutDim:
return self.flattened_dim