Christina Theodoris commited on
Commit
2f25aea
1 Parent(s): fd93ebf

Add functions for extracting gene embeddings, move state_embs_dict outside isp, fix bugs in isp

Browse files
examples/in_silico_perturbation.ipynb CHANGED
@@ -8,21 +8,62 @@
8
  "outputs": [],
9
  "source": [
10
  "from geneformer import InSilicoPerturber\n",
11
- "from geneformer import InSilicoPerturberStats"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  ]
13
  },
14
  {
15
  "cell_type": "code",
16
  "execution_count": null,
17
- "id": "67b44366-f255-4415-a865-6a27a8ffcce7",
18
  "metadata": {
19
  "tags": []
20
  },
21
  "outputs": [],
22
  "source": [
23
- "# in silico perturbation in deletion mode to determine genes whose \n",
24
- "# deletion in the dilated cardiomyopathy (dcm) state significantly shifts\n",
25
- "# the embedding towards non-failing (nf) state\n",
26
  "isp = InSilicoPerturber(perturb_type=\"delete\",\n",
27
  " perturb_rank_shift=None,\n",
28
  " genes_to_perturb=\"all\",\n",
@@ -32,11 +73,9 @@
32
  " num_classes=3,\n",
33
  " emb_mode=\"cell\",\n",
34
  " cell_emb_style=\"mean_pool\",\n",
35
- " filter_data={\"cell_type\":[\"Cardiomyocyte1\",\"Cardiomyocyte2\",\"Cardiomyocyte3\"]},\n",
36
- " cell_states_to_model={'state_key': 'disease', \n",
37
- " 'start_state': 'dcm', \n",
38
- " 'goal_state': 'nf', \n",
39
- " 'alt_states': ['hcm']},\n",
40
  " max_ncells=2000,\n",
41
  " emb_layer=0,\n",
42
  " forward_batch_size=400,\n",
@@ -68,7 +107,7 @@
68
  " genes_perturbed=\"all\",\n",
69
  " combos=0,\n",
70
  " anchor_gene=None,\n",
71
- " cell_states_to_model={\"disease\":([\"dcm\"],[\"nf\"],[\"hcm\"])})"
72
  ]
73
  },
74
  {
8
  "outputs": [],
9
  "source": [
10
  "from geneformer import InSilicoPerturber\n",
11
+ "from geneformer import InSilicoPerturberStats\n",
12
+ "from geneformer import EmbExtractor"
13
+ ]
14
+ },
15
+ {
16
+ "cell_type": "markdown",
17
+ "id": "cbd6851c-060e-4967-b816-e605ffe58b23",
18
+ "metadata": {
19
+ "tags": []
20
+ },
21
+ "source": [
22
+ "### in silico perturbation in deletion mode to determine genes whose deletion in the dilated cardiomyopathy (dcm) state significantly shifts the embedding towards non-failing (nf) state"
23
+ ]
24
+ },
25
+ {
26
+ "cell_type": "code",
27
+ "execution_count": null,
28
+ "id": "c53e98cd-c603-4878-82ba-db471181bb55",
29
+ "metadata": {},
30
+ "outputs": [],
31
+ "source": [
32
+ "# first obtain start, goal, and alt embedding positions\n",
33
+ "# this function was changed to be separate from perturb_data\n",
34
+ "# to avoid repeating calcuations when parallelizing perturb_data\n",
35
+ "cell_states_to_model={\"state_key\": \"disease\", \n",
36
+ " \"start_state\": \"dcm\", \n",
37
+ " \"goal_state\": \"nf\", \n",
38
+ " \"alt_states\": [\"hcm\"]}\n",
39
+ "\n",
40
+ "filter_data_dict={\"cell_type\":[\"Cardiomyocyte1\",\"Cardiomyocyte2\",\"Cardiomyocyte3\"]}\n",
41
+ "\n",
42
+ "embex = EmbExtractor(model_type=\"CellClassifier\",\n",
43
+ " num_classes=3,\n",
44
+ " filter_data=filter_data_dict,\n",
45
+ " max_ncells=1000,\n",
46
+ " emb_layer=0,\n",
47
+ " summary_stat=\"exact_mean\",\n",
48
+ " forward_batch_size=256,\n",
49
+ " nproc=16)\n",
50
+ "\n",
51
+ "state_embs_dict = embex.get_state_embs(cell_states_to_model,\n",
52
+ " \"path/to/model\",\n",
53
+ " \"path/to/input_data\",\n",
54
+ " \"path/to/output_directory\",\n",
55
+ " \"output_prefix\")"
56
  ]
57
  },
58
  {
59
  "cell_type": "code",
60
  "execution_count": null,
61
+ "id": "981e1190-62da-4543-b7d3-6e2a2d6a6d56",
62
  "metadata": {
63
  "tags": []
64
  },
65
  "outputs": [],
66
  "source": [
 
 
 
67
  "isp = InSilicoPerturber(perturb_type=\"delete\",\n",
68
  " perturb_rank_shift=None,\n",
69
  " genes_to_perturb=\"all\",\n",
73
  " num_classes=3,\n",
74
  " emb_mode=\"cell\",\n",
75
  " cell_emb_style=\"mean_pool\",\n",
76
+ " filter_data=filter_data_dict,\n",
77
+ " cell_states_to_model=cell_states_to_model,\n",
78
+ " state_embs_dict=state_embs_dict,\n",
 
 
79
  " max_ncells=2000,\n",
80
  " emb_layer=0,\n",
81
  " forward_batch_size=400,\n",
107
  " genes_perturbed=\"all\",\n",
108
  " combos=0,\n",
109
  " anchor_gene=None,\n",
110
+ " cell_states_to_model=cell_states_to_model)"
111
  ]
112
  },
113
  {
geneformer/emb_extractor.py CHANGED
@@ -7,66 +7,62 @@ Usage:
7
  num_classes=3,
8
  emb_mode="cell",
9
  cell_emb_style="mean_pool",
 
10
  filter_data={"cell_type":["cardiomyocyte"]},
11
  max_ncells=1000,
12
  max_ncells_to_plot=1000,
13
  emb_layer=-1,
14
  emb_label=["disease","cell_type"],
15
  labels_to_plot=["disease","cell_type"],
16
- forward_batch_size=100,
17
  nproc=16,
18
  summary_stat=None)
19
  embs = embex.extract_embs("path/to/model",
20
  "path/to/input_data",
21
  "path/to/output_directory",
22
  "output_prefix")
23
- embex.plot_embs(embs=embs,
24
  plot_style="heatmap",
25
  output_directory="path/to/output_directory",
26
  output_prefix="output_prefix")
27
-
28
  """
29
 
30
  # imports
31
  import logging
 
 
 
 
32
  import anndata
33
  import matplotlib.pyplot as plt
34
  import numpy as np
35
  import pandas as pd
36
- import pickle
37
- from tdigest import TDigest
38
  import scanpy as sc
39
  import seaborn as sns
40
  import torch
41
- from collections import Counter
42
- from pathlib import Path
43
  from tqdm.auto import trange
44
- from transformers import BertForMaskedLM, BertForTokenClassification, BertForSequenceClassification
45
 
 
46
  from .tokenizer import TOKEN_DICTIONARY_FILE
47
 
48
- from .in_silico_perturber import downsample_and_sort, \
49
- gen_attention_mask, \
50
- get_model_input_size, \
51
- load_and_filter, \
52
- load_model, \
53
- mean_nonpadding_embs, \
54
- pad_tensor_list, \
55
- quant_layers
56
-
57
  logger = logging.getLogger(__name__)
58
 
 
59
  # extract embeddings
60
- def get_embs(model,
61
- filtered_input_data,
62
- emb_mode,
63
- layer_to_quant,
64
- pad_token_id,
65
- forward_batch_size,
66
- summary_stat):
67
- model_input_size = get_model_input_size(model)
 
 
 
68
  total_batch_length = len(filtered_input_data)
69
-
70
  if summary_stat is None:
71
  embs_list = []
72
  elif summary_stat is not None:
@@ -74,69 +70,173 @@ def get_embs(model,
74
  example = filtered_input_data.select([i for i in range(1)])
75
  example.set_format(type="torch")
76
  emb_dims = test_emb(model, example["input_ids"], layer_to_quant)
77
- # initiate tdigests for # of emb dims
78
- embs_tdigests = [TDigest() for _ in range(emb_dims)]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
 
80
- for i in trange(0, total_batch_length, forward_batch_size):
81
- max_range = min(i+forward_batch_size, total_batch_length)
82
 
83
  minibatch = filtered_input_data.select([i for i in range(i, max_range)])
84
- max_len = max(minibatch["length"])
85
- original_lens = torch.tensor(minibatch["length"]).to("cuda")
 
86
  minibatch.set_format(type="torch")
87
 
88
  input_data_minibatch = minibatch["input_ids"]
89
- input_data_minibatch = pad_tensor_list(input_data_minibatch,
90
- max_len,
91
- pad_token_id,
92
- model_input_size)
93
-
94
  with torch.no_grad():
95
  outputs = model(
96
- input_ids = input_data_minibatch.to("cuda"),
97
- attention_mask = gen_attention_mask(minibatch)
98
  )
99
 
100
  embs_i = outputs.hidden_states[layer_to_quant]
101
-
102
  if emb_mode == "cell":
103
- mean_embs = mean_nonpadding_embs(embs_i, original_lens)
104
  if summary_stat is None:
105
- embs_list += [mean_embs]
106
  elif summary_stat is not None:
107
  # update tdigests with current batch for each emb dim
108
- # note: tdigest batch update known to be slow so updating serially
109
- [embs_tdigests[j].update(mean_embs[i,j].item()) for i in range(mean_embs.size(0)) for j in range(emb_dims)]
110
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
  del outputs
112
  del minibatch
113
  del input_data_minibatch
114
  del embs_i
115
- del mean_embs
116
- torch.cuda.empty_cache()
117
-
118
  if summary_stat is None:
119
- embs_stack = torch.cat(embs_list)
 
 
 
 
 
 
 
 
 
 
 
120
  # calculate summary stat embs from approximated tdigests
121
  elif summary_stat is not None:
122
- if summary_stat == "mean":
123
- summary_emb_list = [embs_tdigests[i].trimmed_mean(0,100) for i in range(emb_dims)]
124
- elif summary_stat == "median":
125
- summary_emb_list = [embs_tdigests[i].percentile(50) for i in range(emb_dims)]
126
- embs_stack = torch.tensor(summary_emb_list)
 
 
 
 
 
 
 
 
 
 
 
 
 
127
 
128
  return embs_stack
129
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
130
  def test_emb(model, example, layer_to_quant):
131
  with torch.no_grad():
132
- outputs = model(
133
- input_ids = example.to("cuda")
134
- )
135
 
136
  embs_test = outputs.hidden_states[layer_to_quant]
137
  return embs_test.size()[2]
138
 
139
- def label_embs(embs, downsampled_data, emb_labels):
 
140
  embs_df = pd.DataFrame(embs.cpu().numpy())
141
  if emb_labels is not None:
142
  for label in emb_labels:
@@ -144,94 +244,145 @@ def label_embs(embs, downsampled_data, emb_labels):
144
  embs_df[label] = emb_label
145
  return embs_df
146
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
147
  def plot_umap(embs_df, emb_dims, label, output_file, kwargs_dict):
148
- only_embs_df = embs_df.iloc[:,:emb_dims]
149
  only_embs_df.index = pd.RangeIndex(0, only_embs_df.shape[0], name=None).astype(str)
150
- only_embs_df.columns = pd.RangeIndex(0, only_embs_df.shape[1], name=None).astype(str)
 
 
151
  vars_dict = {"embs": only_embs_df.columns}
152
- obs_dict = {"cell_id": list(only_embs_df.index),
153
- f"{label}": list(embs_df[label])}
154
  adata = anndata.AnnData(X=only_embs_df, obs=obs_dict, var=vars_dict)
155
- sc.tl.pca(adata, svd_solver='arpack')
156
  sc.pp.neighbors(adata)
157
  sc.tl.umap(adata)
158
- sns.set(rc={'figure.figsize':(10,10)}, font_scale=2.3)
159
  sns.set_style("white")
160
- default_kwargs_dict = {"palette":"Set2", "size":200}
161
  if kwargs_dict is not None:
162
  default_kwargs_dict.update(kwargs_dict)
163
-
164
  sc.pl.umap(adata, color=label, save=output_file, **default_kwargs_dict)
165
 
 
166
  def gen_heatmap_class_colors(labels, df):
167
- pal = sns.cubehelix_palette(len(Counter(labels).keys()), light=0.9, dark=0.1, hue=1, reverse=True, start=1, rot=-2)
 
 
 
 
 
 
 
 
168
  lut = dict(zip(map(str, Counter(labels).keys()), pal))
169
  colors = pd.Series(labels, index=df.index).map(lut)
170
  return colors
171
-
 
172
  def gen_heatmap_class_dict(classes, label_colors_series):
173
- class_color_dict_df = pd.DataFrame({"classes": classes, "color": label_colors_series})
 
 
174
  class_color_dict_df = class_color_dict_df.drop_duplicates(subset=["classes"])
175
- return dict(zip(class_color_dict_df["classes"],class_color_dict_df["color"]))
176
-
177
- def make_colorbar(embs_df, label):
178
 
 
 
179
  labels = list(embs_df[label])
180
-
181
  cell_type_colors = gen_heatmap_class_colors(labels, embs_df)
182
  label_colors = pd.DataFrame(cell_type_colors, columns=[label])
183
 
184
- for i,row in label_colors.iterrows():
185
- colors=row[0]
186
- if len(colors)!=3 or any(np.isnan(colors)):
187
- print(i,colors)
188
 
189
  label_colors.isna().sum()
190
-
191
  # create dictionary for colors and classes
192
  label_color_dict = gen_heatmap_class_dict(labels, label_colors[label])
193
  return label_colors, label_color_dict
194
-
 
195
  def plot_heatmap(embs_df, emb_dims, label, output_file, kwargs_dict):
196
  sns.set_style("white")
197
  sns.set(font_scale=2)
198
  plt.figure(figsize=(15, 15), dpi=150)
199
  label_colors, label_color_dict = make_colorbar(embs_df, label)
200
-
201
- default_kwargs_dict = {"row_cluster": True,
202
- "col_cluster": True,
203
- "row_colors": label_colors,
204
- "standard_scale": 1,
205
- "linewidths": 0,
206
- "xticklabels": False,
207
- "yticklabels": False,
208
- "figsize": (15,15),
209
- "center": 0,
210
- "cmap": "magma"}
211
-
 
 
212
  if kwargs_dict is not None:
213
  default_kwargs_dict.update(kwargs_dict)
214
- g = sns.clustermap(embs_df.iloc[:,0:emb_dims].apply(pd.to_numeric), **default_kwargs_dict)
 
 
215
 
216
  plt.setp(g.ax_row_colors.get_xmajorticklabels(), rotation=45, ha="right")
217
 
218
  for label_color in list(label_color_dict.keys()):
219
- g.ax_col_dendrogram.bar(0, 0, color=label_color_dict[label_color], label=label_color, linewidth=0)
 
 
 
 
 
 
 
 
 
 
220
 
221
- l1 = g.ax_col_dendrogram.legend(title=f"{label}",
222
- loc="lower center",
223
- ncol=4,
224
- bbox_to_anchor=(0.5, 1),
225
- facecolor="white")
226
 
227
- plt.savefig(output_file, bbox_inches='tight')
228
 
229
  class EmbExtractor:
230
  valid_option_dict = {
231
- "model_type": {"Pretrained","GeneClassifier","CellClassifier"},
232
  "num_classes": {int},
233
- "emb_mode": {"cell","gene"},
234
  "cell_emb_style": {"mean_pool"},
 
235
  "filter_data": {None, dict},
236
  "max_ncells": {None, int},
237
  "emb_layer": {-1, 0},
@@ -239,14 +390,16 @@ class EmbExtractor:
239
  "labels_to_plot": {None, list},
240
  "forward_batch_size": {int},
241
  "nproc": {int},
242
- "summary_stat": {None, "mean", "median"},
243
  }
 
244
  def __init__(
245
  self,
246
  model_type="Pretrained",
247
  num_classes=0,
248
  emb_mode="cell",
249
  cell_emb_style="mean_pool",
 
250
  filter_data=None,
251
  max_ncells=1000,
252
  emb_layer=-1,
@@ -272,6 +425,9 @@ class EmbExtractor:
272
  cell_emb_style : "mean_pool"
273
  Method for summarizing cell embeddings.
274
  Currently only option is mean pooling of gene embeddings for given cell.
 
 
 
275
  filter_data : None, dict
276
  Default is to extract embeddings from all input data.
277
  Otherwise, dictionary specifying .dataset column name and list of values to filter by.
@@ -296,10 +452,11 @@ class EmbExtractor:
296
  Batch size for forward pass.
297
  nproc : int
298
  Number of CPU processes to use.
299
- summary_stat : {None, "mean", "median"}
300
- If not None, outputs only approximated mean or median embedding of input data.
301
- Recommended if encountering memory constraints while generating goal embedding positions.
302
- Slower but more memory-efficient.
 
303
  token_dictionary_file : Path
304
  Path to pickle file containing token dictionary (Ensembl ID:token).
305
  """
@@ -308,6 +465,7 @@ class EmbExtractor:
308
  self.num_classes = num_classes
309
  self.emb_mode = emb_mode
310
  self.cell_emb_style = cell_emb_style
 
311
  self.filter_data = filter_data
312
  self.max_ncells = max_ncells
313
  self.emb_layer = emb_layer
@@ -315,7 +473,12 @@ class EmbExtractor:
315
  self.labels_to_plot = labels_to_plot
316
  self.forward_batch_size = forward_batch_size
317
  self.nproc = nproc
318
- self.summary_stat = summary_stat
 
 
 
 
 
319
 
320
  self.validate_options()
321
 
@@ -323,51 +486,49 @@ class EmbExtractor:
323
  with open(token_dictionary_file, "rb") as f:
324
  self.gene_token_dict = pickle.load(f)
325
 
 
326
  self.pad_token_id = self.gene_token_dict.get("<pad>")
327
-
328
-
329
  def validate_options(self):
330
- # first disallow options under development
331
- if self.emb_mode == "gene":
332
- logger.error(
333
- "Extraction and plotting of gene-level embeddings currently under development. " \
334
- "Current valid option for 'emb_mode': 'cell'"
335
- )
336
- raise
337
-
338
  # confirm arguments are within valid options and compatible with each other
339
- for attr_name,valid_options in self.valid_option_dict.items():
340
  attr_value = self.__dict__[attr_name]
341
- if type(attr_value) not in {list, dict}:
342
  if attr_value in valid_options:
343
  continue
344
  valid_type = False
345
  for option in valid_options:
346
- if (option in [int,list,dict]) and isinstance(attr_value, option):
 
 
347
  valid_type = True
348
  break
349
  if valid_type:
350
  continue
351
  logger.error(
352
- f"Invalid option for {attr_name}. " \
353
  f"Valid options for {attr_name}: {valid_options}"
354
  )
355
  raise
356
-
357
  if self.filter_data is not None:
358
- for key,value in self.filter_data.items():
359
- if type(value) != list:
360
  self.filter_data[key] = [value]
361
  logger.warning(
362
- "Values in filter_data dict must be lists. " \
363
- f"Changing {key} value to list ([{value}]).")
364
-
365
- def extract_embs(self,
366
- model_directory,
367
- input_data_file,
368
- output_directory,
369
- output_prefix,
370
- output_torch_embs=False):
 
 
 
 
371
  """
372
  Extract embeddings from input data and save as results in output_directory.
373
 
@@ -384,42 +545,165 @@ class EmbExtractor:
384
  output_torch_embs : bool
385
  Whether or not to also output the embeddings as a tensor.
386
  Note, if true, will output embeddings as both dataframe and tensor.
 
 
387
  """
388
 
389
- filtered_input_data = load_and_filter(self.filter_data, self.nproc, input_data_file)
390
- downsampled_data = downsample_and_sort(filtered_input_data, self.max_ncells)
391
- model = load_model(self.model_type, self.num_classes, model_directory)
392
- layer_to_quant = quant_layers(model)+self.emb_layer
393
- embs = get_embs(model,
394
- downsampled_data,
395
- self.emb_mode,
396
- layer_to_quant,
397
- self.pad_token_id,
398
- self.forward_batch_size,
399
- self.summary_stat)
400
-
401
- if self.summary_stat is None:
402
- embs_df = label_embs(embs, downsampled_data, self.emb_label)
403
- elif self.summary_stat is not None:
404
- embs_df = pd.DataFrame(embs.cpu().numpy()).T
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
405
 
406
  # save embeddings to output_path
407
- output_path = (Path(output_directory) / output_prefix).with_suffix(".csv")
408
- embs_df.to_csv(output_path)
409
-
410
- if output_torch_embs == True:
411
- return embs_df, embs
 
 
 
 
 
 
 
 
 
 
 
 
412
  else:
413
- return embs_df
414
-
415
- def plot_embs(self,
416
- embs,
417
- plot_style,
418
- output_directory,
419
- output_prefix,
420
- max_ncells_to_plot=1000,
421
- kwargs_dict=None):
422
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
423
  """
424
  Plot embeddings, coloring by provided labels.
425
 
@@ -440,60 +724,63 @@ class EmbExtractor:
440
  kwargs_dict : dict
441
  Dictionary of kwargs to pass to plotting function.
442
  """
443
-
444
- if plot_style not in ["heatmap","umap"]:
445
  logger.error(
446
- "Invalid option for 'plot_style'. " \
447
- "Valid options: {'heatmap','umap'}"
448
  )
449
  raise
450
-
451
  if (plot_style == "umap") and (self.labels_to_plot is None):
452
- logger.error(
453
- "Plotting UMAP requires 'labels_to_plot'. "
454
- )
455
  raise
456
-
457
  if max_ncells_to_plot > self.max_ncells:
458
  max_ncells_to_plot = self.max_ncells
459
  logger.warning(
460
- "max_ncells_to_plot must be <= max_ncells. " \
461
- f"Changing max_ncells_to_plot to {self.max_ncells}.")
462
-
463
- if (max_ncells_to_plot is not None) \
464
- and (max_ncells_to_plot < self.max_ncells):
465
  embs = embs.sample(max_ncells_to_plot, axis=0)
466
-
467
  if self.emb_label is None:
468
  label_len = 0
469
  else:
470
  label_len = len(self.emb_label)
471
-
472
  emb_dims = embs.shape[1] - label_len
473
-
474
  if self.emb_label is None:
475
  emb_labels = None
476
  else:
477
  emb_labels = embs.columns[emb_dims:]
478
-
479
  if plot_style == "umap":
480
  for label in self.labels_to_plot:
481
  if label not in emb_labels:
482
  logger.warning(
483
- f"Label {label} from labels_to_plot " \
484
- f"not present in provided embeddings dataframe.")
 
485
  continue
486
  output_prefix_label = "_" + output_prefix + f"_umap_{label}"
487
- output_file = (Path(output_directory) / output_prefix_label).with_suffix(".pdf")
 
 
488
  plot_umap(embs, emb_dims, label, output_prefix_label, kwargs_dict)
489
-
490
  if plot_style == "heatmap":
491
  for label in self.labels_to_plot:
492
  if label not in emb_labels:
493
  logger.warning(
494
- f"Label {label} from labels_to_plot " \
495
- f"not present in provided embeddings dataframe.")
 
496
  continue
497
  output_prefix_label = output_prefix + f"_heatmap_{label}"
498
- output_file = (Path(output_directory) / output_prefix_label).with_suffix(".pdf")
499
- plot_heatmap(embs, emb_dims, label, output_file, kwargs_dict)
 
 
7
  num_classes=3,
8
  emb_mode="cell",
9
  cell_emb_style="mean_pool",
10
+ gene_emb_style="mean_pool",
11
  filter_data={"cell_type":["cardiomyocyte"]},
12
  max_ncells=1000,
13
  max_ncells_to_plot=1000,
14
  emb_layer=-1,
15
  emb_label=["disease","cell_type"],
16
  labels_to_plot=["disease","cell_type"],
 
17
  nproc=16,
18
  summary_stat=None)
19
  embs = embex.extract_embs("path/to/model",
20
  "path/to/input_data",
21
  "path/to/output_directory",
22
  "output_prefix")
23
+ embex.plot_embs(embs=embs,
24
  plot_style="heatmap",
25
  output_directory="path/to/output_directory",
26
  output_prefix="output_prefix")
