Spaces:
Sleeping
Sleeping
| from torch import nn, cat | |
| class MLP1(nn.Sequential): | |
| def __init__(self, | |
| input_channels, | |
| hidden_channels: list[int], | |
| out_channels: int, | |
| activation: type[nn.Module] = nn.ReLU, | |
| dropout: float = 0.0): | |
| layers = [] | |
| num_layers = len(hidden_channels) + 1 | |
| dims = [input_channels] + hidden_channels + [out_channels] | |
| for i in range(num_layers): | |
| if i != (num_layers - 1): | |
| layers.append(nn.Linear(dims[i], dims[i+1])) | |
| layers.append(nn.Dropout(dropout)) | |
| layers.append(activation()) | |
| else: | |
| layers.append(nn.Linear(dims[i], dims[i+1])) | |
| super().__init__(*layers) | |
| class MLP2(nn.Sequential): | |
| def __init__(self, | |
| input_channels, | |
| hidden_channels: list[int], | |
| out_channels: int, | |
| dropout: float = 0.0): | |
| super().__init__() | |
| self.dropout = nn.Dropout(dropout) | |
| num_layers = len(hidden_channels) + 1 | |
| dims = [input_channels] + hidden_channels + [out_channels] | |
| self.layers = nn.ModuleList([nn.Linear(dims[i], dims[i+1]) for i in range(num_layers)]) | |
| def forward(self, x): | |
| for i, layer in enumerate(self.layers): | |
| if i == (len(self.layers) - 1): | |
| x = layer(x) | |
| else: | |
| x = nn.functional.relu(self.dropout(layer(x))) | |
| return x | |
| class LazyMLP(nn.Sequential): | |
| def __init__( | |
| self, | |
| out_channels: int, | |
| hidden_channels: list[int], | |
| activation: type[nn.Module] = nn.ReLU, | |
| dropout: float = 0.0 | |
| ): | |
| layers = [] | |
| for hidden_dim in hidden_channels: | |
| layers.append(nn.LazyLinear(out_features=hidden_dim)) | |
| layers.append(nn.Dropout(dropout)) | |
| layers.append(activation()) | |
| layers.append(nn.LazyLinear(out_features=out_channels)) | |
| super().__init__(*layers) | |
| class ConcatMLP(LazyMLP): | |
| def forward(self, *inputs): | |
| x = cat([*inputs], 1) | |
| x = super().forward(x) | |
| return x | |
| # class ConcatMLP(MLP1): | |
| # def forward(self, *inputs): | |
| # x = cat([*inputs], 1) | |
| # for module in self: | |
| # x = module(x) | |
| # return x | |