### Run ESM2 on the validation and test set. Get val and test losses. import os import fuson_plm.training.config as config # Set the WANDB_API_KEY environment variable os.environ['WANDB_API_KEY'] = config.WANDB_API_KEY os.environ['CUDA_VISIBLE_DEVICES'] = config.CUDA_VISIBLE_DEVICES import torch import tqdm import numpy as np import pandas as pd import logging from transformers import AutoModelForMaskedLM, AutoTokenizer from fuson_plm.utils.logging import log_update, open_logfile, print_configpy from fuson_plm.benchmarking.caid.utils import DisorderDataset, get_dataloader, check_dataloaders from fuson_plm.training.utils import batch_sample_mask_tokens_with_probabilities, get_dataloader, check_dataloaders from fuson_plm.training.train import test def load_esm2_maskedlm(esm_type, device=None): """ Loads ESM-2 version of a specified version (e.g. esm2_t33_650M_UR50D) """ # Suppress warnings about newly initialized 'esm.pooler.dense.bias', 'esm.pooler.dense.weight' layers - these are not used to extract embeddings logging.getLogger("transformers.modeling_utils").setLevel(logging.ERROR) if device is None: device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Using device: {device}") model = AutoModelForMaskedLM.from_pretrained(f"facebook/{esm_type}") tokenizer = AutoTokenizer.from_pretrained(f"facebook/{esm_type}") model.to(device) model.eval() # disables dropout for deterministic results return model, tokenizer, device def val(model, tokenizer, val_loader, mask_percentage=0.15, device='cuda', checkpoint_dir='./checkpoints'): """ Same method as val, just for running the val set """ model.to(device) model.eval() total_val_loss = 0 total_weighted_val_loss = 0 total_val_masked_tokens = 0 with torch.no_grad(): # No gradients needed # Loop over val data (no progress bar) with tqdm.tqdm(enumerate(val_loader), total=len(val_loader), desc='Val Batch', leave=True, position=0) as tbar: for batch_idx, (inputs, prob) in tbar: # Move tensors inputs = {k: v.to(device) for k, v in inputs.items()} prob = prob.to(device) # Mask based on probability vectors masked_inputs = batch_sample_mask_tokens_with_probabilities(inputs, prob, tokenizer, mask_percentage=mask_percentage) # Forward pass outputs = model(**masked_inputs) val_loss = outputs.loss # Number of masked tokens num_masked_tokens = (masked_inputs["input_ids"] == tokenizer.mask_token_id).sum().item() # Loss calculations total_val_loss += val_loss.item() total_weighted_val_loss += val_loss.item() * num_masked_tokens # Multiply loss by number of masked tokens total_val_masked_tokens += num_masked_tokens # Compute and log avg. loss and perplexity n_val_batches = len(val_loader) avg_val_loss = total_val_loss / n_val_batches avg_weighted_val_loss = total_weighted_val_loss / total_val_masked_tokens val_perplexity = np.exp(avg_weighted_val_loss) log_update(f"\nval results:\nTotal batches = {n_val_batches}, Total masked tokens = {total_val_masked_tokens}, Total Loss = {total_val_loss:.4f}, Avg Batch Loss = {avg_val_loss:.4f}, Avg Masked Token-Weighted Loss = {avg_weighted_val_loss:.4f}, Perplexity = {val_perplexity:.4f}") # Save to dataframe for plotting val_stats_df = pd.DataFrame(data={ "total_val_loss": [total_val_loss], "weighted_val_loss": [total_weighted_val_loss], "avg_val_loss": [avg_val_loss], "avg_weighted_val_loss": [avg_weighted_val_loss], "val_perplexity": [val_perplexity] }) val_stats_df.to_csv(f"{checkpoint_dir}/val_results.csv",index=False) # overwrite old file no matter what; should only be one val eval def main(): # Load the ESM-2 model model, tokenizer, device = load_esm2_maskedlm("esm2_t33_650M_UR50D") checkpoint_dir = f"checkpoints/esm2_t33_650M_UR50D_{config.PROBABILITY_TYPE}_mask{config.MASK_PERCENTAGE}" os.makedirs(checkpoint_dir,exist_ok=True) with open_logfile(f"{checkpoint_dir}/evaluate_val_test_esm.txt"): # Print configurations print_configpy(config) ##### Validation val_loader = get_dataloader(config.VAL_PATH, tokenizer, probability_type=config.PROBABILITY_TYPE, batch_size=config.BATCH_SIZE, max_length=config.MAX_LENGTH, shuffle=False) # Validation val(model, tokenizer, val_loader, config.MASK_PERCENTAGE, device=device, checkpoint_dir=checkpoint_dir) ##### Test # Crete dataloader test_loader = get_dataloader(config.TEST_PATH, tokenizer, probability_type=config.PROBABILITY_TYPE, batch_size=config.BATCH_SIZE, max_length=config.MAX_LENGTH, shuffle=False) # Test the model test(model, tokenizer, test_loader, config.MASK_PERCENTAGE, device=device, checkpoint_dir=checkpoint_dir) if __name__ == "__main__": main()