Umair Khan commited on
Commit
6f0c8ad
·
1 Parent(s): 3e136de

add more robust error checking

Browse files
Files changed (1) hide show
  1. app.py +119 -47
app.py CHANGED
@@ -16,6 +16,7 @@ import numpy as np
16
  import scanpy as sc
17
  import pyarrow as pa
18
  import pyarrow.parquet as pq
 
19
  from pathlib import Path
20
  from composer import Trainer, Callback
21
  from tahoex.model.model import ComposerTX
@@ -45,6 +46,10 @@ with open("./symbol-to-ensembl.json", "r") as f:
45
  PARQUET_INDEX_COL = "index"
46
  PARQUET_EMB_COL = "tx1-70m"
47
 
 
 
 
 
48
  # helper to read AnnData header
49
  def read_anndata_header(fileobj):
50
  adata = sc.read_h5ad(fileobj.name, backed="r")
@@ -105,21 +110,130 @@ def ensure_dropdowns(fileobj):
105
  return (
106
  gr.Dropdown(choices=["<use .X>"], value="<use .X>"),
107
  gr.Dropdown(choices=[], value=None),
108
- gr.Dropdown(choices=[], value=None),
109
  )
110
  try:
111
- layers, var_cols, obs_cols = read_anndata_header(fileobj)
112
  return (
113
  gr.Dropdown(choices=["<use .X>"] + layers, value="<use .X>"),
114
  gr.Dropdown(choices=var_cols, value=(var_cols[0] if var_cols else None)),
115
- gr.Dropdown(choices=obs_cols, value=None),
116
  )
117
  except Exception:
