|
|
|
|
|
|
|
|
|
|
|
|
|
import os
|
|
import subprocess
|
|
GPU_NUMBER = [0,1,2,3]
|
|
os.environ["CUDA_VISIBLE_DEVICES"] = ",".join([str(s) for s in GPU_NUMBER])
|
|
os.environ["NCCL_DEBUG"] = "INFO"
|
|
os.environ["CONDA_OVERRIDE_GLIBC"] = "2.56"
|
|
os.environ["LD_LIBRARY_PATH"] = "/path/to/miniconda3/lib:/path/to/sw/lib:/path/to/sw/lib"
|
|
|
|
|
|
import pyarrow
|
|
import ray
|
|
from ray import tune
|
|
from ray.tune import ExperimentAnalysis
|
|
from ray.tune.suggest.hyperopt import HyperOptSearch
|
|
ray.shutdown()
|
|
runtime_env = {"conda": "base",
|
|
"env_vars": {"LD_LIBRARY_PATH": "/path/to/miniconda3/lib:/path/to/sw/lib:/path/to/sw/lib"}}
|
|
ray.init(runtime_env=runtime_env)
|
|
|
|
def initialize_ray_with_check(ip_address):
|
|
"""
|
|
Initialize Ray with a specified IP address and check its status and accessibility.
|
|
|
|
Args:
|
|
- ip_address (str): The IP address (with port) to initialize Ray.
|
|
|
|
Returns:
|
|
- bool: True if initialization was successful and dashboard is accessible, False otherwise.
|
|
"""
|
|
try:
|
|
ray.init(address=ip_address)
|
|
print(ray.nodes())
|
|
|
|
services = ray.get_webui_url()
|
|
if not services:
|
|
raise RuntimeError("Ray dashboard is not accessible.")
|
|
else:
|
|
print(f"Ray dashboard is accessible at: {services}")
|
|
return True
|
|
except Exception as e:
|
|
print(f"Error initializing Ray: {e}")
|
|
return False
|
|
|
|
|
|
ip = 'your_ip:xxxx'
|
|
if initialize_ray_with_check(ip):
|
|
print("Ray initialized successfully.")
|
|
else:
|
|
print("Error during Ray initialization.")
|
|
|
|
import datetime
|
|
import numpy as np
|
|
import pandas as pd
|
|
import random
|
|
import seaborn as sns; sns.set()
|
|
from collections import Counter
|
|
from datasets import load_from_disk
|
|
from scipy.stats import ranksums
|
|
from sklearn.metrics import accuracy_score
|
|
from transformers import BertForSequenceClassification
|
|
from transformers import Trainer
|
|
from transformers.training_args import TrainingArguments
|
|
|
|
from geneformer import DataCollatorForCellClassification
|
|
|
|
|
|
num_proc=30
|
|
|
|
|
|
|
|
|
|
|
|
|
|
train_dataset=load_from_disk("/path/to/disease_train_data.dataset")
|
|
|
|
|
|
def if_cell_type(example):
|
|
return example["cell_type"].startswith("Cardiomyocyte")
|
|
|
|
trainset_v2 = train_dataset.filter(if_cell_type, num_proc=num_proc)
|
|
|
|
|
|
target_names = ["healthy", "disease1", "disease2"]
|
|
target_name_id_dict = dict(zip(target_names,[i for i in range(len(target_names))]))
|
|
|
|
trainset_v3 = trainset_v2.rename_column("disease","label")
|
|
|
|
|
|
def classes_to_ids(example):
|
|
example["label"] = target_name_id_dict[example["label"]]
|
|
return example
|
|
|
|
trainset_v4 = trainset_v3.map(classes_to_ids, num_proc=num_proc)
|
|
|
|
|
|
indiv_set = set(trainset_v4["individual"])
|
|
random.seed(42)
|
|
train_indiv = random.sample(indiv_set,round(0.7*len(indiv_set)))
|
|
eval_indiv = [indiv for indiv in indiv_set if indiv not in train_indiv]
|
|
valid_indiv = random.sample(eval_indiv,round(0.5*len(eval_indiv)))
|
|
test_indiv = [indiv for indiv in eval_indiv if indiv not in valid_indiv]
|
|
|
|
def if_train(example):
|
|
return example["individual"] in train_indiv
|
|
|
|
classifier_trainset = trainset_v4.filter(if_train,num_proc=num_proc).shuffle(seed=42)
|
|
|
|
def if_valid(example):
|
|
return example["individual"] in valid_indiv
|
|
|
|
classifier_validset = trainset_v4.filter(if_valid,num_proc=num_proc).shuffle(seed=42)
|
|
|
|
|
|
current_date = datetime.datetime.now()
|
|
datestamp = f"{str(current_date.year)[-2:]}{current_date.month:02d}{current_date.day:02d}"
|
|
output_dir = f"/path/to/models/{datestamp}_geneformer_DiseaseClassifier/"
|
|
|
|
|
|
saved_model_test = os.path.join(output_dir, f"pytorch_model.bin")
|
|
if os.path.isfile(saved_model_test) == True:
|
|
raise Exception("Model already saved to this directory.")
|
|
|
|
|
|
subprocess.call(f'mkdir {output_dir}', shell=True)
|
|
|
|
|
|
|
|
freeze_layers = 2
|
|
|
|
geneformer_batch_size = 12
|
|
|
|
epochs = 1
|
|
|
|
logging_steps = round(len(classifier_trainset)/geneformer_batch_size/10)
|
|
|
|
|
|
def model_init():
|
|
model = BertForSequenceClassification.from_pretrained("/path/to/pretrained_model/",
|
|
num_labels=len(target_names),
|
|
output_attentions = False,
|
|
output_hidden_states = False)
|
|
if freeze_layers is not None:
|
|
modules_to_freeze = model.bert.encoder.layer[:freeze_layers]
|
|
for module in modules_to_freeze:
|
|
for param in module.parameters():
|
|
param.requires_grad = False
|
|
|
|
model = model.to("cuda:0")
|
|
return model
|
|
|
|
|
|
|
|
def compute_metrics(pred):
|
|
labels = pred.label_ids
|
|
preds = pred.predictions.argmax(-1)
|
|
|
|
acc = accuracy_score(labels, preds)
|
|
return {
|
|
'accuracy': acc,
|
|
}
|
|
|
|
|
|
training_args = {
|
|
"do_train": True,
|
|
"do_eval": True,
|
|
"evaluation_strategy": "steps",
|
|
"eval_steps": logging_steps,
|
|
"logging_steps": logging_steps,
|
|
"group_by_length": True,
|
|
"length_column_name": "length",
|
|
"disable_tqdm": True,
|
|
"skip_memory_metrics": True,
|
|
"per_device_train_batch_size": geneformer_batch_size,
|
|
"per_device_eval_batch_size": geneformer_batch_size,
|
|
"num_train_epochs": epochs,
|
|
"load_best_model_at_end": True,
|
|
"output_dir": output_dir,
|
|
}
|
|
|
|
training_args_init = TrainingArguments(**training_args)
|
|
|
|
|
|
trainer = Trainer(
|
|
model_init=model_init,
|
|
args=training_args_init,
|
|
data_collator=DataCollatorForCellClassification(),
|
|
train_dataset=classifier_trainset,
|
|
eval_dataset=classifier_validset,
|
|
compute_metrics=compute_metrics,
|
|
)
|
|
|
|
|
|
ray_config = {
|
|
"num_train_epochs": tune.choice([epochs]),
|
|
"learning_rate": tune.loguniform(1e-6, 1e-3),
|
|
"weight_decay": tune.uniform(0.0, 0.3),
|
|
"lr_scheduler_type": tune.choice(["linear","cosine","polynomial"]),
|
|
"warmup_steps": tune.uniform(100, 2000),
|
|
"seed": tune.uniform(0,100),
|
|
"per_device_train_batch_size": tune.choice([geneformer_batch_size])
|
|
}
|
|
|
|
hyperopt_search = HyperOptSearch(
|
|
metric="eval_accuracy", mode="max")
|
|
|
|
|
|
trainer.hyperparameter_search(
|
|
direction="maximize",
|
|
backend="ray",
|
|
resources_per_trial={"cpu":8,"gpu":1},
|
|
hp_space=lambda _: ray_config,
|
|
search_alg=hyperopt_search,
|
|
n_trials=100,
|
|
progress_reporter=tune.CLIReporter(max_report_frequency=600,
|
|
sort_by_metric=True,
|
|
max_progress_rows=100,
|
|
mode="max",
|
|
metric="eval_accuracy",
|
|
metric_columns=["loss", "eval_loss", "eval_accuracy"])
|
|
) |