AmelieSchreiber commited on
Commit
06f7792
1 Parent(s): 1de9930

Upload qlora_ptm_v2.py

Browse files
Files changed (1) hide show
  1. qlora_ptm_v2.py +253 -0
qlora_ptm_v2.py ADDED
@@ -0,0 +1,253 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import wandb
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn as nn
6
+ from datetime import datetime
7
+ from sklearn.model_selection import train_test_split
8
+ from sklearn.utils.class_weight import compute_class_weight
9
+ from sklearn.metrics import accuracy_score, precision_recall_fscore_support, roc_auc_score, matthews_corrcoef
10
+ from transformers import (
11
+ AutoModelForTokenClassification,
12
+ AutoTokenizer,
13
+ DataCollatorForTokenClassification,
14
+ TrainingArguments,
15
+ Trainer,
16
+ BitsAndBytesConfig
17
+ )
18
+ from datasets import Dataset
19
+ from accelerate import Accelerator
20
+ from peft import get_peft_config, PeftModel, PeftConfig, get_peft_model, LoraConfig, TaskType, prepare_model_for_kbit_training
21
+ import pickle
22
+
23
+ # Initialize accelerator and Weights & Biases
24
+ accelerator = Accelerator()
25
+ os.environ["WANDB_NOTEBOOK_NAME"] = 'qlora_ptm_v2.py'
26
+ wandb.init(project='ptm_site_prediction')
27
+
28
+ # Helper Functions and Data Preparation
29
+ #-----------------------------------------------------------------------------
30
+ # Added this first function in
31
+ def print_trainable_parameters(model):
32
+ """
33
+ Prints the number of trainable parameters in the model.
34
+ """
35
+ trainable_params = 0
36
+ all_param = 0
37
+ for _, param in model.named_parameters():
38
+ all_param += param.numel()
39
+ if param.requires_grad:
40
+ trainable_params += param.numel()
41
+ print(
42
+ f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}"
43
+ )
44
+ #-----------------------------------------------------------------------------
45
+
46
+ def save_config_to_txt(config, filename):
47
+ """Save the configuration dictionary to a text file."""
48
+ with open(filename, 'w') as f:
49
+ for key, value in config.items():
50
+ f.write(f"{key}: {value}\n")
51
+
52
+ def truncate_labels(labels, max_length):
53
+ return [label[:max_length] for label in labels]
54
+
55
+ def compute_metrics(p):
56
+ predictions, labels = p
57
+ predictions = np.argmax(predictions, axis=2)
58
+ predictions = predictions[labels != -100].flatten()
59
+ labels = labels[labels != -100].flatten()
60
+ accuracy = accuracy_score(labels, predictions)
61
+ precision, recall, f1, _ = precision_recall_fscore_support(labels, predictions, average='binary')
62
+ auc = roc_auc_score(labels, predictions)
63
+ mcc = matthews_corrcoef(labels, predictions)
64
+ return {'accuracy': accuracy, 'precision': precision, 'recall': recall, 'f1': f1, 'auc': auc, 'mcc': mcc}
65
+
66
+ def compute_loss(model, logits, inputs):
67
+ # logits = model(**inputs).logits
68
+ labels = inputs["labels"]
69
+ loss_fct = nn.CrossEntropyLoss(weight=class_weights)
70
+ active_loss = inputs["attention_mask"].view(-1) == 1
71
+ active_logits = logits.view(-1, model.config.num_labels)
72
+ active_labels = torch.where(
73
+ active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels)
74
+ )
75
+ loss = loss_fct(active_logits, active_labels)
76
+ return loss
77
+
78
+
79
+ # Load data from pickle files
80
+ with open("2100K_ptm_data/train_sequences_chunked_by_family.pkl", "rb") as f:
81
+ train_sequences = pickle.load(f)
82
+
83
+ with open("2100K_ptm_data/test_sequences_chunked_by_family.pkl", "rb") as f:
84
+ test_sequences = pickle.load(f)
85
+
86
+ with open("2100K_ptm_data/train_labels_chunked_by_family.pkl", "rb") as f:
87
+ train_labels = pickle.load(f)
88
+
89
+ with open("2100K_ptm_data/test_labels_chunked_by_family.pkl", "rb") as f:
90
+ test_labels = pickle.load(f)
91
+
92
+ # Tokenization
93
+ tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t30_150M_UR50D")
94
+
95
+ # Set max_sequence_length to the tokenizer's max input length
96
+ max_sequence_length = 1024
97
+
98
+ train_tokenized = tokenizer(train_sequences, padding=True, truncation=True, max_length=max_sequence_length, return_tensors="pt", is_split_into_words=False, add_special_tokens=False)
99
+ test_tokenized = tokenizer(test_sequences, padding=True, truncation=True, max_length=max_sequence_length, return_tensors="pt", is_split_into_words=False, add_special_tokens=False)
100
+
101
+ # Directly truncate the entire list of labels
102
+ train_labels = truncate_labels(train_labels, max_sequence_length)
103
+ test_labels = truncate_labels(test_labels, max_sequence_length)
104
+
105
+ train_dataset = Dataset.from_dict({k: v for k, v in train_tokenized.items()}).add_column("labels", train_labels)
106
+ test_dataset = Dataset.from_dict({k: v for k, v in test_tokenized.items()}).add_column("labels", test_labels)
107
+
108
+ # Compute Class Weights
109
+ classes = [0, 1]
110
+ flat_train_labels = [label for sublist in train_labels for label in sublist]
111
+ class_weights = compute_class_weight(class_weight='balanced', classes=classes, y=flat_train_labels)
112
+ class_weights = torch.tensor(class_weights, dtype=torch.float32).to(accelerator.device)
113
+
114
+ # Define Custom Trainer Class
115
+ class WeightedTrainer(Trainer):
116
+ def compute_loss(self, model, inputs, return_outputs=False):
117
+ outputs = model(**inputs)
118
+ logits = outputs.logits
119
+ loss = compute_loss(model, logits, inputs)
120
+ return (loss, outputs) if return_outputs else loss
121
+
122
+ # Configure the quantization settings
123
+ bnb_config = BitsAndBytesConfig(
124
+ load_in_4bit=True,
125
+ bnb_4bit_use_double_quant=True,
126
+ bnb_4bit_quant_type="nf4",
127
+ bnb_4bit_compute_dtype=torch.bfloat16
128
+ )
129
+
130
+ def train_function_no_sweeps(train_dataset, test_dataset):
131
+
132
+
133
+ # Directly set the config
134
+ config = {
135
+ "lora_alpha": 1,
136
+ "lora_dropout": 0.5,
137
+ "lr": 3.701568055793089e-04,
138
+ "lr_scheduler_type": "cosine",
139
+ "max_grad_norm": 0.5,
140
+ "num_train_epochs": 1,
141
+ "per_device_train_batch_size": 36,
142
+ "r": 2,
143
+ "weight_decay": 0.3,
144
+ # Add other hyperparameters as needed
145
+ }
146
+
147
+ # Log the config to W&B
148
+ wandb.config.update(config)
149
+
150
+ # Save the config to a text file
151
+ timestamp = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
152
+ config_filename = f"esm2_t30_150M_qlora_ptm_config_{timestamp}.txt"
153
+ save_config_to_txt(config, config_filename)
154
+
155
+
156
+ model_checkpoint = "facebook/esm2_t30_150M_UR50D"
157
+
158
+ # Define labels and model
159
+ id2label = {0: "No ptm site", 1: "ptm site"}
160
+ label2id = {v: k for k, v in id2label.items()}
161
+
162
+ model = AutoModelForTokenClassification.from_pretrained(
163
+ model_checkpoint,
164
+ num_labels=len(id2label),
165
+ id2label=id2label,
166
+ label2id=label2id,
167
+ quantization_config=bnb_config # Apply quantization here
168
+ )
169
+
170
+ # Prepare the model for 4-bit quantization training
171
+ model.gradient_checkpointing_enable()
172
+ model = prepare_model_for_kbit_training(model)
173
+
174
+ # Convert the model into a PeftModel
175
+ peft_config = LoraConfig(
176
+ task_type=TaskType.TOKEN_CLS,
177
+ inference_mode=False,
178
+ r=config["r"],
179
+ lora_alpha=config["lora_alpha"],
180
+ target_modules=[
181
+ "query",
182
+ "key",
183
+ "value",
184
+ "EsmSelfOutput.dense",
185
+ "EsmIntermediate.dense",
186
+ "EsmOutput.dense",
187
+ "EsmContactPredictionHead.regression",
188
+ "classifier"
189
+ ],
190
+ lora_dropout=config["lora_dropout"],
191
+ bias="none", # or "all" or "lora_only"
192
+ # modules_to_save=["classifier"]
193
+ )
194
+ model = get_peft_model(model, peft_config)
195
+ print_trainable_parameters(model) # added this in
196
+
197
+ # Use the accelerator
198
+ model = accelerator.prepare(model)
199
+ train_dataset = accelerator.prepare(train_dataset)
200
+ test_dataset = accelerator.prepare(test_dataset)
201
+
202
+ timestamp = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
203
+
204
+ # Training setup
205
+ training_args = TrainingArguments(
206
+ output_dir=f"esm2_t30_150M_qlora_ptm_sites_{timestamp}",
207
+ learning_rate=config["lr"],
208
+ lr_scheduler_type=config["lr_scheduler_type"],
209
+ gradient_accumulation_steps=1, # changed from 1 to 4
210
+ # warmup_steps=2, # added this in
211
+ max_grad_norm=config["max_grad_norm"],
212
+ per_device_train_batch_size=config["per_device_train_batch_size"],
213
+ per_device_eval_batch_size=config["per_device_train_batch_size"],
214
+ num_train_epochs=config["num_train_epochs"],
215
+ weight_decay=config["weight_decay"],
216
+ evaluation_strategy="epoch",
217
+ save_strategy="epoch",
218
+ load_best_model_at_end=True,
219
+ metric_for_best_model="f1",
220
+ greater_is_better=True,
221
+ push_to_hub=False,
222
+ logging_dir=None,
223
+ logging_first_step=False,
224
+ logging_steps=200,
225
+ save_total_limit=7,
226
+ no_cuda=False,
227
+ seed=8893,
228
+ fp16=True,
229
+ report_to='wandb',
230
+ optim="paged_adamw_8bit" # added this in
231
+
232
+ )
233
+
234
+ # Initialize Trainer
235
+ trainer = WeightedTrainer(
236
+ model=model,
237
+ args=training_args,
238
+ train_dataset=train_dataset,
239
+ eval_dataset=test_dataset,
240
+ tokenizer=tokenizer,
241
+ data_collator=DataCollatorForTokenClassification(tokenizer=tokenizer),
242
+ compute_metrics=compute_metrics
243
+ )
244
+
245
+ # Train and Save Model
246
+ trainer.train()
247
+ save_path = os.path.join("qlora_ptm_sites", f"best_model_esm2_t30_150M_qlora_{timestamp}")
248
+ trainer.save_model(save_path)
249
+ tokenizer.save_pretrained(save_path)
250
+
251
+ # Call the training function
252
+ if __name__ == "__main__":
253
+ train_function_no_sweeps(train_dataset, test_dataset)