DeepAcceptor / utils.py
jinysun's picture
Upload 2 files
fc06566
# -*- coding: utf-8 -*-
"""
Created on Thu Jul 28 14:40:59 2022
@author: BM109X32G-10GPU-02
"""
import os
from collections import OrderedDict
import numpy as np
from rdkit import Chem
from rdkit.Chem import AllChem
from rdkit.Chem import rdchem
from compound_constants import DAY_LIGHT_FG_SMARTS_LIST
def get_gasteiger_partial_charges(mol, n_iter=12):
"""
Calculates list of gasteiger partial charges for each atom in mol object.
Args:
mol: rdkit mol object.
n_iter(int): number of iterations. Default 12.
Returns:
list of computed partial charges for each atom.
"""
Chem.rdPartialCharges.ComputeGasteigerCharges(mol, nIter=n_iter,
throwOnParamFailure=True)
partial_charges = [float(a.GetProp('_GasteigerCharge')) for a in
mol.GetAtoms()]
return partial_charges
def create_standardized_mol_id(smiles):
"""
Args:
smiles: smiles sequence.
Returns:
inchi.
"""
if check_smiles_validity(smiles):
# remove stereochemistry
smiles = AllChem.MolToSmiles(AllChem.MolFromSmiles(smiles),
isomericSmiles=False)
mol = Chem.AddHs(AllChem.MolFromSmiles(smiles))
if not mol is None: # to catch weird issue with O=C1O[al]2oc(=O)c3ccc(cn3)c3ccccc3c3cccc(c3)c3ccccc3c3cc(C(F)(F)F)c(cc3o2)-c2ccccc2-c2cccc(c2)-c2ccccc2-c2cccnc21
if '.' in smiles: # if multiple species, pick largest molecule
mol_species_list = split_rdkit_mol_obj(mol)
largest_mol = get_largest_mol(mol_species_list)
inchi = AllChem.MolToInchi(largest_mol)
else:
inchi = AllChem.MolToInchi(mol)
return inchi
else:
return
else:
return
def check_smiles_validity(smiles):
"""
Check whether the smile can't be converted to rdkit mol object.
"""
try:
m = Chem.MolFromSmiles(smiles)
if m:
return True
else:
return False
except Exception as e:
return False
def split_rdkit_mol_obj(mol):
"""
Split rdkit mol object containing multiple species or one species into a
list of mol objects or a list containing a single object respectively.
Args:
mol: rdkit mol object.
"""
smiles = AllChem.MolToSmiles(mol, isomericSmiles=True)
smiles_list = smiles.split('.')
mol_species_list = []
for s in smiles_list:
if check_smiles_validity(s):
mol_species_list.append(AllChem.MolFromSmiles(s))
return mol_species_list
def get_largest_mol(mol_list):
"""
Given a list of rdkit mol objects, returns mol object containing the
largest num of atoms. If multiple containing largest num of atoms,
picks the first one.
Args:
mol_list(list): a list of rdkit mol object.
Returns:
the largest mol.
"""
num_atoms_list = [len(m.GetAtoms()) for m in mol_list]
largest_mol_idx = num_atoms_list.index(max(num_atoms_list))
return mol_list[largest_mol_idx]
def rdchem_enum_to_list(values):
"""values = {0: rdkit.Chem.rdchem.ChiralType.CHI_UNSPECIFIED,
1: rdkit.Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CW,
2: rdkit.Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CCW,
3: rdkit.Chem.rdchem.ChiralType.CHI_OTHER}
"""
return [values[i] for i in range(len(values))]
def safe_index(alist, elem):
"""
Return index of element e in list l. If e is not present, return the last index
"""
try:
return alist.index(elem)
except ValueError:
return len(alist) - 1
def get_atom_feature_dims(list_acquired_feature_names):
""" tbd
"""
return list(map(len, [CompoundKit.atom_vocab_dict[name] for name in list_acquired_feature_names]))
def get_bond_feature_dims(list_acquired_feature_names):
""" tbd
"""
list_bond_feat_dim = list(map(len, [CompoundKit.bond_vocab_dict[name] for name in list_acquired_feature_names]))
# +1 for self loop edges
return [_l + 1 for _l in list_bond_feat_dim]
class CompoundKit(object):
"""
CompoundKit
"""
atom_vocab_dict = {
"atomic_num": list(range(1, 119)) + ['misc'],
"chiral_tag": rdchem_enum_to_list(rdchem.ChiralType.values),
}
bond_vocab_dict = {
"bond_dir": rdchem_enum_to_list(rdchem.BondDir.values),
"bond_type": rdchem_enum_to_list(rdchem.BondType.values),
}
# float features
atom_float_names = ["van_der_waals_radis", "partial_charge", 'mass']
# bond_float_feats= ["bond_length", "bond_angle"] # optional
### functional groups
day_light_fg_smarts_list = DAY_LIGHT_FG_SMARTS_LIST
day_light_fg_mo_list = [Chem.MolFromSmarts(smarts) for smarts in day_light_fg_smarts_list]
morgan_fp_N = 200
morgan2048_fp_N = 2048
maccs_fp_N = 167
period_table = Chem.GetPeriodicTable()
### atom
@staticmethod
def get_atom_value(atom, name):
"""get atom values"""
if name == 'atomic_num':
return atom.GetAtomicNum()
elif name == 'chiral_tag':
return atom.GetChiralTag()
elif name == 'degree':
return atom.GetDegree()
elif name == 'explicit_valence':
return atom.GetExplicitValence()
elif name == 'formal_charge':
return atom.GetFormalCharge()
elif name == 'hybridization':
return atom.GetHybridization()
elif name == 'implicit_valence':
return atom.GetImplicitValence()
elif name == 'is_aromatic':
return int(atom.GetIsAromatic())
elif name == 'mass':
return int(atom.GetMass())
elif name == 'total_numHs':
return atom.GetTotalNumHs()
elif name == 'num_radical_e':
return atom.GetNumRadicalElectrons()
elif name == 'atom_is_in_ring':
return int(atom.IsInRing())
elif name == 'valence_out_shell':
return CompoundKit.period_table.GetNOuterElecs(atom.GetAtomicNum())
else:
raise ValueError(name)
@staticmethod
def get_atom_feature_id(atom, name):
"""get atom features id"""
assert name in CompoundKit.atom_vocab_dict, "%s not found in atom_vocab_dict" % name
return safe_index(CompoundKit.atom_vocab_dict[name], CompoundKit.get_atom_value(atom, name))
@staticmethod
def get_atom_feature_size(name):
"""get atom features size"""
assert name in CompoundKit.atom_vocab_dict, "%s not found in atom_vocab_dict" % name
return len(CompoundKit.atom_vocab_dict[name])
### bond
@staticmethod
def get_bond_value(bond, name):
"""get bond values"""
if name == 'bond_dir':
return bond.GetBondDir()
elif name == 'bond_type':
return bond.GetBondType()
elif name == 'is_in_ring':
return int(bond.IsInRing())
elif name == 'is_conjugated':
return int(bond.GetIsConjugated())
elif name == 'bond_stereo':
return bond.GetStereo()
else:
raise ValueError(name)
@staticmethod
def get_bond_feature_id(bond, name):
"""get bond features id"""
assert name in CompoundKit.bond_vocab_dict, "%s not found in bond_vocab_dict" % name
return safe_index(CompoundKit.bond_vocab_dict[name], CompoundKit.get_bond_value(bond, name))
@staticmethod
def get_bond_feature_size(name):
"""get bond features size"""
assert name in CompoundKit.bond_vocab_dict, "%s not found in bond_vocab_dict" % name
return len(CompoundKit.bond_vocab_dict[name])
### fingerprint
@staticmethod
def get_morgan_fingerprint(mol, radius=2):
"""get morgan fingerprint"""
nBits = CompoundKit.morgan_fp_N
mfp = AllChem.GetMorganFingerprintAsBitVect(mol, radius, nBits=nBits)
return [int(b) for b in mfp.ToBitString()]
@staticmethod
def get_morgan2048_fingerprint(mol, radius=2):
"""get morgan2048 fingerprint"""
nBits = CompoundKit.morgan2048_fp_N
mfp = AllChem.GetMorganFingerprintAsBitVect(mol, radius, nBits=nBits)
return [int(b) for b in mfp.ToBitString()]
@staticmethod
def get_maccs_fingerprint(mol):
"""get maccs fingerprint"""
fp = AllChem.GetMACCSKeysFingerprint(mol)
return [int(b) for b in fp.ToBitString()]
### functional groups
@staticmethod
def get_daylight_functional_group_counts(mol):
"""get daylight functional group counts"""
fg_counts = []
for fg_mol in CompoundKit.day_light_fg_mo_list:
sub_structs = Chem.Mol.GetSubstructMatches(mol, fg_mol, uniquify=True)
fg_counts.append(len(sub_structs))
return fg_counts
@staticmethod
def get_ring_size(mol):
"""return (N,6) list"""
rings = mol.GetRingInfo()
rings_info = []
for r in rings.AtomRings():
rings_info.append(r)
ring_list = []
for atom in mol.GetAtoms():
atom_result = []
for ringsize in range(3, 9):
num_of_ring_at_ringsize = 0
for r in rings_info:
if len(r) == ringsize and atom.GetIdx() in r:
num_of_ring_at_ringsize += 1
if num_of_ring_at_ringsize > 8:
num_of_ring_at_ringsize = 9
atom_result.append(num_of_ring_at_ringsize)
ring_list.append(atom_result)
return ring_list
@staticmethod
def atom_to_feat_vector(atom):
""" tbd """
atom_names = {
"atomic_num": safe_index(CompoundKit.atom_vocab_dict["atomic_num"], atom.GetAtomicNum()),
}
return atom_names
@staticmethod
def get_atom_names(mol):
"""get atom name list
TODO: to be remove in the future
"""
atom_features_dicts = []
Chem.rdPartialCharges.ComputeGasteigerCharges(mol)
for i, atom in enumerate(mol.GetAtoms()):
atom_features_dicts.append(CompoundKit.atom_to_feat_vector(atom))
ring_list = CompoundKit.get_ring_size(mol)
for i, atom in enumerate(mol.GetAtoms()):
atom_features_dicts[i]['in_num_ring_with_size3'] = safe_index(
CompoundKit.atom_vocab_dict['in_num_ring_with_size3'], ring_list[i][0])
atom_features_dicts[i]['in_num_ring_with_size4'] = safe_index(
CompoundKit.atom_vocab_dict['in_num_ring_with_size4'], ring_list[i][1])
atom_features_dicts[i]['in_num_ring_with_size5'] = safe_index(
CompoundKit.atom_vocab_dict['in_num_ring_with_size5'], ring_list[i][2])
atom_features_dicts[i]['in_num_ring_with_size6'] = safe_index(
CompoundKit.atom_vocab_dict['in_num_ring_with_size6'], ring_list[i][3])
atom_features_dicts[i]['in_num_ring_with_size7'] = safe_index(
CompoundKit.atom_vocab_dict['in_num_ring_with_size7'], ring_list[i][4])
atom_features_dicts[i]['in_num_ring_with_size8'] = safe_index(
CompoundKit.atom_vocab_dict['in_num_ring_with_size8'], ring_list[i][5])
return atom_features_dicts
@staticmethod
def check_partial_charge(atom):
"""tbd"""
pc = atom.GetDoubleProp('_GasteigerCharge')
if pc != pc:
# unsupported atom, replace nan with 0
pc = 0
if pc == float('inf'):
# max 4 for other atoms, set to 10 here if inf is get
pc = 10
return pc
class Compound3DKit(object):
"""the 3Dkit of Compound"""
@staticmethod
def get_atom_poses(mol, conf):
"""tbd"""
atom_poses = []
for i, atom in enumerate(mol.GetAtoms()):
if atom.GetAtomicNum() == 0:
return [[0.0, 0.0, 0.0]] * len(mol.GetAtoms())
pos = conf.GetAtomPosition(i)
atom_poses.append([pos.x, pos.y, pos.z])
return atom_poses
@staticmethod
def get_MMFF_atom_poses(mol, numConfs=None, return_energy=False):
"""the atoms of mol will be changed in some cases."""
try:
new_mol = Chem.AddHs(mol)
res = AllChem.EmbedMultipleConfs(new_mol, numConfs=numConfs)
### MMFF generates multiple conformations
res = AllChem.MMFFOptimizeMoleculeConfs(new_mol)
#new_mol = Chem.RemoveHs(new_mol)
index = np.argmin([x[1] for x in res])
energy = res[index][1]
conf = new_mol.GetConformer(id=int(index))
except:
new_mol = Chem.AddHs(mol)
AllChem.Compute2DCoords(new_mol)
energy = 0
conf = new_mol.GetConformer()
atom_poses = Compound3DKit.get_atom_poses(new_mol, conf)
if return_energy:
return new_mol, atom_poses, energy
else:
return new_mol, atom_poses
@staticmethod
def get_2d_atom_poses(mol):
"""get 2d atom poses"""
AllChem.Compute2DCoords(mol)
conf = mol.GetConformer()
atom_poses = Compound3DKit.get_atom_poses(mol, conf)
return atom_poses
@staticmethod
def get_bond_lengths(edges, atom_poses):
"""get bond lengths"""
bond_lengths = []
for src_node_i, tar_node_j in edges:
bond_lengths.append(np.linalg.norm(atom_poses[tar_node_j] - atom_poses[src_node_i]))
bond_lengths = np.array(bond_lengths, 'float32')
return bond_lengths
@staticmethod
def get_superedge_angles(edges, atom_poses, dir_type='HT'):
"""get superedge angles"""
def _get_vec(atom_poses, edge):
return atom_poses[edge[1]] - atom_poses[edge[0]]
def _get_angle(vec1, vec2):
norm1 = np.linalg.norm(vec1)
norm2 = np.linalg.norm(vec2)
if norm1 == 0 or norm2 == 0:
return 0
vec1 = vec1 / (norm1 + 1e-5) # 1e-5: prevent numerical errors
vec2 = vec2 / (norm2 + 1e-5)
angle = np.arccos(np.dot(vec1, vec2))
return angle
E = len(edges)
edge_indices = np.arange(E)
super_edges = []
bond_angles = []
bond_angle_dirs = []
for tar_edge_i in range(E):
tar_edge = edges[tar_edge_i]
if dir_type == 'HT':
src_edge_indices = edge_indices[edges[:, 1] == tar_edge[0]]
elif dir_type == 'HH':
src_edge_indices = edge_indices[edges[:, 1] == tar_edge[1]]
else:
raise ValueError(dir_type)
for src_edge_i in src_edge_indices:
if src_edge_i == tar_edge_i:
continue
src_edge = edges[src_edge_i]
src_vec = _get_vec(atom_poses, src_edge)
tar_vec = _get_vec(atom_poses, tar_edge)
super_edges.append([src_edge_i, tar_edge_i])
angle = _get_angle(src_vec, tar_vec)
bond_angles.append(angle)
bond_angle_dirs.append(src_edge[1] == tar_edge[0]) # H -> H or H -> T
if len(super_edges) == 0:
super_edges = np.zeros([0, 2], 'int64')
bond_angles = np.zeros([0,], 'float32')
else:
super_edges = np.array(super_edges, 'int64')
bond_angles = np.array(bond_angles, 'float32')
return super_edges, bond_angles, bond_angle_dirs
def new_smiles_to_graph_data(smiles, **kwargs):
"""
Convert smiles to graph data.
"""
mol = Chem.AddHs(AllChem.MolFromSmiles(smiles))
if mol is None:
return None
data = new_mol_to_graph_data(mol)
return data
def new_mol_to_graph_data(mol):
"""
mol_to_graph_data
Args:
atom_features: Atom features.
edge_features: Edge features.
morgan_fingerprint: Morgan fingerprint.
functional_groups: Functional groups.
"""
if len(mol.GetAtoms()) == 0:
return None
atom_id_names = list(CompoundKit.atom_vocab_dict.keys()) + CompoundKit.atom_float_names
bond_id_names = list(CompoundKit.bond_vocab_dict.keys())
data = {}
### atom features
data = {name: [] for name in atom_id_names}
raw_atom_feat_dicts = CompoundKit.get_atom_names(mol)
for atom_feat in raw_atom_feat_dicts:
for name in atom_id_names:
data[name].append(atom_feat[name])
### bond and bond features
for name in bond_id_names:
data[name] = []
data['edges'] = []
for bond in mol.GetBonds():
i = bond.GetBeginAtomIdx()
j = bond.GetEndAtomIdx()
# i->j and j->i
data['edges'] += [(i, j), (j, i)]
for name in bond_id_names:
bond_feature_id = CompoundKit.get_bond_feature_id(bond, name)
data[name] += [bond_feature_id] * 2
#### self loop
N = len(data[atom_id_names[0]])
for i in range(N):
data['edges'] += [(i, i)]
for name in bond_id_names:
bond_feature_id = get_bond_feature_dims([name])[0] - 1 # self loop: value = len - 1
data[name] += [bond_feature_id] * N
### make ndarray and check length
for name in list(CompoundKit.atom_vocab_dict.keys()):
data[name] = np.array(data[name], 'int64')
for name in CompoundKit.atom_float_names:
data[name] = np.array(data[name], 'float32')
for name in bond_id_names:
data[name] = np.array(data[name], 'int64')
data['edges'] = np.array(data['edges'], 'int64')
### morgan fingerprint
data['morgan_fp'] = np.array(CompoundKit.get_morgan_fingerprint(mol), 'int64')
# data['morgan2048_fp'] = np.array(CompoundKit.get_morgan2048_fingerprint(mol), 'int64')
data['maccs_fp'] = np.array(CompoundKit.get_maccs_fingerprint(mol), 'int64')
data['daylight_fg_counts'] = np.array(CompoundKit.get_daylight_functional_group_counts(mol), 'int64')
return data
def mol_to_graph_data(mol):
"""
mol_to_graph_data
Args:
atom_features: Atom features.
edge_features: Edge features.
morgan_fingerprint: Morgan fingerprint.
functional_groups: Functional groups.
"""
if len(mol.GetAtoms()) == 0:
return None
atom_id_names = [
"atomic_num"
]
bond_id_names = [
"bond_dir", "bond_type"
]
data = {}
for name in atom_id_names:
data[name] = []
data['mass'] = []
for name in bond_id_names:
data[name] = []
data['edges'] = []
### atom features
for i, atom in enumerate(mol.GetAtoms()):
if atom.GetAtomicNum() == 0:
return None
for name in atom_id_names:
data[name].append(CompoundKit.get_atom_feature_id(atom, name) + 1) # 0: OOV
data['mass'].append(CompoundKit.get_atom_value(atom, 'mass') * 0.01)
### bond features
for bond in mol.GetBonds():
i = bond.GetBeginAtomIdx()
j = bond.GetEndAtomIdx()
# i->j and j->i
data['edges'] += [(i, j), (j, i)]
for name in bond_id_names:
bond_feature_id = CompoundKit.get_bond_feature_id(bond, name) + 1 # 0: OOV
data[name] += [bond_feature_id] * 2
num_atoms = mol.GetNumAtoms()
atoms_list = []
for i in range(num_atoms):
atom = mol.GetAtomWithIdx(i)
atoms_list.append(atom.GetSymbol())
### self loop (+2)
N = len(data[atom_id_names[0]])
for i in range(N):
data['edges'] += [(i, i)]
for name in bond_id_names:
bond_feature_id = CompoundKit.get_bond_feature_size(name) + 2 # N + 2: self loop
data[name] += [bond_feature_id] * N
### check whether edge exists
if len(data['edges']) == 0: # mol has no bonds
for name in bond_id_names:
data[name] = np.zeros((0,), dtype="int64")
data['edges'] = np.zeros((0, 2), dtype="int64")
### make ndarray and check length
for name in atom_id_names:
data[name] = np.array(data[name], 'int64')
data['mass'] = np.array(data['mass'], 'float32')
for name in bond_id_names:
data[name] = np.array(data[name], 'int64')
data['edges'] = np.array(data['edges'], 'int64')
data['atoms'] = np.array(atoms_list)
### morgan fingerprint
#data['morgan_fp'] = np.array(CompoundKit.get_morgan_fingerprint(mol), 'int64')
# data['morgan2048_fp'] = np.array(CompoundKit.get_morgan2048_fingerprint(mol), 'int64')
#data['maccs_fp'] = np.array(CompoundKit.get_maccs_fingerprint(mol), 'int64')
#data['daylight_fg_counts'] = np.array(CompoundKit.get_daylight_functional_group_counts(mol), 'int64')
#return data['bonds_dir'],data['adj_angle']
return data
def mol_to_geognn_graph_data(mol, atom_poses, dir_type):
"""
mol: rdkit molecule
dir_type: direction type for bond_angle grpah
"""
if len(mol.GetAtoms()) == 0:
return None
data = mol_to_graph_data(mol)
data['atom_pos'] = np.array(atom_poses, 'float32')
data['bond_length'] = Compound3DKit.get_bond_lengths(data['edges'], data['atom_pos'])
# BondAngleGraph_edges, bond_angles, bond_angle_dirs = \
# Compound3DKit.get_superedge_angles(data['edges'], data['atom_pos'])
# data['BondAngleGraph_edges'] = BondAngleGraph_edges
# data['bond_angle'] = np.array(bond_angles, 'float32')
data['adj_node'] = gen_adj(len(data['atoms']),data['edges'],data['bond_length'])
# data['adj_edge'] = gen_adj(len(data['bond_dir']),data['BondAngleGraph_edges'],data['bond_angle'])
return data['atoms'], data['adj_node']
def mol_to_geognn_graph_data_MMFF3d(smiles):
"""tbd"""
mol = Chem.AddHs(AllChem.MolFromSmiles(smiles))
if len(mol.GetAtoms()) <= 400:
mol, atom_poses = Compound3DKit.get_MMFF_atom_poses(mol, numConfs=10)
else:
atom_poses = Compound3DKit.get_2d_atom_poses(mol)
return mol_to_geognn_graph_data(mol, atom_poses, dir_type='HT')
def mol_to_geognn_graph_data_raw3d(mol):
"""tbd"""
atom_poses = Compound3DKit.get_atom_poses(mol, mol.GetConformer())
return mol_to_geognn_graph_data(mol, atom_poses, dir_type='HT')
def gen_adj(shape,edges,length):
adj=edges
e = shape
ones = np.eye(e)
#for i in range(e):
for i in range (len(length)):
if adj[i,0] != adj[i,1]:
ones[adj[i,0],adj[i,1]]=(float(length[i] ))
return ones
if __name__ == "__main__":
import pandas as pd
from tqdm import tqdm
f = pd.read_csv (r"J:\screenacc\new4.csv")
# re = []
# pce = f['PCE']
# for ind,smile in enumerate ( f.iloc[:,1]):
# print(ind)
# atom,adj = mol_to_geognn_graph_data_MMFF3d(smile)
# np.save('data/reg/train/adj'+str(ind)+'.npy',np.array(adj))
# re.append([atom,'data/reg/train/adj'+str(ind)+'.npy',pce[ind] ])
# r = pd.DataFrame(re)
# r.to_csv('data/reg/train/train.csv')
# re = []
# f = pd.read_csv(r'data/reg/test3.csv')
# re = []
# pce = f['PCE']
# for ind,smile in enumerate ( f.iloc[:,1]):
# print(ind)
# atom,adj = mol_to_geognn_graph_data_MMFF3d(smile)
# np.save('data/reg/test/adj'+str(ind)+'.npy',np.array(adj))
# re.append([atom,'data/reg/test/adj'+str(ind)+'.npy',pce[ind] ])
# r = pd.DataFrame(re)
# r.to_csv('data/reg/test/test.csv')
# f = pd.read_csv(r'val.csv')
re = []
pce = f['PCE']
for ind,smile in enumerate ( f.iloc[ 22000: ,0]):
ind = ind + 22000
print(ind)
atom,adj = mol_to_geognn_graph_data_MMFF3d(smile)
np.save('data/reg/val/adj'+str(ind)+'.npy',np.array(adj))
re.append([atom,'data/reg/val/adj'+str(ind)+'.npy',pce[ind] ])
r = pd.DataFrame(re)
r.to_csv('data/reg/val/val22000.csv')