AmelieSchreiber
commited on
Commit
·
a2760f2
1
Parent(s):
4e63f79
Upload 7 files
Browse files- README.md +1 -72
- adapter_config.json +22 -0
- adapter_model.bin +3 -0
- special_tokens_map.json +7 -0
- tokenizer_config.json +5 -0
- training_args.bin +3 -0
- vocab.txt +33 -0
README.md
CHANGED
@@ -1,80 +1,9 @@
|
|
1 |
---
|
2 |
library_name: peft
|
3 |
-
license: mit
|
4 |
-
language:
|
5 |
-
- en
|
6 |
-
tags:
|
7 |
-
- transformers
|
8 |
-
- biology
|
9 |
-
- esm
|
10 |
-
- esm2
|
11 |
-
- protein
|
12 |
-
- protein language model
|
13 |
---
|
14 |
-
# ESM-2 RNA Binding Site LoRA
|
15 |
-
|
16 |
-
This is a Parameter Efficient Fine Tuning (PEFT) Low Rank Adaptation ([LoRA](https://huggingface.co/docs/peft/task_guides/token-classification-lora)) of
|
17 |
-
the [esm2_t30_150M_UR50D](https://huggingface.co/facebook/esm2_t30_150M_UR50D) model for the (binary) token classification task of
|
18 |
-
predicting RNA binding sites of proteins. The Github with the training script and conda env YAML can be
|
19 |
-
[found here](https://github.com/Amelie-Schreiber/esm2_LoRA_binding_sites/tree/main). You can also find a version of this model
|
20 |
-
that was fine-tuned without LoRA [here](https://huggingface.co/AmelieSchreiber/esm2_t6_8M_UR50D_rna_binding_site_predictor).
|
21 |
-
|
22 |
## Training procedure
|
23 |
|
24 |
-
This is a Low Rank Adaptation (LoRA) of `esm2_t6_8M_UR50D`,
|
25 |
-
trained on `166` protein sequences in the [RNA binding sites dataset](https://huggingface.co/datasets/AmelieSchreiber/data_of_protein-rna_binding_sites)
|
26 |
-
using a `75/25` train/test split. It achieves an evaluation loss of `0.17312709987163544`.
|
27 |
-
|
28 |
### Framework versions
|
29 |
|
30 |
-
- PEFT 0.4.0
|
31 |
-
|
32 |
-
## Using the Model
|
33 |
-
|
34 |
-
To use, try running:
|
35 |
-
```python
|
36 |
-
from transformers import AutoModelForTokenClassification, AutoTokenizer
|
37 |
-
from peft import PeftModel
|
38 |
-
import torch
|
39 |
-
|
40 |
-
# Path to the saved LoRA model
|
41 |
-
model_path = "AmelieSchreiber/esm2_t30_150M_UR50D_LoRA_RNA_binding"
|
42 |
-
# ESM2 base model
|
43 |
-
base_model_path = "facebook/esm2_t30_150M_UR50D"
|
44 |
-
|
45 |
-
# Load the model
|
46 |
-
base_model = AutoModelForTokenClassification.from_pretrained(base_model_path)
|
47 |
-
loaded_model = PeftModel.from_pretrained(base_model, model_path)
|
48 |
-
|
49 |
-
# Ensure the model is in evaluation mode
|
50 |
-
loaded_model.eval()
|
51 |
|
52 |
-
|
53 |
-
loaded_tokenizer = AutoTokenizer.from_pretrained(base_model_path)
|
54 |
-
|
55 |
-
# Protein sequence for inference
|
56 |
-
protein_sequence = "MAVPETRPNHTIYINNLNEKIKKDELKKSLHAIFSRFGQILDILVSRSLKMRGQAFVIFKEVSSATNALRSMQGFPFYDKPMRIQYAKTDSDIIAKMKGT" # Replace with your actual sequence
|
57 |
-
|
58 |
-
# Tokenize the sequence
|
59 |
-
inputs = loaded_tokenizer(protein_sequence, return_tensors="pt", truncation=True, max_length=1024, padding='max_length')
|
60 |
-
|
61 |
-
# Run the model
|
62 |
-
with torch.no_grad():
|
63 |
-
logits = loaded_model(**inputs).logits
|
64 |
-
|
65 |
-
# Get predictions
|
66 |
-
tokens = loaded_tokenizer.convert_ids_to_tokens(inputs["input_ids"][0]) # Convert input ids back to tokens
|
67 |
-
predictions = torch.argmax(logits, dim=2)
|
68 |
-
|
69 |
-
# Define labels
|
70 |
-
id2label = {
|
71 |
-
0: "No binding site",
|
72 |
-
1: "Binding site"
|
73 |
-
}
|
74 |
-
|
75 |
-
# Print the predicted labels for each token
|
76 |
-
for token, prediction in zip(tokens, predictions[0].numpy()):
|
77 |
-
if token not in ['<pad>', '<cls>', '<eos>']:
|
78 |
-
print((token, id2label[prediction]))
|
79 |
-
|
80 |
-
```
|
|
|
1 |
---
|
2 |
library_name: peft
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
## Training procedure
|
5 |
|
|
|
|
|
|
|
|
|
6 |
### Framework versions
|
7 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
|
9 |
+
- PEFT 0.4.0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
adapter_config.json
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"auto_mapping": null,
|
3 |
+
"base_model_name_or_path": "facebook/esm2_t30_150M_UR50D",
|
4 |
+
"bias": "all",
|
5 |
+
"fan_in_fan_out": false,
|
6 |
+
"inference_mode": true,
|
7 |
+
"init_lora_weights": true,
|
8 |
+
"layers_pattern": null,
|
9 |
+
"layers_to_transform": null,
|
10 |
+
"lora_alpha": 16,
|
11 |
+
"lora_dropout": 0.1,
|
12 |
+
"modules_to_save": null,
|
13 |
+
"peft_type": "LORA",
|
14 |
+
"r": 32,
|
15 |
+
"revision": null,
|
16 |
+
"target_modules": [
|
17 |
+
"query",
|
18 |
+
"key",
|
19 |
+
"value"
|
20 |
+
],
|
21 |
+
"task_type": "TOKEN_CLS"
|
22 |
+
}
|
adapter_model.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:54594596755c5a6be53f40096bdaa80d79320473cc84ec25b0cc41db622aa2cd
|
3 |
+
size 15750833
|
special_tokens_map.json
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cls_token": "<cls>",
|
3 |
+
"eos_token": "<eos>",
|
4 |
+
"mask_token": "<mask>",
|
5 |
+
"pad_token": "<pad>",
|
6 |
+
"unk_token": "<unk>"
|
7 |
+
}
|
tokenizer_config.json
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"clean_up_tokenization_spaces": true,
|
3 |
+
"model_max_length": 1000000000000000019884624838656,
|
4 |
+
"tokenizer_class": "EsmTokenizer"
|
5 |
+
}
|
training_args.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:8fffc4fc4da24eef90002cdc8ff78928507824695c61fdce3f26df01aff254ec
|
3 |
+
size 4091
|
vocab.txt
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<cls>
|
2 |
+
<pad>
|
3 |
+
<eos>
|
4 |
+
<unk>
|
5 |
+
L
|
6 |
+
A
|
7 |
+
G
|
8 |
+
V
|
9 |
+
S
|
10 |
+
E
|
11 |
+
R
|
12 |
+
T
|
13 |
+
I
|
14 |
+
D
|
15 |
+
P
|
16 |
+
K
|
17 |
+
Q
|
18 |
+
N
|
19 |
+
F
|
20 |
+
Y
|
21 |
+
M
|
22 |
+
H
|
23 |
+
W
|
24 |
+
C
|
25 |
+
X
|
26 |
+
B
|
27 |
+
U
|
28 |
+
Z
|
29 |
+
O
|
30 |
+
.
|
31 |
+
-
|
32 |
+
<null_1>
|
33 |
+
<mask>
|