esm_mlmppi_ph50 / clustered_ppi_train.py
AmelieSchreiber's picture
Upload clustered_ppi_train.py
e3768bd
from transformers import Trainer, TrainingArguments, AutoTokenizer, AutoModelForMaskedLM, TrainerCallback, EsmConfig
from torch.utils.data import Dataset
import pandas as pd
import torch
from torch.optim import AdamW
import random
import datetime
class ProteinDataset(Dataset):
def __init__(self, proteins, peptides, tokenizer, mask_percentage=0.30):
self.tokenizer = tokenizer
self.proteins = proteins
self.peptides = peptides
self.mask_percentage = mask_percentage
def __len__(self):
return len(self.proteins)
def mask_sequence(self, sequence):
mask_indices = random.sample(range(len(sequence)), int(len(sequence) * self.mask_percentage))
return ''.join([self.tokenizer.mask_token if i in mask_indices else char for i, char in enumerate(sequence)])
def __getitem__(self, idx):
protein_seq = self.proteins[idx]
peptide_seq = self.peptides[idx]
masked_protein = self.mask_sequence(protein_seq)
masked_peptide = self.mask_sequence(peptide_seq)
complex_seq = masked_protein + masked_peptide
complex_input = self.tokenizer(
complex_seq,
return_tensors="pt",
padding="max_length",
max_length=1024,
truncation=True,
add_special_tokens=False
)
input_ids = complex_input["input_ids"].squeeze()
attention_mask = complex_input["attention_mask"].squeeze()
label_seq = protein_seq + peptide_seq
labels = self.tokenizer(
label_seq,
return_tensors="pt",
padding="max_length",
max_length=1024,
truncation=True,
add_special_tokens=False
)["input_ids"].squeeze()
labels = torch.where(input_ids == self.tokenizer.mask_token_id, labels, -100)
return {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels}
# Callback to update mask percentage after each epoch
class DynamicMaskingCallback(TrainerCallback):
def __init__(self, dataset, increment=0.10):
self.dataset = dataset
self.increment = increment
def on_epoch_end(self, args, state, control, **kwargs):
self.dataset.mask_percentage = min(self.dataset.mask_percentage + self.increment, 1.0)
print(f"Updated mask percentage to: {self.dataset.mask_percentage * 100}%")
# Loading the dataset
file_path = "clustered_protein_pair_landscapes_l2_distances.tsv"
data = pd.read_csv(file_path, delimiter='\t')
# Splitting the data based on clusters, starting with cluster 0
test_clusters = [0] # Start with cluster 0
remaining_clusters = data[data['Cluster'] != 0]['Cluster'].unique()
random.shuffle(remaining_clusters) # Shuffle the remaining clusters
# Determine the size of cluster 0 in the dataset
cluster_0_size = (data['Cluster'] == 0).mean()
# Add more clusters until reaching approximately 20% of the dataset
test_size = cluster_0_size
for cluster in remaining_clusters:
cluster_size = (data['Cluster'] == cluster).mean()
if test_size + cluster_size > 0.20:
break
test_clusters.append(cluster)
test_size += cluster_size
# Creating test and train data based on the selected clusters
test_data = data[data['Cluster'].isin(test_clusters)]
train_data = data[~data['Cluster'].isin(test_clusters)]
proteins_train = train_data["Protein1"].tolist()
peptides_train = train_data["Protein2"].tolist()
proteins_test = test_data["Protein1"].tolist()
peptides_test = test_data["Protein2"].tolist()
# Load tokenizer and model
model_name = "esm2_t33_650M_UR50D"
tokenizer = AutoTokenizer.from_pretrained("facebook/" + model_name)
# Load model configuration and modify dropout rates
config = EsmConfig.from_pretrained("facebook/" + model_name)
# config.hidden_dropout_prob = 0.1 # Adjust hidden layer dropout
# config.attention_probs_dropout_prob = 0.1 # Adjust attention dropout
model = AutoModelForMaskedLM.from_pretrained("facebook/" + model_name, config=config)
# Generate a timestamp for the output directory
current_time = datetime.datetime.now()
timestamp = current_time.strftime("%Y%m%d_%H%M%S")
output_dir = f'./interact_output_{timestamp}/'
# Calculate the total number of training steps
num_train_epochs = 4
per_device_train_batch_size = 8
gradient_accumulation_steps = 4
total_steps = (len(proteins_train) // (per_device_train_batch_size * gradient_accumulation_steps)) * num_train_epochs
# Training arguments with cosine learning rate scheduler and gradient clipping
training_args = TrainingArguments(
output_dir=output_dir,
num_train_epochs=num_train_epochs,
per_device_train_batch_size=per_device_train_batch_size,
per_device_eval_batch_size=8,
warmup_steps=10,
logging_dir='./logs',
logging_steps=10,
evaluation_strategy="epoch",
load_best_model_at_end=True,
save_strategy='epoch',
metric_for_best_model='eval_loss',
save_total_limit=3,
gradient_accumulation_steps=gradient_accumulation_steps,
lr_scheduler_type='cosine',
max_steps=total_steps, # Corrected: Added comma here
gradient_checkpointing=True, # Enable gradient checkpointing for memory optimization
max_grad_norm=1.0 # Gradient clipping
)
# Optimizer with added weight decay for regularization
optimizer = AdamW(model.parameters(), lr=0.0007984276816171436, weight_decay=0.03)
# Instantiate the ProteinDataset for training and testing
train_dataset = ProteinDataset(proteins_train, peptides_train, tokenizer)
test_dataset = ProteinDataset(proteins_test, peptides_test, tokenizer)
# Initialize DynamicMaskingCallback
dynamic_masking_callback = DynamicMaskingCallback(train_dataset)
# Trainer with callbacks for dynamic masking and gradient clipping
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=test_dataset,
optimizers=(optimizer, None),
callbacks=[dynamic_masking_callback]
)
# Start training
trainer.train()