Spaces:
Building
Building
cleaned code
Browse files- app.py +27 -32
- predictor.py +3 -6
- ui/components.py +2 -15
- utils/paths.py +17 -0
- utils/report.py +9 -7
app.py
CHANGED
|
@@ -20,6 +20,7 @@ from config.constants import (
|
|
| 20 |
MODEL_TYPE_LABELS,
|
| 21 |
SAMPLE_EXTENSIONS,
|
| 22 |
)
|
|
|
|
| 23 |
from utils.segmentation import estimate_cell_mask
|
| 24 |
from utils.substrate_settings import list_substrates
|
| 25 |
from utils.display import apply_display_scale
|
|
@@ -32,11 +33,6 @@ from ui.components import (
|
|
| 32 |
HAS_DRAWABLE_CANVAS,
|
| 33 |
)
|
| 34 |
|
| 35 |
-
try:
|
| 36 |
-
from streamlit_drawable_canvas import st_canvas
|
| 37 |
-
except (ImportError, AttributeError):
|
| 38 |
-
pass # HAS_DRAWABLE_CANVAS from ui.components
|
| 39 |
-
|
| 40 |
CITATION = (
|
| 41 |
"Lautaro Baro, Kaveh Shahhosseini, Amparo Andrés-Bordería, Claudio Angione, and Maria Angeles Juanes. "
|
| 42 |
"**\"Shape-to-force (S2F): Predicting Cell Traction Forces from LabelFree Imaging\"**, 2026."
|
|
@@ -73,6 +69,21 @@ def _get_measure_dialog_fn():
|
|
| 73 |
return measure_region_dialog if (HAS_DRAWABLE_CANVAS and ST_DIALOG) else None
|
| 74 |
|
| 75 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 76 |
st.set_page_config(page_title="Shape2Force (S2F)", page_icon="🦠", layout="centered")
|
| 77 |
|
| 78 |
# Theme CSS (inject based on sidebar selection)
|
|
@@ -123,11 +134,7 @@ st.title("🦠 Shape2Force (S2F)")
|
|
| 123 |
st.caption("Predict traction force maps from bright-field microscopy images of cells or spheroids")
|
| 124 |
|
| 125 |
# Folders
|
| 126 |
-
ckp_base =
|
| 127 |
-
if not os.path.isdir(ckp_base):
|
| 128 |
-
project_root = os.path.dirname(S2F_ROOT)
|
| 129 |
-
if os.path.isdir(os.path.join(project_root, "ckp")):
|
| 130 |
-
ckp_base = os.path.join(project_root, "ckp")
|
| 131 |
ckp_single_cell = os.path.join(ckp_base, "single_cell")
|
| 132 |
ckp_spheroid = os.path.join(ckp_base, "spheroid")
|
| 133 |
sample_base = os.path.join(S2F_ROOT, "samples")
|
|
@@ -179,7 +186,7 @@ with st.sidebar:
|
|
| 179 |
|
| 180 |
ckp_files = get_ckp_files_for_model(model_type)
|
| 181 |
ckp_folder = ckp_single_cell if model_type == "single_cell" else ckp_spheroid
|
| 182 |
-
ckp_subfolder_name =
|
| 183 |
|
| 184 |
if ckp_files:
|
| 185 |
checkpoint = st.selectbox(
|
|
@@ -265,7 +272,7 @@ if img_source == "Upload":
|
|
| 265 |
else:
|
| 266 |
sample_files = get_sample_files_for_model(model_type)
|
| 267 |
sample_folder = sample_single_cell if model_type == "single_cell" else sample_spheroid
|
| 268 |
-
sample_subfolder_name =
|
| 269 |
if sample_files:
|
| 270 |
selected_sample = st.selectbox(
|
| 271 |
f"Select example image (from `samples/{sample_subfolder_name}/`)",
|
|
@@ -345,16 +352,10 @@ if just_ran:
|
|
| 345 |
"pixel_sum": pixel_sum,
|
| 346 |
"cache_key": cache_key,
|
| 347 |
}
|
| 348 |
-
|
| 349 |
-
|
| 350 |
-
|
| 351 |
-
|
| 352 |
-
st.session_state["measure_original_vals"] = build_original_vals(heatmap, pixel_sum, force)
|
| 353 |
-
st.session_state["measure_colormap"] = colormap_name
|
| 354 |
-
cell_mask = estimate_cell_mask(heatmap)
|
| 355 |
-
st.session_state["measure_auto_cell_on"] = auto_cell_boundary
|
| 356 |
-
st.session_state["measure_cell_vals"] = build_cell_vals(heatmap, cell_mask, pixel_sum, force) if auto_cell_boundary else None
|
| 357 |
-
st.session_state["measure_cell_mask"] = cell_mask if auto_cell_boundary else None
|
| 358 |
|
| 359 |
render_result_display(
|
| 360 |
img, heatmap, display_heatmap, pixel_sum, force, key_img,
|
|
@@ -373,16 +374,10 @@ elif has_cached:
|
|
| 373 |
img, heatmap, force, pixel_sum = r["img"], r["heatmap"], r["force"], r["pixel_sum"]
|
| 374 |
display_heatmap = apply_display_scale(heatmap, display_mode)
|
| 375 |
|
| 376 |
-
|
| 377 |
-
|
| 378 |
-
|
| 379 |
-
|
| 380 |
-
st.session_state["measure_original_vals"] = build_original_vals(heatmap, pixel_sum, force)
|
| 381 |
-
st.session_state["measure_colormap"] = colormap_name
|
| 382 |
-
cell_mask = estimate_cell_mask(heatmap)
|
| 383 |
-
st.session_state["measure_auto_cell_on"] = auto_cell_boundary
|
| 384 |
-
st.session_state["measure_cell_vals"] = build_cell_vals(heatmap, cell_mask, pixel_sum, force) if auto_cell_boundary else None
|
| 385 |
-
st.session_state["measure_cell_mask"] = cell_mask if auto_cell_boundary else None
|
| 386 |
|
| 387 |
if st.session_state.pop("open_measure_dialog", False):
|
| 388 |
measure_region_dialog()
|
|
|
|
| 20 |
MODEL_TYPE_LABELS,
|
| 21 |
SAMPLE_EXTENSIONS,
|
| 22 |
)
|
| 23 |
+
from utils.paths import get_ckp_base, model_subfolder
|
| 24 |
from utils.segmentation import estimate_cell_mask
|
| 25 |
from utils.substrate_settings import list_substrates
|
| 26 |
from utils.display import apply_display_scale
|
|
|
|
| 33 |
HAS_DRAWABLE_CANVAS,
|
| 34 |
)
|
| 35 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
CITATION = (
|
| 37 |
"Lautaro Baro, Kaveh Shahhosseini, Amparo Andrés-Bordería, Claudio Angione, and Maria Angeles Juanes. "
|
| 38 |
"**\"Shape-to-force (S2F): Predicting Cell Traction Forces from LabelFree Imaging\"**, 2026."
|
|
|
|
| 69 |
return measure_region_dialog if (HAS_DRAWABLE_CANVAS and ST_DIALOG) else None
|
| 70 |
|
| 71 |
|
| 72 |
+
def _populate_measure_session_state(heatmap, img, pixel_sum, force, key_img, colormap_name,
|
| 73 |
+
display_mode, auto_cell_boundary):
|
| 74 |
+
"""Populate session state for the measure tool."""
|
| 75 |
+
cell_mask = estimate_cell_mask(heatmap)
|
| 76 |
+
st.session_state["measure_raw_heatmap"] = heatmap.copy()
|
| 77 |
+
st.session_state["measure_display_mode"] = display_mode
|
| 78 |
+
st.session_state["measure_bf_img"] = img.copy()
|
| 79 |
+
st.session_state["measure_input_filename"] = key_img or "image"
|
| 80 |
+
st.session_state["measure_original_vals"] = build_original_vals(heatmap, pixel_sum, force)
|
| 81 |
+
st.session_state["measure_colormap"] = colormap_name
|
| 82 |
+
st.session_state["measure_auto_cell_on"] = auto_cell_boundary
|
| 83 |
+
st.session_state["measure_cell_vals"] = build_cell_vals(heatmap, cell_mask, pixel_sum, force) if auto_cell_boundary else None
|
| 84 |
+
st.session_state["measure_cell_mask"] = cell_mask if auto_cell_boundary else None
|
| 85 |
+
|
| 86 |
+
|
| 87 |
st.set_page_config(page_title="Shape2Force (S2F)", page_icon="🦠", layout="centered")
|
| 88 |
|
| 89 |
# Theme CSS (inject based on sidebar selection)
|
|
|
|
| 134 |
st.caption("Predict traction force maps from bright-field microscopy images of cells or spheroids")
|
| 135 |
|
| 136 |
# Folders
|
| 137 |
+
ckp_base = get_ckp_base(S2F_ROOT)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 138 |
ckp_single_cell = os.path.join(ckp_base, "single_cell")
|
| 139 |
ckp_spheroid = os.path.join(ckp_base, "spheroid")
|
| 140 |
sample_base = os.path.join(S2F_ROOT, "samples")
|
|
|
|
| 186 |
|
| 187 |
ckp_files = get_ckp_files_for_model(model_type)
|
| 188 |
ckp_folder = ckp_single_cell if model_type == "single_cell" else ckp_spheroid
|
| 189 |
+
ckp_subfolder_name = model_subfolder(model_type)
|
| 190 |
|
| 191 |
if ckp_files:
|
| 192 |
checkpoint = st.selectbox(
|
|
|
|
| 272 |
else:
|
| 273 |
sample_files = get_sample_files_for_model(model_type)
|
| 274 |
sample_folder = sample_single_cell if model_type == "single_cell" else sample_spheroid
|
| 275 |
+
sample_subfolder_name = model_subfolder(model_type)
|
| 276 |
if sample_files:
|
| 277 |
selected_sample = st.selectbox(
|
| 278 |
f"Select example image (from `samples/{sample_subfolder_name}/`)",
|
|
|
|
| 352 |
"pixel_sum": pixel_sum,
|
| 353 |
"cache_key": cache_key,
|
| 354 |
}
|
| 355 |
+
_populate_measure_session_state(
|
| 356 |
+
heatmap, img, pixel_sum, force, key_img, colormap_name,
|
| 357 |
+
display_mode, auto_cell_boundary,
|
| 358 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 359 |
|
| 360 |
render_result_display(
|
| 361 |
img, heatmap, display_heatmap, pixel_sum, force, key_img,
|
|
|
|
| 374 |
img, heatmap, force, pixel_sum = r["img"], r["heatmap"], r["force"], r["pixel_sum"]
|
| 375 |
display_heatmap = apply_display_scale(heatmap, display_mode)
|
| 376 |
|
| 377 |
+
_populate_measure_session_state(
|
| 378 |
+
heatmap, img, pixel_sum, force, key_img, colormap_name,
|
| 379 |
+
display_mode, auto_cell_boundary,
|
| 380 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 381 |
|
| 382 |
if st.session_state.pop("open_measure_dialog", False):
|
| 383 |
measure_region_dialog()
|
predictor.py
CHANGED
|
@@ -14,6 +14,7 @@ if S2F_ROOT not in sys.path:
|
|
| 14 |
sys.path.insert(0, S2F_ROOT)
|
| 15 |
|
| 16 |
from models.s2f_model import create_s2f_model
|
|
|
|
| 17 |
from utils.substrate_settings import get_settings_of_category, compute_settings_normalization
|
| 18 |
from utils import config
|
| 19 |
|
|
@@ -89,12 +90,8 @@ class S2FPredictor:
|
|
| 89 |
"""
|
| 90 |
self.model_type = model_type
|
| 91 |
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
|
| 92 |
-
ckp_base =
|
| 93 |
-
|
| 94 |
-
project_root = os.path.dirname(S2F_ROOT)
|
| 95 |
-
if os.path.isdir(os.path.join(project_root, "ckp")):
|
| 96 |
-
ckp_base = os.path.join(project_root, "ckp")
|
| 97 |
-
subfolder = "single_cell" if model_type == "single_cell" else "spheroid"
|
| 98 |
ckp_dir = ckp_folder if ckp_folder else os.path.join(ckp_base, subfolder)
|
| 99 |
if not os.path.isdir(ckp_dir):
|
| 100 |
ckp_dir = ckp_base # fallback if subfolders not used
|
|
|
|
| 14 |
sys.path.insert(0, S2F_ROOT)
|
| 15 |
|
| 16 |
from models.s2f_model import create_s2f_model
|
| 17 |
+
from utils.paths import get_ckp_base, model_subfolder
|
| 18 |
from utils.substrate_settings import get_settings_of_category, compute_settings_normalization
|
| 19 |
from utils import config
|
| 20 |
|
|
|
|
| 90 |
"""
|
| 91 |
self.model_type = model_type
|
| 92 |
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
|
| 93 |
+
ckp_base = get_ckp_base(S2F_ROOT)
|
| 94 |
+
subfolder = model_subfolder(model_type)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 95 |
ckp_dir = ckp_folder if ckp_folder else os.path.join(ckp_base, subfolder)
|
| 96 |
if not os.path.isdir(ckp_dir):
|
| 97 |
ckp_dir = ckp_base # fallback if subfolders not used
|
ui/components.py
CHANGED
|
@@ -17,7 +17,7 @@ from config.constants import (
|
|
| 17 |
TOOL_LABELS,
|
| 18 |
)
|
| 19 |
from utils.display import apply_display_scale, cv_colormap_to_plotly_colorscale
|
| 20 |
-
from utils.report import heatmap_to_rgb, heatmap_to_png_bytes, create_pdf_report
|
| 21 |
from utils.segmentation import estimate_cell_mask
|
| 22 |
|
| 23 |
try:
|
|
@@ -102,17 +102,6 @@ def _obj_to_pts(obj, scale_x, scale_y, heatmap_w, heatmap_h):
|
|
| 102 |
return pts
|
| 103 |
|
| 104 |
|
| 105 |
-
def parse_canvas_shapes_to_mask(json_data, canvas_h, canvas_w, heatmap_h, heatmap_w):
|
| 106 |
-
"""Parse drawn shapes from streamlit-drawable-canvas json_data and create binary mask (combined)."""
|
| 107 |
-
masks, _ = parse_canvas_shapes_to_masks(json_data, canvas_h, canvas_w, heatmap_h, heatmap_w)
|
| 108 |
-
if not masks:
|
| 109 |
-
return None, 0
|
| 110 |
-
combined = np.zeros((heatmap_h, heatmap_w), dtype=np.uint8)
|
| 111 |
-
for m in masks:
|
| 112 |
-
combined = np.maximum(combined, m)
|
| 113 |
-
return combined, len(masks)
|
| 114 |
-
|
| 115 |
-
|
| 116 |
def parse_canvas_shapes_to_masks(json_data, canvas_h, canvas_w, heatmap_h, heatmap_w):
|
| 117 |
"""Parse drawn shapes and return a list of individual masks (one per shape)."""
|
| 118 |
if not json_data or "objects" not in json_data or not json_data["objects"]:
|
|
@@ -237,9 +226,7 @@ def render_region_canvas(display_heatmap, raw_heatmap=None, bf_img=None, origina
|
|
| 237 |
"""Render drawable canvas and region metrics. When cell_vals: show cell area (replaces Full map). Else: show Full map."""
|
| 238 |
raw_heatmap = raw_heatmap if raw_heatmap is not None else display_heatmap
|
| 239 |
h, w = display_heatmap.shape
|
| 240 |
-
heatmap_rgb =
|
| 241 |
-
if cell_mask is not None and np.any(cell_mask > 0):
|
| 242 |
-
heatmap_rgb = _draw_contour_on_image(heatmap_rgb.copy(), cell_mask, stroke_color=(255, 0, 0), stroke_width=2)
|
| 243 |
pil_bg = Image.fromarray(heatmap_rgb).resize((CANVAS_SIZE, CANVAS_SIZE), Image.Resampling.LANCZOS)
|
| 244 |
|
| 245 |
st.markdown("""
|
|
|
|
| 17 |
TOOL_LABELS,
|
| 18 |
)
|
| 19 |
from utils.display import apply_display_scale, cv_colormap_to_plotly_colorscale
|
| 20 |
+
from utils.report import heatmap_to_rgb, heatmap_to_rgb_with_contour, heatmap_to_png_bytes, create_pdf_report
|
| 21 |
from utils.segmentation import estimate_cell_mask
|
| 22 |
|
| 23 |
try:
|
|
|
|
| 102 |
return pts
|
| 103 |
|
| 104 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 105 |
def parse_canvas_shapes_to_masks(json_data, canvas_h, canvas_w, heatmap_h, heatmap_w):
|
| 106 |
"""Parse drawn shapes and return a list of individual masks (one per shape)."""
|
| 107 |
if not json_data or "objects" not in json_data or not json_data["objects"]:
|
|
|
|
| 226 |
"""Render drawable canvas and region metrics. When cell_vals: show cell area (replaces Full map). Else: show Full map."""
|
| 227 |
raw_heatmap = raw_heatmap if raw_heatmap is not None else display_heatmap
|
| 228 |
h, w = display_heatmap.shape
|
| 229 |
+
heatmap_rgb = heatmap_to_rgb_with_contour(display_heatmap, colormap_name, cell_mask)
|
|
|
|
|
|
|
| 230 |
pil_bg = Image.fromarray(heatmap_rgb).resize((CANVAS_SIZE, CANVAS_SIZE), Image.Resampling.LANCZOS)
|
| 231 |
|
| 232 |
st.markdown("""
|
utils/paths.py
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Path resolution utilities for S2F App."""
|
| 2 |
+
import os
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def get_ckp_base(root):
|
| 6 |
+
"""Resolve checkpoint base directory (S2FApp/ckp or project/ckp)."""
|
| 7 |
+
ckp_base = os.path.join(root, "ckp")
|
| 8 |
+
if not os.path.isdir(ckp_base):
|
| 9 |
+
project_root = os.path.dirname(root)
|
| 10 |
+
if os.path.isdir(os.path.join(project_root, "ckp")):
|
| 11 |
+
ckp_base = os.path.join(project_root, "ckp")
|
| 12 |
+
return ckp_base
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def model_subfolder(model_type):
|
| 16 |
+
"""Return subfolder name for model type: 'single_cell' or 'spheroid'."""
|
| 17 |
+
return "single_cell" if model_type == "single_cell" else "spheroid"
|
utils/report.py
CHANGED
|
@@ -21,13 +21,19 @@ def heatmap_to_rgb(scaled_heatmap, colormap_name="Jet"):
|
|
| 21 |
return heatmap_rgb
|
| 22 |
|
| 23 |
|
| 24 |
-
def
|
| 25 |
-
"""Convert
|
| 26 |
heatmap_rgb = heatmap_to_rgb(scaled_heatmap, colormap_name)
|
| 27 |
if cell_mask is not None and np.any(cell_mask > 0):
|
| 28 |
contours, _ = cv2.findContours(cell_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
| 29 |
if contours:
|
| 30 |
cv2.drawContours(heatmap_rgb, contours, -1, (255, 0, 0), 2)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
buf = io.BytesIO()
|
| 32 |
Image.fromarray(heatmap_rgb).save(buf, format="PNG")
|
| 33 |
buf.seek(0)
|
|
@@ -62,11 +68,7 @@ def create_pdf_report(img, display_heatmap, raw_heatmap, pixel_sum, force, base_
|
|
| 62 |
c.setFont("Helvetica", 9)
|
| 63 |
c.drawString(72, img_top - img_h - 12, "Input: Bright-field")
|
| 64 |
|
| 65 |
-
heatmap_rgb =
|
| 66 |
-
if cell_mask is not None and np.any(cell_mask > 0):
|
| 67 |
-
contours, _ = cv2.findContours(cell_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
| 68 |
-
if contours:
|
| 69 |
-
cv2.drawContours(heatmap_rgb, contours, -1, (255, 0, 0), 2)
|
| 70 |
hm_buf = io.BytesIO()
|
| 71 |
Image.fromarray(heatmap_rgb).save(hm_buf, format="PNG")
|
| 72 |
hm_buf.seek(0)
|
|
|
|
| 21 |
return heatmap_rgb
|
| 22 |
|
| 23 |
|
| 24 |
+
def heatmap_to_rgb_with_contour(scaled_heatmap, colormap_name="Jet", cell_mask=None):
|
| 25 |
+
"""Convert heatmap to RGB, optionally drawing red cell contour. Mask must match heatmap shape."""
|
| 26 |
heatmap_rgb = heatmap_to_rgb(scaled_heatmap, colormap_name)
|
| 27 |
if cell_mask is not None and np.any(cell_mask > 0):
|
| 28 |
contours, _ = cv2.findContours(cell_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
| 29 |
if contours:
|
| 30 |
cv2.drawContours(heatmap_rgb, contours, -1, (255, 0, 0), 2)
|
| 31 |
+
return heatmap_rgb
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def heatmap_to_png_bytes(scaled_heatmap, colormap_name="Jet", cell_mask=None):
|
| 35 |
+
"""Convert scaled heatmap (float 0-1) to PNG bytes buffer. Optionally draw red cell contour."""
|
| 36 |
+
heatmap_rgb = heatmap_to_rgb_with_contour(scaled_heatmap, colormap_name, cell_mask)
|
| 37 |
buf = io.BytesIO()
|
| 38 |
Image.fromarray(heatmap_rgb).save(buf, format="PNG")
|
| 39 |
buf.seek(0)
|
|
|
|
| 68 |
c.setFont("Helvetica", 9)
|
| 69 |
c.drawString(72, img_top - img_h - 12, "Input: Bright-field")
|
| 70 |
|
| 71 |
+
heatmap_rgb = heatmap_to_rgb_with_contour(display_heatmap, colormap_name, cell_mask)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 72 |
hm_buf = io.BytesIO()
|
| 73 |
Image.fromarray(heatmap_rgb).save(hm_buf, format="PNG")
|
| 74 |
hm_buf.seek(0)
|