SmartHeal commited on
Commit
8ba3ba0
Β·
verified Β·
1 Parent(s): 4e605cd

Update src/ai_processor.py

Browse files
Files changed (1) hide show
  1. src/ai_processor.py +167 -58
src/ai_processor.py CHANGED
@@ -232,6 +232,108 @@ initialize_cpu_models()
232
  setup_knowledge_base()
233
 
234
  # ---------- Calibration helpers ----------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
235
  def _exif_to_dict(pil_img: Image.Image) -> Dict[str, object]:
236
  out = {}
237
  try:
@@ -326,95 +428,102 @@ _last_seg_debug: Dict[str, object] = {}
326
 
327
  def segment_wound(image_bgr: np.ndarray, ts: str, out_dir: str) -> Tuple[np.ndarray, Dict[str, object]]:
328
  """
329
- Attempts TF segmentation first; falls back to KMeans if needed.
 
330
  Returns (mask_uint8_0_255, debug_dict)
331
  """
332
- global _last_seg_debug
333
- _last_seg_debug = {}
334
 
335
  seg_model = models_cache.get("seg", None)
336
- used = "fallback_kmeans"
337
- reason = "no_model"
338
- heatmap_path = None
339
- saw_roi_path = None
340
 
 
341
  if seg_model is not None:
342
  try:
343
  ishape = getattr(seg_model, "input_shape", None)
344
  if not ishape or len(ishape) < 4:
345
  raise ValueError(f"Bad seg input_shape: {ishape}")
346
  th, tw = int(ishape[1]), int(ishape[2])
 
 
347
  x = _preprocess_for_seg(image_bgr, (th, tw))
348
- saw_roi = (cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB) if SEG_EXPECTS_RGB else image_bgr)
 
349
  if SMARTHEAL_DEBUG:
350
- saw_roi_path = os.path.join(out_dir, f"roi_for_seg_{ts}.png")
351
- cv2.imwrite(saw_roi_path, (cv2.cvtColor(saw_roi, cv2.COLOR_RGB2BGR) if SEG_EXPECTS_RGB else saw_roi))
352
 
353
- # Inference
354
  pred = seg_model.predict(x, verbose=0)
355
- if isinstance(pred, (list, tuple)):
356
- pred = pred[0]
357
- p = _to_prob(pred) # HxW
358
- p = cv2.resize(p, (image_bgr.shape[1], image_bgr.shape[0])) # back to ROI size
359
-
360
- # Debug stats
361
- pmin, pmax, pmean = float(p.min()), float(p.max()), float(p.mean())
362
- _log_kv("SEG_PROB_STATS", {"min": pmin, "max": pmax, "mean": pmean})
363
 
 
 
364
  if SMARTHEAL_DEBUG:
365
  hm = (np.clip(p, 0, 1) * 255).astype(np.uint8)
366
  heat = cv2.applyColorMap(hm, cv2.COLORMAP_JET)
367
  heatmap_path = os.path.join(out_dir, f"seg_pred_heatmap_{ts}.png")
368
  cv2.imwrite(heatmap_path, heat)
369
 
370
- # Threshold
371
- thr = SEG_THRESH
372
- mask = (p >= thr).astype(np.uint8) # 0/1
373
- pos = int(mask.sum())
374
- frac = pos / float(mask.size)
375
- logging.info(f"SegModel USED | thr={thr} pos_px={pos} pos_frac={frac:.4f} ex_rgb={SEG_EXPECTS_RGB} norm={SEG_NORM}")
376
-
377
- used = "tf_model"
378
- reason = "ok"
379
-
380
- _last_seg_debug = {
381
- "used": used,
382
- "reason": reason,
383
- "input_shape": ishape,
384
- "prob_min": pmin, "prob_max": pmax, "prob_mean": pmean,
385
- "threshold": thr,
386
- "positive_fraction": frac,
 
 
 
 
 
 
 
 
 
 
387
  "heatmap_path": heatmap_path,
388
- "roi_seen_by_model": saw_roi_path,
389
- }
390
- return (mask * 255).astype(np.uint8), _last_seg_debug
391
 
392
  except Exception as e:
393
- reason = f"model_failed: {e}"
394
- logging.warning(f"⚠️ Segmentation model prediction failed β†’ fallback. Reason: {e}")
395
 
