--- library_name: peft license: mit language: - en tags: - transformers - biology - esm - esm2 - protein - protein language model --- # ESM-2 RNA Binding Site LoRA This is a Parameter Efficient Fine Tuning (PEFT) Low Rank Adaptation (LoRA) of the [esm2_t12_35M_UR50D](https://huggingface.co/facebook/esm2_t12_35M_UR50D) model for the (binary) token classification task of predicting RNA binding sites of proteins. You can also find a version of this model that was fine-tuned without LoRA [here](https://huggingface.co/AmelieSchreiber/esm2_t6_8M_UR50D_rna_binding_site_predictor). ## Training procedure This is a Low Rank Adaptation (LoRA) of `esm2_t12_35M_UR50D`, trained on `166` protein sequences in the [RNA binding sites dataset](https://huggingface.co/datasets/AmelieSchreiber/data_of_protein-rna_binding_sites) using a `85/15` train/test split. This model was trained with class weighting due to the imbalanced nature of the RNA binding site dataset (fewer binding sites than non-binding sites). This model has slightly improved precision, recall, and F1 score over [AmelieSchreiber/esm2_t12_35M_weighted_lora_rna_binding](https://huggingface.co/AmelieSchreiber/esm2_t12_35M_weighted_lora_rna_binding) but may suffer from mild overfitting, as indicated by the training loss being slightly lower than the eval loss (see metrics below). If you are searching for binding sites and aren't worried about false positives, the higher recall may make this model preferable to the other RNA binding site predictors. You can train your own version using [this notebook](https://huggingface.co/AmelieSchreiber/esm2_t6_8M_weighted_lora_rna_binding/blob/main/LoRA_binding_sites_no_sweeps_v2.ipynb)! You just need the RNA `binding_sites.xml` file [found here](https://huggingface.co/datasets/AmelieSchreiber/data_of_protein-rna_binding_sites). You may also need to run some `pip install` statements at the beginning of the script. If you are running in colab run: ```python !pip install transformers[torch] datasets peft -q ``` ```python !pip install accelerate -U -q ``` Try to improve upon these metrics by adjusting the hyperparameters: ``` {'eval_loss': 0.500779926776886, 'eval_precision': 0.1708695652173913, 'eval_recall': 0.8397435897435898, 'eval_f1': 0.2839595375722543, 'eval_auc': 0.771835775620126, 'epoch': 11.0} {'loss': 0.4171, 'learning_rate': 0.00032491416877500004, 'epoch': 11.43} ``` A similar model can also be trained using the Github with a training script and conda env YAML, which can be [found here](https://github.com/Amelie-Schreiber/esm2_LoRA_binding_sites/tree/main). This version uses wandb sweeps for hyperparameter search. However, it does not use class weighting. ### Framework versions - PEFT 0.4.0 ## Using the Model To use the model, try running the following pip install statements: ```python !pip install transformers peft -q ``` then try tunning: ```python from transformers import AutoModelForTokenClassification, AutoTokenizer from peft import PeftModel import torch # Path to the saved LoRA model model_path = "AmelieSchreiber/esm2_t12_35M_UR50D_RNA_LoRA_weighted" # ESM2 base model base_model_path = "facebook/esm2_t12_35M_UR50D" # Load the model base_model = AutoModelForTokenClassification.from_pretrained(base_model_path) loaded_model = PeftModel.from_pretrained(base_model, model_path) # Ensure the model is in evaluation mode loaded_model.eval() # Load the tokenizer loaded_tokenizer = AutoTokenizer.from_pretrained(base_model_path) # Protein sequence for inference protein_sequence = "MAVPETRPNHTIYINNLNEKIKKDELKKSLHAIFSRFGQILDILVSRSLKMRGQAFVIFKEVSSATNALRSMQGFPFYDKPMRIQYAKTDSDIIAKMKGT" # Replace with your actual sequence # Tokenize the sequence inputs = loaded_tokenizer(protein_sequence, return_tensors="pt", truncation=True, max_length=1024, padding='max_length') # Run the model with torch.no_grad(): logits = loaded_model(**inputs).logits # Get predictions tokens = loaded_tokenizer.convert_ids_to_tokens(inputs["input_ids"][0]) # Convert input ids back to tokens predictions = torch.argmax(logits, dim=2) # Define labels id2label = { 0: "No binding site", 1: "Binding site" } # Print the predicted labels for each token for token, prediction in zip(tokens, predictions[0].numpy()): if token not in ['', '', '']: print((token, id2label[prediction])) ```