| """ |
| Detect and remove constant-value artifact planes at volume boundaries. |
| |
| Interpolation during preprocessing can introduce planes filled with a single |
| non-zero constant value (e.g. 8.0 for CT) at the start or end of each spatial |
| axis. This script: |
| |
| 1. Scans all .nii.gz files under --image_dir (and optionally --label_dir). |
| 2. For each image, identifies boundary planes that are entirely one NON-ZERO |
| value — zero-valued planes are skipped as they are legitimate background. |
| 3. Crops the artifact planes from image AND matching label (if present), |
| preserving the spatial origin so the image stays in the same physical |
| coordinate space. |
| 4. Overwrites in-place (use --dry-run to preview without writing). |
| |
| Usage: |
| # Dry-run: report artifacts without modifying files |
| python clean_artifact_planes.py \ |
| --image_dir /path/to/MSD_processed/images \ |
| --label_dir /path/to/MSD_processed/labels \ |
| --dry-run |
| |
| # Actually clean: |
| python clean_artifact_planes.py \ |
| --image_dir /path/to/MSD_processed/images \ |
| --label_dir /path/to/MSD_processed/labels |
| """ |
| import os |
| import glob |
| import argparse |
| import numpy as np |
| import SimpleITK as sitk |
| from tqdm import tqdm |
|
|
|
|
| def _get_plane(arr, axis, idx): |
| """Extract a single plane from the array along the given axis.""" |
| slc = [slice(None)] * arr.ndim |
| slc[axis] = idx |
| return arr[tuple(slc)] |
|
|
|
|
| def find_artifact_slices(arr, axis, max_search=20): |
| """Find contiguous constant-value boundary slices along `axis`. |
| |
| Returns (n_start, n_end): number of artifact slices to trim from the |
| start and end of the given axis. |
| |
| A slice is considered an artifact if: |
| - It has exactly 1 unique value, AND |
| - That value is foreign to the adjacent interior plane (i.e. the value |
| does not appear, or appears very rarely, in the neighbor). |
| This avoids trimming legitimate background planes (e.g. -300 in CT air) |
| that are naturally connected to interior regions with the same value. |
| """ |
| n = arr.shape[axis] |
|
|
| def _is_artifact(idx, interior_idx): |
| plane = _get_plane(arr, axis, idx) |
| unique = np.unique(plane) |
| if len(unique) != 1: |
| return False |
| val = float(unique[0]) |
| |
| neighbor = _get_plane(arr, axis, interior_idx) |
| |
| |
| match_ratio = np.mean(np.abs(neighbor - val) < 1e-6) |
| if match_ratio > 0.01: |
| return False |
| return True |
|
|
| |
| def _find_reference(start, stop, step): |
| for idx in range(start, stop, step): |
| plane = _get_plane(arr, axis, idx) |
| if len(np.unique(plane)) > 1: |
| return idx |
| return start |
|
|
| ref_start = _find_reference(0, min(max_search + 5, n), 1) |
| ref_end = _find_reference(n - 1, max(n - 1 - max_search - 5, -1), -1) |
|
|
| n_start = 0 |
| for i in range(min(max_search, n // 2)): |
| if _is_artifact(i, ref_start): |
| n_start = i + 1 |
| else: |
| break |
|
|
| n_end = 0 |
| for i in range(n - 1, max(n - 1 - max_search, n // 2), -1): |
| if _is_artifact(i, ref_end): |
| n_end = (n - 1 - i) + 1 |
| else: |
| break |
|
|
| return n_start, n_end |
|
|
|
|
| def detect_artifacts(arr, max_search=20): |
| """Detect artifact planes on all spatial axes. |
| |
| For 4D arrays (e.g. BRATS with shape [C, D, H, W]), only spatial axes |
| (1, 2, 3) are checked; the channel axis (0) is skipped. |
| |
| Returns a dict: {axis: (n_start, n_end)} for axes that need trimming. |
| """ |
| if arr.ndim == 3: |
| spatial_axes = [0, 1, 2] |
| elif arr.ndim == 4: |
| spatial_axes = [1, 2, 3] |
| else: |
| spatial_axes = list(range(arr.ndim)) |
|
|
| crops = {} |
| for axis in spatial_axes: |
| n_start, n_end = find_artifact_slices(arr, axis, max_search=max_search) |
| if n_start > 0 or n_end > 0: |
| crops[axis] = (n_start, n_end) |
| return crops |
|
|
|
|
| def build_crop_slices(ndim, crops): |
| """Build a tuple of slices to crop the array according to `crops`.""" |
| slices = [slice(None)] * ndim |
| for axis, (n_start, n_end) in crops.items(): |
| end = None if n_end == 0 else -n_end |
| slices[axis] = slice(n_start, end) |
| return tuple(slices) |
|
|
|
|
| def crop_sitk_image(sitk_img, crops): |
| """Crop a SimpleITK image according to the detected artifact planes. |
| |
| Updates the origin so the cropped image occupies the correct physical space. |
| """ |
| arr = sitk.GetArrayFromImage(sitk_img) |
| crop_slices = build_crop_slices(arr.ndim, crops) |
| cropped_arr = arr[crop_slices] |
|
|
| cropped_img = sitk.GetImageFromArray(cropped_arr) |
| cropped_img.SetSpacing(sitk_img.GetSpacing()) |
| cropped_img.SetDirection(sitk_img.GetDirection()) |
|
|
| |
| ndim_phys = sitk_img.GetDimension() |
| origin = list(sitk_img.GetOrigin()) |
| spacing = list(sitk_img.GetSpacing()) |
| direction = np.array(sitk_img.GetDirection()).reshape(ndim_phys, ndim_phys) |
|
|
| for axis, (n_start, _) in crops.items(): |
| if n_start > 0: |
| |
| |
| if arr.ndim == 3: |
| phys_axis = 2 - axis |
| elif arr.ndim == 4: |
| phys_axis = 2 - (axis - 1) |
| else: |
| continue |
| if phys_axis < ndim_phys: |
| for i in range(min(3, ndim_phys)): |
| origin[i] += n_start * spacing[phys_axis] * direction[i, phys_axis] |
|
|
| cropped_img.SetOrigin(origin) |
|
|
| for key in sitk_img.GetMetaDataKeys(): |
| cropped_img.SetMetaData(key, sitk_img.GetMetaData(key)) |
|
|
| return cropped_img |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser(description="Detect and remove constant-value artifact planes at volume boundaries.") |
| parser.add_argument("--image_dir", type=str, required=True, |
| help="Directory containing .nii.gz image files.") |
| parser.add_argument("--label_dir", type=str, default=None, |
| help="Directory containing matching .nii.gz label files (same filenames). " |
| "In recursive mode, labels are found at {subject_dir}/segmentation/{filename}.") |
| parser.add_argument("--recursive", action="store_true", |
| help="Recursively search for .nii.gz files, excluding segmentation/ subdirs.") |
| parser.add_argument("--max_search", type=int, default=20, |
| help="Max number of boundary slices to check per side (default: 20).") |
| parser.add_argument("--dry-run", action="store_true", |
| help="Report artifacts without modifying any files.") |
| args = parser.parse_args() |
|
|
| if args.recursive: |
| all_files = sorted(glob.glob(os.path.join(args.image_dir, "**", "*.nii.gz"), recursive=True)) |
| image_files = [f for f in all_files if '/segmentation/' not in f and '/label' not in f.lower()] |
| else: |
| image_files = sorted(glob.glob(os.path.join(args.image_dir, "*.nii.gz"))) |
| print(f"Found {len(image_files)} images in {args.image_dir}{' (recursive)' if args.recursive else ''}") |
| if args.label_dir: |
| print(f"Label dir: {args.label_dir}") |
| if args.dry_run: |
| print("*** DRY RUN — no files will be modified ***") |
|
|
| total_artifacts = 0 |
| total_clean = 0 |
| total_slices_removed = 0 |
|
|
| for img_path in tqdm(image_files, desc="Scanning"): |
| filename = os.path.basename(img_path) |
| sitk_img = sitk.ReadImage(img_path) |
| arr = sitk.GetArrayFromImage(sitk_img) |
|
|
| crops = detect_artifacts(arr, max_search=args.max_search) |
|
|
| if not crops: |
| total_clean += 1 |
| continue |
|
|
| total_artifacts += 1 |
| slices_removed = sum(s + e for s, e in crops.values()) |
| total_slices_removed += slices_removed |
|
|
| detail = ", ".join( |
| f"axis{ax}: -{s} start, -{e} end" |
| for ax, (s, e) in sorted(crops.items()) |
| ) |
|
|
| |
| for ax, (s, e) in crops.items(): |
| slc = [slice(None)] * arr.ndim |
| if s > 0: |
| slc[ax] = 0 |
| else: |
| slc[ax] = arr.shape[ax] - 1 |
| val = arr[tuple(slc)].flat[0] |
| break |
| print(f" {filename}: {arr.shape} -> trim {slices_removed} planes, val={val} ({detail})") |
|
|
| if args.dry_run: |
| continue |
|
|
| |
| cropped_img = crop_sitk_image(sitk_img, crops) |
| sitk.WriteImage(cropped_img, img_path) |
|
|
| |
| if args.label_dir and not args.recursive: |
| label_path = os.path.join(args.label_dir, filename) |
| if os.path.isfile(label_path): |
| sitk_lbl = sitk.ReadImage(label_path) |
| cropped_lbl = crop_sitk_image(sitk_lbl, crops) |
| sitk.WriteImage(cropped_lbl, label_path) |
| elif args.recursive: |
| |
| parent_dir = os.path.dirname(img_path) |
| label_path = os.path.join(parent_dir, 'segmentation', filename) |
| if os.path.isfile(label_path): |
| sitk_lbl = sitk.ReadImage(label_path) |
| cropped_lbl = crop_sitk_image(sitk_lbl, crops) |
| sitk.WriteImage(cropped_lbl, label_path) |
|
|
| print(f"\nSummary:") |
| print(f" Total images: {len(image_files)}") |
| print(f" With artifacts: {total_artifacts}") |
| print(f" Clean: {total_clean}") |
| print(f" Planes removed: {total_slices_removed}") |
| if args.dry_run: |
| print(" (dry-run — nothing was modified)") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|