Spaces:
Sleeping
Sleeping
import numpy as np | |
import torch | |
from monai.transforms import ( | |
Activations, AsDiscreteD, AsDiscrete, Compose, ToTensorD, | |
GaussianSmoothD, LoadImageD, TransposeD, OrientationD, ScaleIntensityRangeD, | |
ToTensor, FillHoles, KeepLargestConnectedComponent, NormalizeIntensityD | |
) | |
from nrrd import read | |
from visualization import visualize_results | |
from data_preparation import get_patient_dictionaries | |
from monai.data import Dataset, DataLoader | |
import os | |
from data_transforms import ConvertMaskValues, MaskOutNonliver | |
from pipeline import build_model, evaluate | |
def run_sequential_inference(txt_file, config_liver, config_tumor, eval_metrics, output_dir, only_tumor=False, export=True): | |
def custom_collate_fn(batch): | |
num_samples_to_select = config_liver['BATCH_SIZE'] | |
# Extract images and masks from the batch, ensure image and mask same size | |
images, masks, pred_liver = [], [], [] | |
for sample in batch: | |
num_samples = min(sample["image"].shape[0], sample["mask"].shape[0]) | |
random_indices = torch.randperm(num_samples)[:num_samples_to_select] | |
images.append(sample["image"][:,:512,:512,:]) | |
masks.append(sample["mask"][:,:512,:512,:]) | |
# Stack images and masks along the first dimension | |
try: | |
concatenated_images = torch.stack(images, dim=0) | |
concatenated_masks = torch.stack(masks, dim=0) | |
except Exception as e: | |
print("WARNING: not all images/masks are 512 by 512. Please check. ", images[0].shape, images[1].shape, masks[0].shape, masks[1].shape) | |
return None, None | |
# Return stacked images and masks as tensors | |
if "pred_liver" in sample.keys(): | |
return {"image": concatenated_images, "mask": concatenated_masks, "pred_liver": sample["pred_liver"]} | |
else: | |
return {"image": concatenated_images, "mask": concatenated_masks} | |
### Model preparation | |
print("") | |
print("Loading models....") | |
liver_model = build_model(config_liver) | |
tumor_model = build_model(config_tumor) | |
#### Data preparation | |
print("") | |
print("Loading test data....") | |
test_data_dict = get_patient_dictionaries(txt_file=txt_file, data_dir=config_liver['DATA_DIR']) | |
print(" Number of test patients:", len(test_data_dict)) | |
# assign output file names and paths | |
export_file_metadata = [] | |
if not os.path.exists(output_dir): os.makedirs(output_dir) | |
for patient_dict in test_data_dict: | |
patient_folder = os.path.join(output_dir, patient_dict['patient_id']) | |
if not os.path.exists(patient_folder): os.makedirs(patient_folder) | |
patient_dict['pred_liver'] = os.path.join(patient_folder, "liver_segmentation.nrrd") | |
patient_dict['pred_tumor'] = os.path.join(patient_folder, "tumor_segmentation.nrrd") | |
export_file_metadata.append(read(patient_dict['image'])[1]) | |
#### Liver segmentation | |
# define liver data loading and preprocessing | |
if not only_tumor: | |
print("") | |
print("Producing liver segmentations....") | |
liver_preprocessing = Compose([ | |
LoadImageD(keys=["image", "mask"], reader="NrrdReader", ensure_channel_first=True), | |
OrientationD(keys=["image", "mask"], axcodes="PLI"), | |
ScaleIntensityRangeD(keys=["image"], | |
a_min=config_liver['HU_RANGE'][0], | |
a_max=config_liver['HU_RANGE'][1], | |
b_min=0.0, b_max=1.0, clip=True | |
), | |
ConvertMaskValues(keys=["mask"], keep_classes=["liver"]), | |
ToTensorD(keys=["image", "mask"]) | |
]) | |
liver_postprocessing = Compose([ | |
Activations(sigmoid=True), | |
AsDiscrete(argmax=True, to_onehot=None), | |
KeepLargestConnectedComponent(applied_labels=[1]), | |
FillHoles(applied_labels=[1]), | |
ToTensor() | |
]) | |
test_ds_liver = Dataset(test_data_dict, transform=liver_preprocessing) | |
test_ds_liver = DataLoader(test_ds_liver, batch_size=config_liver['BATCH_SIZE'], collate_fn=custom_collate_fn, shuffle=False, num_workers=config_liver['NUM_WORKERS']) | |
# produce liver model results | |
test_metrics_liver, sample_output_liver = evaluate(liver_model, test_ds_liver, eval_metrics, config_liver, postprocessing_transforms=liver_postprocessing, export_filenames = [p['pred_liver'] for p in test_data_dict], export_file_metadata=export_file_metadata) | |
print("") | |
print("==============================") | |
print("Liver segmentation test performance ....") | |
for key, value in test_metrics_liver.items(): | |
print(f' {key.replace("_avg", "_liver")}: {value:.3f}') | |
print("==============================") | |
##### Tumor segmentation | |
print("") | |
print("Producing tumor segmentations....") | |
# define tumor loading and preprocessing | |
tumor_preprocessing = Compose([ | |
LoadImageD(keys=["image", "mask", "pred_liver"], reader="NrrdReader", ensure_channel_first=True), | |
OrientationD(keys=["image", "mask"], axcodes="PLI"), | |
MaskOutNonliver(mask_key="pred_liver"), # note that liver's predicted segmentation is used to crop to the liver region | |
ScaleIntensityRangeD(keys=["image"], | |
a_min=config_tumor['HU_RANGE'][0], | |
a_max=config_tumor['HU_RANGE'][1], | |
b_min=0.0, b_max=1.0, clip=True | |
), | |
ConvertMaskValues(keys=["mask"], keep_classes=["liver", "tumor"]), # format mask for measuring test performance | |
AsDiscreteD(keys=["mask"], to_onehot=3), # format mask for measuring test performance | |
ToTensorD(keys=["image", "mask", "pred_liver"]) | |
]) | |
tumor_postprocessing = Compose([ | |
Activations(sigmoid=True), | |
AsDiscrete(argmax=True, to_onehot=3), | |
ToTensor() | |
]) | |
test_ds_tumor = Dataset(test_data_dict, transform=tumor_preprocessing) | |
test_ds_tumor = DataLoader(test_ds_tumor, batch_size=config_tumor['BATCH_SIZE'], collate_fn=custom_collate_fn, shuffle=False, num_workers=config_tumor['NUM_WORKERS']) | |
test_metrics_tumor, sample_output_tumor = evaluate(tumor_model, test_ds_tumor, eval_metrics, config_tumor, tumor_postprocessing, use_liver_seg = True, export_filenames = [p['pred_tumor'] for p in test_data_dict] if export else [], export_file_metadata=export_file_metadata) | |
print("") | |
print("==============================") | |
print("Tumor segmentation test performance ....") | |
for key, value in test_metrics_tumor.items(): | |
if "class2" in key: | |
print(f' {key.replace("_class2", "_tumor")}: {value:.3f}') | |
print("==============================") | |
print("") | |
#### Visualization | |
# combine liver and tumor segmentations into one segmentation output | |
if not only_tumor: sample_output_tumor[2][0][1] = sample_output_liver[2][0][0] | |
# visualization | |
print("") | |
if not only_tumor: | |
visualize_results(sample_output_liver[0][0].cpu(), sample_output_tumor[1][0].cpu(), sample_output_tumor[2][0].cpu(), n_slices=5, title="") | |
else: | |
visualize_results(sample_output_tumor[0][0].cpu(), sample_output_tumor[1][0].cpu(), sample_output_tumor[2][0].cpu(), n_slices=5, title="") | |
return | |