AmelieSchreiber commited on
Commit
81cfd05
1 Parent(s): 79ba33b

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +86 -1
README.md CHANGED
@@ -1,9 +1,94 @@
1
  ---
2
  library_name: peft
 
 
 
 
 
 
 
 
 
 
3
  ---
 
 
 
 
 
 
 
 
4
  ## Training procedure
5
 
6
- ### Framework versions
 
 
 
 
7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
  - PEFT 0.4.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  library_name: peft
3
+ license: mit
4
+ language:
5
+ - en
6
+ tags:
7
+ - transformers
8
+ - biology
9
+ - esm
10
+ - esm2
11
+ - protein
12
+ - protein language model
13
  ---
14
+ # ESM-2 RNA Binding Site LoRA
15
+
16
+ This is a Parameter Efficient Fine Tuning (PEFT) Low Rank Adaptation (LoRA) of
17
+ the [esm2_t6_8M_UR50D](https://huggingface.co/facebook/esm2_t6_8M_UR50D) model for the (binary) token classification task of
18
+ predicting RNA binding sites of proteins. The Github with the training script and conda env YAML can be
19
+ [found here](https://github.com/Amelie-Schreiber/esm2_LoRA_binding_sites/tree/main). You can also find a version of this model
20
+ that was fine-tuned without LoRA [here](https://huggingface.co/AmelieSchreiber/esm2_t6_8M_UR50D_rna_binding_site_predictor).
21
+
22
  ## Training procedure
23
 
24
+ This is a Low Rank Adaptation (LoRA) of `esm2_t6_8M_UR50D`,
25
+ trained on `166` protein sequences in the [RNA binding sites dataset](https://huggingface.co/datasets/AmelieSchreiber/data_of_protein-rna_binding_sites)
26
+ using a `80/20` train/test split. This model was trained with class weighting due to the imbalanced nature
27
+ of the RNA binding site dataset (fewer binding sites than non-binding sites). You can train your own version
28
+ using [this notebook](https://huggingface.co/AmelieSchreiber/esm2_t6_8M_weighted_lora_rna_binding/blob/main/LoRA_binding_sites_no_sweeps_v2.ipynb)!
29
 
30
+ ```
31
+ {'eval_loss': 0.49476009607315063,
32
+ 'eval_precision': 0.14372964169381108,
33
+ 'eval_recall': 0.7526652452025586,
34
+ 'eval_f1': 0.24136752136752138,
35
+ 'eval_auc': 0.7710141129858947,
36
+ 'eval_runtime': 3.5601,
37
+ 'eval_samples_per_second': 9.27,
38
+ 'eval_steps_per_second': 2.528,
39
+ 'epoch': 15.0}
40
+ ```
41
+
42
+ ### Framework versions
43
 
44
  - PEFT 0.4.0
45
+
46
+ ## Using the Model
47
+
48
+ To use, try running:
49
+ ```python
50
+ from transformers import AutoModelForTokenClassification, AutoTokenizer
51
+ from peft import PeftModel
52
+ import torch
53
+
54
+ # Path to the saved LoRA model
55
+ model_path = "AmelieSchreiber/esm2_t6_8M_weighted_lora_rna_binding"
56
+ # ESM2 base model
57
+ base_model_path = "facebook/esm2_t6_8M_UR50D"
58
+
59
+ # Load the model
60
+ base_model = AutoModelForTokenClassification.from_pretrained(base_model_path)
61
+ loaded_model = PeftModel.from_pretrained(base_model, model_path)
62
+
63
+ # Ensure the model is in evaluation mode
64
+ loaded_model.eval()
65
+
66
+ # Load the tokenizer
67
+ loaded_tokenizer = AutoTokenizer.from_pretrained(base_model_path)
68
+
69
+ # Protein sequence for inference
70
+ protein_sequence = "MAVPETRPNHTIYINNLNEKIKKDELKKSLHAIFSRFGQILDILVSRSLKMRGQAFVIFKEVSSATNALRSMQGFPFYDKPMRIQYAKTDSDIIAKMKGT" # Replace with your actual sequence
71
+
72
+ # Tokenize the sequence
73
+ inputs = loaded_tokenizer(protein_sequence, return_tensors="pt", truncation=True, max_length=1024, padding='max_length')
74
+
75
+ # Run the model
76
+ with torch.no_grad():
77
+ logits = loaded_model(**inputs).logits
78
+
79
+ # Get predictions
80
+ tokens = loaded_tokenizer.convert_ids_to_tokens(inputs["input_ids"][0]) # Convert input ids back to tokens
81
+ predictions = torch.argmax(logits, dim=2)
82
+
83
+ # Define labels
84
+ id2label = {
85
+ 0: "No binding site",
86
+ 1: "Binding site"
87
+ }
88
+
89
+ # Print the predicted labels for each token
90
+ for token, prediction in zip(tokens, predictions[0].numpy()):
91
+ if token not in ['<pad>', '<cls>', '<eos>']:
92
+ print((token, id2label[prediction]))
93
+
94
+ ```