Fill-Mask
Transformers
PyTorch
esm
Inference Endpoints
File size: 22,458 Bytes
9a73cb0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
'''
This is a training script for finetuning ESM. 
I am going to freeze the parameters in the head and unfreeze the last N layers in the model.
'''

import os
import fuson_plm.training.config as config 

# Set the WANDB_API_KEY environment variable
os.environ['WANDB_API_KEY'] = config.WANDB_API_KEY
os.environ['CUDA_VISIBLE_DEVICES'] = config.CUDA_VISIBLE_DEVICES

import torch
import numpy as np
import pandas as pd
import tqdm
from datetime import datetime
import wandb
import pytz
import sys

from transformers import AdamW

from fuson_plm.utils.logging import print_configpy, get_local_time, open_logfile, open_errfile, log_update
from fuson_plm.training.model import FusOnpLM
from fuson_plm.training.utils import batch_sample_mask_tokens_with_probabilities, get_dataloader, check_dataloaders, get_mask_rate_scheduler
from fuson_plm.training.plot import make_train_val_test_bd_plot

def prepare_model(model, n_unfrozen_layers,  unfreeze_query=True, unfreeze_key=True, unfreeze_value=True):
    # Log the model's initial state
    n_layers = model.count_encoder_layers()
    total_params = sum(p.numel() for p in model.parameters())
    total_head_params = sum(p.numel() for p in model.lm_head.parameters())
    log_update(f'\nInitial state:\n\tTotal number of layers in the model: {n_layers}')
    log_update(f'\tTotal parameters in the AutoModelforMaskedLM model: {total_params}')
    log_update(f'\tTotal parameters in the MLM Head ONLY: {total_head_params}')
    
    # Freeze the model to start
    model.freeze_model()
    n_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    log_update(f'Froze all {model.n_layers} model layers')
    log_update(f'\tTrainable params: {n_trainable_params}')
    
    # Unfreeze the last n layers
    model.unfreeze_last_n_layers(n_unfrozen_layers, unfreeze_query=unfreeze_query, unfreeze_key=unfreeze_key, unfreeze_value=unfreeze_value)
    n_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    trainable_params = '\n\t\t'.join([name for name, param in model.named_parameters() if param.requires_grad])
    num_trainable_params_lm_head = sum(p.numel() for p in model.lm_head.parameters() if p.requires_grad)
    num_trainable_params_esm = sum(p.numel() for p in model.esm.parameters() if p.requires_grad)
    log_update(f'Unfroze final {n_unfrozen_layers} layers')
    log_update(f'\tTrainable params: {n_trainable_params}\n\t\t{trainable_params}')
    log_update(f"\tTrainable parameters in the lm_head: {num_trainable_params_lm_head}")
    log_update(f"\tTrainable params in the ESM part: {num_trainable_params_esm}") 
    
