File size: 2,891 Bytes
fcf5eb5
 
 
cd5e192
fcd05f8
 
86ae9fd
 
28bad80
86ae9fd
 
 
 
 
 
 
 
 
 
 
20d7597
 
 
 
 
 
 
 
 
 
057fb11
 
 
ba66946
 
 
 
 
 
 
057fb11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
---
license: mit
---

# ESM-2 QLoRA for Predicting Binding Sites

This model is the ESM-2 model [esm2_t12_35M_UR50D](https://huggingface.co/facebook/esm2_t12_35M_UR50D) finetuned with QLoRA on 
[this dataset](https://huggingface.co/datasets/AmelieSchreiber/2600K_binding_sites) of 2.6M protein sequences with binding and active 
site annotations from UniProt. The model and dataset size were scaled in a one-to-one way (following the Chinchilla paper) up from the smaller 
QLoRA adaptations of the `esm2_t6_8M_UR50D` models which were trained on 600K proteins. Since this model is 4.375 times larger, a dataset 
approximately 4.375 times larger is needed if Chinchilla scaling laws hold for QLoRA finetuning of protein language models. Determining if 
such scaling laws also hold is part of this project, so checking for improvements in performance metrics over a period of 3 epochs, as well 
as checking for signs of overfitting for each epoch are underway. 


## QLoRA Info

```
trainable params: 71046 || all params: 17246053 || trainable%: 0.41195512967517844
```

```python
'eval_loss': 0.6011912822723389,
'eval_accuracy': 0.9297529150299436,
'eval_precision': 0.22835223718675476,
'eval_recall': 0.697386656717114,
'eval_f1': 0.3440490710592986,
'eval_auc': 0.8167222019799886,
'eval_mcc': 0.3730152153022164
```

To use this model, run:

```
!pip install transformers -q
!pip install peft -q
```

Then run: 

```python
from transformers import AutoModelForTokenClassification, AutoTokenizer
from peft import PeftModel
import torch

# Path to the saved LoRA model
model_path = "AmelieSchreiber/esm2_t12_35M_qlora_binding_2600K_cp1"
# ESM2 base model
base_model_path = "facebook/esm2_t12_35M_UR50D"

# Load the model
base_model = AutoModelForTokenClassification.from_pretrained(base_model_path)
loaded_model = PeftModel.from_pretrained(base_model, model_path)

# Ensure the model is in evaluation mode
loaded_model.eval()

# Load the tokenizer
loaded_tokenizer = AutoTokenizer.from_pretrained(base_model_path)

# Protein sequence for inference
protein_sequence = "MAVPETRPNHTIYINNLNEKIKKDELKKSLHAIFSRFGQILDILVSRSLKMRGQAFVIFKEVSSATNALRSMQGFPFYDKPMRIQYAKTDSDIIAKMKGT"  # Replace with your actual sequence

# Tokenize the sequence
inputs = loaded_tokenizer(protein_sequence, return_tensors="pt", truncation=True, max_length=1024, padding='max_length')

# Run the model
with torch.no_grad():
    logits = loaded_model(**inputs).logits

# Get predictions
tokens = loaded_tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])  # Convert input ids back to tokens
predictions = torch.argmax(logits, dim=2)

# Define labels
id2label = {
    0: "No binding site",
    1: "Binding site"
}

# Print the predicted labels for each token
for token, prediction in zip(tokens, predictions[0].numpy()):
    if token not in ['<pad>', '<cls>', '<eos>']:
        print((token, id2label[prediction]))
```