|
import copy |
|
import torch |
|
import numpy as np |
|
from torch_geometric.data import Data, Batch |
|
|
|
from torch.utils.data import Dataset |
|
|
|
FOLLOW_BATCH = ['protein_element', 'ligand_context_element', 'pos_real', 'pos_fake'] |
|
|
|
|
|
class ProteinLigandData(object): |
|
|
|
def __init__(self, *args, **kwargs): |
|
super().__init__(*args, **kwargs) |
|
|
|
@staticmethod |
|
def from_protein_ligand_dicts(protein_dict=None, ligand_dict=None, **kwargs): |
|
instance = ProteinLigandData(**kwargs) |
|
|
|
if protein_dict is not None: |
|
for key, item in protein_dict.items(): |
|
instance['protein_' + key] = item |
|
|
|
if ligand_dict is not None: |
|
for key, item in ligand_dict.items(): |
|
if key == 'moltree': |
|
instance['moltree'] = item |
|
else: |
|
instance['ligand_' + key] = item |
|
|
|
|
|
return instance |
|
|
|
|
|
def batch_from_data_list(data_list): |
|
return Batch.from_data_list(data_list, follow_batch=['ligand_element', 'protein_element']) |
|
|
|
|
|
def torchify_dict(data): |
|
output = {} |
|
for k, v in data.items(): |
|
if isinstance(v, np.ndarray): |
|
output[k] = torch.from_numpy(v) |
|
else: |
|
output[k] = v |
|
return output |
|
|
|
|
|
def collate_mols(mol_dicts): |
|
data_batch = {} |
|
batch_size = len(mol_dicts) |
|
for key in ['protein_pos', 'protein_atom_feature', 'ligand_context_pos', 'ligand_context_feature_full', |
|
'ligand_frontier', 'num_atoms', 'next_wid', 'current_wid', 'current_atoms', 'cand_labels', |
|
'ligand_pos_torsion', 'ligand_feature_torsion', 'true_sin', 'true_cos', 'true_three_hop', |
|
'dihedral_mask', 'protein_contact', 'true_dm', 'alpha_carbon_indicator']: |
|
data_batch[key] = torch.cat([mol_dict[key] for mol_dict in mol_dicts], dim=0) |
|
|
|
for key in ['xn_pos', 'yn_pos', 'ligand_torsion_xy_index', 'y_pos']: |
|
cat_list = [mol_dict[key].unsqueeze(0) for mol_dict in mol_dicts if len(mol_dict[key]) > 0] |
|
if len(cat_list) > 0: |
|
data_batch[key] = torch.cat(cat_list, dim=0) |
|
else: |
|
data_batch[key] = torch.tensor([]) |
|
|
|
for key in ['protein_element', 'ligand_context_element', 'current_atoms']: |
|
repeats = torch.tensor([len(mol_dict[key]) for mol_dict in mol_dicts]) |
|
data_batch[key + '_batch'] = torch.repeat_interleave(torch.arange(batch_size), repeats) |
|
for key in ['ligand_element_torsion']: |
|
repeats = torch.tensor([len(mol_dict[key]) for mol_dict in mol_dicts if len(mol_dict[key]) > 0]) |
|
if len(repeats) > 0: |
|
data_batch[key + '_batch'] = torch.repeat_interleave(torch.arange(len(repeats)), repeats) |
|
else: |
|
data_batch[key + '_batch'] = torch.tensor([]) |
|
|
|
p_idx, q_idx = torch.cartesian_prod(torch.arange(4), torch.arange(2)).chunk(2, dim=-1) |
|
p_idx, q_idx = p_idx.squeeze(-1), q_idx.squeeze(-1) |
|
protein_offsets = torch.cumsum(data_batch['protein_element_batch'].bincount(), dim=0) |
|
ligand_offsets = torch.cumsum(data_batch['ligand_context_element_batch'].bincount(), dim=0) |
|
protein_offsets, ligand_offsets = torch.cat([torch.tensor([0]), protein_offsets]), torch.cat([torch.tensor([0]), ligand_offsets]) |
|
ligand_idx, protein_idx = [], [] |
|
for i, mol_dict in enumerate(mol_dicts): |
|
if len(mol_dict['true_dm']) > 0: |
|
protein_idx.append(mol_dict['dm_protein_idx'][p_idx] + protein_offsets[i]) |
|
ligand_idx.append(mol_dict['dm_ligand_idx'][q_idx] + ligand_offsets[i]) |
|
if len(ligand_idx) > 0: |
|
data_batch['dm_ligand_idx'], data_batch['dm_protein_idx'] = torch.cat(ligand_idx), torch.cat(protein_idx) |
|
|
|
|
|
sr_ligand_idx, sr_protein_idx = [], [] |
|
for i, mol_dict in enumerate(mol_dicts): |
|
if len(mol_dict['true_dm']) > 0: |
|
ligand_atom_index = torch.arange(len(mol_dict['ligand_context_pos'])) |
|
p_idx, q_idx = torch.cartesian_prod(torch.arange(len(mol_dict['ligand_context_pos'])), torch.arange(len(mol_dict['protein_alpha_carbon_index']))).chunk(2, dim=-1) |
|
p_idx, q_idx = p_idx.squeeze(-1), q_idx.squeeze(-1) |
|
sr_ligand_idx.append(ligand_atom_index[p_idx] + ligand_offsets[i]) |
|
sr_protein_idx.append(mol_dict['protein_alpha_carbon_index'][q_idx] + protein_offsets[i]) |
|
if len(sr_ligand_idx) > 0: |
|
data_batch['sr_ligand_idx'], data_batch['sr_protein_idx'] = torch.cat(sr_ligand_idx).long(), torch.cat(sr_protein_idx).long() |
|
|
|
|
|
sr_ligand_idx0, sr_ligand_idx1 = [], [] |
|
for i, mol_dict in enumerate(mol_dicts): |
|
if len(mol_dict['true_dm']) > 0: |
|
ligand_atom_index = torch.arange(len(mol_dict['ligand_context_pos'])) |
|
p_idx, q_idx = torch.cartesian_prod(torch.arange(len(mol_dict['ligand_context_pos'])), torch.arange(len(mol_dict['ligand_context_pos']))).chunk(2, dim=-1) |
|
p_idx, q_idx = p_idx.squeeze(-1), q_idx.squeeze(-1) |
|
sr_ligand_idx0.append(ligand_atom_index[p_idx] + ligand_offsets[i]) |
|
sr_ligand_idx1.append(ligand_atom_index[q_idx] + ligand_offsets[i]) |
|
if len(ligand_idx) > 0: |
|
data_batch['sr_ligand_idx0'], data_batch['sr_ligand_idx1'] = torch.cat(sr_ligand_idx0).long(), torch.cat(sr_ligand_idx1).long() |
|
|
|
if len(data_batch['y_pos']) > 0: |
|
repeats = torch.tensor([len(mol_dict['ligand_element_torsion']) for mol_dict in mol_dicts if len(mol_dict['ligand_element_torsion']) > 0]) |
|
offsets = torch.cat([torch.tensor([0]), torch.cumsum(repeats, dim=0)])[:-1] |
|
data_batch['ligand_torsion_xy_index'] += offsets.unsqueeze(1) |
|
|
|
offsets1 = torch.cat([torch.tensor([0]), torch.cumsum(data_batch['num_atoms'], dim=0)])[:-1] |
|
data_batch['current_atoms'] += torch.repeat_interleave(offsets1, data_batch['current_atoms_batch'].bincount()) |
|
|
|
cand_mol_list = [] |
|
for data in mol_dicts: |
|
if len(data['cand_labels']) > 0: |
|
cand_mol_list.extend(data['cand_mols']) |
|
if len(cand_mol_list) > 0: |
|
data_batch['cand_mols'] = Batch.from_data_list(cand_mol_list) |
|
return data_batch |
|
|
|
|