AmelieSchreiber's picture
Update README.md
7fe197f
|
raw
history blame
4.92 kB
metadata
license: mit

Training:

For a report on the training please see here and here.

Metrics:

Train:
({'accuracy': 0.9406146072672105,
  'precision': 0.2947122459102886,
  'recall': 0.952624323712029,
  'f1': 0.4501592605994876,
  'auc': 0.9464622170085311,
  'mcc': 0.5118390407598565},
Test:
 {'accuracy': 0.9266827008067329,
  'precision': 0.22378953253253775,
  'recall': 0.7790246675002842,
  'f1': 0.3476966444342296,
  'auc': 0.8547531675185658,
  'mcc': 0.3930283737012391})

Using the Model

Head over to here to download the dataset first. Once you have the pickle files downloaded locally, run the following:

from datasets import Dataset
from transformers import AutoTokenizer
import pickle

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t12_35M_UR50D")

# Function to truncate labels
def truncate_labels(labels, max_length):
    """Truncate labels to the specified max_length."""
    return [label[:max_length] for label in labels]

# Set the maximum sequence length
max_sequence_length = 1000

# Load the data from pickle files
with open("train_sequences_chunked_by_family.pkl", "rb") as f:
    train_sequences = pickle.load(f)
with open("test_sequences_chunked_by_family.pkl", "rb") as f:
    test_sequences = pickle.load(f)
with open("train_labels_chunked_by_family.pkl", "rb") as f:
    train_labels = pickle.load(f)
with open("test_labels_chunked_by_family.pkl", "rb") as f:
    test_labels = pickle.load(f)

# Tokenize the sequences
train_tokenized = tokenizer(train_sequences, padding=True, truncation=True, max_length=max_sequence_length, return_tensors="pt", is_split_into_words=False)
test_tokenized = tokenizer(test_sequences, padding=True, truncation=True, max_length=max_sequence_length, return_tensors="pt", is_split_into_words=False)

# Truncate the labels to match the tokenized sequence lengths
train_labels = truncate_labels(train_labels, max_sequence_length)
test_labels = truncate_labels(test_labels, max_sequence_length)

# Create train and test datasets
train_dataset = Dataset.from_dict({k: v for k, v in train_tokenized.items()}).add_column("labels", train_labels)
test_dataset = Dataset.from_dict({k: v for k, v in test_tokenized.items()}).add_column("labels", test_labels)

Then run the following to get the train/test metrics:

from sklearn.metrics import(
    matthews_corrcoef, 
    accuracy_score, 
    precision_recall_fscore_support, 
    roc_auc_score
)
from peft import PeftModel
from transformers import DataCollatorForTokenClassification, AutoModelForTokenClassification
from transformers import Trainer
from accelerate import Accelerator

# Instantiate the accelerator
accelerator = Accelerator()

# Define paths to the LoRA and base models
base_model_path = "facebook/esm2_t12_35M_UR50D"
lora_model_path = "AmelieSchreiber/esm2_t12_35M_lora_binding_sites_cp1" # "path/to/your/lora/model" Replace with the correct path to your LoRA model

# Load the base model
base_model = AutoModelForTokenClassification.from_pretrained(base_model_path)

# Load the LoRA model
model = PeftModel.from_pretrained(base_model, lora_model_path)
model = accelerator.prepare(model)  # Prepare the model using the accelerator

# Define label mappings
id2label = {0: "No binding site", 1: "Binding site"}
label2id = {v: k for k, v in id2label.items()}

# Create a data collator
data_collator = DataCollatorForTokenClassification(tokenizer)

# Define a function to compute the metrics
def compute_metrics(dataset):
    # Get the predictions using the trained model
    trainer = Trainer(model=model, data_collator=data_collator)
    predictions, labels, _ = trainer.predict(test_dataset=dataset)
    
    # Remove padding and special tokens
    mask = labels != -100
    true_labels = labels[mask].flatten()
    flat_predictions = np.argmax(predictions, axis=2)[mask].flatten().tolist()

    # Compute the metrics
    accuracy = accuracy_score(true_labels, flat_predictions)
    precision, recall, f1, _ = precision_recall_fscore_support(true_labels, flat_predictions, average='binary')
    auc = roc_auc_score(true_labels, flat_predictions)
    mcc = matthews_corrcoef(true_labels, flat_predictions)  # Compute the MCC
    
    return {"accuracy": accuracy, "precision": precision, "recall": recall, "f1": f1, "auc": auc, "mcc": mcc}  # Include the MCC in the returned dictionary

# Get the metrics for the training and test datasets
train_metrics = compute_metrics(train_dataset)
test_metrics = compute_metrics(test_dataset)

train_metrics, test_metrics