File size: 5,621 Bytes
9a73cb0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 |
### Run ESM2 on the validation and test set. Get val and test losses.
import os
import fuson_plm.training.config as config
# Set the WANDB_API_KEY environment variable
os.environ['WANDB_API_KEY'] = config.WANDB_API_KEY
os.environ['CUDA_VISIBLE_DEVICES'] = config.CUDA_VISIBLE_DEVICES
import torch
import tqdm
import numpy as np
import pandas as pd
import logging
from transformers import AutoModelForMaskedLM, AutoTokenizer
from fuson_plm.utils.logging import log_update, open_logfile, print_configpy
from fuson_plm.benchmarking.caid.utils import DisorderDataset, get_dataloader, check_dataloaders
from fuson_plm.training.utils import batch_sample_mask_tokens_with_probabilities, get_dataloader, check_dataloaders
from fuson_plm.training.train import test
def load_esm2_maskedlm(esm_type, device=None):
"""
Loads ESM-2 version of a specified version (e.g. esm2_t33_650M_UR50D)
"""
# Suppress warnings about newly initialized 'esm.pooler.dense.bias', 'esm.pooler.dense.weight' layers - these are not used to extract embeddings
logging.getLogger("transformers.modeling_utils").setLevel(logging.ERROR)
if device is None:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
model = AutoModelForMaskedLM.from_pretrained(f"facebook/{esm_type}")
tokenizer = AutoTokenizer.from_pretrained(f"facebook/{esm_type}")
model.to(device)
model.eval() # disables dropout for deterministic results
return model, tokenizer, device
def val(model, tokenizer, val_loader, mask_percentage=0.15, device='cuda', checkpoint_dir='./checkpoints'):
"""
Same method as val, just for running the val set
"""
model.to(device)
model.eval()
total_val_loss = 0
total_weighted_val_loss = 0
total_val_masked_tokens = 0
with torch.no_grad(): # No gradients needed
# Loop over val data (no progress bar)
with tqdm.tqdm(enumerate(val_loader), total=len(val_loader), desc='Val Batch', leave=True, position=0) as tbar:
for batch_idx, (inputs, prob) in tbar:
# Move tensors
inputs = {k: v.to(device) for k, v in inputs.items()}
prob = prob.to(device)
# Mask based on probability vectors
masked_inputs = batch_sample_mask_tokens_with_probabilities(inputs, prob, tokenizer, mask_percentage=mask_percentage)
# Forward pass
outputs = model(**masked_inputs)
val_loss = outputs.loss
# Number of masked tokens
num_masked_tokens = (masked_inputs["input_ids"] == tokenizer.mask_token_id).sum().item()
# Loss calculations
total_val_loss += val_loss.item()
total_weighted_val_loss += val_loss.item() * num_masked_tokens # Multiply loss by number of masked tokens
total_val_masked_tokens += num_masked_tokens
# Compute and log avg. loss and perplexity
n_val_batches = len(val_loader)
avg_val_loss = total_val_loss / n_val_batches
avg_weighted_val_loss = total_weighted_val_loss / total_val_masked_tokens
val_perplexity = np.exp(avg_weighted_val_loss)
log_update(f"\nval results:\nTotal batches = {n_val_batches}, Total masked tokens = {total_val_masked_tokens}, Total Loss = {total_val_loss:.4f}, Avg Batch Loss = {avg_val_loss:.4f}, Avg Masked Token-Weighted Loss = {avg_weighted_val_loss:.4f}, Perplexity = {val_perplexity:.4f}")
# Save to dataframe for plotting
val_stats_df = pd.DataFrame(data={
"total_val_loss": [total_val_loss], "weighted_val_loss": [total_weighted_val_loss],
"avg_val_loss": [avg_val_loss], "avg_weighted_val_loss": [avg_weighted_val_loss],
"val_perplexity": [val_perplexity]
})
val_stats_df.to_csv(f"{checkpoint_dir}/val_results.csv",index=False) # overwrite old file no matter what; should only be one val eval
def main():
# Load the ESM-2 model
model, tokenizer, device = load_esm2_maskedlm("esm2_t33_650M_UR50D")
checkpoint_dir = f"checkpoints/esm2_t33_650M_UR50D_{config.PROBABILITY_TYPE}_mask{config.MASK_PERCENTAGE}"
os.makedirs(checkpoint_dir,exist_ok=True)
with open_logfile(f"{checkpoint_dir}/evaluate_val_test_esm.txt"):
# Print configurations
print_configpy(config)
##### Validation
val_loader = get_dataloader(config.VAL_PATH, tokenizer,
probability_type=config.PROBABILITY_TYPE,
batch_size=config.BATCH_SIZE,
max_length=config.MAX_LENGTH, shuffle=False)
# Validation
val(model, tokenizer, val_loader, config.MASK_PERCENTAGE, device=device, checkpoint_dir=checkpoint_dir)
##### Test
# Crete dataloader
test_loader = get_dataloader(config.TEST_PATH,
tokenizer,
probability_type=config.PROBABILITY_TYPE,
batch_size=config.BATCH_SIZE,
max_length=config.MAX_LENGTH, shuffle=False)
# Test the model
test(model, tokenizer, test_loader, config.MASK_PERCENTAGE, device=device, checkpoint_dir=checkpoint_dir)
if __name__ == "__main__":
main() |