|
import os |
|
import sys |
|
|
|
GPU_NUMBER = [0] |
|
os.environ["CUDA_VISIBLE_DEVICES"] = ",".join([str(s) for s in GPU_NUMBER]) |
|
os.environ["NCCL_DEBUG"] = "INFO" |
|
|
|
import ast |
|
import datetime |
|
import math |
|
import pickle |
|
import subprocess |
|
from pathlib import Path |
|
|
|
import matplotlib.pyplot as plt |
|
import numpy as np |
|
import pandas as pd |
|
import torch |
|
import torch.nn.functional as F |
|
from datasets import Dataset, load_from_disk |
|
from sklearn import preprocessing |
|
from sklearn.metrics import ( |
|
ConfusionMatrixDisplay, |
|
accuracy_score, |
|
auc, |
|
confusion_matrix, |
|
roc_auc_score, |
|
roc_curve, |
|
) |
|
|
|
|
|
from sklearn.model_selection import StratifiedKFold, train_test_split |
|
from tqdm.notebook import tqdm |
|
from transformers import BertForTokenClassification, Trainer |
|
from transformers.training_args import TrainingArguments |
|
|
|
from geneformer import DataCollatorForGeneClassification, EmbExtractor |
|
from geneformer.pretrainer import token_dictionary |
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
from geneformer import TranscriptomeTokenizer |
|
|
|
|
|
def vote(logit_pair): |
|
a, b = logit_pair |
|
if a > b: |
|
return 0 |
|
elif b > a: |
|
return 1 |
|
elif a == b: |
|
return "tie" |
|
|
|
|
|
def py_softmax(vector): |
|
e = np.exp(vector) |
|
return e / e.sum() |
|
|
|
|
|
|
|
|
|
def similarity(tensor1, tensor2, cosine=True): |
|
if cosine == False: |
|
if tensor1.ndimension() > 1: |
|
tensor1 = tensor1.view(1, -1) |
|
if tensor2.ndimension() > 1: |
|
tensor2 = tensor2.view(1, -1) |
|
dot_product = torch.matmul(tensor1, tensor2) |
|
norm_tensor1 = torch.norm(tensor1) |
|
norm_tensor2 = torch.norm(tensor2) |
|
epsilon = 1e-8 |
|
similarity = dot_product / (norm_tensor1 * norm_tensor2 + epsilon) |
|
similarity = (similarity.item() + 1) / 2 |
|
else: |
|
if tensor1.shape != tensor2.shape: |
|
raise ValueError("Input tensors must have the same shape.") |
|
|
|
|
|
dot_product = torch.dot(tensor1, tensor2) |
|
norm_tensor1 = torch.norm(tensor1) |
|
norm_tensor2 = torch.norm(tensor2) |
|
|
|
|
|
epsilon = 1e-8 |
|
similarity = dot_product / (norm_tensor1 * norm_tensor2 + epsilon) |
|
|
|
return similarity.item() |
|
|
|
|
|
|
|
def plot_similarity_heatmap(similarities): |
|
classes = list(similarities.keys()) |
|
classlen = len(classes) |
|
arr = np.zeros((classlen, classlen)) |
|
for i, c in enumerate(classes): |
|
for j, cc in enumerate(classes): |
|
if cc == c: |
|
val = 1.0 |
|
else: |
|
val = similarities[c][cc] |
|
arr[i][j] = val |
|
|
|
plt.figure(figsize=(8, 6)) |
|
plt.imshow(arr, cmap="inferno", vmin=0, vmax=1) |
|
plt.colorbar() |
|
plt.xticks(np.arange(classlen), classes, rotation=45, ha="right") |
|
plt.yticks(np.arange(classlen), classes) |
|
plt.title("Similarity Heatmap") |
|
plt.savefig("similarity_heatmap.png") |
|
|
|
|
|
|
|
def get_cross_valid_metrics(all_tpr, all_roc_auc, all_tpr_wt): |
|
wts = [count / sum(all_tpr_wt) for count in all_tpr_wt] |
|
|
|
all_weighted_tpr = [a * b for a, b in zip(all_tpr, wts)] |
|
mean_tpr = np.sum(all_weighted_tpr, axis=0) |
|
mean_tpr[-1] = 1.0 |
|
all_weighted_roc_auc = [a * b for a, b in zip(all_roc_auc, wts)] |
|
roc_auc = np.sum(all_weighted_roc_auc) |
|
roc_auc_sd = math.sqrt(np.average((all_roc_auc - roc_auc) ** 2, weights=wts)) |
|
return mean_tpr, roc_auc, roc_auc_sd |
|
|
|
|
|
def validate( |
|
data, |
|
targets, |
|
labels, |
|
nsplits, |
|
subsample_size, |
|
training_args, |
|
freeze_layers, |
|
output_dir, |
|
num_proc, |
|
num_labels, |
|
pre_model, |
|
): |
|
|
|
num_classes = len(set(labels)) |
|
mean_fpr = np.linspace(0, 1, 100) |
|
|
|
|
|
targets_train, targets_eval, labels_train, labels_eval = train_test_split( |
|
targets, labels, test_size=0.25, shuffle=True |
|
) |
|
label_dict_train = dict(zip(targets_train, labels_train)) |
|
label_dict_eval = dict(zip(targets_eval, labels_eval)) |
|
|
|
|
|
def if_contains_train_label(example): |
|
a = label_dict_train.keys() |
|
b = example["input_ids"] |
|
return not set(a).isdisjoint(b) |
|
|
|
def if_contains_eval_label(example): |
|
a = label_dict_eval.keys() |
|
b = example["input_ids"] |
|
return not set(a).isdisjoint(b) |
|
|
|
|
|
print(f"Filtering training data") |
|
trainset = data.filter(if_contains_train_label, num_proc=num_proc) |
|
print( |
|
f"Filtered {round((1-len(trainset)/len(data))*100)}%; {len(trainset)} remain\n" |
|
) |
|
print(f"Filtering evalation data") |
|
evalset = data.filter(if_contains_eval_label, num_proc=num_proc) |
|
print(f"Filtered {round((1-len(evalset)/len(data))*100)}%; {len(evalset)} remain\n") |
|
|
|
|
|
training_size = min(subsample_size, len(trainset)) |
|
trainset_min = trainset.select([i for i in range(training_size)]) |
|
eval_size = min(training_size, len(evalset)) |
|
half_training_size = round(eval_size / 2) |
|
evalset_train_min = evalset.select([i for i in range(half_training_size)]) |
|
evalset_oos_min = evalset.select([i for i in range(half_training_size, eval_size)]) |
|
|
|
|
|
def generate_train_labels(example): |
|
example["labels"] = [ |
|
label_dict_train.get(token_id, -100) for token_id in example["input_ids"] |
|
] |
|
return example |
|
|
|
def generate_eval_labels(example): |
|
example["labels"] = [ |
|
label_dict_eval.get(token_id, -100) for token_id in example["input_ids"] |
|
] |
|
return example |
|
|
|
|
|
print(f"Labeling training data") |
|
trainset_labeled = trainset_min.map(generate_train_labels) |
|
print(f"Labeling evaluation data") |
|
evalset_train_labeled = evalset_train_min.map(generate_eval_labels) |
|
print(f"Labeling evaluation OOS data") |
|
evalset_oos_labeled = evalset_oos_min.map(generate_eval_labels) |
|
|
|
|
|
model = BertForTokenClassification.from_pretrained( |
|
pre_model, |
|
num_labels=num_labels, |
|
output_attentions=False, |
|
output_hidden_states=False, |
|
) |
|
if freeze_layers is not None: |
|
modules_to_freeze = model.bert.encoder.layer[:freeze_layers] |
|
for module in modules_to_freeze: |
|
for param in module.parameters(): |
|
param.requires_grad = False |
|
|
|
model = model.to(device) |
|
|
|
|
|
training_args["output_dir"] = output_dir |
|
training_args_init = TrainingArguments(**training_args) |
|
|
|
|
|
trainer = Trainer( |
|
model=model, |
|
args=training_args_init, |
|
data_collator=DataCollatorForGeneClassification(), |
|
train_dataset=trainset_labeled, |
|
eval_dataset=evalset_train_labeled, |
|
) |
|
|
|
|
|
trainer.train() |
|
trainer.save_model(output_dir) |
|
|
|
fpr, tpr, interp_tpr, conf_mat = classifier_predict( |
|
trainer.model, evalset_oos_labeled, 200, mean_fpr |
|
) |
|
auc_score = auc(fpr, tpr) |
|
|
|
return fpr, tpr, auc_score |
|
|
|
|
|
|
|
def cross_validate( |
|
data, |
|
targets, |
|
labels, |
|
nsplits, |
|
subsample_size, |
|
training_args, |
|
freeze_layers, |
|
output_dir, |
|
num_proc, |
|
num_labels, |
|
pre_model, |
|
): |
|
|
|
|
|
model_dir_test = os.path.join(output_dir, "ksplit0/models/pytorch_model.bin") |
|
|
|
|
|
|
|
device = device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
num_classes = len(set(labels)) |
|
mean_fpr = np.linspace(0, 1, 100) |
|
all_tpr = [] |
|
all_roc_auc = [] |
|
all_tpr_wt = [] |
|
label_dicts = [] |
|
confusion = np.zeros((num_classes, num_classes)) |
|
|
|
|
|
skf = StratifiedKFold(n_splits=nsplits, random_state=0, shuffle=True) |
|
|
|
iteration_num = 0 |
|
for train_index, eval_index in tqdm(skf.split(targets, labels)): |
|
if len(labels) > 500: |
|
print("early stopping activated due to large # of training examples") |
|
if iteration_num == 3: |
|
break |
|
|
|
print(f"****** Crossval split: {iteration_num}/{nsplits-1} ******\n") |
|
|
|
|
|
targets_train, targets_eval = targets[train_index], targets[eval_index] |
|
labels_train, labels_eval = labels[train_index], labels[eval_index] |
|
label_dict_train = dict(zip(targets_train, labels_train)) |
|
label_dict_eval = dict(zip(targets_eval, labels_eval)) |
|
label_dicts += ( |
|
iteration_num, |
|
targets_train, |
|
targets_eval, |
|
labels_train, |
|
labels_eval, |
|
) |
|
|
|
|
|
def if_contains_train_label(example): |
|
a = label_dict_train.keys() |
|
b = example["input_ids"] |
|
|
|
return not set(a).isdisjoint(b) |
|
|
|
def if_contains_eval_label(example): |
|
a = label_dict_eval.keys() |
|
b = example["input_ids"] |
|
|
|
return not set(a).isdisjoint(b) |
|
|
|
|
|
print(f"Filtering training data") |
|
trainset = data.filter(if_contains_train_label, num_proc=num_proc) |
|
print( |
|
f"Filtered {round((1-len(trainset)/len(data))*100)}%; {len(trainset)} remain\n" |
|
) |
|
print(f"Filtering evalation data") |
|
evalset = data.filter(if_contains_eval_label, num_proc=num_proc) |
|
print( |
|
f"Filtered {round((1-len(evalset)/len(data))*100)}%; {len(evalset)} remain\n" |
|
) |
|
|
|
|
|
training_size = min(subsample_size, len(trainset)) |
|
trainset_min = trainset.select([i for i in range(training_size)]) |
|
eval_size = min(training_size, len(evalset)) |
|
half_training_size = round(eval_size / 2) |
|
evalset_train_min = evalset.select([i for i in range(half_training_size)]) |
|
evalset_oos_min = evalset.select( |
|
[i for i in range(half_training_size, eval_size)] |
|
) |
|
|
|
|
|
def generate_train_labels(example): |
|
example["labels"] = [ |
|
label_dict_train.get(token_id, -100) |
|
for token_id in example["input_ids"] |
|
] |
|
return example |
|
|
|
def generate_eval_labels(example): |
|
example["labels"] = [ |
|
label_dict_eval.get(token_id, -100) for token_id in example["input_ids"] |
|
] |
|
return example |
|
|
|
|
|
print(f"Labeling training data") |
|
trainset_labeled = trainset_min.map(generate_train_labels) |
|
print(f"Labeling evaluation data") |
|
evalset_train_labeled = evalset_train_min.map(generate_eval_labels) |
|
print(f"Labeling evaluation OOS data") |
|
evalset_oos_labeled = evalset_oos_min.map(generate_eval_labels) |
|
|
|
|
|
ksplit_output_dir = os.path.join(output_dir, f"ksplit{iteration_num}") |
|
ksplit_model_dir = os.path.join(ksplit_output_dir, "models/") |
|
|
|
|
|
model_output_file = os.path.join(ksplit_model_dir, "pytorch_model.bin") |
|
|
|
|
|
|
|
|
|
subprocess.call(f"mkdir -p {ksplit_output_dir}", shell=True) |
|
subprocess.call(f"mkdir -p {ksplit_model_dir}", shell=True) |
|
|
|
|
|
model = BertForTokenClassification.from_pretrained( |
|
pre_model, |
|
num_labels=num_labels, |
|
output_attentions=False, |
|
output_hidden_states=False, |
|
) |
|
if freeze_layers is not None: |
|
modules_to_freeze = model.bert.encoder.layer[:freeze_layers] |
|
for module in modules_to_freeze: |
|
for param in module.parameters(): |
|
param.requires_grad = False |
|
|
|
model = model.to(device) |
|
|
|
|
|
training_args["output_dir"] = ksplit_output_dir |
|
training_args_init = TrainingArguments(**training_args) |
|
|
|
|
|
trainer = Trainer( |
|
model=model, |
|
args=training_args_init, |
|
data_collator=DataCollatorForGeneClassification(), |
|
train_dataset=trainset_labeled, |
|
eval_dataset=evalset_train_labeled, |
|
) |
|
|
|
|
|
trainer.train() |
|
|
|
|
|
trainer.save_model(ksplit_model_dir) |
|
|
|
|
|
fpr, tpr, interp_tpr, conf_mat = classifier_predict( |
|
trainer.model, evalset_oos_labeled, 200, mean_fpr |
|
) |
|
|
|
|
|
confusion = confusion + conf_mat |
|
all_tpr.append(interp_tpr) |
|
all_roc_auc.append(auc(fpr, tpr)) |
|
|
|
all_tpr_wt.append(len(tpr)) |
|
|
|
iteration_num = iteration_num + 1 |
|
|
|
|
|
mean_tpr, roc_auc, roc_auc_sd = get_cross_valid_metrics( |
|
all_tpr, all_roc_auc, all_tpr_wt |
|
) |
|
return all_roc_auc, roc_auc, roc_auc_sd, mean_fpr, mean_tpr, confusion, label_dicts |
|
|
|
|
|
|
|
def compute_metrics(pred): |
|
labels = pred.label_ids |
|
preds = pred.predictions.argmax(-1) |
|
|
|
acc = accuracy_score(labels, preds) |
|
macro_f1 = f1_score(labels, preds, average="macro") |
|
|
|
return {"accuracy": acc, "macro_f1": macro_f1} |
|
|
|
|
|
|
|
def plot_ROC(bundled_data, title): |
|
plt.figure() |
|
lw = 2 |
|
for roc_auc, roc_auc_sd, mean_fpr, mean_tpr, sample, color in bundled_data: |
|
plt.plot( |
|
mean_fpr, |
|
mean_tpr, |
|
color=color, |
|
lw=lw, |
|
label="{0} (AUC {1:0.2f} $\pm$ {2:0.2f})".format( |
|
sample, roc_auc, roc_auc_sd |
|
), |
|
) |
|
|
|
plt.plot([0, 1], [0, 1], color="black", lw=lw, linestyle="--") |
|
plt.xlim([0.0, 1.0]) |
|
plt.ylim([0.0, 1.05]) |
|
plt.xlabel("False Positive Rate") |
|
plt.ylabel("True Positive Rate") |
|
plt.title(title) |
|
plt.legend(loc="lower right") |
|
plt.savefig("ROC.png") |
|
|
|
return mean_fpr, mean_tpr, roc_auc |
|
|
|
|
|
|
|
def plot_confusion_matrix(classes_list, conf_mat, title): |
|
display_labels = [] |
|
i = 0 |
|
for label in classes_list: |
|
display_labels += ["{0}\nn={1:.0f}".format(label, sum(conf_mat[:, i]))] |
|
i = i + 1 |
|
display = ConfusionMatrixDisplay( |
|
confusion_matrix=preprocessing.normalize(conf_mat, norm="l1"), |
|
display_labels=display_labels, |
|
) |
|
display.plot(cmap="Blues", values_format=".2g") |
|
plt.title(title) |
|
plt.savefig("CM.png") |
|
|
|
|
|
|
|
|
|
def find_largest_div(N, K): |
|
rem = N % K |
|
if rem == 0: |
|
return N |
|
else: |
|
return N - rem |
|
|
|
|
|
def preprocess_classifier_batch(cell_batch, max_len): |
|
if max_len == None: |
|
max_len = max([len(i) for i in cell_batch["input_ids"]]) |
|
|
|
def pad_label_example(example): |
|
example["labels"] = np.pad( |
|
example["labels"], |
|
(0, max_len - len(example["input_ids"])), |
|
mode="constant", |
|
constant_values=-100, |
|
) |
|
example["input_ids"] = np.pad( |
|
example["input_ids"], |
|
(0, max_len - len(example["input_ids"])), |
|
mode="constant", |
|
constant_values=token_dictionary.get("<pad>"), |
|
) |
|
example["attention_mask"] = ( |
|
example["input_ids"] != token_dictionary.get("<pad>") |
|
).astype(int) |
|
return example |
|
|
|
padded_batch = cell_batch.map(pad_label_example) |
|
return padded_batch |
|
|
|
|
|
|
|
def classifier_predict(model, evalset, forward_batch_size, mean_fpr): |
|
predict_logits = [] |
|
predict_labels = [] |
|
model.to("cpu") |
|
model.eval() |
|
|
|
|
|
evalset_len = len(evalset) |
|
max_divisible = find_largest_div(evalset_len, forward_batch_size) |
|
if len(evalset) - max_divisible == 1: |
|
evalset_len = max_divisible |
|
|
|
max_evalset_len = max(evalset.select([i for i in range(evalset_len)])["length"]) |
|
|
|
for i in range(0, evalset_len, forward_batch_size): |
|
max_range = min(i + forward_batch_size, evalset_len) |
|
batch_evalset = evalset.select([i for i in range(i, max_range)]) |
|
padded_batch = preprocess_classifier_batch(batch_evalset, max_evalset_len) |
|
padded_batch.set_format(type="torch") |
|
|
|
input_data_batch = padded_batch["input_ids"] |
|
attn_msk_batch = padded_batch["attention_mask"] |
|
label_batch = padded_batch["labels"] |
|
with torch.no_grad(): |
|
input_ids = input_data_batch |
|
attn_mask = attn_msk_batch |
|
labels = label_batch |
|
outputs = model( |
|
input_ids=input_ids, attention_mask=attn_mask, labels=labels |
|
) |
|
predict_logits += [torch.squeeze(outputs.logits.to("cpu"))] |
|
predict_labels += [torch.squeeze(label_batch.to("cpu"))] |
|
|
|
logits_by_cell = torch.cat(predict_logits) |
|
all_logits = logits_by_cell.reshape(-1, logits_by_cell.shape[2]) |
|
labels_by_cell = torch.cat(predict_labels) |
|
all_labels = torch.flatten(labels_by_cell) |
|
logit_label_paired = [ |
|
item |
|
for item in list(zip(all_logits.tolist(), all_labels.tolist())) |
|
if item[1] != -100 |
|
] |
|
y_pred = [vote(item[0]) for item in logit_label_paired] |
|
y_true = [item[1] for item in logit_label_paired] |
|
logits_list = [item[0] for item in logit_label_paired] |
|
|
|
y_score = [py_softmax(item)[1] for item in logits_list] |
|
conf_mat = confusion_matrix(y_true, y_pred) |
|
fpr, tpr, _ = roc_curve(y_true, y_score) |
|
|
|
plt.plot(fpr, tpr) |
|
plt.xlim([0.0, 1.0]) |
|
plt.ylim([0.0, 1.05]) |
|
plt.xlabel("False Positive Rate") |
|
plt.ylabel("True Positive Rate") |
|
plt.title("ROC") |
|
plt.show() |
|
|
|
interp_tpr = np.interp(mean_fpr, fpr, tpr) |
|
interp_tpr[0] = 0.0 |
|
return fpr, tpr, interp_tpr, conf_mat |
|
|
|
|
|
def classify_genes( |
|
gene_info="Genecorpus-30M/example_input_files/gene_info_table.csv", |
|
genes="Genecorpus-30M/example_input_files/gene_classification/dosage_sensitive_tfs/dosage_sens_tf_labels.csv", |
|
corpus_30M="Genecorpus-30M/genecorpus_30M_2048.dataset/", |
|
model=".", |
|
max_input_size=2**11, |
|
max_lr=5e-5, |
|
freeze_layers=4, |
|
num_gpus=1, |
|
num_proc=os.cpu_count(), |
|
geneformer_batch_size=9, |
|
epochs=1, |
|
filter_dataset=50_000, |
|
emb_extract=True, |
|
emb_layer=0, |
|
forward_batch=200, |
|
filter_data=None, |
|
inference=False, |
|
k_validate=True, |
|
model_location="230917_geneformer_GeneClassifier_dosageTF_L2048_B12_LR5e-05_LSlinear_WU500_E1_Oadamw_n10000_F4/", |
|
skip_training=False, |
|
emb_dir="gene_emb", |
|
output_dir=None, |
|
max_cells=1000, |
|
num_cpus=os.cpu_count(), |
|
): |
|
""" " |
|
Primary Parameters |
|
----------- |
|
|
|
gene_info: path |
|
Path to gene mappings |
|
|
|
corpus_30M: path |
|
Path to 30M Gene Corpus |
|
|
|
model: path |
|
Path to pretrained GeneFormer model |
|
|
|
genes: path |
|
Path to csv file containing different columns of genes and the column labels |
|
|
|
inference: bool |
|
Whether the model should be used to run inference. If False, model will train with labeled data instead. Defaults to False |
|
|
|
k_validate: bool |
|
Whether the model should run k-fold validation or simply perform regular training/evaluate. Defaults to True |
|
|
|
skip_training: bool |
|
Whether the model should skip the training portion. Defaults to False |
|
|
|
emb_extract: bool |
|
WHether the model should extract embeddings for a given gene (WIP) |
|
|
|
|
|
Customization Parameters |
|
----------- |
|
|
|
freeze_layers: int |
|
Freezes x number of layers from the model. Default is 4 (2 non-frozen layers) |
|
|
|
filter_dataset: int |
|
Number of cells to filter from 30M dataset. Default is 50_000 |
|
|
|
emb_layer: int |
|
What layer embeddings are extracted from. Default is 4 |
|
|
|
filter_data: str, list |
|
Filters down embeddings to a single category. Default is None |
|
|
|
|
|
""" |
|
|
|
|
|
gene_info = pd.read_csv(gene_info, index_col=0) |
|
labels = gene_info.columns |
|
|
|
|
|
gene_id_type_dict = dict(zip(gene_info["ensembl_id"], gene_info["gene_type"])) |
|
gene_name_id_dict = dict(zip(gene_info["gene_name"], gene_info["ensembl_id"])) |
|
gene_id_name_dict = {v: k for k, v in gene_name_id_dict.items()} |
|
|
|
|
|
def prep_inputs(label_store, id_type): |
|
target_list = [] |
|
if id_type == "gene_name": |
|
for key in list(label_store.keys()): |
|
targets = [ |
|
gene_name_id_dict[gene] |
|
for gene in label_store[key] |
|
if gene_name_id_dict.get(gene) in token_dictionary |
|
] |
|
targets_id = [token_dictionary[gene] for gene in targets] |
|
target_list.append(targets_id) |
|
elif id_type == "ensembl_id": |
|
for key in list(label_store.keys()): |
|
targets = [ |
|
gene for gene in label_store[key] if gene in token_dictionary |
|
] |
|
targets_id = [token_dictionary[gene] for gene in targets] |
|
target_list.append(targets_id) |
|
|
|
targets, labels = [], [] |
|
for targ in target_list: |
|
targets = targets + targ |
|
targets = np.array(targets) |
|
for num, targ in enumerate(target_list): |
|
label = [num] * len(targ) |
|
labels = labels + label |
|
labels = np.array(labels) |
|
unique_labels = num + 1 |
|
|
|
nsplits = min(5, min([len(targ) for targ in target_list]) - 1) |
|
assert nsplits > 2 |
|
|
|
return targets, labels, nsplits, unique_labels |
|
|
|
if skip_training == False: |
|
|
|
gene_classes = pd.read_csv(genes, header=0) |
|
if filter_data == None: |
|
labels = gene_classes.columns |
|
else: |
|
if isinstance(filter_data, list): |
|
labels = filter_data |
|
else: |
|
labels = [filter_data] |
|
label_store = {} |
|
|
|
|
|
decode = {i: labels[i] for i in range(len(labels))} |
|
|
|
for label in labels: |
|
label_store[label] = gene_classes[label].dropna() |
|
|
|
targets, labels, nsplits, unique_labels = prep_inputs(label_store, "ensembl_id") |
|
|
|
|
|
train_dataset = load_from_disk(corpus_30M) |
|
shuffled_train_dataset = train_dataset.shuffle(seed=42) |
|
subsampled_train_dataset = shuffled_train_dataset.select( |
|
[i for i in range(filter_dataset)] |
|
) |
|
lr_schedule_fn = "linear" |
|
warmup_steps = 500 |
|
optimizer = "adamw" |
|
subsample_size = 10_000 |
|
|
|
training_args = { |
|
"learning_rate": max_lr, |
|
"do_train": True, |
|
"evaluation_strategy": "no", |
|
"save_strategy": "epoch", |
|
"logging_steps": 10, |
|
"group_by_length": True, |
|
"length_column_name": "length", |
|
"disable_tqdm": False, |
|
"lr_scheduler_type": lr_schedule_fn, |
|
"warmup_steps": warmup_steps, |
|
"weight_decay": 0.001, |
|
"per_device_train_batch_size": geneformer_batch_size, |
|
"per_device_eval_batch_size": geneformer_batch_size, |
|
"num_train_epochs": epochs, |
|
} |
|
|
|
|
|
current_date = datetime.datetime.now() |
|
datestamp = f"{str(current_date.year)[-2:]}{current_date.month:02d}{current_date.day:02d}" |
|
|
|
if output_dir == None: |
|
training_output_dir = Path( |
|
f"{datestamp}_geneformer_GeneClassifier_dosageTF_L{max_input_size}_B{geneformer_batch_size}_LR{max_lr}_LS{lr_schedule_fn}_WU{warmup_steps}_E{epochs}_O{optimizer}_n{subsample_size}_F{freeze_layers}/" |
|
) |
|
else: |
|
training_output_dir = Path(output_dir) |
|
|
|
|
|
subprocess.call(f"mkdir -p {training_output_dir}", shell=True) |
|
|
|
|
|
num_classes = len(set(labels)) |
|
info_list = [num_classes, decode] |
|
|
|
with open(training_output_dir / "classes.txt", "w") as f: |
|
f.write(str(info_list)) |
|
|
|
subsampled_train_dataset.save_to_disk(output_dir / "dataset") |
|
|
|
if k_validate == True: |
|
ksplit_model = "ksplit0/models" |
|
ksplit_model_test = os.path.join(training_output_dir, ksplit_model) |
|
|
|
|
|
|
|
( |
|
all_roc_auc, |
|
roc_auc, |
|
roc_auc_sd, |
|
mean_fpr, |
|
mean_tpr, |
|
confusion, |
|
label_dicts, |
|
) = cross_validate( |
|
subsampled_train_dataset, |
|
targets, |
|
labels, |
|
nsplits, |
|
subsample_size, |
|
training_args, |
|
freeze_layers, |
|
training_output_dir, |
|
1, |
|
unique_labels, |
|
model, |
|
) |
|
|
|
bundled_data = [] |
|
bundled_data += [ |
|
(roc_auc, roc_auc_sd, mean_fpr, mean_tpr, "Geneformer", "red") |
|
] |
|
graph_title = " ".join( |
|
[ |
|
i + " vs" if count < len(label_store) - 1 else i |
|
for count, i in enumerate(label_store) |
|
] |
|
) |
|
fpr, tpr, auc = plot_ROC( |
|
bundled_data, "Dosage Sensitive vs Insensitive TFs" |
|
) |
|
print(auc) |
|
|
|
plot_confusion_matrix(label_store, confusion, "Geneformer") |
|
else: |
|
fpr, tpr, auc = validate( |
|
subsampled_train_dataset, |
|
targets, |
|
labels, |
|
nsplits, |
|
subsample_size, |
|
training_args, |
|
freeze_layers, |
|
training_output_dir, |
|
1, |
|
unique_labels, |
|
model, |
|
) |
|
print(auc) |
|
|
|
if inference == True: |
|
|
|
gene_classes = pd.read_csv(genes, header=0) |
|
targets = [] |
|
for column in gene_classes.columns: |
|
targets += list(gene_classes[column]) |
|
tokens = [] |
|
for target in targets: |
|
try: |
|
tokens.append(token_dictionary[target]) |
|
except: |
|
tokens.append(0) |
|
|
|
targets = torch.LongTensor([tokens]) |
|
|
|
with open(f"{model_location}classes.txt", "r") as f: |
|
info_list = ast.literal_eval(f.read()) |
|
num_classes = info_list[0] |
|
labels = info_list[1] |
|
|
|
model = BertForTokenClassification.from_pretrained( |
|
model_location, |
|
num_labels=num_classes, |
|
output_attentions=False, |
|
output_hidden_states=False, |
|
local_files_only=True, |
|
) |
|
if freeze_layers is not None: |
|
modules_to_freeze = model.bert.encoder.layer[:freeze_layers] |
|
for module in modules_to_freeze: |
|
for param in module.parameters(): |
|
param.requires_grad = False |
|
|
|
model = model.to(device) |
|
|
|
|
|
predictions = F.softmax(model(targets.to(device))["logits"], dim=-1).argmax(-1)[ |
|
0 |
|
] |
|
predictions = [labels[int(pred)] for pred in predictions] |
|
|
|
return predictions |
|
|
|
|
|
if emb_extract == True: |
|
with open(f"{model_location}/classes.txt", "r") as f: |
|
data = ast.literal_eval(f.read()) |
|
num_classes = data[0] |
|
decode = data[1] |
|
|
|
gene_classes = pd.read_csv(genes, header=0) |
|
labels = gene_classes.columns |
|
tokenize = TranscriptomeTokenizer() |
|
|
|
label_dict = {} |
|
for label in labels: |
|
genes = gene_classes[label] |
|
tokenized_genes = [] |
|
for gene in genes: |
|
try: |
|
tokenized_genes.append(tokenize.gene_token_dict[gene]) |
|
except: |
|
continue |
|
label_dict[label] = tokenized_genes |
|
|
|
embex = EmbExtractor( |
|
model_type="GeneClassifier", |
|
num_classes=num_classes, |
|
emb_mode="gene", |
|
filter_data=None, |
|
max_ncells=max_cells, |
|
emb_layer=emb_layer, |
|
emb_label=label_dict, |
|
labels_to_plot=list(labels), |
|
forward_batch_size=forward_batch, |
|
nproc=num_cpus, |
|
) |
|
|
|
subprocess.call(f"mkdir -p {emb_dir}", shell=True) |
|
|
|
embs = embex.extract_embs( |
|
model_directory=model_location, |
|
input_data_file=model_location / "dataset", |
|
output_directory=emb_dir, |
|
output_prefix=f"{label}_embbeddings", |
|
) |
|
|
|
emb_dict = {label: [] for label in list(set(labels))} |
|
similarities = {key: {} for key in list(emb_dict.keys())} |
|
|
|
for column in embs.columns: |
|
remaining_cols = [k for k in embs.columns if k != column] |
|
for k in remaining_cols: |
|
embedding = torch.Tensor(embs[k]) |
|
sim = similarity(torch.Tensor(embs[column]), embedding, cosine=True) |
|
similarities[column][k] = sim |
|
|
|
plot_similarity_heatmap(similarities) |
|
print(similarities) |
|
|
|
return similarities |
|
|
|
|
|
if __name__ == "__main__": |
|
classify_genes( |
|
k_validate=False, |
|
inference=False, |
|
skip_training=False, |
|
emb_extract=True, |
|
output_dir=Path("gene_emb"), |
|
model_location=Path("gene_emb"), |
|
epochs=5, |
|
gene_info="../GeneFormer_repo/Genecorpus-30M/example_input_files/gene_info_table.csv", |
|
genes="../GeneFormer_repo/Genecorpus-30M/example_input_files/gene_classification/dosage_sensitive_tfs/dosage_sens_tf_labels.csv", |
|
corpus_30M="../GeneFormer_repo/Genecorpus-30M/genecorpus_30M_2048.dataset/", |
|
) |
|
|