jinbo1129 commited on
Commit
7e129eb
1 Parent(s): 4e4b58c

Upload emb_extractor.py

Browse files
Files changed (1) hide show
  1. geneformer/emb_extractor.py +493 -0
geneformer/emb_extractor.py ADDED
@@ -0,0 +1,493 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Geneformer embedding extractor.
3
+
4
+ Usage:
5
+ from geneformer import EmbExtractor
6
+ embex = EmbExtractor(model_type="CellClassifier",
7
+ num_classes=3,
8
+ emb_mode="cell",
9
+ cell_emb_style="mean_pool",
10
+ filter_data={"cell_type":["cardiomyocyte"]},
11
+ max_ncells=1000,
12
+ max_ncells_to_plot=1000,
13
+ emb_layer=-1,
14
+ emb_label=["disease","cell_type"],
15
+ labels_to_plot=["disease","cell_type"],
16
+ forward_batch_size=100,
17
+ nproc=16,
18
+ summary_stat=None)
19
+ embs = embex.extract_embs("path/to/model",
20
+ "path/to/input_data",
21
+ "path/to/output_directory",
22
+ "output_prefix")
23
+ embex.plot_embs(embs=embs,
24
+ plot_style="heatmap",
25
+ output_directory="path/to/output_directory",
26
+ output_prefix="output_prefix")
27
+
28
+ """
29
+
30
+ # imports
31
+ import logging
32
+ import anndata
33
+ import matplotlib.pyplot as plt
34
+ import numpy as np
35
+ import pandas as pd
36
+ import pickle
37
+ from tdigest import TDigest
38
+ import scanpy as sc
39
+ import seaborn as sns
40
+ import torch
41
+ from collections import Counter
42
+ from pathlib import Path
43
+ from tqdm.notebook import trange
44
+ from transformers import BertForMaskedLM, BertForTokenClassification, BertForSequenceClassification
45
+
46
+ from .tokenizer import TOKEN_DICTIONARY_FILE
47
+
48
+ from .in_silico_perturber import downsample_and_sort, \
49
+ gen_attention_mask, \
50
+ get_model_input_size, \
51
+ load_and_filter, \
52
+ load_model, \
53
+ mean_nonpadding_embs, \
54
+ pad_tensor_list, \
55
+ quant_layers
56
+
57
+ logger = logging.getLogger(__name__)
58
+
59
+ # extract embeddings
60
+ def get_embs(model,
61
+ filtered_input_data,
62
+ emb_mode,
63
+ layer_to_quant,
64
+ pad_token_id,
65
+ forward_batch_size,
66
+ summary_stat):
67
+
68
+ model_input_size = get_model_input_size(model)
69
+ total_batch_length = len(filtered_input_data)
70
+
71
+ if summary_stat is None:
72
+ embs_list = []
73
+ elif summary_stat is not None:
74
+ # test embedding extraction for example cell and extract # emb dims
75
+ example = filtered_input_data.select([i for i in range(1)])
76
+ example.set_format(type="torch")
77
+ emb_dims = test_emb(model, example["input_ids"], layer_to_quant)
78
+ # initiate tdigests for # of emb dims
79
+ embs_tdigests = [TDigest() for _ in range(emb_dims)]
80
+
81
+ for i in trange(0, total_batch_length, forward_batch_size):
82
+ max_range = min(i+forward_batch_size, total_batch_length)
83
+
84
+ minibatch = filtered_input_data.select([i for i in range(i, max_range)])
85
+ max_len = max(minibatch["length"])
86
+ original_lens = torch.tensor(minibatch["length"]).to("cuda")
87
+ minibatch.set_format(type="torch")
88
+
89
+ input_data_minibatch = minibatch["input_ids"]
90
+ input_data_minibatch = pad_tensor_list(input_data_minibatch,
91
+ max_len,
92
+ pad_token_id,
93
+ model_input_size)
94
+
95
+ with torch.no_grad():
96
+ outputs = model(
97
+ input_ids = input_data_minibatch.to("cuda"),
98
+ attention_mask = gen_attention_mask(minibatch)
99
+ )
100
+
101
+ embs_i = outputs.hidden_states[layer_to_quant]
102
+
103
+ if emb_mode == "cell":
104
+ mean_embs = mean_nonpadding_embs(embs_i, original_lens)
105
+ if summary_stat is None:
106
+ embs_list += [mean_embs]
107
+ elif summary_stat is not None:
108
+ # update tdigests with current batch for each emb dim
109
+ # note: tdigest batch update known to be slow so updating serially
110
+ [embs_tdigests[j].update(mean_embs[i,j].item()) for i in range(mean_embs.size(0)) for j in range(emb_dims)]
111
+
112
+ del outputs
113
+ del minibatch
114
+ del input_data_minibatch
115
+ del embs_i
116
+ del mean_embs
117
+ torch.cuda.empty_cache()
118
+
119
+ if summary_stat is None:
120
+ embs_stack = torch.cat(embs_list)
121
+ # calculate summary stat embs from approximated tdigests
122
+ elif summary_stat is not None:
123
+ if summary_stat == "mean":
124
+ summary_emb_list = [embs_tdigests[i].trimmed_mean(0,100) for i in range(emb_dims)]
125
+ elif summary_stat == "median":
126
+ summary_emb_list = [embs_tdigests[i].percentile(50) for i in range(emb_dims)]
127
+ embs_stack = torch.tensor(summary_emb_list)
128
+
129
+ return embs_stack
130
+
131
+ def test_emb(model, example, layer_to_quant):
132
+ with torch.no_grad():
133
+ outputs = model(
134
+ input_ids = example.to("cuda")
135
+ )
136
+
137
+ embs_test = outputs.hidden_states[layer_to_quant]
138
+ return embs_test.size()[2]
139
+
140
+ def label_embs(embs, downsampled_data, emb_labels):
141
+ embs_df = pd.DataFrame(embs.cpu())
142
+ if emb_labels is not None:
143
+ for label in emb_labels:
144
+ emb_label = downsampled_data[label]
145
+ embs_df[label] = emb_label
146
+ return embs_df
147
+
148
+ def plot_umap(embs_df, emb_dims, label, output_file, kwargs_dict):
149
+ only_embs_df = embs_df.iloc[:,:emb_dims]
150
+ only_embs_df.index = pd.RangeIndex(0, only_embs_df.shape[0], name=None).astype(str)
151
+ only_embs_df.columns = pd.RangeIndex(0, only_embs_df.shape[1], name=None).astype(str)
152
+ vars_dict = {"embs": only_embs_df.columns}
153
+ obs_dict = {"cell_id": list(only_embs_df.index),
154
+ f"{label}": list(embs_df[label])}
155
+ adata = anndata.AnnData(X=only_embs_df, obs=obs_dict, var=vars_dict)
156
+ sc.tl.pca(adata, svd_solver='arpack')
157
+ sc.pp.neighbors(adata)
158
+ sc.tl.umap(adata)
159
+ sns.set(rc={'figure.figsize':(10,10)}, font_scale=2.3)
160
+ sns.set_style("white")
161
+ default_kwargs_dict = {"palette":"Set2", "size":200}
162
+ if kwargs_dict is not None:
163
+ default_kwargs_dict.update(kwargs_dict)
164
+
165
+ sc.pl.umap(adata, color=label, save=output_file, **default_kwargs_dict)
166
+
167
+ def gen_heatmap_class_colors(labels, df):
168
+ pal = sns.cubehelix_palette(len(Counter(labels).keys()), light=0.9, dark=0.1, hue=1, reverse=True, start=1, rot=-2)
169
+ lut = dict(zip(map(str, Counter(labels).keys()), pal))
170
+ colors = pd.Series(labels, index=df.index).map(lut)
171
+ return colors
172
+
173
+ def gen_heatmap_class_dict(classes, label_colors_series):
174
+ class_color_dict_df = pd.DataFrame({"classes": classes, "color": label_colors_series})
175
+ class_color_dict_df = class_color_dict_df.drop_duplicates(subset=["classes"])
176
+ return dict(zip(class_color_dict_df["classes"],class_color_dict_df["color"]))
177
+
178
+ def make_colorbar(embs_df, label):
179
+
180
+ labels = list(embs_df[label])
181
+
182
+ cell_type_colors = gen_heatmap_class_colors(labels, embs_df)
183
+ label_colors = pd.DataFrame(cell_type_colors, columns=[label])
184
+
185
+ for i,row in label_colors.iterrows():
186
+ colors=row[0]
187
+ if len(colors)!=3 or any(np.isnan(colors)):
188
+ print(i,colors)
189
+
190
+ label_colors.isna().sum()
191
+
192
+ # create dictionary for colors and classes
193
+ label_color_dict = gen_heatmap_class_dict(labels, label_colors[label])
194
+ return label_colors, label_color_dict
195
+
196
+ def plot_heatmap(embs_df, emb_dims, label, output_file, kwargs_dict):
197
+ sns.set_style("white")
198
+ sns.set(font_scale=2)
199
+ plt.figure(figsize=(15, 15), dpi=150)
200
+ label_colors, label_color_dict = make_colorbar(embs_df, label)
201
+
202
+ default_kwargs_dict = {"row_cluster": True,
203
+ "col_cluster": True,
204
+ "row_colors": label_colors,
205
+ "standard_scale": 1,
206
+ "linewidths": 0,
207
+ "xticklabels": False,
208
+ "yticklabels": False,
209
+ "figsize": (15,15),
210
+ "center": 0,
211
+ "cmap": "magma"}
212
+
213
+ if kwargs_dict is not None:
214
+ default_kwargs_dict.update(kwargs_dict)
215
+ g = sns.clustermap(embs_df.iloc[:,0:emb_dims].apply(pd.to_numeric), **default_kwargs_dict)
216
+
217
+ plt.setp(g.ax_row_colors.get_xmajorticklabels(), rotation=45, ha="right")
218
+
219
+ for label_color in list(label_color_dict.keys()):
220
+ g.ax_col_dendrogram.bar(0, 0, color=label_color_dict[label_color], label=label_color, linewidth=0)
221
+
222
+ l1 = g.ax_col_dendrogram.legend(title=f"{label}",
223
+ loc="lower center",
224
+ ncol=4,
225
+ bbox_to_anchor=(0.5, 1),
226
+ facecolor="white")
227
+
228
+ plt.savefig(output_file, bbox_inches='tight')
229
+
230
+ class EmbExtractor:
231
+ valid_option_dict = {
232
+ "model_type": {"Pretrained","GeneClassifier","CellClassifier"},
233
+ "num_classes": {int},
234
+ "emb_mode": {"cell","gene"},
235
+ "cell_emb_style": {"mean_pool"},
236
+ "filter_data": {None, dict},
237
+ "max_ncells": {None, int},
238
+ "emb_layer": {-1, 0},
239
+ "emb_label": {None, list},
240
+ "labels_to_plot": {None, list},
241
+ "forward_batch_size": {int},
242
+ "nproc": {int},
243
+ "summary_stat": {None, "mean", "median"},
244
+ }
245
+ def __init__(
246
+ self,
247
+ model_type="Pretrained",
248
+ num_classes=0,
249
+ emb_mode="cell",
250
+ cell_emb_style="mean_pool",
251
+ filter_data=None,
252
+ max_ncells=1000,
253
+ emb_layer=-1,
254
+ emb_label=None,
255
+ labels_to_plot=None,
256
+ forward_batch_size=100,
257
+ nproc=4,
258
+ summary_stat=None,
259
+ token_dictionary_file=TOKEN_DICTIONARY_FILE,
260
+ ):
261
+ """
262
+ Initialize embedding extractor.
263
+
264
+ Parameters
265
+ ----------
266
+ model_type : {"Pretrained","GeneClassifier","CellClassifier"}
267
+ Whether model is the pretrained Geneformer or a fine-tuned gene or cell classifier.
268
+ num_classes : int
269
+ If model is a gene or cell classifier, specify number of classes it was trained to classify.
270
+ For the pretrained Geneformer model, number of classes is 0 as it is not a classifier.
271
+ emb_mode : {"cell","gene"}
272
+ Whether to output cell or gene embeddings.
273
+ cell_emb_style : "mean_pool"
274
+ Method for summarizing cell embeddings.
275
+ Currently only option is mean pooling of gene embeddings for given cell.
276
+ filter_data : None, dict
277
+ Default is to extract embeddings from all input data.
278
+ Otherwise, dictionary specifying .dataset column name and list of values to filter by.
279
+ max_ncells : None, int
280
+ Maximum number of cells to extract embeddings from.
281
+ Default is 1000 cells randomly sampled from input data.
282
+ If None, will extract embeddings from all cells.
283
+ emb_layer : {-1, 0}
284
+ Embedding layer to extract.
285
+ The last layer is most specifically weighted to optimize the given learning objective.
286
+ Generally, it is best to extract the 2nd to last layer to get a more general representation.
287
+ -1: 2nd to last layer
288
+ 0: last layer
289
+ emb_label : None, list
290
+ List of column name(s) in .dataset to add as labels to embedding output.
291
+ labels_to_plot : None, list
292
+ Cell labels to plot.
293
+ Shown as color bar in heatmap.
294
+ Shown as cell color in umap.
295
+ Plotting umap requires labels to plot.
296
+ forward_batch_size : int
297
+ Batch size for forward pass.
298
+ nproc : int
299
+ Number of CPU processes to use.
300
+ summary_stat : {None, "mean", "median"}
301
+ If not None, outputs only approximated mean or median embedding of input data.
302
+ Recommended if encountering memory constraints while generating goal embedding positions.
303
+ Slower but more memory-efficient.
304
+ token_dictionary_file : Path
305
+ Path to pickle file containing token dictionary (Ensembl ID:token).
306
+ """
307
+
308
+ self.model_type = model_type
309
+ self.num_classes = num_classes
310
+ self.emb_mode = emb_mode
311
+ self.cell_emb_style = cell_emb_style
312
+ self.filter_data = filter_data
313
+ self.max_ncells = max_ncells
314
+ self.emb_layer = emb_layer
315
+ self.emb_label = emb_label
316
+ self.labels_to_plot = labels_to_plot
317
+ self.forward_batch_size = forward_batch_size
318
+ self.nproc = nproc
319
+ self.summary_stat = summary_stat
320
+
321
+ self.validate_options()
322
+
323
+ # load token dictionary (Ensembl IDs:token)
324
+ with open(token_dictionary_file, "rb") as f:
325
+ self.gene_token_dict = pickle.load(f)
326
+
327
+ self.pad_token_id = self.gene_token_dict.get("<pad>")
328
+
329
+
330
+ def validate_options(self):
331
+ # first disallow options under development
332
+ if self.emb_mode == "gene":
333
+ logger.error(
334
+ "Extraction and plotting of gene-level embeddings currently under development. " \
335
+ "Current valid option for 'emb_mode': 'cell'"
336
+ )
337
+ raise
338
+
339
+ # confirm arguments are within valid options and compatible with each other
340
+ for attr_name,valid_options in self.valid_option_dict.items():
341
+ attr_value = self.__dict__[attr_name]
342
+ if type(attr_value) not in {list, dict}:
343
+ if attr_value in valid_options:
344
+ continue
345
+ valid_type = False
346
+ for option in valid_options:
347
+ if (option in [int,list,dict]) and isinstance(attr_value, option):
348
+ valid_type = True
349
+ break
350
+ if valid_type:
351
+ continue
352
+ logger.error(
353
+ f"Invalid option for {attr_name}. " \
354
+ f"Valid options for {attr_name}: {valid_options}"
355
+ )
356
+ raise
357
+
358
+ if self.filter_data is not None:
359
+ for key,value in self.filter_data.items():
360
+ if type(value) != list:
361
+ self.filter_data[key] = [value]
362
+ logger.warning(
363
+ "Values in filter_data dict must be lists. " \
364
+ f"Changing {key} value to list ([{value}]).")
365
+
366
+ def extract_embs(self,
367
+ model_directory,
368
+ input_data_file,
369
+ output_directory,
370
+ output_prefix):
371
+ """
372
+ Extract embeddings from input data and save as results in output_directory.
373
+
374
+ Parameters
375
+ ----------
376
+ model_directory : Path
377
+ Path to directory containing model
378
+ input_data_file : Path
379
+ Path to directory containing .dataset inputs
380
+ output_directory : Path
381
+ Path to directory where embedding data will be saved as csv
382
+ output_prefix : str
383
+ Prefix for output file
384
+ """
385
+
386
+ filtered_input_data = load_and_filter(self.filter_data, self.nproc, input_data_file)
387
+ downsampled_data = downsample_and_sort(filtered_input_data, self.max_ncells)
388
+ model = load_model(self.model_type, self.num_classes, model_directory)
389
+ layer_to_quant = quant_layers(model)+self.emb_layer
390
+ embs = get_embs(model,
391
+ downsampled_data,
392
+ self.emb_mode,
393
+ layer_to_quant,
394
+ self.pad_token_id,
395
+ self.forward_batch_size,
396
+ self.summary_stat)
397
+
398
+ if self.summary_stat is None:
399
+ embs_df = label_embs(embs, downsampled_data, self.emb_label)
400
+ elif self.summary_stat is not None:
401
+ embs_df = pd.DataFrame(embs.cpu()).T
402
+
403
+ # save embeddings to output_path
404
+ output_path = (Path(output_directory) / output_prefix).with_suffix(".csv")
405
+ embs_df.to_csv(output_path)
406
+
407
+ return embs_df
408
+
409
+ def plot_embs(self,
410
+ embs,
411
+ plot_style,
412
+ output_directory,
413
+ output_prefix,
414
+ max_ncells_to_plot=1000,
415
+ kwargs_dict=None):
416
+
417
+ """
418
+ Plot embeddings, coloring by provided labels.
419
+
420
+ Parameters
421
+ ----------
422
+ embs : pandas.core.frame.DataFrame
423
+ Pandas dataframe containing embeddings output from extract_embs
424
+ plot_style : str
425
+ Style of plot: "heatmap" or "umap"
426
+ output_directory : Path
427
+ Path to directory where plots will be saved as pdf
428
+ output_prefix : str
429
+ Prefix for output file
430
+ max_ncells_to_plot : None, int
431
+ Maximum number of cells to plot.
432
+ Default is 1000 cells randomly sampled from embeddings.
433
+ If None, will plot embeddings from all cells.
434
+ kwargs_dict : dict
435
+ Dictionary of kwargs to pass to plotting function.
436
+ """
437
+
438
+ if plot_style not in ["heatmap","umap"]:
439
+ logger.error(
440
+ "Invalid option for 'plot_style'. " \
441
+ "Valid options: {'heatmap','umap'}"
442
+ )
443
+ raise
444
+
445
+ if (plot_style == "umap") and (self.labels_to_plot is None):
446
+ logger.error(
447
+ "Plotting UMAP requires 'labels_to_plot'. "
448
+ )
449
+ raise
450
+
451
+ if max_ncells_to_plot > self.max_ncells:
452
+ max_ncells_to_plot = self.max_ncells
453
+ logger.warning(
454
+ "max_ncells_to_plot must be <= max_ncells. " \
455
+ f"Changing max_ncells_to_plot to {self.max_ncells}.")
456
+
457
+ if (max_ncells_to_plot is not None) \
458
+ and (max_ncells_to_plot < self.max_ncells):
459
+ embs = embs.sample(max_ncells_to_plot, axis=0)
460
+
461
+ if self.emb_label is None:
462
+ label_len = 0
463
+ else:
464
+ label_len = len(self.emb_label)
465
+
466
+ emb_dims = embs.shape[1] - label_len
467
+
468
+ if self.emb_label is None:
469
+ emb_labels = None
470
+ else:
471
+ emb_labels = embs.columns[emb_dims:]
472
+
473
+ if plot_style == "umap":
474
+ for label in self.labels_to_plot:
475
+ if label not in emb_labels:
476
+ logger.warning(
477
+ f"Label {label} from labels_to_plot " \
478
+ f"not present in provided embeddings dataframe.")
479
+ continue
480
+ output_prefix_label = "_" + output_prefix + f"_umap_{label}"
481
+ output_file = (Path(output_directory) / output_prefix_label).with_suffix(".pdf")
482
+ plot_umap(embs, emb_dims, label, output_prefix_label, kwargs_dict)
483
+
484
+ if plot_style == "heatmap":
485
+ for label in self.labels_to_plot:
486
+ if label not in emb_labels:
487
+ logger.warning(
488
+ f"Label {label} from labels_to_plot " \
489
+ f"not present in provided embeddings dataframe.")
490
+ continue
491
+ output_prefix_label = output_prefix + f"_heatmap_{label}"
492
+ output_file = (Path(output_directory) / output_prefix_label).with_suffix(".pdf")
493
+ plot_heatmap(embs, emb_dims, label, output_file, kwargs_dict)