zhangzhi's picture
init commit
a476bbf verified
raw
history blame
No virus
11.1 kB
import torch
import hydra
from omegaconf import DictConfig, OmegaConf
from transformers.trainer import DataLoader
import torch.nn.functional as F
import os
import accelerate
from accelerate import Accelerator, DistributedType
from accelerate.utils import DistributedDataParallelKwargs
from accelerate.utils import LoggerType
from accelerate.local_sgd import LocalSGD
import esm
from protein_lm.modeling.getters.collate import (
DataCollatorWithPadding,
SequenceLengthSampler,
)
from accelerate.utils import set_seed
from protein_lm.modeling.getters.dataset import get_dataset
from protein_lm.modeling.getters.ptm_dataset import get_ptm_dataset
from protein_lm.modeling.getters.log import TrainLogger
from protein_lm.modeling.getters.mask import Masker
from protein_lm.modeling.getters.scheduler import Esm2LRScheduler
from protein_lm.modeling.models.mamba.lm import MambaLMHeadModel
from protein_lm.tokenizer.tokenizer import PTMTokenizer
def mlm_loss(outputs, input_ids, mask):
return F.cross_entropy(
outputs[mask],
input_ids[mask],
)
@torch.no_grad()
def compute_esm_embedding(tokenizer, esm_model, batch_converter, masked_input_ids):
device = masked_input_ids.device
esm_model = esm_model.to(device)
inputs = [
(i, tokenizer.decode(input_id.detach().cpu().tolist()))
for i, input_id in enumerate(masked_input_ids)
]
batch_labels, batch_strs, batch_tokens = batch_converter(inputs)
batch_tokens = batch_tokens[..., 1:-1].to(
device
) # remove <cls> and <eos> from ESM encoding
out = esm_model(batch_tokens, repr_layers=[33], return_contacts=False)
embedding = out["representations"][33]
return embedding
@hydra.main(
version_base=None, config_path="../../configs",
)
def main(config_dict: DictConfig):
config_dict = config_dict['train']
data_config = config_dict["dataset"]
kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
accelerator = Accelerator(kwargs_handlers=[kwargs])
device = accelerator.device
train_args = config_dict["training_arguments"]
set_seed(config_dict["seed"])
if "wandb" in config_dict.report_to:
import wandb
if accelerator.is_local_main_process:
wandb.init(
project="PTM-Mamba", config=dict(config_dict), name=train_args.save_dir
)
logger = wandb
else:
logger = TrainLogger()
tokenizer = PTMTokenizer()
dataset = get_dataset(
config_dict=data_config,
tokenizer=tokenizer,
)
if train_args.resume_from_checkpoint:
model = load_ckpt(train_args.resume_from_checkpoint, tokenizer, device)
accelerator.print(f"Model loaded from {train_args.resume_from_checkpoint}")
else:
config_dict.model.vocab_size = tokenizer.get_vocab_size()
model = MambaLMHeadModel(config=config_dict.model, device=device)
num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
accelerator.print(f"Number of parameters: {num_params:,}")
sampler = SequenceLengthSampler(
dataset["train"], train_args.sort_by_seq, train_args.sample_len_ascending)
train_loader = DataLoader(
dataset["train"],
batch_size=train_args.per_device_train_batch_size,
sampler=sampler,
collate_fn=DataCollatorWithPadding(
max_tokens=train_args.max_tokens_per_batch,
tokenizer=tokenizer,
batch_by_tokens=True,
max_seq_len=train_args.max_sequence_length,
),
num_workers=0,
pin_memory=True,
)
val_loader = DataLoader(
dataset["val"],
batch_size=train_args.per_device_train_batch_size // 2,
collate_fn=DataCollatorWithPadding(
max_tokens=train_args.max_tokens_per_batch,
tokenizer=tokenizer,
batch_by_tokens=False,
max_seq_len=train_args.max_sequence_length,
),
num_workers=0,
)
optimizer = torch.optim.AdamW(
model.parameters(), train_args.lr, betas=(0.9, 0.98), weight_decay=0.01
)
scheduler = Esm2LRScheduler(
optimizer, last_epoch=-1, init_lr=train_args.lr, on_use=False
)
masker = Masker(tokenizer)
model, optimizer, scheduler, train_loader, val_loader = accelerator.prepare(
model, optimizer, scheduler, train_loader, val_loader
)
train(
config_dict=config_dict,
model=model,
train_loader=train_loader,
val_loader=val_loader,
optimizer=optimizer,
scheduler=scheduler,
tokenizer=tokenizer,
masker=masker,
logger=logger,
accelerator=accelerator,
)
def load_ckpt(ckpt_path, tokenizer, device):
ckpt = torch.load(ckpt_path)
model_state_dict = ckpt["model"]
model_config = ckpt["config"]
model_config.vocab_size = tokenizer.get_vocab_size()
model = MambaLMHeadModel(config=model_config, device=device)
msg = model.load_state_dict(model_state_dict, strict=True)
print(msg)
return model
def train(
config_dict: DictConfig,
model,
train_loader,
val_loader,
optimizer,
scheduler,
tokenizer,
masker: Masker,
logger,
accelerator: Accelerator,
):
train_args = config_dict["training_arguments"]
save_dir = train_args.save_dir
device = accelerator.device
os.makedirs(save_dir, exist_ok=True)
best_ckpt_path = os.path.join(save_dir, "best.ckpt")
last_ckpt_path = os.path.join(save_dir, "last.ckpt")
best_loss = float("inf")
if train_args.use_esm:
esm_model, alphabet = esm.pretrained.esm2_t33_650M_UR50D()
batch_converter = alphabet.get_batch_converter()
esm_model.eval()
for param in esm_model.parameters():
param.requires_grad = False
model_to_save = model if accelerator.distributed_type==DistributedType.NO else model.module
masking_fn = masker.random_or_random_and_ptm_mask
total_steps = 0
for epoch in range(train_args.num_train_epochs):
for batch in train_loader:
optimizer.zero_grad()
input_ids = batch["input_ids"]
pad_mask = batch["pad_mask"]
esm_input_ids = make_esm_input_ids(input_ids, tokenizer)
# masking
masked_input_ids, pred_mask = masking_fn(input_ids)
masked_esm_input_ids = esm_input_ids.clone()
masked_esm_input_ids[pred_mask] = tokenizer.mask_token_id
if train_args.use_esm:
embedding = compute_esm_embedding(
tokenizer, esm_model, batch_converter, masked_esm_input_ids
)
else:
embedding = None
outputs = model(masked_input_ids, embedding=embedding)
logits = outputs.logits
loss = mlm_loss(logits, input_ids, pred_mask)
accelerator.backward(loss)
preplexity = torch.exp(loss)
acc = (logits.argmax(dim=-1) == input_ids)[pred_mask].float().mean()
ptm_acc = (
(logits.argmax(dim=-1) == input_ids)[
pred_mask & tokenizer.is_ptm_token(input_ids).to(device)
]
.float()
.mean()
)
if accelerator.is_local_main_process:
logger.log(
{
"train_loss": loss.item(),
"train_preplexity": preplexity.item(),
"train_acc": acc.item(),
"train_ptm_acc": ptm_acc.item(),
"act_bs": input_ids.shape[0],
"act_seq_len": input_ids.shape[1],
}
)
optimizer.step()
scheduler.step()
total_steps += 1
if total_steps % train_args.log_steps == 0:
model.eval()
for val_batch in val_loader:
with torch.no_grad():
input_ids = val_batch["input_ids"]
pad_mask = val_batch["pad_mask"]
esm_input_ids = make_esm_input_ids(input_ids, tokenizer)
masked_input_ids, pred_mask = masking_fn(input_ids)
masked_esm_input_ids = esm_input_ids.clone()
masked_esm_input_ids[pred_mask] = tokenizer.mask_token_id
if train_args.use_esm:
embedding = compute_esm_embedding(
tokenizer, esm_model, batch_converter, masked_esm_input_ids
)
else:
embedding = None
outputs = model(masked_input_ids, embedding=embedding)
logits = outputs.logits
loss = mlm_loss(logits, input_ids, pred_mask)
preplexity = torch.exp(loss)
acc = (logits.argmax(dim=-1) == input_ids)[pred_mask].float().mean()
ptm_acc = (
(logits.argmax(dim=-1) == input_ids)[
pred_mask & tokenizer.is_ptm_token(input_ids).to(device)
]
.float()
.mean()
)
if accelerator.is_local_main_process:
logger.log(
{
"Epoch": epoch,
"val_loss": loss.item(),
"val_preplexity": preplexity.item(),
"val_acc": acc.item(),
"val_ptm_acc": ptm_acc.item(),
}
)
if loss < best_loss:
best_loss = loss
torch.save(
{"model": model_to_save.state_dict(), "config": config_dict.model},
best_ckpt_path,
)
if accelerator.is_local_main_process:
torch.save(
{"model": model_to_save.state_dict(), "config": config_dict.model},
last_ckpt_path,
)
accelerator.print(f"Epoch {epoch}, Step {total_steps} finished!")
if accelerator.is_local_main_process:
torch.save(
{"model": model_to_save.state_dict(), "config": config_dict.model},
last_ckpt_path,
)
accelerator.print(f"Training completed!")
def make_esm_input_ids(input_ids, tokenizer,):
"""
Replace PTM tokens with mask token for ESM input
"""
device = input_ids.device
is_ptm_mask = tokenizer.is_ptm_token(input_ids).to(device)
esm_input_ids = input_ids.clone()
esm_input_ids[is_ptm_mask] = tokenizer.mask_token_id
return esm_input_ids
if __name__ == "__main__":
main()