RobroKools commited on
Commit
e59f78e
·
verified ·
1 Parent(s): 06ac391

Upload 44 files

Browse files
Files changed (44) hide show
  1. app.py +99 -0
  2. celldreamer/__init__.py +0 -0
  3. celldreamer/__pycache__/__init__.cpython-310.pyc +0 -0
  4. celldreamer/__pycache__/__init__.cpython-313.pyc +0 -0
  5. celldreamer/checkpoints/best.pth +3 -0
  6. celldreamer/checkpoints/last.pth +3 -0
  7. celldreamer/config/evaluate_config.yml +29 -0
  8. celldreamer/config/train_config.yml +30 -0
  9. celldreamer/data/__init__.py +133 -0
  10. celldreamer/data/__pycache__/__init__.cpython-310.pyc +0 -0
  11. celldreamer/data/__pycache__/class_celldreamerDataset.cpython-310.pyc +0 -0
  12. celldreamer/data/__pycache__/download.cpython-310.pyc +0 -0
  13. celldreamer/data/__pycache__/plots.cpython-310.pyc +0 -0
  14. celldreamer/data/__pycache__/process.cpython-310.pyc +0 -0
  15. celldreamer/data/class_celldreamerDataset.py +48 -0
  16. celldreamer/data/download.py +17 -0
  17. celldreamer/data/plots.py +33 -0
  18. celldreamer/data/process.py +59 -0
  19. celldreamer/data/stats/stats.pt +3 -0
  20. celldreamer/environments/environment_cpu.yml +25 -0
  21. celldreamer/environments/environment_gpu.yml +29 -0
  22. celldreamer/logs/CellDreamer_V1_Panc8_20260124-172947/events.out.tfevents.1769304587.wifi-10-45-214-157.wifi.berkeley.edu.83075.0 +3 -0
  23. celldreamer/logs/CellDreamer_V1_Panc8_20260124-173010/events.out.tfevents.1769304610.wifi-10-45-214-157.wifi.berkeley.edu.83336.0 +3 -0
  24. celldreamer/logs/CellDreamer_V1_Panc8_20260125-131802/events.out.tfevents.1769375882.wifi-10-45-214-157.wifi.berkeley.edu.13242.0 +3 -0
  25. celldreamer/models/__init__.py +10 -0
  26. celldreamer/models/__pycache__/__init__.cpython-310.pyc +0 -0
  27. celldreamer/models/__pycache__/__init__.cpython-313.pyc +0 -0
  28. celldreamer/models/__pycache__/class_celldreamer.cpython-310.pyc +0 -0
  29. celldreamer/models/__pycache__/evaluate.cpython-310.pyc +0 -0
  30. celldreamer/models/__pycache__/least_squares_umap.cpython-310.pyc +0 -0
  31. celldreamer/models/__pycache__/networks.cpython-310.pyc +0 -0
  32. celldreamer/models/__pycache__/train.cpython-310.pyc +0 -0
  33. celldreamer/models/class_celldreamer.py +94 -0
  34. celldreamer/models/evaluate.py +145 -0
  35. celldreamer/models/least_squares_umap.py +56 -0
  36. celldreamer/models/networks.py +162 -0
  37. celldreamer/models/train.py +170 -0
  38. celldreamer/results/latent_umap.png +0 -0
  39. celldreamer/results/test_metrics.json +11 -0
  40. celldreamer/scripts/data.sh +3 -0
  41. celldreamer/scripts/evaluate.sh +3 -0
  42. celldreamer/scripts/train.sh +5 -0
  43. master.ipynb +241 -0
  44. requirements.txt +8 -0
app.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import sys
4
+ import os
5
+
6
+ sys.path.append(os.getcwd())
7
+ from celldreamer.models.class_celldreamer import ClassCellDreamer
8
+ from celldreamer.models import load_config
9
+
10
+ CONFIG_PATH = "celldreamer/config/evaluate_config.yml"
11
+ CHECKPOINT_PATH = "celldreamer/checkpoints/best.pth"
12
+ STATS_PATH = "celldreamer/data/stats/stats.pt"
13
+ RNN_DIM = 32
14
+
15
+
16
+ try:
17
+ args = load_config(CONFIG_PATH)
18
+ args.device = "cpu"
19
+
20
+ model_wrapper = ClassCellDreamer(args)
21
+ state_dict = torch.load(CHECKPOINT_PATH, map_location=torch.device('cpu'))
22
+ model_wrapper.model.load_state_dict(state_dict)
23
+ model_wrapper.model.eval()
24
+ model_wrapper.model.encoder.eval()
25
+ model_wrapper.model.decoder.eval()
26
+ print("Model loaded successfully.")
27
+
28
+ stats = torch.load(STATS_PATH, map_location="cpu")
29
+ train_mean = stats["mean"].view(1, -1)
30
+ train_std = stats["std"].view(1, -1)
31
+ STATS_LOADED = True
32
+ print("Normalization stats loaded.")
33
+
34
+ except Exception as e:
35
+ print(f"Critical Error during initialization: {e}")
36
+ STATS_LOADED = False
37
+
38
+ def normalize_input(x_raw):
39
+ x_log = torch.log1p(x_raw)
40
+
41
+ if STATS_LOADED:
42
+ x_scaled = (x_log - train_mean) / train_std
43
+ else:
44
+ x_scaled = x_log
45
+
46
+ return torch.clamp(x_scaled, max=10.0)
47
+
48
+ def predict_api(input_data):
49
+ # Validation
50
+ if model_wrapper is None:
51
+ return {"error": "Model not loaded"}
52
+
53
+ try:
54
+ genes = input_data.get("genes")
55
+ steps = input_data.get("steps", 10)
56
+
57
+ x_t = torch.tensor(genes, dtype=torch.float32)
58
+ if x_t.dim() == 1: x_t = x_t.unsqueeze(0)
59
+
60
+ if x_t.shape[1] != args.num_genes:
61
+ return {"error": f"Gene count mismatch. Expected {args.num_genes}, got {x_t.shape[1]}"}
62
+
63
+ x_norm = normalize_input(x_t)
64
+
65
+ trajectory = []
66
+
67
+ with torch.no_grad():
68
+ z_mean, z_std = model_wrapper.model.encoder(x_norm)
69
+
70
+ z_current = z_mean
71
+ hidden_state = torch.zeros(z_current.size(0), RNN_DIM)
72
+
73
+ trajectory = []
74
+
75
+ for i in range(steps):
76
+ trajectory.append(z_current[0].tolist())
77
+ hidden, velocity_mean, velocity_std = model_wrapper.model.rssm(z_current, hidden_state)
78
+
79
+ z_next = z_current + velocity_mean
80
+ z_current = z_next
81
+
82
+ return {
83
+ "status": "success",
84
+ "trajectory": trajectory
85
+ }
86
+
87
+ except Exception as e:
88
+ return {"error": str(e)}
89
+
90
+
91
+ demo = gr.Interface(
92
+ fn=predict_api,
93
+ inputs=gr.JSON(label="Input Gene Vector"),
94
+ outputs=gr.JSON(label="Output"),
95
+ title="CellDreamer API"
96
+ )
97
+
98
+ if __name__ == "__main__":
99
+ demo.launch()
celldreamer/__init__.py ADDED
File without changes
celldreamer/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (176 Bytes). View file
 
celldreamer/__pycache__/__init__.cpython-313.pyc ADDED
Binary file (194 Bytes). View file
 
