AdritRao's picture
Upload 62 files
a3290d1
raw
history blame contribute delete
No virus
5.9 kB
import math
from typing import List, Sequence
import keras.utils as k_utils
import numpy as np
import pydicom
from keras.utils.data_utils import OrderedEnqueuer
from tqdm import tqdm
def parse_windows(windows):
"""Parse windows provided by the user.
These windows can either be strings corresponding to popular windowing
thresholds for CT or tuples of (upper, lower) bounds.
Args:
windows (list): List of strings or tuples.
Returns:
list: List of tuples of (upper, lower) bounds.
"""
windowing = {
"soft": (400, 50),
"bone": (1800, 400),
"liver": (150, 30),
"spine": (250, 50),
"custom": (500, 50),
}
vals = []
for w in windows:
if isinstance(w, Sequence) and len(w) == 2:
assert_msg = "Expected tuple of (lower, upper) bound"
assert len(w) == 2, assert_msg
assert isinstance(w[0], (float, int)), assert_msg
assert isinstance(w[1], (float, int)), assert_msg
assert w[0] < w[1], assert_msg
vals.append(w)
continue
if w not in windowing:
raise KeyError("Window {} not found".format(w))
window_width = windowing[w][0]
window_level = windowing[w][1]
upper = window_level + window_width / 2
lower = window_level - window_width / 2
vals.append((lower, upper))
return tuple(vals)
def _window(xs, bounds):
"""Apply windowing to an array of CT images.
Args:
xs (ndarray): NxHxW
bounds (tuple): (lower, upper) bounds
Returns:
ndarray: Windowed images.
"""
imgs = []
for lb, ub in bounds:
imgs.append(np.clip(xs, a_min=lb, a_max=ub))
if len(imgs) == 1:
return imgs[0]
elif xs.shape[-1] == 1:
return np.concatenate(imgs, axis=-1)
else:
return np.stack(imgs, axis=-1)
class Dataset(k_utils.Sequence):
def __init__(self, files: List[str], batch_size: int = 16, windows=None):
self._files = files
self._batch_size = batch_size
self.windows = windows
def __len__(self):
return math.ceil(len(self._files) / self._batch_size)
def __getitem__(self, idx):
files = self._files[idx * self._batch_size : (idx + 1) * self._batch_size]
dcms = [pydicom.read_file(f, force=True) for f in files]
xs = [(x.pixel_array + int(x.RescaleIntercept)).astype("float32") for x in dcms]
params = [
{"spacing": header.PixelSpacing, "image": x} for header, x in zip(dcms, xs)
]
# Preprocess xs via windowing.
xs = np.stack(xs, axis=0)
if self.windows:
xs = _window(xs, parse_windows(self.windows))
else:
xs = xs[..., np.newaxis]
return xs, params
def _swap_muscle_imap(xs, ys, muscle_idx: int, imat_idx: int, threshold=-30.0):
"""
If pixel labeled as muscle but has HU < threshold, change label to imat.
Args:
xs (ndarray): NxHxWxC
ys (ndarray): NxHxWxC
muscle_idx (int): Index of the muscle label.
imat_idx (int): Index of the imat label.
threshold (float): Threshold for HU value.
Returns:
ndarray: Segmentation mask with swapped labels.
"""
labels = ys.copy()
muscle_mask = (labels[..., muscle_idx] > 0.5).astype(int)
imat_mask = labels[..., imat_idx]
imat_mask[muscle_mask.astype(np.bool) & (xs < threshold)] = 1
muscle_mask[xs < threshold] = 0
labels[..., muscle_idx] = muscle_mask
labels[..., imat_idx] = imat_mask
return labels
def postprocess(xs: np.ndarray, ys: np.ndarray):
"""Built-in post-processing.
TODO: Make this configurable.
Args:
xs (ndarray): NxHxW
ys (ndarray): NxHxWxC
params (dictionary): Post-processing parameters. Must contain
"categories".
Returns:
ndarray: Post-processed labels.
"""
# Add another channel full of zeros to ys
ys = np.concatenate([ys, np.zeros_like(ys[..., :1])], axis=-1)
# If muscle hu is < -30, assume it is imat.
"""
if "muscle" in categories and "imat" in categories:
ys = _swap_muscle_imap(
xs,
ys,
muscle_idx=categories["muscle"],
imat_idx=categories["imat"],
)
"""
return ys
def predict(
model,
dataset: Dataset,
batch_size: int = 16,
num_workers: int = 1,
max_queue_size: int = 10,
use_multiprocessing: bool = False,
):
"""Predict segmentation masks for a dataset.
Args:
model (keras.Model): Model to use for prediction.
dataset (Dataset): Dataset to predict on.
batch_size (int): Batch size.
num_workers (int): Number of workers.
max_queue_size (int): Maximum queue size.
use_multiprocessing (bool): Use multiprocessing.
use_postprocessing (bool): Use built-in post-processing.
postprocessing_params (dict): Post-processing parameters.
Returns:
List: List of segmentation masks.
"""
if num_workers > 0:
enqueuer = OrderedEnqueuer(
dataset, use_multiprocessing=use_multiprocessing, shuffle=False
)
enqueuer.start(workers=num_workers, max_queue_size=max_queue_size)
output_generator = enqueuer.get()
else:
output_generator = iter(dataset)
num_scans = len(dataset)
xs = []
ys = []
params = []
for _ in tqdm(range(num_scans)):
x, p_dicts = next(output_generator)
y = model.predict(x, batch_size=batch_size)
image = np.stack([out["image"] for out in p_dicts], axis=0)
y = postprocess(image, y)
params.extend(p_dicts)
xs.extend([x[i, ...] for i in range(len(x))])
ys.extend([y[i, ...] for i in range(len(y))])
return xs, ys, params