File size: 4,241 Bytes
6d1ad4f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import gym
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from gym.spaces import Box, Discrete
from stable_baselines3.common.preprocessing import get_flattened_obs_dim
from typing import Sequence, Type


class FeatureExtractor(nn.Module):
    def __init__(
        self,
        obs_space: gym.Space,
        activation: Type[nn.Module],
        init_layers_orthogonal: bool = False,
        cnn_feature_dim: int = 512,
    ) -> None:
        super().__init__()
        if isinstance(obs_space, Box):
            # Conv2D: (channels, height, width)
            if len(obs_space.shape) == 3:
                # CNN from DQN Nature paper: Mnih, Volodymyr, et al.
                # "Human-level control through deep reinforcement learning."
                # Nature 518.7540 (2015): 529-533.
                cnn = nn.Sequential(
                    layer_init(
                        nn.Conv2d(obs_space.shape[0], 32, kernel_size=8, stride=4),
                        init_layers_orthogonal,
                    ),
                    activation(),
                    layer_init(
                        nn.Conv2d(32, 64, kernel_size=4, stride=2),
                        init_layers_orthogonal,
                    ),
                    activation(),
                    layer_init(
                        nn.Conv2d(64, 64, kernel_size=3, stride=1),
                        init_layers_orthogonal,
                    ),
                    activation(),
                    nn.Flatten(),
                )

                def preprocess(obs: torch.Tensor) -> torch.Tensor:
                    if len(obs.shape) == 3:
                        obs = obs.unsqueeze(0)
                    return obs.float() / 255.0

                with torch.no_grad():
                    cnn_out = cnn(preprocess(torch.as_tensor(obs_space.sample())))
                self.preprocess = preprocess
                self.feature_extractor = nn.Sequential(
                    cnn,
                    layer_init(
                        nn.Linear(cnn_out.shape[1], cnn_feature_dim),
                        init_layers_orthogonal,
                    ),
                    activation(),
                )
                self.out_dim = cnn_feature_dim
            elif len(obs_space.shape) == 1:

                def preprocess(obs: torch.Tensor) -> torch.Tensor:
                    if len(obs.shape) == 1:
                        obs = obs.unsqueeze(0)
                    return obs.float()

                self.preprocess = preprocess
                self.feature_extractor = nn.Flatten()
                self.out_dim = get_flattened_obs_dim(obs_space)
            else:
                raise ValueError(f"Unsupported observation space: {obs_space}")
        elif isinstance(obs_space, Discrete):
            self.preprocess = lambda x: F.one_hot(x, obs_space.n).float()
            self.feature_extractor = nn.Flatten()
            self.out_dim = obs_space.n
        else:
            raise NotImplementedError

    def forward(self, obs: torch.Tensor) -> torch.Tensor:
        if self.preprocess:
            obs = self.preprocess(obs)
        return self.feature_extractor(obs)


def mlp(
    layer_sizes: Sequence[int],
    activation: Type[nn.Module],
    output_activation: Type[nn.Module] = nn.Identity,
    init_layers_orthogonal: bool = False,
    final_layer_gain: float = np.sqrt(2),
) -> nn.Module:
    layers = []
    for i in range(len(layer_sizes) - 2):
        layers.append(
            layer_init(
                nn.Linear(layer_sizes[i], layer_sizes[i + 1]), init_layers_orthogonal
            )
        )
        layers.append(activation())
    layers.append(
        layer_init(
            nn.Linear(layer_sizes[-2], layer_sizes[-1]),
            init_layers_orthogonal,
            std=final_layer_gain,
        )
    )
    layers.append(output_activation())
    return nn.Sequential(*layers)


def layer_init(
    layer: nn.Module, init_layers_orthogonal: bool, std: float = np.sqrt(2)
) -> nn.Module:
    if not init_layers_orthogonal:
        return layer
    nn.init.orthogonal_(layer.weight, std)  # type: ignore
    nn.init.constant_(layer.bias, 0.0)  # type: ignore
    return layer