AmelieSchreiber's picture
Rename ensemble (1).py to ensemble.py
289189f
raw
history blame contribute delete
No virus
5.75 kB
import os
import pickle
import numpy as np
from scipy import stats
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, roc_auc_score, matthews_corrcoef
from transformers import AutoModelForTokenClassification, Trainer, AutoTokenizer, DataCollatorForTokenClassification
from datasets import Dataset, concatenate_datasets
from accelerate import Accelerator
from peft import PeftModel
import gc
# Step 1: Load train/test data and labels from pickle files
with open("/kaggle/input/550k-dataset/train_sequences_chunked_by_family.pkl", "rb") as f:
train_sequences = pickle.load(f)
with open("/kaggle/input/550k-dataset/test_sequences_chunked_by_family.pkl", "rb") as f:
test_sequences = pickle.load(f)
with open("/kaggle/input/550k-dataset/train_labels_chunked_by_family.pkl", "rb") as f:
train_labels = pickle.load(f)
with open("/kaggle/input/550k-dataset/test_labels_chunked_by_family.pkl", "rb") as f:
test_labels = pickle.load(f)
# Step 2: Define the Tokenizer
tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t12_35M_UR50D")
max_sequence_length = tokenizer.model_max_length
# Step 3: Define a `compute_metrics_for_batch` function.
def compute_metrics_for_batch(sequences_batch, labels_batch, models, voting='hard'):
# Tokenize batch
batch_tokenized = tokenizer(sequences_batch, padding=True, truncation=True, max_length=max_sequence_length, return_tensors="pt", is_split_into_words=False)
# print("Shape of tokenized sequences:", batch_tokenized["input_ids"].shape) # Debug print
batch_dataset = Dataset.from_dict({k: v for k, v in batch_tokenized.items()})
batch_dataset = batch_dataset.add_column("labels", labels_batch[:len(batch_dataset)])
# Convert labels to numpy array of shape (1000, 1002)
labels_array = np.array([np.pad(label, (0, 1002 - len(label)), constant_values=-100) for label in batch_dataset["labels"]])
# Initialize a trainer for each model
data_collator = DataCollatorForTokenClassification(tokenizer)
trainers = [Trainer(model=model, data_collator=data_collator) for model in models]
# Get the predictions from each model
all_predictions = [trainer.predict(test_dataset=batch_dataset)[0] for trainer in trainers]
if voting == 'hard':
# Hard voting
hard_predictions = [np.argmax(predictions, axis=2) for predictions in all_predictions]
ensemble_predictions = stats.mode(hard_predictions, axis=0)[0][0]
elif voting == 'soft':
# Soft voting
avg_predictions = np.mean(all_predictions, axis=0)
ensemble_predictions = np.argmax(avg_predictions, axis=2)
else:
raise ValueError("Voting must be either 'hard' or 'soft'")
# Use broadcasting to create 2D mask
mask_2d = labels_array != -100
# Filter true labels and predictions using the mask
true_labels_list = [label[mask_2d[idx]] for idx, label in enumerate(labels_array)]
true_labels = np.concatenate(true_labels_list)
flat_predictions_list = [ensemble_predictions[idx][mask_2d[idx]] for idx in range(ensemble_predictions.shape[0])]
flat_predictions = np.concatenate(flat_predictions_list).tolist()
# Compute the metrics
accuracy = accuracy_score(true_labels, flat_predictions)
precision, recall, f1, _ = precision_recall_fscore_support(true_labels, flat_predictions, average='binary')
auc = roc_auc_score(true_labels, flat_predictions)
mcc = matthews_corrcoef(true_labels, flat_predictions) # Compute MCC
return {"accuracy": accuracy, "precision": precision, "recall": recall, "f1": f1, "auc": auc, "mcc": mcc}
# Step 4: Evaluate in batches
def evaluate_in_batches(sequences, labels, models, dataset_name, voting, batch_size=1000, print_first_n=5):
num_batches = len(sequences) // batch_size + int(len(sequences) % batch_size != 0)
metrics_list = []
for i in range(num_batches):
start_idx = i * batch_size
end_idx = start_idx + batch_size
batch_metrics = compute_metrics_for_batch(sequences[start_idx:end_idx], labels[start_idx:end_idx], models, voting)
# Print metrics for the first few batches for both train and test datasets
if i < print_first_n:
print(f"{dataset_name} - Batch {i+1}/{num_batches} metrics: {batch_metrics}")
metrics_list.append(batch_metrics)
# Average metrics over all batches
avg_metrics = {key: np.mean([metrics[key] for metrics in metrics_list]) for key in metrics_list[0]}
return avg_metrics
# Step 5: Load pre-trained base model and fine-tuned LoRA models
accelerator = Accelerator()
base_model_path = "facebook/esm2_t12_35M_UR50D"
base_model = AutoModelForTokenClassification.from_pretrained(base_model_path)
lora_model_paths = [
"AmelieSchreiber/esm2_t12_35M_lora_binding_sites_cp1",
"AmelieSchreiber/esm2_t12_35M_lora_binding_sites_v2_cp1",
]
models = [PeftModel.from_pretrained(base_model, path) for path in lora_model_paths]
models = [accelerator.prepare(model) for model in models]
# Step 6: Compute and print the metrics
test_metrics_soft = evaluate_in_batches(test_sequences, test_labels, models, "test", voting='soft')
train_metrics_soft = evaluate_in_batches(train_sequences, train_labels, models, "train", voting='soft')
test_metrics_hard = evaluate_in_batches(test_sequences, test_labels, models, "test", voting='hard')
train_metrics_hard = evaluate_in_batches(train_sequences, train_labels, models, "train", voting='hard')
print("Test metrics (soft voting):", test_metrics_soft)
print("Train metrics (soft voting):", train_metrics_soft)
print("Test metrics (hard voting):", test_metrics_hard)
print("Train metrics (hard voting):", train_metrics_hard)