ICLR_FLAG / utils /data.py
zaixizhang
renew
10efe81
raw
history blame
6.51 kB
import copy
import torch
import numpy as np
from torch_geometric.data import Data, Batch
# from torch_geometric.loader import DataLoader
from torch.utils.data import Dataset
FOLLOW_BATCH = ['protein_element', 'ligand_context_element', 'pos_real', 'pos_fake']
class ProteinLigandData(object):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@staticmethod
def from_protein_ligand_dicts(protein_dict=None, ligand_dict=None, **kwargs):
instance = ProteinLigandData(**kwargs)
if protein_dict is not None:
for key, item in protein_dict.items():
instance['protein_' + key] = item
if ligand_dict is not None:
for key, item in ligand_dict.items():
if key == 'moltree':
instance['moltree'] = item
else:
instance['ligand_' + key] = item
# instance['ligand_nbh_list'] = {i.item():[j.item() for k, j in enumerate(instance.ligand_bond_index[1]) if instance.ligand_bond_index[0, k].item() == i] for i in instance.ligand_bond_index[0]}
return instance
def batch_from_data_list(data_list):
return Batch.from_data_list(data_list, follow_batch=['ligand_element', 'protein_element'])
def torchify_dict(data):
output = {}
for k, v in data.items():
if isinstance(v, np.ndarray):
output[k] = torch.from_numpy(v)
else:
output[k] = v
return output
def collate_mols(mol_dicts):
data_batch = {}
batch_size = len(mol_dicts)
for key in ['protein_pos', 'protein_atom_feature', 'ligand_context_pos', 'ligand_context_feature_full',
'ligand_frontier', 'num_atoms', 'next_wid', 'current_wid', 'current_atoms', 'cand_labels',
'ligand_pos_torsion', 'ligand_feature_torsion', 'true_sin', 'true_cos', 'true_three_hop',
'dihedral_mask', 'protein_contact', 'true_dm', 'alpha_carbon_indicator']:
data_batch[key] = torch.cat([mol_dict[key] for mol_dict in mol_dicts], dim=0)
# unsqueeze dim0
for key in ['xn_pos', 'yn_pos', 'ligand_torsion_xy_index', 'y_pos']:
cat_list = [mol_dict[key].unsqueeze(0) for mol_dict in mol_dicts if len(mol_dict[key]) > 0]
if len(cat_list) > 0:
data_batch[key] = torch.cat(cat_list, dim=0)
else:
data_batch[key] = torch.tensor([])
# follow batch
for key in ['protein_element', 'ligand_context_element', 'current_atoms']:
repeats = torch.tensor([len(mol_dict[key]) for mol_dict in mol_dicts])
data_batch[key + '_batch'] = torch.repeat_interleave(torch.arange(batch_size), repeats)
for key in ['ligand_element_torsion']:
repeats = torch.tensor([len(mol_dict[key]) for mol_dict in mol_dicts if len(mol_dict[key]) > 0])
if len(repeats) > 0:
data_batch[key + '_batch'] = torch.repeat_interleave(torch.arange(len(repeats)), repeats)
else:
data_batch[key + '_batch'] = torch.tensor([])
# distance matrix prediction
p_idx, q_idx = torch.cartesian_prod(torch.arange(4), torch.arange(2)).chunk(2, dim=-1)
p_idx, q_idx = p_idx.squeeze(-1), q_idx.squeeze(-1)
protein_offsets = torch.cumsum(data_batch['protein_element_batch'].bincount(), dim=0)
ligand_offsets = torch.cumsum(data_batch['ligand_context_element_batch'].bincount(), dim=0)
protein_offsets, ligand_offsets = torch.cat([torch.tensor([0]), protein_offsets]), torch.cat([torch.tensor([0]), ligand_offsets])
ligand_idx, protein_idx = [], []
for i, mol_dict in enumerate(mol_dicts):
if len(mol_dict['true_dm']) > 0:
protein_idx.append(mol_dict['dm_protein_idx'][p_idx] + protein_offsets[i])
ligand_idx.append(mol_dict['dm_ligand_idx'][q_idx] + ligand_offsets[i])
if len(ligand_idx) > 0:
data_batch['dm_ligand_idx'], data_batch['dm_protein_idx'] = torch.cat(ligand_idx), torch.cat(protein_idx)
# structure refinement (alpha carbon - ligand atom)
sr_ligand_idx, sr_protein_idx = [], []
for i, mol_dict in enumerate(mol_dicts):
if len(mol_dict['true_dm']) > 0:
ligand_atom_index = torch.arange(len(mol_dict['ligand_context_pos']))
p_idx, q_idx = torch.cartesian_prod(torch.arange(len(mol_dict['ligand_context_pos'])), torch.arange(len(mol_dict['protein_alpha_carbon_index']))).chunk(2, dim=-1)
p_idx, q_idx = p_idx.squeeze(-1), q_idx.squeeze(-1)
sr_ligand_idx.append(ligand_atom_index[p_idx] + ligand_offsets[i])
sr_protein_idx.append(mol_dict['protein_alpha_carbon_index'][q_idx] + protein_offsets[i])
if len(sr_ligand_idx) > 0:
data_batch['sr_ligand_idx'], data_batch['sr_protein_idx'] = torch.cat(sr_ligand_idx).long(), torch.cat(sr_protein_idx).long()
# structure refinement (ligand atom - ligand atom)
sr_ligand_idx0, sr_ligand_idx1 = [], []
for i, mol_dict in enumerate(mol_dicts):
if len(mol_dict['true_dm']) > 0:
ligand_atom_index = torch.arange(len(mol_dict['ligand_context_pos']))
p_idx, q_idx = torch.cartesian_prod(torch.arange(len(mol_dict['ligand_context_pos'])), torch.arange(len(mol_dict['ligand_context_pos']))).chunk(2, dim=-1)
p_idx, q_idx = p_idx.squeeze(-1), q_idx.squeeze(-1)
sr_ligand_idx0.append(ligand_atom_index[p_idx] + ligand_offsets[i])
sr_ligand_idx1.append(ligand_atom_index[q_idx] + ligand_offsets[i])
if len(ligand_idx) > 0:
data_batch['sr_ligand_idx0'], data_batch['sr_ligand_idx1'] = torch.cat(sr_ligand_idx0).long(), torch.cat(sr_ligand_idx1).long()
# index
if len(data_batch['y_pos']) > 0:
repeats = torch.tensor([len(mol_dict['ligand_element_torsion']) for mol_dict in mol_dicts if len(mol_dict['ligand_element_torsion']) > 0])
offsets = torch.cat([torch.tensor([0]), torch.cumsum(repeats, dim=0)])[:-1]
data_batch['ligand_torsion_xy_index'] += offsets.unsqueeze(1)
offsets1 = torch.cat([torch.tensor([0]), torch.cumsum(data_batch['num_atoms'], dim=0)])[:-1]
data_batch['current_atoms'] += torch.repeat_interleave(offsets1, data_batch['current_atoms_batch'].bincount())
# cand mols: torch geometric Data
cand_mol_list = []
for data in mol_dicts:
if len(data['cand_labels']) > 0:
cand_mol_list.extend(data['cand_mols'])
if len(cand_mol_list) > 0:
data_batch['cand_mols'] = Batch.from_data_list(cand_mol_list)
return data_batch