sgoodfriend's picture
A2C playing PongNoFrameskip-v4 from https://github.com/sgoodfriend/rl-algo-impls/tree/983cb75e43e51cf4ef57f177194ab9a4a1a8808b
05b94c0
from typing import Sequence, Type
import numpy as np
import torch.nn as nn
def mlp(
layer_sizes: Sequence[int],
activation: Type[nn.Module],
output_activation: Type[nn.Module] = nn.Identity,
init_layers_orthogonal: bool = False,
final_layer_gain: float = np.sqrt(2),
hidden_layer_gain: float = np.sqrt(2),
) -> nn.Module:
layers = []
for i in range(len(layer_sizes) - 2):
layers.append(
layer_init(
nn.Linear(layer_sizes[i], layer_sizes[i + 1]),
init_layers_orthogonal,
std=hidden_layer_gain,
)
)
layers.append(activation())
layers.append(
layer_init(
nn.Linear(layer_sizes[-2], layer_sizes[-1]),
init_layers_orthogonal,
std=final_layer_gain,
)
)
layers.append(output_activation())
return nn.Sequential(*layers)
def layer_init(
layer: nn.Module, init_layers_orthogonal: bool, std: float = np.sqrt(2)
) -> nn.Module:
if not init_layers_orthogonal:
return layer
nn.init.orthogonal_(layer.weight, std) # type: ignore
nn.init.constant_(layer.bias, 0.0) # type: ignore
return layer