File size: 6,001 Bytes
d504f0a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
from transformers import Trainer, TrainingArguments, AutoTokenizer, AutoModelForMaskedLM, TrainerCallback, EsmConfig
from torch.utils.data import Dataset
import pandas as pd
import torch
from torch.optim import AdamW
import random
import datetime

class ProteinDataset(Dataset):
    def __init__(self, proteins, peptides, tokenizer, mask_percentage=0.30):
        self.tokenizer = tokenizer
        self.proteins = proteins
        self.peptides = peptides
        self.mask_percentage = mask_percentage

    def __len__(self):
        return len(self.proteins)

    def mask_sequence(self, sequence):
        mask_indices = random.sample(range(len(sequence)), int(len(sequence) * self.mask_percentage))
        return ''.join([self.tokenizer.mask_token if i in mask_indices else char for i, char in enumerate(sequence)])

    def __getitem__(self, idx):
        protein_seq = self.proteins[idx]
        peptide_seq = self.peptides[idx]

        masked_protein = self.mask_sequence(protein_seq)
        masked_peptide = self.mask_sequence(peptide_seq)
        complex_seq = masked_protein + masked_peptide

        complex_input = self.tokenizer(
            complex_seq, 
            return_tensors="pt", 
            padding="max_length", 
            max_length=1024, 
            truncation=True, 
            add_special_tokens=False
        )

        input_ids = complex_input["input_ids"].squeeze()
        attention_mask = complex_input["attention_mask"].squeeze()

        label_seq = protein_seq + peptide_seq
        labels = self.tokenizer(
            label_seq, 
            return_tensors="pt", 
            padding="max_length", 
            max_length=1024, 
            truncation=True, 
            add_special_tokens=False
        )["input_ids"].squeeze()

        labels = torch.where(input_ids == self.tokenizer.mask_token_id, labels, -100)

        return {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels}

# Callback to update mask percentage after each epoch
class DynamicMaskingCallback(TrainerCallback):
    def __init__(self, dataset, increment=0.10):
        self.dataset = dataset
        self.increment = increment

    def on_epoch_end(self, args, state, control, **kwargs):
        self.dataset.mask_percentage = min(self.dataset.mask_percentage + self.increment, 1.0)
        print(f"Updated mask percentage to: {self.dataset.mask_percentage * 100}%")

# Loading the dataset
file_path = "clustered_protein_pair_landscapes_l2_distances.tsv"
data = pd.read_csv(file_path, delimiter='\t')

# Splitting the data based on clusters, starting with cluster 0
test_clusters = [0]  # Start with cluster 0
remaining_clusters = data[data['Cluster'] != 0]['Cluster'].unique()
random.shuffle(remaining_clusters)  # Shuffle the remaining clusters

# Determine the size of cluster 0 in the dataset
cluster_0_size = (data['Cluster'] == 0).mean()

# Add more clusters until reaching approximately 20% of the dataset
test_size = cluster_0_size
for cluster in remaining_clusters:
    cluster_size = (data['Cluster'] == cluster).mean()
    if test_size + cluster_size > 0.20:
        break
    test_clusters.append(cluster)
    test_size += cluster_size

# Creating test and train data based on the selected clusters
test_data = data[data['Cluster'].isin(test_clusters)]
train_data = data[~data['Cluster'].isin(test_clusters)]

proteins_train = train_data["Protein1"].tolist()
peptides_train = train_data["Protein2"].tolist()
proteins_test = test_data["Protein1"].tolist()
peptides_test = test_data["Protein2"].tolist()

# Load tokenizer and model
model_name = "esm2_t33_650M_UR50D"
tokenizer = AutoTokenizer.from_pretrained("facebook/" + model_name)

# Load model configuration and modify dropout rates
config = EsmConfig.from_pretrained("facebook/" + model_name)
# config.hidden_dropout_prob = 0.1  # Adjust hidden layer dropout
# config.attention_probs_dropout_prob = 0.1  # Adjust attention dropout
model = AutoModelForMaskedLM.from_pretrained("facebook/" + model_name, config=config)

# Generate a timestamp for the output directory
current_time = datetime.datetime.now()
timestamp = current_time.strftime("%Y%m%d_%H%M%S")
output_dir = f'./interact_output_{timestamp}/'

# Calculate the total number of training steps
num_train_epochs = 4
per_device_train_batch_size = 8
gradient_accumulation_steps = 4
total_steps = (len(proteins_train) // (per_device_train_batch_size * gradient_accumulation_steps)) * num_train_epochs

# Training arguments with cosine learning rate scheduler and gradient clipping
training_args = TrainingArguments(
    output_dir=output_dir,
    num_train_epochs=num_train_epochs,
    per_device_train_batch_size=per_device_train_batch_size,
    per_device_eval_batch_size=8,
    warmup_steps=10,
    logging_dir='./logs',
    logging_steps=10,
    evaluation_strategy="epoch",
    load_best_model_at_end=True,
    save_strategy='epoch',
    metric_for_best_model='eval_loss',
    save_total_limit=3,
    gradient_accumulation_steps=gradient_accumulation_steps,
    lr_scheduler_type='cosine',
    max_steps=total_steps,  # Corrected: Added comma here
    gradient_checkpointing=True,  # Enable gradient checkpointing for memory optimization
    max_grad_norm=1.0  # Gradient clipping
)

# Optimizer with added weight decay for regularization
optimizer = AdamW(model.parameters(), lr=0.0007984276816171436, weight_decay=0.03)

# Instantiate the ProteinDataset for training and testing
train_dataset = ProteinDataset(proteins_train, peptides_train, tokenizer)
test_dataset = ProteinDataset(proteins_test, peptides_test, tokenizer)

# Initialize DynamicMaskingCallback
dynamic_masking_callback = DynamicMaskingCallback(train_dataset)

# Trainer with callbacks for dynamic masking and gradient clipping
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
    optimizers=(optimizer, None),
    callbacks=[dynamic_masking_callback]
)

# Start training
trainer.train()