andreped commited on
Commit
c5a5e6c
2 Parent(s): 06664ab 4e18454

Merge pull request #24 from andreped/improved-demo-ui

Browse files
Dockerfile CHANGED
@@ -22,6 +22,8 @@ WORKDIR /code
22
  RUN apt-get update -y
23
  RUN apt install git --fix-missing -y
24
 
 
 
25
  # install dependencies
26
  COPY ./demo/requirements.txt /code/demo/requirements.txt
27
  RUN python3.7 -m pip install --no-cache-dir --upgrade -r /code/demo/requirements.txt
@@ -32,8 +34,6 @@ RUN python3.7 -m pip install --force-reinstall typing_extensions==4.0.0
32
  # Install wget
33
  RUN apt install wget -y
34
 
35
- RUN ls -la
36
-
37
  # Set up a new user named "user" with user ID 1000
38
  RUN useradd -m -u 1000 user
39
 
 
22
  RUN apt-get update -y
23
  RUN apt install git --fix-missing -y
24
 
25
+ RUN ls -la
26
+
27
  # install dependencies
28
  COPY ./demo/requirements.txt /code/demo/requirements.txt
29
  RUN python3.7 -m pip install --no-cache-dir --upgrade -r /code/demo/requirements.txt
 
34
  # Install wget
35
  RUN apt install wget -y
36
 
 
 
37
  # Set up a new user named "user" with user ID 1000
38
  RUN useradd -m -u 1000 user
39
 
demo/README.md CHANGED
@@ -40,6 +40,15 @@ of the predicted liver parenchyma 3D volume when finished processing.
40
  Analysis process can be monitored from the `Logs` tab next to the `Running` button
41
  in the Hugging Face `livermask` space.
42
 
 
 
 
 
 
 
 
 
 
43
  Natural future TODOs include:
44
  - [ ] Add gallery widget to enable scrolling through 2D slices
45
  - [ ] Render segmentation for individual 2D slices as overlays
 
40
  Analysis process can be monitored from the `Logs` tab next to the `Running` button
41
  in the Hugging Face `livermask` space.
42
 
43
+ It is also possible to build the app as a docker image and deploy it. To do so follow these steps:
44
+
45
+ ```
46
+ docker build -t livermask ..
47
+ docker run -it -p 7860:7860 livermask
48
+ ```
49
+
50
+ Then open `http://127.0.0.1:7860` in your favourite internet browser to view the demo.
51
+
52
  Natural future TODOs include:
53
  - [ ] Add gallery widget to enable scrolling through 2D slices
54
  - [ ] Render segmentation for individual 2D slices as overlays
demo/app.py CHANGED
@@ -1,53 +1,16 @@
1
- import gradio as gr
2
- import subprocess as sp
3
- from skimage.measure import marching_cubes
4
- import nibabel as nib
5
- from nibabel.processing import resample_to_output
6
 
7
 
8
- def nifti_to_glb(path):
9
- # load NIFTI into numpy array
10
- image = nib.load(path)
11
- resampled = resample_to_output(image, [1, 1, 1], order=1)
12
- data = resampled.get_fdata().astype("uint8")
13
-
14
- # extract surface
15
- verts, faces, normals, values = marching_cubes(data, 0)
16
- faces += 1
17
-
18
- with open('prediction.obj', 'w') as thefile:
19
- for item in verts:
20
- thefile.write("v {0} {1} {2}\n".format(item[0],item[1],item[2]))
21
-
22
- for item in normals:
23
- thefile.write("vn {0} {1} {2}\n".format(item[0],item[1],item[2]))
24
-
25
- for item in faces:
26
- thefile.write("f {0}//{0} {1}//{1} {2}//{2}\n".format(item[0],item[1],item[2]))
27
-
28
-
29
- def run_model(input_path):
30
- from livermask.utils.run import run_analysis
31
-
32
- run_analysis(cpu=True, extension='.nii', path=input_path, output='prediction', verbose=True, vessels=False, name="/home/user/app/model.h5", mp_enabled=False)
33
 
 
 
34
 
35
- def load_mesh(mesh_file_name):
36
- path = mesh_file_name.name
37
- run_model(path)
38
- nifti_to_glb("prediction-livermask.nii")
39
- return "./prediction.obj"
40
 
41
 
42
  if __name__ == "__main__":
43
- print("Launching demo...")
44
- demo = gr.Interface(
45
- fn=load_mesh,
46
- inputs=gr.UploadButton(label="Click to Upload a File", file_types=[".nii", ".nii.nz"], file_count="single"),
47
- outputs=gr.Model3D(clear_color=[0.0, 0.0, 0.0, 0.0], label="3D Model"),
48
- title="livermask: Automatic Liver Parenchyma segmentation in CT",
49
- description="Using pretrained deep learning model trained on the LiTS17 dataset",
50
- )
51
- # sharing app publicly -> share=True: https://gradio.app/sharing-your-app/
52
- # inference times > 60 seconds -> need queue(): https://github.com/tloen/alpaca-lora/issues/60#issuecomment-1510006062
53
- demo.queue().launch(server_name="0.0.0.0", server_port=7860, share=True)
 