celldreamer/checkpoints/best.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ea01e526ec38112a805fe698dfd7f41073a9644bb3db2c369da4ff941c669532
3
+ size 5453065
celldreamer/checkpoints/last.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:76bd8aa65b1a7b9193217bc2475b1979e85c72f8ff5bd11d18d477db77baac98
3
+ size 5453065
celldreamer/config/evaluate_config.yml ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ run_name: "Eval_CellDreamer_V1"
2
+ model_type: "celldreamer"
3
+ device: "mps"
4
+
5
+
6
+ data_path: "celldreamer/data/datasets"
7
+ checkpoint_path: "celldreamer/checkpoints/best.pth"
8
+ output_dir: "celldreamer/results"
9
+ output_filename: "test_metrics.json"
10
+
11
+ batch_size: 128
12
+ kl_scale: 0.01 # updated to match train_config to prevent posterior collapse
13
+
14
+
15
+ # MUST BE SAME AS TRAINIG CONFIG
16
+ num_genes: 2446
17
+ latent_dim: 50
18
+ rnn_dim: 32
19
+ learning_rate: 25e-6
20
+
21
+ enc_hidden_dims:
22
+ - 256
23
+ - 128
24
+
25
+ dec_hidden_dims:
26
+ - 128
27
+ - 256
28
+
29
+ weight_decay: 1e-3
celldreamer/config/train_config.yml ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ run_name: "CellDreamer_V1_Panc8"
2
+ model_type: "celldreamer"
3
+ device: "cuda"
4
+
5
+ data_path: "celldreamer/data/datasets"
6
+ save_dir: "celldreamer/checkpoints"
7
+ log_dir: "celldreamer/logs"
8
+
9
+ epochs: 30
10
+ batch_size: 128 # dreamer uses higher batch sizes to reduce noise from affecting learning
11
+ learning_rate: 25e-6
12
+ log_interval: 10
13
+ save_freq: 10
14
+
15
+ num_genes: 2446
16
+ latent_dim: 50 # z (embedding)
17
+ rnn_dim: 32 # h (memory)
18
+
19
+ # [Input -> 256 -> 128 -> Latent]
20
+ enc_hidden_dims:
21
+ - 256
22
+ - 128
23
+
24
+ # [Latent+RNN -> 128 -> 256 -> Output]
25
+ dec_hidden_dims:
26
+ - 128
27
+ - 256
28
+
29
+ weight_decay: 1e-3
30
+ kl_scale: 0.01 # increased from 0.00001 to prevent posterior collapse. Lower = more dream, higher = more physics emphasis
celldreamer/data/__init__.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import os
3
+ import scanpy as sc
4
+ import numpy as np
5
+ import json
6
+ import scipy
7
+
8
+ from celldreamer.data.download import collect_data
9
+ from celldreamer.data.process import process
10
+ from celldreamer.data.plots import validate
11
+ from celldreamer.data.class_celldreamerDataset import CellDreamerDataset
12
+
13
+ def create_data():
14
+ collect_data()
15
+ process()
16
+ validate()
17
+
18
+ dtr = CellDreamerDataset(pairs_path="celldreamer/data/processed/train_pairs.npy")
19
+ dv = CellDreamerDataset(pairs_path="celldreamer/data/processed/val_pairs.npy")
20
+ dt = CellDreamerDataset(pairs_path="celldreamer/data/processed/test_pairs.npy")
21
+
22
+ os.makedirs("celldreamer/data/datasets", exist_ok=True)
23
+ torch.save(dtr, "celldreamer/data/datasets/train.pt")
24
+ torch.save(dv, "celldreamer/data/datasets/val.pt")
25
+ torch.save(dt, "celldreamer/data/datasets/test.pt")
26
+
27
+
28
+ def get_data_stats(n_background_points=5000):
29
+
30
+ data_path = "celldreamer/data/processed/cleaned.h5ad"
31
+ adata = sc.read(data_path)
32
+
33
+ if adata.raw is not None:
34
+ raw_subset = adata.raw[:, adata.var_names]
35
+ X_source = raw_subset.X
36
+ if scipy.sparse.issparse(X_source):
37
+ X_source = X_source.toarray()
38
+
39
+ mean = np.mean(X_source, axis=0)
40
+ std = np.std(X_source, axis=0)
41
+ else:
42
+ X_source = adata.X
43
+ if scipy.sparse.issparse(X_source):
44
+ X_source = X_source.toarray()
45
+
46
+ mean = np.mean(X_source, axis=0)
47
+ std = np.std(X_source, axis=0)
48
+
49
+ std[std == 0] = 1.0
50
+
51
+ stats = {
52
+ "mean": torch.tensor(mean),
53
+ "std": torch.tensor(std)
54
+ }
55
+ os.makedirs("celldreamer/data/stats", exist_ok=True)
56
+ torch.save(stats, "celldreamer/data/stats/stats.pt")
57
+
58
+
59
+ # create useful data for react application
60
+ output_dir="celldreamer/data/artifacts"
61
+ os.makedirs(output_dir, exist_ok=True)
62
+
63
+
64
+ # create index to gene name map
65
+ gene_names = adata.var_names.tolist()
66
+ gene_indices = {name: i for i, name in enumerate(gene_names)}
67
+ gene_map_payload = {
68
+ "gene_names": gene_names, # dropdown
69
+ "indices": gene_indices # model gene perterbation
70
+ }
71
+
72
+ with open(f"{output_dir}/gene_map.json", "w") as f:
73
+ json.dump(gene_map_payload, f)
74
+
75
+
76
+ # get random 5000 coords for showing cell type clusters
77
+ if 'X_umap' not in adata.obsm:
78
+ if 'neighbors' not in adata.uns:
79
+ sc.pp.neighbors(adata)
80
+ sc.tl.umap(adata)
81
+
82
+ total_cells = adata.shape[0]
83
+ if total_cells > n_background_points:
84
+ indices = np.random.choice(total_cells, n_background_points, replace=False)
85
+ indices.sort()
86
+ else:
87
+ indices = np.arange(total_cells)
88
+
89
+ umap_coords = adata.obsm['X_umap']
90
+ background_payload = []
91
+ has_celltype = 'celltype' in adata.obs
92
+
93
+ for idx in indices:
94
+ idx = int(idx)
95
+
96
+ point = {
97
+ "id": idx,
98
+ "x": round(float(umap_coords[idx, 0]), 3),
99
+ "y": round(float(umap_coords[idx, 1]), 3),
100
+ "t": round(float(adata.obs['dpt_pseudotime'].iloc[idx]), 3)
101
+ }
102
+
103
+ if has_celltype:
104
+ point["label"] = str(adata.obs['celltype'].iloc[idx])
105
+
106
+ background_payload.append(point)
107
+
108
+ with open(f"{output_dir}/background_map.json", "w") as f:
109
+ json.dump(background_payload, f)
110
+
111
+ # get mean ductal cell that can be used as a starting point for people to perterb
112
+ stem_mask = adata.obs['celltype'].str.contains('ductal', case=False)
113
+ if stem_mask.sum() == 0:
114
+ stem_data = adata.X
115
+ else:
116
+ stem_data = adata.X[stem_mask]
117
+
118
+ if scipy.sparse.issparse(stem_data):
119
+ mean_stem_z_score = stem_data.mean(axis=0).A1
120
+ else:
121
+ mean_stem_z_score = stem_data.mean(axis=0)
122
+
123
+ # Un-scale the data so the UI gets usable numbers (not -1.7)
124
+ usable_stem_vector = (mean_stem_z_score * std) + mean
125
+ usable_stem_vector = np.maximum(usable_stem_vector, 0.0)
126
+
127
+ with open(f"{output_dir}/default_stem_cell.json", "w") as f:
128
+ json.dump(usable_stem_vector.tolist(), f)
129
+
130
+
131
+ if __name__ == "__main__":
132
+ create_data()
133
+ get_data_stats()
celldreamer/data/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (3.33 kB). View file
 
celldreamer/data/__pycache__/class_celldreamerDataset.cpython-310.pyc ADDED
Binary file (1.66 kB). View file
 
celldreamer/data/__pycache__/download.cpython-310.pyc ADDED
Binary file (693 Bytes). View file
 
celldreamer/data/__pycache__/plots.cpython-310.pyc ADDED
Binary file (1.19 kB). View file
 
celldreamer/data/__pycache__/process.cpython-310.pyc ADDED
Binary file (2.01 kB). View file
 
