tigerdeF commited on
Commit
e562c0c
1 Parent(s): dbfac41

Upload 15 files

Browse files

Deployable version of GeneFormer gene/cell classification and embedding extraction in a single function. Function parameters explained in the markdown file, example usage at the bottom of each python file. Let me know if anything is needed or if there are unresolved issues, and I can get to fixing them!

Cell_classifier.py ADDED
@@ -0,0 +1,861 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Package Imports
2
+ import tqdm
3
+ import sys
4
+ import polars as pl
5
+ import pysam
6
+ import os
7
+ from datasets import Dataset
8
+ from collections import Counter
9
+ import random
10
+ import datetime
11
+ from pathlib import Path
12
+ import subprocess
13
+ import seaborn as sns; sns.set()
14
+ from datasets import load_from_disk
15
+ import fastcluster
16
+ from sklearn.metrics import accuracy_score, f1_score
17
+ from transformers import BertForSequenceClassification
18
+ from transformers import Trainer
19
+ from transformers.training_args import TrainingArguments
20
+ from geneformer import DataCollatorForCellClassification, EmbExtractor
21
+ import pickle
22
+ from sklearn.metrics import roc_curve, roc_auc_score, precision_recall_curve
23
+ from sklearn.metrics import auc as precision_auc
24
+ from sklearn.preprocessing import label_binarize
25
+ import pyarrow as pa
26
+ import concurrent.futures
27
+ from matplotlib import pyplot as plt
28
+ import torch
29
+ import torch.nn.functional as F
30
+ from scipy.stats import ranksums
31
+ import ray
32
+ import ast
33
+ from ray import tune
34
+ from ray.tune import ExperimentAnalysis
35
+ from ray.tune.search.hyperopt import HyperOptSearch
36
+ import numpy as np
37
+
38
+ # Properly sets up NCCV environment
39
+ GPU_NUMBER = [i for i in range(torch.cuda.device_count())]
40
+ os.environ["CUDA_VISIBLE_DEVICES"] = ",".join([str(s) for s in GPU_NUMBER])
41
+ os.environ["NCCL_DEBUG"] = "INFO"
42
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
43
+
44
+ # Function for generating a ROC curve from data
45
+ def ROC(prediction, truth, type = 'GeneFormer', label = ''):
46
+
47
+ fpr, tpr, _ = roc_curve(truth, prediction[:, 1])
48
+ auc = roc_auc_score(truth, prediction[:, 1])
49
+ print(f'{type} AUC: {auc}')
50
+ plt.plot(fpr,tpr, label="AUC="+str(auc))
51
+ plt.ylabel('True Positive Rate')
52
+ plt.xlabel('False Positive Rate')
53
+ plt.title(f'{label} ROC Curve')
54
+ plt.legend(loc=4)
55
+ plt.savefig('ROC.png')
56
+
57
+ return tpr, fpr, auc
58
+
59
+ # Identifies cosine similarity between two embeddings. 0 is perfectly dissimilar and 1 is perfectly similar
60
+ def similarity(tensor1, tensor2, cosine = False):
61
+
62
+ if cosine == False:
63
+ if tensor1.ndimension() > 1:
64
+ tensor1 = tensor1.view(1, -1)
65
+ if tensor2.ndimension() > 1:
66
+ tensor2 = tensor2.view(1, -1)
67
+ dot_product = torch.matmul(tensor1, tensor2)
68
+ norm_tensor1 = torch.norm(tensor1)
69
+ norm_tensor2 = torch.norm(tensor2)
70
+ epsilon = 1e-8
71
+ similarity = dot_product / (norm_tensor1 * norm_tensor2 + epsilon)
72
+ similarity = (similarity.item() + 1)/2
73
+ else:
74
+ if tensor1.shape != tensor2.shape:
75
+ raise ValueError("Input tensors must have the same shape.")
76
+
77
+ # Compute cosine similarity using PyTorch's dot product function
78
+ dot_product = torch.dot(tensor1, tensor2)
79
+ norm_tensor1 = torch.norm(tensor1)
80
+ norm_tensor2 = torch.norm(tensor2)
81
+
82
+ # Avoid division by zero by adding a small epsilon
83
+ epsilon = 1e-8
84
+ similarity = dot_product / (norm_tensor1 * norm_tensor2 + epsilon)
85
+
86
+ return similarity.item()
87
+
88
+ # Plots heatmap between different classes/labels
89
+ def plot_similarity_heatmap(similarities):
90
+ classes = list(similarities.keys())
91
+ classlen = len(classes)
92
+ arr = np.zeros((classlen, classlen))
93
+ for i, c in enumerate(classes):
94
+ for j, cc in enumerate(classes):
95
+ if cc == c:
96
+ val = 1.0
97
+ else:
98
+ val = similarities[c][cc]
99
+ arr[i][j] = val
100
+
101
+ plt.figure(figsize=(8, 6))
102
+ plt.imshow(arr, cmap='inferno', vmin=0, vmax=1)
103
+ plt.colorbar()
104
+ plt.xticks(np.arange(classlen), classes, rotation = 45, ha = 'right')
105
+ plt.yticks(np.arange(classlen), classes)
106
+ plt.title("Similarity Heatmap")
107
+ plt.savefig("similarity_heatmap.png")
108
+
109
+ # Function for tokenizing genes into ranked-value encodings from Geneformer
110
+ def tokenize_dataset(gene_set, type = None, token_set = 'token_dictionary.pkl', species = 'human'):
111
+ token_dataset = open(token_set, 'rb')
112
+ token_dict = pickle.load(token_dataset)
113
+ wrap = True
114
+
115
+ if isinstance(gene_set[0], list) == False:
116
+ gene_set = [gene_set]
117
+ wrap = False
118
+
119
+ pool = Pool()
120
+ converted_set = []
121
+
122
+ def process_gene(gene):
123
+ api_url = f"https://rest.ensembl.org/xrefs/symbol/{species}/{gene}?object_type=gene"
124
+ response = requests.get(api_url, headers={"Content-Type": "application/json"})
125
+ try:
126
+ data = response.json()
127
+ gene = data[0]['id']
128
+ except:
129
+ gene = None
130
+ return gene
131
+
132
+ def process_hgnc(gene):
133
+ for gene in tqdm.tqdm(genes, total = len(genes)):
134
+ api_url = f"https://rest.ensembl.org/xrefs/symbol/{species}/{hgnc_id}?object_type=gene"
135
+ response = requests.get(api_url, headers={"Content-Type": "application/json"})
136
+ try:
137
+ data = response.json()
138
+ gene = data[0]['id']
139
+ except:
140
+ gene = None
141
+ return gene
142
+
143
+ def process_go(gene):
144
+ mg = mygene.MyGeneInfo()
145
+ results = mg.query(gene, scopes="go", species=species, fields="ensembl.gene")
146
+
147
+ ensembl_ids = []
148
+ max_score = 0
149
+ for hit_num, hit in enumerate(results["hits"]):
150
+ if hit['_score'] > max_score:
151
+ max_score = hit['_score']
152
+ chosen_hit = hit
153
+ try:
154
+ try:
155
+ gene = chosen_hit["ensembl"]["gene"]
156
+ except:
157
+ gene = chosen_hit["ensembl"][0]["gene"]
158
+ except:
159
+ gene = None
160
+ return gene
161
+
162
+ if type == None or type.upper() == 'ENSEMBL':
163
+ converted_set = gene_set
164
+ elif type.upper() == 'GENE':
165
+ for genes in gene_set:
166
+ converted_genes = []
167
+ for result in tqdm.tqdm(pool.imap(process_gene, genes), total = len(genes)):
168
+ converted_genes.append(result)
169
+ converted_set.append(converted_genes)
170
+ elif type.upper() == 'GO':
171
+ for genes in gene_set:
172
+ converted_genes = []
173
+ for result in tqdm.tqdm(pool.imap(process_go, genes), total = len(genes)):
174
+ converted_genes.append(result)
175
+ converted_set.append(converted_genes)
176
+ elif type.upper() == 'HGNC':
177
+ for genes in gene_set:
178
+ converted_genes = []
179
+ for result in tqdm.tqdm(pool.imap(process_hgnc, genes), total = len(genes)):
180
+ converted_genes.append(result)
181
+ converted_set.append(converted_genes)
182
+
183
+ Chembl = []
184
+ for set_num, set in enumerate(converted_set):
185
+ Chembl.append([])
186
+ for gene in set:
187
+ if gene == None:
188
+ Chembl[set_num].append(None)
189
+ else:
190
+ try:
191
+ Chembl[set_num].append(token_dict[gene])
192
+ except:
193
+ print(f'{gene} not found in tokenized dataset!')
194
+ Chembl[set_num].append(None)
195
+
196
+ if wrap == False:
197
+ Chembl = Chembl[0]
198
+
199
+ return Chembl
200
+
201
+
202
+ # '/work/ccnr/GeneFormer/GeneFormer_repo/Genecorpus-30M/example_input_files/cell_classification/disease_classification/human_dcm_hcm_nf.dataset/'
203
+ # '/work/ccnr/GeneFormer/GeneFormer_repo/Genecorpus-30M/example_input_files/cell_classification/cell_type_annotation/cell_type_train_data.dataset/'
204
+ '''
205
+ ======================================================
206
+
207
+ PRIMARY CELL - CLASSIFIER AND EMBEDDING EXTRACTOR CLASS
208
+
209
+ +++++++++++++++++++++++++++++++++++++++++++++++++++++++
210
+
211
+ Runs cell-level classification and embedding extraction with Geneformer
212
+
213
+ '''
214
+
215
+ def finetune_cells(token_set = Path('geneformer/token_dictionary.pkl'), median_set = Path('geneformer/gene_median_dictionary.pkl'), pretrained_model = ".",
216
+ dataset = 'Genecorpus-30M/example_input_files/cell_classification/cell_type_annotation/cell_type_train_data.dataset/', dataset_split = None, filter_cells = .005, epochs = 1, cpu_cores = os.cpu_count(), geneformer_batch_size = 12, optimizer = 'adamw', max_lr = 5e-5, num_gpus = torch.cuda.device_count(), max_input_size = 2 ** 11, lr_schedule_fn = "linear", warmup_steps = 500, freeze_layers = 0, emb_extract = False, max_cells = 1000, emb_layer = 0, emb_filter = None, emb_dir = 'embeddings', overwrite = True, label = "cell_type", data_filter = None,
217
+ forward_batch = 200, model_location = None, skip_training = False, sample_data = 1, inference = False, optimize_hyperparameters = False, output_dir = None):
218
+
219
+ '''
220
+ Primary Parameters
221
+ -------------------
222
+ dataset: path
223
+ Path to fine-tuning/testing dataset for training
224
+
225
+ model_location: path
226
+ Path to location of existing model to use for inference and embedding extraction
227
+
228
+ pretrained_model: path
229
+ Path to pretrained GeneFormer 30M model before fine-tuning
230
+
231
+ inference: bool
232
+ Chooses whether to perform inference (which causes the function to return the list of similarities). Defaults to False
233
+
234
+ skip_training: bool
235
+ Chooses whether to skip training the model. Defaults to False
236
+
237
+ emb_extract: bool
238
+ Choose whether to extract embeddings and calculate similarities. Defaults to True
239
+
240
+ optimize_hyperparameters: bool
241
+ Choose whether to optimize model hyperparamters. Defaults to False
242
+
243
+
244
+ Customization Parameters
245
+ -------------------
246
+
247
+ dataset_split: str
248
+ How the dataset should be partitioned (if at all), and what ID should be used for partitioning
249
+
250
+ data_filter: list
251
+ (For embeddings and inference) Runs analysis subsets of the dataset by the ID defined by dataset_split
252
+
253
+ label: str
254
+ What feature should be read as a classification label
255
+
256
+ emb_layer: int
257
+ What layer embeddings should be extracted and compared from.
258
+
259
+ emb_filter: ['cell1', 'cell2'...]
260
+ Allows user to narrow down range of cells that embeddings will be extracted from.
261
+
262
+ max_cells: int
263
+ How many embeddings from cells should be extracted.
264
+
265
+ freeze_layers: int
266
+ Number of layers should be permanently frozen during fine-tuning (starting from the first layer, 4 brings it up to the pretrained model).
267
+
268
+ sample_data: float
269
+ What proportion of the HF dataset should be used
270
+
271
+ '''`
272
+
273
+ dataset_list = []
274
+ evalset_list = []
275
+ split_list = []
276
+ target_dict_list = []
277
+
278
+ '''
279
+ For loading and pretraining with custom median expressions and/or custom gene conversions
280
+ -------------------------------------------------------------
281
+
282
+ token set: path
283
+ Path to token conversion dictionary
284
+
285
+ median set: path
286
+ Path to median gene dictionary (ensembl IDs as the keys)
287
+
288
+
289
+ median_data = pickle.load(open(median_set, 'rb'))
290
+ median_data['<pad>'] = None
291
+ median_data['<mask>'] = None
292
+
293
+ token_set = pickle.load(open(token_set, 'rb'))
294
+ median_dict = {key:median_data[key] for key in list(token_set.keys())}
295
+ '''
296
+
297
+ train_dataset = load_from_disk(dataset)
298
+ num_samples = int(len(train_dataset) * sample_data)
299
+ random_indices = random.sample(range(len(train_dataset)), num_samples)
300
+ train_dataset = train_dataset.select(random_indices)
301
+
302
+ sample = int(sample_data * len(train_dataset))
303
+ sample_indices = random.sample(range(len(train_dataset)), sample)
304
+ train_dataset = train_dataset.select(sample_indices)
305
+
306
+ def if_not_rare_celltype(example):
307
+ return example[label] in cells_to_keep
308
+
309
+ # change labels to numerical ids
310
+ def classes_to_ids(example):
311
+ example["label"] = target_name_id_dict[example["label"]]
312
+ return example
313
+
314
+ def if_trained_label(example):
315
+ return example["label"] in trained_labels
316
+
317
+ if skip_training != True:
318
+ def compute_metrics(pred):
319
+ labels = pred.label_ids
320
+ preds = pred.predictions.argmax(-1)
321
+ # calculate accuracy and macro f1 using sklearn's function
322
+ acc = accuracy_score(labels, preds)
323
+ macro_f1 = f1_score(labels, preds, average='macro')
324
+ return {
325
+ 'accuracy': acc,
326
+ 'macro_f1': macro_f1
327
+ }
328
+
329
+ # Defines custom exceptions for collecting labels (default excluded)
330
+ excep = {"bone_marrow":"immune"}
331
+
332
+ if dataset_split != None:
333
+ if data_filter != None:
334
+ split_iter = [data_filter]
335
+ else:
336
+ split_iter = Counter(train_dataset[dataset_split]).keys()
337
+ for lab in split_iter:
338
+
339
+ # collect list of tissues for fine-tuning (immune and bone marrow are included together)
340
+ if lab in list(excep.keys()):
341
+ continue
342
+ elif lab == list(excep.values()):
343
+ split_ids = [excep.keys(),excep.values()]
344
+ split_list += [excep.values()]
345
+ else:
346
+ split_ids = [lab]
347
+ split_list += [lab]
348
+
349
+ # filter datasets for given organ
350
+ def if_label(example):
351
+ return example[dataset_split] == lab
352
+
353
+ trainset_label = train_dataset.filter(if_label, num_proc=cpu_cores)
354
+ label_counter = Counter(trainset_label[label])
355
+ total_cells = sum(label_counter.values())
356
+
357
+ # Throws out cells with a low proportion in the dataset (drop cell types representing <0.5% of cells per deepsort published method)
358
+ cells_to_keep = [k for k,v in label_counter.items() if v>(filter_cells*total_cells)]
359
+ trainset_label_subset = trainset_label.filter(if_not_rare_celltype, num_proc=cpu_cores)
360
+
361
+ # shuffle datasets and rename columns
362
+ trainset_label_shuffled = trainset_label_subset.shuffle(seed=42)
363
+ trainset_label_shuffled = trainset_label_shuffled.rename_column(label,"label")
364
+ trainset_label_shuffled = trainset_label_shuffled.remove_columns(dataset_split)
365
+
366
+ # create dictionary of cell types : label ids
367
+ target_names = list(Counter(trainset_label_shuffled["label"]).keys())
368
+ target_name_id_dict = dict(zip(target_names,[i for i in range(len(target_names))]))
369
+ target_dict_list += [target_name_id_dict]
370
+
371
+ labeled_trainset = trainset_label_shuffled.map(classes_to_ids, num_proc=cpu_cores)
372
+
373
+ # create 80/20 train/eval splits
374
+ labeled_train_split = trainset_label_shuffled.select([i for i in range(0,round(len(labeled_trainset)*0.8))])
375
+ labeled_eval_split = trainset_label_shuffled.select([i for i in range(round(len(labeled_trainset)*0.8),len(labeled_trainset))])
376
+
377
+ # filter dataset for cell types in corresponding training set
378
+ trained_labels = list(Counter(labeled_train_split["label"]).keys())
379
+
380
+ labeled_eval_split_subset = labeled_eval_split.filter(if_trained_label, num_proc=cpu_cores)
381
+
382
+ dataset_list += [labeled_train_split]
383
+ evalset_list += [labeled_eval_split_subset]
384
+
385
+ trainset_dict = dict(zip(split_list,dataset_list))
386
+ traintargetdict_dict = dict(zip(split_list,target_dict_list))
387
+ evalset_dict = dict(zip(split_list,evalset_list))
388
+
389
+ for lab in split_list:
390
+ label_trainset = trainset_dict[lab]
391
+ label_evalset = evalset_dict[lab]
392
+ label_dict = traintargetdict_dict[lab]
393
+
394
+ # set logging steps
395
+ logging_steps = round(len(label_trainset)/geneformer_batch_size/10)
396
+ if logging_steps == 0:
397
+ logging_steps = 1
398
+
399
+ # reload pretrained model
400
+ model = BertForSequenceClassification.from_pretrained("/work/ccnr/GeneFormer/GeneFormer_repo",
401
+ num_labels=len(label_dict.keys()),
402
+ output_attentions = False,
403
+ output_hidden_states = False).to(device)
404
+
405
+ # define output directory path
406
+ current_date = datetime.datetime.now()
407
+ datestamp = f"{str(current_date.year)[-2:]}{current_date.month:02d}{current_date.day:02d}"
408
+
409
+ if output_dir == None:
410
+ output_dir = f"{datestamp}_geneformer_CellClassifier_{lab}_L{max_input_size}_B{geneformer_batch_size}_LR{max_lr}_LS{lr_schedule_fn}_WU{warmup_steps}_E{epochs}_O{optimizer}_F{freeze_layers}/"
411
+
412
+ # ensure not overwriting previously saved model
413
+ saved_model_test = os.path.join(output_dir, f"pytorch_model.bin")
414
+
415
+ if os.path.isfile(saved_model_test) == True and overwrite == False:
416
+ raise Exception("Model already saved to this directory.")
417
+
418
+ # make output directory
419
+ subprocess.call(f'mkdir -p {output_dir}', shell=True)
420
+
421
+ # set training arguments
422
+ training_args = {
423
+ "learning_rate": max_lr,
424
+ "do_train": True,
425
+ "do_eval": True,
426
+ "evaluation_strategy": "epoch",
427
+ "save_strategy": "epoch",
428
+ "logging_steps": logging_steps,
429
+ "group_by_length": True,
430
+ "length_column_name": "length",
431
+ "disable_tqdm": False,
432
+ "lr_scheduler_type": lr_schedule_fn,
433
+ "warmup_steps": warmup_steps,
434
+ "weight_decay": 0.001,
435
+ "per_device_train_batch_size": geneformer_batch_size,
436
+ "per_device_eval_batch_size": geneformer_batch_size,
437
+ "num_train_epochs": epochs,
438
+ "load_best_model_at_end": True,
439
+ "output_dir": output_dir,
440
+ }
441
+
442
+
443
+ training_args_init = TrainingArguments(**training_args)
444
+ true_labels = label_evalset['label']
445
+
446
+
447
+ if optimize_hyperparameters == False:
448
+ # create the trainer
449
+ trainer = Trainer(
450
+ model=model,
451
+ args=training_args_init,
452
+ data_collator=DataCollatorForCellClassification(),
453
+ train_dataset=label_trainset,
454
+ eval_dataset=label_evalset,
455
+ compute_metrics=compute_metrics
456
+ )
457
+
458
+ # train the cell type classifier
459
+ trainer.train()
460
+ predictions = trainer.predict(label_evalset)
461
+ print(f'accuracy: {accuracy_score(predictions.argmax(), label_evalset["labels"])}')
462
+
463
+ tpr, fpr, auc = ROC(predictions.predictions, true_labels)
464
+
465
+ metrics = compute_metrics(predictions)
466
+ with open(f"{output_dir}predictions.pickle", "wb") as fp:
467
+ pickle.dump(predictions, fp)
468
+
469
+ trainer.save_metrics("eval",predictions.metrics)
470
+
471
+ with open(f'{output_dir}/targets.txt', 'w') as f:
472
+ if len(target_dict_list) == 1:
473
+ f.write(str(target_dict_list[0]))
474
+ else:
475
+ f.write(str(target_dict_list))
476
+
477
+ try:
478
+
479
+ precision, recall, _ = precision_recall_curve(true_labels, predictions.predictions[:, 1])
480
+ pr_auc = precision_auc(recall, precision)
481
+
482
+ print(f'AUC: {pr_auc}')
483
+ return recall, precision, pr_auc
484
+ except:
485
+ pass
486
+
487
+ trainer.save_model(output_dir)
488
+ else:
489
+
490
+ def model_init():
491
+ model = BertForSequenceClassification.from_pretrained(pretrained_model,
492
+ num_labels=len(label_dict.keys()),
493
+ output_attentions = False,
494
+ output_hidden_states = False)
495
+ if freeze_layers is not None:
496
+ modules_to_freeze = model.bert.encoder.layer[:freeze_layers]
497
+ for module in modules_to_freeze:
498
+ for param in module.parameters():
499
+ param.requires_grad = False
500
+ model = model.to(device)
501
+ return model
502
+
503
+ trainer = Trainer(
504
+ model_init=model_init,
505
+ args=training_args_init,
506
+ data_collator=DataCollatorForCellClassification(),
507
+ train_dataset=label_trainset,
508
+ eval_dataset=label_evalset,
509
+ compute_metrics=compute_metrics
510
+ )
511
+ # specify raytune hyperparameter search space
512
+ ray_config = {
513
+ "num_train_epochs": tune.choice([epochs]),
514
+ "learning_rate": tune.loguniform(1e-6, 1e-3),
515
+ "weight_decay": tune.uniform(0.0, 0.3),
516
+ "lr_scheduler_type": tune.choice(["linear","cosine","polynomial"]),
517
+ "warmup_steps": tune.uniform(100, 2000),
518
+ "seed": tune.uniform(0,100),
519
+ "per_device_train_batch_size": tune.choice([geneformer_batch_size])
520
+ }
521
+
522
+ hyperopt_search = HyperOptSearch(
523
+ metric="eval_accuracy", mode="max")
524
+
525
+ if torch.device == 'cuda':
526
+ resources_per_trial={"cpu":8,"gpu":1},
527
+ else:
528
+ resources_per_trial={"cpu":8}
529
+
530
+ # optimize hyperparameters
531
+ best_trial = trainer.hyperparameter_search(
532
+ direction="maximize",
533
+ backend="ray",
534
+ resources_per_trial = resources_per_trial,
535
+ hp_space=lambda _: ray_config,
536
+ search_alg=hyperopt_search,
537
+ n_trials=10, # number of trials
538
+ progress_reporter=tune.CLIReporter(max_report_frequency=600,
539
+ sort_by_metric=True,
540
+ max_progress_rows=100,
541
+ mode="max",
542
+ metric="eval_accuracy",
543
+ metric_columns=["loss", "eval_loss", "eval_accuracy"]))
544
+ best_hyperparameters = best_trial.hyperparameters
545
+
546
+ print("Best Hyperparameters:")
547
+ print(best_hyperparameters)
548
+
549
+
550
+
551
+ else:
552
+ trainset_label = train_dataset
553
+ label_counter = Counter(trainset_label[label])
554
+ total_cells = sum(label_counter.values())
555
+
556
+ # Throws out cells with a low proportion in the dataset
557
+ cells_to_keep = [k for k,v in label_counter.items() if v>(filter_cells*total_cells)]
558
+ trainset_label_subset = trainset_label.filter(if_not_rare_celltype, num_proc=cpu_cores)
559
+
560
+ # shuffle datasets and rename columns
561
+ trainset_label_shuffled = trainset_label_subset.shuffle(seed=42)
562
+ trainset_label_shuffled = trainset_label_shuffled.rename_column(label,"label")
563
+
564
+ # create dictionary of cell types : label ids
565
+ target_names = list(Counter(trainset_label_shuffled["label"]).keys())
566
+ target_name_id_dict = dict(zip(target_names,[i for i in range(len(target_names))]))
567
+ target_dict_list = target_name_id_dict
568
+
569
+ labeled_trainset = trainset_label_shuffled.map(classes_to_ids, num_proc=cpu_cores)
570
+
571
+ # create 80/20 train/eval splits
572
+ labeled_train_split = labeled_trainset.select([i for i in range(0,round(len(labeled_trainset)*0.8))])
573
+ labeled_eval_split = labeled_trainset.select([i for i in range(round(len(labeled_trainset)*0.8),len(labeled_trainset))])
574
+
575
+ # filter dataset for cell types in corresponding training set
576
+ trained_labels = list(Counter(labeled_train_split["label"]).keys())
577
+ labeled_eval_split_subset = labeled_eval_split.filter(if_trained_label, num_proc=cpu_cores)
578
+
579
+ # set logging steps
580
+ logging_steps = round(len(trainset_label)/geneformer_batch_size/10)
581
+
582
+ # reload pretrained model
583
+ model = BertForSequenceClassification.from_pretrained(pretrained_model,
584
+ num_labels=len(target_dict_list.keys()),
585
+ output_attentions = False,
586
+ output_hidden_states = False).to(device)
587
+ # define output directory path
588
+ current_date = datetime.datetime.now()
589
+ datestamp = f"{str(current_date.year)[-2:]}{current_date.month:02d}{current_date.day:02d}"
590
+
591
+ if output_dir == None:
592
+ output_dir = f"{datestamp}_geneformer_CellClassifier_L{max_input_size}_B{geneformer_batch_size}_LR{max_lr}_LS{lr_schedule_fn}_WU{warmup_steps}_E{epochs}_O{optimizer}_F{freeze_layers}/"
593
+
594
+ # ensure not overwriting previously saved model
595
+ saved_model_test = os.path.join(output_dir, f"pytorch_model.bin")
596
+ if os.path.isfile(saved_model_test) == True and overwrite == False:
597
+ raise Exception("Model already saved to this directory.")
598
+
599
+ # make output directory
600
+ subprocess.call(f'mkdir -p {output_dir}', shell=True)
601
+
602
+ # set training arguments
603
+ training_args = {
604
+ "learning_rate": max_lr,
605
+ "do_train": True,
606
+ "do_eval": True,
607
+ "evaluation_strategy": "epoch",
608
+ "save_strategy": "epoch",
609
+ "logging_steps": logging_steps,
610
+ "group_by_length": True,
611
+ "length_column_name": "length",
612
+ "disable_tqdm": False,
613
+ "lr_scheduler_type": lr_schedule_fn,
614
+ "warmup_steps": warmup_steps,
615
+ "weight_decay": 0.001,
616
+ "per_device_train_batch_size": geneformer_batch_size,
617
+ "per_device_eval_batch_size": geneformer_batch_size,
618
+ "num_train_epochs": epochs,
619
+ "load_best_model_at_end": True,
620
+ "output_dir": output_dir,}
621
+
622
+ training_args_init = TrainingArguments(**training_args)
623
+ true_labels = labeled_eval_split_subset['label']
624
+
625
+ if optimize_hyperparameters == False:
626
+
627
+ # create the trainer
628
+ trainer = Trainer(
629
+ model=model,
630
+ args=training_args_init,
631
+ data_collator=DataCollatorForCellClassification(),
632
+ train_dataset=labeled_train_split,
633
+ eval_dataset=labeled_eval_split_subset,
634
+ compute_metrics=compute_metrics
635
+ )
636
+
637
+ # train the cell type classifier
638
+ trainer.train()
639
+ predictions = trainer.predict(labeled_eval_split_subset)
640
+ predictions_tensor = torch.Tensor(predictions.predictions)
641
+ predicted_labels = torch.argmax(predictions_tensor, dim=1)
642
+ print(f'accuracy: {accuracy_score(predicted_labels, labeled_eval_split_subset["label"])}')
643
+ metrics = compute_metrics(predictions)
644
+
645
+ with open(f"{output_dir}predictions.pickle", "wb") as fp:
646
+ pickle.dump(predictions.predictions.argmax(-1), fp)
647
+
648
+ trainer.save_metrics("eval",predictions.metrics)
649
+ trainer.save_model(output_dir)
650
+
651
+ # Saves label conversion dictionary to output directory
652
+ with open(f'{output_dir}/targets.txt', 'w') as f:
653
+ f.write(str(target_dict_list))
654
+
655
+ try:
656
+
657
+ precision, recall, _ = precision_recall_curve(true_labels, predictions.predictions[:, 1])
658
+ pr_auc = precision_auc(recall, precision)
659
+
660
+ print(f'AUC: {pr_auc}')
661
+ return recall, precision, pr_auc
662
+ except:
663
+ pass
664
+
665
+ else:
666
+ # Optimizes hyperparameters
667
+
668
+ num_classes = len(list(set(labeled_train_split['label'])))
669
+ def model_init():
670
+ model = BertForSequenceClassification.from_pretrained(pretrained_model,
671
+ num_labels=num_classes,
672
+ output_attentions = False,
673
+ output_hidden_states = False)
674
+
675
+ if freeze_layers is not None:
676
+ modules_to_freeze = model.bert.encoder.layer[:freeze_layers]
677
+ for module in modules_to_freeze:
678
+ for param in module.parameters():
679
+ param.requires_grad = False
680
+ model = model.to(device)
681
+ return model
682
+
683
+
684
+ # create the trainer
685
+ trainer = Trainer(
686
+ model_init=model_init,
687
+ args=training_args_init,
688
+ data_collator=DataCollatorForCellClassification(),
689
+ train_dataset=labeled_train_split,
690
+ eval_dataset=labeled_eval_split_subset,
691
+ compute_metrics=compute_metrics
692
+ )
693
+
694
+ # specify raytune hyperparameter search space
695
+ ray_config = {
696
+ "num_train_epochs": tune.choice([epochs]),
697
+ "learning_rate": tune.loguniform(1e-6, 1e-3),
698
+ "weight_decay": tune.uniform(0.0, 0.3),
699
+ "lr_scheduler_type": tune.choice(["linear","cosine","polynomial"]),
700
+ "warmup_steps": tune.uniform(100, 2000),
701
+ "seed": tune.uniform(0,100),
702
+ "per_device_train_batch_size": tune.choice([geneformer_batch_size])
703
+ }
704
+
705
+ hyperopt_search = HyperOptSearch(
706
+ metric="eval_accuracy", mode="max")
707
+
708
+ if torch.device == 'cuda':
709
+ resources_per_trial={"cpu":8,"gpu":1},
710
+ else:
711
+ resources_per_trial={"cpu":8}
712
+
713
+ # optimize hyperparameters
714
+ best_trial = trainer.hyperparameter_search(
715
+ direction="maximize",
716
+ backend="ray",
717
+ resources_per_trial = resources_per_trial,
718
+ hp_space=lambda _: ray_config,
719
+ search_alg=hyperopt_search,
720
+ n_trials=10, # number of trials
721
+ progress_reporter=tune.CLIReporter(max_report_frequency=600,
722
+ sort_by_metric=True,
723
+ max_progress_rows=100,
724
+ mode="max",
725
+ metric="eval_accuracy",
726
+ metric_columns=["loss", "eval_loss", "eval_accuracy"]))
727
+ best_hyperparameters = best_trial.hyperparameters
728
+
729
+ print("Best Hyperparameters:")
730
+ print(best_hyperparameters)
731
+
732
+
733
+ # Performs Inference with model
734
+ if inference == True:
735
+ if dataset_split != None and data_filter != None:
736
+ def if_label(example):
737
+ return example[dataset_split] == data_filter
738
+
739
+ train_dataset = train_dataset.filter(if_label, num_proc=cpu_cores)
740
+
741
+ trainset_label_shuffled = train_dataset
742
+ total_cells = len(trainset_label_shuffled)
743
+
744
+ # loads dictionary of all cell labels model was trained on
745
+ with open(Path(model_location) / 'targets.txt', 'r') as f:
746
+ data = ast.literal_eval(f.read())
747
+ if dataset_split != None and data_filter == None:
748
+ indexer = dataset_split.index(data_filter)
749
+ data = data[indexer]
750
+
751
+ target_dict_list = {key:value for key, value in enumerate(data)}
752
+
753
+ # set logging steps
754
+ logging_steps = round(len(trainset_label_shuffled)/geneformer_batch_size/20)
755
+
756
+ # reload pretrained model
757
+ input_ids = trainset_label_shuffled["input_ids"]
758
+ inputs = torch.zeros(len(input_ids), max_input_size, dtype=torch.int64)
759
+ attention = torch.zeros(len(input_ids), max_input_size, dtype=torch.int64)
760
+
761
+ for i, sentence in enumerate(input_ids):
762
+ sentence_length = len(sentence)
763
+ if sentence_length <= max_input_size:
764
+ inputs[i, :sentence_length] = torch.tensor(sentence)
765
+ attention[i, :sentence_length] = torch.ones(sentence_length)
766
+ else:
767
+ inputs[i, :] = torch.tensor(sentence[:max_input_size])
768
+ attention[i, :] = torch.ones(max_input_size)
769
+
770
+ model = BertForSequenceClassification.from_pretrained(model_location, num_labels=len(target_dict_list)).to(device)
771
+ model_outputs = model(inputs.to(device), attention_mask = attention)["logits"]
772
+ predictions = F.softmax(model_outputs, dim = -1).argmax(-1)
773
+
774
+ predictions = [target_dict_list[int(pred)] for pred in predictions]
775
+
776
+ return predictions
777
+
778
+ # Extracts embeddings from labelled data
779
+ if emb_extract == True:
780
+ if emb_filter == None:
781
+ with open(f'{model_location}/targets.txt', 'r') as f:
782
+ data = ast.literal_eval(f.read())
783
+ if dataset_split != None and data_filter == None:
784
+ indexer = dataset_split.index(data_filter)
785
+ data = data[indexer]
786
+
787
+ target_dict_list = {key:value for key, value in enumerate(data)}
788
+ total_filter = None
789
+ else:
790
+ total_filter = emb_filter
791
+
792
+ train_dataset = load_from_disk(dataset)
793
+ if dataset_split != None:
794
+ def if_label(example):
795
+ return example[dataset_split] == data_filter
796
+
797
+ train_dataset = train_dataset.filter(if_label, num_proc=cpu_cores)
798
+
799
+ label_counter = Counter(train_dataset[label])
800
+ total_cells = sum(label_counter.values())
801
+ cells_to_keep = [k for k,v in label_counter.items() if v>(filter_cells*total_cells)]
802
+
803
+ def if_not_rare(example):
804
+ return example[label] in cells_to_keep
805
+
806
+ train_dataset = train_dataset.filter(if_not_rare, num_proc=cpu_cores)
807
+
808
+ true_labels = train_dataset[label]
809
+ num_classes = len(list(set(true_labels)))
810
+
811
+ embex = EmbExtractor(model_type="CellClassifier", num_classes=num_classes,
812
+ filter_data=total_filter, max_ncells=max_cells, emb_layer=emb_layer,
813
+ emb_label=[dataset_split,label], labels_to_plot=[label], forward_batch_size=forward_batch, nproc=cpu_cores)
814
+
815
+ # example dataset: https://huggingface.co/datasets/ctheodoris/Genecorpus-30M/tree/main/example_input_files/cell_classification/disease_classification/human_dcm_hcm_nf.dataset
816
+ subprocess.call(f'mkdir -p {emb_dir}', shell = True)
817
+
818
+ embs = embex.extract_embs(model_directory = model_location, input_data_file = dataset, output_directory = emb_dir, output_prefix = f"{label}_embbeddings")
819
+ true_labels = embex.filtered_input_data[label]
820
+
821
+ emb_dict = {label:[] for label in list(set(true_labels))}
822
+ for num, emb in embs.iterrows():
823
+ key = emb[label]
824
+ selection = emb.iloc[:255]
825
+ emb = torch.Tensor(selection)
826
+ emb_dict[key].append(emb)
827
+
828
+ for key in list(emb_dict.keys()):
829
+ stack = torch.stack(emb_dict[key], dim = 0)
830
+ emb_dict[key] = torch.mean(stack, dim=0)
831
+ similarities = {key:{} for key in list(emb_dict.keys())}
832
+
833
+ for key in list(emb_dict.keys()):
834
+ remaining_keys = [k for k in list(emb_dict.keys()) if k != key]
835
+ for k in remaining_keys:
836
+ embedding = emb_dict[k]
837
+ sim = similarity(emb_dict[key], embedding, cosine = True)
838
+
839
+ similarities[key][k] = sim
840
+
841
+ plot_similarity_heatmap(similarities)
842
+
843
+ embex.plot_embs(embs=embs,
844
+ plot_style="umap",
845
+ output_directory=emb_dir,
846
+ output_prefix="emb_plot")
847
+
848
+
849
+ embex.plot_embs(embs=embs,
850
+ plot_style="heatmap",
851
+ output_directory=emb_dir,
852
+ output_prefix="emb_plot")
853
+
854
+
855
+ return similarities
856
+
857
+ if __name__ == '__main__':
858
+ predictions = finetune_cells(skip_training = False, dataset_split = None, label = "disease", sample_data = .5, data_filter = 'hcm', epochs = 10, output_dir = 'hcm_model', model_location = 'hcm_model',
859
+ emb_extract = True, geneformer_batch_size = 12, inference = False, dataset = "/work/ccnr/GeneFormer/GeneFormer_repo/Genecorpus-30M/example_input_files/cell_classification/disease_classification/human_dcm_hcm_nf.dataset/")
860
+
861
+
Gene_classifier.py ADDED
@@ -0,0 +1,746 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ GPU_NUMBER = [0] # CHANGE WITH MULTIGPU
4
+ os.environ["CUDA_VISIBLE_DEVICES"] = ",".join([str(s) for s in GPU_NUMBER])
5
+ os.environ["NCCL_DEBUG"] = "INFO"
6
+
7
+ # imports
8
+ from sklearn.model_selection import train_test_split
9
+ import datetime
10
+ import subprocess
11
+ from pathlib import Path
12
+ import math
13
+ import matplotlib.pyplot as plt
14
+ import numpy as np
15
+ import pickle
16
+ import pandas as pd
17
+ from datasets import load_from_disk, Dataset
18
+ from sklearn import preprocessing
19
+ from sklearn.metrics import accuracy_score, auc, confusion_matrix, ConfusionMatrixDisplay, roc_curve
20
+ from sklearn.model_selection import StratifiedKFold
21
+ import torch
22
+ from transformers import BertForTokenClassification
23
+ from transformers import Trainer
24
+ from transformers.training_args import TrainingArguments
25
+ from tqdm.notebook import tqdm
26
+ from sklearn.metrics import roc_curve, roc_auc_score
27
+ from geneformer import DataCollatorForGeneClassification, EmbExtractor
28
+ from geneformer.pretrainer import token_dictionary
29
+ import ast
30
+ import torch.nn.functional as F
31
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
32
+ from geneformer import TranscriptomeTokenizer
33
+
34
+ def vote(logit_pair):
35
+ a, b = logit_pair
36
+ if a > b:
37
+ return 0
38
+ elif b > a:
39
+ return 1
40
+ elif a == b:
41
+ return "tie"
42
+
43
+ def py_softmax(vector):
44
+ e = np.exp(vector)
45
+ return e / e.sum()
46
+
47
+ # Identifies cosine similarity between two embeddings. 0 is perfectly dissimilar and 1 is perfectly similar
48
+ def similarity(tensor1, tensor2, cosine = True):
49
+ if cosine == False:
50
+ if tensor1.ndimension() > 1:
51
+ tensor1 = tensor1.view(1, -1)
52
+ if tensor2.ndimension() > 1:
53
+ tensor2 = tensor2.view(1, -1)
54
+ dot_product = torch.matmul(tensor1, tensor2)
55
+ norm_tensor1 = torch.norm(tensor1)
56
+ norm_tensor2 = torch.norm(tensor2)
57
+ epsilon = 1e-8
58
+ similarity = dot_product / (norm_tensor1 * norm_tensor2 + epsilon)
59
+ similarity = (similarity.item() + 1)/2
60
+ else:
61
+ if tensor1.shape != tensor2.shape:
62
+ raise ValueError("Input tensors must have the same shape.")
63
+
64
+ # Compute cosine similarity using PyTorch's dot product function
65
+ dot_product = torch.dot(tensor1, tensor2)
66
+ norm_tensor1 = torch.norm(tensor1)
67
+ norm_tensor2 = torch.norm(tensor2)
68
+
69
+ # Avoid division by zero by adding a small epsilon
70
+ epsilon = 1e-8
71
+ similarity = dot_product / (norm_tensor1 * norm_tensor2 + epsilon)
72
+
73
+ return similarity.item()
74
+
75
+ # Plots heatmap between different classes/labels
76
+ def plot_similarity_heatmap(similarities):
77
+ classes = list(similarities.keys())
78
+ classlen = len(classes)
79
+ arr = np.zeros((classlen, classlen))
80
+ for i, c in enumerate(classes):
81
+ for j, cc in enumerate(classes):
82
+ if cc == c:
83
+ val = 1.0
84
+ else:
85
+ val = similarities[c][cc]
86
+ arr[i][j] = val
87
+
88
+ plt.figure(figsize=(8, 6))
89
+ plt.imshow(arr, cmap='inferno', vmin=0, vmax=1)
90
+ plt.colorbar()
91
+ plt.xticks(np.arange(classlen), classes, rotation = 45, ha = 'right')
92
+ plt.yticks(np.arange(classlen), classes)
93
+ plt.title("Similarity Heatmap")
94
+ plt.savefig("similarity_heatmap.png")
95
+
96
+ # get cross-validated mean and sd metrics
97
+ def get_cross_valid_metrics(all_tpr, all_roc_auc, all_tpr_wt):
98
+ wts = [count/sum(all_tpr_wt) for count in all_tpr_wt]
99
+
100
+ all_weighted_tpr = [a*b for a,b in zip(all_tpr, wts)]
101
+ mean_tpr = np.sum(all_weighted_tpr, axis=0)
102
+ mean_tpr[-1] = 1.0
103
+ all_weighted_roc_auc = [a*b for a,b in zip(all_roc_auc, wts)]
104
+ roc_auc = np.sum(all_weighted_roc_auc)
105
+ roc_auc_sd = math.sqrt(np.average((all_roc_auc-roc_auc)**2, weights=wts))
106
+ return mean_tpr, roc_auc, roc_auc_sd
107
+
108
+ def validate(data, targets, labels, nsplits, subsample_size, training_args, freeze_layers, output_dir, num_proc, num_labels, pre_model):
109
+ # initiate eval metrics to return
110
+ num_classes = len(set(labels))
111
+ mean_fpr = np.linspace(0, 1, 100)
112
+
113
+ # create 80/20 train/eval splits
114
+ targets_train, targets_eval, labels_train, labels_eval = train_test_split(targets, labels ,test_size=0.25, shuffle=True)
115
+ label_dict_train = dict(zip(targets_train, labels_train))
116
+ label_dict_eval = dict(zip(targets_eval, labels_eval))
117
+
118
+ # function to filter by whether contains train or eval labels
119
+ def if_contains_train_label(example):
120
+ a = label_dict_train.keys()
121
+ b = example['input_ids']
122
+ return not set(a).isdisjoint(b)
123
+
124
+ def if_contains_eval_label(example):
125
+ a = label_dict_eval.keys()
126
+ b = example['input_ids']
127
+ return not set(a).isdisjoint(b)
128
+
129
+ # filter dataset for examples containing classes for this split
130
+ print(f"Filtering training data")
131
+ trainset = data.filter(if_contains_train_label, num_proc=num_proc)
132
+ print(f"Filtered {round((1-len(trainset)/len(data))*100)}%; {len(trainset)} remain\n")
133
+ print(f"Filtering evalation data")
134
+ evalset = data.filter(if_contains_eval_label, num_proc=num_proc)
135
+ print(f"Filtered {round((1-len(evalset)/len(data))*100)}%; {len(evalset)} remain\n")
136
+
137
+ # minimize to smaller training sample
138
+ training_size = min(subsample_size, len(trainset))
139
+ trainset_min = trainset.select([i for i in range(training_size)])
140
+ eval_size = min(training_size, len(evalset))
141
+ half_training_size = round(eval_size/2)
142
+ evalset_train_min = evalset.select([i for i in range(half_training_size)])
143
+ evalset_oos_min = evalset.select([i for i in range(half_training_size, eval_size)])
144
+
145
+ # label conversion functions
146
+ def generate_train_labels(example):
147
+ example["labels"] = [label_dict_train.get(token_id, -100) for token_id in example["input_ids"]]
148
+ return example
149
+
150
+ def generate_eval_labels(example):
151
+ example["labels"] = [label_dict_eval.get(token_id, -100) for token_id in example["input_ids"]]
152
+ return example
153
+
154
+ # label datasets
155
+ print(f"Labeling training data")
156
+ trainset_labeled = trainset_min.map(generate_train_labels)
157
+ print(f"Labeling evaluation data")
158
+ evalset_train_labeled = evalset_train_min.map(generate_eval_labels)
159
+ print(f"Labeling evaluation OOS data")
160
+ evalset_oos_labeled = evalset_oos_min.map(generate_eval_labels)
161
+
162
+ # load model
163
+ model = BertForTokenClassification.from_pretrained(
164
+ pre_model,
165
+ num_labels=num_labels,
166
+ output_attentions = False,
167
+ output_hidden_states = False,
168
+ )
169
+ if freeze_layers is not None:
170
+ modules_to_freeze = model.bert.encoder.layer[:freeze_layers]
171
+ for module in modules_to_freeze:
172
+ for param in module.parameters():
173
+ param.requires_grad = False
174
+
175
+ model = model.to(device)
176
+
177
+ # add output directory to training args and initiate
178
+ training_args["output_dir"] = output_dir
179
+ training_args_init = TrainingArguments(**training_args)
180
+
181
+ # create the trainer
182
+ trainer = Trainer(
183
+ model=model,
184
+ args=training_args_init,
185
+ data_collator=DataCollatorForGeneClassification(),
186
+ train_dataset=trainset_labeled,
187
+ eval_dataset=evalset_train_labeled,
188
+ )
189
+
190
+ # train the gene classifier
191
+ trainer.train()
192
+ trainer.save_model(output_dir)
193
+
194
+ fpr, tpr, interp_tpr, conf_mat = classifier_predict(trainer.model, evalset_oos_labeled, 200, mean_fpr)
195
+ auc_score = auc(fpr, tpr)
196
+
197
+ return fpr, tpr, auc_score
198
+
199
+ # cross-validate gene classifier
200
+ def cross_validate(data, targets, labels, nsplits, subsample_size, training_args, freeze_layers, output_dir, num_proc, num_labels, pre_model):
201
+ # check if output directory already written to
202
+ # ensure not overwriting previously saved model
203
+ model_dir_test = os.path.join(output_dir, "ksplit0/models/pytorch_model.bin")
204
+ #if os.path.isfile(model_dir_test) == True:
205
+ # raise Exception("Model already saved to this directory.")
206
+
207
+ device = device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
208
+ # initiate eval metrics to return
209
+ num_classes = len(set(labels))
210
+ mean_fpr = np.linspace(0, 1, 100)
211
+ all_tpr = []
212
+ all_roc_auc = []
213
+ all_tpr_wt = []
214
+ label_dicts = []
215
+ confusion = np.zeros((num_classes,num_classes))
216
+
217
+ # set up cross-validation splits
218
+ skf = StratifiedKFold(n_splits=nsplits, random_state=0, shuffle=True)
219
+ # train and evaluate
220
+ iteration_num = 0
221
+ for train_index, eval_index in tqdm(skf.split(targets, labels)):
222
+ if len(labels) > 500:
223
+ print("early stopping activated due to large # of training examples")
224
+ if iteration_num == 3:
225
+ break
226
+
227
+ print(f"****** Crossval split: {iteration_num}/{nsplits-1} ******\n")
228
+
229
+ # generate cross-validation splits
230
+ targets_train, targets_eval = targets[train_index], targets[eval_index]
231
+ labels_train, labels_eval = labels[train_index], labels[eval_index]
232
+ label_dict_train = dict(zip(targets_train, labels_train))
233
+ label_dict_eval = dict(zip(targets_eval, labels_eval))
234
+ label_dicts += (iteration_num, targets_train, targets_eval, labels_train, labels_eval)
235
+
236
+ # function to filter by whether contains train or eval labels
237
+ def if_contains_train_label(example):
238
+ a = label_dict_train.keys()
239
+ b = example['input_ids']
240
+
241
+ return not set(a).isdisjoint(b)
242
+
243
+ def if_contains_eval_label(example):
244
+ a = label_dict_eval.keys()
245
+ b = example['input_ids']
246
+
247
+ return not set(a).isdisjoint(b)
248
+
249
+ # filter dataset for examples containing classes for this split
250
+ print(f"Filtering training data")
251
+ trainset = data.filter(if_contains_train_label, num_proc=num_proc)
252
+ print(f"Filtered {round((1-len(trainset)/len(data))*100)}%; {len(trainset)} remain\n")
253
+ print(f"Filtering evalation data")
254
+ evalset = data.filter(if_contains_eval_label, num_proc=num_proc)
255
+ print(f"Filtered {round((1-len(evalset)/len(data))*100)}%; {len(evalset)} remain\n")
256
+
257
+ # minimize to smaller training sample
258
+ training_size = min(subsample_size, len(trainset))
259
+ trainset_min = trainset.select([i for i in range(training_size)])
260
+ eval_size = min(training_size, len(evalset))
261
+ half_training_size = round(eval_size/2)
262
+ evalset_train_min = evalset.select([i for i in range(half_training_size)])
263
+ evalset_oos_min = evalset.select([i for i in range(half_training_size, eval_size)])
264
+
265
+ # label conversion functions
266
+ def generate_train_labels(example):
267
+ example["labels"] = [label_dict_train.get(token_id, -100) for token_id in example["input_ids"]]
268
+ return example
269
+
270
+ def generate_eval_labels(example):
271
+ example["labels"] = [label_dict_eval.get(token_id, -100) for token_id in example["input_ids"]]
272
+ return example
273
+
274
+ # label datasets
275
+ print(f"Labeling training data")
276
+ trainset_labeled = trainset_min.map(generate_train_labels)
277
+ print(f"Labeling evaluation data")
278
+ evalset_train_labeled = evalset_train_min.map(generate_eval_labels)
279
+ print(f"Labeling evaluation OOS data")
280
+ evalset_oos_labeled = evalset_oos_min.map(generate_eval_labels)
281
+
282
+ # create output directories
283
+ ksplit_output_dir = os.path.join(output_dir, f"ksplit{iteration_num}")
284
+ ksplit_model_dir = os.path.join(ksplit_output_dir, "models/")
285
+
286
+ # ensure not overwriting previously saved model
287
+ model_output_file = os.path.join(ksplit_model_dir, "pytorch_model.bin")
288
+ #if os.path.isfile(model_output_file) == True:
289
+ # raise Exception("Model already saved to this directory.")
290
+
291
+ # make training and model output directories
292
+ subprocess.call(f'mkdir -p {ksplit_output_dir}', shell=True)
293
+ subprocess.call(f'mkdir -p {ksplit_model_dir}', shell=True)
294
+
295
+ # load model
296
+ model = BertForTokenClassification.from_pretrained(
297
+ pre_model,
298
+ num_labels=num_labels,
299
+ output_attentions = False,
300
+ output_hidden_states = False,
301
+ )
302
+ if freeze_layers is not None:
303
+ modules_to_freeze = model.bert.encoder.layer[:freeze_layers]
304
+ for module in modules_to_freeze:
305
+ for param in module.parameters():
306
+ param.requires_grad = False
307
+
308
+ model = model.to(device)
309
+
310
+ # add output directory to training args and initiate
311
+ training_args["output_dir"] = ksplit_output_dir
312
+ training_args_init = TrainingArguments(**training_args)
313
+
314
+ # create the trainer
315
+ trainer = Trainer(
316
+ model=model,
317
+ args=training_args_init,
318
+ data_collator=DataCollatorForGeneClassification(),
319
+ train_dataset=trainset_labeled,
320
+ eval_dataset=evalset_train_labeled
321
+ )
322
+
323
+ # train the gene classifier
324
+ trainer.train()
325
+
326
+ # save model
327
+ trainer.save_model(ksplit_model_dir)
328
+
329
+ # evaluate model
330
+ fpr, tpr, interp_tpr, conf_mat = classifier_predict(trainer.model, evalset_oos_labeled, 200, mean_fpr)
331
+
332
+ # append to tpr and roc lists
333
+ confusion = confusion + conf_mat
334
+ all_tpr.append(interp_tpr)
335
+ all_roc_auc.append(auc(fpr, tpr))
336
+ # append number of eval examples by which to weight tpr in averaged graphs
337
+ all_tpr_wt.append(len(tpr))
338
+
339
+ iteration_num = iteration_num + 1
340
+
341
+ # get overall metrics for cross-validation
342
+ mean_tpr, roc_auc, roc_auc_sd = get_cross_valid_metrics(all_tpr, all_roc_auc, all_tpr_wt)
343
+ return all_roc_auc, roc_auc, roc_auc_sd, mean_fpr, mean_tpr, confusion, label_dicts
344
+
345
+ # Computes metrics
346
+ def compute_metrics(pred):
347
+ labels = pred.label_ids
348
+ preds = pred.predictions.argmax(-1)
349
+ # calculate accuracy and macro f1 using sklearn's function
350
+ acc = accuracy_score(labels, preds)
351
+ macro_f1 = f1_score(labels, preds, average='macro')
352
+
353
+ return {
354
+ 'accuracy': acc,
355
+ 'macro_f1': macro_f1
356
+ }
357
+
358
+ # plot ROC curve
359
+ def plot_ROC(bundled_data, title):
360
+ plt.figure()
361
+ lw = 2
362
+ for roc_auc, roc_auc_sd, mean_fpr, mean_tpr, sample, color in bundled_data:
363
+ plt.plot(mean_fpr, mean_tpr, color=color,
364
+ lw=lw, label="{0} (AUC {1:0.2f} $\pm$ {2:0.2f})".format(sample, roc_auc, roc_auc_sd))
365
+
366
+ plt.plot([0, 1], [0, 1], color='black', lw=lw, linestyle='--')
367
+ plt.xlim([0.0, 1.0])
368
+ plt.ylim([0.0, 1.05])
369
+ plt.xlabel('False Positive Rate')
370
+ plt.ylabel('True Positive Rate')
371
+ plt.title(title)
372
+ plt.legend(loc="lower right")
373
+ plt.savefig("ROC.png")
374
+
375
+ return mean_fpr, mean_tpr, roc_auc
376
+
377
+ # plot confusion matrix
378
+ def plot_confusion_matrix(classes_list, conf_mat, title):
379
+ display_labels = []
380
+ i = 0
381
+ for label in classes_list:
382
+ display_labels += ["{0}\nn={1:.0f}".format(label, sum(conf_mat[:,i]))]
383
+ i = i + 1
384
+ display = ConfusionMatrixDisplay(confusion_matrix=preprocessing.normalize(conf_mat, norm="l1"),
385
+ display_labels=display_labels)
386
+ display.plot(cmap="Blues",values_format=".2g")
387
+ plt.title(title)
388
+ plt.savefig("CM.png")
389
+
390
+ # Function to find the largest number smaller
391
+ # than or equal to N that is divisible by k
392
+ def find_largest_div(N, K):
393
+ rem = N % K
394
+ if(rem == 0):
395
+ return N
396
+ else:
397
+ return N - rem
398
+
399
+ def preprocess_classifier_batch(cell_batch, max_len):
400
+ if max_len == None:
401
+ max_len = max([len(i) for i in cell_batch["input_ids"]])
402
+ def pad_label_example(example):
403
+ example["labels"] = np.pad(example["labels"],
404
+ (0, max_len-len(example["input_ids"])),
405
+ mode='constant', constant_values=-100)
406
+ example["input_ids"] = np.pad(example["input_ids"],
407
+ (0, max_len-len(example["input_ids"])),
408
+ mode='constant', constant_values=token_dictionary.get("<pad>"))
409
+ example["attention_mask"] = (example["input_ids"] != token_dictionary.get("<pad>")).astype(int)
410
+ return example
411
+ padded_batch = cell_batch.map(pad_label_example)
412
+ return padded_batch
413
+
414
+ # forward batch size is batch size for model inference (e.g. 200)
415
+ def classifier_predict(model, evalset, forward_batch_size, mean_fpr):
416
+ predict_logits = []
417
+ predict_labels = []
418
+ model.to('cpu')
419
+ model.eval()
420
+
421
+ # ensure there is at least 2 examples in each batch to avoid incorrect tensor dims
422
+ evalset_len = len(evalset)
423
+ max_divisible = find_largest_div(evalset_len, forward_batch_size)
424
+ if len(evalset) - max_divisible == 1:
425
+ evalset_len = max_divisible
426
+
427
+ max_evalset_len = max(evalset.select([i for i in range(evalset_len)])["length"])
428
+
429
+ for i in range(0, evalset_len, forward_batch_size):
430
+ max_range = min(i+forward_batch_size, evalset_len)
431
+ batch_evalset = evalset.select([i for i in range(i, max_range)])
432
+ padded_batch = preprocess_classifier_batch(batch_evalset, max_evalset_len)
433
+ padded_batch.set_format(type="torch")
434
+
435
+ input_data_batch = padded_batch["input_ids"]
436
+ attn_msk_batch = padded_batch["attention_mask"]
437
+ label_batch = padded_batch["labels"]
438
+ with torch.no_grad():
439
+ input_ids = input_data_batch
440
+ attn_mask = attn_msk_batch
441
+ labels = label_batch
442
+ outputs = model(
443
+
444
+ input_ids = input_ids,
445
+ attention_mask = attn_mask,
446
+ labels = labels
447
+ )
448
+ predict_logits += [torch.squeeze(outputs.logits.to("cpu"))]
449
+ predict_labels += [torch.squeeze(label_batch.to("cpu"))]
450
+
451
+ logits_by_cell = torch.cat(predict_logits)
452
+ all_logits = logits_by_cell.reshape(-1, logits_by_cell.shape[2])
453
+ labels_by_cell = torch.cat(predict_labels)
454
+ all_labels = torch.flatten(labels_by_cell)
455
+ logit_label_paired = [item for item in list(zip(all_logits.tolist(), all_labels.tolist())) if item[1]!=-100]
456
+ y_pred = [vote(item[0]) for item in logit_label_paired]
457
+ y_true = [item[1] for item in logit_label_paired]
458
+ logits_list = [item[0] for item in logit_label_paired]
459
+ # probability of class 1
460
+ y_score = [py_softmax(item)[1] for item in logits_list]
461
+ conf_mat = confusion_matrix(y_true, y_pred)
462
+ fpr, tpr, _ = roc_curve(y_true, y_score)
463
+ # plot roc_curve for this split
464
+ plt.plot(fpr, tpr)
465
+ plt.xlim([0.0, 1.0])
466
+ plt.ylim([0.0, 1.05])
467
+ plt.xlabel('False Positive Rate')
468
+ plt.ylabel('True Positive Rate')
469
+ plt.title('ROC')
470
+ plt.show()
471
+ # interpolate to graph
472
+ interp_tpr = np.interp(mean_fpr, fpr, tpr)
473
+ interp_tpr[0] = 0.0
474
+ return fpr, tpr, interp_tpr, conf_mat
475
+
476
+ def classify_genes(gene_info = "Genecorpus-30M/example_input_files/gene_info_table.csv", genes = "Genecorpus-30M/example_input_files/gene_classification/dosage_sensitive_tfs/dosage_sens_tf_labels.csv",
477
+ corpus_30M = "Genecorpus-30M/genecorpus_30M_2048.dataset/", model = '.',
478
+ max_input_size = 2 ** 11, max_lr = 5e-5, freeze_layers = 4, num_gpus = 1, num_proc = os.cpu_count(), geneformer_batch_size = 9, epochs = 1, filter_dataset = 50_000,
479
+ emb_extract = True, emb_layer = 0, forward_batch = 200, filter_data = None, inference = False, k_validate = True, model_location = "230917_geneformer_GeneClassifier_dosageTF_L2048_B12_LR5e-05_LSlinear_WU500_E1_Oadamw_n10000_F4/", skip_training = False, emb_dir = 'gene_emb', output_dir = None, max_cells = 1000, num_cpus = os.cpu_count()):
480
+
481
+
482
+ """"
483
+ Primary Parameters
484
+ -----------
485
+
486
+ gene_info: path
487
+ Path to gene mappings
488
+
489
+ corpus_30M: path
490
+ Path to 30M Gene Corpus
491
+
492
+ model: path
493
+ Path to pretrained GeneFormer model
494
+
495
+ genes: path
496
+ Path to csv file containing different columns of genes and the column labels
497
+
498
+ inference: bool
499
+ Whether the model should be used to run inference. If False, model will train with labeled data instead. Defaults to False
500
+
501
+ k_validate: bool
502
+ Whether the model should run k-fold validation or simply perform regular training/evaluate. Defaults to True
503
+
504
+ skip_training: bool
505
+ Whether the model should skip the training portion. Defaults to False
506
+
507
+ emb_extract: bool
508
+ WHether the model should extract embeddings for a given gene (WIP)
509
+
510
+
511
+ Customization Parameters
512
+ -----------
513
+
514
+ freeze_layers: int
515
+ Freezes x number of layers from the model. Default is 4 (2 non-frozen layers)
516
+
517
+ filter_dataset: int
518
+ Number of cells to filter from 30M dataset. Default is 50_000
519
+
520
+ emb_layer: int
521
+ What layer embeddings are extracted from. Default is 4
522
+
523
+ filter_data: str, list
524
+ Filters down embeddings to a single category. Default is None
525
+
526
+
527
+ """
528
+
529
+ # table of corresponding Ensembl IDs, gene names, and gene types (e.g. coding, miRNA, etc.)
530
+ gene_info = pd.read_csv(gene_info, index_col=0)
531
+ labels = gene_info.columns
532
+
533
+ # create dictionaries for corresponding attributes
534
+ gene_id_type_dict = dict(zip(gene_info["ensembl_id"],gene_info["gene_type"]))
535
+ gene_name_id_dict = dict(zip(gene_info["gene_name"],gene_info["ensembl_id"]))
536
+ gene_id_name_dict = {v: k for k,v in gene_name_id_dict.items()}
537
+
538
+ # function for preparing targets and labels
539
+ def prep_inputs(label_store, id_type):
540
+ target_list = []
541
+ if id_type == "gene_name":
542
+ for key in list(label_store.keys()):
543
+ targets = [gene_name_id_dict[gene] for gene in label_store[key] if gene_name_id_dict.get(gene) in token_dictionary]
544
+ targets_id = [token_dictionary[gene] for gene in targets]
545
+ target_list.append(targets_id)
546
+ elif id_type == "ensembl_id":
547
+ for key in list(label_store.keys()):
548
+ targets = [gene for gene in label_store[key] if gene in token_dictionary]
549
+ targets_id = [token_dictionary[gene] for gene in targets]
550
+ target_list.append(targets_id)
551
+
552
+ targets, labels = [], []
553
+ for targ in target_list:
554
+ targets = targets + targ
555
+ targets = np.array(targets)
556
+ for num, targ in enumerate(target_list):
557
+ label = [num]*len(targ)
558
+ labels = labels + label
559
+ labels = np.array(labels)
560
+ unique_labels = num + 1
561
+
562
+ nsplits = min(5, min([len(targ) for targ in target_list])-1)
563
+ assert nsplits > 2
564
+
565
+ return targets, labels, nsplits, unique_labels
566
+
567
+ if skip_training == False:
568
+ # preparing targets and labels for dosage sensitive vs insensitive TFs
569
+ gene_classes = pd.read_csv(genes, header=0)
570
+ if filter_data == None:
571
+ labels = gene_classes.columns
572
+ else:
573
+ if isinstance(filter_data, list):
574
+ labels = filter_data
575
+ else:
576
+ labels = [filter_data]
577
+ label_store = {}
578
+
579
+ # Dictionary for decoding labels
580
+ decode = {i:labels[i] for i in range(len(labels))}
581
+
582
+ for label in labels:
583
+ label_store[label] = gene_classes[label].dropna()
584
+
585
+ targets, labels, nsplits, unique_labels = prep_inputs(label_store, "ensembl_id")
586
+
587
+
588
+
589
+ # load training dataset
590
+ train_dataset=load_from_disk(corpus_30M)
591
+ shuffled_train_dataset = train_dataset.shuffle(seed=42)
592
+ subsampled_train_dataset = shuffled_train_dataset.select([i for i in range(filter_dataset)])
593
+ lr_schedule_fn = "linear"
594
+ warmup_steps = 500
595
+ optimizer = "adamw"
596
+ subsample_size = 10_000
597
+
598
+ training_args = {
599
+ "learning_rate": max_lr,
600
+ "do_train": True,
601
+ "evaluation_strategy": "no",
602
+ "save_strategy": "epoch",
603
+ "logging_steps": 10,
604
+ "group_by_length": True,
605
+ "length_column_name": "length",
606
+ "disable_tqdm": False,
607
+ "lr_scheduler_type": lr_schedule_fn,
608
+ "warmup_steps": warmup_steps,
609
+ "weight_decay": 0.001,
610
+ "per_device_train_batch_size": geneformer_batch_size,
611
+ "per_device_eval_batch_size": geneformer_batch_size,
612
+ "num_train_epochs": epochs,
613
+ }
614
+
615
+ # define output directory path
616
+ current_date = datetime.datetime.now()
617
+ datestamp = f"{str(current_date.year)[-2:]}{current_date.month:02d}{current_date.day:02d}"
618
+
619
+ if output_dir == None:
620
+ training_output_dir = Path(f"{datestamp}_geneformer_GeneClassifier_dosageTF_L{max_input_size}_B{geneformer_batch_size}_LR{max_lr}_LS{lr_schedule_fn}_WU{warmup_steps}_E{epochs}_O{optimizer}_n{subsample_size}_F{freeze_layers}/")
621
+ else:
622
+ training_output_dir = Path(output_dir)
623
+
624
+ # make output directory
625
+ subprocess.call(f'mkdir -p {training_output_dir}', shell=True)
626
+
627
+ # Places number of classes + in directory
628
+ num_classes = len(set(labels))
629
+ info_list = [num_classes, decode]
630
+
631
+ with open(training_output_dir / 'classes.txt', 'w') as f:
632
+ f.write(str(info_list))
633
+
634
+ subsampled_train_dataset.save_to_disk(output_dir / 'dataset')
635
+
636
+ if k_validate == True:
637
+ ksplit_model ="ksplit0/models"
638
+ ksplit_model_test = os.path.join(training_output_dir, ksplit_model)
639
+ #if os.path.isfile(ksplit_model_test) == True:
640
+ # raise Exception("Model already saved to this directory.")
641
+ # cross-validate gene classifier
642
+ all_roc_auc, roc_auc, roc_auc_sd, mean_fpr, mean_tpr, confusion, label_dicts = cross_validate(subsampled_train_dataset, targets, labels, nsplits, subsample_size, training_args, freeze_layers, training_output_dir, 1, unique_labels, model)
643
+
644
+ bundled_data = []
645
+ bundled_data += [(roc_auc, roc_auc_sd, mean_fpr, mean_tpr, "Geneformer", "red")]
646
+ graph_title = " ".join([i + ' vs' if count < len(label_store) - 1 else i for count, i in enumerate(label_store)])
647
+ fpr, tpr, auc = plot_ROC(bundled_data, 'Dosage Sensitive vs Insensitive TFs')
648
+ print(auc)
649
+ # plot confusion matrix
650
+ plot_confusion_matrix(label_store, confusion, "Geneformer")
651
+ else:
652
+ fpr, tpr, auc = validate(subsampled_train_dataset, targets, labels, nsplits, subsample_size, training_args, freeze_layers, training_output_dir, 1, unique_labels, model)
653
+ print(auc)
654
+
655
+ if inference == True:
656
+ # preparing targets and labels for dosage sensitive vs insensitive TFs
657
+ gene_classes = pd.read_csv(genes, header=0)
658
+ targets = []
659
+ for column in gene_classes.columns:
660
+ targets += list(gene_classes[column])
661
+ tokens = []
662
+ for target in targets:
663
+ try:
664
+ tokens.append(token_dictionary[target])
665
+ except:
666
+ tokens.append(0)
667
+
668
+ targets = torch.LongTensor([tokens])
669
+
670
+
671
+ with open(f'{model_location}classes.txt', 'r') as f:
672
+ info_list = ast.literal_eval(f.read())
673
+ num_classes = info_list[0]
674
+ labels = info_list[1]
675
+
676
+ model = BertForTokenClassification.from_pretrained(
677
+ model_location,
678
+ num_labels=num_classes,
679
+ output_attentions = False,
680
+ output_hidden_states = False,
681
+ local_files_only = True
682
+ )
683
+ if freeze_layers is not None:
684
+ modules_to_freeze = model.bert.encoder.layer[:freeze_layers]
685
+ for module in modules_to_freeze:
686
+ for param in module.parameters():
687
+ param.requires_grad = False
688
+
689
+ model = model.to(device)
690
+
691
+ # evaluate model
692
+ predictions = F.softmax(model(targets.to(device))["logits"], dim = -1).argmax(-1)[0]
693
+ predictions = [labels[int(pred)] for pred in predictions]
694
+
695
+ return predictions
696
+
697
+ # Extracts aggregate gene embeddings for each label
698
+ if emb_extract == True:
699
+ with open(f'{model_location}/classes.txt', 'r') as f:
700
+ data = ast.literal_eval(f.read())
701
+ num_classes = data[0]
702
+ decode = data[1]
703
+
704
+ gene_classes = pd.read_csv(genes, header=0)
705
+ labels = gene_classes.columns
706
+ tokenize = TranscriptomeTokenizer()
707
+
708
+ label_dict = {}
709
+ for label in labels:
710
+ genes = gene_classes[label]
711
+ tokenized_genes = []
712
+ for gene in genes:
713
+ try:
714
+ tokenized_genes.append(tokenize.gene_token_dict[gene])
715
+ except:
716
+ continue
717
+ label_dict[label] = tokenized_genes
718
+
719
+ embex = EmbExtractor(model_type="GeneClassifier", num_classes=num_classes, emb_mode = "gene",
720
+ filter_data=None, max_ncells=max_cells, emb_layer=emb_layer,
721
+ emb_label=label_dict, labels_to_plot=list(labels), forward_batch_size=forward_batch, nproc=num_cpus)
722
+
723
+
724
+ subprocess.call(f'mkdir -p {emb_dir}', shell = True)
725
+
726
+ embs = embex.extract_embs(model_directory = model_location, input_data_file = model_location / 'dataset', output_directory = emb_dir, output_prefix = f"{label}_embbeddings")
727
+
728
+ emb_dict = {label:[] for label in list(set(labels))}
729
+ similarities = {key:{} for key in list(emb_dict.keys())}
730
+
731
+ for column in embs.columns:
732
+ remaining_cols = [k for k in embs.columns if k != column]
733
+ for k in remaining_cols:
734
+ embedding = torch.Tensor(embs[k])
735
+ sim = similarity(torch.Tensor(embs[column]), embedding, cosine = True)
736
+ similarities[column][k] = sim
737
+
738
+ plot_similarity_heatmap(similarities)
739
+ print(similarities)
740
+
741
+ return similarities
742
+
743
+ if __name__ == '__main__':
744
+ classify_genes(k_validate = False, inference = False, skip_training = False, emb_extract = True, output_dir = Path('gene_emb'), model_location = Path('gene_emb'), epochs = 5, gene_info = "../GeneFormer_repo/Genecorpus-30M/example_input_files/gene_info_table.csv", genes = "../GeneFormer_repo/Genecorpus-30M/example_input_files/gene_classification/dosage_sensitive_tfs/dosage_sens_tf_labels.csv", corpus_30M = "../GeneFormer_repo/Genecorpus-30M/genecorpus_30M_2048.dataset/")
745
+
746
+
Immune_modelpredictions.pickle ADDED
Binary file (99.1 kB). View file
 
Modular_usage.md ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Cell classifier
2
+ def finetune_cells(token_set = Path('geneformer/token_dictionary.pkl'), median_set = Path('geneformer/gene_median_dictionary.pkl'), pretrained_model = ".",
3
+ dataset = 'Genecorpus-30M/example_input_files/cell_classification/cell_type_annotation/cell_type_train_data.dataset/',
4
+ dataset_split = None,
5
+ filter_cells = .005,
6
+ epochs = 1,
7
+ cpu_cores = os.cpu_count(),
8
+ geneformer_batch_size = 12,
9
+ optimizer = 'adamw',
10
+ max_lr = 5e-5,
11
+ num_gpus = torch.cuda.device_count(),
12
+ max_input_size = 2 ** 11,
13
+ lr_schedule_fn = "linear",
14
+ warmup_steps = 500,
15
+ freeze_layers = 0,
16
+ emb_extract = False,
17
+ max_cells = 1000,
18
+ emb_layer = 0,
19
+ emb_filter = None,
20
+ emb_dir = 'embeddings',
21
+ overwrite = True,
22
+ label = "cell_type",
23
+ data_filter = None,
24
+ forward_batch = 200, model_location = None,
25
+ skip_training = False,
26
+ sample_data = 1,
27
+ inference = False,
28
+ optimize_hyperparameters = False,
29
+ output_dir = None):
30
+
31
+ '''
32
+ Primary Parameters
33
+ -------------------
34
+ dataset: path
35
+ Path to fine-tuning/testing dataset for training
36
+
37
+ model_location: path
38
+ Path to location of existing model to use for inference and embedding extraction
39
+
40
+ pretrained_model: path
41
+ Path to pretrained GeneFormer 30M model before fine-tuning
42
+
43
+ inference: bool
44
+ Chooses whether to perform inference (which causes the function to return the list of similarities). Defaults to False
45
+
46
+ skip_training: bool
47
+ Chooses whether to skip training the model. Defaults to False
48
+
49
+ emb_extract: bool
50
+ Choose whether to extract embeddings and calculate similarities. Defaults to True
51
+
52
+ optimize_hyperparameters: bool
53
+ Choose whether to optimize model hyperparamters. Defaults to False
54
+ label: string
55
+ The label string in the formatted dataset that contains true class labels. Defaults to "label"
56
+
57
+ Customization Parameters
58
+ -------------------
59
+
60
+ dataset_split: str
61
+ How the dataset should be partitioned (if at all), and what ID should be used for partitioning
62
+
63
+ data_filter: list
64
+ (For embeddings and inference) Runs analysis subsets of the dataset by the ID defined by dataset_split
65
+
66
+ label: str
67
+ What feature should be read as a classification label
68
+
69
+ emb_layer: int
70
+ What layer embeddings should be extracted and compared from.
71
+
72
+ emb_filter: ['cell1', 'cell2'...]
73
+ Allows user to narrow down range of cells that embeddings will be extracted from.
74
+
75
+ max_cells: int
76
+ How many embeddings from cells should be extracted.
77
+
78
+ freeze_layers: int
79
+ Number of layers should be permanently frozen during fine-tuning (starting from the first layer, 4 brings it up to the pretrained model).
80
+
81
+ sample_data: float
82
+ What proportion of the HF dataset should be used
83
+
84
+ '''
85
+
86
+ # Gene Classifier
87
+ def classify_genes(gene_info = "Genecorpus-30M/example_input_files/gene_info_table.csv",
88
+ genes = "Genecorpus-30M/example_input_files/gene_classification/dosage_sensitive_tfs/dosage_sens_tf_labels.csv",
89
+ corpus_30M = "Genecorpus-30M/genecorpus_30M_2048.dataset/", model = '.',
90
+ max_input_size = 2 ** 11,
91
+ max_lr = 5e-5,
92
+ freeze_layers = 4,
93
+ num_gpus = 1,
94
+ num_proc = os.cpu_count(),
95
+ geneformer_batch_size = 9,
96
+ epochs = 1,
97
+ filter_dataset = 50_000,
98
+ emb_extract = True,
99
+ emb_layer = 0,
100
+ forward_batch = 200,
101
+ filter_data = None,
102
+ inference = False,
103
+ k_validate = True,
104
+ model_location = "230917_geneformer_GeneClassifier_dosageTF_L2048_B12_LR5e-05_LSlinear_WU500_E1_Oadamw_n10000_F4/",
105
+ skip_training = False,
106
+ emb_dir = 'gene_emb',
107
+ output_dir = None,
108
+ max_cells = 1000,
109
+ num_cpus = os.cpu_count()):
110
+
111
+ """"
112
+ Primary Parameters
113
+ -----------
114
+
115
+ gene_info: path
116
+ Path to gene mappings
117
+
118
+ corpus_30M: path
119
+ Path to 30M Gene Corpus
120
+
121
+ model: path
122
+ Path to pretrained GeneFormer model
123
+
124
+ genes: path
125
+ Path to csv file containing different columns of genes and the column labels
126
+
127
+ inference: bool
128
+ Whether the model should be used to run inference. If False, model will train with labeled data instead. Defaults to False
129
+
130
+ k_validate: bool
131
+ Whether the model should run k-fold validation or simply perform regular training/evaluate. Defaults to True
132
+
133
+ skip_training: bool
134
+ Whether the model should skip the training portion. Defaults to False
135
+
136
+ emb_extract: bool
137
+ WHether the model should extract embeddings for a given gene (WIP)
138
+
139
+
140
+ Customization Parameters
141
+ -----------
142
+
143
+ freeze_layers: int
144
+ Freezes x number of layers from the model. Default is 4 (2 non-frozen layers)
145
+
146
+ filter_dataset: int
147
+ Number of cells to filter from 30M dataset. Default is 50_000
148
+
149
+ emb_layer: int
150
+ What layer embeddings are extracted from. Default is 4
151
+
152
+ filter_data: str, list
153
+ Filters down embeddings to a single category. Default is None
154
+
155
+
156
+ """
gene_embclasses.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ [2, {0: 0, 1: 0}]
gene_embdataset.pk ADDED
Binary file (1.76 kB). View file