File size: 4,758 Bytes
a049c3f
 
c20ef98
 
 
 
 
 
 
 
a049c3f
c20ef98
45e45d6
 
8b577db
 
508e6f9
 
bbe0e74
 
 
 
 
 
 
2248c60
 
c20ef98
b79663e
 
 
 
 
 
 
 
 
 
 
 
 
 
45e45d6
 
 
 
 
 
 
0b6b91e
49868ca
77ab0f2
2f8efbb
16bb25e
 
77ab0f2
0b6b91e
 
 
 
c66d466
 
c1701b9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b79663e
 
3290895
0b6b91e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
---
license: mit
language:
- en
library_name: peft
tags:
- ESM-2
- QLoRA
- Binding Sites
- biology
---

# ESM-2 QLoRA

These are the checkpoints for the first ever QLoRA for ESM-2! You can load and use them similarly to the LoRA models. 
This is the smallest `esm2_t6_8M_UR50D` model, so the metrics aren't great. 
Scaling to larger models for better metrics is in progress. These checkpoints were trained using 
[the 600K dataset](https://huggingface.co/datasets/AmelieSchreiber/600K_data). To replicate the training of QLoRA for ESM-2 models, 
you can use the `conda-environment.yml` file. However, for the next week or two (28/09/2023) you will need to uninstall transformers
and use this instead:

```
pip install --upgrade git+https://github.com/huggingface/transformers.git
```

In a couple of weeks, once the transformers library is updated, you should be able to simply use the latest version of transformers 
and gradient checkpointing will be fully enabled, and QLoRA compatibility should be fully integrated into ESM-2 models. 

## Data Curation and Preprocessing

To create your own datasets and perform the same data preprocessing as was used for this project, you will need to download a TSV file 
from UniProt with the following columns (Protein families, Binding sites, Active sites, Protein sequence), and then you can use 
[this notebook](https://huggingface.co/AmelieSchreiber/esm2_t6_8m_qlora_binding_sites_v0/blob/main/data_processing_v1.ipynb) for 
separating out the test sequences by choosing random families to use (including all sequences in that family, with no overlap in with 
the training data), filtering out proteins with incomplete annotations, merging the binding and active sites, converting them to binary 
labels (`0` for non-binding sites, `1` for binding sites), and splitting the sequences into non-overlapping chunks of 1000 residues or 
less to accomodate the 1022 sized context window of ESM-2 models. This notebook will also allow you to reduce the size of your dataset 
at the end. Note, this step is not currently ideal as it only selects proteins at random from the train and test datasets to keep and does 
not take into account that proteins from small families are less likely to be chosen, biasing the models towards larger families. Due to 
this shortcoming in our data preprocessing step, smaller models trained on smaller datasets are likely biased towards larger families. 
Perhaps an approach that is biased towards smaller families would be better. 

## QLoRA Info

Note, we are only training 0.58% of the parameters, using only the query, key, and value weight matrices. 

```
trainable params: 23682 || all params: 4075265 || trainable%: 0.5811155838945443
```

It was shown in the QLoRA paper that to obtain performance comparable to or better than full finetuning, the most important hyperparameter than can 
that can be adjusted is which weight matrices the LoRA adapters are applied to, with more being better. The rank and other hyperparameters
such as the scaling factor alpha did not seem to matter. So, an important thing to investigate next would be to check and see if this 
transfers to protein language models as well. A general pattern showing that overfitting is improved by adding in adapters for more of the 
weight matrices is emerging, so more adapter layers seems to be better in that regard as well. 

## Testing for Overfitting

### Checkpoint 1

Train/Test Split from 600K dataset:

```python
Train metrics:
{'eval_loss': 0.31757092475891113,
'eval_accuracy': 0.8666164527145709,
'eval_precision': 0.12977997642311132,
'eval_recall': 0.8907064653559833,
'eval_f1': 0.2265505142278714,
'eval_auc': 0.8783913689919987,
'eval_mcc': 0.30996745466311043}

Test metrics:
{'eval_loss': 0.3398605287075043,
'eval_accuracy': 0.8557050926566265,
'eval_precision': 0.10792930844408741,
'eval_recall': 0.7726298654561553,
'eval_f1': 0.18940102955847055,
'eval_auc': 0.8150939843855006,
'eval_mcc': 0.2535956911257298}
```

Metrics for this checkpoint for [these datasets](https://github.com/hamzagamouh/pt-lm-gnn) can be 
[found here](https://huggingface.co/AmelieSchreiber/esm2_t6_8m_qlora_binding_sites_v0/blob/main/pdb_struct_metrics.txt). 

### Checkpoint 4

```python
Train metrics:
{'eval_loss': 0.24070295691490173,
'eval_accuracy': 0.9018779246397052,
'eval_precision': 0.16624103834249204,
'eval_recall': 0.8651772818812425,
'eval_f1': 0.27889357183237473,
'eval_auc': 0.8839390799308487,
'eval_mcc': 0.3536803490333407}

Test metrics:
{'eval_loss': 0.26776671409606934,
'eval_accuracy': 0.8902711124906878,
'eval_precision': 0.13008662855482372,
'eval_recall': 0.7084623832213568,
'eval_f1': 0.219811797752809,
'eval_auc': 0.8013943890942485,
'eval_mcc': 0.2721459410994918}
```