saicharan2804 commited on
Commit
36173e1
1 Parent(s): af1e58a

Added manual implementation of metrics

Browse files
Files changed (5) hide show
  1. fcd.py +102 -0
  2. metrics.py +344 -0
  3. molgen_metric.py +0 -2
  4. utils.py +316 -0
  5. utils2.py +271 -0
fcd.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fcd_torch.utils import SmilesDataset, \
2
+ calculate_frechet_distance, \
3
+ todevice, \
4
+ load_imported_model
5
+ import torch
6
+ from torch.utils.data import DataLoader
7
+ import os
8
+ import numpy as np
9
+ import warnings
10
+
11
+
12
+ class FCD:
13
+ """
14
+ Computes Frechet ChemNet Distance on PyTorch.
15
+ * You can precalculate mean and sigma for further usage,
16
+ e.g. if you use the statistics from the same dataset
17
+ multiple times.
18
+ * Supports GPU and selection of GPU index
19
+ * Multithread SMILES parsing
20
+
21
+ Example 1:
22
+ fcd = FCD(device='cuda:0', n_jobs=8)
23
+ smiles_list = ['CCC', 'CCNC']
24
+ fcd(smiles_list, smiles_list)
25
+
26
+ Example 2:
27
+ fcd = FCD(device='cuda:0', n_jobs=8)
28
+ smiles_list = ['CCC', 'CCNC']
29
+ pgen = fcd.precalc(smiles_list)
30
+ fcd(smiles_list, pgen=pgen)
31
+ """
32
+ def __init__(self, device='cpu', n_jobs=1,
33
+ batch_size=512,
34
+ model_path=None,
35
+ canonize=True):
36
+ """
37
+ Loads ChemNet on device
38
+ params:
39
+ device: cpu for CPU, cuda:0 for GPU 0, etc.
40
+ n_jobs: number of workers to parse SMILES
41
+ batch_size: batch size for processing SMILES
42
+ model_path: path to ChemNet_v0.13_pretrained.pt
43
+ """
44
+ if model_path is None:
45
+ model_dir = os.path.split(__file__)[0]
46
+ model_path = os.path.join(model_dir, 'ChemNet_v0.13_pretrained.pt')
47
+
48
+ self.device = device
49
+ self.n_jobs = n_jobs if n_jobs != 1 else 0
50
+ self.batch_size = batch_size
51
+ keras_config = torch.load(model_path)
52
+ self.model = load_imported_model(keras_config)
53
+ self.model.eval()
54
+ self.canonize = canonize
55
+
56
+ def get_predictions(self, smiles_list):
57
+ if len(smiles_list) == 0:
58
+ return np.zeros((0, 512))
59
+ dataloader = DataLoader(
60
+ SmilesDataset(smiles_list, canonize=self.canonize),
61
+ batch_size=self.batch_size,
62
+ num_workers=self.n_jobs
63
+ )
64
+ with todevice(self.model, self.device), torch.no_grad():
65
+ chemnet_activations = []
66
+ for batch in dataloader:
67
+ chemnet_activations.append(
68
+ self.model(
69
+ batch.transpose(1, 2).float().to(self.device)
70
+ ).to('cpu').detach().numpy()
71
+ )
72
+ return np.row_stack(chemnet_activations)
73
+
74
+ def precalc(self, smiles_list):
75
+ if len(smiles_list) < 2:
76
+ warnings.warn("Can't compute FCD for less than 2 molecules"
77
+ "({} given)".format(len(smiles_list)))
78
+ return {}
79
+ chemnet_activations = self.get_predictions(smiles_list)
80
+ mu = chemnet_activations.mean(0)
81
+ sigma = np.cov(chemnet_activations.T)
82
+ return {'mu': mu, 'sigma': sigma}
83
+
84
+ def metric(self, pref, pgen):
85
+ if 'mu' not in pref or 'sigma' not in pgen:
86
+ warnings.warn("Failed to compute FCD (check ref)")
87
+ return np.nan
88
+ if 'mu' not in pgen or 'sigma' not in pgen:
89
+ warnings.warn("Failed to compute FCD (check gen)")
90
+ return np.nan
91
+ return calculate_frechet_distance(
92
+ pref['mu'], pref['sigma'], pgen['mu'], pgen['sigma']
93
+ )
94
+
95
+ def __call__(self, ref=None, gen=None, pref=None, pgen=None):
96
+ assert (ref is None) != (pref is None), "specify ref xor pref"
97
+ assert (gen is None) != (pgen is None), "specify gen xor pgen"
98
+ if pref is None:
99
+ pref = self.precalc(ref)
100
+ if pgen is None:
101
+ pgen = self.precalc(gen)
102
+ return self.metric(pref, pgen)
metrics.py ADDED
@@ -0,0 +1,344 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+ from multiprocessing import Pool
3
+ import numpy as np
4
+ from scipy.spatial.distance import cosine as cos_distance
5
+ from fcd_torch import FCD as FCDMetric
6
+ from scipy.stats import wasserstein_distance
7
+
8
+ from moses.dataset import get_dataset, get_statistics
9
+ from moses.utils import mapper
10
+ from moses.utils import disable_rdkit_log, enable_rdkit_log
11
+ from .utils import compute_fragments, average_agg_tanimoto, \
12
+ compute_scaffolds, fingerprints, \
13
+ get_mol, canonic_smiles, mol_passes_filters, \
14
+ logP, QED, SA, weight
15
+
16
+
17
+ def get_all_metrics(gen, k=None, n_jobs=1,
18
+ device='cpu', batch_size=512, pool=None,
19
+ test=None, test_scaffolds=None,
20
+ ptest=None, ptest_scaffolds=None,
21
+ train=None):
22
+ """
23
+ Computes all available metrics between test (scaffold test)
24
+ and generated sets of SMILES.
25
+ Parameters:
26
+ gen: list of generated SMILES
27
+ k: int or list with values for unique@k. Will calculate number of
28
+ unique molecules in the first k molecules. Default [1000, 10000]
29
+ n_jobs: number of workers for parallel processing
30
+ device: 'cpu' or 'cuda:n', where n is GPU device number
31
+ batch_size: batch size for FCD metric
32
+ pool: optional multiprocessing pool to use for parallelization
33
+
34
+ test (None or list): test SMILES. If None, will load
35
+ a default test set
36
+ test_scaffolds (None or list): scaffold test SMILES. If None, will
37
+ load a default scaffold test set
38
+ ptest (None or dict): precalculated statistics of the test set. If
39
+ None, will load default test statistics. If you specified a custom
40
+ test set, default test statistics will be ignored
41
+ ptest_scaffolds (None or dict): precalculated statistics of the
42
+ scaffold test set If None, will load default scaffold test
43
+ statistics. If you specified a custom test set, default test
44
+ statistics will be ignored
45
+ train (None or list): train SMILES. If None, will load a default
46
+ train set
47
+ Available metrics:
48
+ * %valid
49
+ * %unique@k
50
+ * Frechet ChemNet Distance (FCD)
51
+ * Fragment similarity (Frag)
52
+ * Scaffold similarity (Scaf)
53
+ * Similarity to nearest neighbour (SNN)
54
+ * Internal diversity (IntDiv)
55
+ * Internal diversity 2: using square root of mean squared
56
+ Tanimoto similarity (IntDiv2)
57
+ * %passes filters (Filters)
58
+ * Distribution difference for logP, SA, QED, weight
59
+ * Novelty (molecules not present in train)
60
+ """
61
+ if test is None:
62
+ if ptest is not None:
63
+ raise ValueError(
64
+ "You cannot specify custom test "
65
+ "statistics for default test set")
66
+ test = get_dataset('test')
67
+ ptest = get_statistics('test')
68
+
69
+ if test_scaffolds is None:
70
+ if ptest_scaffolds is not None:
71
+ raise ValueError(
72
+ "You cannot specify custom scaffold test "
73
+ "statistics for default scaffold test set")
74
+ test_scaffolds = get_dataset('test_scaffolds')
75
+ ptest_scaffolds = get_statistics('test_scaffolds')
76
+
77
+ train = train or get_dataset('train')
78
+
79
+ if k is None:
80
+ k = [1000, 10000]
81
+ disable_rdkit_log()
82
+ metrics = {}
83
+ close_pool = False
84
+ if pool is None:
85
+ if n_jobs != 1:
86
+ pool = Pool(n_jobs)
87
+ close_pool = True
88
+ else:
89
+ pool = 1
90
+ metrics['valid'] = fraction_valid(gen, n_jobs=pool)
91
+ gen = remove_invalid(gen, canonize=True)
92
+ if not isinstance(k, (list, tuple)):
93
+ k = [k]
94
+ for _k in k:
95
+ metrics['unique@{}'.format(_k)] = fraction_unique(gen, _k, pool)
96
+
97
+ if ptest is None:
98
+ ptest = compute_intermediate_statistics(test, n_jobs=n_jobs,
99
+ device=device,
100
+ batch_size=batch_size,
101
+ pool=pool)
102
+ if test_scaffolds is not None and ptest_scaffolds is None:
103
+ ptest_scaffolds = compute_intermediate_statistics(
104
+ test_scaffolds, n_jobs=n_jobs,
105
+ device=device, batch_size=batch_size,
106
+ pool=pool
107
+ )
108
+ mols = mapper(pool)(get_mol, gen)
109
+ kwargs = {'n_jobs': pool, 'device': device, 'batch_size': batch_size}
110
+ kwargs_fcd = {'n_jobs': n_jobs, 'device': device, 'batch_size': batch_size}
111
+ metrics['FCD/Test'] = FCDMetric(**kwargs_fcd)(gen=gen, pref=ptest['FCD'])
112
+ metrics['SNN/Test'] = SNNMetric(**kwargs)(gen=mols, pref=ptest['SNN'])
113
+ metrics['Frag/Test'] = FragMetric(**kwargs)(gen=mols, pref=ptest['Frag'])
114
+ metrics['Scaf/Test'] = ScafMetric(**kwargs)(gen=mols, pref=ptest['Scaf'])
115
+ if ptest_scaffolds is not None:
116
+ metrics['FCD/TestSF'] = FCDMetric(**kwargs_fcd)(
117
+ gen=gen, pref=ptest_scaffolds['FCD']
118
+ )
119
+ metrics['SNN/TestSF'] = SNNMetric(**kwargs)(
120
+ gen=mols, pref=ptest_scaffolds['SNN']
121
+ )
122
+ metrics['Frag/TestSF'] = FragMetric(**kwargs)(
123
+ gen=mols, pref=ptest_scaffolds['Frag']
124
+ )
125
+ metrics['Scaf/TestSF'] = ScafMetric(**kwargs)(
126
+ gen=mols, pref=ptest_scaffolds['Scaf']
127
+ )
128
+
129
+ metrics['IntDiv'] = internal_diversity(mols, pool, device=device)
130
+ metrics['IntDiv2'] = internal_diversity(mols, pool, device=device, p=2)
131
+ metrics['Filters'] = fraction_passes_filters(mols, pool)
132
+
133
+ # Properties
134
+ for name, func in [('logP', logP), ('SA', SA),
135
+ ('QED', QED),
136
+ ('weight', weight)]:
137
+ metrics[name] = WassersteinMetric(func, **kwargs)(
138
+ gen=mols, pref=ptest[name])
139
+
140
+ if train is not None:
141
+ metrics['Novelty'] = novelty(mols, train, pool)
142
+ enable_rdkit_log()
143
+ if close_pool:
144
+ pool.close()
145
+ pool.join()
146
+ return metrics
147
+
148
+
149
+ def compute_intermediate_statistics(smiles, n_jobs=1, device='cpu',
150
+ batch_size=512, pool=None):
151
+ """
152
+ The function precomputes statistics such as mean and variance for FCD, etc.
153
+ It is useful to compute the statistics for test and scaffold test sets to
154
+ speedup metrics calculation.
155
+ """
156
+ close_pool = False
157
+ if pool is None:
158
+ if n_jobs != 1:
159
+ pool = Pool(n_jobs)
160
+ close_pool = True
161
+ else:
162
+ pool = 1
163
+ statistics = {}
164
+ mols = mapper(pool)(get_mol, smiles)
165
+ kwargs = {'n_jobs': pool, 'device': device, 'batch_size': batch_size}
166
+ kwargs_fcd = {'n_jobs': n_jobs, 'device': device, 'batch_size': batch_size}
167
+ statistics['FCD'] = FCDMetric(**kwargs_fcd).precalc(smiles)
168
+ statistics['SNN'] = SNNMetric(**kwargs).precalc(mols)
169
+ statistics['Frag'] = FragMetric(**kwargs).precalc(mols)
170
+ statistics['Scaf'] = ScafMetric(**kwargs).precalc(mols)
171
+ for name, func in [('logP', logP), ('SA', SA),
172
+ ('QED', QED),
173
+ ('weight', weight)]:
174
+ statistics[name] = WassersteinMetric(func, **kwargs).precalc(mols)
175
+ if close_pool:
176
+ pool.terminate()
177
+ return statistics
178
+
179
+
180
+ def fraction_passes_filters(gen, n_jobs=1):
181
+ """
182
+ Computes the fraction of molecules that pass filters:
183
+ * MCF
184
+ * PAINS
185
+ * Only allowed atoms ('C','N','S','O','F','Cl','Br','H')
186
+ * No charges
187
+ """
188
+ passes = mapper(n_jobs)(mol_passes_filters, gen)
189
+ return np.mean(passes)
190
+
191
+
192
+ def internal_diversity(gen, n_jobs=1, device='cpu', fp_type='morgan',
193
+ gen_fps=None, p=1):
194
+ """
195
+ Computes internal diversity as:
196
+ 1/|A|^2 sum_{x, y in AxA} (1-tanimoto(x, y))
197
+ """
198
+ if gen_fps is None:
199
+ gen_fps = fingerprints(gen, fp_type=fp_type, n_jobs=n_jobs)
200
+ return 1 - (average_agg_tanimoto(gen_fps, gen_fps,
201
+ agg='mean', device=device, p=p)).mean()
202
+
203
+
204
+ def fraction_unique(gen, k=None, n_jobs=1, check_validity=True):
205
+ """
206
+ Computes a number of unique molecules
207
+ Parameters:
208
+ gen: list of SMILES
209
+ k: compute unique@k
210
+ n_jobs: number of threads for calculation
211
+ check_validity: raises ValueError if invalid molecules are present
212
+ """
213
+ if k is not None:
214
+ if len(gen) < k:
215
+ warnings.warn(
216
+ "Can't compute unique@{}.".format(k) +
217
+ "gen contains only {} molecules".format(len(gen))
218
+ )
219
+ gen = gen[:k]
220
+ canonic = set(mapper(n_jobs)(canonic_smiles, gen))
221
+ if None in canonic and check_validity:
222
+ raise ValueError("Invalid molecule passed to unique@k")
223
+ return len(canonic) / len(gen)
224
+
225
+
226
+ def fraction_valid(gen, n_jobs=1):
227
+ """
228
+ Computes a number of valid molecules
229
+ Parameters:
230
+ gen: list of SMILES
231
+ n_jobs: number of threads for calculation
232
+ """
233
+ gen = mapper(n_jobs)(get_mol, gen)
234
+ return 1 - gen.count(None) / len(gen)
235
+
236
+
237
+ def novelty(gen, train, n_jobs=1):
238
+ gen_smiles = mapper(n_jobs)(canonic_smiles, gen)
239
+ gen_smiles_set = set(gen_smiles) - {None}
240
+ train_set = set(train)
241
+ return len(gen_smiles_set - train_set) / len(gen_smiles_set)
242
+
243
+
244
+ def remove_invalid(gen, canonize=True, n_jobs=1):
245
+ """
246
+ Removes invalid molecules from the dataset
247
+ """
248
+ if not canonize:
249
+ mols = mapper(n_jobs)(get_mol, gen)
250
+ return [gen_ for gen_, mol in zip(gen, mols) if mol is not None]
251
+ return [x for x in mapper(n_jobs)(canonic_smiles, gen) if
252
+ x is not None]
253
+
254
+
255
+ class Metric:
256
+ def __init__(self, n_jobs=1, device='cpu', batch_size=512, **kwargs):
257
+ self.n_jobs = n_jobs
258
+ self.device = device
259
+ self.batch_size = batch_size
260
+ for k, v in kwargs.values():
261
+ setattr(self, k, v)
262
+
263
+ def __call__(self, ref=None, gen=None, pref=None, pgen=None):
264
+ assert (ref is None) != (pref is None), "specify ref xor pref"
265
+ assert (gen is None) != (pgen is None), "specify gen xor pgen"
266
+ if pref is None:
267
+ pref = self.precalc(ref)
268
+ if pgen is None:
269
+ pgen = self.precalc(gen)
270
+ return self.metric(pref, pgen)
271
+
272
+ def precalc(self, moleclues):
273
+ raise NotImplementedError
274
+
275
+ def metric(self, pref, pgen):
276
+ raise NotImplementedError
277
+
278
+
279
+ class SNNMetric(Metric):
280
+ """
281
+ Computes average max similarities of gen SMILES to ref SMILES
282
+ """
283
+
284
+ def __init__(self, fp_type='morgan', **kwargs):
285
+ self.fp_type = fp_type
286
+ super().__init__(**kwargs)
287
+
288
+ def precalc(self, mols):
289
+ return {'fps': fingerprints(mols, n_jobs=self.n_jobs,
290
+ fp_type=self.fp_type)}
291
+
292
+ def metric(self, pref, pgen):
293
+ return average_agg_tanimoto(pref['fps'], pgen['fps'],
294
+ device=self.device)
295
+
296
+
297
+ def cos_similarity(ref_counts, gen_counts):
298
+ """
299
+ Computes cosine similarity between
300
+ dictionaries of form {name: count}. Non-present
301
+ elements are considered zero:
302
+
303
+ sim = <r, g> / ||r|| / ||g||
304
+ """
305
+ if len(ref_counts) == 0 or len(gen_counts) == 0:
306
+ return np.nan
307
+ keys = np.unique(list(ref_counts.keys()) + list(gen_counts.keys()))
308
+ ref_vec = np.array([ref_counts.get(k, 0) for k in keys])
309
+ gen_vec = np.array([gen_counts.get(k, 0) for k in keys])
310
+ return 1 - cos_distance(ref_vec, gen_vec)
311
+
312
+
313
+ class FragMetric(Metric):
314
+ def precalc(self, mols):
315
+ return {'frag': compute_fragments(mols, n_jobs=self.n_jobs)}
316
+
317
+ def metric(self, pref, pgen):
318
+ return cos_similarity(pref['frag'], pgen['frag'])
319
+
320
+
321
+ class ScafMetric(Metric):
322
+ def precalc(self, mols):
323
+ return {'scaf': compute_scaffolds(mols, n_jobs=self.n_jobs)}
324
+
325
+ def metric(self, pref, pgen):
326
+ return cos_similarity(pref['scaf'], pgen['scaf'])
327
+
328
+
329
+ class WassersteinMetric(Metric):
330
+ def __init__(self, func=None, **kwargs):
331
+ self.func = func
332
+ super().__init__(**kwargs)
333
+
334
+ def precalc(self, mols):
335
+ if self.func is not None:
336
+ values = mapper(self.n_jobs)(self.func, mols)
337
+ else:
338
+ values = mols
339
+ return {'values': values}
340
+
341
+ def metric(self, pref, pgen):
342
+ return wasserstein_distance(
343
+ pref['values'], pgen['values']
344
+ )
molgen_metric.py CHANGED
@@ -94,8 +94,6 @@ class molgen_metric(evaluate.Measurement):
94
 
