Umair Khan commited on
Commit
529341f
·
1 Parent(s): 8482f48

first pass of app

Browse files
Files changed (1) hide show
  1. app.py +293 -16
app.py CHANGED
@@ -1,26 +1,303 @@
1
- # import spaces
2
  import spaces
3
 
4
- # wrap package installation in decoration
5
  @spaces.GPU
6
  def install_custom():
7
- import os
8
- os.system("pip install --no-deps ./mosaicfm-0.1.2-py3-none-any.whl")
9
-
10
- # install custom package(s)
11
  install_custom()
12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  import gradio as gr
14
- import torch
15
- import mosaicfm
 
 
16
 
17
- zero = torch.Tensor([0]).cuda()
18
- print(zero.device) # <-- 'cpu' 🤔
 
 
 
 
 
 
 
 
19
 
20
- @spaces.GPU
21
- def greet(n):
22
- print(zero.device) # <-- 'cuda:0' 🤗
23
- return f"Hello {zero + n} Tensor"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
- demo = gr.Interface(fn=greet, inputs=gr.Number(), outputs=gr.Text())
26
- demo.launch()
 
1
+ # custom package installation
2
  import spaces
3
 
 
4
  @spaces.GPU
5
  def install_custom():
6
+ import os
7
+ os.system("pip install --no-deps ./mosaicfm-0.1.2-py3-none-any.whl")
8
+
 
9
  install_custom()
10
 
11
+ # app.py
12
+ # ZeroGPU-friendly Gradio Space for MosaicFM-70M embeddings + UMAP
13
+ # - Upload .h5ad
14
+ # - Compute embeddings via mosaicfm (GPU)
15
+ # - Show UMAP
16
+ # - Download embeddings (.parquet) + adata with obsm["X_mosaicfm_70m"] (.h5ad)
17
+
18
+ from __future__ import annotations
19
+
20
+ import gc
21
+ import io
22
+ import os
23
+ import tempfile
24
+ from pathlib import Path
25
+ from typing import Optional, Tuple
26
+
27
  import gradio as gr
28
+ import anndata as ad
29
+ import pandas as pd
30
+ import numpy as np
31
+ import scanpy as sc
32
 
33
+ # -----------------------------
34
+ # Config
35
+ # -----------------------------
36
+ EMB_KEY = "X_mosaicfm_70m"
37
+ DEFAULT_BATCH_SIZE = 64
38
+ APP_TITLE = "MosaicFM-70M Embeddings"
39
+ APP_DESC = """
40
+ Upload an `.h5ad` (AnnData), compute MosaicFM-70M embeddings (on GPU via ZeroGPU),
41
+ preview a UMAP, and download the results.
42
+ """
43
 
