File size: 4,384 Bytes
274cc2d
60da2ec
575416b
 
 
 
 
 
 
 
 
 
274cc2d
575416b
 
 
 
 
 
 
60da2ec
 
575416b
 
 
 
 
4b595e2
 
 
575416b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60da2ec
 
575416b
 
60da2ec
575416b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
---
library_name: peft
license: mit
language:
- en
tags:
- transformers
- biology
- esm
- esm2
- protein
- protein language model
---
# ESM-2 RNA Binding Site LoRA

This is a Parameter Efficient Fine Tuning (PEFT) Low Rank Adaptation (LoRA) of 
the [esm2_t12_35M_UR50D](https://huggingface.co/facebook/esm2_t12_35M_UR50D) model for the (binary) token classification task of 
predicting RNA binding sites of proteins. You can also find a version of this model 
that was fine-tuned without LoRA [here](https://huggingface.co/AmelieSchreiber/esm2_t6_8M_UR50D_rna_binding_site_predictor). 

## Training procedure

This is a Low Rank Adaptation (LoRA) of `esm2_t12_35M_UR50D`, 
trained on `166` protein sequences in the [RNA binding sites dataset](https://huggingface.co/datasets/AmelieSchreiber/data_of_protein-rna_binding_sites)
using a `85/15` train/test split. This model was trained with class weighting due to the imbalanced nature 
of the RNA binding site dataset (fewer binding sites than non-binding sites). This model has slightly improved 
precision, recall, and F1 score over [AmelieSchreiber/esm2_t12_35M_weighted_lora_rna_binding](https://huggingface.co/AmelieSchreiber/esm2_t12_35M_weighted_lora_rna_binding)
but may suffer from mild overfitting, as indicated by the training loss being slightly lower than the eval loss (see metrics below). 
If you are searching for binding sites and aren't worried about false positives, the higher recall may make this model 
preferable to the other RNA binding site predictors. 

You can train your own version 
using [this notebook](https://huggingface.co/AmelieSchreiber/esm2_t6_8M_weighted_lora_rna_binding/blob/main/LoRA_binding_sites_no_sweeps_v2.ipynb)! 
You just need the RNA `binding_sites.xml` file [found here](https://huggingface.co/datasets/AmelieSchreiber/data_of_protein-rna_binding_sites). 
You may also need to run some `pip install` statements at the beginning of the script. If you are running in colab run:

```python
!pip install transformers[torch] datasets peft -q
```
```python
!pip install accelerate -U -q
```
Try to improve upon these metrics by adjusting the hyperparameters:
```
{'eval_loss': 0.500779926776886,
'eval_precision': 0.1708695652173913,
'eval_recall': 0.8397435897435898,
'eval_f1': 0.2839595375722543,
'eval_auc': 0.771835775620126,
'epoch': 11.0}
{'loss': 0.4171,
'learning_rate': 0.00032491416877500004,
'epoch': 11.43}
```

A similar model can also be trained using the Github with a training script and conda env YAML, which can be
[found here](https://github.com/Amelie-Schreiber/esm2_LoRA_binding_sites/tree/main). This version uses wandb sweeps for hyperparameter search. 
However, it does not use class weighting. 


### Framework versions

- PEFT 0.4.0

## Using the Model

To use the model, try running the following pip install statements:
```python
!pip install transformers peft -q
```
then try tunning:
```python
from transformers import AutoModelForTokenClassification, AutoTokenizer
from peft import PeftModel
import torch

# Path to the saved LoRA model
model_path = "AmelieSchreiber/esm2_t12_35M_UR50D_RNA_LoRA_weighted"
# ESM2 base model
base_model_path = "facebook/esm2_t12_35M_UR50D"

# Load the model
base_model = AutoModelForTokenClassification.from_pretrained(base_model_path)
loaded_model = PeftModel.from_pretrained(base_model, model_path)

# Ensure the model is in evaluation mode
loaded_model.eval()

# Load the tokenizer
loaded_tokenizer = AutoTokenizer.from_pretrained(base_model_path)

# Protein sequence for inference
protein_sequence = "MAVPETRPNHTIYINNLNEKIKKDELKKSLHAIFSRFGQILDILVSRSLKMRGQAFVIFKEVSSATNALRSMQGFPFYDKPMRIQYAKTDSDIIAKMKGT"  # Replace with your actual sequence

# Tokenize the sequence
inputs = loaded_tokenizer(protein_sequence, return_tensors="pt", truncation=True, max_length=1024, padding='max_length')

# Run the model
with torch.no_grad():
    logits = loaded_model(**inputs).logits

# Get predictions
tokens = loaded_tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])  # Convert input ids back to tokens
predictions = torch.argmax(logits, dim=2)

# Define labels
id2label = {
    0: "No binding site",
    1: "Binding site"
}

# Print the predicted labels for each token
for token, prediction in zip(tokens, predictions[0].numpy()):
    if token not in ['<pad>', '<cls>', '<eos>']:
        print((token, id2label[prediction]))

```