def train(model, tokenizer, optimizer, train_loader, val_loader, n_epochs=10, start_epoch=1, mask_percentage=0.15, mask_rate_scheduler=None, device='cuda', checkpoint_dir='./checkpoints'):
    """
    Train the model
    """
    # Loop over epochs
    log_update("\n")
   
    for epoch in range(start_epoch, start_epoch+n_epochs):
        if mask_rate_scheduler is not None:
            mask_rate_scheduler.reset() # resetting because we rant to ramp it up again every epoch
        
        model.train()
        total_train_loss = 0
        total_weighted_train_loss = 0
        total_train_masked_tokens = 0

        log_update(f"Epoch {epoch}")
        # Loop over train data with progress bar
        with tqdm.tqdm(enumerate(train_loader), total=len(train_loader), desc='Training Batch', leave=True, position=0) as pbar:
            for batch_idx, (inputs, prob) in pbar:
                # Take a step with the mask rate scheduler, if there is one. 
                masking_rate = mask_percentage
                if mask_rate_scheduler is not None:
                    mask_rate_scheduler.step()
                    masking_rate = mask_rate_scheduler.get_masking_rate()
                    log_update(f"\tBatch index: {batch_idx}\tMasking rate: {masking_rate:.5f}")
                
                # Move tensors
                inputs = {k: v.to(device) for k, v in inputs.items()} 
                prob = prob.to(device) 
                
                # Mask based on probability vectors
                masked_inputs = batch_sample_mask_tokens_with_probabilities(inputs, prob, tokenizer,mask_percentage=masking_rate)
                
                # Forward pass and update
                optimizer.zero_grad()
                outputs = model(**masked_inputs)
                loss = outputs.loss
                loss.backward()
                optimizer.step()
                
                # Number of masked tokens
                num_masked_tokens = (masked_inputs["input_ids"] == tokenizer.mask_token_id).sum().item() 
            
                # Loss calculations and wandb log
                total_train_loss += loss.item()
                total_weighted_train_loss += loss.item() * num_masked_tokens  # Multiply loss by number of masked tokens
                total_train_masked_tokens += num_masked_tokens
                wandb.log({"batch_loss": loss.item()})

            # Save a checkpoint at the end of each epoch
            checkpoint_path = os.path.join(checkpoint_dir, f'checkpoint_epoch_{epoch}')
            model.save_model(checkpoint_path, optimizer=optimizer)
            log_update(f'\nSaved checkpoint to {checkpoint_path}')

            # Calculate and log average training loss on wandb
            n_train_batches = len(train_loader)
            avg_train_loss = total_train_loss / n_train_batches
            avg_weighted_train_loss = total_weighted_train_loss / total_train_masked_tokens
            train_perplexity = np.exp(avg_weighted_train_loss)
            wandb.log({"epoch": epoch,
                       "total_train_loss": total_train_loss, "weighted_train_loss": total_weighted_train_loss,
                       "avg_train_loss": avg_train_loss, "avg_weighted_train_loss": avg_weighted_train_loss,
                       "train_perplexity": train_perplexity})
            
            # Track curve stats for easy re-plotting of training curves later
            train_stats_df = pd.DataFrame(data={
                    "epoch": [epoch],
                    "total_train_loss": [total_train_loss], "weighted_train_loss": [total_weighted_train_loss],
                    "avg_train_loss": [avg_train_loss], "avg_weighted_train_loss": [avg_weighted_train_loss],
                    "train_perplexity": [train_perplexity]
                    })
            if os.path.exists(f"{checkpoint_dir}/train_curve.csv"):   # add to file if necessary 
                train_stats_df.to_csv(f"{checkpoint_dir}/train_curve.csv",index=False,header=False,mode='a')
            else:   # make new file if necessary
                train_stats_df.to_csv(f"{checkpoint_dir}/train_curve.csv",index=False)

            # Validation loop
            model.eval()
            total_val_loss = 0
            total_weighted_val_loss = 0
            total_val_masked_tokens = 0
            
            with torch.no_grad(): # No gradients needed
                # Loop over val data with progress bar
                with tqdm.tqdm(enumerate(val_loader), total=len(val_loader), desc='Validation Batch', leave=True, position=0) as vbar:
                    for batch_idx, (inputs, prob) in vbar:
                        # Move tensors
                        inputs = {k: v.to(device) for k, v in inputs.items()}  
                        prob = prob.to(device) 
                        
                        # Mask based on probability vectors
                        ## FIXED 15% masking for the validation set
                        masked_inputs = batch_sample_mask_tokens_with_probabilities(inputs, prob, tokenizer,mask_percentage=0.15)
                        
                        # Forward pass
                        outputs = model(**masked_inputs)
                        val_loss = outputs.loss
                        
                        # Number of masked tokens
                        num_masked_tokens = (masked_inputs["input_ids"] == tokenizer.mask_token_id).sum().item() 
                        
                        # Loss calculations
                        total_val_loss += val_loss.item()
                        total_weighted_val_loss += val_loss.item() * num_masked_tokens  # Multiply loss by number of masked tokens
                        total_val_masked_tokens += num_masked_tokens

                # Calculate and log avg. loss and perplexity (wandb and locally)
                n_val_batches = len(val_loader)
                avg_val_loss = total_val_loss / n_val_batches                                # avg per batch
                avg_weighted_val_loss = total_weighted_val_loss / total_val_masked_tokens       # avg per masked token
                val_perplexity = np.exp(avg_weighted_val_loss)
                wandb.log({"epoch": epoch, 
                           "total_val_loss": total_val_loss, "weighted_val_loss": total_weighted_val_loss,
                           "avg_val_loss": avg_val_loss, "avg_weighted_val_loss": avg_weighted_val_loss,
                           "val_perplexity": val_perplexity})
                
                # Track curve stats for easy re-plotting of training curves later
                val_stats_df = pd.DataFrame(data={
                    "epoch": [epoch], 
                    "total_val_loss": [total_val_loss], "weighted_val_loss": [total_weighted_val_loss],
                    "avg_val_loss": [avg_val_loss], "avg_weighted_val_loss": [avg_weighted_val_loss],
                    "val_perplexity": [val_perplexity]
                })
                if os.path.exists(f"{checkpoint_dir}/val_curve.csv"):   # add to file if necessary 
                    val_stats_df.to_csv(f"{checkpoint_dir}/val_curve.csv",index=False,header=False,mode='a')
                else:   # make new file if necessary
                    val_stats_df.to_csv(f"{checkpoint_dir}/val_curve.csv",index=False)
                    
                log_update(f"Epoch: {epoch}")
                log_update(f"\tTrain set: Total batches = {n_train_batches}, Total masked tokens = {total_train_masked_tokens}, Total Loss = {total_train_loss:.4f}, Avg Batch Loss = {avg_train_loss:.4f}, Avg Masked Token-Weighted Loss = {avg_weighted_train_loss:.4f}, Perplexity = {train_perplexity:.4f}")
                log_update(f"\tValidation set: Total batches = {n_val_batches}, Total masked tokens = {total_val_masked_tokens}, Total Loss = {total_val_loss:.4f},  Avg Batch Loss = {avg_val_loss:.4f}, Avg Masked Token-Weighted Loss = {avg_weighted_val_loss:.4f}, Perplexity = {val_perplexity:.4f}")

