import torch import torch.nn as nn from UltraFlow import layers, losses class IGN_basic(nn.Module): def __init__(self,config): super(IGN_basic, self).__init__() self.config = config self.pretrain_assay_mlp_share = config.train.pretrain_assay_mlp_share self.pretrain_use_assay_description = config.train.pretrain_use_assay_description self.graph_conv = layers.ModifiedAttentiveFPGNNV2(config.model.lig_node_dim, config.model.lig_edge_dim, config.model.num_layers, config.model.hidden_dim, config.model.dropout, config.model.jk) if config.model.jk == 'concat': self.noncov_graph = layers.DTIConvGraph3Layer_IGN_basic(config.model.hidden_dim * config.model.num_layers + config.model.inter_edge_dim, config.model.inter_out_dim, config.model.dropout) else: self.noncov_graph = layers.DTIConvGraph3Layer_IGN_basic(config.model.hidden_dim + config.model.inter_edge_dim, config.model.inter_out_dim, config.model.dropout) if config.model.readout.startswith('multi_head') and config.model.attn_merge=='concat': self.FC = layers.FC(config.model.inter_out_dim * (config.model.num_head + 1), config.model.fc_hidden_dim, config.model.dropout, config.model.out_dim) else: self.FC = layers.FC(config.model.inter_out_dim * 2, config.model.fc_hidden_dim, config.model.dropout, config.model.out_dim) self.readout = layers.ReadsOutLayer(config.model.inter_out_dim, config.model.readout, config.model.num_head, config.model.attn_merge) self.softmax = nn.Softmax(dim=1) if self.pretrain_use_assay_description: print(f'use assay descrption type: {config.data.assay_des_type}') if self.pretrain_assay_mlp_share: self.assay_info_aggre_mlp = layers.FC(config.data.assay_des_dim, config.model.assay_des_fc_hidden_dim, config.model.dropout, config.model.inter_out_dim * 2) else: self.assay_info_aggre_mlp_pointwise = layers.FC(config.data.assay_des_dim, config.model.assay_des_fc_hidden_dim, config.model.dropout, config.model.inter_out_dim * 2) self.assay_info_aggre_mlp_pairwise = layers.FC(config.data.assay_des_dim, config.model.assay_des_fc_hidden_dim, config.model.dropout, config.model.inter_out_dim * 2) def forward(self, batch): bg_lig, bg_prot, bg_inter, labels, _, ass_des = batch node_feats_lig = self.graph_conv(bg_lig) node_feats_prot = self.graph_conv(bg_prot) bg_inter.ndata['h'] = self.alignfeature(bg_lig,bg_prot,node_feats_lig,node_feats_prot) bond_feats_inter = self.noncov_graph(bg_inter) graph_embedding = self.readout(bg_inter, bond_feats_inter) if self.pretrain_use_assay_description: if self.pretrain_assay_mlp_share: ranking_assay_embedding = self.assay_info_aggre_mlp(ass_des) affinity_pred = self.FC(graph_embedding + ranking_assay_embedding) else: regression_assay_embedding = self.assay_info_aggre_mlp_pointwise(ass_des) affinity_pred = self.FC(graph_embedding + regression_assay_embedding) ranking_assay_embedding = self.assay_info_aggre_mlp_pairwise(ass_des) else: affinity_pred = self.FC(graph_embedding) ranking_assay_embedding = torch.zeros(len(affinity_pred)) return affinity_pred, graph_embedding, ranking_assay_embedding def alignfeature(self,bg_lig,bg_prot,node_feats_lig,node_feats_prot): inter_feature = torch.cat((node_feats_lig,node_feats_prot)) lig_num,prot_num = bg_lig.batch_num_nodes(),bg_prot.batch_num_nodes() lig_start, prot_start = lig_num.cumsum(0) - lig_num, prot_num.cumsum(0) - prot_num inter_start = lig_start + prot_start for i in range(lig_num.shape[0]): inter_feature[inter_start[i]:inter_start[i]+lig_num[i]] = node_feats_lig[lig_start[i]:lig_start[i]+lig_num[i]] inter_feature[inter_start[i]+lig_num[i]:inter_start[i]+lig_num[i]+prot_num[i]] = node_feats_prot[prot_start[i]:prot_start[i]+prot_num[i]] return inter_feature class IGN(nn.Module): def __init__(self,config): super(IGN, self).__init__() self.config = config self.pretrain_assay_mlp_share = config.train.pretrain_assay_mlp_share self.pretrain_use_assay_description = config.train.pretrain_use_assay_description self.ligand_conv = layers.ModifiedAttentiveFPGNNV2(config.model.lig_node_dim, config.model.lig_edge_dim, config.model.num_layers, config.model.hidden_dim, config.model.dropout, config.model.jk) self.protein_conv = layers.ModifiedAttentiveFPGNNV2(config.model.pro_node_dim, config.model.pro_edge_dim, config.model.num_layers, config.model.hidden_dim, config.model.dropout, config.model.jk) if config.model.jk == 'concat': self.noncov_graph = layers.DTIConvGraph3Layer(config.model.hidden_dim * (config.model.num_layers + config.model.num_layers) + config.model.inter_edge_dim, config.model.inter_out_dim, config.model.dropout) else: self.noncov_graph = layers.DTIConvGraph3Layer(config.model.hidden_dim * 2 + config.model.inter_edge_dim, config.model.inter_out_dim, config.model.dropout) if config.model.readout.startswith('multi_head') and config.model.attn_merge=='concat': self.FC = layers.FC(config.model.inter_out_dim * (config.model.num_head + 1), config.model.fc_hidden_dim, config.model.dropout, config.model.out_dim) else: self.FC = layers.FC(config.model.inter_out_dim * 2, config.model.fc_hidden_dim, config.model.dropout, config.model.out_dim) self.readout = layers.ReadsOutLayer(config.model.inter_out_dim, config.model.readout, config.model.num_head, config.model.attn_merge) self.softmax = nn.Softmax(dim=1) if self.pretrain_use_assay_description: print(f'use assay descrption type: {config.data.assay_des_type}') if self.pretrain_assay_mlp_share: self.assay_info_aggre_mlp = layers.FC(config.data.assay_des_dim, config.model.assay_des_fc_hidden_dim, config.model.dropout, config.model.inter_out_dim * 2) else: self.assay_info_aggre_mlp_pointwise = layers.FC(config.data.assay_des_dim, config.model.assay_des_fc_hidden_dim, config.model.dropout, config.model.inter_out_dim * 2) self.assay_info_aggre_mlp_pairwise = layers.FC(config.data.assay_des_dim, config.model.assay_des_fc_hidden_dim, config.model.dropout, config.model.inter_out_dim * 2) def forward(self, batch): bg_lig, bg_prot, bg_inter, labels, _, ass_des = batch node_feats_lig = self.ligand_conv(bg_lig) node_feats_prot = self.protein_conv(bg_prot) bg_inter.ndata['h'] = self.alignfeature(bg_lig,bg_prot,node_feats_lig,node_feats_prot) bond_feats_inter = self.noncov_graph(bg_inter) graph_embedding = self.readout(bg_inter, bond_feats_inter) if self.pretrain_use_assay_description: if self.pretrain_assay_mlp_share: ranking_assay_embedding = self.assay_info_aggre_mlp(ass_des) affinity_pred = self.FC(graph_embedding + ranking_assay_embedding) else: regression_assay_embedding = self.assay_info_aggre_mlp_pointwise(ass_des) affinity_pred = self.FC(graph_embedding + regression_assay_embedding) ranking_assay_embedding = self.assay_info_aggre_mlp_pairwise(ass_des) else: affinity_pred = self.FC(graph_embedding) ranking_assay_embedding = torch.zeros(len(affinity_pred)) return affinity_pred, graph_embedding, ranking_assay_embedding def alignfeature(self,bg_lig,bg_prot,node_feats_lig,node_feats_prot): inter_feature = torch.cat((node_feats_lig,node_feats_prot)) lig_num,prot_num = bg_lig.batch_num_nodes(),bg_prot.batch_num_nodes() lig_start, prot_start = lig_num.cumsum(0) - lig_num, prot_num.cumsum(0) - prot_num inter_start = lig_start + prot_start for i in range(lig_num.shape[0]): inter_feature[inter_start[i]:inter_start[i]+lig_num[i]] = node_feats_lig[lig_start[i]:lig_start[i]+lig_num[i]] inter_feature[inter_start[i]+lig_num[i]:inter_start[i]+lig_num[i]+prot_num[i]] = node_feats_prot[prot_start[i]:prot_start[i]+prot_num[i]] return inter_feature class GNNs(nn.Module): def __init__(self, nLigNode, nLigEdge, nLayer, nHid, JK, GNN): super(GNNs, self).__init__() if GNN == 'GCN': self.Encoder = layers.GCN(nLigNode, hidden_feats=[nHid] * nLayer) elif GNN == 'GAT': self.Encoder = layers.GAT(nLigNode, hidden_feats=[nHid] * nLayer) elif GNN == 'GIN': self.Encoder = layers.GIN(nLigNode, nHid, nLayer, num_mlp_layers=2, dropout=0.1, learn_eps=False, neighbor_pooling_type='sum', JK=JK) elif GNN == 'EGNN': self.Encoder = layers.EGNN(nLigNode, nLigEdge, nHid, nLayer, dropout=0.1, JK=JK) elif GNN == 'AttentiveFP': self.Encoder = layers.ModifiedAttentiveFPGNNV2(nLigNode, nLigEdge, nLayer, nHid, 0.1, JK) def forward(self, Graph, Perturb=None): Node_Rep = self.Encoder(Graph, Perturb) return Node_Rep class Affinity_GNNs(nn.Module): def __init__(self, config): super(Affinity_GNNs, self).__init__() lig_node_dim = config.model.lig_node_dim lig_edge_dim = config.model.lig_edge_dim pro_node_dim = config.model.pro_node_dim pro_edge_dim = config.model.pro_edge_dim layer_num = config.model.num_layers hidden_dim = config.model.hidden_dim jk = config.model.jk GNN = config.model.GNN_type self.pretrain_assay_mlp_share = config.train.pretrain_assay_mlp_share self.pretrain_use_assay_description = config.train.pretrain_use_assay_description self.lig_encoder = GNNs(lig_node_dim, lig_edge_dim, layer_num, hidden_dim, jk, GNN) self.pro_encoder = GNNs(pro_node_dim, pro_edge_dim, layer_num, hidden_dim, jk, GNN) if config.model.jk == 'concat': self.noncov_graph = layers.DTIConvGraph3Layer(hidden_dim * (layer_num + layer_num) + config.model.inter_edge_dim, config.model.inter_out_dim, config.model.dropout) else: self.noncov_graph = layers.DTIConvGraph3Layer(hidden_dim * 2 + config.model.inter_edge_dim, config.model.inter_out_dim, config.model.dropout) self.readout = layers.ReadsOutLayer(config.model.inter_out_dim, config.model.readout, config.model.num_head, config.model.attn_merge) if config.model.readout.startswith('multi_head') and config.model.attn_merge=='concat': self.FC = layers.FC(config.model.inter_out_dim * (config.model.num_head + 1), config.model.fc_hidden_dim, config.model.dropout, config.model.out_dim) else: self.FC = layers.FC(config.model.inter_out_dim * 2, config.model.fc_hidden_dim, config.model.dropout, config.model.out_dim) self.softmax = nn.Softmax(dim=1) if self.pretrain_use_assay_description: print(f'use assay descrption type: {config.data.assay_des_type}') if self.pretrain_assay_mlp_share: self.assay_info_aggre_mlp = layers.FC(config.data.assay_des_dim, config.model.assay_des_fc_hidden_dim, config.model.dropout, config.model.inter_out_dim * 2) else: self.assay_info_aggre_mlp_pointwise = layers.FC(config.data.assay_des_dim, config.model.assay_des_fc_hidden_dim, config.model.dropout, config.model.inter_out_dim * 2) self.assay_info_aggre_mlp_pairwise = layers.FC(config.data.assay_des_dim, config.model.assay_des_fc_hidden_dim, config.model.dropout, config.model.inter_out_dim * 2) def forward(self, batch): bg_lig, bg_prot, bg_inter, labels, _, ass_des = batch node_feats_lig = self.lig_encoder(bg_lig) node_feats_prot = self.pro_encoder(bg_prot) bg_inter.ndata['h'] = self.alignfeature(bg_lig,bg_prot,node_feats_lig,node_feats_prot) bond_feats_inter = self.noncov_graph(bg_inter) graph_embedding = self.readout(bg_inter, bond_feats_inter) if self.pretrain_use_assay_description: if self.pretrain_assay_mlp_share: ranking_assay_embedding = self.assay_info_aggre_mlp(ass_des) affinity_pred = self.FC(graph_embedding + ranking_assay_embedding) else: regression_assay_embedding = self.assay_info_aggre_mlp_pointwise(ass_des) affinity_pred = self.FC(graph_embedding + regression_assay_embedding) ranking_assay_embedding = self.assay_info_aggre_mlp_pairwise(ass_des) else: affinity_pred = self.FC(graph_embedding) ranking_assay_embedding = torch.zeros(len(affinity_pred)) return affinity_pred, graph_embedding, ranking_assay_embedding def alignfeature(self,bg_lig,bg_prot,node_feats_lig,node_feats_prot): inter_feature = torch.cat((node_feats_lig,node_feats_prot)) lig_num,prot_num = bg_lig.batch_num_nodes(),bg_prot.batch_num_nodes() lig_start, prot_start = lig_num.cumsum(0) - lig_num, prot_num.cumsum(0) - prot_num inter_start = lig_start + prot_start for i in range(lig_num.shape[0]): inter_feature[inter_start[i]:inter_start[i]+lig_num[i]] = node_feats_lig[lig_start[i]:lig_start[i]+lig_num[i]] inter_feature[inter_start[i]+lig_num[i]:inter_start[i]+lig_num[i]+prot_num[i]] = node_feats_prot[prot_start[i]:prot_start[i]+prot_num[i]] return inter_feature class affinity_head(nn.Module): def __init__(self, config): super(affinity_head, self).__init__() self.pretrain_assay_mlp_share = config.train.pretrain_assay_mlp_share self.pretrain_use_assay_description = config.train.pretrain_use_assay_description if self.pretrain_use_assay_description: print(f'use assay descrption type: {config.data.assay_des_type}') if self.pretrain_assay_mlp_share: self.assay_info_aggre_mlp = layers.FC(config.data.assay_des_dim, config.model.assay_des_fc_hidden_dim, config.model.dropout, config.model.inter_out_dim * 2) else: self.assay_info_aggre_mlp_pointwise = layers.FC(config.data.assay_des_dim, config.model.assay_des_fc_hidden_dim, config.model.dropout, config.model.inter_out_dim * 2) self.assay_info_aggre_mlp_pairwise = layers.FC(config.data.assay_des_dim, config.model.assay_des_fc_hidden_dim, config.model.dropout, config.model.inter_out_dim * 2) if config.model.readout.startswith('multi_head') and config.model.attn_merge=='concat': self.FC = layers.FC(config.model.inter_out_dim * (config.model.num_head + 1), config.model.fintune_fc_hidden_dim, config.model.dropout, config.model.out_dim) else: self.FC = layers.FC(config.model.inter_out_dim * 2, config.model.fintune_fc_hidden_dim, config.model.dropout, config.model.out_dim) def forward(self, graph_embedding, ass_des): if self.pretrain_use_assay_description: if self.pretrain_assay_mlp_share: ranking_assay_embedding = self.assay_info_aggre_mlp(ass_des) affinity_pred = self.FC(graph_embedding + ranking_assay_embedding) else: regression_assay_embedding = self.assay_info_aggre_mlp_pointwise(ass_des) affinity_pred = self.FC(graph_embedding + regression_assay_embedding) ranking_assay_embedding = self.assay_info_aggre_mlp_pairwise(ass_des) else: affinity_pred = self.FC(graph_embedding) ranking_assay_embedding = torch.zeros(len(affinity_pred)) return affinity_pred class ASRP_head(nn.Module): def __init__(self, config): super(ASRP_head, self).__init__() self.readout = layers.ReadsOutLayer(config.model.inter_out_dim, config.model.readout, config.model.num_head, config.model.attn_merge) self.pretrain_assay_mlp_share = config.train.pretrain_assay_mlp_share self.pretrain_use_assay_description = config.train.pretrain_use_assay_description if self.pretrain_use_assay_description: print(f'use assay descrption type: {config.data.assay_des_type}') if self.pretrain_assay_mlp_share: self.assay_info_aggre_mlp = layers.FC(config.data.assay_des_dim, config.model.assay_des_fc_hidden_dim, config.model.dropout, config.model.inter_out_dim * 2) else: self.assay_info_aggre_mlp_pointwise = layers.FC(config.data.assay_des_dim, config.model.assay_des_fc_hidden_dim, config.model.dropout, config.model.inter_out_dim * 2) self.assay_info_aggre_mlp_pairwise = layers.FC(config.data.assay_des_dim, config.model.assay_des_fc_hidden_dim, config.model.dropout, config.model.inter_out_dim * 2) if config.model.readout.startswith('multi_head') and config.model.attn_merge=='concat': self.FC = layers.FC(config.model.inter_out_dim * (config.model.num_head + 1), config.model.fintune_fc_hidden_dim, config.model.dropout, config.model.out_dim) else: self.FC = layers.FC(config.model.inter_out_dim * 2, config.model.fintune_fc_hidden_dim, config.model.dropout, config.model.out_dim) self.regression_loss_fn = nn.MSELoss(reduce=False) self.ranking_loss_fn = losses.pairwise_BCE_loss(config) self.pairwise_two_tower_regression_loss = config.train.pairwise_two_tower_regression_loss if self.pairwise_two_tower_regression_loss: print('use two tower regression loss') def forward(self, bg_inter, bond_feats_inter, ass_des, labels, select_flag): graph_embedding = self.readout(bg_inter, bond_feats_inter) if self.pretrain_use_assay_description: if self.pretrain_assay_mlp_share: ranking_assay_embedding = self.assay_info_aggre_mlp(ass_des) affinity_pred = self.FC(graph_embedding + ranking_assay_embedding) else: regression_assay_embedding = self.assay_info_aggre_mlp_pointwise(ass_des) affinity_pred = self.FC(graph_embedding + regression_assay_embedding) ranking_assay_embedding = self.assay_info_aggre_mlp_pairwise(ass_des) else: affinity_pred = self.FC(graph_embedding) ranking_assay_embedding = torch.zeros(len(affinity_pred)) y_pred_num = len(affinity_pred) assert y_pred_num % 2 == 0 if self.pairwise_two_tower_regression_loss: regression_loss = self.regression_loss_fn(affinity_pred, labels) # labels_select = labels[select_flag] affinity_pred_select = affinity_pred[select_flag] regression_loss_select = regression_loss[select_flag].sum() else: regression_loss = self.regression_loss_fn(affinity_pred[:y_pred_num // 2], labels[:y_pred_num // 2]) # labels_select = labels[:y_pred_num // 2][select_flag[:y_pred_num // 2]] affinity_pred_select = affinity_pred[:y_pred_num // 2][select_flag[:y_pred_num // 2]] regression_loss_select = regression_loss[select_flag[:y_pred_num // 2]].sum() ranking_loss, relation, relation_pred = self.ranking_loss_fn(graph_embedding, labels, ranking_assay_embedding) # ranking_loss_select = ranking_loss[select_flag[:y_pred_num // 2]].sum() relation_select = relation[select_flag[:y_pred_num // 2]] relation_pred_selcet = relation_pred[select_flag[:y_pred_num // 2]] return regression_loss_select, ranking_loss_select,\ labels_select, affinity_pred_select,\ relation_select, relation_pred_selcet def forward_pointwise(self, bg_inter, bond_feats_inter, ass_des, labels, select_flag): graph_embedding = self.readout(bg_inter, bond_feats_inter) affinity_pred = self.FC(graph_embedding) regression_loss = self.regression_loss_fn(affinity_pred, labels) # regression_loss_select = regression_loss[select_flag].sum() labels_select = labels[select_flag] affinity_pred_select = affinity_pred[select_flag] return regression_loss_select, labels_select, affinity_pred_select def evaluate_mtl(self, bg_inter, bond_feats_inter, ass_des, labels): graph_embedding = self.readout(bg_inter, bond_feats_inter) if self.pretrain_use_assay_description: if self.pretrain_assay_mlp_share: ranking_assay_embedding = self.assay_info_aggre_mlp(ass_des) affinity_pred = self.FC(graph_embedding + ranking_assay_embedding) else: regression_assay_embedding = self.assay_info_aggre_mlp_pointwise(ass_des) affinity_pred = self.FC(graph_embedding + regression_assay_embedding) ranking_assay_embedding = self.assay_info_aggre_mlp_pairwise(ass_des) else: affinity_pred = self.FC(graph_embedding) ranking_assay_embedding = torch.zeros(len(affinity_pred)) n = graph_embedding.shape[0] pair_a_index, pair_b_index = [], [] for i in range(n): pair_a_index.extend([i] * (n - 1)) pair_b_index.extend([j for j in range(n) if i != j]) pair_index = pair_a_index + pair_b_index _, relation, relation_pred = self.ranking_fn(graph_embedding[pair_index], labels[pair_index], ranking_assay_embedding[pair_index]) return affinity_pred, relation, relation_pred class Affinity_GNNs_MTL(nn.Module): def __init__(self, config): super(Affinity_GNNs_MTL, self).__init__() lig_node_dim = config.model.lig_node_dim lig_edge_dim = config.model.lig_edge_dim pro_node_dim = config.model.pro_node_dim pro_edge_dim = config.model.pro_edge_dim layer_num = config.model.num_layers hidden_dim = config.model.hidden_dim jk = config.model.jk GNN = config.model.GNN_type self.multi_task = config.train.multi_task self.pretrain_assay_mlp_share = config.train.pretrain_assay_mlp_share self.pretrain_use_assay_description = config.train.pretrain_use_assay_description self.lig_encoder = GNNs(lig_node_dim, lig_edge_dim, layer_num, hidden_dim, jk, GNN) self.pro_encoder = GNNs(pro_node_dim, pro_edge_dim, layer_num, hidden_dim, jk, GNN) if config.model.jk == 'concat': self.noncov_graph = layers.DTIConvGraph3Layer(hidden_dim * (layer_num + layer_num) + config.model.inter_edge_dim, config.model.inter_out_dim, config.model.dropout) else: self.noncov_graph = layers.DTIConvGraph3Layer(hidden_dim * 2 + config.model.inter_edge_dim, config.model.inter_out_dim, config.model.dropout) self.softmax = nn.Softmax(dim=1) if self.multi_task == 'IC50KdKi': self.IC50_ASRP_head = ASRP_head(config) self.Kd_ASRP_head = ASRP_head(config) self.Ki_ASRP_head = ASRP_head(config) elif self.multi_task == 'IC50K': self.IC50_ASRP_head = ASRP_head(config) self.K_ASRP_head = ASRP_head(config) self.config = config def forward(self, batch, ASRP=True, Perturb=None, Perturb_v=None): if self.multi_task == 'IC50KdKi': bg_lig, bg_prot, bg_inter, labels, _, ass_des, IC50_f, Kd_f, Ki_f = batch lig_node_feats_init = bg_lig.ndata['h'] pro_node_feats_init = bg_prot.ndata['h'] if Perturb is not None and Perturb_v == 'v_intra': node_feats_lig = self.lig_encoder(bg_lig, Perturb_v[:bg_lig.number_of_nodes()]) node_feats_prot = self.pro_encoder(bg_prot, Perturb_v[bg_lig.number_of_nodes():]) else: node_feats_lig = self.lig_encoder(bg_lig) node_feats_prot = self.pro_encoder(bg_prot) if self.config.train.encoder_ablation == 'interact': return node_feats_lig, node_feats_prot elif self.config.train.encoder_ablation == 'ligand': node_feats_lig = node_feats_lig.zero_() node_feats_lig[:,:self.config.model.lig_node_dim] = lig_node_feats_init elif self.config.train.encoder_ablation == 'protein': node_feats_prot = node_feats_prot.zero_() node_feats_prot[:,:self.config.model.pro_node_dim] = pro_node_feats_init bg_inter.ndata['h'] = self.alignfeature(bg_lig,bg_prot,node_feats_lig,node_feats_prot) if Perturb is not None and Perturb_v == 'v_inter': bg_inter.ndata['h'] = bg_inter.ndata['h'] + Perturb bond_feats_inter = self.noncov_graph(bg_inter) if ASRP: return self.multi_head_pred(bg_inter, bond_feats_inter, labels, ass_des, IC50_f, Kd_f, Ki_f) else: return self.multi_head_pointwise(bg_inter, bond_feats_inter, labels, ass_des, IC50_f, Kd_f, Ki_f) elif self.multi_task == 'IC50K': bg_lig, bg_prot, bg_inter, labels, _, ass_des, IC50_f, K_f = batch lig_node_feats_init = bg_lig.ndata['h'] pro_node_feats_init = bg_prot.ndata['h'] if Perturb is not None and Perturb_v == 'v_intra': node_feats_lig = self.lig_encoder(bg_lig, Perturb_v[:bg_lig.number_of_nodes()]) node_feats_prot = self.pro_encoder(bg_prot, Perturb_v[bg_lig.number_of_nodes():]) else: node_feats_lig = self.lig_encoder(bg_lig) node_feats_prot = self.pro_encoder(bg_prot) if self.config.train.encoder_ablation == 'interact': return node_feats_lig, node_feats_prot elif self.config.train.encoder_ablation == 'ligand': node_feats_lig = node_feats_lig.zero_() node_feats_lig[:,:self.config.model.lig_node_dim] = lig_node_feats_init elif self.config.train.encoder_ablation == 'protein': node_feats_prot = node_feats_prot.zero_() node_feats_prot[:,:self.config.model.pro_node_dim] = pro_node_feats_init bg_inter.ndata['h'] = self.alignfeature(bg_lig,bg_prot,node_feats_lig,node_feats_prot) if Perturb is not None and Perturb_v == 'v_inter': bg_inter.ndata['h'] = bg_inter.ndata['h'] + Perturb bond_feats_inter = self.noncov_graph(bg_inter) if ASRP: return self.multi_head_pred_v2(bg_inter, bond_feats_inter, labels, ass_des, IC50_f, K_f) else: return self.multi_head_pointwise_v2(bg_inter, bond_feats_inter, labels, ass_des, IC50_f, K_f) def multi_head_pointwise(self, bg_inter, bond_feats_inter, labels, ass_des, IC50_f, Kd_f, Ki_f): regression_loss_IC50, affinity_IC50, affinity_pred_IC50 = \ self.IC50_ASRP_head.forward_pointwise(bg_inter, bond_feats_inter, ass_des, labels, IC50_f) regression_loss_Kd, affinity_Kd, affinity_pred_Kd = \ self.Kd_ASRP_head.forward_pointwise(bg_inter, bond_feats_inter, ass_des, labels, Kd_f) regression_loss_Ki, affinity_Ki, affinity_pred_Ki = \ self.Ki_ASRP_head.forward_pointwise(bg_inter, bond_feats_inter, ass_des, labels, Ki_f) return (regression_loss_IC50, regression_loss_Kd, regression_loss_Ki),\ (affinity_pred_IC50, affinity_pred_Kd, affinity_pred_Ki), \ (affinity_IC50, affinity_Kd, affinity_Ki) def multi_head_pointwise_v2(self, bg_inter, bond_feats_inter, labels, ass_des, IC50_f, K_f): regression_loss_IC50, affinity_IC50, affinity_pred_IC50 = \ self.IC50_ASRP_head.forward_pointwise(bg_inter, bond_feats_inter, ass_des, labels, IC50_f) regression_loss_K, affinity_K, affinity_pred_K = \ self.K_ASRP_head.forward_pointwise(bg_inter, bond_feats_inter, ass_des, labels, K_f) return (regression_loss_IC50, regression_loss_K),\ (affinity_pred_IC50, affinity_pred_K), \ (affinity_IC50, affinity_K) def multi_head_pred(self, bg_inter, bond_feats_inter, labels, ass_des, IC50_f, Kd_f, Ki_f): regression_loss_IC50, ranking_loss_IC50, \ affinity_IC50, affinity_pred_IC50, \ relation_IC50, relation_pred_IC50 = self.IC50_ASRP_head(bg_inter, bond_feats_inter, ass_des, labels, IC50_f) regression_loss_Kd, ranking_loss_Kd, \ affinity_Kd, affinity_pred_Kd, \ relation_Kd, relation_pred_Kd = self.Kd_ASRP_head(bg_inter, bond_feats_inter, ass_des, labels, Kd_f) regression_loss_Ki, ranking_loss_Ki, \ affinity_Ki, affinity_pred_Ki, \ relation_Ki, relation_pred_Ki = self.Ki_ASRP_head(bg_inter, bond_feats_inter, ass_des, labels, Ki_f) return (regression_loss_IC50, regression_loss_Kd, regression_loss_Ki),\ (ranking_loss_IC50, ranking_loss_Kd, ranking_loss_Ki), \ (affinity_pred_IC50, affinity_pred_Kd, affinity_pred_Ki), \ (relation_pred_IC50, relation_pred_Kd, relation_pred_Ki), \ (affinity_IC50, affinity_Kd, affinity_Ki), \ (relation_IC50, relation_Kd, relation_Kd) def multi_head_pred_v2(self, bg_inter, bond_feats_inter, labels, ass_des, IC50_f, K_f): regression_loss_IC50, ranking_loss_IC50, \ affinity_IC50, affinity_pred_IC50, \ relation_IC50, relation_pred_IC50 = self.IC50_ASRP_head(bg_inter, bond_feats_inter, ass_des, labels, IC50_f) regression_loss_K, ranking_loss_K, \ affinity_K, affinity_pred_K, \ relation_K, relation_pred_K = self.K_ASRP_head(bg_inter, bond_feats_inter, ass_des, labels, K_f) return (regression_loss_IC50, regression_loss_K),\ (ranking_loss_IC50, ranking_loss_K), \ (affinity_pred_IC50, affinity_pred_K), \ (relation_pred_IC50, relation_pred_K), \ (affinity_IC50, affinity_K), \ (relation_IC50, relation_K) def multi_head_evaluate(self, bg_inter, bond_feats_inter, labels, ass_des, IC50_f, Kd_f, Ki_f): if sum(IC50_f): assert sum(Kd_f) == 0 and sum(Ki_f) == 0 return self.IC50_ASRP_head.evaluate_mtl(bg_inter, bond_feats_inter, labels, ass_des) elif sum(Kd_f): assert sum(IC50_f) == 0 and sum(Ki_f) == 0 return self.Kd_ASRP_head.evaluate_mtl(bg_inter, bond_feats_inter, labels, ass_des) elif sum(Ki_f): assert sum(IC50_f) == 0 and sum(Kd_f) == 0 return self.Kd_ASRP_head.evaluate_mtl(bg_inter, bond_feats_inter, labels, ass_des) def alignfeature(self,bg_lig,bg_prot,node_feats_lig,node_feats_prot): inter_feature = torch.cat((node_feats_lig,node_feats_prot)) lig_num,prot_num = bg_lig.batch_num_nodes(),bg_prot.batch_num_nodes() lig_start, prot_start = lig_num.cumsum(0) - lig_num, prot_num.cumsum(0) - prot_num inter_start = lig_start + prot_start for i in range(lig_num.shape[0]): inter_feature[inter_start[i]:inter_start[i]+lig_num[i]] = node_feats_lig[lig_start[i]:lig_start[i]+lig_num[i]] inter_feature[inter_start[i]+lig_num[i]:inter_start[i]+lig_num[i]+prot_num[i]] = node_feats_prot[prot_start[i]:prot_start[i]+prot_num[i]] return inter_feature class interact_ablation(nn.Module): def __init__(self, config): super(interact_ablation, self).__init__() self.IC50_ASRP_head = interact_ablation_head(config) self.K_ASRP_head = interact_ablation_head(config) self.config = config def forward(self, graph_embedding, labels, IC50_f, K_f): regression_loss_IC50, \ affinity_IC50, affinity_pred_IC50,= self.IC50_ASRP_head(graph_embedding, labels, IC50_f) regression_loss_K, \ affinity_K, affinity_pred_K = self.K_ASRP_head(graph_embedding, labels, K_f) return (regression_loss_IC50, regression_loss_K),\ (affinity_pred_IC50, affinity_pred_K), \ (affinity_IC50, affinity_K), \ class interact_ablation_head(nn.Module): def __init__(self, config): super(interact_ablation_head, self).__init__() self.FC = layers.FC(config.model.inter_out_dim * 2, config.model.fintune_fc_hidden_dim, config.model.dropout, config.model.out_dim) self.regression_loss_fn = nn.MSELoss(reduce=False) def forward(self, graph_embedding, labels, select_flag): affinity_pred = self.FC(graph_embedding) regression_loss = self.regression_loss_fn(affinity_pred, labels) # regression_loss_select = regression_loss[select_flag].sum() labels_select = labels[select_flag] affinity_pred_select = affinity_pred[select_flag] return regression_loss_select, labels_select, affinity_pred_select