|  | import torch | 
					
						
						|  | import torch_geometric | 
					
						
						|  | import torch.nn as nn | 
					
						
						|  | import torch.nn.functional as F | 
					
						
						|  | from torch_geometric.nn import ( | 
					
						
						|  | PNAConv, | 
					
						
						|  | global_mean_pool, | 
					
						
						|  | global_max_pool, | 
					
						
						|  | global_add_pool, | 
					
						
						|  | ) | 
					
						
						|  | from torch_geometric.utils import degree | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class PolyatomicNet(nn.Module): | 
					
						
						|  | def __init__( | 
					
						
						|  | self, | 
					
						
						|  | node_feat_dim, | 
					
						
						|  | edge_feat_dim, | 
					
						
						|  | graph_feat_dim, | 
					
						
						|  | deg, | 
					
						
						|  | hidden_dim=128, | 
					
						
						|  | num_layers=5, | 
					
						
						|  | dropout=0.1, | 
					
						
						|  | ): | 
					
						
						|  | super().__init__() | 
					
						
						|  | self.graph_feat_dim = graph_feat_dim | 
					
						
						|  | self.node_emb = nn.Linear(node_feat_dim, hidden_dim) | 
					
						
						|  | self.deg = deg | 
					
						
						|  | self.virtualnode_emb = nn.Embedding(1, hidden_dim) | 
					
						
						|  | self.vn_mlp = nn.Sequential( | 
					
						
						|  | nn.Linear(hidden_dim, hidden_dim), | 
					
						
						|  | nn.ReLU(), | 
					
						
						|  | nn.Linear(hidden_dim, hidden_dim), | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | self.graph_proj = nn.Sequential( | 
					
						
						|  | nn.Linear(graph_feat_dim, hidden_dim), | 
					
						
						|  | nn.ReLU(), | 
					
						
						|  | nn.Linear(hidden_dim, hidden_dim), | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | self.deg_emb = nn.Embedding(20, hidden_dim) | 
					
						
						|  |  | 
					
						
						|  | aggregators = ["mean", "min", "max", "std"] | 
					
						
						|  | scalers = ["identity", "amplification", "attenuation"] | 
					
						
						|  |  | 
					
						
						|  | self.convs = nn.ModuleList() | 
					
						
						|  | self.bns = nn.ModuleList() | 
					
						
						|  |  | 
					
						
						|  | for _ in range(num_layers): | 
					
						
						|  | conv = PNAConv( | 
					
						
						|  | in_channels=hidden_dim, | 
					
						
						|  | out_channels=hidden_dim, | 
					
						
						|  | aggregators=aggregators, | 
					
						
						|  | scalers=scalers, | 
					
						
						|  | edge_dim=edge_feat_dim, | 
					
						
						|  | towers=4, | 
					
						
						|  | pre_layers=1, | 
					
						
						|  | post_layers=1, | 
					
						
						|  | divide_input=True, | 
					
						
						|  | deg=deg, | 
					
						
						|  | ) | 
					
						
						|  | self.convs.append(conv) | 
					
						
						|  | self.bns.append(nn.BatchNorm1d(hidden_dim)) | 
					
						
						|  |  | 
					
						
						|  | self.dropout = nn.Dropout(dropout) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | self.readout = nn.Sequential( | 
					
						
						|  | nn.Linear(hidden_dim * 3, hidden_dim), | 
					
						
						|  | nn.ReLU(), | 
					
						
						|  | nn.Dropout(dropout), | 
					
						
						|  | nn.Linear(hidden_dim, hidden_dim // 2), | 
					
						
						|  | nn.ReLU(), | 
					
						
						|  | nn.Linear(hidden_dim // 2, 1), | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | def forward(self, data): | 
					
						
						|  | x, edge_index, edge_attr, batch = ( | 
					
						
						|  | data.x, | 
					
						
						|  | data.edge_index, | 
					
						
						|  | data.edge_attr, | 
					
						
						|  | data.batch, | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | deg = degree(edge_index[0], x.size(0), dtype=torch.long).clamp(max=19) | 
					
						
						|  | h = self.node_emb(x) + self.deg_emb(deg) | 
					
						
						|  |  | 
					
						
						|  | vn = self.virtualnode_emb( | 
					
						
						|  | torch.zeros(batch.max().item() + 1, dtype=torch.long, device=x.device) | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | for conv, bn in zip(self.convs, self.bns): | 
					
						
						|  | h = h + vn[batch] | 
					
						
						|  | h = conv(h, edge_index, edge_attr) | 
					
						
						|  | h = bn(h) | 
					
						
						|  | h = F.relu(h) | 
					
						
						|  | h = self.dropout(h) | 
					
						
						|  | vn = vn + self.vn_mlp(global_mean_pool(h, batch)) | 
					
						
						|  |  | 
					
						
						|  | mean_pool = global_mean_pool(h, batch) | 
					
						
						|  | max_pool = global_max_pool(h, batch) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | max_feat_dim = self.graph_feat_dim | 
					
						
						|  |  | 
					
						
						|  | if hasattr(data, "graph_feats") and isinstance( | 
					
						
						|  | data, torch_geometric.data.Batch | 
					
						
						|  | ): | 
					
						
						|  | g_proj_list = [] | 
					
						
						|  | for g in data.to_data_list(): | 
					
						
						|  | g_feat = g.graph_feats.to(x.device) | 
					
						
						|  |  | 
					
						
						|  | if g_feat.size(0) < max_feat_dim: | 
					
						
						|  | padded = torch.zeros(max_feat_dim, device=g_feat.device) | 
					
						
						|  | padded[: g_feat.size(0)] = g_feat | 
					
						
						|  | g_feat = padded | 
					
						
						|  | elif g_feat.size(0) > max_feat_dim: | 
					
						
						|  | g_feat = g_feat[:max_feat_dim] | 
					
						
						|  | g_feat = torch.nan_to_num(g_feat, nan=0.0, posinf=1e5, neginf=-1e5) | 
					
						
						|  | g_proj_list.append(self.graph_proj(g_feat)) | 
					
						
						|  |  | 
					
						
						|  | g_proj = torch.stack(g_proj_list, dim=0) | 
					
						
						|  |  | 
					
						
						|  | else: | 
					
						
						|  | g_feat = data.graph_feats.to(x.device) | 
					
						
						|  | if g_feat.size(0) < max_feat_dim: | 
					
						
						|  | padded = torch.zeros(max_feat_dim, device=g_feat.device) | 
					
						
						|  | padded[: g_feat.size(0)] = g_feat | 
					
						
						|  | g_feat = padded | 
					
						
						|  | elif g_feat.size(0) > max_feat_dim: | 
					
						
						|  | g_feat = g_feat[:max_feat_dim] | 
					
						
						|  | g_feat = torch.nan_to_num(g_feat, nan=0.0, posinf=1e5, neginf=-1e5) | 
					
						
						|  | g_proj = self.graph_proj(g_feat).unsqueeze(0) | 
					
						
						|  |  | 
					
						
						|  | final_input = torch.cat([mean_pool, max_pool, g_proj], dim=1) | 
					
						
						|  | return self.readout(final_input).view(-1) | 
					
						
						|  |  |