|
""" |
|
Module: preprocess.py |
|
|
|
This module provides a preprocessing pipeline for single-cell RNA sequencing (scRNA-seq) data |
|
stored in AnnData format. It includes functions for loading data, filtering cells and genes, |
|
normalizing and scaling data, and saving processed results. The pipeline is designed to be |
|
configurable via hyperparameters and supports various preprocessing steps such as mitochondrial |
|
gene filtering, highly variable gene selection, and log transformation. |
|
|
|
Main Features: |
|
- Load and preprocess scRNA-seq data in AnnData format. |
|
- Filter cells and genes based on various criteria. |
|
- Normalize, scale, and log-transform data. |
|
- Save processed data and metadata to disk. |
|
- Configurable via JSON-based hyperparameters. |
|
|
|
Dependencies: |
|
- anndata, numpy, pandas, scanpy, scipy, sklearn |
|
|
|
Usage: |
|
- Run this script as a standalone program with a configuration file specifying the hyperparameters. |
|
- Import the `preprocess` function and call it with the data path, metadata path, and hyperparameters. |
|
""" |
|
|
|
import gc |
|
import json |
|
import os |
|
import warnings |
|
from argparse import ArgumentParser |
|
from typing import Sequence, Optional, Union |
|
from pathlib import Path |
|
|
|
import anndata as ad |
|
import numpy as np |
|
import pandas as pd |
|
import scanpy as sc |
|
from anndata import ImplicitModificationWarning |
|
import scipy.sparse as sp |
|
from scipy.sparse import csr_matrix, issparse |
|
from sklearn.utils import sparsefuncs, sparsefuncs_fast |
|
|
|
from teddy.data_processing.utils.gene_mapping.gene_mapper import ( |
|
map_mouse_human, |
|
map_mouse_human2, |
|
) |
|
|
|
|
|
_HUMAN_MITO_ENSEMBL= { |
|
"ENSG00000211459", "ENSG00000210082", |
|
|
|
"ENSG00000210049", "ENSG00000210077", "ENSG00000209082", |
|
"ENSG00000210100", "ENSG00000210107", "ENSG00000210112", |
|
"ENSG00000210119", "ENSG00000210122", "ENSG00000210116", |
|
"ENSG00000210117", "ENSG00000210118", "ENSG00000210124", |
|
"ENSG00000210126", "ENSG00000210134", "ENSG00000210135", |
|
"ENSG00000210142", "ENSG00000210144", "ENSG00000210148", |
|
"ENSG00000210150", "ENSG00000210155", "ENSG00000210196", |
|
"ENSG00000210151", |
|
|
|
"ENSG00000198888", "ENSG00000198763", "ENSG00000198840", |
|
"ENSG00000198886", "ENSG00000212907", "ENSG00000198786", |
|
"ENSG00000198695", "ENSG00000198804", "ENSG00000198712", |
|
"ENSG00000198938", "ENSG00000198899", "ENSG00000228253", |
|
"ENSG00000198727", |
|
} |
|
|
|
_HUMAN_MITO_SYMBOLS = { |
|
"MT-RNR1", "MT-RNR2", "MT-TF", "MT-TV", "MT-TL1", "MT-TI", "MT-TQ", |
|
"MT-TM", "MT-TW", "MT-TA", "MT-TN", "MT-TC", "MT-TY", "MT-TD", "MT-TK", |
|
"MT-TG", "MT-TR", "MT-TH", "MT-TS2", "MT-TL2", "MT-TT", "MT-TE", "MT-TP", |
|
"MT-TS1", "MT-ND1", "MT-ND2", "MT-ND3", "MT-ND4", "MT-ND4L", "MT-ND5", |
|
"MT-ND6", "MT-CO1", "MT-CO2", "MT-CO3", "MT-ATP6", "MT-ATP8", "MT-CYB", |
|
} |
|
|
|
|
|
def load_data_and_metadata(data_path: str, metadata_path: str): |
|
""" |
|
Load an AnnData h5ad file (data_processing) and a JSON file (metadata). |
|
""" |
|
data = ad.read_h5ad(data_path) |
|
with open(metadata_path, "r") as f: |
|
metadata = json.load(f) |
|
return data, metadata |
|
|
|
|
|
def set_raw_if_necessary(data: ad.AnnData): |
|
""" |
|
If data_processing.raw is None, checks if data_processing.X is integer for ~64 cells. |
|
If so, set data_processing.raw = data_processing. Otherwise return None (skip). |
|
""" |
|
if data.raw is not None: |
|
return data |
|
|
|
if 'counts' in data.layers: |
|
X = data.layers['counts'] |
|
|
|
if isinstance(X, np.ndarray): |
|
X_sample = X[:64] |
|
elif issparse(X): |
|
X_sample = X[:64].toarray() |
|
|
|
if np.all(np.equal(np.mod(X_sample, 1), 0)): |
|
data.raw = ad.AnnData(X = data.layers['counts'], var = data.var.copy()) |
|
return data |
|
|
|
X = data.X |
|
|
|
if isinstance(X, np.ndarray): |
|
X_sample = X[:64] |
|
elif issparse(X): |
|
X_sample = X[:64].toarray() |
|
|
|
if np.all(np.equal(np.mod(X_sample, 1), 0)): |
|
data.raw = data |
|
return data |
|
else: |
|
print("No integer-valued matrix found") |
|
return None |
|
|
|
|
|
|
|
|
|
def initialize_processed_layer(data: ad.AnnData): |
|
""" |
|
If 'processed' layer is missing, copy from data_processing.raw.X |
|
""" |
|
if "processed" not in data.layers: |
|
data.layers["processed"] = data.raw.X.astype("float32") |
|
return data |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def filter_reference_id(data: ad.AnnData, hyperparameters: dict): |
|
human_map = pd.read_csv("teddy/data_processing/utils/gene_mapping/data/human_mapping.txt", sep="\t") |
|
mouse_map = pd.read_csv("teddy/data_processing/utils/gene_mapping/data/2407_mouse_gene_mapping.txt", sep="\t") |
|
orthologs = pd.read_csv( |
|
"teddy/data_processing/utils/gene_mapping/data/mouse_to_human_orthologs.one2one.txt", sep="\t" |
|
) |
|
|
|
if hyperparameters.get("mouse_nonorthologs", False): |
|
reference_id = map_mouse_human2( |
|
data_frame=data.var, |
|
query_column=None, |
|
human_map_db=human_map, |
|
mouse_map_db=mouse_map, |
|
orthology_db=orthologs, |
|
)["reference_id"] |
|
else: |
|
reference_id = map_mouse_human( |
|
data_frame=data.var, |
|
query_column=None, |
|
human_map_db=human_map, |
|
mouse_map_db=mouse_map, |
|
orthology_db=orthologs, |
|
)["reference_id"] |
|
|
|
valid_mask = reference_id != "" |
|
data = data[:, valid_mask].copy() |
|
reference_id = reference_id[valid_mask].reset_index(drop=True) |
|
|
|
if not isinstance(data.layers["processed"], np.ndarray): |
|
corrected = data.layers["processed"].toarray() |
|
else: |
|
corrected = data.layers["processed"] |
|
|
|
unique_ids = reference_id.unique() |
|
vars_to_keep = [] |
|
for rid in unique_ids: |
|
repeated_idx = np.where(reference_id == rid)[0] |
|
vars_to_keep.append(repeated_idx[0]) |
|
if len(repeated_idx) > 1: |
|
corrected[:, repeated_idx[0]] = corrected[:, repeated_idx].max(axis=1) |
|
|
|
vars_to_keep = sorted(vars_to_keep) |
|
corrected = corrected[:, vars_to_keep] |
|
data = data[:, vars_to_keep] |
|
|
|
with warnings.catch_warnings(): |
|
warnings.filterwarnings("ignore", category=ImplicitModificationWarning) |
|
data.layers["processed"] = csr_matrix(corrected) |
|
data.var["reference_id"] = list(reference_id[vars_to_keep]) |
|
|
|
gc.collect() |
|
return data |
|
|
|
|
|
|
|
|
|
|
|
|
|
def remove_assays(data: ad.AnnData, assays_to_remove: list): |
|
""" |
|
Removes observations from specified 'assay' categories if 'assay' is in data_processing.obs. |
|
""" |
|
data = data[~data.obs.assay.isin(assays_to_remove)].copy() |
|
gc.collect() |
|
return data |
|
|
|
|
|
def filter_cells_by_gene_counts(data: ad.AnnData, min_count: int): |
|
""" |
|
Removes cells (observations) whose total gene counts < min_count. |
|
""" |
|
mask = sc.pp.filter_cells(data.layers["processed"], min_counts=min_count)[0] |
|
data = data[np.where(mask)].copy() |
|
del mask |
|
gc.collect() |
|
return data |
|
|
|
|
|
def filter_cells_by_mitochondrial_fraction(data: ad.AnnData, max_mito_prop: float): |
|
""" |
|
Remove low-quality cells whose mitochondrial read fraction exceeds *max_fraction*. |
|
DO NOT RUN THIS IN ANY PREPROCESSING PIPELINE UNTIL YOU HAVE SET RAW COUNTS |
|
Parameters |
|
---------- |
|
data |
|
`AnnData` object containing counts. Works with dense or sparse matrices. |
|
max_mito_prop |
|
Threshold above which cells are discarded. |
|
Returns |
|
------- |
|
AnnData |
|
A **copy** of `data` with poor-quality cells removed and two new |
|
columns added to ``.obs``: |
|
- **mito_prop** – per-cell mitochondrial fraction |
|
- **poor_quality_mito** – boolean flag marking dropped cells |
|
""" |
|
|
|
|
|
counts = data.X |
|
var_index = data.var_names |
|
if var_index[0].startswith("ENSG"): |
|
ref = _HUMAN_MITO_ENSEMBL |
|
else: |
|
ref = _HUMAN_MITO_SYMBOLS |
|
mito_idx = np.flatnonzero(var_index.isin(ref)) |
|
if mito_idx.size == 0: |
|
_logger.info("No mitochondrial genes found, returning data") |
|
return data |
|
if sp.issparse(counts): |
|
total = counts.sum(axis=1).A1 |
|
mito = counts[:, mito_idx].sum(axis=1).A1 |
|
else: |
|
total = counts.sum(axis=1) |
|
mito = counts[:, mito_idx].sum(axis=1) |
|
mito_prop = mito / np.maximum(total, 1) |
|
data.obs["mito_prop"] = mito_prop |
|
data.obs["poor_quality_mito"] = mito_prop > max_mito_prop |
|
filtered = data[~data.obs["poor_quality_mito"]].copy() |
|
gc.collect() |
|
return filtered |
|
|
|
|
|
def filter_highly_variable_genes(data: ad.AnnData, method: str): |
|
""" |
|
Filter genes to those that are highly variable using scanpy. |
|
method must be "seurat_v3" or "cell_ranger". |
|
""" |
|
if "highly_variable" in data.var: |
|
data = data[:, data.var["highly_variable"]] |
|
else: |
|
sc.pp.highly_variable_genes(data, flavor=method, n_top_genes=10000) |
|
gc.collect() |
|
return data |
|
|
|
|
|
def normalize_data_inplace(matrix_csr: csr_matrix, norm_value: float): |
|
""" |
|
In-place row normalization + scale. matrix_csr must be a CSR matrix. |
|
""" |
|
|
|
sparsefuncs_fast.inplace_csr_row_normalize_l1(matrix_csr) |
|
|
|
scale_factors = np.array([norm_value] * matrix_csr.shape[0]) |
|
sparsefuncs.inplace_row_scale(matrix_csr, scale_factors) |
|
gc.collect() |
|
|
|
|
|
def scale_columns_by_median_dict(layer: csr_matrix, data: ad.AnnData, median_dict_path: str, median_column: str): |
|
""" |
|
Read a JSON median_dict, scale columns by 1/median. The lookup key is either |
|
data_processing.var.index or data_processing.var[median_column]. |
|
""" |
|
with open(median_dict_path) as f: |
|
median_dict = json.load(f) |
|
|
|
if median_column == "index": |
|
median_var = data.var.index |
|
else: |
|
median_var = data.var[median_column] |
|
|
|
factors = [] |
|
for g in median_var: |
|
if g in median_dict: |
|
factors.append(1.0 / median_dict[g]) |
|
else: |
|
factors.append(1.0) |
|
factors = np.array(factors) |
|
|
|
|
|
sparsefuncs.inplace_csr_column_scale(layer, factors) |
|
|
|
|
|
def log_transform_layer(data: ad.AnnData, layer_name: str = "processed"): |
|
""" |
|
Apply sc.pp.log1p in place to data_processing.layers[layer_name]. |
|
""" |
|
sc.pp.log1p(data, layer=layer_name, copy=False) |
|
|
|
|
|
def compute_and_save_medians(data: ad.AnnData, data_path: str, hyperparameters: dict): |
|
""" |
|
Convert zeros to NaN, compute column medians ignoring NaN, and save results as JSON. |
|
""" |
|
with warnings.catch_warnings(): |
|
warnings.filterwarnings("ignore", r"All-NaN (slice|axis) encountered") |
|
|
|
mat = data.layers["processed"].toarray() |
|
mat[mat == 0] = np.nan |
|
medians = np.nanmedian(mat, axis=0) |
|
|
|
if hyperparameters["median_column"] == "index": |
|
median_var = data.var.index.copy() |
|
if not isinstance(median_var, pd.Series): |
|
median_var = pd.Series(median_var) |
|
else: |
|
median_var = data.var[hyperparameters["median_column"]].copy() |
|
|
|
valid_idxs = np.where(~np.isnan(medians))[0] |
|
median_values = {median_var.iloc[k]: medians[k].item() for k in valid_idxs} |
|
|
|
save_path = data_path.replace(hyperparameters["load_dir"], hyperparameters["save_dir"]) |
|
save_path = save_path.replace(".h5ad", "_medians.json") |
|
with open(save_path, "w") as f: |
|
json.dump(median_values, f, indent=4) |
|
|
|
|
|
def update_metadata(metadata: dict, data: ad.AnnData, hyperparameters: dict): |
|
""" |
|
Update metadata with cell_count and track processing arguments. |
|
""" |
|
metadata["cell_count"] = data.n_obs |
|
if "processing_args" in metadata: |
|
metadata["processing_args"] = [metadata["processing_args"]] + [hyperparameters] |
|
else: |
|
|
|
metadata["processings_args"] = [hyperparameters] |
|
return metadata |
|
|
|
|
|
def save_and_cleanup(data: ad.AnnData, metadata: dict, data_path: str, metadata_path: str, hyperparameters: dict): |
|
""" |
|
Write processed data_processing and metadata to disk, then GC cleanup. |
|
""" |
|
load_dir = hyperparameters["load_dir"] |
|
save_dir = hyperparameters["save_dir"] |
|
data_filename = os.path.basename(data_path) |
|
metadata_filename = os.path.basename(metadata_path) |
|
|
|
save_processed_path = os.path.join(save_dir, data_filename) |
|
save_metadata_path = os.path.join(save_dir, metadata_filename) |
|
|
|
|
|
os.makedirs(os.path.dirname(save_processed_path), exist_ok=True) |
|
os.makedirs(os.path.dirname(save_metadata_path), exist_ok=True) |
|
|
|
if data.n_obs == 0: |
|
return None, None |
|
|
|
|
|
if not isinstance(data.raw.X, csr_matrix): |
|
data.raw.X = csr_matrix(data.raw.X) |
|
if not isinstance(data.X, csr_matrix): |
|
data.X = csr_matrix(data.X) |
|
if "processed" in data.layers and not isinstance(data.layers["processed"], csr_matrix): |
|
data.layers["processed"] = csr_matrix(data.layers["processed"]) |
|
|
|
try: |
|
data.write_h5ad(save_processed_path, compression="gzip") |
|
except Exception: |
|
|
|
if data.obs.index.name in data.obs.columns: |
|
del data.obs[data.obs.index.name] |
|
data.write_h5ad(save_processed_path, compression="gzip") |
|
|
|
del data |
|
gc.collect() |
|
|
|
with open(save_metadata_path, "w") as f: |
|
json.dump(metadata, f, indent=4) |
|
|
|
return True, True |
|
|
|
|
|
def preprocess(data_path: str, metadata_path: str, hyperparameters: dict): |
|
""" |
|
Original pipeline steps: |
|
1. Load data_processing & metadata |
|
2. Ensure data_processing.raw if counts are integer |
|
3. Initialize 'processed' layer |
|
4. Filter genes by reference_id |
|
5. Remove assays |
|
6. Filter cells (min gene counts) |
|
7. Filter cells (max mito fraction) |
|
8. HVG filtering |
|
9. Normalize total |
|
10. Median-based column scaling |
|
11. Log transform |
|
12. Compute medians (optional) |
|
13. Update metadata and save |
|
""" |
|
|
|
data, metadata = load_data_and_metadata(data_path, metadata_path) |
|
|
|
|
|
data = set_raw_if_necessary(data) |
|
if data is None: |
|
return None, None |
|
|
|
|
|
data = initialize_processed_layer(data) |
|
|
|
|
|
|
|
if hyperparameters["reference_id_only"]: |
|
data = filter_reference_id(data, hyperparameters) |
|
|
|
|
|
if "assay" in data.obs and hyperparameters["remove_assays"]: |
|
data = remove_assays(data, hyperparameters["remove_assays"]) |
|
|
|
|
|
if hyperparameters["min_gene_counts"]: |
|
data = filter_cells_by_gene_counts(data, hyperparameters["min_gene_counts"]) |
|
|
|
|
|
if hyperparameters["max_mitochondrial_prop"]: |
|
|
|
data = filter_cells_by_mitochondrial_fraction( |
|
data, hyperparameters["max_mitochondrial_prop"]) |
|
|
|
|
|
if hyperparameters["hvg_method"] in ["seurat_v3", "cell_ranger"]: |
|
data = filter_highly_variable_genes(data, hyperparameters["hvg_method"]) |
|
|
|
|
|
if hyperparameters["normalized_total"]: |
|
if not isinstance(data.layers["processed"], csr_matrix): |
|
data.layers["processed"] = csr_matrix(data.layers["processed"]) |
|
normalize_data_inplace(data.layers["processed"], hyperparameters["normalized_total"]) |
|
|
|
|
|
if hyperparameters["median_dict"]: |
|
scale_columns_by_median_dict( |
|
data.layers["processed"], data, hyperparameters["median_dict"], hyperparameters["median_column"] |
|
) |
|
|
|
|
|
if hyperparameters["log1p"]: |
|
log_transform_layer(data, "processed") |
|
|
|
|
|
if hyperparameters["compute_medians"]: |
|
compute_and_save_medians(data, data_path, hyperparameters) |
|
|
|
|
|
metadata = update_metadata(metadata, data, hyperparameters) |
|
return save_and_cleanup(data, metadata, data_path, metadata_path, hyperparameters) |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
parser = ArgumentParser(description="Preprocess scRNA-seq data stored in AnnData format.") |
|
parser.add_argument( |
|
"--data_path", |
|
type=str, |
|
required=True, |
|
help="Path to the input .h5ad file." |
|
) |
|
parser.add_argument( |
|
"--metadata_path", |
|
type=str, |
|
required=True, |
|
help="Path to the input metadata JSON file." |
|
) |
|
parser.add_argument( |
|
"--config_path", |
|
type=str, |
|
required=True, |
|
help="Path to the JSON configuration file containing hyperparameters." |
|
) |
|
|
|
args = parser.parse_args() |
|
|
|
|
|
with open(args.config_path, "r") as f: |
|
hyperparameters = json.load(f) |
|
|
|
|
|
success, _ = preprocess( |
|
data_path=args.data_path, |
|
metadata_path=args.metadata_path, |
|
hyperparameters=hyperparameters |
|
) |
|
|
|
if success: |
|
print("Preprocessing completed successfully.") |
|
else: |
|
print("Preprocessing returned no data (0 cells), no file saved.") |
|
|