Upload tokenizer.py

#99
by giovp - opened
Files changed (1) hide show
  1. tokenizer.py +239 -0
tokenizer.py ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Geneformer tokenizer.
3
+
4
+ Input data:
5
+ Required format: raw counts scRNAseq data without feature selection as .loom file
6
+ Required row (gene) attribute: "ensembl_id"; Ensembl ID for each gene
7
+ Required col (cell) attribute: "n_counts"; total read counts in that cell
8
+ Optional col (cell) attribute: "filter_pass"; binary indicator of whether cell should be tokenized based on user-defined filtering criteria
9
+ Optional col (cell) attributes: any other cell metadata can be passed on to the tokenized dataset as a custom attribute dictionary as shown below
10
+
11
+ Usage:
12
+ from geneformer import TranscriptomeTokenizer
13
+ tk = TranscriptomeTokenizer({"cell_type": "cell_type", "organ_major": "organ_major"}, nproc=4)
14
+ tk.tokenize_data("loom_data_directory", "output_directory", "output_prefix")
15
+ """
16
+
17
+
18
+ from __future__ import annotations
19
+ from typing import Literal
20
+ import pickle
21
+ from pathlib import Path
22
+
23
+ import loompy as lp
24
+ import numpy as np
25
+ from datasets import Dataset
26
+
27
+ GENE_MEDIAN_FILE = Path(__file__).parent / "gene_median_dictionary.pkl"
28
+ TOKEN_DICTIONARY_FILE = Path(__file__).parent / "token_dictionary.pkl"
29
+
30
+
31
+ def tokenize_cell(gene_vector, gene_tokens):
32
+ """
33
+ Convert normalized gene expression vector to tokenized rank value encoding.
34
+ """
35
+ # create array of gene vector with token indices
36
+ # mask undetected genes
37
+ nonzero_mask = np.nonzero(gene_vector)[0]
38
+ # sort by median-scaled gene values
39
+ sorted_indices = np.argsort(-gene_vector[nonzero_mask])
40
+ # tokenize
41
+ sentence_tokens = gene_tokens[nonzero_mask][sorted_indices]
42
+ return sentence_tokens
43
+
44
+
45
+ class TranscriptomeTokenizer:
46
+ def __init__(
47
+ self,
48
+ custom_attr_name_dict,
49
+ nproc=1,
50
+ gene_median_file=GENE_MEDIAN_FILE,
51
+ token_dictionary_file=TOKEN_DICTIONARY_FILE,
52
+ ):
53
+ """
54
+ Initialize tokenizer.
55
+
56
+ Parameters
57
+ ----------
58
+ custom_attr_name_dict : dict
59
+ Dictionary of custom attributes to be added to the dataset.
60
+ Keys are the names of the attributes in the loom file.
61
+ Values are the names of the attributes in the dataset.
62
+ nproc : int
63
+ Number of processes to use for dataset mapping.
64
+ gene_median_file : Path
65
+ Path to pickle file containing dictionary of non-zero median
66
+ gene expression values across Genecorpus-30M.
67
+ token_dictionary_file : Path
68
+ Path to pickle file containing token dictionary (Ensembl IDs:token).
69
+ """
70
+ # dictionary of custom attributes {output dataset column name: input .loom column name}
71
+ self.custom_attr_name_dict = custom_attr_name_dict
72
+
73
+ # number of processes for dataset mapping
74
+ self.nproc = nproc
75
+
76
+ # load dictionary of gene normalization factors
77
+ # (non-zero median value of expression across Genecorpus-30M)
78
+ with open(gene_median_file, "rb") as f:
79
+ self.gene_median_dict = pickle.load(f)
80
+
81
+ # load token dictionary (Ensembl IDs:token)
82
+ with open(token_dictionary_file, "rb") as f:
83
+ self.gene_token_dict = pickle.load(f)
84
+
85
+ # gene keys for full vocabulary
86
+ self.gene_keys = list(self.gene_median_dict.keys())
87
+
88
+ # protein-coding and miRNA gene list dictionary for selecting .loom rows for tokenization
89
+ self.genelist_dict = dict(zip(self.gene_keys, [True] * len(self.gene_keys)))
90
+
91
+ def tokenize_data(
92
+ self,
93
+ data_directory: Path | str,
94
+ output_directory: Path | str,
95
+ output_prefix: str,
96
+ file_format: Literal["loom", "h5ad"] = "loom",
97
+ ):
98
+ """
99
+ Tokenize .loom files in loom_data_directory and save as tokenized .dataset in output_directory.
100
+
101
+ Parameters
102
+ ----------
103
+ loom_data_directory : Path
104
+ Path to directory containing loom files or anndata files
105
+ output_directory : Path
106
+ Path to directory where tokenized data will be saved as .dataset
107
+ output_prefix : str
108
+ Prefix for output .dataset
109
+ file_format : str
110
+ Format of input files. Can be "loom" or "h5ad".
111
+ """
112
+ tokenized_cells, cell_metadata = self.tokenize_files(Path(data_directory), file_format)
113
+ tokenized_dataset = self.create_dataset(tokenized_cells, cell_metadata)
114
+
115
+ output_path = (Path(output_directory) / output_prefix).with_suffix(".dataset")
116
+ tokenized_dataset.save_to_disk(output_path)
117
+
118
+ def tokenize_files(self, data_directory, file_format: Literal["loom", "h5ad"] = "loom"):
119
+ tokenized_cells = []
120
+ loom_cell_attr = [attr_key for attr_key in self.custom_attr_name_dict.keys()]
121
+ cell_metadata = {attr_key: [] for attr_key in self.custom_attr_name_dict.values()}
122
+
123
+ # loops through directories to tokenize .loom or .h5ad files
124
+ tokenize_file_fn = self.tokenize_file if file_format == "loom" else self.tokenize_anndata
125
+ for file_path in data_directory.glob("*.{}".format(file_format)):
126
+ print(f"Tokenizing {file_path}")
127
+ file_tokenized_cells, file_cell_metadata = tokenize_file_fn(file_path)
128
+ tokenized_cells += file_tokenized_cells
129
+ for k in loom_cell_attr:
130
+ cell_metadata[self.custom_attr_name_dict[k]] += file_cell_metadata[k]
131
+
132
+ return tokenized_cells, cell_metadata
133
+
134
+ def tokenize_anndata(self, adata_file_path):
135
+ import anndata as ad
136
+
137
+ adata = ad.read(adata_file_path)
138
+ file_cell_metadata = {attr_key: [] for attr_key in self.custom_attr_name_dict.keys()}
139
+
140
+ coding_miRNA_loc = np.where([self.genelist_dict.get(i, False) for i in adata.var["ensembl_id"]])[0]
141
+ norm_factor_vector = np.array([self.gene_median_dict[i] for i in adata.var["ensembl_id"][coding_miRNA_loc]])
142
+ coding_miRNA_ids = adata.var["ensembl_id"][coding_miRNA_loc]
143
+ coding_miRNA_tokens = np.array([self.gene_token_dict[i] for i in coding_miRNA_ids])
144
+
145
+ try:
146
+ adata.obs["filter_pass"]
147
+ except AttributeError:
148
+ var_exists = False
149
+ else:
150
+ var_exists = True
151
+
152
+ if var_exists is True:
153
+ filter_pass_loc = np.where([True if i == 1 else False for i in adata.obs["filter_pass"]])[0]
154
+ elif var_exists is False:
155
+ print(f"{adata_file_path} has no column attribute 'filter_pass'; tokenizing all cells.")
156
+ filter_pass_loc = np.array([i for i in range(adata.shape[1])])
157
+
158
+ tokenized_cells = []
159
+ adata_filter = adata[:, filter_pass_loc]
160
+ X_norm = ((adata_filter.X / adata_filter.X.sum(axis=1) * 10_000) / norm_factor_vector).tocsr()
161
+
162
+ tokenized_cells += [
163
+ tokenize_cell(X_norm[i, ...].A.flatten(), coding_miRNA_tokens) for i in range(X_norm.shape[0])
164
+ ]
165
+
166
+ # add custom attributes for subview to dict
167
+ for k in file_cell_metadata.keys():
168
+ file_cell_metadata[k] += adata_filter.obs[k].tolist()
169
+
170
+ return tokenized_cells, file_cell_metadata
171
+
172
+ def tokenize_file(self, loom_file_path):
173
+ file_cell_metadata = {attr_key: [] for attr_key in self.custom_attr_name_dict.keys()}
174
+
175
+ with lp.connect(str(loom_file_path)) as data:
176
+ # define coordinates of detected protein-coding or miRNA genes and vector of their normalization factors
177
+ coding_miRNA_loc = np.where([self.genelist_dict.get(i, False) for i in data.ra["ensembl_id"]])[0]
178
+ norm_factor_vector = np.array([self.gene_median_dict[i] for i in data.ra["ensembl_id"][coding_miRNA_loc]])
179
+ coding_miRNA_ids = data.ra["ensembl_id"][coding_miRNA_loc]
180
+ coding_miRNA_tokens = np.array([self.gene_token_dict[i] for i in coding_miRNA_ids])
181
+
182
+ # define coordinates of cells passing filters for inclusion (e.g. QC)
183
+ try:
184
+ data.ca["filter_pass"]
185
+ except AttributeError:
186
+ var_exists = False
187
+ else:
188
+ var_exists = True
189
+
190
+ if var_exists is True:
191
+ filter_pass_loc = np.where([True if i == 1 else False for i in data.ca["filter_pass"]])[0]
192
+ elif var_exists is False:
193
+ print(f"{loom_file_path} has no column attribute 'filter_pass'; tokenizing all cells.")
194
+ filter_pass_loc = np.array([i for i in range(data.shape[1])])
195
+
196
+ # scan through .loom files and tokenize cells
197
+ tokenized_cells = []
198
+ for _ix, _selection, view in data.scan(items=filter_pass_loc, axis=1):
199
+ # select subview with protein-coding and miRNA genes
200
+ subview = view.view[coding_miRNA_loc, :]
201
+
202
+ # normalize by total counts per cell and multiply by 10,000 to allocate bits to precision
203
+ # and normalize by gene normalization factors
204
+ subview_norm_array = subview[:, :] / subview.ca.n_counts * 10_000 / norm_factor_vector[:, None]
205
+ # tokenize subview gene vectors
206
+ tokenized_cells += [
207
+ tokenize_cell(subview_norm_array[:, i], coding_miRNA_tokens)
208
+ for i in range(subview_norm_array.shape[1])
209
+ ]
210
+
211
+ # add custom attributes for subview to dict
212
+ for k in file_cell_metadata.keys():
213
+ file_cell_metadata[k] += subview.ca[k].tolist()
214
+
215
+ return tokenized_cells, file_cell_metadata
216
+
217
+ def create_dataset(self, tokenized_cells, cell_metadata):
218
+ # create dict for dataset creation
219
+ dataset_dict = {"input_ids": tokenized_cells}
220
+ dataset_dict.update(cell_metadata)
221
+
222
+ # create dataset
223
+ output_dataset = Dataset.from_dict(dataset_dict)
224
+
225
+ # truncate dataset
226
+ def truncate(example):
227
+ example["input_ids"] = example["input_ids"][0:2048]
228
+ return example
229
+
230
+ output_dataset_truncated = output_dataset.map(truncate, num_proc=self.nproc)
231
+
232
+ # measure lengths of dataset
233
+ def measure_length(example):
234
+ example["length"] = len(example["input_ids"])
235
+ return example
236
+
237
+ output_dataset_truncated_w_length = output_dataset_truncated.map(measure_length, num_proc=self.nproc)
238
+
239
+ return output_dataset_truncated_w_length