1
+ from src.gui import WebUI
 
 
 
 
2
 
3
 
4
+ def main():
5
+ print("Launching demo...")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
+ model_name = "/home/user/app/model.h5" # "/Users/andreped/workspace/livermask/model.h5"
8
+ class_name = "parenchyma"
9
 
10
+ # initialize and run app
11
+ app = WebUI(model_name=model_name, class_name=class_name)
12
+ app.run()
 
 
13
 
14
 
15
  if __name__ == "__main__":
16
+ main()
 
 
 
 
 
 
 
 
 
 
demo/src/__init__.py ADDED
File without changes
demo/src/compute.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+
2
+
3
+ def run_model(input_path, model_name="/home/user/app/model.h5"):
4
+ from livermask.utils.run import run_analysis
5
+ run_analysis(cpu=True, extension='.nii', path=input_path, output='prediction', verbose=True, vessels=False, name=model_name, mp_enabled=False)
6
+
demo/src/convert.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import nibabel as nib
2
+ from nibabel.processing import resample_to_output
3
+ from skimage.measure import marching_cubes
4
+
5
+
6
+ def nifti_to_glb(path, output="prediction.obj"):
7
+ # load NIFTI into numpy array
8
+ image = nib.load(path)
9
+ resampled = resample_to_output(image, [1, 1, 1], order=1)
10
+ data = resampled.get_fdata().astype("uint8")
11
+
12
+ # extract surface
13
+ verts, faces, normals, values = marching_cubes(data, 0)
14
+ faces += 1
15
+
16
+ with open(output, 'w') as thefile:
17
+ for item in verts:
18
+ thefile.write("v {0} {1} {2}\n".format(item[0],item[1],item[2]))
19
+
20
+ for item in normals:
21
+ thefile.write("vn {0} {1} {2}\n".format(item[0],item[1],item[2]))
22
+
23
+ for item in faces:
24
+ thefile.write("f {0}//{0} {1}//{1} {2}//{2}\n".format(item[0],item[1],item[2]))
demo/src/gui.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from .utils import load_ct_to_numpy, load_pred_volume_to_numpy
3
+ from .compute import run_model
4
+ from .convert import nifti_to_glb
5
+
6
+
7
+ class WebUI:
8
+ def __init__(self, model_name, class_name):
9
+ # global states
10
+ self.images = []
11
+ self.pred_images = []
12
+
13
+ self.nb_slider_items = 100
14
+
15
+ self.model_name = model_name
16
+ self.class_name = class_name
17
+
18
+ # define widgets not to be rendered immediantly, but later on
19
+ self.slider = gr.Slider(1, self.nb_slider_items, value=1, step=1, label="Which 2D slice to show")
20
+ self.volume_renderer = gr.Model3D(
21
+ clear_color=[0.0, 0.0, 0.0, 0.0],
22
+ label="3D Model",
23
+ visible=True
24
+ ).style(height=512)
25
+
26
+ def combine_ct_and_seg(self, img, pred):
27
+ return (img, [(pred, self.class_name)])
28
+
29
+ def upload_file(self, file):
30
+ return file.name
31
+
32
+ def load_mesh(self, mesh_file_name, model_name="/home/user/app/model.h5"):
33
+ path = mesh_file_name.name
34
+ run_model(path, model_name)
35
+ nifti_to_glb("prediction-livermask.nii")
36
+ self.images = load_ct_to_numpy("./files/test_ct.nii")
37
+ self.pred_images = load_pred_volume_to_numpy("./prediction-livermask.nii")
38
+ self.slider = self.slider.update(value=2)
39
+ return "./prediction.obj"
40
+
41
+ def get_img_pred_pair(self, k):
42
+ k = int(k) - 1
43
+ out = [gr.AnnotatedImage.update(visible=False)] * self.nb_slider_items
44
+ out[k] = gr.AnnotatedImage.update(self.combine_ct_and_seg(self.images[k], self.pred_images[k]), visible=True)
45
+ return out
46
+
47
+ def run(self):
48
+ with gr.Blocks() as demo:
49
+
50
+ with gr.Row().style(equal_height=True):
51
+ file_output = gr.File(file_types=[".nii", ".nii.nz"], file_count="single").style(full_width=False, size="sm")
52
+ file_output.upload(self.upload_file, file_output, file_output)
53
+
54
+ run_btn = gr.Button("Run analysis").style(full_width=False, size="sm")
55
+ run_btn.click(fn=lambda x: self.load_mesh(x, model_name=self.model_name), inputs=file_output, outputs=self.volume_renderer)
56
+
57
+ with gr.Row().style(equal_height=True):
58
+ with gr.Box():
59
+ image_boxes = []
60
+ for i in range(self.nb_slider_items):
61
+ visibility = True if i == 1 else False
62
+ t = gr.AnnotatedImage(visible=visibility)\
63
+ .style(color_map={self.class_name: "#ffae00"}, height=512, width=512)
64
+ image_boxes.append(t)
65
+
66
+ self.slider.change(self.get_img_pred_pair, self.slider, image_boxes)
67
+
68
+ with gr.Box():
69
+ self.volume_renderer.render()
70
+
71
+ with gr.Row():
72
+ self.slider.render()
73
+
74
+ # sharing app publicly -> share=True: https://gradio.app/sharing-your-app/
75
+ # inference times > 60 seconds -> need queue(): https://github.com/tloen/alpaca-lora/issues/60#issuecomment-1510006062
76
+ demo.queue().launch(server_name="0.0.0.0", server_port=7860, share=True)
demo/src/utils.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import nibabel as nib
2
+ import numpy as np
3
+
4
+
5
+ def load_ct_to_numpy(data_path):
6
+ if type(data_path) != str:
7
+ data_path = data_path.name
8
+
9
+ image = nib.load(data_path)
10
+ data = image.get_fdata()
11
+
12
+ data = np.rot90(data, k=1, axes=(0, 1))
13
+
14
+ data[data < -150] = -150
15
+ data[data > 250] = 250
16
+
17
+ data = data - np.amin(data)
18
+ data = data / np.amax(data) * 255
19
+ data = data.astype("uint8")
20
+
21
+ print(data.shape)
22
+ return [data[..., i] for i in range(data.shape[-1])]
23
+
24
+
25
+ def load_pred_volume_to_numpy(data_path):
26
+ if type(data_path) != str:
27
+ data_path = data_path.name
28
+
29
+ image = nib.load(data_path)
30
+ data = image.get_fdata()
31
+
32
+ data = np.rot90(data, k=1, axes=(0, 1))
33
+
34
+ data[data > 0] = 1
35
+ data = data.astype("uint8")
36
+
37
+ print(data.shape)
38
+ return [data[..., i] for i in range(data.shape[-1])]
livermask/utils/process.py CHANGED
@@ -13,14 +13,11 @@ import argparse
13
  import pkg_resources
