Geneformer / examples /hyperparam_optimiz_for_disease_classifier.py
Christina Theodoris
Add example for hyperparameter optimization for disease classifier
79a0c41
raw
history blame
No virus
6.77 kB
#!/usr/bin/env python
# coding: utf-8
# hyperparameter optimization with raytune for disease classification
# imports
import os
import subprocess
GPU_NUMBER = [0,1,2,3]
os.environ["CUDA_VISIBLE_DEVICES"] = ",".join([str(s) for s in GPU_NUMBER])
os.environ["NCCL_DEBUG"] = "INFO"
os.environ["CONDA_OVERRIDE_GLIBC"] = "2.56"
os.environ["LD_LIBRARY_PATH"] = "/path/to/miniconda3/lib:/path/to/sw/lib:/path/to/sw/lib"
# initiate runtime environment for raytune
import pyarrow # must occur prior to ray import
import ray
from ray import tune
from ray.tune import ExperimentAnalysis
from ray.tune.suggest.hyperopt import HyperOptSearch
runtime_env = {"conda": "base",
"env_vars": {"LD_LIBRARY_PATH": "/path/to/miniconda3/lib:/path/to/sw/lib:/path/to/sw/lib"}}
ray.init(runtime_env=runtime_env)
import datetime
import numpy as np
import pandas as pd
import random
import seaborn as sns; sns.set()
from collections import Counter
from datasets import load_from_disk
from scipy.stats import ranksums
from sklearn.metrics import accuracy_score
from transformers import BertForSequenceClassification
from transformers import Trainer
from transformers.training_args import TrainingArguments
from geneformer import DataCollatorForCellClassification
# number of CPU cores
num_proc=30
# load train dataset with columns:
# cell_type (annotation of each cell's type)
# disease (healthy or disease state)
# individual (unique ID for each patient)
# length (length of that cell's rank value encoding)
train_dataset=load_from_disk("/path/to/disease_train_data.dataset")
# filter dataset for given cell_type
def if_cell_type(example):
return example["cell_type"].startswith("Cardiomyocyte")
trainset_v2 = train_dataset.filter(if_cell_type, num_proc=num_proc)
# create dictionary of disease states : label ids
target_names = ["healthy", "disease1", "disease2"]
target_name_id_dict = dict(zip(target_names,[i for i in range(len(target_names))]))
trainset_v3 = trainset_v2.rename_column("disease","label")
# change labels to numerical ids
def classes_to_ids(example):
example["label"] = target_name_id_dict[example["label"]]
return example
trainset_v4 = trainset_v3.map(classes_to_ids, num_proc=num_proc)
# separate into train, validation, test sets
indiv_list = trainset_v4["individual"]
random.seed(42)
train_indiv = random.sample(indiv_list,round(0.7*len(indiv_list)))
eval_indiv = [indiv for indiv in indiv_list if indiv not in train_indiv]
valid_indiv = random.sample(eval_indiv,round(0.5*len(eval_indiv)))
test_indiv = [indiv for indiv in eval_indiv if indiv not in valid_indiv]
def if_train(example):
return example["individual"] in train_indiv
classifier_trainset = trainset_v4.filter(if_train,num_proc=num_proc).shuffle(seed=42)
def if_valid(example):
return example["individual"] in valid_indiv
classifier_validset = trainset_v4.filter(if_valid,num_proc=num_proc).shuffle(seed=42)
# define output directory path
current_date = datetime.datetime.now()
datestamp = f"{str(current_date.year)[-2:]}{current_date.month:02d}{current_date.day:02d}"
output_dir = f"/path/to/models/{datestamp}_geneformer_DiseaseClassifier/"
# ensure not overwriting previously saved model
saved_model_test = os.path.join(output_dir, f"pytorch_model.bin")
if os.path.isfile(saved_model_test) == True:
raise Exception("Model already saved to this directory.")
# make output directory
subprocess.call(f'mkdir {output_dir}', shell=True)
# set training parameters
# how many pretrained layers to freeze
freeze_layers = 2
# batch size for training and eval
geneformer_batch_size = 12
# number of epochs
epochs = 1
# logging steps
logging_steps = round(len(classifier_trainset)/geneformer_batch_size/10)
# define function to initiate model
def model_init():
model = BertForSequenceClassification.from_pretrained("/path/to/pretrained_model/",
num_labels=len(target_names),
output_attentions = False,
output_hidden_states = False)
if freeze_layers is not None:
modules_to_freeze = model.bert.encoder.layer[:freeze_layers]
for module in modules_to_freeze:
for param in module.parameters():
param.requires_grad = False
model = model.to("cuda:0")
return model
# define metrics
def compute_metrics(pred):
labels = pred.label_ids
preds = pred.predictions.argmax(-1)
# calculate accuracy using sklearn's function
acc = accuracy_score(labels, preds)
return {
'accuracy': acc,
}
# set training arguments
training_args = {
"do_train": True,
"do_eval": True,
"evaluation_strategy": "steps",
"eval_steps": logging_steps,
"logging_steps": logging_steps,
"group_by_length": True,
"length_column_name": "length",
"disable_tqdm": True,
"skip_memory_metrics": True, # memory tracker causes errors in raytune
"per_device_train_batch_size": geneformer_batch_size,
"per_device_eval_batch_size": geneformer_batch_size,
"num_train_epochs": epochs,
"load_best_model_at_end": True,
"output_dir": output_dir,
}
training_args_init = TrainingArguments(**training_args)
# create the trainer
trainer = Trainer(
model_init=model_init,
args=training_args_init,
data_collator=DataCollatorForCellClassification(),
train_dataset=classifier_trainset,
eval_dataset=classifier_validset,
compute_metrics=compute_metrics,
)
# specify raytune hyperparameter search space
ray_config = {
"num_train_epochs": tune.choice([epochs]),
"learning_rate": tune.loguniform(1e-6, 1e-3),
"weight_decay": tune.uniform(0.0, 0.3),
"lr_scheduler_type": tune.choice(["linear","cosine","polynomial"]),
"warmup_steps": tune.uniform(100, 2000),
"seed": tune.uniform(0,100),
"per_device_train_batch_size": tune.choice([geneformer_batch_size])
}
hyperopt_search = HyperOptSearch(
metric="eval_accuracy", mode="max")
# optimize hyperparameters
trainer.hyperparameter_search(
direction="maximize",
backend="ray",
resources_per_trial={"cpu":8,"gpu":1},
hp_space=lambda _: ray_config,
search_alg=hyperopt_search,
n_trials=100, # number of trials
progress_reporter=tune.CLIReporter(max_report_frequency=600,
sort_by_metric=True,
max_progress_rows=100,
mode="max",
metric="eval_accuracy",
metric_columns=["loss", "eval_loss", "eval_accuracy"])
)