def test(model, tokenizer, test_loader, mask_percentage=0.15, device='cuda', checkpoint_dir='./checkpoints'):
    """
    """
    model.to(device)
    model.eval()  
    total_test_loss = 0
    total_weighted_test_loss = 0
    total_test_masked_tokens = 0

    with torch.no_grad():  # No gradients needed
        # Loop over test data (no progress bar)
        with tqdm.tqdm(enumerate(test_loader), total=len(test_loader), desc='Test Batch', leave=True, position=0) as tbar:
            for batch_idx, (inputs, prob) in tbar:
                # Move tensors
                inputs = {k: v.to(device) for k, v in inputs.items()} 
                prob = prob.to(device) 
                
                # Mask based on probability vectors
                ### FIXED 15% masking for the testing set
                masked_inputs = batch_sample_mask_tokens_with_probabilities(inputs, prob, tokenizer, mask_percentage=0.15)
                
                # Forward pass
                outputs = model(**masked_inputs)
                test_loss = outputs.loss
                
                # Number of masked tokens
                num_masked_tokens = (masked_inputs["input_ids"] == tokenizer.mask_token_id).sum().item() 

                # Loss calculations
                total_test_loss += test_loss.item()
                total_weighted_test_loss += test_loss.item() * num_masked_tokens  # Multiply loss by number of masked tokens
                total_test_masked_tokens += num_masked_tokens

            # Compute and log avg. loss and perplexity
            n_test_batches = len(test_loader)
            avg_test_loss = total_test_loss / n_test_batches
            avg_weighted_test_loss = total_weighted_test_loss / total_test_masked_tokens
            test_perplexity = np.exp(avg_weighted_test_loss)

            log_update(f"\nTest results:\nTotal batches = {n_test_batches}, Total masked tokens = {total_test_masked_tokens}, Total Loss = {total_test_loss:.4f},  Avg Batch Loss = {avg_test_loss:.4f}, Avg Masked Token-Weighted Loss = {avg_weighted_test_loss:.4f}, Perplexity = {test_perplexity:.4f}")
            
            # Save to dataframe for plotting
            test_stats_df = pd.DataFrame(data={
                    "total_test_loss": [total_test_loss], "weighted_test_loss": [total_weighted_test_loss],
                    "avg_test_loss": [avg_test_loss], "avg_weighted_test_loss": [avg_weighted_test_loss],
                    "test_perplexity": [test_perplexity]
                })
            test_stats_df.to_csv(f"{checkpoint_dir}/test_results.csv",index=False)    # overwrite old file no matter what; should only be one test eval

def check_env_variables():
    log_update("\nChecking on environment variables...")
    log_update(f"\tWANDB_API_KEY: {os.environ.get('WANDB_API_KEY')}")
    log_update(f"\tCUDA_VISIBLE_DEVICES: {os.environ.get('CUDA_VISIBLE_DEVICES')}")
    log_update(f"\ttorch.cuda.device_count(): {torch.cuda.device_count()}")
    for i in range(torch.cuda.device_count()):
        log_update(f"\t\tDevice {i}: {torch.cuda.get_device_name(i)}")
        