95
  def _compute(self, generated_smiles, train_smiles = None):
96
 
97
-
98
-
99
  Results = metrics.get_all_metrics(gen = generated_smiles, train= train_smiles)
100
 
101
  generated_smiles = [s for s in generated_smiles if s != '']
 
94
 
95
  def _compute(self, generated_smiles, train_smiles = None):
96
 
 
 
97
  Results = metrics.get_all_metrics(gen = generated_smiles, train= train_smiles)
98
 
99
  generated_smiles = [s for s in generated_smiles if s != '']
utils.py ADDED
@@ -0,0 +1,316 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ from multiprocessing import Pool
3
+ from collections import UserList, defaultdict
4
+ import numpy as np
5
+ import pandas as pd
6
+ from matplotlib import pyplot as plt
7
+ import torch
8
+ from rdkit import rdBase
9
+ from rdkit import Chem
10
+
11
+
12
+ # https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader
13
+ def set_torch_seed_to_all_gens(_):
14
+ seed = torch.initial_seed() % (2**32 - 1)
15
+ random.seed(seed)
16
+ np.random.seed(seed)
17
+
18
+
19
+ class SpecialTokens:
20
+ bos = '<bos>'
21
+ eos = '<eos>'
22
+ pad = '<pad>'
23
+ unk = '<unk>'
24
+
25
+
26
+ class CharVocab:
27
+ @classmethod
28
+ def from_data(cls, data, *args, **kwargs):
29
+ chars = set()
30
+ for string in data:
31
+ chars.update(string)
32
+
33
+ return cls(chars, *args, **kwargs)
34
+
35
+ def __init__(self, chars, ss=SpecialTokens):
36
+ if (ss.bos in chars) or (ss.eos in chars) or \
37
+ (ss.pad in chars) or (ss.unk in chars):
38
+ raise ValueError('SpecialTokens in chars')
39
+
40
+ all_syms = sorted(list(chars)) + [ss.bos, ss.eos, ss.pad, ss.unk]
41
+
42
+ self.ss = ss
43
+ self.c2i = {c: i for i, c in enumerate(all_syms)}
44
+ self.i2c = {i: c for i, c in enumerate(all_syms)}
45
+
46
+ def __len__(self):
47
+ return len(self.c2i)
48
+
49
+ @property
50
+ def bos(self):
51
+ return self.c2i[self.ss.bos]
52
+
53
+ @property
54
+ def eos(self):
55
+ return self.c2i[self.ss.eos]
56
+
57
+ @property
58
+ def pad(self):
59
+ return self.c2i[self.ss.pad]
60
+
61
+ @property
62
+ def unk(self):
63
+ return self.c2i[self.ss.unk]
64
+
65
+ def char2id(self, char):
66
+ if char not in self.c2i:
67
+ return self.unk
68
+
69
+ return self.c2i[char]
70
+
71
+ def id2char(self, id):
72
+ if id not in self.i2c:
73
+ return self.ss.unk
74
+
75
+ return self.i2c[id]
76
+
77
+ def string2ids(self, string, add_bos=False, add_eos=False):
78
+ ids = [self.char2id(c) for c in string]
79
+
80
+ if add_bos:
81
+ ids = [self.bos] + ids
82
+ if add_eos:
83
+ ids = ids + [self.eos]
84
+
85
+ return ids
86
+
87
+ def ids2string(self, ids, rem_bos=True, rem_eos=True):
88
+ if len(ids) == 0:
89
+ return ''
90
+ if rem_bos and ids[0] == self.bos:
91
+ ids = ids[1:]
92
+ if rem_eos and ids[-1] == self.eos:
93
+ ids = ids[:-1]
94
+
95
+ string = ''.join([self.id2char(id) for id in ids])
96
+
97
+ return string
98
+
99
+
100
+ class OneHotVocab(CharVocab):
101
+ def __init__(self, *args, **kwargs):
102
+ super(OneHotVocab, self).__init__(*args, **kwargs)
103
+ self.vectors = torch.eye(len(self.c2i))
104
+
105
+
106
+ def mapper(n_jobs):
107
+ '''
108
+ Returns function for map call.
109
+ If n_jobs == 1, will use standard map
110
+ If n_jobs > 1, will use multiprocessing pool
111
+ If n_jobs is a pool object, will return its map function
112
+ '''
113
+ if n_jobs == 1:
114
+ def _mapper(*args, **kwargs):
115
+ return list(map(*args, **kwargs))
116
+
117
+ return _mapper
118
+ if isinstance(n_jobs, int):
119
+ pool = Pool(n_jobs)
120
+
121
+ def _mapper(*args, **kwargs):
122
+ try:
123
+ result = pool.map(*args, **kwargs)
124
+ finally:
125
+ pool.terminate()
126
+ return result
127
+
128
+ return _mapper
129
+ return n_jobs.map
130
+
131
+
132
+ class Logger(UserList):
133
+ def __init__(self, data=None):
134
+ super().__init__()
135
+ self.sdata = defaultdict(list)
136
+ for step in (data or []):
137
+ self.append(step)
138
+
139
+ def __getitem__(self, key):
140
+ if isinstance(key, int):
141
+ return self.data[key]
142
+ if isinstance(key, slice):
143
+ return Logger(self.data[key])
144
+ ldata = self.sdata[key]
145
+ if isinstance(ldata[0], dict):
146
+ return Logger(ldata)
147
+ return ldata
148
+
149
+ def append(self, step_dict):
150
+ super().append(step_dict)
151
+ for k, v in step_dict.items():
152
+ self.sdata[k].append(v)
153
+
154
+ def save(self, path):
155
+ df = pd.DataFrame(list(self))
156
+ df.to_csv(path, index=None)
157
+
158
+
159
+ class LogPlotter:
160
+ def __init__(self, log):
161
+ self.log = log
162
+
163
+ def line(self, ax, name):
164
+ if isinstance(self.log[0][name], dict):
165
+ for k in self.log[0][name]:
166
+ ax.plot(self.log[name][k], label=k)
167
+ ax.legend()
168
+ else:
169
+ ax.plot(self.log[name])
170
+
171
+ ax.set_ylabel('value')
172
+ ax.set_xlabel('epoch')
173
+ ax.set_title(name)
174
+
175
+ def grid(self, names, size=7):
176
+ _, axs = plt.subplots(nrows=len(names) // 2, ncols=2,
177
+ figsize=(size * 2, size * (len(names) // 2)))
178
+
179
+ for ax, name in zip(axs.flatten(), names):
180
+ self.line(ax, name)
181
+
182
+
183
+ class CircularBuffer:
184
+ def __init__(self, size):
185
+ self.max_size = size
186
+ self.data = np.zeros(self.max_size)
187
+ self.size = 0
188
+ self.pointer = -1
189
+
190
+ def add(self, element):
191
+ self.size = min(self.size + 1, self.max_size)
192
+ self.pointer = (self.pointer + 1) % self.max_size
193
+ self.data[self.pointer] = element
194
+ return element
195
+
196
+ def last(self):
197
+ assert self.pointer != -1, "Can't get an element from an empty buffer!"
198
+ return self.data[self.pointer]
199
+
200
+ def mean(self):
201
+ if self.size > 0:
202
+ return self.data[:self.size].mean()
203
+ return 0.0
204
+
205
+
206
+ def disable_rdkit_log():
207
+ rdBase.DisableLog('rdApp.*')
208
+
209
+
210
+ def enable_rdkit_log():
211
+ rdBase.EnableLog('rdApp.*')
212
+
213
+
214
+ def get_mol(smiles_or_mol):
215
+ '''
216
+ Loads SMILES/molecule into RDKit's object
217
+ '''
218
+ if isinstance(smiles_or_mol, str):
219
+ if len(smiles_or_mol) == 0:
220
+ return None
221
+ mol = Chem.MolFromSmiles(smiles_or_mol)
222
+ if mol is None:
223
+ return None
224
+ try:
225
+ Chem.SanitizeMol(mol)
226
+ except ValueError:
227
+ return None
228
+ return mol
229
+ return smiles_or_mol
230
+
231
+
232
+ class StringDataset:
233
+ def __init__(self, vocab, data):
234
+ """
235
+ Creates a convenient Dataset with SMILES tokinization
236
+
237
+ Arguments:
238
+ vocab: CharVocab instance for tokenization
239
+ data (list): SMILES strings for the dataset
240
+ """
241
+ self.vocab = vocab
242
+ self.tokens = [vocab.string2ids(s) for s in data]
243
+ self.data = data
244
+ self.bos = vocab.bos
245
+ self.eos = vocab.eos
246
+
247
+ def __len__(self):
248
+ """
249
+ Computes a number of objects in the dataset
250
+ """
251
+ return len(self.tokens)
252
+
253
+ def __getitem__(self, index):
254
+ """
255
+ Prepares torch tensors with a given SMILES.
256
+
257
+ Arguments:
258
+ index (int): index of SMILES in the original dataset
259
+
260
+ Returns:
261
+ A tuple (with_bos, with_eos, smiles), where
262
+ * with_bos is a torch.long tensor of SMILES tokens with
263
+ BOS (beginning of a sentence) token
264
+ * with_eos is a torch.long tensor of SMILES tokens with
265
+ EOS (end of a sentence) token
266
+ * smiles is an original SMILES from the dataset
267
+ """
268
+ tokens = self.tokens[index]
269
+ with_bos = torch.tensor([self.bos] + tokens, dtype=torch.long)
270
+ with_eos = torch.tensor(tokens + [self.eos], dtype=torch.long)
271
+ return with_bos, with_eos, self.data[index]
272
+
273
+ def default_collate(self, batch, return_data=False):
274
+ """
275
+ Simple collate function for SMILES dataset. Joins a
276
+ batch of objects from StringDataset into a batch
277
+
278
+ Arguments:
279
+ batch: list of objects from StringDataset
280
+ pad: padding symbol, usually equals to vocab.pad
281
+ return_data: if True, will return SMILES used in a batch
282
+
283
+ Returns:
284
+ with_bos, with_eos, lengths [, data] where
285
+ * with_bos: padded sequence with BOS in the beginning
286
+ * with_eos: padded sequence with EOS in the end
287
+ * lengths: array with SMILES lengths in the batch
288
+ * data: SMILES in the batch
289
+
290
+ Note: output batch is sorted with respect to SMILES lengths in
291
+ decreasing order, since this is a default format for torch
292
+ RNN implementations
293
+ """
294
+ with_bos, with_eos, data = list(zip(*batch))
295
+ lengths = [len(x) for x in with_bos]
296
+ order = np.argsort(lengths)[::-1]
297
+ with_bos = [with_bos[i] for i in order]
298
+ with_eos = [with_eos[i] for i in order]
299
+ lengths = [lengths[i] for i in order]
300
+ with_bos = torch.nn.utils.rnn.pad_sequence(
301
+ with_bos, padding_value=self.vocab.pad
302
+ )
303
+ with_eos = torch.nn.utils.rnn.pad_sequence(
304
+ with_eos, padding_value=self.vocab.pad
305
+ )
306
+ if return_data:
307
+ data = np.array(data)[order]
308
+ return with_bos, with_eos, lengths, data
309
+ return with_bos, with_eos, lengths
310
+
311
+
312
+ def batch_to_device(batch, device):
313
+ return [
314
+ x.to(device) if isinstance(x, torch.Tensor) else x
315
+ for x in batch
316
+ ]
utils2.py ADDED
@@ -0,0 +1,271 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from collections import Counter
3
+ from functools import partial
4
+ import numpy as np
5
+ import pandas as pd
6
+ import scipy.sparse
7
+ import torch
8
+ from rdkit import Chem
9
+ from rdkit.Chem import AllChem
10
+ from rdkit.Chem import MACCSkeys
11
+ from rdkit.Chem.AllChem import GetMorganFingerprintAsBitVect as Morgan
12
+ from rdkit.Chem.QED import qed
13
+ from rdkit.Chem.Scaffolds import MurckoScaffold
14
+ from rdkit.Chem import Descriptors
15
+ from moses.metrics.SA_Score import sascorer
16
+ from moses.metrics.NP_Score import npscorer
17
+ from moses.utils import mapper, get_mol
18
+
19
+ _base_dir = os.path.split(__file__)[0]
20
+ _mcf = pd.read_csv(os.path.join(_base_dir, 'mcf.csv'))
21
+ _pains = pd.read_csv(os.path.join(_base_dir, 'wehi_pains.csv'),
22
+ names=['smarts', 'names'])
23
+ _filters = [Chem.MolFromSmarts(x) for x in
24
+ _mcf.append(_pains, sort=True)['smarts'].values]
25
+
26
+
27
+ def canonic_smiles(smiles_or_mol):
28
+ mol = get_mol(smiles_or_mol)
29
+ if mol is None:
30
+ return None
31
+ return Chem.MolToSmiles(mol)
32
+
33
+
34
+ def logP(mol):
35
+ """
36
+ Computes RDKit's logP
37
+ """
38
+ return Chem.Crippen.MolLogP(mol)
39
+
40
+
41
+ def SA(mol):
42
+ """
43
+ Computes RDKit's Synthetic Accessibility score
44
+ """
45
+ return sascorer.calculateScore(mol)
46
+
47
+
48
+ def NP(mol):
49
+ """
50
+ Computes RDKit's Natural Product-likeness score
51
+ """
52
+ return npscorer.scoreMol(mol)
53
+
54
+
55
+ def QED(mol):
56
+ """
57
+ Computes RDKit's QED score
58
+ """
59
+ return qed(mol)
60
+
61
+
62
+ def weight(mol):
63
+ """
64
+ Computes molecular weight for given molecule.
65
+ Returns float,
66
+ """
67
+ return Descriptors.MolWt(mol)
68
+
69
+
70
+ def get_n_rings(mol):
71
+ """
72
+ Computes the number of rings in a molecule
73
+ """
74
+ return mol.GetRingInfo().NumRings()
75
+
76
+
77
+ def fragmenter(mol):
78
+ """
79
+ fragment mol using BRICS and return smiles list
80
+ """
81
+ fgs = AllChem.FragmentOnBRICSBonds(get_mol(mol))
82
+ fgs_smi = Chem.MolToSmiles(fgs).split(".")
83
+ return fgs_smi
84
+
85
+
86
+ def compute_fragments(mol_list, n_jobs=1):
87
+ """
88
+ fragment list of mols using BRICS and return smiles list
89
+ """
90
+ fragments = Counter()
91
+ for mol_frag in mapper(n_jobs)(fragmenter, mol_list):
92
+ fragments.update(mol_frag)
93
+ return fragments
94
+
95
+
96
+ def compute_scaffolds(mol_list, n_jobs=1, min_rings=2):
97
+ """
98
+ Extracts a scafold from a molecule in a form of a canonic SMILES
99
+ """
100
+ scaffolds = Counter()
101
+ map_ = mapper(n_jobs)
102
+ scaffolds = Counter(
103
+ map_(partial(compute_scaffold, min_rings=min_rings), mol_list))
104
+ if None in scaffolds:
105
+ scaffolds.pop(None)
106
+ return scaffolds
107
+
108
+
109
+ def compute_scaffold(mol, min_rings=2):
110
+ mol = get_mol(mol)
111
+ try:
112
+ scaffold = MurckoScaffold.GetScaffoldForMol(mol)
113
+ except (ValueError, RuntimeError):
114
+ return None
115
+ n_rings = get_n_rings(scaffold)
116
+ scaffold_smiles = Chem.MolToSmiles(scaffold)
117
+ if scaffold_smiles == '' or n_rings < min_rings:
118
+ return None
119
+ return scaffold_smiles
120
+
121
+
122
+ def average_agg_tanimoto(stock_vecs, gen_vecs,
123
+ batch_size=5000, agg='max',
124
+ device='cpu', p=1):
125
+ """
126
+ For each molecule in gen_vecs finds closest molecule in stock_vecs.
127
+ Returns average tanimoto score for between these molecules
128
+
129
+ Parameters:
130
+ stock_vecs: numpy array <n_vectors x dim>
131
+ gen_vecs: numpy array <n_vectors' x dim>
132
+ agg: max or mean
133
+ p: power for averaging: (mean x^p)^(1/p)
134
+ """
135
+ assert agg in ['max', 'mean'], "Can aggregate only max or mean"
136
+ agg_tanimoto = np.zeros(len(gen_vecs))
137
+ total = np.zeros(len(gen_vecs))
138
+ for j in range(0, stock_vecs.shape[0], batch_size):
139
+ x_stock = torch.tensor(stock_vecs[j:j + batch_size]).to(device).float()
140
+ for i in range(0, gen_vecs.shape[0], batch_size):
141
+ y_gen = torch.tensor(gen_vecs[i:i + batch_size]).to(device).float()
142
+ y_gen = y_gen.transpose(0, 1)
143
+ tp = torch.mm(x_stock, y_gen)
144
+ jac = (tp / (x_stock.sum(1, keepdim=True) +
145
+ y_gen.sum(0, keepdim=True) - tp)).cpu().numpy()
146
+ jac[np.isnan(jac)] = 1
147
+ if p != 1:
148
+ jac = jac**p
149
+ if agg == 'max':
150
+ agg_tanimoto[i:i + y_gen.shape[1]] = np.maximum(
151
+ agg_tanimoto[i:i + y_gen.shape[1]], jac.max(0))
152
+ elif agg == 'mean':
153
+ agg_tanimoto[i:i + y_gen.shape[1]] += jac.sum(0)
154
+ total[i:i + y_gen.shape[1]] += jac.shape[0]
155
+ if agg == 'mean':
156
+ agg_tanimoto /= total
157
+ if p != 1:
158
+ agg_tanimoto = (agg_tanimoto)**(1/p)
159
+ return np.mean(agg_tanimoto)
160
+
161
+
162
+ def fingerprint(smiles_or_mol, fp_type='maccs', dtype=None, morgan__r=2,
163
+ morgan__n=1024, *args, **kwargs):
164
+ """
165
+ Generates fingerprint for SMILES
166
+ If smiles is invalid, returns None
167
+ Returns numpy array of fingerprint bits
168
+
169
+ Parameters:
170
+ smiles: SMILES string
171
+ type: type of fingerprint: [MACCS|morgan]
172
+ dtype: if not None, specifies the dtype of returned array
173
+ """
174
+ fp_type = fp_type.lower()
175
+ molecule = get_mol(smiles_or_mol, *args, **kwargs)
176
+ if molecule is None:
177
+ return None
178
+ if fp_type == 'maccs':
179
+ keys = MACCSkeys.GenMACCSKeys(molecule)
180
+ keys = np.array(keys.GetOnBits())
181
+ fingerprint = np.zeros(166, dtype='uint8')
182
+ if len(keys) != 0:
183
+ fingerprint[keys - 1] = 1 # We drop 0-th key that is always zero
184
+ elif fp_type == 'morgan':
185
+ fingerprint = np.asarray(Morgan(molecule, morgan__r, nBits=morgan__n),
186
+ dtype='uint8')
187
+ else:
188
+ raise ValueError("Unknown fingerprint type {}".format(fp_type))
189
+ if dtype is not None:
190
+ fingerprint = fingerprint.astype(dtype)
191
+ return fingerprint
192
+
193
+
194
+ def fingerprints(smiles_mols_array, n_jobs=1, already_unique=False, *args,
195
+ **kwargs):
196
+ '''
197
+ Computes fingerprints of smiles np.array/list/pd.Series with n_jobs workers
198
+ e.g.fingerprints(smiles_mols_array, type='morgan', n_jobs=10)
199
+ Inserts np.NaN to rows corresponding to incorrect smiles.
200
+ IMPORTANT: if there is at least one np.NaN, the dtype would be float
201
+ Parameters:
202
+ smiles_mols_array: list/array/pd.Series of smiles or already computed
203
+ RDKit molecules
204
+ n_jobs: number of parralel workers to execute
205
+ already_unique: flag for performance reasons, if smiles array is big
206
+ and already unique. Its value is set to True if smiles_mols_array
207
+ contain RDKit molecules already.
208
+ '''
209
+ if isinstance(smiles_mols_array, pd.Series):
210
+ smiles_mols_array = smiles_mols_array.values
211
+ else:
212
+ smiles_mols_array = np.asarray(smiles_mols_array)
213
+ if not isinstance(smiles_mols_array[0], str):
214
+ already_unique = True
215
+
216
+ if not already_unique:
217
+ smiles_mols_array, inv_index = np.unique(smiles_mols_array,
218
+ return_inverse=True)
219
+
220
+ fps = mapper(n_jobs)(
221
+ partial(fingerprint, *args, **kwargs), smiles_mols_array
222
+ )
223
+
224
+ length = 1
225
+ for fp in fps:
226
+ if fp is not None:
227
+ length = fp.shape[-1]
228
+ first_fp = fp
229
+ break
230
+ fps = [fp if fp is not None else np.array([np.NaN]).repeat(length)[None, :]
231
+ for fp in fps]
232
+ if scipy.sparse.issparse(first_fp):
233
+ fps = scipy.sparse.vstack(fps).tocsr()
234
+ else:
235
+ fps = np.vstack(fps)
236
+ if not already_unique:
237
+ return fps[inv_index]
238
+ return fps
239
+
240
+
241
+ def mol_passes_filters(mol,
242
+ allowed=None,
243
+ isomericSmiles=False):
244
+ """
245
+ Checks if mol
246
+ * passes MCF and PAINS filters,
247
+ * has only allowed atoms
248
+ * is not charged
249
+ """
250
+ allowed = allowed or {'C', 'N', 'S', 'O', 'F', 'Cl', 'Br', 'H'}
251
+ mol = get_mol(mol)
252
+ if mol is None:
253
+ return False
254
+ ring_info = mol.GetRingInfo()
255
+ if ring_info.NumRings() != 0 and any(
256
+ len(x) >= 8 for x in ring_info.AtomRings()
257
+ ):
258
+ return False
259
+ h_mol = Chem.AddHs(mol)
260
+ if any(atom.GetFormalCharge() != 0 for atom in mol.GetAtoms()):
261
+ return False
262
+ if any(atom.GetSymbol() not in allowed for atom in mol.GetAtoms()):
263
+ return False
264
+ if any(h_mol.HasSubstructMatch(smarts) for smarts in _filters):
265
+ return False
266
+ smiles = Chem.MolToSmiles(mol, isomericSmiles=isomericSmiles)
267
+ if smiles is None or len(smiles) == 0:
268
+ return False
269
+ if Chem.MolFromSmiles(smiles) is None:
270
+ return False
271
+ return True