| """Extract patch-level features from a single WSI. |
| |
| This script reads patch coordinates from an HDF5 file (typically produced by |
| `exaonepath.patches.patchfy`) and runs the EXAONE Path patch encoder on the |
| corresponding image regions to produce patch embeddings. |
| |
| Input H5 (`--coords_h5_path`) requirements |
| ----------------------------------------- |
| - Dataset: `coords` (shape: [N, 2]) containing (x, y) coordinates in **level-0** pixel space. |
| - Attribute on `coords`: `patch_size` (int). If missing, defaults to 256. |
| - Optional dataset: `contour_index` (shape: [N]). If missing, all patches are treated as contour 0. |
| |
| Output H5 (`--out_h5_path`) keys |
| -------------------------------- |
| - `features`: [N, C] patch embeddings |
| - `coords`: [N, 2] coordinates (copied from input) |
| - `contour_index`: [N] contour indices (copied from input or synthesized) |
| |
| Notes |
| ----- |
| - CUDA is required. |
| - This implementation loads all extracted features into memory before writing. |
| """ |
|
|
| import argparse |
| import os |
|
|
| import h5py |
| import numpy as np |
| import openslide |
| import torch |
| import torch.backends.cudnn as cudnn |
| from torch.amp import autocast |
| from torchvision import transforms |
|
|
| from exaonepath.models.patch_encoder_hf import PatchEncoder |
|
|
|
|
| def _open_coords_h5(coords_h5_path): |
| """Open coordinates H5 file and read coordinates.""" |
| if not os.path.exists(coords_h5_path): |
| print(f"Coords H5 file not found: {coords_h5_path}") |
| return None, None, None |
|
|
| with h5py.File(coords_h5_path, "r") as f: |
| if "coords" not in f: |
| print(f"Invalid coords H5 (missing 'coords'): {coords_h5_path}") |
| return None, None, None |
|
|
| coords = np.array(f["coords"]) |
| patch_size = int(f["coords"].attrs.get("patch_size", 256)) |
| if "contour_index" in f: |
| contour_index = np.array(f["contour_index"]) |
| else: |
| |
| contour_index = np.zeros((len(coords),), dtype=np.int32) |
|
|
| return coords, patch_size, contour_index |
|
|
|
|
| def _save_slide_features(out_h5_path, feat_list, coord_list, contour_list): |
| os.makedirs(os.path.dirname(os.path.abspath(out_h5_path)), exist_ok=True) |
| feat_tensor = torch.cat(feat_list, dim=0).numpy() |
| coords_tensor = torch.cat(coord_list, dim=0).numpy() |
| contour_tensor = torch.cat(contour_list, dim=0).numpy() |
|
|
| with h5py.File(out_h5_path, "w") as f: |
| f.create_dataset("features", data=feat_tensor) |
| f.create_dataset("coords", data=coords_tensor) |
| f.create_dataset("contour_index", data=contour_tensor) |
|
|
|
|
| def process_single_slide( |
| slide_path, |
| coords_h5_path, |
| out_h5_path, |
| model, |
| transform, |
| batch_size_per_gpu=32, |
| ): |
| device = "cuda" |
| print(f"Processing slide: {slide_path}") |
| print(f"Coords H5: {coords_h5_path}") |
| print(f"Output H5: {out_h5_path}") |
|
|
| |
| coords, patch_size, contour_index = _open_coords_h5(coords_h5_path) |
| if coords is None or contour_index is None: |
| print("No coordinates loaded; aborting.") |
| return |
|
|
| if len(coords) == 0: |
| print("No coordinates found (N=0); nothing to do.") |
| return |
|
|
| |
| wsi = openslide.OpenSlide(slide_path) |
|
|
| batch_images = [] |
| batch_coords = [] |
| batch_contours = [] |
|
|
| feat_list = [] |
| coord_list = [] |
| contour_list = [] |
|
|
| total_patches = len(coords) |
|
|
| |
| for j, (coord, contour_val) in enumerate(zip(coords, contour_index)): |
| x, y = int(coord[0]), int(coord[1]) |
|
|
| try: |
| patch = wsi.read_region((x, y), 0, (patch_size, patch_size)).convert("RGB") |
| patch = transform(patch) |
| except Exception as e: |
| print( |
| f"Error extracting patch at index {j} (coord: {coord}) from slide {slide_path}: {e}" |
| ) |
| continue |
|
|
| batch_images.append(patch) |
| batch_coords.append(coord) |
| batch_contours.append(contour_val) |
|
|
| |
| if len(batch_images) >= batch_size_per_gpu or j == total_patches - 1: |
| imgs_tensor = torch.stack(batch_images).to(device, non_blocking=True) |
|
|
| |
| with torch.inference_mode(), autocast("cuda", torch.bfloat16): |
| features = model(imgs_tensor) |
|
|
| features = features.detach().float().cpu() |
| feat_list.append(features) |
| coord_list.append(torch.tensor(np.array(batch_coords))) |
| contour_list.append(torch.tensor(np.array(batch_contours))) |
|
|
| batch_images = [] |
| batch_coords = [] |
| batch_contours = [] |
|
|
| |
| if (j + 1) % 1000 == 0 or (j + 1) == total_patches: |
| print(f"Processed {j + 1}/{total_patches} patches...") |
|
|
| wsi.close() |
|
|
| |
| if feat_list: |
| _save_slide_features(out_h5_path, feat_list, coord_list, contour_list) |
| else: |
| print(f"No features extracted for {slide_path}") |
|
|
|
|
| def main(args): |
| |
| if not torch.cuda.is_available(): |
| raise RuntimeError( |
| "CUDA is required for this script, but torch.cuda.is_available() is False" |
| ) |
|
|
| device = "cuda" |
| cudnn.benchmark = True |
|
|
| |
| print("Loading model...") |
| model = PatchEncoder.from_pretrained(args.repo_id) |
| model.to(device).eval() |
|
|
| |
| transform = transforms.Compose( |
| [ |
| transforms.Resize(256), |
| transforms.CenterCrop(224), |
| transforms.ToTensor(), |
| transforms.Normalize( |
| mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] |
| ), |
| ] |
| ) |
|
|
| |
| if args.out_h5_path: |
| out_h5_path = args.out_h5_path |
| else: |
| |
| base, _ = os.path.splitext(args.coords_h5_path) |
| out_h5_path = base + "_features.h5" |
|
|
| print("Extracting patch features ...") |
| process_single_slide( |
| args.slide_path, |
| args.coords_h5_path, |
| out_h5_path, |
| model, |
| transform, |
| args.batch_size_per_gpu, |
| ) |
|
|
|
|
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser( |
| "Single WSI patch-feature extraction", |
| formatter_class=argparse.ArgumentDefaultsHelpFormatter, |
| ) |
|
|
| parser.add_argument( |
| "--slide_path", |
| type=str, |
| required=True, |
| help="Path to a whole-slide image file (.svs/.tif/.tiff/.ndpi/.mrxs/...).", |
| ) |
| parser.add_argument( |
| "--coords_h5_path", |
| type=str, |
| required=True, |
| help=( |
| "Path to the coordinates HDF5 produced by `patchfy` (must contain dataset 'coords' " |
| "and ideally 'contour_index')." |
| ), |
| ) |
| parser.add_argument( |
| "--out_h5_path", |
| type=str, |
| default="", |
| help=( |
| "Output HDF5 path for patch features. If empty, defaults to '<coords_h5_path>_features.h5'." |
| ), |
| ) |
| parser.add_argument( |
| "--batch_size_per_gpu", |
| type=int, |
| default=32, |
| help="Batch size for patch encoder inference on a single GPU.", |
| ) |
| parser.add_argument( |
| "--repo_id", |
| type=str, |
| default="anonymous-bio/EXAONE-Path-2.5", |
| help="Hugging Face repo id containing the EXAONE Path 2.5 patch encoder.", |
| ) |
|
|
| args = parser.parse_args() |
| main(args) |
|
|