CIFM / models_cifm /egnn_void_invariant.py
Yuning You
update
552cf9a
raw
history blame
2.64 kB
import torch
from torch.nn import functional as F
from torch_geometric.nn import global_add_pool, global_mean_pool
from models_cifm.layers.egnn_layer_void_invariant import EGNNLayer
from models_cifm.mlp_and_gnn import MLPBiasFree
class VIEGNNModel(torch.nn.Module):
"""
E-GNN model from "E(n) Equivariant Graph Neural Networks".
"""
def __init__(
self,
num_layers: int = 5,
num_mlp_layers_in_module: int = 2,
emb_dim: int = 128,
in_dim: int = 1,
out_dim: int = 1,
activation: str = "relu",
norm: str = "layer",
aggr: str = "sum",
pool: str = "sum",
residual: bool = False
):
"""
Initializes an instance of the EGNNModel class with the provided parameters.
Parameters:
- num_layers (int): Number of layers in the model (default: 5)
- emb_dim (int): Dimension of the node embeddings (default: 128)
- in_dim (int): Input dimension of the model (default: 1)
- out_dim (int): Output dimension of the model (default: 1)
- activation (str): Activation function to be used (default: "relu")
- norm (str): Normalization method to be used (default: "layer")
- aggr (str): Aggregation method to be used (default: "sum")
- pool (str): Global pooling method to be used (default: "sum")
- residual (bool): Whether to use residual connections (default: True)
- equivariant_pred (bool): Whether it is an equivariant prediction task (default: False)
"""
super().__init__()
self.residual = residual
# Embedding lookup for initial node features
self.emb_in = torch.nn.Linear(in_dim, emb_dim, bias=False)
# Stack of GNN layers
self.convs = torch.nn.ModuleList()
for _ in range(num_layers):
self.convs.append(EGNNLayer(emb_dim, num_mlp_layers_in_module, aggr))
self.pred = MLPBiasFree(in_dim=emb_dim, out_dim=out_dim, hidden_dim=emb_dim, num_layer=num_mlp_layers_in_module)
# unroll the batch argments and comment out the pooling operation
def forward(self, x, pos, edge_index):
pos_init = pos
h = self.emb_in(x) # (n,) -> (n, d)
for conv in self.convs:
# Message passing layer
h_update, pos_update = conv(h, pos, edge_index)
# Update node features (n, d) -> (n, d)
h = h + h_update if self.residual else h_update
# Update node coordinates (no residual) (n, 3) -> (n, 3)
pos = pos_update
h = self.pred(h)
return h, pos