27
+
28
  """
29
 
30
  # imports
31
  import logging
32
+ import pickle
33
+ from collections import Counter
34
+ from pathlib import Path
35
+
36
  import anndata
37
  import matplotlib.pyplot as plt
38
  import numpy as np
39
  import pandas as pd
 
 
40
  import scanpy as sc
41
  import seaborn as sns
42
  import torch
43
+ from tdigest import TDigest
 
44
  from tqdm.auto import trange
 
45
 
46
+ from . import perturber_utils as pu
47
  from .tokenizer import TOKEN_DICTIONARY_FILE
48
 
 
 
 
 
 
 
 
 
 
49
  logger = logging.getLogger(__name__)
50
 
51
+
52
  # extract embeddings
53
+ def get_embs(
54
+ model,
55
+ filtered_input_data,
56
+ emb_mode,
57
+ layer_to_quant,
58
+ pad_token_id,
59
+ forward_batch_size,
60
+ summary_stat=None,
61
+ silent=False,
62
+ ):
63
+ model_input_size = pu.get_model_input_size(model)
64
  total_batch_length = len(filtered_input_data)
65
+
66
  if summary_stat is None:
67
  embs_list = []
68
  elif summary_stat is not None:
70
  example = filtered_input_data.select([i for i in range(1)])
71
  example.set_format(type="torch")
72
  emb_dims = test_emb(model, example["input_ids"], layer_to_quant)
73
+ if emb_mode == "cell":
74
+ # initiate tdigests for # of emb dims
75
+ embs_tdigests = [TDigest() for _ in range(emb_dims)]
76
+ if emb_mode == "gene":
77
+ gene_set = list(
78
+ {
79
+ element
80
+ for sublist in filtered_input_data["input_ids"]
81
+ for element in sublist
82
+ }
83
+ )
84
+ # initiate dict with genes as keys and tdigests for # of emb dims as values
85
+ embs_tdigests_dict = {
86
+ k: [TDigest() for _ in range(emb_dims)] for k in gene_set
87
+ }
88
+
89
+ overall_max_len = 0
90
 
91
+ for i in trange(0, total_batch_length, forward_batch_size, leave=(not silent)):
92
+ max_range = min(i + forward_batch_size, total_batch_length)
93
 
94
  minibatch = filtered_input_data.select([i for i in range(i, max_range)])
95
+
96
+ max_len = int(max(minibatch["length"]))
97
+ original_lens = torch.tensor(minibatch["length"], device="cuda")
98
  minibatch.set_format(type="torch")
99
 
100
  input_data_minibatch = minibatch["input_ids"]
101
+ input_data_minibatch = pu.pad_tensor_list(
102
+ input_data_minibatch, max_len, pad_token_id, model_input_size
103
+ )
104
+
 
105
  with torch.no_grad():
106
  outputs = model(
107
+ input_ids=input_data_minibatch.to("cuda"),
108
+ attention_mask=pu.gen_attention_mask(minibatch),
109
  )
110
 
111
  embs_i = outputs.hidden_states[layer_to_quant]
112
+
113
  if emb_mode == "cell":
114
+ mean_embs = pu.mean_nonpadding_embs(embs_i, original_lens)
115
  if summary_stat is None:
116
+ embs_list.append(mean_embs)
117
  elif summary_stat is not None:
118
  # update tdigests with current batch for each emb dim
119
+ accumulate_tdigests(embs_tdigests, mean_embs, emb_dims)
120
+ del mean_embs
121
+ elif emb_mode == "gene":
122
+ if summary_stat is None:
123
+ embs_list.append(embs_i)
124
+ elif summary_stat is not None:
125
+ for h in trange(len(minibatch)):
126
+ length_h = minibatch[h]["length"]
127
+ input_ids_h = minibatch[h]["input_ids"][0:length_h]
128
+
129
+ # double check dimensions before unsqueezing
130
+ embs_i_dim = embs_i.dim()
131
+ if embs_i_dim != 3:
132
+ logger.error(
133
+ f"Embedding tensor should have 3 dimensions, not {embs_i_dim}"
134
+ )
135
+ raise
136
+
137
+ embs_h = embs_i[h, :, :].unsqueeze(dim=1)
138
+ dict_h = dict(zip(input_ids_h, embs_h))
139
+ for k in dict_h.keys():
140
+ accumulate_tdigests(
141
+ embs_tdigests_dict[int(k)], dict_h[k], emb_dims
142
+ )
143
+
144
+ overall_max_len = max(overall_max_len, max_len)
145
  del outputs
146
  del minibatch
147
  del input_data_minibatch
148
  del embs_i
149
+
150
+ torch.cuda.empty_cache()
151
+
152
  if summary_stat is None:
153
+ if emb_mode == "cell":
154
+ embs_stack = torch.cat(embs_list, dim=0)
155
+ elif emb_mode == "gene":
156
+ embs_stack = pu.pad_tensor_list(
157
+ embs_list,
158
+ overall_max_len,
159
+ pad_token_id,
160
+ model_input_size,
161
+ 1,
162
+ pu.pad_3d_tensor,
163
+ )
164
+
165
  # calculate summary stat embs from approximated tdigests
166
  elif summary_stat is not None:
167
+ if emb_mode == "cell":
168
+ if summary_stat == "mean":
169
+ summary_emb_list = tdigest_mean(embs_tdigests, emb_dims)
170
+ elif summary_stat == "median":
171
+ summary_emb_list = tdigest_median(embs_tdigests, emb_dims)
172
+ embs_stack = torch.tensor(summary_emb_list)
173
+ elif emb_mode == "gene":
174
+ if summary_stat == "mean":
175
+ [
176
+ update_tdigest_dict_mean(embs_tdigests_dict, gene, emb_dims)
177
+ for gene in embs_tdigests_dict.keys()
178
+ ]
179
+ elif summary_stat == "median":
180
+ [
181
+ update_tdigest_dict_median(embs_tdigests_dict, gene, emb_dims)
182
+ for gene in embs_tdigests_dict.keys()
183
+ ]
184
+ return embs_tdigests_dict
185
 
186
  return embs_stack
187
 
188
+
189
+ def accumulate_tdigests(embs_tdigests, mean_embs, emb_dims):
190
+ # note: tdigest batch update known to be slow so updating serially
191
+ [
192
+ embs_tdigests[j].update(mean_embs[i, j].item())
193
+ for i in range(mean_embs.size(0))
194
+ for j in range(emb_dims)
195
+ ]
196
+
197
+
198
+ def update_tdigest_dict(embs_tdigests_dict, gene, gene_embs, emb_dims):
199
+ embs_tdigests_dict[gene] = accumulate_tdigests(
200
+ embs_tdigests_dict[gene], gene_embs, emb_dims
201
+ )
202
+
203
+
204
+ def update_tdigest_dict_mean(embs_tdigests_dict, gene, emb_dims):
205
+ embs_tdigests_dict[gene] = tdigest_mean(embs_tdigests_dict[gene], emb_dims)
206
+
207
+
208
+ def update_tdigest_dict_median(embs_tdigests_dict, gene, emb_dims):
209
+ embs_tdigests_dict[gene] = tdigest_median(embs_tdigests_dict[gene], emb_dims)
210
+
211
+
212
+ def summarize_gene_embs(h, minibatch, embs_i, embs_tdigests_dict, emb_dims):
213
+ length_h = minibatch[h]["length"]
214
+ input_ids_h = minibatch[h]["input_ids"][0:length_h]
215
+ embs_h = embs_i[h, :, :].unsqueeze(dim=1)
216
+ dict_h = dict(zip(input_ids_h, embs_h))
217
+ [
218
+ update_tdigest_dict(embs_tdigests_dict, k, dict_h[k], emb_dims)
219
+ for k in dict_h.keys()
220
+ ]
221
+
222
+
223
+ def tdigest_mean(embs_tdigests, emb_dims):
224
+ return [embs_tdigests[i].trimmed_mean(0, 100) for i in range(emb_dims)]
225
+
226
+
227
+ def tdigest_median(embs_tdigests, emb_dims):
228
+ return [embs_tdigests[i].percentile(50) for i in range(emb_dims)]
229
+
230
+
231
  def test_emb(model, example, layer_to_quant):
232
  with torch.no_grad():
233
+ outputs = model(input_ids=example.to("cuda"))
 
 
234
 
235
  embs_test = outputs.hidden_states[layer_to_quant]
236
  return embs_test.size()[2]
237
 
238
+
239
+ def label_cell_embs(embs, downsampled_data, emb_labels):
240
  embs_df = pd.DataFrame(embs.cpu().numpy())
241
  if emb_labels is not None:
242
  for label in emb_labels:
244
  embs_df[label] = emb_label
245
  return embs_df
246
 
247
+
248
+ def label_gene_embs(embs, downsampled_data, token_gene_dict):
249
+ gene_set = {
250
+ element for sublist in downsampled_data["input_ids"] for element in sublist
251
+ }
252
+ gene_emb_dict = {k: [] for k in gene_set}
253
+ for i in range(embs.size()[0]):
254
+ length = downsampled_data[i]["length"]
255
+ dict_i = dict(
256
+ zip(
257
+ downsampled_data[i]["input_ids"][0:length],
258
+ embs[i, :, :].unsqueeze(dim=1),
259
+ )
260
+ )
261
+ for k in dict_i.keys():
262
+ gene_emb_dict[k].append(dict_i[k])
263
+ for k in gene_emb_dict.keys():
264
+ gene_emb_dict[k] = (
265
+ torch.squeeze(torch.mean(torch.stack(gene_emb_dict[k]), dim=0), dim=0)
266
+ .cpu()
267
+ .numpy()
268
+ )
269
+ embs_df = pd.DataFrame(gene_emb_dict).T
270
+ embs_df.index = [token_gene_dict[token] for token in embs_df.index]
271
+ return embs_df
272
+
273
+
274
  def plot_umap(embs_df, emb_dims, label, output_file, kwargs_dict):
275
+ only_embs_df = embs_df.iloc[:, :emb_dims]
276
  only_embs_df.index = pd.RangeIndex(0, only_embs_df.shape[0], name=None).astype(str)
277
+ only_embs_df.columns = pd.RangeIndex(0, only_embs_df.shape[1], name=None).astype(
278
+ str
279
+ )
280
  vars_dict = {"embs": only_embs_df.columns}
281
+ obs_dict = {"cell_id": list(only_embs_df.index), f"{label}": list(embs_df[label])}
 
282
  adata = anndata.AnnData(X=only_embs_df, obs=obs_dict, var=vars_dict)
283
+ sc.tl.pca(adata, svd_solver="arpack")
284
  sc.pp.neighbors(adata)
285
  sc.tl.umap(adata)
286
+ sns.set(rc={"figure.figsize": (10, 10)}, font_scale=2.3)
287
  sns.set_style("white")
288
+ default_kwargs_dict = {"palette": "Set2", "size": 200}
289
  if kwargs_dict is not None:
290
  default_kwargs_dict.update(kwargs_dict)
291
+
292
  sc.pl.umap(adata, color=label, save=output_file, **default_kwargs_dict)
293
 
294
+
295
  def gen_heatmap_class_colors(labels, df):
296
+ pal = sns.cubehelix_palette(
297
+ len(Counter(labels).keys()),
298
+ light=0.9,
299
+ dark=0.1,
300
+ hue=1,
301
+ reverse=True,
302
+ start=1,
303
+ rot=-2,
304
+ )
305
  lut = dict(zip(map(str, Counter(labels).keys()), pal))
306
  colors = pd.Series(labels, index=df.index).map(lut)
307
  return colors
308
+
309
+
310
  def gen_heatmap_class_dict(classes, label_colors_series):
311
+ class_color_dict_df = pd.DataFrame(
312
+ {"classes": classes, "color": label_colors_series}
313
+ )
314
  class_color_dict_df = class_color_dict_df.drop_duplicates(subset=["classes"])
315
+ return dict(zip(class_color_dict_df["classes"], class_color_dict_df["color"]))
 
 
316
 
317
+
318
+ def make_colorbar(embs_df, label):
319
  labels = list(embs_df[label])
320
+
321
  cell_type_colors = gen_heatmap_class_colors(labels, embs_df)
322
  label_colors = pd.DataFrame(cell_type_colors, columns=[label])
323
 
324
+ for i, row in label_colors.iterrows():
325
+ colors = row[0]
326
+ if len(colors) != 3 or any(np.isnan(colors)):
327
+ print(i, colors)
328
 
329
  label_colors.isna().sum()
330
+
331
  # create dictionary for colors and classes
332
  label_color_dict = gen_heatmap_class_dict(labels, label_colors[label])
333
  return label_colors, label_color_dict
334
+
335
+
336
  def plot_heatmap(embs_df, emb_dims, label, output_file, kwargs_dict):
337
  sns.set_style("white")
338
  sns.set(font_scale=2)
339
  plt.figure(figsize=(15, 15), dpi=150)
340
  label_colors, label_color_dict = make_colorbar(embs_df, label)
341
+
342
+ default_kwargs_dict = {
343
+ "row_cluster": True,
344
+ "col_cluster": True,
345
+ "row_colors": label_colors,
346
+ "standard_scale": 1,
347
+ "linewidths": 0,
348
+ "xticklabels": False,
349
+ "yticklabels": False,
350
+ "figsize": (15, 15),
351
+ "center": 0,
352
+ "cmap": "magma",
353
+ }
354
+
355
  if kwargs_dict is not None:
356
  default_kwargs_dict.update(kwargs_dict)
357
+ g = sns.clustermap(
358
+ embs_df.iloc[:, 0:emb_dims].apply(pd.to_numeric), **default_kwargs_dict
359
+ )
360
 
361
  plt.setp(g.ax_row_colors.get_xmajorticklabels(), rotation=45, ha="right")
362
 
363
  for label_color in list(label_color_dict.keys()):
364
+ g.ax_col_dendrogram.bar(
365
+ 0, 0, color=label_color_dict[label_color], label=label_color, linewidth=0
366
+ )
367
+
368
+ g.ax_col_dendrogram.legend(
369
+ title=f"{label}",
370
+ loc="lower center",
371
+ ncol=4,
372
+ bbox_to_anchor=(0.5, 1),
373
+ facecolor="white",
374
+ )
375
 
376
+ plt.savefig(output_file, bbox_inches="tight")
 
 
 
 
377
 
 
378
 
379
  class EmbExtractor:
380
  valid_option_dict = {
381
+ "model_type": {"Pretrained", "GeneClassifier", "CellClassifier"},
382
  "num_classes": {int},
383
+ "emb_mode": {"cell", "gene"},
384
  "cell_emb_style": {"mean_pool"},
385
+ "gene_emb_style": {"mean_pool"},
386
  "filter_data": {None, dict},
387
  "max_ncells": {None, int},
388
  "emb_layer": {-1, 0},
390
  "labels_to_plot": {None, list},
391
  "forward_batch_size": {int},
392
  "nproc": {int},
393
+ "summary_stat": {None, "mean", "median", "exact_mean", "exact_median"},
394
  }
395
+
396
  def __init__(
397
  self,
398
  model_type="Pretrained",
399
  num_classes=0,
400
  emb_mode="cell",
401
  cell_emb_style="mean_pool",
402
+ gene_emb_style="mean_pool",
403
  filter_data=None,
404
  max_ncells=1000,
405
  emb_layer=-1,
425
  cell_emb_style : "mean_pool"
426
  Method for summarizing cell embeddings.
427
  Currently only option is mean pooling of gene embeddings for given cell.
428
+ gene_emb_style : "mean_pool"
429
+ Method for summarizing gene embeddings.
430
+ Currently only option is mean pooling of contextual gene embeddings for given gene.
431
  filter_data : None, dict
432
  Default is to extract embeddings from all input data.
433
  Otherwise, dictionary specifying .dataset column name and list of values to filter by.
452
  Batch size for forward pass.
453
  nproc : int
454
  Number of CPU processes to use.
455
+ summary_stat : {None, "mean", "median", "exact_mean", "exact_median"}
456
+ If exact_mean or exact_median, outputs only exact mean or median embedding of input data.
457
+ If mean or median, outputs only approximated mean or median embedding of input data.
458
+ Non-exact recommended if encountering memory constraints while generating goal embedding positions.
459
+ Non-exact is slower but more memory-efficient.
460
  token_dictionary_file : Path
461
  Path to pickle file containing token dictionary (Ensembl ID:token).
462
  """
465
  self.num_classes = num_classes
466
  self.emb_mode = emb_mode
467
  self.cell_emb_style = cell_emb_style
468
+ self.gene_emb_style = gene_emb_style
469
  self.filter_data = filter_data
470
  self.max_ncells = max_ncells
471
  self.emb_layer = emb_layer
473
  self.labels_to_plot = labels_to_plot
474
  self.forward_batch_size = forward_batch_size
475
  self.nproc = nproc
476
+ if (summary_stat is not None) and ("exact" in summary_stat):
477
+ self.summary_stat = None
478
+ self.exact_summary_stat = summary_stat
479
+ else:
480
+ self.summary_stat = summary_stat
481
+ self.exact_summary_stat = None
482
 
483
  self.validate_options()
484
 
486
  with open(token_dictionary_file, "rb") as f:
487
  self.gene_token_dict = pickle.load(f)
488
 
489
+ self.token_gene_dict = {v: k for k, v in self.gene_token_dict.items()}
490
  self.pad_token_id = self.gene_token_dict.get("<pad>")
491
+
 
492
  def validate_options(self):
 
 
 
 
 
 
 
 
493
  # confirm arguments are within valid options and compatible with each other
494
+ for attr_name, valid_options in self.valid_option_dict.items():
495
  attr_value = self.__dict__[attr_name]
496
+ if not isinstance(attr_value, (list, dict)):
497
  if attr_value in valid_options:
498
  continue
499
  valid_type = False
500
  for option in valid_options:
501
+ if (option in [int, list, dict, bool]) and isinstance(
502
+ attr_value, option
503
+ ):
504
  valid_type = True
505
  break
506
  if valid_type:
507
  continue
508
  logger.error(
509
+ f"Invalid option for {attr_name}. "
510
  f"Valid options for {attr_name}: {valid_options}"
511
  )
512
  raise
513
+
514
  if self.filter_data is not None:
515
+ for key, value in self.filter_data.items():
516
+ if not isinstance(value, list):
517
  self.filter_data[key] = [value]
518
  logger.warning(
519
+ "Values in filter_data dict must be lists. "
520
+ f"Changing {key} value to list ([{value}])."
521
+ )
522
+
523
+ def extract_embs(
524
+ self,
525
+ model_directory,
526
+ input_data_file,
527
+ output_directory,
528
+ output_prefix,
529
+ output_torch_embs=False,
530
+ cell_state=None,
531
+ ):
532
  """
533
  Extract embeddings from input data and save as results in output_directory.
534
 
545
  output_torch_embs : bool
546
  Whether or not to also output the embeddings as a tensor.
547
  Note, if true, will output embeddings as both dataframe and tensor.
548
+ cell_state : dict
549
+ Cell state key and value for state embedding extraction.
550
  """
551
 
552
+ filtered_input_data = pu.load_and_filter(
553
+ self.filter_data, self.nproc, input_data_file
554
+ )
555
+ if cell_state is not None:
556
+ filtered_input_data = pu.filter_by_dict(
557
+ filtered_input_data, cell_state, self.nproc
558
+ )
559
+ downsampled_data = pu.downsample_and_sort(filtered_input_data, self.max_ncells)
560
+ model = pu.load_model(self.model_type, self.num_classes, model_directory)
561
+ layer_to_quant = pu.quant_layers(model) + self.emb_layer
562
+ embs = get_embs(
563
+ model,
564
+ downsampled_data,
565
+ self.emb_mode,
566
+ layer_to_quant,
567
+ self.pad_token_id,
568
+ self.forward_batch_size,
569
+ self.summary_stat,
570
+ )
571
+
572
+ if self.emb_mode == "cell":
573
+ if self.summary_stat is None:
574
+ embs_df = label_cell_embs(embs, downsampled_data, self.emb_label)
575
+ elif self.summary_stat is not None:
576
+ embs_df = pd.DataFrame(embs.cpu().numpy()).T
577
+ elif self.emb_mode == "gene":
578
+ if self.summary_stat is None:
579
+ embs_df = label_gene_embs(embs, downsampled_data, self.token_gene_dict)
580
+ elif self.summary_stat is not None:
581
+ embs_df = pd.DataFrame(embs).T
582
+ embs_df.index = [self.token_gene_dict[token] for token in embs_df.index]
583
 
584
  # save embeddings to output_path
585
+ if cell_state is None:
586
+ output_path = (Path(output_directory) / output_prefix).with_suffix(".csv")
587
+ embs_df.to_csv(output_path)
588
+
589
+ if self.exact_summary_stat == "exact_mean":
590
+ embs = embs.mean(dim=0)
591
+ embs_df = pd.DataFrame(
592
+ embs_df[0:255].mean(axis="rows"), columns=[self.exact_summary_stat]
593
+ ).T
594
+ elif self.exact_summary_stat == "exact_median":
595
+ embs = torch.median(embs, dim=0)[0]
596
+ embs_df = pd.DataFrame(
597
+ embs_df[0:255].median(axis="rows"), columns=[self.exact_summary_stat]
598
+ ).T
599
+
600
+ if cell_state is not None:
601
+ return embs
602
  else:
603
+ if output_torch_embs:
604
+ return embs_df, embs
605
+ else:
606
+ return embs_df
607
+
608
+ def get_state_embs(
609
+ self,
610
+ cell_states_to_model,
611
+ model_directory,
612
+ input_data_file,
613
+ output_directory,
614
+ output_prefix,
615
+ output_torch_embs=True,
616
+ ):
617
+ """
618
+ Extract exact mean or exact median cell state embedding positions from input data and save as results in output_directory.
619
+
620
+ Parameters
621
+ ----------
622
+ cell_states_to_model : None, dict
623
+ Cell states to model if testing perturbations that achieve goal state change.
624
+ Four-item dictionary with keys: state_key, start_state, goal_state, and alt_states
625
+ state_key: key specifying name of column in .dataset that defines the start/goal states
626
+ start_state: value in the state_key column that specifies the start state
627
+ goal_state: value in the state_key column taht specifies the goal end state
628
+ alt_states: list of values in the state_key column that specify the alternate end states
629
+ For example: {"state_key": "disease",
630
+ "start_state": "dcm",
631
+ "goal_state": "nf",
632
+ "alt_states": ["hcm", "other1", "other2"]}
633
+ model_directory : Path
634
+ Path to directory containing model
635
+ input_data_file : Path
636
+ Path to directory containing .dataset inputs
637
+ output_directory : Path
638
+ Path to directory where embedding data will be saved as csv
639
+ output_prefix : str
640
+ Prefix for output file
641
+ output_torch_embs : bool
642
+ Whether or not to also output the embeddings as a tensor.
643
+ Note, if true, will output embeddings as both dataframe and tensor.
644
+
645
+ Outputs
646
+ ----------
647
+ Outputs state_embs_dict for use with in silico perturber.
648
+ Format is dictionary of embedding positions of each cell state to model shifts from/towards.
649
+ Keys specify each possible cell state to model.
650
+ Values are target embedding positions as torch.tensor.
651
+ For example: {"nf": emb_nf,
652
+ "hcm": emb_hcm,
653
+ "dcm": emb_dcm,
654
+ "other1": emb_other1,
655
+ "other2": emb_other2}
656
+ """
657
+
658
+ pu.validate_cell_states_to_model(cell_states_to_model)
659
+ valid_summary_stats = ["exact_mean", "exact_median"]
660
+ if self.exact_summary_stat not in valid_summary_stats:
661
+ logger.error(
662
+ "For extracting state embs, summary_stat in EmbExtractor "
663
+ f"must be set to option in {valid_summary_stats}"
664
+ )
665
+ raise
666
+
667
+ state_embs_dict = dict()
668
+ state_key = cell_states_to_model["state_key"]
669
+ for k, v in cell_states_to_model.items():
670
+ if k == "state_key":
671
+ continue
672
+ elif (k == "start_state") or (k == "goal_state"):
673
+ state_embs_dict[v] = self.extract_embs(
674
+ model_directory,
675
+ input_data_file,
676
+ output_directory,
677
+ output_prefix,
678
+ output_torch_embs,
679
+ cell_state={state_key: v},
680
+ )
681
+ else: # k == "alt_states"
682
+ for alt_state in v:
683
+ state_embs_dict[alt_state] = self.extract_embs(
684
+ model_directory,
685
+ input_data_file,
686
+ output_directory,
687
+ output_prefix,
688
+ output_torch_embs,
689
+ cell_state={state_key: alt_state},
690
+ )
691
+
692
+ output_path = (Path(output_directory) / output_prefix).with_suffix(".pkl")
693
+ with open(output_path, "wb") as fp:
694
+ pickle.dump(state_embs_dict, fp)
695
+
696
+ return state_embs_dict
697
+
698
+ def plot_embs(
699
+ self,
700
+ embs,
701
+ plot_style,
702
+ output_directory,
703
+ output_prefix,
704
+ max_ncells_to_plot=1000,
705
+ kwargs_dict=None,
706
+ ):
707
  """
708
  Plot embeddings, coloring by provided labels.
709
 
724
  kwargs_dict : dict
725
  Dictionary of kwargs to pass to plotting function.
726
  """
727
+
728
+ if plot_style not in ["heatmap", "umap"]:
729
  logger.error(
730
+ "Invalid option for 'plot_style'. " "Valid options: {'heatmap','umap'}"
 
731
  )
732
  raise
733
+
734
  if (plot_style == "umap") and (self.labels_to_plot is None):
735
+ logger.error("Plotting UMAP requires 'labels_to_plot'. ")
 
 
736
  raise
737
+
738
  if max_ncells_to_plot > self.max_ncells:
739
  max_ncells_to_plot = self.max_ncells
740
  logger.warning(
741
+ "max_ncells_to_plot must be <= max_ncells. "
742
+ f"Changing max_ncells_to_plot to {self.max_ncells}."
743
+ )
744
+
745
+ if (max_ncells_to_plot is not None) and (max_ncells_to_plot < self.max_ncells):
746
  embs = embs.sample(max_ncells_to_plot, axis=0)
747
+
748
  if self.emb_label is None:
749
  label_len = 0
750
  else:
751
  label_len = len(self.emb_label)
752
+
753
  emb_dims = embs.shape[1] - label_len
754
+
755
  if self.emb_label is None:
756
  emb_labels = None
757
  else:
758
  emb_labels = embs.columns[emb_dims:]
759
+
760
  if plot_style == "umap":
761
  for label in self.labels_to_plot:
762
  if label not in emb_labels:
763
  logger.warning(
764
+ f"Label {label} from labels_to_plot "
765
+ f"not present in provided embeddings dataframe."
766
+ )
767
  continue
768
  output_prefix_label = "_" + output_prefix + f"_umap_{label}"
769
+ output_file = (
770
+ Path(output_directory) / output_prefix_label
771
+ ).with_suffix(".pdf")
772
  plot_umap(embs, emb_dims, label, output_prefix_label, kwargs_dict)
773
+
774
  if plot_style == "heatmap":
775
  for label in self.labels_to_plot:
776
  if label not in emb_labels:
777
  logger.warning(
778
+ f"Label {label} from labels_to_plot "
779
+ f"not present in provided embeddings dataframe."
780
+ )
781
  continue
782
  output_prefix_label = output_prefix + f"_heatmap_{label}"
783
+ output_file = (
784
+ Path(output_directory) / output_prefix_label
785
+ ).with_suffix(".pdf")
786
+ plot_heatmap(embs, emb_dims, label, output_file, kwargs_dict)
geneformer/in_silico_perturber.py CHANGED
@@ -8,614 +8,66 @@ Usage:
8
  genes_to_perturb="all",
9
  combos=0,
10
  anchor_gene=None,
11
- model_type="Pretrained",
12
  num_classes=0,
13
  emb_mode="cell",
14
  cell_emb_style="mean_pool",
15
  filter_data={"cell_type":["cardiomyocyte"]},
16
  cell_states_to_model={"state_key": "disease", "start_state": "dcm", "goal_state": "nf", "alt_states": ["hcm", "other1", "other2"]},
 
17
  max_ncells=None,
18
- emb_layer=-1,
19
  forward_batch_size=100,
20
- nproc=4)
21
  isp.perturb_data("path/to/model",
22
  "path/to/input_data",
23
  "path/to/output_directory",
24
  "output_prefix")
