File size: 4,758 Bytes
a049c3f c20ef98 a049c3f c20ef98 45e45d6 8b577db 508e6f9 bbe0e74 2248c60 c20ef98 b79663e 45e45d6 0b6b91e 49868ca 77ab0f2 2f8efbb 16bb25e 77ab0f2 0b6b91e c66d466 c1701b9 b79663e 3290895 0b6b91e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 |
---
license: mit
language:
- en
library_name: peft
tags:
- ESM-2
- QLoRA
- Binding Sites
- biology
---
# ESM-2 QLoRA
These are the checkpoints for the first ever QLoRA for ESM-2! You can load and use them similarly to the LoRA models.
This is the smallest `esm2_t6_8M_UR50D` model, so the metrics aren't great.
Scaling to larger models for better metrics is in progress. These checkpoints were trained using
[the 600K dataset](https://huggingface.co/datasets/AmelieSchreiber/600K_data). To replicate the training of QLoRA for ESM-2 models,
you can use the `conda-environment.yml` file. However, for the next week or two (28/09/2023) you will need to uninstall transformers
and use this instead:
```
pip install --upgrade git+https://github.com/huggingface/transformers.git
```
In a couple of weeks, once the transformers library is updated, you should be able to simply use the latest version of transformers
and gradient checkpointing will be fully enabled, and QLoRA compatibility should be fully integrated into ESM-2 models.
## Data Curation and Preprocessing
To create your own datasets and perform the same data preprocessing as was used for this project, you will need to download a TSV file
from UniProt with the following columns (Protein families, Binding sites, Active sites, Protein sequence), and then you can use
[this notebook](https://huggingface.co/AmelieSchreiber/esm2_t6_8m_qlora_binding_sites_v0/blob/main/data_processing_v1.ipynb) for
separating out the test sequences by choosing random families to use (including all sequences in that family, with no overlap in with
the training data), filtering out proteins with incomplete annotations, merging the binding and active sites, converting them to binary
labels (`0` for non-binding sites, `1` for binding sites), and splitting the sequences into non-overlapping chunks of 1000 residues or
less to accomodate the 1022 sized context window of ESM-2 models. This notebook will also allow you to reduce the size of your dataset
at the end. Note, this step is not currently ideal as it only selects proteins at random from the train and test datasets to keep and does
not take into account that proteins from small families are less likely to be chosen, biasing the models towards larger families. Due to
this shortcoming in our data preprocessing step, smaller models trained on smaller datasets are likely biased towards larger families.
Perhaps an approach that is biased towards smaller families would be better.
## QLoRA Info
Note, we are only training 0.58% of the parameters, using only the query, key, and value weight matrices.
```
trainable params: 23682 || all params: 4075265 || trainable%: 0.5811155838945443
```
It was shown in the QLoRA paper that to obtain performance comparable to or better than full finetuning, the most important hyperparameter than can
that can be adjusted is which weight matrices the LoRA adapters are applied to, with more being better. The rank and other hyperparameters
such as the scaling factor alpha did not seem to matter. So, an important thing to investigate next would be to check and see if this
transfers to protein language models as well. A general pattern showing that overfitting is improved by adding in adapters for more of the
weight matrices is emerging, so more adapter layers seems to be better in that regard as well.
## Testing for Overfitting
### Checkpoint 1
Train/Test Split from 600K dataset:
```python
Train metrics:
{'eval_loss': 0.31757092475891113,
'eval_accuracy': 0.8666164527145709,
'eval_precision': 0.12977997642311132,
'eval_recall': 0.8907064653559833,
'eval_f1': 0.2265505142278714,
'eval_auc': 0.8783913689919987,
'eval_mcc': 0.30996745466311043}
Test metrics:
{'eval_loss': 0.3398605287075043,
'eval_accuracy': 0.8557050926566265,
'eval_precision': 0.10792930844408741,
'eval_recall': 0.7726298654561553,
'eval_f1': 0.18940102955847055,
'eval_auc': 0.8150939843855006,
'eval_mcc': 0.2535956911257298}
```
Metrics for this checkpoint for [these datasets](https://github.com/hamzagamouh/pt-lm-gnn) can be
[found here](https://huggingface.co/AmelieSchreiber/esm2_t6_8m_qlora_binding_sites_v0/blob/main/pdb_struct_metrics.txt).
### Checkpoint 4
```python
Train metrics:
{'eval_loss': 0.24070295691490173,
'eval_accuracy': 0.9018779246397052,
'eval_precision': 0.16624103834249204,
'eval_recall': 0.8651772818812425,
'eval_f1': 0.27889357183237473,
'eval_auc': 0.8839390799308487,
'eval_mcc': 0.3536803490333407}
Test metrics:
{'eval_loss': 0.26776671409606934,
'eval_accuracy': 0.8902711124906878,
'eval_precision': 0.13008662855482372,
'eval_recall': 0.7084623832213568,
'eval_f1': 0.219811797752809,
'eval_auc': 0.8013943890942485,
'eval_mcc': 0.2721459410994918}
```
|