AmelieSchreiber commited on
Commit
e255d39
1 Parent(s): 27f69f3

Update README (2).md

Browse files
Files changed (1) hide show
  1. README (2).md +126 -1
README (2).md CHANGED
@@ -1,9 +1,134 @@
1
  ---
2
  library_name: peft
 
3
  ---
4
  ## Training procedure
5
 
6
  ### Framework versions
7
 
8
-
9
  - PEFT 0.5.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  library_name: peft
3
+ license: mit
4
  ---
5
  ## Training procedure
6
 
7
  ### Framework versions
8
 
 
9
  - PEFT 0.5.0
10
+
11
+ ## Metrics:
12
+
13
+ ```python
14
+ Train:
15
+ ({'accuracy': 0.9406146072672105,
16
+ 'precision': 0.2947122459102886,
17
+ 'recall': 0.952624323712029,
18
+ 'f1': 0.4501592605994876,
19
+ 'auc': 0.9464622170085311,
20
+ 'mcc': 0.5118390407598565},
21
+ Test:
22
+ {'accuracy': 0.9266827008067329,
23
+ 'precision': 0.22378953253253775,
24
+ 'recall': 0.7790246675002842,
25
+ 'f1': 0.3476966444342296,
26
+ 'auc': 0.8547531675185658,
27
+ 'mcc': 0.3930283737012391})
28
+ ```
29
+
30
+ ## Using the Model
31
+
32
+ Head over to [here](https://huggingface.co/datasets/AmelieSchreiber/binding_sites_random_split_by_family)
33
+ to download the dataset first. Once you have the pickle files downloaded locally, run the following:
34
+
35
+ ```python
36
+ from datasets import Dataset
37
+ from transformers import AutoTokenizer
38
+ import pickle
39
+
40
+ # Load tokenizer
41
+ tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t12_35M_UR50D")
42
+
43
+ # Function to truncate labels
44
+ def truncate_labels(labels, max_length):
45
+ """Truncate labels to the specified max_length."""
46
+ return [label[:max_length] for label in labels]
47
+
48
+ # Set the maximum sequence length
49
+ max_sequence_length = 1000
50
+
51
+ # Load the data from pickle files
52
+ with open("train_sequences_chunked_by_family.pkl", "rb") as f:
53
+ train_sequences = pickle.load(f)
54
+ with open("test_sequences_chunked_by_family.pkl", "rb") as f:
55
+ test_sequences = pickle.load(f)
56
+ with open("train_labels_chunked_by_family.pkl", "rb") as f:
57
+ train_labels = pickle.load(f)
58
+ with open("test_labels_chunked_by_family.pkl", "rb") as f:
59
+ test_labels = pickle.load(f)
60
+
61
+ # Tokenize the sequences
62
+ train_tokenized = tokenizer(train_sequences, padding=True, truncation=True, max_length=max_sequence_length, return_tensors="pt", is_split_into_words=False)
63
+ test_tokenized = tokenizer(test_sequences, padding=True, truncation=True, max_length=max_sequence_length, return_tensors="pt", is_split_into_words=False)
64
+
65
+ # Truncate the labels to match the tokenized sequence lengths
66
+ train_labels = truncate_labels(train_labels, max_sequence_length)
67
+ test_labels = truncate_labels(test_labels, max_sequence_length)
68
+
69
+ # Create train and test datasets
70
+ train_dataset = Dataset.from_dict({k: v for k, v in train_tokenized.items()}).add_column("labels", train_labels)
71
+ test_dataset = Dataset.from_dict({k: v for k, v in test_tokenized.items()}).add_column("labels", test_labels)
72
+ ```
73
+
74
+ Then run the following to get the train/test metrics:
75
+
76
+ ```python
77
+ from sklearn.metrics import(
78
+ matthews_corrcoef,
79
+ accuracy_score,
80
+ precision_recall_fscore_support,
81
+ roc_auc_score
82
+ )
83
+ from peft import PeftModel
84
+ from transformers import DataCollatorForTokenClassification, AutoModelForTokenClassification
85
+ from transformers import Trainer
86
+ from accelerate import Accelerator
87
+
88
+ # Instantiate the accelerator
89
+ accelerator = Accelerator()
90
+
91
+ # Define paths to the LoRA and base models
92
+ base_model_path = "facebook/esm2_t12_35M_UR50D"
93
+ lora_model_path = "AmelieSchreiber/esm2_t12_35M_lora_binding_sites_cp1" # "path/to/your/lora/model" Replace with the correct path to your LoRA model
94
+
95
+ # Load the base model
96
+ base_model = AutoModelForTokenClassification.from_pretrained(base_model_path)
97
+
98
+ # Load the LoRA model
99
+ model = PeftModel.from_pretrained(base_model, lora_model_path)
100
+ model = accelerator.prepare(model) # Prepare the model using the accelerator
101
+
102
+ # Define label mappings
103
+ id2label = {0: "No binding site", 1: "Binding site"}
104
+ label2id = {v: k for k, v in id2label.items()}
105
+
106
+ # Create a data collator
107
+ data_collator = DataCollatorForTokenClassification(tokenizer)
108
+
109
+ # Define a function to compute the metrics
110
+ def compute_metrics(dataset):
111
+ # Get the predictions using the trained model
112
+ trainer = Trainer(model=model, data_collator=data_collator)
113
+ predictions, labels, _ = trainer.predict(test_dataset=dataset)
114
+
115
+ # Remove padding and special tokens
116
+ mask = labels != -100
117
+ true_labels = labels[mask].flatten()
118
+ flat_predictions = np.argmax(predictions, axis=2)[mask].flatten().tolist()
119
+
120
+ # Compute the metrics
121
+ accuracy = accuracy_score(true_labels, flat_predictions)
122
+ precision, recall, f1, _ = precision_recall_fscore_support(true_labels, flat_predictions, average='binary')
123
+ auc = roc_auc_score(true_labels, flat_predictions)
124
+ mcc = matthews_corrcoef(true_labels, flat_predictions) # Compute the MCC
125
+
126
+ return {"accuracy": accuracy, "precision": precision, "recall": recall, "f1": f1, "auc": auc, "mcc": mcc} # Include the MCC in the returned dictionary
127
+
128
+ # Get the metrics for the training and test datasets
129
+ train_metrics = compute_metrics(train_dataset)
130
+ test_metrics = compute_metrics(test_dataset)
131
+
132
+ train_metrics, test_metrics
133
+ ```
134
+