--- license: mit --- # ESM-2 for Predicting Binding Sites This is the 650M parameter version of ESM-2, finetuned with QLoRA to predict binding sites of proteins based on single sequences alone. No multiple sequence alignment or structure is required. The embeddings from this model can also be used in structural models. The model is trained on approximately 12M protein sequences from UniProt, with an 80/20 train/test split. ## Metrics ### Train Metrics (Based on a 40% sample) ```python 'eval_loss': 0.05597764626145363, 'eval_accuracy': 0.9829392036087405, 'eval_precision': 0.5626191259397847, 'eval_recall': 0.9488112528941492, 'eval_f1': 0.7063763773187873, 'eval_auc': 0.9662524626230765, 'eval_mcc': 0.7235838533979579 ``` ### Test Metrics Due to the size of the dataset we had to get the test metrics in chunks and aggregate. To see the metrics for each chunk, [refer to this text file](https://huggingface.co/AmelieSchreiber/esm2_t33_650M_qlora_binding_12M/blob/main/test_metrics.txt). ```python 'eval_loss': 0.16281947493553162, 'eval_accuracy': 0.9569658774883986, 'eval_precision': 0.3209956738348438, 'eval_recall': 0.7883697002335764, 'eval_f1': 0.4562306866120791, 'eval_auc': 0.8746433990040084, 'eval_mcc': 0.48648765699020435 ``` The metrics for the earlier checkpoints are not reported here yet. ## Using the Model ```python from transformers import AutoModelForTokenClassification, AutoTokenizer from peft import PeftModel import torch # Path to the saved LoRA model model_path = "AmelieSchreiber/esm2_t33_650M_qlora_binding_12M" # ESM2 base model base_model_path = "facebook/esm2_t33_650M_UR50D" # Load the model base_model = AutoModelForTokenClassification.from_pretrained(base_model_path) loaded_model = PeftModel.from_pretrained(base_model, model_path) # Ensure the model is in evaluation mode loaded_model.eval() # Load the tokenizer loaded_tokenizer = AutoTokenizer.from_pretrained(base_model_path) # Protein sequence for inference protein_sequence = "MAVPETRPNHTIYINNLNEKIKKDELKKSLHAIFSRFGQILDILVSRSLKMRGQAFVIFKEVSSATNALRSMQGFPFYDKPMRIQYAKTDSDIIAKMKGT" # Replace with your actual sequence # Tokenize the sequence inputs = loaded_tokenizer(protein_sequence, return_tensors="pt", truncation=True, max_length=1024, padding='max_length') # Run the model with torch.no_grad(): logits = loaded_model(**inputs).logits # Get predictions tokens = loaded_tokenizer.convert_ids_to_tokens(inputs["input_ids"][0]) # Convert input ids back to tokens predictions = torch.argmax(logits, dim=2) # Define labels id2label = { 0: "No binding site", 1: "Binding site" } # Print the predicted labels for each token for token, prediction in zip(tokens, predictions[0].numpy()): if token not in ['', '', '']: print((token, id2label[prediction])) ```