kaveh commited on
Commit
bcea26a
·
1 Parent(s): a181e0e

updated for clarity

Browse files
Files changed (1) hide show
  1. S2FApp/app.py +127 -34
S2FApp/app.py CHANGED
@@ -27,7 +27,7 @@ st.markdown("""
27
  </style>
28
  """, unsafe_allow_html=True)
29
  st.title("🔬 Shape2Force (S2F)")
30
- st.caption("Predict force maps from bright field microscopy images")
31
 
32
  # Folders: checkpoints in subfolders by model type (single_cell / spheroid)
33
  ckp_base = os.path.join(S2F_ROOT, "ckp")
@@ -111,24 +111,17 @@ with st.sidebar:
111
  except FileNotFoundError:
112
  st.error("config/substrate_settings.json not found")
113
 
114
- st.divider()
115
- st.subheader("Display")
116
- display_size = st.slider("Image size (px)", min_value=200, max_value=800, value=350, step=50,
117
- help="Adjust display size. Drag to pan, scroll to zoom.")
118
-
119
- st.divider()
120
-
121
  # Main area: image input
122
- img_source = st.radio("Image source", ["Upload", "Sample"], horizontal=True, label_visibility="collapsed")
123
  img = None
124
  uploaded = None
125
  selected_sample = None
126
 
127
  if img_source == "Upload":
128
  uploaded = st.file_uploader(
129
- "Upload bright field image",
130
  type=["tif", "tiff", "png", "jpg", "jpeg"],
131
- help="Bright field microscopy image (grayscale or RGB)",
132
  )
133
  if uploaded:
134
  bytes_data = uploaded.read()
@@ -141,7 +134,7 @@ else:
141
  sample_subfolder_name = "single_cell" if model_type == "single_cell" else "spheroid"
142
  if sample_files:
