andreped's picture
Moved chainer imports inside functions [no ci]
6faad95
import numpy as np
import os, sys
from tqdm import tqdm
import nibabel as nib
from nibabel.processing import resample_to_output, resample_from_to
from scipy.ndimage import zoom
from tensorflow.python.keras.models import load_model
import gdown
from skimage.morphology import remove_small_holes, binary_dilation, binary_erosion, ball
from skimage.measure import label, regionprops
import warnings
import argparse
import pkg_resources
import tensorflow as tf
import logging as log
import math
from .yaml_utils import Config
import yaml
from tensorflow.keras import backend as K
from numba import cuda
import multiprocessing as mp
def intensity_normalization(volume, intensity_clipping_range):
result = np.copy(volume)
result[volume < intensity_clipping_range[0]] = intensity_clipping_range[0]
result[volume > intensity_clipping_range[1]] = intensity_clipping_range[1]
min_val = np.amin(result)
max_val = np.amax(result)
if (max_val - min_val) != 0:
result = (result - min_val) / (max_val - min_val)
return result
def liver_segmenter_wrapper(curr, output, cpu, verbose, multiple_flag, name, extension, mp_enabled=True):
if mp_enabled:
# run inference in a different process
mp.set_start_method('spawn', force=True)
with mp.Pool(processes=1, maxtasksperchild=1) as p: # , initializer=initializer)
result = p.map_async(liver_segmenter, ((curr, output, cpu, verbose, multiple_flag, name, extension),))
log.info("getting result from process...")
ret = result.get()[0]
else:
ret = liver_segmenter(params=(curr, output, cpu, verbose, multiple_flag, name, extension))
return ret
def liver_segmenter(params):
try:
curr, output, cpu, verbose, multiple_flag, name, extension = params
# load model
model = load_model(name, compile=False)
log.info("preprocessing...")
nib_volume = nib.load(curr)
new_spacing = [1., 1., 1.]
resampled_volume = resample_to_output(nib_volume, new_spacing, order=1)
data = resampled_volume.get_data().astype('float32')
curr_shape = data.shape
# resize to get (512, 512) output images
img_size = 512
data = zoom(data, [img_size / data.shape[0], img_size / data.shape[1], 1.0], order=1)
# intensity normalization
intensity_clipping_range = [-150, 250] # HU clipping limits (Pravdaray's configs)
data = intensity_normalization(volume=data, intensity_clipping_range=intensity_clipping_range)
# fix orientation
data = np.rot90(data, k=1, axes=(0, 1))
data = np.flip(data, axis=0)
log.info("predicting...")
# predict on data
pred = np.zeros_like(data).astype(np.float32)
for i in tqdm(range(data.shape[-1]), "pred: ", disable=not verbose):
pred[..., i] = \
model.predict(np.expand_dims(np.expand_dims(np.expand_dims(data[..., i], axis=0), axis=-1), axis=0))[
0, ..., 1]
del data
# threshold
pred = (pred >= 0.4).astype(int)
# fix orientation back
pred = np.flip(pred, axis=0)
pred = np.rot90(pred, k=-1, axes=(0, 1))
log.info("resize back...")
# resize back from 512x512
pred = zoom(pred, [curr_shape[0] / img_size, curr_shape[1] / img_size, 1.0], order=1)
pred = (pred >= 0.5).astype(bool)
log.info("morphological post-processing...")
# morpological post-processing
# 1) first erode
pred = binary_erosion(pred, ball(3)).astype(np.int32)
# 2) keep only largest connected component
labels = label(pred)
nb_uniques = len(np.unique(labels)) # note: includes background 0
if nb_uniques > 2: # if only one, no filtering needed
regions = regionprops(labels)
area_sizes = []
for region in regions:
area_sizes.append([region.label, region.area])
area_sizes = np.array(area_sizes)
tmp = np.zeros_like(pred)
tmp[labels == area_sizes[np.argmax(area_sizes[:, 1]), 0]] = 1
pred = tmp.copy()
del tmp, labels, regions, area_sizes
if nb_uniques > 1: # if no segmentation, no post-processing needed
# 3) dilate
pred = binary_dilation(pred.astype(bool), ball(3))
# 4) remove small holes
pred = remove_small_holes(pred.astype(bool), area_threshold=0.001 * np.prod(pred.shape))
log.info("saving...")
pred = pred.astype(np.uint8)
img = nib.Nifti1Image(pred, affine=resampled_volume.affine)
resampled_lab = resample_from_to(img, nib_volume, order=0)
if multiple_flag:
nib.save(resampled_lab, output + "/" + curr.split("/")[-1].split(".")[0] + "-livermask" + extension)
else:
nib.save(resampled_lab, output + "-livermask" + extension)
return pred
except KeyboardInterrupt:
raise "Caught KeyboardInterrupt, terminating worker"
def vessel_segmenter(curr, output, cpu, verbose, multiple_flag, liver_mask, name_vessel, extension):
# only import chainer stuff inside here, to avoid unnecessary imports
import chainer
from .unet3d import UNet3D
from .utils import load_vessel_model
# check if cupy is available, if not, set cpu=True
try:
import cupy
except ModuleNotFoundError as e:
log.info(e)
log.info("cupy is not available. Setting cpu=True")
cpu = True
# load model
unet, xp = load_vessel_model(name_vessel, cpu)
# read config
config = Config(yaml.safe_load(open(os.path.dirname(os.path.abspath(__file__)) + "/../configs/base.yml")))
log.info("resize back...")
nib_volume = nib.load(curr)
new_spacing = [1., 1., 1.]
resampled_volume = resample_to_output(nib_volume, new_spacing, order=1)
org = resampled_volume.get_data().astype('float32')
# HU clipping
intensity_clipping_range = [80, 220]
org[org < intensity_clipping_range[0]] = intensity_clipping_range[0]
org[org > intensity_clipping_range[1]] = intensity_clipping_range[1]
# Calculate maximum of number of patch at each side
ze, ye, xe = org.shape
xm = int(math.ceil((float(xe) / float(config.patch['patchside']))))
ym = int(math.ceil((float(ye) / float(config.patch['patchside']))))
zm = int(math.ceil((float(ze) / float(config.patch['patchside']))))
margin = ((0, config.patch['patchside']),
(0, config.patch['patchside']),
(0, config.patch['patchside']))
org = np.pad(org, margin, 'edge')
org = chainer.Variable(xp.array(org[np.newaxis, np.newaxis, :], dtype=xp.float32))
# init prediction array
prediction_map = np.zeros(
(ze + config.patch['patchside'], ye + config.patch['patchside'], xe + config.patch['patchside']))
probability_map = np.zeros((config.unet['number_of_label'], ze + config.patch['patchside'],
ye + config.patch['patchside'], xe + config.patch['patchside']))
log.info("predicting...")
# Patch loop
for s in tqdm(range(xm * ym * zm), 'Patch loop', disable=not verbose):
xi = int(s % xm) * config.patch['patchside']
yi = int((s % (ym * xm)) / xm) * config.patch['patchside']
zi = int(s / (ym * xm)) * config.patch['patchside']
# check if current region contains any liver mask, if not, skip
parenchyma_patch = liver_mask[zi:zi + config.patch['patchside'], yi:yi + config.patch['patchside'],
xi:xi + config.patch['patchside']]
# if np.count_nonzero(parenchyma_patch) == 0:
if np.mean(parenchyma_patch) < 0.25:
continue
# Extract patch from original image
patch = org[:, :, zi:zi + config.patch['patchside'], yi:yi + config.patch['patchside'],
xi:xi + config.patch['patchside']]
with chainer.using_config('train', False), chainer.using_config('enable_backprop', False):
probability_patch = unet(patch)
# Generate probability map
probability_patch = probability_patch.data
# if args.gpu >= 0:
if not cpu:
probability_patch = chainer.cuda.to_cpu(probability_patch)
for ch in range(probability_patch.shape[1]):
probability_map[ch, zi:zi + config.patch['patchside'], yi:yi + config.patch['patchside'],
xi:xi + config.patch['patchside']] = probability_patch[0, ch, :, :, :]
prediction_patch = np.argmax(probability_patch, axis=1)
prediction_map[zi:zi + config.patch['patchside'], yi:yi + config.patch['patchside'],
xi:xi + config.patch['patchside']] = prediction_patch[0, :, :, :]
# probability_map = probability_map[:, :ze, :ye, :xe]
prediction_map = prediction_map[:ze, :ye, :xe]
# post-process prediction
# prediction_map = prediction_map + liver_mask
# prediction_map[prediction_map > 0] = 1
# filter segmented vessels outside the predicted liver parenchyma
pred = prediction_map.astype(np.uint8)
pred[liver_mask == 0] = 0
log.info("saving...")
img = nib.Nifti1Image(pred, affine=resampled_volume.affine)
resampled_lab = resample_from_to(img, nib_volume, order=0)
if multiple_flag:
nib.save(resampled_lab, output + "/" + curr.split("/")[-1].split(".")[0] + "-vessels" + extension)
else:
nib.save(resampled_lab, output + "-vessels" + extension)