--- widget: - text: "MEPLDDLDLLLLEEDSGAEAVPRMEILQKKADAFFAETVLSRGVDNRYLVLAVETKLNERGAEEKHLLITVSQEGEQEVLCILRNGWSSVPVEPGDIIHIEGDCTSEPWIVDDDFGYFILSPDMLISGTSVASSIRCLRRAVLSETFRVSDTATRQMLIGTILHEVFQKAISESFAPEKLQELALQTLREVRHLKEMYRLNLSQDEVRCEVEEYLPSFSKWADEFMHKGTKAEFPQMHLSLPSDSSDRSSPCNIEVVKSLDIEESIWSPRFGLKGKIDVTVGVKIHRDCKTKYKIMPLELKTGKESNSIEHRGQVILYTLLSQERREDPEAGWLLYLKTGQMYPVPANHLDKRELLKLRNQLAFSLLHRVSRAAAGEEARLLALPQIIEEEKTCKYCSQMGNCALYSRAVEQVHDTSIPEGMRSKIQEGTQHLTRAHLKYFSLWCLMLTLESQSKDTKKSHQSIWLTPASKLEESGNCIGSLVRTEPVKRVCDGHYLHNFQRKNGPMPATNLMAGDRIILSGEERKLFALSKGYVKRIDTAAVTCLLDRNLSTLPETTLFRLDREEKHGDINTPLGNLSKLMENTDSSKRLRELIIDFKEPQFIAYLSSVLPHDAKDTVANILKGLNKPQRQAMKKVLLSKDYTLIVGMPGTGKTTTICALVRILSACGFSVLLTSYTHSAVDNILLKLAKFKIGFLRLGQSHKVHPDIQKFTEEEMCRLRSIASLAHLEELYNSHPVVATTCMGISHPMFSRKTFDFCIVDEASQISQPICLGPLFFSRRFVLVGDHKQLPPLVLNREARALGMSESLFKRLERNESAVVQLTIQYRMNRKIMSLSNKLTYEGKLECGSDRVANAVITLPNLKDVRLEFYADYSDNPWLAGVFEPDNPVCFLNTDKVPAPEQIENGGVSNVTEARLIVFLTSTFIKAGCSPSDIGIIAPYRQQLRTITDLLARSSVGMVEVNTVDKYQGRDKSLILVSFVRSNEDGTLGELLKDWRRLNVAITRAKHKLILLGSVSSLKRF" example_title: "Protein Sequence 1" - text: "MNSVTVSHAPYYIVYHDDWEPVMSQLVEFYNEVASWLLRDETSPIPPKFFIQLKQMLRNKRVCVCGILPYPIDGTGVPFESPNFTKKSIKEIASSISRLTGVIDYKGYNLNIIDGVIPWNYYLSCKLGETKSHAIYWDKISKLLLQHITKHVSVLYCLGKTDFSNIRAKLESPVTTIVGYHPAARDRQFEKDRSFEIINELLELDNKVPINWAQGFIY" example_title: "Protein Sequence 2" - text: "MNSVTVSHAPYTIAYHDDWEPVMSQLVEFYNEAASWLLRDETSPIPSKFNIQLKQPLRNKRVCVFGIDPYPKDGTGVPFESPNFTKKSIKEIASSISRLMGVIDYEGYNLNIIDGVIPWNYYLSCKLGETKSHAIYWDKISKLLLQHITKHVSVLYCLGKTDFSNIRAKLESPVTTIVGYHPSARDRQFEKDRSFEIINVLLELDNKVPLNWAQGFIY" example_title: "Protein Sequence 3" license: mit language: - en metrics: - f1 - accuracy - precision - recall - matthews_correlation - roc_auc library_name: peft tags: - ESM-2 - protein language model - biology - binding sites --- ## Training: For a report on the training [please see here](https://api.wandb.ai/links/amelie-schreiber-math/84t5gsfm) and [here](https://wandb.ai/amelie-schreiber-math/huggingface/reports/ESM-2-Binding-Sites-Predictor-Scaling-Up--Vmlldzo1Mzc3MTAz?accessToken=cbl9v3bvuq65j5t4qo9l0bhccm3hrse8nt01t3dka6h6zb0azzakahnxdxfrb28m). ## Metrics: ```python Train: ({'accuracy': 0.9406146072672105, 'precision': 0.2947122459102886, 'recall': 0.952624323712029, 'f1': 0.4501592605994876, 'auc': 0.9464622170085311, 'mcc': 0.5118390407598565}, Test: {'accuracy': 0.9266827008067329, 'precision': 0.22378953253253775, 'recall': 0.7790246675002842, 'f1': 0.3476966444342296, 'auc': 0.8547531675185658, 'mcc': 0.3930283737012391}) ``` ## Using the Model ### Using on your Protein Sequences To use the model on one of your protein sequences try running the following: ```python from transformers import AutoModelForTokenClassification, AutoTokenizer from peft import PeftModel import torch # Path to the saved LoRA model model_path = "AmelieSchreiber/esm2_t12_35M_lora_binding_sites_cp1" # ESM2 base model base_model_path = "facebook/esm2_t12_35M_UR50D" # Load the model base_model = AutoModelForTokenClassification.from_pretrained(base_model_path) loaded_model = PeftModel.from_pretrained(base_model, model_path) # Ensure the model is in evaluation mode loaded_model.eval() # Load the tokenizer loaded_tokenizer = AutoTokenizer.from_pretrained(base_model_path) # Protein sequence for inference protein_sequence = "MAVPETRPNHTIYINNLNEKIKKDELKKSLHAIFSRFGQILDILVSRSLKMRGQAFVIFKEVSSATNALRSMQGFPFYDKPMRIQYAKTDSDIIAKMKGT" # Replace with your actual sequence # Tokenize the sequence inputs = loaded_tokenizer(protein_sequence, return_tensors="pt", truncation=True, max_length=1024, padding='max_length') # Run the model with torch.no_grad(): logits = loaded_model(**inputs).logits # Get predictions tokens = loaded_tokenizer.convert_ids_to_tokens(inputs["input_ids"][0]) # Convert input ids back to tokens predictions = torch.argmax(logits, dim=2) # Define labels id2label = { 0: "No binding site", 1: "Binding site" } # Print the predicted labels for each token for token, prediction in zip(tokens, predictions[0].numpy()): if token not in ['', '', '']: print((token, id2label[prediction])) ``` ### Getting the Train/Test Metrics: Head over to [here](https://huggingface.co/datasets/AmelieSchreiber/binding_sites_random_split_by_family) to download the dataset first. Once you have the pickle files downloaded locally, run the following: ```python from datasets import Dataset from transformers import AutoTokenizer import pickle # Load tokenizer tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t12_35M_UR50D") # Function to truncate labels def truncate_labels(labels, max_length): """Truncate labels to the specified max_length.""" return [label[:max_length] for label in labels] # Set the maximum sequence length max_sequence_length = 1000 # Load the data from pickle files with open("train_sequences_chunked_by_family.pkl", "rb") as f: train_sequences = pickle.load(f) with open("test_sequences_chunked_by_family.pkl", "rb") as f: test_sequences = pickle.load(f) with open("train_labels_chunked_by_family.pkl", "rb") as f: train_labels = pickle.load(f) with open("test_labels_chunked_by_family.pkl", "rb") as f: test_labels = pickle.load(f) # Tokenize the sequences train_tokenized = tokenizer(train_sequences, padding=True, truncation=True, max_length=max_sequence_length, return_tensors="pt", is_split_into_words=False) test_tokenized = tokenizer(test_sequences, padding=True, truncation=True, max_length=max_sequence_length, return_tensors="pt", is_split_into_words=False) # Truncate the labels to match the tokenized sequence lengths train_labels = truncate_labels(train_labels, max_sequence_length) test_labels = truncate_labels(test_labels, max_sequence_length) # Create train and test datasets train_dataset = Dataset.from_dict({k: v for k, v in train_tokenized.items()}).add_column("labels", train_labels) test_dataset = Dataset.from_dict({k: v for k, v in test_tokenized.items()}).add_column("labels", test_labels) ``` Then run the following to get the train/test metrics: ```python from sklearn.metrics import( matthews_corrcoef, accuracy_score, precision_recall_fscore_support, roc_auc_score ) from peft import PeftModel from transformers import DataCollatorForTokenClassification, AutoModelForTokenClassification from transformers import Trainer from accelerate import Accelerator # Instantiate the accelerator accelerator = Accelerator() # Define paths to the LoRA and base models base_model_path = "facebook/esm2_t12_35M_UR50D" lora_model_path = "AmelieSchreiber/esm2_t12_35M_lora_binding_sites_cp1" # "path/to/your/lora/model" Replace with the correct path to your LoRA model # Load the base model base_model = AutoModelForTokenClassification.from_pretrained(base_model_path) # Load the LoRA model model = PeftModel.from_pretrained(base_model, lora_model_path) model = accelerator.prepare(model) # Prepare the model using the accelerator # Define label mappings id2label = {0: "No binding site", 1: "Binding site"} label2id = {v: k for k, v in id2label.items()} # Create a data collator data_collator = DataCollatorForTokenClassification(tokenizer) # Define a function to compute the metrics def compute_metrics(dataset): # Get the predictions using the trained model trainer = Trainer(model=model, data_collator=data_collator) predictions, labels, _ = trainer.predict(test_dataset=dataset) # Remove padding and special tokens mask = labels != -100 true_labels = labels[mask].flatten() flat_predictions = np.argmax(predictions, axis=2)[mask].flatten().tolist() # Compute the metrics accuracy = accuracy_score(true_labels, flat_predictions) precision, recall, f1, _ = precision_recall_fscore_support(true_labels, flat_predictions, average='binary') auc = roc_auc_score(true_labels, flat_predictions) mcc = matthews_corrcoef(true_labels, flat_predictions) # Compute the MCC return {"accuracy": accuracy, "precision": precision, "recall": recall, "f1": f1, "auc": auc, "mcc": mcc} # Include the MCC in the returned dictionary # Get the metrics for the training and test datasets train_metrics = compute_metrics(train_dataset) test_metrics = compute_metrics(test_dataset) train_metrics, test_metrics ```