|
--- |
|
license: mit |
|
language: |
|
- en |
|
library_name: transformers |
|
tags: |
|
- esm |
|
- esm-2 |
|
- sequence classifier |
|
- proteins |
|
- protein language model |
|
pipeline_tag: zero-shot-classification |
|
--- |
|
|
|
# ESM-2 Sequence Classifier |
|
This is a small sequence classifier trained on synthetic data generated by GPT-4 |
|
which classifies protein sequences into three categories `enzymes` (class `0`), `receptor_proteins` (class `1`), and `structural_proteins` (class `2`). |
|
This is trained using [facebook/esm2_t6_8M_UR50D](https://huggingface.co/facebook/esm2_t6_8M_UR50D), one of the [ESM-2 models](https://huggingface.co/docs/transformers/model_doc/esm). |
|
|
|
This model is not well tested, and is for experimental and eductaional purposes. Use with caution. |
|
|
|
## Using the Model |
|
To use the model, try running: |
|
|
|
```python |
|
# Load the trained model and tokenizer |
|
model = EsmForSequenceClassification.from_pretrained("AmelieSchreiber/esm2_t6_8M_UR50D_sequence_classifier_v1") |
|
tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t6_8M_UR50D") |
|
|
|
# Suppose these are your new sequences that you want to classify |
|
# Additional Family 0: Enzymes |
|
new_sequences_0 = [ |
|
"ACGYLKTPKLADPPVLRGDSSVTKAICKPDPVLEK", |
|
"GVALDECKALDYLPGKPLPMDGKVCQCGSKTPLRP", |
|
"VLPGYTCGELDCKPGKPLPKCGADKTQVATPFLRG", |
|
"TCGALVQYPSCADPPVLRGSDSSVKACKKLDPQDK", |
|
"GALCEECKLCPGADYKPMDGDRLPAAATSKTRPVG", |
|
"PAVDCKKALVYLPKPLPMDGKVCRGSKTPKTRPYG", |
|
"VLGYTCGALDCKPGKPLPKCGADKTQVATPFLRGA", |
|
"CGALVQYPSCADPPVLRGSDSSVKACKKLDPQDKT", |
|
"ALCEECKLCPGADYKPMDGDRLPAAATSKTRPVGK", |
|
"AVDCKKALVYLPKPLPMDGKVCRGSKTPKTRPYGR", |
|
] |
|
|
|
# Additional Family 1: Receptor Proteins |
|
new_sequences_1 = [ |
|
"VGQRFYGGRQKNRHCELSPLPSACRGSVQGALYTD", |
|
"KDQVLTVPTYACRCCPKMDSKGRVPSTLRVKSARS", |
|
"PLAGVACGRGLDYRCPRKMVPGDLQVTPATQRPYG", |
|
"CGVRLGYPGCADVPLRGRSSFAPRACMKKDPRVTR", |
|
"RKGVAYLYECRKLRCRADYKPRGMDGRRLPKASTT", |
|
"RPTGAVNCKQAKVYRGLPLPMMGKVPRVCRSRRPY", |
|
"RLDGGYTCGQALDCKPGRKPPKMGCADLKSTVATP", |
|
"LGTCRKLVRYPQCADPPVMGRSSFRPKACCRQDPV", |
|
"RVGYAMCSPKLCSCRADYKPPMGDGDRLPKAATSK", |
|
"QPKAVNCRKAMVYRPKPLPMDKGVPVCRSKRPRPY", |
|
] |
|
|
|
# Additional Family 2: Structural Proteins |
|
new_sequences_2 = [ |
|
"VGKGFRYGSSQKRYLHCQKSALPPSCRRGKGQGSAT", |
|
"KDPTVMTVGTYSCQCPKQDSRGSVQPTSRVKTSRSK", |
|
"PLVGKACGRSSDYKCPGQMVSGGSKQTPASQRPSYD", |
|
"CGKKLVGYPSSKADVPLQGRSSFSPKACKKDPQMTS", |
|
"RKGVASLYCSSKLSCKAQYSKGMSDGRSPKASSTTS", |
|
"RPKSAASCEQAKSYRSLSLPSMKGKVPSKCSRSKRP", |
|
"RSDVSYTSCSQSKDCKPSKPPKMSGSKDSSTVATPS", |
|
"LSTCSKKVAYPSSKADPPSSGRSSFSMKACKKQDPPV", |
|
"RVGSASSEPKSSCSVQSYSKPSMSGDSSPKASSTSK", |
|
"QPSASNCEKMSSYRPSLPSMSKGVPSSRSKSSPPYQ", |
|
] |
|
|
|
# Tokenize the sequences and convert to tensors |
|
# Merge all sequences |
|
new_sequences = new_sequences_0 + new_sequences_1 + new_sequences_2 |
|
inputs = tokenizer(new_sequences, return_tensors="pt", padding=True, truncation=True) |
|
|
|
# Use the model to get the logits |
|
with torch.no_grad(): |
|
logits = model(**inputs).logits |
|
|
|
# Get the predicted class for each sequence |
|
predicted_class_ids = torch.argmax(logits, dim=-1) |
|
|
|
# Print the predicted class for each sequence |
|
for sequence, predicted_class in zip(new_sequences, predicted_class_ids): |
|
print(f"Sequence: {sequence}, Predicted class: {predicted_class.item()}") |
|
``` |
|
|