25
  """
26
 
27
- # imports
28
- import itertools as it
29
  import logging
30
- import numpy as np
 
 
31
  import pickle
32
- import re
33
- import seaborn as sns; sns.set()
34
- import torch
35
  from collections import defaultdict
36
- from datasets import Dataset, load_from_disk
 
 
 
37
  from tqdm.auto import trange
38
- from transformers import BertForMaskedLM, BertForTokenClassification, BertForSequenceClassification
39
 
 
 
40
  from .tokenizer import TOKEN_DICTIONARY_FILE
41
 
42
- logger = logging.getLogger(__name__)
43
-
44
-
45
- # load data and filter by defined criteria
46
- def load_and_filter(filter_data, nproc, input_data_file):
47
- data = load_from_disk(input_data_file)
48
- if filter_data is not None:
49
- for key,value in filter_data.items():
50
- def filter_data_by_criteria(example):
51
- return example[key] in value
52
- data = data.filter(filter_data_by_criteria, num_proc=nproc)
53
- if len(data) == 0:
54
- logger.error(
55
- "No cells remain after filtering. Check filtering criteria.")
56
- raise
57
- data_shuffled = data.shuffle(seed=42)
58
- return data_shuffled
59
-
60
- # load model to GPU
61
- def load_model(model_type, num_classes, model_directory):
62
- if model_type == "Pretrained":
63
- model = BertForMaskedLM.from_pretrained(model_directory,
64
- output_hidden_states=True,
65
- output_attentions=False)
66
- elif model_type == "GeneClassifier":
67
- model = BertForTokenClassification.from_pretrained(model_directory,
68
- num_labels=num_classes,
69
- output_hidden_states=True,
70
- output_attentions=False)
71
- elif model_type == "CellClassifier":
72
- model = BertForSequenceClassification.from_pretrained(model_directory,
73
- num_labels=num_classes,
74
- output_hidden_states=True,
75
- output_attentions=False)
76
- # put the model in eval mode for fwd pass
77
- model.eval()
78
- model = model.to("cuda:0")
79
- return model
80
-
81
- def quant_layers(model):
82
- layer_nums = []
83
- for name, parameter in model.named_parameters():
84
- if "layer" in name:
85
- layer_nums += [int(name.split("layer.")[1].split(".")[0])]
86
- return int(max(layer_nums))+1
87
-
88
- def get_model_input_size(model):
89
- return int(re.split("\(|,",str(model.bert.embeddings.position_embeddings))[1])
90
-
91
- def flatten_list(megalist):
92
- return [item for sublist in megalist for item in sublist]
93
-
94
- def measure_length(example):
95
- example["length"] = len(example["input_ids"])
96
- return example
97
-
98
- def downsample_and_sort(data_shuffled, max_ncells):
99
- num_cells = len(data_shuffled)
100
- # if max number of cells is defined, then subsample to this max number
101
- if max_ncells != None:
102
- num_cells = min(max_ncells,num_cells)
103
- data_subset = data_shuffled.select([i for i in range(num_cells)])
104
- # sort dataset with largest cell first to encounter any memory errors earlier
105
- data_sorted = data_subset.sort("length",reverse=True)
106
- return data_sorted
107
-
108
- def get_possible_states(cell_states_to_model):
109
- possible_states = []
110
- for key in ["start_state","goal_state"]:
111
- possible_states += [cell_states_to_model[key]]
112
- possible_states += cell_states_to_model.get("alt_states",[])
113
- return possible_states
114
-
115
- def forward_pass_single_cell(model, example_cell, layer_to_quant):
116
- example_cell.set_format(type="torch")
117
- input_data = example_cell["input_ids"]
118
- with torch.no_grad():
119
- outputs = model(
120
- input_ids = input_data.to("cuda")
121
- )
122
- emb = torch.squeeze(outputs.hidden_states[layer_to_quant])
123
- del outputs
124
- return emb
125
-
126
- def perturb_emb_by_index(emb, indices):
127
- mask = torch.ones(emb.numel(), dtype=torch.bool)
128
- mask[indices] = False
129
- return emb[mask]
130
-
131
- def delete_indices(example):
132
- indices = example["perturb_index"]
133
- if any(isinstance(el, list) for el in indices):
134
- indices = flatten_list(indices)
135
- for index in sorted(indices, reverse=True):
136
- del example["input_ids"][index]
137
- return example
138
-
139
- # for genes_to_perturb = "all" where only genes within cell are overexpressed
140
- def overexpress_indices(example):
141
- indices = example["perturb_index"]
142
- if any(isinstance(el, list) for el in indices):
143
- indices = flatten_list(indices)
144
- for index in sorted(indices, reverse=True):
145
- example["input_ids"].insert(0, example["input_ids"].pop(index))
146
- return example
147
-
148
- # for genes_to_perturb = list of genes to overexpress that are not necessarily expressed in cell
149
- def overexpress_tokens(example):
150
- # -100 indicates tokens to overexpress are not present in rank value encoding
151
- if example["perturb_index"] != [-100]:
152
- example = delete_indices(example)
153
- [example["input_ids"].insert(0, token) for token in example["tokens_to_perturb"][::-1]]
154
-
155
- return example
156
-
157
- def remove_indices_from_emb(emb, indices_to_remove, gene_dim):
158
- # indices_to_remove is list of indices to remove
159
- indices_to_keep = [i for i in range(emb.size()[gene_dim]) if i not in indices_to_remove]
160
- num_dims = emb.dim()
161
- emb_slice = [slice(None) if dim != gene_dim else indices_to_keep for dim in range(num_dims)]
162
- sliced_emb = emb[emb_slice]
163
- return sliced_emb
164
-
165
- def remove_indices_from_emb_batch(emb_batch, list_of_indices_to_remove, gene_dim):
166
- output_batch = torch.stack([
167
- remove_indices_from_emb(emb_batch[i, :, :], idxs, gene_dim-1) for
168
- i, idxs in enumerate(list_of_indices_to_remove)
169
- ])
170
- return output_batch
171
-
172
- def make_perturbation_batch(example_cell,
173
- perturb_type,
174
- tokens_to_perturb,
175
- anchor_token,
176
- combo_lvl,
177
- num_proc):
178
- if tokens_to_perturb == "all":
179
- if perturb_type in ["overexpress","activate"]:
180
- range_start = 1
181
- elif perturb_type in ["delete","inhibit"]:
182
- range_start = 0
183
- indices_to_perturb = [[i] for i in range(range_start, example_cell["length"][0])]
184
- elif combo_lvl>0 and (anchor_token is not None):
185
- example_input_ids = example_cell["input_ids "][0]
186
- anchor_index = example_input_ids.index(anchor_token[0])
187
- indices_to_perturb = [sorted([anchor_index,i]) if i!=anchor_index else None for i in range(example_cell["length"][0])]
188
- indices_to_perturb = [item for item in indices_to_perturb if item is not None]
189
- else:
190
- example_input_ids = example_cell["input_ids"][0]
191
- indices_to_perturb = [[example_input_ids.index(token)] if token in example_input_ids else None for token in tokens_to_perturb]
192
- indices_to_perturb = [item for item in indices_to_perturb if item is not None]
193
-
194
- # create all permutations of combo_lvl of modifiers from tokens_to_perturb
195
- if combo_lvl>0 and (anchor_token is None):
196
- if tokens_to_perturb != "all":
197
- if len(tokens_to_perturb) == combo_lvl+1:
198
- indices_to_perturb = [list(x) for x in it.combinations(indices_to_perturb, combo_lvl+1)]
199
- else:
200
- all_indices = [[i] for i in range(example_cell["length"][0])]
201
- all_indices = [index for index in all_indices if index not in indices_to_perturb]
202
- indices_to_perturb = [[[j for i in indices_to_perturb for j in i], x] for x in all_indices]
203
- length = len(indices_to_perturb)
204
- perturbation_dataset = Dataset.from_dict({"input_ids": example_cell["input_ids"]*length,
205
- "perturb_index": indices_to_perturb})
206
- if length<400:
207
- num_proc_i = 1
208
- else:
209
- num_proc_i = num_proc
210
- if perturb_type == "delete":
211
- perturbation_dataset = perturbation_dataset.map(delete_indices, num_proc=num_proc_i)
212
- elif perturb_type == "overexpress":
213
- perturbation_dataset = perturbation_dataset.map(overexpress_indices, num_proc=num_proc_i)
214
- return perturbation_dataset, indices_to_perturb
215
-
216
- # perturbed cell emb removing the activated/overexpressed/inhibited gene emb
217
- # so that only non-perturbed gene embeddings are compared to each other
218
- # in original or perturbed context
219
- def make_comparison_batch(original_emb_batch, indices_to_perturb, perturb_group):
220
- all_embs_list = []
221
-
222
- # if making comparison batch for multiple perturbations in single cell
223
- if perturb_group == False:
224
- original_emb_list = [original_emb_batch]*len(indices_to_perturb)
225
- # if making comparison batch for single perturbation in multiple cells
226
- elif perturb_group == True:
227
- original_emb_list = original_emb_batch
228
-
229
-
230
- for i in range(len(original_emb_list)):
231
- original_emb = original_emb_list[i]
232
- indices = indices_to_perturb[i]
233
- if indices == [-100]:
234
- all_embs_list += [original_emb[:]]
235
- continue
236
- emb_list = []
237
- start = 0
238
- if any(isinstance(el, list) for el in indices):
239
- indices = flatten_list(indices)
240
- for i in sorted(indices):
241
- emb_list += [original_emb[start:i]]
242
- start = i+1
243
- emb_list += [original_emb[start:]]
244
- all_embs_list += [torch.cat(emb_list)]
245
- len_set = set([emb.size()[0] for emb in all_embs_list])
246
- if len(len_set) > 1:
247
- max_len = max(len_set)
248
- all_embs_list = [pad_2d_tensor(emb, None, max_len, 0) for emb in all_embs_list]
249
- return torch.stack(all_embs_list)
250
-
251
- # average embedding position of goal cell states
252
- def get_cell_state_avg_embs(model,
253
- filtered_input_data,
254
- cell_states_to_model,
255
- layer_to_quant,
256
- pad_token_id,
257
- forward_batch_size,
258
- num_proc):
259
-
260
- model_input_size = get_model_input_size(model)
261
- possible_states = get_possible_states(cell_states_to_model)
262
- state_embs_dict = dict()
263
- for possible_state in possible_states:
264
- state_embs_list = []
265
- original_lens = []
266
-
267
- def filter_states(example):
268
- state_key = cell_states_to_model["state_key"]
269
- return example[state_key] in [possible_state]
270
- filtered_input_data_state = filtered_input_data.filter(filter_states, num_proc=num_proc)
271
- total_batch_length = len(filtered_input_data_state)
272
- if ((total_batch_length-1)/forward_batch_size).is_integer():
273
- forward_batch_size = forward_batch_size-1
274
- max_len = max(filtered_input_data_state["length"])
275
- for i in range(0, total_batch_length, forward_batch_size):
276
- max_range = min(i+forward_batch_size, total_batch_length)
277
-
278
- state_minibatch = filtered_input_data_state.select([i for i in range(i, max_range)])
279
- state_minibatch.set_format(type="torch")
280
-
281
- input_data_minibatch = state_minibatch["input_ids"]
282
- original_lens += state_minibatch["length"]
283
- input_data_minibatch = pad_tensor_list(input_data_minibatch,
284
- max_len,
285
- pad_token_id,
286
- model_input_size)
287
- attention_mask = gen_attention_mask(state_minibatch, max_len)
288
-
289
- with torch.no_grad():
290
- outputs = model(
291
- input_ids = input_data_minibatch.to("cuda"),
292
- attention_mask = attention_mask
293
- )
294
-
295
- state_embs_i = outputs.hidden_states[layer_to_quant]
296
- state_embs_list += [state_embs_i]
297
- del outputs
298
- del state_minibatch
299
- del input_data_minibatch
300
- del attention_mask
301
- del state_embs_i
302
- torch.cuda.empty_cache()
303
-
304
- state_embs = torch.cat(state_embs_list)
305
- avg_state_emb = mean_nonpadding_embs(state_embs, torch.Tensor(original_lens).to("cuda"))
306
- avg_state_emb = torch.mean(avg_state_emb, dim=0, keepdim=True)
307
- state_embs_dict[possible_state] = avg_state_emb
308
- return state_embs_dict
309
-
310
- # quantify cosine similarity of perturbed vs original or alternate states
311
- def quant_cos_sims(model,
312
- perturb_type,
313
- perturbation_batch,
314
- forward_batch_size,
315
- layer_to_quant,
316
- original_emb,
317
- tokens_to_perturb,
318
- indices_to_perturb,
319
- perturb_group,
320
- cell_states_to_model,
321
- state_embs_dict,
322
- pad_token_id,
323
- model_input_size,
324
- nproc):
325
- cos = torch.nn.CosineSimilarity(dim=2)
326
- total_batch_length = len(perturbation_batch)
327
-
328
- if ((total_batch_length-1)/forward_batch_size).is_integer():
329
- forward_batch_size = forward_batch_size-1
330
-
331
- if perturb_group == False:
332
- comparison_batch = make_comparison_batch(original_emb, indices_to_perturb, perturb_group)
333
-
334
- if cell_states_to_model is None:
335
- cos_sims = []
336
- else:
337
- possible_states = get_possible_states(cell_states_to_model)
338
- cos_sims_vs_alt_dict = dict(zip(possible_states,[[] for _ in range(len(possible_states))]))
339
-
340
- # measure length of each element in perturbation_batch
341
- perturbation_batch = perturbation_batch.map(
342
- measure_length, num_proc=nproc
343
- )
344
 
345
- def compute_batch_embeddings(minibatch, _max_len = None):
346
- minibatch_lengths = minibatch["length"]
347
- minibatch_length_set = set(minibatch_lengths)
348
- max_len = model_input_size
349
 
350
- if (len(minibatch_length_set) > 1) or (max(minibatch_length_set) > max_len):
351
- needs_pad_or_trunc = True
352
- else:
353
- needs_pad_or_trunc = False
354
- max_len = max(minibatch_length_set)
355
-
356
-
357
- if needs_pad_or_trunc == True:
358
- if _max_len is None:
359
- max_len = min(max(minibatch_length_set), max_len)
360
- else:
361
- max_len = _max_len
362
- def pad_or_trunc_example(example):
363
- example["input_ids"] = pad_or_truncate_encoding(example["input_ids"],
364
- pad_token_id,
365
- max_len)
366
- return example
367
- minibatch = minibatch.map(pad_or_trunc_example, num_proc=nproc)
368
-
369
- minibatch.set_format(type="torch")
370
-
371
- input_data_minibatch = minibatch["input_ids"]
372
- attention_mask = gen_attention_mask(minibatch, max_len)
373
-
374
- # extract embeddings for perturbation minibatch
375
- with torch.no_grad():
376
- outputs = model(
377
- input_ids = input_data_minibatch.to("cuda"),
378
- attention_mask = attention_mask
379
- )
380
 
381
- return outputs, max_len
382
-
383
- for i in range(0, total_batch_length, forward_batch_size):
384
- max_range = min(i+forward_batch_size, total_batch_length)
385
- perturbation_minibatch = perturbation_batch.select([i for i in range(i, max_range)])
386
- outputs, mini_max_len = compute_batch_embeddings(perturbation_minibatch)
387
-
388
- if len(indices_to_perturb)>1:
389
- minibatch_emb = torch.squeeze(outputs.hidden_states[layer_to_quant])
390
- else:
391
- minibatch_emb = outputs.hidden_states[layer_to_quant]
392
-
393
- if perturb_type == "overexpress":
394
- # remove overexpressed genes to quantify effect on remaining genes
395
- if perturb_group == False:
396
- overexpressed_to_remove = 1
397
- if perturb_group == True:
398
- overexpressed_to_remove = len(tokens_to_perturb)
399
- minibatch_emb = minibatch_emb[:, overexpressed_to_remove: ,:]
400
-
401
-
402
- # if quantifying single perturbation in multiple different cells, pad original batch and extract embs
403
- if perturb_group == True:
404
- # pad minibatch of original batch to extract embeddings
405
- # truncate to the (model input size - # tokens to overexpress) to ensure comparability
406
- # since max input size of perturb batch will be reduced by # tokens to overexpress
407
- original_minibatch = original_emb.select([i for i in range(i, max_range)])
408
- original_outputs, orig_max_len = compute_batch_embeddings(original_minibatch, mini_max_len)
409
-
410
- if len(indices_to_perturb)>1:
411
- original_minibatch_emb = torch.squeeze(original_outputs.hidden_states[layer_to_quant])
412
- else:
413
- original_minibatch_emb = original_outputs.hidden_states[layer_to_quant]
414
-
415
- # if we overexpress genes that aren't already expressed,
416
- # we need to remove genes to make sure the embeddings are of a consistent size
417
- # get rid of the bottom n genes/padding since those will get truncated anyways
418
- # multiple perturbations is more complicated because if 1 out of n perturbed genes is expressed
419
- # the idxs will still not be [-100]
420
- if len(tokens_to_perturb) == 1:
421
- indices_to_perturb_minibatch = [idx if idx != [-100] else [orig_max_len - 1]
422
- for idx in indices_to_perturb[i:max_range]]
423
- else:
424
- num_perturbed = len(tokens_to_perturb)
425
- indices_to_perturb_minibatch = []
426
- end_range = [i for i in range(orig_max_len - tokens_to_perturb, orig_max_len)]
427
- for idx in indices_to_perturb[i:i+max_range]:
428
- if idx == [-100]:
429
- indices_to_perturb_minibatch.append(end_range)
430
- elif len(idx) < len(tokens_to_perturb):
431
- indices_to_perturb_minibatch.append(idx + end_range[-num_perturbed:])
432
- else:
433
- indices_to_perturb_minibatch.append(idx)
434
-
435
- original_minibatch_emb = remove_indices_from_emb_batch(original_minibatch_emb,
436
- indices_to_perturb_minibatch,
437
- gene_dim=1)
438
-
439
- # cosine similarity between original emb and batch items
440
- if cell_states_to_model is None:
441
- if perturb_group == False:
442
- minibatch_comparison = comparison_batch[i:max_range]
443
- elif perturb_group == True:
444
- minibatch_comparison = original_minibatch_emb
445
- cos_sims += [cos(minibatch_emb, minibatch_comparison).to("cpu")]
446
- elif cell_states_to_model is not None:
447
- if perturb_group == False:
448
- original_emb = comparison_batch[i:max_range]
449
- else:
450
- original_minibatch_lengths = torch.tensor(original_minibatch["length"], device="cuda")
451
- minibatch_lengths = torch.tensor(perturbation_minibatch["length"], device="cuda")
452
- for state in possible_states:
453
- if perturb_group == False:
454
- cos_sims_vs_alt_dict[state] += cos_sim_shift(original_emb,
455
- minibatch_emb,
456
- state_embs_dict[state],
457
- perturb_group)
458
- elif perturb_group == True:
459
- cos_sims_vs_alt_dict[state] += cos_sim_shift(original_minibatch_emb,
460
- minibatch_emb,
461
- state_embs_dict[state],
462
- perturb_group,
463
- original_minibatch_lengths,
464
- minibatch_lengths)
465
- del outputs
466
- del minibatch_emb
467
- if cell_states_to_model is None:
468
- del minibatch_comparison
469
- if perturb_group == True:
470
- del original_minibatch_emb
471
- torch.cuda.empty_cache()
472
- if cell_states_to_model is None:
473
- cos_sims_stack = torch.cat(cos_sims)
474
- return cos_sims_stack
475
- else:
476
- for state in possible_states:
477
- cos_sims_vs_alt_dict[state] = torch.cat(cos_sims_vs_alt_dict[state])
478
- return cos_sims_vs_alt_dict
479
-
480
-
481
- # calculate cos sim shift of perturbation with respect to origin and alternative cell
482
- def cos_sim_shift(original_emb,
483
- minibatch_emb,
484
- end_emb,
485
- perturb_group,
486
- original_minibatch_lengths = None,
487
- minibatch_lengths = None):
488
- cos = torch.nn.CosineSimilarity(dim=2)
489
- if original_emb.size() != minibatch_emb.size():
490
- logger.error(
491
- f"Embeddings are not the same dimensions. " \
492
- f"original_emb is {original_emb.size()}. " \
493
- f"minibatch_emb is {minibatch_emb.size()}. "
494
- )
495
- raise
496
- if not perturb_group:
497
- original_emb = torch.mean(original_emb,dim=1,keepdim=True)
498
- origin_v_end = torch.squeeze(cos(original_emb, end_emb))
499
- else:
500
- if original_minibatch_lengths is not None:
501
- original_emb = mean_nonpadding_embs(original_emb, original_minibatch_lengths)
502
- # else:
503
- # original_emb = torch.mean(original_emb,dim=1,keepdim=True)
504
-
505
- end_emb = torch.unsqueeze(end_emb, 1)
506
- origin_v_end = torch.squeeze(cos(original_emb, end_emb))
507
- if minibatch_lengths is not None:
508
- perturb_emb = mean_nonpadding_embs(minibatch_emb, minibatch_lengths)
509
- else:
510
- perturb_emb = torch.mean(minibatch_emb,dim=1,keepdim=True)
511
- perturb_v_end = cos(perturb_emb, end_emb)
512
- perturb_v_end = torch.squeeze(perturb_v_end)
513
- if (perturb_v_end-origin_v_end).numel() == 1:
514
- return [([perturb_v_end-origin_v_end]).to("cpu")]
515
- return [(perturb_v_end-origin_v_end).to("cpu")]
516
-
517
- def pad_list(input_ids, pad_token_id, max_len):
518
- input_ids = np.pad(input_ids,
519
- (0, max_len-len(input_ids)),
520
- mode='constant', constant_values=pad_token_id)
521
- return input_ids
522
-
523
- def pad_tensor(tensor, pad_token_id, max_len):
524
- tensor = torch.nn.functional.pad(tensor, pad=(0,
525
- max_len - tensor.numel()),
526
- mode='constant',
527
- value=pad_token_id)
528
- return tensor
529
-
530
- def pad_2d_tensor(tensor, pad_token_id, max_len, dim):
531
- if dim == 0:
532
- pad = (0, 0, 0, max_len - tensor.size()[dim])
533
- elif dim == 1:
534
- pad = (0, max_len - tensor.size()[dim], 0, 0)
535
- tensor = torch.nn.functional.pad(tensor, pad=pad,
536
- mode='constant',
537
- value=pad_token_id)
538
- return tensor
539
-
540
- def pad_or_truncate_encoding(encoding, pad_token_id, max_len):
541
- if isinstance(encoding, torch.Tensor):
542
- encoding_len = tensor.size()[0]
543
- elif isinstance(encoding, list):
544
- encoding_len = len(encoding)
545
- if encoding_len > max_len:
546
- encoding = encoding[0:max_len]
547
- elif encoding_len < max_len:
548
- if isinstance(encoding, torch.Tensor):
549
- encoding = pad_tensor(encoding, pad_token_id, max_len)
550
- elif isinstance(encoding, list):
551
- encoding = pad_list(encoding, pad_token_id, max_len)
552
- return encoding
553
-
554
- # pad list of tensors and convert to tensor
555
- def pad_tensor_list(tensor_list, dynamic_or_constant, pad_token_id, model_input_size):
556
-
557
- # Determine maximum tensor length
558
- if dynamic_or_constant == "dynamic":
559
- max_len = max([tensor.squeeze().numel() for tensor in tensor_list])
560
- elif type(dynamic_or_constant) == int:
561
- max_len = dynamic_or_constant
562
- else:
563
- max_len = model_input_size
564
- logger.warning(
565
- "If padding style is constant, must provide integer value. " \
566
- f"Setting padding to max input size {model_input_size}.")
567
-
568
- # pad all tensors to maximum length
569
- tensor_list = [pad_tensor(tensor, pad_token_id, max_len) for tensor in tensor_list]
570
-
571
- # return stacked tensors
572
- return torch.stack(tensor_list)
573
-
574
- def gen_attention_mask(minibatch_encoding, max_len = None):
575
- if max_len == None:
576
- max_len = max(minibatch_encoding["length"])
577
- original_lens = minibatch_encoding["length"]
578
- attention_mask = [[1]*original_len
579
- +[0]*(max_len - original_len)
580
- if original_len <= max_len
581
- else [1]*max_len
582
- for original_len in original_lens]
583
- return torch.tensor(attention_mask).to("cuda")
584
-
585
- # get cell embeddings excluding padding
586
- def mean_nonpadding_embs(embs, original_lens):
587
- # mask based on padding lengths
588
- mask = torch.arange(embs.size(1)).unsqueeze(0).to("cuda") < original_lens.unsqueeze(1)
589
-
590
- # extend mask dimensions to match the embeddings tensor
591
- mask = mask.unsqueeze(2).expand_as(embs)
592
-
593
- # use the mask to zero out the embeddings in padded areas
594
- masked_embs = embs * mask.float()
595
-
596
- # sum and divide by the lengths to get the mean of non-padding embs
597
- mean_embs = masked_embs.sum(1) / original_lens.view(-1, 1).float()
598
- return mean_embs
599
 
600
  class InSilicoPerturber:
601
  valid_option_dict = {
602
- "perturb_type": {"delete","overexpress","inhibit","activate"},
603
  "perturb_rank_shift": {None, 1, 2, 3},
604
  "genes_to_perturb": {"all", list},
605
  "combos": {0, 1},
606
  "anchor_gene": {None, str},
607
- "model_type": {"Pretrained","GeneClassifier","CellClassifier"},
608
  "num_classes": {int},
609
- "emb_mode": {"cell","cell_and_gene"},
610
  "cell_emb_style": {"mean_pool"},
611
  "filter_data": {None, dict},
612
  "cell_states_to_model": {None, dict},
 
613
  "max_ncells": {None, int},
614
  "cell_inds_to_perturb": {"all", dict},
615
  "emb_layer": {-1, 0},
616
  "forward_batch_size": {int},
617
  "nproc": {int},
618
  }
 
619
  def __init__(
620
  self,
621
  perturb_type="delete",
@@ -629,6 +81,7 @@ class InSilicoPerturber:
629
  cell_emb_style="mean_pool",
630
  filter_data=None,
631
  cell_states_to_model=None,
 
632
  max_ncells=None,
633
  cell_inds_to_perturb="all",
634
  emb_layer=-1,
@@ -676,13 +129,14 @@ class InSilicoPerturber:
676
  For the pretrained Geneformer model, number of classes is 0 as it is not a classifier.
677
  emb_mode : {"cell","cell_and_gene"}
678
  Whether to output impact of perturbation on cell and/or gene embeddings.
 
679
  cell_emb_style : "mean_pool"
680
  Method for summarizing cell embeddings.
681
  Currently only option is mean pooling of gene embeddings for given cell.
682
  filter_data : None, dict
683
  Default is to use all input data for in silico perturbation study.
684
  Otherwise, dictionary specifying .dataset column name and list of values to filter by.
685
- cell_states_to_model: None, dict
686
  Cell states to model if testing perturbations that achieve goal state change.
687
  Four-item dictionary with keys: state_key, start_state, goal_state, and alt_states
688
  state_key: key specifying name of column in .dataset that defines the start/goal states
@@ -693,6 +147,15 @@ class InSilicoPerturber:
693
  "start_state": "dcm",
694
  "goal_state": "nf",
695
  "alt_states": ["hcm", "other1", "other2"]}
 
 
 
 
 
 
 
 
 
696
  max_ncells : None, int
697
  Maximum number of cells to test.
698
  If None, will test all cells.
@@ -705,8 +168,8 @@ class InSilicoPerturber:
705
  Useful for splitting extremely large datasets across separate GPUs.
706
  emb_layer : {-1, 0}
707
  Embedding layer to use for quantification.
708
- -1: 2nd to last layer (recommended for pretrained Geneformer)
709
- 0: last layer (recommended for cell classifier fine-tuned for disease state)
710
  forward_batch_size : int
711
  Batch size for forward pass.
712
  nproc : int
@@ -721,23 +184,25 @@ class InSilicoPerturber:
721
  self.combos = combos
722
  self.anchor_gene = anchor_gene
723
  if self.genes_to_perturb == "all":
724
- self.perturb_group = False
725
  else:
726
  self.perturb_group = True
727
- if (self.anchor_gene != None) or (self.combos != 0):
728
  self.anchor_gene = None
729
  self.combos = 0
730
  logger.warning(
731
- "anchor_gene set to None and combos set to 0. " \
732
- "If providing list of genes to perturb, " \
733
- "list of genes_to_perturb will be perturbed together, "\
734
- "without anchor gene or combinations.")
 
735
  self.model_type = model_type
736
  self.num_classes = num_classes
737
  self.emb_mode = emb_mode
738
  self.cell_emb_style = cell_emb_style
739
  self.filter_data = filter_data
740
  self.cell_states_to_model = cell_states_to_model
 
741
  self.max_ncells = max_ncells
742
  self.cell_inds_to_perturb = cell_inds_to_perturb
743
  self.emb_layer = emb_layer
@@ -758,36 +223,47 @@ class InSilicoPerturber:
758
  try:
759
  self.anchor_token = [self.gene_token_dict[self.anchor_gene]]
760
  except KeyError:
761
- logger.error(
762
- f"Anchor gene {self.anchor_gene} not in token dictionary."
763
- )
764
  raise
765
 
766
  if self.genes_to_perturb == "all":
767
  self.tokens_to_perturb = "all"
768
  else:
769
- missing_genes = [gene for gene in self.genes_to_perturb if gene not in self.gene_token_dict.keys()]
 
 
 
 
770
  if len(missing_genes) == len(self.genes_to_perturb):
771
  logger.error(
772
  "None of the provided genes to perturb are in token dictionary."
773
  )
774
  raise
775
- elif len(missing_genes)>0:
776
  logger.warning(
777
- f"Genes to perturb {missing_genes} are not in token dictionary.")
778
- self.tokens_to_perturb = [self.gene_token_dict.get(gene) for gene in self.genes_to_perturb]
 
 
 
779
 
780
  def validate_options(self):
781
  # first disallow options under development
782
  if self.perturb_type in ["inhibit", "activate"]:
783
  logger.error(
784
- "In silico inhibition and activation currently under development. " \
785
  "Current valid options for 'perturb_type': 'delete' or 'overexpress'"
786
  )
787
  raise
788
-
 
 
 
 
 
 
789
  # confirm arguments are within valid options and compatible with each other
790
- for attr_name,valid_options in self.valid_option_dict.items():
791
  attr_value = self.__dict__[attr_name]
792
  if type(attr_value) not in {list, dict}:
793
  if attr_value in valid_options:
@@ -797,141 +273,120 @@ class InSilicoPerturber:
797
  continue
798
  valid_type = False
799
  for option in valid_options:
800
- if (option in [int,list,dict]) and isinstance(attr_value, option):
 
 
801
  valid_type = True
802
  break
803
  if valid_type:
804
  continue
805
  logger.error(
806
- f"Invalid option for {attr_name}. " \
807
  f"Valid options for {attr_name}: {valid_options}"
808
  )
809
  raise
810
-
811
- if self.perturb_type in ["delete","overexpress"]:
812
  if self.perturb_rank_shift is not None:
813
  if self.perturb_type == "delete":
814
  logger.warning(
815
- "perturb_rank_shift set to None. " \
816
- "If perturb type is delete then gene is deleted entirely " \
817
- "rather than shifted by quartile")
 
818
  elif self.perturb_type == "overexpress":
819
  logger.warning(
820
- "perturb_rank_shift set to None. " \
821
- "If perturb type is overexpress then gene is moved to front " \
822
- "of rank value encoding rather than shifted by quartile")
 
823
  self.perturb_rank_shift = None
824
-
825
  if (self.anchor_gene is not None) and (self.emb_mode == "cell_and_gene"):
826
  self.emb_mode = "cell"
827
  logger.warning(
828
- "emb_mode set to 'cell'. " \
829
- "Currently, analysis with anchor gene " \
830
- "only outputs effect on cell embeddings.")
831
-
 
832
  if self.cell_states_to_model is not None:
833
- if len(self.cell_states_to_model.items()) == 1:
 
 
 
834
  logger.warning(
835
- "The single value dictionary for cell_states_to_model will be " \
836
- "replaced with a dictionary with named keys for start, goal, and alternate states. " \
837
- "Please specify state_key, start_state, goal_state, and alt_states " \
838
- "in the cell_states_to_model dictionary for future use. " \
839
- "For example, cell_states_to_model={" \
840
- "'state_key': 'disease', " \
841
- "'start_state': 'dcm', " \
842
- "'goal_state': 'nf', " \
843
- "'alt_states': ['hcm', 'other1', 'other2']}"
844
  )
845
- for key,value in self.cell_states_to_model.items():
846
- if (len(value) == 3) and isinstance(value, tuple):
847
- if isinstance(value[0],list) and isinstance(value[1],list) and isinstance(value[2],list):
848
- if len(value[0]) == 1 and len(value[1]) == 1:
849
- all_values = value[0]+value[1]+value[2]
850
- if len(all_values) == len(set(all_values)):
851
- continue
852
- # reformat to the new named key format
853
- state_values = flatten_list(list(self.cell_states_to_model.values()))
854
- self.cell_states_to_model = {
855
- "state_key": list(self.cell_states_to_model.keys())[0],
856
- "start_state": state_values[0][0],
857
- "goal_state": state_values[1][0],
858
- "alt_states": state_values[2:][0]
859
- }
860
- elif set(self.cell_states_to_model.keys()) == {"state_key", "start_state", "goal_state", "alt_states"}:
861
- if (self.cell_states_to_model["state_key"] is None) \
862
- or (self.cell_states_to_model["start_state"] is None) \
863
- or (self.cell_states_to_model["goal_state"] is None):
864
- logger.error(
865
- "Please specify 'state_key', 'start_state', and 'goal_state' in cell_states_to_model.")
866
- raise
867
-
868
- if self.cell_states_to_model["start_state"] == self.cell_states_to_model["goal_state"]:
869
  logger.error(
870
- "All states must be unique.")
 
871
  raise
872
 
873
- if self.cell_states_to_model["alt_states"] is not None:
874
- if type(self.cell_states_to_model["alt_states"]) is not list:
875
- logger.error(
876
- "self.cell_states_to_model['alt_states'] must be a list (even if it is one element)."
877
- )
878
- raise
879
- if len(self.cell_states_to_model["alt_states"])!= len(set(self.cell_states_to_model["alt_states"])):
880
- logger.error(
881
- "All states must be unique.")
882
- raise
883
-
884
- else:
885
  logger.error(
886
- "cell_states_to_model must only have the following four keys: " \
887
- "'state_key', 'start_state', 'goal_state', 'alt_states'." \
888
- "For example, cell_states_to_model={" \
889
- "'state_key': 'disease', " \
890
- "'start_state': 'dcm', " \
891
- "'goal_state': 'nf', " \
892
- "'alt_states': ['hcm', 'other1', 'other2']}"
893
  )
894
  raise
895
 
896
- if self.anchor_gene is not None:
897
- self.anchor_gene = None
898
- logger.warning(
899
- "anchor_gene set to None. " \
900
- "Currently, anchor gene not available " \
901
- "when modeling multiple cell states.")
902
-
903
- if self.perturb_type in ["inhibit","activate"]:
904
  if self.perturb_rank_shift is None:
905
  logger.error(
906
- "If perturb_type is inhibit or activate then " \
907
- "quartile to shift by must be specified.")
 
908
  raise
909
-
910
  if self.filter_data is not None:
911
- for key,value in self.filter_data.items():
912
- if type(value) != list:
913
  self.filter_data[key] = [value]
914
  logger.warning(
915
- "Values in filter_data dict must be lists. " \
916
- f"Changing {key} value to list ([{value}]).")
917
-
 
918
  if self.cell_inds_to_perturb != "all":
919
  if set(self.cell_inds_to_perturb.keys()) != {"start", "end"}:
920
  logger.error(
921
  "If cell_inds_to_perturb is a dictionary, keys must be 'start' and 'end'."
922
  )
923
  raise
924
- if self.cell_inds_to_perturb["start"] < 0 or self.cell_inds_to_perturb["end"] < 0:
925
- logger.error(
926
- 'cell_inds_to_perturb must be positive.'
927
- )
 
928
  raise
929
 
930
- def perturb_data(self,
931
- model_directory,
932
- input_data_file,
933
- output_directory,
934
- output_prefix):
935
  """
936
  Perturb genes in input data and save as results in output_directory.
937
 
@@ -947,365 +402,506 @@ class InSilicoPerturber:
947
  Prefix for output files
