import os import wandb import numpy as np import torch import torch.nn as nn from datetime import datetime from sklearn.model_selection import train_test_split from sklearn.utils.class_weight import compute_class_weight from sklearn.metrics import precision_recall_fscore_support, roc_auc_score, accuracy_score, matthews_corrcoef from transformers import ( AutoModelForTokenClassification, AutoTokenizer, DataCollatorForTokenClassification, TrainingArguments, Trainer ) from datasets import Dataset from accelerate import Accelerator import pickle # Initialize Weights & Biases logging os.environ["WANDB_NOTEBOOK_NAME"] = 'esm2_t6_8M_finetune_600K.ipynb' wandb.init(project='binding_site_prediction') # Helper Functions def truncate_labels(labels, max_length): """Truncate labels to the specified max_length.""" return [label[:max_length] for label in labels] def compute_metrics(p): """Compute metrics for evaluation.""" predictions, labels = p predictions = np.argmax(predictions, axis=2) predictions = predictions[labels != -100].flatten() labels = labels[labels != -100].flatten() accuracy = accuracy_score(labels, predictions) precision, recall, f1, _ = precision_recall_fscore_support(labels, predictions, average='binary') auc = roc_auc_score(labels, predictions) mcc = matthews_corrcoef(labels, predictions) return {'accuracy': accuracy, 'precision': precision, 'recall': recall, 'f1': f1, 'auc': auc, 'mcc': mcc} def compute_loss(model, inputs): """Custom compute_loss function.""" logits = model(**inputs).logits labels = inputs["labels"] loss_fct = nn.CrossEntropyLoss(weight=class_weights) active_loss = inputs["attention_mask"].view(-1) == 1 active_logits = logits.view(-1, model.config.num_labels) active_labels = torch.where( active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels) ) loss = loss_fct(active_logits, active_labels) return loss # Custom Trainer Class class WeightedTrainer(Trainer): def compute_loss(self, model, inputs, return_outputs=False): outputs = model(**inputs) loss = compute_loss(model, inputs) return (loss, outputs) if return_outputs else loss # Load data with open("600K_data/train_sequences_chunked_by_family.pkl", "rb") as f: train_sequences = pickle.load(f) with open("600K_data/test_sequences_chunked_by_family.pkl", "rb") as f: test_sequences = pickle.load(f) with open("600K_data/train_labels_chunked_by_family.pkl", "rb") as f: train_labels = pickle.load(f) with open("600K_data/test_labels_chunked_by_family.pkl", "rb") as f: test_labels = pickle.load(f) # Tokenization tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t6_8M_UR50D") max_sequence_length = 1000 train_tokenized = tokenizer(train_sequences, padding=True, truncation=True, max_length=max_sequence_length, return_tensors="pt", is_split_into_words=False) test_tokenized = tokenizer(test_sequences, padding=True, truncation=True, max_length=max_sequence_length, return_tensors="pt", is_split_into_words=False) train_labels = truncate_labels(train_labels, max_sequence_length) test_labels = truncate_labels(test_labels, max_sequence_length) train_dataset = Dataset.from_dict({k: v for k, v in train_tokenized.items()}).add_column("labels", train_labels) test_dataset = Dataset.from_dict({k: v for k, v in test_tokenized.items()}).add_column("labels", test_labels) # Compute Class Weights classes = [0, 1] flat_train_labels = [label for sublist in train_labels for label in sublist] class_weights = compute_class_weight(class_weight='balanced', classes=classes, y=flat_train_labels) accelerator = Accelerator() class_weights = torch.tensor(class_weights, dtype=torch.float32).to(accelerator.device) # Training Function def train_function_no_sweeps(train_dataset, test_dataset): # Initialize wandb wandb.init() # Configurations config = { "lr": 5.701568055793089e-04, "lr_scheduler_type": "cosine", "max_grad_norm": 0.5, "num_train_epochs": 1, "per_device_train_batch_size": 12, "weight_decay": 0.2 } # Model Setup model_checkpoint = "facebook/esm2_t6_8M_UR50D" id2label = {0: "No binding site", 1: "Binding site"} label2id = {v: k for k, v in id2label.items()} model = AutoModelForTokenClassification.from_pretrained( model_checkpoint, num_labels=len(id2label), id2label=id2label, label2id=label2id, hidden_dropout_prob=0.5, # Add this line for hidden dropout attention_probs_dropout_prob=0.5 # Add this line for attention dropout ) model = accelerator.prepare(model) train_dataset = accelerator.prepare(train_dataset) test_dataset = accelerator.prepare(test_dataset) timestamp = datetime.now().strftime('%Y-%m-%d_%H-%M-%S') # Training setup training_args = TrainingArguments( output_dir=f"esm2_t6_8M_finetune_{timestamp}", learning_rate=config["lr"], lr_scheduler_type=config["lr_scheduler_type"], gradient_accumulation_steps=1, max_grad_norm=config["max_grad_norm"], per_device_train_batch_size=config["per_device_train_batch_size"], per_device_eval_batch_size=config["per_device_train_batch_size"], num_train_epochs=config["num_train_epochs"], weight_decay=config["weight_decay"], evaluation_strategy="epoch", save_strategy="epoch", load_best_model_at_end=True, metric_for_best_model="f1", greater_is_better=True, push_to_hub=False, logging_dir=None, logging_first_step=False, logging_steps=200, save_total_limit=7, no_cuda=False, seed=42, fp16=True, report_to='wandb' ) # Initialize Trainer trainer = WeightedTrainer( model=model, args=training_args, train_dataset=train_dataset, eval_dataset=test_dataset, tokenizer=tokenizer, data_collator=DataCollatorForTokenClassification(tokenizer=tokenizer), compute_metrics=compute_metrics ) # Train and Save Model trainer.train() save_path = os.path.join("binding_sites", f"best_model_esm2_t6_8M_{timestamp}") trainer.save_model(save_path) tokenizer.save_pretrained(save_path) # Call the training function if __name__ == "__main__": train_function_no_sweeps(train_dataset, test_dataset)