AmelieSchreiber commited on
Commit
e3768bd
1 Parent(s): 40a9af8

Upload clustered_ppi_train.py

Browse files
Files changed (1) hide show
  1. clustered_ppi_train.py +160 -0
clustered_ppi_train.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import Trainer, TrainingArguments, AutoTokenizer, AutoModelForMaskedLM, TrainerCallback, EsmConfig
2
+ from torch.utils.data import Dataset
3
+ import pandas as pd
4
+ import torch
5
+ from torch.optim import AdamW
6
+ import random
7
+ import datetime
8
+
9
+ class ProteinDataset(Dataset):
10
+ def __init__(self, proteins, peptides, tokenizer, mask_percentage=0.30):
11
+ self.tokenizer = tokenizer
12
+ self.proteins = proteins
13
+ self.peptides = peptides
14
+ self.mask_percentage = mask_percentage
15
+
16
+ def __len__(self):
17
+ return len(self.proteins)
18
+
19
+ def mask_sequence(self, sequence):
20
+ mask_indices = random.sample(range(len(sequence)), int(len(sequence) * self.mask_percentage))
21
+ return ''.join([self.tokenizer.mask_token if i in mask_indices else char for i, char in enumerate(sequence)])
22
+
23
+ def __getitem__(self, idx):
24
+ protein_seq = self.proteins[idx]
25
+ peptide_seq = self.peptides[idx]
26
+
27
+ masked_protein = self.mask_sequence(protein_seq)
28
+ masked_peptide = self.mask_sequence(peptide_seq)
29
+ complex_seq = masked_protein + masked_peptide
30
+
31
+ complex_input = self.tokenizer(
32
+ complex_seq,
33
+ return_tensors="pt",
34
+ padding="max_length",
35
+ max_length=1024,
36
+ truncation=True,
37
+ add_special_tokens=False
38
+ )
39
+
40
+ input_ids = complex_input["input_ids"].squeeze()
41
+ attention_mask = complex_input["attention_mask"].squeeze()
42
+
43
+ label_seq = protein_seq + peptide_seq
44
+ labels = self.tokenizer(
45
+ label_seq,
46
+ return_tensors="pt",
47
+ padding="max_length",
48
+ max_length=1024,
49
+ truncation=True,
50
+ add_special_tokens=False
51
+ )["input_ids"].squeeze()
52
+
53
+ labels = torch.where(input_ids == self.tokenizer.mask_token_id, labels, -100)
54
+
55
+ return {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels}
56
+
57
+ # Callback to update mask percentage after each epoch
58
+ class DynamicMaskingCallback(TrainerCallback):
59
+ def __init__(self, dataset, increment=0.10):
60
+ self.dataset = dataset
61
+ self.increment = increment
62
+
63
+ def on_epoch_end(self, args, state, control, **kwargs):
64
+ self.dataset.mask_percentage = min(self.dataset.mask_percentage + self.increment, 1.0)
65
+ print(f"Updated mask percentage to: {self.dataset.mask_percentage * 100}%")
66
+
67
+ # Loading the dataset
68
+ file_path = "clustered_protein_pair_landscapes_l2_distances.tsv"
69
+ data = pd.read_csv(file_path, delimiter='\t')
70
+
71
+ # Splitting the data based on clusters, starting with cluster 0
72
+ test_clusters = [0] # Start with cluster 0
73
+ remaining_clusters = data[data['Cluster'] != 0]['Cluster'].unique()
74
+ random.shuffle(remaining_clusters) # Shuffle the remaining clusters
75
+
76
+ # Determine the size of cluster 0 in the dataset
77
+ cluster_0_size = (data['Cluster'] == 0).mean()
78
+
79
+ # Add more clusters until reaching approximately 20% of the dataset
80
+ test_size = cluster_0_size
81
+ for cluster in remaining_clusters:
82
+ cluster_size = (data['Cluster'] == cluster).mean()
83
+ if test_size + cluster_size > 0.20:
84
+ break
85
+ test_clusters.append(cluster)
86
+ test_size += cluster_size
87
+
88
+ # Creating test and train data based on the selected clusters
89
+ test_data = data[data['Cluster'].isin(test_clusters)]
90
+ train_data = data[~data['Cluster'].isin(test_clusters)]
91
+
92
+ proteins_train = train_data["Protein1"].tolist()
93
+ peptides_train = train_data["Protein2"].tolist()
94
+ proteins_test = test_data["Protein1"].tolist()
95
+ peptides_test = test_data["Protein2"].tolist()
96
+
97
+ # Load tokenizer and model
98
+ model_name = "esm2_t33_650M_UR50D"
99
+ tokenizer = AutoTokenizer.from_pretrained("facebook/" + model_name)
100
+
101
+ # Load model configuration and modify dropout rates
102
+ config = EsmConfig.from_pretrained("facebook/" + model_name)
103
+ # config.hidden_dropout_prob = 0.1 # Adjust hidden layer dropout
104
+ # config.attention_probs_dropout_prob = 0.1 # Adjust attention dropout
105
+ model = AutoModelForMaskedLM.from_pretrained("facebook/" + model_name, config=config)
106
+
107
+ # Generate a timestamp for the output directory
108
+ current_time = datetime.datetime.now()
109
+ timestamp = current_time.strftime("%Y%m%d_%H%M%S")
110
+ output_dir = f'./interact_output_{timestamp}/'
111
+
112
+ # Calculate the total number of training steps
113
+ num_train_epochs = 4
114
+ per_device_train_batch_size = 8
115
+ gradient_accumulation_steps = 4
116
+ total_steps = (len(proteins_train) // (per_device_train_batch_size * gradient_accumulation_steps)) * num_train_epochs
117
+
118
+ # Training arguments with cosine learning rate scheduler and gradient clipping
119
+ training_args = TrainingArguments(
120
+ output_dir=output_dir,
121
+ num_train_epochs=num_train_epochs,
122
+ per_device_train_batch_size=per_device_train_batch_size,
123
+ per_device_eval_batch_size=8,
124
+ warmup_steps=10,
125
+ logging_dir='./logs',
126
+ logging_steps=10,
127
+ evaluation_strategy="epoch",
128
+ load_best_model_at_end=True,
129
+ save_strategy='epoch',
130
+ metric_for_best_model='eval_loss',
131
+ save_total_limit=3,
132
+ gradient_accumulation_steps=gradient_accumulation_steps,
133
+ lr_scheduler_type='cosine',
134
+ max_steps=total_steps, # Corrected: Added comma here
135
+ gradient_checkpointing=True, # Enable gradient checkpointing for memory optimization
136
+ max_grad_norm=1.0 # Gradient clipping
137
+ )
138
+
139
+ # Optimizer with added weight decay for regularization
140
+ optimizer = AdamW(model.parameters(), lr=0.0007984276816171436, weight_decay=0.03)
141
+
142
+ # Instantiate the ProteinDataset for training and testing
143
+ train_dataset = ProteinDataset(proteins_train, peptides_train, tokenizer)
144
+ test_dataset = ProteinDataset(proteins_test, peptides_test, tokenizer)
145
+
146
+ # Initialize DynamicMaskingCallback
147
+ dynamic_masking_callback = DynamicMaskingCallback(train_dataset)
148
+
149
+ # Trainer with callbacks for dynamic masking and gradient clipping
150
+ trainer = Trainer(
151
+ model=model,
152
+ args=training_args,
153
+ train_dataset=train_dataset,
154
+ eval_dataset=test_dataset,
155
+ optimizers=(optimizer, None),
156
+ callbacks=[dynamic_masking_callback]
157
+ )
158
+
159
+ # Start training
160
+ trainer.train()