tongyujun's picture
Upload 641 files
8c6b5ee verified
import functools
import torch.nn as nn
from .build import HEAD_REGISTRY
class MLP(nn.Module):
def __init__(
self,
in_features=2048,
hidden_layers=[],
activation="relu",
bn=True,
dropout=0.0,
):
super().__init__()
if isinstance(hidden_layers, int):
hidden_layers = [hidden_layers]
assert len(hidden_layers) > 0
self.out_features = hidden_layers[-1]
mlp = []
if activation == "relu":
act_fn = functools.partial(nn.ReLU, inplace=True)
elif activation == "leaky_relu":
act_fn = functools.partial(nn.LeakyReLU, inplace=True)
else:
raise NotImplementedError
for hidden_dim in hidden_layers:
mlp += [nn.Linear(in_features, hidden_dim)]
if bn:
mlp += [nn.BatchNorm1d(hidden_dim)]
mlp += [act_fn()]
if dropout > 0:
mlp += [nn.Dropout(dropout)]
in_features = hidden_dim
self.mlp = nn.Sequential(*mlp)
def forward(self, x):
return self.mlp(x)
@HEAD_REGISTRY.register()
def mlp(**kwargs):
return MLP(**kwargs)