|
--- |
|
widget: |
|
- text: "MEPLDDLDLLLLEEDSGAEAVPRMEILQKKADAFFAETVLSRGVDNRYLVLAVETKLNERGAEEKHLLITVSQEGEQEVLCILRNGWSSVPVEPGDIIHIEGDCTSEPWIVDDDFGYFILSPDMLISGTSVASSIRCLRRAVLSETFRVSDTATRQMLIGTILHEVFQKAISESFAPEKLQELALQTLREVRHLKEMYRLNLSQDEVRCEVEEYLPSFSKWADEFMHKGTKAEFPQMHLSLPSDSSDRSSPCNIEVVKSLDIEESIWSPRFGLKGKIDVTVGVKIHRDCKTKYKIMPLELKTGKESNSIEHRGQVILYTLLSQERREDPEAGWLLYLKTGQMYPVPANHLDKRELLKLRNQLAFSLLHRVSRAAAGEEARLLALPQIIEEEKTCKYCSQMGNCALYSRAVEQVHDTSIPEGMRSKIQEGTQHLTRAHLKYFSLWCLMLTLESQSKDTKKSHQSIWLTPASKLEESGNCIGSLVRTEPVKRVCDGHYLHNFQRKNGPMPATNLMAGDRIILSGEERKLFALSKGYVKRIDTAAVTCLLDRNLSTLPETTLFRLDREEKHGDINTPLGNLSKLMENTDSSKRLRELIIDFKEPQFIAYLSSVLPHDAKDTVANILKGLNKPQRQAMKKVLLSKDYTLIVGMPGTGKTTTICALVRILSACGFSVLLTSYTHSAVDNILLKLAKFKIGFLRLGQSHKVHPDIQKFTEEEMCRLRSIASLAHLEELYNSHPVVATTCMGISHPMFSRKTFDFCIVDEASQISQPICLGPLFFSRRFVLVGDHKQLPPLVLNREARALGMSESLFKRLERNESAVVQLTIQYRMNRKIMSLSNKLTYEGKLECGSDRVANAVITLPNLKDVRLEFYADYSDNPWLAGVFEPDNPVCFLNTDKVPAPEQIENGGVSNVTEARLIVFLTSTFIKAGCSPSDIGIIAPYRQQLRTITDLLARSSVGMVEVNTVDKYQGRDKSLILVSFVRSNEDGTLGELLKDWRRLNVAITRAKHKLILLGSVSSLKRF" |
|
example_title: "Protein Sequence 1" |
|
- text: "MNSVTVSHAPYYIVYHDDWEPVMSQLVEFYNEVASWLLRDETSPIPPKFFIQLKQMLRNKRVCVCGILPYPIDGTGVPFESPNFTKKSIKEIASSISRLTGVIDYKGYNLNIIDGVIPWNYYLSCKLGETKSHAIYWDKISKLLLQHITKHVSVLYCLGKTDFSNIRAKLESPVTTIVGYHPAARDRQFEKDRSFEIINELLELDNKVPINWAQGFIY" |
|
example_title: "Protein Sequence 2" |
|
- text: "MNSVTVSHAPYTIAYHDDWEPVMSQLVEFYNEAASWLLRDETSPIPSKFNIQLKQPLRNKRVCVFGIDPYPKDGTGVPFESPNFTKKSIKEIASSISRLMGVIDYEGYNLNIIDGVIPWNYYLSCKLGETKSHAIYWDKISKLLLQHITKHVSVLYCLGKTDFSNIRAKLESPVTTIVGYHPSARDRQFEKDRSFEIINVLLELDNKVPLNWAQGFIY" |
|
example_title: "Protein Sequence 3" |
|
license: mit |
|
language: |
|
- en |
|
metrics: |
|
- f1 |
|
- accuracy |
|
- precision |
|
- recall |
|
- matthews_correlation |
|
- roc_auc |
|
library_name: peft |
|
tags: |
|
- ESM-2 |
|
- protein language model |
|
- biology |
|
- binding sites |
|
--- |
|
|
|
## Training: |
|
|
|
For a report on the training [please see here](https://api.wandb.ai/links/amelie-schreiber-math/84t5gsfm) and |
|
[here](https://wandb.ai/amelie-schreiber-math/huggingface/reports/ESM-2-Binding-Sites-Predictor-Scaling-Up--Vmlldzo1Mzc3MTAz?accessToken=cbl9v3bvuq65j5t4qo9l0bhccm3hrse8nt01t3dka6h6zb0azzakahnxdxfrb28m). |
|
|
|
|
|
## Metrics: |
|
|
|
```python |
|
Train: |
|
({'accuracy': 0.9406146072672105, |
|
'precision': 0.2947122459102886, |
|
'recall': 0.952624323712029, |
|
'f1': 0.4501592605994876, |
|
'auc': 0.9464622170085311, |
|
'mcc': 0.5118390407598565}, |
|
Test: |
|
{'accuracy': 0.9266827008067329, |
|
'precision': 0.22378953253253775, |
|
'recall': 0.7790246675002842, |
|
'f1': 0.3476966444342296, |
|
'auc': 0.8547531675185658, |
|
'mcc': 0.3930283737012391}) |
|
``` |
|
|
|
## Using the Model |
|
|
|
### Using on your Protein Sequences |
|
|
|
To use the model on one of your protein sequences try running the following: |
|
|
|
```python |
|
from transformers import AutoModelForTokenClassification, AutoTokenizer |
|
from peft import PeftModel |
|
import torch |
|
|
|
# Path to the saved LoRA model |
|
model_path = "AmelieSchreiber/esm2_t12_35M_lora_binding_sites_cp1" |
|
# ESM2 base model |
|
base_model_path = "facebook/esm2_t12_35M_UR50D" |
|
|
|
# Load the model |
|
base_model = AutoModelForTokenClassification.from_pretrained(base_model_path) |
|
loaded_model = PeftModel.from_pretrained(base_model, model_path) |
|
|
|
# Ensure the model is in evaluation mode |
|
loaded_model.eval() |
|
|
|
# Load the tokenizer |
|
loaded_tokenizer = AutoTokenizer.from_pretrained(base_model_path) |
|
|
|
# Protein sequence for inference |
|
protein_sequence = "MAVPETRPNHTIYINNLNEKIKKDELKKSLHAIFSRFGQILDILVSRSLKMRGQAFVIFKEVSSATNALRSMQGFPFYDKPMRIQYAKTDSDIIAKMKGT" # Replace with your actual sequence |
|
|
|
# Tokenize the sequence |
|
inputs = loaded_tokenizer(protein_sequence, return_tensors="pt", truncation=True, max_length=1024, padding='max_length') |
|
|
|
# Run the model |
|
with torch.no_grad(): |
|
logits = loaded_model(**inputs).logits |
|
|
|
# Get predictions |
|
tokens = loaded_tokenizer.convert_ids_to_tokens(inputs["input_ids"][0]) # Convert input ids back to tokens |
|
predictions = torch.argmax(logits, dim=2) |
|
|
|
# Define labels |
|
id2label = { |
|
0: "No binding site", |
|
1: "Binding site" |
|
} |
|
|
|
# Print the predicted labels for each token |
|
for token, prediction in zip(tokens, predictions[0].numpy()): |
|
if token not in ['<pad>', '<cls>', '<eos>']: |
|
print((token, id2label[prediction])) |
|
``` |
|
|
|
### Getting the Train/Test Metrics: |
|
Head over to [here](https://huggingface.co/datasets/AmelieSchreiber/binding_sites_random_split_by_family) |
|
to download the dataset first. Once you have the pickle files downloaded locally, run the following: |
|
|
|
```python |
|
from datasets import Dataset |
|
from transformers import AutoTokenizer |
|
import pickle |
|
|
|
# Load tokenizer |
|
tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t12_35M_UR50D") |
|
|
|
# Function to truncate labels |
|
def truncate_labels(labels, max_length): |
|
"""Truncate labels to the specified max_length.""" |
|
return [label[:max_length] for label in labels] |
|
|
|
# Set the maximum sequence length |
|
max_sequence_length = 1000 |
|
|
|
# Load the data from pickle files |
|
with open("train_sequences_chunked_by_family.pkl", "rb") as f: |
|
train_sequences = pickle.load(f) |
|
with open("test_sequences_chunked_by_family.pkl", "rb") as f: |
|
test_sequences = pickle.load(f) |
|
with open("train_labels_chunked_by_family.pkl", "rb") as f: |
|
train_labels = pickle.load(f) |
|
with open("test_labels_chunked_by_family.pkl", "rb") as f: |
|
test_labels = pickle.load(f) |
|
|
|
# Tokenize the sequences |
|
train_tokenized = tokenizer(train_sequences, padding=True, truncation=True, max_length=max_sequence_length, return_tensors="pt", is_split_into_words=False) |
|
test_tokenized = tokenizer(test_sequences, padding=True, truncation=True, max_length=max_sequence_length, return_tensors="pt", is_split_into_words=False) |
|
|
|
# Truncate the labels to match the tokenized sequence lengths |
|
train_labels = truncate_labels(train_labels, max_sequence_length) |
|
test_labels = truncate_labels(test_labels, max_sequence_length) |
|
|
|
# Create train and test datasets |
|
train_dataset = Dataset.from_dict({k: v for k, v in train_tokenized.items()}).add_column("labels", train_labels) |
|
test_dataset = Dataset.from_dict({k: v for k, v in test_tokenized.items()}).add_column("labels", test_labels) |
|
``` |
|
|
|
Then run the following to get the train/test metrics: |
|
|
|
```python |
|
from sklearn.metrics import( |
|
matthews_corrcoef, |
|
accuracy_score, |
|
precision_recall_fscore_support, |
|
roc_auc_score |
|
) |
|
from peft import PeftModel |
|
from transformers import DataCollatorForTokenClassification, AutoModelForTokenClassification |
|
from transformers import Trainer |
|
from accelerate import Accelerator |
|
|
|
# Instantiate the accelerator |
|
accelerator = Accelerator() |
|
|
|
# Define paths to the LoRA and base models |
|
base_model_path = "facebook/esm2_t12_35M_UR50D" |
|
lora_model_path = "AmelieSchreiber/esm2_t12_35M_lora_binding_sites_cp1" # "path/to/your/lora/model" Replace with the correct path to your LoRA model |
|
|
|
# Load the base model |
|
base_model = AutoModelForTokenClassification.from_pretrained(base_model_path) |
|
|
|
# Load the LoRA model |
|
model = PeftModel.from_pretrained(base_model, lora_model_path) |
|
model = accelerator.prepare(model) # Prepare the model using the accelerator |
|
|
|
# Define label mappings |
|
id2label = {0: "No binding site", 1: "Binding site"} |
|
label2id = {v: k for k, v in id2label.items()} |
|
|
|
# Create a data collator |
|
data_collator = DataCollatorForTokenClassification(tokenizer) |
|
|
|
# Define a function to compute the metrics |
|
def compute_metrics(dataset): |
|
# Get the predictions using the trained model |
|
trainer = Trainer(model=model, data_collator=data_collator) |
|
predictions, labels, _ = trainer.predict(test_dataset=dataset) |
|
|
|
# Remove padding and special tokens |
|
mask = labels != -100 |
|
true_labels = labels[mask].flatten() |
|
flat_predictions = np.argmax(predictions, axis=2)[mask].flatten().tolist() |
|
|
|
# Compute the metrics |
|
accuracy = accuracy_score(true_labels, flat_predictions) |
|
precision, recall, f1, _ = precision_recall_fscore_support(true_labels, flat_predictions, average='binary') |
|
auc = roc_auc_score(true_labels, flat_predictions) |
|
mcc = matthews_corrcoef(true_labels, flat_predictions) # Compute the MCC |
|
|
|
return {"accuracy": accuracy, "precision": precision, "recall": recall, "f1": f1, "auc": auc, "mcc": mcc} # Include the MCC in the returned dictionary |
|
|
|
# Get the metrics for the training and test datasets |
|
train_metrics = compute_metrics(train_dataset) |
|
test_metrics = compute_metrics(test_dataset) |
|
|
|
train_metrics, test_metrics |
|
``` |