celldreamer/data/class_celldreamerDataset.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.utils.data import Dataset
3
+ import scanpy as sc
4
+ import numpy as np
5
+ import scipy.sparse
6
+
7
+ class CellDreamerDataset(Dataset):
8
+ def __init__(
9
+ self,
10
+ data_path="celldreamer/data/processed/cleaned.h5ad",
11
+ pairs_path="celldreamer/data/processed/train_pairs.npy",
12
+ normalize=False
13
+ ):
14
+
15
+ adata = sc.read(data_path)
16
+
17
+ data_min = adata.X.min()
18
+ data_max = adata.X.max()
19
+ print(f"min: {data_min:.4f}, max: {data_max:.4f}")
20
+
21
+ if normalize:
22
+ sc.pp.normalize_total(adata, target_sum=1e4)
23
+ sc.pp.log1p(adata)
24
+
25
+ self.pairs = np.load(pairs_path)
26
+
27
+ if scipy.sparse.issparse(adata.X):
28
+ self.data = torch.tensor(adata.X.toarray(), dtype=torch.float32)
29
+ else:
30
+ self.data = torch.tensor(adata.X, dtype=torch.float32)
31
+
32
+ self.times = torch.tensor(adata.obs['dpt_pseudotime'].values, dtype=torch.float32)
33
+
34
+ def __len__(self):
35
+ return len(self.pairs)
36
+
37
+ def __getitem__(self, idx):
38
+ curr_idx, next_idx = self.pairs[idx]
39
+
40
+ x_t = self.data[curr_idx]
41
+ x_next = self.data[next_idx]
42
+
43
+ return {
44
+ "x_t": x_t,
45
+ "x_next": x_next,
46
+ "delta": x_next - x_t,
47
+ "dt": self.times[next_idx] - self.times[curr_idx]
48
+ }
celldreamer/data/download.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import urllib.request
3
+ import scanpy as sc
4
+
5
+
6
+ def collect_data():
7
+
8
+ os.makedirs("celldreamer/data/raw", exist_ok=True)
9
+
10
+ # Source: https://scanpy-tutorials.readthedocs.io/en/latest/integrating-data-using-ingest.html
11
+ url = "https://www.dropbox.com/s/qj1jlm9w10wmt0u/pancreas.h5ad?dl=1"
12
+ save_path = "celldreamer/data/raw/panc8_raw.h5ad"
13
+
14
+ urllib.request.urlretrieve(url, save_path)
15
+
16
+ adata = sc.read(save_path)
17
+ print(f"{adata.shape[0]} cells x {adata.shape[1]} genes")
celldreamer/data/plots.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import scanpy as sc
2
+ import numpy as np
3
+ import matplotlib.pyplot as plt
4
+
5
+
6
+ def validate():
7
+
8
+ adata = sc.read("celldreamer/data/processed/cleaned.h5ad")
9
+ pairs = np.load("celldreamer/data/processed/full_set.npy")
10
+
11
+ sc.tl.umap(adata) # get umap embedding
12
+
13
+ # timeline: EXPECTED; gradient from blue in beginning going to red later on
14
+ fig, axs = plt.subplots(1, 2, figsize=(15, 6))
15
+
16
+ sc.pl.umap(adata, color='dpt_pseudotime', ax=axs[0], show=False, title="Pseudotime (Time)")
17
+ sc.pl.umap(adata, color='celltype', ax=axs[1], show=False, title="Pairs (Arrows)")
18
+
19
+ umap_coords = adata.obsm['X_umap']
20
+
21
+ # choose 100 random pairs and if it's good for those we assume its good for the others
22
+ sample_indices = np.random.choice(len(pairs), 100, replace=False)
23
+ for idx in sample_indices:
24
+ i, j = pairs[idx]
25
+ start = umap_coords[i]
26
+ end = umap_coords[j]
27
+
28
+ # make sure there aren't too many extremeley long arrows in the plot cuz those = data is shooting around umap space
29
+ axs[1].arrow(start[0], start[1], end[0]-start[0], end[1]-start[1],
30
+ head_width=0.3, length_includes_head=True, color='black', alpha=0.5)
31
+
32
+ plt.tight_layout()
33
+ plt.savefig("celldreamer/data/processed/dataset_cell_futures.png")
celldreamer/data/process.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import scanpy as sc
2
+ import numpy as np
3
+ from sklearn.model_selection import train_test_split
4
+ import os
5
+ import warnings
6
+
7
+ warnings.filterwarnings("ignore", category=FutureWarning, module="anndata")
8
+ warnings.filterwarnings("ignore", message="Moving element from .uns")
9
+
10
+ def process():
11
+
12
+ os.makedirs("celldreamer/data/processed", exist_ok=True)
13
+
14
+ adata = sc.read("celldreamer/data/raw/panc8_raw.h5ad")
15
+ sc.pp.filter_cells(adata, min_genes=200)
16
+ sc.pp.filter_genes(adata, min_cells=3)
17
+ print(f"cleaned Shape: {adata.shape}")
18
+
19
+
20
+ print("getting K-nearest nieghbors")
21
+ sc.pp.pca(adata, n_comps=50)
22
+ sc.pp.neighbors(adata, n_neighbors=30, n_pcs=20)
23
+ sc.tl.diffmap(adata)
24
+
25
+ # find step 0 stem cell
26
+ try:
27
+ root_candidates = np.where(adata.obs['celltype'].str.contains('ductal', case=False))[0]
28
+ adata.uns['iroot'] = root_candidates[0] if len(root_candidates) > 0 else 0
29
+ except:
30
+ adata.uns['iroot'] = 0
31
+
32
+ sc.tl.dpt(adata)
33
+
34
+ # create t,t+1 pairs
35
+ print("creating pairs")
36
+ graph = adata.obsp['connectivities']
37
+ times = adata.obs['dpt_pseudotime'].values
38
+ pairs = []
39
+
40
+ rows, cols = graph.nonzero()
41
+ for i, j in zip(rows, cols):
42
+ t_i, t_j = times[i], times[j]
43
+
44
+ # max time diff is 0.1 for ~similar time diffs
45
+ if t_j > t_i and (t_j - t_i) < 0.1:
46
+ pairs.append([i, j])
47
+
48
+ pairs = np.array(pairs)
49
+
50
+ train, temp = train_test_split(pairs, test_size=0.2, random_state=42)
51
+ val, test = train_test_split(temp, test_size=0.5, random_state=42)
52
+
53
+ np.save("celldreamer/data/processed/train_pairs.npy", train)
54
+ np.save("celldreamer/data/processed/val_pairs.npy", val)
55
+ np.save("celldreamer/data/processed/test_pairs.npy", test)
56
+ print(f"Train({len(train)}), Val({len(val)}), Test({len(test)})")
57
+
58
+ adata.write("celldreamer/data/processed/cleaned.h5ad")
59
+ np.save("celldreamer/data/processed/full_set.npy", pairs)
celldreamer/data/stats/stats.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:542bb1069a0d55ba11cc26ffea8ab5e0b94f84198e9614fb68bedc5ddb38b267
3
+ size 20876
celldreamer/environments/environment_cpu.yml ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: celldreamer
2
+ channels:
3
+ - pytorch
4
+ - conda-forge
5
+ - defaults
6
+ dependencies:
7
+ - python=3.10
8
+ - pytorch
9
+ - torchvision
10
+ - torchaudio
11
+ - cpuonly
12
+ - numpy<2.0
13
+ - pandas
14
+ - scipy
15
+ - scikit-learn
16
+ - matplotlib
17
+ - seaborn
18
+ - scanpy
19
+ - python-igraph
20
+ - leidenalg
21
+ - tqdm
22
+ - jupyterlab
23
+ - pip
24
+ - pip:
25
+ - umap-learn
celldreamer/environments/environment_gpu.yml ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: celldreamer
2
+ channels:
3
+ - pytorch
4
+ - nvidia
5
+ - conda-forge
6
+ - defaults
7
+ dependencies:
8
+ - python=3.10
9
+ - pytorch
10
+ - torchvision
11
+ - torchaudio
12
+ - pytorch-cuda=11.8 # 12.1 for 40xx card
13
+ - numpy<2.0
14
+ - pandas
15
+ - scipy
16
+ - scikit-learn
17
+ - matplotlib
18
+ - seaborn
19
+ - scanpy
20
+ - python-igraph
21
+ - leidenalg
22
+ - tqdm
23
+ - jupyterlab
24
+ - pip
25
+ - tensorboard
26
+ - pip:
27
+ - umap-learn
28
+ - python-box
29
+ - yaml
celldreamer/logs/CellDreamer_V1_Panc8_20260124-172947/events.out.tfevents.1769304587.wifi-10-45-214-157.wifi.berkeley.edu.83075.0 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ba268787e132a3ee99092028bcae8b0cc2a6737f3e37b428a101632ed03cf2e8
3
+ size 88
celldreamer/logs/CellDreamer_V1_Panc8_20260124-173010/events.out.tfevents.1769304610.wifi-10-45-214-157.wifi.berkeley.edu.83336.0 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:82a58fcb7f8cac67888636752fdd96097d97f2ebd42843efb67bfe6e17ff11eb
3
+ size 84568
celldreamer/logs/CellDreamer_V1_Panc8_20260125-131802/events.out.tfevents.1769375882.wifi-10-45-214-157.wifi.berkeley.edu.13242.0 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:42ffdfd63227943975a940d4386c61e4c4c84e454fe3976984dff168a790b4b0
3
+ size 88
celldreamer/models/__init__.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ import yaml
2
+ from box import Box
3
+
4
+ def load_config(path):
5
+ with open(path, 'r') as f:
6
+ args = Box(yaml.safe_load(f))
7
+ args.learning_rate = float(args.learning_rate)
8
+ args.weight_decay = float(args.weight_decay)
9
+
10
+ return args
celldreamer/models/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (521 Bytes). View file
 
celldreamer/models/__pycache__/__init__.cpython-313.pyc ADDED
Binary file (786 Bytes). View file
 
celldreamer/models/__pycache__/class_celldreamer.cpython-310.pyc ADDED
Binary file (2.73 kB). View file
 
celldreamer/models/__pycache__/evaluate.cpython-310.pyc ADDED
Binary file (3.45 kB). View file
 
celldreamer/models/__pycache__/least_squares_umap.cpython-310.pyc ADDED
Binary file (1.64 kB). View file
 
