import pickle import os.path as osp import re import torch import numpy as np from tqdm import tqdm from rdkit import Chem from rdkit import RDLogger from torch_geometric.data import (Data, InMemoryDataset) RDLogger.DisableLog('rdApp.*') class DruggenDataset(InMemoryDataset): def __init__(self, root, dataset_file, raw_files, max_atom, features, transform=None, pre_transform=None, pre_filter=None): self.dataset_name = dataset_file.split(".")[0] self.dataset_file = dataset_file self.raw_files = raw_files self.max_atom = max_atom self.features = features super().__init__(root, transform, pre_transform, pre_filter) self.data, self.slices = torch.load(osp.join(root, dataset_file)) @property def raw_file_names(self): return self.raw_files @property def processed_file_names(self): ''' Return the processed file names. If these names are not present, they will be automatically processed using process function of this class. ''' return self.dataset_file def _generate_encoders_decoders(self, data): """ Generates the encoders and decoders for the atoms and bonds. """ self.data = data print('Creating atoms encoder and decoder..') atom_labels = set() # bond_labels = set() self.max_atom_size_in_data = 0 for smile in data: mol = Chem.MolFromSmiles(smile) atom_labels.update([atom.GetAtomicNum() for atom in mol.GetAtoms()]) # bond_labels.update([bond.GetBondType() for bond in mol.GetBonds()]) self.max_atom_size_in_data = max(self.max_atom_size_in_data, mol.GetNumAtoms()) atom_labels.update([0]) # add PAD symbol (for unknown atoms) atom_labels = sorted(atom_labels) # turn set into list and sort it # atom_labels = sorted(set([atom.GetAtomicNum() for mol in self.data for atom in mol.GetAtoms()] + [0])) self.atom_encoder_m = {l: i for i, l in enumerate(atom_labels)} self.atom_decoder_m = {i: l for i, l in enumerate(atom_labels)} self.atom_num_types = len(atom_labels) print(f'Created atoms encoder and decoder with {self.atom_num_types - 1} atom types and 1 PAD symbol!') print("atom_labels", atom_labels) print('Creating bonds encoder and decoder..') # bond_labels = [Chem.rdchem.BondType.ZERO] + list(sorted(set(bond.GetBondType() # for mol in self.data # for bond in mol.GetBonds()))) bond_labels = [ Chem.rdchem.BondType.ZERO, Chem.rdchem.BondType.SINGLE, Chem.rdchem.BondType.DOUBLE, Chem.rdchem.BondType.TRIPLE, Chem.rdchem.BondType.AROMATIC, ] print("bond labels", bond_labels) self.bond_encoder_m = {l: i for i, l in enumerate(bond_labels)} self.bond_decoder_m = {i: l for i, l in enumerate(bond_labels)} self.bond_num_types = len(bond_labels) print(f'Created bonds encoder and decoder with {self.bond_num_types - 1} bond types and 1 PAD symbol!') #dataset_names = str(self.dataset_name) with open("DrugGEN/data/encoders/" +"atom_" + self.dataset_name + ".pkl","wb") as atom_encoders: pickle.dump(self.atom_encoder_m,atom_encoders) with open("DrugGEN/data/decoders/" +"atom_" + self.dataset_name + ".pkl","wb") as atom_decoders: pickle.dump(self.atom_decoder_m,atom_decoders) with open("DrugGEN/data/encoders/" +"bond_" + self.dataset_name + ".pkl","wb") as bond_encoders: pickle.dump(self.bond_encoder_m,bond_encoders) with open("DrugGEN/data/decoders/" +"bond_" + self.dataset_name + ".pkl","wb") as bond_decoders: pickle.dump(self.bond_decoder_m,bond_decoders) def generate_adjacency_matrix(self, mol, connected=True, max_length=None): """ Generates the adjacency matrix for a molecule. Args: mol (Molecule): The molecule object. connected (bool): Whether to check for connectivity in the molecule. Defaults to True. max_length (int): The maximum length of the adjacency matrix. Defaults to the number of atoms in the molecule. Returns: numpy.ndarray or None: The adjacency matrix if connected and all atoms have a degree greater than 0, otherwise None. """ max_length = max_length if max_length is not None else mol.GetNumAtoms() A = np.zeros(shape=(max_length, max_length)) begin, end = [b.GetBeginAtomIdx() for b in mol.GetBonds()], [b.GetEndAtomIdx() for b in mol.GetBonds()] bond_type = [self.bond_encoder_m[b.GetBondType()] for b in mol.GetBonds()] A[begin, end] = bond_type A[end, begin] = bond_type degree = np.sum(A[:mol.GetNumAtoms(), :mol.GetNumAtoms()], axis=-1) return A if connected and (degree > 0).all() else None def generate_node_features(self, mol, max_length=None): """ Generates the node features for a molecule. Args: mol (Molecule): The molecule object. max_length (int): The maximum length of the node features. Defaults to the number of atoms in the molecule. Returns: numpy.ndarray: The node features matrix. """ max_length = max_length if max_length is not None else mol.GetNumAtoms() return np.array([self.atom_encoder_m[atom.GetAtomicNum()] for atom in mol.GetAtoms()] + [0] * ( max_length - mol.GetNumAtoms())) def generate_additional_features(self, mol, max_length=None): """ Generates additional features for a molecule. Args: mol (Molecule): The molecule object. max_length (int): The maximum length of the additional features. Defaults to the number of atoms in the molecule. Returns: numpy.ndarray: The additional features matrix. """ max_length = max_length if max_length is not None else mol.GetNumAtoms() features = np.array([[*[a.GetDegree() == i for i in range(5)], *[a.GetExplicitValence() == i for i in range(9)], *[int(a.GetHybridization()) == i for i in range(1, 7)], *[a.GetImplicitValence() == i for i in range(9)], a.GetIsAromatic(), a.GetNoImplicit(), *[a.GetNumExplicitHs() == i for i in range(5)], *[a.GetNumImplicitHs() == i for i in range(5)], *[a.GetNumRadicalElectrons() == i for i in range(5)], a.IsInRing(), *[a.IsInRingSize(i) for i in range(2, 9)]] for a in mol.GetAtoms()], dtype=np.int32) return np.vstack((features, np.zeros((max_length - features.shape[0], features.shape[1])))) def decoder_load(self, dictionary_name): with open("DrugGEN/data/decoders/" + dictionary_name + "_" + self.dataset_name + '.pkl', 'rb') as f: return pickle.load(f) def drugs_decoder_load(self, dictionary_name): with open("DrugGEN/data/decoders/" + dictionary_name +'.pkl', 'rb') as f: return pickle.load(f) def matrices2mol(self, node_labels, edge_labels, strict=True): mol = Chem.RWMol() RDLogger.DisableLog('rdApp.*') atom_decoders = self.decoder_load("atom") bond_decoders = self.decoder_load("bond") for node_label in node_labels: mol.AddAtom(Chem.Atom(atom_decoders[node_label])) for start, end in zip(*np.nonzero(edge_labels)): if start > end: mol.AddBond(int(start), int(end), bond_decoders[edge_labels[start, end]]) mol = self.correct_mol(mol) if strict: try: Chem.SanitizeMol(mol) except: mol = None return mol def drug_decoder_load(self, dictionary_name): ''' Loading the atom and bond decoders ''' with open("DrugGEN/data/decoders/" + dictionary_name +"_" + "akt_train" +'.pkl', 'rb') as f: return pickle.load(f) def matrices2mol_drugs(self, node_labels, edge_labels, strict=True): mol = Chem.RWMol() RDLogger.DisableLog('rdApp.*') atom_decoders = self.drug_decoder_load("atom") bond_decoders = self.drug_decoder_load("bond") for node_label in node_labels: mol.AddAtom(Chem.Atom(atom_decoders[node_label])) for start, end in zip(*np.nonzero(edge_labels)): if start > end: mol.AddBond(int(start), int(end), bond_decoders[edge_labels[start, end]]) mol = self.correct_mol(mol) if strict: try: Chem.SanitizeMol(mol) except: mol = None return mol def check_valency(self,mol): """ Checks that no atoms in the mol have exceeded their possible valency :return: True if no valency issues, False otherwise """ try: Chem.SanitizeMol(mol, sanitizeOps=Chem.SanitizeFlags.SANITIZE_PROPERTIES) return True, None except ValueError as e: e = str(e) p = e.find('#') e_sub = e[p:] atomid_valence = list(map(int, re.findall(r'\d+', e_sub))) return False, atomid_valence def correct_mol(self,x): # xsm = Chem.MolToSmiles(x, isomericSmiles=True) mol = x while True: flag, atomid_valence = self.check_valency(mol) if flag: break else: assert len (atomid_valence) == 2 idx = atomid_valence[0] v = atomid_valence[1] queue = [] for b in mol.GetAtomWithIdx(idx).GetBonds(): queue.append( (b.GetIdx(), int(b.GetBondType()), b.GetBeginAtomIdx(), b.GetEndAtomIdx()) ) queue.sort(key=lambda tup: tup[1], reverse=True) if len(queue) > 0: start = queue[0][2] end = queue[0][3] t = queue[0][1] - 1 mol.RemoveBond(start, end) #if t >= 1: #mol.AddBond(start, end, self.decoder_load('bond_decoders')[t]) # if '.' in Chem.MolToSmiles(mol, isomericSmiles=True): # mol.AddBond(start, end, self.decoder_load('bond_decoders')[t]) # print(tt) # print(Chem.MolToSmiles(mol, isomericSmiles=True)) return mol def label2onehot(self, labels, dim): """Convert label indices to one-hot vectors.""" out = torch.zeros(list(labels.size())+[dim]) out.scatter_(len(out.size())-1,labels.unsqueeze(-1),1.) return out.float() def process(self, size= None): ''' Process the dataset. This function will be only run if processed_file_names does not exist in the data folder already. ''' # mols = [Chem.MolFromSmiles(line) for line in open(self.raw_files, 'r').readlines()] # mols = list(filter(lambda x: x.GetNumAtoms() <= self.max_atom, mols)) # mols = mols[:size] # i # indices = range(len(mols)) smiles = pd.read_csv(self.raw_files, header=None)[0].tolist() self._generate_encoders_decoders(smiles) # pbar.set_description(f'Processing chembl dataset') # max_length = max(mol.GetNumAtoms() for mol in mols) data_list = [] max_length = min(self.max_atom_size_in_data, self.max_atom) self.m_dim = len(self.atom_decoder_m) # for idx in indices: for smiles in tqdm(smiles, desc='Processing chembl dataset', total=len(smiles)): # mol = mols[idx] mol = Chem.MolFromSmiles(smile) # filter by max atom size if mol.GetNumAtoms() > max_length: continue A = self.generate_adjacency_matrix(mol, connected=True, max_length=max_length) if A is not None: x = torch.from_numpy(self.generate_node_features(mol, max_length=max_length)).to(torch.long).view(1, -1) x = self.label2onehot(x,self.m_dim).squeeze() if self.features: f = torch.from_numpy(self.generate_additional_features(mol, max_length=max_length)).to(torch.long).view(x.shape[0], -1) x = torch.concat((x,f), dim=-1) adjacency = torch.from_numpy(A) edge_index = adjacency.nonzero(as_tuple=False).t().contiguous() edge_attr = adjacency[edge_index[0], edge_index[1]].to(torch.long) data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr) if self.pre_filter is not None and not self.pre_filter(data): continue if self.pre_transform is not None: data = self.pre_transform(data) data_list.append(data) # pbar.update(1) # pbar.close() torch.save(self.collate(data_list), osp.join(self.processed_dir, self.dataset_file)) if __name__ == '__main__': data = DruggenDataset("DrugGEN/data")