File size: 3,486 Bytes
a3290d1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import os
from pathlib import Path
from time import time
from typing import Union

from totalsegmentator.libs import (
    download_pretrained_weights,
    nostdout,
    setup_nnunet,
)

from comp2comp.contrast_phase.contrast_inf import predict_phase
from comp2comp.inference_class_base import InferenceClass


class ContrastPhaseDetection(InferenceClass):
    """Contrast Phase Detection."""

    def __init__(self, input_path):
        super().__init__()
        self.input_path = input_path

    def __call__(self, inference_pipeline):
        self.output_dir = inference_pipeline.output_dir
        self.output_dir_segmentations = os.path.join(self.output_dir, "segmentations/")
        if not os.path.exists(self.output_dir_segmentations):
            os.makedirs(self.output_dir_segmentations)
        self.model_dir = inference_pipeline.model_dir

        seg, img = self.run_segmentation(
            os.path.join(self.output_dir_segmentations, "converted_dcm.nii.gz"),
            self.output_dir_segmentations + "s01.nii.gz",
            inference_pipeline.model_dir,
        )

        # segArray, imgArray = self.convertNibToNumpy(seg, img)

        imgNiftiPath = os.path.join(
            self.output_dir_segmentations, "converted_dcm.nii.gz"
        )
        segNiftPath = os.path.join(self.output_dir_segmentations, "s01.nii.gz")

        predict_phase(segNiftPath, imgNiftiPath, outputPath=self.output_dir)

        return {}

    def run_segmentation(
        self, input_path: Union[str, Path], output_path: Union[str, Path], model_dir
    ):
        """Run segmentation.

        Args:
            input_path (Union[str, Path]): Input path.
            output_path (Union[str, Path]): Output path.
        """

        print("Segmenting...")
        st = time()
        os.environ["SCRATCH"] = self.model_dir

        # Setup nnunet
        model = "3d_fullres"
        folds = [0]
        trainer = "nnUNetTrainerV2_ep4000_nomirror"
        crop_path = None
        task_id = [251]

        setup_nnunet()
        for task_id in [251]:
            download_pretrained_weights(task_id)

        from totalsegmentator.nnunet import nnUNet_predict_image

        with nostdout():
            img, seg = nnUNet_predict_image(
                input_path,
                output_path,
                task_id,
                model=model,
                folds=folds,
                trainer=trainer,
                tta=False,
                multilabel_image=True,
                resample=1.5,
                crop=None,
                crop_path=crop_path,
                task_name="total",
                nora_tag=None,
                preview=False,
                nr_threads_resampling=1,
                nr_threads_saving=6,
                quiet=False,
                verbose=False,
                test=0,
            )
        end = time()

        # Log total time for spine segmentation
        print(f"Total time for segmentation: {end-st:.2f}s.")

        return seg, img

    def convertNibToNumpy(self, TSNib, ImageNib):
        """Convert nifti to numpy array.

        Args:
            TSNib (nibabel.nifti1.Nifti1Image): TotalSegmentator output.
            ImageNib (nibabel.nifti1.Nifti1Image): Input image.

        Returns:
            numpy.ndarray: TotalSegmentator output.
            numpy.ndarray: Input image.
        """
        TS_array = TSNib.get_fdata()
        img_array = ImageNib.get_fdata()
        return TS_array, img_array