396
- # --- Fallback: KMeans (k=2), pick 'reddest' cluster in Lab a* ---
397
  Z = image_bgr.reshape((-1, 3)).astype(np.float32)
398
  criteria = (cv2.TERM_CRITERIA_EPS + cv2.TERM_CRITERIA_MAX_ITER, 10, 1.0)
399
  _, labels, centers = cv2.kmeans(Z, 2, None, criteria, 5, cv2.KMEANS_PP_CENTERS)
400
  centers_u8 = centers.astype(np.uint8).reshape(1, 2, 3)
401
  centers_lab = cv2.cvtColor(centers_u8, cv2.COLOR_BGR2LAB)[0]
402
- wound_idx = int(np.argmax(centers_lab[:, 1])) # maximize a* (redness)
403
- mask = (labels.reshape(image_bgr.shape[:2]) == wound_idx).astype(np.uint8)
404
-
405
- pos = int(mask.sum()); frac = pos / float(mask.size)
406
- logging.info(f"KMeans USED | pos_px={pos} pos_frac={frac:.4f}")
407
-
408
- _last_seg_debug = {
409
- "used": used,
410
- "reason": reason,
411
- "kmeans_centers_bgr": centers.tolist(),
412
- "kmeans_centers_lab": centers_lab.astype(float).tolist(),
413
- "positive_fraction": frac,
414
- "heatmap_path": heatmap_path,
415
- "roi_seen_by_model": saw_roi_path,
416
- }
417
- return (mask * 255).astype(np.uint8), _last_seg_debug
418
 
419
  # ---------- Measurement + overlay helpers ----------
420
  def largest_component_mask(binary01: np.ndarray, min_area_px: int = 50) -> np.ndarray:
 
232
  setup_knowledge_base()
233
 
234
  # ---------- Calibration helpers ----------
