AmelieSchreiber commited on
Commit
0c083d0
1 Parent(s): 1f043b2

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +51 -0
README.md CHANGED
@@ -44,6 +44,57 @@ Test:
44
 
45
  ## Using the Model
46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  Head over to [here](https://huggingface.co/datasets/AmelieSchreiber/binding_sites_random_split_by_family)
48
  to download the dataset first. Once you have the pickle files downloaded locally, run the following:
49
 
 
44
 
45
  ## Using the Model
46
 
47
+ ### Using on your Protein Sequences
48
+
49
+ To use the model on one of your protein sequences try running the following:
50
+
51
+ ```python
52
+ from transformers import AutoModelForTokenClassification, AutoTokenizer
53
+ from peft import PeftModel
54
+ import torch
55
+
56
+ # Path to the saved LoRA model
57
+ model_path = "AmelieSchreiber/esm2_t12_35M_lora_binding_sites_cp1"
58
+ # ESM2 base model
59
+ base_model_path = "facebook/esm2_t12_35M_UR50D"
60
+
61
+ # Load the model
62
+ base_model = AutoModelForTokenClassification.from_pretrained(base_model_path)
63
+ loaded_model = PeftModel.from_pretrained(base_model, model_path)
64
+
65
+ # Ensure the model is in evaluation mode
66
+ loaded_model.eval()
67
+
68
+ # Load the tokenizer
69
+ loaded_tokenizer = AutoTokenizer.from_pretrained(base_model_path)
70
+
71
+ # Protein sequence for inference
72
+ protein_sequence = "MAVPETRPNHTIYINNLNEKIKKDELKKSLHAIFSRFGQILDILVSRSLKMRGQAFVIFKEVSSATNALRSMQGFPFYDKPMRIQYAKTDSDIIAKMKGT" # Replace with your actual sequence
73
+
74
+ # Tokenize the sequence
75
+ inputs = loaded_tokenizer(protein_sequence, return_tensors="pt", truncation=True, max_length=1024, padding='max_length')
76
+
77
+ # Run the model
78
+ with torch.no_grad():
79
+ logits = loaded_model(**inputs).logits
80
+
81
+ # Get predictions
82
+ tokens = loaded_tokenizer.convert_ids_to_tokens(inputs["input_ids"][0]) # Convert input ids back to tokens
83
+ predictions = torch.argmax(logits, dim=2)
84
+
85
+ # Define labels
86
+ id2label = {
87
+ 0: "No binding site",
88
+ 1: "Binding site"
89
+ }
90
+
91
+ # Print the predicted labels for each token
92
+ for token, prediction in zip(tokens, predictions[0].numpy()):
93
+ if token not in ['<pad>', '<cls>', '<eos>']:
94
+ print((token, id2label[prediction]))
95
+ ```
96
+
97
+ ### Getting the Train/Test Metrics:
98
  Head over to [here](https://huggingface.co/datasets/AmelieSchreiber/binding_sites_random_split_by_family)
99
  to download the dataset first. Once you have the pickle files downloaded locally, run the following:
100