|
import torch |
|
from torch.nn import functional as F |
|
from torch_geometric.nn import global_add_pool, global_mean_pool |
|
|
|
from models_cifm.layers.egnn_layer_void_invariant import EGNNLayer |
|
from models_cifm.mlp_and_gnn import MLPBiasFree |
|
|
|
|
|
class VIEGNNModel(torch.nn.Module): |
|
""" |
|
E-GNN model from "E(n) Equivariant Graph Neural Networks". |
|
""" |
|
def __init__( |
|
self, |
|
num_layers: int = 5, |
|
num_mlp_layers_in_module: int = 2, |
|
emb_dim: int = 128, |
|
in_dim: int = 1, |
|
out_dim: int = 1, |
|
activation: str = "relu", |
|
norm: str = "layer", |
|
aggr: str = "sum", |
|
pool: str = "sum", |
|
residual: bool = False |
|
): |
|
""" |
|
Initializes an instance of the EGNNModel class with the provided parameters. |
|
|
|
Parameters: |
|
- num_layers (int): Number of layers in the model (default: 5) |
|
- emb_dim (int): Dimension of the node embeddings (default: 128) |
|
- in_dim (int): Input dimension of the model (default: 1) |
|
- out_dim (int): Output dimension of the model (default: 1) |
|
- activation (str): Activation function to be used (default: "relu") |
|
- norm (str): Normalization method to be used (default: "layer") |
|
- aggr (str): Aggregation method to be used (default: "sum") |
|
- pool (str): Global pooling method to be used (default: "sum") |
|
- residual (bool): Whether to use residual connections (default: True) |
|
- equivariant_pred (bool): Whether it is an equivariant prediction task (default: False) |
|
""" |
|
super().__init__() |
|
self.residual = residual |
|
|
|
|
|
self.emb_in = torch.nn.Linear(in_dim, emb_dim, bias=False) |
|
|
|
|
|
self.convs = torch.nn.ModuleList() |
|
for _ in range(num_layers): |
|
self.convs.append(EGNNLayer(emb_dim, num_mlp_layers_in_module, aggr)) |
|
self.pred = MLPBiasFree(in_dim=emb_dim, out_dim=out_dim, hidden_dim=emb_dim, num_layer=num_mlp_layers_in_module) |
|
|
|
|
|
def forward(self, x, pos, edge_index): |
|
|
|
pos_init = pos |
|
h = self.emb_in(x) |
|
|
|
for conv in self.convs: |
|
|
|
h_update, pos_update = conv(h, pos, edge_index) |
|
|
|
|
|
h = h + h_update if self.residual else h_update |
|
|
|
|
|
pos = pos_update |
|
|
|
h = self.pred(h) |
|
return h, pos |
|
|