Spaces:
Runtime error
Runtime error
# -*- coding: utf-8 -*- | |
""" | |
Author: Philipp Seidl | |
ELLIS Unit Linz, LIT AI Lab, Institute for Machine Learning | |
Johannes Kepler University Linz | |
Contact: seidl@ml.jku.at | |
Model related functionality | |
""" | |
from .utils import top_k_accuracy | |
from .plotutils import plot_loss, plot_topk, plot_nte | |
from .molutils import convert_smiles_to_fp | |
import os | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
from collections import defaultdict | |
from scipy import sparse | |
import logging | |
from tqdm import tqdm | |
import wandb | |
log = logging.getLogger(__name__) | |
class ChemRXNDataset(torch.utils.data.Dataset): | |
"Torch Dataset for ChemRXN containing Xs: the input as np array, target: the target molecules (or nothing), and ys: the label" | |
def __init__(self, Xs, target, ys, is_smiles=False, fp_size=2048, fingerprint_type='morgan'): | |
self.is_smiles=is_smiles | |
if is_smiles: | |
self.Xs = Xs | |
self.target = target | |
self.fp_size = fp_size | |
self.fingerprint_type = fingerprint_type | |
else: | |
self.Xs = Xs.astype(np.float32) | |
self.target = target.astype(np.float32) | |
self.ys = ys | |
self.ys_is_sparse = isinstance(self.ys, sparse.csr.csr_matrix) | |
def __getitem__(self, k): | |
mol_fp = self.Xs[k] | |
if self.is_smiles: | |
mol_fp = convert_smiles_to_fp(mol_fp, fp_size=self.fp_size, which=self.fingerprint_type).astype(np.float32) | |
target = None if self.target is None else self.target[k] | |
if self.is_smiles and self.target: | |
target = convert_smiles_to_fp(target, fp_size=self.fp_size, which=self.fingerprint_type).astype(np.float32) | |
label = self.ys[k] | |
if isinstance(self.ys, sparse.csr.csr_matrix): | |
label = label.toarray()[0] | |
return (mol_fp, target, label) | |
def __len__(self): | |
return len(self.Xs) | |
class ModelConfig(object): | |
def __init__(self, **kwargs): | |
self.fingerprint_type = kwargs.pop("fingerprint_type", 'morgan') | |
self.template_fp_type = kwargs.pop("template_fp_type", 'rdk') | |
self.num_templates = kwargs.pop("num_templates", 401) | |
self.fp_size = kwargs.pop("fp_size", 2048) | |
self.fp_radius = kwargs.pop("fp_radius", 4) | |
self.device = kwargs.pop("device", 'cuda' if torch.cuda.is_available() else 'cpu') | |
self.batch_size = kwargs.pop("batch_size", 32) | |
self.pooling_operation_state_embedding = kwargs.pop('pooling_operation_state_embedding', 'mean') | |
self.pooling_operation_head = kwargs.pop('pooling_operation_head', 'max') | |
self.dropout = kwargs.pop('dropout', 0.0) | |
self.lr = kwargs.pop('lr', 1e-4) | |
self.optimizer = kwargs.pop("optimizer", "Adam") | |
self.activation_function = kwargs.pop('activation_function', 'ReLU') | |
self.verbose = kwargs.pop("verbose", False) # debugging or printing additional warnings / information set tot True | |
self.hopf_input_size = kwargs.pop('hopf_input_size', 2048) | |
self.hopf_output_size = kwargs.pop("hopf_output_size", 768) | |
self.hopf_num_heads = kwargs.pop("hopf_num_heads", 1) | |
self.hopf_asso_dim = kwargs.pop("hopf_asso_dim", 768) | |
self.hopf_association_activation = kwargs.pop("hopf_association_activation", None) | |
self.hopf_beta = kwargs.pop("hopf_beta",0.125) # 1/(self.hopf_asso_dim**(1/2) sqrt(d_k) | |
self.norm_input = kwargs.pop("norm_input",False) | |
self.norm_asso = kwargs.pop("norm_asso", False) | |
# additional experimental hyperparams | |
if 'hopf_n_layers' in kwargs.keys(): | |
self.hopf_n_layers = kwargs.pop('hopf_n_layers', 0) | |
if 'mol_encoder_layers' in kwargs.keys(): | |
self.mol_encoder_layers = kwargs.pop('mol_encoder_layers', 1) | |
if 'temp_encoder_layers' in kwargs.keys(): | |
self.temp_encoder_layers = kwargs.pop('temp_encoder_layers', 1) | |
if 'encoder_af' in kwargs.keys(): | |
self.encoder_af = kwargs.pop('encoder_af', 'ReLU') | |
# additional kwargs | |
for key, value in kwargs.items(): | |
try: | |
setattr(self, key, value) | |
except AttributeError as err: | |
log.error(f"Can't set {key} with value {value} for {self}") | |
raise err | |
class Encoder(nn.Module): | |
"""Simple FFNN""" | |
def __init__(self, input_size: int = 2048, output_size: int = 1024, | |
num_layers: int = 1, dropout: float = 0.3, af_name: str ='None', | |
norm_in: bool = False, norm_out: bool = False): | |
super().__init__() | |
self.ws = [] | |
self.setup_af(af_name) | |
self.norm_in = (lambda k: k) if not norm_in else torch.nn.LayerNorm(input_size, elementwise_affine=False) | |
self.norm_out = (lambda k: k) if not norm_out else torch.nn.LayerNorm(output_size, elementwise_affine=False) | |
self.setup_ff(input_size, output_size, num_layers) | |
self.dropout = nn.Dropout(p=dropout) | |
def forward(self, x: torch.Tensor): | |
x = self.norm_in(x) | |
for i, w in enumerate(self.ws): | |
if i==(len(self.ws)-1): | |
x = self.dropout(w(x)) # all except last haf ff_af | |
else: | |
x = self.dropout(self.af(w(x))) | |
x = self.norm_out(x) | |
return x | |
def setup_ff(self, input_size:int, output_size:int, num_layers=1): | |
"""setup feed-forward NN with n-layers""" | |
for n in range(0, num_layers): | |
w = nn.Linear(input_size if n==0 else output_size, output_size) | |
torch.nn.init.kaiming_normal_(w.weight, mode='fan_in', nonlinearity='linear') # eqiv to LeCun init | |
setattr(self, f'W_{n}', w) # consider doing a step-wise reduction | |
self.ws.append(getattr(self, f'W_{n}')) | |
def setup_af(self, af_name : str): | |
"""set activation function""" | |
if af_name is None or (af_name == 'None'): | |
self.af = lambda k: k | |
else: | |
try: | |
self.af = getattr(nn, af_name)() | |
except AttributeError as err: | |
log.error(f"Can't find activation-function {af_name} in torch.nn") | |
raise err | |
class MoleculeEncoder(Encoder): | |
""" | |
Class for Molecule encoder: can be any class mapping Smiles to a Vector (preferable differentiable ;) | |
""" | |
def __init__(self, config): | |
self.config = config | |
class FPMolEncoder(Encoder): | |
""" | |
Fingerprint Based Molecular encoder | |
""" | |
def __init__(self, config): | |
super().__init__(input_size = config.hopf_input_size*config.hopf_num_heads, | |
output_size = config.hopf_asso_dim*config.hopf_num_heads, | |
num_layers = config.mol_encoder_layers, | |
dropout = config.dropout, | |
af_name = config.encoder_af, | |
norm_in = config.norm_input, | |
norm_out = config.norm_asso, | |
) | |
# number of layers = self.config.mol_encoder_layers | |
# layer-dimension = self.config.hopf_asso_dim | |
# activation-function = self.config.af | |
self.config = config | |
def forward_smiles(self, list_of_smiles: list): | |
fp_tensor = self.convert_smiles_to_tensor(list_of_smiles) | |
return self.forward(fp_tensor) | |
def convert_smiles_to_tensor(self, list_of_smiles): | |
fps = convert_smiles_to_fp(list_of_smiles, fp_size=self.config.fp_size, | |
which=self.config.fingerprint_type, radius=self.config.fp_radius) | |
fps_tensor = torch.from_numpy(fps.astype(np.float)).to(dtype=torch.float).to(self.config.device) | |
return fps_tensor | |
class TemplateEncoder(Encoder): | |
""" | |
Class for Template encoder: can be any class mapping a Smarts-Reaction to a Vector (preferable differentiable ;) | |
""" | |
def __init__(self, config): | |
super().__init__(input_size = config.hopf_input_size*config.hopf_num_heads, | |
output_size = config.hopf_asso_dim*config.hopf_num_heads, | |
num_layers = config.temp_encoder_layers, | |
dropout = config.dropout, | |
af_name = config.encoder_af, | |
norm_in = config.norm_input, | |
norm_out = config.norm_asso, | |
) | |
self.config = config | |
#number of layers | |
#template fingerprint type | |
#random template threshold | |
#reactant pooling | |
if config.temp_encoder_layers==0: | |
print('No Key-Projection = Static Key/Templates') | |
assert self.config.hopf_asso_dim==self.config.fp_size | |
self.wks = [] | |
class MHN(nn.Module): | |
""" | |
MHN - modern Hopfield Network -- for Template relevance prediction | |
""" | |
def __init__(self, config=None, layer2weight=0.05, use_template_encoder=True): | |
super().__init__() | |
if config: | |
self.config = config | |
else: | |
self.config = ModelConfig() | |
self.beta = self.config.hopf_beta | |
# hopf_num_heads | |
self.mol_encoder = FPMolEncoder(self.config) | |
if use_template_encoder: | |
self.template_encoder = TemplateEncoder(self.config) | |
self.W_v = None | |
self.layer2weight = layer2weight | |
# more MHN layers -- added recursively | |
if hasattr(self.config, 'hopf_n_layers'): | |
di = self.config.__dict__ | |
di['hopf_n_layers'] -= 1 | |
if di['hopf_n_layers']>0: | |
conf_wo_hopf_nlayers = ModelConfig(**di) | |
self.layer = MHN(conf_wo_hopf_nlayers) | |
if di['hopf_n_layers']!=0: | |
self.W_v = nn.Linear(self.config.hopf_asso_dim, self.config.hopf_input_size) | |
torch.nn.init.kaiming_normal_(self.W_v.weight, mode='fan_in', nonlinearity='linear') # eqiv to LeCun init | |
self.softmax = torch.nn.Softmax(dim=1) | |
self.lossfunction = nn.CrossEntropyLoss(reduction='none')#, weight=class_weights) | |
self.pretrain_lossfunction = nn.BCEWithLogitsLoss(reduction='none')#, weight=class_weights) | |
self.lr = self.config.lr | |
if self.config.hopf_association_activation is None or (self.config.hopf_association_activation.lower()=='none'): | |
self.af = lambda k: k | |
else: | |
self.af = getattr(nn, self.config.hopf_association_activation)() | |
self.pooling_operation_head = getattr(torch, self.config.pooling_operation_head) | |
self.X = None # templates projected to Hopfield Layer | |
self.optimizer = getattr(torch.optim, self.config.optimizer)(self.parameters(), lr=self.lr) | |
self.steps = 0 | |
self.hist = defaultdict(list) | |
self.to(self.config.device) | |
def set_templates(self, template_list, which='rdk', fp_size=None, radius=2, learnable=False, njobs=1, only_templates_in_batch=False): | |
self.template_list = template_list.copy() | |
if fp_size is None: | |
fp_size = self.config.fp_size | |
if len(template_list)>=100000: | |
import math | |
print('batch-wise template_calculation') | |
bs = 30000 | |
final_temp_emb = torch.zeros((len(template_list), fp_size)).float().to(self.config.device) | |
for b in range(math.ceil(len(template_list)//bs)+1): | |
self.template_list = template_list[bs*b:min(bs*(b+1), len(template_list))] | |
templ_emb = self.update_template_embedding(which=which, fp_size=fp_size, radius=radius, learnable=learnable, njobs=njobs, only_templates_in_batch=only_templates_in_batch) | |
final_temp_emb[bs*b:min(bs*(b+1), len(template_list))] = torch.from_numpy(templ_emb) | |
self.templates = final_temp_emb | |
else: | |
self.update_template_embedding(which=which, fp_size=fp_size, radius=radius, learnable=learnable, njobs=njobs, only_templates_in_batch=only_templates_in_batch) | |
self.set_templates_recursively() | |
def set_templates_recursively(self): | |
if 'hopf_n_layers' in self.config.__dict__.keys(): | |
if self.config.hopf_n_layers >0: | |
self.layer.templates = self.templates | |
self.layer.set_templates_recursively() | |
def update_template_embedding(self,fp_size=2048, radius=4, which='rdk', learnable=False, njobs=1, only_templates_in_batch=False): | |
print('updating template-embedding; (just computing the template-fingerprint and using that)') | |
bs = self.config.batch_size | |
split_template_list = [str(t).split('>')[0].split('.') for t in self.template_list] | |
templates_np = convert_smiles_to_fp(split_template_list, is_smarts=True, fp_size=fp_size, radius=radius, which=which, njobs=njobs) | |
split_template_list = [str(t).split('>')[-1].split('.') for t in self.template_list] | |
reactants_np = convert_smiles_to_fp(split_template_list, is_smarts=True, fp_size=fp_size, radius=radius, which=which, njobs=njobs) | |
template_representation = templates_np-(reactants_np*0.5) | |
if learnable: | |
self.templates = torch.nn.Parameter(torch.from_numpy(template_representation).float(), requires_grad=True).to(self.config.device) | |
self.register_parameter(name='templates', param=self.templates) | |
else: | |
if only_templates_in_batch: | |
self.templates_np = template_representation | |
else: | |
self.templates = torch.from_numpy(template_representation).float().to(self.config.device) | |
return template_representation | |
def np_fp_to_tensor(self, np_fp): | |
return torch.from_numpy(np_fp.astype(np.float64)).to(self.config.device).float() | |
def masked_loss_fun(self, loss_fun, h_out, ys_batch): | |
if loss_fun == self.BCEWithLogitsLoss: | |
mask = (ys_batch != -1).float() | |
ys_batch = ys_batch.float() | |
else: | |
mask = (ys_batch.long() != -1).long() | |
mask_sum = int(mask.sum().cpu().numpy()) | |
if mask_sum == 0: | |
return 0 | |
ys_batch = ys_batch * mask | |
loss = (loss_fun(h_out, ys_batch * mask) * mask.float()).sum() / mask_sum # only mean from non -1 | |
return loss | |
def compute_losses(self, out, ys_batch, head_loss_weight=None): | |
if len(ys_batch.shape)==2: | |
if ys_batch.shape[1]==self.config.num_templates: # it is in pretraining_mode | |
loss = self.pretrain_lossfunction(out, ys_batch.float()).mean() | |
else: | |
# legacy from policyNN | |
loss = self.lossfunction(out, ys_batch[:, 2]).mean() # WARNING: HEAD4 Reaction Template is ys[:,2] | |
else: | |
loss = self.lossfunction(out, ys_batch).mean() | |
return loss | |
def forward_smiles(self, list_of_smiles, templates=None): | |
state_tensor = self.mol_encoder.convert_smiles_to_tensor(list_of_smiles) | |
return self.forward(state_tensor, templates=templates) | |
def forward(self, m, templates=None): | |
""" | |
m: molecule in the form batch x fingerprint | |
templates: None or newly given templates if not instanciated | |
returns logits ranking the templates for each molecule | |
""" | |
#states_emb = self.fcfe(state_fp) | |
bs = m.shape[0] #batch_size | |
#templates = self.temp_emb(torch.arange(0,2000).long()) | |
if (templates is None) and (self.X is None) and (self.templates is None): | |
raise Exception('Either pass in templates, or init templates by runnting clf.set_templates') | |
n_temp = len(templates) if templates is not None else len(self.templates) | |
if self.training or (templates is None) or (self.X is not None): | |
templates = templates if templates is not None else self.templates | |
X = self.template_encoder(templates) | |
else: | |
X = self.X # precomputed from last forward run | |
Xi = self.mol_encoder(m) | |
Xi = Xi.view(bs, self.config.hopf_num_heads, self.config.hopf_asso_dim) # [bs, H, A] | |
X = X.view(1, n_temp, self.config.hopf_asso_dim, self.config.hopf_num_heads) #[1, T, A, H] | |
XXi = torch.tensordot(Xi, X, dims=[(2,1), (2,0)]) # AxA -> [bs, T, H] | |
# pooling over heads | |
if self.config.hopf_num_heads<=1: | |
#QKt_pooled = QKt | |
XXi = XXi[:,:,0] #torch.squeeze(QKt, dim=2) | |
else: | |
XXi = self.pooling_operation_head(XXi, dim=2) # default is max pooling over H [bs, T] | |
if (self.config.pooling_operation_head =='max') or (self.config.pooling_operation_head =='min'): | |
XXi = XXi[0] #max and min also return the indices =S | |
out = self.beta*XXi # [bs, T, H] # softmax over dim=1 #pooling_operation_head | |
self.xinew = self.softmax(out)@X.view(n_temp, self.config.hopf_asso_dim) # [bs,T]@[T,emb] -> [bs,emb] | |
if self.W_v: | |
# call layers recursive | |
hopfout = self.W_v(self.xinew) # [bs,emb]@[emb,hopf_inp] --> [bs, hopf_inp] | |
# TODO check if using x_pooled or if not going through mol_encoder again | |
hopfout = hopfout + m # skip-connection | |
# give it to the next layer | |
out2 = self.layer.forward(hopfout) #templates=self.W_v(self.K) | |
out = out*(1-self.layer2weight)+out2*self.layer2weight | |
return out | |
def train_from_np(self, Xs, targets, ys, is_smiles=False, epochs=2, lr=0.001, bs=32, | |
permute_batches=False, shuffle=True, optimizer=None, | |
use_dataloader=True, verbose=False, | |
wandb=None, scheduler=None, only_templates_in_batch=False): | |
""" | |
Xs in the form sample x states | |
targets | |
ys in the form sample x [y_h1, y_h2, y_h3, y_h4] | |
""" | |
self.train() | |
if optimizer is None: | |
try: | |
self.optimizer = getattr(torch.optim, self.config.optimizer)(self.parameters(), lr=self.lr if lr is None else lr) | |
except AttributeError as err: | |
log.error(f"Can't find optimizer {config.optimizer} in torch.optim") | |
raise err | |
optimizer = self.optimizer | |
dataset = ChemRXNDataset(Xs, targets, ys, is_smiles=is_smiles, | |
fp_size=self.config.fp_size, fingerprint_type=self.config.fingerprint_type) | |
dataloader = torch.utils.data.DataLoader(dataset, batch_size=bs, shuffle=shuffle, sampler=None, | |
batch_sampler=None, num_workers=0, collate_fn=None, | |
pin_memory=False, drop_last=False, timeout=0, | |
worker_init_fn=None) | |
for epoch in range(epochs): # loop over the dataset multiple times | |
running_loss = 0.0 | |
running_loss_dict = defaultdict(int) | |
batch_order = range(0, len(Xs), bs) | |
if permute_batches: | |
batch_order = np.random.permutation(batch_order) | |
for step, s in tqdm(enumerate(dataloader),mininterval=2): | |
batch = [b.to(self.config.device, non_blocking=True) for b in s] | |
Xs_batch, target_batch, ys_batch = batch | |
# zero the parameter gradients | |
optimizer.zero_grad() | |
# forward + backward + optimize | |
out = self.forward(Xs_batch) | |
total_loss = self.compute_losses(out, ys_batch) | |
loss_dict = {'CE_loss': total_loss} | |
total_loss.backward() | |
optimizer.step() | |
if scheduler: | |
scheduler.step() | |
self.steps += 1 | |
# print statistics | |
for k in loss_dict: | |
running_loss_dict[k] += loss_dict[k].item() | |
try: | |
running_loss += total_loss.item() | |
except: | |
running_loss += 0 | |
rs = min(100,len(Xs)//bs) # reporting/logging steps | |
if step % rs == (rs-1): # print every 2000 mini-batches | |
if verbose: print('[%d, %5d] loss: %.3f' % | |
(epoch + 1, step + 1, running_loss / rs)) | |
self.hist['step'].append(self.steps) | |
self.hist['loss'].append(running_loss/rs) | |
self.hist['trianing_running_loss'].append(running_loss/rs) | |
[self.hist[k].append(running_loss_dict[k]/rs) for k in running_loss_dict] | |
if wandb: | |
wandb.log({'trianing_running_loss': running_loss / rs}) | |
running_loss = 0.0 | |
running_loss_dict = defaultdict(int) | |
if verbose: print('Finished Training') | |
return optimizer | |
def evaluate(self, Xs, targets, ys, split='test', is_smiles=False, bs = 32, shuffle=False, wandb=None, only_loss=False): | |
self.eval() | |
y_preds = np.zeros( (ys.shape[0], self.config.num_templates), dtype=np.float16) | |
loss_metrics = defaultdict(int) | |
new_hist = defaultdict(float) | |
with torch.no_grad(): | |
dataset = ChemRXNDataset(Xs, targets, ys, is_smiles=is_smiles, | |
fp_size=self.config.fp_size, fingerprint_type=self.config.fingerprint_type) | |
dataloader = torch.utils.data.DataLoader(dataset, batch_size=bs, shuffle=shuffle, sampler=None, | |
batch_sampler=None, num_workers=0, collate_fn=None, | |
pin_memory=False, drop_last=False, timeout=0, | |
worker_init_fn=None) | |
#for step, s in eoutputs = self.forward(batch[0], batchnumerate(range(0, len(Xs), bs)): | |
for step, batch in enumerate(dataloader):# | |
batch = [b.to(self.config.device, non_blocking=True) for b in batch] | |
ys_batch = batch[2] | |
if hasattr(self, 'templates_np'): | |
outputs = [] | |
for ii in range(10): | |
tlen = len(self.templates_np) | |
i_tlen = tlen//10 | |
templates = torch.from_numpy(self.templates_np[(i_tlen*ii):min(i_tlen*(ii+1), tlen)]).float().to(self.config.device) | |
outputs.append( self.forward(batch[0], templates = templates ) ) | |
outputs = torch.cat(outputs, dim=0) | |
else: | |
outputs = self.forward(batch[0]) | |
loss = self.compute_losses(outputs, ys_batch, None) | |
# not quite right because in every batch there might be different number of valid samples | |
weight = 1/len(batch[0])#len(Xs[s:min(s + bs, len(Xs))]) / len(Xs) | |
loss_metrics['loss'] += (loss.item()) | |
if len(ys.shape)>1: | |
outputs = self.softmax(outputs) if not (ys.shape[1]==self.config.num_templates) else torch.sigmoid(outputs) | |
else: | |
outputs = self.softmax(outputs) | |
outputs_np = [None if o is None else o.to('cpu').numpy().astype(np.float16) for o in outputs] | |
if not only_loss: | |
ks = [1, 2, 3, 4, 5, 10, 20, 30, 40, 50, 100] | |
topkacc, mrocc = top_k_accuracy(ys_batch, outputs, k=ks, ret_arocc=True, ret_mrocc=False) | |
# mrocc -- median rank of correct choice | |
for k, tkacc in zip(ks, topkacc): | |
#iterative average update | |
new_hist[f't{k}_acc_{split}'] += (tkacc-new_hist[f't{k}_acc_{split}']) / (step+1) | |
# todo weight by batch-size | |
new_hist[f'meanrank_{split}'] = mrocc | |
y_preds[step*bs : min((step+1)*bs,len(y_preds))] = outputs_np | |
new_hist[f'steps_{split}'] = (self.steps) | |
new_hist[f'loss_{split}'] = (loss_metrics['loss'] / (step+1)) | |
for k in new_hist: | |
self.hist[k].append(new_hist[k]) | |
if wandb: | |
wandb.log(new_hist) | |
self.hist[f'loss_{split}'].append(loss_metrics[f'loss'] / (step+1)) | |
return y_preds | |
def save_hist(self, prefix='', postfix=''): | |
HIST_PATH = 'data/hist/' | |
if not os.path.exists(HIST_PATH): | |
os.mkdir(HIST_PATH) | |
fn_hist = HIST_PATH+prefix+postfix+'.csv' | |
with open(fn_hist, 'w') as fh: | |
print(dict(self.hist), file=fh) | |
return fn_hist | |
def save_model(self, prefix='', postfix='', name_as_conf=False): | |
MODEL_PATH = 'data/model/' | |
if not os.path.exists(MODEL_PATH): | |
os.mkdir(MODEL_PATH) | |
if name_as_conf: | |
confi_str = str(self.config.__dict__.values()).replace("'","").replace(': ','_').replace(', ',';') | |
else: | |
confi_str = '' | |
model_name = prefix+confi_str+postfix+'.pt' | |
torch.save(self.state_dict(), MODEL_PATH+model_name) | |
return MODEL_PATH+model_name | |
def plot_loss(self): | |
plot_loss(self.hist) | |
def plot_topk(self, sets=['train', 'valid', 'test'], with_last = 2): | |
plot_topk(self.hist, sets=sets, with_last = with_last) | |
def plot_nte(self, last_cpt=1, dataset='Sm', include_bar=True): | |
plot_nte(self.hist, dataset=dataset, last_cpt=last_cpt, include_bar=include_bar) | |
class SeglerBaseline(MHN): | |
"""FFNN - only the Molecule Encoder + an output projection""" | |
def __init__(self, config=None): | |
config.template_fp_type = 'none' | |
config.temp_encoder_layers = 0 | |
super().__init__(config, use_template_encoder=False) | |
self.W_out = torch.nn.Linear(config.hopf_asso_dim, config.num_templates) | |
self.optimizer = getattr(torch.optim, self.config.optimizer)(self.parameters(), lr=self.lr) | |
self.steps = 0 | |
self.hist = defaultdict(list) | |
self.to(self.config.device) | |
def forward(self, m, templates=None): | |
""" | |
m: molecule in the form batch x fingerprint | |
templates: won't be used in this case | |
returns logits ranking the templates for each molecule | |
""" | |
bs = m.shape[0] #batch_size | |
Xi = self.mol_encoder(m) | |
Xi = self.mol_encoder.af(Xi) # is not applied in encoder for last layer | |
out = self.W_out(Xi) # [bs, T] # softmax over dim=1 | |
return out | |
class StaticQK(MHN): | |
""" Static QK baseline - beware to have the same fingerprint for mol_encoder as for the template_encoder (fp2048 r4 rdk by default)""" | |
def __init__(self, config=None): | |
if config: | |
self.config = config | |
else: | |
self.config = ModelConfig() | |
super().__init__(config) | |
self.fp_size = 2048 | |
self.fingerprint_type = 'rdk' | |
self.beta = 1 | |
def update_template_embedding(self, which='rdk', fp_size=2048, radius=4, learnable=False): | |
bs = self.config.batch_size | |
split_template_list = [t.split('>>')[0].split('.') for t in self.template_list] | |
self.templates = torch.from_numpy(convert_smiles_to_fp(split_template_list, | |
is_smarts=True, fp_size=fp_size, | |
radius=radius, which=which).max(1)).float().to(self.config.device) | |
def forward(self, m, templates=None): | |
""" | |
""" | |
#states_emb = self.fcfe(state_fp) | |
bs = m.shape[0] #batch_size | |
Xi = m #[bs, emb] | |
X = self.templates #[T, emb]) | |
XXi = Xi@X.T # [bs, T] | |
# normalize | |
t_sum = templates.sum(1) #[T] | |
t_sum = t_sum.view(1,-1).expand(bs, -1) #[bs, T] | |
XXi = XXi / t_sum | |
# not neccecaire because it is not trained | |
out = self.beta*XXi # [bs, T] # softmax over dim=1 | |
return out | |
class Retrosim(StaticQK): | |
""" Retrosim-like baseline only for template relevance prediction """ | |
def fit_with_train(self, X_fp_train, y_train): | |
self.templates = torch.from_numpy(X_fp_train).float().to(self.config.device) | |
# train_samples, num_templates | |
self.sample2acttemplate = torch.nn.functional.one_hot(torch.from_numpy(y_train), self.config.num_templates).float() | |
tmpnorm = self.sample2acttemplate.sum(0) | |
tmpnorm[tmpnorm==0] = 1 | |
self.sample2acttemplate = (self.sample2acttemplate / tmpnorm).to(self.config.device) # results in an average after dot product | |
def forward(self, m, templates=None): | |
""" | |
""" | |
out = super().forward(m, templates=templates) | |
# bs, train_samples | |
# map out to actual templates | |
out = out @ self.sample2acttemplate | |
return out |