AmelieSchreiber commited on
Commit
3de8059
1 Parent(s): bf2d7db

Upload metrics.py

Browse files
Files changed (1) hide show
  1. metrics.py +116 -0
metrics.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pickle
2
+ import numpy as np
3
+ import torch
4
+ from torch.utils.data import Dataset as TorchDataset
5
+ from transformers import AutoTokenizer, AutoModelForTokenClassification, Trainer
6
+ from sklearn.metrics import accuracy_score, precision_recall_fscore_support, roc_auc_score, matthews_corrcoef
7
+ from peft import PeftModel, get_peft_config, PeftConfig, get_peft_model, LoraConfig, TaskType
8
+ from accelerate import Accelerator
9
+ from tqdm import tqdm
10
+
11
+ # Initialize the Accelerator
12
+ accelerator = Accelerator()
13
+
14
+ class ProteinDataset(TorchDataset):
15
+ def __init__(self, sequences_path, labels_path, tokenizer, max_length):
16
+ self.tokenizer = tokenizer
17
+ self.max_length = max_length
18
+
19
+ with open(sequences_path, "rb") as f:
20
+ self.sequences = pickle.load(f)
21
+
22
+ with open(labels_path, "rb") as f:
23
+ self.labels = pickle.load(f)
24
+
25
+ def __len__(self):
26
+ return len(self.sequences)
27
+
28
+ def __getitem__(self, idx):
29
+ sequence = self.sequences[idx]
30
+ label = self.labels[idx]
31
+
32
+ tokenized = self.tokenizer(sequence, padding='max_length', truncation=True, max_length=self.max_length, return_tensors="pt", is_split_into_words=False, add_special_tokens=False)
33
+
34
+ # Remove the extra batch dimension
35
+ for key in tokenized:
36
+ tokenized[key] = tokenized[key].squeeze(0)
37
+
38
+ # Ensure labels are also padded/truncated to match tokenized input
39
+ label_padded = [-100] * self.max_length # Using -100 as the ignore index
40
+ label_padded[:len(label)] = label[:self.max_length]
41
+
42
+ tokenized["labels"] = torch.tensor(label_padded)
43
+
44
+ return tokenized
45
+
46
+ def compute_metrics(p):
47
+ predictions, labels = p.predictions, p.label_ids
48
+ predictions = np.argmax(predictions, axis=2)
49
+
50
+ mask = labels != -100
51
+ predictions = predictions[mask].flatten()
52
+ labels = labels[mask].flatten()
53
+
54
+ accuracy = accuracy_score(labels, predictions)
55
+ precision, recall, f1, _ = precision_recall_fscore_support(labels, predictions, average='binary')
56
+ auc = roc_auc_score(labels, predictions)
57
+ mcc = matthews_corrcoef(labels, predictions)
58
+
59
+ return {'accuracy': accuracy, 'precision': precision, 'recall': recall, 'f1': f1, 'auc': auc, 'mcc': mcc}
60
+
61
+ def evaluate_in_chunks(dataset, trainer, chunk_percentage=0.2):
62
+ chunk_size = int(len(dataset) * chunk_percentage)
63
+ all_results = []
64
+
65
+ # Wrap the loop with tqdm for progress bar
66
+ for i in tqdm(range(0, len(dataset), chunk_size), desc="Evaluating chunks"):
67
+ chunk = [dataset[j] for j in range(i, min(i + chunk_size, len(dataset)))]
68
+ chunk_results = trainer.evaluate(chunk)
69
+ print(f"Results for chunk starting at index {i}: {chunk_results}")
70
+
71
+ # Save the chunk results to disk
72
+ with open(f"results_chunk_{i}.pkl", "wb") as f:
73
+ pickle.dump(chunk_results, f)
74
+
75
+ all_results.append(chunk_results)
76
+
77
+ return all_results
78
+
79
+ def aggregate_results(results_list):
80
+ total_samples = sum([res["eval_samples"] for res in results_list])
81
+ aggregated_results = {}
82
+
83
+ for key in results_list[0].keys():
84
+ if key == "eval_samples":
85
+ continue
86
+ aggregated_results[key] = sum([res[key] * res["eval_samples"] for res in results_list]) / total_samples
87
+
88
+ return aggregated_results
89
+
90
+ # Initialize tokenizer and datasets
91
+ tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D")
92
+ train_dataset = ProteinDataset("data/12M_data/512_train_sequences_chunked_by_family.pkl", "data/12M_data/512_train_labels_chunked_by_family.pkl", tokenizer, 512)
93
+ test_dataset = ProteinDataset("data/12M_data/512_test_sequences_chunked_by_family.pkl", "data/12M_data/512_test_labels_chunked_by_family.pkl", tokenizer, 512)
94
+
95
+ # Load the pre-trained LoRA model
96
+ base_model_path = "facebook/esm2_t33_650M_UR50D"
97
+ lora_model_path = "qlora_binding_sites/best_model_esm2_t33_650M_qlora_2023-10-18_02-14-48"
98
+ base_model = AutoModelForTokenClassification.from_pretrained(base_model_path)
99
+ model = PeftModel.from_pretrained(base_model, lora_model_path)
100
+ model = accelerator.prepare(model)
101
+
102
+ # Initialize the Trainer
103
+ trainer = Trainer(
104
+ model=model,
105
+ compute_metrics=compute_metrics
106
+ )
107
+
108
+ Evaluate the model on chunks of the training dataset
109
+ train_results = evaluate_in_chunks(train_dataset, trainer)
110
+ aggregated_train_results = aggregate_results(train_results)
111
+ print(f"Aggregated Training Results: {aggregated_train_results}")
112
+
113
+ # Evaluate the model on chunks of the test dataset
114
+ test_results = evaluate_in_chunks(test_dataset, trainer)
115
+ aggregated_test_results = aggregate_results(test_results)
116
+ print(f"Aggregated Test Results: {aggregated_test_results}")