{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# 1. Imports\n", "import os\n", "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\"\n", "import numpy as np\n", "import xml.etree.ElementTree as ET\n", "import torch\n", "import torch.nn as nn\n", "from datetime import datetime\n", "from sklearn.model_selection import train_test_split\n", "from sklearn.utils.class_weight import compute_class_weight\n", "from sklearn.metrics import precision_recall_fscore_support, roc_auc_score\n", "from transformers import (\n", " AutoModelForTokenClassification,\n", " AutoTokenizer,\n", " DataCollatorForTokenClassification,\n", " TrainingArguments,\n", " Trainer,\n", ")\n", "from datasets import Dataset\n", "from accelerate import Accelerator\n", "\n", "# Imports specific to the custom model\n", "from peft import get_peft_config, PeftModel, PeftConfig, get_peft_model, LoraConfig, TaskType\n", "\n", "# 2. Setup Environment Variables and Accelerator\n", "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\"\n", "accelerator = Accelerator()\n", "\n", "# 3. Helper Functions\n", "def convert_binding_string_to_labels(binding_string):\n", " \"\"\"Convert 'proBnd' strings into label arrays.\"\"\"\n", " return [1 if char == '+' else 0 for char in binding_string]\n", "\n", "def truncate_labels(labels, max_length):\n", " \"\"\"Truncate labels to the specified max_length.\"\"\"\n", " return [label[:max_length] for label in labels]\n", "\n", "def compute_metrics(p):\n", " \"\"\"Compute metrics for evaluation.\"\"\"\n", " predictions, labels = p\n", " predictions = np.argmax(predictions, axis=2)\n", " predictions = predictions[labels != -100].flatten()\n", " labels = labels[labels != -100].flatten()\n", " precision, recall, f1, _ = precision_recall_fscore_support(labels, predictions, average='binary')\n", " auc = roc_auc_score(labels, predictions)\n", " return {'precision': precision, 'recall': recall, 'f1': f1, 'auc': auc}\n", "\n", "def compute_loss(model, inputs):\n", " \"\"\"Custom compute_loss function.\"\"\"\n", " logits = model(**inputs).logits\n", " labels = inputs[\"labels\"]\n", " loss_fct = nn.CrossEntropyLoss(weight=class_weights)\n", " active_loss = inputs[\"attention_mask\"].view(-1) == 1\n", " active_logits = logits.view(-1, model.config.num_labels)\n", " active_labels = torch.where(\n", " active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels)\n", " )\n", " loss = loss_fct(active_logits, active_labels)\n", " return loss\n", "\n", "# 4. Parse XML and Extract Data\n", "tree = ET.parse('binding_sites.xml')\n", "root = tree.getroot()\n", "all_sequences = [partner.find(\".//proSeq\").text for partner in root.findall(\".//BindPartner\")]\n", "all_labels = [convert_binding_string_to_labels(partner.find(\".//proBnd\").text) for partner in root.findall(\".//BindPartner\")]\n", "\n", "# 5. Data Splitting and Tokenization\n", "train_sequences, test_sequences, train_labels, test_labels = train_test_split(all_sequences, all_labels, test_size=0.20, shuffle=True)\n", "tokenizer = AutoTokenizer.from_pretrained(\"facebook/esm2_t6_8M_UR50D\")\n", "max_sequence_length = 1291\n", "train_tokenized = tokenizer(train_sequences, padding=True, truncation=True, max_length=max_sequence_length, return_tensors=\"pt\", is_split_into_words=False)\n", "test_tokenized = tokenizer(test_sequences, padding=True, truncation=True, max_length=max_sequence_length, return_tensors=\"pt\", is_split_into_words=False)\n", "\n", "train_labels = truncate_labels(train_labels, max_sequence_length)\n", "test_labels = truncate_labels(test_labels, max_sequence_length)\n", "train_dataset = Dataset.from_dict({k: v for k, v in train_tokenized.items()}).add_column(\"labels\", train_labels)\n", "test_dataset = Dataset.from_dict({k: v for k, v in test_tokenized.items()}).add_column(\"labels\", test_labels)\n", "\n", "# 6. Compute Class Weights\n", "classes = [0, 1] \n", "flat_train_labels = [label for sublist in train_labels for label in sublist]\n", "class_weights = compute_class_weight(class_weight='balanced', classes=classes, y=flat_train_labels)\n", "class_weights = torch.tensor(class_weights, dtype=torch.float32).to(accelerator.device)\n", "\n", "# 7. Define Custom Trainer Class\n", "class WeightedTrainer(Trainer):\n", " def compute_loss(self, model, inputs, return_outputs=False):\n", " outputs = model(**inputs)\n", " loss = compute_loss(model, inputs)\n", " return (loss, outputs) if return_outputs else loss\n", "\n", "# 8. Training Setup\n", "model_checkpoint = \"facebook/esm2_t6_8M_UR50D\"\n", "lr = 0.0005437551839696541\n", "batch_size = 4\n", "num_epochs = 15\n", "\n", "# Define labels and model\n", "id2label = {0: \"No binding site\", 1: \"Binding site\"}\n", "label2id = {v: k for k, v in id2label.items()}\n", "model = AutoModelForTokenClassification.from_pretrained(model_checkpoint, num_labels=len(id2label), id2label=id2label, label2id=label2id)\n", "\n", "# Convert the model into a PeftModel\n", "peft_config = LoraConfig(\n", " task_type=TaskType.TOKEN_CLS, \n", " inference_mode=False, \n", " r=16, \n", " lora_alpha=16, \n", " target_modules=[\"query\", \"key\", \"value\"],\n", " lora_dropout=0.1, \n", " bias=\"all\"\n", ")\n", "model = get_peft_model(model, peft_config)\n", "\n", "# Use the accelerator\n", "model = accelerator.prepare(model)\n", "train_dataset = accelerator.prepare(train_dataset)\n", "test_dataset = accelerator.prepare(test_dataset)\n", "\n", "timestamp = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')\n", "# Training setup\n", "training_args = TrainingArguments(\n", " output_dir=f\"esm2_t6_8M-lora-binding-site-classification_{timestamp}\",\n", " learning_rate=lr,\n", " \n", " # Learning Rate Scheduling\n", " lr_scheduler_type=\"linear\",\n", " warmup_steps=500, # Number of warm-up steps; adjust based on your observations\n", " \n", " # Gradient Clipping\n", " gradient_accumulation_steps=1,\n", " max_grad_norm=1.0, # Common value, but can be adjusted based on your observations\n", " \n", " # Batch Size\n", " per_device_train_batch_size=batch_size,\n", " per_device_eval_batch_size=batch_size,\n", " \n", " # Number of Epochs\n", " num_train_epochs=num_epochs,\n", " \n", " # Weight Decay\n", " weight_decay=0.025, # Adjust this value based on your observations, e.g., 0.01 or 0.05\n", " \n", " # Early Stopping\n", " evaluation_strategy=\"epoch\",\n", " save_strategy=\"epoch\",\n", " load_best_model_at_end=True,\n", " metric_for_best_model=\"f1\", # You can also use \"eval_loss\" or \"eval_auc\" based on your preference\n", " greater_is_better=True,\n", " # early_stopping_patience=4, # Stops after 3 evaluations without improvement\n", " \n", " # Additional default arguments\n", " push_to_hub=False, # Set to True if you want to push the model to the HuggingFace Hub\n", " logging_dir=None, # Directory for storing logs\n", " logging_first_step=False,\n", " logging_steps=200, # Log every 200 steps\n", " save_total_limit=4, # Only the last 4 models are saved. Helps in saving disk space.\n", " no_cuda=False, # If True, will not use CUDA even if it's available\n", " seed=42, # Random seed for reproducibility\n", " fp16=True, # If True, uses half precision for training, which is faster and requires less memory but might be less accurate\n", " # dataloader_num_workers=4, # Number of CPU processes for data loading\n", ")\n", "\n", "# Initialize Trainer\n", "trainer = WeightedTrainer(\n", " model=model,\n", " args=training_args,\n", " train_dataset=train_dataset,\n", " eval_dataset=test_dataset,\n", " tokenizer=tokenizer,\n", " data_collator=DataCollatorForTokenClassification(tokenizer=tokenizer),\n", " compute_metrics=compute_metrics\n", ")\n", "\n", "# 9. Train and Save Model\n", "trainer.train()\n", "save_path = os.path.join(\"lora_binding_sites\", f\"best_model_esm2_t6_8M_UR50D_{timestamp}\")\n", "trainer.save_model(save_path)\n", "tokenizer.save_pretrained(save_path)\n" ] } ], "metadata": { "language_info": { "name": "python" }, "orig_nbformat": 4 }, "nbformat": 4, "nbformat_minor": 2 }