AmelieSchreiber commited on
Commit
a928b59
1 Parent(s): d24656f

Upload qlora_train_v4.py

Browse files
Files changed (1) hide show
  1. qlora_train_v4.py +336 -0
qlora_train_v4.py ADDED
@@ -0,0 +1,336 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import wandb
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn as nn
6
+ from datetime import datetime
7
+ from sklearn.utils.class_weight import compute_class_weight
8
+ from sklearn.metrics import accuracy_score, precision_recall_fscore_support, roc_auc_score, matthews_corrcoef
9
+ from transformers import (
10
+ AutoModelForTokenClassification,
11
+ AutoTokenizer,
12
+ DataCollatorForTokenClassification,
13
+ TrainingArguments,
14
+ Trainer,
15
+ BitsAndBytesConfig,
16
+ default_data_collator
17
+ )
18
+ from torch.utils.data import Dataset as TorchDataset
19
+ from accelerate import Accelerator
20
+ from peft import get_peft_config, PeftModel, PeftConfig, get_peft_model, LoraConfig, TaskType, prepare_model_for_kbit_training
21
+ import pickle
22
+ import gc
23
+ from tqdm import tqdm
24
+
25
+ # Initialize accelerator and Weights & Biases
26
+ accelerator = Accelerator()
27
+ os.environ["WANDB_NOTEBOOK_NAME"] = 'qlora_train.py'
28
+ wandb.init(project='binding_site_prediction')
29
+
30
+ # Helper Functions and Data Preparation
31
+ # -----------------------------------------------------------------------------
32
+
33
+ def print_trainable_parameters(model):
34
+ """
35
+ Prints the number of trainable parameters in the model.
36
+ """
37
+ trainable_params = 0
38
+ all_param = 0
39
+ for _, param in model.named_parameters():
40
+ all_param += param.numel()
41
+ if param.requires_grad:
42
+ trainable_params += param.numel()
43
+ print(
44
+ f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}"
45
+ )
46
+
47
+ def save_config_to_txt(config, filename):
48
+ """Save the configuration dictionary to a text file."""
49
+ with open(filename, 'w') as f:
50
+ for key, value in config.items():
51
+ f.write(f"{key}: {value}\n")
52
+
53
+ def truncate_labels(labels, max_length):
54
+ return [label[:max_length] for label in labels]
55
+
56
+ def compute_metrics(p):
57
+ predictions, labels = p
58
+ predictions = np.argmax(predictions, axis=2)
59
+ predictions = predictions[labels != -100].flatten()
60
+ labels = labels[labels != -100].flatten()
61
+ accuracy = accuracy_score(labels, predictions)
62
+ precision, recall, f1, _ = precision_recall_fscore_support(labels, predictions, average='binary')
63
+ auc = roc_auc_score(labels, predictions)
64
+ mcc = matthews_corrcoef(labels, predictions)
65
+ return {'accuracy': accuracy, 'precision': precision, 'recall': recall, 'f1': f1, 'auc': auc, 'mcc': mcc}
66
+
67
+ def compute_loss(model, logits, inputs):
68
+ # print("Shape of input_ids:", inputs["input_ids"].shape)
69
+ labels = inputs["labels"]
70
+ loss_fct = nn.CrossEntropyLoss(weight=class_weights)
71
+ active_loss = inputs["attention_mask"].view(-1) == 1
72
+ active_logits = logits.view(-1, model.config.num_labels)
73
+ active_labels = torch.where(
74
+ active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels)
75
+ )
76
+ loss = loss_fct(active_logits, active_labels)
77
+ return loss
78
+
79
+ # Load data from pickle files
80
+ with open("data/16M_data_big/v2_train_sequences_chunked_by_family.pkl", "rb") as f:
81
+ train_sequences = pickle.load(f)
82
+ del f
83
+ gc.collect()
84
+
85
+ with open("data/16M_data_big/v2_test_sequences_chunked_by_family.pkl", "rb") as f:
86
+ test_sequences = pickle.load(f)
87
+ del f
88
+ gc.collect()
89
+
90
+ with open("data/16M_data_big/v2_train_labels_chunked_by_family.pkl", "rb") as f:
91
+ train_labels = pickle.load(f)
92
+ del f
93
+ gc.collect()
94
+
95
+ with open("data/16M_data_big/v2_test_labels_chunked_by_family.pkl", "rb") as f:
96
+ test_labels = pickle.load(f)
97
+ del f
98
+ gc.collect()
99
+
100
+ # Adjust max_sequence_length for special tokens
101
+ desired_length = 1022
102
+ tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D")
103
+ sample_sequence = "A"
104
+ tokenized_sample = tokenizer(sample_sequence)
105
+
106
+ # Debugging print statements
107
+ print(f"Sample Sequence: {sample_sequence}")
108
+ print(f"Tokenized Sample: {tokenized_sample}")
109
+ print(f"Number of tokens in tokenized sample: {len(tokenized_sample['input_ids'])}")
110
+
111
+ num_special_tokens = len(tokenized_sample["input_ids"]) - 1
112
+ print(f"Number of special tokens: {num_special_tokens}")
113
+
114
+ effective_length = desired_length - num_special_tokens
115
+ print(f"Effective sequence length (accounting for special tokens): {effective_length}")
116
+
117
+ # Custom Dataset for on-the-fly tokenization
118
+ class CustomDataset(TorchDataset):
119
+ def __init__(self, sequences, labels, tokenizer, max_length):
120
+ self.sequences = sequences
121
+ self.labels = labels
122
+ self.tokenizer = tokenizer
123
+ self.max_length = max_length
124
+
125
+ def __len__(self):
126
+ return len(self.sequences)
127
+
128
+ def __getitem__(self, idx):
129
+ sequence = self.sequences[idx]
130
+ label = self.labels[idx][:self.max_length]
131
+
132
+ tokenized = self.tokenizer(sequence, padding='max_length', truncation=True, max_length=effective_length, return_tensors="pt", is_split_into_words=False)
133
+
134
+ # Remove batch dimension
135
+ for key, value in tokenized.items():
136
+ tokenized[key] = value[0]
137
+
138
+ tokenized['labels'] = torch.tensor(label, dtype=torch.long)
139
+
140
+ # Diagnostics: Print the shape of the input_ids (or any other key you're interested in)
141
+ # print("Shape of input_ids:", tokenized["input_ids"].shape)
142
+
143
+ # Delete variables that are not needed anymore and collect garbage
144
+ del sequence, label
145
+ gc.collect()
146
+
147
+ return tokenized
148
+
149
+
150
+ train_dataset = CustomDataset(train_sequences, train_labels, tokenizer, effective_length)
151
+ test_dataset = CustomDataset(test_sequences, test_labels, tokenizer, effective_length)
152
+
153
+
154
+ # Compute Class Weights
155
+ classes = [0, 1]
156
+ # flat_train_labels = [label for sublist in train_labels for label in sublist]
157
+ flat_train_labels_gen = (label for sublist in tqdm(train_labels, desc="Flattening labels") for label in sublist)
158
+ flat_train_labels = np.fromiter(flat_train_labels_gen, dtype=np.int8)
159
+
160
+ del train_sequences, test_sequences, test_labels
161
+ gc.collect()
162
+
163
+ def compute_average_class_weight(train_labels, classes, batch_size):
164
+ num_batches = len(train_labels) // batch_size + (len(train_labels) % batch_size != 0)
165
+ total_weights = np.zeros(len(classes))
166
+
167
+ for i in tqdm(range(num_batches), desc="Computing class weights in batches"):
168
+ start_idx = i * batch_size
169
+ end_idx = start_idx + batch_size
170
+
171
+ batch_labels = train_labels[start_idx:end_idx]
172
+ flat_labels = np.array([label for sublist in batch_labels for label in sublist], dtype=np.int8)
173
+
174
+ weights = compute_class_weight(class_weight='balanced', classes=classes, y=flat_labels)
175
+ total_weights += weights
176
+
177
+ # Clear memory
178
+ del batch_labels, flat_labels, weights
179
+ gc.collect()
180
+
181
+ # Average the weights
182
+ average_weights = total_weights / num_batches
183
+ return average_weights
184
+
185
+ batch_size = 100000 # You can adjust this based on your memory capacity
186
+ class_weights = compute_average_class_weight(train_labels, classes, batch_size)
187
+ class_weights = torch.tensor(class_weights, dtype=torch.float32).to(accelerator.device)
188
+
189
+ del train_labels
190
+ gc.collect()
191
+
192
+ # class_weights = torch.tensor(class_weights, dtype=np.int8).to(accelerator.device)
193
+
194
+ # Define Custom Trainer Class
195
+ class WeightedTrainer(Trainer):
196
+ def compute_loss(self, model, inputs, return_outputs=False):
197
+ outputs = model(**inputs)
198
+ logits = outputs.logits
199
+ loss = compute_loss(model, logits, inputs)
200
+ return (loss, outputs) if return_outputs else loss
201
+
202
+
203
+ # Configure the quantization settings
204
+ bnb_config = BitsAndBytesConfig(
205
+ load_in_4bit=True,
206
+ bnb_4bit_use_double_quant=True,
207
+ bnb_4bit_quant_type="nf4",
208
+ bnb_4bit_compute_dtype=torch.bfloat16
209
+ )
210
+
211
+
212
+ def train_function_no_sweeps(train_dataset, test_dataset):
213
+
214
+ # Directly set the config
215
+ config = {
216
+ "lora_alpha": 1,
217
+ "lora_dropout": 0.5,
218
+ "lr": 1.701568055793089e-04,
219
+ "lr_scheduler_type": "cosine",
220
+ "max_grad_norm": 0.5,
221
+ "num_train_epochs": 4,
222
+ "per_device_train_batch_size": 60,
223
+ "r": 2,
224
+ "weight_decay": 0.3,
225
+ # Add other hyperparameters as needed
226
+ }
227
+
228
+ # Log the config to W&B
229
+ wandb.config.update(config)
230
+
231
+ # Save the config to a text file
232
+ timestamp = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
233
+ config_filename = f"esm2_t33_650M_qlora_config_{timestamp}.txt"
234
+ save_config_to_txt(config, config_filename)
235
+
236
+ model_checkpoint = "facebook/esm2_t33_650M_UR50D"
237
+
238
+ # Define labels and model
239
+ id2label = {0: "No binding site", 1: "Binding site"}
240
+ label2id = {v: k for k, v in id2label.items()}
241
+
242
+ model = AutoModelForTokenClassification.from_pretrained(
243
+ model_checkpoint,
244
+ num_labels=len(id2label),
245
+ id2label=id2label,
246
+ label2id=label2id,
247
+ quantization_config=bnb_config # Apply quantization here
248
+ )
249
+
250
+ # Prepare the model for 4-bit quantization training
251
+ model.gradient_checkpointing_enable()
252
+ model = prepare_model_for_kbit_training(model)
253
+
254
+ # Convert the model into a PeftModel
255
+ peft_config = LoraConfig(
256
+ task_type=TaskType.TOKEN_CLS,
257
+ inference_mode=False,
258
+ r=config["r"],
259
+ lora_alpha=config["lora_alpha"],
260
+ target_modules=[
261
+ "query",
262
+ "key",
263
+ "value",
264
+ "EsmSelfOutput.dense",
265
+ "EsmIntermediate.dense",
266
+ "EsmOutput.dense",
267
+ "EsmContactPredictionHead.regression",
268
+ "classifier"
269
+ ],
270
+ lora_dropout=config["lora_dropout"],
271
+ bias="none", # or "all" or "lora_only"
272
+ # modules_to_save=["classifier"]
273
+ )
274
+ model = get_peft_model(model, peft_config)
275
+ print_trainable_parameters(model) # added this in
276
+
277
+ # Use the accelerator
278
+ model = accelerator.prepare(model)
279
+ train_dataset = accelerator.prepare(train_dataset)
280
+ test_dataset = accelerator.prepare(test_dataset)
281
+
282
+ timestamp = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
283
+
284
+ # Training setup
285
+ training_args = TrainingArguments(
286
+ output_dir=f"esm2_t33_650M_qlora_binding_sites_{timestamp}",
287
+ learning_rate=config["lr"],
288
+ lr_scheduler_type=config["lr_scheduler_type"],
289
+ gradient_accumulation_steps=4, # changed from 1 to 4
290
+ # warmup_steps=2, # added this in
291
+ max_grad_norm=config["max_grad_norm"],
292
+ per_device_train_batch_size=config["per_device_train_batch_size"],
293
+ per_device_eval_batch_size=config["per_device_train_batch_size"],
294
+ num_train_epochs=config["num_train_epochs"],
295
+ weight_decay=config["weight_decay"],
296
+ evaluation_strategy="epoch",
297
+ save_strategy="epoch",
298
+ load_best_model_at_end=True,
299
+ metric_for_best_model="f1",
300
+ greater_is_better=True,
301
+ push_to_hub=False,
302
+ logging_dir=None,
303
+ logging_first_step=False,
304
+ logging_steps=200,
305
+ save_total_limit=7,
306
+ no_cuda=False,
307
+ seed=8893,
308
+ fp16=True,
309
+ report_to='wandb',
310
+ optim="paged_adamw_8bit" # added this in
311
+
312
+ )
313
+
314
+ # Initialize Trainer
315
+ trainer = WeightedTrainer(
316
+ model=model,
317
+ args=training_args,
318
+ train_dataset=train_dataset,
319
+ eval_dataset=test_dataset,
320
+ tokenizer=tokenizer,
321
+ data_collator=DataCollatorForTokenClassification(tokenizer=tokenizer),
322
+ compute_metrics=compute_metrics
323
+ )
324
+
325
+ # Train and Save Model
326
+ trainer.train()
327
+ save_path = os.path.join("qlora_binding_sites", f"best_model_esm2_t33_650M_qlora_{timestamp}")
328
+ trainer.save_model(save_path)
329
+ tokenizer.save_pretrained(save_path)
330
+
331
+ # Call the training function
332
+ if __name__ == "__main__":
333
+ train_function_no_sweeps(train_dataset, test_dataset)
334
+
335
+
336
+