import sys sys.path.append("..") import torch import torch.nn as nn from torch.nn import Module, Linear, Embedding from torch.nn import functional as F from torch_scatter import scatter_add, scatter_mean from torch_geometric.data import Data, Batch from copy import deepcopy from .encoders import get_encoder, GNN_graphpred, MLP from .common import * from utils import dihedral_utils, chemutils class FLAG(Module): def __init__(self, config, protein_atom_feature_dim, ligand_atom_feature_dim, vocab): super().__init__() self.config = config self.vocab = vocab self.protein_atom_emb = Linear(protein_atom_feature_dim, config.hidden_channels) self.ligand_atom_emb = Linear(ligand_atom_feature_dim, config.hidden_channels) self.embedding = nn.Embedding(vocab.size() + 1, config.hidden_channels) self.W = nn.Linear(2 * config.hidden_channels, config.hidden_channels) self.W_o = nn.Linear(config.hidden_channels, self.vocab.size()) self.encoder = get_encoder(config.encoder) self.comb_head = GNN_graphpred(num_layer=3, emb_dim=config.hidden_channels, num_tasks=1, JK='last', drop_ratio=0.5, graph_pooling='mean', gnn_type='gin') if config.random_alpha: self.alpha_mlp = MLP(in_dim=config.hidden_channels * 4, out_dim=1, num_layers=2) else: self.alpha_mlp = MLP(in_dim=config.hidden_channels * 3, out_dim=1, num_layers=2) self.focal_mlp_ligand = MLP(in_dim=config.hidden_channels, out_dim=1, num_layers=1) self.focal_mlp_protein = MLP(in_dim=config.hidden_channels, out_dim=1, num_layers=1) self.dist_mlp = MLP(in_dim=protein_atom_feature_dim + ligand_atom_feature_dim, out_dim=1, num_layers=2) if config.refinement: self.refine_protein = MLP(in_dim=config.hidden_channels * 2 + config.encoder.edge_channels, out_dim=1, num_layers=2) self.refine_ligand = MLP(in_dim=config.hidden_channels * 2 + config.encoder.edge_channels, out_dim=1, num_layers=2) self.smooth_cross_entropy = SmoothCrossEntropyLoss(reduction='mean', smoothing=0.1) self.pred_loss = nn.CrossEntropyLoss() self.comb_loss = nn.BCEWithLogitsLoss() self.three_hop_loss = torch.nn.MSELoss() self.focal_loss = nn.BCEWithLogitsLoss() self.dist_loss = torch.nn.MSELoss(reduction='mean') def forward(self, protein_pos, protein_atom_feature, ligand_pos, ligand_atom_feature, batch_protein, batch_ligand): h_protein = self.protein_atom_emb(protein_atom_feature) h_ligand = self.ligand_atom_emb(ligand_atom_feature) h_ctx, pos_ctx, batch_ctx, protein_mask = compose_context_stable(h_protein=h_protein, h_ligand=h_ligand, pos_protein=protein_pos, pos_ligand=ligand_pos, batch_protein=batch_protein, batch_ligand=batch_ligand) h_ctx = self.encoder(node_attr=h_ctx, pos=pos_ctx, batch=batch_ctx) # (N_p+N_l, H) focal_pred = torch.cat([self.focal_mlp_protein(h_ctx[protein_mask]), self.focal_mlp_ligand(h_ctx[~protein_mask])], dim=0) return focal_pred, protein_mask, h_ctx def forward_motif(self, h_ctx_focal, current_wid, current_atoms_batch, n_samples=1): node_hiddens = scatter_add(h_ctx_focal, dim=0, index=current_atoms_batch) motif_hiddens = self.embedding(current_wid) pred_vecs = torch.cat([node_hiddens, motif_hiddens], dim=1) pred_vecs = nn.ReLU()(self.W(pred_vecs)) pred_scores = self.W_o(pred_vecs) pred_scores = F.softmax(pred_scores, dim=-1) _, preds = torch.max(pred_scores, dim=1) # random select n_samples in topk k = 5*n_samples select_pool = torch.topk(pred_scores, k, dim=1)[1] index = torch.randint(k, (select_pool.shape[0], n_samples)) preds = torch.cat([select_pool[i][index[i]] for i in range(len(index))]) idx_parent = torch.repeat_interleave(torch.arange(pred_scores.shape[0]), n_samples, dim=0).to(pred_scores.device) prob = pred_scores[idx_parent, preds] return preds, prob def forward_attach(self, mol_list, next_motif_smiles, device): cand_mols, cand_batch, new_atoms, one_atom_attach, intersection, attach_fail = chemutils.assemble(mol_list, next_motif_smiles) graph_data = Batch.from_data_list([chemutils.mol_to_graph_data_obj_simple(mol) for mol in cand_mols]).to(device) comb_pred = self.comb_head(graph_data.x, graph_data.edge_index, graph_data.edge_attr, graph_data.batch).reshape(-1) slice_idx = torch.cat([torch.tensor([0]), torch.cumsum(cand_batch.bincount(), dim=0)], dim=0) select = [(torch.argmax(comb_pred[slice_idx[i]:slice_idx[i + 1]]) + slice_idx[i]).item() for i in range(len(slice_idx) - 1)] ''' select = [] for k in range(len(slice_idx) - 1): id = torch.multinomial(torch.exp(comb_pred[slice_idx[k]:slice_idx[k + 1]]).reshape(-1).float(), 1) select.append((id+slice_idx[k]).item())''' select_mols = [cand_mols[i] for i in select] new_atoms = [new_atoms[i] for i in select] one_atom_attach = [one_atom_attach[i] for i in select] intersection = [intersection[i] for i in select] return select_mols, new_atoms, one_atom_attach, intersection, attach_fail def forward_alpha(self, protein_pos, protein_atom_feature, ligand_pos, ligand_atom_feature, batch_protein, batch_ligand, xy_index, rotatable): # encode again h_protein = self.protein_atom_emb(protein_atom_feature) h_ligand = self.ligand_atom_emb(ligand_atom_feature) h_ctx, pos_ctx, batch_ctx, protein_mask = compose_context_stable(h_protein=h_protein, h_ligand=h_ligand, pos_protein=protein_pos, pos_ligand=ligand_pos, batch_protein=batch_protein, batch_ligand=batch_ligand) h_ctx = self.encoder(node_attr=h_ctx, pos=pos_ctx, batch=batch_ctx) # (N_p+N_l, H) h_ctx_ligand = h_ctx[~protein_mask] hx, hy = h_ctx_ligand[xy_index[:, 0]], h_ctx_ligand[xy_index[:, 1]] h_mol = scatter_add(h_ctx_ligand, dim=0, index=batch_ligand) h_mol = h_mol[rotatable] if self.config.random_alpha: rand_dist = torch.distributions.normal.Normal(loc=0, scale=1) rand_alpha = rand_dist.sample(hx.shape).to(hx.device) alpha = self.alpha_mlp(torch.cat([hx, hy, h_mol, rand_alpha], dim=-1)) else: alpha = self.alpha_mlp(torch.cat([hx, hy, h_mol], dim=-1)) return alpha def get_loss(self, protein_pos, protein_atom_feature, ligand_pos, ligand_atom_feature, ligand_pos_torsion, ligand_atom_feature_torsion, batch_protein, batch_ligand, batch_ligand_torsion, batch): self.device = protein_pos.device h_protein = self.protein_atom_emb(protein_atom_feature) h_ligand = self.ligand_atom_emb(ligand_atom_feature) loss_list = [0, 0, 0, 0, 0, 0] # Encode for motif prediction h_ctx, pos_ctx, batch_ctx, mask_protein = compose_context_stable(h_protein=h_protein, h_ligand=h_ligand, pos_protein=protein_pos, pos_ligand=ligand_pos, batch_protein=batch_protein, batch_ligand=batch_ligand) h_ctx = self.encoder(node_attr=h_ctx, pos=pos_ctx, batch=batch_ctx) # (N_p+N_l, H) h_ctx_ligand = h_ctx[~mask_protein] h_ctx_protein = h_ctx[mask_protein] h_ctx_focal = h_ctx[batch['current_atoms']] # Encode for torsion prediction if len(batch['y_pos']) > 0: h_ligand_torsion = self.ligand_atom_emb(ligand_atom_feature_torsion) h_ctx_torison, pos_ctx_torison, batch_ctx_torsion, mask_protein = compose_context_stable(h_protein=h_protein, h_ligand=h_ligand_torsion, pos_protein=protein_pos, pos_ligand=ligand_pos_torsion, batch_protein=batch_protein, batch_ligand=batch_ligand_torsion) h_ctx_torsion = self.encoder(node_attr=h_ctx_torison, pos=pos_ctx_torison, batch=batch_ctx_torsion) # (N_p+N_l, H) h_ctx_ligand_torsion = h_ctx_torsion[~mask_protein] # next motif prediction node_hiddens = scatter_add(h_ctx_focal, dim=0, index=batch['current_atoms_batch']) motif_hiddens = self.embedding(batch['current_wid']) pred_vecs = torch.cat([node_hiddens, motif_hiddens], dim=1) pred_vecs = nn.ReLU()(self.W(pred_vecs)) pred_scores = self.W_o(pred_vecs) pred_loss = self.pred_loss(pred_scores, batch['next_wid']) loss_list[0] = pred_loss.item() # attachment prediction if len(batch['cand_labels']) > 0: cand_mols = batch['cand_mols'] comb_pred = self.comb_head(cand_mols.x, cand_mols.edge_index, cand_mols.edge_attr, cand_mols.batch) comb_loss = self.comb_loss(comb_pred, batch['cand_labels'].view(comb_pred.shape).float()) loss_list[1] = comb_loss.item() else: comb_loss = 0 # focal prediction focal_ligand_pred, focal_protein_pred = self.focal_mlp_ligand(h_ctx_ligand), self.focal_mlp_protein(h_ctx_protein) focal_loss = self.focal_loss(focal_ligand_pred.reshape(-1), batch['ligand_frontier'].float()) +\ self.focal_loss(focal_protein_pred.reshape(-1), batch['protein_contact'].float()) loss_list[2] = focal_loss.item() # distance matrix prediction if len(batch['true_dm']) > 0: input = torch.cat([protein_atom_feature[batch['dm_protein_idx']], ligand_atom_feature[batch['dm_ligand_idx']]], dim=-1) pred_dist = self.dist_mlp(input) dm_target = batch['true_dm'].unsqueeze(-1) dm_loss = self.dist_loss(pred_dist, dm_target) loss_list[3] = dm_loss.item() else: dm_loss = 0 # structure refinement loss if self.config.refinement and len(batch['true_dm']) > 0: true_distance_alpha = torch.norm(batch['ligand_context_pos'][batch['sr_ligand_idx']] - batch['protein_pos'][batch['sr_protein_idx']], dim=1) true_distance_intra = torch.norm(batch['ligand_context_pos'][batch['sr_ligand_idx0']] - batch['ligand_context_pos'][batch['sr_ligand_idx1']], dim=1) input_distance_alpha = ligand_pos[batch['sr_ligand_idx']] - protein_pos[batch['sr_protein_idx']] input_distance_intra = ligand_pos[batch['sr_ligand_idx0']] - ligand_pos[batch['sr_ligand_idx1']] distance_emb1 = self.encoder.distance_expansion(torch.norm(input_distance_alpha, dim=1)) distance_emb2 = self.encoder.distance_expansion(torch.norm(input_distance_intra, dim=1)) input1 = torch.cat([h_ctx_ligand[batch['sr_ligand_idx']], h_ctx_protein[batch['sr_protein_idx']], distance_emb1], dim=-1)[true_distance_alpha<=10.0] input2 = torch.cat([h_ctx_ligand[batch['sr_ligand_idx0']], h_ctx_ligand[batch['sr_ligand_idx1']], distance_emb2], dim=-1)[true_distance_intra<=10.0] #distance cut_off norm_dir1 = F.normalize(input_distance_alpha, p=2, dim=1)[true_distance_alpha<=10.0] norm_dir2 = F.normalize(input_distance_intra, p=2, dim=1)[true_distance_intra<=10.0] force1 = scatter_mean(self.refine_protein(input1)*norm_dir1, dim=0, index=batch['sr_ligand_idx'][true_distance_alpha<=10.0], dim_size=ligand_pos.size(0)) force2 = scatter_mean(self.refine_ligand(input2)*norm_dir2, dim=0, index=batch['sr_ligand_idx0'][true_distance_intra<=10.0], dim_size=ligand_pos.size(0)) new_ligand_pos = deepcopy(ligand_pos) new_ligand_pos += force1 new_ligand_pos += force2 refine_dist1 = torch.norm(new_ligand_pos[batch['sr_ligand_idx']] - protein_pos[batch['sr_protein_idx']], dim=1) refine_dist2 = torch.norm(new_ligand_pos[batch['sr_ligand_idx0']] - new_ligand_pos[batch['sr_ligand_idx1']], dim=1) sr_loss = (self.dist_loss(refine_dist1, true_distance_alpha) + self.dist_loss(refine_dist2, true_distance_intra)) loss_list[5] = sr_loss.item() else: sr_loss = 0 # torsion prediction if len(batch['y_pos']) > 0: Hx = dihedral_utils.rotation_matrix_v2(batch['y_pos']) xn_pos = torch.matmul(Hx, batch['xn_pos'].permute(0, 2, 1)).permute(0, 2, 1) yn_pos = torch.matmul(Hx, batch['yn_pos'].permute(0, 2, 1)).permute(0, 2, 1) y_pos = torch.matmul(Hx, batch['y_pos'].unsqueeze(1).permute(0, 2, 1)).squeeze(-1) hx, hy = h_ctx_ligand_torsion[batch['ligand_torsion_xy_index'][:, 0]], h_ctx_ligand_torsion[batch['ligand_torsion_xy_index'][:, 1]] h_mol = scatter_add(h_ctx_ligand_torsion, dim=0, index=batch['ligand_element_torsion_batch']) if self.config.random_alpha: rand_dist = torch.distributions.normal.Normal(loc=0, scale=1) rand_alpha = rand_dist.sample(hx.shape).to(self.device) alpha = self.alpha_mlp(torch.cat([hx, hy, h_mol, rand_alpha], dim=-1)) else: alpha = self.alpha_mlp(torch.cat([hx, hy, h_mol], dim=-1)) # rotate xn R_alpha = self.build_alpha_rotation(torch.sin(alpha).squeeze(-1), torch.cos(alpha).squeeze(-1)) xn_pos = torch.matmul(R_alpha, xn_pos.permute(0, 2, 1)).permute(0, 2, 1) p_idx, q_idx = torch.cartesian_prod(torch.arange(3), torch.arange(3)).chunk(2, dim=-1) p_idx, q_idx = p_idx.squeeze(-1), q_idx.squeeze(-1) pred_sin, pred_cos = dihedral_utils.batch_dihedrals(xn_pos[:, p_idx], torch.zeros_like(y_pos).unsqueeze(1).repeat(1, 9, 1), y_pos.unsqueeze(1).repeat(1, 9, 1), yn_pos[:, q_idx]) dihedral_loss = torch.mean(dihedral_utils.von_Mises_loss(batch['true_cos'], pred_cos.reshape(-1), batch['true_sin'], pred_cos.reshape(-1))[batch['dihedral_mask']]) torsion_loss = -dihedral_loss loss_list[4] = torsion_loss.item() else: torsion_loss = 0 # dm: distance matrix loss = pred_loss + comb_loss + focal_loss + dm_loss + torsion_loss + sr_loss return loss, loss_list def build_alpha_rotation(self, alpha, alpha_cos=None): """ Builds the alpha rotation matrix :param alpha: predicted values of torsion parameter alpha (n_dihedral_pairs) :return: alpha rotation matrix (n_dihedral_pairs, 3, 3) """ H_alpha = torch.FloatTensor([[[1, 0, 0], [0, 0, 0], [0, 0, 0]]]).repeat(alpha.shape[0], 1, 1).to(self.device) if torch.is_tensor(alpha_cos): H_alpha[:, 1, 1] = alpha_cos H_alpha[:, 1, 2] = -alpha H_alpha[:, 2, 1] = alpha H_alpha[:, 2, 2] = alpha_cos else: H_alpha[:, 1, 1] = torch.cos(alpha) H_alpha[:, 1, 2] = -torch.sin(alpha) H_alpha[:, 2, 1] = torch.sin(alpha) H_alpha[:, 2, 2] = torch.cos(alpha) return H_alpha