QLoRA for ESM-2 and Post Translational Modification Site Prediction

Community Article Published November 11, 2023

In this post, we will show you how to train your own ESM-2 QLoRA model using data from UniProt on Post Translational Modification sites, treated as a binary token classification task. We will begin with instructions on how to gather the data from UniProt and create a train/test split based on UniProt families. This will help avoid overfitting due to sequence similarities that may occur in a standard random train/test split. Once we have created the training and test datasets, we will show you how to finetune a QLoRA for the protein language model ESM-2 to predict where in the proteins sequences post translational modifications are likely to occur.

image/png

What is ESM-2?

Protein language models such as ESM-2 (Evolutionary Scale Modeling) represent a significant advancement in computational biology. ESM-2, a deep learning model, is designed to understand the 'language' of proteins, i.e., the patterns and rules governing the structure and function of amino acid sequences in proteins, somewhat similar to the way ChatGPT understands human language, but with a masked language modeling objective rather than an autoregressive (causal-LM) objective, which is more suitable for certain tasks related to proteins. This model can be fine-tuned to predict general post-translational modification sites by treating the problem as a binary token classification task, where each amino acid in a protein sequence is considered a token.

The fine-tuning process involves training the model on datasets of known PTM sites, enabling the model to learn the contextual patterns associated with these modifications. By doing so, ESM-2 can predict whether each amino acid (token) in a new, unseen protein sequence is likely to undergo a specific modification or not. This binary classification is crucial for identifying potential PTM sites in proteins, which can aid in understanding protein function and regulation in a more detailed manner.

Introduction to Post-Translational Modification (PTM)

Post-translational modification (PTM) of proteins is a critical aspect of cellular biology, significantly influencing protein function and regulation. PTM refers to the chemical modification of a protein after its synthesis. These modifications typically occur following protein biosynthesis at the ribosome, where proteins are generated as linear chains of amino acids. The most common forms of PTM include phosphorylation, glycosylation, ubiquitination, nitrosylation, methylation, acetylation, lipidation, and proteolytic cleavage.

The importance of PTMs lies in their ability to diversify protein functions beyond what is dictated by gene sequence alone. They play a vital role in regulating protein activity, stability, localization, and interaction with other cellular molecules. PTMs can alter the physical and chemical properties of proteins, thereby affecting their folding, conformation, distribution, and interactions with other proteins and DNA. This is crucial for a myriad of cellular processes, including signal transduction, cell cycle control, metabolic pathways, and immune responses.

PTMs are used in various biological and medical applications. In drug discovery and development, understanding PTMs can lead to the identification of new drug targets and therapeutic strategies. Additionally, aberrant PTMs are often associated with diseases such as cancer, neurodegenerative disorders, and metabolic diseases, making them potential biomarkers for diagnosis and targets for treatment.

Data Curation and Preprocessing

First, head over to UniProt and go to "Advanced" in the search bar. Next, when the options for the advanced search appear, select "PTM/Processing", and then select "Modified Residue". Type in * in the search field (after removing all of the extra search fields), and select "Search". Once you've done this you will get a list of proteins with modified amino acid residues. You can customize the table layout to reflect this by selecting "Customize Columns" in the table view. You should customize the columns to only include the protein sequence, protein families, and the modified residues. Next, download this data, making sure to only include the protein sequence, the "Protein Families", and the modified residues. Be sure you include the protein families, as this will be needed for creating the train/test split. Once you have downloaded this file as a TSV with these columns, you can run the following data preprocessing steps to create your train/test split.

import pandas as pd

# Load the TSV file
file_path = 'PTM/uniprotkb_family_AND_ft_mod_res_AND_pro_2023_10_07.tsv'
data = pd.read_csv(file_path, sep='\t')

# Display the first few rows of the data
data.head()

This should print something like the following:

image/png

import re

def get_ptm_sites(row):
    # Extract the positions of modified residues from the 'Modified residue' column
    modified_positions = [int(i) for i in re.findall(r'MOD_RES (\d+)', row['Modified residue'])]
    
    # Create a list of zeros of length equal to the protein sequence
    ptm_sites = [0] * len(row['Sequence'])
    
    # Replace the zeros with ones at the positions of modified residues
    for position in modified_positions:
        # Subtracting 1 because positions are 1-indexed, but lists are 0-indexed
        ptm_sites[position - 1] = 1
    
    return ptm_sites

