|
import os |
|
import os.path |
|
import sys |
|
import json |
|
import torch |
|
|
|
|
|
sys.path.append(os.path.dirname(os.path.abspath(__file__))) |
|
sys.path.append(os.path.abspath("..")) |
|
|
|
|
|
from dicom_to_nii import convert_ct_dicom_to_nii |
|
from nii_to_dicom import convert_nii_to_dicom |
|
|
|
|
|
|
|
from monai.networks.nets import UNet |
|
from monai.networks.layers import Norm |
|
from monai.inferers import sliding_window_inference |
|
from monai.data import CacheDataset, DataLoader, Dataset, decollate_batch, NibabelReader |
|
from monai.utils import first |
|
from monai.transforms import ( |
|
EnsureChannelFirstd, |
|
Compose, |
|
CropForegroundd, |
|
ScaleIntensityRanged, |
|
Invertd, |
|
AsDiscreted, |
|
ThresholdIntensityd, |
|
RemoveSmallObjectsd, |
|
KeepLargestConnectedComponentd, |
|
Activationsd |
|
) |
|
|
|
from preprocessing import LoadImaged |
|
|
|
from postprocessing import SaveImaged, add_contours_exist |
|
import matplotlib.pyplot as plt |
|
import numpy as np |
|
from utils import * |
|
|
|
|
|
def predict(tempPath, patient_id, ctSeriesInstanceUID, runInterpreter): |
|
|
|
|
|
if not patient_id or patient_id == "": |
|
sys.exit("No Patient dataset loaded: Load the patient dataset in Study Management.") |
|
|
|
if not ctSeriesInstanceUID or ctSeriesInstanceUID == "": |
|
sys.exit("No CT series instance UID to load the CT images. Check for CT data in your study") |
|
|
|
print("+++ tempath: ", tempPath) |
|
print("+++ patient_id: ", patient_id) |
|
print("+++ CT SeriesInstanceUID: ", ctSeriesInstanceUID) |
|
print("+++ runInterpreter", runInterpreter) |
|
|
|
|
|
dir_base = os.path.join(tempPath, patient_id) |
|
createdir(dir_base) |
|
dir_ct_dicom = os.path.join(dir_base, 'ct_dicom') |
|
createdir(dir_ct_dicom) |
|
dir_ct_nii = os.path.join(dir_base, "ct_nii") |
|
createdir(dir_ct_nii) |
|
dir_prediction_nii = os.path.join(dir_base, 'prediction_nii') |
|
createdir(dir_prediction_nii) |
|
dir_prediction_dicom = os.path.join(dir_base, 'prediction_dicom') |
|
createdir(dir_prediction_dicom) |
|
|
|
|
|
predictedNiiFile = os.path.join(dir_prediction_nii, 'RTStruct.nii.gz') |
|
predictedDicomFile = os.path.join(dir_prediction_dicom, 'predicted_rtstruct.dcm') |
|
|
|
model_path = r'best_metric_model.pth' |
|
if not os.path.exists(model_path): |
|
sys.exit("Not found the trained model") |
|
|
|
|
|
print('** Use python interpreter: ', runInterpreter) |
|
print('** Patient name: ', patient_id) |
|
print('** CT Serial instance UID: ', ctSeriesInstanceUID) |
|
|
|
downloadSeriesInstanceByModality(ctSeriesInstanceUID, dir_ct_dicom, "CT") |
|
print("Loading CT from Orthanc done") |
|
|
|
|
|
|
|
|
|
|
|
refCT = convert_ct_dicom_to_nii(dir_dicom = dir_ct_dicom, dir_nii = dir_ct_nii, outputname='ct.nii.gz', newvoxelsize = None) |
|
|
|
print("Conversion DICOM to nii done") |
|
|
|
|
|
test_Data = [{'image':os.path.join(dir_ct_nii,'ct.nii.gz')}] |
|
|
|
|
|
test_pretransforms = Compose( |
|
[ |
|
LoadImaged(keys=["image"], reader = NibabelReader(), patientname=patient_id), |
|
EnsureChannelFirstd(keys=["image"]), |
|
ThresholdIntensityd(keys=["image"], threshold=1560, above=False, cval=1560), |
|
ThresholdIntensityd(keys=["image"], threshold=-50, above=True, cval=-1000), |
|
|
|
ScaleIntensityRanged( |
|
keys=["image"], a_min=-1000, a_max=1560, |
|
b_min=0.0, b_max=1.0, clip=True, |
|
), |
|
CropForegroundd(keys=["image"], source_key="image") |
|
] |
|
) |
|
test_posttransforms = Compose( |
|
[ |
|
Activationsd(keys="pred", softmax=True), |
|
Invertd( |
|
keys="pred", |
|
transform=test_pretransforms, |
|
orig_keys="image", |
|
|
|
|
|
nearest_interp=False, |
|
|
|
to_tensor=True, |
|
), |
|
AsDiscreted(keys="pred", argmax=True, to_onehot=2, threshold=0.5), |
|
KeepLargestConnectedComponentd(keys="pred",is_onehot=True), |
|
SaveImaged(keys="pred", output_postfix='', separate_folder=False, output_dir=dir_prediction_nii, resample=False) |
|
] |
|
) |
|
|
|
|
|
test_ds = CacheDataset(data=test_Data, transform=test_pretransforms) |
|
test_loader = DataLoader(test_ds, batch_size=1, shuffle=True, num_workers=1) |
|
|
|
|
|
|
|
|
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") |
|
model_param = dict( |
|
spatial_dims=3, |
|
in_channels=1, |
|
out_channels=2, |
|
channels=(16, 32, 64, 128, 256), |
|
strides=(2, 2, 2, 2), |
|
num_res_units=2, |
|
norm=Norm.BATCH |
|
) |
|
model = UNet(**model_param) |
|
|
|
trained_model_dict = torch.load(model_path, map_location=torch.device('cuda:0' if torch.cuda.is_available() else "cpu")) |
|
model.load_state_dict(trained_model_dict) |
|
|
|
|
|
model = model.to(device) |
|
|
|
|
|
model.eval() |
|
d = first(test_loader) |
|
images = d["image"].to(device) |
|
d['pred'] = sliding_window_inference(inputs=images, roi_size=(96,96,64),sw_batch_size=1, predictor = model) |
|
d['pred'] = [test_posttransforms(i) for i in decollate_batch(d)] |
|
|
|
|
|
|
|
add_contours_exist(preddir = dir_prediction_nii, refCT = refCT) |
|
|
|
convert_nii_to_dicom(dicomctdir = dir_ct_dicom, predictedNiiFile = predictedNiiFile, predictedDicomFile = predictedDicomFile, predicted_structures=['BODY'], rtstruct_colors=[[255,0,0]], refCT = refCT) |
|
|
|
print("Conversion nii to DICOM done") |
|
|
|
|
|
uploadDicomToOrthanc(predictedDicomFile) |
|
|
|
print("Upload predicted result to Orthanc done") |
|
|
|
print("Body Segmentation prediction done") |
|
|
|
''' |
|
Prediction parameters provided by the server. Select the parameters to be used for prediction: |
|
[1] tempPath: The path where the predict.py is stored, |
|
[2] patientname: python version, |
|
[3] ctSeriesInstanceUID: Series instance UID for data set with modality = CT. To predict 'MR' modality data, retrieve the CT UID by the code (see Precision Code) |
|
[4] rtStructSeriesInstanceUID: Series instance UID for modality = RTSTURCT |
|
[5] regSeriesInstanceUID: Series instance UID for modality = REG, |
|
[6] runInterpreter: The python version for the python environment |
|
[7] oarList: only for dose predciton. For contour predicion oarList = [] |
|
[8] tvList: only for dose prediction. For contour prediction tvList = [] |
|
''' |
|
|
|
if __name__ == '__main__': |
|
predict(tempPath=sys.argv[1], patient_id=sys.argv[2], ctSeriesInstanceUID=sys.argv[3], runInterpreter=sys.argv[6]) |
|
|
|
|
|
|