|
import pickle |
|
import numpy as np |
|
import torch |
|
from torch.utils.data import Dataset as TorchDataset |
|
from transformers import AutoTokenizer, AutoModelForTokenClassification, Trainer |
|
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, roc_auc_score, matthews_corrcoef |
|
from peft import PeftModel, get_peft_config, PeftConfig, get_peft_model, LoraConfig, TaskType |
|
from accelerate import Accelerator |
|
from tqdm import tqdm |
|
|
|
|
|
accelerator = Accelerator() |
|
|
|
class ProteinDataset(TorchDataset): |
|
def __init__(self, sequences_path, labels_path, tokenizer, max_length): |
|
self.tokenizer = tokenizer |
|
self.max_length = max_length |
|
|
|
with open(sequences_path, "rb") as f: |
|
self.sequences = pickle.load(f) |
|
|
|
with open(labels_path, "rb") as f: |
|
self.labels = pickle.load(f) |
|
|
|
def __len__(self): |
|
return len(self.sequences) |
|
|
|
def __getitem__(self, idx): |
|
sequence = self.sequences[idx] |
|
label = self.labels[idx] |
|
|
|
tokenized = self.tokenizer(sequence, padding='max_length', truncation=True, max_length=self.max_length, return_tensors="pt", is_split_into_words=False, add_special_tokens=False) |
|
|
|
|
|
for key in tokenized: |
|
tokenized[key] = tokenized[key].squeeze(0) |
|
|
|
|
|
label_padded = [-100] * self.max_length |
|
label_padded[:len(label)] = label[:self.max_length] |
|
|
|
tokenized["labels"] = torch.tensor(label_padded) |
|
|
|
return tokenized |
|
|
|
def compute_metrics(p): |
|
predictions, labels = p.predictions, p.label_ids |
|
predictions = np.argmax(predictions, axis=2) |
|
|
|
mask = labels != -100 |
|
predictions = predictions[mask].flatten() |
|
labels = labels[mask].flatten() |
|
|
|
accuracy = accuracy_score(labels, predictions) |
|
precision, recall, f1, _ = precision_recall_fscore_support(labels, predictions, average='binary') |
|
auc = roc_auc_score(labels, predictions) |
|
mcc = matthews_corrcoef(labels, predictions) |
|
|
|
return {'accuracy': accuracy, 'precision': precision, 'recall': recall, 'f1': f1, 'auc': auc, 'mcc': mcc} |
|
|
|
def evaluate_in_chunks(dataset, trainer, chunk_percentage=0.2): |
|
chunk_size = int(len(dataset) * chunk_percentage) |
|
all_results = [] |
|
|
|
|
|
for i in tqdm(range(0, len(dataset), chunk_size), desc="Evaluating chunks"): |
|
chunk = [dataset[j] for j in range(i, min(i + chunk_size, len(dataset)))] |
|
chunk_results = trainer.evaluate(chunk) |
|
print(f"Results for chunk starting at index {i}: {chunk_results}") |
|
|
|
|
|
with open(f"results_chunk_{i}.pkl", "wb") as f: |
|
pickle.dump(chunk_results, f) |
|
|
|
all_results.append(chunk_results) |
|
|
|
return all_results |
|
|
|
def aggregate_results(results_list): |
|
total_samples = sum([res["eval_samples"] for res in results_list]) |
|
aggregated_results = {} |
|
|
|
for key in results_list[0].keys(): |
|
if key == "eval_samples": |
|
continue |
|
aggregated_results[key] = sum([res[key] * res["eval_samples"] for res in results_list]) / total_samples |
|
|
|
return aggregated_results |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D") |
|
train_dataset = ProteinDataset("data/12M_data/512_train_sequences_chunked_by_family.pkl", "data/12M_data/512_train_labels_chunked_by_family.pkl", tokenizer, 512) |
|
test_dataset = ProteinDataset("data/12M_data/512_test_sequences_chunked_by_family.pkl", "data/12M_data/512_test_labels_chunked_by_family.pkl", tokenizer, 512) |
|
|
|
|
|
base_model_path = "facebook/esm2_t33_650M_UR50D" |
|
lora_model_path = "qlora_binding_sites/best_model_esm2_t33_650M_qlora_2023-10-18_02-14-48" |
|
base_model = AutoModelForTokenClassification.from_pretrained(base_model_path) |
|
model = PeftModel.from_pretrained(base_model, lora_model_path) |
|
model = accelerator.prepare(model) |
|
|
|
|
|
trainer = Trainer( |
|
model=model, |
|
compute_metrics=compute_metrics |
|
) |
|
|
|
Evaluate the model on chunks of the training dataset |
|
train_results = evaluate_in_chunks(train_dataset, trainer) |
|
aggregated_train_results = aggregate_results(train_results) |
|
print(f"Aggregated Training Results: {aggregated_train_results}") |
|
|
|
|
|
test_results = evaluate_in_chunks(test_dataset, trainer) |
|
aggregated_test_results = aggregate_results(test_results) |
|
print(f"Aggregated Test Results: {aggregated_test_results}") |