import functools import torch.nn as nn from .build import HEAD_REGISTRY class MLP(nn.Module): def __init__( self, in_features=2048, hidden_layers=[], activation="relu", bn=True, dropout=0.0, ): super().__init__() if isinstance(hidden_layers, int): hidden_layers = [hidden_layers] assert len(hidden_layers) > 0 self.out_features = hidden_layers[-1] mlp = [] if activation == "relu": act_fn = functools.partial(nn.ReLU, inplace=True) elif activation == "leaky_relu": act_fn = functools.partial(nn.LeakyReLU, inplace=True) else: raise NotImplementedError for hidden_dim in hidden_layers: mlp += [nn.Linear(in_features, hidden_dim)] if bn: mlp += [nn.BatchNorm1d(hidden_dim)] mlp += [act_fn()] if dropout > 0: mlp += [nn.Dropout(dropout)] in_features = hidden_dim self.mlp = nn.Sequential(*mlp) def forward(self, x): return self.mlp(x) @HEAD_REGISTRY.register() def mlp(**kwargs): return MLP(**kwargs)