235
+ # ---- Adaptive thresholding for model prob map ----
236
+ def _adaptive_prob_threshold(p: np.ndarray) -> float:
237
+ """
238
+ Pick a threshold that avoids tiny blobs while not swallowing skin.
239
+ Strategy:
240
+ - try Otsu on the prob map
241
+ - clamp to a reasonable band [0.25, 0.65]
242
+ - also consider percentile cut (p90) and take the "best" by area heuristic
243
+ """
244
+ p01 = np.clip(p.astype(np.float32), 0, 1)
245
+ p255 = (p01 * 255).astype(np.uint8)
246
+
247
+ # Otsu
248
+ _, thr_otsu = cv2.threshold(p255, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
249
+ thr_otsu = np.clip(thr_otsu / 255.0, 0.25, 0.65)
250
+
251
+ # Percentile (90th)
252
+ thr_pctl = float(np.clip(np.percentile(p01, 90), 0.25, 0.65))
253
+
254
+ # Prefer the threshold that yields an area fraction in [0.005..0.20]
255
+ def area_frac(thr):
256
+ return float((p01 >= thr).sum()) / float(p01.size)
257
+
258
+ af_otsu = area_frac(thr_otsu)
259
+ af_pctl = area_frac(thr_pctl)
260
+
261
+ # Score: closeness to a target area fraction (aim ~3–10%)
262
+ def score(af):
263
+ target_low, target_high = 0.03, 0.10
264
+ if af < target_low: return abs(af - target_low) * 3.0
265
+ if af > target_high: return abs(af - target_high) * 1.5
266
+ return 0.0
267
+
268
+ return thr_otsu if score(af_otsu) <= score(af_pctl) else thr_pctl
269
+
270
+
271
+ def _grabcut_refine(bgr: np.ndarray, seed01: np.ndarray, iters: int = 3) -> np.ndarray:
272
+ """
273
+ Use OpenCV GrabCut to grow from a confident core into low-contrast margins.
274
+ seed01: 1=probable FG core, 0=unknown/other
275
+ """
276
+ h, w = bgr.shape[:2]
277
+ # Build GC mask: start with "unknown"
278
+ gc = np.full((h, w), cv2.GC_PR_BGD, np.uint8)
279
+ # definite FG = dilated seed; probable FG = seed
280
+ k = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))
281
+ seed_dil = cv2.dilate(seed01, k, iterations=1)
282
+ gc[seed01.astype(bool)] = cv2.GC_PR_FGD
283
+ gc[seed_dil.astype(bool)] = cv2.GC_FGD
284
+ # border is probable background
285
+ gc[0, :], gc[-1, :], gc[:, 0], gc[:, -1] = cv2.GC_BGD, cv2.GC_BGD, cv2.GC_BGD, cv2.GC_BGD
286
+
287
+ bgdModel = np.zeros((1, 65), np.float64)
288
+ fgdModel = np.zeros((1, 65), np.float64)
289
+ cv2.grabCut(bgr, gc, None, bgdModel, fgdModel, iters, cv2.GC_INIT_WITH_MASK)
290
+
291
+ # FG = definite or probable foreground
292
+ mask01 = np.where((gc == cv2.GC_FGD) | (gc == cv2.GC_PR_FGD), 1, 0).astype(np.uint8)
293
+ return mask01
294
+
295
+
296
+ def _fill_holes(mask01: np.ndarray) -> np.ndarray:
297
+ h, w = mask01.shape[:2]
298
+ ff = np.zeros((h + 2, w + 2), np.uint8)
299
+ m = (mask01 * 255).astype(np.uint8).copy()
300
+ cv2.floodFill(m, ff, (0, 0), 255)
301
+ m_inv = cv2.bitwise_not(m)
302
+ out = ((mask01 * 255) | m_inv) // 255
303
+ return out.astype(np.uint8)
304
+
305
+
306
+ def _clean_mask(mask01: np.ndarray) -> np.ndarray:
307
+ """Open β†’ Close β†’ Fill holes β†’ Largest component β†’ light smooth."""
308
+ mask01 = (mask01 > 0).astype(np.uint8)
309
+ k3 = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3))
310
+ k5 = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))
311
+ mask01 = cv2.morphologyEx(mask01, cv2.MORPH_OPEN, k3, iterations=1)
312
+ mask01 = cv2.morphologyEx(mask01, cv2.MORPH_CLOSE, k5, iterations=2)
313
+ mask01 = _fill_holes(mask01)
314
+
315
+ # keep largest component
316
+ num, labels, stats, _ = cv2.connectedComponentsWithStats(mask01, 8)
317
+ if num > 1:
318
+ areas = stats[1:, cv2.CC_STAT_AREA]
319
+ if areas.size:
320
+ largest_idx = 1 + int(np.argmax(areas))
321
+ mask01 = (labels == largest_idx).astype(np.uint8)
322
+
323
+ # tiny masks β†’ gentle grow (distance transform based)
324
+ area = int(mask01.sum())
325
+ if area > 0:
326
+ grow = 1 if area < 2000 else 0
327
+ if grow:
328
+ k = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3))
329
+ mask01 = cv2.dilate(mask01, k, iterations=1)
330
+
331
+ return (mask01 > 0).astype(np.uint8)
332
+
333
+
334
+
335
+
336
+
337
  def _exif_to_dict(pil_img: Image.Image) -> Dict[str, object]:
338
  out = {}
339
  try:
 
428
 
429
  def segment_wound(image_bgr: np.ndarray, ts: str, out_dir: str) -> Tuple[np.ndarray, Dict[str, object]]:
430
  """
431
+ TF model β†’ adaptive threshold on prob β†’ (optional) GrabCut grow β†’ cleanup.
432
+ Falls back to KMeans-Lab when model missing/fails.
433
  Returns (mask_uint8_0_255, debug_dict)
434
  """
435
+ debug = {"used": None, "reason": None, "positive_fraction": 0.0,
436
+ "thr": None, "heatmap_path": None, "roi_seen_by_model": None}
437
 
438
  seg_model = models_cache.get("seg", None)
 
 
 
 
439
 
440
+ # --- Model path ---
441
  if seg_model is not None:
442
  try:
443
  ishape = getattr(seg_model, "input_shape", None)
444
  if not ishape or len(ishape) < 4:
445
  raise ValueError(f"Bad seg input_shape: {ishape}")
446
  th, tw = int(ishape[1]), int(ishape[2])
447
+
448
+ # preprocess
449
  x = _preprocess_for_seg(image_bgr, (th, tw))
450
+ rgb_for_view = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB)
451
+ roi_seen_path = None
452
  if SMARTHEAL_DEBUG:
453
+ roi_seen_path = os.path.join(out_dir, f"roi_for_seg_{ts}.png")
454
+ cv2.imwrite(roi_seen_path, cv2.cvtColor(rgb_for_view, cv2.COLOR_RGB2BGR))
455
 
