--- license: mit --- # ESM-2 for General Protein Binding Site Prediction This model is trained to predict general binding sites of proteins using on the sequence. This is a finetuned version of `esm2_t6_8M_UR50D`, trained on [this dataset](https://huggingface.co/datasets/AmelieSchreiber/general_binding_sites). The data is not filtered by family, and thus the model may be overfit to some degree. ## Training ``` epoch 3: 'eval_loss': 0.08215777575969696, 'eval_precision': 0.4673852829840273, 'eval_recall': 0.9587594696969697, 'eval_f1': 0.6284215753212091, 'eval_auc': 0.9730582015280457 ``` ## Using the Model Try pasting a protein sequence into the cell on the right and clicking on "Compute". For example, try ``` MKVEEILEKALELVIPDEEEVRKGREAEEELRRRLDELGVEYVFVGSYARNTWLKGSLEIDVFLLFPEEFSKEELRERGLEIGKAVLDSYEIRYAEHPYVHGVVKGVEVDVVPCYKLKEPKNIKSAVDRTPFHHKWLEGRIKGKENEVRLLKGFLKANGIYGAEYKVRGFSGYLCELLIVFYGSFLETVKNARRWTRRTVIDVAKGEVRKGEEFFVVDPVDEKRNVAANLSLDNLARFVHLCREFMEAPSLGFFKPKHPLEIEPERLRKIVEERGTAVFAVKFRKPDIVDDNLYPQLERASRKIFEFLERENFMPLRSAFKASEEFCYLLFECQIKEISRVFRRMGPQFEDERNVKKFLSRNRAFRPFIENGRWWAFEMRKFTTPEEGVRSYASTHWHTLGKNVGESIREYFEIISGEKLFKEPVTAELCEMMGVKD ``` To use the model, try running: ```python import torch from transformers import AutoModelForTokenClassification, AutoTokenizer def predict_binding_sites(model_path, protein_sequences): """ Predict binding sites for a collection of protein sequences. Parameters: - model_path (str): Path to the saved model. - protein_sequences (List[str]): List of protein sequences. Returns: - List[List[str]]: Predicted labels for each sequence. """ # Load tokenizer and model tokenizer = AutoTokenizer.from_pretrained(model_path) model = AutoModelForTokenClassification.from_pretrained(model_path) # Ensure model is in evaluation mode model.eval() # Tokenize sequences inputs = tokenizer(protein_sequences, return_tensors="pt", padding=True, truncation=True) # Move to the same device as model and obtain logits with torch.no_grad(): logits = model(**inputs).logits # Obtain predicted labels predicted_labels = torch.argmax(logits, dim=-1).cpu().numpy() # Convert label IDs to human-readable labels id2label = model.config.id2label human_readable_labels = [[id2label[label_id] for label_id in sequence] for sequence in predicted_labels] return human_readable_labels # Usage: model_path = "AmelieSchreiber/esm2_t6_8M_general_binding_sites" # Replace with your model's path unseen_proteins = [ "MKVEEILEKALELVIPDEEEVRKGREAEEELRRRLDELGVEYVFVGSYARNTWLKGSLEIDVFLLFPEEFSKEELRERGLEIGKAVLDSYEIRYAEHPYVHGVVKGVEVDVVPCYKLKEPKNIKSAVDRTPFHHKWLEGRIKGKENEVRLLKGFLKANGIYGAEYKVRGFSGYLCELLIVFYGSFLETVKNARRWTRRTVIDVAKGEVRKGEEFFVVDPVDEKRNVAANLSLDNLARFVHLCREFMEAPSLGFFKPKHPLEIEPERLRKIVEERGTAVFAVKFRKPDIVDDNLYPQLERASRKIFEFLERENFMPLRSAFKASEEFCYLLFECQIKEISRVFRRMGPQFEDERNVKKFLSRNRAFRPFIENGRWWAFEMRKFTTPEEGVRSYASTHWHTLGKNVGESIREYFEIISGEKLFKEPVTAELCEMMGVKD", "MKVEEILEKALELVIPDEEEVRKGREAEEELRRRLDELGVEYVFVGSYARNTWLKGSLEIAVFLLFPEEFSKEELRERGLEIGKAVLDSYEIRYAEHPYVHGVVKGVEVDVVPCYKLKEPKNIKSAVDRTPFHHKWLEGRIKGKENEVRLLKGFLKANGIYGAEYKVRGFSGYLCELLIVFYGSFLETVKNARRWTRRTVIDVAKGEVRKGEEFFVVDPVDEKRNVAANLSLDNLARFVHLCREFMEAPSLGFFKVKHPLEIEPERLRKIVEERGTAVFAVKFRKPDIVDDNLYPQLERASRKIFEFLERENFMPLRSAFKASEEFCYLLFECQIKEISRVFRRMGPQFEDERNVKKFLSRNRAFRPFIENGRWWAFEMRKFTTPEEGVRSYASTHWHTLGKNVGESIREYFEIISGEKLFKEPVTAELCEMMGVKD", "MKVEEILEKALELVIPDEEEVRKGREAEEELRRRLDELGVEAVFVGSYARNTWLKGSLEIAVFLLFPEEFSKEELRERGLEIEKAVLDSYEIRYAEHPYVHGVVKGVEVDVVPCYKLKEPKNIKSAVDRTPFHHKELEGRIKGKENEVRLLKGFLKANGIYGAEYAVRGFSGYLCELLIVFYGSFLETVKNARRWTRRTVIDVAKGEVRKGEEFFVVDPVDEKRNVAANLSLDNLARFVHLCREFMEAPSLGFFKVKHPLEIEPERLRKIVEERGTAVFMVKFRKPDIVDDNLYPQLRRASRKIFEFLERNNFMPLRSAFKASEEFCYLLFECQIKEISDVFRRMGPLFEDERNVKKFLSRNRALRPFIENGRWWIFEMRKFTTPEEGVRSYASTHWHTLGKNVGESIREYFEIISGEKLFKEPVTAELCRMMGVKD", "MKVEEILEKALELVIPDEEEVRKGREAEEELRRRLDELGVEAVFVGSYARNTWLKGSLEIAVFLLFPEEFSKEELRERGLEIEKAVLDSYGIRYAEHPYVHGVVKGVELDVVPCYKLKEPKNIKSAVDRTPFHHKELEGRIKGKENEYRSLKGFLKANGIYGAEYAVRGFSGYLCELLIVFYGSFLETVKNARRWTRKTVIDVAKGEVRKGEEFFVVDPVDEKRNVAALLSLDNLARFVHLCREFMEAVSLGFFKVKHPLEIEPERLRKIVEERGTAVFMVKFRKPDIVDDNLYPQLRRASRKIFEFLERNNFMPLRRAFKASEEFCYLLFEQQIKEISDVFRRMGPLFEDERNVKKFLSRNRALRPFIENGRWWIFEMRKFTTPEEGVRSYASTHWHTLGKNVGESIREYFEIIEGEKLFKEPVTAELCRMMGVKD" ] # Replace with your unseen protein sequences predictions = predict_binding_sites(model_path, unseen_proteins) predictions ```