118
  return (
119
  gr.Dropdown(choices=["<use .X>"], value="<use .X>"),
120
  gr.Dropdown(choices=[], value=None),
121
- gr.Dropdown(choices=[], value=None),
122
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
123
 
124
  # custom callback to report progress to Gradio
125
  class GradioProgressCallback(Callback):
@@ -337,47 +451,6 @@ def _embed(adata_bytes, layer_name, feature_col, use_symbols, progress):
337
  # return embeddings and metadata
338
  return cell_array, layers, var_cols, obs_cols, adata_persisted
339
 
340
- # recolor UMAP given obs column
341
- def recolor_umap(obs_col, coords, h5ad_path):
342
-
343
- # make sure inputs are valid
344
- if not obs_col:
345
- raise gr.Error("Pick a .obs column to color by.")
346
- if coords is None or h5ad_path is None:
347
- raise gr.Error("Compute embeddings before selecting a .obs column.")
348
-
349
- # read obs column
350
- adata = sc.read_h5ad(h5ad_path, backed="r")
351
- if obs_col not in adata.obs.columns:
352
- raise gr.Error(f"Column '{obs_col}' not found in .obs.")
353
-
354
- # construct plot
355
- color_series = adata.obs[obs_col]
356
- import matplotlib.pyplot as plt
357
- fig = plt.figure(figsize=(5.5, 5.0))
358
- ax = fig.add_subplot(111)
359
- if pd.api.types.is_numeric_dtype(color_series):
360
- scatt = ax.scatter(coords[:, 0], coords[:, 1], s=3, alpha=0.85, c=color_series.values)
361
- fig.colorbar(scatt, ax=ax, shrink=0.7, label=obs_col)
362
- ax.set_title(f"UMAP colored by {obs_col}")
363
- else:
364
- cs = color_series.astype(str).values
365
- for cat in sorted(pd.unique(cs)):
366
- mask = (cs == cat)
367
- ax.scatter(coords[mask, 0], coords[mask, 1], s=3, alpha=0.85, label=str(cat))
368
- ax.legend(markerscale=3, fontsize=8, loc="best", frameon=True)
369
- ax.set_title(f"UMAP colored by {obs_col}")
370
-
371
- # finalize and save plot
372
- ax.set_xlabel("UMAP1")
373
- ax.set_ylabel("UMAP2")
374
- fig.tight_layout()
375
- out_png = _unique_output("umap.png")
376
- fig.savefig(out_png, dpi=160)
377
- plt.close(fig)
378
- return str(out_png.resolve())
379
-
380
-
381
  # processing pipeline given user inputs
382
  def run_pipeline(fileobj, layer_choice, var_choice, use_symbols, progress=gr.Progress(track_tqdm=False)):
383
 
@@ -415,7 +488,6 @@ def run_pipeline(fileobj, layer_choice, var_choice, use_symbols, progress=gr.Pro
415
 
416
  # plot UMAP (no coloring by default)
417
  progress(0.90, desc="plotting UMAP")
418
- import matplotlib.pyplot as plt
419
  fig = plt.figure(figsize=(5.5, 5.0))
420
  ax = fig.add_subplot(111)
421
  ax.scatter(coords[:, 0], coords[:, 1], s=3, alpha=0.75)
@@ -428,7 +500,7 @@ def run_pipeline(fileobj, layer_choice, var_choice, use_symbols, progress=gr.Pro
428
  plt.close(fig)
429
 
430
  # enable coloring dropdown
431
- update_obs_dd = gr.Dropdown(choices=obs_cols, value=None, interactive=True)
432
 
433
  # save other outputs and return paths
434
  progress(0.95, desc="saving outputs")
 
16
  import scanpy as sc
17
  import pyarrow as pa
18
  import pyarrow.parquet as pq
19
+ import matplotlib.pyplot as plt
20
  from pathlib import Path
21
  from composer import Trainer, Callback
22
  from tahoex.model.model import ComposerTX
 
46
  PARQUET_INDEX_COL = "index"
47
  PARQUET_EMB_COL = "tx1-70m"
48
 
49
+ # constants for UMAP recoloring
50
+ OBS_NONE_OPTION = "(none)"
51
+ MAX_CATEGORIES = 50
52
+
53
  # helper to read AnnData header
54
  def read_anndata_header(fileobj):
55
  adata = sc.read_h5ad(fileobj.name, backed="r")
 
110
  return (
111
  gr.Dropdown(choices=["<use .X>"], value="<use .X>"),
112
  gr.Dropdown(choices=[], value=None),
 
113
  )
114
  try:
115
+ layers, var_cols, _ = read_anndata_header(fileobj)
116
  return (
117
  gr.Dropdown(choices=["<use .X>"] + layers, value="<use .X>"),
118
  gr.Dropdown(choices=var_cols, value=(var_cols[0] if var_cols else None)),
 
119
  )
120
  except Exception:
121
  return (
122
  gr.Dropdown(choices=["<use .X>"], value="<use .X>"),
123
  gr.Dropdown(choices=[], value=None),
 
124
  )
125
+
126
+ # draw an uncolored UMAP
127
+ def draw_uncolored(coords, title_suffix=None):
128
+ fig = plt.figure(figsize=(5.5, 5.0))
129
+ ax = fig.add_subplot(111)
130
+ ax.scatter(coords[:, 0], coords[:, 1], s=3, alpha=0.75)
131
+ ttl = "UMAP of Tx1-70M embeddings"
132
+ if title_suffix:
133
+ ttl += f" ({title_suffix})"
134
+ ax.set_title(ttl)
135
+ ax.set_xlabel("UMAP1")
136
+ ax.set_ylabel("UMAP2")
137
+ fig.tight_layout()
138
+ out_png = _unique_output("umap.png")
139
+ fig.savefig(out_png, dpi=160)
140
+ plt.close(fig)
141
+ return out_png
142
+
143
+ # recolor UMAP given obs column
144
+ def recolor_umap(obs_col, coords, h5ad_path):
145
+
146
+ # make sure inputs are valid
147
+ if coords is None or h5ad_path is None:
148
+ raise gr.Error("Run embeddings first to compute UMAP.")
149
+ coords = np.asarray(coords)
150
+ if coords.ndim != 2 or coords.shape[1] != 2:
151
+ raise gr.Error(f"UMAP coordinates look wrong, shape = {coords.shape}. Please recompute.")
152
+
153
+ # handle no-coloring option
154
+ if obs_col == OBS_NONE_OPTION:
155
+ out_png = draw_uncolored(coords)
156
+ return str(out_png.resolve())
157
+
158
+ # read obs column
159
+ try:
160
+ adata = sc.read_h5ad(h5ad_path, backed="r")
161
+ series = adata.obs[obs_col]
162
+ n = series.shape[0]
163
+ if n != coords.shape[0]:
164
+ gr.Warning(f"Length mismatch: obs has {n} rows, UMAP has {coords.shape[0]}. Using minimum length.")
165
+ m = min(n, coords.shape[0])
166
+ series = series.iloc[:m]
167
+ coords = coords[:m]
168
+ except Exception as e:
169
+ raise gr.Error(f"Failed to read .obs column '{obs_col}': {e}")
170
+
171
+ # sanitize values
172
+ s = series.copy()
173
+ numeric_candidate = pd.to_numeric(s, errors="coerce")
174
+ n_numeric_valid = int(np.isfinite(numeric_candidate.astype(float)).sum())
175
+ n_total = int(len(s))
176
+
177
+ # plot numerical labels
178
+ if n_numeric_valid >= max(5, 0.5 * n_total):
179
+
180
+ # check for sufficient numeric values or constant values
181
+ vals = pd.to_numeric(s, errors="coerce").astype(float).values
182
+ mask = np.isfinite(vals)
183
+ if mask.sum() < max(10, 0.1 * len(vals)):
184
+ gr.Warning(f"Too few finite numeric values in '{obs_col}'. Showing uncolored UMAP.")
185
+ return draw_uncolored(f"{obs_col}: insufficient numeric values")
186
+ if np.nanmax(vals[mask]) == np.nanmin(vals[mask]):
187
+ gr.Info(f"'{obs_col}' is constant. Showing uncolored UMAP.")
188
+ return draw_uncolored(f"{obs_col}: constant")
189
+
190
+ # draw with colorbar
191
+ fig = plt.figure(figsize=(5.5, 5.0))
192
+ ax = fig.add_subplot(111)
193
+ scatt = ax.scatter(coords[mask, 0], coords[mask, 1], s=3, alpha=0.85, c=vals[mask])
194
+ fig.colorbar(scatt, ax=ax, shrink=0.7, label=obs_col)
195
+ if (~mask).any():
196
+ ax.scatter(coords[~mask, 0], coords[~mask, 1], s=3, alpha=0.25)
197
+ ax.set_title(f"UMAP colored by {obs_col}")
198
+ ax.set_xlabel("UMAP1")
199
+ ax.set_ylabel("UMAP2")
200
+ fig.tight_layout()
201
+ out_png = _unique_output("umap.png")
202
+ fig.savefig(out_png, dpi=160)
203
+ plt.close(fig)
204
+ return str(out_png.resolve())
205
+
206
+ # categorical coloring
207
+ else:
208
+
209
+ # check for too many or too few categories
210
+ cats = s.astype(str).fillna("NA").values
211
+ uniq = pd.unique(cats)
212
+ n_cat = len(uniq)
213
+ if n_cat > MAX_CATEGORIES:
214
+ gr.Warning(f"'{obs_col}' has too many categories. Showing uncolored UMAP.")
215
+ out_png = draw_uncolored(coords, f"{obs_col}: {n_cat} categories")
216
+ return str(out_png.resolve())
217
+ if n_cat <= 1:
218
+ gr.Info(f"'{obs_col}' has a single category. Showing uncolored UMAP.")
219
+ out_png = draw_uncolored(f"{obs_col}: 1 category")
220
+ return str(out_png.resolve())
221
+
222
+ # draw with legend
223
+ fig = plt.figure(figsize=(5.5, 5.0))
224
+ ax = fig.add_subplot(111)
225
+ for cat in sorted(map(str, uniq)):
226
+ mask = (cats == cat)
227
+ ax.scatter(coords[mask, 0], coords[mask, 1], s=3, alpha=0.85, label=cat)
228
+ ax.legend(markerscale=3, fontsize=8, loc="best", frameon=True, ncol=1)
229
+ ax.set_title(f"UMAP colored by {obs_col}")
230
+ ax.set_xlabel("UMAP1")
231
+ ax.set_ylabel("UMAP2")
232
+ fig.tight_layout()
233
+ out_png = _unique_output("umap.png")
234
+ fig.savefig(out_png, dpi=160)
235
+ plt.close(fig)
236
+ return str(out_png.resolve())
237
 
238
  # custom callback to report progress to Gradio
239
  class GradioProgressCallback(Callback):
 
451
  # return embeddings and metadata
452
  return cell_array, layers, var_cols, obs_cols, adata_persisted
453
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
454
  # processing pipeline given user inputs
455
  def run_pipeline(fileobj, layer_choice, var_choice, use_symbols, progress=gr.Progress(track_tqdm=False)):
456
 
 
488
 
489
  # plot UMAP (no coloring by default)
490
  progress(0.90, desc="plotting UMAP")
 
491
  fig = plt.figure(figsize=(5.5, 5.0))
492
  ax = fig.add_subplot(111)
493
  ax.scatter(coords[:, 0], coords[:, 1], s=3, alpha=0.75)
 
500
  plt.close(fig)
501
 
502
  # enable coloring dropdown
503
+ update_obs_dd = gr.Dropdown(choices=[OBS_NONE_OPTION] + obs_cols, value=OBS_NONE_OPTION, interactive=True)
504
 
505
  # save other outputs and return paths
506
  progress(0.95, desc="saving outputs")