kaveh commited on
Commit
a43edf0
·
1 Parent(s): c7451f4

added colormap

Browse files
Files changed (1) hide show
  1. S2FApp/app.py +139 -36
S2FApp/app.py CHANGED
@@ -34,6 +34,29 @@ DRAW_TOOLS = ["polygon", "rect", "circle"]
34
  TOOL_LABELS = {"polygon": "Polygon", "rect": "Rectangle", "circle": "Circle"}
35
  CANVAS_SIZE = 320
36
  SAMPLE_EXTENSIONS = (".tif", ".tiff", ".png", ".jpg", ".jpeg")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  CITATION = (
38
  "Lautaro Baro, Kaveh Shahhosseini, Amparo Andrés-Bordería, Claudio Angione, and Maria Angeles Juanes. "
39
  "**\"Shape-to-force (S2F): Predicting Cell Traction Forces from LabelFree Imaging\"**, 2026."
@@ -129,17 +152,79 @@ def _parse_canvas_shapes_to_mask(json_data, canvas_h, canvas_w, heatmap_h, heatm
129
  return mask, count
130
 
131
 
132
- def _heatmap_to_png_bytes(scaled_heatmap):
133
- """Convert scaled heatmap (float 0-1) to PNG bytes buffer."""
134
  heatmap_uint8 = (np.clip(scaled_heatmap, 0, 1) * 255).astype(np.uint8)
135
- heatmap_rgb = cv2.applyColorMap(heatmap_uint8, cv2.COLORMAP_JET)
136
- heatmap_rgb = cv2.cvtColor(heatmap_rgb, cv2.COLOR_BGR2RGB)
 
 
 
 
 
 
137
  buf = io.BytesIO()
138
  Image.fromarray(heatmap_rgb).save(buf, format="PNG")
139
  buf.seek(0)
140
  return buf
141
 
142
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
143
  def _build_original_vals(scaled_heatmap, pixel_sum, force, force_scale):
144
  """Build original_vals dict for measure tool."""
145
  return {
@@ -150,9 +235,9 @@ def _build_original_vals(scaled_heatmap, pixel_sum, force, force_scale):
150
  }
151
 
152
 
153
- def _render_result_display(img, scaled_heatmap, pixel_sum, force, force_scale, key_img, download_key_suffix=""):
154
  """Render prediction result: plot, metrics, expander, and download/measure buttons."""
155
- buf_hm = _heatmap_to_png_bytes(scaled_heatmap)
156
  base_name = os.path.splitext(key_img or "image")[0]
157
  main_csv_rows = [
158
  ["image", "Sum of all pixels", "Cell force (scaled)", "Heatmap max", "Heatmap mean"],
@@ -169,7 +254,8 @@ def _render_result_display(img, scaled_heatmap, pixel_sum, force, force_scale, k
169
  st.markdown('<p style="font-size: 1.1rem; color: black; font-weight: 600;">Output: Predicted traction force map</p>', unsafe_allow_html=True)
170
  fig_pl = make_subplots(rows=1, cols=2)
171
  fig_pl.add_trace(go.Heatmap(z=img, colorscale="gray", showscale=False), row=1, col=1)
172
- fig_pl.add_trace(go.Heatmap(z=scaled_heatmap, colorscale="Jet", zmin=0, zmax=1, showscale=True,
 
173
  colorbar=dict(len=0.4, thickness=12)), row=1, col=2)
174
  fig_pl.update_layout(
175
  height=400,
@@ -208,7 +294,8 @@ This is the raw image you provided—it shows cell shape but not forces.
208
  """)
209
 
210
  original_vals = _build_original_vals(scaled_heatmap, pixel_sum, force, force_scale)
211
- btn_col1, btn_col2, btn_col3 = st.columns(3)
 
212
  with btn_col1:
213
  if HAS_DRAWABLE_CANVAS and st_dialog:
214
  if st.button("Measure tool", key="open_measure", icon=":material/straighten:"):
@@ -222,6 +309,7 @@ This is the raw image you provided—it shows cell shape but not forces.
222
  original_vals=original_vals,
223
  key_suffix="expander",
224
  input_filename=key_img,
 
225
  )
226
  else:
227
  st.caption("Install `streamlit-drawable-canvas-fix` for region measurement: `pip install streamlit-drawable-canvas-fix`")
@@ -245,6 +333,16 @@ This is the raw image you provided—it shows cell shape but not forces.
245
  key=f"download_main_values{download_key_suffix}",
246
  icon=":material/download:",
247
  )
 
 
 
 
 
 
 
 
 
 
248
 
249
 
250
  def _compute_region_metrics(scaled_heatmap, mask, original_vals=None):
@@ -322,11 +420,10 @@ def _render_region_metrics_and_downloads(metrics, heatmap_rgb, mask, input_filen
322
  key=f"download_annotated_{key_suffix}", icon=":material/image:")
323
 
324
 
325
- def _render_region_canvas(scaled_heatmap, bf_img=None, original_vals=None, key_suffix="", input_filename=None):
326
  """Render drawable canvas and region metrics. Used in dialog or expander."""
327
  h, w = scaled_heatmap.shape
328
- heatmap_display = (np.clip(scaled_heatmap, 0, 1) * 255).astype(np.uint8)
329
- heatmap_rgb = cv2.cvtColor(cv2.applyColorMap(heatmap_display, cv2.COLORMAP_JET), cv2.COLOR_BGR2RGB)
330
  pil_bg = Image.fromarray(heatmap_rgb).resize((CANVAS_SIZE, CANVAS_SIZE), Image.Resampling.LANCZOS)
331
 
332
  st.markdown("""
@@ -404,7 +501,8 @@ if HAS_DRAWABLE_CANVAS and st_dialog:
404
  bf_img = st.session_state.get("measure_bf_img")
405
  original_vals = st.session_state.get("measure_original_vals")
406
  input_filename = st.session_state.get("measure_input_filename", "image")
407
- _render_region_canvas(scaled_heatmap, bf_img=bf_img, original_vals=original_vals, key_suffix="dialog", input_filename=input_filename)
 
408
  else:
409
  def measure_region_dialog():
410
  pass # no-op when canvas or dialog not available
@@ -519,6 +617,11 @@ with st.sidebar:
519
  format="%.2f",
520
  help="Scale the displayed force values. 1 = full intensity, 0.5 = half the pixel values.",
521
  )
 
 
 
 
 
522
 
523
  # Main area: image input
524
  img_source = st.radio("Image source", ["Upload", "Example"], horizontal=True, label_visibility="collapsed")
@@ -594,33 +697,32 @@ if just_ran:
594
  checkpoint_path=checkpoint,
595
  ckp_folder=ckp_folder,
596
  )
597
- if img is not None:
598
- sub_val = substrate_val if model_type == "single_cell" and not use_manual else "fibroblasts_PDMS"
599
- heatmap, force, pixel_sum = predictor.predict(
600
- image_array=img,
601
- substrate=sub_val,
602
- substrate_config=substrate_config if model_type == "single_cell" else None,
603
- )
604
 
605
- st.success("Prediction complete!")
606
 
607
- scaled_heatmap = heatmap * force_scale
608
 
609
- # Store result and measure data before rendering (Measure click survives rerun)
610
- cache_key = (model_type, checkpoint, key_img)
611
- st.session_state["prediction_result"] = {
612
- "img": img.copy(),
613
- "heatmap": heatmap.copy(),
614
- "force": force,
615
- "pixel_sum": pixel_sum,
616
- "cache_key": cache_key,
617
- }
618
- st.session_state["measure_scaled_heatmap"] = scaled_heatmap.copy()
619
- st.session_state["measure_bf_img"] = img.copy()
620
- st.session_state["measure_input_filename"] = key_img or "image"
621
- st.session_state["measure_original_vals"] = _build_original_vals(scaled_heatmap, pixel_sum, force, force_scale)
622
 
623
- _render_result_display(img, scaled_heatmap, pixel_sum, force, force_scale, key_img)
624
 
625
  except Exception as e:
626
  st.error(f"Prediction failed: {e}")
@@ -635,12 +737,13 @@ elif has_cached:
635
  st.session_state["measure_bf_img"] = img.copy()
636
  st.session_state["measure_input_filename"] = key_img or "image"
637
  st.session_state["measure_original_vals"] = _build_original_vals(scaled_heatmap, pixel_sum, force, force_scale)
 
638
 
639
  if st.session_state.pop("open_measure_dialog", False):
640
  measure_region_dialog()
641
 
642
  st.success("Prediction complete!")
643
- _render_result_display(img, scaled_heatmap, pixel_sum, force, force_scale, key_img, download_key_suffix="_cached")
644
 
645
  elif run and not checkpoint:
646
  st.warning("Please add checkpoint files to the ckp/ folder and select one.")
 
34
  TOOL_LABELS = {"polygon": "Polygon", "rect": "Rectangle", "circle": "Circle"}
35
  CANVAS_SIZE = 320
36
  SAMPLE_EXTENSIONS = (".tif", ".tiff", ".png", ".jpg", ".jpeg")
37
+ COLORMAPS = {
38
+ "Jet": cv2.COLORMAP_JET,
39
+ "Viridis": cv2.COLORMAP_VIRIDIS,
40
+ "Plasma": cv2.COLORMAP_PLASMA,
41
+ "Inferno": cv2.COLORMAP_INFERNO,
42
+ "Magma": cv2.COLORMAP_MAGMA,
43
+ }
44
+
45
+
46
+ def _cv_colormap_to_plotly_colorscale(colormap_name, n_samples=64):
47
+ """Build a Plotly colorscale from OpenCV colormap so UI matches download/PDF exactly."""
48
+ cv2_cmap = COLORMAPS.get(colormap_name, cv2.COLORMAP_JET)
49
+ gradient = np.linspace(0, 255, n_samples, dtype=np.uint8).reshape(1, -1)
50
+ rgb = cv2.applyColorMap(gradient, cv2_cmap)
51
+ rgb = cv2.cvtColor(rgb, cv2.COLOR_BGR2RGB)
52
+ # Plotly colorscale: [[position 0..1, 'rgb(r,g,b)'], ...]
53
+ scale = []
54
+ for i in range(n_samples):
55
+ r, g, b = rgb[0, i]
56
+ scale.append([i / (n_samples - 1), f"rgb({r},{g},{b})"])
57
+ return scale
58
+
59
+
60
  CITATION = (
61
  "Lautaro Baro, Kaveh Shahhosseini, Amparo Andrés-Bordería, Claudio Angione, and Maria Angeles Juanes. "
62
  "**\"Shape-to-force (S2F): Predicting Cell Traction Forces from LabelFree Imaging\"**, 2026."
 
152
  return mask, count
153
 
154
 
155
+ def _heatmap_to_rgb(scaled_heatmap, colormap_name="Jet"):
156
+ """Convert scaled heatmap (float 0-1) to RGB array using the given colormap."""
157
  heatmap_uint8 = (np.clip(scaled_heatmap, 0, 1) * 255).astype(np.uint8)
158
+ cv2_colormap = COLORMAPS.get(colormap_name, cv2.COLORMAP_JET)
159
+ heatmap_rgb = cv2.cvtColor(cv2.applyColorMap(heatmap_uint8, cv2_colormap), cv2.COLOR_BGR2RGB)
160
+ return heatmap_rgb
161
+
162
+
163
+ def _heatmap_to_png_bytes(scaled_heatmap, colormap_name="Jet"):
164
+ """Convert scaled heatmap (float 0-1) to PNG bytes buffer."""
165
+ heatmap_rgb = _heatmap_to_rgb(scaled_heatmap, colormap_name)
166
  buf = io.BytesIO()
167
  Image.fromarray(heatmap_rgb).save(buf, format="PNG")
168
  buf.seek(0)
169
  return buf
170
 
171
 
172
+ def _create_pdf_report(img, scaled_heatmap, pixel_sum, force, force_scale, base_name, colormap_name="Jet"):
173
+ """Create a PDF report with input image, heatmap, and metrics."""
174
+ from reportlab.lib.pagesizes import A4
175
+ from reportlab.lib.units import inch
176
+ from reportlab.pdfgen import canvas
177
+ from reportlab.lib.utils import ImageReader
178
+
179
+ buf = io.BytesIO()
180
+ c = canvas.Canvas(buf, pagesize=A4)
181
+ w, h = A4
182
+ img_w, img_h = 2.5 * inch, 2.5 * inch
183
+
184
+ # Images first (drawn lower so title can go on top)
185
+ img_top = h - 70
186
+ img_pil = Image.fromarray(img) if img.ndim == 2 else Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
187
+ img_buf = io.BytesIO()
188
+ img_pil.save(img_buf, format="PNG")
189
+ img_buf.seek(0)
190
+ c.drawImage(ImageReader(img_buf), 72, img_top - img_h, width=img_w, height=img_h, preserveAspectRatio=True)
191
+ c.setFont("Helvetica", 9)
192
+ c.drawString(72, img_top - img_h - 12, "Input: Bright-field")
193
+
194
+ heatmap_rgb = _heatmap_to_rgb(scaled_heatmap, colormap_name)
195
+ hm_buf = io.BytesIO()
196
+ Image.fromarray(heatmap_rgb).save(hm_buf, format="PNG")
197
+ hm_buf.seek(0)
198
+ c.drawImage(ImageReader(hm_buf), 72 + img_w + 20, img_top - img_h, width=img_w, height=img_h, preserveAspectRatio=True)
199
+ c.drawString(72 + img_w + 20, img_top - img_h - 12, "Output: Force map")
200
+
201
+ # Title above images
202
+ c.setFont("Helvetica-Bold", 16)
203
+ c.drawString(72, img_top + 25, "Shape2Force (S2F) - Prediction Report")
204
+ c.setFont("Helvetica", 10)
205
+ c.drawString(72, img_top + 8, f"Image: {base_name}")
206
+
207
+ # Metrics table below images
208
+ y = img_top - img_h - 45
209
+ c.setFont("Helvetica-Bold", 10)
210
+ c.drawString(72, y, "Metrics")
211
+ c.setFont("Helvetica", 9)
212
+ y -= 18
213
+ metrics = [
214
+ ("Sum of all pixels", f"{pixel_sum * force_scale:.2f}"),
215
+ ("Cell force (scaled)", f"{force * force_scale:.2f}"),
216
+ ("Heatmap max", f"{np.max(scaled_heatmap):.4f}"),
217
+ ("Heatmap mean", f"{np.mean(scaled_heatmap):.4f}"),
218
+ ]
219
+ for label, val in metrics:
220
+ c.drawString(72, y, f"{label}: {val}")
221
+ y -= 16
222
+
223
+ c.save()
224
+ buf.seek(0)
225
+ return buf.getvalue()
226
+
227
+
228
  def _build_original_vals(scaled_heatmap, pixel_sum, force, force_scale):
229
  """Build original_vals dict for measure tool."""
230
  return {
 
235
  }
236
 
237
 
238
+ def _render_result_display(img, scaled_heatmap, pixel_sum, force, force_scale, key_img, download_key_suffix="", colormap_name="Jet"):
239
  """Render prediction result: plot, metrics, expander, and download/measure buttons."""
240
+ buf_hm = _heatmap_to_png_bytes(scaled_heatmap, colormap_name)
241
  base_name = os.path.splitext(key_img or "image")[0]
242
  main_csv_rows = [
243
  ["image", "Sum of all pixels", "Cell force (scaled)", "Heatmap max", "Heatmap mean"],
 
254
  st.markdown('<p style="font-size: 1.1rem; color: black; font-weight: 600;">Output: Predicted traction force map</p>', unsafe_allow_html=True)
255
  fig_pl = make_subplots(rows=1, cols=2)
256
  fig_pl.add_trace(go.Heatmap(z=img, colorscale="gray", showscale=False), row=1, col=1)
257
+ plotly_colorscale = _cv_colormap_to_plotly_colorscale(colormap_name)
258
+ fig_pl.add_trace(go.Heatmap(z=scaled_heatmap, colorscale=plotly_colorscale, zmin=0, zmax=1, showscale=True,
259
  colorbar=dict(len=0.4, thickness=12)), row=1, col=2)
260
  fig_pl.update_layout(
261
  height=400,
 
294
  """)
295
 
296
  original_vals = _build_original_vals(scaled_heatmap, pixel_sum, force, force_scale)
297
+ pdf_bytes = _create_pdf_report(img, scaled_heatmap, pixel_sum, force, force_scale, base_name, colormap_name)
298
+ btn_col1, btn_col2, btn_col3, btn_col4 = st.columns(4)
299
  with btn_col1:
300
  if HAS_DRAWABLE_CANVAS and st_dialog:
301
  if st.button("Measure tool", key="open_measure", icon=":material/straighten:"):
 
309
  original_vals=original_vals,
310
  key_suffix="expander",
311
  input_filename=key_img,
312
+ colormap_name=colormap_name,
313
  )
314
  else:
315
  st.caption("Install `streamlit-drawable-canvas-fix` for region measurement: `pip install streamlit-drawable-canvas-fix`")
 
333
  key=f"download_main_values{download_key_suffix}",
334
  icon=":material/download:",
335
  )
336
+ with btn_col4:
337
+ st.download_button(
338
+ "Download report",
339
+ width="stretch",
340
+ data=pdf_bytes,
341
+ file_name=f"{base_name}_report.pdf",
342
+ mime="application/pdf",
343
+ key=f"download_pdf{download_key_suffix}",
344
+ icon=":material/picture_as_pdf:",
345
+ )
346
 
347
 
348
  def _compute_region_metrics(scaled_heatmap, mask, original_vals=None):
 
420
  key=f"download_annotated_{key_suffix}", icon=":material/image:")
421
 
422
 
423
+ def _render_region_canvas(scaled_heatmap, bf_img=None, original_vals=None, key_suffix="", input_filename=None, colormap_name="Jet"):
424
  """Render drawable canvas and region metrics. Used in dialog or expander."""
425
  h, w = scaled_heatmap.shape
426
+ heatmap_rgb = _heatmap_to_rgb(scaled_heatmap, colormap_name)
 
427
  pil_bg = Image.fromarray(heatmap_rgb).resize((CANVAS_SIZE, CANVAS_SIZE), Image.Resampling.LANCZOS)
428
 
429
  st.markdown("""
 
501
  bf_img = st.session_state.get("measure_bf_img")
502
  original_vals = st.session_state.get("measure_original_vals")
503
  input_filename = st.session_state.get("measure_input_filename", "image")
504
+ colormap_name = st.session_state.get("measure_colormap", "Jet")
505
+ _render_region_canvas(scaled_heatmap, bf_img=bf_img, original_vals=original_vals, key_suffix="dialog", input_filename=input_filename, colormap_name=colormap_name)
506
  else:
507
  def measure_region_dialog():
508
  pass # no-op when canvas or dialog not available
 
617
  format="%.2f",
618
  help="Scale the displayed force values. 1 = full intensity, 0.5 = half the pixel values.",
619
  )
620
+ colormap_name = st.selectbox(
621
+ "Heatmap colormap",
622
+ list(COLORMAPS.keys()),
623
+ help="Color scheme for the force map. Viridis is often preferred for accessibility.",
624
+ )
625
 
626
  # Main area: image input
627
  img_source = st.radio("Image source", ["Upload", "Example"], horizontal=True, label_visibility="collapsed")
 
697
  checkpoint_path=checkpoint,
698
  ckp_folder=ckp_folder,
699
  )
700
+ sub_val = substrate_val if model_type == "single_cell" and not use_manual else "fibroblasts_PDMS"
701
+ heatmap, force, pixel_sum = predictor.predict(
702
+ image_array=img,
703
+ substrate=sub_val,
704
+ substrate_config=substrate_config if model_type == "single_cell" else None,
705
+ )
 
706
 
707
+ st.success("Prediction complete!")
708
 
709
+ scaled_heatmap = heatmap * force_scale
710
 
711
+ cache_key = (model_type, checkpoint, key_img)
712
+ st.session_state["prediction_result"] = {
713
+ "img": img.copy(),
714
+ "heatmap": heatmap.copy(),
715
+ "force": force,
716
+ "pixel_sum": pixel_sum,
717
+ "cache_key": cache_key,
718
+ }
719
+ st.session_state["measure_scaled_heatmap"] = scaled_heatmap.copy()
720
+ st.session_state["measure_bf_img"] = img.copy()
721
+ st.session_state["measure_input_filename"] = key_img or "image"
722
+ st.session_state["measure_original_vals"] = _build_original_vals(scaled_heatmap, pixel_sum, force, force_scale)
723
+ st.session_state["measure_colormap"] = colormap_name
724
 
725
+ _render_result_display(img, scaled_heatmap, pixel_sum, force, force_scale, key_img, colormap_name=colormap_name)
726
 
727
  except Exception as e:
728
  st.error(f"Prediction failed: {e}")
 
737
  st.session_state["measure_bf_img"] = img.copy()
738
  st.session_state["measure_input_filename"] = key_img or "image"
739
  st.session_state["measure_original_vals"] = _build_original_vals(scaled_heatmap, pixel_sum, force, force_scale)
740
+ st.session_state["measure_colormap"] = colormap_name
741
 
742
  if st.session_state.pop("open_measure_dialog", False):
743
  measure_region_dialog()
744
 
745
  st.success("Prediction complete!")
746
+ _render_result_display(img, scaled_heatmap, pixel_sum, force, force_scale, key_img, download_key_suffix="_cached", colormap_name=colormap_name)
747
 
748
  elif run and not checkpoint:
749
  st.warning("Please add checkpoint files to the ckp/ folder and select one.")