|
"""
|
|
Geneformer tokenizer.
|
|
|
|
Input data:
|
|
Required format: raw counts scRNAseq data without feature selection as .loom file
|
|
Required row (gene) attribute: "ensembl_id"; Ensembl ID for each gene
|
|
Required col (cell) attribute: "n_counts"; total read counts in that cell
|
|
Optional col (cell) attribute: "filter_pass"; binary indicator of whether cell should be tokenized based on user-defined filtering criteria
|
|
Optional col (cell) attributes: any other cell metadata can be passed on to the tokenized dataset as a custom attribute dictionary as shown below
|
|
|
|
Usage:
|
|
from geneformer import TranscriptomeTokenizer
|
|
tk = TranscriptomeTokenizer({"cell_type": "cell_type", "organ_major": "organ_major"}, nproc=4)
|
|
tk.tokenize_data("loom_data_directory", "output_directory", "output_prefix")
|
|
"""
|
|
|
|
import pickle
|
|
from pathlib import Path
|
|
|
|
import logging
|
|
|
|
import warnings
|
|
warnings.filterwarnings("ignore", message=".*The 'nopython' keyword.*")
|
|
|
|
import loompy as lp
|
|
import numpy as np
|
|
from datasets import Dataset
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
GENE_MEDIAN_FILE = Path(__file__).parent / "gene_median_dictionary.pkl"
|
|
TOKEN_DICTIONARY_FILE = Path(__file__).parent / "token_dictionary.pkl"
|
|
|
|
|
|
def tokenize_cell(gene_vector, gene_tokens):
|
|
"""
|
|
Convert normalized gene expression vector to tokenized rank value encoding.
|
|
"""
|
|
|
|
|
|
nonzero_mask = np.nonzero(gene_vector)[0]
|
|
|
|
sorted_indices = np.argsort(-gene_vector[nonzero_mask])
|
|
|
|
sentence_tokens = gene_tokens[nonzero_mask][sorted_indices]
|
|
return sentence_tokens
|
|
|
|
|
|
class TranscriptomeTokenizer:
|
|
def __init__(
|
|
self,
|
|
custom_attr_name_dict=None,
|
|
nproc=1,
|
|
gene_median_file=GENE_MEDIAN_FILE,
|
|
token_dictionary_file=TOKEN_DICTIONARY_FILE,
|
|
):
|
|
"""
|
|
Initialize tokenizer.
|
|
|
|
Parameters
|
|
----------
|
|
custom_attr_name_dict : None, dict
|
|
Dictionary of custom attributes to be added to the dataset.
|
|
Keys are the names of the attributes in the loom file.
|
|
Values are the names of the attributes in the dataset.
|
|
nproc : int
|
|
Number of processes to use for dataset mapping.
|
|
gene_median_file : Path
|
|
Path to pickle file containing dictionary of non-zero median
|
|
gene expression values across Genecorpus-30M.
|
|
token_dictionary_file : Path
|
|
Path to pickle file containing token dictionary (Ensembl IDs:token).
|
|
"""
|
|
|
|
self.custom_attr_name_dict = custom_attr_name_dict
|
|
|
|
|
|
self.nproc = nproc
|
|
|
|
|
|
|
|
with open(gene_median_file, "rb") as f:
|
|
self.gene_median_dict = pickle.load(f)
|
|
|
|
|
|
with open(token_dictionary_file, "rb") as f:
|
|
self.gene_token_dict = pickle.load(f)
|
|
|
|
|
|
self.gene_keys = list(self.gene_median_dict.keys())
|
|
|
|
|
|
self.genelist_dict = dict(zip(self.gene_keys, [True] * len(self.gene_keys)))
|
|
|
|
def tokenize_data(self, loom_data_directory, output_directory, output_prefix):
|
|
"""
|
|
Tokenize .loom files in loom_data_directory and save as tokenized .dataset in output_directory.
|
|
|
|
Parameters
|
|
----------
|
|
loom_data_directory : Path
|
|
Path to directory containing loom files
|
|
output_directory : Path
|
|
Path to directory where tokenized data will be saved as .dataset
|
|
output_prefix : str
|
|
Prefix for output .dataset
|
|
"""
|
|
tokenized_cells, cell_metadata = self.tokenize_files(Path(loom_data_directory))
|
|
tokenized_dataset = self.create_dataset(tokenized_cells, cell_metadata)
|
|
|
|
output_path = (Path(output_directory) / output_prefix).with_suffix(".dataset")
|
|
tokenized_dataset.save_to_disk(output_path)
|
|
|
|
def tokenize_files(self, loom_data_directory):
|
|
tokenized_cells = []
|
|
if self.custom_attr_name_dict is not None:
|
|
loom_cell_attr = [attr_key for attr_key in self.custom_attr_name_dict.keys()]
|
|
cell_metadata = {attr_key: [] for attr_key in self.custom_attr_name_dict.values()}
|
|
|
|
|
|
file_found = 0
|
|
for loom_file_path in loom_data_directory.glob("*.loom"):
|
|
file_found = 1
|
|
print(f"Tokenizing {loom_file_path}")
|
|
file_tokenized_cells, file_cell_metadata = self.tokenize_file(
|
|
loom_file_path
|
|
)
|
|
tokenized_cells += file_tokenized_cells
|
|
if self.custom_attr_name_dict is not None:
|
|
for k in loom_cell_attr:
|
|
cell_metadata[self.custom_attr_name_dict[k]] += file_cell_metadata[k]
|
|
else:
|
|
cell_metadata = None
|
|
|
|
if file_found == 0:
|
|
logger.error(
|
|
f"No .loom files found in directory {loom_data_directory}.")
|
|
raise
|
|
return tokenized_cells, cell_metadata
|
|
|
|
def tokenize_file(self, loom_file_path):
|
|
if self.custom_attr_name_dict is not None:
|
|
file_cell_metadata = {
|
|
attr_key: [] for attr_key in self.custom_attr_name_dict.keys()
|
|
}
|
|
|
|
with lp.connect(str(loom_file_path)) as data:
|
|
|
|
coding_miRNA_loc = np.where(
|
|
[self.genelist_dict.get(i, False) for i in data.ra["ensembl_id"]]
|
|
)[0]
|
|
norm_factor_vector = np.array(
|
|
[
|
|
self.gene_median_dict[i]
|
|
for i in data.ra["ensembl_id"][coding_miRNA_loc]
|
|
]
|
|
)
|
|
coding_miRNA_ids = data.ra["ensembl_id"][coding_miRNA_loc]
|
|
coding_miRNA_tokens = np.array(
|
|
[self.gene_token_dict[i] for i in coding_miRNA_ids]
|
|
)
|
|
|
|
|
|
try:
|
|
data.ca["filter_pass"]
|
|
except AttributeError:
|
|
var_exists = False
|
|
else:
|
|
var_exists = True
|
|
|
|
if var_exists is True:
|
|
filter_pass_loc = np.where(
|
|
[True if i == 1 else False for i in data.ca["filter_pass"]]
|
|
)[0]
|
|
elif var_exists is False:
|
|
print(
|
|
f"{loom_file_path} has no column attribute 'filter_pass'; tokenizing all cells."
|
|
)
|
|
filter_pass_loc = np.array([i for i in range(data.shape[1])])
|
|
|
|
|
|
tokenized_cells = []
|
|
for (_ix, _selection, view) in data.scan(items=filter_pass_loc, axis=1):
|
|
|
|
subview = view.view[coding_miRNA_loc, :]
|
|
|
|
|
|
|
|
subview_norm_array = (
|
|
subview[:, :]
|
|
/ subview.ca.n_counts
|
|
* 10_000
|
|
/ norm_factor_vector[:, None]
|
|
)
|
|
|
|
tokenized_cells += [
|
|
tokenize_cell(subview_norm_array[:, i], coding_miRNA_tokens)
|
|
for i in range(subview_norm_array.shape[1])
|
|
]
|
|
|
|
|
|
if self.custom_attr_name_dict is not None:
|
|
for k in file_cell_metadata.keys():
|
|
file_cell_metadata[k] += subview.ca[k].tolist()
|
|
else:
|
|
file_cell_metadata = None
|
|
|
|
return tokenized_cells, file_cell_metadata
|
|
|
|
def create_dataset(self, tokenized_cells, cell_metadata):
|
|
|
|
dataset_dict = {"input_ids": tokenized_cells}
|
|
if self.custom_attr_name_dict is not None:
|
|
dataset_dict.update(cell_metadata)
|
|
|
|
|
|
output_dataset = Dataset.from_dict(dataset_dict)
|
|
|
|
|
|
def truncate(example):
|
|
example["input_ids"] = example["input_ids"][0:2048]
|
|
return example
|
|
|
|
output_dataset_truncated = output_dataset.map(truncate, num_proc=self.nproc)
|
|
|
|
|
|
def measure_length(example):
|
|
example["length"] = len(example["input_ids"])
|
|
return example
|
|
|
|
output_dataset_truncated_w_length = output_dataset_truncated.map(
|
|
measure_length, num_proc=self.nproc
|
|
)
|
|
|
|
return output_dataset_truncated_w_length
|
|
|