katielink commited on
Commit
b5b43af
β€’
1 Parent(s): a2cbc95

Formatting and examples

Browse files
Files changed (3) hide show
  1. app.py +78 -34
  2. examples/BRATS_486.nii.gz +3 -0
  3. examples/log.csv +0 -2
app.py CHANGED
@@ -13,22 +13,26 @@ from monai.transforms import (
13
  ScaleIntensityd,
14
  )
15
 
 
16
  BUNDLE_NAME = 'spleen_ct_segmentation_v0.1.0'
17
  BUNDLE_PATH = os.path.join(torch.hub.get_dir(), 'bundle', BUNDLE_NAME)
18
 
19
- title = "Segment Brain Tumors with MONAI!"
 
20
  description = """
21
- ## Brain Tumor Segmentation 🧠
22
- A pre-trained model for volumetric (3D) segmentation of brain tumor subregions from multimodal MRIs based on BraTS 2018 data.
23
-
24
  ## To run πŸš€
25
 
26
- Upload a image file in the format: 4 channel MRI (4 aligned MRIs T1c, T1, T2, FLAIR at 1x1x1 mm)
 
 
 
27
 
28
  ## Disclaimer ⚠️
29
 
30
  This is an example, not to be used for diagnostic purposes.
 
31
 
 
32
  ## References πŸ‘€
33
 
34
  1. Myronenko, Andriy. "3D MRI brain tumor segmentation using autoencoder regularization." International MICCAI Brainlesion Workshop. Springer, Cham, 2018. https://arxiv.org/abs/1810.11654.
@@ -36,8 +40,12 @@ This is an example, not to be used for diagnostic purposes.
36
  3. Bakas S, et al. "Advancing The Cancer Genome Atlas glioma MRI collections with expert segmentation labels and radiomic features", Nature Scientific Data, 4:170117 (2017) DOI:10.1038/sdata.2017.117
37
  """
38
 
39
- #examples = 'examples/'
 
 
 
