zaixizhang commited on
Commit
10efe81
1 Parent(s): 465c18c
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.lmdb filter=lfs diff=lfs merge=lfs -text
app.py CHANGED
@@ -1,7 +1,51 @@
1
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
5
 
6
- iface = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- iface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ import os
3
+ from rdkit import Chem
4
+ from motif_sample import demo
5
+ import py3Dmol
6
+ import tempfile
7
+ from rdkit.Chem import AllChem
8
+ import numpy as np
9
+ from PIL import Image
10
+ import io
11
+ from rdkit.Chem import Draw
12
 
 
 
13
 
14
+ # Function to serve the file via Gradio
15
+ def create_and_return_sdf(Protein_index: int):
16
+ # Ensure input is an integer, as it's coming from the interface.
17
+ number = Protein_index
18
+ number = int(number)
19
+
20
+ # Generate SDF file (you'll replace this with your actual logic)
21
+ sdf_filename = demo(number)
22
+
23
+ suppl = Chem.SDMolSupplier(sdf_filename)
24
+ mol = next(suppl)
25
+ # AllChem.UFFOptimizeMolecule(mol)
26
+ mol = Chem.MolFromSmiles(Chem.MolToSmiles(mol))
27
+
28
+ for atom in mol.GetAtoms():
29
+ atom.SetAtomMapNum(0)
30
+
31
+ mol_image = Draw.MolToImage(mol)
32
+
33
+ np_image = np.array(mol_image)
34
+ np_image = np_image[:, :, :3]
35
+
36
+ return np_image, sdf_filename
37
+
38
+
39
+ # Define Gradio interface
40
+ iface = gr.Interface(
41
+ fn=create_and_return_sdf,
42
+ inputs="text",
43
+ outputs=[
44
+ gr.outputs.Image(type="numpy", label="Molecule Image"),
45
+ gr.outputs.File(label="Download SDF")
46
+ ],
47
+ live=False # The function should only be called when the user submits the form
48
+ )
49
+
50
+ # Launch the interface
51
+ iface.launch(share=True)
checkpoints/pretrained.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a61f9bee0a6ce3101d8df5b55d71831d8428c4d6ff82ab81ada8bb5277babd42
3
+ size 44147405
configs/sample.yml ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ dataset:
2
+ name: pl
3
+ path: ./data
4
+ split: ./data/split_by_name.pt
5
+
6
+ model:
7
+ checkpoint: ./checkpoints/pretrained.pt
8
+ hidden_channels: 256
9
+ random_alpha: False
10
+
11
+ sample:
12
+ seed: 2024
13
+ num_samples: 100
14
+ num_retry: 5
15
+ max_steps: 12
16
+ batch_size: 10
17
+ num_workers: 4
18
+ n_samples: 5
data/index.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c095461a584f03838af99d6411483040f641659d1521dc5247bcefd72e36ca8e
3
+ size 226859
data/pdbbind_pocket10_name2id.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0aa8a46e2cd77abb5b59221ece57e607ba30427d964d9d5b08c0d02e3449399c
3
+ size 237265
data/pdbbind_pocket10_processed.lmdb ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a21c6a2231e3394d3c8e27b1e547a9376f4ee0bdc1c246fbbf32a8fc076710eb
3
+ size 386912256
data/pdbbind_pocket10_processed.lmdb-lock ADDED
Binary file (8.19 kB). View file
 
data/split_by_name.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d406b7bd5189c35cbe73943b5c8e0d61ae1fdc79b40d24dbb45c474915892b56
3
+ size 227451
evaluation/prepare_receptor4.py ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ #
3
+ #
4
+ #
5
+ # $Header: /opt/cvs/python/packages/share1.5/AutoDockTools/Utilities24/prepare_receptor4.py,v 1.11 2007/11/28 22:40:22 rhuey Exp $
6
+ #
7
+ import os
8
+
9
+ from MolKit import Read
10
+ import MolKit.molecule
11
+ import MolKit.protein
12
+ from AutoDockTools.MoleculePreparation import AD4ReceptorPreparation
13
+
14
+
15
+ if __name__ == '__main__':
16
+ import sys
17
+ import getopt
18
+
19
+
20
+ def usage():
21
+ "Print helpful, accurate usage statement to stdout."
22
+ print "Usage: prepare_receptor4.py -r filename"
23
+ print
24
+ print " Description of command..."
25
+ print " -r receptor_filename "
26
+ print " supported file types include pdb,mol2,pdbq,pdbqs,pdbqt, possibly pqr,cif"
27
+ print " Optional parameters:"
28
+ print " [-v] verbose output (default is minimal output)"
29
+ print " [-o pdbqt_filename] (default is 'molecule_name.pdbqt')"
30
+ print " [-A] type(s) of repairs to make: "
31
+ print " 'bonds_hydrogens': build bonds and add hydrogens "
32
+ print " 'bonds': build a single bond from each atom with no bonds to its closest neighbor"
33
+ print " 'hydrogens': add hydrogens"
34
+ print " 'checkhydrogens': add hydrogens only if there are none already"
35
+ print " 'None': do not make any repairs "
36
+ print " (default is 'checkhydrogens')"
37
+ print " [-C] preserve all input charges ie do not add new charges "
38
+ print " (default is addition of gasteiger charges)"
39
+ print " [-p] preserve input charges on specific atom types, eg -p Zn -p Fe"
40
+ print " [-U] cleanup type:"
41
+ print " 'nphs': merge charges and remove non-polar hydrogens"
42
+ print " 'lps': merge charges and remove lone pairs"
43
+ print " 'waters': remove water residues"
44
+ print " 'nonstdres': remove chains composed entirely of residues of"
45
+ print " types other than the standard 20 amino acids"
46
+ print " 'deleteAltB': remove XX@B atoms and rename XX@A atoms->XX"
47
+ print " (default is 'nphs_lps_waters_nonstdres') "
48
+ print " [-e] delete every nonstd residue from any chain"
49
+ print " 'True': any residue whose name is not in this list:"
50
+ print " ['CYS','ILE','SER','VAL','GLN','LYS','ASN', "
51
+ print " 'PRO','THR','PHE','ALA','HIS','GLY','ASP', "
52
+ print " 'LEU', 'ARG', 'TRP', 'GLU', 'TYR','MET', "
53
+ print " 'HID', 'HSP', 'HIE', 'HIP', 'CYX', 'CSS']"
54
+ print " will be deleted from any chain. "
55
+ print " NB: there are no nucleic acid residue names at all "
56
+ print " in the list and no metals. "
57
+ print " (default is False which means not to do this)"
58
+ print " [-M] interactive "
59
+ print " (default is 'automatic': outputfile is written with no further user input)"
60
+
61
+
62
+ # process command arguments
63
+ try:
64
+ opt_list, args = getopt.getopt(sys.argv[1:], 'r:vo:A:Cp:U:eM:')
65
+
66
+ except getopt.GetoptError, msg:
67
+ print 'prepare_receptor4.py: %s' %msg
68
+ usage()
69
+ sys.exit(2)
70
+
71
+ # initialize required parameters
72
+ #-s: receptor
73
+ receptor_filename = None
74
+
75
+ # optional parameters
76
+ verbose = None
77
+ #-A: repairs to make: add bonds and/or hydrogens or checkhydrogens
78
+ repairs = ''
79
+ #-C default: add gasteiger charges
80
+ charges_to_add = 'gasteiger'
81
+ #-p preserve charges on specific atom types
82
+ preserve_charge_types=None
83
+ #-U: cleanup by merging nphs_lps, nphs, lps, waters, nonstdres
84
+ cleanup = "nphs_lps_waters_nonstdres"
85
+ #-o outputfilename
86
+ outputfilename = None
87
+ #-m mode
88
+ mode = 'automatic'
89
+ #-e delete every nonstd residue from each chain
90
+ delete_single_nonstd_residues = None
91
+
92
+ #'r:vo:A:Cp:U:eMh'
93
+ for o, a in opt_list:
94
+ if o in ('-r', '--r'):
95
+ receptor_filename = a
96
+ if verbose: print 'set receptor_filename to ', a
97
+ if o in ('-v', '--v'):
98
+ verbose = True
99
+ if verbose: print 'set verbose to ', True
100
+ if o in ('-o', '--o'):
101
+ outputfilename = a
102
+ if verbose: print 'set outputfilename to ', a
103
+ if o in ('-A', '--A'):
104
+ repairs = a
105
+ if verbose: print 'set repairs to ', a
106
+ if o in ('-C', '--C'):
107
+ charges_to_add = None
108
+ if verbose: print 'do not add charges'
109
+ if o in ('-p', '--p'):
110
+ if not preserve_charge_types:
111
+ preserve_charge_types = a
112
+ else:
113
+ preserve_charge_types = preserve_charge_types + ','+ a
114
+ if verbose: print 'preserve initial charges on ', preserve_charge_types
115
+ if o in ('-U', '--U'):
116
+ cleanup = a
117
+ if verbose: print 'set cleanup to ', a
118
+ if o in ('-e', '--e'):
119
+ delete_single_nonstd_residues = True
120
+ if verbose: print 'set delete_single_nonstd_residues to True'
121
+ if o in ('-M', '--M'):
122
+ mode = a
123
+ if verbose: print 'set mode to ', a
124
+ if o in ('-h', '--'):
125
+ usage()
126
+ sys.exit()
127
+
128
+
129
+ if not receptor_filename:
130
+ print 'prepare_receptor4: receptor filename must be specified.'
131
+ usage()
132
+ sys.exit()
133
+
134
+ #what about nucleic acids???
135
+
136
+ mols = Read(receptor_filename)
137
+ if verbose: print 'read ', receptor_filename
138
+ mol = mols[0]
139
+ preserved = {}
140
+ if charges_to_add is not None and preserve_charge_types is not None:
141
+ preserved_types = preserve_charge_types.split(',')
142
+ if verbose: print "preserved_types=", preserved_types
143
+ for t in preserved_types:
144
+ if verbose: print 'preserving charges on type->', t
145
+ if not len(t): continue
146
+ ats = mol.allAtoms.get(lambda x: x.autodock_element==t)
147
+ if verbose: print "preserving charges on ", ats.name
148
+ for a in ats:
149
+ if a.chargeSet is not None:
150
+ preserved[a] = [a.chargeSet, a.charge]
151
+
152
+ if len(mols)>1:
153
+ if verbose: print "more than one molecule in file"
154
+ #use the molecule with the most atoms
155
+ ctr = 1
156
+ for m in mols[1:]:
157
+ ctr += 1
158
+ if len(m.allAtoms)>len(mol.allAtoms):
159
+ mol = m
160
+ if verbose: print "mol set to ", ctr, "th molecule with", len(mol.allAtoms), "atoms"
161
+ mol.buildBondsByDistance()
162
+
163
+ if verbose:
164
+ print "setting up RPO with mode=", mode,
165
+ print "and outputfilename= ", outputfilename
166
+ print "charges_to_add=", charges_to_add
167
+ print "delete_single_nonstd_residues=", delete_single_nonstd_residues
168
+
169
+ RPO = AD4ReceptorPreparation(mol, mode, repairs, charges_to_add,
170
+ cleanup, outputfilename=outputfilename,
171
+ preserved=preserved,
172
+ delete_single_nonstd_residues=delete_single_nonstd_residues)
173
+
174
+ if charges_to_add is not None:
175
+ #restore any previous charges
176
+ for atom, chargeList in preserved.items():
177
+ atom._charges[chargeList[0]] = chargeList[1]
178
+ atom.chargeSet = chargeList[0]
179
+
180
+
181
+ # To execute this command type:
182
+ # prepare_receptor4.py -r pdb_file -o outputfilename -A checkhydrogens
183
+
evaluation/vina_score.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from vina import Vina
2
+ from rdkit.Chem.rdForceFieldHelpers import UFFOptimizeMolecule
3
+ from rdkit import Chem
4
+ import numpy as np
5
+ import os
6
+
7
+ for i in range(100):
8
+ path = './' + str(i) + '.sdf'
9
+ if os.path.exists(path):
10
+ print(path)
11
+ v = Vina(sf_name='vina')
12
+ v.set_receptor('2rma_protein.pdbqt')
13
+ v.set_ligand_from_file('2rma_ligand'+'.pdbqt')
14
+
15
+ # Calculate the docking center
16
+ mol = Chem.MolFromMolFile(path, sanitize=True)
17
+ mol = Chem.AddHs(mol, addCoords=True)
18
+ UFFOptimizeMolecule(mol)
19
+ pos = mol.GetConformer(0).GetPositions()
20
+ center = np.mean(pos, 0)
21
+
22
+ v.compute_vina_maps(center=center, box_size=[20, 20, 20])
23
+
24
+ # Score the current pose
25
+ energy = v.score()
26
+ print('Score before minimization: %.3f (kcal/mol)' % energy[0])
27
+
28
+ # Minimized locally the current pose
29
+ energy_minimized = v.optimize()
30
+ print('Score after minimization : %.3f (kcal/mol)' % energy_minimized[0])
31
+ v.write_pose('ligand_minimized.pdbqt', overwrite=True)
32
+
33
+ # Dock the ligand
34
+ v.dock(exhaustiveness=64, n_poses=30)
35
+ v.write_poses('out.pdbqt', n_poses=5, overwrite=True)
models/common.py ADDED
@@ -0,0 +1,282 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from torch.nn.modules.loss import _WeightedLoss
6
+ from torch_scatter import scatter_mean, scatter_add
7
+
8
+
9
+ def split_tensor_by_batch(x, batch, num_graphs=None):
10
+ """
11
+ Args:
12
+ x: (N, ...)
13
+ batch: (B, )
14
+ Returns:
15
+ [(N_1, ), (N_2, ) ..., (N_B, ))]
16
+ """
17
+ if num_graphs is None:
18
+ num_graphs = batch.max().item() + 1
19
+ x_split = []
20
+ for i in range (num_graphs):
21
+ mask = batch == i
22
+ x_split.append(x[mask])
23
+ return x_split
24
+
25
+
26
+ def concat_tensors_to_batch(x_split):
27
+ x = torch.cat(x_split, dim=0)
28
+ batch = torch.repeat_interleave(
29
+ torch.arange(len(x_split)),
30
+ repeats=torch.LongTensor([s.size(0) for s in x_split])
31
+ ).to(device=x.device)
32
+ return x, batch
33
+
34
+
35
+ def split_tensor_to_segments(x, segsize):
36
+ num_segs = math.ceil(x.size(0) / segsize)
37
+ segs = []
38
+ for i in range(num_segs):
39
+ segs.append(x[i*segsize : (i+1)*segsize])
40
+ return segs
41
+
42
+
43
+ def split_tensor_by_lengths(x, lengths):
44
+ segs = []
45
+ for l in lengths:
46
+ segs.append(x[:l])
47
+ x = x[l:]
48
+ return segs
49
+
50
+
51
+ def batch_intersection_mask(batch, batch_filter):
52
+ batch_filter = batch_filter.unique()
53
+ mask = (batch.view(-1, 1) == batch_filter.view(1, -1)).any(dim=1)
54
+ return mask
55
+
56
+
57
+ class MeanReadout(nn.Module):
58
+ """Mean readout operator over graphs with variadic sizes."""
59
+
60
+ def forward(self, input, batch, num_graphs):
61
+ """
62
+ Perform readout over the graph(s).
63
+ Parameters:
64
+ data (torch_geometric.data.Data): batched graph
65
+ input (Tensor): node representations
66
+ Returns:
67
+ Tensor: graph representations
68
+ """
69
+ output = scatter_mean(input, batch, dim=0, dim_size=num_graphs)
70
+ return output
71
+
72
+
73
+ class SumReadout(nn.Module):
74
+ """Sum readout operator over graphs with variadic sizes."""
75
+
76
+ def forward(self, input, batch, num_graphs):
77
+ """
78
+ Perform readout over the graph(s).
79
+ Parameters:
80
+ data (torch_geometric.data.Data): batched graph
81
+ input (Tensor): node representations
82
+ Returns:
83
+ Tensor: graph representations
84
+ """
85
+ output = scatter_add(input, batch, dim=0, dim_size=num_graphs)
86
+ return output
87
+
88
+
89
+ class MultiLayerPerceptron(nn.Module):
90
+ """
91
+ Multi-layer Perceptron.
92
+ Note there is no activation or dropout in the last layer.
93
+ Parameters:
94
+ input_dim (int): input dimension
95
+ hidden_dim (list of int): hidden dimensions
96
+ activation (str or function, optional): activation function
97
+ dropout (float, optional): dropout rate
98
+ """
99
+
100
+ def __init__(self, input_dim, hidden_dims, activation="relu", dropout=0):
101
+ super(MultiLayerPerceptron, self).__init__()
102
+
103
+ self.dims = [input_dim] + hidden_dims
104
+ if isinstance(activation, str):
105
+ self.activation = getattr(F, activation)
106
+ else:
107
+ self.activation = None
108
+ if dropout:
109
+ self.dropout = nn.Dropout(dropout)
110
+ else:
111
+ self.dropout = None
112
+
113
+ self.layers = nn.ModuleList()
114
+ for i in range(len(self.dims) - 1):
115
+ self.layers.append(nn.Linear(self.dims[i], self.dims[i + 1]))
116
+
117
+ def forward(self, input):
118
+ """"""
119
+ x = input
120
+ for i, layer in enumerate(self.layers):
121
+ x = layer(x)
122
+ if i < len(self.layers) - 1:
123
+ if self.activation:
124
+ x = self.activation(x)
125
+ if self.dropout:
126
+ x = self.dropout(x)
127
+ return x
128
+
129
+
130
+ class SmoothCrossEntropyLoss(_WeightedLoss):
131
+ def __init__(self, weight=None, reduction='mean', smoothing=0.0):
132
+ super().__init__(weight=weight, reduction=reduction)
133
+ self.smoothing = smoothing
134
+ self.weight = weight
135
+ self.reduction = reduction
136
+
137
+ @staticmethod
138
+ def _smooth_one_hot(targets:torch.Tensor, n_classes:int, smoothing=0.0):
139
+ assert 0 <= smoothing < 1
140
+ with torch.no_grad():
141
+ targets = torch.empty(size=(targets.size(0), n_classes),
142
+ device=targets.device) \
143
+ .fill_(smoothing /(n_classes-1)) \
144
+ .scatter_(1, targets.data.unsqueeze(1), 1.-smoothing)
145
+ return targets
146
+
147
+ def forward(self, inputs, targets):
148
+ targets = SmoothCrossEntropyLoss._smooth_one_hot(targets, inputs.size(-1),
149
+ self.smoothing)
150
+ lsm = F.log_softmax(inputs, -1)
151
+
152
+ if self.weight is not None:
153
+ lsm = lsm * self.weight.unsqueeze(0)
154
+
155
+ loss = -(targets * lsm).sum(-1)
156
+
157
+ if self.reduction == 'sum':
158
+ loss = loss.sum()
159
+ elif self.reduction == 'mean':
160
+ loss = loss.mean()
161
+
162
+ return loss
163
+
164
+
165
+ class GaussianSmearing(nn.Module):
166
+ def __init__(self, start=0.0, stop=10.0, num_gaussians=50):
167
+ super().__init__()
168
+ offset = torch.linspace(start, stop, num_gaussians)
169
+ self.coeff = -0.5 / (offset[1] - offset[0]).item()**2
170
+ self.register_buffer('offset', offset)
171
+
172
+ def forward(self, dist):
173
+ dist = dist.view(-1, 1) - self.offset.view(1, -1)
174
+ return torch.exp(self.coeff * torch.pow(dist, 2))
175
+
176
+
177
+ class ShiftedSoftplus(nn.Module):
178
+ def __init__(self):
179
+ super().__init__()
180
+ self.shift = torch.log(torch.tensor(2.0)).item()
181
+
182
+ def forward(self, x):
183
+ return F.softplus(x) - self.shift
184
+
185
+
186
+ def compose_context(h_protein, h_ligand, pos_protein, pos_ligand, batch_protein, batch_ligand):
187
+ batch_ctx = torch.cat([batch_protein, batch_ligand], dim=0)
188
+ sort_idx = batch_ctx.argsort()
189
+
190
+ mask_protein = torch.cat([
191
+ torch.ones([batch_protein.size(0)], device=batch_protein.device).bool(),
192
+ torch.zeros([batch_ligand.size(0)], device=batch_ligand.device).bool(),
193
+ ], dim=0)[sort_idx]
194
+
195
+ batch_ctx = batch_ctx[sort_idx]
196
+ h_ctx = torch.cat([h_protein, h_ligand], dim=0)[sort_idx] # (N_protein+N_ligand, H)
197
+ pos_ctx = torch.cat([pos_protein, pos_ligand], dim=0)[sort_idx] # (N_protein+N_ligand, 3)
198
+
199
+ return h_ctx, pos_ctx, batch_ctx
200
+
201
+
202
+ def get_complete_graph(batch):
203
+ """
204
+ Args:
205
+ batch: Batch index.
206
+ Returns:
207
+ edge_index: (2, N_1 + N_2 + ... + N_{B-1}), where N_i is the number of nodes of the i-th graph.
208
+ neighbors: (B, ), number of edges per graph.
209
+ """
210
+ natoms = scatter_add(torch.ones_like(batch), index=batch, dim=0)
211
+
212
+ natoms_sqr = (natoms ** 2).long()
213
+ num_atom_pairs = torch.sum(natoms_sqr)
214
+ natoms_expand = torch.repeat_interleave(natoms, natoms_sqr)
215
+
216
+ index_offset = torch.cumsum(natoms, dim=0) - natoms
217
+ index_offset_expand = torch.repeat_interleave(index_offset, natoms_sqr)
218
+
219
+ index_sqr_offset = torch.cumsum(natoms_sqr, dim=0) - natoms_sqr
220
+ index_sqr_offset = torch.repeat_interleave(index_sqr_offset, natoms_sqr)
221
+
222
+ atom_count_sqr = torch.arange(num_atom_pairs, device=num_atom_pairs.device) - index_sqr_offset
223
+
224
+ index1 = (atom_count_sqr // natoms_expand).long() + index_offset_expand
225
+ index2 = (atom_count_sqr % natoms_expand).long() + index_offset_expand
226
+ edge_index = torch.cat([index1.view(1, -1), index2.view(1, -1)])
227
+ mask = torch.logical_not(index1 == index2)
228
+ edge_index = edge_index[:, mask]
229
+
230
+ num_edges = natoms_sqr - natoms # Number of edges per graph
231
+
232
+ return edge_index, num_edges
233
+
234
+
235
+ def compose_context_stable(h_protein, h_ligand, pos_protein, pos_ligand, batch_protein, batch_ligand):
236
+ num_graphs = batch_protein.max().item() + 1
237
+
238
+ batch_ctx = []
239
+ h_ctx = []
240
+ pos_ctx = []
241
+ mask_protein = []
242
+
243
+ for i in range(num_graphs):
244
+ mask_p, mask_l = (batch_protein == i), (batch_ligand == i)
245
+ batch_p, batch_l = batch_protein[mask_p], batch_ligand[mask_l]
246
+
247
+ batch_ctx += [batch_p, batch_l]
248
+ h_ctx += [h_protein[mask_p], h_ligand[mask_l]]
249
+ pos_ctx += [pos_protein[mask_p], pos_ligand[mask_l]]
250
+ mask_protein += [
251
+ torch.ones([batch_p.size(0)], device=batch_p.device, dtype=torch.bool),
252
+ torch.zeros([batch_l.size(0)], device=batch_l.device, dtype=torch.bool),
253
+ ]
254
+
255
+ batch_ctx = torch.cat(batch_ctx, dim=0)
256
+ h_ctx = torch.cat(h_ctx, dim=0)
257
+ pos_ctx = torch.cat(pos_ctx, dim=0)
258
+ mask_protein = torch.cat(mask_protein, dim=0)
259
+
260
+ return h_ctx, pos_ctx, batch_ctx, mask_protein
261
+
262
+ if __name__ == '__main__':
263
+ h_protein = torch.randn([60, 64])
264
+ h_ligand = -torch.randn([33, 64])
265
+ pos_protein = torch.clamp(torch.randn([60, 3]), 0, float('inf'))
266
+ pos_ligand = torch.clamp(torch.randn([33, 3]), float('-inf'), 0)
267
+ batch_protein = torch.LongTensor([0]*10 + [1]*20 + [2]*30)
268
+ batch_ligand = torch.LongTensor([0]*11 + [1]*11 + [2]*11)
269
+
270
+ h_ctx, pos_ctx, batch_ctx, mask_protein = compose_context_stable(h_protein, h_ligand, pos_protein, pos_ligand, batch_protein, batch_ligand)
271
+
272
+ assert (batch_ctx[mask_protein] == batch_protein).all()
273
+ assert (batch_ctx[torch.logical_not(mask_protein)] == batch_ligand).all()
274
+
275
+ assert torch.allclose(h_ctx[torch.logical_not(mask_protein)], h_ligand)
276
+ assert torch.allclose(h_ctx[mask_protein], h_protein)
277
+
278
+ assert torch.allclose(pos_ctx[torch.logical_not(mask_protein)], pos_ligand)
279
+ assert torch.allclose(pos_ctx[mask_protein], pos_protein)
280
+
281
+
282
+
models/encoders/__init__.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .schnet import SchNetEncoder
2
+ from .tf import TransformerEncoder
3
+ from .gnn import GNN_graphpred, MLP
4
+
5
+
6
+ def get_encoder(config):
7
+ if config.name == 'schnet':
8
+ return SchNetEncoder(
9
+ hidden_channels = config.hidden_channels,
10
+ num_filters = config.num_filters,
11
+ num_interactions = config.num_interactions,
12
+ edge_channels = config.edge_channels,
13
+ cutoff = config.cutoff,
14
+ )
15
+ elif config.name == 'tf':
16
+ return TransformerEncoder(
17
+ hidden_channels = config.hidden_channels,
18
+ edge_channels = config.edge_channels,
19
+ key_channels = config.key_channels,
20
+ num_heads = config.num_heads,
21
+ num_interactions = config.num_interactions,
22
+ k = config.knn,
23
+ cutoff = config.cutoff,
24
+ )
25
+ else:
26
+ raise NotImplementedError('Unknown encoder: %s' % config.name)
models/encoders/gnn.py ADDED
@@ -0,0 +1,441 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch_geometric.nn import MessagePassing
4
+ from torch_geometric.utils import add_self_loops, degree, softmax
5
+ from torch_geometric.nn import global_add_pool, global_mean_pool, global_max_pool, GlobalAttention, Set2Set
6
+ import torch.nn.functional as F
7
+ from torch_scatter import scatter_add
8
+ from torch_geometric.nn.inits import glorot, zeros
9
+
10
+ num_atom_type = 120 #including the extra mask tokens
11
+ num_chirality_tag = 3
12
+
13
+ num_bond_type = 6 #including aromatic and self-loop edge, and extra masked tokens
14
+ num_bond_direction = 3
15
+
16
+ class GINConv(MessagePassing):
17
+ """
18
+ Extension of GIN aggregation to incorporate edge information by concatenation.
19
+
20
+ Args:
21
+ emb_dim (int): dimensionality of embeddings for nodes and edges.
22
+ embed_input (bool): whether to embed input or not.
23
+
24
+
25
+ See https://arxiv.org/abs/1810.00826
26
+ """
27
+ def __init__(self, emb_dim, aggr = "add"):
28
+ super(GINConv, self).__init__()
29
+ #multi-layer perceptron
30
+ self.mlp = torch.nn.Sequential(torch.nn.Linear(emb_dim, 2*emb_dim), torch.nn.ReLU(), torch.nn.Linear(2*emb_dim, emb_dim))
31
+ self.edge_embedding1 = torch.nn.Embedding(num_bond_type, emb_dim)
32
+ self.edge_embedding2 = torch.nn.Embedding(num_bond_direction, emb_dim)
33
+
34
+ torch.nn.init.xavier_uniform_(self.edge_embedding1.weight.data)
35
+ torch.nn.init.xavier_uniform_(self.edge_embedding2.weight.data)
36
+ self.aggr = aggr
37
+
38
+ def forward(self, x, edge_index, edge_attr):
39
+ #add self loops in the edge space
40
+ edge_index = add_self_loops(edge_index, num_nodes = x.size(0))
41
+
42
+ #add features corresponding to self-loop edges.
43
+ self_loop_attr = torch.zeros(x.size(0), 2)
44
+ self_loop_attr[:,0] = 4 #bond type for self-loop edge
45
+ self_loop_attr = self_loop_attr.to(edge_attr.device).to(edge_attr.dtype)
46
+ edge_attr = torch.cat((edge_attr, self_loop_attr), dim = 0)
47
+
48
+ edge_embeddings = self.edge_embedding1(edge_attr[:,0]) + self.edge_embedding2(edge_attr[:,1])
49
+
50
+ return self.propagate(edge_index[0], x=x, edge_attr=edge_embeddings)
51
+
52
+ def message(self, x_j, edge_attr):
53
+ return x_j + edge_attr
54
+
55
+ def update(self, aggr_out):
56
+ return self.mlp(aggr_out)
57
+
58
+
59
+ class GCNConv(MessagePassing):
60
+
61
+ def __init__(self, emb_dim, aggr = "add"):
62
+ super(GCNConv, self).__init__()
63
+
64
+ self.emb_dim = emb_dim
65
+ self.linear = torch.nn.Linear(emb_dim, emb_dim)
66
+ self.edge_embedding1 = torch.nn.Embedding(num_bond_type, emb_dim)
67
+ self.edge_embedding2 = torch.nn.Embedding(num_bond_direction, emb_dim)
68
+
69
+ torch.nn.init.xavier_uniform_(self.edge_embedding1.weight.data)
70
+ torch.nn.init.xavier_uniform_(self.edge_embedding2.weight.data)
71
+
72
+ self.aggr = aggr
73
+
74
+ def norm(self, edge_index, num_nodes, dtype):
75
+ ### assuming that self-loops have been already added in edge_index
76
+ edge_index = edge_index[0]
77
+ edge_weight = torch.ones((edge_index.size(1), ), dtype=dtype,
78
+ device=edge_index.device)
79
+ row, col = edge_index
80
+ deg = scatter_add(edge_weight, row, dim=0, dim_size=num_nodes)
81
+ deg_inv_sqrt = deg.pow(-0.5)
82
+ deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
83
+
84
+ return deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col]
85
+
86
+
87
+ def forward(self, x, edge_index, edge_attr):
88
+ #add self loops in the edge space
89
+ edge_index = add_self_loops(edge_index, num_nodes = x.size(0))
90
+
91
+ #add features corresponding to self-loop edges.
92
+ self_loop_attr = torch.zeros(x.size(0), 2)
93
+ self_loop_attr[:,0] = 4 #bond type for self-loop edge
94
+ self_loop_attr = self_loop_attr.to(edge_attr.device).to(edge_attr.dtype)
95
+ edge_attr = torch.cat((edge_attr, self_loop_attr), dim = 0)
96
+
97
+ edge_embeddings = self.edge_embedding1(edge_attr[:,0]) + self.edge_embedding2(edge_attr[:,1])
98
+
99
+ norm = self.norm(edge_index, x.size(0), x.dtype)
100
+
101
+ x = self.linear(x)
102
+
103
+ return self.propagate(edge_index[0], x=x, edge_attr=edge_embeddings, norm=norm)
104
+
105
+ def message(self, x_j, edge_attr, norm):
106
+ return norm.view(-1, 1) * (x_j + edge_attr)
107
+
108
+
109
+ class GATConv(MessagePassing):
110
+ def __init__(self, emb_dim, heads=2, negative_slope=0.2, aggr = "add"):
111
+ super(GATConv, self).__init__()
112
+
113
+ self.aggr = aggr
114
+
115
+ self.emb_dim = emb_dim
116
+ self.heads = heads
117
+ self.negative_slope = negative_slope
118
+
119
+ self.weight_linear = torch.nn.Linear(emb_dim, heads * emb_dim)
120
+ self.att = torch.nn.Parameter(torch.Tensor(1, heads, 2 * emb_dim))
121
+
122
+ self.bias = torch.nn.Parameter(torch.Tensor(emb_dim))
123
+
124
+ self.edge_embedding1 = torch.nn.Embedding(num_bond_type, heads * emb_dim)
125
+ self.edge_embedding2 = torch.nn.Embedding(num_bond_direction, heads * emb_dim)
126
+
127
+ torch.nn.init.xavier_uniform_(self.edge_embedding1.weight.data)
128
+ torch.nn.init.xavier_uniform_(self.edge_embedding2.weight.data)
129
+
130
+ self.reset_parameters()
131
+
132
+ def norm(self, edge_index, num_nodes, dtype):
133
+ ### assuming that self-loops have been already added in edge_index
134
+ edge_index = edge_index[0]
135
+ edge_weight = torch.ones((edge_index.size(1), ), dtype=dtype,
136
+ device=edge_index.device)
137
+ row, col = edge_index
138
+ deg = scatter_add(edge_weight, row, dim=0, dim_size=num_nodes)
139
+ deg_inv_sqrt = deg.pow(-0.5)
140
+ deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
141
+
142
+ return deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col]
143
+
144
+ def reset_parameters(self):
145
+ glorot(self.att)
146
+ zeros(self.bias)
147
+
148
+ def forward(self, x, edge_index, edge_attr):
149
+
150
+ #add self loops in the edge space
151
+ edge_index = add_self_loops(edge_index, num_nodes = x.size(0))
152
+
153
+ norm = self.norm(edge_index, x.size(0), x.dtype)
154
+
155
+ #add features corresponding to self-loop edges.
156
+ self_loop_attr = torch.zeros(x.size(0), 2)
157
+ self_loop_attr[:,0] = 4 #bond type for self-loop edge
158
+ self_loop_attr = self_loop_attr.to(edge_attr.device).to(edge_attr.dtype)
159
+ edge_attr = torch.cat((edge_attr, self_loop_attr), dim = 0)
160
+
161
+ edge_embeddings = self.edge_embedding1(edge_attr[:,0]) + self.edge_embedding2(edge_attr[:,1])
162
+
163
+ x = self.weight_linear(x).view(-1, self.heads, self.emb_dim)
164
+
165
+ return self.propagate(edge_index[0], x=x, edge_attr=edge_embeddings, norm=norm)
166
+
167
+ def message(self, edge_index, x_i, x_j, edge_attr):
168
+ edge_attr = edge_attr.view(-1, self.heads, self.emb_dim)
169
+ x_j += edge_attr
170
+
171
+ alpha = (torch.cat([x_i, x_j], dim=-1) * self.att).sum(dim=-1)
172
+
173
+ alpha = F.leaky_relu(alpha, self.negative_slope)
174
+ alpha = softmax(alpha, edge_index[0])
175
+
176
+ return x_j * alpha.view(-1, self.heads, 1)
177
+
178
+ def update(self, aggr_out):
179
+ aggr_out = aggr_out.mean(dim=1)
180
+ aggr_out = aggr_out + self.bias
181
+
182
+ return aggr_out
183
+
184
+
185
+ class GraphSAGEConv(MessagePassing):
186
+ def __init__(self, emb_dim, aggr = "mean"):
187
+ super(GraphSAGEConv, self).__init__()
188
+
189
+ self.emb_dim = emb_dim
190
+ self.linear = torch.nn.Linear(emb_dim, emb_dim)
191
+ self.edge_embedding1 = torch.nn.Embedding(num_bond_type, emb_dim)
192
+ self.edge_embedding2 = torch.nn.Embedding(num_bond_direction, emb_dim)
193
+
194
+ torch.nn.init.xavier_uniform_(self.edge_embedding1.weight.data)
195
+ torch.nn.init.xavier_uniform_(self.edge_embedding2.weight.data)
196
+
197
+ self.aggr = aggr
198
+
199
+ def norm(self, edge_index, num_nodes, dtype):
200
+ ### assuming that self-loops have been already added in edge_index
201
+ edge_index = edge_index[0]
202
+ edge_weight = torch.ones((edge_index.size(1), ), dtype=dtype,
203
+ device=edge_index.device)
204
+ row, col = edge_index
205
+ deg = scatter_add(edge_weight, row, dim=0, dim_size=num_nodes)
206
+ deg_inv_sqrt = deg.pow(-0.5)
207
+ deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
208
+
209
+ return deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col]
210
+
211
+ def forward(self, x, edge_index, edge_attr):
212
+ #add self loops in the edge space
213
+ edge_index = add_self_loops(edge_index, num_nodes = x.size(0))
214
+
215
+ #add features corresponding to self-loop edges.
216
+ self_loop_attr = torch.zeros(x.size(0), 2)
217
+ self_loop_attr[:,0] = 4 #bond type for self-loop edge
218
+ self_loop_attr = self_loop_attr.to(edge_attr.device).to(edge_attr.dtype)
219
+ edge_attr = torch.cat((edge_attr, self_loop_attr), dim = 0)
220
+
221
+ edge_embeddings = self.edge_embedding1(edge_attr[:,0]) + self.edge_embedding2(edge_attr[:,1])
222
+
223
+ norm = self.norm(edge_index, x.size(0), x.dtype)
224
+
225
+ x = self.linear(x)
226
+
227
+ return self.propagate(edge_index[0], x=x, edge_attr=edge_embeddings, norm=norm)
228
+
229
+ def message(self, x_j, edge_attr):
230
+ return x_j + edge_attr
231
+
232
+ def update(self, aggr_out):
233
+ return F.normalize(aggr_out, p = 2, dim = -1)
234
+
235
+
236
+
237
+ class GNN(torch.nn.Module):
238
+ """
239
+
240
+
241
+ Args:
242
+ num_layer (int): the number of GNN layers
243
+ emb_dim (int): dimensionality of embeddings
244
+ JK (str): last, concat, max or sum.
245
+ max_pool_layer (int): the layer from which we use max pool rather than add pool for neighbor aggregation
246
+ drop_ratio (float): dropout rate
247
+ gnn_type: gin, gcn, graphsage, gat
248
+
249
+ Output:
250
+ node representations
251
+
252
+ """
253
+ def __init__(self, num_layer, emb_dim, JK = "last", drop_ratio = 0, gnn_type = "gin"):
254
+ super(GNN, self).__init__()
255
+ self.num_layer = num_layer
256
+ self.drop_ratio = drop_ratio
257
+ self.JK = JK
258
+
259
+ if self.num_layer < 2:
260
+ raise ValueError("Number of GNN layers must be greater than 1.")
261
+
262
+ self.x_embedding1 = torch.nn.Embedding(num_atom_type, emb_dim)
263
+ self.x_embedding2 = torch.nn.Embedding(num_chirality_tag, emb_dim)
264
+
265
+ torch.nn.init.xavier_uniform_(self.x_embedding1.weight.data)
266
+ torch.nn.init.xavier_uniform_(self.x_embedding2.weight.data)
267
+
268
+ ###List of MLPs
269
+ self.gnns = torch.nn.ModuleList()
270
+ for layer in range(num_layer):
271
+ if gnn_type == "gin":
272
+ self.gnns.append(GINConv(emb_dim, aggr = "add"))
273
+ elif gnn_type == "gcn":
274
+ self.gnns.append(GCNConv(emb_dim))
275
+ elif gnn_type == "gat":
276
+ self.gnns.append(GATConv(emb_dim))
277
+ elif gnn_type == "graphsage":
278
+ self.gnns.append(GraphSAGEConv(emb_dim))
279
+
280
+ ###List of batchnorms
281
+ self.batch_norms = torch.nn.ModuleList()
282
+ for layer in range(num_layer):
283
+ self.batch_norms.append(torch.nn.BatchNorm1d(emb_dim))
284
+
285
+ #def forward(self, x, edge_index, edge_attr):
286
+ def forward(self, *argv):
287
+ if len(argv) == 3:
288
+ x, edge_index, edge_attr = argv[0], argv[1], argv[2]
289
+ elif len(argv) == 1:
290
+ data = argv[0]
291
+ x, edge_index, edge_attr = data.x, data.edge_index, data.edge_attr
292
+ else:
293
+ raise ValueError("unmatched number of arguments.")
294
+
295
+ x = self.x_embedding1(x[:,0]) + self.x_embedding2(x[:,1])
296
+
297
+ h_list = [x]
298
+ for layer in range(self.num_layer):
299
+ h = self.gnns[layer](h_list[layer], edge_index, edge_attr)
300
+ h = self.batch_norms[layer](h)
301
+ #h = F.dropout(F.relu(h), self.drop_ratio, training = self.training)
302
+ if layer == self.num_layer - 1:
303
+ #remove relu for the last layer
304
+ h = F.dropout(h, self.drop_ratio, training = self.training)
305
+ else:
306
+ h = F.dropout(F.relu(h), self.drop_ratio, training = self.training)
307
+ h_list.append(h)
308
+
309
+ ### Different implementations of Jk-concat
310
+ if self.JK == "concat":
311
+ node_representation = torch.cat(h_list, dim = 1)
312
+ elif self.JK == "last":
313
+ node_representation = h_list[-1]
314
+ elif self.JK == "max":
315
+ h_list = [h.unsqueeze_(0) for h in h_list]
316
+ node_representation = torch.max(torch.cat(h_list, dim = 0), dim = 0)[0]
317
+ elif self.JK == "sum":
318
+ h_list = [h.unsqueeze_(0) for h in h_list]
319
+ node_representation = torch.sum(torch.cat(h_list, dim = 0), dim = 0)[0]
320
+
321
+ return node_representation
322
+
323
+
324
+ class GNN_graphpred(torch.nn.Module):
325
+ """
326
+ Extension of GIN to incorporate edge information by concatenation.
327
+
328
+ Args:
329
+ num_layer (int): the number of GNN layers
330
+ emb_dim (int): dimensionality of embeddings
331
+ num_tasks (int): number of tasks in multi-task learning scenario
332
+ drop_ratio (float): dropout rate
333
+ JK (str): last, concat, max or sum.
334
+ graph_pooling (str): sum, mean, max, attention, set2set
335
+ gnn_type: gin, gcn, graphsage, gat
336
+
337
+ See https://arxiv.org/abs/1810.00826
338
+ JK-net: https://arxiv.org/abs/1806.03536
339
+ """
340
+ def __init__(self, num_layer, emb_dim, num_tasks, JK = "last", drop_ratio = 0, graph_pooling = "mean", gnn_type = "gin"):
341
+ super(GNN_graphpred, self).__init__()
342
+ self.num_layer = num_layer
343
+ self.drop_ratio = drop_ratio
344
+ self.JK = JK
345
+ self.emb_dim = emb_dim
346
+ self.num_tasks = num_tasks
347
+
348
+ if self.num_layer < 2:
349
+ raise ValueError("Number of GNN layers must be greater than 1.")
350
+
351
+ self.gnn = GNN(num_layer, emb_dim, JK, drop_ratio, gnn_type = gnn_type)
352
+
353
+ #Different kind of graph pooling
354
+ if graph_pooling == "sum":
355
+ self.pool = global_add_pool
356
+ elif graph_pooling == "mean":
357
+ self.pool = global_mean_pool
358
+ elif graph_pooling == "max":
359
+ self.pool = global_max_pool
360
+ elif graph_pooling == "attention":
361
+ if self.JK == "concat":
362
+ self.pool = GlobalAttention(gate_nn = torch.nn.Linear((self.num_layer + 1) * emb_dim, 1))
363
+ else:
364
+ self.pool = GlobalAttention(gate_nn = torch.nn.Linear(emb_dim, 1))
365
+ elif graph_pooling[:-1] == "set2set":
366
+ set2set_iter = int(graph_pooling[-1])
367
+ if self.JK == "concat":
368
+ self.pool = Set2Set((self.num_layer + 1) * emb_dim, set2set_iter)
369
+ else:
370
+ self.pool = Set2Set(emb_dim, set2set_iter)
371
+ else:
372
+ raise ValueError("Invalid graph pooling type.")
373
+
374
+ #For graph-level binary classification
375
+ if graph_pooling[:-1] == "set2set":
376
+ self.mult = 2
377
+ else:
378
+ self.mult = 1
379
+
380
+ if self.JK == "concat":
381
+ self.graph_pred_linear = torch.nn.Linear(self.mult * (self.num_layer + 1) * self.emb_dim, self.num_tasks)
382
+ else:
383
+ self.graph_pred_linear = torch.nn.Linear(self.mult * self.emb_dim, self.num_tasks)
384
+
385
+ def from_pretrained(self, model_file):
386
+ #self.gnn = GNN(self.num_layer, self.emb_dim, JK = self.JK, drop_ratio = self.drop_ratio)
387
+ self.gnn.load_state_dict(torch.load(model_file))
388
+
389
+ def forward(self, *argv):
390
+ if len(argv) == 4:
391
+ x, edge_index, edge_attr, batch = argv[0], argv[1], argv[2], argv[3]
392
+ elif len(argv) == 1:
393
+ data = argv[0]
394
+ x, edge_index, edge_attr, batch = data.x, data.edge_index, data.edge_attr, data.batch
395
+ else:
396
+ raise ValueError("unmatched number of arguments.")
397
+
398
+ node_representation = self.gnn(x, edge_index, edge_attr)
399
+
400
+ return self.graph_pred_linear(self.pool(node_representation, batch))
401
+
402
+
403
+ class MLP(nn.Module):
404
+ """
405
+ Creates a NN using nn.ModuleList to automatically adjust the number of layers.
406
+ For each hidden layer, the number of inputs and outputs is constant.
407
+
408
+ Inputs:
409
+ in_dim (int): number of features contained in the input layer.
410
+ out_dim (int): number of features input and output from each hidden layer,
411
+ including the output layer.
412
+ num_layers (int): number of layers in the network
413
+ activation (torch function): activation function to be used during the hidden layers
414
+ """
415
+
416
+ def __init__(self, in_dim, out_dim, num_layers, activation=torch.nn.ReLU(), layer_norm=False, batch_norm=False):
417
+ super(MLP, self).__init__()
418
+ self.layers = nn.ModuleList()
419
+
420
+ h_dim = in_dim if out_dim < 10 else out_dim
421
+
422
+ # create the input layer
423
+ for layer in range(num_layers):
424
+ if layer == 0:
425
+ self.layers.append(nn.Linear(in_dim, h_dim))
426
+ else:
427
+ self.layers.append(nn.Linear(h_dim, h_dim))
428
+ if layer_norm: self.layers.append(nn.LayerNorm(h_dim))
429
+ if batch_norm: self.layers.append(nn.BatchNorm1d(h_dim))
430
+ self.layers.append(activation)
431
+ self.layers.append(nn.Linear(h_dim, out_dim))
432
+
433
+ def forward(self, x):
434
+ for i in range(len(self.layers)):
435
+ x = self.layers[i](x)
436
+ return x
437
+
438
+
439
+ if __name__ == "__main__":
440
+ pass
441
+
models/encoders/schnet.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from torch.nn import Module, Sequential, ModuleList, Linear
4
+ from torch_geometric.nn import MessagePassing, radius_graph
5
+ from math import pi as PI
6
+
7
+ from ..common import GaussianSmearing, ShiftedSoftplus
8
+
9
+
10
+ class CFConv(MessagePassing):
11
+
12
+ def __init__(self, in_channels, out_channels, num_filters, edge_channels, cutoff=10.0):
13
+ super().__init__(aggr='add')
14
+ self.lin1 = Linear(in_channels, num_filters, bias=False)
15
+ self.lin2 = Linear(num_filters, out_channels)
16
+ self.nn = Sequential(
17
+ Linear(edge_channels, num_filters),
18
+ ShiftedSoftplus(),
19
+ Linear(num_filters, num_filters),
20
+ ) # Network for generating filter weights
21
+ self.cutoff = cutoff
22
+ self.reset_parameters()
23
+
24
+ def reset_parameters(self):
25
+ torch.nn.init.xavier_uniform_(self.nn[0].weight)
26
+ self.nn[0].bias.data.fill_(0)
27
+ torch.nn.init.xavier_uniform_(self.nn[2].weight)
28
+ self.nn[0].bias.data.fill_(0)
29
+ torch.nn.init.xavier_uniform_(self.lin1.weight)
30
+ torch.nn.init.xavier_uniform_(self.lin2.weight)
31
+ self.lin2.bias.data.fill_(0)
32
+
33
+ def forward(self, x, edge_index, edge_length, edge_attr):
34
+ W = self.nn(edge_attr)
35
+
36
+ if self.cutoff is not None:
37
+ C = 0.5 * (torch.cos(edge_length * PI / self.cutoff) + 1.0)
38
+ C = C * (edge_length <= self.cutoff) * (edge_length >= 0.0) # Modification: cutoff
39
+ W = W * C.view(-1, 1)
40
+
41
+ x = self.lin1(x)
42
+ x = self.propagate(edge_index, x=x, W=W)
43
+ x = self.lin2(x)
44
+ return x
45
+
46
+ def message(self, x_j, W):
47
+ return x_j * W
48
+
49
+
50
+ class InteractionBlock(Module):
51
+
52
+ def __init__(self, hidden_channels, num_gaussians, num_filters, cutoff):
53
+ super(InteractionBlock, self).__init__()
54
+ self.conv = CFConv(hidden_channels, hidden_channels, num_filters, num_gaussians, cutoff)
55
+ self.act = ShiftedSoftplus()
56
+ self.lin = Linear(hidden_channels, hidden_channels)
57
+ self.reset_parameters()
58
+
59
+ def reset_parameters(self):
60
+ self.conv.reset_parameters()
61
+ torch.nn.init.xavier_uniform_(self.lin.weight)
62
+ self.lin.bias.data.fill_(0)
63
+
64
+ def forward(self, x, edge_index, edge_length, edge_attr):
65
+ x = self.conv(x, edge_index, edge_length, edge_attr)
66
+ x = self.act(x)
67
+ x = self.lin(x)
68
+ return x
69
+
70
+
71
+ class SchNetEncoder(Module):
72
+
73
+ def __init__(self, hidden_channels=128, num_filters=128,
74
+ num_interactions=6, edge_channels=64, cutoff=10.0):
75
+ super().__init__()
76
+
77
+ self.hidden_channels = hidden_channels
78
+ self.num_filters = num_filters
79
+ self.num_interactions = num_interactions
80
+ self.distance_expansion = GaussianSmearing(stop=cutoff, num_gaussians=edge_channels)
81
+ self.cutoff = cutoff
82
+
83
+ self.interactions = ModuleList()
84
+ for _ in range(num_interactions):
85
+ block = InteractionBlock(hidden_channels, edge_channels,
86
+ num_filters, cutoff)
87
+ self.interactions.append(block)
88
+ self.reset_parameters()
89
+
90
+ def reset_parameters(self):
91
+ for interaction in self.interactions:
92
+ interaction.reset_parameters()
93
+
94
+ @property
95
+ def out_channels(self):
96
+ return self.hidden_channels
97
+
98
+ def forward(self, node_attr, pos, batch):
99
+ edge_index = radius_graph(pos, self.cutoff, batch=batch, loop=False)
100
+ edge_length = torch.norm(pos[edge_index[0]] - pos[edge_index[1]], dim=1)
101
+ edge_attr = self.distance_expansion(edge_length)
102
+ h = node_attr
103
+ for interaction in self.interactions:
104
+ h = h + interaction(h, edge_index, edge_length, edge_attr)
105
+ return h
models/encoders/tf.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from torch import nn
4
+ from torch.nn import Module, Sequential, ModuleList, Linear, Conv1d, LeakyReLU
5
+ from torch_geometric.nn import radius_graph, knn_graph
6
+ from torch_scatter import scatter_sum, scatter_softmax
7
+ import math
8
+ from math import pi as PI
9
+
10
+ from ..common import GaussianSmearing, ShiftedSoftplus
11
+
12
+
13
+ class AttentionInteractionBlock(Module):
14
+
15
+ def __init__(self, hidden_channels, edge_channels, key_channels, num_heads=1):
16
+ super().__init__()
17
+
18
+ assert hidden_channels % num_heads == 0
19
+ assert key_channels % num_heads == 0
20
+
21
+ self.hidden_channels = hidden_channels
22
+ self.key_channels = key_channels
23
+ self.num_heads = num_heads
24
+
25
+ self.k_lin = Conv1d(hidden_channels, key_channels, 1, groups=num_heads, bias=False)
26
+ self.q_lin = Conv1d(hidden_channels, key_channels, 1, groups=num_heads, bias=False)
27
+ self.v_lin = Conv1d(hidden_channels, hidden_channels, 1, groups=num_heads, bias=False)
28
+
29
+ self.weight_k_net = Sequential(
30
+ Linear(edge_channels, key_channels // num_heads),
31
+ LeakyReLU(),
32
+ Linear(key_channels // num_heads, key_channels // num_heads),
33
+ )
34
+ self.weight_k_lin = Linear(key_channels // num_heads, key_channels // num_heads)
35
+
36
+ self.weight_v_net = Sequential(
37
+ Linear(edge_channels, hidden_channels // num_heads),
38
+ LeakyReLU(),
39
+ Linear(hidden_channels // num_heads, hidden_channels // num_heads),
40
+ )
41
+ self.weight_v_lin = Linear(hidden_channels // num_heads, hidden_channels // num_heads)
42
+
43
+ self.centroid_lin = Linear(hidden_channels, hidden_channels)
44
+ self.act = LeakyReLU()
45
+ self.out_transform = Linear(hidden_channels, hidden_channels)
46
+ self.layernorm_ffn = nn.LayerNorm(hidden_channels)
47
+
48
+ def forward(self, x, edge_index, edge_attr):
49
+ """
50
+ Args:
51
+ x: Node features, (N, H).
52
+ edge_index: (2, E).
53
+ edge_attr: (E, H)
54
+ """
55
+ N = x.size(0)
56
+ row, col = edge_index # (E,) , (E,)
57
+
58
+ # Project to multiple key, query and value spaces
59
+ h_keys = self.k_lin(x.unsqueeze(-1)).view(N, self.num_heads, -1) # (N, heads, K_per_head)
60
+ h_queries = self.q_lin(x.unsqueeze(-1)).view(N, self.num_heads, -1) # (N, heads, K_per_head)
61
+ h_values = self.v_lin(x.unsqueeze(-1)).view(N, self.num_heads, -1) # (N, heads, H_per_head)
62
+
63
+ # Compute keys and queries
64
+ W_k = self.weight_k_net(edge_attr) # (E, K_per_head)
65
+ keys_j = self.weight_k_lin(W_k.unsqueeze(1) * h_keys[col]) # (E, heads, K_per_head)
66
+ queries_i = h_queries[row] # (E, heads, K_per_head)
67
+
68
+ # Compute attention weights (alphas)
69
+ d = int(self.hidden_channels / self.num_heads)
70
+ qk_ij = (queries_i * keys_j).sum(-1) / math.sqrt(d) # (E, heads)
71
+ alpha = scatter_softmax(qk_ij, row, dim=0)
72
+
73
+ # Compose messages
74
+ W_v = self.weight_v_net(edge_attr) # (E, H_per_head)
75
+ msg_j = self.weight_v_lin(W_v.unsqueeze(1) * h_values[col]) # (E, heads, H_per_head)
76
+ msg_j = alpha.unsqueeze(-1) * msg_j # (E, heads, H_per_head)
77
+
78
+ # Aggregate messages
79
+ aggr_msg = scatter_sum(msg_j, row, dim=0, dim_size=N).view(N, -1) # (N, heads*H_per_head)
80
+ out = self.centroid_lin(x) + aggr_msg
81
+ out = self.layernorm_ffn(out)
82
+ out = self.out_transform(self.act(out))
83
+ return out
84
+
85
+
86
+ class TransformerEncoder(Module):
87
+
88
+ def __init__(self, hidden_channels=256, edge_channels=64, key_channels=128, num_heads=4, num_interactions=6, k=32,
89
+ cutoff=10.0):
90
+ super().__init__()
91
+
92
+ self.hidden_channels = hidden_channels
93
+ self.edge_channels = edge_channels
94
+ self.key_channels = key_channels
95
+ self.num_heads = num_heads
96
+ self.num_interactions = num_interactions
97
+ self.k = k
98
+ self.cutoff = cutoff
99
+
100
+ self.distance_expansion = GaussianSmearing(stop=cutoff, num_gaussians=edge_channels)
101
+ self.interactions = ModuleList()
102
+ for _ in range(num_interactions):
103
+ block = AttentionInteractionBlock(
104
+ hidden_channels=hidden_channels,
105
+ edge_channels=edge_channels,
106
+ key_channels=key_channels,
107
+ num_heads=num_heads,
108
+ )
109
+ self.interactions.append(block)
110
+
111
+ @property
112
+ def out_channels(self):
113
+ return self.hidden_channels
114
+
115
+ def forward(self, node_attr, pos, batch):
116
+ # edge_index = radius_graph(pos, self.cutoff, batch=batch, loop=False)
117
+ edge_index = knn_graph(pos, k=self.k, batch=batch, flow='target_to_source')
118
+ edge_length = torch.norm(pos[edge_index[0]] - pos[edge_index[1]], dim=1)
119
+ edge_attr = self.distance_expansion(edge_length)
120
+
121
+ h = node_attr
122
+ for interaction in self.interactions:
123
+ h = h + interaction(h, edge_index, edge_attr)
124
+ return h
125
+
126
+
127
+ if __name__ == '__main__':
128
+ from torch_geometric.data import Data, Batch
129
+
130
+ hidden_channels = 64
131
+ edge_channels = 48
132
+ key_channels = 32
133
+ num_heads = 4
134
+
135
+ data_list = []
136
+ for num_nodes in [11, 13, 15]:
137
+ data_list.append(Data(
138
+ x=torch.randn([num_nodes, hidden_channels]),
139
+ pos=torch.randn([num_nodes, 3]) * 2
140
+ ))
141
+ batch = Batch.from_data_list(data_list)
142
+
143
+ model = CFTransformerEncoder(
144
+ hidden_channels=hidden_channels,
145
+ edge_channels=edge_channels,
146
+ key_channels=key_channels,
147
+ num_heads=num_heads,
148
+ )
149
+ out = model(batch.x, batch.pos, batch.batch)
150
+
151
+ print(out)
152
+ print(out.size())
models/flag.py ADDED
@@ -0,0 +1,268 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ sys.path.append("..")
3
+ import torch
4
+ import torch.nn as nn
5
+ from torch.nn import Module, Linear, Embedding
6
+ from torch.nn import functional as F
7
+ from torch_scatter import scatter_add, scatter_mean
8
+ from torch_geometric.data import Data, Batch
9
+ from copy import deepcopy
10
+
11
+ from .encoders import get_encoder, GNN_graphpred, MLP
12
+ from .common import *
13
+ from utils import dihedral_utils, chemutils
14
+
15
+
16
+ class FLAG(Module):
17
+
18
+ def __init__(self, config, protein_atom_feature_dim, ligand_atom_feature_dim, vocab):
19
+ super().__init__()
20
+ self.config = config
21
+ self.vocab = vocab
22
+ self.protein_atom_emb = Linear(protein_atom_feature_dim, config.hidden_channels)
23
+ self.ligand_atom_emb = Linear(ligand_atom_feature_dim, config.hidden_channels)
24
+ self.embedding = nn.Embedding(vocab.size() + 1, config.hidden_channels)
25
+ self.W = nn.Linear(2 * config.hidden_channels, config.hidden_channels)
26
+ self.W_o = nn.Linear(config.hidden_channels, self.vocab.size())
27
+ self.encoder = get_encoder(config.encoder)
28
+ self.comb_head = GNN_graphpred(num_layer=3, emb_dim=config.hidden_channels, num_tasks=1, JK='last',
29
+ drop_ratio=0.5, graph_pooling='mean', gnn_type='gin')
30
+ if config.random_alpha:
31
+ self.alpha_mlp = MLP(in_dim=config.hidden_channels * 4, out_dim=1, num_layers=2)
32
+ else:
33
+ self.alpha_mlp = MLP(in_dim=config.hidden_channels * 3, out_dim=1, num_layers=2)
34
+ self.focal_mlp_ligand = MLP(in_dim=config.hidden_channels, out_dim=1, num_layers=1)
35
+ self.focal_mlp_protein = MLP(in_dim=config.hidden_channels, out_dim=1, num_layers=1)
36
+ self.dist_mlp = MLP(in_dim=protein_atom_feature_dim + ligand_atom_feature_dim, out_dim=1, num_layers=2)
37
+ if config.refinement:
38
+ self.refine_protein = MLP(in_dim=config.hidden_channels * 2 + config.encoder.edge_channels, out_dim=1, num_layers=2)
39
+ self.refine_ligand = MLP(in_dim=config.hidden_channels * 2 + config.encoder.edge_channels, out_dim=1, num_layers=2)
40
+
41
+ self.smooth_cross_entropy = SmoothCrossEntropyLoss(reduction='mean', smoothing=0.1)
42
+ self.pred_loss = nn.CrossEntropyLoss()
43
+ self.comb_loss = nn.BCEWithLogitsLoss()
44
+ self.three_hop_loss = torch.nn.MSELoss()
45
+ self.focal_loss = nn.BCEWithLogitsLoss()
46
+ self.dist_loss = torch.nn.MSELoss(reduction='mean')
47
+
48
+ def forward(self, protein_pos, protein_atom_feature, ligand_pos, ligand_atom_feature, batch_protein, batch_ligand):
49
+ h_protein = self.protein_atom_emb(protein_atom_feature)
50
+ h_ligand = self.ligand_atom_emb(ligand_atom_feature)
51
+
52
+ h_ctx, pos_ctx, batch_ctx, protein_mask = compose_context_stable(h_protein=h_protein, h_ligand=h_ligand,
53
+ pos_protein=protein_pos, pos_ligand=ligand_pos,
54
+ batch_protein=batch_protein,
55
+ batch_ligand=batch_ligand)
56
+ h_ctx = self.encoder(node_attr=h_ctx, pos=pos_ctx, batch=batch_ctx) # (N_p+N_l, H)
57
+ focal_pred = torch.cat([self.focal_mlp_protein(h_ctx[protein_mask]), self.focal_mlp_ligand(h_ctx[~protein_mask])], dim=0)
58
+
59
+ return focal_pred, protein_mask, h_ctx
60
+
61
+ def forward_motif(self, h_ctx_focal, current_wid, current_atoms_batch, n_samples=1):
62
+ node_hiddens = scatter_add(h_ctx_focal, dim=0, index=current_atoms_batch)
63
+ motif_hiddens = self.embedding(current_wid)
64
+ pred_vecs = torch.cat([node_hiddens, motif_hiddens], dim=1)
65
+ pred_vecs = nn.ReLU()(self.W(pred_vecs))
66
+ pred_scores = self.W_o(pred_vecs)
67
+ pred_scores = F.softmax(pred_scores, dim=-1)
68
+ _, preds = torch.max(pred_scores, dim=1)
69
+ # random select n_samples in topk
70
+ k = 5*n_samples
71
+ select_pool = torch.topk(pred_scores, k, dim=1)[1]
72
+ index = torch.randint(k, (select_pool.shape[0], n_samples))
73
+ preds = torch.cat([select_pool[i][index[i]] for i in range(len(index))])
74
+
75
+ idx_parent = torch.repeat_interleave(torch.arange(pred_scores.shape[0]), n_samples, dim=0).to(pred_scores.device)
76
+ prob = pred_scores[idx_parent, preds]
77
+ return preds, prob
78
+
79
+ def forward_attach(self, mol_list, next_motif_smiles, device):
80
+ cand_mols, cand_batch, new_atoms, one_atom_attach, intersection, attach_fail = chemutils.assemble(mol_list, next_motif_smiles)
81
+ graph_data = Batch.from_data_list([chemutils.mol_to_graph_data_obj_simple(mol) for mol in cand_mols]).to(device)
82
+ comb_pred = self.comb_head(graph_data.x, graph_data.edge_index, graph_data.edge_attr, graph_data.batch).reshape(-1)
83
+ slice_idx = torch.cat([torch.tensor([0]), torch.cumsum(cand_batch.bincount(), dim=0)], dim=0)
84
+ select = [(torch.argmax(comb_pred[slice_idx[i]:slice_idx[i + 1]]) + slice_idx[i]).item() for i in
85
+ range(len(slice_idx) - 1)]
86
+ '''
87
+ select = []
88
+ for k in range(len(slice_idx) - 1):
89
+ id = torch.multinomial(torch.exp(comb_pred[slice_idx[k]:slice_idx[k + 1]]).reshape(-1).float(), 1)
90
+ select.append((id+slice_idx[k]).item())'''
91
+
92
+ select_mols = [cand_mols[i] for i in select]
93
+ new_atoms = [new_atoms[i] for i in select]
94
+ one_atom_attach = [one_atom_attach[i] for i in select]
95
+ intersection = [intersection[i] for i in select]
96
+ return select_mols, new_atoms, one_atom_attach, intersection, attach_fail
97
+
98
+ def forward_alpha(self, protein_pos, protein_atom_feature, ligand_pos, ligand_atom_feature, batch_protein,
99
+ batch_ligand, xy_index, rotatable):
100
+ # encode again
101
+ h_protein = self.protein_atom_emb(protein_atom_feature)
102
+ h_ligand = self.ligand_atom_emb(ligand_atom_feature)
103
+
104
+ h_ctx, pos_ctx, batch_ctx, protein_mask = compose_context_stable(h_protein=h_protein, h_ligand=h_ligand,
105
+ pos_protein=protein_pos, pos_ligand=ligand_pos,
106
+ batch_protein=batch_protein,
107
+ batch_ligand=batch_ligand)
108
+ h_ctx = self.encoder(node_attr=h_ctx, pos=pos_ctx, batch=batch_ctx) # (N_p+N_l, H)
109
+ h_ctx_ligand = h_ctx[~protein_mask]
110
+ hx, hy = h_ctx_ligand[xy_index[:, 0]], h_ctx_ligand[xy_index[:, 1]]
111
+ h_mol = scatter_add(h_ctx_ligand, dim=0, index=batch_ligand)
112
+ h_mol = h_mol[rotatable]
113
+ if self.config.random_alpha:
114
+ rand_dist = torch.distributions.normal.Normal(loc=0, scale=1)
115
+ rand_alpha = rand_dist.sample(hx.shape).to(hx.device)
116
+ alpha = self.alpha_mlp(torch.cat([hx, hy, h_mol, rand_alpha], dim=-1))
117
+ else:
118
+ alpha = self.alpha_mlp(torch.cat([hx, hy, h_mol], dim=-1))
119
+ return alpha
120
+
121
+ def get_loss(self, protein_pos, protein_atom_feature, ligand_pos, ligand_atom_feature, ligand_pos_torsion,
122
+ ligand_atom_feature_torsion, batch_protein, batch_ligand, batch_ligand_torsion, batch):
123
+ self.device = protein_pos.device
124
+ h_protein = self.protein_atom_emb(protein_atom_feature)
125
+ h_ligand = self.ligand_atom_emb(ligand_atom_feature)
126
+
127
+ loss_list = [0, 0, 0, 0, 0, 0]
128
+
129
+ # Encode for motif prediction
130
+ h_ctx, pos_ctx, batch_ctx, mask_protein = compose_context_stable(h_protein=h_protein, h_ligand=h_ligand,
131
+ pos_protein=protein_pos, pos_ligand=ligand_pos,
132
+ batch_protein=batch_protein,
133
+ batch_ligand=batch_ligand)
134
+ h_ctx = self.encoder(node_attr=h_ctx, pos=pos_ctx, batch=batch_ctx) # (N_p+N_l, H)
135
+ h_ctx_ligand = h_ctx[~mask_protein]
136
+ h_ctx_protein = h_ctx[mask_protein]
137
+ h_ctx_focal = h_ctx[batch['current_atoms']]
138
+
139
+ # Encode for torsion prediction
140
+ if len(batch['y_pos']) > 0:
141
+ h_ligand_torsion = self.ligand_atom_emb(ligand_atom_feature_torsion)
142
+ h_ctx_torison, pos_ctx_torison, batch_ctx_torsion, mask_protein = compose_context_stable(h_protein=h_protein,
143
+ h_ligand=h_ligand_torsion,
144
+ pos_protein=protein_pos,
145
+ pos_ligand=ligand_pos_torsion,
146
+ batch_protein=batch_protein,
147
+ batch_ligand=batch_ligand_torsion)
148
+ h_ctx_torsion = self.encoder(node_attr=h_ctx_torison, pos=pos_ctx_torison, batch=batch_ctx_torsion) # (N_p+N_l, H)
149
+ h_ctx_ligand_torsion = h_ctx_torsion[~mask_protein]
150
+
151
+ # next motif prediction
152
+
153
+ node_hiddens = scatter_add(h_ctx_focal, dim=0, index=batch['current_atoms_batch'])
154
+ motif_hiddens = self.embedding(batch['current_wid'])
155
+ pred_vecs = torch.cat([node_hiddens, motif_hiddens], dim=1)
156
+ pred_vecs = nn.ReLU()(self.W(pred_vecs))
157
+ pred_scores = self.W_o(pred_vecs)
158
+ pred_loss = self.pred_loss(pred_scores, batch['next_wid'])
159
+ loss_list[0] = pred_loss.item()
160
+
161
+ # attachment prediction
162
+ if len(batch['cand_labels']) > 0:
163
+ cand_mols = batch['cand_mols']
164
+ comb_pred = self.comb_head(cand_mols.x, cand_mols.edge_index, cand_mols.edge_attr, cand_mols.batch)
165
+ comb_loss = self.comb_loss(comb_pred, batch['cand_labels'].view(comb_pred.shape).float())
166
+ loss_list[1] = comb_loss.item()
167
+ else:
168
+ comb_loss = 0
169
+
170
+ # focal prediction
171
+ focal_ligand_pred, focal_protein_pred = self.focal_mlp_ligand(h_ctx_ligand), self.focal_mlp_protein(h_ctx_protein)
172
+ focal_loss = self.focal_loss(focal_ligand_pred.reshape(-1), batch['ligand_frontier'].float()) +\
173
+ self.focal_loss(focal_protein_pred.reshape(-1), batch['protein_contact'].float())
174
+ loss_list[2] = focal_loss.item()
175
+
176
+ # distance matrix prediction
177
+ if len(batch['true_dm']) > 0:
178
+ input = torch.cat([protein_atom_feature[batch['dm_protein_idx']], ligand_atom_feature[batch['dm_ligand_idx']]], dim=-1)
179
+ pred_dist = self.dist_mlp(input)
180
+ dm_target = batch['true_dm'].unsqueeze(-1)
181
+ dm_loss = self.dist_loss(pred_dist, dm_target)
182
+ loss_list[3] = dm_loss.item()
183
+ else:
184
+ dm_loss = 0
185
+
186
+ # structure refinement loss
187
+ if self.config.refinement and len(batch['true_dm']) > 0:
188
+ true_distance_alpha = torch.norm(batch['ligand_context_pos'][batch['sr_ligand_idx']] - batch['protein_pos'][batch['sr_protein_idx']], dim=1)
189
+ true_distance_intra = torch.norm(batch['ligand_context_pos'][batch['sr_ligand_idx0']] - batch['ligand_context_pos'][batch['sr_ligand_idx1']], dim=1)
190
+ input_distance_alpha = ligand_pos[batch['sr_ligand_idx']] - protein_pos[batch['sr_protein_idx']]
191
+ input_distance_intra = ligand_pos[batch['sr_ligand_idx0']] - ligand_pos[batch['sr_ligand_idx1']]
192
+ distance_emb1 = self.encoder.distance_expansion(torch.norm(input_distance_alpha, dim=1))
193
+ distance_emb2 = self.encoder.distance_expansion(torch.norm(input_distance_intra, dim=1))
194
+ input1 = torch.cat([h_ctx_ligand[batch['sr_ligand_idx']], h_ctx_protein[batch['sr_protein_idx']], distance_emb1], dim=-1)[true_distance_alpha<=10.0]
195
+ input2 = torch.cat([h_ctx_ligand[batch['sr_ligand_idx0']], h_ctx_ligand[batch['sr_ligand_idx1']], distance_emb2], dim=-1)[true_distance_intra<=10.0]
196
+ #distance cut_off
197
+ norm_dir1 = F.normalize(input_distance_alpha, p=2, dim=1)[true_distance_alpha<=10.0]
198
+ norm_dir2 = F.normalize(input_distance_intra, p=2, dim=1)[true_distance_intra<=10.0]
199
+ force1 = scatter_mean(self.refine_protein(input1)*norm_dir1, dim=0, index=batch['sr_ligand_idx'][true_distance_alpha<=10.0], dim_size=ligand_pos.size(0))
200
+ force2 = scatter_mean(self.refine_ligand(input2)*norm_dir2, dim=0, index=batch['sr_ligand_idx0'][true_distance_intra<=10.0], dim_size=ligand_pos.size(0))
201
+ new_ligand_pos = deepcopy(ligand_pos)
202
+ new_ligand_pos += force1
203
+ new_ligand_pos += force2
204
+ refine_dist1 = torch.norm(new_ligand_pos[batch['sr_ligand_idx']] - protein_pos[batch['sr_protein_idx']], dim=1)
205
+ refine_dist2 = torch.norm(new_ligand_pos[batch['sr_ligand_idx0']] - new_ligand_pos[batch['sr_ligand_idx1']], dim=1)
206
+ sr_loss = (self.dist_loss(refine_dist1, true_distance_alpha) + self.dist_loss(refine_dist2, true_distance_intra))
207
+ loss_list[5] = sr_loss.item()
208
+ else:
209
+ sr_loss = 0
210
+
211
+ # torsion prediction
212
+ if len(batch['y_pos']) > 0:
213
+ Hx = dihedral_utils.rotation_matrix_v2(batch['y_pos'])
214
+ xn_pos = torch.matmul(Hx, batch['xn_pos'].permute(0, 2, 1)).permute(0, 2, 1)
215
+ yn_pos = torch.matmul(Hx, batch['yn_pos'].permute(0, 2, 1)).permute(0, 2, 1)
216
+ y_pos = torch.matmul(Hx, batch['y_pos'].unsqueeze(1).permute(0, 2, 1)).squeeze(-1)
217
+
218
+ hx, hy = h_ctx_ligand_torsion[batch['ligand_torsion_xy_index'][:, 0]], h_ctx_ligand_torsion[batch['ligand_torsion_xy_index'][:, 1]]
219
+ h_mol = scatter_add(h_ctx_ligand_torsion, dim=0, index=batch['ligand_element_torsion_batch'])
220
+ if self.config.random_alpha:
221
+ rand_dist = torch.distributions.normal.Normal(loc=0, scale=1)
222
+ rand_alpha = rand_dist.sample(hx.shape).to(self.device)
223
+ alpha = self.alpha_mlp(torch.cat([hx, hy, h_mol, rand_alpha], dim=-1))
224
+ else:
225
+ alpha = self.alpha_mlp(torch.cat([hx, hy, h_mol], dim=-1))
226
+ # rotate xn
227
+ R_alpha = self.build_alpha_rotation(torch.sin(alpha).squeeze(-1), torch.cos(alpha).squeeze(-1))
228
+ xn_pos = torch.matmul(R_alpha, xn_pos.permute(0, 2, 1)).permute(0, 2, 1)
229
+
230
+ p_idx, q_idx = torch.cartesian_prod(torch.arange(3), torch.arange(3)).chunk(2, dim=-1)
231
+ p_idx, q_idx = p_idx.squeeze(-1), q_idx.squeeze(-1)
232
+ pred_sin, pred_cos = dihedral_utils.batch_dihedrals(xn_pos[:, p_idx],
233
+ torch.zeros_like(y_pos).unsqueeze(1).repeat(1, 9, 1),
234
+ y_pos.unsqueeze(1).repeat(1, 9, 1),
235
+ yn_pos[:, q_idx])
236
+ dihedral_loss = torch.mean(dihedral_utils.von_Mises_loss(batch['true_cos'], pred_cos.reshape(-1), batch['true_sin'], pred_cos.reshape(-1))[batch['dihedral_mask']])
237
+ torsion_loss = -dihedral_loss
238
+ loss_list[4] = torsion_loss.item()
239
+ else:
240
+ torsion_loss = 0
241
+
242
+ # dm: distance matrix
243
+ loss = pred_loss + comb_loss + focal_loss + dm_loss + torsion_loss + sr_loss
244
+
245
+ return loss, loss_list
246
+
247
+ def build_alpha_rotation(self, alpha, alpha_cos=None):
248
+ """
249
+ Builds the alpha rotation matrix
250
+
251
+ :param alpha: predicted values of torsion parameter alpha (n_dihedral_pairs)
252
+ :return: alpha rotation matrix (n_dihedral_pairs, 3, 3)
253
+ """
254
+ H_alpha = torch.FloatTensor([[[1, 0, 0], [0, 0, 0], [0, 0, 0]]]).repeat(alpha.shape[0], 1, 1).to(self.device)
255
+
256
+ if torch.is_tensor(alpha_cos):
257
+ H_alpha[:, 1, 1] = alpha_cos
258
+ H_alpha[:, 1, 2] = -alpha
259
+ H_alpha[:, 2, 1] = alpha
260
+ H_alpha[:, 2, 2] = alpha_cos
261
+ else:
262
+ H_alpha[:, 1, 1] = torch.cos(alpha)
263
+ H_alpha[:, 1, 2] = -torch.sin(alpha)
264
+ H_alpha[:, 2, 1] = torch.sin(alpha)
265
+ H_alpha[:, 2, 2] = torch.cos(alpha)
266
+
267
+ return H_alpha
268
+
motif_sample.py ADDED
@@ -0,0 +1,660 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import shutil
3
+ import argparse
4
+ import random
5
+ import torch
6
+ import numpy as np
7
+ import math
8
+ from vina import Vina
9
+ from openbabel import pybel
10
+ import subprocess
11
+ import multiprocessing as mp
12
+ from functools import partial
13
+ from torch_geometric.data import Batch
14
+ from tqdm.auto import tqdm
15
+ from rdkit import Chem
16
+ from rdkit.Geometry import Point3D
17
+ from torch.utils.data import DataLoader
18
+ from rdkit.Chem.rdchem import BondType
19
+ from rdkit.Chem import ChemicalFeatures, rdMolDescriptors
20
+ from rdkit import RDConfig
21
+ from rdkit.Chem.Descriptors import MolLogP, qed
22
+ from copy import deepcopy
23
+ import tempfile
24
+ import AutoDockTools
25
+ import contextlib
26
+ from torch_scatter import scatter_add, scatter_mean
27
+ from rdkit.Geometry import Point3D
28
+ from meeko import MoleculePreparation
29
+ from meeko import obutils
30
+ from models.flag import FLAG
31
+ from utils.transforms import *
32
+ from utils.datasets import get_dataset
33
+ from utils.misc import *
34
+ from utils.data import *
35
+ from utils.mol_tree import *
36
+ from utils.chemutils import *
37
+ from utils.dihedral_utils import *
38
+ from utils.sascorer import compute_sa_score
39
+ from rdkit.Chem import AllChem
40
+
41
+ _fscores = None
42
+
43
+ ATOM_FAMILIES = ['Acceptor', 'Donor', 'Aromatic', 'Hydrophobe', 'LumpedHydrophobe', 'NegIonizable', 'PosIonizable',
44
+ 'ZnBinder']
45
+ ATOM_FAMILIES_ID = {s: i for i, s in enumerate(ATOM_FAMILIES)}
46
+
47
+ STATUS_RUNNING = 'running'
48
+ STATUS_FINISHED = 'finished'
49
+ STATUS_FAILED = 'failed'
50
+
51
+
52
+ def supress_stdout(func):
53
+ def wrapper(*a, **ka):
54
+ with open(os.devnull, 'w') as devnull:
55
+ with contextlib.redirect_stdout(devnull):
56
+ return func(*a, **ka)
57
+ return wrapper
58
+
59
+
60
+ class PrepLig(object):
61
+ def __init__(self, input_mol, mol_format):
62
+ if mol_format == 'smi':
63
+ self.ob_mol = pybel.readstring('smi', input_mol)
64
+ elif mol_format == 'sdf':
65
+ self.ob_mol = next(pybel.readfile(mol_format, input_mol))
66
+ else:
67
+ raise ValueError(f'mol_format {mol_format} not supported')
68
+
69
+ def addH(self, polaronly=False, correctforph=True, PH=7):
70
+ self.ob_mol.OBMol.AddHydrogens(polaronly, correctforph, PH)
71
+ obutils.writeMolecule(self.ob_mol.OBMol, 'tmp_h.sdf')
72
+
73
+ def gen_conf(self):
74
+ sdf_block = self.ob_mol.write('sdf')
75
+ rdkit_mol = Chem.MolFromMolBlock(sdf_block, removeHs=False)
76
+ AllChem.EmbedMolecule(rdkit_mol, Chem.rdDistGeom.ETKDGv3())
77
+ self.ob_mol = pybel.readstring('sdf', Chem.MolToMolBlock(rdkit_mol))
78
+ obutils.writeMolecule(self.ob_mol.OBMol, 'conf_h.sdf')
79
+
80
+ @supress_stdout
81
+ def get_pdbqt(self, lig_pdbqt=None):
82
+ preparator = MoleculePreparation()
83
+ preparator.prepare(self.ob_mol.OBMol)
84
+ if lig_pdbqt is not None:
85
+ preparator.write_pdbqt_file(lig_pdbqt)
86
+ return
87
+ else:
88
+ return preparator.write_pdbqt_string()
89
+
90
+
91
+ class PrepProt(object):
92
+ def __init__(self, pdb_file):
93
+ self.prot = pdb_file
94
+
95
+ def del_water(self, dry_pdb_file): # optional
96
+ with open(self.prot) as f:
97
+ lines = [l for l in f.readlines() if l.startswith('ATOM') or l.startswith('HETATM')]
98
+ dry_lines = [l for l in lines if not 'HOH' in l]
99
+
100
+ with open(dry_pdb_file, 'w') as f:
101
+ f.write(''.join(dry_lines))
102
+ self.prot = dry_pdb_file
103
+
104
+ def addH(self, prot_pqr): # call pdb2pqr
105
+ self.prot_pqr = prot_pqr
106
+ subprocess.Popen(['pdb2pqr30', '--ff=AMBER', self.prot, self.prot_pqr],
107
+ stderr=subprocess.DEVNULL, stdout=subprocess.DEVNULL).communicate()
108
+
109
+ def get_pdbqt(self, prot_pdbqt):
110
+ prepare_receptor = os.path.join(AutoDockTools.__path__[0], 'Utilities24/prepare_receptor4.py')
111
+ subprocess.Popen(['python3', prepare_receptor, '-r', self.prot_pqr, '-o', prot_pdbqt],
112
+ stderr=subprocess.DEVNULL, stdout=subprocess.DEVNULL).communicate()
113
+
114
+
115
+ def calculate_vina(number, pro_path, lig_path):
116
+ lig_path = os.path.join(lig_path, str(number)+'.sdf')
117
+ size_factor = 1.2
118
+ buffer = 5.
119
+ # openmm_relax(pro_path)
120
+ # relax_sdf(lig_path)
121
+ mol = Chem.MolFromMolFile(lig_path, sanitize=True)
122
+ pos = mol.GetConformer(0).GetPositions()
123
+ center = np.mean(pos, 0)
124
+ ligand_pdbqt = './data/tmp/' + str(number) + '_lig.pdbqt'
125
+ protein_pqr = './data/tmp/' + str(number) + '_pro.pqr'
126
+ protein_pdbqt = './data/tmp/' + str(number) + '_pro.pdbqt'
127
+ lig = PrepLig(lig_path, 'sdf')
128
+ lig.addH()
129
+ lig.get_pdbqt(ligand_pdbqt)
130
+
131
+ prot = PrepProt(pro_path)
132
+ prot.addH(protein_pqr)
133
+ prot.get_pdbqt(protein_pdbqt)
134
+
135
+ v = Vina(sf_name='vina', seed=0, verbosity=0)
136
+ v.set_receptor(protein_pdbqt)
137
+ v.set_ligand_from_file(ligand_pdbqt)
138
+ x, y, z = (pos.max(0) - pos.min(0)) * size_factor + buffer
139
+ v.compute_vina_maps(center=center, box_size=[x, y, z])
140
+ energy = v.score()
141
+ print('Score before minimization: %.3f (kcal/mol)' % energy[0])
142
+ energy_minimized = v.optimize()
143
+ print('Score after minimization : %.3f (kcal/mol)' % energy_minimized[0])
144
+ v.dock(exhaustiveness=64, n_poses=32)
145
+ score = v.energies(n_poses=1)[0][0]
146
+ print('Score after docking : %.3f (kcal/mol)' % score)
147
+
148
+ return score
149
+
150
+
151
+ def get_feat(mol):
152
+ fdefName = os.path.join(RDConfig.RDDataDir, 'BaseFeatures.fdef')
153
+ factory = ChemicalFeatures.BuildFeatureFactory(fdefName)
154
+ atomic_numbers = torch.LongTensor([6, 7, 8, 9, 15, 16, 17]) # C N O F P S Cl
155
+ ptable = Chem.GetPeriodicTable()
156
+ Chem.SanitizeMol(mol)
157
+ feat_mat = np.zeros([mol.GetNumAtoms(), len(ATOM_FAMILIES)], dtype=np.int_)
158
+ for feat in factory.GetFeaturesForMol(mol):
159
+ feat_mat[feat.GetAtomIds(), ATOM_FAMILIES_ID[feat.GetFamily()]] = 1
160
+ ligand_element = torch.tensor([ptable.GetAtomicNumber(atom.GetSymbol()) for atom in mol.GetAtoms()])
161
+ element = ligand_element.view(-1, 1) == atomic_numbers.view(1, -1) # (N_atoms, N_elements)
162
+ return torch.cat([element, torch.tensor(feat_mat)], dim=-1).float()
163
+
164
+
165
+ def find_reference(protein_pos, focal_id):
166
+ # Select three reference protein atoms
167
+ d = torch.norm(protein_pos - protein_pos[focal_id], dim=1)
168
+ reference_idx = torch.topk(d, k=4, largest=False)[1]
169
+ reference_pos = protein_pos[reference_idx]
170
+ return reference_pos, reference_idx
171
+
172
+
173
+ def SetAtomNum(mol, atoms):
174
+ for atom in mol.GetAtoms():
175
+ if atom.GetIdx() in atoms:
176
+ atom.SetAtomMapNum(1)
177
+ else:
178
+ atom.SetAtomMapNum(0)
179
+ return mol
180
+
181
+
182
+ def SetMolPos(mol_list, pos_list):
183
+ new_mol_list = []
184
+ for i in range(len(pos_list)):
185
+ mol = mol_list[i]
186
+ conf = mol.GetConformer(0)
187
+ pos = pos_list[i].cpu().double().numpy()
188
+ if mol.GetNumAtoms() == len(pos):
189
+ for node in range(mol.GetNumAtoms()):
190
+ x, y, z = pos[node]
191
+ conf.SetAtomPosition(node, Point3D(x,y,z))
192
+ try:
193
+ AllChem.UFFOptimizeMolecule(mol)
194
+ new_mol_list.append(mol)
195
+ except:
196
+ new_mol_list.append(mol)
197
+ return new_mol_list
198
+
199
+
200
+ def lipinski(mol):
201
+ count = 0
202
+ if qed(mol) <= 5:
203
+ count += 1
204
+ if Chem.Lipinski.NumHDonors(mol) <= 5:
205
+ count += 1
206
+ if Chem.Lipinski.NumHAcceptors(mol) <= 10:
207
+ count += 1
208
+ if Chem.Descriptors.ExactMolWt(mol) <= 500:
209
+ count += 1
210
+ if Chem.Lipinski.NumRotatableBonds(mol) <= 5:
211
+ count += 1
212
+ return count
213
+
214
+
215
+ def refine_pos(ligand_pos, protein_pos, h_ctx_ligand, h_ctx_protein, model, batch, repeats, protein_batch,
216
+ ligand_batch):
217
+ protein_offsets = torch.cumsum(protein_batch.bincount(), dim=0)
218
+ ligand_offsets = torch.cumsum(ligand_batch.bincount(), dim=0)
219
+ protein_offsets, ligand_offsets = torch.cat([torch.tensor([0]), protein_offsets]), torch.cat([torch.tensor([0]), ligand_offsets])
220
+
221
+ sr_ligand_idx, sr_protein_idx = [], []
222
+ sr_ligand_idx0, sr_ligand_idx1 = [], []
223
+ for i in range(len(repeats)):
224
+ alpha_index = batch['alpha_carbon_indicator'][protein_batch == i].nonzero().reshape(-1)
225
+ ligand_atom_index = torch.arange(repeats[i])
226
+
227
+ p_idx, q_idx = torch.cartesian_prod(ligand_atom_index, torch.arange(len(alpha_index))).chunk(2, dim=-1)
228
+ p_idx, q_idx = p_idx.squeeze(-1), q_idx.squeeze(-1)
229
+ sr_ligand_idx.append(ligand_atom_index[p_idx] + ligand_offsets[i])
230
+ sr_protein_idx.append(alpha_index[q_idx] + protein_offsets[i])
231
+
232
+ p_idx, q_idx = torch.cartesian_prod(ligand_atom_index, ligand_atom_index).chunk(2, dim=-1)
233
+ p_idx, q_idx = p_idx.squeeze(-1), q_idx.squeeze(-1)
234
+ sr_ligand_idx0.append(ligand_atom_index[p_idx] + ligand_offsets[i])
235
+ sr_ligand_idx1.append(ligand_atom_index[q_idx] + ligand_offsets[i])
236
+ sr_ligand_idx, sr_protein_idx = torch.cat(sr_ligand_idx).long(), torch.cat(sr_protein_idx).long()
237
+ sr_ligand_idx0, sr_ligand_idx1 = torch.cat(sr_ligand_idx0).long(), torch.cat(sr_ligand_idx1).long()
238
+
239
+ dist_alpha = torch.norm(ligand_pos[sr_ligand_idx] - protein_pos[sr_protein_idx], dim=1)
240
+ dist_intra = torch.norm(ligand_pos[sr_ligand_idx0] - ligand_pos[sr_ligand_idx1], dim=1)
241
+ input_dir_alpha = ligand_pos[sr_ligand_idx] - protein_pos[sr_protein_idx]
242
+ input_dir_intra = ligand_pos[sr_ligand_idx0] - ligand_pos[sr_ligand_idx1]
243
+ distance_emb1 = model.encoder.distance_expansion(torch.norm(input_dir_alpha, dim=1))
244
+ distance_emb2 = model.encoder.distance_expansion(torch.norm(input_dir_intra, dim=1))
245
+ input1 = torch.cat([h_ctx_ligand[sr_ligand_idx], h_ctx_protein[sr_protein_idx], distance_emb1], dim=-1)[dist_alpha <= 10.0]
246
+ input2 = torch.cat([h_ctx_ligand[sr_ligand_idx0], h_ctx_ligand[sr_ligand_idx1], distance_emb2], dim=-1)[dist_intra <= 10.0]
247
+ # distance cut_off
248
+ norm_dir1 = F.normalize(input_dir_alpha, p=2, dim=1)[dist_alpha <= 10.0]
249
+ norm_dir2 = F.normalize(input_dir_intra, p=2, dim=1)[dist_intra <= 10.0]
250
+ force1 = scatter_mean(model.refine_protein(input1) * norm_dir1, dim=0, index=sr_ligand_idx[dist_alpha <= 10.0], dim_size=ligand_pos.size(0))
251
+ force2 = scatter_mean(model.refine_ligand(input2) * norm_dir2, dim=0, index=sr_ligand_idx0[dist_intra <= 10.0], dim_size=ligand_pos.size(0))
252
+ ligand_pos += force1
253
+ ligand_pos += force2
254
+
255
+ ligand_pos = [ligand_pos[ligand_batch==k].float() for k in range(len(repeats))]
256
+ return ligand_pos
257
+
258
+
259
+ def ligand_gen(batch, model, vocab, config, center, device, refinement=False):
260
+ pos_list = []
261
+ feat_list = []
262
+ motif_id = [0 for _ in range(config.sample.batch_size)]
263
+ finished = torch.zeros(config.sample.batch_size).bool()
264
+ for i in range(config.sample.max_steps):
265
+ print(i)
266
+ print(finished)
267
+ if torch.sum(finished) == config.sample.batch_size:
268
+ # mol_list = SetMolPos(mol_list, pos_list)
269
+ return mol_list, pos_list
270
+ if i == 0:
271
+ focal_pred, mask_protein, h_ctx = model(protein_pos=batch['protein_pos'],
272
+ protein_atom_feature=batch['protein_atom_feature'].float(),
273
+ ligand_pos=batch['ligand_context_pos'],
274
+ ligand_atom_feature=batch['ligand_context_feature_full'].float(),
275
+ batch_protein=batch['protein_element_batch'],
276
+ batch_ligand=batch['ligand_context_element_batch'])
277
+ protein_atom_feature = batch['protein_atom_feature'].float()
278
+ focal_protein = focal_pred[mask_protein]
279
+ h_ctx_protein = h_ctx[mask_protein]
280
+ focus_score = torch.sigmoid(focal_protein)
281
+ #can_focus = focus_score > 0.5
282
+ slice_idx = torch.cat([torch.tensor([0]).to(device), torch.cumsum(batch['protein_element_batch'].bincount(), dim=0)])
283
+ focal_id = []
284
+ for j in range(len(slice_idx) - 1):
285
+ focus = focus_score[slice_idx[j]:slice_idx[j + 1]]
286
+ focal_id.append(torch.argmax(focus.reshape(-1).float()).item() + slice_idx[j].item())
287
+ focal_id = torch.tensor(focal_id, device=device)
288
+
289
+ h_ctx_focal = h_ctx_protein[focal_id]
290
+ current_wid = torch.tensor([vocab.size()] * config.sample.batch_size, device=device)
291
+ next_motif_wid, motif_prob = model.forward_motif(h_ctx_focal, current_wid, torch.arange(config.sample.batch_size, device=device).to(device))
292
+ mol_list = [Chem.MolFromSmiles(vocab.get_smiles(id)) for id in next_motif_wid]
293
+
294
+ for j in range(config.sample.batch_size):
295
+ AllChem.EmbedMolecule(mol_list[j])
296
+ AllChem.UFFOptimizeMolecule(mol_list[j])
297
+ ligand_pos, ligand_feat = torch.tensor(mol_list[j].GetConformer().GetPositions(), device=device), get_feat(mol_list[j]).to(device)
298
+ feat_list.append(ligand_feat)
299
+ # set the initial positions with distance matrix
300
+ reference_pos, reference_idx = find_reference(batch['protein_pos'][slice_idx[j]:slice_idx[j + 1]], focal_id[j] - slice_idx[j])
301
+
302
+ p_idx, l_idx = torch.cartesian_prod(torch.arange(4), torch.arange(len(ligand_pos))).chunk(2, dim=-1)
303
+ p_idx = p_idx.squeeze(-1).to(device)
304
+ l_idx = l_idx.squeeze(-1).to(device)
305
+ d_m = model.dist_mlp(torch.cat([protein_atom_feature[reference_idx[p_idx]], ligand_feat[l_idx]], dim=-1)).reshape(4,len(ligand_pos))
306
+
307
+ d_m = d_m ** 2
308
+ p_d, l_d = self_square_dist(reference_pos), self_square_dist(ligand_pos)
309
+ D = torch.cat([torch.cat([p_d, d_m], dim=1), torch.cat([d_m.permute(1, 0), l_d], dim=1)])
310
+ coordinate = eig_coord_from_dist(D)
311
+ new_pos, _, _ = kabsch_torch(coordinate[:len(reference_pos)], reference_pos,
312
+ coordinate[len(reference_pos):])
313
+ # new_pos += (center*0.8+torch.mean(reference_pos, dim=0)*0.2) - torch.mean(new_pos, dim=0)
314
+ new_pos += (center - torch.mean(new_pos, dim=0)) * .8
315
+ pos_list.append(new_pos)
316
+
317
+ atom_to_motif = [{} for _ in range(config.sample.batch_size)]
318
+ motif_to_atoms = [{} for _ in range(config.sample.batch_size)]
319
+ motif_wid = [{} for _ in range(config.sample.batch_size)]
320
+ for j in range(config.sample.batch_size):
321
+ for k in range(mol_list[j].GetNumAtoms()):
322
+ atom_to_motif[j][k] = 0
323
+ for j in range(config.sample.batch_size):
324
+ motif_to_atoms[j][0] = list(np.arange(mol_list[j].GetNumAtoms()))
325
+ motif_wid[j][0] = next_motif_wid[j].item()
326
+ else:
327
+ repeats = torch.tensor([len(pos) for pos in pos_list], device=device)
328
+ ligand_batch = torch.repeat_interleave(torch.arange(config.sample.batch_size, device=device), repeats)
329
+ focal_pred, mask_protein, h_ctx = model(protein_pos=batch['protein_pos'].float(),
330
+ protein_atom_feature=batch['protein_atom_feature'].float(),
331
+ ligand_pos=torch.cat(pos_list, dim=0).float(),
332
+ ligand_atom_feature=torch.cat(feat_list, dim=0).float(),
333
+ batch_protein=batch['protein_element_batch'],
334
+ batch_ligand=ligand_batch)
335
+ # structure refinement
336
+ if refinement:
337
+ pos_list = refine_pos(torch.cat(pos_list, dim=0).float(), batch['protein_pos'].float(),
338
+ h_ctx[~mask_protein], h_ctx[mask_protein], model, batch, repeats.tolist(),
339
+ batch['protein_element_batch'], ligand_batch)
340
+
341
+ focal_ligand = focal_pred[~mask_protein]
342
+ h_ctx_ligand = h_ctx[~mask_protein]
343
+ focus_score = torch.sigmoid(focal_ligand)
344
+ can_focus = focus_score > 0.
345
+ slice_idx = torch.cat([torch.tensor([0], device=device), torch.cumsum(repeats, dim=0)])
346
+
347
+ current_atoms_batch, current_atoms = [], []
348
+ for j in range(len(slice_idx) - 1):
349
+ focus = focus_score[slice_idx[j]:slice_idx[j + 1]]
350
+ if torch.sum(can_focus[slice_idx[j]:slice_idx[j + 1]]) > 0 and ~finished[j]:
351
+ sample_focal_atom = torch.multinomial(focus.reshape(-1).float(), 1)
352
+ focal_motif = atom_to_motif[j][sample_focal_atom.item()]
353
+ motif_id[j] = focal_motif
354
+ else:
355
+ finished[j] = True
356
+
357
+ current_atoms.extend((np.array(motif_to_atoms[j][motif_id[j]]) + slice_idx[j].item()).tolist())
358
+ current_atoms_batch.extend([j] * len(motif_to_atoms[j][motif_id[j]]))
359
+ mol_list[j] = SetAtomNum(mol_list[j], motif_to_atoms[j][motif_id[j]])
360
+ # second step: next motif prediction
361
+ current_wid = [motif_wid[j][motif_id[j]] for j in range(len(mol_list))]
362
+ next_motif_wid, motif_prob = model.forward_motif(h_ctx_ligand[torch.tensor(current_atoms)],
363
+ torch.tensor(current_wid).to(device),
364
+ torch.tensor(current_atoms_batch).to(device))
365
+
366
+ # assemble
367
+ next_motif_smiles = [vocab.get_smiles(id) for id in next_motif_wid]
368
+ new_mol_list, new_atoms, one_atom_attach, intersection, attach_fail = model.forward_attach(mol_list, next_motif_smiles, device)
369
+
370
+ for j in range(len(mol_list)):
371
+ if ~finished[j] and ~attach_fail[j]:
372
+ # num_new_atoms
373
+ mol_list[j] = new_mol_list[j]
374
+ rotatable = torch.logical_and(torch.tensor(current_atoms_batch).bincount() == 2, torch.tensor(one_atom_attach))
375
+ rotatable = torch.logical_and(rotatable, ~torch.tensor(attach_fail))
376
+ rotatable = torch.logical_and(rotatable, ~finished).to(device)
377
+ # update motif2atoms and atom2motif
378
+ for j in range(len(mol_list)):
379
+ if attach_fail[j] or finished[j]:
380
+ continue
381
+ motif_to_atoms[j][i] = new_atoms[j]
382
+ motif_wid[j][i] = next_motif_wid[j]
383
+ for k in new_atoms[j]:
384
+ atom_to_motif[j][k] = i
385
+ '''
386
+ if k in atom_to_motif[j]:
387
+ continue
388
+ else:
389
+ atom_to_motif[j][k] = i'''
390
+
391
+ # generate initial positions
392
+ for j in range(len(mol_list)):
393
+ if attach_fail[j] or finished[j]:
394
+ continue
395
+ mol = mol_list[j]
396
+ anchor = [atom.GetIdx() for atom in mol.GetAtoms() if atom.GetAtomMapNum() == 1]
397
+ # positions = mol.GetConformer().GetPositions()
398
+ anchor_pos = deepcopy(pos_list[j][anchor]).to(device)
399
+ Chem.SanitizeMol(mol)
400
+ AllChem.EmbedMolecule(mol, useRandomCoords=True)
401
+ try:
402
+ AllChem.UFFOptimizeMolecule(mol)
403
+ except:
404
+ print('UFF error')
405
+ anchor_pos_new = mol.GetConformer(0).GetPositions()[anchor]
406
+ new_idx = [atom.GetIdx() for atom in mol.GetAtoms() if atom.GetAtomMapNum() == 2]
407
+ '''
408
+ R, T = kabsch(np.matrix(anchor_pos), np.matrix(anchor_pos_new))
409
+ new_pos = R * np.matrix(mol.GetConformer().GetPositions()[new_idx]).T + np.tile(T, (1, len(new_idx)))
410
+ new_pos = np.array(new_pos.T)'''
411
+ new_pos = mol.GetConformer().GetPositions()[new_idx]
412
+ new_pos, _, _ = kabsch_torch(torch.tensor(anchor_pos_new, device=device), anchor_pos, torch.tensor(new_pos, device=device))
413
+
414
+ conf = mol.GetConformer()
415
+ # update curated parameters
416
+ pos_list[j] = torch.cat([pos_list[j], new_pos])
417
+ feat_list[j] = get_feat(mol_list[j]).to(device)
418
+ for node in range(mol.GetNumAtoms()):
419
+ conf.SetAtomPosition(node, np.array(pos_list[j][node].cpu()))
420
+ assert mol.GetNumAtoms() == len(pos_list[j])
421
+
422
+ # predict alpha and rotate (only change the position)
423
+ if torch.sum(rotatable) > 0 and i >= 2:
424
+ repeats = torch.tensor([len(pos) for pos in pos_list])
425
+ ligand_batch = torch.repeat_interleave(torch.arange(len(pos_list)), repeats).to(device)
426
+ slice_idx = torch.cat([torch.tensor([0]), torch.cumsum(repeats, dim=0)])
427
+ xy_index = [(np.array(motif_to_atoms[j][motif_id[j]]) + slice_idx[j].item()).tolist() for j in range(len(slice_idx) - 1) if rotatable[j]]
428
+
429
+ alpha = model.forward_alpha(protein_pos=batch['protein_pos'].float(),
430
+ protein_atom_feature=batch['protein_atom_feature'].float(),
431
+ ligand_pos=torch.cat(pos_list, dim=0).float(),
432
+ ligand_atom_feature=torch.cat(feat_list, dim=0).float(),
433
+ batch_protein=batch['protein_element_batch'],
434
+ batch_ligand=ligand_batch, xy_index=torch.tensor(xy_index, device=device),
435
+ rotatable=rotatable)
436
+
437
+ rotatable_id = [id for id in range(len(mol_list)) if rotatable[id]]
438
+ xy_index = [motif_to_atoms[j][motif_id[j]] for j in range(len(slice_idx) - 1) if rotatable[j]]
439
+ x_index = [intersection[j] for j in range(len(slice_idx) - 1) if rotatable[j]]
440
+ y_index = [(set(xy_index[k]) - set(x_index[k])).pop() for k in range(len(x_index))]
441
+
442
+ for j in range(len(alpha)):
443
+ mol = mol_list[rotatable_id[j]]
444
+ new_idx = [atom.GetIdx() for atom in mol.GetAtoms() if atom.GetAtomMapNum() == 2]
445
+ positions = deepcopy(pos_list[rotatable_id[j]])
446
+
447
+ xn_pos = positions[new_idx].float()
448
+ dir=(positions[x_index[j]] - positions[y_index[j]]).reshape(-1)
449
+ ref=positions[x_index[j]].reshape(-1)
450
+ xn_pos = rand_rotate(dir.to(device), ref.to(device), xn_pos.to(device), alpha[j], device=device)
451
+ if xn_pos.shape[0] > 0:
452
+ pos_list[rotatable_id[j]][-len(xn_pos):] = xn_pos
453
+ conf = mol.GetConformer()
454
+ for node in range(mol.GetNumAtoms()):
455
+ conf.SetAtomPosition(node, np.array(pos_list[rotatable_id[j]][node].cpu()))
456
+ assert mol.GetNumAtoms() == len(pos_list[rotatable_id[j]])
457
+
458
+ return mol_list, pos_list
459
+
460
+
461
+ def demo(data_id):
462
+ vocab_path = 'vocab.txt'
463
+ device = 'cpu'
464
+ config = './configs/sample.yml'
465
+ vocab = []
466
+ for line in open(vocab_path):
467
+ p, _, _ = line.partition(':')
468
+ vocab.append(p)
469
+ vocab = Vocab(vocab)
470
+
471
+ # Load configs
472
+ config = load_config(config)
473
+
474
+ # Data
475
+ protein_featurizer = FeaturizeProteinAtom()
476
+ ligand_featurizer = FeaturizeLigandAtom()
477
+ masking = LigandMaskAll(vocab)
478
+ transform = Compose([
479
+ LigandCountNeighbors(),
480
+ protein_featurizer,
481
+ ligand_featurizer,
482
+ FeaturizeLigandBond(),
483
+ masking,
484
+ ])
485
+ dataset, subsets = get_dataset(
486
+ config=config.dataset,
487
+ transform=transform,
488
+ )
489
+ testset = subsets['test']
490
+ data = testset[data_id%100]
491
+ center = data['ligand_center'].to(device)
492
+ test_set = [data for _ in range(config.sample.num_samples)]
493
+
494
+ # Model (Main)
495
+ ckpt = torch.load(config.model.checkpoint, map_location=device)
496
+ model = FLAG(
497
+ ckpt['config'].model,
498
+ protein_atom_feature_dim=protein_featurizer.feature_dim,
499
+ ligand_atom_feature_dim=ligand_featurizer.feature_dim,
500
+ vocab=vocab,
501
+ ).to(device)
502
+ model.load_state_dict(ckpt['model'])
503
+
504
+ # my code goes here
505
+ sample_loader = DataLoader(test_set, batch_size=config.sample.batch_size,
506
+ shuffle=False, num_workers=config.sample.num_workers,
507
+ collate_fn=collate_mols)
508
+
509
+ with torch.no_grad():
510
+ model.eval()
511
+ number = 0
512
+ for batch in tqdm(sample_loader):
513
+ for key in batch:
514
+ batch[key] = batch[key].to(device)
515
+ gen_data, pos_list = ligand_gen(batch, model, vocab, config, center, device)
516
+ SetMolPos(gen_data, pos_list)
517
+ for mol in gen_data:
518
+ try:
519
+ AllChem.UFFOptimizeMolecule(mol)
520
+ except:
521
+ print('UFF error')
522
+ for _, mol in enumerate(gen_data):
523
+ number += 1
524
+ if mol.GetNumAtoms() < 12 or MolLogP(mol) < 0.60:
525
+ continue
526
+ filename = os.path.join('./data', 'Ligand.sdf')
527
+ writer = Chem.SDWriter(filename)
528
+ # writer.SetKekulize(False)
529
+ writer.write(mol, confId=0)
530
+ writer.close()
531
+ return filename
532
+
533
+
534
+
535
+ if __name__ == '__main__':
536
+ parser = argparse.ArgumentParser()
537
+ parser.add_argument('--config', type=str, default='./configs/sample.yml')
538
+ parser.add_argument('-i', '--data_id', type=int, default=0)
539
+ parser.add_argument('--device', type=str, default='cuda:0')
540
+ parser.add_argument('--outdir', type=str, default='./outputs')
541
+ parser.add_argument('--vocab_path', type=str, default='vocab.txt')
542
+ parser.add_argument('--num_workers', type=int, default=64)
543
+ args = parser.parse_args()
544
+
545
+ # Load vocab
546
+ vocab = []
547
+ for line in open(args.vocab_path):
548
+ p, _, _ = line.partition(':')
549
+ vocab.append(p)
550
+ vocab = Vocab(vocab)
551
+
552
+ # Load configs
553
+ config = load_config(args.config)
554
+ config_name = os.path.basename(args.config)[:os.path.basename(args.config).rfind('.')]
555
+ seed_all(config.sample.seed)
556
+
557
+ # Logging
558
+ log_dir = get_new_log_dir(args.outdir, prefix='%s-%d' % (config_name, args.data_id))
559
+ logger = get_logger('sample', log_dir)
560
+ logger.info(args)
561
+ logger.info(config)
562
+ shutil.copyfile(args.config, os.path.join(log_dir, os.path.basename(args.config)))
563
+
564
+ # Data
565
+ logger.info('Loading data...')
566
+ protein_featurizer = FeaturizeProteinAtom()
567
+ ligand_featurizer = FeaturizeLigandAtom()
568
+ masking = LigandMaskAll(vocab)
569
+ transform = Compose([
570
+ LigandCountNeighbors(),
571
+ protein_featurizer,
572
+ ligand_featurizer,
573
+ FeaturizeLigandBond(),
574
+ masking,
575
+ ])
576
+ dataset, subsets = get_dataset(
577
+ config=config.dataset,
578
+ transform=transform,
579
+ )
580
+ testset = subsets['test']
581
+ data = testset[args.data_id]
582
+ center = data['ligand_center'].to(args.device)
583
+ test_set = [data for _ in range(config.sample.num_samples)]
584
+
585
+ with open(os.path.join(log_dir, 'pocket_info.txt'), 'a') as f:
586
+ f.write(data['protein_filename'] + '\n')
587
+
588
+ # Model (Main)
589
+ logger.info('Loading main model...')
590
+ ckpt = torch.load(config.model.checkpoint, map_location=args.device)
591
+ model = FLAG(
592
+ ckpt['config'].model,
593
+ protein_atom_feature_dim=protein_featurizer.feature_dim,
594
+ ligand_atom_feature_dim=ligand_featurizer.feature_dim,
595
+ vocab=vocab,
596
+ ).to(args.device)
597
+ model.load_state_dict(ckpt['model'])
598
+
599
+ # my code goes here
600
+ sample_loader = DataLoader(test_set, batch_size=config.sample.batch_size,
601
+ shuffle=False, num_workers=config.sample.num_workers,
602
+ collate_fn=collate_mols)
603
+ data_list = []
604
+ try:
605
+ with torch.no_grad():
606
+ model.eval()
607
+ number = 0
608
+ number_list = []
609
+ for batch in tqdm(sample_loader):
610
+ for key in batch:
611
+ batch[key] = batch[key].to(args.device)
612
+ gen_data, pos_list = ligand_gen(batch, model, vocab, config, center, args.device)
613
+ SetMolPos(gen_data, pos_list)
614
+ for mol in gen_data:
615
+ try:
616
+ AllChem.UFFOptimizeMolecule(mol)
617
+ except:
618
+ print('UFF error')
619
+ data_list.extend(gen_data)
620
+ with open(os.path.join(log_dir, 'SMILES.txt'), 'a') as smiles_f:
621
+ for _, mol in enumerate(gen_data):
622
+ number+=1
623
+ if mol.GetNumAtoms() < 12 or MolLogP(mol) < 0.60:
624
+ continue
625
+ smiles_f.write(Chem.MolToSmiles(mol) + '\n')
626
+ writer = Chem.SDWriter(os.path.join(log_dir, '%d.sdf' % number))
627
+ # writer.SetKekulize(False)
628
+ writer.write(mol, confId=0)
629
+ writer.close()
630
+ number_list.append(number)
631
+
632
+ # Calculate metrics
633
+ print([Chem.MolToSmiles(mol) for mol in data_list])
634
+ smiles = [Chem.MolFromSmiles(Chem.MolToSmiles(mol)) for mol in data_list]
635
+ qed_list = [qed(mol) for mol in smiles if mol.GetNumAtoms() >= 8]
636
+ logp_list = [MolLogP(mol) for mol in smiles]
637
+ sa_list = [compute_sa_score(mol) for mol in smiles]
638
+ Lip_list = [lipinski(mol) for mol in smiles]
639
+ print('QED %.6f | LogP %.6f | SA %.6f | Lipinski %.6f \n' % (np.average(qed_list), np.average(logp_list), np.average(sa_list), np.average(Lip_list)))
640
+
641
+ except KeyboardInterrupt:
642
+ logger.info('Terminated. Generated molecules will be saved.')
643
+ with open(os.path.join(log_dir, 'SMILES.txt'), 'a') as smiles_f:
644
+ for i, mol in enumerate(data_list):
645
+ if mol.GetNumAtoms() < 12 or MolLogP(mol) < 0.60:
646
+ continue
647
+ smiles_f.write(Chem.MolToSmiles(mol) + '\n')
648
+ writer = Chem.SDWriter(os.path.join(log_dir, '%d.sdf' % i))
649
+ # writer.SetKekulize(False)
650
+ writer.write(mol, confId=0)
651
+ writer.close()
652
+
653
+ pool = mp.Pool(args.num_workers)
654
+ vina_list = []
655
+ pro_path = '/n/holyscratch01/mzitnik_lab/zaixizhang/pdbbind_pocket10/' + os.path.join(data['pdbid'], data['pdbid']+'_pocket.pdb')
656
+ for vina_score in tqdm(pool.imap_unordered(partial(calculate_vina, pro_path=pro_path, lig_path=log_dir), number_list), total=len(number_list)):
657
+ if vina_score != None:
658
+ vina_list.append(vina_score)
659
+ pool.close()
660
+ print('Vina: ', np.average(vina_list))
requirements.txt ADDED
@@ -0,0 +1,390 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This file may be used to create an environment using:
2
+ # $ conda create --name <env> --file <this file>
3
+ # platform: linux-64
4
+ _libgcc_mutex=0.1=conda_forge
5
+ _openmp_mutex=4.5=2_kmp_llvm
6
+ abseil-cpp=20211102.0=hd4dd3e8_0
7
+ absl-py=1.4.0=py38h06a4308_0
8
+ aiofiles=23.2.1=pypi_0
9
+ aiohttp=3.8.5=py38h5eee18b_0
10
+ aiosignal=1.2.0=pyhd3eb1b0_0
11
+ altair=5.1.2=pypi_0
12
+ amberlite=22.0=pypi_0
13
+ ambertools=22.0=py38h6177452_1
14
+ amberutils=21.0=pypi_0
15
+ annotated-types=0.6.0=pypi_0
16
+ antlr4-python3-runtime=4.9.3=pypi_0
17
+ anyio=3.7.1=pypi_0
18
+ appdirs=1.4.4=pyhd3eb1b0_0
19
+ argon2-cffi=21.3.0=pyhd3eb1b0_0
20
+ argon2-cffi-bindings=21.2.0=py38h7f8727e_0
21
+ arpack=3.7.0=hdefa2d7_2
22
+ arrow-cpp=11.0.0=hda39474_2
23
+ asttokens=2.0.5=pyhd3eb1b0_0
24
+ astunparse=1.6.3=py_0
25
+ async-timeout=4.0.2=py38h06a4308_0
26
+ attrs=22.1.0=py38h06a4308_0
27
+ autodocktools-py3=1.5.7.post1+9.gda0c87c=pypi_0
28
+ aws-c-common=0.4.57=he6710b0_1
29
+ aws-c-event-stream=0.1.6=h2531618_5
30
+ aws-checksums=0.1.9=he6710b0_0
31
+ aws-sdk-cpp=1.8.185=hce553d0_0
32
+ backcall=0.2.0=pyhd3eb1b0_0
33
+ beautifulsoup4=4.12.2=py38h06a4308_0
34
+ bio=1.5.9=pypi_0
35
+ biopython=1.81=pypi_0
36
+ biothings-client=0.3.0=pypi_0
37
+ blas=1.0=mkl
38
+ bleach=4.1.0=pyhd3eb1b0_0
39
+ blinker=1.4=py38h06a4308_0
40
+ blosc=1.21.3=h6a678d5_0
41
+ boost=1.74.0=py38h2b96118_5
42
+ boost-cpp=1.74.0=h75c5d50_8
43
+ bottleneck=1.3.5=py38h7deecbd_0
44
+ brotli=1.0.9=h5eee18b_7
45
+ brotli-bin=1.0.9=h5eee18b_7
46
+ brotlipy=0.7.0=py38h27cfd23_1003
47
+ bzip2=1.0.8=h7b6447c_0
48
+ c-ares=1.19.1=h5eee18b_0
49
+ c-blosc2=2.8.0=h6a678d5_0
50
+ ca-certificates=2023.08.22=h06a4308_0
51
+ cachetools=4.2.2=pyhd3eb1b0_0
52
+ cairo=1.16.0=hb05425b_5
53
+ certifi=2023.7.22=py38h06a4308_0
54
+ cffi=1.15.1=py38h74dc2b5_0
55
+ charset-normalizer=2.0.4=pyhd3eb1b0_0
56
+ click=8.0.4=py38h06a4308_0
57
+ comm=0.1.2=py38h06a4308_0
58
+ contourpy=1.0.5=py38hdb19cb5_0
59
+ cryptography=41.0.3=py38h130f0dd_0
60
+ cuda-cudart=11.7.99=0
61
+ cuda-cupti=11.7.101=0
62
+ cuda-libraries=11.7.1=0
63
+ cuda-nvrtc=11.7.99=0
64
+ cuda-nvtx=11.7.91=0
65
+ cuda-runtime=11.7.1=0
66
+ cudatoolkit=11.8.0=h6a678d5_0
67
+ curl=8.2.1=h37d81fd_0
68
+ cycler=0.11.0=pyhd3eb1b0_0
69
+ cython=3.0.2=py38h17151c0_0
70
+ dataclasses=0.8=pyh6d0b6a4_7
71
+ datamol=0.11.4=pypi_0
72
+ datasets=2.12.0=py38h06a4308_0
73
+ debugpy=1.6.7=py38h6a678d5_0
74
+ decorator=5.1.1=pyhd3eb1b0_0
75
+ defusedxml=0.7.1=pyhd3eb1b0_0
76
+ dill=0.3.6=py38h06a4308_0
77
+ docutils=0.17.1=pypi_0
78
+ easydict=1.9=py_0
79
+ entrypoints=0.4=py38h06a4308_0
80
+ exceptiongroup=1.1.3=pypi_0
81
+ executing=0.8.3=pyhd3eb1b0_0
82
+ expat=2.5.0=h6a678d5_0
83
+ fair-esm=2.0.0=pypi_0
84
+ fastapi=0.104.0=pypi_0
85
+ fasteners=0.19=pypi_0
86
+ ffmpeg=4.2.2=h20bf706_0
87
+ ffmpy=0.3.1=pypi_0
88
+ fftw=3.3.10=nompi_hf0379b8_106
89
+ filelock=3.9.0=py38h06a4308_0
90
+ fontconfig=2.14.2=h14ed4e7_0
91
+ fonttools=4.25.0=pyhd3eb1b0_0
92
+ freetype=2.12.1=h4a9f257_0
93
+ frozenlist=1.3.3=py38h5eee18b_0
94
+ fsspec=2023.4.0=py38h06a4308_0
95
+ gflags=2.2.2=he6710b0_0
96
+ giflib=5.2.1=h5eee18b_3
97
+ glib=2.69.1=h4ff587b_1
98
+ glog=0.5.0=h2531618_0
99
+ gmp=6.2.1=h295c915_3
100
+ gnutls=3.6.15=he1e5248_0
101
+ google-auth=2.22.0=py38h06a4308_0
102
+ google-auth-oauthlib=0.5.2=py38h06a4308_0
103
+ gprofiler-official=1.0.0=pypi_0
104
+ gradio=3.50.2=pypi_0
105
+ gradio-client=0.6.1=pypi_0
106
+ greenlet=2.0.1=py38h6a678d5_0
107
+ griddataformats=1.0.1=pypi_0
108
+ grpc-cpp=1.48.2=h5bf31a4_0
109
+ grpcio=1.48.2=py38h5bf31a4_0
110
+ gsd=3.1.1=pypi_0
111
+ h11=0.14.0=pypi_0
112
+ hdf4=4.2.15=h9772cbc_5
113
+ hdf5=1.12.1=nompi_h2386368_104
114
+ httpcore=0.18.0=pypi_0
115
+ httpx=0.25.0=pypi_0
116
+ huggingface_hub=0.15.1=py38h06a4308_0
117
+ icu=70.1=h27087fc_0
118
+ idna=3.4=py38h06a4308_0
119
+ importlib-metadata=6.0.0=py38h06a4308_0
120
+ importlib_metadata=6.0.0=hd3eb1b0_0
121
+ importlib_resources=5.2.0=pyhd3eb1b0_1
122
+ intel-openmp=2021.4.0=h06a4308_3561
123
+ ipykernel=6.25.0=py38h2f386ee_0
124
+ ipython=8.12.2=py38h06a4308_0
125
+ ipython_genutils=0.2.0=pyhd3eb1b0_1
126
+ jedi=0.18.1=py38h06a4308_1
127
+ jinja2=3.1.2=py38h06a4308_0
128
+ joblib=1.2.0=py38h06a4308_0
129
+ jpeg=9e=h5eee18b_1
130
+ jsonschema=4.17.3=py38h06a4308_0
131
+ jupyter_client=7.4.9=py38h06a4308_0
132
+ jupyter_core=5.3.0=py38h06a4308_0
133
+ jupyter_server=1.23.4=py38h06a4308_0
134
+ jupyterlab_pygments=0.1.2=py_0
135
+ kiwisolver=1.4.4=py38h6a678d5_0
136
+ krb5=1.20.1=h568e23c_1
137
+ lame=3.100=h7b6447c_0
138
+ lcms2=2.12=h3be6417_0
139
+ ld_impl_linux-64=2.38=h1181459_1
140
+ lerc=3.0=h295c915_0
141
+ libblas=3.9.0=12_linux64_mkl
142
+ libbrotlicommon=1.0.9=h5eee18b_7
143
+ libbrotlidec=1.0.9=h5eee18b_7
144
+ libbrotlienc=1.0.9=h5eee18b_7
145
+ libcublas=11.10.3.66=0
146
+ libcufft=10.7.2.124=h4fbf590_0
147
+ libcufile=1.7.1.12=0
148
+ libcurand=10.3.3.129=0
149
+ libcurl=8.2.1=h91b91d3_0
150
+ libcusolver=11.4.0.1=0
151
+ libcusparse=11.7.4.91=0
152
+ libdeflate=1.17=h5eee18b_0
153
+ libedit=3.1.20221030=h5eee18b_0
154
+ libev=4.33=h7f8727e_1
155
+ libevent=2.1.12=h8f2d780_0
156
+ libffi=3.3=he6710b0_2
157
+ libgcc-ng=13.1.0=he5830b7_0
158
+ libgfortran-ng=11.2.0=h00389a5_1
159
+ libgfortran5=11.2.0=h1234567_1
160
+ libgomp=13.1.0=he5830b7_0
161
+ libiconv=1.16=h7f8727e_2
162
+ libidn2=2.3.4=h5eee18b_0
163
+ liblapack=3.9.0=12_linux64_mkl
164
+ libnetcdf=4.8.1=nompi_h329d8a1_102
165
+ libnghttp2=1.52.0=ha637b67_1
166
+ libnpp=11.7.4.75=0
167
+ libnsl=2.0.0=h7f98852_0
168
+ libnvjpeg=11.8.0.2=0
169
+ libopus=1.3.1=h7b6447c_0
170
+ libpng=1.6.39=h5eee18b_0
171
+ libprotobuf=3.20.3=he621ea3_0
172
+ libsodium=1.0.18=h7b6447c_0
173
+ libssh2=1.10.0=h37d81fd_2
174
+ libstdcxx-ng=13.1.0=hfd8a6a1_0
175
+ libtasn1=4.19.0=h5eee18b_0
176
+ libthrift=0.15.0=h0d84882_2
177
+ libtiff=4.5.1=h6a678d5_0
178
+ libunistring=0.9.10=h27cfd23_0
179
+ libuuid=2.38.1=h0b41bf4_0
180
+ libvpx=1.7.0=h439df22_0
181
+ libwebp=1.2.4=h11a3e52_1
182
+ libwebp-base=1.2.4=h5eee18b_1
183
+ libxcb=1.15=h7f8727e_0
184
+ libxml2=2.9.14=h22db469_4
185
+ libxslt=1.1.35=h8affb1d_0
186
+ libzip=1.9.2=hc869a4a_1
187
+ libzlib=1.2.13=hd590300_5
188
+ littleutils=0.2.2=pypi_0
189
+ llvm-openmp=14.0.6=h9e868ea_0
190
+ loguru=0.7.2=pypi_0
191
+ lxml=4.9.1=py38h1edc446_0
192
+ lz4-c=1.9.4=h6a678d5_0
193
+ lzo=2.10=h7b6447c_2
194
+ markdown=3.4.1=py38h06a4308_0
195
+ markupsafe=2.1.1=py38h7f8727e_0
196
+ matplotlib-base=3.7.2=py38h1128e8f_0
197
+ matplotlib-inline=0.1.6=py38h06a4308_0
198
+ mdanalysis=2.4.3=pypi_0
199
+ mdtraj=1.9.9=py38h028faf2_0
200
+ meeko=0.1.dev3=pypi_0
201
+ mistune=0.8.4=py38h7b6447c_1000
202
+ mkl=2021.4.0=h06a4308_640
203
+ mkl-service=2.4.0=py38h7f8727e_0
204
+ mkl_fft=1.3.1=py38hd3c417c_0
205
+ mkl_random=1.2.2=py38h51133e4_0
206
+ mmcif-pdbx=2.0.1=pypi_0
207
+ mmpbsa-py=16.0=pypi_0
208
+ mmtf-python=1.1.3=pypi_0
209
+ mrcfile=1.4.3=pypi_0
210
+ mscorefonts=0.0.1=3
211
+ msgpack=1.0.6=pypi_0
212
+ multidict=6.0.2=py38h5eee18b_0
213
+ multiprocess=0.70.14=py38h06a4308_0
214
+ munkres=1.1.4=py_0
215
+ mygene=3.2.2=pypi_0
216
+ nb_conda_kernels=2.3.1=py38h06a4308_0
217
+ nbclassic=0.5.5=py38h06a4308_0
218
+ nbclient=0.5.13=py38h06a4308_0
219
+ nbconvert=6.5.4=py38h06a4308_0
220
+ nbformat=5.9.2=py38h06a4308_0
221
+ ncurses=6.4=h6a678d5_0
222
+ nest-asyncio=1.5.6=py38h06a4308_0
223
+ netcdf-fortran=4.5.4=nompi_h2b6e579_100
224
+ nettle=3.7.3=hbbd107a_1
225
+ networkx=3.1=pyhd8ed1ab_0
226
+ notebook=6.5.4=py38h06a4308_1
227
+ notebook-shim=0.2.2=py38h06a4308_0
228
+ numexpr=2.8.4=py38he184ba9_0
229
+ numpy=1.24.3=py38h14f4228_0
230
+ numpy-base=1.24.3=py38h31eccc5_0
231
+ oauthlib=3.2.2=py38h06a4308_0
232
+ ocl-icd=2.3.1=h7f98852_0
233
+ ocl-icd-system=1.0.0=1
234
+ ogb=1.3.6=pypi_0
235
+ omegaconf=2.3.0=pypi_0
236
+ openbabel=3.1.1=py38hd2c4bc0_3
237
+ openff-forcefields=2023.08.0=pyh1a96a4e_0
238
+ openff-toolkit=0.10.7=pyhd8ed1ab_0
239
+ openff-toolkit-base=0.10.7=pyhd8ed1ab_0
240
+ openh264=2.1.1=h4ff587b_0
241
+ openmm=8.0.0=py38hd11a18e_1
242
+ openssl=1.1.1w=h7f8727e_0
243
+ opt-einsum=3.3.0=pypi_0
244
+ orc=1.7.4=hb3bc3d3_1
245
+ orjson=3.9.9=pypi_0
246
+ outdated=0.2.2=pypi_0
247
+ packaging=23.1=py38h06a4308_0
248
+ packmol=20.010=h86c2bf4_0
249
+ packmol-memgen=1.2.3rc0=pypi_0
250
+ pandas=2.0.3=py38h1128e8f_0
251
+ pandocfilters=1.5.0=pyhd3eb1b0_0
252
+ parmed=3.4.4=py38h8dc9893_0
253
+ parso=0.8.3=pyhd3eb1b0_0
254
+ pcre=8.45=h295c915_0
255
+ pdb2pqr=3.6.1=pypi_0
256
+ pdb4amber=22.0=pypi_0
257
+ pdbfixer=1.9=pyh1a96a4e_0
258
+ perl=5.32.1=4_hd590300_perl5
259
+ pexpect=4.8.0=pyhd3eb1b0_3
260
+ pickleshare=0.7.5=pyhd3eb1b0_1003
261
+ pillow=9.4.0=py38h6a678d5_0
262
+ pip=23.2.1=py38h06a4308_0
263
+ pixman=0.40.0=h7f8727e_1
264
+ pkgutil-resolve-name=1.3.10=py38h06a4308_0
265
+ platformdirs=3.10.0=py38h06a4308_0
266
+ pooch=1.4.0=pyhd3eb1b0_0
267
+ posebusters=0.2.6=pypi_0
268
+ posecheck=0.1=dev_0
269
+ prolif=2.0.0.post1=pypi_0
270
+ prometheus_client=0.14.1=py38h06a4308_0
271
+ prompt-toolkit=3.0.36=py38h06a4308_0
272
+ propka=3.5.0=pypi_0
273
+ protobuf=3.20.3=py38h6a678d5_0
274
+ psutil=5.9.0=py38h5eee18b_0
275
+ ptyprocess=0.7.0=pyhd3eb1b0_2
276
+ pure_eval=0.2.2=pyhd3eb1b0_0
277
+ py-cpuinfo=8.0.0=pyhd3eb1b0_1
278
+ py3dmol=2.0.4=pypi_0
279
+ pyarrow=11.0.0=py38h468efa6_1
280
+ pyasn1=0.4.8=pyhd3eb1b0_0
281
+ pyasn1-modules=0.2.8=py_0
282
+ pycairo=1.23.0=py38hd1222b9_0
283
+ pycparser=2.21=pyhd3eb1b0_0
284
+ pydantic=2.4.2=pypi_0
285
+ pydantic-core=2.10.1=pypi_0
286
+ pydub=0.25.1=pypi_0
287
+ pyg-lib=0.2.0+pt113cu117=pypi_0
288
+ pygments=2.15.1=py38h06a4308_1
289
+ pygsp=0.5.1=pypi_0
290
+ pyjwt=2.4.0=py38h06a4308_0
291
+ pyopenssl=23.2.0=py38h06a4308_0
292
+ pyparsing=3.0.9=py38h06a4308_0
293
+ pyrsistent=0.18.0=py38heee7806_0
294
+ pysocks=1.7.1=py38h06a4308_0
295
+ pytables=3.8.0=py38hb8ae3fc_3
296
+ python=3.8.11=h12debd9_0_cpython
297
+ python-constraint=1.4.0=py_0
298
+ python-dateutil=2.8.2=pyhd3eb1b0_0
299
+ python-fastjsonschema=2.16.2=py38h06a4308_0
300
+ python-lmdb=1.4.1=py38h6a678d5_0
301
+ python-multipart=0.0.6=pypi_0
302
+ python-tzdata=2023.3=pyhd3eb1b0_0
303
+ python-xxhash=2.0.2=py38h5eee18b_1
304
+ python_abi=3.8=3_cp38
305
+ pytorch=1.13.0=py3.8_cuda11.7_cudnn8.5.0_0
306
+ pytorch-cuda=11.7=h778d358_5
307
+ pytorch-mutex=1.0=cuda
308
+ pytraj=2.0.6=pypi_0
309
+ pytz=2022.7=py38h06a4308_0
310
+ pyyaml=6.0=py38h5eee18b_1
311
+ pyzmq=23.2.0=py38h6a678d5_0
312
+ qvina=2.1.0=h62396cd_2
313
+ rdkit=2023.3.3=pypi_0
314
+ re2=2022.04.01=h295c915_0
315
+ readline=8.2=h5eee18b_0
316
+ reduce=3.24=0
317
+ regex=2022.7.9=py38h5eee18b_0
318
+ reportlab=3.5.67=py38hfdd840d_1
319
+ requests=2.31.0=py38h06a4308_0
320
+ requests-oauthlib=1.3.0=py_0
321
+ responses=0.13.3=pyhd3eb1b0_0
322
+ rsa=4.7.2=pyhd3eb1b0_1
323
+ sacremoses=0.0.43=pyhd3eb1b0_0
324
+ safetensors=0.3.2=py38hb02cf49_0
325
+ sander=22.0=pypi_0
326
+ scikit-learn=1.3.0=py38h1128e8f_0
327
+ scipy=1.10.1=py38h14f4228_0
328
+ seaborn=0.12.2=pypi_0
329
+ selfies=2.1.1=pypi_0
330
+ semantic-version=2.10.0=pypi_0
331
+ send2trash=1.8.0=pyhd3eb1b0_1
332
+ setuptools=68.0.0=py38h06a4308_0
333
+ six=1.16.0=pyhd3eb1b0_1
334
+ smirnoff99frosst=1.1.0=pyh44b312d_0
335
+ snappy=1.1.9=h295c915_0
336
+ sniffio=1.2.0=py38h06a4308_1
337
+ sortedcontainers=2.4.0=pypi_0
338
+ soupsieve=2.4=py38h06a4308_0
339
+ sqlalchemy=1.4.39=py38h5eee18b_0
340
+ sqlite=3.41.2=h5eee18b_0
341
+ stack_data=0.2.0=pyhd3eb1b0_0
342
+ starlette=0.27.0=pypi_0
343
+ tensorboard=2.12.1=py38h06a4308_0
344
+ tensorboard-data-server=0.7.0=py38h52d8a92_0
345
+ tensorboard-plugin-wit=1.8.1=py38h06a4308_0
346
+ terminado=0.17.1=py38h06a4308_0
347
+ threadpoolctl=2.2.0=pyh0d69192_0
348
+ tinycss2=1.2.1=py38h06a4308_0
349
+ tk=8.6.12=h1ccaba5_0
350
+ tokenizers=0.13.2=py38he7d60b5_1
351
+ toolz=0.12.0=pypi_0
352
+ torch-cluster=1.6.1+pt113cu117=pypi_0
353
+ torch-geometric=2.3.1=pypi_0
354
+ torch-scatter=2.1.1+pt113cu117=pypi_0
355
+ torch-sparse=0.6.17+pt113cu117=pypi_0
356
+ torch-spline-conv=1.2.2+pt113cu117=pypi_0
357
+ tornado=6.3.2=py38h5eee18b_0
358
+ tqdm=4.65.0=py38hb070fc8_0
359
+ traitlets=5.7.1=py38h06a4308_0
360
+ transformers=4.31.0=pyhd8ed1ab_0
361
+ typing-extensions=4.8.0=pypi_0
362
+ urllib3=1.26.16=py38h06a4308_0
363
+ utf8proc=2.6.1=h27cfd23_0
364
+ uvicorn=0.23.2=pypi_0
365
+ vina=1.2.2=pypi_0
366
+ wcwidth=0.2.5=pyhd3eb1b0_0
367
+ webencodings=0.5.1=py38_1
368
+ websocket-client=0.58.0=py38h06a4308_4
369
+ websockets=11.0.3=pypi_0
370
+ werkzeug=2.2.3=py38h06a4308_0
371
+ wheel=0.38.4=py38h06a4308_0
372
+ x264=1!157.20191217=h7b6447c_0
373
+ xmltodict=0.13.0=pyhd8ed1ab_0
374
+ xorg-kbproto=1.0.7=h7f98852_1002
375
+ xorg-libice=1.1.1=hd590300_0
376
+ xorg-libsm=1.2.4=h7391055_0
377
+ xorg-libx11=1.8.6=h8ee46fc_0
378
+ xorg-libxext=1.3.4=h0b41bf4_2
379
+ xorg-libxt=1.3.0=hd590300_1
380
+ xorg-xextproto=7.3.0=h0b41bf4_1003
381
+ xorg-xproto=7.0.31=h7f98852_1007
382
+ xxhash=0.8.0=h7f8727e_3
383
+ xz=5.2.10=h5eee18b_1
384
+ yaml=0.2.5=h7b6447c_0
385
+ yarl=1.8.1=py38h5eee18b_0
386
+ zeromq=4.3.4=h2531618_0
387
+ zipp=3.11.0=py38h06a4308_0
388
+ zlib=1.2.13=hd590300_5
389
+ zlib-ng=2.0.7=h5eee18b_0
390
+ zstd=1.5.5=hc292b87_0
utils/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .dihedral_utils import batch_dihedrals, rotation_matrix_v2, von_Mises_loss
2
+ from .chemutils import get_clique_mol, tree_decomp, get_mol, get_smiles, set_atommap, get_clique_mol_simple, assemble, mol_to_graph_data_obj_simple
3
+ from .dihedral_utils import rotation_matrix_v2, von_Mises_loss, batch_dihedrals
utils/chem.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import torch
3
+ from io import BytesIO
4
+ from openbabel import openbabel
5
+ from torch_geometric.utils import to_networkx
6
+ from torch_geometric.data import Data
7
+ from torch_scatter import scatter
8
+ from rdkit import Chem
9
+ from rdkit.Chem.rdchem import Mol, HybridizationType, BondType
10
+ from rdkit.Chem.rdchem import BondType as BT
11
+
12
+
13
+ BOND_TYPES = {t: i for i, t in enumerate(BT.names.values())}
14
+ BOND_NAMES = {i: t for i, t in enumerate(BT.names.keys())}
15
+
16
+
17
+
18
+ def rdmol_to_data(mol, smiles=None):
19
+ assert mol.GetNumConformers() == 1
20
+ N = mol.GetNumAtoms()
21
+
22
+ pos = torch.tensor(mol.GetConformer(0).GetPositions(), dtype=torch.float32)
23
+
24
+ atomic_number = []
25
+ aromatic = []
26
+ sp = []
27
+ sp2 = []
28
+ sp3 = []
29
+ num_hs = []
30
+ for atom in mol.GetAtoms():
31
+ atomic_number.append(atom.GetAtomicNum())
32
+ aromatic.append(1 if atom.GetIsAromatic() else 0)
33
+ hybridization = atom.GetHybridization()
34
+ sp.append(1 if hybridization == HybridizationType.SP else 0)
35
+ sp2.append(1 if hybridization == HybridizationType.SP2 else 0)
36
+ sp3.append(1 if hybridization == HybridizationType.SP3 else 0)
37
+
38
+ z = torch.tensor(atomic_number, dtype=torch.long)
39
+
40
+ row, col, edge_type = [], [], []
41
+ for bond in mol.GetBonds():
42
+ start, end = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
43
+ row += [start, end]
44
+ col += [end, start]
45
+ edge_type += 2 * [BOND_TYPES[bond.GetBondType()]]
46
+
47
+ edge_index = torch.tensor([row, col], dtype=torch.long)
48
+ edge_type = torch.tensor(edge_type)
49
+
50
+ perm = (edge_index[0] * N + edge_index[1]).argsort()
51
+ edge_index = edge_index[:, perm]
52
+ edge_type = edge_type[perm]
53
+
54
+ row, col = edge_index
55
+ hs = (z == 1).to(torch.float32)
56
+
57
+ num_hs = scatter(hs[row], col, dim_size=N, reduce='sum').tolist()
58
+
59
+ if smiles is None:
60
+ smiles = Chem.MolToSmiles(Chem.RemoveHs(mol))
61
+
62
+ data = Data(atom_type=z, pos=pos, edge_index=edge_index, edge_type=edge_type,
63
+ rdmol=copy.deepcopy(mol), smiles=smiles)
64
+ data.nx = to_networkx(data, to_undirected=True)
65
+
66
+ return data
67
+
68
+
69
+ def generated_to_xyz(data):
70
+ ptable = Chem.GetPeriodicTable()
71
+
72
+ num_atoms = data.ligand_context_element.size(0)
73
+ xyz = "%d\n\n" % (num_atoms, )
74
+ for i in range(num_atoms):
75
+ symb = ptable.GetElementSymbol(data.ligand_context_element[i].item())
76
+ x, y, z = data.ligand_context_pos[i].clone().cpu().tolist()
77
+ xyz += "%s %.8f %.8f %.8f\n" % (symb, x, y, z)
78
+
79
+ return xyz
80
+
81
+
82
+ def generated_to_sdf(data):
83
+ xyz = generated_to_xyz(data)
84
+ obConversion = openbabel.OBConversion()
85
+ obConversion.SetInAndOutFormats("xyz", "sdf")
86
+
87
+ mol = openbabel.OBMol()
88
+ obConversion.ReadString(mol, xyz)
89
+ sdf = obConversion.WriteString(mol)
90
+ return sdf
91
+
92
+
93
+ def sdf_to_rdmol(sdf):
94
+ stream = BytesIO(sdf.encode())
95
+ suppl = Chem.ForwardSDMolSupplier(stream)
96
+ for mol in suppl:
97
+ return mol
98
+ return None
99
+
100
+ def generated_to_rdmol(data):
101
+ sdf = generated_to_sdf(data)
102
+ return sdf_to_rdmol(sdf)
103
+
104
+
105
+ def filter_rd_mol(rdmol):
106
+ ring_info = rdmol.GetRingInfo()
107
+ ring_info.AtomRings()
108
+ rings = [set(r) for r in ring_info.AtomRings()]
109
+
110
+ # 3-3 ring intersection
111
+ for i, ring_a in enumerate(rings):
112
+ if len(ring_a) != 3:continue
113
+ for j, ring_b in enumerate(rings):
114
+ if i <= j: continue
115
+ inter = ring_a.intersection(ring_b)
116
+ if (len(ring_b) == 3) and (len(inter) > 0):
117
+ return False
118
+
119
+ return True
utils/chemutils.py ADDED
@@ -0,0 +1,597 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import rdkit
2
+ import rdkit.Chem as Chem
3
+ from scipy.sparse import csr_matrix
4
+ from scipy.sparse.csgraph import minimum_spanning_tree
5
+ from collections import defaultdict
6
+ from rdkit.Chem.EnumerateStereoisomers import EnumerateStereoisomers, StereoEnumerationOptions
7
+ from rdkit.Chem.Descriptors import MolLogP, qed
8
+ from torch_geometric.data import Data, Batch
9
+ from random import sample
10
+ from rdkit.Chem.rdForceFieldHelpers import UFFOptimizeMolecule
11
+ import numpy as np
12
+ from math import sqrt
13
+ import torch
14
+ from copy import deepcopy
15
+ MST_MAX_WEIGHT = 100
16
+ MAX_NCAND = 2000
17
+
18
+
19
+ def vina_score(mol):
20
+ ligand_rdmol = Chem.AddHs(mol, addCoords=True)
21
+ if use_uff:
22
+ UFFOptimizeMolecule(ligand_rdmol)
23
+
24
+ def lipinski(mol):
25
+ if qed(mol)<=5 and Chem.Lipinski.NumHDonors(mol)<=5 and Chem.Lipinski.NumHAcceptors(mol)<=10 and Chem.Descriptors.ExactMolWt(mol)<=500 and Chem.Lipinski.NumRotatableBonds(mol)<=5:
26
+ return True
27
+ else:
28
+ return False
29
+
30
+
31
+ def list_filter(a,b):
32
+ filter = []
33
+ for i in a:
34
+ if i in b:
35
+ filter.append(i)
36
+ return filter
37
+
38
+
39
+ def rand_rotate(dir, ref, pos, alpha=None, device=None):
40
+ if device is None:
41
+ device = 'cpu'
42
+ dir = dir/torch.norm(dir)
43
+ if alpha is None:
44
+ alpha = torch.randn(1).to(device)
45
+ n_pos = pos.shape[0]
46
+ sin, cos = torch.sin(alpha).to(device), torch.cos(alpha).to(device)
47
+ K = 1 - cos
48
+ M = torch.dot(dir, ref)
49
+ nx, ny, nz = dir[0], dir[1], dir[2]
50
+ x0, y0, z0 = ref[0], ref[1], ref[2]
51
+ T = torch.tensor([nx ** 2 * K + cos, nx * ny * K - nz * sin, nx * nz * K + ny * sin,
52
+ (x0 - nx * M) * K + (nz * y0 - ny * z0) * sin,
53
+ nx * ny * K + nz * sin, ny ** 2 * K + cos, ny * nz * K - nx * sin,
54
+ (y0 - ny * M) * K + (nx * z0 - nz * x0) * sin,
55
+ nx * nz * K - ny * sin, ny * nz * K + nx * sin, nz ** 2 * K + cos,
56
+ (z0 - nz * M) * K + (ny * x0 - nx * y0) * sin,
57
+ 0, 0, 0, 1], device=device).reshape(4, 4)
58
+ pos = torch.cat([pos.t(), torch.ones(n_pos, device=device).unsqueeze(0)], dim=0)
59
+ rotated_pos = torch.mm(T, pos)[:3]
60
+ return rotated_pos.t()
61
+
62
+
63
+ def kabsch(A, B):
64
+ # Input:
65
+ # Nominal A Nx3 matrix of points
66
+ # Measured B Nx3 matrix of points
67
+ # Returns R,t
68
+ # R = 3x3 rotation matrix (B to A)
69
+ # t = 3x1 translation vector (B to A)
70
+ assert len(A) == len(B)
71
+ N = A.shape[0] # total points
72
+ centroid_A = np.mean(A, axis=0)
73
+ centroid_B = np.mean(B, axis=0)
74
+ # center the points
75
+ AA = A - np.tile(centroid_A, (N, 1))
76
+ BB = B - np.tile(centroid_B, (N, 1))
77
+ H = np.transpose(BB) * AA
78
+ U, S, Vt = np.linalg.svd(H)
79
+ R = Vt.T * U.T
80
+ # special reflection case
81
+ if np.linalg.det(R) < 0:
82
+ Vt[2, :] *= -1
83
+ R = Vt.T * U.T
84
+ t = -R * centroid_B.T + centroid_A.T
85
+ return R, t
86
+
87
+
88
+ def kabsch_torch(A, B, C):
89
+ A=A.double()
90
+ B=B.double()
91
+ C=C.double()
92
+ a_mean = A.mean(dim=0, keepdims=True)
93
+ b_mean = B.mean(dim=0, keepdims=True)
94
+ A_c = A - a_mean
95
+ B_c = B - b_mean
96
+ # Covariance matrix
97
+ H = torch.matmul(A_c.transpose(0,1), B_c) # [B, 3, 3]
98
+ U, S, V = torch.svd(H)
99
+ # Rotation matrix
100
+ R = torch.matmul(V, U.transpose(0,1)) # [B, 3, 3]
101
+ # Translation vector
102
+ t = b_mean - torch.matmul(R, a_mean.transpose(0,1)).transpose(0,1)
103
+ C_aligned = torch.matmul(R, C.transpose(0,1)).transpose(0,1) + t
104
+ return C_aligned, R, t
105
+
106
+
107
+ def eig_coord_from_dist(D):
108
+ M = (D[:1, :] + D[:, :1] - D) / 2
109
+ L, V = torch.linalg.eigh(M)
110
+ L = torch.diag_embed(torch.sort(L, descending=True)[0])
111
+ X = torch.matmul(V, L.clamp(min=0).sqrt())
112
+ return X[:, :3].detach()
113
+
114
+
115
+ def self_square_dist(X):
116
+ dX = X.unsqueeze(0) - X.unsqueeze(1) # [1, N, 3] - [N, 1, 3]
117
+ D = torch.sum(dX**2, dim=-1)
118
+ return D
119
+
120
+
121
+ def set_atommap(mol, num=0):
122
+ for atom in mol.GetAtoms():
123
+ atom.SetAtomMapNum(num)
124
+
125
+
126
+ def get_mol(smiles):
127
+ mol = Chem.MolFromSmiles(smiles)
128
+ if mol is None:
129
+ return None
130
+ Chem.Kekulize(mol)
131
+ return mol
132
+
133
+
134
+ def get_smiles(mol):
135
+ return Chem.MolToSmiles(mol, kekuleSmiles=False)
136
+
137
+
138
+ def decode_stereo(smiles2D):
139
+ mol = Chem.MolFromSmiles(smiles2D)
140
+ dec_isomers = list(EnumerateStereoisomers(mol))
141
+
142
+ dec_isomers = [Chem.MolFromSmiles(Chem.MolToSmiles(mol, isomericSmiles=True)) for mol in dec_isomers]
143
+ smiles3D = [Chem.MolToSmiles(mol, isomericSmiles=True) for mol in dec_isomers]
144
+
145
+ chiralN = [atom.GetIdx() for atom in dec_isomers[0].GetAtoms() if
146
+ int(atom.GetChiralTag()) > 0 and atom.GetSymbol() == "N"]
147
+ if len(chiralN) > 0:
148
+ for mol in dec_isomers:
149
+ for idx in chiralN:
150
+ mol.GetAtomWithIdx(idx).SetChiralTag(Chem.rdchem.ChiralType.CHI_UNSPECIFIED)
151
+ smiles3D.append(Chem.MolToSmiles(mol, isomericSmiles=True))
152
+
153
+ return smiles3D
154
+
155
+
156
+ def sanitize(mol):
157
+ try:
158
+ smiles = get_smiles(mol)
159
+ mol = get_mol(smiles)
160
+ except Exception as e:
161
+ return None
162
+ return mol
163
+
164
+
165
+ def copy_atom(atom):
166
+ new_atom = Chem.Atom(atom.GetSymbol())
167
+ new_atom.SetFormalCharge(atom.GetFormalCharge())
168
+ new_atom.SetAtomMapNum(atom.GetAtomMapNum())
169
+ return new_atom
170
+
171
+
172
+ def copy_edit_mol(mol):
173
+ new_mol = Chem.RWMol(Chem.MolFromSmiles(''))
174
+ for atom in mol.GetAtoms():
175
+ new_atom = copy_atom(atom)
176
+ new_mol.AddAtom(new_atom)
177
+ for bond in mol.GetBonds():
178
+ a1 = bond.GetBeginAtom().GetIdx()
179
+ a2 = bond.GetEndAtom().GetIdx()
180
+ bt = bond.GetBondType()
181
+ new_mol.AddBond(a1, a2, bt)
182
+ return new_mol
183
+
184
+
185
+ def get_submol(mol, idxs, mark=[]):
186
+ new_mol = Chem.RWMol(Chem.MolFromSmiles(''))
187
+ map = {}
188
+ for atom in mol.GetAtoms():
189
+ if atom.GetIdx() in idxs:
190
+ new_atom = copy_atom(atom)
191
+ if atom.GetIdx() in mark:
192
+ new_atom.SetAtomMapNum(1)
193
+ else:
194
+ new_atom.SetAtomMapNum(0)
195
+ map[atom.GetIdx()] = new_mol.AddAtom(new_atom)
196
+ for bond in mol.GetBonds():
197
+ a1 = bond.GetBeginAtom().GetIdx()
198
+ a2 = bond.GetEndAtom().GetIdx()
199
+ if a1 in idxs and a2 in idxs:
200
+ bt = bond.GetBondType()
201
+ new_mol.AddBond(map[a1], map[a2], bt)
202
+ return new_mol.GetMol()
203
+
204
+
205
+ def get_clique_mol(mol, atoms):
206
+ smiles = Chem.MolFragmentToSmiles(mol, atoms, kekuleSmiles=True)
207
+ new_mol = Chem.MolFromSmiles(smiles, sanitize=False)
208
+ new_mol = copy_edit_mol(new_mol).GetMol()
209
+ new_mol = sanitize(new_mol) # We assume this is not None
210
+ return new_mol
211
+
212
+
213
+ def get_clique_mol_simple(mol, cluster):
214
+ smile_cluster = Chem.MolFragmentToSmiles(mol, cluster, canonical=True, kekuleSmiles=True)
215
+ mol_cluster = Chem.MolFromSmiles(smile_cluster, sanitize=False)
216
+ return mol_cluster
217
+
218
+
219
+ def tree_decomp(mol, reference_vocab=None):
220
+ edges = defaultdict(int)
221
+ n_atoms = mol.GetNumAtoms()
222
+ clusters = []
223
+ for bond in mol.GetBonds():
224
+ a1 = bond.GetBeginAtom().GetIdx()
225
+ a2 = bond.GetEndAtom().GetIdx()
226
+ if not bond.IsInRing():
227
+ clusters.append({a1, a2})
228
+ # extract rotatable bonds
229
+
230
+ ssr = [set(x) for x in Chem.GetSymmSSSR(mol)]
231
+ # remove too large circles
232
+ ssr = [x for x in ssr if len(x) <= 8]
233
+
234
+ # Merge Rings with intersection >= 2 atoms
235
+ # check the reference_vocab if it is not None
236
+ for i in range(len(ssr)-1):
237
+ if len(ssr[i]) <= 2:
238
+ continue
239
+ for j in range(i+1, len(ssr)):
240
+ if len(ssr[j]) <= 2:
241
+ continue
242
+ inter = ssr[i] & ssr[j]
243
+ if reference_vocab is not None:
244
+ if len(inter) >= 2:
245
+ merge = ssr[i] | ssr[j]
246
+ smile_merge = Chem.MolFragmentToSmiles(mol, merge, canonical=True, kekuleSmiles=True)
247
+ if reference_vocab[smile_merge] <= 100 and len(inter) == 2:
248
+ continue
249
+ ssr[i] = merge
250
+ ssr[j] = set()
251
+ else:
252
+ if len(inter) > 2:
253
+ merge = ssr[i] | ssr[j]
254
+ ssr[i] = merge
255
+ ssr[j] = set()
256
+
257
+ ssr = [c for c in ssr if len(c) > 0]
258
+ clusters.extend(ssr)
259
+ nei_list = [[] for _ in range(n_atoms)]
260
+ for i in range(len(clusters)):
261
+ for atom in clusters[i]:
262
+ nei_list[atom].append(i)
263
+
264
+ # Build edges
265
+ for atom in range(n_atoms):
266
+ if len(nei_list[atom]) <= 1:
267
+ continue
268
+ cnei = nei_list[atom]
269
+ for i in range(len(cnei)):
270
+ for j in range(i + 1, len(cnei)):
271
+ c1, c2 = cnei[i], cnei[j]
272
+ inter = set(clusters[c1]) & set(clusters[c2])
273
+ if edges[(c1, c2)] < len(inter):
274
+ edges[(c1, c2)] = len(inter) # cnei[i] < cnei[j] by construction
275
+
276
+ edges = [u + (MST_MAX_WEIGHT - v,) for u, v in edges.items()]
277
+ if len(edges) == 0:
278
+ return clusters, edges
279
+
280
+ # Compute Maximum Spanning Tree
281
+ row, col, data = zip(*edges)
282
+ n_clique = len(clusters)
283
+ clique_graph = csr_matrix((data, (row, col)), shape=(n_clique, n_clique))
284
+ junc_tree = minimum_spanning_tree(clique_graph)
285
+ row, col = junc_tree.nonzero()
286
+ edges = [(row[i], col[i]) for i in range(len(row))]
287
+ return clusters, edges
288
+
289
+
290
+ def atom_equal(a1, a2):
291
+ return a1.GetSymbol() == a2.GetSymbol() and a1.GetFormalCharge() == a2.GetFormalCharge()
292
+
293
+
294
+ # Bond type not considered because all aromatic (so SINGLE matches DOUBLE)
295
+ def ring_bond_equal(bond1, bond2, reverse=False):
296
+ b1 = (bond1.GetBeginAtom(), bond1.GetEndAtom())
297
+ if reverse:
298
+ b2 = (bond2.GetEndAtom(), bond2.GetBeginAtom())
299
+ else:
300
+ b2 = (bond2.GetBeginAtom(), bond2.GetEndAtom())
301
+ return atom_equal(b1[0], b2[0]) and atom_equal(b1[1], b2[1]) and bond1.GetBondType() == bond2.GetBondType()
302
+
303
+
304
+ def attach(ctr_mol, nei_mol, amap):
305
+ ctr_mol = Chem.RWMol(ctr_mol)
306
+ for atom in nei_mol.GetAtoms():
307
+ if atom.GetIdx() not in amap:
308
+ new_atom = copy_atom(atom)
309
+ new_atom.SetAtomMapNum(2)
310
+ amap[atom.GetIdx()] = ctr_mol.AddAtom(new_atom)
311
+
312
+ for bond in nei_mol.GetBonds():
313
+ a1 = amap[bond.GetBeginAtom().GetIdx()]
314
+ a2 = amap[bond.GetEndAtom().GetIdx()]
315
+ if ctr_mol.GetBondBetweenAtoms(a1, a2) is None:
316
+ ctr_mol.AddBond(a1, a2, bond.GetBondType())
317
+
318
+ return ctr_mol.GetMol(), amap
319
+
320
+
321
+ def attach_mols(ctr_mol, neighbors, prev_nodes, nei_amap):
322
+ prev_nids = [node.nid for node in prev_nodes]
323
+ for nei_node in prev_nodes + neighbors:
324
+ nei_id, nei_mol = nei_node.nid, nei_node.mol
325
+ amap = nei_amap[nei_id]
326
+ for atom in nei_mol.GetAtoms():
327
+ if atom.GetIdx() not in amap:
328
+ new_atom = copy_atom(atom)
329
+ amap[atom.GetIdx()] = ctr_mol.AddAtom(new_atom)
330
+
331
+ if nei_mol.GetNumBonds() == 0:
332
+ nei_atom = nei_mol.GetAtomWithIdx(0)
333
+ ctr_atom = ctr_mol.GetAtomWithIdx(amap[0])
334
+ ctr_atom.SetAtomMapNum(nei_atom.GetAtomMapNum())
335
+ else:
336
+ for bond in nei_mol.GetBonds():
337
+ a1 = amap[bond.GetBeginAtom().GetIdx()]
338
+ a2 = amap[bond.GetEndAtom().GetIdx()]
339
+ if ctr_mol.GetBondBetweenAtoms(a1, a2) is None:
340
+ ctr_mol.AddBond(a1, a2, bond.GetBondType())
341
+ elif nei_id in prev_nids: # father node overrides
342
+ ctr_mol.RemoveBond(a1, a2)
343
+ ctr_mol.AddBond(a1, a2, bond.GetBondType())
344
+ return ctr_mol
345
+
346
+
347
+ def local_attach(ctr_mol, neighbors, prev_nodes, amap_list):
348
+ ctr_mol = copy_edit_mol(ctr_mol)
349
+ nei_amap = {nei.nid: {} for nei in prev_nodes + neighbors}
350
+
351
+ for nei_id, ctr_atom, nei_atom in amap_list:
352
+ nei_amap[nei_id][nei_atom] = ctr_atom
353
+
354
+ ctr_mol = attach_mols(ctr_mol, neighbors, prev_nodes, nei_amap)
355
+ return ctr_mol.GetMol()
356
+
357
+
358
+ # This version records idx mapping between ctr_mol and nei_mol
359
+ def enum_attach(ctr_mol, nei_mol):
360
+ try:
361
+ Chem.Kekulize(ctr_mol)
362
+ Chem.Kekulize(nei_mol)
363
+ except:
364
+ return []
365
+ att_confs = []
366
+ valence_ctr = {i: 0 for i in range(ctr_mol.GetNumAtoms())}
367
+ valence_nei = {i: 0 for i in range(nei_mol.GetNumAtoms())}
368
+ ctr_bonds = [bond for bond in ctr_mol.GetBonds() if bond.GetBeginAtom().GetAtomMapNum() == 1 and bond.GetEndAtom().GetAtomMapNum() == 1]
369
+ ctr_atoms = [atom for atom in ctr_mol.GetAtoms() if atom.GetAtomMapNum() == 1]
370
+ if nei_mol.GetNumBonds() == 1: # neighbor is a bond
371
+ bond = nei_mol.GetBondWithIdx(0)
372
+ #bond_val = int(bond.GetBondType())
373
+ bond_val = int(bond.GetBondTypeAsDouble())
374
+ b1, b2 = bond.GetBeginAtom(), bond.GetEndAtom()
375
+
376
+ for atom in ctr_atoms:
377
+ # Optimize if atom is carbon (other atoms may change valence)
378
+ if atom.GetAtomicNum() == 6 and atom.GetTotalNumHs() < bond_val:
379
+ continue
380
+ if atom_equal(atom, b1):
381
+ new_amap = {b1.GetIdx(): atom.GetIdx()}
382
+ att_confs.append(new_amap)
383
+ elif atom_equal(atom, b2):
384
+ new_amap = {b2.GetIdx(): atom.GetIdx()}
385
+ att_confs.append(new_amap)
386
+ else:
387
+ # intersection is an atom
388
+ for a1 in ctr_atoms:
389
+ for a2 in nei_mol.GetAtoms():
390
+ if atom_equal(a1, a2):
391
+ # Optimize if atom is carbon (other atoms may change valence)
392
+ if a1.GetAtomicNum() == 6 and a1.GetTotalNumHs() + a2.GetTotalNumHs() < 4:
393
+ continue
394
+ amap = {a2.GetIdx(): a1.GetIdx()}
395
+ att_confs.append(amap)
396
+
397
+ # intersection is an bond
398
+ if ctr_mol.GetNumBonds() > 1:
399
+ for b1 in ctr_bonds:
400
+ for b2 in nei_mol.GetBonds():
401
+ if ring_bond_equal(b1, b2):
402
+ amap = {b2.GetBeginAtom().GetIdx(): b1.GetBeginAtom().GetIdx(),
403
+ b2.GetEndAtom().GetIdx(): b1.GetEndAtom().GetIdx()}
404
+ att_confs.append(amap)
405
+
406
+ if ring_bond_equal(b1, b2, reverse=True):
407
+ amap = {b2.GetEndAtom().GetIdx(): b1.GetBeginAtom().GetIdx(),
408
+ b2.GetBeginAtom().GetIdx(): b1.GetEndAtom().GetIdx()}
409
+ att_confs.append(amap)
410
+ return att_confs
411
+
412
+
413
+ def enumerate_assemble(mol, idxs, current, next):
414
+ ctr_mol = get_submol(mol, idxs, mark=current.clique)
415
+ ground_truth = get_submol(mol, list(set(idxs) | set(next.clique)))
416
+ # submol can also obtained with get_clique_mol, future exploration
417
+ ground_truth_smiles = get_smiles(ground_truth)
418
+ cand_smiles = []
419
+ cand_mols = []
420
+ cand_amap = enum_attach(ctr_mol, next.mol)
421
+ for amap in cand_amap:
422
+ try:
423
+ cand_mol, _ = attach(ctr_mol, next.mol, amap)
424
+ cand_mol = sanitize(cand_mol)
425
+ except:
426
+ continue
427
+ if cand_mol is None:
428
+ continue
429
+ smiles = get_smiles(cand_mol)
430
+ if smiles in cand_smiles or smiles == ground_truth_smiles:
431
+ continue
432
+ cand_smiles.append(smiles)
433
+ cand_mols.append(cand_mol)
434
+ if len(cand_mols) >= 1:
435
+ cand_mols = sample(cand_mols, 1)
436
+ cand_mols.append(ground_truth)
437
+ labels = torch.tensor([0, 1])
438
+ else:
439
+ cand_mols = [ground_truth]
440
+ labels = torch.tensor([1])
441
+
442
+ return labels, cand_mols
443
+
444
+
445
+ # allowable node and edge features
446
+ allowable_features = {
447
+ 'possible_atomic_num_list' : list(range(1, 119)),
448
+ 'possible_formal_charge_list' : [-5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5],
449
+ 'possible_chirality_list' : [
450
+ Chem.rdchem.ChiralType.CHI_UNSPECIFIED,
451
+ Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CW,
452
+ Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CCW,
453
+ Chem.rdchem.ChiralType.CHI_OTHER
454
+ ],
455
+ 'possible_hybridization_list' : [
456
+ Chem.rdchem.HybridizationType.S,
457
+ Chem.rdchem.HybridizationType.SP, Chem.rdchem.HybridizationType.SP2,
458
+ Chem.rdchem.HybridizationType.SP3, Chem.rdchem.HybridizationType.SP3D,
459
+ Chem.rdchem.HybridizationType.SP3D2, Chem.rdchem.HybridizationType.UNSPECIFIED
460
+ ],
461
+ 'possible_numH_list' : [0, 1, 2, 3, 4, 5, 6, 7, 8],
462
+ 'possible_implicit_valence_list' : [0, 1, 2, 3, 4, 5, 6],
463
+ 'possible_degree_list' : [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
464
+ 'possible_bonds' : [
465
+ Chem.rdchem.BondType.SINGLE,
466
+ Chem.rdchem.BondType.DOUBLE,
467
+ Chem.rdchem.BondType.TRIPLE,
468
+ Chem.rdchem.BondType.AROMATIC
469
+ ],
470
+ 'possible_bond_dirs' : [ # only for double bond stereo information
471
+ Chem.rdchem.BondDir.NONE,
472
+ Chem.rdchem.BondDir.ENDUPRIGHT,
473
+ Chem.rdchem.BondDir.ENDDOWNRIGHT
474
+ ]
475
+ }
476
+
477
+ def mol_to_graph_data_obj_simple(mol):
478
+ """
479
+ Converts rdkit mol object to graph Data object required by the pytorch
480
+ geometric package. NB: Uses simplified atom and bond features, and represent
481
+ as indices
482
+ :param mol: rdkit mol object
483
+ :return: graph data object with the attributes: x, edge_index, edge_attr
484
+ """
485
+ # atoms
486
+ num_atom_features = 2 # atom type, chirality tag
487
+ atom_features_list = []
488
+ for atom in mol.GetAtoms():
489
+ atom_feature = [allowable_features['possible_atomic_num_list'].index(
490
+ atom.GetAtomicNum())] + [allowable_features[
491
+ 'possible_chirality_list'].index(atom.GetChiralTag())]
492
+ atom_features_list.append(atom_feature)
493
+ x = torch.tensor(np.array(atom_features_list), dtype=torch.long)
494
+
495
+ # bonds
496
+ num_bond_features = 2 # bond type, bond direction
497
+ if len(mol.GetBonds()) > 0: # mol has bonds
498
+ edges_list = []
499
+ edge_features_list = []
500
+ for bond in mol.GetBonds():
501
+ i = bond.GetBeginAtomIdx()
502
+ j = bond.GetEndAtomIdx()
503
+ edge_feature = [allowable_features['possible_bonds'].index(
504
+ bond.GetBondType())] + [allowable_features[
505
+ 'possible_bond_dirs'].index(
506
+ bond.GetBondDir())]
507
+ edges_list.append((i, j))
508
+ edge_features_list.append(edge_feature)
509
+ edges_list.append((j, i))
510
+ edge_features_list.append(edge_feature)
511
+
512
+ # data.edge_index: Graph connectivity in COO format with shape [2, num_edges]
513
+ edge_index = torch.tensor(np.array(edges_list).T, dtype=torch.long)
514
+
515
+ # data.edge_attr: Edge feature matrix with shape [num_edges, num_edge_features]
516
+ edge_attr = torch.tensor(np.array(edge_features_list),
517
+ dtype=torch.long)
518
+ else: # mol has no bonds
519
+ edge_index = torch.empty((2, 0), dtype=torch.long)
520
+ edge_attr = torch.empty((0, num_bond_features), dtype=torch.long)
521
+
522
+ data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr)
523
+
524
+ return data
525
+
526
+
527
+ # For inference
528
+ def assemble(mol_list, next_motif_smiles):
529
+ attach_fail = torch.zeros(len(mol_list)).bool()
530
+ cand_mols, cand_batch, new_atoms, cand_smiles, one_atom_attach, intersection = [], [], [], [], [], []
531
+ for i in range(len(mol_list)):
532
+ next = Chem.MolFromSmiles(next_motif_smiles[i])
533
+ cand_amap = enum_attach(mol_list[i], next)
534
+ if len(cand_amap) == 0:
535
+ attach_fail[i] = True
536
+ cand_mols.append(mol_list[i])
537
+ cand_batch.append(i)
538
+ one_atom_attach.append(-1)
539
+ intersection.append([])
540
+ new_atoms.append([])
541
+ else:
542
+ valid_cand = 0
543
+ for amap in cand_amap:
544
+ amap_len = len(amap)
545
+ iter_atoms = [v for v in amap.values()]
546
+ ctr_mol = deepcopy(mol_list[i])
547
+ cand_mol, amap1 = attach(ctr_mol, next, amap)
548
+ if sanitize(deepcopy(cand_mol)) is None:
549
+ continue
550
+ smiles = get_smiles(cand_mol)
551
+ cand_smiles.append(smiles)
552
+ cand_mols.append(cand_mol)
553
+ cand_batch.append(i)
554
+ new_atoms.append([v for v in amap1.values()])
555
+ one_atom_attach.append(amap_len)
556
+ intersection.append(iter_atoms)
557
+ valid_cand+=1
558
+ if valid_cand==0:
559
+ attach_fail[i] = True
560
+ cand_mols.append(mol_list[i])
561
+ cand_batch.append(i)
562
+ one_atom_attach.append(-1)
563
+ intersection.append([])
564
+ new_atoms.append([])
565
+ cand_batch = torch.tensor(cand_batch)
566
+ one_atom_attach = torch.tensor(one_atom_attach) == 1
567
+ return cand_mols, cand_batch, new_atoms, one_atom_attach, intersection, attach_fail
568
+
569
+
570
+ if __name__ == "__main__":
571
+ import sys
572
+ from mol_tree import MolTree
573
+
574
+ lg = rdkit.RDLogger.logger()
575
+ lg.setLevel(rdkit.RDLogger.CRITICAL)
576
+
577
+ smiles = ["O=C1[C@@H]2C=C[C@@H](C=CC2)C1(c1ccccc1)c1ccccc1", "O=C([O-])CC[C@@]12CCCC[C@]1(O)OC(=O)CC2",
578
+ "ON=C1C[C@H]2CC3(C[C@@H](C1)c1ccccc12)OCCO3",
579
+ "C[C@H]1CC(=O)[C@H]2[C@@]3(O)C(=O)c4cccc(O)c4[C@@H]4O[C@@]43[C@@H](O)C[C@]2(O)C1",
580
+ 'Cc1cc(NC(=O)CSc2nnc3c4ccccc4n(C)c3n2)ccc1Br', 'CC(C)(C)c1ccc(C(=O)N[C@H]2CCN3CCCc4cccc2c43)cc1',
581
+ "O=c1c2ccc3c(=O)n(-c4nccs4)c(=O)c4ccc(c(=O)n1-c1nccs1)c2c34", "O=C(N1CCc2c(F)ccc(F)c2C1)C1(O)Cc2ccccc2C1"]
582
+ mol_tree = MolTree("C")
583
+ assert len(mol_tree.nodes) > 0
584
+
585
+
586
+ def count():
587
+ cnt, n = 0, 0
588
+ for s in sys.stdin:
589
+ s = s.split()[0]
590
+ tree = MolTree(s)
591
+ tree.recover()
592
+ tree.assemble()
593
+ for node in tree.nodes:
594
+ cnt += len(node.cands)
595
+ n += len(tree.nodes)
596
+ # print cnt * 1.0 / n
597
+ count()
utils/data.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import torch
3
+ import numpy as np
4
+ from torch_geometric.data import Data, Batch
5
+ # from torch_geometric.loader import DataLoader
6
+ from torch.utils.data import Dataset
7
+
8
+ FOLLOW_BATCH = ['protein_element', 'ligand_context_element', 'pos_real', 'pos_fake']
9
+
10
+
11
+ class ProteinLigandData(object):
12
+
13
+ def __init__(self, *args, **kwargs):
14
+ super().__init__(*args, **kwargs)
15
+
16
+ @staticmethod
17
+ def from_protein_ligand_dicts(protein_dict=None, ligand_dict=None, **kwargs):
18
+ instance = ProteinLigandData(**kwargs)
19
+
20
+ if protein_dict is not None:
21
+ for key, item in protein_dict.items():
22
+ instance['protein_' + key] = item
23
+
24
+ if ligand_dict is not None:
25
+ for key, item in ligand_dict.items():
26
+ if key == 'moltree':
27
+ instance['moltree'] = item
28
+ else:
29
+ instance['ligand_' + key] = item
30
+
31
+ # instance['ligand_nbh_list'] = {i.item():[j.item() for k, j in enumerate(instance.ligand_bond_index[1]) if instance.ligand_bond_index[0, k].item() == i] for i in instance.ligand_bond_index[0]}
32
+ return instance
33
+
34
+
35
+ def batch_from_data_list(data_list):
36
+ return Batch.from_data_list(data_list, follow_batch=['ligand_element', 'protein_element'])
37
+
38
+
39
+ def torchify_dict(data):
40
+ output = {}
41
+ for k, v in data.items():
42
+ if isinstance(v, np.ndarray):
43
+ output[k] = torch.from_numpy(v)
44
+ else:
45
+ output[k] = v
46
+ return output
47
+
48
+
49
+ def collate_mols(mol_dicts):
50
+ data_batch = {}
51
+ batch_size = len(mol_dicts)
52
+ for key in ['protein_pos', 'protein_atom_feature', 'ligand_context_pos', 'ligand_context_feature_full',
53
+ 'ligand_frontier', 'num_atoms', 'next_wid', 'current_wid', 'current_atoms', 'cand_labels',
54
+ 'ligand_pos_torsion', 'ligand_feature_torsion', 'true_sin', 'true_cos', 'true_three_hop',
55
+ 'dihedral_mask', 'protein_contact', 'true_dm', 'alpha_carbon_indicator']:
56
+ data_batch[key] = torch.cat([mol_dict[key] for mol_dict in mol_dicts], dim=0)
57
+ # unsqueeze dim0
58
+ for key in ['xn_pos', 'yn_pos', 'ligand_torsion_xy_index', 'y_pos']:
59
+ cat_list = [mol_dict[key].unsqueeze(0) for mol_dict in mol_dicts if len(mol_dict[key]) > 0]
60
+ if len(cat_list) > 0:
61
+ data_batch[key] = torch.cat(cat_list, dim=0)
62
+ else:
63
+ data_batch[key] = torch.tensor([])
64
+ # follow batch
65
+ for key in ['protein_element', 'ligand_context_element', 'current_atoms']:
66
+ repeats = torch.tensor([len(mol_dict[key]) for mol_dict in mol_dicts])
67
+ data_batch[key + '_batch'] = torch.repeat_interleave(torch.arange(batch_size), repeats)
68
+ for key in ['ligand_element_torsion']:
69
+ repeats = torch.tensor([len(mol_dict[key]) for mol_dict in mol_dicts if len(mol_dict[key]) > 0])
70
+ if len(repeats) > 0:
71
+ data_batch[key + '_batch'] = torch.repeat_interleave(torch.arange(len(repeats)), repeats)
72
+ else:
73
+ data_batch[key + '_batch'] = torch.tensor([])
74
+ # distance matrix prediction
75
+ p_idx, q_idx = torch.cartesian_prod(torch.arange(4), torch.arange(2)).chunk(2, dim=-1)
76
+ p_idx, q_idx = p_idx.squeeze(-1), q_idx.squeeze(-1)
77
+ protein_offsets = torch.cumsum(data_batch['protein_element_batch'].bincount(), dim=0)
78
+ ligand_offsets = torch.cumsum(data_batch['ligand_context_element_batch'].bincount(), dim=0)
79
+ protein_offsets, ligand_offsets = torch.cat([torch.tensor([0]), protein_offsets]), torch.cat([torch.tensor([0]), ligand_offsets])
80
+ ligand_idx, protein_idx = [], []
81
+ for i, mol_dict in enumerate(mol_dicts):
82
+ if len(mol_dict['true_dm']) > 0:
83
+ protein_idx.append(mol_dict['dm_protein_idx'][p_idx] + protein_offsets[i])
84
+ ligand_idx.append(mol_dict['dm_ligand_idx'][q_idx] + ligand_offsets[i])
85
+ if len(ligand_idx) > 0:
86
+ data_batch['dm_ligand_idx'], data_batch['dm_protein_idx'] = torch.cat(ligand_idx), torch.cat(protein_idx)
87
+
88
+ # structure refinement (alpha carbon - ligand atom)
89
+ sr_ligand_idx, sr_protein_idx = [], []
90
+ for i, mol_dict in enumerate(mol_dicts):
91
+ if len(mol_dict['true_dm']) > 0:
92
+ ligand_atom_index = torch.arange(len(mol_dict['ligand_context_pos']))
93
+ p_idx, q_idx = torch.cartesian_prod(torch.arange(len(mol_dict['ligand_context_pos'])), torch.arange(len(mol_dict['protein_alpha_carbon_index']))).chunk(2, dim=-1)
94
+ p_idx, q_idx = p_idx.squeeze(-1), q_idx.squeeze(-1)
95
+ sr_ligand_idx.append(ligand_atom_index[p_idx] + ligand_offsets[i])
96
+ sr_protein_idx.append(mol_dict['protein_alpha_carbon_index'][q_idx] + protein_offsets[i])
97
+ if len(sr_ligand_idx) > 0:
98
+ data_batch['sr_ligand_idx'], data_batch['sr_protein_idx'] = torch.cat(sr_ligand_idx).long(), torch.cat(sr_protein_idx).long()
99
+
100
+ # structure refinement (ligand atom - ligand atom)
101
+ sr_ligand_idx0, sr_ligand_idx1 = [], []
102
+ for i, mol_dict in enumerate(mol_dicts):
103
+ if len(mol_dict['true_dm']) > 0:
104
+ ligand_atom_index = torch.arange(len(mol_dict['ligand_context_pos']))
105
+ p_idx, q_idx = torch.cartesian_prod(torch.arange(len(mol_dict['ligand_context_pos'])), torch.arange(len(mol_dict['ligand_context_pos']))).chunk(2, dim=-1)
106
+ p_idx, q_idx = p_idx.squeeze(-1), q_idx.squeeze(-1)
107
+ sr_ligand_idx0.append(ligand_atom_index[p_idx] + ligand_offsets[i])
108
+ sr_ligand_idx1.append(ligand_atom_index[q_idx] + ligand_offsets[i])
109
+ if len(ligand_idx) > 0:
110
+ data_batch['sr_ligand_idx0'], data_batch['sr_ligand_idx1'] = torch.cat(sr_ligand_idx0).long(), torch.cat(sr_ligand_idx1).long()
111
+ # index
112
+ if len(data_batch['y_pos']) > 0:
113
+ repeats = torch.tensor([len(mol_dict['ligand_element_torsion']) for mol_dict in mol_dicts if len(mol_dict['ligand_element_torsion']) > 0])
114
+ offsets = torch.cat([torch.tensor([0]), torch.cumsum(repeats, dim=0)])[:-1]
115
+ data_batch['ligand_torsion_xy_index'] += offsets.unsqueeze(1)
116
+
117
+ offsets1 = torch.cat([torch.tensor([0]), torch.cumsum(data_batch['num_atoms'], dim=0)])[:-1]
118
+ data_batch['current_atoms'] += torch.repeat_interleave(offsets1, data_batch['current_atoms_batch'].bincount())
119
+ # cand mols: torch geometric Data
120
+ cand_mol_list = []
121
+ for data in mol_dicts:
122
+ if len(data['cand_labels']) > 0:
123
+ cand_mol_list.extend(data['cand_mols'])
124
+ if len(cand_mol_list) > 0:
125
+ data_batch['cand_mols'] = Batch.from_data_list(cand_mol_list)
126
+ return data_batch
127
+
utils/datasets/__init__.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.utils.data import Subset
3
+ from .pl import PocketLigandPairDataset
4
+ import random
5
+
6
+
7
+ def get_dataset(config, *args, **kwargs):
8
+ name = config.name
9
+ root = config.path
10
+ if name == 'pl':
11
+ dataset = PocketLigandPairDataset(root, *args, **kwargs)
12
+ else:
13
+ raise NotImplementedError('Unknown dataset: %s' % name)
14
+
15
+ if 'split' in config:
16
+ split_by_name = torch.load(config.split)
17
+ split = {k: [dataset.name2id[n] for n in names if n in dataset.name2id] for k, names in split_by_name.items()}
18
+ subsets = {k:Subset(dataset, indices=v) for k, v in split.items()}
19
+ return dataset, subsets
20
+ else:
21
+ return dataset
utils/datasets/pl.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pickle
3
+ import lmdb
4
+ import torch
5
+ from torch.utils.data import Dataset
6
+ from tqdm.auto import tqdm
7
+ import numpy as np
8
+
9
+ from ..protein_ligand import PDBProtein, parse_sdf_file
10
+ from ..data import ProteinLigandData, torchify_dict
11
+ from ..mol_tree import MolTree
12
+
13
+
14
+ def reset_moltree_root(moltree, ligand_pos, protein_pos):
15
+ ligand2 = np.sum(np.square(ligand_pos), 1, keepdims=True)
16
+ protein2 = np.sum(np.square(protein_pos), 1, keepdims=True)
17
+ dist = np.add(np.add(-2 * np.dot(ligand_pos, protein_pos.T), ligand2), protein2.T)
18
+ min_dist = np.min(dist, 1)
19
+ avg_min_dist = []
20
+ for node in moltree.nodes:
21
+ avg_min_dist.append(np.min(min_dist[node.clique]))
22
+ root = np.argmin(avg_min_dist)
23
+ if root > 0:
24
+ moltree.nodes[0], moltree.nodes[root] = moltree.nodes[root], moltree.nodes[0]
25
+ contact_idx = np.argmin(np.min(dist[moltree.nodes[0].clique], 0))
26
+ contact_protein = torch.tensor(np.min(dist, 0) < 4 ** 2)
27
+
28
+ return moltree, contact_protein, torch.tensor([contact_idx])
29
+
30
+
31
+ def from_protein_ligand_dicts(protein_dict=None, ligand_dict=None):
32
+ instance = {}
33
+
34
+ if protein_dict is not None:
35
+ for key, item in protein_dict.items():
36
+ instance['protein_' + key] = item
37
+
38
+ if ligand_dict is not None:
39
+ for key, item in ligand_dict.items():
40
+ if key == 'moltree':
41
+ instance['moltree'] = item
42
+ else:
43
+ instance['ligand_' + key] = item
44
+ return instance
45
+
46
+
47
+ class PocketLigandPairDataset(Dataset):
48
+
49
+ def __init__(self, raw_path, transform=None):
50
+ super().__init__()
51
+ self.raw_path = raw_path.rstrip('/')
52
+ self.index_path = os.path.join(self.raw_path, 'index.pt')
53
+ self.processed_path = os.path.join(os.path.dirname(self.raw_path),
54
+ os.path.basename(self.raw_path) + '_processed.lmdb')
55
+ self.name2id_path = os.path.join(os.path.dirname(self.raw_path),
56
+ os.path.basename(self.raw_path) + '_name2id.pt')
57
+ self.transform = transform
58
+ self.db = None
59
+
60
+ self.keys = None
61
+
62
+ if not os.path.exists(self.processed_path):
63
+ self._process()
64
+ self._precompute_name2id()
65
+
66
+ self.name2id = torch.load(self.name2id_path)
67
+
68
+ def _connect_db(self):
69
+ """
70
+ Establish read-only database connection
71
+ """
72
+ assert self.db is None, 'A connection has already been opened.'
73
+ self.db = lmdb.open(
74
+ self.processed_path,
75
+ map_size=10 * (1024 * 1024 * 1024), # 10GB
76
+ create=False,
77
+ subdir=False,
78
+ readonly=True,
79
+ lock=False,
80
+ readahead=False,
81
+ meminit=False,
82
+ )
83
+ with self.db.begin() as txn:
84
+ self.keys = list(txn.cursor().iternext(values=False))
85
+
86
+ def _close_db(self):
87
+ self.db.close()
88
+ self.db = None
89
+ self.keys = None
90
+
91
+ def _process(self):
92
+ db = lmdb.open(
93
+ self.processed_path,
94
+ map_size=10 * (1024 * 1024 * 1024), # 10GB
95
+ create=True,
96
+ subdir=False,
97
+ readonly=False, # Writable
98
+ )
99
+ #with open(self.index_path, 'rb') as f:
100
+ #index = pickle.load(f)
101
+ index = torch.load(self.index_path)
102
+ vocab = []
103
+ for line in open('./vocab.txt'):
104
+ p, _, _ = line.partition(':')
105
+ vocab.append(p)
106
+
107
+ num_skipped = 0
108
+ with db.begin(write=True, buffers=True) as txn:
109
+ for i, pdbid in enumerate(tqdm(index)):
110
+ if pdbid is None: continue
111
+ try:
112
+ ligand_fn = os.path.join(pdbid, pdbid + '_ligand.sdf')
113
+ pocket_fn = os.path.join(pdbid, pdbid + '_pocket.pdb')
114
+ pocket_dict = PDBProtein(os.path.join(self.raw_path, pocket_fn)).to_dict_atom()
115
+ ligand_dict = parse_sdf_file(os.path.join(self.raw_path, ligand_fn))
116
+ ligand_dict['moltree'], pocket_dict['contact'], pocket_dict['contact_idx'] = reset_moltree_root(
117
+ ligand_dict['moltree'],
118
+ ligand_dict['pos'],
119
+ pocket_dict['pos'])
120
+ data = from_protein_ligand_dicts(
121
+ protein_dict=torchify_dict(pocket_dict),
122
+ ligand_dict=torchify_dict(ligand_dict),
123
+ )
124
+ data['protein_filename'] = pocket_fn
125
+ data['ligand_filename'] = ligand_fn
126
+ data['pdbid'] = pdbid
127
+ txn.put(
128
+ key=str(i).encode(),
129
+ value=pickle.dumps(data)
130
+ )
131
+ for c in ligand_dict['moltree'].nodes:
132
+ smile_cluster = c.smiles
133
+ assert smile_cluster in vocab
134
+ except:
135
+ num_skipped += 1
136
+ print('Skipping (%d) %s' % (num_skipped, ligand_fn,))
137
+ continue
138
+ db.close()
139
+
140
+ def _precompute_name2id(self):
141
+ name2id = {}
142
+ for i in tqdm(range(self.__len__()), 'Indexing'):
143
+ try:
144
+ data = self.__getitem__(i)
145
+ except AssertionError as e:
146
+ print(i, e)
147
+ continue
148
+ name = data['pdbid']
149
+ name2id[name] = i
150
+ torch.save(name2id, self.name2id_path)
151
+
152
+ def __len__(self):
153
+ if self.db is None:
154
+ self._connect_db()
155
+ return len(self.keys)
156
+
157
+ def __getitem__(self, idx):
158
+ if self.db is None:
159
+ self._connect_db()
160
+ key = self.keys[idx]
161
+ data = pickle.loads(self.db.begin().get(key))
162
+ data['id'] = idx
163
+ assert data['protein_pos'].size(0) > 0
164
+ if self.transform is not None:
165
+ data = self.transform(data)
166
+ return data
167
+
168
+
169
+ if __name__ == '__main__':
170
+ import argparse
171
+
172
+ parser = argparse.ArgumentParser()
173
+ parser.add_argument('path', type=str)
174
+ args = parser.parse_args()
175
+
176
+ PocketLigandPairDataset(args.path)
utils/dihedral_utils.py ADDED
@@ -0,0 +1,383 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch_geometric as tg
3
+ from torch_geometric.utils import degree
4
+ import networkx as nx
5
+
6
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
7
+ angle_mask_ref = torch.LongTensor([[0, 0, 0, 0, 0, 0],
8
+ [0, 0, 0, 0, 0, 0],
9
+ [1, 0, 0, 0, 0, 0],
10
+ [1, 1, 1, 0, 0, 0],
11
+ [1, 1, 1, 1, 1, 1]]).to(device)
12
+
13
+ angle_combos = torch.LongTensor([[0, 1],
14
+ [0, 2],
15
+ [1, 2],
16
+ [0, 3],
17
+ [1, 3],
18
+ [2, 3]]).to(device)
19
+
20
+
21
+ def get_neighbor_ids(data):
22
+ """
23
+ Takes the edge indices and returns dictionary mapping atom index to neighbor indices
24
+ Note: this only includes atoms with degree > 1
25
+ """
26
+ # start, end = edge_index
27
+ # idxs, vals = torch.unique(start, return_counts=True)
28
+ # vs = torch.split_with_sizes(end, tuple(vals))
29
+ # return {k.item(): v for k, v in zip(idxs, vs) if len(v) > 1}
30
+ neighbors = data.neighbors.pop(0)
31
+ n_atoms_per_mol = data.batch.bincount()
32
+ n_atoms_prev_mol = 0
33
+
34
+ for i, n_dict in enumerate(data.neighbors):
35
+ new_dict = {}
36
+ n_atoms_prev_mol += n_atoms_per_mol[i].item()
37
+ for k, v in n_dict.items():
38
+ new_dict[k + n_atoms_prev_mol] = v + n_atoms_prev_mol
39
+ neighbors.update(new_dict)
40
+ return neighbors
41
+
42
+
43
+ def get_neighbor_bonds(edge_index, bond_type):
44
+ """
45
+ Takes the edge indices and bond type and returns dictionary mapping atom index to neighbor bond types
46
+ Note: this only includes atoms with degree > 1
47
+ """
48
+ start, end = edge_index
49
+ idxs, vals = torch.unique(start, return_counts=True)
50
+ vs = torch.split_with_sizes(bond_type, tuple(vals))
51
+ return {k.item(): v for k, v in zip(idxs, vs) if len(v) > 1}
52
+
53
+
54
+ def get_leaf_hydrogens(neighbors, x):
55
+ """
56
+ Takes the edge indices and atom features and returns dictionary mapping atom index to neighbors, indicating true
57
+ for hydrogens that are leaf nodes
58
+ Note: this only works because degree = 1 and hydrogen atomic number = 1 (checks when 1 == 1)
59
+ Note: we use the 5th feature index bc this corresponds to the atomic number
60
+ """
61
+ # start, end = edge_index
62
+ # degrees = degree(end)
63
+ # idxs, vals = torch.unique(start, return_counts=True)
64
+ # vs = torch.split_with_sizes(end, tuple(vals))
65
+ # return {k.item(): degrees[v] == x[v, 5] for k, v in zip(idxs, vs) if len(v) > 1}
66
+ leaf_hydrogens = {}
67
+ h_mask = x[:, 0] == 1
68
+ for k, v in neighbors.items():
69
+ leaf_hydrogens[k] = h_mask[neighbors[k]]
70
+ return leaf_hydrogens
71
+
72
+
73
+ def get_dihedral_pairs(edge_index, data):
74
+ """
75
+ Given edge indices, return pairs of indices that we must calculate dihedrals for
76
+ """
77
+ start, end = edge_index
78
+ degrees = degree(end)
79
+ dihedral_pairs_true = torch.nonzero(torch.logical_and(degrees[start] > 1, degrees[end] > 1))
80
+ dihedral_pairs = edge_index[:, dihedral_pairs_true].squeeze(-1)
81
+
82
+ # # first method which removes one (pseudo) random edge from a cycle
83
+ dihedral_idxs = torch.nonzero(dihedral_pairs.sort(dim=0).indices[0, :] == 0).squeeze().detach().cpu().numpy()
84
+
85
+ # prioritize rings for assigning dihedrals
86
+ dihedral_pairs = dihedral_pairs.t()[dihedral_idxs]
87
+ G = nx.to_undirected(tg.utils.to_networkx(data))
88
+ cycles = nx.cycle_basis(G)
89
+ keep, sorted_keep = [], []
90
+
91
+ if len(dihedral_pairs.shape) == 1:
92
+ dihedral_pairs = dihedral_pairs.unsqueeze(0)
93
+
94
+ for pair in dihedral_pairs:
95
+ x, y = pair
96
+
97
+ if sorted(pair) in sorted_keep:
98
+ continue
99
+
100
+ y_cycle_check = [y in cycle for cycle in cycles]
101
+ x_cycle_check = [x in cycle for cycle in cycles]
102
+
103
+ if any(x_cycle_check) and any(y_cycle_check): # both in new cycle
104
+ cycle_indices = get_current_cycle_indices(cycles, x_cycle_check, x)
105
+ keep.extend(cycle_indices)
106
+
107
+ sorted_keep.extend([sorted(c) for c in cycle_indices])
108
+ continue
109
+
110
+ if any(y_cycle_check):
111
+ cycle_indices = get_current_cycle_indices(cycles, y_cycle_check, y)
112
+ keep.append(pair)
113
+ keep.extend(cycle_indices)
114
+
115
+ sorted_keep.append(sorted(pair))
116
+ sorted_keep.extend([sorted(c) for c in cycle_indices])
117
+ continue
118
+
119
+ keep.append(pair)
120
+
121
+ keep = [t.to(device) for t in keep]
122
+ return torch.stack(keep).t()
123
+
124
+
125
+ def batch_distance_metrics_from_coords(coords, mask):
126
+ """
127
+ Given coordinates of neighboring atoms, compute bond
128
+ distances and 2-hop distances in local neighborhood
129
+ """
130
+ d_mat_mask = mask.unsqueeze(1) * mask.unsqueeze(2)
131
+
132
+ if coords.dim() == 4:
133
+ two_dop_d_mat = torch.square(coords.unsqueeze(1) - coords.unsqueeze(2) + 1e-10).sum(dim=-1).sqrt() * d_mat_mask.unsqueeze(-1)
134
+ one_hop_ds = torch.linalg.norm(torch.zeros_like(coords[0]).unsqueeze(0) - coords, dim=-1)
135
+ elif coords.dim() == 5:
136
+ two_dop_d_mat = torch.square(coords.unsqueeze(2) - coords.unsqueeze(3) + 1e-10).sum(dim=-1).sqrt() * d_mat_mask.unsqueeze(-1).unsqueeze(1)
137
+ one_hop_ds = torch.linalg.norm(torch.zeros_like(coords[0]).unsqueeze(0) - coords, dim=-1)
138
+
139
+ return one_hop_ds, two_dop_d_mat
140
+
141
+
142
+ def batch_angle_between_vectors(a, b):
143
+ """
144
+ Compute angle between two batches of input vectors
145
+ """
146
+ inner_product = (a * b).sum(dim=-1)
147
+
148
+ # norms
149
+ a_norm = torch.linalg.norm(a, dim=-1)
150
+ b_norm = torch.linalg.norm(b, dim=-1)
151
+
152
+ # protect denominator during division
153
+ den = a_norm * b_norm + 1e-10
154
+ cos = inner_product / den
155
+
156
+ return cos
157
+
158
+
159
+ def batch_angles_from_coords(coords, mask):
160
+ """
161
+ Given coordinates, compute all local neighborhood angles
162
+ """
163
+ if coords.dim() == 4:
164
+ all_possible_combos = coords[:, angle_combos]
165
+ v_a, v_b = all_possible_combos.split(1, dim=2) # does one of these need to be negative?
166
+ angle_mask = angle_mask_ref[mask.sum(dim=1).long()]
167
+ angles = batch_angle_between_vectors(v_a.squeeze(2), v_b.squeeze(2)) * angle_mask.unsqueeze(-1)
168
+ elif coords.dim() == 5:
169
+ all_possible_combos = coords[:, :, angle_combos]
170
+ v_a, v_b = all_possible_combos.split(1, dim=3) # does one of these need to be negative?
171
+ angle_mask = angle_mask_ref[mask.sum(dim=1).long()]
172
+ angles = batch_angle_between_vectors(v_a.squeeze(3), v_b.squeeze(3)) * angle_mask.unsqueeze(-1).unsqueeze(-1)
173
+
174
+ return angles
175
+
176
+
177
+ def batch_local_stats_from_coords(coords, mask):
178
+ """
179
+ Given neighborhood neighbor coordinates, compute bond distances,
180
+ 2-hop distances, and angles in local neighborhood (this assumes
181
+ the central atom has coordinates at the origin)
182
+ """
183
+ one_hop_ds, two_dop_d_mat = batch_distance_metrics_from_coords(coords, mask)
184
+ angles = batch_angles_from_coords(coords, mask)
185
+ return one_hop_ds, two_dop_d_mat, angles
186
+
187
+
188
+ def batch_dihedrals(p0, p1, p2, p3, angle=False):
189
+
190
+ s1 = p1 - p0
191
+ s2 = p2 - p1
192
+ s3 = p3 - p2
193
+
194
+ sin_d_ = torch.linalg.norm(s2, dim=-1) * torch.sum(s1 * torch.cross(s2, s3, dim=-1), dim=-1)
195
+ cos_d_ = torch.sum(torch.cross(s1, s2, dim=-1) * torch.cross(s2, s3, dim=-1), dim=-1)
196
+
197
+ if angle:
198
+ return torch.atan2(sin_d_, cos_d_ + 1e-10)
199
+
200
+ else:
201
+ den = torch.linalg.norm(torch.cross(s1, s2, dim=-1), dim=-1) * torch.linalg.norm(torch.cross(s2, s3, dim=-1), dim=-1) + 1e-10
202
+ return sin_d_/den, cos_d_/den
203
+
204
+
205
+ def batch_vector_angles(xn, x, y, yn):
206
+ uT = xn.view(-1, 3)
207
+ uX = x.view(-1, 3)
208
+ uY = y.view(-1, 3)
209
+ uZ = yn.view(-1, 3)
210
+
211
+ b1 = uT - uX
212
+ b2 = uZ - uY
213
+
214
+ num = torch.bmm(b1.view(-1, 1, 3), b2.view(-1, 3, 1)).squeeze(-1).squeeze(-1)
215
+ den = torch.linalg.norm(b1, dim=-1) * torch.linalg.norm(b2, dim=-1) + 1e-10
216
+
217
+ return (num / den).view(-1, 9)
218
+
219
+
220
+ def von_Mises_loss(a, b, a_sin=None, b_sin=None):
221
+ """
222
+ :param a: cos of first angle
223
+ :param b: cos of second angle
224
+ :return: difference of cosines
225
+ """
226
+ if torch.is_tensor(a_sin):
227
+ out = a * b + a_sin * b_sin
228
+ else:
229
+ out = a * b + torch.sqrt(1-a**2 + 1e-5) * torch.sqrt(1-b**2 + 1e-5)
230
+ return out
231
+
232
+
233
+ def rotation_matrix(neighbor_coords, neighbor_mask, neighbor_map, mu=None):
234
+ """
235
+ Given predicted neighbor coordinates from model, return rotation matrix
236
+
237
+ :param neighbor_coords: neighbor coordinates for each edge as defined by dihedral_pairs
238
+ (n_dihedral_pairs, 4, n_generated_confs, 3)
239
+ :param neighbor_mask: mask describing which atoms are present (n_dihedral_pairs, 4)
240
+ :param neighbor_map: mask describing which neighbor corresponds to the other central dihedral atom
241
+ (n_dihedral_pairs, 4) each entry in neighbor_map should have one TRUE entry with the rest as FALSE
242
+ :return: rotation matrix (n_dihedral_pairs, n_model_confs, 3, 3)
243
+ """
244
+
245
+ if not torch.is_tensor(mu):
246
+ # mu = neighbor_coords.sum(dim=1, keepdim=True) / (neighbor_mask.sum(dim=-1, keepdim=True).unsqueeze(-1).unsqueeze(-1) + 1e-10)
247
+ mu_num = neighbor_coords[~neighbor_map.bool()].view(neighbor_coords.size(0), 3, neighbor_coords.size(2), -1).sum(dim=1)
248
+ mu_den = (neighbor_mask.sum(dim=-1, keepdim=True).unsqueeze(-1) - 1 + 1e-10)
249
+ mu = mu_num / mu_den # (n_dihedral_pairs, n_model_confs, 10)
250
+ mu = mu.squeeze(1) # (n_dihedral_pairs, n_model_confs, 10)
251
+
252
+ p_Y = neighbor_coords[neighbor_map.bool(), :]
253
+ h1 = p_Y / (torch.linalg.norm(p_Y, dim=-1, keepdim=True) + 1e-10) # (n_dihedral_pairs, n_model_confs, 10)
254
+
255
+ h3_1 = torch.cross(p_Y, mu, dim=-1)
256
+ h3 = h3_1 / (torch.linalg.norm(h3_1, dim=-1, keepdim=True) + 1e-10) # (n_dihedral_pairs, n_model_confs, 10)
257
+
258
+ h2 = -torch.cross(h1, h3, dim=-1) # (n_dihedral_pairs, n_model_confs, 10)
259
+
260
+ H = torch.cat([h1.unsqueeze(-2),
261
+ h2.unsqueeze(-2),
262
+ h3.unsqueeze(-2)], dim=-2)
263
+
264
+ return H
265
+
266
+
267
+ def rotation_matrix_v2(neighbor_coords):
268
+ """
269
+ Given predicted neighbor coordinates from model, return rotation matrix
270
+ :param neighbor_coords: y or x coordinates for the x or y center node
271
+ (n_dihedral_pairs, 3)
272
+ :return: rotation matrix (n_dihedral_pairs, 3, 3)
273
+ """
274
+
275
+ p_Y = neighbor_coords
276
+
277
+ eta_1 = torch.rand_like(p_Y)
278
+ eta_2 = eta_1 - torch.sum(eta_1 * p_Y, dim=-1, keepdim=True) / (torch.linalg.norm(p_Y, dim=-1, keepdim=True)**2 + 1e-10) * p_Y
279
+ eta = eta_2 / torch.linalg.norm(eta_2, dim=-1, keepdim=True)
280
+
281
+ h1 = p_Y / (torch.linalg.norm(p_Y, dim=-1, keepdim=True) + 1e-10) # (n_dihedral_pairs, n_model_confs, 10)
282
+
283
+ h3_1 = torch.cross(p_Y, eta, dim=-1)
284
+ h3 = h3_1 / (torch.linalg.norm(h3_1, dim=-1, keepdim=True) + 1e-10) # (n_dihedral_pairs, n_model_confs, 10)
285
+
286
+ h2 = -torch.cross(h1, h3, dim=-1) # (n_dihedral_pairs, n_model_confs, 10)
287
+
288
+ H = torch.cat([h1.unsqueeze(-2),
289
+ h2.unsqueeze(-2),
290
+ h3.unsqueeze(-2)], dim=-2)
291
+
292
+ return H
293
+
294
+
295
+ def signed_volume(local_coords):
296
+ """
297
+ Compute signed volume given ordered neighbor local coordinates
298
+
299
+ :param local_coords: (n_tetrahedral_chiral_centers, 4, n_generated_confs, 3)
300
+ :return: signed volume of each tetrahedral center (n_tetrahedral_chiral_centers, n_generated_confs)
301
+ """
302
+ v1 = local_coords[:, 0] - local_coords[:, 3]
303
+ v2 = local_coords[:, 1] - local_coords[:, 3]
304
+ v3 = local_coords[:, 2] - local_coords[:, 3]
305
+ cp = v2.cross(v3, dim=-1)
306
+ vol = torch.sum(v1 * cp, dim=-1)
307
+ return torch.sign(vol)
308
+
309
+
310
+ def rotation_matrix_inf(neighbor_coords, neighbor_mask, neighbor_map):
311
+ """
312
+ Given predicted neighbor coordinates from model, return rotation matrix
313
+
314
+ :param neighbor_coords: neighbor coordinates for each edge as defined by dihedral_pairs (4, n_model_confs, 3)
315
+ :param neighbor_mask: mask describing which atoms are present (4)
316
+ :param neighbor_map: mask describing which neighbor corresponds to the other central dihedral atom (4)
317
+ each entry in neighbor_map should have one TRUE entry with the rest as FALSE
318
+ :return: rotation matrix (3, 3)
319
+ """
320
+
321
+ mu = neighbor_coords.sum(dim=0, keepdim=True) / (neighbor_mask.sum(dim=-1, keepdim=True).unsqueeze(-1) + 1e-10)
322
+ mu = mu.squeeze(0)
323
+ p_Y = neighbor_coords[neighbor_map.bool(), :].squeeze(0)
324
+
325
+ h1 = p_Y / (torch.linalg.norm(p_Y, dim=-1, keepdim=True) + 1e-10)
326
+
327
+ h3_1 = torch.cross(p_Y, mu, dim=-1)
328
+ h3 = h3_1 / (torch.linalg.norm(h3_1, dim=-1, keepdim=True) + 1e-10)
329
+
330
+ h2 = -torch.cross(h1, h3, dim=-1)
331
+
332
+ H = torch.cat([h1.unsqueeze(-2),
333
+ h2.unsqueeze(-2),
334
+ h3.unsqueeze(-2)], dim=-2)
335
+
336
+ return H
337
+
338
+
339
+ def build_alpha_rotation_inf(alpha, n_model_confs):
340
+
341
+ H_alpha = torch.FloatTensor([[[1, 0, 0], [0, 0, 0], [0, 0, 0]]]).repeat(n_model_confs, 1, 1)
342
+ H_alpha[:, 1, 1] = torch.cos(alpha)
343
+ H_alpha[:, 1, 2] = -torch.sin(alpha)
344
+ H_alpha[:, 2, 1] = torch.sin(alpha)
345
+ H_alpha[:, 2, 2] = torch.cos(alpha)
346
+
347
+ return H_alpha
348
+
349
+
350
+ def random_rotation_matrix(dim):
351
+ yaw = torch.rand(dim)
352
+ pitch = torch.rand(dim)
353
+ roll = torch.rand(dim)
354
+
355
+ R = torch.stack([torch.stack([torch.cos(yaw) * torch.cos(pitch),
356
+ torch.cos(yaw) * torch.sin(pitch) * torch.sin(roll) - torch.sin(yaw) * torch.cos(
357
+ roll),
358
+ torch.cos(yaw) * torch.sin(pitch) * torch.cos(roll) + torch.sin(yaw) * torch.sin(
359
+ roll)], dim=-1),
360
+ torch.stack([torch.sin(yaw) * torch.cos(pitch),
361
+ torch.sin(yaw) * torch.sin(pitch) * torch.sin(roll) + torch.cos(yaw) * torch.cos(
362
+ roll),
363
+ torch.sin(yaw) * torch.sin(pitch) * torch.cos(roll) - torch.cos(yaw) * torch.sin(
364
+ roll)], dim=-1),
365
+ torch.stack([-torch.sin(pitch),
366
+ torch.cos(pitch) * torch.sin(roll),
367
+ torch.cos(pitch) * torch.cos(roll)], dim=-1)], dim=-2)
368
+
369
+ return R
370
+
371
+
372
+ def length_to_mask(length, max_len=None, dtype=None):
373
+ """length: B.
374
+ return B x max_len.
375
+ If max_len is None, then max of length will be used.
376
+ """
377
+ assert len(length.shape) == 1, 'Length shape should be 1 dimensional.'
378
+ max_len = max_len or length.max().item()
379
+ mask = torch.arange(max_len, device=length.device,
380
+ dtype=length.dtype).expand(len(length), max_len) < length.unsqueeze(1)
381
+ if dtype is not None:
382
+ mask = torch.as_tensor(mask, dtype=dtype, device=length.device)
383
+ return mask
utils/docking.py ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import subprocess
3
+ import random
4
+ import string
5
+ from easydict import EasyDict
6
+ from rdkit import Chem
7
+ from rdkit.Chem.rdForceFieldHelpers import UFFOptimizeMolecule
8
+
9
+ from .reconstruct import reconstruct_from_generated
10
+
11
+
12
+ def get_random_id(length=30):
13
+ letters = string.ascii_lowercase
14
+ return ''.join(random.choice(letters) for i in range(length))
15
+
16
+
17
+ def load_pdb(path):
18
+ with open(path, 'r') as f:
19
+ return f.read()
20
+
21
+
22
+ def parse_qvina_outputs(docked_sdf_path):
23
+
24
+ suppl = Chem.SDMolSupplier(docked_sdf_path)
25
+ results = []
26
+ for i, mol in enumerate(suppl):
27
+ if mol is None:
28
+ continue
29
+ line = mol.GetProp('REMARK').splitlines()[0].split()[2:]
30
+ results.append(EasyDict({
31
+ 'rdmol': mol,
32
+ 'mode_id': i,
33
+ 'affinity': float(line[0]),
34
+ 'rmsd_lb': float(line[1]),
35
+ 'rmsd_ub': float(line[2]),
36
+ }))
37
+
38
+ return results
39
+
40
+ class BaseDockingTask(object):
41
+
42
+ def __init__(self, pdb_block, ligand_rdmol):
43
+ super().__init__()
44
+ self.pdb_block = pdb_block
45
+ self.ligand_rdmol = ligand_rdmol
46
+
47
+ def run(self):
48
+ raise NotImplementedError()
49
+
50
+ def get_results(self):
51
+ raise NotImplementedError()
52
+
53
+
54
+ class QVinaDockingTask(BaseDockingTask):
55
+
56
+ @classmethod
57
+ def from_generated_data(cls, data, protein_root='./data/crossdocked', **kwargs):
58
+ protein_fn = os.path.join(
59
+ os.path.dirname(data.ligand_filename),
60
+ os.path.basename(data.ligand_filename)[:10] + '.pdb'
61
+ )
62
+ protein_path = os.path.join(protein_root, protein_fn)
63
+ with open(protein_path, 'r') as f:
64
+ pdb_block = f.read()
65
+ ligand_rdmol = reconstruct_from_generated(data)
66
+ return cls(pdb_block, ligand_rdmol, **kwargs)
67
+
68
+ @classmethod
69
+ def from_original_data(cls, data, ligand_root='./data/crossdocked_pocket10', protein_root='./data/crossdocked', **kwargs):
70
+ protein_fn = os.path.join(
71
+ os.path.dirname(data.ligand_filename),
72
+ os.path.basename(data.ligand_filename)[:10] + '.pdb'
73
+ )
74
+ protein_path = os.path.join(protein_root, protein_fn)
75
+ with open(protein_path, 'r') as f:
76
+ pdb_block = f.read()
77
+
78
+ ligand_path = os.path.join(ligand_root, data.ligand_filename)
79
+ ligand_rdmol = next(iter(Chem.SDMolSupplier(ligand_path)))
80
+ return cls(pdb_block, ligand_rdmol, **kwargs)
81
+
82
+ def __init__(self, pdb_block, ligand_rdmol, conda_env='adt', tmp_dir='./tmp', use_uff=True, center=None):
83
+ super().__init__(pdb_block, ligand_rdmol)
84
+ self.conda_env = conda_env
85
+ self.tmp_dir = os.path.realpath(tmp_dir)
86
+ os.makedirs(tmp_dir, exist_ok=True)
87
+
88
+ self.task_id = get_random_id()
89
+ self.receptor_id = self.task_id + '_receptor'
90
+ self.ligand_id = self.task_id + '_ligand'
91
+
92
+ self.receptor_path = os.path.join(self.tmp_dir, self.receptor_id + '.pdb')
93
+ self.ligand_path = os.path.join(self.tmp_dir, self.ligand_id + '.sdf')
94
+
95
+ with open(self.receptor_path, 'w') as f:
96
+ f.write(pdb_block)
97
+
98
+ ligand_rdmol = Chem.AddHs(ligand_rdmol, addCoords=True)
99
+ if use_uff:
100
+ UFFOptimizeMolecule(ligand_rdmol)
101
+ sdf_writer = Chem.SDWriter(self.ligand_path)
102
+ sdf_writer.write(ligand_rdmol)
103
+ sdf_writer.close()
104
+ self.ligand_rdmol = ligand_rdmol
105
+
106
+ pos = ligand_rdmol.GetConformer(0).GetPositions()
107
+ if center is None:
108
+ self.center = (pos.max(0) + pos.min(0)) / 2
109
+ else:
110
+ self.center = center
111
+
112
+ self.proc = None
113
+ self.results = None
114
+ self.output = None
115
+ self.docked_sdf_path = None
116
+
117
+ def run(self, exhaustiveness=16):
118
+ commands = """
119
+ eval "$(conda shell.bash hook)"
120
+ conda activate {env}
121
+ cd {tmp}
122
+ # Prepare receptor (PDB->PDBQT)
123
+ prepare_receptor4.py -r {receptor_id}.pdb
124
+ # Prepare ligand
125
+ obabel {ligand_id}.sdf -O{ligand_id}.pdbqt
126
+ qvina2.1 \
127
+ --receptor {receptor_id}.pdbqt \
128
+ --ligand {ligand_id}.pdbqt \
129
+ --center_x {center_x:.4f} \
130
+ --center_y {center_y:.4f} \
131
+ --center_z {center_z:.4f} \
132
+ --size_x 20 --size_y 20 --size_z 20 \
133
+ --exhaustiveness {exhaust}
134
+ obabel {ligand_id}_out.pdbqt -O{ligand_id}_out.sdf -h
135
+ """.format(
136
+ receptor_id = self.receptor_id,
137
+ ligand_id = self.ligand_id,
138
+ env = self.conda_env,
139
+ tmp = self.tmp_dir,
140
+ exhaust = exhaustiveness,
141
+ center_x = self.center[0],
142
+ center_y = self.center[1],
143
+ center_z = self.center[2],
144
+ )
145
+
146
+ self.docked_sdf_path = os.path.join(self.tmp_dir, '%s_out.sdf' % self.ligand_id)
147
+
148
+ self.proc = subprocess.Popen(
149
+ '/bin/bash',
150
+ shell=False,
151
+ stdin=subprocess.PIPE,
152
+ stdout=subprocess.PIPE,
153
+ stderr=subprocess.PIPE
154
+ )
155
+
156
+ self.proc.stdin.write(commands.encode('utf-8'))
157
+ self.proc.stdin.close()
158
+
159
+ # return commands
160
+
161
+ def run_sync(self):
162
+ self.run()
163
+ while self.get_results() is None:
164
+ pass
165
+ results = self.get_results()
166
+ print('Best affinity:', results[0]['affinity'])
167
+ return results
168
+
169
+ def get_results(self):
170
+ if self.proc is None: # Not started
171
+ return None
172
+ elif self.proc.poll() is None: # In progress
173
+ return None
174
+ else:
175
+ if self.output is None:
176
+ self.output = self.proc.stdout.readlines()
177
+ try:
178
+ self.results = parse_qvina_outputs(self.docked_sdf_path)
179
+ except:
180
+ print('[Error] Vina output error: %s' % self.docked_sdf_path)
181
+ return []
182
+ return self.results
183
+
utils/fpscores.pkl.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:10dcef9340c873e7b987924461b0af5365eb8dd96be607203debe8ddf80c1e73
3
+ size 3848394
utils/misc.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import random
4
+ import logging
5
+ import torch
6
+ import numpy as np
7
+ import yaml
8
+ from easydict import EasyDict
9
+ from logging import Logger
10
+ from tqdm.auto import tqdm
11
+
12
+
13
+ class BlackHole(object):
14
+ def __setattr__(self, name, value):
15
+ pass
16
+ def __call__(self, *args, **kwargs):
17
+ return self
18
+ def __getattr__(self, name):
19
+ return self
20
+
21
+
22
+ def load_config(path):
23
+ with open(path, 'r') as f:
24
+ return EasyDict(yaml.safe_load(f))
25
+
26
+
27
+ def get_logger(name, log_dir=None):
28
+ logger = logging.getLogger(name)
29
+ logger.setLevel(logging.DEBUG)
30
+ formatter = logging.Formatter('[%(asctime)s::%(name)s::%(levelname)s] %(message)s')
31
+
32
+ stream_handler = logging.StreamHandler()
33
+ stream_handler.setLevel(logging.DEBUG)
34
+ stream_handler.setFormatter(formatter)
35
+ logger.addHandler(stream_handler)
36
+
37
+ if log_dir is not None:
38
+ file_handler = logging.FileHandler(os.path.join(log_dir, 'log.txt'))
39
+ file_handler.setLevel(logging.DEBUG)
40
+ file_handler.setFormatter(formatter)
41
+ logger.addHandler(file_handler)
42
+
43
+ return logger
44
+
45
+
46
+ def get_new_log_dir(root='./logs', prefix='', tag=''):
47
+ fn = time.strftime('%Y_%m_%d__%H_%M_%S', time.localtime())
48
+ if prefix != '':
49
+ fn = prefix + '_' + fn
50
+ if tag != '':
51
+ fn = fn + '_' + tag
52
+ log_dir = os.path.join(root, fn)
53
+ os.makedirs(log_dir)
54
+ return log_dir
55
+
56
+
57
+ def seed_all(seed):
58
+ torch.manual_seed(seed)
59
+ np.random.seed(seed)
60
+ random.seed(seed)
61
+
62
+
63
+ def log_hyperparams(writer, args):
64
+ from torch.utils.tensorboard.summary import hparams
65
+ vars_args = {k:v if isinstance(v, str) else repr(v) for k, v in vars(args).items()}
66
+ exp, ssi, sei = hparams(vars_args, {})
67
+ writer.file_writer.add_summary(exp)
68
+ writer.file_writer.add_summary(ssi)
69
+ writer.file_writer.add_summary(sei)
70
+
71
+
72
+ def int_tuple(argstr):
73
+ return tuple(map(int, argstr.split(',')))
74
+
75
+
76
+ def str_tuple(argstr):
77
+ return tuple(argstr.split(','))
78
+
utils/mol_tree.py ADDED
@@ -0,0 +1,220 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ sys.path.append("..")
3
+ import rdkit
4
+ import rdkit.Chem as Chem
5
+ import copy
6
+ import pickle
7
+ from tqdm.auto import tqdm
8
+ import numpy as np
9
+ import torch
10
+ import random
11
+ from .chemutils import get_clique_mol, tree_decomp, get_mol, get_smiles, set_atommap, get_clique_mol_simple
12
+ from collections import defaultdict
13
+
14
+
15
+ def get_slots(smiles):
16
+ mol = Chem.MolFromSmiles(smiles, sanitize=False)
17
+ return [(atom.GetSymbol(), atom.GetFormalCharge(), atom.GetTotalNumHs()) for atom in mol.GetAtoms()]
18
+
19
+
20
+ class Vocab(object):
21
+
22
+ def __init__(self, smiles_list):
23
+ self.vocab = smiles_list
24
+ self.vmap = {x: i for i, x in enumerate(self.vocab)}
25
+ #self.slots = [get_slots(smiles) for smiles in self.vocab]
26
+
27
+ def get_index(self, smiles):
28
+ if smiles in self.vmap.keys():
29
+ return self.vmap[smiles]
30
+ else:
31
+ return 0
32
+
33
+ def get_smiles(self, idx):
34
+ return self.vocab[idx]
35
+
36
+ def get_slots(self, idx):
37
+ return copy.deepcopy(self.slots[idx])
38
+
39
+ def size(self):
40
+ return len(self.vocab)
41
+
42
+
43
+ class MolTreeNode(object):
44
+
45
+ def __init__(self, mol, cmol, clique):
46
+ self.smiles = Chem.MolToSmiles(cmol, canonical=True)
47
+ self.mol = cmol
48
+ self.clique = [x for x in clique] # copy
49
+
50
+ self.neighbors = []
51
+ self.rotatable = False
52
+ if len(self.clique) == 2:
53
+ if mol.GetAtomWithIdx(self.clique[0]).GetDegree() >= 2 and mol.GetAtomWithIdx(self.clique[1]).GetDegree() >= 2:
54
+ self.rotatable = True
55
+ # should restrict to single bond, but double bond is ok
56
+
57
+ def add_neighbor(self, nei_node):
58
+ self.neighbors.append(nei_node)
59
+
60
+ def recover(self, original_mol):
61
+ clique = []
62
+ clique.extend(self.clique)
63
+ if not self.is_leaf:
64
+ for cidx in self.clique:
65
+ original_mol.GetAtomWithIdx(cidx).SetAtomMapNum(self.nid)
66
+
67
+ for nei_node in self.neighbors:
68
+ clique.extend(nei_node.clique)
69
+ if nei_node.is_leaf: # Leaf node, no need to mark
70
+ continue
71
+ for cidx in nei_node.clique:
72
+ # allow singleton node override the atom mapping
73
+ if cidx not in self.clique or len(nei_node.clique) == 1:
74
+ atom = original_mol.GetAtomWithIdx(cidx)
75
+ atom.SetAtomMapNum(nei_node.nid)
76
+
77
+ clique = list(set(clique))
78
+ label_mol = get_clique_mol_simple(original_mol, clique)
79
+ self.label = Chem.MolToSmiles(Chem.MolFromSmiles(get_smiles(label_mol)))
80
+ self.label_mol = get_mol(self.label)
81
+
82
+ for cidx in clique:
83
+ original_mol.GetAtomWithIdx(cidx).SetAtomMapNum(0)
84
+
85
+ return self.label
86
+
87
+ def assemble(self):
88
+ # neighbors = [nei for nei in self.neighbors if nei.mol.GetNumAtoms() > 1]
89
+ neighbors = sorted(self.neighbors, key=lambda x: x.mol.GetNumAtoms(), reverse=True)
90
+ # singletons = [nei for nei in self.neighbors if nei.mol.GetNumAtoms() == 1]
91
+ # neighbors = singletons + neighbors
92
+
93
+ cands = enum_assemble(self, neighbors)
94
+ if len(cands) > 0:
95
+ self.cands, self.cand_mols, _ = zip(*cands)
96
+ self.cands = list(self.cands)
97
+ self.cand_mols = list(self.cand_mols)
98
+ else:
99
+ self.cands = []
100
+ self.cand_mols = []
101
+
102
+
103
+ class MolTree(object):
104
+ def __init__(self, mol):
105
+ self.smiles = Chem.MolToSmiles(mol)
106
+ self.mol = mol
107
+ self.num_rotatable_bond = 0
108
+ '''
109
+ # use reference_vocab and threshold to control the size of vocab
110
+ reference_vocab = np.load('./utils/reference.npy', allow_pickle=True).item()
111
+ reference = defaultdict(int)
112
+ for k, v in reference_vocab.items():
113
+ reference[k] = v'''
114
+
115
+ # use vanilla tree decomposition for simplicity
116
+ cliques, edges = tree_decomp(self.mol, reference_vocab=None)
117
+ self.nodes = []
118
+ root = 0
119
+ for i, c in enumerate(cliques):
120
+ cmol = get_clique_mol_simple(self.mol, c)
121
+ node = MolTreeNode(self.mol, cmol, c)
122
+ self.nodes.append(node)
123
+ if min(c) == 0:
124
+ root = i
125
+
126
+ for node in self.nodes:
127
+ if node.rotatable:
128
+ self.num_rotatable_bond += 1
129
+
130
+ for x, y in edges:
131
+ self.nodes[x].add_neighbor(self.nodes[y])
132
+ self.nodes[y].add_neighbor(self.nodes[x])
133
+
134
+ if root > 0:
135
+ self.nodes[0], self.nodes[root] = self.nodes[root], self.nodes[0]
136
+
137
+ for i, node in enumerate(self.nodes):
138
+ node.nid = i + 1
139
+ '''
140
+ if len(node.neighbors) > 1: # Leaf node mol is not marked
141
+ set_atommap(node.mol, node.nid)
142
+ node.is_leaf = (len(node.neighbors) == 1)'''
143
+
144
+ def size(self):
145
+ return len(self.nodes)
146
+
147
+ def recover(self):
148
+ for node in self.nodes:
149
+ node.recover(self.mol)
150
+
151
+ def assemble(self):
152
+ for node in self.nodes:
153
+ node.assemble()
154
+
155
+
156
+ if __name__ == "__main__":
157
+ seed = 2023
158
+ torch.manual_seed(seed)
159
+ np.random.seed(seed)
160
+ random.seed(seed)
161
+
162
+ vocab = {}
163
+ cnt = 0
164
+ rot = 0
165
+ '''
166
+ index_path = './data/crossdocked_pocket10/index.pkl'
167
+ with open(index_path, 'rb') as f:
168
+ index = pickle.load(f)
169
+ for i, (pocket_fn, ligand_fn, _, rmsd_str) in enumerate(tqdm(index)):
170
+ if pocket_fn is None: continue
171
+ try:
172
+ path = './data/crossdocked_pocket10/' + ligand_fn
173
+ mol = Chem.MolFromMolFile(path, sanitize=False)
174
+ moltree = MolTree(mol)
175
+ cnt += 1
176
+ if moltree.num_rotatable_bond > 0:
177
+ rot += 1
178
+ except:
179
+ continue
180
+
181
+ for c in moltree.nodes:
182
+ smile_cluster = c.smiles
183
+ if smile_cluster not in vocab:
184
+ vocab[smile_cluster] = 1
185
+ else:
186
+ vocab[smile_cluster] += 1
187
+ '''
188
+
189
+ index = torch.load('/n/holyscratch01/mzitnik_lab/zaixizhang/pdbbind_pocket10/index.pt')
190
+ for i, pdbid in enumerate(tqdm(index)):
191
+ if pdbid is None: continue
192
+ try:
193
+ path = '/n/holyscratch01/mzitnik_lab/zaixizhang/pdbbind_pocket10/'
194
+ ligand_path = os.path.join(path, os.path.join(item, item+'_ligand.sdf'))
195
+ mol = Chem.MolFromMolFile(ligand_path, sanitize=False)
196
+ moltree = MolTree(mol)
197
+ cnt += 1
198
+ if moltree.num_rotatable_bond > 0:
199
+ rot += 1
200
+ except:
201
+ continue
202
+
203
+ for c in moltree.nodes:
204
+ smile_cluster = c.smiles
205
+ if smile_cluster not in vocab:
206
+ vocab[smile_cluster] = 1
207
+ else:
208
+ vocab[smile_cluster] += 1
209
+
210
+ vocab = dict(sorted(vocab.items(), key=lambda kv: (kv[1], kv[0]), reverse=True))
211
+ filename = open('./vocab.txt', 'w')
212
+ for k, v in vocab.items():
213
+ filename.write(k + ':' + str(v))
214
+ filename.write('\n')
215
+ filename.close()
216
+
217
+ # number of molecules and vocab
218
+ print('Size of the motif vocab:', len(vocab))
219
+ print('Total number of molecules', cnt)
220
+ print('Percent of molecules with rotatable bonds:', rot / cnt)
utils/protein_ligand.py ADDED
@@ -0,0 +1,283 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ sys.path.append("..")
3
+ import os
4
+ import numpy as np
5
+ from rdkit import Chem
6
+ from rdkit.Chem.rdchem import BondType
7
+ from rdkit.Chem import ChemicalFeatures
8
+ from rdkit import RDConfig
9
+ from .mol_tree import *
10
+
11
+ ATOM_FAMILIES = ['Acceptor', 'Donor', 'Aromatic', 'Hydrophobe', 'LumpedHydrophobe', 'NegIonizable', 'PosIonizable', 'ZnBinder']
12
+ ATOM_FAMILIES_ID = {s: i for i, s in enumerate(ATOM_FAMILIES)}
13
+ BOND_TYPES = {t: i for i, t in enumerate(BondType.names.values())}
14
+ BOND_NAMES = {i: t for i, t in enumerate(BondType.names.keys())}
15
+
16
+
17
+ class PDBProtein(object):
18
+
19
+ AA_NAME_SYM = {
20
+ 'ALA': 'A', 'CYS': 'C', 'ASP': 'D', 'GLU': 'E', 'PHE': 'F', 'GLY': 'G', 'HIS': 'H',
21
+ 'ILE': 'I', 'LYS': 'K', 'LEU': 'L', 'MET': 'M', 'ASN': 'N', 'PRO': 'P', 'GLN': 'Q',
22
+ 'ARG': 'R', 'SER': 'S', 'THR': 'T', 'VAL': 'V', 'TRP': 'W', 'TYR': 'Y',
23
+ }
24
+
25
+ AA_NAME_NUMBER = {
26
+ k: i for i, (k, _) in enumerate(AA_NAME_SYM.items())
27
+ }
28
+
29
+ BACKBONE_NAMES = ["CA", "C", "N", "O"]
30
+
31
+ def __init__(self, data, mode='auto'):
32
+ super().__init__()
33
+ if (data[-4:].lower() == '.pdb' and mode == 'auto') or mode == 'path':
34
+ with open(data, 'r') as f:
35
+ self.block = f.read()
36
+ else:
37
+ self.block = data
38
+
39
+ self.ptable = Chem.GetPeriodicTable()
40
+
41
+ # Molecule properties
42
+ self.title = None
43
+ # Atom properties
44
+ self.atoms = []
45
+ self.element = []
46
+ self.atomic_weight = []
47
+ self.pos = []
48
+ self.atom_name = []
49
+ self.is_backbone = []
50
+ self.atom_to_aa_type = []
51
+ # Residue properties
52
+ self.residues = []
53
+ self.amino_acid = []
54
+ self.center_of_mass = []
55
+ self.pos_CA = []
56
+ self.pos_C = []
57
+ self.pos_N = []
58
+ self.pos_O = []
59
+
60
+ self._parse()
61
+
62
+ def _enum_formatted_atom_lines(self):
63
+ for line in self.block.splitlines():
64
+ if line[0:6].strip() == 'ATOM':
65
+ element_symb = line[76:78].strip().capitalize()
66
+ if len(element_symb) == 0:
67
+ element_symb = line[13:14]
68
+ yield {
69
+ 'line': line,
70
+ 'type': 'ATOM',
71
+ 'atom_id': int(line[6:11]),
72
+ 'atom_name': line[12:16].strip(),
73
+ 'res_name': line[17:20].strip(),
74
+ 'chain': line[21:22].strip(),
75
+ 'res_id': int(line[22:26]),
76
+ 'res_insert_id': line[26:27].strip(),
77
+ 'x': float(line[30:38]),
78
+ 'y': float(line[38:46]),
79
+ 'z': float(line[46:54]),
80
+ 'occupancy': float(line[54:60]),
81
+ 'segment': line[72:76].strip(),
82
+ 'element_symb': element_symb,
83
+ 'charge': line[78:80].strip(),
84
+ }
85
+ elif line[0:6].strip() == 'HEADER':
86
+ yield {
87
+ 'type': 'HEADER',
88
+ 'value': line[10:].strip()
89
+ }
90
+ elif line[0:6].strip() == 'ENDMDL':
91
+ break # Some PDBs have more than 1 model.
92
+
93
+ def _parse(self):
94
+ # Process atoms
95
+ residues_tmp = {}
96
+ for atom in self._enum_formatted_atom_lines():
97
+ if atom['type'] == 'HEADER':
98
+ self.title = atom['value'].lower()
99
+ continue
100
+ self.atoms.append(atom)
101
+ atomic_number = self.ptable.GetAtomicNumber(atom['element_symb'])
102
+ next_ptr = len(self.element)
103
+ self.element.append(atomic_number)
104
+ self.atomic_weight.append(self.ptable.GetAtomicWeight(atomic_number))
105
+ self.pos.append(np.array([atom['x'], atom['y'], atom['z']], dtype=np.float32))
106
+ self.atom_name.append(atom['atom_name'])
107
+ self.is_backbone.append(atom['atom_name'] in self.BACKBONE_NAMES)
108
+ self.atom_to_aa_type.append(self.AA_NAME_NUMBER[atom['res_name']])
109
+
110
+ chain_res_id = '%s_%s_%d_%s' % (atom['chain'], atom['segment'], atom['res_id'], atom['res_insert_id'])
111
+ if chain_res_id not in residues_tmp:
112
+ residues_tmp[chain_res_id] = {
113
+ 'name': atom['res_name'],
114
+ 'atoms': [next_ptr],
115
+ 'chain': atom['chain'],
116
+ 'segment': atom['segment'],
117
+ }
118
+ else:
119
+ assert residues_tmp[chain_res_id]['name'] == atom['res_name']
120
+ assert residues_tmp[chain_res_id]['chain'] == atom['chain']
121
+ residues_tmp[chain_res_id]['atoms'].append(next_ptr)
122
+
123
+ # Process residues
124
+ self.residues = [r for _, r in residues_tmp.items()]
125
+ for residue in self.residues:
126
+ sum_pos = np.zeros([3], dtype=np.float32)
127
+ sum_mass = 0.0
128
+ for atom_idx in residue['atoms']:
129
+ sum_pos += self.pos[atom_idx] * self.atomic_weight[atom_idx]
130
+ sum_mass += self.atomic_weight[atom_idx]
131
+ if self.atom_name[atom_idx] in self.BACKBONE_NAMES:
132
+ residue['pos_%s' % self.atom_name[atom_idx]] = self.pos[atom_idx]
133
+ residue['center_of_mass'] = sum_pos / sum_mass
134
+
135
+ # Process backbone atoms of residues
136
+ for residue in self.residues:
137
+ self.amino_acid.append(self.AA_NAME_NUMBER[residue['name']])
138
+ self.center_of_mass.append(residue['center_of_mass'])
139
+ for name in self.BACKBONE_NAMES:
140
+ pos_key = 'pos_%s' % name # pos_CA, pos_C, pos_N, pos_O
141
+ if pos_key in residue:
142
+ getattr(self, pos_key).append(residue[pos_key])
143
+ else:
144
+ getattr(self, pos_key).append(residue['center_of_mass'])
145
+
146
+ def to_dict_atom(self):
147
+ return {
148
+ 'element': np.array(self.element, dtype=np.int_),
149
+ 'molecule_name': self.title,
150
+ 'pos': np.array(self.pos, dtype=np.float32),
151
+ 'is_backbone': np.array(self.is_backbone, dtype=bool),
152
+ 'atom_name': self.atom_name,
153
+ 'atom_to_aa_type': np.array(self.atom_to_aa_type, dtype=np.int_)
154
+ }
155
+
156
+ def to_dict_residue(self):
157
+ return {
158
+ 'amino_acid': np.array(self.amino_acid, dtype=np.int_),
159
+ 'center_of_mass': np.array(self.center_of_mass, dtype=np.float32),
160
+ 'pos_CA': np.array(self.pos_CA, dtype=np.float32),
161
+ 'pos_C': np.array(self.pos_C, dtype=np.float32),
162
+ 'pos_N': np.array(self.pos_N, dtype=np.float32),
163
+ 'pos_O': np.array(self.pos_O, dtype=np.float32),
164
+ }
165
+
166
+ def query_residues_radius(self, center, radius, criterion='center_of_mass'):
167
+ center = np.array(center).reshape(3)
168
+ selected = []
169
+ for residue in self.residues:
170
+ distance = np.linalg.norm(residue[criterion] - center, ord=2)
171
+ print(residue[criterion], distance)
172
+ if distance < radius:
173
+ selected.append(residue)
174
+ return selected
175
+
176
+ def query_residues_ligand(self, ligand, radius, criterion='center_of_mass'):
177
+ selected = []
178
+ sel_idx = set()
179
+ # The time-complexity is O(mn).
180
+ for center in ligand['pos']:
181
+ for i, residue in enumerate(self.residues):
182
+ distance = np.linalg.norm(residue[criterion] - center, ord=2)
183
+ if distance < radius and i not in sel_idx:
184
+ selected.append(residue)
185
+ sel_idx.add(i)
186
+ return selected
187
+
188
+ def residues_to_pdb_block(self, residues, name='POCKET'):
189
+ block = "HEADER %s\n" % name
190
+ block += "COMPND %s\n" % name
191
+ for residue in residues:
192
+ for atom_idx in residue['atoms']:
193
+ block += self.atoms[atom_idx]['line'] + "\n"
194
+ block += "END\n"
195
+ return block
196
+
197
+
198
+ def parse_pdbbind_index_file(path):
199
+ pdb_id = []
200
+ with open(path, 'r') as f:
201
+ lines = f.readlines()
202
+ for line in lines:
203
+ if line.startswith('#'): continue
204
+ pdb_id.append(line.split()[0])
205
+ return pdb_id
206
+
207
+
208
+ def parse_sdf_file(path):
209
+ mol = Chem.MolFromMolFile(path, sanitize=True)
210
+ moltree = MolTree(mol)
211
+ fdefName = os.path.join(RDConfig.RDDataDir,'BaseFeatures.fdef')
212
+ factory = ChemicalFeatures.BuildFeatureFactory(fdefName)
213
+ rdmol = next(iter(Chem.SDMolSupplier(path, removeHs=True)))
214
+ rd_num_atoms = rdmol.GetNumAtoms()
215
+ feat_mat = np.zeros([rd_num_atoms, len(ATOM_FAMILIES)], dtype=np.int_)
216
+ for feat in factory.GetFeaturesForMol(rdmol):
217
+ feat_mat[feat.GetAtomIds(), ATOM_FAMILIES_ID[feat.GetFamily()]] = 1
218
+
219
+ with open(path, 'r') as f:
220
+ sdf = f.read()
221
+
222
+ sdf = sdf.splitlines()
223
+ num_atoms, num_bonds = map(int, [sdf[3][0:3], sdf[3][3:6]])
224
+ assert num_atoms == rd_num_atoms
225
+
226
+ ptable = Chem.GetPeriodicTable()
227
+ element, pos = [], []
228
+ accum_pos = np.array([0.0, 0.0, 0.0], dtype=np.float32)
229
+ accum_mass = 0.0
230
+ for atom_line in map(lambda x:x.split(), sdf[4:4+num_atoms]):
231
+ x, y, z = map(float, atom_line[:3])
232
+ symb = atom_line[3]
233
+ atomic_number = ptable.GetAtomicNumber(symb.capitalize())
234
+ element.append(atomic_number)
235
+ pos.append([x, y, z])
236
+
237
+ atomic_weight = ptable.GetAtomicWeight(atomic_number)
238
+ accum_pos += np.array([x, y, z]) * atomic_weight
239
+ accum_mass += atomic_weight
240
+
241
+ center_of_mass = np.array(accum_pos / accum_mass, dtype=np.float32)
242
+
243
+ element = np.array(element, dtype=np.int_)
244
+ pos = np.array(pos, dtype=np.float32)
245
+
246
+ BOND_TYPES = {t: i for i, t in enumerate(BondType.names.values())}
247
+ bond_type_map = {
248
+ 1: BOND_TYPES[BondType.SINGLE],
249
+ 2: BOND_TYPES[BondType.DOUBLE],
250
+ 3: BOND_TYPES[BondType.TRIPLE],
251
+ 4: BOND_TYPES[BondType.AROMATIC],
252
+ }
253
+ row, col, edge_type = [], [], []
254
+ for bond_line in sdf[4+num_atoms:4+num_atoms+num_bonds]:
255
+ start, end = int(bond_line[0:3])-1, int(bond_line[3:6])-1
256
+ row += [start, end]
257
+ col += [end, start]
258
+ edge_type += 2 * [bond_type_map[int(bond_line[6:9])]]
259
+
260
+ edge_index = np.array([row, col], dtype=np.int_)
261
+ edge_type = np.array(edge_type, dtype=np.int_)
262
+
263
+ perm = (edge_index[0] * num_atoms + edge_index[1]).argsort()
264
+ edge_index = edge_index[:, perm]
265
+ edge_type = edge_type[perm]
266
+
267
+ neighbor_dict = {}
268
+
269
+ #used in rotation angle prediction
270
+ for i, atom in enumerate(mol.GetAtoms()):
271
+ neighbor_dict[i] = [n.GetIdx() for n in atom.GetNeighbors()]
272
+
273
+ data = {
274
+ 'element': element,
275
+ 'pos': pos,
276
+ 'bond_index': edge_index,
277
+ 'bond_type': edge_type,
278
+ 'center_of_mass': center_of_mass,
279
+ 'atom_feature': feat_mat,
280
+ 'moltree': moltree,
281
+ 'neighbors': neighbor_dict
282
+ }
283
+ return data
utils/reconstruct.py ADDED
@@ -0,0 +1,498 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from rdkit.Chem import AllChem as Chem
3
+ from rdkit import Geometry
4
+ from openbabel import openbabel as ob
5
+ from openbabel import pybel
6
+ from scipy.spatial.distance import pdist
7
+ from scipy.spatial.distance import squareform
8
+
9
+ from .protein_ligand import ATOM_FAMILIES_ID
10
+
11
+
12
+ class MolReconsError(Exception):
13
+ pass
14
+
15
+
16
+ def reachable_r(a,b, seenbonds):
17
+ '''Recursive helper.'''
18
+
19
+ for nbr in ob.OBAtomAtomIter(a):
20
+ bond = a.GetBond(nbr).GetIdx()
21
+ if bond not in seenbonds:
22
+ seenbonds.add(bond)
23
+ if nbr == b:
24
+ return True
25
+ elif reachable_r(nbr,b,seenbonds):
26
+ return True
27
+ return False
28
+
29
+
30
+ def reachable(a,b):
31
+ '''Return true if atom b is reachable from a without using the bond between them.'''
32
+ if a.GetExplicitDegree() == 1 or b.GetExplicitDegree() == 1:
33
+ return False #this is the _only_ bond for one atom
34
+ #otherwise do recursive traversal
35
+ seenbonds = set([a.GetBond(b).GetIdx()])
36
+ return reachable_r(a,b,seenbonds)
37
+
38
+
39
+ def forms_small_angle(a,b,cutoff=45):
40
+ '''Return true if bond between a and b is part of a small angle
41
+ with a neighbor of a only.'''
42
+
43
+ for nbr in ob.OBAtomAtomIter(a):
44
+ if nbr != b:
45
+ degrees = b.GetAngle(a,nbr)
46
+ if degrees < cutoff:
47
+ return True
48
+ return False
49
+
50
+
51
+ def make_obmol(xyz, atomic_numbers):
52
+ mol = ob.OBMol()
53
+ mol.BeginModify()
54
+ atoms = []
55
+ for xyz,t in zip(xyz, atomic_numbers):
56
+ x,y,z = xyz
57
+ # ch = struct.channels[t]
58
+ atom = mol.NewAtom()
59
+ atom.SetAtomicNum(t)
60
+ atom.SetVector(x,y,z)
61
+ atoms.append(atom)
62
+ return mol, atoms
63
+
64
+
65
+ def connect_the_dots(mol, atoms, indicators, maxbond=4):
66
+ '''Custom implementation of ConnectTheDots. This is similar to
67
+ OpenBabel's version, but is more willing to make long bonds
68
+ (up to maxbond long) to keep the molecule connected. It also
69
+ attempts to respect atom type information from struct.
70
+ atoms and struct need to correspond in their order
71
+ Assumes no hydrogens or existing bonds.
72
+ '''
73
+ pt = Chem.GetPeriodicTable()
74
+
75
+ if len(atoms) == 0:
76
+ return
77
+
78
+ mol.BeginModify()
79
+
80
+ #just going to to do n^2 comparisons, can worry about efficiency later
81
+ coords = np.array([(a.GetX(),a.GetY(),a.GetZ()) for a in atoms])
82
+ dists = squareform(pdist(coords))
83
+ # types = [struct.channels[t].name for t in struct.c]
84
+
85
+ for (i,a) in enumerate(atoms):
86
+ for (j,b) in enumerate(atoms):
87
+ if a == b:
88
+ break
89
+ if dists[i,j] < 0.01: #reduce from 0.4
90
+ continue #don't bond too close atoms
91
+ if dists[i,j] < maxbond:
92
+ flag = 0
93
+ if indicators[i][ATOM_FAMILIES_ID['Aromatic']] and indicators[j][ATOM_FAMILIES_ID['Aromatic']]:
94
+ # print('Aromatic', ATOM_FAMILIES_ID['Aromatic'], indicators[i])
95
+ flag = ob.OB_AROMATIC_BOND
96
+ # if 'Aromatic' in types[i] and 'Aromatic' in types[j]:
97
+ # flag = ob.OB_AROMATIC_BOND
98
+ mol.AddBond(a.GetIdx(),b.GetIdx(),1,flag)
99
+
100
+ atom_maxb = {}
101
+ for (i,a) in enumerate(atoms):
102
+ #set max valance to the smallest max allowed by openbabel or rdkit
103
+ #since we want the molecule to be valid for both (rdkit is usually lower)
104
+ maxb = ob.GetMaxBonds(a.GetAtomicNum())
105
+ maxb = min(maxb,pt.GetDefaultValence(a.GetAtomicNum()))
106
+
107
+ if a.GetAtomicNum() == 16: # sulfone check
108
+ if count_nbrs_of_elem(a, 8) >= 2:
109
+ maxb = 6
110
+
111
+ # if indicators[i][ATOM_FAMILIES_ID['Donor']]:
112
+ # maxb -= 1 #leave room for hydrogen
113
+ # if 'Donor' in types[i]:
114
+ # maxb -= 1 #leave room for hydrogen
115
+ atom_maxb[a.GetIdx()] = maxb
116
+
117
+ #remove any impossible bonds between halogens
118
+ for bond in ob.OBMolBondIter(mol):
119
+ a1 = bond.GetBeginAtom()
120
+ a2 = bond.GetEndAtom()
121
+ if atom_maxb[a1.GetIdx()] == 1 and atom_maxb[a2.GetIdx()] == 1:
122
+ mol.DeleteBond(bond)
123
+
124
+ def get_bond_info(biter):
125
+ '''Return bonds sorted by their distortion'''
126
+ bonds = [b for b in biter]
127
+ binfo = []
128
+ for bond in bonds:
129
+ bdist = bond.GetLength()
130
+ #compute how far away from optimal we are
131
+ a1 = bond.GetBeginAtom()
132
+ a2 = bond.GetEndAtom()
133
+ ideal = ob.GetCovalentRad(a1.GetAtomicNum()) + ob.GetCovalentRad(a2.GetAtomicNum())
134
+ stretch = bdist-ideal
135
+ binfo.append((stretch,bdist,bond))
136
+ binfo.sort(reverse=True, key=lambda t: t[:2]) #most stretched bonds first
137
+ return binfo
138
+
139
+ #prioritize removing hypervalency causing bonds, do more valent
140
+ #constrained atoms first since their bonds introduce the most problems
141
+ #with reachability (e.g. oxygen)
142
+ # hypers = sorted([(atom_maxb[a.GetIdx()],a.GetExplicitValence() - atom_maxb[a.GetIdx()], a) for a in atoms],key=lambda aa: (aa[0],-aa[1]))
143
+ # for mb,diff,a in hypers:
144
+ # if a.GetExplicitValence() <= atom_maxb[a.GetIdx()]:
145
+ # continue
146
+ # binfo = get_bond_info(ob.OBAtomBondIter(a))
147
+ # for stretch,bdist,bond in binfo:
148
+ # #can we remove this bond without disconnecting the molecule?
149
+ # a1 = bond.GetBeginAtom()
150
+ # a2 = bond.GetEndAtom()
151
+
152
+ # #get right valence
153
+ # if a1.GetExplicitValence() > atom_maxb[a1.GetIdx()] or \
154
+ # a2.GetExplicitValence() > atom_maxb[a2.GetIdx()]:
155
+ # #don't fragment the molecule
156
+ # if not reachable(a1,a2):
157
+ # continue
158
+ # mol.DeleteBond(bond)
159
+ # if a.GetExplicitValence() <= atom_maxb[a.GetIdx()]:
160
+ # break #let nbr atoms choose what bonds to throw out
161
+
162
+
163
+ binfo = get_bond_info(ob.OBMolBondIter(mol))
164
+ #now eliminate geometrically poor bonds
165
+ for stretch,bdist,bond in binfo:
166
+ #can we remove this bond without disconnecting the molecule?
167
+ a1 = bond.GetBeginAtom()
168
+ a2 = bond.GetEndAtom()
169
+
170
+ #as long as we aren't disconnecting, let's remove things
171
+ #that are excessively far away (0.45 from ConnectTheDots)
172
+ #get bonds to be less than max allowed
173
+ #also remove tight angles, because that is what ConnectTheDots does
174
+ if stretch > 0.45 or forms_small_angle(a1,a2) or forms_small_angle(a2,a1):
175
+ #don't fragment the molecule
176
+ if not reachable(a1,a2):
177
+ continue
178
+ mol.DeleteBond(bond)
179
+
180
+ mol.EndModify()
181
+
182
+
183
+ def convert_ob_mol_to_rd_mol(ob_mol,struct=None):
184
+ '''Convert OBMol to RDKit mol, fixing up issues'''
185
+ ob_mol.DeleteHydrogens()
186
+ n_atoms = ob_mol.NumAtoms()
187
+ rd_mol = Chem.RWMol()
188
+ rd_conf = Chem.Conformer(n_atoms)
189
+
190
+ for ob_atom in ob.OBMolAtomIter(ob_mol):
191
+ rd_atom = Chem.Atom(ob_atom.GetAtomicNum())
192
+ #TODO copy format charge
193
+ if ob_atom.IsAromatic() and ob_atom.IsInRing() and ob_atom.MemberOfRingSize() <= 6:
194
+ #don't commit to being aromatic unless rdkit will be okay with the ring status
195
+ #(this can happen if the atoms aren't fit well enough)
196
+ rd_atom.SetIsAromatic(True)
197
+ i = rd_mol.AddAtom(rd_atom)
198
+ ob_coords = ob_atom.GetVector()
199
+ x = ob_coords.GetX()
200
+ y = ob_coords.GetY()
201
+ z = ob_coords.GetZ()
202
+ rd_coords = Geometry.Point3D(x, y, z)
203
+ rd_conf.SetAtomPosition(i, rd_coords)
204
+
205
+ rd_mol.AddConformer(rd_conf)
206
+
207
+ for ob_bond in ob.OBMolBondIter(ob_mol):
208
+ i = ob_bond.GetBeginAtomIdx()-1
209
+ j = ob_bond.GetEndAtomIdx()-1
210
+ bond_order = ob_bond.GetBondOrder()
211
+ if bond_order == 1:
212
+ rd_mol.AddBond(i, j, Chem.BondType.SINGLE)
213
+ elif bond_order == 2:
214
+ rd_mol.AddBond(i, j, Chem.BondType.DOUBLE)
215
+ elif bond_order == 3:
216
+ rd_mol.AddBond(i, j, Chem.BondType.TRIPLE)
217
+ else:
218
+ raise Exception('unknown bond order {}'.format(bond_order))
219
+
220
+ if ob_bond.IsAromatic():
221
+ bond = rd_mol.GetBondBetweenAtoms (i,j)
222
+ bond.SetIsAromatic(True)
223
+
224
+ rd_mol = Chem.RemoveHs(rd_mol, sanitize=False)
225
+
226
+ pt = Chem.GetPeriodicTable()
227
+ #if double/triple bonds are connected to hypervalent atoms, decrement the order
228
+
229
+ positions = rd_mol.GetConformer().GetPositions()
230
+ nonsingles = []
231
+ for bond in rd_mol.GetBonds():
232
+ if bond.GetBondType() == Chem.BondType.DOUBLE or bond.GetBondType() == Chem.BondType.TRIPLE:
233
+ i = bond.GetBeginAtomIdx()
234
+ j = bond.GetEndAtomIdx()
235
+ dist = np.linalg.norm(positions[i]-positions[j])
236
+ nonsingles.append((dist,bond))
237
+ nonsingles.sort(reverse=True, key=lambda t: t[0])
238
+
239
+ for (d,bond) in nonsingles:
240
+ a1 = bond.GetBeginAtom()
241
+ a2 = bond.GetEndAtom()
242
+
243
+ if calc_valence(a1) > pt.GetDefaultValence(a1.GetAtomicNum()) or \
244
+ calc_valence(a2) > pt.GetDefaultValence(a2.GetAtomicNum()):
245
+ btype = Chem.BondType.SINGLE
246
+ if bond.GetBondType() == Chem.BondType.TRIPLE:
247
+ btype = Chem.BondType.DOUBLE
248
+ bond.SetBondType(btype)
249
+
250
+ for atom in rd_mol.GetAtoms():
251
+ #set nitrogens with 4 neighbors to have a charge
252
+ if atom.GetAtomicNum() == 7 and atom.GetDegree() == 4:
253
+ atom.SetFormalCharge(1)
254
+
255
+ rd_mol = Chem.AddHs(rd_mol,addCoords=True)
256
+
257
+ positions = rd_mol.GetConformer().GetPositions()
258
+ center = np.mean(positions[np.all(np.isfinite(positions),axis=1)],axis=0)
259
+ for atom in rd_mol.GetAtoms():
260
+ i = atom.GetIdx()
261
+ pos = positions[i]
262
+ if not np.all(np.isfinite(pos)):
263
+ #hydrogens on C fragment get set to nan (shouldn't, but they do)
264
+ rd_mol.GetConformer().SetAtomPosition(i,center)
265
+
266
+ try:
267
+ Chem.SanitizeMol(rd_mol,Chem.SANITIZE_ALL^Chem.SANITIZE_KEKULIZE)
268
+ except:
269
+ raise MolReconsError()
270
+ # try:
271
+ # Chem.SanitizeMol(rd_mol,Chem.SANITIZE_ALL^Chem.SANITIZE_KEKULIZE)
272
+ # except: # mtr22 - don't assume mols will pass this
273
+ # pass
274
+ # # dkoes - but we want to make failures as rare as possible and should debug them
275
+ # m = pybel.Molecule(ob_mol)
276
+ # i = np.random.randint(1000000)
277
+ # outname = 'bad%d.sdf'%i
278
+ # print("WRITING",outname)
279
+ # m.write('sdf',outname,overwrite=True)
280
+ # pickle.dump(struct,open('bad%d.pkl'%i,'wb'))
281
+
282
+ #but at some point stop trying to enforce our aromaticity -
283
+ #openbabel and rdkit have different aromaticity models so they
284
+ #won't always agree. Remove any aromatic bonds to non-aromatic atoms
285
+ for bond in rd_mol.GetBonds():
286
+ a1 = bond.GetBeginAtom()
287
+ a2 = bond.GetEndAtom()
288
+ if bond.GetIsAromatic():
289
+ if not a1.GetIsAromatic() or not a2.GetIsAromatic():
290
+ bond.SetIsAromatic(False)
291
+ elif a1.GetIsAromatic() and a2.GetIsAromatic():
292
+ bond.SetIsAromatic(True)
293
+
294
+ return rd_mol
295
+
296
+
297
+ def calc_valence(rdatom):
298
+ '''Can call GetExplicitValence before sanitize, but need to
299
+ know this to fix up the molecule to prevent sanitization failures'''
300
+ cnt = 0.0
301
+ for bond in rdatom.GetBonds():
302
+ cnt += bond.GetBondTypeAsDouble()
303
+ return cnt
304
+
305
+
306
+ def count_nbrs_of_elem(atom, atomic_num):
307
+ '''
308
+ Count the number of neighbors atoms
309
+ of atom with the given atomic_num.
310
+ '''
311
+ count = 0
312
+ for nbr in ob.OBAtomAtomIter(atom):
313
+ if nbr.GetAtomicNum() == atomic_num:
314
+ count += 1
315
+ return count
316
+
317
+
318
+ def fixup(atoms, mol, indicators):
319
+ '''Set atom properties to match channel. Keep doing this
320
+ to beat openbabel over the head with what we want to happen.'''
321
+
322
+ mol.SetAromaticPerceived(True) #avoid perception
323
+ for i, atom in enumerate(atoms):
324
+ # ch = struct.channels[t]
325
+ ind = indicators[i]
326
+
327
+ if ind[ATOM_FAMILIES_ID['Aromatic']]:
328
+ atom.SetAromatic(True)
329
+ atom.SetHyb(2)
330
+
331
+ # if ind[ATOM_FAMILIES_ID['Donor']]:
332
+ # if atom.GetExplicitDegree() == atom.GetHvyDegree():
333
+ # if atom.GetHvyDegree() == 1 and atom.GetAtomicNum() == 7:
334
+ # atom.SetImplicitHCount(2)
335
+ # else:
336
+ # atom.SetImplicitHCount(1)
337
+
338
+
339
+ # elif ind[ATOM_FAMILIES_ID['Acceptor']]: # NOT AcceptorDonor because of else
340
+ # atom.SetImplicitHCount(0)
341
+
342
+ if (atom.GetAtomicNum() in (7, 8)) and atom.IsInRing(): # Nitrogen, Oxygen
343
+ #this is a little iffy, ommitting until there is more evidence it is a net positive
344
+ #we don't have aromatic types for nitrogen, but if it
345
+ #is in a ring with aromatic carbon mark it aromatic as well
346
+ acnt = 0
347
+ for nbr in ob.OBAtomAtomIter(atom):
348
+ if nbr.IsAromatic():
349
+ acnt += 1
350
+ if acnt > 1:
351
+ atom.SetAromatic(True)
352
+
353
+
354
+ def raw_obmol_from_generated(data):
355
+ xyz = data.ligand_context_pos.clone().cpu().tolist()
356
+ atomic_nums = data.ligand_context_element.clone().cpu().tolist()
357
+ # indicators = data.ligand_context_feature_full[:, -len(ATOM_FAMILIES_ID):].clone().cpu().bool().tolist()
358
+
359
+ mol, atoms = make_obmol(xyz, atomic_nums)
360
+ return mol, atoms
361
+
362
+
363
+ UPGRADE_BOND_ORDER = {Chem.BondType.SINGLE:Chem.BondType.DOUBLE, Chem.BondType.DOUBLE:Chem.BondType.TRIPLE}
364
+
365
+ def postprocess_rd_mol_1(rdmol):
366
+
367
+ rdmol = Chem.RemoveHs(rdmol)
368
+
369
+ # Construct bond nbh list
370
+ nbh_list = {}
371
+ for bond in rdmol.GetBonds():
372
+ begin, end = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
373
+ if begin not in nbh_list: nbh_list[begin] = [end]
374
+ else: nbh_list[begin].append(end)
375
+
376
+ if end not in nbh_list: nbh_list[end] = [begin]
377
+ else: nbh_list[end].append(begin)
378
+
379
+ # Fix missing bond-order
380
+ for atom in rdmol.GetAtoms():
381
+ idx = atom.GetIdx()
382
+ num_radical = atom.GetNumRadicalElectrons()
383
+ if num_radical > 0:
384
+ for j in nbh_list[idx]:
385
+ if j <= idx: continue
386
+ nb_atom = rdmol.GetAtomWithIdx(j)
387
+ nb_radical = nb_atom.GetNumRadicalElectrons()
388
+ if nb_radical > 0:
389
+ bond = rdmol.GetBondBetweenAtoms(idx, j)
390
+ bond.SetBondType(UPGRADE_BOND_ORDER[bond.GetBondType()])
391
+ nb_atom.SetNumRadicalElectrons(nb_radical - 1)
392
+ num_radical -= 1
393
+ atom.SetNumRadicalElectrons(num_radical)
394
+
395
+ num_radical = atom.GetNumRadicalElectrons()
396
+ if num_radical > 0:
397
+ atom.SetNumRadicalElectrons(0)
398
+ num_hs = atom.GetNumExplicitHs()
399
+ atom.SetNumExplicitHs(num_hs + num_radical)
400
+
401
+ return rdmol
402
+
403
+
404
+ def postprocess_rd_mol_2(rdmol):
405
+ rdmol_edit = Chem.RWMol(rdmol)
406
+
407
+ ring_info = rdmol.GetRingInfo()
408
+ ring_info.AtomRings()
409
+ rings = [set(r) for r in ring_info.AtomRings()]
410
+ for i, ring_a in enumerate(rings):
411
+ if len(ring_a) == 3:
412
+ non_carbon = []
413
+ atom_by_symb = {}
414
+ for atom_idx in ring_a:
415
+ symb = rdmol.GetAtomWithIdx(atom_idx).GetSymbol()
416
+ if symb != 'C':
417
+ non_carbon.append(atom_idx)
418
+ if symb not in atom_by_symb:
419
+ atom_by_symb[symb] = [atom_idx]
420
+ else:
421
+ atom_by_symb[symb].append(atom_idx)
422
+ if len(non_carbon) == 2:
423
+ rdmol_edit.RemoveBond(*non_carbon)
424
+ if 'O' in atom_by_symb and len(atom_by_symb['O']) == 2:
425
+ rdmol_edit.RemoveBond(*atom_by_symb['O'])
426
+ rdmol_edit.GetAtomWithIdx(atom_by_symb['O'][0]).SetNumExplicitHs(
427
+ rdmol_edit.GetAtomWithIdx(atom_by_symb['O'][0]).GetNumExplicitHs() + 1
428
+ )
429
+ rdmol_edit.GetAtomWithIdx(atom_by_symb['O'][1]).SetNumExplicitHs(
430
+ rdmol_edit.GetAtomWithIdx(atom_by_symb['O'][1]).GetNumExplicitHs() + 1
431
+ )
432
+ rdmol = rdmol_edit.GetMol()
433
+
434
+ for atom in rdmol.GetAtoms():
435
+ if atom.GetFormalCharge() > 0:
436
+ atom.SetFormalCharge(0)
437
+
438
+ return rdmol
439
+
440
+
441
+ def reconstruct_from_generated(data):
442
+ xyz = data.ligand_context_pos.clone().cpu().tolist()
443
+ atomic_nums = data.ligand_context_element.clone().cpu().tolist()
444
+ indicators = data.ligand_context_feature_full[:, -len(ATOM_FAMILIES_ID):].clone().cpu().bool().tolist()
445
+
446
+ mol, atoms = make_obmol(xyz, atomic_nums)
447
+ fixup(atoms, mol, indicators)
448
+
449
+ connect_the_dots(mol, atoms, indicators, 2)
450
+ fixup(atoms, mol, indicators)
451
+ mol.EndModify()
452
+
453
+ fixup(atoms, mol, indicators)
454
+
455
+ mol.AddPolarHydrogens()
456
+ mol.PerceiveBondOrders()
457
+ fixup(atoms, mol, indicators)
458
+
459
+ for (i,a) in enumerate(atoms):
460
+ ob.OBAtomAssignTypicalImplicitHydrogens(a)
461
+ fixup(atoms, mol, indicators)
462
+
463
+ mol.AddHydrogens()
464
+ fixup(atoms, mol, indicators)
465
+
466
+ #make rings all aromatic if majority of carbons are aromatic
467
+ for ring in ob.OBMolRingIter(mol):
468
+ if 5 <= ring.Size() <= 6:
469
+ carbon_cnt = 0
470
+ aromatic_ccnt = 0
471
+ for ai in ring._path:
472
+ a = mol.GetAtom(ai)
473
+ if a.GetAtomicNum() == 6:
474
+ carbon_cnt += 1
475
+ if a.IsAromatic():
476
+ aromatic_ccnt += 1
477
+ if aromatic_ccnt >= carbon_cnt/2 and aromatic_ccnt != ring.Size():
478
+ #set all ring atoms to be aromatic
479
+ for ai in ring._path:
480
+ a = mol.GetAtom(ai)
481
+ a.SetAromatic(True)
482
+
483
+ #bonds must be marked aromatic for smiles to match
484
+ for bond in ob.OBMolBondIter(mol):
485
+ a1 = bond.GetBeginAtom()
486
+ a2 = bond.GetEndAtom()
487
+ if a1.IsAromatic() and a2.IsAromatic():
488
+ bond.SetAromatic(True)
489
+
490
+ mol.PerceiveBondOrders()
491
+
492
+ rd_mol = convert_ob_mol_to_rd_mol(mol)
493
+
494
+ # Post-processing
495
+ rd_mol = postprocess_rd_mol_1(rd_mol)
496
+ rd_mol = postprocess_rd_mol_2(rd_mol)
497
+
498
+ return rd_mol
utils/sascorer.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import print_function
2
+
3
+ from rdkit import Chem
4
+ from rdkit.Chem import rdMolDescriptors
5
+ from rdkit.six.moves import cPickle
6
+ from rdkit.six import iteritems
7
+
8
+ import math
9
+ from collections import defaultdict
10
+
11
+ import os.path as op
12
+
13
+ _fscores = None
14
+
15
+
16
+ def readFragmentScores(name='fpscores'):
17
+ import gzip
18
+ global _fscores
19
+ # generate the full path filename:
20
+ if name == "fpscores":
21
+ name = op.join(op.dirname(__file__), name)
22
+ _fscores = cPickle.load(gzip.open('%s.pkl.gz' % name))
23
+ outDict = {}
24
+ for i in _fscores:
25
+ for j in range(1, len(i)):
26
+ outDict[i[j]] = float(i[0])
27
+ _fscores = outDict
28
+
29
+
30
+ def numBridgeheadsAndSpiro(mol, ri=None):
31
+ nSpiro = rdMolDescriptors.CalcNumSpiroAtoms(mol)
32
+ nBridgehead = rdMolDescriptors.CalcNumBridgeheadAtoms(mol)
33
+ return nBridgehead, nSpiro
34
+
35
+
36
+ def calculateScore(m):
37
+ if _fscores is None:
38
+ readFragmentScores()
39
+
40
+ # fragment score
41
+ fp = rdMolDescriptors.GetMorganFingerprint(m,
42
+ 2) #<- 2 is the *radius* of the circular fingerprint
43
+ fps = fp.GetNonzeroElements()
44
+ score1 = 0.
45
+ nf = 0
46
+ for bitId, v in iteritems(fps):
47
+ nf += v
48
+ sfp = bitId
49
+ score1 += _fscores.get(sfp, -4) * v
50
+ score1 /= nf
51
+
52
+ # features score
53
+ nAtoms = m.GetNumAtoms()
54
+ nChiralCenters = len(Chem.FindMolChiralCenters(m, includeUnassigned=True))
55
+ ri = m.GetRingInfo()
56
+ nBridgeheads, nSpiro = numBridgeheadsAndSpiro(m, ri)
57
+ nMacrocycles = 0
58
+ for x in ri.AtomRings():
59
+ if len(x) > 8:
60
+ nMacrocycles += 1
61
+
62
+ sizePenalty = nAtoms**1.005 - nAtoms
63
+ stereoPenalty = math.log10(nChiralCenters + 1)
64
+ spiroPenalty = math.log10(nSpiro + 1)
65
+ bridgePenalty = math.log10(nBridgeheads + 1)
66
+ macrocyclePenalty = 0.
67
+ # ---------------------------------------
68
+ # This differs from the paper, which defines:
69
+ # macrocyclePenalty = math.log10(nMacrocycles+1)
70
+ # This form generates better results when 2 or more macrocycles are present
71
+ if nMacrocycles > 0:
72
+ macrocyclePenalty = math.log10(2)
73
+
74
+ score2 = 0. - sizePenalty - stereoPenalty - spiroPenalty - bridgePenalty - macrocyclePenalty
75
+
76
+ # correction for the fingerprint density
77
+ # not in the original publication, added in version 1.1
78
+ # to make highly symmetrical molecules easier to synthetise
79
+ score3 = 0.
80
+ if nAtoms > len(fps):
81
+ score3 = math.log(float(nAtoms) / len(fps)) * .5
82
+
83
+ sascore = score1 + score2 + score3
84
+
85
+ # need to transform "raw" value into scale between 1 and 10
86
+ min = -4.0
87
+ max = 2.5
88
+ sascore = 11. - (sascore - min + 1) / (max - min) * 9.
89
+ # smooth the 10-end
90
+ if sascore > 8.:
91
+ sascore = 8. + math.log(sascore + 1. - 9.)
92
+ if sascore > 10.:
93
+ sascore = 10.0
94
+ elif sascore < 1.:
95
+ sascore = 1.0
96
+
97
+ return sascore
98
+
99
+
100
+ def processMols(mols):
101
+ print('smiles\tName\tsa_score')
102
+ for i, m in enumerate(mols):
103
+ if m is None:
104
+ continue
105
+
106
+ s = calculateScore(m)
107
+
108
+ smiles = Chem.MolToSmiles(m)
109
+ print(smiles + "\t" + m.GetProp('_Name') + "\t%3f" % s)
110
+
111
+
112
+ if __name__ == '__main__':
113
+ import sys, time
114
+
115
+ t1 = time.time()
116
+ readFragmentScores("fpscores")
117
+ t2 = time.time()
118
+
119
+ suppl = Chem.SmilesMolSupplier(sys.argv[1])
120
+ t3 = time.time()
121
+ processMols(suppl)
122
+ t4 = time.time()
123
+
124
+ print('Reading took %.2f seconds. Calculating took %.2f seconds' % ((t2 - t1), (t4 - t3)),
125
+ file=sys.stderr)
126
+
127
+ #
128
+ # Copyright (c) 2013, Novartis Institutes for BioMedical Research Inc.
129
+ # All rights reserved.
130
+ #
131
+ # Redistribution and use in source and binary forms, with or without
132
+ # modification, are permitted provided that the following conditions are
133
+ # met:
134
+ #
135
+ # * Redistributions of source code must retain the above copyright
136
+ # notice, this list of conditions and the following disclaimer.
137
+ # * Redistributions in binary form must reproduce the above
138
+ # copyright notice, this list of conditions and the following
139
+ # disclaimer in the documentation and/or other materials provided
140
+ # with the distribution.
141
+ # * Neither the name of Novartis Institutes for BioMedical Research Inc.
142
+ # nor the names of its contributors may be used to endorse or promote
143
+ # products derived from this software without specific prior written permission.
144
+ #
145
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
146
+ # "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
147
+ # LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
148
+ # A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
149
+ # OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
150
+ # SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
151
+ # LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
152
+ # DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
153
+ # THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
154
+ # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
155
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
156
+ #
157
+
158
+ def compute_sa_score(rdmol):
159
+ rdmol = Chem.MolFromSmiles(Chem.MolToSmiles(rdmol))
160
+ sa = calculateScore(rdmol)
161
+ sa = round((10-sa)/9,2)
162
+ return sa
163
+
utils/similarity.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from rdkit import Chem, DataStructs
3
+
4
+
5
+ def tanimoto_sim(mol, ref):
6
+ fp1 = Chem.RDKFingerprint(ref)
7
+ fp2 = Chem.RDKFingerprint(mol)
8
+ return DataStructs.TanimotoSimilarity(fp1,fp2)
9
+
10
+
11
+ def tanimoto_sim_N_to_1(mols, ref):
12
+ sim = [tanimoto_sim(m, ref) for m in mols]
13
+ return sim
14
+
15
+
16
+ def batched_number_of_rings(mols):
17
+ n = []
18
+ for m in mols:
19
+ n.append(Chem.rdMolDescriptors.CalcNumRings(m))
20
+ return np.array(n)
utils/train.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import warnings
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn as nn
6
+ from torch_geometric.data import Data, Batch
7
+
8
+ from .warmup import GradualWarmupScheduler
9
+
10
+
11
+ #customize exp lr scheduler with min lr
12
+ class ExponentialLR_with_minLr(torch.optim.lr_scheduler.ExponentialLR):
13
+ def __init__(self, optimizer, gamma, min_lr=1e-4, last_epoch=-1, verbose=False):
14
+ self.gamma = gamma
15
+ self.min_lr = min_lr
16
+ super(ExponentialLR_with_minLr, self).__init__(optimizer, gamma, last_epoch, verbose)
17
+
18
+ def get_lr(self):
19
+ if not self._get_lr_called_within_step:
20
+ warnings.warn("To get the last learning rate computed by the scheduler, "
21
+ "please use `get_last_lr()`.", UserWarning)
22
+
23
+ if self.last_epoch == 0:
24
+ return self.base_lrs
25
+ return [max(group['lr'] * self.gamma, self.min_lr)
26
+ for group in self.optimizer.param_groups]
27
+
28
+ def _get_closed_form_lr(self):
29
+ return [max(base_lr * self.gamma ** self.last_epoch, self.min_lr)
30
+ for base_lr in self.base_lrs]
31
+
32
+
33
+ def repeat_data(data: Data, num_repeat) -> Batch:
34
+ datas = [copy.deepcopy(data) for i in range(num_repeat)]
35
+ return Batch.from_data_list(datas)
36
+
37
+
38
+ def repeat_batch(batch: Batch, num_repeat) -> Batch:
39
+ datas = batch.to_data_list()
40
+ new_data = []
41
+ for i in range(num_repeat):
42
+ new_data += copy.deepcopy(datas)
43
+ return Batch.from_data_list(new_data)
44
+
45
+
46
+ def inf_iterator(iterable):
47
+ iterator = iterable.__iter__()
48
+ while True:
49
+ try:
50
+ yield iterator.__next__()
51
+ except StopIteration:
52
+ iterator = iterable.__iter__()
53
+
54
+
55
+ def get_optimizer(cfg, model):
56
+ if cfg.type == 'adam':
57
+ return torch.optim.Adam(
58
+ model.parameters(),
59
+ lr=cfg.lr,
60
+ weight_decay=cfg.weight_decay,
61
+ betas=(cfg.beta1, cfg.beta2, )
62
+ )
63
+ else:
64
+ raise NotImplementedError('Optimizer not supported: %s' % cfg.type)
65
+
66
+
67
+ def get_scheduler(cfg, optimizer):
68
+ if cfg.type == 'plateau':
69
+ return torch.optim.lr_scheduler.ReduceLROnPlateau(
70
+ optimizer,
71
+ factor=cfg.factor,
72
+ patience=cfg.patience,
73
+ min_lr=cfg.min_lr
74
+ )
75
+ elif cfg.type == 'warmup_plateau':
76
+ return GradualWarmupScheduler(
77
+ optimizer,
78
+ multiplier = cfg.multiplier,
79
+ total_epoch = cfg.total_epoch,
80
+ after_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
81
+ optimizer,
82
+ factor=cfg.factor,
83
+ patience=cfg.patience,
84
+ min_lr=cfg.min_lr
85
+ )
86
+ )
87
+ elif cfg.type == 'expmin':
88
+ return ExponentialLR_with_minLr(
89
+ optimizer,
90
+ gamma=cfg.factor,
91
+ min_lr=cfg.min_lr,
92
+ )
93
+ elif cfg.type == 'expmin_milestone':
94
+ gamma = np.exp(np.log(cfg.factor) / cfg.milestone)
95
+ return ExponentialLR_with_minLr(
96
+ optimizer,
97
+ gamma=gamma,
98
+ min_lr=cfg.min_lr,
99
+ )
100
+ else:
101
+ raise NotImplementedError('Scheduler not supported: %s' % cfg.type)
102
+
utils/transforms.py ADDED
@@ -0,0 +1,684 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ sys.path.append("..")
3
+ import copy
4
+ import os
5
+ import random
6
+ import torch
7
+ import torch.nn.functional as F
8
+ import numpy as np
9
+ from copy import deepcopy
10
+ from torch_geometric.transforms import Compose
11
+ from torch_geometric.nn.pool import knn_graph
12
+ from torch_geometric.utils.subgraph import subgraph
13
+ from torch_geometric.utils.num_nodes import maybe_num_nodes
14
+ from torch_geometric.data import Data, Batch
15
+ from torch_scatter import scatter_add
16
+ from rdkit import Chem
17
+ from rdkit.Chem import Descriptors
18
+ from rdkit.Chem import AllChem
19
+
20
+ from .data import ProteinLigandData
21
+ from .protein_ligand import ATOM_FAMILIES
22
+ from .chemutils import enumerate_assemble, list_filter, rand_rotate
23
+ from .dihedral_utils import batch_dihedrals
24
+
25
+ # allowable node and edge features
26
+ allowable_features = {
27
+ 'possible_atomic_num_list': list(range(1, 119)),
28
+ 'possible_formal_charge_list': [-5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5],
29
+ 'possible_chirality_list': [
30
+ Chem.rdchem.ChiralType.CHI_UNSPECIFIED,
31
+ Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CW,
32
+ Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CCW,
33
+ Chem.rdchem.ChiralType.CHI_OTHER
34
+ ],
35
+ 'possible_hybridization_list': [
36
+ Chem.rdchem.HybridizationType.S,
37
+ Chem.rdchem.HybridizationType.SP, Chem.rdchem.HybridizationType.SP2,
38
+ Chem.rdchem.HybridizationType.SP3, Chem.rdchem.HybridizationType.SP3D,
39
+ Chem.rdchem.HybridizationType.SP3D2, Chem.rdchem.HybridizationType.UNSPECIFIED
40
+ ],
41
+ 'possible_numH_list': [0, 1, 2, 3, 4, 5, 6, 7, 8],
42
+ 'possible_implicit_valence_list': [0, 1, 2, 3, 4, 5, 6],
43
+ 'possible_degree_list': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
44
+ 'possible_bonds': [
45
+ Chem.rdchem.BondType.SINGLE,
46
+ Chem.rdchem.BondType.DOUBLE,
47
+ Chem.rdchem.BondType.TRIPLE,
48
+ Chem.rdchem.BondType.AROMATIC
49
+ ],
50
+ 'possible_bond_dirs': [ # only for double bond stereo information
51
+ Chem.rdchem.BondDir.NONE,
52
+ Chem.rdchem.BondDir.ENDUPRIGHT,
53
+ Chem.rdchem.BondDir.ENDDOWNRIGHT
54
+ ]
55
+ }
56
+
57
+
58
+ def mol_to_graph_data_obj_simple(mol):
59
+ """
60
+ Converts rdkit mol object to graph Data object required by the pytorch
61
+ geometric package. NB: Uses simplified atom and bond features, and represent
62
+ as indices
63
+ :param mol: rdkit mol object
64
+ :return: graph data object with the attributes: x, edge_index, edge_attr
65
+ """
66
+ # atoms
67
+ num_atom_features = 2 # atom type, chirality tag
68
+ atom_features_list = []
69
+ for atom in mol.GetAtoms():
70
+ atom_feature = [allowable_features['possible_atomic_num_list'].index(
71
+ atom.GetAtomicNum())] + [allowable_features[
72
+ 'possible_chirality_list'].index(atom.GetChiralTag())]
73
+ atom_features_list.append(atom_feature)
74
+ x = torch.tensor(np.array(atom_features_list), dtype=torch.long)
75
+
76
+ # bonds
77
+ num_bond_features = 2 # bond type, bond direction
78
+ if len(mol.GetBonds()) > 0: # mol has bonds
79
+ edges_list = []
80
+ edge_features_list = []
81
+ for bond in mol.GetBonds():
82
+ i = bond.GetBeginAtomIdx()
83
+ j = bond.GetEndAtomIdx()
84
+ edge_feature = [allowable_features['possible_bonds'].index(
85
+ bond.GetBondType())] + [allowable_features[
86
+ 'possible_bond_dirs'].index(
87
+ bond.GetBondDir())]
88
+ edges_list.append((i, j))
89
+ edge_features_list.append(edge_feature)
90
+ edges_list.append((j, i))
91
+ edge_features_list.append(edge_feature)
92
+
93
+ # data.edge_index: Graph connectivity in COO format with shape [2, num_edges]
94
+ edge_index = torch.tensor(np.array(edges_list).T, dtype=torch.long)
95
+
96
+ # data.edge_attr: Edge feature matrix with shape [num_edges, num_edge_features]
97
+ edge_attr = torch.tensor(np.array(edge_features_list),
98
+ dtype=torch.long)
99
+ else: # mol has no bonds
100
+ edge_index = torch.empty((2, 0), dtype=torch.long)
101
+ edge_attr = torch.empty((0, num_bond_features), dtype=torch.long)
102
+
103
+ data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr)
104
+
105
+ return data
106
+
107
+
108
+ class RefineData(object):
109
+ def __init__(self):
110
+ super().__init__()
111
+
112
+ def __call__(self, data):
113
+ # delete H atom of pocket
114
+ protein_element = data.protein_element
115
+ is_H_protein = (protein_element == 1)
116
+ if torch.sum(is_H_protein) > 0:
117
+ not_H_protein = ~is_H_protein
118
+ data.protein_atom_name = list(compress(data.protein_atom_name, not_H_protein))
119
+ data.protein_atom_to_aa_type = data.protein_atom_to_aa_type[not_H_protein]
120
+ data.protein_element = data.protein_element[not_H_protein]
121
+ data.protein_is_backbone = data.protein_is_backbone[not_H_protein]
122
+ data.protein_pos = data.protein_pos[not_H_protein]
123
+ # delete H atom of ligand
124
+ ligand_element = data.ligand_element
125
+ is_H_ligand = (ligand_element == 1)
126
+ if torch.sum(is_H_ligand) > 0:
127
+ not_H_ligand = ~is_H_ligand
128
+ data.ligand_atom_feature = data.ligand_atom_feature[not_H_ligand]
129
+ data.ligand_element = data.ligand_element[not_H_ligand]
130
+ data.ligand_pos = data.ligand_pos[not_H_ligand]
131
+ # nbh
132
+ index_atom_H = torch.nonzero(is_H_ligand)[:, 0]
133
+ index_changer = -np.ones(len(not_H_ligand), dtype=np.int64)
134
+ index_changer[not_H_ligand] = np.arange(torch.sum(not_H_ligand))
135
+ new_nbh_list = [value for ind_this, value in zip(not_H_ligand, data.ligand_nbh_list.values()) if ind_this]
136
+ data.ligand_nbh_list = {i: [index_changer[node] for node in neigh if node not in index_atom_H] for i, neigh
137
+ in enumerate(new_nbh_list)}
138
+ # bond
139
+ ind_bond_with_H = np.array([(bond_i in index_atom_H) | (bond_j in index_atom_H) for bond_i, bond_j in
140
+ zip(*data.ligand_bond_index)])
141
+ ind_bond_without_H = ~ind_bond_with_H
142
+ old_ligand_bond_index = data.ligand_bond_index[:, ind_bond_without_H]
143
+ data.ligand_bond_index = torch.tensor(index_changer)[old_ligand_bond_index]
144
+ data.ligand_bond_type = data.ligand_bond_type[ind_bond_without_H]
145
+
146
+ return data
147
+
148
+
149
+ class FocalBuilder(object):
150
+ def __init__(self, close_threshold=0.8, max_bond_length=2.4):
151
+ self.close_threshold = close_threshold
152
+ self.max_bond_length = max_bond_length
153
+ super().__init__()
154
+
155
+ def __call__(self, data: ProteinLigandData):
156
+ # ligand_context_pos = data.ligand_context_pos
157
+ # ligand_pos = data.ligand_pos
158
+ ligand_masked_pos = data.ligand_masked_pos
159
+ protein_pos = data.protein_pos
160
+ context_idx = data.context_idx
161
+ masked_idx = data.masked_idx
162
+ old_bond_index = data.ligand_bond_index
163
+ # old_bond_types = data.ligand_bond_type # type: 0, 1, 2
164
+ has_unmask_atoms = context_idx.nelement() > 0
165
+ if has_unmask_atoms:
166
+ # # get bridge bond index (mask-context bond)
167
+ ind_edge_index_candidate = [
168
+ (context_node in context_idx) and (mask_node in masked_idx)
169
+ for mask_node, context_node in zip(*old_bond_index)
170
+ ] # the mask-context order is right
171
+ bridge_bond_index = old_bond_index[:, ind_edge_index_candidate]
172
+ # candidate_bond_types = old_bond_types[idx_edge_index_candidate]
173
+ idx_generated_in_whole_ligand = bridge_bond_index[0]
174
+ idx_focal_in_whole_ligand = bridge_bond_index[1]
175
+
176
+ index_changer_masked = torch.zeros(masked_idx.max() + 1, dtype=torch.int64)
177
+ index_changer_masked[masked_idx] = torch.arange(len(masked_idx))
178
+ idx_generated_in_ligand_masked = index_changer_masked[idx_generated_in_whole_ligand]
179
+ pos_generate = ligand_masked_pos[idx_generated_in_ligand_masked]
180
+
181
+ data.idx_generated_in_ligand_masked = idx_generated_in_ligand_masked
182
+ data.pos_generate = pos_generate
183
+
184
+ index_changer_context = torch.zeros(context_idx.max() + 1, dtype=torch.int64)
185
+ index_changer_context[context_idx] = torch.arange(len(context_idx))
186
+ idx_focal_in_ligand_context = index_changer_context[idx_focal_in_whole_ligand]
187
+ idx_focal_in_compose = idx_focal_in_ligand_context # if ligand_context was not before protein in the compose, this was not correct
188
+ data.idx_focal_in_compose = idx_focal_in_compose
189
+
190
+ data.idx_protein_all_mask = torch.empty(0, dtype=torch.long) # no use if has context
191
+ data.y_protein_frontier = torch.empty(0, dtype=torch.bool) # no use if has context
192
+
193
+ else: # # the initial atom. surface atoms between ligand and protein
194
+ assign_index = radius(x=ligand_masked_pos, y=protein_pos, r=4., num_workers=16)
195
+ if assign_index.size(1) == 0:
196
+ dist = torch.norm(data.protein_pos.unsqueeze(1) - data.ligand_masked_pos.unsqueeze(0), p=2, dim=-1)
197
+ assign_index = torch.nonzero(dist <= torch.min(dist) + 1e-5)[0:1].transpose(0, 1)
198
+ idx_focal_in_protein = assign_index[0]
199
+ data.idx_focal_in_compose = idx_focal_in_protein # no ligand context, so all composes are protein atoms
200
+ data.pos_generate = ligand_masked_pos[assign_index[1]]
201
+ data.idx_generated_in_ligand_masked = torch.unique(assign_index[1]) # for real of the contractive transform
202
+
203
+ data.idx_protein_all_mask = data.idx_protein_in_compose # for input of initial frontier prediction
204
+ y_protein_frontier = torch.zeros_like(data.idx_protein_all_mask,
205
+ dtype=torch.bool) # for label of initial frontier prediction
206
+ y_protein_frontier[torch.unique(idx_focal_in_protein)] = True
207
+ data.y_protein_frontier = y_protein_frontier
208
+
209
+ # generate not positions: around pos_focal ( with `max_bond_length` distance) but not close to true generated within `close_threshold`
210
+ # pos_focal = ligand_context_pos[idx_focal_in_ligand_context]
211
+ # pos_notgenerate = pos_focal + torch.randn_like(pos_focal) * self.max_bond_length / 2.4
212
+ # dist = torch.norm(pos_generate - pos_notgenerate, p=2, dim=-1)
213
+ # ind_close = (dist < self.close_threshold)
214
+ # while ind_close.any():
215
+ # new_pos_notgenerate = pos_focal[ind_close] + torch.randn_like(pos_focal[ind_close]) * self.max_bond_length / 2.3
216
+ # dist[ind_close] = torch.norm(pos_generate[ind_close] - new_pos_notgenerate, p=2, dim=-1)
217
+ # pos_notgenerate[ind_close] = new_pos_notgenerate
218
+ # ind_close = (dist < self.close_threshold)
219
+ # data.pos_notgenerate = pos_notgenerate
220
+
221
+ return data
222
+
223
+
224
+ class AtomComposer(object):
225
+
226
+ def __init__(self, protein_dim, ligand_dim, knn):
227
+ super().__init__()
228
+ self.protein_dim = protein_dim
229
+ self.ligand_dim = ligand_dim
230
+ self.knn = knn # knn of compose atoms
231
+
232
+ def __call__(self, data: ProteinLigandData):
233
+ # fetch ligand context and protein from data
234
+ ligand_context_pos = data['ligand_context_pos']
235
+ ligand_context_feature_full = data['ligand_context_feature_full']
236
+ protein_pos = data['protein_pos']
237
+ protein_atom_feature = data['protein_atom_feature']
238
+ len_ligand_ctx = len(ligand_context_pos)
239
+ len_protein = len(protein_pos)
240
+
241
+ # compose ligand context and protein. save idx of them in compose
242
+ data['compose_pos'] = torch.cat([ligand_context_pos, protein_pos], dim=0)
243
+ len_compose = len_ligand_ctx + len_protein
244
+ ligand_context_feature_full_expand = torch.cat([
245
+ ligand_context_feature_full,
246
+ torch.zeros([len_ligand_ctx, self.protein_dim - self.ligand_dim], dtype=torch.long)
247
+ ], dim=1)
248
+ data['compose_feature'] = torch.cat([ligand_context_feature_full_expand, protein_atom_feature], dim=0)
249
+ data['idx_ligand_ctx_in_compose'] = torch.arange(len_ligand_ctx, dtype=torch.long) # can be delete
250
+ data['idx_protein_in_compose'] = torch.arange(len_protein, dtype=torch.long) + len_ligand_ctx # can be delete
251
+
252
+ # build knn graph and bond type
253
+ data = self.get_knn_graph(data, self.knn, len_ligand_ctx, len_compose, num_workers=16)
254
+ return data
255
+
256
+ @staticmethod
257
+ def get_knn_graph(data: ProteinLigandData, knn, len_ligand_ctx, len_compose, num_workers=1, ):
258
+ data['compose_knn_edge_index'] = knn_graph(data['compose_pos'], knn, flow='target_to_source', num_workers=num_workers)
259
+
260
+ id_compose_edge = data['compose_knn_edge_index'][0,
261
+ :len_ligand_ctx * knn] * len_compose + data['compose_knn_edge_index'][1, :len_ligand_ctx * knn]
262
+ id_ligand_ctx_edge = data['ligand_context_bond_index'][0] * len_compose + data['ligand_context_bond_index'][1]
263
+ idx_edge = [torch.nonzero(id_compose_edge == id_) for id_ in id_ligand_ctx_edge]
264
+ idx_edge = torch.tensor([a.squeeze() if len(a) > 0 else torch.tensor(-1) for a in idx_edge], dtype=torch.long)
265
+ data['compose_knn_edge_type'] = torch.zeros(len(data['compose_knn_edge_index'][0]),
266
+ dtype=torch.long) # for encoder edge embedding
267
+ data['compose_knn_edge_type'][idx_edge[idx_edge >= 0]] = data['ligand_context_bond_type'][idx_edge >= 0]
268
+ data['compose_knn_edge_feature'] = torch.cat([
269
+ torch.ones([len(data['compose_knn_edge_index'][0]), 1], dtype=torch.long),
270
+ torch.zeros([len(data['compose_knn_edge_index'][0]), 3], dtype=torch.long),
271
+ ], dim=-1)
272
+ data['compose_knn_edge_feature'][idx_edge[idx_edge >= 0]] = F.one_hot(data['ligand_context_bond_type'][idx_edge >= 0],
273
+ num_classes=4) # 0 (1,2,3)-onehot
274
+ return data
275
+
276
+
277
+ class FeaturizeProteinAtom(object):
278
+
279
+ def __init__(self):
280
+ super().__init__()
281
+ # self.atomic_numbers = torch.LongTensor([1, 6, 7, 8, 16, 34]) # H, C, N, O, S, Se
282
+ self.atomic_numbers = torch.LongTensor([6, 7, 8, 16, 34]) # H, C, N, O, S, Se
283
+ self.max_num_aa = 20
284
+
285
+ @property
286
+ def feature_dim(self):
287
+ return self.atomic_numbers.size(0) + self.max_num_aa + 1
288
+
289
+ def __call__(self, data: ProteinLigandData):
290
+ element = data['protein_element'].view(-1, 1) == self.atomic_numbers.view(1, -1) # (N_atoms, N_elements)
291
+ amino_acid = F.one_hot(data['protein_atom_to_aa_type'], num_classes=self.max_num_aa)
292
+ is_backbone = data['protein_is_backbone'].view(-1, 1).long()
293
+ x = torch.cat([element, amino_acid, is_backbone], dim=-1)
294
+ data['protein_atom_feature'] = x
295
+ return data
296
+
297
+
298
+ class FeaturizeLigandAtom(object):
299
+
300
+ def __init__(self):
301
+ super().__init__()
302
+ # self.atomic_numbers = torch.LongTensor([1,6,7,8,9,15,16,17]) # H C N O F P S Cl
303
+ self.atomic_numbers = torch.LongTensor([6, 7, 8, 9, 15, 16, 17]) # C N O F P S Cl
304
+
305
+ @property
306
+ def num_properties(self):
307
+ return len(ATOM_FAMILIES)
308
+
309
+ @property
310
+ def feature_dim(self):
311
+ return self.atomic_numbers.size(0) + len(ATOM_FAMILIES)
312
+
313
+ def __call__(self, data: ProteinLigandData):
314
+ element = data['ligand_element'].view(-1, 1) == self.atomic_numbers.view(1, -1) # (N_atoms, N_elements)
315
+ x = torch.cat([element, data['ligand_atom_feature']], dim=-1)
316
+ data['ligand_atom_feature_full'] = x
317
+ return data
318
+
319
+
320
+ class FeaturizeLigandBond(object):
321
+
322
+ def __init__(self):
323
+ super().__init__()
324
+
325
+ def __call__(self, data: ProteinLigandData):
326
+ data['ligand_bond_feature'] = F.one_hot((data['ligand_bond_type'] - 1)%3, num_classes=3) # (1,2,3) to (0,1,2)-onehot
327
+
328
+ neighbor_dict = {}
329
+ # used in rotation angle prediction
330
+ mol = data['moltree'].mol
331
+ for i, atom in enumerate(mol.GetAtoms()):
332
+ neighbor_dict[i] = [n.GetIdx() for n in atom.GetNeighbors()]
333
+ data['ligand_neighbors'] = neighbor_dict
334
+ return data
335
+
336
+
337
+ class LigandCountNeighbors(object):
338
+
339
+ @staticmethod
340
+ def count_neighbors(edge_index, symmetry, valence=None, num_nodes=None):
341
+ assert symmetry == True, 'Only support symmetrical edges.'
342
+
343
+ if num_nodes is None:
344
+ num_nodes = maybe_num_nodes(edge_index)
345
+
346
+ if valence is None:
347
+ valence = torch.ones([edge_index.size(1)], device=edge_index.device)
348
+ valence = valence.view(edge_index.size(1))
349
+
350
+ return scatter_add(valence, index=edge_index[0], dim=0, dim_size=num_nodes).long()
351
+
352
+ def __init__(self):
353
+ super().__init__()
354
+
355
+ def __call__(self, data):
356
+ data['ligand_num_neighbors'] = self.count_neighbors(
357
+ data['ligand_bond_index'],
358
+ symmetry=True,
359
+ num_nodes=data['ligand_element'].size(0),
360
+ )
361
+ data['ligand_atom_valence'] = self.count_neighbors(
362
+ data['ligand_bond_index'],
363
+ symmetry=True,
364
+ valence=data['ligand_bond_type'],
365
+ num_nodes=data['ligand_element'].size(0),
366
+ )
367
+ return data
368
+
369
+
370
+ class LigandRandomMask(object):
371
+
372
+ def __init__(self, min_ratio=0.0, max_ratio=1.2, min_num_masked=1, min_num_unmasked=0):
373
+ super().__init__()
374
+ self.min_ratio = min_ratio
375
+ self.max_ratio = max_ratio
376
+ self.min_num_masked = min_num_masked
377
+ self.min_num_unmasked = min_num_unmasked
378
+
379
+ def __call__(self, data: ProteinLigandData):
380
+ ratio = np.clip(random.uniform(self.min_ratio, self.max_ratio), 0.0, 1.0)
381
+ num_atoms = data.ligand_element.size(0)
382
+ num_masked = int(num_atoms * ratio)
383
+
384
+ if num_masked < self.min_num_masked:
385
+ num_masked = self.min_num_masked
386
+ if (num_atoms - num_masked) < self.min_num_unmasked:
387
+ num_masked = num_atoms - self.min_num_unmasked
388
+
389
+ idx = np.arange(num_atoms)
390
+ np.random.shuffle(idx)
391
+ idx = torch.LongTensor(idx)
392
+ masked_idx = idx[:num_masked]
393
+ context_idx = idx[num_masked:]
394
+
395
+ data.ligand_masked_element = data.ligand_element[masked_idx]
396
+ data.ligand_masked_feature = data.ligand_atom_feature[masked_idx] # For Prediction
397
+ data.ligand_masked_pos = data.ligand_pos[masked_idx]
398
+
399
+ data.ligand_context_element = data.ligand_element[context_idx]
400
+ data.ligand_context_feature_full = data.ligand_atom_feature_full[context_idx] # For Input
401
+ data.ligand_context_pos = data.ligand_pos[context_idx]
402
+
403
+ data.ligand_context_bond_index, data.ligand_context_bond_feature = subgraph(
404
+ context_idx,
405
+ data.ligand_bond_index,
406
+ edge_attr=data.ligand_bond_feature,
407
+ relabel_nodes=True,
408
+ )
409
+ data.ligand_context_num_neighbors = LigandCountNeighbors.count_neighbors(
410
+ data.ligand_context_bond_index,
411
+ symmetry=True,
412
+ num_nodes=context_idx.size(0),
413
+ )
414
+
415
+ # print(context_idx)
416
+ # print(data.ligand_context_bond_index)
417
+
418
+ # mask = torch.logical_and(
419
+ # (data.ligand_bond_index[0].view(-1, 1) == context_idx.view(1, -1)).any(dim=-1),
420
+ # (data.ligand_bond_index[1].view(-1, 1) == context_idx.view(1, -1)).any(dim=-1),
421
+ # )
422
+ # print(data.ligand_bond_index[:, mask])
423
+
424
+ # print(data.ligand_context_num_neighbors)
425
+ # print(data.ligand_num_neighbors[context_idx])
426
+
427
+ data.ligand_frontier = data.ligand_context_num_neighbors < data.ligand_num_neighbors[context_idx]
428
+
429
+ data._mask = 'random'
430
+
431
+ return data
432
+
433
+
434
+ class LigandBFSMask(object):
435
+
436
+ def __init__(self, min_ratio=0.0, max_ratio=1.2, min_num_masked=1, min_num_unmasked=0, vocab=None):
437
+ super().__init__()
438
+ self.min_ratio = min_ratio
439
+ self.max_ratio = max_ratio
440
+ self.min_num_masked = min_num_masked
441
+ self.min_num_unmasked = min_num_unmasked
442
+ self.vocab = vocab
443
+ self.vocab_size = vocab.size()
444
+
445
+ @staticmethod
446
+ def get_bfs_perm_motif(moltree, vocab):
447
+ for i, node in enumerate(moltree.nodes):
448
+ node.nid = i
449
+ node.wid = vocab.get_index(node.smiles)
450
+ # num_motifs = len(moltree.nodes)
451
+ bfs_queue = [0]
452
+ bfs_perm = []
453
+ bfs_focal = []
454
+ visited = {bfs_queue[0]}
455
+ while len(bfs_queue) > 0:
456
+ current = bfs_queue.pop(0)
457
+ bfs_perm.append(current)
458
+ next_candid = []
459
+ for motif in moltree.nodes[current].neighbors:
460
+ if motif.nid in visited: continue
461
+ next_candid.append(motif.nid)
462
+ visited.add(motif.nid)
463
+ bfs_focal.append(current)
464
+
465
+ random.shuffle(next_candid)
466
+ bfs_queue += next_candid
467
+
468
+ return bfs_perm, bfs_focal
469
+
470
+ def __call__(self, data):
471
+ bfs_perm, bfs_focal = self.get_bfs_perm_motif(data['moltree'], self.vocab)
472
+ ratio = np.clip(random.uniform(self.min_ratio, self.max_ratio), 0.0, 1.0)
473
+ num_motifs = len(bfs_perm)
474
+ num_masked = int(num_motifs * ratio)
475
+ if num_masked < self.min_num_masked:
476
+ num_masked = self.min_num_masked
477
+ if (num_motifs - num_masked) < self.min_num_unmasked:
478
+ num_masked = num_motifs - self.min_num_unmasked
479
+ num_unmasked = num_motifs - num_masked
480
+
481
+ context_motif_ids = bfs_perm[:-num_masked]
482
+ context_idx = set()
483
+ for i in context_motif_ids:
484
+ context_idx = context_idx | set(data['moltree'].nodes[i].clique)
485
+ context_idx = torch.LongTensor(list(context_idx))
486
+
487
+ if num_masked == num_motifs:
488
+ data['current_wid'] = torch.tensor([self.vocab_size])
489
+ data['current_atoms'] = torch.tensor([data['protein_contact_idx']])
490
+ data['next_wid'] = torch.tensor([data['moltree'].nodes[bfs_perm[-num_masked]].wid])
491
+ else:
492
+ data['current_wid'] = torch.tensor([data['moltree'].nodes[bfs_focal[-num_masked]].wid])
493
+ data['next_wid'] = torch.tensor([data['moltree'].nodes[bfs_perm[-num_masked]].wid]) # For Prediction
494
+ current_atoms = data['moltree'].nodes[bfs_focal[-num_masked]].clique
495
+ data['current_atoms'] = torch.cat([torch.where(context_idx == i)[0] for i in current_atoms]) + len(data['protein_pos'])
496
+
497
+ data['ligand_context_element'] = data['ligand_element'][context_idx]
498
+ data['ligand_context_feature_full'] = data['ligand_atom_feature_full'][context_idx] # For Input
499
+ data['ligand_context_pos'] = data['ligand_pos'][context_idx]
500
+ data['ligand_center'] = torch.mean(data['ligand_pos'], dim=0)
501
+ data['num_atoms'] = torch.tensor([len(context_idx) + len(data['protein_pos'])])
502
+ # distance matrix prediction
503
+ if len(data['ligand_context_pos']) > 0:
504
+ sample_idx = random.sample(data['moltree'].nodes[bfs_perm[0]].clique, 2)
505
+ data['dm_ligand_idx'] = torch.cat([torch.where(context_idx == i)[0] for i in sample_idx])
506
+ data['dm_protein_idx'] = torch.sort(torch.norm(data['protein_pos'] - data['ligand_context_pos'][data['dm_ligand_idx'][0]], dim=-1)).indices[:4]
507
+ data['true_dm'] = torch.norm(data['protein_pos'][data['dm_protein_idx']].unsqueeze(1) - data['ligand_context_pos'][data['dm_ligand_idx']].unsqueeze(0), dim=-1).reshape(-1)
508
+ else:
509
+ data['true_dm'] = torch.tensor([])
510
+
511
+ data['protein_alpha_carbon_index'] = torch.tensor([i for i, name in enumerate(data['protein_atom_name']) if name =="CA"])
512
+ data['alpha_carbon_indicator'] = torch.tensor([True if name =="CA" else False for name in data['protein_atom_name']])
513
+
514
+ # assemble prediction
515
+ data['protein_contact'] = torch.tensor(data['protein_contact'])
516
+ if len(context_motif_ids) > 0:
517
+ cand_labels, cand_mols = enumerate_assemble(data['moltree'].mol, context_idx.tolist(),
518
+ data['moltree'].nodes[bfs_focal[-num_masked]],
519
+ data['moltree'].nodes[bfs_perm[-num_masked]])
520
+ data['cand_labels'] = cand_labels
521
+ data['cand_mols'] = [mol_to_graph_data_obj_simple(mol) for mol in cand_mols]
522
+ else:
523
+ data['cand_labels'], data['cand_mols'] = torch.tensor([]), []
524
+
525
+ data['ligand_context_bond_index'], data['ligand_context_bond_feature'] = subgraph(
526
+ context_idx,
527
+ data['ligand_bond_index'],
528
+ edge_attr=data['ligand_bond_feature'],
529
+ relabel_nodes=True,
530
+ )
531
+ data['ligand_context_num_neighbors'] = LigandCountNeighbors.count_neighbors(
532
+ data['ligand_context_bond_index'],
533
+ symmetry=True,
534
+ num_nodes=context_idx.size(0),
535
+ )
536
+ data['ligand_frontier'] = data['ligand_context_num_neighbors'] < data['ligand_num_neighbors'][context_idx]
537
+ data['_mask'] = 'bfs'
538
+
539
+ # find a rotatable bond as the current motif
540
+ rotatable_ids = []
541
+ for i, id in enumerate(bfs_focal):
542
+ if data['moltree'].nodes[id].rotatable:
543
+ rotatable_ids.append(i)
544
+ if len(rotatable_ids) == 0:
545
+ # assign empty tensor
546
+ data['ligand_torsion_xy_index'] = torch.tensor([])
547
+ data['dihedral_mask'] = torch.tensor([]).bool()
548
+ data['ligand_element_torsion'] = torch.tensor([])
549
+ data['ligand_pos_torsion'] = torch.tensor([])
550
+ data['ligand_feature_torsion'] = torch.tensor([])
551
+ data['true_sin'], data['true_cos'], data['true_three_hop'] = torch.tensor([]), torch.tensor([]), torch.tensor([])
552
+ data['xn_pos'], data['yn_pos'], data['y_pos'] = torch.tensor([]), torch.tensor([]), torch.tensor([])
553
+ else:
554
+ num_unmasked = random.sample(rotatable_ids, 1)[0]
555
+ current_idx = torch.LongTensor(data['moltree'].nodes[bfs_focal[num_unmasked]].clique)
556
+ next_idx = torch.LongTensor(data['moltree'].nodes[bfs_perm[num_unmasked + 1]].clique)
557
+ current_idx_set = set(data['moltree'].nodes[bfs_focal[num_unmasked]].clique)
558
+ next_idx_set = set(data['moltree'].nodes[bfs_perm[num_unmasked + 1]].clique)
559
+ all_idx = set()
560
+ for i in bfs_perm[:num_unmasked + 2]:
561
+ all_idx = all_idx | set(data['moltree'].nodes[i].clique)
562
+ all_idx = list(all_idx)
563
+ x_id = current_idx_set.intersection(next_idx_set).pop()
564
+ y_id = (current_idx_set - {x_id}).pop()
565
+ data['ligand_torsion_xy_index'] = torch.cat([torch.where(torch.LongTensor(all_idx) == i)[0] for i in [x_id, y_id]])
566
+
567
+ x_pos, y_pos = deepcopy(data['ligand_pos'][x_id]), deepcopy(data['ligand_pos'][y_id])
568
+ # remove x, y, and non-generated elements
569
+ xn, yn = deepcopy(data['ligand_neighbors'][x_id]), deepcopy(data['ligand_neighbors'][y_id])
570
+ xn.remove(y_id)
571
+ yn.remove(x_id)
572
+ xn, yn = xn[:3], yn[:3]
573
+ # debug
574
+ xn, yn = list_filter(xn, all_idx), list_filter(yn, all_idx)
575
+ xn_pos, yn_pos = torch.zeros(3, 3), torch.zeros(3, 3)
576
+ xn_pos[:len(xn)], yn_pos[:len(yn)] = deepcopy(data['ligand_pos'][xn]), deepcopy(data['ligand_pos'][yn])
577
+ xn_idx, yn_idx = torch.cartesian_prod(torch.arange(3), torch.arange(3)).chunk(2, dim=-1)
578
+ xn_idx = xn_idx.squeeze(-1)
579
+ yn_idx = yn_idx.squeeze(-1)
580
+ dihedral_x, dihedral_y = torch.zeros(3), torch.zeros(3)
581
+ dihedral_x[:len(xn)] = 1
582
+ dihedral_y[:len(yn)] = 1
583
+ data['dihedral_mask'] = torch.matmul(dihedral_x.view(3, 1), dihedral_y.view(1, 3)).view(-1).bool()
584
+ data['true_sin'], data['true_cos'] = batch_dihedrals(xn_pos[xn_idx], x_pos.repeat(9, 1), y_pos.repeat(9, 1),
585
+ yn_pos[yn_idx])
586
+ data['true_three_hop'] = torch.linalg.norm(xn_pos[xn_idx] - yn_pos[yn_idx], dim=-1)[data['dihedral_mask']]
587
+
588
+ # random rotate to simulate the inference situation
589
+ dir = data['ligand_pos'][current_idx[0]] - data['ligand_pos'][current_idx[1]]
590
+ ref = deepcopy(data['ligand_pos'][current_idx[0]])
591
+ next_motif_pos = deepcopy(data['ligand_pos'][next_idx])
592
+ data['ligand_pos'][next_idx] = rand_rotate(dir, ref, next_motif_pos)
593
+
594
+ data['ligand_element_torsion'] = data['ligand_element'][all_idx]
595
+ data['ligand_pos_torsion'] = data['ligand_pos'][all_idx]
596
+ data['ligand_feature_torsion'] = data['ligand_atom_feature_full'][all_idx]
597
+
598
+ x_pos = deepcopy(data['ligand_pos'][x_id])
599
+ data['y_pos'] = data['ligand_pos'][y_id] - x_pos
600
+ data['xn_pos'], data['yn_pos'] = torch.zeros(3, 3), torch.zeros(3, 3)
601
+ data['xn_pos'][:len(xn)], data['yn_pos'][:len(yn)] = data['ligand_pos'][xn] - x_pos, data['ligand_pos'][yn] - x_pos
602
+
603
+ return data
604
+
605
+
606
+ class LigandMaskAll(LigandBFSMask):
607
+
608
+ def __init__(self, vocab):
609
+ super().__init__(min_ratio=1.0, vocab=vocab)
610
+
611
+
612
+ class LigandMixedMask(object):
613
+
614
+ def __init__(self, min_ratio=0.0, max_ratio=1.2, min_num_masked=1, min_num_unmasked=0, p_random=0.5, p_bfs=0.25,
615
+ p_invbfs=0.25):
616
+ super().__init__()
617
+
618
+ self.t = [
619
+ LigandRandomMask(min_ratio, max_ratio, min_num_masked, min_num_unmasked),
620
+ LigandBFSMask(min_ratio, max_ratio, min_num_masked, min_num_unmasked, inverse=False),
621
+ LigandBFSMask(min_ratio, max_ratio, min_num_masked, min_num_unmasked, inverse=True),
622
+ ]
623
+ self.p = [p_random, p_bfs, p_invbfs]
624
+
625
+ def __call__(self, data):
626
+ f = random.choices(self.t, k=1, weights=self.p)[0]
627
+ return f(data)
628
+
629
+
630
+ def get_mask(cfg, vocab):
631
+ if cfg.type == 'bfs':
632
+ return LigandBFSMask(
633
+ min_ratio=cfg.min_ratio,
634
+ max_ratio=cfg.max_ratio,
635
+ min_num_masked=cfg.min_num_masked,
636
+ min_num_unmasked=cfg.min_num_unmasked,
637
+ vocab=vocab
638
+ )
639
+ elif cfg.type == 'random':
640
+ return LigandRandomMask(
641
+ min_ratio=cfg.min_ratio,
642
+ max_ratio=cfg.max_ratio,
643
+ min_num_masked=cfg.min_num_masked,
644
+ min_num_unmasked=cfg.min_num_unmasked,
645
+ )
646
+ elif cfg.type == 'mixed':
647
+ return LigandMixedMask(
648
+ min_ratio=cfg.min_ratio,
649
+ max_ratio=cfg.max_ratio,
650
+ min_num_masked=cfg.min_num_masked,
651
+ min_num_unmasked=cfg.min_num_unmasked,
652
+ p_random=cfg.p_random,
653
+ p_bfs=cfg.p_bfs,
654
+ p_invbfs=cfg.p_invbfs,
655
+ )
656
+ elif cfg.type == 'all':
657
+ return LigandMaskAll()
658
+ else:
659
+ raise NotImplementedError('Unknown mask: %s' % cfg.type)
660
+
661
+
662
+ def kabsch(A, B):
663
+ # Input:
664
+ # Nominal A Nx3 matrix of points
665
+ # Measured B Nx3 matrix of points
666
+ # Returns R,t
667
+ # R = 3x3 rotation matrix (B to A)
668
+ # t = 3x1 translation vector (B to A)
669
+ assert len(A) == len(B)
670
+ N = A.shape[0] # total points
671
+ centroid_A = np.mean(A, axis=0)
672
+ centroid_B = np.mean(B, axis=0)
673
+ # center the points
674
+ AA = A - np.tile(centroid_A, (N, 1))
675
+ BB = B - np.tile(centroid_B, (N, 1))
676
+ H = np.transpose(BB) * AA
677
+ U, S, Vt = np.linalg.svd(H)
678
+ R = Vt.T * U.T
679
+ # special reflection case
680
+ if np.linalg.det(R) < 0:
681
+ Vt[2, :] *= -1
682
+ R = Vt.T * U.T
683
+ t = -R * centroid_B.T + centroid_A.T
684
+ return R, t
utils/warmup.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ MIT License
3
+
4
+ Copyright (c) 2019 Ildoo Kim
5
+
6
+ Permission is hereby granted, free of charge, to any person obtaining a copy
7
+ of this software and associated documentation files (the "Software"), to deal
8
+ in the Software without restriction, including without limitation the rights
9
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10
+ copies of the Software, and to permit persons to whom the Software is
11
+ furnished to do so, subject to the following conditions:
12
+
13
+ The above copyright notice and this permission notice shall be included in all
14
+ copies or substantial portions of the Software.
15
+
16
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22
+ SOFTWARE.
23
+ """
24
+ from torch.optim.lr_scheduler import _LRScheduler
25
+ from torch.optim.lr_scheduler import ReduceLROnPlateau
26
+
27
+
28
+ class GradualWarmupScheduler(_LRScheduler):
29
+ """ Gradually warm-up(increasing) learning rate in optimizer.
30
+ Proposed in 'Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour'.
31
+ Args:
32
+ optimizer (Optimizer): Wrapped optimizer.
33
+ multiplier: target learning rate = base lr * multiplier if multiplier > 1.0. if multiplier = 1.0, lr starts from 0 and ends up with the base_lr.
34
+ total_epoch: target learning rate is reached at total_epoch, gradually
35
+ after_scheduler: after target_epoch, use this scheduler(eg. ReduceLROnPlateau)
36
+ """
37
+
38
+ def __init__(self, optimizer, multiplier, total_epoch, after_scheduler=None):
39
+ self.multiplier = multiplier
40
+ if self.multiplier < 1.:
41
+ raise ValueError('multiplier should be greater thant or equal to 1.')
42
+ self.total_epoch = total_epoch
43
+ self.after_scheduler = after_scheduler
44
+ self.finished = False
45
+ super(GradualWarmupScheduler, self).__init__(optimizer)
46
+
47
+ def get_lr(self):
48
+ if self.last_epoch > self.total_epoch:
49
+ if self.after_scheduler:
50
+ if not self.finished:
51
+ self.after_scheduler.base_lrs = [base_lr * self.multiplier for base_lr in self.base_lrs]
52
+ self.finished = True
53
+ return self.after_scheduler.get_last_lr()
54
+ return [base_lr * self.multiplier for base_lr in self.base_lrs]
55
+
56
+ if self.multiplier == 1.0:
57
+ return [base_lr * (float(self.last_epoch) / self.total_epoch) for base_lr in self.base_lrs]
58
+ else:
59
+ return [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in self.base_lrs]
60
+
61
+ def step_ReduceLROnPlateau(self, metrics, epoch=None):
62
+ if epoch is None:
63
+ epoch = self.last_epoch + 1
64
+ self.last_epoch = epoch if epoch != 0 else 1 # ReduceLROnPlateau is called at the end of epoch, whereas others are called at beginning
65
+ if self.last_epoch <= self.total_epoch:
66
+ warmup_lr = [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in self.base_lrs]
67
+ for param_group, lr in zip(self.optimizer.param_groups, warmup_lr):
68
+ param_group['lr'] = lr
69
+ else:
70
+ if epoch is None:
71
+ self.after_scheduler.step(metrics, None)
72
+ else:
73
+ self.after_scheduler.step(metrics, epoch - self.total_epoch)
74
+
75
+ def step(self, epoch=None, metrics=None):
76
+ if type(self.after_scheduler) != ReduceLROnPlateau:
77
+ if self.finished and self.after_scheduler:
78
+ if epoch is None:
79
+ self.after_scheduler.step(None)
80
+ else:
81
+ self.after_scheduler.step(epoch - self.total_epoch)
82
+ self._last_lr = self.after_scheduler.get_last_lr()
83
+ else:
84
+ return super(GradualWarmupScheduler, self).step(epoch)
85
+ else:
86
+ self.step_ReduceLROnPlateau(metrics, epoch)
vocab.txt ADDED
@@ -0,0 +1,549 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ CC:108150
2
+ CN:59667
3
+ CO:39300
4
+ C=O:36148
5
+ C1=CC=CC=C1:17649
6
+ OP:7954
7
+ O=S:5180
8
+ CF:4607
9
+ CS:4246
10
+ C[NH3+]:3561
11
+ O=P:3006
12
+ CCl:2484
13
+ C[NH+]:2321
14
+ C=N:2305
15
+ C1CCNC1:2115
16
+ [H]N:2073
17
+ C1CCOCC1:1957
18
+ C1=CC=NC=C1:1892
19
+ C1=CN=CN=C1:1875
20
+ NS:1824
21
+ C1CCOC1:1578
22
+ C[NH2+]:1291
23
+ C1CCCCC1:1209
24
+ C=C:1202
25
+ C1=CNC=N1:1184
26
+ C[N+]:1066
27
+ C1=CNN=C1:676
28
+ C1CCNCC1:670
29
+ CP:662
30
+ C1=CCCC=C1:628
31
+ OS:624
32
+ C1=CSC=C1:614
33
+ C1CCCC1:531
34
+ C#N:481
35
+ NO:477
36
+ C1=CSC=N1:463
37
+ CBr:437
38
+ C1=CNCNC1:436
39
+ C1CC1:434
40
+ C1=CCCCC1:423
41
+ C1=CNC=C1:391
42
+ C1CC[NH+]CC1:352
43
+ C1=CCNC=C1:319
44
+ [N+]=O:315
45
+ C1=CN=CNC1:313
46
+ C1=NCCN1:311
47
+ [N+][O-]:310
48
+ C1=CNCC1:300
49
+ C1CC[NH2+]CC1:274
50
+ C1CC[NH2+]C1:264
51
+ BO:260
52
+ C1C[NH+]CCN1:247
53
+ C1=CNN=N1:226
54
+ C1=COC=C1:218
55
+ C1CC[NH+]C1:210
56
+ C#C:208
57
+ C1COCCN1:206
58
+ C1=CN=CC=N1:184
59
+ C1=CCCN=C1:179
60
+ C1=CCNCC1:176
61
+ CI:175
62
+ C1CNCN1:171
63
+ C1CNCCN1:165
64
+ C1=COCCC1:165
65
+ C1=CON=C1:162
66
+ C1COCC[NH+]1:138
67
+ BC:137
68
+ C1=CCNC1:131
69
+ C1CNCNC1:129
70
+ C1=CNCN=C1:127
71
+ NP:124
72
+ C1=NN=CN1:124
73
+ C1=NC=NN1:123
74
+ C1C[NH+]CC[NH+]1:118
75
+ C=S:117
76
+ C1=CCCC1:112
77
+ C1=NCCCN1:107
78
+ C1=COCC1:107
79
+ C1=NCCS1:102
80
+ C1=NNCC1:100
81
+ C1CSCN1:98
82
+ NN:95
83
+ C1=CNCCC1:94
84
+ C1=CCNN=C1:90
85
+ C1=COC=N1:81
86
+ C1COCO1:78
87
+ C1=COCO1:78
88
+ C1=CCOCC1:75
89
+ C1=NC=NCC1:73
90
+ C1=CNCC=N1:72
91
+ C1=NN=CS1:70
92
+ C1CCSC1:66
93
+ C1CCC1:65
94
+ C1C[NH2+]CCN1:63
95
+ C1=CSCC1:62
96
+ C1=COC=CC1:62
97
+ C1=NC=NC=N1:60
98
+ C1COCN1:59
99
+ N=N:57
100
+ C1=CNCN1:55
101
+ C1=CNCCN1:55
102
+ C1=CNC=NC1:54
103
+ C1CCCNCC1:51
104
+ C1=CN=CCC1:51
105
+ C1=CC1:51
106
+ C1C2CC3CC1CC(C2)C3:50
107
+ C1=CCC=CC1:50
108
+ C:N:49
109
+ C1=CCOC=C1:48
110
+ [NH2+]O:47
111
+ C1=NCNCC1:46
112
+ C1CNC1:44
113
+ C1CNCCNC1:42
114
+ C1=CNC=CC1:40
115
+ C1CO1:39
116
+ C1=CC[NH+]CC1:39
117
+ N=O:37
118
+ C1=CSCN1:35
119
+ [NH+]N:34
120
+ C1=NC=NO1:34
121
+ C1=CNCCNC1:32
122
+ C1=CC[NH2+]CC1:30
123
+ [N+]=[N-]:29
124
+ [N+]=N:29
125
+ O[PH]:29
126
+ C1CNCSC1:29
127
+ C1CCCCCC1:29
128
+ C1=NN=NN1:27
129
+ C1COPOC1:26
130
+ C1CCSCC1:26
131
+ C1=NNCCC1:26
132
+ C1=CCCCCC1:26
133
+ C1C[Fe]1:25
134
+ C1=NN=CO1:25
135
+ C1=NCCN=C1:25
136
+ C1=CSC=[N+]1:25
137
+ PS:24
138
+ C1CNCOC1:24
139
+ C1CCC[NH2+]CC1:24
140
+ C1=NCCO1:24
141
+ C1=COCCN1:24
142
+ C1=CNC=[N+]1:24
143
+ C1=CCN=CC1:24
144
+ C[PH]:23
145
+ C1COCOC1:23
146
+ C1CNCC[NH2+]C1:23
147
+ C1=CONC1:23
148
+ C1CNCCOC1:22
149
+ C1=NON=C1:21
150
+ C1=NNCN1:21
151
+ O:S:20
152
+ C[Si]:20
153
+ C1C[NH2+]CC[NH+]1:19
154
+ C1C[NH+]C1:19
155
+ C1=[N+]CCC1:19
156
+ C1=NOCC1:19
157
+ C1NO1:18
158
+ C1=CNCCN=C1:18
159
+ SS:17
160
+ C1CSC[NH+]1:17
161
+ C1COCCO1:17
162
+ C1=CCOC1:17
163
+ C[Se]:16
164
+ C[AsH]:16
165
+ C1CSC[NH2+]1:16
166
+ C1CSCCN1:16
167
+ C1CCCCCCC1:16
168
+ C1=NC=NCN1:16
169
+ C1=CSN=C1:16
170
+ C1=CSCCC1:16
171
+ C1=COCCO1:16
172
+ C1=CCSC1:16
173
+ C1=CCN=C1:16
174
+ C1COC1:15
175
+ C1=NCCNC1:15
176
+ C1=NC=NNC1:15
177
+ C1=C[Fe]1:15
178
+ C1=CSNCC1:15
179
+ C1=CN[C@H]2CCCC(N1)O2:15
180
+ C1=CC=[N+]C=C1:15
181
+ [NH+]O:14
182
+ C=[N+]:14
183
+ C1CCNCNC1:14
184
+ C1=NNN=C1:14
185
+ C1=CNNC1:14
186
+ C1=CN=NC=C1:14
187
+ C1=CC=NN=C1:14
188
+ C1C[C@@H]2CC[C@H](C1)[NH+]2:13
189
+ C1CNN=N1:13
190
+ C1=NCCCC1:13
191
+ C1C[Ru]1:12
192
+ C1=NNN=N1:12
193
+ [N+]N:11
194
+ C1CNSC1:11
195
+ C1CC[NH2+]NC1:11
196
+ C1CC2NCCN[C@@H](C1)O2:11
197
+ C1=NCNC1:11
198
+ C1=NC=NC1:11
199
+ C1=CCCCC=C1:11
200
+ C1=CCC=C1:11
201
+ O=[SH]:10
202
+ C1CSNCN1:10
203
+ C1=CSCCN1:10
204
+ C1=CC=CCC=C1:10
205
+ C1C[NH2+]C1:9
206
+ C1C[NH+]2CCC1CC2:9
207
+ C1CC[SH]C1:9
208
+ C1=[N+]CCN1:9
209
+ C1=CSNC1:9
210
+ C1CN[Fe]NC1:8
211
+ C1CNCC[NH+]C1:8
212
+ C1=NNCCN1:8
213
+ C1=CSN=N1:8
214
+ C1=CNCCCN1:8
215
+ C1=CNC=CN1:8
216
+ C1=CC[NH+]C1:8
217
+ C1=CCSCC1:8
218
+ C1=CCNNC1:8
219
+ C1NCON1:7
220
+ C1COCC[NH2+]1:7
221
+ C1CNSN1:7
222
+ C1CNNC1:7
223
+ C1CCC[NH+]CC1:7
224
+ C1=[N+]O[Fe]O1:7
225
+ C1=C[Ru]1:7
226
+ C1=C[N+][Co][N+]=C1:7
227
+ C1=CNN=CC1:7
228
+ [O-]P:6
229
+ O:P:6
230
+ N[NH3+]:6
231
+ N[NH2+]:6
232
+ FS:6
233
+ C[SeH]:6
234
+ C1NO[Fe]O1:6
235
+ C1C[C@@H]2CC[C@H](C1)N2:6
236
+ C1CSCCO1:6
237
+ C1CC[N+]C1:6
238
+ C1=NSN=C1:6
239
+ C1=NCCC1:6
240
+ C1=COCN1:6
241
+ C1=CNNCC1:6
242
+ C1=CNCSC1:6
243
+ C1=CNC=CNC1:6
244
+ C1=CN=NC1:6
245
+ C1=CC[N+]=C1:6
246
+ [N+]O:5
247
+ P=S:5
248
+ C1[NH+]O1:5
249
+ C1C[Rh]1:5
250
+ C1C[C@@H]2CC[C@H](C1)[N+]2:5
251
+ C1C[C@@H]2CCNC[C@H](C1)N2:5
252
+ C1CN[NH2+]C1:5
253
+ C1CC[SH]CC1:5
254
+ C1CCNSCC1:5
255
+ C1CCNNC1:5
256
+ C1=N[N+]=CS1:5
257
+ C1=NSCCN1:5
258
+ C1=NCCNCC1:5
259
+ C1=NC=NS1:5
260
+ C1=C\CCCCCC/1:5
261
+ C1=C[Fe]C1:5
262
+ C1=CSNCN1:5
263
+ C1=COCCCN1:5
264
+ C1=CN[Co][N+]=C1:5
265
+ C1=CNCCCC1:5
266
+ C1=CN=NC=N1:5
267
+ C1=CCCNCC1:5
268
+ [NH3+]O:4
269
+ [NH+][NH2+]:4
270
+ O[Si]:4
271
+ C1N[NH2+]CS1:4
272
+ C1NCN[Fe]N1:4
273
+ C1NCNN1:4
274
+ C1C[NH2+]1:4
275
+ C1C[NH+]CC[NH+]C1:4
276
+ C1C[C@H]2CC[C@@H]1C2:4
277
+ C1C[C@@H]2CC[C@H](C1)[NH2+]2:4
278
+ C1COCC[N+]1:4
279
+ C1CNOC1:4
280
+ C1CCO[NH2+]C1:4
281
+ C1CC2CCC1C2:4
282
+ C1=[N+]NCC1:4
283
+ C1=NNCNC1:4
284
+ C1=NCCSN1:4
285
+ C1=C[N+][Mg][N+]=C1:4
286
+ C1=CSNC=N1:4
287
+ C1=CN=NCC1:4
288
+ C1=CCNCCC1:4
289
+ C1=CCCOCC1:4
290
+ C1=CCCOC=C1:4
291
+ C1=CCC1:4
292
+ C1=C2C[N@@H+]3CC[C@]45CCN6CC[C@H](OC1)[C@@H]([C@H]64)[C@H]2C[C@H]35:4
293
+ B1CCC=CO1:4
294
+ O=[V]:3
295
+ O=[Sb]:3
296
+ N=S:3
297
+ C1NC[C@H]2CNC[C@@H]1C2:3
298
+ C1C[C@@H]2CC[C@H](C1)C2:3
299
+ C1COSO1:3
300
+ C1COC[NH2+]1:3
301
+ C1CN[NH+]C1:3
302
+ C1CN[Co][N+]1:3
303
+ C1CNCC[N+]1:3
304
+ C1CNCCSC1:3
305
+ C1CN=[N+]C1:3
306
+ C1CC[NH+]NC1:3
307
+ C1CC[NH+]CCNC1:3
308
+ C1CCSNC1:3
309
+ C1CCNSNC1:3
310
+ C1CC2C[NH+][C@@H]3C4CCC[C@]3(C1)[C@@H]2C4:3
311
+ C1CC2CCC1CC2:3
312
+ C1=[N+]CCS1:3
313
+ C1=[N+]CCCN1:3
314
+ C1=[N+]CCCCC1:3
315
+ C1=NSCC1:3
316
+ C1=NNCSC1:3
317
+ C1=NNCS1:3
318
+ C1=NCNN=C1:3
319
+ C1=NCCCCC1:3
320
+ C1=C[Ru]C1:3
321
+ C1=C[Rh]1:3
322
+ C1=C[N+]=CC1:3
323
+ C1=C[C@H]2COCC(C1)C2:3
324
+ C1=C[C@H]2CC[C@@H]1N2:3
325
+ C1=C[C@H]2CC[C@@H]1C2:3
326
+ C1=C[C@@H]2CC=C[C@H](C1)C2:3
327
+ C1=COCOC1:3
328
+ C1=COCCNC1:3
329
+ C1=CNSCCN1:3
330
+ C1=CNSC=C1:3
331
+ C1=CNSC1:3
332
+ C1=CNC=CCC1:3
333
+ C1=CCCCN=C1:3
334
+ C1=CCCCCCC1:3
335
+ C1=CCCC=CC1:3
336
+ C1=CC2CCO[C@@H](C1)C2:3
337
+ B1C=CCO1:3
338
+ [O-]S:2
339
+ OO:2
340
+ C1OCO1:2
341
+ C1C[NH2+]C[NH2+]C1:2
342
+ C1C[NH2+]CN1:2
343
+ C1C[NH+][NH2+]C1:2
344
+ C1C[C@H]2C[C@H](CCO2)N1:2
345
+ C1C[C@H]2COC[C@@H]1C2:2
346
+ C1C[C@@H]2C[C@H]1CN2:2
347
+ C1C[C@@H]2CNC[C@H](C1)N2:2
348
+ C1CSN=N1:2
349
+ C1CNC[NH2+]C1:2
350
+ C1CN1:2
351
+ C1CC[NH2+][NH2+]C1:2
352
+ C1CC[N+]CC1:2
353
+ C1CCSCNC1:2
354
+ C1CCCNCCC1:2
355
+ C1CC2OC[C@@H](C1)O2:2
356
+ C1CC2COC(C1)C2:2
357
+ C1CC2CC[C@H](C1)[NH2+]2:2
358
+ C1CC2CCO[Fe](O1)OC2:2
359
+ C1CC2CCC(O1)O2:2
360
+ C1CC2CCC(C1)C2:2
361
+ C1=[N+]CNN1:2
362
+ C1=N[N+]=CCC1:2
363
+ C1=NSNC1:2
364
+ C1=NNCO1:2
365
+ C1=NN=CCC1:2
366
+ C1=NN=CC1:2
367
+ C1=NCSCC1:2
368
+ C1=NCC=[N+]1:2
369
+ C1=NC=NN=C1:2
370
+ C1=C[N+][Zn]NC1:2
371
+ C1=C[N+]C=CC1:2
372
+ C1=C[C@H]2C[NH2+]C[C@@H]1C2:2
373
+ C1=C[C@H]2C[NH2+]C[C@@H](C1)[NH2+]2:2
374
+ C1=C[C@H]2CC=CC(C1)C2:2
375
+ C1=CSNCCO1:2
376
+ C1=CSC=CN1:2
377
+ C1=COSNC1:2
378
+ C1=COCCCNC1:2
379
+ C1=COCC=N1:2
380
+ C1=CN[Zn][N+]=C1:2
381
+ C1=CN[C@H]2CCC(N1)O2:2
382
+ C1=CNSNC1:2
383
+ C1=CNSCCC1:2
384
+ C1=CNNC=C1:2
385
+ C1=CNC[NH2+]C1:2
386
+ C1=CNCCOC1:2
387
+ C1=CNCCCNC1:2
388
+ C1=CN=[N+]C1:2
389
+ C1=CN=CNCC1:2
390
+ C1=CN=CCN=C1:2
391
+ C1=CN=CC1:2
392
+ C1=CC[NH2+]NC1:2
393
+ C1=CC[NH2+]C1:2
394
+ C1=CC[NH+]NC1:2
395
+ C1=CC[NH+]CCC1:2
396
+ C1=CC[N+]CC1:2
397
+ C1=CCOC=CC1:2
398
+ C1=CCN=CCC1:2
399
+ C1=CCCNC=C1:2
400
+ C1=CC2CC(C1)CO2:2
401
+ B1OCCO1:2
402
+ B1CCCO1:2
403
+ [N+][NH+]:1
404
+ O[Fe]:1
405
+ O1POPOPOP1:1
406
+ N1OO1:1
407
+ C[TeH]:1
408
+ C[Sb]:1
409
+ C[Ru]:1
410
+ C[O-]:1
411
+ C[AsH2]:1
412
+ C1O[C@@H]2COC1O2:1
413
+ C1N[C@@H]2CN[C@H]1C2:1
414
+ C1NN=NN1:1
415
+ C1NC[C@H]2C[NH2+]C[C@@H]1C2:1
416
+ C1NC[C@H]2C[NH+]C[C@@H]1C2:1
417
+ C1NC[C@H]2CC[C@H](C2)N1:1
418
+ C1NC[C@@H]2OC[C@H]1O2:1
419
+ C1NCO[NH2+]1:1
420
+ C1NCNCN1:1
421
+ C1NC2COC[C@H](C2)S1:1
422
+ C1N=[N+]CS1:1
423
+ C1N=NCO1:1
424
+ C1N=N1:1
425
+ C1C[NH2+][Pt][NH2+]1:1
426
+ C1C[NH2+]OC1:1
427
+ C1C[NH2+]C[NH+]1:1
428
+ C1C[NH2+]CSC1:1
429
+ C1C[NH2+]CC[NH2+]1:1
430
+ C1C[NH2+]CCSC1:1
431
+ C1C[NH2+]CCOC1:1
432
+ C1C[NH+][NH2+]N1:1
433
+ C1C[NH+]C[NH+]C1:1
434
+ C1C[NH+]CSC1:1
435
+ C1C[N+][Co][N+]1:1
436
+ C1C[N+]CNC1:1
437
+ C1C[C@H]2C[NH+]C[C@@H]1N2:1
438
+ C1C[C@H]2C[C@@H]1CN2:1
439
+ C1C[C@H]2CC[C@@H]1O2:1
440
+ C1C[C@H]2CCC[C@@H](C1)[NH+]2:1
441
+ C1C[C@@H]2C[C@@H](CCO2)N1:1
442
+ C1C[C@@H]2CO[C@H](C1)N2:1
443
+ C1C[C@@H]2CC[C@H](C1)O2:1
444
+ C1CSCS1:1
445
+ C1CSCN[NH2+]1:1
446
+ C1CSCC[NH+]1:1
447
+ C1CSC1:1
448
+ C1CO[V]O1:1
449
+ C1CO[SH]C1:1
450
+ C1COSN1:1
451
+ C1COPO1:1
452
+ C1COCCOC1:1
453
+ C1CN[C@H]2CCC(N1)O2:1
454
+ C1CNSCCN1:1
455
+ C1CNNCN1:1
456
+ C1CNCN[NH+]C1:1
457
+ C1CN=NC1:1
458
+ C1CCSSC1:1
459
+ C1CCONC1:1
460
+ C1CCOCOC1:1
461
+ C1CCOCCOC1:1
462
+ C1CCNNCC1:1
463
+ C1CC2CO[C@@H](C2)[NH+]1:1
464
+ C1CC2COC(C1)OO2:1
465
+ C1CC2CC[NH+][C@H](C1)C2:1
466
+ C1CC2CC[C@H](C1)C2:1
467
+ C1CC2CCC1C[NH+]2:1
468
+ C1CC2CC1CO2:1
469
+ C1C2C[C@@H]3C[C@H](C2)C1O3:1
470
+ C1C2CC1C2:1
471
+ C1=[N+]CON1:1
472
+ C1=[N+]CCOC1:1
473
+ C1=[N+]CCNC1:1
474
+ C1=NSN=CC1:1
475
+ C1=NO[N+]=C1:1
476
+ C1=NOCN1:1
477
+ C1=NNNC1:1
478
+ C1=NNC=[N+]1:1
479
+ C1=NN=[N+]C1:1
480
+ C1=NC[NH2+]CC1:1
481
+ C1=NC[NH2+]C1:1
482
+ C1=NCON1:1
483
+ C1=NCOC1:1
484
+ C1=NCNN1:1
485
+ C1=NCNCN1:1
486
+ C1=NCN=CN1:1
487
+ C1=NCC[N+]=C1:1
488
+ C1=NCCC=[N+]1:1
489
+ C1=C\CCOCCC/1:1
490
+ C1=C[Se]C=N1:1
491
+ C1=C[SH]CCCN1:1
492
+ C1=C[Ru]NC1:1
493
+ C1=C[Rh]C1:1
494
+ C1=C[N+]CC1:1
495
+ C1=C[N+]C=NC1:1
496
+ C1=C[N+]=CNC1:1
497
+ C1=C[C@H]2C[C@H](CCO2)N1:1
498
+ C1=C[C@H]2C[C@H](C2)NC1:1
499
+ C1=C[C@H]2CC[C@@H]1O2:1
500
+ C1=C[C@H]2CCC[C@@H]1[N+]2:1
501
+ C1=C[C@H]2CCCC1[NH2+]2:1
502
+ C1=C[C@@H]2C[NH2+]C[C@H](C1)[NH2+]2:1
503
+ C1=C[C@@H]2CN(C1)CN2:1
504
+ C1=C[C@@H]2CCOC(C1)O2:1
505
+ C1=C[C@@H]2CC=CC(C1)C2:1
506
+ C1=CSOCNC1:1
507
+ C1=CSOC1:1
508
+ C1=CSCO1:1
509
+ C1=CSCCCN1:1
510
+ C1=CSCCC=N1:1
511
+ C1=CSC=CC1:1
512
+ C1=CPNCN1:1
513
+ C1=CO[NH2+]C1:1
514
+ C1=COCNC1:1
515
+ C1=COCC[N+]1:1
516
+ C1=COCCCO1:1
517
+ C1=COCCCC1:1
518
+ C1=COC=CN1:1
519
+ C1=CN[NH2+]N1:1
520
+ C1=CN[NH+]C1:1
521
+ C1=CN[C@@H]2CC[C@H](N1)O2:1
522
+ C1=CNOC1:1
523
+ C1=CNCOC1:1
524
+ C1=CNCNN=C1:1
525
+ C1=CNCN=N1:1
526
+ C1=CNCCC=N1:1
527
+ C1=CNCC=[N+]1:1
528
+ C1=CNC=[N+]C1:1
529
+ C1=CN=CCSC1:1
530
+ C1=CN=CCCC1:1
531
+ C1=CC[NH2+]CCC1:1
532
+ C1=CC[N+]C=C1:1
533
+ C1=CCOCCC1:1
534
+ C1=CCNCC=C1:1
535
+ C1=CCN=NC1:1
536
+ C1=CCC[NH+]CC1:1
537
+ C1=CCCSC=C1:1
538
+ C1=CC=NCC=C1:1
539
+ C1=CC=COC=C1:1
540
+ C1=CC=CNC=C1:1
541
+ C1=CC2CC[NH+][C@@H](C1)C2:1
542
+ C1=CC2CCCC(C1)C2:1
543
+ C1=CC2CCC(C1)O2:1
544
+ C1=CC2CC=C[C@H](C2)OC1:1
545
+ C1=CC2CC(CN2)[NH2+]C1:1
546
+ C1=CC2C3C[NH2+][C@H]2CC1C3:1
547
+ C1#CCCCCCC1:1
548
+ B1OCCCO1:1
549
+ B1CCCCO1:1