from typing import Type from torch import nn # Lightly adapted from # https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py # noqa class MLPBlock(nn.Module): def __init__( self, input_dim: int, hidden_dim: int, output_dim: int, num_layers: int, act: Type[nn.Module], ) -> None: super().__init__() self.num_layers = num_layers h = [hidden_dim] * (num_layers - 1) self.layers = nn.ModuleList( nn.Sequential(nn.Linear(n, k), act()) for n, k in zip([input_dim] + h, [hidden_dim] * num_layers) ) self.fc = nn.Linear(hidden_dim, output_dim) def forward(self, x): for layer in self.layers: x = layer(x) return self.fc(x)