File size: 4,636 Bytes
62a2f1c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 |
import json
import os.path as osp
from multiprocessing import Pool
import numpy as np
import pandas as pd
import torch
from ogb.utils.features import atom_to_feature_vector, bond_to_feature_vector
from rdkit import Chem
from torch_geometric.data import Data, InMemoryDataset
from tqdm import tqdm
class Molecule3D(InMemoryDataset):
def __init__(
self,
root,
transform=None,
pre_transform=None,
pre_filter=None,
**kwargs,
):
self.root = root
super(Molecule3D, self).__init__(root, transform, pre_transform, pre_filter)
self.data, self.slices = torch.load(self.processed_paths[0])
@property
def processed_file_names(self):
return 'molecule3d.pt'
def process(self):
data_list = []
sdf_paths = [
osp.join(self.raw_dir, 'combined_mols_0_to_1000000.sdf'),
osp.join(self.raw_dir, 'combined_mols_1000000_to_2000000.sdf'),
osp.join(self.raw_dir, 'combined_mols_2000000_to_3000000.sdf'),
osp.join(self.raw_dir, 'combined_mols_3000000_to_3899647.sdf')
]
suppl_list = [Chem.SDMolSupplier(p, removeHs=False, sanitize=True) for p in sdf_paths]
target_path = osp.join(self.raw_dir, 'properties.csv')
target_df = pd.read_csv(target_path)
abs_idx = -1
for i, suppl in enumerate(suppl_list):
with Pool(processes=120) as pool:
iter = pool.imap(self.mol2graph, suppl)
for j, graph in tqdm(enumerate(iter), total=len(suppl)):
abs_idx += 1
data = Data()
data.__num_nodes__ = int(graph['num_nodes'])
# Required by GNNs
data.edge_index = torch.from_numpy(graph['edge_index']).to(torch.int64)
data.edge_attr = torch.from_numpy(graph['edge_feat']).to(torch.int64)
data.x = torch.from_numpy(graph['node_feat']).to(torch.int64)
data.y = torch.FloatTensor([target_df.iloc[abs_idx, 6]]).unsqueeze(1)
# Required by ViSNet
data.pos = torch.tensor(graph['position'], dtype=torch.float32)
data.z = torch.tensor(graph['z'], dtype=torch.int64)
data_list.append(data)
torch.save(self.collate(data_list), self.processed_paths[0])
def get_idx_split(self, split_mode='random'):
assert split_mode in ['random', 'scaffold']
split_dict = json.load(open(osp.join(self.raw_dir, f'{split_mode}_split_inds.json'), 'r'))
for key, values in split_dict.items():
split_dict[key] = torch.tensor(values)
return split_dict
def mol2graph(self, mol):
# atoms
atom_features_list = []
for atom in mol.GetAtoms():
atom_features_list.append(atom_to_feature_vector(atom))
x = np.array(atom_features_list, dtype = np.int64)
coords = mol.GetConformer().GetPositions()
z = [atom.GetAtomicNum() for atom in mol.GetAtoms()]
# bonds
num_bond_features = 3 # bond type, bond stereo, is_conjugated
if len(mol.GetBonds()) > 0: # mol has bonds
edges_list = []
edge_features_list = []
for bond in mol.GetBonds():
i = bond.GetBeginAtomIdx()
j = bond.GetEndAtomIdx()
edge_feature = bond_to_feature_vector(bond)
# add edges in both directions
edges_list.append((i, j))
edge_features_list.append(edge_feature)
edges_list.append((j, i))
edge_features_list.append(edge_feature)
# data.edge_index: Graph connectivity in COO format with shape [2, num_edges]
edge_index = np.array(edges_list, dtype = np.int64).T
# data.edge_attr: Edge feature matrix with shape [num_edges, num_edge_features]
edge_attr = np.array(edge_features_list, dtype = np.int64)
else: # mol has no bonds
edge_index = np.empty((2, 0), dtype = np.int64)
edge_attr = np.empty((0, num_bond_features), dtype = np.int64)
graph = dict()
graph['edge_index'] = edge_index
graph['edge_feat'] = edge_attr
graph['node_feat'] = x
graph['num_nodes'] = len(x)
graph['position'] = coords
graph['z'] = z
return graph
|