Spaces:
Running
Running
Create utils.py
Browse files
utils.py
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Code copied and modified from https://huggingface.co/spaces/BAAI/SegVol/blob/main/utils.py
|
2 |
+
|
3 |
+
from pathlib import Path
|
4 |
+
|
5 |
+
import matplotlib as mpl
|
6 |
+
import matplotlib.pyplot as plt
|
7 |
+
import nibabel as nib
|
8 |
+
import numpy as np
|
9 |
+
import torch
|
10 |
+
from monai.transforms import LoadImage
|
11 |
+
from mrsegmentator import inference
|
12 |
+
from mrsegmentator.utils import add_postfix
|
13 |
+
from PIL import Image, ImageColor, ImageDraw, ImageEnhance
|
14 |
+
from scipy import ndimage
|
15 |
+
from monai.transforms import LoadImage, Orientation, Spacing
|
16 |
+
import SimpleITK as sitk
|
17 |
+
|
18 |
+
import streamlit as st
|
19 |
+
|
20 |
+
initial_rectangle = {
|
21 |
+
"version": "4.4.0",
|
22 |
+
"objects": [
|
23 |
+
{
|
24 |
+
"type": "rect",
|
25 |
+
"version": "4.4.0",
|
26 |
+
"originX": "left",
|
27 |
+
"originY": "top",
|
28 |
+
"left": 50,
|
29 |
+
"top": 50,
|
30 |
+
"width": 100,
|
31 |
+
"height": 100,
|
32 |
+
"fill": "rgba(255, 165, 0, 0.3)",
|
33 |
+
"stroke": "#2909F1",
|
34 |
+
"strokeWidth": 3,
|
35 |
+
"strokeDashArray": None,
|
36 |
+
"strokeLineCap": "butt",
|
37 |
+
"strokeDashOffset": 0,
|
38 |
+
"strokeLineJoin": "miter",
|
39 |
+
"strokeUniform": True,
|
40 |
+
"strokeMiterLimit": 4,
|
41 |
+
"scaleX": 1,
|
42 |
+
"scaleY": 1,
|
43 |
+
"angle": 0,
|
44 |
+
"flipX": False,
|
45 |
+
"flipY": False,
|
46 |
+
"opacity": 1,
|
47 |
+
"shadow": None,
|
48 |
+
"visible": True,
|
49 |
+
"backgroundColor": "",
|
50 |
+
"fillRule": "nonzero",
|
51 |
+
"paintFirst": "fill",
|
52 |
+
"globalCompositeOperation": "source-over",
|
53 |
+
"skewX": 0,
|
54 |
+
"skewY": 0,
|
55 |
+
"rx": 0,
|
56 |
+
"ry": 0,
|
57 |
+
}
|
58 |
+
],
|
59 |
+
}
|
60 |
+
|
61 |
+
|
62 |
+
def run(tmpdirname):
|
63 |
+
if st.session_state.option is not None:
|
64 |
+
image = Path(__file__).parent / str(st.session_state.option)
|
65 |
+
|
66 |
+
inference.infer([image], tmpdirname, [0], split_level=1)
|
67 |
+
seg_name = add_postfix(image.name, "seg")
|
68 |
+
preds_path = tmpdirname + "/" + seg_name
|
69 |
+
st.session_state.preds_3D = read_image(preds_path)
|
70 |
+
st.session_state.preds_3D_ori = sitk.ReadImage(preds_path)
|
71 |
+
|
72 |
+
|
73 |
+
def reflect_box_into_model(box_3d):
|
74 |
+
z1, y1, x1, z2, y2, x2 = box_3d
|
75 |
+
x1_prompt = int(x1 * 256.0 / 325.0)
|
76 |
+
y1_prompt = int(y1 * 256.0 / 325.0)
|
77 |
+
z1_prompt = int(z1 * 32.0 / 325.0)
|
78 |
+
x2_prompt = int(x2 * 256.0 / 325.0)
|
79 |
+
y2_prompt = int(y2 * 256.0 / 325.0)
|
80 |
+
z2_prompt = int(z2 * 32.0 / 325.0)
|
81 |
+
return torch.tensor(
|
82 |
+
np.array([z1_prompt, y1_prompt, x1_prompt, z2_prompt, y2_prompt, x2_prompt])
|
83 |
+
)
|
84 |
+
|
85 |
+
|
86 |
+
def reflect_json_data_to_3D_box(json_data, view):
|
87 |
+
if view == "xy":
|
88 |
+
st.session_state.rectangle_3Dbox[1] = json_data["objects"][0]["top"]
|
89 |
+
st.session_state.rectangle_3Dbox[2] = json_data["objects"][0]["left"]
|
90 |
+
st.session_state.rectangle_3Dbox[4] = (
|
91 |
+
json_data["objects"][0]["top"]
|
92 |
+
+ json_data["objects"][0]["height"] * json_data["objects"][0]["scaleY"]
|
93 |
+
)
|
94 |
+
st.session_state.rectangle_3Dbox[5] = (
|
95 |
+
json_data["objects"][0]["left"]
|
96 |
+
+ json_data["objects"][0]["width"] * json_data["objects"][0]["scaleX"]
|
97 |
+
)
|
98 |
+
print(st.session_state.rectangle_3Dbox)
|
99 |
+
|
100 |
+
|
101 |
+
def make_fig(image, preds, px_range = (10, 400), transparency=0.5):
|
102 |
+
|
103 |
+
fig, ax = plt.subplots(1, 1, figsize=(4,4))
|
104 |
+
image_slice = image.clip(*px_range)
|
105 |
+
|
106 |
+
ax.imshow(
|
107 |
+
image_slice,
|
108 |
+
cmap="Greys_r",
|
109 |
+
vmin=px_range[0],
|
110 |
+
vmax=px_range[1],
|
111 |
+
)
|
112 |
+
|
113 |
+
if preds is not None:
|
114 |
+
image_slice = np.array(preds)
|
115 |
+
alpha = np.zeros(image_slice.shape)
|
116 |
+
alpha[image_slice > 0.1] = transparency
|
117 |
+
ax.imshow(
|
118 |
+
image_slice,
|
119 |
+
cmap="jet",
|
120 |
+
alpha=alpha,
|
121 |
+
vmin=0,
|
122 |
+
vmax=40,
|
123 |
+
)
|
124 |
+
|
125 |
+
# plot edges
|
126 |
+
edge_slice = np.zeros(image_slice.shape, dtype=int)
|
127 |
+
|
128 |
+
for i in np.unique(image_slice):
|
129 |
+
_slice = image_slice.copy()
|
130 |
+
_slice[_slice != i] = 0
|
131 |
+
edges = ndimage.laplace(_slice)
|
132 |
+
edge_slice[edges != 0] = i
|
133 |
+
|
134 |
+
cmap = mpl.cm.jet(np.linspace(0, 1, int(preds.max())))
|
135 |
+
cmap -= 0.4
|
136 |
+
cmap = cmap.clip(0, 1)
|
137 |
+
cmap = mpl.colors.ListedColormap(cmap)
|
138 |
+
|
139 |
+
alpha = np.zeros(edge_slice.shape)
|
140 |
+
alpha[edge_slice > 0.01] = 0.9
|
141 |
+
|
142 |
+
ax.imshow(
|
143 |
+
edge_slice,
|
144 |
+
alpha=alpha,
|
145 |
+
cmap=cmap,
|
146 |
+
vmin=0,
|
147 |
+
vmax=40,
|
148 |
+
)
|
149 |
+
|
150 |
+
plt.axis("off")
|
151 |
+
ax.set_xticks([])
|
152 |
+
ax.set_yticks([])
|
153 |
+
|
154 |
+
fig.canvas.draw()
|
155 |
+
|
156 |
+
# transform to image
|
157 |
+
return Image.frombytes("RGB", fig.canvas.get_width_height(), fig.canvas.tostring_rgb())
|
158 |
+
|
159 |
+
|
160 |
+
#######################################
|
161 |
+
|
162 |
+
|
163 |
+
def make_isotropic(image, interpolator = sitk.sitkLinear, spacing = None):
|
164 |
+
'''
|
165 |
+
Many file formats (e.g. jpg, png,...) expect the pixels to be isotropic, same
|
166 |
+
spacing for all axes. Saving non-isotropic data in these formats will result in
|
167 |
+
distorted images. This function makes an image isotropic via resampling, if needed.
|
168 |
+
Args:
|
169 |
+
image (SimpleITK.Image): Input image.
|
170 |
+
interpolator: By default the function uses a linear interpolator. For
|
171 |
+
label images one should use the sitkNearestNeighbor interpolator
|
172 |
+
so as not to introduce non-existant labels.
|
173 |
+
spacing (float): Desired spacing. If none given then use the smallest spacing from
|
174 |
+
the original image.
|
175 |
+
Returns:
|
176 |
+
SimpleITK.Image with isotropic spacing which occupies the same region in space as
|
177 |
+
the input image.
|
178 |
+
'''
|
179 |
+
original_spacing = image.GetSpacing()
|
180 |
+
# Image is already isotropic, just return a copy.
|
181 |
+
if all(spc == original_spacing[0] for spc in original_spacing):
|
182 |
+
return sitk.Image(image)
|
183 |
+
# Make image isotropic via resampling.
|
184 |
+
original_size = image.GetSize()
|
185 |
+
if spacing is None:
|
186 |
+
spacing = min(original_spacing)
|
187 |
+
new_spacing = [spacing]*image.GetDimension()
|
188 |
+
new_size = [int(round(osz*ospc/spacing)) for osz, ospc in zip(original_size, original_spacing)]
|
189 |
+
return sitk.Resample(image, new_size, sitk.Transform(), interpolator,
|
190 |
+
image.GetOrigin(), new_spacing, image.GetDirection(), 0, # default pixel value
|
191 |
+
image.GetPixelID())
|
192 |
+
|
193 |
+
|
194 |
+
def read_image(path):
|
195 |
+
|
196 |
+
img = sitk.ReadImage(path)
|
197 |
+
img = sitk.DICOMOrient(img, "LPS")
|
198 |
+
img = make_isotropic(img)
|
199 |
+
img = sitk.GetArrayFromImage(img)
|
200 |
+
|
201 |
+
return img
|