AmelieSchreiber commited on
Commit
97a01f0
1 Parent(s): 9d54d52

Upload 2 files

Browse files
Files changed (2) hide show
  1. finetune.py +170 -0
  2. metrics (2).py +96 -0
finetune.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import wandb
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn as nn
6
+ from datetime import datetime
7
+ from sklearn.model_selection import train_test_split
8
+ from sklearn.utils.class_weight import compute_class_weight
9
+ from sklearn.metrics import precision_recall_fscore_support, roc_auc_score, accuracy_score, matthews_corrcoef
10
+ from transformers import (
11
+ AutoModelForTokenClassification,
12
+ AutoTokenizer,
13
+ DataCollatorForTokenClassification,
14
+ TrainingArguments,
15
+ Trainer
16
+ )
17
+ from datasets import Dataset
18
+ from accelerate import Accelerator
19
+ import pickle
20
+
21
+ # Initialize Weights & Biases logging
22
+ os.environ["WANDB_NOTEBOOK_NAME"] = 'esm2_t6_8M_finetune_600K.ipynb'
23
+ wandb.init(project='binding_site_prediction')
24
+
25
+ # Helper Functions
26
+ def truncate_labels(labels, max_length):
27
+ """Truncate labels to the specified max_length."""
28
+ return [label[:max_length] for label in labels]
29
+
30
+ def compute_metrics(p):
31
+ """Compute metrics for evaluation."""
32
+ predictions, labels = p
33
+ predictions = np.argmax(predictions, axis=2)
34
+ predictions = predictions[labels != -100].flatten()
35
+ labels = labels[labels != -100].flatten()
36
+ accuracy = accuracy_score(labels, predictions)
37
+ precision, recall, f1, _ = precision_recall_fscore_support(labels, predictions, average='binary')
38
+ auc = roc_auc_score(labels, predictions)
39
+ mcc = matthews_corrcoef(labels, predictions)
40
+ return {'accuracy': accuracy, 'precision': precision, 'recall': recall, 'f1': f1, 'auc': auc, 'mcc': mcc}
41
+
42
+ def compute_loss(model, inputs):
43
+ """Custom compute_loss function."""
44
+ logits = model(**inputs).logits
45
+ labels = inputs["labels"]
46
+ loss_fct = nn.CrossEntropyLoss(weight=class_weights)
47
+ active_loss = inputs["attention_mask"].view(-1) == 1
48
+ active_logits = logits.view(-1, model.config.num_labels)
49
+ active_labels = torch.where(
50
+ active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels)
51
+ )
52
+ loss = loss_fct(active_logits, active_labels)
53
+ return loss
54
+
55
+ # Custom Trainer Class
56
+ class WeightedTrainer(Trainer):
57
+ def compute_loss(self, model, inputs, return_outputs=False):
58
+ outputs = model(**inputs)
59
+ loss = compute_loss(model, inputs)
60
+ return (loss, outputs) if return_outputs else loss
61
+
62
+ # Load data
63
+ with open("600K_data/train_sequences_chunked_by_family.pkl", "rb") as f:
64
+ train_sequences = pickle.load(f)
65
+
66
+ with open("600K_data/test_sequences_chunked_by_family.pkl", "rb") as f:
67
+ test_sequences = pickle.load(f)
68
+
69
+ with open("600K_data/train_labels_chunked_by_family.pkl", "rb") as f:
70
+ train_labels = pickle.load(f)
71
+
72
+ with open("600K_data/test_labels_chunked_by_family.pkl", "rb") as f:
73
+ test_labels = pickle.load(f)
74
+
75
+ # Tokenization
76
+ tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t6_8M_UR50D")
77
+ max_sequence_length = 1000
78
+ train_tokenized = tokenizer(train_sequences, padding=True, truncation=True, max_length=max_sequence_length, return_tensors="pt", is_split_into_words=False)
79
+ test_tokenized = tokenizer(test_sequences, padding=True, truncation=True, max_length=max_sequence_length, return_tensors="pt", is_split_into_words=False)
80
+ train_labels = truncate_labels(train_labels, max_sequence_length)
81
+ test_labels = truncate_labels(test_labels, max_sequence_length)
82
+ train_dataset = Dataset.from_dict({k: v for k, v in train_tokenized.items()}).add_column("labels", train_labels)
83
+ test_dataset = Dataset.from_dict({k: v for k, v in test_tokenized.items()}).add_column("labels", test_labels)
84
+
85
+ # Compute Class Weights
86
+ classes = [0, 1]
87
+ flat_train_labels = [label for sublist in train_labels for label in sublist]
88
+ class_weights = compute_class_weight(class_weight='balanced', classes=classes, y=flat_train_labels)
89
+ accelerator = Accelerator()
90
+ class_weights = torch.tensor(class_weights, dtype=torch.float32).to(accelerator.device)
91
+
92
+ # Training Function
93
+ def train_function_no_sweeps(train_dataset, test_dataset):
94
+ # Initialize wandb
95
+ wandb.init()
96
+
97
+ # Configurations
98
+ config = {
99
+ "lr": 5.701568055793089e-04,
100
+ "lr_scheduler_type": "cosine",
101
+ "max_grad_norm": 0.5,
102
+ "num_train_epochs": 1,
103
+ "per_device_train_batch_size": 12,
104
+ "weight_decay": 0.2
105
+ }
106
+
107
+ # Model Setup
108
+ model_checkpoint = "facebook/esm2_t6_8M_UR50D"
109
+ id2label = {0: "No binding site", 1: "Binding site"}
110
+ label2id = {v: k for k, v in id2label.items()}
111
+ model = AutoModelForTokenClassification.from_pretrained(
112
+ model_checkpoint,
113
+ num_labels=len(id2label),
114
+ id2label=id2label,
115
+ label2id=label2id,
116
+ hidden_dropout_prob=0.5, # Add this line for hidden dropout
117
+ attention_probs_dropout_prob=0.5 # Add this line for attention dropout
118
+ )
119
+ model = accelerator.prepare(model)
120
+ train_dataset = accelerator.prepare(train_dataset)
121
+ test_dataset = accelerator.prepare(test_dataset)
122
+ timestamp = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
123
+
124
+ # Training setup
125
+ training_args = TrainingArguments(
126
+ output_dir=f"esm2_t6_8M_finetune_{timestamp}",
127
+ learning_rate=config["lr"],
128
+ lr_scheduler_type=config["lr_scheduler_type"],
129
+ gradient_accumulation_steps=1,
130
+ max_grad_norm=config["max_grad_norm"],
131
+ per_device_train_batch_size=config["per_device_train_batch_size"],
132
+ per_device_eval_batch_size=config["per_device_train_batch_size"],
133
+ num_train_epochs=config["num_train_epochs"],
134
+ weight_decay=config["weight_decay"],
135
+ evaluation_strategy="epoch",
136
+ save_strategy="epoch",
137
+ load_best_model_at_end=True,
138
+ metric_for_best_model="f1",
139
+ greater_is_better=True,
140
+ push_to_hub=False,
141
+ logging_dir=None,
142
+ logging_first_step=False,
143
+ logging_steps=200,
144
+ save_total_limit=7,
145
+ no_cuda=False,
146
+ seed=42,
147
+ fp16=True,
148
+ report_to='wandb'
149
+ )
150
+
151
+ # Initialize Trainer
152
+ trainer = WeightedTrainer(
153
+ model=model,
154
+ args=training_args,
155
+ train_dataset=train_dataset,
156
+ eval_dataset=test_dataset,
157
+ tokenizer=tokenizer,
158
+ data_collator=DataCollatorForTokenClassification(tokenizer=tokenizer),
159
+ compute_metrics=compute_metrics
160
+ )
161
+
162
+ # Train and Save Model
163
+ trainer.train()
164
+ save_path = os.path.join("binding_sites", f"best_model_esm2_t6_8M_{timestamp}")
165
+ trainer.save_model(save_path)
166
+ tokenizer.save_pretrained(save_path)
167
+
168
+ # Call the training function
169
+ if __name__ == "__main__":
170
+ train_function_no_sweeps(train_dataset, test_dataset)
metrics (2).py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import wandb
3
+ import numpy as np
4
+ import pickle
5
+ import torch
6
+ import torch.nn as nn
7
+ from sklearn.metrics import accuracy_score, precision_recall_fscore_support, roc_auc_score, matthews_corrcoef
8
+ from transformers import AutoModelForTokenClassification, AutoTokenizer, DataCollatorForTokenClassification, Trainer
9
+ from datasets import Dataset
10
+ from accelerate import Accelerator
11
+ from peft import PeftModel
12
+
13
+ # Helper functions and data preparation
14
+ def truncate_labels(labels, max_length):
15
+ """Truncate labels to the specified max_length."""
16
+ return [label[:max_length] for label in labels]
17
+
18
+ def compute_metrics(p):
19
+ """Compute metrics for evaluation."""
20
+ predictions, labels = p
21
+ predictions = np.argmax(predictions, axis=2)
22
+
23
+ # Remove padding (-100 labels)
24
+ predictions = predictions[labels != -100].flatten()
25
+ labels = labels[labels != -100].flatten()
26
+
27
+ # Compute accuracy
28
+ accuracy = accuracy_score(labels, predictions)
29
+
30
+ # Compute precision, recall, F1 score, and AUC
31
+ precision, recall, f1, _ = precision_recall_fscore_support(labels, predictions, average='binary')
32
+ auc = roc_auc_score(labels, predictions)
33
+
34
+ # Compute MCC
35
+ mcc = matthews_corrcoef(labels, predictions)
36
+
37
+ return {'accuracy': accuracy, 'precision': precision, 'recall': recall, 'f1': f1, 'auc': auc, 'mcc': mcc}
38
+
39
+ class WeightedTrainer(Trainer):
40
+ def compute_loss(self, model, inputs, return_outputs=False):
41
+ """Custom compute_loss function."""
42
+ outputs = model(**inputs)
43
+ loss_fct = nn.CrossEntropyLoss()
44
+ active_loss = inputs["attention_mask"].view(-1) == 1
45
+ active_logits = outputs.logits.view(-1, model.config.num_labels)
46
+ active_labels = torch.where(
47
+ active_loss, inputs["labels"].view(-1), torch.tensor(loss_fct.ignore_index).type_as(inputs["labels"])
48
+ )
49
+ loss = loss_fct(active_logits, active_labels)
50
+ return (loss, outputs) if return_outputs else loss
51
+
52
+ if __name__ == "__main__":
53
+ # Environment setup
54
+ accelerator = Accelerator()
55
+ # wandb.init(project='binding_site_prediction')
56
+
57
+ # Load data and labels
58
+ with open("600K_data/train_sequences_chunked_by_family.pkl", "rb") as f:
59
+ train_sequences = pickle.load(f)
60
+ with open("600K_data/test_sequences_chunked_by_family.pkl", "rb") as f:
61
+ test_sequences = pickle.load(f)
62
+ with open("600K_data/train_labels_chunked_by_family.pkl", "rb") as f:
63
+ train_labels = pickle.load(f)
64
+ with open("600K_data/test_labels_chunked_by_family.pkl", "rb") as f:
65
+ test_labels = pickle.load(f)
66
+
67
+ # Tokenization and dataset creation
68
+ tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t12_35M_UR50D")
69
+ max_sequence_length = tokenizer.model_max_length
70
+ train_tokenized = tokenizer(train_sequences, padding=True, truncation=True, max_length=max_sequence_length, return_tensors="pt", is_split_into_words=False)
71
+ test_tokenized = tokenizer(test_sequences, padding=True, truncation=True, max_length=max_sequence_length, return_tensors="pt", is_split_into_words=False)
72
+ train_labels = truncate_labels(train_labels, max_sequence_length)
73
+ test_labels = truncate_labels(test_labels, max_sequence_length)
74
+ train_dataset = Dataset.from_dict({k: v for k, v in train_tokenized.items()}).add_column("labels", train_labels)
75
+ test_dataset = Dataset.from_dict({k: v for k, v in test_tokenized.items()}).add_column("labels", test_labels)
76
+
77
+ # Load the pre-trained LoRA model
78
+ base_model_path = "esm2_t6_8M_finetune_2023-10-08_00-58-24/checkpoint-42015"
79
+ # lora_model_path = "AmelieSchreiber/esm2_t12_35M_qlora_binding_2600K_cp1" # Replace with the correct path to your LoRA model
80
+ # base_model = AutoModelForTokenClassification.from_pretrained(base_model_path) # use this for LoRA
81
+ model = AutoModelForTokenClassification.from_pretrained(base_model_path) # remove this for LoRA
82
+ # model = PeftModel.from_pretrained(base_model, lora_model_path) # use this for LoRA
83
+ model = accelerator.prepare(model)
84
+
85
+ # Define a function to compute metrics and get the train/test metrics
86
+ data_collator = DataCollatorForTokenClassification(tokenizer)
87
+ trainer = Trainer(model=model, data_collator=data_collator, compute_metrics=compute_metrics)
88
+ train_metrics = trainer.evaluate(train_dataset)
89
+ test_metrics = trainer.evaluate(test_dataset)
90
+
91
+ # Print the metrics
92
+ print(f"Train metrics: {train_metrics}")
93
+ print(f"Test metrics: {test_metrics}")
94
+
95
+ # Log metrics to W&B
96
+ # wandb.log({"Train metrics": train_metrics, "Test metrics": test_metrics})