AmelieSchreiber commited on
Commit
1de9930
1 Parent(s): e60a6b0

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +54 -0
README.md CHANGED
@@ -2,6 +2,10 @@
2
  license: mit
3
  ---
4
 
 
 
 
 
5
  ```python
6
  accuracy 0.97515
7
  auc 0.91801
@@ -10,4 +14,54 @@ loss 0.25427
10
  mcc 0.27049
11
  precision 0.08791
12
  recall 0.86056
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  ```
 
2
  license: mit
3
  ---
4
 
5
+ # ESM-2 for Post Translational Modification
6
+
7
+ ## Metrics
8
+
9
  ```python
10
  accuracy 0.97515
11
  auc 0.91801
 
14
  mcc 0.27049
15
  precision 0.08791
16
  recall 0.86056
17
+ ```
18
+
19
+ ## Using the Model
20
+
21
+ To use the model, run:
22
+
23
+ ```python
24
+ from transformers import AutoModelForTokenClassification, AutoTokenizer
25
+ from peft import PeftModel
26
+ import torch
27
+
28
+ # Path to the saved LoRA model
29
+ model_path = "AmelieSchreiber/esm2_t30_150M_ptm_qlora_2100K_1000"
30
+ # ESM2 base model
31
+ base_model_path = "facebook/esm2_t30_150M_UR50D"
32
+
33
+ # Load the model
34
+ base_model = AutoModelForTokenClassification.from_pretrained(base_model_path)
35
+ loaded_model = PeftModel.from_pretrained(base_model, model_path)
36
+
37
+ # Ensure the model is in evaluation mode
38
+ loaded_model.eval()
39
+
40
+ # Load the tokenizer
41
+ loaded_tokenizer = AutoTokenizer.from_pretrained(base_model_path)
42
+
43
+ # Protein sequence for inference
44
+ protein_sequence = "MAVPETRPNHTIYINNLNEKIKKDELKKSLHAIFSRFGQILDILVSRSLKMRGQAFVIFKEVSSATNALRSMQGFPFYDKPMRIQYAKTDSDIIAKMKGT" # Replace with your actual sequence
45
+
46
+ # Tokenize the sequence
47
+ inputs = loaded_tokenizer(protein_sequence, return_tensors="pt", truncation=True, max_length=1024, padding='max_length')
48
+
49
+ # Run the model
50
+ with torch.no_grad():
51
+ logits = loaded_model(**inputs).logits
52
+
53
+ # Get predictions
54
+ tokens = loaded_tokenizer.convert_ids_to_tokens(inputs["input_ids"][0]) # Convert input ids back to tokens
55
+ predictions = torch.argmax(logits, dim=2)
56
+
57
+ # Define labels
58
+ id2label = {
59
+ 0: "No ptm site",
60
+ 1: "ptm site"
61
+ }
62
+
63
+ # Print the predicted labels for each token
64
+ for token, prediction in zip(tokens, predictions[0].numpy()):
65
+ if token not in ['<pad>', '<cls>', '<eos>']:
66
+ print((token, id2label[prediction]))
67
  ```