Data_Engineering / all /clean_artifact_planes.py
maxmo2009's picture
Initial upload: data cleanup pipeline for 12 medical imaging datasets
da9fb1e verified
"""
Detect and remove constant-value artifact planes at volume boundaries.
Interpolation during preprocessing can introduce planes filled with a single
non-zero constant value (e.g. 8.0 for CT) at the start or end of each spatial
axis. This script:
1. Scans all .nii.gz files under --image_dir (and optionally --label_dir).
2. For each image, identifies boundary planes that are entirely one NON-ZERO
value — zero-valued planes are skipped as they are legitimate background.
3. Crops the artifact planes from image AND matching label (if present),
preserving the spatial origin so the image stays in the same physical
coordinate space.
4. Overwrites in-place (use --dry-run to preview without writing).
Usage:
# Dry-run: report artifacts without modifying files
python clean_artifact_planes.py \
--image_dir /path/to/MSD_processed/images \
--label_dir /path/to/MSD_processed/labels \
--dry-run
# Actually clean:
python clean_artifact_planes.py \
--image_dir /path/to/MSD_processed/images \
--label_dir /path/to/MSD_processed/labels
"""
import os
import glob
import argparse
import numpy as np
import SimpleITK as sitk
from tqdm import tqdm
def _get_plane(arr, axis, idx):
"""Extract a single plane from the array along the given axis."""
slc = [slice(None)] * arr.ndim
slc[axis] = idx
return arr[tuple(slc)]
def find_artifact_slices(arr, axis, max_search=20):
"""Find contiguous constant-value boundary slices along `axis`.
Returns (n_start, n_end): number of artifact slices to trim from the
start and end of the given axis.
A slice is considered an artifact if:
- It has exactly 1 unique value, AND
- That value is foreign to the adjacent interior plane (i.e. the value
does not appear, or appears very rarely, in the neighbor).
This avoids trimming legitimate background planes (e.g. -300 in CT air)
that are naturally connected to interior regions with the same value.
"""
n = arr.shape[axis]
def _is_artifact(idx, interior_idx):
plane = _get_plane(arr, axis, idx)
unique = np.unique(plane)
if len(unique) != 1:
return False
val = float(unique[0])
# Check if this constant value appears in the adjacent interior plane
neighbor = _get_plane(arr, axis, interior_idx)
# If the value appears in >1% of the neighbor's voxels, it's likely
# connected background, not an artifact
match_ratio = np.mean(np.abs(neighbor - val) < 1e-6)
if match_ratio > 0.01:
return False
return True
# Find the first non-constant plane from each boundary to use as reference
def _find_reference(start, stop, step):
for idx in range(start, stop, step):
plane = _get_plane(arr, axis, idx)
if len(np.unique(plane)) > 1:
return idx
return start # fallback
ref_start = _find_reference(0, min(max_search + 5, n), 1)
ref_end = _find_reference(n - 1, max(n - 1 - max_search - 5, -1), -1)
n_start = 0
for i in range(min(max_search, n // 2)):
if _is_artifact(i, ref_start):
n_start = i + 1
else:
break
n_end = 0
for i in range(n - 1, max(n - 1 - max_search, n // 2), -1):
if _is_artifact(i, ref_end):
n_end = (n - 1 - i) + 1
else:
break
return n_start, n_end
def detect_artifacts(arr, max_search=20):
"""Detect artifact planes on all spatial axes.
For 4D arrays (e.g. BRATS with shape [C, D, H, W]), only spatial axes
(1, 2, 3) are checked; the channel axis (0) is skipped.
Returns a dict: {axis: (n_start, n_end)} for axes that need trimming.
"""
if arr.ndim == 3:
spatial_axes = [0, 1, 2]
elif arr.ndim == 4:
spatial_axes = [1, 2, 3]
else:
spatial_axes = list(range(arr.ndim))
crops = {}
for axis in spatial_axes:
n_start, n_end = find_artifact_slices(arr, axis, max_search=max_search)
if n_start > 0 or n_end > 0:
crops[axis] = (n_start, n_end)
return crops
def build_crop_slices(ndim, crops):
"""Build a tuple of slices to crop the array according to `crops`."""
slices = [slice(None)] * ndim
for axis, (n_start, n_end) in crops.items():
end = None if n_end == 0 else -n_end
slices[axis] = slice(n_start, end)
return tuple(slices)
def crop_sitk_image(sitk_img, crops):
"""Crop a SimpleITK image according to the detected artifact planes.
Updates the origin so the cropped image occupies the correct physical space.
"""
arr = sitk.GetArrayFromImage(sitk_img)
crop_slices = build_crop_slices(arr.ndim, crops)
cropped_arr = arr[crop_slices]
cropped_img = sitk.GetImageFromArray(cropped_arr)
cropped_img.SetSpacing(sitk_img.GetSpacing())
cropped_img.SetDirection(sitk_img.GetDirection())
# Adjust origin: SimpleITK arrays are in ZYX order, origin is in XYZ
ndim_phys = sitk_img.GetDimension() # physical dimensions (3 for 3D, 4 for 4D)
origin = list(sitk_img.GetOrigin())
spacing = list(sitk_img.GetSpacing())
direction = np.array(sitk_img.GetDirection()).reshape(ndim_phys, ndim_phys)
for axis, (n_start, _) in crops.items():
if n_start > 0:
# Map array axis to physical axis
# SimpleITK: last array axis = first physical axis
if arr.ndim == 3:
phys_axis = 2 - axis
elif arr.ndim == 4:
phys_axis = 2 - (axis - 1)
else:
continue
if phys_axis < ndim_phys:
for i in range(min(3, ndim_phys)):
origin[i] += n_start * spacing[phys_axis] * direction[i, phys_axis]
cropped_img.SetOrigin(origin)
for key in sitk_img.GetMetaDataKeys():
cropped_img.SetMetaData(key, sitk_img.GetMetaData(key))
return cropped_img
def main():
parser = argparse.ArgumentParser(description="Detect and remove constant-value artifact planes at volume boundaries.")
parser.add_argument("--image_dir", type=str, required=True,
help="Directory containing .nii.gz image files.")
parser.add_argument("--label_dir", type=str, default=None,
help="Directory containing matching .nii.gz label files (same filenames). "
"In recursive mode, labels are found at {subject_dir}/segmentation/{filename}.")
parser.add_argument("--recursive", action="store_true",
help="Recursively search for .nii.gz files, excluding segmentation/ subdirs.")
parser.add_argument("--max_search", type=int, default=20,
help="Max number of boundary slices to check per side (default: 20).")
parser.add_argument("--dry-run", action="store_true",
help="Report artifacts without modifying any files.")
args = parser.parse_args()
if args.recursive:
all_files = sorted(glob.glob(os.path.join(args.image_dir, "**", "*.nii.gz"), recursive=True))
image_files = [f for f in all_files if '/segmentation/' not in f and '/label' not in f.lower()]
else:
image_files = sorted(glob.glob(os.path.join(args.image_dir, "*.nii.gz")))
print(f"Found {len(image_files)} images in {args.image_dir}{' (recursive)' if args.recursive else ''}")
if args.label_dir:
print(f"Label dir: {args.label_dir}")
if args.dry_run:
print("*** DRY RUN — no files will be modified ***")
total_artifacts = 0
total_clean = 0
total_slices_removed = 0
for img_path in tqdm(image_files, desc="Scanning"):
filename = os.path.basename(img_path)
sitk_img = sitk.ReadImage(img_path)
arr = sitk.GetArrayFromImage(sitk_img)
crops = detect_artifacts(arr, max_search=args.max_search)
if not crops:
total_clean += 1
continue
total_artifacts += 1
slices_removed = sum(s + e for s, e in crops.values())
total_slices_removed += slices_removed
detail = ", ".join(
f"axis{ax}: -{s} start, -{e} end"
for ax, (s, e) in sorted(crops.items())
)
# Report the artifact value
for ax, (s, e) in crops.items():
slc = [slice(None)] * arr.ndim
if s > 0:
slc[ax] = 0
else:
slc[ax] = arr.shape[ax] - 1
val = arr[tuple(slc)].flat[0]
break
print(f" {filename}: {arr.shape} -> trim {slices_removed} planes, val={val} ({detail})")
if args.dry_run:
continue
# Crop and save image
cropped_img = crop_sitk_image(sitk_img, crops)
sitk.WriteImage(cropped_img, img_path)
# Crop matching label if present
if args.label_dir and not args.recursive:
label_path = os.path.join(args.label_dir, filename)
if os.path.isfile(label_path):
sitk_lbl = sitk.ReadImage(label_path)
cropped_lbl = crop_sitk_image(sitk_lbl, crops)
sitk.WriteImage(cropped_lbl, label_path)
elif args.recursive:
# In recursive mode, look for label at {parent}/segmentation/{filename}
parent_dir = os.path.dirname(img_path)
label_path = os.path.join(parent_dir, 'segmentation', filename)
if os.path.isfile(label_path):
sitk_lbl = sitk.ReadImage(label_path)
cropped_lbl = crop_sitk_image(sitk_lbl, crops)
sitk.WriteImage(cropped_lbl, label_path)
print(f"\nSummary:")
print(f" Total images: {len(image_files)}")
print(f" With artifacts: {total_artifacts}")
print(f" Clean: {total_clean}")
print(f" Planes removed: {total_slices_removed}")
if args.dry_run:
print(" (dry-run — nothing was modified)")
if __name__ == "__main__":
main()