kaveh commited on
Commit
0e1a816
·
1 Parent(s): 4b93ba5

updated UI

Browse files
.gitignore CHANGED
@@ -5,7 +5,9 @@ dataset/
5
 
6
  *_pycache*/
7
 
 
8
  test/
9
  S2FApp/ckp/spheroid/ckp_spheroid_FN.pth
10
  S2FApp/ckp/single_cell/ckp_singlecell_GN.pth
11
  S2FApp/ckp/spheroid/ckp_spheroid_GN.pth
 
 
5
 
6
  *_pycache*/
7
 
8
+ S2FApp/.s2f_env/
9
  test/
10
  S2FApp/ckp/spheroid/ckp_spheroid_FN.pth
11
  S2FApp/ckp/single_cell/ckp_singlecell_GN.pth
12
  S2FApp/ckp/spheroid/ckp_spheroid_GN.pth
13
+ S2FApp/*.docx
S2FApp/.gitignore CHANGED
@@ -1,6 +1,7 @@
1
  __pycache__
2
  *.py[cod]
3
- .venv
4
  venv
 
5
  .DS_Store
6
  ckp/*.pth
 
 
1
  __pycache__
2
  *.py[cod]
 
3
  venv
4
+ .venv
5
  .DS_Store
6
  ckp/*.pth
7
+ *.docx
S2FApp/.streamlit/config.toml CHANGED
@@ -1,3 +1,13 @@
1
  [server]
2
  # Required for file uploads on Hugging Face Spaces (iframe blocks XSRF cookies)
3
  enableXsrfProtection = false
 
 
 
 
 
 
 
 
 
 
 
1
  [server]
2
  # Required for file uploads on Hugging Face Spaces (iframe blocks XSRF cookies)
3
  enableXsrfProtection = false
4
+
5
+ [theme]
6
+ primaryColor = "#0d9488"
7
+ backgroundColor = "#ffffff"
8
+ secondaryBackgroundColor = "#f8fafc"
9
+ textColor = "#1e293b"
10
+ font = "sans serif"
11
+
12
+ [logger]
13
+ level = "error"
S2FApp/Dockerfile CHANGED
@@ -13,15 +13,17 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
13
  && rm -rf /var/lib/apt/lists/*
14
 
15
  # Copy requirements first for better caching
16
- COPY requirements.txt .
17
 
18
  # Install Python dependencies - CPU-only PyTorch to fit Space memory limits (avoids OOM)
19
- RUN pip install --no-cache-dir torch torchvision --index-url https://download.pytorch.org/whl/cpu && \
20
- pip install --no-cache-dir numpy opencv-python scipy scikit-image streamlit matplotlib Pillow plotly huggingface_hub reportlab streamlit-drawable-canvas-fix psutil
 
21
 
22
  # Copy app code (chown for HF Spaces permissions)
23
  COPY --chown=user:user app.py predictor.py download_ckp.py ./
24
  COPY --chown=user:user .streamlit/ .streamlit/
 
25
  COPY --chown=user:user models/ models/
26
  COPY --chown=user:user ui/ ui/
27
  COPY --chown=user:user utils/ utils/
 
13
  && rm -rf /var/lib/apt/lists/*
14
 
15
  # Copy requirements first for better caching
16
+ COPY requirements.txt requirements-docker.txt ./
17
 
18
  # Install Python dependencies - CPU-only PyTorch to fit Space memory limits (avoids OOM)
19
+ # PyTorch 2.2 + torchvision 0.17 (CPU) - match requirements.txt torch>=2.0
20
+ RUN pip install --no-cache-dir torch==2.2.0 torchvision==0.17.0 --index-url https://download.pytorch.org/whl/cpu && \
21
+ pip install --no-cache-dir -r requirements-docker.txt
22
 
23
  # Copy app code (chown for HF Spaces permissions)
24
  COPY --chown=user:user app.py predictor.py download_ckp.py ./
25
  COPY --chown=user:user .streamlit/ .streamlit/
26
+ COPY --chown=user:user static/ static/
27
  COPY --chown=user:user models/ models/
28
  COPY --chown=user:user ui/ ui/
29
  COPY --chown=user:user utils/ utils/
S2FApp/app.py CHANGED
@@ -22,6 +22,7 @@ from config.constants import (
22
  MODEL_TYPE_LABELS,
23
  SAMPLE_EXTENSIONS,
24
  SAMPLE_THUMBNAIL_LIMIT,
 
25
  )
26
  from utils.paths import get_ckp_base, get_ckp_folder, get_sample_folder, list_files_in_folder, model_subfolder
27
  from utils.segmentation import estimate_cell_mask
@@ -40,7 +41,7 @@ from ui.components import (
40
 
41
  CITATION = (
42
  "Lautaro Baro, Kaveh Shahhosseini, Amparo Andrés-Bordería, Claudio Angione, and Maria Angeles Juanes. "
43
- "**\"Shape-to-force (S2F): Predicting Cell Traction Forces from LabelFree Imaging\"**, 2026."
44
  )
45
 
46
  # Measure tool dialog: defined early so it exists before render_result_display uses it
@@ -58,6 +59,7 @@ if HAS_DRAWABLE_CANVAS and ST_DIALOG:
58
  max_percentile=st.session_state.get("measure_max_percentile", 100),
59
  clip_min=st.session_state.get("measure_clip_min", 0),
60
  clip_max=st.session_state.get("measure_clip_max", 1),
 
61
  )
62
  bf_img = st.session_state.get("measure_bf_img")
63
  original_vals = st.session_state.get("measure_original_vals")
@@ -82,7 +84,8 @@ def _get_measure_dialog_fn():
82
 
83
  def _populate_measure_session_state(heatmap, img, pixel_sum, force, key_img, colormap_name,
84
  display_mode, auto_cell_boundary, cell_mask=None,
85
- min_percentile=0, max_percentile=100, clip_min=0, clip_max=1):
 
86
  """Populate session state for the measure tool. If cell_mask is None and auto_cell_boundary, computes it."""
87
  if cell_mask is None and auto_cell_boundary:
88
  cell_mask = estimate_cell_mask(heatmap)
@@ -92,6 +95,7 @@ def _populate_measure_session_state(heatmap, img, pixel_sum, force, key_img, col
92
  st.session_state["measure_max_percentile"] = max_percentile
93
  st.session_state["measure_clip_min"] = clip_min
94
  st.session_state["measure_clip_max"] = clip_max
 
95
  st.session_state["measure_bf_img"] = img.copy()
96
  st.session_state["measure_input_filename"] = key_img or "image"
97
  st.session_state["measure_original_vals"] = build_original_vals(heatmap, pixel_sum, force)
@@ -101,57 +105,27 @@ def _populate_measure_session_state(heatmap, img, pixel_sum, force, key_img, col
101
  st.session_state["measure_cell_mask"] = cell_mask if auto_cell_boundary else None
102
 
103
 
104
- st.set_page_config(page_title="Shape2Force (S2F)", page_icon="🦠", layout="centered")
105
 
106
- # Theme CSS (inject based on sidebar selection)
107
- def _inject_theme_css(theme):
108
- if theme == "Dark":
109
- st.markdown("""
110
- <style>
111
- .stApp { background-color: #0e1117 !important; }
112
- .stApp header { background-color: #0e1117 !important; }
113
- section[data-testid="stSidebar"] { background-color: #1a1a2e !important; }
114
- section[data-testid="stSidebar"] .stMarkdown { color: #fafafa !important; }
115
- section[data-testid="stSidebar"] [data-testid="stWidgetLabel"] { color: #e2e8f0 !important; }
116
- h1, h2, h3 { color: #fafafa !important; }
117
- p { color: #e2e8f0 !important; }
118
- .stCaption { color: #94a3b8 !important; }
119
- </style>
120
- """, unsafe_allow_html=True)
121
 
 
 
 
 
122
 
123
  st.markdown("""
124
- <style>
125
- section[data-testid="stSidebar"] { width: 380px !important; }
126
- @media (max-width: 768px) {
127
- section[data-testid="stSidebar"] { width: 100% !important; max-width: 100% !important; }
128
- }
129
- section[data-testid="stSidebar"] h2 {
130
- font-size: 1.25rem !important;
131
- font-weight: 600 !important;
132
- }
133
- section[data-testid="stSidebar"] [data-testid="stWidgetLabel"],
134
- section[data-testid="stSidebar"] [data-testid="stWidgetLabel"] p {
135
- font-size: 0.95rem !important;
136
- font-weight: 500 !important;
137
- }
138
- div[data-testid="stHorizontalBlock"]:has([data-testid="stDownloadButton"]):has([data-testid="stButton"]) > div {
139
- flex: 1 1 0 !important; min-width: 0 !important;
140
- }
141
- div[data-testid="stHorizontalBlock"]:has([data-testid="stDownloadButton"]):has([data-testid="stButton"]) button {
142
- width: 100% !important; min-width: 100px !important; white-space: nowrap !important;
143
- }
144
- div[data-testid="stHorizontalBlock"]:has([data-testid="stDownloadButton"]):has([data-testid="stButton"]) > div:nth-child(1) button {
145
- background-color: #0d9488 !important; color: white !important; border-color: #0d9488 !important;
146
- }
147
- div[data-testid="stHorizontalBlock"]:has([data-testid="stDownloadButton"]):has([data-testid="stButton"]) > div:nth-child(1) button:hover {
148
- background-color: #0f766e !important; border-color: #0f766e !important; color: white !important;
149
- }
150
- </style>
151
  """, unsafe_allow_html=True)
152
 
153
- st.title("🦠 Shape2Force (S2F)")
154
- st.caption("Predict traction force maps from bright-field microscopy images of cells or spheroids")
 
 
 
155
 
156
  # Folders
157
  ckp_base = get_ckp_base(S2F_ROOT)
@@ -173,9 +147,67 @@ def get_cached_sample_thumbnails(model_type, sample_folder, sample_files):
173
  return cache[cache_key]
174
 
175
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
176
  # Sidebar
177
  with st.sidebar:
178
- st.header("Settings")
 
 
 
 
 
 
179
 
180
  model_type = st.radio(
181
  "Model type",
@@ -193,6 +225,7 @@ with st.sidebar:
193
  checkpoint = st.selectbox(
194
  "Checkpoint",
195
  ckp_files,
 
196
  help=f"Select a .pth file from ckp/{ckp_subfolder_name}/",
197
  )
198
  else:
@@ -232,11 +265,7 @@ with st.sidebar:
232
  except FileNotFoundError:
233
  st.error("config/substrate_settings.json not found")
234
 
235
- auto_cell_boundary = st.toggle(
236
- "Auto boundary",
237
- value=False,
238
- help="When on: estimate cell region from force map and use it for metrics (red contour). When off: use entire map.",
239
- )
240
 
241
  batch_mode = st.toggle(
242
  "Batch mode",
@@ -244,45 +273,73 @@ with st.sidebar:
244
  help=f"Process up to {BATCH_MAX_IMAGES} images at once. Upload multiple files or select multiple examples.",
245
  )
246
 
247
- display_mode = st.radio(
248
- "Heatmap display",
249
- ["Default", "Percentile", "Range"],
250
- horizontal=True,
251
- help="Default: full 0–1 range. Percentile: map a percentile range to improve contrast when few bright pixels dominate. Range: show only values in [min, max]; others hidden (black).",
252
  )
253
- min_percentile, max_percentile = 0, 100
254
- clip_min, clip_max = 0.0, 1.0
255
- if display_mode == "Percentile":
256
- col_pmin, col_pmax = st.columns(2)
257
- with col_pmin:
258
- min_percentile = st.slider("Min percentile", 0, 100, 2, 1, help="Values below this percentile → black")
259
- with col_pmax:
260
- max_percentile = st.slider("Max percentile", 0, 100, 99, 1, help="Values above this percentile → white")
261
- if min_percentile >= max_percentile:
262
- st.warning("Min percentile must be less than max. Using min=0, max=100.")
263
- min_percentile, max_percentile = 0, 100
264
- elif display_mode == "Range":
265
- col_cmin, col_cmax = st.columns(2)
266
- with col_cmin:
267
- clip_min = st.number_input("Min", value=0.0, min_value=0.0, max_value=1.0, step=0.01, format="%.3f",
268
- help="Values below this range → hidden (black)")
269
- with col_cmax:
270
- clip_max = st.number_input("Max", value=1.0, min_value=0.0, max_value=1.0, step=0.01, format="%.3f",
271
- help="Values above this range → hidden (black)")
272
- if clip_min >= clip_max:
273
- st.warning("Min must be less than max. Using min=0, max=1.")
274
- clip_min, clip_max = 0.0, 1.0
275
- colormap_name = st.selectbox(
276
- "Heatmap colormap",
277
- list(COLORMAPS.keys()),
278
- help="Color scheme for the force map. Viridis is often preferred for accessibility.",
279
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
280
 
281
- theme = st.radio("Theme", ["Light", "Dark"], horizontal=True, key="theme_selector")
282
- _inject_theme_css(theme)
 
 
 
 
 
 
 
 
 
 
 
 
 
283
 
284
  # Main area: image input
285
- img_source = st.radio("Image source", ["Upload", "Example"], horizontal=True, label_visibility="collapsed")
286
  img = None
287
  imgs_batch = [] # list of (img, key_img) for batch mode
288
  uploaded = None
@@ -309,32 +366,7 @@ if batch_mode:
309
  imgs_batch.append((decoded, u.name))
310
  u.seek(0)
311
  else:
312
- sample_folder = get_sample_folder(S2F_ROOT, model_type)
313
- sample_files = list_files_in_folder(sample_folder, SAMPLE_EXTENSIONS)
314
- sample_subfolder_name = model_subfolder(model_type)
315
- if sample_files:
316
- selected_samples = st.multiselect(
317
- f"Select example images (max {BATCH_MAX_IMAGES})",
318
- sample_files,
319
- default=None,
320
- max_selections=BATCH_MAX_IMAGES,
321
- key=f"sample_batch_{model_type}",
322
- )
323
- if selected_samples:
324
- for fname in selected_samples[:BATCH_MAX_IMAGES]:
325
- sample_path = os.path.join(sample_folder, fname)
326
- loaded = cv2.imread(sample_path, cv2.IMREAD_GRAYSCALE)
327
- if loaded is not None:
328
- imgs_batch.append((loaded, fname))
329
- thumbnails = get_cached_sample_thumbnails(model_type, sample_folder, sample_files)
330
- n_cols = min(5, len(thumbnails))
331
- cols = st.columns(n_cols)
332
- for i, (fname, sample_img) in enumerate(thumbnails):
333
- if sample_img is not None:
334
- with cols[i % n_cols]:
335
- st.image(sample_img, caption=fname, width=120)
336
- else:
337
- st.info(f"No example images in samples/{sample_subfolder_name}/. Add images or use Upload.")
338
  else:
339
  # Single image mode
340
  if img_source == "Upload":
@@ -349,37 +381,20 @@ else:
349
  img = cv2.imdecode(nparr, cv2.IMREAD_GRAYSCALE)
350
  uploaded.seek(0)
351
  else:
352
- sample_folder = get_sample_folder(S2F_ROOT, model_type)
353
- sample_files = list_files_in_folder(sample_folder, SAMPLE_EXTENSIONS)
354
- sample_subfolder_name = model_subfolder(model_type)
355
- if sample_files:
356
- selected_sample = st.selectbox(
357
- f"Select example image (from `samples/{sample_subfolder_name}/`)",
358
- sample_files,
359
- format_func=lambda x: x,
360
- key=f"sample_{model_type}",
361
- )
362
- if selected_sample:
363
- sample_path = os.path.join(sample_folder, selected_sample)
364
- img = cv2.imread(sample_path, cv2.IMREAD_GRAYSCALE)
365
- thumbnails = get_cached_sample_thumbnails(model_type, sample_folder, sample_files)
366
- n_cols = min(5, len(thumbnails))
367
- cols = st.columns(n_cols)
368
- for i, (fname, sample_img) in enumerate(thumbnails):
369
- if sample_img is not None:
370
- with cols[i % n_cols]:
371
- st.image(sample_img, caption=fname, width=120)
372
- else:
373
- st.info(f"No example images in samples/{sample_subfolder_name}/. Add images or use Upload.")
374
-
375
- col_btn, col_model, col_path = st.columns([1, 1, 3])
376
  with col_btn:
377
- run = st.button("Run prediction", type="primary")
378
- with col_model:
379
- st.markdown(f"<span style='display: inline-flex; align-items: center; height: 38px;'>{MODEL_TYPE_LABELS[model_type]}</span>", unsafe_allow_html=True)
380
- with col_path:
381
  ckp_path = f"ckp/{ckp_subfolder_name}/{checkpoint}" if checkpoint else f"ckp/{ckp_subfolder_name}/"
382
- st.markdown(f"<span style='display: inline-flex; align-items: center; height: 38px;'>Checkpoint: <code>{ckp_path}</code></span>", unsafe_allow_html=True)
 
 
 
 
 
383
 
384
  has_image = img is not None
385
  has_batch = len(imgs_batch) > 0
@@ -400,18 +415,50 @@ just_ran = run and checkpoint and has_image and not batch_mode
400
  just_ran_batch = run and checkpoint and has_batch and batch_mode
401
 
402
 
403
- def get_or_create_predictor(model_type, checkpoint, ckp_folder):
404
- """Cache predictor in session state. Invalidate when model/checkpoint changes."""
405
- cache_key = (model_type, checkpoint)
406
- if "predictor" not in st.session_state or st.session_state.get("predictor_key") != cache_key:
407
- from predictor import S2FPredictor
408
- st.session_state["predictor"] = S2FPredictor(
409
- model_type=model_type,
410
- checkpoint_path=checkpoint,
411
- ckp_folder=ckp_folder,
412
- )
413
- st.session_state["predictor_key"] = cache_key
414
- return st.session_state["predictor"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
415
 
416
 
417
  if just_ran_batch:
@@ -419,27 +466,24 @@ if just_ran_batch:
419
  st.session_state["batch_results"] = None
420
  with st.spinner("Loading model and predicting..."):
421
  try:
422
- predictor = get_or_create_predictor(model_type, checkpoint, ckp_folder)
423
  sub_val = substrate_val if model_type == "single_cell" and not use_manual else DEFAULT_SUBSTRATE
424
- batch_results = []
425
- progress_bar = st.progress(0, text="Processing images...")
426
- for idx, (img_b, key_b) in enumerate(imgs_batch):
427
- progress_bar.progress((idx + 1) / len(imgs_batch), text=f"Processing {key_b}...")
428
- heatmap, force, pixel_sum = predictor.predict(
429
- image_array=img_b,
430
- substrate=sub_val,
431
- substrate_config=substrate_config if model_type == "single_cell" else None,
432
- )
433
- cell_mask = estimate_cell_mask(heatmap) if auto_cell_boundary else None
434
- batch_results.append({
435
  "img": img_b.copy(),
436
  "heatmap": heatmap.copy(),
437
  "force": force,
438
  "pixel_sum": pixel_sum,
439
  "key_img": key_b,
440
- "cell_mask": cell_mask,
441
- })
442
- progress_bar.empty()
 
443
  st.session_state["batch_results"] = batch_results
444
  st.success(f"Prediction complete for {len(batch_results)} image(s)!")
445
  render_batch_results(
@@ -451,6 +495,7 @@ if just_ran_batch:
451
  clip_min=clip_min,
452
  clip_max=clip_max,
453
  auto_cell_boundary=auto_cell_boundary,
 
454
  )
455
  except Exception as e:
456
  st.error(f"Prediction failed: {e}")
@@ -467,88 +512,44 @@ elif batch_mode and st.session_state.get("batch_results"):
467
  clip_min=clip_min,
468
  clip_max=clip_max,
469
  auto_cell_boundary=auto_cell_boundary,
 
470
  )
471
 
472
  elif just_ran:
473
  st.session_state["prediction_result"] = None
474
  with st.spinner("Loading model and predicting..."):
475
  try:
476
- predictor = get_or_create_predictor(model_type, checkpoint, ckp_folder)
477
  sub_val = substrate_val if model_type == "single_cell" and not use_manual else DEFAULT_SUBSTRATE
478
  heatmap, force, pixel_sum = predictor.predict(
479
  image_array=img,
480
  substrate=sub_val,
481
  substrate_config=substrate_config if model_type == "single_cell" else None,
482
  )
483
-
484
- st.success("Prediction complete!")
485
-
486
- display_heatmap = apply_display_scale(
487
- heatmap, display_mode,
488
- min_percentile=min_percentile,
489
- max_percentile=max_percentile,
490
- clip_min=clip_min,
491
- clip_max=clip_max,
492
- )
493
-
494
  cache_key = (model_type, checkpoint, key_img)
495
- st.session_state["prediction_result"] = {
496
  "img": img.copy(),
497
  "heatmap": heatmap.copy(),
498
  "force": force,
499
  "pixel_sum": pixel_sum,
500
  "cache_key": cache_key,
501
  }
502
- cell_mask = estimate_cell_mask(heatmap) if auto_cell_boundary else None
503
- _populate_measure_session_state(
504
- heatmap, img, pixel_sum, force, key_img, colormap_name,
505
- display_mode, auto_cell_boundary, cell_mask=cell_mask,
506
- min_percentile=min_percentile, max_percentile=max_percentile,
507
- clip_min=clip_min, clip_max=clip_max,
508
  )
509
- render_result_display(
510
- img, heatmap, display_heatmap, pixel_sum, force, key_img,
511
- colormap_name=colormap_name,
512
- display_mode=display_mode,
513
- measure_region_dialog=_get_measure_dialog_fn(),
514
- auto_cell_boundary=auto_cell_boundary,
515
- cell_mask=cell_mask,
516
- )
517
-
518
  except Exception as e:
519
  st.error(f"Prediction failed: {e}")
520
  st.code(traceback.format_exc())
521
 
522
  elif has_cached:
523
  r = st.session_state["prediction_result"]
524
- img, heatmap, force, pixel_sum = r["img"], r["heatmap"], r["force"], r["pixel_sum"]
525
- display_heatmap = apply_display_scale(
526
- heatmap, display_mode,
527
- min_percentile=min_percentile,
528
- max_percentile=max_percentile,
529
- clip_min=clip_min,
530
- clip_max=clip_max,
531
- )
532
- cell_mask = estimate_cell_mask(heatmap) if auto_cell_boundary else None
533
- _populate_measure_session_state(
534
- heatmap, img, pixel_sum, force, key_img, colormap_name,
535
- display_mode, auto_cell_boundary, cell_mask=cell_mask,
536
- min_percentile=min_percentile, max_percentile=max_percentile,
537
- clip_min=clip_min, clip_max=clip_max,
538
- )
539
-
540
- if st.session_state.pop("open_measure_dialog", False):
541
- measure_region_dialog()
542
-
543
- st.success("Prediction complete!")
544
- render_result_display(
545
- img, heatmap, display_heatmap, pixel_sum, force, key_img,
546
- download_key_suffix="_cached",
547
- colormap_name=colormap_name,
548
- display_mode=display_mode,
549
- measure_region_dialog=_get_measure_dialog_fn(),
550
- auto_cell_boundary=auto_cell_boundary,
551
- cell_mask=cell_mask,
552
  )
553
 
554
  elif run and not checkpoint:
@@ -558,6 +559,5 @@ elif run and not has_image and not has_batch:
558
  elif run and batch_mode and not has_batch:
559
  st.warning(f"Please upload or select 1–{BATCH_MAX_IMAGES} images for batch processing.")
560
 
561
- st.sidebar.divider()
562
  render_system_status()
563
- st.sidebar.caption("<br>If you find this software useful, please cite:<br>" + CITATION, unsafe_allow_html=True)
 
22
  MODEL_TYPE_LABELS,
23
  SAMPLE_EXTENSIONS,
24
  SAMPLE_THUMBNAIL_LIMIT,
25
+ THEMES,
26
  )
27
  from utils.paths import get_ckp_base, get_ckp_folder, get_sample_folder, list_files_in_folder, model_subfolder
28
  from utils.segmentation import estimate_cell_mask
 
41
 
42
  CITATION = (
43
  "Lautaro Baro, Kaveh Shahhosseini, Amparo Andrés-Bordería, Claudio Angione, and Maria Angeles Juanes. "
44
+ "<b>\"Shape-to-force (S2F): Predicting Cell Traction Forces from LabelFree Imaging\"</b>, 2026."
45
  )
46
 
47
  # Measure tool dialog: defined early so it exists before render_result_display uses it
 
59
  max_percentile=st.session_state.get("measure_max_percentile", 100),
60
  clip_min=st.session_state.get("measure_clip_min", 0),
61
  clip_max=st.session_state.get("measure_clip_max", 1),
62
+ clip_bounds=st.session_state.get("measure_clip_bounds", False),
63
  )
64
  bf_img = st.session_state.get("measure_bf_img")
65
  original_vals = st.session_state.get("measure_original_vals")
 
84
 
85
  def _populate_measure_session_state(heatmap, img, pixel_sum, force, key_img, colormap_name,
86
  display_mode, auto_cell_boundary, cell_mask=None,
87
+ min_percentile=0, max_percentile=100, clip_min=0, clip_max=1,
88
+ clip_bounds=False):
89
  """Populate session state for the measure tool. If cell_mask is None and auto_cell_boundary, computes it."""
90
  if cell_mask is None and auto_cell_boundary:
91
  cell_mask = estimate_cell_mask(heatmap)
 
95
  st.session_state["measure_max_percentile"] = max_percentile
96
  st.session_state["measure_clip_min"] = clip_min
97
  st.session_state["measure_clip_max"] = clip_max
98
+ st.session_state["measure_clip_bounds"] = clip_bounds
99
  st.session_state["measure_bf_img"] = img.copy()
100
  st.session_state["measure_input_filename"] = key_img or "image"
101
  st.session_state["measure_original_vals"] = build_original_vals(heatmap, pixel_sum, force)
 
105
  st.session_state["measure_cell_mask"] = cell_mask if auto_cell_boundary else None
106
 
107
 
108
+ st.set_page_config(page_title="Shape2Force (S2F)", page_icon="🦠", layout="wide")
109
 
110
+ st.markdown('<link href="https://fonts.googleapis.com/css2?family=Inter:wght@300;400;500;600;700&display=swap" rel="stylesheet">', unsafe_allow_html=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
 
112
+ _css_path = os.path.join(S2F_ROOT, "static", "s2f_styles.css")
113
+ if os.path.exists(_css_path):
114
+ with open(_css_path, "r") as f:
115
+ st.markdown(f"<style>{f.read()}</style>", unsafe_allow_html=True)
116
 
117
  st.markdown("""
118
+ <div class="s2f-header">
119
+ <h1>🦠 Shape2Force (S2F)</h1>
120
+ <p>Predict traction force maps from bright-field microscopy images of cells or spheroids</p>
121
+ </div>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
122
  """, unsafe_allow_html=True)
123
 
124
+ st.markdown(f"""
125
+ <div class="footer-citation">
126
+ <span>If you find this software useful, please cite: {CITATION}</span>
127
+ </div>
128
+ """, unsafe_allow_html=True)
129
 
130
  # Folders
131
  ckp_base = get_ckp_base(S2F_ROOT)
 
147
  return cache[cache_key]
148
 
149
 
150
+ def _render_sample_selector(model_type, batch_mode):
151
+ """
152
+ Render sample image selector (Example mode). Returns (img, imgs_batch, selected_sample, selected_samples).
153
+ For single mode: img is set, imgs_batch=[]. For batch: img=None, imgs_batch=list of (img, key).
154
+ """
155
+ sample_folder = get_sample_folder(S2F_ROOT, model_type)
156
+ sample_files = list_files_in_folder(sample_folder, SAMPLE_EXTENSIONS)
157
+ sample_subfolder_name = model_subfolder(model_type)
158
+ img = None
159
+ imgs_batch = []
160
+ selected_sample = None
161
+ selected_samples = []
162
+
163
+ if not sample_files:
164
+ st.info(f"No example images in samples/{sample_subfolder_name}/. Add images or use Upload.")
165
+ return img, imgs_batch, selected_sample, selected_samples
166
+
167
+ if batch_mode:
168
+ selected_samples = st.multiselect(
169
+ f"Select example images (max {BATCH_MAX_IMAGES})",
170
+ sample_files,
171
+ default=None,
172
+ max_selections=BATCH_MAX_IMAGES,
173
+ key=f"sample_batch_{model_type}",
174
+ )
175
+ if selected_samples:
176
+ for fname in selected_samples[:BATCH_MAX_IMAGES]:
177
+ sample_path = os.path.join(sample_folder, fname)
178
+ loaded = cv2.imread(sample_path, cv2.IMREAD_GRAYSCALE)
179
+ if loaded is not None:
180
+ imgs_batch.append((loaded, fname))
181
+ else:
182
+ selected_sample = st.selectbox(
183
+ f"Select example image (from `samples/{sample_subfolder_name}/`)",
184
+ sample_files,
185
+ format_func=lambda x: x,
186
+ key=f"sample_{model_type}",
187
+ )
188
+ if selected_sample:
189
+ sample_path = os.path.join(sample_folder, selected_sample)
190
+ img = cv2.imread(sample_path, cv2.IMREAD_GRAYSCALE)
191
+
192
+ thumbnails = get_cached_sample_thumbnails(model_type, sample_folder, sample_files)
193
+ n_cols = min(5, len(thumbnails))
194
+ cols = st.columns(n_cols)
195
+ for i, (fname, sample_img) in enumerate(thumbnails):
196
+ if sample_img is not None:
197
+ with cols[i % n_cols]:
198
+ st.image(sample_img, caption=fname, width=120)
199
+ return img, imgs_batch, selected_sample, selected_samples
200
+
201
+
202
  # Sidebar
203
  with st.sidebar:
204
+ st.markdown("""
205
+ <div class="sidebar-brand">
206
+ <span class="brand-text">Shape2Force</span>
207
+ </div>
208
+ """, unsafe_allow_html=True)
209
+
210
+ st.markdown('<div class="sidebar-section"><span class="section-title">Model</span></div>', unsafe_allow_html=True)
211
 
212
  model_type = st.radio(
213
  "Model type",
 
225
  checkpoint = st.selectbox(
226
  "Checkpoint",
227
  ckp_files,
228
+ key=f"checkpoint_{model_type}",
229
  help=f"Select a .pth file from ckp/{ckp_subfolder_name}/",
230
  )
231
  else:
 
265
  except FileNotFoundError:
266
  st.error("config/substrate_settings.json not found")
267
 
268
+ st.markdown('<div class="sidebar-section"><span class="section-title">Analysis</span></div>', unsafe_allow_html=True)
 
 
 
 
269
 
270
  batch_mode = st.toggle(
271
  "Batch mode",
 
273
  help=f"Process up to {BATCH_MAX_IMAGES} images at once. Upload multiple files or select multiple examples.",
274
  )
275
 
276
+ auto_cell_boundary = st.toggle(
277
+ "Auto boundary",
278
+ value=False,
279
+ help="When on: estimate cell region from force map and use it for metrics (red contour). When off: use entire map.",
 
280
  )
281
+
282
+ clip_min, clip_max = st.slider(
283
+ "Force Range",
284
+ min_value=0.0,
285
+ max_value=1.0,
286
+ value=(0.0, 1.0),
287
+ step=0.01,
288
+ format="%.2f",
289
+ help="Min–max range for force values. Values outside are set to 0; inside are rescaled so max shows as red.",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
290
  )
291
+ if clip_min >= clip_max:
292
+ clip_min, clip_max = 0.0, 1.0
293
+ display_mode = "Range" if (clip_min != 0.0 or clip_max != 1.0) else "Default"
294
+ clip_bounds = False if display_mode == "Range" else True
295
+ min_percentile, max_percentile = 0, 100
296
+
297
+ st.markdown('<div class="sidebar-section"><span class="section-title">Display</span></div>', unsafe_allow_html=True)
298
+
299
+ cm_col_lbl, cm_col_sb = st.columns([1, 2])
300
+ with cm_col_lbl:
301
+ st.markdown('<p class="selectbox-label">Colormap</p>', unsafe_allow_html=True)
302
+ with cm_col_sb:
303
+ colormap_name = st.selectbox(
304
+ "Colormap",
305
+ list(COLORMAPS.keys()),
306
+ key="s2f_colormap",
307
+ label_visibility="collapsed",
308
+ help="Color scheme for the force map. Viridis is often preferred for accessibility.",
309
+ )
310
+
311
+ th_col_lbl, th_col_sb = st.columns([1, 2])
312
+ with th_col_lbl:
313
+ st.markdown('<p class="selectbox-label">Theme</p>', unsafe_allow_html=True)
314
+ with th_col_sb:
315
+ theme_name = st.selectbox(
316
+ "Theme",
317
+ list(THEMES.keys()),
318
+ index=0,
319
+ key="s2f_theme",
320
+ label_visibility="collapsed",
321
+ help="App accent color theme.",
322
+ )
323
+
324
 
325
+ # Inject theme CSS (main area so it applies globally)
326
+ primary, primary_dark, primary_darker, primary_rgb = THEMES[theme_name]
327
+ st.markdown(
328
+ f"""
329
+ <style>
330
+ :root {{
331
+ --s2f-primary: {primary};
332
+ --s2f-primary-dark: {primary_dark};
333
+ --s2f-primary-darker: {primary_darker};
334
+ --s2f-primary-rgb: {primary_rgb};
335
+ }}
336
+ </style>
337
+ """,
338
+ unsafe_allow_html=True,
339
+ )
340
 
341
  # Main area: image input
342
+ img_source = st.radio("Image source", ["Upload", "Example"], horizontal=True, label_visibility="collapsed", key="s2f_img_source")
343
  img = None
344
  imgs_batch = [] # list of (img, key_img) for batch mode
345
  uploaded = None
 
366
  imgs_batch.append((decoded, u.name))
367
  u.seek(0)
368
  else:
369
+ img, imgs_batch, selected_sample, selected_samples = _render_sample_selector(model_type, batch_mode=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
370
  else:
371
  # Single image mode
372
  if img_source == "Upload":
 
381
  img = cv2.imdecode(nparr, cv2.IMREAD_GRAYSCALE)
382
  uploaded.seek(0)
383
  else:
384
+ img, imgs_batch, selected_sample, selected_samples = _render_sample_selector(model_type, batch_mode=False)
385
+
386
+ st.markdown("")
387
+ col_btn, col_info = st.columns([1, 3])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
388
  with col_btn:
389
+ run = st.button("Run prediction", type="primary", use_container_width=True)
390
+ with col_info:
 
 
391
  ckp_path = f"ckp/{ckp_subfolder_name}/{checkpoint}" if checkpoint else f"ckp/{ckp_subfolder_name}/"
392
+ st.markdown(f"""
393
+ <div class="run-info">
394
+ <span class="run-info-tag">{MODEL_TYPE_LABELS[model_type]}</span>
395
+ <code>{ckp_path}</code>
396
+ </div>
397
+ """, unsafe_allow_html=True)
398
 
399
  has_image = img is not None
400
  has_batch = len(imgs_batch) > 0
 
415
  just_ran_batch = run and checkpoint and has_batch and batch_mode
416
 
417
 
418
+ @st.cache_resource
419
+ def _load_predictor(model_type, checkpoint, ckp_folder):
420
+ """Load and cache predictor. Invalidated when model_type or checkpoint changes."""
421
+ from predictor import S2FPredictor
422
+ return S2FPredictor(
423
+ model_type=model_type,
424
+ checkpoint_path=checkpoint,
425
+ ckp_folder=ckp_folder,
426
+ )
427
+
428
+
429
+ def _prepare_and_render_cached_result(r, key_img, colormap_name, display_mode, auto_cell_boundary,
430
+ min_percentile, max_percentile, clip_min, clip_max, clip_bounds,
431
+ download_key_suffix="", check_measure_dialog=False):
432
+ """Prepare display from cached result and render. Used by both just_ran and has_cached paths."""
433
+ img, heatmap, force, pixel_sum = r["img"], r["heatmap"], r["force"], r["pixel_sum"]
434
+ display_heatmap = apply_display_scale(
435
+ heatmap, display_mode,
436
+ min_percentile=min_percentile,
437
+ max_percentile=max_percentile,
438
+ clip_min=clip_min,
439
+ clip_max=clip_max,
440
+ clip_bounds=clip_bounds,
441
+ )
442
+ cell_mask = estimate_cell_mask(heatmap) if auto_cell_boundary else None
443
+ _populate_measure_session_state(
444
+ heatmap, img, pixel_sum, force, key_img, colormap_name,
445
+ display_mode, auto_cell_boundary, cell_mask=cell_mask,
446
+ min_percentile=min_percentile, max_percentile=max_percentile,
447
+ clip_min=clip_min, clip_max=clip_max, clip_bounds=clip_bounds,
448
+ )
449
+ if check_measure_dialog and st.session_state.pop("open_measure_dialog", False):
450
+ measure_region_dialog()
451
+ st.success("Prediction complete!")
452
+ render_result_display(
453
+ img, heatmap, display_heatmap, pixel_sum, force, key_img,
454
+ download_key_suffix=download_key_suffix,
455
+ colormap_name=colormap_name,
456
+ display_mode=display_mode,
457
+ measure_region_dialog=_get_measure_dialog_fn(),
458
+ auto_cell_boundary=auto_cell_boundary,
459
+ cell_mask=cell_mask,
460
+ clip_min=clip_min, clip_max=clip_max, clip_bounds=clip_bounds,
461
+ )
462
 
463
 
464
  if just_ran_batch:
 
466
  st.session_state["batch_results"] = None
467
  with st.spinner("Loading model and predicting..."):
468
  try:
469
+ predictor = _load_predictor(model_type, checkpoint, ckp_folder)
470
  sub_val = substrate_val if model_type == "single_cell" and not use_manual else DEFAULT_SUBSTRATE
471
+ pred_results = predictor.predict_batch(
472
+ imgs_batch,
473
+ substrate=sub_val,
474
+ substrate_config=substrate_config if model_type == "single_cell" else None,
475
+ )
476
+ batch_results = [
477
+ {
 
 
 
 
478
  "img": img_b.copy(),
479
  "heatmap": heatmap.copy(),
480
  "force": force,
481
  "pixel_sum": pixel_sum,
482
  "key_img": key_b,
483
+ "cell_mask": estimate_cell_mask(heatmap) if auto_cell_boundary else None,
484
+ }
485
+ for (img_b, key_b), (heatmap, force, pixel_sum) in zip(imgs_batch, pred_results)
486
+ ]
487
  st.session_state["batch_results"] = batch_results
488
  st.success(f"Prediction complete for {len(batch_results)} image(s)!")
489
  render_batch_results(
 
495
  clip_min=clip_min,
496
  clip_max=clip_max,
497
  auto_cell_boundary=auto_cell_boundary,
498
+ clip_bounds=clip_bounds,
499
  )
500
  except Exception as e:
501
  st.error(f"Prediction failed: {e}")
 
512
  clip_min=clip_min,
513
  clip_max=clip_max,
514
  auto_cell_boundary=auto_cell_boundary,
515
+ clip_bounds=clip_bounds,
516
  )
517
 
518
  elif just_ran:
519
  st.session_state["prediction_result"] = None
520
  with st.spinner("Loading model and predicting..."):
521
  try:
522
+ predictor = _load_predictor(model_type, checkpoint, ckp_folder)
523
  sub_val = substrate_val if model_type == "single_cell" and not use_manual else DEFAULT_SUBSTRATE
524
  heatmap, force, pixel_sum = predictor.predict(
525
  image_array=img,
526
  substrate=sub_val,
527
  substrate_config=substrate_config if model_type == "single_cell" else None,
528
  )
 
 
 
 
 
 
 
 
 
 
 
529
  cache_key = (model_type, checkpoint, key_img)
530
+ r = {
531
  "img": img.copy(),
532
  "heatmap": heatmap.copy(),
533
  "force": force,
534
  "pixel_sum": pixel_sum,
535
  "cache_key": cache_key,
536
  }
537
+ st.session_state["prediction_result"] = r
538
+ _prepare_and_render_cached_result(
539
+ r, key_img, colormap_name, display_mode, auto_cell_boundary,
540
+ min_percentile, max_percentile, clip_min, clip_max, clip_bounds,
541
+ download_key_suffix="", check_measure_dialog=False,
 
542
  )
 
 
 
 
 
 
 
 
 
543
  except Exception as e:
544
  st.error(f"Prediction failed: {e}")
545
  st.code(traceback.format_exc())
546
 
547
  elif has_cached:
548
  r = st.session_state["prediction_result"]
549
+ _prepare_and_render_cached_result(
550
+ r, key_img, colormap_name, display_mode, auto_cell_boundary,
551
+ min_percentile, max_percentile, clip_min, clip_max, clip_bounds,
552
+ download_key_suffix="_cached", check_measure_dialog=True,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
553
  )
554
 
555
  elif run and not checkpoint:
 
559
  elif run and batch_mode and not has_batch:
560
  st.warning(f"Please upload or select 1–{BATCH_MAX_IMAGES} images for batch processing.")
561
 
562
+ st.sidebar.markdown('<div class="sidebar-section"><span class="section-title"></span></div>', unsafe_allow_html=True)
563
  render_system_status()
 
S2FApp/config/constants.py CHANGED
@@ -25,6 +25,16 @@ TOOL_LABELS = {"polygon": "Polygon", "rect": "Rectangle", "circle": "Circle"}
25
  # File extensions
26
  SAMPLE_EXTENSIONS = (".tif", ".tiff", ".png", ".jpg", ".jpeg")
27
 
 
 
 
 
 
 
 
 
 
 
28
  # Colormaps (OpenCV)
29
  COLORMAPS = {
30
  "Jet": cv2.COLORMAP_JET,
 
25
  # File extensions
26
  SAMPLE_EXTENSIONS = (".tif", ".tiff", ".png", ".jpg", ".jpeg")
27
 
28
+ # UI themes: primary, primary-dark, primary-darker, rgb (for rgba)
29
+ THEMES = {
30
+ "Teal": ("#0d9488", "#0f766e", "#115e59", "13, 148, 136"),
31
+ "Blue": ("#2563eb", "#1d4ed8", "#1e40af", "37, 99, 235"),
32
+ "Indigo": ("#6366f1", "#4f46e5", "#4338ca", "99, 102, 241"),
33
+ "Purple": ("#7c3aed", "#6d28d9", "#5b21b6", "124, 58, 237"),
34
+ "Amber": ("#f59e0b", "#d97706", "#b45309", "245, 158, 11"),
35
+ "Emerald": ("#10b981", "#059669", "#047857", "16, 185, 129"),
36
+ }
37
+
38
  # Colormaps (OpenCV)
39
  COLORMAPS = {
40
  "Jet": cv2.COLORMAP_JET,
S2FApp/predictor.py CHANGED
@@ -177,3 +177,50 @@ class S2FPredictor:
177
  pixel_sum = float(np.sum(heatmap))
178
 
179
  return heatmap, force, pixel_sum
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
177
  pixel_sum = float(np.sum(heatmap))
178
 
179
  return heatmap, force, pixel_sum
180
+
181
+ def predict_batch(self, images, substrate=None, substrate_config=None):
182
+ """
183
+ Run prediction on a batch of images (single forward pass).
184
+
185
+ Args:
186
+ images: List of (img_array, key) or list of img arrays. img_array: (H, W) or (H, W, C).
187
+ substrate: Substrate name for single-cell mode (same for all images).
188
+ substrate_config: Optional dict with 'pixelsize' and 'young' (same for all).
189
+
190
+ Returns:
191
+ List of (heatmap, force, pixel_sum) tuples.
192
+ """
193
+ imgs = []
194
+ for item in images:
195
+ img = item[0] if isinstance(item, tuple) else item
196
+ img = np.asarray(img, dtype=np.float32)
197
+ if img.ndim == 3:
198
+ img = img[:, :, 0] if img.shape[-1] >= 1 else img
199
+ if img.max() > 1.0:
200
+ img = img / 255.0
201
+ img = cv2.resize(img, (MODEL_INPUT_SIZE, MODEL_INPUT_SIZE))
202
+ imgs.append(img)
203
+ x = torch.from_numpy(np.stack(imgs)).float().unsqueeze(1).to(self.device) # [B, 1, H, W]
204
+
205
+ if self.model_type == "single_cell" and self.norm_params is not None:
206
+ sub = substrate if substrate is not None else DEFAULT_SUBSTRATE
207
+ settings_ch = create_settings_channels_single(
208
+ sub, self.device, x.shape[2], x.shape[3],
209
+ config_path=self.config_path, substrate_config=substrate_config
210
+ )
211
+ settings_batch = settings_ch.expand(x.shape[0], -1, -1, -1)
212
+ x = torch.cat([x, settings_batch], dim=1) # [B, 3, H, W]
213
+
214
+ with torch.no_grad():
215
+ pred = self.generator(x)
216
+
217
+ if self._use_tanh_output:
218
+ pred = (pred + 1.0) / 2.0
219
+
220
+ results = []
221
+ for i in range(pred.shape[0]):
222
+ heatmap = pred[i, 0].cpu().numpy()
223
+ force = sum_force_map(pred[i : i + 1]).item()
224
+ pixel_sum = float(np.sum(heatmap))
225
+ results.append((heatmap, force, pixel_sum))
226
+ return results
S2FApp/requirements-docker.txt ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Shape2Force App - deps for Docker (torch/torchvision installed separately from CPU wheel)
2
+ numpy>=1.20.0
3
+ opencv-python>=4.5.0
4
+ scipy>=1.7.0
5
+ scikit-image>=0.19.0
6
+ streamlit>=1.28.0
7
+ streamlit-drawable-canvas-fix>=0.9.8
8
+ matplotlib>=3.5.0
9
+ Pillow>=9.0.0
10
+ plotly>=5.14.0
11
+ huggingface_hub>=0.20.0
12
+ reportlab>=4.0.0
13
+ psutil>=5.9.0
S2FApp/static/s2f_styles.css ADDED
@@ -0,0 +1,471 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /* === Theme variables (overridden by theme selector) === */
2
+ :root {
3
+ --s2f-primary: #0d9488;
4
+ --s2f-primary-dark: #0f766e;
5
+ --s2f-primary-darker: #115e59;
6
+ --s2f-primary-rgb: 13, 148, 136;
7
+ }
8
+
9
+ /* === Typography === */
10
+ html, body, .stApp {
11
+ font-family: 'Inter', -apple-system, BlinkMacSystemFont, 'Segoe UI', sans-serif !important;
12
+ }
13
+
14
+ /* === Header banner === */
15
+ .s2f-header {
16
+ background: linear-gradient(135deg, var(--s2f-primary) 0%, var(--s2f-primary-dark) 40%, var(--s2f-primary-darker) 100%);
17
+ padding: 1.1rem 1.5rem 1rem;
18
+ border-radius: 12px;
19
+ margin-bottom: 1.2rem;
20
+ color: white;
21
+ position: relative;
22
+ overflow: hidden;
23
+ box-shadow: 0 4px 20px rgba(var(--s2f-primary-rgb), 0.25);
24
+ }
25
+ .s2f-header::before {
26
+ content: '';
27
+ position: absolute;
28
+ top: -50%;
29
+ right: -15%;
30
+ width: 300px;
31
+ height: 300px;
32
+ background: radial-gradient(circle, rgba(255,255,255,0.08) 0%, transparent 70%);
33
+ border-radius: 50%;
34
+ }
35
+ .s2f-header::after {
36
+ content: '';
37
+ position: absolute;
38
+ bottom: -30%;
39
+ left: 10%;
40
+ width: 200px;
41
+ height: 200px;
42
+ background: radial-gradient(circle, rgba(255,255,255,0.05) 0%, transparent 70%);
43
+ border-radius: 50%;
44
+ }
45
+ .s2f-header h1 {
46
+ font-size: 1.85rem !important;
47
+ font-weight: 700 !important;
48
+ margin: 0 0 0.35rem !important;
49
+ color: white !important;
50
+ letter-spacing: -0.02em;
51
+ position: relative;
52
+ z-index: 1;
53
+ }
54
+ .s2f-header p {
55
+ font-size: 0.95rem !important;
56
+ color: rgba(255,255,255,0.85) !important;
57
+ margin: 0 !important;
58
+ font-weight: 400;
59
+ position: relative;
60
+ z-index: 1;
61
+ }
62
+
63
+ /* === Sidebar === */
64
+ section[data-testid="stSidebar"] {
65
+ width: 360px !important;
66
+ background: linear-gradient(180deg, #f8fafc 0%, #f1f5f9 100%) !important;
67
+ border-right: 1px solid #e2e8f0 !important;
68
+ }
69
+ @media (max-width: 768px) {
70
+ section[data-testid="stSidebar"] { width: 100% !important; max-width: 100% !important; }
71
+ }
72
+ section[data-testid="stSidebar"] [data-testid="stWidgetLabel"],
73
+ section[data-testid="stSidebar"] [data-testid="stWidgetLabel"] p {
74
+ font-size: 0.9rem !important;
75
+ font-weight: 500 !important;
76
+ color: #334155 !important;
77
+ }
78
+ .sidebar-section {
79
+ display: flex;
80
+ align-items: center;
81
+ gap: 8px;
82
+ padding: 0.5rem 0 0.25rem;
83
+ margin-top: 0.6rem;
84
+ border-bottom: 2px solid #94a3b8;
85
+ margin-bottom: 0.75rem;
86
+ }
87
+ .sidebar-section .section-title {
88
+ font-size: 0.78rem;
89
+ font-weight: 700;
90
+ color: var(--s2f-primary-dark);
91
+ text-transform: uppercase;
92
+ letter-spacing: 0.06em;
93
+ }
94
+ .sidebar-brand {
95
+ display: flex;
96
+ align-items: center;
97
+ gap: 10px;
98
+ padding-bottom: 0.8rem;
99
+ margin-bottom: 0.5rem;
100
+ border-bottom: 1px solid #e2e8f0;
101
+ }
102
+ .sidebar-brand .brand-icon {
103
+ font-size: 1.6rem;
104
+ line-height: 1;
105
+ }
106
+ .sidebar-brand .brand-text {
107
+ font-size: 1.1rem;
108
+ font-weight: 700;
109
+ color: var(--s2f-primary-dark);
110
+ letter-spacing: -0.01em;
111
+ }
112
+
113
+ /* === Metric cards === */
114
+ [data-testid="stMetric"] {
115
+ background: linear-gradient(145deg, #ffffff 0%, #f8fafc 100%);
116
+ border: 1px solid #e2e8f0;
117
+ border-radius: 12px;
118
+ padding: 0.85rem 1rem !important;
119
+ box-shadow: 0 1px 4px rgba(0,0,0,0.06);
120
+ transition: box-shadow 0.2s ease, transform 0.2s ease;
121
+ }
122
+ [data-testid="stMetric"]:hover {
123
+ box-shadow: 0 4px 14px rgba(0,0,0,0.1);
124
+ transform: translateY(-1px);
125
+ }
126
+ [data-testid="stMetric"] label {
127
+ font-size: 0.75rem !important;
128
+ font-weight: 600 !important;
129
+ color: #64748b !important;
130
+ text-transform: uppercase;
131
+ letter-spacing: 0.04em;
132
+ }
133
+ [data-testid="stMetric"] [data-testid="stMetricValue"] {
134
+ font-size: 1.25rem !important;
135
+ font-weight: 700 !important;
136
+ color: var(--s2f-primary-dark) !important;
137
+ }
138
+
139
+ /* === Buttons === */
140
+ .stButton > button[kind="primary"], button[kind="primary"] {
141
+ background: linear-gradient(135deg, var(--s2f-primary), var(--s2f-primary-dark)) !important;
142
+ border: none !important;
143
+ border-radius: 10px !important;
144
+ font-weight: 600 !important;
145
+ letter-spacing: 0.02em !important;
146
+ box-shadow: 0 2px 10px rgba(var(--s2f-primary-rgb), 0.3) !important;
147
+ transition: all 0.2s ease !important;
148
+ }
149
+ .stButton > button[kind="primary"]:hover, button[kind="primary"]:hover {
150
+ background: linear-gradient(135deg, var(--s2f-primary-dark), var(--s2f-primary-darker)) !important;
151
+ box-shadow: 0 4px 18px rgba(var(--s2f-primary-rgb), 0.4) !important;
152
+ transform: translateY(-1px) !important;
153
+ }
154
+ .stButton > button:not([kind="primary"]) {
155
+ border-radius: 10px !important;
156
+ font-weight: 500 !important;
157
+ border: 1px solid #cbd5e1 !important;
158
+ transition: all 0.2s ease !important;
159
+ }
160
+ .stButton > button:not([kind="primary"]):hover {
161
+ border-color: var(--s2f-primary) !important;
162
+ color: var(--s2f-primary) !important;
163
+ background: rgba(var(--s2f-primary-rgb), 0.04) !important;
164
+ }
165
+ [data-testid="stDownloadButton"] button {
166
+ border-radius: 10px !important;
167
+ font-weight: 500 !important;
168
+ transition: all 0.2s ease !important;
169
+ }
170
+ [data-testid="stDownloadButton"] button:hover {
171
+ border-color: var(--s2f-primary) !important;
172
+ color: var(--s2f-primary) !important;
173
+ }
174
+
175
+ /* === Action buttons row (measure + downloads) === */
176
+ div[data-testid="stHorizontalBlock"]:has([data-testid="stDownloadButton"]):has([data-testid="stButton"]) {
177
+ background: #f8fafc;
178
+ border: 1px solid #e2e8f0;
179
+ border-radius: 14px;
180
+ padding: 0.65rem 0.6rem !important;
181
+ margin-top: 0.5rem;
182
+ gap: 0.5rem !important;
183
+ }
184
+ div[data-testid="stHorizontalBlock"]:has([data-testid="stDownloadButton"]):has([data-testid="stButton"]) > div {
185
+ flex: 1 1 0 !important; min-width: 0 !important;
186
+ }
187
+ div[data-testid="stHorizontalBlock"]:has([data-testid="stDownloadButton"]):has([data-testid="stButton"]) button {
188
+ width: 100% !important;
189
+ min-width: 0 !important;
190
+ white-space: nowrap !important;
191
+ border-radius: 10px !important;
192
+ font-size: 0.82rem !important;
193
+ font-weight: 600 !important;
194
+ padding: 0.55rem 0.8rem !important;
195
+ letter-spacing: 0.01em !important;
196
+ }
197
+ div[data-testid="stHorizontalBlock"]:has([data-testid="stDownloadButton"]):has([data-testid="stButton"]) > div:nth-child(1) button {
198
+ background: linear-gradient(135deg, var(--s2f-primary), var(--s2f-primary-dark)) !important;
199
+ color: white !important;
200
+ border-color: transparent !important;
201
+ box-shadow: 0 2px 8px rgba(var(--s2f-primary-rgb), 0.25) !important;
202
+ }
203
+ div[data-testid="stHorizontalBlock"]:has([data-testid="stDownloadButton"]):has([data-testid="stButton"]) > div:nth-child(1) button:hover {
204
+ background: linear-gradient(135deg, var(--s2f-primary-dark), var(--s2f-primary-darker)) !important;
205
+ box-shadow: 0 4px 14px rgba(var(--s2f-primary-rgb), 0.35) !important;
206
+ transform: translateY(-1px) !important;
207
+ }
208
+ div[data-testid="stHorizontalBlock"]:has([data-testid="stDownloadButton"]):has([data-testid="stButton"]) [data-testid="stDownloadButton"] button {
209
+ background: white !important;
210
+ border: 1px solid #e2e8f0 !important;
211
+ color: #334155 !important;
212
+ box-shadow: 0 1px 2px rgba(0,0,0,0.04) !important;
213
+ }
214
+ div[data-testid="stHorizontalBlock"]:has([data-testid="stDownloadButton"]):has([data-testid="stButton"]) [data-testid="stDownloadButton"] button:hover {
215
+ background: rgba(var(--s2f-primary-rgb), 0.06) !important;
216
+ border-color: var(--s2f-primary) !important;
217
+ color: var(--s2f-primary-dark) !important;
218
+ box-shadow: 0 2px 6px rgba(var(--s2f-primary-rgb), 0.12) !important;
219
+ transform: translateY(-1px) !important;
220
+ }
221
+
222
+ /* === Expanders === */
223
+ .stExpander {
224
+ border: 1px solid #e2e8f0 !important;
225
+ border-radius: 12px !important;
226
+ overflow: hidden;
227
+ box-shadow: 0 1px 3px rgba(0,0,0,0.04);
228
+ }
229
+
230
+ /* === File uploader === */
231
+ [data-testid="stFileUploader"] section {
232
+ border: 2px dashed #cbd5e1 !important;
233
+ border-radius: 12px !important;
234
+ transition: border-color 0.2s ease;
235
+ }
236
+ [data-testid="stFileUploader"] section:hover {
237
+ border-color: var(--s2f-primary) !important;
238
+ }
239
+
240
+ /* === Result labels === */
241
+ .result-label {
242
+ display: flex;
243
+ align-items: center;
244
+ gap: 8px;
245
+ font-size: 0.92rem;
246
+ font-weight: 600;
247
+ color: #334155;
248
+ padding: 0.4rem 0;
249
+ }
250
+ .result-badge {
251
+ font-size: 0.68rem;
252
+ font-weight: 700;
253
+ padding: 2px 8px;
254
+ border-radius: 4px;
255
+ letter-spacing: 0.06em;
256
+ }
257
+ .result-badge.input {
258
+ background: #e2e8f0;
259
+ color: #475569;
260
+ }
261
+ .result-badge.output {
262
+ background: rgba(var(--s2f-primary-rgb), 0.15);
263
+ color: var(--s2f-primary-dark);
264
+ }
265
+
266
+ /* === Scale visualization === */
267
+ .scale-viz {
268
+ margin: 0.3rem 0 0.5rem;
269
+ font-size: 0.78rem;
270
+ color: #64748b;
271
+ }
272
+ .sv-track {
273
+ display: flex;
274
+ align-items: center;
275
+ gap: 6px;
276
+ }
277
+ .sv-end {
278
+ font-weight: 600;
279
+ font-size: 0.72rem;
280
+ color: #94a3b8;
281
+ min-width: 14px;
282
+ text-align: center;
283
+ }
284
+ .sv-bar {
285
+ flex: 1;
286
+ height: 10px;
287
+ background: #e2e8f0;
288
+ border-radius: 5px;
289
+ position: relative;
290
+ overflow: visible;
291
+ }
292
+ .sv-active {
293
+ position: absolute;
294
+ top: 0;
295
+ height: 100%;
296
+ background: linear-gradient(90deg, var(--s2f-primary), var(--s2f-primary-dark));
297
+ border-radius: 5px;
298
+ box-shadow: 0 1px 4px rgba(var(--s2f-primary-rgb), 0.3);
299
+ }
300
+ .sv-lbl {
301
+ position: absolute;
302
+ top: 14px;
303
+ font-size: 0.7rem;
304
+ font-weight: 700;
305
+ color: var(--s2f-primary-dark);
306
+ white-space: nowrap;
307
+ }
308
+ .sv-lbl-l { left: 0; }
309
+ .sv-lbl-r { right: 0; }
310
+ .sv-note {
311
+ display: flex;
312
+ align-items: center;
313
+ gap: 4px;
314
+ margin-top: 10px;
315
+ font-size: 0.75rem;
316
+ color: #64748b;
317
+ }
318
+ .sv-pill {
319
+ background: rgba(var(--s2f-primary-rgb), 0.15);
320
+ color: var(--s2f-primary-dark);
321
+ font-weight: 700;
322
+ padding: 1px 6px;
323
+ border-radius: 4px;
324
+ font-size: 0.72rem;
325
+ }
326
+
327
+ /* === Run prediction info bar === */
328
+ .run-info {
329
+ display: flex;
330
+ align-items: center;
331
+ gap: 10px;
332
+ height: 42px;
333
+ font-size: 0.85rem;
334
+ color: #64748b;
335
+ }
336
+ .run-info-tag {
337
+ background: rgba(var(--s2f-primary-rgb), 0.1);
338
+ color: var(--s2f-primary-dark);
339
+ font-weight: 600;
340
+ font-size: 0.78rem;
341
+ padding: 3px 10px;
342
+ border-radius: 6px;
343
+ }
344
+ .run-info code {
345
+ background: #f1f5f9;
346
+ padding: 2px 8px;
347
+ border-radius: 4px;
348
+ font-size: 0.8rem;
349
+ color: #475569;
350
+ }
351
+
352
+ /* === Messages === */
353
+ .stSuccess {
354
+ background: linear-gradient(135deg, #ecfdf5 0%, #d1fae5 100%) !important;
355
+ border-left: 4px solid #10b981 !important;
356
+ border-radius: 8px !important;
357
+ }
358
+ .stWarning { border-radius: 8px !important; }
359
+ .stInfo { border-radius: 8px !important; }
360
+
361
+ /* === Selectbox, toggle === */
362
+ [data-testid="stSelectbox"] > div > div { border-radius: 8px !important; }
363
+ .selectbox-label {
364
+ margin: 0;
365
+ padding-top: 0.4rem;
366
+ font-size: 0.9rem;
367
+ font-weight: 500;
368
+ color: #334155;
369
+ line-height: 1.2;
370
+ }
371
+
372
+ /* === Dataframe === */
373
+ [data-testid="stDataFrame"] {
374
+ border-radius: 10px !important;
375
+ overflow: hidden;
376
+ box-shadow: 0 1px 3px rgba(0,0,0,0.06);
377
+ }
378
+
379
+ /* === Batch colorbar (minimal, no box) === */
380
+ .colorbar-table-header {
381
+ width: 100%;
382
+ margin-bottom: 0.5rem;
383
+ padding: 0;
384
+ background: transparent;
385
+ border: none;
386
+ box-shadow: none;
387
+ }
388
+ .colorbar-ticks {
389
+ display: flex;
390
+ justify-content: space-between;
391
+ align-items: center;
392
+ margin-bottom: 5px;
393
+ padding: 0 1px;
394
+ font-size: 0.65rem;
395
+ font-weight: 600;
396
+ color: #64748b;
397
+ letter-spacing: 0.03em;
398
+ }
399
+ .colorbar-ticks .cb-tick {
400
+ font-variant-numeric: tabular-nums;
401
+ }
402
+ .colorbar-bar {
403
+ width: 100%;
404
+ height: 6px;
405
+ background-size: 100% 100%;
406
+ background-repeat: no-repeat;
407
+ background-position: center;
408
+ border-radius: 3px;
409
+ box-shadow: inset 0 1px 1px rgba(0,0,0,0.05);
410
+ }
411
+ /* === Divider === */
412
+ hr { border-color: #cbd5e1 !important; opacity: 0.7; }
413
+
414
+ /* === Plotly chart === */
415
+ .stPlotlyChart { border-radius: 12px; overflow: hidden; }
416
+
417
+ /* === System status === */
418
+ .system-status {
419
+ font-size: 0.78rem;
420
+ margin-top: 0.5rem;
421
+ padding: 8px 12px;
422
+ border-radius: 8px;
423
+ border: 1px solid rgba(148, 163, 184, 0.25);
424
+ background: rgba(148, 163, 184, 0.08);
425
+ color: inherit;
426
+ display: flex;
427
+ align-items: center;
428
+ gap: 6px;
429
+ }
430
+ .system-status .status-dot {
431
+ width: 6px;
432
+ height: 6px;
433
+ border-radius: 50%;
434
+ background: #10b981;
435
+ display: inline-block;
436
+ flex-shrink: 0;
437
+ }
438
+
439
+ /* === Footer citation === */
440
+ .footer-citation {
441
+ position: fixed;
442
+ bottom: 0;
443
+ left: 360px;
444
+ right: 0;
445
+ z-index: 999;
446
+ padding: 0.45rem 1rem;
447
+ background: #f1f5f9;
448
+ border-top: 1px solid #e2e8f0;
449
+ font-size: 0.7rem;
450
+ color: #64748b;
451
+ text-align: center;
452
+ line-height: 1.4;
453
+ }
454
+ .block-container {
455
+ padding-bottom: 2.5rem !important;
456
+ max-width: 1050px !important;
457
+ }
458
+ section[data-testid="stSidebar"] > div:first-child {
459
+ padding-top: 1rem !important;
460
+ }
461
+
462
+ /* === Responsive === */
463
+ @media (max-width: 768px) {
464
+ .s2f-header {
465
+ padding: 1.5rem;
466
+ border-radius: 12px;
467
+ }
468
+ .s2f-header h1 {
469
+ font-size: 1.4rem !important;
470
+ }
471
+ }
S2FApp/ui/components.py CHANGED
@@ -1,837 +1,29 @@
1
- """UI components for S2F App."""
2
- import csv
3
- import html
4
- import io
5
- import os
6
- import zipfile
7
-
8
- import cv2
9
- import numpy as np
10
  import streamlit as st
11
- from PIL import Image
12
- import plotly.graph_objects as go
13
- from plotly.subplots import make_subplots
14
-
15
- from config.constants import (
16
- CANVAS_SIZE,
17
- COLORMAPS,
18
- DRAW_TOOLS,
19
- TOOL_LABELS,
20
- )
21
- from utils.display import apply_display_scale, cv_colormap_to_plotly_colorscale
22
- from utils.report import (
23
- heatmap_to_rgb,
24
- heatmap_to_rgb_with_contour,
25
- heatmap_to_png_bytes,
26
- create_pdf_report,
27
- create_measure_pdf_report,
28
- )
29
- from utils.segmentation import estimate_cell_mask
30
 
31
- try:
32
- from streamlit_drawable_canvas import st_canvas
33
- HAS_DRAWABLE_CANVAS = True
34
- except (ImportError, AttributeError):
35
- HAS_DRAWABLE_CANVAS = False
36
-
37
- try:
38
- import psutil
39
- HAS_PSUTIL = True
40
- except ImportError:
41
- HAS_PSUTIL = False
42
-
43
- # Resolve st.dialog early to fix ordering bug (used in _render_result_display)
44
  ST_DIALOG = getattr(st, "dialog", None) or getattr(st, "experimental_dialog", None)
45
 
 
 
 
 
 
 
 
 
 
 
46
 
47
- def _get_container_memory():
48
- """
49
- Read memory from cgroups when running in a container (Docker, HF Spaces).
50
- psutil reports host memory in containers, which can be misleading (e.g. 128 GB vs 16 GB limit).
51
- Returns (used_bytes, total_bytes) or None to fall back to psutil.
52
- """
53
- try:
54
- # cgroup v2 (modern Docker, HF Spaces)
55
- for base in ("/sys/fs/cgroup", "/sys/fs/cgroup/self"):
56
- try:
57
- with open(f"{base}/memory.max", "r") as f:
58
- max_val = f.read().strip()
59
- if max_val == "max":
60
- return None # No limit, use psutil
61
- total = int(max_val)
62
- with open(f"{base}/memory.current", "r") as f:
63
- used = int(f.read().strip())
64
- return (used, total)
65
- except (FileNotFoundError, ValueError):
66
- continue
67
- # cgroup v1
68
- try:
69
- with open("/sys/fs/cgroup/memory/memory.limit_in_bytes", "r") as f:
70
- total = int(f.read().strip())
71
- with open("/sys/fs/cgroup/memory/memory.usage_in_bytes", "r") as f:
72
- used = int(f.read().strip())
73
- if total > 2**50: # Often 9223372036854771712 when unlimited
74
- return None
75
- return (used, total)
76
- except (FileNotFoundError, ValueError):
77
- pass
78
- except Exception:
79
- pass
80
- return None
81
-
82
-
83
- def render_system_status():
84
- """Render a small live CPU/memory status panel in the sidebar."""
85
- if not HAS_PSUTIL:
86
- return
87
- try:
88
- cpu = psutil.cpu_percent(interval=0.1)
89
- # Prefer cgroup memory in containers (Docker, HF Spaces); psutil shows host memory
90
- container_mem = _get_container_memory()
91
- if container_mem is not None:
92
- used_bytes, total_bytes = container_mem
93
- mem_used_gb = used_bytes / (1024**3)
94
- mem_total_gb = total_bytes / (1024**3)
95
- mem_pct = 100 * used_bytes / total_bytes if total_bytes > 0 else 0
96
- else:
97
- mem = psutil.virtual_memory()
98
- mem_used_gb = mem.used / (1024**3)
99
- mem_total_gb = mem.total / (1024**3)
100
- mem_pct = mem.percent
101
- st.sidebar.markdown(
102
- f"""
103
- <div style="
104
- font-size: 0.8rem; margin-top: 0.5rem; padding: 6px 10px;
105
- border-radius: 6px;
106
- border: 1px solid rgba(148, 163, 184, 0.3);
107
- background: rgba(148, 163, 184, 0.1);
108
- color: inherit;
109
- ">
110
- <strong>System</strong> CPU {cpu:.0f}% · Mem {mem_pct:.0f}% ({mem_used_gb:.1f}/{mem_total_gb:.1f} GB)
111
- </div>
112
- """,
113
- unsafe_allow_html=True,
114
- )
115
- except Exception:
116
- pass
117
-
118
-
119
- def render_batch_results(batch_results, colormap_name="Jet", display_mode="Default",
120
- min_percentile=0, max_percentile=100, clip_min=0, clip_max=1,
121
- auto_cell_boundary=False):
122
- """
123
- Render batch prediction results: summary table, bright-field row, heatmap row, and bulk download.
124
- batch_results: list of dicts with img, heatmap, force, pixel_sum, key_img, cell_mask.
125
- cell_mask is computed on-the-fly when auto_cell_boundary is True and not stored.
126
- """
127
- if not batch_results:
128
- return
129
- st.markdown("### Batch results")
130
- # Resolve cell_mask for each result (compute if needed when auto_cell_boundary toggled on)
131
- for r in batch_results:
132
- if auto_cell_boundary and (r.get("cell_mask") is None or not np.any(r.get("cell_mask", 0) > 0)):
133
- r["_cell_mask"] = estimate_cell_mask(r["heatmap"])
134
- else:
135
- r["_cell_mask"] = r.get("cell_mask") if auto_cell_boundary else None
136
- # Build table rows - consistent column names for both modes
137
- headers = ["Image", "Force", "Sum", "Max", "Mean"]
138
- rows = []
139
- csv_rows = [["image"] + headers[1:]]
140
- for r in batch_results:
141
- heatmap = r["heatmap"]
142
- cell_mask = r.get("_cell_mask")
143
- key = r["key_img"] or "image"
144
- if auto_cell_boundary and cell_mask is not None and np.any(cell_mask > 0):
145
- vals = heatmap[cell_mask > 0]
146
- cell_pixel_sum = float(np.sum(vals))
147
- cell_force = cell_pixel_sum * (r["force"] / r["pixel_sum"]) if r["pixel_sum"] > 0 else cell_pixel_sum
148
- cell_mean = cell_pixel_sum / np.sum(cell_mask) if np.sum(cell_mask) > 0 else 0
149
- row = [key, f"{cell_force:.2f}", f"{cell_pixel_sum:.2f}",
150
- f"{np.max(heatmap):.4f}", f"{cell_mean:.4f}"]
151
- else:
152
- row = [key, f"{r['force']:.2f}", f"{r['pixel_sum']:.2f}",
153
- f"{np.max(heatmap):.4f}", f"{np.mean(heatmap):.4f}"]
154
- rows.append(row)
155
- csv_rows.append([os.path.splitext(key)[0]] + row[1:])
156
- # Bright-field row
157
- st.markdown("**Input: Bright-field images**")
158
- n_cols = min(5, len(batch_results))
159
- bf_cols = st.columns(n_cols)
160
- for i, r in enumerate(batch_results):
161
- img = r["img"]
162
- if img.ndim == 2:
163
- img_rgb = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
164
- else:
165
- img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
166
- with bf_cols[i % n_cols]:
167
- st.image(img_rgb, caption=r["key_img"], use_container_width=True)
168
- # Heatmap row
169
- st.markdown("**Output: Predicted force maps**")
170
- hm_cols = st.columns(n_cols)
171
- for i, r in enumerate(batch_results):
172
- display_heatmap = apply_display_scale(
173
- r["heatmap"], display_mode,
174
- min_percentile=min_percentile, max_percentile=max_percentile,
175
- clip_min=clip_min, clip_max=clip_max,
176
- )
177
- hm_rgb = heatmap_to_rgb_with_contour(
178
- display_heatmap, colormap_name,
179
- r.get("_cell_mask") if auto_cell_boundary else None,
180
- )
181
- with hm_cols[i % n_cols]:
182
- st.image(hm_rgb, caption=r["key_img"], use_container_width=True)
183
- # Table
184
- st.dataframe(
185
- {h: [r[i] for r in rows] for i, h in enumerate(headers)},
186
- use_container_width=True,
187
- hide_index=True,
188
- )
189
- # Histograms in accordion (one per row for visibility)
190
- with st.expander("Force distribution (histograms)", expanded=False):
191
- for i, r in enumerate(batch_results):
192
- heatmap = r["heatmap"]
193
- cell_mask = r.get("_cell_mask")
194
- vals = heatmap[cell_mask > 0] if (cell_mask is not None and np.any(cell_mask > 0) and auto_cell_boundary) else heatmap.flatten()
195
- vals = vals[vals > 0] if np.any(vals > 0) else vals
196
- st.markdown(f"**{r['key_img']}**")
197
- if len(vals) > 0:
198
- fig = go.Figure(data=[go.Histogram(x=vals, nbinsx=50, marker_color="#0d9488")])
199
- fig.update_layout(
200
- height=220, margin=dict(l=40, r=20, t=10, b=40),
201
- xaxis_title="Force value", yaxis_title="Count",
202
- showlegend=False,
203
- )
204
- st.plotly_chart(fig, use_container_width=True, config={"displayModeBar": False})
205
- else:
206
- st.caption("No data")
207
- if i < len(batch_results) - 1:
208
- st.divider()
209
- # Bulk downloads: CSV and heatmaps (zip)
210
- buf_csv = io.StringIO()
211
- csv.writer(buf_csv).writerows(csv_rows)
212
- zip_buf = io.BytesIO()
213
- with zipfile.ZipFile(zip_buf, "w", zipfile.ZIP_DEFLATED) as zf:
214
- for r in batch_results:
215
- display_heatmap = apply_display_scale(
216
- r["heatmap"], display_mode,
217
- min_percentile=min_percentile, max_percentile=max_percentile,
218
- clip_min=clip_min, clip_max=clip_max,
219
- )
220
- hm_bytes = heatmap_to_png_bytes(
221
- display_heatmap, colormap_name,
222
- r.get("_cell_mask") if auto_cell_boundary else None,
223
- )
224
- base = os.path.splitext(r["key_img"] or "image")[0]
225
- zf.writestr(f"{base}_heatmap.png", hm_bytes.getvalue())
226
- zip_buf.seek(0)
227
- dl_col1, dl_col2 = st.columns(2)
228
- with dl_col1:
229
- st.download_button(
230
- "Download all as CSV",
231
- data=buf_csv.getvalue(),
232
- file_name="s2f_batch_results.csv",
233
- mime="text/csv",
234
- key="download_batch_csv",
235
- icon=":material/download:",
236
- )
237
- with dl_col2:
238
- st.download_button(
239
- "Download all heatmaps",
240
- data=zip_buf.getvalue(),
241
- file_name="s2f_batch_heatmaps.zip",
242
- mime="application/zip",
243
- key="download_batch_heatmaps",
244
- icon=":material/image:",
245
- )
246
-
247
-
248
- # Distinct colors for each region (RGB - heatmap_rgb is RGB)
249
- _REGION_COLORS = [
250
- (255, 102, 0), # orange
251
- (255, 165, 0), # orange-red
252
- (255, 255, 0), # yellow
253
- (255, 0, 255), # magenta
254
- (0, 255, 127), # spring green
255
- (0, 128, 255), # blue
256
- (203, 192, 255), # lavender
257
- (255, 215, 0), # gold
258
  ]
259
-
260
-
261
- def _draw_region_overlay(annotated, mask, color, fill_alpha=0.3, stroke_width=2):
262
- """Draw single region overlay on annotated heatmap (fill + alpha blend + contour). Modifies annotated in place."""
263
- contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
264
- overlay = annotated.copy()
265
- cv2.fillPoly(overlay, contours, color)
266
- mask_3d = np.stack([mask] * 3, axis=-1).astype(bool)
267
- annotated[mask_3d] = (
268
- (1 - fill_alpha) * annotated[mask_3d].astype(np.float32)
269
- + fill_alpha * overlay[mask_3d].astype(np.float32)
270
- ).astype(np.uint8)
271
- cv2.drawContours(annotated, contours, -1, color, stroke_width)
272
-
273
-
274
- def make_annotated_heatmap(heatmap_rgb, mask, fill_alpha=0.3, stroke_color=(255, 102, 0), stroke_width=2):
275
- """Composite heatmap with drawn region overlay."""
276
- annotated = heatmap_rgb.copy()
277
- _draw_region_overlay(annotated, mask, stroke_color, fill_alpha, stroke_width)
278
- return annotated
279
-
280
-
281
- def make_annotated_heatmap_multi_regions(heatmap_rgb, masks, labels, cell_mask=None, fill_alpha=0.3):
282
- """Draw each region separately with distinct color and label (R1, R2, ...). No merging."""
283
- annotated = heatmap_rgb.copy()
284
- if cell_mask is not None and np.any(cell_mask > 0):
285
- contours, _ = cv2.findContours(cell_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
286
- cv2.drawContours(annotated, contours, -1, (255, 0, 0), 2)
287
- for i, mask in enumerate(masks):
288
- color = _REGION_COLORS[i % len(_REGION_COLORS)]
289
- _draw_region_overlay(annotated, mask, color, fill_alpha, stroke_width=2)
290
- # Label at centroid
291
- M = cv2.moments(mask)
292
- if M["m00"] > 0:
293
- cx = int(M["m10"] / M["m00"])
294
- cy = int(M["m01"] / M["m00"])
295
- label = labels[i] if i < len(labels) else f"R{i + 1}"
296
- cv2.putText(
297
- annotated, label, (cx - 12, cy + 5),
298
- cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2, cv2.LINE_AA
299
- )
300
- cv2.putText(
301
- annotated, label, (cx - 12, cy + 5),
302
- cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 0, 0), 1, cv2.LINE_AA
303
- )
304
- return annotated
305
-
306
-
307
- def _obj_to_pts(obj, scale_x, scale_y, heatmap_w, heatmap_h):
308
- """Convert a single canvas object to polygon points in heatmap coords. Returns None if invalid."""
309
- obj_type = obj.get("type", "")
310
- pts = []
311
- if obj_type == "rect":
312
- left = obj.get("left", 0)
313
- top = obj.get("top", 0)
314
- w = obj.get("width", 0)
315
- h = obj.get("height", 0)
316
- pts = np.array([
317
- [left, top], [left + w, top], [left + w, top + h], [left, top + h]
318
- ], dtype=np.float32)
319
- elif obj_type == "circle" or obj_type == "ellipse":
320
- left = obj.get("left", 0)
321
- top = obj.get("top", 0)
322
- width = obj.get("width", 0)
323
- height = obj.get("height", 0)
324
- radius = obj.get("radius", 0)
325
- angle_deg = obj.get("angle", 0)
326
- if radius > 0:
327
- rx = ry = radius
328
- angle_rad = np.deg2rad(angle_deg)
329
- cx = left + radius * np.cos(angle_rad)
330
- cy = top + radius * np.sin(angle_rad)
331
- else:
332
- rx = width / 2 if width > 0 else 0
333
- ry = height / 2 if height > 0 else 0
334
- if rx <= 0 or ry <= 0:
335
- return None
336
- cx = left + rx
337
- cy = top + ry
338
- if rx <= 0 or ry <= 0:
339
- return None
340
- n = 32
341
- angles = np.linspace(0, 2 * np.pi, n, endpoint=False)
342
- pts = np.column_stack([cx + rx * np.cos(angles), cy + ry * np.sin(angles)]).astype(np.float32)
343
- elif obj_type == "path":
344
- path = obj.get("path", [])
345
- for cmd in path:
346
- if isinstance(cmd, (list, tuple)) and len(cmd) >= 3:
347
- if cmd[0] in ("M", "L"):
348
- pts.append([float(cmd[1]), float(cmd[2])])
349
- elif cmd[0] == "Q" and len(cmd) >= 5:
350
- pts.append([float(cmd[3]), float(cmd[4])])
351
- elif cmd[0] == "C" and len(cmd) >= 7:
352
- pts.append([float(cmd[5]), float(cmd[6])])
353
- if len(pts) < 3:
354
- return None
355
- pts = np.array(pts, dtype=np.float32)
356
- else:
357
- return None
358
- pts[:, 0] *= scale_x
359
- pts[:, 1] *= scale_y
360
- pts = np.clip(pts, 0, [heatmap_w - 1, heatmap_h - 1]).astype(np.int32)
361
- return pts
362
-
363
-
364
- def parse_canvas_shapes_to_masks(json_data, canvas_h, canvas_w, heatmap_h, heatmap_w):
365
- """Parse drawn shapes and return a list of individual masks (one per shape)."""
366
- if not json_data or "objects" not in json_data or not json_data["objects"]:
367
- return []
368
- scale_x = heatmap_w / canvas_w
369
- scale_y = heatmap_h / canvas_h
370
- masks = []
371
- for obj in json_data["objects"]:
372
- pts = _obj_to_pts(obj, scale_x, scale_y, heatmap_w, heatmap_h)
373
- if pts is None:
374
- continue
375
- mask = np.zeros((heatmap_h, heatmap_w), dtype=np.uint8)
376
- cv2.fillPoly(mask, [pts], 1)
377
- masks.append(mask)
378
- return masks
379
-
380
-
381
- def build_original_vals(raw_heatmap, pixel_sum, force):
382
- """Build original_vals dict for measure tool (full map)."""
383
- return {
384
- "pixel_sum": pixel_sum,
385
- "force": force,
386
- "max": float(np.max(raw_heatmap)),
387
- "mean": float(np.mean(raw_heatmap)),
388
- }
389
-
390
-
391
- def build_cell_vals(raw_heatmap, cell_mask, pixel_sum, force):
392
- """Build cell_vals dict for measure tool (estimated cell area). Returns None if invalid."""
393
- cell_pixel_sum, cell_force, cell_mean = _compute_cell_metrics(raw_heatmap, cell_mask, pixel_sum, force)
394
- if cell_pixel_sum is None:
395
- return None
396
- region_values = raw_heatmap * cell_mask
397
- region_nonzero = region_values[cell_mask > 0]
398
- cell_max = float(np.max(region_nonzero)) if len(region_nonzero) > 0 else 0
399
- return {
400
- "pixel_sum": cell_pixel_sum,
401
- "force": cell_force,
402
- "max": cell_max,
403
- "mean": cell_mean,
404
- }
405
-
406
-
407
- def compute_region_metrics(raw_heatmap, mask, original_vals=None):
408
- """Compute region metrics from mask."""
409
- area_px = int(np.sum(mask))
410
- region_values = raw_heatmap * mask
411
- region_nonzero = region_values[mask > 0]
412
- force_sum = float(np.sum(region_values))
413
- density = force_sum / area_px if area_px > 0 else 0
414
- region_max = float(np.max(region_nonzero)) if len(region_nonzero) > 0 else 0
415
- region_mean = float(np.mean(region_nonzero)) if len(region_nonzero) > 0 else 0
416
- region_force_scaled = (
417
- force_sum * (original_vals["force"] / original_vals["pixel_sum"])
418
- if original_vals and original_vals.get("pixel_sum", 0) > 0
419
- else force_sum
420
- )
421
- return {
422
- "area_px": area_px,
423
- "force_sum": force_sum,
424
- "density": density,
425
- "max": region_max,
426
- "mean": region_mean,
427
- "force_scaled": region_force_scaled,
428
- }
429
-
430
-
431
- def render_region_metrics_and_downloads(metrics_list, masks, heatmap_rgb, input_filename, key_suffix, has_original_vals,
432
- first_region_label=None, bf_img=None, cell_mask=None, colormap_name="Jet"):
433
- """
434
- Render per-shape metrics table and download buttons.
435
- first_region_label: custom label for first row (e.g. 'Auto boundary').
436
- masks: list of region masks (user-drawn only; used for labeled heatmap with R1, R2...).
437
- """
438
- base_name = os.path.splitext(input_filename or "image")[0]
439
- st.markdown("**Regions (each selection = one row)**")
440
- if has_original_vals:
441
- headers = ["Region", "Area", "F.sum", "Force", "Max", "Mean"]
442
- csv_rows = [["image", "region"] + headers[1:]]
443
- else:
444
- headers = ["Region", "Area (px²)", "Force sum", "Mean"]
445
- csv_rows = [["image", "region", "Area", "Force sum", "Mean"]]
446
- table_rows = [headers]
447
- for i, metrics in enumerate(metrics_list, 1):
448
- region_label = first_region_label if (i == 1 and first_region_label) else f"Region {i - (1 if first_region_label else 0)}"
449
- if has_original_vals:
450
- row = [region_label, str(metrics["area_px"]), f"{metrics['force_sum']:.3f}", f"{metrics['force_scaled']:.1f}",
451
- f"{metrics['max']:.3f}", f"{metrics['mean']:.4f}"]
452
- csv_rows.append([base_name, region_label, metrics["area_px"], f"{metrics['force_sum']:.3f}",
453
- f"{metrics['force_scaled']:.1f}", f"{metrics['max']:.3f}", f"{metrics['mean']:.4f}"])
454
- else:
455
- row = [region_label, str(metrics["area_px"]), f"{metrics['force_sum']:.4f}", f"{metrics['mean']:.6f}"]
456
- csv_rows.append([base_name, region_label, metrics["area_px"], f"{metrics['force_sum']:.4f}",
457
- f"{metrics['mean']:.6f}"])
458
- table_rows.append(row)
459
- # Render as HTML table to avoid Streamlit's default row/column indices
460
- header = table_rows[0]
461
- body = table_rows[1:]
462
- th_cells = "".join(
463
- f'<th style="border: 1px solid #ddd; padding: 8px; text-align: left;">{html.escape(str(h))}</th>'
464
- for h in header
465
- )
466
- rows_html = [
467
- "<tr>"
468
- + "".join(
469
- f'<td style="border: 1px solid #ddd; padding: 8px;">{html.escape(str(c))}</td>'
470
- for c in row
471
- )
472
- + "</tr>"
473
- for row in body
474
- ]
475
- table_html = (
476
- f'<table style="border-collapse: collapse; width: 100%;">'
477
- f"<thead><tr>{th_cells}</tr></thead>"
478
- f"<tbody>{''.join(rows_html)}</tbody></table>"
479
- )
480
- st.markdown(table_html, unsafe_allow_html=True)
481
- buf_csv = io.StringIO()
482
- csv.writer(buf_csv).writerows(csv_rows)
483
- # Annotated heatmap: each region separate with R1, R2 labels (no merging)
484
- # heatmap_rgb already has cell contour if applicable
485
- region_labels = [f"R{i + 1}" for i in range(len(masks))]
486
- heatmap_labeled = make_annotated_heatmap_multi_regions(heatmap_rgb.copy(), masks, region_labels, cell_mask=None)
487
- buf_img = io.BytesIO()
488
- Image.fromarray(heatmap_labeled).save(buf_img, format="PNG")
489
- buf_img.seek(0)
490
- # PDF report (requires bf_img)
491
- pdf_bytes = None
492
- if bf_img is not None:
493
- pdf_bytes = create_measure_pdf_report(bf_img, heatmap_labeled, table_rows, base_name)
494
- n_cols = 3 if pdf_bytes is not None else 2
495
- dl_cols = st.columns(n_cols)
496
- with dl_cols[0]:
497
- st.download_button("Download all regions", data=buf_csv.getvalue(),
498
- file_name=f"{base_name}_all_regions.csv", mime="text/csv",
499
- key=f"download_all_regions_{key_suffix}", icon=":material/download:")
500
- with dl_cols[1]:
501
- st.download_button("Download heatmap", data=buf_img.getvalue(),
502
- file_name=f"{base_name}_annotated_heatmap.png", mime="image/png",
503
- key=f"download_annotated_{key_suffix}", icon=":material/image:")
504
- if pdf_bytes is not None:
505
- with dl_cols[2]:
506
- st.download_button("Download report", data=pdf_bytes,
507
- file_name=f"{base_name}_measure_report.pdf", mime="application/pdf",
508
- key=f"download_measure_pdf_{key_suffix}", icon=":material/picture_as_pdf:")
509
-
510
-
511
- def _draw_contour_on_image(img_rgb, mask, stroke_color=(255, 0, 0), stroke_width=2):
512
- """Draw contour from mask on RGB image. Resizes mask to match img if needed."""
513
- h, w = img_rgb.shape[:2]
514
- if mask.shape[:2] != (h, w):
515
- mask = cv2.resize(mask.astype(np.uint8), (w, h), interpolation=cv2.INTER_NEAREST)
516
- contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
517
- if contours:
518
- cv2.drawContours(img_rgb, contours, -1, stroke_color, stroke_width)
519
- return img_rgb
520
-
521
-
522
- def render_region_canvas(display_heatmap, raw_heatmap=None, bf_img=None, original_vals=None, cell_vals=None,
523
- cell_mask=None, key_suffix="", input_filename=None, colormap_name="Jet"):
524
- """Render drawable canvas and region metrics. When cell_vals: show cell area (replaces Full map). Else: show Full map."""
525
- raw_heatmap = raw_heatmap if raw_heatmap is not None else display_heatmap
526
- h, w = display_heatmap.shape
527
- heatmap_rgb = heatmap_to_rgb_with_contour(display_heatmap, colormap_name, cell_mask)
528
- pil_bg = Image.fromarray(heatmap_rgb).resize((CANVAS_SIZE, CANVAS_SIZE), Image.Resampling.LANCZOS)
529
-
530
- st.markdown("""
531
- <style>
532
- [data-testid="stDialog"] [data-testid="stSelectbox"], [data-testid="stExpander"] [data-testid="stSelectbox"],
533
- [data-testid="stDialog"] [data-testid="stSelectbox"] > div, [data-testid="stExpander"] [data-testid="stSelectbox"] > div {
534
- width: 100% !important; max-width: 100% !important;
535
- }
536
- [data-testid="stDialog"] [data-testid="stMetric"] label, [data-testid="stDialog"] [data-testid="stMetric"] [data-testid="stMetricValue"],
537
- [data-testid="stExpander"] [data-testid="stMetric"] label, [data-testid="stExpander"] [data-testid="stMetric"] [data-testid="stMetricValue"] {
538
- font-size: 0.95rem !important;
539
- }
540
- [data-testid="stDialog"] img, [data-testid="stExpander"] img { border-radius: 0 !important; }
541
- </style>
542
- """, unsafe_allow_html=True)
543
-
544
- if bf_img is not None:
545
- bf_resized = cv2.resize(bf_img, (CANVAS_SIZE, CANVAS_SIZE))
546
- bf_rgb = cv2.cvtColor(bf_resized, cv2.COLOR_GRAY2RGB) if bf_img.ndim == 2 else cv2.cvtColor(bf_resized, cv2.COLOR_BGR2RGB)
547
- left_col, right_col = st.columns(2, gap=None)
548
- with left_col:
549
- draw_mode = st.selectbox("Tool", DRAW_TOOLS, format_func=lambda x: TOOL_LABELS[x], key=f"draw_mode_region_{key_suffix}")
550
- st.caption("Left-click add, right-click close. \nForce map (draw region)")
551
- canvas_result = st_canvas(
552
- fill_color="rgba(255, 165, 0, 0.3)", stroke_width=2, stroke_color="#ff6600",
553
- background_image=pil_bg, drawing_mode=draw_mode, update_streamlit=True,
554
- height=CANVAS_SIZE, width=CANVAS_SIZE, display_toolbar=True,
555
- key=f"region_measure_canvas_{key_suffix}",
556
- )
557
- with right_col:
558
- vals = cell_vals if cell_vals else original_vals
559
- if vals:
560
- label = "Cell area" if cell_vals else "Full map"
561
- st.markdown(f'<p style="font-weight: 400; color: #334155; font-size: 0.95rem; margin: 0 20px 4px 4px;">{label}</p>', unsafe_allow_html=True)
562
- st.markdown(f"""
563
- <div style="width: 100%; box-sizing: border-box; border: 1px solid #e2e8f0; border-radius: 10px;
564
- padding: 10px 12px; margin: 0 10px 20px 10px; background: linear-gradient(145deg, #f8fafc 0%, #f1f5f9 100%);
565
- box-shadow: 0 1px 3px rgba(0,0,0,0.06);">
566
- <div style="display: flex; flex-wrap: wrap; gap: 5px; font-size: 0.9rem;">
567
- <span><strong>Sum:</strong> {vals['pixel_sum']:.1f}</span>
568
- <span><strong>Force:</strong> {vals['force']:.1f}</span>
569
- <span><strong>Max:</strong> {vals['max']:.3f}</span>
570
- <span><strong>Mean:</strong> {vals['mean']:.3f}</span>
571
- </div>
572
- </div>
573
- """, unsafe_allow_html=True)
574
- st.caption("Bright-field")
575
- bf_display = bf_rgb.copy()
576
- if cell_mask is not None and np.any(cell_mask > 0):
577
- bf_display = _draw_contour_on_image(bf_display, cell_mask, stroke_color=(255, 0, 0), stroke_width=2)
578
- st.image(bf_display, width=CANVAS_SIZE)
579
- else:
580
- st.markdown("**Draw a region** on the heatmap.")
581
- draw_mode = st.selectbox("Drawing tool", DRAW_TOOLS,
582
- format_func=lambda x: "Polygon (free shape)" if x == "polygon" else TOOL_LABELS[x],
583
- key=f"draw_mode_region_{key_suffix}")
584
- st.caption("Polygon: left-click to add points, right-click to close.")
585
- canvas_result = st_canvas(
586
- fill_color="rgba(255, 165, 0, 0.3)", stroke_width=2, stroke_color="#ff6600",
587
- background_image=pil_bg, drawing_mode=draw_mode, update_streamlit=True,
588
- height=CANVAS_SIZE, width=CANVAS_SIZE, display_toolbar=True,
589
- key=f"region_measure_canvas_{key_suffix}",
590
- )
591
-
592
- if canvas_result.json_data:
593
- masks = parse_canvas_shapes_to_masks(canvas_result.json_data, CANVAS_SIZE, CANVAS_SIZE, h, w)
594
- if masks:
595
- metrics_list = [compute_region_metrics(raw_heatmap, m, original_vals) for m in masks]
596
- if cell_mask is not None and np.any(cell_mask > 0):
597
- cell_metrics = compute_region_metrics(raw_heatmap, cell_mask, original_vals)
598
- metrics_list = [cell_metrics] + metrics_list
599
- render_region_metrics_and_downloads(
600
- metrics_list, masks, heatmap_rgb, input_filename, key_suffix, original_vals is not None,
601
- first_region_label="Auto boundary" if (cell_mask is not None and np.any(cell_mask > 0)) else None,
602
- bf_img=bf_img, cell_mask=cell_mask, colormap_name=colormap_name,
603
- )
604
-
605
-
606
- def _compute_cell_metrics(raw_heatmap, cell_mask, pixel_sum, force):
607
- """Compute metrics over estimated cell area only."""
608
- area_px = int(np.sum(cell_mask))
609
- if area_px == 0:
610
- return None, None, None
611
- region_values = raw_heatmap * cell_mask
612
- cell_pixel_sum = float(np.sum(region_values))
613
- cell_force = cell_pixel_sum * (force / pixel_sum) if pixel_sum > 0 else cell_pixel_sum
614
- cell_mean = cell_pixel_sum / area_px
615
- return cell_pixel_sum, cell_force, cell_mean
616
-
617
-
618
- def _add_cell_contour_to_fig(fig_pl, cell_mask, row=1, col=2):
619
- """Add red contour overlay to Plotly heatmap subplot."""
620
- if cell_mask is None or not np.any(cell_mask > 0):
621
- return
622
- contours, _ = cv2.findContours(cell_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
623
- if not contours:
624
- return
625
- # Use largest contour
626
- cnt = max(contours, key=cv2.contourArea)
627
- pts = cnt.squeeze()
628
- if pts.ndim == 1:
629
- pts = pts.reshape(1, 2)
630
- x, y = pts[:, 0].tolist(), pts[:, 1].tolist()
631
- if x[0] != x[-1] or y[0] != y[-1]:
632
- x.append(x[0])
633
- y.append(y[0])
634
- fig_pl.add_trace(
635
- go.Scatter(x=x, y=y, mode="lines", line=dict(color="red", width=2), showlegend=False),
636
- row=row, col=col
637
- )
638
-
639
-
640
- def render_result_display(img, raw_heatmap, display_heatmap, pixel_sum, force, key_img, download_key_suffix="",
641
- colormap_name="Jet", display_mode="Default", measure_region_dialog=None, auto_cell_boundary=True,
642
- cell_mask=None):
643
- """
644
- Render prediction result: plot, metrics, expander, and download/measure buttons.
645
- measure_region_dialog: callable to open measure dialog (when ST_DIALOG available).
646
- auto_cell_boundary: when True, use estimated cell area for metrics; when False, use entire map.
647
- cell_mask: optional precomputed cell mask; if None and auto_cell_boundary, will be computed.
648
- """
649
- if cell_mask is None and auto_cell_boundary:
650
- cell_mask = estimate_cell_mask(raw_heatmap)
651
- elif not auto_cell_boundary:
652
- cell_mask = None
653
- cell_pixel_sum, cell_force, cell_mean = _compute_cell_metrics(raw_heatmap, cell_mask, pixel_sum, force) if cell_mask is not None else (None, None, None)
654
- use_cell_metrics = auto_cell_boundary and cell_pixel_sum is not None and cell_force is not None and cell_mean is not None
655
-
656
- base_name = os.path.splitext(key_img or "image")[0]
657
- if use_cell_metrics:
658
- main_csv_rows = [
659
- ["image", "Cell sum", "Cell force (scaled)", "Heatmap max", "Heatmap mean"],
660
- [base_name, f"{cell_pixel_sum:.2f}", f"{cell_force:.2f}",
661
- f"{np.max(raw_heatmap):.4f}", f"{cell_mean:.4f}"],
662
- ]
663
- else:
664
- main_csv_rows = [
665
- ["image", "Sum of all pixels", "Cell force (scaled)", "Heatmap max", "Heatmap mean"],
666
- [base_name, f"{pixel_sum:.2f}", f"{force:.2f}",
667
- f"{np.max(raw_heatmap):.4f}", f"{np.mean(raw_heatmap):.4f}"],
668
- ]
669
- buf_main_csv = io.StringIO()
670
- csv.writer(buf_main_csv).writerows(main_csv_rows)
671
-
672
- buf_hm = heatmap_to_png_bytes(display_heatmap, colormap_name, cell_mask=cell_mask)
673
-
674
- tit1, tit2 = st.columns(2)
675
- with tit1:
676
- st.markdown('<p style="font-size: 1.1rem; color: black; font-weight: 600;">Input: Bright-field image</p>', unsafe_allow_html=True)
677
- with tit2:
678
- st.markdown('<p style="font-size: 1.1rem; color: black; font-weight: 600;">Output: Predicted traction force map</p>', unsafe_allow_html=True)
679
- fig_pl = make_subplots(rows=1, cols=2)
680
- fig_pl.add_trace(go.Heatmap(z=img, colorscale="gray", showscale=False), row=1, col=1)
681
- plotly_colorscale = cv_colormap_to_plotly_colorscale(colormap_name)
682
- zmin, zmax = 0.0, 1.0
683
- fig_pl.add_trace(go.Heatmap(z=display_heatmap, colorscale=plotly_colorscale, zmin=zmin, zmax=zmax, showscale=True,
684
- colorbar=dict(len=0.4, thickness=12)), row=1, col=2)
685
- _add_cell_contour_to_fig(fig_pl, cell_mask, row=1, col=2)
686
- fig_pl.update_layout(
687
- height=400,
688
- margin=dict(l=10, r=10, t=10, b=10),
689
- xaxis=dict(scaleanchor="y", scaleratio=1),
690
- xaxis2=dict(scaleanchor="y2", scaleratio=1),
691
- )
692
- fig_pl.update_xaxes(showticklabels=False)
693
- fig_pl.update_yaxes(showticklabels=False, autorange="reversed")
694
- st.plotly_chart(fig_pl, use_container_width=True, config={"displayModeBar": True, "responsive": True})
695
-
696
- col1, col2, col3, col4 = st.columns(4)
697
- if use_cell_metrics:
698
- with col1:
699
- st.metric("Cell sum", f"{cell_pixel_sum:.2f}", help="Sum over estimated cell area (background excluded)")
700
- with col2:
701
- st.metric("Cell force (scaled)", f"{cell_force:.2f}", help="Total traction force in physical units")
702
- with col3:
703
- st.metric("Heatmap max", f"{np.max(raw_heatmap):.4f}", help="Peak force intensity in the map")
704
- with col4:
705
- st.metric("Heatmap mean", f"{cell_mean:.4f}", help="Mean force over estimated cell area")
706
- else:
707
- with col1:
708
- st.metric("Sum of all pixels", f"{pixel_sum:.2f}", help="Raw sum of all pixel values in the force map")
709
- with col2:
710
- st.metric("Cell force (scaled)", f"{force:.2f}", help="Total traction force in physical units")
711
- with col3:
712
- st.metric("Heatmap max", f"{np.max(raw_heatmap):.4f}", help="Peak force intensity in the map")
713
- with col4:
714
- st.metric("Heatmap mean", f"{np.mean(raw_heatmap):.4f}", help="Average force intensity (full FOV)")
715
-
716
- # Statistics panel (mean, std, percentiles, histogram)
717
- with st.expander("Statistics"):
718
- vals = raw_heatmap[cell_mask > 0] if (cell_mask is not None and np.any(cell_mask > 0) and use_cell_metrics) else raw_heatmap.flatten()
719
- if len(vals) > 0:
720
- st.markdown("**Summary**")
721
- stat_col1, stat_col2, stat_col3 = st.columns(3)
722
- with stat_col1:
723
- st.metric("Mean", f"{float(np.mean(vals)):.4f}")
724
- st.metric("Std", f"{float(np.std(vals)):.4f}")
725
- with stat_col2:
726
- p25, p50, p75 = float(np.percentile(vals, 25)), float(np.percentile(vals, 50)), float(np.percentile(vals, 75))
727
- st.metric("P25", f"{p25:.4f}")
728
- st.metric("P50 (median)", f"{p50:.4f}")
729
- st.metric("P75", f"{p75:.4f}")
730
- with stat_col3:
731
- p90 = float(np.percentile(vals, 90))
732
- st.metric("P90", f"{p90:.4f}")
733
- st.markdown("**Histogram**")
734
- hist_fig = go.Figure(data=[go.Histogram(x=vals, nbinsx=50, marker_color="#0d9488")])
735
- hist_fig.update_layout(
736
- height=220, margin=dict(l=40, r=20, t=20, b=40),
737
- xaxis_title="Force value", yaxis_title="Count",
738
- showlegend=False,
739
- )
740
- st.plotly_chart(hist_fig, use_container_width=True, config={"displayModeBar": False})
741
- else:
742
- st.caption("No nonzero values to compute statistics.")
743
-
744
- with st.expander("How to read the results"):
745
- if use_cell_metrics:
746
- st.markdown("""
747
- **Input (left):** Bright-field microscopy image of a cell or spheroid on a substrate.
748
- This is the raw image you provided—it shows cell shape but not forces.
749
-
750
- **Output (right):** Predicted traction force map.
751
- - **Color** indicates force magnitude: blue = low, red = high
752
- - **Brighter/warmer colors** = stronger forces exerted by the cell on the substrate
753
- - **Red border = estimated cell area** (background excluded from metrics)
754
- - Values are normalized to [0, 1] for visualization
755
-
756
- **Metrics (auto cell boundary on):**
757
- - **Cell sum:** Sum over estimated cell area (background excluded)
758
- - **Cell force (scaled):** Total traction force in physical units
759
- - **Heatmap max:** Peak force intensity in the map
760
- - **Heatmap mean:** Mean force over the estimated cell area
761
- """)
762
- else:
763
- st.markdown("""
764
- **Input (left):** Bright-field microscopy image of a cell or spheroid on a substrate.
765
- This is the raw image you provided—it shows cell shape but not forces.
766
-
767
- **Output (right):** Predicted traction force map.
768
- - **Color** indicates force magnitude: blue = low, red = high
769
- - **Brighter/warmer colors** = stronger forces exerted by the cell on the substrate
770
- - Values are normalized to [0, 1] for visualization
771
-
772
- **Metrics (auto cell boundary off):**
773
- - **Sum of all pixels:** Raw sum over entire map
774
- - **Cell force (scaled):** Total traction force in physical units
775
- - **Heatmap max/mean:** Peak and average force intensity (full field of view)
776
- """)
777
-
778
- original_vals = build_original_vals(raw_heatmap, pixel_sum, force)
779
-
780
- pdf_bytes = create_pdf_report(
781
- img, display_heatmap, raw_heatmap, pixel_sum, force, base_name, colormap_name,
782
- cell_mask=cell_mask, cell_pixel_sum=cell_pixel_sum, cell_force=cell_force, cell_mean=cell_mean
783
- )
784
-
785
- btn_col1, btn_col2, btn_col3, btn_col4 = st.columns(4)
786
- with btn_col1:
787
- if HAS_DRAWABLE_CANVAS and measure_region_dialog is not None:
788
- if st.button("Measure tool", key="open_measure", icon=":material/straighten:"):
789
- st.session_state["open_measure_dialog"] = True
790
- st.rerun()
791
- elif HAS_DRAWABLE_CANVAS:
792
- with st.expander("Measure tool"):
793
- expander_cell_vals = build_cell_vals(raw_heatmap, cell_mask, pixel_sum, force) if (auto_cell_boundary and cell_mask is not None) else None
794
- expander_cell_mask = cell_mask if auto_cell_boundary else None
795
- render_region_canvas(
796
- display_heatmap,
797
- raw_heatmap=raw_heatmap,
798
- bf_img=img,
799
- original_vals=original_vals,
800
- cell_vals=expander_cell_vals,
801
- cell_mask=expander_cell_mask,
802
- key_suffix="expander",
803
- input_filename=key_img,
804
- colormap_name=colormap_name,
805
- )
806
- else:
807
- st.caption("Install `streamlit-drawable-canvas-fix` for region measurement: `pip install streamlit-drawable-canvas-fix`")
808
- with btn_col2:
809
- st.download_button(
810
- "Download heatmap",
811
- width="stretch",
812
- data=buf_hm.getvalue(),
813
- file_name="s2f_heatmap.png",
814
- mime="image/png",
815
- key=f"download_heatmap{download_key_suffix}",
816
- icon=":material/download:",
817
- )
818
- with btn_col3:
819
- st.download_button(
820
- "Download values",
821
- width="stretch",
822
- data=buf_main_csv.getvalue(),
823
- file_name=f"{base_name}_main_values.csv",
824
- mime="text/csv",
825
- key=f"download_main_values{download_key_suffix}",
826
- icon=":material/download:",
827
- )
828
- with btn_col4:
829
- st.download_button(
830
- "Download report",
831
- width="stretch",
832
- data=pdf_bytes,
833
- file_name=f"{base_name}_report.pdf",
834
- mime="application/pdf",
835
- key=f"download_pdf{download_key_suffix}",
836
- icon=":material/picture_as_pdf:",
837
- )
 
1
+ """UI components for S2F App. Re-exports from submodules for backward compatibility."""
 
 
 
 
 
 
 
 
2
  import streamlit as st
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
 
4
+ # Resolve st.dialog early to fix ordering bug (used in measure dialog)
 
 
 
 
 
 
 
 
 
 
 
 
5
  ST_DIALOG = getattr(st, "dialog", None) or getattr(st, "experimental_dialog", None)
6
 
7
+ from ui.system_status import render_system_status
8
+ from ui.result_display import render_batch_results, render_result_display
9
+ from ui.measure_tool import (
10
+ build_original_vals,
11
+ build_cell_vals,
12
+ render_region_canvas,
13
+ parse_canvas_shapes_to_masks,
14
+ compute_region_metrics,
15
+ HAS_DRAWABLE_CANVAS,
16
+ )
17
 
18
+ __all__ = [
19
+ "ST_DIALOG",
20
+ "HAS_DRAWABLE_CANVAS",
21
+ "render_system_status",
22
+ "render_batch_results",
23
+ "render_result_display",
24
+ "build_original_vals",
25
+ "build_cell_vals",
26
+ "render_region_canvas",
27
+ "parse_canvas_shapes_to_masks",
28
+ "compute_region_metrics",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
S2FApp/ui/heatmaps.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Heatmap visualization utilities (colorbar, overlays, Plotly)."""
2
+ import base64
3
+
4
+ import cv2
5
+ import numpy as np
6
+ import streamlit as st
7
+ import plotly.graph_objects as go
8
+
9
+ from config.constants import COLORMAPS
10
+
11
+
12
+ def _colormap_gradient_base64(colormap_name, width=512):
13
+ """Generate a horizontal gradient bar as base64 PNG for the given colormap."""
14
+ cv2_cmap = COLORMAPS.get(colormap_name, cv2.COLORMAP_JET)
15
+ gradient = np.linspace(0, 255, width, dtype=np.uint8).reshape(1, -1)
16
+ rgb = cv2.cvtColor(cv2.applyColorMap(gradient, cv2_cmap), cv2.COLOR_BGR2RGB)
17
+ bar = np.repeat(rgb, 6, axis=0)
18
+ _, buf = cv2.imencode(".png", cv2.cvtColor(bar, cv2.COLOR_RGB2BGR))
19
+ return base64.b64encode(buf.tobytes()).decode("utf-8")
20
+
21
+ # Distinct colors for each region (RGB - heatmap_rgb is RGB)
22
+ _REGION_COLORS = [
23
+ (0, 188, 212), # cyan (matches drawing tool)
24
+ (0, 230, 118), # green
25
+ (255, 235, 59), # yellow
26
+ (171, 71, 188), # purple
27
+ (0, 150, 255), # blue
28
+ (255, 167, 38), # amber
29
+ (124, 179, 66), # light green
30
+ (233, 30, 99), # pink
31
+ ]
32
+
33
+
34
+ def _draw_region_overlay(annotated, mask, color, fill_alpha=0.3, stroke_width=2):
35
+ """Draw single region overlay on annotated heatmap (fill + alpha blend + contour). Modifies annotated in place."""
36
+ contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
37
+ overlay = annotated.copy()
38
+ cv2.fillPoly(overlay, contours, color)
39
+ mask_3d = np.stack([mask] * 3, axis=-1).astype(bool)
40
+ annotated[mask_3d] = (
41
+ (1 - fill_alpha) * annotated[mask_3d].astype(np.float32)
42
+ + fill_alpha * overlay[mask_3d].astype(np.float32)
43
+ ).astype(np.uint8)
44
+ cv2.drawContours(annotated, contours, -1, color, stroke_width)
45
+
46
+
47
+ def render_horizontal_colorbar(colormap_name, clip_min=0, clip_max=1, is_rescale=False):
48
+ """Render a compact horizontal colorbar for batch mode, anchored above the table."""
49
+ ticks = [0, 0.25, 0.5, 0.75, 1]
50
+ if is_rescale:
51
+ rng = clip_max - clip_min
52
+ labels = [f"{clip_min + t * rng:.2f}" for t in ticks]
53
+ else:
54
+ labels = [f"{t:.2f}" for t in ticks]
55
+
56
+ data_url = _colormap_gradient_base64(colormap_name)
57
+ labels_html = "".join(f'<span class="cb-tick">{l}</span>' for l in labels)
58
+ html = f"""
59
+ <div class="colorbar-table-header">
60
+ <div class="colorbar-ticks">{labels_html}</div>
61
+ <div class="colorbar-bar" style="background-image: url(data:image/png;base64,{data_url});"></div>
62
+ </div>
63
+ """
64
+ st.markdown(html, unsafe_allow_html=True)
65
+
66
+
67
+ def make_annotated_heatmap(heatmap_rgb, mask, fill_alpha=0.3, stroke_color=(0, 188, 212), stroke_width=2):
68
+ """Composite heatmap with drawn region overlay."""
69
+ annotated = heatmap_rgb.copy()
70
+ _draw_region_overlay(annotated, mask, stroke_color, fill_alpha, stroke_width)
71
+ return annotated
72
+
73
+
74
+ def make_annotated_heatmap_multi_regions(heatmap_rgb, masks, labels, cell_mask=None, fill_alpha=0.3):
75
+ """Draw each region separately with distinct color and label (R1, R2, ...). No merging."""
76
+ annotated = heatmap_rgb.copy()
77
+ if cell_mask is not None and np.any(cell_mask > 0):
78
+ contours, _ = cv2.findContours(cell_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
79
+ cv2.drawContours(annotated, contours, -1, (255, 0, 0), 5)
80
+ for i, mask in enumerate(masks):
81
+ color = _REGION_COLORS[i % len(_REGION_COLORS)]
82
+ _draw_region_overlay(annotated, mask, color, fill_alpha, stroke_width=2)
83
+ # Label at centroid
84
+ M = cv2.moments(mask)
85
+ if M["m00"] > 0:
86
+ cx = int(M["m10"] / M["m00"])
87
+ cy = int(M["m01"] / M["m00"])
88
+ label = labels[i] if i < len(labels) else f"R{i + 1}"
89
+ cv2.putText(
90
+ annotated, label, (cx - 12, cy + 5),
91
+ cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2, cv2.LINE_AA
92
+ )
93
+ cv2.putText(
94
+ annotated, label, (cx - 12, cy + 5),
95
+ cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 0, 0), 1, cv2.LINE_AA
96
+ )
97
+ return annotated
98
+
99
+
100
+ def add_cell_contour_to_fig(fig_pl, cell_mask, row=1, col=2):
101
+ """Add red contour overlay to Plotly heatmap subplot."""
102
+ if cell_mask is None or not np.any(cell_mask > 0):
103
+ return
104
+ contours, _ = cv2.findContours(cell_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
105
+ if not contours:
106
+ return
107
+ # Use largest contour
108
+ cnt = max(contours, key=cv2.contourArea)
109
+ pts = cnt.squeeze()
110
+ if pts.ndim == 1:
111
+ pts = pts.reshape(1, 2)
112
+ x, y = pts[:, 0].tolist(), pts[:, 1].tolist()
113
+ if x[0] != x[-1] or y[0] != y[-1]:
114
+ x.append(x[0])
115
+ y.append(y[0])
116
+ fig_pl.add_trace(
117
+ go.Scatter(x=x, y=y, mode="lines", line=dict(color="red", width=4), showlegend=False),
118
+ row=row, col=col
119
+ )
S2FApp/ui/measure_tool.py ADDED
@@ -0,0 +1,333 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Measure tool: drawable canvas, region metrics, and downloads."""
2
+ import csv
3
+ import html
4
+ import io
5
+ import os
6
+
7
+ import cv2
8
+ import numpy as np
9
+ import streamlit as st
10
+ from PIL import Image
11
+
12
+ from config.constants import CANVAS_SIZE, DRAW_TOOLS, TOOL_LABELS
13
+ from utils.report import heatmap_to_rgb_with_contour, create_measure_pdf_report
14
+ from ui.heatmaps import make_annotated_heatmap_multi_regions
15
+
16
+ try:
17
+ from streamlit_drawable_canvas import st_canvas
18
+ HAS_DRAWABLE_CANVAS = True
19
+ except (ImportError, AttributeError):
20
+ HAS_DRAWABLE_CANVAS = False
21
+
22
+
23
+ def _obj_to_pts(obj, scale_x, scale_y, heatmap_w, heatmap_h):
24
+ """Convert a single canvas object to polygon points in heatmap coords. Returns None if invalid."""
25
+ obj_type = obj.get("type", "")
26
+ pts = []
27
+ if obj_type == "rect":
28
+ left = obj.get("left", 0)
29
+ top = obj.get("top", 0)
30
+ w = obj.get("width", 0)
31
+ h = obj.get("height", 0)
32
+ pts = np.array([
33
+ [left, top], [left + w, top], [left + w, top + h], [left, top + h]
34
+ ], dtype=np.float32)
35
+ elif obj_type == "circle" or obj_type == "ellipse":
36
+ left = obj.get("left", 0)
37
+ top = obj.get("top", 0)
38
+ width = obj.get("width", 0)
39
+ height = obj.get("height", 0)
40
+ radius = obj.get("radius", 0)
41
+ angle_deg = obj.get("angle", 0)
42
+ if radius > 0:
43
+ rx = ry = radius
44
+ angle_rad = np.deg2rad(angle_deg)
45
+ cx = left + radius * np.cos(angle_rad)
46
+ cy = top + radius * np.sin(angle_rad)
47
+ else:
48
+ rx = width / 2 if width > 0 else 0
49
+ ry = height / 2 if height > 0 else 0
50
+ if rx <= 0 or ry <= 0:
51
+ return None
52
+ cx = left + rx
53
+ cy = top + ry
54
+ if rx <= 0 or ry <= 0:
55
+ return None
56
+ n = 32
57
+ angles = np.linspace(0, 2 * np.pi, n, endpoint=False)
58
+ pts = np.column_stack([cx + rx * np.cos(angles), cy + ry * np.sin(angles)]).astype(np.float32)
59
+ elif obj_type == "path":
60
+ path = obj.get("path", [])
61
+ for cmd in path:
62
+ if isinstance(cmd, (list, tuple)) and len(cmd) >= 3:
63
+ if cmd[0] in ("M", "L"):
64
+ pts.append([float(cmd[1]), float(cmd[2])])
65
+ elif cmd[0] == "Q" and len(cmd) >= 5:
66
+ pts.append([float(cmd[3]), float(cmd[4])])
67
+ elif cmd[0] == "C" and len(cmd) >= 7:
68
+ pts.append([float(cmd[5]), float(cmd[6])])
69
+ if len(pts) < 3:
70
+ return None
71
+ pts = np.array(pts, dtype=np.float32)
72
+ else:
73
+ return None
74
+ pts[:, 0] *= scale_x
75
+ pts[:, 1] *= scale_y
76
+ pts = np.clip(pts, 0, [heatmap_w - 1, heatmap_h - 1]).astype(np.int32)
77
+ return pts
78
+
79
+
80
+ def parse_canvas_shapes_to_masks(json_data, canvas_h, canvas_w, heatmap_h, heatmap_w):
81
+ """Parse drawn shapes and return a list of individual masks (one per shape)."""
82
+ if not json_data or "objects" not in json_data or not json_data["objects"]:
83
+ return []
84
+ scale_x = heatmap_w / canvas_w
85
+ scale_y = heatmap_h / canvas_h
86
+ masks = []
87
+ for obj in json_data["objects"]:
88
+ pts = _obj_to_pts(obj, scale_x, scale_y, heatmap_w, heatmap_h)
89
+ if pts is None:
90
+ continue
91
+ mask = np.zeros((heatmap_h, heatmap_w), dtype=np.uint8)
92
+ cv2.fillPoly(mask, [pts], 1)
93
+ masks.append(mask)
94
+ return masks
95
+
96
+
97
+ def build_original_vals(raw_heatmap, pixel_sum, force):
98
+ """Build original_vals dict for measure tool (full map)."""
99
+ return {
100
+ "pixel_sum": pixel_sum,
101
+ "force": force,
102
+ "max": float(np.max(raw_heatmap)),
103
+ "mean": float(np.mean(raw_heatmap)),
104
+ }
105
+
106
+
107
+ def _compute_cell_metrics(raw_heatmap, cell_mask, pixel_sum, force):
108
+ """Compute metrics over estimated cell area only."""
109
+ area_px = int(np.sum(cell_mask))
110
+ if area_px == 0:
111
+ return None, None, None
112
+ region_values = raw_heatmap * cell_mask
113
+ cell_pixel_sum = float(np.sum(region_values))
114
+ cell_force = cell_pixel_sum * (force / pixel_sum) if pixel_sum > 0 else cell_pixel_sum
115
+ cell_mean = cell_pixel_sum / area_px
116
+ return cell_pixel_sum, cell_force, cell_mean
117
+
118
+
119
+ def build_cell_vals(raw_heatmap, cell_mask, pixel_sum, force):
120
+ """Build cell_vals dict for measure tool (estimated cell area). Returns None if invalid."""
121
+ cell_pixel_sum, cell_force, cell_mean = _compute_cell_metrics(raw_heatmap, cell_mask, pixel_sum, force)
122
+ if cell_pixel_sum is None:
123
+ return None
124
+ region_values = raw_heatmap * cell_mask
125
+ region_nonzero = region_values[cell_mask > 0]
126
+ cell_max = float(np.max(region_nonzero)) if len(region_nonzero) > 0 else 0
127
+ return {
128
+ "pixel_sum": cell_pixel_sum,
129
+ "force": cell_force,
130
+ "max": cell_max,
131
+ "mean": cell_mean,
132
+ }
133
+
134
+
135
+ def compute_region_metrics(raw_heatmap, mask, original_vals=None):
136
+ """Compute region metrics from mask."""
137
+ area_px = int(np.sum(mask))
138
+ region_values = raw_heatmap * mask
139
+ region_nonzero = region_values[mask > 0]
140
+ force_sum = float(np.sum(region_values))
141
+ density = force_sum / area_px if area_px > 0 else 0
142
+ region_max = float(np.max(region_nonzero)) if len(region_nonzero) > 0 else 0
143
+ region_mean = float(np.mean(region_nonzero)) if len(region_nonzero) > 0 else 0
144
+ region_force_scaled = (
145
+ force_sum * (original_vals["force"] / original_vals["pixel_sum"])
146
+ if original_vals and original_vals.get("pixel_sum", 0) > 0
147
+ else force_sum
148
+ )
149
+ return {
150
+ "area_px": area_px,
151
+ "force_sum": force_sum,
152
+ "density": density,
153
+ "max": region_max,
154
+ "mean": region_mean,
155
+ "force_scaled": region_force_scaled,
156
+ }
157
+
158
+
159
+ def _draw_contour_on_image(img_rgb, mask, stroke_color=(255, 0, 0), stroke_width=5):
160
+ """Draw contour from mask on RGB image. Resizes mask to match img if needed."""
161
+ h, w = img_rgb.shape[:2]
162
+ if mask.shape[:2] != (h, w):
163
+ mask = cv2.resize(mask.astype(np.uint8), (w, h), interpolation=cv2.INTER_NEAREST)
164
+ contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
165
+ if contours:
166
+ cv2.drawContours(img_rgb, contours, -1, stroke_color, stroke_width)
167
+ return img_rgb
168
+
169
+
170
+ def render_region_metrics_and_downloads(metrics_list, masks, heatmap_rgb, input_filename, key_suffix, has_original_vals,
171
+ first_region_label=None, bf_img=None, cell_mask=None, colormap_name="Jet"):
172
+ """
173
+ Render per-shape metrics table and download buttons.
174
+ first_region_label: custom label for first row (e.g. 'Auto boundary').
175
+ masks: list of region masks (user-drawn only; used for labeled heatmap with R1, R2...).
176
+ """
177
+ base_name = os.path.splitext(input_filename or "image")[0]
178
+ st.markdown("**Regions (each selection = one row)**")
179
+ if has_original_vals:
180
+ headers = ["Region", "Area", "F.sum", "Force", "Max", "Mean"]
181
+ csv_rows = [["image", "region"] + headers[1:]]
182
+ else:
183
+ headers = ["Region", "Area (px²)", "Force sum", "Mean"]
184
+ csv_rows = [["image", "region", "Area", "Force sum", "Mean"]]
185
+ table_rows = [headers]
186
+ for i, metrics in enumerate(metrics_list, 1):
187
+ region_label = first_region_label if (i == 1 and first_region_label) else f"Region {i - (1 if first_region_label else 0)}"
188
+ if has_original_vals:
189
+ row = [region_label, str(metrics["area_px"]), f"{metrics['force_sum']:.3f}", f"{metrics['force_scaled']:.1f}",
190
+ f"{metrics['max']:.3f}", f"{metrics['mean']:.4f}"]
191
+ csv_rows.append([base_name, region_label, metrics["area_px"], f"{metrics['force_sum']:.3f}",
192
+ f"{metrics['force_scaled']:.1f}", f"{metrics['max']:.3f}", f"{metrics['mean']:.4f}"])
193
+ else:
194
+ row = [region_label, str(metrics["area_px"]), f"{metrics['force_sum']:.4f}", f"{metrics['mean']:.6f}"]
195
+ csv_rows.append([base_name, region_label, metrics["area_px"], f"{metrics['force_sum']:.4f}",
196
+ f"{metrics['mean']:.6f}"])
197
+ table_rows.append(row)
198
+ # Render as HTML table to avoid Streamlit's default row/column indices
199
+ header = table_rows[0]
200
+ body = table_rows[1:]
201
+ th_cells = "".join(
202
+ f'<th style="border: 1px solid #ddd; padding: 8px; text-align: left;">{html.escape(str(h))}</th>'
203
+ for h in header
204
+ )
205
+ rows_html = [
206
+ "<tr>"
207
+ + "".join(
208
+ f'<td style="border: 1px solid #ddd; padding: 8px;">{html.escape(str(c))}</td>'
209
+ for c in row
210
+ )
211
+ + "</tr>"
212
+ for row in body
213
+ ]
214
+ table_html = (
215
+ f'<table style="border-collapse: collapse; width: 100%;">'
216
+ f"<thead><tr>{th_cells}</tr></thead>"
217
+ f"<tbody>{''.join(rows_html)}</tbody></table>"
218
+ )
219
+ st.markdown(table_html, unsafe_allow_html=True)
220
+ buf_csv = io.StringIO()
221
+ csv.writer(buf_csv).writerows(csv_rows)
222
+ # Annotated heatmap: each region separate with R1, R2 labels (no merging)
223
+ region_labels = [f"R{i + 1}" for i in range(len(masks))]
224
+ heatmap_labeled = make_annotated_heatmap_multi_regions(heatmap_rgb.copy(), masks, region_labels, cell_mask=None)
225
+ buf_img = io.BytesIO()
226
+ Image.fromarray(heatmap_labeled).save(buf_img, format="PNG")
227
+ buf_img.seek(0)
228
+ # PDF report (requires bf_img)
229
+ pdf_bytes = None
230
+ if bf_img is not None:
231
+ pdf_bytes = create_measure_pdf_report(bf_img, heatmap_labeled, table_rows, base_name)
232
+ n_cols = 3 if pdf_bytes is not None else 2
233
+ dl_cols = st.columns(n_cols)
234
+ with dl_cols[0]:
235
+ st.download_button("Download all regions", data=buf_csv.getvalue(),
236
+ file_name=f"{base_name}_all_regions.csv", mime="text/csv",
237
+ key=f"download_all_regions_{key_suffix}", icon=":material/download:")
238
+ with dl_cols[1]:
239
+ st.download_button("Download heatmap", data=buf_img.getvalue(),
240
+ file_name=f"{base_name}_annotated_heatmap.png", mime="image/png",
241
+ key=f"download_annotated_{key_suffix}", icon=":material/image:")
242
+ if pdf_bytes is not None:
243
+ with dl_cols[2]:
244
+ st.download_button("Download report", data=pdf_bytes,
245
+ file_name=f"{base_name}_measure_report.pdf", mime="application/pdf",
246
+ key=f"download_measure_pdf_{key_suffix}", icon=":material/picture_as_pdf:")
247
+
248
+
249
+ def render_region_canvas(display_heatmap, raw_heatmap=None, bf_img=None, original_vals=None, cell_vals=None,
250
+ cell_mask=None, key_suffix="", input_filename=None, colormap_name="Jet"):
251
+ """Render drawable canvas and region metrics. When cell_vals: show cell area (replaces Full map). Else: show Full map."""
252
+ if not HAS_DRAWABLE_CANVAS:
253
+ st.caption("Install `streamlit-drawable-canvas-fix` for region measurement: `pip install streamlit-drawable-canvas-fix`")
254
+ return
255
+ raw_heatmap = raw_heatmap if raw_heatmap is not None else display_heatmap
256
+ h, w = display_heatmap.shape
257
+ heatmap_rgb = heatmap_to_rgb_with_contour(display_heatmap, colormap_name, cell_mask)
258
+ pil_bg = Image.fromarray(heatmap_rgb).resize((CANVAS_SIZE, CANVAS_SIZE), Image.Resampling.LANCZOS)
259
+
260
+ st.markdown("""
261
+ <style>
262
+ [data-testid="stDialog"] [data-testid="stSelectbox"], [data-testid="stExpander"] [data-testid="stSelectbox"],
263
+ [data-testid="stDialog"] [data-testid="stSelectbox"] > div, [data-testid="stExpander"] [data-testid="stSelectbox"] > div {
264
+ width: 100% !important; max-width: 100% !important;
265
+ }
266
+ [data-testid="stDialog"] [data-testid="stMetric"] label, [data-testid="stDialog"] [data-testid="stMetric"] [data-testid="stMetricValue"],
267
+ [data-testid="stExpander"] [data-testid="stMetric"] label, [data-testid="stExpander"] [data-testid="stMetric"] [data-testid="stMetricValue"] {
268
+ font-size: 0.95rem !important;
269
+ }
270
+ [data-testid="stDialog"] img, [data-testid="stExpander"] img { border-radius: 0 !important; }
271
+ </style>
272
+ """, unsafe_allow_html=True)
273
+
274
+ if bf_img is not None:
275
+ bf_resized = cv2.resize(bf_img, (CANVAS_SIZE, CANVAS_SIZE))
276
+ bf_rgb = cv2.cvtColor(bf_resized, cv2.COLOR_GRAY2RGB) if bf_img.ndim == 2 else cv2.cvtColor(bf_resized, cv2.COLOR_BGR2RGB)
277
+ left_col, right_col = st.columns(2, gap=None)
278
+ with left_col:
279
+ draw_mode = st.selectbox("Tool", DRAW_TOOLS, format_func=lambda x: TOOL_LABELS[x], key=f"draw_mode_region_{key_suffix}")
280
+ st.caption("Left-click add, right-click close. \nForce map (draw region)")
281
+ canvas_result = st_canvas(
282
+ fill_color="rgba(0, 188, 212, 0.25)", stroke_width=2, stroke_color="#00bcd4",
283
+ background_image=pil_bg, drawing_mode=draw_mode, update_streamlit=True,
284
+ height=CANVAS_SIZE, width=CANVAS_SIZE, display_toolbar=True,
285
+ key=f"region_measure_canvas_{key_suffix}",
286
+ )
287
+ with right_col:
288
+ vals = cell_vals if cell_vals else original_vals
289
+ if vals:
290
+ label = "Cell area" if cell_vals else "Full map"
291
+ st.markdown(f'<p style="font-weight: 400; color: #334155; font-size: 0.95rem; margin: 0 20px 4px 4px;">{label}</p>', unsafe_allow_html=True)
292
+ st.markdown(f"""
293
+ <div style="width: 100%; box-sizing: border-box; border: 1px solid #e2e8f0; border-radius: 10px;
294
+ padding: 10px 12px; margin: 0 10px 20px 10px; background: linear-gradient(145deg, #f8fafc 0%, #f1f5f9 100%);
295
+ box-shadow: 0 1px 3px rgba(0,0,0,0.06);">
296
+ <div style="display: flex; flex-wrap: wrap; gap: 5px; font-size: 0.9rem;">
297
+ <span><strong>Sum:</strong> {vals['pixel_sum']:.1f}</span>
298
+ <span><strong>Force:</strong> {vals['force']:.1f}</span>
299
+ <span><strong>Max:</strong> {vals['max']:.3f}</span>
300
+ <span><strong>Mean:</strong> {vals['mean']:.3f}</span>
301
+ </div>
302
+ </div>
303
+ """, unsafe_allow_html=True)
304
+ st.caption("Bright-field")
305
+ bf_display = bf_rgb.copy()
306
+ if cell_mask is not None and np.any(cell_mask > 0):
307
+ bf_display = _draw_contour_on_image(bf_display, cell_mask, stroke_color=(255, 0, 0), stroke_width=5)
308
+ st.image(bf_display, width=CANVAS_SIZE)
309
+ else:
310
+ st.markdown("**Draw a region** on the heatmap.")
311
+ draw_mode = st.selectbox("Drawing tool", DRAW_TOOLS,
312
+ format_func=lambda x: "Polygon (free shape)" if x == "polygon" else TOOL_LABELS[x],
313
+ key=f"draw_mode_region_{key_suffix}")
314
+ st.caption("Polygon: left-click to add points, right-click to close.")
315
+ canvas_result = st_canvas(
316
+ fill_color="rgba(0, 188, 212, 0.25)", stroke_width=2, stroke_color="#00bcd4",
317
+ background_image=pil_bg, drawing_mode=draw_mode, update_streamlit=True,
318
+ height=CANVAS_SIZE, width=CANVAS_SIZE, display_toolbar=True,
319
+ key=f"region_measure_canvas_{key_suffix}",
320
+ )
321
+
322
+ if canvas_result.json_data:
323
+ masks = parse_canvas_shapes_to_masks(canvas_result.json_data, CANVAS_SIZE, CANVAS_SIZE, h, w)
324
+ if masks:
325
+ metrics_list = [compute_region_metrics(raw_heatmap, m, original_vals) for m in masks]
326
+ if cell_mask is not None and np.any(cell_mask > 0):
327
+ cell_metrics = compute_region_metrics(raw_heatmap, cell_mask, original_vals)
328
+ metrics_list = [cell_metrics] + metrics_list
329
+ render_region_metrics_and_downloads(
330
+ metrics_list, masks, heatmap_rgb, input_filename, key_suffix, original_vals is not None,
331
+ first_region_label="Auto boundary" if (cell_mask is not None and np.any(cell_mask > 0)) else None,
332
+ bf_img=bf_img, cell_mask=cell_mask, colormap_name=colormap_name,
333
+ )
S2FApp/ui/result_display.py ADDED
@@ -0,0 +1,358 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Result display: single and batch prediction views."""
2
+ import csv
3
+ import io
4
+ import os
5
+ import zipfile
6
+
7
+ import cv2
8
+ import numpy as np
9
+ import streamlit as st
10
+ import plotly.graph_objects as go
11
+ from plotly.subplots import make_subplots
12
+
13
+ from utils.display import apply_display_scale, cv_colormap_to_plotly_colorscale
14
+ from utils.report import heatmap_to_rgb_with_contour, heatmap_to_png_bytes, create_pdf_report
15
+ from utils.segmentation import estimate_cell_mask
16
+ from ui.heatmaps import render_horizontal_colorbar, add_cell_contour_to_fig
17
+ from ui.measure_tool import (
18
+ build_original_vals,
19
+ build_cell_vals,
20
+ render_region_canvas,
21
+ _compute_cell_metrics,
22
+ HAS_DRAWABLE_CANVAS,
23
+ )
24
+
25
+
26
+ def render_batch_results(batch_results, colormap_name="Jet", display_mode="Default",
27
+ min_percentile=0, max_percentile=100, clip_min=0, clip_max=1,
28
+ auto_cell_boundary=False, clip_bounds=False):
29
+ """
30
+ Render batch prediction results: summary table, bright-field row, heatmap row, and bulk download.
31
+ batch_results: list of dicts with img, heatmap, force, pixel_sum, key_img, cell_mask.
32
+ cell_mask is computed on-the-fly when auto_cell_boundary is True and not stored.
33
+ """
34
+ if not batch_results:
35
+ return
36
+
37
+ # Resolve cell_mask and precompute display_heatmap for each result
38
+ for r in batch_results:
39
+ if auto_cell_boundary and (r.get("cell_mask") is None or not np.any(r.get("cell_mask", 0) > 0)):
40
+ r["_cell_mask"] = estimate_cell_mask(r["heatmap"])
41
+ else:
42
+ r["_cell_mask"] = r.get("cell_mask") if auto_cell_boundary else None
43
+ r["_display_heatmap"] = apply_display_scale(
44
+ r["heatmap"], display_mode,
45
+ min_percentile=min_percentile, max_percentile=max_percentile,
46
+ clip_min=clip_min, clip_max=clip_max, clip_bounds=clip_bounds,
47
+ )
48
+ # Build table rows - consistent column names for both modes
49
+ headers = ["Image", "Force", "Sum", "Max", "Mean"]
50
+ rows = []
51
+ csv_rows = [["image"] + headers[1:]]
52
+ for r in batch_results:
53
+ heatmap = r["heatmap"]
54
+ cell_mask = r.get("_cell_mask")
55
+ key = r["key_img"] or "image"
56
+ if auto_cell_boundary and cell_mask is not None and np.any(cell_mask > 0):
57
+ vals = heatmap[cell_mask > 0]
58
+ cell_pixel_sum = float(np.sum(vals))
59
+ cell_force = cell_pixel_sum * (r["force"] / r["pixel_sum"]) if r["pixel_sum"] > 0 else cell_pixel_sum
60
+ cell_mean = cell_pixel_sum / np.sum(cell_mask) if np.sum(cell_mask) > 0 else 0
61
+ row = [key, f"{cell_force:.2f}", f"{cell_pixel_sum:.2f}",
62
+ f"{np.max(heatmap):.4f}", f"{cell_mean:.4f}"]
63
+ else:
64
+ row = [key, f"{r['force']:.2f}", f"{r['pixel_sum']:.2f}",
65
+ f"{np.max(heatmap):.4f}", f"{np.mean(heatmap):.4f}"]
66
+ rows.append(row)
67
+ csv_rows.append([os.path.splitext(key)[0]] + row[1:])
68
+ st.markdown('<div class="result-label"><span class="result-badge input">INPUT</span> Bright-field images</div>', unsafe_allow_html=True)
69
+ n_cols = min(5, len(batch_results))
70
+ bf_cols = st.columns(n_cols)
71
+ for i, r in enumerate(batch_results):
72
+ img = r["img"]
73
+ if img.ndim == 2:
74
+ img_rgb = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
75
+ else:
76
+ img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
77
+ with bf_cols[i % n_cols]:
78
+ st.image(img_rgb, caption=r["key_img"], use_container_width=True)
79
+ is_rescale_b = display_mode == "Range" and clip_max > clip_min and not (clip_min == 0 and clip_max == 1)
80
+ st.markdown('<div class="result-label"><span class="result-badge output">OUTPUT</span> Predicted force maps</div>', unsafe_allow_html=True)
81
+ hm_cols = st.columns(n_cols)
82
+ for i, r in enumerate(batch_results):
83
+ hm_rgb = heatmap_to_rgb_with_contour(
84
+ r["_display_heatmap"], colormap_name,
85
+ r.get("_cell_mask") if auto_cell_boundary else None,
86
+ )
87
+ with hm_cols[i % n_cols]:
88
+ st.image(hm_rgb, caption=r["key_img"], use_container_width=True)
89
+ render_horizontal_colorbar(colormap_name, clip_min, clip_max, is_rescale_b)
90
+ # Table
91
+ st.dataframe(
92
+ {h: [r[i] for r in rows] for i, h in enumerate(headers)},
93
+ use_container_width=True,
94
+ hide_index=True,
95
+ )
96
+ # Histograms in accordion (one per row for visibility)
97
+ with st.expander("Force distribution (histograms)", expanded=False):
98
+ for i, r in enumerate(batch_results):
99
+ heatmap = r["heatmap"]
100
+ cell_mask = r.get("_cell_mask")
101
+ vals = heatmap[cell_mask > 0] if (cell_mask is not None and np.any(cell_mask > 0) and auto_cell_boundary) else heatmap.flatten()
102
+ vals = vals[vals > 0] if np.any(vals > 0) else vals
103
+ st.markdown(f"**{r['key_img']}**")
104
+ if len(vals) > 0:
105
+ fig = go.Figure(data=[go.Histogram(x=vals, nbinsx=50, marker_color="#0d9488")])
106
+ fig.update_layout(
107
+ height=220, margin=dict(l=40, r=20, t=10, b=40),
108
+ xaxis_title="Force value", yaxis_title="Count",
109
+ showlegend=False,
110
+ )
111
+ st.plotly_chart(fig, use_container_width=True, config={"displayModeBar": False})
112
+ else:
113
+ st.caption("No data")
114
+ if i < len(batch_results) - 1:
115
+ st.divider()
116
+ # Bulk downloads: CSV and heatmaps (zip)
117
+ buf_csv = io.StringIO()
118
+ csv.writer(buf_csv).writerows(csv_rows)
119
+ zip_buf = io.BytesIO()
120
+ with zipfile.ZipFile(zip_buf, "w", zipfile.ZIP_DEFLATED) as zf:
121
+ for r in batch_results:
122
+ hm_bytes = heatmap_to_png_bytes(
123
+ r["_display_heatmap"], colormap_name,
124
+ r.get("_cell_mask") if auto_cell_boundary else None,
125
+ )
126
+ base = os.path.splitext(r["key_img"] or "image")[0]
127
+ zf.writestr(f"{base}_heatmap.png", hm_bytes.getvalue())
128
+ zip_buf.seek(0)
129
+ dl_col1, dl_col2 = st.columns(2)
130
+ with dl_col1:
131
+ st.download_button(
132
+ "Download all as CSV",
133
+ data=buf_csv.getvalue(),
134
+ file_name="s2f_batch_results.csv",
135
+ mime="text/csv",
136
+ key="download_batch_csv",
137
+ icon=":material/download:",
138
+ )
139
+ with dl_col2:
140
+ st.download_button(
141
+ "Download all heatmaps",
142
+ data=zip_buf.getvalue(),
143
+ file_name="s2f_batch_heatmaps.zip",
144
+ mime="application/zip",
145
+ key="download_batch_heatmaps",
146
+ icon=":material/image:",
147
+ )
148
+
149
+
150
+ def render_result_display(img, raw_heatmap, display_heatmap, pixel_sum, force, key_img, download_key_suffix="",
151
+ colormap_name="Jet", display_mode="Default", measure_region_dialog=None, auto_cell_boundary=True,
152
+ cell_mask=None, clip_min=0.0, clip_max=1.0, clip_bounds=False):
153
+ """
154
+ Render prediction result: plot, metrics, expander, and download/measure buttons.
155
+ measure_region_dialog: callable to open measure dialog (when ST_DIALOG available).
156
+ auto_cell_boundary: when True, use estimated cell area for metrics; when False, use entire map.
157
+ cell_mask: optional precomputed cell mask; if None and auto_cell_boundary, will be computed.
158
+ """
159
+ if cell_mask is None and auto_cell_boundary:
160
+ cell_mask = estimate_cell_mask(raw_heatmap)
161
+ elif not auto_cell_boundary:
162
+ cell_mask = None
163
+ cell_pixel_sum, cell_force, cell_mean = _compute_cell_metrics(raw_heatmap, cell_mask, pixel_sum, force) if cell_mask is not None else (None, None, None)
164
+ use_cell_metrics = auto_cell_boundary and cell_pixel_sum is not None and cell_force is not None and cell_mean is not None
165
+
166
+ base_name = os.path.splitext(key_img or "image")[0]
167
+ if use_cell_metrics:
168
+ main_csv_rows = [
169
+ ["image", "Cell sum", "Cell force (scaled)", "Heatmap max", "Heatmap mean"],
170
+ [base_name, f"{cell_pixel_sum:.2f}", f"{cell_force:.2f}",
171
+ f"{np.max(raw_heatmap):.4f}", f"{cell_mean:.4f}"],
172
+ ]
173
+ else:
174
+ main_csv_rows = [
175
+ ["image", "Sum of all pixels", "Cell force (scaled)", "Heatmap max", "Heatmap mean"],
176
+ [base_name, f"{pixel_sum:.2f}", f"{force:.2f}",
177
+ f"{np.max(raw_heatmap):.4f}", f"{np.mean(raw_heatmap):.4f}"],
178
+ ]
179
+ buf_main_csv = io.StringIO()
180
+ csv.writer(buf_main_csv).writerows(main_csv_rows)
181
+
182
+ buf_hm = heatmap_to_png_bytes(display_heatmap, colormap_name, cell_mask=cell_mask)
183
+
184
+ is_rescale = display_mode == "Range" and clip_max > clip_min and not (clip_min == 0.0 and clip_max == 1.0)
185
+
186
+ tit1, tit2 = st.columns(2)
187
+ with tit1:
188
+ st.markdown('<div class="result-label"><span class="result-badge input">INPUT</span> Bright-field image</div>', unsafe_allow_html=True)
189
+ with tit2:
190
+ st.markdown('<div class="result-label"><span class="result-badge output">OUTPUT</span> Predicted force map</div>', unsafe_allow_html=True)
191
+ fig_pl = make_subplots(rows=1, cols=2)
192
+ fig_pl.add_trace(go.Heatmap(z=img, colorscale="gray", showscale=False), row=1, col=1)
193
+ plotly_colorscale = cv_colormap_to_plotly_colorscale(colormap_name)
194
+ colorbar_cfg = dict(len=0.4, thickness=12, tickmode="array")
195
+ tick_positions = [0, 0.25, 0.5, 0.75, 1]
196
+ if is_rescale:
197
+ rng = clip_max - clip_min
198
+ colorbar_cfg["tickvals"] = tick_positions
199
+ colorbar_cfg["ticktext"] = [f"{clip_min + t * rng:.2f}" for t in tick_positions]
200
+ else:
201
+ colorbar_cfg["tickvals"] = tick_positions
202
+ colorbar_cfg["ticktext"] = [f"{t:.2f}" for t in tick_positions]
203
+ fig_pl.add_trace(go.Heatmap(z=display_heatmap, colorscale=plotly_colorscale, zmin=0.0, zmax=1.0, showscale=True,
204
+ colorbar=colorbar_cfg), row=1, col=2)
205
+ add_cell_contour_to_fig(fig_pl, cell_mask, row=1, col=2)
206
+ fig_pl.update_layout(
207
+ height=400,
208
+ margin=dict(l=10, r=10, t=10, b=10),
209
+ xaxis=dict(scaleanchor="y", scaleratio=1),
210
+ xaxis2=dict(scaleanchor="y2", scaleratio=1),
211
+ )
212
+ fig_pl.update_xaxes(showticklabels=False, showgrid=False, zeroline=False)
213
+ fig_pl.update_yaxes(showticklabels=False, autorange="reversed", showgrid=False, zeroline=False)
214
+ st.plotly_chart(fig_pl, use_container_width=True, config={"displayModeBar": True, "responsive": True})
215
+
216
+ col1, col2, col3, col4 = st.columns(4)
217
+ if use_cell_metrics:
218
+ with col1:
219
+ st.metric("Cell sum", f"{cell_pixel_sum:.2f}", help="Sum over estimated cell area (background excluded)")
220
+ with col2:
221
+ st.metric("Cell force (scaled)", f"{cell_force:.2f}", help="Total traction force in physical units")
222
+ with col3:
223
+ st.metric("Heatmap max", f"{np.max(raw_heatmap):.4f}", help="Peak force intensity in the map")
224
+ with col4:
225
+ st.metric("Heatmap mean", f"{cell_mean:.4f}", help="Mean force over estimated cell area")
226
+ else:
227
+ with col1:
228
+ st.metric("Sum of all pixels", f"{pixel_sum:.2f}", help="Raw sum of all pixel values in the force map")
229
+ with col2:
230
+ st.metric("Cell force (scaled)", f"{force:.2f}", help="Total traction force in physical units")
231
+ with col3:
232
+ st.metric("Heatmap max", f"{np.max(raw_heatmap):.4f}", help="Peak force intensity in the map")
233
+ with col4:
234
+ st.metric("Heatmap mean", f"{np.mean(raw_heatmap):.4f}", help="Average force intensity (full FOV)")
235
+
236
+ # Statistics panel (mean, std, percentiles, histogram)
237
+ with st.expander("Statistics"):
238
+ vals = raw_heatmap[cell_mask > 0] if (cell_mask is not None and np.any(cell_mask > 0) and use_cell_metrics) else raw_heatmap.flatten()
239
+ if len(vals) > 0:
240
+ stat_col1, stat_col2, stat_col3 = st.columns(3)
241
+ p25, p50, p75, p90 = (
242
+ float(np.percentile(vals, 25)), float(np.percentile(vals, 50)),
243
+ float(np.percentile(vals, 75)), float(np.percentile(vals, 90)),
244
+ )
245
+ with stat_col1:
246
+ st.metric("Mean", f"{float(np.mean(vals)):.4f}")
247
+ st.metric("Std", f"{float(np.std(vals)):.4f}")
248
+ with stat_col2:
249
+ st.metric("P25", f"{p25:.4f}")
250
+ st.metric("P50 (median)", f"{p50:.4f}")
251
+ with stat_col3:
252
+ st.metric("P75", f"{p75:.4f}")
253
+ st.metric("P90", f"{p90:.4f}")
254
+ st.markdown("**Histogram**")
255
+ hist_fig = go.Figure(data=[go.Histogram(x=vals, nbinsx=50, marker_color="#0d9488")])
256
+ hist_fig.update_layout(
257
+ height=220, margin=dict(l=40, r=20, t=20, b=40),
258
+ xaxis_title="Force value", yaxis_title="Count",
259
+ showlegend=False,
260
+ )
261
+ st.plotly_chart(hist_fig, use_container_width=True, config={"displayModeBar": False})
262
+ else:
263
+ st.caption("No nonzero values to compute statistics.")
264
+
265
+ with st.expander("How to read the results"):
266
+ if use_cell_metrics:
267
+ st.markdown("""
268
+ **Input (left):** Bright-field microscopy image of a cell or spheroid on a substrate.
269
+ This is the raw image you provided—it shows cell shape but not forces.
270
+
271
+ **Output (right):** Predicted traction force map.
272
+ - **Color** indicates force magnitude: blue = low, red = high
273
+ - **Brighter/warmer colors** = stronger forces exerted by the cell on the substrate
274
+ - **Red border = estimated cell area** (background excluded from metrics)
275
+ - Values are normalized to [0, 1] for visualization
276
+
277
+ **Metrics (auto cell boundary on):**
278
+ - **Cell sum:** Sum over estimated cell area (background excluded)
279
+ - **Cell force (scaled):** Total traction force in physical units
280
+ - **Heatmap max:** Peak force intensity in the map
281
+ - **Heatmap mean:** Mean force over the estimated cell area
282
+ """)
283
+ else:
284
+ st.markdown("""
285
+ **Input (left):** Bright-field microscopy image of a cell or spheroid on a substrate.
286
+ This is the raw image you provided—it shows cell shape but not forces.
287
+
288
+ **Output (right):** Predicted traction force map.
289
+ - **Color** indicates force magnitude: blue = low, red = high
290
+ - **Brighter/warmer colors** = stronger forces exerted by the cell on the substrate
291
+ - Values are normalized to [0, 1] for visualization
292
+
293
+ **Metrics (auto cell boundary off):**
294
+ - **Sum of all pixels:** Raw sum over entire map
295
+ - **Cell force (scaled):** Total traction force in physical units
296
+ - **Heatmap max/mean:** Peak and average force intensity (full field of view)
297
+ """)
298
+
299
+ original_vals = build_original_vals(raw_heatmap, pixel_sum, force)
300
+
301
+ pdf_bytes = create_pdf_report(
302
+ img, display_heatmap, raw_heatmap, pixel_sum, force, base_name, colormap_name,
303
+ cell_mask=cell_mask, cell_pixel_sum=cell_pixel_sum, cell_force=cell_force, cell_mean=cell_mean
304
+ )
305
+
306
+ btn_col1, btn_col2, btn_col3, btn_col4 = st.columns(4)
307
+ with btn_col1:
308
+ if HAS_DRAWABLE_CANVAS and measure_region_dialog is not None:
309
+ if st.button("Measure tool", key="open_measure", icon=":material/straighten:"):
310
+ st.session_state["open_measure_dialog"] = True
311
+ st.rerun()
312
+ elif HAS_DRAWABLE_CANVAS:
313
+ with st.expander("Measure tool"):
314
+ expander_cell_vals = build_cell_vals(raw_heatmap, cell_mask, pixel_sum, force) if (auto_cell_boundary and cell_mask is not None) else None
315
+ expander_cell_mask = cell_mask if auto_cell_boundary else None
316
+ render_region_canvas(
317
+ display_heatmap,
318
+ raw_heatmap=raw_heatmap,
319
+ bf_img=img,
320
+ original_vals=original_vals,
321
+ cell_vals=expander_cell_vals,
322
+ cell_mask=expander_cell_mask,
323
+ key_suffix="expander",
324
+ input_filename=key_img,
325
+ colormap_name=colormap_name,
326
+ )
327
+ else:
328
+ st.caption("Install `streamlit-drawable-canvas-fix` for region measurement: `pip install streamlit-drawable-canvas-fix`")
329
+ with btn_col2:
330
+ st.download_button(
331
+ "Download heatmap",
332
+ width="stretch",
333
+ data=buf_hm.getvalue(),
334
+ file_name="s2f_heatmap.png",
335
+ mime="image/png",
336
+ key=f"download_heatmap{download_key_suffix}",
337
+ icon=":material/download:",
338
+ )
339
+ with btn_col3:
340
+ st.download_button(
341
+ "Download values",
342
+ width="stretch",
343
+ data=buf_main_csv.getvalue(),
344
+ file_name=f"{base_name}_main_values.csv",
345
+ mime="text/csv",
346
+ key=f"download_main_values{download_key_suffix}",
347
+ icon=":material/download:",
348
+ )
349
+ with btn_col4:
350
+ st.download_button(
351
+ "Download report",
352
+ width="stretch",
353
+ data=pdf_bytes,
354
+ file_name=f"{base_name}_report.pdf",
355
+ mime="application/pdf",
356
+ key=f"download_pdf{download_key_suffix}",
357
+ icon=":material/picture_as_pdf:",
358
+ )
S2FApp/ui/system_status.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """System status UI component (CPU/memory)."""
2
+ import streamlit as st
3
+
4
+ try:
5
+ import psutil
6
+ HAS_PSUTIL = True
7
+ except ImportError:
8
+ HAS_PSUTIL = False
9
+
10
+
11
+ def _get_container_memory():
12
+ """
13
+ Read memory from cgroups when running in a container (Docker, HF Spaces).
14
+ psutil reports host memory in containers, which can be misleading (e.g. 128 GB vs 16 GB limit).
15
+ Returns (used_bytes, total_bytes) or None to fall back to psutil.
16
+ """
17
+ try:
18
+ # cgroup v2 (modern Docker, HF Spaces)
19
+ for base in ("/sys/fs/cgroup", "/sys/fs/cgroup/self"):
20
+ try:
21
+ with open(f"{base}/memory.max", "r") as f:
22
+ max_val = f.read().strip()
23
+ if max_val == "max":
24
+ return None # No limit, use psutil
25
+ total = int(max_val)
26
+ with open(f"{base}/memory.current", "r") as f:
27
+ used = int(f.read().strip())
28
+ return (used, total)
29
+ except (FileNotFoundError, ValueError):
30
+ continue
31
+ # cgroup v1
32
+ try:
33
+ with open("/sys/fs/cgroup/memory/memory.limit_in_bytes", "r") as f:
34
+ total = int(f.read().strip())
35
+ with open("/sys/fs/cgroup/memory/memory.usage_in_bytes", "r") as f:
36
+ used = int(f.read().strip())
37
+ if total > 2**50: # Often 9223372036854771712 when unlimited
38
+ return None
39
+ return (used, total)
40
+ except (FileNotFoundError, ValueError):
41
+ pass
42
+ except Exception:
43
+ pass
44
+ return None
45
+
46
+
47
+ def render_system_status():
48
+ """Render CPU/memory status in the sidebar (always visible)."""
49
+ if not HAS_PSUTIL:
50
+ return
51
+ try:
52
+ cpu = psutil.cpu_percent(interval=0.1)
53
+ container_mem = _get_container_memory()
54
+ if container_mem is not None:
55
+ used_bytes, total_bytes = container_mem
56
+ mem_used_gb = used_bytes / (1024**3)
57
+ mem_total_gb = total_bytes / (1024**3)
58
+ mem_pct = 100 * used_bytes / total_bytes if total_bytes > 0 else 0
59
+ else:
60
+ mem = psutil.virtual_memory()
61
+ mem_used_gb = mem.used / (1024**3)
62
+ mem_total_gb = mem.total / (1024**3)
63
+ mem_pct = mem.percent
64
+ st.sidebar.markdown(
65
+ f"""
66
+ <div class="system-status">
67
+ <span class="status-dot"></span>
68
+ <span><strong>System</strong>&ensp;CPU {cpu:.0f}%&ensp;·&ensp;Mem {mem_pct:.0f}% ({mem_used_gb:.1f}/{mem_total_gb:.1f} GB)</span>
69
+ </div>
70
+ """,
71
+ unsafe_allow_html=True,
72
+ )
73
+ except Exception:
74
+ pass
S2FApp/utils/display.py CHANGED
@@ -19,12 +19,50 @@ def cv_colormap_to_plotly_colorscale(colormap_name, n_samples=None):
19
  return scale
20
 
21
 
22
- def apply_display_scale(heatmap, mode, min_percentile=0, max_percentile=100, clip_min=0, clip_max=1):
23
  """
24
- Apply display scaling (Fiji/ImageJ style). Display only—does not change underlying values.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  - Default: full 0–1 range as-is.
26
- - Percentile: map min..max percentiles to 0..1.
27
- - Range: show only values in [clip_min, clip_max]; others hidden (black).
 
 
28
  """
29
  if mode == "Default" or mode == "Auto" or mode == "Full":
30
  return np.clip(heatmap, 0, 1).astype(np.float32)
@@ -36,9 +74,24 @@ def apply_display_scale(heatmap, mode, min_percentile=0, max_percentile=100, cli
36
  return np.clip(out, 0, 1).astype(np.float32)
37
  return np.clip(heatmap, 0, 1).astype(np.float32)
38
  if mode == "Range" or mode == "Filter" or mode == "Threshold":
 
 
 
 
 
 
 
 
 
 
 
 
39
  vmin, vmax = float(clip_min), float(clip_max)
40
  if vmax > vmin:
41
  h = heatmap.astype(np.float32)
 
 
 
42
  mask = (h >= vmin) & (h <= vmax)
43
  out = np.zeros_like(h)
44
  out[mask] = (h[mask] - vmin) / (vmax - vmin)
 
19
  return scale
20
 
21
 
22
+ def build_range_colorscale(colormap_name, clip_min, clip_max, n_range_samples=32):
23
  """
24
+ Build a Plotly colorscale for Range mode: normal gradient in [clip_min, clip_max],
25
+ the "zero" color everywhere else (0 → clip_min and clip_max → 1).
26
+ """
27
+ cv2_cmap = COLORMAPS.get(colormap_name, cv2.COLORMAP_JET)
28
+
29
+ zero_px = np.array([[0]], dtype=np.uint8)
30
+ zero_rgb = cv2.applyColorMap(zero_px, cv2_cmap)
31
+ zero_rgb = cv2.cvtColor(zero_rgb, cv2.COLOR_BGR2RGB)
32
+ zr, zg, zb = zero_rgb[0, 0]
33
+ zero_color = f"rgb({zr},{zg},{zb})"
34
+
35
+ eps = 0.0005
36
+ scale = []
37
+
38
+ scale.append([0.0, zero_color])
39
+ if clip_min > eps:
40
+ scale.append([clip_min - eps, zero_color])
41
+
42
+ positions = np.linspace(clip_min, clip_max, n_range_samples)
43
+ pixel_vals = np.clip((positions * 255).astype(np.uint8), 0, 255).reshape(1, -1)
44
+ rgb = cv2.applyColorMap(pixel_vals, cv2_cmap)
45
+ rgb = cv2.cvtColor(rgb, cv2.COLOR_BGR2RGB)
46
+ for i, pos in enumerate(positions):
47
+ r, g, b = rgb[0, i]
48
+ scale.append([float(pos), f"rgb({r},{g},{b})"])
49
+
50
+ if clip_max < 1.0 - eps:
51
+ scale.append([clip_max + eps, zero_color])
52
+ scale.append([1.0, zero_color])
53
+
54
+ return scale
55
+
56
+
57
+ def apply_display_scale(heatmap, mode, min_percentile=0, max_percentile=100,
58
+ clip_min=0, clip_max=1, clip_bounds=False):
59
+ """
60
+ Apply display scaling. Display only—does not change underlying values.
61
  - Default: full 0–1 range as-is.
62
+ - Range: keep original values inside [clip_min, clip_max].
63
+ clip_bounds=False zero out outside. clip_bounds=True clamp to bounds.
64
+ - Rescale: map [clip_min, clip_max] → [0, 1].
65
+ clip_bounds=False → zero out outside. clip_bounds=True → clamp to bounds first.
66
  """
67
  if mode == "Default" or mode == "Auto" or mode == "Full":
68
  return np.clip(heatmap, 0, 1).astype(np.float32)
 
74
  return np.clip(out, 0, 1).astype(np.float32)
75
  return np.clip(heatmap, 0, 1).astype(np.float32)
76
  if mode == "Range" or mode == "Filter" or mode == "Threshold":
77
+ # Range: filter (discard outside) + rescale [clip_min, clip_max] → [0, 1] so max shows as red
78
+ vmin, vmax = float(clip_min), float(clip_max)
79
+ if vmax > vmin:
80
+ h = heatmap.astype(np.float32)
81
+ if clip_bounds:
82
+ return np.clip(h, vmin, vmax).astype(np.float32)
83
+ mask = (h >= vmin) & (h <= vmax)
84
+ out = np.zeros_like(h)
85
+ out[mask] = (h[mask] - vmin) / (vmax - vmin)
86
+ return np.clip(out, 0, 1).astype(np.float32)
87
+ return np.clip(heatmap, 0, 1).astype(np.float32)
88
+ if mode == "Rescale":
89
  vmin, vmax = float(clip_min), float(clip_max)
90
  if vmax > vmin:
91
  h = heatmap.astype(np.float32)
92
+ if clip_bounds:
93
+ clamped = np.clip(h, vmin, vmax)
94
+ return ((clamped - vmin) / (vmax - vmin)).astype(np.float32)
95
  mask = (h >= vmin) & (h <= vmax)
96
  out = np.zeros_like(h)
97
  out[mask] = (h[mask] - vmin) / (vmax - vmin)
S2FApp/utils/report.py CHANGED
@@ -48,27 +48,34 @@ def _pdf_image_layout(page_w_pt, page_h_pt, margin=72, n_images=2):
48
  }
49
 
50
 
51
- def heatmap_to_rgb(scaled_heatmap, colormap_name="Jet"):
52
- """Convert scaled heatmap (float 0-1) to RGB array using the given colormap."""
53
- heatmap_uint8 = (np.clip(scaled_heatmap, 0, 1) * 255).astype(np.uint8)
 
 
 
 
 
 
54
  cv2_colormap = COLORMAPS.get(colormap_name, cv2.COLORMAP_JET)
55
  heatmap_rgb = cv2.cvtColor(cv2.applyColorMap(heatmap_uint8, cv2_colormap), cv2.COLOR_BGR2RGB)
56
  return heatmap_rgb
57
 
58
 
59
- def heatmap_to_rgb_with_contour(scaled_heatmap, colormap_name="Jet", cell_mask=None):
60
  """Convert heatmap to RGB, optionally drawing red cell contour. Mask must match heatmap shape."""
61
- heatmap_rgb = heatmap_to_rgb(scaled_heatmap, colormap_name)
62
  if cell_mask is not None and np.any(cell_mask > 0):
63
  contours, _ = cv2.findContours(cell_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
64
  if contours:
65
- cv2.drawContours(heatmap_rgb, contours, -1, (255, 0, 0), 2)
66
  return heatmap_rgb
67
 
68
 
69
- def heatmap_to_png_bytes(scaled_heatmap, colormap_name="Jet", cell_mask=None):
70
- """Convert scaled heatmap (float 0-1) to PNG bytes buffer. Optionally draw red cell contour."""
71
- heatmap_rgb = heatmap_to_rgb_with_contour(scaled_heatmap, colormap_name, cell_mask)
 
72
  buf = io.BytesIO()
73
  Image.fromarray(heatmap_rgb).save(buf, format="PNG")
74
  buf.seek(0)
@@ -76,7 +83,7 @@ def heatmap_to_png_bytes(scaled_heatmap, colormap_name="Jet", cell_mask=None):
76
 
77
 
78
  def create_pdf_report(img, display_heatmap, raw_heatmap, pixel_sum, force, base_name, colormap_name="Jet",
79
- cell_mask=None, cell_pixel_sum=None, cell_force=None, cell_mean=None):
80
  """Create a PDF report with input image, heatmap, and metrics."""
81
  buf = io.BytesIO()
82
  c = canvas.Canvas(buf, pagesize=A4)
@@ -108,7 +115,7 @@ def create_pdf_report(img, display_heatmap, raw_heatmap, pixel_sum, force, base_
108
  bf_label_w = c.stringWidth("Bright-field", "Helvetica", 9)
109
  c.drawString(bf_x + (img_w - bf_label_w) / 2, img_bottom - 14, "Bright-field")
110
 
111
- heatmap_rgb = heatmap_to_rgb_with_contour(display_heatmap, colormap_name, cell_mask)
112
  hm_buf = io.BytesIO()
113
  Image.fromarray(heatmap_rgb).save(hm_buf, format="PNG")
114
  hm_buf.seek(0)
 
48
  }
49
 
50
 
51
+ def heatmap_to_rgb(scaled_heatmap, colormap_name="Jet", zmin=None, zmax=None):
52
+ """Convert scaled heatmap to RGB array using the given colormap.
53
+ If zmin, zmax are provided (e.g. for Range mode), map [zmin,zmax] to 0-1 for coloring."""
54
+ arr = np.asarray(scaled_heatmap, dtype=np.float32)
55
+ if zmin is not None and zmax is not None and zmax > zmin:
56
+ arr = np.clip((arr - zmin) / (zmax - zmin), 0, 1)
57
+ else:
58
+ arr = np.clip(arr, 0, 1)
59
+ heatmap_uint8 = (arr * 255).astype(np.uint8)
60
  cv2_colormap = COLORMAPS.get(colormap_name, cv2.COLORMAP_JET)
61
  heatmap_rgb = cv2.cvtColor(cv2.applyColorMap(heatmap_uint8, cv2_colormap), cv2.COLOR_BGR2RGB)
62
  return heatmap_rgb
63
 
64
 
65
+ def heatmap_to_rgb_with_contour(scaled_heatmap, colormap_name="Jet", cell_mask=None, zmin=None, zmax=None):
66
  """Convert heatmap to RGB, optionally drawing red cell contour. Mask must match heatmap shape."""
67
+ heatmap_rgb = heatmap_to_rgb(scaled_heatmap, colormap_name, zmin=zmin, zmax=zmax)
68
  if cell_mask is not None and np.any(cell_mask > 0):
69
  contours, _ = cv2.findContours(cell_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
70
  if contours:
71
+ cv2.drawContours(heatmap_rgb, contours, -1, (255, 0, 0), 5)
72
  return heatmap_rgb
73
 
74
 
75
+ def heatmap_to_png_bytes(scaled_heatmap, colormap_name="Jet", cell_mask=None, zmin=None, zmax=None):
76
+ """Convert scaled heatmap to PNG bytes buffer. Optionally draw red cell contour.
77
+ If zmin, zmax provided (Range mode), map that range to full colormap."""
78
+ heatmap_rgb = heatmap_to_rgb_with_contour(scaled_heatmap, colormap_name, cell_mask, zmin=zmin, zmax=zmax)
79
  buf = io.BytesIO()
80
  Image.fromarray(heatmap_rgb).save(buf, format="PNG")
81
  buf.seek(0)
 
83
 
84
 
85
  def create_pdf_report(img, display_heatmap, raw_heatmap, pixel_sum, force, base_name, colormap_name="Jet",
86
+ cell_mask=None, cell_pixel_sum=None, cell_force=None, cell_mean=None, zmin=None, zmax=None):
87
  """Create a PDF report with input image, heatmap, and metrics."""
88
  buf = io.BytesIO()
89
  c = canvas.Canvas(buf, pagesize=A4)
 
115
  bf_label_w = c.stringWidth("Bright-field", "Helvetica", 9)
116
  c.drawString(bf_x + (img_w - bf_label_w) / 2, img_bottom - 14, "Bright-field")
117
 
118
+ heatmap_rgb = heatmap_to_rgb_with_contour(display_heatmap, colormap_name, cell_mask, zmin=zmin, zmax=zmax)
119
  hm_buf = io.BytesIO()
120
  Image.fromarray(heatmap_rgb).save(hm_buf, format="PNG")
121
  hm_buf.seek(0)
S2FApp/utils/segmentation.py CHANGED
@@ -2,7 +2,7 @@
2
  import numpy as np
3
  from scipy.ndimage import gaussian_filter
4
  from skimage.filters import threshold_otsu
5
- from skimage.morphology import binary_closing, binary_opening, binary_dilation, remove_small_objects, disk
6
  from skimage.measure import label, regionprops
7
 
8
 
@@ -37,9 +37,9 @@ def estimate_cell_mask(heatmap, sigma=2, min_size=200, exclude_full_image=True,
37
  mask = (smoothed > thresh).astype(np.uint8)
38
 
39
  # Morphological cleanup
40
- mask = binary_closing(mask, disk(5)).astype(np.uint8)
41
- mask = binary_opening(mask, disk(3)).astype(np.uint8)
42
- mask = remove_small_objects(mask.astype(bool), min_size=min_size).astype(np.uint8)
43
 
44
  # Select component: second largest if largest is whole image
45
  labeled = label(mask)
@@ -60,6 +60,6 @@ def estimate_cell_mask(heatmap, sigma=2, min_size=200, exclude_full_image=True,
60
 
61
  # Dilate to include surrounding pixels
62
  if dilate_radius > 0:
63
- mask = binary_dilation(mask, disk(dilate_radius)).astype(np.uint8)
64
 
65
  return mask
 
2
  import numpy as np
3
  from scipy.ndimage import gaussian_filter
4
  from skimage.filters import threshold_otsu
5
+ from skimage.morphology import closing, opening, dilation, remove_small_objects, disk
6
  from skimage.measure import label, regionprops
7
 
8
 
 
37
  mask = (smoothed > thresh).astype(np.uint8)
38
 
39
  # Morphological cleanup
40
+ mask = closing(mask, disk(5)).astype(np.uint8)
41
+ mask = opening(mask, disk(3)).astype(np.uint8)
42
+ mask = remove_small_objects(mask.astype(bool), max_size=min_size - 1).astype(np.uint8)
43
 
44
  # Select component: second largest if largest is whole image
45
  labeled = label(mask)
 
60
 
61
  # Dilate to include surrounding pixels
62
  if dilate_radius > 0:
63
+ mask = dilation(mask, disk(dilate_radius)).astype(np.uint8)
64
 
65
  return mask