|
""" |
|
Geneformer cell classifier. |
|
|
|
Usage: |
|
from geneformer import classify_cells |
|
classify_cells( |
|
token_set=Path("geneformer/token_dictionary.pkl"), |
|
median_set=Path("geneformer/gene_median_dictionary.pkl"), |
|
pretrained_model=".", |
|
dataset="Genecorpus-30M/example_input_files/cell_classification/cell_type_annotation/cell_type_train_data.dataset/", |
|
dataset_split=None, |
|
filter_cells=0.005, |
|
epochs=1, |
|
cpu_cores=os.cpu_count(), |
|
geneformer_batch_size=12, |
|
optimizer="adamw", |
|
max_lr=5e-5, |
|
num_gpus=torch.cuda.device_count(), |
|
max_input_size=2**11, |
|
lr_schedule_fn="linear", |
|
warmup_steps=500, |
|
freeze_layers=0, |
|
emb_extract=False, |
|
max_cells=1000, |
|
emb_layer=0, |
|
emb_filter=None, |
|
emb_dir="embeddings", |
|
overwrite=True, |
|
label="cell_type", |
|
data_filter=None, |
|
forward_batch=200, |
|
model_location=None, |
|
skip_training=False, |
|
sample_data=1, |
|
inference=False, |
|
optimize_hyperparameters=False, |
|
output_dir=None, |
|
) |
|
""" |
|
|
|
import ast |
|
import datetime |
|
import os |
|
import pickle |
|
import random |
|
import subprocess |
|
from collections import Counter |
|
from pathlib import Path |
|
|
|
import numpy as np |
|
import seaborn as sns |
|
import torch |
|
import torch.nn.functional as F |
|
from datasets import load_from_disk |
|
from matplotlib import pyplot as plt |
|
from ray import tune |
|
from ray.tune.search.hyperopt import HyperOptSearch |
|
from sklearn.metrics import accuracy_score |
|
from sklearn.metrics import auc as precision_auc |
|
from sklearn.metrics import f1_score, precision_recall_curve, roc_auc_score, roc_curve |
|
from transformers import BertForSequenceClassification, Trainer |
|
from transformers.training_args import TrainingArguments |
|
|
|
from geneformer import DataCollatorForCellClassification, EmbExtractor |
|
|
|
sns.set() |
|
|
|
|
|
GPU_NUMBER = [i for i in range(torch.cuda.device_count())] |
|
os.environ["CUDA_VISIBLE_DEVICES"] = ",".join([str(s) for s in GPU_NUMBER]) |
|
os.environ["NCCL_DEBUG"] = "INFO" |
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
|
|
def ROC(prediction, truth, type="GeneFormer", label=""): |
|
fpr, tpr, _ = roc_curve(truth, prediction[:, 1]) |
|
auc = roc_auc_score(truth, prediction[:, 1]) |
|
print(f"{type} AUC: {auc}") |
|
plt.plot(fpr, tpr, label="AUC=" + str(auc)) |
|
plt.ylabel("True Positive Rate") |
|
plt.xlabel("False Positive Rate") |
|
plt.title(f"{label} ROC Curve") |
|
plt.legend(loc=4) |
|
plt.savefig("ROC.png") |
|
|
|
return tpr, fpr, auc |
|
|
|
|
|
|
|
def similarity(tensor1, tensor2, cosine=False): |
|
if cosine is False: |
|
if tensor1.ndimension() > 1: |
|
tensor1 = tensor1.view(1, -1) |
|
if tensor2.ndimension() > 1: |
|
tensor2 = tensor2.view(1, -1) |
|
dot_product = torch.matmul(tensor1, tensor2) |
|
norm_tensor1 = torch.norm(tensor1) |
|
norm_tensor2 = torch.norm(tensor2) |
|
epsilon = 1e-8 |
|
similarity = dot_product / (norm_tensor1 * norm_tensor2 + epsilon) |
|
similarity = (similarity.item() + 1) / 2 |
|
else: |
|
if tensor1.shape != tensor2.shape: |
|
raise ValueError("Input tensors must have the same shape.") |
|
|
|
|
|
dot_product = torch.dot(tensor1, tensor2) |
|
norm_tensor1 = torch.norm(tensor1) |
|
norm_tensor2 = torch.norm(tensor2) |
|
|
|
|
|
epsilon = 1e-8 |
|
similarity = dot_product / (norm_tensor1 * norm_tensor2 + epsilon) |
|
|
|
return similarity.item() |
|
|
|
|
|
|
|
def plot_similarity_heatmap(similarities): |
|
classes = list(similarities.keys()) |
|
classlen = len(classes) |
|
arr = np.zeros((classlen, classlen)) |
|
for i, c in enumerate(classes): |
|
for j, cc in enumerate(classes): |
|
if cc == c: |
|
val = 1.0 |
|
else: |
|
val = similarities[c][cc] |
|
arr[i][j] = val |
|
|
|
plt.figure(figsize=(8, 6)) |
|
plt.imshow(arr, cmap="inferno", vmin=0, vmax=1) |
|
plt.colorbar() |
|
plt.xticks(np.arange(classlen), classes, rotation=45, ha="right") |
|
plt.yticks(np.arange(classlen), classes) |
|
plt.title("Similarity Heatmap") |
|
plt.savefig("similarity_heatmap.png") |
|
|
|
|
|
def classify_cells( |
|
token_set=Path("./token_dictionary.pkl"), |
|
median_set=Path("./gene_median_dictionary.pkl"), |
|
pretrained_model="../", |
|
dataset="Genecorpus-30M/example_input_files/cell_classification/cell_type_annotation/cell_type_train_data.dataset/", |
|
dataset_split=None, |
|
filter_cells=0.005, |
|
epochs=1, |
|
cpu_cores=os.cpu_count(), |
|
training_batch_size=12, |
|
optimizer="adamw", |
|
max_lr=5e-5, |
|
num_gpus=torch.cuda.device_count(), |
|
max_input_size=2**11, |
|
lr_schedule_fn="linear", |
|
warmup_steps=500, |
|
freeze_layers=0, |
|
emb_extract=False, |
|
max_cells=None, |
|
emb_layer=-1, |
|
emb_filter=None, |
|
emb_dir="embeddings", |
|
overwrite=False, |
|
label="cell_type", |
|
data_filter=None, |
|
inference_batch_size=200, |
|
finetuned_model=None, |
|
skip_training=False, |
|
sample_data=1, |
|
inference=False, |
|
optimize_hyperparameters=True, |
|
output_dir=None, |
|
): |
|
""" |
|
Primary Parameters |
|
------------------- |
|
dataset: path |
|
Path to fine-tuning dataset for training |
|
|
|
finetuned_model: path |
|
Path to location of fine-tuned model to use for inference and embedding extraction |
|
|
|
pretrained_model: path |
|
Path to pretrained Geneformer model |
|
|
|
inference: bool |
|
Indicates whether to perform inference and return a list of similarities. Defaults to False. |
|
|
|
skip_training: bool |
|
Indicates whether to skip training the model. Defaults to False. |
|
|
|
emb_extract: bool |
|
Indicates whether to extract embeddings and calculate similarities. Defaults to True. |
|
|
|
optimize_hyperparameters: bool |
|
Indicates whether to optimize model hyperparamters. Defaults to False. |
|
|
|
|
|
Customization Parameters |
|
------------------- |
|
|
|
dataset_split: str |
|
Indicates how the dataset should be partitioned (if at all), and what ID should be used for partitioning |
|
|
|
data_filter: list |
|
(For embeddings and inference) Runs analysis on subsets of the dataset based on the ID defined by dataset_split |
|
|
|
label: str |
|
Feature to read as a classification label. |
|
|
|
emb_layer: int |
|
What layer embeddings should be extracted and compared. |
|
|
|
emb_filter: ['cell1', 'cell2'...] |
|
Allows user to narrow down range of cells that embeddings will be extracted from. |
|
|
|
max_cells: int |
|
Max number of cells to use for embedding extraction. |
|
|
|
freeze_layers: int |
|
Number of layers that should be frozen during fine-tuning. |
|
|
|
sample_data: float |
|
Proportion of the dataset that should be used. |
|
|
|
""" |
|
|
|
dataset_list = [] |
|
evalset_list = [] |
|
split_list = [] |
|
target_dict_list = [] |
|
|
|
train_dataset = load_from_disk(dataset) |
|
num_samples = int(len(train_dataset) * sample_data) |
|
random_indices = random.sample(range(len(train_dataset)), num_samples) |
|
train_dataset = train_dataset.select(random_indices) |
|
|
|
sample = int(sample_data * len(train_dataset)) |
|
sample_indices = random.sample(range(len(train_dataset)), sample) |
|
train_dataset = train_dataset.select(sample_indices) |
|
|
|
def if_not_rare_cell_state(example): |
|
return example[label] in cells_to_keep |
|
|
|
|
|
def classes_to_ids(example): |
|
example["label"] = target_name_id_dict[example["label"]] |
|
return example |
|
|
|
def if_trained_label(example): |
|
return example["label"] in trained_labels |
|
|
|
if skip_training is not True: |
|
|
|
def compute_metrics(pred): |
|
labels = pred.label_ids |
|
preds = pred.predictions.argmax(-1) |
|
|
|
acc = accuracy_score(labels, preds) |
|
macro_f1 = f1_score(labels, preds, average="macro") |
|
return {"accuracy": acc, "macro_f1": macro_f1} |
|
|
|
|
|
excep = {"bone_marrow": "immune"} |
|
|
|
if dataset_split is not None: |
|
if data_filter is not None: |
|
split_iter = [data_filter] |
|
else: |
|
split_iter = Counter(train_dataset[dataset_split]).keys() |
|
for lab in split_iter: |
|
|
|
if lab in list(excep.keys()): |
|
continue |
|
elif lab == list(excep.values()): |
|
split_ids = [excep.keys(), excep.values()] |
|
split_list += [excep.values()] |
|
else: |
|
split_ids = [lab] |
|
split_list += [lab] |
|
|
|
|
|
def if_label(example): |
|
return example[dataset_split] == lab |
|
|
|
trainset_label = train_dataset.filter(if_label, num_proc=cpu_cores) |
|
label_counter = Counter(trainset_label[label]) |
|
total_cells = sum(label_counter.values()) |
|
|
|
|
|
cells_to_keep = [ |
|
k |
|
for k, v in label_counter.items() |
|
if v > (filter_cells * total_cells) |
|
] |
|
trainset_label_subset = trainset_label.filter( |
|
if_not_rare_cell_state, num_proc=cpu_cores |
|
) |
|
|
|
|
|
trainset_label_shuffled = trainset_label_subset.shuffle(seed=42) |
|
trainset_label_shuffled = trainset_label_shuffled.rename_column( |
|
label, "label" |
|
) |
|
trainset_label_shuffled = trainset_label_shuffled.remove_columns( |
|
dataset_split |
|
) |
|
|
|
|
|
target_names = list(Counter(trainset_label_shuffled["label"]).keys()) |
|
target_name_id_dict = dict( |
|
zip(target_names, [i for i in range(len(target_names))]) |
|
) |
|
target_dict_list += [target_name_id_dict] |
|
|
|
labeled_trainset = trainset_label_shuffled.map( |
|
classes_to_ids, num_proc=cpu_cores |
|
) |
|
|
|
|
|
labeled_train_split = trainset_label_shuffled.select( |
|
[i for i in range(0, round(len(labeled_trainset) * 0.8))] |
|
) |
|
labeled_eval_split = trainset_label_shuffled.select( |
|
[ |
|
i |
|
for i in range( |
|
round(len(labeled_trainset) * 0.8), len(labeled_trainset) |
|
) |
|
] |
|
) |
|
|
|
|
|
trained_labels = list(Counter(labeled_train_split["label"]).keys()) |
|
|
|
labeled_eval_split_subset = labeled_eval_split.filter( |
|
if_trained_label, num_proc=cpu_cores |
|
) |
|
|
|
dataset_list += [labeled_train_split] |
|
evalset_list += [labeled_eval_split_subset] |
|
|
|
trainset_dict = dict(zip(split_list, dataset_list)) |
|
traintargetdict_dict = dict(zip(split_list, target_dict_list)) |
|
evalset_dict = dict(zip(split_list, evalset_list)) |
|
|
|
for lab in split_list: |
|
label_trainset = trainset_dict[lab] |
|
label_evalset = evalset_dict[lab] |
|
label_dict = traintargetdict_dict[lab] |
|
|
|
|
|
logging_steps = round(len(label_trainset) / training_batch_size / 10) |
|
if logging_steps == 0: |
|
logging_steps = 1 |
|
|
|
|
|
model = BertForSequenceClassification.from_pretrained( |
|
pretrained_model, |
|
num_labels=len(label_dict.keys()), |
|
output_attentions=False, |
|
output_hidden_states=False, |
|
).to(device) |
|
|
|
|
|
current_date = datetime.datetime.now() |
|
datestamp = f"{str(current_date.year)[-2:]}{current_date.month:02d}{current_date.day:02d}" |
|
|
|
if output_dir is None: |
|
output_dir = f"{datestamp}_geneformer_CellClassifier_{lab}_L{max_input_size}_B{training_batch_size}_LR{max_lr}_LS{lr_schedule_fn}_WU{warmup_steps}_E{epochs}_O{optimizer}_F{freeze_layers}/" |
|
|
|
|
|
saved_model_test = os.path.join(output_dir, "pytorch_model.bin") |
|
|
|
if os.path.isfile(saved_model_test) is True and overwrite is False: |
|
raise Exception("Model already saved to this directory.") |
|
|
|
|
|
subprocess.call(f"mkdir -p {output_dir}", shell=True) |
|
|
|
|
|
training_args = { |
|
"learning_rate": max_lr, |
|
"do_train": True, |
|
"do_eval": True, |
|
"evaluation_strategy": "epoch", |
|
"save_strategy": "epoch", |
|
"logging_steps": logging_steps, |
|
"group_by_length": True, |
|
"length_column_name": "length", |
|
"disable_tqdm": False, |
|
"lr_scheduler_type": lr_schedule_fn, |
|
"warmup_steps": warmup_steps, |
|
"weight_decay": 0.001, |
|
"per_device_train_batch_size": training_batch_size, |
|
"per_device_eval_batch_size": training_batch_size, |
|
"num_train_epochs": epochs, |
|
"load_best_model_at_end": True, |
|
"output_dir": output_dir, |
|
} |
|
|
|
training_args_init = TrainingArguments(**training_args) |
|
true_labels = label_evalset["label"] |
|
|
|
if optimize_hyperparameters is False: |
|
|
|
trainer = Trainer( |
|
model=model, |
|
args=training_args_init, |
|
data_collator=DataCollatorForCellClassification(), |
|
train_dataset=label_trainset, |
|
eval_dataset=label_evalset, |
|
compute_metrics=compute_metrics, |
|
) |
|
|
|
|
|
trainer.train() |
|
predictions = trainer.predict(label_evalset) |
|
print( |
|
f'accuracy: {accuracy_score(predictions.argmax(), label_evalset["labels"])}' |
|
) |
|
|
|
tpr, fpr, auc = ROC(predictions.predictions, true_labels) |
|
|
|
metrics = compute_metrics(predictions) |
|
with open(f"{output_dir}predictions.pickle", "wb") as fp: |
|
pickle.dump(predictions, fp) |
|
|
|
trainer.save_metrics("eval", predictions.metrics) |
|
|
|
with open(f"{output_dir}/targets.txt", "w") as f: |
|
if len(target_dict_list) == 1: |
|
f.write(str(target_dict_list[0])) |
|
else: |
|
f.write(str(target_dict_list)) |
|
|
|
try: |
|
precision, recall, _ = precision_recall_curve( |
|
true_labels, predictions.predictions[:, 1] |
|
) |
|
pr_auc = precision_auc(recall, precision) |
|
|
|
print(f"AUC: {pr_auc}") |
|
return recall, precision, pr_auc |
|
except: |
|
pass |
|
|
|
trainer.save_model(output_dir) |
|
else: |
|
|
|
def model_init(): |
|
model = BertForSequenceClassification.from_pretrained( |
|
pretrained_model, |
|
num_labels=len(label_dict.keys()), |
|
output_attentions=False, |
|
output_hidden_states=False, |
|
) |
|
if freeze_layers is not None: |
|
modules_to_freeze = model.bert.encoder.layer[:freeze_layers] |
|
for module in modules_to_freeze: |
|
for param in module.parameters(): |
|
param.requires_grad = False |
|
model = model.to(device) |
|
return model |
|
|
|
trainer = Trainer( |
|
model_init=model_init, |
|
args=training_args_init, |
|
data_collator=DataCollatorForCellClassification(), |
|
train_dataset=label_trainset, |
|
eval_dataset=label_evalset, |
|
compute_metrics=compute_metrics, |
|
) |
|
|
|
ray_config = { |
|
"num_train_epochs": tune.choice([epochs]), |
|
"learning_rate": tune.loguniform(1e-6, 1e-3), |
|
"weight_decay": tune.uniform(0.0, 0.3), |
|
"lr_scheduler_type": tune.choice( |
|
["linear", "cosine", "polynomial"] |
|
), |
|
"warmup_steps": tune.uniform(100, 2000), |
|
"seed": tune.uniform(0, 100), |
|
"per_device_train_batch_size": tune.choice( |
|
[training_batch_size] |
|
), |
|
} |
|
|
|
hyperopt_search = HyperOptSearch(metric="eval_accuracy", mode="max") |
|
|
|
if torch.device == "cuda": |
|
resources_per_trial = ({"cpu": 8, "gpu": 1},) |
|
else: |
|
resources_per_trial = {"cpu": 8} |
|
|
|
|
|
best_trial = trainer.hyperparameter_search( |
|
direction="maximize", |
|
backend="ray", |
|
resources_per_trial=resources_per_trial, |
|
hp_space=lambda _: ray_config, |
|
search_alg=hyperopt_search, |
|
n_trials=100, |
|
progress_reporter=tune.CLIReporter( |
|
max_report_frequency=600, |
|
sort_by_metric=True, |
|
max_progress_rows=100, |
|
mode="max", |
|
metric="eval_accuracy", |
|
metric_columns=["loss", "eval_loss", "eval_accuracy"], |
|
), |
|
) |
|
best_hyperparameters = best_trial.hyperparameters |
|
|
|
print("Best Hyperparameters:") |
|
print(best_hyperparameters) |
|
|
|
else: |
|
trainset_label = train_dataset |
|
label_counter = Counter(trainset_label[label]) |
|
total_cells = sum(label_counter.values()) |
|
|
|
|
|
cells_to_keep = [ |
|
k for k, v in label_counter.items() if v > (filter_cells * total_cells) |
|
] |
|
trainset_label_subset = trainset_label.filter( |
|
if_not_rare_cell_state, num_proc=cpu_cores |
|
) |
|
|
|
|
|
trainset_label_shuffled = trainset_label_subset.shuffle(seed=42) |
|
trainset_label_shuffled = trainset_label_shuffled.rename_column( |
|
label, "label" |
|
) |
|
|
|
|
|
target_names = list(Counter(trainset_label_shuffled["label"]).keys()) |
|
target_name_id_dict = dict( |
|
zip(target_names, [i for i in range(len(target_names))]) |
|
) |
|
target_dict_list = target_name_id_dict |
|
|
|
labeled_trainset = trainset_label_shuffled.map( |
|
classes_to_ids, num_proc=cpu_cores |
|
) |
|
|
|
|
|
labeled_train_split = labeled_trainset.select( |
|
[i for i in range(0, round(len(labeled_trainset) * 0.8))] |
|
) |
|
labeled_eval_split = labeled_trainset.select( |
|
[ |
|
i |
|
for i in range( |
|
round(len(labeled_trainset) * 0.8), len(labeled_trainset) |
|
) |
|
] |
|
) |
|
|
|
|
|
trained_labels = list(Counter(labeled_train_split["label"]).keys()) |
|
labeled_eval_split_subset = labeled_eval_split.filter( |
|
if_trained_label, num_proc=cpu_cores |
|
) |
|
|
|
|
|
logging_steps = round(len(trainset_label) / training_batch_size / 10) |
|
|
|
|
|
model = BertForSequenceClassification.from_pretrained( |
|
pretrained_model, |
|
num_labels=len(target_dict_list.keys()), |
|
output_attentions=False, |
|
output_hidden_states=False, |
|
).to(device) |
|
|
|
current_date = datetime.datetime.now() |
|
datestamp = f"{str(current_date.year)[-2:]}{current_date.month:02d}{current_date.day:02d}" |
|
|
|
if output_dir is None: |
|
output_dir = f"{datestamp}_geneformer_CellClassifier_L{max_input_size}_B{training_batch_size}_LR{max_lr}_LS{lr_schedule_fn}_WU{warmup_steps}_E{epochs}_O{optimizer}_F{freeze_layers}/" |
|
|
|
|
|
saved_model_test = os.path.join(output_dir, "pytorch_model.bin") |
|
if os.path.isfile(saved_model_test) is True and overwrite is False: |
|
raise Exception("Model already saved to this directory.") |
|
|
|
|
|
subprocess.call(f"mkdir -p {output_dir}", shell=True) |
|
|
|
|
|
training_args = { |
|
"learning_rate": max_lr, |
|
"do_train": True, |
|
"do_eval": True, |
|
"evaluation_strategy": "epoch", |
|
"save_strategy": "epoch", |
|
"logging_steps": logging_steps, |
|
"group_by_length": True, |
|
"length_column_name": "length", |
|
"disable_tqdm": False, |
|
"lr_scheduler_type": lr_schedule_fn, |
|
"warmup_steps": warmup_steps, |
|
"weight_decay": 0.001, |
|
"per_device_train_batch_size": training_batch_size, |
|
"per_device_eval_batch_size": training_batch_size, |
|
"num_train_epochs": epochs, |
|
"load_best_model_at_end": True, |
|
"output_dir": output_dir, |
|
} |
|
|
|
training_args_init = TrainingArguments(**training_args) |
|
true_labels = labeled_eval_split_subset["label"] |
|
|
|
if optimize_hyperparameters is False: |
|
|
|
trainer = Trainer( |
|
model=model, |
|
args=training_args_init, |
|
data_collator=DataCollatorForCellClassification(), |
|
train_dataset=labeled_train_split, |
|
eval_dataset=labeled_eval_split_subset, |
|
compute_metrics=compute_metrics, |
|
) |
|
|
|
|
|
trainer.train() |
|
predictions = trainer.predict(labeled_eval_split_subset) |
|
predictions_tensor = torch.Tensor(predictions.predictions) |
|
predicted_labels = torch.argmax(predictions_tensor, dim=1) |
|
print( |
|
f'accuracy: {accuracy_score(predicted_labels, labeled_eval_split_subset["label"])}' |
|
) |
|
metrics = compute_metrics(predictions) |
|
|
|
with open(f"{output_dir}predictions.pickle", "wb") as fp: |
|
pickle.dump(predictions.predictions.argmax(-1), fp) |
|
|
|
trainer.save_metrics("eval", predictions.metrics) |
|
trainer.save_model(output_dir) |
|
|
|
|
|
with open(f"{output_dir}/targets.txt", "w") as f: |
|
f.write(str(target_dict_list)) |
|
|
|
try: |
|
precision, recall, _ = precision_recall_curve( |
|
true_labels, predictions.predictions[:, 1] |
|
) |
|
pr_auc = precision_auc(recall, precision) |
|
|
|
print(f"AUC: {pr_auc}") |
|
return recall, precision, pr_auc |
|
except: |
|
pass |
|
|
|
else: |
|
|
|
|
|
num_classes = len(list(set(labeled_train_split["label"]))) |
|
|
|
def model_init(): |
|
model = BertForSequenceClassification.from_pretrained( |
|
pretrained_model, |
|
num_labels=num_classes, |
|
output_attentions=False, |
|
output_hidden_states=False, |
|
) |
|
|
|
if freeze_layers is not None: |
|
modules_to_freeze = model.bert.encoder.layer[:freeze_layers] |
|
for module in modules_to_freeze: |
|
for param in module.parameters(): |
|
param.requires_grad = False |
|
model = model.to(device) |
|
return model |
|
|
|
|
|
trainer = Trainer( |
|
model_init=model_init, |
|
args=training_args_init, |
|
data_collator=DataCollatorForCellClassification(), |
|
train_dataset=labeled_train_split, |
|
eval_dataset=labeled_eval_split_subset, |
|
compute_metrics=compute_metrics, |
|
) |
|
|
|
|
|
ray_config = { |
|
"num_train_epochs": tune.choice([epochs]), |
|
"learning_rate": tune.loguniform(1e-6, 1e-3), |
|
"weight_decay": tune.uniform(0.0, 0.3), |
|
"lr_scheduler_type": tune.choice( |
|
["linear", "cosine", "polynomial"] |
|
), |
|
"warmup_steps": tune.uniform(100, 2000), |
|
"seed": tune.uniform(0, 100), |
|
"per_device_train_batch_size": tune.choice([training_batch_size]), |
|
} |
|
|
|
hyperopt_search = HyperOptSearch(metric="eval_accuracy", mode="max") |
|
|
|
if torch.device == "cuda": |
|
resources_per_trial = ({"cpu": 8, "gpu": 1},) |
|
else: |
|
resources_per_trial = {"cpu": 8} |
|
|
|
|
|
best_trial = trainer.hyperparameter_search( |
|
direction="maximize", |
|
backend="ray", |
|
resources_per_trial=resources_per_trial, |
|
hp_space=lambda _: ray_config, |
|
search_alg=hyperopt_search, |
|
n_trials=100, |
|
progress_reporter=tune.CLIReporter( |
|
max_report_frequency=600, |
|
sort_by_metric=True, |
|
max_progress_rows=100, |
|
mode="max", |
|
metric="eval_accuracy", |
|
metric_columns=["loss", "eval_loss", "eval_accuracy"], |
|
), |
|
) |
|
best_hyperparameters = best_trial.hyperparameters |
|
|
|
print("Best Hyperparameters:") |
|
print(best_hyperparameters) |
|
|
|
|
|
if inference is True: |
|
if dataset_split is not None and data_filter is not None: |
|
|
|
def if_label(example): |
|
return example[dataset_split] == data_filter |
|
|
|
train_dataset = train_dataset.filter(if_label, num_proc=cpu_cores) |
|
|
|
trainset_label_shuffled = train_dataset |
|
total_cells = len(trainset_label_shuffled) |
|
|
|
|
|
with open(Path(finetuned_model) / "targets.txt", "r") as f: |
|
data = ast.literal_eval(f.read()) |
|
if dataset_split is not None and data_filter is None: |
|
indexer = dataset_split.index(data_filter) |
|
data = data[indexer] |
|
|
|
target_dict_list = {key: value for key, value in enumerate(data)} |
|
|
|
|
|
logging_steps = round(len(trainset_label_shuffled) / training_batch_size / 20) |
|
|
|
|
|
input_ids = trainset_label_shuffled["input_ids"] |
|
inputs = torch.zeros(len(input_ids), max_input_size, dtype=torch.int64) |
|
attention = torch.zeros(len(input_ids), max_input_size, dtype=torch.int64) |
|
|
|
for i, sentence in enumerate(input_ids): |
|
sentence_length = len(sentence) |
|
if sentence_length <= max_input_size: |
|
inputs[i, :sentence_length] = torch.tensor(sentence) |
|
attention[i, :sentence_length] = torch.ones(sentence_length) |
|
else: |
|
inputs[i, :] = torch.tensor(sentence[:max_input_size]) |
|
attention[i, :] = torch.ones(max_input_size) |
|
|
|
model = BertForSequenceClassification.from_pretrained( |
|
finetuned_model, num_labels=len(target_dict_list) |
|
).to(device) |
|
model_outputs = model(inputs.to(device), attention_mask=attention)["logits"] |
|
predictions = F.softmax(model_outputs, dim=-1).argmax(-1) |
|
|
|
predictions = [target_dict_list[int(pred)] for pred in predictions] |
|
|
|
return predictions |
|
|
|
|
|
if emb_extract is True: |
|
if emb_filter is None: |
|
with open(f"{finetuned_model}/targets.txt", "r") as f: |
|
data = ast.literal_eval(f.read()) |
|
if dataset_split is not None and data_filter is None: |
|
indexer = dataset_split.index(data_filter) |
|
data = data[indexer] |
|
|
|
target_dict_list = {key: value for key, value in enumerate(data)} |
|
total_filter = None |
|
else: |
|
total_filter = emb_filter |
|
|
|
train_dataset = load_from_disk(dataset) |
|
if dataset_split is not None: |
|
|
|
def if_label(example): |
|
return example[dataset_split] == data_filter |
|
|
|
train_dataset = train_dataset.filter(if_label, num_proc=cpu_cores) |
|
|
|
label_counter = Counter(train_dataset[label]) |
|
total_cells = sum(label_counter.values()) |
|
cells_to_keep = [ |
|
k for k, v in label_counter.items() if v > (filter_cells * total_cells) |
|
] |
|
|
|
def if_not_rare(example): |
|
return example[label] in cells_to_keep |
|
|
|
train_dataset = train_dataset.filter(if_not_rare, num_proc=cpu_cores) |
|
|
|
true_labels = train_dataset[label] |
|
num_classes = len(list(set(true_labels))) |
|
|
|
embex = EmbExtractor( |
|
model_type="CellClassifier", |
|
num_classes=num_classes, |
|
filter_data=total_filter, |
|
max_ncells=max_cells, |
|
emb_layer=emb_layer, |
|
emb_label=[dataset_split, label], |
|
labels_to_plot=[label], |
|
forward_batch_size=inference_batch_size, |
|
nproc=cpu_cores, |
|
) |
|
|
|
|
|
subprocess.call(f"mkdir -p {emb_dir}", shell=True) |
|
|
|
embs = embex.extract_embs( |
|
model_directory=finetuned_model, |
|
input_data_file=dataset, |
|
output_directory=emb_dir, |
|
output_prefix=f"{label}_embeddings", |
|
) |
|
true_labels = embex.filtered_input_data[label] |
|
|
|
emb_dict = {label: [] for label in list(set(true_labels))} |
|
for num, emb in embs.iterrows(): |
|
key = emb[label] |
|
selection = emb.iloc[:255] |
|
emb = torch.Tensor(selection) |
|
emb_dict[key].append(emb) |
|
|
|
for key in list(emb_dict.keys()): |
|
stack = torch.stack(emb_dict[key], dim=0) |
|
emb_dict[key] = torch.mean(stack, dim=0) |
|
similarities = {key: {} for key in list(emb_dict.keys())} |
|
|
|
for key in list(emb_dict.keys()): |
|
remaining_keys = [k for k in list(emb_dict.keys()) if k != key] |
|
for k in remaining_keys: |
|
embedding = emb_dict[k] |
|
sim = similarity(emb_dict[key], embedding, cosine=True) |
|
|
|
similarities[key][k] = sim |
|
|
|
plot_similarity_heatmap(similarities) |
|
|
|
embex.plot_embs( |
|
embs=embs, |
|
plot_style="umap", |
|
output_directory=emb_dir, |
|
output_prefix="emb_plot", |
|
) |
|
|
|
embex.plot_embs( |
|
embs=embs, |
|
plot_style="heatmap", |
|
output_directory=emb_dir, |
|
output_prefix="emb_plot", |
|
) |
|
|
|
return similarities |
|
|