Fill-Mask
Transformers
PyTorch
esm
Inference Endpoints
svincoff commited on
Commit
9a73cb0
1 Parent(s): 4f08905

uploaded training code and model weights

Browse files
fuson_plm/training/README.md ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Training Script
2
+
3
+ This folder holds code for training the model (`train.py`), defining the model architecture (`model.py`), and defining utility functions including masking rate schedulers adn dataloaders (`utils.py`). There is also a script for running ESM-2 on the test data (`test_esm2.py`).
4
+
5
+ The weights and other necessary files for loading FusOn-pLM are stored in `checkpoints/best/ckpt`. Results on the test set are stored in `checkpoints/best/test_results.csv`.
6
+
7
+ ### Usage
8
+ #### Configs
9
+ The `config.py` script holds configurations for **training** and **plotting**.
10
+
11
+ ```python
12
+ # Model parameters
13
+ EPOCHS = 30
14
+ BATCH_SIZE = 8
15
+ MAX_LENGTH = 2000
16
+ LEARNING_RATE = 3e-4
17
+ N_UNFROZEN_LAYERS = 8
18
+ UNFREEZE_QUERY = True
19
+ UNFREEZE_KEY = True
20
+ UNFREEZE_VALUE = True
21
+
22
+ ### Masking parameters - must use either variable or fixed masking rate
23
+ # var masking rate (choice 1)
24
+ VAR_MASK_RATE = True # if this is
25
+ MASK_LOW = 0.15
26
+ MASK_HIGH = 0.40
27
+ MASK_STEPS = 20
28
+ MASK_SCHEDULER = "cosine" # specify the type of scheduler to use. options are: "cosine","loglinear","stepwise"
29
+ # fixed masking rate (choice 2)
30
+ MASK_PERCENTAGE = 0.15 # if VAR_MASK_RATE = False, code will use fixed masking rate
31
+
32
+ # To continue training a model you already started, fill in the following parameters
33
+ FINETUNE_FROM_SCRATCH = True # Set to False if you want to finetune from a checkpoint
34
+ PATH_TO_STARTING_CKPT = '' # only set the path if FINETUNE_FROM_SCRATCH = False
35
+
36
+ # File paths - do not change unless you move the training dta
37
+ TRAIN_PATH = '../data/splits/train_df.csv'
38
+ VAL_PATH = '../data/splits/val_df.csv'
39
+ TEST_PATH = '../data/splits/test_df.csv'
40
+
41
+ # WandB parameters
42
+ # Fill these in with your own WandB account info
43
+ WANDB_PROJECT = ''
44
+ WANDB_ENTITY = ''
45
+ WANDB_API_KEY=''
46
+
47
+ # GPU parameters
48
+ CUDA_VISIBLE_DEVICES = "0"
49
+ ```
50
+
51
+ #### Training
52
+ The `train.py` script trains a fusion-aware ESM model according to the settings specified in `config.py`.
53
+
54
+ To run, enter in terminal:
55
+ ```bash
56
+ python train.py
57
+ ```
58
+ or, to run the (long) training process in the background:
59
+ ```bash
60
+ nohup python train.py > train.out 2> train.err &
fuson_plm/training/__init__.py ADDED
File without changes
fuson_plm/training/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (152 Bytes). View file
 
fuson_plm/training/__pycache__/config.cpython-310.pyc ADDED
Binary file (898 Bytes). View file
 
fuson_plm/training/__pycache__/model.cpython-310.pyc ADDED
Binary file (4.88 kB). View file
 
fuson_plm/training/__pycache__/plot.cpython-310.pyc ADDED
Binary file (4.02 kB). View file
 
fuson_plm/training/__pycache__/train.cpython-310.pyc ADDED
Binary file (12.2 kB). View file
 
fuson_plm/training/__pycache__/utils.cpython-310.pyc ADDED
Binary file (11.3 kB). View file
 
fuson_plm/training/config.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ###### TRAINING
2
+ # Model parameters
3
+ EPOCHS = 30
4
+ BATCH_SIZE = 8
5
+ MAX_LENGTH = 2000
6
+ LEARNING_RATE = 3e-4
7
+ N_UNFROZEN_LAYERS = 8
8
+ UNFREEZE_QUERY = True
9
+ UNFREEZE_KEY = True
10
+ UNFREEZE_VALUE = True
11
+
12
+ ### Masking parameters - must use either variable or fixed masking rate
13
+ # var masking rate (choice 1)
14
+ VAR_MASK_RATE = True # if this is
15
+ MASK_LOW = 0.15
16
+ MASK_HIGH = 0.40
17
+ MASK_STEPS = 20
18
+ MASK_SCHEDULER = "cosine" # specify the type of scheduler to use. options are: "cosine","loglinear","stepwise"
19
+ # fixed masking rate (choice 2)
20
+ MASK_PERCENTAGE = 0.15 # if VAR_MASK_RATE = False, code will use fixed masking rate
21
+
22
+ # To continue training a model you already started, fill in the following parameters
23
+ FINETUNE_FROM_SCRATCH = True # Set to False if you want to finetune from a checkpoint
24
+ PATH_TO_STARTING_CKPT = '' # only set the path if FINETUNE_FROM_SCRATCH = False
25
+
26
+ # File paths - do not change unless you move the training dta
27
+ TRAIN_PATH = '../data/splits/train_df.csv'
28
+ VAL_PATH = '../data/splits/val_df.csv'
29
+ TEST_PATH = '../data/splits/test_df.csv'
30
+
31
+ # WandB parameters
32
+ # Fill these in with your own WandB account info
33
+ WANDB_PROJECT = ''
34
+ WANDB_ENTITY = ''
35
+ WANDB_API_KEY=''
36
+
37
+ # GPU parameters
38
+ CUDA_VISIBLE_DEVICES = "0"
fuson_plm/training/demo.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fuson_plm.training.model import FusOnpLM
2
+ from transformers import AutoModelForMaskedLM, AutoTokenizer, AutoModel
3
+ import logging
4
+ import torch
5
+ import os
6
+
7
+ os.environ['CUDA_VISIBLE_DEVICES'] = "1"
8
+
9
+ # Suppress warnings about newly initialized 'esm.pooler.dense.bias', 'esm.pooler.dense.weight' layers - these are not used to extract embeddings
10
+ logging.getLogger("transformers.modeling_utils").setLevel(logging.ERROR)
11
+
12
+ # Set device
13
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14
+ print(f"Using device: {device}")
15
+
16
+ # Load the tokenizer and model
17
+ model_name = 'checkpoints/old_splits_snp_2000_ft_11layers_Q_b8_lr5e-05_mask0.15-08-12-2024-12:42:48/checkpoint_epoch_1.pth'
18
+ model = AutoModel.from_pretrained(model_name) # initialize model
19
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
20
+ model.eval()
21
+ model.to(device)
22
+
23
+ # Example fusion oncoprotein sequence: MLLT10:PICALM, associated with Acute Myeloid Leukemia (LAML)
24
+ # Amino acids 1-80 are derived from the head gene, MLLT10
25
+ # Amino acids 81-119 are derived from the tail gene, PICALM
26
+ sequence = "MVSSDRPVSLEDEVSHSMKEMIGGCCVCSDERGWAENPLVYCDGHGCSVAVHQACYGIVQVPTGPWFCRKCESQERAARVPPQMGSVPVMTQPTLIYSQPVMRPPNPFGPVSGAQIQFM"
27
+
28
+ # Tokenize the input sequence
29
+ inputs = tokenizer(sequence, return_tensors="pt", padding=True, truncation=True,max_length=2000)
30
+ inputs = {k: v.to(device) for k, v in inputs.items()}
31
+
32
+ # Get the embeddings
33
+ with torch.no_grad():
34
+ outputs = model(**inputs)
35
+ # The embeddings are in the last_hidden_state tensor
36
+ embeddings = outputs.last_hidden_state
37
+ # remove extra dimension
38
+ embeddings = embeddings.squeeze(0)
39
+ # remove BOS and EOS tokens
40
+ embeddings = embeddings[1:-1, :]
41
+
42
+ # Convert embeddings to numpy array (if needed)
43
+ embeddings = embeddings.cpu().numpy()
44
+
45
+ print("Sequence length: ", len(sequence))
46
+ print("Per-residue embeddings shape:", embeddings.shape)
fuson_plm/training/model.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoModelForMaskedLM, AutoTokenizer, AutoModel
2
+ import torch
3
+ import os
4
+
5
+ class FusOnTokenizer:
6
+ """
7
+ FusOnTokenizer class: a wrapper around AutoTokenizer
8
+ """
9
+ def __init__(self, pretrained_path='facebook/esm2_t33_650M_UR50D'):
10
+ self.tokenizer = AutoTokenizer.from_pretrained(pretrained_path)
11
+
12
+ def __getattr__(self, name):
13
+ """
14
+ Delegate attribute access to the underlying tokenizer.
15
+ This allows calls like .tokenize(), .train(), and .eval() to be forwarded to the tokenizer.
16
+ """
17
+ return getattr(self.tokenizer, name)
18
+
19
+ def __call__(self, *args, **kwargs):
20
+ """
21
+ Make the FusOnTokenizer object callable, delegating to the tokenizer's __call__ method.
22
+ """
23
+ return self.tokenizer(*args, **kwargs)
24
+
25
+ def save_tokenizer(self, save_directory):
26
+ self.tokenizer.save_pretrained(save_directory)
27
+
28
+ def load_tokenizer(self, load_directory):
29
+ self.tokenizer = AutoTokenizer.from_pretrained(load_directory)
30
+
31
+ class FusOnpLM:
32
+ """
33
+ FusOn-pLM class: a wrapper around AutoModelForMaskedLM
34
+ """
35
+ def __init__(self, pretrained_path='facebook/esm2_t33_650M_UR50D', ckpt_path = None, mlm_head=False):
36
+ if not(ckpt_path is None):
37
+ self.load_model(ckpt_path, mlm_head)
38
+ else:
39
+ # Load the pre-trained model and tokenizer
40
+ self.model = AutoModelForMaskedLM.from_pretrained(pretrained_path)
41
+ self.tokenizer = FusOnTokenizer(pretrained_path)
42
+
43
+ self.n_layers = self.count_encoder_layers()
44
+
45
+ def __getattr__(self, name):
46
+ """
47
+ Delegate attribute access to the underlying model.
48
+ This allows calls like .to(), .train(), and .eval() to be forwarded to the model.
49
+ """
50
+ return getattr(self.model, name)
51
+
52
+ def __call__(self, *args, **kwargs):
53
+ """
54
+ Make the FusOnpLM object callable, delegating to the model's __call__ method.
55
+ """
56
+ return self.model(*args, **kwargs)
57
+
58
+ def freeze_model(self):
59
+ """
60
+ Freezes all parameters in the model
61
+ """
62
+ for param in self.model.parameters():
63
+ param.requires_grad = False
64
+
65
+ def unfreeze_last_n_layers(self, n_unfrozen_layers, unfreeze_query=True, unfreeze_key=True, unfreeze_value=True):
66
+ """
67
+ Unfreezes specific parts of the final n layers in the model's encoder.
68
+
69
+ Args:
70
+ n_unfrozen_layers (int): Number of final layers to unfreeze.
71
+ unfreeze_query (bool): Whether to unfreeze the query projections. Default is True.
72
+ unfreeze_key (bool): Whether to unfreeze the key projections. Default is True.
73
+ unfreeze_value (bool): Whether to unfreeze the value projections. Default is True.
74
+ """
75
+ for i, layer in enumerate(self.model.esm.encoder.layer):
76
+ if (self.n_layers - i) <= n_unfrozen_layers: # Only the last n layers
77
+ if unfreeze_query:
78
+ self._unfreeze_parameters(layer.attention.self.query)
79
+ if unfreeze_key:
80
+ self._unfreeze_parameters(layer.attention.self.key)
81
+ if unfreeze_value:
82
+ self._unfreeze_parameters(layer.attention.self.value)
83
+
84
+ def _unfreeze_parameters(self, module):
85
+ """
86
+ Helper method to unfreeze parameters in a given module.
87
+
88
+ Args:
89
+ module (nn.Module): The module whose parameters are to be unfrozen.
90
+ """
91
+ for param in module.parameters():
92
+ param.requires_grad = True
93
+
94
+
95
+ def count_encoder_layers(self):
96
+ """
97
+ Count the number of encoder layers in the model.
98
+ """
99
+ return len(self.model.esm.encoder.layer)
100
+
101
+ def save_model(self, save_directory, optimizer=None):
102
+ # Save the model and tokenizer
103
+ self.model.save_pretrained(save_directory)
104
+ self.tokenizer.save_pretrained(save_directory)
105
+
106
+ # If an optimizer is provided, save its state dict
107
+ if optimizer is not None:
108
+ optimizer_path = os.path.join(save_directory, "optimizer.pt")
109
+ torch.save(optimizer.state_dict(), optimizer_path)
110
+
111
+ def load_model(self, load_directory, mlm_head):
112
+ # Load a checkpoint of the model either with or without an MLM head
113
+ if mlm_head:
114
+ self.model = AutoModelForMaskedLM.from_pretrained(load_directory)
115
+ else:
116
+ # Load the model and tokenizer from a directory
117
+ self.model = AutoModel.from_pretrained(load_directory)
118
+ self.tokenizer = AutoTokenizer.from_pretrained(load_directory)
119
+
fuson_plm/training/test_esm2.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ### Run ESM2 on the validation and test set. Get val and test losses.
2
+ import os
3
+ import fuson_plm.training.config as config
4
+ # Set the WANDB_API_KEY environment variable
5
+ os.environ['WANDB_API_KEY'] = config.WANDB_API_KEY
6
+ os.environ['CUDA_VISIBLE_DEVICES'] = config.CUDA_VISIBLE_DEVICES
7
+
8
+ import torch
9
+ import tqdm
10
+ import numpy as np
11
+ import pandas as pd
12
+ import logging
13
+ from transformers import AutoModelForMaskedLM, AutoTokenizer
14
+ from fuson_plm.utils.logging import log_update, open_logfile, print_configpy
15
+ from fuson_plm.benchmarking.caid.utils import DisorderDataset, get_dataloader, check_dataloaders
16
+ from fuson_plm.training.utils import batch_sample_mask_tokens_with_probabilities, get_dataloader, check_dataloaders
17
+ from fuson_plm.training.train import test
18
+
19
+ def load_esm2_maskedlm(esm_type, device=None):
20
+ """
21
+ Loads ESM-2 version of a specified version (e.g. esm2_t33_650M_UR50D)
22
+ """
23
+ # Suppress warnings about newly initialized 'esm.pooler.dense.bias', 'esm.pooler.dense.weight' layers - these are not used to extract embeddings
24
+ logging.getLogger("transformers.modeling_utils").setLevel(logging.ERROR)
25
+
26
+ if device is None:
27
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
28
+ print(f"Using device: {device}")
29
+
30
+ model = AutoModelForMaskedLM.from_pretrained(f"facebook/{esm_type}")
31
+ tokenizer = AutoTokenizer.from_pretrained(f"facebook/{esm_type}")
32
+
33
+ model.to(device)
34
+ model.eval() # disables dropout for deterministic results
35
+
36
+ return model, tokenizer, device
37
+
38
+
39
+ def val(model, tokenizer, val_loader, mask_percentage=0.15, device='cuda', checkpoint_dir='./checkpoints'):
40
+ """
41
+ Same method as val, just for running the val set
42
+ """
43
+ model.to(device)
44
+ model.eval()
45
+ total_val_loss = 0
46
+ total_weighted_val_loss = 0
47
+ total_val_masked_tokens = 0
48
+
49
+ with torch.no_grad(): # No gradients needed
50
+ # Loop over val data (no progress bar)
51
+ with tqdm.tqdm(enumerate(val_loader), total=len(val_loader), desc='Val Batch', leave=True, position=0) as tbar:
52
+ for batch_idx, (inputs, prob) in tbar:
53
+ # Move tensors
54
+ inputs = {k: v.to(device) for k, v in inputs.items()}
55
+ prob = prob.to(device)
56
+
57
+ # Mask based on probability vectors
58
+ masked_inputs = batch_sample_mask_tokens_with_probabilities(inputs, prob, tokenizer, mask_percentage=mask_percentage)
59
+
60
+ # Forward pass
61
+ outputs = model(**masked_inputs)
62
+ val_loss = outputs.loss
63
+
64
+ # Number of masked tokens
65
+ num_masked_tokens = (masked_inputs["input_ids"] == tokenizer.mask_token_id).sum().item()
66
+
67
+ # Loss calculations
68
+ total_val_loss += val_loss.item()
69
+ total_weighted_val_loss += val_loss.item() * num_masked_tokens # Multiply loss by number of masked tokens
70
+ total_val_masked_tokens += num_masked_tokens
71
+
72
+ # Compute and log avg. loss and perplexity
73
+ n_val_batches = len(val_loader)
74
+ avg_val_loss = total_val_loss / n_val_batches
75
+ avg_weighted_val_loss = total_weighted_val_loss / total_val_masked_tokens
76
+ val_perplexity = np.exp(avg_weighted_val_loss)
77
+
78
+ log_update(f"\nval results:\nTotal batches = {n_val_batches}, Total masked tokens = {total_val_masked_tokens}, Total Loss = {total_val_loss:.4f}, Avg Batch Loss = {avg_val_loss:.4f}, Avg Masked Token-Weighted Loss = {avg_weighted_val_loss:.4f}, Perplexity = {val_perplexity:.4f}")
79
+
80
+ # Save to dataframe for plotting
81
+ val_stats_df = pd.DataFrame(data={
82
+ "total_val_loss": [total_val_loss], "weighted_val_loss": [total_weighted_val_loss],
83
+ "avg_val_loss": [avg_val_loss], "avg_weighted_val_loss": [avg_weighted_val_loss],
84
+ "val_perplexity": [val_perplexity]
85
+ })
86
+ val_stats_df.to_csv(f"{checkpoint_dir}/val_results.csv",index=False) # overwrite old file no matter what; should only be one val eval
87
+
88
+ def main():
89
+ # Load the ESM-2 model
90
+ model, tokenizer, device = load_esm2_maskedlm("esm2_t33_650M_UR50D")
91
+
92
+ checkpoint_dir = f"checkpoints/esm2_t33_650M_UR50D_{config.PROBABILITY_TYPE}_mask{config.MASK_PERCENTAGE}"
93
+ os.makedirs(checkpoint_dir,exist_ok=True)
94
+
95
+ with open_logfile(f"{checkpoint_dir}/evaluate_val_test_esm.txt"):
96
+ # Print configurations
97
+ print_configpy(config)
98
+
99
+ ##### Validation
100
+ val_loader = get_dataloader(config.VAL_PATH, tokenizer,
101
+ probability_type=config.PROBABILITY_TYPE,
102
+ batch_size=config.BATCH_SIZE,
103
+ max_length=config.MAX_LENGTH, shuffle=False)
104
+
105
+ # Validation
106
+ val(model, tokenizer, val_loader, config.MASK_PERCENTAGE, device=device, checkpoint_dir=checkpoint_dir)
107
+
108
+
109
+ ##### Test
110
+ # Crete dataloader
111
+ test_loader = get_dataloader(config.TEST_PATH,
112
+ tokenizer,
113
+ probability_type=config.PROBABILITY_TYPE,
114
+ batch_size=config.BATCH_SIZE,
115
+ max_length=config.MAX_LENGTH, shuffle=False)
116
+
117
+
118
+ # Test the model
119
+ test(model, tokenizer, test_loader, config.MASK_PERCENTAGE, device=device, checkpoint_dir=checkpoint_dir)
120
+
121
+ if __name__ == "__main__":
122
+ main()
fuson_plm/training/train.py ADDED
@@ -0,0 +1,388 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ This is a training script for finetuning ESM.
3
+ I am going to freeze the parameters in the head and unfreeze the last N layers in the model.
4
+ '''
5
+
6
+ import os
7
+ import fuson_plm.training.config as config
8
+
9
+ # Set the WANDB_API_KEY environment variable
10
+ os.environ['WANDB_API_KEY'] = config.WANDB_API_KEY
11
+ os.environ['CUDA_VISIBLE_DEVICES'] = config.CUDA_VISIBLE_DEVICES
12
+
13
+ import torch
14
+ import numpy as np
15
+ import pandas as pd
16
+ import tqdm
17
+ from datetime import datetime
18
+ import wandb
19
+ import pytz
20
+ import sys
21
+
22
+ from transformers import AdamW
23
+
24
+ from fuson_plm.utils.logging import print_configpy, get_local_time, open_logfile, open_errfile, log_update
25
+ from fuson_plm.training.model import FusOnpLM
26
+ from fuson_plm.training.utils import batch_sample_mask_tokens_with_probabilities, get_dataloader, check_dataloaders, get_mask_rate_scheduler
27
+ from fuson_plm.training.plot import make_train_val_test_bd_plot
28
+
29
+ def prepare_model(model, n_unfrozen_layers, unfreeze_query=True, unfreeze_key=True, unfreeze_value=True):
30
+ # Log the model's initial state
31
+ n_layers = model.count_encoder_layers()
32
+ total_params = sum(p.numel() for p in model.parameters())
33
+ total_head_params = sum(p.numel() for p in model.lm_head.parameters())
34
+ log_update(f'\nInitial state:\n\tTotal number of layers in the model: {n_layers}')
35
+ log_update(f'\tTotal parameters in the AutoModelforMaskedLM model: {total_params}')
36
+ log_update(f'\tTotal parameters in the MLM Head ONLY: {total_head_params}')
37
+
38
+ # Freeze the model to start
39
+ model.freeze_model()
40
+ n_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
41
+ log_update(f'Froze all {model.n_layers} model layers')
42
+ log_update(f'\tTrainable params: {n_trainable_params}')
43
+
44
+ # Unfreeze the last n layers
45
+ model.unfreeze_last_n_layers(n_unfrozen_layers, unfreeze_query=unfreeze_query, unfreeze_key=unfreeze_key, unfreeze_value=unfreeze_value)
46
+ n_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
47
+ trainable_params = '\n\t\t'.join([name for name, param in model.named_parameters() if param.requires_grad])
48
+ num_trainable_params_lm_head = sum(p.numel() for p in model.lm_head.parameters() if p.requires_grad)
49
+ num_trainable_params_esm = sum(p.numel() for p in model.esm.parameters() if p.requires_grad)
50
+ log_update(f'Unfroze final {n_unfrozen_layers} layers')
51
+ log_update(f'\tTrainable params: {n_trainable_params}\n\t\t{trainable_params}')
52
+ log_update(f"\tTrainable parameters in the lm_head: {num_trainable_params_lm_head}")
53
+ log_update(f"\tTrainable params in the ESM part: {num_trainable_params_esm}")
54
+
55
+ def train(model, tokenizer, optimizer, train_loader, val_loader, n_epochs=10, start_epoch=1, mask_percentage=0.15, mask_rate_scheduler=None, device='cuda', checkpoint_dir='./checkpoints'):
56
+ """
57
+ Train the model
58
+ """
59
+ # Loop over epochs
60
+ log_update("\n")
61
+
62
+ for epoch in range(start_epoch, start_epoch+n_epochs):
63
+ if mask_rate_scheduler is not None:
64
+ mask_rate_scheduler.reset() # resetting because we rant to ramp it up again every epoch
65
+
66
+ model.train()
67
+ total_train_loss = 0
68
+ total_weighted_train_loss = 0
69
+ total_train_masked_tokens = 0
70
+
71
+ log_update(f"Epoch {epoch}")
72
+ # Loop over train data with progress bar
73
+ with tqdm.tqdm(enumerate(train_loader), total=len(train_loader), desc='Training Batch', leave=True, position=0) as pbar:
74
+ for batch_idx, (inputs, prob) in pbar:
75
+ # Take a step with the mask rate scheduler, if there is one.
76
+ masking_rate = mask_percentage
77
+ if mask_rate_scheduler is not None:
78
+ mask_rate_scheduler.step()
79
+ masking_rate = mask_rate_scheduler.get_masking_rate()
80
+ log_update(f"\tBatch index: {batch_idx}\tMasking rate: {masking_rate:.5f}")
81
+
82
+ # Move tensors
83
+ inputs = {k: v.to(device) for k, v in inputs.items()}
84
+ prob = prob.to(device)
85
+
86
+ # Mask based on probability vectors
87
+ masked_inputs = batch_sample_mask_tokens_with_probabilities(inputs, prob, tokenizer,mask_percentage=masking_rate)
88
+
89
+ # Forward pass and update
90
+ optimizer.zero_grad()
91
+ outputs = model(**masked_inputs)
92
+ loss = outputs.loss
93
+ loss.backward()
94
+ optimizer.step()
95
+
96
+ # Number of masked tokens
97
+ num_masked_tokens = (masked_inputs["input_ids"] == tokenizer.mask_token_id).sum().item()
98
+
99
+ # Loss calculations and wandb log
100
+ total_train_loss += loss.item()
101
+ total_weighted_train_loss += loss.item() * num_masked_tokens # Multiply loss by number of masked tokens
102
+ total_train_masked_tokens += num_masked_tokens
103
+ wandb.log({"batch_loss": loss.item()})
104
+
105
+ # Save a checkpoint at the end of each epoch
106
+ checkpoint_path = os.path.join(checkpoint_dir, f'checkpoint_epoch_{epoch}')
107
+ model.save_model(checkpoint_path, optimizer=optimizer)
108
+ log_update(f'\nSaved checkpoint to {checkpoint_path}')
109
+
110
+ # Calculate and log average training loss on wandb
111
+ n_train_batches = len(train_loader)
112
+ avg_train_loss = total_train_loss / n_train_batches
113
+ avg_weighted_train_loss = total_weighted_train_loss / total_train_masked_tokens
114
+ train_perplexity = np.exp(avg_weighted_train_loss)
115
+ wandb.log({"epoch": epoch,
116
+ "total_train_loss": total_train_loss, "weighted_train_loss": total_weighted_train_loss,
117
+ "avg_train_loss": avg_train_loss, "avg_weighted_train_loss": avg_weighted_train_loss,
118
+ "train_perplexity": train_perplexity})
119
+
120
+ # Track curve stats for easy re-plotting of training curves later
121
+ train_stats_df = pd.DataFrame(data={
122
+ "epoch": [epoch],
123
+ "total_train_loss": [total_train_loss], "weighted_train_loss": [total_weighted_train_loss],
124
+ "avg_train_loss": [avg_train_loss], "avg_weighted_train_loss": [avg_weighted_train_loss],
125
+ "train_perplexity": [train_perplexity]
126
+ })
127
+ if os.path.exists(f"{checkpoint_dir}/train_curve.csv"): # add to file if necessary
128
+ train_stats_df.to_csv(f"{checkpoint_dir}/train_curve.csv",index=False,header=False,mode='a')
129
+ else: # make new file if necessary
130
+ train_stats_df.to_csv(f"{checkpoint_dir}/train_curve.csv",index=False)
131
+
132
+ # Validation loop
133
+ model.eval()
134
+ total_val_loss = 0
135
+ total_weighted_val_loss = 0
136
+ total_val_masked_tokens = 0
137
+
138
+ with torch.no_grad(): # No gradients needed
139
+ # Loop over val data with progress bar
140
+ with tqdm.tqdm(enumerate(val_loader), total=len(val_loader), desc='Validation Batch', leave=True, position=0) as vbar:
141
+ for batch_idx, (inputs, prob) in vbar:
142
+ # Move tensors
143
+ inputs = {k: v.to(device) for k, v in inputs.items()}
144
+ prob = prob.to(device)
145
+
146
+ # Mask based on probability vectors
147
+ ## FIXED 15% masking for the validation set
148
+ masked_inputs = batch_sample_mask_tokens_with_probabilities(inputs, prob, tokenizer,mask_percentage=0.15)
149
+
150
+ # Forward pass
151
+ outputs = model(**masked_inputs)
152
+ val_loss = outputs.loss
153
+
154
+ # Number of masked tokens
155
+ num_masked_tokens = (masked_inputs["input_ids"] == tokenizer.mask_token_id).sum().item()
156
+
157
+ # Loss calculations
158
+ total_val_loss += val_loss.item()
159
+ total_weighted_val_loss += val_loss.item() * num_masked_tokens # Multiply loss by number of masked tokens
160
+ total_val_masked_tokens += num_masked_tokens
161
+
162
+ # Calculate and log avg. loss and perplexity (wandb and locally)
163
+ n_val_batches = len(val_loader)
164
+ avg_val_loss = total_val_loss / n_val_batches # avg per batch
165
+ avg_weighted_val_loss = total_weighted_val_loss / total_val_masked_tokens # avg per masked token
166
+ val_perplexity = np.exp(avg_weighted_val_loss)
167
+ wandb.log({"epoch": epoch,
168
+ "total_val_loss": total_val_loss, "weighted_val_loss": total_weighted_val_loss,
169
+ "avg_val_loss": avg_val_loss, "avg_weighted_val_loss": avg_weighted_val_loss,
170
+ "val_perplexity": val_perplexity})
171
+
172
+ # Track curve stats for easy re-plotting of training curves later
173
+ val_stats_df = pd.DataFrame(data={
174
+ "epoch": [epoch],
175
+ "total_val_loss": [total_val_loss], "weighted_val_loss": [total_weighted_val_loss],
176
+ "avg_val_loss": [avg_val_loss], "avg_weighted_val_loss": [avg_weighted_val_loss],
177
+ "val_perplexity": [val_perplexity]
178
+ })
179
+ if os.path.exists(f"{checkpoint_dir}/val_curve.csv"): # add to file if necessary
180
+ val_stats_df.to_csv(f"{checkpoint_dir}/val_curve.csv",index=False,header=False,mode='a')
181
+ else: # make new file if necessary
182
+ val_stats_df.to_csv(f"{checkpoint_dir}/val_curve.csv",index=False)
183
+
184
+ log_update(f"Epoch: {epoch}")
185
+ log_update(f"\tTrain set: Total batches = {n_train_batches}, Total masked tokens = {total_train_masked_tokens}, Total Loss = {total_train_loss:.4f}, Avg Batch Loss = {avg_train_loss:.4f}, Avg Masked Token-Weighted Loss = {avg_weighted_train_loss:.4f}, Perplexity = {train_perplexity:.4f}")
186
+ log_update(f"\tValidation set: Total batches = {n_val_batches}, Total masked tokens = {total_val_masked_tokens}, Total Loss = {total_val_loss:.4f}, Avg Batch Loss = {avg_val_loss:.4f}, Avg Masked Token-Weighted Loss = {avg_weighted_val_loss:.4f}, Perplexity = {val_perplexity:.4f}")
187
+
188
+ def test(model, tokenizer, test_loader, mask_percentage=0.15, device='cuda', checkpoint_dir='./checkpoints'):
189
+ """
190
+ """
191
+ model.to(device)
192
+ model.eval()
193
+ total_test_loss = 0
194
+ total_weighted_test_loss = 0
195
+ total_test_masked_tokens = 0
196
+
197
+ with torch.no_grad(): # No gradients needed
198
+ # Loop over test data (no progress bar)
199
+ with tqdm.tqdm(enumerate(test_loader), total=len(test_loader), desc='Test Batch', leave=True, position=0) as tbar:
200
+ for batch_idx, (inputs, prob) in tbar:
201
+ # Move tensors
202
+ inputs = {k: v.to(device) for k, v in inputs.items()}
203
+ prob = prob.to(device)
204
+
205
+ # Mask based on probability vectors
206
+ ### FIXED 15% masking for the testing set
207
+ masked_inputs = batch_sample_mask_tokens_with_probabilities(inputs, prob, tokenizer, mask_percentage=0.15)
208
+
209
+ # Forward pass
210
+ outputs = model(**masked_inputs)
211
+ test_loss = outputs.loss
212
+
213
+ # Number of masked tokens
214
+ num_masked_tokens = (masked_inputs["input_ids"] == tokenizer.mask_token_id).sum().item()
215
+
216
+ # Loss calculations
217
+ total_test_loss += test_loss.item()
218
+ total_weighted_test_loss += test_loss.item() * num_masked_tokens # Multiply loss by number of masked tokens
219
+ total_test_masked_tokens += num_masked_tokens
220
+
221
+ # Compute and log avg. loss and perplexity
222
+ n_test_batches = len(test_loader)
223
+ avg_test_loss = total_test_loss / n_test_batches
224
+ avg_weighted_test_loss = total_weighted_test_loss / total_test_masked_tokens
225
+ test_perplexity = np.exp(avg_weighted_test_loss)
226
+
227
+ log_update(f"\nTest results:\nTotal batches = {n_test_batches}, Total masked tokens = {total_test_masked_tokens}, Total Loss = {total_test_loss:.4f}, Avg Batch Loss = {avg_test_loss:.4f}, Avg Masked Token-Weighted Loss = {avg_weighted_test_loss:.4f}, Perplexity = {test_perplexity:.4f}")
228
+
229
+ # Save to dataframe for plotting
230
+ test_stats_df = pd.DataFrame(data={
231
+ "total_test_loss": [total_test_loss], "weighted_test_loss": [total_weighted_test_loss],
232
+ "avg_test_loss": [avg_test_loss], "avg_weighted_test_loss": [avg_weighted_test_loss],
233
+ "test_perplexity": [test_perplexity]
234
+ })
235
+ test_stats_df.to_csv(f"{checkpoint_dir}/test_results.csv",index=False) # overwrite old file no matter what; should only be one test eval
236
+
237
+ def check_env_variables():
238
+ log_update("\nChecking on environment variables...")
239
+ log_update(f"\tWANDB_API_KEY: {os.environ.get('WANDB_API_KEY')}")
240
+ log_update(f"\tCUDA_VISIBLE_DEVICES: {os.environ.get('CUDA_VISIBLE_DEVICES')}")
241
+ log_update(f"\ttorch.cuda.device_count(): {torch.cuda.device_count()}")
242
+ for i in range(torch.cuda.device_count()):
243
+ log_update(f"\t\tDevice {i}: {torch.cuda.get_device_name(i)}")
244
+
245
+ def intialize_model_and_optimizer(finetune_from_scratch, device, path_to_starting_ckpt=None, learning_rate=1e-4, n_unfrozen_layers=0, unfreeze_query=False, unfreeze_key=False, unfreeze_value=False):
246
+ """
247
+ Initializes the model, either from ESM-2-650M if finetuning from scratch, or from a prior checkpoint if not finetuning from scratch.
248
+ Also prepares
249
+
250
+ Args:
251
+ finetune_from_scratch (bool): True if finetuning from scratch. False if finetuning from a previous ckpt
252
+ path_to_starting_ckpt (str): path to starting ckpt for finetuning (optional)
253
+ """
254
+ if not(finetune_from_scratch) and not(os.path.exists(path_to_starting_ckpt)):
255
+ raise Exception(f"Error: could not find {path_to_starting_ckpt}. When finetuning from a prior checkpoint, you must provide a valid path to that checkpoint.")
256
+
257
+ # if finetuning from scratch, initialize from scratch
258
+ if finetune_from_scratch:
259
+ log_update(f"\nInitializing FusOn-pLM model to be finetuned from scratch")
260
+ model = FusOnpLM() # because of __getattr__, we can use FusOnpLM() to get the model. It also contains the tokenizer.
261
+ model.to(device)
262
+ prepare_model(model, n_unfrozen_layers,
263
+ unfreeze_query=unfreeze_query, unfreeze_key=unfreeze_key, unfreeze_value=unfreeze_value)
264
+
265
+ # Set the optimizer here, change it if we are finetuning from an old checkpoint
266
+ optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=learning_rate)
267
+
268
+ return model, optimizer
269
+
270
+ # if not, initialize from starting ckpt
271
+ else:
272
+ log_update(f"\nInitializing FusOn-pLM model to be finetuned from previous checkpoint: {path_to_starting_ckpt}")
273
+ model = FusOnpLM(ckpt_path = path_to_starting_ckpt, mlm_head=True)
274
+ model.to(device)
275
+ prepare_model(model, n_unfrozen_layers,
276
+ unfreeze_query=unfreeze_query, unfreeze_key=unfreeze_key, unfreeze_value=unfreeze_value)
277
+
278
+ log_update(f"Loading optimizer state_dict from previous checkpoint")
279
+ optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()))
280
+ optimizer.load_state_dict(torch.load(os.path.join(path_to_starting_ckpt, "optimizer.pt"), map_location=device))
281
+
282
+ return model, optimizer
283
+
284
+ def main():
285
+ # Set probability type to uniform; only option
286
+ config.PROBABILITY_TYPE = "uniform"
287
+
288
+ # Set run name (WANDB_NAME)
289
+ kqv_tag = f"{'Q' if config.UNFREEZE_QUERY else ''}" + f"{'K' if config.UNFREEZE_KEY else ''}" + f"{'V' if config.UNFREEZE_VALUE else ''}"
290
+ timestamp = get_local_time()
291
+ # make a mask tag _mask{config.MASK_PERCENTAGE}
292
+ mask_tag = f"mask{config.MASK_PERCENTAGE}"
293
+ if config.VAR_MASK_RATE: # if variable masking rate, change the tag to relfect this
294
+ mask_tag=f"maskvar_{config.MASK_SCHEDULER}_low{config.MASK_LOW}_high{config.MASK_HIGH}"
295
+
296
+ # Define the train settings string and wandb name from this
297
+ TRAIN_SETTINGS_STRING = f"{config.PROBABILITY_TYPE}_{config.MAX_LENGTH}_ft_{config.N_UNFROZEN_LAYERS}layers_{kqv_tag}_b{config.BATCH_SIZE}_lr{config.LEARNING_RATE}_{mask_tag}"
298
+ WANDB_NAME = f'{TRAIN_SETTINGS_STRING}-{timestamp}'
299
+
300
+ # Create directory for model checkpoints
301
+ checkpoint_dir = f'checkpoints/{WANDB_NAME}'
302
+ start_epoch = 1
303
+
304
+ # Determine if we're adding to an old log file or opening a new one
305
+ logmode='w'
306
+
307
+ # If we're finetuning from a checkpoint, save to the same folder instead, and keep track of which epoch to start on
308
+ # Also, load the optimizer from here
309
+ if not(config.FINETUNE_FROM_SCRATCH):
310
+ logmode='a'
311
+ path_to_starting_ckpt = config.PATH_TO_STARTING_CKPT
312
+ checkpoint_dir = path_to_starting_ckpt[0:path_to_starting_ckpt.rindex('/')]
313
+ START_MODEL_TRAIN_SETTINGS_STRING = checkpoint_dir[checkpoint_dir.index('checkpoints/')+len('checkpoints/'):checkpoint_dir.index('-')]
314
+ start_epoch = int(path_to_starting_ckpt.split('/checkpoint_epoch_')[1])+1
315
+
316
+ os.makedirs(f'checkpoints', exist_ok=True)
317
+ os.makedirs(checkpoint_dir, exist_ok=True)
318
+
319
+ # Open log file
320
+ LOG_PATH = f'{checkpoint_dir}/training_log.txt'
321
+ ERR_PATH = f'{checkpoint_dir}/training_errors.txt'
322
+ with open_logfile(LOG_PATH,mode=logmode), open_errfile(ERR_PATH,mode=logmode):
323
+ if not(config.FINETUNE_FROM_SCRATCH):
324
+ log_update(f"\n{'-'*200}\nResuming finetuning from checkpoint {start_epoch-1} (first new checkpoint: {start_epoch})\n")
325
+ log_update(f"Settings tag for original model (starting point for finetuning) = {START_MODEL_TRAIN_SETTINGS_STRING}\nSettings tag for new model based on configs = {TRAIN_SETTINGS_STRING}\nSame: {START_MODEL_TRAIN_SETTINGS_STRING==TRAIN_SETTINGS_STRING}\n")
326
+ # ONLY proceed with training if we're using the same settings, otherwise we are not finetuning the model we think we are!
327
+ assert START_MODEL_TRAIN_SETTINGS_STRING==TRAIN_SETTINGS_STRING
328
+
329
+ # Print configurations
330
+ print_configpy(config)
331
+
332
+ # Verify that the environment variables are set correctly
333
+ check_env_variables()
334
+
335
+ # Check CUDA availability and set device
336
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
337
+ log_update(f"\nUsing device: {device}")
338
+
339
+ # Init wandb
340
+ wandb.init(project=config.WANDB_PROJECT, entity=config.WANDB_ENTITY, name=WANDB_NAME , config={
341
+ "batch_size": config.BATCH_SIZE,
342
+ "epochs": config.EPOCHS,
343
+ "learning_rate": config.LEARNING_RATE,
344
+ })
345
+
346
+ # Initialize model and prepare it (freeze/unfreeze proper layers). Initialize optimizer as well. Details depend on whether we are finetuning from scratch.
347
+ model, optimizer = intialize_model_and_optimizer(config.FINETUNE_FROM_SCRATCH, device,
348
+ path_to_starting_ckpt=config.PATH_TO_STARTING_CKPT,
349
+ learning_rate=config.LEARNING_RATE,
350
+ n_unfrozen_layers=config.N_UNFROZEN_LAYERS,
351
+ unfreeze_query=config.UNFREEZE_QUERY,
352
+ unfreeze_key=config.UNFREEZE_KEY,
353
+ unfreeze_value=config.UNFREEZE_VALUE)
354
+
355
+ # Initialize the tokenizer (independent of starting model for finetuning)
356
+ tokenizer = model.tokenizer
357
+
358
+ # Create DataLoader instances and perform sanity checks on them
359
+ train_loader = get_dataloader(config.TRAIN_PATH, tokenizer, probability_type=config.PROBABILITY_TYPE, batch_size=config.BATCH_SIZE, max_length=config.MAX_LENGTH, shuffle=True) ## FOR DEBUGGING ONLY, change shuffle to False. Otherwise, True!!
360
+ val_loader = get_dataloader(config.VAL_PATH, tokenizer, probability_type=config.PROBABILITY_TYPE, batch_size=config.BATCH_SIZE, max_length=config.MAX_LENGTH, shuffle=False)
361
+ test_loader = get_dataloader(config.TEST_PATH, tokenizer, probability_type=config.PROBABILITY_TYPE, batch_size=config.BATCH_SIZE, max_length=config.MAX_LENGTH, shuffle=False)
362
+
363
+ # If we're continuing to finetune an old ckpt, store the old batch diversity plot before we overwrite it
364
+ check_dataloaders(train_loader, val_loader, test_loader, max_length=config.MAX_LENGTH, checkpoint_dir=checkpoint_dir)
365
+
366
+ # Set up a masking rate scheduler, if one is needed
367
+ mask_rate_scheduler = None
368
+ if config.VAR_MASK_RATE:
369
+ mask_rate_scheduler = get_mask_rate_scheduler(scheduler_type=config.MASK_SCHEDULER,
370
+ min_masking_rate=config.MASK_LOW,
371
+ max_masking_rate=config.MASK_HIGH,
372
+ total_batches=len(train_loader),
373
+ total_steps=config.MASK_STEPS)
374
+
375
+ # Train the model
376
+ train(model, tokenizer, optimizer, train_loader, val_loader,
377
+ n_epochs=config.EPOCHS,
378
+ start_epoch = start_epoch,
379
+ device=device,
380
+ mask_rate_scheduler=mask_rate_scheduler,
381
+ mask_percentage=config.MASK_PERCENTAGE,
382
+ checkpoint_dir=checkpoint_dir)
383
+
384
+ # Test the model
385
+ test(model, tokenizer, test_loader, mask_percentage=0.15, device=device, checkpoint_dir=checkpoint_dir)
386
+
387
+ if __name__ == "__main__":
388
+ main()
fuson_plm/training/utils.py ADDED
@@ -0,0 +1,312 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import pandas as pd
3
+ import torch
4
+ import os
5
+ from torch.nn.functional import softmax
6
+ from fuson_plm.utils.logging import log_update
7
+ from torch.utils.data import Dataset, DataLoader
8
+ from transformers import AutoTokenizer
9
+ from abc import ABC, abstractmethod
10
+
11
+ #----------------------------------------------------------------------------------------------------------------------------------------------------
12
+ #### Masking Rate Scheduler base class and sub classes
13
+ # abstract base class
14
+ class MaskingRateScheduler(ABC):
15
+ def __init__(self, total_steps, min_masking_rate, max_masking_rate, last_step=-1):
16
+ self.total_steps = total_steps
17
+ self.min_masking_rate = min_masking_rate
18
+ self.max_masking_rate = max_masking_rate
19
+ self.current_step = last_step
20
+
21
+ def step(self):
22
+ self.current_step += 1
23
+
24
+ def reset(self):
25
+ """Reset the scheduler to its initial state."""
26
+ self.current_step = -1
27
+
28
+ def get_masking_rate(self):
29
+ progress = self.current_step / self.total_steps
30
+ return self.compute_masking_rate(progress)
31
+
32
+ @abstractmethod
33
+ def compute_masking_rate(self, progress):
34
+ """To be implemented by subclasses for specific increase functions."""
35
+ raise NotImplementedError("Subclasses must implement this method.")
36
+
37
+
38
+ class CosineIncreaseMaskingRateScheduler(MaskingRateScheduler):
39
+ def compute_masking_rate(self, progress):
40
+ # Use a cosine increase function
41
+ cosine_increase = 0.5 * (1 - np.cos(np.pi * progress))
42
+ return self.min_masking_rate + (self.max_masking_rate - self.min_masking_rate) * cosine_increase
43
+
44
+ class LogLinearIncreaseMaskingRateScheduler(MaskingRateScheduler):
45
+ def compute_masking_rate(self, progress):
46
+ # Avoid log(0) by clamping progress to a minimum of a small positive number
47
+ progress = max(progress, 1e-10)
48
+ log_linear_increase = np.log1p(progress) / np.log1p(1) # Normalizing to keep range in [0, 1]
49
+ return self.min_masking_rate + (self.max_masking_rate - self.min_masking_rate) * log_linear_increase
50
+
51
+ class StepwiseIncreaseMaskingRateScheduler(MaskingRateScheduler):
52
+ def __init__(self, total_batches, min_masking_rate, max_masking_rate, num_steps):
53
+ super().__init__(total_steps=total_batches, min_masking_rate=min_masking_rate, max_masking_rate=max_masking_rate)
54
+ self.num_steps = num_steps
55
+ self.batch_interval = total_batches // (num_steps) # Adjusting to ensure max rate is included
56
+ self.rate_increment = (max_masking_rate - min_masking_rate) / (num_steps - 1) # Include end rate in the steps
57
+
58
+ def compute_masking_rate(self, progress):
59
+ # Determine the current step based on the number of completed batches
60
+ current_step = int(self.current_step / self.batch_interval)
61
+ # Cap the step number to `num_steps - 1` to include the max rate at the final step
62
+ current_step = min(current_step, self.num_steps - 1)
63
+ # Calculate the masking rate for the current step
64
+ masking_rate = self.min_masking_rate + current_step * self.rate_increment
65
+ return masking_rate
66
+
67
+ def get_mask_rate_scheduler(scheduler_type="cosine",min_masking_rate=0.15,max_masking_rate=0.40,total_batches=100,total_steps=20):
68
+ """
69
+ Initialize the mask rate scheduler and return it
70
+ """
71
+ if scheduler_type=="cosine":
72
+ return CosineIncreaseMaskingRateScheduler(total_steps=total_batches,
73
+ min_masking_rate=min_masking_rate,
74
+ max_masking_rate=max_masking_rate)
75
+ elif scheduler_type=="loglinear":
76
+ return LogLinearIncreaseMaskingRateScheduler(total_steps=total_batches,
77
+ min_masking_rate=min_masking_rate,
78
+ max_masking_rate=max_masking_rate)
79
+ elif scheduler_type=="stepwise":
80
+ return StepwiseIncreaseMaskingRateScheduler(total_batches=total_batches,
81
+ num_steps=total_steps,
82
+ min_masking_rate=min_masking_rate,
83
+ max_masking_rate=max_masking_rate)
84
+ else:
85
+ raise Exception("Must specify valid scheduler_type: cosine, loglinear, stepwise")
86
+
87
+ # Adjusted Dataloader for the sequences and probability vectors
88
+ class ProteinDataset(Dataset):
89
+ def __init__(self, data_path, tokenizer, probability_type, max_length=512):
90
+ self.dataframe = pd.read_csv(data_path)
91
+ self.tokenizer = tokenizer
92
+ self.probability_type=probability_type
93
+ self.max_length = max_length
94
+
95
+ self.set_probabilities()
96
+
97
+ def __len__(self):
98
+ return len(self.dataframe)
99
+
100
+ def set_probabilities(self):
101
+ if self.probability_type=="snp":
102
+ self.dataframe = self.dataframe.rename(columns={'snp_probabilities':'probabilities'})
103
+ if self.probability_type=="uniform":
104
+ self.dataframe['probabilities'] = self.dataframe['sequence'].apply(len).apply(lambda x: ('1,'*x)[0:-1])
105
+
106
+ # make probabilities into numbers if they aren't already
107
+ if type(self.dataframe['probabilities'][0]) == str:
108
+ self.dataframe['probabilities'] = self.dataframe['probabilities'].apply(
109
+ lambda x: np.array([float(i) for i in x.split(',')])
110
+ )
111
+
112
+ def get_padded_probabilities(self, idx):
113
+ '''
114
+ Pads probabilities to max_length if they're too short; truncate them if they're too long
115
+ '''
116
+ no_mask_value = int(-1e9) # will be used to make sure CLS and PAD aren't masked
117
+
118
+ # add a no-mask slot for <CLS>
119
+ prob = np.concatenate((
120
+ np.array([no_mask_value]),
121
+ self.dataframe.iloc[idx]['probabilities']
122
+ )
123
+ )
124
+
125
+ # Pad with no_mask_value for everything after the probability vector ends
126
+ if len(prob) < self.max_length:
127
+ return np.pad(
128
+ prob,
129
+ (0, self.max_length - len(prob)),
130
+ 'constant', constant_values=(0,no_mask_value))
131
+
132
+ # If it's too long, we need to truncate, but we also need to change the last token to an <EOS>.
133
+ prob = prob[0:self.max_length-1]
134
+ prob = np.concatenate((
135
+ prob,
136
+ np.array([no_mask_value]),
137
+ )
138
+ )
139
+ return prob
140
+
141
+ def __getitem__(self, idx):
142
+ sequence = self.dataframe.iloc[idx]['sequence']
143
+ probability = self.get_padded_probabilities(idx) # extract them
144
+ inputs = self.tokenizer(sequence, return_tensors="pt", padding='max_length', truncation=True, max_length=self.max_length) # does this have to be 512?
145
+ inputs = {key: tensor.squeeze(0) for key, tensor in inputs.items()} # Remove batch dimension
146
+ return inputs, probability
147
+
148
+ def get_dataloader(data_path, tokenizer, probability_type='snp', max_length=512, batch_size=8, shuffle=True):
149
+ """
150
+ Creates a DataLoader for the dataset.
151
+ Args:
152
+ data_path (str): Path to the CSV file (train, val, or test).
153
+ batch_size (int): Batch size.
154
+ shuffle (bool): Whether to shuffle the data.
155
+ tokenizer (Tokenizer): tokenizer object for data tokenization
156
+ Returns:
157
+ DataLoader: DataLoader object.
158
+ """
159
+ dataset = ProteinDataset(data_path, tokenizer, probability_type, max_length=max_length)
160
+ return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle)
161
+
162
+ def check_dataloaders(train_loader, val_loader, test_loader, max_length=512, checkpoint_dir=''):
163
+ log_update(f'\nBuilt train, validation, and test dataloders')
164
+ log_update(f"\tNumber of sequences in the Training DataLoader: {len(train_loader.dataset)}")
165
+ log_update(f"\tNumber of sequences in the Validation DataLoader: {len(val_loader.dataset)}")
166
+ log_update(f"\tNumber of sequences in the Training DataLoader: {len(test_loader.dataset)}")
167
+ dataloader_overlaps = check_dataloader_overlap(train_loader, val_loader, test_loader)
168
+ if len(dataloader_overlaps)==0: log_update("\tDataloaders are clean (no overlaps)")
169
+ else: log_update(f"\tWARNING! sequence overlap found: {','.join(dataloader_overlaps)}")
170
+
171
+ # write length ranges to a text file
172
+ if not(os.path.exists(f'{checkpoint_dir}/batch_diversity')):
173
+ os.mkdir(f'{checkpoint_dir}/batch_diversity')
174
+
175
+ max_length_violators = []
176
+ for name, dataloader in {'train':train_loader, 'val':val_loader, 'test':test_loader}.items():
177
+ max_length_followed, length_ranges = check_max_length_and_length_diversity(dataloader, max_length)
178
+ if max_length_followed == False:
179
+ max_length_violators.append(name)
180
+
181
+ with open(f'{checkpoint_dir}/batch_diversity/{name}_batch_length_ranges.txt','w') as f:
182
+ for tup in length_ranges:
183
+ f.write(f'{tup[0]}\t{tup[1]}\n')
184
+
185
+ if len(max_length_violators)==0: log_update(f"\tDataloaders follow the max length limit set by user: {max_length}")
186
+ else: log_update(f"\tWARNING! these loaders have sequences longer than max length={max_length}: {','.join(max_length_violators)}")
187
+
188
+ def check_dataloader_overlap(train_loader, val_loader, test_loader):
189
+ """
190
+ Check the data that's about to go into the model. Make sure there is no overlap between train, test, and val
191
+
192
+ Returns:
193
+ """
194
+ train_protein_seqs = set(train_loader.dataset.dataframe['sequence'].unique())
195
+ val_protein_seqs = set(val_loader.dataset.dataframe['sequence'].unique())
196
+ test_protein_seqs = set(test_loader.dataset.dataframe['sequence'].unique())
197
+
198
+ tr_va = len(train_protein_seqs.intersection(val_protein_seqs))
199
+ tr_te = len(train_protein_seqs.intersection(test_protein_seqs))
200
+ va_te = len(val_protein_seqs.intersection(test_protein_seqs))
201
+
202
+ overlaps = []
203
+ if tr_va==tr_te==va_te==0:
204
+ return overlaps # data is clean
205
+ else:
206
+ if tr_va > 0: overlaps.append(f"Train-Val Overlap={tr_va}")
207
+ if tr_te > 0: overlaps.append(f"Train-Test Overlap={tr_te}")
208
+ if va_te > 0: overlaps.append(f"Val-Test Overlap={va_te}")
209
+ return overlaps
210
+
211
+ def check_max_length_and_length_diversity(dataloader, max_length):
212
+ """
213
+ Check if all sequences in the DataLoader conform to the specified max_length,
214
+ and return the sequence length ranges within each batch.
215
+
216
+ Args:
217
+ dataloader (DataLoader): The DataLoader object to check.
218
+ max_length (int): The maximum allowed sequence length.
219
+
220
+ Returns:
221
+ bool: True if all sequences are within the max_length, False otherwise.
222
+ list: A list of tuples representing the min and max sequence lengths in each batch.
223
+ """
224
+ length_ranges = []
225
+ all_within_max_length = True
226
+
227
+ for batch_idx, (inputs, _) in enumerate(dataloader):
228
+ input_ids = inputs['input_ids']
229
+
230
+ # Calculate the actual lengths of sequences in this batch
231
+ actual_lengths = (input_ids != dataloader.dataset.tokenizer.pad_token_id).sum(dim=1)
232
+ min_length = actual_lengths.min().item()
233
+ max_length_in_batch = actual_lengths.max().item()
234
+
235
+ # Check for max length violation
236
+ if max_length_in_batch > max_length:
237
+ #print(f"Error: Sequence exceeds max_length of {max_length} at batch {batch_idx + 1}. Max length found: {max_length_in_batch}")
238
+ all_within_max_length = False
239
+
240
+ # Store the length range for this batch
241
+ length_ranges.append((min_length, max_length_in_batch))
242
+
243
+ #print(f"All sequences in the DataLoader conform to the max_length of {max_length}.") if all_within_max_length else None
244
+ #print(f"Sequence length ranges per batch: {length_ranges}")
245
+
246
+ return all_within_max_length, length_ranges
247
+
248
+
249
+ def check_max_length_in_dataloader(dataloader, max_length):
250
+ """
251
+ Check if all sequences in the DataLoader conform to the specified max_length.
252
+
253
+ Args:
254
+ dataloader (DataLoader): The DataLoader object to check.
255
+ max_length (int): The maximum allowed sequence length.
256
+
257
+ Returns:
258
+ bool: True if all sequences are within the max_length, False otherwise.
259
+ """
260
+ for batch_idx, (inputs, _) in enumerate(dataloader):
261
+ input_ids = inputs['input_ids']
262
+
263
+ # Check if any sequence length exceeds max_length
264
+ if input_ids.size(1) > max_length:
265
+ return False
266
+
267
+ return True
268
+
269
+
270
+ def batch_sample_mask_tokens_with_probabilities(inputs, probabilities, tokenizer: AutoTokenizer, mask_percentage=0.15):
271
+ """
272
+ """
273
+ #print('the batch sample method was called!')
274
+ labels = inputs["input_ids"].detach().clone()
275
+ labels[labels != tokenizer.mask_token_id] = -100 # Set labels for unmasked tokens to -100
276
+
277
+ # Iterate over each sequence and its corresponding probabilities in the batch
278
+ for idx in range(inputs["input_ids"].size(0)): # Assuming the first dimension is batch size
279
+ input_ids = inputs["input_ids"][idx]
280
+ prob = probabilities[idx]
281
+
282
+ cls_token_index = (input_ids == 0).nonzero(as_tuple=False)[0].item()
283
+ eos_token_index = (input_ids == 2).nonzero(as_tuple=False)[0].item()
284
+ seq_length = eos_token_index - (cls_token_index+1)
285
+
286
+ assert prob.shape[0] == input_ids.shape[0]
287
+
288
+ # Normalize probabilities using softmax
289
+ prob = softmax(prob, dim=0).cpu().numpy() # move to CPU for numpy
290
+ assert 1 - sum(prob) < 1e-6
291
+
292
+ # Calculate the number of tokens to mask
293
+ num_tokens_to_mask = int(mask_percentage * seq_length)
294
+
295
+ # Choose indices to mask based on the probability distribution
296
+ mask_indices = np.random.choice(input_ids.shape[0], size=num_tokens_to_mask, replace=False, p=prob)
297
+ attention_mask_1_indices = np.arange(0, eos_token_index+1, 1)
298
+
299
+ # Mask the selected indices and set the corresponding labels
300
+ labels[idx, mask_indices] = input_ids[mask_indices].detach().clone()
301
+ input_ids[mask_indices] = tokenizer.mask_token_id
302
+
303
+ inputs["attention_mask"][idx] = torch.zeros_like(input_ids)
304
+ inputs["attention_mask"][idx][attention_mask_1_indices] = 1 # just added this to try and update the attention mask....
305
+
306
+ # Update the input_ids in the inputs dictionary
307
+ inputs["input_ids"][idx] = input_ids
308
+
309
+ inputs["labels"] = labels
310
+ return inputs
311
+
312
+