| | |
| |
|
| | import torch |
| | import torch.nn as nn |
| | from gymnasium import spaces |
| | from stable_baselines3.common.torch_layers import BaseFeaturesExtractor |
| |
|
| | class TransformerFeatureExtractor(BaseFeaturesExtractor): |
| | """ |
| | A custom feature extractor that uses a Transformer Encoder. |
| | |
| | It takes a flattened observation (window_size * n_features_per_step) and processes |
| | it as a sequence. |
| | """ |
| | def __init__( |
| | self, |
| | observation_space: spaces.Box, |
| | features_dim: int = 256, |
| | n_features_per_step: int = 8, |
| | window_size: int = 30, |
| | d_model: int = 64, |
| | n_head: int = 4, |
| | n_layers: int = 2, |
| | dropout: float = 0.1 |
| | ): |
| |
|
| | super().__init__(observation_space, features_dim) |
| |
|
| | self.window_size = window_size |
| | self.n_features_per_step = n_features_per_step |
| |
|
| | |
| | expected_flat_dim = window_size * n_features_per_step |
| | if observation_space.shape[0] != expected_flat_dim: |
| | raise ValueError( |
| | f"Observation space flat dimension {observation_space.shape[0]} " |
| | f"does not match expected {expected_flat_dim} " |
| | f"(window_size={window_size}, n_features_per_step={n_features_per_step})." |
| | ) |
| |
|
| | |
| | self.input_projection = nn.Linear(n_features_per_step, d_model) |
| |
|
| | |
| | self.positional_encoding = nn.Parameter(torch.randn(1, window_size, d_model)) |
| |
|
| | |
| | encoder_layer = nn.TransformerEncoderLayer( |
| | d_model=d_model, |
| | nhead=n_head, |
| | dropout=dropout, |
| | batch_first=True |
| | ) |
| | self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=n_layers) |
| |
|
| | |
| | self.flatten = nn.Flatten() |
| | self.linear_out = nn.Linear(window_size * d_model, features_dim) |
| | self.relu = nn.ReLU() |
| |
|
| | def forward(self, observations: torch.Tensor) -> torch.Tensor: |
| | |
| |
|
| | |
| | x = observations.reshape(-1, self.window_size, self.n_features_per_step) |
| |
|
| | |
| | x = self.input_projection(x) |
| |
|
| | |
| | x = x + self.positional_encoding |
| |
|
| | |
| | x = self.transformer_encoder(x) |
| |
|
| | |
| | x = self.flatten(x) |
| | x = self.relu(self.linear_out(x)) |
| |
|
| | return x |