|
import sys |
|
sys.path.append("..") |
|
import torch |
|
import torch.nn as nn |
|
from torch.nn import Module, Linear, Embedding |
|
from torch.nn import functional as F |
|
from torch_scatter import scatter_add, scatter_mean |
|
from torch_geometric.data import Data, Batch |
|
from copy import deepcopy |
|
|
|
from .encoders import get_encoder, GNN_graphpred, MLP |
|
from .common import * |
|
from utils import dihedral_utils, chemutils |
|
|
|
|
|
class FLAG(Module): |
|
|
|
def __init__(self, config, protein_atom_feature_dim, ligand_atom_feature_dim, vocab): |
|
super().__init__() |
|
self.config = config |
|
self.vocab = vocab |
|
self.protein_atom_emb = Linear(protein_atom_feature_dim, config.hidden_channels) |
|
self.ligand_atom_emb = Linear(ligand_atom_feature_dim, config.hidden_channels) |
|
self.embedding = nn.Embedding(vocab.size() + 1, config.hidden_channels) |
|
self.W = nn.Linear(2 * config.hidden_channels, config.hidden_channels) |
|
self.W_o = nn.Linear(config.hidden_channels, self.vocab.size()) |
|
self.encoder = get_encoder(config.encoder) |
|
self.comb_head = GNN_graphpred(num_layer=3, emb_dim=config.hidden_channels, num_tasks=1, JK='last', |
|
drop_ratio=0.5, graph_pooling='mean', gnn_type='gin') |
|
if config.random_alpha: |
|
self.alpha_mlp = MLP(in_dim=config.hidden_channels * 4, out_dim=1, num_layers=2) |
|
else: |
|
self.alpha_mlp = MLP(in_dim=config.hidden_channels * 3, out_dim=1, num_layers=2) |
|
self.focal_mlp_ligand = MLP(in_dim=config.hidden_channels, out_dim=1, num_layers=1) |
|
self.focal_mlp_protein = MLP(in_dim=config.hidden_channels, out_dim=1, num_layers=1) |
|
self.dist_mlp = MLP(in_dim=protein_atom_feature_dim + ligand_atom_feature_dim, out_dim=1, num_layers=2) |
|
if config.refinement: |
|
self.refine_protein = MLP(in_dim=config.hidden_channels * 2 + config.encoder.edge_channels, out_dim=1, num_layers=2) |
|
self.refine_ligand = MLP(in_dim=config.hidden_channels * 2 + config.encoder.edge_channels, out_dim=1, num_layers=2) |
|
|
|
self.smooth_cross_entropy = SmoothCrossEntropyLoss(reduction='mean', smoothing=0.1) |
|
self.pred_loss = nn.CrossEntropyLoss() |
|
self.comb_loss = nn.BCEWithLogitsLoss() |
|
self.three_hop_loss = torch.nn.MSELoss() |
|
self.focal_loss = nn.BCEWithLogitsLoss() |
|
self.dist_loss = torch.nn.MSELoss(reduction='mean') |
|
|
|
def forward(self, protein_pos, protein_atom_feature, ligand_pos, ligand_atom_feature, batch_protein, batch_ligand): |
|
h_protein = self.protein_atom_emb(protein_atom_feature) |
|
h_ligand = self.ligand_atom_emb(ligand_atom_feature) |
|
|
|
h_ctx, pos_ctx, batch_ctx, protein_mask = compose_context_stable(h_protein=h_protein, h_ligand=h_ligand, |
|
pos_protein=protein_pos, pos_ligand=ligand_pos, |
|
batch_protein=batch_protein, |
|
batch_ligand=batch_ligand) |
|
h_ctx = self.encoder(node_attr=h_ctx, pos=pos_ctx, batch=batch_ctx) |
|
focal_pred = torch.cat([self.focal_mlp_protein(h_ctx[protein_mask]), self.focal_mlp_ligand(h_ctx[~protein_mask])], dim=0) |
|
|
|
return focal_pred, protein_mask, h_ctx |
|
|
|
def forward_motif(self, h_ctx_focal, current_wid, current_atoms_batch, n_samples=1): |
|
node_hiddens = scatter_add(h_ctx_focal, dim=0, index=current_atoms_batch) |
|
motif_hiddens = self.embedding(current_wid) |
|
pred_vecs = torch.cat([node_hiddens, motif_hiddens], dim=1) |
|
pred_vecs = nn.ReLU()(self.W(pred_vecs)) |
|
pred_scores = self.W_o(pred_vecs) |
|
pred_scores = F.softmax(pred_scores, dim=-1) |
|
_, preds = torch.max(pred_scores, dim=1) |
|
|
|
k = 5*n_samples |
|
select_pool = torch.topk(pred_scores, k, dim=1)[1] |
|
index = torch.randint(k, (select_pool.shape[0], n_samples)) |
|
preds = torch.cat([select_pool[i][index[i]] for i in range(len(index))]) |
|
|
|
idx_parent = torch.repeat_interleave(torch.arange(pred_scores.shape[0]), n_samples, dim=0).to(pred_scores.device) |
|
prob = pred_scores[idx_parent, preds] |
|
return preds, prob |
|
|
|
def forward_attach(self, mol_list, next_motif_smiles, device): |
|
cand_mols, cand_batch, new_atoms, one_atom_attach, intersection, attach_fail = chemutils.assemble(mol_list, next_motif_smiles) |
|
graph_data = Batch.from_data_list([chemutils.mol_to_graph_data_obj_simple(mol) for mol in cand_mols]).to(device) |
|
comb_pred = self.comb_head(graph_data.x, graph_data.edge_index, graph_data.edge_attr, graph_data.batch).reshape(-1) |
|
slice_idx = torch.cat([torch.tensor([0]), torch.cumsum(cand_batch.bincount(), dim=0)], dim=0) |
|
select = [(torch.argmax(comb_pred[slice_idx[i]:slice_idx[i + 1]]) + slice_idx[i]).item() for i in |
|
range(len(slice_idx) - 1)] |
|
''' |
|
select = [] |
|
for k in range(len(slice_idx) - 1): |
|
id = torch.multinomial(torch.exp(comb_pred[slice_idx[k]:slice_idx[k + 1]]).reshape(-1).float(), 1) |
|
select.append((id+slice_idx[k]).item())''' |
|
|
|
select_mols = [cand_mols[i] for i in select] |
|
new_atoms = [new_atoms[i] for i in select] |
|
one_atom_attach = [one_atom_attach[i] for i in select] |
|
intersection = [intersection[i] for i in select] |
|
return select_mols, new_atoms, one_atom_attach, intersection, attach_fail |
|
|
|
def forward_alpha(self, protein_pos, protein_atom_feature, ligand_pos, ligand_atom_feature, batch_protein, |
|
batch_ligand, xy_index, rotatable): |
|
|
|
h_protein = self.protein_atom_emb(protein_atom_feature) |
|
h_ligand = self.ligand_atom_emb(ligand_atom_feature) |
|
|
|
h_ctx, pos_ctx, batch_ctx, protein_mask = compose_context_stable(h_protein=h_protein, h_ligand=h_ligand, |
|
pos_protein=protein_pos, pos_ligand=ligand_pos, |
|
batch_protein=batch_protein, |
|
batch_ligand=batch_ligand) |
|
h_ctx = self.encoder(node_attr=h_ctx, pos=pos_ctx, batch=batch_ctx) |
|
h_ctx_ligand = h_ctx[~protein_mask] |
|
hx, hy = h_ctx_ligand[xy_index[:, 0]], h_ctx_ligand[xy_index[:, 1]] |
|
h_mol = scatter_add(h_ctx_ligand, dim=0, index=batch_ligand) |
|
h_mol = h_mol[rotatable] |
|
if self.config.random_alpha: |
|
rand_dist = torch.distributions.normal.Normal(loc=0, scale=1) |
|
rand_alpha = rand_dist.sample(hx.shape).to(hx.device) |
|
alpha = self.alpha_mlp(torch.cat([hx, hy, h_mol, rand_alpha], dim=-1)) |
|
else: |
|
alpha = self.alpha_mlp(torch.cat([hx, hy, h_mol], dim=-1)) |
|
return alpha |
|
|
|
def get_loss(self, protein_pos, protein_atom_feature, ligand_pos, ligand_atom_feature, ligand_pos_torsion, |
|
ligand_atom_feature_torsion, batch_protein, batch_ligand, batch_ligand_torsion, batch): |
|
self.device = protein_pos.device |
|
h_protein = self.protein_atom_emb(protein_atom_feature) |
|
h_ligand = self.ligand_atom_emb(ligand_atom_feature) |
|
|
|
loss_list = [0, 0, 0, 0, 0, 0] |
|
|
|
|
|
h_ctx, pos_ctx, batch_ctx, mask_protein = compose_context_stable(h_protein=h_protein, h_ligand=h_ligand, |
|
pos_protein=protein_pos, pos_ligand=ligand_pos, |
|
batch_protein=batch_protein, |
|
batch_ligand=batch_ligand) |
|
h_ctx = self.encoder(node_attr=h_ctx, pos=pos_ctx, batch=batch_ctx) |
|
h_ctx_ligand = h_ctx[~mask_protein] |
|
h_ctx_protein = h_ctx[mask_protein] |
|
h_ctx_focal = h_ctx[batch['current_atoms']] |
|
|
|
|
|
if len(batch['y_pos']) > 0: |
|
h_ligand_torsion = self.ligand_atom_emb(ligand_atom_feature_torsion) |
|
h_ctx_torison, pos_ctx_torison, batch_ctx_torsion, mask_protein = compose_context_stable(h_protein=h_protein, |
|
h_ligand=h_ligand_torsion, |
|
pos_protein=protein_pos, |
|
pos_ligand=ligand_pos_torsion, |
|
batch_protein=batch_protein, |
|
batch_ligand=batch_ligand_torsion) |
|
h_ctx_torsion = self.encoder(node_attr=h_ctx_torison, pos=pos_ctx_torison, batch=batch_ctx_torsion) |
|
h_ctx_ligand_torsion = h_ctx_torsion[~mask_protein] |
|
|
|
|
|
|
|
node_hiddens = scatter_add(h_ctx_focal, dim=0, index=batch['current_atoms_batch']) |
|
motif_hiddens = self.embedding(batch['current_wid']) |
|
pred_vecs = torch.cat([node_hiddens, motif_hiddens], dim=1) |
|
pred_vecs = nn.ReLU()(self.W(pred_vecs)) |
|
pred_scores = self.W_o(pred_vecs) |
|
pred_loss = self.pred_loss(pred_scores, batch['next_wid']) |
|
loss_list[0] = pred_loss.item() |
|
|
|
|
|
if len(batch['cand_labels']) > 0: |
|
cand_mols = batch['cand_mols'] |
|
comb_pred = self.comb_head(cand_mols.x, cand_mols.edge_index, cand_mols.edge_attr, cand_mols.batch) |
|
comb_loss = self.comb_loss(comb_pred, batch['cand_labels'].view(comb_pred.shape).float()) |
|
loss_list[1] = comb_loss.item() |
|
else: |
|
comb_loss = 0 |
|
|
|
|
|
focal_ligand_pred, focal_protein_pred = self.focal_mlp_ligand(h_ctx_ligand), self.focal_mlp_protein(h_ctx_protein) |
|
focal_loss = self.focal_loss(focal_ligand_pred.reshape(-1), batch['ligand_frontier'].float()) +\ |
|
self.focal_loss(focal_protein_pred.reshape(-1), batch['protein_contact'].float()) |
|
loss_list[2] = focal_loss.item() |
|
|
|
|
|
if len(batch['true_dm']) > 0: |
|
input = torch.cat([protein_atom_feature[batch['dm_protein_idx']], ligand_atom_feature[batch['dm_ligand_idx']]], dim=-1) |
|
pred_dist = self.dist_mlp(input) |
|
dm_target = batch['true_dm'].unsqueeze(-1) |
|
dm_loss = self.dist_loss(pred_dist, dm_target) |
|
loss_list[3] = dm_loss.item() |
|
else: |
|
dm_loss = 0 |
|
|
|
|
|
if self.config.refinement and len(batch['true_dm']) > 0: |
|
true_distance_alpha = torch.norm(batch['ligand_context_pos'][batch['sr_ligand_idx']] - batch['protein_pos'][batch['sr_protein_idx']], dim=1) |
|
true_distance_intra = torch.norm(batch['ligand_context_pos'][batch['sr_ligand_idx0']] - batch['ligand_context_pos'][batch['sr_ligand_idx1']], dim=1) |
|
input_distance_alpha = ligand_pos[batch['sr_ligand_idx']] - protein_pos[batch['sr_protein_idx']] |
|
input_distance_intra = ligand_pos[batch['sr_ligand_idx0']] - ligand_pos[batch['sr_ligand_idx1']] |
|
distance_emb1 = self.encoder.distance_expansion(torch.norm(input_distance_alpha, dim=1)) |
|
distance_emb2 = self.encoder.distance_expansion(torch.norm(input_distance_intra, dim=1)) |
|
input1 = torch.cat([h_ctx_ligand[batch['sr_ligand_idx']], h_ctx_protein[batch['sr_protein_idx']], distance_emb1], dim=-1)[true_distance_alpha<=10.0] |
|
input2 = torch.cat([h_ctx_ligand[batch['sr_ligand_idx0']], h_ctx_ligand[batch['sr_ligand_idx1']], distance_emb2], dim=-1)[true_distance_intra<=10.0] |
|
|
|
norm_dir1 = F.normalize(input_distance_alpha, p=2, dim=1)[true_distance_alpha<=10.0] |
|
norm_dir2 = F.normalize(input_distance_intra, p=2, dim=1)[true_distance_intra<=10.0] |
|
force1 = scatter_mean(self.refine_protein(input1)*norm_dir1, dim=0, index=batch['sr_ligand_idx'][true_distance_alpha<=10.0], dim_size=ligand_pos.size(0)) |
|
force2 = scatter_mean(self.refine_ligand(input2)*norm_dir2, dim=0, index=batch['sr_ligand_idx0'][true_distance_intra<=10.0], dim_size=ligand_pos.size(0)) |
|
new_ligand_pos = deepcopy(ligand_pos) |
|
new_ligand_pos += force1 |
|
new_ligand_pos += force2 |
|
refine_dist1 = torch.norm(new_ligand_pos[batch['sr_ligand_idx']] - protein_pos[batch['sr_protein_idx']], dim=1) |
|
refine_dist2 = torch.norm(new_ligand_pos[batch['sr_ligand_idx0']] - new_ligand_pos[batch['sr_ligand_idx1']], dim=1) |
|
sr_loss = (self.dist_loss(refine_dist1, true_distance_alpha) + self.dist_loss(refine_dist2, true_distance_intra)) |
|
loss_list[5] = sr_loss.item() |
|
else: |
|
sr_loss = 0 |
|
|
|
|
|
if len(batch['y_pos']) > 0: |
|
Hx = dihedral_utils.rotation_matrix_v2(batch['y_pos']) |
|
xn_pos = torch.matmul(Hx, batch['xn_pos'].permute(0, 2, 1)).permute(0, 2, 1) |
|
yn_pos = torch.matmul(Hx, batch['yn_pos'].permute(0, 2, 1)).permute(0, 2, 1) |
|
y_pos = torch.matmul(Hx, batch['y_pos'].unsqueeze(1).permute(0, 2, 1)).squeeze(-1) |
|
|
|
hx, hy = h_ctx_ligand_torsion[batch['ligand_torsion_xy_index'][:, 0]], h_ctx_ligand_torsion[batch['ligand_torsion_xy_index'][:, 1]] |
|
h_mol = scatter_add(h_ctx_ligand_torsion, dim=0, index=batch['ligand_element_torsion_batch']) |
|
if self.config.random_alpha: |
|
rand_dist = torch.distributions.normal.Normal(loc=0, scale=1) |
|
rand_alpha = rand_dist.sample(hx.shape).to(self.device) |
|
alpha = self.alpha_mlp(torch.cat([hx, hy, h_mol, rand_alpha], dim=-1)) |
|
else: |
|
alpha = self.alpha_mlp(torch.cat([hx, hy, h_mol], dim=-1)) |
|
|
|
R_alpha = self.build_alpha_rotation(torch.sin(alpha).squeeze(-1), torch.cos(alpha).squeeze(-1)) |
|
xn_pos = torch.matmul(R_alpha, xn_pos.permute(0, 2, 1)).permute(0, 2, 1) |
|
|
|
p_idx, q_idx = torch.cartesian_prod(torch.arange(3), torch.arange(3)).chunk(2, dim=-1) |
|
p_idx, q_idx = p_idx.squeeze(-1), q_idx.squeeze(-1) |
|
pred_sin, pred_cos = dihedral_utils.batch_dihedrals(xn_pos[:, p_idx], |
|
torch.zeros_like(y_pos).unsqueeze(1).repeat(1, 9, 1), |
|
y_pos.unsqueeze(1).repeat(1, 9, 1), |
|
yn_pos[:, q_idx]) |
|
dihedral_loss = torch.mean(dihedral_utils.von_Mises_loss(batch['true_cos'], pred_cos.reshape(-1), batch['true_sin'], pred_cos.reshape(-1))[batch['dihedral_mask']]) |
|
torsion_loss = -dihedral_loss |
|
loss_list[4] = torsion_loss.item() |
|
else: |
|
torsion_loss = 0 |
|
|
|
|
|
loss = pred_loss + comb_loss + focal_loss + dm_loss + torsion_loss + sr_loss |
|
|
|
return loss, loss_list |
|
|
|
def build_alpha_rotation(self, alpha, alpha_cos=None): |
|
""" |
|
Builds the alpha rotation matrix |
|
|
|
:param alpha: predicted values of torsion parameter alpha (n_dihedral_pairs) |
|
:return: alpha rotation matrix (n_dihedral_pairs, 3, 3) |
|
""" |
|
H_alpha = torch.FloatTensor([[[1, 0, 0], [0, 0, 0], [0, 0, 0]]]).repeat(alpha.shape[0], 1, 1).to(self.device) |
|
|
|
if torch.is_tensor(alpha_cos): |
|
H_alpha[:, 1, 1] = alpha_cos |
|
H_alpha[:, 1, 2] = -alpha |
|
H_alpha[:, 2, 1] = alpha |
|
H_alpha[:, 2, 2] = alpha_cos |
|
else: |
|
H_alpha[:, 1, 1] = torch.cos(alpha) |
|
H_alpha[:, 1, 2] = -torch.sin(alpha) |
|
H_alpha[:, 2, 1] = torch.sin(alpha) |
|
H_alpha[:, 2, 2] = torch.cos(alpha) |
|
|
|
return H_alpha |
|
|
|
|