# -*- coding: utf-8 -*- """ Author: Philipp Seidl ELLIS Unit Linz, LIT AI Lab, Institute for Machine Learning Johannes Kepler University Linz Contact: seidl@ml.jku.at Training """ from .utils import str2bool, lgamma, multinom_gk, top_k_accuracy from .data import load_templates, load_dataset_from_csv, load_USPTO from .model import ModelConfig, MHN, StaticQK, SeglerBaseline, Retrosim from .molutils import convert_smiles_to_fp, FP_featurizer, smarts2appl, getTemplateFingerprint, disable_rdkit_logging from collections import defaultdict import argparse import os import numpy as np import pandas as pd import datetime import sys from time import time import matplotlib.pyplot as plt import torch import multiprocessing import warnings from joblib import Memory cachedir = 'data/cache/' memory = Memory(cachedir, verbose=0, bytes_limit=80e9) def parse_args(): parser = argparse.ArgumentParser(description="Train MHNreact.", epilog="--", prog="Train") parser.add_argument('-f', type=str) parser.add_argument('--model_type', type=str, default='mhn', help="Model-type: choose from 'segler', 'fortunato', 'mhn' or 'staticQK', default:'mhn'") parser.add_argument("--exp_name", type=str, default='', help="experiment name, (added as postfix to the file-names)") parser.add_argument("-d", "--dataset_type", type=str, default='sm', help="Input Dataset 'sm' for Scheider-USPTO-50k 'lg' for USPTO large or 'golden' or use keyword '--csv_path to specify an input file', default: 'sm'") parser.add_argument("--csv_path", default=None, type=str, help="path to preprocessed trainings file + split columns, default: None") parser.add_argument("--split_col", default='split', type=str, help="split column of csv, default: 'split'") parser.add_argument("--input_col", default='prod_smiles', type=str, help="input column of csv, default: 'pro_smiles'") parser.add_argument("--reactants_col", default='reactants_can', type=str, help="reactant colum of csv, default: 'reactants_can'") parser.add_argument("--fp_type", type=str, default='morganc', help="Fingerprint type for the input only!: default: 'morgan', other options: 'rdk', 'ECFP', 'ECFC', 'MxFP', 'Morgan2CBF' or a combination of fingerprints with '+'' for max-pooling and '&' for concatination e.g. maccs+morganc+topologicaltorsion+erg+atompair+pattern+rdkc+layered+mhfp, default: 'morganc'") parser.add_argument("--template_fp_type", type=str, default='rdk', help="Fingerprint type for the template fingerprint, default: 'rdk'") parser.add_argument("--device", type=str, default='best', help="Device to run the model on, preferably 'cuda:0', default: 'best' (takes the gpu with most RAM)") parser.add_argument("--fp_size", type=int, default=4096, help="fingerprint-size used for templates as well as for inputs, default: 4096") parser.add_argument("--fp_radius", type=int, default=2, help="fingerprint-radius (if applicable to the fingerprint-type), default: 2") parser.add_argument("--epochs", type=int, default=10, help='number of epochs, default: 10') parser.add_argument("--pretrain_epochs", type=int, default=0, help="applicability-matrix pretraining epochs if applicable (e.g. fortunato model_type), default: 0") parser.add_argument("--save_model", type=str2bool, default=False, help="save the model, default: False") parser.add_argument("--dropout", type=float, default=0.2, help="dropout rate for encoders, default: 0.2") parser.add_argument("--lr", type=float, default=5e-4, help="learning-rate, dfeault: 5e-4") parser.add_argument("--hopf_beta", type=float, default=0.05, help="hopfield beta parameter, default: 0.125") parser.add_argument("--hopf_asso_dim", type=int, default=512, help="association dimension, default: 512") parser.add_argument("--hopf_num_heads", type=int, default=1, help="hopfield number of heads, default: 1") parser.add_argument("--hopf_association_activation", type=str, default='None', help="hopfield association activation function recommended:'Tanh' or 'None', other: 'ReLU', 'SeLU', 'GeLU', or 'None' for more, see torch.nn, default: 'None'") parser.add_argument("--norm_input", default=True, type=str2bool, help="input-normalization, default: True") parser.add_argument("--norm_asso", default=True, type=str2bool, help="association-normalization, default: True") # additional experimental hyperparams parser.add_argument("--hopf_n_layers", default=1, type=int, help="Number of hopfield-layers, default: 1") parser.add_argument("--mol_encoder_layers", default=1, type=int, help="Number of molecule-encoder layers, default: 1") parser.add_argument("--temp_encoder_layers", default=1, type=int, help="Number of template-encoder layers, default: 1") parser.add_argument("--encoder_af", default='ReLU', type=str, help="Encoder-NN intermediate activation function (before association_activation function), default: 'ReLU'") parser.add_argument("--hopf_pooling_operation_head", default='mean', type=str, help="Pooling operation over heads default=max, (max, min, mean, ...), default: 'mean'") parser.add_argument("--splitting_scheme", default=None, type=str, help="Splitting_scheme for non-csv-input, default: None, other options: 'class-freq', 'random'") parser.add_argument("--concat_rand_template_thresh", default=-1, type=int, help="Concatinates a random vector to the tempalte-fingerprint at all templates with num_training samples > this threshold; -1 (default) means deactivated") parser.add_argument("--repl_quotient", default=10, type=float, help="Only if --concat_rand_template_thresh >= 0 - Quotient of how much should be replaced by random in template-embedding, (default: 10)") parser.add_argument("--verbose", default=False, type=str2bool, help="If verbose, will print out more stuff, default: False") parser.add_argument("--batch_size", default=128, type=int, help="Training batch-size, default: 128") parser.add_argument("--eval_every_n_epochs", default=1, type=int, help="Evaluate every _ epochs (Evaluation is costly for USPTO-Lg), default: 1") parser.add_argument("--save_preds", default=False, type=str2bool, help="Save predictions for test split at the end of training, default: False") parser.add_argument("--wandb", default=False, type=str2bool, help="Save to wandb; login required, default: False") parser.add_argument("--seed", default=None, type=int, help="Seed your run to make it reproducible, defualt: None") parser.add_argument("--template_fp_type2", default=None, type=str, help="experimental template_fp_type for layer 2, default: None") parser.add_argument("--layer2weight",default=0.2, type=float, help="hopf-layer2 weight of p, default: 0.2") parser.add_argument("--reactant_pooling", default='max', type=str, help="reactant pooling operation over template-fingerprint, default: 'max', options: 'min','mean','lgamma'") parser.add_argument("--ssretroeval", default=False, type=str2bool, help="single-step retro-synthesis eval, default: False") parser.add_argument("--addval2train", default=False, type=str2bool, help="adds the validation set to the training set, default: False") parser.add_argument("--njobs",default=-1, type=int, help="Number of jobs, default: -1 -> uses all available") parser.add_argument("--eval_only_loss", default=False, type=str2bool, help="if only loss should be evaluated (if top-k acc may be time consuming), default: False") parser.add_argument("--only_templates_in_batch", default=False, type=str2bool, help="while training only forwards templates that are in the batch, default: False") parser.add_argument("--plot_res", default=False, type=str2bool, help="Plotting results for USPTO-sm/lg, default: False") args = parser.parse_args() if args.njobs ==-1: args.njobs = int(multiprocessing.cpu_count()) if args.device=='best': from .utils import get_best_gpu try: args.device = get_best_gpu() except: print('couldnt get the best gpu, using cpu instead') args.device = 'cpu' # some save checks on model type if (args.model_type == 'segler') & (args.pretrain_epochs>=1): print('changing model type to fortunato because of pretraining_epochs>0') args.model_type = 'fortunato' if ((args.model_type == 'staticQK') or (args.model_type == 'retrosim')) & (args.epochs>1): print('changing epochs to 1 (StaticQK is not lernable ;)') args.epochs=1 if args.template_fp_type != args.fp_type: print('fp_type must be the same as template_fp_type --> setting template_fp_type to fp_type') args.template_fp_type = args.fp_type if args.save_model & (args.fp_type=='MxFP'): warnings.warn('Currently MxFP is not recommended for saving the model paprameter (fragment dict for others would need to be saved or compued again, currently not implemented)') return args @memory.cache(ignore=['njobs']) def featurize_smiles(X, fp_type='morgan', fp_size=4096, fp_radius=2, njobs=1, verbose=False): X_fp = {} if fp_type in ['MxFP','MACCS','Morgan2CBF','Morgan4CBF', 'Morgan6CBF', 'ErG','AtomPair','TopologicalTorsion','RDK']: print('computing', fp_type) if fp_type == 'MxFP': fp_types = ['MACCS','Morgan2CBF','Morgan4CBF', 'Morgan6CBF', 'ErG','AtomPair','TopologicalTorsion','RDK'] else: fp_types = [fp_type] remaining = int(fp_size) for fp_type in fp_types: print(fp_type,end=' ') feat = FP_featurizer(fp_types=fp_type, max_features= (fp_size//len(fp_types)) if (fp_type != fp_types[-1]) else remaining ) X_fp[f'train_{fp_type}'] = feat.fit(X['train']) X_fp[f'valid_{fp_type}'] = feat.transform(X['valid']) X_fp[f'test_{fp_type}'] = feat.transform(X['test']) remaining -= X_fp[f'train_{fp_type}'].shape[1] #X_fp['train'].shape, X_fp['test'].shape X_fp['train'] = np.hstack([ X_fp[f'train_{fp_type}'] for fp_type in fp_types]) X_fp['valid'] = np.hstack([ X_fp[f'valid_{fp_type}'] for fp_type in fp_types]) X_fp['test'] = np.hstack([ X_fp[f'test_{fp_type}'] for fp_type in fp_types]) else: #fp_type in ['rdk','morgan','ecfp4','pattern','morganc','rdkc']: if verbose: print('computing', fp_type, 'folded') for split in X.keys(): X_fp[split] = convert_smiles_to_fp(X[split], fp_size=fp_size, which=fp_type, radius=fp_radius, njobs=njobs, verbose=verbose) return X_fp def compute_template_fp(fp_len=2048, reactant_pooling='max', do_log=True): """Pre-Compute the template-fingerprint""" # combine them to one fingerprint comb_template_fp = np.zeros((max(template_list.keys())+1,fp_len if reactant_pooling!='concat' else fp_len*6)) for i in template_list: tpl = template_list[i] try: pr, rea = str(tpl).split('>>') idxx = temp_part_to_fp[pr] prod_fp = templates_fp['fp'][idxx] except: print('err', pr, end='\r') prod_fp = np.zeros(fp_len) rea_fp = templates_fp['fp'][[temp_part_to_fp[r] for r in str(rea).split('.')]] # max-pooling if reactant_pooling=='only_product': rea_fp = np.zeros(fp_len) if reactant_pooling=='max': rea_fp = np.log(1 + rea_fp.max(0)) elif reactant_pooling=='mean': rea_fp = np.log(1 + rea_fp.mean(0)) elif reactant_pooling=='sum': rea_fp = np.log(1 + rea_fp.mean(0)) elif reactant_pooling=='lgamma': rea_fp = multinom_gk(rea_fp, axis=0) elif reactant_pooling=='concat': rs = str(rea).split('.') rs.sort() for ii, r in enumerate(rs): idx = temp_part_to_fp[r] rea_fp = templates_fp['fp'][idx] comb_template_fp[i, (fp_len*(ii+1)):(fp_len*(ii+2))] = np.log(1 + rea_fp) comb_template_fp[i,:prod_fp.shape[0]] = np.log(1 + prod_fp) #- rea_fp*0.5 if reactant_pooling!='concat': #comb_template_fp[i] = multinom_gk(np.stack([np.log(1+prod_fp), rea_fp])) #comb_template_fp[i,fp_len:] = rea_fp comb_template_fp[i,:rea_fp.shape[0]] = comb_template_fp[i, :rea_fp.shape[0]] - rea_fp*0.5 return comb_template_fp def set_up_model(args, template_list=None): hpn_config = ModelConfig(num_templates = int(max(template_list.keys()))+1, #len(template_list.values()), #env.num_templates, # dropout=args.dropout, fingerprint_type=args.fp_type, template_fp_type = args.template_fp_type, fp_size = args.fp_size, fp_radius= args.fp_radius, device=args.device, lr=args.lr, hopf_beta=args.hopf_beta, #1/(128**0.5),#1/(2048**0.5), hopf_input_size=args.fp_size, hopf_output_size=None, hopf_num_heads=args.hopf_num_heads, hopf_asso_dim=args.hopf_asso_dim, hopf_association_activation = args.hopf_association_activation, #or ReLU, Tanh works better, SELU, GELU norm_input = args.norm_input, norm_asso = args.norm_asso, hopf_n_layers= args.hopf_n_layers, mol_encoder_layers=args.mol_encoder_layers, temp_encoder_layers=args.temp_encoder_layers, encoder_af=args.encoder_af, hopf_pooling_operation_head = args.hopf_pooling_operation_head, batch_size=args.batch_size, ) print(hpn_config.__dict__) if args.model_type=='segler': # baseline clf = SeglerBaseline(hpn_config) elif args.model_type=='mhn': clf = MHN(hpn_config, layer2weight=args.layer2weight) elif args.model_type=='fortunato': # pretraining with applicability-matrix clf = SeglerBaseline(hpn_config) elif args.model_type=='staticQK': # staticQK clf = StaticQK(hpn_config) elif args.model_type=='retrosim': # staticQK clf = Retrosim(hpn_config) else: raise NotImplementedError return clf, hpn_config def set_up_template_encoder(args, clf, label_to_n_train_samples=None, template_list=None): if isinstance(clf, SeglerBaseline): clf.templates = [] elif args.model_type=='staticQK': clf.template_list = list(template_list.values()) clf.update_template_embedding(which=args.template_fp_type, fp_size=args.fp_size, radius=args.fp_radius, njobs=args.njobs) elif args.model_type=='retrosim': #clf.template_list = list(X['train'].values()) clf.fit_with_train(X_fp['train'], y['train']) else: import hashlib PATH = './data/cache/' if not os.path.exists(PATH): os.mkdir(PATH) fn_templ_emb = f'{PATH}templ_emb_{args.fp_size}_{args.template_fp_type}{args.fp_radius}_{len(template_list)}_{int(hashlib.sha512((str(template_list)).encode()).hexdigest(), 16)}.npy' if (os.path.exists(fn_templ_emb)): # load the template embedding print(f'loading tfp from file {fn_templ_emb}') templ_emb = np.load(fn_templ_emb) # !!! beware of different fingerprint types clf.template_list = list(template_list.values()) if args.only_templates_in_batch: clf.templates_np = templ_emb clf.templates = None else: clf.templates = torch.from_numpy(templ_emb).float().to(clf.config.device) else: if args.template_fp_type=='MxFP': clf.template_list = list(template_list.values()) clf.templates = torch.from_numpy(comb_template_fp).float().to(clf.config.device) clf.set_templates_recursively() elif args.template_fp_type=='Tfidf': clf.template_list = list(template_list.values()) clf.templates = torch.from_numpy(tfidf_template_fp).float().to(clf.config.device) clf.set_templates_recursively() elif args.template_fp_type=='random': clf.template_list = list(template_list.values()) clf.templates = torch.from_numpy(np.random.rand(len(template_list),args.fp_size)).float().to(clf.config.device) clf.set_templates_recursively() else: clf.set_templates(list(template_list.values()), which=args.template_fp_type, fp_size=args.fp_size, radius=args.fp_radius, learnable=False, njobs=args.njobs, only_templates_in_batch=args.only_templates_in_batch) #if len(template_list)<100000: np.save(fn_templ_emb, clf.templates_np if args.only_templates_in_batch else clf.templates.detach().cpu().numpy().astype(np.float16)) # concatinate the current fingerprint with a random fingerprint if the threshold is above if (args.concat_rand_template_thresh != -1) & (args.repl_quotient>0): REPLACE_FACTOR = int(args.repl_quotient) # default was 8 # fold the original fingerprint pre_comp_templates = clf.templates_np if args.only_templates_in_batch else clf.templates.detach().cpu().numpy() # mask of labels with mor than 49 training samples l_mask = np.array([label_to_n_train_samples[k]>=args.concat_rand_template_thresh for k in template_list]) print(f'Num of templates with added rand-vect of size {pre_comp_templates.shape[1]//REPLACE_FACTOR} due to >=thresh ({args.concat_rand_template_thresh}):',l_mask.sum()) # remove the bits with the lowest variance v = pre_comp_templates.var(0) idx_lowest_var_half = v.argsort()[:(pre_comp_templates.shape[1]//REPLACE_FACTOR)] # the new zero-init-vectors pre = np.zeros([pre_comp_templates.shape[0], pre_comp_templates.shape[1]//REPLACE_FACTOR]).astype(np.float) print(pre.shape, l_mask.shape, l_mask.sum()) #(616, 1700) (11790,) 519 print(pre_comp_templates.shape, len(template_list)) #(616, 17000) 616 # only the ones with >thresh will receive a random vect pre[l_mask] = np.random.rand(l_mask.sum(), pre.shape[1]) pre_comp_templates[:,idx_lowest_var_half] = pre #clf.templates = torch.from_numpy(pre_comp_templates).float().to(clf.config.device) if pre_comp_templates.shape[0]<100000: print('adding template_matrix to params') param = torch.nn.Parameter(torch.from_numpy(pre_comp_templates).float(), requires_grad=False) clf.register_parameter(name='templates+noise', param=param) clf.templates = param.to(clf.config.device) clf.set_templates_recursively() else: #otherwise might cause memory issues print('more than 100k templates') if args.only_templates_in_batch: clf.templates = None clf.templates_np = pre_comp_templates else: clf.templates = torch.from_numpy(pre_comp_templates).float() clf.set_templates_recursively() # set's this for the first layer!! if args.template_fp_type2=='MxFP': print('first_layer template_fingerprint is set to MxFP') clf.templates = torch.from_numpy(comb_template_fp).float().to(clf.config.device) elif args.template_fp_type2=='Tfidf': print('first_layer template_fingerprint is set to Tfidf') clf.templates = torch.from_numpy(tfidf_template_fp).float().to(clf.config.device) elif args.template_fp_type2=='random': print('first_layer template_fingerprint is set to random') clf.templates = torch.from_numpy(np.random.rand(len(template_list),args.fp_size)).float().to(clf.config.device) elif args.template_fp_type2=='stfp': print('first_layer template_fingerprint is set to stfp ! only works with 4096 fp_size') tfp = getTemplateFingerprint(list(template_list.values())) clf.templates = torch.from_numpy(tfp).float().to(clf.config.device) return clf if __name__ == '__main__': args = parse_args() run_id = str(time()).split('.')[0] fn_postfix = str(args.exp_name) + '_' + run_id if args.wandb: import wandb wandb.init(project='mhn-react', entity='phseidl', name=args.dataset_type+'_'+args.model_type+'_'+fn_postfix, config=args.__dict__) else: wandb=None if not args.verbose: disable_rdkit_logging() if args.seed is not None: from .utils import seed_everything seed_everything(args.seed) print('seeded with',args.seed) # load csv or data if args.csv_path is None: X, y = load_USPTO(which=args.dataset_type) template_list = load_templates(which=args.dataset_type) else: X, y, template_list, test_reactants_can = load_dataset_from_csv(**vars(args)) if args.addval2train: print('adding val to train') X['train'] = [*X['train'],*X['valid']] y['train'] = np.concatenate([y['train'],y['valid']]) splits = ['train', 'valid', 'test'] #TODO split up in seperate class if args.splitting_scheme == 'class-freq': X_all = np.concatenate([X[split] for split in splits], axis=0) y_all = np.concatenate([y[split] for split in splits]) # sort class by frequency / assumes class-index is ordered (wich is mildely violated) res = y_all.argsort() # use same split proportions cum_split_lens = np.cumsum([len(y[split]) for split in splits]) #cumulative split length X['train'] = X_all[res[0:cum_split_lens[0]]] y['train'] = y_all[res[0:cum_split_lens[0]]] X['valid'] = X_all[res[cum_split_lens[0]:cum_split_lens[1]]] y['valid'] = y_all[res[cum_split_lens[0]:cum_split_lens[1]]] X['test'] = X_all[res[cum_split_lens[1]:]] y['test'] = y_all[res[cum_split_lens[1]:]] for split in splits: print(split, y[split].shape[0], 'samples (', y[split].max(),'max label)') if args.splitting_scheme == 'remove_once_in_train_and_not_in_test': print('remove_once_in_train') from collections import Counter cc = Counter() cc.update(y['train']) classes_set_only_once_in_train = set(np.array(list(cc.keys()))[ (np.array(list(cc.values())))==1]) not_in_test = set(y['train']).union(y['valid']) - (set(y['test'])) classes_set_only_once_in_train = (classes_set_only_once_in_train.intersection(not_in_test)) remove_those_mask = np.array([yii in classes_set_only_once_in_train for yii in y['train']]) X['train'] = np.array(X['train'])[~remove_those_mask] y['train'] = np.array(y['train'])[~remove_those_mask] print(remove_those_mask.mean(),'%', remove_those_mask.sum(), 'samples removed') if args.splitting_scheme == 'random': print('random-splitting-scheme:8-1-1') if args.ssretroeval: print('ssretroeval not available') raise NotImplementedError import numpy as np from sklearn.model_selection import train_test_split def _unpack(lod): r = [] for k,v in lod.items(): [r.append(i) for i in v] return r X_all = _unpack(X) y_all = np.array( _unpack(y) ) X['train'], X['test'], y['train'], y['test'] = train_test_split(X_all, y_all, test_size=0.2, random_state=70135) X['test'], X['valid'], y['test'], y['valid'] = train_test_split(X['test'], y['test'], test_size=0.5, random_state=70135) zero_shot = set(y['test']).difference( set(y['train']).union(set(y['valid'])) ) zero_shot_mask = np.array([yi in zero_shot for yi in y['test']]) print(sum(zero_shot_mask)) #y['test'][zero_shot_mask] = list(zero_shot)[0] #not right but quick if args.model_type=='staticQK' or args.model_type=='retrosim': print('staticQK model: caution: use pattern, or rdk -fingerprint-embedding') fp_size = args.fp_size radius = args.fp_radius #quite important ;) fp_embedding = args.fp_type X_fp = featurize_smiles(X, fp_type=args.fp_type, fp_size=args.fp_size, fp_radius=args.fp_radius, njobs=args.njobs) if args.template_fp_type=='MxFP' or (args.template_fp_type2=='MxFP'): temp_part_to_fp = {} for i in template_list: tpl = template_list[i] for part in str(tpl).split('>>'): for p in str(part).split('.'): temp_part_to_fp[p]=None for i, k in enumerate(temp_part_to_fp): temp_part_to_fp[k] = i fp_types = ['Morgan2CBF','Morgan4CBF', 'Morgan6CBF','AtomPair','TopologicalTorsion', 'Pattern', 'RDK'] #MACCS ErG don't work --> errors with explicit / inplicit valence templates_fp = {} remaining = args.fp_size for fp_type in fp_types: #print(fp_type, end='\t') # if it's that last use up the remaining fps te_feat = FP_featurizer(fp_types=fp_type, max_features=(args.fp_size//len(fp_types)) if (fp_type != fp_types[-1]) else remaining, log_scale=False ) templates_fp[fp_type] = te_feat.fit(list(temp_part_to_fp.keys())[:], is_smarts=True) #print(np.unique(templates_fp[fp_type]), end='\r') remaining -= templates_fp[fp_type].shape[1] templates_fp['fp'] = np.hstack([ templates_fp[f'{fp_type}'] for fp_type in fp_types]) if args.template_fp_type=='MxFP' or (args.template_fp_type2=='MxFP'): comb_template_fp = compute_template_fp(fp_len= args.fp_size, reactant_pooling=args.reactant_pooling) if args.template_fp_type=='Tfidf' or (args.template_fp_type2 == 'Tfidf'): print('using tfidf template-fingerprint') from sklearn.feature_extraction.text import TfidfVectorizer corpus = (list(template_list.values())) vectorizer = TfidfVectorizer(analyzer='char', ngram_range=(1,12), max_features=args.fp_size) tfidf_template_fp = vectorizer.fit_transform(corpus).toarray() tfidf_template_fp.shape acutal_fp_size = X_fp['train'].shape[1] if acutal_fp_size != args.fp_size: args.fp_size = int(X_fp['train'].shape[1]) print('Warning: fp-size has changed to', acutal_fp_size) label_to_n_train_samples = {} n_train_samples_to_label = defaultdict(list) n_templates = max(template_list.keys())+1 #max(max(y['train']), max(y['test']), max(y['valid'])) for i in range(n_templates): n_train_samples = (y['train']==i).sum() label_to_n_train_samples[i] = n_train_samples n_train_samples_to_label[n_train_samples].append(i) up_to = 11 n_samples = [] masks = [] ntes = range(up_to) mask_dict = {} for nte in ntes: # Number of training examples split = f'nte_{nte}' #print(split) mask = np.zeros(y['test'].shape) if isinstance(nte, int): for label_with_nte in n_train_samples_to_label[nte]: mask += (y['test'] == label_with_nte) mask = mask>=1 masks.append(mask) mask_dict[str(nte)] = mask n_samples.append(mask.sum()) # for greater than 10 # >10 n_samples.append((np.array(masks).max(0)==0).sum()) mask_dict['>10'] = (np.array(masks).max(0)==0) sum(n_samples), mask.shape ntes = range(50) #to 49 for nte in ntes: # Number of training examples split = f'nte_{nte}' #print(split) mask = np.zeros(y['test'].shape) for label_with_nte in n_train_samples_to_label[nte]: mask += (y['test'] == label_with_nte) mask = mask>=1 masks.append(mask) # for greater than 10 # >49 n_samples.append((np.array(masks).max(0)==0).sum()) mask_dict['>49'] = np.array(masks).max(0)==0 print(n_samples) clf, hpn_config = set_up_model(args, template_list=template_list) clf = set_up_template_encoder(args, clf, label_to_n_train_samples=label_to_n_train_samples, template_list=template_list) if args.verbose: print(clf.config.__dict__) print(clf) wda = torch.optim.AdamW(clf.parameters(), lr=args.lr, weight_decay=1e-2) if args.wandb: wandb.watch(clf) # pretraining with applicablity matrix, if applicable if args.model_type == 'fortunato' or args.pretrain_epochs>1: print('pretraining on applicability-matrix -- loading the matrix') _, y_appl = load_USPTO(args.dataset_type, is_appl_matrix=True) if args.splitting_scheme == 'remove_once_in_train_and_not_in_test': y_appl['train'] = y_appl['train'][~remove_those_mask] # check random if the applicability is true for y splt = 'train' for i in range(500): i = np.random.randint(len(y[splt])) #assert ( y_appl[splt][i].indices == y[splt][i] ).sum()==1 print('pre-training (BCE-loss)') for epoch in range(args.pretrain_epochs): clf.train_from_np(X_fp['train'], X_fp['train'], y_appl['train'], use_dataloader=True, is_smiles=False, epochs=1, wandb=wandb, verbose=args.verbose, bs=args.batch_size, permute_batches=True, shuffle=True, optimizer=wda, only_templates_in_batch=args.only_templates_in_batch) y_pred = clf.evaluate(X_fp['valid'], X_fp['valid'], y_appl['valid'], split='pretrain_valid', is_smiles=False, only_loss=True, bs=args.batch_size,wandb=wandb) appl_acc = ((y_appl['valid'].toarray()) == (y_pred>0.5)).mean() print(f'{epoch:2.0f} -- train_loss: {clf.hist["loss"][-1]:1.3f}, loss_valid: {clf.hist["loss_pretrain_valid"][-1]:1.3f}, train_acc: {appl_acc:1.5f}') fn_hist = None y_preds = None for epoch in range(round(args.epochs / args.eval_every_n_epochs)): if not isinstance(clf, StaticQK): now = time() clf.train_from_np(X_fp['train'], X_fp['train'], y['train'], use_dataloader=True, is_smiles=False, epochs=args.eval_every_n_epochs, wandb=wandb, verbose=args.verbose, bs=args.batch_size, permute_batches=True, shuffle=True, optimizer=wda, only_templates_in_batch=args.only_templates_in_batch) if args.verbose: print(f'training took {(time()-now)/60:3.1f} min for {args.eval_every_n_epochs} epochs') for split in ['valid', 'test']: print(split, 'evaluating', end='\r') now = time() #only_loss = ((epoch%5)==4) if args.dataset_type=='lg' else True y_preds = clf.evaluate(X_fp[split], X_fp[split], y[split], is_smiles=False, split=split, bs=args.batch_size, only_loss=args.eval_only_loss, wandb=wandb); if args.verbose: print(f'eval {split} took',(time()-now)/60,'min') if not isinstance(clf, StaticQK): try: print(f'{epoch:2.0f} -- train_loss: {clf.hist["loss"][-1]:1.3f}, loss_valid: {clf.hist["loss_valid"][-1]:1.3f}, val_t1acc: {clf.hist["t1_acc_valid"][-1]:1.3f}, val_t100acc: {clf.hist["t100_acc_valid"][-1]:1.3f}') except: pass now = time() ks = [1, 2, 3, 4, 5, 10, 20, 30, 40, 50, 100] for nte in mask_dict: # Number of training examples split = f'nte_{nte}' #print(split) mask = mask_dict[nte] topkacc = top_k_accuracy(np.array(y['test'])[mask], y_preds[mask, :], k=ks, ret_arocc=False) new_hist = {} for k, tkacc in zip(ks, topkacc): new_hist[f't{k}_acc_{split}'] = tkacc #new_hist[(f'arocc_{split}')] = (arocc) new_hist[f'steps_{split}'] = (clf.steps) for k in new_hist: clf.hist[k].append(new_hist[k]) if args.verbose: print(f'eval nte-test took',(time()-now)/60,'min') fn_hist = clf.save_hist(prefix=f'USTPO_{args.dataset_type}_{args.model_type}_', postfix=fn_postfix) if args.save_preds: PATH = './data/preds/' if not os.path.exists(PATH): os.mkdir(PATH) pred_fn = f'{PATH}USPTO_{args.dataset_type}_test_{args.model_type}_{fn_postfix}.npy' print('saving predictions to',pred_fn) np.save(pred_fn,y_preds) args.save_preds = pred_fn if args.save_model: model_save_path = clf.save_model(prefix=f'USPTO_{args.dataset_type}_{args.model_type}_valloss{clf.hist.get("loss_valid",[-1])[-1]:1.3f}_',name_as_conf=False, postfix=fn_postfix) # Serialize data into file: import json json.dump( args.__dict__, open( f"data/model/{fn_postfix}_args.json", 'w' ) ) json.dump( hpn_config.__dict__, open( f"data/model/{fn_postfix}_config.json", 'w' ) ) print('model saved to', model_save_path) print(min(clf.hist.get('loss_valid',[-1]))) if args.plot_res: from plotutils import plot_topk, plot_nte plt.figure() clf.plot_loss() plt.draw() plt.figure() plot_topk(clf.hist, sets=['valid']) if args.dataset_type=='sm': baseline_val_res = {1:0.4061, 10:0.6827, 50: 0.7883, 100:0.8400} plt.plot(list(baseline_val_res.keys()), list(baseline_val_res.values()), 'k.--') plt.draw() plt.figure() best_cpt = np.array(clf.hist['loss_valid'])[::-1].argmin()+1 print(best_cpt) try: best_cpt = np.array(clf.hist['t10_acc_valid'])[::-1].argmax()+1 print(best_cpt) except: print('err with t10_acc_valid') plot_nte(clf.hist, dataset=args.dataset_type.capitalize(), last_cpt=best_cpt, include_bar=True, model_legend=args.exp_name, n_samples=n_samples, z=1.96) if os.path.exists('data/figs/'): try: os.mkdir(f'data/figs/{args.exp_name}/') except: pass plt.savefig(f'data/figs/{args.exp_name}/training_examples_vs_top100_acc_{args.dataset_type}_{hash(str(args))}.pdf') plt.draw() fn_hist = clf.save_hist(prefix=f'USTPO_{args.dataset_type}_{args.model_type}_', postfix=fn_postfix) if args.ssretroeval: print('testing on the real test set ;)') from .data import load_templates from .retroeval import run_templates, topkaccuracy from .utils import sort_by_template_and_flatten a = list(template_list.keys()) #assert list(range(len(a))) == a templates = list(template_list.values()) #templates = [*templates, *expert_templates] template_product_smarts = [str(s).split('>')[0] for s in templates] #execute all template print('execute all templates') test_product_smarts = [xi[0] for xi in X['test']] #added later smarts2appl = memory.cache(smarts2appl, ignore=['njobs','nsplits', 'use_tqdm']) appl = smarts2appl(test_product_smarts, template_product_smarts, njobs=args.njobs) n_pairs = len(test_product_smarts) * len(template_product_smarts) n_appl = len(appl[0]) print(n_pairs, n_appl, n_appl/n_pairs) #forward split = 'test' print('len(X_fp[test]):',len(X_fp[split])) y[split] = np.zeros(len(X[split])).astype(np.int) clf.eval() if y_preds is None: y_preds = clf.evaluate(X_fp[split], X_fp[split], y[split], is_smiles=False, split='ttest', bs=args.batch_size, only_loss=True, wandb=None); template_scores = y_preds #this should allready be test #### if y_preds.shape[1]>100000: kth = 200 print(f'only evaluating top {kth} applicable predicted templates') # only take top kth and multiply by applicability matrix appl_mtrx = np.zeros_like(y_preds, dtype=bool) appl_mtrx[appl[0], appl[1]] = 1 appl_and_topkth = ([], []) for row in range(len(y_preds)): argpreds = (np.argpartition(-(y_preds[row]*appl_mtrx[row]), kth, axis=0)[:kth]) # if there are less than kth applicable mask = appl_mtrx[row][argpreds] argpreds = argpreds[mask] #if len(argpreds)!=kth: # print('changed to ', len(argpreds)) appl_and_topkth[0].extend([row for _ in range(len(argpreds))]) appl_and_topkth[1].extend(list(argpreds)) appl = appl_and_topkth #### print('running the templates') run_templates = run_templates #memory.cache( ) ... allready cached to tmp prod_idx_reactants, prod_temp_reactants = run_templates(test_product_smarts, templates, appl, njobs=args.njobs) #sorted_results = sort_by_template(template_scores, prod_idx_reactants) #flat_results = flatten_per_product(sorted_results, remove_duplicates=True) #now aglomerates over same outcome flat_results = sort_by_template_and_flatten(y_preds, prod_idx_reactants, agglo_fun=sum) accs = topkaccuracy(test_reactants_can, flat_results, [*list(range(1,101)), 100000]) mtrcs2 = {f't{k}acc_ttest':accs[k-1] for k in [1,2,3,5,10,20,50,100,101]} if wandb: wandb.log(mtrcs2) print('Single-step retrosynthesis-evaluation, results on ttest:') #print([k[:-6]+'|' for k in mtrcs2.keys()]) [print(k[:-6],end='\t') for k in mtrcs2.keys()] print() for k,v in mtrcs2.items(): print(f'{v*100:2.2f}',end='\t') # save the history of this experiment EXP_DIR = 'data/experiments/' df = pd.DataFrame([args.__dict__]) df['min_loss_valid'] = min(clf.hist.get('loss_valid', [-1])) df['min_loss_train'] = 0 if ((args.model_type=='staticQK') or (args.model_type=='retrosim')) else min(clf.hist.get('loss',[-1])) try: df['max_t1_acc_valid'] = max(clf.hist.get('t1_acc_valid', [0])) df['max_t100_acc_valid'] = max(clf.hist.get('t100_acc_valid', [0])) except: pass df['hist'] = [clf.hist] df['n_samples'] = [n_samples] df['fn_hist'] = fn_hist if fn_hist else None df['fn_model'] = '' if not args.save_model else model_save_path df['date'] = str(datetime.datetime.fromtimestamp(time())) df['cmd'] = ' '.join(sys.argv[:]) if not os.path.exists(EXP_DIR): os.mkdir(EXP_DIR) df.to_csv(f'{EXP_DIR}{run_id}.tsv', sep='\t') df