jiayueru
Add app
7352753
raw
history blame
839 Bytes
from typing import Type
from torch import nn
# Lightly adapted from
# https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py # noqa
class MLPBlock(nn.Module):
def __init__(
self,
input_dim: int,
hidden_dim: int,
output_dim: int,
num_layers: int,
act: Type[nn.Module],
) -> None:
super().__init__()
self.num_layers = num_layers
h = [hidden_dim] * (num_layers - 1)
self.layers = nn.ModuleList(
nn.Sequential(nn.Linear(n, k), act())
for n, k in zip([input_dim] + h, [hidden_dim] * num_layers)
)
self.fc = nn.Linear(hidden_dim, output_dim)
def forward(self, x):
for layer in self.layers:
x = layer(x)
return self.fc(x)