File size: 9,118 Bytes
216d13a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 1. Imports\n",
    "import os\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\"\n",
    "import numpy as np\n",
    "import xml.etree.ElementTree as ET\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "from datetime import datetime\n",
    "from sklearn.model_selection import train_test_split\n",
    "from sklearn.utils.class_weight import compute_class_weight\n",
    "from sklearn.metrics import precision_recall_fscore_support, roc_auc_score\n",
    "from transformers import (\n",
    "    AutoModelForTokenClassification,\n",
    "    AutoTokenizer,\n",
    "    DataCollatorForTokenClassification,\n",
    "    TrainingArguments,\n",
    "    Trainer,\n",
    ")\n",
    "from datasets import Dataset\n",
    "from accelerate import Accelerator\n",
    "\n",
    "# Imports specific to the custom model\n",
    "from peft import get_peft_config, PeftModel, PeftConfig, get_peft_model, LoraConfig, TaskType\n",
    "\n",
    "# 2. Setup Environment Variables and Accelerator\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\"\n",
    "accelerator = Accelerator()\n",
    "\n",
    "# 3. Helper Functions\n",
    "def convert_binding_string_to_labels(binding_string):\n",
    "    \"\"\"Convert 'proBnd' strings into label arrays.\"\"\"\n",
    "    return [1 if char == '+' else 0 for char in binding_string]\n",
    "\n",
    "def truncate_labels(labels, max_length):\n",
    "    \"\"\"Truncate labels to the specified max_length.\"\"\"\n",
    "    return [label[:max_length] for label in labels]\n",
    "\n",
    "def compute_metrics(p):\n",
    "    \"\"\"Compute metrics for evaluation.\"\"\"\n",
    "    predictions, labels = p\n",
    "    predictions = np.argmax(predictions, axis=2)\n",
    "    predictions = predictions[labels != -100].flatten()\n",
    "    labels = labels[labels != -100].flatten()\n",
    "    precision, recall, f1, _ = precision_recall_fscore_support(labels, predictions, average='binary')\n",
    "    auc = roc_auc_score(labels, predictions)\n",
    "    return {'precision': precision, 'recall': recall, 'f1': f1, 'auc': auc}\n",
    "\n",
    "def compute_loss(model, inputs):\n",
    "    \"\"\"Custom compute_loss function.\"\"\"\n",
    "    logits = model(**inputs).logits\n",
    "    labels = inputs[\"labels\"]\n",
    "    loss_fct = nn.CrossEntropyLoss(weight=class_weights)\n",
    "    active_loss = inputs[\"attention_mask\"].view(-1) == 1\n",
    "    active_logits = logits.view(-1, model.config.num_labels)\n",
    "    active_labels = torch.where(\n",
    "        active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels)\n",
    "    )\n",
    "    loss = loss_fct(active_logits, active_labels)\n",
    "    return loss\n",
    "\n",
    "# 4. Parse XML and Extract Data\n",
    "tree = ET.parse('binding_sites.xml')\n",
    "root = tree.getroot()\n",
    "all_sequences = [partner.find(\".//proSeq\").text for partner in root.findall(\".//BindPartner\")]\n",
    "all_labels = [convert_binding_string_to_labels(partner.find(\".//proBnd\").text) for partner in root.findall(\".//BindPartner\")]\n",
    "\n",
    "# 5. Data Splitting and Tokenization\n",
    "train_sequences, test_sequences, train_labels, test_labels = train_test_split(all_sequences, all_labels, test_size=0.20, shuffle=True)\n",
    "tokenizer = AutoTokenizer.from_pretrained(\"facebook/esm2_t6_8M_UR50D\")\n",
    "max_sequence_length = 1291\n",
    "train_tokenized = tokenizer(train_sequences, padding=True, truncation=True, max_length=max_sequence_length, return_tensors=\"pt\", is_split_into_words=False)\n",
    "test_tokenized = tokenizer(test_sequences, padding=True, truncation=True, max_length=max_sequence_length, return_tensors=\"pt\", is_split_into_words=False)\n",
    "\n",
    "train_labels = truncate_labels(train_labels, max_sequence_length)\n",
    "test_labels = truncate_labels(test_labels, max_sequence_length)\n",
    "train_dataset = Dataset.from_dict({k: v for k, v in train_tokenized.items()}).add_column(\"labels\", train_labels)\n",
    "test_dataset = Dataset.from_dict({k: v for k, v in test_tokenized.items()}).add_column(\"labels\", test_labels)\n",
    "\n",
    "# 6. Compute Class Weights\n",
    "classes = [0, 1]  \n",
    "flat_train_labels = [label for sublist in train_labels for label in sublist]\n",
    "class_weights = compute_class_weight(class_weight='balanced', classes=classes, y=flat_train_labels)\n",
    "class_weights = torch.tensor(class_weights, dtype=torch.float32).to(accelerator.device)\n",
    "\n",
    "# 7. Define Custom Trainer Class\n",
    "class WeightedTrainer(Trainer):\n",
    "    def compute_loss(self, model, inputs, return_outputs=False):\n",
    "        outputs = model(**inputs)\n",
    "        loss = compute_loss(model, inputs)\n",
    "        return (loss, outputs) if return_outputs else loss\n",
    "\n",
    "# 8. Training Setup\n",
    "model_checkpoint = \"facebook/esm2_t6_8M_UR50D\"\n",
    "lr = 0.0005437551839696541\n",
    "batch_size = 4\n",
    "num_epochs = 15\n",
    "\n",
    "# Define labels and model\n",
    "id2label = {0: \"No binding site\", 1: \"Binding site\"}\n",
    "label2id = {v: k for k, v in id2label.items()}\n",
    "model = AutoModelForTokenClassification.from_pretrained(model_checkpoint, num_labels=len(id2label), id2label=id2label, label2id=label2id)\n",
    "\n",
    "# Convert the model into a PeftModel\n",
    "peft_config = LoraConfig(\n",
    "    task_type=TaskType.TOKEN_CLS, \n",
    "    inference_mode=False, \n",
    "    r=16, \n",
    "    lora_alpha=16, \n",
    "    target_modules=[\"query\", \"key\", \"value\"],\n",
    "    lora_dropout=0.1, \n",
    "    bias=\"all\"\n",
    ")\n",
    "model = get_peft_model(model, peft_config)\n",
    "\n",
    "# Use the accelerator\n",
    "model = accelerator.prepare(model)\n",
    "train_dataset = accelerator.prepare(train_dataset)\n",
    "test_dataset = accelerator.prepare(test_dataset)\n",
    "\n",
    "timestamp = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')\n",
    "# Training setup\n",
    "training_args = TrainingArguments(\n",
    "    output_dir=f\"esm2_t6_8M-lora-binding-site-classification_{timestamp}\",\n",
    "    learning_rate=lr,\n",
    "    \n",
    "    # Learning Rate Scheduling\n",
    "    lr_scheduler_type=\"linear\",\n",
    "    warmup_steps=500,   # Number of warm-up steps; adjust based on your observations\n",
    "    \n",
    "    # Gradient Clipping\n",
    "    gradient_accumulation_steps=1,\n",
    "    max_grad_norm=1.0,   # Common value, but can be adjusted based on your observations\n",
    "    \n",
    "    # Batch Size\n",
    "    per_device_train_batch_size=batch_size,\n",
    "    per_device_eval_batch_size=batch_size,\n",
    "    \n",
    "    # Number of Epochs\n",
    "    num_train_epochs=num_epochs,\n",
    "    \n",
    "    # Weight Decay\n",
    "    weight_decay=0.025, # Adjust this value based on your observations, e.g., 0.01 or 0.05\n",
    "    \n",
    "    # Early Stopping\n",
    "    evaluation_strategy=\"epoch\",\n",
    "    save_strategy=\"epoch\",\n",
    "    load_best_model_at_end=True,\n",
    "    metric_for_best_model=\"f1\",   # You can also use \"eval_loss\" or \"eval_auc\" based on your preference\n",
    "    greater_is_better=True,\n",
    "    # early_stopping_patience=4,   # Stops after 3 evaluations without improvement\n",
    "    \n",
    "    # Additional default arguments\n",
    "    push_to_hub=False,   # Set to True if you want to push the model to the HuggingFace Hub\n",
    "    logging_dir=None,    # Directory for storing logs\n",
    "    logging_first_step=False,\n",
    "    logging_steps=200,   # Log every 200 steps\n",
    "    save_total_limit=4,  # Only the last 4 models are saved. Helps in saving disk space.\n",
    "    no_cuda=False,       # If True, will not use CUDA even if it's available\n",
    "    seed=42,             # Random seed for reproducibility\n",
    "    fp16=True,           # If True, uses half precision for training, which is faster and requires less memory but might be less accurate\n",
    "    # dataloader_num_workers=4,  # Number of CPU processes for data loading\n",
    ")\n",
    "\n",
    "# Initialize Trainer\n",
    "trainer = WeightedTrainer(\n",
    "    model=model,\n",
    "    args=training_args,\n",
    "    train_dataset=train_dataset,\n",
    "    eval_dataset=test_dataset,\n",
    "    tokenizer=tokenizer,\n",
    "    data_collator=DataCollatorForTokenClassification(tokenizer=tokenizer),\n",
    "    compute_metrics=compute_metrics\n",
    ")\n",
    "\n",
    "# 9. Train and Save Model\n",
    "trainer.train()\n",
    "save_path = os.path.join(\"lora_binding_sites\", f\"best_model_esm2_t6_8M_UR50D_{timestamp}\")\n",
    "trainer.save_model(save_path)\n",
    "tokenizer.save_pretrained(save_path)\n"
   ]
  }
 ],
 "metadata": {
  "language_info": {
   "name": "python"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}