|
from transformers import Trainer, TrainingArguments, AutoTokenizer, AutoModelForMaskedLM, TrainerCallback, EsmConfig |
|
from torch.utils.data import Dataset |
|
import pandas as pd |
|
import torch |
|
from torch.optim import AdamW |
|
import random |
|
import datetime |
|
|
|
class ProteinDataset(Dataset): |
|
def __init__(self, proteins, peptides, tokenizer, mask_percentage=0.30): |
|
self.tokenizer = tokenizer |
|
self.proteins = proteins |
|
self.peptides = peptides |
|
self.mask_percentage = mask_percentage |
|
|
|
def __len__(self): |
|
return len(self.proteins) |
|
|
|
def mask_sequence(self, sequence): |
|
mask_indices = random.sample(range(len(sequence)), int(len(sequence) * self.mask_percentage)) |
|
return ''.join([self.tokenizer.mask_token if i in mask_indices else char for i, char in enumerate(sequence)]) |
|
|
|
def __getitem__(self, idx): |
|
protein_seq = self.proteins[idx] |
|
peptide_seq = self.peptides[idx] |
|
|
|
masked_protein = self.mask_sequence(protein_seq) |
|
masked_peptide = self.mask_sequence(peptide_seq) |
|
complex_seq = masked_protein + masked_peptide |
|
|
|
complex_input = self.tokenizer( |
|
complex_seq, |
|
return_tensors="pt", |
|
padding="max_length", |
|
max_length=1024, |
|
truncation=True, |
|
add_special_tokens=False |
|
) |
|
|
|
input_ids = complex_input["input_ids"].squeeze() |
|
attention_mask = complex_input["attention_mask"].squeeze() |
|
|
|
label_seq = protein_seq + peptide_seq |
|
labels = self.tokenizer( |
|
label_seq, |
|
return_tensors="pt", |
|
padding="max_length", |
|
max_length=1024, |
|
truncation=True, |
|
add_special_tokens=False |
|
)["input_ids"].squeeze() |
|
|
|
labels = torch.where(input_ids == self.tokenizer.mask_token_id, labels, -100) |
|
|
|
return {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels} |
|
|
|
|
|
class DynamicMaskingCallback(TrainerCallback): |
|
def __init__(self, dataset, increment=0.10): |
|
self.dataset = dataset |
|
self.increment = increment |
|
|
|
def on_epoch_end(self, args, state, control, **kwargs): |
|
self.dataset.mask_percentage = min(self.dataset.mask_percentage + self.increment, 1.0) |
|
print(f"Updated mask percentage to: {self.dataset.mask_percentage * 100}%") |
|
|
|
|
|
file_path = "clustered_protein_pair_landscapes_l2_distances.tsv" |
|
data = pd.read_csv(file_path, delimiter='\t') |
|
|
|
|
|
test_clusters = [0] |
|
remaining_clusters = data[data['Cluster'] != 0]['Cluster'].unique() |
|
random.shuffle(remaining_clusters) |
|
|
|
|
|
cluster_0_size = (data['Cluster'] == 0).mean() |
|
|
|
|
|
test_size = cluster_0_size |
|
for cluster in remaining_clusters: |
|
cluster_size = (data['Cluster'] == cluster).mean() |
|
if test_size + cluster_size > 0.20: |
|
break |
|
test_clusters.append(cluster) |
|
test_size += cluster_size |
|
|
|
|
|
test_data = data[data['Cluster'].isin(test_clusters)] |
|
train_data = data[~data['Cluster'].isin(test_clusters)] |
|
|
|
proteins_train = train_data["Protein1"].tolist() |
|
peptides_train = train_data["Protein2"].tolist() |
|
proteins_test = test_data["Protein1"].tolist() |
|
peptides_test = test_data["Protein2"].tolist() |
|
|
|
|
|
model_name = "esm2_t33_650M_UR50D" |
|
tokenizer = AutoTokenizer.from_pretrained("facebook/" + model_name) |
|
|
|
|
|
config = EsmConfig.from_pretrained("facebook/" + model_name) |
|
|
|
|
|
model = AutoModelForMaskedLM.from_pretrained("facebook/" + model_name, config=config) |
|
|
|
|
|
current_time = datetime.datetime.now() |
|
timestamp = current_time.strftime("%Y%m%d_%H%M%S") |
|
output_dir = f'./interact_output_{timestamp}/' |
|
|
|
|
|
num_train_epochs = 4 |
|
per_device_train_batch_size = 8 |
|
gradient_accumulation_steps = 4 |
|
total_steps = (len(proteins_train) // (per_device_train_batch_size * gradient_accumulation_steps)) * num_train_epochs |
|
|
|
|
|
training_args = TrainingArguments( |
|
output_dir=output_dir, |
|
num_train_epochs=num_train_epochs, |
|
per_device_train_batch_size=per_device_train_batch_size, |
|
per_device_eval_batch_size=8, |
|
warmup_steps=10, |
|
logging_dir='./logs', |
|
logging_steps=10, |
|
evaluation_strategy="epoch", |
|
load_best_model_at_end=True, |
|
save_strategy='epoch', |
|
metric_for_best_model='eval_loss', |
|
save_total_limit=3, |
|
gradient_accumulation_steps=gradient_accumulation_steps, |
|
lr_scheduler_type='cosine', |
|
max_steps=total_steps, |
|
gradient_checkpointing=True, |
|
max_grad_norm=1.0 |
|
) |
|
|
|
|
|
optimizer = AdamW(model.parameters(), lr=0.0007984276816171436, weight_decay=0.03) |
|
|
|
|
|
train_dataset = ProteinDataset(proteins_train, peptides_train, tokenizer) |
|
test_dataset = ProteinDataset(proteins_test, peptides_test, tokenizer) |
|
|
|
|
|
dynamic_masking_callback = DynamicMaskingCallback(train_dataset) |
|
|
|
|
|
trainer = Trainer( |
|
model=model, |
|
args=training_args, |
|
train_dataset=train_dataset, |
|
eval_dataset=test_dataset, |
|
optimizers=(optimizer, None), |
|
callbacks=[dynamic_masking_callback] |
|
) |
|
|
|
|
|
trainer.train() |