# Apply the function to each row in the DataFrame
data['PTM sites'] = data.apply(get_ptm_sites, axis=1)

# Display the first few rows of the updated DataFrame
data.head()

This next cell will split the longer protein sequences and theor lables into non-overlapping chunks of length 512 or less to account for a context window of 1024 for smaller ESM-2 models. Feel free to adjust this to a longer length if you like. Most protein sequences are on average 350 or so residues, so having longer context windows is often unnecessary, although we have observed better performance with a context window of 1000. Keep in mind this will effect training time and batch size though.

# Function to split sequences and PTM sites into chunks
def split_into_chunks(row):
    sequence = row['Sequence']
    ptm_sites = row['PTM sites']
    chunk_size = 512
    
    # Calculate the number of chunks
    num_chunks = (len(sequence) + chunk_size - 1) // chunk_size
    
    # Split sequences and PTM sites into chunks
    sequence_chunks = [sequence[i * chunk_size: (i + 1) * chunk_size] for i in range(num_chunks)]
    ptm_sites_chunks = [ptm_sites[i * chunk_size: (i + 1) * chunk_size] for i in range(num_chunks)]
    
    # Create new rows for each chunk
    rows = []
    for i in range(num_chunks):
        new_row = row.copy()
        new_row['Sequence'] = sequence_chunks[i]
        new_row['PTM sites'] = ptm_sites_chunks[i]
        rows.append(new_row)
    
    return rows

# Create a new DataFrame to store the chunks
chunks_data = []

# Iterate through each row of the original DataFrame and split into chunks
for _, row in data.iterrows():
    chunks_data.extend(split_into_chunks(row))

# Convert the list of chunks into a DataFrame
chunks_df = pd.DataFrame(chunks_data)

# Reset the index of the DataFrame
chunks_df.reset_index(drop=True, inplace=True)

# Display the first few rows of the new DataFrame
chunks_df.head()

Next, we create the train/test split based on UniProt families.

from tqdm import tqdm
import numpy as np

# Function to split data into train and test based on families
def split_data(df):
    # Get a unique list of protein families
    unique_families = df['Protein families'].unique().tolist()
    np.random.shuffle(unique_families)  # Shuffle the list to randomize the order of families
    
    test_data = []
    test_families = []
    total_entries = len(df)
    total_families = len(unique_families)
    
    # Set up tqdm progress bar
    with tqdm(total=total_families) as pbar:
        for family in unique_families:
            # Separate out all proteins in the current family into the test data
            family_data = df[df['Protein families'] == family]
            test_data.append(family_data)
            
            # Update the list of test families
            test_families.append(family)
            
            # Remove the current family data from the original DataFrame
            df = df[df['Protein families'] != family]
            
            # Calculate the percentage of test data and the percentage of families in the test data
            percent_test_data = sum(len(data) for data in test_data) / total_entries * 100
            percent_test_families = len(test_families) / total_families * 100
            
            # Update tqdm progress bar with readout of percentages
            pbar.set_description(f'% Test Data: {percent_test_data:.2f}% | % Test Families: {percent_test_families:.2f}%')
            pbar.update(1)
            
            # Check if the 20% threshold for test data is crossed
            if percent_test_data >= 20:
                break
    
    # Concatenate the list of test data DataFrames into a single DataFrame
    test_df = pd.concat(test_data, ignore_index=True)
    
    return df, test_df  # Return the remaining data and the test data

# Split the data into train and test based on families
train_df, test_df = split_data(chunks_df)

If you want to reduce the size of your datasets while maintaining the train/test split, you can adjust the percentage to something less thant 100% below.

import pandas as pd

# Assuming train_df and test_df are your dataframes
fraction = 1.00  # 100.0%

# Randomly select 100% of the data
reduced_train_df = train_df.sample(frac=fraction, random_state=42)
reduced_test_df = test_df.sample(frac=fraction, random_state=42)
import pickle 

# Extract sequences and PTM site labels from the reduced train and test DataFrames
train_sequences_reduced = reduced_train_df['Sequence'].tolist()
train_labels_reduced = reduced_train_df['PTM sites'].tolist()
test_sequences_reduced = reduced_test_df['Sequence'].tolist()
test_labels_reduced = reduced_test_df['PTM sites'].tolist()

