zaixizhang
commited on
Commit
•
10efe81
1
Parent(s):
465c18c
renew
Browse files- .gitattributes +1 -0
- app.py +48 -4
- checkpoints/pretrained.pt +3 -0
- configs/sample.yml +18 -0
- data/index.pt +3 -0
- data/pdbbind_pocket10_name2id.pt +3 -0
- data/pdbbind_pocket10_processed.lmdb +3 -0
- data/pdbbind_pocket10_processed.lmdb-lock +0 -0
- data/split_by_name.pt +3 -0
- evaluation/prepare_receptor4.py +183 -0
- evaluation/vina_score.py +35 -0
- models/common.py +282 -0
- models/encoders/__init__.py +26 -0
- models/encoders/gnn.py +441 -0
- models/encoders/schnet.py +105 -0
- models/encoders/tf.py +152 -0
- models/flag.py +268 -0
- motif_sample.py +660 -0
- requirements.txt +390 -0
- utils/__init__.py +3 -0
- utils/chem.py +119 -0
- utils/chemutils.py +597 -0
- utils/data.py +127 -0
- utils/datasets/__init__.py +21 -0
- utils/datasets/pl.py +176 -0
- utils/dihedral_utils.py +383 -0
- utils/docking.py +183 -0
- utils/fpscores.pkl.gz +3 -0
- utils/misc.py +78 -0
- utils/mol_tree.py +220 -0
- utils/protein_ligand.py +283 -0
- utils/reconstruct.py +498 -0
- utils/sascorer.py +163 -0
- utils/similarity.py +20 -0
- utils/train.py +102 -0
- utils/transforms.py +684 -0
- utils/warmup.py +86 -0
- vocab.txt +549 -0
.gitattributes
CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
*.lmdb filter=lfs diff=lfs merge=lfs -text
|
app.py
CHANGED
@@ -1,7 +1,51 @@
|
|
1 |
import gradio as gr
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
|
3 |
-
def greet(name):
|
4 |
-
return "Hello " + name + "!!"
|
5 |
|
6 |
-
|
7 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import gradio as gr
|
2 |
+
import os
|
3 |
+
from rdkit import Chem
|
4 |
+
from motif_sample import demo
|
5 |
+
import py3Dmol
|
6 |
+
import tempfile
|
7 |
+
from rdkit.Chem import AllChem
|
8 |
+
import numpy as np
|
9 |
+
from PIL import Image
|
10 |
+
import io
|
11 |
+
from rdkit.Chem import Draw
|
12 |
|
|
|
|
|
13 |
|
14 |
+
# Function to serve the file via Gradio
|
15 |
+
def create_and_return_sdf(Protein_index: int):
|
16 |
+
# Ensure input is an integer, as it's coming from the interface.
|
17 |
+
number = Protein_index
|
18 |
+
number = int(number)
|
19 |
+
|
20 |
+
# Generate SDF file (you'll replace this with your actual logic)
|
21 |
+
sdf_filename = demo(number)
|
22 |
+
|
23 |
+
suppl = Chem.SDMolSupplier(sdf_filename)
|
24 |
+
mol = next(suppl)
|
25 |
+
# AllChem.UFFOptimizeMolecule(mol)
|
26 |
+
mol = Chem.MolFromSmiles(Chem.MolToSmiles(mol))
|
27 |
+
|
28 |
+
for atom in mol.GetAtoms():
|
29 |
+
atom.SetAtomMapNum(0)
|
30 |
+
|
31 |
+
mol_image = Draw.MolToImage(mol)
|
32 |
+
|
33 |
+
np_image = np.array(mol_image)
|
34 |
+
np_image = np_image[:, :, :3]
|
35 |
+
|
36 |
+
return np_image, sdf_filename
|
37 |
+
|
38 |
+
|
39 |
+
# Define Gradio interface
|
40 |
+
iface = gr.Interface(
|
41 |
+
fn=create_and_return_sdf,
|
42 |
+
inputs="text",
|
43 |
+
outputs=[
|
44 |
+
gr.outputs.Image(type="numpy", label="Molecule Image"),
|
45 |
+
gr.outputs.File(label="Download SDF")
|
46 |
+
],
|
47 |
+
live=False # The function should only be called when the user submits the form
|
48 |
+
)
|
49 |
+
|
50 |
+
# Launch the interface
|
51 |
+
iface.launch(share=True)
|
checkpoints/pretrained.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:a61f9bee0a6ce3101d8df5b55d71831d8428c4d6ff82ab81ada8bb5277babd42
|
3 |
+
size 44147405
|
configs/sample.yml
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
dataset:
|
2 |
+
name: pl
|
3 |
+
path: ./data
|
4 |
+
split: ./data/split_by_name.pt
|
5 |
+
|
6 |
+
model:
|
7 |
+
checkpoint: ./checkpoints/pretrained.pt
|
8 |
+
hidden_channels: 256
|
9 |
+
random_alpha: False
|
10 |
+
|
11 |
+
sample:
|
12 |
+
seed: 2024
|
13 |
+
num_samples: 100
|
14 |
+
num_retry: 5
|
15 |
+
max_steps: 12
|
16 |
+
batch_size: 10
|
17 |
+
num_workers: 4
|
18 |
+
n_samples: 5
|
data/index.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:c095461a584f03838af99d6411483040f641659d1521dc5247bcefd72e36ca8e
|
3 |
+
size 226859
|
data/pdbbind_pocket10_name2id.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:0aa8a46e2cd77abb5b59221ece57e607ba30427d964d9d5b08c0d02e3449399c
|
3 |
+
size 237265
|
data/pdbbind_pocket10_processed.lmdb
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:a21c6a2231e3394d3c8e27b1e547a9376f4ee0bdc1c246fbbf32a8fc076710eb
|
3 |
+
size 386912256
|
data/pdbbind_pocket10_processed.lmdb-lock
ADDED
Binary file (8.19 kB). View file
|
|
data/split_by_name.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:d406b7bd5189c35cbe73943b5c8e0d61ae1fdc79b40d24dbb45c474915892b56
|
3 |
+
size 227451
|
evaluation/prepare_receptor4.py
ADDED
@@ -0,0 +1,183 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
#
|
3 |
+
#
|
4 |
+
#
|
5 |
+
# $Header: /opt/cvs/python/packages/share1.5/AutoDockTools/Utilities24/prepare_receptor4.py,v 1.11 2007/11/28 22:40:22 rhuey Exp $
|
6 |
+
#
|
7 |
+
import os
|
8 |
+
|
9 |
+
from MolKit import Read
|
10 |
+
import MolKit.molecule
|
11 |
+
import MolKit.protein
|
12 |
+
from AutoDockTools.MoleculePreparation import AD4ReceptorPreparation
|
13 |
+
|
14 |
+
|
15 |
+
if __name__ == '__main__':
|
16 |
+
import sys
|
17 |
+
import getopt
|
18 |
+
|
19 |
+
|
20 |
+
def usage():
|
21 |
+
"Print helpful, accurate usage statement to stdout."
|
22 |
+
print "Usage: prepare_receptor4.py -r filename"
|
23 |
+
print
|
24 |
+
print " Description of command..."
|
25 |
+
print " -r receptor_filename "
|
26 |
+
print " supported file types include pdb,mol2,pdbq,pdbqs,pdbqt, possibly pqr,cif"
|
27 |
+
print " Optional parameters:"
|
28 |
+
print " [-v] verbose output (default is minimal output)"
|
29 |
+
print " [-o pdbqt_filename] (default is 'molecule_name.pdbqt')"
|
30 |
+
print " [-A] type(s) of repairs to make: "
|
31 |
+
print " 'bonds_hydrogens': build bonds and add hydrogens "
|
32 |
+
print " 'bonds': build a single bond from each atom with no bonds to its closest neighbor"
|
33 |
+
print " 'hydrogens': add hydrogens"
|
34 |
+
print " 'checkhydrogens': add hydrogens only if there are none already"
|
35 |
+
print " 'None': do not make any repairs "
|
36 |
+
print " (default is 'checkhydrogens')"
|
37 |
+
print " [-C] preserve all input charges ie do not add new charges "
|
38 |
+
print " (default is addition of gasteiger charges)"
|
39 |
+
print " [-p] preserve input charges on specific atom types, eg -p Zn -p Fe"
|
40 |
+
print " [-U] cleanup type:"
|
41 |
+
print " 'nphs': merge charges and remove non-polar hydrogens"
|
42 |
+
print " 'lps': merge charges and remove lone pairs"
|
43 |
+
print " 'waters': remove water residues"
|
44 |
+
print " 'nonstdres': remove chains composed entirely of residues of"
|
45 |
+
print " types other than the standard 20 amino acids"
|
46 |
+
print " 'deleteAltB': remove XX@B atoms and rename XX@A atoms->XX"
|
47 |
+
print " (default is 'nphs_lps_waters_nonstdres') "
|
48 |
+
print " [-e] delete every nonstd residue from any chain"
|
49 |
+
print " 'True': any residue whose name is not in this list:"
|
50 |
+
print " ['CYS','ILE','SER','VAL','GLN','LYS','ASN', "
|
51 |
+
print " 'PRO','THR','PHE','ALA','HIS','GLY','ASP', "
|
52 |
+
print " 'LEU', 'ARG', 'TRP', 'GLU', 'TYR','MET', "
|
53 |
+
print " 'HID', 'HSP', 'HIE', 'HIP', 'CYX', 'CSS']"
|
54 |
+
print " will be deleted from any chain. "
|
55 |
+
print " NB: there are no nucleic acid residue names at all "
|
56 |
+
print " in the list and no metals. "
|
57 |
+
print " (default is False which means not to do this)"
|
58 |
+
print " [-M] interactive "
|
59 |
+
print " (default is 'automatic': outputfile is written with no further user input)"
|
60 |
+
|
61 |
+
|
62 |
+
# process command arguments
|
63 |
+
try:
|
64 |
+
opt_list, args = getopt.getopt(sys.argv[1:], 'r:vo:A:Cp:U:eM:')
|
65 |
+
|
66 |
+
except getopt.GetoptError, msg:
|
67 |
+
print 'prepare_receptor4.py: %s' %msg
|
68 |
+
usage()
|
69 |
+
sys.exit(2)
|
70 |
+
|
71 |
+
# initialize required parameters
|
72 |
+
#-s: receptor
|
73 |
+
receptor_filename = None
|
74 |
+
|
75 |
+
# optional parameters
|
76 |
+
verbose = None
|
77 |
+
#-A: repairs to make: add bonds and/or hydrogens or checkhydrogens
|
78 |
+
repairs = ''
|
79 |
+
#-C default: add gasteiger charges
|
80 |
+
charges_to_add = 'gasteiger'
|
81 |
+
#-p preserve charges on specific atom types
|
82 |
+
preserve_charge_types=None
|
83 |
+
#-U: cleanup by merging nphs_lps, nphs, lps, waters, nonstdres
|
84 |
+
cleanup = "nphs_lps_waters_nonstdres"
|
85 |
+
#-o outputfilename
|
86 |
+
outputfilename = None
|
87 |
+
#-m mode
|
88 |
+
mode = 'automatic'
|
89 |
+
#-e delete every nonstd residue from each chain
|
90 |
+
delete_single_nonstd_residues = None
|
91 |
+
|
92 |
+
#'r:vo:A:Cp:U:eMh'
|
93 |
+
for o, a in opt_list:
|
94 |
+
if o in ('-r', '--r'):
|
95 |
+
receptor_filename = a
|
96 |
+
if verbose: print 'set receptor_filename to ', a
|
97 |
+
if o in ('-v', '--v'):
|
98 |
+
verbose = True
|
99 |
+
if verbose: print 'set verbose to ', True
|
100 |
+
if o in ('-o', '--o'):
|
101 |
+
outputfilename = a
|
102 |
+
if verbose: print 'set outputfilename to ', a
|
103 |
+
if o in ('-A', '--A'):
|
104 |
+
repairs = a
|
105 |
+
if verbose: print 'set repairs to ', a
|
106 |
+
if o in ('-C', '--C'):
|
107 |
+
charges_to_add = None
|
108 |
+
if verbose: print 'do not add charges'
|
109 |
+
if o in ('-p', '--p'):
|
110 |
+
if not preserve_charge_types:
|
111 |
+
preserve_charge_types = a
|
112 |
+
else:
|
113 |
+
preserve_charge_types = preserve_charge_types + ','+ a
|
114 |
+
if verbose: print 'preserve initial charges on ', preserve_charge_types
|
115 |
+
if o in ('-U', '--U'):
|
116 |
+
cleanup = a
|
117 |
+
if verbose: print 'set cleanup to ', a
|
118 |
+
if o in ('-e', '--e'):
|
119 |
+
delete_single_nonstd_residues = True
|
120 |
+
if verbose: print 'set delete_single_nonstd_residues to True'
|
121 |
+
if o in ('-M', '--M'):
|
122 |
+
mode = a
|
123 |
+
if verbose: print 'set mode to ', a
|
124 |
+
if o in ('-h', '--'):
|
125 |
+
usage()
|
126 |
+
sys.exit()
|
127 |
+
|
128 |
+
|
129 |
+
if not receptor_filename:
|
130 |
+
print 'prepare_receptor4: receptor filename must be specified.'
|
131 |
+
usage()
|
132 |
+
sys.exit()
|
133 |
+
|
134 |
+
#what about nucleic acids???
|
135 |
+
|
136 |
+
mols = Read(receptor_filename)
|
137 |
+
if verbose: print 'read ', receptor_filename
|
138 |
+
mol = mols[0]
|
139 |
+
preserved = {}
|
140 |
+
if charges_to_add is not None and preserve_charge_types is not None:
|
141 |
+
preserved_types = preserve_charge_types.split(',')
|
142 |
+
if verbose: print "preserved_types=", preserved_types
|
143 |
+
for t in preserved_types:
|
144 |
+
if verbose: print 'preserving charges on type->', t
|
145 |
+
if not len(t): continue
|
146 |
+
ats = mol.allAtoms.get(lambda x: x.autodock_element==t)
|
147 |
+
if verbose: print "preserving charges on ", ats.name
|
148 |
+
for a in ats:
|
149 |
+
if a.chargeSet is not None:
|
150 |
+
preserved[a] = [a.chargeSet, a.charge]
|
151 |
+
|
152 |
+
if len(mols)>1:
|
153 |
+
if verbose: print "more than one molecule in file"
|
154 |
+
#use the molecule with the most atoms
|
155 |
+
ctr = 1
|
156 |
+
for m in mols[1:]:
|
157 |
+
ctr += 1
|
158 |
+
if len(m.allAtoms)>len(mol.allAtoms):
|
159 |
+
mol = m
|
160 |
+
if verbose: print "mol set to ", ctr, "th molecule with", len(mol.allAtoms), "atoms"
|
161 |
+
mol.buildBondsByDistance()
|
162 |
+
|
163 |
+
if verbose:
|
164 |
+
print "setting up RPO with mode=", mode,
|
165 |
+
print "and outputfilename= ", outputfilename
|
166 |
+
print "charges_to_add=", charges_to_add
|
167 |
+
print "delete_single_nonstd_residues=", delete_single_nonstd_residues
|
168 |
+
|
169 |
+
RPO = AD4ReceptorPreparation(mol, mode, repairs, charges_to_add,
|
170 |
+
cleanup, outputfilename=outputfilename,
|
171 |
+
preserved=preserved,
|
172 |
+
delete_single_nonstd_residues=delete_single_nonstd_residues)
|
173 |
+
|
174 |
+
if charges_to_add is not None:
|
175 |
+
#restore any previous charges
|
176 |
+
for atom, chargeList in preserved.items():
|
177 |
+
atom._charges[chargeList[0]] = chargeList[1]
|
178 |
+
atom.chargeSet = chargeList[0]
|
179 |
+
|
180 |
+
|
181 |
+
# To execute this command type:
|
182 |
+
# prepare_receptor4.py -r pdb_file -o outputfilename -A checkhydrogens
|
183 |
+
|
evaluation/vina_score.py
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from vina import Vina
|
2 |
+
from rdkit.Chem.rdForceFieldHelpers import UFFOptimizeMolecule
|
3 |
+
from rdkit import Chem
|
4 |
+
import numpy as np
|
5 |
+
import os
|
6 |
+
|
7 |
+
for i in range(100):
|
8 |
+
path = './' + str(i) + '.sdf'
|
9 |
+
if os.path.exists(path):
|
10 |
+
print(path)
|
11 |
+
v = Vina(sf_name='vina')
|
12 |
+
v.set_receptor('2rma_protein.pdbqt')
|
13 |
+
v.set_ligand_from_file('2rma_ligand'+'.pdbqt')
|
14 |
+
|
15 |
+
# Calculate the docking center
|
16 |
+
mol = Chem.MolFromMolFile(path, sanitize=True)
|
17 |
+
mol = Chem.AddHs(mol, addCoords=True)
|
18 |
+
UFFOptimizeMolecule(mol)
|
19 |
+
pos = mol.GetConformer(0).GetPositions()
|
20 |
+
center = np.mean(pos, 0)
|
21 |
+
|
22 |
+
v.compute_vina_maps(center=center, box_size=[20, 20, 20])
|
23 |
+
|
24 |
+
# Score the current pose
|
25 |
+
energy = v.score()
|
26 |
+
print('Score before minimization: %.3f (kcal/mol)' % energy[0])
|
27 |
+
|
28 |
+
# Minimized locally the current pose
|
29 |
+
energy_minimized = v.optimize()
|
30 |
+
print('Score after minimization : %.3f (kcal/mol)' % energy_minimized[0])
|
31 |
+
v.write_pose('ligand_minimized.pdbqt', overwrite=True)
|
32 |
+
|
33 |
+
# Dock the ligand
|
34 |
+
v.dock(exhaustiveness=64, n_poses=30)
|
35 |
+
v.write_poses('out.pdbqt', n_poses=5, overwrite=True)
|
models/common.py
ADDED
@@ -0,0 +1,282 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from torch.nn.modules.loss import _WeightedLoss
|
6 |
+
from torch_scatter import scatter_mean, scatter_add
|
7 |
+
|
8 |
+
|
9 |
+
def split_tensor_by_batch(x, batch, num_graphs=None):
|
10 |
+
"""
|
11 |
+
Args:
|
12 |
+
x: (N, ...)
|
13 |
+
batch: (B, )
|
14 |
+
Returns:
|
15 |
+
[(N_1, ), (N_2, ) ..., (N_B, ))]
|
16 |
+
"""
|
17 |
+
if num_graphs is None:
|
18 |
+
num_graphs = batch.max().item() + 1
|
19 |
+
x_split = []
|
20 |
+
for i in range (num_graphs):
|
21 |
+
mask = batch == i
|
22 |
+
x_split.append(x[mask])
|
23 |
+
return x_split
|
24 |
+
|
25 |
+
|
26 |
+
def concat_tensors_to_batch(x_split):
|
27 |
+
x = torch.cat(x_split, dim=0)
|
28 |
+
batch = torch.repeat_interleave(
|
29 |
+
torch.arange(len(x_split)),
|
30 |
+
repeats=torch.LongTensor([s.size(0) for s in x_split])
|
31 |
+
).to(device=x.device)
|
32 |
+
return x, batch
|
33 |
+
|
34 |
+
|
35 |
+
def split_tensor_to_segments(x, segsize):
|
36 |
+
num_segs = math.ceil(x.size(0) / segsize)
|
37 |
+
segs = []
|
38 |
+
for i in range(num_segs):
|
39 |
+
segs.append(x[i*segsize : (i+1)*segsize])
|
40 |
+
return segs
|
41 |
+
|
42 |
+
|
43 |
+
def split_tensor_by_lengths(x, lengths):
|
44 |
+
segs = []
|
45 |
+
for l in lengths:
|
46 |
+
segs.append(x[:l])
|
47 |
+
x = x[l:]
|
48 |
+
return segs
|
49 |
+
|
50 |
+
|
51 |
+
def batch_intersection_mask(batch, batch_filter):
|
52 |
+
batch_filter = batch_filter.unique()
|
53 |
+
mask = (batch.view(-1, 1) == batch_filter.view(1, -1)).any(dim=1)
|
54 |
+
return mask
|
55 |
+
|
56 |
+
|
57 |
+
class MeanReadout(nn.Module):
|
58 |
+
"""Mean readout operator over graphs with variadic sizes."""
|
59 |
+
|
60 |
+
def forward(self, input, batch, num_graphs):
|
61 |
+
"""
|
62 |
+
Perform readout over the graph(s).
|
63 |
+
Parameters:
|
64 |
+
data (torch_geometric.data.Data): batched graph
|
65 |
+
input (Tensor): node representations
|
66 |
+
Returns:
|
67 |
+
Tensor: graph representations
|
68 |
+
"""
|
69 |
+
output = scatter_mean(input, batch, dim=0, dim_size=num_graphs)
|
70 |
+
return output
|
71 |
+
|
72 |
+
|
73 |
+
class SumReadout(nn.Module):
|
74 |
+
"""Sum readout operator over graphs with variadic sizes."""
|
75 |
+
|
76 |
+
def forward(self, input, batch, num_graphs):
|
77 |
+
"""
|
78 |
+
Perform readout over the graph(s).
|
79 |
+
Parameters:
|
80 |
+
data (torch_geometric.data.Data): batched graph
|
81 |
+
input (Tensor): node representations
|
82 |
+
Returns:
|
83 |
+
Tensor: graph representations
|
84 |
+
"""
|
85 |
+
output = scatter_add(input, batch, dim=0, dim_size=num_graphs)
|
86 |
+
return output
|
87 |
+
|
88 |
+
|
89 |
+
class MultiLayerPerceptron(nn.Module):
|
90 |
+
"""
|
91 |
+
Multi-layer Perceptron.
|
92 |
+
Note there is no activation or dropout in the last layer.
|
93 |
+
Parameters:
|
94 |
+
input_dim (int): input dimension
|
95 |
+
hidden_dim (list of int): hidden dimensions
|
96 |
+
activation (str or function, optional): activation function
|
97 |
+
dropout (float, optional): dropout rate
|
98 |
+
"""
|
99 |
+
|
100 |
+
def __init__(self, input_dim, hidden_dims, activation="relu", dropout=0):
|
101 |
+
super(MultiLayerPerceptron, self).__init__()
|
102 |
+
|
103 |
+
self.dims = [input_dim] + hidden_dims
|
104 |
+
if isinstance(activation, str):
|
105 |
+
self.activation = getattr(F, activation)
|
106 |
+
else:
|
107 |
+
self.activation = None
|
108 |
+
if dropout:
|
109 |
+
self.dropout = nn.Dropout(dropout)
|
110 |
+
else:
|
111 |
+
self.dropout = None
|
112 |
+
|
113 |
+
self.layers = nn.ModuleList()
|
114 |
+
for i in range(len(self.dims) - 1):
|
115 |
+
self.layers.append(nn.Linear(self.dims[i], self.dims[i + 1]))
|
116 |
+
|
117 |
+
def forward(self, input):
|
118 |
+
""""""
|
119 |
+
x = input
|
120 |
+
for i, layer in enumerate(self.layers):
|
121 |
+
x = layer(x)
|
122 |
+
if i < len(self.layers) - 1:
|
123 |
+
if self.activation:
|
124 |
+
x = self.activation(x)
|
125 |
+
if self.dropout:
|
126 |
+
x = self.dropout(x)
|
127 |
+
return x
|
128 |
+
|
129 |
+
|
130 |
+
class SmoothCrossEntropyLoss(_WeightedLoss):
|
131 |
+
def __init__(self, weight=None, reduction='mean', smoothing=0.0):
|
132 |
+
super().__init__(weight=weight, reduction=reduction)
|
133 |
+
self.smoothing = smoothing
|
134 |
+
self.weight = weight
|
135 |
+
self.reduction = reduction
|
136 |
+
|
137 |
+
@staticmethod
|
138 |
+
def _smooth_one_hot(targets:torch.Tensor, n_classes:int, smoothing=0.0):
|
139 |
+
assert 0 <= smoothing < 1
|
140 |
+
with torch.no_grad():
|
141 |
+
targets = torch.empty(size=(targets.size(0), n_classes),
|
142 |
+
device=targets.device) \
|
143 |
+
.fill_(smoothing /(n_classes-1)) \
|
144 |
+
.scatter_(1, targets.data.unsqueeze(1), 1.-smoothing)
|
145 |
+
return targets
|
146 |
+
|
147 |
+
def forward(self, inputs, targets):
|
148 |
+
targets = SmoothCrossEntropyLoss._smooth_one_hot(targets, inputs.size(-1),
|
149 |
+
self.smoothing)
|
150 |
+
lsm = F.log_softmax(inputs, -1)
|
151 |
+
|
152 |
+
if self.weight is not None:
|
153 |
+
lsm = lsm * self.weight.unsqueeze(0)
|
154 |
+
|
155 |
+
loss = -(targets * lsm).sum(-1)
|
156 |
+
|
157 |
+
if self.reduction == 'sum':
|
158 |
+
loss = loss.sum()
|
159 |
+
elif self.reduction == 'mean':
|
160 |
+
loss = loss.mean()
|
161 |
+
|
162 |
+
return loss
|
163 |
+
|
164 |
+
|
165 |
+
class GaussianSmearing(nn.Module):
|
166 |
+
def __init__(self, start=0.0, stop=10.0, num_gaussians=50):
|
167 |
+
super().__init__()
|
168 |
+
offset = torch.linspace(start, stop, num_gaussians)
|
169 |
+
self.coeff = -0.5 / (offset[1] - offset[0]).item()**2
|
170 |
+
self.register_buffer('offset', offset)
|
171 |
+
|
172 |
+
def forward(self, dist):
|
173 |
+
dist = dist.view(-1, 1) - self.offset.view(1, -1)
|
174 |
+
return torch.exp(self.coeff * torch.pow(dist, 2))
|
175 |
+
|
176 |
+
|
177 |
+
class ShiftedSoftplus(nn.Module):
|
178 |
+
def __init__(self):
|
179 |
+
super().__init__()
|
180 |
+
self.shift = torch.log(torch.tensor(2.0)).item()
|
181 |
+
|
182 |
+
def forward(self, x):
|
183 |
+
return F.softplus(x) - self.shift
|
184 |
+
|
185 |
+
|
186 |
+
def compose_context(h_protein, h_ligand, pos_protein, pos_ligand, batch_protein, batch_ligand):
|
187 |
+
batch_ctx = torch.cat([batch_protein, batch_ligand], dim=0)
|
188 |
+
sort_idx = batch_ctx.argsort()
|
189 |
+
|
190 |
+
mask_protein = torch.cat([
|
191 |
+
torch.ones([batch_protein.size(0)], device=batch_protein.device).bool(),
|
192 |
+
torch.zeros([batch_ligand.size(0)], device=batch_ligand.device).bool(),
|
193 |
+
], dim=0)[sort_idx]
|
194 |
+
|
195 |
+
batch_ctx = batch_ctx[sort_idx]
|
196 |
+
h_ctx = torch.cat([h_protein, h_ligand], dim=0)[sort_idx] # (N_protein+N_ligand, H)
|
197 |
+
pos_ctx = torch.cat([pos_protein, pos_ligand], dim=0)[sort_idx] # (N_protein+N_ligand, 3)
|
198 |
+
|
199 |
+
return h_ctx, pos_ctx, batch_ctx
|
200 |
+
|
201 |
+
|
202 |
+
def get_complete_graph(batch):
|
203 |
+
"""
|
204 |
+
Args:
|
205 |
+
batch: Batch index.
|
206 |
+
Returns:
|
207 |
+
edge_index: (2, N_1 + N_2 + ... + N_{B-1}), where N_i is the number of nodes of the i-th graph.
|
208 |
+
neighbors: (B, ), number of edges per graph.
|
209 |
+
"""
|
210 |
+
natoms = scatter_add(torch.ones_like(batch), index=batch, dim=0)
|
211 |
+
|
212 |
+
natoms_sqr = (natoms ** 2).long()
|
213 |
+
num_atom_pairs = torch.sum(natoms_sqr)
|
214 |
+
natoms_expand = torch.repeat_interleave(natoms, natoms_sqr)
|
215 |
+
|
216 |
+
index_offset = torch.cumsum(natoms, dim=0) - natoms
|
217 |
+
index_offset_expand = torch.repeat_interleave(index_offset, natoms_sqr)
|
218 |
+
|
219 |
+
index_sqr_offset = torch.cumsum(natoms_sqr, dim=0) - natoms_sqr
|
220 |
+
index_sqr_offset = torch.repeat_interleave(index_sqr_offset, natoms_sqr)
|
221 |
+
|
222 |
+
atom_count_sqr = torch.arange(num_atom_pairs, device=num_atom_pairs.device) - index_sqr_offset
|
223 |
+
|
224 |
+
index1 = (atom_count_sqr // natoms_expand).long() + index_offset_expand
|
225 |
+
index2 = (atom_count_sqr % natoms_expand).long() + index_offset_expand
|
226 |
+
edge_index = torch.cat([index1.view(1, -1), index2.view(1, -1)])
|
227 |
+
mask = torch.logical_not(index1 == index2)
|
228 |
+
edge_index = edge_index[:, mask]
|
229 |
+
|
230 |
+
num_edges = natoms_sqr - natoms # Number of edges per graph
|
231 |
+
|
232 |
+
return edge_index, num_edges
|
233 |
+
|
234 |
+
|
235 |
+
def compose_context_stable(h_protein, h_ligand, pos_protein, pos_ligand, batch_protein, batch_ligand):
|
236 |
+
num_graphs = batch_protein.max().item() + 1
|
237 |
+
|
238 |
+
batch_ctx = []
|
239 |
+
h_ctx = []
|
240 |
+
pos_ctx = []
|
241 |
+
mask_protein = []
|
242 |
+
|
243 |
+
for i in range(num_graphs):
|
244 |
+
mask_p, mask_l = (batch_protein == i), (batch_ligand == i)
|
245 |
+
batch_p, batch_l = batch_protein[mask_p], batch_ligand[mask_l]
|
246 |
+
|
247 |
+
batch_ctx += [batch_p, batch_l]
|
248 |
+
h_ctx += [h_protein[mask_p], h_ligand[mask_l]]
|
249 |
+
pos_ctx += [pos_protein[mask_p], pos_ligand[mask_l]]
|
250 |
+
mask_protein += [
|
251 |
+
torch.ones([batch_p.size(0)], device=batch_p.device, dtype=torch.bool),
|
252 |
+
torch.zeros([batch_l.size(0)], device=batch_l.device, dtype=torch.bool),
|
253 |
+
]
|
254 |
+
|
255 |
+
batch_ctx = torch.cat(batch_ctx, dim=0)
|
256 |
+
h_ctx = torch.cat(h_ctx, dim=0)
|
257 |
+
pos_ctx = torch.cat(pos_ctx, dim=0)
|
258 |
+
mask_protein = torch.cat(mask_protein, dim=0)
|
259 |
+
|
260 |
+
return h_ctx, pos_ctx, batch_ctx, mask_protein
|
261 |
+
|
262 |
+
if __name__ == '__main__':
|
263 |
+
h_protein = torch.randn([60, 64])
|
264 |
+
h_ligand = -torch.randn([33, 64])
|
265 |
+
pos_protein = torch.clamp(torch.randn([60, 3]), 0, float('inf'))
|
266 |
+
pos_ligand = torch.clamp(torch.randn([33, 3]), float('-inf'), 0)
|
267 |
+
batch_protein = torch.LongTensor([0]*10 + [1]*20 + [2]*30)
|
268 |
+
batch_ligand = torch.LongTensor([0]*11 + [1]*11 + [2]*11)
|
269 |
+
|
270 |
+
h_ctx, pos_ctx, batch_ctx, mask_protein = compose_context_stable(h_protein, h_ligand, pos_protein, pos_ligand, batch_protein, batch_ligand)
|
271 |
+
|
272 |
+
assert (batch_ctx[mask_protein] == batch_protein).all()
|
273 |
+
assert (batch_ctx[torch.logical_not(mask_protein)] == batch_ligand).all()
|
274 |
+
|
275 |
+
assert torch.allclose(h_ctx[torch.logical_not(mask_protein)], h_ligand)
|
276 |
+
assert torch.allclose(h_ctx[mask_protein], h_protein)
|
277 |
+
|
278 |
+
assert torch.allclose(pos_ctx[torch.logical_not(mask_protein)], pos_ligand)
|
279 |
+
assert torch.allclose(pos_ctx[mask_protein], pos_protein)
|
280 |
+
|
281 |
+
|
282 |
+
|
models/encoders/__init__.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .schnet import SchNetEncoder
|
2 |
+
from .tf import TransformerEncoder
|
3 |
+
from .gnn import GNN_graphpred, MLP
|
4 |
+
|
5 |
+
|
6 |
+
def get_encoder(config):
|
7 |
+
if config.name == 'schnet':
|
8 |
+
return SchNetEncoder(
|
9 |
+
hidden_channels = config.hidden_channels,
|
10 |
+
num_filters = config.num_filters,
|
11 |
+
num_interactions = config.num_interactions,
|
12 |
+
edge_channels = config.edge_channels,
|
13 |
+
cutoff = config.cutoff,
|
14 |
+
)
|
15 |
+
elif config.name == 'tf':
|
16 |
+
return TransformerEncoder(
|
17 |
+
hidden_channels = config.hidden_channels,
|
18 |
+
edge_channels = config.edge_channels,
|
19 |
+
key_channels = config.key_channels,
|
20 |
+
num_heads = config.num_heads,
|
21 |
+
num_interactions = config.num_interactions,
|
22 |
+
k = config.knn,
|
23 |
+
cutoff = config.cutoff,
|
24 |
+
)
|
25 |
+
else:
|
26 |
+
raise NotImplementedError('Unknown encoder: %s' % config.name)
|
models/encoders/gnn.py
ADDED
@@ -0,0 +1,441 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from torch_geometric.nn import MessagePassing
|
4 |
+
from torch_geometric.utils import add_self_loops, degree, softmax
|
5 |
+
from torch_geometric.nn import global_add_pool, global_mean_pool, global_max_pool, GlobalAttention, Set2Set
|
6 |
+
import torch.nn.functional as F
|
7 |
+
from torch_scatter import scatter_add
|
8 |
+
from torch_geometric.nn.inits import glorot, zeros
|
9 |
+
|
10 |
+
num_atom_type = 120 #including the extra mask tokens
|
11 |
+
num_chirality_tag = 3
|
12 |
+
|
13 |
+
num_bond_type = 6 #including aromatic and self-loop edge, and extra masked tokens
|
14 |
+
num_bond_direction = 3
|
15 |
+
|
16 |
+
class GINConv(MessagePassing):
|
17 |
+
"""
|
18 |
+
Extension of GIN aggregation to incorporate edge information by concatenation.
|
19 |
+
|
20 |
+
Args:
|
21 |
+
emb_dim (int): dimensionality of embeddings for nodes and edges.
|
22 |
+
embed_input (bool): whether to embed input or not.
|
23 |
+
|
24 |
+
|
25 |
+
See https://arxiv.org/abs/1810.00826
|
26 |
+
"""
|
27 |
+
def __init__(self, emb_dim, aggr = "add"):
|
28 |
+
super(GINConv, self).__init__()
|
29 |
+
#multi-layer perceptron
|
30 |
+
self.mlp = torch.nn.Sequential(torch.nn.Linear(emb_dim, 2*emb_dim), torch.nn.ReLU(), torch.nn.Linear(2*emb_dim, emb_dim))
|
31 |
+
self.edge_embedding1 = torch.nn.Embedding(num_bond_type, emb_dim)
|
32 |
+
self.edge_embedding2 = torch.nn.Embedding(num_bond_direction, emb_dim)
|
33 |
+
|
34 |
+
torch.nn.init.xavier_uniform_(self.edge_embedding1.weight.data)
|
35 |
+
torch.nn.init.xavier_uniform_(self.edge_embedding2.weight.data)
|
36 |
+
self.aggr = aggr
|
37 |
+
|
38 |
+
def forward(self, x, edge_index, edge_attr):
|
39 |
+
#add self loops in the edge space
|
40 |
+
edge_index = add_self_loops(edge_index, num_nodes = x.size(0))
|
41 |
+
|
42 |
+
#add features corresponding to self-loop edges.
|
43 |
+
self_loop_attr = torch.zeros(x.size(0), 2)
|
44 |
+
self_loop_attr[:,0] = 4 #bond type for self-loop edge
|
45 |
+
self_loop_attr = self_loop_attr.to(edge_attr.device).to(edge_attr.dtype)
|
46 |
+
edge_attr = torch.cat((edge_attr, self_loop_attr), dim = 0)
|
47 |
+
|
48 |
+
edge_embeddings = self.edge_embedding1(edge_attr[:,0]) + self.edge_embedding2(edge_attr[:,1])
|
49 |
+
|
50 |
+
return self.propagate(edge_index[0], x=x, edge_attr=edge_embeddings)
|
51 |
+
|
52 |
+
def message(self, x_j, edge_attr):
|
53 |
+
return x_j + edge_attr
|
54 |
+
|
55 |
+
def update(self, aggr_out):
|
56 |
+
return self.mlp(aggr_out)
|
57 |
+
|
58 |
+
|
59 |
+
class GCNConv(MessagePassing):
|
60 |
+
|
61 |
+
def __init__(self, emb_dim, aggr = "add"):
|
62 |
+
super(GCNConv, self).__init__()
|
63 |
+
|
64 |
+
self.emb_dim = emb_dim
|
65 |
+
self.linear = torch.nn.Linear(emb_dim, emb_dim)
|
66 |
+
self.edge_embedding1 = torch.nn.Embedding(num_bond_type, emb_dim)
|
67 |
+
self.edge_embedding2 = torch.nn.Embedding(num_bond_direction, emb_dim)
|
68 |
+
|
69 |
+
torch.nn.init.xavier_uniform_(self.edge_embedding1.weight.data)
|
70 |
+
torch.nn.init.xavier_uniform_(self.edge_embedding2.weight.data)
|
71 |
+
|
72 |
+
self.aggr = aggr
|
73 |
+
|
74 |
+
def norm(self, edge_index, num_nodes, dtype):
|
75 |
+
### assuming that self-loops have been already added in edge_index
|
76 |
+
edge_index = edge_index[0]
|
77 |
+
edge_weight = torch.ones((edge_index.size(1), ), dtype=dtype,
|
78 |
+
device=edge_index.device)
|
79 |
+
row, col = edge_index
|
80 |
+
deg = scatter_add(edge_weight, row, dim=0, dim_size=num_nodes)
|
81 |
+
deg_inv_sqrt = deg.pow(-0.5)
|
82 |
+
deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
|
83 |
+
|
84 |
+
return deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col]
|
85 |
+
|
86 |
+
|
87 |
+
def forward(self, x, edge_index, edge_attr):
|
88 |
+
#add self loops in the edge space
|
89 |
+
edge_index = add_self_loops(edge_index, num_nodes = x.size(0))
|
90 |
+
|
91 |
+
#add features corresponding to self-loop edges.
|
92 |
+
self_loop_attr = torch.zeros(x.size(0), 2)
|
93 |
+
self_loop_attr[:,0] = 4 #bond type for self-loop edge
|
94 |
+
self_loop_attr = self_loop_attr.to(edge_attr.device).to(edge_attr.dtype)
|
95 |
+
edge_attr = torch.cat((edge_attr, self_loop_attr), dim = 0)
|
96 |
+
|
97 |
+
edge_embeddings = self.edge_embedding1(edge_attr[:,0]) + self.edge_embedding2(edge_attr[:,1])
|
98 |
+
|
99 |
+
norm = self.norm(edge_index, x.size(0), x.dtype)
|
100 |
+
|
101 |
+
x = self.linear(x)
|
102 |
+
|
103 |
+
return self.propagate(edge_index[0], x=x, edge_attr=edge_embeddings, norm=norm)
|
104 |
+
|
105 |
+
def message(self, x_j, edge_attr, norm):
|
106 |
+
return norm.view(-1, 1) * (x_j + edge_attr)
|
107 |
+
|
108 |
+
|
109 |
+
class GATConv(MessagePassing):
|
110 |
+
def __init__(self, emb_dim, heads=2, negative_slope=0.2, aggr = "add"):
|
111 |
+
super(GATConv, self).__init__()
|
112 |
+
|
113 |
+
self.aggr = aggr
|
114 |
+
|
115 |
+
self.emb_dim = emb_dim
|
116 |
+
self.heads = heads
|
117 |
+
self.negative_slope = negative_slope
|
118 |
+
|
119 |
+
self.weight_linear = torch.nn.Linear(emb_dim, heads * emb_dim)
|
120 |
+
self.att = torch.nn.Parameter(torch.Tensor(1, heads, 2 * emb_dim))
|
121 |
+
|
122 |
+
self.bias = torch.nn.Parameter(torch.Tensor(emb_dim))
|
123 |
+
|
124 |
+
self.edge_embedding1 = torch.nn.Embedding(num_bond_type, heads * emb_dim)
|
125 |
+
self.edge_embedding2 = torch.nn.Embedding(num_bond_direction, heads * emb_dim)
|
126 |
+
|
127 |
+
torch.nn.init.xavier_uniform_(self.edge_embedding1.weight.data)
|
128 |
+
torch.nn.init.xavier_uniform_(self.edge_embedding2.weight.data)
|
129 |
+
|
130 |
+
self.reset_parameters()
|
131 |
+
|
132 |
+
def norm(self, edge_index, num_nodes, dtype):
|
133 |
+
### assuming that self-loops have been already added in edge_index
|
134 |
+
edge_index = edge_index[0]
|
135 |
+
edge_weight = torch.ones((edge_index.size(1), ), dtype=dtype,
|
136 |
+
device=edge_index.device)
|
137 |
+
row, col = edge_index
|
138 |
+
deg = scatter_add(edge_weight, row, dim=0, dim_size=num_nodes)
|
139 |
+
deg_inv_sqrt = deg.pow(-0.5)
|
140 |
+
deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
|
141 |
+
|
142 |
+
return deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col]
|
143 |
+
|
144 |
+
def reset_parameters(self):
|
145 |
+
glorot(self.att)
|
146 |
+
zeros(self.bias)
|
147 |
+
|
148 |
+
def forward(self, x, edge_index, edge_attr):
|
149 |
+
|
150 |
+
#add self loops in the edge space
|
151 |
+
edge_index = add_self_loops(edge_index, num_nodes = x.size(0))
|
152 |
+
|
153 |
+
norm = self.norm(edge_index, x.size(0), x.dtype)
|
154 |
+
|
155 |
+
#add features corresponding to self-loop edges.
|
156 |
+
self_loop_attr = torch.zeros(x.size(0), 2)
|
157 |
+
self_loop_attr[:,0] = 4 #bond type for self-loop edge
|
158 |
+
self_loop_attr = self_loop_attr.to(edge_attr.device).to(edge_attr.dtype)
|
159 |
+
edge_attr = torch.cat((edge_attr, self_loop_attr), dim = 0)
|
160 |
+
|
161 |
+
edge_embeddings = self.edge_embedding1(edge_attr[:,0]) + self.edge_embedding2(edge_attr[:,1])
|
162 |
+
|
163 |
+
x = self.weight_linear(x).view(-1, self.heads, self.emb_dim)
|
164 |
+
|
165 |
+
return self.propagate(edge_index[0], x=x, edge_attr=edge_embeddings, norm=norm)
|
166 |
+
|
167 |
+
def message(self, edge_index, x_i, x_j, edge_attr):
|
168 |
+
edge_attr = edge_attr.view(-1, self.heads, self.emb_dim)
|
169 |
+
x_j += edge_attr
|
170 |
+
|
171 |
+
alpha = (torch.cat([x_i, x_j], dim=-1) * self.att).sum(dim=-1)
|
172 |
+
|
173 |
+
alpha = F.leaky_relu(alpha, self.negative_slope)
|
174 |
+
alpha = softmax(alpha, edge_index[0])
|
175 |
+
|
176 |
+
return x_j * alpha.view(-1, self.heads, 1)
|
177 |
+
|
178 |
+
def update(self, aggr_out):
|
179 |
+
aggr_out = aggr_out.mean(dim=1)
|
180 |
+
aggr_out = aggr_out + self.bias
|
181 |
+
|
182 |
+
return aggr_out
|
183 |
+
|
184 |
+
|
185 |
+
class GraphSAGEConv(MessagePassing):
|
186 |
+
def __init__(self, emb_dim, aggr = "mean"):
|
187 |
+
super(GraphSAGEConv, self).__init__()
|
188 |
+
|
189 |
+
self.emb_dim = emb_dim
|
190 |
+
self.linear = torch.nn.Linear(emb_dim, emb_dim)
|
191 |
+
self.edge_embedding1 = torch.nn.Embedding(num_bond_type, emb_dim)
|
192 |
+
self.edge_embedding2 = torch.nn.Embedding(num_bond_direction, emb_dim)
|
193 |
+
|
194 |
+
torch.nn.init.xavier_uniform_(self.edge_embedding1.weight.data)
|
195 |
+
torch.nn.init.xavier_uniform_(self.edge_embedding2.weight.data)
|
196 |
+
|
197 |
+
self.aggr = aggr
|
198 |
+
|
199 |
+
def norm(self, edge_index, num_nodes, dtype):
|
200 |
+
### assuming that self-loops have been already added in edge_index
|
201 |
+
edge_index = edge_index[0]
|
202 |
+
edge_weight = torch.ones((edge_index.size(1), ), dtype=dtype,
|
203 |
+
device=edge_index.device)
|
204 |
+
row, col = edge_index
|
205 |
+
deg = scatter_add(edge_weight, row, dim=0, dim_size=num_nodes)
|
206 |
+
deg_inv_sqrt = deg.pow(-0.5)
|
207 |
+
deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
|
208 |
+
|
209 |
+
return deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col]
|
210 |
+
|
211 |
+
def forward(self, x, edge_index, edge_attr):
|
212 |
+
#add self loops in the edge space
|
213 |
+
edge_index = add_self_loops(edge_index, num_nodes = x.size(0))
|
214 |
+
|
215 |
+
#add features corresponding to self-loop edges.
|
216 |
+
self_loop_attr = torch.zeros(x.size(0), 2)
|
217 |
+
self_loop_attr[:,0] = 4 #bond type for self-loop edge
|
218 |
+
self_loop_attr = self_loop_attr.to(edge_attr.device).to(edge_attr.dtype)
|
219 |
+
edge_attr = torch.cat((edge_attr, self_loop_attr), dim = 0)
|
220 |
+
|
221 |
+
edge_embeddings = self.edge_embedding1(edge_attr[:,0]) + self.edge_embedding2(edge_attr[:,1])
|
222 |
+
|
223 |
+
norm = self.norm(edge_index, x.size(0), x.dtype)
|
224 |
+
|
225 |
+
x = self.linear(x)
|
226 |
+
|
227 |
+
return self.propagate(edge_index[0], x=x, edge_attr=edge_embeddings, norm=norm)
|
228 |
+
|
229 |
+
def message(self, x_j, edge_attr):
|
230 |
+
return x_j + edge_attr
|
231 |
+
|
232 |
+
def update(self, aggr_out):
|
233 |
+
return F.normalize(aggr_out, p = 2, dim = -1)
|
234 |
+
|
235 |
+
|
236 |
+
|
237 |
+
class GNN(torch.nn.Module):
|
238 |
+
"""
|
239 |
+
|
240 |
+
|
241 |
+
Args:
|
242 |
+
num_layer (int): the number of GNN layers
|
243 |
+
emb_dim (int): dimensionality of embeddings
|
244 |
+
JK (str): last, concat, max or sum.
|
245 |
+
max_pool_layer (int): the layer from which we use max pool rather than add pool for neighbor aggregation
|
246 |
+
drop_ratio (float): dropout rate
|
247 |
+
gnn_type: gin, gcn, graphsage, gat
|
248 |
+
|
249 |
+
Output:
|
250 |
+
node representations
|
251 |
+
|
252 |
+
"""
|
253 |
+
def __init__(self, num_layer, emb_dim, JK = "last", drop_ratio = 0, gnn_type = "gin"):
|
254 |
+
super(GNN, self).__init__()
|
255 |
+
self.num_layer = num_layer
|
256 |
+
self.drop_ratio = drop_ratio
|
257 |
+
self.JK = JK
|
258 |
+
|
259 |
+
if self.num_layer < 2:
|
260 |
+
raise ValueError("Number of GNN layers must be greater than 1.")
|
261 |
+
|
262 |
+
self.x_embedding1 = torch.nn.Embedding(num_atom_type, emb_dim)
|
263 |
+
self.x_embedding2 = torch.nn.Embedding(num_chirality_tag, emb_dim)
|
264 |
+
|
265 |
+
torch.nn.init.xavier_uniform_(self.x_embedding1.weight.data)
|
266 |
+
torch.nn.init.xavier_uniform_(self.x_embedding2.weight.data)
|
267 |
+
|
268 |
+
###List of MLPs
|
269 |
+
self.gnns = torch.nn.ModuleList()
|
270 |
+
for layer in range(num_layer):
|
271 |
+
if gnn_type == "gin":
|
272 |
+
self.gnns.append(GINConv(emb_dim, aggr = "add"))
|
273 |
+
elif gnn_type == "gcn":
|
274 |
+
self.gnns.append(GCNConv(emb_dim))
|
275 |
+
elif gnn_type == "gat":
|
276 |
+
self.gnns.append(GATConv(emb_dim))
|
277 |
+
elif gnn_type == "graphsage":
|
278 |
+
self.gnns.append(GraphSAGEConv(emb_dim))
|
279 |
+
|
280 |
+
###List of batchnorms
|
281 |
+
self.batch_norms = torch.nn.ModuleList()
|
282 |
+
for layer in range(num_layer):
|
283 |
+
self.batch_norms.append(torch.nn.BatchNorm1d(emb_dim))
|
284 |
+
|
285 |
+
#def forward(self, x, edge_index, edge_attr):
|
286 |
+
def forward(self, *argv):
|
287 |
+
if len(argv) == 3:
|
288 |
+
x, edge_index, edge_attr = argv[0], argv[1], argv[2]
|
289 |
+
elif len(argv) == 1:
|
290 |
+
data = argv[0]
|
291 |
+
x, edge_index, edge_attr = data.x, data.edge_index, data.edge_attr
|
292 |
+
else:
|
293 |
+
raise ValueError("unmatched number of arguments.")
|
294 |
+
|
295 |
+
x = self.x_embedding1(x[:,0]) + self.x_embedding2(x[:,1])
|
296 |
+
|
297 |
+
h_list = [x]
|
298 |
+
for layer in range(self.num_layer):
|
299 |
+
h = self.gnns[layer](h_list[layer], edge_index, edge_attr)
|
300 |
+
h = self.batch_norms[layer](h)
|
301 |
+
#h = F.dropout(F.relu(h), self.drop_ratio, training = self.training)
|
302 |
+
if layer == self.num_layer - 1:
|
303 |
+
#remove relu for the last layer
|
304 |
+
h = F.dropout(h, self.drop_ratio, training = self.training)
|
305 |
+
else:
|
306 |
+
h = F.dropout(F.relu(h), self.drop_ratio, training = self.training)
|
307 |
+
h_list.append(h)
|
308 |
+
|
309 |
+
### Different implementations of Jk-concat
|
310 |
+
if self.JK == "concat":
|
311 |
+
node_representation = torch.cat(h_list, dim = 1)
|
312 |
+
elif self.JK == "last":
|
313 |
+
node_representation = h_list[-1]
|
314 |
+
elif self.JK == "max":
|
315 |
+
h_list = [h.unsqueeze_(0) for h in h_list]
|
316 |
+
node_representation = torch.max(torch.cat(h_list, dim = 0), dim = 0)[0]
|
317 |
+
elif self.JK == "sum":
|
318 |
+
h_list = [h.unsqueeze_(0) for h in h_list]
|
319 |
+
node_representation = torch.sum(torch.cat(h_list, dim = 0), dim = 0)[0]
|
320 |
+
|
321 |
+
return node_representation
|
322 |
+
|
323 |
+
|
324 |
+
class GNN_graphpred(torch.nn.Module):
|
325 |
+
"""
|
326 |
+
Extension of GIN to incorporate edge information by concatenation.
|
327 |
+
|
328 |
+
Args:
|
329 |
+
num_layer (int): the number of GNN layers
|
330 |
+
emb_dim (int): dimensionality of embeddings
|
331 |
+
num_tasks (int): number of tasks in multi-task learning scenario
|
332 |
+
drop_ratio (float): dropout rate
|
333 |
+
JK (str): last, concat, max or sum.
|
334 |
+
graph_pooling (str): sum, mean, max, attention, set2set
|
335 |
+
gnn_type: gin, gcn, graphsage, gat
|
336 |
+
|
337 |
+
See https://arxiv.org/abs/1810.00826
|
338 |
+
JK-net: https://arxiv.org/abs/1806.03536
|
339 |
+
"""
|
340 |
+
def __init__(self, num_layer, emb_dim, num_tasks, JK = "last", drop_ratio = 0, graph_pooling = "mean", gnn_type = "gin"):
|
341 |
+
super(GNN_graphpred, self).__init__()
|
342 |
+
self.num_layer = num_layer
|
343 |
+
self.drop_ratio = drop_ratio
|
344 |
+
self.JK = JK
|
345 |
+
self.emb_dim = emb_dim
|
346 |
+
self.num_tasks = num_tasks
|
347 |
+
|
348 |
+
if self.num_layer < 2:
|
349 |
+
raise ValueError("Number of GNN layers must be greater than 1.")
|
350 |
+
|
351 |
+
self.gnn = GNN(num_layer, emb_dim, JK, drop_ratio, gnn_type = gnn_type)
|
352 |
+
|
353 |
+
#Different kind of graph pooling
|
354 |
+
if graph_pooling == "sum":
|
355 |
+
self.pool = global_add_pool
|
356 |
+
elif graph_pooling == "mean":
|
357 |
+
self.pool = global_mean_pool
|
358 |
+
elif graph_pooling == "max":
|
359 |
+
self.pool = global_max_pool
|
360 |
+
elif graph_pooling == "attention":
|
361 |
+
if self.JK == "concat":
|
362 |
+
self.pool = GlobalAttention(gate_nn = torch.nn.Linear((self.num_layer + 1) * emb_dim, 1))
|
363 |
+
else:
|
364 |
+
self.pool = GlobalAttention(gate_nn = torch.nn.Linear(emb_dim, 1))
|
365 |
+
elif graph_pooling[:-1] == "set2set":
|
366 |
+
set2set_iter = int(graph_pooling[-1])
|
367 |
+
if self.JK == "concat":
|
368 |
+
self.pool = Set2Set((self.num_layer + 1) * emb_dim, set2set_iter)
|
369 |
+
else:
|
370 |
+
self.pool = Set2Set(emb_dim, set2set_iter)
|
371 |
+
else:
|
372 |
+
raise ValueError("Invalid graph pooling type.")
|
373 |
+
|
374 |
+
#For graph-level binary classification
|
375 |
+
if graph_pooling[:-1] == "set2set":
|
376 |
+
self.mult = 2
|
377 |
+
else:
|
378 |
+
self.mult = 1
|
379 |
+
|
380 |
+
if self.JK == "concat":
|
381 |
+
self.graph_pred_linear = torch.nn.Linear(self.mult * (self.num_layer + 1) * self.emb_dim, self.num_tasks)
|
382 |
+
else:
|
383 |
+
self.graph_pred_linear = torch.nn.Linear(self.mult * self.emb_dim, self.num_tasks)
|
384 |
+
|
385 |
+
def from_pretrained(self, model_file):
|
386 |
+
#self.gnn = GNN(self.num_layer, self.emb_dim, JK = self.JK, drop_ratio = self.drop_ratio)
|
387 |
+
self.gnn.load_state_dict(torch.load(model_file))
|
388 |
+
|
389 |
+
def forward(self, *argv):
|
390 |
+
if len(argv) == 4:
|
391 |
+
x, edge_index, edge_attr, batch = argv[0], argv[1], argv[2], argv[3]
|
392 |
+
elif len(argv) == 1:
|
393 |
+
data = argv[0]
|
394 |
+
x, edge_index, edge_attr, batch = data.x, data.edge_index, data.edge_attr, data.batch
|
395 |
+
else:
|
396 |
+
raise ValueError("unmatched number of arguments.")
|
397 |
+
|
398 |
+
node_representation = self.gnn(x, edge_index, edge_attr)
|
399 |
+
|
400 |
+
return self.graph_pred_linear(self.pool(node_representation, batch))
|
401 |
+
|
402 |
+
|
403 |
+
class MLP(nn.Module):
|
404 |
+
"""
|
405 |
+
Creates a NN using nn.ModuleList to automatically adjust the number of layers.
|
406 |
+
For each hidden layer, the number of inputs and outputs is constant.
|
407 |
+
|
408 |
+
Inputs:
|
409 |
+
in_dim (int): number of features contained in the input layer.
|
410 |
+
out_dim (int): number of features input and output from each hidden layer,
|
411 |
+
including the output layer.
|
412 |
+
num_layers (int): number of layers in the network
|
413 |
+
activation (torch function): activation function to be used during the hidden layers
|
414 |
+
"""
|
415 |
+
|
416 |
+
def __init__(self, in_dim, out_dim, num_layers, activation=torch.nn.ReLU(), layer_norm=False, batch_norm=False):
|
417 |
+
super(MLP, self).__init__()
|
418 |
+
self.layers = nn.ModuleList()
|
419 |
+
|
420 |
+
h_dim = in_dim if out_dim < 10 else out_dim
|
421 |
+
|
422 |
+
# create the input layer
|
423 |
+
for layer in range(num_layers):
|
424 |
+
if layer == 0:
|
425 |
+
self.layers.append(nn.Linear(in_dim, h_dim))
|
426 |
+
else:
|
427 |
+
self.layers.append(nn.Linear(h_dim, h_dim))
|
428 |
+
if layer_norm: self.layers.append(nn.LayerNorm(h_dim))
|
429 |
+
if batch_norm: self.layers.append(nn.BatchNorm1d(h_dim))
|
430 |
+
self.layers.append(activation)
|
431 |
+
self.layers.append(nn.Linear(h_dim, out_dim))
|
432 |
+
|
433 |
+
def forward(self, x):
|
434 |
+
for i in range(len(self.layers)):
|
435 |
+
x = self.layers[i](x)
|
436 |
+
return x
|
437 |
+
|
438 |
+
|
439 |
+
if __name__ == "__main__":
|
440 |
+
pass
|
441 |
+
|
models/encoders/schnet.py
ADDED
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
from torch.nn import Module, Sequential, ModuleList, Linear
|
4 |
+
from torch_geometric.nn import MessagePassing, radius_graph
|
5 |
+
from math import pi as PI
|
6 |
+
|
7 |
+
from ..common import GaussianSmearing, ShiftedSoftplus
|
8 |
+
|
9 |
+
|
10 |
+
class CFConv(MessagePassing):
|
11 |
+
|
12 |
+
def __init__(self, in_channels, out_channels, num_filters, edge_channels, cutoff=10.0):
|
13 |
+
super().__init__(aggr='add')
|
14 |
+
self.lin1 = Linear(in_channels, num_filters, bias=False)
|
15 |
+
self.lin2 = Linear(num_filters, out_channels)
|
16 |
+
self.nn = Sequential(
|
17 |
+
Linear(edge_channels, num_filters),
|
18 |
+
ShiftedSoftplus(),
|
19 |
+
Linear(num_filters, num_filters),
|
20 |
+
) # Network for generating filter weights
|
21 |
+
self.cutoff = cutoff
|
22 |
+
self.reset_parameters()
|
23 |
+
|
24 |
+
def reset_parameters(self):
|
25 |
+
torch.nn.init.xavier_uniform_(self.nn[0].weight)
|
26 |
+
self.nn[0].bias.data.fill_(0)
|
27 |
+
torch.nn.init.xavier_uniform_(self.nn[2].weight)
|
28 |
+
self.nn[0].bias.data.fill_(0)
|
29 |
+
torch.nn.init.xavier_uniform_(self.lin1.weight)
|
30 |
+
torch.nn.init.xavier_uniform_(self.lin2.weight)
|
31 |
+
self.lin2.bias.data.fill_(0)
|
32 |
+
|
33 |
+
def forward(self, x, edge_index, edge_length, edge_attr):
|
34 |
+
W = self.nn(edge_attr)
|
35 |
+
|
36 |
+
if self.cutoff is not None:
|
37 |
+
C = 0.5 * (torch.cos(edge_length * PI / self.cutoff) + 1.0)
|
38 |
+
C = C * (edge_length <= self.cutoff) * (edge_length >= 0.0) # Modification: cutoff
|
39 |
+
W = W * C.view(-1, 1)
|
40 |
+
|
41 |
+
x = self.lin1(x)
|
42 |
+
x = self.propagate(edge_index, x=x, W=W)
|
43 |
+
x = self.lin2(x)
|
44 |
+
return x
|
45 |
+
|
46 |
+
def message(self, x_j, W):
|
47 |
+
return x_j * W
|
48 |
+
|
49 |
+
|
50 |
+
class InteractionBlock(Module):
|
51 |
+
|
52 |
+
def __init__(self, hidden_channels, num_gaussians, num_filters, cutoff):
|
53 |
+
super(InteractionBlock, self).__init__()
|
54 |
+
self.conv = CFConv(hidden_channels, hidden_channels, num_filters, num_gaussians, cutoff)
|
55 |
+
self.act = ShiftedSoftplus()
|
56 |
+
self.lin = Linear(hidden_channels, hidden_channels)
|
57 |
+
self.reset_parameters()
|
58 |
+
|
59 |
+
def reset_parameters(self):
|
60 |
+
self.conv.reset_parameters()
|
61 |
+
torch.nn.init.xavier_uniform_(self.lin.weight)
|
62 |
+
self.lin.bias.data.fill_(0)
|
63 |
+
|
64 |
+
def forward(self, x, edge_index, edge_length, edge_attr):
|
65 |
+
x = self.conv(x, edge_index, edge_length, edge_attr)
|
66 |
+
x = self.act(x)
|
67 |
+
x = self.lin(x)
|
68 |
+
return x
|
69 |
+
|
70 |
+
|
71 |
+
class SchNetEncoder(Module):
|
72 |
+
|
73 |
+
def __init__(self, hidden_channels=128, num_filters=128,
|
74 |
+
num_interactions=6, edge_channels=64, cutoff=10.0):
|
75 |
+
super().__init__()
|
76 |
+
|
77 |
+
self.hidden_channels = hidden_channels
|
78 |
+
self.num_filters = num_filters
|
79 |
+
self.num_interactions = num_interactions
|
80 |
+
self.distance_expansion = GaussianSmearing(stop=cutoff, num_gaussians=edge_channels)
|
81 |
+
self.cutoff = cutoff
|
82 |
+
|
83 |
+
self.interactions = ModuleList()
|
84 |
+
for _ in range(num_interactions):
|
85 |
+
block = InteractionBlock(hidden_channels, edge_channels,
|
86 |
+
num_filters, cutoff)
|
87 |
+
self.interactions.append(block)
|
88 |
+
self.reset_parameters()
|
89 |
+
|
90 |
+
def reset_parameters(self):
|
91 |
+
for interaction in self.interactions:
|
92 |
+
interaction.reset_parameters()
|
93 |
+
|
94 |
+
@property
|
95 |
+
def out_channels(self):
|
96 |
+
return self.hidden_channels
|
97 |
+
|
98 |
+
def forward(self, node_attr, pos, batch):
|
99 |
+
edge_index = radius_graph(pos, self.cutoff, batch=batch, loop=False)
|
100 |
+
edge_length = torch.norm(pos[edge_index[0]] - pos[edge_index[1]], dim=1)
|
101 |
+
edge_attr = self.distance_expansion(edge_length)
|
102 |
+
h = node_attr
|
103 |
+
for interaction in self.interactions:
|
104 |
+
h = h + interaction(h, edge_index, edge_length, edge_attr)
|
105 |
+
return h
|
models/encoders/tf.py
ADDED
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
from torch import nn
|
4 |
+
from torch.nn import Module, Sequential, ModuleList, Linear, Conv1d, LeakyReLU
|
5 |
+
from torch_geometric.nn import radius_graph, knn_graph
|
6 |
+
from torch_scatter import scatter_sum, scatter_softmax
|
7 |
+
import math
|
8 |
+
from math import pi as PI
|
9 |
+
|
10 |
+
from ..common import GaussianSmearing, ShiftedSoftplus
|
11 |
+
|
12 |
+
|
13 |
+
class AttentionInteractionBlock(Module):
|
14 |
+
|
15 |
+
def __init__(self, hidden_channels, edge_channels, key_channels, num_heads=1):
|
16 |
+
super().__init__()
|
17 |
+
|
18 |
+
assert hidden_channels % num_heads == 0
|
19 |
+
assert key_channels % num_heads == 0
|
20 |
+
|
21 |
+
self.hidden_channels = hidden_channels
|
22 |
+
self.key_channels = key_channels
|
23 |
+
self.num_heads = num_heads
|
24 |
+
|
25 |
+
self.k_lin = Conv1d(hidden_channels, key_channels, 1, groups=num_heads, bias=False)
|
26 |
+
self.q_lin = Conv1d(hidden_channels, key_channels, 1, groups=num_heads, bias=False)
|
27 |
+
self.v_lin = Conv1d(hidden_channels, hidden_channels, 1, groups=num_heads, bias=False)
|
28 |
+
|
29 |
+
self.weight_k_net = Sequential(
|
30 |
+
Linear(edge_channels, key_channels // num_heads),
|
31 |
+
LeakyReLU(),
|
32 |
+
Linear(key_channels // num_heads, key_channels // num_heads),
|
33 |
+
)
|
34 |
+
self.weight_k_lin = Linear(key_channels // num_heads, key_channels // num_heads)
|
35 |
+
|
36 |
+
self.weight_v_net = Sequential(
|
37 |
+
Linear(edge_channels, hidden_channels // num_heads),
|
38 |
+
LeakyReLU(),
|
39 |
+
Linear(hidden_channels // num_heads, hidden_channels // num_heads),
|
40 |
+
)
|
41 |
+
self.weight_v_lin = Linear(hidden_channels // num_heads, hidden_channels // num_heads)
|
42 |
+
|
43 |
+
self.centroid_lin = Linear(hidden_channels, hidden_channels)
|
44 |
+
self.act = LeakyReLU()
|
45 |
+
self.out_transform = Linear(hidden_channels, hidden_channels)
|
46 |
+
self.layernorm_ffn = nn.LayerNorm(hidden_channels)
|
47 |
+
|
48 |
+
def forward(self, x, edge_index, edge_attr):
|
49 |
+
"""
|
50 |
+
Args:
|
51 |
+
x: Node features, (N, H).
|
52 |
+
edge_index: (2, E).
|
53 |
+
edge_attr: (E, H)
|
54 |
+
"""
|
55 |
+
N = x.size(0)
|
56 |
+
row, col = edge_index # (E,) , (E,)
|
57 |
+
|
58 |
+
# Project to multiple key, query and value spaces
|
59 |
+
h_keys = self.k_lin(x.unsqueeze(-1)).view(N, self.num_heads, -1) # (N, heads, K_per_head)
|
60 |
+
h_queries = self.q_lin(x.unsqueeze(-1)).view(N, self.num_heads, -1) # (N, heads, K_per_head)
|
61 |
+
h_values = self.v_lin(x.unsqueeze(-1)).view(N, self.num_heads, -1) # (N, heads, H_per_head)
|
62 |
+
|
63 |
+
# Compute keys and queries
|
64 |
+
W_k = self.weight_k_net(edge_attr) # (E, K_per_head)
|
65 |
+
keys_j = self.weight_k_lin(W_k.unsqueeze(1) * h_keys[col]) # (E, heads, K_per_head)
|
66 |
+
queries_i = h_queries[row] # (E, heads, K_per_head)
|
67 |
+
|
68 |
+
# Compute attention weights (alphas)
|
69 |
+
d = int(self.hidden_channels / self.num_heads)
|
70 |
+
qk_ij = (queries_i * keys_j).sum(-1) / math.sqrt(d) # (E, heads)
|
71 |
+
alpha = scatter_softmax(qk_ij, row, dim=0)
|
72 |
+
|
73 |
+
# Compose messages
|
74 |
+
W_v = self.weight_v_net(edge_attr) # (E, H_per_head)
|
75 |
+
msg_j = self.weight_v_lin(W_v.unsqueeze(1) * h_values[col]) # (E, heads, H_per_head)
|
76 |
+
msg_j = alpha.unsqueeze(-1) * msg_j # (E, heads, H_per_head)
|
77 |
+
|
78 |
+
# Aggregate messages
|
79 |
+
aggr_msg = scatter_sum(msg_j, row, dim=0, dim_size=N).view(N, -1) # (N, heads*H_per_head)
|
80 |
+
out = self.centroid_lin(x) + aggr_msg
|
81 |
+
out = self.layernorm_ffn(out)
|
82 |
+
out = self.out_transform(self.act(out))
|
83 |
+
return out
|
84 |
+
|
85 |
+
|
86 |
+
class TransformerEncoder(Module):
|
87 |
+
|
88 |
+
def __init__(self, hidden_channels=256, edge_channels=64, key_channels=128, num_heads=4, num_interactions=6, k=32,
|
89 |
+
cutoff=10.0):
|
90 |
+
super().__init__()
|
91 |
+
|
92 |
+
self.hidden_channels = hidden_channels
|
93 |
+
self.edge_channels = edge_channels
|
94 |
+
self.key_channels = key_channels
|
95 |
+
self.num_heads = num_heads
|
96 |
+
self.num_interactions = num_interactions
|
97 |
+
self.k = k
|
98 |
+
self.cutoff = cutoff
|
99 |
+
|
100 |
+
self.distance_expansion = GaussianSmearing(stop=cutoff, num_gaussians=edge_channels)
|
101 |
+
self.interactions = ModuleList()
|
102 |
+
for _ in range(num_interactions):
|
103 |
+
block = AttentionInteractionBlock(
|
104 |
+
hidden_channels=hidden_channels,
|
105 |
+
edge_channels=edge_channels,
|
106 |
+
key_channels=key_channels,
|
107 |
+
num_heads=num_heads,
|
108 |
+
)
|
109 |
+
self.interactions.append(block)
|
110 |
+
|
111 |
+
@property
|
112 |
+
def out_channels(self):
|
113 |
+
return self.hidden_channels
|
114 |
+
|
115 |
+
def forward(self, node_attr, pos, batch):
|
116 |
+
# edge_index = radius_graph(pos, self.cutoff, batch=batch, loop=False)
|
117 |
+
edge_index = knn_graph(pos, k=self.k, batch=batch, flow='target_to_source')
|
118 |
+
edge_length = torch.norm(pos[edge_index[0]] - pos[edge_index[1]], dim=1)
|
119 |
+
edge_attr = self.distance_expansion(edge_length)
|
120 |
+
|
121 |
+
h = node_attr
|
122 |
+
for interaction in self.interactions:
|
123 |
+
h = h + interaction(h, edge_index, edge_attr)
|
124 |
+
return h
|
125 |
+
|
126 |
+
|
127 |
+
if __name__ == '__main__':
|
128 |
+
from torch_geometric.data import Data, Batch
|
129 |
+
|
130 |
+
hidden_channels = 64
|
131 |
+
edge_channels = 48
|
132 |
+
key_channels = 32
|
133 |
+
num_heads = 4
|
134 |
+
|
135 |
+
data_list = []
|
136 |
+
for num_nodes in [11, 13, 15]:
|
137 |
+
data_list.append(Data(
|
138 |
+
x=torch.randn([num_nodes, hidden_channels]),
|
139 |
+
pos=torch.randn([num_nodes, 3]) * 2
|
140 |
+
))
|
141 |
+
batch = Batch.from_data_list(data_list)
|
142 |
+
|
143 |
+
model = CFTransformerEncoder(
|
144 |
+
hidden_channels=hidden_channels,
|
145 |
+
edge_channels=edge_channels,
|
146 |
+
key_channels=key_channels,
|
147 |
+
num_heads=num_heads,
|
148 |
+
)
|
149 |
+
out = model(batch.x, batch.pos, batch.batch)
|
150 |
+
|
151 |
+
print(out)
|
152 |
+
print(out.size())
|
models/flag.py
ADDED
@@ -0,0 +1,268 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
sys.path.append("..")
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
from torch.nn import Module, Linear, Embedding
|
6 |
+
from torch.nn import functional as F
|
7 |
+
from torch_scatter import scatter_add, scatter_mean
|
8 |
+
from torch_geometric.data import Data, Batch
|
9 |
+
from copy import deepcopy
|
10 |
+
|
11 |
+
from .encoders import get_encoder, GNN_graphpred, MLP
|
12 |
+
from .common import *
|
13 |
+
from utils import dihedral_utils, chemutils
|
14 |
+
|
15 |
+
|
16 |
+
class FLAG(Module):
|
17 |
+
|
18 |
+
def __init__(self, config, protein_atom_feature_dim, ligand_atom_feature_dim, vocab):
|
19 |
+
super().__init__()
|
20 |
+
self.config = config
|
21 |
+
self.vocab = vocab
|
22 |
+
self.protein_atom_emb = Linear(protein_atom_feature_dim, config.hidden_channels)
|
23 |
+
self.ligand_atom_emb = Linear(ligand_atom_feature_dim, config.hidden_channels)
|
24 |
+
self.embedding = nn.Embedding(vocab.size() + 1, config.hidden_channels)
|
25 |
+
self.W = nn.Linear(2 * config.hidden_channels, config.hidden_channels)
|
26 |
+
self.W_o = nn.Linear(config.hidden_channels, self.vocab.size())
|
27 |
+
self.encoder = get_encoder(config.encoder)
|
28 |
+
self.comb_head = GNN_graphpred(num_layer=3, emb_dim=config.hidden_channels, num_tasks=1, JK='last',
|
29 |
+
drop_ratio=0.5, graph_pooling='mean', gnn_type='gin')
|
30 |
+
if config.random_alpha:
|
31 |
+
self.alpha_mlp = MLP(in_dim=config.hidden_channels * 4, out_dim=1, num_layers=2)
|
32 |
+
else:
|
33 |
+
self.alpha_mlp = MLP(in_dim=config.hidden_channels * 3, out_dim=1, num_layers=2)
|
34 |
+
self.focal_mlp_ligand = MLP(in_dim=config.hidden_channels, out_dim=1, num_layers=1)
|
35 |
+
self.focal_mlp_protein = MLP(in_dim=config.hidden_channels, out_dim=1, num_layers=1)
|
36 |
+
self.dist_mlp = MLP(in_dim=protein_atom_feature_dim + ligand_atom_feature_dim, out_dim=1, num_layers=2)
|
37 |
+
if config.refinement:
|
38 |
+
self.refine_protein = MLP(in_dim=config.hidden_channels * 2 + config.encoder.edge_channels, out_dim=1, num_layers=2)
|
39 |
+
self.refine_ligand = MLP(in_dim=config.hidden_channels * 2 + config.encoder.edge_channels, out_dim=1, num_layers=2)
|
40 |
+
|
41 |
+
self.smooth_cross_entropy = SmoothCrossEntropyLoss(reduction='mean', smoothing=0.1)
|
42 |
+
self.pred_loss = nn.CrossEntropyLoss()
|
43 |
+
self.comb_loss = nn.BCEWithLogitsLoss()
|
44 |
+
self.three_hop_loss = torch.nn.MSELoss()
|
45 |
+
self.focal_loss = nn.BCEWithLogitsLoss()
|
46 |
+
self.dist_loss = torch.nn.MSELoss(reduction='mean')
|
47 |
+
|
48 |
+
def forward(self, protein_pos, protein_atom_feature, ligand_pos, ligand_atom_feature, batch_protein, batch_ligand):
|
49 |
+
h_protein = self.protein_atom_emb(protein_atom_feature)
|
50 |
+
h_ligand = self.ligand_atom_emb(ligand_atom_feature)
|
51 |
+
|
52 |
+
h_ctx, pos_ctx, batch_ctx, protein_mask = compose_context_stable(h_protein=h_protein, h_ligand=h_ligand,
|
53 |
+
pos_protein=protein_pos, pos_ligand=ligand_pos,
|
54 |
+
batch_protein=batch_protein,
|
55 |
+
batch_ligand=batch_ligand)
|
56 |
+
h_ctx = self.encoder(node_attr=h_ctx, pos=pos_ctx, batch=batch_ctx) # (N_p+N_l, H)
|
57 |
+
focal_pred = torch.cat([self.focal_mlp_protein(h_ctx[protein_mask]), self.focal_mlp_ligand(h_ctx[~protein_mask])], dim=0)
|
58 |
+
|
59 |
+
return focal_pred, protein_mask, h_ctx
|
60 |
+
|
61 |
+
def forward_motif(self, h_ctx_focal, current_wid, current_atoms_batch, n_samples=1):
|
62 |
+
node_hiddens = scatter_add(h_ctx_focal, dim=0, index=current_atoms_batch)
|
63 |
+
motif_hiddens = self.embedding(current_wid)
|
64 |
+
pred_vecs = torch.cat([node_hiddens, motif_hiddens], dim=1)
|
65 |
+
pred_vecs = nn.ReLU()(self.W(pred_vecs))
|
66 |
+
pred_scores = self.W_o(pred_vecs)
|
67 |
+
pred_scores = F.softmax(pred_scores, dim=-1)
|
68 |
+
_, preds = torch.max(pred_scores, dim=1)
|
69 |
+
# random select n_samples in topk
|
70 |
+
k = 5*n_samples
|
71 |
+
select_pool = torch.topk(pred_scores, k, dim=1)[1]
|
72 |
+
index = torch.randint(k, (select_pool.shape[0], n_samples))
|
73 |
+
preds = torch.cat([select_pool[i][index[i]] for i in range(len(index))])
|
74 |
+
|
75 |
+
idx_parent = torch.repeat_interleave(torch.arange(pred_scores.shape[0]), n_samples, dim=0).to(pred_scores.device)
|
76 |
+
prob = pred_scores[idx_parent, preds]
|
77 |
+
return preds, prob
|
78 |
+
|
79 |
+
def forward_attach(self, mol_list, next_motif_smiles, device):
|
80 |
+
cand_mols, cand_batch, new_atoms, one_atom_attach, intersection, attach_fail = chemutils.assemble(mol_list, next_motif_smiles)
|
81 |
+
graph_data = Batch.from_data_list([chemutils.mol_to_graph_data_obj_simple(mol) for mol in cand_mols]).to(device)
|
82 |
+
comb_pred = self.comb_head(graph_data.x, graph_data.edge_index, graph_data.edge_attr, graph_data.batch).reshape(-1)
|
83 |
+
slice_idx = torch.cat([torch.tensor([0]), torch.cumsum(cand_batch.bincount(), dim=0)], dim=0)
|
84 |
+
select = [(torch.argmax(comb_pred[slice_idx[i]:slice_idx[i + 1]]) + slice_idx[i]).item() for i in
|
85 |
+
range(len(slice_idx) - 1)]
|
86 |
+
'''
|
87 |
+
select = []
|
88 |
+
for k in range(len(slice_idx) - 1):
|
89 |
+
id = torch.multinomial(torch.exp(comb_pred[slice_idx[k]:slice_idx[k + 1]]).reshape(-1).float(), 1)
|
90 |
+
select.append((id+slice_idx[k]).item())'''
|
91 |
+
|
92 |
+
select_mols = [cand_mols[i] for i in select]
|
93 |
+
new_atoms = [new_atoms[i] for i in select]
|
94 |
+
one_atom_attach = [one_atom_attach[i] for i in select]
|
95 |
+
intersection = [intersection[i] for i in select]
|
96 |
+
return select_mols, new_atoms, one_atom_attach, intersection, attach_fail
|
97 |
+
|
98 |
+
def forward_alpha(self, protein_pos, protein_atom_feature, ligand_pos, ligand_atom_feature, batch_protein,
|
99 |
+
batch_ligand, xy_index, rotatable):
|
100 |
+
# encode again
|
101 |
+
h_protein = self.protein_atom_emb(protein_atom_feature)
|
102 |
+
h_ligand = self.ligand_atom_emb(ligand_atom_feature)
|
103 |
+
|
104 |
+
h_ctx, pos_ctx, batch_ctx, protein_mask = compose_context_stable(h_protein=h_protein, h_ligand=h_ligand,
|
105 |
+
pos_protein=protein_pos, pos_ligand=ligand_pos,
|
106 |
+
batch_protein=batch_protein,
|
107 |
+
batch_ligand=batch_ligand)
|
108 |
+
h_ctx = self.encoder(node_attr=h_ctx, pos=pos_ctx, batch=batch_ctx) # (N_p+N_l, H)
|
109 |
+
h_ctx_ligand = h_ctx[~protein_mask]
|
110 |
+
hx, hy = h_ctx_ligand[xy_index[:, 0]], h_ctx_ligand[xy_index[:, 1]]
|
111 |
+
h_mol = scatter_add(h_ctx_ligand, dim=0, index=batch_ligand)
|
112 |
+
h_mol = h_mol[rotatable]
|
113 |
+
if self.config.random_alpha:
|
114 |
+
rand_dist = torch.distributions.normal.Normal(loc=0, scale=1)
|
115 |
+
rand_alpha = rand_dist.sample(hx.shape).to(hx.device)
|
116 |
+
alpha = self.alpha_mlp(torch.cat([hx, hy, h_mol, rand_alpha], dim=-1))
|
117 |
+
else:
|
118 |
+
alpha = self.alpha_mlp(torch.cat([hx, hy, h_mol], dim=-1))
|
119 |
+
return alpha
|
120 |
+
|
121 |
+
def get_loss(self, protein_pos, protein_atom_feature, ligand_pos, ligand_atom_feature, ligand_pos_torsion,
|
122 |
+
ligand_atom_feature_torsion, batch_protein, batch_ligand, batch_ligand_torsion, batch):
|
123 |
+
self.device = protein_pos.device
|
124 |
+
h_protein = self.protein_atom_emb(protein_atom_feature)
|
125 |
+
h_ligand = self.ligand_atom_emb(ligand_atom_feature)
|
126 |
+
|
127 |
+
loss_list = [0, 0, 0, 0, 0, 0]
|
128 |
+
|
129 |
+
# Encode for motif prediction
|
130 |
+
h_ctx, pos_ctx, batch_ctx, mask_protein = compose_context_stable(h_protein=h_protein, h_ligand=h_ligand,
|
131 |
+
pos_protein=protein_pos, pos_ligand=ligand_pos,
|
132 |
+
batch_protein=batch_protein,
|
133 |
+
batch_ligand=batch_ligand)
|
134 |
+
h_ctx = self.encoder(node_attr=h_ctx, pos=pos_ctx, batch=batch_ctx) # (N_p+N_l, H)
|
135 |
+
h_ctx_ligand = h_ctx[~mask_protein]
|
136 |
+
h_ctx_protein = h_ctx[mask_protein]
|
137 |
+
h_ctx_focal = h_ctx[batch['current_atoms']]
|
138 |
+
|
139 |
+
# Encode for torsion prediction
|
140 |
+
if len(batch['y_pos']) > 0:
|
141 |
+
h_ligand_torsion = self.ligand_atom_emb(ligand_atom_feature_torsion)
|
142 |
+
h_ctx_torison, pos_ctx_torison, batch_ctx_torsion, mask_protein = compose_context_stable(h_protein=h_protein,
|
143 |
+
h_ligand=h_ligand_torsion,
|
144 |
+
pos_protein=protein_pos,
|
145 |
+
pos_ligand=ligand_pos_torsion,
|
146 |
+
batch_protein=batch_protein,
|
147 |
+
batch_ligand=batch_ligand_torsion)
|
148 |
+
h_ctx_torsion = self.encoder(node_attr=h_ctx_torison, pos=pos_ctx_torison, batch=batch_ctx_torsion) # (N_p+N_l, H)
|
149 |
+
h_ctx_ligand_torsion = h_ctx_torsion[~mask_protein]
|
150 |
+
|
151 |
+
# next motif prediction
|
152 |
+
|
153 |
+
node_hiddens = scatter_add(h_ctx_focal, dim=0, index=batch['current_atoms_batch'])
|
154 |
+
motif_hiddens = self.embedding(batch['current_wid'])
|
155 |
+
pred_vecs = torch.cat([node_hiddens, motif_hiddens], dim=1)
|
156 |
+
pred_vecs = nn.ReLU()(self.W(pred_vecs))
|
157 |
+
pred_scores = self.W_o(pred_vecs)
|
158 |
+
pred_loss = self.pred_loss(pred_scores, batch['next_wid'])
|
159 |
+
loss_list[0] = pred_loss.item()
|
160 |
+
|
161 |
+
# attachment prediction
|
162 |
+
if len(batch['cand_labels']) > 0:
|
163 |
+
cand_mols = batch['cand_mols']
|
164 |
+
comb_pred = self.comb_head(cand_mols.x, cand_mols.edge_index, cand_mols.edge_attr, cand_mols.batch)
|
165 |
+
comb_loss = self.comb_loss(comb_pred, batch['cand_labels'].view(comb_pred.shape).float())
|
166 |
+
loss_list[1] = comb_loss.item()
|
167 |
+
else:
|
168 |
+
comb_loss = 0
|
169 |
+
|
170 |
+
# focal prediction
|
171 |
+
focal_ligand_pred, focal_protein_pred = self.focal_mlp_ligand(h_ctx_ligand), self.focal_mlp_protein(h_ctx_protein)
|
172 |
+
focal_loss = self.focal_loss(focal_ligand_pred.reshape(-1), batch['ligand_frontier'].float()) +\
|
173 |
+
self.focal_loss(focal_protein_pred.reshape(-1), batch['protein_contact'].float())
|
174 |
+
loss_list[2] = focal_loss.item()
|
175 |
+
|
176 |
+
# distance matrix prediction
|
177 |
+
if len(batch['true_dm']) > 0:
|
178 |
+
input = torch.cat([protein_atom_feature[batch['dm_protein_idx']], ligand_atom_feature[batch['dm_ligand_idx']]], dim=-1)
|
179 |
+
pred_dist = self.dist_mlp(input)
|
180 |
+
dm_target = batch['true_dm'].unsqueeze(-1)
|
181 |
+
dm_loss = self.dist_loss(pred_dist, dm_target)
|
182 |
+
loss_list[3] = dm_loss.item()
|
183 |
+
else:
|
184 |
+
dm_loss = 0
|
185 |
+
|
186 |
+
# structure refinement loss
|
187 |
+
if self.config.refinement and len(batch['true_dm']) > 0:
|
188 |
+
true_distance_alpha = torch.norm(batch['ligand_context_pos'][batch['sr_ligand_idx']] - batch['protein_pos'][batch['sr_protein_idx']], dim=1)
|
189 |
+
true_distance_intra = torch.norm(batch['ligand_context_pos'][batch['sr_ligand_idx0']] - batch['ligand_context_pos'][batch['sr_ligand_idx1']], dim=1)
|
190 |
+
input_distance_alpha = ligand_pos[batch['sr_ligand_idx']] - protein_pos[batch['sr_protein_idx']]
|
191 |
+
input_distance_intra = ligand_pos[batch['sr_ligand_idx0']] - ligand_pos[batch['sr_ligand_idx1']]
|
192 |
+
distance_emb1 = self.encoder.distance_expansion(torch.norm(input_distance_alpha, dim=1))
|
193 |
+
distance_emb2 = self.encoder.distance_expansion(torch.norm(input_distance_intra, dim=1))
|
194 |
+
input1 = torch.cat([h_ctx_ligand[batch['sr_ligand_idx']], h_ctx_protein[batch['sr_protein_idx']], distance_emb1], dim=-1)[true_distance_alpha<=10.0]
|
195 |
+
input2 = torch.cat([h_ctx_ligand[batch['sr_ligand_idx0']], h_ctx_ligand[batch['sr_ligand_idx1']], distance_emb2], dim=-1)[true_distance_intra<=10.0]
|
196 |
+
#distance cut_off
|
197 |
+
norm_dir1 = F.normalize(input_distance_alpha, p=2, dim=1)[true_distance_alpha<=10.0]
|
198 |
+
norm_dir2 = F.normalize(input_distance_intra, p=2, dim=1)[true_distance_intra<=10.0]
|
199 |
+
force1 = scatter_mean(self.refine_protein(input1)*norm_dir1, dim=0, index=batch['sr_ligand_idx'][true_distance_alpha<=10.0], dim_size=ligand_pos.size(0))
|
200 |
+
force2 = scatter_mean(self.refine_ligand(input2)*norm_dir2, dim=0, index=batch['sr_ligand_idx0'][true_distance_intra<=10.0], dim_size=ligand_pos.size(0))
|
201 |
+
new_ligand_pos = deepcopy(ligand_pos)
|
202 |
+
new_ligand_pos += force1
|
203 |
+
new_ligand_pos += force2
|
204 |
+
refine_dist1 = torch.norm(new_ligand_pos[batch['sr_ligand_idx']] - protein_pos[batch['sr_protein_idx']], dim=1)
|
205 |
+
refine_dist2 = torch.norm(new_ligand_pos[batch['sr_ligand_idx0']] - new_ligand_pos[batch['sr_ligand_idx1']], dim=1)
|
206 |
+
sr_loss = (self.dist_loss(refine_dist1, true_distance_alpha) + self.dist_loss(refine_dist2, true_distance_intra))
|
207 |
+
loss_list[5] = sr_loss.item()
|
208 |
+
else:
|
209 |
+
sr_loss = 0
|
210 |
+
|
211 |
+
# torsion prediction
|
212 |
+
if len(batch['y_pos']) > 0:
|
213 |
+
Hx = dihedral_utils.rotation_matrix_v2(batch['y_pos'])
|
214 |
+
xn_pos = torch.matmul(Hx, batch['xn_pos'].permute(0, 2, 1)).permute(0, 2, 1)
|
215 |
+
yn_pos = torch.matmul(Hx, batch['yn_pos'].permute(0, 2, 1)).permute(0, 2, 1)
|
216 |
+
y_pos = torch.matmul(Hx, batch['y_pos'].unsqueeze(1).permute(0, 2, 1)).squeeze(-1)
|
217 |
+
|
218 |
+
hx, hy = h_ctx_ligand_torsion[batch['ligand_torsion_xy_index'][:, 0]], h_ctx_ligand_torsion[batch['ligand_torsion_xy_index'][:, 1]]
|
219 |
+
h_mol = scatter_add(h_ctx_ligand_torsion, dim=0, index=batch['ligand_element_torsion_batch'])
|
220 |
+
if self.config.random_alpha:
|
221 |
+
rand_dist = torch.distributions.normal.Normal(loc=0, scale=1)
|
222 |
+
rand_alpha = rand_dist.sample(hx.shape).to(self.device)
|
223 |
+
alpha = self.alpha_mlp(torch.cat([hx, hy, h_mol, rand_alpha], dim=-1))
|
224 |
+
else:
|
225 |
+
alpha = self.alpha_mlp(torch.cat([hx, hy, h_mol], dim=-1))
|
226 |
+
# rotate xn
|
227 |
+
R_alpha = self.build_alpha_rotation(torch.sin(alpha).squeeze(-1), torch.cos(alpha).squeeze(-1))
|
228 |
+
xn_pos = torch.matmul(R_alpha, xn_pos.permute(0, 2, 1)).permute(0, 2, 1)
|
229 |
+
|
230 |
+
p_idx, q_idx = torch.cartesian_prod(torch.arange(3), torch.arange(3)).chunk(2, dim=-1)
|
231 |
+
p_idx, q_idx = p_idx.squeeze(-1), q_idx.squeeze(-1)
|
232 |
+
pred_sin, pred_cos = dihedral_utils.batch_dihedrals(xn_pos[:, p_idx],
|
233 |
+
torch.zeros_like(y_pos).unsqueeze(1).repeat(1, 9, 1),
|
234 |
+
y_pos.unsqueeze(1).repeat(1, 9, 1),
|
235 |
+
yn_pos[:, q_idx])
|
236 |
+
dihedral_loss = torch.mean(dihedral_utils.von_Mises_loss(batch['true_cos'], pred_cos.reshape(-1), batch['true_sin'], pred_cos.reshape(-1))[batch['dihedral_mask']])
|
237 |
+
torsion_loss = -dihedral_loss
|
238 |
+
loss_list[4] = torsion_loss.item()
|
239 |
+
else:
|
240 |
+
torsion_loss = 0
|
241 |
+
|
242 |
+
# dm: distance matrix
|
243 |
+
loss = pred_loss + comb_loss + focal_loss + dm_loss + torsion_loss + sr_loss
|
244 |
+
|
245 |
+
return loss, loss_list
|
246 |
+
|
247 |
+
def build_alpha_rotation(self, alpha, alpha_cos=None):
|
248 |
+
"""
|
249 |
+
Builds the alpha rotation matrix
|
250 |
+
|
251 |
+
:param alpha: predicted values of torsion parameter alpha (n_dihedral_pairs)
|
252 |
+
:return: alpha rotation matrix (n_dihedral_pairs, 3, 3)
|
253 |
+
"""
|
254 |
+
H_alpha = torch.FloatTensor([[[1, 0, 0], [0, 0, 0], [0, 0, 0]]]).repeat(alpha.shape[0], 1, 1).to(self.device)
|
255 |
+
|
256 |
+
if torch.is_tensor(alpha_cos):
|
257 |
+
H_alpha[:, 1, 1] = alpha_cos
|
258 |
+
H_alpha[:, 1, 2] = -alpha
|
259 |
+
H_alpha[:, 2, 1] = alpha
|
260 |
+
H_alpha[:, 2, 2] = alpha_cos
|
261 |
+
else:
|
262 |
+
H_alpha[:, 1, 1] = torch.cos(alpha)
|
263 |
+
H_alpha[:, 1, 2] = -torch.sin(alpha)
|
264 |
+
H_alpha[:, 2, 1] = torch.sin(alpha)
|
265 |
+
H_alpha[:, 2, 2] = torch.cos(alpha)
|
266 |
+
|
267 |
+
return H_alpha
|
268 |
+
|
motif_sample.py
ADDED
@@ -0,0 +1,660 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import shutil
|
3 |
+
import argparse
|
4 |
+
import random
|
5 |
+
import torch
|
6 |
+
import numpy as np
|
7 |
+
import math
|
8 |
+
from vina import Vina
|
9 |
+
from openbabel import pybel
|
10 |
+
import subprocess
|
11 |
+
import multiprocessing as mp
|
12 |
+
from functools import partial
|
13 |
+
from torch_geometric.data import Batch
|
14 |
+
from tqdm.auto import tqdm
|
15 |
+
from rdkit import Chem
|
16 |
+
from rdkit.Geometry import Point3D
|
17 |
+
from torch.utils.data import DataLoader
|
18 |
+
from rdkit.Chem.rdchem import BondType
|
19 |
+
from rdkit.Chem import ChemicalFeatures, rdMolDescriptors
|
20 |
+
from rdkit import RDConfig
|
21 |
+
from rdkit.Chem.Descriptors import MolLogP, qed
|
22 |
+
from copy import deepcopy
|
23 |
+
import tempfile
|
24 |
+
import AutoDockTools
|
25 |
+
import contextlib
|
26 |
+
from torch_scatter import scatter_add, scatter_mean
|
27 |
+
from rdkit.Geometry import Point3D
|
28 |
+
from meeko import MoleculePreparation
|
29 |
+
from meeko import obutils
|
30 |
+
from models.flag import FLAG
|
31 |
+
from utils.transforms import *
|
32 |
+
from utils.datasets import get_dataset
|
33 |
+
from utils.misc import *
|
34 |
+
from utils.data import *
|
35 |
+
from utils.mol_tree import *
|
36 |
+
from utils.chemutils import *
|
37 |
+
from utils.dihedral_utils import *
|
38 |
+
from utils.sascorer import compute_sa_score
|
39 |
+
from rdkit.Chem import AllChem
|
40 |
+
|
41 |
+
_fscores = None
|
42 |
+
|
43 |
+
ATOM_FAMILIES = ['Acceptor', 'Donor', 'Aromatic', 'Hydrophobe', 'LumpedHydrophobe', 'NegIonizable', 'PosIonizable',
|
44 |
+
'ZnBinder']
|
45 |
+
ATOM_FAMILIES_ID = {s: i for i, s in enumerate(ATOM_FAMILIES)}
|
46 |
+
|
47 |
+
STATUS_RUNNING = 'running'
|
48 |
+
STATUS_FINISHED = 'finished'
|
49 |
+
STATUS_FAILED = 'failed'
|
50 |
+
|
51 |
+
|
52 |
+
def supress_stdout(func):
|
53 |
+
def wrapper(*a, **ka):
|
54 |
+
with open(os.devnull, 'w') as devnull:
|
55 |
+
with contextlib.redirect_stdout(devnull):
|
56 |
+
return func(*a, **ka)
|
57 |
+
return wrapper
|
58 |
+
|
59 |
+
|
60 |
+
class PrepLig(object):
|
61 |
+
def __init__(self, input_mol, mol_format):
|
62 |
+
if mol_format == 'smi':
|
63 |
+
self.ob_mol = pybel.readstring('smi', input_mol)
|
64 |
+
elif mol_format == 'sdf':
|
65 |
+
self.ob_mol = next(pybel.readfile(mol_format, input_mol))
|
66 |
+
else:
|
67 |
+
raise ValueError(f'mol_format {mol_format} not supported')
|
68 |
+
|
69 |
+
def addH(self, polaronly=False, correctforph=True, PH=7):
|
70 |
+
self.ob_mol.OBMol.AddHydrogens(polaronly, correctforph, PH)
|
71 |
+
obutils.writeMolecule(self.ob_mol.OBMol, 'tmp_h.sdf')
|
72 |
+
|
73 |
+
def gen_conf(self):
|
74 |
+
sdf_block = self.ob_mol.write('sdf')
|
75 |
+
rdkit_mol = Chem.MolFromMolBlock(sdf_block, removeHs=False)
|
76 |
+
AllChem.EmbedMolecule(rdkit_mol, Chem.rdDistGeom.ETKDGv3())
|
77 |
+
self.ob_mol = pybel.readstring('sdf', Chem.MolToMolBlock(rdkit_mol))
|
78 |
+
obutils.writeMolecule(self.ob_mol.OBMol, 'conf_h.sdf')
|
79 |
+
|
80 |
+
@supress_stdout
|
81 |
+
def get_pdbqt(self, lig_pdbqt=None):
|
82 |
+
preparator = MoleculePreparation()
|
83 |
+
preparator.prepare(self.ob_mol.OBMol)
|
84 |
+
if lig_pdbqt is not None:
|
85 |
+
preparator.write_pdbqt_file(lig_pdbqt)
|
86 |
+
return
|
87 |
+
else:
|
88 |
+
return preparator.write_pdbqt_string()
|
89 |
+
|
90 |
+
|
91 |
+
class PrepProt(object):
|
92 |
+
def __init__(self, pdb_file):
|
93 |
+
self.prot = pdb_file
|
94 |
+
|
95 |
+
def del_water(self, dry_pdb_file): # optional
|
96 |
+
with open(self.prot) as f:
|
97 |
+
lines = [l for l in f.readlines() if l.startswith('ATOM') or l.startswith('HETATM')]
|
98 |
+
dry_lines = [l for l in lines if not 'HOH' in l]
|
99 |
+
|
100 |
+
with open(dry_pdb_file, 'w') as f:
|
101 |
+
f.write(''.join(dry_lines))
|
102 |
+
self.prot = dry_pdb_file
|
103 |
+
|
104 |
+
def addH(self, prot_pqr): # call pdb2pqr
|
105 |
+
self.prot_pqr = prot_pqr
|
106 |
+
subprocess.Popen(['pdb2pqr30', '--ff=AMBER', self.prot, self.prot_pqr],
|
107 |
+
stderr=subprocess.DEVNULL, stdout=subprocess.DEVNULL).communicate()
|
108 |
+
|
109 |
+
def get_pdbqt(self, prot_pdbqt):
|
110 |
+
prepare_receptor = os.path.join(AutoDockTools.__path__[0], 'Utilities24/prepare_receptor4.py')
|
111 |
+
subprocess.Popen(['python3', prepare_receptor, '-r', self.prot_pqr, '-o', prot_pdbqt],
|
112 |
+
stderr=subprocess.DEVNULL, stdout=subprocess.DEVNULL).communicate()
|
113 |
+
|
114 |
+
|
115 |
+
def calculate_vina(number, pro_path, lig_path):
|
116 |
+
lig_path = os.path.join(lig_path, str(number)+'.sdf')
|
117 |
+
size_factor = 1.2
|
118 |
+
buffer = 5.
|
119 |
+
# openmm_relax(pro_path)
|
120 |
+
# relax_sdf(lig_path)
|
121 |
+
mol = Chem.MolFromMolFile(lig_path, sanitize=True)
|
122 |
+
pos = mol.GetConformer(0).GetPositions()
|
123 |
+
center = np.mean(pos, 0)
|
124 |
+
ligand_pdbqt = './data/tmp/' + str(number) + '_lig.pdbqt'
|
125 |
+
protein_pqr = './data/tmp/' + str(number) + '_pro.pqr'
|
126 |
+
protein_pdbqt = './data/tmp/' + str(number) + '_pro.pdbqt'
|
127 |
+
lig = PrepLig(lig_path, 'sdf')
|
128 |
+
lig.addH()
|
129 |
+
lig.get_pdbqt(ligand_pdbqt)
|
130 |
+
|
131 |
+
prot = PrepProt(pro_path)
|
132 |
+
prot.addH(protein_pqr)
|
133 |
+
prot.get_pdbqt(protein_pdbqt)
|
134 |
+
|
135 |
+
v = Vina(sf_name='vina', seed=0, verbosity=0)
|
136 |
+
v.set_receptor(protein_pdbqt)
|
137 |
+
v.set_ligand_from_file(ligand_pdbqt)
|
138 |
+
x, y, z = (pos.max(0) - pos.min(0)) * size_factor + buffer
|
139 |
+
v.compute_vina_maps(center=center, box_size=[x, y, z])
|
140 |
+
energy = v.score()
|
141 |
+
print('Score before minimization: %.3f (kcal/mol)' % energy[0])
|
142 |
+
energy_minimized = v.optimize()
|
143 |
+
print('Score after minimization : %.3f (kcal/mol)' % energy_minimized[0])
|
144 |
+
v.dock(exhaustiveness=64, n_poses=32)
|
145 |
+
score = v.energies(n_poses=1)[0][0]
|
146 |
+
print('Score after docking : %.3f (kcal/mol)' % score)
|
147 |
+
|
148 |
+
return score
|
149 |
+
|
150 |
+
|
151 |
+
def get_feat(mol):
|
152 |
+
fdefName = os.path.join(RDConfig.RDDataDir, 'BaseFeatures.fdef')
|
153 |
+
factory = ChemicalFeatures.BuildFeatureFactory(fdefName)
|
154 |
+
atomic_numbers = torch.LongTensor([6, 7, 8, 9, 15, 16, 17]) # C N O F P S Cl
|
155 |
+
ptable = Chem.GetPeriodicTable()
|
156 |
+
Chem.SanitizeMol(mol)
|
157 |
+
feat_mat = np.zeros([mol.GetNumAtoms(), len(ATOM_FAMILIES)], dtype=np.int_)
|
158 |
+
for feat in factory.GetFeaturesForMol(mol):
|
159 |
+
feat_mat[feat.GetAtomIds(), ATOM_FAMILIES_ID[feat.GetFamily()]] = 1
|
160 |
+
ligand_element = torch.tensor([ptable.GetAtomicNumber(atom.GetSymbol()) for atom in mol.GetAtoms()])
|
161 |
+
element = ligand_element.view(-1, 1) == atomic_numbers.view(1, -1) # (N_atoms, N_elements)
|
162 |
+
return torch.cat([element, torch.tensor(feat_mat)], dim=-1).float()
|
163 |
+
|
164 |
+
|
165 |
+
def find_reference(protein_pos, focal_id):
|
166 |
+
# Select three reference protein atoms
|
167 |
+
d = torch.norm(protein_pos - protein_pos[focal_id], dim=1)
|
168 |
+
reference_idx = torch.topk(d, k=4, largest=False)[1]
|
169 |
+
reference_pos = protein_pos[reference_idx]
|
170 |
+
return reference_pos, reference_idx
|
171 |
+
|
172 |
+
|
173 |
+
def SetAtomNum(mol, atoms):
|
174 |
+
for atom in mol.GetAtoms():
|
175 |
+
if atom.GetIdx() in atoms:
|
176 |
+
atom.SetAtomMapNum(1)
|
177 |
+
else:
|
178 |
+
atom.SetAtomMapNum(0)
|
179 |
+
return mol
|
180 |
+
|
181 |
+
|
182 |
+
def SetMolPos(mol_list, pos_list):
|
183 |
+
new_mol_list = []
|
184 |
+
for i in range(len(pos_list)):
|
185 |
+
mol = mol_list[i]
|
186 |
+
conf = mol.GetConformer(0)
|
187 |
+
pos = pos_list[i].cpu().double().numpy()
|
188 |
+
if mol.GetNumAtoms() == len(pos):
|
189 |
+
for node in range(mol.GetNumAtoms()):
|
190 |
+
x, y, z = pos[node]
|
191 |
+
conf.SetAtomPosition(node, Point3D(x,y,z))
|
192 |
+
try:
|
193 |
+
AllChem.UFFOptimizeMolecule(mol)
|
194 |
+
new_mol_list.append(mol)
|
195 |
+
except:
|
196 |
+
new_mol_list.append(mol)
|
197 |
+
return new_mol_list
|
198 |
+
|
199 |
+
|
200 |
+
def lipinski(mol):
|
201 |
+
count = 0
|
202 |
+
if qed(mol) <= 5:
|
203 |
+
count += 1
|
204 |
+
if Chem.Lipinski.NumHDonors(mol) <= 5:
|
205 |
+
count += 1
|
206 |
+
if Chem.Lipinski.NumHAcceptors(mol) <= 10:
|
207 |
+
count += 1
|
208 |
+
if Chem.Descriptors.ExactMolWt(mol) <= 500:
|
209 |
+
count += 1
|
210 |
+
if Chem.Lipinski.NumRotatableBonds(mol) <= 5:
|
211 |
+
count += 1
|
212 |
+
return count
|
213 |
+
|
214 |
+
|
215 |
+
def refine_pos(ligand_pos, protein_pos, h_ctx_ligand, h_ctx_protein, model, batch, repeats, protein_batch,
|
216 |
+
ligand_batch):
|
217 |
+
protein_offsets = torch.cumsum(protein_batch.bincount(), dim=0)
|
218 |
+
ligand_offsets = torch.cumsum(ligand_batch.bincount(), dim=0)
|
219 |
+
protein_offsets, ligand_offsets = torch.cat([torch.tensor([0]), protein_offsets]), torch.cat([torch.tensor([0]), ligand_offsets])
|
220 |
+
|
221 |
+
sr_ligand_idx, sr_protein_idx = [], []
|
222 |
+
sr_ligand_idx0, sr_ligand_idx1 = [], []
|
223 |
+
for i in range(len(repeats)):
|
224 |
+
alpha_index = batch['alpha_carbon_indicator'][protein_batch == i].nonzero().reshape(-1)
|
225 |
+
ligand_atom_index = torch.arange(repeats[i])
|
226 |
+
|
227 |
+
p_idx, q_idx = torch.cartesian_prod(ligand_atom_index, torch.arange(len(alpha_index))).chunk(2, dim=-1)
|
228 |
+
p_idx, q_idx = p_idx.squeeze(-1), q_idx.squeeze(-1)
|
229 |
+
sr_ligand_idx.append(ligand_atom_index[p_idx] + ligand_offsets[i])
|
230 |
+
sr_protein_idx.append(alpha_index[q_idx] + protein_offsets[i])
|
231 |
+
|
232 |
+
p_idx, q_idx = torch.cartesian_prod(ligand_atom_index, ligand_atom_index).chunk(2, dim=-1)
|
233 |
+
p_idx, q_idx = p_idx.squeeze(-1), q_idx.squeeze(-1)
|
234 |
+
sr_ligand_idx0.append(ligand_atom_index[p_idx] + ligand_offsets[i])
|
235 |
+
sr_ligand_idx1.append(ligand_atom_index[q_idx] + ligand_offsets[i])
|
236 |
+
sr_ligand_idx, sr_protein_idx = torch.cat(sr_ligand_idx).long(), torch.cat(sr_protein_idx).long()
|
237 |
+
sr_ligand_idx0, sr_ligand_idx1 = torch.cat(sr_ligand_idx0).long(), torch.cat(sr_ligand_idx1).long()
|
238 |
+
|
239 |
+
dist_alpha = torch.norm(ligand_pos[sr_ligand_idx] - protein_pos[sr_protein_idx], dim=1)
|
240 |
+
dist_intra = torch.norm(ligand_pos[sr_ligand_idx0] - ligand_pos[sr_ligand_idx1], dim=1)
|
241 |
+
input_dir_alpha = ligand_pos[sr_ligand_idx] - protein_pos[sr_protein_idx]
|
242 |
+
input_dir_intra = ligand_pos[sr_ligand_idx0] - ligand_pos[sr_ligand_idx1]
|
243 |
+
distance_emb1 = model.encoder.distance_expansion(torch.norm(input_dir_alpha, dim=1))
|
244 |
+
distance_emb2 = model.encoder.distance_expansion(torch.norm(input_dir_intra, dim=1))
|
245 |
+
input1 = torch.cat([h_ctx_ligand[sr_ligand_idx], h_ctx_protein[sr_protein_idx], distance_emb1], dim=-1)[dist_alpha <= 10.0]
|
246 |
+
input2 = torch.cat([h_ctx_ligand[sr_ligand_idx0], h_ctx_ligand[sr_ligand_idx1], distance_emb2], dim=-1)[dist_intra <= 10.0]
|
247 |
+
# distance cut_off
|
248 |
+
norm_dir1 = F.normalize(input_dir_alpha, p=2, dim=1)[dist_alpha <= 10.0]
|
249 |
+
norm_dir2 = F.normalize(input_dir_intra, p=2, dim=1)[dist_intra <= 10.0]
|
250 |
+
force1 = scatter_mean(model.refine_protein(input1) * norm_dir1, dim=0, index=sr_ligand_idx[dist_alpha <= 10.0], dim_size=ligand_pos.size(0))
|
251 |
+
force2 = scatter_mean(model.refine_ligand(input2) * norm_dir2, dim=0, index=sr_ligand_idx0[dist_intra <= 10.0], dim_size=ligand_pos.size(0))
|
252 |
+
ligand_pos += force1
|
253 |
+
ligand_pos += force2
|
254 |
+
|
255 |
+
ligand_pos = [ligand_pos[ligand_batch==k].float() for k in range(len(repeats))]
|
256 |
+
return ligand_pos
|
257 |
+
|
258 |
+
|
259 |
+
def ligand_gen(batch, model, vocab, config, center, device, refinement=False):
|
260 |
+
pos_list = []
|
261 |
+
feat_list = []
|
262 |
+
motif_id = [0 for _ in range(config.sample.batch_size)]
|
263 |
+
finished = torch.zeros(config.sample.batch_size).bool()
|
264 |
+
for i in range(config.sample.max_steps):
|
265 |
+
print(i)
|
266 |
+
print(finished)
|
267 |
+
if torch.sum(finished) == config.sample.batch_size:
|
268 |
+
# mol_list = SetMolPos(mol_list, pos_list)
|
269 |
+
return mol_list, pos_list
|
270 |
+
if i == 0:
|
271 |
+
focal_pred, mask_protein, h_ctx = model(protein_pos=batch['protein_pos'],
|
272 |
+
protein_atom_feature=batch['protein_atom_feature'].float(),
|
273 |
+
ligand_pos=batch['ligand_context_pos'],
|
274 |
+
ligand_atom_feature=batch['ligand_context_feature_full'].float(),
|
275 |
+
batch_protein=batch['protein_element_batch'],
|
276 |
+
batch_ligand=batch['ligand_context_element_batch'])
|
277 |
+
protein_atom_feature = batch['protein_atom_feature'].float()
|
278 |
+
focal_protein = focal_pred[mask_protein]
|
279 |
+
h_ctx_protein = h_ctx[mask_protein]
|
280 |
+
focus_score = torch.sigmoid(focal_protein)
|
281 |
+
#can_focus = focus_score > 0.5
|
282 |
+
slice_idx = torch.cat([torch.tensor([0]).to(device), torch.cumsum(batch['protein_element_batch'].bincount(), dim=0)])
|
283 |
+
focal_id = []
|
284 |
+
for j in range(len(slice_idx) - 1):
|
285 |
+
focus = focus_score[slice_idx[j]:slice_idx[j + 1]]
|
286 |
+
focal_id.append(torch.argmax(focus.reshape(-1).float()).item() + slice_idx[j].item())
|
287 |
+
focal_id = torch.tensor(focal_id, device=device)
|
288 |
+
|
289 |
+
h_ctx_focal = h_ctx_protein[focal_id]
|
290 |
+
current_wid = torch.tensor([vocab.size()] * config.sample.batch_size, device=device)
|
291 |
+
next_motif_wid, motif_prob = model.forward_motif(h_ctx_focal, current_wid, torch.arange(config.sample.batch_size, device=device).to(device))
|
292 |
+
mol_list = [Chem.MolFromSmiles(vocab.get_smiles(id)) for id in next_motif_wid]
|
293 |
+
|
294 |
+
for j in range(config.sample.batch_size):
|
295 |
+
AllChem.EmbedMolecule(mol_list[j])
|
296 |
+
AllChem.UFFOptimizeMolecule(mol_list[j])
|
297 |
+
ligand_pos, ligand_feat = torch.tensor(mol_list[j].GetConformer().GetPositions(), device=device), get_feat(mol_list[j]).to(device)
|
298 |
+
feat_list.append(ligand_feat)
|
299 |
+
# set the initial positions with distance matrix
|
300 |
+
reference_pos, reference_idx = find_reference(batch['protein_pos'][slice_idx[j]:slice_idx[j + 1]], focal_id[j] - slice_idx[j])
|
301 |
+
|
302 |
+
p_idx, l_idx = torch.cartesian_prod(torch.arange(4), torch.arange(len(ligand_pos))).chunk(2, dim=-1)
|
303 |
+
p_idx = p_idx.squeeze(-1).to(device)
|
304 |
+
l_idx = l_idx.squeeze(-1).to(device)
|
305 |
+
d_m = model.dist_mlp(torch.cat([protein_atom_feature[reference_idx[p_idx]], ligand_feat[l_idx]], dim=-1)).reshape(4,len(ligand_pos))
|
306 |
+
|
307 |
+
d_m = d_m ** 2
|
308 |
+
p_d, l_d = self_square_dist(reference_pos), self_square_dist(ligand_pos)
|
309 |
+
D = torch.cat([torch.cat([p_d, d_m], dim=1), torch.cat([d_m.permute(1, 0), l_d], dim=1)])
|
310 |
+
coordinate = eig_coord_from_dist(D)
|
311 |
+
new_pos, _, _ = kabsch_torch(coordinate[:len(reference_pos)], reference_pos,
|
312 |
+
coordinate[len(reference_pos):])
|
313 |
+
# new_pos += (center*0.8+torch.mean(reference_pos, dim=0)*0.2) - torch.mean(new_pos, dim=0)
|
314 |
+
new_pos += (center - torch.mean(new_pos, dim=0)) * .8
|
315 |
+
pos_list.append(new_pos)
|
316 |
+
|
317 |
+
atom_to_motif = [{} for _ in range(config.sample.batch_size)]
|
318 |
+
motif_to_atoms = [{} for _ in range(config.sample.batch_size)]
|
319 |
+
motif_wid = [{} for _ in range(config.sample.batch_size)]
|
320 |
+
for j in range(config.sample.batch_size):
|
321 |
+
for k in range(mol_list[j].GetNumAtoms()):
|
322 |
+
atom_to_motif[j][k] = 0
|
323 |
+
for j in range(config.sample.batch_size):
|
324 |
+
motif_to_atoms[j][0] = list(np.arange(mol_list[j].GetNumAtoms()))
|
325 |
+
motif_wid[j][0] = next_motif_wid[j].item()
|
326 |
+
else:
|
327 |
+
repeats = torch.tensor([len(pos) for pos in pos_list], device=device)
|
328 |
+
ligand_batch = torch.repeat_interleave(torch.arange(config.sample.batch_size, device=device), repeats)
|
329 |
+
focal_pred, mask_protein, h_ctx = model(protein_pos=batch['protein_pos'].float(),
|
330 |
+
protein_atom_feature=batch['protein_atom_feature'].float(),
|
331 |
+
ligand_pos=torch.cat(pos_list, dim=0).float(),
|
332 |
+
ligand_atom_feature=torch.cat(feat_list, dim=0).float(),
|
333 |
+
batch_protein=batch['protein_element_batch'],
|
334 |
+
batch_ligand=ligand_batch)
|
335 |
+
# structure refinement
|
336 |
+
if refinement:
|
337 |
+
pos_list = refine_pos(torch.cat(pos_list, dim=0).float(), batch['protein_pos'].float(),
|
338 |
+
h_ctx[~mask_protein], h_ctx[mask_protein], model, batch, repeats.tolist(),
|
339 |
+
batch['protein_element_batch'], ligand_batch)
|
340 |
+
|
341 |
+
focal_ligand = focal_pred[~mask_protein]
|
342 |
+
h_ctx_ligand = h_ctx[~mask_protein]
|
343 |
+
focus_score = torch.sigmoid(focal_ligand)
|
344 |
+
can_focus = focus_score > 0.
|
345 |
+
slice_idx = torch.cat([torch.tensor([0], device=device), torch.cumsum(repeats, dim=0)])
|
346 |
+
|
347 |
+
current_atoms_batch, current_atoms = [], []
|
348 |
+
for j in range(len(slice_idx) - 1):
|
349 |
+
focus = focus_score[slice_idx[j]:slice_idx[j + 1]]
|
350 |
+
if torch.sum(can_focus[slice_idx[j]:slice_idx[j + 1]]) > 0 and ~finished[j]:
|
351 |
+
sample_focal_atom = torch.multinomial(focus.reshape(-1).float(), 1)
|
352 |
+
focal_motif = atom_to_motif[j][sample_focal_atom.item()]
|
353 |
+
motif_id[j] = focal_motif
|
354 |
+
else:
|
355 |
+
finished[j] = True
|
356 |
+
|
357 |
+
current_atoms.extend((np.array(motif_to_atoms[j][motif_id[j]]) + slice_idx[j].item()).tolist())
|
358 |
+
current_atoms_batch.extend([j] * len(motif_to_atoms[j][motif_id[j]]))
|
359 |
+
mol_list[j] = SetAtomNum(mol_list[j], motif_to_atoms[j][motif_id[j]])
|
360 |
+
# second step: next motif prediction
|
361 |
+
current_wid = [motif_wid[j][motif_id[j]] for j in range(len(mol_list))]
|
362 |
+
next_motif_wid, motif_prob = model.forward_motif(h_ctx_ligand[torch.tensor(current_atoms)],
|
363 |
+
torch.tensor(current_wid).to(device),
|
364 |
+
torch.tensor(current_atoms_batch).to(device))
|
365 |
+
|
366 |
+
# assemble
|
367 |
+
next_motif_smiles = [vocab.get_smiles(id) for id in next_motif_wid]
|
368 |
+
new_mol_list, new_atoms, one_atom_attach, intersection, attach_fail = model.forward_attach(mol_list, next_motif_smiles, device)
|
369 |
+
|
370 |
+
for j in range(len(mol_list)):
|
371 |
+
if ~finished[j] and ~attach_fail[j]:
|
372 |
+
# num_new_atoms
|
373 |
+
mol_list[j] = new_mol_list[j]
|
374 |
+
rotatable = torch.logical_and(torch.tensor(current_atoms_batch).bincount() == 2, torch.tensor(one_atom_attach))
|
375 |
+
rotatable = torch.logical_and(rotatable, ~torch.tensor(attach_fail))
|
376 |
+
rotatable = torch.logical_and(rotatable, ~finished).to(device)
|
377 |
+
# update motif2atoms and atom2motif
|
378 |
+
for j in range(len(mol_list)):
|
379 |
+
if attach_fail[j] or finished[j]:
|
380 |
+
continue
|
381 |
+
motif_to_atoms[j][i] = new_atoms[j]
|
382 |
+
motif_wid[j][i] = next_motif_wid[j]
|
383 |
+
for k in new_atoms[j]:
|
384 |
+
atom_to_motif[j][k] = i
|
385 |
+
'''
|
386 |
+
if k in atom_to_motif[j]:
|
387 |
+
continue
|
388 |
+
else:
|
389 |
+
atom_to_motif[j][k] = i'''
|
390 |
+
|
391 |
+
# generate initial positions
|
392 |
+
for j in range(len(mol_list)):
|
393 |
+
if attach_fail[j] or finished[j]:
|
394 |
+
continue
|
395 |
+
mol = mol_list[j]
|
396 |
+
anchor = [atom.GetIdx() for atom in mol.GetAtoms() if atom.GetAtomMapNum() == 1]
|
397 |
+
# positions = mol.GetConformer().GetPositions()
|
398 |
+
anchor_pos = deepcopy(pos_list[j][anchor]).to(device)
|
399 |
+
Chem.SanitizeMol(mol)
|
400 |
+
AllChem.EmbedMolecule(mol, useRandomCoords=True)
|
401 |
+
try:
|
402 |
+
AllChem.UFFOptimizeMolecule(mol)
|
403 |
+
except:
|
404 |
+
print('UFF error')
|
405 |
+
anchor_pos_new = mol.GetConformer(0).GetPositions()[anchor]
|
406 |
+
new_idx = [atom.GetIdx() for atom in mol.GetAtoms() if atom.GetAtomMapNum() == 2]
|
407 |
+
'''
|
408 |
+
R, T = kabsch(np.matrix(anchor_pos), np.matrix(anchor_pos_new))
|
409 |
+
new_pos = R * np.matrix(mol.GetConformer().GetPositions()[new_idx]).T + np.tile(T, (1, len(new_idx)))
|
410 |
+
new_pos = np.array(new_pos.T)'''
|
411 |
+
new_pos = mol.GetConformer().GetPositions()[new_idx]
|
412 |
+
new_pos, _, _ = kabsch_torch(torch.tensor(anchor_pos_new, device=device), anchor_pos, torch.tensor(new_pos, device=device))
|
413 |
+
|
414 |
+
conf = mol.GetConformer()
|
415 |
+
# update curated parameters
|
416 |
+
pos_list[j] = torch.cat([pos_list[j], new_pos])
|
417 |
+
feat_list[j] = get_feat(mol_list[j]).to(device)
|
418 |
+
for node in range(mol.GetNumAtoms()):
|
419 |
+
conf.SetAtomPosition(node, np.array(pos_list[j][node].cpu()))
|
420 |
+
assert mol.GetNumAtoms() == len(pos_list[j])
|
421 |
+
|
422 |
+
# predict alpha and rotate (only change the position)
|
423 |
+
if torch.sum(rotatable) > 0 and i >= 2:
|
424 |
+
repeats = torch.tensor([len(pos) for pos in pos_list])
|
425 |
+
ligand_batch = torch.repeat_interleave(torch.arange(len(pos_list)), repeats).to(device)
|
426 |
+
slice_idx = torch.cat([torch.tensor([0]), torch.cumsum(repeats, dim=0)])
|
427 |
+
xy_index = [(np.array(motif_to_atoms[j][motif_id[j]]) + slice_idx[j].item()).tolist() for j in range(len(slice_idx) - 1) if rotatable[j]]
|
428 |
+
|
429 |
+
alpha = model.forward_alpha(protein_pos=batch['protein_pos'].float(),
|
430 |
+
protein_atom_feature=batch['protein_atom_feature'].float(),
|
431 |
+
ligand_pos=torch.cat(pos_list, dim=0).float(),
|
432 |
+
ligand_atom_feature=torch.cat(feat_list, dim=0).float(),
|
433 |
+
batch_protein=batch['protein_element_batch'],
|
434 |
+
batch_ligand=ligand_batch, xy_index=torch.tensor(xy_index, device=device),
|
435 |
+
rotatable=rotatable)
|
436 |
+
|
437 |
+
rotatable_id = [id for id in range(len(mol_list)) if rotatable[id]]
|
438 |
+
xy_index = [motif_to_atoms[j][motif_id[j]] for j in range(len(slice_idx) - 1) if rotatable[j]]
|
439 |
+
x_index = [intersection[j] for j in range(len(slice_idx) - 1) if rotatable[j]]
|
440 |
+
y_index = [(set(xy_index[k]) - set(x_index[k])).pop() for k in range(len(x_index))]
|
441 |
+
|
442 |
+
for j in range(len(alpha)):
|
443 |
+
mol = mol_list[rotatable_id[j]]
|
444 |
+
new_idx = [atom.GetIdx() for atom in mol.GetAtoms() if atom.GetAtomMapNum() == 2]
|
445 |
+
positions = deepcopy(pos_list[rotatable_id[j]])
|
446 |
+
|
447 |
+
xn_pos = positions[new_idx].float()
|
448 |
+
dir=(positions[x_index[j]] - positions[y_index[j]]).reshape(-1)
|
449 |
+
ref=positions[x_index[j]].reshape(-1)
|
450 |
+
xn_pos = rand_rotate(dir.to(device), ref.to(device), xn_pos.to(device), alpha[j], device=device)
|
451 |
+
if xn_pos.shape[0] > 0:
|
452 |
+
pos_list[rotatable_id[j]][-len(xn_pos):] = xn_pos
|
453 |
+
conf = mol.GetConformer()
|
454 |
+
for node in range(mol.GetNumAtoms()):
|
455 |
+
conf.SetAtomPosition(node, np.array(pos_list[rotatable_id[j]][node].cpu()))
|
456 |
+
assert mol.GetNumAtoms() == len(pos_list[rotatable_id[j]])
|
457 |
+
|
458 |
+
return mol_list, pos_list
|
459 |
+
|
460 |
+
|
461 |
+
def demo(data_id):
|
462 |
+
vocab_path = 'vocab.txt'
|
463 |
+
device = 'cpu'
|
464 |
+
config = './configs/sample.yml'
|
465 |
+
vocab = []
|
466 |
+
for line in open(vocab_path):
|
467 |
+
p, _, _ = line.partition(':')
|
468 |
+
vocab.append(p)
|
469 |
+
vocab = Vocab(vocab)
|
470 |
+
|
471 |
+
# Load configs
|
472 |
+
config = load_config(config)
|
473 |
+
|
474 |
+
# Data
|
475 |
+
protein_featurizer = FeaturizeProteinAtom()
|
476 |
+
ligand_featurizer = FeaturizeLigandAtom()
|
477 |
+
masking = LigandMaskAll(vocab)
|
478 |
+
transform = Compose([
|
479 |
+
LigandCountNeighbors(),
|
480 |
+
protein_featurizer,
|
481 |
+
ligand_featurizer,
|
482 |
+
FeaturizeLigandBond(),
|
483 |
+
masking,
|
484 |
+
])
|
485 |
+
dataset, subsets = get_dataset(
|
486 |
+
config=config.dataset,
|
487 |
+
transform=transform,
|
488 |
+
)
|
489 |
+
testset = subsets['test']
|
490 |
+
data = testset[data_id%100]
|
491 |
+
center = data['ligand_center'].to(device)
|
492 |
+
test_set = [data for _ in range(config.sample.num_samples)]
|
493 |
+
|
494 |
+
# Model (Main)
|
495 |
+
ckpt = torch.load(config.model.checkpoint, map_location=device)
|
496 |
+
model = FLAG(
|
497 |
+
ckpt['config'].model,
|
498 |
+
protein_atom_feature_dim=protein_featurizer.feature_dim,
|
499 |
+
ligand_atom_feature_dim=ligand_featurizer.feature_dim,
|
500 |
+
vocab=vocab,
|
501 |
+
).to(device)
|
502 |
+
model.load_state_dict(ckpt['model'])
|
503 |
+
|
504 |
+
# my code goes here
|
505 |
+
sample_loader = DataLoader(test_set, batch_size=config.sample.batch_size,
|
506 |
+
shuffle=False, num_workers=config.sample.num_workers,
|
507 |
+
collate_fn=collate_mols)
|
508 |
+
|
509 |
+
with torch.no_grad():
|
510 |
+
model.eval()
|
511 |
+
number = 0
|
512 |
+
for batch in tqdm(sample_loader):
|
513 |
+
for key in batch:
|
514 |
+
batch[key] = batch[key].to(device)
|
515 |
+
gen_data, pos_list = ligand_gen(batch, model, vocab, config, center, device)
|
516 |
+
SetMolPos(gen_data, pos_list)
|
517 |
+
for mol in gen_data:
|
518 |
+
try:
|
519 |
+
AllChem.UFFOptimizeMolecule(mol)
|
520 |
+
except:
|
521 |
+
print('UFF error')
|
522 |
+
for _, mol in enumerate(gen_data):
|
523 |
+
number += 1
|
524 |
+
if mol.GetNumAtoms() < 12 or MolLogP(mol) < 0.60:
|
525 |
+
continue
|
526 |
+
filename = os.path.join('./data', 'Ligand.sdf')
|
527 |
+
writer = Chem.SDWriter(filename)
|
528 |
+
# writer.SetKekulize(False)
|
529 |
+
writer.write(mol, confId=0)
|
530 |
+
writer.close()
|
531 |
+
return filename
|
532 |
+
|
533 |
+
|
534 |
+
|
535 |
+
if __name__ == '__main__':
|
536 |
+
parser = argparse.ArgumentParser()
|
537 |
+
parser.add_argument('--config', type=str, default='./configs/sample.yml')
|
538 |
+
parser.add_argument('-i', '--data_id', type=int, default=0)
|
539 |
+
parser.add_argument('--device', type=str, default='cuda:0')
|
540 |
+
parser.add_argument('--outdir', type=str, default='./outputs')
|
541 |
+
parser.add_argument('--vocab_path', type=str, default='vocab.txt')
|
542 |
+
parser.add_argument('--num_workers', type=int, default=64)
|
543 |
+
args = parser.parse_args()
|
544 |
+
|
545 |
+
# Load vocab
|
546 |
+
vocab = []
|
547 |
+
for line in open(args.vocab_path):
|
548 |
+
p, _, _ = line.partition(':')
|
549 |
+
vocab.append(p)
|
550 |
+
vocab = Vocab(vocab)
|
551 |
+
|
552 |
+
# Load configs
|
553 |
+
config = load_config(args.config)
|
554 |
+
config_name = os.path.basename(args.config)[:os.path.basename(args.config).rfind('.')]
|
555 |
+
seed_all(config.sample.seed)
|
556 |
+
|
557 |
+
# Logging
|
558 |
+
log_dir = get_new_log_dir(args.outdir, prefix='%s-%d' % (config_name, args.data_id))
|
559 |
+
logger = get_logger('sample', log_dir)
|
560 |
+
logger.info(args)
|
561 |
+
logger.info(config)
|
562 |
+
shutil.copyfile(args.config, os.path.join(log_dir, os.path.basename(args.config)))
|
563 |
+
|
564 |
+
# Data
|
565 |
+
logger.info('Loading data...')
|
566 |
+
protein_featurizer = FeaturizeProteinAtom()
|
567 |
+
ligand_featurizer = FeaturizeLigandAtom()
|
568 |
+
masking = LigandMaskAll(vocab)
|
569 |
+
transform = Compose([
|
570 |
+
LigandCountNeighbors(),
|
571 |
+
protein_featurizer,
|
572 |
+
ligand_featurizer,
|
573 |
+
FeaturizeLigandBond(),
|
574 |
+
masking,
|
575 |
+
])
|
576 |
+
dataset, subsets = get_dataset(
|
577 |
+
config=config.dataset,
|
578 |
+
transform=transform,
|
579 |
+
)
|
580 |
+
testset = subsets['test']
|
581 |
+
data = testset[args.data_id]
|
582 |
+
center = data['ligand_center'].to(args.device)
|
583 |
+
test_set = [data for _ in range(config.sample.num_samples)]
|
584 |
+
|
585 |
+
with open(os.path.join(log_dir, 'pocket_info.txt'), 'a') as f:
|
586 |
+
f.write(data['protein_filename'] + '\n')
|
587 |
+
|
588 |
+
# Model (Main)
|
589 |
+
logger.info('Loading main model...')
|
590 |
+
ckpt = torch.load(config.model.checkpoint, map_location=args.device)
|
591 |
+
model = FLAG(
|
592 |
+
ckpt['config'].model,
|
593 |
+
protein_atom_feature_dim=protein_featurizer.feature_dim,
|
594 |
+
ligand_atom_feature_dim=ligand_featurizer.feature_dim,
|
595 |
+
vocab=vocab,
|
596 |
+
).to(args.device)
|
597 |
+
model.load_state_dict(ckpt['model'])
|
598 |
+
|
599 |
+
# my code goes here
|
600 |
+
sample_loader = DataLoader(test_set, batch_size=config.sample.batch_size,
|
601 |
+
shuffle=False, num_workers=config.sample.num_workers,
|
602 |
+
collate_fn=collate_mols)
|
603 |
+
data_list = []
|
604 |
+
try:
|
605 |
+
with torch.no_grad():
|
606 |
+
model.eval()
|
607 |
+
number = 0
|
608 |
+
number_list = []
|
609 |
+
for batch in tqdm(sample_loader):
|
610 |
+
for key in batch:
|
611 |
+
batch[key] = batch[key].to(args.device)
|
612 |
+
gen_data, pos_list = ligand_gen(batch, model, vocab, config, center, args.device)
|
613 |
+
SetMolPos(gen_data, pos_list)
|
614 |
+
for mol in gen_data:
|
615 |
+
try:
|
616 |
+
AllChem.UFFOptimizeMolecule(mol)
|
617 |
+
except:
|
618 |
+
print('UFF error')
|
619 |
+
data_list.extend(gen_data)
|
620 |
+
with open(os.path.join(log_dir, 'SMILES.txt'), 'a') as smiles_f:
|
621 |
+
for _, mol in enumerate(gen_data):
|
622 |
+
number+=1
|
623 |
+
if mol.GetNumAtoms() < 12 or MolLogP(mol) < 0.60:
|
624 |
+
continue
|
625 |
+
smiles_f.write(Chem.MolToSmiles(mol) + '\n')
|
626 |
+
writer = Chem.SDWriter(os.path.join(log_dir, '%d.sdf' % number))
|
627 |
+
# writer.SetKekulize(False)
|
628 |
+
writer.write(mol, confId=0)
|
629 |
+
writer.close()
|
630 |
+
number_list.append(number)
|
631 |
+
|
632 |
+
# Calculate metrics
|
633 |
+
print([Chem.MolToSmiles(mol) for mol in data_list])
|
634 |
+
smiles = [Chem.MolFromSmiles(Chem.MolToSmiles(mol)) for mol in data_list]
|
635 |
+
qed_list = [qed(mol) for mol in smiles if mol.GetNumAtoms() >= 8]
|
636 |
+
logp_list = [MolLogP(mol) for mol in smiles]
|
637 |
+
sa_list = [compute_sa_score(mol) for mol in smiles]
|
638 |
+
Lip_list = [lipinski(mol) for mol in smiles]
|
639 |
+
print('QED %.6f | LogP %.6f | SA %.6f | Lipinski %.6f \n' % (np.average(qed_list), np.average(logp_list), np.average(sa_list), np.average(Lip_list)))
|
640 |
+
|
641 |
+
except KeyboardInterrupt:
|
642 |
+
logger.info('Terminated. Generated molecules will be saved.')
|
643 |
+
with open(os.path.join(log_dir, 'SMILES.txt'), 'a') as smiles_f:
|
644 |
+
for i, mol in enumerate(data_list):
|
645 |
+
if mol.GetNumAtoms() < 12 or MolLogP(mol) < 0.60:
|
646 |
+
continue
|
647 |
+
smiles_f.write(Chem.MolToSmiles(mol) + '\n')
|
648 |
+
writer = Chem.SDWriter(os.path.join(log_dir, '%d.sdf' % i))
|
649 |
+
# writer.SetKekulize(False)
|
650 |
+
writer.write(mol, confId=0)
|
651 |
+
writer.close()
|
652 |
+
|
653 |
+
pool = mp.Pool(args.num_workers)
|
654 |
+
vina_list = []
|
655 |
+
pro_path = '/n/holyscratch01/mzitnik_lab/zaixizhang/pdbbind_pocket10/' + os.path.join(data['pdbid'], data['pdbid']+'_pocket.pdb')
|
656 |
+
for vina_score in tqdm(pool.imap_unordered(partial(calculate_vina, pro_path=pro_path, lig_path=log_dir), number_list), total=len(number_list)):
|
657 |
+
if vina_score != None:
|
658 |
+
vina_list.append(vina_score)
|
659 |
+
pool.close()
|
660 |
+
print('Vina: ', np.average(vina_list))
|
requirements.txt
ADDED
@@ -0,0 +1,390 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# This file may be used to create an environment using:
|
2 |
+
# $ conda create --name <env> --file <this file>
|
3 |
+
# platform: linux-64
|
4 |
+
_libgcc_mutex=0.1=conda_forge
|
5 |
+
_openmp_mutex=4.5=2_kmp_llvm
|
6 |
+
abseil-cpp=20211102.0=hd4dd3e8_0
|
7 |
+
absl-py=1.4.0=py38h06a4308_0
|
8 |
+
aiofiles=23.2.1=pypi_0
|
9 |
+
aiohttp=3.8.5=py38h5eee18b_0
|
10 |
+
aiosignal=1.2.0=pyhd3eb1b0_0
|
11 |
+
altair=5.1.2=pypi_0
|
12 |
+
amberlite=22.0=pypi_0
|
13 |
+
ambertools=22.0=py38h6177452_1
|
14 |
+
amberutils=21.0=pypi_0
|
15 |
+
annotated-types=0.6.0=pypi_0
|
16 |
+
antlr4-python3-runtime=4.9.3=pypi_0
|
17 |
+
anyio=3.7.1=pypi_0
|
18 |
+
appdirs=1.4.4=pyhd3eb1b0_0
|
19 |
+
argon2-cffi=21.3.0=pyhd3eb1b0_0
|
20 |
+
argon2-cffi-bindings=21.2.0=py38h7f8727e_0
|
21 |
+
arpack=3.7.0=hdefa2d7_2
|
22 |
+
arrow-cpp=11.0.0=hda39474_2
|
23 |
+
asttokens=2.0.5=pyhd3eb1b0_0
|
24 |
+
astunparse=1.6.3=py_0
|
25 |
+
async-timeout=4.0.2=py38h06a4308_0
|
26 |
+
attrs=22.1.0=py38h06a4308_0
|
27 |
+
autodocktools-py3=1.5.7.post1+9.gda0c87c=pypi_0
|
28 |
+
aws-c-common=0.4.57=he6710b0_1
|
29 |
+
aws-c-event-stream=0.1.6=h2531618_5
|
30 |
+
aws-checksums=0.1.9=he6710b0_0
|
31 |
+
aws-sdk-cpp=1.8.185=hce553d0_0
|
32 |
+
backcall=0.2.0=pyhd3eb1b0_0
|
33 |
+
beautifulsoup4=4.12.2=py38h06a4308_0
|
34 |
+
bio=1.5.9=pypi_0
|
35 |
+
biopython=1.81=pypi_0
|
36 |
+
biothings-client=0.3.0=pypi_0
|
37 |
+
blas=1.0=mkl
|
38 |
+
bleach=4.1.0=pyhd3eb1b0_0
|
39 |
+
blinker=1.4=py38h06a4308_0
|
40 |
+
blosc=1.21.3=h6a678d5_0
|
41 |
+
boost=1.74.0=py38h2b96118_5
|
42 |
+
boost-cpp=1.74.0=h75c5d50_8
|
43 |
+
bottleneck=1.3.5=py38h7deecbd_0
|
44 |
+
brotli=1.0.9=h5eee18b_7
|
45 |
+
brotli-bin=1.0.9=h5eee18b_7
|
46 |
+
brotlipy=0.7.0=py38h27cfd23_1003
|
47 |
+
bzip2=1.0.8=h7b6447c_0
|
48 |
+
c-ares=1.19.1=h5eee18b_0
|
49 |
+
c-blosc2=2.8.0=h6a678d5_0
|
50 |
+
ca-certificates=2023.08.22=h06a4308_0
|
51 |
+
cachetools=4.2.2=pyhd3eb1b0_0
|
52 |
+
cairo=1.16.0=hb05425b_5
|
53 |
+
certifi=2023.7.22=py38h06a4308_0
|
54 |
+
cffi=1.15.1=py38h74dc2b5_0
|
55 |
+
charset-normalizer=2.0.4=pyhd3eb1b0_0
|
56 |
+
click=8.0.4=py38h06a4308_0
|
57 |
+
comm=0.1.2=py38h06a4308_0
|
58 |
+
contourpy=1.0.5=py38hdb19cb5_0
|
59 |
+
cryptography=41.0.3=py38h130f0dd_0
|
60 |
+
cuda-cudart=11.7.99=0
|
61 |
+
cuda-cupti=11.7.101=0
|
62 |
+
cuda-libraries=11.7.1=0
|
63 |
+
cuda-nvrtc=11.7.99=0
|
64 |
+
cuda-nvtx=11.7.91=0
|
65 |
+
cuda-runtime=11.7.1=0
|
66 |
+
cudatoolkit=11.8.0=h6a678d5_0
|
67 |
+
curl=8.2.1=h37d81fd_0
|
68 |
+
cycler=0.11.0=pyhd3eb1b0_0
|
69 |
+
cython=3.0.2=py38h17151c0_0
|
70 |
+
dataclasses=0.8=pyh6d0b6a4_7
|
71 |
+
datamol=0.11.4=pypi_0
|
72 |
+
datasets=2.12.0=py38h06a4308_0
|
73 |
+
debugpy=1.6.7=py38h6a678d5_0
|
74 |
+
decorator=5.1.1=pyhd3eb1b0_0
|
75 |
+
defusedxml=0.7.1=pyhd3eb1b0_0
|
76 |
+
dill=0.3.6=py38h06a4308_0
|
77 |
+
docutils=0.17.1=pypi_0
|
78 |
+
easydict=1.9=py_0
|
79 |
+
entrypoints=0.4=py38h06a4308_0
|
80 |
+
exceptiongroup=1.1.3=pypi_0
|
81 |
+
executing=0.8.3=pyhd3eb1b0_0
|
82 |
+
expat=2.5.0=h6a678d5_0
|
83 |
+
fair-esm=2.0.0=pypi_0
|
84 |
+
fastapi=0.104.0=pypi_0
|
85 |
+
fasteners=0.19=pypi_0
|
86 |
+
ffmpeg=4.2.2=h20bf706_0
|
87 |
+
ffmpy=0.3.1=pypi_0
|
88 |
+
fftw=3.3.10=nompi_hf0379b8_106
|
89 |
+
filelock=3.9.0=py38h06a4308_0
|
90 |
+
fontconfig=2.14.2=h14ed4e7_0
|
91 |
+
fonttools=4.25.0=pyhd3eb1b0_0
|
92 |
+
freetype=2.12.1=h4a9f257_0
|
93 |
+
frozenlist=1.3.3=py38h5eee18b_0
|
94 |
+
fsspec=2023.4.0=py38h06a4308_0
|
95 |
+
gflags=2.2.2=he6710b0_0
|
96 |
+
giflib=5.2.1=h5eee18b_3
|
97 |
+
glib=2.69.1=h4ff587b_1
|
98 |
+
glog=0.5.0=h2531618_0
|
99 |
+
gmp=6.2.1=h295c915_3
|
100 |
+
gnutls=3.6.15=he1e5248_0
|
101 |
+
google-auth=2.22.0=py38h06a4308_0
|
102 |
+
google-auth-oauthlib=0.5.2=py38h06a4308_0
|
103 |
+
gprofiler-official=1.0.0=pypi_0
|
104 |
+
gradio=3.50.2=pypi_0
|
105 |
+
gradio-client=0.6.1=pypi_0
|
106 |
+
greenlet=2.0.1=py38h6a678d5_0
|
107 |
+
griddataformats=1.0.1=pypi_0
|
108 |
+
grpc-cpp=1.48.2=h5bf31a4_0
|
109 |
+
grpcio=1.48.2=py38h5bf31a4_0
|
110 |
+
gsd=3.1.1=pypi_0
|
111 |
+
h11=0.14.0=pypi_0
|
112 |
+
hdf4=4.2.15=h9772cbc_5
|
113 |
+
hdf5=1.12.1=nompi_h2386368_104
|
114 |
+
httpcore=0.18.0=pypi_0
|
115 |
+
httpx=0.25.0=pypi_0
|
116 |
+
huggingface_hub=0.15.1=py38h06a4308_0
|
117 |
+
icu=70.1=h27087fc_0
|
118 |
+
idna=3.4=py38h06a4308_0
|
119 |
+
importlib-metadata=6.0.0=py38h06a4308_0
|
120 |
+
importlib_metadata=6.0.0=hd3eb1b0_0
|
121 |
+
importlib_resources=5.2.0=pyhd3eb1b0_1
|
122 |
+
intel-openmp=2021.4.0=h06a4308_3561
|
123 |
+
ipykernel=6.25.0=py38h2f386ee_0
|
124 |
+
ipython=8.12.2=py38h06a4308_0
|
125 |
+
ipython_genutils=0.2.0=pyhd3eb1b0_1
|
126 |
+
jedi=0.18.1=py38h06a4308_1
|
127 |
+
jinja2=3.1.2=py38h06a4308_0
|
128 |
+
joblib=1.2.0=py38h06a4308_0
|
129 |
+
jpeg=9e=h5eee18b_1
|
130 |
+
jsonschema=4.17.3=py38h06a4308_0
|
131 |
+
jupyter_client=7.4.9=py38h06a4308_0
|
132 |
+
jupyter_core=5.3.0=py38h06a4308_0
|
133 |
+
jupyter_server=1.23.4=py38h06a4308_0
|
134 |
+
jupyterlab_pygments=0.1.2=py_0
|
135 |
+
kiwisolver=1.4.4=py38h6a678d5_0
|
136 |
+
krb5=1.20.1=h568e23c_1
|
137 |
+
lame=3.100=h7b6447c_0
|
138 |
+
lcms2=2.12=h3be6417_0
|
139 |
+
ld_impl_linux-64=2.38=h1181459_1
|
140 |
+
lerc=3.0=h295c915_0
|
141 |
+
libblas=3.9.0=12_linux64_mkl
|
142 |
+
libbrotlicommon=1.0.9=h5eee18b_7
|
143 |
+
libbrotlidec=1.0.9=h5eee18b_7
|
144 |
+
libbrotlienc=1.0.9=h5eee18b_7
|
145 |
+
libcublas=11.10.3.66=0
|
146 |
+
libcufft=10.7.2.124=h4fbf590_0
|
147 |
+
libcufile=1.7.1.12=0
|
148 |
+
libcurand=10.3.3.129=0
|
149 |
+
libcurl=8.2.1=h91b91d3_0
|
150 |
+
libcusolver=11.4.0.1=0
|
151 |
+
libcusparse=11.7.4.91=0
|
152 |
+
libdeflate=1.17=h5eee18b_0
|
153 |
+
libedit=3.1.20221030=h5eee18b_0
|
154 |
+
libev=4.33=h7f8727e_1
|
155 |
+
libevent=2.1.12=h8f2d780_0
|
156 |
+
libffi=3.3=he6710b0_2
|
157 |
+
libgcc-ng=13.1.0=he5830b7_0
|
158 |
+
libgfortran-ng=11.2.0=h00389a5_1
|
159 |
+
libgfortran5=11.2.0=h1234567_1
|
160 |
+
libgomp=13.1.0=he5830b7_0
|
161 |
+
libiconv=1.16=h7f8727e_2
|
162 |
+
libidn2=2.3.4=h5eee18b_0
|
163 |
+
liblapack=3.9.0=12_linux64_mkl
|
164 |
+
libnetcdf=4.8.1=nompi_h329d8a1_102
|
165 |
+
libnghttp2=1.52.0=ha637b67_1
|
166 |
+
libnpp=11.7.4.75=0
|
167 |
+
libnsl=2.0.0=h7f98852_0
|
168 |
+
libnvjpeg=11.8.0.2=0
|
169 |
+
libopus=1.3.1=h7b6447c_0
|
170 |
+
libpng=1.6.39=h5eee18b_0
|
171 |
+
libprotobuf=3.20.3=he621ea3_0
|
172 |
+
libsodium=1.0.18=h7b6447c_0
|
173 |
+
libssh2=1.10.0=h37d81fd_2
|
174 |
+
libstdcxx-ng=13.1.0=hfd8a6a1_0
|
175 |
+
libtasn1=4.19.0=h5eee18b_0
|
176 |
+
libthrift=0.15.0=h0d84882_2
|
177 |
+
libtiff=4.5.1=h6a678d5_0
|
178 |
+
libunistring=0.9.10=h27cfd23_0
|
179 |
+
libuuid=2.38.1=h0b41bf4_0
|
180 |
+
libvpx=1.7.0=h439df22_0
|
181 |
+
libwebp=1.2.4=h11a3e52_1
|
182 |
+
libwebp-base=1.2.4=h5eee18b_1
|
183 |
+
libxcb=1.15=h7f8727e_0
|
184 |
+
libxml2=2.9.14=h22db469_4
|
185 |
+
libxslt=1.1.35=h8affb1d_0
|
186 |
+
libzip=1.9.2=hc869a4a_1
|
187 |
+
libzlib=1.2.13=hd590300_5
|
188 |
+
littleutils=0.2.2=pypi_0
|
189 |
+
llvm-openmp=14.0.6=h9e868ea_0
|
190 |
+
loguru=0.7.2=pypi_0
|
191 |
+
lxml=4.9.1=py38h1edc446_0
|
192 |
+
lz4-c=1.9.4=h6a678d5_0
|
193 |
+
lzo=2.10=h7b6447c_2
|
194 |
+
markdown=3.4.1=py38h06a4308_0
|
195 |
+
markupsafe=2.1.1=py38h7f8727e_0
|
196 |
+
matplotlib-base=3.7.2=py38h1128e8f_0
|
197 |
+
matplotlib-inline=0.1.6=py38h06a4308_0
|
198 |
+
mdanalysis=2.4.3=pypi_0
|
199 |
+
mdtraj=1.9.9=py38h028faf2_0
|
200 |
+
meeko=0.1.dev3=pypi_0
|
201 |
+
mistune=0.8.4=py38h7b6447c_1000
|
202 |
+
mkl=2021.4.0=h06a4308_640
|
203 |
+
mkl-service=2.4.0=py38h7f8727e_0
|
204 |
+
mkl_fft=1.3.1=py38hd3c417c_0
|
205 |
+
mkl_random=1.2.2=py38h51133e4_0
|
206 |
+
mmcif-pdbx=2.0.1=pypi_0
|
207 |
+
mmpbsa-py=16.0=pypi_0
|
208 |
+
mmtf-python=1.1.3=pypi_0
|
209 |
+
mrcfile=1.4.3=pypi_0
|
210 |
+
mscorefonts=0.0.1=3
|
211 |
+
msgpack=1.0.6=pypi_0
|
212 |
+
multidict=6.0.2=py38h5eee18b_0
|
213 |
+
multiprocess=0.70.14=py38h06a4308_0
|
214 |
+
munkres=1.1.4=py_0
|
215 |
+
mygene=3.2.2=pypi_0
|
216 |
+
nb_conda_kernels=2.3.1=py38h06a4308_0
|
217 |
+
nbclassic=0.5.5=py38h06a4308_0
|
218 |
+
nbclient=0.5.13=py38h06a4308_0
|
219 |
+
nbconvert=6.5.4=py38h06a4308_0
|
220 |
+
nbformat=5.9.2=py38h06a4308_0
|
221 |
+
ncurses=6.4=h6a678d5_0
|
222 |
+
nest-asyncio=1.5.6=py38h06a4308_0
|
223 |
+
netcdf-fortran=4.5.4=nompi_h2b6e579_100
|
224 |
+
nettle=3.7.3=hbbd107a_1
|
225 |
+
networkx=3.1=pyhd8ed1ab_0
|
226 |
+
notebook=6.5.4=py38h06a4308_1
|
227 |
+
notebook-shim=0.2.2=py38h06a4308_0
|
228 |
+
numexpr=2.8.4=py38he184ba9_0
|
229 |
+
numpy=1.24.3=py38h14f4228_0
|
230 |
+
numpy-base=1.24.3=py38h31eccc5_0
|
231 |
+
oauthlib=3.2.2=py38h06a4308_0
|
232 |
+
ocl-icd=2.3.1=h7f98852_0
|
233 |
+
ocl-icd-system=1.0.0=1
|
234 |
+
ogb=1.3.6=pypi_0
|
235 |
+
omegaconf=2.3.0=pypi_0
|
236 |
+
openbabel=3.1.1=py38hd2c4bc0_3
|
237 |
+
openff-forcefields=2023.08.0=pyh1a96a4e_0
|
238 |
+
openff-toolkit=0.10.7=pyhd8ed1ab_0
|
239 |
+
openff-toolkit-base=0.10.7=pyhd8ed1ab_0
|
240 |
+
openh264=2.1.1=h4ff587b_0
|
241 |
+
openmm=8.0.0=py38hd11a18e_1
|
242 |
+
openssl=1.1.1w=h7f8727e_0
|
243 |
+
opt-einsum=3.3.0=pypi_0
|
244 |
+
orc=1.7.4=hb3bc3d3_1
|
245 |
+
orjson=3.9.9=pypi_0
|
246 |
+
outdated=0.2.2=pypi_0
|
247 |
+
packaging=23.1=py38h06a4308_0
|
248 |
+
packmol=20.010=h86c2bf4_0
|
249 |
+
packmol-memgen=1.2.3rc0=pypi_0
|
250 |
+
pandas=2.0.3=py38h1128e8f_0
|
251 |
+
pandocfilters=1.5.0=pyhd3eb1b0_0
|
252 |
+
parmed=3.4.4=py38h8dc9893_0
|
253 |
+
parso=0.8.3=pyhd3eb1b0_0
|
254 |
+
pcre=8.45=h295c915_0
|
255 |
+
pdb2pqr=3.6.1=pypi_0
|
256 |
+
pdb4amber=22.0=pypi_0
|
257 |
+
pdbfixer=1.9=pyh1a96a4e_0
|
258 |
+
perl=5.32.1=4_hd590300_perl5
|
259 |
+
pexpect=4.8.0=pyhd3eb1b0_3
|
260 |
+
pickleshare=0.7.5=pyhd3eb1b0_1003
|
261 |
+
pillow=9.4.0=py38h6a678d5_0
|
262 |
+
pip=23.2.1=py38h06a4308_0
|
263 |
+
pixman=0.40.0=h7f8727e_1
|
264 |
+
pkgutil-resolve-name=1.3.10=py38h06a4308_0
|
265 |
+
platformdirs=3.10.0=py38h06a4308_0
|
266 |
+
pooch=1.4.0=pyhd3eb1b0_0
|
267 |
+
posebusters=0.2.6=pypi_0
|
268 |
+
posecheck=0.1=dev_0
|
269 |
+
prolif=2.0.0.post1=pypi_0
|
270 |
+
prometheus_client=0.14.1=py38h06a4308_0
|
271 |
+
prompt-toolkit=3.0.36=py38h06a4308_0
|
272 |
+
propka=3.5.0=pypi_0
|
273 |
+
protobuf=3.20.3=py38h6a678d5_0
|
274 |
+
psutil=5.9.0=py38h5eee18b_0
|
275 |
+
ptyprocess=0.7.0=pyhd3eb1b0_2
|
276 |
+
pure_eval=0.2.2=pyhd3eb1b0_0
|
277 |
+
py-cpuinfo=8.0.0=pyhd3eb1b0_1
|
278 |
+
py3dmol=2.0.4=pypi_0
|
279 |
+
pyarrow=11.0.0=py38h468efa6_1
|
280 |
+
pyasn1=0.4.8=pyhd3eb1b0_0
|
281 |
+
pyasn1-modules=0.2.8=py_0
|
282 |
+
pycairo=1.23.0=py38hd1222b9_0
|
283 |
+
pycparser=2.21=pyhd3eb1b0_0
|
284 |
+
pydantic=2.4.2=pypi_0
|
285 |
+
pydantic-core=2.10.1=pypi_0
|
286 |
+
pydub=0.25.1=pypi_0
|
287 |
+
pyg-lib=0.2.0+pt113cu117=pypi_0
|
288 |
+
pygments=2.15.1=py38h06a4308_1
|
289 |
+
pygsp=0.5.1=pypi_0
|
290 |
+
pyjwt=2.4.0=py38h06a4308_0
|
291 |
+
pyopenssl=23.2.0=py38h06a4308_0
|
292 |
+
pyparsing=3.0.9=py38h06a4308_0
|
293 |
+
pyrsistent=0.18.0=py38heee7806_0
|
294 |
+
pysocks=1.7.1=py38h06a4308_0
|
295 |
+
pytables=3.8.0=py38hb8ae3fc_3
|
296 |
+
python=3.8.11=h12debd9_0_cpython
|
297 |
+
python-constraint=1.4.0=py_0
|
298 |
+
python-dateutil=2.8.2=pyhd3eb1b0_0
|
299 |
+
python-fastjsonschema=2.16.2=py38h06a4308_0
|
300 |
+
python-lmdb=1.4.1=py38h6a678d5_0
|
301 |
+
python-multipart=0.0.6=pypi_0
|
302 |
+
python-tzdata=2023.3=pyhd3eb1b0_0
|
303 |
+
python-xxhash=2.0.2=py38h5eee18b_1
|
304 |
+
python_abi=3.8=3_cp38
|
305 |
+
pytorch=1.13.0=py3.8_cuda11.7_cudnn8.5.0_0
|
306 |
+
pytorch-cuda=11.7=h778d358_5
|
307 |
+
pytorch-mutex=1.0=cuda
|
308 |
+
pytraj=2.0.6=pypi_0
|
309 |
+
pytz=2022.7=py38h06a4308_0
|
310 |
+
pyyaml=6.0=py38h5eee18b_1
|
311 |
+
pyzmq=23.2.0=py38h6a678d5_0
|
312 |
+
qvina=2.1.0=h62396cd_2
|
313 |
+
rdkit=2023.3.3=pypi_0
|
314 |
+
re2=2022.04.01=h295c915_0
|
315 |
+
readline=8.2=h5eee18b_0
|
316 |
+
reduce=3.24=0
|
317 |
+
regex=2022.7.9=py38h5eee18b_0
|
318 |
+
reportlab=3.5.67=py38hfdd840d_1
|
319 |
+
requests=2.31.0=py38h06a4308_0
|
320 |
+
requests-oauthlib=1.3.0=py_0
|
321 |
+
responses=0.13.3=pyhd3eb1b0_0
|
322 |
+
rsa=4.7.2=pyhd3eb1b0_1
|
323 |
+
sacremoses=0.0.43=pyhd3eb1b0_0
|
324 |
+
safetensors=0.3.2=py38hb02cf49_0
|
325 |
+
sander=22.0=pypi_0
|
326 |
+
scikit-learn=1.3.0=py38h1128e8f_0
|
327 |
+
scipy=1.10.1=py38h14f4228_0
|
328 |
+
seaborn=0.12.2=pypi_0
|
329 |
+
selfies=2.1.1=pypi_0
|
330 |
+
semantic-version=2.10.0=pypi_0
|
331 |
+
send2trash=1.8.0=pyhd3eb1b0_1
|
332 |
+
setuptools=68.0.0=py38h06a4308_0
|
333 |
+
six=1.16.0=pyhd3eb1b0_1
|
334 |
+
smirnoff99frosst=1.1.0=pyh44b312d_0
|
335 |
+
snappy=1.1.9=h295c915_0
|
336 |
+
sniffio=1.2.0=py38h06a4308_1
|
337 |
+
sortedcontainers=2.4.0=pypi_0
|
338 |
+
soupsieve=2.4=py38h06a4308_0
|
339 |
+
sqlalchemy=1.4.39=py38h5eee18b_0
|
340 |
+
sqlite=3.41.2=h5eee18b_0
|
341 |
+
stack_data=0.2.0=pyhd3eb1b0_0
|
342 |
+
starlette=0.27.0=pypi_0
|
343 |
+
tensorboard=2.12.1=py38h06a4308_0
|
344 |
+
tensorboard-data-server=0.7.0=py38h52d8a92_0
|
345 |
+
tensorboard-plugin-wit=1.8.1=py38h06a4308_0
|
346 |
+
terminado=0.17.1=py38h06a4308_0
|
347 |
+
threadpoolctl=2.2.0=pyh0d69192_0
|
348 |
+
tinycss2=1.2.1=py38h06a4308_0
|
349 |
+
tk=8.6.12=h1ccaba5_0
|
350 |
+
tokenizers=0.13.2=py38he7d60b5_1
|
351 |
+
toolz=0.12.0=pypi_0
|
352 |
+
torch-cluster=1.6.1+pt113cu117=pypi_0
|
353 |
+
torch-geometric=2.3.1=pypi_0
|
354 |
+
torch-scatter=2.1.1+pt113cu117=pypi_0
|
355 |
+
torch-sparse=0.6.17+pt113cu117=pypi_0
|
356 |
+
torch-spline-conv=1.2.2+pt113cu117=pypi_0
|
357 |
+
tornado=6.3.2=py38h5eee18b_0
|
358 |
+
tqdm=4.65.0=py38hb070fc8_0
|
359 |
+
traitlets=5.7.1=py38h06a4308_0
|
360 |
+
transformers=4.31.0=pyhd8ed1ab_0
|
361 |
+
typing-extensions=4.8.0=pypi_0
|
362 |
+
urllib3=1.26.16=py38h06a4308_0
|
363 |
+
utf8proc=2.6.1=h27cfd23_0
|
364 |
+
uvicorn=0.23.2=pypi_0
|
365 |
+
vina=1.2.2=pypi_0
|
366 |
+
wcwidth=0.2.5=pyhd3eb1b0_0
|
367 |
+
webencodings=0.5.1=py38_1
|
368 |
+
websocket-client=0.58.0=py38h06a4308_4
|
369 |
+
websockets=11.0.3=pypi_0
|
370 |
+
werkzeug=2.2.3=py38h06a4308_0
|
371 |
+
wheel=0.38.4=py38h06a4308_0
|
372 |
+
x264=1!157.20191217=h7b6447c_0
|
373 |
+
xmltodict=0.13.0=pyhd8ed1ab_0
|
374 |
+
xorg-kbproto=1.0.7=h7f98852_1002
|
375 |
+
xorg-libice=1.1.1=hd590300_0
|
376 |
+
xorg-libsm=1.2.4=h7391055_0
|
377 |
+
xorg-libx11=1.8.6=h8ee46fc_0
|
378 |
+
xorg-libxext=1.3.4=h0b41bf4_2
|
379 |
+
xorg-libxt=1.3.0=hd590300_1
|
380 |
+
xorg-xextproto=7.3.0=h0b41bf4_1003
|
381 |
+
xorg-xproto=7.0.31=h7f98852_1007
|
382 |
+
xxhash=0.8.0=h7f8727e_3
|
383 |
+
xz=5.2.10=h5eee18b_1
|
384 |
+
yaml=0.2.5=h7b6447c_0
|
385 |
+
yarl=1.8.1=py38h5eee18b_0
|
386 |
+
zeromq=4.3.4=h2531618_0
|
387 |
+
zipp=3.11.0=py38h06a4308_0
|
388 |
+
zlib=1.2.13=hd590300_5
|
389 |
+
zlib-ng=2.0.7=h5eee18b_0
|
390 |
+
zstd=1.5.5=hc292b87_0
|
utils/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
from .dihedral_utils import batch_dihedrals, rotation_matrix_v2, von_Mises_loss
|
2 |
+
from .chemutils import get_clique_mol, tree_decomp, get_mol, get_smiles, set_atommap, get_clique_mol_simple, assemble, mol_to_graph_data_obj_simple
|
3 |
+
from .dihedral_utils import rotation_matrix_v2, von_Mises_loss, batch_dihedrals
|
utils/chem.py
ADDED
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
import torch
|
3 |
+
from io import BytesIO
|
4 |
+
from openbabel import openbabel
|
5 |
+
from torch_geometric.utils import to_networkx
|
6 |
+
from torch_geometric.data import Data
|
7 |
+
from torch_scatter import scatter
|
8 |
+
from rdkit import Chem
|
9 |
+
from rdkit.Chem.rdchem import Mol, HybridizationType, BondType
|
10 |
+
from rdkit.Chem.rdchem import BondType as BT
|
11 |
+
|
12 |
+
|
13 |
+
BOND_TYPES = {t: i for i, t in enumerate(BT.names.values())}
|
14 |
+
BOND_NAMES = {i: t for i, t in enumerate(BT.names.keys())}
|
15 |
+
|
16 |
+
|
17 |
+
|
18 |
+
def rdmol_to_data(mol, smiles=None):
|
19 |
+
assert mol.GetNumConformers() == 1
|
20 |
+
N = mol.GetNumAtoms()
|
21 |
+
|
22 |
+
pos = torch.tensor(mol.GetConformer(0).GetPositions(), dtype=torch.float32)
|
23 |
+
|
24 |
+
atomic_number = []
|
25 |
+
aromatic = []
|
26 |
+
sp = []
|
27 |
+
sp2 = []
|
28 |
+
sp3 = []
|
29 |
+
num_hs = []
|
30 |
+
for atom in mol.GetAtoms():
|
31 |
+
atomic_number.append(atom.GetAtomicNum())
|
32 |
+
aromatic.append(1 if atom.GetIsAromatic() else 0)
|
33 |
+
hybridization = atom.GetHybridization()
|
34 |
+
sp.append(1 if hybridization == HybridizationType.SP else 0)
|
35 |
+
sp2.append(1 if hybridization == HybridizationType.SP2 else 0)
|
36 |
+
sp3.append(1 if hybridization == HybridizationType.SP3 else 0)
|
37 |
+
|
38 |
+
z = torch.tensor(atomic_number, dtype=torch.long)
|
39 |
+
|
40 |
+
row, col, edge_type = [], [], []
|
41 |
+
for bond in mol.GetBonds():
|
42 |
+
start, end = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
|
43 |
+
row += [start, end]
|
44 |
+
col += [end, start]
|
45 |
+
edge_type += 2 * [BOND_TYPES[bond.GetBondType()]]
|
46 |
+
|
47 |
+
edge_index = torch.tensor([row, col], dtype=torch.long)
|
48 |
+
edge_type = torch.tensor(edge_type)
|
49 |
+
|
50 |
+
perm = (edge_index[0] * N + edge_index[1]).argsort()
|
51 |
+
edge_index = edge_index[:, perm]
|
52 |
+
edge_type = edge_type[perm]
|
53 |
+
|
54 |
+
row, col = edge_index
|
55 |
+
hs = (z == 1).to(torch.float32)
|
56 |
+
|
57 |
+
num_hs = scatter(hs[row], col, dim_size=N, reduce='sum').tolist()
|
58 |
+
|
59 |
+
if smiles is None:
|
60 |
+
smiles = Chem.MolToSmiles(Chem.RemoveHs(mol))
|
61 |
+
|
62 |
+
data = Data(atom_type=z, pos=pos, edge_index=edge_index, edge_type=edge_type,
|
63 |
+
rdmol=copy.deepcopy(mol), smiles=smiles)
|
64 |
+
data.nx = to_networkx(data, to_undirected=True)
|
65 |
+
|
66 |
+
return data
|
67 |
+
|
68 |
+
|
69 |
+
def generated_to_xyz(data):
|
70 |
+
ptable = Chem.GetPeriodicTable()
|
71 |
+
|
72 |
+
num_atoms = data.ligand_context_element.size(0)
|
73 |
+
xyz = "%d\n\n" % (num_atoms, )
|
74 |
+
for i in range(num_atoms):
|
75 |
+
symb = ptable.GetElementSymbol(data.ligand_context_element[i].item())
|
76 |
+
x, y, z = data.ligand_context_pos[i].clone().cpu().tolist()
|
77 |
+
xyz += "%s %.8f %.8f %.8f\n" % (symb, x, y, z)
|
78 |
+
|
79 |
+
return xyz
|
80 |
+
|
81 |
+
|
82 |
+
def generated_to_sdf(data):
|
83 |
+
xyz = generated_to_xyz(data)
|
84 |
+
obConversion = openbabel.OBConversion()
|
85 |
+
obConversion.SetInAndOutFormats("xyz", "sdf")
|
86 |
+
|
87 |
+
mol = openbabel.OBMol()
|
88 |
+
obConversion.ReadString(mol, xyz)
|
89 |
+
sdf = obConversion.WriteString(mol)
|
90 |
+
return sdf
|
91 |
+
|
92 |
+
|
93 |
+
def sdf_to_rdmol(sdf):
|
94 |
+
stream = BytesIO(sdf.encode())
|
95 |
+
suppl = Chem.ForwardSDMolSupplier(stream)
|
96 |
+
for mol in suppl:
|
97 |
+
return mol
|
98 |
+
return None
|
99 |
+
|
100 |
+
def generated_to_rdmol(data):
|
101 |
+
sdf = generated_to_sdf(data)
|
102 |
+
return sdf_to_rdmol(sdf)
|
103 |
+
|
104 |
+
|
105 |
+
def filter_rd_mol(rdmol):
|
106 |
+
ring_info = rdmol.GetRingInfo()
|
107 |
+
ring_info.AtomRings()
|
108 |
+
rings = [set(r) for r in ring_info.AtomRings()]
|
109 |
+
|
110 |
+
# 3-3 ring intersection
|
111 |
+
for i, ring_a in enumerate(rings):
|
112 |
+
if len(ring_a) != 3:continue
|
113 |
+
for j, ring_b in enumerate(rings):
|
114 |
+
if i <= j: continue
|
115 |
+
inter = ring_a.intersection(ring_b)
|
116 |
+
if (len(ring_b) == 3) and (len(inter) > 0):
|
117 |
+
return False
|
118 |
+
|
119 |
+
return True
|
utils/chemutils.py
ADDED
@@ -0,0 +1,597 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import rdkit
|
2 |
+
import rdkit.Chem as Chem
|
3 |
+
from scipy.sparse import csr_matrix
|
4 |
+
from scipy.sparse.csgraph import minimum_spanning_tree
|
5 |
+
from collections import defaultdict
|
6 |
+
from rdkit.Chem.EnumerateStereoisomers import EnumerateStereoisomers, StereoEnumerationOptions
|
7 |
+
from rdkit.Chem.Descriptors import MolLogP, qed
|
8 |
+
from torch_geometric.data import Data, Batch
|
9 |
+
from random import sample
|
10 |
+
from rdkit.Chem.rdForceFieldHelpers import UFFOptimizeMolecule
|
11 |
+
import numpy as np
|
12 |
+
from math import sqrt
|
13 |
+
import torch
|
14 |
+
from copy import deepcopy
|
15 |
+
MST_MAX_WEIGHT = 100
|
16 |
+
MAX_NCAND = 2000
|
17 |
+
|
18 |
+
|
19 |
+
def vina_score(mol):
|
20 |
+
ligand_rdmol = Chem.AddHs(mol, addCoords=True)
|
21 |
+
if use_uff:
|
22 |
+
UFFOptimizeMolecule(ligand_rdmol)
|
23 |
+
|
24 |
+
def lipinski(mol):
|
25 |
+
if qed(mol)<=5 and Chem.Lipinski.NumHDonors(mol)<=5 and Chem.Lipinski.NumHAcceptors(mol)<=10 and Chem.Descriptors.ExactMolWt(mol)<=500 and Chem.Lipinski.NumRotatableBonds(mol)<=5:
|
26 |
+
return True
|
27 |
+
else:
|
28 |
+
return False
|
29 |
+
|
30 |
+
|
31 |
+
def list_filter(a,b):
|
32 |
+
filter = []
|
33 |
+
for i in a:
|
34 |
+
if i in b:
|
35 |
+
filter.append(i)
|
36 |
+
return filter
|
37 |
+
|
38 |
+
|
39 |
+
def rand_rotate(dir, ref, pos, alpha=None, device=None):
|
40 |
+
if device is None:
|
41 |
+
device = 'cpu'
|
42 |
+
dir = dir/torch.norm(dir)
|
43 |
+
if alpha is None:
|
44 |
+
alpha = torch.randn(1).to(device)
|
45 |
+
n_pos = pos.shape[0]
|
46 |
+
sin, cos = torch.sin(alpha).to(device), torch.cos(alpha).to(device)
|
47 |
+
K = 1 - cos
|
48 |
+
M = torch.dot(dir, ref)
|
49 |
+
nx, ny, nz = dir[0], dir[1], dir[2]
|
50 |
+
x0, y0, z0 = ref[0], ref[1], ref[2]
|
51 |
+
T = torch.tensor([nx ** 2 * K + cos, nx * ny * K - nz * sin, nx * nz * K + ny * sin,
|
52 |
+
(x0 - nx * M) * K + (nz * y0 - ny * z0) * sin,
|
53 |
+
nx * ny * K + nz * sin, ny ** 2 * K + cos, ny * nz * K - nx * sin,
|
54 |
+
(y0 - ny * M) * K + (nx * z0 - nz * x0) * sin,
|
55 |
+
nx * nz * K - ny * sin, ny * nz * K + nx * sin, nz ** 2 * K + cos,
|
56 |
+
(z0 - nz * M) * K + (ny * x0 - nx * y0) * sin,
|
57 |
+
0, 0, 0, 1], device=device).reshape(4, 4)
|
58 |
+
pos = torch.cat([pos.t(), torch.ones(n_pos, device=device).unsqueeze(0)], dim=0)
|
59 |
+
rotated_pos = torch.mm(T, pos)[:3]
|
60 |
+
return rotated_pos.t()
|
61 |
+
|
62 |
+
|
63 |
+
def kabsch(A, B):
|
64 |
+
# Input:
|
65 |
+
# Nominal A Nx3 matrix of points
|
66 |
+
# Measured B Nx3 matrix of points
|
67 |
+
# Returns R,t
|
68 |
+
# R = 3x3 rotation matrix (B to A)
|
69 |
+
# t = 3x1 translation vector (B to A)
|
70 |
+
assert len(A) == len(B)
|
71 |
+
N = A.shape[0] # total points
|
72 |
+
centroid_A = np.mean(A, axis=0)
|
73 |
+
centroid_B = np.mean(B, axis=0)
|
74 |
+
# center the points
|
75 |
+
AA = A - np.tile(centroid_A, (N, 1))
|
76 |
+
BB = B - np.tile(centroid_B, (N, 1))
|
77 |
+
H = np.transpose(BB) * AA
|
78 |
+
U, S, Vt = np.linalg.svd(H)
|
79 |
+
R = Vt.T * U.T
|
80 |
+
# special reflection case
|
81 |
+
if np.linalg.det(R) < 0:
|
82 |
+
Vt[2, :] *= -1
|
83 |
+
R = Vt.T * U.T
|
84 |
+
t = -R * centroid_B.T + centroid_A.T
|
85 |
+
return R, t
|
86 |
+
|
87 |
+
|
88 |
+
def kabsch_torch(A, B, C):
|
89 |
+
A=A.double()
|
90 |
+
B=B.double()
|
91 |
+
C=C.double()
|
92 |
+
a_mean = A.mean(dim=0, keepdims=True)
|
93 |
+
b_mean = B.mean(dim=0, keepdims=True)
|
94 |
+
A_c = A - a_mean
|
95 |
+
B_c = B - b_mean
|
96 |
+
# Covariance matrix
|
97 |
+
H = torch.matmul(A_c.transpose(0,1), B_c) # [B, 3, 3]
|
98 |
+
U, S, V = torch.svd(H)
|
99 |
+
# Rotation matrix
|
100 |
+
R = torch.matmul(V, U.transpose(0,1)) # [B, 3, 3]
|
101 |
+
# Translation vector
|
102 |
+
t = b_mean - torch.matmul(R, a_mean.transpose(0,1)).transpose(0,1)
|
103 |
+
C_aligned = torch.matmul(R, C.transpose(0,1)).transpose(0,1) + t
|
104 |
+
return C_aligned, R, t
|
105 |
+
|
106 |
+
|
107 |
+
def eig_coord_from_dist(D):
|
108 |
+
M = (D[:1, :] + D[:, :1] - D) / 2
|
109 |
+
L, V = torch.linalg.eigh(M)
|
110 |
+
L = torch.diag_embed(torch.sort(L, descending=True)[0])
|
111 |
+
X = torch.matmul(V, L.clamp(min=0).sqrt())
|
112 |
+
return X[:, :3].detach()
|
113 |
+
|
114 |
+
|
115 |
+
def self_square_dist(X):
|
116 |
+
dX = X.unsqueeze(0) - X.unsqueeze(1) # [1, N, 3] - [N, 1, 3]
|
117 |
+
D = torch.sum(dX**2, dim=-1)
|
118 |
+
return D
|
119 |
+
|
120 |
+
|
121 |
+
def set_atommap(mol, num=0):
|
122 |
+
for atom in mol.GetAtoms():
|
123 |
+
atom.SetAtomMapNum(num)
|
124 |
+
|
125 |
+
|
126 |
+
def get_mol(smiles):
|
127 |
+
mol = Chem.MolFromSmiles(smiles)
|
128 |
+
if mol is None:
|
129 |
+
return None
|
130 |
+
Chem.Kekulize(mol)
|
131 |
+
return mol
|
132 |
+
|
133 |
+
|
134 |
+
def get_smiles(mol):
|
135 |
+
return Chem.MolToSmiles(mol, kekuleSmiles=False)
|
136 |
+
|
137 |
+
|
138 |
+
def decode_stereo(smiles2D):
|
139 |
+
mol = Chem.MolFromSmiles(smiles2D)
|
140 |
+
dec_isomers = list(EnumerateStereoisomers(mol))
|
141 |
+
|
142 |
+
dec_isomers = [Chem.MolFromSmiles(Chem.MolToSmiles(mol, isomericSmiles=True)) for mol in dec_isomers]
|
143 |
+
smiles3D = [Chem.MolToSmiles(mol, isomericSmiles=True) for mol in dec_isomers]
|
144 |
+
|
145 |
+
chiralN = [atom.GetIdx() for atom in dec_isomers[0].GetAtoms() if
|
146 |
+
int(atom.GetChiralTag()) > 0 and atom.GetSymbol() == "N"]
|
147 |
+
if len(chiralN) > 0:
|
148 |
+
for mol in dec_isomers:
|
149 |
+
for idx in chiralN:
|
150 |
+
mol.GetAtomWithIdx(idx).SetChiralTag(Chem.rdchem.ChiralType.CHI_UNSPECIFIED)
|
151 |
+
smiles3D.append(Chem.MolToSmiles(mol, isomericSmiles=True))
|
152 |
+
|
153 |
+
return smiles3D
|
154 |
+
|
155 |
+
|
156 |
+
def sanitize(mol):
|
157 |
+
try:
|
158 |
+
smiles = get_smiles(mol)
|
159 |
+
mol = get_mol(smiles)
|
160 |
+
except Exception as e:
|
161 |
+
return None
|
162 |
+
return mol
|
163 |
+
|
164 |
+
|
165 |
+
def copy_atom(atom):
|
166 |
+
new_atom = Chem.Atom(atom.GetSymbol())
|
167 |
+
new_atom.SetFormalCharge(atom.GetFormalCharge())
|
168 |
+
new_atom.SetAtomMapNum(atom.GetAtomMapNum())
|
169 |
+
return new_atom
|
170 |
+
|
171 |
+
|
172 |
+
def copy_edit_mol(mol):
|
173 |
+
new_mol = Chem.RWMol(Chem.MolFromSmiles(''))
|
174 |
+
for atom in mol.GetAtoms():
|
175 |
+
new_atom = copy_atom(atom)
|
176 |
+
new_mol.AddAtom(new_atom)
|
177 |
+
for bond in mol.GetBonds():
|
178 |
+
a1 = bond.GetBeginAtom().GetIdx()
|
179 |
+
a2 = bond.GetEndAtom().GetIdx()
|
180 |
+
bt = bond.GetBondType()
|
181 |
+
new_mol.AddBond(a1, a2, bt)
|
182 |
+
return new_mol
|
183 |
+
|
184 |
+
|
185 |
+
def get_submol(mol, idxs, mark=[]):
|
186 |
+
new_mol = Chem.RWMol(Chem.MolFromSmiles(''))
|
187 |
+
map = {}
|
188 |
+
for atom in mol.GetAtoms():
|
189 |
+
if atom.GetIdx() in idxs:
|
190 |
+
new_atom = copy_atom(atom)
|
191 |
+
if atom.GetIdx() in mark:
|
192 |
+
new_atom.SetAtomMapNum(1)
|
193 |
+
else:
|
194 |
+
new_atom.SetAtomMapNum(0)
|
195 |
+
map[atom.GetIdx()] = new_mol.AddAtom(new_atom)
|
196 |
+
for bond in mol.GetBonds():
|
197 |
+
a1 = bond.GetBeginAtom().GetIdx()
|
198 |
+
a2 = bond.GetEndAtom().GetIdx()
|
199 |
+
if a1 in idxs and a2 in idxs:
|
200 |
+
bt = bond.GetBondType()
|
201 |
+
new_mol.AddBond(map[a1], map[a2], bt)
|
202 |
+
return new_mol.GetMol()
|
203 |
+
|
204 |
+
|
205 |
+
def get_clique_mol(mol, atoms):
|
206 |
+
smiles = Chem.MolFragmentToSmiles(mol, atoms, kekuleSmiles=True)
|
207 |
+
new_mol = Chem.MolFromSmiles(smiles, sanitize=False)
|
208 |
+
new_mol = copy_edit_mol(new_mol).GetMol()
|
209 |
+
new_mol = sanitize(new_mol) # We assume this is not None
|
210 |
+
return new_mol
|
211 |
+
|
212 |
+
|
213 |
+
def get_clique_mol_simple(mol, cluster):
|
214 |
+
smile_cluster = Chem.MolFragmentToSmiles(mol, cluster, canonical=True, kekuleSmiles=True)
|
215 |
+
mol_cluster = Chem.MolFromSmiles(smile_cluster, sanitize=False)
|
216 |
+
return mol_cluster
|
217 |
+
|
218 |
+
|
219 |
+
def tree_decomp(mol, reference_vocab=None):
|
220 |
+
edges = defaultdict(int)
|
221 |
+
n_atoms = mol.GetNumAtoms()
|
222 |
+
clusters = []
|
223 |
+
for bond in mol.GetBonds():
|
224 |
+
a1 = bond.GetBeginAtom().GetIdx()
|
225 |
+
a2 = bond.GetEndAtom().GetIdx()
|
226 |
+
if not bond.IsInRing():
|
227 |
+
clusters.append({a1, a2})
|
228 |
+
# extract rotatable bonds
|
229 |
+
|
230 |
+
ssr = [set(x) for x in Chem.GetSymmSSSR(mol)]
|
231 |
+
# remove too large circles
|
232 |
+
ssr = [x for x in ssr if len(x) <= 8]
|
233 |
+
|
234 |
+
# Merge Rings with intersection >= 2 atoms
|
235 |
+
# check the reference_vocab if it is not None
|
236 |
+
for i in range(len(ssr)-1):
|
237 |
+
if len(ssr[i]) <= 2:
|
238 |
+
continue
|
239 |
+
for j in range(i+1, len(ssr)):
|
240 |
+
if len(ssr[j]) <= 2:
|
241 |
+
continue
|
242 |
+
inter = ssr[i] & ssr[j]
|
243 |
+
if reference_vocab is not None:
|
244 |
+
if len(inter) >= 2:
|
245 |
+
merge = ssr[i] | ssr[j]
|
246 |
+
smile_merge = Chem.MolFragmentToSmiles(mol, merge, canonical=True, kekuleSmiles=True)
|
247 |
+
if reference_vocab[smile_merge] <= 100 and len(inter) == 2:
|
248 |
+
continue
|
249 |
+
ssr[i] = merge
|
250 |
+
ssr[j] = set()
|
251 |
+
else:
|
252 |
+
if len(inter) > 2:
|
253 |
+
merge = ssr[i] | ssr[j]
|
254 |
+
ssr[i] = merge
|
255 |
+
ssr[j] = set()
|
256 |
+
|
257 |
+
ssr = [c for c in ssr if len(c) > 0]
|
258 |
+
clusters.extend(ssr)
|
259 |
+
nei_list = [[] for _ in range(n_atoms)]
|
260 |
+
for i in range(len(clusters)):
|
261 |
+
for atom in clusters[i]:
|
262 |
+
nei_list[atom].append(i)
|
263 |
+
|
264 |
+
# Build edges
|
265 |
+
for atom in range(n_atoms):
|
266 |
+
if len(nei_list[atom]) <= 1:
|
267 |
+
continue
|
268 |
+
cnei = nei_list[atom]
|
269 |
+
for i in range(len(cnei)):
|
270 |
+
for j in range(i + 1, len(cnei)):
|
271 |
+
c1, c2 = cnei[i], cnei[j]
|
272 |
+
inter = set(clusters[c1]) & set(clusters[c2])
|
273 |
+
if edges[(c1, c2)] < len(inter):
|
274 |
+
edges[(c1, c2)] = len(inter) # cnei[i] < cnei[j] by construction
|
275 |
+
|
276 |
+
edges = [u + (MST_MAX_WEIGHT - v,) for u, v in edges.items()]
|
277 |
+
if len(edges) == 0:
|
278 |
+
return clusters, edges
|
279 |
+
|
280 |
+
# Compute Maximum Spanning Tree
|
281 |
+
row, col, data = zip(*edges)
|
282 |
+
n_clique = len(clusters)
|
283 |
+
clique_graph = csr_matrix((data, (row, col)), shape=(n_clique, n_clique))
|
284 |
+
junc_tree = minimum_spanning_tree(clique_graph)
|
285 |
+
row, col = junc_tree.nonzero()
|
286 |
+
edges = [(row[i], col[i]) for i in range(len(row))]
|
287 |
+
return clusters, edges
|
288 |
+
|
289 |
+
|
290 |
+
def atom_equal(a1, a2):
|
291 |
+
return a1.GetSymbol() == a2.GetSymbol() and a1.GetFormalCharge() == a2.GetFormalCharge()
|
292 |
+
|
293 |
+
|
294 |
+
# Bond type not considered because all aromatic (so SINGLE matches DOUBLE)
|
295 |
+
def ring_bond_equal(bond1, bond2, reverse=False):
|
296 |
+
b1 = (bond1.GetBeginAtom(), bond1.GetEndAtom())
|
297 |
+
if reverse:
|
298 |
+
b2 = (bond2.GetEndAtom(), bond2.GetBeginAtom())
|
299 |
+
else:
|
300 |
+
b2 = (bond2.GetBeginAtom(), bond2.GetEndAtom())
|
301 |
+
return atom_equal(b1[0], b2[0]) and atom_equal(b1[1], b2[1]) and bond1.GetBondType() == bond2.GetBondType()
|
302 |
+
|
303 |
+
|
304 |
+
def attach(ctr_mol, nei_mol, amap):
|
305 |
+
ctr_mol = Chem.RWMol(ctr_mol)
|
306 |
+
for atom in nei_mol.GetAtoms():
|
307 |
+
if atom.GetIdx() not in amap:
|
308 |
+
new_atom = copy_atom(atom)
|
309 |
+
new_atom.SetAtomMapNum(2)
|
310 |
+
amap[atom.GetIdx()] = ctr_mol.AddAtom(new_atom)
|
311 |
+
|
312 |
+
for bond in nei_mol.GetBonds():
|
313 |
+
a1 = amap[bond.GetBeginAtom().GetIdx()]
|
314 |
+
a2 = amap[bond.GetEndAtom().GetIdx()]
|
315 |
+
if ctr_mol.GetBondBetweenAtoms(a1, a2) is None:
|
316 |
+
ctr_mol.AddBond(a1, a2, bond.GetBondType())
|
317 |
+
|
318 |
+
return ctr_mol.GetMol(), amap
|
319 |
+
|
320 |
+
|
321 |
+
def attach_mols(ctr_mol, neighbors, prev_nodes, nei_amap):
|
322 |
+
prev_nids = [node.nid for node in prev_nodes]
|
323 |
+
for nei_node in prev_nodes + neighbors:
|
324 |
+
nei_id, nei_mol = nei_node.nid, nei_node.mol
|
325 |
+
amap = nei_amap[nei_id]
|
326 |
+
for atom in nei_mol.GetAtoms():
|
327 |
+
if atom.GetIdx() not in amap:
|
328 |
+
new_atom = copy_atom(atom)
|
329 |
+
amap[atom.GetIdx()] = ctr_mol.AddAtom(new_atom)
|
330 |
+
|
331 |
+
if nei_mol.GetNumBonds() == 0:
|
332 |
+
nei_atom = nei_mol.GetAtomWithIdx(0)
|
333 |
+
ctr_atom = ctr_mol.GetAtomWithIdx(amap[0])
|
334 |
+
ctr_atom.SetAtomMapNum(nei_atom.GetAtomMapNum())
|
335 |
+
else:
|
336 |
+
for bond in nei_mol.GetBonds():
|
337 |
+
a1 = amap[bond.GetBeginAtom().GetIdx()]
|
338 |
+
a2 = amap[bond.GetEndAtom().GetIdx()]
|
339 |
+
if ctr_mol.GetBondBetweenAtoms(a1, a2) is None:
|
340 |
+
ctr_mol.AddBond(a1, a2, bond.GetBondType())
|
341 |
+
elif nei_id in prev_nids: # father node overrides
|
342 |
+
ctr_mol.RemoveBond(a1, a2)
|
343 |
+
ctr_mol.AddBond(a1, a2, bond.GetBondType())
|
344 |
+
return ctr_mol
|
345 |
+
|
346 |
+
|
347 |
+
def local_attach(ctr_mol, neighbors, prev_nodes, amap_list):
|
348 |
+
ctr_mol = copy_edit_mol(ctr_mol)
|
349 |
+
nei_amap = {nei.nid: {} for nei in prev_nodes + neighbors}
|
350 |
+
|
351 |
+
for nei_id, ctr_atom, nei_atom in amap_list:
|
352 |
+
nei_amap[nei_id][nei_atom] = ctr_atom
|
353 |
+
|
354 |
+
ctr_mol = attach_mols(ctr_mol, neighbors, prev_nodes, nei_amap)
|
355 |
+
return ctr_mol.GetMol()
|
356 |
+
|
357 |
+
|
358 |
+
# This version records idx mapping between ctr_mol and nei_mol
|
359 |
+
def enum_attach(ctr_mol, nei_mol):
|
360 |
+
try:
|
361 |
+
Chem.Kekulize(ctr_mol)
|
362 |
+
Chem.Kekulize(nei_mol)
|
363 |
+
except:
|
364 |
+
return []
|
365 |
+
att_confs = []
|
366 |
+
valence_ctr = {i: 0 for i in range(ctr_mol.GetNumAtoms())}
|
367 |
+
valence_nei = {i: 0 for i in range(nei_mol.GetNumAtoms())}
|
368 |
+
ctr_bonds = [bond for bond in ctr_mol.GetBonds() if bond.GetBeginAtom().GetAtomMapNum() == 1 and bond.GetEndAtom().GetAtomMapNum() == 1]
|
369 |
+
ctr_atoms = [atom for atom in ctr_mol.GetAtoms() if atom.GetAtomMapNum() == 1]
|
370 |
+
if nei_mol.GetNumBonds() == 1: # neighbor is a bond
|
371 |
+
bond = nei_mol.GetBondWithIdx(0)
|
372 |
+
#bond_val = int(bond.GetBondType())
|
373 |
+
bond_val = int(bond.GetBondTypeAsDouble())
|
374 |
+
b1, b2 = bond.GetBeginAtom(), bond.GetEndAtom()
|
375 |
+
|
376 |
+
for atom in ctr_atoms:
|
377 |
+
# Optimize if atom is carbon (other atoms may change valence)
|
378 |
+
if atom.GetAtomicNum() == 6 and atom.GetTotalNumHs() < bond_val:
|
379 |
+
continue
|
380 |
+
if atom_equal(atom, b1):
|
381 |
+
new_amap = {b1.GetIdx(): atom.GetIdx()}
|
382 |
+
att_confs.append(new_amap)
|
383 |
+
elif atom_equal(atom, b2):
|
384 |
+
new_amap = {b2.GetIdx(): atom.GetIdx()}
|
385 |
+
att_confs.append(new_amap)
|
386 |
+
else:
|
387 |
+
# intersection is an atom
|
388 |
+
for a1 in ctr_atoms:
|
389 |
+
for a2 in nei_mol.GetAtoms():
|
390 |
+
if atom_equal(a1, a2):
|
391 |
+
# Optimize if atom is carbon (other atoms may change valence)
|
392 |
+
if a1.GetAtomicNum() == 6 and a1.GetTotalNumHs() + a2.GetTotalNumHs() < 4:
|
393 |
+
continue
|
394 |
+
amap = {a2.GetIdx(): a1.GetIdx()}
|
395 |
+
att_confs.append(amap)
|
396 |
+
|
397 |
+
# intersection is an bond
|
398 |
+
if ctr_mol.GetNumBonds() > 1:
|
399 |
+
for b1 in ctr_bonds:
|
400 |
+
for b2 in nei_mol.GetBonds():
|
401 |
+
if ring_bond_equal(b1, b2):
|
402 |
+
amap = {b2.GetBeginAtom().GetIdx(): b1.GetBeginAtom().GetIdx(),
|
403 |
+
b2.GetEndAtom().GetIdx(): b1.GetEndAtom().GetIdx()}
|
404 |
+
att_confs.append(amap)
|
405 |
+
|
406 |
+
if ring_bond_equal(b1, b2, reverse=True):
|
407 |
+
amap = {b2.GetEndAtom().GetIdx(): b1.GetBeginAtom().GetIdx(),
|
408 |
+
b2.GetBeginAtom().GetIdx(): b1.GetEndAtom().GetIdx()}
|
409 |
+
att_confs.append(amap)
|
410 |
+
return att_confs
|
411 |
+
|
412 |
+
|
413 |
+
def enumerate_assemble(mol, idxs, current, next):
|
414 |
+
ctr_mol = get_submol(mol, idxs, mark=current.clique)
|
415 |
+
ground_truth = get_submol(mol, list(set(idxs) | set(next.clique)))
|
416 |
+
# submol can also obtained with get_clique_mol, future exploration
|
417 |
+
ground_truth_smiles = get_smiles(ground_truth)
|
418 |
+
cand_smiles = []
|
419 |
+
cand_mols = []
|
420 |
+
cand_amap = enum_attach(ctr_mol, next.mol)
|
421 |
+
for amap in cand_amap:
|
422 |
+
try:
|
423 |
+
cand_mol, _ = attach(ctr_mol, next.mol, amap)
|
424 |
+
cand_mol = sanitize(cand_mol)
|
425 |
+
except:
|
426 |
+
continue
|
427 |
+
if cand_mol is None:
|
428 |
+
continue
|
429 |
+
smiles = get_smiles(cand_mol)
|
430 |
+
if smiles in cand_smiles or smiles == ground_truth_smiles:
|
431 |
+
continue
|
432 |
+
cand_smiles.append(smiles)
|
433 |
+
cand_mols.append(cand_mol)
|
434 |
+
if len(cand_mols) >= 1:
|
435 |
+
cand_mols = sample(cand_mols, 1)
|
436 |
+
cand_mols.append(ground_truth)
|
437 |
+
labels = torch.tensor([0, 1])
|
438 |
+
else:
|
439 |
+
cand_mols = [ground_truth]
|
440 |
+
labels = torch.tensor([1])
|
441 |
+
|
442 |
+
return labels, cand_mols
|
443 |
+
|
444 |
+
|
445 |
+
# allowable node and edge features
|
446 |
+
allowable_features = {
|
447 |
+
'possible_atomic_num_list' : list(range(1, 119)),
|
448 |
+
'possible_formal_charge_list' : [-5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5],
|
449 |
+
'possible_chirality_list' : [
|
450 |
+
Chem.rdchem.ChiralType.CHI_UNSPECIFIED,
|
451 |
+
Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CW,
|
452 |
+
Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CCW,
|
453 |
+
Chem.rdchem.ChiralType.CHI_OTHER
|
454 |
+
],
|
455 |
+
'possible_hybridization_list' : [
|
456 |
+
Chem.rdchem.HybridizationType.S,
|
457 |
+
Chem.rdchem.HybridizationType.SP, Chem.rdchem.HybridizationType.SP2,
|
458 |
+
Chem.rdchem.HybridizationType.SP3, Chem.rdchem.HybridizationType.SP3D,
|
459 |
+
Chem.rdchem.HybridizationType.SP3D2, Chem.rdchem.HybridizationType.UNSPECIFIED
|
460 |
+
],
|
461 |
+
'possible_numH_list' : [0, 1, 2, 3, 4, 5, 6, 7, 8],
|
462 |
+
'possible_implicit_valence_list' : [0, 1, 2, 3, 4, 5, 6],
|
463 |
+
'possible_degree_list' : [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
|
464 |
+
'possible_bonds' : [
|
465 |
+
Chem.rdchem.BondType.SINGLE,
|
466 |
+
Chem.rdchem.BondType.DOUBLE,
|
467 |
+
Chem.rdchem.BondType.TRIPLE,
|
468 |
+
Chem.rdchem.BondType.AROMATIC
|
469 |
+
],
|
470 |
+
'possible_bond_dirs' : [ # only for double bond stereo information
|
471 |
+
Chem.rdchem.BondDir.NONE,
|
472 |
+
Chem.rdchem.BondDir.ENDUPRIGHT,
|
473 |
+
Chem.rdchem.BondDir.ENDDOWNRIGHT
|
474 |
+
]
|
475 |
+
}
|
476 |
+
|
477 |
+
def mol_to_graph_data_obj_simple(mol):
|
478 |
+
"""
|
479 |
+
Converts rdkit mol object to graph Data object required by the pytorch
|
480 |
+
geometric package. NB: Uses simplified atom and bond features, and represent
|
481 |
+
as indices
|
482 |
+
:param mol: rdkit mol object
|
483 |
+
:return: graph data object with the attributes: x, edge_index, edge_attr
|
484 |
+
"""
|
485 |
+
# atoms
|
486 |
+
num_atom_features = 2 # atom type, chirality tag
|
487 |
+
atom_features_list = []
|
488 |
+
for atom in mol.GetAtoms():
|
489 |
+
atom_feature = [allowable_features['possible_atomic_num_list'].index(
|
490 |
+
atom.GetAtomicNum())] + [allowable_features[
|
491 |
+
'possible_chirality_list'].index(atom.GetChiralTag())]
|
492 |
+
atom_features_list.append(atom_feature)
|
493 |
+
x = torch.tensor(np.array(atom_features_list), dtype=torch.long)
|
494 |
+
|
495 |
+
# bonds
|
496 |
+
num_bond_features = 2 # bond type, bond direction
|
497 |
+
if len(mol.GetBonds()) > 0: # mol has bonds
|
498 |
+
edges_list = []
|
499 |
+
edge_features_list = []
|
500 |
+
for bond in mol.GetBonds():
|
501 |
+
i = bond.GetBeginAtomIdx()
|
502 |
+
j = bond.GetEndAtomIdx()
|
503 |
+
edge_feature = [allowable_features['possible_bonds'].index(
|
504 |
+
bond.GetBondType())] + [allowable_features[
|
505 |
+
'possible_bond_dirs'].index(
|
506 |
+
bond.GetBondDir())]
|
507 |
+
edges_list.append((i, j))
|
508 |
+
edge_features_list.append(edge_feature)
|
509 |
+
edges_list.append((j, i))
|
510 |
+
edge_features_list.append(edge_feature)
|
511 |
+
|
512 |
+
# data.edge_index: Graph connectivity in COO format with shape [2, num_edges]
|
513 |
+
edge_index = torch.tensor(np.array(edges_list).T, dtype=torch.long)
|
514 |
+
|
515 |
+
# data.edge_attr: Edge feature matrix with shape [num_edges, num_edge_features]
|
516 |
+
edge_attr = torch.tensor(np.array(edge_features_list),
|
517 |
+
dtype=torch.long)
|
518 |
+
else: # mol has no bonds
|
519 |
+
edge_index = torch.empty((2, 0), dtype=torch.long)
|
520 |
+
edge_attr = torch.empty((0, num_bond_features), dtype=torch.long)
|
521 |
+
|
522 |
+
data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr)
|
523 |
+
|
524 |
+
return data
|
525 |
+
|
526 |
+
|
527 |
+
# For inference
|
528 |
+
def assemble(mol_list, next_motif_smiles):
|
529 |
+
attach_fail = torch.zeros(len(mol_list)).bool()
|
530 |
+
cand_mols, cand_batch, new_atoms, cand_smiles, one_atom_attach, intersection = [], [], [], [], [], []
|
531 |
+
for i in range(len(mol_list)):
|
532 |
+
next = Chem.MolFromSmiles(next_motif_smiles[i])
|
533 |
+
cand_amap = enum_attach(mol_list[i], next)
|
534 |
+
if len(cand_amap) == 0:
|
535 |
+
attach_fail[i] = True
|
536 |
+
cand_mols.append(mol_list[i])
|
537 |
+
cand_batch.append(i)
|
538 |
+
one_atom_attach.append(-1)
|
539 |
+
intersection.append([])
|
540 |
+
new_atoms.append([])
|
541 |
+
else:
|
542 |
+
valid_cand = 0
|
543 |
+
for amap in cand_amap:
|
544 |
+
amap_len = len(amap)
|
545 |
+
iter_atoms = [v for v in amap.values()]
|
546 |
+
ctr_mol = deepcopy(mol_list[i])
|
547 |
+
cand_mol, amap1 = attach(ctr_mol, next, amap)
|
548 |
+
if sanitize(deepcopy(cand_mol)) is None:
|
549 |
+
continue
|
550 |
+
smiles = get_smiles(cand_mol)
|
551 |
+
cand_smiles.append(smiles)
|
552 |
+
cand_mols.append(cand_mol)
|
553 |
+
cand_batch.append(i)
|
554 |
+
new_atoms.append([v for v in amap1.values()])
|
555 |
+
one_atom_attach.append(amap_len)
|
556 |
+
intersection.append(iter_atoms)
|
557 |
+
valid_cand+=1
|
558 |
+
if valid_cand==0:
|
559 |
+
attach_fail[i] = True
|
560 |
+
cand_mols.append(mol_list[i])
|
561 |
+
cand_batch.append(i)
|
562 |
+
one_atom_attach.append(-1)
|
563 |
+
intersection.append([])
|
564 |
+
new_atoms.append([])
|
565 |
+
cand_batch = torch.tensor(cand_batch)
|
566 |
+
one_atom_attach = torch.tensor(one_atom_attach) == 1
|
567 |
+
return cand_mols, cand_batch, new_atoms, one_atom_attach, intersection, attach_fail
|
568 |
+
|
569 |
+
|
570 |
+
if __name__ == "__main__":
|
571 |
+
import sys
|
572 |
+
from mol_tree import MolTree
|
573 |
+
|
574 |
+
lg = rdkit.RDLogger.logger()
|
575 |
+
lg.setLevel(rdkit.RDLogger.CRITICAL)
|
576 |
+
|
577 |
+
smiles = ["O=C1[C@@H]2C=C[C@@H](C=CC2)C1(c1ccccc1)c1ccccc1", "O=C([O-])CC[C@@]12CCCC[C@]1(O)OC(=O)CC2",
|
578 |
+
"ON=C1C[C@H]2CC3(C[C@@H](C1)c1ccccc12)OCCO3",
|
579 |
+
"C[C@H]1CC(=O)[C@H]2[C@@]3(O)C(=O)c4cccc(O)c4[C@@H]4O[C@@]43[C@@H](O)C[C@]2(O)C1",
|
580 |
+
'Cc1cc(NC(=O)CSc2nnc3c4ccccc4n(C)c3n2)ccc1Br', 'CC(C)(C)c1ccc(C(=O)N[C@H]2CCN3CCCc4cccc2c43)cc1',
|
581 |
+
"O=c1c2ccc3c(=O)n(-c4nccs4)c(=O)c4ccc(c(=O)n1-c1nccs1)c2c34", "O=C(N1CCc2c(F)ccc(F)c2C1)C1(O)Cc2ccccc2C1"]
|
582 |
+
mol_tree = MolTree("C")
|
583 |
+
assert len(mol_tree.nodes) > 0
|
584 |
+
|
585 |
+
|
586 |
+
def count():
|
587 |
+
cnt, n = 0, 0
|
588 |
+
for s in sys.stdin:
|
589 |
+
s = s.split()[0]
|
590 |
+
tree = MolTree(s)
|
591 |
+
tree.recover()
|
592 |
+
tree.assemble()
|
593 |
+
for node in tree.nodes:
|
594 |
+
cnt += len(node.cands)
|
595 |
+
n += len(tree.nodes)
|
596 |
+
# print cnt * 1.0 / n
|
597 |
+
count()
|
utils/data.py
ADDED
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
import torch
|
3 |
+
import numpy as np
|
4 |
+
from torch_geometric.data import Data, Batch
|
5 |
+
# from torch_geometric.loader import DataLoader
|
6 |
+
from torch.utils.data import Dataset
|
7 |
+
|
8 |
+
FOLLOW_BATCH = ['protein_element', 'ligand_context_element', 'pos_real', 'pos_fake']
|
9 |
+
|
10 |
+
|
11 |
+
class ProteinLigandData(object):
|
12 |
+
|
13 |
+
def __init__(self, *args, **kwargs):
|
14 |
+
super().__init__(*args, **kwargs)
|
15 |
+
|
16 |
+
@staticmethod
|
17 |
+
def from_protein_ligand_dicts(protein_dict=None, ligand_dict=None, **kwargs):
|
18 |
+
instance = ProteinLigandData(**kwargs)
|
19 |
+
|
20 |
+
if protein_dict is not None:
|
21 |
+
for key, item in protein_dict.items():
|
22 |
+
instance['protein_' + key] = item
|
23 |
+
|
24 |
+
if ligand_dict is not None:
|
25 |
+
for key, item in ligand_dict.items():
|
26 |
+
if key == 'moltree':
|
27 |
+
instance['moltree'] = item
|
28 |
+
else:
|
29 |
+
instance['ligand_' + key] = item
|
30 |
+
|
31 |
+
# instance['ligand_nbh_list'] = {i.item():[j.item() for k, j in enumerate(instance.ligand_bond_index[1]) if instance.ligand_bond_index[0, k].item() == i] for i in instance.ligand_bond_index[0]}
|
32 |
+
return instance
|
33 |
+
|
34 |
+
|
35 |
+
def batch_from_data_list(data_list):
|
36 |
+
return Batch.from_data_list(data_list, follow_batch=['ligand_element', 'protein_element'])
|
37 |
+
|
38 |
+
|
39 |
+
def torchify_dict(data):
|
40 |
+
output = {}
|
41 |
+
for k, v in data.items():
|
42 |
+
if isinstance(v, np.ndarray):
|
43 |
+
output[k] = torch.from_numpy(v)
|
44 |
+
else:
|
45 |
+
output[k] = v
|
46 |
+
return output
|
47 |
+
|
48 |
+
|
49 |
+
def collate_mols(mol_dicts):
|
50 |
+
data_batch = {}
|
51 |
+
batch_size = len(mol_dicts)
|
52 |
+
for key in ['protein_pos', 'protein_atom_feature', 'ligand_context_pos', 'ligand_context_feature_full',
|
53 |
+
'ligand_frontier', 'num_atoms', 'next_wid', 'current_wid', 'current_atoms', 'cand_labels',
|
54 |
+
'ligand_pos_torsion', 'ligand_feature_torsion', 'true_sin', 'true_cos', 'true_three_hop',
|
55 |
+
'dihedral_mask', 'protein_contact', 'true_dm', 'alpha_carbon_indicator']:
|
56 |
+
data_batch[key] = torch.cat([mol_dict[key] for mol_dict in mol_dicts], dim=0)
|
57 |
+
# unsqueeze dim0
|
58 |
+
for key in ['xn_pos', 'yn_pos', 'ligand_torsion_xy_index', 'y_pos']:
|
59 |
+
cat_list = [mol_dict[key].unsqueeze(0) for mol_dict in mol_dicts if len(mol_dict[key]) > 0]
|
60 |
+
if len(cat_list) > 0:
|
61 |
+
data_batch[key] = torch.cat(cat_list, dim=0)
|
62 |
+
else:
|
63 |
+
data_batch[key] = torch.tensor([])
|
64 |
+
# follow batch
|
65 |
+
for key in ['protein_element', 'ligand_context_element', 'current_atoms']:
|
66 |
+
repeats = torch.tensor([len(mol_dict[key]) for mol_dict in mol_dicts])
|
67 |
+
data_batch[key + '_batch'] = torch.repeat_interleave(torch.arange(batch_size), repeats)
|
68 |
+
for key in ['ligand_element_torsion']:
|
69 |
+
repeats = torch.tensor([len(mol_dict[key]) for mol_dict in mol_dicts if len(mol_dict[key]) > 0])
|
70 |
+
if len(repeats) > 0:
|
71 |
+
data_batch[key + '_batch'] = torch.repeat_interleave(torch.arange(len(repeats)), repeats)
|
72 |
+
else:
|
73 |
+
data_batch[key + '_batch'] = torch.tensor([])
|
74 |
+
# distance matrix prediction
|
75 |
+
p_idx, q_idx = torch.cartesian_prod(torch.arange(4), torch.arange(2)).chunk(2, dim=-1)
|
76 |
+
p_idx, q_idx = p_idx.squeeze(-1), q_idx.squeeze(-1)
|
77 |
+
protein_offsets = torch.cumsum(data_batch['protein_element_batch'].bincount(), dim=0)
|
78 |
+
ligand_offsets = torch.cumsum(data_batch['ligand_context_element_batch'].bincount(), dim=0)
|
79 |
+
protein_offsets, ligand_offsets = torch.cat([torch.tensor([0]), protein_offsets]), torch.cat([torch.tensor([0]), ligand_offsets])
|
80 |
+
ligand_idx, protein_idx = [], []
|
81 |
+
for i, mol_dict in enumerate(mol_dicts):
|
82 |
+
if len(mol_dict['true_dm']) > 0:
|
83 |
+
protein_idx.append(mol_dict['dm_protein_idx'][p_idx] + protein_offsets[i])
|
84 |
+
ligand_idx.append(mol_dict['dm_ligand_idx'][q_idx] + ligand_offsets[i])
|
85 |
+
if len(ligand_idx) > 0:
|
86 |
+
data_batch['dm_ligand_idx'], data_batch['dm_protein_idx'] = torch.cat(ligand_idx), torch.cat(protein_idx)
|
87 |
+
|
88 |
+
# structure refinement (alpha carbon - ligand atom)
|
89 |
+
sr_ligand_idx, sr_protein_idx = [], []
|
90 |
+
for i, mol_dict in enumerate(mol_dicts):
|
91 |
+
if len(mol_dict['true_dm']) > 0:
|
92 |
+
ligand_atom_index = torch.arange(len(mol_dict['ligand_context_pos']))
|
93 |
+
p_idx, q_idx = torch.cartesian_prod(torch.arange(len(mol_dict['ligand_context_pos'])), torch.arange(len(mol_dict['protein_alpha_carbon_index']))).chunk(2, dim=-1)
|
94 |
+
p_idx, q_idx = p_idx.squeeze(-1), q_idx.squeeze(-1)
|
95 |
+
sr_ligand_idx.append(ligand_atom_index[p_idx] + ligand_offsets[i])
|
96 |
+
sr_protein_idx.append(mol_dict['protein_alpha_carbon_index'][q_idx] + protein_offsets[i])
|
97 |
+
if len(sr_ligand_idx) > 0:
|
98 |
+
data_batch['sr_ligand_idx'], data_batch['sr_protein_idx'] = torch.cat(sr_ligand_idx).long(), torch.cat(sr_protein_idx).long()
|
99 |
+
|
100 |
+
# structure refinement (ligand atom - ligand atom)
|
101 |
+
sr_ligand_idx0, sr_ligand_idx1 = [], []
|
102 |
+
for i, mol_dict in enumerate(mol_dicts):
|
103 |
+
if len(mol_dict['true_dm']) > 0:
|
104 |
+
ligand_atom_index = torch.arange(len(mol_dict['ligand_context_pos']))
|
105 |
+
p_idx, q_idx = torch.cartesian_prod(torch.arange(len(mol_dict['ligand_context_pos'])), torch.arange(len(mol_dict['ligand_context_pos']))).chunk(2, dim=-1)
|
106 |
+
p_idx, q_idx = p_idx.squeeze(-1), q_idx.squeeze(-1)
|
107 |
+
sr_ligand_idx0.append(ligand_atom_index[p_idx] + ligand_offsets[i])
|
108 |
+
sr_ligand_idx1.append(ligand_atom_index[q_idx] + ligand_offsets[i])
|
109 |
+
if len(ligand_idx) > 0:
|
110 |
+
data_batch['sr_ligand_idx0'], data_batch['sr_ligand_idx1'] = torch.cat(sr_ligand_idx0).long(), torch.cat(sr_ligand_idx1).long()
|
111 |
+
# index
|
112 |
+
if len(data_batch['y_pos']) > 0:
|
113 |
+
repeats = torch.tensor([len(mol_dict['ligand_element_torsion']) for mol_dict in mol_dicts if len(mol_dict['ligand_element_torsion']) > 0])
|
114 |
+
offsets = torch.cat([torch.tensor([0]), torch.cumsum(repeats, dim=0)])[:-1]
|
115 |
+
data_batch['ligand_torsion_xy_index'] += offsets.unsqueeze(1)
|
116 |
+
|
117 |
+
offsets1 = torch.cat([torch.tensor([0]), torch.cumsum(data_batch['num_atoms'], dim=0)])[:-1]
|
118 |
+
data_batch['current_atoms'] += torch.repeat_interleave(offsets1, data_batch['current_atoms_batch'].bincount())
|
119 |
+
# cand mols: torch geometric Data
|
120 |
+
cand_mol_list = []
|
121 |
+
for data in mol_dicts:
|
122 |
+
if len(data['cand_labels']) > 0:
|
123 |
+
cand_mol_list.extend(data['cand_mols'])
|
124 |
+
if len(cand_mol_list) > 0:
|
125 |
+
data_batch['cand_mols'] = Batch.from_data_list(cand_mol_list)
|
126 |
+
return data_batch
|
127 |
+
|
utils/datasets/__init__.py
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch.utils.data import Subset
|
3 |
+
from .pl import PocketLigandPairDataset
|
4 |
+
import random
|
5 |
+
|
6 |
+
|
7 |
+
def get_dataset(config, *args, **kwargs):
|
8 |
+
name = config.name
|
9 |
+
root = config.path
|
10 |
+
if name == 'pl':
|
11 |
+
dataset = PocketLigandPairDataset(root, *args, **kwargs)
|
12 |
+
else:
|
13 |
+
raise NotImplementedError('Unknown dataset: %s' % name)
|
14 |
+
|
15 |
+
if 'split' in config:
|
16 |
+
split_by_name = torch.load(config.split)
|
17 |
+
split = {k: [dataset.name2id[n] for n in names if n in dataset.name2id] for k, names in split_by_name.items()}
|
18 |
+
subsets = {k:Subset(dataset, indices=v) for k, v in split.items()}
|
19 |
+
return dataset, subsets
|
20 |
+
else:
|
21 |
+
return dataset
|
utils/datasets/pl.py
ADDED
@@ -0,0 +1,176 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import pickle
|
3 |
+
import lmdb
|
4 |
+
import torch
|
5 |
+
from torch.utils.data import Dataset
|
6 |
+
from tqdm.auto import tqdm
|
7 |
+
import numpy as np
|
8 |
+
|
9 |
+
from ..protein_ligand import PDBProtein, parse_sdf_file
|
10 |
+
from ..data import ProteinLigandData, torchify_dict
|
11 |
+
from ..mol_tree import MolTree
|
12 |
+
|
13 |
+
|
14 |
+
def reset_moltree_root(moltree, ligand_pos, protein_pos):
|
15 |
+
ligand2 = np.sum(np.square(ligand_pos), 1, keepdims=True)
|
16 |
+
protein2 = np.sum(np.square(protein_pos), 1, keepdims=True)
|
17 |
+
dist = np.add(np.add(-2 * np.dot(ligand_pos, protein_pos.T), ligand2), protein2.T)
|
18 |
+
min_dist = np.min(dist, 1)
|
19 |
+
avg_min_dist = []
|
20 |
+
for node in moltree.nodes:
|
21 |
+
avg_min_dist.append(np.min(min_dist[node.clique]))
|
22 |
+
root = np.argmin(avg_min_dist)
|
23 |
+
if root > 0:
|
24 |
+
moltree.nodes[0], moltree.nodes[root] = moltree.nodes[root], moltree.nodes[0]
|
25 |
+
contact_idx = np.argmin(np.min(dist[moltree.nodes[0].clique], 0))
|
26 |
+
contact_protein = torch.tensor(np.min(dist, 0) < 4 ** 2)
|
27 |
+
|
28 |
+
return moltree, contact_protein, torch.tensor([contact_idx])
|
29 |
+
|
30 |
+
|
31 |
+
def from_protein_ligand_dicts(protein_dict=None, ligand_dict=None):
|
32 |
+
instance = {}
|
33 |
+
|
34 |
+
if protein_dict is not None:
|
35 |
+
for key, item in protein_dict.items():
|
36 |
+
instance['protein_' + key] = item
|
37 |
+
|
38 |
+
if ligand_dict is not None:
|
39 |
+
for key, item in ligand_dict.items():
|
40 |
+
if key == 'moltree':
|
41 |
+
instance['moltree'] = item
|
42 |
+
else:
|
43 |
+
instance['ligand_' + key] = item
|
44 |
+
return instance
|
45 |
+
|
46 |
+
|
47 |
+
class PocketLigandPairDataset(Dataset):
|
48 |
+
|
49 |
+
def __init__(self, raw_path, transform=None):
|
50 |
+
super().__init__()
|
51 |
+
self.raw_path = raw_path.rstrip('/')
|
52 |
+
self.index_path = os.path.join(self.raw_path, 'index.pt')
|
53 |
+
self.processed_path = os.path.join(os.path.dirname(self.raw_path),
|
54 |
+
os.path.basename(self.raw_path) + '_processed.lmdb')
|
55 |
+
self.name2id_path = os.path.join(os.path.dirname(self.raw_path),
|
56 |
+
os.path.basename(self.raw_path) + '_name2id.pt')
|
57 |
+
self.transform = transform
|
58 |
+
self.db = None
|
59 |
+
|
60 |
+
self.keys = None
|
61 |
+
|
62 |
+
if not os.path.exists(self.processed_path):
|
63 |
+
self._process()
|
64 |
+
self._precompute_name2id()
|
65 |
+
|
66 |
+
self.name2id = torch.load(self.name2id_path)
|
67 |
+
|
68 |
+
def _connect_db(self):
|
69 |
+
"""
|
70 |
+
Establish read-only database connection
|
71 |
+
"""
|
72 |
+
assert self.db is None, 'A connection has already been opened.'
|
73 |
+
self.db = lmdb.open(
|
74 |
+
self.processed_path,
|
75 |
+
map_size=10 * (1024 * 1024 * 1024), # 10GB
|
76 |
+
create=False,
|
77 |
+
subdir=False,
|
78 |
+
readonly=True,
|
79 |
+
lock=False,
|
80 |
+
readahead=False,
|
81 |
+
meminit=False,
|
82 |
+
)
|
83 |
+
with self.db.begin() as txn:
|
84 |
+
self.keys = list(txn.cursor().iternext(values=False))
|
85 |
+
|
86 |
+
def _close_db(self):
|
87 |
+
self.db.close()
|
88 |
+
self.db = None
|
89 |
+
self.keys = None
|
90 |
+
|
91 |
+
def _process(self):
|
92 |
+
db = lmdb.open(
|
93 |
+
self.processed_path,
|
94 |
+
map_size=10 * (1024 * 1024 * 1024), # 10GB
|
95 |
+
create=True,
|
96 |
+
subdir=False,
|
97 |
+
readonly=False, # Writable
|
98 |
+
)
|
99 |
+
#with open(self.index_path, 'rb') as f:
|
100 |
+
#index = pickle.load(f)
|
101 |
+
index = torch.load(self.index_path)
|
102 |
+
vocab = []
|
103 |
+
for line in open('./vocab.txt'):
|
104 |
+
p, _, _ = line.partition(':')
|
105 |
+
vocab.append(p)
|
106 |
+
|
107 |
+
num_skipped = 0
|
108 |
+
with db.begin(write=True, buffers=True) as txn:
|
109 |
+
for i, pdbid in enumerate(tqdm(index)):
|
110 |
+
if pdbid is None: continue
|
111 |
+
try:
|
112 |
+
ligand_fn = os.path.join(pdbid, pdbid + '_ligand.sdf')
|
113 |
+
pocket_fn = os.path.join(pdbid, pdbid + '_pocket.pdb')
|
114 |
+
pocket_dict = PDBProtein(os.path.join(self.raw_path, pocket_fn)).to_dict_atom()
|
115 |
+
ligand_dict = parse_sdf_file(os.path.join(self.raw_path, ligand_fn))
|
116 |
+
ligand_dict['moltree'], pocket_dict['contact'], pocket_dict['contact_idx'] = reset_moltree_root(
|
117 |
+
ligand_dict['moltree'],
|
118 |
+
ligand_dict['pos'],
|
119 |
+
pocket_dict['pos'])
|
120 |
+
data = from_protein_ligand_dicts(
|
121 |
+
protein_dict=torchify_dict(pocket_dict),
|
122 |
+
ligand_dict=torchify_dict(ligand_dict),
|
123 |
+
)
|
124 |
+
data['protein_filename'] = pocket_fn
|
125 |
+
data['ligand_filename'] = ligand_fn
|
126 |
+
data['pdbid'] = pdbid
|
127 |
+
txn.put(
|
128 |
+
key=str(i).encode(),
|
129 |
+
value=pickle.dumps(data)
|
130 |
+
)
|
131 |
+
for c in ligand_dict['moltree'].nodes:
|
132 |
+
smile_cluster = c.smiles
|
133 |
+
assert smile_cluster in vocab
|
134 |
+
except:
|
135 |
+
num_skipped += 1
|
136 |
+
print('Skipping (%d) %s' % (num_skipped, ligand_fn,))
|
137 |
+
continue
|
138 |
+
db.close()
|
139 |
+
|
140 |
+
def _precompute_name2id(self):
|
141 |
+
name2id = {}
|
142 |
+
for i in tqdm(range(self.__len__()), 'Indexing'):
|
143 |
+
try:
|
144 |
+
data = self.__getitem__(i)
|
145 |
+
except AssertionError as e:
|
146 |
+
print(i, e)
|
147 |
+
continue
|
148 |
+
name = data['pdbid']
|
149 |
+
name2id[name] = i
|
150 |
+
torch.save(name2id, self.name2id_path)
|
151 |
+
|
152 |
+
def __len__(self):
|
153 |
+
if self.db is None:
|
154 |
+
self._connect_db()
|
155 |
+
return len(self.keys)
|
156 |
+
|
157 |
+
def __getitem__(self, idx):
|
158 |
+
if self.db is None:
|
159 |
+
self._connect_db()
|
160 |
+
key = self.keys[idx]
|
161 |
+
data = pickle.loads(self.db.begin().get(key))
|
162 |
+
data['id'] = idx
|
163 |
+
assert data['protein_pos'].size(0) > 0
|
164 |
+
if self.transform is not None:
|
165 |
+
data = self.transform(data)
|
166 |
+
return data
|
167 |
+
|
168 |
+
|
169 |
+
if __name__ == '__main__':
|
170 |
+
import argparse
|
171 |
+
|
172 |
+
parser = argparse.ArgumentParser()
|
173 |
+
parser.add_argument('path', type=str)
|
174 |
+
args = parser.parse_args()
|
175 |
+
|
176 |
+
PocketLigandPairDataset(args.path)
|
utils/dihedral_utils.py
ADDED
@@ -0,0 +1,383 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch_geometric as tg
|
3 |
+
from torch_geometric.utils import degree
|
4 |
+
import networkx as nx
|
5 |
+
|
6 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
7 |
+
angle_mask_ref = torch.LongTensor([[0, 0, 0, 0, 0, 0],
|
8 |
+
[0, 0, 0, 0, 0, 0],
|
9 |
+
[1, 0, 0, 0, 0, 0],
|
10 |
+
[1, 1, 1, 0, 0, 0],
|
11 |
+
[1, 1, 1, 1, 1, 1]]).to(device)
|
12 |
+
|
13 |
+
angle_combos = torch.LongTensor([[0, 1],
|
14 |
+
[0, 2],
|
15 |
+
[1, 2],
|
16 |
+
[0, 3],
|
17 |
+
[1, 3],
|
18 |
+
[2, 3]]).to(device)
|
19 |
+
|
20 |
+
|
21 |
+
def get_neighbor_ids(data):
|
22 |
+
"""
|
23 |
+
Takes the edge indices and returns dictionary mapping atom index to neighbor indices
|
24 |
+
Note: this only includes atoms with degree > 1
|
25 |
+
"""
|
26 |
+
# start, end = edge_index
|
27 |
+
# idxs, vals = torch.unique(start, return_counts=True)
|
28 |
+
# vs = torch.split_with_sizes(end, tuple(vals))
|
29 |
+
# return {k.item(): v for k, v in zip(idxs, vs) if len(v) > 1}
|
30 |
+
neighbors = data.neighbors.pop(0)
|
31 |
+
n_atoms_per_mol = data.batch.bincount()
|
32 |
+
n_atoms_prev_mol = 0
|
33 |
+
|
34 |
+
for i, n_dict in enumerate(data.neighbors):
|
35 |
+
new_dict = {}
|
36 |
+
n_atoms_prev_mol += n_atoms_per_mol[i].item()
|
37 |
+
for k, v in n_dict.items():
|
38 |
+
new_dict[k + n_atoms_prev_mol] = v + n_atoms_prev_mol
|
39 |
+
neighbors.update(new_dict)
|
40 |
+
return neighbors
|
41 |
+
|
42 |
+
|
43 |
+
def get_neighbor_bonds(edge_index, bond_type):
|
44 |
+
"""
|
45 |
+
Takes the edge indices and bond type and returns dictionary mapping atom index to neighbor bond types
|
46 |
+
Note: this only includes atoms with degree > 1
|
47 |
+
"""
|
48 |
+
start, end = edge_index
|
49 |
+
idxs, vals = torch.unique(start, return_counts=True)
|
50 |
+
vs = torch.split_with_sizes(bond_type, tuple(vals))
|
51 |
+
return {k.item(): v for k, v in zip(idxs, vs) if len(v) > 1}
|
52 |
+
|
53 |
+
|
54 |
+
def get_leaf_hydrogens(neighbors, x):
|
55 |
+
"""
|
56 |
+
Takes the edge indices and atom features and returns dictionary mapping atom index to neighbors, indicating true
|
57 |
+
for hydrogens that are leaf nodes
|
58 |
+
Note: this only works because degree = 1 and hydrogen atomic number = 1 (checks when 1 == 1)
|
59 |
+
Note: we use the 5th feature index bc this corresponds to the atomic number
|
60 |
+
"""
|
61 |
+
# start, end = edge_index
|
62 |
+
# degrees = degree(end)
|
63 |
+
# idxs, vals = torch.unique(start, return_counts=True)
|
64 |
+
# vs = torch.split_with_sizes(end, tuple(vals))
|
65 |
+
# return {k.item(): degrees[v] == x[v, 5] for k, v in zip(idxs, vs) if len(v) > 1}
|
66 |
+
leaf_hydrogens = {}
|
67 |
+
h_mask = x[:, 0] == 1
|
68 |
+
for k, v in neighbors.items():
|
69 |
+
leaf_hydrogens[k] = h_mask[neighbors[k]]
|
70 |
+
return leaf_hydrogens
|
71 |
+
|
72 |
+
|
73 |
+
def get_dihedral_pairs(edge_index, data):
|
74 |
+
"""
|
75 |
+
Given edge indices, return pairs of indices that we must calculate dihedrals for
|
76 |
+
"""
|
77 |
+
start, end = edge_index
|
78 |
+
degrees = degree(end)
|
79 |
+
dihedral_pairs_true = torch.nonzero(torch.logical_and(degrees[start] > 1, degrees[end] > 1))
|
80 |
+
dihedral_pairs = edge_index[:, dihedral_pairs_true].squeeze(-1)
|
81 |
+
|
82 |
+
# # first method which removes one (pseudo) random edge from a cycle
|
83 |
+
dihedral_idxs = torch.nonzero(dihedral_pairs.sort(dim=0).indices[0, :] == 0).squeeze().detach().cpu().numpy()
|
84 |
+
|
85 |
+
# prioritize rings for assigning dihedrals
|
86 |
+
dihedral_pairs = dihedral_pairs.t()[dihedral_idxs]
|
87 |
+
G = nx.to_undirected(tg.utils.to_networkx(data))
|
88 |
+
cycles = nx.cycle_basis(G)
|
89 |
+
keep, sorted_keep = [], []
|
90 |
+
|
91 |
+
if len(dihedral_pairs.shape) == 1:
|
92 |
+
dihedral_pairs = dihedral_pairs.unsqueeze(0)
|
93 |
+
|
94 |
+
for pair in dihedral_pairs:
|
95 |
+
x, y = pair
|
96 |
+
|
97 |
+
if sorted(pair) in sorted_keep:
|
98 |
+
continue
|
99 |
+
|
100 |
+
y_cycle_check = [y in cycle for cycle in cycles]
|
101 |
+
x_cycle_check = [x in cycle for cycle in cycles]
|
102 |
+
|
103 |
+
if any(x_cycle_check) and any(y_cycle_check): # both in new cycle
|
104 |
+
cycle_indices = get_current_cycle_indices(cycles, x_cycle_check, x)
|
105 |
+
keep.extend(cycle_indices)
|
106 |
+
|
107 |
+
sorted_keep.extend([sorted(c) for c in cycle_indices])
|
108 |
+
continue
|
109 |
+
|
110 |
+
if any(y_cycle_check):
|
111 |
+
cycle_indices = get_current_cycle_indices(cycles, y_cycle_check, y)
|
112 |
+
keep.append(pair)
|
113 |
+
keep.extend(cycle_indices)
|
114 |
+
|
115 |
+
sorted_keep.append(sorted(pair))
|
116 |
+
sorted_keep.extend([sorted(c) for c in cycle_indices])
|
117 |
+
continue
|
118 |
+
|
119 |
+
keep.append(pair)
|
120 |
+
|
121 |
+
keep = [t.to(device) for t in keep]
|
122 |
+
return torch.stack(keep).t()
|
123 |
+
|
124 |
+
|
125 |
+
def batch_distance_metrics_from_coords(coords, mask):
|
126 |
+
"""
|
127 |
+
Given coordinates of neighboring atoms, compute bond
|
128 |
+
distances and 2-hop distances in local neighborhood
|
129 |
+
"""
|
130 |
+
d_mat_mask = mask.unsqueeze(1) * mask.unsqueeze(2)
|
131 |
+
|
132 |
+
if coords.dim() == 4:
|
133 |
+
two_dop_d_mat = torch.square(coords.unsqueeze(1) - coords.unsqueeze(2) + 1e-10).sum(dim=-1).sqrt() * d_mat_mask.unsqueeze(-1)
|
134 |
+
one_hop_ds = torch.linalg.norm(torch.zeros_like(coords[0]).unsqueeze(0) - coords, dim=-1)
|
135 |
+
elif coords.dim() == 5:
|
136 |
+
two_dop_d_mat = torch.square(coords.unsqueeze(2) - coords.unsqueeze(3) + 1e-10).sum(dim=-1).sqrt() * d_mat_mask.unsqueeze(-1).unsqueeze(1)
|
137 |
+
one_hop_ds = torch.linalg.norm(torch.zeros_like(coords[0]).unsqueeze(0) - coords, dim=-1)
|
138 |
+
|
139 |
+
return one_hop_ds, two_dop_d_mat
|
140 |
+
|
141 |
+
|
142 |
+
def batch_angle_between_vectors(a, b):
|
143 |
+
"""
|
144 |
+
Compute angle between two batches of input vectors
|
145 |
+
"""
|
146 |
+
inner_product = (a * b).sum(dim=-1)
|
147 |
+
|
148 |
+
# norms
|
149 |
+
a_norm = torch.linalg.norm(a, dim=-1)
|
150 |
+
b_norm = torch.linalg.norm(b, dim=-1)
|
151 |
+
|
152 |
+
# protect denominator during division
|
153 |
+
den = a_norm * b_norm + 1e-10
|
154 |
+
cos = inner_product / den
|
155 |
+
|
156 |
+
return cos
|
157 |
+
|
158 |
+
|
159 |
+
def batch_angles_from_coords(coords, mask):
|
160 |
+
"""
|
161 |
+
Given coordinates, compute all local neighborhood angles
|
162 |
+
"""
|
163 |
+
if coords.dim() == 4:
|
164 |
+
all_possible_combos = coords[:, angle_combos]
|
165 |
+
v_a, v_b = all_possible_combos.split(1, dim=2) # does one of these need to be negative?
|
166 |
+
angle_mask = angle_mask_ref[mask.sum(dim=1).long()]
|
167 |
+
angles = batch_angle_between_vectors(v_a.squeeze(2), v_b.squeeze(2)) * angle_mask.unsqueeze(-1)
|
168 |
+
elif coords.dim() == 5:
|
169 |
+
all_possible_combos = coords[:, :, angle_combos]
|
170 |
+
v_a, v_b = all_possible_combos.split(1, dim=3) # does one of these need to be negative?
|
171 |
+
angle_mask = angle_mask_ref[mask.sum(dim=1).long()]
|
172 |
+
angles = batch_angle_between_vectors(v_a.squeeze(3), v_b.squeeze(3)) * angle_mask.unsqueeze(-1).unsqueeze(-1)
|
173 |
+
|
174 |
+
return angles
|
175 |
+
|
176 |
+
|
177 |
+
def batch_local_stats_from_coords(coords, mask):
|
178 |
+
"""
|
179 |
+
Given neighborhood neighbor coordinates, compute bond distances,
|
180 |
+
2-hop distances, and angles in local neighborhood (this assumes
|
181 |
+
the central atom has coordinates at the origin)
|
182 |
+
"""
|
183 |
+
one_hop_ds, two_dop_d_mat = batch_distance_metrics_from_coords(coords, mask)
|
184 |
+
angles = batch_angles_from_coords(coords, mask)
|
185 |
+
return one_hop_ds, two_dop_d_mat, angles
|
186 |
+
|
187 |
+
|
188 |
+
def batch_dihedrals(p0, p1, p2, p3, angle=False):
|
189 |
+
|
190 |
+
s1 = p1 - p0
|
191 |
+
s2 = p2 - p1
|
192 |
+
s3 = p3 - p2
|
193 |
+
|
194 |
+
sin_d_ = torch.linalg.norm(s2, dim=-1) * torch.sum(s1 * torch.cross(s2, s3, dim=-1), dim=-1)
|
195 |
+
cos_d_ = torch.sum(torch.cross(s1, s2, dim=-1) * torch.cross(s2, s3, dim=-1), dim=-1)
|
196 |
+
|
197 |
+
if angle:
|
198 |
+
return torch.atan2(sin_d_, cos_d_ + 1e-10)
|
199 |
+
|
200 |
+
else:
|
201 |
+
den = torch.linalg.norm(torch.cross(s1, s2, dim=-1), dim=-1) * torch.linalg.norm(torch.cross(s2, s3, dim=-1), dim=-1) + 1e-10
|
202 |
+
return sin_d_/den, cos_d_/den
|
203 |
+
|
204 |
+
|
205 |
+
def batch_vector_angles(xn, x, y, yn):
|
206 |
+
uT = xn.view(-1, 3)
|
207 |
+
uX = x.view(-1, 3)
|
208 |
+
uY = y.view(-1, 3)
|
209 |
+
uZ = yn.view(-1, 3)
|
210 |
+
|
211 |
+
b1 = uT - uX
|
212 |
+
b2 = uZ - uY
|
213 |
+
|
214 |
+
num = torch.bmm(b1.view(-1, 1, 3), b2.view(-1, 3, 1)).squeeze(-1).squeeze(-1)
|
215 |
+
den = torch.linalg.norm(b1, dim=-1) * torch.linalg.norm(b2, dim=-1) + 1e-10
|
216 |
+
|
217 |
+
return (num / den).view(-1, 9)
|
218 |
+
|
219 |
+
|
220 |
+
def von_Mises_loss(a, b, a_sin=None, b_sin=None):
|
221 |
+
"""
|
222 |
+
:param a: cos of first angle
|
223 |
+
:param b: cos of second angle
|
224 |
+
:return: difference of cosines
|
225 |
+
"""
|
226 |
+
if torch.is_tensor(a_sin):
|
227 |
+
out = a * b + a_sin * b_sin
|
228 |
+
else:
|
229 |
+
out = a * b + torch.sqrt(1-a**2 + 1e-5) * torch.sqrt(1-b**2 + 1e-5)
|
230 |
+
return out
|
231 |
+
|
232 |
+
|
233 |
+
def rotation_matrix(neighbor_coords, neighbor_mask, neighbor_map, mu=None):
|
234 |
+
"""
|
235 |
+
Given predicted neighbor coordinates from model, return rotation matrix
|
236 |
+
|
237 |
+
:param neighbor_coords: neighbor coordinates for each edge as defined by dihedral_pairs
|
238 |
+
(n_dihedral_pairs, 4, n_generated_confs, 3)
|
239 |
+
:param neighbor_mask: mask describing which atoms are present (n_dihedral_pairs, 4)
|
240 |
+
:param neighbor_map: mask describing which neighbor corresponds to the other central dihedral atom
|
241 |
+
(n_dihedral_pairs, 4) each entry in neighbor_map should have one TRUE entry with the rest as FALSE
|
242 |
+
:return: rotation matrix (n_dihedral_pairs, n_model_confs, 3, 3)
|
243 |
+
"""
|
244 |
+
|
245 |
+
if not torch.is_tensor(mu):
|
246 |
+
# mu = neighbor_coords.sum(dim=1, keepdim=True) / (neighbor_mask.sum(dim=-1, keepdim=True).unsqueeze(-1).unsqueeze(-1) + 1e-10)
|
247 |
+
mu_num = neighbor_coords[~neighbor_map.bool()].view(neighbor_coords.size(0), 3, neighbor_coords.size(2), -1).sum(dim=1)
|
248 |
+
mu_den = (neighbor_mask.sum(dim=-1, keepdim=True).unsqueeze(-1) - 1 + 1e-10)
|
249 |
+
mu = mu_num / mu_den # (n_dihedral_pairs, n_model_confs, 10)
|
250 |
+
mu = mu.squeeze(1) # (n_dihedral_pairs, n_model_confs, 10)
|
251 |
+
|
252 |
+
p_Y = neighbor_coords[neighbor_map.bool(), :]
|
253 |
+
h1 = p_Y / (torch.linalg.norm(p_Y, dim=-1, keepdim=True) + 1e-10) # (n_dihedral_pairs, n_model_confs, 10)
|
254 |
+
|
255 |
+
h3_1 = torch.cross(p_Y, mu, dim=-1)
|
256 |
+
h3 = h3_1 / (torch.linalg.norm(h3_1, dim=-1, keepdim=True) + 1e-10) # (n_dihedral_pairs, n_model_confs, 10)
|
257 |
+
|
258 |
+
h2 = -torch.cross(h1, h3, dim=-1) # (n_dihedral_pairs, n_model_confs, 10)
|
259 |
+
|
260 |
+
H = torch.cat([h1.unsqueeze(-2),
|
261 |
+
h2.unsqueeze(-2),
|
262 |
+
h3.unsqueeze(-2)], dim=-2)
|
263 |
+
|
264 |
+
return H
|
265 |
+
|
266 |
+
|
267 |
+
def rotation_matrix_v2(neighbor_coords):
|
268 |
+
"""
|
269 |
+
Given predicted neighbor coordinates from model, return rotation matrix
|
270 |
+
:param neighbor_coords: y or x coordinates for the x or y center node
|
271 |
+
(n_dihedral_pairs, 3)
|
272 |
+
:return: rotation matrix (n_dihedral_pairs, 3, 3)
|
273 |
+
"""
|
274 |
+
|
275 |
+
p_Y = neighbor_coords
|
276 |
+
|
277 |
+
eta_1 = torch.rand_like(p_Y)
|
278 |
+
eta_2 = eta_1 - torch.sum(eta_1 * p_Y, dim=-1, keepdim=True) / (torch.linalg.norm(p_Y, dim=-1, keepdim=True)**2 + 1e-10) * p_Y
|
279 |
+
eta = eta_2 / torch.linalg.norm(eta_2, dim=-1, keepdim=True)
|
280 |
+
|
281 |
+
h1 = p_Y / (torch.linalg.norm(p_Y, dim=-1, keepdim=True) + 1e-10) # (n_dihedral_pairs, n_model_confs, 10)
|
282 |
+
|
283 |
+
h3_1 = torch.cross(p_Y, eta, dim=-1)
|
284 |
+
h3 = h3_1 / (torch.linalg.norm(h3_1, dim=-1, keepdim=True) + 1e-10) # (n_dihedral_pairs, n_model_confs, 10)
|
285 |
+
|
286 |
+
h2 = -torch.cross(h1, h3, dim=-1) # (n_dihedral_pairs, n_model_confs, 10)
|
287 |
+
|
288 |
+
H = torch.cat([h1.unsqueeze(-2),
|
289 |
+
h2.unsqueeze(-2),
|
290 |
+
h3.unsqueeze(-2)], dim=-2)
|
291 |
+
|
292 |
+
return H
|
293 |
+
|
294 |
+
|
295 |
+
def signed_volume(local_coords):
|
296 |
+
"""
|
297 |
+
Compute signed volume given ordered neighbor local coordinates
|
298 |
+
|
299 |
+
:param local_coords: (n_tetrahedral_chiral_centers, 4, n_generated_confs, 3)
|
300 |
+
:return: signed volume of each tetrahedral center (n_tetrahedral_chiral_centers, n_generated_confs)
|
301 |
+
"""
|
302 |
+
v1 = local_coords[:, 0] - local_coords[:, 3]
|
303 |
+
v2 = local_coords[:, 1] - local_coords[:, 3]
|
304 |
+
v3 = local_coords[:, 2] - local_coords[:, 3]
|
305 |
+
cp = v2.cross(v3, dim=-1)
|
306 |
+
vol = torch.sum(v1 * cp, dim=-1)
|
307 |
+
return torch.sign(vol)
|
308 |
+
|
309 |
+
|
310 |
+
def rotation_matrix_inf(neighbor_coords, neighbor_mask, neighbor_map):
|
311 |
+
"""
|
312 |
+
Given predicted neighbor coordinates from model, return rotation matrix
|
313 |
+
|
314 |
+
:param neighbor_coords: neighbor coordinates for each edge as defined by dihedral_pairs (4, n_model_confs, 3)
|
315 |
+
:param neighbor_mask: mask describing which atoms are present (4)
|
316 |
+
:param neighbor_map: mask describing which neighbor corresponds to the other central dihedral atom (4)
|
317 |
+
each entry in neighbor_map should have one TRUE entry with the rest as FALSE
|
318 |
+
:return: rotation matrix (3, 3)
|
319 |
+
"""
|
320 |
+
|
321 |
+
mu = neighbor_coords.sum(dim=0, keepdim=True) / (neighbor_mask.sum(dim=-1, keepdim=True).unsqueeze(-1) + 1e-10)
|
322 |
+
mu = mu.squeeze(0)
|
323 |
+
p_Y = neighbor_coords[neighbor_map.bool(), :].squeeze(0)
|
324 |
+
|
325 |
+
h1 = p_Y / (torch.linalg.norm(p_Y, dim=-1, keepdim=True) + 1e-10)
|
326 |
+
|
327 |
+
h3_1 = torch.cross(p_Y, mu, dim=-1)
|
328 |
+
h3 = h3_1 / (torch.linalg.norm(h3_1, dim=-1, keepdim=True) + 1e-10)
|
329 |
+
|
330 |
+
h2 = -torch.cross(h1, h3, dim=-1)
|
331 |
+
|
332 |
+
H = torch.cat([h1.unsqueeze(-2),
|
333 |
+
h2.unsqueeze(-2),
|
334 |
+
h3.unsqueeze(-2)], dim=-2)
|
335 |
+
|
336 |
+
return H
|
337 |
+
|
338 |
+
|
339 |
+
def build_alpha_rotation_inf(alpha, n_model_confs):
|
340 |
+
|
341 |
+
H_alpha = torch.FloatTensor([[[1, 0, 0], [0, 0, 0], [0, 0, 0]]]).repeat(n_model_confs, 1, 1)
|
342 |
+
H_alpha[:, 1, 1] = torch.cos(alpha)
|
343 |
+
H_alpha[:, 1, 2] = -torch.sin(alpha)
|
344 |
+
H_alpha[:, 2, 1] = torch.sin(alpha)
|
345 |
+
H_alpha[:, 2, 2] = torch.cos(alpha)
|
346 |
+
|
347 |
+
return H_alpha
|
348 |
+
|
349 |
+
|
350 |
+
def random_rotation_matrix(dim):
|
351 |
+
yaw = torch.rand(dim)
|
352 |
+
pitch = torch.rand(dim)
|
353 |
+
roll = torch.rand(dim)
|
354 |
+
|
355 |
+
R = torch.stack([torch.stack([torch.cos(yaw) * torch.cos(pitch),
|
356 |
+
torch.cos(yaw) * torch.sin(pitch) * torch.sin(roll) - torch.sin(yaw) * torch.cos(
|
357 |
+
roll),
|
358 |
+
torch.cos(yaw) * torch.sin(pitch) * torch.cos(roll) + torch.sin(yaw) * torch.sin(
|
359 |
+
roll)], dim=-1),
|
360 |
+
torch.stack([torch.sin(yaw) * torch.cos(pitch),
|
361 |
+
torch.sin(yaw) * torch.sin(pitch) * torch.sin(roll) + torch.cos(yaw) * torch.cos(
|
362 |
+
roll),
|
363 |
+
torch.sin(yaw) * torch.sin(pitch) * torch.cos(roll) - torch.cos(yaw) * torch.sin(
|
364 |
+
roll)], dim=-1),
|
365 |
+
torch.stack([-torch.sin(pitch),
|
366 |
+
torch.cos(pitch) * torch.sin(roll),
|
367 |
+
torch.cos(pitch) * torch.cos(roll)], dim=-1)], dim=-2)
|
368 |
+
|
369 |
+
return R
|
370 |
+
|
371 |
+
|
372 |
+
def length_to_mask(length, max_len=None, dtype=None):
|
373 |
+
"""length: B.
|
374 |
+
return B x max_len.
|
375 |
+
If max_len is None, then max of length will be used.
|
376 |
+
"""
|
377 |
+
assert len(length.shape) == 1, 'Length shape should be 1 dimensional.'
|
378 |
+
max_len = max_len or length.max().item()
|
379 |
+
mask = torch.arange(max_len, device=length.device,
|
380 |
+
dtype=length.dtype).expand(len(length), max_len) < length.unsqueeze(1)
|
381 |
+
if dtype is not None:
|
382 |
+
mask = torch.as_tensor(mask, dtype=dtype, device=length.device)
|
383 |
+
return mask
|
utils/docking.py
ADDED
@@ -0,0 +1,183 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import subprocess
|
3 |
+
import random
|
4 |
+
import string
|
5 |
+
from easydict import EasyDict
|
6 |
+
from rdkit import Chem
|
7 |
+
from rdkit.Chem.rdForceFieldHelpers import UFFOptimizeMolecule
|
8 |
+
|
9 |
+
from .reconstruct import reconstruct_from_generated
|
10 |
+
|
11 |
+
|
12 |
+
def get_random_id(length=30):
|
13 |
+
letters = string.ascii_lowercase
|
14 |
+
return ''.join(random.choice(letters) for i in range(length))
|
15 |
+
|
16 |
+
|
17 |
+
def load_pdb(path):
|
18 |
+
with open(path, 'r') as f:
|
19 |
+
return f.read()
|
20 |
+
|
21 |
+
|
22 |
+
def parse_qvina_outputs(docked_sdf_path):
|
23 |
+
|
24 |
+
suppl = Chem.SDMolSupplier(docked_sdf_path)
|
25 |
+
results = []
|
26 |
+
for i, mol in enumerate(suppl):
|
27 |
+
if mol is None:
|
28 |
+
continue
|
29 |
+
line = mol.GetProp('REMARK').splitlines()[0].split()[2:]
|
30 |
+
results.append(EasyDict({
|
31 |
+
'rdmol': mol,
|
32 |
+
'mode_id': i,
|
33 |
+
'affinity': float(line[0]),
|
34 |
+
'rmsd_lb': float(line[1]),
|
35 |
+
'rmsd_ub': float(line[2]),
|
36 |
+
}))
|
37 |
+
|
38 |
+
return results
|
39 |
+
|
40 |
+
class BaseDockingTask(object):
|
41 |
+
|
42 |
+
def __init__(self, pdb_block, ligand_rdmol):
|
43 |
+
super().__init__()
|
44 |
+
self.pdb_block = pdb_block
|
45 |
+
self.ligand_rdmol = ligand_rdmol
|
46 |
+
|
47 |
+
def run(self):
|
48 |
+
raise NotImplementedError()
|
49 |
+
|
50 |
+
def get_results(self):
|
51 |
+
raise NotImplementedError()
|
52 |
+
|
53 |
+
|
54 |
+
class QVinaDockingTask(BaseDockingTask):
|
55 |
+
|
56 |
+
@classmethod
|
57 |
+
def from_generated_data(cls, data, protein_root='./data/crossdocked', **kwargs):
|
58 |
+
protein_fn = os.path.join(
|
59 |
+
os.path.dirname(data.ligand_filename),
|
60 |
+
os.path.basename(data.ligand_filename)[:10] + '.pdb'
|
61 |
+
)
|
62 |
+
protein_path = os.path.join(protein_root, protein_fn)
|
63 |
+
with open(protein_path, 'r') as f:
|
64 |
+
pdb_block = f.read()
|
65 |
+
ligand_rdmol = reconstruct_from_generated(data)
|
66 |
+
return cls(pdb_block, ligand_rdmol, **kwargs)
|
67 |
+
|
68 |
+
@classmethod
|
69 |
+
def from_original_data(cls, data, ligand_root='./data/crossdocked_pocket10', protein_root='./data/crossdocked', **kwargs):
|
70 |
+
protein_fn = os.path.join(
|
71 |
+
os.path.dirname(data.ligand_filename),
|
72 |
+
os.path.basename(data.ligand_filename)[:10] + '.pdb'
|
73 |
+
)
|
74 |
+
protein_path = os.path.join(protein_root, protein_fn)
|
75 |
+
with open(protein_path, 'r') as f:
|
76 |
+
pdb_block = f.read()
|
77 |
+
|
78 |
+
ligand_path = os.path.join(ligand_root, data.ligand_filename)
|
79 |
+
ligand_rdmol = next(iter(Chem.SDMolSupplier(ligand_path)))
|
80 |
+
return cls(pdb_block, ligand_rdmol, **kwargs)
|
81 |
+
|
82 |
+
def __init__(self, pdb_block, ligand_rdmol, conda_env='adt', tmp_dir='./tmp', use_uff=True, center=None):
|
83 |
+
super().__init__(pdb_block, ligand_rdmol)
|
84 |
+
self.conda_env = conda_env
|
85 |
+
self.tmp_dir = os.path.realpath(tmp_dir)
|
86 |
+
os.makedirs(tmp_dir, exist_ok=True)
|
87 |
+
|
88 |
+
self.task_id = get_random_id()
|
89 |
+
self.receptor_id = self.task_id + '_receptor'
|
90 |
+
self.ligand_id = self.task_id + '_ligand'
|
91 |
+
|
92 |
+
self.receptor_path = os.path.join(self.tmp_dir, self.receptor_id + '.pdb')
|
93 |
+
self.ligand_path = os.path.join(self.tmp_dir, self.ligand_id + '.sdf')
|
94 |
+
|
95 |
+
with open(self.receptor_path, 'w') as f:
|
96 |
+
f.write(pdb_block)
|
97 |
+
|
98 |
+
ligand_rdmol = Chem.AddHs(ligand_rdmol, addCoords=True)
|
99 |
+
if use_uff:
|
100 |
+
UFFOptimizeMolecule(ligand_rdmol)
|
101 |
+
sdf_writer = Chem.SDWriter(self.ligand_path)
|
102 |
+
sdf_writer.write(ligand_rdmol)
|
103 |
+
sdf_writer.close()
|
104 |
+
self.ligand_rdmol = ligand_rdmol
|
105 |
+
|
106 |
+
pos = ligand_rdmol.GetConformer(0).GetPositions()
|
107 |
+
if center is None:
|
108 |
+
self.center = (pos.max(0) + pos.min(0)) / 2
|
109 |
+
else:
|
110 |
+
self.center = center
|
111 |
+
|
112 |
+
self.proc = None
|
113 |
+
self.results = None
|
114 |
+
self.output = None
|
115 |
+
self.docked_sdf_path = None
|
116 |
+
|
117 |
+
def run(self, exhaustiveness=16):
|
118 |
+
commands = """
|
119 |
+
eval "$(conda shell.bash hook)"
|
120 |
+
conda activate {env}
|
121 |
+
cd {tmp}
|
122 |
+
# Prepare receptor (PDB->PDBQT)
|
123 |
+
prepare_receptor4.py -r {receptor_id}.pdb
|
124 |
+
# Prepare ligand
|
125 |
+
obabel {ligand_id}.sdf -O{ligand_id}.pdbqt
|
126 |
+
qvina2.1 \
|
127 |
+
--receptor {receptor_id}.pdbqt \
|
128 |
+
--ligand {ligand_id}.pdbqt \
|
129 |
+
--center_x {center_x:.4f} \
|
130 |
+
--center_y {center_y:.4f} \
|
131 |
+
--center_z {center_z:.4f} \
|
132 |
+
--size_x 20 --size_y 20 --size_z 20 \
|
133 |
+
--exhaustiveness {exhaust}
|
134 |
+
obabel {ligand_id}_out.pdbqt -O{ligand_id}_out.sdf -h
|
135 |
+
""".format(
|
136 |
+
receptor_id = self.receptor_id,
|
137 |
+
ligand_id = self.ligand_id,
|
138 |
+
env = self.conda_env,
|
139 |
+
tmp = self.tmp_dir,
|
140 |
+
exhaust = exhaustiveness,
|
141 |
+
center_x = self.center[0],
|
142 |
+
center_y = self.center[1],
|
143 |
+
center_z = self.center[2],
|
144 |
+
)
|
145 |
+
|
146 |
+
self.docked_sdf_path = os.path.join(self.tmp_dir, '%s_out.sdf' % self.ligand_id)
|
147 |
+
|
148 |
+
self.proc = subprocess.Popen(
|
149 |
+
'/bin/bash',
|
150 |
+
shell=False,
|
151 |
+
stdin=subprocess.PIPE,
|
152 |
+
stdout=subprocess.PIPE,
|
153 |
+
stderr=subprocess.PIPE
|
154 |
+
)
|
155 |
+
|
156 |
+
self.proc.stdin.write(commands.encode('utf-8'))
|
157 |
+
self.proc.stdin.close()
|
158 |
+
|
159 |
+
# return commands
|
160 |
+
|
161 |
+
def run_sync(self):
|
162 |
+
self.run()
|
163 |
+
while self.get_results() is None:
|
164 |
+
pass
|
165 |
+
results = self.get_results()
|
166 |
+
print('Best affinity:', results[0]['affinity'])
|
167 |
+
return results
|
168 |
+
|
169 |
+
def get_results(self):
|
170 |
+
if self.proc is None: # Not started
|
171 |
+
return None
|
172 |
+
elif self.proc.poll() is None: # In progress
|
173 |
+
return None
|
174 |
+
else:
|
175 |
+
if self.output is None:
|
176 |
+
self.output = self.proc.stdout.readlines()
|
177 |
+
try:
|
178 |
+
self.results = parse_qvina_outputs(self.docked_sdf_path)
|
179 |
+
except:
|
180 |
+
print('[Error] Vina output error: %s' % self.docked_sdf_path)
|
181 |
+
return []
|
182 |
+
return self.results
|
183 |
+
|
utils/fpscores.pkl.gz
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:10dcef9340c873e7b987924461b0af5365eb8dd96be607203debe8ddf80c1e73
|
3 |
+
size 3848394
|
utils/misc.py
ADDED
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import time
|
3 |
+
import random
|
4 |
+
import logging
|
5 |
+
import torch
|
6 |
+
import numpy as np
|
7 |
+
import yaml
|
8 |
+
from easydict import EasyDict
|
9 |
+
from logging import Logger
|
10 |
+
from tqdm.auto import tqdm
|
11 |
+
|
12 |
+
|
13 |
+
class BlackHole(object):
|
14 |
+
def __setattr__(self, name, value):
|
15 |
+
pass
|
16 |
+
def __call__(self, *args, **kwargs):
|
17 |
+
return self
|
18 |
+
def __getattr__(self, name):
|
19 |
+
return self
|
20 |
+
|
21 |
+
|
22 |
+
def load_config(path):
|
23 |
+
with open(path, 'r') as f:
|
24 |
+
return EasyDict(yaml.safe_load(f))
|
25 |
+
|
26 |
+
|
27 |
+
def get_logger(name, log_dir=None):
|
28 |
+
logger = logging.getLogger(name)
|
29 |
+
logger.setLevel(logging.DEBUG)
|
30 |
+
formatter = logging.Formatter('[%(asctime)s::%(name)s::%(levelname)s] %(message)s')
|
31 |
+
|
32 |
+
stream_handler = logging.StreamHandler()
|
33 |
+
stream_handler.setLevel(logging.DEBUG)
|
34 |
+
stream_handler.setFormatter(formatter)
|
35 |
+
logger.addHandler(stream_handler)
|
36 |
+
|
37 |
+
if log_dir is not None:
|
38 |
+
file_handler = logging.FileHandler(os.path.join(log_dir, 'log.txt'))
|
39 |
+
file_handler.setLevel(logging.DEBUG)
|
40 |
+
file_handler.setFormatter(formatter)
|
41 |
+
logger.addHandler(file_handler)
|
42 |
+
|
43 |
+
return logger
|
44 |
+
|
45 |
+
|
46 |
+
def get_new_log_dir(root='./logs', prefix='', tag=''):
|
47 |
+
fn = time.strftime('%Y_%m_%d__%H_%M_%S', time.localtime())
|
48 |
+
if prefix != '':
|
49 |
+
fn = prefix + '_' + fn
|
50 |
+
if tag != '':
|
51 |
+
fn = fn + '_' + tag
|
52 |
+
log_dir = os.path.join(root, fn)
|
53 |
+
os.makedirs(log_dir)
|
54 |
+
return log_dir
|
55 |
+
|
56 |
+
|
57 |
+
def seed_all(seed):
|
58 |
+
torch.manual_seed(seed)
|
59 |
+
np.random.seed(seed)
|
60 |
+
random.seed(seed)
|
61 |
+
|
62 |
+
|
63 |
+
def log_hyperparams(writer, args):
|
64 |
+
from torch.utils.tensorboard.summary import hparams
|
65 |
+
vars_args = {k:v if isinstance(v, str) else repr(v) for k, v in vars(args).items()}
|
66 |
+
exp, ssi, sei = hparams(vars_args, {})
|
67 |
+
writer.file_writer.add_summary(exp)
|
68 |
+
writer.file_writer.add_summary(ssi)
|
69 |
+
writer.file_writer.add_summary(sei)
|
70 |
+
|
71 |
+
|
72 |
+
def int_tuple(argstr):
|
73 |
+
return tuple(map(int, argstr.split(',')))
|
74 |
+
|
75 |
+
|
76 |
+
def str_tuple(argstr):
|
77 |
+
return tuple(argstr.split(','))
|
78 |
+
|
utils/mol_tree.py
ADDED
@@ -0,0 +1,220 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
sys.path.append("..")
|
3 |
+
import rdkit
|
4 |
+
import rdkit.Chem as Chem
|
5 |
+
import copy
|
6 |
+
import pickle
|
7 |
+
from tqdm.auto import tqdm
|
8 |
+
import numpy as np
|
9 |
+
import torch
|
10 |
+
import random
|
11 |
+
from .chemutils import get_clique_mol, tree_decomp, get_mol, get_smiles, set_atommap, get_clique_mol_simple
|
12 |
+
from collections import defaultdict
|
13 |
+
|
14 |
+
|
15 |
+
def get_slots(smiles):
|
16 |
+
mol = Chem.MolFromSmiles(smiles, sanitize=False)
|
17 |
+
return [(atom.GetSymbol(), atom.GetFormalCharge(), atom.GetTotalNumHs()) for atom in mol.GetAtoms()]
|
18 |
+
|
19 |
+
|
20 |
+
class Vocab(object):
|
21 |
+
|
22 |
+
def __init__(self, smiles_list):
|
23 |
+
self.vocab = smiles_list
|
24 |
+
self.vmap = {x: i for i, x in enumerate(self.vocab)}
|
25 |
+
#self.slots = [get_slots(smiles) for smiles in self.vocab]
|
26 |
+
|
27 |
+
def get_index(self, smiles):
|
28 |
+
if smiles in self.vmap.keys():
|
29 |
+
return self.vmap[smiles]
|
30 |
+
else:
|
31 |
+
return 0
|
32 |
+
|
33 |
+
def get_smiles(self, idx):
|
34 |
+
return self.vocab[idx]
|
35 |
+
|
36 |
+
def get_slots(self, idx):
|
37 |
+
return copy.deepcopy(self.slots[idx])
|
38 |
+
|
39 |
+
def size(self):
|
40 |
+
return len(self.vocab)
|
41 |
+
|
42 |
+
|
43 |
+
class MolTreeNode(object):
|
44 |
+
|
45 |
+
def __init__(self, mol, cmol, clique):
|
46 |
+
self.smiles = Chem.MolToSmiles(cmol, canonical=True)
|
47 |
+
self.mol = cmol
|
48 |
+
self.clique = [x for x in clique] # copy
|
49 |
+
|
50 |
+
self.neighbors = []
|
51 |
+
self.rotatable = False
|
52 |
+
if len(self.clique) == 2:
|
53 |
+
if mol.GetAtomWithIdx(self.clique[0]).GetDegree() >= 2 and mol.GetAtomWithIdx(self.clique[1]).GetDegree() >= 2:
|
54 |
+
self.rotatable = True
|
55 |
+
# should restrict to single bond, but double bond is ok
|
56 |
+
|
57 |
+
def add_neighbor(self, nei_node):
|
58 |
+
self.neighbors.append(nei_node)
|
59 |
+
|
60 |
+
def recover(self, original_mol):
|
61 |
+
clique = []
|
62 |
+
clique.extend(self.clique)
|
63 |
+
if not self.is_leaf:
|
64 |
+
for cidx in self.clique:
|
65 |
+
original_mol.GetAtomWithIdx(cidx).SetAtomMapNum(self.nid)
|
66 |
+
|
67 |
+
for nei_node in self.neighbors:
|
68 |
+
clique.extend(nei_node.clique)
|
69 |
+
if nei_node.is_leaf: # Leaf node, no need to mark
|
70 |
+
continue
|
71 |
+
for cidx in nei_node.clique:
|
72 |
+
# allow singleton node override the atom mapping
|
73 |
+
if cidx not in self.clique or len(nei_node.clique) == 1:
|
74 |
+
atom = original_mol.GetAtomWithIdx(cidx)
|
75 |
+
atom.SetAtomMapNum(nei_node.nid)
|
76 |
+
|
77 |
+
clique = list(set(clique))
|
78 |
+
label_mol = get_clique_mol_simple(original_mol, clique)
|
79 |
+
self.label = Chem.MolToSmiles(Chem.MolFromSmiles(get_smiles(label_mol)))
|
80 |
+
self.label_mol = get_mol(self.label)
|
81 |
+
|
82 |
+
for cidx in clique:
|
83 |
+
original_mol.GetAtomWithIdx(cidx).SetAtomMapNum(0)
|
84 |
+
|
85 |
+
return self.label
|
86 |
+
|
87 |
+
def assemble(self):
|
88 |
+
# neighbors = [nei for nei in self.neighbors if nei.mol.GetNumAtoms() > 1]
|
89 |
+
neighbors = sorted(self.neighbors, key=lambda x: x.mol.GetNumAtoms(), reverse=True)
|
90 |
+
# singletons = [nei for nei in self.neighbors if nei.mol.GetNumAtoms() == 1]
|
91 |
+
# neighbors = singletons + neighbors
|
92 |
+
|
93 |
+
cands = enum_assemble(self, neighbors)
|
94 |
+
if len(cands) > 0:
|
95 |
+
self.cands, self.cand_mols, _ = zip(*cands)
|
96 |
+
self.cands = list(self.cands)
|
97 |
+
self.cand_mols = list(self.cand_mols)
|
98 |
+
else:
|
99 |
+
self.cands = []
|
100 |
+
self.cand_mols = []
|
101 |
+
|
102 |
+
|
103 |
+
class MolTree(object):
|
104 |
+
def __init__(self, mol):
|
105 |
+
self.smiles = Chem.MolToSmiles(mol)
|
106 |
+
self.mol = mol
|
107 |
+
self.num_rotatable_bond = 0
|
108 |
+
'''
|
109 |
+
# use reference_vocab and threshold to control the size of vocab
|
110 |
+
reference_vocab = np.load('./utils/reference.npy', allow_pickle=True).item()
|
111 |
+
reference = defaultdict(int)
|
112 |
+
for k, v in reference_vocab.items():
|
113 |
+
reference[k] = v'''
|
114 |
+
|
115 |
+
# use vanilla tree decomposition for simplicity
|
116 |
+
cliques, edges = tree_decomp(self.mol, reference_vocab=None)
|
117 |
+
self.nodes = []
|
118 |
+
root = 0
|
119 |
+
for i, c in enumerate(cliques):
|
120 |
+
cmol = get_clique_mol_simple(self.mol, c)
|
121 |
+
node = MolTreeNode(self.mol, cmol, c)
|
122 |
+
self.nodes.append(node)
|
123 |
+
if min(c) == 0:
|
124 |
+
root = i
|
125 |
+
|
126 |
+
for node in self.nodes:
|
127 |
+
if node.rotatable:
|
128 |
+
self.num_rotatable_bond += 1
|
129 |
+
|
130 |
+
for x, y in edges:
|
131 |
+
self.nodes[x].add_neighbor(self.nodes[y])
|
132 |
+
self.nodes[y].add_neighbor(self.nodes[x])
|
133 |
+
|
134 |
+
if root > 0:
|
135 |
+
self.nodes[0], self.nodes[root] = self.nodes[root], self.nodes[0]
|
136 |
+
|
137 |
+
for i, node in enumerate(self.nodes):
|
138 |
+
node.nid = i + 1
|
139 |
+
'''
|
140 |
+
if len(node.neighbors) > 1: # Leaf node mol is not marked
|
141 |
+
set_atommap(node.mol, node.nid)
|
142 |
+
node.is_leaf = (len(node.neighbors) == 1)'''
|
143 |
+
|
144 |
+
def size(self):
|
145 |
+
return len(self.nodes)
|
146 |
+
|
147 |
+
def recover(self):
|
148 |
+
for node in self.nodes:
|
149 |
+
node.recover(self.mol)
|
150 |
+
|
151 |
+
def assemble(self):
|
152 |
+
for node in self.nodes:
|
153 |
+
node.assemble()
|
154 |
+
|
155 |
+
|
156 |
+
if __name__ == "__main__":
|
157 |
+
seed = 2023
|
158 |
+
torch.manual_seed(seed)
|
159 |
+
np.random.seed(seed)
|
160 |
+
random.seed(seed)
|
161 |
+
|
162 |
+
vocab = {}
|
163 |
+
cnt = 0
|
164 |
+
rot = 0
|
165 |
+
'''
|
166 |
+
index_path = './data/crossdocked_pocket10/index.pkl'
|
167 |
+
with open(index_path, 'rb') as f:
|
168 |
+
index = pickle.load(f)
|
169 |
+
for i, (pocket_fn, ligand_fn, _, rmsd_str) in enumerate(tqdm(index)):
|
170 |
+
if pocket_fn is None: continue
|
171 |
+
try:
|
172 |
+
path = './data/crossdocked_pocket10/' + ligand_fn
|
173 |
+
mol = Chem.MolFromMolFile(path, sanitize=False)
|
174 |
+
moltree = MolTree(mol)
|
175 |
+
cnt += 1
|
176 |
+
if moltree.num_rotatable_bond > 0:
|
177 |
+
rot += 1
|
178 |
+
except:
|
179 |
+
continue
|
180 |
+
|
181 |
+
for c in moltree.nodes:
|
182 |
+
smile_cluster = c.smiles
|
183 |
+
if smile_cluster not in vocab:
|
184 |
+
vocab[smile_cluster] = 1
|
185 |
+
else:
|
186 |
+
vocab[smile_cluster] += 1
|
187 |
+
'''
|
188 |
+
|
189 |
+
index = torch.load('/n/holyscratch01/mzitnik_lab/zaixizhang/pdbbind_pocket10/index.pt')
|
190 |
+
for i, pdbid in enumerate(tqdm(index)):
|
191 |
+
if pdbid is None: continue
|
192 |
+
try:
|
193 |
+
path = '/n/holyscratch01/mzitnik_lab/zaixizhang/pdbbind_pocket10/'
|
194 |
+
ligand_path = os.path.join(path, os.path.join(item, item+'_ligand.sdf'))
|
195 |
+
mol = Chem.MolFromMolFile(ligand_path, sanitize=False)
|
196 |
+
moltree = MolTree(mol)
|
197 |
+
cnt += 1
|
198 |
+
if moltree.num_rotatable_bond > 0:
|
199 |
+
rot += 1
|
200 |
+
except:
|
201 |
+
continue
|
202 |
+
|
203 |
+
for c in moltree.nodes:
|
204 |
+
smile_cluster = c.smiles
|
205 |
+
if smile_cluster not in vocab:
|
206 |
+
vocab[smile_cluster] = 1
|
207 |
+
else:
|
208 |
+
vocab[smile_cluster] += 1
|
209 |
+
|
210 |
+
vocab = dict(sorted(vocab.items(), key=lambda kv: (kv[1], kv[0]), reverse=True))
|
211 |
+
filename = open('./vocab.txt', 'w')
|
212 |
+
for k, v in vocab.items():
|
213 |
+
filename.write(k + ':' + str(v))
|
214 |
+
filename.write('\n')
|
215 |
+
filename.close()
|
216 |
+
|
217 |
+
# number of molecules and vocab
|
218 |
+
print('Size of the motif vocab:', len(vocab))
|
219 |
+
print('Total number of molecules', cnt)
|
220 |
+
print('Percent of molecules with rotatable bonds:', rot / cnt)
|
utils/protein_ligand.py
ADDED
@@ -0,0 +1,283 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
sys.path.append("..")
|
3 |
+
import os
|
4 |
+
import numpy as np
|
5 |
+
from rdkit import Chem
|
6 |
+
from rdkit.Chem.rdchem import BondType
|
7 |
+
from rdkit.Chem import ChemicalFeatures
|
8 |
+
from rdkit import RDConfig
|
9 |
+
from .mol_tree import *
|
10 |
+
|
11 |
+
ATOM_FAMILIES = ['Acceptor', 'Donor', 'Aromatic', 'Hydrophobe', 'LumpedHydrophobe', 'NegIonizable', 'PosIonizable', 'ZnBinder']
|
12 |
+
ATOM_FAMILIES_ID = {s: i for i, s in enumerate(ATOM_FAMILIES)}
|
13 |
+
BOND_TYPES = {t: i for i, t in enumerate(BondType.names.values())}
|
14 |
+
BOND_NAMES = {i: t for i, t in enumerate(BondType.names.keys())}
|
15 |
+
|
16 |
+
|
17 |
+
class PDBProtein(object):
|
18 |
+
|
19 |
+
AA_NAME_SYM = {
|
20 |
+
'ALA': 'A', 'CYS': 'C', 'ASP': 'D', 'GLU': 'E', 'PHE': 'F', 'GLY': 'G', 'HIS': 'H',
|
21 |
+
'ILE': 'I', 'LYS': 'K', 'LEU': 'L', 'MET': 'M', 'ASN': 'N', 'PRO': 'P', 'GLN': 'Q',
|
22 |
+
'ARG': 'R', 'SER': 'S', 'THR': 'T', 'VAL': 'V', 'TRP': 'W', 'TYR': 'Y',
|
23 |
+
}
|
24 |
+
|
25 |
+
AA_NAME_NUMBER = {
|
26 |
+
k: i for i, (k, _) in enumerate(AA_NAME_SYM.items())
|
27 |
+
}
|
28 |
+
|
29 |
+
BACKBONE_NAMES = ["CA", "C", "N", "O"]
|
30 |
+
|
31 |
+
def __init__(self, data, mode='auto'):
|
32 |
+
super().__init__()
|
33 |
+
if (data[-4:].lower() == '.pdb' and mode == 'auto') or mode == 'path':
|
34 |
+
with open(data, 'r') as f:
|
35 |
+
self.block = f.read()
|
36 |
+
else:
|
37 |
+
self.block = data
|
38 |
+
|
39 |
+
self.ptable = Chem.GetPeriodicTable()
|
40 |
+
|
41 |
+
# Molecule properties
|
42 |
+
self.title = None
|
43 |
+
# Atom properties
|
44 |
+
self.atoms = []
|
45 |
+
self.element = []
|
46 |
+
self.atomic_weight = []
|
47 |
+
self.pos = []
|
48 |
+
self.atom_name = []
|
49 |
+
self.is_backbone = []
|
50 |
+
self.atom_to_aa_type = []
|
51 |
+
# Residue properties
|
52 |
+
self.residues = []
|
53 |
+
self.amino_acid = []
|
54 |
+
self.center_of_mass = []
|
55 |
+
self.pos_CA = []
|
56 |
+
self.pos_C = []
|
57 |
+
self.pos_N = []
|
58 |
+
self.pos_O = []
|
59 |
+
|
60 |
+
self._parse()
|
61 |
+
|
62 |
+
def _enum_formatted_atom_lines(self):
|
63 |
+
for line in self.block.splitlines():
|
64 |
+
if line[0:6].strip() == 'ATOM':
|
65 |
+
element_symb = line[76:78].strip().capitalize()
|
66 |
+
if len(element_symb) == 0:
|
67 |
+
element_symb = line[13:14]
|
68 |
+
yield {
|
69 |
+
'line': line,
|
70 |
+
'type': 'ATOM',
|
71 |
+
'atom_id': int(line[6:11]),
|
72 |
+
'atom_name': line[12:16].strip(),
|
73 |
+
'res_name': line[17:20].strip(),
|
74 |
+
'chain': line[21:22].strip(),
|
75 |
+
'res_id': int(line[22:26]),
|
76 |
+
'res_insert_id': line[26:27].strip(),
|
77 |
+
'x': float(line[30:38]),
|
78 |
+
'y': float(line[38:46]),
|
79 |
+
'z': float(line[46:54]),
|
80 |
+
'occupancy': float(line[54:60]),
|
81 |
+
'segment': line[72:76].strip(),
|
82 |
+
'element_symb': element_symb,
|
83 |
+
'charge': line[78:80].strip(),
|
84 |
+
}
|
85 |
+
elif line[0:6].strip() == 'HEADER':
|
86 |
+
yield {
|
87 |
+
'type': 'HEADER',
|
88 |
+
'value': line[10:].strip()
|
89 |
+
}
|
90 |
+
elif line[0:6].strip() == 'ENDMDL':
|
91 |
+
break # Some PDBs have more than 1 model.
|
92 |
+
|
93 |
+
def _parse(self):
|
94 |
+
# Process atoms
|
95 |
+
residues_tmp = {}
|
96 |
+
for atom in self._enum_formatted_atom_lines():
|
97 |
+
if atom['type'] == 'HEADER':
|
98 |
+
self.title = atom['value'].lower()
|
99 |
+
continue
|
100 |
+
self.atoms.append(atom)
|
101 |
+
atomic_number = self.ptable.GetAtomicNumber(atom['element_symb'])
|
102 |
+
next_ptr = len(self.element)
|
103 |
+
self.element.append(atomic_number)
|
104 |
+
self.atomic_weight.append(self.ptable.GetAtomicWeight(atomic_number))
|
105 |
+
self.pos.append(np.array([atom['x'], atom['y'], atom['z']], dtype=np.float32))
|
106 |
+
self.atom_name.append(atom['atom_name'])
|
107 |
+
self.is_backbone.append(atom['atom_name'] in self.BACKBONE_NAMES)
|
108 |
+
self.atom_to_aa_type.append(self.AA_NAME_NUMBER[atom['res_name']])
|
109 |
+
|
110 |
+
chain_res_id = '%s_%s_%d_%s' % (atom['chain'], atom['segment'], atom['res_id'], atom['res_insert_id'])
|
111 |
+
if chain_res_id not in residues_tmp:
|
112 |
+
residues_tmp[chain_res_id] = {
|
113 |
+
'name': atom['res_name'],
|
114 |
+
'atoms': [next_ptr],
|
115 |
+
'chain': atom['chain'],
|
116 |
+
'segment': atom['segment'],
|
117 |
+
}
|
118 |
+
else:
|
119 |
+
assert residues_tmp[chain_res_id]['name'] == atom['res_name']
|
120 |
+
assert residues_tmp[chain_res_id]['chain'] == atom['chain']
|
121 |
+
residues_tmp[chain_res_id]['atoms'].append(next_ptr)
|
122 |
+
|
123 |
+
# Process residues
|
124 |
+
self.residues = [r for _, r in residues_tmp.items()]
|
125 |
+
for residue in self.residues:
|
126 |
+
sum_pos = np.zeros([3], dtype=np.float32)
|
127 |
+
sum_mass = 0.0
|
128 |
+
for atom_idx in residue['atoms']:
|
129 |
+
sum_pos += self.pos[atom_idx] * self.atomic_weight[atom_idx]
|
130 |
+
sum_mass += self.atomic_weight[atom_idx]
|
131 |
+
if self.atom_name[atom_idx] in self.BACKBONE_NAMES:
|
132 |
+
residue['pos_%s' % self.atom_name[atom_idx]] = self.pos[atom_idx]
|
133 |
+
residue['center_of_mass'] = sum_pos / sum_mass
|
134 |
+
|
135 |
+
# Process backbone atoms of residues
|
136 |
+
for residue in self.residues:
|
137 |
+
self.amino_acid.append(self.AA_NAME_NUMBER[residue['name']])
|
138 |
+
self.center_of_mass.append(residue['center_of_mass'])
|
139 |
+
for name in self.BACKBONE_NAMES:
|
140 |
+
pos_key = 'pos_%s' % name # pos_CA, pos_C, pos_N, pos_O
|
141 |
+
if pos_key in residue:
|
142 |
+
getattr(self, pos_key).append(residue[pos_key])
|
143 |
+
else:
|
144 |
+
getattr(self, pos_key).append(residue['center_of_mass'])
|
145 |
+
|
146 |
+
def to_dict_atom(self):
|
147 |
+
return {
|
148 |
+
'element': np.array(self.element, dtype=np.int_),
|
149 |
+
'molecule_name': self.title,
|
150 |
+
'pos': np.array(self.pos, dtype=np.float32),
|
151 |
+
'is_backbone': np.array(self.is_backbone, dtype=bool),
|
152 |
+
'atom_name': self.atom_name,
|
153 |
+
'atom_to_aa_type': np.array(self.atom_to_aa_type, dtype=np.int_)
|
154 |
+
}
|
155 |
+
|
156 |
+
def to_dict_residue(self):
|
157 |
+
return {
|
158 |
+
'amino_acid': np.array(self.amino_acid, dtype=np.int_),
|
159 |
+
'center_of_mass': np.array(self.center_of_mass, dtype=np.float32),
|
160 |
+
'pos_CA': np.array(self.pos_CA, dtype=np.float32),
|
161 |
+
'pos_C': np.array(self.pos_C, dtype=np.float32),
|
162 |
+
'pos_N': np.array(self.pos_N, dtype=np.float32),
|
163 |
+
'pos_O': np.array(self.pos_O, dtype=np.float32),
|
164 |
+
}
|
165 |
+
|
166 |
+
def query_residues_radius(self, center, radius, criterion='center_of_mass'):
|
167 |
+
center = np.array(center).reshape(3)
|
168 |
+
selected = []
|
169 |
+
for residue in self.residues:
|
170 |
+
distance = np.linalg.norm(residue[criterion] - center, ord=2)
|
171 |
+
print(residue[criterion], distance)
|
172 |
+
if distance < radius:
|
173 |
+
selected.append(residue)
|
174 |
+
return selected
|
175 |
+
|
176 |
+
def query_residues_ligand(self, ligand, radius, criterion='center_of_mass'):
|
177 |
+
selected = []
|
178 |
+
sel_idx = set()
|
179 |
+
# The time-complexity is O(mn).
|
180 |
+
for center in ligand['pos']:
|
181 |
+
for i, residue in enumerate(self.residues):
|
182 |
+
distance = np.linalg.norm(residue[criterion] - center, ord=2)
|
183 |
+
if distance < radius and i not in sel_idx:
|
184 |
+
selected.append(residue)
|
185 |
+
sel_idx.add(i)
|
186 |
+
return selected
|
187 |
+
|
188 |
+
def residues_to_pdb_block(self, residues, name='POCKET'):
|
189 |
+
block = "HEADER %s\n" % name
|
190 |
+
block += "COMPND %s\n" % name
|
191 |
+
for residue in residues:
|
192 |
+
for atom_idx in residue['atoms']:
|
193 |
+
block += self.atoms[atom_idx]['line'] + "\n"
|
194 |
+
block += "END\n"
|
195 |
+
return block
|
196 |
+
|
197 |
+
|
198 |
+
def parse_pdbbind_index_file(path):
|
199 |
+
pdb_id = []
|
200 |
+
with open(path, 'r') as f:
|
201 |
+
lines = f.readlines()
|
202 |
+
for line in lines:
|
203 |
+
if line.startswith('#'): continue
|
204 |
+
pdb_id.append(line.split()[0])
|
205 |
+
return pdb_id
|
206 |
+
|
207 |
+
|
208 |
+
def parse_sdf_file(path):
|
209 |
+
mol = Chem.MolFromMolFile(path, sanitize=True)
|
210 |
+
moltree = MolTree(mol)
|
211 |
+
fdefName = os.path.join(RDConfig.RDDataDir,'BaseFeatures.fdef')
|
212 |
+
factory = ChemicalFeatures.BuildFeatureFactory(fdefName)
|
213 |
+
rdmol = next(iter(Chem.SDMolSupplier(path, removeHs=True)))
|
214 |
+
rd_num_atoms = rdmol.GetNumAtoms()
|
215 |
+
feat_mat = np.zeros([rd_num_atoms, len(ATOM_FAMILIES)], dtype=np.int_)
|
216 |
+
for feat in factory.GetFeaturesForMol(rdmol):
|
217 |
+
feat_mat[feat.GetAtomIds(), ATOM_FAMILIES_ID[feat.GetFamily()]] = 1
|
218 |
+
|
219 |
+
with open(path, 'r') as f:
|
220 |
+
sdf = f.read()
|
221 |
+
|
222 |
+
sdf = sdf.splitlines()
|
223 |
+
num_atoms, num_bonds = map(int, [sdf[3][0:3], sdf[3][3:6]])
|
224 |
+
assert num_atoms == rd_num_atoms
|
225 |
+
|
226 |
+
ptable = Chem.GetPeriodicTable()
|
227 |
+
element, pos = [], []
|
228 |
+
accum_pos = np.array([0.0, 0.0, 0.0], dtype=np.float32)
|
229 |
+
accum_mass = 0.0
|
230 |
+
for atom_line in map(lambda x:x.split(), sdf[4:4+num_atoms]):
|
231 |
+
x, y, z = map(float, atom_line[:3])
|
232 |
+
symb = atom_line[3]
|
233 |
+
atomic_number = ptable.GetAtomicNumber(symb.capitalize())
|
234 |
+
element.append(atomic_number)
|
235 |
+
pos.append([x, y, z])
|
236 |
+
|
237 |
+
atomic_weight = ptable.GetAtomicWeight(atomic_number)
|
238 |
+
accum_pos += np.array([x, y, z]) * atomic_weight
|
239 |
+
accum_mass += atomic_weight
|
240 |
+
|
241 |
+
center_of_mass = np.array(accum_pos / accum_mass, dtype=np.float32)
|
242 |
+
|
243 |
+
element = np.array(element, dtype=np.int_)
|
244 |
+
pos = np.array(pos, dtype=np.float32)
|
245 |
+
|
246 |
+
BOND_TYPES = {t: i for i, t in enumerate(BondType.names.values())}
|
247 |
+
bond_type_map = {
|
248 |
+
1: BOND_TYPES[BondType.SINGLE],
|
249 |
+
2: BOND_TYPES[BondType.DOUBLE],
|
250 |
+
3: BOND_TYPES[BondType.TRIPLE],
|
251 |
+
4: BOND_TYPES[BondType.AROMATIC],
|
252 |
+
}
|
253 |
+
row, col, edge_type = [], [], []
|
254 |
+
for bond_line in sdf[4+num_atoms:4+num_atoms+num_bonds]:
|
255 |
+
start, end = int(bond_line[0:3])-1, int(bond_line[3:6])-1
|
256 |
+
row += [start, end]
|
257 |
+
col += [end, start]
|
258 |
+
edge_type += 2 * [bond_type_map[int(bond_line[6:9])]]
|
259 |
+
|
260 |
+
edge_index = np.array([row, col], dtype=np.int_)
|
261 |
+
edge_type = np.array(edge_type, dtype=np.int_)
|
262 |
+
|
263 |
+
perm = (edge_index[0] * num_atoms + edge_index[1]).argsort()
|
264 |
+
edge_index = edge_index[:, perm]
|
265 |
+
edge_type = edge_type[perm]
|
266 |
+
|
267 |
+
neighbor_dict = {}
|
268 |
+
|
269 |
+
#used in rotation angle prediction
|
270 |
+
for i, atom in enumerate(mol.GetAtoms()):
|
271 |
+
neighbor_dict[i] = [n.GetIdx() for n in atom.GetNeighbors()]
|
272 |
+
|
273 |
+
data = {
|
274 |
+
'element': element,
|
275 |
+
'pos': pos,
|
276 |
+
'bond_index': edge_index,
|
277 |
+
'bond_type': edge_type,
|
278 |
+
'center_of_mass': center_of_mass,
|
279 |
+
'atom_feature': feat_mat,
|
280 |
+
'moltree': moltree,
|
281 |
+
'neighbors': neighbor_dict
|
282 |
+
}
|
283 |
+
return data
|
utils/reconstruct.py
ADDED
@@ -0,0 +1,498 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from rdkit.Chem import AllChem as Chem
|
3 |
+
from rdkit import Geometry
|
4 |
+
from openbabel import openbabel as ob
|
5 |
+
from openbabel import pybel
|
6 |
+
from scipy.spatial.distance import pdist
|
7 |
+
from scipy.spatial.distance import squareform
|
8 |
+
|
9 |
+
from .protein_ligand import ATOM_FAMILIES_ID
|
10 |
+
|
11 |
+
|
12 |
+
class MolReconsError(Exception):
|
13 |
+
pass
|
14 |
+
|
15 |
+
|
16 |
+
def reachable_r(a,b, seenbonds):
|
17 |
+
'''Recursive helper.'''
|
18 |
+
|
19 |
+
for nbr in ob.OBAtomAtomIter(a):
|
20 |
+
bond = a.GetBond(nbr).GetIdx()
|
21 |
+
if bond not in seenbonds:
|
22 |
+
seenbonds.add(bond)
|
23 |
+
if nbr == b:
|
24 |
+
return True
|
25 |
+
elif reachable_r(nbr,b,seenbonds):
|
26 |
+
return True
|
27 |
+
return False
|
28 |
+
|
29 |
+
|
30 |
+
def reachable(a,b):
|
31 |
+
'''Return true if atom b is reachable from a without using the bond between them.'''
|
32 |
+
if a.GetExplicitDegree() == 1 or b.GetExplicitDegree() == 1:
|
33 |
+
return False #this is the _only_ bond for one atom
|
34 |
+
#otherwise do recursive traversal
|
35 |
+
seenbonds = set([a.GetBond(b).GetIdx()])
|
36 |
+
return reachable_r(a,b,seenbonds)
|
37 |
+
|
38 |
+
|
39 |
+
def forms_small_angle(a,b,cutoff=45):
|
40 |
+
'''Return true if bond between a and b is part of a small angle
|
41 |
+
with a neighbor of a only.'''
|
42 |
+
|
43 |
+
for nbr in ob.OBAtomAtomIter(a):
|
44 |
+
if nbr != b:
|
45 |
+
degrees = b.GetAngle(a,nbr)
|
46 |
+
if degrees < cutoff:
|
47 |
+
return True
|
48 |
+
return False
|
49 |
+
|
50 |
+
|
51 |
+
def make_obmol(xyz, atomic_numbers):
|
52 |
+
mol = ob.OBMol()
|
53 |
+
mol.BeginModify()
|
54 |
+
atoms = []
|
55 |
+
for xyz,t in zip(xyz, atomic_numbers):
|
56 |
+
x,y,z = xyz
|
57 |
+
# ch = struct.channels[t]
|
58 |
+
atom = mol.NewAtom()
|
59 |
+
atom.SetAtomicNum(t)
|
60 |
+
atom.SetVector(x,y,z)
|
61 |
+
atoms.append(atom)
|
62 |
+
return mol, atoms
|
63 |
+
|
64 |
+
|
65 |
+
def connect_the_dots(mol, atoms, indicators, maxbond=4):
|
66 |
+
'''Custom implementation of ConnectTheDots. This is similar to
|
67 |
+
OpenBabel's version, but is more willing to make long bonds
|
68 |
+
(up to maxbond long) to keep the molecule connected. It also
|
69 |
+
attempts to respect atom type information from struct.
|
70 |
+
atoms and struct need to correspond in their order
|
71 |
+
Assumes no hydrogens or existing bonds.
|
72 |
+
'''
|
73 |
+
pt = Chem.GetPeriodicTable()
|
74 |
+
|
75 |
+
if len(atoms) == 0:
|
76 |
+
return
|
77 |
+
|
78 |
+
mol.BeginModify()
|
79 |
+
|
80 |
+
#just going to to do n^2 comparisons, can worry about efficiency later
|
81 |
+
coords = np.array([(a.GetX(),a.GetY(),a.GetZ()) for a in atoms])
|
82 |
+
dists = squareform(pdist(coords))
|
83 |
+
# types = [struct.channels[t].name for t in struct.c]
|
84 |
+
|
85 |
+
for (i,a) in enumerate(atoms):
|
86 |
+
for (j,b) in enumerate(atoms):
|
87 |
+
if a == b:
|
88 |
+
break
|
89 |
+
if dists[i,j] < 0.01: #reduce from 0.4
|
90 |
+
continue #don't bond too close atoms
|
91 |
+
if dists[i,j] < maxbond:
|
92 |
+
flag = 0
|
93 |
+
if indicators[i][ATOM_FAMILIES_ID['Aromatic']] and indicators[j][ATOM_FAMILIES_ID['Aromatic']]:
|
94 |
+
# print('Aromatic', ATOM_FAMILIES_ID['Aromatic'], indicators[i])
|
95 |
+
flag = ob.OB_AROMATIC_BOND
|
96 |
+
# if 'Aromatic' in types[i] and 'Aromatic' in types[j]:
|
97 |
+
# flag = ob.OB_AROMATIC_BOND
|
98 |
+
mol.AddBond(a.GetIdx(),b.GetIdx(),1,flag)
|
99 |
+
|
100 |
+
atom_maxb = {}
|
101 |
+
for (i,a) in enumerate(atoms):
|
102 |
+
#set max valance to the smallest max allowed by openbabel or rdkit
|
103 |
+
#since we want the molecule to be valid for both (rdkit is usually lower)
|
104 |
+
maxb = ob.GetMaxBonds(a.GetAtomicNum())
|
105 |
+
maxb = min(maxb,pt.GetDefaultValence(a.GetAtomicNum()))
|
106 |
+
|
107 |
+
if a.GetAtomicNum() == 16: # sulfone check
|
108 |
+
if count_nbrs_of_elem(a, 8) >= 2:
|
109 |
+
maxb = 6
|
110 |
+
|
111 |
+
# if indicators[i][ATOM_FAMILIES_ID['Donor']]:
|
112 |
+
# maxb -= 1 #leave room for hydrogen
|
113 |
+
# if 'Donor' in types[i]:
|
114 |
+
# maxb -= 1 #leave room for hydrogen
|
115 |
+
atom_maxb[a.GetIdx()] = maxb
|
116 |
+
|
117 |
+
#remove any impossible bonds between halogens
|
118 |
+
for bond in ob.OBMolBondIter(mol):
|
119 |
+
a1 = bond.GetBeginAtom()
|
120 |
+
a2 = bond.GetEndAtom()
|
121 |
+
if atom_maxb[a1.GetIdx()] == 1 and atom_maxb[a2.GetIdx()] == 1:
|
122 |
+
mol.DeleteBond(bond)
|
123 |
+
|
124 |
+
def get_bond_info(biter):
|
125 |
+
'''Return bonds sorted by their distortion'''
|
126 |
+
bonds = [b for b in biter]
|
127 |
+
binfo = []
|
128 |
+
for bond in bonds:
|
129 |
+
bdist = bond.GetLength()
|
130 |
+
#compute how far away from optimal we are
|
131 |
+
a1 = bond.GetBeginAtom()
|
132 |
+
a2 = bond.GetEndAtom()
|
133 |
+
ideal = ob.GetCovalentRad(a1.GetAtomicNum()) + ob.GetCovalentRad(a2.GetAtomicNum())
|
134 |
+
stretch = bdist-ideal
|
135 |
+
binfo.append((stretch,bdist,bond))
|
136 |
+
binfo.sort(reverse=True, key=lambda t: t[:2]) #most stretched bonds first
|
137 |
+
return binfo
|
138 |
+
|
139 |
+
#prioritize removing hypervalency causing bonds, do more valent
|
140 |
+
#constrained atoms first since their bonds introduce the most problems
|
141 |
+
#with reachability (e.g. oxygen)
|
142 |
+
# hypers = sorted([(atom_maxb[a.GetIdx()],a.GetExplicitValence() - atom_maxb[a.GetIdx()], a) for a in atoms],key=lambda aa: (aa[0],-aa[1]))
|
143 |
+
# for mb,diff,a in hypers:
|
144 |
+
# if a.GetExplicitValence() <= atom_maxb[a.GetIdx()]:
|
145 |
+
# continue
|
146 |
+
# binfo = get_bond_info(ob.OBAtomBondIter(a))
|
147 |
+
# for stretch,bdist,bond in binfo:
|
148 |
+
# #can we remove this bond without disconnecting the molecule?
|
149 |
+
# a1 = bond.GetBeginAtom()
|
150 |
+
# a2 = bond.GetEndAtom()
|
151 |
+
|
152 |
+
# #get right valence
|
153 |
+
# if a1.GetExplicitValence() > atom_maxb[a1.GetIdx()] or \
|
154 |
+
# a2.GetExplicitValence() > atom_maxb[a2.GetIdx()]:
|
155 |
+
# #don't fragment the molecule
|
156 |
+
# if not reachable(a1,a2):
|
157 |
+
# continue
|
158 |
+
# mol.DeleteBond(bond)
|
159 |
+
# if a.GetExplicitValence() <= atom_maxb[a.GetIdx()]:
|
160 |
+
# break #let nbr atoms choose what bonds to throw out
|
161 |
+
|
162 |
+
|
163 |
+
binfo = get_bond_info(ob.OBMolBondIter(mol))
|
164 |
+
#now eliminate geometrically poor bonds
|
165 |
+
for stretch,bdist,bond in binfo:
|
166 |
+
#can we remove this bond without disconnecting the molecule?
|
167 |
+
a1 = bond.GetBeginAtom()
|
168 |
+
a2 = bond.GetEndAtom()
|
169 |
+
|
170 |
+
#as long as we aren't disconnecting, let's remove things
|
171 |
+
#that are excessively far away (0.45 from ConnectTheDots)
|
172 |
+
#get bonds to be less than max allowed
|
173 |
+
#also remove tight angles, because that is what ConnectTheDots does
|
174 |
+
if stretch > 0.45 or forms_small_angle(a1,a2) or forms_small_angle(a2,a1):
|
175 |
+
#don't fragment the molecule
|
176 |
+
if not reachable(a1,a2):
|
177 |
+
continue
|
178 |
+
mol.DeleteBond(bond)
|
179 |
+
|
180 |
+
mol.EndModify()
|
181 |
+
|
182 |
+
|
183 |
+
def convert_ob_mol_to_rd_mol(ob_mol,struct=None):
|
184 |
+
'''Convert OBMol to RDKit mol, fixing up issues'''
|
185 |
+
ob_mol.DeleteHydrogens()
|
186 |
+
n_atoms = ob_mol.NumAtoms()
|
187 |
+
rd_mol = Chem.RWMol()
|
188 |
+
rd_conf = Chem.Conformer(n_atoms)
|
189 |
+
|
190 |
+
for ob_atom in ob.OBMolAtomIter(ob_mol):
|
191 |
+
rd_atom = Chem.Atom(ob_atom.GetAtomicNum())
|
192 |
+
#TODO copy format charge
|
193 |
+
if ob_atom.IsAromatic() and ob_atom.IsInRing() and ob_atom.MemberOfRingSize() <= 6:
|
194 |
+
#don't commit to being aromatic unless rdkit will be okay with the ring status
|
195 |
+
#(this can happen if the atoms aren't fit well enough)
|
196 |
+
rd_atom.SetIsAromatic(True)
|
197 |
+
i = rd_mol.AddAtom(rd_atom)
|
198 |
+
ob_coords = ob_atom.GetVector()
|
199 |
+
x = ob_coords.GetX()
|
200 |
+
y = ob_coords.GetY()
|
201 |
+
z = ob_coords.GetZ()
|
202 |
+
rd_coords = Geometry.Point3D(x, y, z)
|
203 |
+
rd_conf.SetAtomPosition(i, rd_coords)
|
204 |
+
|
205 |
+
rd_mol.AddConformer(rd_conf)
|
206 |
+
|
207 |
+
for ob_bond in ob.OBMolBondIter(ob_mol):
|
208 |
+
i = ob_bond.GetBeginAtomIdx()-1
|
209 |
+
j = ob_bond.GetEndAtomIdx()-1
|
210 |
+
bond_order = ob_bond.GetBondOrder()
|
211 |
+
if bond_order == 1:
|
212 |
+
rd_mol.AddBond(i, j, Chem.BondType.SINGLE)
|
213 |
+
elif bond_order == 2:
|
214 |
+
rd_mol.AddBond(i, j, Chem.BondType.DOUBLE)
|
215 |
+
elif bond_order == 3:
|
216 |
+
rd_mol.AddBond(i, j, Chem.BondType.TRIPLE)
|
217 |
+
else:
|
218 |
+
raise Exception('unknown bond order {}'.format(bond_order))
|
219 |
+
|
220 |
+
if ob_bond.IsAromatic():
|
221 |
+
bond = rd_mol.GetBondBetweenAtoms (i,j)
|
222 |
+
bond.SetIsAromatic(True)
|
223 |
+
|
224 |
+
rd_mol = Chem.RemoveHs(rd_mol, sanitize=False)
|
225 |
+
|
226 |
+
pt = Chem.GetPeriodicTable()
|
227 |
+
#if double/triple bonds are connected to hypervalent atoms, decrement the order
|
228 |
+
|
229 |
+
positions = rd_mol.GetConformer().GetPositions()
|
230 |
+
nonsingles = []
|
231 |
+
for bond in rd_mol.GetBonds():
|
232 |
+
if bond.GetBondType() == Chem.BondType.DOUBLE or bond.GetBondType() == Chem.BondType.TRIPLE:
|
233 |
+
i = bond.GetBeginAtomIdx()
|
234 |
+
j = bond.GetEndAtomIdx()
|
235 |
+
dist = np.linalg.norm(positions[i]-positions[j])
|
236 |
+
nonsingles.append((dist,bond))
|
237 |
+
nonsingles.sort(reverse=True, key=lambda t: t[0])
|
238 |
+
|
239 |
+
for (d,bond) in nonsingles:
|
240 |
+
a1 = bond.GetBeginAtom()
|
241 |
+
a2 = bond.GetEndAtom()
|
242 |
+
|
243 |
+
if calc_valence(a1) > pt.GetDefaultValence(a1.GetAtomicNum()) or \
|
244 |
+
calc_valence(a2) > pt.GetDefaultValence(a2.GetAtomicNum()):
|
245 |
+
btype = Chem.BondType.SINGLE
|
246 |
+
if bond.GetBondType() == Chem.BondType.TRIPLE:
|
247 |
+
btype = Chem.BondType.DOUBLE
|
248 |
+
bond.SetBondType(btype)
|
249 |
+
|
250 |
+
for atom in rd_mol.GetAtoms():
|
251 |
+
#set nitrogens with 4 neighbors to have a charge
|
252 |
+
if atom.GetAtomicNum() == 7 and atom.GetDegree() == 4:
|
253 |
+
atom.SetFormalCharge(1)
|
254 |
+
|
255 |
+
rd_mol = Chem.AddHs(rd_mol,addCoords=True)
|
256 |
+
|
257 |
+
positions = rd_mol.GetConformer().GetPositions()
|
258 |
+
center = np.mean(positions[np.all(np.isfinite(positions),axis=1)],axis=0)
|
259 |
+
for atom in rd_mol.GetAtoms():
|
260 |
+
i = atom.GetIdx()
|
261 |
+
pos = positions[i]
|
262 |
+
if not np.all(np.isfinite(pos)):
|
263 |
+
#hydrogens on C fragment get set to nan (shouldn't, but they do)
|
264 |
+
rd_mol.GetConformer().SetAtomPosition(i,center)
|
265 |
+
|
266 |
+
try:
|
267 |
+
Chem.SanitizeMol(rd_mol,Chem.SANITIZE_ALL^Chem.SANITIZE_KEKULIZE)
|
268 |
+
except:
|
269 |
+
raise MolReconsError()
|
270 |
+
# try:
|
271 |
+
# Chem.SanitizeMol(rd_mol,Chem.SANITIZE_ALL^Chem.SANITIZE_KEKULIZE)
|
272 |
+
# except: # mtr22 - don't assume mols will pass this
|
273 |
+
# pass
|
274 |
+
# # dkoes - but we want to make failures as rare as possible and should debug them
|
275 |
+
# m = pybel.Molecule(ob_mol)
|
276 |
+
# i = np.random.randint(1000000)
|
277 |
+
# outname = 'bad%d.sdf'%i
|
278 |
+
# print("WRITING",outname)
|
279 |
+
# m.write('sdf',outname,overwrite=True)
|
280 |
+
# pickle.dump(struct,open('bad%d.pkl'%i,'wb'))
|
281 |
+
|
282 |
+
#but at some point stop trying to enforce our aromaticity -
|
283 |
+
#openbabel and rdkit have different aromaticity models so they
|
284 |
+
#won't always agree. Remove any aromatic bonds to non-aromatic atoms
|
285 |
+
for bond in rd_mol.GetBonds():
|
286 |
+
a1 = bond.GetBeginAtom()
|
287 |
+
a2 = bond.GetEndAtom()
|
288 |
+
if bond.GetIsAromatic():
|
289 |
+
if not a1.GetIsAromatic() or not a2.GetIsAromatic():
|
290 |
+
bond.SetIsAromatic(False)
|
291 |
+
elif a1.GetIsAromatic() and a2.GetIsAromatic():
|
292 |
+
bond.SetIsAromatic(True)
|
293 |
+
|
294 |
+
return rd_mol
|
295 |
+
|
296 |
+
|
297 |
+
def calc_valence(rdatom):
|
298 |
+
'''Can call GetExplicitValence before sanitize, but need to
|
299 |
+
know this to fix up the molecule to prevent sanitization failures'''
|
300 |
+
cnt = 0.0
|
301 |
+
for bond in rdatom.GetBonds():
|
302 |
+
cnt += bond.GetBondTypeAsDouble()
|
303 |
+
return cnt
|
304 |
+
|
305 |
+
|
306 |
+
def count_nbrs_of_elem(atom, atomic_num):
|
307 |
+
'''
|
308 |
+
Count the number of neighbors atoms
|
309 |
+
of atom with the given atomic_num.
|
310 |
+
'''
|
311 |
+
count = 0
|
312 |
+
for nbr in ob.OBAtomAtomIter(atom):
|
313 |
+
if nbr.GetAtomicNum() == atomic_num:
|
314 |
+
count += 1
|
315 |
+
return count
|
316 |
+
|
317 |
+
|
318 |
+
def fixup(atoms, mol, indicators):
|
319 |
+
'''Set atom properties to match channel. Keep doing this
|
320 |
+
to beat openbabel over the head with what we want to happen.'''
|
321 |
+
|
322 |
+
mol.SetAromaticPerceived(True) #avoid perception
|
323 |
+
for i, atom in enumerate(atoms):
|
324 |
+
# ch = struct.channels[t]
|
325 |
+
ind = indicators[i]
|
326 |
+
|
327 |
+
if ind[ATOM_FAMILIES_ID['Aromatic']]:
|
328 |
+
atom.SetAromatic(True)
|
329 |
+
atom.SetHyb(2)
|
330 |
+
|
331 |
+
# if ind[ATOM_FAMILIES_ID['Donor']]:
|
332 |
+
# if atom.GetExplicitDegree() == atom.GetHvyDegree():
|
333 |
+
# if atom.GetHvyDegree() == 1 and atom.GetAtomicNum() == 7:
|
334 |
+
# atom.SetImplicitHCount(2)
|
335 |
+
# else:
|
336 |
+
# atom.SetImplicitHCount(1)
|
337 |
+
|
338 |
+
|
339 |
+
# elif ind[ATOM_FAMILIES_ID['Acceptor']]: # NOT AcceptorDonor because of else
|
340 |
+
# atom.SetImplicitHCount(0)
|
341 |
+
|
342 |
+
if (atom.GetAtomicNum() in (7, 8)) and atom.IsInRing(): # Nitrogen, Oxygen
|
343 |
+
#this is a little iffy, ommitting until there is more evidence it is a net positive
|
344 |
+
#we don't have aromatic types for nitrogen, but if it
|
345 |
+
#is in a ring with aromatic carbon mark it aromatic as well
|
346 |
+
acnt = 0
|
347 |
+
for nbr in ob.OBAtomAtomIter(atom):
|
348 |
+
if nbr.IsAromatic():
|
349 |
+
acnt += 1
|
350 |
+
if acnt > 1:
|
351 |
+
atom.SetAromatic(True)
|
352 |
+
|
353 |
+
|
354 |
+
def raw_obmol_from_generated(data):
|
355 |
+
xyz = data.ligand_context_pos.clone().cpu().tolist()
|
356 |
+
atomic_nums = data.ligand_context_element.clone().cpu().tolist()
|
357 |
+
# indicators = data.ligand_context_feature_full[:, -len(ATOM_FAMILIES_ID):].clone().cpu().bool().tolist()
|
358 |
+
|
359 |
+
mol, atoms = make_obmol(xyz, atomic_nums)
|
360 |
+
return mol, atoms
|
361 |
+
|
362 |
+
|
363 |
+
UPGRADE_BOND_ORDER = {Chem.BondType.SINGLE:Chem.BondType.DOUBLE, Chem.BondType.DOUBLE:Chem.BondType.TRIPLE}
|
364 |
+
|
365 |
+
def postprocess_rd_mol_1(rdmol):
|
366 |
+
|
367 |
+
rdmol = Chem.RemoveHs(rdmol)
|
368 |
+
|
369 |
+
# Construct bond nbh list
|
370 |
+
nbh_list = {}
|
371 |
+
for bond in rdmol.GetBonds():
|
372 |
+
begin, end = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
|
373 |
+
if begin not in nbh_list: nbh_list[begin] = [end]
|
374 |
+
else: nbh_list[begin].append(end)
|
375 |
+
|
376 |
+
if end not in nbh_list: nbh_list[end] = [begin]
|
377 |
+
else: nbh_list[end].append(begin)
|
378 |
+
|
379 |
+
# Fix missing bond-order
|
380 |
+
for atom in rdmol.GetAtoms():
|
381 |
+
idx = atom.GetIdx()
|
382 |
+
num_radical = atom.GetNumRadicalElectrons()
|
383 |
+
if num_radical > 0:
|
384 |
+
for j in nbh_list[idx]:
|
385 |
+
if j <= idx: continue
|
386 |
+
nb_atom = rdmol.GetAtomWithIdx(j)
|
387 |
+
nb_radical = nb_atom.GetNumRadicalElectrons()
|
388 |
+
if nb_radical > 0:
|
389 |
+
bond = rdmol.GetBondBetweenAtoms(idx, j)
|
390 |
+
bond.SetBondType(UPGRADE_BOND_ORDER[bond.GetBondType()])
|
391 |
+
nb_atom.SetNumRadicalElectrons(nb_radical - 1)
|
392 |
+
num_radical -= 1
|
393 |
+
atom.SetNumRadicalElectrons(num_radical)
|
394 |
+
|
395 |
+
num_radical = atom.GetNumRadicalElectrons()
|
396 |
+
if num_radical > 0:
|
397 |
+
atom.SetNumRadicalElectrons(0)
|
398 |
+
num_hs = atom.GetNumExplicitHs()
|
399 |
+
atom.SetNumExplicitHs(num_hs + num_radical)
|
400 |
+
|
401 |
+
return rdmol
|
402 |
+
|
403 |
+
|
404 |
+
def postprocess_rd_mol_2(rdmol):
|
405 |
+
rdmol_edit = Chem.RWMol(rdmol)
|
406 |
+
|
407 |
+
ring_info = rdmol.GetRingInfo()
|
408 |
+
ring_info.AtomRings()
|
409 |
+
rings = [set(r) for r in ring_info.AtomRings()]
|
410 |
+
for i, ring_a in enumerate(rings):
|
411 |
+
if len(ring_a) == 3:
|
412 |
+
non_carbon = []
|
413 |
+
atom_by_symb = {}
|
414 |
+
for atom_idx in ring_a:
|
415 |
+
symb = rdmol.GetAtomWithIdx(atom_idx).GetSymbol()
|
416 |
+
if symb != 'C':
|
417 |
+
non_carbon.append(atom_idx)
|
418 |
+
if symb not in atom_by_symb:
|
419 |
+
atom_by_symb[symb] = [atom_idx]
|
420 |
+
else:
|
421 |
+
atom_by_symb[symb].append(atom_idx)
|
422 |
+
if len(non_carbon) == 2:
|
423 |
+
rdmol_edit.RemoveBond(*non_carbon)
|
424 |
+
if 'O' in atom_by_symb and len(atom_by_symb['O']) == 2:
|
425 |
+
rdmol_edit.RemoveBond(*atom_by_symb['O'])
|
426 |
+
rdmol_edit.GetAtomWithIdx(atom_by_symb['O'][0]).SetNumExplicitHs(
|
427 |
+
rdmol_edit.GetAtomWithIdx(atom_by_symb['O'][0]).GetNumExplicitHs() + 1
|
428 |
+
)
|
429 |
+
rdmol_edit.GetAtomWithIdx(atom_by_symb['O'][1]).SetNumExplicitHs(
|
430 |
+
rdmol_edit.GetAtomWithIdx(atom_by_symb['O'][1]).GetNumExplicitHs() + 1
|
431 |
+
)
|
432 |
+
rdmol = rdmol_edit.GetMol()
|
433 |
+
|
434 |
+
for atom in rdmol.GetAtoms():
|
435 |
+
if atom.GetFormalCharge() > 0:
|
436 |
+
atom.SetFormalCharge(0)
|
437 |
+
|
438 |
+
return rdmol
|
439 |
+
|
440 |
+
|
441 |
+
def reconstruct_from_generated(data):
|
442 |
+
xyz = data.ligand_context_pos.clone().cpu().tolist()
|
443 |
+
atomic_nums = data.ligand_context_element.clone().cpu().tolist()
|
444 |
+
indicators = data.ligand_context_feature_full[:, -len(ATOM_FAMILIES_ID):].clone().cpu().bool().tolist()
|
445 |
+
|
446 |
+
mol, atoms = make_obmol(xyz, atomic_nums)
|
447 |
+
fixup(atoms, mol, indicators)
|
448 |
+
|
449 |
+
connect_the_dots(mol, atoms, indicators, 2)
|
450 |
+
fixup(atoms, mol, indicators)
|
451 |
+
mol.EndModify()
|
452 |
+
|
453 |
+
fixup(atoms, mol, indicators)
|
454 |
+
|
455 |
+
mol.AddPolarHydrogens()
|
456 |
+
mol.PerceiveBondOrders()
|
457 |
+
fixup(atoms, mol, indicators)
|
458 |
+
|
459 |
+
for (i,a) in enumerate(atoms):
|
460 |
+
ob.OBAtomAssignTypicalImplicitHydrogens(a)
|
461 |
+
fixup(atoms, mol, indicators)
|
462 |
+
|
463 |
+
mol.AddHydrogens()
|
464 |
+
fixup(atoms, mol, indicators)
|
465 |
+
|
466 |
+
#make rings all aromatic if majority of carbons are aromatic
|
467 |
+
for ring in ob.OBMolRingIter(mol):
|
468 |
+
if 5 <= ring.Size() <= 6:
|
469 |
+
carbon_cnt = 0
|
470 |
+
aromatic_ccnt = 0
|
471 |
+
for ai in ring._path:
|
472 |
+
a = mol.GetAtom(ai)
|
473 |
+
if a.GetAtomicNum() == 6:
|
474 |
+
carbon_cnt += 1
|
475 |
+
if a.IsAromatic():
|
476 |
+
aromatic_ccnt += 1
|
477 |
+
if aromatic_ccnt >= carbon_cnt/2 and aromatic_ccnt != ring.Size():
|
478 |
+
#set all ring atoms to be aromatic
|
479 |
+
for ai in ring._path:
|
480 |
+
a = mol.GetAtom(ai)
|
481 |
+
a.SetAromatic(True)
|
482 |
+
|
483 |
+
#bonds must be marked aromatic for smiles to match
|
484 |
+
for bond in ob.OBMolBondIter(mol):
|
485 |
+
a1 = bond.GetBeginAtom()
|
486 |
+
a2 = bond.GetEndAtom()
|
487 |
+
if a1.IsAromatic() and a2.IsAromatic():
|
488 |
+
bond.SetAromatic(True)
|
489 |
+
|
490 |
+
mol.PerceiveBondOrders()
|
491 |
+
|
492 |
+
rd_mol = convert_ob_mol_to_rd_mol(mol)
|
493 |
+
|
494 |
+
# Post-processing
|
495 |
+
rd_mol = postprocess_rd_mol_1(rd_mol)
|
496 |
+
rd_mol = postprocess_rd_mol_2(rd_mol)
|
497 |
+
|
498 |
+
return rd_mol
|
utils/sascorer.py
ADDED
@@ -0,0 +1,163 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import print_function
|
2 |
+
|
3 |
+
from rdkit import Chem
|
4 |
+
from rdkit.Chem import rdMolDescriptors
|
5 |
+
from rdkit.six.moves import cPickle
|
6 |
+
from rdkit.six import iteritems
|
7 |
+
|
8 |
+
import math
|
9 |
+
from collections import defaultdict
|
10 |
+
|
11 |
+
import os.path as op
|
12 |
+
|
13 |
+
_fscores = None
|
14 |
+
|
15 |
+
|
16 |
+
def readFragmentScores(name='fpscores'):
|
17 |
+
import gzip
|
18 |
+
global _fscores
|
19 |
+
# generate the full path filename:
|
20 |
+
if name == "fpscores":
|
21 |
+
name = op.join(op.dirname(__file__), name)
|
22 |
+
_fscores = cPickle.load(gzip.open('%s.pkl.gz' % name))
|
23 |
+
outDict = {}
|
24 |
+
for i in _fscores:
|
25 |
+
for j in range(1, len(i)):
|
26 |
+
outDict[i[j]] = float(i[0])
|
27 |
+
_fscores = outDict
|
28 |
+
|
29 |
+
|
30 |
+
def numBridgeheadsAndSpiro(mol, ri=None):
|
31 |
+
nSpiro = rdMolDescriptors.CalcNumSpiroAtoms(mol)
|
32 |
+
nBridgehead = rdMolDescriptors.CalcNumBridgeheadAtoms(mol)
|
33 |
+
return nBridgehead, nSpiro
|
34 |
+
|
35 |
+
|
36 |
+
def calculateScore(m):
|
37 |
+
if _fscores is None:
|
38 |
+
readFragmentScores()
|
39 |
+
|
40 |
+
# fragment score
|
41 |
+
fp = rdMolDescriptors.GetMorganFingerprint(m,
|
42 |
+
2) #<- 2 is the *radius* of the circular fingerprint
|
43 |
+
fps = fp.GetNonzeroElements()
|
44 |
+
score1 = 0.
|
45 |
+
nf = 0
|
46 |
+
for bitId, v in iteritems(fps):
|
47 |
+
nf += v
|
48 |
+
sfp = bitId
|
49 |
+
score1 += _fscores.get(sfp, -4) * v
|
50 |
+
score1 /= nf
|
51 |
+
|
52 |
+
# features score
|
53 |
+
nAtoms = m.GetNumAtoms()
|
54 |
+
nChiralCenters = len(Chem.FindMolChiralCenters(m, includeUnassigned=True))
|
55 |
+
ri = m.GetRingInfo()
|
56 |
+
nBridgeheads, nSpiro = numBridgeheadsAndSpiro(m, ri)
|
57 |
+
nMacrocycles = 0
|
58 |
+
for x in ri.AtomRings():
|
59 |
+
if len(x) > 8:
|
60 |
+
nMacrocycles += 1
|
61 |
+
|
62 |
+
sizePenalty = nAtoms**1.005 - nAtoms
|
63 |
+
stereoPenalty = math.log10(nChiralCenters + 1)
|
64 |
+
spiroPenalty = math.log10(nSpiro + 1)
|
65 |
+
bridgePenalty = math.log10(nBridgeheads + 1)
|
66 |
+
macrocyclePenalty = 0.
|
67 |
+
# ---------------------------------------
|
68 |
+
# This differs from the paper, which defines:
|
69 |
+
# macrocyclePenalty = math.log10(nMacrocycles+1)
|
70 |
+
# This form generates better results when 2 or more macrocycles are present
|
71 |
+
if nMacrocycles > 0:
|
72 |
+
macrocyclePenalty = math.log10(2)
|
73 |
+
|
74 |
+
score2 = 0. - sizePenalty - stereoPenalty - spiroPenalty - bridgePenalty - macrocyclePenalty
|
75 |
+
|
76 |
+
# correction for the fingerprint density
|
77 |
+
# not in the original publication, added in version 1.1
|
78 |
+
# to make highly symmetrical molecules easier to synthetise
|
79 |
+
score3 = 0.
|
80 |
+
if nAtoms > len(fps):
|
81 |
+
score3 = math.log(float(nAtoms) / len(fps)) * .5
|
82 |
+
|
83 |
+
sascore = score1 + score2 + score3
|
84 |
+
|
85 |
+
# need to transform "raw" value into scale between 1 and 10
|
86 |
+
min = -4.0
|
87 |
+
max = 2.5
|
88 |
+
sascore = 11. - (sascore - min + 1) / (max - min) * 9.
|
89 |
+
# smooth the 10-end
|
90 |
+
if sascore > 8.:
|
91 |
+
sascore = 8. + math.log(sascore + 1. - 9.)
|
92 |
+
if sascore > 10.:
|
93 |
+
sascore = 10.0
|
94 |
+
elif sascore < 1.:
|
95 |
+
sascore = 1.0
|
96 |
+
|
97 |
+
return sascore
|
98 |
+
|
99 |
+
|
100 |
+
def processMols(mols):
|
101 |
+
print('smiles\tName\tsa_score')
|
102 |
+
for i, m in enumerate(mols):
|
103 |
+
if m is None:
|
104 |
+
continue
|
105 |
+
|
106 |
+
s = calculateScore(m)
|
107 |
+
|
108 |
+
smiles = Chem.MolToSmiles(m)
|
109 |
+
print(smiles + "\t" + m.GetProp('_Name') + "\t%3f" % s)
|
110 |
+
|
111 |
+
|
112 |
+
if __name__ == '__main__':
|
113 |
+
import sys, time
|
114 |
+
|
115 |
+
t1 = time.time()
|
116 |
+
readFragmentScores("fpscores")
|
117 |
+
t2 = time.time()
|
118 |
+
|
119 |
+
suppl = Chem.SmilesMolSupplier(sys.argv[1])
|
120 |
+
t3 = time.time()
|
121 |
+
processMols(suppl)
|
122 |
+
t4 = time.time()
|
123 |
+
|
124 |
+
print('Reading took %.2f seconds. Calculating took %.2f seconds' % ((t2 - t1), (t4 - t3)),
|
125 |
+
file=sys.stderr)
|
126 |
+
|
127 |
+
#
|
128 |
+
# Copyright (c) 2013, Novartis Institutes for BioMedical Research Inc.
|
129 |
+
# All rights reserved.
|
130 |
+
#
|
131 |
+
# Redistribution and use in source and binary forms, with or without
|
132 |
+
# modification, are permitted provided that the following conditions are
|
133 |
+
# met:
|
134 |
+
#
|
135 |
+
# * Redistributions of source code must retain the above copyright
|
136 |
+
# notice, this list of conditions and the following disclaimer.
|
137 |
+
# * Redistributions in binary form must reproduce the above
|
138 |
+
# copyright notice, this list of conditions and the following
|
139 |
+
# disclaimer in the documentation and/or other materials provided
|
140 |
+
# with the distribution.
|
141 |
+
# * Neither the name of Novartis Institutes for BioMedical Research Inc.
|
142 |
+
# nor the names of its contributors may be used to endorse or promote
|
143 |
+
# products derived from this software without specific prior written permission.
|
144 |
+
#
|
145 |
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
146 |
+
# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
147 |
+
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
148 |
+
# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
149 |
+
# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
150 |
+
# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
151 |
+
# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
152 |
+
# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
153 |
+
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
154 |
+
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
155 |
+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
156 |
+
#
|
157 |
+
|
158 |
+
def compute_sa_score(rdmol):
|
159 |
+
rdmol = Chem.MolFromSmiles(Chem.MolToSmiles(rdmol))
|
160 |
+
sa = calculateScore(rdmol)
|
161 |
+
sa = round((10-sa)/9,2)
|
162 |
+
return sa
|
163 |
+
|
utils/similarity.py
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from rdkit import Chem, DataStructs
|
3 |
+
|
4 |
+
|
5 |
+
def tanimoto_sim(mol, ref):
|
6 |
+
fp1 = Chem.RDKFingerprint(ref)
|
7 |
+
fp2 = Chem.RDKFingerprint(mol)
|
8 |
+
return DataStructs.TanimotoSimilarity(fp1,fp2)
|
9 |
+
|
10 |
+
|
11 |
+
def tanimoto_sim_N_to_1(mols, ref):
|
12 |
+
sim = [tanimoto_sim(m, ref) for m in mols]
|
13 |
+
return sim
|
14 |
+
|
15 |
+
|
16 |
+
def batched_number_of_rings(mols):
|
17 |
+
n = []
|
18 |
+
for m in mols:
|
19 |
+
n.append(Chem.rdMolDescriptors.CalcNumRings(m))
|
20 |
+
return np.array(n)
|
utils/train.py
ADDED
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
import warnings
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
from torch_geometric.data import Data, Batch
|
7 |
+
|
8 |
+
from .warmup import GradualWarmupScheduler
|
9 |
+
|
10 |
+
|
11 |
+
#customize exp lr scheduler with min lr
|
12 |
+
class ExponentialLR_with_minLr(torch.optim.lr_scheduler.ExponentialLR):
|
13 |
+
def __init__(self, optimizer, gamma, min_lr=1e-4, last_epoch=-1, verbose=False):
|
14 |
+
self.gamma = gamma
|
15 |
+
self.min_lr = min_lr
|
16 |
+
super(ExponentialLR_with_minLr, self).__init__(optimizer, gamma, last_epoch, verbose)
|
17 |
+
|
18 |
+
def get_lr(self):
|
19 |
+
if not self._get_lr_called_within_step:
|
20 |
+
warnings.warn("To get the last learning rate computed by the scheduler, "
|
21 |
+
"please use `get_last_lr()`.", UserWarning)
|
22 |
+
|
23 |
+
if self.last_epoch == 0:
|
24 |
+
return self.base_lrs
|
25 |
+
return [max(group['lr'] * self.gamma, self.min_lr)
|
26 |
+
for group in self.optimizer.param_groups]
|
27 |
+
|
28 |
+
def _get_closed_form_lr(self):
|
29 |
+
return [max(base_lr * self.gamma ** self.last_epoch, self.min_lr)
|
30 |
+
for base_lr in self.base_lrs]
|
31 |
+
|
32 |
+
|
33 |
+
def repeat_data(data: Data, num_repeat) -> Batch:
|
34 |
+
datas = [copy.deepcopy(data) for i in range(num_repeat)]
|
35 |
+
return Batch.from_data_list(datas)
|
36 |
+
|
37 |
+
|
38 |
+
def repeat_batch(batch: Batch, num_repeat) -> Batch:
|
39 |
+
datas = batch.to_data_list()
|
40 |
+
new_data = []
|
41 |
+
for i in range(num_repeat):
|
42 |
+
new_data += copy.deepcopy(datas)
|
43 |
+
return Batch.from_data_list(new_data)
|
44 |
+
|
45 |
+
|
46 |
+
def inf_iterator(iterable):
|
47 |
+
iterator = iterable.__iter__()
|
48 |
+
while True:
|
49 |
+
try:
|
50 |
+
yield iterator.__next__()
|
51 |
+
except StopIteration:
|
52 |
+
iterator = iterable.__iter__()
|
53 |
+
|
54 |
+
|
55 |
+
def get_optimizer(cfg, model):
|
56 |
+
if cfg.type == 'adam':
|
57 |
+
return torch.optim.Adam(
|
58 |
+
model.parameters(),
|
59 |
+
lr=cfg.lr,
|
60 |
+
weight_decay=cfg.weight_decay,
|
61 |
+
betas=(cfg.beta1, cfg.beta2, )
|
62 |
+
)
|
63 |
+
else:
|
64 |
+
raise NotImplementedError('Optimizer not supported: %s' % cfg.type)
|
65 |
+
|
66 |
+
|
67 |
+
def get_scheduler(cfg, optimizer):
|
68 |
+
if cfg.type == 'plateau':
|
69 |
+
return torch.optim.lr_scheduler.ReduceLROnPlateau(
|
70 |
+
optimizer,
|
71 |
+
factor=cfg.factor,
|
72 |
+
patience=cfg.patience,
|
73 |
+
min_lr=cfg.min_lr
|
74 |
+
)
|
75 |
+
elif cfg.type == 'warmup_plateau':
|
76 |
+
return GradualWarmupScheduler(
|
77 |
+
optimizer,
|
78 |
+
multiplier = cfg.multiplier,
|
79 |
+
total_epoch = cfg.total_epoch,
|
80 |
+
after_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
|
81 |
+
optimizer,
|
82 |
+
factor=cfg.factor,
|
83 |
+
patience=cfg.patience,
|
84 |
+
min_lr=cfg.min_lr
|
85 |
+
)
|
86 |
+
)
|
87 |
+
elif cfg.type == 'expmin':
|
88 |
+
return ExponentialLR_with_minLr(
|
89 |
+
optimizer,
|
90 |
+
gamma=cfg.factor,
|
91 |
+
min_lr=cfg.min_lr,
|
92 |
+
)
|
93 |
+
elif cfg.type == 'expmin_milestone':
|
94 |
+
gamma = np.exp(np.log(cfg.factor) / cfg.milestone)
|
95 |
+
return ExponentialLR_with_minLr(
|
96 |
+
optimizer,
|
97 |
+
gamma=gamma,
|
98 |
+
min_lr=cfg.min_lr,
|
99 |
+
)
|
100 |
+
else:
|
101 |
+
raise NotImplementedError('Scheduler not supported: %s' % cfg.type)
|
102 |
+
|
utils/transforms.py
ADDED
@@ -0,0 +1,684 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
sys.path.append("..")
|
3 |
+
import copy
|
4 |
+
import os
|
5 |
+
import random
|
6 |
+
import torch
|
7 |
+
import torch.nn.functional as F
|
8 |
+
import numpy as np
|
9 |
+
from copy import deepcopy
|
10 |
+
from torch_geometric.transforms import Compose
|
11 |
+
from torch_geometric.nn.pool import knn_graph
|
12 |
+
from torch_geometric.utils.subgraph import subgraph
|
13 |
+
from torch_geometric.utils.num_nodes import maybe_num_nodes
|
14 |
+
from torch_geometric.data import Data, Batch
|
15 |
+
from torch_scatter import scatter_add
|
16 |
+
from rdkit import Chem
|
17 |
+
from rdkit.Chem import Descriptors
|
18 |
+
from rdkit.Chem import AllChem
|
19 |
+
|
20 |
+
from .data import ProteinLigandData
|
21 |
+
from .protein_ligand import ATOM_FAMILIES
|
22 |
+
from .chemutils import enumerate_assemble, list_filter, rand_rotate
|
23 |
+
from .dihedral_utils import batch_dihedrals
|
24 |
+
|
25 |
+
# allowable node and edge features
|
26 |
+
allowable_features = {
|
27 |
+
'possible_atomic_num_list': list(range(1, 119)),
|
28 |
+
'possible_formal_charge_list': [-5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5],
|
29 |
+
'possible_chirality_list': [
|
30 |
+
Chem.rdchem.ChiralType.CHI_UNSPECIFIED,
|
31 |
+
Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CW,
|
32 |
+
Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CCW,
|
33 |
+
Chem.rdchem.ChiralType.CHI_OTHER
|
34 |
+
],
|
35 |
+
'possible_hybridization_list': [
|
36 |
+
Chem.rdchem.HybridizationType.S,
|
37 |
+
Chem.rdchem.HybridizationType.SP, Chem.rdchem.HybridizationType.SP2,
|
38 |
+
Chem.rdchem.HybridizationType.SP3, Chem.rdchem.HybridizationType.SP3D,
|
39 |
+
Chem.rdchem.HybridizationType.SP3D2, Chem.rdchem.HybridizationType.UNSPECIFIED
|
40 |
+
],
|
41 |
+
'possible_numH_list': [0, 1, 2, 3, 4, 5, 6, 7, 8],
|
42 |
+
'possible_implicit_valence_list': [0, 1, 2, 3, 4, 5, 6],
|
43 |
+
'possible_degree_list': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
|
44 |
+
'possible_bonds': [
|
45 |
+
Chem.rdchem.BondType.SINGLE,
|
46 |
+
Chem.rdchem.BondType.DOUBLE,
|
47 |
+
Chem.rdchem.BondType.TRIPLE,
|
48 |
+
Chem.rdchem.BondType.AROMATIC
|
49 |
+
],
|
50 |
+
'possible_bond_dirs': [ # only for double bond stereo information
|
51 |
+
Chem.rdchem.BondDir.NONE,
|
52 |
+
Chem.rdchem.BondDir.ENDUPRIGHT,
|
53 |
+
Chem.rdchem.BondDir.ENDDOWNRIGHT
|
54 |
+
]
|
55 |
+
}
|
56 |
+
|
57 |
+
|
58 |
+
def mol_to_graph_data_obj_simple(mol):
|
59 |
+
"""
|
60 |
+
Converts rdkit mol object to graph Data object required by the pytorch
|
61 |
+
geometric package. NB: Uses simplified atom and bond features, and represent
|
62 |
+
as indices
|
63 |
+
:param mol: rdkit mol object
|
64 |
+
:return: graph data object with the attributes: x, edge_index, edge_attr
|
65 |
+
"""
|
66 |
+
# atoms
|
67 |
+
num_atom_features = 2 # atom type, chirality tag
|
68 |
+
atom_features_list = []
|
69 |
+
for atom in mol.GetAtoms():
|
70 |
+
atom_feature = [allowable_features['possible_atomic_num_list'].index(
|
71 |
+
atom.GetAtomicNum())] + [allowable_features[
|
72 |
+
'possible_chirality_list'].index(atom.GetChiralTag())]
|
73 |
+
atom_features_list.append(atom_feature)
|
74 |
+
x = torch.tensor(np.array(atom_features_list), dtype=torch.long)
|
75 |
+
|
76 |
+
# bonds
|
77 |
+
num_bond_features = 2 # bond type, bond direction
|
78 |
+
if len(mol.GetBonds()) > 0: # mol has bonds
|
79 |
+
edges_list = []
|
80 |
+
edge_features_list = []
|
81 |
+
for bond in mol.GetBonds():
|
82 |
+
i = bond.GetBeginAtomIdx()
|
83 |
+
j = bond.GetEndAtomIdx()
|
84 |
+
edge_feature = [allowable_features['possible_bonds'].index(
|
85 |
+
bond.GetBondType())] + [allowable_features[
|
86 |
+
'possible_bond_dirs'].index(
|
87 |
+
bond.GetBondDir())]
|
88 |
+
edges_list.append((i, j))
|
89 |
+
edge_features_list.append(edge_feature)
|
90 |
+
edges_list.append((j, i))
|
91 |
+
edge_features_list.append(edge_feature)
|
92 |
+
|
93 |
+
# data.edge_index: Graph connectivity in COO format with shape [2, num_edges]
|
94 |
+
edge_index = torch.tensor(np.array(edges_list).T, dtype=torch.long)
|
95 |
+
|
96 |
+
# data.edge_attr: Edge feature matrix with shape [num_edges, num_edge_features]
|
97 |
+
edge_attr = torch.tensor(np.array(edge_features_list),
|
98 |
+
dtype=torch.long)
|
99 |
+
else: # mol has no bonds
|
100 |
+
edge_index = torch.empty((2, 0), dtype=torch.long)
|
101 |
+
edge_attr = torch.empty((0, num_bond_features), dtype=torch.long)
|
102 |
+
|
103 |
+
data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr)
|
104 |
+
|
105 |
+
return data
|
106 |
+
|
107 |
+
|
108 |
+
class RefineData(object):
|
109 |
+
def __init__(self):
|
110 |
+
super().__init__()
|
111 |
+
|
112 |
+
def __call__(self, data):
|
113 |
+
# delete H atom of pocket
|
114 |
+
protein_element = data.protein_element
|
115 |
+
is_H_protein = (protein_element == 1)
|
116 |
+
if torch.sum(is_H_protein) > 0:
|
117 |
+
not_H_protein = ~is_H_protein
|
118 |
+
data.protein_atom_name = list(compress(data.protein_atom_name, not_H_protein))
|
119 |
+
data.protein_atom_to_aa_type = data.protein_atom_to_aa_type[not_H_protein]
|
120 |
+
data.protein_element = data.protein_element[not_H_protein]
|
121 |
+
data.protein_is_backbone = data.protein_is_backbone[not_H_protein]
|
122 |
+
data.protein_pos = data.protein_pos[not_H_protein]
|
123 |
+
# delete H atom of ligand
|
124 |
+
ligand_element = data.ligand_element
|
125 |
+
is_H_ligand = (ligand_element == 1)
|
126 |
+
if torch.sum(is_H_ligand) > 0:
|
127 |
+
not_H_ligand = ~is_H_ligand
|
128 |
+
data.ligand_atom_feature = data.ligand_atom_feature[not_H_ligand]
|
129 |
+
data.ligand_element = data.ligand_element[not_H_ligand]
|
130 |
+
data.ligand_pos = data.ligand_pos[not_H_ligand]
|
131 |
+
# nbh
|
132 |
+
index_atom_H = torch.nonzero(is_H_ligand)[:, 0]
|
133 |
+
index_changer = -np.ones(len(not_H_ligand), dtype=np.int64)
|
134 |
+
index_changer[not_H_ligand] = np.arange(torch.sum(not_H_ligand))
|
135 |
+
new_nbh_list = [value for ind_this, value in zip(not_H_ligand, data.ligand_nbh_list.values()) if ind_this]
|
136 |
+
data.ligand_nbh_list = {i: [index_changer[node] for node in neigh if node not in index_atom_H] for i, neigh
|
137 |
+
in enumerate(new_nbh_list)}
|
138 |
+
# bond
|
139 |
+
ind_bond_with_H = np.array([(bond_i in index_atom_H) | (bond_j in index_atom_H) for bond_i, bond_j in
|
140 |
+
zip(*data.ligand_bond_index)])
|
141 |
+
ind_bond_without_H = ~ind_bond_with_H
|
142 |
+
old_ligand_bond_index = data.ligand_bond_index[:, ind_bond_without_H]
|
143 |
+
data.ligand_bond_index = torch.tensor(index_changer)[old_ligand_bond_index]
|
144 |
+
data.ligand_bond_type = data.ligand_bond_type[ind_bond_without_H]
|
145 |
+
|
146 |
+
return data
|
147 |
+
|
148 |
+
|
149 |
+
class FocalBuilder(object):
|
150 |
+
def __init__(self, close_threshold=0.8, max_bond_length=2.4):
|
151 |
+
self.close_threshold = close_threshold
|
152 |
+
self.max_bond_length = max_bond_length
|
153 |
+
super().__init__()
|
154 |
+
|
155 |
+
def __call__(self, data: ProteinLigandData):
|
156 |
+
# ligand_context_pos = data.ligand_context_pos
|
157 |
+
# ligand_pos = data.ligand_pos
|
158 |
+
ligand_masked_pos = data.ligand_masked_pos
|
159 |
+
protein_pos = data.protein_pos
|
160 |
+
context_idx = data.context_idx
|
161 |
+
masked_idx = data.masked_idx
|
162 |
+
old_bond_index = data.ligand_bond_index
|
163 |
+
# old_bond_types = data.ligand_bond_type # type: 0, 1, 2
|
164 |
+
has_unmask_atoms = context_idx.nelement() > 0
|
165 |
+
if has_unmask_atoms:
|
166 |
+
# # get bridge bond index (mask-context bond)
|
167 |
+
ind_edge_index_candidate = [
|
168 |
+
(context_node in context_idx) and (mask_node in masked_idx)
|
169 |
+
for mask_node, context_node in zip(*old_bond_index)
|
170 |
+
] # the mask-context order is right
|
171 |
+
bridge_bond_index = old_bond_index[:, ind_edge_index_candidate]
|
172 |
+
# candidate_bond_types = old_bond_types[idx_edge_index_candidate]
|
173 |
+
idx_generated_in_whole_ligand = bridge_bond_index[0]
|
174 |
+
idx_focal_in_whole_ligand = bridge_bond_index[1]
|
175 |
+
|
176 |
+
index_changer_masked = torch.zeros(masked_idx.max() + 1, dtype=torch.int64)
|
177 |
+
index_changer_masked[masked_idx] = torch.arange(len(masked_idx))
|
178 |
+
idx_generated_in_ligand_masked = index_changer_masked[idx_generated_in_whole_ligand]
|
179 |
+
pos_generate = ligand_masked_pos[idx_generated_in_ligand_masked]
|
180 |
+
|
181 |
+
data.idx_generated_in_ligand_masked = idx_generated_in_ligand_masked
|
182 |
+
data.pos_generate = pos_generate
|
183 |
+
|
184 |
+
index_changer_context = torch.zeros(context_idx.max() + 1, dtype=torch.int64)
|
185 |
+
index_changer_context[context_idx] = torch.arange(len(context_idx))
|
186 |
+
idx_focal_in_ligand_context = index_changer_context[idx_focal_in_whole_ligand]
|
187 |
+
idx_focal_in_compose = idx_focal_in_ligand_context # if ligand_context was not before protein in the compose, this was not correct
|
188 |
+
data.idx_focal_in_compose = idx_focal_in_compose
|
189 |
+
|
190 |
+
data.idx_protein_all_mask = torch.empty(0, dtype=torch.long) # no use if has context
|
191 |
+
data.y_protein_frontier = torch.empty(0, dtype=torch.bool) # no use if has context
|
192 |
+
|
193 |
+
else: # # the initial atom. surface atoms between ligand and protein
|
194 |
+
assign_index = radius(x=ligand_masked_pos, y=protein_pos, r=4., num_workers=16)
|
195 |
+
if assign_index.size(1) == 0:
|
196 |
+
dist = torch.norm(data.protein_pos.unsqueeze(1) - data.ligand_masked_pos.unsqueeze(0), p=2, dim=-1)
|
197 |
+
assign_index = torch.nonzero(dist <= torch.min(dist) + 1e-5)[0:1].transpose(0, 1)
|
198 |
+
idx_focal_in_protein = assign_index[0]
|
199 |
+
data.idx_focal_in_compose = idx_focal_in_protein # no ligand context, so all composes are protein atoms
|
200 |
+
data.pos_generate = ligand_masked_pos[assign_index[1]]
|
201 |
+
data.idx_generated_in_ligand_masked = torch.unique(assign_index[1]) # for real of the contractive transform
|
202 |
+
|
203 |
+
data.idx_protein_all_mask = data.idx_protein_in_compose # for input of initial frontier prediction
|
204 |
+
y_protein_frontier = torch.zeros_like(data.idx_protein_all_mask,
|
205 |
+
dtype=torch.bool) # for label of initial frontier prediction
|
206 |
+
y_protein_frontier[torch.unique(idx_focal_in_protein)] = True
|
207 |
+
data.y_protein_frontier = y_protein_frontier
|
208 |
+
|
209 |
+
# generate not positions: around pos_focal ( with `max_bond_length` distance) but not close to true generated within `close_threshold`
|
210 |
+
# pos_focal = ligand_context_pos[idx_focal_in_ligand_context]
|
211 |
+
# pos_notgenerate = pos_focal + torch.randn_like(pos_focal) * self.max_bond_length / 2.4
|
212 |
+
# dist = torch.norm(pos_generate - pos_notgenerate, p=2, dim=-1)
|
213 |
+
# ind_close = (dist < self.close_threshold)
|
214 |
+
# while ind_close.any():
|
215 |
+
# new_pos_notgenerate = pos_focal[ind_close] + torch.randn_like(pos_focal[ind_close]) * self.max_bond_length / 2.3
|
216 |
+
# dist[ind_close] = torch.norm(pos_generate[ind_close] - new_pos_notgenerate, p=2, dim=-1)
|
217 |
+
# pos_notgenerate[ind_close] = new_pos_notgenerate
|
218 |
+
# ind_close = (dist < self.close_threshold)
|
219 |
+
# data.pos_notgenerate = pos_notgenerate
|
220 |
+
|
221 |
+
return data
|
222 |
+
|
223 |
+
|
224 |
+
class AtomComposer(object):
|
225 |
+
|
226 |
+
def __init__(self, protein_dim, ligand_dim, knn):
|
227 |
+
super().__init__()
|
228 |
+
self.protein_dim = protein_dim
|
229 |
+
self.ligand_dim = ligand_dim
|
230 |
+
self.knn = knn # knn of compose atoms
|
231 |
+
|
232 |
+
def __call__(self, data: ProteinLigandData):
|
233 |
+
# fetch ligand context and protein from data
|
234 |
+
ligand_context_pos = data['ligand_context_pos']
|
235 |
+
ligand_context_feature_full = data['ligand_context_feature_full']
|
236 |
+
protein_pos = data['protein_pos']
|
237 |
+
protein_atom_feature = data['protein_atom_feature']
|
238 |
+
len_ligand_ctx = len(ligand_context_pos)
|
239 |
+
len_protein = len(protein_pos)
|
240 |
+
|
241 |
+
# compose ligand context and protein. save idx of them in compose
|
242 |
+
data['compose_pos'] = torch.cat([ligand_context_pos, protein_pos], dim=0)
|
243 |
+
len_compose = len_ligand_ctx + len_protein
|
244 |
+
ligand_context_feature_full_expand = torch.cat([
|
245 |
+
ligand_context_feature_full,
|
246 |
+
torch.zeros([len_ligand_ctx, self.protein_dim - self.ligand_dim], dtype=torch.long)
|
247 |
+
], dim=1)
|
248 |
+
data['compose_feature'] = torch.cat([ligand_context_feature_full_expand, protein_atom_feature], dim=0)
|
249 |
+
data['idx_ligand_ctx_in_compose'] = torch.arange(len_ligand_ctx, dtype=torch.long) # can be delete
|
250 |
+
data['idx_protein_in_compose'] = torch.arange(len_protein, dtype=torch.long) + len_ligand_ctx # can be delete
|
251 |
+
|
252 |
+
# build knn graph and bond type
|
253 |
+
data = self.get_knn_graph(data, self.knn, len_ligand_ctx, len_compose, num_workers=16)
|
254 |
+
return data
|
255 |
+
|
256 |
+
@staticmethod
|
257 |
+
def get_knn_graph(data: ProteinLigandData, knn, len_ligand_ctx, len_compose, num_workers=1, ):
|
258 |
+
data['compose_knn_edge_index'] = knn_graph(data['compose_pos'], knn, flow='target_to_source', num_workers=num_workers)
|
259 |
+
|
260 |
+
id_compose_edge = data['compose_knn_edge_index'][0,
|
261 |
+
:len_ligand_ctx * knn] * len_compose + data['compose_knn_edge_index'][1, :len_ligand_ctx * knn]
|
262 |
+
id_ligand_ctx_edge = data['ligand_context_bond_index'][0] * len_compose + data['ligand_context_bond_index'][1]
|
263 |
+
idx_edge = [torch.nonzero(id_compose_edge == id_) for id_ in id_ligand_ctx_edge]
|
264 |
+
idx_edge = torch.tensor([a.squeeze() if len(a) > 0 else torch.tensor(-1) for a in idx_edge], dtype=torch.long)
|
265 |
+
data['compose_knn_edge_type'] = torch.zeros(len(data['compose_knn_edge_index'][0]),
|
266 |
+
dtype=torch.long) # for encoder edge embedding
|
267 |
+
data['compose_knn_edge_type'][idx_edge[idx_edge >= 0]] = data['ligand_context_bond_type'][idx_edge >= 0]
|
268 |
+
data['compose_knn_edge_feature'] = torch.cat([
|
269 |
+
torch.ones([len(data['compose_knn_edge_index'][0]), 1], dtype=torch.long),
|
270 |
+
torch.zeros([len(data['compose_knn_edge_index'][0]), 3], dtype=torch.long),
|
271 |
+
], dim=-1)
|
272 |
+
data['compose_knn_edge_feature'][idx_edge[idx_edge >= 0]] = F.one_hot(data['ligand_context_bond_type'][idx_edge >= 0],
|
273 |
+
num_classes=4) # 0 (1,2,3)-onehot
|
274 |
+
return data
|
275 |
+
|
276 |
+
|
277 |
+
class FeaturizeProteinAtom(object):
|
278 |
+
|
279 |
+
def __init__(self):
|
280 |
+
super().__init__()
|
281 |
+
# self.atomic_numbers = torch.LongTensor([1, 6, 7, 8, 16, 34]) # H, C, N, O, S, Se
|
282 |
+
self.atomic_numbers = torch.LongTensor([6, 7, 8, 16, 34]) # H, C, N, O, S, Se
|
283 |
+
self.max_num_aa = 20
|
284 |
+
|
285 |
+
@property
|
286 |
+
def feature_dim(self):
|
287 |
+
return self.atomic_numbers.size(0) + self.max_num_aa + 1
|
288 |
+
|
289 |
+
def __call__(self, data: ProteinLigandData):
|
290 |
+
element = data['protein_element'].view(-1, 1) == self.atomic_numbers.view(1, -1) # (N_atoms, N_elements)
|
291 |
+
amino_acid = F.one_hot(data['protein_atom_to_aa_type'], num_classes=self.max_num_aa)
|
292 |
+
is_backbone = data['protein_is_backbone'].view(-1, 1).long()
|
293 |
+
x = torch.cat([element, amino_acid, is_backbone], dim=-1)
|
294 |
+
data['protein_atom_feature'] = x
|
295 |
+
return data
|
296 |
+
|
297 |
+
|
298 |
+
class FeaturizeLigandAtom(object):
|
299 |
+
|
300 |
+
def __init__(self):
|
301 |
+
super().__init__()
|
302 |
+
# self.atomic_numbers = torch.LongTensor([1,6,7,8,9,15,16,17]) # H C N O F P S Cl
|
303 |
+
self.atomic_numbers = torch.LongTensor([6, 7, 8, 9, 15, 16, 17]) # C N O F P S Cl
|
304 |
+
|
305 |
+
@property
|
306 |
+
def num_properties(self):
|
307 |
+
return len(ATOM_FAMILIES)
|
308 |
+
|
309 |
+
@property
|
310 |
+
def feature_dim(self):
|
311 |
+
return self.atomic_numbers.size(0) + len(ATOM_FAMILIES)
|
312 |
+
|
313 |
+
def __call__(self, data: ProteinLigandData):
|
314 |
+
element = data['ligand_element'].view(-1, 1) == self.atomic_numbers.view(1, -1) # (N_atoms, N_elements)
|
315 |
+
x = torch.cat([element, data['ligand_atom_feature']], dim=-1)
|
316 |
+
data['ligand_atom_feature_full'] = x
|
317 |
+
return data
|
318 |
+
|
319 |
+
|
320 |
+
class FeaturizeLigandBond(object):
|
321 |
+
|
322 |
+
def __init__(self):
|
323 |
+
super().__init__()
|
324 |
+
|
325 |
+
def __call__(self, data: ProteinLigandData):
|
326 |
+
data['ligand_bond_feature'] = F.one_hot((data['ligand_bond_type'] - 1)%3, num_classes=3) # (1,2,3) to (0,1,2)-onehot
|
327 |
+
|
328 |
+
neighbor_dict = {}
|
329 |
+
# used in rotation angle prediction
|
330 |
+
mol = data['moltree'].mol
|
331 |
+
for i, atom in enumerate(mol.GetAtoms()):
|
332 |
+
neighbor_dict[i] = [n.GetIdx() for n in atom.GetNeighbors()]
|
333 |
+
data['ligand_neighbors'] = neighbor_dict
|
334 |
+
return data
|
335 |
+
|
336 |
+
|
337 |
+
class LigandCountNeighbors(object):
|
338 |
+
|
339 |
+
@staticmethod
|
340 |
+
def count_neighbors(edge_index, symmetry, valence=None, num_nodes=None):
|
341 |
+
assert symmetry == True, 'Only support symmetrical edges.'
|
342 |
+
|
343 |
+
if num_nodes is None:
|
344 |
+
num_nodes = maybe_num_nodes(edge_index)
|
345 |
+
|
346 |
+
if valence is None:
|
347 |
+
valence = torch.ones([edge_index.size(1)], device=edge_index.device)
|
348 |
+
valence = valence.view(edge_index.size(1))
|
349 |
+
|
350 |
+
return scatter_add(valence, index=edge_index[0], dim=0, dim_size=num_nodes).long()
|
351 |
+
|
352 |
+
def __init__(self):
|
353 |
+
super().__init__()
|
354 |
+
|
355 |
+
def __call__(self, data):
|
356 |
+
data['ligand_num_neighbors'] = self.count_neighbors(
|
357 |
+
data['ligand_bond_index'],
|
358 |
+
symmetry=True,
|
359 |
+
num_nodes=data['ligand_element'].size(0),
|
360 |
+
)
|
361 |
+
data['ligand_atom_valence'] = self.count_neighbors(
|
362 |
+
data['ligand_bond_index'],
|
363 |
+
symmetry=True,
|
364 |
+
valence=data['ligand_bond_type'],
|
365 |
+
num_nodes=data['ligand_element'].size(0),
|
366 |
+
)
|
367 |
+
return data
|
368 |
+
|
369 |
+
|
370 |
+
class LigandRandomMask(object):
|
371 |
+
|
372 |
+
def __init__(self, min_ratio=0.0, max_ratio=1.2, min_num_masked=1, min_num_unmasked=0):
|
373 |
+
super().__init__()
|
374 |
+
self.min_ratio = min_ratio
|
375 |
+
self.max_ratio = max_ratio
|
376 |
+
self.min_num_masked = min_num_masked
|
377 |
+
self.min_num_unmasked = min_num_unmasked
|
378 |
+
|
379 |
+
def __call__(self, data: ProteinLigandData):
|
380 |
+
ratio = np.clip(random.uniform(self.min_ratio, self.max_ratio), 0.0, 1.0)
|
381 |
+
num_atoms = data.ligand_element.size(0)
|
382 |
+
num_masked = int(num_atoms * ratio)
|
383 |
+
|
384 |
+
if num_masked < self.min_num_masked:
|
385 |
+
num_masked = self.min_num_masked
|
386 |
+
if (num_atoms - num_masked) < self.min_num_unmasked:
|
387 |
+
num_masked = num_atoms - self.min_num_unmasked
|
388 |
+
|
389 |
+
idx = np.arange(num_atoms)
|
390 |
+
np.random.shuffle(idx)
|
391 |
+
idx = torch.LongTensor(idx)
|
392 |
+
masked_idx = idx[:num_masked]
|
393 |
+
context_idx = idx[num_masked:]
|
394 |
+
|
395 |
+
data.ligand_masked_element = data.ligand_element[masked_idx]
|
396 |
+
data.ligand_masked_feature = data.ligand_atom_feature[masked_idx] # For Prediction
|
397 |
+
data.ligand_masked_pos = data.ligand_pos[masked_idx]
|
398 |
+
|
399 |
+
data.ligand_context_element = data.ligand_element[context_idx]
|
400 |
+
data.ligand_context_feature_full = data.ligand_atom_feature_full[context_idx] # For Input
|
401 |
+
data.ligand_context_pos = data.ligand_pos[context_idx]
|
402 |
+
|
403 |
+
data.ligand_context_bond_index, data.ligand_context_bond_feature = subgraph(
|
404 |
+
context_idx,
|
405 |
+
data.ligand_bond_index,
|
406 |
+
edge_attr=data.ligand_bond_feature,
|
407 |
+
relabel_nodes=True,
|
408 |
+
)
|
409 |
+
data.ligand_context_num_neighbors = LigandCountNeighbors.count_neighbors(
|
410 |
+
data.ligand_context_bond_index,
|
411 |
+
symmetry=True,
|
412 |
+
num_nodes=context_idx.size(0),
|
413 |
+
)
|
414 |
+
|
415 |
+
# print(context_idx)
|
416 |
+
# print(data.ligand_context_bond_index)
|
417 |
+
|
418 |
+
# mask = torch.logical_and(
|
419 |
+
# (data.ligand_bond_index[0].view(-1, 1) == context_idx.view(1, -1)).any(dim=-1),
|
420 |
+
# (data.ligand_bond_index[1].view(-1, 1) == context_idx.view(1, -1)).any(dim=-1),
|
421 |
+
# )
|
422 |
+
# print(data.ligand_bond_index[:, mask])
|
423 |
+
|
424 |
+
# print(data.ligand_context_num_neighbors)
|
425 |
+
# print(data.ligand_num_neighbors[context_idx])
|
426 |
+
|
427 |
+
data.ligand_frontier = data.ligand_context_num_neighbors < data.ligand_num_neighbors[context_idx]
|
428 |
+
|
429 |
+
data._mask = 'random'
|
430 |
+
|
431 |
+
return data
|
432 |
+
|
433 |
+
|
434 |
+
class LigandBFSMask(object):
|
435 |
+
|
436 |
+
def __init__(self, min_ratio=0.0, max_ratio=1.2, min_num_masked=1, min_num_unmasked=0, vocab=None):
|
437 |
+
super().__init__()
|
438 |
+
self.min_ratio = min_ratio
|
439 |
+
self.max_ratio = max_ratio
|
440 |
+
self.min_num_masked = min_num_masked
|
441 |
+
self.min_num_unmasked = min_num_unmasked
|
442 |
+
self.vocab = vocab
|
443 |
+
self.vocab_size = vocab.size()
|
444 |
+
|
445 |
+
@staticmethod
|
446 |
+
def get_bfs_perm_motif(moltree, vocab):
|
447 |
+
for i, node in enumerate(moltree.nodes):
|
448 |
+
node.nid = i
|
449 |
+
node.wid = vocab.get_index(node.smiles)
|
450 |
+
# num_motifs = len(moltree.nodes)
|
451 |
+
bfs_queue = [0]
|
452 |
+
bfs_perm = []
|
453 |
+
bfs_focal = []
|
454 |
+
visited = {bfs_queue[0]}
|
455 |
+
while len(bfs_queue) > 0:
|
456 |
+
current = bfs_queue.pop(0)
|
457 |
+
bfs_perm.append(current)
|
458 |
+
next_candid = []
|
459 |
+
for motif in moltree.nodes[current].neighbors:
|
460 |
+
if motif.nid in visited: continue
|
461 |
+
next_candid.append(motif.nid)
|
462 |
+
visited.add(motif.nid)
|
463 |
+
bfs_focal.append(current)
|
464 |
+
|
465 |
+
random.shuffle(next_candid)
|
466 |
+
bfs_queue += next_candid
|
467 |
+
|
468 |
+
return bfs_perm, bfs_focal
|
469 |
+
|
470 |
+
def __call__(self, data):
|
471 |
+
bfs_perm, bfs_focal = self.get_bfs_perm_motif(data['moltree'], self.vocab)
|
472 |
+
ratio = np.clip(random.uniform(self.min_ratio, self.max_ratio), 0.0, 1.0)
|
473 |
+
num_motifs = len(bfs_perm)
|
474 |
+
num_masked = int(num_motifs * ratio)
|
475 |
+
if num_masked < self.min_num_masked:
|
476 |
+
num_masked = self.min_num_masked
|
477 |
+
if (num_motifs - num_masked) < self.min_num_unmasked:
|
478 |
+
num_masked = num_motifs - self.min_num_unmasked
|
479 |
+
num_unmasked = num_motifs - num_masked
|
480 |
+
|
481 |
+
context_motif_ids = bfs_perm[:-num_masked]
|
482 |
+
context_idx = set()
|
483 |
+
for i in context_motif_ids:
|
484 |
+
context_idx = context_idx | set(data['moltree'].nodes[i].clique)
|
485 |
+
context_idx = torch.LongTensor(list(context_idx))
|
486 |
+
|
487 |
+
if num_masked == num_motifs:
|
488 |
+
data['current_wid'] = torch.tensor([self.vocab_size])
|
489 |
+
data['current_atoms'] = torch.tensor([data['protein_contact_idx']])
|
490 |
+
data['next_wid'] = torch.tensor([data['moltree'].nodes[bfs_perm[-num_masked]].wid])
|
491 |
+
else:
|
492 |
+
data['current_wid'] = torch.tensor([data['moltree'].nodes[bfs_focal[-num_masked]].wid])
|
493 |
+
data['next_wid'] = torch.tensor([data['moltree'].nodes[bfs_perm[-num_masked]].wid]) # For Prediction
|
494 |
+
current_atoms = data['moltree'].nodes[bfs_focal[-num_masked]].clique
|
495 |
+
data['current_atoms'] = torch.cat([torch.where(context_idx == i)[0] for i in current_atoms]) + len(data['protein_pos'])
|
496 |
+
|
497 |
+
data['ligand_context_element'] = data['ligand_element'][context_idx]
|
498 |
+
data['ligand_context_feature_full'] = data['ligand_atom_feature_full'][context_idx] # For Input
|
499 |
+
data['ligand_context_pos'] = data['ligand_pos'][context_idx]
|
500 |
+
data['ligand_center'] = torch.mean(data['ligand_pos'], dim=0)
|
501 |
+
data['num_atoms'] = torch.tensor([len(context_idx) + len(data['protein_pos'])])
|
502 |
+
# distance matrix prediction
|
503 |
+
if len(data['ligand_context_pos']) > 0:
|
504 |
+
sample_idx = random.sample(data['moltree'].nodes[bfs_perm[0]].clique, 2)
|
505 |
+
data['dm_ligand_idx'] = torch.cat([torch.where(context_idx == i)[0] for i in sample_idx])
|
506 |
+
data['dm_protein_idx'] = torch.sort(torch.norm(data['protein_pos'] - data['ligand_context_pos'][data['dm_ligand_idx'][0]], dim=-1)).indices[:4]
|
507 |
+
data['true_dm'] = torch.norm(data['protein_pos'][data['dm_protein_idx']].unsqueeze(1) - data['ligand_context_pos'][data['dm_ligand_idx']].unsqueeze(0), dim=-1).reshape(-1)
|
508 |
+
else:
|
509 |
+
data['true_dm'] = torch.tensor([])
|
510 |
+
|
511 |
+
data['protein_alpha_carbon_index'] = torch.tensor([i for i, name in enumerate(data['protein_atom_name']) if name =="CA"])
|
512 |
+
data['alpha_carbon_indicator'] = torch.tensor([True if name =="CA" else False for name in data['protein_atom_name']])
|
513 |
+
|
514 |
+
# assemble prediction
|
515 |
+
data['protein_contact'] = torch.tensor(data['protein_contact'])
|
516 |
+
if len(context_motif_ids) > 0:
|
517 |
+
cand_labels, cand_mols = enumerate_assemble(data['moltree'].mol, context_idx.tolist(),
|
518 |
+
data['moltree'].nodes[bfs_focal[-num_masked]],
|
519 |
+
data['moltree'].nodes[bfs_perm[-num_masked]])
|
520 |
+
data['cand_labels'] = cand_labels
|
521 |
+
data['cand_mols'] = [mol_to_graph_data_obj_simple(mol) for mol in cand_mols]
|
522 |
+
else:
|
523 |
+
data['cand_labels'], data['cand_mols'] = torch.tensor([]), []
|
524 |
+
|
525 |
+
data['ligand_context_bond_index'], data['ligand_context_bond_feature'] = subgraph(
|
526 |
+
context_idx,
|
527 |
+
data['ligand_bond_index'],
|
528 |
+
edge_attr=data['ligand_bond_feature'],
|
529 |
+
relabel_nodes=True,
|
530 |
+
)
|
531 |
+
data['ligand_context_num_neighbors'] = LigandCountNeighbors.count_neighbors(
|
532 |
+
data['ligand_context_bond_index'],
|
533 |
+
symmetry=True,
|
534 |
+
num_nodes=context_idx.size(0),
|
535 |
+
)
|
536 |
+
data['ligand_frontier'] = data['ligand_context_num_neighbors'] < data['ligand_num_neighbors'][context_idx]
|
537 |
+
data['_mask'] = 'bfs'
|
538 |
+
|
539 |
+
# find a rotatable bond as the current motif
|
540 |
+
rotatable_ids = []
|
541 |
+
for i, id in enumerate(bfs_focal):
|
542 |
+
if data['moltree'].nodes[id].rotatable:
|
543 |
+
rotatable_ids.append(i)
|
544 |
+
if len(rotatable_ids) == 0:
|
545 |
+
# assign empty tensor
|
546 |
+
data['ligand_torsion_xy_index'] = torch.tensor([])
|
547 |
+
data['dihedral_mask'] = torch.tensor([]).bool()
|
548 |
+
data['ligand_element_torsion'] = torch.tensor([])
|
549 |
+
data['ligand_pos_torsion'] = torch.tensor([])
|
550 |
+
data['ligand_feature_torsion'] = torch.tensor([])
|
551 |
+
data['true_sin'], data['true_cos'], data['true_three_hop'] = torch.tensor([]), torch.tensor([]), torch.tensor([])
|
552 |
+
data['xn_pos'], data['yn_pos'], data['y_pos'] = torch.tensor([]), torch.tensor([]), torch.tensor([])
|
553 |
+
else:
|
554 |
+
num_unmasked = random.sample(rotatable_ids, 1)[0]
|
555 |
+
current_idx = torch.LongTensor(data['moltree'].nodes[bfs_focal[num_unmasked]].clique)
|
556 |
+
next_idx = torch.LongTensor(data['moltree'].nodes[bfs_perm[num_unmasked + 1]].clique)
|
557 |
+
current_idx_set = set(data['moltree'].nodes[bfs_focal[num_unmasked]].clique)
|
558 |
+
next_idx_set = set(data['moltree'].nodes[bfs_perm[num_unmasked + 1]].clique)
|
559 |
+
all_idx = set()
|
560 |
+
for i in bfs_perm[:num_unmasked + 2]:
|
561 |
+
all_idx = all_idx | set(data['moltree'].nodes[i].clique)
|
562 |
+
all_idx = list(all_idx)
|
563 |
+
x_id = current_idx_set.intersection(next_idx_set).pop()
|
564 |
+
y_id = (current_idx_set - {x_id}).pop()
|
565 |
+
data['ligand_torsion_xy_index'] = torch.cat([torch.where(torch.LongTensor(all_idx) == i)[0] for i in [x_id, y_id]])
|
566 |
+
|
567 |
+
x_pos, y_pos = deepcopy(data['ligand_pos'][x_id]), deepcopy(data['ligand_pos'][y_id])
|
568 |
+
# remove x, y, and non-generated elements
|
569 |
+
xn, yn = deepcopy(data['ligand_neighbors'][x_id]), deepcopy(data['ligand_neighbors'][y_id])
|
570 |
+
xn.remove(y_id)
|
571 |
+
yn.remove(x_id)
|
572 |
+
xn, yn = xn[:3], yn[:3]
|
573 |
+
# debug
|
574 |
+
xn, yn = list_filter(xn, all_idx), list_filter(yn, all_idx)
|
575 |
+
xn_pos, yn_pos = torch.zeros(3, 3), torch.zeros(3, 3)
|
576 |
+
xn_pos[:len(xn)], yn_pos[:len(yn)] = deepcopy(data['ligand_pos'][xn]), deepcopy(data['ligand_pos'][yn])
|
577 |
+
xn_idx, yn_idx = torch.cartesian_prod(torch.arange(3), torch.arange(3)).chunk(2, dim=-1)
|
578 |
+
xn_idx = xn_idx.squeeze(-1)
|
579 |
+
yn_idx = yn_idx.squeeze(-1)
|
580 |
+
dihedral_x, dihedral_y = torch.zeros(3), torch.zeros(3)
|
581 |
+
dihedral_x[:len(xn)] = 1
|
582 |
+
dihedral_y[:len(yn)] = 1
|
583 |
+
data['dihedral_mask'] = torch.matmul(dihedral_x.view(3, 1), dihedral_y.view(1, 3)).view(-1).bool()
|
584 |
+
data['true_sin'], data['true_cos'] = batch_dihedrals(xn_pos[xn_idx], x_pos.repeat(9, 1), y_pos.repeat(9, 1),
|
585 |
+
yn_pos[yn_idx])
|
586 |
+
data['true_three_hop'] = torch.linalg.norm(xn_pos[xn_idx] - yn_pos[yn_idx], dim=-1)[data['dihedral_mask']]
|
587 |
+
|
588 |
+
# random rotate to simulate the inference situation
|
589 |
+
dir = data['ligand_pos'][current_idx[0]] - data['ligand_pos'][current_idx[1]]
|
590 |
+
ref = deepcopy(data['ligand_pos'][current_idx[0]])
|
591 |
+
next_motif_pos = deepcopy(data['ligand_pos'][next_idx])
|
592 |
+
data['ligand_pos'][next_idx] = rand_rotate(dir, ref, next_motif_pos)
|
593 |
+
|
594 |
+
data['ligand_element_torsion'] = data['ligand_element'][all_idx]
|
595 |
+
data['ligand_pos_torsion'] = data['ligand_pos'][all_idx]
|
596 |
+
data['ligand_feature_torsion'] = data['ligand_atom_feature_full'][all_idx]
|
597 |
+
|
598 |
+
x_pos = deepcopy(data['ligand_pos'][x_id])
|
599 |
+
data['y_pos'] = data['ligand_pos'][y_id] - x_pos
|
600 |
+
data['xn_pos'], data['yn_pos'] = torch.zeros(3, 3), torch.zeros(3, 3)
|
601 |
+
data['xn_pos'][:len(xn)], data['yn_pos'][:len(yn)] = data['ligand_pos'][xn] - x_pos, data['ligand_pos'][yn] - x_pos
|
602 |
+
|
603 |
+
return data
|
604 |
+
|
605 |
+
|
606 |
+
class LigandMaskAll(LigandBFSMask):
|
607 |
+
|
608 |
+
def __init__(self, vocab):
|
609 |
+
super().__init__(min_ratio=1.0, vocab=vocab)
|
610 |
+
|
611 |
+
|
612 |
+
class LigandMixedMask(object):
|
613 |
+
|
614 |
+
def __init__(self, min_ratio=0.0, max_ratio=1.2, min_num_masked=1, min_num_unmasked=0, p_random=0.5, p_bfs=0.25,
|
615 |
+
p_invbfs=0.25):
|
616 |
+
super().__init__()
|
617 |
+
|
618 |
+
self.t = [
|
619 |
+
LigandRandomMask(min_ratio, max_ratio, min_num_masked, min_num_unmasked),
|
620 |
+
LigandBFSMask(min_ratio, max_ratio, min_num_masked, min_num_unmasked, inverse=False),
|
621 |
+
LigandBFSMask(min_ratio, max_ratio, min_num_masked, min_num_unmasked, inverse=True),
|
622 |
+
]
|
623 |
+
self.p = [p_random, p_bfs, p_invbfs]
|
624 |
+
|
625 |
+
def __call__(self, data):
|
626 |
+
f = random.choices(self.t, k=1, weights=self.p)[0]
|
627 |
+
return f(data)
|
628 |
+
|
629 |
+
|
630 |
+
def get_mask(cfg, vocab):
|
631 |
+
if cfg.type == 'bfs':
|
632 |
+
return LigandBFSMask(
|
633 |
+
min_ratio=cfg.min_ratio,
|
634 |
+
max_ratio=cfg.max_ratio,
|
635 |
+
min_num_masked=cfg.min_num_masked,
|
636 |
+
min_num_unmasked=cfg.min_num_unmasked,
|
637 |
+
vocab=vocab
|
638 |
+
)
|
639 |
+
elif cfg.type == 'random':
|
640 |
+
return LigandRandomMask(
|
641 |
+
min_ratio=cfg.min_ratio,
|
642 |
+
max_ratio=cfg.max_ratio,
|
643 |
+
min_num_masked=cfg.min_num_masked,
|
644 |
+
min_num_unmasked=cfg.min_num_unmasked,
|
645 |
+
)
|
646 |
+
elif cfg.type == 'mixed':
|
647 |
+
return LigandMixedMask(
|
648 |
+
min_ratio=cfg.min_ratio,
|
649 |
+
max_ratio=cfg.max_ratio,
|
650 |
+
min_num_masked=cfg.min_num_masked,
|
651 |
+
min_num_unmasked=cfg.min_num_unmasked,
|
652 |
+
p_random=cfg.p_random,
|
653 |
+
p_bfs=cfg.p_bfs,
|
654 |
+
p_invbfs=cfg.p_invbfs,
|
655 |
+
)
|
656 |
+
elif cfg.type == 'all':
|
657 |
+
return LigandMaskAll()
|
658 |
+
else:
|
659 |
+
raise NotImplementedError('Unknown mask: %s' % cfg.type)
|
660 |
+
|
661 |
+
|
662 |
+
def kabsch(A, B):
|
663 |
+
# Input:
|
664 |
+
# Nominal A Nx3 matrix of points
|
665 |
+
# Measured B Nx3 matrix of points
|
666 |
+
# Returns R,t
|
667 |
+
# R = 3x3 rotation matrix (B to A)
|
668 |
+
# t = 3x1 translation vector (B to A)
|
669 |
+
assert len(A) == len(B)
|
670 |
+
N = A.shape[0] # total points
|
671 |
+
centroid_A = np.mean(A, axis=0)
|
672 |
+
centroid_B = np.mean(B, axis=0)
|
673 |
+
# center the points
|
674 |
+
AA = A - np.tile(centroid_A, (N, 1))
|
675 |
+
BB = B - np.tile(centroid_B, (N, 1))
|
676 |
+
H = np.transpose(BB) * AA
|
677 |
+
U, S, Vt = np.linalg.svd(H)
|
678 |
+
R = Vt.T * U.T
|
679 |
+
# special reflection case
|
680 |
+
if np.linalg.det(R) < 0:
|
681 |
+
Vt[2, :] *= -1
|
682 |
+
R = Vt.T * U.T
|
683 |
+
t = -R * centroid_B.T + centroid_A.T
|
684 |
+
return R, t
|
utils/warmup.py
ADDED
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
MIT License
|
3 |
+
|
4 |
+
Copyright (c) 2019 Ildoo Kim
|
5 |
+
|
6 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
7 |
+
of this software and associated documentation files (the "Software"), to deal
|
8 |
+
in the Software without restriction, including without limitation the rights
|
9 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
10 |
+
copies of the Software, and to permit persons to whom the Software is
|
11 |
+
furnished to do so, subject to the following conditions:
|
12 |
+
|
13 |
+
The above copyright notice and this permission notice shall be included in all
|
14 |
+
copies or substantial portions of the Software.
|
15 |
+
|
16 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
17 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
18 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
19 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
20 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
21 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
22 |
+
SOFTWARE.
|
23 |
+
"""
|
24 |
+
from torch.optim.lr_scheduler import _LRScheduler
|
25 |
+
from torch.optim.lr_scheduler import ReduceLROnPlateau
|
26 |
+
|
27 |
+
|
28 |
+
class GradualWarmupScheduler(_LRScheduler):
|
29 |
+
""" Gradually warm-up(increasing) learning rate in optimizer.
|
30 |
+
Proposed in 'Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour'.
|
31 |
+
Args:
|
32 |
+
optimizer (Optimizer): Wrapped optimizer.
|
33 |
+
multiplier: target learning rate = base lr * multiplier if multiplier > 1.0. if multiplier = 1.0, lr starts from 0 and ends up with the base_lr.
|
34 |
+
total_epoch: target learning rate is reached at total_epoch, gradually
|
35 |
+
after_scheduler: after target_epoch, use this scheduler(eg. ReduceLROnPlateau)
|
36 |
+
"""
|
37 |
+
|
38 |
+
def __init__(self, optimizer, multiplier, total_epoch, after_scheduler=None):
|
39 |
+
self.multiplier = multiplier
|
40 |
+
if self.multiplier < 1.:
|
41 |
+
raise ValueError('multiplier should be greater thant or equal to 1.')
|
42 |
+
self.total_epoch = total_epoch
|
43 |
+
self.after_scheduler = after_scheduler
|
44 |
+
self.finished = False
|
45 |
+
super(GradualWarmupScheduler, self).__init__(optimizer)
|
46 |
+
|
47 |
+
def get_lr(self):
|
48 |
+
if self.last_epoch > self.total_epoch:
|
49 |
+
if self.after_scheduler:
|
50 |
+
if not self.finished:
|
51 |
+
self.after_scheduler.base_lrs = [base_lr * self.multiplier for base_lr in self.base_lrs]
|
52 |
+
self.finished = True
|
53 |
+
return self.after_scheduler.get_last_lr()
|
54 |
+
return [base_lr * self.multiplier for base_lr in self.base_lrs]
|
55 |
+
|
56 |
+
if self.multiplier == 1.0:
|
57 |
+
return [base_lr * (float(self.last_epoch) / self.total_epoch) for base_lr in self.base_lrs]
|
58 |
+
else:
|
59 |
+
return [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in self.base_lrs]
|
60 |
+
|
61 |
+
def step_ReduceLROnPlateau(self, metrics, epoch=None):
|
62 |
+
if epoch is None:
|
63 |
+
epoch = self.last_epoch + 1
|
64 |
+
self.last_epoch = epoch if epoch != 0 else 1 # ReduceLROnPlateau is called at the end of epoch, whereas others are called at beginning
|
65 |
+
if self.last_epoch <= self.total_epoch:
|
66 |
+
warmup_lr = [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in self.base_lrs]
|
67 |
+
for param_group, lr in zip(self.optimizer.param_groups, warmup_lr):
|
68 |
+
param_group['lr'] = lr
|
69 |
+
else:
|
70 |
+
if epoch is None:
|
71 |
+
self.after_scheduler.step(metrics, None)
|
72 |
+
else:
|
73 |
+
self.after_scheduler.step(metrics, epoch - self.total_epoch)
|
74 |
+
|
75 |
+
def step(self, epoch=None, metrics=None):
|
76 |
+
if type(self.after_scheduler) != ReduceLROnPlateau:
|
77 |
+
if self.finished and self.after_scheduler:
|
78 |
+
if epoch is None:
|
79 |
+
self.after_scheduler.step(None)
|
80 |
+
else:
|
81 |
+
self.after_scheduler.step(epoch - self.total_epoch)
|
82 |
+
self._last_lr = self.after_scheduler.get_last_lr()
|
83 |
+
else:
|
84 |
+
return super(GradualWarmupScheduler, self).step(epoch)
|
85 |
+
else:
|
86 |
+
self.step_ReduceLROnPlateau(metrics, epoch)
|
vocab.txt
ADDED
@@ -0,0 +1,549 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
CC:108150
|
2 |
+
CN:59667
|
3 |
+
CO:39300
|
4 |
+
C=O:36148
|
5 |
+
C1=CC=CC=C1:17649
|
6 |
+
OP:7954
|
7 |
+
O=S:5180
|
8 |
+
CF:4607
|
9 |
+
CS:4246
|
10 |
+
C[NH3+]:3561
|
11 |
+
O=P:3006
|
12 |
+
CCl:2484
|
13 |
+
C[NH+]:2321
|
14 |
+
C=N:2305
|
15 |
+
C1CCNC1:2115
|
16 |
+
[H]N:2073
|
17 |
+
C1CCOCC1:1957
|
18 |
+
C1=CC=NC=C1:1892
|
19 |
+
C1=CN=CN=C1:1875
|
20 |
+
NS:1824
|
21 |
+
C1CCOC1:1578
|
22 |
+
C[NH2+]:1291
|
23 |
+
C1CCCCC1:1209
|
24 |
+
C=C:1202
|
25 |
+
C1=CNC=N1:1184
|
26 |
+
C[N+]:1066
|
27 |
+
C1=CNN=C1:676
|
28 |
+
C1CCNCC1:670
|
29 |
+
CP:662
|
30 |
+
C1=CCCC=C1:628
|
31 |
+
OS:624
|
32 |
+
C1=CSC=C1:614
|
33 |
+
C1CCCC1:531
|
34 |
+
C#N:481
|
35 |
+
NO:477
|
36 |
+
C1=CSC=N1:463
|
37 |
+
CBr:437
|
38 |
+
C1=CNCNC1:436
|
39 |
+
C1CC1:434
|
40 |
+
C1=CCCCC1:423
|
41 |
+
C1=CNC=C1:391
|
42 |
+
C1CC[NH+]CC1:352
|
43 |
+
C1=CCNC=C1:319
|
44 |
+
[N+]=O:315
|
45 |
+
C1=CN=CNC1:313
|
46 |
+
C1=NCCN1:311
|
47 |
+
[N+][O-]:310
|
48 |
+
C1=CNCC1:300
|
49 |
+
C1CC[NH2+]CC1:274
|
50 |
+
C1CC[NH2+]C1:264
|
51 |
+
BO:260
|
52 |
+
C1C[NH+]CCN1:247
|
53 |
+
C1=CNN=N1:226
|
54 |
+
C1=COC=C1:218
|
55 |
+
C1CC[NH+]C1:210
|
56 |
+
C#C:208
|
57 |
+
C1COCCN1:206
|
58 |
+
C1=CN=CC=N1:184
|
59 |
+
C1=CCCN=C1:179
|
60 |
+
C1=CCNCC1:176
|
61 |
+
CI:175
|
62 |
+
C1CNCN1:171
|
63 |
+
C1CNCCN1:165
|
64 |
+
C1=COCCC1:165
|
65 |
+
C1=CON=C1:162
|
66 |
+
C1COCC[NH+]1:138
|
67 |
+
BC:137
|
68 |
+
C1=CCNC1:131
|
69 |
+
C1CNCNC1:129
|
70 |
+
C1=CNCN=C1:127
|
71 |
+
NP:124
|
72 |
+
C1=NN=CN1:124
|
73 |
+
C1=NC=NN1:123
|
74 |
+
C1C[NH+]CC[NH+]1:118
|
75 |
+
C=S:117
|
76 |
+
C1=CCCC1:112
|
77 |
+
C1=NCCCN1:107
|
78 |
+
C1=COCC1:107
|
79 |
+
C1=NCCS1:102
|
80 |
+
C1=NNCC1:100
|
81 |
+
C1CSCN1:98
|
82 |
+
NN:95
|
83 |
+
C1=CNCCC1:94
|
84 |
+
C1=CCNN=C1:90
|
85 |
+
C1=COC=N1:81
|
86 |
+
C1COCO1:78
|
87 |
+
C1=COCO1:78
|
88 |
+
C1=CCOCC1:75
|
89 |
+
C1=NC=NCC1:73
|
90 |
+
C1=CNCC=N1:72
|
91 |
+
C1=NN=CS1:70
|
92 |
+
C1CCSC1:66
|
93 |
+
C1CCC1:65
|
94 |
+
C1C[NH2+]CCN1:63
|
95 |
+
C1=CSCC1:62
|
96 |
+
C1=COC=CC1:62
|
97 |
+
C1=NC=NC=N1:60
|
98 |
+
C1COCN1:59
|
99 |
+
N=N:57
|
100 |
+
C1=CNCN1:55
|
101 |
+
C1=CNCCN1:55
|
102 |
+
C1=CNC=NC1:54
|
103 |
+
C1CCCNCC1:51
|
104 |
+
C1=CN=CCC1:51
|
105 |
+
C1=CC1:51
|
106 |
+
C1C2CC3CC1CC(C2)C3:50
|
107 |
+
C1=CCC=CC1:50
|
108 |
+
C:N:49
|
109 |
+
C1=CCOC=C1:48
|
110 |
+
[NH2+]O:47
|
111 |
+
C1=NCNCC1:46
|
112 |
+
C1CNC1:44
|
113 |
+
C1CNCCNC1:42
|
114 |
+
C1=CNC=CC1:40
|
115 |
+
C1CO1:39
|
116 |
+
C1=CC[NH+]CC1:39
|
117 |
+
N=O:37
|
118 |
+
C1=CSCN1:35
|
119 |
+
[NH+]N:34
|
120 |
+
C1=NC=NO1:34
|
121 |
+
C1=CNCCNC1:32
|
122 |
+
C1=CC[NH2+]CC1:30
|
123 |
+
[N+]=[N-]:29
|
124 |
+
[N+]=N:29
|
125 |
+
O[PH]:29
|
126 |
+
C1CNCSC1:29
|
127 |
+
C1CCCCCC1:29
|
128 |
+
C1=NN=NN1:27
|
129 |
+
C1COPOC1:26
|
130 |
+
C1CCSCC1:26
|
131 |
+
C1=NNCCC1:26
|
132 |
+
C1=CCCCCC1:26
|
133 |
+
C1C[Fe]1:25
|
134 |
+
C1=NN=CO1:25
|
135 |
+
C1=NCCN=C1:25
|
136 |
+
C1=CSC=[N+]1:25
|
137 |
+
PS:24
|
138 |
+
C1CNCOC1:24
|
139 |
+
C1CCC[NH2+]CC1:24
|
140 |
+
C1=NCCO1:24
|
141 |
+
C1=COCCN1:24
|
142 |
+
C1=CNC=[N+]1:24
|
143 |
+
C1=CCN=CC1:24
|
144 |
+
C[PH]:23
|
145 |
+
C1COCOC1:23
|
146 |
+
C1CNCC[NH2+]C1:23
|
147 |
+
C1=CONC1:23
|
148 |
+
C1CNCCOC1:22
|
149 |
+
C1=NON=C1:21
|
150 |
+
C1=NNCN1:21
|
151 |
+
O:S:20
|
152 |
+
C[Si]:20
|
153 |
+
C1C[NH2+]CC[NH+]1:19
|
154 |
+
C1C[NH+]C1:19
|
155 |
+
C1=[N+]CCC1:19
|
156 |
+
C1=NOCC1:19
|
157 |
+
C1NO1:18
|
158 |
+
C1=CNCCN=C1:18
|
159 |
+
SS:17
|
160 |
+
C1CSC[NH+]1:17
|
161 |
+
C1COCCO1:17
|
162 |
+
C1=CCOC1:17
|
163 |
+
C[Se]:16
|
164 |
+
C[AsH]:16
|
165 |
+
C1CSC[NH2+]1:16
|
166 |
+
C1CSCCN1:16
|
167 |
+
C1CCCCCCC1:16
|
168 |
+
C1=NC=NCN1:16
|
169 |
+
C1=CSN=C1:16
|
170 |
+
C1=CSCCC1:16
|
171 |
+
C1=COCCO1:16
|
172 |
+
C1=CCSC1:16
|
173 |
+
C1=CCN=C1:16
|
174 |
+
C1COC1:15
|
175 |
+
C1=NCCNC1:15
|
176 |
+
C1=NC=NNC1:15
|
177 |
+
C1=C[Fe]1:15
|
178 |
+
C1=CSNCC1:15
|
179 |
+
C1=CN[C@H]2CCCC(N1)O2:15
|
180 |
+
C1=CC=[N+]C=C1:15
|
181 |
+
[NH+]O:14
|
182 |
+
C=[N+]:14
|
183 |
+
C1CCNCNC1:14
|
184 |
+
C1=NNN=C1:14
|
185 |
+
C1=CNNC1:14
|
186 |
+
C1=CN=NC=C1:14
|
187 |
+
C1=CC=NN=C1:14
|
188 |
+
C1C[C@@H]2CC[C@H](C1)[NH+]2:13
|
189 |
+
C1CNN=N1:13
|
190 |
+
C1=NCCCC1:13
|
191 |
+
C1C[Ru]1:12
|
192 |
+
C1=NNN=N1:12
|
193 |
+
[N+]N:11
|
194 |
+
C1CNSC1:11
|
195 |
+
C1CC[NH2+]NC1:11
|
196 |
+
C1CC2NCCN[C@@H](C1)O2:11
|
197 |
+
C1=NCNC1:11
|
198 |
+
C1=NC=NC1:11
|
199 |
+
C1=CCCCC=C1:11
|
200 |
+
C1=CCC=C1:11
|
201 |
+
O=[SH]:10
|
202 |
+
C1CSNCN1:10
|
203 |
+
C1=CSCCN1:10
|
204 |
+
C1=CC=CCC=C1:10
|
205 |
+
C1C[NH2+]C1:9
|
206 |
+
C1C[NH+]2CCC1CC2:9
|
207 |
+
C1CC[SH]C1:9
|
208 |
+
C1=[N+]CCN1:9
|
209 |
+
C1=CSNC1:9
|
210 |
+
C1CN[Fe]NC1:8
|
211 |
+
C1CNCC[NH+]C1:8
|
212 |
+
C1=NNCCN1:8
|
213 |
+
C1=CSN=N1:8
|
214 |
+
C1=CNCCCN1:8
|
215 |
+
C1=CNC=CN1:8
|
216 |
+
C1=CC[NH+]C1:8
|
217 |
+
C1=CCSCC1:8
|
218 |
+
C1=CCNNC1:8
|
219 |
+
C1NCON1:7
|
220 |
+
C1COCC[NH2+]1:7
|
221 |
+
C1CNSN1:7
|
222 |
+
C1CNNC1:7
|
223 |
+
C1CCC[NH+]CC1:7
|
224 |
+
C1=[N+]O[Fe]O1:7
|
225 |
+
C1=C[Ru]1:7
|
226 |
+
C1=C[N+][Co][N+]=C1:7
|
227 |
+
C1=CNN=CC1:7
|
228 |
+
[O-]P:6
|
229 |
+
O:P:6
|
230 |
+
N[NH3+]:6
|
231 |
+
N[NH2+]:6
|
232 |
+
FS:6
|
233 |
+
C[SeH]:6
|
234 |
+
C1NO[Fe]O1:6
|
235 |
+
C1C[C@@H]2CC[C@H](C1)N2:6
|
236 |
+
C1CSCCO1:6
|
237 |
+
C1CC[N+]C1:6
|
238 |
+
C1=NSN=C1:6
|
239 |
+
C1=NCCC1:6
|
240 |
+
C1=COCN1:6
|
241 |
+
C1=CNNCC1:6
|
242 |
+
C1=CNCSC1:6
|
243 |
+
C1=CNC=CNC1:6
|
244 |
+
C1=CN=NC1:6
|
245 |
+
C1=CC[N+]=C1:6
|
246 |
+
[N+]O:5
|
247 |
+
P=S:5
|
248 |
+
C1[NH+]O1:5
|
249 |
+
C1C[Rh]1:5
|
250 |
+
C1C[C@@H]2CC[C@H](C1)[N+]2:5
|
251 |
+
C1C[C@@H]2CCNC[C@H](C1)N2:5
|
252 |
+
C1CN[NH2+]C1:5
|
253 |
+
C1CC[SH]CC1:5
|
254 |
+
C1CCNSCC1:5
|
255 |
+
C1CCNNC1:5
|
256 |
+
C1=N[N+]=CS1:5
|
257 |
+
C1=NSCCN1:5
|
258 |
+
C1=NCCNCC1:5
|
259 |
+
C1=NC=NS1:5
|
260 |
+
C1=C\CCCCCC/1:5
|
261 |
+
C1=C[Fe]C1:5
|
262 |
+
C1=CSNCN1:5
|
263 |
+
C1=COCCCN1:5
|
264 |
+
C1=CN[Co][N+]=C1:5
|
265 |
+
C1=CNCCCC1:5
|
266 |
+
C1=CN=NC=N1:5
|
267 |
+
C1=CCCNCC1:5
|
268 |
+
[NH3+]O:4
|
269 |
+
[NH+][NH2+]:4
|
270 |
+
O[Si]:4
|
271 |
+
C1N[NH2+]CS1:4
|
272 |
+
C1NCN[Fe]N1:4
|
273 |
+
C1NCNN1:4
|
274 |
+
C1C[NH2+]1:4
|
275 |
+
C1C[NH+]CC[NH+]C1:4
|
276 |
+
C1C[C@H]2CC[C@@H]1C2:4
|
277 |
+
C1C[C@@H]2CC[C@H](C1)[NH2+]2:4
|
278 |
+
C1COCC[N+]1:4
|
279 |
+
C1CNOC1:4
|
280 |
+
C1CCO[NH2+]C1:4
|
281 |
+
C1CC2CCC1C2:4
|
282 |
+
C1=[N+]NCC1:4
|
283 |
+
C1=NNCNC1:4
|
284 |
+
C1=NCCSN1:4
|
285 |
+
C1=C[N+][Mg][N+]=C1:4
|
286 |
+
C1=CSNC=N1:4
|
287 |
+
C1=CN=NCC1:4
|
288 |
+
C1=CCNCCC1:4
|
289 |
+
C1=CCCOCC1:4
|
290 |
+
C1=CCCOC=C1:4
|
291 |
+
C1=CCC1:4
|
292 |
+
C1=C2C[N@@H+]3CC[C@]45CCN6CC[C@H](OC1)[C@@H]([C@H]64)[C@H]2C[C@H]35:4
|
293 |
+
B1CCC=CO1:4
|
294 |
+
O=[V]:3
|
295 |
+
O=[Sb]:3
|
296 |
+
N=S:3
|
297 |
+
C1NC[C@H]2CNC[C@@H]1C2:3
|
298 |
+
C1C[C@@H]2CC[C@H](C1)C2:3
|
299 |
+
C1COSO1:3
|
300 |
+
C1COC[NH2+]1:3
|
301 |
+
C1CN[NH+]C1:3
|
302 |
+
C1CN[Co][N+]1:3
|
303 |
+
C1CNCC[N+]1:3
|
304 |
+
C1CNCCSC1:3
|
305 |
+
C1CN=[N+]C1:3
|
306 |
+
C1CC[NH+]NC1:3
|
307 |
+
C1CC[NH+]CCNC1:3
|
308 |
+
C1CCSNC1:3
|
309 |
+
C1CCNSNC1:3
|
310 |
+
C1CC2C[NH+][C@@H]3C4CCC[C@]3(C1)[C@@H]2C4:3
|
311 |
+
C1CC2CCC1CC2:3
|
312 |
+
C1=[N+]CCS1:3
|
313 |
+
C1=[N+]CCCN1:3
|
314 |
+
C1=[N+]CCCCC1:3
|
315 |
+
C1=NSCC1:3
|
316 |
+
C1=NNCSC1:3
|
317 |
+
C1=NNCS1:3
|
318 |
+
C1=NCNN=C1:3
|
319 |
+
C1=NCCCCC1:3
|
320 |
+
C1=C[Ru]C1:3
|
321 |
+
C1=C[Rh]1:3
|
322 |
+
C1=C[N+]=CC1:3
|
323 |
+
C1=C[C@H]2COCC(C1)C2:3
|
324 |
+
C1=C[C@H]2CC[C@@H]1N2:3
|
325 |
+
C1=C[C@H]2CC[C@@H]1C2:3
|
326 |
+
C1=C[C@@H]2CC=C[C@H](C1)C2:3
|
327 |
+
C1=COCOC1:3
|
328 |
+
C1=COCCNC1:3
|
329 |
+
C1=CNSCCN1:3
|
330 |
+
C1=CNSC=C1:3
|
331 |
+
C1=CNSC1:3
|
332 |
+
C1=CNC=CCC1:3
|
333 |
+
C1=CCCCN=C1:3
|
334 |
+
C1=CCCCCCC1:3
|
335 |
+
C1=CCCC=CC1:3
|
336 |
+
C1=CC2CCO[C@@H](C1)C2:3
|
337 |
+
B1C=CCO1:3
|
338 |
+
[O-]S:2
|
339 |
+
OO:2
|
340 |
+
C1OCO1:2
|
341 |
+
C1C[NH2+]C[NH2+]C1:2
|
342 |
+
C1C[NH2+]CN1:2
|
343 |
+
C1C[NH+][NH2+]C1:2
|
344 |
+
C1C[C@H]2C[C@H](CCO2)N1:2
|
345 |
+
C1C[C@H]2COC[C@@H]1C2:2
|
346 |
+
C1C[C@@H]2C[C@H]1CN2:2
|
347 |
+
C1C[C@@H]2CNC[C@H](C1)N2:2
|
348 |
+
C1CSN=N1:2
|
349 |
+
C1CNC[NH2+]C1:2
|
350 |
+
C1CN1:2
|
351 |
+
C1CC[NH2+][NH2+]C1:2
|
352 |
+
C1CC[N+]CC1:2
|
353 |
+
C1CCSCNC1:2
|
354 |
+
C1CCCNCCC1:2
|
355 |
+
C1CC2OC[C@@H](C1)O2:2
|
356 |
+
C1CC2COC(C1)C2:2
|
357 |
+
C1CC2CC[C@H](C1)[NH2+]2:2
|
358 |
+
C1CC2CCO[Fe](O1)OC2:2
|
359 |
+
C1CC2CCC(O1)O2:2
|
360 |
+
C1CC2CCC(C1)C2:2
|
361 |
+
C1=[N+]CNN1:2
|
362 |
+
C1=N[N+]=CCC1:2
|
363 |
+
C1=NSNC1:2
|
364 |
+
C1=NNCO1:2
|
365 |
+
C1=NN=CCC1:2
|
366 |
+
C1=NN=CC1:2
|
367 |
+
C1=NCSCC1:2
|
368 |
+
C1=NCC=[N+]1:2
|
369 |
+
C1=NC=NN=C1:2
|
370 |
+
C1=C[N+][Zn]NC1:2
|
371 |
+
C1=C[N+]C=CC1:2
|
372 |
+
C1=C[C@H]2C[NH2+]C[C@@H]1C2:2
|
373 |
+
C1=C[C@H]2C[NH2+]C[C@@H](C1)[NH2+]2:2
|
374 |
+
C1=C[C@H]2CC=CC(C1)C2:2
|
375 |
+
C1=CSNCCO1:2
|
376 |
+
C1=CSC=CN1:2
|
377 |
+
C1=COSNC1:2
|
378 |
+
C1=COCCCNC1:2
|
379 |
+
C1=COCC=N1:2
|
380 |
+
C1=CN[Zn][N+]=C1:2
|
381 |
+
C1=CN[C@H]2CCC(N1)O2:2
|
382 |
+
C1=CNSNC1:2
|
383 |
+
C1=CNSCCC1:2
|
384 |
+
C1=CNNC=C1:2
|
385 |
+
C1=CNC[NH2+]C1:2
|
386 |
+
C1=CNCCOC1:2
|
387 |
+
C1=CNCCCNC1:2
|
388 |
+
C1=CN=[N+]C1:2
|
389 |
+
C1=CN=CNCC1:2
|
390 |
+
C1=CN=CCN=C1:2
|
391 |
+
C1=CN=CC1:2
|
392 |
+
C1=CC[NH2+]NC1:2
|
393 |
+
C1=CC[NH2+]C1:2
|
394 |
+
C1=CC[NH+]NC1:2
|
395 |
+
C1=CC[NH+]CCC1:2
|
396 |
+
C1=CC[N+]CC1:2
|
397 |
+
C1=CCOC=CC1:2
|
398 |
+
C1=CCN=CCC1:2
|
399 |
+
C1=CCCNC=C1:2
|
400 |
+
C1=CC2CC(C1)CO2:2
|
401 |
+
B1OCCO1:2
|
402 |
+
B1CCCO1:2
|
403 |
+
[N+][NH+]:1
|
404 |
+
O[Fe]:1
|
405 |
+
O1POPOPOP1:1
|
406 |
+
N1OO1:1
|
407 |
+
C[TeH]:1
|
408 |
+
C[Sb]:1
|
409 |
+
C[Ru]:1
|
410 |
+
C[O-]:1
|
411 |
+
C[AsH2]:1
|
412 |
+
C1O[C@@H]2COC1O2:1
|
413 |
+
C1N[C@@H]2CN[C@H]1C2:1
|
414 |
+
C1NN=NN1:1
|
415 |
+
C1NC[C@H]2C[NH2+]C[C@@H]1C2:1
|
416 |
+
C1NC[C@H]2C[NH+]C[C@@H]1C2:1
|
417 |
+
C1NC[C@H]2CC[C@H](C2)N1:1
|
418 |
+
C1NC[C@@H]2OC[C@H]1O2:1
|
419 |
+
C1NCO[NH2+]1:1
|
420 |
+
C1NCNCN1:1
|
421 |
+
C1NC2COC[C@H](C2)S1:1
|
422 |
+
C1N=[N+]CS1:1
|
423 |
+
C1N=NCO1:1
|
424 |
+
C1N=N1:1
|
425 |
+
C1C[NH2+][Pt][NH2+]1:1
|
426 |
+
C1C[NH2+]OC1:1
|
427 |
+
C1C[NH2+]C[NH+]1:1
|
428 |
+
C1C[NH2+]CSC1:1
|
429 |
+
C1C[NH2+]CC[NH2+]1:1
|
430 |
+
C1C[NH2+]CCSC1:1
|
431 |
+
C1C[NH2+]CCOC1:1
|
432 |
+
C1C[NH+][NH2+]N1:1
|
433 |
+
C1C[NH+]C[NH+]C1:1
|
434 |
+
C1C[NH+]CSC1:1
|
435 |
+
C1C[N+][Co][N+]1:1
|
436 |
+
C1C[N+]CNC1:1
|
437 |
+
C1C[C@H]2C[NH+]C[C@@H]1N2:1
|
438 |
+
C1C[C@H]2C[C@@H]1CN2:1
|
439 |
+
C1C[C@H]2CC[C@@H]1O2:1
|
440 |
+
C1C[C@H]2CCC[C@@H](C1)[NH+]2:1
|
441 |
+
C1C[C@@H]2C[C@@H](CCO2)N1:1
|
442 |
+
C1C[C@@H]2CO[C@H](C1)N2:1
|
443 |
+
C1C[C@@H]2CC[C@H](C1)O2:1
|
444 |
+
C1CSCS1:1
|
445 |
+
C1CSCN[NH2+]1:1
|
446 |
+
C1CSCC[NH+]1:1
|
447 |
+
C1CSC1:1
|
448 |
+
C1CO[V]O1:1
|
449 |
+
C1CO[SH]C1:1
|
450 |
+
C1COSN1:1
|
451 |
+
C1COPO1:1
|
452 |
+
C1COCCOC1:1
|
453 |
+
C1CN[C@H]2CCC(N1)O2:1
|
454 |
+
C1CNSCCN1:1
|
455 |
+
C1CNNCN1:1
|
456 |
+
C1CNCN[NH+]C1:1
|
457 |
+
C1CN=NC1:1
|
458 |
+
C1CCSSC1:1
|
459 |
+
C1CCONC1:1
|
460 |
+
C1CCOCOC1:1
|
461 |
+
C1CCOCCOC1:1
|
462 |
+
C1CCNNCC1:1
|
463 |
+
C1CC2CO[C@@H](C2)[NH+]1:1
|
464 |
+
C1CC2COC(C1)OO2:1
|
465 |
+
C1CC2CC[NH+][C@H](C1)C2:1
|
466 |
+
C1CC2CC[C@H](C1)C2:1
|
467 |
+
C1CC2CCC1C[NH+]2:1
|
468 |
+
C1CC2CC1CO2:1
|
469 |
+
C1C2C[C@@H]3C[C@H](C2)C1O3:1
|
470 |
+
C1C2CC1C2:1
|
471 |
+
C1=[N+]CON1:1
|
472 |
+
C1=[N+]CCOC1:1
|
473 |
+
C1=[N+]CCNC1:1
|
474 |
+
C1=NSN=CC1:1
|
475 |
+
C1=NO[N+]=C1:1
|
476 |
+
C1=NOCN1:1
|
477 |
+
C1=NNNC1:1
|
478 |
+
C1=NNC=[N+]1:1
|
479 |
+
C1=NN=[N+]C1:1
|
480 |
+
C1=NC[NH2+]CC1:1
|
481 |
+
C1=NC[NH2+]C1:1
|
482 |
+
C1=NCON1:1
|
483 |
+
C1=NCOC1:1
|
484 |
+
C1=NCNN1:1
|
485 |
+
C1=NCNCN1:1
|
486 |
+
C1=NCN=CN1:1
|
487 |
+
C1=NCC[N+]=C1:1
|
488 |
+
C1=NCCC=[N+]1:1
|
489 |
+
C1=C\CCOCCC/1:1
|
490 |
+
C1=C[Se]C=N1:1
|
491 |
+
C1=C[SH]CCCN1:1
|
492 |
+
C1=C[Ru]NC1:1
|
493 |
+
C1=C[Rh]C1:1
|
494 |
+
C1=C[N+]CC1:1
|
495 |
+
C1=C[N+]C=NC1:1
|
496 |
+
C1=C[N+]=CNC1:1
|
497 |
+
C1=C[C@H]2C[C@H](CCO2)N1:1
|
498 |
+
C1=C[C@H]2C[C@H](C2)NC1:1
|
499 |
+
C1=C[C@H]2CC[C@@H]1O2:1
|
500 |
+
C1=C[C@H]2CCC[C@@H]1[N+]2:1
|
501 |
+
C1=C[C@H]2CCCC1[NH2+]2:1
|
502 |
+
C1=C[C@@H]2C[NH2+]C[C@H](C1)[NH2+]2:1
|
503 |
+
C1=C[C@@H]2CN(C1)CN2:1
|
504 |
+
C1=C[C@@H]2CCOC(C1)O2:1
|
505 |
+
C1=C[C@@H]2CC=CC(C1)C2:1
|
506 |
+
C1=CSOCNC1:1
|
507 |
+
C1=CSOC1:1
|
508 |
+
C1=CSCO1:1
|
509 |
+
C1=CSCCCN1:1
|
510 |
+
C1=CSCCC=N1:1
|
511 |
+
C1=CSC=CC1:1
|
512 |
+
C1=CPNCN1:1
|
513 |
+
C1=CO[NH2+]C1:1
|
514 |
+
C1=COCNC1:1
|
515 |
+
C1=COCC[N+]1:1
|
516 |
+
C1=COCCCO1:1
|
517 |
+
C1=COCCCC1:1
|
518 |
+
C1=COC=CN1:1
|
519 |
+
C1=CN[NH2+]N1:1
|
520 |
+
C1=CN[NH+]C1:1
|
521 |
+
C1=CN[C@@H]2CC[C@H](N1)O2:1
|
522 |
+
C1=CNOC1:1
|
523 |
+
C1=CNCOC1:1
|
524 |
+
C1=CNCNN=C1:1
|
525 |
+
C1=CNCN=N1:1
|
526 |
+
C1=CNCCC=N1:1
|
527 |
+
C1=CNCC=[N+]1:1
|
528 |
+
C1=CNC=[N+]C1:1
|
529 |
+
C1=CN=CCSC1:1
|
530 |
+
C1=CN=CCCC1:1
|
531 |
+
C1=CC[NH2+]CCC1:1
|
532 |
+
C1=CC[N+]C=C1:1
|
533 |
+
C1=CCOCCC1:1
|
534 |
+
C1=CCNCC=C1:1
|
535 |
+
C1=CCN=NC1:1
|
536 |
+
C1=CCC[NH+]CC1:1
|
537 |
+
C1=CCCSC=C1:1
|
538 |
+
C1=CC=NCC=C1:1
|
539 |
+
C1=CC=COC=C1:1
|
540 |
+
C1=CC=CNC=C1:1
|
541 |
+
C1=CC2CC[NH+][C@@H](C1)C2:1
|
542 |
+
C1=CC2CCCC(C1)C2:1
|
543 |
+
C1=CC2CCC(C1)O2:1
|
544 |
+
C1=CC2CC=C[C@H](C2)OC1:1
|
545 |
+
C1=CC2CC(CN2)[NH2+]C1:1
|
546 |
+
C1=CC2C3C[NH2+][C@H]2CC1C3:1
|
547 |
+
C1#CCCCCCC1:1
|
548 |
+
B1OCCCO1:1
|
549 |
+
B1CCCCO1:1
|