DrugGEN / new_dataloader.py
mgyigit's picture
Update new_dataloader.py
e55a2c8
raw
history blame
12.2 kB
import pickle
import numpy as np
import torch
from rdkit import Chem
from torch_geometric.data import (Data, InMemoryDataset)
import os.path as osp
from tqdm import tqdm
import re
from rdkit import RDLogger
import pandas as pd
RDLogger.DisableLog('rdApp.*')
class DruggenDataset(InMemoryDataset):
def __init__(self, root, dataset_file, raw_files, max_atom, features, transform=None, pre_transform=None, pre_filter=None):
self.dataset_name = dataset_file.split(".")[0]
self.dataset_file = dataset_file
self.raw_files = raw_files
self.max_atom = max_atom
self.features = features
super().__init__(root, transform, pre_transform, pre_filter)
path = osp.join(self.processed_dir, dataset_file)
self.data, self.slices = torch.load(path)
self.root = root
@property
def processed_dir(self):
return self.root
@property
def raw_file_names(self):
return self.raw_files
@property
def processed_file_names(self):
return self.dataset_file
def _generate_encoders_decoders(self, data):
self.data = data
print('Creating atoms and bonds encoder and decoder..')
atom_labels = set()
bond_labels = set()
max_length = 0
smiles_list = []
for smiles in tqdm(data):
mol = Chem.MolFromSmiles(smiles)
molecule_size = mol.GetNumAtoms()
if molecule_size > self.max_atom:
continue
smiles_list.append(smiles)
atom_labels.update([atom.GetAtomicNum() for atom in mol.GetAtoms()])
max_length = max(max_length, molecule_size)
bond_labels.update([bond.GetBondType() for bond in mol.GetBonds()])
atom_labels.update([0]) # add PAD symbol (for unknown atoms)
atom_labels = sorted(atom_labels) # turn set into list and sort it
bond_labels = sorted(bond_labels)
bond_labels = [Chem.rdchem.BondType.ZERO] + bond_labels
# atom_labels = sorted(set([atom.GetAtomicNum() for mol in self.data for atom in mol.GetAtoms()] + [0]))
self.atom_encoder_m = {l: i for i, l in enumerate(atom_labels)}
self.atom_decoder_m = {i: l for i, l in enumerate(atom_labels)}
self.atom_num_types = len(atom_labels)
print('Created atoms encoder and decoder with {} atom types and 1 PAD symbol!'.format(
self.atom_num_types - 1))
print("atom_labels", atom_labels)
# print('Creating bonds encoder and decoder..')
# bond_labels = [Chem.rdchem.BondType.ZERO] + list(sorted(set(bond.GetBondType()
# for mol in self.data
# for bond in mol.GetBonds())))
# bond_labels = [
# Chem.rdchem.BondType.ZERO,
# Chem.rdchem.BondType.SINGLE,
# Chem.rdchem.BondType.DOUBLE,
# Chem.rdchem.BondType.TRIPLE,
# Chem.rdchem.BondType.AROMATIC,
# ]
print("bond labels", bond_labels)
self.bond_encoder_m = {l: i for i, l in enumerate(bond_labels)}
self.bond_decoder_m = {i: l for i, l in enumerate(bond_labels)}
self.bond_num_types = len(bond_labels)
print('Created bonds encoder and decoder with {} bond types and 1 PAD symbol!'.format(
self.bond_num_types - 1))
#dataset_names = str(self.dataset_name)
with open("data/encoders/" +"atom_" + self.dataset_name + ".pkl","wb") as atom_encoders:
pickle.dump(self.atom_encoder_m,atom_encoders)
with open("data/decoders/" +"atom_" + self.dataset_name + ".pkl","wb") as atom_decoders:
pickle.dump(self.atom_decoder_m,atom_decoders)
with open("data/encoders/" +"bond_" + self.dataset_name + ".pkl","wb") as bond_encoders:
pickle.dump(self.bond_encoder_m,bond_encoders)
with open("data/decoders/" +"bond_" + self.dataset_name + ".pkl","wb") as bond_decoders:
pickle.dump(self.bond_decoder_m,bond_decoders)
return max_length, smiles_list # data is filtered now
def _genA(self, mol, connected=True, max_length=None):
max_length = max_length if max_length is not None else mol.GetNumAtoms()
A = np.zeros(shape=(max_length, max_length))
begin, end = [b.GetBeginAtomIdx() for b in mol.GetBonds()], [b.GetEndAtomIdx() for b in mol.GetBonds()]
bond_type = [self.bond_encoder_m[b.GetBondType()] for b in mol.GetBonds()]
A[begin, end] = bond_type
A[end, begin] = bond_type
degree = np.sum(A[:mol.GetNumAtoms(), :mol.GetNumAtoms()], axis=-1)
return A if connected and (degree > 0).all() else None
def _genX(self, mol, max_length=None):
max_length = max_length if max_length is not None else mol.GetNumAtoms()
return np.array([self.atom_encoder_m[atom.GetAtomicNum()] for atom in mol.GetAtoms()] + [0] * (
max_length - mol.GetNumAtoms()))
def _genF(self, mol, max_length=None):
max_length = max_length if max_length is not None else mol.GetNumAtoms()
features = np.array([[*[a.GetDegree() == i for i in range(5)],
*[a.GetExplicitValence() == i for i in range(9)],
*[int(a.GetHybridization()) == i for i in range(1, 7)],
*[a.GetImplicitValence() == i for i in range(9)],
a.GetIsAromatic(),
a.GetNoImplicit(),
*[a.GetNumExplicitHs() == i for i in range(5)],
*[a.GetNumImplicitHs() == i for i in range(5)],
*[a.GetNumRadicalElectrons() == i for i in range(5)],
a.IsInRing(),
*[a.IsInRingSize(i) for i in range(2, 9)]] for a in mol.GetAtoms()], dtype=np.int32)
return np.vstack((features, np.zeros((max_length - features.shape[0], features.shape[1]))))
def decoder_load(self, dictionary_name, file):
with open("data/decoders/" + dictionary_name + "_" + file + '.pkl', 'rb') as f:
return pickle.load(f)
def drugs_decoder_load(self, dictionary_name):
with open("data/decoders/" + dictionary_name +'.pkl', 'rb') as f:
return pickle.load(f)
def matrices2mol(self, node_labels, edge_labels, strict=True, file_name=None):
mol = Chem.RWMol()
RDLogger.DisableLog('rdApp.*')
atom_decoders = self.decoder_load("atom", file_name)
bond_decoders = self.decoder_load("bond", file_name)
for node_label in node_labels:
mol.AddAtom(Chem.Atom(atom_decoders[node_label]))
for start, end in zip(*np.nonzero(edge_labels)):
if start > end:
mol.AddBond(int(start), int(end), bond_decoders[edge_labels[start, end]])
#mol = self.correct_mol(mol)
if strict:
try:
Chem.SanitizeMol(mol)
except:
mol = None
return mol
def drug_decoder_load(self, dictionary_name, file):
''' Loading the atom and bond decoders '''
with open("data/decoders/" + dictionary_name +"_" + file +'.pkl', 'rb') as f:
return pickle.load(f)
def matrices2mol_drugs(self, node_labels, edge_labels, strict=True, file_name=None):
mol = Chem.RWMol()
RDLogger.DisableLog('rdApp.*')
atom_decoders = self.drug_decoder_load("atom", file_name)
bond_decoders = self.drug_decoder_load("bond", file_name)
for node_label in node_labels:
mol.AddAtom(Chem.Atom(atom_decoders[node_label]))
for start, end in zip(*np.nonzero(edge_labels)):
if start > end:
mol.AddBond(int(start), int(end), bond_decoders[edge_labels[start, end]])
#mol = self.correct_mol(mol)
if strict:
try:
Chem.SanitizeMol(mol)
except:
mol = None
return mol
def check_valency(self,mol):
"""
Checks that no atoms in the mol have exceeded their possible
valency
:return: True if no valency issues, False otherwise
"""
try:
Chem.SanitizeMol(mol, sanitizeOps=Chem.SanitizeFlags.SANITIZE_PROPERTIES)
return True, None
except ValueError as e:
e = str(e)
p = e.find('#')
e_sub = e[p:]
atomid_valence = list(map(int, re.findall(r'\d+', e_sub)))
return False, atomid_valence
def correct_mol(self,x):
xsm = Chem.MolToSmiles(x, isomericSmiles=True)
mol = x
while True:
flag, atomid_valence = self.check_valency(mol)
if flag:
break
else:
assert len (atomid_valence) == 2
idx = atomid_valence[0]
v = atomid_valence[1]
queue = []
for b in mol.GetAtomWithIdx(idx).GetBonds():
queue.append(
(b.GetIdx(), int(b.GetBondType()), b.GetBeginAtomIdx(), b.GetEndAtomIdx())
)
queue.sort(key=lambda tup: tup[1], reverse=True)
if len(queue) > 0:
start = queue[0][2]
end = queue[0][3]
t = queue[0][1] - 1
mol.RemoveBond(start, end)
#if t >= 1:
#mol.AddBond(start, end, self.decoder_load('bond_decoders')[t])
# if '.' in Chem.MolToSmiles(mol, isomericSmiles=True):
# mol.AddBond(start, end, self.decoder_load('bond_decoders')[t])
# print(tt)
# print(Chem.MolToSmiles(mol, isomericSmiles=True))
return mol
def label2onehot(self, labels, dim):
"""Convert label indices to one-hot vectors."""
out = torch.zeros(list(labels.size())+[dim])
out.scatter_(len(out.size())-1,labels.unsqueeze(-1),1.)
return out.float()
def process(self, size= None):
smiles_list = pd.read_csv(self.raw_files, header=None)[0].tolist()
max_length, smiles_list = self._generate_encoders_decoders(smiles_list)
data_list = []
self.m_dim = len(self.atom_decoder_m)
for smiles in tqdm(smiles_list, desc='Processing chembl dataset', total=len(smiles_list)):
mol = Chem.MolFromSmiles(smiles)
A = self._genA(mol, connected=True, max_length=max_length)
if A is not None:
x = torch.from_numpy(self._genX(mol, max_length=max_length)).to(torch.long).view(1, -1)
x = self.label2onehot(x,self.m_dim).squeeze()
if self.features:
f = torch.from_numpy(self._genF(mol, max_length=max_length)).to(torch.long).view(x.shape[0], -1)
x = torch.concat((x,f), dim=-1)
adjacency = torch.from_numpy(A)
edge_index = adjacency.nonzero(as_tuple=False).t().contiguous()
edge_attr = adjacency[edge_index[0], edge_index[1]].to(torch.long)
data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, smiles=smiles)
if self.pre_filter is not None and not self.pre_filter(data):
continue
if self.pre_transform is not None:
data = self.pre_transform(data)
data_list.append(data)
torch.save(self.collate(data_list), osp.join(self.processed_dir, self.dataset_file))
if __name__ == '__main__':
data = DruggenDataset("data")