Implemented batch processing for predictions to optimize memory usage and added progress tracking. Updated constants for batch inference size and adjusted contour drawing parameters for better visibility in UI elements.
Browse files- S2FApp/app.py +14 -5
- S2FApp/config/constants.py +2 -0
- S2FApp/predictor.py +36 -24
- S2FApp/static/s2f_styles.css +23 -2
- S2FApp/ui/heatmaps.py +2 -2
- S2FApp/ui/measure_tool.py +2 -2
- S2FApp/utils/report.py +1 -1
S2FApp/app.py
CHANGED
|
@@ -18,6 +18,7 @@ if S2F_ROOT not in sys.path:
|
|
| 18 |
sys.path.insert(0, S2F_ROOT)
|
| 19 |
|
| 20 |
from config.constants import (
|
|
|
|
| 21 |
BATCH_MAX_IMAGES,
|
| 22 |
COLORMAPS,
|
| 23 |
DEFAULT_SUBSTRATE,
|
|
@@ -464,11 +465,19 @@ if just_ran_batch:
|
|
| 464 |
try:
|
| 465 |
predictor = _load_predictor(model_type, checkpoint, ckp_folder)
|
| 466 |
sub_val = substrate_val if model_type == "single_cell" and not use_manual else DEFAULT_SUBSTRATE
|
| 467 |
-
|
| 468 |
-
|
| 469 |
-
|
| 470 |
-
|
| 471 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 472 |
batch_results = [
|
| 473 |
{
|
| 474 |
"img": img_b.copy(),
|
|
|
|
| 18 |
sys.path.insert(0, S2F_ROOT)
|
| 19 |
|
| 20 |
from config.constants import (
|
| 21 |
+
BATCH_INFERENCE_SIZE,
|
| 22 |
BATCH_MAX_IMAGES,
|
| 23 |
COLORMAPS,
|
| 24 |
DEFAULT_SUBSTRATE,
|
|
|
|
| 465 |
try:
|
| 466 |
predictor = _load_predictor(model_type, checkpoint, ckp_folder)
|
| 467 |
sub_val = substrate_val if model_type == "single_cell" and not use_manual else DEFAULT_SUBSTRATE
|
| 468 |
+
n_images = len(imgs_batch)
|
| 469 |
+
progress_bar = st.progress(0, text=f"Predicting 0 / {n_images} images")
|
| 470 |
+
pred_results = []
|
| 471 |
+
for start in range(0, n_images, BATCH_INFERENCE_SIZE):
|
| 472 |
+
chunk = imgs_batch[start : start + BATCH_INFERENCE_SIZE]
|
| 473 |
+
chunk_results = predictor.predict_batch(
|
| 474 |
+
chunk,
|
| 475 |
+
substrate=sub_val,
|
| 476 |
+
substrate_config=substrate_config if model_type == "single_cell" else None,
|
| 477 |
+
)
|
| 478 |
+
pred_results.extend(chunk_results)
|
| 479 |
+
progress_bar.progress(min(start + len(chunk), n_images) / n_images,
|
| 480 |
+
text=f"Predicting {len(pred_results)} / {n_images} images")
|
| 481 |
batch_results = [
|
| 482 |
{
|
| 483 |
"img": img_b.copy(),
|
S2FApp/config/constants.py
CHANGED
|
@@ -13,6 +13,8 @@ DEFAULT_SUBSTRATE = "Fibroblasts_Fibronectin_6KPa"
|
|
| 13 |
CANVAS_SIZE = 320
|
| 14 |
SAMPLE_THUMBNAIL_LIMIT = 8
|
| 15 |
BATCH_MAX_IMAGES = 5
|
|
|
|
|
|
|
| 16 |
COLORMAP_N_SAMPLES = 64
|
| 17 |
|
| 18 |
# Model type labels
|
|
|
|
| 13 |
CANVAS_SIZE = 320
|
| 14 |
SAMPLE_THUMBNAIL_LIMIT = 8
|
| 15 |
BATCH_MAX_IMAGES = 5
|
| 16 |
+
# Max images per model forward pass (avoids OOM on Hugging Face free tier)
|
| 17 |
+
BATCH_INFERENCE_SIZE = 2
|
| 18 |
COLORMAP_N_SAMPLES = 64
|
| 19 |
|
| 20 |
# Model type labels
|
S2FApp/predictor.py
CHANGED
|
@@ -13,7 +13,7 @@ S2F_ROOT = os.path.dirname(os.path.abspath(__file__))
|
|
| 13 |
if S2F_ROOT not in sys.path:
|
| 14 |
sys.path.insert(0, S2F_ROOT)
|
| 15 |
|
| 16 |
-
from config.constants import DEFAULT_SUBSTRATE, MODEL_INPUT_SIZE
|
| 17 |
from models.s2f_model import create_s2f_model
|
| 18 |
from utils.paths import get_ckp_base, model_subfolder
|
| 19 |
from utils.substrate_settings import get_settings_of_category, compute_settings_normalization
|
|
@@ -178,18 +178,23 @@ class S2FPredictor:
|
|
| 178 |
|
| 179 |
return heatmap, force, pixel_sum
|
| 180 |
|
| 181 |
-
def predict_batch(self, images, substrate=None, substrate_config=None
|
|
|
|
| 182 |
"""
|
| 183 |
-
Run prediction on a batch of images
|
|
|
|
| 184 |
|
| 185 |
Args:
|
| 186 |
images: List of (img_array, key) or list of img arrays. img_array: (H, W) or (H, W, C).
|
| 187 |
substrate: Substrate name for single-cell mode (same for all images).
|
| 188 |
substrate_config: Optional dict with 'pixelsize' and 'young' (same for all).
|
|
|
|
|
|
|
| 189 |
|
| 190 |
Returns:
|
| 191 |
List of (heatmap, force, pixel_sum) tuples.
|
| 192 |
"""
|
|
|
|
| 193 |
imgs = []
|
| 194 |
for item in images:
|
| 195 |
img = item[0] if isinstance(item, tuple) else item
|
|
@@ -200,27 +205,34 @@ class S2FPredictor:
|
|
| 200 |
img = img / 255.0
|
| 201 |
img = cv2.resize(img, (MODEL_INPUT_SIZE, MODEL_INPUT_SIZE))
|
| 202 |
imgs.append(img)
|
| 203 |
-
x = torch.from_numpy(np.stack(imgs)).float().unsqueeze(1).to(self.device) # [B, 1, H, W]
|
| 204 |
-
|
| 205 |
-
if self.model_type == "single_cell" and self.norm_params is not None:
|
| 206 |
-
sub = substrate if substrate is not None else DEFAULT_SUBSTRATE
|
| 207 |
-
settings_ch = create_settings_channels_single(
|
| 208 |
-
sub, self.device, x.shape[2], x.shape[3],
|
| 209 |
-
config_path=self.config_path, substrate_config=substrate_config
|
| 210 |
-
)
|
| 211 |
-
settings_batch = settings_ch.expand(x.shape[0], -1, -1, -1)
|
| 212 |
-
x = torch.cat([x, settings_batch], dim=1) # [B, 3, H, W]
|
| 213 |
-
|
| 214 |
-
with torch.no_grad():
|
| 215 |
-
pred = self.generator(x)
|
| 216 |
-
|
| 217 |
-
if self._use_tanh_output:
|
| 218 |
-
pred = (pred + 1.0) / 2.0
|
| 219 |
|
| 220 |
results = []
|
| 221 |
-
for
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 226 |
return results
|
|
|
|
| 13 |
if S2F_ROOT not in sys.path:
|
| 14 |
sys.path.insert(0, S2F_ROOT)
|
| 15 |
|
| 16 |
+
from config.constants import BATCH_INFERENCE_SIZE, DEFAULT_SUBSTRATE, MODEL_INPUT_SIZE
|
| 17 |
from models.s2f_model import create_s2f_model
|
| 18 |
from utils.paths import get_ckp_base, model_subfolder
|
| 19 |
from utils.substrate_settings import get_settings_of_category, compute_settings_normalization
|
|
|
|
| 178 |
|
| 179 |
return heatmap, force, pixel_sum
|
| 180 |
|
| 181 |
+
def predict_batch(self, images, substrate=None, substrate_config=None, batch_size=None,
|
| 182 |
+
on_progress=None):
|
| 183 |
"""
|
| 184 |
+
Run prediction on a batch of images. Processes in chunks to avoid OOM on
|
| 185 |
+
memory-constrained environments (e.g. Hugging Face free tier).
|
| 186 |
|
| 187 |
Args:
|
| 188 |
images: List of (img_array, key) or list of img arrays. img_array: (H, W) or (H, W, C).
|
| 189 |
substrate: Substrate name for single-cell mode (same for all images).
|
| 190 |
substrate_config: Optional dict with 'pixelsize' and 'young' (same for all).
|
| 191 |
+
batch_size: Max images per forward pass (default: BATCH_INFERENCE_SIZE). Use 1 for minimal memory.
|
| 192 |
+
on_progress: Optional callback(processed: int, total: int) called after each forward pass.
|
| 193 |
|
| 194 |
Returns:
|
| 195 |
List of (heatmap, force, pixel_sum) tuples.
|
| 196 |
"""
|
| 197 |
+
batch_size = batch_size if batch_size is not None else BATCH_INFERENCE_SIZE
|
| 198 |
imgs = []
|
| 199 |
for item in images:
|
| 200 |
img = item[0] if isinstance(item, tuple) else item
|
|
|
|
| 205 |
img = img / 255.0
|
| 206 |
img = cv2.resize(img, (MODEL_INPUT_SIZE, MODEL_INPUT_SIZE))
|
| 207 |
imgs.append(img)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 208 |
|
| 209 |
results = []
|
| 210 |
+
for start in range(0, len(imgs), batch_size):
|
| 211 |
+
chunk = imgs[start : start + batch_size]
|
| 212 |
+
x = torch.from_numpy(np.stack(chunk)).float().unsqueeze(1).to(self.device) # [B, 1, H, W]
|
| 213 |
+
|
| 214 |
+
if self.model_type == "single_cell" and self.norm_params is not None:
|
| 215 |
+
sub = substrate if substrate is not None else DEFAULT_SUBSTRATE
|
| 216 |
+
settings_ch = create_settings_channels_single(
|
| 217 |
+
sub, self.device, x.shape[2], x.shape[3],
|
| 218 |
+
config_path=self.config_path, substrate_config=substrate_config
|
| 219 |
+
)
|
| 220 |
+
settings_batch = settings_ch.expand(x.shape[0], -1, -1, -1)
|
| 221 |
+
x = torch.cat([x, settings_batch], dim=1) # [B, 3, H, W]
|
| 222 |
+
|
| 223 |
+
with torch.no_grad():
|
| 224 |
+
pred = self.generator(x)
|
| 225 |
+
|
| 226 |
+
if self._use_tanh_output:
|
| 227 |
+
pred = (pred + 1.0) / 2.0
|
| 228 |
+
|
| 229 |
+
for i in range(pred.shape[0]):
|
| 230 |
+
heatmap = pred[i, 0].cpu().numpy()
|
| 231 |
+
force = sum_force_map(pred[i : i + 1]).item()
|
| 232 |
+
pixel_sum = float(np.sum(heatmap))
|
| 233 |
+
results.append((heatmap, force, pixel_sum))
|
| 234 |
+
|
| 235 |
+
if on_progress is not None:
|
| 236 |
+
on_progress(len(results), len(imgs))
|
| 237 |
+
|
| 238 |
return results
|
S2FApp/static/s2f_styles.css
CHANGED
|
@@ -99,6 +99,9 @@ section[data-testid="stSidebar"] {
|
|
| 99 |
/* === Sidebar === */
|
| 100 |
section[data-testid="stSidebar"] {
|
| 101 |
width: 360px !important;
|
|
|
|
|
|
|
|
|
|
| 102 |
background:
|
| 103 |
linear-gradient(180deg, #f8fafc 0%, #f1f5f9 100%),
|
| 104 |
/* Subtle vertical rhythm */
|
|
@@ -403,8 +406,26 @@ div[data-testid="stHorizontalBlock"]:has([data-testid="stDownloadButton"]):has([
|
|
| 403 |
.stWarning { border-radius: 8px !important; }
|
| 404 |
.stInfo { border-radius: 8px !important; }
|
| 405 |
|
| 406 |
-
/* === Selectbox, toggle === */
|
| 407 |
-
[data-testid="stSelectbox"] > div > div
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 408 |
.selectbox-label {
|
| 409 |
margin: 0;
|
| 410 |
padding-top: 0.4rem;
|
|
|
|
| 99 |
/* === Sidebar === */
|
| 100 |
section[data-testid="stSidebar"] {
|
| 101 |
width: 360px !important;
|
| 102 |
+
max-height: 100vh !important;
|
| 103 |
+
overflow-x: hidden !important;
|
| 104 |
+
overflow-y: auto !important;
|
| 105 |
background:
|
| 106 |
linear-gradient(180deg, #f8fafc 0%, #f1f5f9 100%),
|
| 107 |
/* Subtle vertical rhythm */
|
|
|
|
| 406 |
.stWarning { border-radius: 8px !important; }
|
| 407 |
.stInfo { border-radius: 8px !important; }
|
| 408 |
|
| 409 |
+
/* === Selectbox, multiselect, toggle === */
|
| 410 |
+
[data-testid="stSelectbox"] > div > div,
|
| 411 |
+
[data-testid="stSelectbox"] input,
|
| 412 |
+
[data-testid="stMultiSelect"] > div > div,
|
| 413 |
+
[data-testid="stMultiSelect"] input {
|
| 414 |
+
color: #1e293b !important;
|
| 415 |
+
background-color: #ffffff !important;
|
| 416 |
+
}
|
| 417 |
+
[data-testid="stSelectbox"] > div > div,
|
| 418 |
+
[data-testid="stMultiSelect"] > div > div {
|
| 419 |
+
border-radius: 8px !important;
|
| 420 |
+
border: 1px solid #cbd5e1 !important;
|
| 421 |
+
}
|
| 422 |
+
/* Dropdown options: ensure visible text on light background */
|
| 423 |
+
[role="listbox"] [role="option"],
|
| 424 |
+
[data-baseweb="menu"] li,
|
| 425 |
+
ul[role="listbox"] li {
|
| 426 |
+
color: #1e293b !important;
|
| 427 |
+
background-color: #ffffff !important;
|
| 428 |
+
}
|
| 429 |
.selectbox-label {
|
| 430 |
margin: 0;
|
| 431 |
padding-top: 0.4rem;
|
S2FApp/ui/heatmaps.py
CHANGED
|
@@ -76,7 +76,7 @@ def make_annotated_heatmap_multi_regions(heatmap_rgb, masks, labels, cell_mask=N
|
|
| 76 |
annotated = heatmap_rgb.copy()
|
| 77 |
if cell_mask is not None and np.any(cell_mask > 0):
|
| 78 |
contours, _ = cv2.findContours(cell_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
| 79 |
-
cv2.drawContours(annotated, contours, -1, (255, 0, 0),
|
| 80 |
for i, mask in enumerate(masks):
|
| 81 |
color = _REGION_COLORS[i % len(_REGION_COLORS)]
|
| 82 |
_draw_region_overlay(annotated, mask, color, fill_alpha, stroke_width=2)
|
|
@@ -113,6 +113,6 @@ def add_cell_contour_to_fig(fig_pl, cell_mask, row=1, col=2):
|
|
| 113 |
x.append(x[0])
|
| 114 |
y.append(y[0])
|
| 115 |
fig_pl.add_trace(
|
| 116 |
-
go.Scatter(x=x, y=y, mode="lines", line=dict(color="red", width=
|
| 117 |
row=row, col=col
|
| 118 |
)
|
|
|
|
| 76 |
annotated = heatmap_rgb.copy()
|
| 77 |
if cell_mask is not None and np.any(cell_mask > 0):
|
| 78 |
contours, _ = cv2.findContours(cell_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
| 79 |
+
cv2.drawContours(annotated, contours, -1, (255, 0, 0), 3)
|
| 80 |
for i, mask in enumerate(masks):
|
| 81 |
color = _REGION_COLORS[i % len(_REGION_COLORS)]
|
| 82 |
_draw_region_overlay(annotated, mask, color, fill_alpha, stroke_width=2)
|
|
|
|
| 113 |
x.append(x[0])
|
| 114 |
y.append(y[0])
|
| 115 |
fig_pl.add_trace(
|
| 116 |
+
go.Scatter(x=x, y=y, mode="lines", line=dict(color="red", width=3), showlegend=False),
|
| 117 |
row=row, col=col
|
| 118 |
)
|
S2FApp/ui/measure_tool.py
CHANGED
|
@@ -156,7 +156,7 @@ def compute_region_metrics(raw_heatmap, mask, original_vals=None):
|
|
| 156 |
}
|
| 157 |
|
| 158 |
|
| 159 |
-
def _draw_contour_on_image(img_rgb, mask, stroke_color=(255, 0, 0), stroke_width=
|
| 160 |
"""Draw contour from mask on RGB image. Resizes mask to match img if needed."""
|
| 161 |
h, w = img_rgb.shape[:2]
|
| 162 |
if mask.shape[:2] != (h, w):
|
|
@@ -304,7 +304,7 @@ def render_region_canvas(display_heatmap, raw_heatmap=None, bf_img=None, origina
|
|
| 304 |
st.caption("Bright-field")
|
| 305 |
bf_display = bf_rgb.copy()
|
| 306 |
if cell_mask is not None and np.any(cell_mask > 0):
|
| 307 |
-
bf_display = _draw_contour_on_image(bf_display, cell_mask, stroke_color=(255, 0, 0), stroke_width=
|
| 308 |
st.image(bf_display, width=CANVAS_SIZE)
|
| 309 |
else:
|
| 310 |
st.markdown("**Draw a region** on the heatmap.")
|
|
|
|
| 156 |
}
|
| 157 |
|
| 158 |
|
| 159 |
+
def _draw_contour_on_image(img_rgb, mask, stroke_color=(255, 0, 0), stroke_width=3):
|
| 160 |
"""Draw contour from mask on RGB image. Resizes mask to match img if needed."""
|
| 161 |
h, w = img_rgb.shape[:2]
|
| 162 |
if mask.shape[:2] != (h, w):
|
|
|
|
| 304 |
st.caption("Bright-field")
|
| 305 |
bf_display = bf_rgb.copy()
|
| 306 |
if cell_mask is not None and np.any(cell_mask > 0):
|
| 307 |
+
bf_display = _draw_contour_on_image(bf_display, cell_mask, stroke_color=(255, 0, 0), stroke_width=3)
|
| 308 |
st.image(bf_display, width=CANVAS_SIZE)
|
| 309 |
else:
|
| 310 |
st.markdown("**Draw a region** on the heatmap.")
|
S2FApp/utils/report.py
CHANGED
|
@@ -68,7 +68,7 @@ def heatmap_to_rgb_with_contour(scaled_heatmap, colormap_name="Jet", cell_mask=N
|
|
| 68 |
if cell_mask is not None and np.any(cell_mask > 0):
|
| 69 |
contours, _ = cv2.findContours(cell_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
| 70 |
if contours:
|
| 71 |
-
cv2.drawContours(heatmap_rgb, contours, -1, (255, 0, 0),
|
| 72 |
return heatmap_rgb
|
| 73 |
|
| 74 |
|
|
|
|
| 68 |
if cell_mask is not None and np.any(cell_mask > 0):
|
| 69 |
contours, _ = cv2.findContours(cell_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
| 70 |
if contours:
|
| 71 |
+
cv2.drawContours(heatmap_rgb, contours, -1, (255, 0, 0), 3)
|
| 72 |
return heatmap_rgb
|
| 73 |
|
| 74 |
|