Spaces:
Runtime error
Runtime error
from typing import Type | |
from torch import nn | |
# Lightly adapted from | |
# https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py # noqa | |
class MLPBlock(nn.Module): | |
def __init__( | |
self, | |
input_dim: int, | |
hidden_dim: int, | |
output_dim: int, | |
num_layers: int, | |
act: Type[nn.Module], | |
) -> None: | |
super().__init__() | |
self.num_layers = num_layers | |
h = [hidden_dim] * (num_layers - 1) | |
self.layers = nn.ModuleList( | |
nn.Sequential(nn.Linear(n, k), act()) | |
for n, k in zip([input_dim] + h, [hidden_dim] * num_layers) | |
) | |
self.fc = nn.Linear(hidden_dim, output_dim) | |
def forward(self, x): | |
for layer in self.layers: | |
x = layer(x) | |
return self.fc(x) | |