In [None]:
# 1. Imports
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
import numpy as np
import xml.etree.ElementTree as ET
import torch
import torch.nn as nn
from datetime import datetime
from sklearn.model_selection import train_test_split
from sklearn.utils.class_weight import compute_class_weight
from sklearn.metrics import precision_recall_fscore_support, roc_auc_score
from transformers import (
 AutoModelForTokenClassification,
 AutoTokenizer,
 DataCollatorForTokenClassification,
 TrainingArguments,
 Trainer,
)
from datasets import Dataset
from accelerate import Accelerator

# Imports specific to the custom model
from peft import get_peft_config, PeftModel, PeftConfig, get_peft_model, LoraConfig, TaskType

# 2. Setup Environment Variables and Accelerator
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
accelerator = Accelerator()

# 3. Helper Functions
def convert_binding_string_to_labels(binding_string):
 """Convert 'proBnd' strings into label arrays."""
 return [1 if char == '+' else 0 for char in binding_string]

def truncate_labels(labels, max_length):
 """Truncate labels to the specified max_length."""
 return [label[:max_length] for label in labels]

def compute_metrics(p):
 """Compute metrics for evaluation."""
 predictions, labels = p
 predictions = np.argmax(predictions, axis=2)
 predictions = predictions[labels != -100].flatten()
 labels = labels[labels != -100].flatten()
 precision, recall, f1, _ = precision_recall_fscore_support(labels, predictions, average='binary')
 auc = roc_auc_score(labels, predictions)
 return {'precision': precision, 'recall': recall, 'f1': f1, 'auc': auc}

def compute_loss(model, inputs):
 """Custom compute_loss function."""
 logits = model(**inputs).logits
 labels = inputs["labels"]
 loss_fct = nn.CrossEntropyLoss(weight=class_weights)
 active_loss = inputs["attention_mask"].view(-1) == 1
 active_logits = logits.view(-1, model.config.num_labels)
 active_labels = torch.where(
 active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels)
 )
 loss = loss_fct(active_logits, active_labels)
 return loss

# 4. Parse XML and Extract Data
tree = ET.parse('binding_sites.xml')
root = tree.getroot()
all_sequences = [partner.find(".//proSeq").text for partner in root.findall(".//BindPartner")]
all_labels = [convert_binding_string_to_labels(partner.find(".//proBnd").text) for partner in root.findall(".//BindPartner")]

# 5. Data Splitting and Tokenization
train_sequences, test_sequences, train_labels, test_labels = train_test_split(all_sequences, all_labels, test_size=0.20, shuffle=True)
tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t6_8M_UR50D")
max_sequence_length = 1291
train_tokenized = tokenizer(train_sequences, padding=True, truncation=True, max_length=max_sequence_length, return_tensors="pt", is_split_into_words=False)
test_tokenized = tokenizer(test_sequences, padding=True, truncation=True, max_length=max_sequence_length, return_tensors="pt", is_split_into_words=False)

train_labels = truncate_labels(train_labels, max_sequence_length)
test_labels = truncate_labels(test_labels, max_sequence_length)
train_dataset = Dataset.from_dict({k: v for k, v in train_tokenized.items()}).add_column("labels", train_labels)
test_dataset = Dataset.from_dict({k: v for k, v in test_tokenized.items()}).add_column("labels", test_labels)

# 6. Compute Class Weights
classes = [0, 1] 
flat_train_labels = [label for sublist in train_labels for label in sublist]
class_weights = compute_class_weight(class_weight='balanced', classes=classes, y=flat_train_labels)
class_weights = torch.tensor(class_weights, dtype=torch.float32).to(accelerator.device)

# 7. Define Custom Trainer Class
class WeightedTrainer(Trainer):
 def compute_loss(self, model, inputs, return_outputs=False):
 outputs = model(**inputs)
 loss = compute_loss(model, inputs)
 return (loss, outputs) if return_outputs else loss

# 8. Training Setup
model_checkpoint = "facebook/esm2_t6_8M_UR50D"
lr = 0.0005437551839696541
batch_size = 4
num_epochs = 15

# Define labels and model
id2label = {0: "No binding site", 1: "Binding site"}
label2id = {v: k for k, v in id2label.items()}
model = AutoModelForTokenClassification.from_pretrained(model_checkpoint, num_labels=len(id2label), id2label=id2label, label2id=label2id)

# Convert the model into a PeftModel
peft_config = LoraConfig(
 task_type=TaskType.TOKEN_CLS, 
 inference_mode=False, 
 r=16, 
 lora_alpha=16, 
 target_modules=["query", "key", "value"],
 lora_dropout=0.1, 
 bias="all"
)
model = get_peft_model(model, peft_config)

# Use the accelerator
model = accelerator.prepare(model)
train_dataset = accelerator.prepare(train_dataset)
test_dataset = accelerator.prepare(test_dataset)

timestamp = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
# Training setup
training_args = TrainingArguments(
 output_dir=f"esm2_t6_8M-lora-binding-site-classification_{timestamp}",
 learning_rate=lr,
 
 # Learning Rate Scheduling
 lr_scheduler_type="linear",
 warmup_steps=500, # Number of warm-up steps; adjust based on your observations
 
 # Gradient Clipping
 gradient_accumulation_steps=1,
 max_grad_norm=1.0, # Common value, but can be adjusted based on your observations
 
 # Batch Size
 per_device_train_batch_size=batch_size,
 per_device_eval_batch_size=batch_size,
 
 # Number of Epochs
 num_train_epochs=num_epochs,
 
 # Weight Decay
 weight_decay=0.025, # Adjust this value based on your observations, e.g., 0.01 or 0.05
 
 # Early Stopping
 evaluation_strategy="epoch",
 save_strategy="epoch",
 load_best_model_at_end=True,
 metric_for_best_model="f1", # You can also use "eval_loss" or "eval_auc" based on your preference
 greater_is_better=True,
 # early_stopping_patience=4, # Stops after 3 evaluations without improvement
 
 # Additional default arguments
 push_to_hub=False, # Set to True if you want to push the model to the HuggingFace Hub
 logging_dir=None, # Directory for storing logs
 logging_first_step=False,
 logging_steps=200, # Log every 200 steps
 save_total_limit=4, # Only the last 4 models are saved. Helps in saving disk space.
 no_cuda=False, # If True, will not use CUDA even if it's available
 seed=42, # Random seed for reproducibility
 fp16=True, # If True, uses half precision for training, which is faster and requires less memory but might be less accurate
 # dataloader_num_workers=4, # Number of CPU processes for data loading
)

# Initialize Trainer
trainer = WeightedTrainer(
 model=model,
 args=training_args,
 train_dataset=train_dataset,
 eval_dataset=test_dataset,
 tokenizer=tokenizer,
 data_collator=DataCollatorForTokenClassification(tokenizer=tokenizer),
 compute_metrics=compute_metrics
)

# 9. Train and Save Model
trainer.train()
save_path = os.path.join("lora_binding_sites", f"best_model_esm2_t6_8M_UR50D_{timestamp}")
trainer.save_model(save_path)
tokenizer.save_pretrained(save_path)