948
  """
949
 
950
- filtered_input_data = load_and_filter(self.filter_data, self.nproc, input_data_file)
951
- model = load_model(self.model_type, self.num_classes, model_directory)
952
- layer_to_quant = quant_layers(model)+self.emb_layer
953
-
954
- if self.cell_states_to_model is None:
955
- state_embs_dict = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
956
  else:
957
- # confirm that all states are valid to prevent futile filtering
958
- state_name = self.cell_states_to_model["state_key"]
959
- state_values = filtered_input_data[state_name]
960
- for value in get_possible_states(self.cell_states_to_model):
961
- if value not in state_values:
962
- logger.error(
963
- f"{value} is not present in the dataset's {state_name} attribute.")
964
- raise
965
- # get dictionary of average cell state embeddings for comparison
966
- downsampled_data = downsample_and_sort(filtered_input_data, self.max_ncells)
967
- state_embs_dict = get_cell_state_avg_embs(model,
968
- downsampled_data,
969
- self.cell_states_to_model,
970
- layer_to_quant,
971
- self.pad_token_id,
972
- self.forward_batch_size,
973
- self.nproc)
974
- # filter for start state cells
975
- start_state = self.cell_states_to_model["start_state"]
976
- def filter_for_origin(example):
977
- return example[state_name] in [start_state]
978
-
979
- filtered_input_data = filtered_input_data.filter(filter_for_origin, num_proc=self.nproc)
980
-
981
- self.in_silico_perturb(model,
982
- filtered_input_data,
983
- layer_to_quant,
984
- state_embs_dict,
985
- output_directory,
986
- output_prefix)
987
-
988
- # determine effect of perturbation on other genes
989
- def in_silico_perturb(self,
990
- model,
991
- filtered_input_data,
992
- layer_to_quant,
993
- state_embs_dict,
994
- output_directory,
995
- output_prefix):
996
-
997
- output_path_prefix = f"{output_directory}in_silico_{self.perturb_type}_{output_prefix}_dict_1Kbatch"
998
- model_input_size = get_model_input_size(model)
999
-
1000
- # filter dataset for cells that have tokens to be perturbed
1001
- if self.anchor_token is not None:
1002
- def if_has_tokens_to_perturb(example):
1003
- return (len(set(example["input_ids"]).intersection(self.anchor_token))==len(self.anchor_token))
1004
- filtered_input_data = filtered_input_data.filter(if_has_tokens_to_perturb, num_proc=self.nproc)
1005
- if len(filtered_input_data) == 0:
1006
- logger.error(
1007
- "No cells in dataset contain anchor gene.")
1008
- raise
1009
- else:
1010
- logger.info(f"# cells with anchor gene: {len(filtered_input_data)}")
1011
-
1012
  if (self.tokens_to_perturb != "all") and (self.perturb_type != "overexpress"):
1013
- # minimum # genes needed for perturbation test
1014
- min_genes = len(self.tokens_to_perturb)
1015
-
1016
- def if_has_tokens_to_perturb(example):
1017
- return (len(set(example["input_ids"]).intersection(self.tokens_to_perturb))>=min_genes)
1018
- filtered_input_data = filtered_input_data.filter(if_has_tokens_to_perturb, num_proc=self.nproc)
1019
- if len(filtered_input_data) == 0:
1020
- logger.error(
1021
- "No cells in dataset contain all genes to perturb as a group.")
1022
- raise
1023
-
1024
- cos_sims_dict = defaultdict(list)
1025
- pickle_batch = -1
1026
- filtered_input_data = downsample_and_sort(filtered_input_data, self.max_ncells)
 
 
 
 
 
 
1027
  if self.cell_inds_to_perturb != "all":
1028
- if self.cell_inds_to_perturb["start"] >= len(filtered_input_data):
1029
- logger.error("cell_inds_to_perturb['start'] is larger than the filtered dataset.")
1030
- raise
1031
- if self.cell_inds_to_perturb["end"] > len(filtered_input_data):
1032
- logger.warning("cell_inds_to_perturb['end'] is larger than the filtered dataset. \
1033
- Setting to the end of the filtered dataset.")
1034
- self.cell_inds_to_perturb["end"] = len(filtered_input_data)
1035
- filtered_input_data = filtered_input_data.select([i for i in range(self.cell_inds_to_perturb["start"], self.cell_inds_to_perturb["end"])])
1036
-
1037
- # make perturbation batch w/ single perturbation in multiple cells
1038
- if self.perturb_group == True:
1039
-
1040
- def make_group_perturbation_batch(example):
1041
- example_input_ids = example["input_ids"]
1042
- example["tokens_to_perturb"] = self.tokens_to_perturb
1043
- indices_to_perturb = [example_input_ids.index(token) if token in example_input_ids else None for token in self.tokens_to_perturb]
1044
- indices_to_perturb = [item for item in indices_to_perturb if item is not None]
1045
- if len(indices_to_perturb) > 0:
1046
- example["perturb_index"] = indices_to_perturb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1047
  else:
1048
- # -100 indicates tokens to overexpress are not present in rank value encoding
1049
- example["perturb_index"] = [-100]
1050
- if self.perturb_type == "delete":
1051
- example = delete_indices(example)
1052
- elif self.perturb_type == "overexpress":
1053
- example = overexpress_tokens(example)
1054
- return example
1055
-
1056
- perturbation_batch = filtered_input_data.map(make_group_perturbation_batch, num_proc=self.nproc)
1057
- indices_to_perturb = perturbation_batch["perturb_index"]
1058
-
1059
- cos_sims_data = quant_cos_sims(model,
1060
- self.perturb_type,
1061
- perturbation_batch,
1062
- self.forward_batch_size,
1063
- layer_to_quant,
1064
- filtered_input_data,
1065
- self.tokens_to_perturb,
1066
- indices_to_perturb,
1067
- self.perturb_group,
1068
- self.cell_states_to_model,
1069
- state_embs_dict,
1070
- self.pad_token_id,
1071
- model_input_size,
1072
- self.nproc)
1073
-
1074
- perturbed_genes = tuple(self.tokens_to_perturb)
1075
- original_lengths = filtered_input_data["length"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1076
  if self.cell_states_to_model is None:
1077
- # update cos sims dict
1078
- # key is tuple of (perturbed_gene, affected_gene)
1079
- # or (perturbed_genes, "cell_emb") for avg cell emb change
1080
- cos_sims_data = cos_sims_data.to("cuda")
1081
- max_padded_len = cos_sims_data.shape[1]
1082
- for j in range(cos_sims_data.shape[0]):
1083
- # remove padding before mean pooling cell embedding
1084
- original_length = original_lengths[j]
1085
- gene_list = filtered_input_data[j]["input_ids"]
1086
- indices_removed = indices_to_perturb[j]
1087
- padding_to_remove = max_padded_len - (original_length \
1088
- - len(self.tokens_to_perturb) \
1089
- - len(indices_removed))
1090
- nonpadding_cos_sims_data = cos_sims_data[j][:-padding_to_remove]
1091
- cell_cos_sim = torch.mean(nonpadding_cos_sims_data).item()
1092
- cos_sims_dict[(perturbed_genes, "cell_emb")] += [cell_cos_sim]
1093
-
1094
- if self.emb_mode == "cell_and_gene":
1095
- for k in range(cos_sims_data.shape[1]):
1096
- cos_sim_value = nonpadding_cos_sims_data[k]
1097
- affected_gene = gene_list[k].item()
1098
- cos_sims_dict[(perturbed_genes, affected_gene)] += [cos_sim_value.item()]
1099
  else:
1100
- # update cos sims dict
1101
- # key is tuple of (perturbed_genes, "cell_emb")
1102
- # value is list of tuples of cos sims for cell_states_to_model
1103
- origin_state_key = self.cell_states_to_model["start_state"]
1104
- cos_sims_origin = cos_sims_data[origin_state_key]
1105
- for j in range(cos_sims_origin.shape[0]):
1106
- data_list = []
1107
- for data in list(cos_sims_data.values()):
1108
- data_item = data.to("cuda")
1109
- data_list += [data_item[j].item()]
1110
- cos_sims_dict[(perturbed_genes, "cell_emb")] += [tuple(data_list)]
1111
-
1112
- with open(f"{output_path_prefix}_raw.pickle", "wb") as fp:
1113
- pickle.dump(cos_sims_dict, fp)
1114
-
1115
- # make perturbation batch w/ multiple perturbations in single cell
1116
- if self.perturb_group == False:
1117
-
1118
- for i in trange(len(filtered_input_data)):
1119
- example_cell = filtered_input_data.select([i])
1120
- original_emb = forward_pass_single_cell(model, example_cell, layer_to_quant)
1121
- gene_list = torch.squeeze(example_cell["input_ids"])
1122
-
1123
- # reset to original type to prevent downstream issues due to forward_pass_single_cell modifying as torch format in place
1124
- example_cell = filtered_input_data.select([i])
1125
-
1126
- if self.anchor_token is None:
1127
- for combo_lvl in range(self.combos+1):
1128
- perturbation_batch, indices_to_perturb = make_perturbation_batch(example_cell,
1129
- self.perturb_type,
1130
- self.tokens_to_perturb,
1131
- self.anchor_token,
1132
- combo_lvl,
1133
- self.nproc)
1134
- cos_sims_data = quant_cos_sims(model,
1135
- self.perturb_type,
1136
- perturbation_batch,
1137
- self.forward_batch_size,
1138
- layer_to_quant,
1139
- original_emb,
1140
- self.tokens_to_perturb,
1141
- indices_to_perturb,
1142
- self.perturb_group,
1143
- self.cell_states_to_model,
1144
- state_embs_dict,
1145
- self.pad_token_id,
1146
- model_input_size,
1147
- self.nproc)
1148
-
1149
- if self.cell_states_to_model is None:
1150
- # update cos sims dict
1151
- # key is tuple of (perturbed_gene, affected_gene)
1152
- # or (perturbed_gene, "cell_emb") for avg cell emb change
1153
- cos_sims_data = cos_sims_data.to("cuda")
1154
- for j in range(cos_sims_data.shape[0]):
1155
- if self.tokens_to_perturb != "all":
1156
- j_index = torch.tensor(indices_to_perturb[j])
1157
- if j_index.shape[0]>1:
1158
- j_index = torch.squeeze(j_index)
1159
- else:
1160
- j_index = torch.tensor([j])
1161
-
1162
- if self.perturb_type in ("overexpress", "activate"):
1163
- perturbed_gene = torch.index_select(gene_list, 0, j_index + 1)
1164
- else:
1165
- perturbed_gene = torch.index_select(gene_list, 0, j_index)
1166
-
1167
- if perturbed_gene.shape[0]==1:
1168
- perturbed_gene = perturbed_gene.item()
1169
- elif perturbed_gene.shape[0]>1:
1170
- perturbed_gene = tuple(perturbed_gene.tolist())
1171
-
1172
- cell_cos_sim = torch.mean(cos_sims_data[j]).item()
1173
- cos_sims_dict[(perturbed_gene, "cell_emb")] += [cell_cos_sim]
1174
-
1175
- # not_j_index = list(set(i for i in range(gene_list.shape[0])).difference(j_index))
1176
- # gene_list_j = torch.index_select(gene_list, 0, j_index)
1177
- if self.emb_mode == "cell_and_gene":
1178
- for k in range(cos_sims_data.shape[1]):
1179
- cos_sim_value = cos_sims_data[j][k]
1180
- affected_gene = gene_list[k].item()
1181
- cos_sims_dict[(perturbed_gene, affected_gene)] += [cos_sim_value.item()]
1182
- else:
1183
- # update cos sims dict
1184
- # key is tuple of (perturbed_gene, "cell_emb")
1185
- # value is list of tuples of cos sims for cell_states_to_model
1186
- origin_state_key = self.cell_states_to_model["start_state"]
1187
- cos_sims_origin = cos_sims_data[origin_state_key]
1188
-
1189
- for j in range(cos_sims_origin.shape[0]):
1190
- if (self.tokens_to_perturb != "all") or (combo_lvl>0):
1191
- j_index = torch.tensor(indices_to_perturb[j])
1192
- if j_index.shape[0]>1:
1193
- j_index = torch.squeeze(j_index)
1194
- else:
1195
- j_index = torch.tensor([j])
1196
-
1197
- if self.perturb_type in ("overexpress", "activate"):
1198
- perturbed_gene = torch.index_select(gene_list, 0, j_index + 1)
1199
- else:
1200
- perturbed_gene = torch.index_select(gene_list, 0, j_index)
1201
-
1202
- if perturbed_gene.shape[0]==1:
1203
- perturbed_gene = perturbed_gene.item()
1204
- elif perturbed_gene.shape[0]>1:
1205
- perturbed_gene = tuple(perturbed_gene.tolist())
1206
-
1207
- data_list = []
1208
- for data in list(cos_sims_data.values()):
1209
- data_item = data.to("cuda")
1210
- cell_data = torch.mean(data_item[j]).item()
1211
- data_list += [cell_data]
1212
- cos_sims_dict[(perturbed_gene, "cell_emb")] += [tuple(data_list)]
1213
-
1214
- elif self.anchor_token is not None:
1215
- perturbation_batch, indices_to_perturb = make_perturbation_batch(example_cell,
1216
- self.perturb_type,
1217
- self.tokens_to_perturb,
1218
- None, # first run without anchor token to test individual gene perturbations
1219
- 0,
1220
- self.nproc)
1221
- cos_sims_data = quant_cos_sims(model,
1222
- self.perturb_type,
1223
- perturbation_batch,
1224
- self.forward_batch_size,
1225
- layer_to_quant,
1226
- original_emb,
1227
- self.tokens_to_perturb,
1228
- indices_to_perturb,
1229
- self.perturb_group,
1230
- self.cell_states_to_model,
1231
- state_embs_dict,
1232
- self.pad_token_id,
1233
- model_input_size,
1234
- self.nproc)
1235
- cos_sims_data = cos_sims_data.to("cuda")
1236
-
1237
- combo_perturbation_batch, combo_indices_to_perturb = make_perturbation_batch(example_cell,
1238
- self.perturb_type,
1239
- self.tokens_to_perturb,
1240
- self.anchor_token,
1241
- 1,
1242
- self.nproc)
1243
- combo_cos_sims_data = quant_cos_sims(model,
1244
- self.perturb_type,
1245
- combo_perturbation_batch,
1246
- self.forward_batch_size,
1247
- layer_to_quant,
1248
- original_emb,
1249
- self.tokens_to_perturb,
1250
- combo_indices_to_perturb,
1251
- self.perturb_group,
1252
- self.cell_states_to_model,
1253
- state_embs_dict,
1254
- self.pad_token_id,
1255
- model_input_size,
1256
- self.nproc)
1257
- combo_cos_sims_data = combo_cos_sims_data.to("cuda")
1258
-
1259
- # update cos sims dict
1260
- # key is tuple of (perturbed_gene, "cell_emb") for avg cell emb change
1261
- anchor_index = example_cell["input_ids"][0].index(self.anchor_token[0])
1262
- anchor_cell_cos_sim = torch.mean(cos_sims_data[anchor_index]).item()
1263
- non_anchor_indices = [k for k in range(cos_sims_data.shape[0]) if k != anchor_index]
1264
- cos_sims_data = cos_sims_data[non_anchor_indices,:]
1265
-
1266
- for j in range(cos_sims_data.shape[0]):
1267
-
1268
- if j<anchor_index:
1269
- j_index = torch.tensor([j])
1270
- else:
1271
- j_index = torch.tensor([j+1])
1272
-
1273
- perturbed_gene = torch.index_select(gene_list, 0, j_index)
1274
- perturbed_gene = perturbed_gene.item()
1275
-
1276
- cell_cos_sim = torch.mean(cos_sims_data[j]).item()
1277
- combo_cos_sim = torch.mean(combo_cos_sims_data[j]).item()
1278
- cos_sims_dict[(perturbed_gene, "cell_emb")] += [(anchor_cell_cos_sim, # cos sim anchor gene alone
1279
- cell_cos_sim, # cos sim deleted gene alone
1280
- combo_cos_sim)] # cos sim anchor gene + deleted gene
1281
-
1282
- # save dict to disk every 100 cells
1283
- if (i/100).is_integer():
1284
- with open(f"{output_path_prefix}{pickle_batch}_raw.pickle", "wb") as fp:
1285
- pickle.dump(cos_sims_dict, fp)
1286
- # reset and clear memory every 1000 cells
1287
- if (i/1000).is_integer():
1288
- pickle_batch = pickle_batch+1
1289
- # clear memory
1290
- del perturbed_gene
1291
- del cos_sims_data
1292
- if self.cell_states_to_model is None:
1293
- del cell_cos_sim
1294
- if self.cell_states_to_model is not None:
1295
- del cell_data
1296
- del data_list
1297
- elif self.anchor_token is None:
1298
- if self.emb_mode == "cell_and_gene":
1299
- del affected_gene
1300
- del cos_sim_value
1301
- else:
1302
- del combo_cos_sim
1303
- del combo_cos_sims_data
1304
- # reset dict
1305
- del cos_sims_dict
1306
  cos_sims_dict = defaultdict(list)
1307
- torch.cuda.empty_cache()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1308
 
1309
- # save remainder cells
1310
- with open(f"{output_path_prefix}{pickle_batch}_raw.pickle", "wb") as fp:
1311
- pickle.dump(cos_sims_dict, fp)
8
  genes_to_perturb="all",
9
  combos=0,
10
  anchor_gene=None,
11
+ model_type="CellClassifier",
12
  num_classes=0,
13
  emb_mode="cell",
14
  cell_emb_style="mean_pool",
15
  filter_data={"cell_type":["cardiomyocyte"]},
16
  cell_states_to_model={"state_key": "disease", "start_state": "dcm", "goal_state": "nf", "alt_states": ["hcm", "other1", "other2"]},
17
+ state_embs_dict ={"nf": emb_nf, "hcm": emb_hcm, "dcm": emb_dcm, "other1": emb_other1, "other2": emb_other2},
18
  max_ncells=None,
19
+ emb_layer=0,
20
  forward_batch_size=100,
21
+ nproc=16)
22
  isp.perturb_data("path/to/model",
23
  "path/to/input_data",
24
  "path/to/output_directory",
25
  "output_prefix")
26
  """
27
 
 
 
28
  import logging
29
+
30
+ # imports
31
+ import os
32
  import pickle
 
 
 
33
  from collections import defaultdict
34
+
35
+ import seaborn as sns
36
+ import torch
37
+ from datasets import Dataset
38
  from tqdm.auto import trange
 
39
 
40
+ from . import perturber_utils as pu
41
+ from .emb_extractor import get_embs
42
  from .tokenizer import TOKEN_DICTIONARY_FILE
43
 
44
+ sns.set()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
 
 
 
 
 
46
 
47
+ logger = logging.getLogger(__name__)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
 
50
  class InSilicoPerturber:
51
  valid_option_dict = {
52
+ "perturb_type": {"delete", "overexpress", "inhibit", "activate"},
53
  "perturb_rank_shift": {None, 1, 2, 3},
54
  "genes_to_perturb": {"all", list},
55
  "combos": {0, 1},
56
  "anchor_gene": {None, str},
57
+ "model_type": {"Pretrained", "GeneClassifier", "CellClassifier"},
58
  "num_classes": {int},
59
+ "emb_mode": {"cell", "cell_and_gene"},
60
  "cell_emb_style": {"mean_pool"},
61
  "filter_data": {None, dict},
62
  "cell_states_to_model": {None, dict},
63
+ "state_embs_dict": {None, dict},
64
  "max_ncells": {None, int},
65
  "cell_inds_to_perturb": {"all", dict},
66
  "emb_layer": {-1, 0},
67
  "forward_batch_size": {int},
68
  "nproc": {int},
69
  }
70
+
71
  def __init__(
72
  self,
73
  perturb_type="delete",
81
  cell_emb_style="mean_pool",
82
  filter_data=None,
83
  cell_states_to_model=None,
84
+ state_embs_dict=None,
85
  max_ncells=None,
86
  cell_inds_to_perturb="all",
87
  emb_layer=-1,
129
  For the pretrained Geneformer model, number of classes is 0 as it is not a classifier.
130
  emb_mode : {"cell","cell_and_gene"}
131
  Whether to output impact of perturbation on cell and/or gene embeddings.
132
+ Gene embedding shifts only available as compared to original cell, not comparing to goal state.
133
  cell_emb_style : "mean_pool"
134
  Method for summarizing cell embeddings.
135
  Currently only option is mean pooling of gene embeddings for given cell.
136
  filter_data : None, dict
137
  Default is to use all input data for in silico perturbation study.
138
  Otherwise, dictionary specifying .dataset column name and list of values to filter by.
139
+ cell_states_to_model : None, dict
140
  Cell states to model if testing perturbations that achieve goal state change.
141
  Four-item dictionary with keys: state_key, start_state, goal_state, and alt_states
142
  state_key: key specifying name of column in .dataset that defines the start/goal states
147
  "start_state": "dcm",
148
  "goal_state": "nf",
149
  "alt_states": ["hcm", "other1", "other2"]}
150
+ state_embs_dict : None, dict
151
+ Embedding positions of each cell state to model shifts from/towards (e.g. mean or median).
152
+ Dictionary with keys specifying each possible cell state to model.
153
+ Values are target embedding positions as torch.tensor.
154
+ For example: {"nf": emb_nf,
155
+ "hcm": emb_hcm,
156
+ "dcm": emb_dcm,
157
+ "other1": emb_other1,
158
+ "other2": emb_other2}
159
  max_ncells : None, int
160
  Maximum number of cells to test.
161
  If None, will test all cells.
168
  Useful for splitting extremely large datasets across separate GPUs.
169
  emb_layer : {-1, 0}
170
  Embedding layer to use for quantification.
