sgoodfriend's picture
PPO playing impala-PongNoFrameskip-v4 from https://github.com/sgoodfriend/rl-algo-impls/tree/e47a44c4d891f48885af0b1605b30d19fc67b5af
be9c115
import gym
import torch as th
import torch.nn as nn
from gym.spaces import Discrete
from typing import Sequence, Type
from shared.module.feature_extractor import FeatureExtractor
from shared.module.module import mlp
class QNetwork(nn.Module):
def __init__(
self,
observation_space: gym.Space,
action_space: gym.Space,
hidden_sizes: Sequence[int] = [],
activation: Type[nn.Module] = nn.ReLU, # Used by stable-baselines3
) -> None:
super().__init__()
assert isinstance(action_space, Discrete)
self._feature_extractor = FeatureExtractor(observation_space, activation)
layer_sizes = (
(self._feature_extractor.out_dim,) + tuple(hidden_sizes) + (action_space.n,)
)
self._fc = mlp(layer_sizes, activation)
def forward(self, obs: th.Tensor) -> th.Tensor:
x = self._feature_extractor(obs)
return self._fc(x)