A-PolarBear commited on
Commit
6f0ae7f
1 Parent(s): 5779bf4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +154 -0
app.py CHANGED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gradio as gr
3
+ import torch
4
+ from monai import bundle
5
+ from monai.transforms import (
6
+ Compose,
7
+ LoadImaged,
8
+ EnsureChannelFirstd,
9
+ Orientationd,
10
+ NormalizeIntensityd,
11
+ Activationsd,
12
+ AsDiscreted,
13
+ ScaleIntensityd,
14
+ )
15
+
16
+ # Define the bundle name and path for downloading
17
+ BUNDLE_NAME = 'spleen_ct_segmentation_v0.1.0'
18
+ BUNDLE_PATH = os.path.join(torch.hub.get_dir(), 'bundle', BUNDLE_NAME)
19
+
20
+ # Title and description
21
+ title = '<h1 style="text-align: center;">Segment Brain Tumors with MONAI! 🧠 </h1>'
22
+ description = """
23
+ ## 🚀 To run
24
+ Upload a brain MRI image file, or try out one of the examples below!
25
+ If you want to see a different slice, update the slider.
26
+ More details on the model can be found [here!](https://huggingface.co/katielink/brats_mri_segmentation_v0.1.0)
27
+ ## ⚠️ Disclaimer
28
+ This is an example, not to be used for diagnostic purposes.
29
+ """
30
+
31
+ references = """
32
+ ## 👀 References
33
+ 1. Myronenko, Andriy. "3D MRI brain tumor segmentation using autoencoder regularization." International MICCAI Brainlesion Workshop. Springer, Cham, 2018. https://arxiv.org/abs/1810.11654.
34
+ 2. Menze BH, et al. "The Multimodal Brain Tumor Image Segmentation Benchmark (BRATS)", IEEE Transactions on Medical Imaging 34(10), 1993-2024 (2015) DOI: 10.1109/TMI.2014.2377694
35
+ 3. Bakas S, et al. "Advancing The Cancer Genome Atlas glioma MRI collections with expert segmentation labels and radiomic features", Nature Scientific Data, 4:170117 (2017) DOI:10.1038/sdata.2017.117
36
+ """
37
+
38
+ examples = [
39
+ ['examples/BRATS_485.nii.gz', 65],
40
+ ['examples/BRATS_486.nii.gz', 80]
41
+ ]
42
+
43
+ # Load the MONAI pretrained model from Hugging Face Hub
44
+ model, _, _ = bundle.load(
45
+ name = BUNDLE_NAME,
46
+ source = 'huggingface_hub',
47
+ repo = 'katielink/brats_mri_segmentation_v0.1.0',
48
+ load_ts_module=True,
49
+ )
50
+
51
+ # Use GPU if available
52
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
53
+
54
+ # Load the parser from the MONAI bundle's inference config
55
+ parser = bundle.load_bundle_config(BUNDLE_PATH, 'inference.json')
56
+
57
+ # Compose the preprocessing transforms
58
+ preproc_transforms = Compose(
59
+ [
60
+ LoadImaged(keys=["image"]),
61
+ EnsureChannelFirstd(keys="image"),
62
+ Orientationd(keys=["image"], axcodes="RAS"),
63
+ NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
64
+ ]
65
+ )
66
+
67
+ # Get the inferer from the bundle's inference config
68
+ inferer = parser.get_parsed_content(
69
+ 'inferer',
70
+ lazy=True, eval_expr=True, instantiate=True
71
+ )
72
+
73
+ # Compose the postprocessing transforms
74
+ post_transforms = Compose(
75
+ [
76
+ Activationsd(keys='pred', sigmoid=True),
77
+ AsDiscreted(keys='pred', threshold=0.5),
78
+ ScaleIntensityd(keys='image', minv=0., maxv=1.)
79
+ ]
80
+ )
81
+
82
+
83
+ # Define the predict function for the demo
84
+ def predict(input_file, z_axis, model=model, device=device):
85
+ # Load and process data in MONAI format
86
+ data = {'image': [input_file.name]}
87
+ data = preproc_transforms(data)
88
+
89
+ # Run inference and post-process predicted labels
90
+ model.to(device)
91
+ model.eval()
92
+ with torch.no_grad():
93
+ inputs = data['image'].to(device)
94
+ data['pred'] = inferer(inputs=inputs[None,...], network=model)
95
+ data = post_transforms(data)
96
+
97
+ # Convert tensors back to numpy arrays
98
+ data['image'] = data['image'].numpy()
99
+ data['pred'] = data['pred'].cpu().detach().numpy()
100
+
101
+ # Magnetic resonance imaging sequences
102
+ t1c = data['image'][0, :, :, z_axis] # T1-weighted, post contrast
103
+ t1 = data['image'][1, :, :, z_axis] # T1-weighted, pre contrast
104
+ t2 = data['image'][2, :, :, z_axis] # T2-weighted
105
+ flair = data['image'][3, :, :, z_axis] # FLAIR
106
+
107
+ # BraTS labels
108
+ tc = data['pred'][0, 0, :, :, z_axis] # Tumor core
109
+ wt = data['pred'][0, 1, :, :, z_axis] # Whole tumor
110
+ et = data['pred'][0, 2, :, :, z_axis] # Enhancing tumor
111
+
112
+ return [t1c, t1, t2, flair], [tc, wt, et]
113
+
114
+
115
+ # Use blocks to set up a more complex demo
116
+ with gr.Blocks() as demo:
117
+
118
+ with gr.Row():
119
+ # Get the input file and slice slider as inputs
120
+ input_file = gr.File(label='input file')
121
+ z_axis = gr.Slider(0, 200, label='slice', value=50)
122
+
123
+ with gr.Row():
124
+ # Show the button with custom label
125
+ button = gr.Button("Segment Tumor!")
126
+
127
+ with gr.Row():
128
+ with gr.Column():
129
+ # Show the input image with different MR sequences
130
+ input_image = gr.Gallery(label='input MRI sequences (T1+, T1, T2, FLAIR)')
131
+
132
+ with gr.Column():
133
+ # Show the segmentation labels
134
+ output_segmentation = gr.Gallery(label='output segmentations (TC, WT, ET)')
135
+
136
+
137
+ # Run prediction on button click
138
+ button.click(
139
+ predict,
140
+ inputs=[input_file, z_axis],
141
+ outputs=[input_image, output_segmentation]
142
+ )
143
+
144
+ # Have some example for the user to try out
145
+ examples = gr.Examples(
146
+ examples=examples,
147
+ inputs=[input_file, z_axis],
148
+ outputs=[input_image, output_segmentation],
149
+ fn=predict,
150
+ cache_examples=False
151
+ )
152
+
153
+ # Launch the demo
154
+ demo.launch()