AmelieSchreiber's picture
Update README.md
16bb25e
|
raw
history blame
4.76 kB
metadata
license: mit
language:
  - en
library_name: peft
tags:
  - ESM-2
  - QLoRA
  - Binding Sites
  - biology

ESM-2 QLoRA

These are the checkpoints for the first ever QLoRA for ESM-2! You can load and use them similarly to the LoRA models. This is the smallest esm2_t6_8M_UR50D model, so the metrics aren't great. Scaling to larger models for better metrics is in progress. These checkpoints were trained using the 600K dataset. To replicate the training of QLoRA for ESM-2 models, you can use the conda-environment.yml file. However, for the next week or two (28/09/2023) you will need to uninstall transformers and use this instead:

pip install --upgrade git+https://github.com/huggingface/transformers.git

In a couple of weeks, once the transformers library is updated, you should be able to simply use the latest version of transformers and gradient checkpointing will be fully enabled, and QLoRA compatibility should be fully integrated into ESM-2 models.

Data Curation and Preprocessing

To create your own datasets and perform the same data preprocessing as was used for this project, you will need to download a TSV file from UniProt with the following columns (Protein families, Binding sites, Active sites, Protein sequence), and then you can use this notebook for separating out the test sequences by choosing random families to use (including all sequences in that family, with no overlap in with the training data), filtering out proteins with incomplete annotations, merging the binding and active sites, converting them to binary labels (0 for non-binding sites, 1 for binding sites), and splitting the sequences into non-overlapping chunks of 1000 residues or less to accomodate the 1022 sized context window of ESM-2 models. This notebook will also allow you to reduce the size of your dataset at the end. Note, this step is not currently ideal as it only selects proteins at random from the train and test datasets to keep and does not take into account that proteins from small families are less likely to be chosen, biasing the models towards larger families. Due to this shortcoming in our data preprocessing step, smaller models trained on smaller datasets are likely biased towards larger families. Perhaps an approach that is biased towards smaller families would be better.

QLoRA Info

Note, we are only training 0.58% of the parameters, using only the query, key, and value weight matrices.

trainable params: 23682 || all params: 4075265 || trainable%: 0.5811155838945443

It was shown in the QLoRA paper that to obtain performance comparable to or better than full finetuning, the most important hyperparameter than can that can be adjusted is which weight matrices the LoRA adapters are applied to, with more being better. The rank and other hyperparameters such as the scaling factor alpha did not seem to matter. So, an important thing to investigate next would be to check and see if this transfers to protein language models as well. A general pattern showing that overfitting is improved by adding in adapters for more of the weight matrices is emerging, so more adapter layers seems to be better in that regard as well.

Testing for Overfitting

Checkpoint 1

Train/Test Split from 600K dataset:

Train metrics:
{'eval_loss': 0.31757092475891113,
'eval_accuracy': 0.8666164527145709,
'eval_precision': 0.12977997642311132,
'eval_recall': 0.8907064653559833,
'eval_f1': 0.2265505142278714,
'eval_auc': 0.8783913689919987,
'eval_mcc': 0.30996745466311043}

Test metrics:
{'eval_loss': 0.3398605287075043,
'eval_accuracy': 0.8557050926566265,
'eval_precision': 0.10792930844408741,
'eval_recall': 0.7726298654561553,
'eval_f1': 0.18940102955847055,
'eval_auc': 0.8150939843855006,
'eval_mcc': 0.2535956911257298}

Metrics for this checkpoint for these datasets can be found here.

Checkpoint 4

Train metrics:
{'eval_loss': 0.24070295691490173,
'eval_accuracy': 0.9018779246397052,
'eval_precision': 0.16624103834249204,
'eval_recall': 0.8651772818812425,
'eval_f1': 0.27889357183237473,
'eval_auc': 0.8839390799308487,
'eval_mcc': 0.3536803490333407}

Test metrics:
{'eval_loss': 0.26776671409606934,
'eval_accuracy': 0.8902711124906878,
'eval_precision': 0.13008662855482372,
'eval_recall': 0.7084623832213568,
'eval_f1': 0.219811797752809,
'eval_auc': 0.8013943890942485,
'eval_mcc': 0.2721459410994918}