AmelieSchreiber commited on
Commit
bf2d7db
1 Parent(s): 33f746c

Upload qlora_eff_load_train_only.py

Browse files
Files changed (1) hide show
  1. qlora_eff_load_train_only.py +272 -0
qlora_eff_load_train_only.py ADDED
@@ -0,0 +1,272 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import wandb
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn as nn
6
+ from torch.utils.data import Dataset as TorchDataset
7
+ from datetime import datetime
8
+ import random
9
+ from sklearn.utils.class_weight import compute_class_weight
10
+ from transformers import (
11
+ AutoModelForTokenClassification,
12
+ AutoTokenizer,
13
+ DataCollatorForTokenClassification,
14
+ TrainingArguments,
15
+ Trainer,
16
+ BitsAndBytesConfig
17
+ )
18
+ from accelerate import Accelerator
19
+ from peft import get_peft_config, PeftModel, PeftConfig, get_peft_model, LoraConfig, TaskType, prepare_model_for_kbit_training
20
+ import pickle
21
+ import gc
22
+ from tqdm import tqdm
23
+
24
+ # Define Desired Max Length
25
+ MAX_LENGTH = 512
26
+
27
+ # Initialize accelerator and Weights & Biases
28
+ accelerator = Accelerator()
29
+ os.environ["WANDB_NOTEBOOK_NAME"] = 'training.py'
30
+ wandb.init(project='binding_site_prediction')
31
+
32
+ # Helper Functions and Data Preparation
33
+ #-----------------------------------------------------------------------------
34
+
35
+ class ProteinDataset(TorchDataset):
36
+ def __init__(self, sequences_path, labels_path, tokenizer, max_length):
37
+ self.tokenizer = tokenizer
38
+ self.max_length = max_length
39
+
40
+ with open(sequences_path, "rb") as f:
41
+ self.sequences = pickle.load(f)
42
+
43
+ with open(labels_path, "rb") as f:
44
+ self.labels = pickle.load(f)
45
+
46
+ def __len__(self):
47
+ return len(self.sequences)
48
+
49
+ def __getitem__(self, idx):
50
+ sequence = self.sequences[idx]
51
+ label = self.labels[idx]
52
+
53
+ tokenized = self.tokenizer(sequence, padding='max_length', truncation=True, max_length=self.max_length, return_tensors="pt", is_split_into_words=False, add_special_tokens=False)
54
+
55
+ # Remove the extra batch dimension
56
+ for key in tokenized:
57
+ tokenized[key] = tokenized[key].squeeze(0)
58
+
59
+ # Ensure labels are also padded/truncated to match tokenized input
60
+ label_padded = [-100] * self.max_length # Using -100 as the ignore index
61
+ label_padded[:len(label)] = label[:self.max_length]
62
+
63
+ tokenized["labels"] = torch.tensor(label_padded)
64
+
65
+ return tokenized
66
+
67
+ def print_trainable_parameters(model):
68
+ """
69
+ Prints the number of trainable parameters in the model.
70
+ """
71
+ trainable_params = 0
72
+ all_param = 0
73
+ for _, param in model.named_parameters():
74
+ all_param += param.numel()
75
+ if param.requires_grad:
76
+ trainable_params += param.numel()
77
+ print(
78
+ f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}"
79
+ )
80
+
81
+ def save_config_to_txt(config, filename):
82
+ """Save the configuration dictionary to a text file."""
83
+ with open(filename, 'w') as f:
84
+ for key, value in config.items():
85
+ f.write(f"{key}: {value}\n")
86
+
87
+ def compute_metrics(p):
88
+ predictions, labels = p
89
+ predictions = np.argmax(predictions, axis=2)
90
+ mask = labels != -100
91
+ predictions = predictions[mask].flatten()
92
+ labels = labels[mask].flatten()
93
+
94
+ accuracy = accuracy_score(labels, predictions)
95
+ precision, recall, f1, _ = precision_recall_fscore_support(labels, predictions, average='binary')
96
+ auc = roc_auc_score(labels, predictions)
97
+ mcc = matthews_corrcoef(labels, predictions)
98
+
99
+ # Explicitly delete numpy arrays and call the garbage collector
100
+ del predictions
101
+ del labels
102
+ gc.collect()
103
+
104
+ return {'accuracy': accuracy, 'precision': precision, 'recall': recall, 'f1': f1, 'auc': auc, 'mcc': mcc}
105
+
106
+ def compute_loss(model, logits, inputs):
107
+ labels = inputs["labels"]
108
+ loss_fct = nn.CrossEntropyLoss(weight=class_weights)
109
+ active_loss = inputs["attention_mask"].view(-1) == 1
110
+ active_logits = logits.view(-1, model.config.num_labels)
111
+ active_labels = torch.where(
112
+ active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels)
113
+ )
114
+ loss = loss_fct(active_logits, active_labels)
115
+ return loss
116
+
117
+ tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D")
118
+ train_dataset = ProteinDataset("data/12M_data/512_train_sequences_chunked_by_family.pkl", "data/12M_data/512_train_labels_chunked_by_family.pkl", tokenizer, MAX_LENGTH)
119
+
120
+ # Compute Class Weights
121
+ # Sample a subset of labels for computing class weights (e.g., 100,000 sequences)
122
+ SAMPLE_SIZE = 100000
123
+
124
+ with open("data/12M_data/512_train_labels_chunked_by_family.pkl", "rb") as f:
125
+ all_train_labels = pickle.load(f)
126
+
127
+ sample_labels = random.sample(all_train_labels, SAMPLE_SIZE)
128
+
129
+ # Flatten the sampled labels
130
+ flat_sample_labels = [label for sublist in sample_labels for label in sublist]
131
+
132
+ # Compute class weights using the sampled labels
133
+ classes = [0, 1]
134
+ class_weights = compute_class_weight(class_weight='balanced', classes=classes, y=flat_sample_labels)
135
+ class_weights = torch.tensor(class_weights, dtype=torch.float32).to(accelerator.device)
136
+
137
+ # Define Custom Trainer Class
138
+ class WeightedTrainer(Trainer):
139
+ def compute_loss(self, model, inputs, return_outputs=False):
140
+ outputs = model(**inputs)
141
+ logits = outputs.logits
142
+ loss = compute_loss(model, logits, inputs)
143
+ return (loss, outputs) if return_outputs else loss
144
+
145
+ # Configure the quantization settings
146
+ bnb_config = BitsAndBytesConfig(
147
+ load_in_4bit=True,
148
+ bnb_4bit_use_double_quant=True,
149
+ bnb_4bit_quant_type="nf4",
150
+ bnb_4bit_compute_dtype=torch.bfloat16
151
+ )
152
+
153
+ def train_function_no_sweeps(train_dataset):
154
+
155
+ # Directly set the config
156
+ config = {
157
+ "lora_alpha": 1,
158
+ "lora_dropout": 0.5,
159
+ "lr": 1.701568055793089e-04,
160
+ "lr_scheduler_type": "cosine",
161
+ "max_grad_norm": 0.5,
162
+ "num_train_epochs": 1,
163
+ "per_device_train_batch_size": 200,
164
+ # "per_device_test_batch_size": 40,
165
+ "r": 2,
166
+ "weight_decay": 0.3,
167
+ # Add other hyperparameters as needed
168
+ }
169
+
170
+ # Log the config to W&B
171
+ wandb.config.update(config)
172
+
173
+ # Save the config to a text file
174
+ timestamp = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
175
+ config_filename = f"esm2_t33_650M_qlora_config_{timestamp}.txt"
176
+ save_config_to_txt(config, config_filename)
177
+
178
+ model_checkpoint = "facebook/esm2_t33_650M_UR50D"
179
+
180
+ # Define labels and model
181
+ id2label = {0: "No binding site", 1: "Binding site"}
182
+ label2id = {v: k for k, v in id2label.items()}
183
+
184
+ model = AutoModelForTokenClassification.from_pretrained(
185
+ model_checkpoint,
186
+ num_labels=len(id2label),
187
+ id2label=id2label,
188
+ label2id=label2id,
189
+ quantization_config=bnb_config
190
+ )
191
+
192
+ # Prepare the model for 4-bit quantization training
193
+ model.gradient_checkpointing_enable()
194
+ model = prepare_model_for_kbit_training(model)
195
+
196
+ # Convert the model into a PeftModel
197
+ peft_config = LoraConfig(
198
+ task_type=TaskType.TOKEN_CLS,
199
+ inference_mode=False,
200
+ r=config["r"],
201
+ lora_alpha=config["lora_alpha"],
202
+ target_modules=[
203
+ "query",
204
+ "key",
205
+ "value",
206
+ "EsmSelfOutput.dense",
207
+ "EsmIntermediate.dense",
208
+ "EsmOutput.dense",
209
+ # "EsmContactPredictionHead.regression",
210
+ "classifier"
211
+ ],
212
+ lora_dropout=config["lora_dropout"],
213
+ bias="none", # or "all" or "lora_only"
214
+ # modules_to_save=["classifier"]
215
+ )
216
+ model = get_peft_model(model, peft_config)
217
+ print_trainable_parameters(model) # added this in
218
+
219
+ # Use the accelerator
220
+ model = accelerator.prepare(model)
221
+ train_dataset = accelerator.prepare(train_dataset)
222
+
223
+ timestamp = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
224
+
225
+ # Training setup
226
+ training_args = TrainingArguments(
227
+ output_dir=f"esm2_t33_650M_qlora_binding_sites_{timestamp}",
228
+ learning_rate=config["lr"],
229
+ lr_scheduler_type=config["lr_scheduler_type"],
230
+ gradient_accumulation_steps=1,
231
+ max_grad_norm=config["max_grad_norm"],
232
+ per_device_train_batch_size=config["per_device_train_batch_size"],
233
+ # per_device_eval_batch_size=config["per_device_test_batch_size"],
234
+ num_train_epochs=config["num_train_epochs"],
235
+ weight_decay=config["weight_decay"],
236
+ evaluation_strategy="no",
237
+ save_strategy="steps", # Save at the end of each epoch
238
+ save_steps=10000, # Also save every 10000 steps
239
+ load_best_model_at_end=False,
240
+ metric_for_best_model="f1",
241
+ greater_is_better=True,
242
+ push_to_hub=False,
243
+ logging_dir=None,
244
+ logging_first_step=False,
245
+ logging_steps=100,
246
+ save_total_limit=7,
247
+ no_cuda=False,
248
+ seed=8893,
249
+ fp16=True,
250
+ report_to='wandb',
251
+ optim="paged_adamw_8bit" # added this in
252
+ )
253
+
254
+ # Initialize Trainer
255
+ trainer = WeightedTrainer(
256
+ model=model,
257
+ args=training_args,
258
+ train_dataset=train_dataset,
259
+ tokenizer=tokenizer,
260
+ data_collator=DataCollatorForTokenClassification(tokenizer=tokenizer)
261
+ )
262
+
263
+ # Train and Save Model
264
+ trainer.train()
265
+ save_path = os.path.join("qlora_binding_sites", f"best_model_esm2_t33_650M_qlora_{timestamp}")
266
+ trainer.save_model(save_path)
267
+ tokenizer.save_pretrained(save_path)
268
+
269
+ # Call the training function
270
+ if __name__ == "__main__":
271
+ train_function_no_sweeps(train_dataset)
272
+