14
  import tensorflow as tf
15
  import logging as log
16
- import chainer
17
  import math
18
- from .unet3d import UNet3D
19
  from .yaml_utils import Config
20
  import yaml
21
  from tensorflow.keras import backend as K
22
  from numba import cuda
23
- from .utils import load_vessel_model
24
  import multiprocessing as mp
25
 
26
 
@@ -139,6 +136,11 @@ def liver_segmenter(params):
139
 
140
 
141
  def vessel_segmenter(curr, output, cpu, verbose, multiple_flag, liver_mask, name_vessel, extension):
 
 
 
 
 
142
  # check if cupy is available, if not, set cpu=True
143
  try:
144
  import cupy
@@ -157,7 +159,6 @@ def vessel_segmenter(curr, output, cpu, verbose, multiple_flag, liver_mask, name
157
  nib_volume = nib.load(curr)
158
  new_spacing = [1., 1., 1.]
159
  resampled_volume = resample_to_output(nib_volume, new_spacing, order=1)
160
- # resampled_volume = nib_volume
161
  org = resampled_volume.get_data().astype('float32')
162
 
163
  # HU clipping
 
13
  import pkg_resources
14
  import tensorflow as tf
15
  import logging as log
 
16
  import math
 
17
  from .yaml_utils import Config
18
  import yaml
19
  from tensorflow.keras import backend as K
20
  from numba import cuda
 
21
  import multiprocessing as mp
22
 
23
 
 
136
 
137
 
138
  def vessel_segmenter(curr, output, cpu, verbose, multiple_flag, liver_mask, name_vessel, extension):
139
+ # only import chainer stuff inside here, to avoid unnecessary imports
140
+ import chainer
141
+ from .unet3d import UNet3D
142
+ from .utils import load_vessel_model
143
+
144
  # check if cupy is available, if not, set cpu=True
145
  try:
146
  import cupy
 
159
  nib_volume = nib.load(curr)
160
  new_spacing = [1., 1., 1.]
161
  resampled_volume = resample_to_output(nib_volume, new_spacing, order=1)
 
162
  org = resampled_volume.get_data().astype('float32')
163
 
164
  # HU clipping
livermask/utils/utils.py CHANGED
@@ -1,6 +1,5 @@
1
  import gdown
2
  import logging as log
3
- import chainer
4
  from .unet3d import UNet3D
5
  from .fetch import download
6
  import os
@@ -29,6 +28,7 @@ def get_vessel_model(output):
29
 
30
 
31
  def load_vessel_model(path, cpu):
 
32
  unet = UNet3D(num_of_label=2)
33
  chainer.serializers.load_npz(path, unet)
34
  if not cpu:
 
1
  import gdown
2
  import logging as log
 
3
  from .unet3d import UNet3D
4
  from .fetch import download
5
  import os
 
28
 
29
 
30
  def load_vessel_model(path, cpu):
31
+ import chainer
32
  unet = UNet3D(num_of_label=2)
33
  chainer.serializers.load_npz(path, unet)
34
  if not cpu: