|
import torch |
|
import hydra |
|
from omegaconf import DictConfig, OmegaConf |
|
from transformers.trainer import DataLoader |
|
import torch.nn.functional as F |
|
import os |
|
import accelerate |
|
from accelerate import Accelerator, DistributedType |
|
from accelerate.utils import DistributedDataParallelKwargs |
|
|
|
from accelerate.utils import LoggerType |
|
from accelerate.local_sgd import LocalSGD |
|
import esm |
|
|
|
from protein_lm.modeling.getters.collate import ( |
|
DataCollatorWithPadding, |
|
SequenceLengthSampler, |
|
) |
|
from accelerate.utils import set_seed |
|
from protein_lm.modeling.getters.dataset import get_dataset |
|
|
|
from protein_lm.modeling.getters.ptm_dataset import get_ptm_dataset |
|
from protein_lm.modeling.getters.log import TrainLogger |
|
from protein_lm.modeling.getters.mask import Masker |
|
from protein_lm.modeling.getters.scheduler import Esm2LRScheduler |
|
from protein_lm.modeling.models.mamba.lm import MambaLMHeadModel |
|
from protein_lm.tokenizer.tokenizer import PTMTokenizer |
|
|
|
|
|
def mlm_loss(outputs, input_ids, mask): |
|
return F.cross_entropy( |
|
outputs[mask], |
|
input_ids[mask], |
|
) |
|
|
|
|
|
|
|
@torch.no_grad() |
|
def compute_esm_embedding(tokenizer, esm_model, batch_converter, masked_input_ids): |
|
device = masked_input_ids.device |
|
esm_model = esm_model.to(device) |
|
inputs = [ |
|
(i, tokenizer.decode(input_id.detach().cpu().tolist())) |
|
for i, input_id in enumerate(masked_input_ids) |
|
] |
|
batch_labels, batch_strs, batch_tokens = batch_converter(inputs) |
|
batch_tokens = batch_tokens[..., 1:-1].to( |
|
device |
|
) |
|
out = esm_model(batch_tokens, repr_layers=[33], return_contacts=False) |
|
embedding = out["representations"][33] |
|
return embedding |
|
|
|
|
|
@hydra.main( |
|
version_base=None, config_path="../../configs", |
|
) |
|
def main(config_dict: DictConfig): |
|
config_dict = config_dict['train'] |
|
data_config = config_dict["dataset"] |
|
kwargs = DistributedDataParallelKwargs(find_unused_parameters=True) |
|
accelerator = Accelerator(kwargs_handlers=[kwargs]) |
|
device = accelerator.device |
|
train_args = config_dict["training_arguments"] |
|
set_seed(config_dict["seed"]) |
|
if "wandb" in config_dict.report_to: |
|
import wandb |
|
|
|
if accelerator.is_local_main_process: |
|
wandb.init( |
|
project="PTM-Mamba", config=dict(config_dict), name=train_args.save_dir |
|
) |
|
logger = wandb |
|
else: |
|
logger = TrainLogger() |
|
|
|
tokenizer = PTMTokenizer() |
|
dataset = get_dataset( |
|
config_dict=data_config, |
|
tokenizer=tokenizer, |
|
) |
|
|
|
if train_args.resume_from_checkpoint: |
|
model = load_ckpt(train_args.resume_from_checkpoint, tokenizer, device) |
|
accelerator.print(f"Model loaded from {train_args.resume_from_checkpoint}") |
|
else: |
|
config_dict.model.vocab_size = tokenizer.get_vocab_size() |
|
model = MambaLMHeadModel(config=config_dict.model, device=device) |
|
num_params = sum(p.numel() for p in model.parameters() if p.requires_grad) |
|
accelerator.print(f"Number of parameters: {num_params:,}") |
|
sampler = SequenceLengthSampler( |
|
dataset["train"], train_args.sort_by_seq, train_args.sample_len_ascending) |
|
train_loader = DataLoader( |
|
dataset["train"], |
|
batch_size=train_args.per_device_train_batch_size, |
|
sampler=sampler, |
|
collate_fn=DataCollatorWithPadding( |
|
max_tokens=train_args.max_tokens_per_batch, |
|
tokenizer=tokenizer, |
|
batch_by_tokens=True, |
|
max_seq_len=train_args.max_sequence_length, |
|
), |
|
num_workers=0, |
|
pin_memory=True, |
|
) |
|
val_loader = DataLoader( |
|
dataset["val"], |
|
batch_size=train_args.per_device_train_batch_size // 2, |
|
collate_fn=DataCollatorWithPadding( |
|
max_tokens=train_args.max_tokens_per_batch, |
|
tokenizer=tokenizer, |
|
batch_by_tokens=False, |
|
max_seq_len=train_args.max_sequence_length, |
|
), |
|
num_workers=0, |
|
) |
|
optimizer = torch.optim.AdamW( |
|
model.parameters(), train_args.lr, betas=(0.9, 0.98), weight_decay=0.01 |
|
) |
|
|
|
scheduler = Esm2LRScheduler( |
|
optimizer, last_epoch=-1, init_lr=train_args.lr, on_use=False |
|
) |
|
|
|
masker = Masker(tokenizer) |
|
model, optimizer, scheduler, train_loader, val_loader = accelerator.prepare( |
|
model, optimizer, scheduler, train_loader, val_loader |
|
) |
|
|
|
train( |
|
config_dict=config_dict, |
|
model=model, |
|
train_loader=train_loader, |
|
val_loader=val_loader, |
|
optimizer=optimizer, |
|
scheduler=scheduler, |
|
tokenizer=tokenizer, |
|
masker=masker, |
|
logger=logger, |
|
accelerator=accelerator, |
|
) |
|
|
|
def load_ckpt(ckpt_path, tokenizer, device): |
|
ckpt = torch.load(ckpt_path) |
|
model_state_dict = ckpt["model"] |
|
model_config = ckpt["config"] |
|
model_config.vocab_size = tokenizer.get_vocab_size() |
|
model = MambaLMHeadModel(config=model_config, device=device) |
|
msg = model.load_state_dict(model_state_dict, strict=True) |
|
print(msg) |
|
return model |
|
|
|
|
|
def train( |
|
config_dict: DictConfig, |
|
model, |
|
train_loader, |
|
val_loader, |
|
optimizer, |
|
scheduler, |
|
tokenizer, |
|
masker: Masker, |
|
logger, |
|
accelerator: Accelerator, |
|
): |
|
train_args = config_dict["training_arguments"] |
|
save_dir = train_args.save_dir |
|
device = accelerator.device |
|
os.makedirs(save_dir, exist_ok=True) |
|
best_ckpt_path = os.path.join(save_dir, "best.ckpt") |
|
last_ckpt_path = os.path.join(save_dir, "last.ckpt") |
|
best_loss = float("inf") |
|
|
|
if train_args.use_esm: |
|
esm_model, alphabet = esm.pretrained.esm2_t33_650M_UR50D() |
|
batch_converter = alphabet.get_batch_converter() |
|
esm_model.eval() |
|
for param in esm_model.parameters(): |
|
param.requires_grad = False |
|
model_to_save = model if accelerator.distributed_type==DistributedType.NO else model.module |
|
masking_fn = masker.random_or_random_and_ptm_mask |
|
total_steps = 0 |
|
for epoch in range(train_args.num_train_epochs): |
|
for batch in train_loader: |
|
optimizer.zero_grad() |
|
input_ids = batch["input_ids"] |
|
pad_mask = batch["pad_mask"] |
|
esm_input_ids = make_esm_input_ids(input_ids, tokenizer) |
|
|
|
masked_input_ids, pred_mask = masking_fn(input_ids) |
|
masked_esm_input_ids = esm_input_ids.clone() |
|
masked_esm_input_ids[pred_mask] = tokenizer.mask_token_id |
|
if train_args.use_esm: |
|
embedding = compute_esm_embedding( |
|
tokenizer, esm_model, batch_converter, masked_esm_input_ids |
|
) |
|
else: |
|
embedding = None |
|
outputs = model(masked_input_ids, embedding=embedding) |
|
logits = outputs.logits |
|
loss = mlm_loss(logits, input_ids, pred_mask) |
|
accelerator.backward(loss) |
|
preplexity = torch.exp(loss) |
|
acc = (logits.argmax(dim=-1) == input_ids)[pred_mask].float().mean() |
|
ptm_acc = ( |
|
(logits.argmax(dim=-1) == input_ids)[ |
|
pred_mask & tokenizer.is_ptm_token(input_ids).to(device) |
|
] |
|
.float() |
|
.mean() |
|
) |
|
if accelerator.is_local_main_process: |
|
logger.log( |
|
{ |
|
"train_loss": loss.item(), |
|
"train_preplexity": preplexity.item(), |
|
"train_acc": acc.item(), |
|
"train_ptm_acc": ptm_acc.item(), |
|
"act_bs": input_ids.shape[0], |
|
"act_seq_len": input_ids.shape[1], |
|
} |
|
) |
|
optimizer.step() |
|
scheduler.step() |
|
total_steps += 1 |
|
if total_steps % train_args.log_steps == 0: |
|
model.eval() |
|
for val_batch in val_loader: |
|
with torch.no_grad(): |
|
input_ids = val_batch["input_ids"] |
|
pad_mask = val_batch["pad_mask"] |
|
esm_input_ids = make_esm_input_ids(input_ids, tokenizer) |
|
masked_input_ids, pred_mask = masking_fn(input_ids) |
|
masked_esm_input_ids = esm_input_ids.clone() |
|
masked_esm_input_ids[pred_mask] = tokenizer.mask_token_id |
|
if train_args.use_esm: |
|
embedding = compute_esm_embedding( |
|
tokenizer, esm_model, batch_converter, masked_esm_input_ids |
|
) |
|
else: |
|
embedding = None |
|
outputs = model(masked_input_ids, embedding=embedding) |
|
logits = outputs.logits |
|
loss = mlm_loss(logits, input_ids, pred_mask) |
|
preplexity = torch.exp(loss) |
|
acc = (logits.argmax(dim=-1) == input_ids)[pred_mask].float().mean() |
|
ptm_acc = ( |
|
(logits.argmax(dim=-1) == input_ids)[ |
|
pred_mask & tokenizer.is_ptm_token(input_ids).to(device) |
|
] |
|
.float() |
|
.mean() |
|
) |
|
|
|
if accelerator.is_local_main_process: |
|
logger.log( |
|
{ |
|
"Epoch": epoch, |
|
"val_loss": loss.item(), |
|
"val_preplexity": preplexity.item(), |
|
"val_acc": acc.item(), |
|
"val_ptm_acc": ptm_acc.item(), |
|
} |
|
) |
|
if loss < best_loss: |
|
best_loss = loss |
|
torch.save( |
|
{"model": model_to_save.state_dict(), "config": config_dict.model}, |
|
best_ckpt_path, |
|
) |
|
if accelerator.is_local_main_process: |
|
torch.save( |
|
{"model": model_to_save.state_dict(), "config": config_dict.model}, |
|
last_ckpt_path, |
|
) |
|
accelerator.print(f"Epoch {epoch}, Step {total_steps} finished!") |
|
|
|
if accelerator.is_local_main_process: |
|
torch.save( |
|
{"model": model_to_save.state_dict(), "config": config_dict.model}, |
|
last_ckpt_path, |
|
) |
|
accelerator.print(f"Training completed!") |
|
|
|
def make_esm_input_ids(input_ids, tokenizer,): |
|
""" |
|
Replace PTM tokens with mask token for ESM input |
|
""" |
|
device = input_ids.device |
|
is_ptm_mask = tokenizer.is_ptm_token(input_ids).to(device) |
|
esm_input_ids = input_ids.clone() |
|
esm_input_ids[is_ptm_mask] = tokenizer.mask_token_id |
|
return esm_input_ids |
|
|
|
if __name__ == "__main__": |
|
main() |
|
|