AmelieSchreiber commited on
Commit
8551e6c
1 Parent(s): 4d0ba55

Upload 3 files

Browse files
Files changed (3) hide show
  1. lora_train.py +195 -0
  2. metrics_2.py +95 -0
  3. qlora_train.py +245 -0
lora_train.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import wandb
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn as nn
6
+ from datetime import datetime
7
+ from sklearn.model_selection import train_test_split
8
+ from sklearn.utils.class_weight import compute_class_weight
9
+ from sklearn.metrics import accuracy_score, precision_recall_fscore_support, roc_auc_score, matthews_corrcoef
10
+ from transformers import AutoModelForTokenClassification, AutoTokenizer, DataCollatorForTokenClassification, TrainingArguments, Trainer
11
+ from datasets import Dataset
12
+ from accelerate import Accelerator
13
+ from peft import get_peft_config, PeftModel, PeftConfig, get_peft_model, LoraConfig, TaskType
14
+ import pickle
15
+
16
+ # Initialize accelerator and Weights & Biases
17
+ accelerator = Accelerator()
18
+ os.environ["WANDB_NOTEBOOK_NAME"] = 'train.py'
19
+ wandb.init(project='binding_site_prediction')
20
+
21
+ # Helper Functions and Data Preparation
22
+ def save_config_to_txt(config, filename):
23
+ """Save the configuration dictionary to a text file."""
24
+ with open(filename, 'w') as f:
25
+ for key, value in config.items():
26
+ f.write(f"{key}: {value}\n")
27
+
28
+ def truncate_labels(labels, max_length):
29
+ return [label[:max_length] for label in labels]
30
+
31
+ def compute_metrics(p):
32
+ predictions, labels = p
33
+ predictions = np.argmax(predictions, axis=2)
34
+ predictions = predictions[labels != -100].flatten()
35
+ labels = labels[labels != -100].flatten()
36
+ accuracy = accuracy_score(labels, predictions)
37
+ precision, recall, f1, _ = precision_recall_fscore_support(labels, predictions, average='binary')
38
+ auc = roc_auc_score(labels, predictions)
39
+ mcc = matthews_corrcoef(labels, predictions)
40
+ return {'accuracy': accuracy, 'precision': precision, 'recall': recall, 'f1': f1, 'auc': auc, 'mcc': mcc}
41
+
42
+ def compute_loss(model, logits, inputs):
43
+ # logits = model(**inputs).logits
44
+ labels = inputs["labels"]
45
+ loss_fct = nn.CrossEntropyLoss(weight=class_weights)
46
+ active_loss = inputs["attention_mask"].view(-1) == 1
47
+ active_logits = logits.view(-1, model.config.num_labels)
48
+ active_labels = torch.where(
49
+ active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels)
50
+ )
51
+ loss = loss_fct(active_logits, active_labels)
52
+ return loss
53
+
54
+ # Load data from pickle files
55
+ with open("770K_data/train_sequences_chunked_by_family.pkl", "rb") as f:
56
+ train_sequences = pickle.load(f)
57
+
58
+ with open("770K_data/test_sequences_chunked_by_family.pkl", "rb") as f:
59
+ test_sequences = pickle.load(f)
60
+
61
+ with open("770K_data/train_labels_chunked_by_family.pkl", "rb") as f:
62
+ train_labels = pickle.load(f)
63
+
64
+ with open("770K_data/test_labels_chunked_by_family.pkl", "rb") as f:
65
+ test_labels = pickle.load(f)
66
+
67
+ # Tokenization
68
+ tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t12_35M_UR50D")
69
+
70
+ # Set max_sequence_length to the tokenizer's max input length
71
+ max_sequence_length = tokenizer.model_max_length
72
+
73
+ train_tokenized = tokenizer(train_sequences, padding=True, truncation=True, max_length=max_sequence_length, return_tensors="pt", is_split_into_words=False)
74
+ test_tokenized = tokenizer(test_sequences, padding=True, truncation=True, max_length=max_sequence_length, return_tensors="pt", is_split_into_words=False)
75
+
76
+ # Directly truncate the entire list of labels
77
+ train_labels = truncate_labels(train_labels, max_sequence_length)
78
+ test_labels = truncate_labels(test_labels, max_sequence_length)
79
+
80
+ train_dataset = Dataset.from_dict({k: v for k, v in train_tokenized.items()}).add_column("labels", train_labels)
81
+ test_dataset = Dataset.from_dict({k: v for k, v in test_tokenized.items()}).add_column("labels", test_labels)
82
+
83
+ # Compute Class Weights
84
+ classes = [0, 1]
85
+ flat_train_labels = [label for sublist in train_labels for label in sublist]
86
+ class_weights = compute_class_weight(class_weight='balanced', classes=classes, y=flat_train_labels)
87
+ class_weights = torch.tensor(class_weights, dtype=torch.float32).to(accelerator.device)
88
+
89
+ # Define Custom Trainer Class
90
+ class WeightedTrainer(Trainer):
91
+ def compute_loss(self, model, inputs, return_outputs=False):
92
+ outputs = model(**inputs)
93
+ logits = outputs.logits
94
+ loss = compute_loss(model, logits, inputs)
95
+ return (loss, outputs) if return_outputs else loss
96
+
97
+ # Define and run training function
98
+ def train_function_no_sweeps(train_dataset, test_dataset):
99
+
100
+ # Directly set the config
101
+ config = {
102
+ "lora_alpha": 1,
103
+ "lora_dropout": 0.5,
104
+ "lr": 3.701568055793089e-04,
105
+ "lr_scheduler_type": "cosine_with_restarts",
106
+ "max_grad_norm": 0.5,
107
+ "num_train_epochs": 3,
108
+ "per_device_train_batch_size": 6,
109
+ "r": 2,
110
+ "weight_decay": 0.2,
111
+ # Add other hyperparameters as needed
112
+ }
113
+
114
+ # Log the config to W&B
115
+ wandb.config.update(config)
116
+
117
+ # Save the config to a text file
118
+ timestamp = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
119
+ config_filename = f"esm2_t12_35M_lora_config_{timestamp}.txt"
120
+ save_config_to_txt(config, config_filename)
121
+
122
+ model_checkpoint = "facebook/esm2_t12_35M_UR50D"
123
+
124
+ # Define labels and model
125
+ id2label = {0: "No binding site", 1: "Binding site"}
126
+ label2id = {v: k for k, v in id2label.items()}
127
+ model = AutoModelForTokenClassification.from_pretrained(model_checkpoint, num_labels=len(id2label), id2label=id2label, label2id=label2id)
128
+
129
+ # Convert the model into a PeftModel
130
+ peft_config = LoraConfig(
131
+ task_type=TaskType.TOKEN_CLS,
132
+ inference_mode=False,
133
+ r=config["r"],
134
+ lora_alpha=config["lora_alpha"],
135
+ target_modules=["query", "key", "value"],
136
+ lora_dropout=config["lora_dropout"],
137
+ bias="none", # or "all" or "lora_only"
138
+ modules_to_save=["classifier"]
139
+ )
140
+ model = get_peft_model(model, peft_config)
141
+
142
+ # Use the accelerator
143
+ model = accelerator.prepare(model)
144
+ train_dataset = accelerator.prepare(train_dataset)
145
+ test_dataset = accelerator.prepare(test_dataset)
146
+
147
+ timestamp = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
148
+
149
+ # Training setup
150
+ training_args = TrainingArguments(
151
+ output_dir=f"esm2_t12_35M_lora_binding_sites_{timestamp}",
152
+ learning_rate=config["lr"],
153
+ lr_scheduler_type=config["lr_scheduler_type"],
154
+ gradient_accumulation_steps=1,
155
+ max_grad_norm=config["max_grad_norm"],
156
+ per_device_train_batch_size=config["per_device_train_batch_size"],
157
+ per_device_eval_batch_size=config["per_device_train_batch_size"],
158
+ num_train_epochs=config["num_train_epochs"],
159
+ weight_decay=config["weight_decay"],
160
+ evaluation_strategy="epoch",
161
+ save_strategy="epoch",
162
+ load_best_model_at_end=True,
163
+ metric_for_best_model="f1",
164
+ greater_is_better=True,
165
+ push_to_hub=False,
166
+ logging_dir=None,
167
+ logging_first_step=False,
168
+ logging_steps=200,
169
+ save_total_limit=7,
170
+ no_cuda=False,
171
+ seed=8893,
172
+ fp16=True,
173
+ report_to='wandb'
174
+ )
175
+
176
+ # Initialize Trainer
177
+ trainer = WeightedTrainer(
178
+ model=model,
179
+ args=training_args,
180
+ train_dataset=train_dataset,
181
+ eval_dataset=test_dataset,
182
+ tokenizer=tokenizer,
183
+ data_collator=DataCollatorForTokenClassification(tokenizer=tokenizer),
184
+ compute_metrics=compute_metrics
185
+ )
186
+
187
+ # Train and Save Model
188
+ trainer.train()
189
+ save_path = os.path.join("lora_binding_sites", f"best_model_esm2_t12_35M_lora_{timestamp}")
190
+ trainer.save_model(save_path)
191
+ tokenizer.save_pretrained(save_path)
192
+
193
+ # Call the training function
194
+ if __name__ == "__main__":
195
+ train_function_no_sweeps(train_dataset, test_dataset)
metrics_2.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import wandb
3
+ import numpy as np
4
+ import pickle
5
+ import torch
6
+ import torch.nn as nn
7
+ from sklearn.metrics import accuracy_score, precision_recall_fscore_support, roc_auc_score, matthews_corrcoef
8
+ from transformers import AutoModelForTokenClassification, AutoTokenizer, DataCollatorForTokenClassification, Trainer
9
+ from datasets import Dataset
10
+ from accelerate import Accelerator
11
+ from peft import PeftModel
12
+
13
+ # Helper functions and data preparation
14
+ def truncate_labels(labels, max_length):
15
+ """Truncate labels to the specified max_length."""
16
+ return [label[:max_length] for label in labels]
17
+
18
+ def compute_metrics(p):
19
+ """Compute metrics for evaluation."""
20
+ predictions, labels = p
21
+ predictions = np.argmax(predictions, axis=2)
22
+
23
+ # Remove padding (-100 labels)
24
+ predictions = predictions[labels != -100].flatten()
25
+ labels = labels[labels != -100].flatten()
26
+
27
+ # Compute accuracy
28
+ accuracy = accuracy_score(labels, predictions)
29
+
30
+ # Compute precision, recall, F1 score, and AUC
31
+ precision, recall, f1, _ = precision_recall_fscore_support(labels, predictions, average='binary')
32
+ auc = roc_auc_score(labels, predictions)
33
+
34
+ # Compute MCC
35
+ mcc = matthews_corrcoef(labels, predictions)
36
+
37
+ return {'accuracy': accuracy, 'precision': precision, 'recall': recall, 'f1': f1, 'auc': auc, 'mcc': mcc}
38
+
39
+ class WeightedTrainer(Trainer):
40
+ def compute_loss(self, model, inputs, return_outputs=False):
41
+ """Custom compute_loss function."""
42
+ outputs = model(**inputs)
43
+ loss_fct = nn.CrossEntropyLoss()
44
+ active_loss = inputs["attention_mask"].view(-1) == 1
45
+ active_logits = outputs.logits.view(-1, model.config.num_labels)
46
+ active_labels = torch.where(
47
+ active_loss, inputs["labels"].view(-1), torch.tensor(loss_fct.ignore_index).type_as(inputs["labels"])
48
+ )
49
+ loss = loss_fct(active_logits, active_labels)
50
+ return (loss, outputs) if return_outputs else loss
51
+
52
+ if __name__ == "__main__":
53
+ # Environment setup
54
+ accelerator = Accelerator()
55
+ wandb.init(project='binding_site_prediction')
56
+
57
+ # Load data and labels
58
+ with open("1111K_data/train_sequences_chunked_by_family.pkl", "rb") as f:
59
+ train_sequences = pickle.load(f)
60
+ with open("1111K_data/test_sequences_chunked_by_family.pkl", "rb") as f:
61
+ test_sequences = pickle.load(f)
62
+ with open("1111K_data/train_labels_chunked_by_family.pkl", "rb") as f:
63
+ train_labels = pickle.load(f)
64
+ with open("1111K_data/test_labels_chunked_by_family.pkl", "rb") as f:
65
+ test_labels = pickle.load(f)
66
+
67
+ # Tokenization and dataset creation
68
+ tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t12_35M_UR50D")
69
+ max_sequence_length = tokenizer.model_max_length
70
+ train_tokenized = tokenizer(train_sequences, padding=True, truncation=True, max_length=max_sequence_length, return_tensors="pt", is_split_into_words=False)
71
+ test_tokenized = tokenizer(test_sequences, padding=True, truncation=True, max_length=max_sequence_length, return_tensors="pt", is_split_into_words=False)
72
+ train_labels = truncate_labels(train_labels, max_sequence_length)
73
+ test_labels = truncate_labels(test_labels, max_sequence_length)
74
+ train_dataset = Dataset.from_dict({k: v for k, v in train_tokenized.items()}).add_column("labels", train_labels)
75
+ test_dataset = Dataset.from_dict({k: v for k, v in test_tokenized.items()}).add_column("labels", test_labels)
76
+
77
+ # Load the pre-trained LoRA model
78
+ base_model_path = "facebook/esm2_t12_35M_UR50D"
79
+ lora_model_path = "esm2_t12_35M_lora_binding_sites_2023-09-23_03-04-43/checkpoint-102604" # Replace with the correct path to your LoRA model
80
+ base_model = AutoModelForTokenClassification.from_pretrained(base_model_path)
81
+ model = PeftModel.from_pretrained(base_model, lora_model_path)
82
+ model = accelerator.prepare(model)
83
+
84
+ # Define a function to compute metrics and get the train/test metrics
85
+ data_collator = DataCollatorForTokenClassification(tokenizer)
86
+ trainer = Trainer(model=model, data_collator=data_collator, compute_metrics=compute_metrics)
87
+ train_metrics = trainer.evaluate(train_dataset)
88
+ test_metrics = trainer.evaluate(test_dataset)
89
+
90
+ # Print the metrics
91
+ print(f"Train metrics: {train_metrics}")
92
+ print(f"Test metrics: {test_metrics}")
93
+
94
+ # Log metrics to W&B
95
+ wandb.log({"Train metrics": train_metrics, "Test metrics": test_metrics})
qlora_train.py ADDED
@@ -0,0 +1,245 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import wandb
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn as nn
6
+ from datetime import datetime
7
+ from sklearn.model_selection import train_test_split
8
+ from sklearn.utils.class_weight import compute_class_weight
9
+ from sklearn.metrics import accuracy_score, precision_recall_fscore_support, roc_auc_score, matthews_corrcoef
10
+ from transformers import (
11
+ AutoModelForTokenClassification,
12
+ AutoTokenizer,
13
+ DataCollatorForTokenClassification,
14
+ TrainingArguments,
15
+ Trainer,
16
+ BitsAndBytesConfig
17
+ )
18
+ from datasets import Dataset
19
+ from accelerate import Accelerator
20
+ from peft import get_peft_config, PeftModel, PeftConfig, get_peft_model, LoraConfig, TaskType, prepare_model_for_kbit_training
21
+ import pickle
22
+
23
+ # Initialize accelerator and Weights & Biases
24
+ accelerator = Accelerator()
25
+ os.environ["WANDB_NOTEBOOK_NAME"] = 'train.py'
26
+ wandb.init(project='binding_site_prediction')
27
+
28
+ # Helper Functions and Data Preparation
29
+ #-----------------------------------------------------------------------------
30
+ # Added this first function in
31
+ def print_trainable_parameters(model):
32
+ """
33
+ Prints the number of trainable parameters in the model.
34
+ """
35
+ trainable_params = 0
36
+ all_param = 0
37
+ for _, param in model.named_parameters():
38
+ all_param += param.numel()
39
+ if param.requires_grad:
40
+ trainable_params += param.numel()
41
+ print(
42
+ f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}"
43
+ )
44
+ #-----------------------------------------------------------------------------
45
+
46
+ def save_config_to_txt(config, filename):
47
+ """Save the configuration dictionary to a text file."""
48
+ with open(filename, 'w') as f:
49
+ for key, value in config.items():
50
+ f.write(f"{key}: {value}\n")
51
+
52
+ def truncate_labels(labels, max_length):
53
+ return [label[:max_length] for label in labels]
54
+
55
+ def compute_metrics(p):
56
+ predictions, labels = p
57
+ predictions = np.argmax(predictions, axis=2)
58
+ predictions = predictions[labels != -100].flatten()
59
+ labels = labels[labels != -100].flatten()
60
+ accuracy = accuracy_score(labels, predictions)
61
+ precision, recall, f1, _ = precision_recall_fscore_support(labels, predictions, average='binary')
62
+ auc = roc_auc_score(labels, predictions)
63
+ mcc = matthews_corrcoef(labels, predictions)
64
+ return {'accuracy': accuracy, 'precision': precision, 'recall': recall, 'f1': f1, 'auc': auc, 'mcc': mcc}
65
+
66
+ def compute_loss(model, logits, inputs):
67
+ # logits = model(**inputs).logits
68
+ labels = inputs["labels"]
69
+ loss_fct = nn.CrossEntropyLoss(weight=class_weights)
70
+ active_loss = inputs["attention_mask"].view(-1) == 1
71
+ active_logits = logits.view(-1, model.config.num_labels)
72
+ active_labels = torch.where(
73
+ active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels)
74
+ )
75
+ loss = loss_fct(active_logits, active_labels)
76
+ return loss
77
+
78
+
79
+ # Load data from pickle files
80
+ with open("data/600K_data/train_sequences_chunked_by_family.pkl", "rb") as f:
81
+ train_sequences = pickle.load(f)
82
+
83
+ with open("data/600K_data/test_sequences_chunked_by_family.pkl", "rb") as f:
84
+ test_sequences = pickle.load(f)
85
+
86
+ with open("data/600K_data/train_labels_chunked_by_family.pkl", "rb") as f:
87
+ train_labels = pickle.load(f)
88
+
89
+ with open("data/600K_data/test_labels_chunked_by_family.pkl", "rb") as f:
90
+ test_labels = pickle.load(f)
91
+
92
+ # Tokenization
93
+ tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t6_8M_UR50D")
94
+
95
+ # Set max_sequence_length to the tokenizer's max input length
96
+ max_sequence_length = tokenizer.model_max_length
97
+
98
+ train_tokenized = tokenizer(train_sequences, padding=True, truncation=True, max_length=max_sequence_length, return_tensors="pt", is_split_into_words=False)
99
+ test_tokenized = tokenizer(test_sequences, padding=True, truncation=True, max_length=max_sequence_length, return_tensors="pt", is_split_into_words=False)
100
+
101
+ # Directly truncate the entire list of labels
102
+ train_labels = truncate_labels(train_labels, max_sequence_length)
103
+ test_labels = truncate_labels(test_labels, max_sequence_length)
104
+
105
+ train_dataset = Dataset.from_dict({k: v for k, v in train_tokenized.items()}).add_column("labels", train_labels)
106
+ test_dataset = Dataset.from_dict({k: v for k, v in test_tokenized.items()}).add_column("labels", test_labels)
107
+
108
+ # Compute Class Weights
109
+ classes = [0, 1]
110
+ flat_train_labels = [label for sublist in train_labels for label in sublist]
111
+ class_weights = compute_class_weight(class_weight='balanced', classes=classes, y=flat_train_labels)
112
+ class_weights = torch.tensor(class_weights, dtype=torch.float32).to(accelerator.device)
113
+
114
+ # Define Custom Trainer Class
115
+ class WeightedTrainer(Trainer):
116
+ def compute_loss(self, model, inputs, return_outputs=False):
117
+ outputs = model(**inputs)
118
+ logits = outputs.logits
119
+ loss = compute_loss(model, logits, inputs)
120
+ return (loss, outputs) if return_outputs else loss
121
+
122
+ # Configure the quantization settings
123
+ bnb_config = BitsAndBytesConfig(
124
+ load_in_4bit=True,
125
+ bnb_4bit_use_double_quant=True,
126
+ bnb_4bit_quant_type="nf4",
127
+ bnb_4bit_compute_dtype=torch.bfloat16
128
+ )
129
+
130
+ def train_function_no_sweeps(train_dataset, test_dataset):
131
+
132
+
133
+ # Directly set the config
134
+ config = {
135
+ "lora_alpha": 1,
136
+ "lora_dropout": 0.5,
137
+ "lr": 3.701568055793089e-04,
138
+ "lr_scheduler_type": "cosine",
139
+ "max_grad_norm": 0.5,
140
+ "num_train_epochs": 4,
141
+ "per_device_train_batch_size": 64,
142
+ "r": 2,
143
+ "weight_decay": 0.2,
144
+ # Add other hyperparameters as needed
145
+ }
146
+
147
+ # Log the config to W&B
148
+ wandb.config.update(config)
149
+
150
+ # Save the config to a text file
151
+ timestamp = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
152
+ config_filename = f"esm2_t6_8M_qlora_config_{timestamp}.txt"
153
+ save_config_to_txt(config, config_filename)
154
+
155
+
156
+ model_checkpoint = "facebook/esm2_t6_8M_UR50D"
157
+
158
+ # Define labels and model
159
+ id2label = {0: "No binding site", 1: "Binding site"}
160
+ label2id = {v: k for k, v in id2label.items()}
161
+
162
+ model = AutoModelForTokenClassification.from_pretrained(
163
+ model_checkpoint,
164
+ num_labels=len(id2label),
165
+ id2label=id2label,
166
+ label2id=label2id,
167
+ quantization_config=bnb_config # Apply quantization here
168
+ )
169
+
170
+ # Prepare the model for 4-bit quantization training
171
+ model.gradient_checkpointing_enable()
172
+ model = prepare_model_for_kbit_training(model)
173
+
174
+ # Convert the model into a PeftModel
175
+ peft_config = LoraConfig(
176
+ task_type=TaskType.TOKEN_CLS,
177
+ inference_mode=False,
178
+ r=config["r"],
179
+ lora_alpha=config["lora_alpha"],
180
+ target_modules=["query", "key", "value"],
181
+ lora_dropout=config["lora_dropout"],
182
+ bias="none", # or "all" or "lora_only"
183
+ modules_to_save=["classifier"]
184
+ )
185
+ model = get_peft_model(model, peft_config)
186
+ print_trainable_parameters(model) # added this in
187
+
188
+ # Use the accelerator
189
+ model = accelerator.prepare(model)
190
+ train_dataset = accelerator.prepare(train_dataset)
191
+ test_dataset = accelerator.prepare(test_dataset)
192
+
193
+ timestamp = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
194
+
195
+ # Training setup
196
+ training_args = TrainingArguments(
197
+ output_dir=f"esm2_t6_8M_qlora_binding_sites_{timestamp}",
198
+ learning_rate=config["lr"],
199
+ lr_scheduler_type=config["lr_scheduler_type"],
200
+ gradient_accumulation_steps=2, # changed from 1 to 4
201
+ # warmup_steps=2, # added this in
202
+ max_grad_norm=config["max_grad_norm"],
203
+ per_device_train_batch_size=config["per_device_train_batch_size"],
204
+ per_device_eval_batch_size=config["per_device_train_batch_size"],
205
+ num_train_epochs=config["num_train_epochs"],
206
+ weight_decay=config["weight_decay"],
207
+ evaluation_strategy="epoch",
208
+ save_strategy="epoch",
209
+ load_best_model_at_end=True,
210
+ metric_for_best_model="f1",
211
+ greater_is_better=True,
212
+ push_to_hub=False,
213
+ logging_dir=None,
214
+ logging_first_step=False,
215
+ logging_steps=200,
216
+ save_total_limit=7,
217
+ no_cuda=False,
218
+ seed=8893,
219
+ fp16=True,
220
+ report_to='wandb',
221
+ optim="paged_adamw_8bit" # added this in
222
+
223
+ )
224
+
225
+ # Initialize Trainer
226
+ trainer = WeightedTrainer(
227
+ model=model,
228
+ args=training_args,
229
+ train_dataset=train_dataset,
230
+ eval_dataset=test_dataset,
231
+ tokenizer=tokenizer,
232
+ data_collator=DataCollatorForTokenClassification(tokenizer=tokenizer),
233
+ compute_metrics=compute_metrics
234
+ )
235
+
236
+ # Train and Save Model
237
+ trainer.train()
238
+ save_path = os.path.join("qlora_binding_sites", f"best_model_esm2_t6_8M_qlora_{timestamp}")
239
+ trainer.save_model(save_path)
240
+ tokenizer.save_pretrained(save_path)
241
+
242
+ # Call the training function
243
+ if __name__ == "__main__":
244
+ train_function_no_sweeps(train_dataset, test_dataset)
245
+