# -*- coding: utf-8 -*- """ Created on Thu Jul 28 14:40:59 2022 @author: BM109X32G-10GPU-02 """ import os from collections import OrderedDict import numpy as np from rdkit import Chem from rdkit.Chem import AllChem from rdkit.Chem import rdchem from compound_constants import DAY_LIGHT_FG_SMARTS_LIST def get_gasteiger_partial_charges(mol, n_iter=12): """ Calculates list of gasteiger partial charges for each atom in mol object. Args: mol: rdkit mol object. n_iter(int): number of iterations. Default 12. Returns: list of computed partial charges for each atom. """ Chem.rdPartialCharges.ComputeGasteigerCharges(mol, nIter=n_iter, throwOnParamFailure=True) partial_charges = [float(a.GetProp('_GasteigerCharge')) for a in mol.GetAtoms()] return partial_charges def create_standardized_mol_id(smiles): """ Args: smiles: smiles sequence. Returns: inchi. """ if check_smiles_validity(smiles): # remove stereochemistry smiles = AllChem.MolToSmiles(AllChem.MolFromSmiles(smiles), isomericSmiles=False) mol = Chem.AddHs(AllChem.MolFromSmiles(smiles)) if not mol is None: # to catch weird issue with O=C1O[al]2oc(=O)c3ccc(cn3)c3ccccc3c3cccc(c3)c3ccccc3c3cc(C(F)(F)F)c(cc3o2)-c2ccccc2-c2cccc(c2)-c2ccccc2-c2cccnc21 if '.' in smiles: # if multiple species, pick largest molecule mol_species_list = split_rdkit_mol_obj(mol) largest_mol = get_largest_mol(mol_species_list) inchi = AllChem.MolToInchi(largest_mol) else: inchi = AllChem.MolToInchi(mol) return inchi else: return else: return def check_smiles_validity(smiles): """ Check whether the smile can't be converted to rdkit mol object. """ try: m = Chem.MolFromSmiles(smiles) if m: return True else: return False except Exception as e: return False def split_rdkit_mol_obj(mol): """ Split rdkit mol object containing multiple species or one species into a list of mol objects or a list containing a single object respectively. Args: mol: rdkit mol object. """ smiles = AllChem.MolToSmiles(mol, isomericSmiles=True) smiles_list = smiles.split('.') mol_species_list = [] for s in smiles_list: if check_smiles_validity(s): mol_species_list.append(AllChem.MolFromSmiles(s)) return mol_species_list def get_largest_mol(mol_list): """ Given a list of rdkit mol objects, returns mol object containing the largest num of atoms. If multiple containing largest num of atoms, picks the first one. Args: mol_list(list): a list of rdkit mol object. Returns: the largest mol. """ num_atoms_list = [len(m.GetAtoms()) for m in mol_list] largest_mol_idx = num_atoms_list.index(max(num_atoms_list)) return mol_list[largest_mol_idx] def rdchem_enum_to_list(values): """values = {0: rdkit.Chem.rdchem.ChiralType.CHI_UNSPECIFIED, 1: rdkit.Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CW, 2: rdkit.Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CCW, 3: rdkit.Chem.rdchem.ChiralType.CHI_OTHER} """ return [values[i] for i in range(len(values))] def safe_index(alist, elem): """ Return index of element e in list l. If e is not present, return the last index """ try: return alist.index(elem) except ValueError: return len(alist) - 1 def get_atom_feature_dims(list_acquired_feature_names): """ tbd """ return list(map(len, [CompoundKit.atom_vocab_dict[name] for name in list_acquired_feature_names])) def get_bond_feature_dims(list_acquired_feature_names): """ tbd """ list_bond_feat_dim = list(map(len, [CompoundKit.bond_vocab_dict[name] for name in list_acquired_feature_names])) # +1 for self loop edges return [_l + 1 for _l in list_bond_feat_dim] class CompoundKit(object): """ CompoundKit """ atom_vocab_dict = { "atomic_num": list(range(1, 119)) + ['misc'], "chiral_tag": rdchem_enum_to_list(rdchem.ChiralType.values), } bond_vocab_dict = { "bond_dir": rdchem_enum_to_list(rdchem.BondDir.values), "bond_type": rdchem_enum_to_list(rdchem.BondType.values), } # float features atom_float_names = ["van_der_waals_radis", "partial_charge", 'mass'] # bond_float_feats= ["bond_length", "bond_angle"] # optional ### functional groups day_light_fg_smarts_list = DAY_LIGHT_FG_SMARTS_LIST day_light_fg_mo_list = [Chem.MolFromSmarts(smarts) for smarts in day_light_fg_smarts_list] morgan_fp_N = 200 morgan2048_fp_N = 2048 maccs_fp_N = 167 period_table = Chem.GetPeriodicTable() ### atom @staticmethod def get_atom_value(atom, name): """get atom values""" if name == 'atomic_num': return atom.GetAtomicNum() elif name == 'chiral_tag': return atom.GetChiralTag() elif name == 'degree': return atom.GetDegree() elif name == 'explicit_valence': return atom.GetExplicitValence() elif name == 'formal_charge': return atom.GetFormalCharge() elif name == 'hybridization': return atom.GetHybridization() elif name == 'implicit_valence': return atom.GetImplicitValence() elif name == 'is_aromatic': return int(atom.GetIsAromatic()) elif name == 'mass': return int(atom.GetMass()) elif name == 'total_numHs': return atom.GetTotalNumHs() elif name == 'num_radical_e': return atom.GetNumRadicalElectrons() elif name == 'atom_is_in_ring': return int(atom.IsInRing()) elif name == 'valence_out_shell': return CompoundKit.period_table.GetNOuterElecs(atom.GetAtomicNum()) else: raise ValueError(name) @staticmethod def get_atom_feature_id(atom, name): """get atom features id""" assert name in CompoundKit.atom_vocab_dict, "%s not found in atom_vocab_dict" % name return safe_index(CompoundKit.atom_vocab_dict[name], CompoundKit.get_atom_value(atom, name)) @staticmethod def get_atom_feature_size(name): """get atom features size""" assert name in CompoundKit.atom_vocab_dict, "%s not found in atom_vocab_dict" % name return len(CompoundKit.atom_vocab_dict[name]) ### bond @staticmethod def get_bond_value(bond, name): """get bond values""" if name == 'bond_dir': return bond.GetBondDir() elif name == 'bond_type': return bond.GetBondType() elif name == 'is_in_ring': return int(bond.IsInRing()) elif name == 'is_conjugated': return int(bond.GetIsConjugated()) elif name == 'bond_stereo': return bond.GetStereo() else: raise ValueError(name) @staticmethod def get_bond_feature_id(bond, name): """get bond features id""" assert name in CompoundKit.bond_vocab_dict, "%s not found in bond_vocab_dict" % name return safe_index(CompoundKit.bond_vocab_dict[name], CompoundKit.get_bond_value(bond, name)) @staticmethod def get_bond_feature_size(name): """get bond features size""" assert name in CompoundKit.bond_vocab_dict, "%s not found in bond_vocab_dict" % name return len(CompoundKit.bond_vocab_dict[name]) ### fingerprint @staticmethod def get_morgan_fingerprint(mol, radius=2): """get morgan fingerprint""" nBits = CompoundKit.morgan_fp_N mfp = AllChem.GetMorganFingerprintAsBitVect(mol, radius, nBits=nBits) return [int(b) for b in mfp.ToBitString()] @staticmethod def get_morgan2048_fingerprint(mol, radius=2): """get morgan2048 fingerprint""" nBits = CompoundKit.morgan2048_fp_N mfp = AllChem.GetMorganFingerprintAsBitVect(mol, radius, nBits=nBits) return [int(b) for b in mfp.ToBitString()] @staticmethod def get_maccs_fingerprint(mol): """get maccs fingerprint""" fp = AllChem.GetMACCSKeysFingerprint(mol) return [int(b) for b in fp.ToBitString()] ### functional groups @staticmethod def get_daylight_functional_group_counts(mol): """get daylight functional group counts""" fg_counts = [] for fg_mol in CompoundKit.day_light_fg_mo_list: sub_structs = Chem.Mol.GetSubstructMatches(mol, fg_mol, uniquify=True) fg_counts.append(len(sub_structs)) return fg_counts @staticmethod def get_ring_size(mol): """return (N,6) list""" rings = mol.GetRingInfo() rings_info = [] for r in rings.AtomRings(): rings_info.append(r) ring_list = [] for atom in mol.GetAtoms(): atom_result = [] for ringsize in range(3, 9): num_of_ring_at_ringsize = 0 for r in rings_info: if len(r) == ringsize and atom.GetIdx() in r: num_of_ring_at_ringsize += 1 if num_of_ring_at_ringsize > 8: num_of_ring_at_ringsize = 9 atom_result.append(num_of_ring_at_ringsize) ring_list.append(atom_result) return ring_list @staticmethod def atom_to_feat_vector(atom): """ tbd """ atom_names = { "atomic_num": safe_index(CompoundKit.atom_vocab_dict["atomic_num"], atom.GetAtomicNum()), } return atom_names @staticmethod def get_atom_names(mol): """get atom name list TODO: to be remove in the future """ atom_features_dicts = [] Chem.rdPartialCharges.ComputeGasteigerCharges(mol) for i, atom in enumerate(mol.GetAtoms()): atom_features_dicts.append(CompoundKit.atom_to_feat_vector(atom)) ring_list = CompoundKit.get_ring_size(mol) for i, atom in enumerate(mol.GetAtoms()): atom_features_dicts[i]['in_num_ring_with_size3'] = safe_index( CompoundKit.atom_vocab_dict['in_num_ring_with_size3'], ring_list[i][0]) atom_features_dicts[i]['in_num_ring_with_size4'] = safe_index( CompoundKit.atom_vocab_dict['in_num_ring_with_size4'], ring_list[i][1]) atom_features_dicts[i]['in_num_ring_with_size5'] = safe_index( CompoundKit.atom_vocab_dict['in_num_ring_with_size5'], ring_list[i][2]) atom_features_dicts[i]['in_num_ring_with_size6'] = safe_index( CompoundKit.atom_vocab_dict['in_num_ring_with_size6'], ring_list[i][3]) atom_features_dicts[i]['in_num_ring_with_size7'] = safe_index( CompoundKit.atom_vocab_dict['in_num_ring_with_size7'], ring_list[i][4]) atom_features_dicts[i]['in_num_ring_with_size8'] = safe_index( CompoundKit.atom_vocab_dict['in_num_ring_with_size8'], ring_list[i][5]) return atom_features_dicts @staticmethod def check_partial_charge(atom): """tbd""" pc = atom.GetDoubleProp('_GasteigerCharge') if pc != pc: # unsupported atom, replace nan with 0 pc = 0 if pc == float('inf'): # max 4 for other atoms, set to 10 here if inf is get pc = 10 return pc class Compound3DKit(object): """the 3Dkit of Compound""" @staticmethod def get_atom_poses(mol, conf): """tbd""" atom_poses = [] for i, atom in enumerate(mol.GetAtoms()): if atom.GetAtomicNum() == 0: return [[0.0, 0.0, 0.0]] * len(mol.GetAtoms()) pos = conf.GetAtomPosition(i) atom_poses.append([pos.x, pos.y, pos.z]) return atom_poses @staticmethod def get_MMFF_atom_poses(mol, numConfs=None, return_energy=False): """the atoms of mol will be changed in some cases.""" try: new_mol = Chem.AddHs(mol) res = AllChem.EmbedMultipleConfs(new_mol, numConfs=numConfs) ### MMFF generates multiple conformations res = AllChem.MMFFOptimizeMoleculeConfs(new_mol) #new_mol = Chem.RemoveHs(new_mol) index = np.argmin([x[1] for x in res]) energy = res[index][1] conf = new_mol.GetConformer(id=int(index)) except: new_mol = Chem.AddHs(mol) AllChem.Compute2DCoords(new_mol) energy = 0 conf = new_mol.GetConformer() atom_poses = Compound3DKit.get_atom_poses(new_mol, conf) if return_energy: return new_mol, atom_poses, energy else: return new_mol, atom_poses @staticmethod def get_2d_atom_poses(mol): """get 2d atom poses""" AllChem.Compute2DCoords(mol) conf = mol.GetConformer() atom_poses = Compound3DKit.get_atom_poses(mol, conf) return atom_poses @staticmethod def get_bond_lengths(edges, atom_poses): """get bond lengths""" bond_lengths = [] for src_node_i, tar_node_j in edges: bond_lengths.append(np.linalg.norm(atom_poses[tar_node_j] - atom_poses[src_node_i])) bond_lengths = np.array(bond_lengths, 'float32') return bond_lengths @staticmethod def get_superedge_angles(edges, atom_poses, dir_type='HT'): """get superedge angles""" def _get_vec(atom_poses, edge): return atom_poses[edge[1]] - atom_poses[edge[0]] def _get_angle(vec1, vec2): norm1 = np.linalg.norm(vec1) norm2 = np.linalg.norm(vec2) if norm1 == 0 or norm2 == 0: return 0 vec1 = vec1 / (norm1 + 1e-5) # 1e-5: prevent numerical errors vec2 = vec2 / (norm2 + 1e-5) angle = np.arccos(np.dot(vec1, vec2)) return angle E = len(edges) edge_indices = np.arange(E) super_edges = [] bond_angles = [] bond_angle_dirs = [] for tar_edge_i in range(E): tar_edge = edges[tar_edge_i] if dir_type == 'HT': src_edge_indices = edge_indices[edges[:, 1] == tar_edge[0]] elif dir_type == 'HH': src_edge_indices = edge_indices[edges[:, 1] == tar_edge[1]] else: raise ValueError(dir_type) for src_edge_i in src_edge_indices: if src_edge_i == tar_edge_i: continue src_edge = edges[src_edge_i] src_vec = _get_vec(atom_poses, src_edge) tar_vec = _get_vec(atom_poses, tar_edge) super_edges.append([src_edge_i, tar_edge_i]) angle = _get_angle(src_vec, tar_vec) bond_angles.append(angle) bond_angle_dirs.append(src_edge[1] == tar_edge[0]) # H -> H or H -> T if len(super_edges) == 0: super_edges = np.zeros([0, 2], 'int64') bond_angles = np.zeros([0,], 'float32') else: super_edges = np.array(super_edges, 'int64') bond_angles = np.array(bond_angles, 'float32') return super_edges, bond_angles, bond_angle_dirs def new_smiles_to_graph_data(smiles, **kwargs): """ Convert smiles to graph data. """ mol = Chem.AddHs(AllChem.MolFromSmiles(smiles)) if mol is None: return None data = new_mol_to_graph_data(mol) return data def new_mol_to_graph_data(mol): """ mol_to_graph_data Args: atom_features: Atom features. edge_features: Edge features. morgan_fingerprint: Morgan fingerprint. functional_groups: Functional groups. """ if len(mol.GetAtoms()) == 0: return None atom_id_names = list(CompoundKit.atom_vocab_dict.keys()) + CompoundKit.atom_float_names bond_id_names = list(CompoundKit.bond_vocab_dict.keys()) data = {} ### atom features data = {name: [] for name in atom_id_names} raw_atom_feat_dicts = CompoundKit.get_atom_names(mol) for atom_feat in raw_atom_feat_dicts: for name in atom_id_names: data[name].append(atom_feat[name]) ### bond and bond features for name in bond_id_names: data[name] = [] data['edges'] = [] for bond in mol.GetBonds(): i = bond.GetBeginAtomIdx() j = bond.GetEndAtomIdx() # i->j and j->i data['edges'] += [(i, j), (j, i)] for name in bond_id_names: bond_feature_id = CompoundKit.get_bond_feature_id(bond, name) data[name] += [bond_feature_id] * 2 #### self loop N = len(data[atom_id_names[0]]) for i in range(N): data['edges'] += [(i, i)] for name in bond_id_names: bond_feature_id = get_bond_feature_dims([name])[0] - 1 # self loop: value = len - 1 data[name] += [bond_feature_id] * N ### make ndarray and check length for name in list(CompoundKit.atom_vocab_dict.keys()): data[name] = np.array(data[name], 'int64') for name in CompoundKit.atom_float_names: data[name] = np.array(data[name], 'float32') for name in bond_id_names: data[name] = np.array(data[name], 'int64') data['edges'] = np.array(data['edges'], 'int64') ### morgan fingerprint data['morgan_fp'] = np.array(CompoundKit.get_morgan_fingerprint(mol), 'int64') # data['morgan2048_fp'] = np.array(CompoundKit.get_morgan2048_fingerprint(mol), 'int64') data['maccs_fp'] = np.array(CompoundKit.get_maccs_fingerprint(mol), 'int64') data['daylight_fg_counts'] = np.array(CompoundKit.get_daylight_functional_group_counts(mol), 'int64') return data def mol_to_graph_data(mol): """ mol_to_graph_data Args: atom_features: Atom features. edge_features: Edge features. morgan_fingerprint: Morgan fingerprint. functional_groups: Functional groups. """ if len(mol.GetAtoms()) == 0: return None atom_id_names = [ "atomic_num" ] bond_id_names = [ "bond_dir", "bond_type" ] data = {} for name in atom_id_names: data[name] = [] data['mass'] = [] for name in bond_id_names: data[name] = [] data['edges'] = [] ### atom features for i, atom in enumerate(mol.GetAtoms()): if atom.GetAtomicNum() == 0: return None for name in atom_id_names: data[name].append(CompoundKit.get_atom_feature_id(atom, name) + 1) # 0: OOV data['mass'].append(CompoundKit.get_atom_value(atom, 'mass') * 0.01) ### bond features for bond in mol.GetBonds(): i = bond.GetBeginAtomIdx() j = bond.GetEndAtomIdx() # i->j and j->i data['edges'] += [(i, j), (j, i)] for name in bond_id_names: bond_feature_id = CompoundKit.get_bond_feature_id(bond, name) + 1 # 0: OOV data[name] += [bond_feature_id] * 2 num_atoms = mol.GetNumAtoms() atoms_list = [] for i in range(num_atoms): atom = mol.GetAtomWithIdx(i) atoms_list.append(atom.GetSymbol()) ### self loop (+2) N = len(data[atom_id_names[0]]) for i in range(N): data['edges'] += [(i, i)] for name in bond_id_names: bond_feature_id = CompoundKit.get_bond_feature_size(name) + 2 # N + 2: self loop data[name] += [bond_feature_id] * N ### check whether edge exists if len(data['edges']) == 0: # mol has no bonds for name in bond_id_names: data[name] = np.zeros((0,), dtype="int64") data['edges'] = np.zeros((0, 2), dtype="int64") ### make ndarray and check length for name in atom_id_names: data[name] = np.array(data[name], 'int64') data['mass'] = np.array(data['mass'], 'float32') for name in bond_id_names: data[name] = np.array(data[name], 'int64') data['edges'] = np.array(data['edges'], 'int64') data['atoms'] = np.array(atoms_list) ### morgan fingerprint #data['morgan_fp'] = np.array(CompoundKit.get_morgan_fingerprint(mol), 'int64') # data['morgan2048_fp'] = np.array(CompoundKit.get_morgan2048_fingerprint(mol), 'int64') #data['maccs_fp'] = np.array(CompoundKit.get_maccs_fingerprint(mol), 'int64') #data['daylight_fg_counts'] = np.array(CompoundKit.get_daylight_functional_group_counts(mol), 'int64') #return data['bonds_dir'],data['adj_angle'] return data def mol_to_geognn_graph_data(mol, atom_poses, dir_type): """ mol: rdkit molecule dir_type: direction type for bond_angle grpah """ if len(mol.GetAtoms()) == 0: return None data = mol_to_graph_data(mol) data['atom_pos'] = np.array(atom_poses, 'float32') data['bond_length'] = Compound3DKit.get_bond_lengths(data['edges'], data['atom_pos']) # BondAngleGraph_edges, bond_angles, bond_angle_dirs = \ # Compound3DKit.get_superedge_angles(data['edges'], data['atom_pos']) # data['BondAngleGraph_edges'] = BondAngleGraph_edges # data['bond_angle'] = np.array(bond_angles, 'float32') data['adj_node'] = gen_adj(len(data['atoms']),data['edges'],data['bond_length']) # data['adj_edge'] = gen_adj(len(data['bond_dir']),data['BondAngleGraph_edges'],data['bond_angle']) return data['atoms'], data['adj_node'] def mol_to_geognn_graph_data_MMFF3d(smiles): """tbd""" mol = Chem.AddHs(AllChem.MolFromSmiles(smiles)) if len(mol.GetAtoms()) <= 400: mol, atom_poses = Compound3DKit.get_MMFF_atom_poses(mol, numConfs=10) else: atom_poses = Compound3DKit.get_2d_atom_poses(mol) return mol_to_geognn_graph_data(mol, atom_poses, dir_type='HT') def mol_to_geognn_graph_data_raw3d(mol): """tbd""" atom_poses = Compound3DKit.get_atom_poses(mol, mol.GetConformer()) return mol_to_geognn_graph_data(mol, atom_poses, dir_type='HT') def gen_adj(shape,edges,length): adj=edges e = shape ones = np.eye(e) #for i in range(e): for i in range (len(length)): if adj[i,0] != adj[i,1]: ones[adj[i,0],adj[i,1]]=(float(length[i] )) return ones if __name__ == "__main__": import pandas as pd from tqdm import tqdm f = pd.read_csv (r"J:\screenacc\new4.csv") # re = [] # pce = f['PCE'] # for ind,smile in enumerate ( f.iloc[:,1]): # print(ind) # atom,adj = mol_to_geognn_graph_data_MMFF3d(smile) # np.save('data/reg/train/adj'+str(ind)+'.npy',np.array(adj)) # re.append([atom,'data/reg/train/adj'+str(ind)+'.npy',pce[ind] ]) # r = pd.DataFrame(re) # r.to_csv('data/reg/train/train.csv') # re = [] # f = pd.read_csv(r'data/reg/test3.csv') # re = [] # pce = f['PCE'] # for ind,smile in enumerate ( f.iloc[:,1]): # print(ind) # atom,adj = mol_to_geognn_graph_data_MMFF3d(smile) # np.save('data/reg/test/adj'+str(ind)+'.npy',np.array(adj)) # re.append([atom,'data/reg/test/adj'+str(ind)+'.npy',pce[ind] ]) # r = pd.DataFrame(re) # r.to_csv('data/reg/test/test.csv') # f = pd.read_csv(r'val.csv') re = [] pce = f['PCE'] for ind,smile in enumerate ( f.iloc[ 22000: ,0]): ind = ind + 22000 print(ind) atom,adj = mol_to_geognn_graph_data_MMFF3d(smile) np.save('data/reg/val/adj'+str(ind)+'.npy',np.array(adj)) re.append([atom,'data/reg/val/adj'+str(ind)+'.npy',pce[ind] ]) r = pd.DataFrame(re) r.to_csv('data/reg/val/val22000.csv')