kaveh commited on
Commit
6727da5
·
1 Parent(s): 67b648a

Implemented batch processing for predictions to optimize memory usage and added progress tracking. Updated constants for batch inference size and adjusted contour drawing parameters for better visibility in UI elements.

Browse files
S2FApp/app.py CHANGED
@@ -18,6 +18,7 @@ if S2F_ROOT not in sys.path:
18
  sys.path.insert(0, S2F_ROOT)
19
 
20
  from config.constants import (
 
21
  BATCH_MAX_IMAGES,
22
  COLORMAPS,
23
  DEFAULT_SUBSTRATE,
@@ -464,11 +465,19 @@ if just_ran_batch:
464
  try:
465
  predictor = _load_predictor(model_type, checkpoint, ckp_folder)
466
  sub_val = substrate_val if model_type == "single_cell" and not use_manual else DEFAULT_SUBSTRATE
467
- pred_results = predictor.predict_batch(
468
- imgs_batch,
469
- substrate=sub_val,
470
- substrate_config=substrate_config if model_type == "single_cell" else None,
471
- )
 
 
 
 
 
 
 
 
472
  batch_results = [
473
  {
474
  "img": img_b.copy(),
 
18
  sys.path.insert(0, S2F_ROOT)
19
 
20
  from config.constants import (
21
+ BATCH_INFERENCE_SIZE,
22
  BATCH_MAX_IMAGES,
23
  COLORMAPS,
24
  DEFAULT_SUBSTRATE,
 
465
  try:
466
  predictor = _load_predictor(model_type, checkpoint, ckp_folder)
467
  sub_val = substrate_val if model_type == "single_cell" and not use_manual else DEFAULT_SUBSTRATE
468
+ n_images = len(imgs_batch)
469
+ progress_bar = st.progress(0, text=f"Predicting 0 / {n_images} images")
470
+ pred_results = []
471
+ for start in range(0, n_images, BATCH_INFERENCE_SIZE):
472
+ chunk = imgs_batch[start : start + BATCH_INFERENCE_SIZE]
473
+ chunk_results = predictor.predict_batch(
474
+ chunk,
475
+ substrate=sub_val,
476
+ substrate_config=substrate_config if model_type == "single_cell" else None,
477
+ )
478
+ pred_results.extend(chunk_results)
479
+ progress_bar.progress(min(start + len(chunk), n_images) / n_images,
480
+ text=f"Predicting {len(pred_results)} / {n_images} images")
481
  batch_results = [
482
  {
483
  "img": img_b.copy(),
S2FApp/config/constants.py CHANGED
@@ -13,6 +13,8 @@ DEFAULT_SUBSTRATE = "Fibroblasts_Fibronectin_6KPa"
13
  CANVAS_SIZE = 320
14
  SAMPLE_THUMBNAIL_LIMIT = 8
15
  BATCH_MAX_IMAGES = 5
 
 
16
  COLORMAP_N_SAMPLES = 64
17
 
18
  # Model type labels
 
13
  CANVAS_SIZE = 320
14
  SAMPLE_THUMBNAIL_LIMIT = 8
15
  BATCH_MAX_IMAGES = 5
16
+ # Max images per model forward pass (avoids OOM on Hugging Face free tier)
17
+ BATCH_INFERENCE_SIZE = 2
18
  COLORMAP_N_SAMPLES = 64
19
 
20
  # Model type labels
S2FApp/predictor.py CHANGED
@@ -13,7 +13,7 @@ S2F_ROOT = os.path.dirname(os.path.abspath(__file__))
13
  if S2F_ROOT not in sys.path:
14
  sys.path.insert(0, S2F_ROOT)
15
 
16
- from config.constants import DEFAULT_SUBSTRATE, MODEL_INPUT_SIZE
17
  from models.s2f_model import create_s2f_model
18
  from utils.paths import get_ckp_base, model_subfolder
19
  from utils.substrate_settings import get_settings_of_category, compute_settings_normalization
@@ -178,18 +178,23 @@ class S2FPredictor:
178
 
179
  return heatmap, force, pixel_sum
180
 
181
- def predict_batch(self, images, substrate=None, substrate_config=None):
 
182
  """
183
- Run prediction on a batch of images (single forward pass).
 
184
 
185
  Args:
186
  images: List of (img_array, key) or list of img arrays. img_array: (H, W) or (H, W, C).
187
  substrate: Substrate name for single-cell mode (same for all images).
188
  substrate_config: Optional dict with 'pixelsize' and 'young' (same for all).
 
 
189
 
190
  Returns:
191
  List of (heatmap, force, pixel_sum) tuples.
192
  """
 
193
  imgs = []
194
  for item in images:
195
  img = item[0] if isinstance(item, tuple) else item
@@ -200,27 +205,34 @@ class S2FPredictor:
200
  img = img / 255.0
201
  img = cv2.resize(img, (MODEL_INPUT_SIZE, MODEL_INPUT_SIZE))
202
  imgs.append(img)
203
- x = torch.from_numpy(np.stack(imgs)).float().unsqueeze(1).to(self.device) # [B, 1, H, W]
204
-
205
- if self.model_type == "single_cell" and self.norm_params is not None:
206
- sub = substrate if substrate is not None else DEFAULT_SUBSTRATE
207
- settings_ch = create_settings_channels_single(
208
- sub, self.device, x.shape[2], x.shape[3],
209
- config_path=self.config_path, substrate_config=substrate_config
210
- )
211
- settings_batch = settings_ch.expand(x.shape[0], -1, -1, -1)
212
- x = torch.cat([x, settings_batch], dim=1) # [B, 3, H, W]
213
-
214
- with torch.no_grad():
215
- pred = self.generator(x)
216
-
217
- if self._use_tanh_output:
218
- pred = (pred + 1.0) / 2.0
219
 
220
  results = []
221
- for i in range(pred.shape[0]):
222
- heatmap = pred[i, 0].cpu().numpy()
223
- force = sum_force_map(pred[i : i + 1]).item()
224
- pixel_sum = float(np.sum(heatmap))
225
- results.append((heatmap, force, pixel_sum))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
226
  return results
 
13
  if S2F_ROOT not in sys.path:
14
  sys.path.insert(0, S2F_ROOT)
15
 
16
+ from config.constants import BATCH_INFERENCE_SIZE, DEFAULT_SUBSTRATE, MODEL_INPUT_SIZE
17
  from models.s2f_model import create_s2f_model
18
  from utils.paths import get_ckp_base, model_subfolder
19
  from utils.substrate_settings import get_settings_of_category, compute_settings_normalization
 
178
 
179
  return heatmap, force, pixel_sum
180
 
181
+ def predict_batch(self, images, substrate=None, substrate_config=None, batch_size=None,
182
+ on_progress=None):
183
  """
184
+ Run prediction on a batch of images. Processes in chunks to avoid OOM on
185
+ memory-constrained environments (e.g. Hugging Face free tier).
186
 
187
  Args:
188
  images: List of (img_array, key) or list of img arrays. img_array: (H, W) or (H, W, C).
189
  substrate: Substrate name for single-cell mode (same for all images).
190
  substrate_config: Optional dict with 'pixelsize' and 'young' (same for all).
191
+ batch_size: Max images per forward pass (default: BATCH_INFERENCE_SIZE). Use 1 for minimal memory.
192
+ on_progress: Optional callback(processed: int, total: int) called after each forward pass.
193
 
194
  Returns:
195
  List of (heatmap, force, pixel_sum) tuples.
196
  """
197
+ batch_size = batch_size if batch_size is not None else BATCH_INFERENCE_SIZE
198
  imgs = []
199
  for item in images:
200
  img = item[0] if isinstance(item, tuple) else item
 
205
  img = img / 255.0
206
  img = cv2.resize(img, (MODEL_INPUT_SIZE, MODEL_INPUT_SIZE))
207
  imgs.append(img)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
208
 
209
  results = []
210
+ for start in range(0, len(imgs), batch_size):
211
+ chunk = imgs[start : start + batch_size]
212
+ x = torch.from_numpy(np.stack(chunk)).float().unsqueeze(1).to(self.device) # [B, 1, H, W]
213
+
214
+ if self.model_type == "single_cell" and self.norm_params is not None:
215
+ sub = substrate if substrate is not None else DEFAULT_SUBSTRATE
216
+ settings_ch = create_settings_channels_single(
217
+ sub, self.device, x.shape[2], x.shape[3],
218
+ config_path=self.config_path, substrate_config=substrate_config
219
+ )
220
+ settings_batch = settings_ch.expand(x.shape[0], -1, -1, -1)
221
+ x = torch.cat([x, settings_batch], dim=1) # [B, 3, H, W]
222
+
223
+ with torch.no_grad():
224
+ pred = self.generator(x)
225
+
226
+ if self._use_tanh_output:
227
+ pred = (pred + 1.0) / 2.0
228
+
229
+ for i in range(pred.shape[0]):
230
+ heatmap = pred[i, 0].cpu().numpy()
231
+ force = sum_force_map(pred[i : i + 1]).item()
232
+ pixel_sum = float(np.sum(heatmap))
233
+ results.append((heatmap, force, pixel_sum))
234
+
235
+ if on_progress is not None:
236
+ on_progress(len(results), len(imgs))
237
+
238
  return results
S2FApp/static/s2f_styles.css CHANGED
@@ -99,6 +99,9 @@ section[data-testid="stSidebar"] {
99
  /* === Sidebar === */
100
  section[data-testid="stSidebar"] {
101
  width: 360px !important;
 
 
 
102
  background:
103
  linear-gradient(180deg, #f8fafc 0%, #f1f5f9 100%),
104
  /* Subtle vertical rhythm */
@@ -403,8 +406,26 @@ div[data-testid="stHorizontalBlock"]:has([data-testid="stDownloadButton"]):has([
403
  .stWarning { border-radius: 8px !important; }
404
  .stInfo { border-radius: 8px !important; }
405
 
406
- /* === Selectbox, toggle === */
407
- [data-testid="stSelectbox"] > div > div { border-radius: 8px !important; }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
408
  .selectbox-label {
409
  margin: 0;
410
  padding-top: 0.4rem;
 
99
  /* === Sidebar === */
100
  section[data-testid="stSidebar"] {
101
  width: 360px !important;
102
+ max-height: 100vh !important;
103
+ overflow-x: hidden !important;
104
+ overflow-y: auto !important;
105
  background:
106
  linear-gradient(180deg, #f8fafc 0%, #f1f5f9 100%),
107
  /* Subtle vertical rhythm */
 
406
  .stWarning { border-radius: 8px !important; }
407
  .stInfo { border-radius: 8px !important; }
408
 
409
+ /* === Selectbox, multiselect, toggle === */
410
+ [data-testid="stSelectbox"] > div > div,
411
+ [data-testid="stSelectbox"] input,
412
+ [data-testid="stMultiSelect"] > div > div,
413
+ [data-testid="stMultiSelect"] input {
414
+ color: #1e293b !important;
415
+ background-color: #ffffff !important;
416
+ }
417
+ [data-testid="stSelectbox"] > div > div,
418
+ [data-testid="stMultiSelect"] > div > div {
419
+ border-radius: 8px !important;
420
+ border: 1px solid #cbd5e1 !important;
421
+ }
422
+ /* Dropdown options: ensure visible text on light background */
423
+ [role="listbox"] [role="option"],
424
+ [data-baseweb="menu"] li,
425
+ ul[role="listbox"] li {
426
+ color: #1e293b !important;
427
+ background-color: #ffffff !important;
428
+ }
429
  .selectbox-label {
430
  margin: 0;
431
  padding-top: 0.4rem;
S2FApp/ui/heatmaps.py CHANGED
@@ -76,7 +76,7 @@ def make_annotated_heatmap_multi_regions(heatmap_rgb, masks, labels, cell_mask=N
76
  annotated = heatmap_rgb.copy()
77
  if cell_mask is not None and np.any(cell_mask > 0):
78
  contours, _ = cv2.findContours(cell_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
79
- cv2.drawContours(annotated, contours, -1, (255, 0, 0), 5)
80
  for i, mask in enumerate(masks):
81
  color = _REGION_COLORS[i % len(_REGION_COLORS)]
82
  _draw_region_overlay(annotated, mask, color, fill_alpha, stroke_width=2)
@@ -113,6 +113,6 @@ def add_cell_contour_to_fig(fig_pl, cell_mask, row=1, col=2):
113
  x.append(x[0])
114
  y.append(y[0])
115
  fig_pl.add_trace(
116
- go.Scatter(x=x, y=y, mode="lines", line=dict(color="red", width=4), showlegend=False),
117
  row=row, col=col
118
  )
 
76
  annotated = heatmap_rgb.copy()
77
  if cell_mask is not None and np.any(cell_mask > 0):
78
  contours, _ = cv2.findContours(cell_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
79
+ cv2.drawContours(annotated, contours, -1, (255, 0, 0), 3)
80
  for i, mask in enumerate(masks):
81
  color = _REGION_COLORS[i % len(_REGION_COLORS)]
82
  _draw_region_overlay(annotated, mask, color, fill_alpha, stroke_width=2)
 
113
  x.append(x[0])
114
  y.append(y[0])
115
  fig_pl.add_trace(
116
+ go.Scatter(x=x, y=y, mode="lines", line=dict(color="red", width=3), showlegend=False),
117
  row=row, col=col
118
  )
S2FApp/ui/measure_tool.py CHANGED
@@ -156,7 +156,7 @@ def compute_region_metrics(raw_heatmap, mask, original_vals=None):
156
  }
157
 
158
 
159
- def _draw_contour_on_image(img_rgb, mask, stroke_color=(255, 0, 0), stroke_width=5):
160
  """Draw contour from mask on RGB image. Resizes mask to match img if needed."""
161
  h, w = img_rgb.shape[:2]
162
  if mask.shape[:2] != (h, w):
@@ -304,7 +304,7 @@ def render_region_canvas(display_heatmap, raw_heatmap=None, bf_img=None, origina
304
  st.caption("Bright-field")
305
  bf_display = bf_rgb.copy()
306
  if cell_mask is not None and np.any(cell_mask > 0):
307
- bf_display = _draw_contour_on_image(bf_display, cell_mask, stroke_color=(255, 0, 0), stroke_width=5)
308
  st.image(bf_display, width=CANVAS_SIZE)
309
  else:
310
  st.markdown("**Draw a region** on the heatmap.")
 
156
  }
157
 
158
 
159
+ def _draw_contour_on_image(img_rgb, mask, stroke_color=(255, 0, 0), stroke_width=3):
160
  """Draw contour from mask on RGB image. Resizes mask to match img if needed."""
161
  h, w = img_rgb.shape[:2]
162
  if mask.shape[:2] != (h, w):
 
304
  st.caption("Bright-field")
305
  bf_display = bf_rgb.copy()
306
  if cell_mask is not None and np.any(cell_mask > 0):
307
+ bf_display = _draw_contour_on_image(bf_display, cell_mask, stroke_color=(255, 0, 0), stroke_width=3)
308
  st.image(bf_display, width=CANVAS_SIZE)
309
  else:
310
  st.markdown("**Draw a region** on the heatmap.")
S2FApp/utils/report.py CHANGED
@@ -68,7 +68,7 @@ def heatmap_to_rgb_with_contour(scaled_heatmap, colormap_name="Jet", cell_mask=N
68
  if cell_mask is not None and np.any(cell_mask > 0):
69
  contours, _ = cv2.findContours(cell_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
70
  if contours:
71
- cv2.drawContours(heatmap_rgb, contours, -1, (255, 0, 0), 5)
72
  return heatmap_rgb
73
 
74
 
 
68
  if cell_mask is not None and np.any(cell_mask > 0):
69
  contours, _ = cv2.findContours(cell_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
70
  if contours:
71
+ cv2.drawContours(heatmap_rgb, contours, -1, (255, 0, 0), 3)
72
  return heatmap_rgb
73
 
74