DiGuaQiu commited on
Commit
aa2b9a2
1 Parent(s): c8ffaae

Create utils.py

Browse files
Files changed (1) hide show
  1. utils.py +201 -0
utils.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Code copied and modified from https://huggingface.co/spaces/BAAI/SegVol/blob/main/utils.py
2
+
3
+ from pathlib import Path
4
+
5
+ import matplotlib as mpl
6
+ import matplotlib.pyplot as plt
7
+ import nibabel as nib
8
+ import numpy as np
9
+ import torch
10
+ from monai.transforms import LoadImage
11
+ from mrsegmentator import inference
12
+ from mrsegmentator.utils import add_postfix
13
+ from PIL import Image, ImageColor, ImageDraw, ImageEnhance
14
+ from scipy import ndimage
15
+ from monai.transforms import LoadImage, Orientation, Spacing
16
+ import SimpleITK as sitk
17
+
18
+ import streamlit as st
19
+
20
+ initial_rectangle = {
21
+ "version": "4.4.0",
22
+ "objects": [
23
+ {
24
+ "type": "rect",
25
+ "version": "4.4.0",
26
+ "originX": "left",
27
+ "originY": "top",
28
+ "left": 50,
29
+ "top": 50,
30
+ "width": 100,
31
+ "height": 100,
32
+ "fill": "rgba(255, 165, 0, 0.3)",
33
+ "stroke": "#2909F1",
34
+ "strokeWidth": 3,
35
+ "strokeDashArray": None,
36
+ "strokeLineCap": "butt",
37
+ "strokeDashOffset": 0,
38
+ "strokeLineJoin": "miter",
39
+ "strokeUniform": True,
40
+ "strokeMiterLimit": 4,
41
+ "scaleX": 1,
42
+ "scaleY": 1,
43
+ "angle": 0,
44
+ "flipX": False,
45
+ "flipY": False,
46
+ "opacity": 1,
47
+ "shadow": None,
48
+ "visible": True,
49
+ "backgroundColor": "",
50
+ "fillRule": "nonzero",
51
+ "paintFirst": "fill",
52
+ "globalCompositeOperation": "source-over",
53
+ "skewX": 0,
54
+ "skewY": 0,
55
+ "rx": 0,
56
+ "ry": 0,
57
+ }
58
+ ],
59
+ }
60
+
61
+
62
+ def run(tmpdirname):
63
+ if st.session_state.option is not None:
64
+ image = Path(__file__).parent / str(st.session_state.option)
65
+
66
+ inference.infer([image], tmpdirname, [0], split_level=1)
67
+ seg_name = add_postfix(image.name, "seg")
68
+ preds_path = tmpdirname + "/" + seg_name
69
+ st.session_state.preds_3D = read_image(preds_path)
70
+ st.session_state.preds_3D_ori = sitk.ReadImage(preds_path)
71
+
72
+
73
+ def reflect_box_into_model(box_3d):
74
+ z1, y1, x1, z2, y2, x2 = box_3d
75
+ x1_prompt = int(x1 * 256.0 / 325.0)
76
+ y1_prompt = int(y1 * 256.0 / 325.0)
77
+ z1_prompt = int(z1 * 32.0 / 325.0)
78
+ x2_prompt = int(x2 * 256.0 / 325.0)
79
+ y2_prompt = int(y2 * 256.0 / 325.0)
80
+ z2_prompt = int(z2 * 32.0 / 325.0)
81
+ return torch.tensor(
82
+ np.array([z1_prompt, y1_prompt, x1_prompt, z2_prompt, y2_prompt, x2_prompt])
83
+ )
84
+
85
+
86
+ def reflect_json_data_to_3D_box(json_data, view):
87
+ if view == "xy":
88
+ st.session_state.rectangle_3Dbox[1] = json_data["objects"][0]["top"]
89
+ st.session_state.rectangle_3Dbox[2] = json_data["objects"][0]["left"]
90
+ st.session_state.rectangle_3Dbox[4] = (
91
+ json_data["objects"][0]["top"]
92
+ + json_data["objects"][0]["height"] * json_data["objects"][0]["scaleY"]
93
+ )
94
+ st.session_state.rectangle_3Dbox[5] = (
95
+ json_data["objects"][0]["left"]
96
+ + json_data["objects"][0]["width"] * json_data["objects"][0]["scaleX"]
97
+ )
98
+ print(st.session_state.rectangle_3Dbox)
99
+
100
+
101
+ def make_fig(image, preds, px_range = (10, 400), transparency=0.5):
102
+
103
+ fig, ax = plt.subplots(1, 1, figsize=(4,4))
104
+ image_slice = image.clip(*px_range)
105
+
106
+ ax.imshow(
107
+ image_slice,
108
+ cmap="Greys_r",
109
+ vmin=px_range[0],
110
+ vmax=px_range[1],
111
+ )
112
+
113
+ if preds is not None:
114
+ image_slice = np.array(preds)
115
+ alpha = np.zeros(image_slice.shape)
116
+ alpha[image_slice > 0.1] = transparency
117
+ ax.imshow(
118
+ image_slice,
119
+ cmap="jet",
120
+ alpha=alpha,
121
+ vmin=0,
122
+ vmax=40,
123
+ )
124
+
125
+ # plot edges
126
+ edge_slice = np.zeros(image_slice.shape, dtype=int)
127
+
128
+ for i in np.unique(image_slice):
129
+ _slice = image_slice.copy()
130
+ _slice[_slice != i] = 0
131
+ edges = ndimage.laplace(_slice)
132
+ edge_slice[edges != 0] = i
133
+
134
+ cmap = mpl.cm.jet(np.linspace(0, 1, int(preds.max())))
135
+ cmap -= 0.4
136
+ cmap = cmap.clip(0, 1)
137
+ cmap = mpl.colors.ListedColormap(cmap)
138
+
139
+ alpha = np.zeros(edge_slice.shape)
140
+ alpha[edge_slice > 0.01] = 0.9
141
+
142
+ ax.imshow(
143
+ edge_slice,
144
+ alpha=alpha,
145
+ cmap=cmap,
146
+ vmin=0,
147
+ vmax=40,
148
+ )
149
+
150
+ plt.axis("off")
151
+ ax.set_xticks([])
152
+ ax.set_yticks([])
153
+
154
+ fig.canvas.draw()
155
+
156
+ # transform to image
157
+ return Image.frombytes("RGB", fig.canvas.get_width_height(), fig.canvas.tostring_rgb())
158
+
159
+
160
+ #######################################
161
+
162
+
163
+ def make_isotropic(image, interpolator = sitk.sitkLinear, spacing = None):
164
+ '''
165
+ Many file formats (e.g. jpg, png,...) expect the pixels to be isotropic, same
166
+ spacing for all axes. Saving non-isotropic data in these formats will result in
167
+ distorted images. This function makes an image isotropic via resampling, if needed.
168
+ Args:
169
+ image (SimpleITK.Image): Input image.
170
+ interpolator: By default the function uses a linear interpolator. For
171
+ label images one should use the sitkNearestNeighbor interpolator
172
+ so as not to introduce non-existant labels.
173
+ spacing (float): Desired spacing. If none given then use the smallest spacing from
174
+ the original image.
175
+ Returns:
176
+ SimpleITK.Image with isotropic spacing which occupies the same region in space as
177
+ the input image.
178
+ '''
179
+ original_spacing = image.GetSpacing()
180
+ # Image is already isotropic, just return a copy.
181
+ if all(spc == original_spacing[0] for spc in original_spacing):
182
+ return sitk.Image(image)
183
+ # Make image isotropic via resampling.
184
+ original_size = image.GetSize()
185
+ if spacing is None:
186
+ spacing = min(original_spacing)
187
+ new_spacing = [spacing]*image.GetDimension()
188
+ new_size = [int(round(osz*ospc/spacing)) for osz, ospc in zip(original_size, original_spacing)]
189
+ return sitk.Resample(image, new_size, sitk.Transform(), interpolator,
190
+ image.GetOrigin(), new_spacing, image.GetDirection(), 0, # default pixel value
191
+ image.GetPixelID())
192
+
193
+
194
+ def read_image(path):
195
+
196
+ img = sitk.ReadImage(path)
197
+ img = sitk.DICOMOrient(img, "LPS")
198
+ img = make_isotropic(img)
199
+ img = sitk.GetArrayFromImage(img)
200
+
201
+ return img