Fill-Mask
Transformers
PyTorch
esm
Inference Endpoints
File size: 5,621 Bytes
9a73cb0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
### Run ESM2 on the validation and test set. Get val and test losses. 
import os
import fuson_plm.training.config as config
# Set the WANDB_API_KEY environment variable
os.environ['WANDB_API_KEY'] = config.WANDB_API_KEY
os.environ['CUDA_VISIBLE_DEVICES'] = config.CUDA_VISIBLE_DEVICES

import torch
import tqdm
import numpy as np
import pandas as pd
import logging
from transformers import AutoModelForMaskedLM, AutoTokenizer
from fuson_plm.utils.logging import log_update, open_logfile, print_configpy
from fuson_plm.benchmarking.caid.utils import DisorderDataset, get_dataloader, check_dataloaders
from fuson_plm.training.utils import batch_sample_mask_tokens_with_probabilities, get_dataloader, check_dataloaders
from fuson_plm.training.train import test

def load_esm2_maskedlm(esm_type, device=None):
    """
    Loads ESM-2 version of a specified version (e.g. esm2_t33_650M_UR50D)
    """
    # Suppress warnings about newly initialized 'esm.pooler.dense.bias', 'esm.pooler.dense.weight' layers - these are not used to extract embeddings
    logging.getLogger("transformers.modeling_utils").setLevel(logging.ERROR)
    
    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        print(f"Using device: {device}")

    model = AutoModelForMaskedLM.from_pretrained(f"facebook/{esm_type}")
    tokenizer = AutoTokenizer.from_pretrained(f"facebook/{esm_type}")

    model.to(device)
    model.eval()  # disables dropout for deterministic results
    
    return model, tokenizer, device


def val(model, tokenizer, val_loader, mask_percentage=0.15, device='cuda', checkpoint_dir='./checkpoints'):
    """
    Same method as val, just for running the val set
    """
    model.to(device)
    model.eval()  
    total_val_loss = 0
    total_weighted_val_loss = 0
    total_val_masked_tokens = 0

    with torch.no_grad():  # No gradients needed
        # Loop over val data (no progress bar)
        with tqdm.tqdm(enumerate(val_loader), total=len(val_loader), desc='Val Batch', leave=True, position=0) as tbar:
            for batch_idx, (inputs, prob) in tbar:
                # Move tensors
                inputs = {k: v.to(device) for k, v in inputs.items()} 
                prob = prob.to(device) 
                
                # Mask based on probability vectors
                masked_inputs = batch_sample_mask_tokens_with_probabilities(inputs, prob, tokenizer, mask_percentage=mask_percentage)
                
                # Forward pass
                outputs = model(**masked_inputs)
                val_loss = outputs.loss
                
                # Number of masked tokens
                num_masked_tokens = (masked_inputs["input_ids"] == tokenizer.mask_token_id).sum().item() 

                # Loss calculations
                total_val_loss += val_loss.item()
                total_weighted_val_loss += val_loss.item() * num_masked_tokens  # Multiply loss by number of masked tokens
                total_val_masked_tokens += num_masked_tokens

            # Compute and log avg. loss and perplexity
            n_val_batches = len(val_loader)
            avg_val_loss = total_val_loss / n_val_batches
            avg_weighted_val_loss = total_weighted_val_loss / total_val_masked_tokens
            val_perplexity = np.exp(avg_weighted_val_loss)

            log_update(f"\nval results:\nTotal batches = {n_val_batches}, Total masked tokens = {total_val_masked_tokens}, Total Loss = {total_val_loss:.4f},  Avg Batch Loss = {avg_val_loss:.4f}, Avg Masked Token-Weighted Loss = {avg_weighted_val_loss:.4f}, Perplexity = {val_perplexity:.4f}")
            
            # Save to dataframe for plotting
            val_stats_df = pd.DataFrame(data={
                    "total_val_loss": [total_val_loss], "weighted_val_loss": [total_weighted_val_loss],
                    "avg_val_loss": [avg_val_loss], "avg_weighted_val_loss": [avg_weighted_val_loss],
                    "val_perplexity": [val_perplexity]
                })
            val_stats_df.to_csv(f"{checkpoint_dir}/val_results.csv",index=False)    # overwrite old file no matter what; should only be one val eval

def main():
    # Load the ESM-2 model
    model, tokenizer, device = load_esm2_maskedlm("esm2_t33_650M_UR50D")
    
    checkpoint_dir = f"checkpoints/esm2_t33_650M_UR50D_{config.PROBABILITY_TYPE}_mask{config.MASK_PERCENTAGE}"
    os.makedirs(checkpoint_dir,exist_ok=True)
    
    with open_logfile(f"{checkpoint_dir}/evaluate_val_test_esm.txt"):
        # Print configurations
        print_configpy(config)
        
        ##### Validation
        val_loader = get_dataloader(config.VAL_PATH, tokenizer, 
                                    probability_type=config.PROBABILITY_TYPE, 
                                    batch_size=config.BATCH_SIZE, 
                                    max_length=config.MAX_LENGTH, shuffle=False)
        
        # Validation
        val(model, tokenizer, val_loader, config.MASK_PERCENTAGE, device=device, checkpoint_dir=checkpoint_dir)
            
        
        ##### Test
        # Crete dataloader
        test_loader = get_dataloader(config.TEST_PATH, 
                                    tokenizer, 
                                    probability_type=config.PROBABILITY_TYPE, 
                                    batch_size=config.BATCH_SIZE, 
                                    max_length=config.MAX_LENGTH, shuffle=False)


        # Test the model
        test(model, tokenizer, test_loader, config.MASK_PERCENTAGE, device=device, checkpoint_dir=checkpoint_dir)

if __name__ == "__main__":
    main()