import math import numpy as np import torch import torch.nn as nn from src import utils from pdb import set_trace class GCL(nn.Module): def __init__(self, input_nf, output_nf, hidden_nf, normalization_factor, aggregation_method, activation, edges_in_d=0, nodes_att_dim=0, attention=False, normalization=None): super(GCL, self).__init__() input_edge = input_nf * 2 self.normalization_factor = normalization_factor self.aggregation_method = aggregation_method self.attention = attention self.edge_mlp = nn.Sequential( nn.Linear(input_edge + edges_in_d, hidden_nf), activation, nn.Linear(hidden_nf, hidden_nf), activation) if normalization is None: self.node_mlp = nn.Sequential( nn.Linear(hidden_nf + input_nf + nodes_att_dim, hidden_nf), activation, nn.Linear(hidden_nf, output_nf) ) elif normalization == 'batch_norm': self.node_mlp = nn.Sequential( nn.Linear(hidden_nf + input_nf + nodes_att_dim, hidden_nf), nn.BatchNorm1d(hidden_nf), activation, nn.Linear(hidden_nf, output_nf), nn.BatchNorm1d(output_nf), ) else: raise NotImplementedError if self.attention: self.att_mlp = nn.Sequential(nn.Linear(hidden_nf, 1), nn.Sigmoid()) def edge_model(self, source, target, edge_attr, edge_mask): if edge_attr is None: # Unused. out = torch.cat([source, target], dim=1) else: out = torch.cat([source, target, edge_attr], dim=1) mij = self.edge_mlp(out) if self.attention: att_val = self.att_mlp(mij) out = mij * att_val else: out = mij if edge_mask is not None: out = out * edge_mask return out, mij def node_model(self, x, edge_index, edge_attr, node_attr): row, col = edge_index agg = unsorted_segment_sum(edge_attr, row, num_segments=x.size(0), normalization_factor=self.normalization_factor, aggregation_method=self.aggregation_method) if node_attr is not None: agg = torch.cat([x, agg, node_attr], dim=1) else: agg = torch.cat([x, agg], dim=1) out = x + self.node_mlp(agg) return out, agg def forward(self, h, edge_index, edge_attr=None, node_attr=None, node_mask=None, edge_mask=None): row, col = edge_index edge_feat, mij = self.edge_model(h[row], h[col], edge_attr, edge_mask) h, agg = self.node_model(h, edge_index, edge_feat, node_attr) if node_mask is not None: h = h * node_mask return h, mij class EquivariantUpdate(nn.Module): def __init__(self, hidden_nf, normalization_factor, aggregation_method, edges_in_d=1, activation=nn.SiLU(), tanh=False, coords_range=10.0): super(EquivariantUpdate, self).__init__() self.tanh = tanh self.coords_range = coords_range input_edge = hidden_nf * 2 + edges_in_d layer = nn.Linear(hidden_nf, 1, bias=False) torch.nn.init.xavier_uniform_(layer.weight, gain=0.001) self.coord_mlp = nn.Sequential( nn.Linear(input_edge, hidden_nf), activation, nn.Linear(hidden_nf, hidden_nf), activation, layer) self.normalization_factor = normalization_factor self.aggregation_method = aggregation_method def coord_model(self, h, coord, edge_index, coord_diff, edge_attr, edge_mask, linker_mask): row, col = edge_index input_tensor = torch.cat([h[row], h[col], edge_attr], dim=1) if self.tanh: trans = coord_diff * torch.tanh(self.coord_mlp(input_tensor)) * self.coords_range else: trans = coord_diff * self.coord_mlp(input_tensor) if edge_mask is not None: trans = trans * edge_mask agg = unsorted_segment_sum(trans, row, num_segments=coord.size(0), normalization_factor=self.normalization_factor, aggregation_method=self.aggregation_method) if linker_mask is not None: agg = agg * linker_mask coord = coord + agg return coord def forward( self, h, coord, edge_index, coord_diff, edge_attr=None, linker_mask=None, node_mask=None, edge_mask=None ): coord = self.coord_model(h, coord, edge_index, coord_diff, edge_attr, edge_mask, linker_mask) if node_mask is not None: coord = coord * node_mask return coord class EquivariantBlock(nn.Module): def __init__(self, hidden_nf, edge_feat_nf=2, device='cpu', activation=nn.SiLU(), n_layers=2, attention=True, norm_diff=True, tanh=False, coords_range=15, norm_constant=1, sin_embedding=None, normalization_factor=100, aggregation_method='sum'): super(EquivariantBlock, self).__init__() self.hidden_nf = hidden_nf self.device = device self.n_layers = n_layers self.coords_range_layer = float(coords_range) self.norm_diff = norm_diff self.norm_constant = norm_constant self.sin_embedding = sin_embedding self.normalization_factor = normalization_factor self.aggregation_method = aggregation_method for i in range(0, n_layers): self.add_module("gcl_%d" % i, GCL(self.hidden_nf, self.hidden_nf, self.hidden_nf, edges_in_d=edge_feat_nf, activation=activation, attention=attention, normalization_factor=self.normalization_factor, aggregation_method=self.aggregation_method)) self.add_module("gcl_equiv", EquivariantUpdate(hidden_nf, edges_in_d=edge_feat_nf, activation=activation, tanh=tanh, coords_range=self.coords_range_layer, normalization_factor=self.normalization_factor, aggregation_method=self.aggregation_method)) self.to(self.device) def forward(self, h, x, edge_index, node_mask=None, linker_mask=None, edge_mask=None, edge_attr=None): # Edit Emiel: Remove velocity as input distances, coord_diff = coord2diff(x, edge_index, self.norm_constant) if self.sin_embedding is not None: distances = self.sin_embedding(distances) edge_attr = torch.cat([distances, edge_attr], dim=1) for i in range(0, self.n_layers): h, _ = self._modules["gcl_%d" % i](h, edge_index, edge_attr=edge_attr, node_mask=node_mask, edge_mask=edge_mask) x = self._modules["gcl_equiv"]( h, x, edge_index=edge_index, coord_diff=coord_diff, edge_attr=edge_attr, linker_mask=linker_mask, node_mask=node_mask, edge_mask=edge_mask, ) # Important, the bias of the last linear might be non-zero if node_mask is not None: h = h * node_mask return h, x class EGNN(nn.Module): def __init__(self, in_node_nf, in_edge_nf, hidden_nf, device='cpu', activation=nn.SiLU(), n_layers=3, attention=False, norm_diff=True, out_node_nf=None, tanh=False, coords_range=15, norm_constant=1, inv_sublayers=2, sin_embedding=False, normalization_factor=100, aggregation_method='sum'): super(EGNN, self).__init__() if out_node_nf is None: out_node_nf = in_node_nf self.hidden_nf = hidden_nf self.device = device self.n_layers = n_layers self.coords_range_layer = float(coords_range/n_layers) self.norm_diff = norm_diff self.normalization_factor = normalization_factor self.aggregation_method = aggregation_method if sin_embedding: self.sin_embedding = SinusoidsEmbeddingNew() edge_feat_nf = self.sin_embedding.dim * 2 else: self.sin_embedding = None edge_feat_nf = 2 self.embedding = nn.Linear(in_node_nf, self.hidden_nf) self.embedding_out = nn.Linear(self.hidden_nf, out_node_nf) for i in range(0, n_layers): self.add_module("e_block_%d" % i, EquivariantBlock(hidden_nf, edge_feat_nf=edge_feat_nf, device=device, activation=activation, n_layers=inv_sublayers, attention=attention, norm_diff=norm_diff, tanh=tanh, coords_range=coords_range, norm_constant=norm_constant, sin_embedding=self.sin_embedding, normalization_factor=self.normalization_factor, aggregation_method=self.aggregation_method)) self.to(self.device) def forward(self, h, x, edge_index, node_mask=None, linker_mask=None, edge_mask=None): # Edit Emiel: Remove velocity as input distances, _ = coord2diff(x, edge_index) if self.sin_embedding is not None: distances = self.sin_embedding(distances) h = self.embedding(h) for i in range(0, self.n_layers): h, x = self._modules["e_block_%d" % i]( h, x, edge_index, node_mask=node_mask, linker_mask=linker_mask, edge_mask=edge_mask, edge_attr=distances ) # Important, the bias of the last linear might be non-zero h = self.embedding_out(h) if node_mask is not None: h = h * node_mask return h, x class GNN(nn.Module): def __init__(self, in_node_nf, in_edge_nf, hidden_nf, aggregation_method='sum', device='cpu', activation=nn.SiLU(), n_layers=4, attention=False, normalization_factor=1, out_node_nf=None, normalization=None): super(GNN, self).__init__() if out_node_nf is None: out_node_nf = in_node_nf self.hidden_nf = hidden_nf self.device = device self.n_layers = n_layers # Encoder self.embedding = nn.Linear(in_node_nf, self.hidden_nf) self.embedding_out = nn.Linear(self.hidden_nf, out_node_nf) for i in range(0, n_layers): self.add_module("gcl_%d" % i, GCL( self.hidden_nf, self.hidden_nf, self.hidden_nf, normalization_factor=normalization_factor, aggregation_method=aggregation_method, edges_in_d=in_edge_nf, activation=activation, attention=attention, normalization=normalization)) self.to(self.device) def forward(self, h, edges, edge_attr=None, node_mask=None, edge_mask=None): # Edit Emiel: Remove velocity as input h = self.embedding(h) for i in range(0, self.n_layers): h, _ = self._modules["gcl_%d" % i](h, edges, edge_attr=edge_attr, node_mask=node_mask, edge_mask=edge_mask) h = self.embedding_out(h) # Important, the bias of the last linear might be non-zero if node_mask is not None: h = h * node_mask return h class SinusoidsEmbeddingNew(nn.Module): def __init__(self, max_res=15., min_res=15. / 2000., div_factor=4): super().__init__() self.n_frequencies = int(math.log(max_res / min_res, div_factor)) + 1 self.frequencies = 2 * math.pi * div_factor ** torch.arange(self.n_frequencies)/max_res self.dim = len(self.frequencies) * 2 def forward(self, x): x = torch.sqrt(x + 1e-8) emb = x * self.frequencies[None, :].to(x.device) emb = torch.cat((emb.sin(), emb.cos()), dim=-1) return emb.detach() def coord2diff(x, edge_index, norm_constant=1): row, col = edge_index coord_diff = x[row] - x[col] radial = torch.sum((coord_diff) ** 2, 1).unsqueeze(1) norm = torch.sqrt(radial + 1e-8) coord_diff = coord_diff/(norm + norm_constant) return radial, coord_diff def unsorted_segment_sum(data, segment_ids, num_segments, normalization_factor, aggregation_method: str): """Custom PyTorch op to replicate TensorFlow's `unsorted_segment_sum`. Normalization: 'sum' or 'mean'. """ result_shape = (num_segments, data.size(1)) result = data.new_full(result_shape, 0) # Init empty result tensor. segment_ids = segment_ids.unsqueeze(-1).expand(-1, data.size(1)) result.scatter_add_(0, segment_ids, data) if aggregation_method == 'sum': result = result / normalization_factor if aggregation_method == 'mean': norm = data.new_zeros(result.shape) norm.scatter_add_(0, segment_ids, data.new_ones(data.shape)) norm[norm == 0] = 1 result = result / norm return result class Dynamics(nn.Module): def __init__( self, n_dims, in_node_nf, context_node_nf, hidden_nf=64, device='cpu', activation=nn.SiLU(), n_layers=4, attention=False, condition_time=True, tanh=False, norm_constant=0, inv_sublayers=2, sin_embedding=False, normalization_factor=100, aggregation_method='sum', model='egnn_dynamics', normalization=None, centering=False, graph_type='FC', ): super().__init__() self.device = device self.n_dims = n_dims self.context_node_nf = context_node_nf self.condition_time = condition_time self.model = model self.centering = centering self.graph_type = graph_type in_node_nf = in_node_nf + context_node_nf + condition_time if self.model == 'egnn_dynamics': self.dynamics = EGNN( in_node_nf=in_node_nf, in_edge_nf=1, hidden_nf=hidden_nf, device=device, activation=activation, n_layers=n_layers, attention=attention, tanh=tanh, norm_constant=norm_constant, inv_sublayers=inv_sublayers, sin_embedding=sin_embedding, normalization_factor=normalization_factor, aggregation_method=aggregation_method, ) elif self.model == 'gnn_dynamics': self.dynamics = GNN( in_node_nf=in_node_nf+3, in_edge_nf=0, hidden_nf=hidden_nf, out_node_nf=in_node_nf+3, device=device, activation=activation, n_layers=n_layers, attention=attention, normalization_factor=normalization_factor, aggregation_method=aggregation_method, normalization=normalization, ) else: raise NotImplementedError self.edge_cache = {} def forward(self, t, xh, node_mask, linker_mask, edge_mask, context): """ - t: (B) - xh: (B, N, D), where D = 3 + nf - node_mask: (B, N, 1) - edge_mask: (B*N*N, 1) - context: (B, N, C) """ assert self.graph_type == 'FC' bs, n_nodes = xh.shape[0], xh.shape[1] edges = self.get_edges(n_nodes, bs) # (2, B*N) node_mask = node_mask.view(bs * n_nodes, 1) # (B*N, 1) if linker_mask is not None: linker_mask = linker_mask.view(bs * n_nodes, 1) # (B*N, 1) # Reshaping node features & adding time feature xh = xh.view(bs * n_nodes, -1).clone() * node_mask # (B*N, D) x = xh[:, :self.n_dims].clone() # (B*N, 3) h = xh[:, self.n_dims:].clone() # (B*N, nf) if self.condition_time: if np.prod(t.size()) == 1: # t is the same for all elements in batch. h_time = torch.empty_like(h[:, 0:1]).fill_(t.item()) else: # t is different over the batch dimension. h_time = t.view(bs, 1).repeat(1, n_nodes) h_time = h_time.view(bs * n_nodes, 1) h = torch.cat([h, h_time], dim=1) # (B*N, nf+1) if context is not None: context = context.view(bs*n_nodes, self.context_node_nf) h = torch.cat([h, context], dim=1) # Forward EGNN # Output: h_final (B*N, nf), x_final (B*N, 3), vel (B*N, 3) if self.model == 'egnn_dynamics': h_final, x_final = self.dynamics( h, x, edges, node_mask=node_mask, linker_mask=linker_mask, edge_mask=edge_mask ) vel = (x_final - x) * node_mask # This masking operation is redundant but just in case elif self.model == 'gnn_dynamics': xh = torch.cat([x, h], dim=1) output = self.dynamics(xh, edges, node_mask=node_mask) vel = output[:, 0:3] * node_mask h_final = output[:, 3:] else: raise NotImplementedError # Slice off context size if context is not None: h_final = h_final[:, :-self.context_node_nf] # Slice off last dimension which represented time. if self.condition_time: h_final = h_final[:, :-1] vel = vel.view(bs, n_nodes, -1) # (B, N, 3) h_final = h_final.view(bs, n_nodes, -1) # (B, N, D) node_mask = node_mask.view(bs, n_nodes, 1) # (B, N, 1) if self.centering: vel = utils.remove_mean_with_mask(vel, node_mask) return torch.cat([vel, h_final], dim=2) def get_edges(self, n_nodes, batch_size): if n_nodes in self.edge_cache: edges_dic_b = self.edge_cache[n_nodes] if batch_size in edges_dic_b: return edges_dic_b[batch_size] else: # get edges for a single sample rows, cols = [], [] for batch_idx in range(batch_size): for i in range(n_nodes): for j in range(n_nodes): rows.append(i + batch_idx * n_nodes) cols.append(j + batch_idx * n_nodes) edges = [torch.LongTensor(rows).to(self.device), torch.LongTensor(cols).to(self.device)] edges_dic_b[batch_size] = edges return edges else: self.edge_cache[n_nodes] = {} return self.get_edges(n_nodes, batch_size) class DynamicsWithPockets(Dynamics): def forward(self, t, xh, node_mask, linker_mask, edge_mask, context): """ - t: (B) - xh: (B, N, D), where D = 3 + nf - node_mask: (B, N, 1) - edge_mask: (B*N*N, 1) - context: (B, N, C) """ bs, n_nodes = xh.shape[0], xh.shape[1] node_mask = node_mask.view(bs * n_nodes, 1) # (B*N, 1) if linker_mask is not None: linker_mask = linker_mask.view(bs * n_nodes, 1) # (B*N, 1) fragment_only_mask = context[..., -2].view(bs * n_nodes, 1) # (B*N, 1) pocket_only_mask = context[..., -1].view(bs * n_nodes, 1) # (B*N, 1) assert torch.all(fragment_only_mask.bool() | pocket_only_mask.bool() | linker_mask.bool() == node_mask.bool()) # Reshaping node features & adding time feature xh = xh.view(bs * n_nodes, -1).clone() * node_mask # (B*N, D) x = xh[:, :self.n_dims].clone() # (B*N, 3) h = xh[:, self.n_dims:].clone() # (B*N, nf) assert self.graph_type in ['4A', 'FC-4A', 'FC-10A-4A'] if self.graph_type == '4A' or self.graph_type is None: edges = self.get_dist_edges_4A(x, node_mask, edge_mask) else: edges = self.get_dist_edges(x, node_mask, edge_mask, linker_mask, fragment_only_mask, pocket_only_mask) if self.condition_time: if np.prod(t.size()) == 1: # t is the same for all elements in batch. h_time = torch.empty_like(h[:, 0:1]).fill_(t.item()) else: # t is different over the batch dimension. h_time = t.view(bs, 1).repeat(1, n_nodes) h_time = h_time.view(bs * n_nodes, 1) h = torch.cat([h, h_time], dim=1) # (B*N, nf+1) if context is not None: context = context.view(bs*n_nodes, self.context_node_nf) h = torch.cat([h, context], dim=1) # Forward EGNN # Output: h_final (B*N, nf), x_final (B*N, 3), vel (B*N, 3) if self.model == 'egnn_dynamics': h_final, x_final = self.dynamics( h, x, edges, node_mask=node_mask, linker_mask=linker_mask, edge_mask=None ) vel = (x_final - x) * node_mask # This masking operation is redundant but just in case elif self.model == 'gnn_dynamics': xh = torch.cat([x, h], dim=1) output = self.dynamics(xh, edges, node_mask=node_mask) vel = output[:, 0:3] * node_mask h_final = output[:, 3:] else: raise NotImplementedError # Slice off context size if context is not None: h_final = h_final[:, :-self.context_node_nf] # Slice off last dimension which represented time. if self.condition_time: h_final = h_final[:, :-1] vel = vel.view(bs, n_nodes, -1) # (B, N, 3) h_final = h_final.view(bs, n_nodes, -1) # (B, N, D) node_mask = node_mask.view(bs, n_nodes, 1) # (B, N, 1) if torch.any(torch.isnan(vel)) or torch.any(torch.isnan(h_final)): raise utils.FoundNaNException(vel, h_final) if self.centering: vel = utils.remove_mean_with_mask(vel, node_mask) return torch.cat([vel, h_final], dim=2) @staticmethod def get_dist_edges_4A(x, node_mask, batch_mask): node_mask = node_mask.squeeze().bool() batch_adj = (batch_mask[:, None] == batch_mask[None, :]) nodes_adj = (node_mask[:, None] & node_mask[None, :]) dists_adj = (torch.cdist(x, x) <= 4) rm_self_loops = ~torch.eye(x.size(0), dtype=torch.bool, device=x.device) adj = batch_adj & nodes_adj & dists_adj & rm_self_loops edges = torch.stack(torch.where(adj)) return edges def get_dist_edges(self, x, node_mask, batch_mask, linker_mask, fragment_only_mask, pocket_only_mask): node_mask = node_mask.squeeze().bool() linker_mask = linker_mask.squeeze().bool() & node_mask fragment_only_mask = fragment_only_mask.squeeze().bool() & node_mask pocket_only_mask = pocket_only_mask.squeeze().bool() & node_mask ligand_mask = linker_mask | fragment_only_mask # General constrains: batch_adj = (batch_mask[:, None] == batch_mask[None, :]) nodes_adj = (node_mask[:, None] & node_mask[None, :]) rm_self_loops = ~torch.eye(x.size(0), dtype=torch.bool, device=x.device) constraints = batch_adj & nodes_adj & rm_self_loops # Ligand atoms – fully-connected graph ligand_adj = (ligand_mask[:, None] & ligand_mask[None, :]) ligand_interactions = ligand_adj & constraints # Pocket atoms - within 4A pocket_adj = (pocket_only_mask[:, None] & pocket_only_mask[None, :]) pocket_dists_adj = (torch.cdist(x, x) <= 4) pocket_interactions = pocket_adj & pocket_dists_adj & constraints # Pocket-ligand atoms - within 10A pocket_ligand_cutoff = 4 if self.graph_type == 'FC-4A' else 10 pocket_ligand_adj = (ligand_mask[:, None] & pocket_only_mask[None, :]) pocket_ligand_adj = pocket_ligand_adj | (pocket_only_mask[:, None] & ligand_mask[None, :]) pocket_ligand_dists_adj = (torch.cdist(x, x) <= pocket_ligand_cutoff) pocket_ligand_interactions = pocket_ligand_adj & pocket_ligand_dists_adj & constraints adj = ligand_interactions | pocket_interactions | pocket_ligand_interactions edges = torch.stack(torch.where(adj)) return edges