AmelieSchreiber commited on
Commit
216d13a
1 Parent(s): 790f81b

Upload LoRA_binding_sites_no_sweeps_v2.ipynb

Browse files
Files changed (1) hide show
  1. LoRA_binding_sites_no_sweeps_v2.ipynb +199 -0
LoRA_binding_sites_no_sweeps_v2.ipynb ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "# 1. Imports\n",
10
+ "import os\n",
11
+ "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\"\n",
12
+ "import numpy as np\n",
13
+ "import xml.etree.ElementTree as ET\n",
14
+ "import torch\n",
15
+ "import torch.nn as nn\n",
16
+ "from datetime import datetime\n",
17
+ "from sklearn.model_selection import train_test_split\n",
18
+ "from sklearn.utils.class_weight import compute_class_weight\n",
19
+ "from sklearn.metrics import precision_recall_fscore_support, roc_auc_score\n",
20
+ "from transformers import (\n",
21
+ " AutoModelForTokenClassification,\n",
22
+ " AutoTokenizer,\n",
23
+ " DataCollatorForTokenClassification,\n",
24
+ " TrainingArguments,\n",
25
+ " Trainer,\n",
26
+ ")\n",
27
+ "from datasets import Dataset\n",
28
+ "from accelerate import Accelerator\n",
29
+ "\n",
30
+ "# Imports specific to the custom model\n",
31
+ "from peft import get_peft_config, PeftModel, PeftConfig, get_peft_model, LoraConfig, TaskType\n",
32
+ "\n",
33
+ "# 2. Setup Environment Variables and Accelerator\n",
34
+ "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\"\n",
35
+ "accelerator = Accelerator()\n",
36
+ "\n",
37
+ "# 3. Helper Functions\n",
38
+ "def convert_binding_string_to_labels(binding_string):\n",
39
+ " \"\"\"Convert 'proBnd' strings into label arrays.\"\"\"\n",
40
+ " return [1 if char == '+' else 0 for char in binding_string]\n",
41
+ "\n",
42
+ "def truncate_labels(labels, max_length):\n",
43
+ " \"\"\"Truncate labels to the specified max_length.\"\"\"\n",
44
+ " return [label[:max_length] for label in labels]\n",
45
+ "\n",
46
+ "def compute_metrics(p):\n",
47
+ " \"\"\"Compute metrics for evaluation.\"\"\"\n",
48
+ " predictions, labels = p\n",
49
+ " predictions = np.argmax(predictions, axis=2)\n",
50
+ " predictions = predictions[labels != -100].flatten()\n",
51
+ " labels = labels[labels != -100].flatten()\n",
52
+ " precision, recall, f1, _ = precision_recall_fscore_support(labels, predictions, average='binary')\n",
53
+ " auc = roc_auc_score(labels, predictions)\n",
54
+ " return {'precision': precision, 'recall': recall, 'f1': f1, 'auc': auc}\n",
55
+ "\n",
56
+ "def compute_loss(model, inputs):\n",
57
+ " \"\"\"Custom compute_loss function.\"\"\"\n",
58
+ " logits = model(**inputs).logits\n",
59
+ " labels = inputs[\"labels\"]\n",
60
+ " loss_fct = nn.CrossEntropyLoss(weight=class_weights)\n",
61
+ " active_loss = inputs[\"attention_mask\"].view(-1) == 1\n",
62
+ " active_logits = logits.view(-1, model.config.num_labels)\n",
63
+ " active_labels = torch.where(\n",
64
+ " active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels)\n",
65
+ " )\n",
66
+ " loss = loss_fct(active_logits, active_labels)\n",
67
+ " return loss\n",
68
+ "\n",
69
+ "# 4. Parse XML and Extract Data\n",
70
+ "tree = ET.parse('binding_sites.xml')\n",
71
+ "root = tree.getroot()\n",
72
+ "all_sequences = [partner.find(\".//proSeq\").text for partner in root.findall(\".//BindPartner\")]\n",
73
+ "all_labels = [convert_binding_string_to_labels(partner.find(\".//proBnd\").text) for partner in root.findall(\".//BindPartner\")]\n",
74
+ "\n",
75
+ "# 5. Data Splitting and Tokenization\n",
76
+ "train_sequences, test_sequences, train_labels, test_labels = train_test_split(all_sequences, all_labels, test_size=0.20, shuffle=True)\n",
77
+ "tokenizer = AutoTokenizer.from_pretrained(\"facebook/esm2_t6_8M_UR50D\")\n",
78
+ "max_sequence_length = 1291\n",
79
+ "train_tokenized = tokenizer(train_sequences, padding=True, truncation=True, max_length=max_sequence_length, return_tensors=\"pt\", is_split_into_words=False)\n",
80
+ "test_tokenized = tokenizer(test_sequences, padding=True, truncation=True, max_length=max_sequence_length, return_tensors=\"pt\", is_split_into_words=False)\n",
81
+ "\n",
82
+ "train_labels = truncate_labels(train_labels, max_sequence_length)\n",
83
+ "test_labels = truncate_labels(test_labels, max_sequence_length)\n",
84
+ "train_dataset = Dataset.from_dict({k: v for k, v in train_tokenized.items()}).add_column(\"labels\", train_labels)\n",
85
+ "test_dataset = Dataset.from_dict({k: v for k, v in test_tokenized.items()}).add_column(\"labels\", test_labels)\n",
86
+ "\n",
87
+ "# 6. Compute Class Weights\n",
88
+ "classes = [0, 1] \n",
89
+ "flat_train_labels = [label for sublist in train_labels for label in sublist]\n",
90
+ "class_weights = compute_class_weight(class_weight='balanced', classes=classes, y=flat_train_labels)\n",
91
+ "class_weights = torch.tensor(class_weights, dtype=torch.float32).to(accelerator.device)\n",
92
+ "\n",
93
+ "# 7. Define Custom Trainer Class\n",
94
+ "class WeightedTrainer(Trainer):\n",
95
+ " def compute_loss(self, model, inputs, return_outputs=False):\n",
96
+ " outputs = model(**inputs)\n",
97
+ " loss = compute_loss(model, inputs)\n",
98
+ " return (loss, outputs) if return_outputs else loss\n",
99
+ "\n",
100
+ "# 8. Training Setup\n",
101
+ "model_checkpoint = \"facebook/esm2_t6_8M_UR50D\"\n",
102
+ "lr = 0.0005437551839696541\n",
103
+ "batch_size = 4\n",
104
+ "num_epochs = 15\n",
105
+ "\n",
106
+ "# Define labels and model\n",
107
+ "id2label = {0: \"No binding site\", 1: \"Binding site\"}\n",
108
+ "label2id = {v: k for k, v in id2label.items()}\n",
109
+ "model = AutoModelForTokenClassification.from_pretrained(model_checkpoint, num_labels=len(id2label), id2label=id2label, label2id=label2id)\n",
110
+ "\n",
111
+ "# Convert the model into a PeftModel\n",
112
+ "peft_config = LoraConfig(\n",
113
+ " task_type=TaskType.TOKEN_CLS, \n",
114
+ " inference_mode=False, \n",
115
+ " r=16, \n",
116
+ " lora_alpha=16, \n",
117
+ " target_modules=[\"query\", \"key\", \"value\"],\n",
118
+ " lora_dropout=0.1, \n",
119
+ " bias=\"all\"\n",
120
+ ")\n",
121
+ "model = get_peft_model(model, peft_config)\n",
122
+ "\n",
123
+ "# Use the accelerator\n",
124
+ "model = accelerator.prepare(model)\n",
125
+ "train_dataset = accelerator.prepare(train_dataset)\n",
126
+ "test_dataset = accelerator.prepare(test_dataset)\n",
127
+ "\n",
128
+ "timestamp = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')\n",
129
+ "# Training setup\n",
130
+ "training_args = TrainingArguments(\n",
131
+ " output_dir=f\"esm2_t6_8M-lora-binding-site-classification_{timestamp}\",\n",
132
+ " learning_rate=lr,\n",
133
+ " \n",
134
+ " # Learning Rate Scheduling\n",
135
+ " lr_scheduler_type=\"linear\",\n",
136
+ " warmup_steps=500, # Number of warm-up steps; adjust based on your observations\n",
137
+ " \n",
138
+ " # Gradient Clipping\n",
139
+ " gradient_accumulation_steps=1,\n",
140
+ " max_grad_norm=1.0, # Common value, but can be adjusted based on your observations\n",
141
+ " \n",
142
+ " # Batch Size\n",
143
+ " per_device_train_batch_size=batch_size,\n",
144
+ " per_device_eval_batch_size=batch_size,\n",
145
+ " \n",
146
+ " # Number of Epochs\n",
147
+ " num_train_epochs=num_epochs,\n",
148
+ " \n",
149
+ " # Weight Decay\n",
150
+ " weight_decay=0.025, # Adjust this value based on your observations, e.g., 0.01 or 0.05\n",
151
+ " \n",
152
+ " # Early Stopping\n",
153
+ " evaluation_strategy=\"epoch\",\n",
154
+ " save_strategy=\"epoch\",\n",
155
+ " load_best_model_at_end=True,\n",
156
+ " metric_for_best_model=\"f1\", # You can also use \"eval_loss\" or \"eval_auc\" based on your preference\n",
157
+ " greater_is_better=True,\n",
158
+ " # early_stopping_patience=4, # Stops after 3 evaluations without improvement\n",
159
+ " \n",
160
+ " # Additional default arguments\n",
161
+ " push_to_hub=False, # Set to True if you want to push the model to the HuggingFace Hub\n",
162
+ " logging_dir=None, # Directory for storing logs\n",
163
+ " logging_first_step=False,\n",
164
+ " logging_steps=200, # Log every 200 steps\n",
165
+ " save_total_limit=4, # Only the last 4 models are saved. Helps in saving disk space.\n",
166
+ " no_cuda=False, # If True, will not use CUDA even if it's available\n",
167
+ " seed=42, # Random seed for reproducibility\n",
168
+ " fp16=True, # If True, uses half precision for training, which is faster and requires less memory but might be less accurate\n",
169
+ " # dataloader_num_workers=4, # Number of CPU processes for data loading\n",
170
+ ")\n",
171
+ "\n",
172
+ "# Initialize Trainer\n",
173
+ "trainer = WeightedTrainer(\n",
174
+ " model=model,\n",
175
+ " args=training_args,\n",
176
+ " train_dataset=train_dataset,\n",
177
+ " eval_dataset=test_dataset,\n",
178
+ " tokenizer=tokenizer,\n",
179
+ " data_collator=DataCollatorForTokenClassification(tokenizer=tokenizer),\n",
180
+ " compute_metrics=compute_metrics\n",
181
+ ")\n",
182
+ "\n",
183
+ "# 9. Train and Save Model\n",
184
+ "trainer.train()\n",
185
+ "save_path = os.path.join(\"lora_binding_sites\", f\"best_model_esm2_t6_8M_UR50D_{timestamp}\")\n",
186
+ "trainer.save_model(save_path)\n",
187
+ "tokenizer.save_pretrained(save_path)\n"
188
+ ]
189
+ }
190
+ ],
191
+ "metadata": {
192
+ "language_info": {
193
+ "name": "python"
194
+ },
195
+ "orig_nbformat": 4
196
+ },
197
+ "nbformat": 4,
198
+ "nbformat_minor": 2
199
+ }