def intialize_model_and_optimizer(finetune_from_scratch, device, path_to_starting_ckpt=None, learning_rate=1e-4, n_unfrozen_layers=0, unfreeze_query=False, unfreeze_key=False, unfreeze_value=False):
    """
    Initializes the model, either from ESM-2-650M if finetuning from scratch, or from a prior checkpoint if not finetuning from scratch.
    Also prepares 
    
    Args:
        finetune_from_scratch (bool): True if finetuning from scratch. False if finetuning from a previous ckpt
        path_to_starting_ckpt (str): path to starting ckpt for finetuning (optional) 
    """
    if not(finetune_from_scratch) and not(os.path.exists(path_to_starting_ckpt)):
        raise Exception(f"Error: could not find {path_to_starting_ckpt}. When finetuning from a prior checkpoint, you must provide a valid path to that checkpoint.")
    
    # if finetuning from scratch, initialize from scratch
    if finetune_from_scratch:
        log_update(f"\nInitializing FusOn-pLM model to be finetuned from scratch")
        model = FusOnpLM()                  # because of __getattr__, we can use FusOnpLM() to get the model. It also contains the tokenizer. 
        model.to(device)
        prepare_model(model, n_unfrozen_layers,  
                unfreeze_query=unfreeze_query, unfreeze_key=unfreeze_key, unfreeze_value=unfreeze_value)
    
        # Set the optimizer here, change it if we are finetuning from an old checkpoint
        optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=learning_rate)
        
        return model, optimizer

    # if not, initialize from starting ckpt
    else:
        log_update(f"\nInitializing FusOn-pLM model to be finetuned from previous checkpoint: {path_to_starting_ckpt}")
        model = FusOnpLM(ckpt_path = path_to_starting_ckpt, mlm_head=True)
        model.to(device)
        prepare_model(model, n_unfrozen_layers,  
                unfreeze_query=unfreeze_query, unfreeze_key=unfreeze_key, unfreeze_value=unfreeze_value)
        
        log_update(f"Loading optimizer state_dict from previous checkpoint")
        optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()))
        optimizer.load_state_dict(torch.load(os.path.join(path_to_starting_ckpt, "optimizer.pt"), map_location=device))
        
        return model, optimizer
        
