Christina Theodoris commited on
Commit
efec1c4
1 Parent(s): 09276dd

add in silico perturbation module

Browse files
geneformer/__init__.py CHANGED
@@ -2,7 +2,11 @@ from . import tokenizer
2
  from . import pretrainer
3
  from . import collator_for_cell_classification
4
  from . import collator_for_gene_classification
 
 
5
  from .tokenizer import TranscriptomeTokenizer
6
  from .pretrainer import GeneformerPretrainer
7
  from .collator_for_gene_classification import DataCollatorForGeneClassification
8
  from .collator_for_cell_classification import DataCollatorForCellClassification
 
 
 
2
  from . import pretrainer
3
  from . import collator_for_cell_classification
4
  from . import collator_for_gene_classification
5
+ from . import in_silico_perturber
6
+ from . import in_silico_perturber_stats
7
  from .tokenizer import TranscriptomeTokenizer
8
  from .pretrainer import GeneformerPretrainer
9
  from .collator_for_gene_classification import DataCollatorForGeneClassification
10
  from .collator_for_cell_classification import DataCollatorForCellClassification
11
+ from .in_silico_perturber import InSilicoPerturber
12
+ from .in_silico_perturber_stats import InSilicoPerturberStats
geneformer/in_silico_perturber.py ADDED
@@ -0,0 +1,777 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Geneformer in silico perturber.
3
+
4
+ Usage:
5
+ from geneformer import InSilicoPerturber
6
+ isp = InSilicoPerturber(perturb_type="delete",
7
+ perturb_rank_shift=None,
8
+ genes_to_perturb="all",
9
+ combos=0,
10
+ anchor_gene=None,
11
+ model_type="Pretrained",
12
+ num_classes=0,
13
+ emb_mode="cell",
14
+ cell_emb_style="mean_pool",
15
+ filter_data={"cell_type":["cardiomyocyte"]},
16
+ cell_states_to_model={"disease":(["dcm"],["ctrl"],["hcm"])},
17
+ max_ncells=None,
18
+ emb_layer=-1,
19
+ forward_batch_size=100,
20
+ nproc=4,
21
+ save_raw_data=False)
22
+ isp.perturb_data("path/to/model",
23
+ "path/to/input_data",
24
+ "path/to/output_directory",
25
+ "output_prefix")
26
+ """
27
+
28
+ # imports
29
+ import itertools as it
30
+ import logging
31
+ import pickle
32
+ import seaborn as sns; sns.set()
33
+ import torch
34
+ from collections import defaultdict
35
+ from datasets import Dataset, load_from_disk
36
+ from tqdm.notebook import trange
37
+ from transformers import BertForMaskedLM, BertForTokenClassification, BertForSequenceClassification
38
+
39
+ from .tokenizer import TOKEN_DICTIONARY_FILE
40
+
41
+ logger = logging.getLogger(__name__)
42
+
43
+ def quant_layers(model):
44
+ layer_nums = []
45
+ for name, parameter in model.named_parameters():
46
+ if "layer" in name:
47
+ layer_nums += [name.split("layer.")[1].split(".")[0]]
48
+ return int(max(layer_nums))+1
49
+
50
+ def flatten_list(megalist):
51
+ return [item for sublist in megalist for item in sublist]
52
+
53
+ def forward_pass_single_cell(model, example_cell, layer_to_quant):
54
+ example_cell.set_format(type="torch")
55
+ input_data = example_cell["input_ids"]
56
+ with torch.no_grad():
57
+ outputs = model(
58
+ input_ids = input_data.to("cuda")
59
+ )
60
+ emb = torch.squeeze(outputs.hidden_states[layer_to_quant])
61
+ del outputs
62
+ return emb
63
+
64
+ def perturb_emb_by_index(emb, indices):
65
+ mask = torch.ones(emb.numel(), dtype=torch.bool)
66
+ mask[indices] = False
67
+ return emb[mask]
68
+
69
+ def delete_index(example):
70
+ indexes = example["perturb_index"]
71
+ if len(indexes)>1:
72
+ indexes = flatten_list(indexes)
73
+ for index in sorted(indexes, reverse=True):
74
+ del example["input_ids"][index]
75
+ return example
76
+
77
+ def overexpress_index(example):
78
+ indexes = example["perturb_index"]
79
+ if len(indexes)>1:
80
+ indexes = flatten_list(indexes)
81
+ for index in sorted(indexes, reverse=True):
82
+ example["input_ids"].insert(0, example["input_ids"].pop(index))
83
+ return example
84
+
85
+ def make_perturbation_batch(example_cell,
86
+ perturb_type,
87
+ tokens_to_perturb,
88
+ anchor_token,
89
+ combo_lvl,
90
+ num_proc):
91
+ if tokens_to_perturb == "all":
92
+ if perturb_type in ["overexpress","activate"]:
93
+ range_start = 1
94
+ elif perturb_type in ["delete","inhibit"]:
95
+ range_start = 0
96
+ indices_to_perturb = [[i] for i in range(range_start,example_cell["length"][0])]
97
+ elif combo_lvl>0 and (anchor_token is not None):
98
+ example_input_ids = example_cell["input_ids "][0]
99
+ anchor_index = example_input_ids.index(anchor_token[0])
100
+ indices_to_perturb = [sorted([anchor_index,i]) if i!=anchor_index else None for i in range(example_cell["length"][0])]
101
+ indices_to_perturb = [item for item in indices_to_perturb if item is not None]
102
+ else:
103
+ example_input_ids = example_cell["input_ids"][0]
104
+ indices_to_perturb = [[example_input_ids.index(token)] if token in example_input_ids else None for token in tokens_to_perturb]
105
+ indices_to_perturb = [item for item in indices_to_perturb if item is not None]
106
+
107
+ # create all permutations of combo_lvl of modifiers from tokens_to_perturb
108
+ if combo_lvl>0 and (anchor_token is None):
109
+ if tokens_to_perturb != "all":
110
+ if len(tokens_to_perturb) == combo_lvl+1:
111
+ indices_to_perturb = [list(x) for x in it.combinations(indices_to_perturb, combo_lvl+1)]
112
+ else:
113
+ all_indices = [[i] for i in range(example_cell["length"][0])]
114
+ all_indices = [index for index in all_indices if index not in indices_to_perturb]
115
+ indices_to_perturb = [[[j for i in indices_to_perturb for j in i], x] for x in all_indices]
116
+ length = len(indices_to_perturb)
117
+ perturbation_dataset = Dataset.from_dict({"input_ids": example_cell["input_ids"]*length, "perturb_index": indices_to_perturb})
118
+ if length<400:
119
+ num_proc_i = 1
120
+ else:
121
+ num_proc_i = num_proc
122
+ if perturb_type == "delete":
123
+ perturbation_dataset = perturbation_dataset.map(delete_index, num_proc=num_proc_i)
124
+ elif perturb_type == "overexpress":
125
+ perturbation_dataset = perturbation_dataset.map(overexpress_index, num_proc=num_proc_i)
126
+ return perturbation_dataset, indices_to_perturb
127
+
128
+ # original cell emb removing the respective perturbed gene emb
129
+ def make_comparison_batch(original_emb, indices_to_perturb):
130
+ all_embs_list = []
131
+ for indices in indices_to_perturb:
132
+ emb_list = []
133
+ start = 0
134
+ if len(indices)>1 and isinstance(indices[0],list):
135
+ indices = flatten_list(indices)
136
+ for i in sorted(indices):
137
+ emb_list += [original_emb[start:i]]
138
+ start = i+1
139
+ emb_list += [original_emb[start:]]
140
+ all_embs_list += [torch.cat(emb_list)]
141
+ return torch.stack(all_embs_list)
142
+
143
+ # average embedding position of goal cell states
144
+ def get_cell_state_avg_embs(model,
145
+ filtered_input_data,
146
+ cell_states_to_model,
147
+ layer_to_quant,
148
+ token_dictionary,
149
+ forward_batch_size,
150
+ num_proc):
151
+ possible_states = [value[0]+value[1]+value[2] for value in cell_states_to_model.values()][0]
152
+ state_embs_dict = dict()
153
+ for possible_state in possible_states:
154
+ state_embs_list = []
155
+
156
+ def filter_states(example):
157
+ return example[list(cell_states_to_model.keys())[0]] in [possible_state]
158
+ filtered_input_data_state = filtered_input_data.filter(filter_states, num_proc=num_proc)
159
+ total_batch_length = len(filtered_input_data_state)
160
+ if ((total_batch_length-1)/forward_batch_size).is_integer():
161
+ forward_batch_size = forward_batch_size-1
162
+ max_len = max(filtered_input_data_state["length"])
163
+ for i in range(0, total_batch_length, forward_batch_size):
164
+ max_range = min(i+forward_batch_size, total_batch_length)
165
+
166
+ state_minibatch = filtered_input_data_state.select([i for i in range(i, max_range)])
167
+ state_minibatch.set_format(type="torch")
168
+
169
+ input_data_minibatch = state_minibatch["input_ids"]
170
+ input_data_minibatch = pad_tensor_list(input_data_minibatch, max_len, token_dictionary)
171
+
172
+ with torch.no_grad():
173
+ outputs = model(
174
+ input_ids = input_data_minibatch.to("cuda")
175
+ )
176
+
177
+ state_embs_i = outputs.hidden_states[layer_to_quant]
178
+ state_embs_list += [state_embs_i]
179
+ del outputs
180
+ del state_minibatch
181
+ del input_data_minibatch
182
+ del state_embs_i
183
+ torch.cuda.empty_cache()
184
+ state_embs_stack = torch.cat(state_embs_list)
185
+ avg_state_emb = torch.mean(state_embs_stack,dim=[0,1],keepdim=True)
186
+ state_embs_dict[possible_state] = avg_state_emb
187
+ return state_embs_dict
188
+
189
+ # quantify cosine similarity of perturbed vs original or alternate states
190
+ def quant_cos_sims(model,
191
+ perturbation_batch,
192
+ forward_batch_size,
193
+ layer_to_quant,
194
+ original_emb,
195
+ indices_to_perturb,
196
+ cell_states_to_model,
197
+ state_embs_dict):
198
+ cos = torch.nn.CosineSimilarity(dim=2)
199
+ total_batch_length = len(perturbation_batch)
200
+ if ((total_batch_length-1)/forward_batch_size).is_integer():
201
+ forward_batch_size = forward_batch_size-1
202
+ if cell_states_to_model is None:
203
+ comparison_batch = make_comparison_batch(original_emb, indices_to_perturb)
204
+ cos_sims = []
205
+ else:
206
+ possible_states = [value[0]+value[1]+value[2] for value in cell_states_to_model.values()][0]
207
+ cos_sims_vs_alt_dict = dict(zip(possible_states,[[] for i in range(len(possible_states))]))
208
+ for i in range(0, total_batch_length, forward_batch_size):
209
+ max_range = min(i+forward_batch_size, total_batch_length)
210
+
211
+ perturbation_minibatch = perturbation_batch.select([i for i in range(i, max_range)])
212
+ perturbation_minibatch.set_format(type="torch")
213
+
214
+ input_data_minibatch = perturbation_minibatch["input_ids"]
215
+
216
+ with torch.no_grad():
217
+ outputs = model(
218
+ input_ids = input_data_minibatch.to("cuda")
219
+ )
220
+ del input_data_minibatch
221
+ del perturbation_minibatch
222
+ # cosine similarity between original emb and batch items
223
+ if len(indices_to_perturb)>1:
224
+ minibatch_emb = torch.squeeze(outputs.hidden_states[layer_to_quant])
225
+ else:
226
+ minibatch_emb = outputs.hidden_states[layer_to_quant]
227
+ if cell_states_to_model is None:
228
+ minibatch_comparison = comparison_batch[i:max_range]
229
+ cos_sims += [cos(minibatch_emb, minibatch_comparison).to("cpu")]
230
+ else:
231
+ for state in possible_states:
232
+ cos_sims_vs_alt_dict[state] += cos_sim_shift(original_emb, minibatch_emb, state_embs_dict[state])
233
+ del outputs
234
+ del minibatch_emb
235
+ if cell_states_to_model is None:
236
+ del minibatch_comparison
237
+ torch.cuda.empty_cache()
238
+ if cell_states_to_model is None:
239
+ cos_sims_stack = torch.cat(cos_sims)
240
+ return cos_sims_stack
241
+ else:
242
+ for state in possible_states:
243
+ cos_sims_vs_alt_dict[state] = torch.cat(cos_sims_vs_alt_dict[state])
244
+ return cos_sims_vs_alt_dict
245
+
246
+ # calculate cos sim shift of perturbation with respect to origin and alternative cell
247
+ def cos_sim_shift(original_emb, minibatch_emb, alt_emb):
248
+ cos = torch.nn.CosineSimilarity(dim=2)
249
+ original_emb = torch.mean(original_emb,dim=0,keepdim=True)[None, :]
250
+ alt_emb = alt_emb[None, None, :]
251
+ origin_v_end = cos(original_emb,alt_emb)
252
+ perturb_v_end = cos(torch.mean(minibatch_emb,dim=1,keepdim=True),alt_emb)
253
+ return [(perturb_v_end-origin_v_end).to("cpu")]
254
+
255
+ # pad list of tensors and convert to tensor
256
+ def pad_tensor_list(tensor_list, dynamic_or_constant, token_dictionary):
257
+
258
+ pad_token_id = token_dictionary.get("<pad>")
259
+
260
+ # Determine maximum tensor length
261
+ if dynamic_or_constant == "dynamic":
262
+ max_len = max([tensor.squeeze().numel() for tensor in tensor_list])
263
+ elif type(dynamic_or_constant) == int:
264
+ max_len = dynamic_or_constant
265
+ else:
266
+ logger.warning(
267
+ "If padding style is constant, must provide integer value. " \
268
+ "Setting padding to max input size 2048.")
269
+
270
+ # pad all tensors to maximum length
271
+ tensor_list = [torch.nn.functional.pad(tensor, pad=(0,
272
+ max_len - tensor.numel()),
273
+ mode='constant',
274
+ value=pad_token_id) for tensor in tensor_list]
275
+
276
+ # return stacked tensors
277
+ return torch.stack(tensor_list)
278
+
279
+ class InSilicoPerturber:
280
+ valid_option_dict = {
281
+ "perturb_type": {"delete","overexpress","inhibit","activate"},
282
+ "perturb_rank_shift": {None, int},
283
+ "genes_to_perturb": {"all", list},
284
+ "combos": {0,1,2},
285
+ "anchor_gene": {None, str},
286
+ "model_type": {"Pretrained","GeneClassifier","CellClassifier"},
287
+ "num_classes": {int},
288
+ "emb_mode": {"cell","cell_and_gene"},
289
+ "cell_emb_style": {"mean_pool"},
290
+ "filter_data": {None, dict},
291
+ "cell_states_to_model": {None, dict},
292
+ "max_ncells": {None, int},
293
+ "emb_layer": {-1, 0},
294
+ "forward_batch_size": {int},
295
+ "nproc": {int},
296
+ "save_raw_data": {False, True},
297
+ }
298
+ def __init__(
299
+ self,
300
+ perturb_type="delete",
301
+ perturb_rank_shift=None,
302
+ genes_to_perturb="all",
303
+ combos=0,
304
+ anchor_gene=None,
305
+ model_type="Pretrained",
306
+ num_classes=0,
307
+ emb_mode="cell",
308
+ cell_emb_style="mean_pool",
309
+ filter_data=None,
310
+ cell_states_to_model=None,
311
+ max_ncells=None,
312
+ emb_layer=-1,
313
+ forward_batch_size=100,
314
+ nproc=4,
315
+ save_raw_data=False,
316
+ token_dictionary_file=TOKEN_DICTIONARY_FILE,
317
+ ):
318
+ """
319
+ Initialize in silico perturber.
320
+
321
+ Parameters
322
+ ----------
323
+ perturb_type : {"delete","overexpress","inhibit","activate"}
324
+ Type of perturbation.
325
+ "delete": delete gene from rank value encoding
326
+ "overexpress": move gene to front of rank value encoding
327
+ "inhibit": move gene to lower quartile of rank value encoding
328
+ "activate": move gene to higher quartile of rank value encoding
329
+ perturb_rank_shift : None, int
330
+ Number of quartiles by which to shift rank of gene.
331
+ For example, if perturb_type="activate" and perturb_rank_shift=1:
332
+ genes in 4th quartile will move to middle of 3rd quartile.
333
+ genes in 3rd quartile will move to middle of 2nd quartile.
334
+ genes in 2nd quartile will move to middle of 1st quartile.
335
+ genes in 1st quartile will move to front of rank value encoding.
336
+ For example, if perturb_type="inhibit" and perturb_rank_shift=2:
337
+ genes in 1st quartile will move to middle of 3rd quartile.
338
+ genes in 2nd quartile will move to middle of 4th quartile.
339
+ genes in 3rd or 4th quartile will move to bottom of rank value encoding.
340
+ genes_to_perturb : "all", list
341
+ Default is perturbing each gene detected in each cell in the dataset.
342
+ Otherwise, may provide a list of ENSEMBL IDs of genes to perturb.
343
+ combos : {0,1,2}
344
+ Whether to perturb genes individually (0), in pairs (1), or in triplets (2).
345
+ anchor_gene : None, str
346
+ ENSEMBL ID of gene to use as anchor in combination perturbations.
347
+ For example, if combos=1 and anchor_gene="ENSG00000148400":
348
+ anchor gene will be perturbed in combination with each other gene.
349
+ model_type : {"Pretrained","GeneClassifier","CellClassifier"}
350
+ Whether model is the pretrained Geneformer or a fine-tuned gene or cell classifier.
351
+ num_classes : int
352
+ If model is a gene or cell classifier, specify number of classes it was trained to classify.
353
+ For the pretrained Geneformer model, number of classes is 0 as it is not a classifier.
354
+ emb_mode : {"cell","cell_and_gene"}
355
+ Whether to output impact of perturbation on cell and/or gene embeddings.
356
+ cell_emb_style : "mean_pool"
357
+ Method for summarizing cell embeddings.
358
+ Currently only option is mean pooling of gene embeddings for given cell.
359
+ filter_data : None, dict
360
+ Default is to use all input data for in silico perturbation study.
361
+ Otherwise, dictionary specifying .dataset column name and list of values to filter by.
362
+ cell_states_to_model: None, dict
363
+ Cell states to model if testing perturbations that achieve goal state change.
364
+ Single-item dictionary with key being cell attribute (e.g. "disease").
365
+ Value is tuple of three lists indicating start state, goal end state, and alternate possible end states.
366
+ max_ncells : None, int
367
+ Maximum number of cells to test.
368
+ If None, will test all cells.
369
+ emb_layer : {-1, 0}
370
+ Embedding layer to use for quantification.
371
+ -1: 2nd to last layer (recommended for pretrained Geneformer)
372
+ 0: last layer (recommended for cell classifier fine-tuned for disease state)
373
+ forward_batch_size : int
374
+ Batch size for forward pass.
375
+ nproc : int
376
+ Number of CPU processes to use.
377
+ save_raw_data: {False,True}
378
+ Whether to save raw perturbation data for each gene/cell.
379
+ token_dictionary_file : Path
380
+ Path to pickle file containing token dictionary (Ensembl ID:token).
381
+ """
382
+
383
+ self.perturb_type = perturb_type
384
+ self.perturb_rank_shift = perturb_rank_shift
385
+ self.genes_to_perturb = genes_to_perturb
386
+ self.combos = combos
387
+ self.anchor_gene = anchor_gene
388
+ self.model_type = model_type
389
+ self.num_classes = num_classes
390
+ self.emb_mode = emb_mode
391
+ self.cell_emb_style = cell_emb_style
392
+ self.filter_data = filter_data
393
+ self.cell_states_to_model = cell_states_to_model
394
+ self.max_ncells = max_ncells
395
+ self.emb_layer = emb_layer
396
+ self.forward_batch_size = forward_batch_size
397
+ self.nproc = nproc
398
+ self.save_raw_data = save_raw_data
399
+
400
+ self.validate_options()
401
+
402
+ # load token dictionary (Ensembl IDs:token)
403
+ with open(token_dictionary_file, "rb") as f:
404
+ self.gene_token_dict = pickle.load(f)
405
+
406
+ if anchor_gene is None:
407
+ self.anchor_token = None
408
+ else:
409
+ self.anchor_token = self.gene_token_dict[self.anchor_gene]
410
+
411
+ if genes_to_perturb == "all":
412
+ self.tokens_to_perturb = "all"
413
+ else:
414
+ self.tokens_to_perturb = [self.gene_token_dict[gene] for gene in self.genes_to_perturb]
415
+
416
+ def validate_options(self):
417
+ for attr_name,valid_options in self.valid_option_dict.items():
418
+ attr_value = self.__dict__[attr_name]
419
+ if type(attr_value) not in {list, dict}:
420
+ if attr_value in valid_options:
421
+ continue
422
+ valid_type = False
423
+ for option in valid_options:
424
+ if (option in [int,list,dict]) and isinstance(attr_value, option):
425
+ valid_type = True
426
+ break
427
+ if valid_type:
428
+ continue
429
+ logger.error(
430
+ f"Invalid option for {attr_name}. " \
431
+ f"Valid options for {attr_name}: {valid_options}"
432
+ )
433
+ raise
434
+
435
+ if self.perturb_type in ["delete","overexpress"]:
436
+ if self.perturb_rank_shift is not None:
437
+ if self.perturb_type == "delete":
438
+ logger.warning(
439
+ "perturb_rank_shift set to None. " \
440
+ "If perturb type is delete then gene is deleted entirely " \
441
+ "rather than shifted by quartile")
442
+ elif self.perturb_type == "overexpress":
443
+ logger.warning(
444
+ "perturb_rank_shift set to None. " \
445
+ "If perturb type is activate then gene is moved to front " \
446
+ "of rank value encoding rather than shifted by quartile")
447
+ self.perturb_rank_shift = None
448
+
449
+ if (self.anchor_gene is not None) and (self.emb_mode == "cell_and_gene"):
450
+ self.emb_mode = "cell"
451
+ logger.warning(
452
+ "emb_mode set to 'cell'. " \
453
+ "Currently, analysis with anchor gene " \
454
+ "only outputs effect on cell embeddings.")
455
+
456
+ if self.cell_states_to_model is not None:
457
+ if (len(self.cell_states_to_model.items()) == 1):
458
+ for key,value in self.cell_states_to_model.items():
459
+ if (len(value) == 3) and isinstance(value, tuple):
460
+ if isinstance(value[0],list) and isinstance(value[1],list) and isinstance(value[2],list):
461
+ if len(value[0]) == 1 and len(value[1]) == 1:
462
+ all_values = value[0]+value[1]+value[2]
463
+ if len(all_values) == len(set(all_values)):
464
+ continue
465
+ else:
466
+ logger.error(
467
+ "Cell states to model must be a single-item dictionary with " \
468
+ "key being cell attribute (e.g. 'disease') and value being " \
469
+ "tuple of three lists indicating start state, goal end state, and alternate possible end states. " \
470
+ "Values should all be unique. " \
471
+ "For example: {'disease':(['dcm'],['ctrl'],['hcm'])}")
472
+ raise
473
+ if self.anchor_gene is not None:
474
+ self.anchor_gene = None
475
+ logger.warning(
476
+ "anchor_gene set to None. " \
477
+ "Currently, anchor gene not available " \
478
+ "when modeling multiple cell states.")
479
+
480
+ if self.perturb_type in ["inhibit","activate"]:
481
+ if self.perturb_rank_shift is None:
482
+ logger.error(
483
+ "If perturb type is inhibit or activate then " \
484
+ "quartile to shift by must be specified.")
485
+ raise
486
+
487
+ for key,value in self.filter_data.items():
488
+ if type(value) != list:
489
+ self.filter_data[key] = [value]
490
+ logger.warning(
491
+ "Values in filter_data dict must be lists. " \
492
+ f"Changing {key} value to list ([{value}]).")
493
+
494
+ def perturb_data(self,
495
+ model_directory,
496
+ input_data_file,
497
+ output_directory,
498
+ output_prefix):
499
+ """
500
+ Perturb genes in input data and save as results in output_directory.
501
+
502
+ Parameters
503
+ ----------
504
+ model_directory : Path
505
+ Path to directory containing model
506
+ input_data_file : Path
507
+ Path to directory containing .dataset inputs
508
+ output_directory : Path
509
+ Path to directory where perturbation data will be saved as .csv
510
+ output_prefix : str
511
+ Prefix for output .dataset
512
+ """
513
+
514
+ filtered_input_data = self.load_and_filter(input_data_file)
515
+ model = self.load_model(model_directory)
516
+ layer_to_quant = quant_layers(model)+self.emb_layer
517
+
518
+ if self.cell_states_to_model is None:
519
+ state_embs_dict = None
520
+ else:
521
+ # get dictionary of average cell state embeddings for comparison
522
+ state_embs_dict = get_cell_state_avg_embs(model,
523
+ filtered_input_data,
524
+ self.cell_states_to_model,
525
+ layer_to_quant,
526
+ self.gene_token_dict,
527
+ self.forward_batch_size,
528
+ self.nproc)
529
+ self.in_silico_perturb(model,
530
+ filtered_input_data,
531
+ layer_to_quant,
532
+ state_embs_dict,
533
+ output_directory,
534
+ output_prefix)
535
+
536
+ # if self.save_raw_data is False:
537
+ # # delete intermediate dictionaries
538
+ # output_dir = os.listdir(output_directory)
539
+ # for output_file in output_dir:
540
+ # if output_file.endswith("_raw.pickle"):
541
+ # os.remove(os.path.join(output_directory, output_file))
542
+
543
+ # load data and filter by defined criteria
544
+ def load_and_filter(self, input_data_file):
545
+ data = load_from_disk(input_data_file)
546
+ for key,value in self.filter_data.items():
547
+ def filter_data(example):
548
+ return example[key] in value
549
+ data = data.filter(filter_data, num_proc=self.nproc)
550
+ if len(data) == 0:
551
+ logger.error(
552
+ "No cells remain after filtering. Check filtering criteria.")
553
+ raise
554
+ data_shuffled = data.shuffle(seed=42)
555
+ num_cells = len(data_shuffled)
556
+ # if max number of cells is defined, then subsample to this max number
557
+ if self.max_ncells != None:
558
+ num_cells = min(self.max_ncells,num_cells)
559
+ data_subset = data_shuffled.select([i for i in range(num_cells)])
560
+ # sort dataset with largest cell first to encounter any memory errors earlier
561
+ data_sorted = data_subset.sort("length",reverse=True)
562
+ return data_sorted
563
+
564
+ # load model to GPU
565
+ def load_model(self, model_directory):
566
+ if self.model_type == "Pretrained":
567
+ model = BertForMaskedLM.from_pretrained(model_directory,
568
+ output_hidden_states=True,
569
+ output_attentions=False)
570
+ elif self.model_type == "GeneClassifier":
571
+ model = BertForTokenClassification.from_pretrained(model_directory,
572
+ num_labels=self.num_classes,
573
+ output_hidden_states=True,
574
+ output_attentions=False)
575
+ elif self.model_type == "CellClassifier":
576
+ model = BertForSequenceClassification.from_pretrained(model_directory,
577
+ num_labels=self.num_classes,
578
+ output_hidden_states=True,
579
+ output_attentions=False)
580
+ # put the model in eval mode for fwd pass
581
+ model.eval()
582
+ model = model.to("cuda:0")
583
+ return model
584
+
585
+ # determine effect of perturbation on other genes
586
+ def in_silico_perturb(self,
587
+ model,
588
+ filtered_input_data,
589
+ layer_to_quant,
590
+ state_embs_dict,
591
+ output_directory,
592
+ output_prefix):
593
+
594
+ output_path_prefix = f"{output_directory}in_silico_{self.perturb_type}_{output_prefix}_dict_1Kbatch"
595
+
596
+ # filter dataset for cells that have tokens to be perturbed
597
+ if self.anchor_token is not None:
598
+ def if_has_tokens_to_perturb(example):
599
+ return (len(set(example["input_ids"]).intersection(self.anchor_token))==len(self.anchor_token))
600
+ filtered_input_data = filtered_input_data.filter(if_has_tokens_to_perturb, num_proc=self.nproc)
601
+ logger.info(f"# cells with anchor gene: {len(filtered_input_data)}")
602
+ if self.tokens_to_perturb != "all":
603
+ def if_has_tokens_to_perturb(example):
604
+ return (len(set(example["input_ids"]).intersection(self.tokens_to_perturb))>self.combos)
605
+ filtered_input_data = filtered_input_data.filter(if_has_tokens_to_perturb, num_proc=self.nproc)
606
+
607
+ cos_sims_dict = defaultdict(list)
608
+ pickle_batch = -1
609
+
610
+ for i in trange(len(filtered_input_data)):
611
+ example_cell = filtered_input_data.select([i])
612
+ original_emb = forward_pass_single_cell(model, example_cell, layer_to_quant)
613
+ gene_list = torch.squeeze(example_cell["input_ids"])
614
+
615
+ # reset to original type to prevent downstream issues due to forward_pass_single_cell modifying as torch format in place
616
+ example_cell = filtered_input_data.select([i])
617
+
618
+ if self.anchor_token is None:
619
+ for combo_lvl in range(self.combos+1):
620
+ perturbation_batch, indices_to_perturb = make_perturbation_batch(example_cell,
621
+ self.perturb_type,
622
+ self.tokens_to_perturb,
623
+ self.anchor_token,
624
+ combo_lvl,
625
+ self.nproc)
626
+ cos_sims_data = quant_cos_sims(model,
627
+ perturbation_batch,
628
+ self.forward_batch_size,
629
+ layer_to_quant,
630
+ original_emb,
631
+ indices_to_perturb,
632
+ self.cell_states_to_model,
633
+ state_embs_dict)
634
+
635
+ if self.cell_states_to_model is None:
636
+ # update cos sims dict
637
+ # key is tuple of (perturbed_gene, affected_gene)
638
+ # or (perturbed_gene, "cell_emb") for avg cell emb change
639
+ cos_sims_data = cos_sims_data.to("cuda")
640
+ for j in range(cos_sims_data.shape[0]):
641
+ if self.genes_to_perturb != "all":
642
+ j_index = torch.tensor(indices_to_perturb[j])
643
+ if j_index.shape[0]>1:
644
+ j_index = torch.squeeze(j_index)
645
+ else:
646
+ j_index = torch.tensor([j])
647
+ perturbed_gene = torch.index_select(gene_list, 0, j_index)
648
+
649
+ if perturbed_gene.shape[0]==1:
650
+ perturbed_gene = perturbed_gene.item()
651
+ elif perturbed_gene.shape[0]>1:
652
+ perturbed_gene = tuple(perturbed_gene.tolist())
653
+
654
+ cell_cos_sim = torch.mean(cos_sims_data[j]).item()
655
+ cos_sims_dict[(perturbed_gene, "cell_emb")] += [cell_cos_sim]
656
+
657
+ # not_j_index = list(set(i for i in range(gene_list.shape[0])).difference(j_index))
658
+ # gene_list_j = torch.index_select(gene_list, 0, j_index)
659
+ if self.emb_mode == "cell_and_gene":
660
+ for k in range(cos_sims_data.shape[1]):
661
+ cos_sim_value = cos_sims_data[j][k]
662
+ affected_gene = gene_list[k].item()
663
+ cos_sims_dict[(perturbed_gene, affected_gene)] += [cos_sim_value.item()]
664
+ else:
665
+ # update cos sims dict
666
+ # key is tuple of (perturbed_gene, "cell_emb")
667
+ # value is list of tuples of cos sims for cell_states_to_model
668
+ origin_state_key = [value[0] for value in self.cell_states_to_model.values()][0][0]
669
+ cos_sims_origin = cos_sims_data[origin_state_key]
670
+
671
+ for j in range(cos_sims_origin.shape[0]):
672
+ if (self.genes_to_perturb != "all") or (combo_lvl>0):
673
+ j_index = torch.tensor(indices_to_perturb[j])
674
+ if j_index.shape[0]>1:
675
+ j_index = torch.squeeze(j_index)
676
+ else:
677
+ j_index = torch.tensor([j])
678
+ perturbed_gene = torch.index_select(gene_list, 0, j_index)
679
+
680
+ if perturbed_gene.shape[0]==1:
681
+ perturbed_gene = perturbed_gene.item()
682
+ elif perturbed_gene.shape[0]>1:
683
+ perturbed_gene = tuple(perturbed_gene.tolist())
684
+
685
+ data_list = []
686
+ for data in list(cos_sims_data.values()):
687
+ data_item = data.to("cuda")
688
+ cell_data = torch.mean(data_item[j]).item()
689
+ data_list += [cell_data]
690
+ cos_sims_dict[(perturbed_gene, "cell_emb")] += [tuple(data_list)]
691
+
692
+ elif self.anchor_token is not None:
693
+ perturbation_batch, indices_to_perturb = make_perturbation_batch(example_cell,
694
+ self.perturb_type,
695
+ self.tokens_to_perturb,
696
+ None, # first run without anchor token to test individual gene perturbations
697
+ 0,
698
+ self.nproc)
699
+ cos_sims_data = quant_cos_sims(model,
700
+ perturbation_batch,
701
+ self.forward_batch_size,
702
+ layer_to_quant,
703
+ original_emb,
704
+ indices_to_perturb,
705
+ self.cell_states_to_model,
706
+ state_embs_dict)
707
+ cos_sims_data = cos_sims_data.to("cuda")
708
+
709
+ combo_perturbation_batch, combo_indices_to_perturb = make_perturbation_batch(example_cell,
710
+ self.perturb_type,
711
+ self.tokens_to_perturb,
712
+ self.anchor_token,
713
+ 1,
714
+ self.nproc)
715
+ combo_cos_sims_data = quant_cos_sims(model,
716
+ combo_perturbation_batch,
717
+ self.forward_batch_size,
718
+ layer_to_quant,
719
+ original_emb,
720
+ combo_indices_to_perturb,
721
+ self.cell_states_to_model,
722
+ state_embs_dict)
723
+ combo_cos_sims_data = combo_cos_sims_data.to("cuda")
724
+
725
+ # update cos sims dict
726
+ # key is tuple of (perturbed_gene, "cell_emb") for avg cell emb change
727
+ anchor_index = example_cell["input_ids"][0].index(self.anchor_token[0])
728
+ anchor_cell_cos_sim = torch.mean(cos_sims_data[anchor_index]).item()
729
+ non_anchor_indices = [k for k in range(cos_sims_data.shape[0]) if k != anchor_index]
730
+ cos_sims_data = cos_sims_data[non_anchor_indices,:]
731
+
732
+ for j in range(cos_sims_data.shape[0]):
733
+
734
+ if j<anchor_index:
735
+ j_index = torch.tensor([j])
736
+ else:
737
+ j_index = torch.tensor([j+1])
738
+
739
+ perturbed_gene = torch.index_select(gene_list, 0, j_index)
740
+ perturbed_gene = perturbed_gene.item()
741
+
742
+ cell_cos_sim = torch.mean(cos_sims_data[j]).item()
743
+ combo_cos_sim = torch.mean(combo_cos_sims_data[j]).item()
744
+ cos_sims_dict[(perturbed_gene, "cell_emb")] += [(anchor_cell_cos_sim, # cos sim anchor gene alone
745
+ cell_cos_sim, # cos sim deleted gene alone
746
+ combo_cos_sim)] # cos sim anchor gene + deleted gene
747
+
748
+ # save dict to disk every 100 cells
749
+ if (i/100).is_integer():
750
+ with open(f"{output_path_prefix}{pickle_batch}_raw.pickle", "wb") as fp:
751
+ pickle.dump(cos_sims_dict, fp)
752
+ # reset and clear memory every 1000 cells
753
+ if (i/1000).is_integer():
754
+ pickle_batch = pickle_batch+1
755
+ # clear memory
756
+ del perturbed_gene
757
+ del cos_sims_data
758
+ if self.cell_states_to_model is None:
759
+ del cell_cos_sim
760
+ if self.cell_states_to_model is not None:
761
+ del cell_data
762
+ del data_list
763
+ elif self.anchor_token is None:
764
+ del affected_gene
765
+ del cos_sim_value
766
+ else:
767
+ del combo_cos_sim
768
+ del combo_cos_sims_data
769
+ # reset dict
770
+ del cos_sims_dict
771
+ cos_sims_dict = defaultdict(list)
772
+ torch.cuda.empty_cache()
773
+
774
+ # save remainder cells
775
+ with open(f"{output_path_prefix}{pickle_batch}_raw.pickle", "wb") as fp:
776
+ pickle.dump(cos_sims_dict, fp)
777
+
geneformer/in_silico_perturber_stats.py ADDED
@@ -0,0 +1,302 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Geneformer in silico perturber stats generator.
3
+
4
+ Usage:
5
+ from geneformer import InSilicoPerturberStats
6
+ ispstats = InSilicoPerturberStats(mode="goal_state_shift",
7
+ combos=0,
8
+ anchor_gene=None,
9
+ cell_states_to_model={"disease":(["dcm"],["ctrl"],["hcm"])})
10
+ ispstats.get_stats("path/to/input_data",
11
+ None,
12
+ "path/to/output_directory",
13
+ "output_prefix")
14
+ """
15
+
16
+
17
+ import os
18
+ import logging
19
+ import numpy as np
20
+ import pandas as pd
21
+ import pickle
22
+ import statsmodels.stats.multitest as smt
23
+ from pathlib import Path
24
+ from scipy.stats import ranksums
25
+ from tqdm.notebook import trange
26
+
27
+ from .tokenizer import TOKEN_DICTIONARY_FILE
28
+
29
+ GENE_NAME_ID_DICTIONARY_FILE = Path(__file__).parent / "gene_name_id_dict.pkl"
30
+
31
+ logger = logging.getLogger(__name__)
32
+
33
+ # invert dictionary keys/values
34
+ def invert_dict(dictionary):
35
+ return {v: k for k, v in dictionary.items()}
36
+
37
+ # read raw dictionary files
38
+ def read_dictionaries(dir, cell_or_gene_emb):
39
+ dict_list = []
40
+ for file in os.listdir(dir):
41
+ # process only _raw.pickle files
42
+ if file.endswith("_raw.pickle"):
43
+ with open(f"{dir}/{file}", "rb") as fp:
44
+ cos_sims_dict = pickle.load(fp)
45
+ if cell_or_gene_emb == "cell":
46
+ cell_emb_dict = {k: v for k,
47
+ v in cos_sims_dict.items() if v and "cell_emb" in k}
48
+ dict_list += [cell_emb_dict]
49
+ return dict_list
50
+
51
+ # get complete gene list
52
+ def get_gene_list(dict_list):
53
+ gene_set = set()
54
+ for dict_i in dict_list:
55
+ gene_set.update([k[0] for k, v in dict_i.items() if v])
56
+ gene_list = list(gene_set)
57
+ gene_list.sort()
58
+ return gene_list
59
+
60
+ def n_detections(token, dict_list):
61
+ cos_sim_megalist = []
62
+ for dict_i in dict_list:
63
+ cos_sim_megalist += dict_i.get((token, "cell_emb"),[])
64
+ return len(cos_sim_megalist)
65
+
66
+ def get_fdr(pvalues):
67
+ return list(smt.multipletests(pvalues, alpha=0.05, method="fdr_bh")[1])
68
+
69
+ def isp_stats(cos_sims_df, dict_list, cell_states_to_model):
70
+
71
+ random_tuples = []
72
+ for i in trange(cos_sims_df.shape[0]):
73
+ token = cos_sims_df["Gene"][i]
74
+ for dict_i in dict_list:
75
+ random_tuples += dict_i.get((token, "cell_emb"),[])
76
+ goal_end_random_megalist = [goal_end for goal_end,alt_end,start_state in random_tuples]
77
+ alt_end_random_megalist = [alt_end for goal_end,alt_end,start_state in random_tuples]
78
+ start_state_random_megalist = [start_state for goal_end,alt_end,start_state in random_tuples]
79
+
80
+ # downsample to improve speed of ranksums
81
+ if len(goal_end_random_megalist) > 100_000:
82
+ random.seed(42)
83
+ goal_end_random_megalist = random.sample(goal_end_random_megalist, k=100_000)
84
+ if len(alt_end_random_megalist) > 100_000:
85
+ random.seed(42)
86
+ alt_end_random_megalist = random.sample(alt_end_random_megalist, k=100_000)
87
+ if len(start_state_random_megalist) > 100_000:
88
+ random.seed(42)
89
+ start_state_random_megalist = random.sample(start_state_random_megalist, k=100_000)
90
+
91
+ names=["Gene",
92
+ "Gene_name",
93
+ "Ensembl_ID",
94
+ "Shift_from_goal_end",
95
+ "Shift_from_alt_end",
96
+ "Goal_end_vs_random_pval",
97
+ "Alt_end_vs_random_pval"]
98
+ cos_sims_full_df = pd.DataFrame(columns=names)
99
+
100
+ for i in trange(cos_sims_df.shape[0]):
101
+ token = cos_sims_df["Gene"][i]
102
+ name = cos_sims_df["Gene_name"][i]
103
+ ensembl_id = cos_sims_df["Ensembl_ID"][i]
104
+ token_tuples = []
105
+
106
+ for dict_i in dict_list:
107
+ token_tuples += dict_i.get((token, "cell_emb"),[])
108
+
109
+ goal_end_cos_sim_megalist = [goal_end for goal_end,alt_end,start_state in token_tuples]
110
+ alt_end_cos_sim_megalist = [alt_end for goal_end,alt_end,start_state in token_tuples]
111
+
112
+ mean_goal_end = np.mean(goal_end_cos_sim_megalist)
113
+ mean_alt_end = np.mean(alt_end_cos_sim_megalist)
114
+
115
+ pval_goal_end = ranksums(goal_end_random_megalist,goal_end_cos_sim_megalist).pvalue
116
+ pval_alt_end = ranksums(alt_end_random_megalist,alt_end_cos_sim_megalist).pvalue
117
+
118
+ data_i = [token,
119
+ name,
120
+ ensembl_id,
121
+ mean_goal_end,
122
+ mean_alt_end,
123
+ pval_goal_end,
124
+ pval_alt_end]
125
+
126
+ cos_sims_df_i = pd.DataFrame(dict(zip(names,data_i)),index=[i])
127
+ cos_sims_full_df = pd.concat([cos_sims_full_df,cos_sims_df_i])
128
+
129
+ cos_sims_full_df["Goal_end_FDR"] = get_fdr(list(cos_sims_full_df["Goal_end_vs_random_pval"]))
130
+ cos_sims_full_df["Alt_end_FDR"] = get_fdr(list(cos_sims_full_df["Alt_end_vs_random_pval"]))
131
+
132
+ return cos_sims_full_df
133
+
134
+ class InSilicoPerturberStats:
135
+ valid_option_dict = {
136
+ "mode": {"goal_state_shift","vs_null","vs_random"},
137
+ "combos": {0,1,2},
138
+ "anchor_gene": {None, str},
139
+ "cell_states_to_model": {None, dict},
140
+ }
141
+ def __init__(
142
+ self,
143
+ mode="vs_random",
144
+ combos=0,
145
+ anchor_gene=None,
146
+ cell_states_to_model=None,
147
+ token_dictionary_file=TOKEN_DICTIONARY_FILE,
148
+ gene_name_id_dictionary_file=GENE_NAME_ID_DICTIONARY_FILE,
149
+ ):
150
+ """
151
+ Initialize in silico perturber stats generator.
152
+
153
+ Parameters
154
+ ----------
155
+ mode : {"goal_state_shift","vs_null","vs_random"}
156
+ Type of stats.
157
+ "goal_state_shift": perturbation vs. random for desired cell state shift
158
+ "vs_null": perturbation vs. null from provided null distribution dataset
159
+ "vs_random": perturbation vs. random gene perturbations in that cell (no goal direction)
160
+ combos : {0,1,2}
161
+ Whether to perturb genes individually (0), in pairs (1), or in triplets (2).
162
+ anchor_gene : None, str
163
+ ENSEMBL ID of gene to use as anchor in combination perturbations.
164
+ For example, if combos=1 and anchor_gene="ENSG00000148400":
165
+ anchor gene will be perturbed in combination with each other gene.
166
+ cell_states_to_model: None, dict
167
+ Cell states to model if testing perturbations that achieve goal state change.
168
+ Single-item dictionary with key being cell attribute (e.g. "disease").
169
+ Value is tuple of three lists indicating start state, goal end state, and alternate possible end states.
170
+ token_dictionary_file : Path
171
+ Path to pickle file containing token dictionary (Ensembl ID:token).
172
+ gene_name_id_dictionary_file : Path
173
+ Path to pickle file containing gene name to ID dictionary (gene name:Ensembl ID).
174
+ """
175
+
176
+ self.mode = mode
177
+ self.combos = combos
178
+ self.anchor_gene = anchor_gene
179
+ self.cell_states_to_model = cell_states_to_model
180
+
181
+ self.validate_options()
182
+
183
+ # load token dictionary (Ensembl IDs:token)
184
+ with open(token_dictionary_file, "rb") as f:
185
+ self.gene_token_dict = pickle.load(f)
186
+
187
+ # load gene name dictionary (gene name:Ensembl ID)
188
+ with open(gene_name_id_dictionary_file, "rb") as f:
189
+ self.gene_name_id_dict = pickle.load(f)
190
+
191
+ if anchor_gene is None:
192
+ self.anchor_token = None
193
+ else:
194
+ self.anchor_token = self.gene_token_dict[self.anchor_gene]
195
+
196
+ def validate_options(self):
197
+ for attr_name,valid_options in self.valid_option_dict.items():
198
+ attr_value = self.__dict__[attr_name]
199
+ if type(attr_value) not in {list, dict}:
200
+ if attr_value in valid_options:
201
+ continue
202
+ valid_type = False
203
+ for option in valid_options:
204
+ if (option in [int,list,dict]) and isinstance(attr_value, option):
205
+ valid_type = True
206
+ break
207
+ if valid_type:
208
+ continue
209
+ logger.error(
210
+ f"Invalid option for {attr_name}. " \
211
+ f"Valid options for {attr_name}: {valid_options}"
212
+ )
213
+ raise
214
+
215
+ if self.cell_states_to_model is not None:
216
+ if (len(self.cell_states_to_model.items()) == 1):
217
+ for key,value in self.cell_states_to_model.items():
218
+ if (len(value) == 3) and isinstance(value, tuple):
219
+ if isinstance(value[0],list) and isinstance(value[1],list) and isinstance(value[2],list):
220
+ if len(value[0]) == 1 and len(value[1]) == 1:
221
+ all_values = value[0]+value[1]+value[2]
222
+ if len(all_values) == len(set(all_values)):
223
+ continue
224
+ else:
225
+ logger.error(
226
+ "Cell states to model must be a single-item dictionary with " \
227
+ "key being cell attribute (e.g. 'disease') and value being " \
228
+ "tuple of three lists indicating start state, goal end state, and alternate possible end states. " \
229
+ "Values should all be unique. " \
230
+ "For example: {'disease':(['start_state'],['ctrl'],['alt_end'])}")
231
+ raise
232
+ if self.anchor_gene is not None:
233
+ self.anchor_gene = None
234
+ logger.warning(
235
+ "anchor_gene set to None. " \
236
+ "Currently, anchor gene not available " \
237
+ "when modeling multiple cell states.")
238
+
239
+ def get_stats(self,
240
+ input_data_directory,
241
+ null_dist_data_directory,
242
+ output_directory,
243
+ output_prefix):
244
+ """
245
+ Get stats for in silico perturbation data and save as results in output_directory.
246
+
247
+ Parameters
248
+ ----------
249
+ input_data_directory : Path
250
+ Path to directory containing cos_sim dictionary inputs
251
+ null_dist_data_directory : Path
252
+ Path to directory containing null distribution cos_sim dictionary inputs
253
+ output_directory : Path
254
+ Path to directory where perturbation data will be saved as .csv
255
+ output_prefix : str
256
+ Prefix for output .dataset
257
+ """
258
+
259
+ self.gene_token_id_dict = invert_dict(self.gene_token_dict)
260
+ self.gene_id_name_dict = invert_dict(self.gene_name_id_dict)
261
+
262
+ if self.mode == "goal_state_shift":
263
+ dict_list = read_dictionaries(input_data_directory,"cell")
264
+ else:
265
+ logger.error(
266
+ "Currently, only mode available is stats for goal_state_shift.")
267
+ raise
268
+
269
+ # obtain total gene list
270
+ gene_list = get_gene_list(dict_list)
271
+
272
+ # initiate results dataframe
273
+ cos_sims_df_initial = pd.DataFrame({"Gene": gene_list,
274
+ "Gene_name": [self.token_to_gene_name(item) \
275
+ for item in gene_list], \
276
+ "Ensembl_ID": [self.gene_token_id_dict[genes[1]] \
277
+ if isinstance(genes,tuple) else \
278
+ self.gene_token_id_dict[genes] \
279
+ for genes in gene_list]}, \
280
+ index=[i for i in range(len(gene_list))])
281
+
282
+ # # add ENSEMBL ID for genes
283
+ # cos_sims_df_initial["Ensembl_ID"] = [self.gene_token_id_dict[genes[1]] if isinstance(genes,tuple) else self.gene_token_id_dict[genes] for genes in list(cos_sims_df_initial["Gene"])]
284
+
285
+ cos_sims_df = isp_stats(cos_sims_df_initial, dict_list, self.cell_states_to_model)
286
+
287
+ # quantify number of detections of each gene
288
+ cos_sims_df["N_Detections"] = [n_detections(i, dict_list) for i in cos_sims_df["Gene"]]
289
+
290
+ # sort by shift to desired state
291
+ cos_sims_df = cos_sims_df.sort_values(by=["Shift_from_goal_end",
292
+ "Goal_end_FDR"])
293
+
294
+ # save perturbation stats to output_path
295
+ output_path = (Path(output_directory) / output_prefix).with_suffix(".csv")
296
+ cos_sims_df.to_csv(output_path)
297
+
298
+ def token_to_gene_name(self, item):
299
+ if isinstance(item,int):
300
+ return self.gene_id_name_dict.get(self.gene_token_id_dict.get(item, np.nan), np.nan)
301
+ if isinstance(item,tuple):
302
+ return tuple([self.gene_id_name_dict.get(self.gene_token_id_dict.get(i, np.nan), np.nan) for i in item])
geneformer/pretrainer.py CHANGED
@@ -377,7 +377,7 @@ class GeneformerPreCollator(SpecialTokensMixin):
377
  return_tensors = "tf" if return_tensors is None else return_tensors
