BarlowDTI / model /barlow_twins.py
mschuh's picture
Upload 5 files
4f0db87 verified
import torch
torch.manual_seed(42)
torch.backends.cudnn.deterministic = True
from torch import nn
import numpy as np
from typing import *
from datetime import datetime
import os
import pickle
import inspect
from tqdm.auto import trange
import spaces
from model.base_model import BaseModel
class BarlowTwins(BaseModel):
def __init__(
self,
n_bits: int = 1024,
aa_emb_size: int = 1024,
enc_n_neurons: int = 512,
enc_n_layers: int = 2,
proj_n_neurons: int = 2048,
proj_n_layers: int = 2,
embedding_dim: int = 512,
act_function: str = "relu",
loss_weight: float = 0.005,
batch_size: int = 512,
optimizer: str = "adamw",
momentum: float = 0.9,
learning_rate: float = 0.0001,
betas: tuple = (0.9, 0.999),
weight_decay: float = 1e-3,
step_size: int = 10,
gamma: float = 0.1,
verbose: bool = True,
):
super().__init__()
self.enc_aa = None
self.enc_mol = None
self.proj = None
self.scheduler = None
self.optimizer = None
# store input in dict
self.param_dict = {
"act_function": self.activation_dict[
act_function
], # which activation function to use among dict options
"loss_weight": loss_weight, # off-diagonal cross correlation loss weight
"batch_size": batch_size, # samples per gradient step
"learning_rate": learning_rate, # update step magnitude when training
"betas": betas, # momentum hyperparameter for adam-like optimizers
"step_size": step_size, # decay period for the learning rate
"gamma": gamma, # decay coefficient for the learning rate
"optimizer": self.optimizer_dict[
optimizer
], # which optimizer to use among dict options
"momentum": momentum, # momentum hyperparameter for SGD
"enc_n_neurons": enc_n_neurons, # neurons to use for the mlp encoder
"enc_n_layers": enc_n_layers, # number of hidden layers in the mlp encoder
"proj_n_neurons": proj_n_neurons, # neurons to use for the mlp projector
"proj_n_layers": proj_n_layers, # number of hidden layers in the mlp projector
"embedding_dim": embedding_dim, # latent space dim for downstream tasks
"weight_decay": weight_decay, # l2 regularization for linear layers
"verbose": verbose, # whether to print feedback
"radius": "Not defined yet", # fingerprint radius
"n_bits": n_bits, # fingerprint bit size
"aa_emb_size": aa_emb_size, # aa embedding size
}
# create history dictionary
self.history = {
"train_loss": [],
"on_diag_loss": [],
"off_diag_loss": [],
"validation_loss": [],
}
# run NN architecture construction method
self.construct_model()
# run scheduler construction method
self.construct_scheduler()
# print if necessary
if self.param_dict["verbose"] is True:
self.print_config()
@staticmethod
def __validate_inputs(locals_dict) -> None:
# get signature types from __init__
init_signature = inspect.signature(BarlowTwins.__init__)
# loop over all chosen arguments
for param_name, param_value in locals_dict.items():
# skip self
if param_name != "self":
# check that parameter exists
if param_name in init_signature.parameters:
# check that param is correct type
expected_type = init_signature.parameters[param_name].annotation
assert isinstance(
param_value, expected_type
), f"[BT]: Type mismatch for parameter '{param_name}'"
else:
raise ValueError(f"[BT]: Unexpected parameter '{param_name}'")
def construct_mlp(self, input_units, layer_units, n_layers, output_units) -> nn.Sequential:
# make empty list to fill
mlp_list = []
# make lists defining layer sizes (input + n_neurons*n_layers + embedding_dim)
units = [input_units] + [layer_units] * n_layers
# add layer stack (linear -> batchnorm -> dropout -> activation)
for i in range(len(units) - 1):
mlp_list.append(nn.Linear(units[i], units[i + 1]))
mlp_list.append(nn.BatchNorm1d(units[i + 1]))
mlp_list.append(self.param_dict["act_function"]())
# add final linear layer
mlp_list.append(nn.Linear(units[-1], output_units))
return nn.Sequential(*mlp_list)
def construct_model(self) -> None:
# create fingerprint transformer
self.enc_mol = self.construct_mlp(
self.param_dict["n_bits"],
self.param_dict["enc_n_neurons"],
self.param_dict["enc_n_layers"],
self.param_dict["embedding_dim"],
)
# create aa transformer
self.enc_aa = self.construct_mlp(
self.param_dict["aa_emb_size"],
self.param_dict["enc_n_neurons"],
self.param_dict["enc_n_layers"],
self.param_dict["embedding_dim"],
)
# create mlp projector
self.proj = self.construct_mlp(
self.param_dict["embedding_dim"],
self.param_dict["proj_n_neurons"],
self.param_dict["proj_n_layers"],
self.param_dict["proj_n_neurons"],
)
# print if necessary
if self.param_dict["verbose"] is True:
print("[BT]: Model constructed successfully")
def construct_scheduler(self):
# make optimizer
self.optimizer = self.param_dict["optimizer"](
list(self.enc_mol.parameters())
+ list(self.enc_aa.parameters())
+ list(self.proj.parameters()),
lr=self.param_dict["learning_rate"],
betas=self.param_dict["betas"],
# momentum=self.param_dict["momentum"],
weight_decay=self.param_dict["weight_decay"],
)
# wrap optimizer in scheduler
"""
self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
self.optimizer,
T_max=self.param_dict["step_size"], # T_0
# eta_min=1e-7,
verbose=True
)
self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
self.optimizer,
patience=self.param_dict["step_size"],
verbose=True
)
"""
self.scheduler = torch.optim.lr_scheduler.StepLR(
self.optimizer,
step_size=self.param_dict["step_size"],
gamma=self.param_dict["gamma"],
)
# print if necessary
if self.param_dict["verbose"] is True:
print("[BT]: Optimizer constructed successfully")
def switch_mode(self, is_training: bool):
if is_training:
self.enc_mol.train()
self.enc_aa.train()
self.proj.train()
else:
self.enc_mol.eval()
self.enc_aa.eval()
self.proj.eval()
@staticmethod
def normalize_projection(tensor: torch.tensor) -> torch.tensor:
means = torch.mean(tensor, axis=0)
std = torch.std(tensor, axis=0)
centered = torch.add(tensor, -means)
scaled = torch.div(centered, std)
return scaled
def compute_loss(
self,
mol_embedding: torch.tensor,
aa_embedding: torch.tensor,
) -> torch.tensor:
# empirical cross-correlation matrix
mol_embedding = self.normalize_projection(mol_embedding).T
aa_embedding = self.normalize_projection(aa_embedding)
c = mol_embedding @ aa_embedding
# normalize by number of samples
c.div_(self.param_dict["batch_size"])
# compute elements on diagonal
on_diag = torch.diagonal(c).add_(-1).pow_(2).sum()
# compute elements off diagonal
n, m = c.shape
off_diag = c.flatten()[:-1].view(n - 1, n + 1)[:, 1:].flatten()
off_diag = off_diag.pow_(2).sum() * self.param_dict["loss_weight"]
return on_diag, off_diag
def forward(
self, mol_data: torch.tensor, aa_data: torch.tensor, is_training: bool = True
) -> torch.tensor:
# switch according to input
self.switch_mode(is_training)
# get embeddings
mol_embeddings = self.enc_mol(mol_data)
aa_embeddings = self.enc_aa(aa_data)
# get projections
mol_proj = self.proj(mol_embeddings)
aa_proj = self.proj(aa_embeddings)
# compute loss
on_diag, off_diag = self.compute_loss(mol_proj, aa_proj)
return on_diag, off_diag
def train(
self,
train_data: torch.utils.data.DataLoader,
val_data: torch.utils.data.DataLoader = None,
num_epochs: int = 20,
patience: int = None,
):
if self.param_dict["verbose"] is True:
print("[BT]: Training started")
if patience is None:
patience = 2 * self.param_dict["step_size"]
pbar = trange(num_epochs, desc="[BT]: Epochs", leave=False, colour="blue")
for epoch in pbar:
# initialize loss containers
train_loss = 0.0
on_diag_loss = 0.0
off_diag_loss = 0.0
val_loss = 0.0
# loop over training set
for _, (mol_data, aa_data) in enumerate(train_data):
# reset grad
self.optimizer.zero_grad()
# compute train loss for batch
on_diag, off_diag = self.forward(mol_data, aa_data, is_training=True)
t_loss = on_diag + off_diag
# backpropagation and optimization
t_loss.backward()
"""
nn.utils.clip_grad_norm_(
list(self.enc_mol.parameters()) +
list(self.enc_aa.parameters()) +
list(self.proj.parameters()),
1
)
"""
self.optimizer.step()
# add i-th loss to training container
train_loss += t_loss.item()
on_diag_loss += on_diag.item()
off_diag_loss += off_diag.item()
# add mean epoch loss for train data to history dictionary
self.history["train_loss"].append(train_loss / len(train_data))
self.history["on_diag_loss"].append(on_diag_loss / len(train_data))
self.history["off_diag_loss"].append(off_diag_loss / len(train_data))
# define msg to be printed
msg = (
f"[BT]: Epoch [{epoch + 1}/{num_epochs}], "
f"Train loss: {train_loss / len(train_data):.3f}, "
f"On diagonal: {on_diag_loss / len(train_data):.3f}, "
f"Off diagonal: {off_diag_loss / len(train_data):.3f} "
)
# loop over validation set (if present)
if val_data is not None:
for _, (mol_data, aa_data) in enumerate(val_data):
# compute val loss for batch
on_diag_v_loss, off_diag_v_loss = self.forward(
mol_data, aa_data, is_training=False
)
# add i-th loss to val container
v_loss = on_diag_v_loss + off_diag_v_loss
val_loss += v_loss.item()
# add mean epoc loss for val data to history dictionary
self.history["validation_loss"].append(val_loss / len(val_data))
# add val loss to msg
msg += f", Val loss: {val_loss / len(val_data):.3f}"
# early stopping
if self.early_stopping(patience=patience):
break
pbar.set_postfix(
{
"train loss": train_loss / len(train_data),
"val loss": val_loss / len(val_data),
}
)
else:
pbar.set_postfix({"train loss": train_loss / len(train_data)})
# update scheduler
self.scheduler.step() # val_loss / len(val_data)
if self.param_dict["verbose"] is True:
print(msg)
if self.param_dict["verbose"] is True:
print("[BT]: Training finished")
@spaces.GPU
def encode(
self, vector: np.ndarray, mode: str = "embedding", normalize: bool = True, encoder: str = "mol"
) -> np.ndarray:
"""
Encodes a given vector using the Barlow Twins model.
Args:
- vector (np.ndarray): the input vector to encode
- mode (str): the mode to use for encoding, either "embedding" or "projection"
- normalize (bool): whether to L2 normalize the output vector
Returns:
- np.ndarray: the encoded vector
"""
# set mol encoder to eval mode
self.switch_mode(is_training=False)
# convert from numpy to tensor
if type(vector) is not torch.Tensor:
vector = torch.from_numpy(vector)
# if oly one molecule pair is passed, add a batch dimension
if len(vector.shape) == 1:
vector = vector.unsqueeze(0)
# get representation
if encoder == "mol":
embedding = self.enc_mol(vector)
if mode == "projection":
embedding = self.proj(embedding)
elif encoder == "aa":
embedding = self.enc_aa(vector)
if mode == "projection":
embedding = self.proj(embedding)
else:
raise ValueError("[BT]: Encoder not recognized")
# L2 normalize (optional)
if normalize:
embedding = torch.nn.functional.normalize(embedding)
# convert back to numpy
return embedding.cpu().detach().numpy()
@spaces.GPU
def zero_shot(
self, mol_vector: np.ndarray, aa_vector: np.ndarray, l2_norm: bool = True, device: str = "cpu"
) -> np.ndarray:
# disable training
self.switch_mode(is_training=False)
# cast aa vectors (pos and neg) to correct size, force single precision
# to both
mol_vector = np.array(mol_vector, dtype=np.float32)
aa_vector = np.array(aa_vector, dtype=np.float32)
# convert to tensors
mol_vector = torch.from_numpy(mol_vector).to(device)
aa_vector = torch.from_numpy(aa_vector).to(device)
# get embeddings
mol_embedding = self.encode(mol_vector, normalize=l2_norm, encoder="mol")
aa_embedding = self.encode(aa_vector, normalize=l2_norm, encoder="aa")
# concat mol and aa embeddings
concat = np.concatenate((mol_embedding, aa_embedding), axis=1)
return concat
def zero_shot_explain(
self, mol_vector, aa_vector, l2_norm: bool = True, device: str = "cpu"
):
self.switch_mode(is_training=False)
mol_embedding = self.encode(mol_vector, normalize=l2_norm, encoder="mol")
aa_embedding = self.encode(aa_vector, normalize=l2_norm, encoder="aa")
return torch.cat((mol_embedding, aa_embedding), dim=1)
def consume_preprocessor(self, preprocessor) -> None:
# save attributes related to fingerprint generation from
# preprocessor object
self.param_dict["radius"] = preprocessor.radius
self.param_dict["n_bits"] = preprocessor.n_bits
def save_model(self, path: str) -> None:
# get current date and time for the filename
now = datetime.now()
formatted_date = now.strftime("%d%m%Y")
formatted_time = now.strftime("%H%M")
folder_name = f"{formatted_date}_{formatted_time}"
# make full path string and folder
folder_path = path + "/" + folder_name
os.makedirs(folder_path)
# make paths for weights, config and history
weight_path = folder_path + "/weights.pt"
param_path = folder_path + "/params.pkl"
history_path = folder_path + "/history.json"
# save each Sequential state dict in one object to the path
torch.save(
{
"enc_mol": self.enc_mol.state_dict(),
"enc_aa": self.enc_aa.state_dict(),
"proj": self.proj.state_dict(),
},
weight_path,
)
# dump params in pkl
with open(param_path, "wb") as file:
pickle.dump(self.param_dict, file)
# dump history in json
with open(history_path, "wb") as file:
pickle.dump(self.history, file)
# print if verbose is True
if self.param_dict["verbose"] is True:
print(f"[BT]: Model saved at {folder_path}")
@spaces.GPU
def load_model(self, path: str) -> None:
# make weights, config and history paths
weights_path = path + "/weights.pt"
param_path = path + "/params.pkl"
history_path = path + "/history.json"
# load weights, history and params
checkpoint = torch.load(weights_path, map_location=self.device)
with open(param_path, "rb") as file:
param_dict = pickle.load(file)
with open(history_path, "rb") as file:
history = pickle.load(file)
# construct model again, overriding old verbose key with new instance
verbose = self.param_dict["verbose"]
self.param_dict = param_dict
self.param_dict["verbose"] = verbose
self.history = history
self.construct_model()
# set weights in Sequential models
self.enc_mol.load_state_dict(checkpoint["enc_mol"])
self.enc_aa.load_state_dict(checkpoint["enc_aa"])
self.proj.load_state_dict(checkpoint["proj"])
# recreate scheduler and optimizer in order to add new weights
# to graph
self.construct_scheduler()
# print if verbose is True
if self.param_dict["verbose"] is True:
print(f"[BT]: Model loaded from {path}")
print("[BT]: Loaded parameters:")
print(self.param_dict)
@spaces.GPU
def move_to_device(self, device) -> None:
# move each Sequential model to device
self.enc_mol.to(device)
self.enc_aa.to(device)
self.proj.to(device)
self.device = device