esm_mlmppi_ph50_v3 / clustered_ppi_train.py
AmelieSchreiber's picture
Upload 8 files
d504f0a
from transformers import Trainer, TrainingArguments, AutoTokenizer, AutoModelForMaskedLM, TrainerCallback, EsmConfig
from torch.utils.data import Dataset
import pandas as pd
import torch
from torch.optim import AdamW
import random
import datetime
class ProteinDataset(Dataset):
def __init__(self, proteins, peptides, tokenizer, mask_percentage=0.30):
self.tokenizer = tokenizer
self.proteins = proteins
self.peptides = peptides
self.mask_percentage = mask_percentage
def __len__(self):
return len(self.proteins)
def mask_sequence(self, sequence):
mask_indices = random.sample(range(len(sequence)), int(len(sequence) * self.mask_percentage))
return ''.join([self.tokenizer.mask_token if i in mask_indices else char for i, char in enumerate(sequence)])
def __getitem__(self, idx):
protein_seq = self.proteins[idx]
peptide_seq = self.peptides[idx]
masked_protein = self.mask_sequence(protein_seq)
masked_peptide = self.mask_sequence(peptide_seq)
complex_seq = masked_protein + masked_peptide
complex_input = self.tokenizer(
complex_seq,
return_tensors="pt",
padding="max_length",
max_length=1024,
truncation=True,
add_special_tokens=False
)
input_ids = complex_input["input_ids"].squeeze()
attention_mask = complex_input["attention_mask"].squeeze()
label_seq = protein_seq + peptide_seq
labels = self.tokenizer(
label_seq,
return_tensors="pt",
padding="max_length",
max_length=1024,
truncation=True,
add_special_tokens=False
)["input_ids"].squeeze()
labels = torch.where(input_ids == self.tokenizer.mask_token_id, labels, -100)
return {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels}
# Callback to update mask percentage after each epoch
class DynamicMaskingCallback(TrainerCallback):
def __init__(self, dataset, increment=0.10):
self.dataset = dataset
self.increment = increment
def on_epoch_end(self, args, state, control, **kwargs):
self.dataset.mask_percentage = min(self.dataset.mask_percentage + self.increment, 1.0)
print(f"Updated mask percentage to: {self.dataset.mask_percentage * 100}%")
# Loading the dataset
file_path = "clustered_protein_pair_landscapes_l2_distances.tsv"
data = pd.read_csv(file_path, delimiter='\t')
# Splitting the data based on clusters, starting with cluster 0
test_clusters = [0] # Start with cluster 0
remaining_clusters = data[data['Cluster'] != 0]['Cluster'].unique()
random.shuffle(remaining_clusters) # Shuffle the remaining clusters
# Determine the size of cluster 0 in the dataset
cluster_0_size = (data['Cluster'] == 0).mean()
# Add more clusters until reaching approximately 20% of the dataset
test_size = cluster_0_size
for cluster in remaining_clusters:
cluster_size = (data['Cluster'] == cluster).mean()
if test_size + cluster_size > 0.20:
break
test_clusters.append(cluster)
test_size += cluster_size
# Creating test and train data based on the selected clusters
test_data = data[data['Cluster'].isin(test_clusters)]
train_data = data[~data['Cluster'].isin(test_clusters)]
proteins_train = train_data["Protein1"].tolist()
peptides_train = train_data["Protein2"].tolist()
proteins_test = test_data["Protein1"].tolist()
peptides_test = test_data["Protein2"].tolist()
# Load tokenizer and model
model_name = "esm2_t33_650M_UR50D"
tokenizer = AutoTokenizer.from_pretrained("facebook/" + model_name)
# Load model configuration and modify dropout rates
config = EsmConfig.from_pretrained("facebook/" + model_name)
# config.hidden_dropout_prob = 0.1 # Adjust hidden layer dropout
# config.attention_probs_dropout_prob = 0.1 # Adjust attention dropout
model = AutoModelForMaskedLM.from_pretrained("facebook/" + model_name, config=config)
# Generate a timestamp for the output directory
current_time = datetime.datetime.now()
timestamp = current_time.strftime("%Y%m%d_%H%M%S")
output_dir = f'./interact_output_{timestamp}/'
# Calculate the total number of training steps
num_train_epochs = 4
per_device_train_batch_size = 8
gradient_accumulation_steps = 4
total_steps = (len(proteins_train) // (per_device_train_batch_size * gradient_accumulation_steps)) * num_train_epochs
# Training arguments with cosine learning rate scheduler and gradient clipping
training_args = TrainingArguments(
output_dir=output_dir,
num_train_epochs=num_train_epochs,
per_device_train_batch_size=per_device_train_batch_size,
per_device_eval_batch_size=8,
warmup_steps=10,
logging_dir='./logs',
logging_steps=10,
evaluation_strategy="epoch",
load_best_model_at_end=True,
save_strategy='epoch',
metric_for_best_model='eval_loss',
save_total_limit=3,
gradient_accumulation_steps=gradient_accumulation_steps,
lr_scheduler_type='cosine',
max_steps=total_steps, # Corrected: Added comma here
gradient_checkpointing=True, # Enable gradient checkpointing for memory optimization
max_grad_norm=1.0 # Gradient clipping
)
# Optimizer with added weight decay for regularization
optimizer = AdamW(model.parameters(), lr=0.0007984276816171436, weight_decay=0.03)
# Instantiate the ProteinDataset for training and testing
train_dataset = ProteinDataset(proteins_train, peptides_train, tokenizer)
test_dataset = ProteinDataset(proteins_test, peptides_test, tokenizer)
# Initialize DynamicMaskingCallback
dynamic_masking_callback = DynamicMaskingCallback(train_dataset)
# Trainer with callbacks for dynamic masking and gradient clipping
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=test_dataset,
optimizers=(optimizer, None),
callbacks=[dynamic_masking_callback]
)
# Start training
trainer.train()