Commit
•
fac5c2a
1
Parent(s):
cd2a3ae
Update README.md
Browse files
README.md
CHANGED
@@ -19,4 +19,56 @@ epoch 3:
|
|
19 |
'eval_auc': 0.9730582015280457
|
20 |
```
|
21 |
|
|
|
22 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
'eval_auc': 0.9730582015280457
|
20 |
```
|
21 |
|
22 |
+
## Using the Model
|
23 |
|
24 |
+
To use the model, try running:
|
25 |
+
```python
|
26 |
+
import torch
|
27 |
+
from transformers import AutoModelForTokenClassification, AutoTokenizer
|
28 |
+
|
29 |
+
def predict_binding_sites(model_path, protein_sequences):
|
30 |
+
"""
|
31 |
+
Predict binding sites for a collection of protein sequences.
|
32 |
+
|
33 |
+
Parameters:
|
34 |
+
- model_path (str): Path to the saved model.
|
35 |
+
- protein_sequences (List[str]): List of protein sequences.
|
36 |
+
|
37 |
+
Returns:
|
38 |
+
- List[List[str]]: Predicted labels for each sequence.
|
39 |
+
"""
|
40 |
+
|
41 |
+
# Load tokenizer and model
|
42 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
43 |
+
model = AutoModelForTokenClassification.from_pretrained(model_path)
|
44 |
+
|
45 |
+
# Ensure model is in evaluation mode
|
46 |
+
model.eval()
|
47 |
+
|
48 |
+
# Tokenize sequences
|
49 |
+
inputs = tokenizer(protein_sequences, return_tensors="pt", padding=True, truncation=True)
|
50 |
+
|
51 |
+
# Move to the same device as model and obtain logits
|
52 |
+
with torch.no_grad():
|
53 |
+
logits = model(**inputs).logits
|
54 |
+
|
55 |
+
# Obtain predicted labels
|
56 |
+
predicted_labels = torch.argmax(logits, dim=-1).cpu().numpy()
|
57 |
+
|
58 |
+
# Convert label IDs to human-readable labels
|
59 |
+
id2label = model.config.id2label
|
60 |
+
human_readable_labels = [[id2label[label_id] for label_id in sequence] for sequence in predicted_labels]
|
61 |
+
|
62 |
+
return human_readable_labels
|
63 |
+
|
64 |
+
# Usage:
|
65 |
+
model_path = "AmelieSchreiber/esm2_t6_8M_general_binding_sites" # Replace with your model's path
|
66 |
+
unseen_proteins = [
|
67 |
+
"MKVEEILEKALELVIPDEEEVRKGREAEEELRRRLDELGVEYVFVGSYARNTWLKGSLEIDVFLLFPEEFSKEELRERGLEIGKAVLDSYEIRYAEHPYVHGVVKGVEVDVVPCYKLKEPKNIKSAVDRTPFHHKWLEGRIKGKENEVRLLKGFLKANGIYGAEYKVRGFSGYLCELLIVFYGSFLETVKNARRWTRRTVIDVAKGEVRKGEEFFVVDPVDEKRNVAANLSLDNLARFVHLCREFMEAPSLGFFKPKHPLEIEPERLRKIVEERGTAVFAVKFRKPDIVDDNLYPQLERASRKIFEFLERENFMPLRSAFKASEEFCYLLFECQIKEISRVFRRMGPQFEDERNVKKFLSRNRAFRPFIENGRWWAFEMRKFTTPEEGVRSYASTHWHTLGKNVGESIREYFEIISGEKLFKEPVTAELCEMMGVKD",
|
68 |
+
"MKVEEILEKALELVIPDEEEVRKGREAEEELRRRLDELGVEYVFVGSYARNTWLKGSLEIAVFLLFPEEFSKEELRERGLEIGKAVLDSYEIRYAEHPYVHGVVKGVEVDVVPCYKLKEPKNIKSAVDRTPFHHKWLEGRIKGKENEVRLLKGFLKANGIYGAEYKVRGFSGYLCELLIVFYGSFLETVKNARRWTRRTVIDVAKGEVRKGEEFFVVDPVDEKRNVAANLSLDNLARFVHLCREFMEAPSLGFFKVKHPLEIEPERLRKIVEERGTAVFAVKFRKPDIVDDNLYPQLERASRKIFEFLERENFMPLRSAFKASEEFCYLLFECQIKEISRVFRRMGPQFEDERNVKKFLSRNRAFRPFIENGRWWAFEMRKFTTPEEGVRSYASTHWHTLGKNVGESIREYFEIISGEKLFKEPVTAELCEMMGVKD",
|
69 |
+
"MKVEEILEKALELVIPDEEEVRKGREAEEELRRRLDELGVEAVFVGSYARNTWLKGSLEIAVFLLFPEEFSKEELRERGLEIEKAVLDSYEIRYAEHPYVHGVVKGVEVDVVPCYKLKEPKNIKSAVDRTPFHHKELEGRIKGKENEVRLLKGFLKANGIYGAEYAVRGFSGYLCELLIVFYGSFLETVKNARRWTRRTVIDVAKGEVRKGEEFFVVDPVDEKRNVAANLSLDNLARFVHLCREFMEAPSLGFFKVKHPLEIEPERLRKIVEERGTAVFMVKFRKPDIVDDNLYPQLRRASRKIFEFLERNNFMPLRSAFKASEEFCYLLFECQIKEISDVFRRMGPLFEDERNVKKFLSRNRALRPFIENGRWWIFEMRKFTTPEEGVRSYASTHWHTLGKNVGESIREYFEIISGEKLFKEPVTAELCRMMGVKD",
|
70 |
+
"MKVEEILEKALELVIPDEEEVRKGREAEEELRRRLDELGVEAVFVGSYARNTWLKGSLEIAVFLLFPEEFSKEELRERGLEIEKAVLDSYGIRYAEHPYVHGVVKGVELDVVPCYKLKEPKNIKSAVDRTPFHHKELEGRIKGKENEYRSLKGFLKANGIYGAEYAVRGFSGYLCELLIVFYGSFLETVKNARRWTRKTVIDVAKGEVRKGEEFFVVDPVDEKRNVAALLSLDNLARFVHLCREFMEAVSLGFFKVKHPLEIEPERLRKIVEERGTAVFMVKFRKPDIVDDNLYPQLRRASRKIFEFLERNNFMPLRRAFKASEEFCYLLFEQQIKEISDVFRRMGPLFEDERNVKKFLSRNRALRPFIENGRWWIFEMRKFTTPEEGVRSYASTHWHTLGKNVGESIREYFEIIEGEKLFKEPVTAELCRMMGVKD"
|
71 |
+
] # Replace with your unseen protein sequences
|
72 |
+
predictions = predict_binding_sites(model_path, unseen_proteins)
|
73 |
+
predictions
|
74 |
+
```
|