Fill-Mask
Transformers
PyTorch
esm
Inference Endpoints
svincoff's picture
uploaded training code and model weights
9a73cb0
'''
This is a training script for finetuning ESM.
I am going to freeze the parameters in the head and unfreeze the last N layers in the model.
'''
import os
import fuson_plm.training.config as config
# Set the WANDB_API_KEY environment variable
os.environ['WANDB_API_KEY'] = config.WANDB_API_KEY
os.environ['CUDA_VISIBLE_DEVICES'] = config.CUDA_VISIBLE_DEVICES
import torch
import numpy as np
import pandas as pd
import tqdm
from datetime import datetime
import wandb
import pytz
import sys
from transformers import AdamW
from fuson_plm.utils.logging import print_configpy, get_local_time, open_logfile, open_errfile, log_update
from fuson_plm.training.model import FusOnpLM
from fuson_plm.training.utils import batch_sample_mask_tokens_with_probabilities, get_dataloader, check_dataloaders, get_mask_rate_scheduler
from fuson_plm.training.plot import make_train_val_test_bd_plot
def prepare_model(model, n_unfrozen_layers, unfreeze_query=True, unfreeze_key=True, unfreeze_value=True):
# Log the model's initial state
n_layers = model.count_encoder_layers()
total_params = sum(p.numel() for p in model.parameters())
total_head_params = sum(p.numel() for p in model.lm_head.parameters())
log_update(f'\nInitial state:\n\tTotal number of layers in the model: {n_layers}')
log_update(f'\tTotal parameters in the AutoModelforMaskedLM model: {total_params}')
log_update(f'\tTotal parameters in the MLM Head ONLY: {total_head_params}')
# Freeze the model to start
model.freeze_model()
n_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
log_update(f'Froze all {model.n_layers} model layers')
log_update(f'\tTrainable params: {n_trainable_params}')
# Unfreeze the last n layers
model.unfreeze_last_n_layers(n_unfrozen_layers, unfreeze_query=unfreeze_query, unfreeze_key=unfreeze_key, unfreeze_value=unfreeze_value)
n_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
trainable_params = '\n\t\t'.join([name for name, param in model.named_parameters() if param.requires_grad])
num_trainable_params_lm_head = sum(p.numel() for p in model.lm_head.parameters() if p.requires_grad)
num_trainable_params_esm = sum(p.numel() for p in model.esm.parameters() if p.requires_grad)
log_update(f'Unfroze final {n_unfrozen_layers} layers')
log_update(f'\tTrainable params: {n_trainable_params}\n\t\t{trainable_params}')
log_update(f"\tTrainable parameters in the lm_head: {num_trainable_params_lm_head}")
log_update(f"\tTrainable params in the ESM part: {num_trainable_params_esm}")
def train(model, tokenizer, optimizer, train_loader, val_loader, n_epochs=10, start_epoch=1, mask_percentage=0.15, mask_rate_scheduler=None, device='cuda', checkpoint_dir='./checkpoints'):
"""
Train the model
"""
# Loop over epochs
log_update("\n")
for epoch in range(start_epoch, start_epoch+n_epochs):
if mask_rate_scheduler is not None:
mask_rate_scheduler.reset() # resetting because we rant to ramp it up again every epoch
model.train()
total_train_loss = 0
total_weighted_train_loss = 0
total_train_masked_tokens = 0
log_update(f"Epoch {epoch}")
# Loop over train data with progress bar
with tqdm.tqdm(enumerate(train_loader), total=len(train_loader), desc='Training Batch', leave=True, position=0) as pbar:
for batch_idx, (inputs, prob) in pbar:
# Take a step with the mask rate scheduler, if there is one.
masking_rate = mask_percentage
if mask_rate_scheduler is not None:
mask_rate_scheduler.step()
masking_rate = mask_rate_scheduler.get_masking_rate()
log_update(f"\tBatch index: {batch_idx}\tMasking rate: {masking_rate:.5f}")
# Move tensors
inputs = {k: v.to(device) for k, v in inputs.items()}
prob = prob.to(device)
# Mask based on probability vectors
masked_inputs = batch_sample_mask_tokens_with_probabilities(inputs, prob, tokenizer,mask_percentage=masking_rate)
# Forward pass and update
optimizer.zero_grad()
outputs = model(**masked_inputs)
loss = outputs.loss
loss.backward()
optimizer.step()
# Number of masked tokens
num_masked_tokens = (masked_inputs["input_ids"] == tokenizer.mask_token_id).sum().item()
# Loss calculations and wandb log
total_train_loss += loss.item()
total_weighted_train_loss += loss.item() * num_masked_tokens # Multiply loss by number of masked tokens
total_train_masked_tokens += num_masked_tokens
wandb.log({"batch_loss": loss.item()})
# Save a checkpoint at the end of each epoch
checkpoint_path = os.path.join(checkpoint_dir, f'checkpoint_epoch_{epoch}')
model.save_model(checkpoint_path, optimizer=optimizer)
log_update(f'\nSaved checkpoint to {checkpoint_path}')
# Calculate and log average training loss on wandb
n_train_batches = len(train_loader)
avg_train_loss = total_train_loss / n_train_batches
avg_weighted_train_loss = total_weighted_train_loss / total_train_masked_tokens
train_perplexity = np.exp(avg_weighted_train_loss)
wandb.log({"epoch": epoch,
"total_train_loss": total_train_loss, "weighted_train_loss": total_weighted_train_loss,
"avg_train_loss": avg_train_loss, "avg_weighted_train_loss": avg_weighted_train_loss,
"train_perplexity": train_perplexity})
# Track curve stats for easy re-plotting of training curves later
train_stats_df = pd.DataFrame(data={
"epoch": [epoch],
"total_train_loss": [total_train_loss], "weighted_train_loss": [total_weighted_train_loss],
"avg_train_loss": [avg_train_loss], "avg_weighted_train_loss": [avg_weighted_train_loss],
"train_perplexity": [train_perplexity]
})
if os.path.exists(f"{checkpoint_dir}/train_curve.csv"): # add to file if necessary
train_stats_df.to_csv(f"{checkpoint_dir}/train_curve.csv",index=False,header=False,mode='a')
else: # make new file if necessary
train_stats_df.to_csv(f"{checkpoint_dir}/train_curve.csv",index=False)
# Validation loop
model.eval()
total_val_loss = 0
total_weighted_val_loss = 0
total_val_masked_tokens = 0
with torch.no_grad(): # No gradients needed
# Loop over val data with progress bar
with tqdm.tqdm(enumerate(val_loader), total=len(val_loader), desc='Validation Batch', leave=True, position=0) as vbar:
for batch_idx, (inputs, prob) in vbar:
# Move tensors
inputs = {k: v.to(device) for k, v in inputs.items()}
prob = prob.to(device)
# Mask based on probability vectors
## FIXED 15% masking for the validation set
masked_inputs = batch_sample_mask_tokens_with_probabilities(inputs, prob, tokenizer,mask_percentage=0.15)
# Forward pass
outputs = model(**masked_inputs)
val_loss = outputs.loss
# Number of masked tokens
num_masked_tokens = (masked_inputs["input_ids"] == tokenizer.mask_token_id).sum().item()
# Loss calculations
total_val_loss += val_loss.item()
total_weighted_val_loss += val_loss.item() * num_masked_tokens # Multiply loss by number of masked tokens
total_val_masked_tokens += num_masked_tokens
# Calculate and log avg. loss and perplexity (wandb and locally)
n_val_batches = len(val_loader)
avg_val_loss = total_val_loss / n_val_batches # avg per batch
avg_weighted_val_loss = total_weighted_val_loss / total_val_masked_tokens # avg per masked token
val_perplexity = np.exp(avg_weighted_val_loss)
wandb.log({"epoch": epoch,
"total_val_loss": total_val_loss, "weighted_val_loss": total_weighted_val_loss,
"avg_val_loss": avg_val_loss, "avg_weighted_val_loss": avg_weighted_val_loss,
"val_perplexity": val_perplexity})
# Track curve stats for easy re-plotting of training curves later
val_stats_df = pd.DataFrame(data={
"epoch": [epoch],
"total_val_loss": [total_val_loss], "weighted_val_loss": [total_weighted_val_loss],
"avg_val_loss": [avg_val_loss], "avg_weighted_val_loss": [avg_weighted_val_loss],
"val_perplexity": [val_perplexity]
})
if os.path.exists(f"{checkpoint_dir}/val_curve.csv"): # add to file if necessary
val_stats_df.to_csv(f"{checkpoint_dir}/val_curve.csv",index=False,header=False,mode='a')
else: # make new file if necessary
val_stats_df.to_csv(f"{checkpoint_dir}/val_curve.csv",index=False)
log_update(f"Epoch: {epoch}")
log_update(f"\tTrain set: Total batches = {n_train_batches}, Total masked tokens = {total_train_masked_tokens}, Total Loss = {total_train_loss:.4f}, Avg Batch Loss = {avg_train_loss:.4f}, Avg Masked Token-Weighted Loss = {avg_weighted_train_loss:.4f}, Perplexity = {train_perplexity:.4f}")
log_update(f"\tValidation set: Total batches = {n_val_batches}, Total masked tokens = {total_val_masked_tokens}, Total Loss = {total_val_loss:.4f}, Avg Batch Loss = {avg_val_loss:.4f}, Avg Masked Token-Weighted Loss = {avg_weighted_val_loss:.4f}, Perplexity = {val_perplexity:.4f}")
def test(model, tokenizer, test_loader, mask_percentage=0.15, device='cuda', checkpoint_dir='./checkpoints'):
"""
"""
model.to(device)
model.eval()
total_test_loss = 0
total_weighted_test_loss = 0
total_test_masked_tokens = 0
with torch.no_grad(): # No gradients needed
# Loop over test data (no progress bar)
with tqdm.tqdm(enumerate(test_loader), total=len(test_loader), desc='Test Batch', leave=True, position=0) as tbar:
for batch_idx, (inputs, prob) in tbar:
# Move tensors
inputs = {k: v.to(device) for k, v in inputs.items()}
prob = prob.to(device)
# Mask based on probability vectors
### FIXED 15% masking for the testing set
masked_inputs = batch_sample_mask_tokens_with_probabilities(inputs, prob, tokenizer, mask_percentage=0.15)
# Forward pass
outputs = model(**masked_inputs)
test_loss = outputs.loss
# Number of masked tokens
num_masked_tokens = (masked_inputs["input_ids"] == tokenizer.mask_token_id).sum().item()
# Loss calculations
total_test_loss += test_loss.item()
total_weighted_test_loss += test_loss.item() * num_masked_tokens # Multiply loss by number of masked tokens
total_test_masked_tokens += num_masked_tokens
# Compute and log avg. loss and perplexity
n_test_batches = len(test_loader)
avg_test_loss = total_test_loss / n_test_batches
avg_weighted_test_loss = total_weighted_test_loss / total_test_masked_tokens
test_perplexity = np.exp(avg_weighted_test_loss)
log_update(f"\nTest results:\nTotal batches = {n_test_batches}, Total masked tokens = {total_test_masked_tokens}, Total Loss = {total_test_loss:.4f}, Avg Batch Loss = {avg_test_loss:.4f}, Avg Masked Token-Weighted Loss = {avg_weighted_test_loss:.4f}, Perplexity = {test_perplexity:.4f}")
# Save to dataframe for plotting
test_stats_df = pd.DataFrame(data={
"total_test_loss": [total_test_loss], "weighted_test_loss": [total_weighted_test_loss],
"avg_test_loss": [avg_test_loss], "avg_weighted_test_loss": [avg_weighted_test_loss],
"test_perplexity": [test_perplexity]
})
test_stats_df.to_csv(f"{checkpoint_dir}/test_results.csv",index=False) # overwrite old file no matter what; should only be one test eval
def check_env_variables():
log_update("\nChecking on environment variables...")
log_update(f"\tWANDB_API_KEY: {os.environ.get('WANDB_API_KEY')}")
log_update(f"\tCUDA_VISIBLE_DEVICES: {os.environ.get('CUDA_VISIBLE_DEVICES')}")
log_update(f"\ttorch.cuda.device_count(): {torch.cuda.device_count()}")
for i in range(torch.cuda.device_count()):
log_update(f"\t\tDevice {i}: {torch.cuda.get_device_name(i)}")
def intialize_model_and_optimizer(finetune_from_scratch, device, path_to_starting_ckpt=None, learning_rate=1e-4, n_unfrozen_layers=0, unfreeze_query=False, unfreeze_key=False, unfreeze_value=False):
"""
Initializes the model, either from ESM-2-650M if finetuning from scratch, or from a prior checkpoint if not finetuning from scratch.
Also prepares
Args:
finetune_from_scratch (bool): True if finetuning from scratch. False if finetuning from a previous ckpt
path_to_starting_ckpt (str): path to starting ckpt for finetuning (optional)
"""
if not(finetune_from_scratch) and not(os.path.exists(path_to_starting_ckpt)):
raise Exception(f"Error: could not find {path_to_starting_ckpt}. When finetuning from a prior checkpoint, you must provide a valid path to that checkpoint.")
# if finetuning from scratch, initialize from scratch
if finetune_from_scratch:
log_update(f"\nInitializing FusOn-pLM model to be finetuned from scratch")
model = FusOnpLM() # because of __getattr__, we can use FusOnpLM() to get the model. It also contains the tokenizer.
model.to(device)
prepare_model(model, n_unfrozen_layers,
unfreeze_query=unfreeze_query, unfreeze_key=unfreeze_key, unfreeze_value=unfreeze_value)
# Set the optimizer here, change it if we are finetuning from an old checkpoint
optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=learning_rate)
return model, optimizer
# if not, initialize from starting ckpt
else:
log_update(f"\nInitializing FusOn-pLM model to be finetuned from previous checkpoint: {path_to_starting_ckpt}")
model = FusOnpLM(ckpt_path = path_to_starting_ckpt, mlm_head=True)
model.to(device)
prepare_model(model, n_unfrozen_layers,
unfreeze_query=unfreeze_query, unfreeze_key=unfreeze_key, unfreeze_value=unfreeze_value)
log_update(f"Loading optimizer state_dict from previous checkpoint")
optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()))
optimizer.load_state_dict(torch.load(os.path.join(path_to_starting_ckpt, "optimizer.pt"), map_location=device))
return model, optimizer
def main():
# Set probability type to uniform; only option
config.PROBABILITY_TYPE = "uniform"
# Set run name (WANDB_NAME)
kqv_tag = f"{'Q' if config.UNFREEZE_QUERY else ''}" + f"{'K' if config.UNFREEZE_KEY else ''}" + f"{'V' if config.UNFREEZE_VALUE else ''}"
timestamp = get_local_time()
# make a mask tag _mask{config.MASK_PERCENTAGE}
mask_tag = f"mask{config.MASK_PERCENTAGE}"
if config.VAR_MASK_RATE: # if variable masking rate, change the tag to relfect this
mask_tag=f"maskvar_{config.MASK_SCHEDULER}_low{config.MASK_LOW}_high{config.MASK_HIGH}"
# Define the train settings string and wandb name from this
TRAIN_SETTINGS_STRING = f"{config.PROBABILITY_TYPE}_{config.MAX_LENGTH}_ft_{config.N_UNFROZEN_LAYERS}layers_{kqv_tag}_b{config.BATCH_SIZE}_lr{config.LEARNING_RATE}_{mask_tag}"
WANDB_NAME = f'{TRAIN_SETTINGS_STRING}-{timestamp}'
# Create directory for model checkpoints
checkpoint_dir = f'checkpoints/{WANDB_NAME}'
start_epoch = 1
# Determine if we're adding to an old log file or opening a new one
logmode='w'
# If we're finetuning from a checkpoint, save to the same folder instead, and keep track of which epoch to start on
# Also, load the optimizer from here
if not(config.FINETUNE_FROM_SCRATCH):
logmode='a'
path_to_starting_ckpt = config.PATH_TO_STARTING_CKPT
checkpoint_dir = path_to_starting_ckpt[0:path_to_starting_ckpt.rindex('/')]
START_MODEL_TRAIN_SETTINGS_STRING = checkpoint_dir[checkpoint_dir.index('checkpoints/')+len('checkpoints/'):checkpoint_dir.index('-')]
start_epoch = int(path_to_starting_ckpt.split('/checkpoint_epoch_')[1])+1
os.makedirs(f'checkpoints', exist_ok=True)
os.makedirs(checkpoint_dir, exist_ok=True)
# Open log file
LOG_PATH = f'{checkpoint_dir}/training_log.txt'
ERR_PATH = f'{checkpoint_dir}/training_errors.txt'
with open_logfile(LOG_PATH,mode=logmode), open_errfile(ERR_PATH,mode=logmode):
if not(config.FINETUNE_FROM_SCRATCH):
log_update(f"\n{'-'*200}\nResuming finetuning from checkpoint {start_epoch-1} (first new checkpoint: {start_epoch})\n")
log_update(f"Settings tag for original model (starting point for finetuning) = {START_MODEL_TRAIN_SETTINGS_STRING}\nSettings tag for new model based on configs = {TRAIN_SETTINGS_STRING}\nSame: {START_MODEL_TRAIN_SETTINGS_STRING==TRAIN_SETTINGS_STRING}\n")
# ONLY proceed with training if we're using the same settings, otherwise we are not finetuning the model we think we are!
assert START_MODEL_TRAIN_SETTINGS_STRING==TRAIN_SETTINGS_STRING
# Print configurations
print_configpy(config)
# Verify that the environment variables are set correctly
check_env_variables()
# Check CUDA availability and set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
log_update(f"\nUsing device: {device}")
# Init wandb
wandb.init(project=config.WANDB_PROJECT, entity=config.WANDB_ENTITY, name=WANDB_NAME , config={
"batch_size": config.BATCH_SIZE,
"epochs": config.EPOCHS,
"learning_rate": config.LEARNING_RATE,
})
# Initialize model and prepare it (freeze/unfreeze proper layers). Initialize optimizer as well. Details depend on whether we are finetuning from scratch.
model, optimizer = intialize_model_and_optimizer(config.FINETUNE_FROM_SCRATCH, device,
path_to_starting_ckpt=config.PATH_TO_STARTING_CKPT,
learning_rate=config.LEARNING_RATE,
n_unfrozen_layers=config.N_UNFROZEN_LAYERS,
unfreeze_query=config.UNFREEZE_QUERY,
unfreeze_key=config.UNFREEZE_KEY,
unfreeze_value=config.UNFREEZE_VALUE)
# Initialize the tokenizer (independent of starting model for finetuning)
tokenizer = model.tokenizer
# Create DataLoader instances and perform sanity checks on them
train_loader = get_dataloader(config.TRAIN_PATH, tokenizer, probability_type=config.PROBABILITY_TYPE, batch_size=config.BATCH_SIZE, max_length=config.MAX_LENGTH, shuffle=True) ## FOR DEBUGGING ONLY, change shuffle to False. Otherwise, True!!
val_loader = get_dataloader(config.VAL_PATH, tokenizer, probability_type=config.PROBABILITY_TYPE, batch_size=config.BATCH_SIZE, max_length=config.MAX_LENGTH, shuffle=False)
test_loader = get_dataloader(config.TEST_PATH, tokenizer, probability_type=config.PROBABILITY_TYPE, batch_size=config.BATCH_SIZE, max_length=config.MAX_LENGTH, shuffle=False)
# If we're continuing to finetune an old ckpt, store the old batch diversity plot before we overwrite it
check_dataloaders(train_loader, val_loader, test_loader, max_length=config.MAX_LENGTH, checkpoint_dir=checkpoint_dir)
# Set up a masking rate scheduler, if one is needed
mask_rate_scheduler = None
if config.VAR_MASK_RATE:
mask_rate_scheduler = get_mask_rate_scheduler(scheduler_type=config.MASK_SCHEDULER,
min_masking_rate=config.MASK_LOW,
max_masking_rate=config.MASK_HIGH,
total_batches=len(train_loader),
total_steps=config.MASK_STEPS)
# Train the model
train(model, tokenizer, optimizer, train_loader, val_loader,
n_epochs=config.EPOCHS,
start_epoch = start_epoch,
device=device,
mask_rate_scheduler=mask_rate_scheduler,
mask_percentage=config.MASK_PERCENTAGE,
checkpoint_dir=checkpoint_dir)
# Test the model
test(model, tokenizer, test_loader, mask_percentage=0.15, device=device, checkpoint_dir=checkpoint_dir)
if __name__ == "__main__":
main()