# Save the lists to the specified pickle files
pickle_file_path = "2100K_ptm_data_512/"

with open(pickle_file_path + "train_sequences_chunked_by_family.pkl", "wb") as f:
    pickle.dump(train_sequences_reduced, f)

with open(pickle_file_path + "test_sequences_chunked_by_family.pkl", "wb") as f:
    pickle.dump(test_sequences_reduced, f)

with open(pickle_file_path + "train_labels_chunked_by_family.pkl", "wb") as f:
    pickle.dump(train_labels_reduced, f)

with open(pickle_file_path + "test_labels_chunked_by_family.pkl", "wb") as f:
    pickle.dump(test_labels_reduced, f)

# Return the paths to the saved pickle files
saved_files = [
    pickle_file_path + "train_sequences_chunked_by_family.pkl",
    pickle_file_path + "test_sequences_chunked_by_family.pkl",
    pickle_file_path + "train_labels_chunked_by_family.pkl",
    pickle_file_path + "test_labels_chunked_by_family.pkl"
]
saved_files

Training the QLoRA

Importing Libraries and Modules

The first cell imports necessary libraries and modules:

  • os and wandb for environment and experiment tracking.
  • numpy and torch for numerical and tensor operations.
  • Various modules from transformers, datasets, and accelerate for handling token classification and model acceleration.
  • peft for PEFT (Parameter-efficient Fine-tuning) configurations.
  • pickle for loading the dataset.
import os
import wandb
import numpy as np
import torch
import torch.nn as nn
from datetime import datetime
from sklearn.model_selection import train_test_split
from sklearn.utils.class_weight import compute_class_weight
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, roc_auc_score, matthews_corrcoef
from transformers import (
    AutoModelForTokenClassification,
    AutoTokenizer,
    DataCollatorForTokenClassification,
    TrainingArguments,
    Trainer,
    BitsAndBytesConfig
)
from datasets import Dataset
from accelerate import Accelerator
from peft import get_peft_config, PeftModel, PeftConfig, get_peft_model, LoraConfig, TaskType, prepare_model_for_kbit_training
import pickle

Initializing Accelerator and Weights & Biases

The second cell sets up the Accelerator for efficient training on available hardware and initializes the Weights & Biases (W&B) platform for experiment tracking.

# Initialize accelerator and Weights & Biases
accelerator = Accelerator()
os.environ["WANDB_NOTEBOOK_NAME"] = 'qlora_ptm_v2.py'
wandb.init(project='ptm_site_prediction')

Helper Functions and Data Preparation

The third cell defines several helper functions:

  • print_trainable_parameters: To display the number of trainable parameters.
  • save_config_to_txt: To save model configurations as a text file.
  • truncate_labels: To truncate labels for sequences longer than the maximum length.
  • compute_metrics: To calculate evaluation metrics like accuracy, precision, recall, F1 score, AUC, and MCC.
  • compute_loss: Custom loss computation considering class weights.
# Helper Functions and Data Preparation
def print_trainable_parameters(model):
    """
    Prints the number of trainable parameters in the model.
    """
    trainable_params = 0
    all_param = 0
    for _, param in model.named_parameters():
        all_param += param.numel()
        if param.requires_grad:
            trainable_params += param.numel()
    print(
        f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}"
    )

def save_config_to_txt(config, filename):
    """Save the configuration dictionary to a text file."""
    with open(filename, 'w') as f:
        for key, value in config.items():
            f.write(f"{key}: {value}\n")

def truncate_labels(labels, max_length):
    return [label[:max_length] for label in labels]

def compute_metrics(p):
    predictions, labels = p
    predictions = np.argmax(predictions, axis=2)
    predictions = predictions[labels != -100].flatten()
    labels = labels[labels != -100].flatten()
    accuracy = accuracy_score(labels, predictions)
    precision, recall, f1, _ = precision_recall_fscore_support(labels, predictions, average='binary')
    auc = roc_auc_score(labels, predictions)
    mcc = matthews_corrcoef(labels, predictions)
    return {'accuracy': accuracy, 'precision': precision, 'recall': recall, 'f1': f1, 'auc': auc, 'mcc': mcc}