def main():
    # Set probability type to uniform; only option 
    config.PROBABILITY_TYPE = "uniform"
    
    # Set run name (WANDB_NAME)
    kqv_tag = f"{'Q' if config.UNFREEZE_QUERY else ''}" + f"{'K' if config.UNFREEZE_KEY else ''}" + f"{'V' if config.UNFREEZE_VALUE else ''}"
    timestamp = get_local_time()
    # make a mask tag _mask{config.MASK_PERCENTAGE}
    mask_tag = f"mask{config.MASK_PERCENTAGE}"
    if config.VAR_MASK_RATE:    # if variable masking rate, change the tag to relfect this
        mask_tag=f"maskvar_{config.MASK_SCHEDULER}_low{config.MASK_LOW}_high{config.MASK_HIGH}"
        
    # Define the train settings string and wandb name from this
    TRAIN_SETTINGS_STRING = f"{config.PROBABILITY_TYPE}_{config.MAX_LENGTH}_ft_{config.N_UNFROZEN_LAYERS}layers_{kqv_tag}_b{config.BATCH_SIZE}_lr{config.LEARNING_RATE}_{mask_tag}"
    WANDB_NAME = f'{TRAIN_SETTINGS_STRING}-{timestamp}'
    
    # Create directory for model checkpoints
    checkpoint_dir = f'checkpoints/{WANDB_NAME}'
    start_epoch = 1
    
    # Determine if we're adding to an old log file or opening a new one
    logmode='w'
    
    # If we're finetuning from a checkpoint, save to the same folder instead, and keep track of which epoch to start on
    # Also, load the optimizer from here
    if not(config.FINETUNE_FROM_SCRATCH):
        logmode='a'
        path_to_starting_ckpt = config.PATH_TO_STARTING_CKPT
        checkpoint_dir = path_to_starting_ckpt[0:path_to_starting_ckpt.rindex('/')]
        START_MODEL_TRAIN_SETTINGS_STRING = checkpoint_dir[checkpoint_dir.index('checkpoints/')+len('checkpoints/'):checkpoint_dir.index('-')]
        start_epoch = int(path_to_starting_ckpt.split('/checkpoint_epoch_')[1])+1
        
    os.makedirs(f'checkpoints', exist_ok=True)
    os.makedirs(checkpoint_dir, exist_ok=True)
    
    # Open log file
    LOG_PATH = f'{checkpoint_dir}/training_log.txt'
    ERR_PATH = f'{checkpoint_dir}/training_errors.txt'
    with open_logfile(LOG_PATH,mode=logmode), open_errfile(ERR_PATH,mode=logmode):
        if not(config.FINETUNE_FROM_SCRATCH): 
            log_update(f"\n{'-'*200}\nResuming finetuning from checkpoint {start_epoch-1} (first new checkpoint: {start_epoch})\n")
            log_update(f"Settings tag for original model (starting point for finetuning) = {START_MODEL_TRAIN_SETTINGS_STRING}\nSettings tag for new model based on configs = {TRAIN_SETTINGS_STRING}\nSame: {START_MODEL_TRAIN_SETTINGS_STRING==TRAIN_SETTINGS_STRING}\n")
            # ONLY proceed with training if we're using the same settings, otherwise we are not finetuning the model we think we are! 
            assert START_MODEL_TRAIN_SETTINGS_STRING==TRAIN_SETTINGS_STRING
        
        # Print configurations
        print_configpy(config)
        
        # Verify that the environment variables are set correctly 
        check_env_variables()
        
        # Check CUDA availability and set device
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        log_update(f"\nUsing device: {device}")
            
        # Init wandb
        wandb.init(project=config.WANDB_PROJECT, entity=config.WANDB_ENTITY, name=WANDB_NAME , config={
            "batch_size": config.BATCH_SIZE,
            "epochs": config.EPOCHS,
            "learning_rate": config.LEARNING_RATE,
        })
        
        # Initialize model and prepare it (freeze/unfreeze proper layers). Initialize optimizer as well. Details depend on whether we are finetuning from scratch.
        model, optimizer = intialize_model_and_optimizer(config.FINETUNE_FROM_SCRATCH, device,
                                                         path_to_starting_ckpt=config.PATH_TO_STARTING_CKPT, 
                                                         learning_rate=config.LEARNING_RATE, 
                                                         n_unfrozen_layers=config.N_UNFROZEN_LAYERS, 
                                                         unfreeze_query=config.UNFREEZE_QUERY, 
                                                         unfreeze_key=config.UNFREEZE_KEY, 
                                                         unfreeze_value=config.UNFREEZE_VALUE)
        
        # Initialize the tokenizer (independent of starting model for finetuning)
        tokenizer = model.tokenizer

        # Create DataLoader instances and perform sanity checks on them
        train_loader = get_dataloader(config.TRAIN_PATH, tokenizer, probability_type=config.PROBABILITY_TYPE, batch_size=config.BATCH_SIZE, max_length=config.MAX_LENGTH, shuffle=True) ## FOR DEBUGGING ONLY, change shuffle to False. Otherwise, True!!
        val_loader = get_dataloader(config.VAL_PATH, tokenizer, probability_type=config.PROBABILITY_TYPE, batch_size=config.BATCH_SIZE, max_length=config.MAX_LENGTH, shuffle=False)
        test_loader = get_dataloader(config.TEST_PATH, tokenizer, probability_type=config.PROBABILITY_TYPE, batch_size=config.BATCH_SIZE, max_length=config.MAX_LENGTH, shuffle=False)

        # If we're continuing to finetune an old ckpt, store the old batch diversity plot before we overwrite it
        check_dataloaders(train_loader, val_loader, test_loader, max_length=config.MAX_LENGTH, checkpoint_dir=checkpoint_dir)
        
        # Set up a masking rate scheduler, if one is needed
        mask_rate_scheduler = None
        if config.VAR_MASK_RATE:
            mask_rate_scheduler = get_mask_rate_scheduler(scheduler_type=config.MASK_SCHEDULER,
                                                          min_masking_rate=config.MASK_LOW,
                                                          max_masking_rate=config.MASK_HIGH,
                                                          total_batches=len(train_loader),
                                                          total_steps=config.MASK_STEPS)
        
        # Train the model
        train(model, tokenizer, optimizer, train_loader, val_loader,
            n_epochs=config.EPOCHS,
            start_epoch = start_epoch,
            device=device,
            mask_rate_scheduler=mask_rate_scheduler,
            mask_percentage=config.MASK_PERCENTAGE,
            checkpoint_dir=checkpoint_dir)

        # Test the model
        test(model, tokenizer, test_loader, mask_percentage=0.15, device=device, checkpoint_dir=checkpoint_dir)
        
if __name__ == "__main__":
    main()