AmelieSchreiber commited on
Commit
d24656f
1 Parent(s): 16bb25e

Upload qlora_train_v2.py

Browse files
Files changed (1) hide show
  1. qlora_train_v2.py +283 -0
qlora_train_v2.py ADDED
@@ -0,0 +1,283 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import wandb
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn as nn
6
+ from datetime import datetime
7
+ from sklearn.model_selection import train_test_split
8
+ from sklearn.utils.class_weight import compute_class_weight
9
+ from sklearn.metrics import accuracy_score, precision_recall_fscore_support, roc_auc_score, matthews_corrcoef
10
+ from transformers import (
11
+ AutoModelForTokenClassification,
12
+ AutoTokenizer,
13
+ DataCollatorForTokenClassification,
14
+ TrainingArguments,
15
+ Trainer,
16
+ BitsAndBytesConfig
17
+ )
18
+ from datasets import Dataset
19
+ from accelerate import Accelerator
20
+ from peft import get_peft_config, PeftModel, PeftConfig, get_peft_model, LoraConfig, TaskType, prepare_model_for_kbit_training
21
+ import pickle
22
+ import gc
23
+ from tqdm import tqdm
24
+
25
+ # Initialize accelerator and Weights & Biases
26
+ accelerator = Accelerator()
27
+ os.environ["WANDB_NOTEBOOK_NAME"] = 'qlora_train.py'
28
+ wandb.init(project='binding_site_prediction')
29
+
30
+ # Helper Functions and Data Preparation
31
+ #-----------------------------------------------------------------------------
32
+ # Added this first function in
33
+ def print_trainable_parameters(model):
34
+ """
35
+ Prints the number of trainable parameters in the model.
36
+ """
37
+ trainable_params = 0
38
+ all_param = 0
39
+ for _, param in model.named_parameters():
40
+ all_param += param.numel()
41
+ if param.requires_grad:
42
+ trainable_params += param.numel()
43
+ print(
44
+ f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}"
45
+ )
46
+ #-----------------------------------------------------------------------------
47
+
48
+ def save_config_to_txt(config, filename):
49
+ """Save the configuration dictionary to a text file."""
50
+ with open(filename, 'w') as f:
51
+ for key, value in config.items():
52
+ f.write(f"{key}: {value}\n")
53
+
54
+ def truncate_labels(labels, max_length):
55
+ return [label[:max_length] for label in labels]
56
+
57
+ def compute_metrics(p):
58
+ predictions, labels = p
59
+ predictions = np.argmax(predictions, axis=2)
60
+ predictions = predictions[labels != -100].flatten()
61
+ labels = labels[labels != -100].flatten()
62
+ accuracy = accuracy_score(labels, predictions)
63
+ precision, recall, f1, _ = precision_recall_fscore_support(labels, predictions, average='binary')
64
+ auc = roc_auc_score(labels, predictions)
65
+ mcc = matthews_corrcoef(labels, predictions)
66
+ return {'accuracy': accuracy, 'precision': precision, 'recall': recall, 'f1': f1, 'auc': auc, 'mcc': mcc}
67
+
68
+ def compute_loss(model, logits, inputs):
69
+ # logits = model(**inputs).logits
70
+ labels = inputs["labels"]
71
+ loss_fct = nn.CrossEntropyLoss(weight=class_weights)
72
+ active_loss = inputs["attention_mask"].view(-1) == 1
73
+ active_logits = logits.view(-1, model.config.num_labels)
74
+ active_labels = torch.where(
75
+ active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels)
76
+ )
77
+ loss = loss_fct(active_logits, active_labels)
78
+ return loss
79
+
80
+ # Load data from pickle files
81
+ with open("data/16M_data_small/v3_train_sequences_chunked_by_family.pkl", "rb") as f:
82
+ train_sequences = pickle.load(f)
83
+
84
+ with open("data/16M_data_small/v3_test_sequences_chunked_by_family.pkl", "rb") as f:
85
+ test_sequences = pickle.load(f)
86
+
87
+ with open("data/16M_data_small/v3_train_labels_chunked_by_family.pkl", "rb") as f:
88
+ train_labels = pickle.load(f)
89
+
90
+ with open("data/16M_data_small/v3_test_labels_chunked_by_family.pkl", "rb") as f:
91
+ test_labels = pickle.load(f)
92
+
93
+
94
+ # Tokenization
95
+ tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t30_150M_UR50D")
96
+
97
+ # Set max_sequence_length to the tokenizer's max input length
98
+ max_sequence_length = tokenizer.model_max_length
99
+
100
+ # Function to tokenize in batches, show progress bar with tqdm, and free up memory
101
+ def batch_tokenize(sequences, batch_size=10000): # Reduced batch size for better memory management
102
+ tokenized_outputs = {'input_ids': [], 'attention_mask': []} # Initialize empty lists for outputs
103
+
104
+ for i in tqdm(range(0, len(sequences), batch_size)):
105
+ batch = sequences[i:i+batch_size]
106
+ tokenized_batch = tokenizer(batch, padding=True, truncation=True, max_length=max_sequence_length, return_tensors="pt", is_split_into_words=False)
107
+
108
+ # Append the tokenized batch data to the outputs
109
+ for key in tokenized_outputs:
110
+ tokenized_outputs[key].append(tokenized_batch[key])
111
+
112
+ # Free up memory
113
+ del batch
114
+ del tokenized_batch
115
+ gc.collect()
116
+
117
+ # Convert lists to tensors
118
+ for key in tokenized_outputs:
119
+ tokenized_outputs[key] = torch.cat(tokenized_outputs[key])
120
+
121
+ return tokenized_outputs
122
+
123
+ train_tokenized = batch_tokenize(train_sequences)
124
+ test_tokenized = batch_tokenize(test_sequences)
125
+
126
+ # Free memory
127
+ del train_sequences
128
+ del test_sequences
129
+ gc.collect()
130
+
131
+ # Directly truncate the entire list of labels
132
+ train_labels = truncate_labels(train_labels, max_sequence_length)
133
+ test_labels = truncate_labels(test_labels, max_sequence_length)
134
+
135
+ train_dataset = Dataset.from_dict({k: v for k, v in train_tokenized.items()}).add_column("labels", train_labels)
136
+ test_dataset = Dataset.from_dict({k: v for k, v in test_tokenized.items()}).add_column("labels", test_labels)
137
+
138
+ # Compute Class Weights
139
+ classes = [0, 1]
140
+ flat_train_labels = [label for sublist in train_labels for label in sublist]
141
+ class_weights = compute_class_weight(class_weight='balanced', classes=classes, y=flat_train_labels)
142
+ class_weights = torch.tensor(class_weights, dtype=torch.bfloat16).to(accelerator.device)
143
+
144
+ # Define Custom Trainer Class
145
+ class WeightedTrainer(Trainer):
146
+ def compute_loss(self, model, inputs, return_outputs=False):
147
+ outputs = model(**inputs)
148
+ logits = outputs.logits
149
+ loss = compute_loss(model, logits, inputs)
150
+ return (loss, outputs) if return_outputs else loss
151
+
152
+ # Configure the quantization settings
153
+ bnb_config = BitsAndBytesConfig(
154
+ load_in_4bit=True,
155
+ bnb_4bit_use_double_quant=True,
156
+ bnb_4bit_quant_type="nf4",
157
+ bnb_4bit_compute_dtype=torch.bfloat16
158
+ )
159
+
160
+ def train_function_no_sweeps(train_dataset, test_dataset):
161
+
162
+ # Directly set the config
163
+ config = {
164
+ "lora_alpha": 1,
165
+ "lora_dropout": 0.5,
166
+ "lr": 1.701568055793089e-04,
167
+ "lr_scheduler_type": "cosine",
168
+ "max_grad_norm": 0.5,
169
+ "num_train_epochs": 5,
170
+ "per_device_train_batch_size": 40,
171
+ "r": 2,
172
+ "weight_decay": 0.3,
173
+ # Add other hyperparameters as needed
174
+ }
175
+
176
+ # Log the config to W&B
177
+ wandb.config.update(config)
178
+
179
+ # Save the config to a text file
180
+ timestamp = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
181
+ config_filename = f"esm2_t30_150M_qlora_config_{timestamp}.txt"
182
+ save_config_to_txt(config, config_filename)
183
+
184
+ model_checkpoint = "facebook/esm2_t30_150M_UR50D"
185
+
186
+ # Define labels and model
187
+ id2label = {0: "No binding site", 1: "Binding site"}
188
+ label2id = {v: k for k, v in id2label.items()}
189
+
190
+ model = AutoModelForTokenClassification.from_pretrained(
191
+ model_checkpoint,
192
+ num_labels=len(id2label),
193
+ id2label=id2label,
194
+ label2id=label2id,
195
+ quantization_config=bnb_config # Apply quantization here
196
+ )
197
+
198
+ # Prepare the model for 4-bit quantization training
199
+ model.gradient_checkpointing_enable()
200
+ model = prepare_model_for_kbit_training(model)
201
+
202
+ # Convert the model into a PeftModel
203
+ peft_config = LoraConfig(
204
+ task_type=TaskType.TOKEN_CLS,
205
+ inference_mode=False,
206
+ r=config["r"],
207
+ lora_alpha=config["lora_alpha"],
208
+ target_modules=[
209
+ "query",
210
+ "key",
211
+ "value",
212
+ "EsmSelfOutput.dense",
213
+ "EsmIntermediate.dense",
214
+ "EsmOutput.dense",
215
+ "EsmContactPredictionHead.regression",
216
+ "classifier"
217
+ ],
218
+ lora_dropout=config["lora_dropout"],
219
+ bias="none", # or "all" or "lora_only"
220
+ # modules_to_save=["classifier"]
221
+ )
222
+ model = get_peft_model(model, peft_config)
223
+ print_trainable_parameters(model) # added this in
224
+
225
+ # Use the accelerator
226
+ model = accelerator.prepare(model)
227
+ train_dataset = accelerator.prepare(train_dataset)
228
+ test_dataset = accelerator.prepare(test_dataset)
229
+
230
+ timestamp = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
231
+
232
+ # Training setup
233
+ training_args = TrainingArguments(
234
+ output_dir=f"esm2_t30_150M_qlora_binding_sites_{timestamp}",
235
+ learning_rate=config["lr"],
236
+ lr_scheduler_type=config["lr_scheduler_type"],
237
+ gradient_accumulation_steps=4, # changed from 1 to 4
238
+ # warmup_steps=2, # added this in
239
+ max_grad_norm=config["max_grad_norm"],
240
+ per_device_train_batch_size=config["per_device_train_batch_size"],
241
+ per_device_eval_batch_size=config["per_device_train_batch_size"],
242
+ num_train_epochs=config["num_train_epochs"],
243
+ weight_decay=config["weight_decay"],
244
+ evaluation_strategy="epoch",
245
+ save_strategy="epoch",
246
+ load_best_model_at_end=True,
247
+ metric_for_best_model="f1",
248
+ greater_is_better=True,
249
+ push_to_hub=False,
250
+ logging_dir=None,
251
+ logging_first_step=False,
252
+ logging_steps=200,
253
+ save_total_limit=7,
254
+ no_cuda=False,
255
+ seed=8893,
256
+ fp16=True,
257
+ report_to='wandb',
258
+ optim="paged_adamw_8bit" # added this in
259
+
260
+ )
261
+
262
+ # Initialize Trainer
263
+ trainer = WeightedTrainer(
264
+ model=model,
265
+ args=training_args,
266
+ train_dataset=train_dataset,
267
+ eval_dataset=test_dataset,
268
+ tokenizer=tokenizer,
269
+ data_collator=DataCollatorForTokenClassification(tokenizer=tokenizer),
270
+ compute_metrics=compute_metrics
271
+ )
272
+
273
+ # Train and Save Model
274
+ trainer.train()
275
+ save_path = os.path.join("qlora_binding_sites", f"best_model_esm2_t30_150M_qlora_{timestamp}")
276
+ trainer.save_model(save_path)
277
+ tokenizer.save_pretrained(save_path)
278
+
279
+ # Call the training function
280
+ if __name__ == "__main__":
281
+ train_function_no_sweeps(train_dataset, test_dataset)
282
+
283
+