RobroKools's picture
Upload 44 files
e59f78e verified
import torch
import os
import scanpy as sc
import numpy as np
import json
import scipy
from celldreamer.data.download import collect_data
from celldreamer.data.process import process
from celldreamer.data.plots import validate
from celldreamer.data.class_celldreamerDataset import CellDreamerDataset
def create_data():
collect_data()
process()
validate()
dtr = CellDreamerDataset(pairs_path="celldreamer/data/processed/train_pairs.npy")
dv = CellDreamerDataset(pairs_path="celldreamer/data/processed/val_pairs.npy")
dt = CellDreamerDataset(pairs_path="celldreamer/data/processed/test_pairs.npy")
os.makedirs("celldreamer/data/datasets", exist_ok=True)
torch.save(dtr, "celldreamer/data/datasets/train.pt")
torch.save(dv, "celldreamer/data/datasets/val.pt")
torch.save(dt, "celldreamer/data/datasets/test.pt")
def get_data_stats(n_background_points=5000):
data_path = "celldreamer/data/processed/cleaned.h5ad"
adata = sc.read(data_path)
if adata.raw is not None:
raw_subset = adata.raw[:, adata.var_names]
X_source = raw_subset.X
if scipy.sparse.issparse(X_source):
X_source = X_source.toarray()
mean = np.mean(X_source, axis=0)
std = np.std(X_source, axis=0)
else:
X_source = adata.X
if scipy.sparse.issparse(X_source):
X_source = X_source.toarray()
mean = np.mean(X_source, axis=0)
std = np.std(X_source, axis=0)
std[std == 0] = 1.0
stats = {
"mean": torch.tensor(mean),
"std": torch.tensor(std)
}
os.makedirs("celldreamer/data/stats", exist_ok=True)
torch.save(stats, "celldreamer/data/stats/stats.pt")
# create useful data for react application
output_dir="celldreamer/data/artifacts"
os.makedirs(output_dir, exist_ok=True)
# create index to gene name map
gene_names = adata.var_names.tolist()
gene_indices = {name: i for i, name in enumerate(gene_names)}
gene_map_payload = {
"gene_names": gene_names, # dropdown
"indices": gene_indices # model gene perterbation
}
with open(f"{output_dir}/gene_map.json", "w") as f:
json.dump(gene_map_payload, f)
# get random 5000 coords for showing cell type clusters
if 'X_umap' not in adata.obsm:
if 'neighbors' not in adata.uns:
sc.pp.neighbors(adata)
sc.tl.umap(adata)
total_cells = adata.shape[0]
if total_cells > n_background_points:
indices = np.random.choice(total_cells, n_background_points, replace=False)
indices.sort()
else:
indices = np.arange(total_cells)
umap_coords = adata.obsm['X_umap']
background_payload = []
has_celltype = 'celltype' in adata.obs
for idx in indices:
idx = int(idx)
point = {
"id": idx,
"x": round(float(umap_coords[idx, 0]), 3),
"y": round(float(umap_coords[idx, 1]), 3),
"t": round(float(adata.obs['dpt_pseudotime'].iloc[idx]), 3)
}
if has_celltype:
point["label"] = str(adata.obs['celltype'].iloc[idx])
background_payload.append(point)
with open(f"{output_dir}/background_map.json", "w") as f:
json.dump(background_payload, f)
# get mean ductal cell that can be used as a starting point for people to perterb
stem_mask = adata.obs['celltype'].str.contains('ductal', case=False)
if stem_mask.sum() == 0:
stem_data = adata.X
else:
stem_data = adata.X[stem_mask]
if scipy.sparse.issparse(stem_data):
mean_stem_z_score = stem_data.mean(axis=0).A1
else:
mean_stem_z_score = stem_data.mean(axis=0)
# Un-scale the data so the UI gets usable numbers (not -1.7)
usable_stem_vector = (mean_stem_z_score * std) + mean
usable_stem_vector = np.maximum(usable_stem_vector, 0.0)
with open(f"{output_dir}/default_stem_cell.json", "w") as f:
json.dump(usable_stem_vector.tolist(), f)
if __name__ == "__main__":
create_data()
get_data_stats()