celldreamer/models/__pycache__/networks.cpython-310.pyc ADDED
Binary file (3.85 kB). View file
 
celldreamer/models/__pycache__/train.cpython-310.pyc ADDED
Binary file (3.52 kB). View file
 
celldreamer/models/class_celldreamer.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+
4
+ from celldreamer.models.networks import CellDreamer
5
+
6
+ class ClassCellDreamer:
7
+
8
+ def __init__(self, args):
9
+
10
+ self.args = args
11
+ self.device = args.device
12
+
13
+ self.model = CellDreamer(
14
+ device=torch.device(args.device),
15
+ latent_dim=args.latent_dim,
16
+ rnn_dim=args.rnn_dim,
17
+ enc_hidden_dims=args.enc_hidden_dims,
18
+ dec_hidden_dims=args.dec_hidden_dims,
19
+ num_genes=args.num_genes
20
+ )
21
+ self.model.to(self.device)
22
+
23
+ self.optimizer = torch.optim.Adam(
24
+ self.model.parameters(),
25
+ lr=args.learning_rate,
26
+ weight_decay=args.weight_decay
27
+ )
28
+ self.kl_scale = getattr(args, 'kl_scale', 0.1) # default 0.1
29
+
30
+ def get_kl_loss(self, mean1, std1, mean2, std2):
31
+ dist1 = torch.distributions.Normal(mean1, std1)
32
+ dist2 = torch.distributions.Normal(mean2, std2)
33
+ return torch.distributions.kl_divergence(dist1, dist2).sum(dim=1).mean()
34
+
35
+ def train_step(self, x_t, x_next, current_epoch, total_epochs):
36
+
37
+ self.model.train()
38
+ self.optimizer.zero_grad()
39
+
40
+ warmup_period = total_epochs // 2
41
+ kl_weight = min(1.0, (current_epoch / warmup_period))
42
+
43
+ effective_kl = self.kl_scale * kl_weight
44
+
45
+ outputs = self.model(x_t)
46
+ with torch.no_grad():
47
+ target_mean, target_std = self.model.encoder(x_next)
48
+
49
+ recon_loss = F.mse_loss(outputs["recon_x"], x_t)
50
+
51
+ # Dynamics KL: KL(posterior(x_next) || prior_next)
52
+ dynamics_loss = self.get_kl_loss(
53
+ target_mean, target_std,
54
+ outputs["prior_next_mean"], outputs["prior_next_std"]
55
+ )
56
+
57
+ # CRITICAL: Add posterior-prior KL to prevent posterior collapse
58
+ # KL(posterior(x_t) || N(0,1)) - standard VAE regularization
59
+ zeros = torch.zeros_like(outputs["post_mean"])
60
+ ones = torch.ones_like(outputs["post_std"])
61
+ posterior_kl = self.get_kl_loss(
62
+ outputs["post_mean"], outputs["post_std"],
63
+ zeros, ones
64
+ )
65
+
66
+ # Free bits: ensure minimum KL per dimension to prevent collapse
67
+ # This ensures the model uses at least some information capacity
68
+ free_bits_per_dim = 0.1 # minimum nats per dimension
69
+ min_kl = free_bits_per_dim * outputs["post_mean"].shape[1]
70
+ posterior_kl = torch.clamp(posterior_kl, min=min_kl)
71
+ dynamics_loss = torch.clamp(dynamics_loss, min=min_kl)
72
+
73
+ # Total Loss: reconstruction + dynamics KL + posterior regularization
74
+ total_loss = recon_loss + (effective_kl * dynamics_loss) + (effective_kl * posterior_kl)
75
+
76
+ total_loss.backward()
77
+
78
+ torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
79
+
80
+ self.optimizer.step()
81
+
82
+ return {
83
+ "loss": total_loss.item(),
84
+ "recon_loss": recon_loss.item(),
85
+ "dynamics_loss": dynamics_loss.item(),
86
+ "posterior_kl": posterior_kl.item(),
87
+ "kl_weight": effective_kl
88
+ }
89
+
90
+ def save(self, path):
91
+ torch.save(self.model.state_dict(), path)
92
+
93
+ def load(self, path):
94
+ self.model.load_state_dict(torch.load(path, map_location=self.device))
celldreamer/models/evaluate.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.utils.data import DataLoader
3
+ from tqdm import tqdm
4
+ import os
5
+ import numpy as np
6
+ import json
7
+ import argparse
8
+ import sys
9
+ import umap
10
+ import matplotlib.pyplot as plt
11
+
12
+ from celldreamer.models.class_celldreamer import ClassCellDreamer
13
+ from celldreamer.models import load_config
14
+
15
+
16
+ def evaluate(args):
17
+
18
+ device = torch.device(args.device)
19
+
20
+ os.makedirs(args.output_dir, exist_ok=True)
21
+
22
+ test_path = f"{args.data_path}/test.pt"
23
+ print(f"Loading test dataset from {test_path}...")
24
+
25
+ if not os.path.exists(test_path):
26
+ raise FileNotFoundError(f"Test dataset not found at {test_path}")
27
+
28
+ test_ds = torch.load(test_path, weights_only=False)
29
+ test_loader = DataLoader(test_ds, batch_size=args.batch_size, shuffle=False, num_workers=2)
30
+
31
+ print(f"Test Size: {len(test_ds)} samples")
32
+
33
+ print(f"Initializing Model: {args.model_type}")
34
+
35
+ if args.model_type.lower() == "celldreamer":
36
+ model_wrapper = ClassCellDreamer(args)
37
+ else:
38
+ raise ValueError(f"Unknown model type: {args.model_type}")
39
+
40
+ model_wrapper.load(args.checkpoint_path)
41
+ model_wrapper.model.eval()
42
+
43
+ test_recon_losses = []
44
+ test_dynamics_losses = []
45
+ test_posterior_kl_losses = []
46
+ test_total_losses = []
47
+
48
+ all_latents = []
49
+
50
+ print("Running inference...")
51
+ with torch.no_grad():
52
+ for batch in tqdm(test_loader, desc="Evaluating"):
53
+ x_t = batch['x_t'].to(device)
54
+ x_next = batch['x_next'].to(device)
55
+
56
+ outputs = model_wrapper.model(x_t)
57
+
58
+ target_mean, target_std = model_wrapper.model.encoder(x_next)
59
+ recon_loss = torch.nn.functional.mse_loss(outputs["recon_x"], x_t)
60
+
61
+ dyn_loss = model_wrapper.get_kl_loss(
62
+ target_mean, target_std,
63
+ outputs["prior_next_mean"], outputs["prior_next_std"]
64
+ )
65
+
66
+ # Add posterior KL for consistency with training
67
+ zeros = torch.zeros_like(outputs["post_mean"])
68
+ ones = torch.ones_like(outputs["post_std"])
69
+ post_kl = model_wrapper.get_kl_loss(
70
+ outputs["post_mean"], outputs["post_std"],
71
+ zeros, ones
72
+ )
73
+
74
+ # Apply same free bits constraint as training
75
+ free_bits_per_dim = 0.1
76
+ min_kl = free_bits_per_dim * outputs["post_mean"].shape[1]
77
+ post_kl = torch.clamp(post_kl, min=min_kl)
78
+ dyn_loss = torch.clamp(dyn_loss, min=min_kl)
79
+
80
+ # Use same loss computation as training
81
+ total_loss = recon_loss + (args.kl_scale * dyn_loss) + (args.kl_scale * post_kl)
82
+
83
+ test_recon_losses.append(recon_loss.item())
84
+ test_dynamics_losses.append(dyn_loss.item())
85
+ test_posterior_kl_losses.append(post_kl.item())
86
+ test_total_losses.append(total_loss.item())
87
+
88
+ all_latents.append(outputs["post_mean"].cpu())
89
+
90
+ metrics = {
91
+ "model": args.model_type,
92
+ "checkpoint": args.checkpoint_path,
93
+ "test_samples": len(test_ds),
94
+ "metrics": {
95
+ "avg_total_loss": float(np.mean(test_total_losses)),
96
+ "avg_recon_loss_mse": float(np.mean(test_recon_losses)),
97
+ "avg_dynamics_loss_kl": float(np.mean(test_dynamics_losses)),
98
+ "avg_posterior_kl": float(np.mean(test_posterior_kl_losses)),
99
+ "std_total_loss": float(np.std(test_total_losses))
100
+ }
101
+ }
102
+
103
+ print("Results:")
104
+ print(f"MSE (Rec): {metrics['metrics']['avg_recon_loss_mse']:.6f}")
105
+ print(f"KL (Dynamics/Dream): {metrics['metrics']['avg_dynamics_loss_kl']:.6f}")
106
+ print(f"KL (Posterior): {metrics['metrics']['avg_posterior_kl']:.6f}")
107
+ print(f"Total Loss: {metrics['metrics']['avg_total_loss']:.6f}")
108
+
109
+ output_file_path = os.path.join(args.output_dir, args.output_filename)
110
+ with open(output_file_path, 'w') as f:
111
+ json.dump(metrics, f, indent=4)
112
+
113
+ print(f"\nResults saved to: {output_file_path}")
114
+
115
+ print("Generating UMAP visualization...")
116
+ latents_tensor = torch.cat(all_latents)
117
+
118
+ reducer = umap.UMAP(n_components=2)
119
+ coords = reducer.fit_transform(latents_tensor.numpy())
120
+
121
+ plt.figure(figsize=(10, 8))
122
+ plt.scatter(coords[:, 0], coords[:, 1], s=1, alpha=0.5)
123
+ plt.title("Latent Space Visualization")
124
+
125
+ umap_path = os.path.join(args.output_dir, "latent_umap.png")
126
+ plt.savefig(umap_path)
127
+ plt.close()
128
+
129
+ print(f"UMAP plot saved to {umap_path}")
130
+
131
+
132
+ if __name__ == "__main__":
133
+
134
+ parser = argparse.ArgumentParser(description="Evaluation script for celldreamer")
135
+ parser.add_argument(
136
+ "--config",
137
+ type=str,
138
+ default="celldreamer/config/eval_config.yml",
139
+ help="Path to the YAML configuration file"
140
+ )
141
+
142
+ args = parser.parse_args()
143
+ config = load_config(args.config)
144
+
145
+ evaluate(config)
celldreamer/models/least_squares_umap.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import scanpy as sc
3
+ import os
4
+
5
+ from celldreamer.models.class_celldreamer import ClassCellDreamer
6
+ from celldreamer.models import load_config
7
+
8
+ def solve_projector():
9
+
10
+ # loading stuff
11
+ adata = sc.read("celldreamer/data/processed/cleaned.h5ad")
12
+ stats = torch.load("celldreamer/data/stats/stats.pt", weights_only=False)
13
+
14
+ args = load_config("celldreamer/config/evaluate_config.yml")
15
+ args.device = "cpu"
16
+ wrapper = ClassCellDreamer(args)
17
+ wrapper.model.load_state_dict(torch.load("celldreamer/checkpoints/best.pth", map_location="cpu", weights_only=True))
18
+ wrapper.model.eval()
19
+
20
+ if 'X_umap' not in adata.obsm:
21
+ sc.pp.neighbors(adata)
22
+ sc.tl.umap(adata)
23
+
24
+ Y_umap = torch.tensor(adata.obsm['X_umap'], dtype=torch.float32)
25
+
26
+ # raw otherwise just x
27
+ if adata.raw is not None:
28
+ data = adata.raw[:, adata.var_names].X
29
+ else:
30
+ data = adata.X
31
+
32
+ if hasattr(data, "toarray"):
33
+ data = data.toarray()
34
+
35
+ #XTXb = XTy:
36
+
37
+ x_in = torch.tensor(data, dtype=torch.float32)
38
+ x_in = torch.log1p(x_in)
39
+ x_in = (x_in - stats["mean"]) / stats["std"]
40
+ x_in = torch.clamp(x_in, max=10.0)
41
+
42
+ with torch.no_grad():
43
+ Z_latent, _ = wrapper.model.encoder(x_in)
44
+
45
+ solution = torch.linalg.lstsq(Z_latent, Y_umap).solution
46
+
47
+ state_dict = {
48
+ "weight": solution.T,
49
+ "bias": torch.zeros(2) # ignore
50
+ }
51
+
52
+ os.makedirs("celldreamer/data/artifacts", exist_ok=True)
53
+ torch.save(state_dict, "celldreamer/data/artifacts/projector_weights.pth")
54
+
55
+ if __name__ == "__main__":
56
+ solve_projector()
celldreamer/models/networks.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+
6
+ # define a mlp encoder
7
+ # inputs: batch x num_genes (2446)
8
+ # outputs: batch x ecoding_dim
9
+ class Encoder(nn.Module):
10
+
11
+ def __init__(self, latent_dim, hidden_dims, num_genes=2446):
12
+ super().__init__()
13
+
14
+ layers = []
15
+
16
+ prev_dim = num_genes
17
+ for h_dim in hidden_dims:
18
+ layers.append(nn.Linear(prev_dim, h_dim))
19
+ layers.append(nn.BatchNorm1d(h_dim))
20
+ layers.append(nn.ELU())
21
+ layers.append(nn.Dropout(0.4))
22
+ prev_dim = h_dim
23
+
24
+ self.enc_net = nn.Sequential(*layers)
25
+
26
+ self.fc_mean = nn.Linear(prev_dim, latent_dim)
27
+ self.fc_std = nn.Linear(prev_dim, latent_dim)
28
+
29
+
30
+ def forward(self, x_t):
31
+
32
+ h = self.enc_net(x_t)
33
+
34
+ mean = self.fc_mean(h)
35
+
36
+ # Ensure minimum std to prevent posterior collapse
37
+ # Higher minimum (1e-3) prevents std from collapsing to near-zero
38
+ std = F.softplus(self.fc_std(h)) + 1e-3
39
+
40
+ return mean, std
41
+
42
+
43
+ # define a corresponding mlp decoder
44
+ # input: batch x ecoding_dim + rnn_hidden_dim
45
+ class Decoder(nn.Module):
46
+
47
+ def __init__(self, latent_dim, rnn_hidden_dim, hidden_dims, num_genes=2446):
48
+ super().__init__()
49
+
50
+ layers = []
51
+
52
+ prev_dim = latent_dim + rnn_hidden_dim
53
+
54
+ for h_dim in hidden_dims:
55
+ layers.append(nn.Linear(prev_dim, h_dim))
56
+ layers.append(nn.BatchNorm1d(h_dim))
57
+ layers.append(nn.ELU())
58
+ layers.append(nn.Dropout(0.4))
59
+ prev_dim = h_dim
60
+
61
+ layers.append(nn.Linear(prev_dim, num_genes))
62
+ self.dec_net = nn.Sequential(*layers)
63
+
64
+
65
+ def forward(self, z, h):
66
+
67
+ inps = torch.cat([z, h], dim=1)
68
+
69
+ return self.dec_net(inps)
70
+
71
+ # define a gru-based rssm
72
+ # input: batch x ecoding_dim at t=0
73
+ # output: batch x 2*encoding_dim at t = 1 to get the mean and standard deviation
74
+
75
+ class RSSM(nn.Module):
76
+
77
+ def __init__(self, latent_dim, rnn_hidden_dim):
78
+ super().__init__()
79
+
80
+ self.latent_dim = latent_dim
81
+ self.hidden_dim = rnn_hidden_dim
82
+
83
+
84
+ self.gru = nn.GRUCell(latent_dim, rnn_hidden_dim)
85
+ self.mlp = nn.Sequential(
86
+ nn.Linear(rnn_hidden_dim, rnn_hidden_dim),
87
+ nn.LayerNorm(rnn_hidden_dim),
88
+ nn.ELU(),
89
+ nn.Linear(rnn_hidden_dim, 2 * latent_dim)
90
+ )
91
+
92
+ # Better initialization: larger std prevents weak prior
93
+ # Use Xavier/Glorot initialization for better gradient flow
94
+ nn.init.xavier_uniform_(self.mlp[3].weight, gain=0.1)
95
+ nn.init.zeros_(self.mlp[3].bias)
96
+
97
+ def forward(self, prev_r, prev_h):
98
+
99
+ h_t_1 = self.gru(prev_r, prev_h)
100
+
101
+ prev_stats = self.mlp(h_t_1)
102
+
103
+ prev_mean, prev_std = torch.chunk(prev_stats, 2, dim=1)
104
+
105
+ prev_std = F.softplus(prev_std) + 1e-3
106
+
107
+ return h_t_1, prev_mean, prev_std
108
+
109
+
110
+ # create joint training architecture for dreamer
111
+ class CellDreamer(nn.Module):
112
+
113
+ def __init__(
114
+ self,
115
+ device,
116
+ latent_dim = 20,
117
+ rnn_dim = 64,
118
+ enc_hidden_dims = [128, 64, 32],
119
+ dec_hidden_dims = [32, 64, 128],
120
+ num_genes = 2446
121
+ ):
122
+ super().__init__()
123
+
124
+ self.encoder = Encoder(latent_dim, enc_hidden_dims, num_genes)
125
+ self.decoder = Decoder(latent_dim, rnn_dim, dec_hidden_dims, num_genes)
126
+ self.rssm = RSSM(latent_dim, rnn_dim)
127
+
128
+ self.rnn_dim = rnn_dim
129
+ self.latent_dim = latent_dim
130
+ self.input_dim = num_genes
131
+ self.device = device
132
+
133
+ def reparametrize(self, mean, std):
134
+
135
+ eps = torch.randn_like(std)
136
+ return mean + eps * std
137
+
138
+ def forward(self, x_t):
139
+
140
+ post_mean, post_std = self.encoder(x_t)
141
+ z_t = self.reparametrize(post_mean, post_std)
142
+
143
+ h_prev = torch.zeros(x_t.size(0), self.rnn_dim).to(self.device)
144
+
145
+ h_next, velocity_mean, velocity_std = self.rssm(z_t, h_prev)
146
+ prior_next_mean = z_t + velocity_mean
147
+ prior_next_std = velocity_std
148
+
149
+ rec_x = self.decoder(z_t, h_next)
150
+
151
+ return {
152
+ "recon_x": rec_x,
153
+ "post_mean": post_mean,
154
+ "post_std": post_std,
155
+ "prior_next_mean": prior_next_mean,
156
+ "prior_next_std": prior_next_std,
157
+ "z_t": z_t,
158
+ "h_next": h_next
159
+ }
160
+
161
+
162
+
celldreamer/models/train.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.utils.data import DataLoader
3
+ from torch.utils.tensorboard import SummaryWriter
4
+ from tqdm import tqdm
5
+ import os
6
+ import numpy as np
7
+ from datetime import datetime
8
+ import argparse
9
+
10
+ from celldreamer.models.class_celldreamer import ClassCellDreamer
11
+ from celldreamer.models import load_config
12
+
13
+
14
+ def train(args):
15
+ device = torch.device(args.device)
16
+
17
+ os.makedirs(args.save_dir, exist_ok=True)
18
+ os.makedirs(args.log_dir, exist_ok=True)
19
+
20
+ timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")
21
+ writer = SummaryWriter(f"{args.log_dir}/{args.run_name}_{timestamp}")
22
+
23
+ print(f"Loading datasets from {args.data_path}")
24
+
25
+ train_ds = torch.load(f"{args.data_path}/train.pt", weights_only=False)
26
+ val_ds = torch.load(f"{args.data_path}/val.pt", weights_only=False)
27
+
28
+ train_loader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=True, num_workers=4, pin_memory=True)
29
+ val_loader = DataLoader(val_ds, batch_size=args.batch_size, shuffle=False, num_workers=4, pin_memory=True)
30
+
31
+ print(f"Train Size: {len(train_ds)} samples")
32
+ print(f"Val Size: {len(val_ds)} samples")
33
+ print(f"Model: {args.model_type}")
34
+
35
+ if args.model_type.lower() == "celldreamer":
36
+ model_wrapper = ClassCellDreamer(args)
37
+ else:
38
+ raise ValueError(f"Unknown model type: {args.model_type}")
39
+
40
+ global_step = 0
41
+ best_val_loss = float('inf')
42
+ best_val_mse = float('inf') # Track best validation MSE separately
43
+
44
+ for epoch in range(1, args.epochs + 1):
45
+
46
+ # --- TRAIN ---
47
+ model_wrapper.model.train()
48
+ train_mse = []
49
+ train_kl = []
50
+ train_posterior_kl = []
51
+ train_total = []
52
+
53
+ loop = tqdm(train_loader, desc=f"Epoch {epoch}/{args.epochs} [Train]")
54
+
55
+ for batch in loop:
56
+ x_t = batch['x_t'].to(device)
57
+ x_next = batch['x_next'].to(device)
58
+
59
+ logs = model_wrapper.train_step(x_t, x_next, epoch, args.epochs)
60
+
61
+ train_total.append(logs['loss'])
62
+ train_mse.append(logs['recon_loss'])
63
+ train_kl.append(logs['dynamics_loss'])
64
+ train_posterior_kl.append(logs.get('posterior_kl', 0))
65
+
66
+ global_step += 1
67
+
68
+ if global_step % args.log_interval == 0:
69
+ writer.add_scalar("Step/Total_Loss", logs['loss'], global_step)
70
+ writer.add_scalar("Step/Recon_Loss", logs['recon_loss'], global_step)
71
+ writer.add_scalar("Step/Dynamics_KL", logs['dynamics_loss'], global_step)
72
+ writer.add_scalar("Step/Posterior_KL", logs.get('posterior_kl', 0), global_step)
73
+
74
+ loop.set_postfix(loss=logs['loss'])
75
+
76
+ # --- VALIDATION ---
77
+ model_wrapper.model.eval()
78
+ val_mse = []
79
+ val_kl = []
80
+ val_posterior_kl = []
81
+ val_total = []
82
+
83
+ with torch.no_grad():
84
+ for batch in tqdm(val_loader, desc=f"Epoch {epoch}/{args.epochs} [Val] "):
85
+ x_t = batch['x_t'].to(device)
86
+ x_next = batch['x_next'].to(device)
87
+
88
+ outputs = model_wrapper.model(x_t)
89
+ target_mean, target_std = model_wrapper.model.encoder(x_next)
90
+
91
+ recon_loss = torch.nn.functional.mse_loss(outputs["recon_x"], x_t)
92
+ dyn_loss = model_wrapper.get_kl_loss(
93
+ target_mean, target_std,
94
+ outputs["prior_next_mean"], outputs["prior_next_std"]
95
+ )
96
+
97
+ # Add posterior KL for consistency with training
98
+ zeros = torch.zeros_like(outputs["post_mean"])
99
+ ones = torch.ones_like(outputs["post_std"])
100
+ post_kl = model_wrapper.get_kl_loss(
101
+ outputs["post_mean"], outputs["post_std"],
102
+ zeros, ones
103
+ )
104
+
105
+ # Apply same free bits constraint as training
106
+ free_bits_per_dim = 0.1
107
+ min_kl = free_bits_per_dim * outputs["post_mean"].shape[1]
108
+ post_kl = torch.clamp(post_kl, min=min_kl)
109
+ dyn_loss = torch.clamp(dyn_loss, min=min_kl)
110
+
111
+ # Compute KL weight same as training
112
+ warmup_period = args.epochs // 2
113
+ kl_weight = min(1.0, (epoch / warmup_period))
114
+ effective_kl = model_wrapper.kl_scale * kl_weight
115
+ total_val_loss = recon_loss + (effective_kl * dyn_loss) + (effective_kl * post_kl)
116
+
117
+ val_total.append(total_val_loss.item())
118
+ val_mse.append(recon_loss.item())
119
+ val_kl.append(dyn_loss.item())
120
+ val_posterior_kl.append(post_kl.item())
121
+
122
+ # --- STATS ---
123
+ avg_train_loss = np.mean(train_total)
124
+ avg_val_loss = np.mean(val_total)
125
+
126
+ writer.add_scalars("Epoch/MSE", {'Train': np.mean(train_mse), 'Val': np.mean(val_mse)}, epoch)
127
+ writer.add_scalars("Epoch/Dynamics_KL", {'Train': np.mean(train_kl), 'Val': np.mean(val_kl)}, epoch)
128
+ writer.add_scalars("Epoch/Posterior_KL", {'Train': np.mean(train_posterior_kl), 'Val': np.mean(val_posterior_kl)}, epoch)
129
+
130
+ # Calculate KL contribution to understand why validation loss isn't dropping
131
+ warmup_period = args.epochs // 2
132
+ kl_weight = min(1.0, (epoch / warmup_period))
133
+ effective_kl = model_wrapper.kl_scale * kl_weight
134
+ val_kl_contribution = effective_kl * (np.mean(val_kl) + np.mean(val_posterior_kl))
135
+ train_kl_contribution = effective_kl * (np.mean(train_kl) + np.mean(train_posterior_kl))
136
+
137
+ print(f"Stats: Train MSE: {np.mean(train_mse):.4f} | Val MSE: {np.mean(val_mse):.4f} | Train Dyn KL: {np.mean(train_kl):.4f} | Val Dyn KL: {np.mean(val_kl):.4f} | Train Post KL: {np.mean(train_posterior_kl):.4f} | Val Post KL: {np.mean(val_posterior_kl):.4f}")
138
+ print(f"Loss Breakdown: Train Total: {avg_train_loss:.4f} (MSE: {np.mean(train_mse):.4f} + KL: {train_kl_contribution:.4f}) | Val Total: {avg_val_loss:.4f} (MSE: {np.mean(val_mse):.4f} + KL: {val_kl_contribution:.4f}) | KL Weight: {effective_kl:.6f}")
139
+
140
+ if epoch % args.save_freq == 0:
141
+ model_wrapper.save(f"{args.save_dir}/last.pth")
142
+
143
+ avg_val_mse = np.mean(val_mse)
144
+ if avg_val_loss < best_val_loss:
145
+ print(f"Best Total Loss: ({best_val_loss:.4f} -> {avg_val_loss:.4f})")
146
+ best_val_loss = avg_val_loss
147
+
148
+ # Also track best validation MSE (more meaningful metric)
149
+ if avg_val_mse < best_val_mse:
150
+ print(f"Best Val MSE: ({best_val_mse:.4f} -> {avg_val_mse:.4f}) - Saving best model")
151
+ best_val_mse = avg_val_mse
152
+ model_wrapper.save(f"{args.save_dir}/best.pth")
153
+
154
+ writer.close()
155
+
156
+
157
+ if __name__ == "__main__":
158
+
159
+ parser = argparse.ArgumentParser(description="trainig script for celldreamer")
160
+ parser.add_argument(
161
+ "--config",
162
+ type=str,
163
+ default="celldreamer/config/train_config.yml",
164
+ help="Path to the YmML configuration file (default: celldreamer/config/train_config.yml)"
165
+ )
166
+
167
+ args = parser.parse_args()
168
+ config = load_config(args.config)
169
+
170
+ train(config)
celldreamer/results/latent_umap.png ADDED
celldreamer/results/test_metrics.json ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model": "celldreamer",
3
+ "checkpoint": "celldreamer/checkpoints/best.pth",
4
+ "test_samples": 18253,
5
+ "metrics": {
6
+ "avg_total_loss": 0.6892188849982682,
7
+ "avg_recon_loss_mse": 0.6890018098837846,
8
+ "avg_dynamics_loss_kl": 21.70746588540244,
9
+ "std_total_loss": 0.03752287398763396
10
+ }
11
+ }
celldreamer/scripts/data.sh ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ python -m celldreamer.data.__init__
celldreamer/scripts/evaluate.sh ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ python -m celldreamer.models.evaluate --config $1
celldreamer/scripts/train.sh ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ python -m celldreamer.models.train --config $1
4
+
5
+ python -m celldreamer.models.least_squares_umap
master.ipynb ADDED
@@ -0,0 +1,241 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "id": "d6fc963a",
7
+ "metadata": {},
8
+ "outputs": [],
9
+ "source": [
10
+ "%load_ext autoreload"
11
+ ]
12
+ },
13
+ {
14
+ "cell_type": "code",
15
+ "execution_count": 2,
16
+ "id": "6cf002c0",
17
+ "metadata": {},
18
+ "outputs": [],
19
+ "source": [
20
+ "%autoreload 2"
21
+ ]
22
+ },
23
+ {
24
+ "cell_type": "code",
25
+ "execution_count": 15,
26
+ "id": "5e29d1c0",
27
+ "metadata": {},
28
+ "outputs": [],
29
+ "source": [
30
+ "import torch\n",
31
+ "\n",
32
+ "ds = torch.load(\"/Users/rohitkulkarni/Documents/projects/CellDreamer/backend/celldreamer/data/datasets/train.pt\", weights_only=False)"
33
+ ]
34
+ },
35
+ {
36
+ "cell_type": "code",
37
+ "execution_count": 16,
38
+ "id": "ebe6280f",
39
+ "metadata": {},
40
+ "outputs": [
41
+ {
42
+ "data": {
43
+ "text/plain": [
44
+ "torch.Size([2446])"
45
+ ]
46
+ },
47
+ "execution_count": 16,
48
+ "metadata": {},
49
+ "output_type": "execute_result"
50
+ }
51
+ ],
52
+ "source": [
53
+ "ds[0][\"x_t\"].shape"
54
+ ]
55
+ },
56
+ {
57
+ "cell_type": "code",
58
+ "execution_count": 1,
59
+ "id": "f9454346",
60
+ "metadata": {},
61
+ "outputs": [
62
+ {
63
+ "name": "stdout",
64
+ "output_type": "stream",
65
+ "text": [
66
+ "Calculating stats from data matrix...\n"
67
+ ]
68
+ }
69
+ ],
70
+ "source": [
71
+ "from celldreamer.data import get_data_stats\n",
72
+ "\n",
73
+ "get_data_stats()"
74
+ ]
75
+ },
76
+ {
77
+ "cell_type": "code",
78
+ "execution_count": 17,
79
+ "id": "8c8ff06c",
80
+ "metadata": {},
81
+ "outputs": [
82
+ {
83
+ "name": "stdout",
84
+ "output_type": "stream",
85
+ "text": [
86
+ "Loaded as API: https://robrokools-celldreamer-api.hf.space\n"
87
+ ]
88
+ },
89
+ {
90
+ "data": {
91
+ "text/plain": [
92
+ "array([[ 0.20221904, -0.10513306, -0.23988042, 0.1219071 , -0.31176904,\n",
93
+ " -0.07312202, -0.17483664, 0.34703633, -0.14286399, 0.01501414,\n",
94
+ " 0.24577391, -0.17025626, -0.01052079, -0.16482973, 0.01907933,\n",
95
+ " -0.00870946, -0.18495346, 0.0982306 , 0.19570428, 0.03290927,\n",
96
+ " -0.08225775, -0.14782619, -0.00959128, -0.04247084, -0.09117351,\n",
97
+ " 0.02470946, -0.0560773 , -0.0605984 , -0.18847048, 0.06813312,\n",
98
+ " 0.24255574, 0.15523338, 0.01986483, -0.23465055, -0.02495009,\n",
99
+ " 0.03532511, 0.0018872 , -0.07421678, -0.18519297, -0.09254473,\n",
100
+ " -0.18334997, -0.19211988, -0.07095522, 0.08980912, 0.09272885,\n",
101
+ " -0.00154805, -0.11791486, 0.3486139 , -0.21823978, 0.01764041],\n",
102
+ " [ 0.20221904, -0.10513306, -0.23988041, 0.12190711, -0.31176903,\n",
103
+ " -0.07312202, -0.17483664, 0.34703633, -0.14286399, 0.01501414,\n",
104
+ " 0.24577391, -0.17025626, -0.01052079, -0.16482973, 0.01907933,\n",
105
+ " -0.00870946, -0.18495346, 0.0982306 , 0.19570431, 0.03290927,\n",
106
+ " -0.08225775, -0.14782619, -0.00959128, -0.04247084, -0.09117351,\n",
107
+ " 0.02470946, -0.0560773 , -0.0605984 , -0.18847048, 0.06813312,\n",
108
+ " 0.24255586, 0.15523338, 0.01986483, -0.23465055, -0.02495009,\n",
109
+ " 0.03532511, 0.0018872 , -0.0742168 , -0.18519297, -0.09254467,\n",
110
+ " -0.18334997, -0.19211988, -0.07095522, 0.08980912, 0.09272885,\n",
111
+ " -0.00154805, -0.11791486, 0.3486139 , -0.21823978, 0.01764041],\n",
112
+ " [ 0.20221904, -0.10513306, -0.23988041, 0.12190713, -0.31176903,\n",
113
+ " -0.07312202, -0.17483664, 0.34703633, -0.14286399, 0.01501414,\n",
114
+ " 0.24577391, -0.17025626, -0.01052079, -0.16482973, 0.01907933,\n",
115
+ " -0.00870946, -0.18495346, 0.0982306 , 0.19570434, 0.03290927,\n",
116
+ " -0.08225775, -0.14782619, -0.00959128, -0.04247084, -0.09117351,\n",
117
+ " 0.02470946, -0.0560773 , -0.0605984 , -0.18847048, 0.06813312,\n",
118
+ " 0.24255598, 0.15523338, 0.01986483, -0.23465055, -0.02495009,\n",
119
+ " 0.03532511, 0.0018872 , -0.07421681, -0.18519297, -0.09254462,\n",
120
+ " -0.18334997, -0.19211989, -0.07095522, 0.08980912, 0.09272885,\n",
121
+ " -0.00154805, -0.11791486, 0.3486139 , -0.21823978, 0.01764041],\n",
122
+ " [ 0.20221904, -0.10513306, -0.2398804 , 0.12190714, -0.31176902,\n",
123
+ " -0.07312202, -0.17483664, 0.34703633, -0.14286399, 0.01501414,\n",
124
+ " 0.24577391, -0.17025626, -0.01052079, -0.16482973, 0.01907933,\n",
125
+ " -0.00870946, -0.18495345, 0.0982306 , 0.19570437, 0.03290927,\n",
126
+ " -0.08225775, -0.14782619, -0.00959128, -0.04247084, -0.09117351,\n",
127
+ " 0.02470946, -0.0560773 , -0.0605984 , -0.18847048, 0.06813312,\n",
128
+ " 0.2425561 , 0.15523338, 0.01986483, -0.23465055, -0.02495009,\n",
129
+ " 0.03532511, 0.0018872 , -0.07421683, -0.18519297, -0.09254456,\n",
130
+ " -0.18334997, -0.1921199 , -0.07095522, 0.08980912, 0.09272885,\n",
131
+ " -0.00154805, -0.11791486, 0.3486139 , -0.21823978, 0.01764041],\n",
132
+ " [ 0.20221904, -0.10513306, -0.23988039, 0.12190716, -0.31176901,\n",
133
+ " -0.07312202, -0.17483664, 0.34703633, -0.14286399, 0.01501414,\n",
134
+ " 0.24577391, -0.17025626, -0.01052079, -0.16482973, 0.01907933,\n",
135
+ " -0.00870946, -0.18495345, 0.0982306 , 0.1957044 , 0.03290927,\n",
136
+ " -0.08225775, -0.14782619, -0.00959128, -0.04247084, -0.09117351,\n",
137
+ " 0.02470946, -0.0560773 , -0.0605984 , -0.18847048, 0.06813312,\n",
138
+ " 0.24255621, 0.15523338, 0.01986483, -0.23465055, -0.02495009,\n",
139
+ " 0.03532511, 0.0018872 , -0.07421684, -0.18519297, -0.0925445 ,\n",
140
+ " -0.18334997, -0.1921199 , -0.07095522, 0.08980912, 0.09272885,\n",
141
+ " -0.00154805, -0.11791486, 0.3486139 , -0.21823978, 0.01764041],\n",
142
+ " [ 0.20221904, -0.10513306, -0.23988038, 0.12190717, -0.311769 ,\n",
143
+ " -0.07312202, -0.17483664, 0.34703633, -0.14286399, 0.01501414,\n",
144
+ " 0.24577391, -0.17025626, -0.01052079, -0.16482973, 0.01907933,\n",
145
+ " -0.00870946, -0.18495345, 0.0982306 , 0.19570443, 0.03290927,\n",
146
+ " -0.08225775, -0.14782619, -0.00959128, -0.04247084, -0.09117351,\n",
147
+ " 0.02470946, -0.0560773 , -0.0605984 , -0.18847048, 0.06813312,\n",
148
+ " 0.24255633, 0.15523338, 0.01986483, -0.23465055, -0.02495009,\n",
149
+ " 0.03532511, 0.0018872 , -0.07421686, -0.18519297, -0.09254444,\n",
150
+ " -0.18334997, -0.19211991, -0.07095522, 0.08980912, 0.09272885,\n",
151
+ " -0.00154805, -0.11791486, 0.3486139 , -0.21823978, 0.01764041],\n",
152
+ " [ 0.20221904, -0.10513306, -0.23988038, 0.12190719, -0.311769 ,\n",
153
+ " -0.07312202, -0.17483664, 0.34703633, -0.14286399, 0.01501414,\n",
154
+ " 0.24577391, -0.17025626, -0.01052079, -0.16482973, 0.01907933,\n",
155
+ " -0.00870946, -0.18495344, 0.0982306 , 0.19570446, 0.03290927,\n",
156
+ " -0.08225775, -0.14782619, -0.00959128, -0.04247084, -0.09117351,\n",
157
+ " 0.02470946, -0.0560773 , -0.0605984 , -0.18847048, 0.06813312,\n",
158
+ " 0.24255645, 0.15523338, 0.01986483, -0.23465055, -0.02495009,\n",
159
+ " 0.03532511, 0.0018872 , -0.07421687, -0.18519297, -0.09254438,\n",
160
+ " -0.18334997, -0.19211992, -0.07095522, 0.08980912, 0.09272885,\n",
161
+ " -0.00154805, -0.11791486, 0.3486139 , -0.21823978, 0.01764041],\n",
162
+ " [ 0.20221904, -0.10513306, -0.23988037, 0.1219072 , -0.31176899,\n",
163
+ " -0.07312202, -0.17483664, 0.34703633, -0.14286399, 0.01501414,\n",
164
+ " 0.24577391, -0.17025626, -0.01052079, -0.16482973, 0.01907933,\n",
165
+ " -0.00870946, -0.18495344, 0.0982306 , 0.19570449, 0.03290927,\n",
166
+ " -0.08225775, -0.14782619, -0.00959128, -0.04247084, -0.09117351,\n",
167
+ " 0.02470946, -0.0560773 , -0.0605984 , -0.18847048, 0.06813312,\n",
168
+ " 0.24255657, 0.15523338, 0.01986483, -0.23465055, -0.02495009,\n",
169
+ " 0.03532511, 0.0018872 , -0.07421689, -0.18519297, -0.09254432,\n",
170
+ " -0.18334997, -0.19211993, -0.07095522, 0.08980912, 0.09272885,\n",
171
+ " -0.00154805, -0.11791486, 0.3486139 , -0.21823978, 0.01764041],\n",
172
+ " [ 0.20221904, -0.10513306, -0.23988036, 0.12190722, -0.31176898,\n",
173
+ " -0.07312202, -0.17483664, 0.34703633, -0.14286399, 0.01501414,\n",
174
+ " 0.24577391, -0.17025626, -0.01052079, -0.16482973, 0.01907933,\n",
175
+ " -0.00870946, -0.18495343, 0.0982306 , 0.19570452, 0.03290927,\n",
176
+ " -0.08225775, -0.14782619, -0.00959128, -0.04247084, -0.09117351,\n",
177
+ " 0.02470946, -0.0560773 , -0.0605984 , -0.18847048, 0.06813312,\n",
178
+ " 0.24255669, 0.15523338, 0.01986483, -0.23465055, -0.02495009,\n",
179
+ " 0.03532511, 0.0018872 , -0.0742169 , -0.18519297, -0.09254426,\n",
180
+ " -0.18334997, -0.19211993, -0.07095522, 0.08980912, 0.09272885,\n",
181
+ " -0.00154805, -0.11791486, 0.3486139 , -0.21823978, 0.01764041],\n",
182
+ " [ 0.20221904, -0.10513306, -0.23988035, 0.12190723, -0.31176898,\n",
183
+ " -0.07312202, -0.17483664, 0.34703633, -0.14286399, 0.01501414,\n",
184
+ " 0.24577391, -0.17025626, -0.01052079, -0.16482973, 0.01907933,\n",
185
+ " -0.00870946, -0.18495343, 0.0982306 , 0.19570455, 0.03290927,\n",
186
+ " -0.08225775, -0.14782619, -0.00959128, -0.04247084, -0.09117351,\n",
187
+ " 0.02470946, -0.0560773 , -0.0605984 , -0.18847048, 0.06813312,\n",
188
+ " 0.24255681, 0.15523338, 0.01986483, -0.23465055, -0.02495009,\n",
189
+ " 0.03532511, 0.0018872 , -0.07421692, -0.18519297, -0.0925442 ,\n",
190
+ " -0.18334997, -0.19211994, -0.07095522, 0.08980912, 0.09272885,\n",
191
+ " -0.00154805, -0.11791486, 0.3486139 , -0.21823978, 0.01764041]])"
192
+ ]
193
+ },
194
+ "execution_count": 17,
195
+ "metadata": {},
196
+ "output_type": "execute_result"
197
+ }
198
+ ],
199
+ "source": [
200
+ "from gradio_client import Client\n",
201
+ "import json\n",
202
+ "import numpy as np\n",
203
+ "\n",
204
+ "# 1. Connect to the Gradio Space\n",
205
+ "# Uses the same endpoint as your Flask app\n",
206
+ "client = Client(\"RobroKools/CellDreamer-API\")\n",
207
+ "\n",
208
+ "result_a = client.predict(\n",
209
+ " input_data={\"genes\": list(np.random.rand(2446)), \"steps\": 10} # Sending as list to be safe\n",
210
+ ")\n",
211
+ "\n",
212
+ "result_b = client.predict(\n",
213
+ " input_data={\"genes\": list(np.random.rand(2446)), \"steps\": 10}\n",
214
+ ")\n",
215
+ "\n",
216
+ "np.array(result_a[\"trajectory\"]) - np.array(result_b[\"trajectory\"])"
217
+ ]
218
+ }
219
+ ],
220
+ "metadata": {
221
+ "kernelspec": {
222
+ "display_name": "celldreamer",
223
+ "language": "python",
224
+ "name": "python3"
225
+ },
226
+ "language_info": {
227
+ "codemirror_mode": {
228
+ "name": "ipython",
229
+ "version": 3
230
+ },
231
+ "file_extension": ".py",
232
+ "mimetype": "text/x-python",
233
+ "name": "python",
234
+ "nbconvert_exporter": "python",
235
+ "pygments_lexer": "ipython3",
236
+ "version": "3.10.19"
237
+ }
238
+ },
239
+ "nbformat": 4,
240
+ "nbformat_minor": 5
241
+ }
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ gradio
3
+ numpy<2.0
4
+ python-box
5
+ pyyaml
6
+ pandas
7
+ scipy
8
+ scanpy