uploaded training code and model weights
Browse files- fuson_plm/training/README.md +60 -0
- fuson_plm/training/__init__.py +0 -0
- fuson_plm/training/__pycache__/__init__.cpython-310.pyc +0 -0
- fuson_plm/training/__pycache__/config.cpython-310.pyc +0 -0
- fuson_plm/training/__pycache__/model.cpython-310.pyc +0 -0
- fuson_plm/training/__pycache__/plot.cpython-310.pyc +0 -0
- fuson_plm/training/__pycache__/train.cpython-310.pyc +0 -0
- fuson_plm/training/__pycache__/utils.cpython-310.pyc +0 -0
- fuson_plm/training/config.py +38 -0
- fuson_plm/training/demo.py +46 -0
- fuson_plm/training/model.py +119 -0
- fuson_plm/training/test_esm2.py +122 -0
- fuson_plm/training/train.py +388 -0
- fuson_plm/training/utils.py +312 -0
fuson_plm/training/README.md
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
## Training Script
|
2 |
+
|
3 |
+
This folder holds code for training the model (`train.py`), defining the model architecture (`model.py`), and defining utility functions including masking rate schedulers adn dataloaders (`utils.py`). There is also a script for running ESM-2 on the test data (`test_esm2.py`).
|
4 |
+
|
5 |
+
The weights and other necessary files for loading FusOn-pLM are stored in `checkpoints/best/ckpt`. Results on the test set are stored in `checkpoints/best/test_results.csv`.
|
6 |
+
|
7 |
+
### Usage
|
8 |
+
#### Configs
|
9 |
+
The `config.py` script holds configurations for **training** and **plotting**.
|
10 |
+
|
11 |
+
```python
|
12 |
+
# Model parameters
|
13 |
+
EPOCHS = 30
|
14 |
+
BATCH_SIZE = 8
|
15 |
+
MAX_LENGTH = 2000
|
16 |
+
LEARNING_RATE = 3e-4
|
17 |
+
N_UNFROZEN_LAYERS = 8
|
18 |
+
UNFREEZE_QUERY = True
|
19 |
+
UNFREEZE_KEY = True
|
20 |
+
UNFREEZE_VALUE = True
|
21 |
+
|
22 |
+
### Masking parameters - must use either variable or fixed masking rate
|
23 |
+
# var masking rate (choice 1)
|
24 |
+
VAR_MASK_RATE = True # if this is
|
25 |
+
MASK_LOW = 0.15
|
26 |
+
MASK_HIGH = 0.40
|
27 |
+
MASK_STEPS = 20
|
28 |
+
MASK_SCHEDULER = "cosine" # specify the type of scheduler to use. options are: "cosine","loglinear","stepwise"
|
29 |
+
# fixed masking rate (choice 2)
|
30 |
+
MASK_PERCENTAGE = 0.15 # if VAR_MASK_RATE = False, code will use fixed masking rate
|
31 |
+
|
32 |
+
# To continue training a model you already started, fill in the following parameters
|
33 |
+
FINETUNE_FROM_SCRATCH = True # Set to False if you want to finetune from a checkpoint
|
34 |
+
PATH_TO_STARTING_CKPT = '' # only set the path if FINETUNE_FROM_SCRATCH = False
|
35 |
+
|
36 |
+
# File paths - do not change unless you move the training dta
|
37 |
+
TRAIN_PATH = '../data/splits/train_df.csv'
|
38 |
+
VAL_PATH = '../data/splits/val_df.csv'
|
39 |
+
TEST_PATH = '../data/splits/test_df.csv'
|
40 |
+
|
41 |
+
# WandB parameters
|
42 |
+
# Fill these in with your own WandB account info
|
43 |
+
WANDB_PROJECT = ''
|
44 |
+
WANDB_ENTITY = ''
|
45 |
+
WANDB_API_KEY=''
|
46 |
+
|
47 |
+
# GPU parameters
|
48 |
+
CUDA_VISIBLE_DEVICES = "0"
|
49 |
+
```
|
50 |
+
|
51 |
+
#### Training
|
52 |
+
The `train.py` script trains a fusion-aware ESM model according to the settings specified in `config.py`.
|
53 |
+
|
54 |
+
To run, enter in terminal:
|
55 |
+
```bash
|
56 |
+
python train.py
|
57 |
+
```
|
58 |
+
or, to run the (long) training process in the background:
|
59 |
+
```bash
|
60 |
+
nohup python train.py > train.out 2> train.err &
|
fuson_plm/training/__init__.py
ADDED
File without changes
|
fuson_plm/training/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (152 Bytes). View file
|
|
fuson_plm/training/__pycache__/config.cpython-310.pyc
ADDED
Binary file (898 Bytes). View file
|
|
fuson_plm/training/__pycache__/model.cpython-310.pyc
ADDED
Binary file (4.88 kB). View file
|
|
fuson_plm/training/__pycache__/plot.cpython-310.pyc
ADDED
Binary file (4.02 kB). View file
|
|
fuson_plm/training/__pycache__/train.cpython-310.pyc
ADDED
Binary file (12.2 kB). View file
|
|
fuson_plm/training/__pycache__/utils.cpython-310.pyc
ADDED
Binary file (11.3 kB). View file
|
|
fuson_plm/training/config.py
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
###### TRAINING
|
2 |
+
# Model parameters
|
3 |
+
EPOCHS = 30
|
4 |
+
BATCH_SIZE = 8
|
5 |
+
MAX_LENGTH = 2000
|
6 |
+
LEARNING_RATE = 3e-4
|
7 |
+
N_UNFROZEN_LAYERS = 8
|
8 |
+
UNFREEZE_QUERY = True
|
9 |
+
UNFREEZE_KEY = True
|
10 |
+
UNFREEZE_VALUE = True
|
11 |
+
|
12 |
+
### Masking parameters - must use either variable or fixed masking rate
|
13 |
+
# var masking rate (choice 1)
|
14 |
+
VAR_MASK_RATE = True # if this is
|
15 |
+
MASK_LOW = 0.15
|
16 |
+
MASK_HIGH = 0.40
|
17 |
+
MASK_STEPS = 20
|
18 |
+
MASK_SCHEDULER = "cosine" # specify the type of scheduler to use. options are: "cosine","loglinear","stepwise"
|
19 |
+
# fixed masking rate (choice 2)
|
20 |
+
MASK_PERCENTAGE = 0.15 # if VAR_MASK_RATE = False, code will use fixed masking rate
|
21 |
+
|
22 |
+
# To continue training a model you already started, fill in the following parameters
|
23 |
+
FINETUNE_FROM_SCRATCH = True # Set to False if you want to finetune from a checkpoint
|
24 |
+
PATH_TO_STARTING_CKPT = '' # only set the path if FINETUNE_FROM_SCRATCH = False
|
25 |
+
|
26 |
+
# File paths - do not change unless you move the training dta
|
27 |
+
TRAIN_PATH = '../data/splits/train_df.csv'
|
28 |
+
VAL_PATH = '../data/splits/val_df.csv'
|
29 |
+
TEST_PATH = '../data/splits/test_df.csv'
|
30 |
+
|
31 |
+
# WandB parameters
|
32 |
+
# Fill these in with your own WandB account info
|
33 |
+
WANDB_PROJECT = ''
|
34 |
+
WANDB_ENTITY = ''
|
35 |
+
WANDB_API_KEY=''
|
36 |
+
|
37 |
+
# GPU parameters
|
38 |
+
CUDA_VISIBLE_DEVICES = "0"
|
fuson_plm/training/demo.py
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fuson_plm.training.model import FusOnpLM
|
2 |
+
from transformers import AutoModelForMaskedLM, AutoTokenizer, AutoModel
|
3 |
+
import logging
|
4 |
+
import torch
|
5 |
+
import os
|
6 |
+
|
7 |
+
os.environ['CUDA_VISIBLE_DEVICES'] = "1"
|
8 |
+
|
9 |
+
# Suppress warnings about newly initialized 'esm.pooler.dense.bias', 'esm.pooler.dense.weight' layers - these are not used to extract embeddings
|
10 |
+
logging.getLogger("transformers.modeling_utils").setLevel(logging.ERROR)
|
11 |
+
|
12 |
+
# Set device
|
13 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
14 |
+
print(f"Using device: {device}")
|
15 |
+
|
16 |
+
# Load the tokenizer and model
|
17 |
+
model_name = 'checkpoints/old_splits_snp_2000_ft_11layers_Q_b8_lr5e-05_mask0.15-08-12-2024-12:42:48/checkpoint_epoch_1.pth'
|
18 |
+
model = AutoModel.from_pretrained(model_name) # initialize model
|
19 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
20 |
+
model.eval()
|
21 |
+
model.to(device)
|
22 |
+
|
23 |
+
# Example fusion oncoprotein sequence: MLLT10:PICALM, associated with Acute Myeloid Leukemia (LAML)
|
24 |
+
# Amino acids 1-80 are derived from the head gene, MLLT10
|
25 |
+
# Amino acids 81-119 are derived from the tail gene, PICALM
|
26 |
+
sequence = "MVSSDRPVSLEDEVSHSMKEMIGGCCVCSDERGWAENPLVYCDGHGCSVAVHQACYGIVQVPTGPWFCRKCESQERAARVPPQMGSVPVMTQPTLIYSQPVMRPPNPFGPVSGAQIQFM"
|
27 |
+
|
28 |
+
# Tokenize the input sequence
|
29 |
+
inputs = tokenizer(sequence, return_tensors="pt", padding=True, truncation=True,max_length=2000)
|
30 |
+
inputs = {k: v.to(device) for k, v in inputs.items()}
|
31 |
+
|
32 |
+
# Get the embeddings
|
33 |
+
with torch.no_grad():
|
34 |
+
outputs = model(**inputs)
|
35 |
+
# The embeddings are in the last_hidden_state tensor
|
36 |
+
embeddings = outputs.last_hidden_state
|
37 |
+
# remove extra dimension
|
38 |
+
embeddings = embeddings.squeeze(0)
|
39 |
+
# remove BOS and EOS tokens
|
40 |
+
embeddings = embeddings[1:-1, :]
|
41 |
+
|
42 |
+
# Convert embeddings to numpy array (if needed)
|
43 |
+
embeddings = embeddings.cpu().numpy()
|
44 |
+
|
45 |
+
print("Sequence length: ", len(sequence))
|
46 |
+
print("Per-residue embeddings shape:", embeddings.shape)
|
fuson_plm/training/model.py
ADDED
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import AutoModelForMaskedLM, AutoTokenizer, AutoModel
|
2 |
+
import torch
|
3 |
+
import os
|
4 |
+
|
5 |
+
class FusOnTokenizer:
|
6 |
+
"""
|
7 |
+
FusOnTokenizer class: a wrapper around AutoTokenizer
|
8 |
+
"""
|
9 |
+
def __init__(self, pretrained_path='facebook/esm2_t33_650M_UR50D'):
|
10 |
+
self.tokenizer = AutoTokenizer.from_pretrained(pretrained_path)
|
11 |
+
|
12 |
+
def __getattr__(self, name):
|
13 |
+
"""
|
14 |
+
Delegate attribute access to the underlying tokenizer.
|
15 |
+
This allows calls like .tokenize(), .train(), and .eval() to be forwarded to the tokenizer.
|
16 |
+
"""
|
17 |
+
return getattr(self.tokenizer, name)
|
18 |
+
|
19 |
+
def __call__(self, *args, **kwargs):
|
20 |
+
"""
|
21 |
+
Make the FusOnTokenizer object callable, delegating to the tokenizer's __call__ method.
|
22 |
+
"""
|
23 |
+
return self.tokenizer(*args, **kwargs)
|
24 |
+
|
25 |
+
def save_tokenizer(self, save_directory):
|
26 |
+
self.tokenizer.save_pretrained(save_directory)
|
27 |
+
|
28 |
+
def load_tokenizer(self, load_directory):
|
29 |
+
self.tokenizer = AutoTokenizer.from_pretrained(load_directory)
|
30 |
+
|
31 |
+
class FusOnpLM:
|
32 |
+
"""
|
33 |
+
FusOn-pLM class: a wrapper around AutoModelForMaskedLM
|
34 |
+
"""
|
35 |
+
def __init__(self, pretrained_path='facebook/esm2_t33_650M_UR50D', ckpt_path = None, mlm_head=False):
|
36 |
+
if not(ckpt_path is None):
|
37 |
+
self.load_model(ckpt_path, mlm_head)
|
38 |
+
else:
|
39 |
+
# Load the pre-trained model and tokenizer
|
40 |
+
self.model = AutoModelForMaskedLM.from_pretrained(pretrained_path)
|
41 |
+
self.tokenizer = FusOnTokenizer(pretrained_path)
|
42 |
+
|
43 |
+
self.n_layers = self.count_encoder_layers()
|
44 |
+
|
45 |
+
def __getattr__(self, name):
|
46 |
+
"""
|
47 |
+
Delegate attribute access to the underlying model.
|
48 |
+
This allows calls like .to(), .train(), and .eval() to be forwarded to the model.
|
49 |
+
"""
|
50 |
+
return getattr(self.model, name)
|
51 |
+
|
52 |
+
def __call__(self, *args, **kwargs):
|
53 |
+
"""
|
54 |
+
Make the FusOnpLM object callable, delegating to the model's __call__ method.
|
55 |
+
"""
|
56 |
+
return self.model(*args, **kwargs)
|
57 |
+
|
58 |
+
def freeze_model(self):
|
59 |
+
"""
|
60 |
+
Freezes all parameters in the model
|
61 |
+
"""
|
62 |
+
for param in self.model.parameters():
|
63 |
+
param.requires_grad = False
|
64 |
+
|
65 |
+
def unfreeze_last_n_layers(self, n_unfrozen_layers, unfreeze_query=True, unfreeze_key=True, unfreeze_value=True):
|
66 |
+
"""
|
67 |
+
Unfreezes specific parts of the final n layers in the model's encoder.
|
68 |
+
|
69 |
+
Args:
|
70 |
+
n_unfrozen_layers (int): Number of final layers to unfreeze.
|
71 |
+
unfreeze_query (bool): Whether to unfreeze the query projections. Default is True.
|
72 |
+
unfreeze_key (bool): Whether to unfreeze the key projections. Default is True.
|
73 |
+
unfreeze_value (bool): Whether to unfreeze the value projections. Default is True.
|
74 |
+
"""
|
75 |
+
for i, layer in enumerate(self.model.esm.encoder.layer):
|
76 |
+
if (self.n_layers - i) <= n_unfrozen_layers: # Only the last n layers
|
77 |
+
if unfreeze_query:
|
78 |
+
self._unfreeze_parameters(layer.attention.self.query)
|
79 |
+
if unfreeze_key:
|
80 |
+
self._unfreeze_parameters(layer.attention.self.key)
|
81 |
+
if unfreeze_value:
|
82 |
+
self._unfreeze_parameters(layer.attention.self.value)
|
83 |
+
|
84 |
+
def _unfreeze_parameters(self, module):
|
85 |
+
"""
|
86 |
+
Helper method to unfreeze parameters in a given module.
|
87 |
+
|
88 |
+
Args:
|
89 |
+
module (nn.Module): The module whose parameters are to be unfrozen.
|
90 |
+
"""
|
91 |
+
for param in module.parameters():
|
92 |
+
param.requires_grad = True
|
93 |
+
|
94 |
+
|
95 |
+
def count_encoder_layers(self):
|
96 |
+
"""
|
97 |
+
Count the number of encoder layers in the model.
|
98 |
+
"""
|
99 |
+
return len(self.model.esm.encoder.layer)
|
100 |
+
|
101 |
+
def save_model(self, save_directory, optimizer=None):
|
102 |
+
# Save the model and tokenizer
|
103 |
+
self.model.save_pretrained(save_directory)
|
104 |
+
self.tokenizer.save_pretrained(save_directory)
|
105 |
+
|
106 |
+
# If an optimizer is provided, save its state dict
|
107 |
+
if optimizer is not None:
|
108 |
+
optimizer_path = os.path.join(save_directory, "optimizer.pt")
|
109 |
+
torch.save(optimizer.state_dict(), optimizer_path)
|
110 |
+
|
111 |
+
def load_model(self, load_directory, mlm_head):
|
112 |
+
# Load a checkpoint of the model either with or without an MLM head
|
113 |
+
if mlm_head:
|
114 |
+
self.model = AutoModelForMaskedLM.from_pretrained(load_directory)
|
115 |
+
else:
|
116 |
+
# Load the model and tokenizer from a directory
|
117 |
+
self.model = AutoModel.from_pretrained(load_directory)
|
118 |
+
self.tokenizer = AutoTokenizer.from_pretrained(load_directory)
|
119 |
+
|
fuson_plm/training/test_esm2.py
ADDED
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
### Run ESM2 on the validation and test set. Get val and test losses.
|
2 |
+
import os
|
3 |
+
import fuson_plm.training.config as config
|
4 |
+
# Set the WANDB_API_KEY environment variable
|
5 |
+
os.environ['WANDB_API_KEY'] = config.WANDB_API_KEY
|
6 |
+
os.environ['CUDA_VISIBLE_DEVICES'] = config.CUDA_VISIBLE_DEVICES
|
7 |
+
|
8 |
+
import torch
|
9 |
+
import tqdm
|
10 |
+
import numpy as np
|
11 |
+
import pandas as pd
|
12 |
+
import logging
|
13 |
+
from transformers import AutoModelForMaskedLM, AutoTokenizer
|
14 |
+
from fuson_plm.utils.logging import log_update, open_logfile, print_configpy
|
15 |
+
from fuson_plm.benchmarking.caid.utils import DisorderDataset, get_dataloader, check_dataloaders
|
16 |
+
from fuson_plm.training.utils import batch_sample_mask_tokens_with_probabilities, get_dataloader, check_dataloaders
|
17 |
+
from fuson_plm.training.train import test
|
18 |
+
|
19 |
+
def load_esm2_maskedlm(esm_type, device=None):
|
20 |
+
"""
|
21 |
+
Loads ESM-2 version of a specified version (e.g. esm2_t33_650M_UR50D)
|
22 |
+
"""
|
23 |
+
# Suppress warnings about newly initialized 'esm.pooler.dense.bias', 'esm.pooler.dense.weight' layers - these are not used to extract embeddings
|
24 |
+
logging.getLogger("transformers.modeling_utils").setLevel(logging.ERROR)
|
25 |
+
|
26 |
+
if device is None:
|
27 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
28 |
+
print(f"Using device: {device}")
|
29 |
+
|
30 |
+
model = AutoModelForMaskedLM.from_pretrained(f"facebook/{esm_type}")
|
31 |
+
tokenizer = AutoTokenizer.from_pretrained(f"facebook/{esm_type}")
|
32 |
+
|
33 |
+
model.to(device)
|
34 |
+
model.eval() # disables dropout for deterministic results
|
35 |
+
|
36 |
+
return model, tokenizer, device
|
37 |
+
|
38 |
+
|
39 |
+
def val(model, tokenizer, val_loader, mask_percentage=0.15, device='cuda', checkpoint_dir='./checkpoints'):
|
40 |
+
"""
|
41 |
+
Same method as val, just for running the val set
|
42 |
+
"""
|
43 |
+
model.to(device)
|
44 |
+
model.eval()
|
45 |
+
total_val_loss = 0
|
46 |
+
total_weighted_val_loss = 0
|
47 |
+
total_val_masked_tokens = 0
|
48 |
+
|
49 |
+
with torch.no_grad(): # No gradients needed
|
50 |
+
# Loop over val data (no progress bar)
|
51 |
+
with tqdm.tqdm(enumerate(val_loader), total=len(val_loader), desc='Val Batch', leave=True, position=0) as tbar:
|
52 |
+
for batch_idx, (inputs, prob) in tbar:
|
53 |
+
# Move tensors
|
54 |
+
inputs = {k: v.to(device) for k, v in inputs.items()}
|
55 |
+
prob = prob.to(device)
|
56 |
+
|
57 |
+
# Mask based on probability vectors
|
58 |
+
masked_inputs = batch_sample_mask_tokens_with_probabilities(inputs, prob, tokenizer, mask_percentage=mask_percentage)
|
59 |
+
|
60 |
+
# Forward pass
|
61 |
+
outputs = model(**masked_inputs)
|
62 |
+
val_loss = outputs.loss
|
63 |
+
|
64 |
+
# Number of masked tokens
|
65 |
+
num_masked_tokens = (masked_inputs["input_ids"] == tokenizer.mask_token_id).sum().item()
|
66 |
+
|
67 |
+
# Loss calculations
|
68 |
+
total_val_loss += val_loss.item()
|
69 |
+
total_weighted_val_loss += val_loss.item() * num_masked_tokens # Multiply loss by number of masked tokens
|
70 |
+
total_val_masked_tokens += num_masked_tokens
|
71 |
+
|
72 |
+
# Compute and log avg. loss and perplexity
|
73 |
+
n_val_batches = len(val_loader)
|
74 |
+
avg_val_loss = total_val_loss / n_val_batches
|
75 |
+
avg_weighted_val_loss = total_weighted_val_loss / total_val_masked_tokens
|
76 |
+
val_perplexity = np.exp(avg_weighted_val_loss)
|
77 |
+
|
78 |
+
log_update(f"\nval results:\nTotal batches = {n_val_batches}, Total masked tokens = {total_val_masked_tokens}, Total Loss = {total_val_loss:.4f}, Avg Batch Loss = {avg_val_loss:.4f}, Avg Masked Token-Weighted Loss = {avg_weighted_val_loss:.4f}, Perplexity = {val_perplexity:.4f}")
|
79 |
+
|
80 |
+
# Save to dataframe for plotting
|
81 |
+
val_stats_df = pd.DataFrame(data={
|
82 |
+
"total_val_loss": [total_val_loss], "weighted_val_loss": [total_weighted_val_loss],
|
83 |
+
"avg_val_loss": [avg_val_loss], "avg_weighted_val_loss": [avg_weighted_val_loss],
|
84 |
+
"val_perplexity": [val_perplexity]
|
85 |
+
})
|
86 |
+
val_stats_df.to_csv(f"{checkpoint_dir}/val_results.csv",index=False) # overwrite old file no matter what; should only be one val eval
|
87 |
+
|
88 |
+
def main():
|
89 |
+
# Load the ESM-2 model
|
90 |
+
model, tokenizer, device = load_esm2_maskedlm("esm2_t33_650M_UR50D")
|
91 |
+
|
92 |
+
checkpoint_dir = f"checkpoints/esm2_t33_650M_UR50D_{config.PROBABILITY_TYPE}_mask{config.MASK_PERCENTAGE}"
|
93 |
+
os.makedirs(checkpoint_dir,exist_ok=True)
|
94 |
+
|
95 |
+
with open_logfile(f"{checkpoint_dir}/evaluate_val_test_esm.txt"):
|
96 |
+
# Print configurations
|
97 |
+
print_configpy(config)
|
98 |
+
|
99 |
+
##### Validation
|
100 |
+
val_loader = get_dataloader(config.VAL_PATH, tokenizer,
|
101 |
+
probability_type=config.PROBABILITY_TYPE,
|
102 |
+
batch_size=config.BATCH_SIZE,
|
103 |
+
max_length=config.MAX_LENGTH, shuffle=False)
|
104 |
+
|
105 |
+
# Validation
|
106 |
+
val(model, tokenizer, val_loader, config.MASK_PERCENTAGE, device=device, checkpoint_dir=checkpoint_dir)
|
107 |
+
|
108 |
+
|
109 |
+
##### Test
|
110 |
+
# Crete dataloader
|
111 |
+
test_loader = get_dataloader(config.TEST_PATH,
|
112 |
+
tokenizer,
|
113 |
+
probability_type=config.PROBABILITY_TYPE,
|
114 |
+
batch_size=config.BATCH_SIZE,
|
115 |
+
max_length=config.MAX_LENGTH, shuffle=False)
|
116 |
+
|
117 |
+
|
118 |
+
# Test the model
|
119 |
+
test(model, tokenizer, test_loader, config.MASK_PERCENTAGE, device=device, checkpoint_dir=checkpoint_dir)
|
120 |
+
|
121 |
+
if __name__ == "__main__":
|
122 |
+
main()
|
fuson_plm/training/train.py
ADDED
@@ -0,0 +1,388 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
This is a training script for finetuning ESM.
|
3 |
+
I am going to freeze the parameters in the head and unfreeze the last N layers in the model.
|
4 |
+
'''
|
5 |
+
|
6 |
+
import os
|
7 |
+
import fuson_plm.training.config as config
|
8 |
+
|
9 |
+
# Set the WANDB_API_KEY environment variable
|
10 |
+
os.environ['WANDB_API_KEY'] = config.WANDB_API_KEY
|
11 |
+
os.environ['CUDA_VISIBLE_DEVICES'] = config.CUDA_VISIBLE_DEVICES
|
12 |
+
|
13 |
+
import torch
|
14 |
+
import numpy as np
|
15 |
+
import pandas as pd
|
16 |
+
import tqdm
|
17 |
+
from datetime import datetime
|
18 |
+
import wandb
|
19 |
+
import pytz
|
20 |
+
import sys
|
21 |
+
|
22 |
+
from transformers import AdamW
|
23 |
+
|
24 |
+
from fuson_plm.utils.logging import print_configpy, get_local_time, open_logfile, open_errfile, log_update
|
25 |
+
from fuson_plm.training.model import FusOnpLM
|
26 |
+
from fuson_plm.training.utils import batch_sample_mask_tokens_with_probabilities, get_dataloader, check_dataloaders, get_mask_rate_scheduler
|
27 |
+
from fuson_plm.training.plot import make_train_val_test_bd_plot
|
28 |
+
|
29 |
+
def prepare_model(model, n_unfrozen_layers, unfreeze_query=True, unfreeze_key=True, unfreeze_value=True):
|
30 |
+
# Log the model's initial state
|
31 |
+
n_layers = model.count_encoder_layers()
|
32 |
+
total_params = sum(p.numel() for p in model.parameters())
|
33 |
+
total_head_params = sum(p.numel() for p in model.lm_head.parameters())
|
34 |
+
log_update(f'\nInitial state:\n\tTotal number of layers in the model: {n_layers}')
|
35 |
+
log_update(f'\tTotal parameters in the AutoModelforMaskedLM model: {total_params}')
|
36 |
+
log_update(f'\tTotal parameters in the MLM Head ONLY: {total_head_params}')
|
37 |
+
|
38 |
+
# Freeze the model to start
|
39 |
+
model.freeze_model()
|
40 |
+
n_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
41 |
+
log_update(f'Froze all {model.n_layers} model layers')
|
42 |
+
log_update(f'\tTrainable params: {n_trainable_params}')
|
43 |
+
|
44 |
+
# Unfreeze the last n layers
|
45 |
+
model.unfreeze_last_n_layers(n_unfrozen_layers, unfreeze_query=unfreeze_query, unfreeze_key=unfreeze_key, unfreeze_value=unfreeze_value)
|
46 |
+
n_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
47 |
+
trainable_params = '\n\t\t'.join([name for name, param in model.named_parameters() if param.requires_grad])
|
48 |
+
num_trainable_params_lm_head = sum(p.numel() for p in model.lm_head.parameters() if p.requires_grad)
|
49 |
+
num_trainable_params_esm = sum(p.numel() for p in model.esm.parameters() if p.requires_grad)
|
50 |
+
log_update(f'Unfroze final {n_unfrozen_layers} layers')
|
51 |
+
log_update(f'\tTrainable params: {n_trainable_params}\n\t\t{trainable_params}')
|
52 |
+
log_update(f"\tTrainable parameters in the lm_head: {num_trainable_params_lm_head}")
|
53 |
+
log_update(f"\tTrainable params in the ESM part: {num_trainable_params_esm}")
|
54 |
+
|
55 |
+
def train(model, tokenizer, optimizer, train_loader, val_loader, n_epochs=10, start_epoch=1, mask_percentage=0.15, mask_rate_scheduler=None, device='cuda', checkpoint_dir='./checkpoints'):
|
56 |
+
"""
|
57 |
+
Train the model
|
58 |
+
"""
|
59 |
+
# Loop over epochs
|
60 |
+
log_update("\n")
|
61 |
+
|
62 |
+
for epoch in range(start_epoch, start_epoch+n_epochs):
|
63 |
+
if mask_rate_scheduler is not None:
|
64 |
+
mask_rate_scheduler.reset() # resetting because we rant to ramp it up again every epoch
|
65 |
+
|
66 |
+
model.train()
|
67 |
+
total_train_loss = 0
|
68 |
+
total_weighted_train_loss = 0
|
69 |
+
total_train_masked_tokens = 0
|
70 |
+
|
71 |
+
log_update(f"Epoch {epoch}")
|
72 |
+
# Loop over train data with progress bar
|
73 |
+
with tqdm.tqdm(enumerate(train_loader), total=len(train_loader), desc='Training Batch', leave=True, position=0) as pbar:
|
74 |
+
for batch_idx, (inputs, prob) in pbar:
|
75 |
+
# Take a step with the mask rate scheduler, if there is one.
|
76 |
+
masking_rate = mask_percentage
|
77 |
+
if mask_rate_scheduler is not None:
|
78 |
+
mask_rate_scheduler.step()
|
79 |
+
masking_rate = mask_rate_scheduler.get_masking_rate()
|
80 |
+
log_update(f"\tBatch index: {batch_idx}\tMasking rate: {masking_rate:.5f}")
|
81 |
+
|
82 |
+
# Move tensors
|
83 |
+
inputs = {k: v.to(device) for k, v in inputs.items()}
|
84 |
+
prob = prob.to(device)
|
85 |
+
|
86 |
+
# Mask based on probability vectors
|
87 |
+
masked_inputs = batch_sample_mask_tokens_with_probabilities(inputs, prob, tokenizer,mask_percentage=masking_rate)
|
88 |
+
|
89 |
+
# Forward pass and update
|
90 |
+
optimizer.zero_grad()
|
91 |
+
outputs = model(**masked_inputs)
|
92 |
+
loss = outputs.loss
|
93 |
+
loss.backward()
|
94 |
+
optimizer.step()
|
95 |
+
|
96 |
+
# Number of masked tokens
|
97 |
+
num_masked_tokens = (masked_inputs["input_ids"] == tokenizer.mask_token_id).sum().item()
|
98 |
+
|
99 |
+
# Loss calculations and wandb log
|
100 |
+
total_train_loss += loss.item()
|
101 |
+
total_weighted_train_loss += loss.item() * num_masked_tokens # Multiply loss by number of masked tokens
|
102 |
+
total_train_masked_tokens += num_masked_tokens
|
103 |
+
wandb.log({"batch_loss": loss.item()})
|
104 |
+
|
105 |
+
# Save a checkpoint at the end of each epoch
|
106 |
+
checkpoint_path = os.path.join(checkpoint_dir, f'checkpoint_epoch_{epoch}')
|
107 |
+
model.save_model(checkpoint_path, optimizer=optimizer)
|
108 |
+
log_update(f'\nSaved checkpoint to {checkpoint_path}')
|
109 |
+
|
110 |
+
# Calculate and log average training loss on wandb
|
111 |
+
n_train_batches = len(train_loader)
|
112 |
+
avg_train_loss = total_train_loss / n_train_batches
|
113 |
+
avg_weighted_train_loss = total_weighted_train_loss / total_train_masked_tokens
|
114 |
+
train_perplexity = np.exp(avg_weighted_train_loss)
|
115 |
+
wandb.log({"epoch": epoch,
|
116 |
+
"total_train_loss": total_train_loss, "weighted_train_loss": total_weighted_train_loss,
|
117 |
+
"avg_train_loss": avg_train_loss, "avg_weighted_train_loss": avg_weighted_train_loss,
|
118 |
+
"train_perplexity": train_perplexity})
|
119 |
+
|
120 |
+
# Track curve stats for easy re-plotting of training curves later
|
121 |
+
train_stats_df = pd.DataFrame(data={
|
122 |
+
"epoch": [epoch],
|
123 |
+
"total_train_loss": [total_train_loss], "weighted_train_loss": [total_weighted_train_loss],
|
124 |
+
"avg_train_loss": [avg_train_loss], "avg_weighted_train_loss": [avg_weighted_train_loss],
|
125 |
+
"train_perplexity": [train_perplexity]
|
126 |
+
})
|
127 |
+
if os.path.exists(f"{checkpoint_dir}/train_curve.csv"): # add to file if necessary
|
128 |
+
train_stats_df.to_csv(f"{checkpoint_dir}/train_curve.csv",index=False,header=False,mode='a')
|
129 |
+
else: # make new file if necessary
|
130 |
+
train_stats_df.to_csv(f"{checkpoint_dir}/train_curve.csv",index=False)
|
131 |
+
|
132 |
+
# Validation loop
|
133 |
+
model.eval()
|
134 |
+
total_val_loss = 0
|
135 |
+
total_weighted_val_loss = 0
|
136 |
+
total_val_masked_tokens = 0
|
137 |
+
|
138 |
+
with torch.no_grad(): # No gradients needed
|
139 |
+
# Loop over val data with progress bar
|
140 |
+
with tqdm.tqdm(enumerate(val_loader), total=len(val_loader), desc='Validation Batch', leave=True, position=0) as vbar:
|
141 |
+
for batch_idx, (inputs, prob) in vbar:
|
142 |
+
# Move tensors
|
143 |
+
inputs = {k: v.to(device) for k, v in inputs.items()}
|
144 |
+
prob = prob.to(device)
|
145 |
+
|
146 |
+
# Mask based on probability vectors
|
147 |
+
## FIXED 15% masking for the validation set
|
148 |
+
masked_inputs = batch_sample_mask_tokens_with_probabilities(inputs, prob, tokenizer,mask_percentage=0.15)
|
149 |
+
|
150 |
+
# Forward pass
|
151 |
+
outputs = model(**masked_inputs)
|
152 |
+
val_loss = outputs.loss
|
153 |
+
|
154 |
+
# Number of masked tokens
|
155 |
+
num_masked_tokens = (masked_inputs["input_ids"] == tokenizer.mask_token_id).sum().item()
|
156 |
+
|
157 |
+
# Loss calculations
|
158 |
+
total_val_loss += val_loss.item()
|
159 |
+
total_weighted_val_loss += val_loss.item() * num_masked_tokens # Multiply loss by number of masked tokens
|
160 |
+
total_val_masked_tokens += num_masked_tokens
|
161 |
+
|
162 |
+
# Calculate and log avg. loss and perplexity (wandb and locally)
|
163 |
+
n_val_batches = len(val_loader)
|
164 |
+
avg_val_loss = total_val_loss / n_val_batches # avg per batch
|
165 |
+
avg_weighted_val_loss = total_weighted_val_loss / total_val_masked_tokens # avg per masked token
|
166 |
+
val_perplexity = np.exp(avg_weighted_val_loss)
|
167 |
+
wandb.log({"epoch": epoch,
|
168 |
+
"total_val_loss": total_val_loss, "weighted_val_loss": total_weighted_val_loss,
|
169 |
+
"avg_val_loss": avg_val_loss, "avg_weighted_val_loss": avg_weighted_val_loss,
|
170 |
+
"val_perplexity": val_perplexity})
|
171 |
+
|
172 |
+
# Track curve stats for easy re-plotting of training curves later
|
173 |
+
val_stats_df = pd.DataFrame(data={
|
174 |
+
"epoch": [epoch],
|
175 |
+
"total_val_loss": [total_val_loss], "weighted_val_loss": [total_weighted_val_loss],
|
176 |
+
"avg_val_loss": [avg_val_loss], "avg_weighted_val_loss": [avg_weighted_val_loss],
|
177 |
+
"val_perplexity": [val_perplexity]
|
178 |
+
})
|
179 |
+
if os.path.exists(f"{checkpoint_dir}/val_curve.csv"): # add to file if necessary
|
180 |
+
val_stats_df.to_csv(f"{checkpoint_dir}/val_curve.csv",index=False,header=False,mode='a')
|
181 |
+
else: # make new file if necessary
|
182 |
+
val_stats_df.to_csv(f"{checkpoint_dir}/val_curve.csv",index=False)
|
183 |
+
|
184 |
+
log_update(f"Epoch: {epoch}")
|
185 |
+
log_update(f"\tTrain set: Total batches = {n_train_batches}, Total masked tokens = {total_train_masked_tokens}, Total Loss = {total_train_loss:.4f}, Avg Batch Loss = {avg_train_loss:.4f}, Avg Masked Token-Weighted Loss = {avg_weighted_train_loss:.4f}, Perplexity = {train_perplexity:.4f}")
|
186 |
+
log_update(f"\tValidation set: Total batches = {n_val_batches}, Total masked tokens = {total_val_masked_tokens}, Total Loss = {total_val_loss:.4f}, Avg Batch Loss = {avg_val_loss:.4f}, Avg Masked Token-Weighted Loss = {avg_weighted_val_loss:.4f}, Perplexity = {val_perplexity:.4f}")
|
187 |
+
|
188 |
+
def test(model, tokenizer, test_loader, mask_percentage=0.15, device='cuda', checkpoint_dir='./checkpoints'):
|
189 |
+
"""
|
190 |
+
"""
|
191 |
+
model.to(device)
|
192 |
+
model.eval()
|
193 |
+
total_test_loss = 0
|
194 |
+
total_weighted_test_loss = 0
|
195 |
+
total_test_masked_tokens = 0
|
196 |
+
|
197 |
+
with torch.no_grad(): # No gradients needed
|
198 |
+
# Loop over test data (no progress bar)
|
199 |
+
with tqdm.tqdm(enumerate(test_loader), total=len(test_loader), desc='Test Batch', leave=True, position=0) as tbar:
|
200 |
+
for batch_idx, (inputs, prob) in tbar:
|
201 |
+
# Move tensors
|
202 |
+
inputs = {k: v.to(device) for k, v in inputs.items()}
|
203 |
+
prob = prob.to(device)
|
204 |
+
|
205 |
+
# Mask based on probability vectors
|
206 |
+
### FIXED 15% masking for the testing set
|
207 |
+
masked_inputs = batch_sample_mask_tokens_with_probabilities(inputs, prob, tokenizer, mask_percentage=0.15)
|
208 |
+
|
209 |
+
# Forward pass
|
210 |
+
outputs = model(**masked_inputs)
|
211 |
+
test_loss = outputs.loss
|
212 |
+
|
213 |
+
# Number of masked tokens
|
214 |
+
num_masked_tokens = (masked_inputs["input_ids"] == tokenizer.mask_token_id).sum().item()
|
215 |
+
|
216 |
+
# Loss calculations
|
217 |
+
total_test_loss += test_loss.item()
|
218 |
+
total_weighted_test_loss += test_loss.item() * num_masked_tokens # Multiply loss by number of masked tokens
|
219 |
+
total_test_masked_tokens += num_masked_tokens
|
220 |
+
|
221 |
+
# Compute and log avg. loss and perplexity
|
222 |
+
n_test_batches = len(test_loader)
|
223 |
+
avg_test_loss = total_test_loss / n_test_batches
|
224 |
+
avg_weighted_test_loss = total_weighted_test_loss / total_test_masked_tokens
|
225 |
+
test_perplexity = np.exp(avg_weighted_test_loss)
|
226 |
+
|
227 |
+
log_update(f"\nTest results:\nTotal batches = {n_test_batches}, Total masked tokens = {total_test_masked_tokens}, Total Loss = {total_test_loss:.4f}, Avg Batch Loss = {avg_test_loss:.4f}, Avg Masked Token-Weighted Loss = {avg_weighted_test_loss:.4f}, Perplexity = {test_perplexity:.4f}")
|
228 |
+
|
229 |
+
# Save to dataframe for plotting
|
230 |
+
test_stats_df = pd.DataFrame(data={
|
231 |
+
"total_test_loss": [total_test_loss], "weighted_test_loss": [total_weighted_test_loss],
|
232 |
+
"avg_test_loss": [avg_test_loss], "avg_weighted_test_loss": [avg_weighted_test_loss],
|
233 |
+
"test_perplexity": [test_perplexity]
|
234 |
+
})
|
235 |
+
test_stats_df.to_csv(f"{checkpoint_dir}/test_results.csv",index=False) # overwrite old file no matter what; should only be one test eval
|
236 |
+
|
237 |
+
def check_env_variables():
|
238 |
+
log_update("\nChecking on environment variables...")
|
239 |
+
log_update(f"\tWANDB_API_KEY: {os.environ.get('WANDB_API_KEY')}")
|
240 |
+
log_update(f"\tCUDA_VISIBLE_DEVICES: {os.environ.get('CUDA_VISIBLE_DEVICES')}")
|
241 |
+
log_update(f"\ttorch.cuda.device_count(): {torch.cuda.device_count()}")
|
242 |
+
for i in range(torch.cuda.device_count()):
|
243 |
+
log_update(f"\t\tDevice {i}: {torch.cuda.get_device_name(i)}")
|
244 |
+
|
245 |
+
def intialize_model_and_optimizer(finetune_from_scratch, device, path_to_starting_ckpt=None, learning_rate=1e-4, n_unfrozen_layers=0, unfreeze_query=False, unfreeze_key=False, unfreeze_value=False):
|
246 |
+
"""
|
247 |
+
Initializes the model, either from ESM-2-650M if finetuning from scratch, or from a prior checkpoint if not finetuning from scratch.
|
248 |
+
Also prepares
|
249 |
+
|
250 |
+
Args:
|
251 |
+
finetune_from_scratch (bool): True if finetuning from scratch. False if finetuning from a previous ckpt
|
252 |
+
path_to_starting_ckpt (str): path to starting ckpt for finetuning (optional)
|
253 |
+
"""
|
254 |
+
if not(finetune_from_scratch) and not(os.path.exists(path_to_starting_ckpt)):
|
255 |
+
raise Exception(f"Error: could not find {path_to_starting_ckpt}. When finetuning from a prior checkpoint, you must provide a valid path to that checkpoint.")
|
256 |
+
|
257 |
+
# if finetuning from scratch, initialize from scratch
|
258 |
+
if finetune_from_scratch:
|
259 |
+
log_update(f"\nInitializing FusOn-pLM model to be finetuned from scratch")
|
260 |
+
model = FusOnpLM() # because of __getattr__, we can use FusOnpLM() to get the model. It also contains the tokenizer.
|
261 |
+
model.to(device)
|
262 |
+
prepare_model(model, n_unfrozen_layers,
|
263 |
+
unfreeze_query=unfreeze_query, unfreeze_key=unfreeze_key, unfreeze_value=unfreeze_value)
|
264 |
+
|
265 |
+
# Set the optimizer here, change it if we are finetuning from an old checkpoint
|
266 |
+
optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=learning_rate)
|
267 |
+
|
268 |
+
return model, optimizer
|
269 |
+
|
270 |
+
# if not, initialize from starting ckpt
|
271 |
+
else:
|
272 |
+
log_update(f"\nInitializing FusOn-pLM model to be finetuned from previous checkpoint: {path_to_starting_ckpt}")
|
273 |
+
model = FusOnpLM(ckpt_path = path_to_starting_ckpt, mlm_head=True)
|
274 |
+
model.to(device)
|
275 |
+
prepare_model(model, n_unfrozen_layers,
|
276 |
+
unfreeze_query=unfreeze_query, unfreeze_key=unfreeze_key, unfreeze_value=unfreeze_value)
|
277 |
+
|
278 |
+
log_update(f"Loading optimizer state_dict from previous checkpoint")
|
279 |
+
optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()))
|
280 |
+
optimizer.load_state_dict(torch.load(os.path.join(path_to_starting_ckpt, "optimizer.pt"), map_location=device))
|
281 |
+
|
282 |
+
return model, optimizer
|
283 |
+
|
284 |
+
def main():
|
285 |
+
# Set probability type to uniform; only option
|
286 |
+
config.PROBABILITY_TYPE = "uniform"
|
287 |
+
|
288 |
+
# Set run name (WANDB_NAME)
|
289 |
+
kqv_tag = f"{'Q' if config.UNFREEZE_QUERY else ''}" + f"{'K' if config.UNFREEZE_KEY else ''}" + f"{'V' if config.UNFREEZE_VALUE else ''}"
|
290 |
+
timestamp = get_local_time()
|
291 |
+
# make a mask tag _mask{config.MASK_PERCENTAGE}
|
292 |
+
mask_tag = f"mask{config.MASK_PERCENTAGE}"
|
293 |
+
if config.VAR_MASK_RATE: # if variable masking rate, change the tag to relfect this
|
294 |
+
mask_tag=f"maskvar_{config.MASK_SCHEDULER}_low{config.MASK_LOW}_high{config.MASK_HIGH}"
|
295 |
+
|
296 |
+
# Define the train settings string and wandb name from this
|
297 |
+
TRAIN_SETTINGS_STRING = f"{config.PROBABILITY_TYPE}_{config.MAX_LENGTH}_ft_{config.N_UNFROZEN_LAYERS}layers_{kqv_tag}_b{config.BATCH_SIZE}_lr{config.LEARNING_RATE}_{mask_tag}"
|
298 |
+
WANDB_NAME = f'{TRAIN_SETTINGS_STRING}-{timestamp}'
|
299 |
+
|
300 |
+
# Create directory for model checkpoints
|
301 |
+
checkpoint_dir = f'checkpoints/{WANDB_NAME}'
|
302 |
+
start_epoch = 1
|
303 |
+
|
304 |
+
# Determine if we're adding to an old log file or opening a new one
|
305 |
+
logmode='w'
|
306 |
+
|
307 |
+
# If we're finetuning from a checkpoint, save to the same folder instead, and keep track of which epoch to start on
|
308 |
+
# Also, load the optimizer from here
|
309 |
+
if not(config.FINETUNE_FROM_SCRATCH):
|
310 |
+
logmode='a'
|
311 |
+
path_to_starting_ckpt = config.PATH_TO_STARTING_CKPT
|
312 |
+
checkpoint_dir = path_to_starting_ckpt[0:path_to_starting_ckpt.rindex('/')]
|
313 |
+
START_MODEL_TRAIN_SETTINGS_STRING = checkpoint_dir[checkpoint_dir.index('checkpoints/')+len('checkpoints/'):checkpoint_dir.index('-')]
|
314 |
+
start_epoch = int(path_to_starting_ckpt.split('/checkpoint_epoch_')[1])+1
|
315 |
+
|
316 |
+
os.makedirs(f'checkpoints', exist_ok=True)
|
317 |
+
os.makedirs(checkpoint_dir, exist_ok=True)
|
318 |
+
|
319 |
+
# Open log file
|
320 |
+
LOG_PATH = f'{checkpoint_dir}/training_log.txt'
|
321 |
+
ERR_PATH = f'{checkpoint_dir}/training_errors.txt'
|
322 |
+
with open_logfile(LOG_PATH,mode=logmode), open_errfile(ERR_PATH,mode=logmode):
|
323 |
+
if not(config.FINETUNE_FROM_SCRATCH):
|
324 |
+
log_update(f"\n{'-'*200}\nResuming finetuning from checkpoint {start_epoch-1} (first new checkpoint: {start_epoch})\n")
|
325 |
+
log_update(f"Settings tag for original model (starting point for finetuning) = {START_MODEL_TRAIN_SETTINGS_STRING}\nSettings tag for new model based on configs = {TRAIN_SETTINGS_STRING}\nSame: {START_MODEL_TRAIN_SETTINGS_STRING==TRAIN_SETTINGS_STRING}\n")
|
326 |
+
# ONLY proceed with training if we're using the same settings, otherwise we are not finetuning the model we think we are!
|
327 |
+
assert START_MODEL_TRAIN_SETTINGS_STRING==TRAIN_SETTINGS_STRING
|
328 |
+
|
329 |
+
# Print configurations
|
330 |
+
print_configpy(config)
|
331 |
+
|
332 |
+
# Verify that the environment variables are set correctly
|
333 |
+
check_env_variables()
|
334 |
+
|
335 |
+
# Check CUDA availability and set device
|
336 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
337 |
+
log_update(f"\nUsing device: {device}")
|
338 |
+
|
339 |
+
# Init wandb
|
340 |
+
wandb.init(project=config.WANDB_PROJECT, entity=config.WANDB_ENTITY, name=WANDB_NAME , config={
|
341 |
+
"batch_size": config.BATCH_SIZE,
|
342 |
+
"epochs": config.EPOCHS,
|
343 |
+
"learning_rate": config.LEARNING_RATE,
|
344 |
+
})
|
345 |
+
|
346 |
+
# Initialize model and prepare it (freeze/unfreeze proper layers). Initialize optimizer as well. Details depend on whether we are finetuning from scratch.
|
347 |
+
model, optimizer = intialize_model_and_optimizer(config.FINETUNE_FROM_SCRATCH, device,
|
348 |
+
path_to_starting_ckpt=config.PATH_TO_STARTING_CKPT,
|
349 |
+
learning_rate=config.LEARNING_RATE,
|
350 |
+
n_unfrozen_layers=config.N_UNFROZEN_LAYERS,
|
351 |
+
unfreeze_query=config.UNFREEZE_QUERY,
|
352 |
+
unfreeze_key=config.UNFREEZE_KEY,
|
353 |
+
unfreeze_value=config.UNFREEZE_VALUE)
|
354 |
+
|
355 |
+
# Initialize the tokenizer (independent of starting model for finetuning)
|
356 |
+
tokenizer = model.tokenizer
|
357 |
+
|
358 |
+
# Create DataLoader instances and perform sanity checks on them
|
359 |
+
train_loader = get_dataloader(config.TRAIN_PATH, tokenizer, probability_type=config.PROBABILITY_TYPE, batch_size=config.BATCH_SIZE, max_length=config.MAX_LENGTH, shuffle=True) ## FOR DEBUGGING ONLY, change shuffle to False. Otherwise, True!!
|
360 |
+
val_loader = get_dataloader(config.VAL_PATH, tokenizer, probability_type=config.PROBABILITY_TYPE, batch_size=config.BATCH_SIZE, max_length=config.MAX_LENGTH, shuffle=False)
|
361 |
+
test_loader = get_dataloader(config.TEST_PATH, tokenizer, probability_type=config.PROBABILITY_TYPE, batch_size=config.BATCH_SIZE, max_length=config.MAX_LENGTH, shuffle=False)
|
362 |
+
|
363 |
+
# If we're continuing to finetune an old ckpt, store the old batch diversity plot before we overwrite it
|
364 |
+
check_dataloaders(train_loader, val_loader, test_loader, max_length=config.MAX_LENGTH, checkpoint_dir=checkpoint_dir)
|
365 |
+
|
366 |
+
# Set up a masking rate scheduler, if one is needed
|
367 |
+
mask_rate_scheduler = None
|
368 |
+
if config.VAR_MASK_RATE:
|
369 |
+
mask_rate_scheduler = get_mask_rate_scheduler(scheduler_type=config.MASK_SCHEDULER,
|
370 |
+
min_masking_rate=config.MASK_LOW,
|
371 |
+
max_masking_rate=config.MASK_HIGH,
|
372 |
+
total_batches=len(train_loader),
|
373 |
+
total_steps=config.MASK_STEPS)
|
374 |
+
|
375 |
+
# Train the model
|
376 |
+
train(model, tokenizer, optimizer, train_loader, val_loader,
|
377 |
+
n_epochs=config.EPOCHS,
|
378 |
+
start_epoch = start_epoch,
|
379 |
+
device=device,
|
380 |
+
mask_rate_scheduler=mask_rate_scheduler,
|
381 |
+
mask_percentage=config.MASK_PERCENTAGE,
|
382 |
+
checkpoint_dir=checkpoint_dir)
|
383 |
+
|
384 |
+
# Test the model
|
385 |
+
test(model, tokenizer, test_loader, mask_percentage=0.15, device=device, checkpoint_dir=checkpoint_dir)
|
386 |
+
|
387 |
+
if __name__ == "__main__":
|
388 |
+
main()
|
fuson_plm/training/utils.py
ADDED
@@ -0,0 +1,312 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import pandas as pd
|
3 |
+
import torch
|
4 |
+
import os
|
5 |
+
from torch.nn.functional import softmax
|
6 |
+
from fuson_plm.utils.logging import log_update
|
7 |
+
from torch.utils.data import Dataset, DataLoader
|
8 |
+
from transformers import AutoTokenizer
|
9 |
+
from abc import ABC, abstractmethod
|
10 |
+
|
11 |
+
#----------------------------------------------------------------------------------------------------------------------------------------------------
|
12 |
+
#### Masking Rate Scheduler base class and sub classes
|
13 |
+
# abstract base class
|
14 |
+
class MaskingRateScheduler(ABC):
|
15 |
+
def __init__(self, total_steps, min_masking_rate, max_masking_rate, last_step=-1):
|
16 |
+
self.total_steps = total_steps
|
17 |
+
self.min_masking_rate = min_masking_rate
|
18 |
+
self.max_masking_rate = max_masking_rate
|
19 |
+
self.current_step = last_step
|
20 |
+
|
21 |
+
def step(self):
|
22 |
+
self.current_step += 1
|
23 |
+
|
24 |
+
def reset(self):
|
25 |
+
"""Reset the scheduler to its initial state."""
|
26 |
+
self.current_step = -1
|
27 |
+
|
28 |
+
def get_masking_rate(self):
|
29 |
+
progress = self.current_step / self.total_steps
|
30 |
+
return self.compute_masking_rate(progress)
|
31 |
+
|
32 |
+
@abstractmethod
|
33 |
+
def compute_masking_rate(self, progress):
|
34 |
+
"""To be implemented by subclasses for specific increase functions."""
|
35 |
+
raise NotImplementedError("Subclasses must implement this method.")
|
36 |
+
|
37 |
+
|
38 |
+
class CosineIncreaseMaskingRateScheduler(MaskingRateScheduler):
|
39 |
+
def compute_masking_rate(self, progress):
|
40 |
+
# Use a cosine increase function
|
41 |
+
cosine_increase = 0.5 * (1 - np.cos(np.pi * progress))
|
42 |
+
return self.min_masking_rate + (self.max_masking_rate - self.min_masking_rate) * cosine_increase
|
43 |
+
|
44 |
+
class LogLinearIncreaseMaskingRateScheduler(MaskingRateScheduler):
|
45 |
+
def compute_masking_rate(self, progress):
|
46 |
+
# Avoid log(0) by clamping progress to a minimum of a small positive number
|
47 |
+
progress = max(progress, 1e-10)
|
48 |
+
log_linear_increase = np.log1p(progress) / np.log1p(1) # Normalizing to keep range in [0, 1]
|
49 |
+
return self.min_masking_rate + (self.max_masking_rate - self.min_masking_rate) * log_linear_increase
|
50 |
+
|
51 |
+
class StepwiseIncreaseMaskingRateScheduler(MaskingRateScheduler):
|
52 |
+
def __init__(self, total_batches, min_masking_rate, max_masking_rate, num_steps):
|
53 |
+
super().__init__(total_steps=total_batches, min_masking_rate=min_masking_rate, max_masking_rate=max_masking_rate)
|
54 |
+
self.num_steps = num_steps
|
55 |
+
self.batch_interval = total_batches // (num_steps) # Adjusting to ensure max rate is included
|
56 |
+
self.rate_increment = (max_masking_rate - min_masking_rate) / (num_steps - 1) # Include end rate in the steps
|
57 |
+
|
58 |
+
def compute_masking_rate(self, progress):
|
59 |
+
# Determine the current step based on the number of completed batches
|
60 |
+
current_step = int(self.current_step / self.batch_interval)
|
61 |
+
# Cap the step number to `num_steps - 1` to include the max rate at the final step
|
62 |
+
current_step = min(current_step, self.num_steps - 1)
|
63 |
+
# Calculate the masking rate for the current step
|
64 |
+
masking_rate = self.min_masking_rate + current_step * self.rate_increment
|
65 |
+
return masking_rate
|
66 |
+
|
67 |
+
def get_mask_rate_scheduler(scheduler_type="cosine",min_masking_rate=0.15,max_masking_rate=0.40,total_batches=100,total_steps=20):
|
68 |
+
"""
|
69 |
+
Initialize the mask rate scheduler and return it
|
70 |
+
"""
|
71 |
+
if scheduler_type=="cosine":
|
72 |
+
return CosineIncreaseMaskingRateScheduler(total_steps=total_batches,
|
73 |
+
min_masking_rate=min_masking_rate,
|
74 |
+
max_masking_rate=max_masking_rate)
|
75 |
+
elif scheduler_type=="loglinear":
|
76 |
+
return LogLinearIncreaseMaskingRateScheduler(total_steps=total_batches,
|
77 |
+
min_masking_rate=min_masking_rate,
|
78 |
+
max_masking_rate=max_masking_rate)
|
79 |
+
elif scheduler_type=="stepwise":
|
80 |
+
return StepwiseIncreaseMaskingRateScheduler(total_batches=total_batches,
|
81 |
+
num_steps=total_steps,
|
82 |
+
min_masking_rate=min_masking_rate,
|
83 |
+
max_masking_rate=max_masking_rate)
|
84 |
+
else:
|
85 |
+
raise Exception("Must specify valid scheduler_type: cosine, loglinear, stepwise")
|
86 |
+
|
87 |
+
# Adjusted Dataloader for the sequences and probability vectors
|
88 |
+
class ProteinDataset(Dataset):
|
89 |
+
def __init__(self, data_path, tokenizer, probability_type, max_length=512):
|
90 |
+
self.dataframe = pd.read_csv(data_path)
|
91 |
+
self.tokenizer = tokenizer
|
92 |
+
self.probability_type=probability_type
|
93 |
+
self.max_length = max_length
|
94 |
+
|
95 |
+
self.set_probabilities()
|
96 |
+
|
97 |
+
def __len__(self):
|
98 |
+
return len(self.dataframe)
|
99 |
+
|
100 |
+
def set_probabilities(self):
|
101 |
+
if self.probability_type=="snp":
|
102 |
+
self.dataframe = self.dataframe.rename(columns={'snp_probabilities':'probabilities'})
|
103 |
+
if self.probability_type=="uniform":
|
104 |
+
self.dataframe['probabilities'] = self.dataframe['sequence'].apply(len).apply(lambda x: ('1,'*x)[0:-1])
|
105 |
+
|
106 |
+
# make probabilities into numbers if they aren't already
|
107 |
+
if type(self.dataframe['probabilities'][0]) == str:
|
108 |
+
self.dataframe['probabilities'] = self.dataframe['probabilities'].apply(
|
109 |
+
lambda x: np.array([float(i) for i in x.split(',')])
|
110 |
+
)
|
111 |
+
|
112 |
+
def get_padded_probabilities(self, idx):
|
113 |
+
'''
|
114 |
+
Pads probabilities to max_length if they're too short; truncate them if they're too long
|
115 |
+
'''
|
116 |
+
no_mask_value = int(-1e9) # will be used to make sure CLS and PAD aren't masked
|
117 |
+
|
118 |
+
# add a no-mask slot for <CLS>
|
119 |
+
prob = np.concatenate((
|
120 |
+
np.array([no_mask_value]),
|
121 |
+
self.dataframe.iloc[idx]['probabilities']
|
122 |
+
)
|
123 |
+
)
|
124 |
+
|
125 |
+
# Pad with no_mask_value for everything after the probability vector ends
|
126 |
+
if len(prob) < self.max_length:
|
127 |
+
return np.pad(
|
128 |
+
prob,
|
129 |
+
(0, self.max_length - len(prob)),
|
130 |
+
'constant', constant_values=(0,no_mask_value))
|
131 |
+
|
132 |
+
# If it's too long, we need to truncate, but we also need to change the last token to an <EOS>.
|
133 |
+
prob = prob[0:self.max_length-1]
|
134 |
+
prob = np.concatenate((
|
135 |
+
prob,
|
136 |
+
np.array([no_mask_value]),
|
137 |
+
)
|
138 |
+
)
|
139 |
+
return prob
|
140 |
+
|
141 |
+
def __getitem__(self, idx):
|
142 |
+
sequence = self.dataframe.iloc[idx]['sequence']
|
143 |
+
probability = self.get_padded_probabilities(idx) # extract them
|
144 |
+
inputs = self.tokenizer(sequence, return_tensors="pt", padding='max_length', truncation=True, max_length=self.max_length) # does this have to be 512?
|
145 |
+
inputs = {key: tensor.squeeze(0) for key, tensor in inputs.items()} # Remove batch dimension
|
146 |
+
return inputs, probability
|
147 |
+
|
148 |
+
def get_dataloader(data_path, tokenizer, probability_type='snp', max_length=512, batch_size=8, shuffle=True):
|
149 |
+
"""
|
150 |
+
Creates a DataLoader for the dataset.
|
151 |
+
Args:
|
152 |
+
data_path (str): Path to the CSV file (train, val, or test).
|
153 |
+
batch_size (int): Batch size.
|
154 |
+
shuffle (bool): Whether to shuffle the data.
|
155 |
+
tokenizer (Tokenizer): tokenizer object for data tokenization
|
156 |
+
Returns:
|
157 |
+
DataLoader: DataLoader object.
|
158 |
+
"""
|
159 |
+
dataset = ProteinDataset(data_path, tokenizer, probability_type, max_length=max_length)
|
160 |
+
return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle)
|
161 |
+
|
162 |
+
def check_dataloaders(train_loader, val_loader, test_loader, max_length=512, checkpoint_dir=''):
|
163 |
+
log_update(f'\nBuilt train, validation, and test dataloders')
|
164 |
+
log_update(f"\tNumber of sequences in the Training DataLoader: {len(train_loader.dataset)}")
|
165 |
+
log_update(f"\tNumber of sequences in the Validation DataLoader: {len(val_loader.dataset)}")
|
166 |
+
log_update(f"\tNumber of sequences in the Training DataLoader: {len(test_loader.dataset)}")
|
167 |
+
dataloader_overlaps = check_dataloader_overlap(train_loader, val_loader, test_loader)
|
168 |
+
if len(dataloader_overlaps)==0: log_update("\tDataloaders are clean (no overlaps)")
|
169 |
+
else: log_update(f"\tWARNING! sequence overlap found: {','.join(dataloader_overlaps)}")
|
170 |
+
|
171 |
+
# write length ranges to a text file
|
172 |
+
if not(os.path.exists(f'{checkpoint_dir}/batch_diversity')):
|
173 |
+
os.mkdir(f'{checkpoint_dir}/batch_diversity')
|
174 |
+
|
175 |
+
max_length_violators = []
|
176 |
+
for name, dataloader in {'train':train_loader, 'val':val_loader, 'test':test_loader}.items():
|
177 |
+
max_length_followed, length_ranges = check_max_length_and_length_diversity(dataloader, max_length)
|
178 |
+
if max_length_followed == False:
|
179 |
+
max_length_violators.append(name)
|
180 |
+
|
181 |
+
with open(f'{checkpoint_dir}/batch_diversity/{name}_batch_length_ranges.txt','w') as f:
|
182 |
+
for tup in length_ranges:
|
183 |
+
f.write(f'{tup[0]}\t{tup[1]}\n')
|
184 |
+
|
185 |
+
if len(max_length_violators)==0: log_update(f"\tDataloaders follow the max length limit set by user: {max_length}")
|
186 |
+
else: log_update(f"\tWARNING! these loaders have sequences longer than max length={max_length}: {','.join(max_length_violators)}")
|
187 |
+
|
188 |
+
def check_dataloader_overlap(train_loader, val_loader, test_loader):
|
189 |
+
"""
|
190 |
+
Check the data that's about to go into the model. Make sure there is no overlap between train, test, and val
|
191 |
+
|
192 |
+
Returns:
|
193 |
+
"""
|
194 |
+
train_protein_seqs = set(train_loader.dataset.dataframe['sequence'].unique())
|
195 |
+
val_protein_seqs = set(val_loader.dataset.dataframe['sequence'].unique())
|
196 |
+
test_protein_seqs = set(test_loader.dataset.dataframe['sequence'].unique())
|
197 |
+
|
198 |
+
tr_va = len(train_protein_seqs.intersection(val_protein_seqs))
|
199 |
+
tr_te = len(train_protein_seqs.intersection(test_protein_seqs))
|
200 |
+
va_te = len(val_protein_seqs.intersection(test_protein_seqs))
|
201 |
+
|
202 |
+
overlaps = []
|
203 |
+
if tr_va==tr_te==va_te==0:
|
204 |
+
return overlaps # data is clean
|
205 |
+
else:
|
206 |
+
if tr_va > 0: overlaps.append(f"Train-Val Overlap={tr_va}")
|
207 |
+
if tr_te > 0: overlaps.append(f"Train-Test Overlap={tr_te}")
|
208 |
+
if va_te > 0: overlaps.append(f"Val-Test Overlap={va_te}")
|
209 |
+
return overlaps
|
210 |
+
|
211 |
+
def check_max_length_and_length_diversity(dataloader, max_length):
|
212 |
+
"""
|
213 |
+
Check if all sequences in the DataLoader conform to the specified max_length,
|
214 |
+
and return the sequence length ranges within each batch.
|
215 |
+
|
216 |
+
Args:
|
217 |
+
dataloader (DataLoader): The DataLoader object to check.
|
218 |
+
max_length (int): The maximum allowed sequence length.
|
219 |
+
|
220 |
+
Returns:
|
221 |
+
bool: True if all sequences are within the max_length, False otherwise.
|
222 |
+
list: A list of tuples representing the min and max sequence lengths in each batch.
|
223 |
+
"""
|
224 |
+
length_ranges = []
|
225 |
+
all_within_max_length = True
|
226 |
+
|
227 |
+
for batch_idx, (inputs, _) in enumerate(dataloader):
|
228 |
+
input_ids = inputs['input_ids']
|
229 |
+
|
230 |
+
# Calculate the actual lengths of sequences in this batch
|
231 |
+
actual_lengths = (input_ids != dataloader.dataset.tokenizer.pad_token_id).sum(dim=1)
|
232 |
+
min_length = actual_lengths.min().item()
|
233 |
+
max_length_in_batch = actual_lengths.max().item()
|
234 |
+
|
235 |
+
# Check for max length violation
|
236 |
+
if max_length_in_batch > max_length:
|
237 |
+
#print(f"Error: Sequence exceeds max_length of {max_length} at batch {batch_idx + 1}. Max length found: {max_length_in_batch}")
|
238 |
+
all_within_max_length = False
|
239 |
+
|
240 |
+
# Store the length range for this batch
|
241 |
+
length_ranges.append((min_length, max_length_in_batch))
|
242 |
+
|
243 |
+
#print(f"All sequences in the DataLoader conform to the max_length of {max_length}.") if all_within_max_length else None
|
244 |
+
#print(f"Sequence length ranges per batch: {length_ranges}")
|
245 |
+
|
246 |
+
return all_within_max_length, length_ranges
|
247 |
+
|
248 |
+
|
249 |
+
def check_max_length_in_dataloader(dataloader, max_length):
|
250 |
+
"""
|
251 |
+
Check if all sequences in the DataLoader conform to the specified max_length.
|
252 |
+
|
253 |
+
Args:
|
254 |
+
dataloader (DataLoader): The DataLoader object to check.
|
255 |
+
max_length (int): The maximum allowed sequence length.
|
256 |
+
|
257 |
+
Returns:
|
258 |
+
bool: True if all sequences are within the max_length, False otherwise.
|
259 |
+
"""
|
260 |
+
for batch_idx, (inputs, _) in enumerate(dataloader):
|
261 |
+
input_ids = inputs['input_ids']
|
262 |
+
|
263 |
+
# Check if any sequence length exceeds max_length
|
264 |
+
if input_ids.size(1) > max_length:
|
265 |
+
return False
|
266 |
+
|
267 |
+
return True
|
268 |
+
|
269 |
+
|
270 |
+
def batch_sample_mask_tokens_with_probabilities(inputs, probabilities, tokenizer: AutoTokenizer, mask_percentage=0.15):
|
271 |
+
"""
|
272 |
+
"""
|
273 |
+
#print('the batch sample method was called!')
|
274 |
+
labels = inputs["input_ids"].detach().clone()
|
275 |
+
labels[labels != tokenizer.mask_token_id] = -100 # Set labels for unmasked tokens to -100
|
276 |
+
|
277 |
+
# Iterate over each sequence and its corresponding probabilities in the batch
|
278 |
+
for idx in range(inputs["input_ids"].size(0)): # Assuming the first dimension is batch size
|
279 |
+
input_ids = inputs["input_ids"][idx]
|
280 |
+
prob = probabilities[idx]
|
281 |
+
|
282 |
+
cls_token_index = (input_ids == 0).nonzero(as_tuple=False)[0].item()
|
283 |
+
eos_token_index = (input_ids == 2).nonzero(as_tuple=False)[0].item()
|
284 |
+
seq_length = eos_token_index - (cls_token_index+1)
|
285 |
+
|
286 |
+
assert prob.shape[0] == input_ids.shape[0]
|
287 |
+
|
288 |
+
# Normalize probabilities using softmax
|
289 |
+
prob = softmax(prob, dim=0).cpu().numpy() # move to CPU for numpy
|
290 |
+
assert 1 - sum(prob) < 1e-6
|
291 |
+
|
292 |
+
# Calculate the number of tokens to mask
|
293 |
+
num_tokens_to_mask = int(mask_percentage * seq_length)
|
294 |
+
|
295 |
+
# Choose indices to mask based on the probability distribution
|
296 |
+
mask_indices = np.random.choice(input_ids.shape[0], size=num_tokens_to_mask, replace=False, p=prob)
|
297 |
+
attention_mask_1_indices = np.arange(0, eos_token_index+1, 1)
|
298 |
+
|
299 |
+
# Mask the selected indices and set the corresponding labels
|
300 |
+
labels[idx, mask_indices] = input_ids[mask_indices].detach().clone()
|
301 |
+
input_ids[mask_indices] = tokenizer.mask_token_id
|
302 |
+
|
303 |
+
inputs["attention_mask"][idx] = torch.zeros_like(input_ids)
|
304 |
+
inputs["attention_mask"][idx][attention_mask_1_indices] = 1 # just added this to try and update the attention mask....
|
305 |
+
|
306 |
+
# Update the input_ids in the inputs dictionary
|
307 |
+
inputs["input_ids"][idx] = input_ids
|
308 |
+
|
309 |
+
inputs["labels"] = labels
|
310 |
+
return inputs
|
311 |
+
|
312 |
+
|