MHN-React / mhnreact /model.py
uragankatrrin's picture
Upload 12 files
2956799
# -*- coding: utf-8 -*-
"""
Author: Philipp Seidl
ELLIS Unit Linz, LIT AI Lab, Institute for Machine Learning
Johannes Kepler University Linz
Contact: seidl@ml.jku.at
Model related functionality
"""
from .utils import top_k_accuracy
from .plotutils import plot_loss, plot_topk, plot_nte
from .molutils import convert_smiles_to_fp
import os
import numpy as np
import torch
import torch.nn as nn
from collections import defaultdict
from scipy import sparse
import logging
from tqdm import tqdm
import wandb
log = logging.getLogger(__name__)
class ChemRXNDataset(torch.utils.data.Dataset):
"Torch Dataset for ChemRXN containing Xs: the input as np array, target: the target molecules (or nothing), and ys: the label"
def __init__(self, Xs, target, ys, is_smiles=False, fp_size=2048, fingerprint_type='morgan'):
self.is_smiles=is_smiles
if is_smiles:
self.Xs = Xs
self.target = target
self.fp_size = fp_size
self.fingerprint_type = fingerprint_type
else:
self.Xs = Xs.astype(np.float32)
self.target = target.astype(np.float32)
self.ys = ys
self.ys_is_sparse = isinstance(self.ys, sparse.csr.csr_matrix)
def __getitem__(self, k):
mol_fp = self.Xs[k]
if self.is_smiles:
mol_fp = convert_smiles_to_fp(mol_fp, fp_size=self.fp_size, which=self.fingerprint_type).astype(np.float32)
target = None if self.target is None else self.target[k]
if self.is_smiles and self.target:
target = convert_smiles_to_fp(target, fp_size=self.fp_size, which=self.fingerprint_type).astype(np.float32)
label = self.ys[k]
if isinstance(self.ys, sparse.csr.csr_matrix):
label = label.toarray()[0]
return (mol_fp, target, label)
def __len__(self):
return len(self.Xs)
class ModelConfig(object):
def __init__(self, **kwargs):
self.fingerprint_type = kwargs.pop("fingerprint_type", 'morgan')
self.template_fp_type = kwargs.pop("template_fp_type", 'rdk')
self.num_templates = kwargs.pop("num_templates", 401)
self.fp_size = kwargs.pop("fp_size", 2048)
self.fp_radius = kwargs.pop("fp_radius", 4)
self.device = kwargs.pop("device", 'cuda' if torch.cuda.is_available() else 'cpu')
self.batch_size = kwargs.pop("batch_size", 32)
self.pooling_operation_state_embedding = kwargs.pop('pooling_operation_state_embedding', 'mean')
self.pooling_operation_head = kwargs.pop('pooling_operation_head', 'max')
self.dropout = kwargs.pop('dropout', 0.0)
self.lr = kwargs.pop('lr', 1e-4)
self.optimizer = kwargs.pop("optimizer", "Adam")
self.activation_function = kwargs.pop('activation_function', 'ReLU')
self.verbose = kwargs.pop("verbose", False) # debugging or printing additional warnings / information set tot True
self.hopf_input_size = kwargs.pop('hopf_input_size', 2048)
self.hopf_output_size = kwargs.pop("hopf_output_size", 768)
self.hopf_num_heads = kwargs.pop("hopf_num_heads", 1)
self.hopf_asso_dim = kwargs.pop("hopf_asso_dim", 768)
self.hopf_association_activation = kwargs.pop("hopf_association_activation", None)
self.hopf_beta = kwargs.pop("hopf_beta",0.125) # 1/(self.hopf_asso_dim**(1/2) sqrt(d_k)
self.norm_input = kwargs.pop("norm_input",False)
self.norm_asso = kwargs.pop("norm_asso", False)
# additional experimental hyperparams
if 'hopf_n_layers' in kwargs.keys():
self.hopf_n_layers = kwargs.pop('hopf_n_layers', 0)
if 'mol_encoder_layers' in kwargs.keys():
self.mol_encoder_layers = kwargs.pop('mol_encoder_layers', 1)
if 'temp_encoder_layers' in kwargs.keys():
self.temp_encoder_layers = kwargs.pop('temp_encoder_layers', 1)
if 'encoder_af' in kwargs.keys():
self.encoder_af = kwargs.pop('encoder_af', 'ReLU')
# additional kwargs
for key, value in kwargs.items():
try:
setattr(self, key, value)
except AttributeError as err:
log.error(f"Can't set {key} with value {value} for {self}")
raise err
class Encoder(nn.Module):
"""Simple FFNN"""
def __init__(self, input_size: int = 2048, output_size: int = 1024,
num_layers: int = 1, dropout: float = 0.3, af_name: str ='None',
norm_in: bool = False, norm_out: bool = False):
super().__init__()
self.ws = []
self.setup_af(af_name)
self.norm_in = (lambda k: k) if not norm_in else torch.nn.LayerNorm(input_size, elementwise_affine=False)
self.norm_out = (lambda k: k) if not norm_out else torch.nn.LayerNorm(output_size, elementwise_affine=False)
self.setup_ff(input_size, output_size, num_layers)
self.dropout = nn.Dropout(p=dropout)
def forward(self, x: torch.Tensor):
x = self.norm_in(x)
for i, w in enumerate(self.ws):
if i==(len(self.ws)-1):
x = self.dropout(w(x)) # all except last haf ff_af
else:
x = self.dropout(self.af(w(x)))
x = self.norm_out(x)
return x
def setup_ff(self, input_size:int, output_size:int, num_layers=1):
"""setup feed-forward NN with n-layers"""
for n in range(0, num_layers):
w = nn.Linear(input_size if n==0 else output_size, output_size)
torch.nn.init.kaiming_normal_(w.weight, mode='fan_in', nonlinearity='linear') # eqiv to LeCun init
setattr(self, f'W_{n}', w) # consider doing a step-wise reduction
self.ws.append(getattr(self, f'W_{n}'))
def setup_af(self, af_name : str):
"""set activation function"""
if af_name is None or (af_name == 'None'):
self.af = lambda k: k
else:
try:
self.af = getattr(nn, af_name)()
except AttributeError as err:
log.error(f"Can't find activation-function {af_name} in torch.nn")
raise err
class MoleculeEncoder(Encoder):
"""
Class for Molecule encoder: can be any class mapping Smiles to a Vector (preferable differentiable ;)
"""
def __init__(self, config):
self.config = config
class FPMolEncoder(Encoder):
"""
Fingerprint Based Molecular encoder
"""
def __init__(self, config):
super().__init__(input_size = config.hopf_input_size*config.hopf_num_heads,
output_size = config.hopf_asso_dim*config.hopf_num_heads,
num_layers = config.mol_encoder_layers,
dropout = config.dropout,
af_name = config.encoder_af,
norm_in = config.norm_input,
norm_out = config.norm_asso,
)
# number of layers = self.config.mol_encoder_layers
# layer-dimension = self.config.hopf_asso_dim
# activation-function = self.config.af
self.config = config
def forward_smiles(self, list_of_smiles: list):
fp_tensor = self.convert_smiles_to_tensor(list_of_smiles)
return self.forward(fp_tensor)
def convert_smiles_to_tensor(self, list_of_smiles):
fps = convert_smiles_to_fp(list_of_smiles, fp_size=self.config.fp_size,
which=self.config.fingerprint_type, radius=self.config.fp_radius)
fps_tensor = torch.from_numpy(fps.astype(np.float)).to(dtype=torch.float).to(self.config.device)
return fps_tensor
class TemplateEncoder(Encoder):
"""
Class for Template encoder: can be any class mapping a Smarts-Reaction to a Vector (preferable differentiable ;)
"""
def __init__(self, config):
super().__init__(input_size = config.hopf_input_size*config.hopf_num_heads,
output_size = config.hopf_asso_dim*config.hopf_num_heads,
num_layers = config.temp_encoder_layers,
dropout = config.dropout,
af_name = config.encoder_af,
norm_in = config.norm_input,
norm_out = config.norm_asso,
)
self.config = config
#number of layers
#template fingerprint type
#random template threshold
#reactant pooling
if config.temp_encoder_layers==0:
print('No Key-Projection = Static Key/Templates')
assert self.config.hopf_asso_dim==self.config.fp_size
self.wks = []
class MHN(nn.Module):
"""
MHN - modern Hopfield Network -- for Template relevance prediction
"""
def __init__(self, config=None, layer2weight=0.05, use_template_encoder=True):
super().__init__()
if config:
self.config = config
else:
self.config = ModelConfig()
self.beta = self.config.hopf_beta
# hopf_num_heads
self.mol_encoder = FPMolEncoder(self.config)
if use_template_encoder:
self.template_encoder = TemplateEncoder(self.config)
self.W_v = None
self.layer2weight = layer2weight
# more MHN layers -- added recursively
if hasattr(self.config, 'hopf_n_layers'):
di = self.config.__dict__
di['hopf_n_layers'] -= 1
if di['hopf_n_layers']>0:
conf_wo_hopf_nlayers = ModelConfig(**di)
self.layer = MHN(conf_wo_hopf_nlayers)
if di['hopf_n_layers']!=0:
self.W_v = nn.Linear(self.config.hopf_asso_dim, self.config.hopf_input_size)
torch.nn.init.kaiming_normal_(self.W_v.weight, mode='fan_in', nonlinearity='linear') # eqiv to LeCun init
self.softmax = torch.nn.Softmax(dim=1)
self.lossfunction = nn.CrossEntropyLoss(reduction='none')#, weight=class_weights)
self.pretrain_lossfunction = nn.BCEWithLogitsLoss(reduction='none')#, weight=class_weights)
self.lr = self.config.lr
if self.config.hopf_association_activation is None or (self.config.hopf_association_activation.lower()=='none'):
self.af = lambda k: k
else:
self.af = getattr(nn, self.config.hopf_association_activation)()
self.pooling_operation_head = getattr(torch, self.config.pooling_operation_head)
self.X = None # templates projected to Hopfield Layer
self.optimizer = getattr(torch.optim, self.config.optimizer)(self.parameters(), lr=self.lr)
self.steps = 0
self.hist = defaultdict(list)
self.to(self.config.device)
def set_templates(self, template_list, which='rdk', fp_size=None, radius=2, learnable=False, njobs=1, only_templates_in_batch=False):
self.template_list = template_list.copy()
if fp_size is None:
fp_size = self.config.fp_size
if len(template_list)>=100000:
import math
print('batch-wise template_calculation')
bs = 30000
final_temp_emb = torch.zeros((len(template_list), fp_size)).float().to(self.config.device)
for b in range(math.ceil(len(template_list)//bs)+1):
self.template_list = template_list[bs*b:min(bs*(b+1), len(template_list))]
templ_emb = self.update_template_embedding(which=which, fp_size=fp_size, radius=radius, learnable=learnable, njobs=njobs, only_templates_in_batch=only_templates_in_batch)
final_temp_emb[bs*b:min(bs*(b+1), len(template_list))] = torch.from_numpy(templ_emb)
self.templates = final_temp_emb
else:
self.update_template_embedding(which=which, fp_size=fp_size, radius=radius, learnable=learnable, njobs=njobs, only_templates_in_batch=only_templates_in_batch)
self.set_templates_recursively()
def set_templates_recursively(self):
if 'hopf_n_layers' in self.config.__dict__.keys():
if self.config.hopf_n_layers >0:
self.layer.templates = self.templates
self.layer.set_templates_recursively()
def update_template_embedding(self,fp_size=2048, radius=4, which='rdk', learnable=False, njobs=1, only_templates_in_batch=False):
print('updating template-embedding; (just computing the template-fingerprint and using that)')
bs = self.config.batch_size
split_template_list = [str(t).split('>')[0].split('.') for t in self.template_list]
templates_np = convert_smiles_to_fp(split_template_list, is_smarts=True, fp_size=fp_size, radius=radius, which=which, njobs=njobs)
split_template_list = [str(t).split('>')[-1].split('.') for t in self.template_list]
reactants_np = convert_smiles_to_fp(split_template_list, is_smarts=True, fp_size=fp_size, radius=radius, which=which, njobs=njobs)
template_representation = templates_np-(reactants_np*0.5)
if learnable:
self.templates = torch.nn.Parameter(torch.from_numpy(template_representation).float(), requires_grad=True).to(self.config.device)
self.register_parameter(name='templates', param=self.templates)
else:
if only_templates_in_batch:
self.templates_np = template_representation
else:
self.templates = torch.from_numpy(template_representation).float().to(self.config.device)
return template_representation
def np_fp_to_tensor(self, np_fp):
return torch.from_numpy(np_fp.astype(np.float64)).to(self.config.device).float()
def masked_loss_fun(self, loss_fun, h_out, ys_batch):
if loss_fun == self.BCEWithLogitsLoss:
mask = (ys_batch != -1).float()
ys_batch = ys_batch.float()
else:
mask = (ys_batch.long() != -1).long()
mask_sum = int(mask.sum().cpu().numpy())
if mask_sum == 0:
return 0
ys_batch = ys_batch * mask
loss = (loss_fun(h_out, ys_batch * mask) * mask.float()).sum() / mask_sum # only mean from non -1
return loss
def compute_losses(self, out, ys_batch, head_loss_weight=None):
if len(ys_batch.shape)==2:
if ys_batch.shape[1]==self.config.num_templates: # it is in pretraining_mode
loss = self.pretrain_lossfunction(out, ys_batch.float()).mean()
else:
# legacy from policyNN
loss = self.lossfunction(out, ys_batch[:, 2]).mean() # WARNING: HEAD4 Reaction Template is ys[:,2]
else:
loss = self.lossfunction(out, ys_batch).mean()
return loss
def forward_smiles(self, list_of_smiles, templates=None):
state_tensor = self.mol_encoder.convert_smiles_to_tensor(list_of_smiles)
return self.forward(state_tensor, templates=templates)
def forward(self, m, templates=None):
"""
m: molecule in the form batch x fingerprint
templates: None or newly given templates if not instanciated
returns logits ranking the templates for each molecule
"""
#states_emb = self.fcfe(state_fp)
bs = m.shape[0] #batch_size
#templates = self.temp_emb(torch.arange(0,2000).long())
if (templates is None) and (self.X is None) and (self.templates is None):
raise Exception('Either pass in templates, or init templates by runnting clf.set_templates')
n_temp = len(templates) if templates is not None else len(self.templates)
if self.training or (templates is None) or (self.X is not None):
templates = templates if templates is not None else self.templates
X = self.template_encoder(templates)
else:
X = self.X # precomputed from last forward run
Xi = self.mol_encoder(m)
Xi = Xi.view(bs, self.config.hopf_num_heads, self.config.hopf_asso_dim) # [bs, H, A]
X = X.view(1, n_temp, self.config.hopf_asso_dim, self.config.hopf_num_heads) #[1, T, A, H]
XXi = torch.tensordot(Xi, X, dims=[(2,1), (2,0)]) # AxA -> [bs, T, H]
# pooling over heads
if self.config.hopf_num_heads<=1:
#QKt_pooled = QKt
XXi = XXi[:,:,0] #torch.squeeze(QKt, dim=2)
else:
XXi = self.pooling_operation_head(XXi, dim=2) # default is max pooling over H [bs, T]
if (self.config.pooling_operation_head =='max') or (self.config.pooling_operation_head =='min'):
XXi = XXi[0] #max and min also return the indices =S
out = self.beta*XXi # [bs, T, H] # softmax over dim=1 #pooling_operation_head
self.xinew = self.softmax(out)@X.view(n_temp, self.config.hopf_asso_dim) # [bs,T]@[T,emb] -> [bs,emb]
if self.W_v:
# call layers recursive
hopfout = self.W_v(self.xinew) # [bs,emb]@[emb,hopf_inp] --> [bs, hopf_inp]
# TODO check if using x_pooled or if not going through mol_encoder again
hopfout = hopfout + m # skip-connection
# give it to the next layer
out2 = self.layer.forward(hopfout) #templates=self.W_v(self.K)
out = out*(1-self.layer2weight)+out2*self.layer2weight
return out
def train_from_np(self, Xs, targets, ys, is_smiles=False, epochs=2, lr=0.001, bs=32,
permute_batches=False, shuffle=True, optimizer=None,
use_dataloader=True, verbose=False,
wandb=None, scheduler=None, only_templates_in_batch=False):
"""
Xs in the form sample x states
targets
ys in the form sample x [y_h1, y_h2, y_h3, y_h4]
"""
self.train()
if optimizer is None:
try:
self.optimizer = getattr(torch.optim, self.config.optimizer)(self.parameters(), lr=self.lr if lr is None else lr)
except AttributeError as err:
log.error(f"Can't find optimizer {config.optimizer} in torch.optim")
raise err
optimizer = self.optimizer
dataset = ChemRXNDataset(Xs, targets, ys, is_smiles=is_smiles,
fp_size=self.config.fp_size, fingerprint_type=self.config.fingerprint_type)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=bs, shuffle=shuffle, sampler=None,
batch_sampler=None, num_workers=0, collate_fn=None,
pin_memory=False, drop_last=False, timeout=0,
worker_init_fn=None)
for epoch in range(epochs): # loop over the dataset multiple times
running_loss = 0.0
running_loss_dict = defaultdict(int)
batch_order = range(0, len(Xs), bs)
if permute_batches:
batch_order = np.random.permutation(batch_order)
for step, s in tqdm(enumerate(dataloader),mininterval=2):
batch = [b.to(self.config.device, non_blocking=True) for b in s]
Xs_batch, target_batch, ys_batch = batch
# zero the parameter gradients
optimizer.zero_grad()
# forward + backward + optimize
out = self.forward(Xs_batch)
total_loss = self.compute_losses(out, ys_batch)
loss_dict = {'CE_loss': total_loss}
total_loss.backward()
optimizer.step()
if scheduler:
scheduler.step()
self.steps += 1
# print statistics
for k in loss_dict:
running_loss_dict[k] += loss_dict[k].item()
try:
running_loss += total_loss.item()
except:
running_loss += 0
rs = min(100,len(Xs)//bs) # reporting/logging steps
if step % rs == (rs-1): # print every 2000 mini-batches
if verbose: print('[%d, %5d] loss: %.3f' %
(epoch + 1, step + 1, running_loss / rs))
self.hist['step'].append(self.steps)
self.hist['loss'].append(running_loss/rs)
self.hist['trianing_running_loss'].append(running_loss/rs)
[self.hist[k].append(running_loss_dict[k]/rs) for k in running_loss_dict]
if wandb:
wandb.log({'trianing_running_loss': running_loss / rs})
running_loss = 0.0
running_loss_dict = defaultdict(int)
if verbose: print('Finished Training')
return optimizer
def evaluate(self, Xs, targets, ys, split='test', is_smiles=False, bs = 32, shuffle=False, wandb=None, only_loss=False):
self.eval()
y_preds = np.zeros( (ys.shape[0], self.config.num_templates), dtype=np.float16)
loss_metrics = defaultdict(int)
new_hist = defaultdict(float)
with torch.no_grad():
dataset = ChemRXNDataset(Xs, targets, ys, is_smiles=is_smiles,
fp_size=self.config.fp_size, fingerprint_type=self.config.fingerprint_type)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=bs, shuffle=shuffle, sampler=None,
batch_sampler=None, num_workers=0, collate_fn=None,
pin_memory=False, drop_last=False, timeout=0,
worker_init_fn=None)
#for step, s in eoutputs = self.forward(batch[0], batchnumerate(range(0, len(Xs), bs)):
for step, batch in enumerate(dataloader):#
batch = [b.to(self.config.device, non_blocking=True) for b in batch]
ys_batch = batch[2]
if hasattr(self, 'templates_np'):
outputs = []
for ii in range(10):
tlen = len(self.templates_np)
i_tlen = tlen//10
templates = torch.from_numpy(self.templates_np[(i_tlen*ii):min(i_tlen*(ii+1), tlen)]).float().to(self.config.device)
outputs.append( self.forward(batch[0], templates = templates ) )
outputs = torch.cat(outputs, dim=0)
else:
outputs = self.forward(batch[0])
loss = self.compute_losses(outputs, ys_batch, None)
# not quite right because in every batch there might be different number of valid samples
weight = 1/len(batch[0])#len(Xs[s:min(s + bs, len(Xs))]) / len(Xs)
loss_metrics['loss'] += (loss.item())
if len(ys.shape)>1:
outputs = self.softmax(outputs) if not (ys.shape[1]==self.config.num_templates) else torch.sigmoid(outputs)
else:
outputs = self.softmax(outputs)
outputs_np = [None if o is None else o.to('cpu').numpy().astype(np.float16) for o in outputs]
if not only_loss:
ks = [1, 2, 3, 4, 5, 10, 20, 30, 40, 50, 100]
topkacc, mrocc = top_k_accuracy(ys_batch, outputs, k=ks, ret_arocc=True, ret_mrocc=False)
# mrocc -- median rank of correct choice
for k, tkacc in zip(ks, topkacc):
#iterative average update
new_hist[f't{k}_acc_{split}'] += (tkacc-new_hist[f't{k}_acc_{split}']) / (step+1)
# todo weight by batch-size
new_hist[f'meanrank_{split}'] = mrocc
y_preds[step*bs : min((step+1)*bs,len(y_preds))] = outputs_np
new_hist[f'steps_{split}'] = (self.steps)
new_hist[f'loss_{split}'] = (loss_metrics['loss'] / (step+1))
for k in new_hist:
self.hist[k].append(new_hist[k])
if wandb:
wandb.log(new_hist)
self.hist[f'loss_{split}'].append(loss_metrics[f'loss'] / (step+1))
return y_preds
def save_hist(self, prefix='', postfix=''):
HIST_PATH = 'data/hist/'
if not os.path.exists(HIST_PATH):
os.mkdir(HIST_PATH)
fn_hist = HIST_PATH+prefix+postfix+'.csv'
with open(fn_hist, 'w') as fh:
print(dict(self.hist), file=fh)
return fn_hist
def save_model(self, prefix='', postfix='', name_as_conf=False):
MODEL_PATH = 'data/model/'
if not os.path.exists(MODEL_PATH):
os.mkdir(MODEL_PATH)
if name_as_conf:
confi_str = str(self.config.__dict__.values()).replace("'","").replace(': ','_').replace(', ',';')
else:
confi_str = ''
model_name = prefix+confi_str+postfix+'.pt'
torch.save(self.state_dict(), MODEL_PATH+model_name)
return MODEL_PATH+model_name
def plot_loss(self):
plot_loss(self.hist)
def plot_topk(self, sets=['train', 'valid', 'test'], with_last = 2):
plot_topk(self.hist, sets=sets, with_last = with_last)
def plot_nte(self, last_cpt=1, dataset='Sm', include_bar=True):
plot_nte(self.hist, dataset=dataset, last_cpt=last_cpt, include_bar=include_bar)
class SeglerBaseline(MHN):
"""FFNN - only the Molecule Encoder + an output projection"""
def __init__(self, config=None):
config.template_fp_type = 'none'
config.temp_encoder_layers = 0
super().__init__(config, use_template_encoder=False)
self.W_out = torch.nn.Linear(config.hopf_asso_dim, config.num_templates)
self.optimizer = getattr(torch.optim, self.config.optimizer)(self.parameters(), lr=self.lr)
self.steps = 0
self.hist = defaultdict(list)
self.to(self.config.device)
def forward(self, m, templates=None):
"""
m: molecule in the form batch x fingerprint
templates: won't be used in this case
returns logits ranking the templates for each molecule
"""
bs = m.shape[0] #batch_size
Xi = self.mol_encoder(m)
Xi = self.mol_encoder.af(Xi) # is not applied in encoder for last layer
out = self.W_out(Xi) # [bs, T] # softmax over dim=1
return out
class StaticQK(MHN):
""" Static QK baseline - beware to have the same fingerprint for mol_encoder as for the template_encoder (fp2048 r4 rdk by default)"""
def __init__(self, config=None):
if config:
self.config = config
else:
self.config = ModelConfig()
super().__init__(config)
self.fp_size = 2048
self.fingerprint_type = 'rdk'
self.beta = 1
def update_template_embedding(self, which='rdk', fp_size=2048, radius=4, learnable=False):
bs = self.config.batch_size
split_template_list = [t.split('>>')[0].split('.') for t in self.template_list]
self.templates = torch.from_numpy(convert_smiles_to_fp(split_template_list,
is_smarts=True, fp_size=fp_size,
radius=radius, which=which).max(1)).float().to(self.config.device)
def forward(self, m, templates=None):
"""
"""
#states_emb = self.fcfe(state_fp)
bs = m.shape[0] #batch_size
Xi = m #[bs, emb]
X = self.templates #[T, emb])
XXi = Xi@X.T # [bs, T]
# normalize
t_sum = templates.sum(1) #[T]
t_sum = t_sum.view(1,-1).expand(bs, -1) #[bs, T]
XXi = XXi / t_sum
# not neccecaire because it is not trained
out = self.beta*XXi # [bs, T] # softmax over dim=1
return out
class Retrosim(StaticQK):
""" Retrosim-like baseline only for template relevance prediction """
def fit_with_train(self, X_fp_train, y_train):
self.templates = torch.from_numpy(X_fp_train).float().to(self.config.device)
# train_samples, num_templates
self.sample2acttemplate = torch.nn.functional.one_hot(torch.from_numpy(y_train), self.config.num_templates).float()
tmpnorm = self.sample2acttemplate.sum(0)
tmpnorm[tmpnorm==0] = 1
self.sample2acttemplate = (self.sample2acttemplate / tmpnorm).to(self.config.device) # results in an average after dot product
def forward(self, m, templates=None):
"""
"""
out = super().forward(m, templates=templates)
# bs, train_samples
# map out to actual templates
out = out @ self.sample2acttemplate
return out