Spaces:
Runtime error
Runtime error
saicharan2804
commited on
Commit
•
36173e1
1
Parent(s):
af1e58a
Added manual implementation of metrics
Browse files- fcd.py +102 -0
- metrics.py +344 -0
- molgen_metric.py +0 -2
- utils.py +316 -0
- utils2.py +271 -0
fcd.py
ADDED
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fcd_torch.utils import SmilesDataset, \
|
2 |
+
calculate_frechet_distance, \
|
3 |
+
todevice, \
|
4 |
+
load_imported_model
|
5 |
+
import torch
|
6 |
+
from torch.utils.data import DataLoader
|
7 |
+
import os
|
8 |
+
import numpy as np
|
9 |
+
import warnings
|
10 |
+
|
11 |
+
|
12 |
+
class FCD:
|
13 |
+
"""
|
14 |
+
Computes Frechet ChemNet Distance on PyTorch.
|
15 |
+
* You can precalculate mean and sigma for further usage,
|
16 |
+
e.g. if you use the statistics from the same dataset
|
17 |
+
multiple times.
|
18 |
+
* Supports GPU and selection of GPU index
|
19 |
+
* Multithread SMILES parsing
|
20 |
+
|
21 |
+
Example 1:
|
22 |
+
fcd = FCD(device='cuda:0', n_jobs=8)
|
23 |
+
smiles_list = ['CCC', 'CCNC']
|
24 |
+
fcd(smiles_list, smiles_list)
|
25 |
+
|
26 |
+
Example 2:
|
27 |
+
fcd = FCD(device='cuda:0', n_jobs=8)
|
28 |
+
smiles_list = ['CCC', 'CCNC']
|
29 |
+
pgen = fcd.precalc(smiles_list)
|
30 |
+
fcd(smiles_list, pgen=pgen)
|
31 |
+
"""
|
32 |
+
def __init__(self, device='cpu', n_jobs=1,
|
33 |
+
batch_size=512,
|
34 |
+
model_path=None,
|
35 |
+
canonize=True):
|
36 |
+
"""
|
37 |
+
Loads ChemNet on device
|
38 |
+
params:
|
39 |
+
device: cpu for CPU, cuda:0 for GPU 0, etc.
|
40 |
+
n_jobs: number of workers to parse SMILES
|
41 |
+
batch_size: batch size for processing SMILES
|
42 |
+
model_path: path to ChemNet_v0.13_pretrained.pt
|
43 |
+
"""
|
44 |
+
if model_path is None:
|
45 |
+
model_dir = os.path.split(__file__)[0]
|
46 |
+
model_path = os.path.join(model_dir, 'ChemNet_v0.13_pretrained.pt')
|
47 |
+
|
48 |
+
self.device = device
|
49 |
+
self.n_jobs = n_jobs if n_jobs != 1 else 0
|
50 |
+
self.batch_size = batch_size
|
51 |
+
keras_config = torch.load(model_path)
|
52 |
+
self.model = load_imported_model(keras_config)
|
53 |
+
self.model.eval()
|
54 |
+
self.canonize = canonize
|
55 |
+
|
56 |
+
def get_predictions(self, smiles_list):
|
57 |
+
if len(smiles_list) == 0:
|
58 |
+
return np.zeros((0, 512))
|
59 |
+
dataloader = DataLoader(
|
60 |
+
SmilesDataset(smiles_list, canonize=self.canonize),
|
61 |
+
batch_size=self.batch_size,
|
62 |
+
num_workers=self.n_jobs
|
63 |
+
)
|
64 |
+
with todevice(self.model, self.device), torch.no_grad():
|
65 |
+
chemnet_activations = []
|
66 |
+
for batch in dataloader:
|
67 |
+
chemnet_activations.append(
|
68 |
+
self.model(
|
69 |
+
batch.transpose(1, 2).float().to(self.device)
|
70 |
+
).to('cpu').detach().numpy()
|
71 |
+
)
|
72 |
+
return np.row_stack(chemnet_activations)
|
73 |
+
|
74 |
+
def precalc(self, smiles_list):
|
75 |
+
if len(smiles_list) < 2:
|
76 |
+
warnings.warn("Can't compute FCD for less than 2 molecules"
|
77 |
+
"({} given)".format(len(smiles_list)))
|
78 |
+
return {}
|
79 |
+
chemnet_activations = self.get_predictions(smiles_list)
|
80 |
+
mu = chemnet_activations.mean(0)
|
81 |
+
sigma = np.cov(chemnet_activations.T)
|
82 |
+
return {'mu': mu, 'sigma': sigma}
|
83 |
+
|
84 |
+
def metric(self, pref, pgen):
|
85 |
+
if 'mu' not in pref or 'sigma' not in pgen:
|
86 |
+
warnings.warn("Failed to compute FCD (check ref)")
|
87 |
+
return np.nan
|
88 |
+
if 'mu' not in pgen or 'sigma' not in pgen:
|
89 |
+
warnings.warn("Failed to compute FCD (check gen)")
|
90 |
+
return np.nan
|
91 |
+
return calculate_frechet_distance(
|
92 |
+
pref['mu'], pref['sigma'], pgen['mu'], pgen['sigma']
|
93 |
+
)
|
94 |
+
|
95 |
+
def __call__(self, ref=None, gen=None, pref=None, pgen=None):
|
96 |
+
assert (ref is None) != (pref is None), "specify ref xor pref"
|
97 |
+
assert (gen is None) != (pgen is None), "specify gen xor pgen"
|
98 |
+
if pref is None:
|
99 |
+
pref = self.precalc(ref)
|
100 |
+
if pgen is None:
|
101 |
+
pgen = self.precalc(gen)
|
102 |
+
return self.metric(pref, pgen)
|
metrics.py
ADDED
@@ -0,0 +1,344 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import warnings
|
2 |
+
from multiprocessing import Pool
|
3 |
+
import numpy as np
|
4 |
+
from scipy.spatial.distance import cosine as cos_distance
|
5 |
+
from fcd_torch import FCD as FCDMetric
|
6 |
+
from scipy.stats import wasserstein_distance
|
7 |
+
|
8 |
+
from moses.dataset import get_dataset, get_statistics
|
9 |
+
from moses.utils import mapper
|
10 |
+
from moses.utils import disable_rdkit_log, enable_rdkit_log
|
11 |
+
from .utils import compute_fragments, average_agg_tanimoto, \
|
12 |
+
compute_scaffolds, fingerprints, \
|
13 |
+
get_mol, canonic_smiles, mol_passes_filters, \
|
14 |
+
logP, QED, SA, weight
|
15 |
+
|
16 |
+
|
17 |
+
def get_all_metrics(gen, k=None, n_jobs=1,
|
18 |
+
device='cpu', batch_size=512, pool=None,
|
19 |
+
test=None, test_scaffolds=None,
|
20 |
+
ptest=None, ptest_scaffolds=None,
|
21 |
+
train=None):
|
22 |
+
"""
|
23 |
+
Computes all available metrics between test (scaffold test)
|
24 |
+
and generated sets of SMILES.
|
25 |
+
Parameters:
|
26 |
+
gen: list of generated SMILES
|
27 |
+
k: int or list with values for unique@k. Will calculate number of
|
28 |
+
unique molecules in the first k molecules. Default [1000, 10000]
|
29 |
+
n_jobs: number of workers for parallel processing
|
30 |
+
device: 'cpu' or 'cuda:n', where n is GPU device number
|
31 |
+
batch_size: batch size for FCD metric
|
32 |
+
pool: optional multiprocessing pool to use for parallelization
|
33 |
+
|
34 |
+
test (None or list): test SMILES. If None, will load
|
35 |
+
a default test set
|
36 |
+
test_scaffolds (None or list): scaffold test SMILES. If None, will
|
37 |
+
load a default scaffold test set
|
38 |
+
ptest (None or dict): precalculated statistics of the test set. If
|
39 |
+
None, will load default test statistics. If you specified a custom
|
40 |
+
test set, default test statistics will be ignored
|
41 |
+
ptest_scaffolds (None or dict): precalculated statistics of the
|
42 |
+
scaffold test set If None, will load default scaffold test
|
43 |
+
statistics. If you specified a custom test set, default test
|
44 |
+
statistics will be ignored
|
45 |
+
train (None or list): train SMILES. If None, will load a default
|
46 |
+
train set
|
47 |
+
Available metrics:
|
48 |
+
* %valid
|
49 |
+
* %unique@k
|
50 |
+
* Frechet ChemNet Distance (FCD)
|
51 |
+
* Fragment similarity (Frag)
|
52 |
+
* Scaffold similarity (Scaf)
|
53 |
+
* Similarity to nearest neighbour (SNN)
|
54 |
+
* Internal diversity (IntDiv)
|
55 |
+
* Internal diversity 2: using square root of mean squared
|
56 |
+
Tanimoto similarity (IntDiv2)
|
57 |
+
* %passes filters (Filters)
|
58 |
+
* Distribution difference for logP, SA, QED, weight
|
59 |
+
* Novelty (molecules not present in train)
|
60 |
+
"""
|
61 |
+
if test is None:
|
62 |
+
if ptest is not None:
|
63 |
+
raise ValueError(
|
64 |
+
"You cannot specify custom test "
|
65 |
+
"statistics for default test set")
|
66 |
+
test = get_dataset('test')
|
67 |
+
ptest = get_statistics('test')
|
68 |
+
|
69 |
+
if test_scaffolds is None:
|
70 |
+
if ptest_scaffolds is not None:
|
71 |
+
raise ValueError(
|
72 |
+
"You cannot specify custom scaffold test "
|
73 |
+
"statistics for default scaffold test set")
|
74 |
+
test_scaffolds = get_dataset('test_scaffolds')
|
75 |
+
ptest_scaffolds = get_statistics('test_scaffolds')
|
76 |
+
|
77 |
+
train = train or get_dataset('train')
|
78 |
+
|
79 |
+
if k is None:
|
80 |
+
k = [1000, 10000]
|
81 |
+
disable_rdkit_log()
|
82 |
+
metrics = {}
|
83 |
+
close_pool = False
|
84 |
+
if pool is None:
|
85 |
+
if n_jobs != 1:
|
86 |
+
pool = Pool(n_jobs)
|
87 |
+
close_pool = True
|
88 |
+
else:
|
89 |
+
pool = 1
|
90 |
+
metrics['valid'] = fraction_valid(gen, n_jobs=pool)
|
91 |
+
gen = remove_invalid(gen, canonize=True)
|
92 |
+
if not isinstance(k, (list, tuple)):
|
93 |
+
k = [k]
|
94 |
+
for _k in k:
|
95 |
+
metrics['unique@{}'.format(_k)] = fraction_unique(gen, _k, pool)
|
96 |
+
|
97 |
+
if ptest is None:
|
98 |
+
ptest = compute_intermediate_statistics(test, n_jobs=n_jobs,
|
99 |
+
device=device,
|
100 |
+
batch_size=batch_size,
|
101 |
+
pool=pool)
|
102 |
+
if test_scaffolds is not None and ptest_scaffolds is None:
|
103 |
+
ptest_scaffolds = compute_intermediate_statistics(
|
104 |
+
test_scaffolds, n_jobs=n_jobs,
|
105 |
+
device=device, batch_size=batch_size,
|
106 |
+
pool=pool
|
107 |
+
)
|
108 |
+
mols = mapper(pool)(get_mol, gen)
|
109 |
+
kwargs = {'n_jobs': pool, 'device': device, 'batch_size': batch_size}
|
110 |
+
kwargs_fcd = {'n_jobs': n_jobs, 'device': device, 'batch_size': batch_size}
|
111 |
+
metrics['FCD/Test'] = FCDMetric(**kwargs_fcd)(gen=gen, pref=ptest['FCD'])
|
112 |
+
metrics['SNN/Test'] = SNNMetric(**kwargs)(gen=mols, pref=ptest['SNN'])
|
113 |
+
metrics['Frag/Test'] = FragMetric(**kwargs)(gen=mols, pref=ptest['Frag'])
|
114 |
+
metrics['Scaf/Test'] = ScafMetric(**kwargs)(gen=mols, pref=ptest['Scaf'])
|
115 |
+
if ptest_scaffolds is not None:
|
116 |
+
metrics['FCD/TestSF'] = FCDMetric(**kwargs_fcd)(
|
117 |
+
gen=gen, pref=ptest_scaffolds['FCD']
|
118 |
+
)
|
119 |
+
metrics['SNN/TestSF'] = SNNMetric(**kwargs)(
|
120 |
+
gen=mols, pref=ptest_scaffolds['SNN']
|
121 |
+
)
|
122 |
+
metrics['Frag/TestSF'] = FragMetric(**kwargs)(
|
123 |
+
gen=mols, pref=ptest_scaffolds['Frag']
|
124 |
+
)
|
125 |
+
metrics['Scaf/TestSF'] = ScafMetric(**kwargs)(
|
126 |
+
gen=mols, pref=ptest_scaffolds['Scaf']
|
127 |
+
)
|
128 |
+
|
129 |
+
metrics['IntDiv'] = internal_diversity(mols, pool, device=device)
|
130 |
+
metrics['IntDiv2'] = internal_diversity(mols, pool, device=device, p=2)
|
131 |
+
metrics['Filters'] = fraction_passes_filters(mols, pool)
|
132 |
+
|
133 |
+
# Properties
|
134 |
+
for name, func in [('logP', logP), ('SA', SA),
|
135 |
+
('QED', QED),
|
136 |
+
('weight', weight)]:
|
137 |
+
metrics[name] = WassersteinMetric(func, **kwargs)(
|
138 |
+
gen=mols, pref=ptest[name])
|
139 |
+
|
140 |
+
if train is not None:
|
141 |
+
metrics['Novelty'] = novelty(mols, train, pool)
|
142 |
+
enable_rdkit_log()
|
143 |
+
if close_pool:
|
144 |
+
pool.close()
|
145 |
+
pool.join()
|
146 |
+
return metrics
|
147 |
+
|
148 |
+
|
149 |
+
def compute_intermediate_statistics(smiles, n_jobs=1, device='cpu',
|
150 |
+
batch_size=512, pool=None):
|
151 |
+
"""
|
152 |
+
The function precomputes statistics such as mean and variance for FCD, etc.
|
153 |
+
It is useful to compute the statistics for test and scaffold test sets to
|
154 |
+
speedup metrics calculation.
|
155 |
+
"""
|
156 |
+
close_pool = False
|
157 |
+
if pool is None:
|
158 |
+
if n_jobs != 1:
|
159 |
+
pool = Pool(n_jobs)
|
160 |
+
close_pool = True
|
161 |
+
else:
|
162 |
+
pool = 1
|
163 |
+
statistics = {}
|
164 |
+
mols = mapper(pool)(get_mol, smiles)
|
165 |
+
kwargs = {'n_jobs': pool, 'device': device, 'batch_size': batch_size}
|
166 |
+
kwargs_fcd = {'n_jobs': n_jobs, 'device': device, 'batch_size': batch_size}
|
167 |
+
statistics['FCD'] = FCDMetric(**kwargs_fcd).precalc(smiles)
|
168 |
+
statistics['SNN'] = SNNMetric(**kwargs).precalc(mols)
|
169 |
+
statistics['Frag'] = FragMetric(**kwargs).precalc(mols)
|
170 |
+
statistics['Scaf'] = ScafMetric(**kwargs).precalc(mols)
|
171 |
+
for name, func in [('logP', logP), ('SA', SA),
|
172 |
+
('QED', QED),
|
173 |
+
('weight', weight)]:
|
174 |
+
statistics[name] = WassersteinMetric(func, **kwargs).precalc(mols)
|
175 |
+
if close_pool:
|
176 |
+
pool.terminate()
|
177 |
+
return statistics
|
178 |
+
|
179 |
+
|
180 |
+
def fraction_passes_filters(gen, n_jobs=1):
|
181 |
+
"""
|
182 |
+
Computes the fraction of molecules that pass filters:
|
183 |
+
* MCF
|
184 |
+
* PAINS
|
185 |
+
* Only allowed atoms ('C','N','S','O','F','Cl','Br','H')
|
186 |
+
* No charges
|
187 |
+
"""
|
188 |
+
passes = mapper(n_jobs)(mol_passes_filters, gen)
|
189 |
+
return np.mean(passes)
|
190 |
+
|
191 |
+
|
192 |
+
def internal_diversity(gen, n_jobs=1, device='cpu', fp_type='morgan',
|
193 |
+
gen_fps=None, p=1):
|
194 |
+
"""
|
195 |
+
Computes internal diversity as:
|
196 |
+
1/|A|^2 sum_{x, y in AxA} (1-tanimoto(x, y))
|
197 |
+
"""
|
198 |
+
if gen_fps is None:
|
199 |
+
gen_fps = fingerprints(gen, fp_type=fp_type, n_jobs=n_jobs)
|
200 |
+
return 1 - (average_agg_tanimoto(gen_fps, gen_fps,
|
201 |
+
agg='mean', device=device, p=p)).mean()
|
202 |
+
|
203 |
+
|
204 |
+
def fraction_unique(gen, k=None, n_jobs=1, check_validity=True):
|
205 |
+
"""
|
206 |
+
Computes a number of unique molecules
|
207 |
+
Parameters:
|
208 |
+
gen: list of SMILES
|
209 |
+
k: compute unique@k
|
210 |
+
n_jobs: number of threads for calculation
|
211 |
+
check_validity: raises ValueError if invalid molecules are present
|
212 |
+
"""
|
213 |
+
if k is not None:
|
214 |
+
if len(gen) < k:
|
215 |
+
warnings.warn(
|
216 |
+
"Can't compute unique@{}.".format(k) +
|
217 |
+
"gen contains only {} molecules".format(len(gen))
|
218 |
+
)
|
219 |
+
gen = gen[:k]
|
220 |
+
canonic = set(mapper(n_jobs)(canonic_smiles, gen))
|
221 |
+
if None in canonic and check_validity:
|
222 |
+
raise ValueError("Invalid molecule passed to unique@k")
|
223 |
+
return len(canonic) / len(gen)
|
224 |
+
|
225 |
+
|
226 |
+
def fraction_valid(gen, n_jobs=1):
|
227 |
+
"""
|
228 |
+
Computes a number of valid molecules
|
229 |
+
Parameters:
|
230 |
+
gen: list of SMILES
|
231 |
+
n_jobs: number of threads for calculation
|
232 |
+
"""
|
233 |
+
gen = mapper(n_jobs)(get_mol, gen)
|
234 |
+
return 1 - gen.count(None) / len(gen)
|
235 |
+
|
236 |
+
|
237 |
+
def novelty(gen, train, n_jobs=1):
|
238 |
+
gen_smiles = mapper(n_jobs)(canonic_smiles, gen)
|
239 |
+
gen_smiles_set = set(gen_smiles) - {None}
|
240 |
+
train_set = set(train)
|
241 |
+
return len(gen_smiles_set - train_set) / len(gen_smiles_set)
|
242 |
+
|
243 |
+
|
244 |
+
def remove_invalid(gen, canonize=True, n_jobs=1):
|
245 |
+
"""
|
246 |
+
Removes invalid molecules from the dataset
|
247 |
+
"""
|
248 |
+
if not canonize:
|
249 |
+
mols = mapper(n_jobs)(get_mol, gen)
|
250 |
+
return [gen_ for gen_, mol in zip(gen, mols) if mol is not None]
|
251 |
+
return [x for x in mapper(n_jobs)(canonic_smiles, gen) if
|
252 |
+
x is not None]
|
253 |
+
|
254 |
+
|
255 |
+
class Metric:
|
256 |
+
def __init__(self, n_jobs=1, device='cpu', batch_size=512, **kwargs):
|
257 |
+
self.n_jobs = n_jobs
|
258 |
+
self.device = device
|
259 |
+
self.batch_size = batch_size
|
260 |
+
for k, v in kwargs.values():
|
261 |
+
setattr(self, k, v)
|
262 |
+
|
263 |
+
def __call__(self, ref=None, gen=None, pref=None, pgen=None):
|
264 |
+
assert (ref is None) != (pref is None), "specify ref xor pref"
|
265 |
+
assert (gen is None) != (pgen is None), "specify gen xor pgen"
|
266 |
+
if pref is None:
|
267 |
+
pref = self.precalc(ref)
|
268 |
+
if pgen is None:
|
269 |
+
pgen = self.precalc(gen)
|
270 |
+
return self.metric(pref, pgen)
|
271 |
+
|
272 |
+
def precalc(self, moleclues):
|
273 |
+
raise NotImplementedError
|
274 |
+
|
275 |
+
def metric(self, pref, pgen):
|
276 |
+
raise NotImplementedError
|
277 |
+
|
278 |
+
|
279 |
+
class SNNMetric(Metric):
|
280 |
+
"""
|
281 |
+
Computes average max similarities of gen SMILES to ref SMILES
|
282 |
+
"""
|
283 |
+
|
284 |
+
def __init__(self, fp_type='morgan', **kwargs):
|
285 |
+
self.fp_type = fp_type
|
286 |
+
super().__init__(**kwargs)
|
287 |
+
|
288 |
+
def precalc(self, mols):
|
289 |
+
return {'fps': fingerprints(mols, n_jobs=self.n_jobs,
|
290 |
+
fp_type=self.fp_type)}
|
291 |
+
|
292 |
+
def metric(self, pref, pgen):
|
293 |
+
return average_agg_tanimoto(pref['fps'], pgen['fps'],
|
294 |
+
device=self.device)
|
295 |
+
|
296 |
+
|
297 |
+
def cos_similarity(ref_counts, gen_counts):
|
298 |
+
"""
|
299 |
+
Computes cosine similarity between
|
300 |
+
dictionaries of form {name: count}. Non-present
|
301 |
+
elements are considered zero:
|
302 |
+
|
303 |
+
sim = <r, g> / ||r|| / ||g||
|
304 |
+
"""
|
305 |
+
if len(ref_counts) == 0 or len(gen_counts) == 0:
|
306 |
+
return np.nan
|
307 |
+
keys = np.unique(list(ref_counts.keys()) + list(gen_counts.keys()))
|
308 |
+
ref_vec = np.array([ref_counts.get(k, 0) for k in keys])
|
309 |
+
gen_vec = np.array([gen_counts.get(k, 0) for k in keys])
|
310 |
+
return 1 - cos_distance(ref_vec, gen_vec)
|
311 |
+
|
312 |
+
|
313 |
+
class FragMetric(Metric):
|
314 |
+
def precalc(self, mols):
|
315 |
+
return {'frag': compute_fragments(mols, n_jobs=self.n_jobs)}
|
316 |
+
|
317 |
+
def metric(self, pref, pgen):
|
318 |
+
return cos_similarity(pref['frag'], pgen['frag'])
|
319 |
+
|
320 |
+
|
321 |
+
class ScafMetric(Metric):
|
322 |
+
def precalc(self, mols):
|
323 |
+
return {'scaf': compute_scaffolds(mols, n_jobs=self.n_jobs)}
|
324 |
+
|
325 |
+
def metric(self, pref, pgen):
|
326 |
+
return cos_similarity(pref['scaf'], pgen['scaf'])
|
327 |
+
|
328 |
+
|
329 |
+
class WassersteinMetric(Metric):
|
330 |
+
def __init__(self, func=None, **kwargs):
|
331 |
+
self.func = func
|
332 |
+
super().__init__(**kwargs)
|
333 |
+
|
334 |
+
def precalc(self, mols):
|
335 |
+
if self.func is not None:
|
336 |
+
values = mapper(self.n_jobs)(self.func, mols)
|
337 |
+
else:
|
338 |
+
values = mols
|
339 |
+
return {'values': values}
|
340 |
+
|
341 |
+
def metric(self, pref, pgen):
|
342 |
+
return wasserstein_distance(
|
343 |
+
pref['values'], pgen['values']
|
344 |
+
)
|
molgen_metric.py
CHANGED
@@ -94,8 +94,6 @@ class molgen_metric(evaluate.Measurement):
|
|
94 |
|
95 |
def _compute(self, generated_smiles, train_smiles = None):
|
96 |
|
97 |
-
|
98 |
-
|
99 |
Results = metrics.get_all_metrics(gen = generated_smiles, train= train_smiles)
|
100 |
|
101 |
generated_smiles = [s for s in generated_smiles if s != '']
|
|
|
94 |
|
95 |
def _compute(self, generated_smiles, train_smiles = None):
|
96 |
|
|
|
|
|
97 |
Results = metrics.get_all_metrics(gen = generated_smiles, train= train_smiles)
|
98 |
|
99 |
generated_smiles = [s for s in generated_smiles if s != '']
|
utils.py
ADDED
@@ -0,0 +1,316 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
from multiprocessing import Pool
|
3 |
+
from collections import UserList, defaultdict
|
4 |
+
import numpy as np
|
5 |
+
import pandas as pd
|
6 |
+
from matplotlib import pyplot as plt
|
7 |
+
import torch
|
8 |
+
from rdkit import rdBase
|
9 |
+
from rdkit import Chem
|
10 |
+
|
11 |
+
|
12 |
+
# https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader
|
13 |
+
def set_torch_seed_to_all_gens(_):
|
14 |
+
seed = torch.initial_seed() % (2**32 - 1)
|
15 |
+
random.seed(seed)
|
16 |
+
np.random.seed(seed)
|
17 |
+
|
18 |
+
|
19 |
+
class SpecialTokens:
|
20 |
+
bos = '<bos>'
|
21 |
+
eos = '<eos>'
|
22 |
+
pad = '<pad>'
|
23 |
+
unk = '<unk>'
|
24 |
+
|
25 |
+
|
26 |
+
class CharVocab:
|
27 |
+
@classmethod
|
28 |
+
def from_data(cls, data, *args, **kwargs):
|
29 |
+
chars = set()
|
30 |
+
for string in data:
|
31 |
+
chars.update(string)
|
32 |
+
|
33 |
+
return cls(chars, *args, **kwargs)
|
34 |
+
|
35 |
+
def __init__(self, chars, ss=SpecialTokens):
|
36 |
+
if (ss.bos in chars) or (ss.eos in chars) or \
|
37 |
+
(ss.pad in chars) or (ss.unk in chars):
|
38 |
+
raise ValueError('SpecialTokens in chars')
|
39 |
+
|
40 |
+
all_syms = sorted(list(chars)) + [ss.bos, ss.eos, ss.pad, ss.unk]
|
41 |
+
|
42 |
+
self.ss = ss
|
43 |
+
self.c2i = {c: i for i, c in enumerate(all_syms)}
|
44 |
+
self.i2c = {i: c for i, c in enumerate(all_syms)}
|
45 |
+
|
46 |
+
def __len__(self):
|
47 |
+
return len(self.c2i)
|
48 |
+
|
49 |
+
@property
|
50 |
+
def bos(self):
|
51 |
+
return self.c2i[self.ss.bos]
|
52 |
+
|
53 |
+
@property
|
54 |
+
def eos(self):
|
55 |
+
return self.c2i[self.ss.eos]
|
56 |
+
|
57 |
+
@property
|
58 |
+
def pad(self):
|
59 |
+
return self.c2i[self.ss.pad]
|
60 |
+
|
61 |
+
@property
|
62 |
+
def unk(self):
|
63 |
+
return self.c2i[self.ss.unk]
|
64 |
+
|
65 |
+
def char2id(self, char):
|
66 |
+
if char not in self.c2i:
|
67 |
+
return self.unk
|
68 |
+
|
69 |
+
return self.c2i[char]
|
70 |
+
|
71 |
+
def id2char(self, id):
|
72 |
+
if id not in self.i2c:
|
73 |
+
return self.ss.unk
|
74 |
+
|
75 |
+
return self.i2c[id]
|
76 |
+
|
77 |
+
def string2ids(self, string, add_bos=False, add_eos=False):
|
78 |
+
ids = [self.char2id(c) for c in string]
|
79 |
+
|
80 |
+
if add_bos:
|
81 |
+
ids = [self.bos] + ids
|
82 |
+
if add_eos:
|
83 |
+
ids = ids + [self.eos]
|
84 |
+
|
85 |
+
return ids
|
86 |
+
|
87 |
+
def ids2string(self, ids, rem_bos=True, rem_eos=True):
|
88 |
+
if len(ids) == 0:
|
89 |
+
return ''
|
90 |
+
if rem_bos and ids[0] == self.bos:
|
91 |
+
ids = ids[1:]
|
92 |
+
if rem_eos and ids[-1] == self.eos:
|
93 |
+
ids = ids[:-1]
|
94 |
+
|
95 |
+
string = ''.join([self.id2char(id) for id in ids])
|
96 |
+
|
97 |
+
return string
|
98 |
+
|
99 |
+
|
100 |
+
class OneHotVocab(CharVocab):
|
101 |
+
def __init__(self, *args, **kwargs):
|
102 |
+
super(OneHotVocab, self).__init__(*args, **kwargs)
|
103 |
+
self.vectors = torch.eye(len(self.c2i))
|
104 |
+
|
105 |
+
|
106 |
+
def mapper(n_jobs):
|
107 |
+
'''
|
108 |
+
Returns function for map call.
|
109 |
+
If n_jobs == 1, will use standard map
|
110 |
+
If n_jobs > 1, will use multiprocessing pool
|
111 |
+
If n_jobs is a pool object, will return its map function
|
112 |
+
'''
|
113 |
+
if n_jobs == 1:
|
114 |
+
def _mapper(*args, **kwargs):
|
115 |
+
return list(map(*args, **kwargs))
|
116 |
+
|
117 |
+
return _mapper
|
118 |
+
if isinstance(n_jobs, int):
|
119 |
+
pool = Pool(n_jobs)
|
120 |
+
|
121 |
+
def _mapper(*args, **kwargs):
|
122 |
+
try:
|
123 |
+
result = pool.map(*args, **kwargs)
|
124 |
+
finally:
|
125 |
+
pool.terminate()
|
126 |
+
return result
|
127 |
+
|
128 |
+
return _mapper
|
129 |
+
return n_jobs.map
|
130 |
+
|
131 |
+
|
132 |
+
class Logger(UserList):
|
133 |
+
def __init__(self, data=None):
|
134 |
+
super().__init__()
|
135 |
+
self.sdata = defaultdict(list)
|
136 |
+
for step in (data or []):
|
137 |
+
self.append(step)
|
138 |
+
|
139 |
+
def __getitem__(self, key):
|
140 |
+
if isinstance(key, int):
|
141 |
+
return self.data[key]
|
142 |
+
if isinstance(key, slice):
|
143 |
+
return Logger(self.data[key])
|
144 |
+
ldata = self.sdata[key]
|
145 |
+
if isinstance(ldata[0], dict):
|
146 |
+
return Logger(ldata)
|
147 |
+
return ldata
|
148 |
+
|
149 |
+
def append(self, step_dict):
|
150 |
+
super().append(step_dict)
|
151 |
+
for k, v in step_dict.items():
|
152 |
+
self.sdata[k].append(v)
|
153 |
+
|
154 |
+
def save(self, path):
|
155 |
+
df = pd.DataFrame(list(self))
|
156 |
+
df.to_csv(path, index=None)
|
157 |
+
|
158 |
+
|
159 |
+
class LogPlotter:
|
160 |
+
def __init__(self, log):
|
161 |
+
self.log = log
|
162 |
+
|
163 |
+
def line(self, ax, name):
|
164 |
+
if isinstance(self.log[0][name], dict):
|
165 |
+
for k in self.log[0][name]:
|
166 |
+
ax.plot(self.log[name][k], label=k)
|
167 |
+
ax.legend()
|
168 |
+
else:
|
169 |
+
ax.plot(self.log[name])
|
170 |
+
|
171 |
+
ax.set_ylabel('value')
|
172 |
+
ax.set_xlabel('epoch')
|
173 |
+
ax.set_title(name)
|
174 |
+
|
175 |
+
def grid(self, names, size=7):
|
176 |
+
_, axs = plt.subplots(nrows=len(names) // 2, ncols=2,
|
177 |
+
figsize=(size * 2, size * (len(names) // 2)))
|
178 |
+
|
179 |
+
for ax, name in zip(axs.flatten(), names):
|
180 |
+
self.line(ax, name)
|
181 |
+
|
182 |
+
|
183 |
+
class CircularBuffer:
|
184 |
+
def __init__(self, size):
|
185 |
+
self.max_size = size
|
186 |
+
self.data = np.zeros(self.max_size)
|
187 |
+
self.size = 0
|
188 |
+
self.pointer = -1
|
189 |
+
|
190 |
+
def add(self, element):
|
191 |
+
self.size = min(self.size + 1, self.max_size)
|
192 |
+
self.pointer = (self.pointer + 1) % self.max_size
|
193 |
+
self.data[self.pointer] = element
|
194 |
+
return element
|
195 |
+
|
196 |
+
def last(self):
|
197 |
+
assert self.pointer != -1, "Can't get an element from an empty buffer!"
|
198 |
+
return self.data[self.pointer]
|
199 |
+
|
200 |
+
def mean(self):
|
201 |
+
if self.size > 0:
|
202 |
+
return self.data[:self.size].mean()
|
203 |
+
return 0.0
|
204 |
+
|
205 |
+
|
206 |
+
def disable_rdkit_log():
|
207 |
+
rdBase.DisableLog('rdApp.*')
|
208 |
+
|
209 |
+
|
210 |
+
def enable_rdkit_log():
|
211 |
+
rdBase.EnableLog('rdApp.*')
|
212 |
+
|
213 |
+
|
214 |
+
def get_mol(smiles_or_mol):
|
215 |
+
'''
|
216 |
+
Loads SMILES/molecule into RDKit's object
|
217 |
+
'''
|
218 |
+
if isinstance(smiles_or_mol, str):
|
219 |
+
if len(smiles_or_mol) == 0:
|
220 |
+
return None
|
221 |
+
mol = Chem.MolFromSmiles(smiles_or_mol)
|
222 |
+
if mol is None:
|
223 |
+
return None
|
224 |
+
try:
|
225 |
+
Chem.SanitizeMol(mol)
|
226 |
+
except ValueError:
|
227 |
+
return None
|
228 |
+
return mol
|
229 |
+
return smiles_or_mol
|
230 |
+
|
231 |
+
|
232 |
+
class StringDataset:
|
233 |
+
def __init__(self, vocab, data):
|
234 |
+
"""
|
235 |
+
Creates a convenient Dataset with SMILES tokinization
|
236 |
+
|
237 |
+
Arguments:
|
238 |
+
vocab: CharVocab instance for tokenization
|
239 |
+
data (list): SMILES strings for the dataset
|
240 |
+
"""
|
241 |
+
self.vocab = vocab
|
242 |
+
self.tokens = [vocab.string2ids(s) for s in data]
|
243 |
+
self.data = data
|
244 |
+
self.bos = vocab.bos
|
245 |
+
self.eos = vocab.eos
|
246 |
+
|
247 |
+
def __len__(self):
|
248 |
+
"""
|
249 |
+
Computes a number of objects in the dataset
|
250 |
+
"""
|
251 |
+
return len(self.tokens)
|
252 |
+
|
253 |
+
def __getitem__(self, index):
|
254 |
+
"""
|
255 |
+
Prepares torch tensors with a given SMILES.
|
256 |
+
|
257 |
+
Arguments:
|
258 |
+
index (int): index of SMILES in the original dataset
|
259 |
+
|
260 |
+
Returns:
|
261 |
+
A tuple (with_bos, with_eos, smiles), where
|
262 |
+
* with_bos is a torch.long tensor of SMILES tokens with
|
263 |
+
BOS (beginning of a sentence) token
|
264 |
+
* with_eos is a torch.long tensor of SMILES tokens with
|
265 |
+
EOS (end of a sentence) token
|
266 |
+
* smiles is an original SMILES from the dataset
|
267 |
+
"""
|
268 |
+
tokens = self.tokens[index]
|
269 |
+
with_bos = torch.tensor([self.bos] + tokens, dtype=torch.long)
|
270 |
+
with_eos = torch.tensor(tokens + [self.eos], dtype=torch.long)
|
271 |
+
return with_bos, with_eos, self.data[index]
|
272 |
+
|
273 |
+
def default_collate(self, batch, return_data=False):
|
274 |
+
"""
|
275 |
+
Simple collate function for SMILES dataset. Joins a
|
276 |
+
batch of objects from StringDataset into a batch
|
277 |
+
|
278 |
+
Arguments:
|
279 |
+
batch: list of objects from StringDataset
|
280 |
+
pad: padding symbol, usually equals to vocab.pad
|
281 |
+
return_data: if True, will return SMILES used in a batch
|
282 |
+
|
283 |
+
Returns:
|
284 |
+
with_bos, with_eos, lengths [, data] where
|
285 |
+
* with_bos: padded sequence with BOS in the beginning
|
286 |
+
* with_eos: padded sequence with EOS in the end
|
287 |
+
* lengths: array with SMILES lengths in the batch
|
288 |
+
* data: SMILES in the batch
|
289 |
+
|
290 |
+
Note: output batch is sorted with respect to SMILES lengths in
|
291 |
+
decreasing order, since this is a default format for torch
|
292 |
+
RNN implementations
|
293 |
+
"""
|
294 |
+
with_bos, with_eos, data = list(zip(*batch))
|
295 |
+
lengths = [len(x) for x in with_bos]
|
296 |
+
order = np.argsort(lengths)[::-1]
|
297 |
+
with_bos = [with_bos[i] for i in order]
|
298 |
+
with_eos = [with_eos[i] for i in order]
|
299 |
+
lengths = [lengths[i] for i in order]
|
300 |
+
with_bos = torch.nn.utils.rnn.pad_sequence(
|
301 |
+
with_bos, padding_value=self.vocab.pad
|
302 |
+
)
|
303 |
+
with_eos = torch.nn.utils.rnn.pad_sequence(
|
304 |
+
with_eos, padding_value=self.vocab.pad
|
305 |
+
)
|
306 |
+
if return_data:
|
307 |
+
data = np.array(data)[order]
|
308 |
+
return with_bos, with_eos, lengths, data
|
309 |
+
return with_bos, with_eos, lengths
|
310 |
+
|
311 |
+
|
312 |
+
def batch_to_device(batch, device):
|
313 |
+
return [
|
314 |
+
x.to(device) if isinstance(x, torch.Tensor) else x
|
315 |
+
for x in batch
|
316 |
+
]
|
utils2.py
ADDED
@@ -0,0 +1,271 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from collections import Counter
|
3 |
+
from functools import partial
|
4 |
+
import numpy as np
|
5 |
+
import pandas as pd
|
6 |
+
import scipy.sparse
|
7 |
+
import torch
|
8 |
+
from rdkit import Chem
|
9 |
+
from rdkit.Chem import AllChem
|
10 |
+
from rdkit.Chem import MACCSkeys
|
11 |
+
from rdkit.Chem.AllChem import GetMorganFingerprintAsBitVect as Morgan
|
12 |
+
from rdkit.Chem.QED import qed
|
13 |
+
from rdkit.Chem.Scaffolds import MurckoScaffold
|
14 |
+
from rdkit.Chem import Descriptors
|
15 |
+
from moses.metrics.SA_Score import sascorer
|
16 |
+
from moses.metrics.NP_Score import npscorer
|
17 |
+
from moses.utils import mapper, get_mol
|
18 |
+
|
19 |
+
_base_dir = os.path.split(__file__)[0]
|
20 |
+
_mcf = pd.read_csv(os.path.join(_base_dir, 'mcf.csv'))
|
21 |
+
_pains = pd.read_csv(os.path.join(_base_dir, 'wehi_pains.csv'),
|
22 |
+
names=['smarts', 'names'])
|
23 |
+
_filters = [Chem.MolFromSmarts(x) for x in
|
24 |
+
_mcf.append(_pains, sort=True)['smarts'].values]
|
25 |
+
|
26 |
+
|
27 |
+
def canonic_smiles(smiles_or_mol):
|
28 |
+
mol = get_mol(smiles_or_mol)
|
29 |
+
if mol is None:
|
30 |
+
return None
|
31 |
+
return Chem.MolToSmiles(mol)
|
32 |
+
|
33 |
+
|
34 |
+
def logP(mol):
|
35 |
+
"""
|
36 |
+
Computes RDKit's logP
|
37 |
+
"""
|
38 |
+
return Chem.Crippen.MolLogP(mol)
|
39 |
+
|
40 |
+
|
41 |
+
def SA(mol):
|
42 |
+
"""
|
43 |
+
Computes RDKit's Synthetic Accessibility score
|
44 |
+
"""
|
45 |
+
return sascorer.calculateScore(mol)
|
46 |
+
|
47 |
+
|
48 |
+
def NP(mol):
|
49 |
+
"""
|
50 |
+
Computes RDKit's Natural Product-likeness score
|
51 |
+
"""
|
52 |
+
return npscorer.scoreMol(mol)
|
53 |
+
|
54 |
+
|
55 |
+
def QED(mol):
|
56 |
+
"""
|
57 |
+
Computes RDKit's QED score
|
58 |
+
"""
|
59 |
+
return qed(mol)
|
60 |
+
|
61 |
+
|
62 |
+
def weight(mol):
|
63 |
+
"""
|
64 |
+
Computes molecular weight for given molecule.
|
65 |
+
Returns float,
|
66 |
+
"""
|
67 |
+
return Descriptors.MolWt(mol)
|
68 |
+
|
69 |
+
|
70 |
+
def get_n_rings(mol):
|
71 |
+
"""
|
72 |
+
Computes the number of rings in a molecule
|
73 |
+
"""
|
74 |
+
return mol.GetRingInfo().NumRings()
|
75 |
+
|
76 |
+
|
77 |
+
def fragmenter(mol):
|
78 |
+
"""
|
79 |
+
fragment mol using BRICS and return smiles list
|
80 |
+
"""
|
81 |
+
fgs = AllChem.FragmentOnBRICSBonds(get_mol(mol))
|
82 |
+
fgs_smi = Chem.MolToSmiles(fgs).split(".")
|
83 |
+
return fgs_smi
|
84 |
+
|
85 |
+
|
86 |
+
def compute_fragments(mol_list, n_jobs=1):
|
87 |
+
"""
|
88 |
+
fragment list of mols using BRICS and return smiles list
|
89 |
+
"""
|
90 |
+
fragments = Counter()
|
91 |
+
for mol_frag in mapper(n_jobs)(fragmenter, mol_list):
|
92 |
+
fragments.update(mol_frag)
|
93 |
+
return fragments
|
94 |
+
|
95 |
+
|
96 |
+
def compute_scaffolds(mol_list, n_jobs=1, min_rings=2):
|
97 |
+
"""
|
98 |
+
Extracts a scafold from a molecule in a form of a canonic SMILES
|
99 |
+
"""
|
100 |
+
scaffolds = Counter()
|
101 |
+
map_ = mapper(n_jobs)
|
102 |
+
scaffolds = Counter(
|
103 |
+
map_(partial(compute_scaffold, min_rings=min_rings), mol_list))
|
104 |
+
if None in scaffolds:
|
105 |
+
scaffolds.pop(None)
|
106 |
+
return scaffolds
|
107 |
+
|
108 |
+
|
109 |
+
def compute_scaffold(mol, min_rings=2):
|
110 |
+
mol = get_mol(mol)
|
111 |
+
try:
|
112 |
+
scaffold = MurckoScaffold.GetScaffoldForMol(mol)
|
113 |
+
except (ValueError, RuntimeError):
|
114 |
+
return None
|
115 |
+
n_rings = get_n_rings(scaffold)
|
116 |
+
scaffold_smiles = Chem.MolToSmiles(scaffold)
|
117 |
+
if scaffold_smiles == '' or n_rings < min_rings:
|
118 |
+
return None
|
119 |
+
return scaffold_smiles
|
120 |
+
|
121 |
+
|
122 |
+
def average_agg_tanimoto(stock_vecs, gen_vecs,
|
123 |
+
batch_size=5000, agg='max',
|
124 |
+
device='cpu', p=1):
|
125 |
+
"""
|
126 |
+
For each molecule in gen_vecs finds closest molecule in stock_vecs.
|
127 |
+
Returns average tanimoto score for between these molecules
|
128 |
+
|
129 |
+
Parameters:
|
130 |
+
stock_vecs: numpy array <n_vectors x dim>
|
131 |
+
gen_vecs: numpy array <n_vectors' x dim>
|
132 |
+
agg: max or mean
|
133 |
+
p: power for averaging: (mean x^p)^(1/p)
|
134 |
+
"""
|
135 |
+
assert agg in ['max', 'mean'], "Can aggregate only max or mean"
|
136 |
+
agg_tanimoto = np.zeros(len(gen_vecs))
|
137 |
+
total = np.zeros(len(gen_vecs))
|
138 |
+
for j in range(0, stock_vecs.shape[0], batch_size):
|
139 |
+
x_stock = torch.tensor(stock_vecs[j:j + batch_size]).to(device).float()
|
140 |
+
for i in range(0, gen_vecs.shape[0], batch_size):
|
141 |
+
y_gen = torch.tensor(gen_vecs[i:i + batch_size]).to(device).float()
|
142 |
+
y_gen = y_gen.transpose(0, 1)
|
143 |
+
tp = torch.mm(x_stock, y_gen)
|
144 |
+
jac = (tp / (x_stock.sum(1, keepdim=True) +
|
145 |
+
y_gen.sum(0, keepdim=True) - tp)).cpu().numpy()
|
146 |
+
jac[np.isnan(jac)] = 1
|
147 |
+
if p != 1:
|
148 |
+
jac = jac**p
|
149 |
+
if agg == 'max':
|
150 |
+
agg_tanimoto[i:i + y_gen.shape[1]] = np.maximum(
|
151 |
+
agg_tanimoto[i:i + y_gen.shape[1]], jac.max(0))
|
152 |
+
elif agg == 'mean':
|
153 |
+
agg_tanimoto[i:i + y_gen.shape[1]] += jac.sum(0)
|
154 |
+
total[i:i + y_gen.shape[1]] += jac.shape[0]
|
155 |
+
if agg == 'mean':
|
156 |
+
agg_tanimoto /= total
|
157 |
+
if p != 1:
|
158 |
+
agg_tanimoto = (agg_tanimoto)**(1/p)
|
159 |
+
return np.mean(agg_tanimoto)
|
160 |
+
|
161 |
+
|
162 |
+
def fingerprint(smiles_or_mol, fp_type='maccs', dtype=None, morgan__r=2,
|
163 |
+
morgan__n=1024, *args, **kwargs):
|
164 |
+
"""
|
165 |
+
Generates fingerprint for SMILES
|
166 |
+
If smiles is invalid, returns None
|
167 |
+
Returns numpy array of fingerprint bits
|
168 |
+
|
169 |
+
Parameters:
|
170 |
+
smiles: SMILES string
|
171 |
+
type: type of fingerprint: [MACCS|morgan]
|
172 |
+
dtype: if not None, specifies the dtype of returned array
|
173 |
+
"""
|
174 |
+
fp_type = fp_type.lower()
|
175 |
+
molecule = get_mol(smiles_or_mol, *args, **kwargs)
|
176 |
+
if molecule is None:
|
177 |
+
return None
|
178 |
+
if fp_type == 'maccs':
|
179 |
+
keys = MACCSkeys.GenMACCSKeys(molecule)
|
180 |
+
keys = np.array(keys.GetOnBits())
|
181 |
+
fingerprint = np.zeros(166, dtype='uint8')
|
182 |
+
if len(keys) != 0:
|
183 |
+
fingerprint[keys - 1] = 1 # We drop 0-th key that is always zero
|
184 |
+
elif fp_type == 'morgan':
|
185 |
+
fingerprint = np.asarray(Morgan(molecule, morgan__r, nBits=morgan__n),
|
186 |
+
dtype='uint8')
|
187 |
+
else:
|
188 |
+
raise ValueError("Unknown fingerprint type {}".format(fp_type))
|
189 |
+
if dtype is not None:
|
190 |
+
fingerprint = fingerprint.astype(dtype)
|
191 |
+
return fingerprint
|
192 |
+
|
193 |
+
|
194 |
+
def fingerprints(smiles_mols_array, n_jobs=1, already_unique=False, *args,
|
195 |
+
**kwargs):
|
196 |
+
'''
|
197 |
+
Computes fingerprints of smiles np.array/list/pd.Series with n_jobs workers
|
198 |
+
e.g.fingerprints(smiles_mols_array, type='morgan', n_jobs=10)
|
199 |
+
Inserts np.NaN to rows corresponding to incorrect smiles.
|
200 |
+
IMPORTANT: if there is at least one np.NaN, the dtype would be float
|
201 |
+
Parameters:
|
202 |
+
smiles_mols_array: list/array/pd.Series of smiles or already computed
|
203 |
+
RDKit molecules
|
204 |
+
n_jobs: number of parralel workers to execute
|
205 |
+
already_unique: flag for performance reasons, if smiles array is big
|
206 |
+
and already unique. Its value is set to True if smiles_mols_array
|
207 |
+
contain RDKit molecules already.
|
208 |
+
'''
|
209 |
+
if isinstance(smiles_mols_array, pd.Series):
|
210 |
+
smiles_mols_array = smiles_mols_array.values
|
211 |
+
else:
|
212 |
+
smiles_mols_array = np.asarray(smiles_mols_array)
|
213 |
+
if not isinstance(smiles_mols_array[0], str):
|
214 |
+
already_unique = True
|
215 |
+
|
216 |
+
if not already_unique:
|
217 |
+
smiles_mols_array, inv_index = np.unique(smiles_mols_array,
|
218 |
+
return_inverse=True)
|
219 |
+
|
220 |
+
fps = mapper(n_jobs)(
|
221 |
+
partial(fingerprint, *args, **kwargs), smiles_mols_array
|
222 |
+
)
|
223 |
+
|
224 |
+
length = 1
|
225 |
+
for fp in fps:
|
226 |
+
if fp is not None:
|
227 |
+
length = fp.shape[-1]
|
228 |
+
first_fp = fp
|
229 |
+
break
|
230 |
+
fps = [fp if fp is not None else np.array([np.NaN]).repeat(length)[None, :]
|
231 |
+
for fp in fps]
|
232 |
+
if scipy.sparse.issparse(first_fp):
|
233 |
+
fps = scipy.sparse.vstack(fps).tocsr()
|
234 |
+
else:
|
235 |
+
fps = np.vstack(fps)
|
236 |
+
if not already_unique:
|
237 |
+
return fps[inv_index]
|
238 |
+
return fps
|
239 |
+
|
240 |
+
|
241 |
+
def mol_passes_filters(mol,
|
242 |
+
allowed=None,
|
243 |
+
isomericSmiles=False):
|
244 |
+
"""
|
245 |
+
Checks if mol
|
246 |
+
* passes MCF and PAINS filters,
|
247 |
+
* has only allowed atoms
|
248 |
+
* is not charged
|
249 |
+
"""
|
250 |
+
allowed = allowed or {'C', 'N', 'S', 'O', 'F', 'Cl', 'Br', 'H'}
|
251 |
+
mol = get_mol(mol)
|
252 |
+
if mol is None:
|
253 |
+
return False
|
254 |
+
ring_info = mol.GetRingInfo()
|
255 |
+
if ring_info.NumRings() != 0 and any(
|
256 |
+
len(x) >= 8 for x in ring_info.AtomRings()
|
257 |
+
):
|
258 |
+
return False
|
259 |
+
h_mol = Chem.AddHs(mol)
|
260 |
+
if any(atom.GetFormalCharge() != 0 for atom in mol.GetAtoms()):
|
261 |
+
return False
|
262 |
+
if any(atom.GetSymbol() not in allowed for atom in mol.GetAtoms()):
|
263 |
+
return False
|
264 |
+
if any(h_mol.HasSubstructMatch(smarts) for smarts in _filters):
|
265 |
+
return False
|
266 |
+
smiles = Chem.MolToSmiles(mol, isomericSmiles=isomericSmiles)
|
267 |
+
if smiles is None or len(smiles) == 0:
|
268 |
+
return False
|
269 |
+
if Chem.MolFromSmiles(smiles) is None:
|
270 |
+
return False
|
271 |
+
return True
|