def compute_loss(model, logits, inputs):
    # logits = model(**inputs).logits
    labels = inputs["labels"]
    loss_fct = nn.CrossEntropyLoss(weight=class_weights)
    active_loss = inputs["attention_mask"].view(-1) == 1
    active_logits = logits.view(-1, model.config.num_labels)
    active_labels = torch.where(
        active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels)
    )
    loss = loss_fct(active_logits, active_labels)
    return loss

Loading Data

The fourth cell loads the training and testing datasets from pickle files, ensuring data is ready for processing and model training.

# Load data from pickle files
with open("2100K_ptm_data/train_sequences_chunked_by_family.pkl", "rb") as f:
    train_sequences = pickle.load(f)
    
with open("2100K_ptm_data/test_sequences_chunked_by_family.pkl", "rb") as f:
    test_sequences = pickle.load(f)

with open("2100K_ptm_data/train_labels_chunked_by_family.pkl", "rb") as f:
    train_labels = pickle.load(f)

with open("2100K_ptm_data/test_labels_chunked_by_family.pkl", "rb") as f:
    test_labels = pickle.load(f)

Tokenization

The fifth cell involves tokenizing the protein sequences using the AutoTokenizer from the ESM-2 model. This process converts the sequences into a format suitable for the model, considering aspects like padding, truncation, and maximum sequence length.

# Tokenization
tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t30_150M_UR50D")

# Set max_sequence_length to the tokenizer's max input length
max_sequence_length = 1024

train_tokenized = tokenizer(train_sequences, padding=True, truncation=True, max_length=max_sequence_length, return_tensors="pt", is_split_into_words=False, add_special_tokens=False)
test_tokenized = tokenizer(test_sequences, padding=True, truncation=True, max_length=max_sequence_length, return_tensors="pt", is_split_into_words=False, add_special_tokens=False)

# Directly truncate the entire list of labels
train_labels = truncate_labels(train_labels, max_sequence_length)
test_labels = truncate_labels(test_labels, max_sequence_length)

Creating Datasets

The sixth cell creates Dataset objects for training and testing, incorporating the tokenized data and corresponding labels.

train_dataset = Dataset.from_dict({k: v for k, v in train_tokenized.items()}).add_column("labels", train_labels)
test_dataset = Dataset.from_dict({k: v for k, v in test_tokenized.items()}).add_column("labels", test_labels)

Computing Class Weights

The seventh cell calculates class weights to address class imbalance, essential for a balanced training process in binary classification tasks. Due to the fact that there are significantly fewer PTM sites than non-PTM sites, we will need this to make sure the model doesn't just learn to predict the majority class, get a high accuracy, and call it a day.

# Compute Class Weights
classes = [0, 1]  
flat_train_labels = [label for sublist in train_labels for label in sublist]
class_weights = compute_class_weight(class_weight='balanced', classes=classes, y=flat_train_labels)
class_weights = torch.tensor(class_weights, dtype=torch.float32).to(accelerator.device)

Defining a Custom Trainer Class

The eighth cell introduces a custom Trainer class to incorporate the weighted loss function during model training.

# Define Custom Trainer Class
class WeightedTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False):
        outputs = model(**inputs)
        logits = outputs.logits
        loss = compute_loss(model, logits, inputs)
        return (loss, outputs) if return_outputs else loss

Configuring Quantization Settings

The ninth cell sets up the quantization settings for the model, which helps in reducing the model size and improving inference efficiency.

# Configure the quantization settings
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16
)

Training Function Without Sweeps

The tenth cell defines the main training function:

  • Sets model configurations and logs them to W&B.
  • Initializes the ESM-2 model for token classification with specific labels and applies quantization.
  • Prepares the model for PEFT and 4-bit quantization training.
  • Configures training arguments like learning rate, batch size, epochs, etc.
  • Initializes the custom WeightedTrainer.
  • Executes the training process and saves the model.
