File size: 22,458 Bytes
9a73cb0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 |
'''
This is a training script for finetuning ESM.
I am going to freeze the parameters in the head and unfreeze the last N layers in the model.
'''
import os
import fuson_plm.training.config as config
# Set the WANDB_API_KEY environment variable
os.environ['WANDB_API_KEY'] = config.WANDB_API_KEY
os.environ['CUDA_VISIBLE_DEVICES'] = config.CUDA_VISIBLE_DEVICES
import torch
import numpy as np
import pandas as pd
import tqdm
from datetime import datetime
import wandb
import pytz
import sys
from transformers import AdamW
from fuson_plm.utils.logging import print_configpy, get_local_time, open_logfile, open_errfile, log_update
from fuson_plm.training.model import FusOnpLM
from fuson_plm.training.utils import batch_sample_mask_tokens_with_probabilities, get_dataloader, check_dataloaders, get_mask_rate_scheduler
from fuson_plm.training.plot import make_train_val_test_bd_plot
def prepare_model(model, n_unfrozen_layers, unfreeze_query=True, unfreeze_key=True, unfreeze_value=True):
# Log the model's initial state
n_layers = model.count_encoder_layers()
total_params = sum(p.numel() for p in model.parameters())
total_head_params = sum(p.numel() for p in model.lm_head.parameters())
log_update(f'\nInitial state:\n\tTotal number of layers in the model: {n_layers}')
log_update(f'\tTotal parameters in the AutoModelforMaskedLM model: {total_params}')
log_update(f'\tTotal parameters in the MLM Head ONLY: {total_head_params}')
# Freeze the model to start
model.freeze_model()
n_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
log_update(f'Froze all {model.n_layers} model layers')
log_update(f'\tTrainable params: {n_trainable_params}')
# Unfreeze the last n layers
model.unfreeze_last_n_layers(n_unfrozen_layers, unfreeze_query=unfreeze_query, unfreeze_key=unfreeze_key, unfreeze_value=unfreeze_value)
n_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
trainable_params = '\n\t\t'.join([name for name, param in model.named_parameters() if param.requires_grad])
num_trainable_params_lm_head = sum(p.numel() for p in model.lm_head.parameters() if p.requires_grad)
num_trainable_params_esm = sum(p.numel() for p in model.esm.parameters() if p.requires_grad)
log_update(f'Unfroze final {n_unfrozen_layers} layers')
log_update(f'\tTrainable params: {n_trainable_params}\n\t\t{trainable_params}')
log_update(f"\tTrainable parameters in the lm_head: {num_trainable_params_lm_head}")
log_update(f"\tTrainable params in the ESM part: {num_trainable_params_esm}")
def train(model, tokenizer, optimizer, train_loader, val_loader, n_epochs=10, start_epoch=1, mask_percentage=0.15, mask_rate_scheduler=None, device='cuda', checkpoint_dir='./checkpoints'):
"""
Train the model
"""
# Loop over epochs
log_update("\n")
for epoch in range(start_epoch, start_epoch+n_epochs):
if mask_rate_scheduler is not None:
mask_rate_scheduler.reset() # resetting because we rant to ramp it up again every epoch
model.train()
total_train_loss = 0
total_weighted_train_loss = 0
total_train_masked_tokens = 0
log_update(f"Epoch {epoch}")
# Loop over train data with progress bar
with tqdm.tqdm(enumerate(train_loader), total=len(train_loader), desc='Training Batch', leave=True, position=0) as pbar:
for batch_idx, (inputs, prob) in pbar:
# Take a step with the mask rate scheduler, if there is one.
masking_rate = mask_percentage
if mask_rate_scheduler is not None:
mask_rate_scheduler.step()
masking_rate = mask_rate_scheduler.get_masking_rate()
log_update(f"\tBatch index: {batch_idx}\tMasking rate: {masking_rate:.5f}")
# Move tensors
inputs = {k: v.to(device) for k, v in inputs.items()}
prob = prob.to(device)
# Mask based on probability vectors
masked_inputs = batch_sample_mask_tokens_with_probabilities(inputs, prob, tokenizer,mask_percentage=masking_rate)
# Forward pass and update
optimizer.zero_grad()
outputs = model(**masked_inputs)
loss = outputs.loss
loss.backward()
optimizer.step()
# Number of masked tokens
num_masked_tokens = (masked_inputs["input_ids"] == tokenizer.mask_token_id).sum().item()
# Loss calculations and wandb log
total_train_loss += loss.item()
total_weighted_train_loss += loss.item() * num_masked_tokens # Multiply loss by number of masked tokens
total_train_masked_tokens += num_masked_tokens
wandb.log({"batch_loss": loss.item()})
# Save a checkpoint at the end of each epoch
checkpoint_path = os.path.join(checkpoint_dir, f'checkpoint_epoch_{epoch}')
model.save_model(checkpoint_path, optimizer=optimizer)
log_update(f'\nSaved checkpoint to {checkpoint_path}')
# Calculate and log average training loss on wandb
n_train_batches = len(train_loader)
avg_train_loss = total_train_loss / n_train_batches
avg_weighted_train_loss = total_weighted_train_loss / total_train_masked_tokens
train_perplexity = np.exp(avg_weighted_train_loss)
wandb.log({"epoch": epoch,
"total_train_loss": total_train_loss, "weighted_train_loss": total_weighted_train_loss,
"avg_train_loss": avg_train_loss, "avg_weighted_train_loss": avg_weighted_train_loss,
"train_perplexity": train_perplexity})
# Track curve stats for easy re-plotting of training curves later
train_stats_df = pd.DataFrame(data={
"epoch": [epoch],
"total_train_loss": [total_train_loss], "weighted_train_loss": [total_weighted_train_loss],
"avg_train_loss": [avg_train_loss], "avg_weighted_train_loss": [avg_weighted_train_loss],
"train_perplexity": [train_perplexity]
})
if os.path.exists(f"{checkpoint_dir}/train_curve.csv"): # add to file if necessary
train_stats_df.to_csv(f"{checkpoint_dir}/train_curve.csv",index=False,header=False,mode='a')
else: # make new file if necessary
train_stats_df.to_csv(f"{checkpoint_dir}/train_curve.csv",index=False)
# Validation loop
model.eval()
total_val_loss = 0
total_weighted_val_loss = 0
total_val_masked_tokens = 0
with torch.no_grad(): # No gradients needed
# Loop over val data with progress bar
with tqdm.tqdm(enumerate(val_loader), total=len(val_loader), desc='Validation Batch', leave=True, position=0) as vbar:
for batch_idx, (inputs, prob) in vbar:
# Move tensors
inputs = {k: v.to(device) for k, v in inputs.items()}
prob = prob.to(device)
# Mask based on probability vectors
## FIXED 15% masking for the validation set
masked_inputs = batch_sample_mask_tokens_with_probabilities(inputs, prob, tokenizer,mask_percentage=0.15)
# Forward pass
outputs = model(**masked_inputs)
val_loss = outputs.loss
# Number of masked tokens
num_masked_tokens = (masked_inputs["input_ids"] == tokenizer.mask_token_id).sum().item()
# Loss calculations
total_val_loss += val_loss.item()
total_weighted_val_loss += val_loss.item() * num_masked_tokens # Multiply loss by number of masked tokens
total_val_masked_tokens += num_masked_tokens
# Calculate and log avg. loss and perplexity (wandb and locally)
n_val_batches = len(val_loader)
avg_val_loss = total_val_loss / n_val_batches # avg per batch
avg_weighted_val_loss = total_weighted_val_loss / total_val_masked_tokens # avg per masked token
val_perplexity = np.exp(avg_weighted_val_loss)
wandb.log({"epoch": epoch,
"total_val_loss": total_val_loss, "weighted_val_loss": total_weighted_val_loss,
"avg_val_loss": avg_val_loss, "avg_weighted_val_loss": avg_weighted_val_loss,
"val_perplexity": val_perplexity})
# Track curve stats for easy re-plotting of training curves later
val_stats_df = pd.DataFrame(data={
"epoch": [epoch],
"total_val_loss": [total_val_loss], "weighted_val_loss": [total_weighted_val_loss],
"avg_val_loss": [avg_val_loss], "avg_weighted_val_loss": [avg_weighted_val_loss],
"val_perplexity": [val_perplexity]
})
if os.path.exists(f"{checkpoint_dir}/val_curve.csv"): # add to file if necessary
val_stats_df.to_csv(f"{checkpoint_dir}/val_curve.csv",index=False,header=False,mode='a')
else: # make new file if necessary
val_stats_df.to_csv(f"{checkpoint_dir}/val_curve.csv",index=False)
log_update(f"Epoch: {epoch}")
log_update(f"\tTrain set: Total batches = {n_train_batches}, Total masked tokens = {total_train_masked_tokens}, Total Loss = {total_train_loss:.4f}, Avg Batch Loss = {avg_train_loss:.4f}, Avg Masked Token-Weighted Loss = {avg_weighted_train_loss:.4f}, Perplexity = {train_perplexity:.4f}")
log_update(f"\tValidation set: Total batches = {n_val_batches}, Total masked tokens = {total_val_masked_tokens}, Total Loss = {total_val_loss:.4f}, Avg Batch Loss = {avg_val_loss:.4f}, Avg Masked Token-Weighted Loss = {avg_weighted_val_loss:.4f}, Perplexity = {val_perplexity:.4f}")
def test(model, tokenizer, test_loader, mask_percentage=0.15, device='cuda', checkpoint_dir='./checkpoints'):
"""
"""
model.to(device)
model.eval()
total_test_loss = 0
total_weighted_test_loss = 0
total_test_masked_tokens = 0
with torch.no_grad(): # No gradients needed
# Loop over test data (no progress bar)
with tqdm.tqdm(enumerate(test_loader), total=len(test_loader), desc='Test Batch', leave=True, position=0) as tbar:
for batch_idx, (inputs, prob) in tbar:
# Move tensors
inputs = {k: v.to(device) for k, v in inputs.items()}
prob = prob.to(device)
# Mask based on probability vectors
### FIXED 15% masking for the testing set
masked_inputs = batch_sample_mask_tokens_with_probabilities(inputs, prob, tokenizer, mask_percentage=0.15)
# Forward pass
outputs = model(**masked_inputs)
test_loss = outputs.loss
# Number of masked tokens
num_masked_tokens = (masked_inputs["input_ids"] == tokenizer.mask_token_id).sum().item()
# Loss calculations
total_test_loss += test_loss.item()
total_weighted_test_loss += test_loss.item() * num_masked_tokens # Multiply loss by number of masked tokens
total_test_masked_tokens += num_masked_tokens
# Compute and log avg. loss and perplexity
n_test_batches = len(test_loader)
avg_test_loss = total_test_loss / n_test_batches
avg_weighted_test_loss = total_weighted_test_loss / total_test_masked_tokens
test_perplexity = np.exp(avg_weighted_test_loss)
log_update(f"\nTest results:\nTotal batches = {n_test_batches}, Total masked tokens = {total_test_masked_tokens}, Total Loss = {total_test_loss:.4f}, Avg Batch Loss = {avg_test_loss:.4f}, Avg Masked Token-Weighted Loss = {avg_weighted_test_loss:.4f}, Perplexity = {test_perplexity:.4f}")
# Save to dataframe for plotting
test_stats_df = pd.DataFrame(data={
"total_test_loss": [total_test_loss], "weighted_test_loss": [total_weighted_test_loss],
"avg_test_loss": [avg_test_loss], "avg_weighted_test_loss": [avg_weighted_test_loss],
"test_perplexity": [test_perplexity]
})
test_stats_df.to_csv(f"{checkpoint_dir}/test_results.csv",index=False) # overwrite old file no matter what; should only be one test eval
def check_env_variables():
log_update("\nChecking on environment variables...")
log_update(f"\tWANDB_API_KEY: {os.environ.get('WANDB_API_KEY')}")
log_update(f"\tCUDA_VISIBLE_DEVICES: {os.environ.get('CUDA_VISIBLE_DEVICES')}")
log_update(f"\ttorch.cuda.device_count(): {torch.cuda.device_count()}")
for i in range(torch.cuda.device_count()):
log_update(f"\t\tDevice {i}: {torch.cuda.get_device_name(i)}")
def intialize_model_and_optimizer(finetune_from_scratch, device, path_to_starting_ckpt=None, learning_rate=1e-4, n_unfrozen_layers=0, unfreeze_query=False, unfreeze_key=False, unfreeze_value=False):
"""
Initializes the model, either from ESM-2-650M if finetuning from scratch, or from a prior checkpoint if not finetuning from scratch.
Also prepares
Args:
finetune_from_scratch (bool): True if finetuning from scratch. False if finetuning from a previous ckpt
path_to_starting_ckpt (str): path to starting ckpt for finetuning (optional)
"""
if not(finetune_from_scratch) and not(os.path.exists(path_to_starting_ckpt)):
raise Exception(f"Error: could not find {path_to_starting_ckpt}. When finetuning from a prior checkpoint, you must provide a valid path to that checkpoint.")
# if finetuning from scratch, initialize from scratch
if finetune_from_scratch:
log_update(f"\nInitializing FusOn-pLM model to be finetuned from scratch")
model = FusOnpLM() # because of __getattr__, we can use FusOnpLM() to get the model. It also contains the tokenizer.
model.to(device)
prepare_model(model, n_unfrozen_layers,
unfreeze_query=unfreeze_query, unfreeze_key=unfreeze_key, unfreeze_value=unfreeze_value)
# Set the optimizer here, change it if we are finetuning from an old checkpoint
optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=learning_rate)
return model, optimizer
# if not, initialize from starting ckpt
else:
log_update(f"\nInitializing FusOn-pLM model to be finetuned from previous checkpoint: {path_to_starting_ckpt}")
model = FusOnpLM(ckpt_path = path_to_starting_ckpt, mlm_head=True)
model.to(device)
prepare_model(model, n_unfrozen_layers,
unfreeze_query=unfreeze_query, unfreeze_key=unfreeze_key, unfreeze_value=unfreeze_value)
log_update(f"Loading optimizer state_dict from previous checkpoint")
optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()))
optimizer.load_state_dict(torch.load(os.path.join(path_to_starting_ckpt, "optimizer.pt"), map_location=device))
return model, optimizer
def main():
# Set probability type to uniform; only option
config.PROBABILITY_TYPE = "uniform"
# Set run name (WANDB_NAME)
kqv_tag = f"{'Q' if config.UNFREEZE_QUERY else ''}" + f"{'K' if config.UNFREEZE_KEY else ''}" + f"{'V' if config.UNFREEZE_VALUE else ''}"
timestamp = get_local_time()
# make a mask tag _mask{config.MASK_PERCENTAGE}
mask_tag = f"mask{config.MASK_PERCENTAGE}"
if config.VAR_MASK_RATE: # if variable masking rate, change the tag to relfect this
mask_tag=f"maskvar_{config.MASK_SCHEDULER}_low{config.MASK_LOW}_high{config.MASK_HIGH}"
# Define the train settings string and wandb name from this
TRAIN_SETTINGS_STRING = f"{config.PROBABILITY_TYPE}_{config.MAX_LENGTH}_ft_{config.N_UNFROZEN_LAYERS}layers_{kqv_tag}_b{config.BATCH_SIZE}_lr{config.LEARNING_RATE}_{mask_tag}"
WANDB_NAME = f'{TRAIN_SETTINGS_STRING}-{timestamp}'
# Create directory for model checkpoints
checkpoint_dir = f'checkpoints/{WANDB_NAME}'
start_epoch = 1
# Determine if we're adding to an old log file or opening a new one
logmode='w'
# If we're finetuning from a checkpoint, save to the same folder instead, and keep track of which epoch to start on
# Also, load the optimizer from here
if not(config.FINETUNE_FROM_SCRATCH):
logmode='a'
path_to_starting_ckpt = config.PATH_TO_STARTING_CKPT
checkpoint_dir = path_to_starting_ckpt[0:path_to_starting_ckpt.rindex('/')]
START_MODEL_TRAIN_SETTINGS_STRING = checkpoint_dir[checkpoint_dir.index('checkpoints/')+len('checkpoints/'):checkpoint_dir.index('-')]
start_epoch = int(path_to_starting_ckpt.split('/checkpoint_epoch_')[1])+1
os.makedirs(f'checkpoints', exist_ok=True)
os.makedirs(checkpoint_dir, exist_ok=True)
# Open log file
LOG_PATH = f'{checkpoint_dir}/training_log.txt'
ERR_PATH = f'{checkpoint_dir}/training_errors.txt'
with open_logfile(LOG_PATH,mode=logmode), open_errfile(ERR_PATH,mode=logmode):
if not(config.FINETUNE_FROM_SCRATCH):
log_update(f"\n{'-'*200}\nResuming finetuning from checkpoint {start_epoch-1} (first new checkpoint: {start_epoch})\n")
log_update(f"Settings tag for original model (starting point for finetuning) = {START_MODEL_TRAIN_SETTINGS_STRING}\nSettings tag for new model based on configs = {TRAIN_SETTINGS_STRING}\nSame: {START_MODEL_TRAIN_SETTINGS_STRING==TRAIN_SETTINGS_STRING}\n")
# ONLY proceed with training if we're using the same settings, otherwise we are not finetuning the model we think we are!
assert START_MODEL_TRAIN_SETTINGS_STRING==TRAIN_SETTINGS_STRING
# Print configurations
print_configpy(config)
# Verify that the environment variables are set correctly
check_env_variables()
# Check CUDA availability and set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
log_update(f"\nUsing device: {device}")
# Init wandb
wandb.init(project=config.WANDB_PROJECT, entity=config.WANDB_ENTITY, name=WANDB_NAME , config={
"batch_size": config.BATCH_SIZE,
"epochs": config.EPOCHS,
"learning_rate": config.LEARNING_RATE,
})
# Initialize model and prepare it (freeze/unfreeze proper layers). Initialize optimizer as well. Details depend on whether we are finetuning from scratch.
model, optimizer = intialize_model_and_optimizer(config.FINETUNE_FROM_SCRATCH, device,
path_to_starting_ckpt=config.PATH_TO_STARTING_CKPT,
learning_rate=config.LEARNING_RATE,
n_unfrozen_layers=config.N_UNFROZEN_LAYERS,
unfreeze_query=config.UNFREEZE_QUERY,
unfreeze_key=config.UNFREEZE_KEY,
unfreeze_value=config.UNFREEZE_VALUE)
# Initialize the tokenizer (independent of starting model for finetuning)
tokenizer = model.tokenizer
# Create DataLoader instances and perform sanity checks on them
train_loader = get_dataloader(config.TRAIN_PATH, tokenizer, probability_type=config.PROBABILITY_TYPE, batch_size=config.BATCH_SIZE, max_length=config.MAX_LENGTH, shuffle=True) ## FOR DEBUGGING ONLY, change shuffle to False. Otherwise, True!!
val_loader = get_dataloader(config.VAL_PATH, tokenizer, probability_type=config.PROBABILITY_TYPE, batch_size=config.BATCH_SIZE, max_length=config.MAX_LENGTH, shuffle=False)
test_loader = get_dataloader(config.TEST_PATH, tokenizer, probability_type=config.PROBABILITY_TYPE, batch_size=config.BATCH_SIZE, max_length=config.MAX_LENGTH, shuffle=False)
# If we're continuing to finetune an old ckpt, store the old batch diversity plot before we overwrite it
check_dataloaders(train_loader, val_loader, test_loader, max_length=config.MAX_LENGTH, checkpoint_dir=checkpoint_dir)
# Set up a masking rate scheduler, if one is needed
mask_rate_scheduler = None
if config.VAR_MASK_RATE:
mask_rate_scheduler = get_mask_rate_scheduler(scheduler_type=config.MASK_SCHEDULER,
min_masking_rate=config.MASK_LOW,
max_masking_rate=config.MASK_HIGH,
total_batches=len(train_loader),
total_steps=config.MASK_STEPS)
# Train the model
train(model, tokenizer, optimizer, train_loader, val_loader,
n_epochs=config.EPOCHS,
start_epoch = start_epoch,
device=device,
mask_rate_scheduler=mask_rate_scheduler,
mask_percentage=config.MASK_PERCENTAGE,
checkpoint_dir=checkpoint_dir)
# Test the model
test(model, tokenizer, test_loader, mask_percentage=0.15, device=device, checkpoint_dir=checkpoint_dir)
if __name__ == "__main__":
main() |