171
+ 0: last layer (recommended for questions closely tied to model's training objective)
172
+ -1: 2nd to last layer (recommended for questions requiring more general representations)
173
  forward_batch_size : int
174
  Batch size for forward pass.
175
  nproc : int
184
  self.combos = combos
185
  self.anchor_gene = anchor_gene
186
  if self.genes_to_perturb == "all":
187
+ self.perturb_group = False
188
  else:
189
  self.perturb_group = True
190
+ if (self.anchor_gene is not None) or (self.combos != 0):
191
  self.anchor_gene = None
192
  self.combos = 0
193
  logger.warning(
194
+ "anchor_gene set to None and combos set to 0. "
195
+ "If providing list of genes to perturb, "
196
+ "list of genes_to_perturb will be perturbed together, "
197
+ "without anchor gene or combinations."
198
+ )
199
  self.model_type = model_type
200
  self.num_classes = num_classes
201
  self.emb_mode = emb_mode
202
  self.cell_emb_style = cell_emb_style
203
  self.filter_data = filter_data
204
  self.cell_states_to_model = cell_states_to_model
205
+ self.state_embs_dict = state_embs_dict
206
  self.max_ncells = max_ncells
207
  self.cell_inds_to_perturb = cell_inds_to_perturb
208
  self.emb_layer = emb_layer
223
  try:
224
  self.anchor_token = [self.gene_token_dict[self.anchor_gene]]
225
  except KeyError:
226
+ logger.error(f"Anchor gene {self.anchor_gene} not in token dictionary.")
 
 
227
  raise
228
 
229
  if self.genes_to_perturb == "all":
230
  self.tokens_to_perturb = "all"
231
  else:
232
+ missing_genes = [
233
+ gene
234
+ for gene in self.genes_to_perturb
235
+ if gene not in self.gene_token_dict.keys()
236
+ ]
237
  if len(missing_genes) == len(self.genes_to_perturb):
238
  logger.error(
239
  "None of the provided genes to perturb are in token dictionary."
240
  )
241
  raise
242
+ elif len(missing_genes) > 0:
243
  logger.warning(
244
+ f"Genes to perturb {missing_genes} are not in token dictionary."
245
+ )
246
+ self.tokens_to_perturb = [
247
+ self.gene_token_dict.get(gene) for gene in self.genes_to_perturb
248
+ ]
249
 
250
  def validate_options(self):
251
  # first disallow options under development
252
  if self.perturb_type in ["inhibit", "activate"]:
253
  logger.error(
254
+ "In silico inhibition and activation currently under development. "
255
  "Current valid options for 'perturb_type': 'delete' or 'overexpress'"
256
  )
257
  raise
258
+ if (self.combos > 0) and (self.anchor_token is None):
259
+ logger.error(
260
+ "Combination perturbation without anchor gene is currently under development. "
261
+ "Currently, must provide anchor gene for combination perturbation."
262
+ )
263
+ raise
264
+
265
  # confirm arguments are within valid options and compatible with each other
266
+ for attr_name, valid_options in self.valid_option_dict.items():
267
  attr_value = self.__dict__[attr_name]
268
  if type(attr_value) not in {list, dict}:
269
  if attr_value in valid_options:
273
  continue
274
  valid_type = False
275
  for option in valid_options:
276
+ if (option in [bool, int, list, dict]) and isinstance(
277
+ attr_value, option
278
+ ):
279
  valid_type = True
280
  break
281
  if valid_type:
282
  continue
283
  logger.error(
284
+ f"Invalid option for {attr_name}. "
285
  f"Valid options for {attr_name}: {valid_options}"
286
  )
287
  raise
288
+
289
+ if self.perturb_type in ["delete", "overexpress"]:
290
  if self.perturb_rank_shift is not None:
291
  if self.perturb_type == "delete":
292
  logger.warning(
293
+ "perturb_rank_shift set to None. "
294
+ "If perturb type is delete then gene is deleted entirely "
295
+ "rather than shifted by quartile"
296
+ )
297
  elif self.perturb_type == "overexpress":
298
  logger.warning(
299
+ "perturb_rank_shift set to None. "
300
+ "If perturb type is overexpress then gene is moved to front "
301
+ "of rank value encoding rather than shifted by quartile"
302
+ )
303
  self.perturb_rank_shift = None
304
+
305
  if (self.anchor_gene is not None) and (self.emb_mode == "cell_and_gene"):
306
  self.emb_mode = "cell"
307
  logger.warning(
308
+ "emb_mode set to 'cell'. "
309
+ "Currently, analysis with anchor gene "
310
+ "only outputs effect on cell embeddings."
311
+ )
312
+
313
  if self.cell_states_to_model is not None:
314
+ pu.validate_cell_states_to_model(self.cell_states_to_model)
315
+
316
+ if self.anchor_gene is not None:
317
+ self.anchor_gene = None
318
  logger.warning(
319
+ "anchor_gene set to None. "
320
+ "Currently, anchor gene not available "
321
+ "when modeling multiple cell states."
 
 
 
 
 
 
322
  )
323
+
324
+ if self.state_embs_dict is None:
325
+ logger.error(
326
+ "state_embs_dict must be provided for mode with cell_states_to_model. "
327
+ "Format is dictionary with keys specifying each possible cell state to model. "
328
+ "Values are target embedding positions as torch.tensor."
329
+ )
330
+ raise
331
+
332
+ for state_emb in self.state_embs_dict.values():
333
+ if not torch.is_tensor(state_emb):
 
 
 
 
 
 
 
 
 
 
 
 
 
334
  logger.error(
335
+ "state_embs_dict must be dictionary with values being torch.tensor."
336
+ )
337
  raise
338
 
339
+ keys_absent = []
340
+ for k, v in self.cell_states_to_model.items():
341
+ if (k == "start_state") or (k == "goal_state"):
342
+ if v not in self.state_embs_dict.keys():
343
+ keys_absent.append(v)
344
+ if k == "alt_states":
345
+ for state in v:
346
+ if state not in self.state_embs_dict.keys():
347
+ keys_absent.append(state)
348
+ if len(keys_absent) > 0:
 
 
349
  logger.error(
350
+ "Each start_state, goal_state, and alt_states in cell_states_to_model "
351
+ "must be a key in state_embs_dict with the value being "
352
+ "the state's embedding position as torch.tensor. "
353
+ f"Missing keys: {keys_absent}"
 
 
 
354
  )
355
  raise
356
 
357
+ if self.perturb_type in ["inhibit", "activate"]:
 
 
 
 
 
 
 
358
  if self.perturb_rank_shift is None:
359
  logger.error(
360
+ "If perturb_type is inhibit or activate then "
361
+ "quartile to shift by must be specified."
362
+ )
363
  raise
364
+
365
  if self.filter_data is not None:
366
+ for key, value in self.filter_data.items():
367
+ if not isinstance(value, list):
368
  self.filter_data[key] = [value]
369
  logger.warning(
370
+ "Values in filter_data dict must be lists. "
371
+ f"Changing {key} value to list ([{value}])."
372
+ )
373
+
374
  if self.cell_inds_to_perturb != "all":
375
  if set(self.cell_inds_to_perturb.keys()) != {"start", "end"}:
376
  logger.error(
377
  "If cell_inds_to_perturb is a dictionary, keys must be 'start' and 'end'."
378
  )
379
  raise
380
+ if (
381
+ self.cell_inds_to_perturb["start"] < 0
382
+ or self.cell_inds_to_perturb["end"] < 0
383
+ ):
384
+ logger.error("cell_inds_to_perturb must be positive.")
385
  raise
386
 
387
+ def perturb_data(
388
+ self, model_directory, input_data_file, output_directory, output_prefix
389
+ ):
 
 
390
  """
391
  Perturb genes in input data and save as results in output_directory.
392
 
402
  Prefix for output files
403
  """
404
 
405
+ ### format output path ###
406
+ output_path_prefix = os.path.join(
407
+ output_directory, f"in_silico_{self.perturb_type}_{output_prefix}"
408
+ )
409
+
410
+ ### load model and define parameters ###
411
+ model = pu.load_model(self.model_type, self.num_classes, model_directory)
412
+ self.max_len = pu.get_model_input_size(model)
413
+ layer_to_quant = pu.quant_layers(model) + self.emb_layer
414
+
415
+ ### filter input data ###
416
+ # general filtering of input data based on filter_data argument
417
+ filtered_input_data = pu.load_and_filter(
418
+ self.filter_data, self.nproc, input_data_file
419
+ )
420
+ filtered_input_data = self.apply_additional_filters(filtered_input_data)
421
+
422
+ if self.perturb_group is True:
423
+ self.isp_perturb_set(
424
+ model, filtered_input_data, layer_to_quant, output_path_prefix
425
+ )
426
  else:
427
+ self.isp_perturb_all(
428
+ model, filtered_input_data, layer_to_quant, output_path_prefix
429
+ )
430
+
431
+ def apply_additional_filters(self, filtered_input_data):
432
+ # additional filtering of input data dependent on isp mode
433
+ if self.cell_states_to_model is not None:
434
+ # filter for cells with start_state and log result
435
+ filtered_input_data = pu.filter_data_by_start_state(
436
+ filtered_input_data, self.cell_states_to_model, self.nproc
437
+ )
438
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
439
  if (self.tokens_to_perturb != "all") and (self.perturb_type != "overexpress"):
440
+ # filter for cells with tokens_to_perturb and log result
441
+ filtered_input_data = pu.filter_data_by_tokens_and_log(
442
+ filtered_input_data,
443
+ self.tokens_to_perturb,
444
+ self.nproc,
445
+ "genes_to_perturb",
446
+ )
447
+
448
+ if self.anchor_token is not None:
449
+ # filter for cells with anchor gene and log result
450
+ filtered_input_data = pu.filter_data_by_tokens_and_log(
451
+ filtered_input_data, self.anchor_token, self.nproc, "anchor_gene"
452
+ )
453
+
454
+ # downsample and sort largest to smallest to encounter memory constraints earlier
455
+ filtered_input_data = pu.downsample_and_sort(
456
+ filtered_input_data, self.max_ncells
457
+ )
458
+
459
+ # slice dataset if cells_inds_to_perturb is not "all"
460
  if self.cell_inds_to_perturb != "all":
461
+ filtered_input_data = pu.slice_by_inds_to_perturb(
462
+ filtered_input_data, self.cell_inds_to_perturb
463
+ )
464
+
465
+ return filtered_input_data
466
+
467
+ def isp_perturb_set(
468
+ self,
469
+ model,
470
+ filtered_input_data: Dataset,
471
+ layer_to_quant: int,
472
+ output_path_prefix: str,
473
+ ):
474
+ def make_group_perturbation_batch(example):
475
+ example_input_ids = example["input_ids"]
476
+ example["tokens_to_perturb"] = self.tokens_to_perturb
477
+ indices_to_perturb = [
478
+ example_input_ids.index(token) if token in example_input_ids else None
479
+ for token in self.tokens_to_perturb
480
+ ]
481
+ indices_to_perturb = [
482
+ item for item in indices_to_perturb if item is not None
483
+ ]
484
+ if len(indices_to_perturb) > 0:
485
+ example["perturb_index"] = indices_to_perturb
486
+ else:
487
+ # -100 indicates tokens to overexpress are not present in rank value encoding
488
+ example["perturb_index"] = [-100]
489
+ if self.perturb_type == "delete":
490
+ example = pu.delete_indices(example)
491
+ elif self.perturb_type == "overexpress":
492
+ example = pu.overexpress_tokens(example, self.max_len)
493
+ example["n_overflow"] = pu.calc_n_overflow(
494
+ self.max_len,
495
+ example["length"],
496
+ self.tokens_to_perturb,
497
+ indices_to_perturb,
498
+ )
499
+ return example
500
+
501
+ total_batch_length = len(filtered_input_data)
502
+ if self.cell_states_to_model is None:
503
+ cos_sims_dict = defaultdict(list)
504
+ else:
505
+ cos_sims_dict = {
506
+ state: defaultdict(list)
507
+ for state in pu.get_possible_states(self.cell_states_to_model)
508
+ }
509
+
510
+ perturbed_data = filtered_input_data.map(
511
+ make_group_perturbation_batch, num_proc=self.nproc
512
+ )
513
+ if self.perturb_type == "overexpress":
514
+ filtered_input_data = filtered_input_data.add_column(
515
+ "n_overflow", perturbed_data["n_overflow"]
516
+ )
517
+ # remove overflow genes from original data so that embeddings are comparable
518
+ # i.e. if original cell has genes 0:2047 and you want to overexpress new gene 2048,
519
+ # then the perturbed cell will be 2048+0:2046 so we compare it to an original cell 0:2046.
520
+ # (otherwise we will be modeling the effect of both deleting 2047 and adding 2048,
521
+ # rather than only adding 2048)
522
+ filtered_input_data = filtered_input_data.map(
523
+ pu.truncate_by_n_overflow, num_proc=self.nproc
524
+ )
525
+
526
+ if self.emb_mode == "cell_and_gene":
527
+ stored_gene_embs_dict = defaultdict(list)
528
+
529
+ # iterate through batches
530
+ for i in trange(0, total_batch_length, self.forward_batch_size):
531
+ max_range = min(i + self.forward_batch_size, total_batch_length)
532
+ inds_select = [i for i in range(i, max_range)]
533
+
534
+ minibatch = filtered_input_data.select(inds_select)
535
+ perturbation_batch = perturbed_data.select(inds_select)
536
+
537
+ if self.cell_emb_style == "mean_pool":
538
+ full_original_emb = get_embs(
539
+ model,
540
+ minibatch,
541
+ "gene",
542
+ layer_to_quant,
543
+ self.pad_token_id,
544
+ self.forward_batch_size,
545
+ summary_stat=None,
546
+ silent=True,
547
+ )
548
+ indices_to_perturb = perturbation_batch["perturb_index"]
549
+ # remove indices that were perturbed
550
+ original_emb = pu.remove_perturbed_indices_set(
551
+ full_original_emb,
552
+ self.perturb_type,
553
+ indices_to_perturb,
554
+ self.tokens_to_perturb,
555
+ minibatch["length"],
556
+ )
557
+ full_perturbation_emb = get_embs(
558
+ model,
559
+ perturbation_batch,
560
+ "gene",
561
+ layer_to_quant,
562
+ self.pad_token_id,
563
+ self.forward_batch_size,
564
+ summary_stat=None,
565
+ silent=True,
566
+ )
567
+
568
+ # remove overexpressed genes
569
+ if self.perturb_type == "overexpress":
570
+ perturbation_emb = full_perturbation_emb[
571
+ :, len(self.tokens_to_perturb) :, :
572
+ ]
573
+
574
+ elif self.perturb_type == "delete":
575
+ perturbation_emb = full_perturbation_emb[
576
+ :, : max(perturbation_batch["length"]), :
577
+ ]
578
+
579
+ n_perturbation_genes = perturbation_emb.size()[1]
580
+
581
+ # if no goal states, the cosine similarties are the mean of gene cosine similarities
582
+ if (
583
+ self.cell_states_to_model is None
584
+ or self.emb_mode == "cell_and_gene"
585
+ ):
586
+ gene_cos_sims = pu.quant_cos_sims(
587
+ perturbation_emb,
588
+ original_emb,
589
+ self.cell_states_to_model,
590
+ self.state_embs_dict,
591
+ emb_mode="gene",
592
+ )
593
+
594
+ # if there are goal states, the cosine similarities are the cell cosine similarities
595
+ if self.cell_states_to_model is not None:
596
+ original_cell_emb = pu.mean_nonpadding_embs(
597
+ full_original_emb,
598
+ torch.tensor(minibatch["length"], device="cuda"),
599
+ dim=1,
600
+ )
601
+ perturbation_cell_emb = pu.mean_nonpadding_embs(
602
+ full_perturbation_emb,
603
+ torch.tensor(perturbation_batch["length"], device="cuda"),
604
+ dim=1,
605
+ )
606
+ cell_cos_sims = pu.quant_cos_sims(
607
+ perturbation_cell_emb,
608
+ original_cell_emb,
609
+ self.cell_states_to_model,
610
+ self.state_embs_dict,
611
+ emb_mode="cell",
612
+ )
613
+
614
+ # get cosine similarities in gene embeddings
615
+ # if getting gene embeddings, need gene names
616
+ if self.emb_mode == "cell_and_gene":
617
+ gene_list = minibatch["input_ids"]
618
+ # need to truncate gene_list
619
+ gene_list = [
620
+ [g for g in genes if g not in self.tokens_to_perturb][
621
+ :n_perturbation_genes
622
+ ]
623
+ for genes in gene_list
624
+ ]
625
+
626
+ for cell_i, genes in enumerate(gene_list):
627
+ for gene_j, affected_gene in enumerate(genes):
628
+ if len(self.genes_to_perturb) > 1:
629
+ tokens_to_perturb = tuple(self.tokens_to_perturb)
630
+ else:
631
+ tokens_to_perturb = self.tokens_to_perturb
632
+
633
+ # fill in the gene cosine similarities
634
+ try:
635
+ stored_gene_embs_dict[
636
+ (tokens_to_perturb, affected_gene)
637
+ ].append(gene_cos_sims[cell_i, gene_j].item())
638
+ except KeyError:
639
+ stored_gene_embs_dict[
640
+ (tokens_to_perturb, affected_gene)
641
+ ] = gene_cos_sims[cell_i, gene_j].item()
642
  else:
643
+ gene_list = None
644
+
645
+ if self.cell_states_to_model is None:
646
+ # calculate the mean of the gene cosine similarities for cell shift
647
+ # tensor of nonpadding lengths for each cell
648
+ if self.perturb_type == "overexpress":
649
+ # subtract number of genes that were overexpressed
650
+ # since they are removed before getting cos sims
651
+ n_overexpressed = len(self.tokens_to_perturb)
652
+ nonpadding_lens = [
653
+ x - n_overexpressed for x in perturbation_batch["length"]
654
+ ]
655
+ else:
656
+ nonpadding_lens = perturbation_batch["length"]
657
+ cos_sims_data = pu.mean_nonpadding_embs(
658
+ gene_cos_sims, torch.tensor(nonpadding_lens, device="cuda")
659
+ )
660
+ cos_sims_dict = self.update_perturbation_dictionary(
661
+ cos_sims_dict,
662
+ cos_sims_data,
663
+ filtered_input_data,
664
+ indices_to_perturb,
665
+ gene_list,
666
+ )
667
+ else:
668
+ cos_sims_data = cell_cos_sims
669
+ for state in cos_sims_dict.keys():
670
+ cos_sims_dict[state] = self.update_perturbation_dictionary(
671
+ cos_sims_dict[state],
672
+ cos_sims_data[state],
673
+ filtered_input_data,
674
+ indices_to_perturb,
675
+ gene_list,
676
+ )
677
+ del minibatch
678
+ del perturbation_batch
679
+ del original_emb
680
+ del perturbation_emb
681
+ del cos_sims_data
682
+
683
+ torch.cuda.empty_cache()
684
+
685
+ pu.write_perturbation_dictionary(
686
+ cos_sims_dict,
687
+ f"{output_path_prefix}_cell_embs_dict_{self.tokens_to_perturb}",
688
+ )
689
+
690
+ if self.emb_mode == "cell_and_gene":
691
+ pu.write_perturbation_dictionary(
692
+ stored_gene_embs_dict,
693
+ f"{output_path_prefix}_gene_embs_dict_{self.tokens_to_perturb}",
694
+ )
695
+
696
+ def isp_perturb_all(
697
+ self,
698
+ model,
699
+ filtered_input_data: Dataset,
700
+ layer_to_quant: int,
701
+ output_path_prefix: str,
702
+ ):
703
+ pickle_batch = -1
704
+ if self.cell_states_to_model is None:
705
+ cos_sims_dict = defaultdict(list)
706
+ else:
707
+ cos_sims_dict = {
708
+ state: defaultdict(list)
709
+ for state in pu.get_possible_states(self.cell_states_to_model)
710
+ }
711
+
712
+ if self.emb_mode == "cell_and_gene":
713
+ stored_gene_embs_dict = defaultdict(list)
714
+ for i in trange(len(filtered_input_data)):
715
+ example_cell = filtered_input_data.select([i])
716
+ full_original_emb = get_embs(
717
+ model,
718
+ example_cell,
719
+ "gene",
720
+ layer_to_quant,
721
+ self.pad_token_id,
722
+ self.forward_batch_size,
723
+ summary_stat=None,
724
+ silent=True,
725
+ )
726
+
727
+ # gene_list is used to assign cos sims back to genes
728
+ # need to remove the anchor gene
729
+ gene_list = example_cell["input_ids"][0][:]
730
+ if self.anchor_token is not None:
731
+ for token in self.anchor_token:
732
+ gene_list.remove(token)
733
+
734
+ perturbation_batch, indices_to_perturb = pu.make_perturbation_batch(
735
+ example_cell,
736
+ self.perturb_type,
737
+ self.tokens_to_perturb,
738
+ self.anchor_token,
739
+ self.combos,
740
+ self.nproc,
741
+ )
742
+
743
+ full_perturbation_emb = get_embs(
744
+ model,
745
+ perturbation_batch,
746
+ "gene",
747
+ layer_to_quant,
748
+ self.pad_token_id,
749
+ self.forward_batch_size,
750
+ summary_stat=None,
751
+ silent=True,
752
+ )
753
+
754
+ num_inds_perturbed = 1 + self.combos
755
+ # need to remove overexpressed gene to quantify cosine shifts
756
+ if self.perturb_type == "overexpress":
757
+ perturbation_emb = full_perturbation_emb[:, num_inds_perturbed:, :]
758
+ gene_list = gene_list[
759
+ num_inds_perturbed:
760
+ ] # index 0 is not overexpressed
761
+
762
+ elif self.perturb_type == "delete":
763
+ perturbation_emb = full_perturbation_emb
764
+
765
+ original_batch = pu.make_comparison_batch(
766
+ full_original_emb, indices_to_perturb, perturb_group=False
767
+ )
768
+
769
+ if self.cell_states_to_model is None or self.emb_mode == "cell_and_gene":
770
+ gene_cos_sims = pu.quant_cos_sims(
771
+ perturbation_emb,
772
+ original_batch,
773
+ self.cell_states_to_model,
774
+ self.state_embs_dict,
775
+ emb_mode="gene",
776
+ )
777
+ if self.cell_states_to_model is not None:
778
+ original_cell_emb = pu.compute_nonpadded_cell_embedding(
779
+ full_original_emb, "mean_pool"
780
+ )
781
+ perturbation_cell_emb = pu.compute_nonpadded_cell_embedding(
782
+ full_perturbation_emb, "mean_pool"
783
+ )
784
+
785
+ cell_cos_sims = pu.quant_cos_sims(
786
+ perturbation_cell_emb,
787
+ original_cell_emb,
788
+ self.cell_states_to_model,
789
+ self.state_embs_dict,
790
+ emb_mode="cell",
791
+ )
792
+
793
+ if self.emb_mode == "cell_and_gene":
794
+ # remove perturbed index for gene list
795
+ perturbed_gene_dict = {
796
+ gene: gene_list[:i] + gene_list[i + 1 :]
797
+ for i, gene in enumerate(gene_list)
798
+ }
799
+
800
+ for perturbation_i, perturbed_gene in enumerate(gene_list):
801
+ for gene_j, affected_gene in enumerate(
802
+ perturbed_gene_dict[perturbed_gene]
803
+ ):
804
+ try:
805
+ stored_gene_embs_dict[
806
+ (perturbed_gene, affected_gene)
807
+ ].append(gene_cos_sims[perturbation_i, gene_j].item())
808
+ except KeyError:
809
+ stored_gene_embs_dict[
810
+ (perturbed_gene, affected_gene)
811
+ ] = gene_cos_sims[perturbation_i, gene_j].item()
812
+
813
  if self.cell_states_to_model is None:
814
+ cos_sims_data = torch.mean(gene_cos_sims, dim=1)
815
+ cos_sims_dict = self.update_perturbation_dictionary(
816
+ cos_sims_dict,
817
+ cos_sims_data,
818
+ filtered_input_data,
819
+ indices_to_perturb,
820
+ gene_list,
821
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
822
  else:
823
+ cos_sims_data = cell_cos_sims
824
+ for state in cos_sims_dict.keys():
825
+ cos_sims_dict[state] = self.update_perturbation_dictionary(
826
+ cos_sims_dict[state],
827
+ cos_sims_data[state],
828
+ filtered_input_data,
829
+ indices_to_perturb,
830
+ gene_list,
831
+ )
832
+
833
+ # save dict to disk every 100 cells
834
+ if i % 100 == 0:
835
+ pu.write_perturbation_dictionary(
836
+ cos_sims_dict,
837
+ f"{output_path_prefix}_dict_cell_embs_1Kbatch{pickle_batch}",
838
+ )
839
+ if self.emb_mode == "cell_and_gene":
840
+ pu.write_perturbation_dictionary(
841
+ stored_gene_embs_dict,
842
+ f"{output_path_prefix}_dict_gene_embs_1Kbatch{pickle_batch}",
843
+ )
844
+
845
+ # reset and clear memory every 1000 cells
846
+ if i % 1000 == 0:
847
+ pickle_batch += 1
848
+ if self.cell_states_to_model is None:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
849
  cos_sims_dict = defaultdict(list)
850
+ else:
851
+ cos_sims_dict = {
852
+ state: defaultdict(list)
853
+ for state in pu.get_possible_states(self.cell_states_to_model)
854
+ }
855
+
856
+ if self.emb_mode == "cell_and_gene":
857
+ stored_gene_embs_dict = defaultdict(list)
858
+
859
+ torch.cuda.empty_cache()
860
+
861
+ pu.write_perturbation_dictionary(
862
+ cos_sims_dict, f"{output_path_prefix}_dict_cell_embs_1Kbatch{pickle_batch}"
863
+ )
864
+
865
+ if self.emb_mode == "cell_and_gene":
866
+ pu.write_perturbation_dictionary(
867
+ stored_gene_embs_dict,
868
+ f"{output_path_prefix}_dict_gene_embs_1Kbatch{pickle_batch}",
869
+ )
870
+
871
+ def update_perturbation_dictionary(
872
+ self,
873
+ cos_sims_dict: defaultdict,
874
+ cos_sims_data: torch.Tensor,
875
+ filtered_input_data: Dataset,
876
+ indices_to_perturb: list[list[int]],
877
+ gene_list=None,
878
+ ):
879
+ if gene_list is not None and cos_sims_data.shape[0] != len(gene_list):
880
+ logger.error(
881
+ f"len(cos_sims_data.shape[0]) != len(gene_list). \n \
882
+ cos_sims_data.shape[0] = {cos_sims_data.shape[0]}.\n \
883
+ len(gene_list) = {len(gene_list)}."
884
+ )
885
+ raise
886
+
887
+ if self.perturb_group is True:
888
+ if len(self.tokens_to_perturb) > 1:
889
+ perturbed_genes = tuple(self.tokens_to_perturb)
890
+ else:
891
+ perturbed_genes = self.tokens_to_perturb[0]
892
+
893
+ # if cell embeddings, can just append
894
+ # shape will be (batch size, 1)
895
+ cos_sims_data = torch.squeeze(cos_sims_data).tolist()
896
+
897
+ # handle case of single cell left
898
+ if not isinstance(cos_sims_data, list):
899
+ cos_sims_data = [cos_sims_data]
900
+
901
+ cos_sims_dict[(perturbed_genes, "cell_emb")] += cos_sims_data
902
+
903
+ else:
904
+ for i, cos in enumerate(cos_sims_data.tolist()):
905
+ cos_sims_dict[(gene_list[i], "cell_emb")].append(cos)
906
 
907
+ return cos_sims_dict
 
 
geneformer/in_silico_perturber_stats.py CHANGED
@@ -6,9 +6,9 @@ Usage:
6
  ispstats = InSilicoPerturberStats(mode="goal_state_shift",
7
  combos=0,
8
  anchor_gene=None,
9
- cell_states_to_model={"state_key": "disease",
10
- "start_state": "dcm",
11
- "goal_state": "nf",
12
  "alt_states": ["hcm", "other1", "other2"]})
13
  ispstats.get_stats("path/to/input_data",
14
  None,
@@ -17,88 +17,157 @@ Usage:
17
  """
18
 
19
 
20
- import os
21
  import logging
22
- import numpy as np
23
- import pandas as pd
24
  import pickle
25
  import random
26
- import statsmodels.stats.multitest as smt
27
  from pathlib import Path
 
 
 
 
28
  from scipy.stats import ranksums
29
  from sklearn.mixture import GaussianMixture
30
- from tqdm.auto import trange, tqdm
31
-
32
- from .in_silico_perturber import flatten_list
33
 
 
34
  from .tokenizer import TOKEN_DICTIONARY_FILE
35
 
36
  GENE_NAME_ID_DICTIONARY_FILE = Path(__file__).parent / "gene_name_id_dict.pkl"
37
 
38
  logger = logging.getLogger(__name__)
39
 
 
40
  # invert dictionary keys/values
41
  def invert_dict(dictionary):
42
  return {v: k for k, v in dictionary.items()}
43
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  # read raw dictionary files
45
- def read_dictionaries(input_data_directory, cell_or_gene_emb, anchor_token):
46
- file_found = 0
 
 
 
 
 
 
47
  file_path_list = []
48
- dict_list = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  for file in os.listdir(input_data_directory):
50
- # process only _raw.pickle files
51
- if file.endswith("_raw.pickle"):
52
- file_found = 1
53
  file_path_list += [f"{input_data_directory}/{file}"]
54
  for file_path in tqdm(file_path_list):
55
  with open(file_path, "rb") as fp:
56
  cos_sims_dict = pickle.load(fp)
57
- if cell_or_gene_emb == "cell":
58
- cell_emb_dict = {k: v for k,
59
- v in cos_sims_dict.items() if v and "cell_emb" in k}
60
- dict_list += [cell_emb_dict]
61
- elif cell_or_gene_emb == "gene":
62
- gene_emb_dict = {k: v for k,
63
- v in cos_sims_dict.items() if v and anchor_token == k[0]}
64
- dict_list += [gene_emb_dict]
65
- if file_found == 0:
 
 
 
 
66
  logger.error(
67
- "No raw data for processing found within provided directory. " \
68
- "Please ensure data files end with '_raw.pickle'.")
 
69
  raise
70
- return dict_list
 
 
 
 
71
 
72
  # get complete gene list
73
- def get_gene_list(dict_list,mode):
74
  if mode == "cell":
75
  position = 0
76
  elif mode == "gene":
77
  position = 1
78
  gene_set = set()
79
- for dict_i in dict_list:
80
- gene_set.update([k[position] for k, v in dict_i.items() if v])
 
 
 
 
 
 
 
 
 
 
81
  gene_list = list(gene_set)
82
  if mode == "gene":
83
  gene_list.remove("cell_emb")
84
  gene_list.sort()
85
  return gene_list
86
 
 
87
  def token_tuple_to_ensembl_ids(token_tuple, gene_token_id_dict):
88
- return tuple([gene_token_id_dict.get(i, np.nan) for i in token_tuple])
 
 
 
 
89
 
90
  def n_detections(token, dict_list, mode, anchor_token):
91
  cos_sim_megalist = []
92
  for dict_i in dict_list:
93
  if mode == "cell":
94
- cos_sim_megalist += dict_i.get((token, "cell_emb"),[])
95
  elif mode == "gene":
96
- cos_sim_megalist += dict_i.get((anchor_token, token),[])
97
  return len(cos_sim_megalist)
98
 
 
99
  def get_fdr(pvalues):
100
  return list(smt.multipletests(pvalues, alpha=0.05, method="fdr_bh")[1])
101
 
 
102
  def get_impact_component(test_value, gaussian_mixture_model):
103
  impact_border = gaussian_mixture_model.means_[0][0]
104
  nonimpact_border = gaussian_mixture_model.means_[1][0]
@@ -114,236 +183,356 @@ def get_impact_component(test_value, gaussian_mixture_model):
114
  impact_component = 1
115
  return impact_component
116
 
 
117
  # aggregate data for single perturbation in multiple cells
118
- def isp_aggregate_grouped_perturb(cos_sims_df, dict_list):
119
- names=["Cosine_shift"]
120
  cos_sims_full_df = pd.DataFrame(columns=names)
121
 
122
  cos_shift_data = []
123
  token = cos_sims_df["Gene"][0]
124
  for dict_i in dict_list:
125
- cos_shift_data += dict_i.get((token, "cell_emb"),[])
126
  cos_sims_full_df["Cosine_shift"] = cos_shift_data
127
- return cos_sims_full_df
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
128
 
129
  # stats comparing cos sim shifts towards goal state of test perturbations vs random perturbations
130
- def isp_stats_to_goal_state(cos_sims_df, dict_list, cell_states_to_model, genes_perturbed):
131
- cell_state_key = cell_states_to_model["start_state"]
132
- if ("alt_states" not in cell_states_to_model.keys()) \
133
- or (len(cell_states_to_model["alt_states"]) == 0) \
134
- or (cell_states_to_model["alt_states"] == [None]):
 
 
 
135
  alt_end_state_exists = False
136
- elif (len(cell_states_to_model["alt_states"]) > 0) and (cell_states_to_model["alt_states"] != [None]):
 
 
137
  alt_end_state_exists = True
138
-
139
  # for single perturbation in multiple cells, there are no random perturbations to compare to
140
  if genes_perturbed != "all":
141
- names=["Shift_to_goal_end",
142
- "Shift_to_alt_end"]
143
- if alt_end_state_exists == False:
144
- names.remove("Shift_to_alt_end")
145
- cos_sims_full_df = pd.DataFrame(columns=names)
146
-
147
- cos_shift_data = []
148
  token = cos_sims_df["Gene"][0]
149
- for dict_i in dict_list:
150
- cos_shift_data += dict_i.get((token, "cell_emb"),[])
151
- if alt_end_state_exists == False:
152
- cos_sims_full_df["Shift_to_goal_end"] = [goal_end for start_state,goal_end in cos_shift_data]
153
- if alt_end_state_exists == True:
154
- cos_sims_full_df["Shift_to_goal_end"] = [goal_end for start_state,goal_end,alt_end in cos_shift_data]
155
- cos_sims_full_df["Shift_to_alt_end"] = [alt_end for start_state,goal_end,alt_end in cos_shift_data]
156
-
 
 
 
 
 
 
157
  # sort by shift to desired state
158
- cos_sims_full_df = cos_sims_full_df.sort_values(by=["Shift_to_goal_end"],
159
- ascending=[False])
160
- return cos_sims_full_df
161
-
 
162
  elif genes_perturbed == "all":
163
- random_tuples = []
 
 
 
 
164
  for i in trange(cos_sims_df.shape[0]):
165
  token = cos_sims_df["Gene"][i]
166
- for dict_i in dict_list:
167
- random_tuples += dict_i.get((token, "cell_emb"),[])
168
-
169
- if alt_end_state_exists == False:
170
- goal_end_random_megalist = [goal_end for start_state,goal_end in random_tuples]
171
- elif alt_end_state_exists == True:
172
- goal_end_random_megalist = [goal_end for start_state,goal_end,alt_end in random_tuples]
173
- alt_end_random_megalist = [alt_end for start_state,goal_end,alt_end in random_tuples]
174
 
175
  # downsample to improve speed of ranksums
176
  if len(goal_end_random_megalist) > 100_000:
177
  random.seed(42)
178
- goal_end_random_megalist = random.sample(goal_end_random_megalist, k=100_000)
179
- if alt_end_state_exists == True:
180
- if len(alt_end_random_megalist) > 100_000:
181
- random.seed(42)
182
- alt_end_random_megalist = random.sample(alt_end_random_megalist, k=100_000)
183
-
184
- names=["Gene",
185
- "Gene_name",
186
- "Ensembl_ID",
187
- "Shift_to_goal_end",
188
- "Shift_to_alt_end",
189
- "Goal_end_vs_random_pval",
190
- "Alt_end_vs_random_pval"]
191
- if alt_end_state_exists == False:
192
- names.remove("Shift_to_alt_end")
193
- names.remove("Alt_end_vs_random_pval")
 
 
 
 
 
 
 
 
 
 
 
 
194
  cos_sims_full_df = pd.DataFrame(columns=names)
195
 
 
196
  for i in trange(cos_sims_df.shape[0]):
197
  token = cos_sims_df["Gene"][i]
198
  name = cos_sims_df["Gene_name"][i]
199
  ensembl_id = cos_sims_df["Ensembl_ID"][i]
200
- cos_shift_data = []
 
 
 
 
 
 
 
201
 
202
- for dict_i in dict_list:
203
- cos_shift_data += dict_i.get((token, "cell_emb"),[])
 
 
 
 
 
 
 
 
 
 
 
 
 
204
 
205
- if alt_end_state_exists == False:
206
- goal_end_cos_sim_megalist = [goal_end for start_state,goal_end in cos_shift_data]
207
- elif alt_end_state_exists == True:
208
- goal_end_cos_sim_megalist = [goal_end for start_state,goal_end,alt_end in cos_shift_data]
209
- alt_end_cos_sim_megalist = [alt_end for start_state,goal_end,alt_end in cos_shift_data]
210
- mean_alt_end = np.mean(alt_end_cos_sim_megalist)
211
- pval_alt_end = ranksums(alt_end_random_megalist,alt_end_cos_sim_megalist).pvalue
 
 
 
 
 
 
 
212
 
213
- mean_goal_end = np.mean(goal_end_cos_sim_megalist)
214
- pval_goal_end = ranksums(goal_end_random_megalist,goal_end_cos_sim_megalist).pvalue
215
-
216
- if alt_end_state_exists == False:
217
- data_i = [token,
218
- name,
219
- ensembl_id,
220
- mean_goal_end,
221
- pval_goal_end]
222
- elif alt_end_state_exists == True:
223
- data_i = [token,
224
- name,
225
- ensembl_id,
226
- mean_goal_end,
227
- mean_alt_end,
228
- pval_goal_end,
229
- pval_alt_end]
230
-
231
- cos_sims_df_i = pd.DataFrame(dict(zip(names,data_i)),index=[i])
232
- cos_sims_full_df = pd.concat([cos_sims_full_df,cos_sims_df_i])
233
-
234
- cos_sims_full_df["Goal_end_FDR"] = get_fdr(list(cos_sims_full_df["Goal_end_vs_random_pval"]))
235
- if alt_end_state_exists == True:
236
- cos_sims_full_df["Alt_end_FDR"] = get_fdr(list(cos_sims_full_df["Alt_end_vs_random_pval"]))
237
 
238
  # quantify number of detections of each gene
239
- cos_sims_full_df["N_Detections"] = [n_detections(i, dict_list, "cell", None) for i in cos_sims_full_df["Gene"]]
240
-
241
- # sort by shift to desired state\
242
- cos_sims_full_df["Sig"] = [1 if fdr<0.05 else 0 for fdr in cos_sims_full_df["Goal_end_FDR"]]
243
- cos_sims_full_df = cos_sims_full_df.sort_values(by=["Sig",
244
- "Shift_to_goal_end",
245
- "Goal_end_FDR"],
246
- ascending=[False,False,True])
247
-
 
 
 
 
248
  return cos_sims_full_df
249
 
 
250
  # stats comparing cos sim shifts of test perturbations vs null distribution
251
  def isp_stats_vs_null(cos_sims_df, dict_list, null_dict_list):
252
  cos_sims_full_df = cos_sims_df.copy()
253
 
254
  cos_sims_full_df["Test_avg_shift"] = np.zeros(cos_sims_df.shape[0], dtype=float)
255
  cos_sims_full_df["Null_avg_shift"] = np.zeros(cos_sims_df.shape[0], dtype=float)
256
- cos_sims_full_df["Test_vs_null_avg_shift"] = np.zeros(cos_sims_df.shape[0], dtype=float)
 
 
257
  cos_sims_full_df["Test_vs_null_pval"] = np.zeros(cos_sims_df.shape[0], dtype=float)
258
  cos_sims_full_df["Test_vs_null_FDR"] = np.zeros(cos_sims_df.shape[0], dtype=float)
259
- cos_sims_full_df["N_Detections_test"] = np.zeros(cos_sims_df.shape[0], dtype="uint32")
260
- cos_sims_full_df["N_Detections_null"] = np.zeros(cos_sims_df.shape[0], dtype="uint32")
261
-
 
 
 
 
262
  for i in trange(cos_sims_df.shape[0]):
263
  token = cos_sims_df["Gene"][i]
264
  test_shifts = []
265
  null_shifts = []
266
-
267
  for dict_i in dict_list:
268
- test_shifts += dict_i.get((token, "cell_emb"),[])
269
 
270
  for dict_i in null_dict_list:
271
- null_shifts += dict_i.get((token, "cell_emb"),[])
272
-
273
  cos_sims_full_df.loc[i, "Test_avg_shift"] = np.mean(test_shifts)
274
  cos_sims_full_df.loc[i, "Null_avg_shift"] = np.mean(null_shifts)
275
- cos_sims_full_df.loc[i, "Test_vs_null_avg_shift"] = np.mean(test_shifts)-np.mean(null_shifts)
276
- cos_sims_full_df.loc[i, "Test_vs_null_pval"] = ranksums(test_shifts,
277
- null_shifts, nan_policy="omit").pvalue
278
-
 
 
 
 
 
 
 
 
279
  cos_sims_full_df.loc[i, "N_Detections_test"] = len(test_shifts)
280
  cos_sims_full_df.loc[i, "N_Detections_null"] = len(null_shifts)
281
 
282
- cos_sims_full_df["Test_vs_null_FDR"] = get_fdr(cos_sims_full_df["Test_vs_null_pval"])
283
-
284
- cos_sims_full_df["Sig"] = [1 if fdr<0.05 else 0 for fdr in cos_sims_full_df["Test_vs_null_FDR"]]
285
- cos_sims_full_df = cos_sims_full_df.sort_values(by=["Sig",
286
- "Test_vs_null_avg_shift",
287
- "Test_vs_null_FDR"],
288
- ascending=[False,False,True])
 
 
 
 
289
  return cos_sims_full_df
290
 
 
291
  # stats for identifying perturbations with largest effect within a given set of cells
292
  # fits a mixture model to 2 components (impact vs. non-impact) and
293
  # reports the most likely component for each test perturbation
294
  # Note: because assumes given perturbation has a consistent effect in the cells tested,
295
  # we recommend only using the mixture model strategy with uniform cell populations
296
  def isp_stats_mixture_model(cos_sims_df, dict_list, combos, anchor_token):
297
-
298
- names=["Gene",
299
- "Gene_name",
300
- "Ensembl_ID"]
301
-
302
  if combos == 0:
303
  names += ["Test_avg_shift"]
304
  elif combos == 1:
305
- names += ["Anchor_shift",
306
- "Test_token_shift",
307
- "Sum_of_indiv_shifts",
308
- "Combo_shift",
309
- "Combo_minus_sum_shift"]
310
-
311
- names += ["Impact_component",
312
- "Impact_component_percent"]
 
313
 
314
  cos_sims_full_df = pd.DataFrame(columns=names)
315
  avg_values = []
316
  gene_names = []
317
-
318
  for i in trange(cos_sims_df.shape[0]):
319
  token = cos_sims_df["Gene"][i]
320
  name = cos_sims_df["Gene_name"][i]
321
  ensembl_id = cos_sims_df["Ensembl_ID"][i]
322
  cos_shift_data = []
323
-
324
  for dict_i in dict_list:
325
  if (combos == 0) and (anchor_token is not None):
326
- cos_shift_data += dict_i.get((anchor_token, token),[])
327
  else:
328
- cos_shift_data += dict_i.get((token, "cell_emb"),[])
329
-
330
  # Extract values for current gene
331
  if combos == 0:
332
  test_values = cos_shift_data
333
  elif combos == 1:
334
  test_values = []
335
  for tup in cos_shift_data:
336
- test_values.append(tup[2])
337
-
338
  if len(test_values) > 0:
339
  avg_value = np.mean(test_values)
340
  avg_values.append(avg_value)
341
  gene_names.append(name)
342
-
343
  # fit Gaussian mixture model to dataset of mean for each gene
344
  avg_values_to_fit = np.array(avg_values).reshape(-1, 1)
345
  gm = GaussianMixture(n_components=2, random_state=0).fit(avg_values_to_fit)
346
-
347
  for i in trange(cos_sims_df.shape[0]):
348
  token = cos_sims_df["Gene"][i]
349
  name = cos_sims_df["Gene_name"][i]
@@ -352,71 +541,95 @@ def isp_stats_mixture_model(cos_sims_df, dict_list, combos, anchor_token):
352
 
353
  for dict_i in dict_list:
354
  if (combos == 0) and (anchor_token is not None):
355
- cos_shift_data += dict_i.get((anchor_token, token),[])
356
  else:
357
- cos_shift_data += dict_i.get((token, "cell_emb"),[])
358
-
359
  if combos == 0:
360
  mean_test = np.mean(cos_shift_data)
361
- impact_components = [get_impact_component(value,gm) for value in cos_shift_data]
 
 
362
  elif combos == 1:
363
- anchor_cos_sim_megalist = [anchor for anchor,token,combo in cos_shift_data]
364
- token_cos_sim_megalist = [token for anchor,token,combo in cos_shift_data]
365
- anchor_plus_token_cos_sim_megalist = [1-((1-anchor)+(1-token)) for anchor,token,combo in cos_shift_data]
366
- combo_anchor_token_cos_sim_megalist = [combo for anchor,token,combo in cos_shift_data]
367
- combo_minus_sum_cos_sim_megalist = [combo-(1-((1-anchor)+(1-token))) for anchor,token,combo in cos_shift_data]
 
 
 
 
 
 
 
 
 
 
368
 
369
  mean_anchor = np.mean(anchor_cos_sim_megalist)
370
  mean_token = np.mean(token_cos_sim_megalist)
371
  mean_sum = np.mean(anchor_plus_token_cos_sim_megalist)
372
  mean_test = np.mean(combo_anchor_token_cos_sim_megalist)
373
  mean_combo_minus_sum = np.mean(combo_minus_sum_cos_sim_megalist)
374
-
375
- impact_components = [get_impact_component(value,gm) for value in combo_anchor_token_cos_sim_megalist]
376
-
377
- impact_component = get_impact_component(mean_test,gm)
378
- impact_component_percent = np.mean(impact_components)*100
379
-
380
- data_i = [token,
381
- name,
382
- ensembl_id]
 
383
  if combos == 0:
384
  data_i += [mean_test]
385
  elif combos == 1:
386
- data_i += [mean_anchor,
387
- mean_token,
388
- mean_sum,
389
- mean_test,
390
- mean_combo_minus_sum]
391
- data_i += [impact_component,
392
- impact_component_percent]
393
-
394
- cos_sims_df_i = pd.DataFrame(dict(zip(names,data_i)),index=[i])
395
- cos_sims_full_df = pd.concat([cos_sims_full_df,cos_sims_df_i])
396
-
 
397
  # quantify number of detections of each gene
398
- cos_sims_full_df["N_Detections"] = [n_detections(i,
399
- dict_list,
400
- "gene",
401
- anchor_token) for i in cos_sims_full_df["Gene"]]
402
-
403
  if combos == 0:
404
- cos_sims_full_df = cos_sims_full_df.sort_values(by=["Impact_component",
405
- "Test_avg_shift"],
406
- ascending=[False,True])
407
  elif combos == 1:
408
- cos_sims_full_df = cos_sims_full_df.sort_values(by=["Impact_component",
409
- "Combo_minus_sum_shift"],
410
- ascending=[False,True])
411
  return cos_sims_full_df
412
 
 
413
  class InSilicoPerturberStats:
414
  valid_option_dict = {
415
- "mode": {"goal_state_shift","vs_null","mixture_model","aggregate_data"},
416
- "combos": {0,1},
 
 
 
 
 
 
 
417
  "anchor_gene": {None, str},
418
  "cell_states_to_model": {None, dict},
 
419
  }
 
420
  def __init__(
421
  self,
422
  mode="mixture_model",
@@ -424,6 +637,7 @@ class InSilicoPerturberStats:
424
  combos=0,
425
  anchor_gene=None,
426
  cell_states_to_model=None,
 
427
  token_dictionary_file=TOKEN_DICTIONARY_FILE,
428
  gene_name_id_dictionary_file=GENE_NAME_ID_DICTIONARY_FILE,
429
  ):
@@ -432,12 +646,13 @@ class InSilicoPerturberStats:
432
 
433
  Parameters
434
  ----------
435
- mode : {"goal_state_shift","vs_null","mixture_model","aggregate_data"}
436
  Type of stats.
437
  "goal_state_shift": perturbation vs. random for desired cell state shift
438
  "vs_null": perturbation vs. null from provided null distribution dataset
439
  "mixture_model": perturbation in impact vs. no impact component of mixture model (no goal direction)
440
  "aggregate_data": aggregates cosine shifts for single perturbation in multiple cells
 
441
  genes_perturbed : "all", list
442
  Genes perturbed in isp experiment.
443
  Default is assuming genes_to_perturb in isp experiment was "all" (each gene in each cell).
@@ -472,13 +687,14 @@ class InSilicoPerturberStats:
472
  self.combos = combos
473
  self.anchor_gene = anchor_gene
474
  self.cell_states_to_model = cell_states_to_model
475
-
 
476
  self.validate_options()
477
 
478
  # load token dictionary (Ensembl IDs:token)
479
  with open(token_dictionary_file, "rb") as f:
480
  self.gene_token_dict = pickle.load(f)
481
-
482
  # load gene name dictionary (gene name:Ensembl ID)
483
  with open(gene_name_id_dictionary_file, "rb") as f:
484
  self.gene_name_id_dict = pickle.load(f)
@@ -489,7 +705,7 @@ class InSilicoPerturberStats:
489
  self.anchor_token = self.gene_token_dict[self.anchor_gene]
490
 
491
  def validate_options(self):
492
- for attr_name,valid_options in self.valid_option_dict.items():
493
  attr_value = self.__dict__[attr_name]
494
  if type(attr_value) not in {list, dict}:
495
  if attr_name in {"anchor_gene"}:
@@ -498,35 +714,40 @@ class InSilicoPerturberStats:
498
  continue
499
  valid_type = False
500
  for option in valid_options:
501
- if (option in [int,list,dict]) and isinstance(attr_value, option):
 
 
502
  valid_type = True
503
  break
504
- if valid_type:
505
- continue
506
- logger.error(
507
- f"Invalid option for {attr_name}. " \
508
- f"Valid options for {attr_name}: {valid_options}"
509
- )
510
- raise
511
-
512
  if self.cell_states_to_model is not None:
513
  if len(self.cell_states_to_model.items()) == 1:
514
  logger.warning(
515
- "The single value dictionary for cell_states_to_model will be " \
516
- "replaced with a dictionary with named keys for start, goal, and alternate states. " \
517
- "Please specify state_key, start_state, goal_state, and alt_states " \
518
- "in the cell_states_to_model dictionary for future use. " \
519
- "For example, cell_states_to_model={" \
520
- "'state_key': 'disease', " \
521
- "'start_state': 'dcm', " \
522
- "'goal_state': 'nf', " \
523
- "'alt_states': ['hcm', 'other1', 'other2']}"
524
  )
525
- for key,value in self.cell_states_to_model.items():
526
  if (len(value) == 3) and isinstance(value, tuple):
527
- if isinstance(value[0],list) and isinstance(value[1],list) and isinstance(value[2],list):
 
 
 
 
528
  if len(value[0]) == 1 and len(value[1]) == 1:
529
- all_values = value[0]+value[1]+value[2]
530
  if len(all_values) == len(set(all_values)):
531
  continue
532
  # reformat to the new named key format
@@ -535,75 +756,93 @@ class InSilicoPerturberStats:
535
  "state_key": list(self.cell_states_to_model.keys())[0],
536
  "start_state": state_values[0][0],
537
  "goal_state": state_values[1][0],
538
- "alt_states": state_values[2:][0]
539
  }
540
- elif set(self.cell_states_to_model.keys()) == {"state_key", "start_state", "goal_state", "alt_states"}:
541
- if (self.cell_states_to_model["state_key"] is None) \
542
- or (self.cell_states_to_model["start_state"] is None) \
543
- or (self.cell_states_to_model["goal_state"] is None):
 
 
 
 
 
 
 
544
  logger.error(
545
- "Please specify 'state_key', 'start_state', and 'goal_state' in cell_states_to_model.")
 
546
  raise
547
-
548
- if self.cell_states_to_model["start_state"] == self.cell_states_to_model["goal_state"]:
549
- logger.error(
550
- "All states must be unique.")
 
 
551
  raise
552
 
553
  if self.cell_states_to_model["alt_states"] is not None:
554
- if type(self.cell_states_to_model["alt_states"]) is not list:
555
  logger.error(
556
  "self.cell_states_to_model['alt_states'] must be a list (even if it is one element)."
557
  )
558
  raise
559
- if len(self.cell_states_to_model["alt_states"])!= len(set(self.cell_states_to_model["alt_states"])):
560
- logger.error(
561
- "All states must be unique.")
 
562
  raise
563
 
564
  else:
565
  logger.error(
566
- "cell_states_to_model must only have the following four keys: " \
567
- "'state_key', 'start_state', 'goal_state', 'alt_states'." \
568
- "For example, cell_states_to_model={" \
569
- "'state_key': 'disease', " \
570
- "'start_state': 'dcm', " \
571
- "'goal_state': 'nf', " \
572
- "'alt_states': ['hcm', 'other1', 'other2']}"
573
  )
574
  raise
575
 
576
  if self.anchor_gene is not None:
577
  self.anchor_gene = None
578
  logger.warning(
579
- "anchor_gene set to None. " \
580
- "Currently, anchor gene not available " \
581
- "when modeling multiple cell states.")
582
-
 
583
  if self.combos > 0:
584
  if self.anchor_gene is None:
585
  logger.error(
586
- "Currently, stats are only supported for combination " \
587
- "in silico perturbation run with anchor gene. Please add " \
588
- "anchor gene when using with combos > 0. ")
 
589
  raise
590
-
591
  if (self.mode == "mixture_model") and (self.genes_perturbed != "all"):
592
  logger.error(
593
- "Mixture model mode requires multiple gene perturbations to fit model " \
594
- "so is incompatible with a single grouped perturbation.")
 
595
  raise
596
  if (self.mode == "aggregate_data") and (self.genes_perturbed == "all"):
597
  logger.error(
598
- "Simple data aggregation mode is for single perturbation in multiple cells " \
599
- "so is incompatible with a genes_perturbed being 'all'.")
600
- raise
601
-
602
- def get_stats(self,
603
- input_data_directory,
604
- null_dist_data_directory,
605
- output_directory,
606
- output_prefix):
 
 
 
 
607
  """
608
  Get stats for in silico perturbation data and save as results in output_directory.
609
 
@@ -617,20 +856,22 @@ class InSilicoPerturberStats:
617
  Path to directory where perturbation data will be saved as .csv
618
  output_prefix : str
619
  Prefix for output .csv
620
-
 
 
621
  Outputs
622
  ----------
623
  Definition of possible columns in .csv output file.
624
-
625
  Of note, not all columns will be present in all output files.
626
  Some columns are specific to particular perturbation modes.
627
-
628
  "Gene": gene token
629
  "Gene_name": gene name
630
  "Ensembl_ID": gene Ensembl ID
631
  "N_Detections": number of cells in which each gene or gene combination was detected in the input dataset
632
  "Sig": 1 if FDR<0.05, otherwise 0
633
-
634
  "Shift_to_goal_end": cosine shift from start state towards goal end state in response to given perturbation
635
  "Shift_to_alt_end": cosine shift from start state towards alternate end state in response to given perturbation
636
  "Goal_end_vs_random_pval": pvalue of cosine shift from start state towards goal end state by Wilcoxon
@@ -639,7 +880,7 @@ class InSilicoPerturberStats:
639
  pvalue compares shift caused by perturbing given gene compared to random genes
640
  "Goal_end_FDR": Benjamini-Hochberg correction of "Goal_end_vs_random_pval"
641
  "Alt_end_FDR": Benjamini-Hochberg correction of "Alt_end_vs_random_pval"
642
-
643
  "Test_avg_shift": cosine shift in response to given perturbation in cells from test distribution
644
  "Null_avg_shift": cosine shift in response to given perturbation in cells from null distribution (e.g. random cells)
645
  "Test_vs_null_avg_shift": difference in cosine shift in cells from test vs. null distribution
@@ -648,7 +889,7 @@ class InSilicoPerturberStats:
648
  "Test_vs_null_FDR": Benjamini-Hochberg correction of "Test_vs_null_pval"
649
  "N_Detections_test": "N_Detections" in cells from test distribution
650
  "N_Detections_null": "N_Detections" in cells from null distribution
651
-
652
  "Anchor_shift": cosine shift in response to given perturbation of anchor gene
653
  "Test_token_shift": cosine shift in response to given perturbation of test gene
654
  "Sum_of_indiv_shifts": sum of cosine shifts in response to individually perturbing test and anchor genes
@@ -658,13 +899,27 @@ class InSilicoPerturberStats:
658
  "Impact_component": whether the given perturbation was modeled to be within the impact component by the mixture model
659
  1: within impact component; 0: not within impact component
660
  "Impact_component_percent": percent of cells in which given perturbation was modeled to be within impact component
 
 
 
 
 
 
661
  """
662
 
663
- if self.mode not in ["goal_state_shift", "vs_null", "mixture_model","aggregate_data"]:
 
 
 
 
 
 
664
  logger.error(
665
- "Currently, only modes available are stats for goal_state_shift, " \
666
- "vs_null (comparing to null distribution), and " \
667
- "mixture_model (fitting mixture model for perturbations with or without impact.")
 
 
668
  raise
669
 
670
  self.gene_token_id_dict = invert_dict(self.gene_token_dict)
@@ -673,44 +928,107 @@ class InSilicoPerturberStats:
673
  # obtain total gene list
674
  if (self.combos == 0) and (self.anchor_token is not None):
675
  # cos sim data for effect of gene perturbation on the embedding of each other gene
676
- dict_list = read_dictionaries(input_data_directory, "gene", self.anchor_token)
 
 
 
 
 
 
677
  gene_list = get_gene_list(dict_list, "gene")
 
 
 
 
 
 
 
 
 
 
 
 
 
678
  else:
679
  # cos sim data for effect of gene perturbation on the embedding of each cell
680
- dict_list = read_dictionaries(input_data_directory, "cell", self.anchor_token)
 
 
 
 
 
 
681
  gene_list = get_gene_list(dict_list, "cell")
682
-
683
  # initiate results dataframe
684
- cos_sims_df_initial = pd.DataFrame({"Gene": gene_list,
685
- "Gene_name": [self.token_to_gene_name(item) \
686
- for item in gene_list], \
687
- "Ensembl_ID": [token_tuple_to_ensembl_ids(genes, self.gene_token_id_dict) \
688
- if self.genes_perturbed != "all" else \
689
- self.gene_token_id_dict[genes[1]] \
690
- if isinstance(genes,tuple) else \
691
- self.gene_token_id_dict[genes] \
692
- for genes in gene_list]}, \
693
- index=[i for i in range(len(gene_list))])
 
 
 
 
 
694
 
695
  if self.mode == "goal_state_shift":
696
- cos_sims_df = isp_stats_to_goal_state(cos_sims_df_initial, dict_list, self.cell_states_to_model, self.genes_perturbed)
697
-
 
 
 
 
 
698
  elif self.mode == "vs_null":
699
- null_dict_list = read_dictionaries(null_dist_data_directory, "cell", self.anchor_token)
700
- cos_sims_df = isp_stats_vs_null(cos_sims_df_initial, dict_list, null_dict_list)
 
 
 
 
 
 
 
 
 
701
 
702
  elif self.mode == "mixture_model":
703
- cos_sims_df = isp_stats_mixture_model(cos_sims_df_initial, dict_list, self.combos, self.anchor_token)
704
-
 
 
705
  elif self.mode == "aggregate_data":
706
  cos_sims_df = isp_aggregate_grouped_perturb(cos_sims_df_initial, dict_list)
707
 
 
 
 
 
 
 
 
 
708
  # save perturbation stats to output_path
709
  output_path = (Path(output_directory) / output_prefix).with_suffix(".csv")
710
  cos_sims_df.to_csv(output_path)
711
 
712
  def token_to_gene_name(self, item):
713
- if isinstance(item,int):
714
- return self.gene_id_name_dict.get(self.gene_token_id_dict.get(item, np.nan), np.nan)
715
- if isinstance(item,tuple):
716
- return tuple([self.gene_id_name_dict.get(self.gene_token_id_dict.get(i, np.nan), np.nan) for i in item])
 
 
 
 
 
 
 
 
 
6
  ispstats = InSilicoPerturberStats(mode="goal_state_shift",
7
  combos=0,
8
  anchor_gene=None,
9
+ cell_states_to_model={"state_key": "disease",
10
+ "start_state": "dcm",
11
+ "goal_state": "nf",
12
  "alt_states": ["hcm", "other1", "other2"]})
13
  ispstats.get_stats("path/to/input_data",
14
  None,
17
  """
18
 
19
 
 
20
  import logging
21
+ import os
 
22
  import pickle
23
  import random
 
24
  from pathlib import Path
25
+
26
+ import numpy as np
27
+ import pandas as pd
28
+ import statsmodels.stats.multitest as smt
29
  from scipy.stats import ranksums
30
  from sklearn.mixture import GaussianMixture
31
+ from tqdm.auto import tqdm, trange
 
 
32
 
33
+ from .perturber_utils import flatten_list, validate_cell_states_to_model
34
  from .tokenizer import TOKEN_DICTIONARY_FILE
35
 
36
  GENE_NAME_ID_DICTIONARY_FILE = Path(__file__).parent / "gene_name_id_dict.pkl"
37
 
38
  logger = logging.getLogger(__name__)
39
 
40
+
41
  # invert dictionary keys/values
42
  def invert_dict(dictionary):
43
  return {v: k for k, v in dictionary.items()}
44
 
45
+
46
+ def read_dict(cos_sims_dict, cell_or_gene_emb, anchor_token):
47
+ if cell_or_gene_emb == "cell":
48
+ cell_emb_dict = {
49
+ k: v for k, v in cos_sims_dict.items() if v and "cell_emb" in k
50
+ }
51
+ return [cell_emb_dict]
52
+ elif cell_or_gene_emb == "gene":
53
+ if anchor_token is None:
54
+ gene_emb_dict = {k: v for k, v in cos_sims_dict.items() if v}
55
+ else:
56
+ gene_emb_dict = {
57
+ k: v for k, v in cos_sims_dict.items() if v and anchor_token == k[0]
58
+ }
59
+ return [gene_emb_dict]
60
+
61
+
62
  # read raw dictionary files
63
+ def read_dictionaries(
64
+ input_data_directory,
65
+ cell_or_gene_emb,
66
+ anchor_token,
67
+ cell_states_to_model,
68
+ pickle_suffix,
69
+ ):
70
+ file_found = False
71
  file_path_list = []
72
+ if cell_states_to_model is None:
73
+ dict_list = []
74
+ else:
75
+ validate_cell_states_to_model(cell_states_to_model)
76
+ cell_states_to_model_valid = {
77
+ state: value
78
+ for state, value in cell_states_to_model.items()
79
+ if state != "state_key"
80
+ and cell_states_to_model[state] is not None
81
+ and cell_states_to_model[state] != []
82
+ }
83
+ cell_states_list = []
84
+ # flatten all state values into list
85
+ for state in cell_states_to_model_valid:
86
+ value = cell_states_to_model_valid[state]
87
+ if isinstance(value, list):
88
+ cell_states_list += value
89
+ else:
90
+ cell_states_list.append(value)
91
+ state_dict = {state_value: dict() for state_value in cell_states_list}
92
  for file in os.listdir(input_data_directory):
93
+ # process only files with given suffix (e.g. "_raw.pickle")
94
+ if file.endswith(pickle_suffix):
95
+ file_found = True
96
  file_path_list += [f"{input_data_directory}/{file}"]
97
  for file_path in tqdm(file_path_list):
98
  with open(file_path, "rb") as fp:
99
  cos_sims_dict = pickle.load(fp)
100
+ if cell_states_to_model is None:
101
+ dict_list += read_dict(cos_sims_dict, cell_or_gene_emb, anchor_token)
102
+ else:
103
+ for state_value in cell_states_list:
104
+ new_dict = read_dict(
105
+ cos_sims_dict[state_value], cell_or_gene_emb, anchor_token
106
+ )[0]
107
+ for key in new_dict:
108
+ try:
109
+ state_dict[state_value][key] += new_dict[key]
110
+ except KeyError:
111
+ state_dict[state_value][key] = new_dict[key]
112
+ if not file_found:
113
  logger.error(
114
+ "No raw data for processing found within provided directory. "
115
+ "Please ensure data files end with '{pickle_suffix}'."
116
+ )
117
  raise
118
+ if cell_states_to_model is None:
119
+ return dict_list
120
+ else:
121
+ return state_dict
122
+
123
 
124
  # get complete gene list
125
+ def get_gene_list(dict_list, mode):
126
  if mode == "cell":
127
  position = 0
128
  elif mode == "gene":
129
  position = 1
130
  gene_set = set()
131
+ if isinstance(dict_list, list):
132
+ for dict_i in dict_list:
133
+ gene_set.update([k[position] for k, v in dict_i.items() if v])
134
+ elif isinstance(dict_list, dict):
135
+ for state, dict_i in dict_list.items():
136
+ gene_set.update([k[position] for k, v in dict_i.items() if v])
137
+ else:
138
+ logger.error(
139
+ "dict_list should be a list, or if modeling shift to goal states, a dict. "
140
+ f"{type(dict_list)} is not the correct format."
141
+ )
142
+ raise
143
  gene_list = list(gene_set)
144
  if mode == "gene":
145
  gene_list.remove("cell_emb")
146
  gene_list.sort()
147
  return gene_list
148
 
149
+
150
  def token_tuple_to_ensembl_ids(token_tuple, gene_token_id_dict):
151
+ try:
152
+ return tuple([gene_token_id_dict.get(i, np.nan) for i in token_tuple])
153
+ except TypeError:
154
+ return tuple(gene_token_id_dict.get(token_tuple, np.nan))
155
+
156
 
157
  def n_detections(token, dict_list, mode, anchor_token):
158
  cos_sim_megalist = []
159
  for dict_i in dict_list:
160
  if mode == "cell":
161
+ cos_sim_megalist += dict_i.get((token, "cell_emb"), [])
162
  elif mode == "gene":
163
+ cos_sim_megalist += dict_i.get((anchor_token, token), [])
164
  return len(cos_sim_megalist)
165
 
166
+
167
  def get_fdr(pvalues):
168
  return list(smt.multipletests(pvalues, alpha=0.05, method="fdr_bh")[1])
169
 
170
+
171
  def get_impact_component(test_value, gaussian_mixture_model):
172
  impact_border = gaussian_mixture_model.means_[0][0]
173
  nonimpact_border = gaussian_mixture_model.means_[1][0]
183
  impact_component = 1
184
  return impact_component
185
 
186
+
187
  # aggregate data for single perturbation in multiple cells
188
+ def isp_aggregate_grouped_perturb(cos_sims_df, dict_list):
189
+ names = ["Cosine_shift"]
190
  cos_sims_full_df = pd.DataFrame(columns=names)
191
 
192
  cos_shift_data = []
193
  token = cos_sims_df["Gene"][0]
194
  for dict_i in dict_list:
195
+ cos_shift_data += dict_i.get((token, "cell_emb"), [])
196
  cos_sims_full_df["Cosine_shift"] = cos_shift_data
197
+ return cos_sims_full_df
198
+
199
+
200
+ def find(variable, x):
201
+ try:
202
+ if x in variable: # Test if variable is iterable and contains x
203
+ return True
204
+ except TypeError:
205
+ return x == variable # Test if variable is x if non-iterable
206
+
207
+
208
+ def isp_aggregate_gene_shifts(
209
+ cos_sims_df, dict_list, gene_token_id_dict, gene_id_name_dict
210
+ ):
211
+ cos_shift_data = dict()
212
+ for i in trange(cos_sims_df.shape[0]):
213
+ token = cos_sims_df["Gene"][i]
214
+ for dict_i in dict_list:
215
+ affected_pairs = [k for k, v in dict_i.items() if find(k[0], token)]
216
+ for key in affected_pairs:
217
+ if key in cos_shift_data.keys():
218
+ cos_shift_data[key] += dict_i.get(key, [])
219
+ else:
220
+ cos_shift_data[key] = dict_i.get(key, [])
221
+
222
+ cos_data_mean = {
223
+ k: [np.mean(v), np.std(v), len(v)] for k, v in cos_shift_data.items()
224
+ }
225
+ cos_sims_full_df = pd.DataFrame()
226
+ cos_sims_full_df["Perturbed"] = [k[0] for k, v in cos_data_mean.items()]
227
+ cos_sims_full_df["Gene_name"] = [
228
+ cos_sims_df[cos_sims_df["Gene"] == k[0]]["Gene_name"][0]
229
+ for k, v in cos_data_mean.items()
230
+ ]
231
+ cos_sims_full_df["Ensembl_ID"] = [
232
+ cos_sims_df[cos_sims_df["Gene"] == k[0]]["Ensembl_ID"][0]
233
+ for k, v in cos_data_mean.items()
234
+ ]
235
+ cos_sims_full_df["Affected"] = [k[1] for k, v in cos_data_mean.items()]
236
+ cos_sims_full_df["Affected_Gene_name"] = [
237
+ gene_id_name_dict.get(gene_token_id_dict.get(token, np.nan), np.nan)
238
+ for token in cos_sims_full_df["Affected"]
239
+ ]
240
+ cos_sims_full_df["Affected_Ensembl_ID"] = [
241
+ gene_token_id_dict.get(token, np.nan) for token in cos_sims_full_df["Affected"]
242
+ ]
243
+ cos_sims_full_df["Cosine_shift_mean"] = [v[0] for k, v in cos_data_mean.items()]
244
+ cos_sims_full_df["Cosine_shift_stdev"] = [v[1] for k, v in cos_data_mean.items()]
245
+ cos_sims_full_df["N_Detections"] = [v[2] for k, v in cos_data_mean.items()]
246
+
247
+ specific_val = "cell_emb"
248
+ cos_sims_full_df["temp"] = list(cos_sims_full_df["Affected"] == specific_val)
249
+ # reorder so cell embs are at the top and all are subordered by magnitude of cosine shift
250
+ cos_sims_full_df = cos_sims_full_df.sort_values(
251
+ by=(["temp", "Cosine_shift_mean"]), ascending=[False, False]
252
+ ).drop("temp", axis=1)
253
+
254
+ return cos_sims_full_df
255
+
256
 
257
  # stats comparing cos sim shifts towards goal state of test perturbations vs random perturbations
258
+ def isp_stats_to_goal_state(
259
+ cos_sims_df, result_dict, cell_states_to_model, genes_perturbed
260
+ ):
261
+ if (
262
+ ("alt_states" not in cell_states_to_model.keys())
263
+ or (len(cell_states_to_model["alt_states"]) == 0)
264
+ or (cell_states_to_model["alt_states"] == [None])
265
+ ):
266
  alt_end_state_exists = False
267
+ elif (len(cell_states_to_model["alt_states"]) > 0) and (
268
+ cell_states_to_model["alt_states"] != [None]
269
+ ):
270
  alt_end_state_exists = True
271
+
272
  # for single perturbation in multiple cells, there are no random perturbations to compare to
273
  if genes_perturbed != "all":
274
+ cos_sims_full_df = pd.DataFrame()
275
+
276
+ cos_shift_data_end = []
 
 
 
 
277
  token = cos_sims_df["Gene"][0]
278
+ cos_shift_data_end += result_dict[cell_states_to_model["goal_state"]].get(
279
+ (token, "cell_emb"), []
280
+ )
281
+ cos_sims_full_df["Shift_to_goal_end"] = [np.mean(cos_shift_data_end)]
282
+ if alt_end_state_exists is True:
283
+ for alt_state in cell_states_to_model["alt_states"]:
284
+ cos_shift_data_alt_state = []
285
+ cos_shift_data_alt_state += result_dict.get(alt_state).get(
286
+ (token, "cell_emb"), []
287
+ )
288
+ cos_sims_full_df[f"Shift_to_alt_end_{alt_state}"] = [
289
+ np.mean(cos_shift_data_alt_state)
290
+ ]
291
+
292
  # sort by shift to desired state
293
+ cos_sims_full_df = cos_sims_full_df.sort_values(
294
+ by=["Shift_to_goal_end"], ascending=[False]
295
+ )
296
+ return cos_sims_full_df
297
+
298
  elif genes_perturbed == "all":
299
+ goal_end_random_megalist = []
300
+ if alt_end_state_exists is True:
301
+ alt_end_state_random_dict = {
302
+ alt_state: [] for alt_state in cell_states_to_model["alt_states"]
303
+ }
304
  for i in trange(cos_sims_df.shape[0]):
305
  token = cos_sims_df["Gene"][i]
306
+ goal_end_random_megalist += result_dict[
307
+ cell_states_to_model["goal_state"]
308
+ ].get((token, "cell_emb"), [])
309
+ if alt_end_state_exists is True:
310
+ for alt_state in cell_states_to_model["alt_states"]:
311
+ alt_end_state_random_dict[alt_state] += result_dict[alt_state].get(
312
+ (token, "cell_emb"), []
313
+ )
314
 
315
  # downsample to improve speed of ranksums
316
  if len(goal_end_random_megalist) > 100_000:
317
  random.seed(42)
318
+ goal_end_random_megalist = random.sample(
319
+ goal_end_random_megalist, k=100_000
320
+ )
321
+ if alt_end_state_exists is True:
322
+ for alt_state in cell_states_to_model["alt_states"]:
323
+ if len(alt_end_state_random_dict[alt_state]) > 100_000:
324
+ random.seed(42)
325
+ alt_end_state_random_dict[alt_state] = random.sample(
326
+ alt_end_state_random_dict[alt_state], k=100_000
327
+ )
328
+
329
+ names = [
330
+ "Gene",
331
+ "Gene_name",
332
+ "Ensembl_ID",
333
+ "Shift_to_goal_end",
334
+ "Goal_end_vs_random_pval",
335
+ ]
336
+ if alt_end_state_exists is True:
337
+ [
338
+ names.append(f"Shift_to_alt_end_{alt_state}")
339
+ for alt_state in cell_states_to_model["alt_states"]
340
+ ]
341
+ names.append(names.pop(names.index("Goal_end_vs_random_pval")))
342
+ [
343
+ names.append(f"Alt_end_vs_random_pval_{alt_state}")
344
+ for alt_state in cell_states_to_model["alt_states"]
345
+ ]
346
  cos_sims_full_df = pd.DataFrame(columns=names)
347
 
348
+ n_detections_dict = dict()
349
  for i in trange(cos_sims_df.shape[0]):
350
  token = cos_sims_df["Gene"][i]
351
  name = cos_sims_df["Gene_name"][i]
352
  ensembl_id = cos_sims_df["Ensembl_ID"][i]
353
+ goal_end_cos_sim_megalist = result_dict[
354
+ cell_states_to_model["goal_state"]
355
+ ].get((token, "cell_emb"), [])
356
+ n_detections_dict[token] = len(goal_end_cos_sim_megalist)
357
+ mean_goal_end = np.mean(goal_end_cos_sim_megalist)
358
+ pval_goal_end = ranksums(
359
+ goal_end_random_megalist, goal_end_cos_sim_megalist
360
+ ).pvalue
361
 
362
+ if alt_end_state_exists is True:
363
+ alt_end_state_dict = {
364
+ alt_state: [] for alt_state in cell_states_to_model["alt_states"]
365
+ }
366
+ for alt_state in cell_states_to_model["alt_states"]:
367
+ alt_end_state_dict[alt_state] = result_dict[alt_state].get(
368
+ (token, "cell_emb"), []
369
+ )
370
+ alt_end_state_dict[f"{alt_state}_mean"] = np.mean(
371
+ alt_end_state_dict[alt_state]
372
+ )
373
+ alt_end_state_dict[f"{alt_state}_pval"] = ranksums(
374
+ alt_end_state_random_dict[alt_state],
375
+ alt_end_state_dict[alt_state],
376
+ ).pvalue
377
 
378
+ results_dict = dict()
379
+ results_dict["Gene"] = token
380
+ results_dict["Gene_name"] = name
381
+ results_dict["Ensembl_ID"] = ensembl_id
382
+ results_dict["Shift_to_goal_end"] = mean_goal_end
383
+ results_dict["Goal_end_vs_random_pval"] = pval_goal_end
384
+ if alt_end_state_exists is True:
385
+ for alt_state in cell_states_to_model["alt_states"]:
386
+ results_dict[f"Shift_to_alt_end_{alt_state}"] = alt_end_state_dict[
387
+ f"{alt_state}_mean"
388
+ ]
389
+ results_dict[
390
+ f"Alt_end_vs_random_pval_{alt_state}"
391
+ ] = alt_end_state_dict[f"{alt_state}_pval"]
392
 
393
+ cos_sims_df_i = pd.DataFrame(results_dict, index=[i])
394
+ cos_sims_full_df = pd.concat([cos_sims_full_df, cos_sims_df_i])
395
+
396
+ cos_sims_full_df["Goal_end_FDR"] = get_fdr(
397
+ list(cos_sims_full_df["Goal_end_vs_random_pval"])
398
+ )
399
+ if alt_end_state_exists is True:
400
+ for alt_state in cell_states_to_model["alt_states"]:
401
+ cos_sims_full_df[f"Alt_end_FDR_{alt_state}"] = get_fdr(
402
+ list(cos_sims_full_df[f"Alt_end_vs_random_pval_{alt_state}"])
403
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
404
 
405
  # quantify number of detections of each gene
406
+ cos_sims_full_df["N_Detections"] = [
407
+ n_detections_dict[token] for token in cos_sims_full_df["Gene"]
408
+ ]
409
+
410
+ # sort by shift to desired state
411
+ cos_sims_full_df["Sig"] = [
412
+ 1 if fdr < 0.05 else 0 for fdr in cos_sims_full_df["Goal_end_FDR"]
413
+ ]
414
+ cos_sims_full_df = cos_sims_full_df.sort_values(
415
+ by=["Sig", "Shift_to_goal_end", "Goal_end_FDR"],
416
+ ascending=[False, False, True],
417
+ )
418
+
419
  return cos_sims_full_df
420
 
421
+
422
  # stats comparing cos sim shifts of test perturbations vs null distribution
423
  def isp_stats_vs_null(cos_sims_df, dict_list, null_dict_list):
424
  cos_sims_full_df = cos_sims_df.copy()
425
 
426
  cos_sims_full_df["Test_avg_shift"] = np.zeros(cos_sims_df.shape[0], dtype=float)
427
  cos_sims_full_df["Null_avg_shift"] = np.zeros(cos_sims_df.shape[0], dtype=float)
428
+ cos_sims_full_df["Test_vs_null_avg_shift"] = np.zeros(
429
+ cos_sims_df.shape[0], dtype=float
430
+ )
431
  cos_sims_full_df["Test_vs_null_pval"] = np.zeros(cos_sims_df.shape[0], dtype=float)
432
  cos_sims_full_df["Test_vs_null_FDR"] = np.zeros(cos_sims_df.shape[0], dtype=float)
433
+ cos_sims_full_df["N_Detections_test"] = np.zeros(
434
+ cos_sims_df.shape[0], dtype="uint32"
435
+ )
436
+ cos_sims_full_df["N_Detections_null"] = np.zeros(
437
+ cos_sims_df.shape[0], dtype="uint32"
438
+ )
439
+
440
  for i in trange(cos_sims_df.shape[0]):
441
  token = cos_sims_df["Gene"][i]
442
  test_shifts = []
443
  null_shifts = []
444
+
445
  for dict_i in dict_list:
446
+ test_shifts += dict_i.get((token, "cell_emb"), [])
447
 
448
  for dict_i in null_dict_list:
449
+ null_shifts += dict_i.get((token, "cell_emb"), [])
450
+
451
  cos_sims_full_df.loc[i, "Test_avg_shift"] = np.mean(test_shifts)
452
  cos_sims_full_df.loc[i, "Null_avg_shift"] = np.mean(null_shifts)
453
+ cos_sims_full_df.loc[i, "Test_vs_null_avg_shift"] = np.mean(
454
+ test_shifts
455
+ ) - np.mean(null_shifts)
456
+ cos_sims_full_df.loc[i, "Test_vs_null_pval"] = ranksums(
457
+ test_shifts, null_shifts, nan_policy="omit"
458
+ ).pvalue
459
+ # remove nan values
460
+ cos_sims_full_df.Test_vs_null_pval = np.where(
461
+ np.isnan(cos_sims_full_df.Test_vs_null_pval),
462
+ 1,
463
+ cos_sims_full_df.Test_vs_null_pval,
464
+ )
465
  cos_sims_full_df.loc[i, "N_Detections_test"] = len(test_shifts)
466
  cos_sims_full_df.loc[i, "N_Detections_null"] = len(null_shifts)
467
 
468
+ cos_sims_full_df["Test_vs_null_FDR"] = get_fdr(
469
+ cos_sims_full_df["Test_vs_null_pval"]
470
+ )
471
+
472
+ cos_sims_full_df["Sig"] = [
473
+ 1 if fdr < 0.05 else 0 for fdr in cos_sims_full_df["Test_vs_null_FDR"]
474
+ ]
475
+ cos_sims_full_df = cos_sims_full_df.sort_values(
476
+ by=["Sig", "Test_vs_null_avg_shift", "Test_vs_null_FDR"],
477
+ ascending=[False, False, True],
478
+ )
479
  return cos_sims_full_df
480
 
481
+
482
  # stats for identifying perturbations with largest effect within a given set of cells
483
  # fits a mixture model to 2 components (impact vs. non-impact) and
484
  # reports the most likely component for each test perturbation
485
  # Note: because assumes given perturbation has a consistent effect in the cells tested,
486
  # we recommend only using the mixture model strategy with uniform cell populations
487
  def isp_stats_mixture_model(cos_sims_df, dict_list, combos, anchor_token):
488
+ names = ["Gene", "Gene_name", "Ensembl_ID"]
489
+
 
 
 
490
  if combos == 0:
491
  names += ["Test_avg_shift"]
492
  elif combos == 1:
493
+ names += [
494
+ "Anchor_shift",
495
+ "Test_token_shift",
496
+ "Sum_of_indiv_shifts",
497
+ "Combo_shift",
498
+ "Combo_minus_sum_shift",
499
+ ]
500
+
501
+ names += ["Impact_component", "Impact_component_percent"]
502
 
503
  cos_sims_full_df = pd.DataFrame(columns=names)
504
  avg_values = []
505
  gene_names = []
506
+
507
  for i in trange(cos_sims_df.shape[0]):
508
  token = cos_sims_df["Gene"][i]
509
  name = cos_sims_df["Gene_name"][i]
510
  ensembl_id = cos_sims_df["Ensembl_ID"][i]
511
  cos_shift_data = []
512
+
513
  for dict_i in dict_list:
514
  if (combos == 0) and (anchor_token is not None):
515
+ cos_shift_data += dict_i.get((anchor_token, token), [])
516
  else:
517
+ cos_shift_data += dict_i.get((token, "cell_emb"), [])
518
+
519
  # Extract values for current gene
520
  if combos == 0:
521
  test_values = cos_shift_data
522
  elif combos == 1:
523
  test_values = []
524
  for tup in cos_shift_data:
525
+ test_values.append(tup[2])
526
+
527
  if len(test_values) > 0:
528
  avg_value = np.mean(test_values)
529
  avg_values.append(avg_value)
530
  gene_names.append(name)
531
+
532
  # fit Gaussian mixture model to dataset of mean for each gene
533
  avg_values_to_fit = np.array(avg_values).reshape(-1, 1)
534
  gm = GaussianMixture(n_components=2, random_state=0).fit(avg_values_to_fit)
535
+
536
  for i in trange(cos_sims_df.shape[0]):
537
  token = cos_sims_df["Gene"][i]
538
  name = cos_sims_df["Gene_name"][i]
541
 
542
  for dict_i in dict_list:
543
  if (combos == 0) and (anchor_token is not None):
544
+ cos_shift_data += dict_i.get((anchor_token, token), [])
545
  else:
546
+ cos_shift_data += dict_i.get((token, "cell_emb"), [])
547
+
548
  if combos == 0:
549
  mean_test = np.mean(cos_shift_data)
550
+ impact_components = [
551
+ get_impact_component(value, gm) for value in cos_shift_data
552
+ ]
553
  elif combos == 1:
554
+ anchor_cos_sim_megalist = [
555
+ anchor for anchor, token, combo in cos_shift_data
556
+ ]
557
+ token_cos_sim_megalist = [token for anchor, token, combo in cos_shift_data]
558
+ anchor_plus_token_cos_sim_megalist = [
559
+ 1 - ((1 - anchor) + (1 - token))
560
+ for anchor, token, combo in cos_shift_data
561
+ ]
562
+ combo_anchor_token_cos_sim_megalist = [
563
+ combo for anchor, token, combo in cos_shift_data
564
+ ]
565
+ combo_minus_sum_cos_sim_megalist = [
566
+ combo - (1 - ((1 - anchor) + (1 - token)))
567
+ for anchor, token, combo in cos_shift_data
568
+ ]
569
 
570
  mean_anchor = np.mean(anchor_cos_sim_megalist)
571
  mean_token = np.mean(token_cos_sim_megalist)
572
  mean_sum = np.mean(anchor_plus_token_cos_sim_megalist)
573
  mean_test = np.mean(combo_anchor_token_cos_sim_megalist)
574
  mean_combo_minus_sum = np.mean(combo_minus_sum_cos_sim_megalist)
575
+
576
+ impact_components = [
577
+ get_impact_component(value, gm)
578
+ for value in combo_anchor_token_cos_sim_megalist
579
+ ]
580
+
581
+ impact_component = get_impact_component(mean_test, gm)
582
+ impact_component_percent = np.mean(impact_components) * 100
583
+
584
+ data_i = [token, name, ensembl_id]
585
  if combos == 0:
586
  data_i += [mean_test]
587
  elif combos == 1:
588
+ data_i += [
589
+ mean_anchor,
590
+ mean_token,
591
+ mean_sum,
592
+ mean_test,
593
+ mean_combo_minus_sum,
594
+ ]
595
+ data_i += [impact_component, impact_component_percent]
596
+
597
+ cos_sims_df_i = pd.DataFrame(dict(zip(names, data_i)), index=[i])
598
+ cos_sims_full_df = pd.concat([cos_sims_full_df, cos_sims_df_i])
599
+
600
  # quantify number of detections of each gene
601
+ cos_sims_full_df["N_Detections"] = [
602
+ n_detections(i, dict_list, "gene", anchor_token)
603
+ for i in cos_sims_full_df["Gene"]
604
+ ]
605
+
606
  if combos == 0:
607
+ cos_sims_full_df = cos_sims_full_df.sort_values(
608
+ by=["Impact_component", "Test_avg_shift"], ascending=[False, True]
609
+ )
610
  elif combos == 1:
611
+ cos_sims_full_df = cos_sims_full_df.sort_values(
612
+ by=["Impact_component", "Combo_minus_sum_shift"], ascending=[False, True]
613
+ )
614
  return cos_sims_full_df
615
 
616
+
617
  class InSilicoPerturberStats:
618
  valid_option_dict = {
619
+ "mode": {
620
+ "goal_state_shift",
621
+ "vs_null",
622
+ "mixture_model",
623
+ "aggregate_data",
624
+ "aggregate_gene_shifts",
625
+ },
626
+ "genes_perturbed": {"all", list},
627
+ "combos": {0, 1},
628
  "anchor_gene": {None, str},
629
  "cell_states_to_model": {None, dict},
630
+ "pickle_suffix": {None, str},
631
  }
632
+
633
  def __init__(
634
  self,
635
  mode="mixture_model",
637
  combos=0,
638
  anchor_gene=None,
639
  cell_states_to_model=None,
640
+ pickle_suffix="_raw.pickle",
641
  token_dictionary_file=TOKEN_DICTIONARY_FILE,
642
  gene_name_id_dictionary_file=GENE_NAME_ID_DICTIONARY_FILE,
643
  ):
646
 
647
  Parameters
648
  ----------
649
+ mode : {"goal_state_shift","vs_null","mixture_model","aggregate_data","aggregate_gene_shifts"}
650
  Type of stats.
651
  "goal_state_shift": perturbation vs. random for desired cell state shift
652
  "vs_null": perturbation vs. null from provided null distribution dataset
653
  "mixture_model": perturbation in impact vs. no impact component of mixture model (no goal direction)
654
  "aggregate_data": aggregates cosine shifts for single perturbation in multiple cells
655
+ "aggregate_gene_shifts": aggregates cosine shifts of genes in response to perturbation(s)
656
  genes_perturbed : "all", list
657
  Genes perturbed in isp experiment.
658
  Default is assuming genes_to_perturb in isp experiment was "all" (each gene in each cell).
687
  self.combos = combos
688
  self.anchor_gene = anchor_gene
689
  self.cell_states_to_model = cell_states_to_model
690
+ self.pickle_suffix = pickle_suffix
691
+
692
  self.validate_options()
693
 
694
  # load token dictionary (Ensembl IDs:token)
695
  with open(token_dictionary_file, "rb") as f:
696
  self.gene_token_dict = pickle.load(f)
697
+
698
  # load gene name dictionary (gene name:Ensembl ID)
699
  with open(gene_name_id_dictionary_file, "rb") as f:
700
  self.gene_name_id_dict = pickle.load(f)
705
  self.anchor_token = self.gene_token_dict[self.anchor_gene]
706
 
707
  def validate_options(self):
708
+ for attr_name, valid_options in self.valid_option_dict.items():
709
  attr_value = self.__dict__[attr_name]
710
  if type(attr_value) not in {list, dict}:
711
  if attr_name in {"anchor_gene"}:
714
  continue
715
  valid_type = False
716
  for option in valid_options:
717
+ if (option in [str, int, list, dict]) and isinstance(
718
+ attr_value, option
719
+ ):
720
  valid_type = True
721
  break
722
+ if not valid_type:
723
+ logger.error(
724
+ f"Invalid option for {attr_name}. "
725
+ f"Valid options for {attr_name}: {valid_options}"
726
+ )
727
+ raise
728
+
 
729
  if self.cell_states_to_model is not None:
730
  if len(self.cell_states_to_model.items()) == 1:
731
  logger.warning(
732
+ "The single value dictionary for cell_states_to_model will be "
733
+ "replaced with a dictionary with named keys for start, goal, and alternate states. "
734
+ "Please specify state_key, start_state, goal_state, and alt_states "
735
+ "in the cell_states_to_model dictionary for future use. "
736
+ "For example, cell_states_to_model={"
737
+ "'state_key': 'disease', "
738
+ "'start_state': 'dcm', "
739
+ "'goal_state': 'nf', "
740
+ "'alt_states': ['hcm', 'other1', 'other2']}"
741
  )
742
+ for key, value in self.cell_states_to_model.items():
743
  if (len(value) == 3) and isinstance(value, tuple):
744
+ if (
745
+ isinstance(value[0], list)
746
+ and isinstance(value[1], list)
747
+ and isinstance(value[2], list)
748
+ ):
749
  if len(value[0]) == 1 and len(value[1]) == 1:
750
+ all_values = value[0] + value[1] + value[2]
751
  if len(all_values) == len(set(all_values)):
752
  continue
753
  # reformat to the new named key format
756
  "state_key": list(self.cell_states_to_model.keys())[0],
757
  "start_state": state_values[0][0],
758
  "goal_state": state_values[1][0],
759
+ "alt_states": state_values[2:][0],
760
  }
761
+ elif set(self.cell_states_to_model.keys()) == {
762
+ "state_key",
763
+ "start_state",
764
+ "goal_state",
765
+ "alt_states",
766
+ }:
767
+ if (
768
+ (self.cell_states_to_model["state_key"] is None)
769
+ or (self.cell_states_to_model["start_state"] is None)
770
+ or (self.cell_states_to_model["goal_state"] is None)
771
+ ):
772
  logger.error(
773
+ "Please specify 'state_key', 'start_state', and 'goal_state' in cell_states_to_model."
774
+ )
775
  raise
776
+
777
+ if (
778
+ self.cell_states_to_model["start_state"]
779
+ == self.cell_states_to_model["goal_state"]
780
+ ):
781
+ logger.error("All states must be unique.")
782
  raise
783
 
784
  if self.cell_states_to_model["alt_states"] is not None:
785
+ if not isinstance(self.cell_states_to_model["alt_states"], list):
786
  logger.error(
787
  "self.cell_states_to_model['alt_states'] must be a list (even if it is one element)."
788
  )
789
  raise
790
+ if len(self.cell_states_to_model["alt_states"]) != len(
791
+ set(self.cell_states_to_model["alt_states"])
792
+ ):
793
+ logger.error("All states must be unique.")
794
  raise
795
 
796
  else:
797
  logger.error(
798
+ "cell_states_to_model must only have the following four keys: "
799
+ "'state_key', 'start_state', 'goal_state', 'alt_states'."
800
+ "For example, cell_states_to_model={"
801
+ "'state_key': 'disease', "
802
+ "'start_state': 'dcm', "
803
+ "'goal_state': 'nf', "
804
+ "'alt_states': ['hcm', 'other1', 'other2']}"
805
  )
806
  raise
807
 
808
  if self.anchor_gene is not None:
809
  self.anchor_gene = None
810
  logger.warning(
811
+ "anchor_gene set to None. "
812
+ "Currently, anchor gene not available "
813
+ "when modeling multiple cell states."
814
+ )
815
+
816
  if self.combos > 0:
817
  if self.anchor_gene is None:
818
  logger.error(
819
+ "Currently, stats are only supported for combination "
820
+ "in silico perturbation run with anchor gene. Please add "
821
+ "anchor gene when using with combos > 0. "
822
+ )
823
  raise
824
+
825
  if (self.mode == "mixture_model") and (self.genes_perturbed != "all"):
826
  logger.error(
827
+ "Mixture model mode requires multiple gene perturbations to fit model "
828
+ "so is incompatible with a single grouped perturbation."
829
+ )
830
  raise
831
  if (self.mode == "aggregate_data") and (self.genes_perturbed == "all"):
832
  logger.error(
833
+ "Simple data aggregation mode is for single perturbation in multiple cells "
834
+ "so is incompatible with a genes_perturbed being 'all'."
835
+ )
836
+ raise
837
+
838
+ def get_stats(
839
+ self,
840
+ input_data_directory,
841
+ null_dist_data_directory,
842
+ output_directory,
843
+ output_prefix,
844
+ null_dict_list=None,
845
+ ):
846
  """
847
  Get stats for in silico perturbation data and save as results in output_directory.
848
 
856
  Path to directory where perturbation data will be saved as .csv
857
  output_prefix : str
858
  Prefix for output .csv
859
+ null_dict_list: dict
860
+ List of loaded null distribtion dictionary if more than one comparison vs. the null is to be performed
861
+
862
  Outputs
863
  ----------
864
  Definition of possible columns in .csv output file.
865
+
866
  Of note, not all columns will be present in all output files.
867
  Some columns are specific to particular perturbation modes.
868
+
869
  "Gene": gene token
870
  "Gene_name": gene name
871
  "Ensembl_ID": gene Ensembl ID
872
  "N_Detections": number of cells in which each gene or gene combination was detected in the input dataset
873
  "Sig": 1 if FDR<0.05, otherwise 0
874
+
875
  "Shift_to_goal_end": cosine shift from start state towards goal end state in response to given perturbation
876
  "Shift_to_alt_end": cosine shift from start state towards alternate end state in response to given perturbation
877
  "Goal_end_vs_random_pval": pvalue of cosine shift from start state towards goal end state by Wilcoxon
880
  pvalue compares shift caused by perturbing given gene compared to random genes
881
  "Goal_end_FDR": Benjamini-Hochberg correction of "Goal_end_vs_random_pval"
882
  "Alt_end_FDR": Benjamini-Hochberg correction of "Alt_end_vs_random_pval"
883
+
884
  "Test_avg_shift": cosine shift in response to given perturbation in cells from test distribution
885
  "Null_avg_shift": cosine shift in response to given perturbation in cells from null distribution (e.g. random cells)
886
  "Test_vs_null_avg_shift": difference in cosine shift in cells from test vs. null distribution
889
  "Test_vs_null_FDR": Benjamini-Hochberg correction of "Test_vs_null_pval"
890
  "N_Detections_test": "N_Detections" in cells from test distribution
891
  "N_Detections_null": "N_Detections" in cells from null distribution
892
+
893
  "Anchor_shift": cosine shift in response to given perturbation of anchor gene
894
  "Test_token_shift": cosine shift in response to given perturbation of test gene
895
  "Sum_of_indiv_shifts": sum of cosine shifts in response to individually perturbing test and anchor genes
899
  "Impact_component": whether the given perturbation was modeled to be within the impact component by the mixture model
900
  1: within impact component; 0: not within impact component
901
  "Impact_component_percent": percent of cells in which given perturbation was modeled to be within impact component
902
+
903
+ In case of aggregating gene shifts:
904
+ "Perturbed": ID(s) of gene(s) being perturbed
905
+ "Affected": ID of affected gene or "cell_emb" indicating the impact on the cell embedding as a whole
906
+ "Cosine_shift_mean": mean of cosine shift of modeled perturbation on affected gene or cell
907
+ "Cosine_shift_stdev": standard deviation of cosine shift of modeled perturbation on affected gene or cell
908
  """
909
 
910
+ if self.mode not in [
911
+ "goal_state_shift",
912
+ "vs_null",
913
+ "mixture_model",
914
+ "aggregate_data",
915
+ "aggregate_gene_shifts",
916
+ ]:
917
  logger.error(
918
+ "Currently, only modes available are stats for goal_state_shift, "
919
+ "vs_null (comparing to null distribution), "
920
+ "mixture_model (fitting mixture model for perturbations with or without impact), "
921
+ "and aggregating data for single perturbations or for gene embedding shifts."
922
+ )
923
  raise
924
 
925
  self.gene_token_id_dict = invert_dict(self.gene_token_dict)
928
  # obtain total gene list
929
  if (self.combos == 0) and (self.anchor_token is not None):
930
  # cos sim data for effect of gene perturbation on the embedding of each other gene
931
+ dict_list = read_dictionaries(
932
+ input_data_directory,
933
+ "gene",
934
+ self.anchor_token,
935
+ self.cell_states_to_model,
936
+ self.pickle_suffix,
937
+ )
938
  gene_list = get_gene_list(dict_list, "gene")
939
+ elif (
940
+ (self.combos == 0)
941
+ and (self.anchor_token is None)
942
+ and (self.mode == "aggregate_gene_shifts")
943
+ ):
944
+ dict_list = read_dictionaries(
945
+ input_data_directory,
946
+ "gene",
947
+ self.anchor_token,
948
+ self.cell_states_to_model,
949
+ self.pickle_suffix,
950
+ )
951
+ gene_list = get_gene_list(dict_list, "cell")
952
  else:
953
  # cos sim data for effect of gene perturbation on the embedding of each cell
954
+ dict_list = read_dictionaries(
955
+ input_data_directory,
956
+ "cell",
957
+ self.anchor_token,
958
+ self.cell_states_to_model,
959
+ self.pickle_suffix,
960
+ )
961
  gene_list = get_gene_list(dict_list, "cell")
962
+
963
  # initiate results dataframe
964
+ cos_sims_df_initial = pd.DataFrame(
965
+ {
966
+ "Gene": gene_list,
967
+ "Gene_name": [self.token_to_gene_name(item) for item in gene_list],
968
+ "Ensembl_ID": [
969
+ token_tuple_to_ensembl_ids(genes, self.gene_token_id_dict)
970
+ if self.genes_perturbed != "all"
971
+ else self.gene_token_id_dict[genes[1]]
972
+ if isinstance(genes, tuple)
973
+ else self.gene_token_id_dict[genes]
974
+ for genes in gene_list
975
+ ],
976
+ },
977
+ index=[i for i in range(len(gene_list))],
978
+ )
979
 
980
  if self.mode == "goal_state_shift":
981
+ cos_sims_df = isp_stats_to_goal_state(
982
+ cos_sims_df_initial,
983
+ dict_list,
984
+ self.cell_states_to_model,
985
+ self.genes_perturbed,
986
+ )
987
+
988
  elif self.mode == "vs_null":
989
+ if null_dict_list is None:
990
+ null_dict_list = read_dictionaries(
991
+ null_dist_data_directory,
992
+ "cell",
993
+ self.anchor_token,
994
+ self.cell_states_to_model,
995
+ self.pickle_suffix,
996
+ )
997
+ cos_sims_df = isp_stats_vs_null(
998
+ cos_sims_df_initial, dict_list, null_dict_list
999
+ )
1000
 
1001
  elif self.mode == "mixture_model":
1002
+ cos_sims_df = isp_stats_mixture_model(
1003
+ cos_sims_df_initial, dict_list, self.combos, self.anchor_token
1004
+ )
1005
+
1006
  elif self.mode == "aggregate_data":
1007
  cos_sims_df = isp_aggregate_grouped_perturb(cos_sims_df_initial, dict_list)
1008
 
1009
+ elif self.mode == "aggregate_gene_shifts":
1010
+ cos_sims_df = isp_aggregate_gene_shifts(
1011
+ cos_sims_df_initial,
1012
+ dict_list,
1013
+ self.gene_token_id_dict,
1014
+ self.gene_id_name_dict,
1015
+ )
1016
+
1017
  # save perturbation stats to output_path
1018
  output_path = (Path(output_directory) / output_prefix).with_suffix(".csv")
1019
  cos_sims_df.to_csv(output_path)
1020
 
1021
  def token_to_gene_name(self, item):
1022
+ if isinstance(item, int):
1023
+ return self.gene_id_name_dict.get(
1024
+ self.gene_token_id_dict.get(item, np.nan), np.nan
1025
+ )
1026
+ if isinstance(item, tuple):
1027
+ return tuple(
1028
+ [
1029
+ self.gene_id_name_dict.get(
1030
+ self.gene_token_id_dict.get(i, np.nan), np.nan
1031
+ )
1032
+ for i in item
1033
+ ]
1034
+ )
geneformer/perturber_utils.py ADDED
@@ -0,0 +1,698 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import itertools as it
2
+ import logging
3
+ import pickle
4
+ import re
5
+ from collections import defaultdict
6
+
7
+ import numpy as np
8
+ import pandas as pd
9
+ import seaborn as sns
10
+ import torch
11
+ from datasets import Dataset, load_from_disk
12
+ from transformers import (
13
+ BertForMaskedLM,
14
+ BertForSequenceClassification,
15
+ BertForTokenClassification,
16
+ )
17
+
18
+ sns.set()
19
+
20
+ logger = logging.getLogger(__name__)
21
+
22
+
23
+ # load data and filter by defined criteria
24
+ def load_and_filter(filter_data, nproc, input_data_file):
25
+ data = load_from_disk(input_data_file)
26
+ if filter_data is not None:
27
+ data = filter_by_dict(data, filter_data, nproc)
28
+ return data
29
+
30
+
31
+ def filter_by_dict(data, filter_data, nproc):
32
+ for key, value in filter_data.items():
33
+
34
+ def filter_data_by_criteria(example):
35
+ return example[key] in value
36
+
37
+ data = data.filter(filter_data_by_criteria, num_proc=nproc)
38
+ if len(data) == 0:
39
+ logger.error("No cells remain after filtering. Check filtering criteria.")
40
+ raise
41
+ return data
42
+
43
+
44
+ def filter_data_by_tokens(filtered_input_data, tokens, nproc):
45
+ def if_has_tokens(example):
46
+ return len(set(example["input_ids"]).intersection(tokens)) == len(tokens)
47
+
48
+ filtered_input_data = filtered_input_data.filter(if_has_tokens, num_proc=nproc)
49
+ return filtered_input_data
50
+
51
+
52
+ def logging_filtered_data_len(filtered_input_data, filtered_tokens_categ):
53
+ if len(filtered_input_data) == 0:
54
+ logger.error(f"No cells in dataset contain {filtered_tokens_categ}.")
55
+ raise
56
+ else:
57
+ logger.info(f"# cells with {filtered_tokens_categ}: {len(filtered_input_data)}")
58
+
59
+
60
+ def filter_data_by_tokens_and_log(
61
+ filtered_input_data, tokens, nproc, filtered_tokens_categ
62
+ ):
63
+ # filter for cells with anchor gene
64
+ filtered_input_data = filter_data_by_tokens(filtered_input_data, tokens, nproc)
65
+ # logging length of filtered data
66
+ logging_filtered_data_len(filtered_input_data, filtered_tokens_categ)
67
+
68
+ return filtered_input_data
69
+
70
+
71
+ def filter_data_by_start_state(filtered_input_data, cell_states_to_model, nproc):
72
+ # confirm that start state is valid to prevent futile filtering
73
+ state_key = cell_states_to_model["state_key"]
74
+ state_values = filtered_input_data[state_key]
75
+ start_state = cell_states_to_model["start_state"]
76
+ if start_state not in state_values:
77
+ logger.error(
78
+ f"Start state {start_state} is not present "
79
+ f"in the dataset's {state_key} attribute."
80
+ )
81
+ raise
82
+
83
+ # filter for start state cells
84
+ def filter_for_origin(example):
85
+ return example[state_key] in [start_state]
86
+
87
+ filtered_input_data = filtered_input_data.filter(filter_for_origin, num_proc=nproc)
88
+ return filtered_input_data
89
+
90
+
91
+ def slice_by_inds_to_perturb(filtered_input_data, cell_inds_to_perturb):
92
+ if cell_inds_to_perturb["start"] >= len(filtered_input_data):
93
+ logger.error(
94
+ "cell_inds_to_perturb['start'] is larger than the filtered dataset."
95
+ )
96
+ raise
97
+ if cell_inds_to_perturb["end"] > len(filtered_input_data):
98
+ logger.warning(
99
+ "cell_inds_to_perturb['end'] is larger than the filtered dataset. \
100
+ Setting to the end of the filtered dataset."
101
+ )
102
+ cell_inds_to_perturb["end"] = len(filtered_input_data)
103
+ filtered_input_data = filtered_input_data.select(
104
+ [i for i in range(cell_inds_to_perturb["start"], cell_inds_to_perturb["end"])]
105
+ )
106
+ return filtered_input_data
107
+
108
+
109
+ # load model to GPU
110
+ def load_model(model_type, num_classes, model_directory):
111
+ if model_type == "Pretrained":
112
+ model = BertForMaskedLM.from_pretrained(
113
+ model_directory, output_hidden_states=True, output_attentions=False
114
+ )
115
+ elif model_type == "GeneClassifier":
116
+ model = BertForTokenClassification.from_pretrained(
117
+ model_directory,
118
+ num_labels=num_classes,
119
+ output_hidden_states=True,
120
+ output_attentions=False,
121
+ )
122
+ elif model_type == "CellClassifier":
123
+ model = BertForSequenceClassification.from_pretrained(
124
+ model_directory,
125
+ num_labels=num_classes,
126
+ output_hidden_states=True,
127
+ output_attentions=False,
128
+ )
129
+ # put the model in eval mode for fwd pass
130
+ model.eval()
131
+ model = model.to("cuda:0")
132
+ return model
133
+
134
+
135
+ def quant_layers(model):
136
+ layer_nums = []
137
+ for name, parameter in model.named_parameters():
138
+ if "layer" in name:
139
+ layer_nums += [int(name.split("layer.")[1].split(".")[0])]
140
+ return int(max(layer_nums)) + 1
141
+
142
+
143
+ def get_model_input_size(model):
144
+ return int(re.split("\(|,", str(model.bert.embeddings.position_embeddings))[1])
145
+
146
+
147
+ def flatten_list(megalist):
148
+ return [item for sublist in megalist for item in sublist]
149
+
150
+
151
+ def measure_length(example):
152
+ example["length"] = len(example["input_ids"])
153
+ return example
154
+
155
+
156
+ def downsample_and_sort(data, max_ncells):
157
+ num_cells = len(data)
158
+ # if max number of cells is defined, then shuffle and subsample to this max number
159
+ if max_ncells is not None:
160
+ if num_cells > max_ncells:
161
+ data = data.shuffle(seed=42)
162
+ num_cells = max_ncells
163
+ data_subset = data.select([i for i in range(num_cells)])
164
+ # sort dataset with largest cell first to encounter any memory errors earlier
165
+ data_sorted = data_subset.sort("length", reverse=True)
166
+ return data_sorted
167
+
168
+
169
+ def get_possible_states(cell_states_to_model):
170
+ possible_states = []
171
+ for key in ["start_state", "goal_state"]:
172
+ possible_states += [cell_states_to_model[key]]
173
+ possible_states += cell_states_to_model.get("alt_states", [])
174
+ return possible_states
175
+
176
+
177
+ def forward_pass_single_cell(model, example_cell, layer_to_quant):
178
+ example_cell.set_format(type="torch")
179
+ input_data = example_cell["input_ids"]
180
+ with torch.no_grad():
181
+ outputs = model(input_ids=input_data.to("cuda"))
182
+ emb = torch.squeeze(outputs.hidden_states[layer_to_quant])
183
+ del outputs
184
+ return emb
185
+
186
+
187
+ def perturb_emb_by_index(emb, indices):
188
+ mask = torch.ones(emb.numel(), dtype=torch.bool)
189
+ mask[indices] = False
190
+ return emb[mask]
191
+
192
+
193
+ def delete_indices(example):
194
+ indices = example["perturb_index"]
195
+ if any(isinstance(el, list) for el in indices):
196
+ indices = flatten_list(indices)
197
+ for index in sorted(indices, reverse=True):
198
+ del example["input_ids"][index]
199
+
200
+ example["length"] = len(example["input_ids"])
201
+ return example
202
+
203
+
204
+ # for genes_to_perturb = "all" where only genes within cell are overexpressed
205
+ def overexpress_indices(example):
206
+ indices = example["perturb_index"]
207
+ if any(isinstance(el, list) for el in indices):
208
+ indices = flatten_list(indices)
209
+ for index in sorted(indices, reverse=True):
210
+ example["input_ids"].insert(0, example["input_ids"].pop(index))
211
+
212
+ example["length"] = len(example["input_ids"])
213
+ return example
214
+
215
+
216
+ # for genes_to_perturb = list of genes to overexpress that are not necessarily expressed in cell
217
+ def overexpress_tokens(example, max_len):
218
+ # -100 indicates tokens to overexpress are not present in rank value encoding
219
+ if example["perturb_index"] != [-100]:
220
+ example = delete_indices(example)
221
+ [
222
+ example["input_ids"].insert(0, token)
223
+ for token in example["tokens_to_perturb"][::-1]
224
+ ]
225
+
226
+ # truncate to max input size, must also truncate original emb to be comparable
227
+ if len(example["input_ids"]) > max_len:
228
+ example["input_ids"] = example["input_ids"][0:max_len]
229
+
230
+ example["length"] = len(example["input_ids"])
231
+ return example
232
+
233
+
234
+ def calc_n_overflow(max_len, example_len, tokens_to_perturb, indices_to_perturb):
235
+ n_to_add = len(tokens_to_perturb) - len(indices_to_perturb)
236
+ n_overflow = example_len + n_to_add - max_len
237
+ return n_overflow
238
+
239
+
240
+ def truncate_by_n_overflow(example):
241
+ new_max_len = example["length"] - example["n_overflow"]
242
+ example["input_ids"] = example["input_ids"][0:new_max_len]
243
+ example["length"] = len(example["input_ids"])
244
+ return example
245
+
246
+
247
+ def remove_indices_from_emb(emb, indices_to_remove, gene_dim):
248
+ # indices_to_remove is list of indices to remove
249
+ indices_to_keep = [
250
+ i for i in range(emb.size()[gene_dim]) if i not in indices_to_remove
251
+ ]
252
+ num_dims = emb.dim()
253
+ emb_slice = [
254
+ slice(None) if dim != gene_dim else indices_to_keep for dim in range(num_dims)
255
+ ]
256
+ sliced_emb = emb[emb_slice]
257
+ return sliced_emb
258
+
259
+
260
+ def remove_indices_from_emb_batch(emb_batch, list_of_indices_to_remove, gene_dim):
261
+ output_batch_list = [
262
+ remove_indices_from_emb(emb_batch[i, :, :], idxes, gene_dim - 1)
263
+ for i, idxes in enumerate(list_of_indices_to_remove)
264
+ ]
265
+ # add padding given genes are sometimes added that are or are not in original cell
266
+ batch_max = max([emb.size()[gene_dim - 1] for emb in output_batch_list])
267
+ output_batch_list_padded = [
268
+ pad_xd_tensor(emb, 0.000, batch_max, gene_dim - 1) for emb in output_batch_list
269
+ ]
270
+ return torch.stack(output_batch_list_padded)
271
+
272
+
273
+ # removes perturbed indices
274
+ # need to handle the various cases where a set of genes is overexpressed
275
+ def remove_perturbed_indices_set(
276
+ emb,
277
+ perturb_type: str,
278
+ indices_to_perturb: list[list],
279
+ tokens_to_perturb: list[list],
280
+ original_lengths: list[int],
281
+ input_ids=None,
282
+ ):
283
+ if perturb_type == "overexpress":
284
+ num_perturbed = len(tokens_to_perturb)
285
+ if num_perturbed == 1:
286
+ indices_to_perturb_orig = [
287
+ idx if idx != [-100] else [None] for idx in indices_to_perturb
288
+ ]
289
+ if all(v is [None] for v in indices_to_perturb_orig):
290
+ return emb
291
+ else:
292
+ indices_to_perturb_orig = []
293
+
294
+ for idx_list in indices_to_perturb:
295
+ indices_to_perturb_orig.append(
296
+ [idx if idx != [-100] else [None] for idx in idx_list]
297
+ )
298
+
299
+ else:
300
+ indices_to_perturb_orig = indices_to_perturb
301
+
302
+ emb = remove_indices_from_emb_batch(emb, indices_to_perturb_orig, gene_dim=1)
303
+
304
+ return emb
305
+
306
+
307
+ def make_perturbation_batch(
308
+ example_cell, perturb_type, tokens_to_perturb, anchor_token, combo_lvl, num_proc
309
+ ) -> tuple[Dataset, list[int]]:
310
+ if combo_lvl == 0 and tokens_to_perturb == "all":
311
+ if perturb_type in ["overexpress", "activate"]:
312
+ range_start = 1
313
+ elif perturb_type in ["delete", "inhibit"]:
314
+ range_start = 0
315
+ indices_to_perturb = [
316
+ [i] for i in range(range_start, example_cell["length"][0])
317
+ ]
318
+ # elif combo_lvl > 0 and anchor_token is None:
319
+ ## to implement
320
+ elif combo_lvl > 0 and (anchor_token is not None):
321
+ example_input_ids = example_cell["input_ids"][0]
322
+ anchor_index = example_input_ids.index(anchor_token[0])
323
+ indices_to_perturb = [
324
+ sorted([anchor_index, i]) if i != anchor_index else None
325
+ for i in range(example_cell["length"][0])
326
+ ]
327
+ indices_to_perturb = [item for item in indices_to_perturb if item is not None]
328
+ else:
329
+ example_input_ids = example_cell["input_ids"][0]
330
+ indices_to_perturb = [
331
+ [example_input_ids.index(token)] if token in example_input_ids else None
332
+ for token in tokens_to_perturb
333
+ ]
334
+ indices_to_perturb = [item for item in indices_to_perturb if item is not None]
335
+
336
+ # create all permutations of combo_lvl of modifiers from tokens_to_perturb
337
+ if combo_lvl > 0 and (anchor_token is None):
338
+ if tokens_to_perturb != "all":
339
+ if len(tokens_to_perturb) == combo_lvl + 1:
340
+ indices_to_perturb = [
341
+ list(x) for x in it.combinations(indices_to_perturb, combo_lvl + 1)
342
+ ]
343
+ else:
344
+ all_indices = [[i] for i in range(example_cell["length"][0])]
345
+ all_indices = [
346
+ index for index in all_indices if index not in indices_to_perturb
347
+ ]
348
+ indices_to_perturb = [
349
+ [[j for i in indices_to_perturb for j in i], x] for x in all_indices
350
+ ]
351
+
352
+ length = len(indices_to_perturb)
353
+ perturbation_dataset = Dataset.from_dict(
354
+ {
355
+ "input_ids": example_cell["input_ids"] * length,
356
+ "perturb_index": indices_to_perturb,
357
+ }
358
+ )
359
+
360
+ if length < 400:
361
+ num_proc_i = 1
362
+ else:
363
+ num_proc_i = num_proc
364
+
365
+ if perturb_type == "delete":
366
+ perturbation_dataset = perturbation_dataset.map(
367
+ delete_indices, num_proc=num_proc_i
368
+ )
369
+ elif perturb_type == "overexpress":
370
+ perturbation_dataset = perturbation_dataset.map(
371
+ overexpress_indices, num_proc=num_proc_i
372
+ )
373
+
374
+ perturbation_dataset = perturbation_dataset.map(measure_length, num_proc=num_proc_i)
375
+
376
+ return perturbation_dataset, indices_to_perturb
377
+
378
+
379
+ # perturbed cell emb removing the activated/overexpressed/inhibited gene emb
380
+ # so that only non-perturbed gene embeddings are compared to each other
381
+ # in original or perturbed context
382
+ def make_comparison_batch(original_emb_batch, indices_to_perturb, perturb_group):
383
+ all_embs_list = []
384
+
385
+ # if making comparison batch for multiple perturbations in single cell
386
+ if perturb_group is False:
387
+ # squeeze if single cell
388
+ if original_emb_batch.ndim == 3 and original_emb_batch.size()[0] == 1:
389
+ original_emb_batch = torch.squeeze(original_emb_batch)
390
+ original_emb_list = [original_emb_batch] * len(indices_to_perturb)
391
+ # if making comparison batch for single perturbation in multiple cells
392
+ elif perturb_group is True:
393
+ original_emb_list = original_emb_batch
394
+
395
+ for original_emb, indices in zip(original_emb_list, indices_to_perturb):
396
+ if indices == [-100]:
397
+ all_embs_list += [original_emb[:]]
398
+ continue
399
+
400
+ emb_list = []
401
+ start = 0
402
+ if any(isinstance(el, list) for el in indices):
403
+ indices = flatten_list(indices)
404
+
405
+ # removes indices that were perturbed from the original embedding
406
+ for i in sorted(indices):
407
+ emb_list += [original_emb[start:i]]
408
+ start = i + 1
409
+
410
+ emb_list += [original_emb[start:]]
411
+ all_embs_list += [torch.cat(emb_list)]
412
+
413
+ len_set = set([emb.size()[0] for emb in all_embs_list])
414
+ if len(len_set) > 1:
415
+ max_len = max(len_set)
416
+ all_embs_list = [pad_2d_tensor(emb, None, max_len, 0) for emb in all_embs_list]
417
+ return torch.stack(all_embs_list)
418
+
419
+
420
+ def pad_list(input_ids, pad_token_id, max_len):
421
+ input_ids = np.pad(
422
+ input_ids,
423
+ (0, max_len - len(input_ids)),
424
+ mode="constant",
425
+ constant_values=pad_token_id,
426
+ )
427
+ return input_ids
428
+
429
+
430
+ def pad_xd_tensor(tensor, pad_token_id, max_len, dim):
431
+ padding_length = max_len - tensor.size()[dim]
432
+ # Construct a padding configuration where all padding values are 0, except for the padding dimension
433
+ # 2 * number of dimensions (padding before and after for every dimension)
434
+ pad_config = [0] * 2 * tensor.dim()
435
+ # Set the padding after the desired dimension to the calculated padding length
436
+ pad_config[-2 * dim - 1] = padding_length
437
+ return torch.nn.functional.pad(
438
+ tensor, pad=pad_config, mode="constant", value=pad_token_id
439
+ )
440
+
441
+
442
+ def pad_tensor(tensor, pad_token_id, max_len):
443
+ tensor = torch.nn.functional.pad(
444
+ tensor, pad=(0, max_len - tensor.numel()), mode="constant", value=pad_token_id
445
+ )
446
+
447
+ return tensor
448
+
449
+
450
+ def pad_2d_tensor(tensor, pad_token_id, max_len, dim):
451
+ if dim == 0:
452
+ pad = (0, 0, 0, max_len - tensor.size()[dim])
453
+ elif dim == 1:
454
+ pad = (0, max_len - tensor.size()[dim], 0, 0)
455
+ tensor = torch.nn.functional.pad(
456
+ tensor, pad=pad, mode="constant", value=pad_token_id
457
+ )
458
+ return tensor
459
+
460
+
461
+ def pad_3d_tensor(tensor, pad_token_id, max_len, dim):
462
+ if dim == 0:
463
+ raise Exception("dim 0 usually does not need to be padded.")
464
+ if dim == 1:
465
+ pad = (0, 0, 0, max_len - tensor.size()[dim])
466
+ elif dim == 2:
467
+ pad = (0, max_len - tensor.size()[dim], 0, 0)
468
+ tensor = torch.nn.functional.pad(
469
+ tensor, pad=pad, mode="constant", value=pad_token_id
470
+ )
471
+ return tensor
472
+
473
+
474
+ def pad_or_truncate_encoding(encoding, pad_token_id, max_len):
475
+ if isinstance(encoding, torch.Tensor):
476
+ encoding_len = encoding.size()[0]
477
+ elif isinstance(encoding, list):
478
+ encoding_len = len(encoding)
479
+ if encoding_len > max_len:
480
+ encoding = encoding[0:max_len]
481
+ elif encoding_len < max_len:
482
+ if isinstance(encoding, torch.Tensor):
483
+ encoding = pad_tensor(encoding, pad_token_id, max_len)
484
+ elif isinstance(encoding, list):
485
+ encoding = pad_list(encoding, pad_token_id, max_len)
486
+ return encoding
487
+
488
+
489
+ # pad list of tensors and convert to tensor
490
+ def pad_tensor_list(
491
+ tensor_list,
492
+ dynamic_or_constant,
493
+ pad_token_id,
494
+ model_input_size,
495
+ dim=None,
496
+ padding_func=None,
497
+ ):
498
+ # determine maximum tensor length
499
+ if dynamic_or_constant == "dynamic":
500
+ max_len = max([tensor.squeeze().numel() for tensor in tensor_list])
501
+ elif isinstance(dynamic_or_constant, int):
502
+ max_len = dynamic_or_constant
503
+ else:
504
+ max_len = model_input_size
505
+ logger.warning(
506
+ "If padding style is constant, must provide integer value. "
507
+ f"Setting padding to max input size {model_input_size}."
508
+ )
509
+
510
+ # pad all tensors to maximum length
511
+ if dim is None:
512
+ tensor_list = [
513
+ pad_tensor(tensor, pad_token_id, max_len) for tensor in tensor_list
514
+ ]
515
+ else:
516
+ tensor_list = [
517
+ padding_func(tensor, pad_token_id, max_len, dim) for tensor in tensor_list
518
+ ]
519
+ # return stacked tensors
520
+ if padding_func != pad_3d_tensor:
521
+ return torch.stack(tensor_list)
522
+ else:
523
+ return torch.cat(tensor_list, 0)
524
+
525
+
526
+ def gen_attention_mask(minibatch_encoding, max_len=None):
527
+ if max_len is None:
528
+ max_len = max(minibatch_encoding["length"])
529
+ original_lens = minibatch_encoding["length"]
530
+ attention_mask = [
531
+ [1] * original_len + [0] * (max_len - original_len)
532
+ if original_len <= max_len
533
+ else [1] * max_len
534
+ for original_len in original_lens
535
+ ]
536
+ return torch.tensor(attention_mask, device="cuda")
537
+
538
+
539
+ # get cell embeddings excluding padding
540
+ def mean_nonpadding_embs(embs, original_lens, dim=1):
541
+ # create a mask tensor based on padding lengths
542
+ mask = torch.arange(embs.size(dim), device=embs.device) < original_lens.unsqueeze(1)
543
+ if embs.dim() == 3:
544
+ # fill the masked positions in embs with zeros
545
+ masked_embs = embs.masked_fill(~mask.unsqueeze(2), 0.0)
546
+
547
+ # compute the mean across the non-padding dimensions
548
+ mean_embs = masked_embs.sum(dim) / original_lens.view(-1, 1).float()
549
+
550
+ elif embs.dim() == 2:
551
+ masked_embs = embs.masked_fill(~mask, 0.0)
552
+ mean_embs = masked_embs.sum(dim) / original_lens.float()
553
+ return mean_embs
554
+
555
+
556
+ # get cell embeddings when there is no padding
557
+ def compute_nonpadded_cell_embedding(embs, cell_emb_style):
558
+ if cell_emb_style == "mean_pool":
559
+ return torch.mean(embs, dim=embs.ndim - 2)
560
+
561
+
562
+ # quantify shifts for a set of genes
563
+ def quant_cos_sims(
564
+ perturbation_emb,
565
+ original_emb,
566
+ cell_states_to_model,
567
+ state_embs_dict,
568
+ emb_mode="gene",
569
+ ):
570
+ if emb_mode == "gene":
571
+ cos = torch.nn.CosineSimilarity(dim=2)
572
+ elif emb_mode == "cell":
573
+ cos = torch.nn.CosineSimilarity(dim=1)
574
+
575
+ if cell_states_to_model is None:
576
+ cos_sims = cos(perturbation_emb, original_emb).to("cuda")
577
+ else:
578
+ possible_states = get_possible_states(cell_states_to_model)
579
+ cos_sims = dict(zip(possible_states, [[] for _ in range(len(possible_states))]))
580
+ for state in possible_states:
581
+ cos_sims[state] = cos_sim_shift(
582
+ original_emb,
583
+ perturbation_emb,
584
+ state_embs_dict[state].to("cuda"), # required to move to cuda here
585
+ cos,
586
+ )
587
+
588
+ return cos_sims
589
+
590
+
591
+ # calculate cos sim shift of perturbation with respect to origin and alternative cell
592
+ def cos_sim_shift(original_emb, perturbed_emb, end_emb, cos):
593
+ origin_v_end = cos(original_emb, end_emb)
594
+ perturb_v_end = cos(perturbed_emb, end_emb)
595
+
596
+ return perturb_v_end - origin_v_end
597
+
598
+
599
+ def concatenate_cos_sims(cos_sims):
600
+ if isinstance(cos_sims, list):
601
+ return torch.cat(cos_sims)
602
+ else:
603
+ for state in cos_sims.keys():
604
+ cos_sims[state] = torch.cat(cos_sims[state])
605
+ return cos_sims
606
+
607
+
608
+ def write_perturbation_dictionary(cos_sims_dict: defaultdict, output_path_prefix: str):
609
+ with open(f"{output_path_prefix}_raw.pickle", "wb") as fp:
610
+ pickle.dump(cos_sims_dict, fp)
611
+
612
+
613
+ def tensor_list_to_pd(tensor_list):
614
+ tensor = torch.cat(tensor_list).cpu().numpy()
615
+ df = pd.DataFrame(tensor)
616
+ return df
617
+
618
+
619
+ def validate_cell_states_to_model(cell_states_to_model):
620
+ if cell_states_to_model is not None:
621
+ if len(cell_states_to_model.items()) == 1:
622
+ logger.warning(
623
+ "The single value dictionary for cell_states_to_model will be "
624
+ "replaced with a dictionary with named keys for start, goal, and alternate states. "
625
+ "Please specify state_key, start_state, goal_state, and alt_states "
626
+ "in the cell_states_to_model dictionary for future use. "
627
+ "For example, cell_states_to_model={"
628
+ "'state_key': 'disease', "
629
+ "'start_state': 'dcm', "
630
+ "'goal_state': 'nf', "
631
+ "'alt_states': ['hcm', 'other1', 'other2']}"
632
+ )
633
+ for key, value in cell_states_to_model.items():
634
+ if (len(value) == 3) and isinstance(value, tuple):
635
+ if (
636
+ isinstance(value[0], list)
637
+ and isinstance(value[1], list)
638
+ and isinstance(value[2], list)
639
+ ):
640
+ if len(value[0]) == 1 and len(value[1]) == 1:
641
+ all_values = value[0] + value[1] + value[2]
642
+ if len(all_values) == len(set(all_values)):
643
+ continue
644
+ # reformat to the new named key format
645
+ state_values = flatten_list(list(cell_states_to_model.values()))
646
+
647
+ cell_states_to_model = {
648
+ "state_key": list(cell_states_to_model.keys())[0],
649
+ "start_state": state_values[0][0],
650
+ "goal_state": state_values[1][0],
651
+ "alt_states": state_values[2:][0],
652
+ }
653
+ elif set(cell_states_to_model.keys()).issuperset(
654
+ {"state_key", "start_state", "goal_state"}
655
+ ):
656
+ if (
657
+ (cell_states_to_model["state_key"] is None)
658
+ or (cell_states_to_model["start_state"] is None)
659
+ or (cell_states_to_model["goal_state"] is None)
660
+ ):
661
+ logger.error(
662
+ "Please specify 'state_key', 'start_state', and 'goal_state' in cell_states_to_model."
663
+ )
664
+ raise
665
+
666
+ if (
667
+ cell_states_to_model["start_state"]
668
+ == cell_states_to_model["goal_state"]
669
+ ):
670
+ logger.error("All states must be unique.")
671
+ raise
672
+
673
+ if "alt_states" in set(cell_states_to_model.keys()):
674
+ if cell_states_to_model["alt_states"] is not None:
675
+ if not isinstance(cell_states_to_model["alt_states"], list):
676
+ logger.error(
677
+ "cell_states_to_model['alt_states'] must be a list (even if it is one element)."
678
+ )
679
+ raise
680
+ if len(cell_states_to_model["alt_states"]) != len(
681
+ set(cell_states_to_model["alt_states"])
682
+ ):
683
+ logger.error("All states must be unique.")
684
+ raise
685
+ else:
686
+ cell_states_to_model["alt_states"] = []
687
+
688
+ else:
689
+ logger.error(
690
+ "cell_states_to_model must only have the following four keys: "
691
+ "'state_key', 'start_state', 'goal_state', 'alt_states'."
692
+ "For example, cell_states_to_model={"
693
+ "'state_key': 'disease', "
694
+ "'start_state': 'dcm', "
695
+ "'goal_state': 'nf', "
696
+ "'alt_states': ['hcm', 'other1', 'other2']}"
697
+ )
698
+ raise
setup.py CHANGED
@@ -2,7 +2,7 @@ from setuptools import setup
2
 
3
  setup(
4
  name="geneformer",
5
- version="0.0.1",
6
  author="Christina Theodoris",
7
  author_email="christina.theodoris@gladstone.ucsf.edu",
8
  description="Geneformer is a transformer model pretrained \
2
 
3
  setup(
4
  name="geneformer",
5
+ version="0.1.0",
6
  author="Christina Theodoris",
7
  author_email="christina.theodoris@gladstone.ucsf.edu",
8
  description="Geneformer is a transformer model pretrained \