Spaces:
Sleeping
Sleeping
| import torch | |
| import os | |
| import scanpy as sc | |
| import numpy as np | |
| import json | |
| import scipy | |
| from celldreamer.data.download import collect_data | |
| from celldreamer.data.process import process | |
| from celldreamer.data.plots import validate | |
| from celldreamer.data.class_celldreamerDataset import CellDreamerDataset | |
| def create_data(): | |
| collect_data() | |
| process() | |
| validate() | |
| dtr = CellDreamerDataset(pairs_path="celldreamer/data/processed/train_pairs.npy") | |
| dv = CellDreamerDataset(pairs_path="celldreamer/data/processed/val_pairs.npy") | |
| dt = CellDreamerDataset(pairs_path="celldreamer/data/processed/test_pairs.npy") | |
| os.makedirs("celldreamer/data/datasets", exist_ok=True) | |
| torch.save(dtr, "celldreamer/data/datasets/train.pt") | |
| torch.save(dv, "celldreamer/data/datasets/val.pt") | |
| torch.save(dt, "celldreamer/data/datasets/test.pt") | |
| def get_data_stats(n_background_points=5000): | |
| data_path = "celldreamer/data/processed/cleaned.h5ad" | |
| adata = sc.read(data_path) | |
| if adata.raw is not None: | |
| raw_subset = adata.raw[:, adata.var_names] | |
| X_source = raw_subset.X | |
| if scipy.sparse.issparse(X_source): | |
| X_source = X_source.toarray() | |
| mean = np.mean(X_source, axis=0) | |
| std = np.std(X_source, axis=0) | |
| else: | |
| X_source = adata.X | |
| if scipy.sparse.issparse(X_source): | |
| X_source = X_source.toarray() | |
| mean = np.mean(X_source, axis=0) | |
| std = np.std(X_source, axis=0) | |
| std[std == 0] = 1.0 | |
| stats = { | |
| "mean": torch.tensor(mean), | |
| "std": torch.tensor(std) | |
| } | |
| os.makedirs("celldreamer/data/stats", exist_ok=True) | |
| torch.save(stats, "celldreamer/data/stats/stats.pt") | |
| # create useful data for react application | |
| output_dir="celldreamer/data/artifacts" | |
| os.makedirs(output_dir, exist_ok=True) | |
| # create index to gene name map | |
| gene_names = adata.var_names.tolist() | |
| gene_indices = {name: i for i, name in enumerate(gene_names)} | |
| gene_map_payload = { | |
| "gene_names": gene_names, # dropdown | |
| "indices": gene_indices # model gene perterbation | |
| } | |
| with open(f"{output_dir}/gene_map.json", "w") as f: | |
| json.dump(gene_map_payload, f) | |
| # get random 5000 coords for showing cell type clusters | |
| if 'X_umap' not in adata.obsm: | |
| if 'neighbors' not in adata.uns: | |
| sc.pp.neighbors(adata) | |
| sc.tl.umap(adata) | |
| total_cells = adata.shape[0] | |
| if total_cells > n_background_points: | |
| indices = np.random.choice(total_cells, n_background_points, replace=False) | |
| indices.sort() | |
| else: | |
| indices = np.arange(total_cells) | |
| umap_coords = adata.obsm['X_umap'] | |
| background_payload = [] | |
| has_celltype = 'celltype' in adata.obs | |
| for idx in indices: | |
| idx = int(idx) | |
| point = { | |
| "id": idx, | |
| "x": round(float(umap_coords[idx, 0]), 3), | |
| "y": round(float(umap_coords[idx, 1]), 3), | |
| "t": round(float(adata.obs['dpt_pseudotime'].iloc[idx]), 3) | |
| } | |
| if has_celltype: | |
| point["label"] = str(adata.obs['celltype'].iloc[idx]) | |
| background_payload.append(point) | |
| with open(f"{output_dir}/background_map.json", "w") as f: | |
| json.dump(background_payload, f) | |
| # get mean ductal cell that can be used as a starting point for people to perterb | |
| stem_mask = adata.obs['celltype'].str.contains('ductal', case=False) | |
| if stem_mask.sum() == 0: | |
| stem_data = adata.X | |
| else: | |
| stem_data = adata.X[stem_mask] | |
| if scipy.sparse.issparse(stem_data): | |
| mean_stem_z_score = stem_data.mean(axis=0).A1 | |
| else: | |
| mean_stem_z_score = stem_data.mean(axis=0) | |
| # Un-scale the data so the UI gets usable numbers (not -1.7) | |
| usable_stem_vector = (mean_stem_z_score * std) + mean | |
| usable_stem_vector = np.maximum(usable_stem_vector, 0.0) | |
| with open(f"{output_dir}/default_stem_cell.json", "w") as f: | |
| json.dump(usable_stem_vector.tolist(), f) | |
| if __name__ == "__main__": | |
| create_data() | |
| get_data_stats() |