kaveh commited on
Commit
5a1821b
·
1 Parent(s): 1cc2797

cleaned app.py

Browse files
S2FApp/config/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # Config package
S2FApp/config/constants.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Centralized constants for S2F App.
3
+ """
4
+ import cv2
5
+
6
+ # Model & paths
7
+ MODEL_INPUT_SIZE = 1024
8
+
9
+ # UI
10
+ CANVAS_SIZE = 320
11
+ COLORMAP_N_SAMPLES = 64
12
+
13
+ # Model type labels
14
+ MODEL_TYPE_LABELS = {"single_cell": "Single cell", "spheroid": "Spheroid LS174T"}
15
+
16
+ # Drawing tools
17
+ DRAW_TOOLS = ["polygon", "rect", "circle"]
18
+ TOOL_LABELS = {"polygon": "Polygon", "rect": "Rectangle", "circle": "Circle"}
19
+
20
+ # File extensions
21
+ SAMPLE_EXTENSIONS = (".tif", ".tiff", ".png", ".jpg", ".jpeg")
22
+
23
+ # Colormaps (OpenCV)
24
+ COLORMAPS = {
25
+ "Jet": cv2.COLORMAP_JET,
26
+ "Viridis": cv2.COLORMAP_VIRIDIS,
27
+ "Plasma": cv2.COLORMAP_PLASMA,
28
+ "Inferno": cv2.COLORMAP_INFERNO,
29
+ "Magma": cv2.COLORMAP_MAGMA,
30
+ }
S2FApp/ui/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # UI components package
S2FApp/ui/components.py ADDED
@@ -0,0 +1,521 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """UI components for S2F App."""
2
+ import csv
3
+ import io
4
+ import os
5
+
6
+ import cv2
7
+ import numpy as np
8
+ import streamlit as st
9
+ from PIL import Image
10
+ import plotly.graph_objects as go
11
+ from plotly.subplots import make_subplots
12
+
13
+ from config.constants import (
14
+ CANVAS_SIZE,
15
+ COLORMAPS,
16
+ DRAW_TOOLS,
17
+ TOOL_LABELS,
18
+ )
19
+ from utils.display import apply_display_scale, cv_colormap_to_plotly_colorscale
20
+ from utils.report import heatmap_to_rgb, heatmap_to_png_bytes, create_pdf_report
21
+ from utils.segmentation import estimate_cell_mask
22
+
23
+ try:
24
+ from streamlit_drawable_canvas import st_canvas
25
+ HAS_DRAWABLE_CANVAS = True
26
+ except (ImportError, AttributeError):
27
+ HAS_DRAWABLE_CANVAS = False
28
+
29
+ # Resolve st.dialog early to fix ordering bug (used in _render_result_display)
30
+ ST_DIALOG = getattr(st, "dialog", None) or getattr(st, "experimental_dialog", None)
31
+
32
+
33
+ def make_annotated_heatmap(heatmap_rgb, mask, fill_alpha=0.3, stroke_color=(255, 102, 0), stroke_width=2):
34
+ """Composite heatmap with drawn region overlay."""
35
+ annotated = heatmap_rgb.copy()
36
+ contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
37
+ overlay = annotated.copy()
38
+ cv2.fillPoly(overlay, contours, stroke_color)
39
+ mask_3d = np.stack([mask] * 3, axis=-1).astype(bool)
40
+ annotated[mask_3d] = (
41
+ (1 - fill_alpha) * annotated[mask_3d].astype(np.float32)
42
+ + fill_alpha * overlay[mask_3d].astype(np.float32)
43
+ ).astype(np.uint8)
44
+ cv2.drawContours(annotated, contours, -1, stroke_color, stroke_width)
45
+ return annotated
46
+
47
+
48
+ def _obj_to_pts(obj, scale_x, scale_y, heatmap_w, heatmap_h):
49
+ """Convert a single canvas object to polygon points in heatmap coords. Returns None if invalid."""
50
+ obj_type = obj.get("type", "")
51
+ pts = []
52
+ if obj_type == "rect":
53
+ left = obj.get("left", 0)
54
+ top = obj.get("top", 0)
55
+ w = obj.get("width", 0)
56
+ h = obj.get("height", 0)
57
+ pts = np.array([
58
+ [left, top], [left + w, top], [left + w, top + h], [left, top + h]
59
+ ], dtype=np.float32)
60
+ elif obj_type == "circle" or obj_type == "ellipse":
61
+ left = obj.get("left", 0)
62
+ top = obj.get("top", 0)
63
+ width = obj.get("width", 0)
64
+ height = obj.get("height", 0)
65
+ radius = obj.get("radius", 0)
66
+ angle_deg = obj.get("angle", 0)
67
+ if radius > 0:
68
+ rx = ry = radius
69
+ angle_rad = np.deg2rad(angle_deg)
70
+ cx = left + radius * np.cos(angle_rad)
71
+ cy = top + radius * np.sin(angle_rad)
72
+ else:
73
+ rx = width / 2 if width > 0 else 0
74
+ ry = height / 2 if height > 0 else 0
75
+ if rx <= 0 or ry <= 0:
76
+ return None
77
+ cx = left + rx
78
+ cy = top + ry
79
+ if rx <= 0 or ry <= 0:
80
+ return None
81
+ n = 32
82
+ angles = np.linspace(0, 2 * np.pi, n, endpoint=False)
83
+ pts = np.column_stack([cx + rx * np.cos(angles), cy + ry * np.sin(angles)]).astype(np.float32)
84
+ elif obj_type == "path":
85
+ path = obj.get("path", [])
86
+ for cmd in path:
87
+ if isinstance(cmd, (list, tuple)) and len(cmd) >= 3:
88
+ if cmd[0] in ("M", "L"):
89
+ pts.append([float(cmd[1]), float(cmd[2])])
90
+ elif cmd[0] == "Q" and len(cmd) >= 5:
91
+ pts.append([float(cmd[3]), float(cmd[4])])
92
+ elif cmd[0] == "C" and len(cmd) >= 7:
93
+ pts.append([float(cmd[5]), float(cmd[6])])
94
+ if len(pts) < 3:
95
+ return None
96
+ pts = np.array(pts, dtype=np.float32)
97
+ else:
98
+ return None
99
+ pts[:, 0] *= scale_x
100
+ pts[:, 1] *= scale_y
101
+ pts = np.clip(pts, 0, [heatmap_w - 1, heatmap_h - 1]).astype(np.int32)
102
+ return pts
103
+
104
+
105
+ def parse_canvas_shapes_to_mask(json_data, canvas_h, canvas_w, heatmap_h, heatmap_w):
106
+ """Parse drawn shapes from streamlit-drawable-canvas json_data and create binary mask (combined)."""
107
+ masks, _ = parse_canvas_shapes_to_masks(json_data, canvas_h, canvas_w, heatmap_h, heatmap_w)
108
+ if not masks:
109
+ return None, 0
110
+ combined = np.zeros((heatmap_h, heatmap_w), dtype=np.uint8)
111
+ for m in masks:
112
+ combined = np.maximum(combined, m)
113
+ return combined, len(masks)
114
+
115
+
116
+ def parse_canvas_shapes_to_masks(json_data, canvas_h, canvas_w, heatmap_h, heatmap_w):
117
+ """Parse drawn shapes and return a list of individual masks (one per shape)."""
118
+ if not json_data or "objects" not in json_data or not json_data["objects"]:
119
+ return [], 0
120
+ scale_x = heatmap_w / canvas_w
121
+ scale_y = heatmap_h / canvas_h
122
+ masks = []
123
+ for obj in json_data["objects"]:
124
+ pts = _obj_to_pts(obj, scale_x, scale_y, heatmap_w, heatmap_h)
125
+ if pts is None:
126
+ continue
127
+ mask = np.zeros((heatmap_h, heatmap_w), dtype=np.uint8)
128
+ cv2.fillPoly(mask, [pts], 1)
129
+ masks.append(mask)
130
+ return masks, len(masks)
131
+
132
+
133
+ def build_original_vals(raw_heatmap, pixel_sum, force):
134
+ """Build original_vals dict for measure tool (full map)."""
135
+ return {
136
+ "pixel_sum": pixel_sum,
137
+ "force": force,
138
+ "max": float(np.max(raw_heatmap)),
139
+ "mean": float(np.mean(raw_heatmap)),
140
+ }
141
+
142
+
143
+ def build_cell_vals(raw_heatmap, cell_mask, pixel_sum, force):
144
+ """Build cell_vals dict for measure tool (estimated cell area). Returns None if invalid."""
145
+ cell_pixel_sum, cell_force, cell_mean = _compute_cell_metrics(raw_heatmap, cell_mask, pixel_sum, force)
146
+ if cell_pixel_sum is None:
147
+ return None
148
+ region_values = raw_heatmap * cell_mask
149
+ region_nonzero = region_values[cell_mask > 0]
150
+ cell_max = float(np.max(region_nonzero)) if len(region_nonzero) > 0 else 0
151
+ return {
152
+ "pixel_sum": cell_pixel_sum,
153
+ "force": cell_force,
154
+ "max": cell_max,
155
+ "mean": cell_mean,
156
+ }
157
+
158
+
159
+ def compute_region_metrics(raw_heatmap, mask, original_vals=None):
160
+ """Compute region metrics from mask."""
161
+ area_px = int(np.sum(mask))
162
+ region_values = raw_heatmap * mask
163
+ region_nonzero = region_values[mask > 0]
164
+ force_sum = float(np.sum(region_values))
165
+ density = force_sum / area_px if area_px > 0 else 0
166
+ region_max = float(np.max(region_nonzero)) if len(region_nonzero) > 0 else 0
167
+ region_mean = float(np.mean(region_nonzero)) if len(region_nonzero) > 0 else 0
168
+ region_force_scaled = (
169
+ force_sum * (original_vals["force"] / original_vals["pixel_sum"])
170
+ if original_vals and original_vals.get("pixel_sum", 0) > 0
171
+ else force_sum
172
+ )
173
+ return {
174
+ "area_px": area_px,
175
+ "force_sum": force_sum,
176
+ "density": density,
177
+ "max": region_max,
178
+ "mean": region_mean,
179
+ "force_scaled": region_force_scaled,
180
+ }
181
+
182
+
183
+ def render_region_metrics_and_downloads(metrics_list, heatmap_rgb, combined_mask, input_filename, key_suffix, has_original_vals,
184
+ first_region_label=None):
185
+ """Render per-shape metrics table and download buttons. first_region_label: custom label for first row (e.g. 'Auto boundary')."""
186
+ base_name = os.path.splitext(input_filename or "image")[0]
187
+ st.markdown("**Regions (each selection = one row)**")
188
+ if has_original_vals:
189
+ headers = ["Region", "Area", "F.sum", "Force", "Max", "Mean"]
190
+ csv_rows = [["image", "region"] + headers[1:]]
191
+ else:
192
+ headers = ["Region", "Area (px²)", "Force sum", "Mean"]
193
+ csv_rows = [["image", "region", "Area", "Force sum", "Mean"]]
194
+ table_rows = [headers]
195
+ for i, metrics in enumerate(metrics_list, 1):
196
+ region_label = first_region_label if (i == 1 and first_region_label) else f"Region {i - (1 if first_region_label else 0)}"
197
+ if has_original_vals:
198
+ row = [region_label, str(metrics["area_px"]), f"{metrics['force_sum']:.3f}", f"{metrics['force_scaled']:.1f}",
199
+ f"{metrics['max']:.3f}", f"{metrics['mean']:.4f}"]
200
+ csv_rows.append([base_name, region_label, metrics["area_px"], f"{metrics['force_sum']:.3f}",
201
+ f"{metrics['force_scaled']:.1f}", f"{metrics['max']:.3f}", f"{metrics['mean']:.4f}"])
202
+ else:
203
+ row = [region_label, str(metrics["area_px"]), f"{metrics['force_sum']:.4f}", f"{metrics['mean']:.6f}"]
204
+ csv_rows.append([base_name, region_label, metrics["area_px"], f"{metrics['force_sum']:.4f}",
205
+ f"{metrics['mean']:.6f}"])
206
+ table_rows.append(row)
207
+ st.table(table_rows)
208
+ buf_csv = io.StringIO()
209
+ csv.writer(buf_csv).writerows(csv_rows)
210
+ buf_img = io.BytesIO()
211
+ Image.fromarray(make_annotated_heatmap(heatmap_rgb, combined_mask)).save(buf_img, format="PNG")
212
+ buf_img.seek(0)
213
+ dl_col1, dl_col2 = st.columns(2)
214
+ with dl_col1:
215
+ st.download_button("Download all regions", data=buf_csv.getvalue(),
216
+ file_name=f"{base_name}_all_regions.csv", mime="text/csv",
217
+ key=f"download_all_regions_{key_suffix}", icon=":material/download:")
218
+ with dl_col2:
219
+ st.download_button("Download annotated heatmap", data=buf_img.getvalue(),
220
+ file_name=f"{base_name}_annotated_heatmap.png", mime="image/png",
221
+ key=f"download_annotated_{key_suffix}", icon=":material/image:")
222
+
223
+
224
+ def _draw_contour_on_image(img_rgb, mask, stroke_color=(255, 0, 0), stroke_width=2):
225
+ """Draw contour from mask on RGB image. Resizes mask to match img if needed."""
226
+ h, w = img_rgb.shape[:2]
227
+ if mask.shape[:2] != (h, w):
228
+ mask = cv2.resize(mask.astype(np.uint8), (w, h), interpolation=cv2.INTER_NEAREST)
229
+ contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
230
+ if contours:
231
+ cv2.drawContours(img_rgb, contours, -1, stroke_color, stroke_width)
232
+ return img_rgb
233
+
234
+
235
+ def render_region_canvas(display_heatmap, raw_heatmap=None, bf_img=None, original_vals=None, cell_vals=None,
236
+ cell_mask=None, key_suffix="", input_filename=None, colormap_name="Jet"):
237
+ """Render drawable canvas and region metrics. When cell_vals: show cell area (replaces Full map). Else: show Full map."""
238
+ raw_heatmap = raw_heatmap if raw_heatmap is not None else display_heatmap
239
+ h, w = display_heatmap.shape
240
+ heatmap_rgb = heatmap_to_rgb(display_heatmap, colormap_name)
241
+ if cell_mask is not None and np.any(cell_mask > 0):
242
+ heatmap_rgb = _draw_contour_on_image(heatmap_rgb.copy(), cell_mask, stroke_color=(255, 0, 0), stroke_width=2)
243
+ pil_bg = Image.fromarray(heatmap_rgb).resize((CANVAS_SIZE, CANVAS_SIZE), Image.Resampling.LANCZOS)
244
+
245
+ st.markdown("""
246
+ <style>
247
+ [data-testid="stDialog"] [data-testid="stSelectbox"], [data-testid="stExpander"] [data-testid="stSelectbox"],
248
+ [data-testid="stDialog"] [data-testid="stSelectbox"] > div, [data-testid="stExpander"] [data-testid="stSelectbox"] > div {
249
+ width: 100% !important; max-width: 100% !important;
250
+ }
251
+ [data-testid="stDialog"] [data-testid="stMetric"] label, [data-testid="stDialog"] [data-testid="stMetric"] [data-testid="stMetricValue"],
252
+ [data-testid="stExpander"] [data-testid="stMetric"] label, [data-testid="stExpander"] [data-testid="stMetric"] [data-testid="stMetricValue"] {
253
+ font-size: 0.95rem !important;
254
+ }
255
+ [data-testid="stDialog"] img, [data-testid="stExpander"] img { border-radius: 0 !important; }
256
+ </style>
257
+ """, unsafe_allow_html=True)
258
+
259
+ if bf_img is not None:
260
+ bf_resized = cv2.resize(bf_img, (CANVAS_SIZE, CANVAS_SIZE))
261
+ bf_rgb = cv2.cvtColor(bf_resized, cv2.COLOR_GRAY2RGB) if bf_img.ndim == 2 else cv2.cvtColor(bf_resized, cv2.COLOR_BGR2RGB)
262
+ left_col, right_col = st.columns(2, gap=None)
263
+ with left_col:
264
+ draw_mode = st.selectbox("Tool", DRAW_TOOLS, format_func=lambda x: TOOL_LABELS[x], key=f"draw_mode_region_{key_suffix}")
265
+ st.caption("Left-click add, right-click close. \nForce map (draw region)")
266
+ canvas_result = st_canvas(
267
+ fill_color="rgba(255, 165, 0, 0.3)", stroke_width=2, stroke_color="#ff6600",
268
+ background_image=pil_bg, drawing_mode=draw_mode, update_streamlit=True,
269
+ height=CANVAS_SIZE, width=CANVAS_SIZE, display_toolbar=True,
270
+ key=f"region_measure_canvas_{key_suffix}",
271
+ )
272
+ with right_col:
273
+ vals = cell_vals if cell_vals else original_vals
274
+ if vals:
275
+ label = "Cell area" if cell_vals else "Full map"
276
+ st.markdown(f'<p style="font-weight: 400; color: #334155; font-size: 0.95rem; margin: 0 20px 4px 4px;">{label}</p>', unsafe_allow_html=True)
277
+ st.markdown(f"""
278
+ <div style="width: 100%; box-sizing: border-box; border: 1px solid #e2e8f0; border-radius: 10px;
279
+ padding: 10px 12px; margin: 0 10px 20px 10px; background: linear-gradient(145deg, #f8fafc 0%, #f1f5f9 100%);
280
+ box-shadow: 0 1px 3px rgba(0,0,0,0.06);">
281
+ <div style="display: flex; flex-wrap: wrap; gap: 5px; font-size: 0.9rem;">
282
+ <span><strong>Sum:</strong> {vals['pixel_sum']:.1f}</span>
283
+ <span><strong>Force:</strong> {vals['force']:.1f}</span>
284
+ <span><strong>Max:</strong> {vals['max']:.3f}</span>
285
+ <span><strong>Mean:</strong> {vals['mean']:.3f}</span>
286
+ </div>
287
+ </div>
288
+ """, unsafe_allow_html=True)
289
+ st.caption("Bright-field")
290
+ bf_display = bf_rgb.copy()
291
+ if cell_mask is not None and np.any(cell_mask > 0):
292
+ bf_display = _draw_contour_on_image(bf_display, cell_mask, stroke_color=(255, 0, 0), stroke_width=2)
293
+ st.image(bf_display, width=CANVAS_SIZE)
294
+ else:
295
+ st.markdown("**Draw a region** on the heatmap.")
296
+ draw_mode = st.selectbox("Drawing tool", DRAW_TOOLS,
297
+ format_func=lambda x: "Polygon (free shape)" if x == "polygon" else TOOL_LABELS[x],
298
+ key=f"draw_mode_region_{key_suffix}")
299
+ st.caption("Polygon: left-click to add points, right-click to close.")
300
+ canvas_result = st_canvas(
301
+ fill_color="rgba(255, 165, 0, 0.3)", stroke_width=2, stroke_color="#ff6600",
302
+ background_image=pil_bg, drawing_mode=draw_mode, update_streamlit=True,
303
+ height=CANVAS_SIZE, width=CANVAS_SIZE, display_toolbar=True,
304
+ key=f"region_measure_canvas_{key_suffix}",
305
+ )
306
+
307
+ if canvas_result.json_data:
308
+ masks, n = parse_canvas_shapes_to_masks(canvas_result.json_data, CANVAS_SIZE, CANVAS_SIZE, h, w)
309
+ if masks and n > 0:
310
+ metrics_list = [compute_region_metrics(raw_heatmap, m, original_vals) for m in masks]
311
+ if cell_mask is not None and np.any(cell_mask > 0):
312
+ cell_metrics = compute_region_metrics(raw_heatmap, cell_mask, original_vals)
313
+ metrics_list = [cell_metrics] + metrics_list
314
+ combined_mask = masks[0].copy()
315
+ for m in masks[1:]:
316
+ combined_mask = np.maximum(combined_mask, m)
317
+ render_region_metrics_and_downloads(
318
+ metrics_list, heatmap_rgb, combined_mask, input_filename, key_suffix, original_vals is not None,
319
+ first_region_label="Auto boundary" if (cell_mask is not None and np.any(cell_mask > 0)) else None,
320
+ )
321
+
322
+
323
+ def _compute_cell_metrics(raw_heatmap, cell_mask, pixel_sum, force):
324
+ """Compute metrics over estimated cell area only."""
325
+ area_px = int(np.sum(cell_mask))
326
+ if area_px == 0:
327
+ return None, None, None
328
+ region_values = raw_heatmap * cell_mask
329
+ cell_pixel_sum = float(np.sum(region_values))
330
+ cell_force = cell_pixel_sum * (force / pixel_sum) if pixel_sum > 0 else cell_pixel_sum
331
+ cell_mean = cell_pixel_sum / area_px
332
+ return cell_pixel_sum, cell_force, cell_mean
333
+
334
+
335
+ def _add_cell_contour_to_fig(fig_pl, cell_mask, row=1, col=2):
336
+ """Add red contour overlay to Plotly heatmap subplot."""
337
+ if cell_mask is None or not np.any(cell_mask > 0):
338
+ return
339
+ contours, _ = cv2.findContours(cell_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
340
+ if not contours:
341
+ return
342
+ # Use largest contour
343
+ cnt = max(contours, key=cv2.contourArea)
344
+ pts = cnt.squeeze()
345
+ if pts.ndim == 1:
346
+ pts = pts.reshape(1, 2)
347
+ x, y = pts[:, 0].tolist(), pts[:, 1].tolist()
348
+ if x[0] != x[-1] or y[0] != y[-1]:
349
+ x.append(x[0])
350
+ y.append(y[0])
351
+ fig_pl.add_trace(
352
+ go.Scatter(x=x, y=y, mode="lines", line=dict(color="red", width=2), showlegend=False),
353
+ row=row, col=col
354
+ )
355
+
356
+
357
+ def render_result_display(img, raw_heatmap, display_heatmap, pixel_sum, force, key_img, download_key_suffix="",
358
+ colormap_name="Jet", display_mode="Auto", measure_region_dialog=None, auto_cell_boundary=True):
359
+ """
360
+ Render prediction result: plot, metrics, expander, and download/measure buttons.
361
+ measure_region_dialog: callable to open measure dialog (when ST_DIALOG available).
362
+ auto_cell_boundary: when True, use estimated cell area for metrics; when False, use entire map.
363
+ """
364
+ cell_mask = estimate_cell_mask(raw_heatmap) if auto_cell_boundary else None
365
+ cell_pixel_sum, cell_force, cell_mean = _compute_cell_metrics(raw_heatmap, cell_mask, pixel_sum, force) if cell_mask is not None else (None, None, None)
366
+ use_cell_metrics = auto_cell_boundary and cell_pixel_sum is not None and cell_force is not None and cell_mean is not None
367
+
368
+ base_name = os.path.splitext(key_img or "image")[0]
369
+ if use_cell_metrics:
370
+ main_csv_rows = [
371
+ ["image", "Cell sum", "Cell force (scaled)", "Heatmap max", "Cell mean"],
372
+ [base_name, f"{cell_pixel_sum:.2f}", f"{cell_force:.2f}",
373
+ f"{np.max(raw_heatmap):.4f}", f"{cell_mean:.4f}"],
374
+ ]
375
+ else:
376
+ main_csv_rows = [
377
+ ["image", "Sum of all pixels", "Cell force (scaled)", "Heatmap max", "Heatmap mean"],
378
+ [base_name, f"{pixel_sum:.2f}", f"{force:.2f}",
379
+ f"{np.max(raw_heatmap):.4f}", f"{np.mean(raw_heatmap):.4f}"],
380
+ ]
381
+ buf_main_csv = io.StringIO()
382
+ csv.writer(buf_main_csv).writerows(main_csv_rows)
383
+
384
+ buf_hm = heatmap_to_png_bytes(display_heatmap, colormap_name, cell_mask=cell_mask)
385
+
386
+ tit1, tit2 = st.columns(2)
387
+ with tit1:
388
+ st.markdown('<p style="font-size: 1.1rem; color: black; font-weight: 600;">Input: Bright-field image</p>', unsafe_allow_html=True)
389
+ with tit2:
390
+ st.markdown('<p style="font-size: 1.1rem; color: black; font-weight: 600;">Output: Predicted traction force map</p>', unsafe_allow_html=True)
391
+ fig_pl = make_subplots(rows=1, cols=2)
392
+ fig_pl.add_trace(go.Heatmap(z=img, colorscale="gray", showscale=False), row=1, col=1)
393
+ plotly_colorscale = cv_colormap_to_plotly_colorscale(colormap_name)
394
+ zmin, zmax = 0.0, 1.0
395
+ fig_pl.add_trace(go.Heatmap(z=display_heatmap, colorscale=plotly_colorscale, zmin=zmin, zmax=zmax, showscale=True,
396
+ colorbar=dict(len=0.4, thickness=12)), row=1, col=2)
397
+ _add_cell_contour_to_fig(fig_pl, cell_mask, row=1, col=2)
398
+ fig_pl.update_layout(
399
+ height=400,
400
+ margin=dict(l=10, r=10, t=10, b=10),
401
+ xaxis=dict(scaleanchor="y", scaleratio=1),
402
+ xaxis2=dict(scaleanchor="y2", scaleratio=1),
403
+ )
404
+ fig_pl.update_xaxes(showticklabels=False)
405
+ fig_pl.update_yaxes(showticklabels=False, autorange="reversed")
406
+ st.plotly_chart(fig_pl, use_container_width=True, config={"displayModeBar": True, "responsive": True})
407
+
408
+ col1, col2, col3, col4 = st.columns(4)
409
+ if use_cell_metrics:
410
+ with col1:
411
+ st.metric("Cell sum", f"{cell_pixel_sum:.2f}", help="Sum over estimated cell area (background excluded)")
412
+ with col2:
413
+ st.metric("Cell force (scaled)", f"{cell_force:.2f}", help="Total traction force in physical units")
414
+ with col3:
415
+ st.metric("Heatmap max", f"{np.max(raw_heatmap):.4f}", help="Peak force intensity in the map")
416
+ with col4:
417
+ st.metric("Cell mean", f"{cell_mean:.4f}", help="Mean force over estimated cell area")
418
+ else:
419
+ with col1:
420
+ st.metric("Sum of all pixels", f"{pixel_sum:.2f}", help="Raw sum of all pixel values in the force map")
421
+ with col2:
422
+ st.metric("Cell force (scaled)", f"{force:.2f}", help="Total traction force in physical units")
423
+ with col3:
424
+ st.metric("Heatmap max", f"{np.max(raw_heatmap):.4f}", help="Peak force intensity in the map")
425
+ with col4:
426
+ st.metric("Heatmap mean", f"{np.mean(raw_heatmap):.4f}", help="Average force intensity (full FOV)")
427
+
428
+ with st.expander("How to read the results"):
429
+ if use_cell_metrics:
430
+ st.markdown("""
431
+ **Input (left):** Bright-field microscopy image of a cell or spheroid on a substrate.
432
+ This is the raw image you provided—it shows cell shape but not forces.
433
+
434
+ **Output (right):** Predicted traction force map.
435
+ - **Color** indicates force magnitude: blue = low, red = high
436
+ - **Brighter/warmer colors** = stronger forces exerted by the cell on the substrate
437
+ - **Red border = estimated cell area** (background excluded from metrics)
438
+ - Values are normalized to [0, 1] for visualization
439
+
440
+ **Metrics (auto cell boundary on):**
441
+ - **Cell sum:** Sum over estimated cell area (background excluded)
442
+ - **Cell force (scaled):** Total traction force in physical units
443
+ - **Heatmap max:** Peak force intensity in the map
444
+ - **Cell mean:** Mean force over the estimated cell area
445
+ """)
446
+ else:
447
+ st.markdown("""
448
+ **Input (left):** Bright-field microscopy image of a cell or spheroid on a substrate.
449
+ This is the raw image you provided—it shows cell shape but not forces.
450
+
451
+ **Output (right):** Predicted traction force map.
452
+ - **Color** indicates force magnitude: blue = low, red = high
453
+ - **Brighter/warmer colors** = stronger forces exerted by the cell on the substrate
454
+ - Values are normalized to [0, 1] for visualization
455
+
456
+ **Metrics (auto cell boundary off):**
457
+ - **Sum of all pixels:** Raw sum over entire map
458
+ - **Cell force (scaled):** Total traction force in physical units
459
+ - **Heatmap max/mean:** Peak and average force intensity (full field of view)
460
+ """)
461
+
462
+ original_vals = build_original_vals(raw_heatmap, pixel_sum, force)
463
+
464
+ pdf_bytes = create_pdf_report(
465
+ img, display_heatmap, raw_heatmap, pixel_sum, force, base_name, colormap_name,
466
+ cell_mask=cell_mask, cell_pixel_sum=cell_pixel_sum, cell_force=cell_force, cell_mean=cell_mean
467
+ )
468
+
469
+ btn_col1, btn_col2, btn_col3, btn_col4 = st.columns(4)
470
+ with btn_col1:
471
+ if HAS_DRAWABLE_CANVAS and measure_region_dialog is not None:
472
+ if st.button("Measure tool", key="open_measure", icon=":material/straighten:"):
473
+ st.session_state["open_measure_dialog"] = True
474
+ st.rerun()
475
+ elif HAS_DRAWABLE_CANVAS:
476
+ with st.expander("Measure tool"):
477
+ expander_cell_vals = build_cell_vals(raw_heatmap, cell_mask, pixel_sum, force) if (auto_cell_boundary and cell_mask is not None) else None
478
+ expander_cell_mask = cell_mask if auto_cell_boundary else None
479
+ render_region_canvas(
480
+ display_heatmap,
481
+ raw_heatmap=raw_heatmap,
482
+ bf_img=img,
483
+ original_vals=original_vals,
484
+ cell_vals=expander_cell_vals,
485
+ cell_mask=expander_cell_mask,
486
+ key_suffix="expander",
487
+ input_filename=key_img,
488
+ colormap_name=colormap_name,
489
+ )
490
+ else:
491
+ st.caption("Install `streamlit-drawable-canvas-fix` for region measurement: `pip install streamlit-drawable-canvas-fix`")
492
+ with btn_col2:
493
+ st.download_button(
494
+ "Download heatmap",
495
+ width="stretch",
496
+ data=buf_hm.getvalue(),
497
+ file_name="s2f_heatmap.png",
498
+ mime="image/png",
499
+ key=f"download_heatmap{download_key_suffix}",
500
+ icon=":material/download:",
501
+ )
502
+ with btn_col3:
503
+ st.download_button(
504
+ "Download values",
505
+ width="stretch",
506
+ data=buf_main_csv.getvalue(),
507
+ file_name=f"{base_name}_main_values.csv",
508
+ mime="text/csv",
509
+ key=f"download_main_values{download_key_suffix}",
510
+ icon=":material/download:",
511
+ )
512
+ with btn_col4:
513
+ st.download_button(
514
+ "Download report",
515
+ width="stretch",
516
+ data=pdf_bytes,
517
+ file_name=f"{base_name}_report.pdf",
518
+ mime="application/pdf",
519
+ key=f"download_pdf{download_key_suffix}",
520
+ icon=":material/picture_as_pdf:",
521
+ )
S2FApp/utils/display.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Display utilities for heatmaps and colormaps."""
2
+ import numpy as np
3
+ import cv2
4
+
5
+ from config.constants import COLORMAPS, COLORMAP_N_SAMPLES
6
+
7
+
8
+ def cv_colormap_to_plotly_colorscale(colormap_name, n_samples=None):
9
+ """Build a Plotly colorscale from OpenCV colormap so UI matches download/PDF exactly."""
10
+ n = n_samples or COLORMAP_N_SAMPLES
11
+ cv2_cmap = COLORMAPS.get(colormap_name, cv2.COLORMAP_JET)
12
+ gradient = np.linspace(0, 255, n, dtype=np.uint8).reshape(1, -1)
13
+ rgb = cv2.applyColorMap(gradient, cv2_cmap)
14
+ rgb = cv2.cvtColor(rgb, cv2.COLOR_BGR2RGB)
15
+ scale = []
16
+ for i in range(n):
17
+ r, g, b = rgb[0, i]
18
+ scale.append([i / (n - 1), f"rgb({r},{g},{b})"])
19
+ return scale
20
+
21
+
22
+ def apply_display_scale(heatmap, mode):
23
+ """
24
+ Apply display scaling (Fiji-style). Display only—does not change underlying values.
25
+ - Auto: map data min..max to 0..1 (full color range)
26
+ - Fixed: use 0-1 range as-is
27
+ """
28
+ if mode == "Fixed":
29
+ return np.clip(heatmap, 0, 1).astype(np.float32)
30
+ hmin, hmax = float(np.min(heatmap)), float(np.max(heatmap))
31
+ if hmax > hmin:
32
+ out = (heatmap.astype(np.float32) - hmin) / (hmax - hmin)
33
+ return np.clip(out, 0, 1).astype(np.float32)
34
+ return np.clip(heatmap, 0, 1).astype(np.float32)
S2FApp/utils/metrics.py CHANGED
@@ -237,7 +237,7 @@ def evaluate_metrics_on_dataset(generator, data_loader, device=None, description
237
 
238
  if use_settings and normalization_params is not None:
239
  from models.s2f_model import create_settings_channels
240
- meta = metadata if has_metadata else {'substrate': [substrate_override or 'fibroblasts_PDMS'] * images.size(0)}
241
  settings_ch = create_settings_channels(meta, normalization_params, device, images.shape, config_path=config_path)
242
  images = torch.cat([images, settings_ch], dim=1)
243
 
@@ -420,7 +420,7 @@ def plot_predictions(loader, generator, n_samples, device, threshold=0.0,
420
  bf_batch = torch.stack(bf_list[:n]).to(device, dtype=torch.float32)
421
  if use_settings and normalization_params:
422
  from models.s2f_model import create_settings_channels
423
- sub = substrate_override or 'fibroblasts_PDMS'
424
  meta_dict = {'substrate': [sub] * n}
425
  settings_ch = create_settings_channels(meta_dict, normalization_params, device, bf_batch.shape, config_path=config_path)
426
  bf_batch = torch.cat([bf_batch, settings_ch], dim=1)
 
237
 
238
  if use_settings and normalization_params is not None:
239
  from models.s2f_model import create_settings_channels
240
+ meta = metadata if has_metadata else {'substrate': [substrate_override or 'Fibroblasts_Fibronectin_6KPa'] * images.size(0)}
241
  settings_ch = create_settings_channels(meta, normalization_params, device, images.shape, config_path=config_path)
242
  images = torch.cat([images, settings_ch], dim=1)
243
 
 
420
  bf_batch = torch.stack(bf_list[:n]).to(device, dtype=torch.float32)
421
  if use_settings and normalization_params:
422
  from models.s2f_model import create_settings_channels
423
+ sub = substrate_override or 'Fibroblasts_Fibronectin_6KPa'
424
  meta_dict = {'substrate': [sub] * n}
425
  settings_ch = create_settings_channels(meta_dict, normalization_params, device, bf_batch.shape, config_path=config_path)
426
  bf_batch = torch.cat([bf_batch, settings_ch], dim=1)
S2FApp/utils/report.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Report and heatmap export utilities."""
2
+ import io
3
+ from datetime import datetime
4
+
5
+ import cv2
6
+ import numpy as np
7
+ from PIL import Image
8
+ from reportlab.lib.pagesizes import A4
9
+ from reportlab.lib.units import inch
10
+ from reportlab.lib.utils import ImageReader
11
+ from reportlab.pdfgen import canvas
12
+
13
+ from config.constants import COLORMAPS
14
+
15
+
16
+ def heatmap_to_rgb(scaled_heatmap, colormap_name="Jet"):
17
+ """Convert scaled heatmap (float 0-1) to RGB array using the given colormap."""
18
+ heatmap_uint8 = (np.clip(scaled_heatmap, 0, 1) * 255).astype(np.uint8)
19
+ cv2_colormap = COLORMAPS.get(colormap_name, cv2.COLORMAP_JET)
20
+ heatmap_rgb = cv2.cvtColor(cv2.applyColorMap(heatmap_uint8, cv2_colormap), cv2.COLOR_BGR2RGB)
21
+ return heatmap_rgb
22
+
23
+
24
+ def heatmap_to_png_bytes(scaled_heatmap, colormap_name="Jet", cell_mask=None):
25
+ """Convert scaled heatmap (float 0-1) to PNG bytes buffer. Optionally draw red cell contour."""
26
+ heatmap_rgb = heatmap_to_rgb(scaled_heatmap, colormap_name)
27
+ if cell_mask is not None and np.any(cell_mask > 0):
28
+ contours, _ = cv2.findContours(cell_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
29
+ if contours:
30
+ cv2.drawContours(heatmap_rgb, contours, -1, (255, 0, 0), 2)
31
+ buf = io.BytesIO()
32
+ Image.fromarray(heatmap_rgb).save(buf, format="PNG")
33
+ buf.seek(0)
34
+ return buf
35
+
36
+
37
+ def create_pdf_report(img, display_heatmap, raw_heatmap, pixel_sum, force, base_name, colormap_name="Jet",
38
+ cell_mask=None, cell_pixel_sum=None, cell_force=None, cell_mean=None):
39
+ """Create a PDF report with input image, heatmap, and metrics."""
40
+ buf = io.BytesIO()
41
+ c = canvas.Canvas(buf, pagesize=A4)
42
+ c.setTitle("Shape2Force")
43
+ c.setAuthor("Angione-Lab")
44
+ h = A4[1]
45
+ img_w, img_h = 2.5 * inch, 2.5 * inch
46
+
47
+ footer_y = 40
48
+ c.setFont("Helvetica", 8)
49
+ c.setFillColorRGB(0.4, 0.4, 0.4)
50
+ gen_date = datetime.now().strftime("%Y-%m-%d %H:%M")
51
+ c.drawString(72, footer_y, f"Generated by Shape2Force (S2F) on {gen_date}")
52
+ c.drawString(72, footer_y - 12, "Model: https://huggingface.co/Angione-Lab/Shape2Force")
53
+ c.drawString(72, footer_y - 24, "Web app: https://huggingface.co/spaces/Angione-Lab/Shape2force")
54
+ c.setFillColorRGB(0, 0, 0)
55
+
56
+ img_top = h - 70
57
+ img_pil = Image.fromarray(img) if img.ndim == 2 else Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
58
+ img_buf = io.BytesIO()
59
+ img_pil.save(img_buf, format="PNG")
60
+ img_buf.seek(0)
61
+ c.drawImage(ImageReader(img_buf), 72, img_top - img_h, width=img_w, height=img_h, preserveAspectRatio=True)
62
+ c.setFont("Helvetica", 9)
63
+ c.drawString(72, img_top - img_h - 12, "Input: Bright-field")
64
+
65
+ heatmap_rgb = heatmap_to_rgb(display_heatmap, colormap_name)
66
+ if cell_mask is not None and np.any(cell_mask > 0):
67
+ contours, _ = cv2.findContours(cell_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
68
+ if contours:
69
+ cv2.drawContours(heatmap_rgb, contours, -1, (255, 0, 0), 2)
70
+ hm_buf = io.BytesIO()
71
+ Image.fromarray(heatmap_rgb).save(hm_buf, format="PNG")
72
+ hm_buf.seek(0)
73
+ c.drawImage(ImageReader(hm_buf), 72 + img_w + 20, img_top - img_h, width=img_w, height=img_h, preserveAspectRatio=True)
74
+ c.drawString(72 + img_w + 20, img_top - img_h - 12, "Output: Force map (red = estimated cell)")
75
+
76
+ c.setFont("Helvetica-Bold", 16)
77
+ c.drawString(72, img_top + 25, "Shape2Force (S2F) - Prediction Report")
78
+ c.setFont("Helvetica", 10)
79
+ c.drawString(72, img_top + 8, f"Image: {base_name}")
80
+
81
+ y = img_top - img_h - 45
82
+ c.setFont("Helvetica-Bold", 10)
83
+ c.drawString(72, y, "Metrics")
84
+ c.setFont("Helvetica", 9)
85
+ y -= 18
86
+ if cell_pixel_sum is not None and cell_force is not None and cell_mean is not None:
87
+ metrics = [
88
+ ("Cell sum (estimated cell area)", f"{cell_pixel_sum:.2f}"),
89
+ ("Cell force (scaled)", f"{cell_force:.2f}"),
90
+ ("Heatmap max", f"{np.max(raw_heatmap):.4f}"),
91
+ ("Cell mean (estimated cell area)", f"{cell_mean:.4f}"),
92
+ ]
93
+ else:
94
+ metrics = [
95
+ ("Sum of all pixels", f"{pixel_sum:.2f}"),
96
+ ("Cell force (scaled)", f"{force:.2f}"),
97
+ ("Heatmap max", f"{np.max(raw_heatmap):.4f}"),
98
+ ("Heatmap mean", f"{np.mean(raw_heatmap):.4f}"),
99
+ ]
100
+ for label, val in metrics:
101
+ c.drawString(72, y, f"{label}: {val}")
102
+ y -= 16
103
+
104
+ c.save()
105
+ buf.seek(0)
106
+ return buf.getvalue()