--- license: mit language: - en library_name: transformers tags: - esm - esm-2 - sequence classifier - proteins - protein language model pipeline_tag: zero-shot-classification --- # ESM-2 Sequence Classifier This is a small sequence classifier trained on synthetic data generated by GPT-4 which classifies protein sequences into three categories `enzymes` (class `0`), `receptor_proteins` (class `1`), and `structural_proteins` (class `2`). This is trained using [facebook/esm2_t6_8M_UR50D](https://huggingface.co/facebook/esm2_t6_8M_UR50D), one of the [ESM-2 models](https://huggingface.co/docs/transformers/model_doc/esm). This model is not well tested, and is for experimental and eductaional purposes. Use with caution. ## Using the Model To use the model, try running: ```python # Load the trained model and tokenizer model = EsmForSequenceClassification.from_pretrained("AmelieSchreiber/esm2_t6_8M_UR50D_sequence_classifier_v1") tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t6_8M_UR50D") # Suppose these are your new sequences that you want to classify # Additional Family 0: Enzymes new_sequences_0 = [ "ACGYLKTPKLADPPVLRGDSSVTKAICKPDPVLEK", "GVALDECKALDYLPGKPLPMDGKVCQCGSKTPLRP", "VLPGYTCGELDCKPGKPLPKCGADKTQVATPFLRG", "TCGALVQYPSCADPPVLRGSDSSVKACKKLDPQDK", "GALCEECKLCPGADYKPMDGDRLPAAATSKTRPVG", "PAVDCKKALVYLPKPLPMDGKVCRGSKTPKTRPYG", "VLGYTCGALDCKPGKPLPKCGADKTQVATPFLRGA", "CGALVQYPSCADPPVLRGSDSSVKACKKLDPQDKT", "ALCEECKLCPGADYKPMDGDRLPAAATSKTRPVGK", "AVDCKKALVYLPKPLPMDGKVCRGSKTPKTRPYGR", ] # Additional Family 1: Receptor Proteins new_sequences_1 = [ "VGQRFYGGRQKNRHCELSPLPSACRGSVQGALYTD", "KDQVLTVPTYACRCCPKMDSKGRVPSTLRVKSARS", "PLAGVACGRGLDYRCPRKMVPGDLQVTPATQRPYG", "CGVRLGYPGCADVPLRGRSSFAPRACMKKDPRVTR", "RKGVAYLYECRKLRCRADYKPRGMDGRRLPKASTT", "RPTGAVNCKQAKVYRGLPLPMMGKVPRVCRSRRPY", "RLDGGYTCGQALDCKPGRKPPKMGCADLKSTVATP", "LGTCRKLVRYPQCADPPVMGRSSFRPKACCRQDPV", "RVGYAMCSPKLCSCRADYKPPMGDGDRLPKAATSK", "QPKAVNCRKAMVYRPKPLPMDKGVPVCRSKRPRPY", ] # Additional Family 2: Structural Proteins new_sequences_2 = [ "VGKGFRYGSSQKRYLHCQKSALPPSCRRGKGQGSAT", "KDPTVMTVGTYSCQCPKQDSRGSVQPTSRVKTSRSK", "PLVGKACGRSSDYKCPGQMVSGGSKQTPASQRPSYD", "CGKKLVGYPSSKADVPLQGRSSFSPKACKKDPQMTS", "RKGVASLYCSSKLSCKAQYSKGMSDGRSPKASSTTS", "RPKSAASCEQAKSYRSLSLPSMKGKVPSKCSRSKRP", "RSDVSYTSCSQSKDCKPSKPPKMSGSKDSSTVATPS", "LSTCSKKVAYPSSKADPPSSGRSSFSMKACKKQDPPV", "RVGSASSEPKSSCSVQSYSKPSMSGDSSPKASSTSK", "QPSASNCEKMSSYRPSLPSMSKGVPSSRSKSSPPYQ", ] # Tokenize the sequences and convert to tensors # Merge all sequences new_sequences = new_sequences_0 + new_sequences_1 + new_sequences_2 inputs = tokenizer(new_sequences, return_tensors="pt", padding=True, truncation=True) # Use the model to get the logits with torch.no_grad(): logits = model(**inputs).logits # Get the predicted class for each sequence predicted_class_ids = torch.argmax(logits, dim=-1) # Print the predicted class for each sequence for sequence, predicted_class in zip(new_sequences, predicted_class_ids): print(f"Sequence: {sequence}, Predicted class: {predicted_class.item()}") ```