Umair Khan
commited on
Commit
·
6f0c8ad
1
Parent(s):
3e136de
add more robust error checking
Browse files
app.py
CHANGED
|
@@ -16,6 +16,7 @@ import numpy as np
|
|
| 16 |
import scanpy as sc
|
| 17 |
import pyarrow as pa
|
| 18 |
import pyarrow.parquet as pq
|
|
|
|
| 19 |
from pathlib import Path
|
| 20 |
from composer import Trainer, Callback
|
| 21 |
from tahoex.model.model import ComposerTX
|
|
@@ -45,6 +46,10 @@ with open("./symbol-to-ensembl.json", "r") as f:
|
|
| 45 |
PARQUET_INDEX_COL = "index"
|
| 46 |
PARQUET_EMB_COL = "tx1-70m"
|
| 47 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 48 |
# helper to read AnnData header
|
| 49 |
def read_anndata_header(fileobj):
|
| 50 |
adata = sc.read_h5ad(fileobj.name, backed="r")
|
|
@@ -105,21 +110,130 @@ def ensure_dropdowns(fileobj):
|
|
| 105 |
return (
|
| 106 |
gr.Dropdown(choices=["<use .X>"], value="<use .X>"),
|
| 107 |
gr.Dropdown(choices=[], value=None),
|
| 108 |
-
gr.Dropdown(choices=[], value=None),
|
| 109 |
)
|
| 110 |
try:
|
| 111 |
-
layers, var_cols,
|
| 112 |
return (
|
| 113 |
gr.Dropdown(choices=["<use .X>"] + layers, value="<use .X>"),
|
| 114 |
gr.Dropdown(choices=var_cols, value=(var_cols[0] if var_cols else None)),
|
| 115 |
-
gr.Dropdown(choices=obs_cols, value=None),
|
| 116 |
)
|
| 117 |
except Exception:
|
| 118 |
return (
|
| 119 |
gr.Dropdown(choices=["<use .X>"], value="<use .X>"),
|
| 120 |
gr.Dropdown(choices=[], value=None),
|
| 121 |
-
gr.Dropdown(choices=[], value=None),
|
| 122 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 123 |
|
| 124 |
# custom callback to report progress to Gradio
|
| 125 |
class GradioProgressCallback(Callback):
|
|
@@ -337,47 +451,6 @@ def _embed(adata_bytes, layer_name, feature_col, use_symbols, progress):
|
|
| 337 |
# return embeddings and metadata
|
| 338 |
return cell_array, layers, var_cols, obs_cols, adata_persisted
|
| 339 |
|
| 340 |
-
# recolor UMAP given obs column
|
| 341 |
-
def recolor_umap(obs_col, coords, h5ad_path):
|
| 342 |
-
|
| 343 |
-
# make sure inputs are valid
|
| 344 |
-
if not obs_col:
|
| 345 |
-
raise gr.Error("Pick a .obs column to color by.")
|
| 346 |
-
if coords is None or h5ad_path is None:
|
| 347 |
-
raise gr.Error("Compute embeddings before selecting a .obs column.")
|
| 348 |
-
|
| 349 |
-
# read obs column
|
| 350 |
-
adata = sc.read_h5ad(h5ad_path, backed="r")
|
| 351 |
-
if obs_col not in adata.obs.columns:
|
| 352 |
-
raise gr.Error(f"Column '{obs_col}' not found in .obs.")
|
| 353 |
-
|
| 354 |
-
# construct plot
|
| 355 |
-
color_series = adata.obs[obs_col]
|
| 356 |
-
import matplotlib.pyplot as plt
|
| 357 |
-
fig = plt.figure(figsize=(5.5, 5.0))
|
| 358 |
-
ax = fig.add_subplot(111)
|
| 359 |
-
if pd.api.types.is_numeric_dtype(color_series):
|
| 360 |
-
scatt = ax.scatter(coords[:, 0], coords[:, 1], s=3, alpha=0.85, c=color_series.values)
|
| 361 |
-
fig.colorbar(scatt, ax=ax, shrink=0.7, label=obs_col)
|
| 362 |
-
ax.set_title(f"UMAP colored by {obs_col}")
|
| 363 |
-
else:
|
| 364 |
-
cs = color_series.astype(str).values
|
| 365 |
-
for cat in sorted(pd.unique(cs)):
|
| 366 |
-
mask = (cs == cat)
|
| 367 |
-
ax.scatter(coords[mask, 0], coords[mask, 1], s=3, alpha=0.85, label=str(cat))
|
| 368 |
-
ax.legend(markerscale=3, fontsize=8, loc="best", frameon=True)
|
| 369 |
-
ax.set_title(f"UMAP colored by {obs_col}")
|
| 370 |
-
|
| 371 |
-
# finalize and save plot
|
| 372 |
-
ax.set_xlabel("UMAP1")
|
| 373 |
-
ax.set_ylabel("UMAP2")
|
| 374 |
-
fig.tight_layout()
|
| 375 |
-
out_png = _unique_output("umap.png")
|
| 376 |
-
fig.savefig(out_png, dpi=160)
|
| 377 |
-
plt.close(fig)
|
| 378 |
-
return str(out_png.resolve())
|
| 379 |
-
|
| 380 |
-
|
| 381 |
# processing pipeline given user inputs
|
| 382 |
def run_pipeline(fileobj, layer_choice, var_choice, use_symbols, progress=gr.Progress(track_tqdm=False)):
|
| 383 |
|
|
@@ -415,7 +488,6 @@ def run_pipeline(fileobj, layer_choice, var_choice, use_symbols, progress=gr.Pro
|
|
| 415 |
|
| 416 |
# plot UMAP (no coloring by default)
|
| 417 |
progress(0.90, desc="plotting UMAP")
|
| 418 |
-
import matplotlib.pyplot as plt
|
| 419 |
fig = plt.figure(figsize=(5.5, 5.0))
|
| 420 |
ax = fig.add_subplot(111)
|
| 421 |
ax.scatter(coords[:, 0], coords[:, 1], s=3, alpha=0.75)
|
|
@@ -428,7 +500,7 @@ def run_pipeline(fileobj, layer_choice, var_choice, use_symbols, progress=gr.Pro
|
|
| 428 |
plt.close(fig)
|
| 429 |
|
| 430 |
# enable coloring dropdown
|
| 431 |
-
update_obs_dd = gr.Dropdown(choices=obs_cols, value=
|
| 432 |
|
| 433 |
# save other outputs and return paths
|
| 434 |
progress(0.95, desc="saving outputs")
|
|
|
|
| 16 |
import scanpy as sc
|
| 17 |
import pyarrow as pa
|
| 18 |
import pyarrow.parquet as pq
|
| 19 |
+
import matplotlib.pyplot as plt
|
| 20 |
from pathlib import Path
|
| 21 |
from composer import Trainer, Callback
|
| 22 |
from tahoex.model.model import ComposerTX
|
|
|
|
| 46 |
PARQUET_INDEX_COL = "index"
|
| 47 |
PARQUET_EMB_COL = "tx1-70m"
|
| 48 |
|
| 49 |
+
# constants for UMAP recoloring
|
| 50 |
+
OBS_NONE_OPTION = "(none)"
|
| 51 |
+
MAX_CATEGORIES = 50
|
| 52 |
+
|
| 53 |
# helper to read AnnData header
|
| 54 |
def read_anndata_header(fileobj):
|
| 55 |
adata = sc.read_h5ad(fileobj.name, backed="r")
|
|
|
|
| 110 |
return (
|
| 111 |
gr.Dropdown(choices=["<use .X>"], value="<use .X>"),
|
| 112 |
gr.Dropdown(choices=[], value=None),
|
|
|
|
| 113 |
)
|
| 114 |
try:
|
| 115 |
+
layers, var_cols, _ = read_anndata_header(fileobj)
|
| 116 |
return (
|
| 117 |
gr.Dropdown(choices=["<use .X>"] + layers, value="<use .X>"),
|
| 118 |
gr.Dropdown(choices=var_cols, value=(var_cols[0] if var_cols else None)),
|
|
|
|
| 119 |
)
|
| 120 |
except Exception:
|
| 121 |
return (
|
| 122 |
gr.Dropdown(choices=["<use .X>"], value="<use .X>"),
|
| 123 |
gr.Dropdown(choices=[], value=None),
|
|
|
|
| 124 |
)
|
| 125 |
+
|
| 126 |
+
# draw an uncolored UMAP
|
| 127 |
+
def draw_uncolored(coords, title_suffix=None):
|
| 128 |
+
fig = plt.figure(figsize=(5.5, 5.0))
|
| 129 |
+
ax = fig.add_subplot(111)
|
| 130 |
+
ax.scatter(coords[:, 0], coords[:, 1], s=3, alpha=0.75)
|
| 131 |
+
ttl = "UMAP of Tx1-70M embeddings"
|
| 132 |
+
if title_suffix:
|
| 133 |
+
ttl += f" ({title_suffix})"
|
| 134 |
+
ax.set_title(ttl)
|
| 135 |
+
ax.set_xlabel("UMAP1")
|
| 136 |
+
ax.set_ylabel("UMAP2")
|
| 137 |
+
fig.tight_layout()
|
| 138 |
+
out_png = _unique_output("umap.png")
|
| 139 |
+
fig.savefig(out_png, dpi=160)
|
| 140 |
+
plt.close(fig)
|
| 141 |
+
return out_png
|
| 142 |
+
|
| 143 |
+
# recolor UMAP given obs column
|
| 144 |
+
def recolor_umap(obs_col, coords, h5ad_path):
|
| 145 |
+
|
| 146 |
+
# make sure inputs are valid
|
| 147 |
+
if coords is None or h5ad_path is None:
|
| 148 |
+
raise gr.Error("Run embeddings first to compute UMAP.")
|
| 149 |
+
coords = np.asarray(coords)
|
| 150 |
+
if coords.ndim != 2 or coords.shape[1] != 2:
|
| 151 |
+
raise gr.Error(f"UMAP coordinates look wrong, shape = {coords.shape}. Please recompute.")
|
| 152 |
+
|
| 153 |
+
# handle no-coloring option
|
| 154 |
+
if obs_col == OBS_NONE_OPTION:
|
| 155 |
+
out_png = draw_uncolored(coords)
|
| 156 |
+
return str(out_png.resolve())
|
| 157 |
+
|
| 158 |
+
# read obs column
|
| 159 |
+
try:
|
| 160 |
+
adata = sc.read_h5ad(h5ad_path, backed="r")
|
| 161 |
+
series = adata.obs[obs_col]
|
| 162 |
+
n = series.shape[0]
|
| 163 |
+
if n != coords.shape[0]:
|
| 164 |
+
gr.Warning(f"Length mismatch: obs has {n} rows, UMAP has {coords.shape[0]}. Using minimum length.")
|
| 165 |
+
m = min(n, coords.shape[0])
|
| 166 |
+
series = series.iloc[:m]
|
| 167 |
+
coords = coords[:m]
|
| 168 |
+
except Exception as e:
|
| 169 |
+
raise gr.Error(f"Failed to read .obs column '{obs_col}': {e}")
|
| 170 |
+
|
| 171 |
+
# sanitize values
|
| 172 |
+
s = series.copy()
|
| 173 |
+
numeric_candidate = pd.to_numeric(s, errors="coerce")
|
| 174 |
+
n_numeric_valid = int(np.isfinite(numeric_candidate.astype(float)).sum())
|
| 175 |
+
n_total = int(len(s))
|
| 176 |
+
|
| 177 |
+
# plot numerical labels
|
| 178 |
+
if n_numeric_valid >= max(5, 0.5 * n_total):
|
| 179 |
+
|
| 180 |
+
# check for sufficient numeric values or constant values
|
| 181 |
+
vals = pd.to_numeric(s, errors="coerce").astype(float).values
|
| 182 |
+
mask = np.isfinite(vals)
|
| 183 |
+
if mask.sum() < max(10, 0.1 * len(vals)):
|
| 184 |
+
gr.Warning(f"Too few finite numeric values in '{obs_col}'. Showing uncolored UMAP.")
|
| 185 |
+
return draw_uncolored(f"{obs_col}: insufficient numeric values")
|
| 186 |
+
if np.nanmax(vals[mask]) == np.nanmin(vals[mask]):
|
| 187 |
+
gr.Info(f"'{obs_col}' is constant. Showing uncolored UMAP.")
|
| 188 |
+
return draw_uncolored(f"{obs_col}: constant")
|
| 189 |
+
|
| 190 |
+
# draw with colorbar
|
| 191 |
+
fig = plt.figure(figsize=(5.5, 5.0))
|
| 192 |
+
ax = fig.add_subplot(111)
|
| 193 |
+
scatt = ax.scatter(coords[mask, 0], coords[mask, 1], s=3, alpha=0.85, c=vals[mask])
|
| 194 |
+
fig.colorbar(scatt, ax=ax, shrink=0.7, label=obs_col)
|
| 195 |
+
if (~mask).any():
|
| 196 |
+
ax.scatter(coords[~mask, 0], coords[~mask, 1], s=3, alpha=0.25)
|
| 197 |
+
ax.set_title(f"UMAP colored by {obs_col}")
|
| 198 |
+
ax.set_xlabel("UMAP1")
|
| 199 |
+
ax.set_ylabel("UMAP2")
|
| 200 |
+
fig.tight_layout()
|
| 201 |
+
out_png = _unique_output("umap.png")
|
| 202 |
+
fig.savefig(out_png, dpi=160)
|
| 203 |
+
plt.close(fig)
|
| 204 |
+
return str(out_png.resolve())
|
| 205 |
+
|
| 206 |
+
# categorical coloring
|
| 207 |
+
else:
|
| 208 |
+
|
| 209 |
+
# check for too many or too few categories
|
| 210 |
+
cats = s.astype(str).fillna("NA").values
|
| 211 |
+
uniq = pd.unique(cats)
|
| 212 |
+
n_cat = len(uniq)
|
| 213 |
+
if n_cat > MAX_CATEGORIES:
|
| 214 |
+
gr.Warning(f"'{obs_col}' has too many categories. Showing uncolored UMAP.")
|
| 215 |
+
out_png = draw_uncolored(coords, f"{obs_col}: {n_cat} categories")
|
| 216 |
+
return str(out_png.resolve())
|
| 217 |
+
if n_cat <= 1:
|
| 218 |
+
gr.Info(f"'{obs_col}' has a single category. Showing uncolored UMAP.")
|
| 219 |
+
out_png = draw_uncolored(f"{obs_col}: 1 category")
|
| 220 |
+
return str(out_png.resolve())
|
| 221 |
+
|
| 222 |
+
# draw with legend
|
| 223 |
+
fig = plt.figure(figsize=(5.5, 5.0))
|
| 224 |
+
ax = fig.add_subplot(111)
|
| 225 |
+
for cat in sorted(map(str, uniq)):
|
| 226 |
+
mask = (cats == cat)
|
| 227 |
+
ax.scatter(coords[mask, 0], coords[mask, 1], s=3, alpha=0.85, label=cat)
|
| 228 |
+
ax.legend(markerscale=3, fontsize=8, loc="best", frameon=True, ncol=1)
|
| 229 |
+
ax.set_title(f"UMAP colored by {obs_col}")
|
| 230 |
+
ax.set_xlabel("UMAP1")
|
| 231 |
+
ax.set_ylabel("UMAP2")
|
| 232 |
+
fig.tight_layout()
|
| 233 |
+
out_png = _unique_output("umap.png")
|
| 234 |
+
fig.savefig(out_png, dpi=160)
|
| 235 |
+
plt.close(fig)
|
| 236 |
+
return str(out_png.resolve())
|
| 237 |
|
| 238 |
# custom callback to report progress to Gradio
|
| 239 |
class GradioProgressCallback(Callback):
|
|
|
|
| 451 |
# return embeddings and metadata
|
| 452 |
return cell_array, layers, var_cols, obs_cols, adata_persisted
|
| 453 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 454 |
# processing pipeline given user inputs
|
| 455 |
def run_pipeline(fileobj, layer_choice, var_choice, use_symbols, progress=gr.Progress(track_tqdm=False)):
|
| 456 |
|
|
|
|
| 488 |
|
| 489 |
# plot UMAP (no coloring by default)
|
| 490 |
progress(0.90, desc="plotting UMAP")
|
|
|
|
| 491 |
fig = plt.figure(figsize=(5.5, 5.0))
|
| 492 |
ax = fig.add_subplot(111)
|
| 493 |
ax.scatter(coords[:, 0], coords[:, 1], s=3, alpha=0.75)
|
|
|
|
| 500 |
plt.close(fig)
|
| 501 |
|
| 502 |
# enable coloring dropdown
|
| 503 |
+
update_obs_dd = gr.Dropdown(choices=[OBS_NONE_OPTION] + obs_cols, value=OBS_NONE_OPTION, interactive=True)
|
| 504 |
|
| 505 |
# save other outputs and return paths
|
| 506 |
progress(0.95, desc="saving outputs")
|