| import torch |
| import hydra |
| from omegaconf import DictConfig, OmegaConf |
| from transformers.trainer import DataLoader |
| import torch.nn.functional as F |
| import os |
| import accelerate |
| from accelerate import Accelerator, DistributedType |
| from accelerate.utils import DistributedDataParallelKwargs |
|
|
| from accelerate.utils import LoggerType |
| from accelerate.local_sgd import LocalSGD |
| import esm |
|
|
| from protein_lm.modeling.getters.collate import ( |
| DataCollatorWithPadding, |
| SequenceLengthSampler, |
| ) |
| from accelerate.utils import set_seed |
| from protein_lm.modeling.getters.dataset import get_dataset |
|
|
| from protein_lm.modeling.getters.ptm_dataset import get_ptm_dataset |
| from protein_lm.modeling.getters.log import TrainLogger |
| from protein_lm.modeling.getters.mask import Masker |
| from protein_lm.modeling.getters.scheduler import Esm2LRScheduler |
| from protein_lm.modeling.models.mamba.lm import MambaLMHeadModel |
| from protein_lm.tokenizer.tokenizer import PTMTokenizer |
|
|
|
|
| def mlm_loss(outputs, input_ids, mask): |
| return F.cross_entropy( |
| outputs[mask], |
| input_ids[mask], |
| ) |
|
|
|
|
|
|
| @torch.no_grad() |
| def compute_esm_embedding(tokenizer, esm_model, batch_converter, masked_input_ids): |
| device = masked_input_ids.device |
| esm_model = esm_model.to(device) |
| inputs = [ |
| (i, tokenizer.decode(input_id.detach().cpu().tolist())) |
| for i, input_id in enumerate(masked_input_ids) |
| ] |
| batch_labels, batch_strs, batch_tokens = batch_converter(inputs) |
| batch_tokens = batch_tokens[..., 1:-1].to( |
| device |
| ) |
| out = esm_model(batch_tokens, repr_layers=[33], return_contacts=False) |
| embedding = out["representations"][33] |
| return embedding |
|
|
|
|
| @hydra.main( |
| version_base=None, config_path="../../configs", |
| ) |
| def main(config_dict: DictConfig): |
| config_dict = config_dict['train'] |
| data_config = config_dict["dataset"] |
| kwargs = DistributedDataParallelKwargs(find_unused_parameters=True) |
| accelerator = Accelerator(kwargs_handlers=[kwargs]) |
| device = accelerator.device |
| train_args = config_dict["training_arguments"] |
| set_seed(config_dict["seed"]) |
| if "wandb" in config_dict.report_to: |
| import wandb |
|
|
| if accelerator.is_local_main_process: |
| wandb.init( |
| project="PTM-Mamba", config=dict(config_dict), name=train_args.save_dir |
| ) |
| logger = wandb |
| else: |
| logger = TrainLogger() |
|
|
| tokenizer = PTMTokenizer() |
| dataset = get_dataset( |
| config_dict=data_config, |
| tokenizer=tokenizer, |
| ) |
|
|
| if train_args.resume_from_checkpoint: |
| model = load_ckpt(train_args.resume_from_checkpoint, tokenizer, device) |
| accelerator.print(f"Model loaded from {train_args.resume_from_checkpoint}") |
| else: |
| config_dict.model.vocab_size = tokenizer.get_vocab_size() |
| model = MambaLMHeadModel(config=config_dict.model, device=device) |
| num_params = sum(p.numel() for p in model.parameters() if p.requires_grad) |
| accelerator.print(f"Number of parameters: {num_params:,}") |
| sampler = SequenceLengthSampler( |
| dataset["train"], train_args.sort_by_seq, train_args.sample_len_ascending) |
| train_loader = DataLoader( |
| dataset["train"], |
| batch_size=train_args.per_device_train_batch_size, |
| sampler=sampler, |
| collate_fn=DataCollatorWithPadding( |
| max_tokens=train_args.max_tokens_per_batch, |
| tokenizer=tokenizer, |
| batch_by_tokens=True, |
| max_seq_len=train_args.max_sequence_length, |
| ), |
| num_workers=0, |
| pin_memory=True, |
| ) |
| val_loader = DataLoader( |
| dataset["val"], |
| batch_size=train_args.per_device_train_batch_size // 2, |
| collate_fn=DataCollatorWithPadding( |
| max_tokens=train_args.max_tokens_per_batch, |
| tokenizer=tokenizer, |
| batch_by_tokens=False, |
| max_seq_len=train_args.max_sequence_length, |
| ), |
| num_workers=0, |
| ) |
| optimizer = torch.optim.AdamW( |
| model.parameters(), train_args.lr, betas=(0.9, 0.98), weight_decay=0.01 |
| ) |
|
|
| scheduler = Esm2LRScheduler( |
| optimizer, last_epoch=-1, init_lr=train_args.lr, on_use=False |
| ) |
|
|
| masker = Masker(tokenizer) |
| model, optimizer, scheduler, train_loader, val_loader = accelerator.prepare( |
| model, optimizer, scheduler, train_loader, val_loader |
| ) |
|
|
| train( |
| config_dict=config_dict, |
| model=model, |
| train_loader=train_loader, |
| val_loader=val_loader, |
| optimizer=optimizer, |
| scheduler=scheduler, |
| tokenizer=tokenizer, |
| masker=masker, |
| logger=logger, |
| accelerator=accelerator, |
| ) |
|
|
| def load_ckpt(ckpt_path, tokenizer, device): |
| ckpt = torch.load(ckpt_path) |
| model_state_dict = ckpt["model"] |
| model_config = ckpt["config"] |
| model_config.vocab_size = tokenizer.get_vocab_size() |
| model = MambaLMHeadModel(config=model_config, device=device) |
| msg = model.load_state_dict(model_state_dict, strict=True) |
| print(msg) |
| return model |
|
|
|
|
| def train( |
| config_dict: DictConfig, |
| model, |
| train_loader, |
| val_loader, |
| optimizer, |
| scheduler, |
| tokenizer, |
| masker: Masker, |
| logger, |
| accelerator: Accelerator, |
| ): |
| train_args = config_dict["training_arguments"] |
| save_dir = train_args.save_dir |
| device = accelerator.device |
| os.makedirs(save_dir, exist_ok=True) |
| best_ckpt_path = os.path.join(save_dir, "best.ckpt") |
| last_ckpt_path = os.path.join(save_dir, "last.ckpt") |
| best_loss = float("inf") |
|
|
| if train_args.use_esm: |
| esm_model, alphabet = esm.pretrained.esm2_t33_650M_UR50D() |
| batch_converter = alphabet.get_batch_converter() |
| esm_model.eval() |
| for param in esm_model.parameters(): |
| param.requires_grad = False |
| model_to_save = model if accelerator.distributed_type==DistributedType.NO else model.module |
| masking_fn = masker.random_or_random_and_ptm_mask |
| total_steps = 0 |
| for epoch in range(train_args.num_train_epochs): |
| for batch in train_loader: |
| optimizer.zero_grad() |
| input_ids = batch["input_ids"] |
| pad_mask = batch["pad_mask"] |
| esm_input_ids = make_esm_input_ids(input_ids, tokenizer) |
| |
| masked_input_ids, pred_mask = masking_fn(input_ids) |
| masked_esm_input_ids = esm_input_ids.clone() |
| masked_esm_input_ids[pred_mask] = tokenizer.mask_token_id |
| if train_args.use_esm: |
| embedding = compute_esm_embedding( |
| tokenizer, esm_model, batch_converter, masked_esm_input_ids |
| ) |
| else: |
| embedding = None |
| outputs = model(masked_input_ids, embedding=embedding) |
| logits = outputs.logits |
| loss = mlm_loss(logits, input_ids, pred_mask) |
| accelerator.backward(loss) |
| preplexity = torch.exp(loss) |
| acc = (logits.argmax(dim=-1) == input_ids)[pred_mask].float().mean() |
| ptm_acc = ( |
| (logits.argmax(dim=-1) == input_ids)[ |
| pred_mask & tokenizer.is_ptm_token(input_ids).to(device) |
| ] |
| .float() |
| .mean() |
| ) |
| if accelerator.is_local_main_process: |
| logger.log( |
| { |
| "train_loss": loss.item(), |
| "train_preplexity": preplexity.item(), |
| "train_acc": acc.item(), |
| "train_ptm_acc": ptm_acc.item(), |
| "act_bs": input_ids.shape[0], |
| "act_seq_len": input_ids.shape[1], |
| } |
| ) |
| optimizer.step() |
| scheduler.step() |
| total_steps += 1 |
| if total_steps % train_args.log_steps == 0: |
| model.eval() |
| for val_batch in val_loader: |
| with torch.no_grad(): |
| input_ids = val_batch["input_ids"] |
| pad_mask = val_batch["pad_mask"] |
| esm_input_ids = make_esm_input_ids(input_ids, tokenizer) |
| masked_input_ids, pred_mask = masking_fn(input_ids) |
| masked_esm_input_ids = esm_input_ids.clone() |
| masked_esm_input_ids[pred_mask] = tokenizer.mask_token_id |
| if train_args.use_esm: |
| embedding = compute_esm_embedding( |
| tokenizer, esm_model, batch_converter, masked_esm_input_ids |
| ) |
| else: |
| embedding = None |
| outputs = model(masked_input_ids, embedding=embedding) |
| logits = outputs.logits |
| loss = mlm_loss(logits, input_ids, pred_mask) |
| preplexity = torch.exp(loss) |
| acc = (logits.argmax(dim=-1) == input_ids)[pred_mask].float().mean() |
| ptm_acc = ( |
| (logits.argmax(dim=-1) == input_ids)[ |
| pred_mask & tokenizer.is_ptm_token(input_ids).to(device) |
| ] |
| .float() |
| .mean() |
| ) |
| |
| if accelerator.is_local_main_process: |
| logger.log( |
| { |
| "Epoch": epoch, |
| "val_loss": loss.item(), |
| "val_preplexity": preplexity.item(), |
| "val_acc": acc.item(), |
| "val_ptm_acc": ptm_acc.item(), |
| } |
| ) |
| if loss < best_loss: |
| best_loss = loss |
| torch.save( |
| {"model": model_to_save.state_dict(), "config": config_dict.model}, |
| best_ckpt_path, |
| ) |
| if accelerator.is_local_main_process: |
| torch.save( |
| {"model": model_to_save.state_dict(), "config": config_dict.model}, |
| last_ckpt_path, |
| ) |
| accelerator.print(f"Epoch {epoch}, Step {total_steps} finished!") |
| |
| if accelerator.is_local_main_process: |
| torch.save( |
| {"model": model_to_save.state_dict(), "config": config_dict.model}, |
| last_ckpt_path, |
| ) |
| accelerator.print(f"Training completed!") |
|
|
| def make_esm_input_ids(input_ids, tokenizer,): |
| """ |
| Replace PTM tokens with mask token for ESM input |
| """ |
| device = input_ids.device |
| is_ptm_mask = tokenizer.is_ptm_token(input_ids).to(device) |
| esm_input_ids = input_ids.clone() |
| esm_input_ids[is_ptm_mask] = tokenizer.mask_token_id |
| return esm_input_ids |
| |
| if __name__ == "__main__": |
| main() |
|
|