SmartHeal commited on
Commit
4e605cd
·
verified ·
1 Parent(s): 860a048

Update src/ai_processor.py

Browse files
Files changed (1) hide show
  1. src/ai_processor.py +79 -43
src/ai_processor.py CHANGED
@@ -292,29 +292,35 @@ def _imagenet_norm(arr: np.ndarray) -> np.ndarray:
292
 
293
  def _preprocess_for_seg(bgr_roi: np.ndarray, target_hw: Tuple[int, int]) -> np.ndarray:
294
  H, W = target_hw
295
- # Resize first
296
  resized = cv2.resize(bgr_roi, (W, H), interpolation=cv2.INTER_LINEAR)
297
- # Convert to RGB if required
298
  if SEG_EXPECTS_RGB:
299
  resized = cv2.cvtColor(resized, cv2.COLOR_BGR2RGB)
300
- # Normalize
301
  if SEG_NORM.lower() == "imagenet":
302
  x = _imagenet_norm(resized)
303
  else:
304
  x = resized.astype(np.float32) / 255.0
305
- # Add batch dim
306
  x = np.expand_dims(x, axis=0) # (1,H,W,3)
307
  return x
308
 
309
  def _to_prob(pred: np.ndarray) -> np.ndarray:
310
- # Pred could be (1,H,W,1), (H,W,1), (1,H,W), (H,W), or logits
311
  p = np.squeeze(pred)
312
- # If values look like logits, apply sigmoid
313
  pmin, pmax = float(p.min()), float(p.max())
314
  if pmax > 1.0 or pmin < 0.0:
315
  p = 1.0 / (1.0 + np.exp(-p))
316
  return p.astype(np.float32)
317
 
 
 
 
 
 
 
 
 
 
 
 
 
318
  # Global last debug dict (per-process) to attach into results
319
  _last_seg_debug: Dict[str, object] = {}
320
 
