AmelieSchreiber
commited on
Commit
•
e3768bd
1
Parent(s):
40a9af8
Upload clustered_ppi_train.py
Browse files- clustered_ppi_train.py +160 -0
clustered_ppi_train.py
ADDED
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import Trainer, TrainingArguments, AutoTokenizer, AutoModelForMaskedLM, TrainerCallback, EsmConfig
|
2 |
+
from torch.utils.data import Dataset
|
3 |
+
import pandas as pd
|
4 |
+
import torch
|
5 |
+
from torch.optim import AdamW
|
6 |
+
import random
|
7 |
+
import datetime
|
8 |
+
|
9 |
+
class ProteinDataset(Dataset):
|
10 |
+
def __init__(self, proteins, peptides, tokenizer, mask_percentage=0.30):
|
11 |
+
self.tokenizer = tokenizer
|
12 |
+
self.proteins = proteins
|
13 |
+
self.peptides = peptides
|
14 |
+
self.mask_percentage = mask_percentage
|
15 |
+
|
16 |
+
def __len__(self):
|
17 |
+
return len(self.proteins)
|
18 |
+
|
19 |
+
def mask_sequence(self, sequence):
|
20 |
+
mask_indices = random.sample(range(len(sequence)), int(len(sequence) * self.mask_percentage))
|
21 |
+
return ''.join([self.tokenizer.mask_token if i in mask_indices else char for i, char in enumerate(sequence)])
|
22 |
+
|
23 |
+
def __getitem__(self, idx):
|
24 |
+
protein_seq = self.proteins[idx]
|
25 |
+
peptide_seq = self.peptides[idx]
|
26 |
+
|
27 |
+
masked_protein = self.mask_sequence(protein_seq)
|
28 |
+
masked_peptide = self.mask_sequence(peptide_seq)
|
29 |
+
complex_seq = masked_protein + masked_peptide
|
30 |
+
|
31 |
+
complex_input = self.tokenizer(
|
32 |
+
complex_seq,
|
33 |
+
return_tensors="pt",
|
34 |
+
padding="max_length",
|
35 |
+
max_length=1024,
|
36 |
+
truncation=True,
|
37 |
+
add_special_tokens=False
|
38 |
+
)
|
39 |
+
|
40 |
+
input_ids = complex_input["input_ids"].squeeze()
|
41 |
+
attention_mask = complex_input["attention_mask"].squeeze()
|
42 |
+
|
43 |
+
label_seq = protein_seq + peptide_seq
|
44 |
+
labels = self.tokenizer(
|
45 |
+
label_seq,
|
46 |
+
return_tensors="pt",
|
47 |
+
padding="max_length",
|
48 |
+
max_length=1024,
|
49 |
+
truncation=True,
|
50 |
+
add_special_tokens=False
|
51 |
+
)["input_ids"].squeeze()
|
52 |
+
|
53 |
+
labels = torch.where(input_ids == self.tokenizer.mask_token_id, labels, -100)
|
54 |
+
|
55 |
+
return {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels}
|
56 |
+
|
57 |
+
# Callback to update mask percentage after each epoch
|
58 |
+
class DynamicMaskingCallback(TrainerCallback):
|
59 |
+
def __init__(self, dataset, increment=0.10):
|
60 |
+
self.dataset = dataset
|
61 |
+
self.increment = increment
|
62 |
+
|
63 |
+
def on_epoch_end(self, args, state, control, **kwargs):
|
64 |
+
self.dataset.mask_percentage = min(self.dataset.mask_percentage + self.increment, 1.0)
|
65 |
+
print(f"Updated mask percentage to: {self.dataset.mask_percentage * 100}%")
|
66 |
+
|
67 |
+
# Loading the dataset
|
68 |
+
file_path = "clustered_protein_pair_landscapes_l2_distances.tsv"
|
69 |
+
data = pd.read_csv(file_path, delimiter='\t')
|
70 |
+
|
71 |
+
# Splitting the data based on clusters, starting with cluster 0
|
72 |
+
test_clusters = [0] # Start with cluster 0
|
73 |
+
remaining_clusters = data[data['Cluster'] != 0]['Cluster'].unique()
|
74 |
+
random.shuffle(remaining_clusters) # Shuffle the remaining clusters
|
75 |
+
|
76 |
+
# Determine the size of cluster 0 in the dataset
|
77 |
+
cluster_0_size = (data['Cluster'] == 0).mean()
|
78 |
+
|
79 |
+
# Add more clusters until reaching approximately 20% of the dataset
|
80 |
+
test_size = cluster_0_size
|
81 |
+
for cluster in remaining_clusters:
|
82 |
+
cluster_size = (data['Cluster'] == cluster).mean()
|
83 |
+
if test_size + cluster_size > 0.20:
|
84 |
+
break
|
85 |
+
test_clusters.append(cluster)
|
86 |
+
test_size += cluster_size
|
87 |
+
|
88 |
+
# Creating test and train data based on the selected clusters
|
89 |
+
test_data = data[data['Cluster'].isin(test_clusters)]
|
90 |
+
train_data = data[~data['Cluster'].isin(test_clusters)]
|
91 |
+
|
92 |
+
proteins_train = train_data["Protein1"].tolist()
|
93 |
+
peptides_train = train_data["Protein2"].tolist()
|
94 |
+
proteins_test = test_data["Protein1"].tolist()
|
95 |
+
peptides_test = test_data["Protein2"].tolist()
|
96 |
+
|
97 |
+
# Load tokenizer and model
|
98 |
+
model_name = "esm2_t33_650M_UR50D"
|
99 |
+
tokenizer = AutoTokenizer.from_pretrained("facebook/" + model_name)
|
100 |
+
|
101 |
+
# Load model configuration and modify dropout rates
|
102 |
+
config = EsmConfig.from_pretrained("facebook/" + model_name)
|
103 |
+
# config.hidden_dropout_prob = 0.1 # Adjust hidden layer dropout
|
104 |
+
# config.attention_probs_dropout_prob = 0.1 # Adjust attention dropout
|
105 |
+
model = AutoModelForMaskedLM.from_pretrained("facebook/" + model_name, config=config)
|
106 |
+
|
107 |
+
# Generate a timestamp for the output directory
|
108 |
+
current_time = datetime.datetime.now()
|
109 |
+
timestamp = current_time.strftime("%Y%m%d_%H%M%S")
|
110 |
+
output_dir = f'./interact_output_{timestamp}/'
|
111 |
+
|
112 |
+
# Calculate the total number of training steps
|
113 |
+
num_train_epochs = 4
|
114 |
+
per_device_train_batch_size = 8
|
115 |
+
gradient_accumulation_steps = 4
|
116 |
+
total_steps = (len(proteins_train) // (per_device_train_batch_size * gradient_accumulation_steps)) * num_train_epochs
|
117 |
+
|
118 |
+
# Training arguments with cosine learning rate scheduler and gradient clipping
|
119 |
+
training_args = TrainingArguments(
|
120 |
+
output_dir=output_dir,
|
121 |
+
num_train_epochs=num_train_epochs,
|
122 |
+
per_device_train_batch_size=per_device_train_batch_size,
|
123 |
+
per_device_eval_batch_size=8,
|
124 |
+
warmup_steps=10,
|
125 |
+
logging_dir='./logs',
|
126 |
+
logging_steps=10,
|
127 |
+
evaluation_strategy="epoch",
|
128 |
+
load_best_model_at_end=True,
|
129 |
+
save_strategy='epoch',
|
130 |
+
metric_for_best_model='eval_loss',
|
131 |
+
save_total_limit=3,
|
132 |
+
gradient_accumulation_steps=gradient_accumulation_steps,
|
133 |
+
lr_scheduler_type='cosine',
|
134 |
+
max_steps=total_steps, # Corrected: Added comma here
|
135 |
+
gradient_checkpointing=True, # Enable gradient checkpointing for memory optimization
|
136 |
+
max_grad_norm=1.0 # Gradient clipping
|
137 |
+
)
|
138 |
+
|
139 |
+
# Optimizer with added weight decay for regularization
|
140 |
+
optimizer = AdamW(model.parameters(), lr=0.0007984276816171436, weight_decay=0.03)
|
141 |
+
|
142 |
+
# Instantiate the ProteinDataset for training and testing
|
143 |
+
train_dataset = ProteinDataset(proteins_train, peptides_train, tokenizer)
|
144 |
+
test_dataset = ProteinDataset(proteins_test, peptides_test, tokenizer)
|
145 |
+
|
146 |
+
# Initialize DynamicMaskingCallback
|
147 |
+
dynamic_masking_callback = DynamicMaskingCallback(train_dataset)
|
148 |
+
|
149 |
+
# Trainer with callbacks for dynamic masking and gradient clipping
|
150 |
+
trainer = Trainer(
|
151 |
+
model=model,
|
152 |
+
args=training_args,
|
153 |
+
train_dataset=train_dataset,
|
154 |
+
eval_dataset=test_dataset,
|
155 |
+
optimizers=(optimizer, None),
|
156 |
+
callbacks=[dynamic_masking_callback]
|
157 |
+
)
|
158 |
+
|
159 |
+
# Start training
|
160 |
+
trainer.train()
|