revert: remove post-processing, use raw model output — works best for majority of images
ffa8f7b verified | """WoundNetB7 multiclass segmentation model — 4 classes (bg, foot, perilesion, ulcer). | |
| Architecture: EfficientNet-B7 encoder + ASPP + CBAM + TAM + UNet decoder. | |
| Checkpoint: Track B multiclass, ulcer Dice = 0.927 (Bootstrap CI: [0.917, 0.936]). | |
| """ | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import segmentation_models_pytorch as smp | |
| import numpy as np | |
| import cv2 | |
| from pathlib import Path | |
| IMG_SIZE = 512 | |
| MEAN = np.array([0.485, 0.456, 0.406]) | |
| STD = np.array([0.229, 0.224, 0.225]) | |
| CLASS_NAMES = {0: "background", 1: "foot", 2: "perilesion", 3: "ulcer"} | |
| CLASS_COLORS = { | |
| 0: (0, 0, 0), | |
| 1: (0, 255, 0), | |
| 2: (255, 165, 0), | |
| 3: (255, 0, 0), | |
| } | |
| # --------------------------------------------------------------------------- | |
| # Architecture modules (match checkpoint weights exactly) | |
| # --------------------------------------------------------------------------- | |
| class ChannelAttention(nn.Module): | |
| def __init__(self, channels, reduction=16): | |
| super().__init__() | |
| self.mlp = nn.Sequential( | |
| nn.Linear(channels, channels // reduction, bias=False), | |
| nn.ReLU(inplace=True), | |
| nn.Linear(channels // reduction, channels, bias=False), | |
| ) | |
| def forward(self, x): | |
| avg_out = self.mlp(x.mean(dim=[2, 3])) | |
| max_out = self.mlp(x.amax(dim=[2, 3])) | |
| attn = torch.sigmoid(avg_out + max_out).unsqueeze(-1).unsqueeze(-1) | |
| return x * attn | |
| class SpatialAttention(nn.Module): | |
| def __init__(self, kernel_size=7): | |
| super().__init__() | |
| self.conv = nn.Conv2d(2, 1, kernel_size, padding=kernel_size // 2, bias=False) | |
| def forward(self, x): | |
| avg_out = x.mean(dim=1, keepdim=True) | |
| max_out = x.amax(dim=1, keepdim=True) | |
| attn = torch.sigmoid(self.conv(torch.cat([avg_out, max_out], dim=1))) | |
| return x * attn | |
| class CBAM(nn.Module): | |
| def __init__(self, channels, reduction=16, kernel_size=7): | |
| super().__init__() | |
| self.ca = ChannelAttention(channels, reduction) | |
| self.sa = SpatialAttention(kernel_size) | |
| def forward(self, x): | |
| return self.sa(self.ca(x)) | |
| class DifferentiableFractalDimension(nn.Module): | |
| def __init__(self, scales=None): | |
| super().__init__() | |
| self.scales = scales or [2, 4, 8, 16, 32] | |
| def forward(self, x): | |
| B, C, H, W = x.shape | |
| counts = [] | |
| for s in self.scales: | |
| if s >= H or s >= W: | |
| continue | |
| pooled = F.avg_pool2d(x, kernel_size=s, stride=s) | |
| n_boxes = torch.sigmoid(10.0 * (pooled - 0.1)).sum(dim=[2, 3]) | |
| counts.append(n_boxes) | |
| if len(counts) < 2: | |
| return torch.ones(B, C, device=x.device) | |
| log_s = torch.log(torch.tensor([float(s) for s in self.scales[: len(counts)]], device=x.device)) | |
| log_c = torch.stack([torch.log(c + 1) for c in counts], dim=-1) | |
| n = log_s.shape[0] | |
| sx, sxx = log_s.sum(), (log_s**2).sum() | |
| sy = log_c.sum(dim=-1) | |
| sxy = (log_c * log_s.unsqueeze(0).unsqueeze(0)).sum(dim=-1) | |
| slope = (n * sxy - sx * sy) / (n * sxx - sx**2 + 1e-8) | |
| return -slope.mean(dim=1, keepdim=True).unsqueeze(-1).unsqueeze(-1) | |
| class DifferentiableEulerCharacteristic(nn.Module): | |
| def forward(self, x): | |
| B, C, H, W = x.shape | |
| b = torch.sigmoid(10.0 * (torch.sigmoid(x) - 0.5)) | |
| V = b.sum(dim=[2, 3]) | |
| E_h = (b[:, :, :, :-1] * b[:, :, :, 1:]).sum(dim=[2, 3]) | |
| E_v = (b[:, :, :-1, :] * b[:, :, 1:, :]).sum(dim=[2, 3]) | |
| F_val = (b[:, :, :-1, :-1] * b[:, :, :-1, 1:] * b[:, :, 1:, :-1] * b[:, :, 1:, 1:]).sum(dim=[2, 3]) | |
| euler = V - E_h - E_v + F_val | |
| return euler.mean(dim=1, keepdim=True).unsqueeze(-1).unsqueeze(-1) / (H * W) | |
| class TopologicalAttentionModule(nn.Module): | |
| def __init__(self, in_channels): | |
| super().__init__() | |
| self.fractal = DifferentiableFractalDimension() | |
| self.euler = DifferentiableEulerCharacteristic() | |
| self.alpha = nn.Parameter(torch.tensor(1.0)) | |
| self.beta = nn.Parameter(torch.tensor(1.0)) | |
| self.conv = nn.Sequential( | |
| nn.Conv2d(in_channels + 2, in_channels, 1), | |
| nn.BatchNorm2d(in_channels), | |
| nn.ReLU(inplace=True), | |
| nn.Conv2d(in_channels, in_channels, 1), | |
| nn.Sigmoid(), | |
| ) | |
| def forward(self, x): | |
| B, C, H, W = x.shape | |
| fm = self.fractal(x).expand(B, 1, H, W) | |
| em = self.euler(x).expand(B, 1, H, W) | |
| attn = self.conv(torch.cat([x, self.alpha * fm, self.beta * em], dim=1)) | |
| return x * attn + x | |
| class ASPP(nn.Module): | |
| def __init__(self, in_ch, out_ch, rates=None): | |
| super().__init__() | |
| rates = rates or [6, 12, 18] | |
| self.conv1x1 = nn.Sequential(nn.Conv2d(in_ch, out_ch, 1), nn.BatchNorm2d(out_ch), nn.ReLU(True)) | |
| self.atrous = nn.ModuleList( | |
| [nn.Sequential(nn.Conv2d(in_ch, out_ch, 3, padding=r, dilation=r), nn.BatchNorm2d(out_ch), nn.ReLU(True)) for r in rates] | |
| ) | |
| self.pool = nn.Sequential(nn.AdaptiveAvgPool2d(1), nn.Conv2d(in_ch, out_ch, 1), nn.ReLU(True)) | |
| self.project = nn.Sequential( | |
| nn.Conv2d(out_ch * (2 + len(rates)), out_ch, 1), nn.BatchNorm2d(out_ch), nn.ReLU(True), nn.Dropout(0.5) | |
| ) | |
| def forward(self, x): | |
| size = x.shape[2:] | |
| feats = [self.conv1x1(x)] + [a(x) for a in self.atrous] | |
| feats.append(F.interpolate(self.pool(x), size=size, mode="bilinear", align_corners=False)) | |
| return self.project(torch.cat(feats, dim=1)) | |
| class WoundNetB7(nn.Module): | |
| """WoundNetB7 matching the Track B checkpoint structure.""" | |
| NUM_CLASSES = 4 | |
| def __init__(self, num_classes=4): | |
| super().__init__() | |
| self.backbone = smp.Unet(encoder_name="efficientnet-b7", encoder_weights=None, in_channels=3, classes=num_classes) | |
| enc_ch = self.backbone.encoder.out_channels[-1] | |
| self.aspp = ASPP(enc_ch, enc_ch) | |
| self.cbam = CBAM(num_classes, reduction=max(1, num_classes // 2)) | |
| self.tam = TopologicalAttentionModule(num_classes) | |
| self.diffusion_weight = nn.Parameter(torch.tensor(0.01)) | |
| def forward(self, x): | |
| features = list(self.backbone.encoder(x)) | |
| features[-1] = self.aspp(features[-1]) | |
| try: | |
| dec = self.backbone.decoder(features) | |
| except TypeError: | |
| dec = self.backbone.decoder(*features) | |
| seg = self.backbone.segmentation_head(dec) | |
| seg = self.cbam(seg) | |
| seg = self.tam(seg) | |
| return seg | |
| # --------------------------------------------------------------------------- | |
| # Inference helpers | |
| # --------------------------------------------------------------------------- | |
| def preprocess(img_bgr: np.ndarray) -> torch.Tensor: | |
| """BGR image -> normalized CHW tensor (1, 3, 512, 512).""" | |
| img = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB) | |
| img = cv2.resize(img, (IMG_SIZE, IMG_SIZE), interpolation=cv2.INTER_LINEAR) | |
| img = (img.astype(np.float32) / 255.0 - MEAN) / STD | |
| return torch.from_numpy(img.transpose(2, 0, 1)).unsqueeze(0).float() | |
| def tta_inference(model: nn.Module, img_tensor: torch.Tensor, device: torch.device) -> torch.Tensor: | |
| """6-fold TTA -> averaged softmax probabilities (1, C, H, W).""" | |
| transforms = [ | |
| lambda x: x, | |
| lambda x: torch.flip(x, [3]), | |
| lambda x: torch.flip(x, [2]), | |
| lambda x: torch.rot90(x, 1, [2, 3]), | |
| lambda x: torch.rot90(x, 2, [2, 3]), | |
| lambda x: torch.rot90(x, 3, [2, 3]), | |
| ] | |
| inverse = [ | |
| lambda x: x, | |
| lambda x: torch.flip(x, [3]), | |
| lambda x: torch.flip(x, [2]), | |
| lambda x: torch.rot90(x, 3, [2, 3]), | |
| lambda x: torch.rot90(x, 2, [2, 3]), | |
| lambda x: torch.rot90(x, 1, [2, 3]), | |
| ] | |
| probs_sum = None | |
| with torch.no_grad(): | |
| for tfm, inv in zip(transforms, inverse): | |
| out = model(tfm(img_tensor).to(device)) | |
| if isinstance(out, (tuple, list)): | |
| out = out[0] | |
| if isinstance(out, dict): | |
| out = out["seg"] | |
| p = inv(F.softmax(out, dim=1)) | |
| probs_sum = p if probs_sum is None else probs_sum + p | |
| return probs_sum / len(transforms) | |
| def load_segmentation_model(checkpoint_path: str, device: torch.device) -> nn.Module: | |
| """Load WoundNetB7 from checkpoint.""" | |
| model = WoundNetB7(num_classes=4) | |
| state = torch.load(checkpoint_path, map_location=device, weights_only=False) | |
| # Remove PWAT head keys if present | |
| state = {k: v for k, v in state.items() if not k.startswith("pwat_head.")} | |
| model.load_state_dict(state, strict=False) | |
| model.to(device).eval() | |
| return model | |
| def segment(model: nn.Module, img_bgr: np.ndarray, device: torch.device, use_tta: bool = True) -> dict: | |
| """Run segmentation on a BGR image. | |
| Returns dict with: | |
| classmap: (H, W) uint8 with class indices 0-3 | |
| masks: dict of per-class binary masks {cls_name: (H, W) bool} | |
| probs: (4, H, W) float32 softmax probabilities | |
| """ | |
| h, w = img_bgr.shape[:2] | |
| tensor = preprocess(img_bgr) | |
| if use_tta: | |
| probs = tta_inference(model, tensor, device) | |
| else: | |
| with torch.no_grad(): | |
| out = model(tensor.to(device)) | |
| if isinstance(out, (tuple, list)): | |
| out = out[0] | |
| if isinstance(out, dict): | |
| out = out["seg"] | |
| probs = F.softmax(out, dim=1) | |
| probs_np = probs[0].cpu().numpy() | |
| probs_resized = np.stack([cv2.resize(probs_np[c], (w, h), interpolation=cv2.INTER_LINEAR) for c in range(4)]) | |
| classmap = probs_resized.argmax(axis=0).astype(np.uint8) | |
| masks = {name: (classmap == cid) for cid, name in CLASS_NAMES.items() if cid > 0} | |
| return {"classmap": classmap, "masks": masks, "probs": probs_resized} | |
| def postprocess_segmentation( | |
| classmap: np.ndarray, | |
| img_bgr: np.ndarray, | |
| min_foot_ratio: float = 0.01, | |
| dark_l_threshold: float = 15.0, | |
| ) -> np.ndarray: | |
| """Post-process segmentation with necrotic tissue recovery. | |
| Steps: | |
| 1. Keep only the largest connected component of foreground. | |
| 2. Exclude dark pixels NOT in the main connected component. | |
| 3. RECOVER necrotic tissue: dark regions adjacent to the detected foot | |
| that the model missed are reclassified as ulcer (class 3). | |
| 4. Light morphological closing to smooth edges (no opening — preserves | |
| thin structures like toes). | |
| """ | |
| h, w = classmap.shape | |
| cleaned = classmap.copy() | |
| # Step 1: Largest connected component of foreground | |
| foreground = (cleaned > 0).astype(np.uint8) | |
| num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(foreground, connectivity=8) | |
| main_component_mask = np.zeros((h, w), dtype=bool) | |
| if num_labels > 1: | |
| areas = stats[1:, cv2.CC_STAT_AREA] | |
| largest_label = np.argmax(areas) + 1 | |
| main_component_mask = (labels == largest_label) | |
| min_area = h * w * min_foot_ratio | |
| for label_id in range(1, num_labels): | |
| if label_id == largest_label: | |
| continue | |
| if stats[label_id, cv2.CC_STAT_AREA] < min_area: | |
| cleaned[labels == label_id] = 0 | |
| else: | |
| main_component_mask = foreground.astype(bool) | |
| # Step 2: Dark pixel exclusion — ONLY for disconnected blobs | |
| lab = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2Lab).astype(np.float32) | |
| l_channel = lab[:, :, 0] * (100.0 / 255.0) | |
| a_channel = lab[:, :, 1] - 128.0 | |
| s_channel = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2HSV).astype(np.float32)[:, :, 1] | |
| dark_mask = l_channel < dark_l_threshold | |
| is_foreground = cleaned > 0 | |
| dark_isolated = dark_mask & is_foreground & ~main_component_mask | |
| cleaned[dark_isolated] = 0 | |
| # Step 3: Necrotic tissue recovery | |
| # Dark skin-like regions adjacent to detected foot → reclassify as ulcer | |
| cleaned = recover_necrotic_tissue(cleaned, img_bgr, l_channel, a_channel, s_channel) | |
| # Step 4: Light morphological closing (fills small gaps, does NOT erode thin structures) | |
| kernel_close = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3)) | |
| for cid in [1, 2, 3]: | |
| class_mask = (cleaned == cid).astype(np.uint8) | |
| if np.sum(class_mask) < 50: | |
| continue | |
| closed = cv2.morphologyEx(class_mask, cv2.MORPH_CLOSE, kernel_close) | |
| # Only ADD pixels (fill gaps), never remove | |
| new_pixels = (closed > 0) & (class_mask == 0) & (cleaned == 0) | |
| cleaned[new_pixels] = cid | |
| return cleaned | |
| def recover_necrotic_tissue( | |
| classmap: np.ndarray, | |
| img_bgr: np.ndarray, | |
| l_channel: np.ndarray, | |
| a_channel: np.ndarray, | |
| s_channel: np.ndarray, | |
| necrotic_l_max: float = 45.0, | |
| necrotic_s_max: float = 120.0, | |
| min_region_px: int = 100, | |
| ) -> np.ndarray: | |
| """Recover dark necrotic tissue regions adjacent to detected foreground. | |
| Necrotic tissue (eschar, gangrene, dry/wet gangrene on toes) is very dark | |
| and the model often misclassifies it as background. This function uses | |
| iterative dilation to progressively recover necrotic regions connected | |
| to the foot, even when there's a gap between the detected foot and the toes. | |
| Detection criteria for necrotic candidate pixels: | |
| - L* < 45 (dark tissue — covers eschar, gangrene, necrotic toes) | |
| - Saturation < 120 (not vivid colored — rules out green/blue backgrounds) | |
| - Currently classified as background (class 0) | |
| Iterative approach: dilate foreground progressively (3 rounds x 30px), | |
| recovering necrotic candidates at each step. This bridges gaps between | |
| the detected foot and disconnected necrotic regions like toes. | |
| """ | |
| h, w = classmap.shape | |
| recovered = classmap.copy() | |
| # Candidate necrotic pixels: dark, not vivid, currently background | |
| is_background = recovered == 0 | |
| necrotic_candidates = ( | |
| is_background | |
| & (l_channel < necrotic_l_max) | |
| & (s_channel < necrotic_s_max) | |
| ) | |
| if not np.any(necrotic_candidates): | |
| return recovered | |
| # Iterative recovery: progressively expand from detected foreground | |
| # Each round dilates 30px and recovers adjacent necrotic tissue, | |
| # then the recovered tissue becomes part of the foreground for the next round. | |
| # 3 rounds × 30px = up to 90px reach from the original foreground edge. | |
| dilation_step = 30 | |
| num_rounds = 3 | |
| current_foreground = (recovered > 0).astype(np.uint8) | |
| for round_idx in range(num_rounds): | |
| kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (dilation_step, dilation_step)) | |
| fg_dilated = cv2.dilate(current_foreground, kernel).astype(bool) | |
| # Candidates that are within reach this round | |
| adjacent = necrotic_candidates & fg_dilated & (recovered == 0) | |
| if not np.any(adjacent): | |
| break | |
| # Connected component filtering | |
| adjacent_u8 = adjacent.astype(np.uint8) | |
| num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(adjacent_u8, connectivity=8) | |
| recovered_any = False | |
| for label_id in range(1, num_labels): | |
| area = stats[label_id, cv2.CC_STAT_AREA] | |
| if area < min_region_px: | |
| continue | |
| region_mask = labels == label_id | |
| # Verify it touches current foreground | |
| region_dilated = cv2.dilate( | |
| region_mask.astype(np.uint8), | |
| cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (7, 7)) | |
| ) | |
| if np.any((region_dilated > 0) & (current_foreground > 0)): | |
| recovered[region_mask] = 3 # Ulcer (necrotic) | |
| recovered_any = True | |
| if not recovered_any: | |
| break | |
| # Update foreground for next round (include newly recovered tissue) | |
| current_foreground = (recovered > 0).astype(np.uint8) | |
| return recovered | |