143
  selected_sample = st.selectbox(
144
- "Select sample image",
145
  sample_files,
146
  format_func=lambda x: x,
147
  key=f"sample_{model_type}",
@@ -149,8 +142,8 @@ else:
149
  if selected_sample:
150
  sample_path = os.path.join(sample_folder, selected_sample)
151
  img = cv2.imread(sample_path, cv2.IMREAD_GRAYSCALE)
152
- # Show sample thumbnails (filtered by model type)
153
- st.caption(f"Sample images from `samples/{sample_subfolder_name}/`")
154
  n_cols = min(5, len(sample_files))
155
  cols = st.columns(n_cols)
156
  for i, fname in enumerate(sample_files[:8]): # show up to 8
@@ -160,12 +153,24 @@ else:
160
  if sample_img is not None:
161
  st.image(sample_img, caption=fname, width='content')
162
  else:
163
- st.info(f"No sample images in samples/{sample_subfolder_name}/. Add images or use Upload.")
164
 
165
  run = st.button("Run prediction", type="primary")
166
  has_image = img is not None
167
 
168
- if run and checkpoint and has_image:
 
 
 
 
 
 
 
 
 
 
 
 
169
  st.markdown(f"**Using checkpoint:** `ckp/{ckp_subfolder_name}/{checkpoint}`")
170
  with st.spinner("Loading model and predicting..."):
171
  try:
@@ -185,23 +190,18 @@ if run and checkpoint and has_image:
185
 
186
  st.success("Prediction complete!")
187
 
188
- # Metrics
189
- col1, col2, col3, col4 = st.columns(4)
190
- with col1:
191
- st.metric("Sum of all pixels", f"{pixel_sum:.2f}")
192
- with col2:
193
- st.metric("Cell force (scaled)", f"{force:.2f}")
194
- with col3:
195
- st.metric("Heatmap max", f"{np.max(heatmap):.4f}")
196
- with col4:
197
- st.metric("Heatmap mean", f"{np.mean(heatmap):.4f}")
198
-
199
- # Visualization - Plotly with zoom/pan
200
- fig_pl = make_subplots(rows=1, cols=2, subplot_titles=["", ""])
201
  fig_pl.add_trace(go.Heatmap(z=img, colorscale="gray", showscale=False), row=1, col=1)
202
- fig_pl.add_trace(go.Heatmap(z=heatmap, colorscale="Jet", zmin=0, zmax=1, showscale=True), row=1, col=2)
 
203
  fig_pl.update_layout(
204
- height=display_size,
205
  margin=dict(l=10, r=10, t=10, b=10),
206
  xaxis=dict(scaleanchor="y", scaleratio=1),
207
  xaxis2=dict(scaleanchor="y2", scaleratio=1),
@@ -210,6 +210,34 @@ if run and checkpoint and has_image:
210
  fig_pl.update_yaxes(showticklabels=False, autorange="reversed")
211
  st.plotly_chart(fig_pl, use_container_width=True)
212
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
213
  # Download
214
  heatmap_uint8 = (np.clip(heatmap, 0, 1) * 255).astype(np.uint8)
215
  heatmap_rgb = cv2.applyColorMap(heatmap_uint8, cv2.COLORMAP_JET)
@@ -219,18 +247,83 @@ if run and checkpoint and has_image:
219
  pil_heatmap.save(buf_hm, format="PNG")
220
  buf_hm.seek(0)
221
  st.download_button("Download Heatmap", data=buf_hm.getvalue(),
222
- file_name="s2f_heatmap.png", mime="image/png")
 
 
 
 
 
 
 
 
 
 
223
 
224
  except Exception as e:
225
  st.error(f"Prediction failed: {e}")
226
  import traceback
227
  st.code(traceback.format_exc())
228
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
229
  elif run and not checkpoint:
230
  st.warning("Please add checkpoint files to the ckp/ folder and select one.")
231
  elif run and not has_image:
232
- st.warning("Please upload an image or select a sample.")
233
 
234
  # Footer
235
  st.sidebar.divider()
236
- st.sidebar.caption("Checkpoints: ckp/single_cell/ and ckp/spheroid/. Samples: samples/single_cell/ and samples/spheroid/")
 
 
27
  </style>
28
  """, unsafe_allow_html=True)
29
  st.title("🔬 Shape2Force (S2F)")
30
+ st.caption("Predict traction force maps from bright-field microscopy images of cells or spheroids")
31
 
32
  # Folders: checkpoints in subfolders by model type (single_cell / spheroid)
33
  ckp_base = os.path.join(S2F_ROOT, "ckp")
 
111
  except FileNotFoundError:
112
  st.error("config/substrate_settings.json not found")
113
 
 
 
 
 
 
 
 
114
  # Main area: image input
115
+ img_source = st.radio("Image source", ["Upload", "Example"], horizontal=True, label_visibility="collapsed")
116
  img = None
117
  uploaded = None
118
  selected_sample = None
119
 
120
  if img_source == "Upload":
121
  uploaded = st.file_uploader(
122
+ "Upload bright-field image",
123
  type=["tif", "tiff", "png", "jpg", "jpeg"],
124
+ help="Bright-field microscopy image of a cell or spheroid on a substrate (grayscale or RGB). The model will predict traction forces from the cell shape.",
125
  )
126
  if uploaded:
127
  bytes_data = uploaded.read()
 
134
  sample_subfolder_name = "single_cell" if model_type == "single_cell" else "spheroid"
135
  if sample_files:
136
  selected_sample = st.selectbox(
137
+ "Select example image",
138
  sample_files,
139
  format_func=lambda x: x,
140
  key=f"sample_{model_type}",
 
142
  if selected_sample:
143
  sample_path = os.path.join(sample_folder, selected_sample)
144
  img = cv2.imread(sample_path, cv2.IMREAD_GRAYSCALE)
145
+ # Show example thumbnails (filtered by model type)
146
+ st.caption(f"Example images from `samples/{sample_subfolder_name}/`")
147
  n_cols = min(5, len(sample_files))
148
  cols = st.columns(n_cols)
149
  for i, fname in enumerate(sample_files[:8]): # show up to 8
 
153
  if sample_img is not None:
154
  st.image(sample_img, caption=fname, width='content')
155
  else:
156
+ st.info(f"No example images in samples/{sample_subfolder_name}/. Add images or use Upload.")
157
 
158
  run = st.button("Run prediction", type="primary")
159
  has_image = img is not None
160
 
161
+ # Persist results in session state so they survive re-runs (e.g. when clicking Download)
162
+ if "prediction_result" not in st.session_state:
163
+ st.session_state["prediction_result"] = None
164
+
165
+ # Show results if we just ran prediction OR we have cached results from a previous run
166
+ just_ran = run and checkpoint and has_image
167
+ cached = st.session_state["prediction_result"]
168
+ key_img = (uploaded.name if uploaded else None) if img_source == "Upload" else selected_sample
169
+ current_key = (model_type, checkpoint, key_img)
170
+ has_cached = cached is not None and cached.get("cache_key") == current_key
171
+
172
+ if just_ran:
173
+ st.session_state["prediction_result"] = None # Clear before new run
174
  st.markdown(f"**Using checkpoint:** `ckp/{ckp_subfolder_name}/{checkpoint}`")
175
  with st.spinner("Loading model and predicting..."):
176
  try:
 
190
 
191
  st.success("Prediction complete!")
192
 
193
+ # Visualization - Plotly with zoom/pan, annotated (titles in Streamlit to avoid clipping)
194
+ tit1, tit2 = st.columns(2)
195
+ with tit1:
196
+ st.markdown('<p style="font-size: 1.1rem; color: black; font-weight: 600;">Input: Bright-field image</p>', unsafe_allow_html=True)
197
+ with tit2:
198
+ st.markdown('<p style="font-size: 1.1rem; color: black; font-weight: 600;">Output: Predicted traction force map</p>', unsafe_allow_html=True)
199
+ fig_pl = make_subplots(rows=1, cols=2)
 
 
 
 
 
 
200
  fig_pl.add_trace(go.Heatmap(z=img, colorscale="gray", showscale=False), row=1, col=1)
201
+ fig_pl.add_trace(go.Heatmap(z=heatmap, colorscale="Jet", zmin=0, zmax=1, showscale=True,
202
+ colorbar=dict(len=0.4, thickness=12)), row=1, col=2)
203
  fig_pl.update_layout(
204
+ height=400,
205
  margin=dict(l=10, r=10, t=10, b=10),
206
  xaxis=dict(scaleanchor="y", scaleratio=1),
207
  xaxis2=dict(scaleanchor="y2", scaleratio=1),
 
210
  fig_pl.update_yaxes(showticklabels=False, autorange="reversed")
211
  st.plotly_chart(fig_pl, use_container_width=True)
212
 
213
+ # Metrics with help (below plot)
214
+ col1, col2, col3, col4 = st.columns(4)
215
+ with col1:
216
+ st.metric("Sum of all pixels", f"{pixel_sum:.2f}", help="Raw sum of all pixel values in the force map")
217
+ with col2:
218
+ st.metric("Cell force (scaled)", f"{force:.2f}", help="Total traction force in physical units")
219
+ with col3:
220
+ st.metric("Heatmap max", f"{np.max(heatmap):.4f}", help="Peak force intensity in the map")
221
+ with col4:
222
+ st.metric("Heatmap mean", f"{np.mean(heatmap):.4f}", help="Average force intensity")
223
+
224
+ # How to read (below numbers)
225
+ with st.expander("ℹ️ How to read the results"):
226
+ st.markdown("""
227
+ **Input (left):** Bright-field microscopy image of a cell or spheroid on a substrate.
228
+ This is the raw image you provided—it shows cell shape but not forces.
229
+
230
+ **Output (right):** Predicted traction force map.
231
+ - **Color** indicates force magnitude: blue = low, red = high
232
+ - **Brighter/warmer colors** = stronger forces exerted by the cell on the substrate
233
+ - Values are normalized to [0, 1] for visualization
234
+
235
+ **Metrics:**
236
+ - **Sum of all pixels:** Total force is the sum of all pixels in the force map. Each pixel represents the magnitude of force at that location; summing them gives the overall traction.
237
+ - **Cell force (scaled):** Total traction force in physical units (scaled by substrate stiffness)
238
+ - **Heatmap max/mean:** Peak and average force intensity in the map
239
+ """)
240
+
241
  # Download
242
  heatmap_uint8 = (np.clip(heatmap, 0, 1) * 255).astype(np.uint8)
243
  heatmap_rgb = cv2.applyColorMap(heatmap_uint8, cv2.COLORMAP_JET)
 
247
  pil_heatmap.save(buf_hm, format="PNG")
248
  buf_hm.seek(0)
249
  st.download_button("Download Heatmap", data=buf_hm.getvalue(),
250
+ file_name="s2f_heatmap.png", mime="image/png", key="download_heatmap")
251
+
252
+ # Store in session state so results persist when user clicks Download
253
+ cache_key = (model_type, checkpoint, key_img)
254
+ st.session_state["prediction_result"] = {
255
+ "img": img.copy(),
256
+ "heatmap": heatmap.copy(),
257
+ "force": force,
258
+ "pixel_sum": pixel_sum,
259
+ "cache_key": cache_key,
260
+ }
261
 
262
  except Exception as e:
263
  st.error(f"Prediction failed: {e}")
264
  import traceback
265
  st.code(traceback.format_exc())
266
 
267
+ elif has_cached:
268
+ # Show cached results (e.g. after clicking Download)
269
+ r = st.session_state["prediction_result"]
270
+ img, heatmap, force, pixel_sum = r["img"], r["heatmap"], r["force"], r["pixel_sum"]
271
+ st.success("Prediction complete!")
272
+ tit1, tit2 = st.columns(2)
273
+ with tit1:
274
+ st.markdown('<p style="font-size: 1.1rem; color: black; font-weight: 600;">Input: Bright-field image</p>', unsafe_allow_html=True)
275
+ with tit2:
276
+ st.markdown('<p style="font-size: 1.1rem; color: black; font-weight: 600;">Output: Predicted traction force map</p>', unsafe_allow_html=True)
277
+ fig_pl = make_subplots(rows=1, cols=2)
278
+ fig_pl.add_trace(go.Heatmap(z=img, colorscale="gray", showscale=False), row=1, col=1)
279
+ fig_pl.add_trace(go.Heatmap(z=heatmap, colorscale="Jet", zmin=0, zmax=1, showscale=True,
280
+ colorbar=dict(len=0.4, thickness=12)), row=1, col=2)
281
+ fig_pl.update_layout(height=400, margin=dict(l=10, r=10, t=10, b=10),
282
+ xaxis=dict(scaleanchor="y", scaleratio=1),
283
+ xaxis2=dict(scaleanchor="y2", scaleratio=1))
284
+ fig_pl.update_xaxes(showticklabels=False)
285
+ fig_pl.update_yaxes(showticklabels=False, autorange="reversed")
286
+ st.plotly_chart(fig_pl, use_container_width=True)
287
+ col1, col2, col3, col4 = st.columns(4)
288
+ with col1:
289
+ st.metric("Sum of all pixels", f"{pixel_sum:.2f}", help="Raw sum of all pixel values in the force map")
290
+ with col2:
291
+ st.metric("Cell force (scaled)", f"{force:.2f}", help="Total traction force in physical units")
292
+ with col3:
293
+ st.metric("Heatmap max", f"{np.max(heatmap):.4f}", help="Peak force intensity in the map")
294
+ with col4:
295
+ st.metric("Heatmap mean", f"{np.mean(heatmap):.4f}", help="Average force intensity")
296
+ with st.expander("ℹ️ How to read the results"):
297
+ st.markdown("""
298
+ **Input (left):** Bright-field microscopy image of a cell or spheroid on a substrate.
299
+ This is the raw image you provided—it shows cell shape but not forces.
300
+
301
+ **Output (right):** Predicted traction force map.
302
+ - **Color** indicates force magnitude: blue = low, red = high
303
+ - **Brighter/warmer colors** = stronger forces exerted by the cell on the substrate
304
+ - Values are normalized to [0, 1] for visualization
305
+
306
+ **Metrics:**
307
+ - **Sum of all pixels:** Total force is the sum of all pixels in the force map. Each pixel represents the magnitude of force at that location; summing them gives the overall traction.
308
+ - **Cell force (scaled):** Total traction force in physical units (scaled by substrate stiffness)
309
+ - **Heatmap max/mean:** Peak and average force intensity in the map
310
+ """)
311
+ heatmap_uint8 = (np.clip(heatmap, 0, 1) * 255).astype(np.uint8)
312
+ heatmap_rgb = cv2.applyColorMap(heatmap_uint8, cv2.COLORMAP_JET)
313
+ heatmap_rgb = cv2.cvtColor(heatmap_rgb, cv2.COLOR_BGR2RGB)
314
+ pil_heatmap = Image.fromarray(heatmap_rgb)
315
+ buf_hm = io.BytesIO()
316
+ pil_heatmap.save(buf_hm, format="PNG")
317
+ buf_hm.seek(0)
318
+ st.download_button("Download Heatmap", data=buf_hm.getvalue(),
319
+ file_name="s2f_heatmap.png", mime="image/png", key="download_cached")
320
+
321
  elif run and not checkpoint:
322
  st.warning("Please add checkpoint files to the ckp/ folder and select one.")
323
  elif run and not has_image:
324
+ st.warning("Please upload an image or select an example.")
325
 
326
  # Footer
327
  st.sidebar.divider()
328
+ st.sidebar.caption(f"Checkpoint: `ckp/{ckp_subfolder_name}/`")
329
+ st.sidebar.caption(f"Examples: `samples/{ckp_subfolder_name}/`")