File size: 4,636 Bytes
62a2f1c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import json
import os.path as osp
from multiprocessing import Pool

import numpy as np
import pandas as pd
import torch
from ogb.utils.features import atom_to_feature_vector, bond_to_feature_vector
from rdkit import Chem
from torch_geometric.data import Data, InMemoryDataset
from tqdm import tqdm


class Molecule3D(InMemoryDataset):
    
    def __init__(
        self,
        root,
        transform=None,
        pre_transform=None,
        pre_filter=None,
        **kwargs,
    ):
        
        self.root = root
        super(Molecule3D, self).__init__(root, transform, pre_transform, pre_filter)
        self.data, self.slices = torch.load(self.processed_paths[0])

    @property
    def processed_file_names(self):
        return 'molecule3d.pt'
    
    def process(self):
        
        data_list = []
        sdf_paths = [
            osp.join(self.raw_dir, 'combined_mols_0_to_1000000.sdf'),
            osp.join(self.raw_dir, 'combined_mols_1000000_to_2000000.sdf'),
            osp.join(self.raw_dir, 'combined_mols_2000000_to_3000000.sdf'),
            osp.join(self.raw_dir, 'combined_mols_3000000_to_3899647.sdf')
        ]
        suppl_list = [Chem.SDMolSupplier(p, removeHs=False, sanitize=True) for p in sdf_paths]
        
        
        target_path = osp.join(self.raw_dir, 'properties.csv')
        target_df = pd.read_csv(target_path)
        
        abs_idx = -1
        
        for i, suppl in enumerate(suppl_list):
            with Pool(processes=120) as pool:
                iter = pool.imap(self.mol2graph, suppl)
                for j, graph in tqdm(enumerate(iter), total=len(suppl)):
                    abs_idx += 1
                    
                    data = Data()
                    data.__num_nodes__ = int(graph['num_nodes'])
                    
                    # Required by GNNs
                    data.edge_index = torch.from_numpy(graph['edge_index']).to(torch.int64)
                    data.edge_attr = torch.from_numpy(graph['edge_feat']).to(torch.int64)
                    data.x = torch.from_numpy(graph['node_feat']).to(torch.int64)
                    data.y = torch.FloatTensor([target_df.iloc[abs_idx, 6]]).unsqueeze(1)
                    
                    # Required by ViSNet
                    data.pos = torch.tensor(graph['position'], dtype=torch.float32)
                    data.z = torch.tensor(graph['z'], dtype=torch.int64)
                    data_list.append(data)
                    
        torch.save(self.collate(data_list), self.processed_paths[0])
    
    def get_idx_split(self, split_mode='random'):
        assert split_mode in ['random', 'scaffold']
        split_dict = json.load(open(osp.join(self.raw_dir, f'{split_mode}_split_inds.json'), 'r'))
        for key, values in split_dict.items():
            split_dict[key] = torch.tensor(values)
        return split_dict
                  
    def mol2graph(self, mol):
        # atoms
        atom_features_list = []
        for atom in mol.GetAtoms():
            atom_features_list.append(atom_to_feature_vector(atom))
        x = np.array(atom_features_list, dtype = np.int64)
        
        coords = mol.GetConformer().GetPositions()
        z = [atom.GetAtomicNum() for atom in mol.GetAtoms()]

        # bonds
        num_bond_features = 3  # bond type, bond stereo, is_conjugated
        if len(mol.GetBonds()) > 0: # mol has bonds
            edges_list = []
            edge_features_list = []
            for bond in mol.GetBonds():
                i = bond.GetBeginAtomIdx()
                j = bond.GetEndAtomIdx()

                edge_feature = bond_to_feature_vector(bond)

                # add edges in both directions
                edges_list.append((i, j))
                edge_features_list.append(edge_feature)
                edges_list.append((j, i))
                edge_features_list.append(edge_feature)

            # data.edge_index: Graph connectivity in COO format with shape [2, num_edges]
            edge_index = np.array(edges_list, dtype = np.int64).T

            # data.edge_attr: Edge feature matrix with shape [num_edges, num_edge_features]
            edge_attr = np.array(edge_features_list, dtype = np.int64)

        else:   # mol has no bonds
            edge_index = np.empty((2, 0), dtype = np.int64)
            edge_attr = np.empty((0, num_bond_features), dtype = np.int64)

        graph = dict()
        graph['edge_index'] = edge_index
        graph['edge_feat'] = edge_attr
        graph['node_feat'] = x
        graph['num_nodes'] = len(x)
        graph['position'] = coords
        graph['z'] = z

        return graph