# Finetuneing ESM-2 Models for CAFA-5

## Finetune an ESM-2 Model

In [None]:
import torch
from torch.utils.data import DataLoader, Dataset
from transformers import AutoTokenizer, EsmForSequenceClassification
from accelerate import Accelerator
from sklearn.model_selection import train_test_split
from torchmetrics.classification import MultilabelF1Score
from sklearn.metrics import accuracy_score, precision_score, recall_score, average_precision_score
import datetime
import pandas as pd

# Load the data
data = pd.read_csv("C:/Users/OWO/Desktop/amelie_vscode/cafa5/data/merged_protein_data.tsv", sep="\t")
# Use only the first 100 entries
# data = data.head(100)

# Initialize the accelerator
accelerator = Accelerator()
device = accelerator.device

# Data Preprocessing
tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t6_8M_UR50D")
MAX_LENGTH = tokenizer.model_max_length
NUM_EPOCHS = 3
LR = 5e-4
BATCH_SIZE = 2

class ProteinDataset(Dataset):
 def __init__(self, sequences, labels):
 self.sequences = sequences
 self.labels = labels

 def __len__(self):
 return len(self.sequences)

 def __getitem__(self, idx):
 sequence = self.sequences[idx]
 label = self.labels[idx]
 encoding = tokenizer(sequence, return_tensors="pt", padding='max_length', truncation=True, max_length=MAX_LENGTH)
 return {
 'input_ids': encoding['input_ids'].flatten(),
 'attention_mask': encoding['attention_mask'].flatten(),
 'labels': torch.tensor(label, dtype=torch.float)
 }

def encode_labels(go_terms, unique_terms):
 encoded = []
 for terms in go_terms:
 encoding = [1 if term in terms else 0 for term in unique_terms]
 encoded.append(encoding)
 return encoded

train_sequences, val_sequences, train_labels, val_labels = train_test_split(data['sequence'], data['term'], test_size=0.1)

# Reset the indices
train_sequences = train_sequences.reset_index(drop=True)
val_sequences = val_sequences.reset_index(drop=True)
train_labels = train_labels.reset_index(drop=True)
val_labels = val_labels.reset_index(drop=True)

unique_terms = list(set(term for sublist in data['term'] for term in sublist))
train_labels_encoded = encode_labels(train_labels, unique_terms)
val_labels_encoded = encode_labels(val_labels, unique_terms)

train_dataset = ProteinDataset(train_sequences, train_labels_encoded)
val_dataset = ProteinDataset(val_sequences, val_labels_encoded)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE)

# Model Training
model = EsmForSequenceClassification.from_pretrained("facebook/esm2_t6_8M_UR50D", num_labels=len(unique_terms), problem_type="multi_label_classification")
model = model.to(device)
model.train()

optimizer = torch.optim.AdamW(model.parameters(), lr=LR)
optimizer, model = accelerator.prepare(optimizer, model)

# Initialize metrics
f1_metric = MultilabelF1Score(num_labels=len(unique_terms), threshold=0.5)
f1_metric = f1_metric.to(device)

num_epochs = NUM_EPOCHS

for epoch in range(num_epochs):
 total_loss = 0
 for batch in train_loader:
 optimizer.zero_grad()
 input_ids = batch['input_ids'].to(device)
 attention_mask = batch['attention_mask'].to(device)
 labels = batch['labels'].to(device)

 outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
 loss = outputs.loss
 accelerator.backward(loss)
 optimizer.step()

 total_loss += loss.item()

 print(f'Epoch {epoch + 1}/{num_epochs}, Training loss: {total_loss/len(train_loader)}')

 model.eval()
 predictions = []
 true_labels_list = []
 with torch.no_grad():
 for batch in val_loader:
 input_ids = batch['input_ids'].to(device)
 attention_mask = batch['attention_mask'].to(device)
 labels = batch['labels'].to(device)

 outputs = model(input_ids=input_ids, attention_mask=attention_mask)
 logits = outputs.logits
 predictions.append(torch.sigmoid(logits))
 true_labels_list.append(labels)

 predictions_tensor = torch.cat(predictions, dim=0).cpu().numpy()
 true_labels_tensor = torch.cat(true_labels_list, dim=0).cpu().numpy()

 threshold = 0.5
 predictions_bin = (predictions_tensor > threshold).astype(int)

 # Compute metrics
 val_f1 = f1_metric(torch.tensor(predictions_tensor).to(device), torch.tensor(true_labels_tensor).to(device))
 val_accuracy = accuracy_score(true_labels_tensor.flatten(), predictions_bin.flatten())
 val_precision = precision_score(true_labels_tensor.flatten(), predictions_bin.flatten(), average='micro')
 val_recall = recall_score(true_labels_tensor.flatten(), predictions_bin.flatten(), average='micro')
 val_auc = average_precision_score(true_labels_tensor, predictions_tensor, average='micro')

 # Print metrics
 print(f'Validation F1 Score: {val_f1}')
 print(f'Validation Accuracy: {val_accuracy}')
 print(f'Validation Precision: {val_precision}')
 print(f'Validation Recall: {val_recall}')
 print(f'Validation AUC: {val_auc}')

 timestamp = datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
 model_path = f'./esm2_t6_8M_finetuned_cafa5_{timestamp}'
 model.save_pretrained(model_path)
 tokenizer.save_pretrained(model_path)

 print(f'Model checkpoint saved to {model_path}')