44
+ # If your wheel expects an environment variable for model path, set it here:
45
+ # os.environ.setdefault("MOSAICFM_MODEL_DIR", "/home/user/app/model-70m") # example
46
+
47
+ # -----------------------------
48
+ # Lightweight helpers (CPU)
49
+ # -----------------------------
50
+
51
+ def read_anndata_header(fileobj) -> Tuple[list[str], list[str]]:
52
+ """Return (layers, obs_columns) without doing heavy work."""
53
+ adata = sc.read_h5ad(fileobj.name, backed=None)
54
+ layers = ["<use .X>"] + list(adata.layers.keys())
55
+ obs_cols = list(adata.obs.columns)
56
+ del adata
57
+ gc.collect()
58
+ return layers, obs_cols
59
+
60
+
61
+ def _pick_layer(adata: ad.AnnData, layer_name: Optional[str]) -> np.ndarray:
62
+ X = adata.layers[layer_name] if layer_name else adata.X
63
+ if not isinstance(X, np.ndarray):
64
+ X = X.toarray()
65
+ return X
66
+
67
+
68
+ def _compute_umap_from_emb(emb: np.ndarray, color: Optional[pd.Series] = None) -> Tuple[np.ndarray, Optional[pd.Series]]:
69
+ """Compute UMAP (CPU, via scanpy) given embeddings (cells x d)."""
70
+ ad_umap = ad.AnnData(X=emb)
71
+ sc.pp.neighbors(ad_umap, use_rep=None, n_neighbors=15)
72
+ sc.tl.umap(ad_umap, min_dist=0.4)
73
+ coords = np.asarray(ad_umap.obsm["X_umap"])
74
+ # Return color Series (unaltered) for plotting
75
+ del ad_umap
76
+ return coords, color
77
+
78
+
79
+ def _save_outputs(adata: ad.AnnData, E: np.ndarray) -> Tuple[str, str]:
80
+ """Save embeddings parquet and the .h5ad (with obsm set)."""
81
+ tmpdir = Path(tempfile.mkdtemp())
82
+ # embeddings parquet
83
+ emb_df = pd.DataFrame(E, index=adata.obs_names)
84
+ parquet_path = tmpdir / "mosaicfm70m_embeddings.parquet"
85
+ emb_df.to_parquet(parquet_path)
86
+ # adata with obsm
87
+ out_h5ad = tmpdir / "adata_with_mosaicfm70m.h5ad"
88
+ adata.write(out_h5ad, compression="gzip")
89
+ return str(parquet_path), str(out_h5ad)
90
+
91
+ # -----------------------------
92
+ # GPU-bound work
93
+ # -----------------------------
94
+
95
+ @spaces.GPU # ZeroGPU will spin up a GPU for this call
96
+ def _gpu_embed(adata_bytes: bytes, layer_name: Optional[str], batch_size: int) -> Tuple[np.ndarray, list[str], list[str]]:
97
+ """
98
+ Runs on GPU. We read adata from bytes inside the GPU context so that any
99
+ preprocessing that could leverage torch stays here if needed.
100
+ Returns (embeddings, layers_list, obs_cols_list).
101
+ """
102
+ # Lazy imports inside GPU scope
103
+ import torch
104
+
105
+ # Import mosaicfm lazily; if unavailable or no helper, fallback to PCA
106
+ try:
107
+ import mosaicfm # noqa: F401
108
+ except Exception as e:
109
+ mosaicfm = None # fallback path below
110
+
111
+ # Rehydrate AnnData from bytes
112
+ with tempfile.TemporaryDirectory() as td:
113
+ fpath = Path(td) / "input.h5ad"
114
+ with open(fpath, "wb") as f:
115
+ f.write(adata_bytes)
116
+ adata = sc.read_h5ad(str(fpath), backed=None)
117
+
118
+ # Validate layer
119
+ if layer_name and layer_name not in adata.layers:
120
+ raise gr.Error(f"Layer '{layer_name}' not found. Available: {list(adata.layers.keys())}")
121
+
122
+ # Try calling a helper from your package if it exists.
123
+ # Adjust the import path/names to your wheel's API if needed.
124
+ E = None
125
+ used_helper = False
126
+ if mosaicfm is not None:
127
+ # Try a few likely helper locations
128
+ helpers = [
129
+ "mosaicfm.tasks.embeddings.embed_anndata",
130
+ "mosaicfm.inference.embed_anndata",
131
+ "mosaicfm.embed_anndata",
132
+ ]
133
+ for dotted in helpers:
134
+ try:
135
+ mod_path, fn_name = dotted.rsplit(".", 1)
136
+ mod = __import__(mod_path, fromlist=[fn_name])
137
+ embed_fn = getattr(mod, fn_name)
138
+ # Expected signature: (adata, layer=None, batch_size=..., device="cuda", out_key=None) -> np.ndarray or writes to adata
139
+ device = "cuda" if torch.cuda.is_available() else "cpu"
140
+ result = embed_fn(
141
+ adata=adata,
142
+ layer=(None if (layer_name in [None, "", "<use .X>"]) else layer_name),
143
+ batch_size=int(batch_size),
144
+ device=device,
145
+ out_key=EMB_KEY,
146
+ )
147
+ if isinstance(result, np.ndarray):
148
+ E = result
149
+ else:
150
+ # If helper writes into adata.obsm[EMB_KEY]
151
+ E = np.asarray(adata.obsm.get(EMB_KEY))
152
+ used_helper = True
153
+ break
154
+ except Exception:
155
+ continue
156
+
157
+ # Fallback: simple PCA to keep the app functional even without a helper
158
+ if E is None:
159
+ X = _pick_layer(adata, None if (layer_name in [None, "", "<use .X>"]) else layer_name)
160
+ from sklearn.decomposition import PCA
161
+ n_comp = min(50, X.shape[1])
162
+ E = PCA(n_components=n_comp).fit_transform(X)
163
+
164
+ # Attach to adata and return bytes for CPU-side saving/UMAP
165
+ adata.obsm[EMB_KEY] = E
166
+
167
+ # Hand back small metadata for UI refresh
168
+ layers = list(adata.layers.keys())
169
+ obs_cols = list(adata.obs.columns)
170
+
171
+ # Serialize adata (with embeddings) back to bytes for CPU side
172
+ with tempfile.TemporaryDirectory() as td2:
173
+ outp = Path(td2) / "tmp.h5ad"
174
+ adata.write(outp, compression="gzip")
175
+ with open(outp, "rb") as f:
176
+ adata_persisted = f.read()
177
+
178
+ # Free
179
+ del adata
180
+ torch.cuda.empty_cache()
181
+ gc.collect()
182
+
183
+ # Return embeddings plus metadata; caller will rebuild AnnData from bytes when needed
184
+ return E, layers, obs_cols, adata_persisted
185
+
186
+
187
+ # -----------------------------
188
+ # Orchestration (CPU + GPU)
189
+ # -----------------------------
190
+
191
+ def run_pipeline(fileobj, layer_choice, color_key, batch_size):
192
+ """
193
+ CPU entrypoint invoked by Gradio:
194
+ - reads file to bytes
195
+ - calls GPU embedding
196
+ - rebuilds AnnData (with obsm) on CPU
197
+ - computes UMAP
198
+ - saves outputs and returns UI payloads
199
+ """
200
+ if fileobj is None:
201
+ raise gr.Error("Please upload an .h5ad file.")
202
+
203
+ # Read upload to bytes so the GPU function can load it in its context
204
+ with open(fileobj.name, "rb") as f:
205
+ adata_bytes = f.read()
206
+
207
+ # Compute embeddings on GPU
208
+ E, layers, obs_cols, adata_with_emb_bytes = _gpu_embed(
209
+ adata_bytes=adata_bytes,
210
+ layer_name=(None if layer_choice in [None, "", "<use .X>"] else layer_choice),
211
+ batch_size=int(batch_size),
212
+ )
213
+
214
+ # Rebuild AnnData (with obsm) on CPU
215
+ with tempfile.TemporaryDirectory() as td:
216
+ tmp_in = Path(td) / "with_emb.h5ad"
217
+ with open(tmp_in, "wb") as f:
218
+ f.write(adata_with_emb_bytes)
219
+ adata = sc.read_h5ad(tmp_in, backed=None)
220
+
221
+ # Compute UMAP on CPU
222
+ color_series = adata.obs[color_key] if (color_key and color_key in adata.obs) else None
223
+ coords, color_series = _compute_umap_from_emb(E, color_series)
224
+
225
+ # Make UMAP figure
226
+ import matplotlib.pyplot as plt
227
+ fig = plt.figure(figsize=(5.5, 5.0))
228
+ ax = fig.add_subplot(111)
229
+
230
+ if color_series is None:
231
+ ax.scatter(coords[:, 0], coords[:, 1], s=3, alpha=0.75)
232
+ ax.set_title("UMAP of MosaicFM-70M embeddings")
233
+ else:
234
+ if pd.api.types.is_numeric_dtype(color_series):
235
+ scatt = ax.scatter(coords[:, 0], coords[:, 1], s=3, alpha=0.85, c=color_series.values)
236
+ fig.colorbar(scatt, ax=ax, shrink=0.7, label=color_key)
237
+ ax.set_title(f"UMAP colored by {color_key}")
238
+ else:
239
+ for cat in sorted(color_series.astype(str).unique()):
240
+ mask = (color_series.astype(str).values == str(cat))
241
+ ax.scatter(coords[mask, 0], coords[mask, 1], s=3, alpha=0.85, label=str(cat))
242
+ ax.legend(markerscale=3, fontsize=8, loc="best", frameon=True)
243
+ ax.set_title(f"UMAP colored by {color_key}")
244
+
245
+ ax.set_xlabel("UMAP1")
246
+ ax.set_ylabel("UMAP2")
247
+ fig.tight_layout()
248
+
249
+ tmpdir = Path(tempfile.mkdtemp())
250
+ umap_png = tmpdir / "umap.png"
251
+ fig.savefig(umap_png, dpi=160)
252
+ plt.close(fig)
253
+
254
+ # Save outputs
255
+ parquet_path, h5ad_path = _save_outputs(adata, E)
256
+
257
+ # Return outputs + refreshed dropdown choices (layers/obs)
258
+ return str(umap_png), parquet_path, h5ad_path, ["<use .X>"] + layers, obs_cols
259
+
260
+
261
+ def refresh_after_upload(fileobj):
262
+ if fileobj is None:
263
+ return gr.Dropdown(choices=["<use .X>"], value="<use .X>"), gr.Dropdown(choices=[], value=None)
264
+ try:
265
+ layers, obs_cols = read_anndata_header(fileobj)
266
+ return gr.Dropdown(choices=layers, value=layers[0]), gr.Dropdown(choices=obs_cols, value=None)
267
+ except Exception:
268
+ return gr.Dropdown(choices=["<use .X>"], value="<use .X>"), gr.Dropdown(choices=[], value=None)
269
+
270
+ # -----------------------------
271
+ # UI
272
+ # -----------------------------
273
+
274
+ with gr.Blocks(title=APP_TITLE) as demo:
275
+ gr.Markdown(f"# {APP_TITLE}\n{APP_DESC}")
276
+
277
+ with gr.Row():
278
+ f_in = gr.File(label="Upload .h5ad", file_types=[".h5ad"], type="file")
279
+ batch = gr.Number(value=DEFAULT_BATCH_SIZE, precision=0, label="Batch size")
280
+
281
+ with gr.Row():
282
+ layer_dd = gr.Dropdown(choices=["<use .X>"], value="<use .X>", label="Layer (optional)")
283
+ color_dd = gr.Dropdown(choices=[], value=None, label="UMAP color (obs column, optional)")
284
+
285
+ run_btn = gr.Button("Compute Embeddings + UMAP", variant="primary")
286
+
287
+ with gr.Row():
288
+ umap_img = gr.Image(label="UMAP preview", interactive=False)
289
+
290
+ with gr.Row():
291
+ emb_parquet = gr.File(label="Download embeddings (.parquet)")
292
+ adata_with_emb = gr.File(label="Download AnnData with obsm['X_mosaicfm_70m'] (.h5ad)")
293
+
294
+ # Wire events
295
+ f_in.change(refresh_after_upload, inputs=[f_in], outputs=[layer_dd, color_dd], queue=False)
296
+ run_btn.click(
297
+ run_pipeline,
298
+ inputs=[f_in, layer_dd, color_dd, batch],
299
+ outputs=[umap_img, emb_parquet, adata_with_emb, layer_dd, color_dd],
300
+ )
301
 
302
+ if __name__ == "__main__":
303
+ demo.launch()