from transformers import Trainer, TrainingArguments, AutoTokenizer, AutoModelForMaskedLM, TrainerCallback, EsmConfig from torch.utils.data import Dataset import pandas as pd import torch from torch.optim import AdamW import random import datetime class ProteinDataset(Dataset): def __init__(self, proteins, peptides, tokenizer, mask_percentage=0.30): self.tokenizer = tokenizer self.proteins = proteins self.peptides = peptides self.mask_percentage = mask_percentage def __len__(self): return len(self.proteins) def mask_sequence(self, sequence): mask_indices = random.sample(range(len(sequence)), int(len(sequence) * self.mask_percentage)) return ''.join([self.tokenizer.mask_token if i in mask_indices else char for i, char in enumerate(sequence)]) def __getitem__(self, idx): protein_seq = self.proteins[idx] peptide_seq = self.peptides[idx] masked_protein = self.mask_sequence(protein_seq) masked_peptide = self.mask_sequence(peptide_seq) complex_seq = masked_protein + masked_peptide complex_input = self.tokenizer( complex_seq, return_tensors="pt", padding="max_length", max_length=1024, truncation=True, add_special_tokens=False ) input_ids = complex_input["input_ids"].squeeze() attention_mask = complex_input["attention_mask"].squeeze() label_seq = protein_seq + peptide_seq labels = self.tokenizer( label_seq, return_tensors="pt", padding="max_length", max_length=1024, truncation=True, add_special_tokens=False )["input_ids"].squeeze() labels = torch.where(input_ids == self.tokenizer.mask_token_id, labels, -100) return {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels} # Callback to update mask percentage after each epoch class DynamicMaskingCallback(TrainerCallback): def __init__(self, dataset, increment=0.10): self.dataset = dataset self.increment = increment def on_epoch_end(self, args, state, control, **kwargs): self.dataset.mask_percentage = min(self.dataset.mask_percentage + self.increment, 1.0) print(f"Updated mask percentage to: {self.dataset.mask_percentage * 100}%") # Loading the dataset file_path = "clustered_protein_pair_landscapes_l2_distances.tsv" data = pd.read_csv(file_path, delimiter='\t') # Splitting the data based on clusters, starting with cluster 0 test_clusters = [0] # Start with cluster 0 remaining_clusters = data[data['Cluster'] != 0]['Cluster'].unique() random.shuffle(remaining_clusters) # Shuffle the remaining clusters # Determine the size of cluster 0 in the dataset cluster_0_size = (data['Cluster'] == 0).mean() # Add more clusters until reaching approximately 20% of the dataset test_size = cluster_0_size for cluster in remaining_clusters: cluster_size = (data['Cluster'] == cluster).mean() if test_size + cluster_size > 0.20: break test_clusters.append(cluster) test_size += cluster_size # Creating test and train data based on the selected clusters test_data = data[data['Cluster'].isin(test_clusters)] train_data = data[~data['Cluster'].isin(test_clusters)] proteins_train = train_data["Protein1"].tolist() peptides_train = train_data["Protein2"].tolist() proteins_test = test_data["Protein1"].tolist() peptides_test = test_data["Protein2"].tolist() # Load tokenizer and model model_name = "esm2_t33_650M_UR50D" tokenizer = AutoTokenizer.from_pretrained("facebook/" + model_name) # Load model configuration and modify dropout rates config = EsmConfig.from_pretrained("facebook/" + model_name) # config.hidden_dropout_prob = 0.1 # Adjust hidden layer dropout # config.attention_probs_dropout_prob = 0.1 # Adjust attention dropout model = AutoModelForMaskedLM.from_pretrained("facebook/" + model_name, config=config) # Generate a timestamp for the output directory current_time = datetime.datetime.now() timestamp = current_time.strftime("%Y%m%d_%H%M%S") output_dir = f'./interact_output_{timestamp}/' # Calculate the total number of training steps num_train_epochs = 4 per_device_train_batch_size = 8 gradient_accumulation_steps = 4 total_steps = (len(proteins_train) // (per_device_train_batch_size * gradient_accumulation_steps)) * num_train_epochs # Training arguments with cosine learning rate scheduler and gradient clipping training_args = TrainingArguments( output_dir=output_dir, num_train_epochs=num_train_epochs, per_device_train_batch_size=per_device_train_batch_size, per_device_eval_batch_size=8, warmup_steps=10, logging_dir='./logs', logging_steps=10, evaluation_strategy="epoch", load_best_model_at_end=True, save_strategy='epoch', metric_for_best_model='eval_loss', save_total_limit=3, gradient_accumulation_steps=gradient_accumulation_steps, lr_scheduler_type='cosine', max_steps=total_steps, # Corrected: Added comma here gradient_checkpointing=True, # Enable gradient checkpointing for memory optimization max_grad_norm=1.0 # Gradient clipping ) # Optimizer with added weight decay for regularization optimizer = AdamW(model.parameters(), lr=0.0007984276816171436, weight_decay=0.03) # Instantiate the ProteinDataset for training and testing train_dataset = ProteinDataset(proteins_train, peptides_train, tokenizer) test_dataset = ProteinDataset(proteins_test, peptides_test, tokenizer) # Initialize DynamicMaskingCallback dynamic_masking_callback = DynamicMaskingCallback(train_dataset) # Trainer with callbacks for dynamic masking and gradient clipping trainer = Trainer( model=model, args=training_args, train_dataset=train_dataset, eval_dataset=test_dataset, optimizers=(optimizer, None), callbacks=[dynamic_masking_callback] ) # Start training trainer.train()