378
  elif is_torch_available() and _is_torch(first_element):
379
  return_tensors = "pt" if return_tensors is None else return_tensors
380
- elif isinstance(first_element, np.ndarray):
381
  return_tensors = "np" if return_tensors is None else return_tensors
382
  else:
383
  raise ValueError(
@@ -387,6 +387,7 @@ class GeneformerPreCollator(SpecialTokensMixin):
387
 
388
  for key, value in encoded_inputs.items():
389
  encoded_inputs[key] = to_py_obj(value)
 
390
 
391
  # Convert padding_strategy in PaddingStrategy
392
  padding_strategy, _, max_length, _ = self._get_padding_truncation_strategies(
 
377
  return_tensors = "tf" if return_tensors is None else return_tensors
378
  elif is_torch_available() and _is_torch(first_element):
379
  return_tensors = "pt" if return_tensors is None else return_tensors
380
+ if isinstance(first_element, np.ndarray):
381
  return_tensors = "np" if return_tensors is None else return_tensors
382
  else:
383
  raise ValueError(
 
387
 
388
  for key, value in encoded_inputs.items():
389
  encoded_inputs[key] = to_py_obj(value)
390
+
391
 
392
  # Convert padding_strategy in PaddingStrategy
393
  padding_strategy, _, max_length, _ = self._get_padding_truncation_strategies(