|
''' |
|
This is a training script for finetuning ESM. |
|
I am going to freeze the parameters in the head and unfreeze the last N layers in the model. |
|
''' |
|
|
|
import os |
|
import fuson_plm.training.config as config |
|
|
|
|
|
os.environ['WANDB_API_KEY'] = config.WANDB_API_KEY |
|
os.environ['CUDA_VISIBLE_DEVICES'] = config.CUDA_VISIBLE_DEVICES |
|
|
|
import torch |
|
import numpy as np |
|
import pandas as pd |
|
import tqdm |
|
from datetime import datetime |
|
import wandb |
|
import pytz |
|
import sys |
|
|
|
from transformers import AdamW |
|
|
|
from fuson_plm.utils.logging import print_configpy, get_local_time, open_logfile, open_errfile, log_update |
|
from fuson_plm.training.model import FusOnpLM |
|
from fuson_plm.training.utils import batch_sample_mask_tokens_with_probabilities, get_dataloader, check_dataloaders, get_mask_rate_scheduler |
|
from fuson_plm.training.plot import make_train_val_test_bd_plot |
|
|
|
def prepare_model(model, n_unfrozen_layers, unfreeze_query=True, unfreeze_key=True, unfreeze_value=True): |
|
|
|
n_layers = model.count_encoder_layers() |
|
total_params = sum(p.numel() for p in model.parameters()) |
|
total_head_params = sum(p.numel() for p in model.lm_head.parameters()) |
|
log_update(f'\nInitial state:\n\tTotal number of layers in the model: {n_layers}') |
|
log_update(f'\tTotal parameters in the AutoModelforMaskedLM model: {total_params}') |
|
log_update(f'\tTotal parameters in the MLM Head ONLY: {total_head_params}') |
|
|
|
|
|
model.freeze_model() |
|
n_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) |
|
log_update(f'Froze all {model.n_layers} model layers') |
|
log_update(f'\tTrainable params: {n_trainable_params}') |
|
|
|
|
|
model.unfreeze_last_n_layers(n_unfrozen_layers, unfreeze_query=unfreeze_query, unfreeze_key=unfreeze_key, unfreeze_value=unfreeze_value) |
|
n_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) |
|
trainable_params = '\n\t\t'.join([name for name, param in model.named_parameters() if param.requires_grad]) |
|
num_trainable_params_lm_head = sum(p.numel() for p in model.lm_head.parameters() if p.requires_grad) |
|
num_trainable_params_esm = sum(p.numel() for p in model.esm.parameters() if p.requires_grad) |
|
log_update(f'Unfroze final {n_unfrozen_layers} layers') |
|
log_update(f'\tTrainable params: {n_trainable_params}\n\t\t{trainable_params}') |
|
log_update(f"\tTrainable parameters in the lm_head: {num_trainable_params_lm_head}") |
|
log_update(f"\tTrainable params in the ESM part: {num_trainable_params_esm}") |
|
|
|
def train(model, tokenizer, optimizer, train_loader, val_loader, n_epochs=10, start_epoch=1, mask_percentage=0.15, mask_rate_scheduler=None, device='cuda', checkpoint_dir='./checkpoints'): |
|
""" |
|
Train the model |
|
""" |
|
|
|
log_update("\n") |
|
|
|
for epoch in range(start_epoch, start_epoch+n_epochs): |
|
if mask_rate_scheduler is not None: |
|
mask_rate_scheduler.reset() |
|
|
|
model.train() |
|
total_train_loss = 0 |
|
total_weighted_train_loss = 0 |
|
total_train_masked_tokens = 0 |
|
|
|
log_update(f"Epoch {epoch}") |
|
|
|
with tqdm.tqdm(enumerate(train_loader), total=len(train_loader), desc='Training Batch', leave=True, position=0) as pbar: |
|
for batch_idx, (inputs, prob) in pbar: |
|
|
|
masking_rate = mask_percentage |
|
if mask_rate_scheduler is not None: |
|
mask_rate_scheduler.step() |
|
masking_rate = mask_rate_scheduler.get_masking_rate() |
|
log_update(f"\tBatch index: {batch_idx}\tMasking rate: {masking_rate:.5f}") |
|
|
|
|
|
inputs = {k: v.to(device) for k, v in inputs.items()} |
|
prob = prob.to(device) |
|
|
|
|
|
masked_inputs = batch_sample_mask_tokens_with_probabilities(inputs, prob, tokenizer,mask_percentage=masking_rate) |
|
|
|
|
|
optimizer.zero_grad() |
|
outputs = model(**masked_inputs) |
|
loss = outputs.loss |
|
loss.backward() |
|
optimizer.step() |
|
|
|
|
|
num_masked_tokens = (masked_inputs["input_ids"] == tokenizer.mask_token_id).sum().item() |
|
|
|
|
|
total_train_loss += loss.item() |
|
total_weighted_train_loss += loss.item() * num_masked_tokens |
|
total_train_masked_tokens += num_masked_tokens |
|
wandb.log({"batch_loss": loss.item()}) |
|
|
|
|
|
checkpoint_path = os.path.join(checkpoint_dir, f'checkpoint_epoch_{epoch}') |
|
model.save_model(checkpoint_path, optimizer=optimizer) |
|
log_update(f'\nSaved checkpoint to {checkpoint_path}') |
|
|
|
|
|
n_train_batches = len(train_loader) |
|
avg_train_loss = total_train_loss / n_train_batches |
|
avg_weighted_train_loss = total_weighted_train_loss / total_train_masked_tokens |
|
train_perplexity = np.exp(avg_weighted_train_loss) |
|
wandb.log({"epoch": epoch, |
|
"total_train_loss": total_train_loss, "weighted_train_loss": total_weighted_train_loss, |
|
"avg_train_loss": avg_train_loss, "avg_weighted_train_loss": avg_weighted_train_loss, |
|
"train_perplexity": train_perplexity}) |
|
|
|
|
|
train_stats_df = pd.DataFrame(data={ |
|
"epoch": [epoch], |
|
"total_train_loss": [total_train_loss], "weighted_train_loss": [total_weighted_train_loss], |
|
"avg_train_loss": [avg_train_loss], "avg_weighted_train_loss": [avg_weighted_train_loss], |
|
"train_perplexity": [train_perplexity] |
|
}) |
|
if os.path.exists(f"{checkpoint_dir}/train_curve.csv"): |
|
train_stats_df.to_csv(f"{checkpoint_dir}/train_curve.csv",index=False,header=False,mode='a') |
|
else: |
|
train_stats_df.to_csv(f"{checkpoint_dir}/train_curve.csv",index=False) |
|
|
|
|
|
model.eval() |
|
total_val_loss = 0 |
|
total_weighted_val_loss = 0 |
|
total_val_masked_tokens = 0 |
|
|
|
with torch.no_grad(): |
|
|
|
with tqdm.tqdm(enumerate(val_loader), total=len(val_loader), desc='Validation Batch', leave=True, position=0) as vbar: |
|
for batch_idx, (inputs, prob) in vbar: |
|
|
|
inputs = {k: v.to(device) for k, v in inputs.items()} |
|
prob = prob.to(device) |
|
|
|
|
|
|
|
masked_inputs = batch_sample_mask_tokens_with_probabilities(inputs, prob, tokenizer,mask_percentage=0.15) |
|
|
|
|
|
outputs = model(**masked_inputs) |
|
val_loss = outputs.loss |
|
|
|
|
|
num_masked_tokens = (masked_inputs["input_ids"] == tokenizer.mask_token_id).sum().item() |
|
|
|
|
|
total_val_loss += val_loss.item() |
|
total_weighted_val_loss += val_loss.item() * num_masked_tokens |
|
total_val_masked_tokens += num_masked_tokens |
|
|
|
|
|
n_val_batches = len(val_loader) |
|
avg_val_loss = total_val_loss / n_val_batches |
|
avg_weighted_val_loss = total_weighted_val_loss / total_val_masked_tokens |
|
val_perplexity = np.exp(avg_weighted_val_loss) |
|
wandb.log({"epoch": epoch, |
|
"total_val_loss": total_val_loss, "weighted_val_loss": total_weighted_val_loss, |
|
"avg_val_loss": avg_val_loss, "avg_weighted_val_loss": avg_weighted_val_loss, |
|
"val_perplexity": val_perplexity}) |
|
|
|
|
|
val_stats_df = pd.DataFrame(data={ |
|
"epoch": [epoch], |
|
"total_val_loss": [total_val_loss], "weighted_val_loss": [total_weighted_val_loss], |
|
"avg_val_loss": [avg_val_loss], "avg_weighted_val_loss": [avg_weighted_val_loss], |
|
"val_perplexity": [val_perplexity] |
|
}) |
|
if os.path.exists(f"{checkpoint_dir}/val_curve.csv"): |
|
val_stats_df.to_csv(f"{checkpoint_dir}/val_curve.csv",index=False,header=False,mode='a') |
|
else: |
|
val_stats_df.to_csv(f"{checkpoint_dir}/val_curve.csv",index=False) |
|
|
|
log_update(f"Epoch: {epoch}") |
|
log_update(f"\tTrain set: Total batches = {n_train_batches}, Total masked tokens = {total_train_masked_tokens}, Total Loss = {total_train_loss:.4f}, Avg Batch Loss = {avg_train_loss:.4f}, Avg Masked Token-Weighted Loss = {avg_weighted_train_loss:.4f}, Perplexity = {train_perplexity:.4f}") |
|
log_update(f"\tValidation set: Total batches = {n_val_batches}, Total masked tokens = {total_val_masked_tokens}, Total Loss = {total_val_loss:.4f}, Avg Batch Loss = {avg_val_loss:.4f}, Avg Masked Token-Weighted Loss = {avg_weighted_val_loss:.4f}, Perplexity = {val_perplexity:.4f}") |
|
|
|
def test(model, tokenizer, test_loader, mask_percentage=0.15, device='cuda', checkpoint_dir='./checkpoints'): |
|
""" |
|
""" |
|
model.to(device) |
|
model.eval() |
|
total_test_loss = 0 |
|
total_weighted_test_loss = 0 |
|
total_test_masked_tokens = 0 |
|
|
|
with torch.no_grad(): |
|
|
|
with tqdm.tqdm(enumerate(test_loader), total=len(test_loader), desc='Test Batch', leave=True, position=0) as tbar: |
|
for batch_idx, (inputs, prob) in tbar: |
|
|
|
inputs = {k: v.to(device) for k, v in inputs.items()} |
|
prob = prob.to(device) |
|
|
|
|
|
|
|
masked_inputs = batch_sample_mask_tokens_with_probabilities(inputs, prob, tokenizer, mask_percentage=0.15) |
|
|
|
|
|
outputs = model(**masked_inputs) |
|
test_loss = outputs.loss |
|
|
|
|
|
num_masked_tokens = (masked_inputs["input_ids"] == tokenizer.mask_token_id).sum().item() |
|
|
|
|
|
total_test_loss += test_loss.item() |
|
total_weighted_test_loss += test_loss.item() * num_masked_tokens |
|
total_test_masked_tokens += num_masked_tokens |
|
|
|
|
|
n_test_batches = len(test_loader) |
|
avg_test_loss = total_test_loss / n_test_batches |
|
avg_weighted_test_loss = total_weighted_test_loss / total_test_masked_tokens |
|
test_perplexity = np.exp(avg_weighted_test_loss) |
|
|
|
log_update(f"\nTest results:\nTotal batches = {n_test_batches}, Total masked tokens = {total_test_masked_tokens}, Total Loss = {total_test_loss:.4f}, Avg Batch Loss = {avg_test_loss:.4f}, Avg Masked Token-Weighted Loss = {avg_weighted_test_loss:.4f}, Perplexity = {test_perplexity:.4f}") |
|
|
|
|
|
test_stats_df = pd.DataFrame(data={ |
|
"total_test_loss": [total_test_loss], "weighted_test_loss": [total_weighted_test_loss], |
|
"avg_test_loss": [avg_test_loss], "avg_weighted_test_loss": [avg_weighted_test_loss], |
|
"test_perplexity": [test_perplexity] |
|
}) |
|
test_stats_df.to_csv(f"{checkpoint_dir}/test_results.csv",index=False) |
|
|
|
def check_env_variables(): |
|
log_update("\nChecking on environment variables...") |
|
log_update(f"\tWANDB_API_KEY: {os.environ.get('WANDB_API_KEY')}") |
|
log_update(f"\tCUDA_VISIBLE_DEVICES: {os.environ.get('CUDA_VISIBLE_DEVICES')}") |
|
log_update(f"\ttorch.cuda.device_count(): {torch.cuda.device_count()}") |
|
for i in range(torch.cuda.device_count()): |
|
log_update(f"\t\tDevice {i}: {torch.cuda.get_device_name(i)}") |
|
|
|
def intialize_model_and_optimizer(finetune_from_scratch, device, path_to_starting_ckpt=None, learning_rate=1e-4, n_unfrozen_layers=0, unfreeze_query=False, unfreeze_key=False, unfreeze_value=False): |
|
""" |
|
Initializes the model, either from ESM-2-650M if finetuning from scratch, or from a prior checkpoint if not finetuning from scratch. |
|
Also prepares |
|
|
|
Args: |
|
finetune_from_scratch (bool): True if finetuning from scratch. False if finetuning from a previous ckpt |
|
path_to_starting_ckpt (str): path to starting ckpt for finetuning (optional) |
|
""" |
|
if not(finetune_from_scratch) and not(os.path.exists(path_to_starting_ckpt)): |
|
raise Exception(f"Error: could not find {path_to_starting_ckpt}. When finetuning from a prior checkpoint, you must provide a valid path to that checkpoint.") |
|
|
|
|
|
if finetune_from_scratch: |
|
log_update(f"\nInitializing FusOn-pLM model to be finetuned from scratch") |
|
model = FusOnpLM() |
|
model.to(device) |
|
prepare_model(model, n_unfrozen_layers, |
|
unfreeze_query=unfreeze_query, unfreeze_key=unfreeze_key, unfreeze_value=unfreeze_value) |
|
|
|
|
|
optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=learning_rate) |
|
|
|
return model, optimizer |
|
|
|
|
|
else: |
|
log_update(f"\nInitializing FusOn-pLM model to be finetuned from previous checkpoint: {path_to_starting_ckpt}") |
|
model = FusOnpLM(ckpt_path = path_to_starting_ckpt, mlm_head=True) |
|
model.to(device) |
|
prepare_model(model, n_unfrozen_layers, |
|
unfreeze_query=unfreeze_query, unfreeze_key=unfreeze_key, unfreeze_value=unfreeze_value) |
|
|
|
log_update(f"Loading optimizer state_dict from previous checkpoint") |
|
optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters())) |
|
optimizer.load_state_dict(torch.load(os.path.join(path_to_starting_ckpt, "optimizer.pt"), map_location=device)) |
|
|
|
return model, optimizer |
|
|
|
def main(): |
|
|
|
config.PROBABILITY_TYPE = "uniform" |
|
|
|
|
|
kqv_tag = f"{'Q' if config.UNFREEZE_QUERY else ''}" + f"{'K' if config.UNFREEZE_KEY else ''}" + f"{'V' if config.UNFREEZE_VALUE else ''}" |
|
timestamp = get_local_time() |
|
|
|
mask_tag = f"mask{config.MASK_PERCENTAGE}" |
|
if config.VAR_MASK_RATE: |
|
mask_tag=f"maskvar_{config.MASK_SCHEDULER}_low{config.MASK_LOW}_high{config.MASK_HIGH}" |
|
|
|
|
|
TRAIN_SETTINGS_STRING = f"{config.PROBABILITY_TYPE}_{config.MAX_LENGTH}_ft_{config.N_UNFROZEN_LAYERS}layers_{kqv_tag}_b{config.BATCH_SIZE}_lr{config.LEARNING_RATE}_{mask_tag}" |
|
WANDB_NAME = f'{TRAIN_SETTINGS_STRING}-{timestamp}' |
|
|
|
|
|
checkpoint_dir = f'checkpoints/{WANDB_NAME}' |
|
start_epoch = 1 |
|
|
|
|
|
logmode='w' |
|
|
|
|
|
|
|
if not(config.FINETUNE_FROM_SCRATCH): |
|
logmode='a' |
|
path_to_starting_ckpt = config.PATH_TO_STARTING_CKPT |
|
checkpoint_dir = path_to_starting_ckpt[0:path_to_starting_ckpt.rindex('/')] |
|
START_MODEL_TRAIN_SETTINGS_STRING = checkpoint_dir[checkpoint_dir.index('checkpoints/')+len('checkpoints/'):checkpoint_dir.index('-')] |
|
start_epoch = int(path_to_starting_ckpt.split('/checkpoint_epoch_')[1])+1 |
|
|
|
os.makedirs(f'checkpoints', exist_ok=True) |
|
os.makedirs(checkpoint_dir, exist_ok=True) |
|
|
|
|
|
LOG_PATH = f'{checkpoint_dir}/training_log.txt' |
|
ERR_PATH = f'{checkpoint_dir}/training_errors.txt' |
|
with open_logfile(LOG_PATH,mode=logmode), open_errfile(ERR_PATH,mode=logmode): |
|
if not(config.FINETUNE_FROM_SCRATCH): |
|
log_update(f"\n{'-'*200}\nResuming finetuning from checkpoint {start_epoch-1} (first new checkpoint: {start_epoch})\n") |
|
log_update(f"Settings tag for original model (starting point for finetuning) = {START_MODEL_TRAIN_SETTINGS_STRING}\nSettings tag for new model based on configs = {TRAIN_SETTINGS_STRING}\nSame: {START_MODEL_TRAIN_SETTINGS_STRING==TRAIN_SETTINGS_STRING}\n") |
|
|
|
assert START_MODEL_TRAIN_SETTINGS_STRING==TRAIN_SETTINGS_STRING |
|
|
|
|
|
print_configpy(config) |
|
|
|
|
|
check_env_variables() |
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
log_update(f"\nUsing device: {device}") |
|
|
|
|
|
wandb.init(project=config.WANDB_PROJECT, entity=config.WANDB_ENTITY, name=WANDB_NAME , config={ |
|
"batch_size": config.BATCH_SIZE, |
|
"epochs": config.EPOCHS, |
|
"learning_rate": config.LEARNING_RATE, |
|
}) |
|
|
|
|
|
model, optimizer = intialize_model_and_optimizer(config.FINETUNE_FROM_SCRATCH, device, |
|
path_to_starting_ckpt=config.PATH_TO_STARTING_CKPT, |
|
learning_rate=config.LEARNING_RATE, |
|
n_unfrozen_layers=config.N_UNFROZEN_LAYERS, |
|
unfreeze_query=config.UNFREEZE_QUERY, |
|
unfreeze_key=config.UNFREEZE_KEY, |
|
unfreeze_value=config.UNFREEZE_VALUE) |
|
|
|
|
|
tokenizer = model.tokenizer |
|
|
|
|
|
train_loader = get_dataloader(config.TRAIN_PATH, tokenizer, probability_type=config.PROBABILITY_TYPE, batch_size=config.BATCH_SIZE, max_length=config.MAX_LENGTH, shuffle=True) |
|
val_loader = get_dataloader(config.VAL_PATH, tokenizer, probability_type=config.PROBABILITY_TYPE, batch_size=config.BATCH_SIZE, max_length=config.MAX_LENGTH, shuffle=False) |
|
test_loader = get_dataloader(config.TEST_PATH, tokenizer, probability_type=config.PROBABILITY_TYPE, batch_size=config.BATCH_SIZE, max_length=config.MAX_LENGTH, shuffle=False) |
|
|
|
|
|
check_dataloaders(train_loader, val_loader, test_loader, max_length=config.MAX_LENGTH, checkpoint_dir=checkpoint_dir) |
|
|
|
|
|
mask_rate_scheduler = None |
|
if config.VAR_MASK_RATE: |
|
mask_rate_scheduler = get_mask_rate_scheduler(scheduler_type=config.MASK_SCHEDULER, |
|
min_masking_rate=config.MASK_LOW, |
|
max_masking_rate=config.MASK_HIGH, |
|
total_batches=len(train_loader), |
|
total_steps=config.MASK_STEPS) |
|
|
|
|
|
train(model, tokenizer, optimizer, train_loader, val_loader, |
|
n_epochs=config.EPOCHS, |
|
start_epoch = start_epoch, |
|
device=device, |
|
mask_rate_scheduler=mask_rate_scheduler, |
|
mask_percentage=config.MASK_PERCENTAGE, |
|
checkpoint_dir=checkpoint_dir) |
|
|
|
|
|
test(model, tokenizer, test_loader, mask_percentage=0.15, device=device, checkpoint_dir=checkpoint_dir) |
|
|
|
if __name__ == "__main__": |
|
main() |