## Save the Train/Validation Split Data

In [None]:
import pickle

# After you've created the train and validation splits:
data_splits = {
 "train_sequences": train_sequences,
 "val_sequences": val_sequences,
 "train_labels": train_labels,
 "val_labels": val_labels
}

with open('data_splits.pkl', 'wb') as file:
 pickle.dump(data_splits, file)


## Reload the Data Later

In [None]:
import pickle

# Load the data splits
with open('data_splits.pkl', 'rb') as file:
 data_splits = pickle.load(file)

train_sequences = data_splits["train_sequences"]
val_sequences = data_splits["val_sequences"]
train_labels = data_splits["train_labels"]
val_labels = data_splits["val_labels"]

# Now, the rest of your code can proceed as it is, 
# with the train and validation sets loaded from the pickle file.

## Data Preprocessing

In [None]:
import torch
from torch.utils.data import DataLoader, Dataset
from transformers import AutoTokenizer, EsmForSequenceClassification
from accelerate import Accelerator
from sklearn.model_selection import train_test_split
from torchmetrics.classification import MultilabelF1Score
from sklearn.metrics import accuracy_score, precision_score, recall_score, average_precision_score
import datetime
import pandas as pd

# Load the data
data = pd.read_csv("C:/Users/OWO/Desktop/amelie_vscode/cafa5/data/merged_protein_data.tsv", sep="\t")
# Use only the first 100 entries
data = data.head(100)

# Initialize the accelerator
accelerator = Accelerator()
device = accelerator.device

# Data Preprocessing
tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t6_8M_UR50D")
MAX_LENGTH = tokenizer.model_max_length

class ProteinDataset(Dataset):
 def __init__(self, sequences, labels):
 self.sequences = sequences
 self.labels = labels

 def __len__(self):
 return len(self.sequences)

 def __getitem__(self, idx):
 sequence = self.sequences[idx]
 label = self.labels[idx]
 encoding = tokenizer(sequence, return_tensors="pt", padding='max_length', truncation=True, max_length=MAX_LENGTH)
 return {
 'input_ids': encoding['input_ids'].flatten(),
 'attention_mask': encoding['attention_mask'].flatten(),
 'labels': torch.tensor(label, dtype=torch.float)
 }

def encode_labels(go_terms, unique_terms):
 encoded = []
 for terms in go_terms:
 encoding = [1 if term in terms else 0 for term in unique_terms]
 encoded.append(encoding)
 return encoded

# train_sequences, val_sequences, train_labels, val_labels = train_test_split(data['sequence'], data['term'], test_size=0.1)

# Reset the indices
# train_sequences = train_sequences.reset_index(drop=True)
# val_sequences = val_sequences.reset_index(drop=True)
# train_labels = train_labels.reset_index(drop=True)
# val_labels = val_labels.reset_index(drop=True)

unique_terms = list(set(term for sublist in data['term'] for term in sublist))
train_labels_encoded = encode_labels(train_labels, unique_terms)
val_labels_encoded = encode_labels(val_labels, unique_terms)

train_dataset = ProteinDataset(train_sequences, train_labels_encoded)
val_dataset = ProteinDataset(val_sequences, val_labels_encoded)

train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=2)

## Fine-tune with LoRA

In [None]:
from collections import Counter
from peft import get_peft_config, get_peft_model, LoraConfig
import datetime
from sklearn.metrics import accuracy_score, precision_score, recall_score, hamming_loss, average_precision_score
from torchmetrics.classification import MultilabelF1Score

# Constants
MODEL_NAME = "facebook/esm2_t6_8M_UR50D" # Replace with your trained model above
BATCH_SIZE = 4
NUM_EPOCHS = 7
LR = 3e-5