40
 
 
41
  model, _, _ = bundle.load(
42
  name = BUNDLE_NAME,
43
  source = 'huggingface_hub',
@@ -45,10 +53,13 @@ model, _, _ = bundle.load(
45
  load_ts_module=True,
46
  )
47
 
 
48
  device = "cuda:0" if torch.cuda.is_available() else "cpu"
49
 
 
50
  parser = bundle.load_bundle_config(BUNDLE_PATH, 'inference.json')
51
 
 
52
  preproc_transforms = Compose(
53
  [
54
  LoadImaged(keys=["image"]),
@@ -57,7 +68,14 @@ preproc_transforms = Compose(
57
  NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
58
  ]
59
  )
60
- inferer = parser.get_parsed_content('inferer', lazy=True, eval_expr=True, instantiate=True)
 
 
 
 
 
 
 
61
  post_transforms = Compose(
62
  [
63
  Activationsd(keys='pred', sigmoid=True),
@@ -66,10 +84,13 @@ post_transforms = Compose(
66
  ]
67
  )
68
 
 
69
  def predict(input_file, z_axis, model=model, device=device):
 
70
  data = {'image': [input_file.name]}
71
  data = preproc_transforms(data)
72
 
 
73
  model.to(device)
74
  model.eval()
75
  with torch.no_grad():
@@ -77,34 +98,57 @@ def predict(input_file, z_axis, model=model, device=device):
77
  data['pred'] = inferer(inputs=inputs[None,...], network=model)
78
  data = post_transforms(data)
79
 
80
- input_image = data['image'].numpy()
81
- pred_image = data['pred'].cpu().detach().numpy()
 
82
 
83
- input_t1c_image = input_image[0, :, :, z_axis]
84
- #input_t1_image = input_image[1, :, :, z_axis]
85
- #input_t2_image = input_image[2, :, :, z_axis]
86
- #input_flair_image = input_image[3, :, :, z_axis]
 
87
 
88
- pred_tc_image = pred_image[0, 0, :, :, z_axis]
89
- #pred_et_image = pred_image[0, 1, :, :, z_axis]
90
- #pred_wt_image = pred_image[0, 2, :, :, z_axis]
 
91
 
92
- return input_t1c_image, pred_tc_image,
93
-
94
-
95
- iface = gr.Interface(
96
- fn=predict,
97
- inputs=[
98
- gr.File(label='Input file'),
99
- gr.Slider(0, 200, label='z-axis', value=100)
100
- ],
101
- outputs=[
102
- gr.Image(label='T1C image'),
103
- gr.Image(label='Segmentation'),
104
- ],
105
- title=title,
106
- description=description,
107
- #examples=examples,
108
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
 
110
- iface.launch()
 
 
13
  ScaleIntensityd,
14
  )
15
 
16
+ # Define the bundle name and path for downloading
17
  BUNDLE_NAME = 'spleen_ct_segmentation_v0.1.0'
18
  BUNDLE_PATH = os.path.join(torch.hub.get_dir(), 'bundle', BUNDLE_NAME)
19
 
20
+ # Title and description
21
+ title = "# Segment Brain Tumors with MONAI! 🧠"
22
  description = """
 
 
 
23
  ## To run πŸš€
24
 
25
+ Upload a image file in the format: 4 channel MRI (4 aligned MRIs T1c, T1, T2, FLAIR at 1x1x1 mm), or try out one of the examples below!
26
+ If you want to see a different slice, update the slider and click the button.
27
+
28
+ More details on the model can be found [here!](https://huggingface.co/katielink/brats_mri_segmentation_v0.1.0)
29
 
30
  ## Disclaimer ⚠️
31
 
32
  This is an example, not to be used for diagnostic purposes.
33
+ """
34
 
35
+ references = """
36
  ## References πŸ‘€
37
 
38
  1. Myronenko, Andriy. "3D MRI brain tumor segmentation using autoencoder regularization." International MICCAI Brainlesion Workshop. Springer, Cham, 2018. https://arxiv.org/abs/1810.11654.
 
40
  3. Bakas S, et al. "Advancing The Cancer Genome Atlas glioma MRI collections with expert segmentation labels and radiomic features", Nature Scientific Data, 4:170117 (2017) DOI:10.1038/sdata.2017.117
41
  """
42
 
43
+ examples = [
44
+ ['examples/BRATS_485.nii.gz', 100],
45
+ ['examples/BRATS_', 100]
46
+ ]
47
 
48
+ # Load the MONAI pretrained model from Hugging Face Hub
49
  model, _, _ = bundle.load(
50
  name = BUNDLE_NAME,
51
  source = 'huggingface_hub',
 
53
  load_ts_module=True,
54
  )
55
 
56
+ # Use GPU if available
57
  device = "cuda:0" if torch.cuda.is_available() else "cpu"
58
 
59
+ # Load the parser from the MONAI bundle's inference config
60
  parser = bundle.load_bundle_config(BUNDLE_PATH, 'inference.json')
61
 
62
+ # Compose the preprocessing transforms
63
  preproc_transforms = Compose(
64
  [
65
  LoadImaged(keys=["image"]),
 
68
  NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
69
  ]
70
  )
71
+
72
+ # Get the inferer from the bundle's inference config
73
+ inferer = parser.get_parsed_content(
74
+ 'inferer',
75
+ lazy=True, eval_expr=True, instantiate=True
76
+ )
77
+
78
+ # Compose the postprocessing transforms
79
  post_transforms = Compose(
80
  [
81
  Activationsd(keys='pred', sigmoid=True),
 
84
  ]
85
  )
86
 
87
+ # Define the predict function for the demo
88
  def predict(input_file, z_axis, model=model, device=device):
89
+ # Load and process data in MONAI format
90
  data = {'image': [input_file.name]}
91
  data = preproc_transforms(data)
92
 
93
+ # Run inference and post-process predicted labels
94
  model.to(device)
95
  model.eval()
96
  with torch.no_grad():
 
98
  data['pred'] = inferer(inputs=inputs[None,...], network=model)
99
  data = post_transforms(data)
100
 
101
+ # Convert tensors back to numpy arrays
102
+ data['image'] = data['image'].numpy()
103
+ data['pred'] = data['pred'].cpu().detach().numpy()
104
 
105
+ # Magnetic resonance imaging sequences
106
+ t1c = data['image'][0, :, :, z_axis] # T1-weighted, post contrast
107
+ t1 = data['image'][1, :, :, z_axis] # T1-weighted, pre contrast
108
+ t2 = data['image'][2, :, :, z_axis] # T2-weighted
109
+ flair = data['image'][3, :, :, z_axis] # FLAIR
110
 
111
+ # BraTS labels
112
+ tc = data['pred'][0, 0, :, :, z_axis] # Tumor core
113
+ wt = data['pred'][0, 1, :, :, z_axis] # Whole tumor
114
+ et = data['pred'][0, 2, :, :, z_axis] # Enhancing tumor
115
 
116
+ return [t1c, t1, t2, flair], [tc, wt, et]
117
+
118
+ # Use blocks to set up a more complex demo
119
+ with gr.Blocks() as demo:
120
+
121
+ # Show title and description
122
+ gr.Markdown(title)
123
+ gr.Markdown(description)
124
+
125
+ # Get the input file and slice slider as inputs
126
+ input_file = gr.File(label='input file')
127
+ z_axis = gr.Slider(0, 200, label='z-axis', value=50)
128
+
129
+ # Show the button with custom label
130
+ button = gr.Button("Segment Tumor!")
131
+
132
+ # Show examples for the user to try
133
+ gr.Markdown("Try some examples from MONAI's Decathlon Dataset:")
134
+ examples = gr.Examples(
135
+ examples=examples,
136
+ inputs=[gr.File(), gr.Slider()]
137
+ )
138
+
139
+ # Show the input image with different MR sequences
140
+ input_image = gr.Gallery(label='input MRI sequences (T1+, T1, T2, FLAIR)')
141
+ output_segmentation = gr.Gallery(label='output segmentations (TC, EC, WT)')
142
+
143
+ # Run prediction on button click
144
+ button.click(
145
+ predict,
146
+ inputs=[input_file, z_axis],
147
+ outputs=[input_image, output_segmentation]
148
+ )
149
+
150
+ # Show references at the bottom of the demo
151
+ gr.Markdown(references)
152
 
153
+ # Launch the demo
154
+ demo.launch()
examples/BRATS_486.nii.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e8957d67a50b39afd8210f3ca51a20c77ef1c92642800f91b50f16b27778f2b2
3
+ size 11111216
examples/log.csv DELETED
@@ -1,2 +0,0 @@
1
- input_file
2
- BRATS_485.nii.gz