sgoodfriend's picture
PPO playing CarRacing-v0 from https://github.com/sgoodfriend/rl-algo-impls/tree/fbc943f151b95afc4905a67a3835fb6b18c6a5e4
85e4a43
raw
history blame
884 Bytes
import gym
import torch as th
import torch.nn as nn
from gym.spaces import Discrete
from typing import Sequence, Type
from shared.module import FeatureExtractor, mlp
class QNetwork(nn.Module):
def __init__(
self,
observation_space: gym.Space,
action_space: gym.Space,
hidden_sizes: Sequence[int] = [],
activation: Type[nn.Module] = nn.ReLU, # Used by stable-baselines3
) -> None:
super().__init__()
assert isinstance(action_space, Discrete)
self._feature_extractor = FeatureExtractor(observation_space, activation)
layer_sizes = (
(self._feature_extractor.out_dim,) + tuple(hidden_sizes) + (action_space.n,)
)
self._fc = mlp(layer_sizes, activation)
def forward(self, obs: th.Tensor) -> th.Tensor:
x = self._feature_extractor(obs)
return self._fc(x)