ICLR_FLAG / utils /mol_tree.py
zaixizhang
renew
10efe81
raw
history blame
7.19 kB
import sys
sys.path.append("..")
import rdkit
import rdkit.Chem as Chem
import copy
import pickle
from tqdm.auto import tqdm
import numpy as np
import torch
import random
from .chemutils import get_clique_mol, tree_decomp, get_mol, get_smiles, set_atommap, get_clique_mol_simple
from collections import defaultdict
def get_slots(smiles):
mol = Chem.MolFromSmiles(smiles, sanitize=False)
return [(atom.GetSymbol(), atom.GetFormalCharge(), atom.GetTotalNumHs()) for atom in mol.GetAtoms()]
class Vocab(object):
def __init__(self, smiles_list):
self.vocab = smiles_list
self.vmap = {x: i for i, x in enumerate(self.vocab)}
#self.slots = [get_slots(smiles) for smiles in self.vocab]
def get_index(self, smiles):
if smiles in self.vmap.keys():
return self.vmap[smiles]
else:
return 0
def get_smiles(self, idx):
return self.vocab[idx]
def get_slots(self, idx):
return copy.deepcopy(self.slots[idx])
def size(self):
return len(self.vocab)
class MolTreeNode(object):
def __init__(self, mol, cmol, clique):
self.smiles = Chem.MolToSmiles(cmol, canonical=True)
self.mol = cmol
self.clique = [x for x in clique] # copy
self.neighbors = []
self.rotatable = False
if len(self.clique) == 2:
if mol.GetAtomWithIdx(self.clique[0]).GetDegree() >= 2 and mol.GetAtomWithIdx(self.clique[1]).GetDegree() >= 2:
self.rotatable = True
# should restrict to single bond, but double bond is ok
def add_neighbor(self, nei_node):
self.neighbors.append(nei_node)
def recover(self, original_mol):
clique = []
clique.extend(self.clique)
if not self.is_leaf:
for cidx in self.clique:
original_mol.GetAtomWithIdx(cidx).SetAtomMapNum(self.nid)
for nei_node in self.neighbors:
clique.extend(nei_node.clique)
if nei_node.is_leaf: # Leaf node, no need to mark
continue
for cidx in nei_node.clique:
# allow singleton node override the atom mapping
if cidx not in self.clique or len(nei_node.clique) == 1:
atom = original_mol.GetAtomWithIdx(cidx)
atom.SetAtomMapNum(nei_node.nid)
clique = list(set(clique))
label_mol = get_clique_mol_simple(original_mol, clique)
self.label = Chem.MolToSmiles(Chem.MolFromSmiles(get_smiles(label_mol)))
self.label_mol = get_mol(self.label)
for cidx in clique:
original_mol.GetAtomWithIdx(cidx).SetAtomMapNum(0)
return self.label
def assemble(self):
# neighbors = [nei for nei in self.neighbors if nei.mol.GetNumAtoms() > 1]
neighbors = sorted(self.neighbors, key=lambda x: x.mol.GetNumAtoms(), reverse=True)
# singletons = [nei for nei in self.neighbors if nei.mol.GetNumAtoms() == 1]
# neighbors = singletons + neighbors
cands = enum_assemble(self, neighbors)
if len(cands) > 0:
self.cands, self.cand_mols, _ = zip(*cands)
self.cands = list(self.cands)
self.cand_mols = list(self.cand_mols)
else:
self.cands = []
self.cand_mols = []
class MolTree(object):
def __init__(self, mol):
self.smiles = Chem.MolToSmiles(mol)
self.mol = mol
self.num_rotatable_bond = 0
'''
# use reference_vocab and threshold to control the size of vocab
reference_vocab = np.load('./utils/reference.npy', allow_pickle=True).item()
reference = defaultdict(int)
for k, v in reference_vocab.items():
reference[k] = v'''
# use vanilla tree decomposition for simplicity
cliques, edges = tree_decomp(self.mol, reference_vocab=None)
self.nodes = []
root = 0
for i, c in enumerate(cliques):
cmol = get_clique_mol_simple(self.mol, c)
node = MolTreeNode(self.mol, cmol, c)
self.nodes.append(node)
if min(c) == 0:
root = i
for node in self.nodes:
if node.rotatable:
self.num_rotatable_bond += 1
for x, y in edges:
self.nodes[x].add_neighbor(self.nodes[y])
self.nodes[y].add_neighbor(self.nodes[x])
if root > 0:
self.nodes[0], self.nodes[root] = self.nodes[root], self.nodes[0]
for i, node in enumerate(self.nodes):
node.nid = i + 1
'''
if len(node.neighbors) > 1: # Leaf node mol is not marked
set_atommap(node.mol, node.nid)
node.is_leaf = (len(node.neighbors) == 1)'''
def size(self):
return len(self.nodes)
def recover(self):
for node in self.nodes:
node.recover(self.mol)
def assemble(self):
for node in self.nodes:
node.assemble()
if __name__ == "__main__":
seed = 2023
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
vocab = {}
cnt = 0
rot = 0
'''
index_path = './data/crossdocked_pocket10/index.pkl'
with open(index_path, 'rb') as f:
index = pickle.load(f)
for i, (pocket_fn, ligand_fn, _, rmsd_str) in enumerate(tqdm(index)):
if pocket_fn is None: continue
try:
path = './data/crossdocked_pocket10/' + ligand_fn
mol = Chem.MolFromMolFile(path, sanitize=False)
moltree = MolTree(mol)
cnt += 1
if moltree.num_rotatable_bond > 0:
rot += 1
except:
continue
for c in moltree.nodes:
smile_cluster = c.smiles
if smile_cluster not in vocab:
vocab[smile_cluster] = 1
else:
vocab[smile_cluster] += 1
'''
index = torch.load('/n/holyscratch01/mzitnik_lab/zaixizhang/pdbbind_pocket10/index.pt')
for i, pdbid in enumerate(tqdm(index)):
if pdbid is None: continue
try:
path = '/n/holyscratch01/mzitnik_lab/zaixizhang/pdbbind_pocket10/'
ligand_path = os.path.join(path, os.path.join(item, item+'_ligand.sdf'))
mol = Chem.MolFromMolFile(ligand_path, sanitize=False)
moltree = MolTree(mol)
cnt += 1
if moltree.num_rotatable_bond > 0:
rot += 1
except:
continue
for c in moltree.nodes:
smile_cluster = c.smiles
if smile_cluster not in vocab:
vocab[smile_cluster] = 1
else:
vocab[smile_cluster] += 1
vocab = dict(sorted(vocab.items(), key=lambda kv: (kv[1], kv[0]), reverse=True))
filename = open('./vocab.txt', 'w')
for k, v in vocab.items():
filename.write(k + ':' + str(v))
filename.write('\n')
filename.close()
# number of molecules and vocab
print('Size of the motif vocab:', len(vocab))
print('Total number of molecules', cnt)
print('Percent of molecules with rotatable bonds:', rot / cnt)