|
""" |
|
Module: tokenization.py |
|
|
|
This module provides a tokenization pipeline for preprocessed single-cell RNA sequencing (scRNA-seq) data. |
|
It converts gene expression data stored in AnnData format into tokenized sequences that can |
|
be used for downstream machine learning tasks, such as masked language modeling or classification. |
|
|
|
Main Features: |
|
- Tokenizes gene expression data into integer tokens using a custom GeneTokenizer. |
|
- Supports additional biological annotations (e.g., disease, tissue, cell type, sex). |
|
- Handles both top-k and random gene selection for tokenization. |
|
- Configurable via JSON-based hyperparameters or TokenizationArgs objects. |
|
- Saves tokenized data in Hugging Face Dataset format for efficient processing. |
|
|
|
Dependencies: |
|
- anndata, numpy, torch, datasets, tqdm |
|
|
|
Usage: |
|
- Run this script as a standalone program with a configuration file specifying the hyperparameters. |
|
- Import the `tokenize` function and call it with the data path, metadata path, and tokenization arguments. |
|
""" |
|
|
|
import gc |
|
import os |
|
import json |
|
import random |
|
import shutil |
|
from argparse import ArgumentParser |
|
from typing import Union |
|
|
|
import anndata as ad |
|
import numpy as np |
|
import torch |
|
from datasets import Dataset, load_from_disk |
|
from tqdm import tqdm |
|
|
|
from teddy.tokenizer.gene_tokenizer import GeneTokenizer |
|
from teddy.tokenizer.tokenization_args import TokenizationArgs |
|
|
|
|
|
|
|
|
|
|
|
|
|
def _bin_values(vals_list, tokenization_args, no_sorting=False): |
|
""" |
|
Bins expression values into specified bins, assigning bin 0 to non-expressed genes |
|
when `include_zero_genes` is True. |
|
|
|
no_sorting=False => "positional chunk" approach for topk-sorted arrays - provided data_processing is expected to be sorted through topk (input expression values). |
|
no_sorting=True => simple bucketize approach ignoring the topk order - provided data_processing is not sorted (labels). |
|
""" |
|
binned_vals = [] |
|
for vals in vals_list: |
|
if isinstance(vals, np.ndarray): |
|
vals = torch.tensor(vals) |
|
|
|
vals_to_bin = vals |
|
|
|
|
|
if not no_sorting: |
|
|
|
num_repetitions = max(1, len(vals_to_bin) // tokenization_args.bins) |
|
bin_pattern = torch.arange(0, tokenization_args.bins).unsqueeze(1).repeat(1, num_repetitions).flatten() |
|
|
|
|
|
if len(bin_pattern) > len(vals_to_bin): |
|
bin_pattern = bin_pattern[-len(vals_to_bin) :] |
|
else: |
|
extra = len(vals_to_bin) - len(bin_pattern) |
|
if extra > 0: |
|
bin_pattern = torch.cat([torch.zeros(extra), bin_pattern]) |
|
bin_pattern = bin_pattern.flip(0) |
|
|
|
binned_vals.append(bin_pattern) |
|
else: |
|
if len(vals_to_bin) > 0: |
|
bin_edges = torch.linspace(vals_to_bin.min(), vals_to_bin.max(), steps=tokenization_args.bins + 1) |
|
binned_non_zero_vals = torch.bucketize(vals_to_bin, bin_edges) |
|
binned_non_zero_vals = torch.clamp(binned_non_zero_vals, min=1) |
|
binned_tensor = binned_non_zero_vals.float() |
|
binned_vals.append(binned_tensor) |
|
else: |
|
binned_tensor = torch.zeros_like(vals_to_bin, dtype=torch.float) |
|
binned_vals.append(binned_tensor) |
|
return binned_vals |
|
|
|
|
|
def _rank_continuous(vals, tokenization_args): |
|
""" |
|
Ranks gene expression values in the range [-1, 1]. |
|
""" |
|
if isinstance(vals, np.ndarray): |
|
vals = torch.tensor(vals) |
|
|
|
if len(vals) > 0: |
|
ranked_vals = torch.linspace(-1, 1, steps=len(vals)).flip(0) |
|
else: |
|
ranked_vals = vals |
|
return ranked_vals |
|
|
|
|
|
def _prepare_tokenizer_args(tokenization_args: Union[dict, TokenizationArgs]): |
|
""" |
|
Prepares and validates tokenization arguments, ensuring reproducibility |
|
by setting random seeds if specified. |
|
""" |
|
if isinstance(tokenization_args, dict): |
|
load_dir = tokenization_args["load_dir"] |
|
save_dir = tokenization_args["save_dir"] |
|
token_args_obj = TokenizationArgs(**tokenization_args) |
|
else: |
|
|
|
load_dir = tokenization_args.load_dir |
|
save_dir = tokenization_args.save_dir |
|
token_args_obj = tokenization_args |
|
|
|
|
|
if token_args_obj.gene_seed is not None: |
|
random.seed(token_args_obj.gene_seed) |
|
np.random.seed(token_args_obj.gene_seed) |
|
torch.manual_seed(token_args_obj.gene_seed) |
|
if torch.cuda.is_available(): |
|
torch.cuda.manual_seed_all(token_args_obj.gene_seed) |
|
|
|
return token_args_obj, load_dir, save_dir |
|
|
|
|
|
def _check_genes_in_tokenizer(data: ad.AnnData, gene_id_column: str, tokenizer: GeneTokenizer): |
|
""" |
|
Checks if the genes in the dataset are present in the tokenizer's vocabulary. |
|
""" |
|
if gene_id_column == "index": |
|
gene_index = data.var.index |
|
else: |
|
gene_index = data.var[gene_id_column] |
|
|
|
|
|
gene_in_vocab = np.where([g in tokenizer.vocab for g in gene_index])[0] |
|
coding_genes = gene_index[gene_in_vocab] |
|
ratio = len(gene_in_vocab) / len(data.var) |
|
if ratio < 0.1: |
|
raise OSError( |
|
f"Only {ratio:.2%} of gene IDs found in tokenizer vocab. " "Check gene_id_column or vocab mismatch." |
|
) |
|
return gene_in_vocab, coding_genes, ratio |
|
|
|
|
|
def _build_batch_tensors(X_batch: torch.Tensor, token_array: torch.Tensor, token_args, data=None, obs_indices=None): |
|
""" |
|
Build topk or random subsets for each row in X_batch (batch_size x num_genes). |
|
Return gene_list, vals_list, labels_list. |
|
""" |
|
batch_size = X_batch.shape[0] |
|
seq_tokens = token_args.max_seq_len - 1 if token_args.add_cls else token_args.max_seq_len |
|
|
|
|
|
if token_args.random_genes: |
|
random_indices = torch.stack([torch.randperm(X_batch.shape[1])[:seq_tokens] for _ in range(batch_size)]) |
|
random_vals = torch.gather(X_batch, 1, random_indices) |
|
top_vals, rel_indices = torch.topk( |
|
random_vals, k=min(seq_tokens, random_vals.shape[1]), largest=True, sorted=True |
|
) |
|
|
|
top_indices = torch.gather(random_indices, 1, rel_indices) |
|
else: |
|
|
|
top_vals, top_indices = torch.topk(X_batch, k=min(seq_tokens, X_batch.shape[1]), largest=True, sorted=True) |
|
|
|
gene_ids = token_array[top_indices] |
|
|
|
|
|
if token_args.add_cls: |
|
cls_col = torch.tensor(token_args.cls_token_id).repeat(batch_size, 1) |
|
gene_ids = torch.cat([cls_col, gene_ids], dim=1) |
|
ones_col = torch.ones(batch_size, 1, dtype=top_vals.dtype) |
|
top_vals = torch.cat([ones_col, top_vals], dim=1) |
|
|
|
labels_list = None |
|
|
|
return gene_ids, top_vals, labels_list, None |
|
|
|
|
|
|
|
|
|
|
|
def tokenize(data_path: str, metadata_path: str, tokenization_args: Union[dict, TokenizationArgs]): |
|
""" |
|
Tokenizes gene expression data stored in AnnData format. |
|
|
|
Args: |
|
data_path (str): Path to the AnnData file containing preprocessed gene expression data. |
|
metadata_path (str): Path to the metadata file in JSON format. |
|
tokenization_args (Union[dict, TokenizationArgs]): Configuration for tokenization. |
|
""" |
|
|
|
token_args, load_dir, save_dir = _prepare_tokenizer_args(tokenization_args) |
|
|
|
|
|
tokenizer = GeneTokenizer.from_pretrained(token_args.tokenizer_name_or_path) |
|
if token_args.cls_token_id is None: |
|
token_args.cls_token_id = tokenizer.cls_token_id |
|
|
|
|
|
data = ad.read_h5ad(data_path) |
|
|
|
if "processed" not in data.layers: |
|
raise ValueError(f"Missing 'processed' layer in {data_path}") |
|
|
|
|
|
gene_in_vocab, coding_genes, ratio = _check_genes_in_tokenizer(data, token_args.gene_id_column, tokenizer) |
|
print(f"{ratio:.2%} of genes found in tokenizer vocab") |
|
|
|
|
|
token_array = torch.tensor(tokenizer.encode(coding_genes.tolist(), add_special_tokens=False)) |
|
|
|
|
|
X_matrix = data.layers["processed"].toarray() |
|
|
|
|
|
all_data = {"gene_ids": [], "values": []} |
|
|
|
BATCH_SIZE = 512 |
|
n_obs = data.shape[0] |
|
|
|
for start_idx in tqdm(range(0, n_obs, BATCH_SIZE), desc="Tokenizing in batches"): |
|
end_idx = min(start_idx + BATCH_SIZE, n_obs) |
|
obs_indices = np.arange(start_idx, end_idx) |
|
|
|
X_batch = torch.tensor(X_matrix[obs_indices, :][:, gene_in_vocab], dtype=torch.float) |
|
gene_ids_batch, vals_batch, labels_batch, decoder_vals_batch = _build_batch_tensors( |
|
X_batch, |
|
token_array, |
|
token_args, |
|
data=None, |
|
obs_indices=None, |
|
) |
|
|
|
final_gene_list = [] |
|
final_vals_list = [] |
|
final_labels_list = [] |
|
if "decoder_values" in data.layers: |
|
final_decoder_vals_list = [] |
|
|
|
|
|
|
|
for row_idx in range(len(gene_ids_batch)): |
|
g_row = gene_ids_batch[row_idx] |
|
v_row = vals_batch[row_idx] |
|
|
|
if labels_batch is not None: |
|
lb_row = labels_batch[row_idx] |
|
else: |
|
lb_row = None |
|
|
|
if decoder_vals_batch is not None: |
|
dec_v_row = decoder_vals_batch[row_idx] |
|
else: |
|
dec_v_row = None |
|
|
|
if not token_args.include_zero_genes: |
|
nonzero_mask = v_row != 0 |
|
g_row = g_row[nonzero_mask] |
|
v_row = v_row[nonzero_mask] |
|
if lb_row is not None: |
|
lb_row = lb_row[nonzero_mask] |
|
if dec_v_row is not None: |
|
dec_v_row = dec_v_row[nonzero_mask] |
|
|
|
final_gene_list.append(g_row) |
|
final_vals_list.append(v_row) |
|
final_labels_list.append(lb_row) |
|
if "decoder_values" in data.layers: |
|
final_decoder_vals_list.append(dec_v_row) |
|
|
|
|
|
if token_args.bins and token_args.continuous_rank: |
|
raise ValueError("Should not use bins and continuous_rank simultaneously.") |
|
|
|
if token_args.bins: |
|
|
|
|
|
final_vals_list = _bin_values(final_vals_list, token_args, no_sorting=False) |
|
|
|
elif token_args.continuous_rank: |
|
for i, vals in enumerate(final_vals_list): |
|
final_vals_list[i] = _rank_continuous(vals, token_args) |
|
|
|
|
|
for row_idx in range(len(final_gene_list)): |
|
all_data["gene_ids"].append(final_gene_list[row_idx].tolist()) |
|
all_data["values"].append(final_vals_list[row_idx].tolist()) |
|
|
|
if token_args.label_column: |
|
all_data["labels"] = data.obs[token_args.label_column].cat.codes.values.tolist() |
|
|
|
|
|
if token_args.bio_annotations: |
|
with open(token_args.disease_mapping) as f: |
|
disease_mapping = json.load(f) |
|
with open(token_args.tissue_mapping) as f: |
|
tissue_mapping = json.load(f) |
|
with open(token_args.cell_mapping) as f: |
|
cell_mapping = json.load(f) |
|
with open(token_args.sex_mapping) as f: |
|
sex_mapping = json.load(f) |
|
|
|
if "disease" not in data.obs.columns: |
|
data.obs["disease"] = "normal" |
|
if "tissue" not in data.obs.columns: |
|
data.obs["tissue"] = "cultured cell" |
|
if "sex" not in data.obs.columns: |
|
data.obs["sex"] = "unknown" |
|
if "cell_type" not in data.obs.columns: |
|
data.obs["cell_type"] = "unknown" |
|
|
|
mapped_diseases = [disease_mapping[k] for k in data.obs["disease"].tolist()] |
|
mapped_tissues = [tissue_mapping[k] for k in data.obs["tissue"].tolist()] |
|
mapped_cell_types = [cell_mapping[k] for k in data.obs["cell_type"].tolist()] |
|
mapped_sexes = [sex_mapping[k] for k in data.obs["sex"].tolist()] |
|
|
|
all_data["disease"] = tokenizer.encode(mapped_diseases, add_special_tokens=False) |
|
all_data["tissue"] = tokenizer.encode(mapped_tissues, add_special_tokens=False) |
|
all_data["cell_type"] = tokenizer.encode(mapped_cell_types, add_special_tokens=False) |
|
all_data["sex"] = tokenizer.encode(mapped_sexes, add_special_tokens=False) |
|
|
|
if token_args.add_disease_annotation: |
|
|
|
all_data["labels"] = all_data["disease"] |
|
|
|
del data |
|
gc.collect() |
|
|
|
dataset = Dataset.from_dict(all_data) |
|
num_samples = len(dataset) |
|
if token_args.max_shard_samples: |
|
num_shards = num_samples // min(token_args.max_shard_samples, num_samples) |
|
else: |
|
num_shards = 1 |
|
|
|
|
|
relative_data_path = os.path.relpath(data_path, load_dir) |
|
relative_metadata_path = os.path.relpath(metadata_path, load_dir) |
|
|
|
|
|
no_extension_data_path = os.path.splitext(relative_data_path)[0] |
|
|
|
|
|
save_tokenized_data_path = os.path.join(save_dir, no_extension_data_path) |
|
save_metadata_path = os.path.join(save_dir, relative_metadata_path) |
|
|
|
dataset.save_to_disk(save_tokenized_data_path, num_shards=num_shards) |
|
shutil.copy(metadata_path, save_metadata_path) |
|
|
|
|
|
|
|
|
|
|
|
def shard_hf_dataset(data_path: str, metadata_path: str, tokenization_args: Union[dict, TokenizationArgs]): |
|
""" |
|
Shards a Hugging Face Dataset into smaller chunks for efficient storage and processing. |
|
""" |
|
if isinstance(tokenization_args, dict): |
|
load_dir = tokenization_args["load_dir"] |
|
save_dir = tokenization_args["save_dir"] |
|
token_args_obj = TokenizationArgs(**tokenization_args) |
|
else: |
|
load_dir = tokenization_args.load_dir |
|
save_dir = tokenization_args.save_dir |
|
token_args_obj = tokenization_args |
|
|
|
all_data = load_from_disk(data_path) |
|
num_samples = len(all_data) |
|
if token_args_obj.max_shard_samples: |
|
num_shards = num_samples // min(token_args_obj.max_shard_samples, num_samples) |
|
else: |
|
num_shards = 1 |
|
|
|
save_tokenized_data_path = data_path.replace(load_dir, save_dir) |
|
save_metadata_path = metadata_path.replace(load_dir, save_dir) |
|
all_data.save_to_disk(save_tokenized_data_path, num_shards=num_shards) |
|
shutil.copy(metadata_path, save_metadata_path) |
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
parser = ArgumentParser(description="Tokenize an AnnData file for downstream ML tasks.") |
|
parser.add_argument( |
|
"--data_path", |
|
type=str, |
|
required=True, |
|
help="Path to the .h5ad file containing the preprocessed scRNA-seq data." |
|
) |
|
parser.add_argument( |
|
"--metadata_path", |
|
type=str, |
|
required=True, |
|
help="Path to the JSON file containing metadata." |
|
) |
|
parser.add_argument( |
|
"--config_path", |
|
type=str, |
|
required=True, |
|
help="Path to the JSON file specifying tokenization hyperparameters." |
|
) |
|
|
|
args = parser.parse_args() |
|
|
|
|
|
with open(args.config_path, "r") as f: |
|
tokenization_args = json.load(f) |
|
|
|
|
|
tokenize( |
|
data_path=args.data_path, |
|
metadata_path=args.metadata_path, |
|
tokenization_args=tokenization_args |
|
) |
|
|