456
+ # predict β†’ prob map back to ROI size
457
  pred = seg_model.predict(x, verbose=0)
458
+ if isinstance(pred, (list, tuple)): pred = pred[0]
459
+ p = _to_prob(pred)
460
+ p = cv2.resize(p, (image_bgr.shape[1], image_bgr.shape[0]), interpolation=cv2.INTER_LINEAR)
 
 
 
 
 
461
 
462
+ # visualization (optional)
463
+ heatmap_path = None
464
  if SMARTHEAL_DEBUG:
465
  hm = (np.clip(p, 0, 1) * 255).astype(np.uint8)
466
  heat = cv2.applyColorMap(hm, cv2.COLORMAP_JET)
467
  heatmap_path = os.path.join(out_dir, f"seg_pred_heatmap_{ts}.png")
468
  cv2.imwrite(heatmap_path, heat)
469
 
470
+ # --- Adaptive threshold ---
471
+ thr = _adaptive_prob_threshold(p)
472
+ core01 = (p >= thr).astype(np.uint8)
473
+ core_frac = float(core01.sum()) / float(core01.size)
474
+
475
+ # If still too tiny, try a gentler threshold
476
+ if core_frac < 0.005:
477
+ thr2 = max(thr - 0.10, 0.15)
478
+ core01 = (p >= thr2).astype(np.uint8)
479
+ thr = thr2
480
+ core_frac = float(core01.sum()) / float(core01.size)
481
+
482
+ # --- Grow with GrabCut (only if some core exists) ---
483
+ if core01.any():
484
+ gc01 = _grabcut_refine(image_bgr, core01, iters=3)
485
+ mask01 = _clean_mask(gc01)
486
+ else:
487
+ mask01 = np.zeros(core01.shape, np.uint8)
488
+
489
+ pos_frac = float(mask01.sum()) / float(mask01.size)
490
+ logging.info(f"SegModel USED | thr={thr:.2f} core_frac={core_frac:.4f} final_frac={pos_frac:.4f}")
491
+
492
+ debug.update({
493
+ "used": "tf_model",
494
+ "reason": "ok",
495
+ "positive_fraction": pos_frac,
496
+ "thr": thr,
497
  "heatmap_path": heatmap_path,
498
+ "roi_seen_by_model": roi_seen_path
499
+ })
500
+ return (mask01 * 255).astype(np.uint8), debug
501
 
502
  except Exception as e:
503
+ logging.warning(f"⚠️ Segmentation model failed β†’ fallback. Reason: {e}")
504
+ debug.update({"used": "fallback_kmeans", "reason": f"model_failed: {e}"})
505
 
506
+ # --- Fallback: KMeans in Lab (reddest cluster as wound) ---
507
  Z = image_bgr.reshape((-1, 3)).astype(np.float32)
508
  criteria = (cv2.TERM_CRITERIA_EPS + cv2.TERM_CRITERIA_MAX_ITER, 10, 1.0)
509
  _, labels, centers = cv2.kmeans(Z, 2, None, criteria, 5, cv2.KMEANS_PP_CENTERS)
510
  centers_u8 = centers.astype(np.uint8).reshape(1, 2, 3)
511
  centers_lab = cv2.cvtColor(centers_u8, cv2.COLOR_BGR2LAB)[0]
512
+ wound_idx = int(np.argmax(centers_lab[:, 1])) # maximize a* (red)
513
+ mask01 = (labels.reshape(image_bgr.shape[:2]) == wound_idx).astype(np.uint8)
514
+ mask01 = _clean_mask(mask01)
515
+
516
+ pos_frac = float(mask01.sum()) / float(mask01.size)
517
+ logging.info(f"KMeans USED | final_frac={pos_frac:.4f}")
518
+
519
+ debug.update({
520
+ "used": "fallback_kmeans",
521
+ "reason": debug.get("reason") or "no_model",
522
+ "positive_fraction": pos_frac,
523
+ "thr": None
524
+ })
525
+ return (mask01 * 255).astype(np.uint8), debug
526
+
 
527
 
528
  # ---------- Measurement + overlay helpers ----------
529
  def largest_component_mask(binary01: np.ndarray, min_area_px: int = 50) -> np.ndarray: