AmelieSchreiber's picture
Rename metrics (2).py to metrics.py
0def08a
raw
history blame contribute delete
No virus
4.61 kB
import os
import wandb
import numpy as np
import pickle
import torch
import torch.nn as nn
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, roc_auc_score, matthews_corrcoef
from transformers import AutoModelForTokenClassification, AutoTokenizer, DataCollatorForTokenClassification, Trainer
from datasets import Dataset
from accelerate import Accelerator
from peft import PeftModel
# Helper functions and data preparation
def truncate_labels(labels, max_length):
"""Truncate labels to the specified max_length."""
return [label[:max_length] for label in labels]
def compute_metrics(p):
"""Compute metrics for evaluation."""
predictions, labels = p
predictions = np.argmax(predictions, axis=2)
# Remove padding (-100 labels)
predictions = predictions[labels != -100].flatten()
labels = labels[labels != -100].flatten()
# Compute accuracy
accuracy = accuracy_score(labels, predictions)
# Compute precision, recall, F1 score, and AUC
precision, recall, f1, _ = precision_recall_fscore_support(labels, predictions, average='binary')
auc = roc_auc_score(labels, predictions)
# Compute MCC
mcc = matthews_corrcoef(labels, predictions)
return {'accuracy': accuracy, 'precision': precision, 'recall': recall, 'f1': f1, 'auc': auc, 'mcc': mcc}
class WeightedTrainer(Trainer):
def compute_loss(self, model, inputs, return_outputs=False):
"""Custom compute_loss function."""
outputs = model(**inputs)
loss_fct = nn.CrossEntropyLoss()
active_loss = inputs["attention_mask"].view(-1) == 1
active_logits = outputs.logits.view(-1, model.config.num_labels)
active_labels = torch.where(
active_loss, inputs["labels"].view(-1), torch.tensor(loss_fct.ignore_index).type_as(inputs["labels"])
)
loss = loss_fct(active_logits, active_labels)
return (loss, outputs) if return_outputs else loss
if __name__ == "__main__":
# Environment setup
accelerator = Accelerator()
# wandb.init(project='binding_site_prediction')
# Load data and labels
with open("600K_data/train_sequences_chunked_by_family.pkl", "rb") as f:
train_sequences = pickle.load(f)
with open("600K_data/test_sequences_chunked_by_family.pkl", "rb") as f:
test_sequences = pickle.load(f)
with open("600K_data/train_labels_chunked_by_family.pkl", "rb") as f:
train_labels = pickle.load(f)
with open("600K_data/test_labels_chunked_by_family.pkl", "rb") as f:
test_labels = pickle.load(f)
# Tokenization and dataset creation
tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t12_35M_UR50D")
max_sequence_length = tokenizer.model_max_length
train_tokenized = tokenizer(train_sequences, padding=True, truncation=True, max_length=max_sequence_length, return_tensors="pt", is_split_into_words=False)
test_tokenized = tokenizer(test_sequences, padding=True, truncation=True, max_length=max_sequence_length, return_tensors="pt", is_split_into_words=False)
train_labels = truncate_labels(train_labels, max_sequence_length)
test_labels = truncate_labels(test_labels, max_sequence_length)
train_dataset = Dataset.from_dict({k: v for k, v in train_tokenized.items()}).add_column("labels", train_labels)
test_dataset = Dataset.from_dict({k: v for k, v in test_tokenized.items()}).add_column("labels", test_labels)
# Load the pre-trained LoRA model
base_model_path = "esm2_t6_8M_finetune_2023-10-08_00-58-24/checkpoint-42015"
# lora_model_path = "AmelieSchreiber/esm2_t12_35M_qlora_binding_2600K_cp1" # Replace with the correct path to your LoRA model
# base_model = AutoModelForTokenClassification.from_pretrained(base_model_path) # use this for LoRA
model = AutoModelForTokenClassification.from_pretrained(base_model_path) # remove this for LoRA
# model = PeftModel.from_pretrained(base_model, lora_model_path) # use this for LoRA
model = accelerator.prepare(model)
# Define a function to compute metrics and get the train/test metrics
data_collator = DataCollatorForTokenClassification(tokenizer)
trainer = Trainer(model=model, data_collator=data_collator, compute_metrics=compute_metrics)
train_metrics = trainer.evaluate(train_dataset)
test_metrics = trainer.evaluate(test_dataset)
# Print the metrics
print(f"Train metrics: {train_metrics}")
print(f"Test metrics: {test_metrics}")
# Log metrics to W&B
# wandb.log({"Train metrics": train_metrics, "Test metrics": test_metrics})