jinbo1129 commited on
Commit
4e4b58c
1 Parent(s): 4d69e3e

Delete emb_extractor.py

Browse files
Files changed (1) hide show
  1. emb_extractor.py +0 -493
emb_extractor.py DELETED
@@ -1,493 +0,0 @@
1
- """
2
- Geneformer embedding extractor.
3
-
4
- Usage:
5
- from geneformer import EmbExtractor
6
- embex = EmbExtractor(model_type="CellClassifier",
7
- num_classes=3,
8
- emb_mode="cell",
9
- cell_emb_style="mean_pool",
10
- filter_data={"cell_type":["cardiomyocyte"]},
11
- max_ncells=1000,
12
- max_ncells_to_plot=1000,
13
- emb_layer=-1,
14
- emb_label=["disease","cell_type"],
15
- labels_to_plot=["disease","cell_type"],
16
- forward_batch_size=100,
17
- nproc=16,
18
- summary_stat=None)
19
- embs = embex.extract_embs("path/to/model",
20
- "path/to/input_data",
21
- "path/to/output_directory",
22
- "output_prefix")
23
- embex.plot_embs(embs=embs,
24
- plot_style="heatmap",
25
- output_directory="path/to/output_directory",
26
- output_prefix="output_prefix")
27
-
28
- """
29
-
30
- # imports
31
- import logging
32
- import anndata
33
- import matplotlib.pyplot as plt
34
- import numpy as np
35
- import pandas as pd
36
- import pickle
37
- from tdigest import TDigest
38
- import scanpy as sc
39
- import seaborn as sns
40
- import torch
41
- from collections import Counter
42
- from pathlib import Path
43
- from tqdm.notebook import trange
44
- from transformers import BertForMaskedLM, BertForTokenClassification, BertForSequenceClassification
45
-
46
- from .tokenizer import TOKEN_DICTIONARY_FILE
47
-
48
- from .in_silico_perturber import downsample_and_sort, \
49
- gen_attention_mask, \
50
- get_model_input_size, \
51
- load_and_filter, \
52
- load_model, \
53
- mean_nonpadding_embs, \
54
- pad_tensor_list, \
55
- quant_layers
56
-
57
- logger = logging.getLogger(__name__)
58
-
59
- # extract embeddings
60
- def get_embs(model,
61
- filtered_input_data,
62
- emb_mode,
63
- layer_to_quant,
64
- pad_token_id,
65
- forward_batch_size,
66
- summary_stat):
67
-
68
- model_input_size = get_model_input_size(model)
69
- total_batch_length = len(filtered_input_data)
70
-
71
- if summary_stat is None:
72
- embs_list = []
73
- elif summary_stat is not None:
74
- # test embedding extraction for example cell and extract # emb dims
75
- example = filtered_input_data.select([i for i in range(1)])
76
- example.set_format(type="torch")
77
- emb_dims = test_emb(model, example["input_ids"], layer_to_quant)
78
- # initiate tdigests for # of emb dims
79
- embs_tdigests = [TDigest() for _ in range(emb_dims)]
80
-
81
- for i in trange(0, total_batch_length, forward_batch_size):
82
- max_range = min(i+forward_batch_size, total_batch_length)
83
-
84
- minibatch = filtered_input_data.select([i for i in range(i, max_range)])
85
- max_len = max(minibatch["length"])
86
- original_lens = torch.tensor(minibatch["length"]).to("cuda")
87
- minibatch.set_format(type="torch")
88
-
89
- input_data_minibatch = minibatch["input_ids"]
90
- input_data_minibatch = pad_tensor_list(input_data_minibatch,
91
- max_len,
92
- pad_token_id,
93
- model_input_size)
94
-
95
- with torch.no_grad():
96
- outputs = model(
97
- input_ids = input_data_minibatch.to("cuda"),
98
- attention_mask = gen_attention_mask(minibatch)
99
- )
100
-
101
- embs_i = outputs.hidden_states[layer_to_quant]
102
-
103
- if emb_mode == "cell":
104
- mean_embs = mean_nonpadding_embs(embs_i, original_lens)
105
- if summary_stat is None:
106
- embs_list += [mean_embs]
107
- elif summary_stat is not None:
108
- # update tdigests with current batch for each emb dim
109
- # note: tdigest batch update known to be slow so updating serially
110
- [embs_tdigests[j].update(mean_embs[i,j].item()) for i in range(mean_embs.size(0)) for j in range(emb_dims)]
111
-
112
- del outputs
113
- del minibatch
114
- del input_data_minibatch
115
- del embs_i
116
- del mean_embs
117
- torch.cuda.empty_cache()
118
-
119
- if summary_stat is None:
120
- embs_stack = torch.cat(embs_list)
121
- # calculate summary stat embs from approximated tdigests
122
- elif summary_stat is not None:
123
- if summary_stat == "mean":
124
- summary_emb_list = [embs_tdigests[i].trimmed_mean(0,100) for i in range(emb_dims)]
125
- elif summary_stat == "median":
126
- summary_emb_list = [embs_tdigests[i].percentile(50) for i in range(emb_dims)]
127
- embs_stack = torch.tensor(summary_emb_list)
128
-
129
- return embs_stack
130
-
131
- def test_emb(model, example, layer_to_quant):
132
- with torch.no_grad():
133
- outputs = model(
134
- input_ids = example.to("cuda")
135
- )
136
-
137
- embs_test = outputs.hidden_states[layer_to_quant]
138
- return embs_test.size()[2]
139
-
140
- def label_embs(embs, downsampled_data, emb_labels):
141
- embs_df = pd.DataFrame(embs.cpu())
142
- if emb_labels is not None:
143
- for label in emb_labels:
144
- emb_label = downsampled_data[label]
145
- embs_df[label] = emb_label
146
- return embs_df
147
-
148
- def plot_umap(embs_df, emb_dims, label, output_file, kwargs_dict):
149
- only_embs_df = embs_df.iloc[:,:emb_dims]
150
- only_embs_df.index = pd.RangeIndex(0, only_embs_df.shape[0], name=None).astype(str)
151
- only_embs_df.columns = pd.RangeIndex(0, only_embs_df.shape[1], name=None).astype(str)
152
- vars_dict = {"embs": only_embs_df.columns}
153
- obs_dict = {"cell_id": list(only_embs_df.index),
154
- f"{label}": list(embs_df[label])}
155
- adata = anndata.AnnData(X=only_embs_df, obs=obs_dict, var=vars_dict)
156
- sc.tl.pca(adata, svd_solver='arpack')
157
- sc.pp.neighbors(adata)
158
- sc.tl.umap(adata)
159
- sns.set(rc={'figure.figsize':(10,10)}, font_scale=2.3)
160
- sns.set_style("white")
161
- default_kwargs_dict = {"palette":"Set2", "size":200}
162
- if kwargs_dict is not None:
163
- default_kwargs_dict.update(kwargs_dict)
164
-
165
- sc.pl.umap(adata, color=label, save=output_file, **default_kwargs_dict)
166
-
167
- def gen_heatmap_class_colors(labels, df):
168
- pal = sns.cubehelix_palette(len(Counter(labels).keys()), light=0.9, dark=0.1, hue=1, reverse=True, start=1, rot=-2)
169
- lut = dict(zip(map(str, Counter(labels).keys()), pal))
170
- colors = pd.Series(labels, index=df.index).map(lut)
171
- return colors
172
-
173
- def gen_heatmap_class_dict(classes, label_colors_series):
174
- class_color_dict_df = pd.DataFrame({"classes": classes, "color": label_colors_series})
175
- class_color_dict_df = class_color_dict_df.drop_duplicates(subset=["classes"])
176
- return dict(zip(class_color_dict_df["classes"],class_color_dict_df["color"]))
177
-
178
- def make_colorbar(embs_df, label):
179
-
180
- labels = list(embs_df[label])
181
-
182
- cell_type_colors = gen_heatmap_class_colors(labels, embs_df)
183
- label_colors = pd.DataFrame(cell_type_colors, columns=[label])
184
-
185
- for i,row in label_colors.iterrows():
186
- colors=row[0]
187
- if len(colors)!=3 or any(np.isnan(colors)):
188
- print(i,colors)
189
-
190
- label_colors.isna().sum()
191
-
192
- # create dictionary for colors and classes
193
- label_color_dict = gen_heatmap_class_dict(labels, label_colors[label])
194
- return label_colors, label_color_dict
195
-
196
- def plot_heatmap(embs_df, emb_dims, label, output_file, kwargs_dict):
197
- sns.set_style("white")
198
- sns.set(font_scale=2)
199
- plt.figure(figsize=(15, 15), dpi=150)
200
- label_colors, label_color_dict = make_colorbar(embs_df, label)
201
-
202
- default_kwargs_dict = {"row_cluster": True,
203
- "col_cluster": True,
204
- "row_colors": label_colors,
205
- "standard_scale": 1,
206
- "linewidths": 0,
207
- "xticklabels": False,
208
- "yticklabels": False,
209
- "figsize": (15,15),
210
- "center": 0,
211
- "cmap": "magma"}
212
-
213
- if kwargs_dict is not None:
214
- default_kwargs_dict.update(kwargs_dict)
215
- g = sns.clustermap(embs_df.iloc[:,0:emb_dims].apply(pd.to_numeric), **default_kwargs_dict)
216
-
217
- plt.setp(g.ax_row_colors.get_xmajorticklabels(), rotation=45, ha="right")
218
-
219
- for label_color in list(label_color_dict.keys()):
220
- g.ax_col_dendrogram.bar(0, 0, color=label_color_dict[label_color], label=label_color, linewidth=0)
221
-
222
- l1 = g.ax_col_dendrogram.legend(title=f"{label}",
223
- loc="lower center",
224
- ncol=4,
225
- bbox_to_anchor=(0.5, 1),
226
- facecolor="white")
227
-
228
- plt.savefig(output_file, bbox_inches='tight')
229
-
230
- class EmbExtractor:
231
- valid_option_dict = {
232
- "model_type": {"Pretrained","GeneClassifier","CellClassifier"},
233
- "num_classes": {int},
234
- "emb_mode": {"cell","gene"},
235
- "cell_emb_style": {"mean_pool"},
236
- "filter_data": {None, dict},
237
- "max_ncells": {None, int},
238
- "emb_layer": {-1, 0},
239
- "emb_label": {None, list},
240
- "labels_to_plot": {None, list},
241
- "forward_batch_size": {int},
242
- "nproc": {int},
243
- "summary_stat": {None, "mean", "median"},
244
- }
245
- def __init__(
246
- self,
247
- model_type="Pretrained",
248
- num_classes=0,
249
- emb_mode="cell",
250
- cell_emb_style="mean_pool",
251
- filter_data=None,
252
- max_ncells=1000,
253
- emb_layer=-1,
254
- emb_label=None,
255
- labels_to_plot=None,
256
- forward_batch_size=100,
257
- nproc=4,
258
- summary_stat=None,
259
- token_dictionary_file=TOKEN_DICTIONARY_FILE,
260
- ):
261
- """
262
- Initialize embedding extractor.
263
-
264
- Parameters
265
- ----------
266
- model_type : {"Pretrained","GeneClassifier","CellClassifier"}
267
- Whether model is the pretrained Geneformer or a fine-tuned gene or cell classifier.
268
- num_classes : int
269
- If model is a gene or cell classifier, specify number of classes it was trained to classify.
270
- For the pretrained Geneformer model, number of classes is 0 as it is not a classifier.
271
- emb_mode : {"cell","gene"}
272
- Whether to output cell or gene embeddings.
273
- cell_emb_style : "mean_pool"
274
- Method for summarizing cell embeddings.
275
- Currently only option is mean pooling of gene embeddings for given cell.
276
- filter_data : None, dict
277
- Default is to extract embeddings from all input data.
278
- Otherwise, dictionary specifying .dataset column name and list of values to filter by.
279
- max_ncells : None, int
280
- Maximum number of cells to extract embeddings from.
281
- Default is 1000 cells randomly sampled from input data.
282
- If None, will extract embeddings from all cells.
283
- emb_layer : {-1, 0}
284
- Embedding layer to extract.
285
- The last layer is most specifically weighted to optimize the given learning objective.
286
- Generally, it is best to extract the 2nd to last layer to get a more general representation.
287
- -1: 2nd to last layer
288
- 0: last layer
289
- emb_label : None, list
290
- List of column name(s) in .dataset to add as labels to embedding output.
291
- labels_to_plot : None, list
292
- Cell labels to plot.
293
- Shown as color bar in heatmap.
294
- Shown as cell color in umap.
295
- Plotting umap requires labels to plot.
296
- forward_batch_size : int
297
- Batch size for forward pass.
298
- nproc : int
299
- Number of CPU processes to use.
300
- summary_stat : {None, "mean", "median"}
301
- If not None, outputs only approximated mean or median embedding of input data.
302
- Recommended if encountering memory constraints while generating goal embedding positions.
303
- Slower but more memory-efficient.
304
- token_dictionary_file : Path
305
- Path to pickle file containing token dictionary (Ensembl ID:token).
306
- """
307
-
308
- self.model_type = model_type
309
- self.num_classes = num_classes
310
- self.emb_mode = emb_mode
311
- self.cell_emb_style = cell_emb_style
312
- self.filter_data = filter_data
313
- self.max_ncells = max_ncells
314
- self.emb_layer = emb_layer
315
- self.emb_label = emb_label
316
- self.labels_to_plot = labels_to_plot
317
- self.forward_batch_size = forward_batch_size
318
- self.nproc = nproc
319
- self.summary_stat = summary_stat
320
-
321
- self.validate_options()
322
-
323
- # load token dictionary (Ensembl IDs:token)
324
- with open(token_dictionary_file, "rb") as f:
325
- self.gene_token_dict = pickle.load(f)
326
-
327
- self.pad_token_id = self.gene_token_dict.get("<pad>")
328
-
329
-
330
- def validate_options(self):
331
- # first disallow options under development
332
- if self.emb_mode == "gene":
333
- logger.error(
334
- "Extraction and plotting of gene-level embeddings currently under development. " \
335
- "Current valid option for 'emb_mode': 'cell'"
336
- )
337
- raise
338
-
339
- # confirm arguments are within valid options and compatible with each other
340
- for attr_name,valid_options in self.valid_option_dict.items():
341
- attr_value = self.__dict__[attr_name]
342
- if type(attr_value) not in {list, dict}:
343
- if attr_value in valid_options:
344
- continue
345
- valid_type = False
346
- for option in valid_options:
347
- if (option in [int,list,dict]) and isinstance(attr_value, option):
348
- valid_type = True
349
- break
350
- if valid_type:
351
- continue
352
- logger.error(
353
- f"Invalid option for {attr_name}. " \
354
- f"Valid options for {attr_name}: {valid_options}"
355
- )
356
- raise
357
-
358
- if self.filter_data is not None:
359
- for key,value in self.filter_data.items():
360
- if type(value) != list:
361
- self.filter_data[key] = [value]
362
- logger.warning(
363
- "Values in filter_data dict must be lists. " \
364
- f"Changing {key} value to list ([{value}]).")
365
-
366
- def extract_embs(self,
367
- model_directory,
368
- input_data_file,
369
- output_directory,
370
- output_prefix):
371
- """
372
- Extract embeddings from input data and save as results in output_directory.
373
-
374
- Parameters
375
- ----------
376
- model_directory : Path
377
- Path to directory containing model
378
- input_data_file : Path
379
- Path to directory containing .dataset inputs
380
- output_directory : Path
381
- Path to directory where embedding data will be saved as csv
382
- output_prefix : str
383
- Prefix for output file
384
- """
385
-
386
- filtered_input_data = load_and_filter(self.filter_data, self.nproc, input_data_file)
387
- downsampled_data = downsample_and_sort(filtered_input_data, self.max_ncells)
388
- model = load_model(self.model_type, self.num_classes, model_directory)
389
- layer_to_quant = quant_layers(model)+self.emb_layer
390
- embs = get_embs(model,
391
- downsampled_data,
392
- self.emb_mode,
393
- layer_to_quant,
394
- self.pad_token_id,
395
- self.forward_batch_size,
396
- self.summary_stat)
397
-
398
- if self.summary_stat is None:
399
- embs_df = label_embs(embs, downsampled_data, self.emb_label)
400
- elif self.summary_stat is not None:
401
- embs_df = pd.DataFrame(embs.cpu()).T
402
-
403
- # save embeddings to output_path
404
- output_path = (Path(output_directory) / output_prefix).with_suffix(".csv")
405
- embs_df.to_csv(output_path)
406
-
407
- return embs_df
408
-
409
- def plot_embs(self,
410
- embs,
411
- plot_style,
412
- output_directory,
413
- output_prefix,
414
- max_ncells_to_plot=1000,
415
- kwargs_dict=None):
416
-
417
- """
418
- Plot embeddings, coloring by provided labels.
419
-
420
- Parameters
421
- ----------
422
- embs : pandas.core.frame.DataFrame
423
- Pandas dataframe containing embeddings output from extract_embs
424
- plot_style : str
425
- Style of plot: "heatmap" or "umap"
426
- output_directory : Path
427
- Path to directory where plots will be saved as pdf
428
- output_prefix : str
429
- Prefix for output file
430
- max_ncells_to_plot : None, int
431
- Maximum number of cells to plot.
432
- Default is 1000 cells randomly sampled from embeddings.
433
- If None, will plot embeddings from all cells.
434
- kwargs_dict : dict
435
- Dictionary of kwargs to pass to plotting function.
436
- """
437
-
438
- if plot_style not in ["heatmap","umap"]:
439
- logger.error(
440
- "Invalid option for 'plot_style'. " \
441
- "Valid options: {'heatmap','umap'}"
442
- )
443
- raise
444
-
445
- if (plot_style == "umap") and (self.labels_to_plot is None):
446
- logger.error(
447
- "Plotting UMAP requires 'labels_to_plot'. "
448
- )
449
- raise
450
-
451
- if max_ncells_to_plot > self.max_ncells:
452
- max_ncells_to_plot = self.max_ncells
453
- logger.warning(
454
- "max_ncells_to_plot must be <= max_ncells. " \
455
- f"Changing max_ncells_to_plot to {self.max_ncells}.")
456
-
457
- if (max_ncells_to_plot is not None) \
458
- and (max_ncells_to_plot < self.max_ncells):
459
- embs = embs.sample(max_ncells_to_plot, axis=0)
460
-
461
- if self.emb_label is None:
462
- label_len = 0
463
- else:
464
- label_len = len(self.emb_label)
465
-
466
- emb_dims = embs.shape[1] - label_len
467
-
468
- if self.emb_label is None:
469
- emb_labels = None
470
- else:
471
- emb_labels = embs.columns[emb_dims:]
472
-
473
- if plot_style == "umap":
474
- for label in self.labels_to_plot:
475
- if label not in emb_labels:
476
- logger.warning(
477
- f"Label {label} from labels_to_plot " \
478
- f"not present in provided embeddings dataframe.")
479
- continue
480
- output_prefix_label = "_" + output_prefix + f"_umap_{label}"
481
- output_file = (Path(output_directory) / output_prefix_label).with_suffix(".pdf")
482
- plot_umap(embs, emb_dims, label, output_prefix_label, kwargs_dict)
483
-
484
- if plot_style == "heatmap":
485
- for label in self.labels_to_plot:
486
- if label not in emb_labels:
487
- logger.warning(
488
- f"Label {label} from labels_to_plot " \
489
- f"not present in provided embeddings dataframe.")
490
- continue
491
- output_prefix_label = output_prefix + f"_heatmap_{label}"
492
- output_file = (Path(output_directory) / output_prefix_label).with_suffix(".pdf")
493
- plot_heatmap(embs, emb_dims, label, output_file, kwargs_dict)