File size: 4,714 Bytes
3de8059
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import pickle
import numpy as np
import torch
from torch.utils.data import Dataset as TorchDataset
from transformers import AutoTokenizer, AutoModelForTokenClassification, Trainer
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, roc_auc_score, matthews_corrcoef
from peft import PeftModel, get_peft_config, PeftConfig, get_peft_model, LoraConfig, TaskType
from accelerate import Accelerator
from tqdm import tqdm

# Initialize the Accelerator
accelerator = Accelerator()

class ProteinDataset(TorchDataset):
    def __init__(self, sequences_path, labels_path, tokenizer, max_length):
        self.tokenizer = tokenizer
        self.max_length = max_length
        
        with open(sequences_path, "rb") as f:
            self.sequences = pickle.load(f)
        
        with open(labels_path, "rb") as f:
            self.labels = pickle.load(f)

    def __len__(self):
        return len(self.sequences)

    def __getitem__(self, idx):
        sequence = self.sequences[idx]
        label = self.labels[idx]
    
        tokenized = self.tokenizer(sequence, padding='max_length', truncation=True, max_length=self.max_length, return_tensors="pt", is_split_into_words=False, add_special_tokens=False)
    
        # Remove the extra batch dimension
        for key in tokenized:
            tokenized[key] = tokenized[key].squeeze(0)
    
        # Ensure labels are also padded/truncated to match tokenized input
        label_padded = [-100] * self.max_length  # Using -100 as the ignore index
        label_padded[:len(label)] = label[:self.max_length]
        
        tokenized["labels"] = torch.tensor(label_padded)
    
        return tokenized

def compute_metrics(p):
    predictions, labels = p.predictions, p.label_ids
    predictions = np.argmax(predictions, axis=2)
    
    mask = labels != -100
    predictions = predictions[mask].flatten()
    labels = labels[mask].flatten()
    
    accuracy = accuracy_score(labels, predictions)
    precision, recall, f1, _ = precision_recall_fscore_support(labels, predictions, average='binary')
    auc = roc_auc_score(labels, predictions)
    mcc = matthews_corrcoef(labels, predictions)
    
    return {'accuracy': accuracy, 'precision': precision, 'recall': recall, 'f1': f1, 'auc': auc, 'mcc': mcc}

def evaluate_in_chunks(dataset, trainer, chunk_percentage=0.2):
    chunk_size = int(len(dataset) * chunk_percentage)
    all_results = []
    
    # Wrap the loop with tqdm for progress bar
    for i in tqdm(range(0, len(dataset), chunk_size), desc="Evaluating chunks"):
        chunk = [dataset[j] for j in range(i, min(i + chunk_size, len(dataset)))]
        chunk_results = trainer.evaluate(chunk)
        print(f"Results for chunk starting at index {i}: {chunk_results}")
        
        # Save the chunk results to disk
        with open(f"results_chunk_{i}.pkl", "wb") as f:
            pickle.dump(chunk_results, f)
        
        all_results.append(chunk_results)

    return all_results

def aggregate_results(results_list):
    total_samples = sum([res["eval_samples"] for res in results_list])
    aggregated_results = {}

    for key in results_list[0].keys():
        if key == "eval_samples":
            continue
        aggregated_results[key] = sum([res[key] * res["eval_samples"] for res in results_list]) / total_samples

    return aggregated_results

# Initialize tokenizer and datasets
tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D")
train_dataset = ProteinDataset("data/12M_data/512_train_sequences_chunked_by_family.pkl", "data/12M_data/512_train_labels_chunked_by_family.pkl", tokenizer, 512)
test_dataset = ProteinDataset("data/12M_data/512_test_sequences_chunked_by_family.pkl", "data/12M_data/512_test_labels_chunked_by_family.pkl", tokenizer, 512)

# Load the pre-trained LoRA model
base_model_path = "facebook/esm2_t33_650M_UR50D"
lora_model_path = "qlora_binding_sites/best_model_esm2_t33_650M_qlora_2023-10-18_02-14-48"
base_model = AutoModelForTokenClassification.from_pretrained(base_model_path)
model = PeftModel.from_pretrained(base_model, lora_model_path)
model = accelerator.prepare(model)

# Initialize the Trainer
trainer = Trainer(
    model=model,
    compute_metrics=compute_metrics
)

Evaluate the model on chunks of the training dataset
train_results = evaluate_in_chunks(train_dataset, trainer)
aggregated_train_results = aggregate_results(train_results)
print(f"Aggregated Training Results: {aggregated_train_results}")

# Evaluate the model on chunks of the test dataset
test_results = evaluate_in_chunks(test_dataset, trainer)
aggregated_test_results = aggregate_results(test_results)
print(f"Aggregated Test Results: {aggregated_test_results}")