| import sys | |
| sys.path.append("./BranchSBM") | |
| import torch | |
| import torch.nn as nn | |
| from typing import List, Optional | |
| from networks.mlp_base import SimpleDenseNet | |
| class GeoPathMLP(nn.Module): | |
| def __init__( | |
| self, | |
| input_dim: int, | |
| activation: str, | |
| batch_norm: bool = True, | |
| hidden_dims: Optional[List[int]] = None, | |
| time_geopath: bool = False, | |
| ): | |
| super().__init__() | |
| self.input_dim = input_dim | |
| self.time_geopath = time_geopath | |
| self.mainnet = SimpleDenseNet( | |
| input_size=2 * input_dim + (1 if time_geopath else 0), | |
| target_size=input_dim, | |
| activation=activation, | |
| batch_norm=batch_norm, | |
| hidden_dims=hidden_dims, | |
| ) | |
| def forward( | |
| self, x0: torch.Tensor, x1: torch.Tensor, t: torch.Tensor | |
| ) -> torch.Tensor: | |
| x = torch.cat([x0, x1], dim=1) | |
| if self.time_geopath: | |
| x = torch.cat([x, t], dim=1) | |
| return self.mainnet(x) |