# Initialize model with LoRA
peft_config = LoraConfig(
 task_type="SEQ_CLS", 
 inference_mode=False, 
 r=16, 
 bias="none",
 lora_alpha=16, 
 lora_dropout=0.1, 
 target_modules=["query", "key", "value"]
)

base_model = EsmForSequenceClassification.from_pretrained(MODEL_NAME, num_labels=len(unique_terms), problem_type="multi_label_classification")
model = get_peft_model(base_model, peft_config)
model = model.to(accelerator.device)

optimizer = torch.optim.AdamW(model.parameters(), lr=LR)
optimizer, model = accelerator.prepare(optimizer, model)

f1_metric = MultilabelF1Score(num_labels=len(unique_terms), threshold=0.5)
f1_metric = f1_metric.to(device)

# Compute Class Weights
def compute_class_weights(terms, term_to_id):
 all_terms = [term for terms_list in terms for term in terms_list]
 term_counts = Counter(all_terms)
 total_terms = sum(term_counts.values())
 class_weights = {term: total_terms / count for term, count in term_counts.items()}
 weights = torch.tensor([class_weights[term] for term in term_to_id.keys()], dtype=torch.float)
 normalized_weights = weights / weights.sum()
 return normalized_weights

term_to_id = {term: idx for idx, term in enumerate(unique_terms)}
all_terms_combined = train_labels.tolist() + val_labels.tolist()
weights = compute_class_weights(all_terms_combined, term_to_id)
weights = weights.to(accelerator.device)
loss_criterion = torch.nn.BCEWithLogitsLoss(pos_weight=weights)

# Training loop
for epoch in range(NUM_EPOCHS):
 # Training Phase
 model.train()
 total_train_loss = 0
 for batch in train_loader:
 optimizer.zero_grad()
 input_ids = batch['input_ids'].to(accelerator.device)
 attention_mask = batch['attention_mask'].to(accelerator.device)
 labels = batch['labels'].to(accelerator.device)

 outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
 logits = outputs.logits
 loss = loss_criterion(logits, labels)
 accelerator.backward(loss)
 optimizer.step()

 total_train_loss += loss.item()

 avg_train_loss = total_train_loss / len(train_loader)

 # Validation Phase
 model.eval()
 total_val_loss = 0
 predictions = []
 true_labels = []
 with torch.no_grad():
 for batch in val_loader:
 input_ids = batch['input_ids'].to(accelerator.device)
 attention_mask = batch['attention_mask'].to(accelerator.device)
 labels = batch['labels'].to(accelerator.device)

 outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
 logits = outputs.logits
 loss = loss_criterion(logits, labels)

 total_val_loss += loss.item()
 predictions.append(torch.sigmoid(logits).detach())
 true_labels.append(labels.detach())


 avg_val_loss = total_val_loss / len(val_loader)
 
 predictions_tensor = torch.cat(predictions, dim=0).cpu().numpy()
 true_labels_tensor = torch.cat(true_labels, dim=0).cpu().numpy()

 threshold = 0.5
 predictions_bin = (predictions_tensor > threshold).astype(int)

 val_f1 = f1_metric(torch.tensor(predictions_tensor).to(device), torch.tensor(true_labels_tensor).to(device))
 val_accuracy = accuracy_score(true_labels_tensor.flatten(), predictions_bin.flatten())
 val_precision = precision_score(true_labels_tensor.flatten(), predictions_bin.flatten(), average='micro')
 val_recall = recall_score(true_labels_tensor.flatten(), predictions_bin.flatten(), average='micro')
 val_auc = average_precision_score(true_labels_tensor, predictions_tensor, average='micro')

 print(f"Epoch {epoch + 1}/{NUM_EPOCHS} - Training Loss: {avg_train_loss:.4f} - Validation Loss: {avg_val_loss:.4f}")
 print(f"Validation Metrics - Accuracy: {val_accuracy:.4f} - Precision (Micro): {val_precision:.4f} - Recall (Micro): {val_recall:.4f} - AUC: {val_auc:.4f} - F1 Score: {val_f1:.4f}")

 timestamp = datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
 # Save model and tokenizer. Note that Accelerator has a save method for models.
 model_path = f'./esm2_t6_8M_cafa5_lora_{timestamp}'
 model.save_pretrained(model_path)
 tokenizer.save_pretrained(model_path)
 model.base_model.save_pretrained(model_path)
 print(f'Model checkpoint saved to {model_path}')
