uragankatrrin commited on
Commit
2956799
1 Parent(s): 6dd21b5

Upload 12 files

Browse files
mhnreact/.gitkeep ADDED
File without changes
mhnreact/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ __version__ = "0.0.1"
mhnreact/data.py ADDED
@@ -0,0 +1,338 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ Author: Philipp Seidl
4
+ ELLIS Unit Linz, LIT AI Lab, Institute for Machine Learning
5
+ Johannes Kepler University Linz
6
+ Contact: seidl@ml.jku.at
7
+
8
+ File contains functions that help prepare and download USPTO-related datasets
9
+ """
10
+
11
+ import os
12
+ import gzip
13
+ import pickle
14
+ import requests
15
+ import subprocess
16
+ import pandas as pd
17
+ import numpy as np
18
+ from scipy import sparse
19
+ import json
20
+
21
+ def download_temprel_repo(save_path='data/temprel-fortunato', chunk_size=128):
22
+ "downloads the template-relevance master branch"
23
+ url = "https://gitlab.com/mefortunato/template-relevance/-/archive/master/template-relevance-master.zip"
24
+ r = requests.get(url, stream=True)
25
+ with open(save_path, 'wb') as fd:
26
+ for chunk in r.iter_content(chunk_size=chunk_size):
27
+ fd.write(chunk)
28
+
29
+ def unzip(path):
30
+ "unzips a file given a path"
31
+ import zipfile
32
+ with zipfile.ZipFile(path, 'r') as zip_ref:
33
+ zip_ref.extractall(path.replace('.zip',''))
34
+
35
+
36
+ def download_file(url, output_path=None):
37
+ """
38
+ # code from fortunato
39
+ # could also import from temprel.data.download import get_uspto_50k but slightly altered ;)
40
+
41
+ """
42
+ if not output_path:
43
+ output_path = url.split('/')[-1]
44
+ with requests.get(url, stream=True) as r:
45
+ r.raise_for_status()
46
+ with open(output_path, 'wb') as f:
47
+ for chunk in r.iter_content(chunk_size=8192):
48
+ if chunk:
49
+ f.write(chunk)
50
+
51
+ def get_uspto_480k():
52
+ if not os.path.exists('data'):
53
+ os.mkdir('data')
54
+ if not os.path.exists('data/raw'):
55
+ os.mkdir('data/raw')
56
+ os.chdir('data/raw')
57
+ download_file(
58
+ 'https://github.com/connorcoley/rexgen_direct/raw/master/rexgen_direct/data/train.txt.tar.gz',
59
+ 'train.txt.tar.gz'
60
+ )
61
+ subprocess.run(['tar', 'zxf', 'train.txt.tar.gz'])
62
+ download_file(
63
+ 'https://github.com/connorcoley/rexgen_direct/raw/master/rexgen_direct/data/valid.txt.tar.gz',
64
+ 'valid.txt.tar.gz'
65
+ )
66
+ subprocess.run(['tar', 'zxf', 'valid.txt.tar.gz'])
67
+ download_file(
68
+ 'https://github.com/connorcoley/rexgen_direct/raw/master/rexgen_direct/data/test.txt.tar.gz',
69
+ 'test.txt.tar.gz'
70
+ )
71
+ subprocess.run(['tar', 'zxf', 'test.txt.tar.gz'])
72
+
73
+ with open('train.txt') as f:
74
+ train = [
75
+ {
76
+ 'reaction_smiles': line.strip(),
77
+ 'split': 'train'
78
+ }
79
+ for line in f.readlines()
80
+ ]
81
+ with open('valid.txt') as f:
82
+ valid = [
83
+ {
84
+ 'reaction_smiles': line.strip(),
85
+ 'split': 'valid'
86
+ }
87
+ for line in f.readlines()
88
+ ]
89
+ with open('test.txt') as f:
90
+ test = [
91
+ {
92
+ 'reaction_smiles': line.strip(),
93
+ 'split': 'test'
94
+ }
95
+ for line in f.readlines()
96
+ ]
97
+
98
+ df = pd.concat([
99
+ pd.DataFrame(train),
100
+ pd.DataFrame(valid),
101
+ pd.DataFrame(test)
102
+ ]).reset_index()
103
+ df.to_json('uspto_lg_reactions.json.gz', compression='gzip')
104
+ os.chdir('..')
105
+ os.chdir('..')
106
+ return df
107
+
108
+ def get_uspto_50k():
109
+ '''
110
+ get SI from:
111
+ Nadine Schneider; Daniel M. Lowe; Roger A. Sayle; Gregory A. Landrum. J. Chem. Inf. Model.201555139-53
112
+ '''
113
+ if not os.path.exists('data'):
114
+ os.mkdir('data')
115
+ if not os.path.exists('data/raw'):
116
+ os.mkdir('data/raw')
117
+ os.chdir('data/raw')
118
+ subprocess.run(['wget', 'https://pubs.acs.org/doi/suppl/10.1021/ci5006614/suppl_file/ci5006614_si_002.zip'])
119
+ subprocess.run(['unzip', '-o', 'ci5006614_si_002.zip'])
120
+ data = []
121
+ with gzip.open('ChemReactionClassification/data/training_test_set_patent_data.pkl.gz') as f:
122
+ while True:
123
+ try:
124
+ data.append(pickle.load(f))
125
+ except EOFError:
126
+ break
127
+ reaction_smiles = [d[0] for d in data]
128
+ reaction_reference = [d[1] for d in data]
129
+ reaction_class = [d[2] for d in data]
130
+ df = pd.DataFrame()
131
+ df['reaction_smiles'] = reaction_smiles
132
+ df['reaction_reference'] = reaction_reference
133
+ df['reaction_class'] = reaction_class
134
+ df.to_json('uspto_sm_reactions.json.gz', compression='gzip')
135
+ os.chdir('..')
136
+ os.chdir('..')
137
+ return df
138
+
139
+ def get_uspto_golden():
140
+ """ get uspto golden and convert it to smiles dataframe from
141
+ Lin, Arkadii; Dyubankova, Natalia; Madzhidov, Timur; Nugmanov, Ramil;
142
+ Rakhimbekova, Assima; Ibragimova, Zarina; Akhmetshin, Tagir; Gimadiev,
143
+ Timur; Suleymanov, Rail; Verhoeven, Jonas; Wegner, Jörg Kurt;
144
+ Ceulemans, Hugo; Varnek, Alexandre (2020):
145
+ Atom-to-Atom Mapping: A Benchmarking Study of Popular Mapping Algorithms and Consensus Strategies.
146
+ ChemRxiv. Preprint. https://doi.org/10.26434/chemrxiv.13012679.v1
147
+ """
148
+ if os.path.exists('data/raw/uspto_golden.json.gz'):
149
+ print('loading precomputed')
150
+ return pd.read_json('data/raw/uspto_golden.json.gz', compression='gzip')
151
+ if not os.path.exists('data'):
152
+ os.mkdir('data')
153
+ if not os.path.exists('data/raw'):
154
+ os.mkdir('data/raw')
155
+ os.chdir('data/raw')
156
+ subprocess.run(['wget', 'https://github.com/Laboratoire-de-Chemoinformatique/Reaction_Data_Cleaning/raw/master/data/golden_dataset.zip'])
157
+ subprocess.run(['unzip', '-o', 'golden_dataset.zip']) #return golden_dataset.rdf
158
+
159
+ from CGRtools.files import RDFRead
160
+ import CGRtools
161
+ from rdkit.Chem import AllChem
162
+ def cgr2rxnsmiles(cgr_rx):
163
+ smiles_rx = '.'.join([AllChem.MolToSmiles(CGRtools.to_rdkit_molecule(m)) for m in cgr_rx.reactants])
164
+ smiles_rx += '>>'+'.'.join([AllChem.MolToSmiles(CGRtools.to_rdkit_molecule(m)) for m in cgr_rx.products])
165
+ return smiles_rx
166
+
167
+ data = {}
168
+ input_file = 'golden_dataset.rdf'
169
+ do_basic_standardization=True
170
+ print('reading and converting the rdf-file')
171
+ with RDFRead(input_file) as f:
172
+ while True:
173
+ try:
174
+ r = next(f)
175
+ key = r.meta['Reaction_ID']
176
+ if do_basic_standardization:
177
+ r.thiele()
178
+ r.standardize()
179
+ data[key] = cgr2rxnsmiles(r)
180
+ except StopIteration:
181
+ break
182
+
183
+ print('saving as a dataframe to data/uspto_golden.json.gz')
184
+ df = pd.DataFrame([data],index=['reaction_smiles']).T
185
+ df['reaction_reference'] = df.index
186
+ df.index = range(len(df)) #reindex
187
+ df.to_json('uspto_golden.json.gz', compression='gzip')
188
+
189
+ os.chdir('..')
190
+ os.chdir('..')
191
+ return df
192
+
193
+ def load_USPTO_fortu(path='data/processed', which='uspto_sm_', is_appl_matrix=False):
194
+ """
195
+ loads the fortunato preprocessed data as
196
+ dict X containing X['train'], X['valid'], and X['test']
197
+ as well as the labels containing the corresponding splits
198
+ returns X, y
199
+ """
200
+
201
+ X = {}
202
+ y = {}
203
+
204
+ for split in ['train','valid', 'test']:
205
+ tmp = np.load(f'{path}/{which}{split}.input.smiles.npy', allow_pickle=True)
206
+ X[split] = []
207
+ for ii in range(len(tmp)):
208
+ X[split].append( tmp[ii].split('.'))
209
+
210
+ if is_appl_matrix:
211
+ y[split] = sparse.load_npz(f'{path}/{which}{split}.appl_matrix.npz')
212
+ else:
213
+ y[split] = np.load(f'{path}/{which}{split}.labels.classes.npy', allow_pickle=True)
214
+ print(split, y[split].shape[0], 'samples (', y[split].max() if not is_appl_matrix else y[split].shape[1],'max label)')
215
+ return X, y
216
+
217
+ #TODO one should load in this file pd.read_json('uspto_R_retro.templates.uspto_R_.json.gz')
218
+ # this only holds the templates.. the other holds everything
219
+ def load_templates_sm(path = 'data/processed/uspto_sm_templates.df.json.gz', get_complete_df=False):
220
+ "returns a dict mapping from class index to mapped reaction_smarts from the templates_df"
221
+ df = pd.read_json(path)
222
+ if get_complete_df: return df
223
+ template_dict = {}
224
+ for row in range(len(df)):
225
+ template_dict[df.iloc[row]['index']] = df.iloc[row].reaction_smarts
226
+ return template_dict
227
+
228
+ def load_templates_lg(path = 'data/processed/uspto_lg_templates.df.json.gz', get_complete_df=False):
229
+ return load_templates_sm(path=path, get_complete_df=get_complete_df)
230
+
231
+ def load_USPTO_sm():
232
+ "loads the default dataset"
233
+ return load_USPTO_fortu(which='uspto_sm_')
234
+
235
+ def load_USPTO_lg():
236
+ "loads the default dataset"
237
+ return load_USPTO_fortu(which='uspto_lg_')
238
+
239
+ def load_USPTO_sm_pretraining():
240
+ "loads the default application matrix label and dataset"
241
+ return load_USPTO_fortu(which='uspto_sm_', is_appl_matrix=True)
242
+ def load_USPTO_lg_pretraining():
243
+ "loads the default application matrix label and dataset"
244
+ return load_USPTO_fortu(which='uspto_lg_', is_appl_matrix=True)
245
+
246
+ def load_USPTO_df_sm():
247
+ "loads the USPTO small Sm dataset dataframe"
248
+ return pd.read_json('data/raw/uspto_sm_reactions.json.gz')
249
+
250
+ def load_USPTO_df_lg():
251
+ "loads the USPTO large Lg dataset dataframe"
252
+ return pd.read_json('data/raw/uspto_sm_reactions.json.gz')
253
+
254
+ def load_USPTO_golden():
255
+ "loads the golden USPTO dataset"
256
+ return load_USPTO_fortu(which=f'uspto_golden_', is_appl_matrix=False)
257
+
258
+ def load_USPTO(which = 'sm', is_appl_matrix=False):
259
+ return load_USPTO_fortu(which=f'uspto_{which}_', is_appl_matrix=is_appl_matrix)
260
+
261
+ def load_templates(which = 'sm',fdir='data/processed', get_complete_df=False):
262
+ return load_templates_sm(path=f'{fdir}/uspto_{which}_templates.df.json.gz', get_complete_df=get_complete_df)
263
+
264
+ def load_data(dataset, path):
265
+ splits = ['train', 'valid', 'test']
266
+ split2smiles = {}
267
+ split2label = {}
268
+ split2reactants = {}
269
+ split2appl = {}
270
+ split2prod_idx_reactants = {}
271
+
272
+ for split in splits:
273
+ label_fn = os.path.join(path, f'{dataset}_{split}.labels.classes.npy')
274
+ split2label[split] = np.load(label_fn, allow_pickle=True)
275
+
276
+ smiles_fn = os.path.join(path, f'{dataset}_{split}.input.smiles.npy')
277
+ split2smiles[split] = np.load(smiles_fn, allow_pickle=True)
278
+
279
+ reactants_fn = os.path.join(path, f'uspto_R_{split}.reactants.canonical.npy')
280
+ split2reactants[split] = np.load(reactants_fn, allow_pickle=True)
281
+
282
+
283
+ split2appl[split] = np.load(os.path.join(path, f'{dataset}_{split}.applicability.npy'))
284
+
285
+ pir_fn = os.path.join(path, f'{dataset}_{split}.prod.idx.reactants.p')
286
+ if os.path.isfile(pir_fn):
287
+ with open(pir_fn, 'rb') as f:
288
+ split2prod_idx_reactants[split] = pickle.load(f)
289
+
290
+
291
+ if len(split2prod_idx_reactants) == 0:
292
+ split2prod_idx_reactants = None
293
+
294
+ with open(os.path.join(path, f'{dataset}_templates.json'), 'r') as f:
295
+ label2template = json.load(f)
296
+ label2template = {int(k): v for k,v in label2template.items()}
297
+
298
+ return split2smiles, split2label, split2reactants, split2appl, split2prod_idx_reactants, label2template
299
+
300
+
301
+ def load_dataset_from_csv(csv_path='', split_col='split', input_col='prod_smiles', ssretroeval=False, reactants_col='reactants_can', ret_df=False, **kwargs):
302
+ """loads the dataset from a CSV file containing a split-column, and input-column which can be defined,
303
+ as well as a 'reaction_smarts' column containing the extracted template, a 'label' column (the index of the template)
304
+ :returns
305
+
306
+ """
307
+ print('loading X, y from csv')
308
+ df = pd.read_csv(csv_path)
309
+ X = {}
310
+ y = {}
311
+
312
+ for spli in set(df[split_col]):
313
+ #X[spli] = list(df[df[split_col]==spli]['prod_smiles'].apply(lambda k: [k]))
314
+ X[spli] = list(df[df[split_col]==spli][input_col].apply(lambda k: [k]))
315
+ y[spli] = (df[df[split_col]==spli]['label']).values
316
+ print(spli, len(X[spli]), 'samples')
317
+
318
+ # template to dict
319
+ tmp = df[['reaction_smarts','label']].drop_duplicates(subset=['reaction_smarts','label']).sort_values('label')
320
+ tmp.index= tmp.label
321
+ template_list = tmp['reaction_smarts'].to_dict()
322
+ print(len(template_list),'templates')
323
+
324
+ if ssretroeval:
325
+ # setup for ttest
326
+ test_reactants_can = list(df[df[split_col]=='test'][reactants_col])
327
+
328
+ only_in_test = set(y['test']) - set(y['train']).union(set(y['valid']))
329
+ print('obfuscating', len(only_in_test), 'templates because they are only in test')
330
+ for ii in only_in_test:
331
+ template_list[ii] = 'CCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCC.CCCCCCCCCCCCCCCCCCCCCCCCCCC.CCCCCCCCCCCCCCCCCCCCCC>>CCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCC.CCCCCCCCCCCCCCCCCCCCC' #obfuscate them
332
+ if ret_df:
333
+ return X, y, template_list, test_reactants_can, df
334
+ return X, y, template_list, test_reactants_can
335
+
336
+ if ret_df:
337
+ return X, y, template_list, None, df
338
+ return X, y, template_list, None
mhnreact/inference.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ Author: Philipp Seidl
4
+ ELLIS Unit Linz, LIT AI Lab, Institute for Machine Learning
5
+ Johannes Kepler University Linz
6
+ Contact: seidl@ml.jku.at
7
+
8
+ File contains functions that help prepare and download USPTO-related datasets
9
+ """
10
+
11
+ # Cell
12
+ from .model import ModelConfig, MHN
13
+ import torch
mhnreact/inspect.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ Author: Philipp Seidl
4
+ ELLIS Unit Linz, LIT AI Lab, Institute for Machine Learning
5
+ Johannes Kepler University Linz
6
+ Contact: seidl@ml.jku.at
7
+
8
+ File contains functions that
9
+ """
10
+
11
+ from . import model
12
+ import torch
13
+ import os
14
+
15
+ MODEL_PATH = 'data/model/'
16
+
17
+ def smarts2svg(smarts, useSmiles=True, highlightByReactant=True, save_to=''):
18
+ """
19
+ draws smiles of smarts to an SVG and displays it in the Notebook,
20
+ or optinally can be saved to a file `save_to`
21
+ adapted from https://www.kesci.com/mw/project/5c7685191ce0af002b556cc5
22
+ """
23
+ # adapted from https://www.kesci.com/mw/project/5c7685191ce0af002b556cc5
24
+ from rdkit import RDConfig
25
+ from rdkit import Chem
26
+ from rdkit.Chem import Draw, AllChem
27
+ from rdkit.Chem.Draw import rdMolDraw2D
28
+ from rdkit import Geometry
29
+ import matplotlib.pyplot as plt
30
+ import matplotlib.cm as cm
31
+ import matplotlib
32
+ from IPython.display import SVG, display
33
+
34
+ rxn = AllChem.ReactionFromSmarts(smarts,useSmiles=useSmiles)
35
+ d = Draw.MolDraw2DSVG(900, 100)
36
+
37
+ # rxn = AllChem.ReactionFromSmarts('[CH3:1][C:2](=[O:3])[OH:4].[CH3:5][NH2:6]>CC(O)C.[Pt]>[CH3:1][C:2](=[O:3])[NH:6][CH3:5].[OH2:4]',useSmiles=True)
38
+ colors=[(0.3, 0.7, 0.9),(0.9, 0.7, 0.9),(0.6,0.9,0.3),(0.9,0.9,0.1)]
39
+ try:
40
+ d.DrawReaction(rxn,highlightByReactant=highlightByReactant)
41
+ d.FinishDrawing()
42
+
43
+ txt = d.GetDrawingText()
44
+ # self.assertTrue(txt.find("<svg") != -1)
45
+ # self.assertTrue(txt.find("</svg>") != -1)
46
+
47
+ svg = d.GetDrawingText()
48
+ svg2 = svg.replace('svg:','')
49
+ svg3 = SVG(svg2)
50
+ display(svg3)
51
+
52
+ if save_to!='':
53
+ with open(save_to, 'w') as f_handle:
54
+ f_handle.write(svg3.data)
55
+ except:
56
+ print('Error drawing')
57
+
58
+ return svg2
59
+
60
+ def list_models(model_path=MODEL_PATH):
61
+ """returns a list of loadable models"""
62
+ return dict(enumerate(list(filter(lambda k: str(k)[-3:]=='.pt', os.listdir(model_path)))))
63
+
64
+ def load_clf(model_fn='', model_path=MODEL_PATH, device='cpu', model_type='mhn'):
65
+ """ returns the model with loaded weights given a filename"""
66
+ import json
67
+ config_fn = '_'.join(model_fn.split('_')[-2:]).split('.pt')[0]
68
+ conf_dict = json.load( open( f"{model_path}{config_fn}_config.json" ) )
69
+ train_conf_dict = json.load( open( f"{model_path}{config_fn}_config.json" ) )
70
+
71
+ # specify the config the saved model had
72
+ conf = model.ModelConfig(**conf_dict)
73
+ conf.device = device
74
+ print(conf.__dict__)
75
+
76
+ if model_type == 'staticQK':
77
+ clf = model.StaticQK(conf)
78
+ elif model_type == 'mhn':
79
+ clf = model.MHN(conf)
80
+ elif model_type == 'segler':
81
+ clf = model.SeglerBaseline(conf)
82
+ elif model_type == 'fortunato':
83
+ clf = model.SeglerBaseline(conf)
84
+ else:
85
+ raise NotImplementedError('model_type',model_type,'not found')
86
+
87
+ # load the model
88
+ PATH = model_path+model_fn
89
+ params = torch.load(PATH, map_location=torch.device('cpu')) #!!!
90
+ clf.load_state_dict(params, strict=False)
91
+ if 'templates+noise' in params.keys():
92
+ print('loading templates+noise')
93
+ clf.templates = params['templates+noise']
94
+ #clf.templates.to(clf.config.device)
95
+ return clf
mhnreact/model.py ADDED
@@ -0,0 +1,660 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ Author: Philipp Seidl
4
+ ELLIS Unit Linz, LIT AI Lab, Institute for Machine Learning
5
+ Johannes Kepler University Linz
6
+ Contact: seidl@ml.jku.at
7
+
8
+ Model related functionality
9
+ """
10
+ from .utils import top_k_accuracy
11
+ from .plotutils import plot_loss, plot_topk, plot_nte
12
+ from .molutils import convert_smiles_to_fp
13
+ import os
14
+ import numpy as np
15
+ import torch
16
+ import torch.nn as nn
17
+ from collections import defaultdict
18
+ from scipy import sparse
19
+ import logging
20
+ from tqdm import tqdm
21
+ import wandb
22
+
23
+ log = logging.getLogger(__name__)
24
+
25
+ class ChemRXNDataset(torch.utils.data.Dataset):
26
+ "Torch Dataset for ChemRXN containing Xs: the input as np array, target: the target molecules (or nothing), and ys: the label"
27
+ def __init__(self, Xs, target, ys, is_smiles=False, fp_size=2048, fingerprint_type='morgan'):
28
+ self.is_smiles=is_smiles
29
+ if is_smiles:
30
+ self.Xs = Xs
31
+ self.target = target
32
+ self.fp_size = fp_size
33
+ self.fingerprint_type = fingerprint_type
34
+ else:
35
+ self.Xs = Xs.astype(np.float32)
36
+ self.target = target.astype(np.float32)
37
+ self.ys = ys
38
+ self.ys_is_sparse = isinstance(self.ys, sparse.csr.csr_matrix)
39
+
40
+ def __getitem__(self, k):
41
+ mol_fp = self.Xs[k]
42
+ if self.is_smiles:
43
+ mol_fp = convert_smiles_to_fp(mol_fp, fp_size=self.fp_size, which=self.fingerprint_type).astype(np.float32)
44
+
45
+ target = None if self.target is None else self.target[k]
46
+ if self.is_smiles and self.target:
47
+ target = convert_smiles_to_fp(target, fp_size=self.fp_size, which=self.fingerprint_type).astype(np.float32)
48
+
49
+ label = self.ys[k]
50
+ if isinstance(self.ys, sparse.csr.csr_matrix):
51
+ label = label.toarray()[0]
52
+
53
+ return (mol_fp, target, label)
54
+
55
+ def __len__(self):
56
+ return len(self.Xs)
57
+
58
+ class ModelConfig(object):
59
+ def __init__(self, **kwargs):
60
+ self.fingerprint_type = kwargs.pop("fingerprint_type", 'morgan')
61
+ self.template_fp_type = kwargs.pop("template_fp_type", 'rdk')
62
+ self.num_templates = kwargs.pop("num_templates", 401)
63
+ self.fp_size = kwargs.pop("fp_size", 2048)
64
+ self.fp_radius = kwargs.pop("fp_radius", 4)
65
+
66
+ self.device = kwargs.pop("device", 'cuda' if torch.cuda.is_available() else 'cpu')
67
+ self.batch_size = kwargs.pop("batch_size", 32)
68
+ self.pooling_operation_state_embedding = kwargs.pop('pooling_operation_state_embedding', 'mean')
69
+ self.pooling_operation_head = kwargs.pop('pooling_operation_head', 'max')
70
+
71
+ self.dropout = kwargs.pop('dropout', 0.0)
72
+
73
+ self.lr = kwargs.pop('lr', 1e-4)
74
+ self.optimizer = kwargs.pop("optimizer", "Adam")
75
+
76
+ self.activation_function = kwargs.pop('activation_function', 'ReLU')
77
+ self.verbose = kwargs.pop("verbose", False) # debugging or printing additional warnings / information set tot True
78
+
79
+ self.hopf_input_size = kwargs.pop('hopf_input_size', 2048)
80
+ self.hopf_output_size = kwargs.pop("hopf_output_size", 768)
81
+ self.hopf_num_heads = kwargs.pop("hopf_num_heads", 1)
82
+ self.hopf_asso_dim = kwargs.pop("hopf_asso_dim", 768)
83
+ self.hopf_association_activation = kwargs.pop("hopf_association_activation", None)
84
+ self.hopf_beta = kwargs.pop("hopf_beta",0.125) # 1/(self.hopf_asso_dim**(1/2) sqrt(d_k)
85
+ self.norm_input = kwargs.pop("norm_input",False)
86
+ self.norm_asso = kwargs.pop("norm_asso", False)
87
+
88
+ # additional experimental hyperparams
89
+ if 'hopf_n_layers' in kwargs.keys():
90
+ self.hopf_n_layers = kwargs.pop('hopf_n_layers', 0)
91
+ if 'mol_encoder_layers' in kwargs.keys():
92
+ self.mol_encoder_layers = kwargs.pop('mol_encoder_layers', 1)
93
+ if 'temp_encoder_layers' in kwargs.keys():
94
+ self.temp_encoder_layers = kwargs.pop('temp_encoder_layers', 1)
95
+ if 'encoder_af' in kwargs.keys():
96
+ self.encoder_af = kwargs.pop('encoder_af', 'ReLU')
97
+
98
+ # additional kwargs
99
+ for key, value in kwargs.items():
100
+ try:
101
+ setattr(self, key, value)
102
+ except AttributeError as err:
103
+ log.error(f"Can't set {key} with value {value} for {self}")
104
+ raise err
105
+
106
+
107
+ class Encoder(nn.Module):
108
+ """Simple FFNN"""
109
+ def __init__(self, input_size: int = 2048, output_size: int = 1024,
110
+ num_layers: int = 1, dropout: float = 0.3, af_name: str ='None',
111
+ norm_in: bool = False, norm_out: bool = False):
112
+ super().__init__()
113
+ self.ws = []
114
+ self.setup_af(af_name)
115
+ self.norm_in = (lambda k: k) if not norm_in else torch.nn.LayerNorm(input_size, elementwise_affine=False)
116
+ self.norm_out = (lambda k: k) if not norm_out else torch.nn.LayerNorm(output_size, elementwise_affine=False)
117
+ self.setup_ff(input_size, output_size, num_layers)
118
+ self.dropout = nn.Dropout(p=dropout)
119
+
120
+ def forward(self, x: torch.Tensor):
121
+ x = self.norm_in(x)
122
+ for i, w in enumerate(self.ws):
123
+ if i==(len(self.ws)-1):
124
+ x = self.dropout(w(x)) # all except last haf ff_af
125
+ else:
126
+ x = self.dropout(self.af(w(x)))
127
+ x = self.norm_out(x)
128
+ return x
129
+
130
+ def setup_ff(self, input_size:int, output_size:int, num_layers=1):
131
+ """setup feed-forward NN with n-layers"""
132
+ for n in range(0, num_layers):
133
+ w = nn.Linear(input_size if n==0 else output_size, output_size)
134
+ torch.nn.init.kaiming_normal_(w.weight, mode='fan_in', nonlinearity='linear') # eqiv to LeCun init
135
+ setattr(self, f'W_{n}', w) # consider doing a step-wise reduction
136
+ self.ws.append(getattr(self, f'W_{n}'))
137
+
138
+ def setup_af(self, af_name : str):
139
+ """set activation function"""
140
+ if af_name is None or (af_name == 'None'):
141
+ self.af = lambda k: k
142
+ else:
143
+ try:
144
+ self.af = getattr(nn, af_name)()
145
+ except AttributeError as err:
146
+ log.error(f"Can't find activation-function {af_name} in torch.nn")
147
+ raise err
148
+
149
+
150
+ class MoleculeEncoder(Encoder):
151
+ """
152
+ Class for Molecule encoder: can be any class mapping Smiles to a Vector (preferable differentiable ;)
153
+ """
154
+ def __init__(self, config):
155
+ self.config = config
156
+
157
+ class FPMolEncoder(Encoder):
158
+ """
159
+ Fingerprint Based Molecular encoder
160
+ """
161
+ def __init__(self, config):
162
+ super().__init__(input_size = config.hopf_input_size*config.hopf_num_heads,
163
+ output_size = config.hopf_asso_dim*config.hopf_num_heads,
164
+ num_layers = config.mol_encoder_layers,
165
+ dropout = config.dropout,
166
+ af_name = config.encoder_af,
167
+ norm_in = config.norm_input,
168
+ norm_out = config.norm_asso,
169
+ )
170
+ # number of layers = self.config.mol_encoder_layers
171
+ # layer-dimension = self.config.hopf_asso_dim
172
+ # activation-function = self.config.af
173
+
174
+ self.config = config
175
+
176
+ def forward_smiles(self, list_of_smiles: list):
177
+ fp_tensor = self.convert_smiles_to_tensor(list_of_smiles)
178
+ return self.forward(fp_tensor)
179
+
180
+ def convert_smiles_to_tensor(self, list_of_smiles):
181
+ fps = convert_smiles_to_fp(list_of_smiles, fp_size=self.config.fp_size,
182
+ which=self.config.fingerprint_type, radius=self.config.fp_radius)
183
+ fps_tensor = torch.from_numpy(fps.astype(np.float)).to(dtype=torch.float).to(self.config.device)
184
+ return fps_tensor
185
+
186
+ class TemplateEncoder(Encoder):
187
+ """
188
+ Class for Template encoder: can be any class mapping a Smarts-Reaction to a Vector (preferable differentiable ;)
189
+ """
190
+ def __init__(self, config):
191
+ super().__init__(input_size = config.hopf_input_size*config.hopf_num_heads,
192
+ output_size = config.hopf_asso_dim*config.hopf_num_heads,
193
+ num_layers = config.temp_encoder_layers,
194
+ dropout = config.dropout,
195
+ af_name = config.encoder_af,
196
+ norm_in = config.norm_input,
197
+ norm_out = config.norm_asso,
198
+ )
199
+ self.config = config
200
+ #number of layers
201
+ #template fingerprint type
202
+ #random template threshold
203
+ #reactant pooling
204
+ if config.temp_encoder_layers==0:
205
+ print('No Key-Projection = Static Key/Templates')
206
+ assert self.config.hopf_asso_dim==self.config.fp_size
207
+ self.wks = []
208
+
209
+
210
+ class MHN(nn.Module):
211
+ """
212
+ MHN - modern Hopfield Network -- for Template relevance prediction
213
+ """
214
+ def __init__(self, config=None, layer2weight=0.05, use_template_encoder=True):
215
+ super().__init__()
216
+ if config:
217
+ self.config = config
218
+ else:
219
+ self.config = ModelConfig()
220
+ self.beta = self.config.hopf_beta
221
+ # hopf_num_heads
222
+ self.mol_encoder = FPMolEncoder(self.config)
223
+ if use_template_encoder:
224
+ self.template_encoder = TemplateEncoder(self.config)
225
+
226
+ self.W_v = None
227
+ self.layer2weight = layer2weight
228
+
229
+ # more MHN layers -- added recursively
230
+ if hasattr(self.config, 'hopf_n_layers'):
231
+ di = self.config.__dict__
232
+ di['hopf_n_layers'] -= 1
233
+ if di['hopf_n_layers']>0:
234
+ conf_wo_hopf_nlayers = ModelConfig(**di)
235
+ self.layer = MHN(conf_wo_hopf_nlayers)
236
+ if di['hopf_n_layers']!=0:
237
+ self.W_v = nn.Linear(self.config.hopf_asso_dim, self.config.hopf_input_size)
238
+ torch.nn.init.kaiming_normal_(self.W_v.weight, mode='fan_in', nonlinearity='linear') # eqiv to LeCun init
239
+
240
+ self.softmax = torch.nn.Softmax(dim=1)
241
+
242
+ self.lossfunction = nn.CrossEntropyLoss(reduction='none')#, weight=class_weights)
243
+ self.pretrain_lossfunction = nn.BCEWithLogitsLoss(reduction='none')#, weight=class_weights)
244
+
245
+ self.lr = self.config.lr
246
+
247
+ if self.config.hopf_association_activation is None or (self.config.hopf_association_activation.lower()=='none'):
248
+ self.af = lambda k: k
249
+ else:
250
+ self.af = getattr(nn, self.config.hopf_association_activation)()
251
+
252
+ self.pooling_operation_head = getattr(torch, self.config.pooling_operation_head)
253
+
254
+ self.X = None # templates projected to Hopfield Layer
255
+
256
+ self.optimizer = getattr(torch.optim, self.config.optimizer)(self.parameters(), lr=self.lr)
257
+ self.steps = 0
258
+ self.hist = defaultdict(list)
259
+ self.to(self.config.device)
260
+
261
+ def set_templates(self, template_list, which='rdk', fp_size=None, radius=2, learnable=False, njobs=1, only_templates_in_batch=False):
262
+ self.template_list = template_list.copy()
263
+ if fp_size is None:
264
+ fp_size = self.config.fp_size
265
+ if len(template_list)>=100000:
266
+ import math
267
+ print('batch-wise template_calculation')
268
+ bs = 30000
269
+ final_temp_emb = torch.zeros((len(template_list), fp_size)).float().to(self.config.device)
270
+ for b in range(math.ceil(len(template_list)//bs)+1):
271
+ self.template_list = template_list[bs*b:min(bs*(b+1), len(template_list))]
272
+ templ_emb = self.update_template_embedding(which=which, fp_size=fp_size, radius=radius, learnable=learnable, njobs=njobs, only_templates_in_batch=only_templates_in_batch)
273
+ final_temp_emb[bs*b:min(bs*(b+1), len(template_list))] = torch.from_numpy(templ_emb)
274
+ self.templates = final_temp_emb
275
+ else:
276
+ self.update_template_embedding(which=which, fp_size=fp_size, radius=radius, learnable=learnable, njobs=njobs, only_templates_in_batch=only_templates_in_batch)
277
+
278
+ self.set_templates_recursively()
279
+
280
+ def set_templates_recursively(self):
281
+ if 'hopf_n_layers' in self.config.__dict__.keys():
282
+ if self.config.hopf_n_layers >0:
283
+ self.layer.templates = self.templates
284
+ self.layer.set_templates_recursively()
285
+
286
+ def update_template_embedding(self,fp_size=2048, radius=4, which='rdk', learnable=False, njobs=1, only_templates_in_batch=False):
287
+ print('updating template-embedding; (just computing the template-fingerprint and using that)')
288
+ bs = self.config.batch_size
289
+
290
+ split_template_list = [str(t).split('>')[0].split('.') for t in self.template_list]
291
+ templates_np = convert_smiles_to_fp(split_template_list, is_smarts=True, fp_size=fp_size, radius=radius, which=which, njobs=njobs)
292
+
293
+ split_template_list = [str(t).split('>')[-1].split('.') for t in self.template_list]
294
+ reactants_np = convert_smiles_to_fp(split_template_list, is_smarts=True, fp_size=fp_size, radius=radius, which=which, njobs=njobs)
295
+
296
+ template_representation = templates_np-(reactants_np*0.5)
297
+ if learnable:
298
+ self.templates = torch.nn.Parameter(torch.from_numpy(template_representation).float(), requires_grad=True).to(self.config.device)
299
+ self.register_parameter(name='templates', param=self.templates)
300
+ else:
301
+ if only_templates_in_batch:
302
+ self.templates_np = template_representation
303
+ else:
304
+ self.templates = torch.from_numpy(template_representation).float().to(self.config.device)
305
+
306
+ return template_representation
307
+
308
+
309
+ def np_fp_to_tensor(self, np_fp):
310
+ return torch.from_numpy(np_fp.astype(np.float64)).to(self.config.device).float()
311
+
312
+ def masked_loss_fun(self, loss_fun, h_out, ys_batch):
313
+ if loss_fun == self.BCEWithLogitsLoss:
314
+ mask = (ys_batch != -1).float()
315
+ ys_batch = ys_batch.float()
316
+ else:
317
+ mask = (ys_batch.long() != -1).long()
318
+ mask_sum = int(mask.sum().cpu().numpy())
319
+ if mask_sum == 0:
320
+ return 0
321
+
322
+ ys_batch = ys_batch * mask
323
+
324
+ loss = (loss_fun(h_out, ys_batch * mask) * mask.float()).sum() / mask_sum # only mean from non -1
325
+ return loss
326
+
327
+ def compute_losses(self, out, ys_batch, head_loss_weight=None):
328
+
329
+ if len(ys_batch.shape)==2:
330
+ if ys_batch.shape[1]==self.config.num_templates: # it is in pretraining_mode
331
+ loss = self.pretrain_lossfunction(out, ys_batch.float()).mean()
332
+ else:
333
+ # legacy from policyNN
334
+ loss = self.lossfunction(out, ys_batch[:, 2]).mean() # WARNING: HEAD4 Reaction Template is ys[:,2]
335
+ else:
336
+ loss = self.lossfunction(out, ys_batch).mean()
337
+ return loss
338
+
339
+ def forward_smiles(self, list_of_smiles, templates=None):
340
+ state_tensor = self.mol_encoder.convert_smiles_to_tensor(list_of_smiles)
341
+ return self.forward(state_tensor, templates=templates)
342
+
343
+ def forward(self, m, templates=None):
344
+ """
345
+ m: molecule in the form batch x fingerprint
346
+ templates: None or newly given templates if not instanciated
347
+ returns logits ranking the templates for each molecule
348
+ """
349
+ #states_emb = self.fcfe(state_fp)
350
+ bs = m.shape[0] #batch_size
351
+ #templates = self.temp_emb(torch.arange(0,2000).long())
352
+ if (templates is None) and (self.X is None) and (self.templates is None):
353
+ raise Exception('Either pass in templates, or init templates by runnting clf.set_templates')
354
+ n_temp = len(templates) if templates is not None else len(self.templates)
355
+ if self.training or (templates is None) or (self.X is not None):
356
+ templates = templates if templates is not None else self.templates
357
+ X = self.template_encoder(templates)
358
+ else:
359
+ X = self.X # precomputed from last forward run
360
+
361
+ Xi = self.mol_encoder(m)
362
+
363
+ Xi = Xi.view(bs, self.config.hopf_num_heads, self.config.hopf_asso_dim) # [bs, H, A]
364
+ X = X.view(1, n_temp, self.config.hopf_asso_dim, self.config.hopf_num_heads) #[1, T, A, H]
365
+
366
+ XXi = torch.tensordot(Xi, X, dims=[(2,1), (2,0)]) # AxA -> [bs, T, H]
367
+
368
+ # pooling over heads
369
+ if self.config.hopf_num_heads<=1:
370
+ #QKt_pooled = QKt
371
+ XXi = XXi[:,:,0] #torch.squeeze(QKt, dim=2)
372
+ else:
373
+ XXi = self.pooling_operation_head(XXi, dim=2) # default is max pooling over H [bs, T]
374
+ if (self.config.pooling_operation_head =='max') or (self.config.pooling_operation_head =='min'):
375
+ XXi = XXi[0] #max and min also return the indices =S
376
+
377
+ out = self.beta*XXi # [bs, T, H] # softmax over dim=1 #pooling_operation_head
378
+
379
+ self.xinew = self.softmax(out)@X.view(n_temp, self.config.hopf_asso_dim) # [bs,T]@[T,emb] -> [bs,emb]
380
+
381
+ if self.W_v:
382
+ # call layers recursive
383
+ hopfout = self.W_v(self.xinew) # [bs,emb]@[emb,hopf_inp] --> [bs, hopf_inp]
384
+ # TODO check if using x_pooled or if not going through mol_encoder again
385
+ hopfout = hopfout + m # skip-connection
386
+ # give it to the next layer
387
+ out2 = self.layer.forward(hopfout) #templates=self.W_v(self.K)
388
+ out = out*(1-self.layer2weight)+out2*self.layer2weight
389
+
390
+ return out
391
+
392
+ def train_from_np(self, Xs, targets, ys, is_smiles=False, epochs=2, lr=0.001, bs=32,
393
+ permute_batches=False, shuffle=True, optimizer=None,
394
+ use_dataloader=True, verbose=False,
395
+ wandb=None, scheduler=None, only_templates_in_batch=False):
396
+ """
397
+ Xs in the form sample x states
398
+ targets
399
+ ys in the form sample x [y_h1, y_h2, y_h3, y_h4]
400
+ """
401
+ self.train()
402
+ if optimizer is None:
403
+ try:
404
+ self.optimizer = getattr(torch.optim, self.config.optimizer)(self.parameters(), lr=self.lr if lr is None else lr)
405
+ except AttributeError as err:
406
+ log.error(f"Can't find optimizer {config.optimizer} in torch.optim")
407
+ raise err
408
+ optimizer = self.optimizer
409
+
410
+ dataset = ChemRXNDataset(Xs, targets, ys, is_smiles=is_smiles,
411
+ fp_size=self.config.fp_size, fingerprint_type=self.config.fingerprint_type)
412
+
413
+ dataloader = torch.utils.data.DataLoader(dataset, batch_size=bs, shuffle=shuffle, sampler=None,
414
+ batch_sampler=None, num_workers=0, collate_fn=None,
415
+ pin_memory=False, drop_last=False, timeout=0,
416
+ worker_init_fn=None)
417
+
418
+ for epoch in range(epochs): # loop over the dataset multiple times
419
+ running_loss = 0.0
420
+ running_loss_dict = defaultdict(int)
421
+ batch_order = range(0, len(Xs), bs)
422
+ if permute_batches:
423
+ batch_order = np.random.permutation(batch_order)
424
+
425
+ for step, s in tqdm(enumerate(dataloader),mininterval=2):
426
+ batch = [b.to(self.config.device, non_blocking=True) for b in s]
427
+ Xs_batch, target_batch, ys_batch = batch
428
+
429
+ # zero the parameter gradients
430
+ optimizer.zero_grad()
431
+
432
+ # forward + backward + optimize
433
+ out = self.forward(Xs_batch)
434
+ total_loss = self.compute_losses(out, ys_batch)
435
+
436
+ loss_dict = {'CE_loss': total_loss}
437
+
438
+ total_loss.backward()
439
+
440
+ optimizer.step()
441
+ if scheduler:
442
+ scheduler.step()
443
+ self.steps += 1
444
+
445
+ # print statistics
446
+ for k in loss_dict:
447
+ running_loss_dict[k] += loss_dict[k].item()
448
+ try:
449
+ running_loss += total_loss.item()
450
+ except:
451
+ running_loss += 0
452
+
453
+ rs = min(100,len(Xs)//bs) # reporting/logging steps
454
+ if step % rs == (rs-1): # print every 2000 mini-batches
455
+ if verbose: print('[%d, %5d] loss: %.3f' %
456
+ (epoch + 1, step + 1, running_loss / rs))
457
+ self.hist['step'].append(self.steps)
458
+ self.hist['loss'].append(running_loss/rs)
459
+ self.hist['trianing_running_loss'].append(running_loss/rs)
460
+
461
+ [self.hist[k].append(running_loss_dict[k]/rs) for k in running_loss_dict]
462
+
463
+ if wandb:
464
+ wandb.log({'trianing_running_loss': running_loss / rs})
465
+
466
+ running_loss = 0.0
467
+ running_loss_dict = defaultdict(int)
468
+
469
+ if verbose: print('Finished Training')
470
+ return optimizer
471
+
472
+ def evaluate(self, Xs, targets, ys, split='test', is_smiles=False, bs = 32, shuffle=False, wandb=None, only_loss=False):
473
+ self.eval()
474
+ y_preds = np.zeros( (ys.shape[0], self.config.num_templates), dtype=np.float16)
475
+
476
+ loss_metrics = defaultdict(int)
477
+ new_hist = defaultdict(float)
478
+ with torch.no_grad():
479
+ dataset = ChemRXNDataset(Xs, targets, ys, is_smiles=is_smiles,
480
+ fp_size=self.config.fp_size, fingerprint_type=self.config.fingerprint_type)
481
+ dataloader = torch.utils.data.DataLoader(dataset, batch_size=bs, shuffle=shuffle, sampler=None,
482
+ batch_sampler=None, num_workers=0, collate_fn=None,
483
+ pin_memory=False, drop_last=False, timeout=0,
484
+ worker_init_fn=None)
485
+
486
+ #for step, s in eoutputs = self.forward(batch[0], batchnumerate(range(0, len(Xs), bs)):
487
+ for step, batch in enumerate(dataloader):#
488
+ batch = [b.to(self.config.device, non_blocking=True) for b in batch]
489
+ ys_batch = batch[2]
490
+
491
+ if hasattr(self, 'templates_np'):
492
+ outputs = []
493
+ for ii in range(10):
494
+ tlen = len(self.templates_np)
495
+ i_tlen = tlen//10
496
+ templates = torch.from_numpy(self.templates_np[(i_tlen*ii):min(i_tlen*(ii+1), tlen)]).float().to(self.config.device)
497
+ outputs.append( self.forward(batch[0], templates = templates ) )
498
+ outputs = torch.cat(outputs, dim=0)
499
+
500
+ else:
501
+ outputs = self.forward(batch[0])
502
+
503
+ loss = self.compute_losses(outputs, ys_batch, None)
504
+
505
+ # not quite right because in every batch there might be different number of valid samples
506
+ weight = 1/len(batch[0])#len(Xs[s:min(s + bs, len(Xs))]) / len(Xs)
507
+
508
+ loss_metrics['loss'] += (loss.item())
509
+
510
+ if len(ys.shape)>1:
511
+ outputs = self.softmax(outputs) if not (ys.shape[1]==self.config.num_templates) else torch.sigmoid(outputs)
512
+ else:
513
+ outputs = self.softmax(outputs)
514
+
515
+ outputs_np = [None if o is None else o.to('cpu').numpy().astype(np.float16) for o in outputs]
516
+
517
+ if not only_loss:
518
+ ks = [1, 2, 3, 4, 5, 10, 20, 30, 40, 50, 100]
519
+ topkacc, mrocc = top_k_accuracy(ys_batch, outputs, k=ks, ret_arocc=True, ret_mrocc=False)
520
+ # mrocc -- median rank of correct choice
521
+ for k, tkacc in zip(ks, topkacc):
522
+ #iterative average update
523
+ new_hist[f't{k}_acc_{split}'] += (tkacc-new_hist[f't{k}_acc_{split}']) / (step+1)
524
+ # todo weight by batch-size
525
+ new_hist[f'meanrank_{split}'] = mrocc
526
+
527
+ y_preds[step*bs : min((step+1)*bs,len(y_preds))] = outputs_np
528
+
529
+
530
+ new_hist[f'steps_{split}'] = (self.steps)
531
+ new_hist[f'loss_{split}'] = (loss_metrics['loss'] / (step+1))
532
+
533
+ for k in new_hist:
534
+ self.hist[k].append(new_hist[k])
535
+
536
+ if wandb:
537
+ wandb.log(new_hist)
538
+
539
+
540
+ self.hist[f'loss_{split}'].append(loss_metrics[f'loss'] / (step+1))
541
+
542
+ return y_preds
543
+
544
+ def save_hist(self, prefix='', postfix=''):
545
+ HIST_PATH = 'data/hist/'
546
+ if not os.path.exists(HIST_PATH):
547
+ os.mkdir(HIST_PATH)
548
+ fn_hist = HIST_PATH+prefix+postfix+'.csv'
549
+ with open(fn_hist, 'w') as fh:
550
+ print(dict(self.hist), file=fh)
551
+ return fn_hist
552
+
553
+ def save_model(self, prefix='', postfix='', name_as_conf=False):
554
+ MODEL_PATH = 'data/model/'
555
+ if not os.path.exists(MODEL_PATH):
556
+ os.mkdir(MODEL_PATH)
557
+ if name_as_conf:
558
+ confi_str = str(self.config.__dict__.values()).replace("'","").replace(': ','_').replace(', ',';')
559
+ else:
560
+ confi_str = ''
561
+ model_name = prefix+confi_str+postfix+'.pt'
562
+ torch.save(self.state_dict(), MODEL_PATH+model_name)
563
+ return MODEL_PATH+model_name
564
+
565
+ def plot_loss(self):
566
+ plot_loss(self.hist)
567
+
568
+ def plot_topk(self, sets=['train', 'valid', 'test'], with_last = 2):
569
+ plot_topk(self.hist, sets=sets, with_last = with_last)
570
+
571
+ def plot_nte(self, last_cpt=1, dataset='Sm', include_bar=True):
572
+ plot_nte(self.hist, dataset=dataset, last_cpt=last_cpt, include_bar=include_bar)
573
+
574
+
575
+ class SeglerBaseline(MHN):
576
+ """FFNN - only the Molecule Encoder + an output projection"""
577
+ def __init__(self, config=None):
578
+ config.template_fp_type = 'none'
579
+ config.temp_encoder_layers = 0
580
+ super().__init__(config, use_template_encoder=False)
581
+ self.W_out = torch.nn.Linear(config.hopf_asso_dim, config.num_templates)
582
+ self.optimizer = getattr(torch.optim, self.config.optimizer)(self.parameters(), lr=self.lr)
583
+ self.steps = 0
584
+ self.hist = defaultdict(list)
585
+ self.to(self.config.device)
586
+
587
+ def forward(self, m, templates=None):
588
+ """
589
+ m: molecule in the form batch x fingerprint
590
+ templates: won't be used in this case
591
+ returns logits ranking the templates for each molecule
592
+ """
593
+ bs = m.shape[0] #batch_size
594
+ Xi = self.mol_encoder(m)
595
+ Xi = self.mol_encoder.af(Xi) # is not applied in encoder for last layer
596
+ out = self.W_out(Xi) # [bs, T] # softmax over dim=1
597
+ return out
598
+
599
+ class StaticQK(MHN):
600
+ """ Static QK baseline - beware to have the same fingerprint for mol_encoder as for the template_encoder (fp2048 r4 rdk by default)"""
601
+ def __init__(self, config=None):
602
+ if config:
603
+ self.config = config
604
+ else:
605
+ self.config = ModelConfig()
606
+ super().__init__(config)
607
+
608
+ self.fp_size = 2048
609
+ self.fingerprint_type = 'rdk'
610
+ self.beta = 1
611
+
612
+ def update_template_embedding(self, which='rdk', fp_size=2048, radius=4, learnable=False):
613
+ bs = self.config.batch_size
614
+ split_template_list = [t.split('>>')[0].split('.') for t in self.template_list]
615
+ self.templates = torch.from_numpy(convert_smiles_to_fp(split_template_list,
616
+ is_smarts=True, fp_size=fp_size,
617
+ radius=radius, which=which).max(1)).float().to(self.config.device)
618
+
619
+
620
+ def forward(self, m, templates=None):
621
+ """
622
+
623
+ """
624
+ #states_emb = self.fcfe(state_fp)
625
+ bs = m.shape[0] #batch_size
626
+
627
+ Xi = m #[bs, emb]
628
+ X = self.templates #[T, emb])
629
+
630
+ XXi = Xi@X.T # [bs, T]
631
+
632
+ # normalize
633
+ t_sum = templates.sum(1) #[T]
634
+ t_sum = t_sum.view(1,-1).expand(bs, -1) #[bs, T]
635
+ XXi = XXi / t_sum
636
+
637
+ # not neccecaire because it is not trained
638
+ out = self.beta*XXi # [bs, T] # softmax over dim=1
639
+ return out
640
+
641
+ class Retrosim(StaticQK):
642
+ """ Retrosim-like baseline only for template relevance prediction """
643
+ def fit_with_train(self, X_fp_train, y_train):
644
+ self.templates = torch.from_numpy(X_fp_train).float().to(self.config.device)
645
+ # train_samples, num_templates
646
+ self.sample2acttemplate = torch.nn.functional.one_hot(torch.from_numpy(y_train), self.config.num_templates).float()
647
+ tmpnorm = self.sample2acttemplate.sum(0)
648
+ tmpnorm[tmpnorm==0] = 1
649
+ self.sample2acttemplate = (self.sample2acttemplate / tmpnorm).to(self.config.device) # results in an average after dot product
650
+
651
+ def forward(self, m, templates=None):
652
+ """
653
+ """
654
+ out = super().forward(m, templates=templates)
655
+ # bs, train_samples
656
+
657
+ # map out to actual templates
658
+ out = out @ self.sample2acttemplate
659
+
660
+ return out
mhnreact/molutils.py ADDED
@@ -0,0 +1,772 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ Author: Philipp Seidl, Philipp Renz
4
+ ELLIS Unit Linz, LIT AI Lab, Institute for Machine Learning
5
+ Johannes Kepler University Linz
6
+ Contact: seidl@ml.jku.at
7
+
8
+ Molutils contains functions that aid in handling molecules or templates
9
+ """
10
+
11
+ import logging
12
+ import re
13
+ import warnings
14
+ from itertools import product, permutations
15
+
16
+ from multiprocessing import Pool
17
+ from tqdm.contrib.concurrent import process_map
18
+ from tqdm.notebook import tqdm
19
+ import swifter
20
+
21
+ import rdkit.RDLogger as rkl
22
+ from rdkit import Chem
23
+ from rdkit.Chem import AllChem
24
+ from rdkit.Chem.rdMolDescriptors import GetMorganFingerprint
25
+ from rdkit.Chem.rdmolops import FastFindRings
26
+ from rdkit.Chem.rdMHFPFingerprint import MHFPEncoder
27
+
28
+ from scipy import sparse
29
+ from sklearn.feature_extraction import DictVectorizer
30
+
31
+ import warnings
32
+ import rdkit.RDLogger as rkl
33
+ import numpy as np
34
+
35
+ log = logging.getLogger(__name__)
36
+ logger = rkl.logger()
37
+
38
+ def remove_attom_mapping(smiles):
39
+ """ removes a number after a ':' """
40
+ return re.sub(r':\d+', '', str(smiles))
41
+
42
+
43
+ def canonicalize_smi(smi, is_smarts=False, remove_atom_mapping=True):
44
+ r"""
45
+ Canonicalize SMARTS from https://github.com/rxn4chemistry/rxnfp/blob/master/rxnfp/tokenization.py#L249
46
+ """
47
+ mol = Chem.MolFromSmarts(smi)
48
+ if not mol:
49
+ raise ValueError("Molecule not canonicalizable")
50
+ if remove_atom_mapping:
51
+ for atom in mol.GetAtoms():
52
+ if atom.HasProp("molAtomMapNumber"):
53
+ atom.ClearProp("molAtomMapNumber")
54
+ return Chem.MolToSmiles(mol)
55
+
56
+
57
+ def canonicalize_template(smarts):
58
+ smarts = str(smarts)
59
+ # remove attom-mapping
60
+ #smarts = remove_attom_mapping(smarts)
61
+
62
+ # order the list of smiles + canonicalize it
63
+ results = []
64
+ for part in smarts.split('>>'):
65
+ a = part.split('.')
66
+ a = [canonicalize_smi(x, is_smarts=True, remove_atom_mapping=True) for x in a]
67
+ #a = [remove_attom_mapping(x) for x in a]
68
+ a.sort()
69
+ results.append( '.'.join(a) )
70
+ return '>>'.join(results)
71
+
72
+ def ebv2np(ebv):
73
+ """Explicit bit vector returned by rdkit to numpy array. """
74
+ return np.frombuffer(bytes(ebv.ToBitString(), 'utf-8'), 'u1') - ord('0')
75
+
76
+ def smiles2morgan(smiles, radius=2):
77
+ """ computes ecfp from smiles """
78
+ return GetMorganFingerprint(smiles, radius)
79
+
80
+
81
+ def getFingerprint(smiles, fp_size=4096, radius=2, is_smarts=False, which='morgan', sanitize=True):
82
+ """maccs+morganc+topologicaltorsion+erg+atompair+pattern+rdkc"""
83
+ if isinstance(smiles, list):
84
+ return np.array([getFingerprint(smi, fp_size, radius, is_smarts, which) for smi in smiles]).max(0) # max pooling if it's list of lists
85
+
86
+ if is_smarts:
87
+ mol = Chem.MolFromSmarts(str(smiles), mergeHs=False)
88
+ #mol.UpdatePropertyCache() #Correcting valence info
89
+ #FastFindRings(mol) #Providing ring info
90
+ else:
91
+ mol = Chem.MolFromSmiles(str(smiles), sanitize=False)
92
+
93
+ if mol is None:
94
+ msg = f"{smiles} couldn't be converted to a fingerprint using 0's instead"
95
+ logger.warning(msg)
96
+ #warnings.warn(msg)
97
+ return np.zeros(fp_size).astype(np.bool)
98
+
99
+ if sanitize:
100
+ faild_op = Chem.SanitizeMol(mol, catchErrors=True)
101
+ FastFindRings(mol) #Providing ring info
102
+
103
+ mol.UpdatePropertyCache(strict=False) #Correcting valence info # important operation
104
+
105
+ def mol2np(mol, which, fp_size):
106
+ is_dict = False
107
+ if which=='morgan':
108
+ fp = AllChem.GetMorganFingerprintAsBitVect(mol, radius, nBits=fp_size, useFeatures=False, useChirality=True)
109
+ elif which=='rdk':
110
+ fp = Chem.RDKFingerprint(mol, fpSize=fp_size, maxPath=6)
111
+ elif which=='rdkc':
112
+ # https://greglandrum.github.io/rdkit-blog/similarity/reference/2021/05/26/similarity-threshold-observations1.html
113
+ # -- maxPath 6 found to be better for retrieval in databases
114
+ fp = AllChem.UnfoldedRDKFingerprintCountBased(mol, maxPath=6).GetNonzeroElements()
115
+ is_dict = True
116
+ elif which=='morganc':
117
+ fp = AllChem.GetMorganFingerprint(mol, radius, useChirality=True, useBondTypes=True, useFeatures=True, useCounts=True).GetNonzeroElements()
118
+ is_dict = True
119
+ elif which=='topologicaltorsion':
120
+ fp = AllChem.GetTopologicalTorsionFingerprint(mol).GetNonzeroElements()
121
+ is_dict = True
122
+ elif which=='maccs':
123
+ fp = AllChem.GetMACCSKeysFingerprint(mol)
124
+ elif which=='erg':
125
+ v = AllChem.GetErGFingerprint(mol)
126
+ fp = {idx:v[idx] for idx in np.nonzero(v)[0]}
127
+ is_dict = True
128
+ elif which=='atompair':
129
+ fp = AllChem.GetAtomPairFingerprint(mol).GetNonzeroElements()
130
+ is_dict = True
131
+ elif which=='pattern':
132
+ fp = Chem.PatternFingerprint(mol, fpSize=fp_size)
133
+ elif which=='ecfp4':
134
+ # roughly equivalent to ECFP4
135
+ fp = AllChem.GetMorganFingerprintAsBitVect(mol, 2, nBits=fp_size, useFeatures=False, useChirality=True)
136
+ elif which=='layered':
137
+ fp = AllChem.LayeredFingerprint(mol, fpSize=fp_size, maxPath=7)
138
+ elif which=='mhfp':
139
+ #TODO check if one can avoid instantiating the MHFP encoder
140
+ fp = MHFPEncoder().EncodeMol(mol, radius=radius, rings=True, isomeric=False, kekulize=False, min_radius=1)
141
+ fp = {f:1 for f in fp}
142
+ is_dict = True
143
+ elif not (type(which)==str):
144
+ fp = which(mol)
145
+
146
+ if is_dict:
147
+ nd = np.zeros(fp_size)
148
+ for k in fp:
149
+ nk = k%fp_size #remainder
150
+ #print(nk, k, fp_size)
151
+ #3160 36322170 3730
152
+ #print(nd[nk], fp[k])
153
+ if nd[nk]!=0:
154
+ #print('c',end='')
155
+ nd[nk] = nd[nk]+fp[k] #pooling colisions
156
+ nd[nk] = fp[k]
157
+
158
+ return nd #np.log(1+nd) # discussion with segler
159
+
160
+ return ebv2np(fp)
161
+
162
+ """ + for folding * for concat """
163
+ cc_symb = '*'
164
+ if ('+' in which) or (cc_symb in which):
165
+ concat = False
166
+ split_sym = '+'
167
+ if cc_symb in which:
168
+ concat=True
169
+ split_sym = '*'
170
+
171
+ np_fp = np.zeros(fp_size)
172
+
173
+ remaining_fps = (which.count(split_sym)+1)
174
+ fp_length_remain = fp_size
175
+
176
+ for fp_type in which.split(split_sym):
177
+ if concat:
178
+ fpp = mol2np(mol, fp_type, fp_length_remain//remaining_fps)
179
+ np_fp[(fp_size-fp_length_remain):(fp_size-fp_length_remain+len(fpp))] += fpp
180
+ fp_length_remain -= len(fpp)
181
+ remaining_fps -=1
182
+ else:
183
+ try:
184
+ fpp = mol2np(mol, fp_type, fp_size)
185
+ np_fp[:len(fpp)] += fpp
186
+ except:
187
+ pass
188
+ #print(fp_type,end='')
189
+
190
+ return np.log(1 + np_fp)
191
+ else:
192
+ return mol2np(mol, which, fp_size)
193
+
194
+
195
+ def _getFingerprint(inp):
196
+ return getFingerprint(inp[0], inp[1], inp[2], inp[3], inp[4])
197
+
198
+
199
+ def disable_rdkit_logging():
200
+ """
201
+ Disables RDKit whiny logging.
202
+ """
203
+ import rdkit.rdBase as rkrb
204
+ import rdkit.RDLogger as rkl
205
+ logger.setLevel(rkl.ERROR)
206
+ rkrb.DisableLog('rdApp.error')
207
+
208
+
209
+ def convert_smiles_to_fp(list_of_smiles, fp_size=2048, is_smarts=False, which='morgan', radius=2, njobs=1, verbose=False):
210
+ """
211
+ list of smiles can be list of lists, than the resulting array will pe badded to the max list len
212
+ which: morgan, rdk, ecfp4, or object
213
+ NOTE: morgan or ecfp4 throws error for is_smarts
214
+ """
215
+
216
+ inp = [(smi, fp_size, radius, is_smarts, which) for smi in list_of_smiles]
217
+ #print(inp)
218
+ if verbose: print(f'starting pool with {njobs} workers')
219
+ if njobs>1:
220
+ #with Pool(njobs) as pool:
221
+ # fps = pool.map(_getFingerprint, inp)
222
+ fps = process_map(_getFingerprint, inp, max_workers=njobs, chunksize=1, mininterval=0)
223
+ else:
224
+ fps = [getFingerprint(smi, fp_size=fp_size, radius=radius, is_smarts=is_smarts, which=which) for smi in list_of_smiles]
225
+ return np.array(fps)
226
+
227
+
228
+ def convert_smartes_to_fp(list_of_smarts, fp_size=2048):
229
+ if isinstance(list_of_smarts, np.ndarray):
230
+ list_of_smarts = list_of_smarts.tolist()
231
+ if isinstance(list_of_smarts, list):
232
+ if isinstance(list_of_smarts[0], list):
233
+ pad = len(max(list_of_smarts, key=len))
234
+ fps = [[getTemplateFingerprint(smarts, fp_size=fp_size) for smarts in sample]
235
+ + [np.zeros(fp_size, dtype=np.bool)] * (pad - len(sample)) # zero padding
236
+ for sample in list_of_smarts]
237
+ else:
238
+ fps = [[getTemplateFingerprint(smarts, fp_size=fp_size) for smarts in list_of_smarts]]
239
+ return np.asarray(fps)
240
+
241
+
242
+ def get_reactants_from_smarts(smarts):
243
+ """
244
+ from a (forward-)reaction given as a smart, only returns the reactants (not e.g. solvents or reagents)
245
+ returns list of smiles or empty list
246
+ """
247
+ from rdkit.Chem import RDConfig
248
+ import sys
249
+ sys.path.append(RDConfig.RDContribDir)
250
+ from RxnRoleAssignment import identifyReactants
251
+ try:
252
+ rdk_reaction = AllChem.ReactionFromSmarts(smarts)
253
+ rx_idx = identifyReactants.identifyReactants(rdk_reaction)[0][0]
254
+ except ValueError:
255
+ return []
256
+ # TODO what if a product is recognized as a reactanat.. is that possible??
257
+ return [Chem.MolToSmiles(rdk_reaction.GetReactants()[i]) for i in rx_idx]
258
+
259
+
260
+ def smarts2rdkfp(smart, fp_size=2048):
261
+ mol = Chem.MolFromSmarts(str(smart))
262
+ if mol is None: return np.zeros(fp_size).astype(np.bool)
263
+ return AllChem.RDKFingerprint(mol)
264
+ # fp = np.asarray(fp).astype(np.bool) # takes ages =/
265
+
266
+
267
+ def smiles2rdkfp(smiles, fp_size=2048):
268
+ mol = Chem.MolFromSmiles(str(smiles))
269
+ if mol is None: return np.zeros(fp_size).astype(np.bool)
270
+ return AllChem.RDKFingerprint(mol)
271
+
272
+
273
+ def mol2morganfp(mol, radius=2, fp_size=2048):
274
+ try:
275
+ Chem.SanitizeMol(mol) # due to error --> see https://sourceforge.net/p/rdkit/mailman/message/34828604/
276
+ except:
277
+ pass
278
+ # print(mol)
279
+ # return np.zeros(fp_size).astype(np.bool)
280
+ # TODO
281
+ return AllChem.GetMorganFingerprintAsBitVect(mol, radius, nBits=fp_size)
282
+
283
+
284
+ def smarts2morganfp(smart, fp_size=2048, radius=2):
285
+ mol = Chem.MolFromSmarts(str(smart))
286
+ if mol is None: return np.zeros(fp_size).astype(np.bool)
287
+ return mol2morganfp(mol)
288
+
289
+
290
+ def smiles2morganfp(smiles, fp_size=2048, radius=2):
291
+ mol = Chem.MolFromSmiles(str(smiles))
292
+ if mol is None: return np.zeros(fp_size).astype(np.bool)
293
+ return mol2morganfp(mol)
294
+
295
+
296
+ def smarts2fp(smart, which='morgan', fp_size=2048, radius=2):
297
+ if which == 'rdk':
298
+ return smarts2rdkfp(smart, fp_size=fp_size)
299
+ else:
300
+ return smarts2morganfp(smart, fp_size=fp_size, radius=radius)
301
+
302
+
303
+ def smiles2fp(smiles, which='morgan', fp_size=2048, radius=2):
304
+ if which == 'rdk':
305
+ return smiles2rdkfp(smiles, fp_size=fp_size)
306
+ else:
307
+ return smiles2morganfp(smiles, fp_size=fp_size, radius=radius)
308
+
309
+
310
+ class FP_featurizer():
311
+ "FP_featurizer: Fingerprint featurizer"
312
+ def __init__(self,
313
+ fp_types = ['MACCS','Morgan2CBF', 'Morgan6CBF', 'ErG','AtomPair','TopologicalTorsion','RDK','ECFP6'],
314
+ max_features = 4096, counts=True, log_scale=True, folding=None, collision_pooling='max'):
315
+
316
+ self.v = DictVectorizer(sparse=True, dtype=np.uint16)
317
+ self.max_features = max_features
318
+ self.idx_col = None
319
+ self.counts = counts
320
+ self.fp_types = [fp_types] if isinstance(fp_types, str) else fp_types
321
+
322
+ self.log_scale = log_scale # from discussion with segler
323
+
324
+ self.folding = None
325
+ self.colision_pooling = collision_pooling
326
+
327
+ def compute_fp_list(self, smiles_list, is_smarts=False):
328
+ fp_list = []
329
+ for smiles in smiles_list:
330
+ try:
331
+ if isinstance(smiles, list):
332
+ smiles = smiles[0]
333
+ if is_smarts:
334
+ mol = Chem.MolFromSmarts(smiles)
335
+ else:
336
+ mol = Chem.MolFromSmiles(smiles) #TODO small hack only applicable here!!!
337
+ fp_dict = {}
338
+ for fp_type in self.fp_types:
339
+ fp_dict.update( fingerprintTypes[fp_type](mol) ) #returns a dict
340
+ fp_list.append(fp_dict)
341
+ except:
342
+ fp_list.append({})
343
+ return fp_list
344
+
345
+ def fit(self, x_train, is_smarts=False):
346
+ fp_list = self.compute_fp_list(x_train, is_smarts=is_smarts)
347
+ Xraw = self.v.fit_transform(fp_list)
348
+ # compute variance of a csr_matrix E[x**2] - E[x]**2
349
+ axis = 0
350
+ Xraw_sqrd = Xraw.copy()
351
+ Xraw_sqrd.data **= 2
352
+ var_col = Xraw_sqrd.mean(axis) - np.square(Xraw.mean(axis))
353
+ #idx_col = (-np.array((Xraw>0).var(axis=0)).argpartition(self.max_features))
354
+ #idx_col = np.array((Xraw>0).sum(axis=0)>=self.min_fragm_occur).flatten()
355
+ self.idx_col = (-np.array(var_col)).flatten().argpartition(min(self.max_features, Xraw.shape[1]-1))[:min(self.max_features, Xraw.shape[1])]
356
+ print(f'from {var_col.shape[1]} to {len(self.idx_col)}')
357
+ return self.scale(Xraw[:,self.idx_col].toarray())
358
+
359
+ def transform(self, x_test, is_smarts=False):
360
+ fp_list = self.compute_fp_list(x_test, is_smarts=is_smarts)
361
+ X_raw = self.v.transform(fp_list)
362
+ return self.scale(X_raw[:,self.idx_col].toarray())
363
+
364
+ def scale(self, X):
365
+ if self.log_scale:
366
+ return np.log(1 + X)
367
+ return X
368
+
369
+ def save(self, path='data/fpfeat.pkl'):
370
+ import pickle
371
+ with open(path, 'wb') as output:
372
+ pickle.dump(self, output, pickle.HIGHEST_PROTOCOL)
373
+
374
+ def load(self, path='data/fpfeat.pkl'):
375
+ import pickle
376
+ with open(path, 'rb') as input:
377
+ self = pickle.load(input)
378
+
379
+
380
+ def getTemplateFingerprintOnBits(smarts, fp_size=2048):
381
+ rxn = AllChem.ReactionFromSmarts(str(smarts))
382
+ #construct a structural fingerprint for a ChemicalReaction by concatenating the reactant fingerprint and the product fingerprint
383
+ return (AllChem.CreateStructuralFingerprintForReaction(rxn)).GetOnBits()
384
+
385
+
386
+ def calc_template_fingerprint_group_mapping(template_list, fp_size, save_path=''):
387
+ """
388
+ calculate the mapping from old idx to new idx for the templates
389
+ returns a set with a numpy array with the mapping and the indices to take
390
+ """
391
+
392
+ templ_df = pd.DataFrame()
393
+ templ_df['smarts'] = template_list
394
+ templ_df['templ_emb'] = templ_df['smarts'].swifter.apply(lambda smarts: str(list(getTemplateFingerprintOnBits(smarts, fp_size))))
395
+ templ_df['idx_orig'] = [ii for ii in range(len(templ_df))]
396
+
397
+ grouped_templ = templ_df.groupby('templ_emb').apply(lambda x: x.index.tolist())
398
+
399
+ grouped_templ = templ_df.groupby('templ_emb')
400
+ grouped_templ = grouped_templ.min().sort_values('idx_orig')
401
+ grouped_templ['new_idx'] = range(len(grouped_templ))
402
+
403
+ new_templ_df = templ_df.join(grouped_templ, on='templ_emb',how='right', lsuffix='_l', rsuffix='_r').sort_values('idx_orig_l')
404
+
405
+ map_orig2new = new_templ_df['new_idx'].values
406
+ take_those_indices_from_orig = grouped_templ.idx_orig.values
407
+ if save_path!='':
408
+ suffix_maporig2new = '_maporig2new.npy'
409
+ suffix_takethose = '_tfp_take_idxs.npy'
410
+ np.save(f'{save_path}{suffix_maporig2new}', map_orig2new,allow_pickle=False)
411
+ np.save(f'{save_path}{suffix_takethose}', take_those_indices_from_orig,allow_pickle=False)
412
+ return (map_orig2new, take_those_indices_from_orig)
413
+
414
+
415
+ class ECFC_featurizer():
416
+ def __init__(self, radius=6, min_fragm_occur=50, useChirality=True, useFeatures=False):
417
+ self.v = DictVectorizer(sparse=True, dtype=np.uint16)
418
+ self.min_fragm_occur=min_fragm_occur
419
+ self.idx_col = None
420
+ self.radius=radius
421
+ self.useChirality = useChirality
422
+ self.useFeatures = useFeatures
423
+
424
+ def compute_fp_list(self, smiles_list):
425
+ fp_list = []
426
+ for smiles in smiles_list:
427
+ try:
428
+ if isinstance(smiles, list):
429
+ smiles = smiles[0]
430
+ mol = Chem.MolFromSmiles(smiles) #TODO small hack only applicable here!!!
431
+ fp_list.append( AllChem.GetMorganFingerprint(mol, self.radius, useChirality=self.useChirality,
432
+ useFeatures=self.useFeatures).GetNonzeroElements() ) #returns a dict
433
+ except:
434
+ fp_list.append({})
435
+ return fp_list
436
+
437
+ def fit(self, x_train):
438
+ fp_list = self.compute_fp_list(x_train)
439
+ Xraw = self.v.fit_transform(fp_list)
440
+ idx_col = np.array((Xraw>0).sum(axis=0)>=self.min_fragm_occur).flatten()
441
+ self.idx_col = idx_col
442
+ return Xraw[:,self.idx_col].toarray()
443
+
444
+ def transform(self, x_test):
445
+ fp_list = self.compute_fp_list(x_test)
446
+ X_raw = self.v.transform(fp_list)
447
+ return X_raw[:,self.idx_col].toarray()
448
+
449
+
450
+ def ecfp2dict(mol, radius=3):
451
+ #SECFP (SMILES Extended Connectifity Fingerprint)
452
+ # from mhfp.encoder import MHFPEncoder
453
+ from mhfp.encoder import MHFPEncoder
454
+ v = MHFPEncoder.secfp_from_mol(mol, length=4068, radius=radius, rings=True, kekulize=True, min_radius=1)
455
+ return {f'ECFP{radius*2}_'+str(idx):1 for idx in np.nonzero(v)[0]}
456
+
457
+
458
+ def erg2dict(mol):
459
+ v = AllChem.GetErGFingerprint(mol)
460
+ return {'erg'+str(idx):v[idx] for idx in np.nonzero(v)[0]}
461
+
462
+
463
+ def morgan2dict(mol, radius=2, useChirality=True, useBondTypes=True, useFeatures=True, useConts=True):
464
+ mdic = AllChem.GetMorganFingerprint(mol, radius=radius, useChirality=useChirality, useBondTypes=True,
465
+ useFeatures=True, useCounts=True).GetNonzeroElements()
466
+ return {f'm{radius}{useChirality}{useBondTypes}{useFeatures}'+str(kk):mdic[kk]for kk in mdic}
467
+
468
+
469
+ def atompair2dict(mol):
470
+ mdic = AllChem.GetAtomPairFingerprint(mol).GetNonzeroElements()
471
+ return {f'ap'+str(kk):mdic[kk]for kk in mdic}
472
+
473
+
474
+ def tt2dict(mol):
475
+ mdic = AllChem.GetTopologicalTorsionFingerprint(mol).GetNonzeroElements()
476
+ return {f'tt'+str(kk):mdic[kk]for kk in mdic}
477
+
478
+
479
+ def rdk2dict(mol):
480
+ mdic = AllChem.UnfoldedRDKFingerprintCountBased(mol).GetNonzeroElements()
481
+ return {f'rdk'+str(kk):mdic[kk]for kk in mdic}
482
+
483
+
484
+ def pattern2dict(mol):
485
+ mdic = AllChem.PatternFingerprint(mol, fpSize=16384).GetOnBits()
486
+ return {'pt'+str(kk):1 for kk in mdic}
487
+
488
+
489
+ fingerprintTypes = {
490
+ 'MACCS' : lambda k: {'MCCS'+str(ob):1 for ob in AllChem.GetMACCSKeysFingerprint(k).GetOnBits()},
491
+ 'Morgan2CBF' : lambda mol: morgan2dict(mol, 2, True, True, True, True),
492
+ 'Morgan4CBF' : lambda mol: morgan2dict(mol, 4, True, True, True, True),
493
+ 'Morgan6CBF' : lambda mol: morgan2dict(mol, 6, True, True, True, True),
494
+ 'ErG' : erg2dict,
495
+ 'AtomPair' : atompair2dict,
496
+ 'TopologicalTorsion' : tt2dict,
497
+ #'RDK' : lambda k: {'MCCS'+str(ob):1 for ob in AllChem.RDKFingerprint(k).GetOnBits()},
498
+ 'RDK' : rdk2dict,
499
+ 'ECFP6' : lambda mol: ecfp2dict(mol, radius=3),
500
+ 'Pattern': pattern2dict,
501
+ }
502
+
503
+
504
+ def smarts2appl(product_smarts, template_product_smarts, fpsize=2048, v=False, use_tqdm=False, njobs=1, nsplits=1):
505
+ """This takes in a list of product smiles (misnamed in code) and a list of product sides
506
+ of templates and calculates which templates are applicable to which product.
507
+ This is basically a substructure search. Maybe there are faster versions but I wrote this one.
508
+
509
+ Args:
510
+ product_smarts: List of smiles of molecules to check.
511
+ template_product_smarts: List of substructures to check
512
+ fpsize: fingerprint size to use in screening
513
+ v: if v then information will be printed
514
+ use_tdqm: if True then a progressbar will be displayed but slows down the computation.
515
+ njobs: how many parallel jobs to run in parallel.
516
+ nsplits: how many splits should be made along the product_smarts list. Useful to avoid memory
517
+ explosion.
518
+ Returns: list of tuples (i,j) that indicates the product i has substructure j.
519
+ """
520
+ if v: print("Calculating template molecules")
521
+ template_mols = [Chem.MolFromSmarts(s) for s in template_product_smarts]
522
+ if v: print("Calculating template fingerprints")
523
+ template_ebvs = [Chem.PatternFingerprint(m, fpSize=fpsize) for m in template_mols]
524
+ if v: print(f'Building template ints: [{len(template_mols)}, {fpsize}]')
525
+ template_ints = [int(e.ToBitString(), base=2) for e in template_ebvs]
526
+ del template_ebvs
527
+
528
+ if njobs == 1 and nsplits == 1:
529
+ return _smarts2appl(product_smarts, template_product_smarts, template_ints, fpsize, v, use_tqdm)
530
+ elif nsplits == 1:
531
+ nsplits = njobs
532
+
533
+
534
+ # split products into batches
535
+ product_splits = np.array_split(np.array(product_smarts), nsplits)
536
+ ioffsets = [0] + list(np.cumsum([p.shape[0] for p in product_splits[:-1]]))
537
+ inps = [(ps, template_product_smarts, template_ints, fpsize, v, use_tqdm, ioff, 0) for ps, ioff in zip(product_splits, ioffsets)]
538
+
539
+ if v: print("Creating workers")
540
+ #results = process_map(__smarts2appl, inps, max_workers=njobs, chunksize=1)
541
+ with Pool(njobs) as pool:
542
+ results = pool.starmap(_smarts2appl, inps)
543
+ imatch = np.concatenate([r[0] for r in results])
544
+ jmatch = np.concatenate([r[1] for r in results])
545
+ return imatch, jmatch
546
+
547
+
548
+ def __smarts2appl(inp):
549
+ return _smarts2appl(*inp)
550
+
551
+
552
+ def _smarts2appl(product_smarts, template_product_smarts, template_ints, fpsize=2048, v=False, use_tqdm=True, ioffset=0, joffset=0):
553
+ """See smarts2appl for a description"""
554
+
555
+ if v: print("Calculating product molecules")
556
+ product_mols = [Chem.MolFromSmiles(s) for s in product_smarts]
557
+ if v: print("Calculating product fingerprints")
558
+ product_ebvs = [Chem.PatternFingerprint(m, fpSize=fpsize) for m in product_mols]
559
+ if v: print(f'Building product ints: [{len(product_mols)}, {fpsize}]')
560
+ # This loads each fingerprint into a python integer on which we can use bitwise operations.
561
+ product_ints = [int(e.ToBitString(), base=2) for e in product_ebvs]
562
+ del product_ebvs
563
+
564
+ # product_mols = {i: m for i,m in enumerate(product_mols)}
565
+
566
+
567
+ if v: print('Checking symbolically')
568
+ # buffer for template molecules. This are handed over as smarts as they are slow to pickle
569
+ template_mols = {}
570
+
571
+ # create iterator and add progressbar if use_tqdm is True
572
+ iterator = product(enumerate(product_ints), enumerate(template_ints))
573
+ if use_tqdm:
574
+ nelem = len(product_ints) * len(template_ints)
575
+ iterator = tqdm(iterator, total=nelem, miniters=1_000_000)
576
+
577
+ imatch = []
578
+ jmatch = []
579
+ for (i, p_int), (j, t_int) in iterator:
580
+ if (p_int & t_int) == t_int: # fingerprint based screen
581
+ p = product_mols[i]
582
+ t = template_mols.get(j, False)
583
+ if not t:
584
+ t = Chem.MolFromSmarts(template_product_smarts[j])
585
+ template_mols[j] = t
586
+ if p.HasSubstructMatch(t):
587
+ imatch.append(i)
588
+ jmatch.append(j)
589
+ if v: print("Finished loop")
590
+ return np.array(imatch)+ioffset, np.array(jmatch)+joffset
591
+
592
+
593
+ def extract_from_reaction(reaction, radius=1, verbose=False):
594
+ """adapted from rdchiral package"""
595
+ from rdchiral.template_extractor import mols_from_smiles_list, replace_deuterated, get_fragments_for_changed_atoms, expand_changed_atom_tags, canonicalize_transform, get_changed_atoms
596
+ reactants = mols_from_smiles_list(replace_deuterated(reaction['reactants']).split('.'))
597
+ products = mols_from_smiles_list(replace_deuterated(reaction['products']).split('.'))
598
+
599
+ # if rdkit cant understand molecule, return
600
+ if None in reactants: return {'reaction_id': reaction['_id']}
601
+ if None in products: return {'reaction_id': reaction['_id']}
602
+
603
+ # try to sanitize molecules
604
+ try:
605
+ #for i in range(len(reactants)):
606
+ # reactants[i] = AllChem.RemoveHs(reactants[i]) # *might* not be safe
607
+ #for i in range(len(products)):
608
+ # products[i] = AllChem.RemoveHs(products[i]) # *might* not be safe
609
+
610
+ #[Chem.SanitizeMol(mol) for mol in reactants + products] # redundant w/ RemoveHs
611
+ for mol in reactants + products:
612
+ Chem.SanitizeMol(mol, catchErrors=True)
613
+ FastFindRings(mol) #Providing ring info
614
+ mol.UpdatePropertyCache(strict=False) #Correcting valence info # important operation
615
+
616
+ #changed
617
+ #[Chem.SanitizeMol(mol, catchErrors=True) for mol in reactants + products] # redundant w/ RemoveHs
618
+
619
+ #[mol.UpdatePropertyCache() for mol in reactants + products]
620
+ except Exception as e:
621
+ # can't sanitize -> skip
622
+ print(e)
623
+ print('Could not load SMILES or sanitize')
624
+ print('ID: {}'.format(reaction['_id']))
625
+ return {'reaction_id': reaction['_id']}
626
+
627
+ are_unmapped_product_atoms = False
628
+ extra_reactant_fragment = ''
629
+ for product in products:
630
+ prod_atoms = product.GetAtoms()
631
+ if sum([a.HasProp('molAtomMapNumber') for a in prod_atoms]) < len(prod_atoms):
632
+ if verbose: print('Not all product atoms have atom mapping')
633
+ if verbose: print('ID: {}'.format(reaction['_id']))
634
+ are_unmapped_product_atoms = True
635
+
636
+ if are_unmapped_product_atoms: # add fragment to template
637
+ for product in products:
638
+ prod_atoms = product.GetAtoms()
639
+ # Get unmapped atoms
640
+ unmapped_ids = [
641
+ a.GetIdx() for a in prod_atoms if not a.HasProp('molAtomMapNumber')
642
+ ]
643
+ if len(unmapped_ids) > MAXIMUM_NUMBER_UNMAPPED_PRODUCT_ATOMS:
644
+ # Skip this example - too many unmapped product atoms!
645
+ return
646
+ # Define new atom symbols for fragment with atom maps, generalizing fully
647
+ atom_symbols = ['[{}]'.format(a.GetSymbol()) for a in prod_atoms]
648
+ # And bond symbols...
649
+ bond_symbols = ['~' for b in product.GetBonds()]
650
+ if unmapped_ids:
651
+ extra_reactant_fragment += AllChem.MolFragmentToSmiles(
652
+ product, unmapped_ids,
653
+ allHsExplicit = False, isomericSmiles = USE_STEREOCHEMISTRY,
654
+ atomSymbols = atom_symbols, bondSymbols = bond_symbols
655
+ ) + '.'
656
+ if extra_reactant_fragment:
657
+ extra_reactant_fragment = extra_reactant_fragment[:-1]
658
+ if verbose: print(' extra reactant fragment: {}'.format(extra_reactant_fragment))
659
+
660
+ # Consolidate repeated fragments (stoichometry)
661
+ extra_reactant_fragment = '.'.join(sorted(list(set(extra_reactant_fragment.split('.')))))
662
+
663
+
664
+ if None in reactants + products:
665
+ print('Could not parse all molecules in reaction, skipping')
666
+ print('ID: {}'.format(reaction['_id']))
667
+ return {'reaction_id': reaction['_id']}
668
+
669
+ # Calculate changed atoms
670
+ changed_atoms, changed_atom_tags, err = get_changed_atoms(reactants, products)
671
+ if err:
672
+ if verbose:
673
+ print('Could not get changed atoms')
674
+ print('ID: {}'.format(reaction['_id']))
675
+ return
676
+ if not changed_atom_tags:
677
+ if verbose:
678
+ print('No atoms changed?')
679
+ print('ID: {}'.format(reaction['_id']))
680
+ # print('Reaction SMILES: {}'.format(example_doc['RXN_SMILES']))
681
+ return {'reaction_id': reaction['_id']}
682
+
683
+ try:
684
+ # Get fragments for reactants
685
+ reactant_fragments, intra_only, dimer_only = get_fragments_for_changed_atoms(reactants, changed_atom_tags,
686
+ radius = radius, expansion = [], category = 'reactants')
687
+ # Get fragments for products
688
+ # (WITHOUT matching groups but WITH the addition of reactant fragments)
689
+ product_fragments, _, _ = get_fragments_for_changed_atoms(products, changed_atom_tags,
690
+ radius = radius-1, expansion = expand_changed_atom_tags(changed_atom_tags, reactant_fragments),
691
+ category = 'products')
692
+ except ValueError as e:
693
+ if verbose:
694
+ print(e)
695
+ print(reaction['_id'])
696
+ return {'reaction_id': reaction['_id']}
697
+
698
+ # Put together and canonicalize (as best as possible)
699
+ rxn_string = '{}>>{}'.format(reactant_fragments, product_fragments)
700
+ rxn_canonical = canonicalize_transform(rxn_string)
701
+ # Change from inter-molecular to intra-molecular
702
+ rxn_canonical_split = rxn_canonical.split('>>')
703
+ rxn_canonical = rxn_canonical_split[0][1:-1].replace(').(', '.') + \
704
+ '>>' + rxn_canonical_split[1][1:-1].replace(').(', '.')
705
+
706
+ reactants_string = rxn_canonical.split('>>')[0]
707
+ products_string = rxn_canonical.split('>>')[1]
708
+
709
+ retro_canonical = products_string + '>>' + reactants_string
710
+
711
+ # Load into RDKit
712
+ rxn = AllChem.ReactionFromSmarts(retro_canonical)
713
+ # edited
714
+ #if rxn.Validate()[1] != 0:
715
+ # print('Could not validate reaction successfully')
716
+ # print('ID: {}'.format(reaction['_id']))
717
+ # print('retro_canonical: {}'.format(retro_canonical))
718
+ # if VERBOSE: raw_input('Pausing...')
719
+ # return {'reaction_id': reaction['_id']}
720
+ n_warning, n_errors = rxn.Validate()
721
+ if n_errors:
722
+ # resolves some errors
723
+ rxn = AllChem.ReactionFromSmarts(AllChem.ReactionToSmiles(rxn))
724
+ n_warning, n_errors = rxn.Validate()
725
+
726
+ template = {
727
+ 'products': products_string,
728
+ 'reactants': reactants_string,
729
+ 'reaction_smarts': retro_canonical,
730
+ 'intra_only': intra_only,
731
+ 'dimer_only': dimer_only,
732
+ 'reaction_id': reaction['_id'],
733
+ 'necessary_reagent': extra_reactant_fragment,
734
+ 'num_errors': n_errors,
735
+ 'num_warnings': n_warning,
736
+ }
737
+
738
+ return template
739
+
740
+
741
+ def extract_template(rxn_smi, radius=1):
742
+ if isinstance(rxn_smi, str):
743
+ reaction = {
744
+ 'reactants': rxn_smi.split('>')[0],
745
+ 'products': rxn_smi.split('>')[-1],
746
+ 'id': rxn_smi,
747
+ '_id': rxn_smi
748
+ }
749
+ else:
750
+ reaction = rxn_smi
751
+ try:
752
+ res = extract_from_reaction(reaction, radius=radius)
753
+ return res['reaction_smarts'] # returns a retro-template
754
+ except:
755
+ msg = f'failed to extract template from "{rxn_smi}"'
756
+ log.warning(msg)
757
+ return None
758
+
759
+
760
+ def getTemplateFingerprint(smarts, fp_size=4096):
761
+ """ CreateStructuralFingerprintForReaction """
762
+ if isinstance(smarts, (list,)):
763
+ return np.vstack([getTemplateFingerprint(sm) for sm in smarts])
764
+
765
+ rxn = AllChem.ReactionFromSmarts(str(smarts))
766
+ if rxn is None:
767
+ msg = f"{smarts} couldn't be converted to a fingerprint using 0's instead"
768
+ log.warning(msg)
769
+ #warnings.warn(msg)
770
+ return np.zeros(fp_size).astype(np.bool)
771
+
772
+ return np.array(list(AllChem.CreateStructuralFingerprintForReaction(rxn, )), dtype=np.bool)
mhnreact/plotutils.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ Author: Philipp Seidl
4
+ ELLIS Unit Linz, LIT AI Lab, Institute for Machine Learning
5
+ Johannes Kepler University Linz
6
+ Contact: seidl@ml.jku.at
7
+
8
+ Plot utils
9
+ """
10
+
11
+ import numpy as np
12
+ import pandas as pd
13
+ import matplotlib.pyplot as plt
14
+ from matplotlib import pyplot as plt
15
+
16
+ plt.style.use('default')
17
+
18
+
19
+ def normal_approx_interval(p_hat, n, z=1.96):
20
+ """ approximating the distribution of error about a binomially-distributed observation, {\hat {p)), with a normal distribution
21
+ z = 1.96 --> alpha =0.05
22
+ z = 1 --> std
23
+ https://www.wikiwand.com/en/Binomial_proportion_confidence_interval"""
24
+ return z*((p_hat*(1-p_hat))/n)**(1/2)
25
+
26
+
27
+ our_colors = {
28
+ "lightblue": ( 0/255, 132/255, 187/255),
29
+ "red": (217/255, 92/255, 76/255),
30
+ "blue": ( 0/255, 132/255, 187/255),
31
+ "green": ( 91/255, 167/255, 85/255),
32
+ "yellow": (241/255, 188/255, 63/255),
33
+ "cyan": ( 79/255, 176/255, 191/255),
34
+ "grey": (125/255, 130/255, 140/255),
35
+ "lightgreen":(191/255, 206/255, 82/255),
36
+ "violett": (174/255, 97/255, 157/255),
37
+ }
38
+
39
+
40
+ def plot_std(p_hats, n_samples,z=1.96, color=our_colors['red'], alpha=0.2, xs=None):
41
+ p_hats = np.array(p_hats)
42
+ stds = np.array([normal_approx_interval(p_hats[ii], n_samples[ii], z=z) for ii in range(len(p_hats))])
43
+ xs = range(len(p_hats)) if xs is None else xs
44
+ plt.fill_between(xs, p_hats-(stds), p_hats+stds, color=color, alpha=alpha)
45
+ #plt.errorbar(range(13), asdf, [normal_approx_interval(asdf[ii], n_samples[ii], z=z) for ii in range(len(asdf))],
46
+ # c=our_colors['red'], linestyle='None', marker='.', ecolor=our_colors['red'])
47
+
48
+
49
+ def plot_loss(hist):
50
+ plt.plot(hist['step'], hist['loss'] )
51
+ plt.plot(hist['steps_valid'], np.array(hist['loss_valid']))
52
+ plt.legend(['train','validation'])
53
+ plt.xlabel('update-step')
54
+ plt.ylabel('loss (categorical-crossentropy-loss)')
55
+
56
+
57
+ def plot_topk(hist, sets=['train', 'valid', 'test'], with_last = 2):
58
+ ks = [1, 2, 3, 4, 5, 10, 20, 30, 40, 50, 100]
59
+ baseline_val_res = {1:0.4061, 10:0.6827, 50: 0.7883, 100:0.8400}
60
+ plt.plot(list(baseline_val_res.keys()), list(baseline_val_res.values()), 'k.--')
61
+ for i in range(1,with_last):
62
+ for s in sets:
63
+ plt.plot(ks, [hist[f't{k}_acc_{s}'][-i] for k in ks],'.--', alpha=1/i)
64
+ plt.xlabel('top-k')
65
+ plt.ylabel('Accuracy')
66
+ plt.legend(sets)
67
+ plt.title('Hopfield-NN')
68
+ plt.ylim([-0.02,1])
69
+
70
+
71
+ def plot_nte(hist, dataset='Sm', last_cpt=1, include_bar=True, model_legend='MHN (ours)',
72
+ draw_std=True, z=1.96, n_samples=None, group_by_template_fp=False, schwaller_hist=None, fortunato_hist=None): #1.96 for 95%CI
73
+ markers = ['.']*4#['1','2','3','4']#['8','P','p','*']
74
+ lw = 2
75
+ ms = 8
76
+ k = 100
77
+ ntes = range(13)
78
+ if dataset=='Sm':
79
+ basel_values = [0. , 0.38424785, 0.66807858, 0.7916149 , 0.9051132 ,
80
+ 0.92531258, 0.87295875, 0.94865587, 0.91830721, 0.95993717,
81
+ 0.97215858, 0.9896713 , 0.99917817] #old basel_values = [0.0, 0.3882, 0.674, 0.7925, 0.9023, 0.9272, 0.874, 0.947, 0.9185, 0.959, 0.9717, 0.9927, 1.0]
82
+ pretr_values = [0.08439423, 0.70743412, 0.85555528, 0.95200267, 0.96513376,
83
+ 0.96976397, 0.98373613, 0.99960286, 0.98683919, 0.96684724,
84
+ 0.95907246, 0.9839079 , 0.98683919]# old [0.094, 0.711, 0.8584, 0.952, 0.9683, 0.9717, 0.988, 1.0, 1.0, 0.984, 0.9717, 1.0, 1.0]
85
+ staticQK = [0.2096, 0.1992, 0.2291, 0.1787, 0.2301, 0.1753, 0.2142, 0.2693, 0.2651, 0.1786, 0.2834, 0.5366, 0.6636]
86
+ if group_by_template_fp:
87
+ staticQK = [0.2651, 0.2617, 0.261 , 0.2181, 0.2622, 0.2393, 0.2157, 0.2184, 0.2 , 0.225 , 0.2039, 0.4568, 0.5293]
88
+ if dataset=='Lg':
89
+ pretr_values = [0.03410448, 0.65397054, 0.7254572 , 0.78969294, 0.81329924,
90
+ 0.8651173 , 0.86775655, 0.8593128 , 0.88184124, 0.87764794,
91
+ 0.89734215, 0.93328846, 0.99531597]
92
+ basel_values = [0. , 0.62478044, 0.68784314, 0.75089511, 0.77044644,
93
+ 0.81229423, 0.82968149, 0.82965544, 0.83778338, 0.83049176,
94
+ 0.8662873 , 0.92308414, 1.00042408]
95
+ #staticQK = [0.03638, 0.0339 , 0.03732, 0.03506, 0.03717, 0.0331 , 0.03003, 0.03613, 0.0304 , 0.02109, 0.0297 , 0.02632, 0.02217] # on 90k templates
96
+ staticQK = [0.006416,0.00686, 0.00616, 0.00825, 0.005085,0.006718,0.01041, 0.0015335,0.006668,0.004673,0.001706,0.02551,0.04074]
97
+ if dataset=='Golden':
98
+ staticQK = [0]*13
99
+ pretr_values = [0]*13
100
+ basel_values = [0]*13
101
+
102
+ if schwaller_hist:
103
+ midx = np.argmin(schwaller_hist['loss_valid'])
104
+ basel_values = ([schwaller_hist[f't100_acc_nte_{k}'][midx] for k in [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, '>10', '>49']])
105
+ if fortunato_hist:
106
+ midx = np.argmin(fortunato_hist['loss_valid'])
107
+ pretr_values = ([fortunato_hist[f't100_acc_nte_{k}'][midx] for k in [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, '>10', '>49']])
108
+
109
+ #hand_val = [0.0 , 0.4, 0.68, 0.79, 0.89, 0.91, 0.86, 0.9,0.88, 0.9, 0.93]
110
+
111
+
112
+ if include_bar:
113
+ if dataset=='Sm':
114
+ if n_samples is None:
115
+ n_samples = [610, 1699, 287, 180, 143, 105, 70, 48, 124, 86, 68, 2539, 1648]
116
+ if group_by_template_fp:
117
+ n_samples = [460, 993, 433, 243, 183, 117, 102, 87, 110, 80, 103, 3048, 2203]
118
+ if dataset=='Lg':
119
+ if n_samples is None:
120
+ n_samples = [18861, 32226, 4220, 2546, 1573, 1191, 865, 652, 1350, 642, 586, 11638, 4958] #new
121
+ if group_by_template_fp:
122
+ n_samples = [13923, 17709, 7637, 4322, 2936, 2137, 1586, 1260, 1272, 1044, 829, 21695, 10559]
123
+ #[5169, 15904, 2814, 1853, 1238, 966, 766, 609, 1316, 664, 640, 30699, 21471]
124
+ #[13424,17246, 7681, 4332, 2844,2129,1698,1269, 1336,1067, 833, 22491, 11202] #grouped fp
125
+ plt.bar(range(11+2), np.array(n_samples)/sum(n_samples[:-1]), alpha=0.4, color=our_colors['grey'])
126
+
127
+ xti = [*[str(i) for i in range(11)], '>10', '>49']
128
+ asdf = []
129
+ for nte in xti:
130
+ try:
131
+ asdf.append( hist[f't{k}_acc_nte_{nte}'][-last_cpt])
132
+ except:
133
+ asdf.append(None)
134
+
135
+ plt.plot(range(13), asdf,f'{markers[3]}--', markersize=ms,c=our_colors['red'], linewidth=lw,alpha=1)
136
+ plt.plot(ntes, pretr_values,f'{markers[1]}--', c=our_colors['green'],
137
+ linewidth=lw, alpha=1,markersize=ms) #old [0.08, 0.7, 0.85, 0.9, 0.91, 0.95, 0.98, 0.97,0.98, 1, 1]
138
+ plt.plot(ntes, basel_values,f'{markers[0]}--',linewidth=lw,
139
+ c=our_colors['blue'], markersize=ms,alpha=1)
140
+ plt.plot(range(len(staticQK)), staticQK, f'{markers[2]}--',markersize=ms,c=our_colors['yellow'],linewidth=lw, alpha=1)
141
+
142
+ plt.title(f'USPTO-{dataset}')
143
+ plt.xlabel('number of training examples')
144
+ plt.ylabel('top-100 test-accuracy')
145
+ plt.legend([model_legend, 'Fortunato et al.','FNN baseline',"FPM baseline", #static${\\xi X}: \\dfrac{|{\\xi} \\cap {X}|}{|{X}|}$
146
+ 'test sample proportion'])
147
+
148
+ if draw_std:
149
+ alpha=0.2
150
+ plot_std(asdf, n_samples, z=z, color=our_colors['red'], alpha=alpha)
151
+ plot_std(pretr_values, n_samples, z=z, color=our_colors['green'], alpha=alpha)
152
+ plot_std(basel_values, n_samples, z=z, color=our_colors['blue'], alpha=alpha)
153
+ plot_std(staticQK, n_samples, z=z, color=our_colors['yellow'], alpha=alpha)
154
+
155
+
156
+ plt.xticks(range(13),xti);
157
+ plt.yticks(np.arange(0,1.05,0.1))
158
+ plt.grid('on', alpha=0.3)
mhnreact/retroeval.py ADDED
@@ -0,0 +1,240 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ Author: Philipp Seidl, Philipp Renz
4
+ ELLIS Unit Linz, LIT AI Lab, Institute for Machine Learning
5
+ Johannes Kepler University Linz
6
+ Contact: seidl@ml.jku.at
7
+
8
+ Evaluation functions for single-step-retrosynthesis
9
+ """
10
+ import sys
11
+
12
+ import rdchiral
13
+ from rdchiral.main import rdchiralRun, rdchiralReaction, rdchiralReactants
14
+ import hashlib
15
+ from rdkit import Chem
16
+
17
+ import torch
18
+ import numpy as np
19
+ import pandas as pd
20
+ from collections import defaultdict
21
+ from copy import deepcopy
22
+ from glob import glob
23
+ import os
24
+ import pickle
25
+
26
+ from multiprocessing import Pool
27
+ import hashlib
28
+ import pickle
29
+ import logging
30
+
31
+ #import timeout_decorator
32
+
33
+
34
+ def _cont_hash(fn):
35
+ with open(fn, 'rb') as f:
36
+ return hashlib.md5(f.read()).hexdigest()
37
+
38
+ def load_templates_only(path, cache_dir='/tmp'):
39
+ arg_hash_base = 'load_templates_only' + path
40
+ arg_hash = hashlib.md5(arg_hash_base.encode()).hexdigest()
41
+ matches = glob(os.path.join(cache_dir, arg_hash+'*'))
42
+
43
+ if len(matches) > 1:
44
+ raise RuntimeError('Too many matches')
45
+ elif len(matches) == 1:
46
+ fn = matches[0]
47
+ content_hash = _cont_hash(path)
48
+ content_hash_file = os.path.basename(fn).split('_')[1].split('.')[0]
49
+ if content_hash_file == content_hash:
50
+ with open(fn, 'rb') as f:
51
+ return pickle.load(f)
52
+
53
+ df = pd.read_json(path)
54
+ template_dict = {}
55
+ for row in range(len(df)):
56
+ template_dict[df.iloc[row]['index']] = df.iloc[row].reaction_smarts
57
+
58
+ # cache the file
59
+ content_hash = _cont_hash(path)
60
+ fn = os.path.join(cache_dir, f"{arg_hash}_{content_hash}.p")
61
+ with open(fn, 'wb') as f:
62
+ pickle.dump(template_dict, f)
63
+
64
+ def load_templates_v2(path, get_complete_df=False):
65
+ if get_complete_df:
66
+ df = pd.read_json(path)
67
+ return df
68
+
69
+ return load_templates_only(path)
70
+
71
+ def canonicalize_reactants(smiles, can_steps=2):
72
+ if can_steps==0:
73
+ return smiles
74
+
75
+ mol = Chem.MolFromSmiles(smiles)
76
+ for a in mol.GetAtoms():
77
+ a.ClearProp('molAtomMapNumber')
78
+
79
+ smiles = Chem.MolToSmiles(mol, True)
80
+ if can_steps==1:
81
+ return smiles
82
+
83
+ smiles = Chem.MolToSmiles(Chem.MolFromSmiles(smiles), True)
84
+ if can_steps==2:
85
+ return smiles
86
+
87
+ raise ValueError("Invalid can_steps")
88
+
89
+
90
+
91
+ def load_test_set(fn):
92
+ df = pd.read_csv(fn, index_col=0)
93
+ test = df[df.dataset=='test']
94
+
95
+ test_product_smarts = list(test.prod_smiles) # we make predictions for these
96
+ for s in test_product_smarts:
97
+ assert len(s.split('.')) == 1
98
+ assert '>' not in s
99
+
100
+ test_reactants = [] # we want to predict these
101
+ for rs in list(test.rxn_smiles):
102
+ rs = rs.split('>>')
103
+ assert len(rs) == 2
104
+ reactants_ori, products = rs
105
+ reactants = reactants_ori.split('.')
106
+ products = products.split('.')
107
+ assert len(reactants) >= 1
108
+ assert len(products) == 1
109
+
110
+ test_reactants.append(reactants_ori)
111
+
112
+ return test_product_smarts, test_reactants
113
+
114
+
115
+ #@timeout_decorator.timeout(1, use_signals=False)
116
+ def time_out_rdchiralRun(temp, prod_rct, combine_enantiomers=False):
117
+ rxn = rdchiralReaction(temp)
118
+ return rdchiralRun(rxn, prod_rct, combine_enantiomers=combine_enantiomers)
119
+
120
+ def _run_templates_rdchiral(prod_appl):
121
+ prod, applicable_templates = prod_appl
122
+ prod_rct = rdchiralReactants(prod) # preprocess reactants with rdchiral
123
+
124
+ results = {}
125
+ for idx, temp in applicable_templates:
126
+ temp = str(temp)
127
+ try:
128
+ results[(idx, temp)] = time_out_rdchiralRun(temp, prod_rct, combine_enantiomers=False)
129
+ except:
130
+ pass
131
+
132
+ return results
133
+
134
+ def _run_templates_rdchiral_original(prod_appl):
135
+ prod, applicable_templates = prod_appl
136
+ prod_rct = rdchiralReactants(prod) # preprocess reactants with rdchiral
137
+
138
+ results = {}
139
+ rxn_cache = {}
140
+ for idx, temp in applicable_templates:
141
+ temp = str(temp)
142
+ if temp in rxn_cache:
143
+ rxn = rxn_cache[(temp)]
144
+ else:
145
+ try:
146
+ rxn = rdchiralReaction(temp)
147
+ rxn_cache[temp] = rxn
148
+ except:
149
+ rxn_cache[temp] = None
150
+ msg = temp+' error converting to rdchiralReaction'
151
+ logging.debug(msg)
152
+ try:
153
+ res = rdchiralRun(rxn, prod_rct, combine_enantiomers=False)
154
+ results[(idx, temp)] = res
155
+ except:
156
+ pass
157
+
158
+ return results
159
+
160
+ def run_templates(test_product_smarts, templates, appl, njobs=32, cache_dir='/tmp'):
161
+ appl_dict = defaultdict(list)
162
+ for i,j in zip(*appl):
163
+ appl_dict[i].append(j)
164
+
165
+ prod_appl_list = []
166
+ for prod_idx, prod in enumerate(test_product_smarts):
167
+ applicable_templates = [(idx, templates[idx]) for idx in appl_dict[prod_idx]]
168
+ prod_appl_list.append((prod, applicable_templates))
169
+
170
+ arg_hash = hashlib.md5(pickle.dumps(prod_appl_list)).hexdigest()
171
+ cache_file = os.path.join(cache_dir, arg_hash+'.p')
172
+
173
+ if os.path.isfile(cache_file):
174
+ with open(cache_file, 'rb') as f:
175
+ print('loading results from file',f)
176
+ all_results = pickle.load(f)
177
+
178
+ #find /tmp -type f \( ! -user root \) -atime +3 -delete
179
+ # to delete the tmp files that havent been accessed 3 days
180
+
181
+ else:
182
+ #with Pool(njobs) as pool:
183
+ # all_results = pool.map(_run_templates_rdchiral, prod_appl_list)
184
+
185
+ from tqdm.contrib.concurrent import process_map
186
+ all_results = process_map(_run_templates_rdchiral, prod_appl_list, max_workers=njobs, chunksize=1, mininterval=2)
187
+
188
+ #with open(cache_file, 'wb') as f:
189
+ # print('saving applicable_templates to cache', cache_file)
190
+ # pickle.dump(all_results, f)
191
+
192
+
193
+
194
+ prod_idx_reactants = []
195
+ prod_temp_reactants = []
196
+
197
+ for prod, idx_temp_reactants in zip(test_product_smarts, all_results):
198
+ prod_idx_reactants.append({idx_temp[0]: r for idx_temp, r in idx_temp_reactants.items()})
199
+ prod_temp_reactants.append({idx_temp[1]: r for idx_temp, r in idx_temp_reactants.items()})
200
+
201
+ return prod_idx_reactants, prod_temp_reactants
202
+
203
+ def sort_by_template(template_scores, prod_idx_reactants):
204
+ sorted_results = []
205
+ for i, predictions in enumerate(prod_idx_reactants):
206
+ score_row = template_scores[i]
207
+ appl_idxs = np.array(list(predictions.keys()))
208
+ if len(appl_idxs) == 0:
209
+ sorted_results.append([])
210
+ continue
211
+ scores = score_row[appl_idxs]
212
+ sorted_idxs = appl_idxs[np.argsort(scores)][::-1]
213
+ sorted_reactants = [predictions[idx] for idx in sorted_idxs]
214
+ sorted_results.append(sorted_reactants)
215
+ return sorted_results
216
+
217
+ def no_dup_same_order(l):
218
+ return list({r: 0 for r in l}.keys())
219
+
220
+ def flatten_per_product(sorted_results, remove_duplicates=True):
221
+ flat_results = [sum((r for r in row), []) for row in sorted_results]
222
+ if remove_duplicates:
223
+ flat_results = [no_dup_same_order(row) for row in flat_results]
224
+ return flat_results
225
+
226
+
227
+ def topkaccuracy(test_reactants, predicted_reactants, ks=[1], ret_ranks=False):
228
+ ks = [k if k is not None else 1e10 for k in ks]
229
+ ranks = []
230
+ for true, pred in zip(test_reactants, predicted_reactants):
231
+ try:
232
+ rank = pred.index(true) + 1
233
+ except ValueError:
234
+ rank = 1e15
235
+ ranks.append(rank)
236
+ ranks = np.array(ranks)
237
+ if ret_ranks:
238
+ return ranks
239
+
240
+ return [np.mean([ranks <= k]) for k in ks]
mhnreact/train.py ADDED
@@ -0,0 +1,804 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ Author: Philipp Seidl
4
+ ELLIS Unit Linz, LIT AI Lab, Institute for Machine Learning
5
+ Johannes Kepler University Linz
6
+ Contact: seidl@ml.jku.at
7
+
8
+ Training
9
+ """
10
+
11
+ from .utils import str2bool, lgamma, multinom_gk, top_k_accuracy
12
+ from .data import load_templates, load_dataset_from_csv, load_USPTO
13
+ from .model import ModelConfig, MHN, StaticQK, SeglerBaseline, Retrosim
14
+ from .molutils import convert_smiles_to_fp, FP_featurizer, smarts2appl, getTemplateFingerprint, disable_rdkit_logging
15
+ from collections import defaultdict
16
+ import argparse
17
+ import os
18
+ import numpy as np
19
+ import pandas as pd
20
+ import datetime
21
+ import sys
22
+ from time import time
23
+ import matplotlib.pyplot as plt
24
+ import torch
25
+ import multiprocessing
26
+ import warnings
27
+ from joblib import Memory
28
+
29
+ cachedir = 'data/cache/'
30
+ memory = Memory(cachedir, verbose=0, bytes_limit=80e9)
31
+
32
+ def parse_args():
33
+ parser = argparse.ArgumentParser(description="Train MHNreact.",
34
+ epilog="--", prog="Train")
35
+ parser.add_argument('-f', type=str)
36
+ parser.add_argument('--model_type', type=str, default='mhn',
37
+ help="Model-type: choose from 'segler', 'fortunato', 'mhn' or 'staticQK', default:'mhn'")
38
+ parser.add_argument("--exp_name", type=str, default='', help="experiment name, (added as postfix to the file-names)")
39
+ parser.add_argument("-d", "--dataset_type", type=str, default='sm',
40
+ help="Input Dataset 'sm' for Scheider-USPTO-50k 'lg' for USPTO large or 'golden' or use keyword '--csv_path to specify an input file', default: 'sm'")
41
+ parser.add_argument("--csv_path", default=None, type=str, help="path to preprocessed trainings file + split columns, default: None")
42
+ parser.add_argument("--split_col", default='split', type=str, help="split column of csv, default: 'split'")
43
+ parser.add_argument("--input_col", default='prod_smiles', type=str, help="input column of csv, default: 'pro_smiles'")
44
+ parser.add_argument("--reactants_col", default='reactants_can', type=str, help="reactant colum of csv, default: 'reactants_can'")
45
+
46
+ parser.add_argument("--fp_type", type=str, default='morganc',
47
+ help="Fingerprint type for the input only!: default: 'morgan', other options: 'rdk', 'ECFP', 'ECFC', 'MxFP', 'Morgan2CBF' or a combination of fingerprints with '+'' for max-pooling and '&' for concatination e.g. maccs+morganc+topologicaltorsion+erg+atompair+pattern+rdkc+layered+mhfp, default: 'morganc'")
48
+ parser.add_argument("--template_fp_type", type=str, default='rdk',
49
+ help="Fingerprint type for the template fingerprint, default: 'rdk'")
50
+ parser.add_argument("--device", type=str, default='best',
51
+ help="Device to run the model on, preferably 'cuda:0', default: 'best' (takes the gpu with most RAM)")
52
+ parser.add_argument("--fp_size", type=int, default=4096,
53
+ help="fingerprint-size used for templates as well as for inputs, default: 4096")
54
+ parser.add_argument("--fp_radius", type=int, default=2, help="fingerprint-radius (if applicable to the fingerprint-type), default: 2")
55
+ parser.add_argument("--epochs", type=int, default=10, help='number of epochs, default: 10')
56
+
57
+ parser.add_argument("--pretrain_epochs", type=int, default=0,
58
+ help="applicability-matrix pretraining epochs if applicable (e.g. fortunato model_type), default: 0")
59
+ parser.add_argument("--save_model", type=str2bool, default=False, help="save the model, default: False")
60
+
61
+ parser.add_argument("--dropout", type=float, default=0.2, help="dropout rate for encoders, default: 0.2")
62
+ parser.add_argument("--lr", type=float, default=5e-4, help="learning-rate, dfeault: 5e-4")
63
+ parser.add_argument("--hopf_beta", type=float, default=0.05, help="hopfield beta parameter, default: 0.125")
64
+ parser.add_argument("--hopf_asso_dim", type=int, default=512, help="association dimension, default: 512")
65
+ parser.add_argument("--hopf_num_heads", type=int, default=1, help="hopfield number of heads, default: 1")
66
+ parser.add_argument("--hopf_association_activation", type=str, default='None',
67
+ help="hopfield association activation function recommended:'Tanh' or 'None', other: 'ReLU', 'SeLU', 'GeLU', or 'None' for more, see torch.nn, default: 'None'")
68
+
69
+ parser.add_argument("--norm_input", default=True, type=str2bool,
70
+ help="input-normalization, default: True")
71
+ parser.add_argument("--norm_asso", default=True, type=str2bool,
72
+ help="association-normalization, default: True")
73
+
74
+ # additional experimental hyperparams
75
+ parser.add_argument("--hopf_n_layers", default=1, type=int, help="Number of hopfield-layers, default: 1")
76
+ parser.add_argument("--mol_encoder_layers", default=1, type=int, help="Number of molecule-encoder layers, default: 1")
77
+ parser.add_argument("--temp_encoder_layers", default=1, type=int, help="Number of template-encoder layers, default: 1")
78
+ parser.add_argument("--encoder_af", default='ReLU', type=str,
79
+ help="Encoder-NN intermediate activation function (before association_activation function), default: 'ReLU'")
80
+ parser.add_argument("--hopf_pooling_operation_head", default='mean', type=str, help="Pooling operation over heads default=max, (max, min, mean, ...), default: 'mean'")
81
+
82
+ parser.add_argument("--splitting_scheme", default=None, type=str, help="Splitting_scheme for non-csv-input, default: None, other options: 'class-freq', 'random'")
83
+
84
+ parser.add_argument("--concat_rand_template_thresh", default=-1, type=int, help="Concatinates a random vector to the tempalte-fingerprint at all templates with num_training samples > this threshold; -1 (default) means deactivated")
85
+ parser.add_argument("--repl_quotient", default=10, type=float, help="Only if --concat_rand_template_thresh >= 0 - Quotient of how much should be replaced by random in template-embedding, (default: 10)")
86
+ parser.add_argument("--verbose", default=False, type=str2bool, help="If verbose, will print out more stuff, default: False")
87
+ parser.add_argument("--batch_size", default=128, type=int, help="Training batch-size, default: 128")
88
+ parser.add_argument("--eval_every_n_epochs", default=1, type=int, help="Evaluate every _ epochs (Evaluation is costly for USPTO-Lg), default: 1")
89
+ parser.add_argument("--save_preds", default=False, type=str2bool, help="Save predictions for test split at the end of training, default: False")
90
+ parser.add_argument("--wandb", default=False, type=str2bool, help="Save to wandb; login required, default: False")
91
+ parser.add_argument("--seed", default=None, type=int, help="Seed your run to make it reproducible, defualt: None")
92
+
93
+ parser.add_argument("--template_fp_type2", default=None, type=str, help="experimental template_fp_type for layer 2, default: None")
94
+ parser.add_argument("--layer2weight",default=0.2, type=float, help="hopf-layer2 weight of p, default: 0.2")
95
+
96
+ parser.add_argument("--reactant_pooling", default='max', type=str, help="reactant pooling operation over template-fingerprint, default: 'max', options: 'min','mean','lgamma'")
97
+
98
+
99
+ parser.add_argument("--ssretroeval", default=False, type=str2bool, help="single-step retro-synthesis eval, default: False")
100
+ parser.add_argument("--addval2train", default=False, type=str2bool, help="adds the validation set to the training set, default: False")
101
+ parser.add_argument("--njobs",default=-1, type=int, help="Number of jobs, default: -1 -> uses all available")
102
+
103
+ parser.add_argument("--eval_only_loss", default=False, type=str2bool, help="if only loss should be evaluated (if top-k acc may be time consuming), default: False")
104
+ parser.add_argument("--only_templates_in_batch", default=False, type=str2bool, help="while training only forwards templates that are in the batch, default: False")
105
+
106
+ parser.add_argument("--plot_res", default=False, type=str2bool, help="Plotting results for USPTO-sm/lg, default: False")
107
+ args = parser.parse_args()
108
+
109
+ if args.njobs ==-1:
110
+ args.njobs = int(multiprocessing.cpu_count())
111
+
112
+ if args.device=='best':
113
+ from .utils import get_best_gpu
114
+ try:
115
+ args.device = get_best_gpu()
116
+ except:
117
+ print('couldnt get the best gpu, using cpu instead')
118
+ args.device = 'cpu'
119
+
120
+ # some save checks on model type
121
+ if (args.model_type == 'segler') & (args.pretrain_epochs>=1):
122
+ print('changing model type to fortunato because of pretraining_epochs>0')
123
+ args.model_type = 'fortunato'
124
+ if ((args.model_type == 'staticQK') or (args.model_type == 'retrosim')) & (args.epochs>1):
125
+ print('changing epochs to 1 (StaticQK is not lernable ;)')
126
+ args.epochs=1
127
+ if args.template_fp_type != args.fp_type:
128
+ print('fp_type must be the same as template_fp_type --> setting template_fp_type to fp_type')
129
+ args.template_fp_type = args.fp_type
130
+ if args.save_model & (args.fp_type=='MxFP'):
131
+ warnings.warn('Currently MxFP is not recommended for saving the model paprameter (fragment dict for others would need to be saved or compued again, currently not implemented)')
132
+
133
+ return args
134
+
135
+ @memory.cache(ignore=['njobs'])
136
+ def featurize_smiles(X, fp_type='morgan', fp_size=4096, fp_radius=2, njobs=1, verbose=False):
137
+ X_fp = {}
138
+
139
+ if fp_type in ['MxFP','MACCS','Morgan2CBF','Morgan4CBF', 'Morgan6CBF', 'ErG','AtomPair','TopologicalTorsion','RDK']:
140
+ print('computing', fp_type)
141
+ if fp_type == 'MxFP':
142
+ fp_types = ['MACCS','Morgan2CBF','Morgan4CBF', 'Morgan6CBF', 'ErG','AtomPair','TopologicalTorsion','RDK']
143
+ else:
144
+ fp_types = [fp_type]
145
+
146
+ remaining = int(fp_size)
147
+ for fp_type in fp_types:
148
+ print(fp_type,end=' ')
149
+ feat = FP_featurizer(fp_types=fp_type,
150
+ max_features= (fp_size//len(fp_types)) if (fp_type != fp_types[-1]) else remaining )
151
+ X_fp[f'train_{fp_type}'] = feat.fit(X['train'])
152
+ X_fp[f'valid_{fp_type}'] = feat.transform(X['valid'])
153
+ X_fp[f'test_{fp_type}'] = feat.transform(X['test'])
154
+
155
+ remaining -= X_fp[f'train_{fp_type}'].shape[1]
156
+ #X_fp['train'].shape, X_fp['test'].shape
157
+ X_fp['train'] = np.hstack([ X_fp[f'train_{fp_type}'] for fp_type in fp_types])
158
+ X_fp['valid'] = np.hstack([ X_fp[f'valid_{fp_type}'] for fp_type in fp_types])
159
+ X_fp['test'] = np.hstack([ X_fp[f'test_{fp_type}'] for fp_type in fp_types])
160
+
161
+ else: #fp_type in ['rdk','morgan','ecfp4','pattern','morganc','rdkc']:
162
+ if verbose: print('computing', fp_type, 'folded')
163
+ for split in X.keys():
164
+ X_fp[split] = convert_smiles_to_fp(X[split], fp_size=fp_size, which=fp_type, radius=fp_radius, njobs=njobs, verbose=verbose)
165
+
166
+ return X_fp
167
+
168
+
169
+ def compute_template_fp(fp_len=2048, reactant_pooling='max', do_log=True):
170
+ """Pre-Compute the template-fingerprint"""
171
+ # combine them to one fingerprint
172
+ comb_template_fp = np.zeros((max(template_list.keys())+1,fp_len if reactant_pooling!='concat' else fp_len*6))
173
+ for i in template_list:
174
+ tpl = template_list[i]
175
+ try:
176
+ pr, rea = str(tpl).split('>>')
177
+ idxx = temp_part_to_fp[pr]
178
+ prod_fp = templates_fp['fp'][idxx]
179
+ except:
180
+ print('err', pr, end='\r')
181
+ prod_fp = np.zeros(fp_len)
182
+
183
+ rea_fp = templates_fp['fp'][[temp_part_to_fp[r] for r in str(rea).split('.')]] # max-pooling
184
+
185
+ if reactant_pooling=='only_product':
186
+ rea_fp = np.zeros(fp_len)
187
+ if reactant_pooling=='max':
188
+ rea_fp = np.log(1 + rea_fp.max(0))
189
+ elif reactant_pooling=='mean':
190
+ rea_fp = np.log(1 + rea_fp.mean(0))
191
+ elif reactant_pooling=='sum':
192
+ rea_fp = np.log(1 + rea_fp.mean(0))
193
+ elif reactant_pooling=='lgamma':
194
+ rea_fp = multinom_gk(rea_fp, axis=0)
195
+ elif reactant_pooling=='concat':
196
+ rs = str(rea).split('.')
197
+ rs.sort()
198
+ for ii, r in enumerate(rs):
199
+ idx = temp_part_to_fp[r]
200
+ rea_fp = templates_fp['fp'][idx]
201
+ comb_template_fp[i, (fp_len*(ii+1)):(fp_len*(ii+2))] = np.log(1 + rea_fp)
202
+
203
+ comb_template_fp[i,:prod_fp.shape[0]] = np.log(1 + prod_fp) #- rea_fp*0.5
204
+ if reactant_pooling!='concat':
205
+ #comb_template_fp[i] = multinom_gk(np.stack([np.log(1+prod_fp), rea_fp]))
206
+ #comb_template_fp[i,fp_len:] = rea_fp
207
+ comb_template_fp[i,:rea_fp.shape[0]] = comb_template_fp[i, :rea_fp.shape[0]] - rea_fp*0.5
208
+
209
+ return comb_template_fp
210
+
211
+
212
+ def set_up_model(args, template_list=None):
213
+ hpn_config = ModelConfig(num_templates = int(max(template_list.keys()))+1,
214
+ #len(template_list.values()), #env.num_templates, #
215
+ dropout=args.dropout,
216
+ fingerprint_type=args.fp_type,
217
+ template_fp_type = args.template_fp_type,
218
+ fp_size = args.fp_size,
219
+ fp_radius= args.fp_radius,
220
+ device=args.device,
221
+ lr=args.lr,
222
+ hopf_beta=args.hopf_beta, #1/(128**0.5),#1/(2048**0.5),
223
+ hopf_input_size=args.fp_size,
224
+ hopf_output_size=None,
225
+ hopf_num_heads=args.hopf_num_heads,
226
+ hopf_asso_dim=args.hopf_asso_dim,
227
+
228
+ hopf_association_activation = args.hopf_association_activation, #or ReLU, Tanh works better, SELU, GELU
229
+ norm_input = args.norm_input,
230
+ norm_asso = args.norm_asso,
231
+
232
+ hopf_n_layers= args.hopf_n_layers,
233
+ mol_encoder_layers=args.mol_encoder_layers,
234
+ temp_encoder_layers=args.temp_encoder_layers,
235
+ encoder_af=args.encoder_af,
236
+
237
+ hopf_pooling_operation_head = args.hopf_pooling_operation_head,
238
+ batch_size=args.batch_size,
239
+ )
240
+ print(hpn_config.__dict__)
241
+
242
+ if args.model_type=='segler': # baseline
243
+ clf = SeglerBaseline(hpn_config)
244
+ elif args.model_type=='mhn':
245
+ clf = MHN(hpn_config, layer2weight=args.layer2weight)
246
+ elif args.model_type=='fortunato': # pretraining with applicability-matrix
247
+ clf = SeglerBaseline(hpn_config)
248
+ elif args.model_type=='staticQK': # staticQK
249
+ clf = StaticQK(hpn_config)
250
+ elif args.model_type=='retrosim': # staticQK
251
+ clf = Retrosim(hpn_config)
252
+ else:
253
+ raise NotImplementedError
254
+
255
+ return clf, hpn_config
256
+
257
+ def set_up_template_encoder(args, clf, label_to_n_train_samples=None, template_list=None):
258
+
259
+ if isinstance(clf, SeglerBaseline):
260
+ clf.templates = []
261
+ elif args.model_type=='staticQK':
262
+ clf.template_list = list(template_list.values())
263
+ clf.update_template_embedding(which=args.template_fp_type, fp_size=args.fp_size, radius=args.fp_radius, njobs=args.njobs)
264
+ elif args.model_type=='retrosim':
265
+ #clf.template_list = list(X['train'].values())
266
+ clf.fit_with_train(X_fp['train'], y['train'])
267
+ else:
268
+ import hashlib
269
+ PATH = './data/cache/'
270
+ if not os.path.exists(PATH):
271
+ os.mkdir(PATH)
272
+ fn_templ_emb = f'{PATH}templ_emb_{args.fp_size}_{args.template_fp_type}{args.fp_radius}_{len(template_list)}_{int(hashlib.sha512((str(template_list)).encode()).hexdigest(), 16)}.npy'
273
+ if (os.path.exists(fn_templ_emb)): # load the template embedding
274
+ print(f'loading tfp from file {fn_templ_emb}')
275
+ templ_emb = np.load(fn_templ_emb)
276
+ # !!! beware of different fingerprint types
277
+ clf.template_list = list(template_list.values())
278
+
279
+ if args.only_templates_in_batch:
280
+ clf.templates_np = templ_emb
281
+ clf.templates = None
282
+ else:
283
+ clf.templates = torch.from_numpy(templ_emb).float().to(clf.config.device)
284
+ else:
285
+ if args.template_fp_type=='MxFP':
286
+ clf.template_list = list(template_list.values())
287
+ clf.templates = torch.from_numpy(comb_template_fp).float().to(clf.config.device)
288
+ clf.set_templates_recursively()
289
+ elif args.template_fp_type=='Tfidf':
290
+ clf.template_list = list(template_list.values())
291
+ clf.templates = torch.from_numpy(tfidf_template_fp).float().to(clf.config.device)
292
+ clf.set_templates_recursively()
293
+ elif args.template_fp_type=='random':
294
+ clf.template_list = list(template_list.values())
295
+ clf.templates = torch.from_numpy(np.random.rand(len(template_list),args.fp_size)).float().to(clf.config.device)
296
+ clf.set_templates_recursively()
297
+ else:
298
+ clf.set_templates(list(template_list.values()), which=args.template_fp_type, fp_size=args.fp_size,
299
+ radius=args.fp_radius, learnable=False, njobs=args.njobs, only_templates_in_batch=args.only_templates_in_batch)
300
+ #if len(template_list)<100000:
301
+ np.save(fn_templ_emb, clf.templates_np if args.only_templates_in_batch else clf.templates.detach().cpu().numpy().astype(np.float16))
302
+
303
+ # concatinate the current fingerprint with a random fingerprint if the threshold is above
304
+ if (args.concat_rand_template_thresh != -1) & (args.repl_quotient>0):
305
+ REPLACE_FACTOR = int(args.repl_quotient) # default was 8
306
+
307
+ # fold the original fingerprint
308
+ pre_comp_templates = clf.templates_np if args.only_templates_in_batch else clf.templates.detach().cpu().numpy()
309
+
310
+ # mask of labels with mor than 49 training samples
311
+ l_mask = np.array([label_to_n_train_samples[k]>=args.concat_rand_template_thresh for k in template_list])
312
+ print(f'Num of templates with added rand-vect of size {pre_comp_templates.shape[1]//REPLACE_FACTOR} due to >=thresh ({args.concat_rand_template_thresh}):',l_mask.sum())
313
+
314
+ # remove the bits with the lowest variance
315
+ v = pre_comp_templates.var(0)
316
+ idx_lowest_var_half = v.argsort()[:(pre_comp_templates.shape[1]//REPLACE_FACTOR)]
317
+
318
+ # the new zero-init-vectors
319
+ pre = np.zeros([pre_comp_templates.shape[0], pre_comp_templates.shape[1]//REPLACE_FACTOR]).astype(np.float)
320
+ print(pre.shape, l_mask.shape, l_mask.sum()) #(616, 1700) (11790,) 519
321
+ print(pre_comp_templates.shape, len(template_list)) #(616, 17000) 616
322
+ # only the ones with >thresh will receive a random vect
323
+ pre[l_mask] = np.random.rand(l_mask.sum(), pre.shape[1])
324
+
325
+ pre_comp_templates[:,idx_lowest_var_half] = pre
326
+
327
+ #clf.templates = torch.from_numpy(pre_comp_templates).float().to(clf.config.device)
328
+ if pre_comp_templates.shape[0]<100000:
329
+ print('adding template_matrix to params')
330
+ param = torch.nn.Parameter(torch.from_numpy(pre_comp_templates).float(), requires_grad=False)
331
+ clf.register_parameter(name='templates+noise', param=param)
332
+ clf.templates = param.to(clf.config.device)
333
+ clf.set_templates_recursively()
334
+ else: #otherwise might cause memory issues
335
+ print('more than 100k templates')
336
+ if args.only_templates_in_batch:
337
+ clf.templates = None
338
+ clf.templates_np = pre_comp_templates
339
+ else:
340
+ clf.templates = torch.from_numpy(pre_comp_templates).float()
341
+ clf.set_templates_recursively()
342
+
343
+
344
+ # set's this for the first layer!!
345
+ if args.template_fp_type2=='MxFP':
346
+ print('first_layer template_fingerprint is set to MxFP')
347
+ clf.templates = torch.from_numpy(comb_template_fp).float().to(clf.config.device)
348
+ elif args.template_fp_type2=='Tfidf':
349
+ print('first_layer template_fingerprint is set to Tfidf')
350
+ clf.templates = torch.from_numpy(tfidf_template_fp).float().to(clf.config.device)
351
+ elif args.template_fp_type2=='random':
352
+ print('first_layer template_fingerprint is set to random')
353
+ clf.templates = torch.from_numpy(np.random.rand(len(template_list),args.fp_size)).float().to(clf.config.device)
354
+ elif args.template_fp_type2=='stfp':
355
+ print('first_layer template_fingerprint is set to stfp ! only works with 4096 fp_size')
356
+ tfp = getTemplateFingerprint(list(template_list.values()))
357
+ clf.templates = torch.from_numpy(tfp).float().to(clf.config.device)
358
+
359
+ return clf
360
+
361
+
362
+ if __name__ == '__main__':
363
+
364
+ args = parse_args()
365
+
366
+ run_id = str(time()).split('.')[0]
367
+ fn_postfix = str(args.exp_name) + '_' + run_id
368
+
369
+ if args.wandb:
370
+ import wandb
371
+ wandb.init(project='mhn-react', entity='phseidl', name=args.dataset_type+'_'+args.model_type+'_'+fn_postfix, config=args.__dict__)
372
+ else:
373
+ wandb=None
374
+
375
+ if not args.verbose:
376
+ disable_rdkit_logging()
377
+
378
+ if args.seed is not None:
379
+ from .utils import seed_everything
380
+ seed_everything(args.seed)
381
+ print('seeded with',args.seed)
382
+
383
+ # load csv or data
384
+ if args.csv_path is None:
385
+ X, y = load_USPTO(which=args.dataset_type)
386
+ template_list = load_templates(which=args.dataset_type)
387
+ else:
388
+ X, y, template_list, test_reactants_can = load_dataset_from_csv(**vars(args))
389
+
390
+ if args.addval2train:
391
+ print('adding val to train')
392
+ X['train'] = [*X['train'],*X['valid']]
393
+ y['train'] = np.concatenate([y['train'],y['valid']])
394
+
395
+ splits = ['train', 'valid', 'test']
396
+
397
+ #TODO split up in seperate class
398
+ if args.splitting_scheme == 'class-freq':
399
+ X_all = np.concatenate([X[split] for split in splits], axis=0)
400
+ y_all = np.concatenate([y[split] for split in splits])
401
+
402
+ # sort class by frequency / assumes class-index is ordered (wich is mildely violated)
403
+ res = y_all.argsort()
404
+
405
+ # use same split proportions
406
+ cum_split_lens = np.cumsum([len(y[split]) for split in splits]) #cumulative split length
407
+
408
+ X['train'] = X_all[res[0:cum_split_lens[0]]]
409
+ y['train'] = y_all[res[0:cum_split_lens[0]]]
410
+
411
+ X['valid'] = X_all[res[cum_split_lens[0]:cum_split_lens[1]]]
412
+ y['valid'] = y_all[res[cum_split_lens[0]:cum_split_lens[1]]]
413
+
414
+ X['test'] = X_all[res[cum_split_lens[1]:]]
415
+ y['test'] = y_all[res[cum_split_lens[1]:]]
416
+ for split in splits:
417
+ print(split, y[split].shape[0], 'samples (', y[split].max(),'max label)')
418
+
419
+ if args.splitting_scheme == 'remove_once_in_train_and_not_in_test':
420
+ print('remove_once_in_train')
421
+ from collections import Counter
422
+ cc = Counter()
423
+ cc.update(y['train'])
424
+ classes_set_only_once_in_train = set(np.array(list(cc.keys()))[ (np.array(list(cc.values())))==1])
425
+ not_in_test = set(y['train']).union(y['valid']) - (set(y['test']))
426
+ classes_set_only_once_in_train = (classes_set_only_once_in_train.intersection(not_in_test))
427
+ remove_those_mask = np.array([yii in classes_set_only_once_in_train for yii in y['train']])
428
+ X['train'] = np.array(X['train'])[~remove_those_mask]
429
+ y['train'] = np.array(y['train'])[~remove_those_mask]
430
+ print(remove_those_mask.mean(),'%', remove_those_mask.sum(), 'samples removed')
431
+
432
+ if args.splitting_scheme == 'random':
433
+ print('random-splitting-scheme:8-1-1')
434
+ if args.ssretroeval:
435
+ print('ssretroeval not available')
436
+ raise NotImplementedError
437
+ import numpy as np
438
+ from sklearn.model_selection import train_test_split
439
+
440
+ def _unpack(lod):
441
+ r = []
442
+ for k,v in lod.items():
443
+ [r.append(i) for i in v]
444
+ return r
445
+
446
+ X_all = _unpack(X)
447
+ y_all = np.array( _unpack(y) )
448
+
449
+ X['train'], X['test'], y['train'], y['test'] = train_test_split(X_all, y_all, test_size=0.2, random_state=70135)
450
+ X['test'], X['valid'], y['test'], y['valid'] = train_test_split(X['test'], y['test'], test_size=0.5, random_state=70135)
451
+
452
+ zero_shot = set(y['test']).difference( set(y['train']).union(set(y['valid'])) )
453
+ zero_shot_mask = np.array([yi in zero_shot for yi in y['test']])
454
+ print(sum(zero_shot_mask))
455
+ #y['test'][zero_shot_mask] = list(zero_shot)[0] #not right but quick
456
+
457
+
458
+ if args.model_type=='staticQK' or args.model_type=='retrosim':
459
+ print('staticQK model: caution: use pattern, or rdk -fingerprint-embedding')
460
+
461
+ fp_size = args.fp_size
462
+ radius = args.fp_radius #quite important ;)
463
+ fp_embedding = args.fp_type
464
+
465
+ X_fp = featurize_smiles(X, fp_type=args.fp_type, fp_size=args.fp_size, fp_radius=args.fp_radius, njobs=args.njobs)
466
+
467
+ if args.template_fp_type=='MxFP' or (args.template_fp_type2=='MxFP'):
468
+ temp_part_to_fp = {}
469
+ for i in template_list:
470
+ tpl = template_list[i]
471
+ for part in str(tpl).split('>>'):
472
+ for p in str(part).split('.'):
473
+ temp_part_to_fp[p]=None
474
+ for i, k in enumerate(temp_part_to_fp):
475
+ temp_part_to_fp[k] = i
476
+
477
+ fp_types = ['Morgan2CBF','Morgan4CBF', 'Morgan6CBF','AtomPair','TopologicalTorsion', 'Pattern', 'RDK']
478
+ #MACCS ErG don't work --> errors with explicit / inplicit valence
479
+ templates_fp = {}
480
+ remaining = args.fp_size
481
+ for fp_type in fp_types:
482
+ #print(fp_type, end='\t')
483
+ # if it's that last use up the remaining fps
484
+ te_feat = FP_featurizer(fp_types=fp_type,
485
+ max_features=(args.fp_size//len(fp_types)) if (fp_type != fp_types[-1]) else remaining,
486
+ log_scale=False
487
+ )
488
+ templates_fp[fp_type] = te_feat.fit(list(temp_part_to_fp.keys())[:], is_smarts=True)
489
+ #print(np.unique(templates_fp[fp_type]), end='\r')
490
+ remaining -= templates_fp[fp_type].shape[1]
491
+ templates_fp['fp'] = np.hstack([ templates_fp[f'{fp_type}'] for fp_type in fp_types])
492
+
493
+
494
+ if args.template_fp_type=='MxFP' or (args.template_fp_type2=='MxFP'):
495
+ comb_template_fp = compute_template_fp(fp_len= args.fp_size, reactant_pooling=args.reactant_pooling)
496
+
497
+
498
+
499
+ if args.template_fp_type=='Tfidf' or (args.template_fp_type2 == 'Tfidf'):
500
+ print('using tfidf template-fingerprint')
501
+ from sklearn.feature_extraction.text import TfidfVectorizer
502
+ corpus = (list(template_list.values()))
503
+ vectorizer = TfidfVectorizer(analyzer='char', ngram_range=(1,12), max_features=args.fp_size)
504
+ tfidf_template_fp = vectorizer.fit_transform(corpus).toarray()
505
+ tfidf_template_fp.shape
506
+
507
+
508
+ acutal_fp_size = X_fp['train'].shape[1]
509
+ if acutal_fp_size != args.fp_size:
510
+ args.fp_size = int(X_fp['train'].shape[1])
511
+ print('Warning: fp-size has changed to', acutal_fp_size)
512
+
513
+
514
+ label_to_n_train_samples = {}
515
+ n_train_samples_to_label = defaultdict(list)
516
+ n_templates = max(template_list.keys())+1 #max(max(y['train']), max(y['test']), max(y['valid']))
517
+ for i in range(n_templates):
518
+ n_train_samples = (y['train']==i).sum()
519
+ label_to_n_train_samples[i] = n_train_samples
520
+ n_train_samples_to_label[n_train_samples].append(i)
521
+
522
+
523
+ up_to = 11
524
+ n_samples = []
525
+ masks = []
526
+ ntes = range(up_to)
527
+ mask_dict = {}
528
+
529
+ for nte in ntes: # Number of training examples
530
+ split = f'nte_{nte}'
531
+ #print(split)
532
+ mask = np.zeros(y['test'].shape)
533
+
534
+ if isinstance(nte, int):
535
+ for label_with_nte in n_train_samples_to_label[nte]:
536
+ mask += (y['test'] == label_with_nte)
537
+
538
+ mask = mask>=1
539
+ masks.append(mask)
540
+ mask_dict[str(nte)] = mask
541
+ n_samples.append(mask.sum())
542
+
543
+ # for greater than 10 # >10
544
+ n_samples.append((np.array(masks).max(0)==0).sum())
545
+ mask_dict['>10'] = (np.array(masks).max(0)==0)
546
+
547
+ sum(n_samples), mask.shape
548
+
549
+ ntes = range(50) #to 49
550
+ for nte in ntes: # Number of training examples
551
+ split = f'nte_{nte}'
552
+ #print(split)
553
+ mask = np.zeros(y['test'].shape)
554
+ for label_with_nte in n_train_samples_to_label[nte]:
555
+ mask += (y['test'] == label_with_nte)
556
+ mask = mask>=1
557
+ masks.append(mask)
558
+ # for greater than 10 # >49
559
+ n_samples.append((np.array(masks).max(0)==0).sum())
560
+ mask_dict['>49'] = np.array(masks).max(0)==0
561
+
562
+ print(n_samples)
563
+
564
+ clf, hpn_config = set_up_model(args, template_list=template_list)
565
+ clf = set_up_template_encoder(args, clf, label_to_n_train_samples=label_to_n_train_samples, template_list=template_list)
566
+
567
+ if args.verbose:
568
+ print(clf.config.__dict__)
569
+ print(clf)
570
+
571
+ wda = torch.optim.AdamW(clf.parameters(), lr=args.lr, weight_decay=1e-2)
572
+
573
+ if args.wandb:
574
+ wandb.watch(clf)
575
+
576
+
577
+ # pretraining with applicablity matrix, if applicable
578
+ if args.model_type == 'fortunato' or args.pretrain_epochs>1:
579
+ print('pretraining on applicability-matrix -- loading the matrix')
580
+ _, y_appl = load_USPTO(args.dataset_type, is_appl_matrix=True)
581
+ if args.splitting_scheme == 'remove_once_in_train_and_not_in_test':
582
+ y_appl['train'] = y_appl['train'][~remove_those_mask]
583
+
584
+ # check random if the applicability is true for y
585
+ splt = 'train'
586
+ for i in range(500):
587
+ i = np.random.randint(len(y[splt]))
588
+ #assert ( y_appl[splt][i].indices == y[splt][i] ).sum()==1
589
+
590
+ print('pre-training (BCE-loss)')
591
+ for epoch in range(args.pretrain_epochs):
592
+ clf.train_from_np(X_fp['train'], X_fp['train'], y_appl['train'], use_dataloader=True, is_smiles=False,
593
+ epochs=1, wandb=wandb, verbose=args.verbose, bs=args.batch_size,
594
+ permute_batches=True, shuffle=True, optimizer=wda,
595
+ only_templates_in_batch=args.only_templates_in_batch)
596
+ y_pred = clf.evaluate(X_fp['valid'], X_fp['valid'], y_appl['valid'],
597
+ split='pretrain_valid', is_smiles=False, only_loss=True,
598
+ bs=args.batch_size,wandb=wandb)
599
+ appl_acc = ((y_appl['valid'].toarray()) == (y_pred>0.5)).mean()
600
+ print(f'{epoch:2.0f} -- train_loss: {clf.hist["loss"][-1]:1.3f}, loss_valid: {clf.hist["loss_pretrain_valid"][-1]:1.3f}, train_acc: {appl_acc:1.5f}')
601
+
602
+ fn_hist = None
603
+ y_preds = None
604
+
605
+ for epoch in range(round(args.epochs / args.eval_every_n_epochs)):
606
+ if not isinstance(clf, StaticQK):
607
+ now = time()
608
+ clf.train_from_np(X_fp['train'], X_fp['train'], y['train'], use_dataloader=True, is_smiles=False,
609
+ epochs=args.eval_every_n_epochs, wandb=wandb, verbose=args.verbose, bs=args.batch_size,
610
+ permute_batches=True, shuffle=True, optimizer=wda, only_templates_in_batch=args.only_templates_in_batch)
611
+ if args.verbose: print(f'training took {(time()-now)/60:3.1f} min for {args.eval_every_n_epochs} epochs')
612
+ for split in ['valid', 'test']:
613
+ print(split, 'evaluating', end='\r')
614
+ now = time()
615
+ #only_loss = ((epoch%5)==4) if args.dataset_type=='lg' else True
616
+ y_preds = clf.evaluate(X_fp[split], X_fp[split], y[split], is_smiles=False, split=split, bs=args.batch_size, only_loss=args.eval_only_loss, wandb=wandb);
617
+
618
+ if args.verbose: print(f'eval {split} took',(time()-now)/60,'min')
619
+ if not isinstance(clf, StaticQK):
620
+ try:
621
+ print(f'{epoch:2.0f} -- train_loss: {clf.hist["loss"][-1]:1.3f}, loss_valid: {clf.hist["loss_valid"][-1]:1.3f}, val_t1acc: {clf.hist["t1_acc_valid"][-1]:1.3f}, val_t100acc: {clf.hist["t100_acc_valid"][-1]:1.3f}')
622
+ except:
623
+ pass
624
+
625
+ now = time()
626
+ ks = [1, 2, 3, 4, 5, 10, 20, 30, 40, 50, 100]
627
+ for nte in mask_dict: # Number of training examples
628
+ split = f'nte_{nte}'
629
+ #print(split)
630
+ mask = mask_dict[nte]
631
+
632
+ topkacc = top_k_accuracy(np.array(y['test'])[mask], y_preds[mask, :], k=ks, ret_arocc=False)
633
+
634
+ new_hist = {}
635
+ for k, tkacc in zip(ks, topkacc):
636
+ new_hist[f't{k}_acc_{split}'] = tkacc
637
+ #new_hist[(f'arocc_{split}')] = (arocc)
638
+ new_hist[f'steps_{split}'] = (clf.steps)
639
+
640
+ for k in new_hist:
641
+ clf.hist[k].append(new_hist[k])
642
+
643
+ if args.verbose: print(f'eval nte-test took',(time()-now)/60,'min')
644
+
645
+ fn_hist = clf.save_hist(prefix=f'USTPO_{args.dataset_type}_{args.model_type}_', postfix=fn_postfix)
646
+
647
+ if args.save_preds:
648
+ PATH = './data/preds/'
649
+ if not os.path.exists(PATH):
650
+ os.mkdir(PATH)
651
+ pred_fn = f'{PATH}USPTO_{args.dataset_type}_test_{args.model_type}_{fn_postfix}.npy'
652
+ print('saving predictions to',pred_fn)
653
+ np.save(pred_fn,y_preds)
654
+ args.save_preds = pred_fn
655
+
656
+
657
+ if args.save_model:
658
+ model_save_path = clf.save_model(prefix=f'USPTO_{args.dataset_type}_{args.model_type}_valloss{clf.hist.get("loss_valid",[-1])[-1]:1.3f}_',name_as_conf=False, postfix=fn_postfix)
659
+
660
+ # Serialize data into file:
661
+ import json
662
+ json.dump( args.__dict__, open( f"data/model/{fn_postfix}_args.json", 'w' ) )
663
+ json.dump( hpn_config.__dict__,
664
+ open( f"data/model/{fn_postfix}_config.json", 'w' ) )
665
+
666
+ print('model saved to', model_save_path)
667
+
668
+ print(min(clf.hist.get('loss_valid',[-1])))
669
+
670
+ if args.plot_res:
671
+ from plotutils import plot_topk, plot_nte
672
+
673
+ plt.figure()
674
+ clf.plot_loss()
675
+ plt.draw()
676
+
677
+ plt.figure()
678
+ plot_topk(clf.hist, sets=['valid'])
679
+ if args.dataset_type=='sm':
680
+ baseline_val_res = {1:0.4061, 10:0.6827, 50: 0.7883, 100:0.8400}
681
+ plt.plot(list(baseline_val_res.keys()), list(baseline_val_res.values()), 'k.--')
682
+ plt.draw()
683
+ plt.figure()
684
+
685
+ best_cpt = np.array(clf.hist['loss_valid'])[::-1].argmin()+1
686
+ print(best_cpt)
687
+ try:
688
+ best_cpt = np.array(clf.hist['t10_acc_valid'])[::-1].argmax()+1
689
+ print(best_cpt)
690
+ except:
691
+ print('err with t10_acc_valid')
692
+ plot_nte(clf.hist, dataset=args.dataset_type.capitalize(), last_cpt=best_cpt, include_bar=True, model_legend=args.exp_name,
693
+ n_samples=n_samples, z=1.96)
694
+ if os.path.exists('data/figs/'):
695
+ try:
696
+ os.mkdir(f'data/figs/{args.exp_name}/')
697
+ except:
698
+ pass
699
+ plt.savefig(f'data/figs/{args.exp_name}/training_examples_vs_top100_acc_{args.dataset_type}_{hash(str(args))}.pdf')
700
+ plt.draw()
701
+ fn_hist = clf.save_hist(prefix=f'USTPO_{args.dataset_type}_{args.model_type}_', postfix=fn_postfix)
702
+
703
+
704
+ if args.ssretroeval:
705
+ print('testing on the real test set ;)')
706
+ from .data import load_templates
707
+ from .retroeval import run_templates, topkaccuracy
708
+ from .utils import sort_by_template_and_flatten
709
+
710
+
711
+ a = list(template_list.keys())
712
+ #assert list(range(len(a))) == a
713
+ templates = list(template_list.values())
714
+ #templates = [*templates, *expert_templates]
715
+ template_product_smarts = [str(s).split('>')[0] for s in templates]
716
+
717
+ #execute all template
718
+ print('execute all templates')
719
+ test_product_smarts = [xi[0] for xi in X['test']] #added later
720
+ smarts2appl = memory.cache(smarts2appl, ignore=['njobs','nsplits', 'use_tqdm'])
721
+ appl = smarts2appl(test_product_smarts, template_product_smarts, njobs=args.njobs)
722
+ n_pairs = len(test_product_smarts) * len(template_product_smarts)
723
+ n_appl = len(appl[0])
724
+ print(n_pairs, n_appl, n_appl/n_pairs)
725
+
726
+ #forward
727
+ split = 'test'
728
+ print('len(X_fp[test]):',len(X_fp[split]))
729
+ y[split] = np.zeros(len(X[split])).astype(np.int)
730
+ clf.eval()
731
+ if y_preds is None:
732
+ y_preds = clf.evaluate(X_fp[split], X_fp[split], y[split], is_smiles=False,
733
+ split='ttest', bs=args.batch_size, only_loss=True, wandb=None);
734
+
735
+ template_scores = y_preds #this should allready be test
736
+
737
+ ####
738
+ if y_preds.shape[1]>100000:
739
+ kth = 200
740
+ print(f'only evaluating top {kth} applicable predicted templates')
741
+ # only take top kth and multiply by applicability matrix
742
+ appl_mtrx = np.zeros_like(y_preds, dtype=bool)
743
+ appl_mtrx[appl[0], appl[1]] = 1
744
+
745
+ appl_and_topkth = ([], [])
746
+ for row in range(len(y_preds)):
747
+ argpreds = (np.argpartition(-(y_preds[row]*appl_mtrx[row]), kth, axis=0)[:kth])
748
+ # if there are less than kth applicable
749
+ mask = appl_mtrx[row][argpreds]
750
+ argpreds = argpreds[mask]
751
+ #if len(argpreds)!=kth:
752
+ # print('changed to ', len(argpreds))
753
+
754
+ appl_and_topkth[0].extend([row for _ in range(len(argpreds))])
755
+ appl_and_topkth[1].extend(list(argpreds))
756
+
757
+ appl = appl_and_topkth
758
+ ####
759
+
760
+ print('running the templates')
761
+ run_templates = run_templates #memory.cache( ) ... allready cached to tmp
762
+ prod_idx_reactants, prod_temp_reactants = run_templates(test_product_smarts, templates, appl, njobs=args.njobs)
763
+ #sorted_results = sort_by_template(template_scores, prod_idx_reactants)
764
+ #flat_results = flatten_per_product(sorted_results, remove_duplicates=True)
765
+ #now aglomerates over same outcome
766
+ flat_results = sort_by_template_and_flatten(y_preds, prod_idx_reactants, agglo_fun=sum)
767
+ accs = topkaccuracy(test_reactants_can, flat_results, [*list(range(1,101)), 100000])
768
+
769
+ mtrcs2 = {f't{k}acc_ttest':accs[k-1] for k in [1,2,3,5,10,20,50,100,101]}
770
+ if wandb:
771
+ wandb.log(mtrcs2)
772
+ print('Single-step retrosynthesis-evaluation, results on ttest:')
773
+ #print([k[:-6]+'|' for k in mtrcs2.keys()])
774
+ [print(k[:-6],end='\t') for k in mtrcs2.keys()]
775
+ print()
776
+ for k,v in mtrcs2.items():
777
+ print(f'{v*100:2.2f}',end='\t')
778
+
779
+
780
+ # save the history of this experiment
781
+ EXP_DIR = 'data/experiments/'
782
+
783
+ df = pd.DataFrame([args.__dict__])
784
+ df['min_loss_valid'] = min(clf.hist.get('loss_valid', [-1]))
785
+ df['min_loss_train'] = 0 if ((args.model_type=='staticQK') or (args.model_type=='retrosim')) else min(clf.hist.get('loss',[-1]))
786
+ try:
787
+ df['max_t1_acc_valid'] = max(clf.hist.get('t1_acc_valid', [0]))
788
+ df['max_t100_acc_valid'] = max(clf.hist.get('t100_acc_valid', [0]))
789
+ except:
790
+ pass
791
+ df['hist'] = [clf.hist]
792
+ df['n_samples'] = [n_samples]
793
+
794
+ df['fn_hist'] = fn_hist if fn_hist else None
795
+ df['fn_model'] = '' if not args.save_model else model_save_path
796
+ df['date'] = str(datetime.datetime.fromtimestamp(time()))
797
+ df['cmd'] = ' '.join(sys.argv[:])
798
+
799
+
800
+ if not os.path.exists(EXP_DIR):
801
+ os.mkdir(EXP_DIR)
802
+
803
+ df.to_csv(f'{EXP_DIR}{run_id}.tsv', sep='\t')
804
+ df
mhnreact/utils.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ Author: Philipp Seidl
4
+ ELLIS Unit Linz, LIT AI Lab, Institute for Machine Learning
5
+ Johannes Kepler University Linz
6
+ Contact: seidl@ml.jku.at
7
+
8
+ General utility functions
9
+ """
10
+
11
+ import argparse
12
+ from collections import defaultdict
13
+ import numpy as np
14
+ import pandas as pd
15
+ import math
16
+ import torch
17
+
18
+ # used and fastest version
19
+ def top_k_accuracy(y_true, y_pred, k=5, ret_arocc=False, ret_mrocc=False, verbose=False, count_equal_as_correct=False, eps_noise=0):
20
+ """ partly from http://stephantul.github.io/python/pytorch/2020/09/18/fast_topk/
21
+ count_equal counts equal values as beein a correct choice e.g. all preds = 0 --> T1acc = 1
22
+ ret_mrocc ... also return median rank of correct choice
23
+ eps_noise ... if >0 ads noise*eps to y_pred .. recommended e.g. 1e-10
24
+ """
25
+ if eps_noise>0:
26
+ if torch.is_tensor(y_pred):
27
+ y_pred = y_pred + torch.rand(y_pred.shape)*eps_noise
28
+ else:
29
+ y_pred = y_pred + np.random.rand(*y_pred.shape)*eps_noise
30
+
31
+ if count_equal_as_correct:
32
+ greater = (y_pred > y_pred[range(len(y_pred)), y_true][:,None]).sum(1) # how many are bigger
33
+ else:
34
+ greater = (y_pred >= y_pred[range(len(y_pred)), y_true][:,None]).sum(1) # how many are bigger or equal
35
+ if torch.is_tensor(y_pred):
36
+ greater = greater.long()
37
+ if isinstance(k, int): k = [k] # pack it into a list
38
+ tkaccs = []
39
+ for ki in k:
40
+ if count_equal_as_correct:
41
+ tkacc = (greater<=(ki-1))
42
+ else:
43
+ tkacc = (greater<=(ki))
44
+ if torch.is_tensor(y_pred):
45
+ tkacc = tkacc.float().mean().detach().cpu().numpy()
46
+ else:
47
+ tkacc = tkacc.mean()
48
+ tkaccs.append(tkacc)
49
+ if verbose: print('Top', ki, 'acc:\t', str(tkacc)[:6])
50
+
51
+ if ret_arocc:
52
+ arocc = greater.float().mean()+1
53
+ if torch.is_tensor(arocc):
54
+ arocc = arocc.detach().cpu().numpy()
55
+ return (tkaccs[0], arocc) if len(tkaccs) == 1 else (tkaccs, arocc)
56
+ if ret_mrocc:
57
+ mrocc = greater.median()+1
58
+ if torch.is_tensor(mrocc):
59
+ mrocc = mrocc.float().detach().cpu().numpy()
60
+ return (tkaccs[0], mrocc) if len(tkaccs) == 1 else (tkaccs, mrocc)
61
+
62
+
63
+ return tkaccs[0] if len(tkaccs) == 1 else tkaccs
64
+
65
+
66
+ def seed_everything(seed=70135):
67
+ """ does what it says ;) - from https://gist.github.com/KirillVladimirov/005ec7f762293d2321385580d3dbe335"""
68
+ import numpy as np
69
+ import random
70
+ import os
71
+ import torch
72
+
73
+ random.seed(seed)
74
+ os.environ['PYTHONHASHSEED'] = str(seed)
75
+ np.random.seed(seed)
76
+ torch.manual_seed(seed)
77
+ torch.cuda.manual_seed(seed)
78
+ torch.backends.cudnn.deterministic = True
79
+
80
+ def get_best_gpu():
81
+ '''Get the gpu with most RAM on the machine. From P. Neves'''
82
+ import torch
83
+ if torch.cuda.is_available():
84
+ gpus_ram = []
85
+ for ind in range(torch.cuda.device_count()):
86
+ gpus_ram.append(torch.cuda.get_device_properties(ind).total_memory/1e9)
87
+ return f"cuda:{gpus_ram.index(max(gpus_ram))}"
88
+ else:
89
+ raise ValueError("No gpus were detected in this machine.")
90
+
91
+
92
+ def sort_by_template_and_flatten(template_scores, prod_idx_reactants, agglo_fun=sum):
93
+ flat_results = []
94
+ for ii in range(len(template_scores)):
95
+ idx_prod_reactants = defaultdict(list)
96
+ for k,v in prod_idx_reactants[ii].items():
97
+ for iv in v:
98
+ idx_prod_reactants[iv].append(template_scores[ii,k])
99
+ d2 = {k: agglo_fun(v) for k, v in idx_prod_reactants.items()}
100
+ if len(d2)==0:
101
+ flat_results.append([])
102
+ else:
103
+ flat_results.append(pd.DataFrame.from_dict(d2, orient='index').sort_values(0, ascending=False).index.values.tolist())
104
+ return flat_results
105
+
106
+
107
+ def str2bool(v):
108
+ """adapted from https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse"""
109
+ if isinstance(v, bool):
110
+ return v
111
+ if v.lower() in ('yes', 'true', 't', 'y', '1', '',' '):
112
+ return True
113
+ elif v.lower() in ('no', 'false', 'f', 'n', '0'):
114
+ return False
115
+ else:
116
+ raise argparse.ArgumentTypeError('Boolean value expected.')
117
+
118
+
119
+ @np.vectorize
120
+ def lgamma(x):
121
+ return math.lgamma(x)
122
+
123
+ def multinom_gk(array, axis=0):
124
+ """Multinomial lgamma pooling over a given axis"""
125
+ res = lgamma(np.sum(array,axis=axis)+2) - np.sum(lgamma(array+1),axis=axis)
126
+ return res
mhnreact/view.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ Author: Philipp Seidl
4
+ ELLIS Unit Linz, LIT AI Lab, Institute for Machine Learning
5
+ Johannes Kepler University Linz
6
+ Contact: seidl@ml.jku.at
7
+
8
+ Loading log-files from training
9
+ """
10
+
11
+ from pathlib import Path
12
+ import os
13
+ import datetime
14
+ import pandas as pd
15
+ import numpy as np
16
+ import pandas as pd
17
+ import matplotlib.pyplot as plt
18
+
19
+ def load_experiments(EXP_DIR = Path('data/experiments/')):
20
+ dfs = []
21
+ for fn in os.listdir(EXP_DIR):
22
+ print(fn, end='\r')
23
+ if fn.split('.')[-1]=='tsv':
24
+ df = pd.read_csv(EXP_DIR/fn, sep='\t', index_col=0)
25
+ try:
26
+ with open(df['fn_hist'][0]) as f:
27
+ hist = eval(f.readlines()[0] )
28
+ df['hist'] = [hist]
29
+ df['fn'] = fn
30
+ except:
31
+ print('err')
32
+ #print(df['fn_hist'])
33
+ dfs.append( df )
34
+ df = pd.concat(dfs,ignore_index=True)
35
+ return df
36
+
37
+ def get_x(k, kw, operation='max', index=None):
38
+ operation = getattr(np,operation)
39
+ try:
40
+ if index is not None:
41
+ return k[kw][index]
42
+
43
+ return operation(k[kw])
44
+ except:
45
+ return 0
46
+
47
+ def get_min_val_loss_idx(k):
48
+ return get_x(k, 'loss_valid', 'argmin') #changed from argmax to argmin!!
49
+
50
+ def get_tauc(hist):
51
+ idx = get_min_val_loss_idx(hist)
52
+ # takes max TODO take idx
53
+ return np.mean([get_x(hist, f't100_acc_nte_{nt}') for nt in [*range(11),'>10']])
54
+
55
+ def get_stats_from_hist(df):
56
+ df['0shot_acc'] = df['hist'].apply(lambda k: get_x(k, 't100_acc_nte_0'))
57
+ df['1shot_acc'] = df['hist'].apply(lambda k: get_x(k, 't100_acc_nte_1'))
58
+ df['>49shot_acc'] = df['hist'].apply(lambda k: get_x(k, 't100_acc_nte_>49'))
59
+ df['min_loss_valid'] = df['hist'].apply(lambda k: get_x(k, 'loss_valid', 'min'))
60
+ return df