osbm commited on
Commit
2df1cff
1 Parent(s): 84f60c4

test this out

Browse files
Files changed (2) hide show
  1. gradio_app.py +4 -2
  2. inference.py +159 -0
gradio_app.py CHANGED
@@ -1,7 +1,9 @@
1
  import gradio as gr
 
2
 
3
  def monai_inference(input):
4
- pass
 
5
 
6
  demo = gr.Interface(
7
  fn=monai_inference,
@@ -10,5 +12,5 @@ demo = gr.Interface(
10
  title="Inference on monai model",
11
  description="You can upload either zip of dicom folder or .nii.gz file. In turn, you can download the mask as .nii.gz file.",
12
  )
13
-
14
  demo.launch()
 
1
  import gradio as gr
2
+ from inference import make_inference
3
 
4
  def monai_inference(input):
5
+ data_dict = [{"t2": input.name}]
6
+ return make_inference(data_dict)
7
 
8
  demo = gr.Interface(
9
  fn=monai_inference,
 
12
  title="Inference on monai model",
13
  description="You can upload either zip of dicom folder or .nii.gz file. In turn, you can download the mask as .nii.gz file.",
14
  )
15
+ demo.queue()
16
  demo.launch()
inference.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import monai
2
+ import torch
3
+ import pandas as pd
4
+ import nibabel as nib
5
+ import numpy as np
6
+ from monai.data import DataLoader
7
+ from monai.utils.enums import CommonKeys
8
+ from scipy import ndimage
9
+ from monai.data import Dataset
10
+ from monai.inferers import sliding_window_inference
11
+ from monai.metrics import DiceMetric
12
+ from monai.transforms import (
13
+ Activationsd,
14
+ AsDiscreted,
15
+ Compose,
16
+ ConcatItemsd,
17
+ KeepLargestConnectedComponentd,
18
+ LoadImaged,
19
+ EnsureChannelFirstd,
20
+ EnsureTyped,
21
+ SaveImaged,
22
+ ScaleIntensityd,
23
+ NormalizeIntensityd,
24
+ Spacingd,
25
+ Orientationd,
26
+ )
27
+
28
+ # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
29
+
30
+ # print("Using device:", device)
31
+
32
+ # model = monai.networks.nets.UNet(
33
+ # in_channels=1,
34
+ # out_channels=3,
35
+ # spatial_dims=3,
36
+ # channels=[16, 32, 64, 128, 256, 512],
37
+ # strides=[2, 2, 2, 2, 2],
38
+ # num_res_units=4,
39
+ # act="PRELU",
40
+ # norm="BATCH",
41
+ # dropout=0.15,
42
+ # )
43
+
44
+ # model.load_state_dict(torch.load("anatomy.pt", map_location=device))
45
+
46
+ keys = ("t2", "t2_anatomy_reader1")
47
+ transforms = Compose(
48
+ [
49
+ LoadImaged(keys=keys, image_only=False),
50
+ EnsureChannelFirstd(keys=keys),
51
+ Spacingd(keys=keys, pixdim=[0.5, 0.5, 0.5], mode=("bilinear", "nearest")),
52
+ Orientationd(keys=keys, axcodes="RAS"),
53
+ ScaleIntensityd(keys=keys, minv=0, maxv=1),
54
+ NormalizeIntensityd(keys=keys),
55
+ EnsureTyped(keys=keys),
56
+ ConcatItemsd(keys=("t2"), name=CommonKeys.IMAGE, dim=0),
57
+ ConcatItemsd(keys=("t2_anatomy_reader1"), name=CommonKeys.LABEL, dim=0),
58
+ ]
59
+ )
60
+
61
+
62
+
63
+
64
+ postprocessing = Compose(
65
+ [
66
+ EnsureTyped(keys=[CommonKeys.PRED, CommonKeys.LABEL]),
67
+ KeepLargestConnectedComponentd(
68
+ keys=CommonKeys.PRED,
69
+ applied_labels=list(range(1, 3))
70
+ ),
71
+ ]
72
+ )
73
+ inferer = monai.inferers.SlidingWindowInferer(
74
+ roi_size=(96, 96, 96),
75
+ sw_batch_size=4,
76
+ overlap=0.5,
77
+ )
78
+
79
+ def resize_image(image: np.array, target_shape: tuple):
80
+ depth_factor = target_shape[0] / image.shape[0]
81
+ width_factor = target_shape[1] / image.shape[1]
82
+ height_factor = target_shape[2] / image.shape[2]
83
+
84
+ return ndimage.zoom(image, (depth_factor, width_factor, height_factor), order=1)
85
+
86
+ # model.eval()
87
+ # with torch.no_grad():
88
+ # for i in range(len(test_ds)):
89
+ # example = test_ds[i]
90
+ # label = example["t2_anatomy_reader1"]
91
+ # input_tensor = example["t2"].unsqueeze(0)
92
+ # input_tensor = input_tensor.to(device)
93
+ # output_tensor = inferer(input_tensor, model)
94
+ # output_tensor = output_tensor.argmax(dim=1, keepdim=False)
95
+ # output_tensor = output_tensor.squeeze(0).to(torch.device("cpu"))
96
+
97
+ # output_tensor = postprocessing({"pred": output_tensor, "label": label})["pred"]
98
+ # output_tensor = output_tensor.numpy().astype(np.uint8)
99
+ # target_shape = example["t2_meta_dict"]["spatial_shape"]
100
+ # output_tensor = resize_image(output_tensor, target_shape)
101
+
102
+ # # flip first two dimensions
103
+ # output_tensor = np.flip(output_tensor, axis=0)
104
+ # output_tensor = np.flip(output_tensor, axis=1)
105
+
106
+ # new_image = nib.Nifti1Image(output_tensor, affine=example["t2_meta_dict"]["affine"])
107
+ # nib.save(new_image, f"test/{i+1:03}/predicted.nii.gz")
108
+
109
+ # print("Saved", i+1)
110
+
111
+
112
+ def make_inference(data_dict:list) -> str:
113
+
114
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
115
+
116
+ print("Using device:", device)
117
+
118
+ model = monai.networks.nets.UNet(
119
+ in_channels=1,
120
+ out_channels=3,
121
+ spatial_dims=3,
122
+ channels=[16, 32, 64, 128, 256, 512],
123
+ strides=[2, 2, 2, 2, 2],
124
+ num_res_units=4,
125
+ act="PRELU",
126
+ norm="BATCH",
127
+ dropout=0.15,
128
+ )
129
+
130
+ model.load_state_dict(torch.load("anatomy.pt", map_location=device))
131
+
132
+
133
+ test_ds = Dataset(
134
+ data=data_dict,
135
+ transform=transforms,
136
+ )
137
+ model.eval()
138
+ with torch.no_grad():
139
+ example = test_ds[0]
140
+ label = example["t2_anatomy_reader1"]
141
+ input_tensor = example["t2"].unsqueeze(0)
142
+ input_tensor = input_tensor.to(device)
143
+ output_tensor = inferer(input_tensor, model)
144
+ output_tensor = output_tensor.argmax(dim=1, keepdim=False)
145
+ output_tensor = output_tensor.squeeze(0).to(torch.device("cpu"))
146
+
147
+ output_tensor = postprocessing({"pred": output_tensor, "label": label})["pred"]
148
+ output_tensor = output_tensor.numpy().astype(np.uint8)
149
+ target_shape = example["t2_meta_dict"]["spatial_shape"]
150
+ output_tensor = resize_image(output_tensor, target_shape)
151
+
152
+ # flip first two dimensions
153
+ output_tensor = np.flip(output_tensor, axis=0)
154
+ output_tensor = np.flip(output_tensor, axis=1)
155
+
156
+ new_image = nib.Nifti1Image(output_tensor, affine=example["t2_meta_dict"]["affine"])
157
+ nib.save(new_image, "predicted.nii.gz")
158
+ return "predicted.nii.gz"
159
+