AmelieSchreiber commited on
Commit
ebf710f
1 Parent(s): a2760f2

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +73 -1
README.md CHANGED
@@ -1,9 +1,81 @@
1
  ---
2
  library_name: peft
 
 
 
 
 
 
 
 
 
 
3
  ---
 
 
 
 
 
 
 
 
4
  ## Training procedure
5
 
6
- ### Framework versions
 
 
7
 
 
8
 
9
  - PEFT 0.4.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  library_name: peft
3
+ license: mit
4
+ language:
5
+ - en
6
+ tags:
7
+ - transformers
8
+ - biology
9
+ - esm
10
+ - esm2
11
+ - protein
12
+ - protein language model
13
  ---
14
+ # ESM-2 RNA Binding Site LoRA
15
+
16
+ This is a Parameter Efficient Fine Tuning (PEFT) Low Rank Adaptation ([LoRA](https://huggingface.co/docs/peft/task_guides/token-classification-lora)) of
17
+ the [esm2_t12_35M_UR50D](https://huggingface.co/facebook/esm2_t12_35M_UR50D) model for the (binary) token classification task of
18
+ predicting RNA binding sites of proteins. The Github with the training script and conda env YAML can be
19
+ [found here](https://github.com/Amelie-Schreiber/esm2_LoRA_binding_sites/tree/main). You can also find a version of this model
20
+ that was fine-tuned without LoRA [here](https://huggingface.co/AmelieSchreiber/esm2_t6_8M_UR50D_rna_binding_site_predictor).
21
+
22
  ## Training procedure
23
 
24
+ This is a Low Rank Adaptation (LoRA) of `esm2_t6_8M_UR50D`,
25
+ trained on `166` protein sequences in the [RNA binding sites dataset](https://huggingface.co/datasets/AmelieSchreiber/data_of_protein-rna_binding_sites)
26
+ using a `75/25` train/test split. It achieves an evaluation loss of `0.18801096081733704`.
27
 
28
+ ### Framework versions
29
 
30
  - PEFT 0.4.0
31
+
32
+ ## Using the Model
33
+
34
+ To use, try running:
35
+ ```python
36
+ from transformers import AutoModelForTokenClassification, AutoTokenizer
37
+ from peft import PeftModel
38
+ import torch
39
+
40
+ # Path to the saved LoRA model
41
+ model_path = "AmelieSchreiber/esm2_t30_150M_LoRA_RNA_binding"
42
+ # ESM2 base model
43
+ base_model_path = "facebook/esm2_t30_150M_UR50D"
44
+
45
+ # Load the model
46
+ base_model = AutoModelForTokenClassification.from_pretrained(base_model_path)
47
+ loaded_model = PeftModel.from_pretrained(base_model, model_path)
48
+
49
+ # Ensure the model is in evaluation mode
50
+ loaded_model.eval()
51
+
52
+ # Load the tokenizer
53
+ loaded_tokenizer = AutoTokenizer.from_pretrained(base_model_path)
54
+
55
+ # Protein sequence for inference
56
+ protein_sequence = "MAVPETRPNHTIYINNLNEKIKKDELKKSLHAIFSRFGQILDILVSRSLKMRGQAFVIFKEVSSATNALRSMQGFPFYDKPMRIQYAKTDSDIIAKMKGT" # Replace with your actual sequence
57
+
58
+ # Tokenize the sequence
59
+ inputs = loaded_tokenizer(protein_sequence, return_tensors="pt", truncation=True, max_length=1024, padding='max_length')
60
+
61
+ # Run the model
62
+ with torch.no_grad():
63
+ logits = loaded_model(**inputs).logits
64
+
65
+ # Get predictions
66
+ tokens = loaded_tokenizer.convert_ids_to_tokens(inputs["input_ids"][0]) # Convert input ids back to tokens
67
+ predictions = torch.argmax(logits, dim=2)
68
+
69
+ # Define labels
70
+ id2label = {
71
+ 0: "No binding site",
72
+ 1: "Binding site"
73
+ }
74
+
75
+ # Print the predicted labels for each token
76
+ for token, prediction in zip(tokens, predictions[0].numpy()):
77
+ if token not in ['<pad>', '<cls>', '<eos>']:
78
+ print((token, id2label[prediction]))
79
+
80
+ ```
81
+