Geneformer / geneformer /cell_classifier.py
Christina Theodoris
update cell classifier module
025e1b8
raw
history blame
34.3 kB
"""
Geneformer cell classifier.
Usage:
from geneformer import classify_cells
classify_cells(
token_set=Path("geneformer/token_dictionary.pkl"),
median_set=Path("geneformer/gene_median_dictionary.pkl"),
pretrained_model=".",
dataset="Genecorpus-30M/example_input_files/cell_classification/cell_type_annotation/cell_type_train_data.dataset/",
dataset_split=None,
filter_cells=0.005,
epochs=1,
cpu_cores=os.cpu_count(),
geneformer_batch_size=12,
optimizer="adamw",
max_lr=5e-5,
num_gpus=torch.cuda.device_count(),
max_input_size=2**11,
lr_schedule_fn="linear",
warmup_steps=500,
freeze_layers=0,
emb_extract=False,
max_cells=1000,
emb_layer=0,
emb_filter=None,
emb_dir="embeddings",
overwrite=True,
label="cell_type",
data_filter=None,
forward_batch=200,
model_location=None,
skip_training=False,
sample_data=1,
inference=False,
optimize_hyperparameters=False,
output_dir=None,
)
"""
import ast
import datetime
import os
import pickle
import random
import subprocess
from collections import Counter
from pathlib import Path
import numpy as np
import seaborn as sns
import torch
import torch.nn.functional as F
from datasets import load_from_disk
from matplotlib import pyplot as plt
from ray import tune
from ray.tune.search.hyperopt import HyperOptSearch
from sklearn.metrics import accuracy_score
from sklearn.metrics import auc as precision_auc
from sklearn.metrics import f1_score, precision_recall_curve, roc_auc_score, roc_curve
from transformers import BertForSequenceClassification, Trainer
from transformers.training_args import TrainingArguments
from geneformer import DataCollatorForCellClassification, EmbExtractor
sns.set()
# Properly sets up NCCV environment
GPU_NUMBER = [i for i in range(torch.cuda.device_count())]
os.environ["CUDA_VISIBLE_DEVICES"] = ",".join([str(s) for s in GPU_NUMBER])
os.environ["NCCL_DEBUG"] = "INFO"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Function for generating an ROC curve from data
def ROC(prediction, truth, type="GeneFormer", label=""):
fpr, tpr, _ = roc_curve(truth, prediction[:, 1])
auc = roc_auc_score(truth, prediction[:, 1])
print(f"{type} AUC: {auc}")
plt.plot(fpr, tpr, label="AUC=" + str(auc))
plt.ylabel("True Positive Rate")
plt.xlabel("False Positive Rate")
plt.title(f"{label} ROC Curve")
plt.legend(loc=4)
plt.savefig("ROC.png")
return tpr, fpr, auc
# Identifies cosine similarity between two embeddings. 0 is perfectly dissimilar and 1 is perfectly similar
def similarity(tensor1, tensor2, cosine=False):
if cosine is False:
if tensor1.ndimension() > 1:
tensor1 = tensor1.view(1, -1)
if tensor2.ndimension() > 1:
tensor2 = tensor2.view(1, -1)
dot_product = torch.matmul(tensor1, tensor2)
norm_tensor1 = torch.norm(tensor1)
norm_tensor2 = torch.norm(tensor2)
epsilon = 1e-8
similarity = dot_product / (norm_tensor1 * norm_tensor2 + epsilon)
similarity = (similarity.item() + 1) / 2
else:
if tensor1.shape != tensor2.shape:
raise ValueError("Input tensors must have the same shape.")
# Compute cosine similarity using PyTorch's dot product function
dot_product = torch.dot(tensor1, tensor2)
norm_tensor1 = torch.norm(tensor1)
norm_tensor2 = torch.norm(tensor2)
# Avoid division by zero by adding a small epsilon
epsilon = 1e-8
similarity = dot_product / (norm_tensor1 * norm_tensor2 + epsilon)
return similarity.item()
# Plots heatmap between different classes/labels
def plot_similarity_heatmap(similarities):
classes = list(similarities.keys())
classlen = len(classes)
arr = np.zeros((classlen, classlen))
for i, c in enumerate(classes):
for j, cc in enumerate(classes):
if cc == c:
val = 1.0
else:
val = similarities[c][cc]
arr[i][j] = val
plt.figure(figsize=(8, 6))
plt.imshow(arr, cmap="inferno", vmin=0, vmax=1)
plt.colorbar()
plt.xticks(np.arange(classlen), classes, rotation=45, ha="right")
plt.yticks(np.arange(classlen), classes)
plt.title("Similarity Heatmap")
plt.savefig("similarity_heatmap.png")
def classify_cells(
token_set=Path("./token_dictionary.pkl"),
median_set=Path("./gene_median_dictionary.pkl"),
pretrained_model="../",
dataset="Genecorpus-30M/example_input_files/cell_classification/cell_type_annotation/cell_type_train_data.dataset/",
dataset_split=None,
filter_cells=0.005,
epochs=1,
cpu_cores=os.cpu_count(),
training_batch_size=12,
optimizer="adamw",
max_lr=5e-5,
num_gpus=torch.cuda.device_count(),
max_input_size=2**11,
lr_schedule_fn="linear",
warmup_steps=500,
freeze_layers=0,
emb_extract=False,
max_cells=None,
emb_layer=-1,
emb_filter=None,
emb_dir="embeddings",
overwrite=False,
label="cell_type",
data_filter=None,
inference_batch_size=200,
finetuned_model=None,
skip_training=False,
sample_data=1,
inference=False,
optimize_hyperparameters=True,
output_dir=None,
):
"""
Primary Parameters
-------------------
dataset: path
Path to fine-tuning dataset for training
finetuned_model: path
Path to location of fine-tuned model to use for inference and embedding extraction
pretrained_model: path
Path to pretrained Geneformer model
inference: bool
Indicates whether to perform inference and return a list of similarities. Defaults to False.
skip_training: bool
Indicates whether to skip training the model. Defaults to False.
emb_extract: bool
Indicates whether to extract embeddings and calculate similarities. Defaults to True.
optimize_hyperparameters: bool
Indicates whether to optimize model hyperparamters. Defaults to False.
Customization Parameters
-------------------
dataset_split: str
Indicates how the dataset should be partitioned (if at all), and what ID should be used for partitioning
data_filter: list
(For embeddings and inference) Runs analysis on subsets of the dataset based on the ID defined by dataset_split
label: str
Feature to read as a classification label.
emb_layer: int
What layer embeddings should be extracted and compared.
emb_filter: ['cell1', 'cell2'...]
Allows user to narrow down range of cells that embeddings will be extracted from.
max_cells: int
Max number of cells to use for embedding extraction.
freeze_layers: int
Number of layers that should be frozen during fine-tuning.
sample_data: float
Proportion of the dataset that should be used.
"""
dataset_list = []
evalset_list = []
split_list = []
target_dict_list = []
train_dataset = load_from_disk(dataset)
num_samples = int(len(train_dataset) * sample_data)
random_indices = random.sample(range(len(train_dataset)), num_samples)
train_dataset = train_dataset.select(random_indices)
sample = int(sample_data * len(train_dataset))
sample_indices = random.sample(range(len(train_dataset)), sample)
train_dataset = train_dataset.select(sample_indices)
def if_not_rare_cell_state(example):
return example[label] in cells_to_keep
# change labels to numerical ids
def classes_to_ids(example):
example["label"] = target_name_id_dict[example["label"]]
return example
def if_trained_label(example):
return example["label"] in trained_labels
if skip_training is not True:
def compute_metrics(pred):
labels = pred.label_ids
preds = pred.predictions.argmax(-1)
# calculate accuracy and macro f1 using sklearn's function
acc = accuracy_score(labels, preds)
macro_f1 = f1_score(labels, preds, average="macro")
return {"accuracy": acc, "macro_f1": macro_f1}
# Defines custom exceptions for collecting labels (default excluded)
excep = {"bone_marrow": "immune"}
if dataset_split is not None:
if data_filter is not None:
split_iter = [data_filter]
else:
split_iter = Counter(train_dataset[dataset_split]).keys()
for lab in split_iter:
# collect list of tissues for fine-tuning (immune and bone marrow are included together)
if lab in list(excep.keys()):
continue
elif lab == list(excep.values()):
split_ids = [excep.keys(), excep.values()]
split_list += [excep.values()]
else:
split_ids = [lab]
split_list += [lab]
# filter datasets for given organ
def if_label(example):
return example[dataset_split] == lab
trainset_label = train_dataset.filter(if_label, num_proc=cpu_cores)
label_counter = Counter(trainset_label[label])
total_cells = sum(label_counter.values())
# excludes cells with a low proportion in the dataset
cells_to_keep = [
k
for k, v in label_counter.items()
if v > (filter_cells * total_cells)
]
trainset_label_subset = trainset_label.filter(
if_not_rare_cell_state, num_proc=cpu_cores
)
# shuffle datasets and rename columns
trainset_label_shuffled = trainset_label_subset.shuffle(seed=42)
trainset_label_shuffled = trainset_label_shuffled.rename_column(
label, "label"
)
trainset_label_shuffled = trainset_label_shuffled.remove_columns(
dataset_split
)
# create dictionary of cell types : label ids
target_names = list(Counter(trainset_label_shuffled["label"]).keys())
target_name_id_dict = dict(
zip(target_names, [i for i in range(len(target_names))])
)
target_dict_list += [target_name_id_dict]
labeled_trainset = trainset_label_shuffled.map(
classes_to_ids, num_proc=cpu_cores
)
# create 80/20 train/eval splits
labeled_train_split = trainset_label_shuffled.select(
[i for i in range(0, round(len(labeled_trainset) * 0.8))]
)
labeled_eval_split = trainset_label_shuffled.select(
[
i
for i in range(
round(len(labeled_trainset) * 0.8), len(labeled_trainset)
)
]
)
# filter dataset for cell types in corresponding training set
trained_labels = list(Counter(labeled_train_split["label"]).keys())
labeled_eval_split_subset = labeled_eval_split.filter(
if_trained_label, num_proc=cpu_cores
)
dataset_list += [labeled_train_split]
evalset_list += [labeled_eval_split_subset]
trainset_dict = dict(zip(split_list, dataset_list))
traintargetdict_dict = dict(zip(split_list, target_dict_list))
evalset_dict = dict(zip(split_list, evalset_list))
for lab in split_list:
label_trainset = trainset_dict[lab]
label_evalset = evalset_dict[lab]
label_dict = traintargetdict_dict[lab]
# set logging steps
logging_steps = round(len(label_trainset) / training_batch_size / 10)
if logging_steps == 0:
logging_steps = 1
# load pretrained model
model = BertForSequenceClassification.from_pretrained(
pretrained_model,
num_labels=len(label_dict.keys()),
output_attentions=False,
output_hidden_states=False,
).to(device)
# define output directory path
current_date = datetime.datetime.now()
datestamp = f"{str(current_date.year)[-2:]}{current_date.month:02d}{current_date.day:02d}"
if output_dir is None:
output_dir = f"{datestamp}_geneformer_CellClassifier_{lab}_L{max_input_size}_B{training_batch_size}_LR{max_lr}_LS{lr_schedule_fn}_WU{warmup_steps}_E{epochs}_O{optimizer}_F{freeze_layers}/"
# ensure not overwriting previously saved model
saved_model_test = os.path.join(output_dir, "pytorch_model.bin")
if os.path.isfile(saved_model_test) is True and overwrite is False:
raise Exception("Model already saved to this directory.")
# make output directory
subprocess.call(f"mkdir -p {output_dir}", shell=True)
# set training arguments
training_args = {
"learning_rate": max_lr,
"do_train": True,
"do_eval": True,
"evaluation_strategy": "epoch",
"save_strategy": "epoch",
"logging_steps": logging_steps,
"group_by_length": True,
"length_column_name": "length",
"disable_tqdm": False,
"lr_scheduler_type": lr_schedule_fn,
"warmup_steps": warmup_steps,
"weight_decay": 0.001,
"per_device_train_batch_size": training_batch_size,
"per_device_eval_batch_size": training_batch_size,
"num_train_epochs": epochs,
"load_best_model_at_end": True,
"output_dir": output_dir,
}
training_args_init = TrainingArguments(**training_args)
true_labels = label_evalset["label"]
if optimize_hyperparameters is False:
# create the trainer
trainer = Trainer(
model=model,
args=training_args_init,
data_collator=DataCollatorForCellClassification(),
train_dataset=label_trainset,
eval_dataset=label_evalset,
compute_metrics=compute_metrics,
)
# train the cell type classifier
trainer.train()
predictions = trainer.predict(label_evalset)
print(
f'accuracy: {accuracy_score(predictions.argmax(), label_evalset["labels"])}'
)
tpr, fpr, auc = ROC(predictions.predictions, true_labels)
metrics = compute_metrics(predictions)
with open(f"{output_dir}predictions.pickle", "wb") as fp:
pickle.dump(predictions, fp)
trainer.save_metrics("eval", predictions.metrics)
with open(f"{output_dir}/targets.txt", "w") as f:
if len(target_dict_list) == 1:
f.write(str(target_dict_list[0]))
else:
f.write(str(target_dict_list))
try:
precision, recall, _ = precision_recall_curve(
true_labels, predictions.predictions[:, 1]
)
pr_auc = precision_auc(recall, precision)
print(f"AUC: {pr_auc}")
return recall, precision, pr_auc
except:
pass
trainer.save_model(output_dir)
else:
def model_init():
model = BertForSequenceClassification.from_pretrained(
pretrained_model,
num_labels=len(label_dict.keys()),
output_attentions=False,
output_hidden_states=False,
)
if freeze_layers is not None:
modules_to_freeze = model.bert.encoder.layer[:freeze_layers]
for module in modules_to_freeze:
for param in module.parameters():
param.requires_grad = False
model = model.to(device)
return model
trainer = Trainer(
model_init=model_init,
args=training_args_init,
data_collator=DataCollatorForCellClassification(),
train_dataset=label_trainset,
eval_dataset=label_evalset,
compute_metrics=compute_metrics,
)
# specify raytune hyperparameter search space
ray_config = {
"num_train_epochs": tune.choice([epochs]),
"learning_rate": tune.loguniform(1e-6, 1e-3),
"weight_decay": tune.uniform(0.0, 0.3),
"lr_scheduler_type": tune.choice(
["linear", "cosine", "polynomial"]
),
"warmup_steps": tune.uniform(100, 2000),
"seed": tune.uniform(0, 100),
"per_device_train_batch_size": tune.choice(
[training_batch_size]
),
}
hyperopt_search = HyperOptSearch(metric="eval_accuracy", mode="max")
if torch.device == "cuda":
resources_per_trial = ({"cpu": 8, "gpu": 1},)
else:
resources_per_trial = {"cpu": 8}
# optimize hyperparameters
best_trial = trainer.hyperparameter_search(
direction="maximize",
backend="ray",
resources_per_trial=resources_per_trial,
hp_space=lambda _: ray_config,
search_alg=hyperopt_search,
n_trials=100, # number of trials
progress_reporter=tune.CLIReporter(
max_report_frequency=600,
sort_by_metric=True,
max_progress_rows=100,
mode="max",
metric="eval_accuracy",
metric_columns=["loss", "eval_loss", "eval_accuracy"],
),
)
best_hyperparameters = best_trial.hyperparameters
print("Best Hyperparameters:")
print(best_hyperparameters)
else:
trainset_label = train_dataset
label_counter = Counter(trainset_label[label])
total_cells = sum(label_counter.values())
# Excludes cells with a low proportion in the dataset
cells_to_keep = [
k for k, v in label_counter.items() if v > (filter_cells * total_cells)
]
trainset_label_subset = trainset_label.filter(
if_not_rare_cell_state, num_proc=cpu_cores
)
# shuffle datasets and rename columns
trainset_label_shuffled = trainset_label_subset.shuffle(seed=42)
trainset_label_shuffled = trainset_label_shuffled.rename_column(
label, "label"
)
# create dictionary of cell types : label ids
target_names = list(Counter(trainset_label_shuffled["label"]).keys())
target_name_id_dict = dict(
zip(target_names, [i for i in range(len(target_names))])
)
target_dict_list = target_name_id_dict
labeled_trainset = trainset_label_shuffled.map(
classes_to_ids, num_proc=cpu_cores
)
# create 80/20 train/eval splits
labeled_train_split = labeled_trainset.select(
[i for i in range(0, round(len(labeled_trainset) * 0.8))]
)
labeled_eval_split = labeled_trainset.select(
[
i
for i in range(
round(len(labeled_trainset) * 0.8), len(labeled_trainset)
)
]
)
# filter dataset for cell types in corresponding training set
trained_labels = list(Counter(labeled_train_split["label"]).keys())
labeled_eval_split_subset = labeled_eval_split.filter(
if_trained_label, num_proc=cpu_cores
)
# set logging steps
logging_steps = round(len(trainset_label) / training_batch_size / 10)
# load pretrained model
model = BertForSequenceClassification.from_pretrained(
pretrained_model,
num_labels=len(target_dict_list.keys()),
output_attentions=False,
output_hidden_states=False,
).to(device)
# define output directory path
current_date = datetime.datetime.now()
datestamp = f"{str(current_date.year)[-2:]}{current_date.month:02d}{current_date.day:02d}"
if output_dir is None:
output_dir = f"{datestamp}_geneformer_CellClassifier_L{max_input_size}_B{training_batch_size}_LR{max_lr}_LS{lr_schedule_fn}_WU{warmup_steps}_E{epochs}_O{optimizer}_F{freeze_layers}/"
# ensure not overwriting previously saved model
saved_model_test = os.path.join(output_dir, "pytorch_model.bin")
if os.path.isfile(saved_model_test) is True and overwrite is False:
raise Exception("Model already saved to this directory.")
# make output directory
subprocess.call(f"mkdir -p {output_dir}", shell=True)
# set training arguments
training_args = {
"learning_rate": max_lr,
"do_train": True,
"do_eval": True,
"evaluation_strategy": "epoch",
"save_strategy": "epoch",
"logging_steps": logging_steps,
"group_by_length": True,
"length_column_name": "length",
"disable_tqdm": False,
"lr_scheduler_type": lr_schedule_fn,
"warmup_steps": warmup_steps,
"weight_decay": 0.001,
"per_device_train_batch_size": training_batch_size,
"per_device_eval_batch_size": training_batch_size,
"num_train_epochs": epochs,
"load_best_model_at_end": True,
"output_dir": output_dir,
}
training_args_init = TrainingArguments(**training_args)
true_labels = labeled_eval_split_subset["label"]
if optimize_hyperparameters is False:
# create the trainer
trainer = Trainer(
model=model,
args=training_args_init,
data_collator=DataCollatorForCellClassification(),
train_dataset=labeled_train_split,
eval_dataset=labeled_eval_split_subset,
compute_metrics=compute_metrics,
)
# train the cell type classifier
trainer.train()
predictions = trainer.predict(labeled_eval_split_subset)
predictions_tensor = torch.Tensor(predictions.predictions)
predicted_labels = torch.argmax(predictions_tensor, dim=1)
print(
f'accuracy: {accuracy_score(predicted_labels, labeled_eval_split_subset["label"])}'
)
metrics = compute_metrics(predictions)
with open(f"{output_dir}predictions.pickle", "wb") as fp:
pickle.dump(predictions.predictions.argmax(-1), fp)
trainer.save_metrics("eval", predictions.metrics)
trainer.save_model(output_dir)
# Saves label conversion dictionary to output directory
with open(f"{output_dir}/targets.txt", "w") as f:
f.write(str(target_dict_list))
try:
precision, recall, _ = precision_recall_curve(
true_labels, predictions.predictions[:, 1]
)
pr_auc = precision_auc(recall, precision)
print(f"AUC: {pr_auc}")
return recall, precision, pr_auc
except:
pass
else:
# Optimizes hyperparameters
num_classes = len(list(set(labeled_train_split["label"])))
def model_init():
model = BertForSequenceClassification.from_pretrained(
pretrained_model,
num_labels=num_classes,
output_attentions=False,
output_hidden_states=False,
)
if freeze_layers is not None:
modules_to_freeze = model.bert.encoder.layer[:freeze_layers]
for module in modules_to_freeze:
for param in module.parameters():
param.requires_grad = False
model = model.to(device)
return model
# create the trainer
trainer = Trainer(
model_init=model_init,
args=training_args_init,
data_collator=DataCollatorForCellClassification(),
train_dataset=labeled_train_split,
eval_dataset=labeled_eval_split_subset,
compute_metrics=compute_metrics,
)
# specify raytune hyperparameter search space
ray_config = {
"num_train_epochs": tune.choice([epochs]),
"learning_rate": tune.loguniform(1e-6, 1e-3),
"weight_decay": tune.uniform(0.0, 0.3),
"lr_scheduler_type": tune.choice(
["linear", "cosine", "polynomial"]
),
"warmup_steps": tune.uniform(100, 2000),
"seed": tune.uniform(0, 100),
"per_device_train_batch_size": tune.choice([training_batch_size]),
}
hyperopt_search = HyperOptSearch(metric="eval_accuracy", mode="max")
if torch.device == "cuda":
resources_per_trial = ({"cpu": 8, "gpu": 1},)
else:
resources_per_trial = {"cpu": 8}
# optimize hyperparameters
best_trial = trainer.hyperparameter_search(
direction="maximize",
backend="ray",
resources_per_trial=resources_per_trial,
hp_space=lambda _: ray_config,
search_alg=hyperopt_search,
n_trials=100, # number of trials
progress_reporter=tune.CLIReporter(
max_report_frequency=600,
sort_by_metric=True,
max_progress_rows=100,
mode="max",
metric="eval_accuracy",
metric_columns=["loss", "eval_loss", "eval_accuracy"],
),
)
best_hyperparameters = best_trial.hyperparameters
print("Best Hyperparameters:")
print(best_hyperparameters)
# Performs Inference with model
if inference is True:
if dataset_split is not None and data_filter is not None:
def if_label(example):
return example[dataset_split] == data_filter
train_dataset = train_dataset.filter(if_label, num_proc=cpu_cores)
trainset_label_shuffled = train_dataset
total_cells = len(trainset_label_shuffled)
# loads dictionary of all cell labels model was trained on
with open(Path(finetuned_model) / "targets.txt", "r") as f:
data = ast.literal_eval(f.read())
if dataset_split is not None and data_filter is None:
indexer = dataset_split.index(data_filter)
data = data[indexer]
target_dict_list = {key: value for key, value in enumerate(data)}
# set logging steps
logging_steps = round(len(trainset_label_shuffled) / training_batch_size / 20)
# load pretrained model
input_ids = trainset_label_shuffled["input_ids"]
inputs = torch.zeros(len(input_ids), max_input_size, dtype=torch.int64)
attention = torch.zeros(len(input_ids), max_input_size, dtype=torch.int64)
for i, sentence in enumerate(input_ids):
sentence_length = len(sentence)
if sentence_length <= max_input_size:
inputs[i, :sentence_length] = torch.tensor(sentence)
attention[i, :sentence_length] = torch.ones(sentence_length)
else:
inputs[i, :] = torch.tensor(sentence[:max_input_size])
attention[i, :] = torch.ones(max_input_size)
model = BertForSequenceClassification.from_pretrained(
finetuned_model, num_labels=len(target_dict_list)
).to(device)
model_outputs = model(inputs.to(device), attention_mask=attention)["logits"]
predictions = F.softmax(model_outputs, dim=-1).argmax(-1)
predictions = [target_dict_list[int(pred)] for pred in predictions]
return predictions
# Extracts embeddings from labeled data
if emb_extract is True:
if emb_filter is None:
with open(f"{finetuned_model}/targets.txt", "r") as f:
data = ast.literal_eval(f.read())
if dataset_split is not None and data_filter is None:
indexer = dataset_split.index(data_filter)
data = data[indexer]
target_dict_list = {key: value for key, value in enumerate(data)}
total_filter = None
else:
total_filter = emb_filter
train_dataset = load_from_disk(dataset)
if dataset_split is not None:
def if_label(example):
return example[dataset_split] == data_filter
train_dataset = train_dataset.filter(if_label, num_proc=cpu_cores)
label_counter = Counter(train_dataset[label])
total_cells = sum(label_counter.values())
cells_to_keep = [
k for k, v in label_counter.items() if v > (filter_cells * total_cells)
]
def if_not_rare(example):
return example[label] in cells_to_keep
train_dataset = train_dataset.filter(if_not_rare, num_proc=cpu_cores)
true_labels = train_dataset[label]
num_classes = len(list(set(true_labels)))
embex = EmbExtractor(
model_type="CellClassifier",
num_classes=num_classes,
filter_data=total_filter,
max_ncells=max_cells,
emb_layer=emb_layer,
emb_label=[dataset_split, label],
labels_to_plot=[label],
forward_batch_size=inference_batch_size,
nproc=cpu_cores,
)
# example dataset: https://huggingface.co/datasets/ctheodoris/Genecorpus-30M/tree/main/example_input_files/cell_classification/disease_classification/human_dcm_hcm_nf.dataset
subprocess.call(f"mkdir -p {emb_dir}", shell=True)
embs = embex.extract_embs(
model_directory=finetuned_model,
input_data_file=dataset,
output_directory=emb_dir,
output_prefix=f"{label}_embeddings",
)
true_labels = embex.filtered_input_data[label]
emb_dict = {label: [] for label in list(set(true_labels))}
for num, emb in embs.iterrows():
key = emb[label]
selection = emb.iloc[:255]
emb = torch.Tensor(selection)
emb_dict[key].append(emb)
for key in list(emb_dict.keys()):
stack = torch.stack(emb_dict[key], dim=0)
emb_dict[key] = torch.mean(stack, dim=0)
similarities = {key: {} for key in list(emb_dict.keys())}
for key in list(emb_dict.keys()):
remaining_keys = [k for k in list(emb_dict.keys()) if k != key]
for k in remaining_keys:
embedding = emb_dict[k]
sim = similarity(emb_dict[key], embedding, cosine=True)
similarities[key][k] = sim
plot_similarity_heatmap(similarities)
embex.plot_embs(
embs=embs,
plot_style="umap",
output_directory=emb_dir,
output_prefix="emb_plot",
)
embex.plot_embs(
embs=embs,
plot_style="heatmap",
output_directory=emb_dir,
output_prefix="emb_plot",
)
return similarities