TDHarshithReddy commited on
Commit
3e09c97
·
0 Parent(s):

Initial commit: MONAI WholeBody CT Segmentation Space

Browse files
.gitattributes ADDED
@@ -0,0 +1 @@
 
 
1
+ *.nii.gz filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: MONAI WholeBody CT Segmentation
3
+ emoji: 🏥
4
+ colorFrom: blue
5
+ colorTo: green
6
+ sdk: gradio
7
+ sdk_version: 4.44.0
8
+ app_file: app.py
9
+ pinned: false
10
+ license: apache-2.0
11
+ tags:
12
+ - medical-imaging
13
+ - segmentation
14
+ - ct-scan
15
+ - monai
16
+ - 3d-segmentation
17
+ ---
18
+
19
+ # 🏥 MONAI WholeBody CT Segmentation
20
+
21
+ **Automatic 3D segmentation of 104 anatomical structures from CT scans**
22
+
23
+ ## Overview
24
+
25
+ This application uses MONAI's pre-trained **SegResNet** model trained on the [TotalSegmentator](https://github.com/wasserth/TotalSegmentator) dataset to automatically segment 104 different anatomical structures from whole-body CT scans.
26
+
27
+ ## Features
28
+
29
+ - 🔬 **104 Anatomical Structures**: Segments organs, bones, muscles, and vessels
30
+ - 📊 **Interactive Visualization**: Navigate through axial, coronal, and sagittal views
31
+ - 🎨 **Color-coded Overlay**: Each structure has a distinct color for easy identification
32
+ - ⚡ **GPU Accelerated**: Uses CUDA when available for faster inference
33
+
34
+ ## Supported Structures
35
+
36
+ | Category | Structures |
37
+ |----------|------------|
38
+ | **Major Organs** | Liver, Spleen, Kidneys, Pancreas, Gallbladder, Stomach, Bladder |
39
+ | **Cardiovascular** | Heart (4 chambers), Aorta, Vena Cava, Portal Vein, Iliac vessels |
40
+ | **Respiratory** | Lung lobes (5), Trachea, Esophagus |
41
+ | **Skeletal** | Vertebrae (C1-L5), 24 Ribs, Hip bones, Femur, Humerus, Scapula, Clavicle |
42
+ | **Muscles** | Gluteal muscles, Iliopsoas, Autochthon |
43
+ | **Other** | Brain, Face, Adrenal glands, Small/Large bowel |
44
+
45
+ ## Usage
46
+
47
+ 1. **Upload** a CT scan in NIfTI format (`.nii` or `.nii.gz`)
48
+ 2. Click **Run Segmentation** and wait for processing (1-5 minutes)
49
+ 3. **Explore** the results using the slice sliders and view controls
50
+ 4. Check the **Detected Structures** panel to see all identified anatomy
51
+
52
+ ## Model Details
53
+
54
+ - **Architecture**: SegResNet (MONAI)
55
+ - **Resolution**: 3.0mm isotropic (low-resolution model)
56
+ - **Training Data**: TotalSegmentator dataset
57
+ - **Output**: 105 channels (background + 104 structures)
58
+
59
+ ## References
60
+
61
+ - [MONAI Model Zoo](https://monai.io/model-zoo.html)
62
+ - [TotalSegmentator Paper](https://pubs.rsna.org/doi/10.1148/ryai.230024)
63
+ - [TotalSegmentator GitHub](https://github.com/wasserth/TotalSegmentator)
64
+
65
+ ## License
66
+
67
+ This model is released under the Apache 2.0 License. The TotalSegmentator dataset is released under CC BY 4.0.
68
+
69
+ ## Citation
70
+
71
+ If you use this model, please cite:
72
+
73
+ ```bibtex
74
+ @article{wasserthal2023totalsegmentator,
75
+ title={TotalSegmentator: robust segmentation of 104 anatomical structures in CT images},
76
+ author={Wasserthal, Jakob and others},
77
+ journal={Radiology: Artificial Intelligence},
78
+ year={2023}
79
+ }
80
+ ```
app.py ADDED
@@ -0,0 +1,412 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ MONAI WholeBody CT Segmentation - Hugging Face Space
3
+ Segments 104 anatomical structures from CT scans using MONAI's SegResNet model
4
+ """
5
+
6
+ import os
7
+ import tempfile
8
+ import numpy as np
9
+ import gradio as gr
10
+ import torch
11
+ import nibabel as nib
12
+ import matplotlib.pyplot as plt
13
+ from matplotlib.colors import ListedColormap
14
+ from huggingface_hub import hf_hub_download
15
+ from monai.networks.nets import SegResNet
16
+ from monai.transforms import (
17
+ Compose,
18
+ LoadImage,
19
+ EnsureChannelFirst,
20
+ Orientation,
21
+ Spacing,
22
+ ScaleIntensityRange,
23
+ CropForeground,
24
+ Activations,
25
+ AsDiscrete,
26
+ )
27
+ from monai.inferers import sliding_window_inference
28
+ from labels import LABEL_NAMES, get_color_map, get_label_name, get_organ_categories
29
+
30
+ # Constants
31
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
32
+ MODEL_REPO = "MONAI/wholeBody_ct_segmentation"
33
+ SPATIAL_SIZE = (96, 96, 96)
34
+ PIXDIM = (3.0, 3.0, 3.0) # Low-res model spacing
35
+
36
+ # Global model variable
37
+ model = None
38
+
39
+
40
+ def load_model():
41
+ """Download and load the MONAI SegResNet model"""
42
+ global model
43
+ if model is not None:
44
+ return model
45
+
46
+ print("Downloading model weights...")
47
+ try:
48
+ model_path = hf_hub_download(
49
+ repo_id=MODEL_REPO,
50
+ filename="models/model_lowres.pt",
51
+ )
52
+ except Exception as e:
53
+ print(f"Failed to download from HF, trying alternative: {e}")
54
+ # Fallback: try to download from MONAI model zoo
55
+ model_path = hf_hub_download(
56
+ repo_id=MODEL_REPO,
57
+ filename="models/model.pt",
58
+ )
59
+
60
+ print(f"Loading model from {model_path}...")
61
+
62
+ # Initialize SegResNet with 105 output channels (background + 104 classes)
63
+ model = SegResNet(
64
+ blocks_down=[1, 2, 2, 4],
65
+ blocks_up=[1, 1, 1],
66
+ init_filters=16,
67
+ in_channels=1,
68
+ out_channels=105,
69
+ dropout_prob=0.2,
70
+ )
71
+
72
+ # Load weights
73
+ checkpoint = torch.load(model_path, map_location=DEVICE)
74
+ if isinstance(checkpoint, dict) and "state_dict" in checkpoint:
75
+ model.load_state_dict(checkpoint["state_dict"])
76
+ else:
77
+ model.load_state_dict(checkpoint)
78
+
79
+ model.to(DEVICE)
80
+ model.eval()
81
+ print(f"Model loaded successfully on {DEVICE}")
82
+
83
+ return model
84
+
85
+
86
+ def get_preprocessing_transforms():
87
+ """Get MONAI preprocessing transforms"""
88
+ return Compose([
89
+ LoadImage(image_only=True),
90
+ EnsureChannelFirst(),
91
+ Orientation(axcodes="RAS"),
92
+ Spacing(pixdim=PIXDIM, mode="bilinear"),
93
+ ScaleIntensityRange(
94
+ a_min=-1024, a_max=1024,
95
+ b_min=0.0, b_max=1.0,
96
+ clip=True
97
+ ),
98
+ ])
99
+
100
+
101
+ def get_postprocessing_transforms():
102
+ """Get MONAI postprocessing transforms"""
103
+ return Compose([
104
+ Activations(softmax=True),
105
+ AsDiscrete(argmax=True),
106
+ ])
107
+
108
+
109
+ def run_inference(image_path: str, progress=gr.Progress()):
110
+ """Run segmentation inference on a CT image"""
111
+ progress(0.1, desc="Loading model...")
112
+ model = load_model()
113
+
114
+ progress(0.2, desc="Preprocessing image...")
115
+ preprocess = get_preprocessing_transforms()
116
+ postprocess = get_postprocessing_transforms()
117
+
118
+ # Load and preprocess
119
+ image = preprocess(image_path)
120
+ image = image.unsqueeze(0).to(DEVICE) # Add batch dimension
121
+
122
+ progress(0.4, desc="Running segmentation (this may take a few minutes)...")
123
+
124
+ with torch.no_grad():
125
+ # Use sliding window inference for large volumes
126
+ outputs = sliding_window_inference(
127
+ image,
128
+ roi_size=SPATIAL_SIZE,
129
+ sw_batch_size=4,
130
+ predictor=model,
131
+ overlap=0.5,
132
+ )
133
+
134
+ progress(0.8, desc="Post-processing...")
135
+
136
+ # Post-process
137
+ outputs = postprocess(outputs)
138
+ segmentation = outputs.squeeze().cpu().numpy().astype(np.uint8)
139
+
140
+ # Load original image for visualization
141
+ original_nib = nib.load(image_path)
142
+ original_data = original_nib.get_fdata()
143
+
144
+ progress(1.0, desc="Complete!")
145
+
146
+ return original_data, segmentation
147
+
148
+
149
+ def create_slice_visualization(ct_data, seg_data, axis, slice_idx, alpha=0.5, show_overlay=True):
150
+ """Create a visualization of a CT slice with segmentation overlay"""
151
+
152
+ # Get the slice based on axis
153
+ if axis == "Axial":
154
+ ct_slice = ct_data[:, :, slice_idx]
155
+ seg_slice = seg_data[:, :, slice_idx] if seg_data is not None else None
156
+ elif axis == "Coronal":
157
+ ct_slice = ct_data[:, slice_idx, :]
158
+ seg_slice = seg_data[:, slice_idx, :] if seg_data is not None else None
159
+ else: # Sagittal
160
+ ct_slice = ct_data[slice_idx, :, :]
161
+ seg_slice = seg_data[slice_idx, :, :] if seg_data is not None else None
162
+
163
+ # Create figure
164
+ fig, ax = plt.subplots(1, 1, figsize=(8, 8))
165
+
166
+ # Normalize CT for display
167
+ ct_normalized = np.clip(ct_slice, -1024, 1024)
168
+ ct_normalized = (ct_normalized - ct_normalized.min()) / (ct_normalized.max() - ct_normalized.min() + 1e-8)
169
+
170
+ # Display CT
171
+ ax.imshow(ct_normalized.T, cmap='gray', origin='lower')
172
+
173
+ # Overlay segmentation
174
+ if show_overlay and seg_slice is not None and np.any(seg_slice > 0):
175
+ colors = get_color_map() / 255.0
176
+ colors[0] = [0, 0, 0, 0] # Make background transparent
177
+
178
+ # Create RGBA overlay
179
+ seg_rgba = colors[seg_slice.astype(int)]
180
+ seg_rgba = np.concatenate([seg_rgba, np.ones((*seg_slice.shape, 1)) * alpha], axis=-1)
181
+ seg_rgba[seg_slice == 0, 3] = 0 # Transparent background
182
+
183
+ ax.imshow(seg_rgba.transpose(1, 0, 2), origin='lower')
184
+
185
+ ax.axis('off')
186
+ ax.set_title(f"{axis} View - Slice {slice_idx}")
187
+
188
+ plt.tight_layout()
189
+ return fig
190
+
191
+
192
+ def get_detected_structures(seg_data):
193
+ """Get list of detected anatomical structures"""
194
+ unique_labels = np.unique(seg_data)
195
+ unique_labels = unique_labels[unique_labels > 0] # Exclude background
196
+
197
+ structures = []
198
+ for label in unique_labels:
199
+ name = get_label_name(label)
200
+ count = np.sum(seg_data == label)
201
+ structures.append(f"• {name} (Label {label})")
202
+
203
+ return "\n".join(structures) if structures else "No structures detected"
204
+
205
+
206
+ # Global state for current visualization
207
+ current_ct_data = None
208
+ current_seg_data = None
209
+
210
+
211
+ def process_upload(file_path, progress=gr.Progress()):
212
+ """Process uploaded CT file and run segmentation"""
213
+ global current_ct_data, current_seg_data
214
+
215
+ if file_path is None:
216
+ return None, "Please upload a NIfTI file", gr.update(maximum=1), gr.update(maximum=1), gr.update(maximum=1)
217
+
218
+ try:
219
+ ct_data, seg_data = run_inference(file_path, progress)
220
+ current_ct_data = ct_data
221
+ current_seg_data = seg_data
222
+
223
+ # Get initial visualization
224
+ mid_axial = ct_data.shape[2] // 2
225
+ mid_coronal = ct_data.shape[1] // 2
226
+ mid_sagittal = ct_data.shape[0] // 2
227
+
228
+ fig = create_slice_visualization(ct_data, seg_data, "Axial", mid_axial)
229
+ structures = get_detected_structures(seg_data)
230
+
231
+ return (
232
+ fig,
233
+ structures,
234
+ gr.update(maximum=ct_data.shape[2] - 1, value=mid_axial),
235
+ gr.update(maximum=ct_data.shape[1] - 1, value=mid_coronal),
236
+ gr.update(maximum=ct_data.shape[0] - 1, value=mid_sagittal),
237
+ )
238
+ except Exception as e:
239
+ return None, f"Error processing file: {str(e)}", gr.update(), gr.update(), gr.update()
240
+
241
+
242
+ def update_visualization(axis, slice_idx, alpha, show_overlay):
243
+ """Update the visualization based on slider changes"""
244
+ global current_ct_data, current_seg_data
245
+
246
+ if current_ct_data is None:
247
+ return None
248
+
249
+ fig = create_slice_visualization(
250
+ current_ct_data,
251
+ current_seg_data,
252
+ axis,
253
+ int(slice_idx),
254
+ alpha,
255
+ show_overlay
256
+ )
257
+ return fig
258
+
259
+
260
+ def load_example(example_name):
261
+ """Load a bundled example CT scan"""
262
+ example_dir = os.path.join(os.path.dirname(__file__), "examples")
263
+ example_path = os.path.join(example_dir, example_name)
264
+
265
+ if os.path.exists(example_path):
266
+ return example_path
267
+ return None
268
+
269
+
270
+ # Create Gradio interface
271
+ with gr.Blocks(
272
+ title="MONAI WholeBody CT Segmentation",
273
+ theme=gr.themes.Soft(),
274
+ css="""
275
+ .gradio-container {max-width: 1200px !important}
276
+ .output-image {min-height: 500px}
277
+ """
278
+ ) as demo:
279
+ gr.Markdown("""
280
+ # 🏥 MONAI WholeBody CT Segmentation
281
+
282
+ **Automatic segmentation of 104 anatomical structures from CT scans**
283
+
284
+ This application uses MONAI's pre-trained SegResNet model trained on the TotalSegmentator dataset.
285
+ Upload a CT scan in NIfTI format (.nii or .nii.gz) to get started.
286
+
287
+ > ⚡ **Note**: Processing may take 1-5 minutes depending on the CT volume size.
288
+ """)
289
+
290
+ with gr.Row():
291
+ with gr.Column(scale=1):
292
+ # Input section
293
+ gr.Markdown("### 📤 Upload CT Scan")
294
+ file_input = gr.File(
295
+ label="Upload NIfTI file (.nii, .nii.gz)",
296
+ file_types=[".nii", ".nii.gz", ".gz"],
297
+ type="filepath"
298
+ )
299
+
300
+ # Example files
301
+ gr.Markdown("### 📁 Example Files")
302
+ example_gallery = gr.Examples(
303
+ examples=[
304
+ ["examples/sample_ct_chest.nii.gz"],
305
+ ["examples/sample_ct_abdomen.nii.gz"],
306
+ ],
307
+ inputs=[file_input],
308
+ label="Click to load example"
309
+ )
310
+
311
+ process_btn = gr.Button("🔬 Run Segmentation", variant="primary", size="lg")
312
+
313
+ # Visualization controls
314
+ gr.Markdown("### 🎛️ Visualization Controls")
315
+
316
+ view_axis = gr.Radio(
317
+ choices=["Axial", "Coronal", "Sagittal"],
318
+ value="Axial",
319
+ label="View Axis"
320
+ )
321
+
322
+ with gr.Row():
323
+ axial_slider = gr.Slider(0, 100, value=50, step=1, label="Axial Slice")
324
+ coronal_slider = gr.Slider(0, 100, value=50, step=1, label="Coronal Slice")
325
+ sagittal_slider = gr.Slider(0, 100, value=50, step=1, label="Sagittal Slice")
326
+
327
+ alpha_slider = gr.Slider(0, 1, value=0.5, step=0.1, label="Overlay Opacity")
328
+ show_overlay = gr.Checkbox(value=True, label="Show Segmentation Overlay")
329
+
330
+ with gr.Column(scale=2):
331
+ # Output section
332
+ gr.Markdown("### 🖼️ Segmentation Result")
333
+ output_image = gr.Plot(label="CT with Segmentation Overlay")
334
+
335
+ gr.Markdown("### 📋 Detected Structures")
336
+ structures_output = gr.Textbox(
337
+ label="Anatomical Structures Found",
338
+ lines=10,
339
+ max_lines=20
340
+ )
341
+
342
+ # Model info section
343
+ with gr.Accordion("ℹ️ Model Information", open=False):
344
+ gr.Markdown("""
345
+ ### About the Model
346
+
347
+ This model is based on **SegResNet** architecture from MONAI, trained on the **TotalSegmentator** dataset.
348
+
349
+ **Capabilities:**
350
+ - Segments 104 distinct anatomical structures
351
+ - Works on whole-body CT scans
352
+ - Uses 3.0mm isotropic spacing (low-resolution model for faster inference)
353
+
354
+ **Segmented Structures include:**
355
+ - **Major Organs**: Liver, Spleen, Kidneys, Pancreas, Gallbladder, Stomach, Bladder
356
+ - **Cardiovascular**: Heart chambers, Aorta, Vena Cava, Portal Vein
357
+ - **Respiratory**: Lung lobes, Trachea
358
+ - **Skeletal**: Vertebrae (C1-L5), Ribs, Hip bones, Femur, Humerus, Scapula
359
+ - **Muscles**: Gluteal muscles, Iliopsoas
360
+ - And many more...
361
+
362
+ **References:**
363
+ - [MONAI Model Zoo](https://monai.io/model-zoo.html)
364
+ - [TotalSegmentator Paper](https://pubs.rsna.org/doi/10.1148/ryai.230024)
365
+ """)
366
+
367
+ # Event handlers
368
+ process_btn.click(
369
+ fn=process_upload,
370
+ inputs=[file_input],
371
+ outputs=[output_image, structures_output, axial_slider, coronal_slider, sagittal_slider]
372
+ )
373
+
374
+ # Update visualization when controls change
375
+ for control in [view_axis, alpha_slider, show_overlay]:
376
+ control.change(
377
+ fn=lambda axis, alpha, overlay, ax_s, cor_s, sag_s: update_visualization(
378
+ axis,
379
+ ax_s if axis == "Axial" else (cor_s if axis == "Coronal" else sag_s),
380
+ alpha,
381
+ overlay
382
+ ),
383
+ inputs=[view_axis, alpha_slider, show_overlay, axial_slider, coronal_slider, sagittal_slider],
384
+ outputs=[output_image]
385
+ )
386
+
387
+ # Update when sliders change
388
+ axial_slider.change(
389
+ fn=lambda s, alpha, overlay: update_visualization("Axial", s, alpha, overlay),
390
+ inputs=[axial_slider, alpha_slider, show_overlay],
391
+ outputs=[output_image]
392
+ )
393
+
394
+ coronal_slider.change(
395
+ fn=lambda s, alpha, overlay: update_visualization("Coronal", s, alpha, overlay),
396
+ inputs=[coronal_slider, alpha_slider, show_overlay],
397
+ outputs=[output_image]
398
+ )
399
+
400
+ sagittal_slider.change(
401
+ fn=lambda s, alpha, overlay: update_visualization("Sagittal", s, alpha, overlay),
402
+ inputs=[sagittal_slider, alpha_slider, show_overlay],
403
+ outputs=[output_image]
404
+ )
405
+
406
+
407
+ if __name__ == "__main__":
408
+ # Ensure examples directory exists
409
+ os.makedirs("examples", exist_ok=True)
410
+
411
+ # Launch the app
412
+ demo.launch()
create_samples.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Create synthetic CT phantom samples for HuggingFace Space demo
3
+ """
4
+ import numpy as np
5
+ import nibabel as nib
6
+ import os
7
+
8
+ def create_synthetic_samples():
9
+ examples_dir = 'examples'
10
+ os.makedirs(examples_dir, exist_ok=True)
11
+
12
+ np.random.seed(42)
13
+
14
+ # Sample 1: Chest phantom
15
+ print('Creating sample_ct_chest.nii.gz...')
16
+ shape = (192, 192, 128)
17
+ data = np.ones(shape, dtype=np.float32) * -1000 # Air background
18
+
19
+ cx, cy, cz = shape[0]//2, shape[1]//2, shape[2]//2
20
+
21
+ # Create coordinate grids
22
+ X, Y, Z = np.meshgrid(
23
+ np.arange(shape[0]) - cx,
24
+ np.arange(shape[1]) - cy,
25
+ np.arange(shape[2]) - cz,
26
+ indexing='ij'
27
+ )
28
+
29
+ # Body contour (ellipsoid)
30
+ body = (X**2 / 70**2 + Y**2 / 50**2 + Z**2 / 55**2) < 1
31
+ data[body] = 40 + np.random.randn(*shape)[body] * 30
32
+
33
+ # Spine
34
+ spine = (X**2 + (Y + 35)**2) < 100
35
+ data[spine] = 400 + np.random.randn(*shape)[spine] * 50
36
+
37
+ # Liver
38
+ liver = ((X - 30)**2 / 25**2 + (Y + 10)**2 / 30**2 + (Z + 20)**2 / 35**2) < 1
39
+ data[liver] = 60 + np.random.randn(*shape)[liver] * 15
40
+
41
+ # Lungs
42
+ lung_l = ((X + 25)**2 / 20**2 + (Y - 20)**2 / 25**2 + Z**2 / 40**2) < 1
43
+ lung_r = ((X - 25)**2 / 20**2 + (Y - 20)**2 / 25**2 + Z**2 / 40**2) < 1
44
+ data[lung_l | lung_r] = -700 + np.random.randn(*shape)[lung_l | lung_r] * 100
45
+
46
+ # Heart
47
+ heart = (X**2 / 20**2 + (Y - 10)**2 / 18**2 + (Z + 10)**2 / 25**2) < 1
48
+ data[heart] = 50 + np.random.randn(*shape)[heart] * 20
49
+
50
+ # Kidneys
51
+ kidney_l = ((X + 35)**2 / 8**2 + (Y + 15)**2 / 12**2 + (Z + 10)**2 / 18**2) < 1
52
+ kidney_r = ((X - 35)**2 / 8**2 + (Y + 15)**2 / 12**2 + (Z + 10)**2 / 18**2) < 1
53
+ data[kidney_l | kidney_r] = 30 + np.random.randn(*shape)[kidney_l | kidney_r] * 10
54
+
55
+ # Ribs
56
+ for i in range(-5, 6):
57
+ rib_z = i * 10
58
+ rib_mask = ((X - 45)**2 + (Y - 15)**2 < 20) & (np.abs(Z - rib_z) < 3)
59
+ data[rib_mask] = 350 + np.random.randn(*shape)[rib_mask] * 40
60
+ rib_mask = ((X + 45)**2 + (Y - 15)**2 < 20) & (np.abs(Z - rib_z) < 3)
61
+ data[rib_mask] = 350 + np.random.randn(*shape)[rib_mask] * 40
62
+
63
+ affine = np.diag([1.5, 1.5, 1.5, 1.0])
64
+ img = nib.Nifti1Image(data.astype(np.int16), affine)
65
+ nib.save(img, os.path.join(examples_dir, 'sample_ct_chest.nii.gz'))
66
+ print(' ✓ Created sample_ct_chest.nii.gz')
67
+
68
+ # Sample 2: Abdomen phantom
69
+ print('Creating sample_ct_abdomen.nii.gz...')
70
+ shape2 = (160, 160, 96)
71
+ data2 = np.ones(shape2, dtype=np.float32) * -1000
72
+
73
+ cx2, cy2, cz2 = shape2[0]//2, shape2[1]//2, shape2[2]//2
74
+
75
+ X2, Y2, Z2 = np.meshgrid(
76
+ np.arange(shape2[0]) - cx2,
77
+ np.arange(shape2[1]) - cy2,
78
+ np.arange(shape2[2]) - cz2,
79
+ indexing='ij'
80
+ )
81
+
82
+ # Body
83
+ body2 = (X2**2 / 60**2 + Y2**2 / 45**2) < 1
84
+ data2[body2] = 35 + np.random.randn(*shape2)[body2] * 25
85
+
86
+ # Spine
87
+ spine2 = (X2**2 + (Y2 + 30)**2) < 80
88
+ data2[spine2] = 380 + np.random.randn(*shape2)[spine2] * 45
89
+
90
+ # Liver
91
+ liver2 = ((X2 - 25)**2 / 30**2 + (Y2 + 5)**2 / 28**2 + (Z2 + 15)**2 / 30**2) < 1
92
+ data2[liver2] = 55 + np.random.randn(*shape2)[liver2] * 12
93
+
94
+ # Spleen
95
+ spleen = ((X2 + 35)**2 / 15**2 + (Y2 + 10)**2 / 18**2 + (Z2 + 5)**2 / 20**2) < 1
96
+ data2[spleen] = 45 + np.random.randn(*shape2)[spleen] * 10
97
+
98
+ # Kidneys
99
+ kidney_l2 = ((X2 + 28)**2 / 10**2 + (Y2 + 12)**2 / 14**2 + Z2**2 / 18**2) < 1
100
+ kidney_r2 = ((X2 - 28)**2 / 10**2 + (Y2 + 12)**2 / 14**2 + Z2**2 / 18**2) < 1
101
+ data2[kidney_l2 | kidney_r2] = 32 + np.random.randn(*shape2)[kidney_l2 | kidney_r2] * 8
102
+
103
+ # Pancreas
104
+ pancreas = (X2**2 / 25**2 + (Y2 + 8)**2 / 6**2 + (Z2 - 5)**2 / 10**2) < 1
105
+ data2[pancreas] = 40 + np.random.randn(*shape2)[pancreas] * 10
106
+
107
+ # Aorta
108
+ aorta = (X2**2 + (Y2 + 20)**2) < 36
109
+ data2[aorta] = 120 + np.random.randn(*shape2)[aorta] * 25
110
+
111
+ # Stomach
112
+ stomach = ((X2 + 10)**2 / 20**2 + (Y2 - 5)**2 / 15**2 + (Z2 + 25)**2 / 18**2) < 1
113
+ data2[stomach] = -50 + np.random.randn(*shape2)[stomach] * 30
114
+
115
+ affine2 = np.diag([2.0, 2.0, 2.0, 1.0])
116
+ img2 = nib.Nifti1Image(data2.astype(np.int16), affine2)
117
+ nib.save(img2, os.path.join(examples_dir, 'sample_ct_abdomen.nii.gz'))
118
+ print(' ✓ Created sample_ct_abdomen.nii.gz')
119
+
120
+ print('\nDone! Synthetic CT samples created.')
121
+ return ['sample_ct_chest.nii.gz', 'sample_ct_abdomen.nii.gz']
122
+
123
+ if __name__ == '__main__':
124
+ create_synthetic_samples()
download_samples.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Download sample CT scans for the Hugging Face Space demo.
3
+ Uses publicly available data from Zenodo (TotalSegmentator sample subset).
4
+ """
5
+
6
+ import os
7
+ import urllib.request
8
+ import zipfile
9
+ import shutil
10
+ import gzip
11
+
12
+ # Sample CT scans from public sources
13
+ SAMPLE_URLS = [
14
+ # TotalSegmentator sample from Zenodo (small sample subset)
15
+ {
16
+ "url": "https://zenodo.org/records/10047292/files/Totalsegmentator_dataset_small_v201.zip?download=1",
17
+ "filename": "ts_sample.zip",
18
+ "type": "zip",
19
+ "description": "TotalSegmentator sample dataset (102 subjects)"
20
+ }
21
+ ]
22
+
23
+ # Alternative: Direct links to individual sample scans from open datasets
24
+ DIRECT_SAMPLES = [
25
+ # These would be actual CT scan URLs if available
26
+ # For now, we'll create a placeholder that downloads on first run
27
+ ]
28
+
29
+
30
+ def download_file(url: str, dest_path: str, description: str = ""):
31
+ """Download a file with progress indication"""
32
+ print(f"Downloading {description}...")
33
+ print(f" URL: {url}")
34
+ print(f" Destination: {dest_path}")
35
+
36
+ try:
37
+ urllib.request.urlretrieve(url, dest_path)
38
+ print(f" ✓ Downloaded successfully")
39
+ return True
40
+ except Exception as e:
41
+ print(f" ✗ Failed: {e}")
42
+ return False
43
+
44
+
45
+ def extract_sample_from_zip(zip_path: str, examples_dir: str, max_samples: int = 2):
46
+ """Extract a few sample CT scans from the dataset zip"""
47
+ print(f"Extracting samples from {zip_path}...")
48
+
49
+ with zipfile.ZipFile(zip_path, 'r') as zf:
50
+ # List all files to find CT scans
51
+ all_files = zf.namelist()
52
+
53
+ # Find .nii.gz files (CT images)
54
+ ct_files = [f for f in all_files if f.endswith('.nii.gz') and 'ct.' in f.lower()]
55
+
56
+ if not ct_files:
57
+ # Try alternative patterns
58
+ ct_files = [f for f in all_files if f.endswith('.nii.gz') and 'image' in f.lower()]
59
+
60
+ if not ct_files:
61
+ # Just get any .nii.gz files
62
+ ct_files = [f for f in all_files if f.endswith('.nii.gz')][:max_samples * 2]
63
+
64
+ # Extract a few samples
65
+ extracted = 0
66
+ for ct_file in ct_files[:max_samples]:
67
+ try:
68
+ # Extract to examples directory
69
+ output_name = f"sample_ct_{extracted + 1}.nii.gz"
70
+ output_path = os.path.join(examples_dir, output_name)
71
+
72
+ # Extract the file
73
+ with zf.open(ct_file) as source:
74
+ with open(output_path, 'wb') as target:
75
+ target.write(source.read())
76
+
77
+ print(f" ✓ Extracted: {output_name}")
78
+ extracted += 1
79
+
80
+ if extracted >= max_samples:
81
+ break
82
+ except Exception as e:
83
+ print(f" ✗ Failed to extract {ct_file}: {e}")
84
+
85
+ return extracted
86
+
87
+
88
+ def create_synthetic_sample(examples_dir: str):
89
+ """Create a small synthetic NIfTI file for testing"""
90
+ import numpy as np
91
+ import nibabel as nib
92
+
93
+ print("Creating synthetic sample CT for testing...")
94
+
95
+ # Create a simple 3D volume (small for demo)
96
+ shape = (128, 128, 64)
97
+ data = np.random.randn(*shape).astype(np.float32) * 100 - 500
98
+
99
+ # Add some structure (spheres to simulate organs)
100
+ center = np.array(shape) // 2
101
+
102
+ for i in range(3):
103
+ offset = np.array([20 * (i - 1), 0, 0])
104
+ pos = center + offset
105
+ x, y, z = np.ogrid[:shape[0], :shape[1], :shape[2]]
106
+ mask = ((x - pos[0])**2 + (y - pos[1])**2 + (z - pos[2])**2) < 400
107
+ data[mask] = 50 + i * 30 # Different intensities
108
+
109
+ # Create NIfTI
110
+ affine = np.diag([3.0, 3.0, 3.0, 1.0]) # 3mm spacing
111
+ img = nib.Nifti1Image(data, affine)
112
+
113
+ output_path = os.path.join(examples_dir, "sample_synthetic.nii.gz")
114
+ nib.save(img, output_path)
115
+ print(f" ✓ Created: {output_path}")
116
+
117
+ return output_path
118
+
119
+
120
+ def setup_examples():
121
+ """Download and set up example CT scans"""
122
+ examples_dir = os.path.join(os.path.dirname(__file__), "examples")
123
+ os.makedirs(examples_dir, exist_ok=True)
124
+
125
+ # Check if examples already exist
126
+ existing = [f for f in os.listdir(examples_dir) if f.endswith('.nii.gz')]
127
+ if len(existing) >= 2:
128
+ print(f"Examples already exist: {existing}")
129
+ return existing
130
+
131
+ print("Setting up example CT scans...")
132
+ print("=" * 50)
133
+
134
+ # Try to download from Zenodo
135
+ temp_dir = os.path.join(examples_dir, "temp")
136
+ os.makedirs(temp_dir, exist_ok=True)
137
+
138
+ downloaded = False
139
+ for sample in SAMPLE_URLS:
140
+ zip_path = os.path.join(temp_dir, sample["filename"])
141
+ if download_file(sample["url"], zip_path, sample["description"]):
142
+ if sample["type"] == "zip":
143
+ extracted = extract_sample_from_zip(zip_path, examples_dir)
144
+ if extracted > 0:
145
+ downloaded = True
146
+ break
147
+
148
+ # Clean up temp files
149
+ if os.path.exists(temp_dir):
150
+ shutil.rmtree(temp_dir)
151
+
152
+ # If download failed, create synthetic sample
153
+ if not downloaded:
154
+ print("\nDownload failed, creating synthetic sample for testing...")
155
+ try:
156
+ create_synthetic_sample(examples_dir)
157
+ except ImportError:
158
+ print(" nibabel not available, skipping synthetic sample creation")
159
+
160
+ # List final examples
161
+ final_examples = [f for f in os.listdir(examples_dir) if f.endswith('.nii.gz')]
162
+ print(f"\nFinal examples: {final_examples}")
163
+ return final_examples
164
+
165
+
166
+ if __name__ == "__main__":
167
+ setup_examples()
examples/sample_ct_abdomen.nii.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:529f536c2c150332ac1a6a190d155cf193ff7f7d14d5980879a27fc0a30a6ac1
3
+ size 1078619
examples/sample_ct_chest.nii.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e865ff9bc7929dfca1d81dc1c19ff3cadbb725740321e84fd3ad616951a1a3c5
3
+ size 1237226
labels.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 104 Anatomical Structure Labels for TotalSegmentator / MONAI wholeBody_ct_segmentation
2
+ # Based on TotalSegmentator v1 class definitions
3
+
4
+ LABEL_NAMES = {
5
+ 0: "background",
6
+ 1: "spleen",
7
+ 2: "kidney_right",
8
+ 3: "kidney_left",
9
+ 4: "gallbladder",
10
+ 5: "liver",
11
+ 6: "stomach",
12
+ 7: "aorta",
13
+ 8: "inferior_vena_cava",
14
+ 9: "portal_vein_and_splenic_vein",
15
+ 10: "pancreas",
16
+ 11: "adrenal_gland_right",
17
+ 12: "adrenal_gland_left",
18
+ 13: "lung_upper_lobe_left",
19
+ 14: "lung_lower_lobe_left",
20
+ 15: "lung_upper_lobe_right",
21
+ 16: "lung_middle_lobe_right",
22
+ 17: "lung_lower_lobe_right",
23
+ 18: "vertebrae_L5",
24
+ 19: "vertebrae_L4",
25
+ 20: "vertebrae_L3",
26
+ 21: "vertebrae_L2",
27
+ 22: "vertebrae_L1",
28
+ 23: "vertebrae_T12",
29
+ 24: "vertebrae_T11",
30
+ 25: "vertebrae_T10",
31
+ 26: "vertebrae_T9",
32
+ 27: "vertebrae_T8",
33
+ 28: "vertebrae_T7",
34
+ 29: "vertebrae_T6",
35
+ 30: "vertebrae_T5",
36
+ 31: "vertebrae_T4",
37
+ 32: "vertebrae_T3",
38
+ 33: "vertebrae_T2",
39
+ 34: "vertebrae_T1",
40
+ 35: "vertebrae_C7",
41
+ 36: "vertebrae_C6",
42
+ 37: "vertebrae_C5",
43
+ 38: "vertebrae_C4",
44
+ 39: "vertebrae_C3",
45
+ 40: "vertebrae_C2",
46
+ 41: "vertebrae_C1",
47
+ 42: "esophagus",
48
+ 43: "trachea",
49
+ 44: "heart_myocardium",
50
+ 45: "heart_atrium_left",
51
+ 46: "heart_ventricle_left",
52
+ 47: "heart_atrium_right",
53
+ 48: "heart_ventricle_right",
54
+ 49: "pulmonary_artery",
55
+ 50: "brain",
56
+ 51: "iliac_artery_left",
57
+ 52: "iliac_artery_right",
58
+ 53: "iliac_vena_left",
59
+ 54: "iliac_vena_right",
60
+ 55: "small_bowel",
61
+ 56: "duodenum",
62
+ 57: "colon",
63
+ 58: "rib_left_1",
64
+ 59: "rib_left_2",
65
+ 60: "rib_left_3",
66
+ 61: "rib_left_4",
67
+ 62: "rib_left_5",
68
+ 63: "rib_left_6",
69
+ 64: "rib_left_7",
70
+ 65: "rib_left_8",
71
+ 66: "rib_left_9",
72
+ 67: "rib_left_10",
73
+ 68: "rib_left_11",
74
+ 69: "rib_left_12",
75
+ 70: "rib_right_1",
76
+ 71: "rib_right_2",
77
+ 72: "rib_right_3",
78
+ 73: "rib_right_4",
79
+ 74: "rib_right_5",
80
+ 75: "rib_right_6",
81
+ 76: "rib_right_7",
82
+ 77: "rib_right_8",
83
+ 78: "rib_right_9",
84
+ 79: "rib_right_10",
85
+ 80: "rib_right_11",
86
+ 81: "rib_right_12",
87
+ 82: "humerus_left",
88
+ 83: "humerus_right",
89
+ 84: "scapula_left",
90
+ 85: "scapula_right",
91
+ 86: "clavicula_left",
92
+ 87: "clavicula_right",
93
+ 88: "femur_left",
94
+ 89: "femur_right",
95
+ 90: "hip_left",
96
+ 91: "hip_right",
97
+ 92: "sacrum",
98
+ 93: "face",
99
+ 94: "gluteus_maximus_left",
100
+ 95: "gluteus_maximus_right",
101
+ 96: "gluteus_medius_left",
102
+ 97: "gluteus_medius_right",
103
+ 98: "gluteus_minimus_left",
104
+ 99: "gluteus_minimus_right",
105
+ 100: "autochthon_left",
106
+ 101: "autochthon_right",
107
+ 102: "iliopsoas_left",
108
+ 103: "iliopsoas_right",
109
+ 104: "urinary_bladder",
110
+ }
111
+
112
+ # Color map for visualization (RGB values)
113
+ # Using a custom colormap for better visualization
114
+ import numpy as np
115
+
116
+ def get_color_map():
117
+ """Generate a color map for 105 classes (background + 104 structures)"""
118
+ np.random.seed(42) # For reproducibility
119
+ colors = np.zeros((105, 3), dtype=np.uint8)
120
+
121
+ # Background is black
122
+ colors[0] = [0, 0, 0]
123
+
124
+ # Assign distinct colors to different organ categories
125
+ # Organs (1-12): Warm colors
126
+ organ_colors = [
127
+ [255, 99, 71], # spleen - tomato
128
+ [255, 165, 0], # kidney_right - orange
129
+ [255, 140, 0], # kidney_left - dark orange
130
+ [50, 205, 50], # gallbladder - lime green
131
+ [139, 69, 19], # liver - saddle brown
132
+ [255, 192, 203], # stomach - pink
133
+ [220, 20, 60], # aorta - crimson
134
+ [0, 0, 139], # inferior_vena_cava - dark blue
135
+ [138, 43, 226], # portal_vein_and_splenic_vein - blue violet
136
+ [255, 215, 0], # pancreas - gold
137
+ [255, 255, 0], # adrenal_gland_right - yellow
138
+ [255, 255, 0], # adrenal_gland_left - yellow
139
+ ]
140
+ colors[1:13] = organ_colors
141
+
142
+ # Lungs (13-17): Light blue shades
143
+ colors[13:18] = [[135, 206, 235], [100, 149, 237], [30, 144, 255], [0, 191, 255], [70, 130, 180]]
144
+
145
+ # Vertebrae (18-41): Gradient from red to purple
146
+ for i in range(18, 42):
147
+ colors[i] = [200 - (i-18)*5, 100, 150 + (i-18)*3]
148
+
149
+ # Other structures (42-57): Various colors
150
+ colors[42] = [255, 182, 193] # esophagus
151
+ colors[43] = [176, 224, 230] # trachea
152
+ colors[44:49] = [[220, 20, 60], [255, 105, 180], [255, 20, 147], [255, 182, 193], [199, 21, 133]] # heart
153
+ colors[49] = [148, 0, 211] # pulmonary_artery
154
+ colors[50] = [255, 218, 185] # brain
155
+ colors[51:55] = [[178, 34, 34], [178, 34, 34], [70, 130, 180], [70, 130, 180]] # iliac vessels
156
+ colors[55:58] = [[222, 184, 135], [210, 180, 140], [188, 143, 143]] # bowels
157
+
158
+ # Ribs (58-81): Bone color variations
159
+ for i in range(58, 82):
160
+ colors[i] = [255, 250, 205 + (i-58) % 50]
161
+
162
+ # Bones (82-92): Gray/white shades
163
+ for i in range(82, 93):
164
+ colors[i] = [220 + (i-82)*2, 220 + (i-82)*2, 220]
165
+
166
+ # Face and muscles (93-103): Skin and muscle tones
167
+ colors[93] = [255, 228, 196] # face
168
+ for i in range(94, 104):
169
+ colors[i] = [205, 92, 92 + (i-94)*5] # muscles - indian red variations
170
+
171
+ # Bladder
172
+ colors[104] = [255, 255, 0] # urinary_bladder - yellow
173
+
174
+ return colors
175
+
176
+ def get_label_name(label_id: int) -> str:
177
+ """Get human-readable name for a label ID"""
178
+ return LABEL_NAMES.get(label_id, f"unknown_{label_id}").replace("_", " ").title()
179
+
180
+ def get_organ_categories():
181
+ """Group organs by category for UI display"""
182
+ return {
183
+ "Major Organs": [1, 2, 3, 4, 5, 6, 10, 104],
184
+ "Cardiovascular": [7, 8, 9, 44, 45, 46, 47, 48, 49, 51, 52, 53, 54],
185
+ "Respiratory": [13, 14, 15, 16, 17, 43],
186
+ "Digestive": [42, 55, 56, 57],
187
+ "Vertebrae": list(range(18, 42)),
188
+ "Ribs": list(range(58, 82)),
189
+ "Upper Body Bones": [82, 83, 84, 85, 86, 87],
190
+ "Lower Body Bones": [88, 89, 90, 91, 92],
191
+ "Muscles": list(range(94, 104)),
192
+ "Other": [11, 12, 50, 93],
193
+ }
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ monai[nibabel]>=1.3.0
2
+ torch>=2.0.0
3
+ gradio>=4.0.0
4
+ nibabel>=5.0.0
5
+ numpy
6
+ matplotlib
7
+ huggingface_hub
8
+ scikit-image
9
+ gdown