AmelieSchreiber
commited on
Commit
•
e255d39
1
Parent(s):
27f69f3
Update README (2).md
Browse files- README (2).md +126 -1
README (2).md
CHANGED
@@ -1,9 +1,134 @@
|
|
1 |
---
|
2 |
library_name: peft
|
|
|
3 |
---
|
4 |
## Training procedure
|
5 |
|
6 |
### Framework versions
|
7 |
|
8 |
-
|
9 |
- PEFT 0.5.0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
---
|
2 |
library_name: peft
|
3 |
+
license: mit
|
4 |
---
|
5 |
## Training procedure
|
6 |
|
7 |
### Framework versions
|
8 |
|
|
|
9 |
- PEFT 0.5.0
|
10 |
+
|
11 |
+
## Metrics:
|
12 |
+
|
13 |
+
```python
|
14 |
+
Train:
|
15 |
+
({'accuracy': 0.9406146072672105,
|
16 |
+
'precision': 0.2947122459102886,
|
17 |
+
'recall': 0.952624323712029,
|
18 |
+
'f1': 0.4501592605994876,
|
19 |
+
'auc': 0.9464622170085311,
|
20 |
+
'mcc': 0.5118390407598565},
|
21 |
+
Test:
|
22 |
+
{'accuracy': 0.9266827008067329,
|
23 |
+
'precision': 0.22378953253253775,
|
24 |
+
'recall': 0.7790246675002842,
|
25 |
+
'f1': 0.3476966444342296,
|
26 |
+
'auc': 0.8547531675185658,
|
27 |
+
'mcc': 0.3930283737012391})
|
28 |
+
```
|
29 |
+
|
30 |
+
## Using the Model
|
31 |
+
|
32 |
+
Head over to [here](https://huggingface.co/datasets/AmelieSchreiber/binding_sites_random_split_by_family)
|
33 |
+
to download the dataset first. Once you have the pickle files downloaded locally, run the following:
|
34 |
+
|
35 |
+
```python
|
36 |
+
from datasets import Dataset
|
37 |
+
from transformers import AutoTokenizer
|
38 |
+
import pickle
|
39 |
+
|
40 |
+
# Load tokenizer
|
41 |
+
tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t12_35M_UR50D")
|
42 |
+
|
43 |
+
# Function to truncate labels
|
44 |
+
def truncate_labels(labels, max_length):
|
45 |
+
"""Truncate labels to the specified max_length."""
|
46 |
+
return [label[:max_length] for label in labels]
|
47 |
+
|
48 |
+
# Set the maximum sequence length
|
49 |
+
max_sequence_length = 1000
|
50 |
+
|
51 |
+
# Load the data from pickle files
|
52 |
+
with open("train_sequences_chunked_by_family.pkl", "rb") as f:
|
53 |
+
train_sequences = pickle.load(f)
|
54 |
+
with open("test_sequences_chunked_by_family.pkl", "rb") as f:
|
55 |
+
test_sequences = pickle.load(f)
|
56 |
+
with open("train_labels_chunked_by_family.pkl", "rb") as f:
|
57 |
+
train_labels = pickle.load(f)
|
58 |
+
with open("test_labels_chunked_by_family.pkl", "rb") as f:
|
59 |
+
test_labels = pickle.load(f)
|
60 |
+
|
61 |
+
# Tokenize the sequences
|
62 |
+
train_tokenized = tokenizer(train_sequences, padding=True, truncation=True, max_length=max_sequence_length, return_tensors="pt", is_split_into_words=False)
|
63 |
+
test_tokenized = tokenizer(test_sequences, padding=True, truncation=True, max_length=max_sequence_length, return_tensors="pt", is_split_into_words=False)
|
64 |
+
|
65 |
+
# Truncate the labels to match the tokenized sequence lengths
|
66 |
+
train_labels = truncate_labels(train_labels, max_sequence_length)
|
67 |
+
test_labels = truncate_labels(test_labels, max_sequence_length)
|
68 |
+
|
69 |
+
# Create train and test datasets
|
70 |
+
train_dataset = Dataset.from_dict({k: v for k, v in train_tokenized.items()}).add_column("labels", train_labels)
|
71 |
+
test_dataset = Dataset.from_dict({k: v for k, v in test_tokenized.items()}).add_column("labels", test_labels)
|
72 |
+
```
|
73 |
+
|
74 |
+
Then run the following to get the train/test metrics:
|
75 |
+
|
76 |
+
```python
|
77 |
+
from sklearn.metrics import(
|
78 |
+
matthews_corrcoef,
|
79 |
+
accuracy_score,
|
80 |
+
precision_recall_fscore_support,
|
81 |
+
roc_auc_score
|
82 |
+
)
|
83 |
+
from peft import PeftModel
|
84 |
+
from transformers import DataCollatorForTokenClassification, AutoModelForTokenClassification
|
85 |
+
from transformers import Trainer
|
86 |
+
from accelerate import Accelerator
|
87 |
+
|
88 |
+
# Instantiate the accelerator
|
89 |
+
accelerator = Accelerator()
|
90 |
+
|
91 |
+
# Define paths to the LoRA and base models
|
92 |
+
base_model_path = "facebook/esm2_t12_35M_UR50D"
|
93 |
+
lora_model_path = "AmelieSchreiber/esm2_t12_35M_lora_binding_sites_cp1" # "path/to/your/lora/model" Replace with the correct path to your LoRA model
|
94 |
+
|
95 |
+
# Load the base model
|
96 |
+
base_model = AutoModelForTokenClassification.from_pretrained(base_model_path)
|
97 |
+
|
98 |
+
# Load the LoRA model
|
99 |
+
model = PeftModel.from_pretrained(base_model, lora_model_path)
|
100 |
+
model = accelerator.prepare(model) # Prepare the model using the accelerator
|
101 |
+
|
102 |
+
# Define label mappings
|
103 |
+
id2label = {0: "No binding site", 1: "Binding site"}
|
104 |
+
label2id = {v: k for k, v in id2label.items()}
|
105 |
+
|
106 |
+
# Create a data collator
|
107 |
+
data_collator = DataCollatorForTokenClassification(tokenizer)
|
108 |
+
|
109 |
+
# Define a function to compute the metrics
|
110 |
+
def compute_metrics(dataset):
|
111 |
+
# Get the predictions using the trained model
|
112 |
+
trainer = Trainer(model=model, data_collator=data_collator)
|
113 |
+
predictions, labels, _ = trainer.predict(test_dataset=dataset)
|
114 |
+
|
115 |
+
# Remove padding and special tokens
|
116 |
+
mask = labels != -100
|
117 |
+
true_labels = labels[mask].flatten()
|
118 |
+
flat_predictions = np.argmax(predictions, axis=2)[mask].flatten().tolist()
|
119 |
+
|
120 |
+
# Compute the metrics
|
121 |
+
accuracy = accuracy_score(true_labels, flat_predictions)
|
122 |
+
precision, recall, f1, _ = precision_recall_fscore_support(true_labels, flat_predictions, average='binary')
|
123 |
+
auc = roc_auc_score(true_labels, flat_predictions)
|
124 |
+
mcc = matthews_corrcoef(true_labels, flat_predictions) # Compute the MCC
|
125 |
+
|
126 |
+
return {"accuracy": accuracy, "precision": precision, "recall": recall, "f1": f1, "auc": auc, "mcc": mcc} # Include the MCC in the returned dictionary
|
127 |
+
|
128 |
+
# Get the metrics for the training and test datasets
|
129 |
+
train_metrics = compute_metrics(train_dataset)
|
130 |
+
test_metrics = compute_metrics(test_dataset)
|
131 |
+
|
132 |
+
train_metrics, test_metrics
|
133 |
+
```
|
134 |
+
|