def train_function_no_sweeps(train_dataset, test_dataset):
    
    # Directly set the config
    config = {
        "lora_alpha": 1, 
        "lora_dropout": 0.5,
        "lr": 3.701568055793089e-04,
        "lr_scheduler_type": "cosine",
        "max_grad_norm": 0.5,
        "num_train_epochs": 1,
        "per_device_train_batch_size": 36,
        "r": 2,
        "weight_decay": 0.3,
        # Add other hyperparameters as needed
    }

    # Log the config to W&B
    wandb.config.update(config)

    # Save the config to a text file
    timestamp = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
    config_filename = f"esm2_t30_150M_qlora_ptm_config_{timestamp}.txt"
    save_config_to_txt(config, config_filename)
    
        
    model_checkpoint = "facebook/esm2_t30_150M_UR50D"  
    
    # Define labels and model
    id2label = {0: "No ptm site", 1: "ptm site"}
    label2id = {v: k for k, v in id2label.items()}
    
    model = AutoModelForTokenClassification.from_pretrained(
        model_checkpoint,
        num_labels=len(id2label),
        id2label=id2label,
        label2id=label2id,
        quantization_config=bnb_config  # Apply quantization here
    )

    # Prepare the model for 4-bit quantization training
    model.gradient_checkpointing_enable()
    model = prepare_model_for_kbit_training(model)
    
    # Convert the model into a PeftModel
    peft_config = LoraConfig(
        task_type=TaskType.TOKEN_CLS,
        inference_mode=False,
        r=config["r"],
        lora_alpha=config["lora_alpha"],
        target_modules=[
            "query",
            "key",
            "value",
            "EsmSelfOutput.dense",
            "EsmIntermediate.dense",
            "EsmOutput.dense",
            "EsmContactPredictionHead.regression",
            "classifier"
        ],
        lora_dropout=config["lora_dropout"],
        bias="none",  # or "all" or "lora_only"
        # modules_to_save=["classifier"]
    )
    model = get_peft_model(model, peft_config)
    print_trainable_parameters(model) # added this in

    # Use the accelerator
    model = accelerator.prepare(model)
    train_dataset = accelerator.prepare(train_dataset)
    test_dataset = accelerator.prepare(test_dataset)

    timestamp = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')

    # Training setup
    training_args = TrainingArguments(
        output_dir=f"esm2_t30_150M_qlora_ptm_sites_{timestamp}",
        learning_rate=config["lr"],
        lr_scheduler_type=config["lr_scheduler_type"],
        gradient_accumulation_steps=1, # changed from 1 to 4
        # warmup_steps=2, # added this in 
        max_grad_norm=config["max_grad_norm"],
        per_device_train_batch_size=config["per_device_train_batch_size"],
        per_device_eval_batch_size=config["per_device_train_batch_size"],
        num_train_epochs=config["num_train_epochs"],
        weight_decay=config["weight_decay"],
        evaluation_strategy="epoch",
        save_strategy="epoch",
        load_best_model_at_end=True,
        metric_for_best_model="f1",
        greater_is_better=True,
        push_to_hub=False,
        logging_dir=None,
        logging_first_step=False,
        logging_steps=200,
        save_total_limit=3,
        no_cuda=False,
        seed=8893,
        fp16=True,
        report_to='wandb', 
        optim="paged_adamw_8bit" # added this in 

    )
    
    # Initialize Trainer
    trainer = WeightedTrainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=test_dataset,
        tokenizer=tokenizer,
        data_collator=DataCollatorForTokenClassification(tokenizer=tokenizer),
        compute_metrics=compute_metrics
    )

    # Train and Save Model
    trainer.train()
    save_path = os.path.join("qlora_ptm_sites", f"best_model_esm2_t30_150M_qlora_{timestamp}")
    trainer.save_model(save_path)
    tokenizer.save_pretrained(save_path)

Main Execution

The final cell is the entry point for the training script, calling the training function with the prepared datasets.

# Call the training function
if __name__ == "__main__":
    train_function_no_sweeps(train_dataset, test_dataset)

Conclusion

This notebook demonstrates a sophisticated approach to leveraging state-of-the-art protein language models in biochemistry for the prediction of post translational modification sites of protein sequences. By fine-tuning ESM-2, a protein language model, the notebook will allow you to integrating deep learning into protein bioinformatics, paving the way for more advanced research in understanding protein functions and interactions. Once you have trained your new ESM-2 model for predicting post translational modifications, be sure to upload it to Hugging Face and share it!

To test out a version of this model, head over to neurosnap, and check out the various models they have available, or head over to the Hugging Face collection ESM-PTM. You might also try reading the recent research on applying LoRA to protein language models Democratizing Protein Language Models with Parameter-Efficient Fine-Tuning, and Exploring Post-Training Quantization of Protein Language Models, which were recently released.