|
import os |
|
import wandb |
|
import numpy as np |
|
import torch |
|
import torch.nn as nn |
|
from torch.utils.data import Dataset as TorchDataset |
|
from datetime import datetime |
|
import random |
|
from sklearn.utils.class_weight import compute_class_weight |
|
from transformers import ( |
|
AutoModelForTokenClassification, |
|
AutoTokenizer, |
|
DataCollatorForTokenClassification, |
|
TrainingArguments, |
|
Trainer, |
|
BitsAndBytesConfig |
|
) |
|
from accelerate import Accelerator |
|
from peft import get_peft_config, PeftModel, PeftConfig, get_peft_model, LoraConfig, TaskType, prepare_model_for_kbit_training |
|
import pickle |
|
import gc |
|
from tqdm import tqdm |
|
|
|
|
|
MAX_LENGTH = 512 |
|
|
|
|
|
accelerator = Accelerator() |
|
os.environ["WANDB_NOTEBOOK_NAME"] = 'training.py' |
|
wandb.init(project='binding_site_prediction') |
|
|
|
|
|
|
|
|
|
class ProteinDataset(TorchDataset): |
|
def __init__(self, sequences_path, labels_path, tokenizer, max_length): |
|
self.tokenizer = tokenizer |
|
self.max_length = max_length |
|
|
|
with open(sequences_path, "rb") as f: |
|
self.sequences = pickle.load(f) |
|
|
|
with open(labels_path, "rb") as f: |
|
self.labels = pickle.load(f) |
|
|
|
def __len__(self): |
|
return len(self.sequences) |
|
|
|
def __getitem__(self, idx): |
|
sequence = self.sequences[idx] |
|
label = self.labels[idx] |
|
|
|
tokenized = self.tokenizer(sequence, padding='max_length', truncation=True, max_length=self.max_length, return_tensors="pt", is_split_into_words=False, add_special_tokens=False) |
|
|
|
|
|
for key in tokenized: |
|
tokenized[key] = tokenized[key].squeeze(0) |
|
|
|
|
|
label_padded = [-100] * self.max_length |
|
label_padded[:len(label)] = label[:self.max_length] |
|
|
|
tokenized["labels"] = torch.tensor(label_padded) |
|
|
|
return tokenized |
|
|
|
def print_trainable_parameters(model): |
|
""" |
|
Prints the number of trainable parameters in the model. |
|
""" |
|
trainable_params = 0 |
|
all_param = 0 |
|
for _, param in model.named_parameters(): |
|
all_param += param.numel() |
|
if param.requires_grad: |
|
trainable_params += param.numel() |
|
print( |
|
f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}" |
|
) |
|
|
|
def save_config_to_txt(config, filename): |
|
"""Save the configuration dictionary to a text file.""" |
|
with open(filename, 'w') as f: |
|
for key, value in config.items(): |
|
f.write(f"{key}: {value}\n") |
|
|
|
def compute_metrics(p): |
|
predictions, labels = p |
|
predictions = np.argmax(predictions, axis=2) |
|
mask = labels != -100 |
|
predictions = predictions[mask].flatten() |
|
labels = labels[mask].flatten() |
|
|
|
accuracy = accuracy_score(labels, predictions) |
|
precision, recall, f1, _ = precision_recall_fscore_support(labels, predictions, average='binary') |
|
auc = roc_auc_score(labels, predictions) |
|
mcc = matthews_corrcoef(labels, predictions) |
|
|
|
|
|
del predictions |
|
del labels |
|
gc.collect() |
|
|
|
return {'accuracy': accuracy, 'precision': precision, 'recall': recall, 'f1': f1, 'auc': auc, 'mcc': mcc} |
|
|
|
def compute_loss(model, logits, inputs): |
|
labels = inputs["labels"] |
|
loss_fct = nn.CrossEntropyLoss(weight=class_weights) |
|
active_loss = inputs["attention_mask"].view(-1) == 1 |
|
active_logits = logits.view(-1, model.config.num_labels) |
|
active_labels = torch.where( |
|
active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels) |
|
) |
|
loss = loss_fct(active_logits, active_labels) |
|
return loss |
|
|
|
tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D") |
|
train_dataset = ProteinDataset("data/12M_data/512_train_sequences_chunked_by_family.pkl", "data/12M_data/512_train_labels_chunked_by_family.pkl", tokenizer, MAX_LENGTH) |
|
|
|
|
|
|
|
SAMPLE_SIZE = 100000 |
|
|
|
with open("data/12M_data/512_train_labels_chunked_by_family.pkl", "rb") as f: |
|
all_train_labels = pickle.load(f) |
|
|
|
sample_labels = random.sample(all_train_labels, SAMPLE_SIZE) |
|
|
|
|
|
flat_sample_labels = [label for sublist in sample_labels for label in sublist] |
|
|
|
|
|
classes = [0, 1] |
|
class_weights = compute_class_weight(class_weight='balanced', classes=classes, y=flat_sample_labels) |
|
class_weights = torch.tensor(class_weights, dtype=torch.float32).to(accelerator.device) |
|
|
|
|
|
class WeightedTrainer(Trainer): |
|
def compute_loss(self, model, inputs, return_outputs=False): |
|
outputs = model(**inputs) |
|
logits = outputs.logits |
|
loss = compute_loss(model, logits, inputs) |
|
return (loss, outputs) if return_outputs else loss |
|
|
|
|
|
bnb_config = BitsAndBytesConfig( |
|
load_in_4bit=True, |
|
bnb_4bit_use_double_quant=True, |
|
bnb_4bit_quant_type="nf4", |
|
bnb_4bit_compute_dtype=torch.bfloat16 |
|
) |
|
|
|
def train_function_no_sweeps(train_dataset): |
|
|
|
|
|
config = { |
|
"lora_alpha": 1, |
|
"lora_dropout": 0.5, |
|
"lr": 1.701568055793089e-04, |
|
"lr_scheduler_type": "cosine", |
|
"max_grad_norm": 0.5, |
|
"num_train_epochs": 1, |
|
"per_device_train_batch_size": 200, |
|
|
|
"r": 2, |
|
"weight_decay": 0.3, |
|
|
|
} |
|
|
|
|
|
wandb.config.update(config) |
|
|
|
|
|
timestamp = datetime.now().strftime('%Y-%m-%d_%H-%M-%S') |
|
config_filename = f"esm2_t33_650M_qlora_config_{timestamp}.txt" |
|
save_config_to_txt(config, config_filename) |
|
|
|
model_checkpoint = "facebook/esm2_t33_650M_UR50D" |
|
|
|
|
|
id2label = {0: "No binding site", 1: "Binding site"} |
|
label2id = {v: k for k, v in id2label.items()} |
|
|
|
model = AutoModelForTokenClassification.from_pretrained( |
|
model_checkpoint, |
|
num_labels=len(id2label), |
|
id2label=id2label, |
|
label2id=label2id, |
|
quantization_config=bnb_config |
|
) |
|
|
|
|
|
model.gradient_checkpointing_enable() |
|
model = prepare_model_for_kbit_training(model) |
|
|
|
|
|
peft_config = LoraConfig( |
|
task_type=TaskType.TOKEN_CLS, |
|
inference_mode=False, |
|
r=config["r"], |
|
lora_alpha=config["lora_alpha"], |
|
target_modules=[ |
|
"query", |
|
"key", |
|
"value", |
|
"EsmSelfOutput.dense", |
|
"EsmIntermediate.dense", |
|
"EsmOutput.dense", |
|
|
|
"classifier" |
|
], |
|
lora_dropout=config["lora_dropout"], |
|
bias="none", |
|
|
|
) |
|
model = get_peft_model(model, peft_config) |
|
print_trainable_parameters(model) |
|
|
|
|
|
model = accelerator.prepare(model) |
|
train_dataset = accelerator.prepare(train_dataset) |
|
|
|
timestamp = datetime.now().strftime('%Y-%m-%d_%H-%M-%S') |
|
|
|
|
|
training_args = TrainingArguments( |
|
output_dir=f"esm2_t33_650M_qlora_binding_sites_{timestamp}", |
|
learning_rate=config["lr"], |
|
lr_scheduler_type=config["lr_scheduler_type"], |
|
gradient_accumulation_steps=1, |
|
max_grad_norm=config["max_grad_norm"], |
|
per_device_train_batch_size=config["per_device_train_batch_size"], |
|
|
|
num_train_epochs=config["num_train_epochs"], |
|
weight_decay=config["weight_decay"], |
|
evaluation_strategy="no", |
|
save_strategy="steps", |
|
save_steps=10000, |
|
load_best_model_at_end=False, |
|
metric_for_best_model="f1", |
|
greater_is_better=True, |
|
push_to_hub=False, |
|
logging_dir=None, |
|
logging_first_step=False, |
|
logging_steps=100, |
|
save_total_limit=7, |
|
no_cuda=False, |
|
seed=8893, |
|
fp16=True, |
|
report_to='wandb', |
|
optim="paged_adamw_8bit" |
|
) |
|
|
|
|
|
trainer = WeightedTrainer( |
|
model=model, |
|
args=training_args, |
|
train_dataset=train_dataset, |
|
tokenizer=tokenizer, |
|
data_collator=DataCollatorForTokenClassification(tokenizer=tokenizer) |
|
) |
|
|
|
|
|
trainer.train() |
|
save_path = os.path.join("qlora_binding_sites", f"best_model_esm2_t33_650M_qlora_{timestamp}") |
|
trainer.save_model(save_path) |
|
tokenizer.save_pretrained(save_path) |
|
|
|
|
|
if __name__ == "__main__": |
|
train_function_no_sweeps(train_dataset) |
|
|
|
|