@@ -356,7 +362,6 @@ def segment_wound(image_bgr: np.ndarray, ts: str, out_dir: str) -> Tuple[np.ndar
356
  _log_kv("SEG_PROB_STATS", {"min": pmin, "max": pmax, "mean": pmean})
357
 
358
  if SMARTHEAL_DEBUG:
359
- # save heatmap (0..255)
360
  hm = (np.clip(p, 0, 1) * 255).astype(np.uint8)
361
  heat = cv2.applyColorMap(hm, cv2.COLORMAP_JET)
362
  heatmap_path = os.path.join(out_dir, f"seg_pred_heatmap_{ts}.png")
@@ -364,8 +369,8 @@ def segment_wound(image_bgr: np.ndarray, ts: str, out_dir: str) -> Tuple[np.ndar
364
 
365
  # Threshold
366
  thr = SEG_THRESH
367
- mask = (p >= thr).astype(np.uint8) * 255
368
- pos = int((mask > 0).sum())
369
  frac = pos / float(mask.size)
370
  logging.info(f"SegModel USED | thr={thr} pos_px={pos} pos_frac={frac:.4f} ex_rgb={SEG_EXPECTS_RGB} norm={SEG_NORM}")
371
 
@@ -382,7 +387,7 @@ def segment_wound(image_bgr: np.ndarray, ts: str, out_dir: str) -> Tuple[np.ndar
382
  "heatmap_path": heatmap_path,
383
  "roi_seen_by_model": saw_roi_path,
384
  }
385
- return mask.astype(np.uint8), _last_seg_debug
386
 
387
  except Exception as e:
388
  reason = f"model_failed: {e}"
@@ -395,9 +400,9 @@ def segment_wound(image_bgr: np.ndarray, ts: str, out_dir: str) -> Tuple[np.ndar
395
  centers_u8 = centers.astype(np.uint8).reshape(1, 2, 3)
396
  centers_lab = cv2.cvtColor(centers_u8, cv2.COLOR_BGR2LAB)[0]
397
  wound_idx = int(np.argmax(centers_lab[:, 1])) # maximize a* (redness)
398
- mask = (labels.reshape(image_bgr.shape[:2]) == wound_idx).astype(np.uint8) * 255
399
 
400
- pos = int((mask > 0).sum()); frac = pos / float(mask.size)
401
  logging.info(f"KMeans USED | pos_px={pos} pos_frac={frac:.4f}")
402
 
403
  _last_seg_debug = {
@@ -409,7 +414,7 @@ def segment_wound(image_bgr: np.ndarray, ts: str, out_dir: str) -> Tuple[np.ndar
409
  "heatmap_path": heatmap_path,
410
  "roi_seen_by_model": saw_roi_path,
411
  }
412
- return mask.astype(np.uint8), _last_seg_debug
413
 
414
  # ---------- Measurement + overlay helpers ----------
415
  def largest_component_mask(binary01: np.ndarray, min_area_px: int = 50) -> np.ndarray:
@@ -422,6 +427,17 @@ def largest_component_mask(binary01: np.ndarray, min_area_px: int = 50) -> np.nd
422
  largest_idx = 1 + int(np.argmax(areas))
423
  return (labels == largest_idx).astype(np.uint8)
424
 
 
 
 
 
 
 
 
 
 
 
 
425
  def measure_min_area_rect(mask01: np.ndarray, px_per_cm: float) -> Tuple[float, float, Tuple]:
426
  contours, _ = cv2.findContours(mask01.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
427
  if not contours:
@@ -447,45 +463,65 @@ def draw_measurement_overlay(
447
  breadth_cm: float,
448
  thickness: int = 2
449
  ) -> np.ndarray:
 
 
 
 
 
 
 
 
450
  overlay = base_bgr.copy()
451
 
452
- # Strong overlay + contour
 
 
453
  red = np.zeros_like(overlay); red[:] = (0, 0, 255)
454
  alpha = 0.55
455
  tinted = cv2.addWeighted(overlay, 1 - alpha, red, alpha, 0)
456
- m3 = cv2.merge([mask01 * 255] * 3).astype("uint8")
457
- overlay = np.where(m3 > 0, tinted, overlay)
458
 
459
- # Draw contour
460
- cnts, _ = cv2.findContours((mask01 * 255).astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
461
  if cnts:
462
  cv2.drawContours(overlay, cnts, -1, (255, 255, 255), 2)
463
 
464
  if rect_box is not None:
465
  cv2.polylines(overlay, [rect_box], True, (255, 255, 255), thickness)
466
  pts = rect_box.reshape(-1, 2)
467
- def midpoint(a, b): return ((a[0] + b[0]) // 2, (a[1] + b[1]) // 2)
468
- mids = [midpoint(pts[i], pts[(i+1) % 4]) for i in range(4)]
469
- e_lens = [np.linalg.norm(pts[i] - pts[(i+1) % 4]) for i in range(4)]
470
- long_pair = (0, 2) if e_lens[0] + e_lens[2] >= e_lens[1] + e_lens[3] else (1, 3)
471
- short_pair = (1, 3) if long_pair == (0, 2) else (0, 2)
472
 
473
- def draw_arrow(img, p1, p2):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
474
  cv2.arrowedLine(img, p1, p2, (0, 0, 0), thickness + 2, tipLength=0.05)
475
  cv2.arrowedLine(img, p2, p1, (0, 0, 0), thickness + 2, tipLength=0.05)
476
  cv2.arrowedLine(img, p1, p2, (255, 255, 255), thickness, tipLength=0.05)
477
  cv2.arrowedLine(img, p2, p1, (255, 255, 255), thickness, tipLength=0.05)
478
 
479
- def put_label(text, org):
480
- cv2.putText(overlay, text, (org[0] + 4, org[1] - 4),
481
- cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 0, 0), 4, cv2.LINE_AA)
482
- cv2.putText(overlay, text, (org[0] + 4, org[1] - 4),
483
- cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2, cv2.LINE_AA)
484
 
485
- draw_arrow(overlay, mids[long_pair[0]], mids[long_pair[1]])
486
- draw_arrow(overlay, mids[short_pair[0]], mids[short_pair[1]])
 
487
  put_label(f"Length: {length_cm:.2f} cm", mids[long_pair[0]])
488
- put_label(f"Breadth: {breadth_cm:.2f} cm", mids[short_pair[0]])
 
489
  return overlay
490
 
491
  # ---------- AI PROCESSOR ----------
@@ -504,7 +540,7 @@ class AIProcessor:
504
 
505
  def perform_visual_analysis(self, image_pil: Image.Image) -> Dict:
506
  """
507
- YOLO detect → crop ROI → segment_wound(ROI) → largest component
508
  minAreaRect measurement (cm) using EXIF px/cm → save outputs.
509
  """
510
  try:
@@ -542,24 +578,23 @@ class AIProcessor:
542
  mask_u8_255, seg_debug = segment_wound(roi, ts, out_dir)
543
  mask01 = (mask_u8_255 > 127).astype(np.uint8)
544
 
545
- # Post-processing + metrics
546
  if mask01.any():
547
- mask_before = mask01.sum()
548
- mask01 = cv2.morphologyEx(mask01, cv2.MORPH_OPEN, np.ones((3,3), np.uint8), iterations=1)
549
- mask01 = cv2.morphologyEx(mask01, cv2.MORPH_CLOSE, np.ones((3,3), np.uint8), iterations=1)
550
- mask01 = largest_component_mask(mask01, min_area_px=30)
551
- logging.debug(f"Mask postproc: px_before={mask_before} px_after={int(mask01.sum())}")
552
 
553
  # --- Measurement ---
554
  if mask01.any():
555
  length_cm, breadth_cm, (box_pts, _) = measure_min_area_rect(mask01, px_per_cm)
556
  surface_area_cm2 = count_area_cm2(mask01, px_per_cm)
 
557
  anno_roi = draw_measurement_overlay(roi, mask01, box_pts, length_cm, breadth_cm)
558
  segmentation_empty = False
559
  else:
 
560
  h_px = max(0, y2 - y1); w_px = max(0, x2 - x1)
561
- length_cm = round(h_px / px_per_cm, 2)
562
- breadth_cm = round(w_px / px_per_cm, 2)
563
  surface_area_cm2 = round((h_px * w_px) / (px_per_cm ** 2), 2)
564
  anno_roi = roi.copy()
565
  cv2.rectangle(anno_roi, (2, 2), (anno_roi.shape[1]-3, anno_roi.shape[0]-3), (0, 0, 255), 3)
@@ -580,7 +615,7 @@ class AIProcessor:
580
  roi_mask_path = os.path.join(out_dir, f"roi_mask_{ts}.png")
581
  cv2.imwrite(roi_mask_path, (mask01 * 255).astype(np.uint8))
582
 
583
- # ROI overlay (very clear)
584
  mask255 = (mask01 * 255).astype(np.uint8)
585
  mask3 = cv2.merge([mask255, mask255, mask255])
586
  red = np.zeros_like(roi); red[:] = (0, 0, 255)
@@ -591,7 +626,7 @@ class AIProcessor:
591
  cnts, _ = cv2.findContours(mask255, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
592
  cv2.drawContours(roi_overlay, cnts, -1, (255, 255, 255), 2)
593
  else:
594
- roi_overlay = anno_roi # already marked X
595
 
596
  seg_full = image_cv.copy()
597
  seg_full[y1:y2, x1:x2] = roi_overlay
@@ -601,6 +636,7 @@ class AIProcessor:
601
  segmentation_roi_path = os.path.join(out_dir, f"segmentation_roi_{ts}.png")
602
  cv2.imwrite(segmentation_roi_path, roi_overlay)
603
 
 
604
  anno_full = image_cv.copy()
605
  anno_full[y1:y2, x1:x2] = anno_roi
606
  annotated_seg_path = os.path.join(out_dir, f"segmentation_annotated_{ts}.png")
 
292
 
293
  def _preprocess_for_seg(bgr_roi: np.ndarray, target_hw: Tuple[int, int]) -> np.ndarray:
294
  H, W = target_hw
 
295
  resized = cv2.resize(bgr_roi, (W, H), interpolation=cv2.INTER_LINEAR)
 
296
  if SEG_EXPECTS_RGB:
297
  resized = cv2.cvtColor(resized, cv2.COLOR_BGR2RGB)
 
298
  if SEG_NORM.lower() == "imagenet":
299
  x = _imagenet_norm(resized)
300
  else:
301
  x = resized.astype(np.float32) / 255.0
 
302
  x = np.expand_dims(x, axis=0) # (1,H,W,3)
303
  return x
304
 
305
  def _to_prob(pred: np.ndarray) -> np.ndarray:
 
306
  p = np.squeeze(pred)
 
307
  pmin, pmax = float(p.min()), float(p.max())
308
  if pmax > 1.0 or pmin < 0.0:
309
  p = 1.0 / (1.0 + np.exp(-p))
310
  return p.astype(np.float32)
311
 
312
+ # ---- Robust mask post-processing (for "proper" masking) ----
313
+ def _fill_holes(mask01: np.ndarray) -> np.ndarray:
314
+ # Flood-fill from border, then invert
315
+ h, w = mask01.shape[:2]
316
+ ff = np.zeros((h + 2, w + 2), np.uint8)
317
+ m = (mask01 * 255).astype(np.uint8).copy()
318
+ cv2.floodFill(m, ff, (0, 0), 255)
319
+ m_inv = cv2.bitwise_not(m)
320
+ # Combine original with filled holes
321
+ out = ((mask01 * 255) | m_inv) // 255
322
+ return out.astype(np.uint8)
323
+
324
  # Global last debug dict (per-process) to attach into results
325
  _last_seg_debug: Dict[str, object] = {}
326
 
 
362
  _log_kv("SEG_PROB_STATS", {"min": pmin, "max": pmax, "mean": pmean})
363
 
364
  if SMARTHEAL_DEBUG:
 
365
  hm = (np.clip(p, 0, 1) * 255).astype(np.uint8)
366
  heat = cv2.applyColorMap(hm, cv2.COLORMAP_JET)
367
  heatmap_path = os.path.join(out_dir, f"seg_pred_heatmap_{ts}.png")
 
369
 
370
  # Threshold
371
  thr = SEG_THRESH
372
+ mask = (p >= thr).astype(np.uint8) # 0/1
373
+ pos = int(mask.sum())
374
  frac = pos / float(mask.size)
375
  logging.info(f"SegModel USED | thr={thr} pos_px={pos} pos_frac={frac:.4f} ex_rgb={SEG_EXPECTS_RGB} norm={SEG_NORM}")
376
 
 
387
  "heatmap_path": heatmap_path,
388
  "roi_seen_by_model": saw_roi_path,
389
  }
390
+ return (mask * 255).astype(np.uint8), _last_seg_debug
391
 
392
  except Exception as e:
393
  reason = f"model_failed: {e}"
 
400
  centers_u8 = centers.astype(np.uint8).reshape(1, 2, 3)
401
  centers_lab = cv2.cvtColor(centers_u8, cv2.COLOR_BGR2LAB)[0]
402
  wound_idx = int(np.argmax(centers_lab[:, 1])) # maximize a* (redness)
403
+ mask = (labels.reshape(image_bgr.shape[:2]) == wound_idx).astype(np.uint8)
404
 
405
+ pos = int(mask.sum()); frac = pos / float(mask.size)
406
  logging.info(f"KMeans USED | pos_px={pos} pos_frac={frac:.4f}")
407
 
408
  _last_seg_debug = {
 
414
  "heatmap_path": heatmap_path,
415
  "roi_seen_by_model": saw_roi_path,
416
  }
417
+ return (mask * 255).astype(np.uint8), _last_seg_debug
418
 
419
  # ---------- Measurement + overlay helpers ----------
420
  def largest_component_mask(binary01: np.ndarray, min_area_px: int = 50) -> np.ndarray:
 
427
  largest_idx = 1 + int(np.argmax(areas))
428
  return (labels == largest_idx).astype(np.uint8)
429
 
430
+ def _clean_mask(mask01: np.ndarray) -> np.ndarray:
431
+ """Open→Close→Fill holes→Largest component."""
432
+ if mask01.dtype != np.uint8:
433
+ mask01 = mask01.astype(np.uint8)
434
+ k = np.ones((3, 3), np.uint8)
435
+ mask01 = cv2.morphologyEx(mask01, cv2.MORPH_OPEN, k, iterations=1)
436
+ mask01 = cv2.morphologyEx(mask01, cv2.MORPH_CLOSE, k, iterations=2)
437
+ mask01 = _fill_holes(mask01)
438
+ mask01 = largest_component_mask(mask01, min_area_px=30)
439
+ return (mask01 > 0).astype(np.uint8)
440
+
441
  def measure_min_area_rect(mask01: np.ndarray, px_per_cm: float) -> Tuple[float, float, Tuple]:
442
  contours, _ = cv2.findContours(mask01.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
443
  if not contours:
 
463
  breadth_cm: float,
464
  thickness: int = 2
465
  ) -> np.ndarray:
466
+ """
467
+ Draws:
468
+ 1) Strong red mask overlay with white contour.
469
+ 2) Min-area rectangle.
470
+ 3) Two double-headed arrows:
471
+ - 'Length' along the longer side.
472
+ - 'Width' along the shorter side.
473
+ """
474
  overlay = base_bgr.copy()
475
 
476
+ # --- Strong overlay from mask (tinted red where mask==1) ---
477
+ mask255 = (mask01 * 255).astype(np.uint8)
478
+ mask3 = cv2.merge([mask255, mask255, mask255])
479
  red = np.zeros_like(overlay); red[:] = (0, 0, 255)
480
  alpha = 0.55
481
  tinted = cv2.addWeighted(overlay, 1 - alpha, red, alpha, 0)
482
+ overlay = np.where(mask3 > 0, tinted, overlay)
 
483
 
484
+ # Draw wound contour
485
+ cnts, _ = cv2.findContours(mask255, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
486
  if cnts:
487
  cv2.drawContours(overlay, cnts, -1, (255, 255, 255), 2)
488
 
489
  if rect_box is not None:
490
  cv2.polylines(overlay, [rect_box], True, (255, 255, 255), thickness)
491
  pts = rect_box.reshape(-1, 2)
 
 
 
 
 
492
 
493
+ def midpoint(a, b):
494
+ return (int((a[0] + b[0]) / 2), int((a[1] + b[1]) / 2))
495
+
496
+ # Edge lengths
497
+ e = [np.linalg.norm(pts[i] - pts[(i + 1) % 4]) for i in range(4)]
498
+ long_edge_idx = int(np.argmax(e))
499
+ short_edge_idx = (long_edge_idx + 1) % 2 # 0/1 map for pairs below
500
+
501
+ # Midpoints of opposite edges for arrows
502
+ mids = [midpoint(pts[i], pts[(i + 1) % 4]) for i in range(4)]
503
+ # Long side uses edges long_edge_idx and the opposite edge (i+2)
504
+ long_pair = (long_edge_idx, (long_edge_idx + 2) % 4)
505
+ # Short side uses the other pair
506
+ short_pair = ((long_edge_idx + 1) % 4, (long_edge_idx + 3) % 4)
507
+
508
+ def draw_double_arrow(img, p1, p2):
509
  cv2.arrowedLine(img, p1, p2, (0, 0, 0), thickness + 2, tipLength=0.05)
510
  cv2.arrowedLine(img, p2, p1, (0, 0, 0), thickness + 2, tipLength=0.05)
511
  cv2.arrowedLine(img, p1, p2, (255, 255, 255), thickness, tipLength=0.05)
512
  cv2.arrowedLine(img, p2, p1, (255, 255, 255), thickness, tipLength=0.05)
513
 
514
+ def put_label(text, anchor):
515
+ org = (anchor[0] + 6, anchor[1] - 6)
516
+ cv2.putText(overlay, text, org, cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 0, 0), 4, cv2.LINE_AA)
517
+ cv2.putText(overlay, text, org, cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2, cv2.LINE_AA)
 
518
 
519
+ # Draw arrows and labels
520
+ draw_double_arrow(overlay, mids[long_pair[0]], mids[long_pair[1]])
521
+ draw_double_arrow(overlay, mids[short_pair[0]], mids[short_pair[1]])
522
  put_label(f"Length: {length_cm:.2f} cm", mids[long_pair[0]])
523
+ put_label(f"Width: {breadth_cm:.2f} cm", mids[short_pair[0]])
524
+
525
  return overlay
526
 
527
  # ---------- AI PROCESSOR ----------
 
540
 
541
  def perform_visual_analysis(self, image_pil: Image.Image) -> Dict:
542
  """
543
+ YOLO detect → crop ROI → segment_wound(ROI) → clean mask
544
  minAreaRect measurement (cm) using EXIF px/cm → save outputs.
545
  """
546
  try:
 
578
  mask_u8_255, seg_debug = segment_wound(roi, ts, out_dir)
579
  mask01 = (mask_u8_255 > 127).astype(np.uint8)
580
 
581
+ # Robust post-processing to ensure "proper" masking
582
  if mask01.any():
583
+ mask01 = _clean_mask(mask01)
584
+ logging.debug(f"Mask postproc: px_after={int(mask01.sum())}")
 
 
 
585
 
586
  # --- Measurement ---
587
  if mask01.any():
588
  length_cm, breadth_cm, (box_pts, _) = measure_min_area_rect(mask01, px_per_cm)
589
  surface_area_cm2 = count_area_cm2(mask01, px_per_cm)
590
+ # Final annotated ROI with mask + arrows + labels
591
  anno_roi = draw_measurement_overlay(roi, mask01, box_pts, length_cm, breadth_cm)
592
  segmentation_empty = False
593
  else:
594
+ # Graceful fallback if seg failed: use ROI box as bounds
595
  h_px = max(0, y2 - y1); w_px = max(0, x2 - x1)
596
+ length_cm = round(max(h_px, w_px) / px_per_cm, 2)
597
+ breadth_cm = round(min(h_px, w_px) / px_per_cm, 2)
598
  surface_area_cm2 = round((h_px * w_px) / (px_per_cm ** 2), 2)
599
  anno_roi = roi.copy()
600
  cv2.rectangle(anno_roi, (2, 2), (anno_roi.shape[1]-3, anno_roi.shape[0]-3), (0, 0, 255), 3)
 
615
  roi_mask_path = os.path.join(out_dir, f"roi_mask_{ts}.png")
616
  cv2.imwrite(roi_mask_path, (mask01 * 255).astype(np.uint8))
617
 
618
+ # ROI overlay (clear mask w/ white contour, no arrows)
619
  mask255 = (mask01 * 255).astype(np.uint8)
620
  mask3 = cv2.merge([mask255, mask255, mask255])
621
  red = np.zeros_like(roi); red[:] = (0, 0, 255)
 
626
  cnts, _ = cv2.findContours(mask255, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
627
  cv2.drawContours(roi_overlay, cnts, -1, (255, 255, 255), 2)
628
  else:
629
+ roi_overlay = anno_roi
630
 
631
  seg_full = image_cv.copy()
632
  seg_full[y1:y2, x1:x2] = roi_overlay
 
636
  segmentation_roi_path = os.path.join(out_dir, f"segmentation_roi_{ts}.png")
637
  cv2.imwrite(segmentation_roi_path, roi_overlay)
638
 
639
+ # Annotated (mask + arrows + labels) in full-frame
640
  anno_full = image_cv.copy()
641
  anno_full[y1:y2, x1:x2] = anno_roi
642
  annotated_seg_path = os.path.join(out_dir, f"segmentation_annotated_{ts}.png")