Christina Theodoris
commited on
Commit
•
025e1b8
1
Parent(s):
e562c0c
update cell classifier module
Browse files- .pre-commit-config.yaml +25 -0
- Immune_modelpredictions.pickle +0 -0
- gene_embclasses.txt +0 -1
- gene_embdataset.pk +0 -0
- Cell_classifier.py → geneformer/cell_classifier.py +612 -599
- Gene_classifier.py → geneformer/gene_classifier.py +448 -259
- Modular_usage.md → geneformer/modular_classifier_usage.md +76 -76
.pre-commit-config.yaml
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# See https://pre-commit.com for more information
|
2 |
+
# See https://pre-commit.com/hooks.html for more hooks
|
3 |
+
repos:
|
4 |
+
- repo: https://github.com/pre-commit/pre-commit-hooks
|
5 |
+
rev: v3.2.0
|
6 |
+
hooks:
|
7 |
+
- id: trailing-whitespace
|
8 |
+
- id: end-of-file-fixer
|
9 |
+
- id: check-yaml
|
10 |
+
- id: check-added-large-files
|
11 |
+
- id: check-merge-conflict
|
12 |
+
- id: mixed-line-ending
|
13 |
+
- id: check-docstring-first
|
14 |
+
- repo: https://github.com/pycqa/isort
|
15 |
+
rev: 5.12.0
|
16 |
+
hooks:
|
17 |
+
- id: isort
|
18 |
+
- repo: https://github.com/astral-sh/ruff-pre-commit
|
19 |
+
# Ruff version.
|
20 |
+
rev: v0.1.4
|
21 |
+
hooks:
|
22 |
+
# Run the Ruff linter.
|
23 |
+
- id: ruff
|
24 |
+
# Run the Ruff formatter.
|
25 |
+
- id: ruff-format
|
Immune_modelpredictions.pickle
DELETED
Binary file (99.1 kB)
|
|
gene_embclasses.txt
DELETED
@@ -1 +0,0 @@
|
|
1 |
-
[2, {0: 0, 1: 0}]
|
|
|
|
gene_embdataset.pk
DELETED
Binary file (1.76 kB)
|
|
Cell_classifier.py → geneformer/cell_classifier.py
RENAMED
@@ -1,65 +1,95 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
import os
|
7 |
-
|
8 |
-
from collections import Counter
|
9 |
import random
|
10 |
-
import datetime
|
11 |
-
from pathlib import Path
|
12 |
import subprocess
|
13 |
-
|
14 |
-
from
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
from transformers import Trainer
|
19 |
-
from transformers.training_args import TrainingArguments
|
20 |
-
from geneformer import DataCollatorForCellClassification, EmbExtractor
|
21 |
-
import pickle
|
22 |
-
from sklearn.metrics import roc_curve, roc_auc_score, precision_recall_curve
|
23 |
-
from sklearn.metrics import auc as precision_auc
|
24 |
-
from sklearn.preprocessing import label_binarize
|
25 |
-
import pyarrow as pa
|
26 |
-
import concurrent.futures
|
27 |
-
from matplotlib import pyplot as plt
|
28 |
import torch
|
29 |
import torch.nn.functional as F
|
30 |
-
from
|
31 |
-
import
|
32 |
-
import ast
|
33 |
from ray import tune
|
34 |
-
from ray.tune import ExperimentAnalysis
|
35 |
from ray.tune.search.hyperopt import HyperOptSearch
|
36 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
37 |
|
38 |
# Properly sets up NCCV environment
|
39 |
-
GPU_NUMBER = [i for i in range(torch.cuda.device_count())]
|
40 |
os.environ["CUDA_VISIBLE_DEVICES"] = ",".join([str(s) for s in GPU_NUMBER])
|
41 |
os.environ["NCCL_DEBUG"] = "INFO"
|
42 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
43 |
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
fpr, tpr, _ = roc_curve(truth, prediction[:, 1])
|
48 |
auc = roc_auc_score(truth, prediction[:, 1])
|
49 |
-
print(f
|
50 |
-
plt.plot(fpr,tpr, label="AUC="+str(auc))
|
51 |
-
plt.ylabel(
|
52 |
-
plt.xlabel(
|
53 |
-
plt.title(f
|
54 |
plt.legend(loc=4)
|
55 |
-
plt.savefig(
|
56 |
-
|
57 |
-
return tpr, fpr, auc
|
58 |
-
|
|
|
59 |
# Identifies cosine similarity between two embeddings. 0 is perfectly dissimilar and 1 is perfectly similar
|
60 |
-
def similarity(tensor1, tensor2, cosine
|
61 |
-
|
62 |
-
if cosine == False:
|
63 |
if tensor1.ndimension() > 1:
|
64 |
tensor1 = tensor1.view(1, -1)
|
65 |
if tensor2.ndimension() > 1:
|
@@ -69,7 +99,7 @@ def similarity(tensor1, tensor2, cosine = False):
|
|
69 |
norm_tensor2 = torch.norm(tensor2)
|
70 |
epsilon = 1e-8
|
71 |
similarity = dot_product / (norm_tensor1 * norm_tensor2 + epsilon)
|
72 |
-
similarity = (similarity.item() + 1)/2
|
73 |
else:
|
74 |
if tensor1.shape != tensor2.shape:
|
75 |
raise ValueError("Input tensors must have the same shape.")
|
@@ -78,13 +108,14 @@ def similarity(tensor1, tensor2, cosine = False):
|
|
78 |
dot_product = torch.dot(tensor1, tensor2)
|
79 |
norm_tensor1 = torch.norm(tensor1)
|
80 |
norm_tensor2 = torch.norm(tensor2)
|
81 |
-
|
82 |
# Avoid division by zero by adding a small epsilon
|
83 |
epsilon = 1e-8
|
84 |
similarity = dot_product / (norm_tensor1 * norm_tensor2 + epsilon)
|
85 |
-
|
86 |
return similarity.item()
|
87 |
-
|
|
|
88 |
# Plots heatmap between different classes/labels
|
89 |
def plot_similarity_heatmap(similarities):
|
90 |
classes = list(similarities.keys())
|
@@ -97,327 +128,256 @@ def plot_similarity_heatmap(similarities):
|
|
97 |
else:
|
98 |
val = similarities[c][cc]
|
99 |
arr[i][j] = val
|
100 |
-
|
101 |
plt.figure(figsize=(8, 6))
|
102 |
-
plt.imshow(arr, cmap=
|
103 |
plt.colorbar()
|
104 |
-
plt.xticks(np.arange(classlen), classes, rotation
|
105 |
plt.yticks(np.arange(classlen), classes)
|
106 |
plt.title("Similarity Heatmap")
|
107 |
plt.savefig("similarity_heatmap.png")
|
108 |
-
|
109 |
-
# Function for tokenizing genes into ranked-value encodings from Geneformer
|
110 |
-
def tokenize_dataset(gene_set, type = None, token_set = 'token_dictionary.pkl', species = 'human'):
|
111 |
-
token_dataset = open(token_set, 'rb')
|
112 |
-
token_dict = pickle.load(token_dataset)
|
113 |
-
wrap = True
|
114 |
-
|
115 |
-
if isinstance(gene_set[0], list) == False:
|
116 |
-
gene_set = [gene_set]
|
117 |
-
wrap = False
|
118 |
-
|
119 |
-
pool = Pool()
|
120 |
-
converted_set = []
|
121 |
-
|
122 |
-
def process_gene(gene):
|
123 |
-
api_url = f"https://rest.ensembl.org/xrefs/symbol/{species}/{gene}?object_type=gene"
|
124 |
-
response = requests.get(api_url, headers={"Content-Type": "application/json"})
|
125 |
-
try:
|
126 |
-
data = response.json()
|
127 |
-
gene = data[0]['id']
|
128 |
-
except:
|
129 |
-
gene = None
|
130 |
-
return gene
|
131 |
-
|
132 |
-
def process_hgnc(gene):
|
133 |
-
for gene in tqdm.tqdm(genes, total = len(genes)):
|
134 |
-
api_url = f"https://rest.ensembl.org/xrefs/symbol/{species}/{hgnc_id}?object_type=gene"
|
135 |
-
response = requests.get(api_url, headers={"Content-Type": "application/json"})
|
136 |
-
try:
|
137 |
-
data = response.json()
|
138 |
-
gene = data[0]['id']
|
139 |
-
except:
|
140 |
-
gene = None
|
141 |
-
return gene
|
142 |
-
|
143 |
-
def process_go(gene):
|
144 |
-
mg = mygene.MyGeneInfo()
|
145 |
-
results = mg.query(gene, scopes="go", species=species, fields="ensembl.gene")
|
146 |
-
|
147 |
-
ensembl_ids = []
|
148 |
-
max_score = 0
|
149 |
-
for hit_num, hit in enumerate(results["hits"]):
|
150 |
-
if hit['_score'] > max_score:
|
151 |
-
max_score = hit['_score']
|
152 |
-
chosen_hit = hit
|
153 |
-
try:
|
154 |
-
try:
|
155 |
-
gene = chosen_hit["ensembl"]["gene"]
|
156 |
-
except:
|
157 |
-
gene = chosen_hit["ensembl"][0]["gene"]
|
158 |
-
except:
|
159 |
-
gene = None
|
160 |
-
return gene
|
161 |
-
|
162 |
-
if type == None or type.upper() == 'ENSEMBL':
|
163 |
-
converted_set = gene_set
|
164 |
-
elif type.upper() == 'GENE':
|
165 |
-
for genes in gene_set:
|
166 |
-
converted_genes = []
|
167 |
-
for result in tqdm.tqdm(pool.imap(process_gene, genes), total = len(genes)):
|
168 |
-
converted_genes.append(result)
|
169 |
-
converted_set.append(converted_genes)
|
170 |
-
elif type.upper() == 'GO':
|
171 |
-
for genes in gene_set:
|
172 |
-
converted_genes = []
|
173 |
-
for result in tqdm.tqdm(pool.imap(process_go, genes), total = len(genes)):
|
174 |
-
converted_genes.append(result)
|
175 |
-
converted_set.append(converted_genes)
|
176 |
-
elif type.upper() == 'HGNC':
|
177 |
-
for genes in gene_set:
|
178 |
-
converted_genes = []
|
179 |
-
for result in tqdm.tqdm(pool.imap(process_hgnc, genes), total = len(genes)):
|
180 |
-
converted_genes.append(result)
|
181 |
-
converted_set.append(converted_genes)
|
182 |
-
|
183 |
-
Chembl = []
|
184 |
-
for set_num, set in enumerate(converted_set):
|
185 |
-
Chembl.append([])
|
186 |
-
for gene in set:
|
187 |
-
if gene == None:
|
188 |
-
Chembl[set_num].append(None)
|
189 |
-
else:
|
190 |
-
try:
|
191 |
-
Chembl[set_num].append(token_dict[gene])
|
192 |
-
except:
|
193 |
-
print(f'{gene} not found in tokenized dataset!')
|
194 |
-
Chembl[set_num].append(None)
|
195 |
-
|
196 |
-
if wrap == False:
|
197 |
-
Chembl = Chembl[0]
|
198 |
-
|
199 |
-
return Chembl
|
200 |
-
|
201 |
-
|
202 |
-
# '/work/ccnr/GeneFormer/GeneFormer_repo/Genecorpus-30M/example_input_files/cell_classification/disease_classification/human_dcm_hcm_nf.dataset/'
|
203 |
-
# '/work/ccnr/GeneFormer/GeneFormer_repo/Genecorpus-30M/example_input_files/cell_classification/cell_type_annotation/cell_type_train_data.dataset/'
|
204 |
-
'''
|
205 |
-
======================================================
|
206 |
|
207 |
-
PRIMARY CELL - CLASSIFIER AND EMBEDDING EXTRACTOR CLASS
|
208 |
-
|
209 |
-
+++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
210 |
|
211 |
-
|
212 |
-
|
213 |
-
|
214 |
-
|
215 |
-
|
216 |
-
|
217 |
-
|
218 |
-
|
219 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
220 |
Primary Parameters
|
221 |
-------------------
|
222 |
dataset: path
|
223 |
-
Path to fine-tuning
|
|
|
|
|
|
|
224 |
|
225 |
-
model_location: path
|
226 |
-
Path to location of existing model to use for inference and embedding extraction
|
227 |
-
|
228 |
pretrained_model: path
|
229 |
-
Path to pretrained
|
230 |
-
|
231 |
-
inference: bool
|
232 |
-
|
233 |
-
|
234 |
-
skip_training: bool
|
235 |
-
|
236 |
-
|
237 |
emb_extract: bool
|
238 |
-
|
239 |
-
|
240 |
optimize_hyperparameters: bool
|
241 |
-
|
242 |
-
|
243 |
-
|
244 |
Customization Parameters
|
245 |
-------------------
|
246 |
-
|
247 |
dataset_split: str
|
248 |
-
|
249 |
-
|
250 |
data_filter: list
|
251 |
-
(For embeddings and inference) Runs analysis subsets of the dataset
|
252 |
-
|
253 |
label: str
|
254 |
-
|
255 |
-
|
256 |
emb_layer: int
|
257 |
-
What layer embeddings should be extracted and compared
|
258 |
-
|
259 |
emb_filter: ['cell1', 'cell2'...]
|
260 |
Allows user to narrow down range of cells that embeddings will be extracted from.
|
261 |
-
|
262 |
max_cells: int
|
263 |
-
|
264 |
-
|
265 |
freeze_layers: int
|
266 |
-
Number of layers should be
|
267 |
-
|
268 |
sample_data: float
|
269 |
-
|
270 |
-
|
271 |
-
|
272 |
-
|
273 |
dataset_list = []
|
274 |
evalset_list = []
|
275 |
split_list = []
|
276 |
target_dict_list = []
|
277 |
-
|
278 |
-
'''
|
279 |
-
For loading and pretraining with custom median expressions and/or custom gene conversions
|
280 |
-
-------------------------------------------------------------
|
281 |
-
|
282 |
-
token set: path
|
283 |
-
Path to token conversion dictionary
|
284 |
-
|
285 |
-
median set: path
|
286 |
-
Path to median gene dictionary (ensembl IDs as the keys)
|
287 |
-
|
288 |
-
|
289 |
-
median_data = pickle.load(open(median_set, 'rb'))
|
290 |
-
median_data['<pad>'] = None
|
291 |
-
median_data['<mask>'] = None
|
292 |
-
|
293 |
-
token_set = pickle.load(open(token_set, 'rb'))
|
294 |
-
median_dict = {key:median_data[key] for key in list(token_set.keys())}
|
295 |
-
'''
|
296 |
-
|
297 |
train_dataset = load_from_disk(dataset)
|
298 |
num_samples = int(len(train_dataset) * sample_data)
|
299 |
random_indices = random.sample(range(len(train_dataset)), num_samples)
|
300 |
train_dataset = train_dataset.select(random_indices)
|
301 |
-
|
302 |
-
sample = int(sample_data
|
303 |
sample_indices = random.sample(range(len(train_dataset)), sample)
|
304 |
train_dataset = train_dataset.select(sample_indices)
|
305 |
|
306 |
-
def
|
307 |
return example[label] in cells_to_keep
|
308 |
-
|
309 |
# change labels to numerical ids
|
310 |
def classes_to_ids(example):
|
311 |
example["label"] = target_name_id_dict[example["label"]]
|
312 |
return example
|
313 |
-
|
314 |
def if_trained_label(example):
|
315 |
-
return example["label"] in trained_labels
|
316 |
-
|
317 |
-
if skip_training
|
|
|
318 |
def compute_metrics(pred):
|
319 |
labels = pred.label_ids
|
320 |
preds = pred.predictions.argmax(-1)
|
321 |
# calculate accuracy and macro f1 using sklearn's function
|
322 |
acc = accuracy_score(labels, preds)
|
323 |
-
macro_f1 = f1_score(labels, preds, average=
|
324 |
-
return {
|
325 |
-
|
326 |
-
'macro_f1': macro_f1
|
327 |
-
}
|
328 |
-
|
329 |
# Defines custom exceptions for collecting labels (default excluded)
|
330 |
-
excep = {"bone_marrow":"immune"}
|
331 |
-
|
332 |
-
if dataset_split
|
333 |
-
if data_filter
|
334 |
-
|
335 |
else:
|
336 |
-
|
337 |
for lab in split_iter:
|
338 |
-
|
339 |
# collect list of tissues for fine-tuning (immune and bone marrow are included together)
|
340 |
if lab in list(excep.keys()):
|
341 |
continue
|
342 |
elif lab == list(excep.values()):
|
343 |
-
split_ids = [excep.keys(),excep.values()]
|
344 |
split_list += [excep.values()]
|
345 |
else:
|
346 |
split_ids = [lab]
|
347 |
split_list += [lab]
|
348 |
-
|
349 |
# filter datasets for given organ
|
350 |
def if_label(example):
|
351 |
return example[dataset_split] == lab
|
352 |
-
|
353 |
trainset_label = train_dataset.filter(if_label, num_proc=cpu_cores)
|
354 |
label_counter = Counter(trainset_label[label])
|
355 |
total_cells = sum(label_counter.values())
|
356 |
-
|
357 |
-
#
|
358 |
-
cells_to_keep = [
|
359 |
-
|
360 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
361 |
# shuffle datasets and rename columns
|
362 |
trainset_label_shuffled = trainset_label_subset.shuffle(seed=42)
|
363 |
-
trainset_label_shuffled = trainset_label_shuffled.rename_column(
|
364 |
-
|
365 |
-
|
|
|
|
|
|
|
|
|
366 |
# create dictionary of cell types : label ids
|
367 |
target_names = list(Counter(trainset_label_shuffled["label"]).keys())
|
368 |
-
target_name_id_dict = dict(
|
|
|
|
|
369 |
target_dict_list += [target_name_id_dict]
|
370 |
-
|
371 |
-
labeled_trainset = trainset_label_shuffled.map(
|
372 |
-
|
|
|
|
|
373 |
# create 80/20 train/eval splits
|
374 |
-
labeled_train_split = trainset_label_shuffled.select(
|
375 |
-
|
376 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
377 |
# filter dataset for cell types in corresponding training set
|
378 |
trained_labels = list(Counter(labeled_train_split["label"]).keys())
|
379 |
-
|
380 |
-
labeled_eval_split_subset = labeled_eval_split.filter(
|
381 |
-
|
|
|
|
|
382 |
dataset_list += [labeled_train_split]
|
383 |
evalset_list += [labeled_eval_split_subset]
|
384 |
-
|
385 |
-
trainset_dict = dict(zip(split_list,dataset_list))
|
386 |
-
traintargetdict_dict = dict(zip(split_list,target_dict_list))
|
387 |
-
evalset_dict = dict(zip(split_list,evalset_list))
|
388 |
-
|
389 |
for lab in split_list:
|
390 |
label_trainset = trainset_dict[lab]
|
391 |
label_evalset = evalset_dict[lab]
|
392 |
label_dict = traintargetdict_dict[lab]
|
393 |
-
|
394 |
# set logging steps
|
395 |
-
logging_steps = round(len(label_trainset)/
|
396 |
if logging_steps == 0:
|
397 |
logging_steps = 1
|
398 |
-
|
399 |
-
#
|
400 |
-
model = BertForSequenceClassification.from_pretrained(
|
401 |
-
|
402 |
-
|
403 |
-
|
404 |
-
|
|
|
|
|
405 |
# define output directory path
|
406 |
current_date = datetime.datetime.now()
|
407 |
datestamp = f"{str(current_date.year)[-2:]}{current_date.month:02d}{current_date.day:02d}"
|
408 |
-
|
409 |
-
if output_dir
|
410 |
-
output_dir = f"{datestamp}_geneformer_CellClassifier_{lab}_L{max_input_size}_B{
|
411 |
-
|
412 |
# ensure not overwriting previously saved model
|
413 |
-
saved_model_test = os.path.join(output_dir,
|
414 |
-
|
415 |
-
if os.path.isfile(saved_model_test)
|
416 |
raise Exception("Model already saved to this directory.")
|
417 |
-
|
418 |
# make output directory
|
419 |
-
subprocess.call(f
|
420 |
-
|
421 |
# set training arguments
|
422 |
training_args = {
|
423 |
"learning_rate": max_lr,
|
@@ -432,332 +392,371 @@ def finetune_cells(token_set = Path('geneformer/token_dictionary.pkl'), median_s
|
|
432 |
"lr_scheduler_type": lr_schedule_fn,
|
433 |
"warmup_steps": warmup_steps,
|
434 |
"weight_decay": 0.001,
|
435 |
-
"per_device_train_batch_size":
|
436 |
-
"per_device_eval_batch_size":
|
437 |
"num_train_epochs": epochs,
|
438 |
"load_best_model_at_end": True,
|
439 |
"output_dir": output_dir,
|
440 |
}
|
441 |
-
|
442 |
-
|
443 |
training_args_init = TrainingArguments(**training_args)
|
444 |
-
true_labels = label_evalset[
|
445 |
-
|
446 |
-
|
447 |
-
if optimize_hyperparameters == False:
|
448 |
# create the trainer
|
449 |
trainer = Trainer(
|
450 |
-
|
451 |
-
|
452 |
-
|
453 |
-
|
454 |
-
|
455 |
-
|
456 |
)
|
457 |
-
|
458 |
# train the cell type classifier
|
459 |
trainer.train()
|
460 |
predictions = trainer.predict(label_evalset)
|
461 |
-
print(
|
462 |
-
|
|
|
|
|
463 |
tpr, fpr, auc = ROC(predictions.predictions, true_labels)
|
464 |
-
|
465 |
metrics = compute_metrics(predictions)
|
466 |
with open(f"{output_dir}predictions.pickle", "wb") as fp:
|
467 |
pickle.dump(predictions, fp)
|
468 |
-
|
469 |
-
trainer.save_metrics("eval",predictions.metrics)
|
470 |
-
|
471 |
-
with open(f
|
472 |
if len(target_dict_list) == 1:
|
473 |
f.write(str(target_dict_list[0]))
|
474 |
else:
|
475 |
f.write(str(target_dict_list))
|
476 |
-
|
477 |
-
|
478 |
-
|
479 |
-
|
480 |
-
|
481 |
-
|
482 |
-
|
483 |
-
|
484 |
-
|
485 |
-
|
486 |
-
|
|
|
487 |
trainer.save_model(output_dir)
|
488 |
else:
|
489 |
-
|
490 |
def model_init():
|
491 |
-
model = BertForSequenceClassification.from_pretrained(
|
492 |
-
|
493 |
-
|
494 |
-
|
|
|
|
|
495 |
if freeze_layers is not None:
|
496 |
modules_to_freeze = model.bert.encoder.layer[:freeze_layers]
|
497 |
for module in modules_to_freeze:
|
498 |
for param in module.parameters():
|
499 |
param.requires_grad = False
|
500 |
-
model = model.to(device)
|
501 |
return model
|
502 |
-
|
503 |
trainer = Trainer(
|
504 |
model_init=model_init,
|
505 |
args=training_args_init,
|
506 |
data_collator=DataCollatorForCellClassification(),
|
507 |
train_dataset=label_trainset,
|
508 |
eval_dataset=label_evalset,
|
509 |
-
compute_metrics=compute_metrics
|
510 |
)
|
511 |
# specify raytune hyperparameter search space
|
512 |
ray_config = {
|
513 |
"num_train_epochs": tune.choice([epochs]),
|
514 |
"learning_rate": tune.loguniform(1e-6, 1e-3),
|
515 |
"weight_decay": tune.uniform(0.0, 0.3),
|
516 |
-
"lr_scheduler_type": tune.choice(
|
|
|
|
|
517 |
"warmup_steps": tune.uniform(100, 2000),
|
518 |
-
"seed": tune.uniform(0,100),
|
519 |
-
"per_device_train_batch_size": tune.choice(
|
|
|
|
|
520 |
}
|
521 |
-
|
522 |
-
hyperopt_search = HyperOptSearch(
|
523 |
-
|
524 |
-
|
525 |
-
|
526 |
-
resources_per_trial={"cpu":8,"gpu":1},
|
527 |
else:
|
528 |
-
resources_per_trial={"cpu":8}
|
529 |
-
|
530 |
# optimize hyperparameters
|
531 |
best_trial = trainer.hyperparameter_search(
|
532 |
direction="maximize",
|
533 |
backend="ray",
|
534 |
-
resources_per_trial
|
535 |
hp_space=lambda _: ray_config,
|
536 |
search_alg=hyperopt_search,
|
537 |
-
n_trials=
|
538 |
-
progress_reporter=tune.CLIReporter(
|
539 |
-
|
540 |
-
|
541 |
-
|
542 |
-
|
543 |
-
|
|
|
|
|
|
|
544 |
best_hyperparameters = best_trial.hyperparameters
|
545 |
-
|
546 |
print("Best Hyperparameters:")
|
547 |
print(best_hyperparameters)
|
548 |
-
|
549 |
-
|
550 |
-
|
551 |
else:
|
552 |
-
|
553 |
-
|
554 |
-
|
555 |
-
|
556 |
-
|
557 |
-
|
558 |
-
|
559 |
-
|
560 |
-
|
561 |
-
|
562 |
-
|
563 |
-
|
564 |
-
|
565 |
-
|
566 |
-
|
567 |
-
|
568 |
-
|
569 |
-
|
570 |
-
|
571 |
-
|
572 |
-
|
573 |
-
|
574 |
-
|
575 |
-
|
576 |
-
|
577 |
-
|
578 |
-
|
579 |
-
|
580 |
-
|
581 |
-
|
582 |
-
|
583 |
-
|
584 |
-
|
585 |
-
|
586 |
-
|
587 |
-
|
588 |
-
|
589 |
-
|
590 |
-
|
591 |
-
|
592 |
-
|
593 |
-
|
594 |
-
|
595 |
-
|
596 |
-
|
597 |
-
|
598 |
-
|
599 |
-
|
600 |
-
|
601 |
-
|
602 |
-
|
603 |
-
|
604 |
-
|
605 |
-
|
606 |
-
|
607 |
-
|
608 |
-
|
609 |
-
|
610 |
-
|
611 |
-
|
612 |
-
|
613 |
-
|
614 |
-
|
615 |
-
|
616 |
-
|
617 |
-
|
618 |
-
|
619 |
-
|
620 |
-
|
621 |
-
|
622 |
-
|
623 |
-
|
624 |
-
|
625 |
-
|
626 |
-
|
627 |
-
|
628 |
-
|
629 |
-
|
630 |
-
|
631 |
-
|
632 |
-
|
633 |
-
|
634 |
-
|
635 |
-
|
636 |
-
|
637 |
-
|
638 |
-
|
639 |
-
|
640 |
-
|
641 |
-
|
642 |
-
|
643 |
-
|
644 |
-
|
645 |
-
|
646 |
-
|
647 |
-
|
648 |
-
|
649 |
-
|
650 |
-
|
651 |
-
|
652 |
-
|
653 |
-
|
654 |
-
|
655 |
-
|
656 |
-
|
657 |
-
|
658 |
-
|
659 |
-
|
660 |
-
|
661 |
-
|
662 |
-
|
663 |
-
|
664 |
-
|
665 |
-
|
666 |
-
|
667 |
-
|
668 |
-
|
669 |
-
|
670 |
-
|
671 |
-
|
672 |
-
|
673 |
-
|
674 |
-
|
675 |
-
|
676 |
-
|
677 |
-
|
678 |
-
|
679 |
-
|
680 |
-
|
681 |
-
|
682 |
-
|
683 |
-
|
684 |
-
|
685 |
-
|
686 |
-
|
687 |
-
|
688 |
-
|
689 |
-
|
690 |
-
|
691 |
-
|
692 |
-
|
693 |
-
|
694 |
-
|
695 |
-
|
696 |
-
|
697 |
-
|
698 |
-
|
699 |
-
|
700 |
-
|
701 |
-
|
702 |
-
|
703 |
-
|
704 |
-
|
705 |
-
|
706 |
-
|
707 |
-
|
708 |
-
|
709 |
-
|
710 |
-
|
711 |
-
|
712 |
-
|
713 |
-
|
714 |
-
|
715 |
-
|
716 |
-
|
717 |
-
|
718 |
-
|
719 |
-
|
720 |
-
|
721 |
-
|
722 |
-
|
723 |
-
|
724 |
-
|
725 |
-
|
726 |
-
|
727 |
-
|
728 |
-
|
729 |
-
|
730 |
-
|
731 |
-
|
732 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
733 |
# Performs Inference with model
|
734 |
-
if inference
|
735 |
-
if dataset_split
|
|
|
736 |
def if_label(example):
|
737 |
-
return example[dataset_split] == data_filter
|
738 |
-
|
739 |
train_dataset = train_dataset.filter(if_label, num_proc=cpu_cores)
|
740 |
-
|
741 |
trainset_label_shuffled = train_dataset
|
742 |
total_cells = len(trainset_label_shuffled)
|
743 |
-
|
744 |
# loads dictionary of all cell labels model was trained on
|
745 |
-
with open(Path(
|
746 |
data = ast.literal_eval(f.read())
|
747 |
-
if dataset_split
|
748 |
indexer = dataset_split.index(data_filter)
|
749 |
data = data[indexer]
|
750 |
-
|
751 |
-
target_dict_list = {key:value for key, value in enumerate(data)}
|
752 |
-
|
753 |
# set logging steps
|
754 |
-
logging_steps = round(len(trainset_label_shuffled)/
|
755 |
-
|
756 |
-
#
|
757 |
input_ids = trainset_label_shuffled["input_ids"]
|
758 |
inputs = torch.zeros(len(input_ids), max_input_size, dtype=torch.int64)
|
759 |
attention = torch.zeros(len(input_ids), max_input_size, dtype=torch.int64)
|
760 |
-
|
761 |
for i, sentence in enumerate(input_ids):
|
762 |
sentence_length = len(sentence)
|
763 |
if sentence_length <= max_input_size:
|
@@ -766,59 +765,77 @@ def finetune_cells(token_set = Path('geneformer/token_dictionary.pkl'), median_s
|
|
766 |
else:
|
767 |
inputs[i, :] = torch.tensor(sentence[:max_input_size])
|
768 |
attention[i, :] = torch.ones(max_input_size)
|
769 |
-
|
770 |
-
model = BertForSequenceClassification.from_pretrained(
|
771 |
-
|
772 |
-
|
|
|
|
|
773 |
|
774 |
predictions = [target_dict_list[int(pred)] for pred in predictions]
|
775 |
|
776 |
return predictions
|
777 |
-
|
778 |
-
# Extracts embeddings from
|
779 |
-
if emb_extract
|
780 |
-
if emb_filter
|
781 |
-
with open(f
|
782 |
data = ast.literal_eval(f.read())
|
783 |
-
if dataset_split
|
784 |
indexer = dataset_split.index(data_filter)
|
785 |
data = data[indexer]
|
786 |
-
|
787 |
-
target_dict_list = {key:value for key, value in enumerate(data)}
|
788 |
total_filter = None
|
789 |
else:
|
790 |
total_filter = emb_filter
|
791 |
-
|
792 |
train_dataset = load_from_disk(dataset)
|
793 |
-
if dataset_split
|
|
|
794 |
def if_label(example):
|
795 |
-
return example[dataset_split] == data_filter
|
796 |
-
|
797 |
train_dataset = train_dataset.filter(if_label, num_proc=cpu_cores)
|
798 |
-
|
799 |
label_counter = Counter(train_dataset[label])
|
800 |
total_cells = sum(label_counter.values())
|
801 |
-
cells_to_keep = [
|
802 |
-
|
|
|
|
|
803 |
def if_not_rare(example):
|
804 |
-
|
805 |
-
|
806 |
train_dataset = train_dataset.filter(if_not_rare, num_proc=cpu_cores)
|
807 |
-
|
808 |
true_labels = train_dataset[label]
|
809 |
num_classes = len(list(set(true_labels)))
|
810 |
-
|
811 |
-
embex = EmbExtractor(
|
812 |
-
|
813 |
-
|
814 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
815 |
# example dataset: https://huggingface.co/datasets/ctheodoris/Genecorpus-30M/tree/main/example_input_files/cell_classification/disease_classification/human_dcm_hcm_nf.dataset
|
816 |
-
subprocess.call(f
|
817 |
-
|
818 |
-
embs = embex.extract_embs(
|
|
|
|
|
|
|
|
|
|
|
819 |
true_labels = embex.filtered_input_data[label]
|
820 |
|
821 |
-
emb_dict = {label:[] for label in list(set(true_labels))}
|
822 |
for num, emb in embs.iterrows():
|
823 |
key = emb[label]
|
824 |
selection = emb.iloc[:255]
|
@@ -826,36 +843,32 @@ def finetune_cells(token_set = Path('geneformer/token_dictionary.pkl'), median_s
|
|
826 |
emb_dict[key].append(emb)
|
827 |
|
828 |
for key in list(emb_dict.keys()):
|
829 |
-
stack = torch.stack(emb_dict[key], dim
|
830 |
emb_dict[key] = torch.mean(stack, dim=0)
|
831 |
-
similarities = {key:{} for key in list(emb_dict.keys())}
|
832 |
-
|
833 |
for key in list(emb_dict.keys()):
|
834 |
remaining_keys = [k for k in list(emb_dict.keys()) if k != key]
|
835 |
for k in remaining_keys:
|
836 |
embedding = emb_dict[k]
|
837 |
-
sim = similarity(emb_dict[key], embedding, cosine
|
838 |
-
|
839 |
similarities[key][k] = sim
|
840 |
-
|
841 |
plot_similarity_heatmap(similarities)
|
842 |
-
|
843 |
-
embex.plot_embs(embs=embs,
|
844 |
-
plot_style="umap",
|
845 |
-
output_directory=emb_dir,
|
846 |
-
output_prefix="emb_plot")
|
847 |
-
|
848 |
-
|
849 |
-
embex.plot_embs(embs=embs,
|
850 |
-
plot_style="heatmap",
|
851 |
-
output_directory=emb_dir,
|
852 |
-
output_prefix="emb_plot")
|
853 |
-
|
854 |
-
|
855 |
-
return similarities
|
856 |
-
|
857 |
-
if __name__ == '__main__':
|
858 |
-
predictions = finetune_cells(skip_training = False, dataset_split = None, label = "disease", sample_data = .5, data_filter = 'hcm', epochs = 10, output_dir = 'hcm_model', model_location = 'hcm_model',
|
859 |
-
emb_extract = True, geneformer_batch_size = 12, inference = False, dataset = "/work/ccnr/GeneFormer/GeneFormer_repo/Genecorpus-30M/example_input_files/cell_classification/disease_classification/human_dcm_hcm_nf.dataset/")
|
860 |
|
861 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Geneformer cell classifier.
|
3 |
+
|
4 |
+
Usage:
|
5 |
+
from geneformer import classify_cells
|
6 |
+
classify_cells(
|
7 |
+
token_set=Path("geneformer/token_dictionary.pkl"),
|
8 |
+
median_set=Path("geneformer/gene_median_dictionary.pkl"),
|
9 |
+
pretrained_model=".",
|
10 |
+
dataset="Genecorpus-30M/example_input_files/cell_classification/cell_type_annotation/cell_type_train_data.dataset/",
|
11 |
+
dataset_split=None,
|
12 |
+
filter_cells=0.005,
|
13 |
+
epochs=1,
|
14 |
+
cpu_cores=os.cpu_count(),
|
15 |
+
geneformer_batch_size=12,
|
16 |
+
optimizer="adamw",
|
17 |
+
max_lr=5e-5,
|
18 |
+
num_gpus=torch.cuda.device_count(),
|
19 |
+
max_input_size=2**11,
|
20 |
+
lr_schedule_fn="linear",
|
21 |
+
warmup_steps=500,
|
22 |
+
freeze_layers=0,
|
23 |
+
emb_extract=False,
|
24 |
+
max_cells=1000,
|
25 |
+
emb_layer=0,
|
26 |
+
emb_filter=None,
|
27 |
+
emb_dir="embeddings",
|
28 |
+
overwrite=True,
|
29 |
+
label="cell_type",
|
30 |
+
data_filter=None,
|
31 |
+
forward_batch=200,
|
32 |
+
model_location=None,
|
33 |
+
skip_training=False,
|
34 |
+
sample_data=1,
|
35 |
+
inference=False,
|
36 |
+
optimize_hyperparameters=False,
|
37 |
+
output_dir=None,
|
38 |
+
)
|
39 |
+
"""
|
40 |
+
|
41 |
+
import ast
|
42 |
+
import datetime
|
43 |
import os
|
44 |
+
import pickle
|
|
|
45 |
import random
|
|
|
|
|
46 |
import subprocess
|
47 |
+
from collections import Counter
|
48 |
+
from pathlib import Path
|
49 |
+
|
50 |
+
import numpy as np
|
51 |
+
import seaborn as sns
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
52 |
import torch
|
53 |
import torch.nn.functional as F
|
54 |
+
from datasets import load_from_disk
|
55 |
+
from matplotlib import pyplot as plt
|
|
|
56 |
from ray import tune
|
|
|
57 |
from ray.tune.search.hyperopt import HyperOptSearch
|
58 |
+
from sklearn.metrics import accuracy_score
|
59 |
+
from sklearn.metrics import auc as precision_auc
|
60 |
+
from sklearn.metrics import f1_score, precision_recall_curve, roc_auc_score, roc_curve
|
61 |
+
from transformers import BertForSequenceClassification, Trainer
|
62 |
+
from transformers.training_args import TrainingArguments
|
63 |
+
|
64 |
+
from geneformer import DataCollatorForCellClassification, EmbExtractor
|
65 |
+
|
66 |
+
sns.set()
|
67 |
|
68 |
# Properly sets up NCCV environment
|
69 |
+
GPU_NUMBER = [i for i in range(torch.cuda.device_count())]
|
70 |
os.environ["CUDA_VISIBLE_DEVICES"] = ",".join([str(s) for s in GPU_NUMBER])
|
71 |
os.environ["NCCL_DEBUG"] = "INFO"
|
72 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
73 |
|
74 |
+
|
75 |
+
# Function for generating an ROC curve from data
|
76 |
+
def ROC(prediction, truth, type="GeneFormer", label=""):
|
77 |
fpr, tpr, _ = roc_curve(truth, prediction[:, 1])
|
78 |
auc = roc_auc_score(truth, prediction[:, 1])
|
79 |
+
print(f"{type} AUC: {auc}")
|
80 |
+
plt.plot(fpr, tpr, label="AUC=" + str(auc))
|
81 |
+
plt.ylabel("True Positive Rate")
|
82 |
+
plt.xlabel("False Positive Rate")
|
83 |
+
plt.title(f"{label} ROC Curve")
|
84 |
plt.legend(loc=4)
|
85 |
+
plt.savefig("ROC.png")
|
86 |
+
|
87 |
+
return tpr, fpr, auc
|
88 |
+
|
89 |
+
|
90 |
# Identifies cosine similarity between two embeddings. 0 is perfectly dissimilar and 1 is perfectly similar
|
91 |
+
def similarity(tensor1, tensor2, cosine=False):
|
92 |
+
if cosine is False:
|
|
|
93 |
if tensor1.ndimension() > 1:
|
94 |
tensor1 = tensor1.view(1, -1)
|
95 |
if tensor2.ndimension() > 1:
|
|
|
99 |
norm_tensor2 = torch.norm(tensor2)
|
100 |
epsilon = 1e-8
|
101 |
similarity = dot_product / (norm_tensor1 * norm_tensor2 + epsilon)
|
102 |
+
similarity = (similarity.item() + 1) / 2
|
103 |
else:
|
104 |
if tensor1.shape != tensor2.shape:
|
105 |
raise ValueError("Input tensors must have the same shape.")
|
|
|
108 |
dot_product = torch.dot(tensor1, tensor2)
|
109 |
norm_tensor1 = torch.norm(tensor1)
|
110 |
norm_tensor2 = torch.norm(tensor2)
|
111 |
+
|
112 |
# Avoid division by zero by adding a small epsilon
|
113 |
epsilon = 1e-8
|
114 |
similarity = dot_product / (norm_tensor1 * norm_tensor2 + epsilon)
|
115 |
+
|
116 |
return similarity.item()
|
117 |
+
|
118 |
+
|
119 |
# Plots heatmap between different classes/labels
|
120 |
def plot_similarity_heatmap(similarities):
|
121 |
classes = list(similarities.keys())
|
|
|
128 |
else:
|
129 |
val = similarities[c][cc]
|
130 |
arr[i][j] = val
|
131 |
+
|
132 |
plt.figure(figsize=(8, 6))
|
133 |
+
plt.imshow(arr, cmap="inferno", vmin=0, vmax=1)
|
134 |
plt.colorbar()
|
135 |
+
plt.xticks(np.arange(classlen), classes, rotation=45, ha="right")
|
136 |
plt.yticks(np.arange(classlen), classes)
|
137 |
plt.title("Similarity Heatmap")
|
138 |
plt.savefig("similarity_heatmap.png")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
139 |
|
|
|
|
|
|
|
140 |
|
141 |
+
def classify_cells(
|
142 |
+
token_set=Path("./token_dictionary.pkl"),
|
143 |
+
median_set=Path("./gene_median_dictionary.pkl"),
|
144 |
+
pretrained_model="../",
|
145 |
+
dataset="Genecorpus-30M/example_input_files/cell_classification/cell_type_annotation/cell_type_train_data.dataset/",
|
146 |
+
dataset_split=None,
|
147 |
+
filter_cells=0.005,
|
148 |
+
epochs=1,
|
149 |
+
cpu_cores=os.cpu_count(),
|
150 |
+
training_batch_size=12,
|
151 |
+
optimizer="adamw",
|
152 |
+
max_lr=5e-5,
|
153 |
+
num_gpus=torch.cuda.device_count(),
|
154 |
+
max_input_size=2**11,
|
155 |
+
lr_schedule_fn="linear",
|
156 |
+
warmup_steps=500,
|
157 |
+
freeze_layers=0,
|
158 |
+
emb_extract=False,
|
159 |
+
max_cells=None,
|
160 |
+
emb_layer=-1,
|
161 |
+
emb_filter=None,
|
162 |
+
emb_dir="embeddings",
|
163 |
+
overwrite=False,
|
164 |
+
label="cell_type",
|
165 |
+
data_filter=None,
|
166 |
+
inference_batch_size=200,
|
167 |
+
finetuned_model=None,
|
168 |
+
skip_training=False,
|
169 |
+
sample_data=1,
|
170 |
+
inference=False,
|
171 |
+
optimize_hyperparameters=True,
|
172 |
+
output_dir=None,
|
173 |
+
):
|
174 |
+
"""
|
175 |
Primary Parameters
|
176 |
-------------------
|
177 |
dataset: path
|
178 |
+
Path to fine-tuning dataset for training
|
179 |
+
|
180 |
+
finetuned_model: path
|
181 |
+
Path to location of fine-tuned model to use for inference and embedding extraction
|
182 |
|
|
|
|
|
|
|
183 |
pretrained_model: path
|
184 |
+
Path to pretrained Geneformer model
|
185 |
+
|
186 |
+
inference: bool
|
187 |
+
Indicates whether to perform inference and return a list of similarities. Defaults to False.
|
188 |
+
|
189 |
+
skip_training: bool
|
190 |
+
Indicates whether to skip training the model. Defaults to False.
|
191 |
+
|
192 |
emb_extract: bool
|
193 |
+
Indicates whether to extract embeddings and calculate similarities. Defaults to True.
|
194 |
+
|
195 |
optimize_hyperparameters: bool
|
196 |
+
Indicates whether to optimize model hyperparamters. Defaults to False.
|
197 |
+
|
198 |
+
|
199 |
Customization Parameters
|
200 |
-------------------
|
201 |
+
|
202 |
dataset_split: str
|
203 |
+
Indicates how the dataset should be partitioned (if at all), and what ID should be used for partitioning
|
204 |
+
|
205 |
data_filter: list
|
206 |
+
(For embeddings and inference) Runs analysis on subsets of the dataset based on the ID defined by dataset_split
|
207 |
+
|
208 |
label: str
|
209 |
+
Feature to read as a classification label.
|
210 |
+
|
211 |
emb_layer: int
|
212 |
+
What layer embeddings should be extracted and compared.
|
213 |
+
|
214 |
emb_filter: ['cell1', 'cell2'...]
|
215 |
Allows user to narrow down range of cells that embeddings will be extracted from.
|
216 |
+
|
217 |
max_cells: int
|
218 |
+
Max number of cells to use for embedding extraction.
|
219 |
+
|
220 |
freeze_layers: int
|
221 |
+
Number of layers that should be frozen during fine-tuning.
|
222 |
+
|
223 |
sample_data: float
|
224 |
+
Proportion of the dataset that should be used.
|
225 |
+
|
226 |
+
"""
|
227 |
+
|
228 |
dataset_list = []
|
229 |
evalset_list = []
|
230 |
split_list = []
|
231 |
target_dict_list = []
|
232 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
233 |
train_dataset = load_from_disk(dataset)
|
234 |
num_samples = int(len(train_dataset) * sample_data)
|
235 |
random_indices = random.sample(range(len(train_dataset)), num_samples)
|
236 |
train_dataset = train_dataset.select(random_indices)
|
237 |
+
|
238 |
+
sample = int(sample_data * len(train_dataset))
|
239 |
sample_indices = random.sample(range(len(train_dataset)), sample)
|
240 |
train_dataset = train_dataset.select(sample_indices)
|
241 |
|
242 |
+
def if_not_rare_cell_state(example):
|
243 |
return example[label] in cells_to_keep
|
244 |
+
|
245 |
# change labels to numerical ids
|
246 |
def classes_to_ids(example):
|
247 |
example["label"] = target_name_id_dict[example["label"]]
|
248 |
return example
|
249 |
+
|
250 |
def if_trained_label(example):
|
251 |
+
return example["label"] in trained_labels
|
252 |
+
|
253 |
+
if skip_training is not True:
|
254 |
+
|
255 |
def compute_metrics(pred):
|
256 |
labels = pred.label_ids
|
257 |
preds = pred.predictions.argmax(-1)
|
258 |
# calculate accuracy and macro f1 using sklearn's function
|
259 |
acc = accuracy_score(labels, preds)
|
260 |
+
macro_f1 = f1_score(labels, preds, average="macro")
|
261 |
+
return {"accuracy": acc, "macro_f1": macro_f1}
|
262 |
+
|
|
|
|
|
|
|
263 |
# Defines custom exceptions for collecting labels (default excluded)
|
264 |
+
excep = {"bone_marrow": "immune"}
|
265 |
+
|
266 |
+
if dataset_split is not None:
|
267 |
+
if data_filter is not None:
|
268 |
+
split_iter = [data_filter]
|
269 |
else:
|
270 |
+
split_iter = Counter(train_dataset[dataset_split]).keys()
|
271 |
for lab in split_iter:
|
|
|
272 |
# collect list of tissues for fine-tuning (immune and bone marrow are included together)
|
273 |
if lab in list(excep.keys()):
|
274 |
continue
|
275 |
elif lab == list(excep.values()):
|
276 |
+
split_ids = [excep.keys(), excep.values()]
|
277 |
split_list += [excep.values()]
|
278 |
else:
|
279 |
split_ids = [lab]
|
280 |
split_list += [lab]
|
281 |
+
|
282 |
# filter datasets for given organ
|
283 |
def if_label(example):
|
284 |
return example[dataset_split] == lab
|
285 |
+
|
286 |
trainset_label = train_dataset.filter(if_label, num_proc=cpu_cores)
|
287 |
label_counter = Counter(trainset_label[label])
|
288 |
total_cells = sum(label_counter.values())
|
289 |
+
|
290 |
+
# excludes cells with a low proportion in the dataset
|
291 |
+
cells_to_keep = [
|
292 |
+
k
|
293 |
+
for k, v in label_counter.items()
|
294 |
+
if v > (filter_cells * total_cells)
|
295 |
+
]
|
296 |
+
trainset_label_subset = trainset_label.filter(
|
297 |
+
if_not_rare_cell_state, num_proc=cpu_cores
|
298 |
+
)
|
299 |
+
|
300 |
# shuffle datasets and rename columns
|
301 |
trainset_label_shuffled = trainset_label_subset.shuffle(seed=42)
|
302 |
+
trainset_label_shuffled = trainset_label_shuffled.rename_column(
|
303 |
+
label, "label"
|
304 |
+
)
|
305 |
+
trainset_label_shuffled = trainset_label_shuffled.remove_columns(
|
306 |
+
dataset_split
|
307 |
+
)
|
308 |
+
|
309 |
# create dictionary of cell types : label ids
|
310 |
target_names = list(Counter(trainset_label_shuffled["label"]).keys())
|
311 |
+
target_name_id_dict = dict(
|
312 |
+
zip(target_names, [i for i in range(len(target_names))])
|
313 |
+
)
|
314 |
target_dict_list += [target_name_id_dict]
|
315 |
+
|
316 |
+
labeled_trainset = trainset_label_shuffled.map(
|
317 |
+
classes_to_ids, num_proc=cpu_cores
|
318 |
+
)
|
319 |
+
|
320 |
# create 80/20 train/eval splits
|
321 |
+
labeled_train_split = trainset_label_shuffled.select(
|
322 |
+
[i for i in range(0, round(len(labeled_trainset) * 0.8))]
|
323 |
+
)
|
324 |
+
labeled_eval_split = trainset_label_shuffled.select(
|
325 |
+
[
|
326 |
+
i
|
327 |
+
for i in range(
|
328 |
+
round(len(labeled_trainset) * 0.8), len(labeled_trainset)
|
329 |
+
)
|
330 |
+
]
|
331 |
+
)
|
332 |
+
|
333 |
# filter dataset for cell types in corresponding training set
|
334 |
trained_labels = list(Counter(labeled_train_split["label"]).keys())
|
335 |
+
|
336 |
+
labeled_eval_split_subset = labeled_eval_split.filter(
|
337 |
+
if_trained_label, num_proc=cpu_cores
|
338 |
+
)
|
339 |
+
|
340 |
dataset_list += [labeled_train_split]
|
341 |
evalset_list += [labeled_eval_split_subset]
|
342 |
+
|
343 |
+
trainset_dict = dict(zip(split_list, dataset_list))
|
344 |
+
traintargetdict_dict = dict(zip(split_list, target_dict_list))
|
345 |
+
evalset_dict = dict(zip(split_list, evalset_list))
|
346 |
+
|
347 |
for lab in split_list:
|
348 |
label_trainset = trainset_dict[lab]
|
349 |
label_evalset = evalset_dict[lab]
|
350 |
label_dict = traintargetdict_dict[lab]
|
351 |
+
|
352 |
# set logging steps
|
353 |
+
logging_steps = round(len(label_trainset) / training_batch_size / 10)
|
354 |
if logging_steps == 0:
|
355 |
logging_steps = 1
|
356 |
+
|
357 |
+
# load pretrained model
|
358 |
+
model = BertForSequenceClassification.from_pretrained(
|
359 |
+
pretrained_model,
|
360 |
+
num_labels=len(label_dict.keys()),
|
361 |
+
output_attentions=False,
|
362 |
+
output_hidden_states=False,
|
363 |
+
).to(device)
|
364 |
+
|
365 |
# define output directory path
|
366 |
current_date = datetime.datetime.now()
|
367 |
datestamp = f"{str(current_date.year)[-2:]}{current_date.month:02d}{current_date.day:02d}"
|
368 |
+
|
369 |
+
if output_dir is None:
|
370 |
+
output_dir = f"{datestamp}_geneformer_CellClassifier_{lab}_L{max_input_size}_B{training_batch_size}_LR{max_lr}_LS{lr_schedule_fn}_WU{warmup_steps}_E{epochs}_O{optimizer}_F{freeze_layers}/"
|
371 |
+
|
372 |
# ensure not overwriting previously saved model
|
373 |
+
saved_model_test = os.path.join(output_dir, "pytorch_model.bin")
|
374 |
+
|
375 |
+
if os.path.isfile(saved_model_test) is True and overwrite is False:
|
376 |
raise Exception("Model already saved to this directory.")
|
377 |
+
|
378 |
# make output directory
|
379 |
+
subprocess.call(f"mkdir -p {output_dir}", shell=True)
|
380 |
+
|
381 |
# set training arguments
|
382 |
training_args = {
|
383 |
"learning_rate": max_lr,
|
|
|
392 |
"lr_scheduler_type": lr_schedule_fn,
|
393 |
"warmup_steps": warmup_steps,
|
394 |
"weight_decay": 0.001,
|
395 |
+
"per_device_train_batch_size": training_batch_size,
|
396 |
+
"per_device_eval_batch_size": training_batch_size,
|
397 |
"num_train_epochs": epochs,
|
398 |
"load_best_model_at_end": True,
|
399 |
"output_dir": output_dir,
|
400 |
}
|
401 |
+
|
|
|
402 |
training_args_init = TrainingArguments(**training_args)
|
403 |
+
true_labels = label_evalset["label"]
|
404 |
+
|
405 |
+
if optimize_hyperparameters is False:
|
|
|
406 |
# create the trainer
|
407 |
trainer = Trainer(
|
408 |
+
model=model,
|
409 |
+
args=training_args_init,
|
410 |
+
data_collator=DataCollatorForCellClassification(),
|
411 |
+
train_dataset=label_trainset,
|
412 |
+
eval_dataset=label_evalset,
|
413 |
+
compute_metrics=compute_metrics,
|
414 |
)
|
415 |
+
|
416 |
# train the cell type classifier
|
417 |
trainer.train()
|
418 |
predictions = trainer.predict(label_evalset)
|
419 |
+
print(
|
420 |
+
f'accuracy: {accuracy_score(predictions.argmax(), label_evalset["labels"])}'
|
421 |
+
)
|
422 |
+
|
423 |
tpr, fpr, auc = ROC(predictions.predictions, true_labels)
|
424 |
+
|
425 |
metrics = compute_metrics(predictions)
|
426 |
with open(f"{output_dir}predictions.pickle", "wb") as fp:
|
427 |
pickle.dump(predictions, fp)
|
428 |
+
|
429 |
+
trainer.save_metrics("eval", predictions.metrics)
|
430 |
+
|
431 |
+
with open(f"{output_dir}/targets.txt", "w") as f:
|
432 |
if len(target_dict_list) == 1:
|
433 |
f.write(str(target_dict_list[0]))
|
434 |
else:
|
435 |
f.write(str(target_dict_list))
|
436 |
+
|
437 |
+
try:
|
438 |
+
precision, recall, _ = precision_recall_curve(
|
439 |
+
true_labels, predictions.predictions[:, 1]
|
440 |
+
)
|
441 |
+
pr_auc = precision_auc(recall, precision)
|
442 |
+
|
443 |
+
print(f"AUC: {pr_auc}")
|
444 |
+
return recall, precision, pr_auc
|
445 |
+
except:
|
446 |
+
pass
|
447 |
+
|
448 |
trainer.save_model(output_dir)
|
449 |
else:
|
450 |
+
|
451 |
def model_init():
|
452 |
+
model = BertForSequenceClassification.from_pretrained(
|
453 |
+
pretrained_model,
|
454 |
+
num_labels=len(label_dict.keys()),
|
455 |
+
output_attentions=False,
|
456 |
+
output_hidden_states=False,
|
457 |
+
)
|
458 |
if freeze_layers is not None:
|
459 |
modules_to_freeze = model.bert.encoder.layer[:freeze_layers]
|
460 |
for module in modules_to_freeze:
|
461 |
for param in module.parameters():
|
462 |
param.requires_grad = False
|
463 |
+
model = model.to(device)
|
464 |
return model
|
465 |
+
|
466 |
trainer = Trainer(
|
467 |
model_init=model_init,
|
468 |
args=training_args_init,
|
469 |
data_collator=DataCollatorForCellClassification(),
|
470 |
train_dataset=label_trainset,
|
471 |
eval_dataset=label_evalset,
|
472 |
+
compute_metrics=compute_metrics,
|
473 |
)
|
474 |
# specify raytune hyperparameter search space
|
475 |
ray_config = {
|
476 |
"num_train_epochs": tune.choice([epochs]),
|
477 |
"learning_rate": tune.loguniform(1e-6, 1e-3),
|
478 |
"weight_decay": tune.uniform(0.0, 0.3),
|
479 |
+
"lr_scheduler_type": tune.choice(
|
480 |
+
["linear", "cosine", "polynomial"]
|
481 |
+
),
|
482 |
"warmup_steps": tune.uniform(100, 2000),
|
483 |
+
"seed": tune.uniform(0, 100),
|
484 |
+
"per_device_train_batch_size": tune.choice(
|
485 |
+
[training_batch_size]
|
486 |
+
),
|
487 |
}
|
488 |
+
|
489 |
+
hyperopt_search = HyperOptSearch(metric="eval_accuracy", mode="max")
|
490 |
+
|
491 |
+
if torch.device == "cuda":
|
492 |
+
resources_per_trial = ({"cpu": 8, "gpu": 1},)
|
|
|
493 |
else:
|
494 |
+
resources_per_trial = {"cpu": 8}
|
495 |
+
|
496 |
# optimize hyperparameters
|
497 |
best_trial = trainer.hyperparameter_search(
|
498 |
direction="maximize",
|
499 |
backend="ray",
|
500 |
+
resources_per_trial=resources_per_trial,
|
501 |
hp_space=lambda _: ray_config,
|
502 |
search_alg=hyperopt_search,
|
503 |
+
n_trials=100, # number of trials
|
504 |
+
progress_reporter=tune.CLIReporter(
|
505 |
+
max_report_frequency=600,
|
506 |
+
sort_by_metric=True,
|
507 |
+
max_progress_rows=100,
|
508 |
+
mode="max",
|
509 |
+
metric="eval_accuracy",
|
510 |
+
metric_columns=["loss", "eval_loss", "eval_accuracy"],
|
511 |
+
),
|
512 |
+
)
|
513 |
best_hyperparameters = best_trial.hyperparameters
|
514 |
+
|
515 |
print("Best Hyperparameters:")
|
516 |
print(best_hyperparameters)
|
517 |
+
|
|
|
|
|
518 |
else:
|
519 |
+
trainset_label = train_dataset
|
520 |
+
label_counter = Counter(trainset_label[label])
|
521 |
+
total_cells = sum(label_counter.values())
|
522 |
+
|
523 |
+
# Excludes cells with a low proportion in the dataset
|
524 |
+
cells_to_keep = [
|
525 |
+
k for k, v in label_counter.items() if v > (filter_cells * total_cells)
|
526 |
+
]
|
527 |
+
trainset_label_subset = trainset_label.filter(
|
528 |
+
if_not_rare_cell_state, num_proc=cpu_cores
|
529 |
+
)
|
530 |
+
|
531 |
+
# shuffle datasets and rename columns
|
532 |
+
trainset_label_shuffled = trainset_label_subset.shuffle(seed=42)
|
533 |
+
trainset_label_shuffled = trainset_label_shuffled.rename_column(
|
534 |
+
label, "label"
|
535 |
+
)
|
536 |
+
|
537 |
+
# create dictionary of cell types : label ids
|
538 |
+
target_names = list(Counter(trainset_label_shuffled["label"]).keys())
|
539 |
+
target_name_id_dict = dict(
|
540 |
+
zip(target_names, [i for i in range(len(target_names))])
|
541 |
+
)
|
542 |
+
target_dict_list = target_name_id_dict
|
543 |
+
|
544 |
+
labeled_trainset = trainset_label_shuffled.map(
|
545 |
+
classes_to_ids, num_proc=cpu_cores
|
546 |
+
)
|
547 |
+
|
548 |
+
# create 80/20 train/eval splits
|
549 |
+
labeled_train_split = labeled_trainset.select(
|
550 |
+
[i for i in range(0, round(len(labeled_trainset) * 0.8))]
|
551 |
+
)
|
552 |
+
labeled_eval_split = labeled_trainset.select(
|
553 |
+
[
|
554 |
+
i
|
555 |
+
for i in range(
|
556 |
+
round(len(labeled_trainset) * 0.8), len(labeled_trainset)
|
557 |
+
)
|
558 |
+
]
|
559 |
+
)
|
560 |
+
|
561 |
+
# filter dataset for cell types in corresponding training set
|
562 |
+
trained_labels = list(Counter(labeled_train_split["label"]).keys())
|
563 |
+
labeled_eval_split_subset = labeled_eval_split.filter(
|
564 |
+
if_trained_label, num_proc=cpu_cores
|
565 |
+
)
|
566 |
+
|
567 |
+
# set logging steps
|
568 |
+
logging_steps = round(len(trainset_label) / training_batch_size / 10)
|
569 |
+
|
570 |
+
# load pretrained model
|
571 |
+
model = BertForSequenceClassification.from_pretrained(
|
572 |
+
pretrained_model,
|
573 |
+
num_labels=len(target_dict_list.keys()),
|
574 |
+
output_attentions=False,
|
575 |
+
output_hidden_states=False,
|
576 |
+
).to(device)
|
577 |
+
# define output directory path
|
578 |
+
current_date = datetime.datetime.now()
|
579 |
+
datestamp = f"{str(current_date.year)[-2:]}{current_date.month:02d}{current_date.day:02d}"
|
580 |
+
|
581 |
+
if output_dir is None:
|
582 |
+
output_dir = f"{datestamp}_geneformer_CellClassifier_L{max_input_size}_B{training_batch_size}_LR{max_lr}_LS{lr_schedule_fn}_WU{warmup_steps}_E{epochs}_O{optimizer}_F{freeze_layers}/"
|
583 |
+
|
584 |
+
# ensure not overwriting previously saved model
|
585 |
+
saved_model_test = os.path.join(output_dir, "pytorch_model.bin")
|
586 |
+
if os.path.isfile(saved_model_test) is True and overwrite is False:
|
587 |
+
raise Exception("Model already saved to this directory.")
|
588 |
+
|
589 |
+
# make output directory
|
590 |
+
subprocess.call(f"mkdir -p {output_dir}", shell=True)
|
591 |
+
|
592 |
+
# set training arguments
|
593 |
+
training_args = {
|
594 |
+
"learning_rate": max_lr,
|
595 |
+
"do_train": True,
|
596 |
+
"do_eval": True,
|
597 |
+
"evaluation_strategy": "epoch",
|
598 |
+
"save_strategy": "epoch",
|
599 |
+
"logging_steps": logging_steps,
|
600 |
+
"group_by_length": True,
|
601 |
+
"length_column_name": "length",
|
602 |
+
"disable_tqdm": False,
|
603 |
+
"lr_scheduler_type": lr_schedule_fn,
|
604 |
+
"warmup_steps": warmup_steps,
|
605 |
+
"weight_decay": 0.001,
|
606 |
+
"per_device_train_batch_size": training_batch_size,
|
607 |
+
"per_device_eval_batch_size": training_batch_size,
|
608 |
+
"num_train_epochs": epochs,
|
609 |
+
"load_best_model_at_end": True,
|
610 |
+
"output_dir": output_dir,
|
611 |
+
}
|
612 |
+
|
613 |
+
training_args_init = TrainingArguments(**training_args)
|
614 |
+
true_labels = labeled_eval_split_subset["label"]
|
615 |
+
|
616 |
+
if optimize_hyperparameters is False:
|
617 |
+
# create the trainer
|
618 |
+
trainer = Trainer(
|
619 |
+
model=model,
|
620 |
+
args=training_args_init,
|
621 |
+
data_collator=DataCollatorForCellClassification(),
|
622 |
+
train_dataset=labeled_train_split,
|
623 |
+
eval_dataset=labeled_eval_split_subset,
|
624 |
+
compute_metrics=compute_metrics,
|
625 |
+
)
|
626 |
+
|
627 |
+
# train the cell type classifier
|
628 |
+
trainer.train()
|
629 |
+
predictions = trainer.predict(labeled_eval_split_subset)
|
630 |
+
predictions_tensor = torch.Tensor(predictions.predictions)
|
631 |
+
predicted_labels = torch.argmax(predictions_tensor, dim=1)
|
632 |
+
print(
|
633 |
+
f'accuracy: {accuracy_score(predicted_labels, labeled_eval_split_subset["label"])}'
|
634 |
+
)
|
635 |
+
metrics = compute_metrics(predictions)
|
636 |
+
|
637 |
+
with open(f"{output_dir}predictions.pickle", "wb") as fp:
|
638 |
+
pickle.dump(predictions.predictions.argmax(-1), fp)
|
639 |
+
|
640 |
+
trainer.save_metrics("eval", predictions.metrics)
|
641 |
+
trainer.save_model(output_dir)
|
642 |
+
|
643 |
+
# Saves label conversion dictionary to output directory
|
644 |
+
with open(f"{output_dir}/targets.txt", "w") as f:
|
645 |
+
f.write(str(target_dict_list))
|
646 |
+
|
647 |
+
try:
|
648 |
+
precision, recall, _ = precision_recall_curve(
|
649 |
+
true_labels, predictions.predictions[:, 1]
|
650 |
+
)
|
651 |
+
pr_auc = precision_auc(recall, precision)
|
652 |
+
|
653 |
+
print(f"AUC: {pr_auc}")
|
654 |
+
return recall, precision, pr_auc
|
655 |
+
except:
|
656 |
+
pass
|
657 |
+
|
658 |
+
else:
|
659 |
+
# Optimizes hyperparameters
|
660 |
+
|
661 |
+
num_classes = len(list(set(labeled_train_split["label"])))
|
662 |
+
|
663 |
+
def model_init():
|
664 |
+
model = BertForSequenceClassification.from_pretrained(
|
665 |
+
pretrained_model,
|
666 |
+
num_labels=num_classes,
|
667 |
+
output_attentions=False,
|
668 |
+
output_hidden_states=False,
|
669 |
+
)
|
670 |
+
|
671 |
+
if freeze_layers is not None:
|
672 |
+
modules_to_freeze = model.bert.encoder.layer[:freeze_layers]
|
673 |
+
for module in modules_to_freeze:
|
674 |
+
for param in module.parameters():
|
675 |
+
param.requires_grad = False
|
676 |
+
model = model.to(device)
|
677 |
+
return model
|
678 |
+
|
679 |
+
# create the trainer
|
680 |
+
trainer = Trainer(
|
681 |
+
model_init=model_init,
|
682 |
+
args=training_args_init,
|
683 |
+
data_collator=DataCollatorForCellClassification(),
|
684 |
+
train_dataset=labeled_train_split,
|
685 |
+
eval_dataset=labeled_eval_split_subset,
|
686 |
+
compute_metrics=compute_metrics,
|
687 |
+
)
|
688 |
+
|
689 |
+
# specify raytune hyperparameter search space
|
690 |
+
ray_config = {
|
691 |
+
"num_train_epochs": tune.choice([epochs]),
|
692 |
+
"learning_rate": tune.loguniform(1e-6, 1e-3),
|
693 |
+
"weight_decay": tune.uniform(0.0, 0.3),
|
694 |
+
"lr_scheduler_type": tune.choice(
|
695 |
+
["linear", "cosine", "polynomial"]
|
696 |
+
),
|
697 |
+
"warmup_steps": tune.uniform(100, 2000),
|
698 |
+
"seed": tune.uniform(0, 100),
|
699 |
+
"per_device_train_batch_size": tune.choice([training_batch_size]),
|
700 |
+
}
|
701 |
+
|
702 |
+
hyperopt_search = HyperOptSearch(metric="eval_accuracy", mode="max")
|
703 |
+
|
704 |
+
if torch.device == "cuda":
|
705 |
+
resources_per_trial = ({"cpu": 8, "gpu": 1},)
|
706 |
+
else:
|
707 |
+
resources_per_trial = {"cpu": 8}
|
708 |
+
|
709 |
+
# optimize hyperparameters
|
710 |
+
best_trial = trainer.hyperparameter_search(
|
711 |
+
direction="maximize",
|
712 |
+
backend="ray",
|
713 |
+
resources_per_trial=resources_per_trial,
|
714 |
+
hp_space=lambda _: ray_config,
|
715 |
+
search_alg=hyperopt_search,
|
716 |
+
n_trials=100, # number of trials
|
717 |
+
progress_reporter=tune.CLIReporter(
|
718 |
+
max_report_frequency=600,
|
719 |
+
sort_by_metric=True,
|
720 |
+
max_progress_rows=100,
|
721 |
+
mode="max",
|
722 |
+
metric="eval_accuracy",
|
723 |
+
metric_columns=["loss", "eval_loss", "eval_accuracy"],
|
724 |
+
),
|
725 |
+
)
|
726 |
+
best_hyperparameters = best_trial.hyperparameters
|
727 |
+
|
728 |
+
print("Best Hyperparameters:")
|
729 |
+
print(best_hyperparameters)
|
730 |
+
|
731 |
# Performs Inference with model
|
732 |
+
if inference is True:
|
733 |
+
if dataset_split is not None and data_filter is not None:
|
734 |
+
|
735 |
def if_label(example):
|
736 |
+
return example[dataset_split] == data_filter
|
737 |
+
|
738 |
train_dataset = train_dataset.filter(if_label, num_proc=cpu_cores)
|
739 |
+
|
740 |
trainset_label_shuffled = train_dataset
|
741 |
total_cells = len(trainset_label_shuffled)
|
742 |
+
|
743 |
# loads dictionary of all cell labels model was trained on
|
744 |
+
with open(Path(finetuned_model) / "targets.txt", "r") as f:
|
745 |
data = ast.literal_eval(f.read())
|
746 |
+
if dataset_split is not None and data_filter is None:
|
747 |
indexer = dataset_split.index(data_filter)
|
748 |
data = data[indexer]
|
749 |
+
|
750 |
+
target_dict_list = {key: value for key, value in enumerate(data)}
|
751 |
+
|
752 |
# set logging steps
|
753 |
+
logging_steps = round(len(trainset_label_shuffled) / training_batch_size / 20)
|
754 |
+
|
755 |
+
# load pretrained model
|
756 |
input_ids = trainset_label_shuffled["input_ids"]
|
757 |
inputs = torch.zeros(len(input_ids), max_input_size, dtype=torch.int64)
|
758 |
attention = torch.zeros(len(input_ids), max_input_size, dtype=torch.int64)
|
759 |
+
|
760 |
for i, sentence in enumerate(input_ids):
|
761 |
sentence_length = len(sentence)
|
762 |
if sentence_length <= max_input_size:
|
|
|
765 |
else:
|
766 |
inputs[i, :] = torch.tensor(sentence[:max_input_size])
|
767 |
attention[i, :] = torch.ones(max_input_size)
|
768 |
+
|
769 |
+
model = BertForSequenceClassification.from_pretrained(
|
770 |
+
finetuned_model, num_labels=len(target_dict_list)
|
771 |
+
).to(device)
|
772 |
+
model_outputs = model(inputs.to(device), attention_mask=attention)["logits"]
|
773 |
+
predictions = F.softmax(model_outputs, dim=-1).argmax(-1)
|
774 |
|
775 |
predictions = [target_dict_list[int(pred)] for pred in predictions]
|
776 |
|
777 |
return predictions
|
778 |
+
|
779 |
+
# Extracts embeddings from labeled data
|
780 |
+
if emb_extract is True:
|
781 |
+
if emb_filter is None:
|
782 |
+
with open(f"{finetuned_model}/targets.txt", "r") as f:
|
783 |
data = ast.literal_eval(f.read())
|
784 |
+
if dataset_split is not None and data_filter is None:
|
785 |
indexer = dataset_split.index(data_filter)
|
786 |
data = data[indexer]
|
787 |
+
|
788 |
+
target_dict_list = {key: value for key, value in enumerate(data)}
|
789 |
total_filter = None
|
790 |
else:
|
791 |
total_filter = emb_filter
|
792 |
+
|
793 |
train_dataset = load_from_disk(dataset)
|
794 |
+
if dataset_split is not None:
|
795 |
+
|
796 |
def if_label(example):
|
797 |
+
return example[dataset_split] == data_filter
|
798 |
+
|
799 |
train_dataset = train_dataset.filter(if_label, num_proc=cpu_cores)
|
800 |
+
|
801 |
label_counter = Counter(train_dataset[label])
|
802 |
total_cells = sum(label_counter.values())
|
803 |
+
cells_to_keep = [
|
804 |
+
k for k, v in label_counter.items() if v > (filter_cells * total_cells)
|
805 |
+
]
|
806 |
+
|
807 |
def if_not_rare(example):
|
808 |
+
return example[label] in cells_to_keep
|
809 |
+
|
810 |
train_dataset = train_dataset.filter(if_not_rare, num_proc=cpu_cores)
|
811 |
+
|
812 |
true_labels = train_dataset[label]
|
813 |
num_classes = len(list(set(true_labels)))
|
814 |
+
|
815 |
+
embex = EmbExtractor(
|
816 |
+
model_type="CellClassifier",
|
817 |
+
num_classes=num_classes,
|
818 |
+
filter_data=total_filter,
|
819 |
+
max_ncells=max_cells,
|
820 |
+
emb_layer=emb_layer,
|
821 |
+
emb_label=[dataset_split, label],
|
822 |
+
labels_to_plot=[label],
|
823 |
+
forward_batch_size=inference_batch_size,
|
824 |
+
nproc=cpu_cores,
|
825 |
+
)
|
826 |
+
|
827 |
# example dataset: https://huggingface.co/datasets/ctheodoris/Genecorpus-30M/tree/main/example_input_files/cell_classification/disease_classification/human_dcm_hcm_nf.dataset
|
828 |
+
subprocess.call(f"mkdir -p {emb_dir}", shell=True)
|
829 |
+
|
830 |
+
embs = embex.extract_embs(
|
831 |
+
model_directory=finetuned_model,
|
832 |
+
input_data_file=dataset,
|
833 |
+
output_directory=emb_dir,
|
834 |
+
output_prefix=f"{label}_embeddings",
|
835 |
+
)
|
836 |
true_labels = embex.filtered_input_data[label]
|
837 |
|
838 |
+
emb_dict = {label: [] for label in list(set(true_labels))}
|
839 |
for num, emb in embs.iterrows():
|
840 |
key = emb[label]
|
841 |
selection = emb.iloc[:255]
|
|
|
843 |
emb_dict[key].append(emb)
|
844 |
|
845 |
for key in list(emb_dict.keys()):
|
846 |
+
stack = torch.stack(emb_dict[key], dim=0)
|
847 |
emb_dict[key] = torch.mean(stack, dim=0)
|
848 |
+
similarities = {key: {} for key in list(emb_dict.keys())}
|
849 |
+
|
850 |
for key in list(emb_dict.keys()):
|
851 |
remaining_keys = [k for k in list(emb_dict.keys()) if k != key]
|
852 |
for k in remaining_keys:
|
853 |
embedding = emb_dict[k]
|
854 |
+
sim = similarity(emb_dict[key], embedding, cosine=True)
|
855 |
+
|
856 |
similarities[key][k] = sim
|
857 |
+
|
858 |
plot_similarity_heatmap(similarities)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
859 |
|
860 |
+
embex.plot_embs(
|
861 |
+
embs=embs,
|
862 |
+
plot_style="umap",
|
863 |
+
output_directory=emb_dir,
|
864 |
+
output_prefix="emb_plot",
|
865 |
+
)
|
866 |
+
|
867 |
+
embex.plot_embs(
|
868 |
+
embs=embs,
|
869 |
+
plot_style="heatmap",
|
870 |
+
output_directory=emb_dir,
|
871 |
+
output_prefix="emb_plot",
|
872 |
+
)
|
873 |
+
|
874 |
+
return similarities
|
Gene_classifier.py → geneformer/gene_classifier.py
RENAMED
@@ -1,37 +1,47 @@
|
|
1 |
import os
|
2 |
import sys
|
3 |
-
|
|
|
4 |
os.environ["CUDA_VISIBLE_DEVICES"] = ",".join([str(s) for s in GPU_NUMBER])
|
5 |
os.environ["NCCL_DEBUG"] = "INFO"
|
6 |
|
7 |
-
|
8 |
-
from sklearn.model_selection import train_test_split
|
9 |
import datetime
|
|
|
|
|
10 |
import subprocess
|
11 |
from pathlib import Path
|
12 |
-
|
13 |
import matplotlib.pyplot as plt
|
14 |
import numpy as np
|
15 |
-
import pickle
|
16 |
import pandas as pd
|
17 |
-
from datasets import load_from_disk, Dataset
|
18 |
-
from sklearn import preprocessing
|
19 |
-
from sklearn.metrics import accuracy_score, auc, confusion_matrix, ConfusionMatrixDisplay, roc_curve
|
20 |
-
from sklearn.model_selection import StratifiedKFold
|
21 |
import torch
|
22 |
-
|
23 |
-
from
|
24 |
-
from
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
from tqdm.notebook import tqdm
|
26 |
-
from
|
|
|
|
|
27 |
from geneformer import DataCollatorForGeneClassification, EmbExtractor
|
28 |
from geneformer.pretrainer import token_dictionary
|
29 |
-
|
30 |
-
import torch.nn.functional as F
|
31 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
32 |
from geneformer import TranscriptomeTokenizer
|
33 |
|
34 |
-
|
|
|
35 |
a, b = logit_pair
|
36 |
if a > b:
|
37 |
return 0
|
@@ -39,13 +49,16 @@ def vote(logit_pair):
|
|
39 |
return 1
|
40 |
elif a == b:
|
41 |
return "tie"
|
42 |
-
|
|
|
43 |
def py_softmax(vector):
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
|
|
|
|
49 |
if cosine == False:
|
50 |
if tensor1.ndimension() > 1:
|
51 |
tensor1 = tensor1.view(1, -1)
|
@@ -56,7 +69,7 @@ def similarity(tensor1, tensor2, cosine = True):
|
|
56 |
norm_tensor2 = torch.norm(tensor2)
|
57 |
epsilon = 1e-8
|
58 |
similarity = dot_product / (norm_tensor1 * norm_tensor2 + epsilon)
|
59 |
-
similarity = (similarity.item() + 1)/2
|
60 |
else:
|
61 |
if tensor1.shape != tensor2.shape:
|
62 |
raise ValueError("Input tensors must have the same shape.")
|
@@ -65,13 +78,14 @@ def similarity(tensor1, tensor2, cosine = True):
|
|
65 |
dot_product = torch.dot(tensor1, tensor2)
|
66 |
norm_tensor1 = torch.norm(tensor1)
|
67 |
norm_tensor2 = torch.norm(tensor2)
|
68 |
-
|
69 |
# Avoid division by zero by adding a small epsilon
|
70 |
epsilon = 1e-8
|
71 |
similarity = dot_product / (norm_tensor1 * norm_tensor2 + epsilon)
|
72 |
-
|
73 |
return similarity.item()
|
74 |
-
|
|
|
75 |
# Plots heatmap between different classes/labels
|
76 |
def plot_similarity_heatmap(similarities):
|
77 |
classes = list(similarities.keys())
|
@@ -84,100 +98,122 @@ def plot_similarity_heatmap(similarities):
|
|
84 |
else:
|
85 |
val = similarities[c][cc]
|
86 |
arr[i][j] = val
|
87 |
-
|
88 |
plt.figure(figsize=(8, 6))
|
89 |
-
plt.imshow(arr, cmap=
|
90 |
plt.colorbar()
|
91 |
-
plt.xticks(np.arange(classlen), classes, rotation
|
92 |
plt.yticks(np.arange(classlen), classes)
|
93 |
plt.title("Similarity Heatmap")
|
94 |
plt.savefig("similarity_heatmap.png")
|
95 |
-
|
|
|
96 |
# get cross-validated mean and sd metrics
|
97 |
def get_cross_valid_metrics(all_tpr, all_roc_auc, all_tpr_wt):
|
98 |
-
wts = [count/sum(all_tpr_wt) for count in all_tpr_wt]
|
99 |
-
|
100 |
-
all_weighted_tpr = [a*b for a,b in zip(all_tpr, wts)]
|
101 |
mean_tpr = np.sum(all_weighted_tpr, axis=0)
|
102 |
mean_tpr[-1] = 1.0
|
103 |
-
all_weighted_roc_auc = [a*b for a,b in zip(all_roc_auc, wts)]
|
104 |
roc_auc = np.sum(all_weighted_roc_auc)
|
105 |
-
roc_auc_sd = math.sqrt(np.average((all_roc_auc-roc_auc)**2, weights=wts))
|
106 |
return mean_tpr, roc_auc, roc_auc_sd
|
107 |
|
108 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
109 |
# initiate eval metrics to return
|
110 |
num_classes = len(set(labels))
|
111 |
mean_fpr = np.linspace(0, 1, 100)
|
112 |
-
|
113 |
# create 80/20 train/eval splits
|
114 |
-
targets_train, targets_eval, labels_train, labels_eval = train_test_split(
|
|
|
|
|
115 |
label_dict_train = dict(zip(targets_train, labels_train))
|
116 |
label_dict_eval = dict(zip(targets_eval, labels_eval))
|
117 |
-
|
118 |
# function to filter by whether contains train or eval labels
|
119 |
def if_contains_train_label(example):
|
120 |
a = label_dict_train.keys()
|
121 |
-
b = example[
|
122 |
return not set(a).isdisjoint(b)
|
123 |
|
124 |
def if_contains_eval_label(example):
|
125 |
a = label_dict_eval.keys()
|
126 |
-
b = example[
|
127 |
return not set(a).isdisjoint(b)
|
128 |
-
|
129 |
# filter dataset for examples containing classes for this split
|
130 |
print(f"Filtering training data")
|
131 |
trainset = data.filter(if_contains_train_label, num_proc=num_proc)
|
132 |
-
print(
|
|
|
|
|
133 |
print(f"Filtering evalation data")
|
134 |
evalset = data.filter(if_contains_eval_label, num_proc=num_proc)
|
135 |
print(f"Filtered {round((1-len(evalset)/len(data))*100)}%; {len(evalset)} remain\n")
|
136 |
-
|
137 |
# minimize to smaller training sample
|
138 |
training_size = min(subsample_size, len(trainset))
|
139 |
trainset_min = trainset.select([i for i in range(training_size)])
|
140 |
eval_size = min(training_size, len(evalset))
|
141 |
-
half_training_size = round(eval_size/2)
|
142 |
evalset_train_min = evalset.select([i for i in range(half_training_size)])
|
143 |
evalset_oos_min = evalset.select([i for i in range(half_training_size, eval_size)])
|
144 |
-
|
145 |
# label conversion functions
|
146 |
def generate_train_labels(example):
|
147 |
-
example["labels"] = [
|
|
|
|
|
148 |
return example
|
149 |
|
150 |
def generate_eval_labels(example):
|
151 |
-
example["labels"] = [
|
|
|
|
|
152 |
return example
|
153 |
-
|
154 |
-
# label datasets
|
155 |
print(f"Labeling training data")
|
156 |
trainset_labeled = trainset_min.map(generate_train_labels)
|
157 |
print(f"Labeling evaluation data")
|
158 |
evalset_train_labeled = evalset_train_min.map(generate_eval_labels)
|
159 |
print(f"Labeling evaluation OOS data")
|
160 |
evalset_oos_labeled = evalset_oos_min.map(generate_eval_labels)
|
161 |
-
|
162 |
# load model
|
163 |
model = BertForTokenClassification.from_pretrained(
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
)
|
169 |
if freeze_layers is not None:
|
170 |
modules_to_freeze = model.bert.encoder.layer[:freeze_layers]
|
171 |
for module in modules_to_freeze:
|
172 |
for param in module.parameters():
|
173 |
param.requires_grad = False
|
174 |
-
|
175 |
model = model.to(device)
|
176 |
-
|
177 |
# add output directory to training args and initiate
|
178 |
training_args["output_dir"] = output_dir
|
179 |
training_args_init = TrainingArguments(**training_args)
|
180 |
-
|
181 |
# create the trainer
|
182 |
trainer = Trainer(
|
183 |
model=model,
|
@@ -186,25 +222,40 @@ def validate(data, targets, labels, nsplits, subsample_size, training_args, free
|
|
186 |
train_dataset=trainset_labeled,
|
187 |
eval_dataset=evalset_train_labeled,
|
188 |
)
|
189 |
-
|
190 |
# train the gene classifier
|
191 |
trainer.train()
|
192 |
trainer.save_model(output_dir)
|
193 |
-
|
194 |
-
fpr, tpr, interp_tpr, conf_mat = classifier_predict(
|
|
|
|
|
195 |
auc_score = auc(fpr, tpr)
|
196 |
-
|
197 |
return fpr, tpr, auc_score
|
198 |
-
|
|
|
199 |
# cross-validate gene classifier
|
200 |
-
def cross_validate(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
201 |
# check if output directory already written to
|
202 |
# ensure not overwriting previously saved model
|
203 |
model_dir_test = os.path.join(output_dir, "ksplit0/models/pytorch_model.bin")
|
204 |
-
#if os.path.isfile(model_dir_test) == True:
|
205 |
# raise Exception("Model already saved to this directory.")
|
206 |
-
|
207 |
-
device = device = torch.device(
|
208 |
# initiate eval metrics to return
|
209 |
num_classes = len(set(labels))
|
210 |
mean_fpr = np.linspace(0, 1, 100)
|
@@ -212,8 +263,8 @@ def cross_validate(data, targets, labels, nsplits, subsample_size, training_args
|
|
212 |
all_roc_auc = []
|
213 |
all_tpr_wt = []
|
214 |
label_dicts = []
|
215 |
-
confusion = np.zeros((num_classes,num_classes))
|
216 |
-
|
217 |
# set up cross-validation splits
|
218 |
skf = StratifiedKFold(n_splits=nsplits, random_state=0, shuffle=True)
|
219 |
# train and evaluate
|
@@ -223,236 +274,280 @@ def cross_validate(data, targets, labels, nsplits, subsample_size, training_args
|
|
223 |
print("early stopping activated due to large # of training examples")
|
224 |
if iteration_num == 3:
|
225 |
break
|
226 |
-
|
227 |
print(f"****** Crossval split: {iteration_num}/{nsplits-1} ******\n")
|
228 |
-
|
229 |
# generate cross-validation splits
|
230 |
targets_train, targets_eval = targets[train_index], targets[eval_index]
|
231 |
labels_train, labels_eval = labels[train_index], labels[eval_index]
|
232 |
label_dict_train = dict(zip(targets_train, labels_train))
|
233 |
label_dict_eval = dict(zip(targets_eval, labels_eval))
|
234 |
-
label_dicts += (
|
235 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
236 |
# function to filter by whether contains train or eval labels
|
237 |
def if_contains_train_label(example):
|
238 |
a = label_dict_train.keys()
|
239 |
-
b = example[
|
240 |
-
|
241 |
return not set(a).isdisjoint(b)
|
242 |
|
243 |
def if_contains_eval_label(example):
|
244 |
a = label_dict_eval.keys()
|
245 |
-
b = example[
|
246 |
-
|
247 |
return not set(a).isdisjoint(b)
|
248 |
-
|
249 |
# filter dataset for examples containing classes for this split
|
250 |
print(f"Filtering training data")
|
251 |
trainset = data.filter(if_contains_train_label, num_proc=num_proc)
|
252 |
-
print(
|
|
|
|
|
253 |
print(f"Filtering evalation data")
|
254 |
evalset = data.filter(if_contains_eval_label, num_proc=num_proc)
|
255 |
-
print(
|
|
|
|
|
256 |
|
257 |
# minimize to smaller training sample
|
258 |
training_size = min(subsample_size, len(trainset))
|
259 |
trainset_min = trainset.select([i for i in range(training_size)])
|
260 |
eval_size = min(training_size, len(evalset))
|
261 |
-
half_training_size = round(eval_size/2)
|
262 |
evalset_train_min = evalset.select([i for i in range(half_training_size)])
|
263 |
-
evalset_oos_min = evalset.select(
|
264 |
-
|
|
|
|
|
265 |
# label conversion functions
|
266 |
def generate_train_labels(example):
|
267 |
-
example["labels"] = [
|
|
|
|
|
|
|
268 |
return example
|
269 |
|
270 |
def generate_eval_labels(example):
|
271 |
-
example["labels"] = [
|
|
|
|
|
272 |
return example
|
273 |
-
|
274 |
-
# label datasets
|
275 |
print(f"Labeling training data")
|
276 |
trainset_labeled = trainset_min.map(generate_train_labels)
|
277 |
print(f"Labeling evaluation data")
|
278 |
evalset_train_labeled = evalset_train_min.map(generate_eval_labels)
|
279 |
print(f"Labeling evaluation OOS data")
|
280 |
evalset_oos_labeled = evalset_oos_min.map(generate_eval_labels)
|
281 |
-
|
282 |
# create output directories
|
283 |
ksplit_output_dir = os.path.join(output_dir, f"ksplit{iteration_num}")
|
284 |
-
ksplit_model_dir = os.path.join(ksplit_output_dir, "models/")
|
285 |
-
|
286 |
# ensure not overwriting previously saved model
|
287 |
model_output_file = os.path.join(ksplit_model_dir, "pytorch_model.bin")
|
288 |
-
#if os.path.isfile(model_output_file) == True:
|
289 |
# raise Exception("Model already saved to this directory.")
|
290 |
|
291 |
# make training and model output directories
|
292 |
-
subprocess.call(f
|
293 |
-
subprocess.call(f
|
294 |
-
|
295 |
# load model
|
296 |
model = BertForTokenClassification.from_pretrained(
|
297 |
pre_model,
|
298 |
num_labels=num_labels,
|
299 |
-
output_attentions
|
300 |
-
output_hidden_states
|
301 |
)
|
302 |
if freeze_layers is not None:
|
303 |
modules_to_freeze = model.bert.encoder.layer[:freeze_layers]
|
304 |
for module in modules_to_freeze:
|
305 |
for param in module.parameters():
|
306 |
param.requires_grad = False
|
307 |
-
|
308 |
model = model.to(device)
|
309 |
-
|
310 |
# add output directory to training args and initiate
|
311 |
training_args["output_dir"] = ksplit_output_dir
|
312 |
training_args_init = TrainingArguments(**training_args)
|
313 |
-
|
314 |
# create the trainer
|
315 |
trainer = Trainer(
|
316 |
model=model,
|
317 |
args=training_args_init,
|
318 |
data_collator=DataCollatorForGeneClassification(),
|
319 |
train_dataset=trainset_labeled,
|
320 |
-
eval_dataset=evalset_train_labeled
|
321 |
)
|
322 |
|
323 |
# train the gene classifier
|
324 |
trainer.train()
|
325 |
-
|
326 |
# save model
|
327 |
trainer.save_model(ksplit_model_dir)
|
328 |
-
|
329 |
# evaluate model
|
330 |
-
fpr, tpr, interp_tpr, conf_mat = classifier_predict(
|
331 |
-
|
|
|
|
|
332 |
# append to tpr and roc lists
|
333 |
confusion = confusion + conf_mat
|
334 |
all_tpr.append(interp_tpr)
|
335 |
all_roc_auc.append(auc(fpr, tpr))
|
336 |
# append number of eval examples by which to weight tpr in averaged graphs
|
337 |
all_tpr_wt.append(len(tpr))
|
338 |
-
|
339 |
iteration_num = iteration_num + 1
|
340 |
-
|
341 |
# get overall metrics for cross-validation
|
342 |
-
mean_tpr, roc_auc, roc_auc_sd = get_cross_valid_metrics(
|
|
|
|
|
343 |
return all_roc_auc, roc_auc, roc_auc_sd, mean_fpr, mean_tpr, confusion, label_dicts
|
344 |
-
|
|
|
345 |
# Computes metrics
|
346 |
def compute_metrics(pred):
|
347 |
labels = pred.label_ids
|
348 |
preds = pred.predictions.argmax(-1)
|
349 |
# calculate accuracy and macro f1 using sklearn's function
|
350 |
acc = accuracy_score(labels, preds)
|
351 |
-
macro_f1 = f1_score(labels, preds, average=
|
352 |
-
|
353 |
-
return {
|
354 |
-
|
355 |
-
'macro_f1': macro_f1
|
356 |
-
}
|
357 |
|
358 |
# plot ROC curve
|
359 |
def plot_ROC(bundled_data, title):
|
360 |
plt.figure()
|
361 |
lw = 2
|
362 |
for roc_auc, roc_auc_sd, mean_fpr, mean_tpr, sample, color in bundled_data:
|
363 |
-
plt.plot(
|
364 |
-
|
365 |
-
|
366 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
367 |
plt.xlim([0.0, 1.0])
|
368 |
plt.ylim([0.0, 1.05])
|
369 |
-
plt.xlabel(
|
370 |
-
plt.ylabel(
|
371 |
plt.title(title)
|
372 |
plt.legend(loc="lower right")
|
373 |
plt.savefig("ROC.png")
|
374 |
-
|
375 |
return mean_fpr, mean_tpr, roc_auc
|
376 |
-
|
|
|
377 |
# plot confusion matrix
|
378 |
def plot_confusion_matrix(classes_list, conf_mat, title):
|
379 |
display_labels = []
|
380 |
i = 0
|
381 |
for label in classes_list:
|
382 |
-
display_labels += ["{0}\nn={1:.0f}".format(label, sum(conf_mat[:,i]))]
|
383 |
i = i + 1
|
384 |
-
display = ConfusionMatrixDisplay(
|
385 |
-
|
386 |
-
|
|
|
|
|
387 |
plt.title(title)
|
388 |
plt.savefig("CM.png")
|
389 |
-
|
|
|
390 |
# Function to find the largest number smaller
|
391 |
# than or equal to N that is divisible by k
|
392 |
def find_largest_div(N, K):
|
393 |
rem = N % K
|
394 |
-
if
|
395 |
return N
|
396 |
else:
|
397 |
return N - rem
|
398 |
-
|
|
|
399 |
def preprocess_classifier_batch(cell_batch, max_len):
|
400 |
if max_len == None:
|
401 |
max_len = max([len(i) for i in cell_batch["input_ids"]])
|
|
|
402 |
def pad_label_example(example):
|
403 |
-
example["labels"] = np.pad(
|
404 |
-
|
405 |
-
|
406 |
-
|
407 |
-
|
408 |
-
|
409 |
-
example["
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
410 |
return example
|
|
|
411 |
padded_batch = cell_batch.map(pad_label_example)
|
412 |
return padded_batch
|
413 |
|
|
|
414 |
# forward batch size is batch size for model inference (e.g. 200)
|
415 |
def classifier_predict(model, evalset, forward_batch_size, mean_fpr):
|
416 |
predict_logits = []
|
417 |
predict_labels = []
|
418 |
-
model.to(
|
419 |
model.eval()
|
420 |
-
|
421 |
# ensure there is at least 2 examples in each batch to avoid incorrect tensor dims
|
422 |
evalset_len = len(evalset)
|
423 |
max_divisible = find_largest_div(evalset_len, forward_batch_size)
|
424 |
if len(evalset) - max_divisible == 1:
|
425 |
evalset_len = max_divisible
|
426 |
-
|
427 |
max_evalset_len = max(evalset.select([i for i in range(evalset_len)])["length"])
|
428 |
-
|
429 |
for i in range(0, evalset_len, forward_batch_size):
|
430 |
-
max_range = min(i+forward_batch_size, evalset_len)
|
431 |
batch_evalset = evalset.select([i for i in range(i, max_range)])
|
432 |
padded_batch = preprocess_classifier_batch(batch_evalset, max_evalset_len)
|
433 |
padded_batch.set_format(type="torch")
|
434 |
-
|
435 |
input_data_batch = padded_batch["input_ids"]
|
436 |
attn_msk_batch = padded_batch["attention_mask"]
|
437 |
label_batch = padded_batch["labels"]
|
438 |
with torch.no_grad():
|
439 |
input_ids = input_data_batch
|
440 |
attn_mask = attn_msk_batch
|
441 |
-
labels =
|
442 |
outputs = model(
|
443 |
-
|
444 |
-
input_ids = input_ids,
|
445 |
-
attention_mask = attn_mask,
|
446 |
-
labels = labels
|
447 |
)
|
448 |
predict_logits += [torch.squeeze(outputs.logits.to("cpu"))]
|
449 |
predict_labels += [torch.squeeze(label_batch.to("cpu"))]
|
450 |
-
|
451 |
logits_by_cell = torch.cat(predict_logits)
|
452 |
all_logits = logits_by_cell.reshape(-1, logits_by_cell.shape[2])
|
453 |
labels_by_cell = torch.cat(predict_labels)
|
454 |
all_labels = torch.flatten(labels_by_cell)
|
455 |
-
logit_label_paired = [
|
|
|
|
|
|
|
|
|
456 |
y_pred = [vote(item[0]) for item in logit_label_paired]
|
457 |
y_true = [item[1] for item in logit_label_paired]
|
458 |
logits_list = [item[0] for item in logit_label_paired]
|
@@ -464,106 +559,133 @@ def classifier_predict(model, evalset, forward_batch_size, mean_fpr):
|
|
464 |
plt.plot(fpr, tpr)
|
465 |
plt.xlim([0.0, 1.0])
|
466 |
plt.ylim([0.0, 1.05])
|
467 |
-
plt.xlabel(
|
468 |
-
plt.ylabel(
|
469 |
-
plt.title(
|
470 |
plt.show()
|
471 |
# interpolate to graph
|
472 |
interp_tpr = np.interp(mean_fpr, fpr, tpr)
|
473 |
interp_tpr[0] = 0.0
|
474 |
-
return fpr, tpr, interp_tpr, conf_mat
|
475 |
-
|
476 |
-
|
477 |
-
|
478 |
-
|
479 |
-
|
480 |
-
|
481 |
-
|
482 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
483 |
Primary Parameters
|
484 |
-----------
|
485 |
-
|
486 |
gene_info: path
|
487 |
Path to gene mappings
|
488 |
-
|
489 |
corpus_30M: path
|
490 |
Path to 30M Gene Corpus
|
491 |
-
|
492 |
model: path
|
493 |
Path to pretrained GeneFormer model
|
494 |
-
|
495 |
genes: path
|
496 |
Path to csv file containing different columns of genes and the column labels
|
497 |
-
|
498 |
inference: bool
|
499 |
Whether the model should be used to run inference. If False, model will train with labeled data instead. Defaults to False
|
500 |
-
|
501 |
k_validate: bool
|
502 |
Whether the model should run k-fold validation or simply perform regular training/evaluate. Defaults to True
|
503 |
-
|
504 |
skip_training: bool
|
505 |
Whether the model should skip the training portion. Defaults to False
|
506 |
-
|
507 |
emb_extract: bool
|
508 |
WHether the model should extract embeddings for a given gene (WIP)
|
509 |
-
|
510 |
-
|
511 |
Customization Parameters
|
512 |
-----------
|
513 |
-
|
514 |
freeze_layers: int
|
515 |
Freezes x number of layers from the model. Default is 4 (2 non-frozen layers)
|
516 |
-
|
517 |
filter_dataset: int
|
518 |
Number of cells to filter from 30M dataset. Default is 50_000
|
519 |
-
|
520 |
emb_layer: int
|
521 |
What layer embeddings are extracted from. Default is 4
|
522 |
-
|
523 |
filter_data: str, list
|
524 |
Filters down embeddings to a single category. Default is None
|
525 |
-
|
526 |
-
|
527 |
"""
|
528 |
-
|
529 |
# table of corresponding Ensembl IDs, gene names, and gene types (e.g. coding, miRNA, etc.)
|
530 |
gene_info = pd.read_csv(gene_info, index_col=0)
|
531 |
labels = gene_info.columns
|
532 |
|
533 |
# create dictionaries for corresponding attributes
|
534 |
-
gene_id_type_dict = dict(zip(gene_info["ensembl_id"],gene_info["gene_type"]))
|
535 |
-
gene_name_id_dict = dict(zip(gene_info["gene_name"],gene_info["ensembl_id"]))
|
536 |
-
gene_id_name_dict = {v: k for k,v in gene_name_id_dict.items()}
|
537 |
|
538 |
# function for preparing targets and labels
|
539 |
def prep_inputs(label_store, id_type):
|
540 |
target_list = []
|
541 |
if id_type == "gene_name":
|
542 |
for key in list(label_store.keys()):
|
543 |
-
targets = [
|
|
|
|
|
|
|
|
|
544 |
targets_id = [token_dictionary[gene] for gene in targets]
|
545 |
target_list.append(targets_id)
|
546 |
elif id_type == "ensembl_id":
|
547 |
for key in list(label_store.keys()):
|
548 |
-
targets = [
|
|
|
|
|
549 |
targets_id = [token_dictionary[gene] for gene in targets]
|
550 |
target_list.append(targets_id)
|
551 |
-
|
552 |
targets, labels = [], []
|
553 |
for targ in target_list:
|
554 |
targets = targets + targ
|
555 |
targets = np.array(targets)
|
556 |
for num, targ in enumerate(target_list):
|
557 |
-
label = [num]*len(targ)
|
558 |
labels = labels + label
|
559 |
labels = np.array(labels)
|
560 |
unique_labels = num + 1
|
561 |
-
|
562 |
-
nsplits = min(5, min([len(targ) for targ in target_list])-1)
|
563 |
assert nsplits > 2
|
564 |
-
|
565 |
return targets, labels, nsplits, unique_labels
|
566 |
-
|
567 |
if skip_training == False:
|
568 |
# preparing targets and labels for dosage sensitive vs insensitive TFs
|
569 |
gene_classes = pd.read_csv(genes, header=0)
|
@@ -575,26 +697,26 @@ def classify_genes(gene_info = "Genecorpus-30M/example_input_files/gene_info_tab
|
|
575 |
else:
|
576 |
labels = [filter_data]
|
577 |
label_store = {}
|
578 |
-
|
579 |
# Dictionary for decoding labels
|
580 |
-
decode = {i:labels[i] for i in range(len(labels))}
|
581 |
-
|
582 |
for label in labels:
|
583 |
label_store[label] = gene_classes[label].dropna()
|
584 |
-
|
585 |
targets, labels, nsplits, unique_labels = prep_inputs(label_store, "ensembl_id")
|
586 |
-
|
587 |
-
|
588 |
-
|
589 |
# load training dataset
|
590 |
-
train_dataset=load_from_disk(corpus_30M)
|
591 |
shuffled_train_dataset = train_dataset.shuffle(seed=42)
|
592 |
-
subsampled_train_dataset = shuffled_train_dataset.select(
|
|
|
|
|
593 |
lr_schedule_fn = "linear"
|
594 |
warmup_steps = 500
|
595 |
optimizer = "adamw"
|
596 |
subsample_size = 10_000
|
597 |
-
|
598 |
training_args = {
|
599 |
"learning_rate": max_lr,
|
600 |
"do_train": True,
|
@@ -611,52 +733,95 @@ def classify_genes(gene_info = "Genecorpus-30M/example_input_files/gene_info_tab
|
|
611 |
"per_device_eval_batch_size": geneformer_batch_size,
|
612 |
"num_train_epochs": epochs,
|
613 |
}
|
614 |
-
|
615 |
# define output directory path
|
616 |
current_date = datetime.datetime.now()
|
617 |
datestamp = f"{str(current_date.year)[-2:]}{current_date.month:02d}{current_date.day:02d}"
|
618 |
-
|
619 |
if output_dir == None:
|
620 |
-
training_output_dir = Path(
|
|
|
|
|
621 |
else:
|
622 |
training_output_dir = Path(output_dir)
|
623 |
-
|
624 |
# make output directory
|
625 |
-
subprocess.call(f
|
626 |
-
|
627 |
# Places number of classes + in directory
|
628 |
num_classes = len(set(labels))
|
629 |
info_list = [num_classes, decode]
|
630 |
-
|
631 |
-
with open(training_output_dir /
|
632 |
f.write(str(info_list))
|
633 |
-
|
634 |
-
subsampled_train_dataset.save_to_disk(output_dir /
|
635 |
-
|
636 |
if k_validate == True:
|
637 |
-
ksplit_model ="ksplit0/models"
|
638 |
ksplit_model_test = os.path.join(training_output_dir, ksplit_model)
|
639 |
-
#if os.path.isfile(ksplit_model_test) == True:
|
640 |
# raise Exception("Model already saved to this directory.")
|
641 |
-
|
642 |
-
|
643 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
644 |
bundled_data = []
|
645 |
-
bundled_data += [
|
646 |
-
|
647 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
648 |
print(auc)
|
649 |
# plot confusion matrix
|
650 |
plot_confusion_matrix(label_store, confusion, "Geneformer")
|
651 |
-
else:
|
652 |
-
fpr, tpr, auc
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
653 |
print(auc)
|
654 |
-
|
655 |
if inference == True:
|
656 |
# preparing targets and labels for dosage sensitive vs insensitive TFs
|
657 |
gene_classes = pd.read_csv(genes, header=0)
|
658 |
targets = []
|
659 |
-
for column in gene_classes.columns:
|
660 |
targets += list(gene_classes[column])
|
661 |
tokens = []
|
662 |
for target in targets:
|
@@ -664,47 +829,48 @@ def classify_genes(gene_info = "Genecorpus-30M/example_input_files/gene_info_tab
|
|
664 |
tokens.append(token_dictionary[target])
|
665 |
except:
|
666 |
tokens.append(0)
|
667 |
-
|
668 |
targets = torch.LongTensor([tokens])
|
669 |
-
|
670 |
-
|
671 |
-
with open(f'{model_location}classes.txt', 'r') as f:
|
672 |
info_list = ast.literal_eval(f.read())
|
673 |
num_classes = info_list[0]
|
674 |
labels = info_list[1]
|
675 |
-
|
676 |
model = BertForTokenClassification.from_pretrained(
|
677 |
model_location,
|
678 |
num_labels=num_classes,
|
679 |
-
output_attentions
|
680 |
-
output_hidden_states
|
681 |
-
local_files_only
|
682 |
)
|
683 |
if freeze_layers is not None:
|
684 |
modules_to_freeze = model.bert.encoder.layer[:freeze_layers]
|
685 |
for module in modules_to_freeze:
|
686 |
for param in module.parameters():
|
687 |
param.requires_grad = False
|
688 |
-
|
689 |
model = model.to(device)
|
690 |
-
|
691 |
# evaluate model
|
692 |
-
predictions = F.softmax(model(targets.to(device))["logits"], dim
|
|
|
|
|
693 |
predictions = [labels[int(pred)] for pred in predictions]
|
694 |
-
|
695 |
return predictions
|
696 |
-
|
697 |
-
# Extracts aggregate gene embeddings for each label
|
698 |
if emb_extract == True:
|
699 |
-
with open(f
|
700 |
data = ast.literal_eval(f.read())
|
701 |
num_classes = data[0]
|
702 |
decode = data[1]
|
703 |
-
|
704 |
gene_classes = pd.read_csv(genes, header=0)
|
705 |
labels = gene_classes.columns
|
706 |
tokenize = TranscriptomeTokenizer()
|
707 |
-
|
708 |
label_dict = {}
|
709 |
for label in labels:
|
710 |
genes = gene_classes[label]
|
@@ -715,32 +881,55 @@ def classify_genes(gene_info = "Genecorpus-30M/example_input_files/gene_info_tab
|
|
715 |
except:
|
716 |
continue
|
717 |
label_dict[label] = tokenized_genes
|
718 |
-
|
719 |
-
embex = EmbExtractor(
|
720 |
-
|
721 |
-
|
722 |
-
|
723 |
-
|
724 |
-
|
725 |
-
|
726 |
-
|
727 |
-
|
728 |
-
|
729 |
-
|
730 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
731 |
for column in embs.columns:
|
732 |
remaining_cols = [k for k in embs.columns if k != column]
|
733 |
for k in remaining_cols:
|
734 |
embedding = torch.Tensor(embs[k])
|
735 |
-
sim = similarity(torch.Tensor(embs[column]), embedding, cosine
|
736 |
similarities[column][k] = sim
|
737 |
-
|
738 |
plot_similarity_heatmap(similarities)
|
739 |
print(similarities)
|
740 |
-
|
741 |
return similarities
|
742 |
-
|
743 |
-
if __name__ == '__main__':
|
744 |
-
classify_genes(k_validate = False, inference = False, skip_training = False, emb_extract = True, output_dir = Path('gene_emb'), model_location = Path('gene_emb'), epochs = 5, gene_info = "../GeneFormer_repo/Genecorpus-30M/example_input_files/gene_info_table.csv", genes = "../GeneFormer_repo/Genecorpus-30M/example_input_files/gene_classification/dosage_sensitive_tfs/dosage_sens_tf_labels.csv", corpus_30M = "../GeneFormer_repo/Genecorpus-30M/genecorpus_30M_2048.dataset/")
|
745 |
|
746 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import os
|
2 |
import sys
|
3 |
+
|
4 |
+
GPU_NUMBER = [0] # CHANGE WITH MULTIGPU
|
5 |
os.environ["CUDA_VISIBLE_DEVICES"] = ",".join([str(s) for s in GPU_NUMBER])
|
6 |
os.environ["NCCL_DEBUG"] = "INFO"
|
7 |
|
8 |
+
import ast
|
|
|
9 |
import datetime
|
10 |
+
import math
|
11 |
+
import pickle
|
12 |
import subprocess
|
13 |
from pathlib import Path
|
14 |
+
|
15 |
import matplotlib.pyplot as plt
|
16 |
import numpy as np
|
|
|
17 |
import pandas as pd
|
|
|
|
|
|
|
|
|
18 |
import torch
|
19 |
+
import torch.nn.functional as F
|
20 |
+
from datasets import Dataset, load_from_disk
|
21 |
+
from sklearn import preprocessing
|
22 |
+
from sklearn.metrics import (
|
23 |
+
ConfusionMatrixDisplay,
|
24 |
+
accuracy_score,
|
25 |
+
auc,
|
26 |
+
confusion_matrix,
|
27 |
+
roc_auc_score,
|
28 |
+
roc_curve,
|
29 |
+
)
|
30 |
+
|
31 |
+
# imports
|
32 |
+
from sklearn.model_selection import StratifiedKFold, train_test_split
|
33 |
from tqdm.notebook import tqdm
|
34 |
+
from transformers import BertForTokenClassification, Trainer
|
35 |
+
from transformers.training_args import TrainingArguments
|
36 |
+
|
37 |
from geneformer import DataCollatorForGeneClassification, EmbExtractor
|
38 |
from geneformer.pretrainer import token_dictionary
|
39 |
+
|
|
|
40 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
41 |
from geneformer import TranscriptomeTokenizer
|
42 |
|
43 |
+
|
44 |
+
def vote(logit_pair):
|
45 |
a, b = logit_pair
|
46 |
if a > b:
|
47 |
return 0
|
|
|
49 |
return 1
|
50 |
elif a == b:
|
51 |
return "tie"
|
52 |
+
|
53 |
+
|
54 |
def py_softmax(vector):
|
55 |
+
e = np.exp(vector)
|
56 |
+
return e / e.sum()
|
57 |
+
|
58 |
+
# Identifies cosine similarity between two embeddings. 0 is perfectly dissimilar and 1 is perfectly similar
|
59 |
+
|
60 |
+
|
61 |
+
def similarity(tensor1, tensor2, cosine=True):
|
62 |
if cosine == False:
|
63 |
if tensor1.ndimension() > 1:
|
64 |
tensor1 = tensor1.view(1, -1)
|
|
|
69 |
norm_tensor2 = torch.norm(tensor2)
|
70 |
epsilon = 1e-8
|
71 |
similarity = dot_product / (norm_tensor1 * norm_tensor2 + epsilon)
|
72 |
+
similarity = (similarity.item() + 1) / 2
|
73 |
else:
|
74 |
if tensor1.shape != tensor2.shape:
|
75 |
raise ValueError("Input tensors must have the same shape.")
|
|
|
78 |
dot_product = torch.dot(tensor1, tensor2)
|
79 |
norm_tensor1 = torch.norm(tensor1)
|
80 |
norm_tensor2 = torch.norm(tensor2)
|
81 |
+
|
82 |
# Avoid division by zero by adding a small epsilon
|
83 |
epsilon = 1e-8
|
84 |
similarity = dot_product / (norm_tensor1 * norm_tensor2 + epsilon)
|
85 |
+
|
86 |
return similarity.item()
|
87 |
+
|
88 |
+
|
89 |
# Plots heatmap between different classes/labels
|
90 |
def plot_similarity_heatmap(similarities):
|
91 |
classes = list(similarities.keys())
|
|
|
98 |
else:
|
99 |
val = similarities[c][cc]
|
100 |
arr[i][j] = val
|
101 |
+
|
102 |
plt.figure(figsize=(8, 6))
|
103 |
+
plt.imshow(arr, cmap="inferno", vmin=0, vmax=1)
|
104 |
plt.colorbar()
|
105 |
+
plt.xticks(np.arange(classlen), classes, rotation=45, ha="right")
|
106 |
plt.yticks(np.arange(classlen), classes)
|
107 |
plt.title("Similarity Heatmap")
|
108 |
plt.savefig("similarity_heatmap.png")
|
109 |
+
|
110 |
+
|
111 |
# get cross-validated mean and sd metrics
|
112 |
def get_cross_valid_metrics(all_tpr, all_roc_auc, all_tpr_wt):
|
113 |
+
wts = [count / sum(all_tpr_wt) for count in all_tpr_wt]
|
114 |
+
|
115 |
+
all_weighted_tpr = [a * b for a, b in zip(all_tpr, wts)]
|
116 |
mean_tpr = np.sum(all_weighted_tpr, axis=0)
|
117 |
mean_tpr[-1] = 1.0
|
118 |
+
all_weighted_roc_auc = [a * b for a, b in zip(all_roc_auc, wts)]
|
119 |
roc_auc = np.sum(all_weighted_roc_auc)
|
120 |
+
roc_auc_sd = math.sqrt(np.average((all_roc_auc - roc_auc) ** 2, weights=wts))
|
121 |
return mean_tpr, roc_auc, roc_auc_sd
|
122 |
|
123 |
+
|
124 |
+
def validate(
|
125 |
+
data,
|
126 |
+
targets,
|
127 |
+
labels,
|
128 |
+
nsplits,
|
129 |
+
subsample_size,
|
130 |
+
training_args,
|
131 |
+
freeze_layers,
|
132 |
+
output_dir,
|
133 |
+
num_proc,
|
134 |
+
num_labels,
|
135 |
+
pre_model,
|
136 |
+
):
|
137 |
# initiate eval metrics to return
|
138 |
num_classes = len(set(labels))
|
139 |
mean_fpr = np.linspace(0, 1, 100)
|
140 |
+
|
141 |
# create 80/20 train/eval splits
|
142 |
+
targets_train, targets_eval, labels_train, labels_eval = train_test_split(
|
143 |
+
targets, labels, test_size=0.25, shuffle=True
|
144 |
+
)
|
145 |
label_dict_train = dict(zip(targets_train, labels_train))
|
146 |
label_dict_eval = dict(zip(targets_eval, labels_eval))
|
147 |
+
|
148 |
# function to filter by whether contains train or eval labels
|
149 |
def if_contains_train_label(example):
|
150 |
a = label_dict_train.keys()
|
151 |
+
b = example["input_ids"]
|
152 |
return not set(a).isdisjoint(b)
|
153 |
|
154 |
def if_contains_eval_label(example):
|
155 |
a = label_dict_eval.keys()
|
156 |
+
b = example["input_ids"]
|
157 |
return not set(a).isdisjoint(b)
|
158 |
+
|
159 |
# filter dataset for examples containing classes for this split
|
160 |
print(f"Filtering training data")
|
161 |
trainset = data.filter(if_contains_train_label, num_proc=num_proc)
|
162 |
+
print(
|
163 |
+
f"Filtered {round((1-len(trainset)/len(data))*100)}%; {len(trainset)} remain\n"
|
164 |
+
)
|
165 |
print(f"Filtering evalation data")
|
166 |
evalset = data.filter(if_contains_eval_label, num_proc=num_proc)
|
167 |
print(f"Filtered {round((1-len(evalset)/len(data))*100)}%; {len(evalset)} remain\n")
|
168 |
+
|
169 |
# minimize to smaller training sample
|
170 |
training_size = min(subsample_size, len(trainset))
|
171 |
trainset_min = trainset.select([i for i in range(training_size)])
|
172 |
eval_size = min(training_size, len(evalset))
|
173 |
+
half_training_size = round(eval_size / 2)
|
174 |
evalset_train_min = evalset.select([i for i in range(half_training_size)])
|
175 |
evalset_oos_min = evalset.select([i for i in range(half_training_size, eval_size)])
|
176 |
+
|
177 |
# label conversion functions
|
178 |
def generate_train_labels(example):
|
179 |
+
example["labels"] = [
|
180 |
+
label_dict_train.get(token_id, -100) for token_id in example["input_ids"]
|
181 |
+
]
|
182 |
return example
|
183 |
|
184 |
def generate_eval_labels(example):
|
185 |
+
example["labels"] = [
|
186 |
+
label_dict_eval.get(token_id, -100) for token_id in example["input_ids"]
|
187 |
+
]
|
188 |
return example
|
189 |
+
|
190 |
+
# label datasets
|
191 |
print(f"Labeling training data")
|
192 |
trainset_labeled = trainset_min.map(generate_train_labels)
|
193 |
print(f"Labeling evaluation data")
|
194 |
evalset_train_labeled = evalset_train_min.map(generate_eval_labels)
|
195 |
print(f"Labeling evaluation OOS data")
|
196 |
evalset_oos_labeled = evalset_oos_min.map(generate_eval_labels)
|
197 |
+
|
198 |
# load model
|
199 |
model = BertForTokenClassification.from_pretrained(
|
200 |
+
pre_model,
|
201 |
+
num_labels=num_labels,
|
202 |
+
output_attentions=False,
|
203 |
+
output_hidden_states=False,
|
204 |
)
|
205 |
if freeze_layers is not None:
|
206 |
modules_to_freeze = model.bert.encoder.layer[:freeze_layers]
|
207 |
for module in modules_to_freeze:
|
208 |
for param in module.parameters():
|
209 |
param.requires_grad = False
|
210 |
+
|
211 |
model = model.to(device)
|
212 |
+
|
213 |
# add output directory to training args and initiate
|
214 |
training_args["output_dir"] = output_dir
|
215 |
training_args_init = TrainingArguments(**training_args)
|
216 |
+
|
217 |
# create the trainer
|
218 |
trainer = Trainer(
|
219 |
model=model,
|
|
|
222 |
train_dataset=trainset_labeled,
|
223 |
eval_dataset=evalset_train_labeled,
|
224 |
)
|
225 |
+
|
226 |
# train the gene classifier
|
227 |
trainer.train()
|
228 |
trainer.save_model(output_dir)
|
229 |
+
|
230 |
+
fpr, tpr, interp_tpr, conf_mat = classifier_predict(
|
231 |
+
trainer.model, evalset_oos_labeled, 200, mean_fpr
|
232 |
+
)
|
233 |
auc_score = auc(fpr, tpr)
|
234 |
+
|
235 |
return fpr, tpr, auc_score
|
236 |
+
|
237 |
+
|
238 |
# cross-validate gene classifier
|
239 |
+
def cross_validate(
|
240 |
+
data,
|
241 |
+
targets,
|
242 |
+
labels,
|
243 |
+
nsplits,
|
244 |
+
subsample_size,
|
245 |
+
training_args,
|
246 |
+
freeze_layers,
|
247 |
+
output_dir,
|
248 |
+
num_proc,
|
249 |
+
num_labels,
|
250 |
+
pre_model,
|
251 |
+
):
|
252 |
# check if output directory already written to
|
253 |
# ensure not overwriting previously saved model
|
254 |
model_dir_test = os.path.join(output_dir, "ksplit0/models/pytorch_model.bin")
|
255 |
+
# if os.path.isfile(model_dir_test) == True:
|
256 |
# raise Exception("Model already saved to this directory.")
|
257 |
+
|
258 |
+
device = device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
259 |
# initiate eval metrics to return
|
260 |
num_classes = len(set(labels))
|
261 |
mean_fpr = np.linspace(0, 1, 100)
|
|
|
263 |
all_roc_auc = []
|
264 |
all_tpr_wt = []
|
265 |
label_dicts = []
|
266 |
+
confusion = np.zeros((num_classes, num_classes))
|
267 |
+
|
268 |
# set up cross-validation splits
|
269 |
skf = StratifiedKFold(n_splits=nsplits, random_state=0, shuffle=True)
|
270 |
# train and evaluate
|
|
|
274 |
print("early stopping activated due to large # of training examples")
|
275 |
if iteration_num == 3:
|
276 |
break
|
277 |
+
|
278 |
print(f"****** Crossval split: {iteration_num}/{nsplits-1} ******\n")
|
279 |
+
|
280 |
# generate cross-validation splits
|
281 |
targets_train, targets_eval = targets[train_index], targets[eval_index]
|
282 |
labels_train, labels_eval = labels[train_index], labels[eval_index]
|
283 |
label_dict_train = dict(zip(targets_train, labels_train))
|
284 |
label_dict_eval = dict(zip(targets_eval, labels_eval))
|
285 |
+
label_dicts += (
|
286 |
+
iteration_num,
|
287 |
+
targets_train,
|
288 |
+
targets_eval,
|
289 |
+
labels_train,
|
290 |
+
labels_eval,
|
291 |
+
)
|
292 |
+
|
293 |
# function to filter by whether contains train or eval labels
|
294 |
def if_contains_train_label(example):
|
295 |
a = label_dict_train.keys()
|
296 |
+
b = example["input_ids"]
|
297 |
+
|
298 |
return not set(a).isdisjoint(b)
|
299 |
|
300 |
def if_contains_eval_label(example):
|
301 |
a = label_dict_eval.keys()
|
302 |
+
b = example["input_ids"]
|
303 |
+
|
304 |
return not set(a).isdisjoint(b)
|
305 |
+
|
306 |
# filter dataset for examples containing classes for this split
|
307 |
print(f"Filtering training data")
|
308 |
trainset = data.filter(if_contains_train_label, num_proc=num_proc)
|
309 |
+
print(
|
310 |
+
f"Filtered {round((1-len(trainset)/len(data))*100)}%; {len(trainset)} remain\n"
|
311 |
+
)
|
312 |
print(f"Filtering evalation data")
|
313 |
evalset = data.filter(if_contains_eval_label, num_proc=num_proc)
|
314 |
+
print(
|
315 |
+
f"Filtered {round((1-len(evalset)/len(data))*100)}%; {len(evalset)} remain\n"
|
316 |
+
)
|
317 |
|
318 |
# minimize to smaller training sample
|
319 |
training_size = min(subsample_size, len(trainset))
|
320 |
trainset_min = trainset.select([i for i in range(training_size)])
|
321 |
eval_size = min(training_size, len(evalset))
|
322 |
+
half_training_size = round(eval_size / 2)
|
323 |
evalset_train_min = evalset.select([i for i in range(half_training_size)])
|
324 |
+
evalset_oos_min = evalset.select(
|
325 |
+
[i for i in range(half_training_size, eval_size)]
|
326 |
+
)
|
327 |
+
|
328 |
# label conversion functions
|
329 |
def generate_train_labels(example):
|
330 |
+
example["labels"] = [
|
331 |
+
label_dict_train.get(token_id, -100)
|
332 |
+
for token_id in example["input_ids"]
|
333 |
+
]
|
334 |
return example
|
335 |
|
336 |
def generate_eval_labels(example):
|
337 |
+
example["labels"] = [
|
338 |
+
label_dict_eval.get(token_id, -100) for token_id in example["input_ids"]
|
339 |
+
]
|
340 |
return example
|
341 |
+
|
342 |
+
# label datasets
|
343 |
print(f"Labeling training data")
|
344 |
trainset_labeled = trainset_min.map(generate_train_labels)
|
345 |
print(f"Labeling evaluation data")
|
346 |
evalset_train_labeled = evalset_train_min.map(generate_eval_labels)
|
347 |
print(f"Labeling evaluation OOS data")
|
348 |
evalset_oos_labeled = evalset_oos_min.map(generate_eval_labels)
|
349 |
+
|
350 |
# create output directories
|
351 |
ksplit_output_dir = os.path.join(output_dir, f"ksplit{iteration_num}")
|
352 |
+
ksplit_model_dir = os.path.join(ksplit_output_dir, "models/")
|
353 |
+
|
354 |
# ensure not overwriting previously saved model
|
355 |
model_output_file = os.path.join(ksplit_model_dir, "pytorch_model.bin")
|
356 |
+
# if os.path.isfile(model_output_file) == True:
|
357 |
# raise Exception("Model already saved to this directory.")
|
358 |
|
359 |
# make training and model output directories
|
360 |
+
subprocess.call(f"mkdir -p {ksplit_output_dir}", shell=True)
|
361 |
+
subprocess.call(f"mkdir -p {ksplit_model_dir}", shell=True)
|
362 |
+
|
363 |
# load model
|
364 |
model = BertForTokenClassification.from_pretrained(
|
365 |
pre_model,
|
366 |
num_labels=num_labels,
|
367 |
+
output_attentions=False,
|
368 |
+
output_hidden_states=False,
|
369 |
)
|
370 |
if freeze_layers is not None:
|
371 |
modules_to_freeze = model.bert.encoder.layer[:freeze_layers]
|
372 |
for module in modules_to_freeze:
|
373 |
for param in module.parameters():
|
374 |
param.requires_grad = False
|
375 |
+
|
376 |
model = model.to(device)
|
377 |
+
|
378 |
# add output directory to training args and initiate
|
379 |
training_args["output_dir"] = ksplit_output_dir
|
380 |
training_args_init = TrainingArguments(**training_args)
|
381 |
+
|
382 |
# create the trainer
|
383 |
trainer = Trainer(
|
384 |
model=model,
|
385 |
args=training_args_init,
|
386 |
data_collator=DataCollatorForGeneClassification(),
|
387 |
train_dataset=trainset_labeled,
|
388 |
+
eval_dataset=evalset_train_labeled,
|
389 |
)
|
390 |
|
391 |
# train the gene classifier
|
392 |
trainer.train()
|
393 |
+
|
394 |
# save model
|
395 |
trainer.save_model(ksplit_model_dir)
|
396 |
+
|
397 |
# evaluate model
|
398 |
+
fpr, tpr, interp_tpr, conf_mat = classifier_predict(
|
399 |
+
trainer.model, evalset_oos_labeled, 200, mean_fpr
|
400 |
+
)
|
401 |
+
|
402 |
# append to tpr and roc lists
|
403 |
confusion = confusion + conf_mat
|
404 |
all_tpr.append(interp_tpr)
|
405 |
all_roc_auc.append(auc(fpr, tpr))
|
406 |
# append number of eval examples by which to weight tpr in averaged graphs
|
407 |
all_tpr_wt.append(len(tpr))
|
408 |
+
|
409 |
iteration_num = iteration_num + 1
|
410 |
+
|
411 |
# get overall metrics for cross-validation
|
412 |
+
mean_tpr, roc_auc, roc_auc_sd = get_cross_valid_metrics(
|
413 |
+
all_tpr, all_roc_auc, all_tpr_wt
|
414 |
+
)
|
415 |
return all_roc_auc, roc_auc, roc_auc_sd, mean_fpr, mean_tpr, confusion, label_dicts
|
416 |
+
|
417 |
+
|
418 |
# Computes metrics
|
419 |
def compute_metrics(pred):
|
420 |
labels = pred.label_ids
|
421 |
preds = pred.predictions.argmax(-1)
|
422 |
# calculate accuracy and macro f1 using sklearn's function
|
423 |
acc = accuracy_score(labels, preds)
|
424 |
+
macro_f1 = f1_score(labels, preds, average="macro")
|
425 |
+
|
426 |
+
return {"accuracy": acc, "macro_f1": macro_f1}
|
427 |
+
|
|
|
|
|
428 |
|
429 |
# plot ROC curve
|
430 |
def plot_ROC(bundled_data, title):
|
431 |
plt.figure()
|
432 |
lw = 2
|
433 |
for roc_auc, roc_auc_sd, mean_fpr, mean_tpr, sample, color in bundled_data:
|
434 |
+
plt.plot(
|
435 |
+
mean_fpr,
|
436 |
+
mean_tpr,
|
437 |
+
color=color,
|
438 |
+
lw=lw,
|
439 |
+
label="{0} (AUC {1:0.2f} $\pm$ {2:0.2f})".format(
|
440 |
+
sample, roc_auc, roc_auc_sd
|
441 |
+
),
|
442 |
+
)
|
443 |
+
|
444 |
+
plt.plot([0, 1], [0, 1], color="black", lw=lw, linestyle="--")
|
445 |
plt.xlim([0.0, 1.0])
|
446 |
plt.ylim([0.0, 1.05])
|
447 |
+
plt.xlabel("False Positive Rate")
|
448 |
+
plt.ylabel("True Positive Rate")
|
449 |
plt.title(title)
|
450 |
plt.legend(loc="lower right")
|
451 |
plt.savefig("ROC.png")
|
452 |
+
|
453 |
return mean_fpr, mean_tpr, roc_auc
|
454 |
+
|
455 |
+
|
456 |
# plot confusion matrix
|
457 |
def plot_confusion_matrix(classes_list, conf_mat, title):
|
458 |
display_labels = []
|
459 |
i = 0
|
460 |
for label in classes_list:
|
461 |
+
display_labels += ["{0}\nn={1:.0f}".format(label, sum(conf_mat[:, i]))]
|
462 |
i = i + 1
|
463 |
+
display = ConfusionMatrixDisplay(
|
464 |
+
confusion_matrix=preprocessing.normalize(conf_mat, norm="l1"),
|
465 |
+
display_labels=display_labels,
|
466 |
+
)
|
467 |
+
display.plot(cmap="Blues", values_format=".2g")
|
468 |
plt.title(title)
|
469 |
plt.savefig("CM.png")
|
470 |
+
|
471 |
+
|
472 |
# Function to find the largest number smaller
|
473 |
# than or equal to N that is divisible by k
|
474 |
def find_largest_div(N, K):
|
475 |
rem = N % K
|
476 |
+
if rem == 0:
|
477 |
return N
|
478 |
else:
|
479 |
return N - rem
|
480 |
+
|
481 |
+
|
482 |
def preprocess_classifier_batch(cell_batch, max_len):
|
483 |
if max_len == None:
|
484 |
max_len = max([len(i) for i in cell_batch["input_ids"]])
|
485 |
+
|
486 |
def pad_label_example(example):
|
487 |
+
example["labels"] = np.pad(
|
488 |
+
example["labels"],
|
489 |
+
(0, max_len - len(example["input_ids"])),
|
490 |
+
mode="constant",
|
491 |
+
constant_values=-100,
|
492 |
+
)
|
493 |
+
example["input_ids"] = np.pad(
|
494 |
+
example["input_ids"],
|
495 |
+
(0, max_len - len(example["input_ids"])),
|
496 |
+
mode="constant",
|
497 |
+
constant_values=token_dictionary.get("<pad>"),
|
498 |
+
)
|
499 |
+
example["attention_mask"] = (
|
500 |
+
example["input_ids"] != token_dictionary.get("<pad>")
|
501 |
+
).astype(int)
|
502 |
return example
|
503 |
+
|
504 |
padded_batch = cell_batch.map(pad_label_example)
|
505 |
return padded_batch
|
506 |
|
507 |
+
|
508 |
# forward batch size is batch size for model inference (e.g. 200)
|
509 |
def classifier_predict(model, evalset, forward_batch_size, mean_fpr):
|
510 |
predict_logits = []
|
511 |
predict_labels = []
|
512 |
+
model.to("cpu")
|
513 |
model.eval()
|
514 |
+
|
515 |
# ensure there is at least 2 examples in each batch to avoid incorrect tensor dims
|
516 |
evalset_len = len(evalset)
|
517 |
max_divisible = find_largest_div(evalset_len, forward_batch_size)
|
518 |
if len(evalset) - max_divisible == 1:
|
519 |
evalset_len = max_divisible
|
520 |
+
|
521 |
max_evalset_len = max(evalset.select([i for i in range(evalset_len)])["length"])
|
522 |
+
|
523 |
for i in range(0, evalset_len, forward_batch_size):
|
524 |
+
max_range = min(i + forward_batch_size, evalset_len)
|
525 |
batch_evalset = evalset.select([i for i in range(i, max_range)])
|
526 |
padded_batch = preprocess_classifier_batch(batch_evalset, max_evalset_len)
|
527 |
padded_batch.set_format(type="torch")
|
528 |
+
|
529 |
input_data_batch = padded_batch["input_ids"]
|
530 |
attn_msk_batch = padded_batch["attention_mask"]
|
531 |
label_batch = padded_batch["labels"]
|
532 |
with torch.no_grad():
|
533 |
input_ids = input_data_batch
|
534 |
attn_mask = attn_msk_batch
|
535 |
+
labels = label_batch
|
536 |
outputs = model(
|
537 |
+
input_ids=input_ids, attention_mask=attn_mask, labels=labels
|
|
|
|
|
|
|
538 |
)
|
539 |
predict_logits += [torch.squeeze(outputs.logits.to("cpu"))]
|
540 |
predict_labels += [torch.squeeze(label_batch.to("cpu"))]
|
541 |
+
|
542 |
logits_by_cell = torch.cat(predict_logits)
|
543 |
all_logits = logits_by_cell.reshape(-1, logits_by_cell.shape[2])
|
544 |
labels_by_cell = torch.cat(predict_labels)
|
545 |
all_labels = torch.flatten(labels_by_cell)
|
546 |
+
logit_label_paired = [
|
547 |
+
item
|
548 |
+
for item in list(zip(all_logits.tolist(), all_labels.tolist()))
|
549 |
+
if item[1] != -100
|
550 |
+
]
|
551 |
y_pred = [vote(item[0]) for item in logit_label_paired]
|
552 |
y_true = [item[1] for item in logit_label_paired]
|
553 |
logits_list = [item[0] for item in logit_label_paired]
|
|
|
559 |
plt.plot(fpr, tpr)
|
560 |
plt.xlim([0.0, 1.0])
|
561 |
plt.ylim([0.0, 1.05])
|
562 |
+
plt.xlabel("False Positive Rate")
|
563 |
+
plt.ylabel("True Positive Rate")
|
564 |
+
plt.title("ROC")
|
565 |
plt.show()
|
566 |
# interpolate to graph
|
567 |
interp_tpr = np.interp(mean_fpr, fpr, tpr)
|
568 |
interp_tpr[0] = 0.0
|
569 |
+
return fpr, tpr, interp_tpr, conf_mat
|
570 |
+
|
571 |
+
|
572 |
+
def classify_genes(
|
573 |
+
gene_info="Genecorpus-30M/example_input_files/gene_info_table.csv",
|
574 |
+
genes="Genecorpus-30M/example_input_files/gene_classification/dosage_sensitive_tfs/dosage_sens_tf_labels.csv",
|
575 |
+
corpus_30M="Genecorpus-30M/genecorpus_30M_2048.dataset/",
|
576 |
+
model=".",
|
577 |
+
max_input_size=2**11,
|
578 |
+
max_lr=5e-5,
|
579 |
+
freeze_layers=4,
|
580 |
+
num_gpus=1,
|
581 |
+
num_proc=os.cpu_count(),
|
582 |
+
geneformer_batch_size=9,
|
583 |
+
epochs=1,
|
584 |
+
filter_dataset=50_000,
|
585 |
+
emb_extract=True,
|
586 |
+
emb_layer=0,
|
587 |
+
forward_batch=200,
|
588 |
+
filter_data=None,
|
589 |
+
inference=False,
|
590 |
+
k_validate=True,
|
591 |
+
model_location="230917_geneformer_GeneClassifier_dosageTF_L2048_B12_LR5e-05_LSlinear_WU500_E1_Oadamw_n10000_F4/",
|
592 |
+
skip_training=False,
|
593 |
+
emb_dir="gene_emb",
|
594 |
+
output_dir=None,
|
595 |
+
max_cells=1000,
|
596 |
+
num_cpus=os.cpu_count(),
|
597 |
+
):
|
598 |
+
""" "
|
599 |
Primary Parameters
|
600 |
-----------
|
601 |
+
|
602 |
gene_info: path
|
603 |
Path to gene mappings
|
604 |
+
|
605 |
corpus_30M: path
|
606 |
Path to 30M Gene Corpus
|
607 |
+
|
608 |
model: path
|
609 |
Path to pretrained GeneFormer model
|
610 |
+
|
611 |
genes: path
|
612 |
Path to csv file containing different columns of genes and the column labels
|
613 |
+
|
614 |
inference: bool
|
615 |
Whether the model should be used to run inference. If False, model will train with labeled data instead. Defaults to False
|
616 |
+
|
617 |
k_validate: bool
|
618 |
Whether the model should run k-fold validation or simply perform regular training/evaluate. Defaults to True
|
619 |
+
|
620 |
skip_training: bool
|
621 |
Whether the model should skip the training portion. Defaults to False
|
622 |
+
|
623 |
emb_extract: bool
|
624 |
WHether the model should extract embeddings for a given gene (WIP)
|
625 |
+
|
626 |
+
|
627 |
Customization Parameters
|
628 |
-----------
|
629 |
+
|
630 |
freeze_layers: int
|
631 |
Freezes x number of layers from the model. Default is 4 (2 non-frozen layers)
|
632 |
+
|
633 |
filter_dataset: int
|
634 |
Number of cells to filter from 30M dataset. Default is 50_000
|
635 |
+
|
636 |
emb_layer: int
|
637 |
What layer embeddings are extracted from. Default is 4
|
638 |
+
|
639 |
filter_data: str, list
|
640 |
Filters down embeddings to a single category. Default is None
|
641 |
+
|
642 |
+
|
643 |
"""
|
644 |
+
|
645 |
# table of corresponding Ensembl IDs, gene names, and gene types (e.g. coding, miRNA, etc.)
|
646 |
gene_info = pd.read_csv(gene_info, index_col=0)
|
647 |
labels = gene_info.columns
|
648 |
|
649 |
# create dictionaries for corresponding attributes
|
650 |
+
gene_id_type_dict = dict(zip(gene_info["ensembl_id"], gene_info["gene_type"]))
|
651 |
+
gene_name_id_dict = dict(zip(gene_info["gene_name"], gene_info["ensembl_id"]))
|
652 |
+
gene_id_name_dict = {v: k for k, v in gene_name_id_dict.items()}
|
653 |
|
654 |
# function for preparing targets and labels
|
655 |
def prep_inputs(label_store, id_type):
|
656 |
target_list = []
|
657 |
if id_type == "gene_name":
|
658 |
for key in list(label_store.keys()):
|
659 |
+
targets = [
|
660 |
+
gene_name_id_dict[gene]
|
661 |
+
for gene in label_store[key]
|
662 |
+
if gene_name_id_dict.get(gene) in token_dictionary
|
663 |
+
]
|
664 |
targets_id = [token_dictionary[gene] for gene in targets]
|
665 |
target_list.append(targets_id)
|
666 |
elif id_type == "ensembl_id":
|
667 |
for key in list(label_store.keys()):
|
668 |
+
targets = [
|
669 |
+
gene for gene in label_store[key] if gene in token_dictionary
|
670 |
+
]
|
671 |
targets_id = [token_dictionary[gene] for gene in targets]
|
672 |
target_list.append(targets_id)
|
673 |
+
|
674 |
targets, labels = [], []
|
675 |
for targ in target_list:
|
676 |
targets = targets + targ
|
677 |
targets = np.array(targets)
|
678 |
for num, targ in enumerate(target_list):
|
679 |
+
label = [num] * len(targ)
|
680 |
labels = labels + label
|
681 |
labels = np.array(labels)
|
682 |
unique_labels = num + 1
|
683 |
+
|
684 |
+
nsplits = min(5, min([len(targ) for targ in target_list]) - 1)
|
685 |
assert nsplits > 2
|
686 |
+
|
687 |
return targets, labels, nsplits, unique_labels
|
688 |
+
|
689 |
if skip_training == False:
|
690 |
# preparing targets and labels for dosage sensitive vs insensitive TFs
|
691 |
gene_classes = pd.read_csv(genes, header=0)
|
|
|
697 |
else:
|
698 |
labels = [filter_data]
|
699 |
label_store = {}
|
700 |
+
|
701 |
# Dictionary for decoding labels
|
702 |
+
decode = {i: labels[i] for i in range(len(labels))}
|
703 |
+
|
704 |
for label in labels:
|
705 |
label_store[label] = gene_classes[label].dropna()
|
706 |
+
|
707 |
targets, labels, nsplits, unique_labels = prep_inputs(label_store, "ensembl_id")
|
708 |
+
|
|
|
|
|
709 |
# load training dataset
|
710 |
+
train_dataset = load_from_disk(corpus_30M)
|
711 |
shuffled_train_dataset = train_dataset.shuffle(seed=42)
|
712 |
+
subsampled_train_dataset = shuffled_train_dataset.select(
|
713 |
+
[i for i in range(filter_dataset)]
|
714 |
+
)
|
715 |
lr_schedule_fn = "linear"
|
716 |
warmup_steps = 500
|
717 |
optimizer = "adamw"
|
718 |
subsample_size = 10_000
|
719 |
+
|
720 |
training_args = {
|
721 |
"learning_rate": max_lr,
|
722 |
"do_train": True,
|
|
|
733 |
"per_device_eval_batch_size": geneformer_batch_size,
|
734 |
"num_train_epochs": epochs,
|
735 |
}
|
736 |
+
|
737 |
# define output directory path
|
738 |
current_date = datetime.datetime.now()
|
739 |
datestamp = f"{str(current_date.year)[-2:]}{current_date.month:02d}{current_date.day:02d}"
|
740 |
+
|
741 |
if output_dir == None:
|
742 |
+
training_output_dir = Path(
|
743 |
+
f"{datestamp}_geneformer_GeneClassifier_dosageTF_L{max_input_size}_B{geneformer_batch_size}_LR{max_lr}_LS{lr_schedule_fn}_WU{warmup_steps}_E{epochs}_O{optimizer}_n{subsample_size}_F{freeze_layers}/"
|
744 |
+
)
|
745 |
else:
|
746 |
training_output_dir = Path(output_dir)
|
747 |
+
|
748 |
# make output directory
|
749 |
+
subprocess.call(f"mkdir -p {training_output_dir}", shell=True)
|
750 |
+
|
751 |
# Places number of classes + in directory
|
752 |
num_classes = len(set(labels))
|
753 |
info_list = [num_classes, decode]
|
754 |
+
|
755 |
+
with open(training_output_dir / "classes.txt", "w") as f:
|
756 |
f.write(str(info_list))
|
757 |
+
|
758 |
+
subsampled_train_dataset.save_to_disk(output_dir / "dataset")
|
759 |
+
|
760 |
if k_validate == True:
|
761 |
+
ksplit_model = "ksplit0/models"
|
762 |
ksplit_model_test = os.path.join(training_output_dir, ksplit_model)
|
763 |
+
# if os.path.isfile(ksplit_model_test) == True:
|
764 |
# raise Exception("Model already saved to this directory.")
|
765 |
+
# cross-validate gene classifier
|
766 |
+
(
|
767 |
+
all_roc_auc,
|
768 |
+
roc_auc,
|
769 |
+
roc_auc_sd,
|
770 |
+
mean_fpr,
|
771 |
+
mean_tpr,
|
772 |
+
confusion,
|
773 |
+
label_dicts,
|
774 |
+
) = cross_validate(
|
775 |
+
subsampled_train_dataset,
|
776 |
+
targets,
|
777 |
+
labels,
|
778 |
+
nsplits,
|
779 |
+
subsample_size,
|
780 |
+
training_args,
|
781 |
+
freeze_layers,
|
782 |
+
training_output_dir,
|
783 |
+
1,
|
784 |
+
unique_labels,
|
785 |
+
model,
|
786 |
+
)
|
787 |
+
|
788 |
bundled_data = []
|
789 |
+
bundled_data += [
|
790 |
+
(roc_auc, roc_auc_sd, mean_fpr, mean_tpr, "Geneformer", "red")
|
791 |
+
]
|
792 |
+
graph_title = " ".join(
|
793 |
+
[
|
794 |
+
i + " vs" if count < len(label_store) - 1 else i
|
795 |
+
for count, i in enumerate(label_store)
|
796 |
+
]
|
797 |
+
)
|
798 |
+
fpr, tpr, auc = plot_ROC(
|
799 |
+
bundled_data, "Dosage Sensitive vs Insensitive TFs"
|
800 |
+
)
|
801 |
print(auc)
|
802 |
# plot confusion matrix
|
803 |
plot_confusion_matrix(label_store, confusion, "Geneformer")
|
804 |
+
else:
|
805 |
+
fpr, tpr, auc = validate(
|
806 |
+
subsampled_train_dataset,
|
807 |
+
targets,
|
808 |
+
labels,
|
809 |
+
nsplits,
|
810 |
+
subsample_size,
|
811 |
+
training_args,
|
812 |
+
freeze_layers,
|
813 |
+
training_output_dir,
|
814 |
+
1,
|
815 |
+
unique_labels,
|
816 |
+
model,
|
817 |
+
)
|
818 |
print(auc)
|
819 |
+
|
820 |
if inference == True:
|
821 |
# preparing targets and labels for dosage sensitive vs insensitive TFs
|
822 |
gene_classes = pd.read_csv(genes, header=0)
|
823 |
targets = []
|
824 |
+
for column in gene_classes.columns:
|
825 |
targets += list(gene_classes[column])
|
826 |
tokens = []
|
827 |
for target in targets:
|
|
|
829 |
tokens.append(token_dictionary[target])
|
830 |
except:
|
831 |
tokens.append(0)
|
832 |
+
|
833 |
targets = torch.LongTensor([tokens])
|
834 |
+
|
835 |
+
with open(f"{model_location}classes.txt", "r") as f:
|
|
|
836 |
info_list = ast.literal_eval(f.read())
|
837 |
num_classes = info_list[0]
|
838 |
labels = info_list[1]
|
839 |
+
|
840 |
model = BertForTokenClassification.from_pretrained(
|
841 |
model_location,
|
842 |
num_labels=num_classes,
|
843 |
+
output_attentions=False,
|
844 |
+
output_hidden_states=False,
|
845 |
+
local_files_only=True,
|
846 |
)
|
847 |
if freeze_layers is not None:
|
848 |
modules_to_freeze = model.bert.encoder.layer[:freeze_layers]
|
849 |
for module in modules_to_freeze:
|
850 |
for param in module.parameters():
|
851 |
param.requires_grad = False
|
852 |
+
|
853 |
model = model.to(device)
|
854 |
+
|
855 |
# evaluate model
|
856 |
+
predictions = F.softmax(model(targets.to(device))["logits"], dim=-1).argmax(-1)[
|
857 |
+
0
|
858 |
+
]
|
859 |
predictions = [labels[int(pred)] for pred in predictions]
|
860 |
+
|
861 |
return predictions
|
862 |
+
|
863 |
+
# Extracts aggregate gene embeddings for each label
|
864 |
if emb_extract == True:
|
865 |
+
with open(f"{model_location}/classes.txt", "r") as f:
|
866 |
data = ast.literal_eval(f.read())
|
867 |
num_classes = data[0]
|
868 |
decode = data[1]
|
869 |
+
|
870 |
gene_classes = pd.read_csv(genes, header=0)
|
871 |
labels = gene_classes.columns
|
872 |
tokenize = TranscriptomeTokenizer()
|
873 |
+
|
874 |
label_dict = {}
|
875 |
for label in labels:
|
876 |
genes = gene_classes[label]
|
|
|
881 |
except:
|
882 |
continue
|
883 |
label_dict[label] = tokenized_genes
|
884 |
+
|
885 |
+
embex = EmbExtractor(
|
886 |
+
model_type="GeneClassifier",
|
887 |
+
num_classes=num_classes,
|
888 |
+
emb_mode="gene",
|
889 |
+
filter_data=None,
|
890 |
+
max_ncells=max_cells,
|
891 |
+
emb_layer=emb_layer,
|
892 |
+
emb_label=label_dict,
|
893 |
+
labels_to_plot=list(labels),
|
894 |
+
forward_batch_size=forward_batch,
|
895 |
+
nproc=num_cpus,
|
896 |
+
)
|
897 |
+
|
898 |
+
subprocess.call(f"mkdir -p {emb_dir}", shell=True)
|
899 |
+
|
900 |
+
embs = embex.extract_embs(
|
901 |
+
model_directory=model_location,
|
902 |
+
input_data_file=model_location / "dataset",
|
903 |
+
output_directory=emb_dir,
|
904 |
+
output_prefix=f"{label}_embbeddings",
|
905 |
+
)
|
906 |
+
|
907 |
+
emb_dict = {label: [] for label in list(set(labels))}
|
908 |
+
similarities = {key: {} for key in list(emb_dict.keys())}
|
909 |
+
|
910 |
for column in embs.columns:
|
911 |
remaining_cols = [k for k in embs.columns if k != column]
|
912 |
for k in remaining_cols:
|
913 |
embedding = torch.Tensor(embs[k])
|
914 |
+
sim = similarity(torch.Tensor(embs[column]), embedding, cosine=True)
|
915 |
similarities[column][k] = sim
|
916 |
+
|
917 |
plot_similarity_heatmap(similarities)
|
918 |
print(similarities)
|
919 |
+
|
920 |
return similarities
|
|
|
|
|
|
|
921 |
|
922 |
+
|
923 |
+
if __name__ == "__main__":
|
924 |
+
classify_genes(
|
925 |
+
k_validate=False,
|
926 |
+
inference=False,
|
927 |
+
skip_training=False,
|
928 |
+
emb_extract=True,
|
929 |
+
output_dir=Path("gene_emb"),
|
930 |
+
model_location=Path("gene_emb"),
|
931 |
+
epochs=5,
|
932 |
+
gene_info="../GeneFormer_repo/Genecorpus-30M/example_input_files/gene_info_table.csv",
|
933 |
+
genes="../GeneFormer_repo/Genecorpus-30M/example_input_files/gene_classification/dosage_sensitive_tfs/dosage_sens_tf_labels.csv",
|
934 |
+
corpus_30M="../GeneFormer_repo/Genecorpus-30M/genecorpus_30M_2048.dataset/",
|
935 |
+
)
|
Modular_usage.md → geneformer/modular_classifier_usage.md
RENAMED
@@ -1,33 +1,33 @@
|
|
1 |
# Cell classifier
|
2 |
def finetune_cells(token_set = Path('geneformer/token_dictionary.pkl'), median_set = Path('geneformer/gene_median_dictionary.pkl'), pretrained_model = ".",
|
3 |
-
dataset = 'Genecorpus-30M/example_input_files/cell_classification/cell_type_annotation/cell_type_train_data.dataset/',
|
4 |
dataset_split = None,
|
5 |
-
filter_cells = .005,
|
6 |
-
epochs = 1,
|
7 |
-
cpu_cores = os.cpu_count(),
|
8 |
-
geneformer_batch_size = 12,
|
9 |
-
optimizer = 'adamw',
|
10 |
-
max_lr = 5e-5,
|
11 |
-
num_gpus = torch.cuda.device_count(),
|
12 |
max_input_size = 2 ** 11,
|
13 |
-
lr_schedule_fn = "linear",
|
14 |
-
warmup_steps = 500,
|
15 |
-
freeze_layers = 0,
|
16 |
-
emb_extract = False,
|
17 |
-
max_cells = 1000,
|
18 |
-
emb_layer = 0,
|
19 |
-
emb_filter = None,
|
20 |
-
emb_dir = 'embeddings',
|
21 |
overwrite = True,
|
22 |
label = "cell_type",
|
23 |
data_filter = None,
|
24 |
-
forward_batch = 200, model_location = None,
|
25 |
-
skip_training = False,
|
26 |
sample_data = 1,
|
27 |
-
inference = False,
|
28 |
-
optimize_hyperparameters = False,
|
29 |
output_dir = None):
|
30 |
-
|
31 |
'''
|
32 |
Primary Parameters
|
33 |
-------------------
|
@@ -36,121 +36,121 @@ def finetune_cells(token_set = Path('geneformer/token_dictionary.pkl'), median_s
|
|
36 |
|
37 |
model_location: path
|
38 |
Path to location of existing model to use for inference and embedding extraction
|
39 |
-
|
40 |
pretrained_model: path
|
41 |
Path to pretrained GeneFormer 30M model before fine-tuning
|
42 |
-
|
43 |
-
inference: bool
|
44 |
Chooses whether to perform inference (which causes the function to return the list of similarities). Defaults to False
|
45 |
-
|
46 |
-
skip_training: bool
|
47 |
Chooses whether to skip training the model. Defaults to False
|
48 |
-
|
49 |
emb_extract: bool
|
50 |
Choose whether to extract embeddings and calculate similarities. Defaults to True
|
51 |
-
|
52 |
optimize_hyperparameters: bool
|
53 |
Choose whether to optimize model hyperparamters. Defaults to False
|
54 |
label: string
|
55 |
-
The label string in the formatted dataset that contains true class labels. Defaults to "label"
|
56 |
-
|
57 |
Customization Parameters
|
58 |
-------------------
|
59 |
-
|
60 |
dataset_split: str
|
61 |
How the dataset should be partitioned (if at all), and what ID should be used for partitioning
|
62 |
-
|
63 |
data_filter: list
|
64 |
(For embeddings and inference) Runs analysis subsets of the dataset by the ID defined by dataset_split
|
65 |
-
|
66 |
label: str
|
67 |
What feature should be read as a classification label
|
68 |
-
|
69 |
emb_layer: int
|
70 |
What layer embeddings should be extracted and compared from.
|
71 |
-
|
72 |
emb_filter: ['cell1', 'cell2'...]
|
73 |
Allows user to narrow down range of cells that embeddings will be extracted from.
|
74 |
-
|
75 |
max_cells: int
|
76 |
-
How many embeddings from cells should be extracted.
|
77 |
-
|
78 |
freeze_layers: int
|
79 |
Number of layers should be permanently frozen during fine-tuning (starting from the first layer, 4 brings it up to the pretrained model).
|
80 |
-
|
81 |
sample_data: float
|
82 |
What proportion of the HF dataset should be used
|
83 |
-
|
84 |
'''
|
85 |
-
|
86 |
# Gene Classifier
|
87 |
-
def classify_genes(gene_info = "Genecorpus-30M/example_input_files/gene_info_table.csv",
|
88 |
genes = "Genecorpus-30M/example_input_files/gene_classification/dosage_sensitive_tfs/dosage_sens_tf_labels.csv",
|
89 |
corpus_30M = "Genecorpus-30M/genecorpus_30M_2048.dataset/", model = '.',
|
90 |
-
max_input_size = 2 ** 11,
|
91 |
max_lr = 5e-5,
|
92 |
-
freeze_layers = 4,
|
93 |
-
num_gpus = 1,
|
94 |
-
num_proc = os.cpu_count(),
|
95 |
-
geneformer_batch_size = 9,
|
96 |
-
epochs = 1,
|
97 |
filter_dataset = 50_000,
|
98 |
-
emb_extract = True,
|
99 |
-
emb_layer = 0,
|
100 |
-
forward_batch = 200,
|
101 |
-
filter_data = None,
|
102 |
-
inference = False,
|
103 |
-
k_validate = True,
|
104 |
-
model_location = "230917_geneformer_GeneClassifier_dosageTF_L2048_B12_LR5e-05_LSlinear_WU500_E1_Oadamw_n10000_F4/",
|
105 |
-
skip_training = False,
|
106 |
-
emb_dir = 'gene_emb',
|
107 |
-
output_dir = None,
|
108 |
-
max_cells = 1000,
|
109 |
num_cpus = os.cpu_count()):
|
110 |
-
|
111 |
""""
|
112 |
Primary Parameters
|
113 |
-----------
|
114 |
-
|
115 |
gene_info: path
|
116 |
Path to gene mappings
|
117 |
-
|
118 |
corpus_30M: path
|
119 |
Path to 30M Gene Corpus
|
120 |
-
|
121 |
model: path
|
122 |
Path to pretrained GeneFormer model
|
123 |
-
|
124 |
genes: path
|
125 |
Path to csv file containing different columns of genes and the column labels
|
126 |
-
|
127 |
inference: bool
|
128 |
Whether the model should be used to run inference. If False, model will train with labeled data instead. Defaults to False
|
129 |
-
|
130 |
k_validate: bool
|
131 |
Whether the model should run k-fold validation or simply perform regular training/evaluate. Defaults to True
|
132 |
-
|
133 |
skip_training: bool
|
134 |
Whether the model should skip the training portion. Defaults to False
|
135 |
-
|
136 |
emb_extract: bool
|
137 |
WHether the model should extract embeddings for a given gene (WIP)
|
138 |
-
|
139 |
-
|
140 |
Customization Parameters
|
141 |
-----------
|
142 |
-
|
143 |
freeze_layers: int
|
144 |
Freezes x number of layers from the model. Default is 4 (2 non-frozen layers)
|
145 |
-
|
146 |
filter_dataset: int
|
147 |
Number of cells to filter from 30M dataset. Default is 50_000
|
148 |
-
|
149 |
emb_layer: int
|
150 |
What layer embeddings are extracted from. Default is 4
|
151 |
-
|
152 |
filter_data: str, list
|
153 |
Filters down embeddings to a single category. Default is None
|
154 |
-
|
155 |
-
|
156 |
"""
|
|
|
1 |
# Cell classifier
|
2 |
def finetune_cells(token_set = Path('geneformer/token_dictionary.pkl'), median_set = Path('geneformer/gene_median_dictionary.pkl'), pretrained_model = ".",
|
3 |
+
dataset = 'Genecorpus-30M/example_input_files/cell_classification/cell_type_annotation/cell_type_train_data.dataset/',
|
4 |
dataset_split = None,
|
5 |
+
filter_cells = .005,
|
6 |
+
epochs = 1,
|
7 |
+
cpu_cores = os.cpu_count(),
|
8 |
+
geneformer_batch_size = 12,
|
9 |
+
optimizer = 'adamw',
|
10 |
+
max_lr = 5e-5,
|
11 |
+
num_gpus = torch.cuda.device_count(),
|
12 |
max_input_size = 2 ** 11,
|
13 |
+
lr_schedule_fn = "linear",
|
14 |
+
warmup_steps = 500,
|
15 |
+
freeze_layers = 0,
|
16 |
+
emb_extract = False,
|
17 |
+
max_cells = 1000,
|
18 |
+
emb_layer = 0,
|
19 |
+
emb_filter = None,
|
20 |
+
emb_dir = 'embeddings',
|
21 |
overwrite = True,
|
22 |
label = "cell_type",
|
23 |
data_filter = None,
|
24 |
+
forward_batch = 200, model_location = None,
|
25 |
+
skip_training = False,
|
26 |
sample_data = 1,
|
27 |
+
inference = False,
|
28 |
+
optimize_hyperparameters = False,
|
29 |
output_dir = None):
|
30 |
+
|
31 |
'''
|
32 |
Primary Parameters
|
33 |
-------------------
|
|
|
36 |
|
37 |
model_location: path
|
38 |
Path to location of existing model to use for inference and embedding extraction
|
39 |
+
|
40 |
pretrained_model: path
|
41 |
Path to pretrained GeneFormer 30M model before fine-tuning
|
42 |
+
|
43 |
+
inference: bool
|
44 |
Chooses whether to perform inference (which causes the function to return the list of similarities). Defaults to False
|
45 |
+
|
46 |
+
skip_training: bool
|
47 |
Chooses whether to skip training the model. Defaults to False
|
48 |
+
|
49 |
emb_extract: bool
|
50 |
Choose whether to extract embeddings and calculate similarities. Defaults to True
|
51 |
+
|
52 |
optimize_hyperparameters: bool
|
53 |
Choose whether to optimize model hyperparamters. Defaults to False
|
54 |
label: string
|
55 |
+
The label string in the formatted dataset that contains true class labels. Defaults to "label"
|
56 |
+
|
57 |
Customization Parameters
|
58 |
-------------------
|
59 |
+
|
60 |
dataset_split: str
|
61 |
How the dataset should be partitioned (if at all), and what ID should be used for partitioning
|
62 |
+
|
63 |
data_filter: list
|
64 |
(For embeddings and inference) Runs analysis subsets of the dataset by the ID defined by dataset_split
|
65 |
+
|
66 |
label: str
|
67 |
What feature should be read as a classification label
|
68 |
+
|
69 |
emb_layer: int
|
70 |
What layer embeddings should be extracted and compared from.
|
71 |
+
|
72 |
emb_filter: ['cell1', 'cell2'...]
|
73 |
Allows user to narrow down range of cells that embeddings will be extracted from.
|
74 |
+
|
75 |
max_cells: int
|
76 |
+
How many embeddings from cells should be extracted.
|
77 |
+
|
78 |
freeze_layers: int
|
79 |
Number of layers should be permanently frozen during fine-tuning (starting from the first layer, 4 brings it up to the pretrained model).
|
80 |
+
|
81 |
sample_data: float
|
82 |
What proportion of the HF dataset should be used
|
83 |
+
|
84 |
'''
|
85 |
+
|
86 |
# Gene Classifier
|
87 |
+
def classify_genes(gene_info = "Genecorpus-30M/example_input_files/gene_info_table.csv",
|
88 |
genes = "Genecorpus-30M/example_input_files/gene_classification/dosage_sensitive_tfs/dosage_sens_tf_labels.csv",
|
89 |
corpus_30M = "Genecorpus-30M/genecorpus_30M_2048.dataset/", model = '.',
|
90 |
+
max_input_size = 2 ** 11,
|
91 |
max_lr = 5e-5,
|
92 |
+
freeze_layers = 4,
|
93 |
+
num_gpus = 1,
|
94 |
+
num_proc = os.cpu_count(),
|
95 |
+
geneformer_batch_size = 9,
|
96 |
+
epochs = 1,
|
97 |
filter_dataset = 50_000,
|
98 |
+
emb_extract = True,
|
99 |
+
emb_layer = 0,
|
100 |
+
forward_batch = 200,
|
101 |
+
filter_data = None,
|
102 |
+
inference = False,
|
103 |
+
k_validate = True,
|
104 |
+
model_location = "230917_geneformer_GeneClassifier_dosageTF_L2048_B12_LR5e-05_LSlinear_WU500_E1_Oadamw_n10000_F4/",
|
105 |
+
skip_training = False,
|
106 |
+
emb_dir = 'gene_emb',
|
107 |
+
output_dir = None,
|
108 |
+
max_cells = 1000,
|
109 |
num_cpus = os.cpu_count()):
|
110 |
+
|
111 |
""""
|
112 |
Primary Parameters
|
113 |
-----------
|
114 |
+
|
115 |
gene_info: path
|
116 |
Path to gene mappings
|
117 |
+
|
118 |
corpus_30M: path
|
119 |
Path to 30M Gene Corpus
|
120 |
+
|
121 |
model: path
|
122 |
Path to pretrained GeneFormer model
|
123 |
+
|
124 |
genes: path
|
125 |
Path to csv file containing different columns of genes and the column labels
|
126 |
+
|
127 |
inference: bool
|
128 |
Whether the model should be used to run inference. If False, model will train with labeled data instead. Defaults to False
|
129 |
+
|
130 |
k_validate: bool
|
131 |
Whether the model should run k-fold validation or simply perform regular training/evaluate. Defaults to True
|
132 |
+
|
133 |
skip_training: bool
|
134 |
Whether the model should skip the training portion. Defaults to False
|
135 |
+
|
136 |
emb_extract: bool
|
137 |
WHether the model should extract embeddings for a given gene (WIP)
|
138 |
+
|
139 |
+
|
140 |
Customization Parameters
|
141 |
-----------
|
142 |
+
|
143 |
freeze_layers: int
|
144 |
Freezes x number of layers from the model. Default is 4 (2 non-frozen layers)
|
145 |
+
|
146 |
filter_dataset: int
|
147 |
Number of cells to filter from 30M dataset. Default is 50_000
|
148 |
+
|
149 |
emb_layer: int
|
150 |
What layer embeddings are extracted from. Default is 4
|
151 |
+
|
152 |
filter_data: str, list
|
153 |
Filters down embeddings to a single category. Default is None
|
154 |
+
|
155 |
+
|
156 |
"""
|