import os import wandb import numpy as np import torch import torch.nn as nn from torch.utils.data import Dataset as TorchDataset from datetime import datetime import random from sklearn.utils.class_weight import compute_class_weight from transformers import ( AutoModelForTokenClassification, AutoTokenizer, DataCollatorForTokenClassification, TrainingArguments, Trainer, BitsAndBytesConfig ) from accelerate import Accelerator from peft import get_peft_config, PeftModel, PeftConfig, get_peft_model, LoraConfig, TaskType, prepare_model_for_kbit_training import pickle import gc from tqdm import tqdm # Define Desired Max Length MAX_LENGTH = 512 # Initialize accelerator and Weights & Biases accelerator = Accelerator() os.environ["WANDB_NOTEBOOK_NAME"] = 'training.py' wandb.init(project='binding_site_prediction') # Helper Functions and Data Preparation #----------------------------------------------------------------------------- class ProteinDataset(TorchDataset): def __init__(self, sequences_path, labels_path, tokenizer, max_length): self.tokenizer = tokenizer self.max_length = max_length with open(sequences_path, "rb") as f: self.sequences = pickle.load(f) with open(labels_path, "rb") as f: self.labels = pickle.load(f) def __len__(self): return len(self.sequences) def __getitem__(self, idx): sequence = self.sequences[idx] label = self.labels[idx] tokenized = self.tokenizer(sequence, padding='max_length', truncation=True, max_length=self.max_length, return_tensors="pt", is_split_into_words=False, add_special_tokens=False) # Remove the extra batch dimension for key in tokenized: tokenized[key] = tokenized[key].squeeze(0) # Ensure labels are also padded/truncated to match tokenized input label_padded = [-100] * self.max_length # Using -100 as the ignore index label_padded[:len(label)] = label[:self.max_length] tokenized["labels"] = torch.tensor(label_padded) return tokenized def print_trainable_parameters(model): """ Prints the number of trainable parameters in the model. """ trainable_params = 0 all_param = 0 for _, param in model.named_parameters(): all_param += param.numel() if param.requires_grad: trainable_params += param.numel() print( f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}" ) def save_config_to_txt(config, filename): """Save the configuration dictionary to a text file.""" with open(filename, 'w') as f: for key, value in config.items(): f.write(f"{key}: {value}\n") def compute_metrics(p): predictions, labels = p predictions = np.argmax(predictions, axis=2) mask = labels != -100 predictions = predictions[mask].flatten() labels = labels[mask].flatten() accuracy = accuracy_score(labels, predictions) precision, recall, f1, _ = precision_recall_fscore_support(labels, predictions, average='binary') auc = roc_auc_score(labels, predictions) mcc = matthews_corrcoef(labels, predictions) # Explicitly delete numpy arrays and call the garbage collector del predictions del labels gc.collect() return {'accuracy': accuracy, 'precision': precision, 'recall': recall, 'f1': f1, 'auc': auc, 'mcc': mcc} def compute_loss(model, logits, inputs): labels = inputs["labels"] loss_fct = nn.CrossEntropyLoss(weight=class_weights) active_loss = inputs["attention_mask"].view(-1) == 1 active_logits = logits.view(-1, model.config.num_labels) active_labels = torch.where( active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels) ) loss = loss_fct(active_logits, active_labels) return loss tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D") train_dataset = ProteinDataset("data/12M_data/512_train_sequences_chunked_by_family.pkl", "data/12M_data/512_train_labels_chunked_by_family.pkl", tokenizer, MAX_LENGTH) # Compute Class Weights # Sample a subset of labels for computing class weights (e.g., 100,000 sequences) SAMPLE_SIZE = 100000 with open("data/12M_data/512_train_labels_chunked_by_family.pkl", "rb") as f: all_train_labels = pickle.load(f) sample_labels = random.sample(all_train_labels, SAMPLE_SIZE) # Flatten the sampled labels flat_sample_labels = [label for sublist in sample_labels for label in sublist] # Compute class weights using the sampled labels classes = [0, 1] class_weights = compute_class_weight(class_weight='balanced', classes=classes, y=flat_sample_labels) class_weights = torch.tensor(class_weights, dtype=torch.float32).to(accelerator.device) # Define Custom Trainer Class class WeightedTrainer(Trainer): def compute_loss(self, model, inputs, return_outputs=False): outputs = model(**inputs) logits = outputs.logits loss = compute_loss(model, logits, inputs) return (loss, outputs) if return_outputs else loss # Configure the quantization settings bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16 ) def train_function_no_sweeps(train_dataset): # Directly set the config config = { "lora_alpha": 1, "lora_dropout": 0.5, "lr": 1.701568055793089e-04, "lr_scheduler_type": "cosine", "max_grad_norm": 0.5, "num_train_epochs": 1, "per_device_train_batch_size": 200, # "per_device_test_batch_size": 40, "r": 2, "weight_decay": 0.3, # Add other hyperparameters as needed } # Log the config to W&B wandb.config.update(config) # Save the config to a text file timestamp = datetime.now().strftime('%Y-%m-%d_%H-%M-%S') config_filename = f"esm2_t33_650M_qlora_config_{timestamp}.txt" save_config_to_txt(config, config_filename) model_checkpoint = "facebook/esm2_t33_650M_UR50D" # Define labels and model id2label = {0: "No binding site", 1: "Binding site"} label2id = {v: k for k, v in id2label.items()} model = AutoModelForTokenClassification.from_pretrained( model_checkpoint, num_labels=len(id2label), id2label=id2label, label2id=label2id, quantization_config=bnb_config ) # Prepare the model for 4-bit quantization training model.gradient_checkpointing_enable() model = prepare_model_for_kbit_training(model) # Convert the model into a PeftModel peft_config = LoraConfig( task_type=TaskType.TOKEN_CLS, inference_mode=False, r=config["r"], lora_alpha=config["lora_alpha"], target_modules=[ "query", "key", "value", "EsmSelfOutput.dense", "EsmIntermediate.dense", "EsmOutput.dense", # "EsmContactPredictionHead.regression", "classifier" ], lora_dropout=config["lora_dropout"], bias="none", # or "all" or "lora_only" # modules_to_save=["classifier"] ) model = get_peft_model(model, peft_config) print_trainable_parameters(model) # added this in # Use the accelerator model = accelerator.prepare(model) train_dataset = accelerator.prepare(train_dataset) timestamp = datetime.now().strftime('%Y-%m-%d_%H-%M-%S') # Training setup training_args = TrainingArguments( output_dir=f"esm2_t33_650M_qlora_binding_sites_{timestamp}", learning_rate=config["lr"], lr_scheduler_type=config["lr_scheduler_type"], gradient_accumulation_steps=1, max_grad_norm=config["max_grad_norm"], per_device_train_batch_size=config["per_device_train_batch_size"], # per_device_eval_batch_size=config["per_device_test_batch_size"], num_train_epochs=config["num_train_epochs"], weight_decay=config["weight_decay"], evaluation_strategy="no", save_strategy="steps", # Save at the end of each epoch save_steps=10000, # Also save every 10000 steps load_best_model_at_end=False, metric_for_best_model="f1", greater_is_better=True, push_to_hub=False, logging_dir=None, logging_first_step=False, logging_steps=100, save_total_limit=7, no_cuda=False, seed=8893, fp16=True, report_to='wandb', optim="paged_adamw_8bit" # added this in ) # Initialize Trainer trainer = WeightedTrainer( model=model, args=training_args, train_dataset=train_dataset, tokenizer=tokenizer, data_collator=DataCollatorForTokenClassification(tokenizer=tokenizer) ) # Train and Save Model trainer.train() save_path = os.path.join("qlora_binding_sites", f"best_model_esm2_t33_650M_qlora_{timestamp}") trainer.save_model(save_path) tokenizer.save_pretrained(save_path) # Call the training function if __name__ == "__main__": train_function_no_sweeps(train_dataset)