Safetensors
soumyatghosh's picture
Upload folder using huggingface_hub
4527b5f verified
"""
Module: calculate_biostats.py
This module calculates and aggregates biological statistics from single-cell RNA sequencing (scRNA-seq) data
stored in AnnData format. It generates per-category statistics (e.g., disease, cell type, tissue, sex)
and computes the median expression values for genes across datasets. The results are saved as JSON and CSV files
for downstream analysis.
Main Features:
- Computes the median expression values for genes in the "processed" layer of AnnData files.
- Generates category-wise statistics (e.g., counts of diseases, cell types, tissues, and sexes).
- Aggregates statistics across multiple training datasets.
- Outputs results in JSON and CSV formats for easy integration with other tools.
Dependencies:
- anndata: For handling AnnData files.
- numpy: For numerical operations, including median calculations.
- pandas: For creating and exporting tabular data.
- tqdm: For progress visualization during processing.
- glob: For recursive file searching.
Usage:
- Run this script as a standalone program with the following arguments:
- `--load_dir`: Directory containing the training `.h5ad` files.
- `--stats_dict_name`: Path to save the aggregated statistics JSON file.
"""
import json
import os
from argparse import ArgumentParser
from glob import glob
import anndata as ad
import numpy as np
import pandas as pd
from datasets.utils.logging import disable_progress_bar
from tqdm import tqdm
def make_median_list(file, out_file):
data = ad.read_h5ad(file)
# set up gene ids
gene_index = data.var.index
all_X = data.layers["processed"].toarray()
all_X[all_X == 0] = np.nan
median = np.nanmedian(all_X, axis=0) # (gene_ids,)
num_median = np.where(~np.isnan(median))[0]
median_dict = {gene_index[k]: median[k].item() for k in num_median}
with open(out_file, "w") as f:
json.dump(median_dict, f, indent=4)
if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument("--load_dir", default="")
parser.add_argument("--stats_dict_name", default="")
args = parser.parse_args()
disable_progress_bar()
# calculate median of medians
all_train = list(glob(args.load_dir + "/**/train_*.h5ad", recursive=True))
print("Generating individual stats")
for train in tqdm(all_train):
data = ad.read_h5ad(train, backed="r+")
stats = {}
for cat in ["disease", "cell_type", "tissue", "sex"]:
stats[cat] = data.obs[cat].value_counts().to_dict()
with open(os.path.join(os.path.dirname(train), "bio_stats.json"), "w") as f:
json.dump(stats, f, indent=4)
print("Collecting stats")
summary_dict = {}
summary_dict["disease"] = {}
summary_dict["cell_type"] = {}
summary_dict["tissue"] = {}
summary_dict["sex"] = {}
for train in tqdm(all_train):
with open(os.path.join(os.path.dirname(train), "bio_stats.json")) as f:
stats = json.load(f)
for cat in ["disease", "cell_type", "tissue", "sex"]:
for k in stats[cat].keys():
if k not in summary_dict[cat].keys():
summary_dict[cat][k] = stats[cat][k]
else:
summary_dict[cat][k] += stats[cat][k]
os.makedirs(os.path.dirname(args.stats_dict_name), exist_ok=True)
with open(args.stats_dict_name, "w") as f:
json.dump(summary_dict, f, indent=4)
# with open(args.stats_dict_name) as f:
# summary_dict = json.load(f)
for cat in ["disease", "cell_type", "tissue", "sex"]:
df = pd.DataFrame.from_dict(summary_dict[cat], orient="index", columns=["Counts"])
df.to_csv(args.stats_dict_name.replace(".json", f"_{cat}.csv"))