File size: 8,021 Bytes
5e2c32d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
import os
import os.path
import sys
import json
import torch

# +++++++++++++ Conversion imports +++++++++++++++++
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(os.path.abspath(".."))
# +++++++++++++ Conversion imports +++++++++++++++++

from dicom_to_nii import convert_ct_dicom_to_nii
from nii_to_dicom import convert_nii_to_dicom


# AI MONAI libraries
from monai.networks.nets import UNet
from monai.networks.layers import Norm
from monai.inferers import sliding_window_inference
from monai.data import CacheDataset, DataLoader, Dataset, decollate_batch, NibabelReader
from monai.utils import first
from monai.transforms import (
    EnsureChannelFirstd,
    Compose,
    CropForegroundd,
    ScaleIntensityRanged,
    Invertd,
    AsDiscreted,
    ThresholdIntensityd,
    RemoveSmallObjectsd,
    KeepLargestConnectedComponentd,
    Activationsd
)
# Preprocessing
from preprocessing import LoadImaged
# Postprocessing
from postprocessing import SaveImaged, add_contours_exist
import matplotlib.pyplot as plt
import numpy as np
from utils import *


def predict(tempPath, patient_id, ctSeriesInstanceUID, runInterpreter):
    
    # Important: Check the input parameters #################
    if not patient_id or patient_id == "":
        sys.exit("No Patient dataset loaded: Load the patient dataset in Study Management.")
        
    if not ctSeriesInstanceUID or ctSeriesInstanceUID == "":
        sys.exit("No CT series instance UID to load the CT images. Check for CT data in your study")
    
    print("+++ tempath: ", tempPath)
    print("+++ patient_id: ", patient_id)
    print("+++ CT SeriesInstanceUID: ", ctSeriesInstanceUID)
    print("+++ runInterpreter", runInterpreter)
    
    # Important: Configure path ###########################   
    dir_base = os.path.join(tempPath, patient_id)
    createdir(dir_base)
    dir_ct_dicom = os.path.join(dir_base, 'ct_dicom')
    createdir(dir_ct_dicom)
    dir_ct_nii = os.path.join(dir_base, "ct_nii")
    createdir(dir_ct_nii)
    dir_prediction_nii = os.path.join(dir_base, 'prediction_nii')
    createdir(dir_prediction_nii)
    dir_prediction_dicom = os.path.join(dir_base, 'prediction_dicom')
    createdir(dir_prediction_dicom)
    
    # predicted files
    predictedNiiFile = os.path.join(dir_prediction_nii, 'RTStruct.nii.gz')
    predictedDicomFile = os.path.join(dir_prediction_dicom, 'predicted_rtstruct.dcm')

    model_path = r'best_metric_model.pth'
    if not os.path.exists(model_path):
        sys.exit("Not found the trained model")
        
    # Important: Configure path ###########################   
    print('** Use python interpreter: ', runInterpreter)
    print('** Patient name: ', patient_id)
    print('** CT Serial instance UID: ', ctSeriesInstanceUID)
    
    downloadSeriesInstanceByModality(ctSeriesInstanceUID, dir_ct_dicom, "CT")
    print("Loading CT from Orthanc done")

    # Conversion DICOM to nii
    # if not os.path.exists(os.path.join(".", dir_ct_nii, patient_id)):
    #     os.makedirs(os.path.join(dir_ct_nii, patient_id))
        
    refCT = convert_ct_dicom_to_nii(dir_dicom = dir_ct_dicom, dir_nii = dir_ct_nii, outputname='ct.nii.gz', newvoxelsize = None)

    print("Conversion DICOM to nii done")

    # Dictionary with patient to predict
    test_Data = [{'image':os.path.join(dir_ct_nii,'ct.nii.gz')}]

    # Transformations
    test_pretransforms = Compose(
        [
            LoadImaged(keys=["image"], reader = NibabelReader(), patientname=patient_id),
            EnsureChannelFirstd(keys=["image"]),
            ThresholdIntensityd(keys=["image"], threshold=1560, above=False, cval=1560),
            ThresholdIntensityd(keys=["image"], threshold=-50, above=True, cval=-1000),
            # MaskIntensityd(keys=['image'], mask_key="body"),
            ScaleIntensityRanged(
                keys=["image"], a_min=-1000, a_max=1560,
                b_min=0.0, b_max=1.0, clip=True,
            ),
            CropForegroundd(keys=["image"], source_key="image")
        ]
    )
    test_posttransforms = Compose(
        [
            Activationsd(keys="pred", softmax=True),
            Invertd(
                keys="pred",  # invert the `pred` data field, also support multiple fields
                transform=test_pretransforms,
                orig_keys="image",  # get the previously applied pre_transforms information on the `img` data field,
                                # then invert `pred` based on this information. we can use same info
                                # for multiple fields, also support different orig_keys for different fields
                nearest_interp=False,  # don't change the interpolation mode to "nearest" when inverting transforms
                                    # to ensure a smooth output, then execute `AsDiscreted` transform
                to_tensor=True,  # convert to PyTorch Tensor after inverting
            ),
            AsDiscreted(keys="pred",  argmax=True, to_onehot=2, threshold=0.5),
            KeepLargestConnectedComponentd(keys="pred",is_onehot=True),
            SaveImaged(keys="pred", output_postfix='', separate_folder=False, output_dir=dir_prediction_nii, resample=False)
        ]
    )

    # Define DataLoader using MONAI, CacheDataset needs to be used
    test_ds = CacheDataset(data=test_Data, transform=test_pretransforms)
    test_loader = DataLoader(test_ds, batch_size=1, shuffle=True, num_workers=1)

    # check_ds = Dataset(data=test_Data, transform=test_pretransforms)
    # check_loader = DataLoader(check_ds, batch_size=1)

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model_param = dict(
        spatial_dims=3,
        in_channels=1,
        out_channels=2,
        channels=(16, 32, 64, 128, 256),
        strides=(2, 2, 2, 2),
        num_res_units=2,
        norm=Norm.BATCH
    )                
    model = UNet(**model_param)
    # trained_model_dict = torch.load(model_path, map_location=torch.device('cpu'))
    trained_model_dict = torch.load(model_path, map_location=torch.device('cuda:0' if torch.cuda.is_available() else "cpu"))
    model.load_state_dict(trained_model_dict)#['state_dict'])

    # model.load_state_dict(torch.load(model_path))
    model = model.to(device)
    # print("MODEL",model)

    model.eval()
    d = first(test_loader)
    images = d["image"].to(device)
    d['pred'] = sliding_window_inference(inputs=images, roi_size=(96,96,64),sw_batch_size=1, predictor = model)
    d['pred'] = [test_posttransforms(i) for i in decollate_batch(d)] 
        
    # model.cpu()

    add_contours_exist(preddir = dir_prediction_nii, refCT = refCT)
    # Conversion nii to DICOM
    convert_nii_to_dicom(dicomctdir = dir_ct_dicom, predictedNiiFile = predictedNiiFile, predictedDicomFile = predictedDicomFile, predicted_structures=['BODY'], rtstruct_colors=[[255,0,0]], refCT = refCT)

    print("Conversion nii to DICOM done")

    # Transfer predicted DICOM to Orthanc
    uploadDicomToOrthanc(predictedDicomFile)

    print("Upload predicted result to Orthanc done")
    
    print("Body Segmentation prediction done")

'''
Prediction parameters provided by the server. Select the parameters to be used for prediction:
    [1] tempPath: The path where the predict.py is stored,
    [2] patientname: python version,
    [3] ctSeriesInstanceUID: Series instance UID for data set with modality = CT. To predict 'MR' modality data, retrieve the CT UID by the code (see Precision Code)
    [4] rtStructSeriesInstanceUID: Series instance UID for modality = RTSTURCT
    [5] regSeriesInstanceUID: Series instance UID for modality = REG,
    [6] runInterpreter: The python version for the python environment
    [7] oarList: only for dose predciton. For contour predicion oarList = []
    [8] tvList:  only for dose prediction. For contour prediction tvList = []
''' 

if __name__ == '__main__':
    predict(tempPath=sys.argv[1], patient_id=sys.argv[2], ctSeriesInstanceUID=sys.argv[3], runInterpreter=sys.argv[6])